diff --git a/.compatibility b/.compatibility index c8ac4083d2a2b2985a02b6ec281a0e33ab4f23b2..32da32be5521539f90b6198e72089e14f8d0bd8e 100644 --- a/.compatibility +++ b/.compatibility @@ -1,3 +1,3 @@ 1.12.0-11.3.0 -1.11.0-11.3.0 -1.10.1-11.3.0 +1.13.0-11.6.0 +2.0.0-11.7.0 diff --git a/.coveragerc b/.coveragerc new file mode 100644 index 0000000000000000000000000000000000000000..b065e6eb9b772cc3a0253622c11986ee8e5613eb --- /dev/null +++ b/.coveragerc @@ -0,0 +1,4 @@ +[run] +concurrency = multiprocessing +parallel = true +sigterm = true diff --git a/.flake8 b/.flake8 deleted file mode 100644 index 229856aa4366c092e94eb573dbbd67756974a3ae..0000000000000000000000000000000000000000 --- a/.flake8 +++ /dev/null @@ -1,22 +0,0 @@ -[flake8] -ignore = - ;W503 line break before binary operator - W503, - ;E203 whitespace before ':' - E203, - -; exclude file -exclude = - .tox, - .git, - __pycache__, - build, - dist, - *.pyc, - *.egg-info, - .cache, - .eggs - -max-line-length = 120 - -per-file-ignores = __init__.py:F401 diff --git a/.github/ISSUE_TEMPLATE/config.yml b/.github/ISSUE_TEMPLATE/config.yml index b310fcfefc154e3c8b3e941a363181743c95b5ab..436bdf887c69326dba7b758db008cc1b1e1f6551 100644 --- a/.github/ISSUE_TEMPLATE/config.yml +++ b/.github/ISSUE_TEMPLATE/config.yml @@ -1,7 +1,7 @@ blank_issues_enabled: true contact_links: - name: ❓ Simple question - Slack Chat - url: https://join.slack.com/t/colossalaiworkspace/shared_invite/zt-z7b26eeb-CBp7jouvu~r0~lcFzX832w + url: https://github.com/hpcaitech/public_assets/tree/main/colossalai/contact/slack about: This issue tracker is not for technical support. Please use our Slack chat, and ask the community for help. - name: ❓ Simple question - WeChat url: https://github.com/hpcaitech/ColossalAI/blob/main/docs/images/WeChat.png diff --git a/.github/workflows/README.md b/.github/workflows/README.md index a46d8b1c24d05804320acf09eff25fd2fcab3fa9..3fad7e36f14c6473d403197534bbf26a93db8095 100644 --- a/.github/workflows/README.md +++ b/.github/workflows/README.md @@ -14,7 +14,7 @@ - [Compatibility Test on Dispatch](#compatibility-test-on-dispatch) - [Release](#release) - [User Friendliness](#user-friendliness) - - [Commmunity](#commmunity) + - [Community](#community) - [Configuration](#configuration) - [Progress Log](#progress-log) @@ -43,10 +43,18 @@ I will provide the details of each workflow below. | Workflow Name | File name | Description | | ---------------------- | -------------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------- | -| `Build on PR` | `build_on_pr.yml` | This workflow is triggered when the label `Run build and Test` is assigned to a PR. It will run all the unit tests in the repository with 4 GPUs. | +| `Build on PR` | `build_on_pr.yml` | This workflow is triggered when a PR changes essential files and a branch is created/deleted. It will run all the unit tests in the repository with 4 GPUs. | | `Build on Schedule` | `build_on_schedule.yml` | This workflow will run the unit tests everyday with 8 GPUs. The result is sent to Lark. | | `Report test coverage` | `report_test_coverage.yml` | This PR will put up a comment to report the test coverage results when `Build` is done. | +To reduce the average time of the unit test on PR, `Build on PR` workflow manages testmon cache. + +1. When creating a new branch, it copies `cache/main/.testmondata*` to `cache//`. +2. When creating a new PR or change the base branch of a PR, it copies `cache//.testmondata*` to `cache/_pull//`. +3. When running unit tests for each PR, it restores testmon cache from `cache/_pull//`. After the test, it stores the cache back to `cache/_pull//`. +4. When a PR is closed, if it's merged, it copies `cache/_pull//.testmondata*` to `cache//`. Otherwise, it just removes `cache/_pull/`. +5. When a branch is deleted, it removes `cache/`. + ### Example Test | Workflow Name | File name | Description | @@ -97,7 +105,7 @@ This workflow is triggered by manually dispatching the workflow. It has the foll | `Synchronize submodule` | `submodule.yml` | This workflow will check if any git submodule is updated. If so, it will create a PR to update the submodule pointers. | | `Close inactive issues` | `close_inactive.yml` | This workflow will close issues which are stale for 14 days. | -### Commmunity +### Community | Workflow Name | File name | Description | | -------------------------------------------- | -------------------------------- | -------------------------------------------------------------------------------- | diff --git a/.github/workflows/build_on_pr.yml b/.github/workflows/build_on_pr.yml index e6febeeb4d87256efc3219b18b9b53875e54aa21..e2114d43bcd0a0da0f36c375d93490f20b9a465f 100644 --- a/.github/workflows/build_on_pr.yml +++ b/.github/workflows/build_on_pr.yml @@ -2,22 +2,93 @@ name: Build on PR on: pull_request: - types: [synchronize, labeled] + types: [synchronize, opened, reopened, ready_for_review, closed, edited] + branches: + - "main" + - "develop" + - "feature/**" + paths: + - ".github/workflows/build_on_pr.yml" # run command & env variables change + - "colossalai/**" # source code change + - "!colossalai/**.md" # ignore doc change + - "op_builder/**" # cuda extension change + - "!op_builder/**.md" # ignore doc change + - "requirements/**" # requirements change + - "tests/**" # test change + - "!tests/**.md" # ignore doc change + - "pytest.ini" # test config change + - "setup.py" # install command change + create: + delete: jobs: + prepare_cache: + name: Prepare testmon cache + if: | + github.event_name == 'create' && + github.event.ref_type == 'branch' && + github.event.repository.full_name == 'hpcaitech/ColossalAI' + runs-on: [self-hosted, gpu] + container: + image: hpcaitech/pytorch-cuda:1.12.0-11.3.0 + options: --rm + timeout-minutes: 5 + defaults: + run: + shell: bash + steps: + - name: Copy testmon cache + run: | # branch name may contain slash, we need to replace it with space + export REF_BRANCH=$(echo ${{ github.event.ref }} | sed "s/\// /") + if [ -d /github/home/testmon_cache/${MAIN_BRANCH} ]; then + cp -p -r /github/home/testmon_cache/${MAIN_BRANCH} "/github/home/testmon_cache/${REF_BRANCH}" + fi + env: + MAIN_BRANCH: ${{ github.event.master_branch }} + + prepare_cache_for_pr: + name: Prepare testmon cache for PR + if: | + github.event_name == 'pull_request' && + (github.event.action == 'opened' || github.event.action == 'reopened' || (github.event.action == 'edited' && github.event.changes.base != null)) && + github.event.pull_request.base.repo.full_name == 'hpcaitech/ColossalAI' + runs-on: [self-hosted, gpu] + container: + image: hpcaitech/pytorch-cuda:1.12.0-11.3.0 + options: --rm + timeout-minutes: 5 + defaults: + run: + shell: bash + concurrency: + group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }}-repare-cache + cancel-in-progress: true + steps: + - name: Copy testmon cache + run: | # branch name may contain slash, we need to replace it with space + export BASE=$(echo ${{ github.event.pull_request.base.ref }} | sed "s/\// /") + if [ -d "/github/home/testmon_cache/${BASE}" ] && [ ! -z "$(ls -A "/github/home/testmon_cache/${BASE}")" ]; then + mkdir -p /github/home/testmon_cache/_pull/${PR_NUMBER} && cp -p -r "/github/home/testmon_cache/${BASE}"/.testmondata* /github/home/testmon_cache/_pull/${PR_NUMBER} + fi + env: + PR_NUMBER: ${{ github.event.number }} + detect: name: Detect file change if: | + github.event_name == 'pull_request' && + (github.event.action == 'synchronize' || github.event.action == 'opened' || github.event.action == 'reopened' || github.event.action == 'ready_for_review') && github.event.pull_request.draft == false && - github.base_ref == 'main' && - github.event.pull_request.base.repo.full_name == 'hpcaitech/ColossalAI' && - contains( github.event.pull_request.labels.*.name, 'Run Build and Test') + github.event.pull_request.base.repo.full_name == 'hpcaitech/ColossalAI' outputs: changedExtenisonFiles: ${{ steps.find-extension-change.outputs.all_changed_files }} anyExtensionFileChanged: ${{ steps.find-extension-change.outputs.any_changed }} changedLibraryFiles: ${{ steps.find-lib-change.outputs.all_changed_files }} anyLibraryFileChanged: ${{ steps.find-lib-change.outputs.any_changed }} runs-on: ubuntu-latest + concurrency: + group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }}-detect-change + cancel-in-progress: true steps: - uses: actions/checkout@v2 with: @@ -66,14 +137,18 @@ jobs: build: name: Build and Test Colossal-AI needs: detect + if: needs.detect.outputs.anyLibraryFileChanged == 'true' runs-on: [self-hosted, gpu] container: - image: hpcaitech/pytorch-cuda:1.11.0-11.3.0 - options: --gpus all --rm -v /data/scratch/cifar-10:/data/scratch/cifar-10 - timeout-minutes: 40 + image: hpcaitech/pytorch-cuda:1.12.0-11.3.0 + options: --gpus all --rm -v /data/scratch/cifar-10:/data/scratch/cifar-10 -v /data/scratch/llama-tiny:/data/scratch/llama-tiny + timeout-minutes: 60 defaults: run: shell: bash + concurrency: + group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }}-run-test + cancel-in-progress: true steps: - name: Checkout TensorNVMe uses: actions/checkout@v2 @@ -84,7 +159,9 @@ jobs: - name: Restore TensorNVMe Cache run: | - [ ! -z "$(ls -A /github/home/tensornvme_cache/)" ] && cp -p -r /github/home/tensornvme_cache/* /__w/ColossalAI/ColossalAI/TensorNVMe + if [ -d /github/home/tensornvme_cache ] && [ ! -z "$(ls -A /github/home/tensornvme_cache/)" ]; then + cp -p -r /github/home/tensornvme_cache/* /__w/ColossalAI/ColossalAI/TensorNVMe + fi - name: Install TensorNVMe run: | @@ -107,10 +184,11 @@ jobs: if: needs.detect.outputs.anyExtensionFileChanged != 'true' run: | # -p flag is required to preserve the file timestamp to avoid ninja rebuild - [ ! -z "$(ls -A /github/home/cuda_ext_cache/)" ] && cp -p -r /github/home/cuda_ext_cache/* /__w/ColossalAI/ColossalAI/ + if [ -d /github/home/cuda_ext_cache ] && [ ! -z "$(ls -A /github/home/cuda_ext_cache/)" ]; then + cp -p -r /github/home/cuda_ext_cache/* /__w/ColossalAI/ColossalAI/ + fi - name: Install Colossal-AI - if: needs.detect.outputs.anyLibraryFileChanged == 'true' run: | CUDA_EXT=1 pip install -v -e . pip install -r requirements/requirements-test.txt @@ -120,14 +198,30 @@ jobs: # -p flag is required to preserve the file timestamp to avoid ninja rebuild cp -p -r /__w/ColossalAI/ColossalAI/build /github/home/cuda_ext_cache/ + - name: Restore Testmon Cache + run: | + if [ -d /github/home/testmon_cache/_pull/${PR_NUMBER} ] && [ ! -z "$(ls -A /github/home/testmon_cache/_pull/${PR_NUMBER})" ]; then + cp -p -r /github/home/testmon_cache/_pull/${PR_NUMBER}/.testmondata* /__w/ColossalAI/ColossalAI/ + fi + env: + PR_NUMBER: ${{ github.event.number }} + - name: Execute Unit Testing - if: needs.detect.outputs.anyLibraryFileChanged == 'true' run: | - CURL_CA_BUNDLE="" PYTHONPATH=$PWD pytest --cov=. --cov-report xml tests/ + CURL_CA_BUNDLE="" PYTHONPATH=$PWD pytest -m "not largedist" --testmon --testmon-forceselect --testmon-cov=. --durations=10 tests/ env: DATA: /data/scratch/cifar-10 NCCL_SHM_DISABLE: 1 LD_LIBRARY_PATH: /github/home/.tensornvme/lib:/usr/local/nvidia/lib:/usr/local/nvidia/lib64 + TESTMON_CORE_PKGS: /__w/ColossalAI/ColossalAI/requirements/requirements.txt,/__w/ColossalAI/ColossalAI/requirements/requirements-test.txt + LLAMA_PATH: /data/scratch/llama-tiny + + - name: Store Testmon Cache + run: | + mkdir -p /github/home/testmon_cache/_pull/${PR_NUMBER} + cp -p -r /__w/ColossalAI/ColossalAI/.testmondata* /github/home/testmon_cache/_pull/${PR_NUMBER}/ + env: + PR_NUMBER: ${{ github.event.number }} - name: Collate artifact env: @@ -140,7 +234,7 @@ jobs: echo $PR_NUMBER > ./report/pr_number # generate coverage.xml if any - if [ "$anyLibraryFileChanged" == "true" ]; then + if [ "$anyLibraryFileChanged" == "true" ] && [ -e .coverage ]; then allFiles="" for file in $changedLibraryFiles; do if [ "$allFiles" == "" ]; then @@ -165,3 +259,54 @@ jobs: with: name: report path: report/ + + store_cache: + name: Store testmon cache for PR + if: | + github.event_name == 'pull_request' && + github.event.action == 'closed' && + github.event.pull_request.base.repo.full_name == 'hpcaitech/ColossalAI' + runs-on: [self-hosted, gpu] + container: + image: hpcaitech/pytorch-cuda:1.12.0-11.3.0 + options: --rm + timeout-minutes: 5 + defaults: + run: + shell: bash + steps: + - name: Store testmon cache if possible + if: github.event.pull_request.merged == true + run: | # branch name may contain slash, we need to replace it with space + export BASE=$(echo ${{ github.event.pull_request.base.ref }} | sed "s/\// /") + if [ -d /github/home/testmon_cache/_pull/${PR_NUMBER} ] && [ ! -z "$(ls -A /github/home/testmon_cache/_pull/${PR_NUMBER})" ]; then + cp -p -r /github/home/testmon_cache/_pull/${PR_NUMBER}/.testmondata* "/github/home/testmon_cache/${BASE}/" + fi + env: + PR_NUMBER: ${{ github.event.pull_request.number }} + + - name: Remove testmon cache + run: | + rm -rf /github/home/testmon_cache/_pull/${PR_NUMBER} + env: + PR_NUMBER: ${{ github.event.pull_request.number }} + + remove_cache: + name: Remove testmon cache + if: | + github.event_name == 'delete' && + github.event.ref_type == 'branch' && + github.event.repository.full_name == 'hpcaitech/ColossalAI' + runs-on: [self-hosted, gpu] + container: + image: hpcaitech/pytorch-cuda:1.12.0-11.3.0 + options: --rm + timeout-minutes: 5 + defaults: + run: + shell: bash + steps: + - name: Remove testmon cache + run: | # branch name may contain slash, we need to replace it with space + export BASE=$(echo ${{ github.event.ref }} | sed "s/\// /") + rm -rf "/github/home/testmon_cache/${BASE}" diff --git a/.github/workflows/build_on_schedule.yml b/.github/workflows/build_on_schedule.yml index 6afdf581e6ca42e078118ef238e3d4199de9aa8f..6c77377be34f993bf409a7332d8cadb367fdd61e 100644 --- a/.github/workflows/build_on_schedule.yml +++ b/.github/workflows/build_on_schedule.yml @@ -3,7 +3,7 @@ name: Build on Schedule on: schedule: # run at 00:00 of every Sunday - - cron: '0 0 * * *' + - cron: "0 0 * * *" workflow_dispatch: jobs: @@ -12,8 +12,8 @@ jobs: if: github.repository == 'hpcaitech/ColossalAI' runs-on: [self-hosted, 8-gpu] container: - image: hpcaitech/pytorch-cuda:1.11.0-11.3.0 - options: --gpus all --rm -v /data/scratch/cifar-10:/data/scratch/cifar-10 + image: hpcaitech/pytorch-cuda:1.12.0-11.3.0 + options: --gpus all --rm -v /data/scratch/cifar-10:/data/scratch/cifar-10 -v /data/scratch/llama-tiny:/data/scratch/llama-tiny timeout-minutes: 40 steps: - name: Check GPU Availability # ensure all GPUs have enough memory @@ -60,10 +60,11 @@ jobs: - name: Unit Testing if: steps.check-avai.outputs.avai == 'true' run: | - PYTHONPATH=$PWD pytest tests + PYTHONPATH=$PWD pytest --durations=0 tests env: DATA: /data/scratch/cifar-10 LD_LIBRARY_PATH: /github/home/.tensornvme/lib:/usr/local/nvidia/lib:/usr/local/nvidia/lib64 + LLAMA_PATH: /data/scratch/llama-tiny - name: Notify Lark id: message-preparation diff --git a/.github/workflows/compatiblity_test_on_dispatch.yml b/.github/workflows/compatiblity_test_on_dispatch.yml index 717cf729b3f3fb13d29daa075129c6c653808eb4..5083212993cc40a70bbededf60919a483bdea7ab 100644 --- a/.github/workflows/compatiblity_test_on_dispatch.yml +++ b/.github/workflows/compatiblity_test_on_dispatch.yml @@ -19,38 +19,38 @@ jobs: outputs: matrix: ${{ steps.set-matrix.outputs.matrix }} steps: - - id: set-matrix - env: - TORCH_VERSIONS: ${{ inputs.torch_version }} - CUDA_VERSIONS: ${{ inputs.cuda_version }} - run: | - IFS=',' - DOCKER_IMAGE=() + - id: set-matrix + env: + TORCH_VERSIONS: ${{ inputs.torch_version }} + CUDA_VERSIONS: ${{ inputs.cuda_version }} + run: | + IFS=',' + DOCKER_IMAGE=() - for tv in $TORCH_VERSIONS - do - for cv in $CUDA_VERSIONS - do - DOCKER_IMAGE+=("\"hpcaitech/pytorch-cuda:${tv}-${cv}\"") - done - done + for tv in $TORCH_VERSIONS + do + for cv in $CUDA_VERSIONS + do + DOCKER_IMAGE+=("\"hpcaitech/pytorch-cuda:${tv}-${cv}\"") + done + done - container=$( IFS=',' ; echo "${DOCKER_IMAGE[*]}" ) - container="[${container}]" - echo "$container" - echo "::set-output name=matrix::{\"container\":$(echo "$container")}" + container=$( IFS=',' ; echo "${DOCKER_IMAGE[*]}" ) + container="[${container}]" + echo "$container" + echo "::set-output name=matrix::{\"container\":$(echo "$container")}" build: name: Test for PyTorch Compatibility needs: matrix_preparation if: github.repository == 'hpcaitech/ColossalAI' - runs-on: [self-hosted, gpu] + runs-on: [self-hosted, 8-gpu] strategy: fail-fast: false matrix: ${{fromJson(needs.matrix_preparation.outputs.matrix)}} container: image: ${{ matrix.container }} - options: --gpus all --rm -v /data/scratch/cifar-10:/data/scratch/cifar-10 + options: --gpus all --rm -v /data/scratch/cifar-10:/data/scratch/cifar-10 -v /data/scratch/llama-tiny:/data/scratch/llama-tiny timeout-minutes: 120 steps: - name: Install dependencies @@ -64,16 +64,26 @@ jobs: - name: Install tensornvme run: | cd TensorNVMe - conda install cmake + apt update && apt install -y cmake pip install -r requirements.txt pip install -v . - uses: actions/checkout@v2 with: ssh-key: ${{ secrets.SSH_KEY_FOR_CI }} + - name: Download cub for CUDA 10.2 + run: | + CUDA_VERSION=$(nvcc -V | awk -F ',| ' '/release/{print $6}') + + # check if it is CUDA 10.2 + # download cub + if [ "$CUDA_VERSION" = "10.2" ]; then + wget https://github.com/NVIDIA/cub/archive/refs/tags/1.8.0.zip + unzip 1.8.0.zip + cp -r cub-1.8.0/cub/ colossalai/kernel/cuda_native/csrc/kernels/include/ + fi - name: Install Colossal-AI run: | - pip install -r requirements/requirements.txt - pip install -v --no-cache-dir . + CUDA_EXT=1 pip install -v . pip install -r requirements/requirements-test.txt - name: Unit Testing run: | @@ -82,3 +92,4 @@ jobs: DATA: /data/scratch/cifar-10 NCCL_SHM_DISABLE: 1 LD_LIBRARY_PATH: /github/home/.tensornvme/lib:/usr/local/nvidia/lib:/usr/local/nvidia/lib64 + LLAMA_PATH: /data/scratch/llama-tiny diff --git a/.github/workflows/compatiblity_test_on_pr.yml b/.github/workflows/compatiblity_test_on_pr.yml index 2fca67b820a1d7cf411f1ee9b854f93a150c639d..cc17c66f9c3afb57f08233bcbc6ca9e1c8e7df5f 100644 --- a/.github/workflows/compatiblity_test_on_pr.yml +++ b/.github/workflows/compatiblity_test_on_pr.yml @@ -3,8 +3,8 @@ name: Compatibility Test on PR on: pull_request: paths: - - 'version.txt' - - '.compatibility' + - "version.txt" + - ".compatibility" jobs: matrix_preparation: @@ -12,6 +12,9 @@ jobs: runs-on: ubuntu-latest outputs: matrix: ${{ steps.set-matrix.outputs.matrix }} + concurrency: + group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }}-prepare-matrix + cancel-in-progress: true steps: - uses: actions/checkout@v3 - id: set-matrix @@ -32,14 +35,17 @@ jobs: name: Test for PyTorch Compatibility needs: matrix_preparation if: github.repository == 'hpcaitech/ColossalAI' - runs-on: [self-hosted, gpu] + runs-on: [self-hosted, 8-gpu] strategy: fail-fast: false matrix: ${{fromJson(needs.matrix_preparation.outputs.matrix)}} container: image: ${{ matrix.container }} - options: --gpus all --rm -v /data/scratch/cifar-10:/data/scratch/cifar-10 + options: --gpus all --rm -v /data/scratch/cifar-10:/data/scratch/cifar-10 -v /data/scratch/llama-tiny:/data/scratch/llama-tiny timeout-minutes: 120 + concurrency: + group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }}-run-test-${{ matrix.container }} + cancel-in-progress: true steps: - name: Install dependencies run: | @@ -52,15 +58,27 @@ jobs: - name: Install tensornvme run: | cd TensorNVMe - conda install cmake + apt update && apt install -y cmake pip install -r requirements.txt pip install -v . - uses: actions/checkout@v2 with: ssh-key: ${{ secrets.SSH_KEY_FOR_CI }} + - name: Download cub for CUDA 10.2 + run: | + CUDA_VERSION=$(nvcc -V | awk -F ',| ' '/release/{print $6}') + + # check if it is CUDA 10.2 + # download cub + if [ "$CUDA_VERSION" = "10.2" ]; then + wget https://github.com/NVIDIA/cub/archive/refs/tags/1.8.0.zip + unzip 1.8.0.zip + cp -r cub-1.8.0/cub/ colossalai/kernel/cuda_native/csrc/kernels/include/ + fi + - name: Install Colossal-AI run: | - pip install -v --no-cache-dir . + CUDA_EXT=1 pip install -v . pip install -r requirements/requirements-test.txt - name: Unit Testing run: | @@ -69,3 +87,4 @@ jobs: DATA: /data/scratch/cifar-10 NCCL_SHM_DISABLE: 1 LD_LIBRARY_PATH: /github/home/.tensornvme/lib:/usr/local/nvidia/lib:/usr/local/nvidia/lib64 + LLAMA_PATH: /data/scratch/llama-tiny diff --git a/.github/workflows/compatiblity_test_on_schedule.yml b/.github/workflows/compatiblity_test_on_schedule.yml index 9802795fad246864db5c7a86d9f4273f2ce99ff9..158fe751bf2efd7d8449d7df694351128dbe8100 100644 --- a/.github/workflows/compatiblity_test_on_schedule.yml +++ b/.github/workflows/compatiblity_test_on_schedule.yml @@ -32,13 +32,13 @@ jobs: name: Test for PyTorch Compatibility needs: matrix_preparation if: github.repository == 'hpcaitech/ColossalAI' - runs-on: [self-hosted, gpu] + runs-on: [self-hosted, 8-gpu] strategy: fail-fast: false matrix: ${{fromJson(needs.matrix_preparation.outputs.matrix)}} container: image: ${{ matrix.container }} - options: --gpus all --rm -v /data/scratch/cifar-10:/data/scratch/cifar-10 + options: --gpus all --rm -v /data/scratch/cifar-10:/data/scratch/cifar-10 -v /data/scratch/llama-tiny:/data/scratch/llama-tiny timeout-minutes: 120 steps: - name: Install dependencies @@ -54,16 +54,28 @@ jobs: - name: Install tensornvme run: | cd TensorNVMe - conda install cmake + apt update && apt install -y cmake pip install -r requirements.txt pip install -v . - uses: actions/checkout@v2 with: ssh-key: ${{ secrets.SSH_KEY_FOR_CI }} + - name: Download cub for CUDA 10.2 + run: | + CUDA_VERSION=$(nvcc -V | awk -F ',| ' '/release/{print $6}') + + # check if it is CUDA 10.2 + # download cub + if [ "$CUDA_VERSION" = "10.2" ]; then + wget https://github.com/NVIDIA/cub/archive/refs/tags/1.8.0.zip + unzip 1.8.0.zip + cp -r cub-1.8.0/cub/ colossalai/kernel/cuda_native/csrc/kernels/include/ + fi + - name: Install Colossal-AI run: | - pip install -v --no-cache-dir . + CUDA_EXT=1 pip install -v . pip install -r requirements/requirements-test.txt - name: Unit Testing @@ -73,6 +85,7 @@ jobs: DATA: /data/scratch/cifar-10 NCCL_SHM_DISABLE: 1 LD_LIBRARY_PATH: /github/home/.tensornvme/lib:/usr/local/nvidia/lib:/usr/local/nvidia/lib64 + LLAMA_PATH: /data/scratch/llama-tiny - name: Notify Lark id: message-preparation diff --git a/.github/workflows/cuda_ext_check_before_merge.yml b/.github/workflows/cuda_ext_check_before_merge.yml index eba5bb98ec07998314708d1b8a159de0d574fcf2..686f0f395c733f6bbab586048195c7fad8db785d 100644 --- a/.github/workflows/cuda_ext_check_before_merge.yml +++ b/.github/workflows/cuda_ext_check_before_merge.yml @@ -37,6 +37,18 @@ jobs: - name: Install PyTorch run: eval ${{ matrix.build.torch_command }} + - name: Download cub for CUDA 10.2 + run: | + CUDA_VERSION=$(nvcc -V | awk -F ',| ' '/release/{print $6}') + + # check if it is CUDA 10.2 + # download cub + if [ "$CUDA_VERSION" = "10.2" ]; then + wget https://github.com/NVIDIA/cub/archive/refs/tags/1.8.0.zip + unzip 1.8.0.zip + cp -r cub-1.8.0/cub/ colossalai/kernel/cuda_native/csrc/kernels/include/ + fi + - name: Build run: | CUDA_EXT=1 pip install -v . diff --git a/.github/workflows/doc_build_after_merge.yml b/.github/workflows/doc_build_after_merge.yml deleted file mode 100644 index ede04b336620870e3739458e68bfd31afb73ce8c..0000000000000000000000000000000000000000 --- a/.github/workflows/doc_build_after_merge.yml +++ /dev/null @@ -1,28 +0,0 @@ -name: Build Documentation After Merge - -on: - workflow_dispatch: - pull_request: - paths: - - 'version.txt' - - 'docs/**' - types: - - closed - -jobs: - build-doc: - name: Trigger Documentation Build Workflow - if: ( github.event_name == 'workflow_dispatch' || github.event.pull_request.merged == true ) && github.repository == 'hpcaitech/ColossalAI' - runs-on: ubuntu-latest - steps: - - name: trigger workflow in ColossalAI-Documentation - run: | - curl \ - -X POST \ - -H "Accept: application/vnd.github+json" \ - -H "Authorization: Bearer ${GH_TOKEN}"\ - -H "X-GitHub-Api-Version: 2022-11-28" \ - https://api.github.com/repos/hpcaitech/ColossalAI-Documentation/actions/workflows/deploy.yml/dispatches \ - -d '{"ref":"main"}' - env: - GH_TOKEN: ${{secrets.DOC_REPO_TOKEN}} diff --git a/.github/workflows/doc_build_on_schedule_after_release.yml b/.github/workflows/doc_build_on_schedule_after_release.yml new file mode 100644 index 0000000000000000000000000000000000000000..62dfdc67257c7ac03015652fb0decaab6e04d588 --- /dev/null +++ b/.github/workflows/doc_build_on_schedule_after_release.yml @@ -0,0 +1,26 @@ +name: Build Documentation On Schedule & After Release + +on: + workflow_dispatch: + schedule: + - cron: "0 12 * * *" # build doc every day at 8pm Singapore time (12pm UTC time) + release: + types: [published] + +jobs: + build-doc: + name: Trigger Documentation Build Workflow + if: github.repository == 'hpcaitech/ColossalAI' + runs-on: ubuntu-latest + steps: + - name: trigger workflow in ColossalAI-Documentation + run: | + curl \ + -X POST \ + -H "Accept: application/vnd.github+json" \ + -H "Authorization: Bearer ${GH_TOKEN}"\ + -H "X-GitHub-Api-Version: 2022-11-28" \ + https://api.github.com/repos/hpcaitech/ColossalAI-Documentation/actions/workflows/deploy.yml/dispatches \ + -d '{"ref":"main"}' + env: + GH_TOKEN: ${{secrets.DOC_REPO_TOKEN}} diff --git a/.github/workflows/doc_check_on_pr.yml b/.github/workflows/doc_check_on_pr.yml index 2022c957fba837904ba2a2efe678edd1686eece8..ee8a82128dd775a72b8dc71c9326a4c7c922a55a 100644 --- a/.github/workflows/doc_check_on_pr.yml +++ b/.github/workflows/doc_check_on_pr.yml @@ -2,57 +2,68 @@ name: Check Documentation on PR on: pull_request: + branches: + - "main" + - "develop" + - "feature/**" paths: - - 'docs/**' + - "docs/**" jobs: check-i18n: name: Check docs in diff languages if: | - github.event.pull_request.draft == false && - github.base_ref == 'main' && - github.event.pull_request.base.repo.full_name == 'hpcaitech/ColossalAI' + github.event.pull_request.draft == false && + github.event.pull_request.base.repo.full_name == 'hpcaitech/ColossalAI' runs-on: ubuntu-latest + concurrency: + group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }}-check-i18n + cancel-in-progress: true steps: - uses: actions/checkout@v2 - uses: actions/setup-python@v2 with: - python-version: '3.8.14' + python-version: "3.8.14" - run: python .github/workflows/scripts/check_doc_i18n.py -d docs/source check-doc-build: name: Test if the docs can be built if: | - github.event.pull_request.draft == false && - github.base_ref == 'main' && - github.event.pull_request.base.repo.full_name == 'hpcaitech/ColossalAI' + github.event.pull_request.draft == false && + github.event.pull_request.base.repo.full_name == 'hpcaitech/ColossalAI' runs-on: ubuntu-latest + concurrency: + group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }}-check-doc + cancel-in-progress: true steps: - uses: actions/checkout@v2 with: - path: './ColossalAI' + path: "./ColossalAI" fetch-depth: 0 - uses: actions/checkout@v2 with: - path: './ColossalAI-Documentation' - repository: 'hpcaitech/ColossalAI-Documentation' + path: "./ColossalAI-Documentation" + repository: "hpcaitech/ColossalAI-Documentation" - uses: actions/setup-python@v2 with: - python-version: '3.8.14' + python-version: "3.8.14" # we use the versions in the main branch as the guide for versions to display # checkout will give your merged branch # therefore, we need to make the merged branch as the main branch + # there is no main branch, so it's safe to checkout the main branch from the merged branch + # docer will rebase the remote main branch to the merged branch, so we have to config user - name: Make the merged branch main run: | cd ColossalAI - curBranch=$(git rev-parse --abbrev-ref HEAD) - git checkout main - git merge $curBranch # fast-forward master up to the merge + git checkout -b main + git branch -u origin/main + git config user.name 'github-actions' + git config user.email 'github-actions@github.com' - name: Build docs run: | diff --git a/.github/workflows/doc_test_on_pr.yml b/.github/workflows/doc_test_on_pr.yml index fbe669582c2088c17678a763ceea9bfe78738108..f1e7a2d0cab057573e04755795a7169b3718b8f2 100644 --- a/.github/workflows/doc_test_on_pr.yml +++ b/.github/workflows/doc_test_on_pr.yml @@ -1,21 +1,27 @@ name: Test Documentation on PR on: pull_request: + branches: + - "main" + - "develop" + - "feature/**" # any change in the examples folder will trigger check for the corresponding example. paths: - - 'docs/source/**.md' + - "docs/source/**.md" jobs: # This is for changed example files detect and output a matrix containing all the corresponding directory name. detect-changed-doc: if: | - github.event.pull_request.draft == false && - github.base_ref == 'main' && - github.event.pull_request.base.repo.full_name == 'hpcaitech/ColossalAI' && github.event_name == 'pull_request' + github.event.pull_request.draft == false && + github.event.pull_request.base.repo.full_name == 'hpcaitech/ColossalAI' && github.event_name == 'pull_request' runs-on: ubuntu-latest outputs: any_changed: ${{ steps.changed-files.outputs.any_changed }} changed_files: ${{ steps.changed-files.outputs.all_changed_files }} + concurrency: + group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }}-detect-change + cancel-in-progress: true name: Detect changed example files steps: - uses: actions/checkout@v3 @@ -26,10 +32,10 @@ jobs: - name: Locate base commit id: locate-base-sha run: | - curBranch=$(git rev-parse --abbrev-ref HEAD) - commonCommit=$(git merge-base origin/main $curBranch) - echo $commonCommit - echo "baseSHA=$commonCommit" >> $GITHUB_OUTPUT + curBranch=$(git rev-parse --abbrev-ref HEAD) + commonCommit=$(git merge-base origin/main $curBranch) + echo $commonCommit + echo "baseSHA=$commonCommit" >> $GITHUB_OUTPUT - name: Get all changed example files id: changed-files @@ -43,10 +49,9 @@ jobs: check-changed-doc: # Add this condition to avoid executing this job if the trigger event is workflow_dispatch. if: | - github.event.pull_request.draft == false && - github.base_ref == 'main' && - github.event.pull_request.base.repo.full_name == 'hpcaitech/ColossalAI' && github.event_name == 'pull_request' && - needs.detect-changed-doc.outputs.any_changed == 'true' + github.event.pull_request.draft == false && + github.event.pull_request.base.repo.full_name == 'hpcaitech/ColossalAI' && github.event_name == 'pull_request' && + needs.detect-changed-doc.outputs.any_changed == 'true' name: Test the changed Doc needs: detect-changed-doc runs-on: [self-hosted, gpu] @@ -57,12 +62,15 @@ jobs: defaults: run: shell: bash + concurrency: + group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }}-run-doctest + cancel-in-progress: true steps: - name: Checkout ColossalAI-Documentation uses: actions/checkout@v2 with: - path: './ColossalAI-Documentation' - repository: 'hpcaitech/ColossalAI-Documentation' + path: "./ColossalAI-Documentation" + repository: "hpcaitech/ColossalAI-Documentation" - name: Install Docer run: | @@ -81,12 +89,12 @@ jobs: - name: Install ColossalAI run: | source activate pytorch - pip install -v . + CUDA_EXT=1 pip install -v . - name: Test the Doc run: | source activate pytorch - for file in ${{ steps.changed-files.outputs.all_changed_files }}; do + for file in ${{ needs.detect-changed-doc.outputs.changed_files }}; do echo "Testing $file now..." docer test -p $file done diff --git a/.github/workflows/doc_test_on_schedule.yml b/.github/workflows/doc_test_on_schedule.yml index 6b4f5d1f908c608a8ba266725c521369a9fb0047..027fbfd0aaeb315d1a013e247bd76a4d2d161ae5 100644 --- a/.github/workflows/doc_test_on_schedule.yml +++ b/.github/workflows/doc_test_on_schedule.yml @@ -32,7 +32,7 @@ jobs: - name: Install ColossalAI run: | - pip install -v . + CUDA_EXT=1 pip install -v . - name: Install Doc Test Requirements run: | diff --git a/.github/workflows/example_check_on_dispatch.yml b/.github/workflows/example_check_on_dispatch.yml index 620d4771af55f53972888013c67347f3f4a392bd..9d3bd9a48235033ae07eb955888e5e3761bc4f3a 100644 --- a/.github/workflows/example_check_on_dispatch.yml +++ b/.github/workflows/example_check_on_dispatch.yml @@ -53,7 +53,7 @@ jobs: uses: actions/checkout@v3 - name: Install Colossal-AI run: | - pip install -v . + CUDA_EXT=1 pip install -v . - name: Test the example run: | dir=${{ matrix.directory }} diff --git a/.github/workflows/example_check_on_pr.yml b/.github/workflows/example_check_on_pr.yml index b22664ee47ccbf8ea786b526964694933a51ab5f..5934704f4102e546d8140c04983e3db58e5da9f8 100644 --- a/.github/workflows/example_check_on_pr.yml +++ b/.github/workflows/example_check_on_pr.yml @@ -1,22 +1,28 @@ name: Test Example on PR on: pull_request: + branches: + - "main" + - "develop" + - "feature/**" # any change in the examples folder will trigger check for the corresponding example. paths: - - 'examples/**' + - "examples/**" jobs: # This is for changed example files detect and output a matrix containing all the corresponding directory name. detect-changed-example: if: | - github.event.pull_request.draft == false && - github.base_ref == 'main' && - github.event.pull_request.base.repo.full_name == 'hpcaitech/ColossalAI' && github.event_name == 'pull_request' + github.event.pull_request.draft == false && + github.event.pull_request.base.repo.full_name == 'hpcaitech/ColossalAI' && github.event_name == 'pull_request' runs-on: ubuntu-latest outputs: matrix: ${{ steps.setup-matrix.outputs.matrix }} anyChanged: ${{ steps.setup-matrix.outputs.anyChanged }} name: Detect changed example files + concurrency: + group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }}-detect-change + cancel-in-progress: true steps: - uses: actions/checkout@v3 with: @@ -26,10 +32,10 @@ jobs: - name: Locate base commit id: locate-base-sha run: | - curBranch=$(git rev-parse --abbrev-ref HEAD) - commonCommit=$(git merge-base origin/main $curBranch) - echo $commonCommit - echo "baseSHA=$commonCommit" >> $GITHUB_OUTPUT + curBranch=$(git rev-parse --abbrev-ref HEAD) + commonCommit=$(git merge-base origin/main $curBranch) + echo $commonCommit + echo "baseSHA=$commonCommit" >> $GITHUB_OUTPUT - name: Get all changed example files id: changed-files @@ -61,10 +67,9 @@ jobs: check-changed-example: # Add this condition to avoid executing this job if the trigger event is workflow_dispatch. if: | - github.event.pull_request.draft == false && - github.base_ref == 'main' && - github.event.pull_request.base.repo.full_name == 'hpcaitech/ColossalAI' && github.event_name == 'pull_request' && - needs.detect-changed-example.outputs.anyChanged == 'true' + github.event.pull_request.draft == false && + github.event.pull_request.base.repo.full_name == 'hpcaitech/ColossalAI' && github.event_name == 'pull_request' && + needs.detect-changed-example.outputs.anyChanged == 'true' name: Test the changed example needs: detect-changed-example runs-on: [self-hosted, gpu] @@ -75,12 +80,15 @@ jobs: image: hpcaitech/pytorch-cuda:1.12.0-11.3.0 options: --gpus all --rm -v /data/scratch/examples-data:/data/ timeout-minutes: 10 + concurrency: + group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }}-run-example-${{ matrix.directory }} + cancel-in-progress: true steps: - uses: actions/checkout@v3 - name: Install Colossal-AI run: | - pip install -v . + CUDA_EXT=1 pip install -v . - name: Test the example run: | diff --git a/.github/workflows/example_check_on_schedule.yml b/.github/workflows/example_check_on_schedule.yml index bd52ca4321a2b9abd3182df53cf15b77e40766c7..5ed128c3ebc59b834d02e1af9faa832d91527fc4 100644 --- a/.github/workflows/example_check_on_schedule.yml +++ b/.github/workflows/example_check_on_schedule.yml @@ -42,7 +42,7 @@ jobs: - name: Install Colossal-AI run: | - pip install -v . + CUDA_EXT=1 pip install -v . - name: Traverse all files run: | diff --git a/.github/workflows/release_docker_after_merge.yml b/.github/workflows/release_docker_after_merge.yml deleted file mode 100644 index 607c19b05472e024d08fdd60dcf95c5516a4420a..0000000000000000000000000000000000000000 --- a/.github/workflows/release_docker_after_merge.yml +++ /dev/null @@ -1,75 +0,0 @@ -name: Publish Docker Image to DockerHub after Merge - -on: - workflow_dispatch: - pull_request: - paths: - - 'version.txt' - types: - - closed - -jobs: - release: - name: Publish Docker Image to DockerHub - if: ( github.event_name == 'workflow_dispatch' || github.event.pull_request.merged == true ) && github.repository == 'hpcaitech/ColossalAI' - runs-on: [self-hosted, gpu] - container: - image: "hpcaitech/docker-in-docker:latest" - options: --gpus all --rm -v /var/run/docker.sock:/var/run/docker.sock - steps: - - uses: actions/checkout@v2 - with: - fetch-depth: 0 - - - name: Build Docker - id: build - run: | - version=$(cat version.txt) - tag=hpcaitech/colossalai:$version - docker build --build-arg http_proxy=http://172.17.0.1:7890 --build-arg https_proxy=http://172.17.0.1:7890 -t $tag ./docker - echo "tag=${tag}" >> $GITHUB_OUTPUT - - - name: Log in to Docker Hub - uses: docker/login-action@f054a8b539a109f9f41c372932f1ae047eff08c9 - with: - username: ${{ secrets.DOCKER_USERNAME }} - password: ${{ secrets.DOCKER_PASSWORD }} - - - name: Push Docker image - id: docker-push - run: | - docker push ${{ steps.build.outputs.tag }} - - notify: - name: Notify Lark via webhook - needs: release - runs-on: ubuntu-latest - if: ${{ always() }} - steps: - - uses: actions/checkout@v2 - - - uses: actions/setup-python@v2 - with: - python-version: '3.8.14' - - - name: Install requests - run: pip install requests - - - name: Notify Lark - id: message-preparation - run: | - url=$SERVER_URL/$REPO/actions/runs/$RUN_ID - if [ "$STATUS" == 'success' ] - then - msg="The Docker image for the latest release has been successfully built and pushed to DockerHub." - else - msg="Failed to build and push the Docker image for the latest release, please visit $url for details." - fi - echo $msg - python .github/workflows/scripts/send_message_to_lark.py -m "$msg" -u $WEBHOOK_URL - env: - SERVER_URL: ${{github.server_url }} - REPO: ${{ github.repository }} - RUN_ID: ${{ github.run_id }} - WEBHOOK_URL: ${{ secrets.LARK_NOTIFICATION_WEBHOOK_URL }} - STATUS: ${{ needs.release.result }} diff --git a/.github/workflows/release_docker_after_publish.yml b/.github/workflows/release_docker_after_publish.yml new file mode 100644 index 0000000000000000000000000000000000000000..6c8df9730b0df2a4a5c6622d79c60c83da54879f --- /dev/null +++ b/.github/workflows/release_docker_after_publish.yml @@ -0,0 +1,76 @@ +name: Publish Docker Image to DockerHub after Publish + +on: + workflow_dispatch: + release: + types: [published] + +jobs: + release: + name: Publish Docker Image to DockerHub + if: github.repository == 'hpcaitech/ColossalAI' + runs-on: [self-hosted, gpu] + container: + image: "hpcaitech/docker-in-docker:latest" + options: --gpus all --rm -v /var/run/docker.sock:/var/run/docker.sock + steps: + - uses: actions/checkout@v2 + with: + fetch-depth: 0 + + - name: Build Docker + id: build + run: | + version=$(cat version.txt) + tag=hpcaitech/colossalai:$version + latest=hpcaitech/colossalai:latest + docker build --build-arg http_proxy=http://172.17.0.1:7890 --build-arg https_proxy=http://172.17.0.1:7890 --build-arg VERSION=v${version} -t $tag ./docker + docker tag $tag $latest + echo "tag=${tag}" >> $GITHUB_OUTPUT + echo "latest=${latest}" >> $GITHUB_OUTPUT + + - name: Log in to Docker Hub + uses: docker/login-action@f054a8b539a109f9f41c372932f1ae047eff08c9 + with: + username: ${{ secrets.DOCKER_USERNAME }} + password: ${{ secrets.DOCKER_PASSWORD }} + + - name: Push Docker image + id: docker-push + run: | + docker push ${{ steps.build.outputs.tag }} + docker push ${{ steps.build.outputs.latest }} + + notify: + name: Notify Lark via webhook + needs: release + runs-on: ubuntu-latest + if: ${{ always() }} + steps: + - uses: actions/checkout@v2 + + - uses: actions/setup-python@v2 + with: + python-version: "3.8.14" + + - name: Install requests + run: pip install requests + + - name: Notify Lark + id: message-preparation + run: | + url=$SERVER_URL/$REPO/actions/runs/$RUN_ID + if [ "$STATUS" == 'success' ] + then + msg="The Docker image for the latest release has been successfully built and pushed to DockerHub." + else + msg="Failed to build and push the Docker image for the latest release, please visit $url for details." + fi + echo $msg + python .github/workflows/scripts/send_message_to_lark.py -m "$msg" -u $WEBHOOK_URL + env: + SERVER_URL: ${{github.server_url }} + REPO: ${{ github.repository }} + RUN_ID: ${{ github.run_id }} + WEBHOOK_URL: ${{ secrets.LARK_NOTIFICATION_WEBHOOK_URL }} + STATUS: ${{ needs.release.result }} diff --git a/.github/workflows/report_test_coverage.yml b/.github/workflows/report_test_coverage.yml index bbada74e685025d15b57f6a4cae855046f99e8ec..c9dc541b8a33fcd29bf6b75d9d70456444c8c421 100644 --- a/.github/workflows/report_test_coverage.yml +++ b/.github/workflows/report_test_coverage.yml @@ -9,8 +9,9 @@ on: jobs: report-test-coverage: runs-on: ubuntu-latest + if: ${{ github.event.workflow_run.conclusion == 'success' }} steps: - - name: 'Download artifact' + - name: "Download artifact" uses: actions/github-script@v6 with: script: | @@ -31,7 +32,7 @@ jobs: let fs = require('fs'); fs.writeFileSync(`${process.env.GITHUB_WORKSPACE}/report.zip`, Buffer.from(download.data)); - - name: 'Unzip artifact' + - name: "Unzip artifact" id: unzip run: | unzip report.zip @@ -58,7 +59,7 @@ jobs: echo "" >> coverage_report.txt mv coverage_report.txt coverage.txt - - name: 'Comment on PR' + - name: "Comment on PR" if: steps.unzip.outputs.hasReport == 'true' uses: actions/github-script@v6 with: diff --git a/.github/workflows/run_chatgpt_examples.yml b/.github/workflows/run_chatgpt_examples.yml index 9d9d3a007851276cbf5dbd69e5c8b776178e5de4..f9e9f400962ef0453a017d4f5cdefcbaf95ff356 100644 --- a/.github/workflows/run_chatgpt_examples.yml +++ b/.github/workflows/run_chatgpt_examples.yml @@ -4,11 +4,10 @@ on: pull_request: types: [synchronize, opened, reopened] paths: - - 'applications/Chat/coati/**' - - 'applications/Chat/requirements.txt' - - 'applications/Chat/setup.py' - - 'applications/Chat/examples/**' - + - "applications/Chat/coati/**" + - "applications/Chat/requirements.txt" + - "applications/Chat/setup.py" + - "applications/Chat/examples/**" jobs: tests: @@ -20,7 +19,7 @@ jobs: runs-on: [self-hosted, gpu] container: image: hpcaitech/pytorch-cuda:1.12.0-11.3.0 - options: --gpus all --rm -v /data/scratch/github_actions/chat:/data/scratch/github_actions/chat + options: --gpus all --rm -v /data/scratch/github_actions/chat:/data/scratch/github_actions/chat --shm-size=10.24gb timeout-minutes: 30 defaults: run: @@ -29,28 +28,26 @@ jobs: - name: Checkout ColossalAI uses: actions/checkout@v2 - - name: Install ColossalAI and ChatGPT + - name: Install ChatGPT run: | - pip install -e . cd applications/Chat pip install -v . pip install -r examples/requirements.txt - name: Install Transformers run: | - cd applications/Chat - git clone https://github.com/hpcaitech/transformers - cd transformers - pip install -v . + pip install transformers==4.30.2 - name: Execute Examples run: | cd applications/Chat rm -rf ~/.cache/colossalai - ./examples/test_ci.sh + ./tests/test_inference.sh + ./tests/test_benchmarks.sh + ./tests/test_train.sh env: NCCL_SHM_DISABLE: 1 MAX_JOBS: 8 SFT_DATASET: /data/scratch/github_actions/chat/data.json - PROMPT_PATH: /data/scratch/github_actions/chat/prompts_en.jsonl + PROMPT_DATASET: /data/scratch/github_actions/chat/prompts_en.jsonl PRETRAIN_DATASET: /data/scratch/github_actions/chat/alpaca_data.json diff --git a/.github/workflows/run_chatgpt_unit_tests.yml b/.github/workflows/run_chatgpt_unit_tests.yml index 47c80fc9a9fecafa332a0cb0e457052b8d929711..ec5c8ffa319f47d75094d403152c9162cac77c6f 100644 --- a/.github/workflows/run_chatgpt_unit_tests.yml +++ b/.github/workflows/run_chatgpt_unit_tests.yml @@ -30,9 +30,8 @@ jobs: - name: Checkout ColossalAI uses: actions/checkout@v2 - - name: Install ColossalAI and ChatGPT + - name: Install ChatGPT run: | - pip install -e . cd applications/Chat pip install -v . pip install -r requirements-test.txt diff --git a/.github/workflows/scripts/check_doc_i18n.py b/.github/workflows/scripts/check_doc_i18n.py index 1aa7283e9e52f169d89f337d7942cf55f601257d..1e7f0c33a78598e12a67fa47de192c519d42b3c8 100644 --- a/.github/workflows/scripts/check_doc_i18n.py +++ b/.github/workflows/scripts/check_doc_i18n.py @@ -22,13 +22,13 @@ def compare_dirs(dir1, dir2): # If the corresponding item doesn't exist in the second directory, the directories are different if not os.path.exists(item_path2): - print(f'Found mismatch: {item_path1}, {item_path2}') + print(f"Found mismatch: {item_path1}, {item_path2}") return False # If the corresponding item is a directory, we compare the two directories recursively if os.path.isdir(item_path1) and os.path.isdir(item_path2): if not compare_dirs(item_path1, item_path2): - print(f'Found mismatch: {item_path1}, {item_path2}') + print(f"Found mismatch: {item_path1}, {item_path2}") return False # both are files @@ -37,16 +37,16 @@ def compare_dirs(dir1, dir2): # If the corresponding item is not a file or a directory, the directories are different else: - print(f'Found mismatch: {item_path1}, {item_path2}') + print(f"Found mismatch: {item_path1}, {item_path2}") return False # If all items are the same, the directories are the same return True -if __name__ == '__main__': +if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument('-d', '--directory', help="The directory where the multi-language source files are kept.") + parser.add_argument("-d", "--directory", help="The directory where the multi-language source files are kept.") args = parser.parse_args() i18n_folders = os.listdir(args.directory) @@ -56,7 +56,7 @@ if __name__ == '__main__': for i in range(1, len(i18n_folders)): dir1 = i18n_folders[0] dir2 = i18n_folders[i] - print(f'comparing {dir1} vs {dir2}') + print(f"comparing {dir1} vs {dir2}") match = compare_dirs(i18n_folders[0], i18n_folders[i]) if not match: diff --git a/.github/workflows/scripts/example_checks/check_dispatch_inputs.py b/.github/workflows/scripts/example_checks/check_dispatch_inputs.py index 5bec96187e0cc5b0aa5ebd8e6a59f73ac8b6d88d..91778f692cc63360afc3a594a61e5a21c387a1a8 100644 --- a/.github/workflows/scripts/example_checks/check_dispatch_inputs.py +++ b/.github/workflows/scripts/example_checks/check_dispatch_inputs.py @@ -4,7 +4,7 @@ import os def check_inputs(input_list): for path in input_list: - real_path = os.path.join('examples', path) + real_path = os.path.join("examples", path) if not os.path.exists(real_path): return False return True @@ -12,16 +12,16 @@ def check_inputs(input_list): def main(): parser = argparse.ArgumentParser() - parser.add_argument('-f', '--fileNameList', type=str, help="List of file names") + parser.add_argument("-f", "--fileNameList", type=str, help="List of file names") args = parser.parse_args() name_list = args.fileNameList.split(",") is_correct = check_inputs(name_list) if is_correct: - print('success') + print("success") else: - print('failure') + print("failure") -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/.github/workflows/scripts/example_checks/check_example_weekly.py b/.github/workflows/scripts/example_checks/check_example_weekly.py index 83eff644e3150dae8fa7ada808dd1e16b571e54a..95a3d24c9a78038095e53a3f1700a8fc8d4d536c 100644 --- a/.github/workflows/scripts/example_checks/check_example_weekly.py +++ b/.github/workflows/scripts/example_checks/check_example_weekly.py @@ -17,21 +17,21 @@ def show_files(path, all_files): def join(input_list, sep=None): - return (sep or ' ').join(input_list) + return (sep or " ").join(input_list) def main(): - contents = show_files('examples/', []) + contents = show_files("examples/", []) all_loc = [] for file_loc in contents: - split_loc = file_loc.split('/') + split_loc = file_loc.split("/") # must have two sub-folder levels after examples folder, such as examples/images/vit is acceptable, examples/images/README.md is not, examples/requirements.txt is not. if len(split_loc) >= 4: - re_loc = '/'.join(split_loc[1:3]) + re_loc = "/".join(split_loc[1:3]) if re_loc not in all_loc: all_loc.append(re_loc) print(all_loc) -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/.github/workflows/scripts/example_checks/detect_changed_example.py b/.github/workflows/scripts/example_checks/detect_changed_example.py index c69d95a552e96bfe425fc7ad06c0de6a30b1d786..95f671dfb32b5ee007fa6287d3b17dae271fad15 100644 --- a/.github/workflows/scripts/example_checks/detect_changed_example.py +++ b/.github/workflows/scripts/example_checks/detect_changed_example.py @@ -3,7 +3,7 @@ import argparse def main(): parser = argparse.ArgumentParser() - parser.add_argument('-f', '--fileNameList', type=str, help="The list of changed files") + parser.add_argument("-f", "--fileNameList", type=str, help="The list of changed files") args = parser.parse_args() name_list = args.fileNameList.split(":") folder_need_check = set() @@ -15,10 +15,10 @@ def main(): # - application # - file if loc.split("/")[0] == "examples" and len(loc.split("/")) >= 4: - folder_need_check.add('/'.join(loc.split("/")[1:3])) + folder_need_check.add("/".join(loc.split("/")[1:3])) # Output the result using print. Then the shell can get the values. print(list(folder_need_check)) -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/.github/workflows/scripts/generate_leaderboard_and_send_to_lark.py b/.github/workflows/scripts/generate_leaderboard_and_send_to_lark.py index 16b8957c1d884aba897fd4734a8ec54efc299c5a..412b14c7b28337e3ab0129714b808b0fee28d687 100644 --- a/.github/workflows/scripts/generate_leaderboard_and_send_to_lark.py +++ b/.github/workflows/scripts/generate_leaderboard_and_send_to_lark.py @@ -1,5 +1,4 @@ import os -from dataclasses import dataclass from datetime import datetime, timedelta from typing import Any, Dict, List @@ -10,8 +9,7 @@ import seaborn from requests_toolbelt import MultipartEncoder -@dataclass -class Contributor: +class Counter(dict): """ Dataclass for a github contributor. @@ -19,8 +17,40 @@ class Contributor: name (str): name of the contributor num_commits_this_week (int): number of commits made within one week """ - name: str - num_commits_this_week: int + + def record(self, item: str): + if item in self: + self[item] += 1 + else: + self[item] = 1 + + def to_sorted_list(self): + data = [(key, value) for key, value in self.items()] + data.sort(key=lambda x: x[1], reverse=True) + return data + + +def get_utc_time_one_week_ago(): + """ + Get the UTC time one week ago. + """ + now = datetime.utcnow() + start_datetime = now - timedelta(days=7) + return start_datetime + + +def datetime2str(dt): + """ + Convert datetime to string in the format of YYYY-MM-DDTHH:MM:SSZ + """ + return dt.strftime("%Y-%m-%dT%H:%M:%SZ") + + +def str2datetime(string): + """ + Convert string in the format of YYYY-MM-DDTHH:MM:SSZ to datetime + """ + return datetime.strptime(string, "%Y-%m-%dT%H:%M:%SZ") def plot_bar_chart(x: List[Any], y: List[Any], xlabel: str, ylabel: str, title: str, output_path: str) -> None: @@ -36,9 +66,30 @@ def plot_bar_chart(x: List[Any], y: List[Any], xlabel: str, ylabel: str, title: plt.savefig(output_path, dpi=1200) -def get_issue_pull_request_comments(github_token: str, since: str) -> Dict[str, int]: +def get_organization_repositories(github_token, organization_name) -> List[str]: + """ + Retrieve the public repositories under the organization. + """ + url = f"https://api.github.com/orgs/{organization_name}/repos?type=public" + + # prepare header + headers = { + "Authorization": f"Bearer {github_token}", + "Accept": "application/vnd.github+json", + "X-GitHub-Api-Version": "2022-11-28", + } + + res = requests.get(url, headers=headers).json() + repo_list = [] + + for item in res: + repo_list.append(item["name"]) + return repo_list + + +def get_issue_pull_request_comments(github_token: str, org_name: str, repo_name: str, since: str) -> Dict[str, int]: """ - Retrive the issue/PR comments made by our members in the last 7 days. + Retrieve the issue/PR comments made by our members in the last 7 days. Args: github_token (str): GitHub access token for API calls @@ -46,9 +97,9 @@ def get_issue_pull_request_comments(github_token: str, since: str) -> Dict[str, """ # prepare header headers = { - 'Authorization': f'Bearer {github_token}', - 'Accept': 'application/vnd.github+json', - 'X-GitHub-Api-Version': '2022-11-28' + "Authorization": f"Bearer {github_token}", + "Accept": "application/vnd.github+json", + "X-GitHub-Api-Version": "2022-11-28", } user_engagement_count = {} @@ -56,28 +107,28 @@ def get_issue_pull_request_comments(github_token: str, since: str) -> Dict[str, # do pagination to the API page = 1 while True: - comment_api = f'https://api.github.com/repos/hpcaitech/ColossalAI/issues/comments?since={since}&page={page}' + comment_api = f"https://api.github.com/repos/{org_name}/{repo_name}/issues/comments?since={since}&page={page}" comment_response = requests.get(comment_api, headers=headers).json() if len(comment_response) == 0: break else: for item in comment_response: - comment_author_relationship = item['author_association'] - if comment_author_relationship != 'MEMBER': + comment_author_relationship = item["author_association"] + if comment_author_relationship != "MEMBER": # if the comment is not made by our member # we don't count this comment towards user engagement continue - issue_id = item['issue_url'].split('/')[-1] - issue_api = f'https://api.github.com/repos/hpcaitech/ColossalAI/issues/{issue_id}' + issue_id = item["issue_url"].split("/")[-1] + issue_api = f"https://api.github.com/repos/{org_name}/{repo_name}/issues/{issue_id}" issue_response = requests.get(issue_api, headers=headers).json() - issue_author_relationship = issue_response['author_association'] + issue_author_relationship = issue_response["author_association"] - if issue_author_relationship != 'MEMBER': + if issue_author_relationship != "MEMBER": # this means that the issue/PR is not created by our own people # any comments in this issue/PR by our member will be counted towards the leaderboard - member_name = item['user']['login'] + member_name = item["user"]["login"] if member_name in user_engagement_count: user_engagement_count[member_name] += 1 @@ -87,9 +138,9 @@ def get_issue_pull_request_comments(github_token: str, since: str) -> Dict[str, return user_engagement_count -def get_discussion_comments(github_token, since) -> Dict[str, int]: +def get_discussion_comments(github_token: str, org_name: str, repo_name: str, since: str) -> Dict[str, int]: """ - Retrive the discussion comments made by our members in the last 7 days. + Retrieve the discussion comments made by our members in the last 7 days. This is only available via the GitHub GraphQL API. Args: @@ -102,10 +153,10 @@ def get_discussion_comments(github_token, since) -> Dict[str, int]: if cursor is None: offset_str = "" else: - offset_str = f", after: \"{cursor}\"" + offset_str = f', after: "{cursor}"' query = f""" {{ - repository(owner: "hpcaitech", name: "ColossalAI"){{ + repository(owner: "{org_name}", name: "{repo_name}"){{ discussions(first: {num} {offset_str}){{ edges {{ cursor @@ -131,10 +182,10 @@ def get_discussion_comments(github_token, since) -> Dict[str, int]: if cursor is None: offset_str = "" else: - offset_str = f", before: \"{cursor}\"" + offset_str = f', before: "{cursor}"' query = f""" {{ - repository(owner: "hpcaitech", name: "ColossalAI"){{ + repository(owner: "{org_name}", name: "{repo_name}"){{ discussion(number: {discussion_number}){{ title comments(last: {num} {offset_str}){{ @@ -169,8 +220,8 @@ def get_discussion_comments(github_token, since) -> Dict[str, int]: # a utility function to make call to Github GraphQL API def _call_graphql_api(query): headers = {"Authorization": f"Bearer {github_token}"} - json_data = {'query': query} - response = requests.post('https://api.github.com/graphql', json=json_data, headers=headers) + json_data = {"query": query} + response = requests.post("https://api.github.com/graphql", json=json_data, headers=headers) data = response.json() return data @@ -183,21 +234,21 @@ def get_discussion_comments(github_token, since) -> Dict[str, int]: data = _call_graphql_api(query) found_discussion_out_of_time_range = False - edges = data['data']['repository']['discussions']['edges'] + edges = data["data"]["repository"]["discussions"]["edges"] if len(edges) == 0: break else: # keep the discussion whose author is not a member for edge in edges: # print the discussion title - discussion = edge['node'] + discussion = edge["node"] + discussion_updated_at = str2datetime(discussion["updatedAt"]) - discussion_updated_at = datetime.strptime(discussion['updatedAt'], "%Y-%m-%dT%H:%M:%SZ") # check if the updatedAt is within the last 7 days - # if yes, add it to dicussion_numbers + # if yes, add it to discussion_numbers if discussion_updated_at > since: - if discussion['authorAssociation'] != 'MEMBER': - discussion_numbers.append(discussion['number']) + if discussion["authorAssociation"] != "MEMBER": + discussion_numbers.append(discussion["number"]) else: found_discussion_out_of_time_range = True @@ -205,54 +256,55 @@ def get_discussion_comments(github_token, since) -> Dict[str, int]: break else: # update cursor - cursor = edges[-1]['cursor'] + cursor = edges[-1]["cursor"] - # get the dicussion comments and replies made by our member + # get the discussion comments and replies made by our member user_engagement_count = {} - for dicussion_number in discussion_numbers: + for discussion_number in discussion_numbers: cursor = None num_per_request = 10 while True: - query = _generate_comment_reply_count_for_discussion(dicussion_number, num_per_request, cursor) + query = _generate_comment_reply_count_for_discussion(discussion_number, num_per_request, cursor) data = _call_graphql_api(query) # get the comments - edges = data['data']['repository']['discussion']['comments']['edges'] + edges = data["data"]["repository"]["discussion"]["comments"]["edges"] # update the cursor if len(edges) == 0: break else: # update cursor for pagination - cursor = edges[-1]['cursor'] + cursor = edges[-1]["cursor"] for edge in edges: - comment = edge['node'] - if comment['authorAssociation'] == 'MEMBER': + comment = edge["node"] + if comment["authorAssociation"] == "MEMBER": # check if the updatedAt is within the last 7 days # if yes, add it to user_engagement_count - comment_updated_at = datetime.strptime(comment['updatedAt'], "%Y-%m-%dT%H:%M:%SZ") + comment_updated_at = datetime.strptime(comment["updatedAt"], "%Y-%m-%dT%H:%M:%SZ") if comment_updated_at > since: - member_name = comment['author']['login'] + member_name = comment["author"]["login"] if member_name in user_engagement_count: user_engagement_count[member_name] += 1 else: user_engagement_count[member_name] = 1 # get the replies - reply_edges = comment['replies']['edges'] + reply_edges = comment["replies"]["edges"] if len(reply_edges) == 0: continue else: for reply_edge in reply_edges: - reply = reply_edge['node'] - if reply['authorAssociation'] == 'MEMBER': + reply = reply_edge["node"] + if reply["authorAssociation"] == "MEMBER": # check if the updatedAt is within the last 7 days - # if yes, add it to dicussion_numbers - reply_updated_at = datetime.strptime(reply['updatedAt'], "%Y-%m-%dT%H:%M:%SZ") + # if yes, add it to discussion_numbers + + reply_updated_at = datetime.strptime(reply["updatedAt"], "%Y-%m-%dT%H:%M:%SZ") if reply_updated_at > since: - member_name = reply['author']['login'] + member_name = reply["author"]["login"] if member_name in user_engagement_count: user_engagement_count[member_name] += 1 else: @@ -260,7 +312,9 @@ def get_discussion_comments(github_token, since) -> Dict[str, int]: return user_engagement_count -def generate_user_engagement_leaderboard_image(github_token: str, output_path: str) -> bool: +def generate_user_engagement_leaderboard_image( + github_token: str, org_name: str, repo_list: List[str], output_path: str +) -> bool: """ Generate the user engagement leaderboard image for stats within the last 7 days @@ -270,22 +324,31 @@ def generate_user_engagement_leaderboard_image(github_token: str, output_path: s """ # request to the Github API to get the users who have replied the most in the last 7 days - now = datetime.utcnow() - start_datetime = now - timedelta(days=7) - start_datetime_str = start_datetime.strftime("%Y-%m-%dT%H:%M:%SZ") + start_datetime = get_utc_time_one_week_ago() + start_datetime_str = datetime2str(start_datetime) # get the issue/PR comments and discussion comment count - issue_pr_engagement_count = get_issue_pull_request_comments(github_token=github_token, since=start_datetime_str) - discussion_engagement_count = get_discussion_comments(github_token=github_token, since=start_datetime) total_engagement_count = {} - # update the total engagement count - total_engagement_count.update(issue_pr_engagement_count) - for name, count in discussion_engagement_count.items(): - if name in total_engagement_count: - total_engagement_count[name] += count - else: - total_engagement_count[name] = count + def _update_count(counter): + for name, count in counter.items(): + if name in total_engagement_count: + total_engagement_count[name] += count + else: + total_engagement_count[name] = count + + for repo_name in repo_list: + print(f"Fetching user engagement count for {repo_name}/{repo_name}") + issue_pr_engagement_count = get_issue_pull_request_comments( + github_token=github_token, org_name=org_name, repo_name=repo_name, since=start_datetime_str + ) + discussion_engagement_count = get_discussion_comments( + github_token=github_token, org_name=org_name, repo_name=repo_name, since=start_datetime + ) + + # update the total engagement count + _update_count(issue_pr_engagement_count) + _update_count(discussion_engagement_count) # prepare the data for plotting x = [] @@ -302,20 +365,17 @@ def generate_user_engagement_leaderboard_image(github_token: str, output_path: s x.append(count) y.append(name) - # use Shanghai time to display on the image - start_datetime_str = datetime.now(pytz.timezone('Asia/Shanghai')).strftime("%Y-%m-%dT%H:%M:%SZ") - # plot the leaderboard xlabel = f"Number of Comments made (since {start_datetime_str})" ylabel = "Member" - title = 'Active User Engagement Leaderboard' + title = "Active User Engagement Leaderboard" plot_bar_chart(x, y, xlabel=xlabel, ylabel=ylabel, title=title, output_path=output_path) return True else: return False -def generate_contributor_leaderboard_image(github_token, output_path) -> bool: +def generate_contributor_leaderboard_image(github_token, org_name, repo_list, output_path) -> bool: """ Generate the contributor leaderboard image for stats within the last 7 days @@ -324,54 +384,81 @@ def generate_contributor_leaderboard_image(github_token, output_path) -> bool: output_path (str): the path to save the image """ # request to the Github API to get the users who have contributed in the last 7 days - URL = 'https://api.github.com/repos/hpcaitech/ColossalAI/stats/contributors' headers = { - 'Authorization': f'Bearer {github_token}', - 'Accept': 'application/vnd.github+json', - 'X-GitHub-Api-Version': '2022-11-28' + "Authorization": f"Bearer {github_token}", + "Accept": "application/vnd.github+json", + "X-GitHub-Api-Version": "2022-11-28", } - while True: - response = requests.get(URL, headers=headers).json() + counter = Counter() + start_datetime = get_utc_time_one_week_ago() - if len(response) != 0: - # sometimes the Github API returns empty response for unknown reason - # request again if the response is empty - break + def _get_url(org_name, repo_name, page): + return f"https://api.github.com/repos/{org_name}/{repo_name}/pulls?per_page=50&page={page}&state=closed" - contributor_list = [] + def _iterate_by_page(org_name, repo_name): + page = 1 + stop = False - # get number of commits for each contributor - start_timestamp = None - for item in response: - num_commits_this_week = item['weeks'][-1]['c'] - name = item['author']['login'] - contributor = Contributor(name=name, num_commits_this_week=num_commits_this_week) - contributor_list.append(contributor) + while not stop: + print(f"Fetching pull request data for {org_name}/{repo_name} - page{page}") + url = _get_url(org_name, repo_name, page) - # update start_timestamp - start_timestamp = item['weeks'][-1]['w'] + while True: + response = requests.get(url, headers=headers).json() + + if isinstance(response, list): + # sometimes the Github API returns nothing + # request again if the response is not a list + break + print("Empty response, request again...") + + if len(response) == 0: + # if the response is empty, stop + stop = True + break + + # count the pull request and author from response + for pr_data in response: + merged_at = pr_data["merged_at"] + author = pr_data["user"]["login"] + + if merged_at is None: + continue + + merge_datetime = str2datetime(merged_at) + + if merge_datetime < start_datetime: + # if we found a pull request that is merged before the start_datetime + # we stop + stop = True + break + else: + # record the author1 + counter.record(author) + + # next page + page += 1 + + for repo_name in repo_list: + _iterate_by_page(org_name, repo_name) # convert unix timestamp to Beijing datetime - start_datetime = datetime.fromtimestamp(start_timestamp, tz=pytz.timezone('Asia/Shanghai')) - start_datetime_str = start_datetime.strftime("%Y-%m-%dT%H:%M:%SZ") + bj_start_datetime = datetime.fromtimestamp(start_datetime.timestamp(), tz=pytz.timezone("Asia/Shanghai")) + bj_start_datetime_str = datetime2str(bj_start_datetime) - # sort by number of commits - contributor_list.sort(key=lambda x: x.num_commits_this_week, reverse=True) + contribution_list = counter.to_sorted_list() # remove contributors who has zero commits - contributor_list = [x for x in contributor_list if x.num_commits_this_week > 0] - - # prepare the data for plotting - x = [x.num_commits_this_week for x in contributor_list] - y = [x.name for x in contributor_list] + author_list = [x[0] for x in contribution_list] + num_commit_list = [x[1] for x in contribution_list] # plot - if len(x) > 0: - xlabel = f"Number of Commits (since {start_datetime_str})" + if len(author_list) > 0: + xlabel = f"Number of Pull Requests (since {bj_start_datetime_str})" ylabel = "Contributor" - title = 'Active Contributor Leaderboard' - plot_bar_chart(x, y, xlabel=xlabel, ylabel=ylabel, title=title, output_path=output_path) + title = "Active Contributor Leaderboard" + plot_bar_chart(num_commit_list, author_list, xlabel=xlabel, ylabel=ylabel, title=title, output_path=output_path) return True else: return False @@ -386,14 +473,14 @@ def upload_image_to_lark(lark_tenant_token: str, image_path: str) -> str: image_path (str): the path to the image to be uploaded """ url = "https://open.feishu.cn/open-apis/im/v1/images" - form = {'image_type': 'message', 'image': (open(image_path, 'rb'))} # 需要替换具体的path + form = {"image_type": "message", "image": (open(image_path, "rb"))} # 需要替换具体的path multi_form = MultipartEncoder(form) headers = { - 'Authorization': f'Bearer {lark_tenant_token}', ## 获取tenant_access_token, 需要替换为实际的token + "Authorization": f"Bearer {lark_tenant_token}", ## 获取tenant_access_token, 需要替换为实际的token } - headers['Content-Type'] = multi_form.content_type + headers["Content-Type"] = multi_form.content_type response = requests.request("POST", url, headers=headers, data=multi_form).json() - return response['data']['image_key'] + return response["data"]["image_key"] def generate_lark_tenant_access_token(app_id: str, app_secret: str) -> str: @@ -404,10 +491,10 @@ def generate_lark_tenant_access_token(app_id: str, app_secret: str) -> str: app_id (str): Lark app id app_secret (str): Lark app secret """ - url = 'https://open.feishu.cn/open-apis/auth/v3/tenant_access_token/internal' - data = {'app_id': app_id, 'app_secret': app_secret} + url = "https://open.feishu.cn/open-apis/auth/v3/tenant_access_token/internal" + data = {"app_id": app_id, "app_secret": app_secret} response = requests.post(url, json=data).json() - return response['tenant_access_token'] + return response["tenant_access_token"] def send_image_to_lark(image_key: str, webhook_url: str) -> None: @@ -434,31 +521,37 @@ def send_message_to_lark(message: str, webhook_url: str): requests.post(webhook_url, json=data) -if __name__ == '__main__': - GITHUB_TOKEN = os.environ['GITHUB_TOKEN'] - CONTRIBUTOR_IMAGE_PATH = 'contributor_leaderboard.png' - USER_ENGAGEMENT_IMAGE_PATH = 'engagement_leaderboard.png' +if __name__ == "__main__": + GITHUB_TOKEN = os.environ["GITHUB_TOKEN"] + CONTRIBUTOR_IMAGE_PATH = "contributor_leaderboard.png" + USER_ENGAGEMENT_IMAGE_PATH = "engagement_leaderboard.png" + ORG_NAME = "hpcaitech" + + # get all open source repositories + REPO_LIST = get_organization_repositories(GITHUB_TOKEN, ORG_NAME) # generate images - contrib_success = generate_contributor_leaderboard_image(GITHUB_TOKEN, CONTRIBUTOR_IMAGE_PATH) - engagement_success = generate_user_engagement_leaderboard_image(GITHUB_TOKEN, USER_ENGAGEMENT_IMAGE_PATH) + contrib_success = generate_contributor_leaderboard_image(GITHUB_TOKEN, ORG_NAME, REPO_LIST, CONTRIBUTOR_IMAGE_PATH) + engagement_success = generate_user_engagement_leaderboard_image( + GITHUB_TOKEN, ORG_NAME, REPO_LIST, USER_ENGAGEMENT_IMAGE_PATH + ) # upload images - APP_ID = os.environ['LARK_APP_ID'] - APP_SECRET = os.environ['LARK_APP_SECRET'] + APP_ID = os.environ["LARK_APP_ID"] + APP_SECRET = os.environ["LARK_APP_SECRET"] LARK_TENANT_TOKEN = generate_lark_tenant_access_token(app_id=APP_ID, app_secret=APP_SECRET) contributor_image_key = upload_image_to_lark(LARK_TENANT_TOKEN, CONTRIBUTOR_IMAGE_PATH) user_engagement_image_key = upload_image_to_lark(LARK_TENANT_TOKEN, USER_ENGAGEMENT_IMAGE_PATH) # send message to lark - LARK_WEBHOOK_URL = os.environ['LARK_WEBHOOK_URL'] + LARK_WEBHOOK_URL = os.environ["LARK_WEBHOOK_URL"] message = """本周的社区榜单出炉啦! 1. 开发贡献者榜单 2. 用户互动榜单 注: -- 开发贡献者测评标准为:本周由公司成员提交的commit次数 -- 用户互动榜单测评标准为:本周由公司成员在非成员创建的issue/PR/discussion中回复的次数 +- 开发贡献者测评标准为:本周由公司成员与社区在所有开源仓库提交的Pull Request次数 +- 用户互动榜单测评标准为:本周由公司成员在非成员在所有开源仓库创建的issue/PR/discussion中回复的次数 """ send_message_to_lark(message, LARK_WEBHOOK_URL) @@ -467,7 +560,7 @@ if __name__ == '__main__': if contrib_success: send_image_to_lark(contributor_image_key, LARK_WEBHOOK_URL) else: - send_message_to_lark("本周没有成员贡献commit,无榜单图片生成。", LARK_WEBHOOK_URL) + send_message_to_lark("本周没有成员贡献PR,无榜单图片生成。", LARK_WEBHOOK_URL) # send user engagement image to lark if engagement_success: diff --git a/.github/workflows/scripts/generate_release_draft.py b/.github/workflows/scripts/generate_release_draft.py index dc592e4c977b46b0f0156a9da08832b7624776df..7374481005ef1ba8de1596771a1271cc28425ffa 100644 --- a/.github/workflows/scripts/generate_release_draft.py +++ b/.github/workflows/scripts/generate_release_draft.py @@ -7,27 +7,27 @@ import re import requests -COMMIT_API = 'https://api.github.com/repos/hpcaitech/ColossalAI/commits' -TAGS_API = 'https://api.github.com/repos/hpcaitech/ColossalAI/tags' +COMMIT_API = "https://api.github.com/repos/hpcaitech/ColossalAI/commits" +TAGS_API = "https://api.github.com/repos/hpcaitech/ColossalAI/tags" def parse_args(): parser = argparse.ArgumentParser() - parser.add_argument('--out', type=str, help='output path for the release draft', required=True) - parser.add_argument('--version', type=str, help='current version to release', required=True) + parser.add_argument("--out", type=str, help="output path for the release draft", required=True) + parser.add_argument("--version", type=str, help="current version to release", required=True) return parser.parse_args() def get_latest_tag_commit(headers=None): res = requests.get(url=TAGS_API, headers=headers) data = res.json() - commit_hash = data[0]['commit']['sha'] - version = data[0]['name'] + commit_hash = data[0]["commit"]["sha"] + version = data[0]["name"] return commit_hash, version def get_commit_info(commit_hash, headers=None): - api = f'{COMMIT_API}/{commit_hash}' + api = f"{COMMIT_API}/{commit_hash}" res = requests.get(url=api, headers=headers) return res.json() @@ -37,7 +37,7 @@ def get_all_commit_info(since, headers=None): results = [] while True: - api = f'{COMMIT_API}?since={since}&per_page=100&page={page}' + api = f"{COMMIT_API}?since={since}&per_page=100&page={page}" resp = requests.get(url=api, headers=headers) data = resp.json() @@ -53,21 +53,21 @@ def get_all_commit_info(since, headers=None): def collate_release_info(commit_info_list): results = dict() - pattern = pattern = r'\[.*\]' + pattern = pattern = r"\[.*\]" for commit_info in commit_info_list: - author = commit_info['commit']['author']['name'] + author = commit_info["commit"]["author"]["name"] try: - author_url = commit_info['author']['url'] + author_url = commit_info["author"]["url"] except: # author can be None author_url = None - msg = commit_info['commit']['message'] + msg = commit_info["commit"]["message"] match = re.search(pattern, msg) if match: - tag = match.group().lstrip('[').rstrip(']').capitalize() + tag = match.group().lstrip("[").rstrip("]").capitalize() if tag not in results: results[tag] = [] results[tag].append((msg, author, author_url)) @@ -89,42 +89,43 @@ def generate_release_post_markdown(current_version, last_version, release_info): for msg, author, author_url in v: # only keep the first line - msg = msg.split('\n')[0] + msg = msg.split("\n")[0] if author_url: - item = f'{msg} by [{author}]({author_url})\n' + item = f"{msg} by [{author}]({author_url})\n" else: - item = f'{msg} by {author}\n' - text.append(f'- {item}') + item = f"{msg} by {author}\n" + text.append(f"- {item}") - text.append('\n') + text.append("\n") # add full change log text.append( - f'**Full Changelog**: https://github.com/hpcaitech/ColossalAI/compare/{current_version}...{last_version}') + f"**Full Changelog**: https://github.com/hpcaitech/ColossalAI/compare/{current_version}...{last_version}" + ) return text -if __name__ == '__main__': +if __name__ == "__main__": args = parse_args() - token = os.environ['GITHUB_API_TOKEN'] - headers = {'Authorization': token} + token = os.environ["GITHUB_API_TOKEN"] + headers = {"Authorization": token} # get previous release tag last_release_commit, last_version = get_latest_tag_commit(headers) last_release_commit_info = get_commit_info(last_release_commit, headers=headers) - last_release_date = last_release_commit_info['commit']['author']['date'] + last_release_date = last_release_commit_info["commit"]["author"]["date"] # get the commits since last release commit_info = get_all_commit_info(since=last_release_date, headers=headers) - commit_info = commit_info[:-1] # remove the release commit + commit_info = commit_info[:-1] # remove the release commit # collate into markdown release_info = collate_release_info(commit_info) markdown_text = generate_release_post_markdown(args.version, last_version, release_info) # write into a file - with open(args.out, 'w') as f: + with open(args.out, "w") as f: for line in markdown_text: f.write(line) diff --git a/.github/workflows/scripts/send_message_to_lark.py b/.github/workflows/scripts/send_message_to_lark.py index a113327a786ed1310b6ef8c0ffc784bd6af2e344..bc005d93c3f5f5c8f431c726a850fee2bc3e9269 100644 --- a/.github/workflows/scripts/send_message_to_lark.py +++ b/.github/workflows/scripts/send_message_to_lark.py @@ -5,8 +5,8 @@ import requests def parse_args(): parser = argparse.ArgumentParser() - parser.add_argument('-m', '--message', type=str) - parser.add_argument('-u', '--url', type=str) + parser.add_argument("-m", "--message", type=str) + parser.add_argument("-u", "--url", type=str) return parser.parse_args() @@ -15,6 +15,6 @@ def send_message_to_lark(message, webhook_url): requests.post(webhook_url, json=data) -if __name__ == '__main__': +if __name__ == "__main__": args = parse_args() send_message_to_lark(args.message, args.url) diff --git a/.gitignore b/.gitignore index bf74a753894fcbbf722812700fdf72a38c9e896c..81113fa99dd570b071e08a9c8cbb681d33626bc9 100644 --- a/.gitignore +++ b/.gitignore @@ -155,3 +155,7 @@ colossalai/version.py # ignore coverage test file coverage.lcov coverage.xml + +# ignore testmon and coverage files +.coverage +.testmondata* diff --git a/.isort.cfg b/.isort.cfg index 090aa28e39f32da8c0161d5317c710b6c8781641..ccbf575fdbfacd185cf880431ad81462e0ae8fdf 100644 --- a/.isort.cfg +++ b/.isort.cfg @@ -3,3 +3,5 @@ line_length = 120 multi_line_output=3 include_trailing_comma = true ignore_comments = true +profile = black +honor_noqa = true diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 725d266375ef42e69f6097f7d61924213d46b8ff..9871e1184462f2069071ea8db96495b20059d645 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,23 +1,31 @@ repos: + - repo: https://github.com/PyCQA/autoflake + rev: v2.2.1 + hooks: + - id: autoflake + name: autoflake (python) + args: ['--in-place', '--remove-unused-variables', '--remove-all-unused-imports', '--ignore-init-module-imports'] + - repo: https://github.com/pycqa/isort rev: 5.12.0 hooks: - id: isort name: sort all imports (python) - - repo: https://github.com/pre-commit/mirrors-yapf - rev: v0.32.0 + - repo: https://github.com/psf/black-pre-commit-mirror + rev: 23.9.1 hooks: - - id: yapf - name: yapf formatter - args: ['--style=.style.yapf', '--parallel', '--in-place'] + - id: black + name: black formatter + args: ['--line-length=120', '--target-version=py37', '--target-version=py38', '--target-version=py39','--target-version=py310'] - repo: https://github.com/pre-commit/mirrors-clang-format rev: v13.0.1 hooks: - id: clang-format name: clang formatter + types_or: [c++, c] - repo: https://github.com/pre-commit/pre-commit-hooks rev: v4.3.0 diff --git a/.style.yapf b/.style.yapf deleted file mode 100644 index 05be0dc6a3a598ebd59ff00cec93ce8b809c78b5..0000000000000000000000000000000000000000 --- a/.style.yapf +++ /dev/null @@ -1,5 +0,0 @@ -[style] -based_on_style = google -spaces_before_comment = 4 -split_before_logical_operator = true -column_limit = 120 diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 00abcf650158c571411f7b78a0ef4c982365b746..a3dc020f74e9529d08a313fc7242b0e552275db6 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -30,6 +30,12 @@ pip install -e . ### Unit Tests We use [PyTest](https://docs.pytest.org/en/latest/) to execute tests. You can install pytest by `pip install pytest`. As some of the tests require initialization of the distributed backend, GPUs are needed to execute these tests. +To set up the environment for unit testing, first change your current directory to the root directory of your local ColossalAI repository, then run +```bash +pip install -r requirements/requirements-test.txt +``` +If you encounter an error telling "Could not find a version that satisfies the requirement fbgemm-gpu==0.2.0", please downgrade your python version to 3.8 or 3.9 and try again. + If you only want to run CPU tests, you can run ```bash @@ -138,4 +144,4 @@ You can now create a pull request on the GitHub webpage of your repository. The Do write clearly the description of your pull request and [link the pull request to your target issue](https://docs.github.com/en/issues/tracking-your-work-with-issues/linking-a-pull-request-to-an-issue). This will automatically close the issue when the pull request is approved. -In case of code conflict, you should rebase your branch and resolve the conflicts manually. \ No newline at end of file +In case of code conflict, you should rebase your branch and resolve the conflicts manually. diff --git a/LICENSE b/LICENSE index c7a5bb16880e6a2a6364a092fe94dda29399b9e3..59d456c5b8a1a452ce5f478f6416d66572bc13c5 100644 --- a/LICENSE +++ b/LICENSE @@ -396,3 +396,84 @@ Copyright 2021- HPC-AI Technology Inc. All rights reserved. CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + + ---------------- LICENSE FOR VLLM TEAM ---------------- + + from VLLM TEAM: + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + https://github.com/vllm-project/vllm/blob/main/LICENSE + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + + ---------------- LICENSE FOR LIGHTLLM TEAM ---------------- + + from LIGHTLLM TEAM: + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + https://github.com/ModelTC/lightllm/blob/main/LICENSE + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + ---------------- LICENSE FOR AutoGPTQ ---------------- + + From AutoGPTQ: + + MIT License + + Copyright (c) 2023 潘其威(William) + + Permission is hereby granted, free of charge, to any person obtaining a copy + of this software and associated documentation files (the "Software"), to deal + in the Software without restriction, including without limitation the rights + to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + copies of the Software, and to permit persons to whom the Software is + furnished to do so, subject to the following conditions: + + The above copyright notice and this permission notice shall be included in all + copies or substantial portions of the Software. + + THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + SOFTWARE. + + ---------------- LICENSE FOR exllama ---------------- + + From exllama: + + MIT License + + Permission is hereby granted, free of charge, to any person obtaining a copy + of this software and associated documentation files (the "Software"), to deal + in the Software without restriction, including without limitation the rights + to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + copies of the Software, and to permit persons to whom the Software is + furnished to do so, subject to the following conditions: + + The above copyright notice and this permission notice shall be included in all + copies or substantial portions of the Software. + + THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + SOFTWARE. diff --git a/README.md b/README.md index 79f733122cb3a14ff27b72889caa710ac5e8a5f6..b2efb79104890d58128cce73bb7c334035af6e55 100644 --- a/README.md +++ b/README.md @@ -16,7 +16,7 @@ [![Documentation](https://readthedocs.org/projects/colossalai/badge/?version=latest)](https://colossalai.readthedocs.io/en/latest/?badge=latest) [![CodeFactor](https://www.codefactor.io/repository/github/hpcaitech/colossalai/badge)](https://www.codefactor.io/repository/github/hpcaitech/colossalai) [![HuggingFace badge](https://img.shields.io/badge/%F0%9F%A4%97HuggingFace-Join-yellow)](https://huggingface.co/hpcai-tech) - [![slack badge](https://img.shields.io/badge/Slack-join-blueviolet?logo=slack&)](https://join.slack.com/t/colossalaiworkspace/shared_invite/zt-z7b26eeb-CBp7jouvu~r0~lcFzX832w) + [![slack badge](https://img.shields.io/badge/Slack-join-blueviolet?logo=slack&)](https://github.com/hpcaitech/public_assets/tree/main/colossalai/contact/slack) [![WeChat badge](https://img.shields.io/badge/微信-加入-green?logo=wechat&)](https://raw.githubusercontent.com/hpcaitech/public_assets/main/colossalai/img/WeChat.png) @@ -25,14 +25,15 @@ ## Latest News +* [2023/09] [One Half-Day of Training Using a Few Hundred Dollars Yields Similar Results to Mainstream Large Models, Open-Source and Commercial-Free Domain-Specific Llm Solution](https://www.hpc-ai.tech/blog/one-half-day-of-training-using-a-few-hundred-dollars-yields-similar-results-to-mainstream-large-models-open-source-and-commercial-free-domain-specific-llm-solution) +* [2023/09] [70 Billion Parameter LLaMA2 Model Training Accelerated by 195%](https://www.hpc-ai.tech/blog/70b-llama2-training) +* [2023/07] [HPC-AI Tech Raises 22 Million USD in Series A Funding](https://www.hpc-ai.tech/blog/hpc-ai-tech-raises-22-million-usd-in-series-a-funding-to-fuel-team-expansion-and-business-growth) +* [2023/07] [65B Model Pretraining Accelerated by 38%, Best Practices for Building LLaMA-Like Base Models Open-Source](https://www.hpc-ai.tech/blog/large-model-pretraining) * [2023/03] [ColossalChat: An Open-Source Solution for Cloning ChatGPT With a Complete RLHF Pipeline](https://medium.com/@yangyou_berkeley/colossalchat-an-open-source-solution-for-cloning-chatgpt-with-a-complete-rlhf-pipeline-5edf08fb538b) * [2023/03] [Intel and Colossal-AI Partner to Deliver Cost-Efficient Open-Source Solution for Protein Folding Structure Prediction](https://www.hpc-ai.tech/blog/intel-habana) * [2023/03] [AWS and Google Fund Colossal-AI with Startup Cloud Programs](https://www.hpc-ai.tech/blog/aws-and-google-fund-colossal-ai-with-startup-cloud-programs) * [2023/02] [Open Source Solution Replicates ChatGPT Training Process! Ready to go with only 1.6GB GPU Memory](https://www.hpc-ai.tech/blog/colossal-ai-chatgpt) * [2023/01] [Hardware Savings Up to 46 Times for AIGC and Automatic Parallelism](https://medium.com/pytorch/latest-colossal-ai-boasts-novel-automatic-parallelism-and-offers-savings-up-to-46x-for-stable-1453b48f3f02) -* [2022/11] [Diffusion Pretraining and Hardware Fine-Tuning Can Be Almost 7X Cheaper](https://www.hpc-ai.tech/blog/diffusion-pretraining-and-hardware-fine-tuning-can-be-almost-7x-cheaper) -* [2022/10] [Use a Laptop to Analyze 90% of Proteins, With a Single-GPU Inference Sequence Exceeding 10,000](https://www.hpc-ai.tech/blog/use-a-laptop-to-analyze-90-of-proteins-with-a-single-gpu-inference-sequence-exceeding) -* [2022/09] [HPC-AI Tech Completes $6 Million Seed and Angel Round Fundraising](https://www.hpc-ai.tech/blog/hpc-ai-tech-completes-6-million-seed-and-angel-round-fundraising-led-by-bluerun-ventures-in-the) ## Table of Contents
    @@ -41,6 +42,7 @@
  • Colossal-AI for Real World Applications
      +
    • Colossal-LLaMA-2: One Half-Day of Training Using a Few Hundred Dollars Yields Similar Results to Mainstream Large Models, Open-Source and Commercial-Free Domain-Specific Llm Solution
    • ColossalChat: An Open-Source Solution for Cloning ChatGPT With a Complete RLHF Pipeline
    • AIGC: Acceleration of Stable Diffusion
    • Biomedicine: Acceleration of AlphaFold Protein Structure
    • @@ -49,6 +51,7 @@
    • Parallel Training Demo
        +
      • LLaMA 1/2
      • GPT-3
      • GPT-2
      • BERT
      • @@ -124,15 +127,55 @@ distributed training and inference in a few lines. ## Colossal-AI in the Real World +### Colossal-LLaMA-2 + +- One half-day of training using a few hundred dollars yields similar results to mainstream large models, open-source and commercial-free domain-specific LLM solution. +[[code]](https://github.com/hpcaitech/ColossalAI/tree/main/applications/Colossal-LLaMA-2) +[[blog]](https://www.hpc-ai.tech/blog/one-half-day-of-training-using-a-few-hundred-dollars-yields-similar-results-to-mainstream-large-models-open-source-and-commercial-free-domain-specific-llm-solution) +[[model weights]](https://huggingface.co/hpcai-tech/Colossal-LLaMA-2-7b-base) + +| | Backbone | Tokens Consumed | | MMLU | CMMLU | AGIEval | GAOKAO | CEval | +| :----------------------------: | :--------: | :-------------: | :------------------: | :-----------: | :-----: | :----: | :----: | :------------------------------: | +| | | - | | 5-shot | 5-shot | 5-shot | 0-shot | 5-shot | +| Baichuan-7B | - | 1.2T | | 42.32 (42.30) | 44.53 (44.02) | 38.72 | 36.74 | 42.80 | +| Baichuan-13B-Base | - | 1.4T | | 50.51 (51.60) | 55.73 (55.30) | 47.20 | 51.41 | 53.60 | +| Baichuan2-7B-Base | - | 2.6T | | 46.97 (54.16) | 57.67 (57.07) | 45.76 | 52.60 | 54.00 | +| Baichuan2-13B-Base | - | 2.6T | | 54.84 (59.17) | 62.62 (61.97) | 52.08 | 58.25 | 58.10 | +| ChatGLM-6B | - | 1.0T | | 39.67 (40.63) | 41.17 (-) | 40.10 | 36.53 | 38.90 | +| ChatGLM2-6B | - | 1.4T | | 44.74 (45.46) | 49.40 (-) | 46.36 | 45.49 | 51.70 | +| InternLM-7B | - | 1.6T | | 46.70 (51.00) | 52.00 (-) | 44.77 | 61.64 | 52.80 | +| Qwen-7B | - | 2.2T | | 54.29 (56.70) | 56.03 (58.80) | 52.47 | 56.42 | 59.60 | +| | | | | | | | | | +| Llama-2-7B | - | 2.0T | | 44.47 (45.30) | 32.97 (-) | 32.60 | 25.46 | - | +| Linly-AI/Chinese-LLaMA-2-7B-hf | Llama-2-7B | 1.0T | | 37.43 | 29.92 | 32.00 | 27.57 | - | +| wenge-research/yayi-7b-llama2 | Llama-2-7B | - | | 38.56 | 31.52 | 30.99 | 25.95 | - | +| ziqingyang/chinese-llama-2-7b | Llama-2-7B | - | | 33.86 | 34.69 | 34.52 | 25.18 | 34.2 | +| TigerResearch/tigerbot-7b-base | Llama-2-7B | 0.3T | | 43.73 | 42.04 | 37.64 | 30.61 | - | +| LinkSoul/Chinese-Llama-2-7b | Llama-2-7B | - | | 48.41 | 38.31 | 38.45 | 27.72 | - | +| FlagAlpha/Atom-7B | Llama-2-7B | 0.1T | | 49.96 | 41.10 | 39.83 | 33.00 | - | +| IDEA-CCNL/Ziya-LLaMA-13B-v1.1 | Llama-13B | 0.11T | | 50.25 | 40.99 | 40.04 | 30.54 | - | +| | | | | | | | | | +| **Colossal-LLaMA-2-7b-base** | Llama-2-7B | **0.0085T** | | 53.06 | 49.89 | 51.48 | 58.82 | 50.2 | + ### ColossalChat -[ColossalChat](https://github.com/hpcaitech/ColossalAI/tree/main/applications/Chat): An open-source solution for cloning [ChatGPT](https://openai.com/blog/chatgpt/) with a complete RLHF pipeline. [[code]](https://github.com/hpcaitech/ColossalAI/tree/main/applications/Chat) [[blog]](https://medium.com/@yangyou_berkeley/colossalchat-an-open-source-solution-for-cloning-chatgpt-with-a-complete-rlhf-pipeline-5edf08fb538b) [[demo]](https://chat.colossalai.org) +[ColossalChat](https://github.com/hpcaitech/ColossalAI/tree/main/applications/Chat): An open-source solution for cloning [ChatGPT](https://openai.com/blog/chatgpt/) with a complete RLHF pipeline. +[[code]](https://github.com/hpcaitech/ColossalAI/tree/main/applications/Chat) +[[blog]](https://medium.com/@yangyou_berkeley/colossalchat-an-open-source-solution-for-cloning-chatgpt-with-a-complete-rlhf-pipeline-5edf08fb538b) +[[demo]](https://www.youtube.com/watch?v=HcTiHzApHm0) +[[tutorial]](https://www.youtube.com/watch?v=-qFBZFmOJfg) + +

        + +

        + +- Up to 10 times faster for RLHF PPO Stage3 Training

        @@ -205,6 +248,23 @@ Acceleration of [AlphaFold Protein Structure](https://alphafold.ebi.ac.uk/)

        (back to top)

        ## Parallel Training Demo +### LLaMA2 +

        + +

        + +- 70 billion parameter LLaMA2 model training accelerated by 195% +[[code]](https://github.com/hpcaitech/ColossalAI/tree/main/examples/language/llama2) +[[blog]](https://www.hpc-ai.tech/blog/70b-llama2-training) + +### LLaMA1 +

        + +

        + +- 65-billion-parameter large model pretraining accelerated by 38% +[[code]](https://github.com/hpcaitech/ColossalAI/tree/example/llama/examples/language/llama) +[[blog]](https://www.hpc-ai.tech/blog/large-model-pretraining) ### GPT-3

        @@ -352,6 +412,22 @@ If you want to install and enable CUDA kernel fusion (compulsory installation wh CUDA_EXT=1 pip install . ``` +For Users with CUDA 10.2, you can still build ColossalAI from source. However, you need to manually download the cub library and copy it to the corresponding directory. + +```bash +# clone the repository +git clone https://github.com/hpcaitech/ColossalAI.git +cd ColossalAI + +# download the cub library +wget https://github.com/NVIDIA/cub/archive/refs/tags/1.8.0.zip +unzip 1.8.0.zip +cp -r cub-1.8.0/cub/ colossalai/kernel/cuda_native/csrc/kernels/include/ + +# install +CUDA_EXT=1 pip install . +``` +

        (back to top)

        ## Use Docker @@ -426,6 +502,7 @@ To cite this project, you can use the following BibTeX citation. } ``` -Colossal-AI has been accepted as official tutorial by top conferences [SC](https://sc22.supercomputing.org/), [AAAI](https://aaai.org/Conferences/AAAI-23/), [PPoPP](https://ppopp23.sigplan.org/), [CVPR](https://cvpr2023.thecvf.com/), [ISC](https://www.isc-hpc.com/), etc. +Colossal-AI has been accepted as official tutorial by top conferences [NeurIPS](https://nips.cc/), [SC](https://sc22.supercomputing.org/), [AAAI](https://aaai.org/Conferences/AAAI-23/), +[PPoPP](https://ppopp23.sigplan.org/), [CVPR](https://cvpr2023.thecvf.com/), [ISC](https://www.isc-hpc.com/), [NVIDIA GTC](https://www.nvidia.com/en-us/on-demand/session/gtcspring23-S51482/) ,etc.

        (back to top)

        diff --git a/applications/Chat/.gitignore b/applications/Chat/.gitignore index 2b9b4f345d0fae7bc3872d3c723d2698d201b8b8..5fa068105e261616c1c50f46ffaefdc7aa629b3c 100644 --- a/applications/Chat/.gitignore +++ b/applications/Chat/.gitignore @@ -145,4 +145,4 @@ docs/.build # wandb log example/wandb/ -examples/awesome-chatgpt-prompts/ \ No newline at end of file +examples/awesome-chatgpt-prompts/ diff --git a/applications/Chat/README.md b/applications/Chat/README.md index 9ba831973b6c912a2d8435c382e4c741dd24b6e4..d5be04ab9f44bf34aa87be1e22d36d64227ddd39 100644 --- a/applications/Chat/README.md +++ b/applications/Chat/README.md @@ -4,7 +4,6 @@ ColossalChat - ## Table of Contents - [Table of Contents](#table-of-contents) @@ -34,7 +33,9 @@ - [Authors](#authors) - [Citations](#citations) - [Licenses](#licenses) + --- + ## What is ColossalChat and Coati ? [ColossalChat](https://github.com/hpcaitech/ColossalAI/tree/main/applications/Chat) is the project to implement LLM with RLHF, powered by the [Colossal-AI](https://github.com/hpcaitech/ColossalAI) project. @@ -42,6 +43,7 @@ Coati stands for `ColossalAI Talking Intelligence`. It is the name for the module implemented in this project and is also the name of the large language model developed by the ColossalChat project. The Coati package provides a unified large language model framework that has implemented the following functions + - Supports comprehensive large-model training acceleration capabilities for ColossalAI, without requiring knowledge of complex distributed training algorithms - Supervised datasets collection - Supervised instructions fine-tuning @@ -56,29 +58,42 @@ The Coati package provides a unified large language model framework that has imp

        - Image source: https://openai.com/blog/chatgpt +Image source: https://openai.com/blog/chatgpt + **As Colossal-AI is undergoing some major updates, this project will be actively maintained to stay in line with the Colossal-AI project.** - More details can be found in the latest news. -* [2023/03] [ColossalChat: An Open-Source Solution for Cloning ChatGPT With a Complete RLHF Pipeline](https://medium.com/@yangyou_berkeley/colossalchat-an-open-source-solution-for-cloning-chatgpt-with-a-complete-rlhf-pipeline-5edf08fb538b) -* [2023/02] [Open Source Solution Replicates ChatGPT Training Process! Ready to go with only 1.6GB GPU Memory](https://www.hpc-ai.tech/blog/colossal-ai-chatgpt) + +- [2023/03] [ColossalChat: An Open-Source Solution for Cloning ChatGPT With a Complete RLHF Pipeline](https://medium.com/@yangyou_berkeley/colossalchat-an-open-source-solution-for-cloning-chatgpt-with-a-complete-rlhf-pipeline-5edf08fb538b) +- [2023/02] [Open Source Solution Replicates ChatGPT Training Process! Ready to go with only 1.6GB GPU Memory](https://www.hpc-ai.tech/blog/colossal-ai-chatgpt) ## Online demo -You can experience the performance of Coati7B on this page. -[chat.colossalai.org](https://chat.colossalai.org/) + + +[ColossalChat](https://github.com/hpcaitech/ColossalAI/tree/main/applications/Chat): An open-source solution for cloning [ChatGPT](https://openai.com/blog/chatgpt/) with a complete RLHF pipeline. +[[code]](https://github.com/hpcaitech/ColossalAI/tree/main/applications/Chat) +[[blog]](https://medium.com/@yangyou_berkeley/colossalchat-an-open-source-solution-for-cloning-chatgpt-with-a-complete-rlhf-pipeline-5edf08fb538b) +[[demo]](https://www.youtube.com/watch?v=HcTiHzApHm0) +[[tutorial]](https://www.youtube.com/watch?v=-qFBZFmOJfg) -Due to resource constraints, we will only provide this service from 29th Mar 2023 to 5 April 2023. However, we have provided the inference code in the [inference](./inference/) folder. The WebUI will be open-sourced soon as well. +

        + +

        + +> DeepSpeedChat performance comes from its blog on 2023 April 12, ColossalChat performance can be reproduced on an AWS p4d.24xlarge node with 8 A100-40G GPUs with the following command: `torchrun --standalone --nproc_per_node 8 benchmark_opt_lora_dummy.py --num_collect_steps 1 --use_kernels --strategy colossalai_zero2 --experience_batch_size 64 --train_batch_size 32` -> Warning: Due to model and dataset size limitations, Coati is just a baby model, Coati7B may output incorrect information and lack the ability for multi-turn dialogue. There is still significant room for improvement. ## Install ### Install the environment -```shell +```bash conda create -n coati conda activate coati git clone https://github.com/hpcaitech/ColossalAI.git @@ -87,22 +102,20 @@ pip install . ``` ### Install the Transformers -Given Hugging Face hasn't officially supported the LLaMA models, We fork a branch of Transformers that can be compatible with our code -```shell -git clone https://github.com/hpcaitech/transformers -cd transformers -pip install . +```bash +pip install transformers==4.30.2 ``` ## How to use? ### Supervised datasets collection -we collected 104K bilingual datasets of Chinese and English, and you can find the datasets in this repo -[InstructionWild](https://github.com/XueFuzhao/InstructionWild) +We collected 104K bilingual datasets of Chinese and English, and you can find the datasets in this repo +[InstructionWild](https://github.com/XueFuzhao/InstructionWild) and in this [file](https://github.com/XueFuzhao/InstructionWild/blob/main/data/README.md). Here is how we collected the data +

        @@ -112,12 +125,28 @@ Here is how we collected the data Stage1 is supervised instructs fine-tuning, which uses the datasets mentioned earlier to fine-tune the model. You can run the `examples/train_sft.sh` to start a supervised instructs fine-tuning. +[[Stage1 tutorial video]](https://www.youtube.com/watch?v=-qFBZFmOJfg) + +**Note**: the supervised dataset follows the following format, + +```json +[ + { + "instruction": "Provide a list of the top 10 most popular mobile games in Asia", + "input": "", + "output": "The top 10 most popular mobile games in Asia are:\n1) PUBG Mobile\n2) Pokemon Go\n3) Candy Crush Saga\n4) Free Fire\n5) Clash of Clans\n6) Mario Kart Tour\n7) Arena of Valor\n8) Fantasy Westward Journey\n9) Subway Surfers\n10) ARK Survival Evolved", + "id": 0 + }, + ... +] +``` ### RLHF Training Stage2 - Training reward model Stage2 trains a reward model, which obtains corresponding scores by manually ranking different outputs for the same prompt and supervises the training of the reward model You can run the `examples/train_rm.sh` to start a reward model training. +[[Stage2 tutorial video]](https://www.youtube.com/watch?v=gMx2CApKhuo) ### RLHF Training Stage3 - Training model with reinforcement learning by human feedback @@ -128,6 +157,39 @@ Stage3 uses reinforcement learning algorithm, which is the most complex part of

        You can run the `examples/train_prompts.sh` to start training PPO with human feedback. +[[Stage3 tutorial video]](https://www.youtube.com/watch?v=Z8wwSHxPL9g) + +**Note**: the required datasets follow the following format, + +- `pretrain dataset` + + ```json + [ + { + "instruction": "Provide a list of the top 10 most popular mobile games in Asia", + "input": "", + "output": "The top 10 most popular mobile games in Asia are:\n1) PUBG Mobile\n2) Pokemon Go\n3) Candy Crush Saga\n4) Free Fire\n5) Clash of Clans\n6) Mario Kart Tour\n7) Arena of Valor\n8) Fantasy Westward Journey\n9) Subway Surfers\n10) ARK Survival Evolved", + "id": 0 + }, + ... + ] + ``` + +- `prompt dataset` + + ```json + [ + { + "instruction": "Edit this paragraph to make it more concise: \"Yesterday, I went to the store and bought some things. Then, I came home and put them away. After that, I went for a walk and met some friends.\"", + "id": 0 + }, + { + "instruction": "Write a descriptive paragraph about a memorable vacation you went on", + "id": 1 + }, + ... + ] + ``` For more details, see [`examples/`](https://github.com/hpcaitech/ColossalAI/tree/main/applications/Chat/examples). @@ -135,9 +197,9 @@ For more details, see [`examples/`](https://github.com/hpcaitech/ColossalAI/tree We provide an online inference server and a benchmark. We aim to run inference on single GPU, so quantization is essential when using large models. -We support 8-bit quantization (RTN), 4-bit quantization (GPTQ), and FP16 inference. You can -Online inference server scripts can help you deploy your own services. +We support 8-bit quantization (RTN), 4-bit quantization (GPTQ), and FP16 inference. +Online inference server scripts can help you deploy your own services. For more details, see [`inference/`](https://github.com/hpcaitech/ColossalAI/tree/main/applications/Chat/inference). ## Coati7B examples @@ -147,6 +209,7 @@ For more details, see [`inference/`](https://github.com/hpcaitech/ColossalAI/tre
        E-mail ![phd](https://raw.githubusercontent.com/hpcaitech/public_assets/main/applications/chat/Phd.png) +
        coding @@ -180,6 +243,7 @@ For more details, see [`inference/`](https://github.com/hpcaitech/ColossalAI/tre
        ### Open QA +
        Game ![Game](https://raw.githubusercontent.com/hpcaitech/public_assets/main/applications/chat/game.png) @@ -213,6 +277,7 @@ For more details, see [`inference/`](https://github.com/hpcaitech/ColossalAI/tre You can find more examples in this [repo](https://github.com/XueFuzhao/InstructionWild/blob/main/comparison.md). ### Limitation +
        Limitation for LLaMA-finetuned models - Both Alpaca and ColossalChat are based on LLaMA. It is hard to compensate for the missing knowledge in the pre-training stage. - Lack of counting ability: Cannot count the number of items in a list. @@ -236,7 +301,7 @@ You can find more examples in this [repo](https://github.com/XueFuzhao/Instructi We have integrated the Transformers save and load pipeline, allowing users to freely call Hugging Face's language models and save them in the HF format. -``` +```python from coati.models.llama import LlamaLM from coati.trainer import SFTTrainer @@ -245,20 +310,20 @@ tokenizer = AutoTokenizer.from_pretrained(args.pretrain) (model, optim) = strategy.prepare((model, optim)) trainer = SFTTrainer(model=model, - strategy=strategy, - optim=optim, - train_dataloader=train_dataloader, - eval_dataloader=eval_dataloader, - batch_size=args.batch_size, - max_epochs=args.max_epochs, - accumulation_steps = args.accumulation_steps -) + strategy=strategy, + optim=optim, + train_dataloader=train_dataloader, + eval_dataloader=eval_dataloader, + batch_size=args.batch_size, + max_epochs=args.max_epochs, + accumulation_steps=args.accumulation_steps + ) trainer.fit() # this saves in pytorch format strategy.save_model(model, args.save_path, only_rank0=True) -# this saves in HF format. ColossalAI strategy with stage-3 doesn't support this method +# this saves in HF format strategy.save_pretrained(model, args.save_path, only_rank0=True, tokenizer=tokenizer) ``` @@ -269,12 +334,13 @@ strategy.save_pretrained(model, args.save_path, only_rank0=True, tokenizer=token Here are some examples that can allow you to train a 7B model on a single or multiple consumer-grade GPUs. If you only have a single 24G GPU, you can use the following script. `batch_size`, `lora_rank` and `grad_checkpoint` are the most important parameters to successfully train the model. -``` + +```bash +// [INFO]: MAX GPU MEMORY ALLOCATED: 19148.9345703125 MB torchrun --standalone --nproc_per_node=1 train_sft.py \ --pretrain "/path/to/LLaMa-7B/" \ --model 'llama' \ - --strategy naive \ - --log_interval 10 \ + --strategy ddp \ --save_path /path/to/Coati-7B \ --dataset /path/to/data.json \ --batch_size 1 \ @@ -287,12 +353,12 @@ torchrun --standalone --nproc_per_node=1 train_sft.py \ ``` `colossalai_gemini` strategy can enable a single 24G GPU to train the whole model without using LoRA if you have sufficient CPU memory. You can use the following script. -``` + +```bash torchrun --standalone --nproc_per_node=1 train_sft.py \ --pretrain "/path/to/LLaMa-7B/" \ --model 'llama' \ --strategy colossalai_gemini \ - --log_interval 10 \ --save_path /path/to/Coati-7B \ --dataset /path/to/data.json \ --batch_size 1 \ @@ -304,12 +370,12 @@ torchrun --standalone --nproc_per_node=1 train_sft.py \ ``` If you have 4x32 GB GPUs, you can even train the whole 7B model using our `colossalai_zero2_cpu` strategy! The script is given as follows. -``` + +```bash torchrun --standalone --nproc_per_node=4 train_sft.py \ --pretrain "/path/to/LLaMa-7B/" \ --model 'llama' \ --strategy colossalai_zero2_cpu \ - --log_interval 10 \ --save_path /path/to/Coati-7B \ --dataset /path/to/data.json \ --batch_size 1 \ @@ -319,8 +385,8 @@ torchrun --standalone --nproc_per_node=4 train_sft.py \ --max_epochs 1 \ --grad_checkpoint ``` -
        +
        ## The Plan @@ -335,31 +401,33 @@ torchrun --standalone --nproc_per_node=4 train_sft.py \ - [ ] support chain-of-thought by [langchain](https://github.com/hwchase17/langchain) ### Real-time progress -You will find our progress in github project broad -[Coati](https://github.com/orgs/hpcaitech/projects/17/views/1) +You will find our progress in github [project broad](https://github.com/orgs/hpcaitech/projects/17/views/1). ## Invitation to open-source contribution + Referring to the successful attempts of [BLOOM](https://bigscience.huggingface.co/) and [Stable Diffusion](https://en.wikipedia.org/wiki/Stable_Diffusion), any and all developers and partners with computing powers, datasets, models are welcome to join and build the Colossal-AI community, making efforts towards the era of big AI models from the starting point of replicating ChatGPT! You may contact us or participate in the following ways: + 1. [Leaving a Star ⭐](https://github.com/hpcaitech/ColossalAI/stargazers) to show your like and support. Thanks! 2. Posting an [issue](https://github.com/hpcaitech/ColossalAI/issues/new/choose), or submitting a PR on GitHub follow the guideline in [Contributing](https://github.com/hpcaitech/ColossalAI/blob/main/CONTRIBUTING.md). 3. Join the Colossal-AI community on -[Slack](https://join.slack.com/t/colossalaiworkspace/shared_invite/zt-z7b26eeb-CBp7jouvu~r0~lcFzX832w), -and [WeChat(微信)](https://raw.githubusercontent.com/hpcaitech/public_assets/main/colossalai/img/WeChat.png "qrcode") to share your ideas. + [Slack](https://github.com/hpcaitech/public_assets/tree/main/colossalai/contact/slack), + and [WeChat(微信)](https://raw.githubusercontent.com/hpcaitech/public_assets/main/colossalai/img/WeChat.png "qrcode") to share your ideas. 4. Send your official proposal to email contact@hpcaitech.com Thanks so much to all of our amazing contributors! ## Quick Preview + -- An open-source low cost solution for cloning [ChatGPT](https://openai.com/blog/chatgpt/) with a complete RLHF pipeline. [[demo]](https://chat.colossalai.org) +- An open-source low-cost solution for cloning [ChatGPT](https://openai.com/blog/chatgpt/) with a complete RLHF pipeline. [[demo]](https://chat.colossalai.org)

        @@ -386,18 +454,21 @@ Thanks so much to all of our amazing contributors! | Better Cases | 38 ⚔ **41** | **45** ⚔ 33 | | Win Rate | 48% ⚔ **52%** | **58%** ⚔ 42% | | Average Score | 7.06 ⚔ **7.13** | **7.31** ⚔ 6.82 | + - Our Coati-7B model performs better than Alpaca-7B when using GPT-4 to evaluate model performance. The Coati-7B model we evaluate is an old version we trained a few weeks ago and the new version is around the corner. ## Authors Coati is developed by ColossalAI Team: + - [Fazzie](https://fazzie-key.cool/about/index.html) - [FrankLeeeee](https://github.com/FrankLeeeee) - [BlueRum](https://github.com/ht-zhou) - [ver217](https://github.com/ver217) - [ofey404](https://github.com/ofey404) +- [Wenhao Chen](https://github.com/CWHer) -The Phd student from [(HPC-AI) Lab](https://ai.comp.nus.edu.sg/) also contributed a lot to this project. +The PhD student from [(HPC-AI) Lab](https://ai.comp.nus.edu.sg/) also contributed a lot to this project. - [Zangwei Zheng](https://github.com/zhengzangw) - [Xue Fuzhao](https://github.com/XueFuzhao) diff --git a/applications/Chat/benchmarks/README.md b/applications/Chat/benchmarks/README.md index bc8ad8ba98165ebfefd842e6deea27120bdb547e..c13f3485863b9dde3044d0ab0fe4f2061544030b 100644 --- a/applications/Chat/benchmarks/README.md +++ b/applications/Chat/benchmarks/README.md @@ -27,9 +27,12 @@ We also provide various training strategies: We only support `torchrun` to launch now. E.g. -```shell +```bash # run OPT-125M with no lora (lora_rank=0) on single-node single-GPU with min batch size -torchrun --standalone --nproc_per_node 1 benchmark_opt_lora_dummy.py --model 125m --critic_model 125m --strategy ddp --experience_batch_size 1 --train_batch_size 1 --lora_rank 0 +torchrun --standalone --nproc_per_node 1 benchmark_opt_lora_dummy.py \ + --model 125m --critic_model 125m --strategy ddp \ + --experience_batch_size 1 --train_batch_size 1 --lora_rank 0 # run Actor (OPT-1.3B) and Critic (OPT-350M) with lora_rank=4 on single-node 4-GPU -torchrun --standalone --nproc_per_node 4 benchmark_opt_lora_dummy.py --model 1.3b --critic_model 350m --strategy colossalai_zero2 --lora_rank 4 +torchrun --standalone --nproc_per_node 4 benchmark_opt_lora_dummy.py \ + --model 1.3b --critic_model 350m --strategy colossalai_zero2 --lora_rank 4 ``` diff --git a/applications/Chat/benchmarks/benchmark_opt_lora_dummy.py b/applications/Chat/benchmarks/benchmark_opt_lora_dummy.py index 7a47624f74d87f188994cf7ee59e5f2d1f2b0b8b..0d0e2a7d34f54642cb712835b60f823a8bff6c8e 100644 --- a/applications/Chat/benchmarks/benchmark_opt_lora_dummy.py +++ b/applications/Chat/benchmarks/benchmark_opt_lora_dummy.py @@ -8,7 +8,7 @@ from coati.models.base import RewardModel from coati.models.opt import OPTActor, OPTCritic from coati.trainer import PPOTrainer from coati.trainer.callbacks import PerformanceEvaluator -from coati.trainer.strategies import ColossalAIStrategy, DDPStrategy, Strategy +from coati.trainer.strategies import DDPStrategy, GeminiStrategy, LowLevelZeroStrategy, Strategy from torch.optim import Adam from torch.utils.data import DataLoader from transformers import AutoTokenizer @@ -19,7 +19,7 @@ from colossalai.nn.optimizer import HybridAdam def get_model_numel(model: nn.Module, strategy: Strategy) -> int: numel = sum(p.numel() for p in model.parameters()) - if isinstance(strategy, ColossalAIStrategy) and strategy.stage == 3 and strategy.shard_init: + if isinstance(strategy, GeminiStrategy) and strategy.shard_init: numel *= dist.get_world_size() return numel @@ -27,7 +27,7 @@ def get_model_numel(model: nn.Module, strategy: Strategy) -> int: def preprocess_batch(samples) -> dict: input_ids = torch.stack(samples) attention_mask = torch.ones_like(input_ids, dtype=torch.long) - return {'input_ids': input_ids, 'attention_mask': attention_mask} + return {"input_ids": input_ids, "attention_mask": attention_mask} def print_rank_0(*args, **kwargs) -> None: @@ -39,32 +39,32 @@ def print_model_numel(model_dict: dict) -> None: B = 1024**3 M = 1024**2 K = 1024 - outputs = '' + outputs = "" for name, numel in model_dict.items(): - outputs += f'{name}: ' + outputs += f"{name}: " if numel >= B: - outputs += f'{numel / B:.2f} B\n' + outputs += f"{numel / B:.2f} B\n" elif numel >= M: - outputs += f'{numel / M:.2f} M\n' + outputs += f"{numel / M:.2f} M\n" elif numel >= K: - outputs += f'{numel / K:.2f} K\n' + outputs += f"{numel / K:.2f} K\n" else: - outputs += f'{numel}\n' + outputs += f"{numel}\n" print_rank_0(outputs) def get_gpt_config(model_name: str) -> OPTConfig: model_map = { - '125m': OPTConfig.from_pretrained('facebook/opt-125m'), - '350m': OPTConfig(hidden_size=1024, ffn_dim=4096, num_hidden_layers=24, num_attention_heads=16), - '700m': OPTConfig(hidden_size=1280, ffn_dim=5120, num_hidden_layers=36, num_attention_heads=20), - '1.3b': OPTConfig.from_pretrained('facebook/opt-1.3b'), - '2.7b': OPTConfig.from_pretrained('facebook/opt-2.7b'), - '3.5b': OPTConfig(hidden_size=3072, ffn_dim=12288, num_hidden_layers=32, num_attention_heads=32), - '5.5b': OPTConfig(hidden_size=3840, ffn_dim=15360, num_hidden_layers=32, num_attention_heads=32), - '6.7b': OPTConfig.from_pretrained('facebook/opt-6.7b'), - '10b': OPTConfig(hidden_size=5120, ffn_dim=20480, num_hidden_layers=32, num_attention_heads=32), - '13b': OPTConfig.from_pretrained('facebook/opt-13b'), + "125m": OPTConfig.from_pretrained("facebook/opt-125m"), + "350m": OPTConfig(hidden_size=1024, ffn_dim=4096, num_hidden_layers=24, num_attention_heads=16), + "700m": OPTConfig(hidden_size=1280, ffn_dim=5120, num_hidden_layers=36, num_attention_heads=20), + "1.3b": OPTConfig.from_pretrained("facebook/opt-1.3b"), + "2.7b": OPTConfig.from_pretrained("facebook/opt-2.7b"), + "3.5b": OPTConfig(hidden_size=3072, ffn_dim=12288, num_hidden_layers=32, num_attention_heads=32), + "5.5b": OPTConfig(hidden_size=3840, ffn_dim=15360, num_hidden_layers=32, num_attention_heads=32), + "6.7b": OPTConfig.from_pretrained("facebook/opt-6.7b"), + "10b": OPTConfig(hidden_size=5120, ffn_dim=20480, num_hidden_layers=32, num_attention_heads=32), + "13b": OPTConfig.from_pretrained("facebook/opt-13b"), } try: return model_map[model_name] @@ -73,20 +73,20 @@ def get_gpt_config(model_name: str) -> OPTConfig: def main(args): - if args.strategy == 'ddp': + if args.strategy == "ddp": strategy = DDPStrategy() - elif args.strategy == 'colossalai_gemini': - strategy = ColossalAIStrategy(stage=3, placement_policy='cuda', initial_scale=2**5) - elif args.strategy == 'colossalai_gemini_cpu': - strategy = ColossalAIStrategy(stage=3, placement_policy='cpu', initial_scale=2**5) - elif args.strategy == 'colossalai_zero2': - strategy = ColossalAIStrategy(stage=2, placement_policy='cuda') - elif args.strategy == 'colossalai_zero2_cpu': - strategy = ColossalAIStrategy(stage=2, placement_policy='cpu') - elif args.strategy == 'colossalai_zero1': - strategy = ColossalAIStrategy(stage=1, placement_policy='cuda') - elif args.strategy == 'colossalai_zero1_cpu': - strategy = ColossalAIStrategy(stage=1, placement_policy='cpu') + elif args.strategy == "colossalai_gemini": + strategy = GeminiStrategy(placement_policy="static",initial_scale=2**5) + elif args.strategy == "colossalai_gemini_cpu": + strategy = GeminiStrategy(placement_policy="static", offload_optim_frac=1.0, offload_param_frac=1.0, initial_scale=2**5) + elif args.strategy == "colossalai_zero2": + strategy = LowLevelZeroStrategy(stage=2, placement_policy="cuda") + elif args.strategy == "colossalai_zero2_cpu": + strategy = LowLevelZeroStrategy(stage=2, placement_policy="cpu") + elif args.strategy == "colossalai_zero1": + strategy = LowLevelZeroStrategy(stage=1, placement_policy="cuda") + elif args.strategy == "colossalai_zero1_cpu": + strategy = LowLevelZeroStrategy(stage=1, placement_policy="cpu") else: raise ValueError(f'Unsupported strategy "{args.strategy}"') @@ -103,92 +103,106 @@ def main(args): if args.use_kernels: from coati.kernels import convert_to_xformer_model - actor, critic, initial_model, reward_model = map(convert_to_xformer_model, - (actor, critic, initial_model, reward_model)) + + actor, critic, initial_model, reward_model = map( + convert_to_xformer_model, (actor, critic, initial_model, reward_model) + ) actor_numel = get_model_numel(actor, strategy) critic_numel = get_model_numel(critic, strategy) initial_model_numel = get_model_numel(initial_model, strategy) reward_model_numel = get_model_numel(reward_model, strategy) - print_model_numel({ - 'Actor': actor_numel, - 'Critic': critic_numel, - 'Initial model': initial_model_numel, - 'Reward model': reward_model_numel - }) - performance_evaluator = PerformanceEvaluator(actor_numel, - critic_numel, - initial_model_numel, - reward_model_numel, - enable_grad_checkpoint=False, - ignore_episodes=1) - - if args.strategy.startswith('colossalai'): + print_model_numel( + { + "Actor": actor_numel, + "Critic": critic_numel, + "Initial model": initial_model_numel, + "Reward model": reward_model_numel, + } + ) + performance_evaluator = PerformanceEvaluator( + actor_numel, + critic_numel, + initial_model_numel, + reward_model_numel, + enable_grad_checkpoint=False, + ignore_episodes=1, + ) + + if args.strategy.startswith("colossalai"): actor_optim = HybridAdam(actor.parameters(), lr=5e-6) critic_optim = HybridAdam(critic.parameters(), lr=5e-6) else: actor_optim = Adam(actor.parameters(), lr=5e-6) critic_optim = Adam(critic.parameters(), lr=5e-6) - tokenizer = AutoTokenizer.from_pretrained('facebook/opt-350m') + tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m") tokenizer.pad_token = tokenizer.eos_token + tokenizer.padding_side = "left" (actor, actor_optim), (critic, critic_optim) = strategy.prepare((actor, actor_optim), (critic, critic_optim)) - trainer = PPOTrainer(strategy, - actor, - critic, - reward_model, - initial_model, - actor_optim, - critic_optim, - ptx_coef=0, - max_epochs=args.max_epochs, - train_batch_size=args.train_batch_size, - offload_inference_models=args.offload_inference_models, - max_length=512, - do_sample=True, - temperature=1.0, - top_k=50, - use_cache=True, - pad_token_id=tokenizer.pad_token_id, - eos_token_id=tokenizer.eos_token_id, - callbacks=[performance_evaluator]) - random_prompts = torch.randint(tokenizer.vocab_size, (1000, 256), device=torch.cuda.current_device()) - dataloader = DataLoader(random_prompts, - batch_size=args.experience_batch_size, - shuffle=True, - collate_fn=preprocess_batch) - - trainer.fit(dataloader, - None, - num_episodes=args.num_episodes, - max_timesteps=args.max_timesteps, - update_timesteps=args.update_timesteps) - - print_rank_0(f'Peak CUDA mem: {torch.cuda.max_memory_allocated()/1024**3:.2f} GB') - - -if __name__ == '__main__': + dataloader = DataLoader( + random_prompts, batch_size=args.experience_batch_size, shuffle=True, collate_fn=preprocess_batch + ) + + trainer = PPOTrainer( + strategy, + actor, + critic, + reward_model, + initial_model, + actor_optim, + critic_optim, + tokenizer=tokenizer, + ptx_coef=0, + train_batch_size=args.train_batch_size, + offload_inference_models=args.offload_inference_models, + max_length=512, + do_sample=True, + temperature=1.0, + top_k=50, + use_cache=True, + callbacks=[performance_evaluator], + ) + + trainer.fit( + prompt_dataloader=dataloader, + pretrain_dataloader=None, + num_episodes=args.num_episodes, + num_update_steps=args.num_update_steps, + num_collect_steps=args.num_collect_steps, + ) + + print_rank_0(f"Peak CUDA mem: {torch.cuda.max_memory_allocated()/1024**3:.2f} GB") + + +if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument('--model', default='125m') - parser.add_argument('--critic_model', default='125m') - parser.add_argument('--strategy', - choices=[ - 'ddp', 'colossalai_gemini', 'colossalai_gemini_cpu', 'colossalai_zero2', - 'colossalai_zero2_cpu', 'colossalai_zero1', 'colossalai_zero1_cpu' - ], - default='ddp') - parser.add_argument('--num_episodes', type=int, default=3) - parser.add_argument('--max_timesteps', type=int, default=8) - parser.add_argument('--update_timesteps', type=int, default=8) - parser.add_argument('--max_epochs', type=int, default=1) - parser.add_argument('--train_batch_size', type=int, default=8) - parser.add_argument('--experience_batch_size', type=int, default=8) - parser.add_argument('--lora_rank', type=int, default=0) - parser.add_argument('--cuda_mem_frac', type=float, default=1.0) - parser.add_argument('--offload_inference_models', action='store_true', default=False) - parser.add_argument('--use_kernels', action='store_true', default=False) + parser.add_argument("--model", default="125m") + parser.add_argument("--critic_model", default="125m") + parser.add_argument( + "--strategy", + choices=[ + "ddp", + "colossalai_gemini", + "colossalai_gemini_cpu", + "colossalai_zero2", + "colossalai_zero2_cpu", + "colossalai_zero1", + "colossalai_zero1_cpu", + ], + default="ddp", + ) + parser.add_argument("--num_episodes", type=int, default=3) + parser.add_argument("--num_collect_steps", type=int, default=8) + parser.add_argument("--num_update_steps", type=int, default=1) + parser.add_argument("--train_batch_size", type=int, default=8) + parser.add_argument("--experience_batch_size", type=int, default=8) + parser.add_argument("--lora_rank", type=int, default=0) + parser.add_argument("--cuda_mem_frac", type=float, default=1.0) + parser.add_argument("--offload_inference_models", action="store_true", default=False) + parser.add_argument("--use_kernels", action="store_true", default=False) args = parser.parse_args() main(args) diff --git a/applications/Chat/benchmarks/ray/1mmt_dummy.py b/applications/Chat/benchmarks/ray/1mmt_dummy.py new file mode 100644 index 0000000000000000000000000000000000000000..98ace3869450901f72948d4a5acbbc969c11a18e --- /dev/null +++ b/applications/Chat/benchmarks/ray/1mmt_dummy.py @@ -0,0 +1,192 @@ +import argparse +import os +import socket +from functools import partial + +import ray +import torch +from coati.quant import llama_load_quant, low_resource_init +from coati.ray.detached_trainer_ppo import DetachedPPOTrainer +from coati.ray.experience_maker_holder import ExperienceMakerHolder +from coati.ray.utils import ( + get_actor_from_args, + get_critic_from_args, + get_receivers_per_sender, + get_reward_model_from_args, + get_strategy_from_args, +) +from torch.utils.data import DataLoader +from transformers import AutoConfig, AutoTokenizer +from transformers.modeling_utils import no_init_weights + + +def get_free_port(): + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + s.bind(("", 0)) + return s.getsockname()[1] + + +def get_local_ip(): + with socket.socket(socket.AF_INET, socket.SOCK_DGRAM) as s: + s.connect(("8.8.8.8", 80)) + return s.getsockname()[0] + + +def main(args): + master_addr = str(get_local_ip()) + # trainer_env_info + trainer_port = str(get_free_port()) + env_info_trainers = [ + { + "local_rank": "0", + "rank": str(rank), + "world_size": str(args.num_trainers), + "master_port": trainer_port, + "master_addr": master_addr, + } + for rank in range(args.num_trainers) + ] + + # maker_env_info + maker_port = str(get_free_port()) + env_info_maker = { + "local_rank": "0", + "rank": "0", + "world_size": "1", + "master_port": maker_port, + "master_addr": master_addr, + } + + # configure tokenizer + tokenizer = AutoTokenizer.from_pretrained(args.pretrain) + tokenizer.pad_token = tokenizer.eos_token + + def model_fn(): + actor_cfg = AutoConfig.from_pretrained(args.pretrain) + critic_cfg = AutoConfig.from_pretrained(args.critic_pretrain) + actor = get_actor_from_args(args.model, config=actor_cfg).requires_grad_(False).half().cuda() + critic = get_critic_from_args(args.critic_model, config=critic_cfg).requires_grad_(False).half().cuda() + reward_model = ( + get_reward_model_from_args(args.critic_model, config=critic_cfg).requires_grad_(False).half().cuda() + ) + if args.initial_model_quant_ckpt is not None and args.model == "llama": + # quantize initial model + with low_resource_init(), no_init_weights(): + initial_model = get_actor_from_args(args.model, config=actor_cfg) + initial_model.model = ( + llama_load_quant( + initial_model.model, args.initial_model_quant_ckpt, args.quant_bits, args.quant_group_size + ) + .cuda() + .requires_grad_(False) + ) + else: + initial_model = get_actor_from_args(args.model, config=actor_cfg).requires_grad_(False).half().cuda() + return actor, critic, reward_model, initial_model + + # configure Experience Maker + experience_holder_ref = ExperienceMakerHolder.options(name="maker0", num_gpus=1, max_concurrency=2).remote( + detached_trainer_name_list=[f"trainer{i}" for i in range(args.num_trainers)], + strategy_fn=partial(get_strategy_from_args, args.maker_strategy), + model_fn=model_fn, + env_info=env_info_maker, + kl_coef=0.1, + debug=args.debug, + # sync_models_from_trainers=True, + # generation kwargs: + max_length=512, + do_sample=True, + temperature=1.0, + top_k=50, + pad_token_id=tokenizer.pad_token_id, + eos_token_id=tokenizer.eos_token_id, + eval_performance=True, + use_cache=True, + ) + + def trainer_model_fn(): + actor = get_actor_from_args(args.model, config=AutoConfig.from_pretrained(args.pretrain)).half().cuda() + critic = ( + get_critic_from_args(args.critic_model, config=AutoConfig.from_pretrained(args.critic_pretrain)) + .half() + .cuda() + ) + return actor, critic + + # configure Trainer + trainer_refs = [ + DetachedPPOTrainer.options(name=f"trainer{i}", num_gpus=1, max_concurrency=2).remote( + experience_maker_holder_name_list=[ + f"maker{x}" for x in get_receivers_per_sender(i, args.num_trainers, 1, allow_idle_sender=True) + ], + strategy_fn=partial(get_strategy_from_args, args.trainer_strategy), + model_fn=trainer_model_fn, + env_info=env_info_trainer, + train_batch_size=args.train_batch_size, + buffer_limit=16, + eval_performance=True, + debug=args.debug, + ) + for i, env_info_trainer in enumerate(env_info_trainers) + ] + + dataset_size = args.experience_batch_size * 4 + + def data_gen_fn(): + input_ids = torch.randint(tokenizer.vocab_size, (256,), device=torch.cuda.current_device()) + attn_mask = torch.ones_like(input_ids) + return {"input_ids": input_ids, "attention_mask": attn_mask} + + def build_dataloader(size): + dataset = [data_gen_fn() for _ in range(size)] + dataloader = DataLoader(dataset, batch_size=args.experience_batch_size) + return dataloader + + # uncomment this function if sync_models_from_trainers is True + # ray.get([ + # trainer_ref.sync_models_to_remote_makers.remote() + # for trainer_ref in trainer_refs + # ]) + + wait_tasks = [] + + wait_tasks.append( + experience_holder_ref.workingloop.remote( + partial(build_dataloader, dataset_size), num_steps=args.experience_steps + ) + ) + + total_steps = args.experience_batch_size * args.experience_steps // (args.num_trainers * args.train_batch_size) + for trainer_ref in trainer_refs: + wait_tasks.append(trainer_ref.fit.remote(total_steps, args.update_steps, args.train_epochs)) + + ray.get(wait_tasks) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--num_trainers", type=int, default=1) + parser.add_argument( + "--trainer_strategy", + choices=["ddp", "colossalai_gemini", "colossalai_zero2", "colossalai_gemini_cpu", "colossalai_zero2_cpu"], + default="ddp", + ) + parser.add_argument("--maker_strategy", choices=["naive"], default="naive") + parser.add_argument("--model", default="gpt2", choices=["gpt2", "bloom", "opt", "llama"]) + parser.add_argument("--critic_model", default="gpt2", choices=["gpt2", "bloom", "opt", "llama"]) + parser.add_argument("--pretrain", type=str, default=None) + parser.add_argument("--critic_pretrain", type=str, default=None) + parser.add_argument("--experience_steps", type=int, default=4) + parser.add_argument("--experience_batch_size", type=int, default=8) + parser.add_argument("--train_epochs", type=int, default=1) + parser.add_argument("--update_steps", type=int, default=2) + parser.add_argument("--train_batch_size", type=int, default=8) + parser.add_argument("--lora_rank", type=int, default=0, help="low-rank adaptation matrices rank") + + parser.add_argument("--initial_model_quant_ckpt", type=str, default=None) + parser.add_argument("--quant_bits", type=int, default=4) + parser.add_argument("--quant_group_size", type=int, default=128) + parser.add_argument("--debug", action="store_true") + args = parser.parse_args() + ray.init(namespace=os.environ["RAY_NAMESPACE"], runtime_env={"env_vars": dict(os.environ)}) + main(args) diff --git a/applications/Chat/benchmarks/ray/mmmt_dummy.py b/applications/Chat/benchmarks/ray/mmmt_dummy.py new file mode 100644 index 0000000000000000000000000000000000000000..f8860f2979ee0fdf2f677d2e4a80d6350c7d80b8 --- /dev/null +++ b/applications/Chat/benchmarks/ray/mmmt_dummy.py @@ -0,0 +1,209 @@ +import argparse +import os +import socket +from functools import partial + +import ray +import torch +from coati.quant import llama_load_quant, low_resource_init +from coati.ray.detached_trainer_ppo import DetachedPPOTrainer +from coati.ray.experience_maker_holder import ExperienceMakerHolder +from coati.ray.utils import ( + get_actor_from_args, + get_critic_from_args, + get_receivers_per_sender, + get_reward_model_from_args, + get_strategy_from_args, +) +from torch.utils.data import DataLoader +from transformers import AutoConfig, AutoTokenizer +from transformers.modeling_utils import no_init_weights + + +def get_free_port(): + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + s.bind(("", 0)) + return s.getsockname()[1] + + +def get_local_ip(): + with socket.socket(socket.AF_INET, socket.SOCK_DGRAM) as s: + s.connect(("8.8.8.8", 80)) + return s.getsockname()[0] + + +def main(args): + master_addr = str(get_local_ip()) + # trainer_env_info + trainer_port = str(get_free_port()) + env_info_trainers = [ + { + "local_rank": "0", + "rank": str(rank), + "world_size": str(args.num_trainers), + "master_port": trainer_port, + "master_addr": master_addr, + } + for rank in range(args.num_trainers) + ] + + # maker_env_info + maker_port = str(get_free_port()) + env_info_makers = [ + { + "local_rank": "0", + "rank": str(rank), + "world_size": str(args.num_makers), + "master_port": maker_port, + "master_addr": master_addr, + } + for rank in range(args.num_makers) + ] + + # configure tokenizer + tokenizer = AutoTokenizer.from_pretrained(args.pretrain) + tokenizer.pad_token = tokenizer.eos_token + + def model_fn(): + actor_cfg = AutoConfig.from_pretrained(args.pretrain) + critic_cfg = AutoConfig.from_pretrained(args.critic_pretrain) + actor = get_actor_from_args(args.model, config=actor_cfg).requires_grad_(False).half().cuda() + critic = get_critic_from_args(args.critic_model, config=critic_cfg).requires_grad_(False).half().cuda() + reward_model = ( + get_reward_model_from_args(args.critic_model, config=critic_cfg).requires_grad_(False).half().cuda() + ) + if args.initial_model_quant_ckpt is not None and args.model == "llama": + # quantize initial model + with low_resource_init(), no_init_weights(): + initial_model = get_actor_from_args(args.model, config=actor_cfg) + initial_model.model = ( + llama_load_quant( + initial_model.model, args.initial_model_quant_ckpt, args.quant_bits, args.quant_group_size + ) + .cuda() + .requires_grad_(False) + ) + else: + initial_model = get_actor_from_args(args.model, config=actor_cfg).requires_grad_(False).half().cuda() + return actor, critic, reward_model, initial_model + + # configure Experience Maker + experience_holder_refs = [ + ExperienceMakerHolder.options(name=f"maker{i}", num_gpus=1, max_concurrency=2).remote( + detached_trainer_name_list=[ + f"trainer{x}" + for x in get_receivers_per_sender(i, args.num_makers, args.num_trainers, allow_idle_sender=False) + ], + strategy_fn=partial(get_strategy_from_args, args.maker_strategy), + model_fn=model_fn, + env_info=env_info_maker, + kl_coef=0.1, + debug=args.debug, + # sync_models_from_trainers=True, + # generation kwargs: + max_length=512, + do_sample=True, + temperature=1.0, + top_k=50, + pad_token_id=tokenizer.pad_token_id, + eos_token_id=tokenizer.eos_token_id, + eval_performance=True, + use_cache=True, + ) + for i, env_info_maker in enumerate(env_info_makers) + ] + + def trainer_model_fn(): + actor = get_actor_from_args(args.model, config=AutoConfig.from_pretrained(args.pretrain)).half().cuda() + critic = ( + get_critic_from_args(args.critic_model, config=AutoConfig.from_pretrained(args.critic_pretrain)) + .half() + .cuda() + ) + return actor, critic + + # configure Trainer + trainer_refs = [ + DetachedPPOTrainer.options(name=f"trainer{i}", num_gpus=1, max_concurrency=2).remote( + experience_maker_holder_name_list=[ + f"maker{x}" + for x in get_receivers_per_sender(i, args.num_trainers, args.num_makers, allow_idle_sender=True) + ], + strategy_fn=partial(get_strategy_from_args, args.trainer_strategy), + model_fn=trainer_model_fn, + env_info=env_info_trainer, + train_batch_size=args.train_batch_size, + buffer_limit=16, + eval_performance=True, + debug=args.debug, + ) + for i, env_info_trainer in enumerate(env_info_trainers) + ] + + dataset_size = args.experience_batch_size * 4 + + def data_gen_fn(): + input_ids = torch.randint(tokenizer.vocab_size, (256,), device=torch.cuda.current_device()) + attn_mask = torch.ones_like(input_ids) + return {"input_ids": input_ids, "attention_mask": attn_mask} + + def build_dataloader(size): + dataset = [data_gen_fn() for _ in range(size)] + dataloader = DataLoader(dataset, batch_size=args.experience_batch_size) + return dataloader + + # uncomment this function if sync_models_from_trainers is True + # ray.get([ + # trainer_ref.sync_models_to_remote_makers.remote() + # for trainer_ref in trainer_refs + # ]) + + wait_tasks = [] + + for experience_holder_ref in experience_holder_refs: + wait_tasks.append( + experience_holder_ref.workingloop.remote( + partial(build_dataloader, dataset_size), num_steps=args.experience_steps + ) + ) + + total_steps = ( + args.experience_batch_size + * args.experience_steps + * args.num_makers + // (args.num_trainers * args.train_batch_size) + ) + for trainer_ref in trainer_refs: + wait_tasks.append(trainer_ref.fit.remote(total_steps, args.update_steps, args.train_epochs)) + + ray.get(wait_tasks) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--num_makers", type=int, default=1) + parser.add_argument("--num_trainers", type=int, default=1) + parser.add_argument( + "--trainer_strategy", + choices=["ddp", "colossalai_gemini", "colossalai_zero2", "colossalai_gemini_cpu", "colossalai_zero2_cpu"], + default="ddp", + ) + parser.add_argument("--maker_strategy", choices=["naive"], default="naive") + parser.add_argument("--model", default="gpt2", choices=["gpt2", "bloom", "opt", "llama"]) + parser.add_argument("--critic_model", default="gpt2", choices=["gpt2", "bloom", "opt", "llama"]) + parser.add_argument("--pretrain", type=str, default=None) + parser.add_argument("--critic_pretrain", type=str, default=None) + parser.add_argument("--experience_steps", type=int, default=4) + parser.add_argument("--experience_batch_size", type=int, default=8) + parser.add_argument("--train_epochs", type=int, default=1) + parser.add_argument("--update_steps", type=int, default=2) + parser.add_argument("--train_batch_size", type=int, default=8) + parser.add_argument("--lora_rank", type=int, default=0, help="low-rank adaptation matrices rank") + + parser.add_argument("--initial_model_quant_ckpt", type=str, default=None) + parser.add_argument("--quant_bits", type=int, default=4) + parser.add_argument("--quant_group_size", type=int, default=128) + parser.add_argument("--debug", action="store_true") + args = parser.parse_args() + ray.init(namespace=os.environ["RAY_NAMESPACE"], runtime_env={"env_vars": dict(os.environ)}) + main(args) diff --git a/applications/Chat/coati/dataset/__init__.py b/applications/Chat/coati/dataset/__init__.py index f650668e90b0feddbf6b20f1e0fc084a61ac4fc5..599b5760977599aa6c9d5119c647ce277cb929d8 100644 --- a/applications/Chat/coati/dataset/__init__.py +++ b/applications/Chat/coati/dataset/__init__.py @@ -1,9 +1,13 @@ from .prompt_dataset import PromptDataset from .reward_dataset import HhRlhfDataset, RmStaticDataset -from .sft_dataset import DataCollatorForSupervisedDataset, SFTDataset, SupervisedDataset +from .sft_dataset import SFTDataset, SupervisedDataset from .utils import is_rank_0 __all__ = [ - 'RmStaticDataset', 'HhRlhfDataset', 'is_rank_0', 'SFTDataset', 'SupervisedDataset', - 'DataCollatorForSupervisedDataset', 'PromptDataset' + "RmStaticDataset", + "HhRlhfDataset", + "SFTDataset", + "SupervisedDataset", + "PromptDataset", + "is_rank_0", ] diff --git a/applications/Chat/coati/dataset/conversation.py b/applications/Chat/coati/dataset/conversation.py new file mode 100644 index 0000000000000000000000000000000000000000..f2180d96b0d344f4f45e19c1ac85aca3e9bb8691 --- /dev/null +++ b/applications/Chat/coati/dataset/conversation.py @@ -0,0 +1,89 @@ +# Copyright 2023 lm-sys@FastChat +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import dataclasses +from enum import Enum, auto +from typing import List + + +class SeparatorStyle(Enum): + ADD_EOS_TOKEN = auto() + + +@dataclasses.dataclass +class Conversation: + system: str + roles: List[str] + messages: List[List[str]] + offset: int + sep_style: SeparatorStyle = SeparatorStyle.ADD_EOS_TOKEN + sep: str = "" + + skip_next: bool = False + + def get_prompt(self): + if self.sep_style == SeparatorStyle.ADD_EOS_TOKEN: + ret = self.system + for role, message in self.messages: + if message: + ret += role + ": " + message + self.sep + else: + ret += role + ": " + return ret + else: + raise ValueError(f"Invalid style: {self.sep_style}") + + def append_message(self, role, message): + self.messages.append([role, message]) + + def to_gradio_chatbot(self): + ret = [] + for i, (role, msg) in enumerate(self.messages[self.offset :]): + if i % 2 == 0: + ret.append([msg, None]) + else: + ret[-1][-1] = msg + return ret + + def copy(self): + return Conversation( + system=self.system, + roles=self.roles, + messages=[[x, y] for x, y in self.messages], + offset=self.offset, + sep_style=self.sep_style, + sep=self.sep, + ) + + def dict(self): + return { + "system": self.system, + "roles": self.roles, + "messages": self.messages, + "offset": self.offset, + "sep": self.sep, + } + + +conv = Conversation( + system="A chat between a curious human and an artificial intelligence assistant. " + "The assistant gives helpful, detailed, and polite answers to the human's questions.\n\n", + roles=("Human", "Assistant"), + messages=(), + offset=0, + sep_style=SeparatorStyle.ADD_EOS_TOKEN, + sep="", +) + +default_conversation = conv diff --git a/applications/Chat/coati/dataset/prompt_dataset.py b/applications/Chat/coati/dataset/prompt_dataset.py index f8ab2346c4b79c31036c0f5904a3344d6729149e..17120e6064b5da4ed0c507f7db01b871c27b3cc4 100644 --- a/applications/Chat/coati/dataset/prompt_dataset.py +++ b/applications/Chat/coati/dataset/prompt_dataset.py @@ -1,51 +1,45 @@ -import copy -import random from collections import defaultdict -from dataclasses import dataclass, field -from typing import Callable, Dict, Sequence +from typing import Dict import torch -import torch.distributed as dist import transformers from torch.utils.data import Dataset -from tqdm import tqdm from colossalai.logging import get_dist_logger -from .utils import is_rank_0, jload - -logger = get_dist_logger() +from .utils import jload class PromptDataset(Dataset): """Dataset for supervised fine-tuning.""" - def __init__(self, - data_path: str, - tokenizer: transformers.PreTrainedTokenizer, - max_datasets_size: int = None, - max_length: int = 96): + def __init__( + self, + data_path: str, + tokenizer: transformers.PreTrainedTokenizer, + max_datasets_size: int = None, + max_length: int = 96, + ): super(PromptDataset, self).__init__() self.keyed_prompt = defaultdict(list) - logger.info("Loading data...") + self.logger = get_dist_logger() + self.logger.info("Loading data...") list_data_dict = jload(data_path) - logger.info(f"Loaded {len(list_data_dict)} examples.") + self.logger.info(f"Loaded {len(list_data_dict)} examples.") if max_datasets_size is not None: - logger.info(f"Limiting dataset to {max_datasets_size} examples.") + self.logger.info(f"Limiting dataset to {max_datasets_size} examples.") list_data_dict = list_data_dict[:max_datasets_size] - for data_dict in list_data_dict: - token = tokenizer(data_dict["instruction"], - return_tensors='pt', - max_length=max_length, - padding='max_length', - truncation=True) - for k, tensor in token.items(): - self.keyed_prompt[k].extend(tensor.to(torch.cuda.current_device()).unbind()) + instructions = [data_dict["instruction"] for data_dict in list_data_dict] + tokens = tokenizer( + instructions, return_tensors="pt", max_length=max_length, padding="max_length", truncation=True + ) + for k, tensor in tokens.items(): + self.keyed_prompt[k] = tensor.to(torch.cuda.current_device()).unbind() def __len__(self): - return len(self.keyed_prompt) + return len(self.keyed_prompt["input_ids"]) def __getitem__(self, i) -> Dict[str, torch.Tensor]: return {k: v[i] for k, v in self.keyed_prompt.items()} diff --git a/applications/Chat/coati/dataset/reward_dataset.py b/applications/Chat/coati/dataset/reward_dataset.py index faa1c94d27286a7c69a62c7da699a3377fe14cd9..3afcd7b692380259a77eb2ec5938b2b5ea9b1c83 100644 --- a/applications/Chat/coati/dataset/reward_dataset.py +++ b/applications/Chat/coati/dataset/reward_dataset.py @@ -6,7 +6,7 @@ from tqdm import tqdm from .utils import is_rank_0 -# Dahaos/rm-static +# Dahoas/rm-static class RmStaticDataset(Dataset): """ Dataset for reward model @@ -20,44 +20,31 @@ class RmStaticDataset(Dataset): def __init__(self, dataset, tokenizer: Callable, max_length: int, special_token=None) -> None: super().__init__() - self.chosen = [] - self.reject = [] - if special_token is None: - self.end_token = tokenizer.eos_token - else: - self.end_token = special_token - for data in tqdm(dataset, disable=not is_rank_0()): - prompt = data['prompt'] - - chosen = prompt + data['chosen'] + self.end_token - chosen_token = tokenizer(chosen, - max_length=max_length, - padding="max_length", - truncation=True, - return_tensors="pt") - self.chosen.append({ - "input_ids": chosen_token['input_ids'], - "attention_mask": chosen_token['attention_mask'] - }) - - reject = prompt + data['rejected'] + self.end_token - reject_token = tokenizer(reject, - max_length=max_length, - padding="max_length", - truncation=True, - return_tensors="pt") - self.reject.append({ - "input_ids": reject_token['input_ids'], - "attention_mask": reject_token['attention_mask'] - }) + self.end_token = tokenizer.eos_token if special_token is None else special_token + + chosen = [data["prompt"] + data["chosen"] + self.end_token for data in tqdm(dataset, disable=not is_rank_0())] + chosen_token = tokenizer( + chosen, max_length=max_length, padding="max_length", truncation=True, return_tensors="pt" + ) + self.chosen = {"input_ids": chosen_token["input_ids"], "attention_mask": chosen_token["attention_mask"]} + + reject = [data["prompt"] + data["rejected"] + self.end_token for data in tqdm(dataset, disable=not is_rank_0())] + reject_token = tokenizer( + reject, max_length=max_length, padding="max_length", truncation=True, return_tensors="pt" + ) + self.reject = {"input_ids": reject_token["input_ids"], "attention_mask": reject_token["attention_mask"]} def __len__(self): - length = len(self.chosen) + length = self.chosen["input_ids"].shape[0] return length def __getitem__(self, idx): - return self.chosen[idx]["input_ids"], self.chosen[idx]["attention_mask"], self.reject[idx][ - "input_ids"], self.reject[idx]["attention_mask"] + return ( + self.chosen["input_ids"][idx], + self.chosen["attention_mask"][idx], + self.reject["input_ids"][idx], + self.reject["attention_mask"][idx], + ) # Anthropic/hh-rlhf @@ -74,39 +61,28 @@ class HhRlhfDataset(Dataset): def __init__(self, dataset, tokenizer: Callable, max_length: int, special_token=None) -> None: super().__init__() - self.chosen = [] - self.reject = [] - if special_token is None: - self.end_token = tokenizer.eos_token - else: - self.end_token = special_token - for data in tqdm(dataset, disable=not is_rank_0()): - chosen = data['chosen'] + self.end_token - chosen_token = tokenizer(chosen, - max_length=max_length, - padding="max_length", - truncation=True, - return_tensors="pt") - self.chosen.append({ - "input_ids": chosen_token['input_ids'], - "attention_mask": chosen_token['attention_mask'] - }) - - reject = data['rejected'] + self.end_token - reject_token = tokenizer(reject, - max_length=max_length, - padding="max_length", - truncation=True, - return_tensors="pt") - self.reject.append({ - "input_ids": reject_token['input_ids'], - "attention_mask": reject_token['attention_mask'] - }) + self.end_token = tokenizer.eos_token if special_token is None else special_token + + chosen = [data["chosen"] + self.end_token for data in tqdm(dataset, disable=not is_rank_0())] + chosen_token = tokenizer( + chosen, max_length=max_length, padding="max_length", truncation=True, return_tensors="pt" + ) + self.chosen = {"input_ids": chosen_token["input_ids"], "attention_mask": chosen_token["attention_mask"]} + + reject = [data["rejected"] + self.end_token for data in tqdm(dataset, disable=not is_rank_0())] + reject_token = tokenizer( + reject, max_length=max_length, padding="max_length", truncation=True, return_tensors="pt" + ) + self.reject = {"input_ids": reject_token["input_ids"], "attention_mask": reject_token["attention_mask"]} def __len__(self): - length = len(self.chosen) + length = self.chosen["input_ids"].shape[0] return length def __getitem__(self, idx): - return self.chosen[idx]["input_ids"], self.chosen[idx]["attention_mask"], self.reject[idx][ - "input_ids"], self.reject[idx]["attention_mask"] + return ( + self.chosen["input_ids"][idx], + self.chosen["attention_mask"][idx], + self.reject["input_ids"][idx], + self.reject["attention_mask"][idx], + ) diff --git a/applications/Chat/coati/dataset/sft_dataset.py b/applications/Chat/coati/dataset/sft_dataset.py index 3e2453468bbc4d5be00e9c2964803299bf176004..c0e257f54a0782711faa2142d0e5f29a2c2a9568 100644 --- a/applications/Chat/coati/dataset/sft_dataset.py +++ b/applications/Chat/coati/dataset/sft_dataset.py @@ -13,15 +13,13 @@ # limitations under the License. import copy -import random -from dataclasses import dataclass, field -from typing import Callable, Dict, Sequence +from typing import Dict, Optional, Sequence, Tuple import torch -import torch.distributed as dist -import transformers +from coati.models.chatglm.chatglm_tokenizer import ChatGLMTokenizer from torch.utils.data import Dataset from tqdm import tqdm +from transformers import PreTrainedTokenizer from colossalai.logging import get_dist_logger @@ -31,16 +29,89 @@ logger = get_dist_logger() IGNORE_INDEX = -100 PROMPT_DICT = { - "prompt_input": - ("Below is an instruction that describes a task, paired with an input that provides further context. " - "Write a response that appropriately completes the request.\n\n" - "### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:"), - "prompt_no_input": ("Below is an instruction that describes a task. " - "Write a response that appropriately completes the request.\n\n" - "### Instruction:\n{instruction}\n\n### Response:"), + "prompt_input": ( + "Below is an instruction that describes a task, paired with an input that provides further context. " + "Write a response that appropriately completes the request.\n\n" + "### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:" + ), + "prompt_no_input": ( + "Below is an instruction that describes a task. " + "Write a response that appropriately completes the request.\n\n" + "### Instruction:\n{instruction}\n\n### Response:" + ), } +def _preprocess( + sources: Sequence[str], + targets: Sequence[str], + tokenizer: PreTrainedTokenizer, + max_length: int, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Preprocess the data by tokenizing.""" + sequences = [s + t for s, t in zip(sources, targets)] + sequences_token = tokenizer( + sequences, max_length=max_length, padding="max_length", truncation=True, return_tensors="pt" + ) + sources_token = tokenizer( + sources, max_length=max_length, padding="max_length", truncation=True, return_tensors="pt" + ) + + assert sequences_token["attention_mask"].dim() == 2, "seq2seq model should be preprocessed differently" + labels = copy.deepcopy(sequences_token["input_ids"]) + for i in range(labels.shape[0]): + source_len = sources_token["attention_mask"][i].sum().item() + pad_len = max_length - sequences_token["attention_mask"][i].sum().item() + if tokenizer.padding_side == "right": + # |prompt|completion|eos|pad| + labels[i][:source_len] = IGNORE_INDEX + labels[i][-pad_len:] = IGNORE_INDEX + elif tokenizer.padding_side == "left": + # |pad|prompt|completion|eos| + labels[i][: pad_len + source_len] = IGNORE_INDEX + else: + raise RuntimeError() + + return sequences_token["input_ids"], labels, sequences_token["attention_mask"] + + +def _preprocess_chatglm( + sources: Sequence[str], + targets: Sequence[str], + tokenizer: PreTrainedTokenizer, + max_length: int, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Preprocess the data by tokenizing. + None for attention mask, ChatGLM will calculate attention mask according to input ids + """ + + labels = [] + input_ids = [] + for source, target in zip(sources, targets): + source_id = tokenizer.encode(text=source, add_special_tokens=False) + target_id = tokenizer.encode(text=target, add_special_tokens=False) + input_id = tokenizer.build_inputs_with_special_tokens(source_id, target_id) + # truncate + sp_token_list = [tokenizer.gmask_token_id, tokenizer.bos_token_id] + truncate_length = max(0, len(input_id) - max_length) + input_id = input_id[truncate_length:] + if truncate_length == len(source_id) + 1: + input_id = sp_token_list + input_id[1:] + elif truncate_length > len(source_id) + 1: + input_id = sp_token_list + input_id[2:] + + context_length = input_id.index(tokenizer.bos_token_id) + mask_position = context_length - 1 + label = [IGNORE_INDEX] * context_length + input_id[mask_position + 1 :] + + pad_len = max_length - len(input_id) + input_id = input_id + [tokenizer.pad_token_id] * pad_len + input_ids.append(input_id) + labels.append(label + [IGNORE_INDEX] * pad_len) + return torch.tensor(input_ids), torch.tensor(labels), None + + class SFTDataset(Dataset): """ Dataset for sft model @@ -51,73 +122,45 @@ class SFTDataset(Dataset): max_length: max length of input """ - def __init__(self, dataset, tokenizer: Callable, max_length: int = 512) -> None: + def __init__(self, dataset: Dict, tokenizer: PreTrainedTokenizer, max_length: int = 512) -> None: super().__init__() self.input_ids = [] - for data in tqdm(dataset, disable=not is_rank_0()): - prompt = data['prompt'] + data['completion'] + tokenizer.eos_token - prompt_token = tokenizer(prompt, - max_length=max_length, - padding="max_length", - truncation=True, - return_tensors="pt") + sources = [data["prompt"] for data in dataset] + targets = [data["completion"] + tokenizer.eos_token for data in tqdm(dataset, disable=not is_rank_0())] - self.input_ids.append(prompt_token['input_ids'][0]) - self.labels = copy.deepcopy(self.input_ids) + logger.info("Tokenizing inputs... This may take some time...") + if isinstance(tokenizer, ChatGLMTokenizer): + self.input_ids, self.labels, self.attention_mask = _preprocess_chatglm( + sources, targets, tokenizer, max_length + ) + else: + self.input_ids, self.labels, self.attention_mask = _preprocess(sources, targets, tokenizer, max_length) + + logger.info("Loaded dataset.") def __len__(self): - length = len(self.input_ids) + length = self.input_ids.shape[0] return length def __getitem__(self, idx): - return dict(input_ids=self.input_ids[idx], labels=self.labels[idx]) - - -def _tokenize_fn(strings: Sequence[str], tokenizer: transformers.PreTrainedTokenizer, max_length: int) -> Dict: - """Tokenize a list of strings.""" - tokenized_list = [ - tokenizer( - text, - return_tensors="pt", - padding="longest", - max_length=max_length, - truncation=True, - ) for text in strings - ] - input_ids = labels = [tokenized.input_ids[0] for tokenized in tokenized_list] - input_ids_lens = labels_lens = [ - tokenized.input_ids.ne(tokenizer.pad_token_id).sum().item() for tokenized in tokenized_list - ] - return dict( - input_ids=input_ids, - labels=labels, - input_ids_lens=input_ids_lens, - labels_lens=labels_lens, - ) - - -def preprocess( - sources: Sequence[str], - targets: Sequence[str], - tokenizer: transformers.PreTrainedTokenizer, - max_length: int, -) -> Dict: - """Preprocess the data by tokenizing.""" - examples = [s + t for s, t in zip(sources, targets)] - examples_tokenized, sources_tokenized = [_tokenize_fn(strings, tokenizer, max_length) for strings in (examples, sources)] - input_ids = examples_tokenized["input_ids"] - labels = copy.deepcopy(input_ids) - for label, source_len in zip(labels, sources_tokenized["input_ids_lens"]): - label[:source_len] = IGNORE_INDEX - return dict(input_ids=input_ids, labels=labels) + if self.attention_mask is not None: + return dict(input_ids=self.input_ids[idx], labels=self.labels[idx], attention_mask=self.attention_mask[idx]) + else: + return dict(input_ids=self.input_ids[idx], labels=self.labels[idx]) class SupervisedDataset(Dataset): """Dataset for supervised fine-tuning.""" - def __init__(self, data_path: str, tokenizer: transformers.PreTrainedTokenizer, max_datasets_size: int = None, max_length: int = 512): - super(SupervisedDataset, self).__init__() + def __init__( + self, + data_path: str, + tokenizer: PreTrainedTokenizer, + max_datasets_size: Optional[int] = None, + max_length: int = 512, + ): + super().__init__() logger.info("Loading data...") list_data_dict = jload(data_path) logger.info(f"Loaded {len(list_data_dict)} examples.") @@ -129,38 +172,27 @@ class SupervisedDataset(Dataset): logger.info("Formatting inputs...") prompt_input, prompt_no_input = PROMPT_DICT["prompt_input"], PROMPT_DICT["prompt_no_input"] sources = [ - prompt_input.format_map(example) if example.get("input", "") != "" else prompt_no_input.format_map(example) + prompt_input.format_map(example) if "input" in example else prompt_no_input.format_map(example) for example in list_data_dict ] - targets = [f"{example['output']}{tokenizer.eos_token}" for example in list_data_dict] + targets = [example["output"] + tokenizer.eos_token for example in list_data_dict] logger.info("Tokenizing inputs... This may take some time...") - data_dict = preprocess(sources, targets, tokenizer, max_length) + if isinstance(tokenizer, ChatGLMTokenizer): + self.input_ids, self.labels, self.attention_mask = _preprocess_chatglm( + sources, targets, tokenizer, max_length + ) + else: + self.input_ids, self.labels, self.attention_mask = _preprocess(sources, targets, tokenizer, max_length) - self.input_ids = data_dict["input_ids"] - self.labels = data_dict["labels"] + logger.info("Loaded dataset.") def __len__(self): - return len(self.input_ids) - - def __getitem__(self, i) -> Dict[str, torch.Tensor]: - return dict(input_ids=self.input_ids[i], labels=self.labels[i]) - - -@dataclass -class DataCollatorForSupervisedDataset(object): - """Collate examples for supervised fine-tuning.""" - - tokenizer: transformers.PreTrainedTokenizer + length = self.input_ids.shape[0] + return length - def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]: - input_ids, labels = tuple([instance[key] for instance in instances] for key in ("input_ids", "labels")) - input_ids = torch.nn.utils.rnn.pad_sequence(input_ids, - batch_first=True, - padding_value=self.tokenizer.pad_token_id) - labels = torch.nn.utils.rnn.pad_sequence(labels, batch_first=True, padding_value=IGNORE_INDEX) - return dict( - input_ids=input_ids, - labels=labels, - attention_mask=input_ids.ne(self.tokenizer.pad_token_id), - ) + def __getitem__(self, idx): + if self.attention_mask is not None: + return dict(input_ids=self.input_ids[idx], labels=self.labels[idx], attention_mask=self.attention_mask[idx]) + else: + return dict(input_ids=self.input_ids[idx], labels=self.labels[idx]) diff --git a/applications/Chat/coati/experience_buffer/__init__.py b/applications/Chat/coati/experience_buffer/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..f2a48d0a3b205a771ab6a3afddfd92e9324ba844 --- /dev/null +++ b/applications/Chat/coati/experience_buffer/__init__.py @@ -0,0 +1,4 @@ +from .base import ExperienceBuffer +from .naive import NaiveExperienceBuffer + +__all__ = ["ExperienceBuffer", "NaiveExperienceBuffer"] diff --git a/applications/Chat/coati/experience_buffer/base.py b/applications/Chat/coati/experience_buffer/base.py new file mode 100644 index 0000000000000000000000000000000000000000..7047785308f31f5835c7d76bc7d5d04265523e13 --- /dev/null +++ b/applications/Chat/coati/experience_buffer/base.py @@ -0,0 +1,43 @@ +from abc import ABC, abstractmethod +from typing import Any + +from coati.experience_maker.base import Experience + + +class ExperienceBuffer(ABC): + """Experience buffer base class. It stores experience. + + Args: + sample_batch_size (int): Batch size when sampling. + limit (int, optional): Limit of number of experience samples. A number <= 0 means unlimited. Defaults to 0. + """ + + def __init__(self, sample_batch_size: int, limit: int = 0) -> None: + super().__init__() + self.sample_batch_size = sample_batch_size + # limit <= 0 means unlimited + self.limit = limit + + @abstractmethod + def append(self, experience: Experience) -> None: + pass + + @abstractmethod + def clear(self) -> None: + pass + + @abstractmethod + def sample(self) -> Experience: + pass + + @abstractmethod + def __len__(self) -> int: + pass + + @abstractmethod + def __getitem__(self, idx: int) -> Any: + pass + + @abstractmethod + def collate_fn(self, batch: Any) -> Experience: + pass diff --git a/applications/Chat/coati/experience_buffer/naive.py b/applications/Chat/coati/experience_buffer/naive.py new file mode 100644 index 0000000000000000000000000000000000000000..d47b67dbe7131df8ff4ba5202cfd8dbcf70eb1ad --- /dev/null +++ b/applications/Chat/coati/experience_buffer/naive.py @@ -0,0 +1,60 @@ +import random +import warnings +from typing import List + +import torch +from coati.experience_maker.base import Experience + +from .base import ExperienceBuffer +from .utils import BufferItem, make_experience_batch, split_experience_batch + + +class NaiveExperienceBuffer(ExperienceBuffer): + """Naive experience buffer class. It stores experience. + + Args: + sample_batch_size (int): Batch size when sampling. + limit (int, optional): Limit of number of experience samples. A number <= 0 means unlimited. Defaults to 0. + cpu_offload (bool, optional): Whether to offload experience to cpu when sampling. Defaults to True. + """ + + def __init__(self, sample_batch_size: int, limit: int = 0, cpu_offload: bool = True) -> None: + super().__init__(sample_batch_size, limit) + self.cpu_offload = cpu_offload + self.target_device = torch.device(f"cuda:{torch.cuda.current_device()}") + # TODO(ver217): add prefetch + self.items: List[BufferItem] = [] + + @torch.no_grad() + def append(self, experience: Experience) -> None: + if self.cpu_offload: + experience.to_device(torch.device("cpu")) + items = split_experience_batch(experience) + self.items.extend(items) + + if self.limit > 0: + samples_to_remove = len(self.items) - self.limit + if samples_to_remove > 0: + warnings.warn(f"Experience buffer is full. Removing {samples_to_remove} samples.") + self.items = self.items[samples_to_remove:] + + def clear(self) -> None: + self.items.clear() + + @torch.no_grad() + def sample(self) -> Experience: + items = random.sample(self.items, self.sample_batch_size) + experience = make_experience_batch(items) + if self.cpu_offload: + experience.to_device(self.target_device) + return experience + + def __len__(self) -> int: + return len(self.items) + + def __getitem__(self, idx: int) -> BufferItem: + return self.items[idx] + + def collate_fn(self, batch) -> Experience: + experience = make_experience_batch(batch) + return experience diff --git a/applications/Chat/coati/experience_buffer/utils.py b/applications/Chat/coati/experience_buffer/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..baedbebd184f7ee5dcf335ded07ac6b50b3cec63 --- /dev/null +++ b/applications/Chat/coati/experience_buffer/utils.py @@ -0,0 +1,74 @@ +from dataclasses import dataclass +from typing import List, Optional + +import torch +import torch.nn.functional as F +from coati.experience_maker.base import Experience + + +@dataclass +class BufferItem: + """BufferItem is an item of experience data. + + Shapes of each tensor: + sequences: (S) + action_log_probs: (A) + values: (1) + reward: (1) + advantages: (1) + attention_mask: (S) + action_mask: (A) + + "A" is the number of actions. + """ + + sequences: torch.Tensor + action_log_probs: torch.Tensor + values: torch.Tensor + reward: torch.Tensor + advantages: torch.Tensor + attention_mask: Optional[torch.LongTensor] + action_mask: Optional[torch.BoolTensor] + + +def split_experience_batch(experience: Experience) -> List[BufferItem]: + batch_size = experience.sequences.size(0) + batch_kwargs = [{} for _ in range(batch_size)] + keys = ("sequences", "action_log_probs", "values", "reward", "advantages", "attention_mask", "action_mask") + for key in keys: + value = getattr(experience, key) + if isinstance(value, torch.Tensor): + vals = torch.unbind(value) + else: + # None + vals = [value for _ in range(batch_size)] + assert batch_size == len(vals) + for i, v in enumerate(vals): + batch_kwargs[i][key] = v + items = [BufferItem(**kwargs) for kwargs in batch_kwargs] + return items + + +def _zero_pad_sequences(sequences: List[torch.Tensor], side: str = "left") -> torch.Tensor: + assert side in ("left", "right") + max_len = max(seq.size(0) for seq in sequences) + padded_sequences = [] + for seq in sequences: + pad_len = max_len - seq.size(0) + padding = (pad_len, 0) if side == "left" else (0, pad_len) + padded_sequences.append(F.pad(seq, padding)) + return torch.stack(padded_sequences, dim=0) + + +def make_experience_batch(items: List[BufferItem]) -> Experience: + kwargs = {} + to_pad_keys = set(("action_log_probs", "action_mask")) + keys = ("sequences", "action_log_probs", "values", "reward", "advantages", "attention_mask", "action_mask") + for key in keys: + vals = [getattr(item, key) for item in items] + if key in to_pad_keys: + batch_data = _zero_pad_sequences(vals) + else: + batch_data = torch.stack(vals, dim=0) + kwargs[key] = batch_data + return Experience(**kwargs) diff --git a/applications/Chat/coati/experience_maker/__init__.py b/applications/Chat/coati/experience_maker/__init__.py index 39ca7576b22761807b5804cc15d82e15fdde42cd..06452292e77c4a6e32ec5ed2a7c996e2b7fa2b64 100644 --- a/applications/Chat/coati/experience_maker/__init__.py +++ b/applications/Chat/coati/experience_maker/__init__.py @@ -1,4 +1,4 @@ from .base import Experience, ExperienceMaker from .naive import NaiveExperienceMaker -__all__ = ['Experience', 'ExperienceMaker', 'NaiveExperienceMaker'] +__all__ = ["Experience", "ExperienceMaker", "NaiveExperienceMaker"] diff --git a/applications/Chat/coati/experience_maker/base.py b/applications/Chat/coati/experience_maker/base.py index ff75852576c848625d8786298a46f59e6a395598..0731f6e0f97f3cb6efb7a25621763ce7350d6ab6 100644 --- a/applications/Chat/coati/experience_maker/base.py +++ b/applications/Chat/coati/experience_maker/base.py @@ -3,14 +3,13 @@ from dataclasses import dataclass from typing import Optional import torch -import torch.nn as nn -from coati.models.base import Actor +from coati.models.base import Actor, Critic, RewardModel @dataclass class Experience: """Experience is a batch of data. - These data should have the the sequence length and number of actions. + These data should have the sequence length and number of actions. Left padding for sequences is applied. Shapes of each tensor: @@ -24,6 +23,7 @@ class Experience: "A" is the number of actions. """ + sequences: torch.Tensor action_log_probs: torch.Tensor values: torch.Tensor @@ -58,20 +58,13 @@ class Experience: class ExperienceMaker(ABC): - - def __init__(self, - actor: Actor, - critic: nn.Module, - reward_model: nn.Module, - initial_model: Actor, - kl_coef: float = 0.1) -> None: + def __init__(self, actor: Actor, critic: Critic, reward_model: RewardModel, initial_model: Actor) -> None: super().__init__() self.actor = actor self.critic = critic self.reward_model = reward_model self.initial_model = initial_model - self.kl_coef = kl_coef @abstractmethod - def make_experience(self, input_ids: torch.Tensor, **generate_kwargs) -> Experience: + def make_experience(self, input_ids: torch.Tensor, attention_mask: torch.Tensor, **generate_kwargs) -> Experience: pass diff --git a/applications/Chat/coati/experience_maker/naive.py b/applications/Chat/coati/experience_maker/naive.py index 94546eeb28e787df620d15ec631b1776dc6bf282..941e1994b148ffa45b57287ee560e138f0d90923 100644 --- a/applications/Chat/coati/experience_maker/naive.py +++ b/applications/Chat/coati/experience_maker/naive.py @@ -1,5 +1,9 @@ import torch -from coati.models.utils import compute_reward, normalize +import torch.nn.functional as F +from coati.models.base import Actor, Critic, RewardModel +from coati.models.generation import generate +from coati.models.utils import calc_action_log_probs, compute_reward +from transformers import PreTrainedTokenizer from .base import Experience, ExperienceMaker @@ -9,6 +13,19 @@ class NaiveExperienceMaker(ExperienceMaker): Naive experience maker. """ + def __init__( + self, + actor: Actor, + critic: Critic, + reward_model: RewardModel, + initial_model: Actor, + tokenizer: PreTrainedTokenizer, + kl_coef: float = 0.1, + ) -> None: + super().__init__(actor, critic, reward_model, initial_model) + self.tokenizer = tokenizer + self.kl_coef = kl_coef + @torch.no_grad() def make_experience(self, input_ids: torch.Tensor, **generate_kwargs) -> Experience: self.actor.eval() @@ -16,14 +33,33 @@ class NaiveExperienceMaker(ExperienceMaker): self.initial_model.eval() self.reward_model.eval() - sequences, attention_mask, action_mask = self.actor.generate(input_ids, - return_action_mask=True, - **generate_kwargs) + # generate sequences + sequences = generate(self.actor, input_ids, self.tokenizer, **generate_kwargs) + + # calculate auxiliary tensors + attention_mask = None + pad_token_id = self.tokenizer.pad_token_id + if pad_token_id is not None: + attention_mask = sequences.not_equal(pad_token_id).to(dtype=torch.long, device=sequences.device) + + input_len = input_ids.size(1) + eos_token_id = self.tokenizer.eos_token_id + if eos_token_id is None: + action_mask = torch.ones_like(sequences, dtype=torch.bool) + else: + # left padding may be applied, only mask action + action_mask = (sequences[:, input_len:] == eos_token_id).cumsum(dim=-1) == 0 + action_mask = F.pad(action_mask, (1 + input_len, -1), value=True) # include eos token and input + action_mask[:, :input_len] = False + action_mask = action_mask[:, 1:] + action_mask = action_mask[:, -(sequences.size(1) - input_len) :] num_actions = action_mask.size(1) - action_log_probs = self.actor(sequences, num_actions, attention_mask) - base_action_log_probs = self.initial_model(sequences, num_actions, attention_mask) - value = self.critic(sequences, action_mask, attention_mask) + actor_output = self.actor(sequences, attention_mask)["logits"] + action_log_probs = calc_action_log_probs(actor_output, sequences, num_actions) + base_model_output = self.initial_model(sequences, attention_mask)["logits"] + base_action_log_probs = calc_action_log_probs(base_model_output, sequences, num_actions) + value = self.critic(sequences, attention_mask) r = self.reward_model(sequences, attention_mask) reward = compute_reward(r, self.kl_coef, action_log_probs, base_action_log_probs, action_mask=action_mask) diff --git a/applications/Chat/coati/kernels/__init__.py b/applications/Chat/coati/kernels/__init__.py index 230eedf7ecba9531dc53d9fafde34a15af806cb3..96d40c7c4709ef7ca56c3040b8d65121f4e4a85d 100644 --- a/applications/Chat/coati/kernels/__init__.py +++ b/applications/Chat/coati/kernels/__init__.py @@ -1,6 +1,6 @@ from .wrapper import convert_to_xformer_model, recover_from_xformer_model __all__ = [ - 'convert_to_xformer_model', - 'recover_from_xformer_model', + "convert_to_xformer_model", + "recover_from_xformer_model", ] diff --git a/applications/Chat/coati/kernels/opt_attn.py b/applications/Chat/coati/kernels/opt_attn.py index c10f341e94a3bd32511f6da0c50dbf075136935d..d1eb139187f392071d6cd15b7c3bc69bdd02a3bb 100644 --- a/applications/Chat/coati/kernels/opt_attn.py +++ b/applications/Chat/coati/kernels/opt_attn.py @@ -21,11 +21,12 @@ class XOPTAttention(OPTAttention): output_attentions: bool = False, ) -> Tuple[Tensor, Optional[Tensor], Optional[Tuple[Tensor]]]: if not self.training: - return super().forward(hidden_states, key_value_states, past_key_value, attention_mask, layer_head_mask, - output_attentions) + return super().forward( + hidden_states, key_value_states, past_key_value, attention_mask, layer_head_mask, output_attentions + ) """Input shape: Batch x Time x Channel""" - assert layer_head_mask is None, 'Xformers attention does not support layer_head_mask' - assert not output_attentions, 'Xformers attention does not support output_attentions' + assert layer_head_mask is None, "Xformers attention does not support layer_head_mask" + assert not output_attentions, "Xformers attention does not support output_attentions" # if key_value_states are provided this layer is used as a cross-attention layer # for the decoder @@ -69,15 +70,17 @@ class XOPTAttention(OPTAttention): key_states = key_states.transpose(1, 2) value_states = value_states.transpose(1, 2) - attn_output = xops.memory_efficient_attention(query_states, - key_states, - value_states, - attn_bias=xops.LowerTriangularMask(), - p=self.dropout if self.training else 0.0, - scale=self.scaling) + attn_output = xops.memory_efficient_attention( + query_states, + key_states, + value_states, + attn_bias=xops.LowerTriangularMask(), + p=self.dropout if self.training else 0.0, + scale=self.scaling, + ) # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be - # partitioned aross GPUs when using tensor-parallelism. + # partitioned across GPUs when using tensor-parallelism. attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim) attn_output = self.out_proj(attn_output) diff --git a/applications/Chat/coati/models/__init__.py b/applications/Chat/coati/models/__init__.py index 709bc5ac0948ec13ca251b77a944b96e6543e97a..ad4a525b4af29adbf678b492d596c6410ea9e964 100644 --- a/applications/Chat/coati/models/__init__.py +++ b/applications/Chat/coati/models/__init__.py @@ -1,8 +1,15 @@ from .base import Actor, Critic, RewardModel from .lora import LoRAModule, convert_to_lora_module -from .loss import LogExpLoss, LogSigLoss, PolicyLoss, PPOPtxActorLoss, ValueLoss +from .loss import LogExpLoss, LogSigLoss, PolicyLoss, ValueLoss __all__ = [ - 'Actor', 'Critic', 'RewardModel', 'PolicyLoss', 'ValueLoss', 'PPOPtxActorLoss', 'LogSigLoss', 'LogExpLoss', - 'LoRAModule', 'convert_to_lora_module' + "Actor", + "Critic", + "RewardModel", + "PolicyLoss", + "ValueLoss", + "LogSigLoss", + "LogExpLoss", + "LoRAModule", + "convert_to_lora_module", ] diff --git a/applications/Chat/coati/models/base/__init__.py b/applications/Chat/coati/models/base/__init__.py index fe4152f2b760bf38924f521ee1d70a89074c8004..5c9905bb2224bb265324dfe37a30a637838b62f4 100644 --- a/applications/Chat/coati/models/base/__init__.py +++ b/applications/Chat/coati/models/base/__init__.py @@ -1,3 +1,5 @@ +from typing import Union + import torch.nn as nn from .actor import Actor @@ -5,10 +7,10 @@ from .critic import Critic from .reward_model import RewardModel -def get_base_model(model: nn.Module) -> nn.Module: +def get_base_model(model: Union[Actor, Critic, RewardModel]) -> nn.Module: """Get the base model of our wrapper classes. - For Actor, it's base model is ``actor.model`` and it's usually a ``transformers.PreTrainedModel``. - For Critic and RewardModel, it's base model is itself. + For Actor, Critic and RewardModel, return ``model.model``, + it's usually a ``transformers.PreTrainedModel``. Args: model (nn.Module): model to get base model from @@ -16,9 +18,10 @@ def get_base_model(model: nn.Module) -> nn.Module: Returns: nn.Module: the base model """ - if isinstance(model, Actor): - return model.get_base_model() - return model + assert isinstance( + model, (Actor, Critic, RewardModel) + ), f"Expect Actor, Critic or RewardModel, got {type(model)}, use unwrap_model first." + return model.model -__all__ = ['Actor', 'Critic', 'RewardModel', 'get_base_model'] +__all__ = ["Actor", "Critic", "RewardModel", "get_base_model"] diff --git a/applications/Chat/coati/models/base/actor.py b/applications/Chat/coati/models/base/actor.py index 71fbf7bbae7d90875a202b5e4e56e431efc1468e..8b2b81ed071ce18c0a2a0cf060fa22badb164885 100644 --- a/applications/Chat/coati/models/base/actor.py +++ b/applications/Chat/coati/models/base/actor.py @@ -1,12 +1,9 @@ -from typing import Optional, Tuple, Union +from typing import Optional import torch import torch.nn as nn -import torch.nn.functional as F -from ..generation import generate from ..lora import LoRAModule -from ..utils import log_probs_from_logits class Actor(LoRAModule): @@ -19,47 +16,18 @@ class Actor(LoRAModule): lora_train_bias (str): LoRA bias training mode. """ - def __init__(self, model: nn.Module, lora_rank: int = 0, lora_train_bias: str = 'none') -> None: + def __init__(self, model: nn.Module, lora_rank: int = 0, lora_train_bias: str = "none") -> None: super().__init__(lora_rank=lora_rank, lora_train_bias=lora_train_bias) self.model = model self.convert_to_lora() - @torch.no_grad() - def generate( + def forward( self, - input_ids: torch.Tensor, - return_action_mask: bool = True, - **kwargs - ) -> Union[Tuple[torch.LongTensor, torch.LongTensor], Tuple[torch.LongTensor, torch.LongTensor, torch.BoolTensor]]: - sequences = generate(self.model, input_ids, **kwargs) - attention_mask = None - pad_token_id = kwargs.get('pad_token_id', None) - if pad_token_id is not None: - attention_mask = sequences.not_equal(pad_token_id).to(dtype=torch.long, device=sequences.device) - if not return_action_mask: - return sequences, attention_mask, None - input_len = input_ids.size(1) - eos_token_id = kwargs.get('eos_token_id', None) - if eos_token_id is None: - action_mask = torch.ones_like(sequences, dtype=torch.bool) - else: - # left padding may be applied, only mask action - action_mask = (sequences[:, input_len:] == eos_token_id).cumsum(dim=-1) == 0 - action_mask = F.pad(action_mask, (1 + input_len, -1), value=True) # include eos token and input - action_mask[:, :input_len] = False - action_mask = action_mask[:, 1:] - return sequences, attention_mask, action_mask[:, -(sequences.size(1) - input_len):] - - def forward(self, - sequences: torch.LongTensor, - num_actions: int, - attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor: - """Returns action log probs - """ - output = self.model(sequences, attention_mask=attention_mask) - logits = output['logits'] - log_probs = log_probs_from_logits(logits[:, :-1, :], sequences[:, 1:]) - return log_probs[:, -num_actions:] - - def get_base_model(self): - return self.model + input_ids: torch.LongTensor, + attention_mask: Optional[torch.Tensor] = None, + **model_kwargs, + ) -> torch.Tensor: + """Returns model output.""" + output = self.model(input_ids, attention_mask=attention_mask, **model_kwargs) + return output + diff --git a/applications/Chat/coati/models/base/critic.py b/applications/Chat/coati/models/base/critic.py index e68a743a7762094c255df002565d61bdc611479f..8672365f5783240a63244fae19da4ab0d9ba20b9 100644 --- a/applications/Chat/coati/models/base/critic.py +++ b/applications/Chat/coati/models/base/critic.py @@ -1,10 +1,7 @@ -from typing import Optional - import torch import torch.nn as nn from ..lora import LoRAModule -from ..utils import masked_mean class Critic(LoRAModule): @@ -19,36 +16,19 @@ class Critic(LoRAModule): """ def __init__( - self, - model: nn.Module, - value_head: nn.Module, - lora_rank: int = 0, - lora_train_bias: str = 'none', - use_action_mask: bool = False, + self, model: nn.Module, value_head: nn.Module, lora_rank: int = 0, lora_train_bias: str = "none" ) -> None: - super().__init__(lora_rank=lora_rank, lora_train_bias=lora_train_bias) self.model = model self.value_head = value_head - self.use_action_mask = use_action_mask self.convert_to_lora() - def forward(self, - sequences: torch.LongTensor, - action_mask: Optional[torch.Tensor] = None, - attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor: + def forward(self, sequences: torch.LongTensor, attention_mask: torch.Tensor) -> torch.Tensor: outputs = self.model(sequences, attention_mask=attention_mask) - last_hidden_states = outputs['last_hidden_state'] - - values = self.value_head(last_hidden_states).squeeze(-1) - - if action_mask is not None and self.use_action_mask: - num_actions = action_mask.size(1) - prompt_mask = attention_mask[:, :-num_actions] - values = values[:, :-num_actions] - value = masked_mean(values, prompt_mask, dim=1) - return value - - values = values[:, :-1] - value = values.mean(dim=1) - return value + last_hidden_states = outputs["last_hidden_state"] + sequence_lengths = torch.max(attention_mask * torch.arange(sequences.size(1), device=sequences.device), dim=1)[ + 0 + ] + sequence_hidden_states = last_hidden_states[torch.arange(last_hidden_states.size(0)), sequence_lengths] + values = self.value_head(sequence_hidden_states).squeeze(1) # ensure shape is (B, ) + return values diff --git a/applications/Chat/coati/models/base/reward_model.py b/applications/Chat/coati/models/base/reward_model.py index ce8c0a1d35687da2cb61b55c59be6c91c0f9c394..e9545d1cddafae62d9999b9262bd969a56c8cf31 100644 --- a/applications/Chat/coati/models/base/reward_model.py +++ b/applications/Chat/coati/models/base/reward_model.py @@ -17,11 +17,13 @@ class RewardModel(LoRAModule): lora_train_bias (str): LoRA bias training mode. """ - def __init__(self, - model: nn.Module, - value_head: Optional[nn.Module] = None, - lora_rank: int = 0, - lora_train_bias: str = 'none') -> None: + def __init__( + self, + model: nn.Module, + value_head: Optional[nn.Module] = None, + lora_rank: int = 0, + lora_train_bias: str = "none", + ) -> None: super().__init__(lora_rank=lora_rank, lora_train_bias=lora_train_bias) self.model = model self.convert_to_lora() @@ -33,9 +35,12 @@ class RewardModel(LoRAModule): else: self.value_head = nn.Linear(model.config.n_embd, 1) - def forward(self, sequences: torch.LongTensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor: + def forward(self, sequences: torch.LongTensor, attention_mask: torch.Tensor) -> torch.Tensor: outputs = self.model(sequences, attention_mask=attention_mask) - last_hidden_states = outputs['last_hidden_state'] - values = self.value_head(last_hidden_states)[:, :-1] - value = values.mean(dim=1).squeeze(1) # ensure shape is (B) - return value + last_hidden_states = outputs["last_hidden_state"] + sequence_lengths = torch.max(attention_mask * torch.arange(sequences.size(1), device=sequences.device), dim=1)[ + 0 + ] + sequence_hidden_states = last_hidden_states[torch.arange(last_hidden_states.size(0)), sequence_lengths] + values = self.value_head(sequence_hidden_states).squeeze(1) # ensure shape is (B, ) + return values diff --git a/applications/Chat/coati/models/bloom/__init__.py b/applications/Chat/coati/models/bloom/__init__.py index d0e7f7b1ef94e2a8434fa36efe6b5a2e4586f5b8..7af199a67d3b282d48b325795fdc8212e5e26381 100644 --- a/applications/Chat/coati/models/bloom/__init__.py +++ b/applications/Chat/coati/models/bloom/__init__.py @@ -2,4 +2,4 @@ from .bloom_actor import BLOOMActor from .bloom_critic import BLOOMCritic from .bloom_rm import BLOOMRM -__all__ = ['BLOOMActor', 'BLOOMCritic', 'BLOOMRM'] +__all__ = ["BLOOMActor", "BLOOMCritic", "BLOOMRM"] diff --git a/applications/Chat/coati/models/bloom/bloom_actor.py b/applications/Chat/coati/models/bloom/bloom_actor.py index d7577f0964934955992f8d952c4e209a68a58829..73855a2245e7e2c4dfbb42c74be1e14a152a6262 100644 --- a/applications/Chat/coati/models/bloom/bloom_actor.py +++ b/applications/Chat/coati/models/bloom/bloom_actor.py @@ -1,7 +1,6 @@ from typing import Optional -import torch -from transformers import BloomConfig, BloomForCausalLM, BloomModel +from transformers import BloomConfig, BloomForCausalLM from ..base import Actor @@ -18,12 +17,14 @@ class BLOOMActor(Actor): lora_train_bias (str): LoRA bias training mode. """ - def __init__(self, - pretrained: str = None, - config: Optional[BloomConfig] = None, - checkpoint: bool = False, - lora_rank: int = 0, - lora_train_bias: str = 'none') -> None: + def __init__( + self, + pretrained: str = None, + config: Optional[BloomConfig] = None, + checkpoint: bool = False, + lora_rank: int = 0, + lora_train_bias: str = "none", + ) -> None: if pretrained is not None: model = BloomForCausalLM.from_pretrained(pretrained) elif config is not None: diff --git a/applications/Chat/coati/models/bloom/bloom_critic.py b/applications/Chat/coati/models/bloom/bloom_critic.py index a32fb2e102f9b05bb5c62d97f215343ce466b840..b2d838f7ffc50bbac62c2950ad4d0c701beeb491 100644 --- a/applications/Chat/coati/models/bloom/bloom_critic.py +++ b/applications/Chat/coati/models/bloom/bloom_critic.py @@ -1,8 +1,7 @@ from typing import Optional -import torch import torch.nn as nn -from transformers import BloomConfig, BloomForCausalLM, BloomModel +from transformers import BloomConfig, BloomModel from ..base import Critic @@ -14,25 +13,24 @@ class BLOOMCritic(Critic): Args: pretrained (str): Pretrained model name or path. config (BloomConfig): Model config. - checkpoint (bool): Enable gradient checkpointing. lora_rank (int): LoRA rank. lora_train_bias (str): LoRA bias training mode. """ - def __init__(self, - pretrained: str = None, - config: Optional[BloomConfig] = None, - checkpoint: bool = False, - lora_rank: int = 0, - lora_train_bias: str = 'none', - **kwargs) -> None: + def __init__( + self, + pretrained: str = None, + config: Optional[BloomConfig] = None, + lora_rank: int = 0, + lora_train_bias: str = "none", + **kwargs, + ) -> None: if pretrained is not None: model = BloomModel.from_pretrained(pretrained) elif config is not None: model = BloomModel(config) else: model = BloomModel(BloomConfig()) - if checkpoint: - model.gradient_checkpointing_enable() + value_head = nn.Linear(model.config.hidden_size, 1) super().__init__(model, value_head, lora_rank, lora_train_bias, **kwargs) diff --git a/applications/Chat/coati/models/bloom/bloom_rm.py b/applications/Chat/coati/models/bloom/bloom_rm.py index 22cfab441abb6666812a5875e02da77b28357e94..c09457ddc8c73190a9a9bef4a903588520c881ec 100644 --- a/applications/Chat/coati/models/bloom/bloom_rm.py +++ b/applications/Chat/coati/models/bloom/bloom_rm.py @@ -1,7 +1,7 @@ from typing import Optional import torch.nn as nn -from transformers import BloomConfig, BloomForCausalLM, BloomModel +from transformers import BloomConfig, BloomModel from ..base import RewardModel @@ -13,25 +13,24 @@ class BLOOMRM(RewardModel): Args: pretrained (str): Pretrained model name or path. config (BloomConfig): Model config. - checkpoint (bool): Enable gradient checkpointing. lora_rank (int): LoRA rank. lora_train_bias (str): LoRA bias training mode. """ - def __init__(self, - pretrained: str = None, - config: Optional[BloomConfig] = None, - checkpoint: bool = False, - lora_rank: int = 0, - lora_train_bias: str = 'none') -> None: + def __init__( + self, + pretrained: str = None, + config: Optional[BloomConfig] = None, + lora_rank: int = 0, + lora_train_bias: str = "none", + ) -> None: if pretrained is not None: model = BloomModel.from_pretrained(pretrained) elif config is not None: model = BloomModel(config) else: model = BloomModel(BloomConfig()) - if checkpoint: - model.gradient_checkpointing_enable() + value_head = nn.Linear(model.config.hidden_size, 1) value_head.weight.data.normal_(mean=0.0, std=1 / (model.config.hidden_size + 1)) super().__init__(model, value_head, lora_rank, lora_train_bias) diff --git a/applications/Chat/coati/models/chatglm/__init__.py b/applications/Chat/coati/models/chatglm/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..5956f5a8e91b8095de6b32edd4005051f73dc7bc --- /dev/null +++ b/applications/Chat/coati/models/chatglm/__init__.py @@ -0,0 +1,3 @@ +from .chatglm_actor import ChatGLMActor + +__all__ = ["ChatGLMActor"] diff --git a/applications/Chat/coati/models/chatglm/chatglm_actor.py b/applications/Chat/coati/models/chatglm/chatglm_actor.py new file mode 100644 index 0000000000000000000000000000000000000000..00a61561ee47d0a60dd981db2b95603e4d20eda3 --- /dev/null +++ b/applications/Chat/coati/models/chatglm/chatglm_actor.py @@ -0,0 +1,31 @@ +from typing import Optional + +from ..base import Actor +from .configuration_chatglm import ChatGLMConfig +from .modeling_chatglm import ChatGLMForConditionalGeneration + + +class ChatGLMActor(Actor): + """ + ChatGLM Actor model. + + Args: + pretrained (str): Pretrained model name or path. + config (ChatGLMConfig): Model config. + checkpoint (bool): Enable gradient checkpointing. + + do not support lora for now. + """ + + def __init__( + self, pretrained: str = None, config: Optional[ChatGLMConfig] = None, checkpoint: bool = False + ) -> None: + if pretrained is not None: + model = ChatGLMForConditionalGeneration.from_pretrained(pretrained) + elif config is not None: + model = ChatGLMForConditionalGeneration(config) + else: + model = ChatGLMForConditionalGeneration(ChatGLMConfig()) + if checkpoint: + model.gradient_checkpointing_enable() + super().__init__(model, lora_rank=0, lora_train_bias="none") diff --git a/applications/Chat/coati/models/chatglm/chatglm_tokenizer.py b/applications/Chat/coati/models/chatglm/chatglm_tokenizer.py new file mode 100644 index 0000000000000000000000000000000000000000..221ef044b470566ed2956f799086adc98ee62e09 --- /dev/null +++ b/applications/Chat/coati/models/chatglm/chatglm_tokenizer.py @@ -0,0 +1,442 @@ +""" +This code is copied from https://huggingface.co/THUDM/chatglm-6b/blob/main/tokenization_chatglm.py +""" +"""Tokenization classes for ChatGLM.""" +import os +from typing import Dict, List, Optional, Union + +import numpy as np +import sentencepiece as spm +from transformers.tokenization_utils import PreTrainedTokenizer +from transformers.tokenization_utils_base import BatchEncoding, EncodedInput +from transformers.utils import PaddingStrategy, logging + +logger = logging.get_logger(__name__) + +PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = { + "THUDM/chatglm-6b": 2048, +} + + +class TextTokenizer: + def __init__(self, model_path): + self.sp = spm.SentencePieceProcessor() + self.sp.Load(model_path) + self.num_tokens = self.sp.vocab_size() + + def encode(self, text): + return self.sp.EncodeAsIds(text) + + def decode(self, ids: List[int]): + return self.sp.DecodeIds(ids) + + def tokenize(self, text): + return self.sp.EncodeAsPieces(text) + + def convert_tokens_to_string(self, tokens): + return self.sp.DecodePieces(tokens) + + def convert_tokens_to_ids(self, tokens): + return [self.sp.PieceToId(token) for token in tokens] + + def convert_token_to_id(self, token): + return self.sp.PieceToId(token) + + def convert_id_to_token(self, idx): + return self.sp.IdToPiece(idx) + + def __len__(self): + return self.num_tokens + + +class SPTokenizer: + def __init__( + self, + vocab_file, + num_image_tokens=20000, + max_blank_length=80, + byte_fallback=True, + ): + assert vocab_file is not None + self.vocab_file = vocab_file + self.num_image_tokens = num_image_tokens + self.special_tokens = ["[MASK]", "[gMASK]", "[sMASK]", "", "", "", "", ""] + self.max_blank_length = max_blank_length + self.byte_fallback = byte_fallback + self.text_tokenizer = TextTokenizer(vocab_file) + + def _get_text_tokenizer(self): + return self.text_tokenizer + + @staticmethod + def get_blank_token(length: int): + assert length >= 2 + return f"<|blank_{length}|>" + + @staticmethod + def get_tab_token(): + return f"<|tab|>" + + @property + def num_text_tokens(self): + return self.text_tokenizer.num_tokens + + @property + def num_tokens(self): + return self.num_image_tokens + self.num_text_tokens + + @staticmethod + def _encode_whitespaces(text: str, max_len: int = 80): + text = text.replace("\t", SPTokenizer.get_tab_token()) + for i in range(max_len, 1, -1): + text = text.replace(" " * i, SPTokenizer.get_blank_token(i)) + return text + + def _preprocess(self, text: str, linebreak=True, whitespaces=True): + if linebreak: + text = text.replace("\n", "") + if whitespaces: + text = self._encode_whitespaces(text, max_len=self.max_blank_length) + return text + + def encode(self, text: str, linebreak=True, whitespaces=True, add_dummy_prefix=True) -> List[int]: + """ + @param text: Text to encode. + @param linebreak: Whether to encode newline (\n) in text. + @param whitespaces: Whether to encode multiple whitespaces or tab in text, useful for source code encoding. + @param special_tokens: Whether to encode special token ([MASK], [gMASK], etc.) in text. + @param add_dummy_prefix: Whether to add dummy blank space in the beginning. + """ + text = self._preprocess(text, linebreak, whitespaces) + if not add_dummy_prefix: + text = "" + text + tmp = self._get_text_tokenizer().encode(text) + tokens = [x + self.num_image_tokens for x in tmp] + return tokens if add_dummy_prefix else tokens[2:] + + def postprocess(self, text): + text = text.replace("", "\n") + text = text.replace(SPTokenizer.get_tab_token(), "\t") + for i in range(2, self.max_blank_length + 1): + text = text.replace(self.get_blank_token(i), " " * i) + return text + + def decode(self, text_ids: List[int]) -> str: + ids = [int(_id) - self.num_image_tokens for _id in text_ids] + ids = [_id for _id in ids if _id >= 0] + text = self._get_text_tokenizer().decode(ids) + text = self.postprocess(text) + return text + + def decode_tokens(self, tokens: List[str]) -> str: + text = self._get_text_tokenizer().convert_tokens_to_string(tokens) + text = self.postprocess(text) + return text + + def tokenize(self, text: str, linebreak=True, whitespaces=True, add_dummy_prefix=True) -> List[str]: + """ + @param text: Text to encode. + @param linebreak: Whether to encode newline (\n) in text. + @param whitespaces: Whether to encode multiple whitespaces or tab in text, useful for source code encoding. + @param special_tokens: Whether to encode special token ([MASK], [gMASK], etc.) in text. + @param add_dummy_prefix: Whether to add dummy blank space in the beginning. + """ + text = self._preprocess(text, linebreak, whitespaces) + if not add_dummy_prefix: + text = "" + text + tokens = self._get_text_tokenizer().tokenize(text) + return tokens if add_dummy_prefix else tokens[2:] + + def __getitem__(self, x: Union[int, str]): + if isinstance(x, int): + if x < self.num_image_tokens: + return "".format(x) + else: + return self.text_tokenizer.convert_id_to_token(x - self.num_image_tokens) + elif isinstance(x, str): + if x.startswith("") and x[7:-1].isdigit(): + return int(x[7:-1]) + else: + return self.text_tokenizer.convert_token_to_id(x) + self.num_image_tokens + else: + raise ValueError("The key should be str or int.") + + +class ChatGLMTokenizer(PreTrainedTokenizer): + """ + Construct a ChatGLM tokenizer. Based on byte-level Byte-Pair-Encoding. + + Args: + vocab_file (`str`): + Path to the vocabulary file. + """ + + vocab_files_names = {"vocab_file": "ice_text.model"} + max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES + model_input_names = ["input_ids", "attention_mask", "position_ids"] + + def __init__( + self, + vocab_file, + do_lower_case=False, + remove_space=False, + bos_token="", + eos_token="", + end_token="", + mask_token="[MASK]", + gmask_token="[gMASK]", + padding_side="left", + pad_token="", + unk_token="", + num_image_tokens=20000, + **kwargs, + ) -> None: + super().__init__( + do_lower_case=do_lower_case, + remove_space=remove_space, + padding_side=padding_side, + bos_token=bos_token, + eos_token=eos_token, + end_token=end_token, + mask_token=mask_token, + gmask_token=gmask_token, + pad_token=pad_token, + unk_token=unk_token, + num_image_tokens=num_image_tokens, + **kwargs, + ) + + self.do_lower_case = do_lower_case + self.remove_space = remove_space + self.vocab_file = vocab_file + + self.bos_token = bos_token + self.eos_token = eos_token + self.end_token = end_token + self.mask_token = mask_token + self.gmask_token = gmask_token + + self.sp_tokenizer = SPTokenizer(vocab_file, num_image_tokens=num_image_tokens) + + """ Initialisation """ + + @property + def gmask_token_id(self) -> Optional[int]: + if self.gmask_token is None: + return None + return self.convert_tokens_to_ids(self.gmask_token) + + @property + def end_token_id(self) -> Optional[int]: + """ + `Optional[int]`: Id of the end of context token in the vocabulary. Returns `None` if the token has not been + set. + """ + if self.end_token is None: + return None + return self.convert_tokens_to_ids(self.end_token) + + @property + def vocab_size(self): + """Returns vocab size""" + return self.sp_tokenizer.num_tokens + + def get_vocab(self): + """Returns vocab as a dict""" + vocab = {self._convert_id_to_token(i): i for i in range(self.vocab_size)} + vocab.update(self.added_tokens_encoder) + return vocab + + def preprocess_text(self, inputs): + if self.remove_space: + outputs = " ".join(inputs.strip().split()) + else: + outputs = inputs + + if self.do_lower_case: + outputs = outputs.lower() + + return outputs + + def _tokenize(self, text, **kwargs): + """Returns a tokenized string.""" + text = self.preprocess_text(text) + + seq = self.sp_tokenizer.tokenize(text) + + return seq + + def convert_tokens_to_string(self, tokens: List[str]) -> str: + return self.sp_tokenizer.decode_tokens(tokens) + + def _decode(self, token_ids: Union[int, List[int]], **kwargs) -> str: + if isinstance(token_ids, int): + token_ids = [token_ids] + if len(token_ids) == 0: + return "" + if self.pad_token_id in token_ids: # remove pad + token_ids = list(filter((self.pad_token_id).__ne__, token_ids)) + return super()._decode(token_ids, **kwargs) + + def _convert_token_to_id(self, token): + """Converts a token (str) in an id using the vocab.""" + return self.sp_tokenizer[token] + + def _convert_id_to_token(self, index): + """Converts an index (integer) in a token (str) using the vocab.""" + return self.sp_tokenizer[index] + + def save_vocabulary(self, save_directory, filename_prefix=None): + """ + Save the vocabulary and special tokens file to a directory. + + Args: + save_directory (`str`): + The directory in which to save the vocabulary. + filename_prefix (`str`, *optional*): + An optional prefix to add to the named of the saved files. + + Returns: + `Tuple(str)`: Paths to the files saved. + """ + if os.path.isdir(save_directory): + vocab_file = os.path.join(save_directory, self.vocab_files_names["vocab_file"]) + else: + vocab_file = save_directory + + with open(self.vocab_file, "rb") as fin: + proto_str = fin.read() + + with open(vocab_file, "wb") as writer: + writer.write(proto_str) + + return (vocab_file,) + + def build_inputs_with_special_tokens( + self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None + ) -> List[int]: + """ + Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and + adding special tokens. A BERT sequence has the following format: + + - single sequence: `[CLS] X [SEP]` + - pair of sequences: `[CLS] A [SEP] B [SEP]` + + Args: + token_ids_0 (`List[int]`): + List of IDs to which the special tokens will be added. + token_ids_1 (`List[int]`, *optional*): + Optional second list of IDs for sequence pairs. + + Returns: + `List[int]`: List of [input IDs](../glossary#input-ids) with the appropriate special tokens. + """ + gmask_id = self.sp_tokenizer[self.gmask_token] + self.sp_tokenizer[self.eos_token] + token_ids_0 = token_ids_0 + [gmask_id, self.sp_tokenizer[self.bos_token]] + if token_ids_1 is not None: + token_ids_0 = token_ids_0 + token_ids_1 + return token_ids_0 + + def _pad( + self, + encoded_inputs: Union[Dict[str, EncodedInput], BatchEncoding], + max_length: Optional[int] = None, + padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD, + pad_to_multiple_of: Optional[int] = None, + return_attention_mask: Optional[bool] = None, + ) -> dict: + """ + Pad encoded inputs (on left/right and up to predefined length or max length in the batch) + + Args: + encoded_inputs: + Dictionary of tokenized inputs (`List[int]`) or batch of tokenized inputs (`List[List[int]]`). + max_length: maximum length of the returned list and optionally padding length (see below). + Will truncate by taking into account the special tokens. + padding_strategy: PaddingStrategy to use for padding. + + - PaddingStrategy.LONGEST Pad to the longest sequence in the batch + - PaddingStrategy.MAX_LENGTH: Pad to the max length (default) + - PaddingStrategy.DO_NOT_PAD: Do not pad + The tokenizer padding sides are defined in self.padding_side: + + - 'left': pads on the left of the sequences + - 'right': pads on the right of the sequences + pad_to_multiple_of: (optional) Integer if set will pad the sequence to a multiple of the provided value. + This is especially useful to enable the use of Tensor Core on NVIDIA hardware with compute capability + `>= 7.5` (Volta). + return_attention_mask: + (optional) Set to False to avoid returning attention mask (default: set to model specifics) + """ + # Load from model defaults + bos_token_id = self.sp_tokenizer[self.bos_token] + mask_token_id = self.sp_tokenizer[self.mask_token] + gmask_token_id = self.sp_tokenizer[self.gmask_token] + assert self.padding_side == "left" + + required_input = encoded_inputs[self.model_input_names[0]] + seq_length = len(required_input) + + if padding_strategy == PaddingStrategy.LONGEST: + max_length = len(required_input) + + if max_length is not None and pad_to_multiple_of is not None and (max_length % pad_to_multiple_of != 0): + max_length = ((max_length // pad_to_multiple_of) + 1) * pad_to_multiple_of + + needs_to_be_padded = padding_strategy != PaddingStrategy.DO_NOT_PAD and len(required_input) != max_length + + # Initialize attention mask if not present. + if max_length is not None: + if "attention_mask" not in encoded_inputs: + if bos_token_id in required_input: + context_length = required_input.index(bos_token_id) + else: + context_length = seq_length + attention_mask = np.ones((1, seq_length, seq_length)) + attention_mask = np.tril(attention_mask) + attention_mask[:, :, :context_length] = 1 + attention_mask = np.bool_(attention_mask < 0.5) + encoded_inputs["attention_mask"] = attention_mask + + if "position_ids" not in encoded_inputs: + if bos_token_id in required_input: + context_length = required_input.index(bos_token_id) + else: + context_length = seq_length + position_ids = np.arange(seq_length, dtype=np.int64) + mask_token = mask_token_id if mask_token_id in required_input else gmask_token_id + if mask_token in required_input: + mask_position = required_input.index(mask_token) + position_ids[context_length:] = mask_position + block_position_ids = np.concatenate( + [ + np.zeros(context_length, dtype=np.int64), + np.arange(1, seq_length - context_length + 1, dtype=np.int64), + ] + ) + encoded_inputs["position_ids"] = np.stack([position_ids, block_position_ids], axis=0) + + if needs_to_be_padded: + difference = max_length - len(required_input) + + if "attention_mask" in encoded_inputs: + encoded_inputs["attention_mask"] = np.pad( + encoded_inputs["attention_mask"], + pad_width=[(0, 0), (difference, 0), (difference, 0)], + mode="constant", + constant_values=True, + ) + if "token_type_ids" in encoded_inputs: + encoded_inputs["token_type_ids"] = [self.pad_token_type_id] * difference + encoded_inputs[ + "token_type_ids" + ] + if "special_tokens_mask" in encoded_inputs: + encoded_inputs["special_tokens_mask"] = [1] * difference + encoded_inputs["special_tokens_mask"] + if "position_ids" in encoded_inputs: + encoded_inputs["position_ids"] = np.pad( + encoded_inputs["position_ids"], pad_width=[(0, 0), (difference, 0)] + ) + encoded_inputs[self.model_input_names[0]] = [self.pad_token_id] * difference + required_input + + return encoded_inputs diff --git a/applications/Chat/coati/models/chatglm/configuration_chatglm.py b/applications/Chat/coati/models/chatglm/configuration_chatglm.py new file mode 100644 index 0000000000000000000000000000000000000000..a6d2ccd187153d293aee713b9665f3be00707d6f --- /dev/null +++ b/applications/Chat/coati/models/chatglm/configuration_chatglm.py @@ -0,0 +1,101 @@ +""" +This code is copied from https://huggingface.co/THUDM/chatglm-6b/resolve/main/configuration_chatglm.py +""" + +""" ChatGLM model configuration """ + +from transformers.configuration_utils import PretrainedConfig +from transformers.utils import logging + +logger = logging.get_logger(__name__) + + +class ChatGLMConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`~ChatGLMModel`]. + It is used to instantiate an ChatGLM model according to the specified arguments, defining the model + architecture. Instantiating a configuration with the defaults will yield a similar configuration to that of + the ChatGLM-6B [THUDM/ChatGLM-6B](https://huggingface.co/THUDM/chatglm-6b) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used + to control the model outputs. Read the documentation from [`PretrainedConfig`] + for more information. + + + Args: + vocab_size (`int`, *optional*, defaults to 150528): + Vocabulary size of the ChatGLM-6B model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`~ChatGLMModel`] or + [`~TFChatGLMModel`]. + hidden_size (`int`, *optional*, defaults to 4096): + Dimension of the encoder layers and the pooler layer. + num_hidden_layers (`int`, *optional*, defaults to 28): + Number of hidden layers in the Transformer encoder. + num_attention_heads (`int`, *optional*, defaults to 32): + Number of attention heads for each attention layer in the Transformer encoder. + inner_hidden_size (`int`, *optional*, defaults to 16384): + Dimension of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder. + max_sequence_length (`int`, *optional*, defaults to 512): + The maximum sequence length that this model might ever be used with. + Typically set this to something large just in case (e.g., 512 or 1024 or 2048). + layernorm_epsilon (`float`, *optional*, defaults to 1e-5): + The epsilon used by the layer normalization layers. + use_cache (`bool`, *optional*, defaults to `True`): + Whether the model should return the last key/values attentions (not used by all models). + Example: + + ```python + >>> from configuration_chatglm import ChatGLMConfig + >>> from modeling_chatglm import ChatGLMModel + + >>> # Initializing a ChatGLM-6B THUDM/ChatGLM-6B style configuration + >>> configuration = ChatGLMConfig() + + >>> # Initializing a model from the THUDM/ChatGLM-6B style configuration + >>> model = ChatGLMModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + model_type = "chatglm" + + def __init__( + self, + vocab_size=130528, + hidden_size=4096, + num_layers=28, + num_attention_heads=32, + layernorm_epsilon=1e-5, + use_cache=True, + bos_token_id=130004, + eos_token_id=130005, + mask_token_id=130000, + gmask_token_id=130001, + pad_token_id=3, + max_sequence_length=2048, + inner_hidden_size=16384, + position_encoding_2d=True, + quantization_bit=0, + pre_seq_len=None, + prefix_projection=False, + **kwargs, + ): + self.num_layers = num_layers + self.vocab_size = vocab_size + self.hidden_size = hidden_size + self.num_attention_heads = num_attention_heads + self.max_sequence_length = max_sequence_length + self.layernorm_epsilon = layernorm_epsilon + self.inner_hidden_size = inner_hidden_size + self.use_cache = use_cache + self.bos_token_id = bos_token_id + self.eos_token_id = eos_token_id + self.pad_token_id = pad_token_id + self.mask_token_id = mask_token_id + self.gmask_token_id = gmask_token_id + self.position_encoding_2d = position_encoding_2d + self.quantization_bit = quantization_bit + self.pre_seq_len = pre_seq_len + self.prefix_projection = prefix_projection + + super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs) diff --git a/applications/Chat/coati/models/chatglm/modeling_chatglm.py b/applications/Chat/coati/models/chatglm/modeling_chatglm.py new file mode 100644 index 0000000000000000000000000000000000000000..d1d15c68ffd8bbc9686741da9f11461c09f02507 --- /dev/null +++ b/applications/Chat/coati/models/chatglm/modeling_chatglm.py @@ -0,0 +1,1477 @@ +""" +This code is copied from https://huggingface.co/THUDM/chatglm-6b/resolve/main/modeling_chatglm.py +""" + +""" PyTorch ChatGLM model. """ + +import copy +import math +import os +import re +import sys +import warnings +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +import torch +import torch.nn.functional as F +import torch.utils.checkpoint +from torch import nn +from torch.nn import CrossEntropyLoss, LayerNorm +from torch.nn.utils import skip_init +from transformers.generation.logits_process import LogitsProcessor +from transformers.generation.utils import GenerationConfig, LogitsProcessorList, ModelOutput, StoppingCriteriaList +from transformers.modeling_outputs import ( + BaseModelOutputWithPast, + BaseModelOutputWithPastAndCrossAttentions, + CausalLMOutputWithPast, +) +from transformers.modeling_utils import PreTrainedModel +from transformers.utils import ( + add_code_sample_docstrings, + add_start_docstrings, + add_start_docstrings_to_model_forward, + logging, +) + +from .configuration_chatglm import ChatGLMConfig + +# flags required to enable jit fusion kernels + +if sys.platform != "darwin": + torch._C._jit_set_profiling_mode(False) + torch._C._jit_set_profiling_executor(False) + torch._C._jit_override_can_fuse_on_cpu(True) + torch._C._jit_override_can_fuse_on_gpu(True) + +logger = logging.get_logger(__name__) + +_CHECKPOINT_FOR_DOC = "THUDM/ChatGLM-6B" +_CONFIG_FOR_DOC = "ChatGLM6BConfig" + +CHATGLM_6B_PRETRAINED_MODEL_ARCHIVE_LIST = [ + "THUDM/chatglm-6b", + # See all ChatGLM-6B models at https://huggingface.co/models?filter=chatglm +] + + +class InvalidScoreLogitsProcessor(LogitsProcessor): + def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: + if torch.isnan(scores).any() or torch.isinf(scores).any(): + scores.zero_() + scores[..., 5] = 5e4 + return scores + + +def load_tf_weights_in_chatglm_6b(model, config, tf_checkpoint_path): + """Load tf checkpoints in a pytorch model.""" + try: + import re + + import numpy as np + import tensorflow as tf + except ImportError: + logger.error( + "Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see " + "https://www.tensorflow.org/install/ for installation instructions." + ) + raise + tf_path = os.path.abspath(tf_checkpoint_path) + logger.info(f"Converting TensorFlow checkpoint from {tf_path}") + # Load weights from TF model + init_vars = tf.train.list_variables(tf_path) + names = [] + arrays = [] + for name, shape in init_vars: + logger.info(f"Loading TF weight {name} with shape {shape}") + array = tf.train.load_variable(tf_path, name) + names.append(name) + arrays.append(array) + + for name, array in zip(names, arrays): + name = name.split("/") + # adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v + # which are not required for using pretrained model + if any( + n in ["adam_v", "adam_m", "AdamWeightDecayOptimizer", "AdamWeightDecayOptimizer_1", "global_step"] + for n in name + ): + logger.info(f"Skipping {'/'.join(name)}") + continue + pointer = model + for m_name in name: + if re.fullmatch(r"[A-Za-z]+_\d+", m_name): + scope_names = re.split(r"_(\d+)", m_name) + else: + scope_names = [m_name] + if scope_names[0] == "kernel" or scope_names[0] == "gamma": + pointer = getattr(pointer, "weight") + elif scope_names[0] == "output_bias" or scope_names[0] == "beta": + pointer = getattr(pointer, "bias") + elif scope_names[0] == "output_weights": + pointer = getattr(pointer, "weight") + elif scope_names[0] == "squad": + pointer = getattr(pointer, "classifier") + else: + try: + pointer = getattr(pointer, scope_names[0]) + except AttributeError: + logger.info(f"Skipping {'/'.join(name)}") + continue + if len(scope_names) >= 2: + num = int(scope_names[1]) + pointer = pointer[num] + if m_name[-11:] == "_embeddings": + pointer = getattr(pointer, "weight") + elif m_name == "kernel": + array = np.transpose(array) + try: + assert ( + pointer.shape == array.shape + ), f"Pointer shape {pointer.shape} and array shape {array.shape} mismatched" + except AssertionError as e: + e.args += (pointer.shape, array.shape) + raise + logger.info(f"Initialize PyTorch weight {name}") + pointer.data = torch.from_numpy(array) + return model + + +class PrefixEncoder(torch.nn.Module): + """ + The torch.nn model to encode the prefix + Input shape: (batch-size, prefix-length) + Output shape: (batch-size, prefix-length, 2*layers*hidden) + """ + + def __init__(self, config): + super().__init__() + self.prefix_projection = config.prefix_projection + if self.prefix_projection: + # Use a two-layer MLP to encode the prefix + self.embedding = torch.nn.Embedding(config.pre_seq_len, config.hidden_size) + self.trans = torch.nn.Sequential( + torch.nn.Linear(config.hidden_size, config.hidden_size), + torch.nn.Tanh(), + torch.nn.Linear(config.hidden_size, config.num_layers * config.hidden_size * 2), + ) + else: + self.embedding = torch.nn.Embedding(config.pre_seq_len, config.num_layers * config.hidden_size * 2) + + def forward(self, prefix: torch.Tensor): + if self.prefix_projection: + prefix_tokens = self.embedding(prefix) + past_key_values = self.trans(prefix_tokens) + else: + past_key_values = self.embedding(prefix) + return past_key_values + + +@torch.jit.script +def gelu_impl(x): + """OpenAI's gelu implementation.""" + return 0.5 * x * (1.0 + torch.tanh(0.7978845608028654 * x * (1.0 + 0.044715 * x * x))) + + +def gelu(x): + return gelu_impl(x) + + +class RotaryEmbedding(torch.nn.Module): + def __init__(self, dim, base=10000, precision=torch.half, learnable=False): + super().__init__() + inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim)) + inv_freq = inv_freq.half() + self.learnable = learnable + if learnable: + self.inv_freq = torch.nn.Parameter(inv_freq) + self.max_seq_len_cached = None + else: + self.register_buffer("inv_freq", inv_freq) + self.max_seq_len_cached = None + self.cos_cached = None + self.sin_cached = None + self.precision = precision + + def _load_from_state_dict( + self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs + ): + pass + + def forward(self, x, seq_dim=1, seq_len=None): + if seq_len is None: + seq_len = x.shape[seq_dim] + if self.max_seq_len_cached is None or (seq_len > self.max_seq_len_cached): + self.max_seq_len_cached = None if self.learnable else seq_len + t = torch.arange(seq_len, device=x.device, dtype=self.inv_freq.dtype) + freqs = torch.einsum("i,j->ij", t, self.inv_freq) + # Different from paper, but it uses a different permutation in order to obtain the same calculation + emb = torch.cat((freqs, freqs), dim=-1).to(x.device) + if self.precision == torch.bfloat16: + emb = emb.float() + + # [sx, 1 (b * np), hn] + cos_cached = emb.cos()[:, None, :] + sin_cached = emb.sin()[:, None, :] + if self.precision == torch.bfloat16: + cos_cached = cos_cached.bfloat16() + sin_cached = sin_cached.bfloat16() + if self.learnable: + return cos_cached, sin_cached + self.cos_cached, self.sin_cached = cos_cached, sin_cached + return self.cos_cached[:seq_len, ...], self.sin_cached[:seq_len, ...] + + def _apply(self, fn): + if self.cos_cached is not None: + self.cos_cached = fn(self.cos_cached) + if self.sin_cached is not None: + self.sin_cached = fn(self.sin_cached) + return super()._apply(fn) + + +def rotate_half(x): + x1, x2 = x[..., : x.shape[-1] // 2], x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=x1.ndim - 1) # dim=-1 triggers a bug in earlier torch versions + + +@torch.jit.script +def apply_rotary_pos_emb_index(q, k, cos, sin, position_id): + # position_id: [sq, b], q, k: [sq, b, np, hn], cos: [sq, 1, hn] -> [sq, b, 1, hn] + cos, sin = F.embedding(position_id, cos.squeeze(1)).unsqueeze(2), F.embedding( + position_id, sin.squeeze(1) + ).unsqueeze(2) + q, k = (q * cos) + (rotate_half(q) * sin), (k * cos) + (rotate_half(k) * sin) + return q, k + + +def attention_fn( + self, + query_layer, + key_layer, + value_layer, + attention_mask, + hidden_size_per_partition, + layer_id, + layer_past=None, + scaling_attention_score=True, + use_cache=False, +): + if layer_past is not None: + past_key, past_value = layer_past[0], layer_past[1] + key_layer = torch.cat((past_key, key_layer), dim=0) + value_layer = torch.cat((past_value, value_layer), dim=0) + + # seqlen, batch, num_attention_heads, hidden_size_per_attention_head + seq_len, b, nh, hidden_size = key_layer.shape + + if use_cache: + present = (key_layer, value_layer) + else: + present = None + + query_key_layer_scaling_coeff = float(layer_id + 1) + if scaling_attention_score: + query_layer = query_layer / (math.sqrt(hidden_size) * query_key_layer_scaling_coeff) + + # =================================== + # Raw attention scores. [b, np, s, s] + # =================================== + + # [b, np, sq, sk] + output_size = (query_layer.size(1), query_layer.size(2), query_layer.size(0), key_layer.size(0)) + + # [sq, b, np, hn] -> [sq, b * np, hn] + query_layer = query_layer.view(output_size[2], output_size[0] * output_size[1], -1) + # [sk, b, np, hn] -> [sk, b * np, hn] + key_layer = key_layer.view(output_size[3], output_size[0] * output_size[1], -1) + + matmul_result = torch.zeros( + 1, + 1, + 1, + dtype=query_layer.dtype, + device=query_layer.device, + ) + + matmul_result = torch.baddbmm( + matmul_result, + query_layer.transpose(0, 1), # [b * np, sq, hn] + key_layer.transpose(0, 1).transpose(1, 2), # [b * np, hn, sk] + beta=0.0, + alpha=1.0, + ) + + # change view to [b, np, sq, sk] + attention_scores = matmul_result.view(*output_size) + + if self.scale_mask_softmax: + self.scale_mask_softmax.scale = query_key_layer_scaling_coeff + attention_probs = self.scale_mask_softmax(attention_scores, attention_mask.contiguous()) + else: + if not (attention_mask == 0).all(): + # if auto-regressive, skip + attention_scores.masked_fill_(attention_mask, -10000.0) + dtype = attention_scores.dtype + attention_scores = attention_scores.float() + attention_scores = attention_scores * query_key_layer_scaling_coeff + + attention_probs = F.softmax(attention_scores, dim=-1) + + attention_probs = attention_probs.type(dtype) + + # ========================= + # Context layer. [sq, b, hp] + # ========================= + + # value_layer -> context layer. + # [sk, b, np, hn] --> [b, np, sq, hn] + + # context layer shape: [b, np, sq, hn] + output_size = (value_layer.size(1), value_layer.size(2), query_layer.size(0), value_layer.size(3)) + + # change view [sk, b * np, hn] + value_layer = value_layer.view(value_layer.size(0), output_size[0] * output_size[1], -1) + + # change view [b * np, sq, sk] + attention_probs = attention_probs.view(output_size[0] * output_size[1], output_size[2], -1) + + # matmul: [b * np, sq, hn] + context_layer = torch.bmm(attention_probs, value_layer.transpose(0, 1)) + + # change view [b, np, sq, hn] + context_layer = context_layer.view(*output_size) + + # [b, np, sq, hn] --> [sq, b, np, hn] + context_layer = context_layer.permute(2, 0, 1, 3).contiguous() + + # [sq, b, np, hn] --> [sq, b, hp] + new_context_layer_shape = context_layer.size()[:-2] + (hidden_size_per_partition,) + context_layer = context_layer.view(*new_context_layer_shape) + + outputs = (context_layer, present, attention_probs) + + return outputs + + +def default_init(cls, *args, **kwargs): + return cls(*args, **kwargs) + + +class SelfAttention(torch.nn.Module): + def __init__( + self, + hidden_size, + num_attention_heads, + layer_id, + hidden_size_per_attention_head=None, + bias=True, + params_dtype=torch.float, + position_encoding_2d=True, + empty_init=True, + ): + if empty_init: + init_method = skip_init + else: + init_method = default_init + super(SelfAttention, self).__init__() + + self.layer_id = layer_id + self.hidden_size = hidden_size + self.hidden_size_per_partition = hidden_size + self.num_attention_heads = num_attention_heads + self.num_attention_heads_per_partition = num_attention_heads + self.position_encoding_2d = position_encoding_2d + self.rotary_emb = RotaryEmbedding( + self.hidden_size // (self.num_attention_heads * 2) + if position_encoding_2d + else self.hidden_size // self.num_attention_heads, + base=10000, + precision=torch.half, + learnable=False, + ) + + self.scale_mask_softmax = None + + if hidden_size_per_attention_head is None: + self.hidden_size_per_attention_head = hidden_size // num_attention_heads + else: + self.hidden_size_per_attention_head = hidden_size_per_attention_head + + self.inner_hidden_size = num_attention_heads * self.hidden_size_per_attention_head + + # Strided linear layer. + self.query_key_value = init_method( + torch.nn.Linear, + hidden_size, + 3 * self.inner_hidden_size, + bias=bias, + dtype=params_dtype, + ) + + self.dense = init_method( + torch.nn.Linear, + self.inner_hidden_size, + hidden_size, + bias=bias, + dtype=params_dtype, + ) + + @staticmethod + def attention_mask_func(attention_scores, attention_mask): + attention_scores.masked_fill_(attention_mask, -10000.0) + return attention_scores + + def split_tensor_along_last_dim(self, tensor, num_partitions, contiguous_split_chunks=False): + """Split a tensor along its last dimension. + Arguments: + tensor: input tensor. + num_partitions: number of partitions to split the tensor + contiguous_split_chunks: If True, make each chunk contiguous + in memory. + """ + # Get the size and dimension. + last_dim = tensor.dim() - 1 + last_dim_size = tensor.size()[last_dim] // num_partitions + # Split. + tensor_list = torch.split(tensor, last_dim_size, dim=last_dim) + # Note: torch.split does not create contiguous tensors by default. + if contiguous_split_chunks: + return tuple(chunk.contiguous() for chunk in tensor_list) + + return tensor_list + + def forward( + self, + hidden_states: torch.Tensor, + position_ids, + attention_mask: torch.Tensor, + layer_id, + layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + use_cache: bool = False, + output_attentions: bool = False, + ): + """ + hidden_states: [seq_len, batch, hidden_size] + attention_mask: [(1, 1), seq_len, seq_len] + """ + + # [seq_len, batch, 3 * hidden_size] + mixed_raw_layer = self.query_key_value(hidden_states) + + # [seq_len, batch, 3 * hidden_size] --> [seq_len, batch, num_attention_heads, 3 * hidden_size_per_attention_head] + new_tensor_shape = mixed_raw_layer.size()[:-1] + ( + self.num_attention_heads_per_partition, + 3 * self.hidden_size_per_attention_head, + ) + mixed_raw_layer = mixed_raw_layer.view(*new_tensor_shape) + + # [seq_len, batch, num_attention_heads, hidden_size_per_attention_head] + (query_layer, key_layer, value_layer) = self.split_tensor_along_last_dim(mixed_raw_layer, 3) + + if self.position_encoding_2d: + q1, q2 = query_layer.chunk(2, dim=(query_layer.ndim - 1)) + k1, k2 = key_layer.chunk(2, dim=(key_layer.ndim - 1)) + cos, sin = self.rotary_emb(q1, seq_len=position_ids.max() + 1) + position_ids, block_position_ids = ( + position_ids[:, 0, :].transpose(0, 1).contiguous(), + position_ids[:, 1, :].transpose(0, 1).contiguous(), + ) + q1, k1 = apply_rotary_pos_emb_index(q1, k1, cos, sin, position_ids) + q2, k2 = apply_rotary_pos_emb_index(q2, k2, cos, sin, block_position_ids) + query_layer = torch.concat([q1, q2], dim=(q1.ndim - 1)) + key_layer = torch.concat([k1, k2], dim=(k1.ndim - 1)) + else: + position_ids = position_ids.transpose(0, 1) + cos, sin = self.rotary_emb(value_layer, seq_len=position_ids.max() + 1) + # [seq_len, batch, num_attention_heads, hidden_size_per_attention_head] + query_layer, key_layer = apply_rotary_pos_emb_index(query_layer, key_layer, cos, sin, position_ids) + + # [seq_len, batch, hidden_size] + context_layer, present, attention_probs = attention_fn( + self=self, + query_layer=query_layer, + key_layer=key_layer, + value_layer=value_layer, + attention_mask=attention_mask, + hidden_size_per_partition=self.hidden_size_per_partition, + layer_id=layer_id, + layer_past=layer_past, + use_cache=use_cache, + ) + + output = self.dense(context_layer) + + outputs = (output, present) + + if output_attentions: + outputs += (attention_probs,) + + return outputs # output, present, attention_probs + + +class GEGLU(torch.nn.Module): + def __init__(self): + super().__init__() + self.activation_fn = F.gelu + + def forward(self, x): + # dim=-1 breaks in jit for pt<1.10 + x1, x2 = x.chunk(2, dim=(x.ndim - 1)) + return x1 * self.activation_fn(x2) + + +class GLU(torch.nn.Module): + def __init__( + self, + hidden_size, + inner_hidden_size=None, + layer_id=None, + bias=True, + activation_func=gelu, + params_dtype=torch.float, + empty_init=True, + ): + super(GLU, self).__init__() + if empty_init: + init_method = skip_init + else: + init_method = default_init + self.layer_id = layer_id + self.activation_func = activation_func + + # Project to 4h. + self.hidden_size = hidden_size + if inner_hidden_size is None: + inner_hidden_size = 4 * hidden_size + self.inner_hidden_size = inner_hidden_size + self.dense_h_to_4h = init_method( + torch.nn.Linear, + self.hidden_size, + self.inner_hidden_size, + bias=bias, + dtype=params_dtype, + ) + # Project back to h. + self.dense_4h_to_h = init_method( + torch.nn.Linear, + self.inner_hidden_size, + self.hidden_size, + bias=bias, + dtype=params_dtype, + ) + + def forward(self, hidden_states): + """ + hidden_states: [seq_len, batch, hidden_size] + """ + + # [seq_len, batch, inner_hidden_size] + intermediate_parallel = self.dense_h_to_4h(hidden_states) + + intermediate_parallel = self.activation_func(intermediate_parallel) + + output = self.dense_4h_to_h(intermediate_parallel) + + return output + + +class GLMBlock(torch.nn.Module): + def __init__( + self, + hidden_size, + num_attention_heads, + layernorm_epsilon, + layer_id, + inner_hidden_size=None, + hidden_size_per_attention_head=None, + layernorm=LayerNorm, + use_bias=True, + params_dtype=torch.float, + num_layers=28, + position_encoding_2d=True, + empty_init=True, + ): + super(GLMBlock, self).__init__() + # Set output layer initialization if not provided. + + self.layer_id = layer_id + + # Layernorm on the input data. + self.input_layernorm = layernorm(hidden_size, eps=layernorm_epsilon) + + self.position_encoding_2d = position_encoding_2d + + # Self attention. + self.attention = SelfAttention( + hidden_size, + num_attention_heads, + layer_id, + hidden_size_per_attention_head=hidden_size_per_attention_head, + bias=use_bias, + params_dtype=params_dtype, + position_encoding_2d=self.position_encoding_2d, + empty_init=empty_init, + ) + + # Layernorm on the input data. + self.post_attention_layernorm = layernorm(hidden_size, eps=layernorm_epsilon) + + self.num_layers = num_layers + + # GLU + self.mlp = GLU( + hidden_size, + inner_hidden_size=inner_hidden_size, + bias=use_bias, + layer_id=layer_id, + params_dtype=params_dtype, + empty_init=empty_init, + ) + + def forward( + self, + hidden_states: torch.Tensor, + position_ids, + attention_mask: torch.Tensor, + layer_id, + layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + use_cache: bool = False, + output_attentions: bool = False, + ): + """ + hidden_states: [seq_len, batch, hidden_size] + attention_mask: [(1, 1), seq_len, seq_len] + """ + + # Layer norm at the begining of the transformer layer. + # [seq_len, batch, hidden_size] + attention_input = self.input_layernorm(hidden_states) + + # Self attention. + attention_outputs = self.attention( + attention_input, + position_ids, + attention_mask=attention_mask, + layer_id=layer_id, + layer_past=layer_past, + use_cache=use_cache, + output_attentions=output_attentions, + ) + + attention_output = attention_outputs[0] + + outputs = attention_outputs[1:] + + # Residual connection. + alpha = (2 * self.num_layers) ** 0.5 + hidden_states = attention_input * alpha + attention_output + + mlp_input = self.post_attention_layernorm(hidden_states) + + # MLP. + mlp_output = self.mlp(mlp_input) + + # Second residual connection. + output = mlp_input * alpha + mlp_output + + if use_cache: + outputs = (output,) + outputs + else: + outputs = (output,) + outputs[1:] + + return outputs # hidden_states, present, attentions + + +class ChatGLMPreTrainedModel(PreTrainedModel): + """ + An abstract class to handle weights initialization and + a simple interface for downloading and loading pretrained models. + """ + + is_parallelizable = False + supports_gradient_checkpointing = True + config_class = ChatGLMConfig + base_model_prefix = "transformer" + _no_split_modules = ["GLMBlock"] + + def __init__(self, *inputs, **kwargs): + super().__init__(*inputs, **kwargs) + + def _init_weights(self, module: nn.Module): + """Initialize the weights.""" + return + + def get_masks(self, input_ids, device): + batch_size, seq_length = input_ids.shape + context_lengths = [seq.tolist().index(self.config.bos_token_id) for seq in input_ids] + attention_mask = torch.ones((batch_size, seq_length, seq_length), device=device) + attention_mask.tril_() + for i, context_length in enumerate(context_lengths): + attention_mask[i, :, :context_length] = 1 + attention_mask.unsqueeze_(1) + attention_mask = (attention_mask < 0.5).bool() + + return attention_mask + + def get_position_ids(self, input_ids, mask_positions, device, use_gmasks=None): + batch_size, seq_length = input_ids.shape + if use_gmasks is None: + use_gmasks = [False] * batch_size + context_lengths = [seq.tolist().index(self.config.bos_token_id) for seq in input_ids] + if self.position_encoding_2d: + position_ids = torch.arange(seq_length, dtype=torch.long, device=device).unsqueeze(0).repeat(batch_size, 1) + for i, context_length in enumerate(context_lengths): + position_ids[i, context_length:] = mask_positions[i] + block_position_ids = [ + torch.cat( + ( + torch.zeros(context_length, dtype=torch.long, device=device), + torch.arange(seq_length - context_length, dtype=torch.long, device=device) + 1, + ) + ) + for context_length in context_lengths + ] + block_position_ids = torch.stack(block_position_ids, dim=0) + position_ids = torch.stack((position_ids, block_position_ids), dim=1) + else: + position_ids = torch.arange(seq_length, dtype=torch.long, device=device).unsqueeze(0).repeat(batch_size, 1) + for i, context_length in enumerate(context_lengths): + if not use_gmasks[i]: + position_ids[i, context_length:] = mask_positions[i] + + return position_ids + + def _set_gradient_checkpointing(self, module, value=False): + if isinstance(module, ChatGLMModel): + module.gradient_checkpointing = value + + +CHATGLM_6B_START_DOCSTRING = r""" + This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) sub-class. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general + usage and behavior. + + Parameters: + config ([`~ChatGLM6BConfig`]): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the configuration. + Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + +CHATGLM_6B_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `({0})`): + Indices of input sequence tokens in the vocabulary. + + Indices can be obtained using [`ChatGLM6BTokenizer`]. + See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.FloatTensor` of shape `({0})`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + token_type_ids (`torch.LongTensor` of shape `({0})`, *optional*): + Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0, 1]`: + + - 0 corresponds to a *sentence A* token, + - 1 corresponds to a *sentence B* token. + + [What are token type IDs?](../glossary#token-type-ids) + position_ids (`torch.LongTensor` of shape `({0})`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. + Selected in the range `[0, config.max_position_embeddings - 1]`. + + [What are position IDs?](../glossary#position-ids) + head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*): + Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + inputs_embeds (`torch.FloatTensor` of shape `({0}, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. + This is useful if you want more control over how to convert *input_ids* indices into associated vectors + than the model's internal embedding lookup matrix. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + + +@add_start_docstrings( + "The bare ChatGLM-6B Model transformer outputting raw hidden-states without any specific head on top.", + CHATGLM_6B_START_DOCSTRING, +) +class ChatGLMModel(ChatGLMPreTrainedModel): + """ + + The model can behave as an encoder (with only self-attention) as well + as a decoder, in which case a layer of cross-attention is added between + the self-attention layers, following the architecture described in [Attention is + all you need](https://arxiv.org/abs/1706.03762) by Ashish Vaswani, + Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin. + + To behave as an decoder the model needs to be initialized with the + `is_decoder` argument of the configuration set to `True`. + To be used in a Seq2Seq model, the model needs to initialized with both `is_decoder` + argument and `add_cross_attention` set to `True`; an + `encoder_hidden_states` is then expected as an input to the forward pass. + """ + + def __init__(self, config: ChatGLMConfig, empty_init=True): + super().__init__(config) + if empty_init: + init_method = skip_init + else: + init_method = default_init + # recording parameters + self.max_sequence_length = config.max_sequence_length + self.hidden_size = config.hidden_size + self.params_dtype = torch.half + self.num_attention_heads = config.num_attention_heads + self.vocab_size = config.vocab_size + self.num_layers = config.num_layers + self.layernorm_epsilon = config.layernorm_epsilon + self.inner_hidden_size = config.inner_hidden_size + self.hidden_size_per_attention_head = self.hidden_size // self.num_attention_heads + self.position_encoding_2d = config.position_encoding_2d + self.pre_seq_len = config.pre_seq_len + self.prefix_projection = config.prefix_projection + + self.word_embeddings = init_method( + torch.nn.Embedding, num_embeddings=self.vocab_size, embedding_dim=self.hidden_size, dtype=self.params_dtype + ) + self.gradient_checkpointing = False + + def get_layer(layer_id): + return GLMBlock( + self.hidden_size, + self.num_attention_heads, + self.layernorm_epsilon, + layer_id, + inner_hidden_size=self.inner_hidden_size, + hidden_size_per_attention_head=self.hidden_size_per_attention_head, + layernorm=LayerNorm, + use_bias=True, + params_dtype=self.params_dtype, + position_encoding_2d=self.position_encoding_2d, + empty_init=empty_init, + ) + + self.layers = torch.nn.ModuleList([get_layer(layer_id) for layer_id in range(self.num_layers)]) + + # Final layer norm before output. + self.final_layernorm = LayerNorm(self.hidden_size, eps=self.layernorm_epsilon) + + if self.pre_seq_len is not None: + for param in self.parameters(): + param.requires_grad = False + self.prefix_tokens = torch.arange(self.pre_seq_len).long() + self.prefix_encoder = PrefixEncoder(config) + self.dropout = torch.nn.Dropout(0.1) + + # total_params = sum(p.numel() for p in self.parameters()) + # trainable_params = sum(p.numel() for p in self.parameters() if p.requires_grad) + # print("Using p-tuning v2: # trainable_params = {} / {}".format(trainable_params, total_params)) + + def get_input_embeddings(self): + return self.word_embeddings + + def set_input_embeddings(self, new_embeddings: torch.Tensor): + self.word_embeddings = new_embeddings + + def get_prompt(self, batch_size, device, dtype=torch.half): + prefix_tokens = self.prefix_tokens.unsqueeze(0).expand(batch_size, -1).to(device) + past_key_values = self.prefix_encoder(prefix_tokens).type(dtype) + past_key_values = past_key_values.view( + batch_size, + self.pre_seq_len, + self.num_layers * 2, + self.num_attention_heads, + self.hidden_size // self.num_attention_heads, + ) + # seq_len, b, nh, hidden_size + past_key_values = self.dropout(past_key_values) + past_key_values = past_key_values.permute([2, 1, 0, 3, 4]).split(2) + # past_key_values = [(v[0], v[1]) for v in past_key_values] + return past_key_values + + @add_start_docstrings_to_model_forward(CHATGLM_6B_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=BaseModelOutputWithPastAndCrossAttentions, + config_class=_CONFIG_FOR_DOC, + ) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None, + inputs_embeds: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple[torch.Tensor, ...], BaseModelOutputWithPast]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is not None: + batch_size, seq_length = input_ids.shape[:2] + elif inputs_embeds is not None: + batch_size, seq_length = inputs_embeds.shape[:2] + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + if inputs_embeds is None: + inputs_embeds = self.word_embeddings(input_ids) + + if past_key_values is None: + if self.pre_seq_len is not None: + past_key_values = self.get_prompt( + batch_size=input_ids.shape[0], device=input_ids.device, dtype=inputs_embeds.dtype + ) + else: + past_key_values = tuple([None] * len(self.layers)) + + if attention_mask is None: + attention_mask = self.get_masks(input_ids, device=input_ids.device) + + if position_ids is None: + MASK, gMASK = self.config.mask_token_id, self.config.gmask_token_id + seqs = input_ids.tolist() + + mask_positions, use_gmasks = [], [] + for seq in seqs: + mask_token = gMASK if gMASK in seq else MASK + use_gmask = mask_token == gMASK + mask_positions.append(seq.index(mask_token)) + use_gmasks.append(use_gmask) + + position_ids = self.get_position_ids( + input_ids, mask_positions=mask_positions, device=input_ids.device, use_gmasks=use_gmasks + ) + + if self.pre_seq_len is not None and attention_mask is not None: + prefix_attention_mask = torch.ones(batch_size, 1, input_ids.size(-1), self.pre_seq_len).to( + attention_mask.device + ) + prefix_attention_mask = (prefix_attention_mask < 0.5).bool() + attention_mask = torch.cat((prefix_attention_mask, attention_mask), dim=3) + + # [seq_len, batch, hidden_size] + hidden_states = inputs_embeds.transpose(0, 1) + + presents = () if use_cache else None + all_self_attentions = () if output_attentions else None + all_hidden_states = () if output_hidden_states else None + + if attention_mask is None: + attention_mask = torch.zeros(1, 1, device=input_ids.device).bool() + else: + attention_mask = attention_mask.to(hidden_states.device) + + for i, layer in enumerate(self.layers): + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + layer_past = past_key_values[i] + + if self.gradient_checkpointing and self.training: + layer_ret = torch.utils.checkpoint.checkpoint( + layer, + hidden_states, + position_ids, + attention_mask, + torch.tensor(i), + layer_past, + use_cache, + output_attentions, + ) + else: + layer_ret = layer( + hidden_states, + position_ids=position_ids, + attention_mask=attention_mask, + layer_id=torch.tensor(i), + layer_past=layer_past, + use_cache=use_cache, + output_attentions=output_attentions, + ) + + hidden_states = layer_ret[0] + + if use_cache: + presents = presents + (layer_ret[1],) + + if output_attentions: + all_self_attentions = all_self_attentions + (layer_ret[2 if use_cache else 1],) + + # Final layer norm. + hidden_states = self.final_layernorm(hidden_states) + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if not return_dict: + return tuple(v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None) + + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=presents, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + ) + + +class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel): + def __init__(self, config: ChatGLMConfig, empty_init=True): + super().__init__(config) + if empty_init: + init_method = skip_init + else: + init_method = default_init + + # self.hidden_size = config.hidden_size + # self.params_dtype = torch.half + # self.vocab_size = config.vocab_size + self.max_sequence_length = config.max_sequence_length + + self.position_encoding_2d = config.position_encoding_2d + + self.transformer = ChatGLMModel(config, empty_init=empty_init) + + self.lm_head = init_method(nn.Linear, config.hidden_size, config.vocab_size, bias=False, dtype=torch.half) + + self.config = config + + self.quantized = False + + if self.config.quantization_bit: + self.quantize(self.config.quantization_bit, empty_init=True) + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def _update_model_kwargs_for_generation( + self, + outputs: ModelOutput, + model_kwargs: Dict[str, Any], + is_encoder_decoder: bool = False, + standardize_cache_format: bool = False, + ) -> Dict[str, Any]: + # update past_key_values + model_kwargs["past_key_values"] = self._extract_past_from_model_output( + outputs, standardize_cache_format=standardize_cache_format + ) + + # update attention mask + if "attention_mask" in model_kwargs: + attention_mask = model_kwargs["attention_mask"] + if attention_mask is not None and attention_mask.dtype == torch.bool: + attention_mask = torch.cat( + [attention_mask, attention_mask.new_ones((*attention_mask.shape[:3], 1))], dim=3 + ) + new_attention_mask = attention_mask[:, :, -1:].clone() + new_attention_mask[..., -1] = False + model_kwargs["attention_mask"] = torch.cat([attention_mask, new_attention_mask], dim=2) + + # update position ids + if "position_ids" in model_kwargs: + position_ids = model_kwargs["position_ids"] + new_position_id = position_ids[..., -1:].clone() + new_position_id[:, 1, :] += 1 + model_kwargs["position_ids"] = torch.cat([position_ids, new_position_id], dim=-1) + + return model_kwargs + + def prepare_inputs_for_generation( + self, + input_ids: torch.LongTensor, + past: Optional[torch.Tensor] = None, + past_key_values: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + **kwargs, + ) -> dict: + batch_size, seq_length = input_ids.shape + MASK, gMASK = self.config.mask_token_id, self.config.gmask_token_id + seqs = input_ids.tolist() + mask_positions, use_gmasks = [], [] + for seq in seqs: + mask_token = gMASK if gMASK in seq else MASK + use_gmask = mask_token == gMASK + mask_positions.append(seq.index(mask_token)) + use_gmasks.append(use_gmask) + + # only last token for input_ids if past is not None + if past is not None or past_key_values is not None: + last_token = input_ids[:, -1].unsqueeze(-1) + if attention_mask is not None and attention_mask.dtype == torch.bool: + attention_mask = attention_mask[:, :, -1:] + else: + attention_mask = None + if position_ids is not None: + position_ids = position_ids[..., -1:] + else: + context_lengths = [seq.index(self.config.bos_token_id) for seq in seqs] + if self.position_encoding_2d: + position_ids = torch.tensor( + [ + [mask_position, seq_length - context_length] + for mask_position, context_length in zip(mask_positions, context_lengths) + ], + dtype=torch.long, + device=input_ids.device, + ).unsqueeze(-1) + else: + position_ids = torch.tensor( + [mask_position for mask_position in mask_positions], dtype=torch.long, device=input_ids.device + ).unsqueeze(-1) + + if past is None: + past = past_key_values + return { + "input_ids": last_token, + "past_key_values": past, + "position_ids": position_ids, + "attention_mask": attention_mask, + } + else: + if attention_mask is not None and attention_mask.dtype != torch.bool: + logger.warning_once(f"The dtype of attention mask ({attention_mask.dtype}) is not bool") + attention_mask = None + if attention_mask is None: + attention_mask = self.get_masks(input_ids, device=input_ids.device) + if position_ids is None: + position_ids = self.get_position_ids( + input_ids, device=input_ids.device, mask_positions=mask_positions, use_gmasks=use_gmasks + ) + + return { + "input_ids": input_ids, + "past_key_values": past, + "position_ids": position_ids, + "attention_mask": attention_mask, + } + + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[Tuple[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ): + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + transformer_outputs = self.transformer( + input_ids=input_ids, + position_ids=position_ids, + attention_mask=attention_mask, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + hidden_states = transformer_outputs[0] + + lm_logits = self.lm_head(hidden_states).permute(1, 0, 2).contiguous() + + loss = None + if labels is not None: + lm_logits = lm_logits.to(torch.float32) + + # Shift so that tokens < n predict n + shift_logits = lm_logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + loss_fct = CrossEntropyLoss(ignore_index=-100) + loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) + + lm_logits = lm_logits.to(hidden_states.dtype) + loss = loss.to(hidden_states.dtype) + + if not return_dict: + output = (lm_logits,) + transformer_outputs[1:] + return ((loss,) + output) if loss is not None else output + + return CausalLMOutputWithPast( + loss=loss, + logits=lm_logits, + past_key_values=transformer_outputs.past_key_values, + hidden_states=transformer_outputs.hidden_states, + attentions=transformer_outputs.attentions, + ) + + @staticmethod + def _reorder_cache( + past: Tuple[Tuple[torch.Tensor, torch.Tensor], ...], beam_idx: torch.LongTensor + ) -> Tuple[Tuple[torch.Tensor, torch.Tensor], ...]: + """ + This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or + [`~PreTrainedModel.beam_sample`] is called. This is required to match `past_key_values` with the correct + beam_idx at every generation step. + + Output shares the same memory storage as `past`. + """ + return tuple( + ( + layer_past[0].index_select(1, beam_idx.to(layer_past[0].device)), + layer_past[1].index_select(1, beam_idx.to(layer_past[1].device)), + ) + for layer_past in past + ) + + def process_response(self, response): + response = response.strip() + response = response.replace("[[训练时间]]", "2023年") + punkts = [ + [",", ","], + ["!", "!"], + [":", ":"], + [";", ";"], + ["\?", "?"], + ] + for item in punkts: + response = re.sub(r"([\u4e00-\u9fff])%s" % item[0], r"\1%s" % item[1], response) + response = re.sub(r"%s([\u4e00-\u9fff])" % item[0], r"%s\1" % item[1], response) + return response + + @torch.no_grad() + def chat( + self, + tokenizer, + query: str, + history: List[Tuple[str, str]] = None, + max_length: int = 2048, + num_beams=1, + do_sample=True, + top_p=0.7, + temperature=0.95, + logits_processor=None, + **kwargs, + ): + if history is None: + history = [] + if logits_processor is None: + logits_processor = LogitsProcessorList() + logits_processor.append(InvalidScoreLogitsProcessor()) + gen_kwargs = { + "max_length": max_length, + "num_beams": num_beams, + "do_sample": do_sample, + "top_p": top_p, + "temperature": temperature, + "logits_processor": logits_processor, + **kwargs, + } + if not history: + prompt = query + else: + prompt = "" + for i, (old_query, response) in enumerate(history): + prompt += "[Round {}]\n问:{}\n答:{}\n".format(i, old_query, response) + prompt += "[Round {}]\n问:{}\n答:".format(len(history), query) + inputs = tokenizer([prompt], return_tensors="pt") + inputs = inputs.to(self.device) + outputs = self.generate(**inputs, **gen_kwargs) + outputs = outputs.tolist()[0][len(inputs["input_ids"][0]) :] + response = tokenizer.decode(outputs) + response = self.process_response(response) + history = history + [(query, response)] + return response, history + + @torch.no_grad() + def stream_chat( + self, + tokenizer, + query: str, + history: List[Tuple[str, str]] = None, + max_length: int = 2048, + do_sample=True, + top_p=0.7, + temperature=0.95, + logits_processor=None, + **kwargs, + ): + if history is None: + history = [] + if logits_processor is None: + logits_processor = LogitsProcessorList() + logits_processor.append(InvalidScoreLogitsProcessor()) + gen_kwargs = { + "max_length": max_length, + "do_sample": do_sample, + "top_p": top_p, + "temperature": temperature, + "logits_processor": logits_processor, + **kwargs, + } + if not history: + prompt = query + else: + prompt = "" + for i, (old_query, response) in enumerate(history): + prompt += "[Round {}]\n问:{}\n答:{}\n".format(i, old_query, response) + prompt += "[Round {}]\n问:{}\n答:".format(len(history), query) + inputs = tokenizer([prompt], return_tensors="pt") + inputs = inputs.to(self.device) + for outputs in self.stream_generate(**inputs, **gen_kwargs): + outputs = outputs.tolist()[0][len(inputs["input_ids"][0]) :] + response = tokenizer.decode(outputs) + response = self.process_response(response) + new_history = history + [(query, response)] + yield response, new_history + + @torch.no_grad() + def stream_generate( + self, + input_ids, + generation_config: Optional[GenerationConfig] = None, + logits_processor: Optional[LogitsProcessorList] = None, + stopping_criteria: Optional[StoppingCriteriaList] = None, + prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None, + **kwargs, + ): + batch_size, input_ids_seq_length = input_ids.shape[0], input_ids.shape[-1] + + if generation_config is None: + generation_config = self.generation_config + generation_config = copy.deepcopy(generation_config) + model_kwargs = generation_config.update(**kwargs) + bos_token_id, eos_token_id = generation_config.bos_token_id, generation_config.eos_token_id + + if isinstance(eos_token_id, int): + eos_token_id = [eos_token_id] + + has_default_max_length = kwargs.get("max_length") is None and generation_config.max_length is not None + if has_default_max_length and generation_config.max_new_tokens is None: + warnings.warn( + f"Using `max_length`'s default ({generation_config.max_length}) to control the generation length. " + "This behaviour is deprecated and will be removed from the config in v5 of Transformers -- we" + " recommend using `max_new_tokens` to control the maximum length of the generation.", + UserWarning, + ) + elif generation_config.max_new_tokens is not None: + generation_config.max_length = generation_config.max_new_tokens + input_ids_seq_length + if not has_default_max_length: + logger.warn( + f"Both `max_new_tokens` (={generation_config.max_new_tokens}) and `max_length`(=" + f"{generation_config.max_length}) seem to have been set. `max_new_tokens` will take precedence. " + "Please refer to the documentation for more information. " + "(https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)", + UserWarning, + ) + + if input_ids_seq_length >= generation_config.max_length: + input_ids_string = "decoder_input_ids" if self.config.is_encoder_decoder else "input_ids" + logger.warning( + f"Input length of {input_ids_string} is {input_ids_seq_length}, but `max_length` is set to" + f" {generation_config.max_length}. This can lead to unexpected behavior. You should consider" + " increasing `max_new_tokens`." + ) + + # 2. Set generation parameters if not already defined + logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList() + stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList() + + logits_processor = self._get_logits_processor( + generation_config=generation_config, + input_ids_seq_length=input_ids_seq_length, + encoder_input_ids=input_ids, + prefix_allowed_tokens_fn=prefix_allowed_tokens_fn, + logits_processor=logits_processor, + ) + + stopping_criteria = self._get_stopping_criteria( + generation_config=generation_config, stopping_criteria=stopping_criteria + ) + logits_warper = self._get_logits_warper(generation_config) + + unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1) + scores = None + while True: + model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) + # forward pass to get next token + outputs = self( + **model_inputs, + return_dict=True, + output_attentions=False, + output_hidden_states=False, + ) + + next_token_logits = outputs.logits[:, -1, :] + + # pre-process distribution + next_token_scores = logits_processor(input_ids, next_token_logits) + next_token_scores = logits_warper(input_ids, next_token_scores) + + # sample + probs = nn.functional.softmax(next_token_scores, dim=-1) + if generation_config.do_sample: + next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1) + else: + next_tokens = torch.argmax(probs, dim=-1) + + # update generated ids, model inputs, and length for next step + input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1) + model_kwargs = self._update_model_kwargs_for_generation( + outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder + ) + unfinished_sequences = unfinished_sequences.mul((sum(next_tokens != i for i in eos_token_id)).long()) + + # stop when each sentence is finished, or if we exceed the maximum length + if unfinished_sequences.max() == 0 or stopping_criteria(input_ids, scores): + break + yield input_ids + + def quantize(self, bits: int, empty_init=False, **kwargs): + if bits == 0: + return + + from .quantization import quantize + + if self.quantized: + logger.info("Already quantized.") + return self + + self.quantized = True + + self.config.quantization_bit = bits + + self.transformer = quantize(self.transformer, bits, empty_init=empty_init, **kwargs) + return self diff --git a/applications/Chat/coati/models/deberta/__init__.py b/applications/Chat/coati/models/deberta/__init__.py deleted file mode 100644 index b66888f34fd0b81b726b28dec39b9ed26d99c945..0000000000000000000000000000000000000000 --- a/applications/Chat/coati/models/deberta/__init__.py +++ /dev/null @@ -1,4 +0,0 @@ -from .deberta_critic import DebertaCritic -from .deberta_rm import DebertaRM - -__all__ = ['DebertaCritic', 'DebertaRM'] diff --git a/applications/Chat/coati/models/deberta/deberta_critic.py b/applications/Chat/coati/models/deberta/deberta_critic.py deleted file mode 100644 index e84c1dbd8380728a6544c50be2b82146821fb3c3..0000000000000000000000000000000000000000 --- a/applications/Chat/coati/models/deberta/deberta_critic.py +++ /dev/null @@ -1,36 +0,0 @@ -from typing import Optional - -import torch.nn as nn -from transformers import DebertaV2Config, DebertaV2Model - -from ..base import Critic - - -class DebertaCritic(Critic): - """ - Deberta Critic model. - - Args: - pretrained (str): Pretrained model name or path. - config (DebertaV2Config): Model config. - checkpoint (bool): Enable gradient checkpointing. - lora_rank (int): Rank of the LO-RA decomposition. - lora_train_bias (str): LoRA bias training mode. - """ - - def __init__(self, - pretrained: Optional[str] = None, - config: Optional[DebertaV2Config] = None, - checkpoint: bool = False, - lora_rank: int = 0, - lora_train_bias: str = 'none') -> None: - if pretrained is not None: - model = DebertaV2Model.from_pretrained(pretrained) - elif config is not None: - model = DebertaV2Model(config) - else: - model = DebertaV2Model(DebertaV2Config()) - if checkpoint: - model.gradient_checkpointing_enable() - value_head = nn.Linear(model.config.hidden_size, 1) - super().__init__(model, value_head, lora_rank, lora_train_bias) diff --git a/applications/Chat/coati/models/deberta/deberta_rm.py b/applications/Chat/coati/models/deberta/deberta_rm.py deleted file mode 100644 index 2448c879ec859ebfd13afd859508e28b63427c06..0000000000000000000000000000000000000000 --- a/applications/Chat/coati/models/deberta/deberta_rm.py +++ /dev/null @@ -1,37 +0,0 @@ -from typing import Optional - -import torch.nn as nn -from transformers import DebertaV2Config, DebertaV2Model - -from ..base import RewardModel - - -class DebertaRM(RewardModel): - """ - Deberta Reward model. - - Args: - pretrained (str): Pretrained model name or path. - config (DebertaV2Config): Model config. - checkpoint (bool): Enable gradient checkpointing. - lora_rank (int): Rank of the LO-RA decomposition. - lora_train_bias (str): LoRA bias training mode. - """ - - def __init__(self, - pretrained: str = None, - config: Optional[DebertaV2Config] = None, - checkpoint: bool = False, - lora_rank: int = 0, - lora_train_bias: str = 'none') -> None: - if pretrained is not None: - model = DebertaV2Model.from_pretrained(pretrained) - elif config is not None: - model = DebertaV2Model(config) - else: - model = DebertaV2Model(DebertaV2Config()) - if checkpoint: - model.gradient_checkpointing_enable() - value_head = nn.Linear(model.config.hidden_size, 1) - value_head.weight.data.normal_(mean=0.0, std=1 / (model.config.hidden_size + 1)) - super().__init__(model, value_head, lora_rank, lora_train_bias) diff --git a/applications/Chat/coati/models/generation.py b/applications/Chat/coati/models/generation.py index f57c9458a271131a14d6b035cd039d2a5de1281b..4ab0cdc8a3eaa21afb5b342f0f12cfea7dddd0b0 100644 --- a/applications/Chat/coati/models/generation.py +++ b/applications/Chat/coati/models/generation.py @@ -2,7 +2,9 @@ from typing import Any, Callable, Optional import torch import torch.distributed as dist -import torch.nn as nn +from transformers import PreTrainedTokenizer + +from .base import Actor try: from transformers.generation_logits_process import ( @@ -15,9 +17,9 @@ except ImportError: from transformers.generation import LogitsProcessorList, TemperatureLogitsWarper, TopKLogitsWarper, TopPLogitsWarper -def prepare_logits_processor(top_k: Optional[int] = None, - top_p: Optional[float] = None, - temperature: Optional[float] = None) -> LogitsProcessorList: +def _prepare_logits_processor( + top_k: Optional[int] = None, top_p: Optional[float] = None, temperature: Optional[float] = None +) -> LogitsProcessorList: processor_list = LogitsProcessorList() if temperature is not None and temperature != 1.0: processor_list.append(TemperatureLogitsWarper(temperature)) @@ -36,32 +38,34 @@ def _is_sequence_finished(unfinished_sequences: torch.Tensor) -> bool: return unfinished_sequences.max() == 0 -def sample(model: nn.Module, - input_ids: torch.Tensor, - max_length: int, - early_stopping: bool = False, - eos_token_id: Optional[int] = None, - pad_token_id: Optional[int] = None, - top_k: Optional[int] = None, - top_p: Optional[float] = None, - temperature: Optional[float] = None, - prepare_inputs_fn: Optional[Callable[[torch.Tensor, Any], dict]] = None, - update_model_kwargs_fn: Optional[Callable[[dict, Any], dict]] = None, - **model_kwargs) -> torch.Tensor: +def _sample( + model: Actor, + input_ids: torch.Tensor, + max_length: int, + early_stopping: bool = False, + eos_token_id: Optional[int] = None, + pad_token_id: Optional[int] = None, + top_k: Optional[int] = None, + top_p: Optional[float] = None, + temperature: Optional[float] = None, + prepare_inputs_fn: Optional[Callable[[torch.Tensor, Any], dict]] = None, + update_model_kwargs_fn: Optional[Callable[[dict, Any], dict]] = None, + **model_kwargs, +) -> torch.Tensor: if input_ids.size(1) >= max_length: return input_ids - logits_processor = prepare_logits_processor(top_k, top_p, temperature) + logits_processor = _prepare_logits_processor(top_k, top_p, temperature) unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1) for _ in range(input_ids.size(1), max_length): - model_inputs = prepare_inputs_fn(input_ids, **model_kwargs) if prepare_inputs_fn is not None else { - 'input_ids': input_ids - } + model_inputs = ( + prepare_inputs_fn(input_ids, **model_kwargs) if prepare_inputs_fn is not None else {"input_ids": input_ids} + ) outputs = model(**model_inputs) - next_token_logits = outputs['logits'][:, -1, :] - # pre-process distribution + # NOTE: this is correct only in left padding mode + next_token_logits = outputs["logits"][:, -1, :] next_token_logits = logits_processor(input_ids, next_token_logits) # sample probs = torch.softmax(next_token_logits, dim=-1, dtype=torch.float) @@ -69,8 +73,7 @@ def sample(model: nn.Module, # finished sentences should have their next token be a padding token if eos_token_id is not None: - if pad_token_id is None: - raise ValueError("If `eos_token_id` is defined, make sure that `pad_token_id` is defined.") + assert pad_token_id is not None, "If `eos_token_id` is defined, make sure that `pad_token_id` is defined." next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences) # update generated ids, model inputs for next step @@ -89,20 +92,22 @@ def sample(model: nn.Module, return input_ids -def generate(model: nn.Module, - input_ids: torch.Tensor, - max_length: int, - num_beams: int = 1, - do_sample: bool = True, - early_stopping: bool = False, - eos_token_id: Optional[int] = None, - pad_token_id: Optional[int] = None, - top_k: Optional[int] = None, - top_p: Optional[float] = None, - temperature: Optional[float] = None, - prepare_inputs_fn: Optional[Callable[[torch.Tensor, Any], dict]] = None, - update_model_kwargs_fn: Optional[Callable[[dict, Any], dict]] = None, - **model_kwargs) -> torch.Tensor: +@torch.no_grad() +def generate( + model: Actor, + input_ids: torch.Tensor, + tokenizer: PreTrainedTokenizer, + max_length: int, + num_beams: int = 1, + do_sample: bool = True, + early_stopping: bool = False, + top_k: Optional[int] = None, + top_p: Optional[float] = None, + temperature: Optional[float] = None, + prepare_inputs_fn: Optional[Callable[[torch.Tensor, Any], dict]] = None, + update_model_kwargs_fn: Optional[Callable[[dict, Any], dict]] = None, + **model_kwargs, +) -> torch.Tensor: """Generate token sequence. The returned sequence is input_ids + generated_tokens. Args: @@ -112,34 +117,35 @@ def generate(model: nn.Module, num_beams (int, optional): number of beams. Defaults to 1. do_sample (bool, optional): whether to do sample. Defaults to True. early_stopping (bool, optional): if True, the sequence length may be smaller than max_length due to finding eos. Defaults to False. - eos_token_id (Optional[int], optional): end of sequence token id. Defaults to None. - pad_token_id (Optional[int], optional): pad token id. Defaults to None. top_k (Optional[int], optional): the number of highest probability vocabulary tokens to keep for top-k-filtering. Defaults to None. top_p (Optional[float], optional): If set to float < 1, only the smallest set of most probable tokens with probabilities that add up to top_p or higher are kept for generation. Defaults to None. temperature (Optional[float], optional): The value used to module the next token probabilities. Defaults to None. prepare_inputs_fn (Optional[Callable[[torch.Tensor, Any], dict]], optional): Function to preprocess model inputs. Arguments of this function should be input_ids and model_kwargs. Defaults to None. update_model_kwargs_fn (Optional[Callable[[dict, Any], dict]], optional): Function to update model_kwargs based on outputs. Arguments of this function should be outputs and model_kwargs. Defaults to None. """ - is_greedy_gen_mode = ((num_beams == 1) and do_sample is False) - is_sample_gen_mode = ((num_beams == 1) and do_sample is True) - is_beam_gen_mode = ((num_beams > 1) and do_sample is False) + assert tokenizer.padding_side == "left", "Current generation only supports left padding." + is_greedy_gen_mode = (num_beams == 1) and do_sample is False + is_sample_gen_mode = (num_beams == 1) and do_sample is True + is_beam_gen_mode = (num_beams > 1) and do_sample is False if is_greedy_gen_mode: # run greedy search raise NotImplementedError elif is_sample_gen_mode: # run sample - return sample(model, - input_ids, - max_length, - early_stopping=early_stopping, - eos_token_id=eos_token_id, - pad_token_id=pad_token_id, - top_k=top_k, - top_p=top_p, - temperature=temperature, - prepare_inputs_fn=prepare_inputs_fn, - update_model_kwargs_fn=update_model_kwargs_fn, - **model_kwargs) + return _sample( + model, + input_ids, + max_length, + early_stopping=early_stopping, + eos_token_id=tokenizer.eos_token_id, + pad_token_id=tokenizer.pad_token_id, + top_k=top_k, + top_p=top_p, + temperature=temperature, + prepare_inputs_fn=prepare_inputs_fn, + update_model_kwargs_fn=update_model_kwargs_fn, + **model_kwargs, + ) elif is_beam_gen_mode: raise NotImplementedError else: diff --git a/applications/Chat/coati/models/gpt/__init__.py b/applications/Chat/coati/models/gpt/__init__.py index 63dc5ab0f5ead4450e1f9bffae874520614fa05c..823cf4a75e0d4fe8cdd9de9e37947f0bf45e5657 100644 --- a/applications/Chat/coati/models/gpt/__init__.py +++ b/applications/Chat/coati/models/gpt/__init__.py @@ -2,4 +2,4 @@ from .gpt_actor import GPTActor from .gpt_critic import GPTCritic from .gpt_rm import GPTRM -__all__ = ['GPTActor', 'GPTCritic', 'GPTRM'] +__all__ = ["GPTActor", "GPTCritic", "GPTRM"] diff --git a/applications/Chat/coati/models/gpt/gpt_actor.py b/applications/Chat/coati/models/gpt/gpt_actor.py index ae9d669f1f5669dc63e829332ccdf7ebf991afe2..a7e4b9bc3e22425da707734a49a902a72368c011 100644 --- a/applications/Chat/coati/models/gpt/gpt_actor.py +++ b/applications/Chat/coati/models/gpt/gpt_actor.py @@ -18,13 +18,15 @@ class GPTActor(Actor): lora_train_bias (str): Bias training strategy for the LoRa layer. """ - def __init__(self, - pretrained: Optional[str] = None, - config: Optional[GPT2Config] = None, - checkpoint: bool = False, - lora_rank: int = 0, - lora_train_bias: str = 'none', - **kwargs) -> None: + def __init__( + self, + pretrained: Optional[str] = None, + config: Optional[GPT2Config] = None, + checkpoint: bool = False, + lora_rank: int = 0, + lora_train_bias: str = "none", + **kwargs, + ) -> None: if pretrained is not None: model = GPT2LMHeadModel.from_pretrained(pretrained) elif config is not None: diff --git a/applications/Chat/coati/models/gpt/gpt_critic.py b/applications/Chat/coati/models/gpt/gpt_critic.py index 2e70f5f1fc9632edbd2d83ef7b0067f7accd0eda..22ab36dea276cfeefffed917b2233cacab736b11 100644 --- a/applications/Chat/coati/models/gpt/gpt_critic.py +++ b/applications/Chat/coati/models/gpt/gpt_critic.py @@ -14,25 +14,24 @@ class GPTCritic(Critic): Args: pretrained (str): Pretrained model name or path. config (GPT2Config): Model config. - checkpoint (bool): Enable gradient checkpointing. lora_rank (int): Rank of the LO-RA decomposition. lora_train_bias (str): LoRA bias training mode. """ - def __init__(self, - pretrained: Optional[str] = None, - config: Optional[GPT2Config] = None, - checkpoint: bool = False, - lora_rank: int = 0, - lora_train_bias: str = 'none', - **kwargs) -> None: + def __init__( + self, + pretrained: Optional[str] = None, + config: Optional[GPT2Config] = None, + lora_rank: int = 0, + lora_train_bias: str = "none", + **kwargs, + ) -> None: if pretrained is not None: model = GPT2Model.from_pretrained(pretrained) elif config is not None: model = GPT2Model(config) else: model = GPT2Model(GPT2Config()) - if checkpoint: - model.gradient_checkpointing_enable() + value_head = nn.Linear(model.config.n_embd, 1) super().__init__(model, value_head, lora_rank, lora_train_bias, **kwargs) diff --git a/applications/Chat/coati/models/gpt/gpt_rm.py b/applications/Chat/coati/models/gpt/gpt_rm.py index 054432e1ce863a36c169ea3b7569198ab1994e6c..8edfc4008466bbfd5ab303e36877cd0d8f5d3a50 100644 --- a/applications/Chat/coati/models/gpt/gpt_rm.py +++ b/applications/Chat/coati/models/gpt/gpt_rm.py @@ -14,25 +14,23 @@ class GPTRM(RewardModel): Args: pretrained (str): Pretrained model name or path. config (GPT2Config): Model config. - checkpoint (bool): Enable gradient checkpointing. lora_rank (int): Rank of the low-rank approximation. lora_train_bias (str): LoRA bias training mode. """ - def __init__(self, - pretrained: Optional[str] = None, - config: Optional[GPT2Config] = None, - checkpoint: bool = False, - lora_rank: int = 0, - lora_train_bias: str = 'none') -> None: + def __init__( + self, + pretrained: Optional[str] = None, + config: Optional[GPT2Config] = None, + lora_rank: int = 0, + lora_train_bias: str = "none", + ) -> None: if pretrained is not None: model = GPT2Model.from_pretrained(pretrained) elif config is not None: model = GPT2Model(config) else: model = GPT2Model(GPT2Config()) - if checkpoint: - model.gradient_checkpointing_enable() value_head = nn.Linear(model.config.n_embd, 1) value_head.weight.data.normal_(mean=0.0, std=1 / (model.config.n_embd + 1)) diff --git a/applications/Chat/coati/models/llama/__init__.py b/applications/Chat/coati/models/llama/__init__.py index 9b2a024afdb28d3336d801895da3f26b01f28c56..c87d732538a9f0357902368c5d5a90b9174e1ead 100644 --- a/applications/Chat/coati/models/llama/__init__.py +++ b/applications/Chat/coati/models/llama/__init__.py @@ -2,4 +2,4 @@ from .llama_actor import LlamaActor from .llama_critic import LlamaCritic from .llama_rm import LlamaRM -__all__ = ['LlamaActor', 'LlamaCritic', 'LlamaRM'] +__all__ = ["LlamaActor", "LlamaCritic", "LlamaRM"] diff --git a/applications/Chat/coati/models/llama/llama_actor.py b/applications/Chat/coati/models/llama/llama_actor.py index 2c7adb390d8bea055e9fc84f75dc66658a8fe0e3..f1d9406835ca889138a465f28eb8fbb997226005 100644 --- a/applications/Chat/coati/models/llama/llama_actor.py +++ b/applications/Chat/coati/models/llama/llama_actor.py @@ -1,7 +1,6 @@ from typing import Optional -import torch -from transformers import AutoModelForCausalLM, LlamaConfig, LlamaForCausalLM +from transformers import LlamaConfig, LlamaForCausalLM from ..base import Actor @@ -18,13 +17,14 @@ class LlamaActor(Actor): lora_train_bias (str): LoRA bias training mode. """ - def __init__(self, - pretrained: Optional[str] = None, - config: Optional[LlamaConfig] = None, - checkpoint: bool = False, - lora_rank: int = 0, - lora_train_bias: str = 'none') -> None: - + def __init__( + self, + pretrained: Optional[str] = None, + config: Optional[LlamaConfig] = None, + checkpoint: bool = False, + lora_rank: int = 0, + lora_train_bias: str = "none", + ) -> None: if pretrained is not None: model = LlamaForCausalLM.from_pretrained(pretrained) elif config is not None: diff --git a/applications/Chat/coati/models/llama/llama_critic.py b/applications/Chat/coati/models/llama/llama_critic.py index dd9e5e7bfa1ae6da1c9c09789ae0eda602a35720..000dce17ccf013d510f98ac1ac24885bf5af757c 100644 --- a/applications/Chat/coati/models/llama/llama_critic.py +++ b/applications/Chat/coati/models/llama/llama_critic.py @@ -13,19 +13,18 @@ class LlamaCritic(Critic): Args: pretrained (str): Pretrained model name or path. config (LlamaConfig): Model config. - checkpoint (bool): Enable gradient checkpointing. lora_rank (int): LoRA rank. lora_train_bias (str): LoRA bias training mode. """ - def __init__(self, - pretrained: Optional[str] = None, - config: Optional[LlamaConfig] = None, - checkpoint: bool = False, - lora_rank: int = 0, - lora_train_bias: str = 'none', - **kwargs) -> None: - + def __init__( + self, + pretrained: Optional[str] = None, + config: Optional[LlamaConfig] = None, + lora_rank: int = 0, + lora_train_bias: str = "none", + **kwargs, + ) -> None: if pretrained is not None: model = LlamaModel.from_pretrained(pretrained) elif config is not None: @@ -33,9 +32,5 @@ class LlamaCritic(Critic): else: model = LlamaModel(LlamaConfig()) - if checkpoint: - model.gradient_checkpointing_enable() - value_head = nn.Linear(model.config.hidden_size, 1) - super().__init__(model, value_head, lora_rank, lora_train_bias, **kwargs) diff --git a/applications/Chat/coati/models/llama/llama_rm.py b/applications/Chat/coati/models/llama/llama_rm.py index f936019d62d28bac0bf41161b3f4aaad26e2bbf0..43bc9e638dc7944191cd903cca1134a15dbfe3dc 100644 --- a/applications/Chat/coati/models/llama/llama_rm.py +++ b/applications/Chat/coati/models/llama/llama_rm.py @@ -1,7 +1,7 @@ from typing import Optional import torch.nn as nn -from transformers import LlamaConfig, LlamaForCausalLM, LlamaModel +from transformers import LlamaConfig, LlamaModel from ..base import RewardModel @@ -13,18 +13,17 @@ class LlamaRM(RewardModel): Args: pretrained (str): Pretrained model name or path. config (LlamaConfig): Model config. - checkpoint (bool): Enable gradient checkpointing. lora_rank (int): LoRA rank. lora_train_bias (str): LoRA bias training mode. """ - def __init__(self, - pretrained: Optional[str] = None, - config: Optional[LlamaConfig] = None, - checkpoint: bool = False, - lora_rank: int = 0, - lora_train_bias: str = 'none') -> None: - + def __init__( + self, + pretrained: Optional[str] = None, + config: Optional[LlamaConfig] = None, + lora_rank: int = 0, + lora_train_bias: str = "none", + ) -> None: if pretrained is not None: model = LlamaModel.from_pretrained(pretrained) elif config is not None: @@ -32,8 +31,6 @@ class LlamaRM(RewardModel): else: model = LlamaModel(LlamaConfig()) - if checkpoint: - model.gradient_checkpointing_enable() value_head = nn.Linear(model.config.hidden_size, 1) value_head.weight.data.normal_(mean=0.0, std=1 / (model.config.hidden_size + 1)) diff --git a/applications/Chat/coati/models/lora.py b/applications/Chat/coati/models/lora.py index 0533a60dc53266d592fe2b5808eb637ac64db875..e9bd7b2ed8f0608e9fd6c4c26a39340d087a454f 100644 --- a/applications/Chat/coati/models/lora.py +++ b/applications/Chat/coati/models/lora.py @@ -1,4 +1,6 @@ +import dataclasses import math +import warnings from typing import Optional import loralib as lora @@ -7,9 +9,16 @@ import torch.nn as nn import torch.nn.functional as F +@dataclasses.dataclass +class LoRAManager: + merge_weights: bool = False + + +LORA_MANAGER = LoRAManager() + + class LoraLinear(lora.LoRALayer, nn.Module): - """Replace in-place ops to out-of-place ops to fit gemini. Convert a torch.nn.Linear to LoraLinear. - """ + """Replace in-place ops to out-of-place ops to fit gemini. Convert a torch.nn.Linear to LoraLinear.""" def __init__( self, @@ -17,16 +26,12 @@ class LoraLinear(lora.LoRALayer, nn.Module): bias: Optional[nn.Parameter], r: int = 0, lora_alpha: int = 1, - lora_dropout: float = 0., - fan_in_fan_out: bool = False, # Set this to True if the layer to replace stores weight like (fan_in, fan_out) - merge_weights: bool = True, + lora_dropout: float = 0.0, + # Set this to True if the layer to replace stores weight like (fan_in, fan_out) + fan_in_fan_out: bool = False, ): nn.Module.__init__(self) - lora.LoRALayer.__init__(self, - r=r, - lora_alpha=lora_alpha, - lora_dropout=lora_dropout, - merge_weights=merge_weights) + lora.LoRALayer.__init__(self, r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout, merge_weights=False) self.weight = weight self.bias = bias @@ -47,39 +52,42 @@ class LoraLinear(lora.LoRALayer, nn.Module): self.weight.data = self.weight.data.T def reset_parameters(self): - if hasattr(self, 'lora_A'): - # initialize A the same way as the default for nn.Linear and B to zero + if hasattr(self, "lora_A"): + # Initialize A with the default values for nn.Linear and set B to zero. nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5)) nn.init.zeros_(self.lora_B) def train(self, mode: bool = True): - - def T(w): - return w.T if self.fan_in_fan_out else w - - nn.Module.train(self, mode) - if self.merge_weights and self.merged: - # Make sure that the weights are not merged - if self.r > 0: - self.weight.data -= T(self.lora_B @ self.lora_A) * self.scaling - self.merged = False - - def eval(self): - def T(w): return w.T if self.fan_in_fan_out else w - nn.Module.eval(self) - if self.merge_weights and not self.merged: - # Merge the weights and mark it - if self.r > 0: - self.weight.data += T(self.lora_B @ self.lora_A) * self.scaling - delattr(self, 'lora_A') - delattr(self, 'lora_B') - self.merged = True + self.training = mode + if LORA_MANAGER.merge_weights: + if mode and self.merged: + warnings.warn("Invoke module.train() would unmerge LoRA weights.") + raise NotImplementedError("LoRA unmerge is not tested.") + # Make sure that the weights are not merged + if self.r > 0: + if not hasattr(self, "lora_A") or not hasattr(self, "lora_B"): + # FIXME(csric): temporary fix + self.lora_A = nn.Parameter(self.weight.new_empty((self.r, self.in_features))) + self.lora_B = nn.Parameter(self.weight.new_empty((self.out_features, self.r))) + self.reset_parameters() + else: + self.weight.data -= T(self.lora_B @ self.lora_A) * self.scaling + self.merged = False + elif not mode and not self.merged: + warnings.warn("Invoke module.eval() would merge LoRA weights.") + # Merge the weights and mark it + if self.r > 0: + self.weight.data += T(self.lora_B @ self.lora_A) * self.scaling + delattr(self, "lora_A") + delattr(self, "lora_B") + self.merged = True + + return self def forward(self, x: torch.Tensor): - def T(w): return w.T if self.fan_in_fan_out else w @@ -92,21 +100,23 @@ class LoraLinear(lora.LoRALayer, nn.Module): return F.linear(x, T(self.weight), bias=self.bias) -def lora_linear_wrapper(linear: nn.Linear, lora_rank: int) -> LoraLinear: - assert lora_rank <= linear.in_features, f'LoRA rank ({lora_rank}) must be less than or equal to in features ({linear.in_features})' - lora_linear = LoraLinear(linear.weight, linear.bias, r=lora_rank, merge_weights=False) +def _lora_linear_wrapper(linear: nn.Linear, lora_rank: int) -> LoraLinear: + assert ( + lora_rank <= linear.in_features + ), f"LoRA rank ({lora_rank}) must be less than or equal to in features ({linear.in_features})" + lora_linear = LoraLinear(linear.weight, linear.bias, r=lora_rank) return lora_linear -def convert_to_lora_recursively(module: nn.Module, lora_rank: int) -> None: +def _convert_to_lora_recursively(module: nn.Module, lora_rank: int) -> None: for name, child in module.named_children(): if isinstance(child, nn.Linear): - setattr(module, name, lora_linear_wrapper(child, lora_rank)) + setattr(module, name, _lora_linear_wrapper(child, lora_rank)) else: - convert_to_lora_recursively(child, lora_rank) + _convert_to_lora_recursively(child, lora_rank) -def convert_to_lora_module(module: nn.Module, lora_rank: int, lora_train_bias: str = 'none') -> nn.Module: +def convert_to_lora_module(module: nn.Module, lora_rank: int, lora_train_bias: str = "none") -> nn.Module: """Convert a torch.nn.Module to a LoRA module. Args: @@ -118,7 +128,7 @@ def convert_to_lora_module(module: nn.Module, lora_rank: int, lora_train_bias: s """ if lora_rank <= 0: return module - convert_to_lora_recursively(module, lora_rank) + _convert_to_lora_recursively(module, lora_rank) lora.mark_only_lora_as_trainable(module, lora_train_bias) return module @@ -134,7 +144,7 @@ class LoRAModule(nn.Module): Defaults to 'none'. """ - def __init__(self, lora_rank: int = 0, lora_train_bias: str = 'none') -> None: + def __init__(self, lora_rank: int = 0, lora_train_bias: str = "none") -> None: super().__init__() self.lora_rank = lora_rank self.lora_train_bias = lora_train_bias diff --git a/applications/Chat/coati/models/loss.py b/applications/Chat/coati/models/loss.py index 926c6e2a4e4131ece0693630bce4894e8cbec5f0..687bd0f7bfe76fc8865569f742fd31b3c8d4f868 100644 --- a/applications/Chat/coati/models/loss.py +++ b/applications/Chat/coati/models/loss.py @@ -13,6 +13,7 @@ class GPTLMLoss(nn.Module): def __init__(self): super().__init__() + # NOTE: default ignore_index is -100, which is equal to IGNORE_INDEX in sft_dataset.py self.loss = nn.CrossEntropyLoss() def forward(self, logits: torch.Tensor, labels: torch.Tensor) -> torch.Tensor: @@ -31,11 +32,13 @@ class PolicyLoss(nn.Module): super().__init__() self.clip_eps = clip_eps - def forward(self, - log_probs: torch.Tensor, - old_log_probs: torch.Tensor, - advantages: torch.Tensor, - action_mask: Optional[torch.Tensor] = None) -> torch.Tensor: + def forward( + self, + log_probs: torch.Tensor, + old_log_probs: torch.Tensor, + advantages: torch.Tensor, + action_mask: Optional[torch.Tensor] = None, + ) -> torch.Tensor: ratio = (log_probs - old_log_probs).exp() surr1 = ratio * advantages surr2 = ratio.clamp(1 - self.clip_eps, 1 + self.clip_eps) * advantages @@ -55,44 +58,21 @@ class ValueLoss(nn.Module): super().__init__() self.clip_eps = clip_eps - def forward(self, - values: torch.Tensor, - old_values: torch.Tensor, - reward: torch.Tensor, - action_mask: Optional[torch.Tensor] = None) -> torch.Tensor: + def forward( + self, + values: torch.Tensor, + old_values: torch.Tensor, + reward: torch.Tensor, + action_mask: Optional[torch.Tensor] = None, + ) -> torch.Tensor: values_clipped = old_values + (values - old_values).clamp(-self.clip_eps, self.clip_eps) - surr1 = (values_clipped - reward)**2 - surr2 = (values - reward)**2 + surr1 = (values_clipped - reward) ** 2 + surr2 = (values - reward) ** 2 loss = torch.max(surr1, surr2) loss = loss.mean() return 0.5 * loss -class PPOPtxActorLoss(nn.Module): - """ - To Do: - - PPO-ptx Actor Loss - """ - - def __init__(self, policy_clip_eps: float = 0.2, pretrain_coef: float = 0.0, pretrain_loss_fn=GPTLMLoss()) -> None: - super().__init__() - self.pretrain_coef = pretrain_coef - self.policy_loss_fn = PolicyLoss(clip_eps=policy_clip_eps) - self.pretrain_loss_fn = pretrain_loss_fn - - def forward(self, - log_probs: torch.Tensor, - old_log_probs: torch.Tensor, - advantages: torch.Tensor, - lm_logits: torch.Tensor, - lm_input_ids: torch.Tensor, - action_mask: Optional[torch.Tensor] = None) -> torch.Tensor: - policy_loss = self.policy_loss_fn(log_probs, old_log_probs, advantages, action_mask=action_mask) - lm_loss = self.pretrain_loss_fn(lm_logits, lm_input_ids) - return policy_loss + self.pretrain_coef * lm_loss - - class LogSigLoss(nn.Module): """ Pairwise Loss for Reward Model diff --git a/applications/Chat/coati/models/opt/__init__.py b/applications/Chat/coati/models/opt/__init__.py index 334f4df0032a1da8f8a3d23c9c987448c9dbc0d8..e37d6e45c8fc9deee35841878857b7d29fd7eed3 100644 --- a/applications/Chat/coati/models/opt/__init__.py +++ b/applications/Chat/coati/models/opt/__init__.py @@ -2,4 +2,4 @@ from .opt_actor import OPTActor from .opt_critic import OPTCritic from .opt_rm import OPTRM -__all__ = ['OPTActor', 'OPTCritic', 'OPTRM'] +__all__ = ["OPTActor", "OPTCritic", "OPTRM"] diff --git a/applications/Chat/coati/models/opt/opt_actor.py b/applications/Chat/coati/models/opt/opt_actor.py index c14e4377ffb2b00983f4ea4a1af7c9931e9d1cb9..cd8908e13fb8cc62de664023d491eb6090dc74b6 100644 --- a/applications/Chat/coati/models/opt/opt_actor.py +++ b/applications/Chat/coati/models/opt/opt_actor.py @@ -18,12 +18,14 @@ class OPTActor(Actor): lora_train_bias (str): LoRA bias training mode. """ - def __init__(self, - pretrained: Optional[str] = None, - config: Optional[OPTConfig] = None, - checkpoint: bool = False, - lora_rank: int = 0, - lora_train_bias: str = 'none') -> None: + def __init__( + self, + pretrained: Optional[str] = None, + config: Optional[OPTConfig] = None, + checkpoint: bool = False, + lora_rank: int = 0, + lora_train_bias: str = "none", + ) -> None: if pretrained is not None: model = OPTForCausalLM.from_pretrained(pretrained) elif config is not None: diff --git a/applications/Chat/coati/models/opt/opt_critic.py b/applications/Chat/coati/models/opt/opt_critic.py index fcfebd8a8b031785d93d709aead85d5aa6f30d08..f37d28812c27e184431245894ee506fe725696db 100644 --- a/applications/Chat/coati/models/opt/opt_critic.py +++ b/applications/Chat/coati/models/opt/opt_critic.py @@ -14,25 +14,24 @@ class OPTCritic(Critic): Args: pretrained (str): Pretrained model name or path. config (OPTConfig): Model config. - checkpoint (bool): Enable gradient checkpointing. lora_rank (int): Rank of the low-rank approximation. lora_train_bias (str): LoRA bias training mode. """ - def __init__(self, - pretrained: Optional[str] = None, - config: Optional[OPTConfig] = None, - checkpoint: bool = False, - lora_rank: int = 0, - lora_train_bias: str = 'none', - **kwargs) -> None: + def __init__( + self, + pretrained: Optional[str] = None, + config: Optional[OPTConfig] = None, + lora_rank: int = 0, + lora_train_bias: str = "none", + **kwargs, + ) -> None: if pretrained is not None: model = OPTModel.from_pretrained(pretrained) elif config is not None: model = OPTModel(config) else: model = OPTModel(OPTConfig()) - if checkpoint: - model.gradient_checkpointing_enable() + value_head = nn.Linear(model.config.word_embed_proj_dim, 1) super().__init__(model, value_head, lora_rank, lora_train_bias, **kwargs) diff --git a/applications/Chat/coati/models/opt/opt_rm.py b/applications/Chat/coati/models/opt/opt_rm.py index 50fc0dee8568f86d8d63e712104306e8c1e013b8..893708344ad4c5e0503b64f3118999c4e3aaa91b 100644 --- a/applications/Chat/coati/models/opt/opt_rm.py +++ b/applications/Chat/coati/models/opt/opt_rm.py @@ -13,25 +13,23 @@ class OPTRM(RewardModel): Args: pretrained (str): Pretrained model name or path. config (OPTConfig): Model config. - checkpoint (bool): Enable gradient checkpointing. lora_rank (int): Rank of the low-rank approximation. lora_train_bias (str): LoRA bias training mode. """ - def __init__(self, - pretrained: Optional[str] = None, - config: Optional[OPTConfig] = None, - checkpoint: bool = False, - lora_rank: int = 0, - lora_train_bias: str = 'none') -> None: + def __init__( + self, + pretrained: Optional[str] = None, + config: Optional[OPTConfig] = None, + lora_rank: int = 0, + lora_train_bias: str = "none", + ) -> None: if pretrained is not None: model = OPTModel.from_pretrained(pretrained) elif config is not None: model = OPTModel(config) else: model = OPTModel(OPTConfig()) - if checkpoint: - model.gradient_checkpointing_enable() value_head = nn.Linear(model.config.word_embed_proj_dim, 1) value_head.weight.data.normal_(mean=0.0, std=1 / (model.config.word_embed_proj_dim + 1)) diff --git a/applications/Chat/coati/models/roberta/__init__.py b/applications/Chat/coati/models/roberta/__init__.py deleted file mode 100644 index 0f4a8de067b1695c4abbc880f2e3dc1f6510ad26..0000000000000000000000000000000000000000 --- a/applications/Chat/coati/models/roberta/__init__.py +++ /dev/null @@ -1,5 +0,0 @@ -from .roberta_actor import RoBERTaActor -from .roberta_critic import RoBERTaCritic -from .roberta_rm import RoBERTaRM - -__all__ = ['RoBERTaActor', 'RoBERTaCritic', 'RoBERTaRM'] \ No newline at end of file diff --git a/applications/Chat/coati/models/roberta/roberta_actor.py b/applications/Chat/coati/models/roberta/roberta_actor.py deleted file mode 100644 index e35fa6eb19a8053a2ea5cec3f2a07a2dd3c80735..0000000000000000000000000000000000000000 --- a/applications/Chat/coati/models/roberta/roberta_actor.py +++ /dev/null @@ -1,35 +0,0 @@ -from typing import Optional - -from transformers.models.roberta.configuration_roberta import RobertaConfig -from transformers.models.roberta.modeling_roberta import RobertaForCausalLM - -from ..base import Actor - -class RoBERTaActor(Actor): - """ - RoBERTa Actor model. - - Args: - pretrained (str): Pretrained model name or path. - config (RoBERTaConfig): Model config. - checkpoint (bool): Enable gradient checkpointing. - lora_rank (int): Rank of the low-rank approximation. - lora_train_bias (str): LoRA bias training mode. - """ - - - def __init__(self, - pretrained: Optional[str] = None, - config: Optional[RobertaConfig] = None, - checkpoint: bool = False, - lora_rank: int = 0, - lora_train_bias: str = 'none') -> None: - if pretrained is not None: - model = RobertaForCausalLM.from_pretrained(pretrained) - elif config is not None: - model = RobertaForCausalLM(config) - else: - model = RobertaForCausalLM(RobertaConfig()) - if checkpoint: - model.gradient_checkpointing_enable() - super().__init__(model, lora_rank, lora_train_bias) diff --git a/applications/Chat/coati/models/roberta/roberta_critic.py b/applications/Chat/coati/models/roberta/roberta_critic.py deleted file mode 100644 index c8dc0d9e14f2813907a6345ffbf93784e9b8528c..0000000000000000000000000000000000000000 --- a/applications/Chat/coati/models/roberta/roberta_critic.py +++ /dev/null @@ -1,38 +0,0 @@ -from typing import Optional - -import torch.nn as nn -from transformers.models.roberta.configuration_roberta import RobertaConfig -from transformers.models.roberta.modeling_roberta import RobertaModel - -from ..base import Critic - - -class RoBERTaCritic(Critic): - """ - RoBERTa Critic model. - - Args: - pretrained (str): Pretrained model name or path. - config (RoBERTa Config): Model config. - checkpoint (bool): Enable gradient checkpointing. - lora_rank (int): Rank of the low-rank approximation. - lora_train_bias (str): LoRA bias training mode. - """ - - def __init__(self, - pretrained: Optional[str] = None, - config: Optional[RobertaConfig] = None, - checkpoint: bool = False, - lora_rank: int = 0, - lora_train_bias: str = 'none', - **kwargs) -> None: - if pretrained is not None: - model = RobertaModel.from_pretrained(pretrained, add_pooling_layer=False) - elif config is not None: - model = RobertaModel(config) - else: - model = RobertaModel(RobertaConfig()) - if checkpoint: - model.gradient_checkpointing_enable() - value_head = nn.Linear(model.config.hidden_size, 1) - super().__init__(model, value_head, lora_rank, lora_train_bias, **kwargs) diff --git a/applications/Chat/coati/models/roberta/roberta_rm.py b/applications/Chat/coati/models/roberta/roberta_rm.py deleted file mode 100644 index 77075052978b56d9bde336b3ac0473a343f6c332..0000000000000000000000000000000000000000 --- a/applications/Chat/coati/models/roberta/roberta_rm.py +++ /dev/null @@ -1,39 +0,0 @@ -from typing import Optional - -import torch.nn as nn -from transformers import RobertaConfig, RobertaModel - - -from ..base import RewardModel - - -class RoBERTaRM(RewardModel): - """ - RoBERTa Reward model. - - Args: - pretrained (str): Pretrained model name or path. - config (RoBERTaConfig): Model config. - checkpoint (bool): Enable gradient checkpointing. - lora_rank (int): Rank of the low-rank approximation. - lora_train_bias (str): LoRA bias training mode. - """ - - def __init__(self, - pretrained: Optional[str] = None, - config: Optional[RobertaConfig] = None, - checkpoint: bool = False, - lora_rank: int = 0, - lora_train_bias: str = 'none') -> None: - if pretrained is not None: - model = RobertaModel.from_pretrained(pretrained, add_pooling_layer=False) - elif config is not None: - model = RobertaModel(config) - else: - model = RobertaModel(RobertaConfig()) - if checkpoint: - model.gradient_checkpointing_enable() - - value_head = nn.Linear(model.config.hidden_size, 1) - value_head.weight.data.normal_(mean=0.0, std=1/(model.config.hidden_size + 1)) - super().__init__(model, value_head, lora_rank, lora_train_bias) \ No newline at end of file diff --git a/applications/Chat/coati/models/utils.py b/applications/Chat/coati/models/utils.py index 0ff13181fcd2f9d36aa413c0ff4a3f2b04a1ca0d..1aaef16620d2b18bb452b1083d0a988190840116 100644 --- a/applications/Chat/coati/models/utils.py +++ b/applications/Chat/coati/models/utils.py @@ -1,14 +1,12 @@ from typing import Optional, Union -import loralib as lora import torch -import torch.nn as nn import torch.nn.functional as F -def compute_approx_kl(log_probs: torch.Tensor, - log_probs_base: torch.Tensor, - action_mask: Optional[torch.Tensor] = None) -> torch.Tensor: +def _compute_approx_kl( + log_probs: torch.Tensor, log_probs_base: torch.Tensor, action_mask: Optional[torch.Tensor] = None +) -> torch.Tensor: """ Compute the approximate KL divergence between two distributions. Schulman blog: http://joschu.net/blog/kl-approx.html @@ -19,7 +17,7 @@ def compute_approx_kl(log_probs: torch.Tensor, action_mask: Mask for actions. """ - log_ratio = log_probs - log_probs_base + log_ratio = log_probs_base - log_probs approx_kl = (log_ratio.exp() - 1) - log_ratio if action_mask is not None: approx_kl = masked_mean(approx_kl, action_mask, dim=1) @@ -28,65 +26,44 @@ def compute_approx_kl(log_probs: torch.Tensor, return approx_kl -def compute_reward(r: Union[torch.Tensor, float], - kl_coef: float, - log_probs: torch.Tensor, - log_probs_base: torch.Tensor, - action_mask: Optional[torch.Tensor] = None) -> torch.Tensor: +def compute_reward( + r: Union[torch.Tensor, float], + kl_coef: float, + log_probs: torch.Tensor, + log_probs_base: torch.Tensor, + action_mask: Optional[torch.Tensor] = None, +) -> torch.Tensor: if kl_coef <= 0.0: return r - kl = compute_approx_kl(log_probs, log_probs_base, action_mask=action_mask) + kl = _compute_approx_kl(log_probs, log_probs_base, action_mask=action_mask) reward = r - kl_coef * kl return reward -def log_probs_from_logits(logits: torch.Tensor, labels: torch.Tensor) -> torch.Tensor: +def _log_probs_from_logits(logits: torch.Tensor, labels: torch.Tensor) -> torch.Tensor: log_probs = F.log_softmax(logits, dim=-1) log_probs_labels = log_probs.gather(dim=-1, index=labels.unsqueeze(-1)) return log_probs_labels.squeeze(-1) +def calc_action_log_probs(logits: torch.Tensor, sequences: torch.LongTensor, num_actions: int) -> torch.Tensor: + """Calculate action log probs. + + Args: + output (torch.Tensor): Output tensor of Actor.forward.logits. + sequences (torch.LongTensor): Input sequences. + num_actions (int): Number of actions. + + Returns: + torch.Tensor: Action log probs. + """ + log_probs = _log_probs_from_logits(logits[:, :-1, :], sequences[:, 1:]) + return log_probs[:, -num_actions:] + + def masked_mean(tensor: torch.Tensor, mask: torch.Tensor, dim: int = 1) -> torch.Tensor: tensor = tensor * mask tensor = tensor.sum(dim=dim) mask_sum = mask.sum(dim=dim) mean = tensor / (mask_sum + 1e-8) return mean - - -def masked_normalize(tensor: torch.Tensor, mask: torch.Tensor, dim: int = 1, eps: float = 1e-8) -> torch.Tensor: - tensor = tensor * mask - mean = masked_mean(tensor, mask, dim=dim) - mean_centered = tensor - mean - var = masked_mean(mean_centered**2, mask, dim=dim) - return mean_centered * var.clamp(min=eps).rsqrt() - - -def normalize(tensor: torch.Tensor, dim: int = 0, eps: float = 1e-8) -> torch.Tensor: - mean = tensor.mean(dim) - mean_centered = tensor - mean - var = (mean_centered**2).mean(dim) - norm = mean_centered * var.clamp(min=eps).rsqrt() - return norm - - -def convert_to_lora(model: nn.Module, - input_size: int, - output_size: int, - lora_rank: int = 16, - lora_alpha: int = 1, - lora_dropout: float = 0., - fan_in_fan_out: bool = False, - merge_weights: bool = True): - if lora_rank > min(input_size, output_size): - raise ValueError(f"LoRA rank {lora_rank} must be less or equal than {min(input_size, output_size)}") - - for name, module in model.named_modules(): - if isinstance(module, nn.Linear): - module._modules[name] = lora.Linear(input_size, - output_size, - r=lora_rank, - lora_alpha=lora_alpha, - lora_dropout=lora_dropout, - fan_in_fan_out=fan_in_fan_out, - merge_weights=merge_weights) diff --git a/applications/Chat/coati/quant/__init__.py b/applications/Chat/coati/quant/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..1765b8091bc3c599c493dcfdf93e8d30c5b340cb --- /dev/null +++ b/applications/Chat/coati/quant/__init__.py @@ -0,0 +1,7 @@ +from .llama_gptq import load_quant as llama_load_quant +from .utils import low_resource_init + +__all__ = [ + "llama_load_quant", + "low_resource_init", +] diff --git a/applications/Chat/coati/quant/llama_gptq/__init__.py b/applications/Chat/coati/quant/llama_gptq/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..51d5233586ad91a8d092c7ee3095e608385bce8d --- /dev/null +++ b/applications/Chat/coati/quant/llama_gptq/__init__.py @@ -0,0 +1,5 @@ +from .loader import load_quant + +__all__ = [ + "load_quant", +] diff --git a/applications/Chat/coati/quant/llama_gptq/loader.py b/applications/Chat/coati/quant/llama_gptq/loader.py new file mode 100644 index 0000000000000000000000000000000000000000..50486337a7ab55e148ad2c1530dca3491669b587 --- /dev/null +++ b/applications/Chat/coati/quant/llama_gptq/loader.py @@ -0,0 +1,27 @@ +import torch +import torch.nn as nn + +from .model_utils import find_layers +from .quant import make_quant + + +def load_quant(model: nn.Module, checkpoint: str, wbits: int, groupsize: int): + model = model.eval() + layers = find_layers(model) + + # ignore lm head + layers = find_layers(model) + for name in ["lm_head"]: + if name in layers: + del layers[name] + + make_quant(model, layers, wbits, groupsize) + + if checkpoint.endswith(".safetensors"): + from safetensors.torch import load_file as safe_load + + model.load_state_dict(safe_load(checkpoint)) + else: + model.load_state_dict(torch.load(checkpoint)) + + return model diff --git a/applications/Chat/coati/quant/llama_gptq/model_utils.py b/applications/Chat/coati/quant/llama_gptq/model_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..18e4e476150019b231781fea059b9a4d8232bbab --- /dev/null +++ b/applications/Chat/coati/quant/llama_gptq/model_utils.py @@ -0,0 +1,12 @@ +# copied from https://github.com/qwopqwop200/GPTQ-for-LLaMa/blob/past/modelutils.py + +import torch.nn as nn + + +def find_layers(module, layers=[nn.Conv2d, nn.Linear], name=""): + if type(module) in layers: + return {name: module} + res = {} + for name1, child in module.named_children(): + res.update(find_layers(child, layers=layers, name=name + "." + name1 if name != "" else name1)) + return res diff --git a/applications/Chat/inference/llama_gptq/quant.py b/applications/Chat/coati/quant/llama_gptq/quant.py similarity index 90% rename from applications/Chat/inference/llama_gptq/quant.py rename to applications/Chat/coati/quant/llama_gptq/quant.py index f7d5b7ce4bd8217bf246abbef0736c78be3869a6..5a7e2e72dfc55faa9c1e9172e43a6a7e91fb228f 100644 --- a/applications/Chat/inference/llama_gptq/quant.py +++ b/applications/Chat/coati/quant/llama_gptq/quant.py @@ -13,14 +13,13 @@ def quantize(x, scale, zero, maxq): class Quantizer(nn.Module): - def __init__(self, shape=1): super(Quantizer, self).__init__() - self.register_buffer('maxq', torch.tensor(0)) - self.register_buffer('scale', torch.zeros(shape)) - self.register_buffer('zero', torch.zeros(shape)) + self.register_buffer("maxq", torch.tensor(0)) + self.register_buffer("scale", torch.zeros(shape)) + self.register_buffer("zero", torch.zeros(shape)) - def configure(self, bits, perchannel=False, sym=True, mse=False, norm=2.4, grid=100, maxshrink=.8): + def configure(self, bits, perchannel=False, sym=True, mse=False, norm=2.4, grid=100, maxshrink=0.8): self.maxq = torch.tensor(2**bits - 1) self.perchannel = perchannel self.sym = sym @@ -68,7 +67,7 @@ class Quantizer(nn.Module): self.zero = torch.round(-xmin / self.scale) if self.mse: - best = torch.full([x.shape[0]], float('inf'), device=dev) + best = torch.full([x.shape[0]], float("inf"), device=dev) for i in range(int(self.maxshrink * self.grid)): p = 1 - i / self.grid xmin1 = p * xmin @@ -123,13 +122,12 @@ class Quantizer(nn.Module): try: import quant_cuda except: - print('CUDA extension not installed.') + print("CUDA extension not installed.") # Assumes layer is perfectly divisible into 256 * 256 blocks class QuantLinear(nn.Module): - def __init__(self, bits, groupsize, infeatures, outfeatures): super().__init__() if bits not in [2, 3, 4, 8]: @@ -142,11 +140,11 @@ class QuantLinear(nn.Module): groupsize = groupsize if groupsize != -1 else infeatures self.groupsize = groupsize self.register_buffer( - 'qzeros', torch.zeros((math.ceil(infeatures / groupsize), outfeatures // 256 * (bits * 8)), - dtype=torch.int)) - self.register_buffer('scales', torch.zeros((math.ceil(infeatures / groupsize), outfeatures))) - self.register_buffer('bias', torch.zeros(outfeatures)) - self.register_buffer('qweight', torch.zeros((infeatures // 256 * (bits * 8), outfeatures), dtype=torch.int)) + "qzeros", torch.zeros((math.ceil(infeatures / groupsize), outfeatures // 256 * (bits * 8)), dtype=torch.int) + ) + self.register_buffer("scales", torch.zeros((math.ceil(infeatures / groupsize), outfeatures))) + self.register_buffer("bias", torch.zeros(outfeatures)) + self.register_buffer("qweight", torch.zeros((infeatures // 256 * (bits * 8), outfeatures), dtype=torch.int)) self._initialized_quant_state = False def pack(self, linear, scales, zeros): @@ -161,8 +159,10 @@ class QuantLinear(nn.Module): for idx in range(self.infeatures): g_idx = idx // self.groupsize intweight.append( - torch.round((linear.weight.data[:, idx] + scale_zeros[g_idx]) / self.scales[g_idx]).to(torch.int)[:, - None]) + torch.round((linear.weight.data[:, idx] + scale_zeros[g_idx]) / self.scales[g_idx]).to(torch.int)[ + :, None + ] + ) intweight = torch.cat(intweight, dim=1) intweight = intweight.t().contiguous() intweight = intweight.numpy().astype(np.uint32) @@ -271,13 +271,13 @@ class QuantLinear(nn.Module): return y.reshape(outshape) -def make_quant(module, names, bits, groupsize, name=''): +def make_quant(module, names, bits, groupsize, name=""): if isinstance(module, QuantLinear): return for attr in dir(module): tmp = getattr(module, attr) - name1 = name + '.' + attr if name != '' else attr + name1 = name + "." + attr if name != "" else attr if name1 in names: setattr(module, attr, QuantLinear(bits, groupsize, tmp.in_features, tmp.out_features)) for name1, child in module.named_children(): - make_quant(child, names, bits, groupsize, name + '.' + name1 if name != '' else name1) + make_quant(child, names, bits, groupsize, name + "." + name1 if name != "" else name1) diff --git a/applications/Chat/coati/quant/utils.py b/applications/Chat/coati/quant/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..d102bb30f52dff6e7996a2d4c1d37f5a4c6d0815 --- /dev/null +++ b/applications/Chat/coati/quant/utils.py @@ -0,0 +1,27 @@ +from contextlib import contextmanager + +import torch + + +def _noop(*args, **kwargs): + pass + + +@contextmanager +def low_resource_init(): + """This context manager disables weight initialization and sets the default float dtype to half.""" + old_kaiming_uniform_ = torch.nn.init.kaiming_uniform_ + old_uniform_ = torch.nn.init.uniform_ + old_normal_ = torch.nn.init.normal_ + dtype = torch.get_default_dtype() + try: + torch.nn.init.kaiming_uniform_ = _noop + torch.nn.init.uniform_ = _noop + torch.nn.init.normal_ = _noop + torch.set_default_dtype(torch.half) + yield + finally: + torch.nn.init.kaiming_uniform_ = old_kaiming_uniform_ + torch.nn.init.uniform_ = old_uniform_ + torch.nn.init.normal_ = old_normal_ + torch.set_default_dtype(dtype) diff --git a/applications/Chat/coati/ray/README.md b/applications/Chat/coati/ray/README.md new file mode 100644 index 0000000000000000000000000000000000000000..79b1db3478279cd55a9a70245f445c75420c6b72 --- /dev/null +++ b/applications/Chat/coati/ray/README.md @@ -0,0 +1,175 @@ +:warning: **This content may be outdated since the major update of Colossal Chat. We will update this content soon.** + +# Distributed PPO Training on Stage 3 + +## Detach Experience Makers and Trainers + +We can completely separate the trainers and makers. + +

        + +

        + +- The experience maker performs inference, produces experience, and remotely delivers it to the trainer (1). +- The trainer consumes experience to train models, and periodically transmits new model parameters to the maker (2.1, 2.2). +- Using an experience buffer to overlap transmission and computing. + +In this manner, each node will work continuously without model idle time, and different optimization strategies can be applied for inference and training to meet the needs of speed or storage. It is also helpful for scalability. + +`DetachedPPOTrainer` and `ExperienceMakerHolder` are Ray Actors (distinguished from Actor Model), representing Trainer and Experience Maker on the graph above, respectively. + +[More about Ray Core](https://docs.ray.io/en/latest/ray-core/walkthrough.html) + +## Usage + +See examples at `ColossalAI/application/Chat/examples/ray` + +### Setup Makers + +- define makers' environment variables : + + ```python + env_info_makers = [{ + 'local_rank': '0', + 'rank': str(rank), + 'world_size': str(num_makers), + 'master_port': maker_port, + 'master_addr': master_addr + } for rank in range(num_makers)] + + ``` + +- define maker models : + + ```python + def model_fn(): + actor = get_actor_from_args(...) + critic = get_critic_from_args(...) + reward_model = get_reward_model_from_args(...) + initial_model = get_actor_from_args(...) + return actor, critic, reward_model, initial_model + + ``` + +- set experience_holder_refs : + + ```python + experience_holder_refs = [ + ExperienceMakerHolder.options( + name=f"maker_{i}", + num_gpus=1, + max_concurrency=2 + ).remote( + detached_trainer_name_list=[f"trainer_{x}" for x in target_trainers(...)], + model_fn=model_fn, + ...) + for i, env_info_maker in enumerate(env_info_makers) + ] + ``` + + The names in the `detached_trainer_name_list` refer to the target trainers that the maker should send experience to. + We set a trainer's name the same as a maker, by `.options(name="str")`. See below. + +### Setup Trainers + +- define trainers' environment variables : + ```python + env_info_trainers = [{ + 'local_rank': '0', + 'rank': str(rank), + 'world_size': str(num_trainers), + 'master_port': trainer_port, + 'master_addr': master_addr + } for rank in range(num_trainers)] + ``` +- define trainer models : + + ```python + def trainer_model_fn(): + actor = get_actor_from_args(...) + critic = get_critic_from_args(...) + return actor, critic + ``` + +- set trainer_refs : + ```python + trainer_refs = [ + DetachedPPOTrainer.options( + name=f"trainer{i}", + num_gpus=1, + max_concurrency=2 + ).remote( + experience_maker_holder_name_list=[f"maker{x}" for x in target_makers(...)], + model_fn = trainer_model_fn(), + ...) + for i, env_info_trainer in enumerate(env_info_trainers) + ] + ``` + The names in `experience_maker_holder_name_list` refer to the target makers that the trainer should send updated models to. + By setting `detached_trainer_name_list` and `experience_maker_holder_name_list`, we can customize the transmission graph. + +### Launch Jobs + +- define data_loader : + + ```python + def data_loader_fn(): + return = torch.utils.data.DataLoader(dataset=dataset) + + ``` + +- launch makers : + + ```python + wait_tasks = [] + for experience_holder_ref in experience_holder_refs: + wait_tasks.append( + experience_holder_ref.workingloop.remote(data_loader_fn(), + num_steps=experience_steps)) + + ``` + +- launch trainers : + + ```python + for trainer_ref in trainer_refs: + wait_tasks.append(trainer_ref.fit.remote(total_steps, update_steps, train_epochs)) + ``` + +- wait for done : + ```python + ray.get(wait_tasks) + ``` + +## Flexible Structure + +We can deploy different strategies to makers and trainers. Here are some notions. + +### 2 Makers 1 Trainer + +

        + +

        + +### 2 Makers 2 Trainer + +

        + +

        + +### Maker Inference Quantization + +

        + +

        + +### Tensor Parallel + +

        + +

        + +## TODO + +- [ ] Support LoRA +- [ ] Support TP & PP diff --git a/applications/Chat/coati/ray/__init__.py b/applications/Chat/coati/ray/__init__.py index 5802c05bc03feb3b755754de7385a2141e547673..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 100644 --- a/applications/Chat/coati/ray/__init__.py +++ b/applications/Chat/coati/ray/__init__.py @@ -1,2 +0,0 @@ -from .src.detached_replay_buffer import DetachedReplayBuffer -from .src.detached_trainer_ppo import DetachedPPOTrainer diff --git a/applications/Chat/coati/ray/callbacks/__init__.py b/applications/Chat/coati/ray/callbacks/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..5f5e488f383e3846b428ff15a8c013ea4df07663 --- /dev/null +++ b/applications/Chat/coati/ray/callbacks/__init__.py @@ -0,0 +1,9 @@ +from .base import MakerCallback, TrainerCallback +from .performance_evaluator import ExperienceMakerPerformanceEvaluator, TrainerPerformanceEvaluator + +__all__ = [ + "TrainerCallback", + "MakerCallback", + "ExperienceMakerPerformanceEvaluator", + "TrainerPerformanceEvaluator", +] diff --git a/applications/Chat/coati/ray/callbacks/base.py b/applications/Chat/coati/ray/callbacks/base.py new file mode 100644 index 0000000000000000000000000000000000000000..8c5bd8a67776843493a0ca4d4412196a093259d7 --- /dev/null +++ b/applications/Chat/coati/ray/callbacks/base.py @@ -0,0 +1,65 @@ +from abc import ABC + +from coati.experience_maker import Experience + + +class TrainerCallback(ABC): + """ + Base callback class. It defines the interface for callbacks. + """ + + def on_fit_start(self) -> None: + pass + + def on_fit_end(self) -> None: + pass + + def on_episode_start(self, episode: int) -> None: + pass + + def on_episode_end(self, episode: int) -> None: + pass + + def on_epoch_start(self, epoch: int) -> None: + pass + + def on_epoch_end(self, epoch: int) -> None: + pass + + def on_batch_start(self) -> None: + pass + + def on_batch_end(self, metrics: dict, experience: Experience) -> None: + pass + + def on_update_start(self) -> None: + pass + + def on_update_end(self) -> None: + pass + + +class MakerCallback(ABC): + def on_loop_start(self) -> None: + pass + + def on_loop_end(self) -> None: + pass + + def on_make_experience_start(self) -> None: + pass + + def on_make_experience_end(self, experience: Experience) -> None: + pass + + def on_send_start(self) -> None: + pass + + def on_send_end(self) -> None: + pass + + def on_batch_start(self) -> None: + pass + + def on_batch_end(self) -> None: + pass diff --git a/applications/Chat/coati/ray/callbacks/performance_evaluator.py b/applications/Chat/coati/ray/callbacks/performance_evaluator.py new file mode 100644 index 0000000000000000000000000000000000000000..18798bce7dcebe449ae280c86f4a219b56f3586c --- /dev/null +++ b/applications/Chat/coati/ray/callbacks/performance_evaluator.py @@ -0,0 +1,214 @@ +from time import time +from typing import Optional + +import torch +import torch.distributed as dist +from coati.experience_maker import Experience + +from .base import MakerCallback, TrainerCallback + + +def get_world_size() -> int: + if dist.is_initialized(): + return dist.get_world_size() + return 1 + + +def print_rank_0(*args, **kwargs) -> None: + if not dist.is_initialized() or dist.get_rank() == 0: + print(*args, **kwargs) + + +@torch.no_grad() +def all_reduce_mean(x: float, world_size: int) -> float: + if world_size == 1: + return x + tensor = torch.tensor([x], device=torch.cuda.current_device()) + dist.all_reduce(tensor) + tensor = tensor / world_size + return tensor.item() + + +class Timer: + def __init__(self) -> None: + self.start_time: Optional[float] = None + self.duration: float = 0.0 + + def start(self) -> None: + self.start_time = time() + + def end(self) -> None: + self.duration += time() - self.start_time + + def reset(self) -> None: + self.duration = 0.0 + + +class ExperienceMakerPerformanceEvaluator(MakerCallback): + def __init__( + self, actor_num_params: int, critic_num_params: int, initial_model_num_params: int, reward_model_num_params: int + ) -> None: + super().__init__() + self.world_size = get_world_size() + self.actor_num_params = actor_num_params + self.critic_num_params = critic_num_params + self.initial_model_num_params = initial_model_num_params + self.reward_model_num_params = reward_model_num_params + + self.batch_timer = Timer() + self.send_timer = Timer() + self.make_experience_timer = Timer() + self.total_samples: int = 0 + self.make_experience_flop: int = 0 + + print_rank_0( + f"ExperienceMaker actor: {actor_num_params/1024**3:.2f}B, critic: {critic_num_params/1024**3:.2f}B, initial model: {initial_model_num_params/1024**3:.2f}B, reward model: {reward_model_num_params/1024**3:.2f}B, world size: {self.world_size}" + ) + + def on_make_experience_start(self) -> None: + self.make_experience_timer.start() + + def on_make_experience_end(self, experience: Experience) -> None: + self.make_experience_timer.end() + + batch_size, seq_len = experience.sequences.shape + + self.total_samples += batch_size + + # actor generate + num_actions = experience.action_mask.size(1) + input_len = seq_len - num_actions + total_seq_len = (input_len + seq_len - 1) * num_actions / 2 + self.make_experience_flop += self.actor_num_params * batch_size * total_seq_len * 2 + # actor forward + self.make_experience_flop += self.actor_num_params * batch_size * seq_len * 2 + # critic forward + self.make_experience_flop += self.critic_num_params * batch_size * seq_len * 2 + # initial model forward + self.make_experience_flop += self.initial_model_num_params * batch_size * seq_len * 2 + # reward model forward + self.make_experience_flop += self.reward_model_num_params * batch_size * seq_len * 2 + + def on_send_start(self) -> None: + self.send_timer.start() + + def on_send_end(self) -> None: + self.send_timer.end() + + def on_batch_start(self) -> None: + self.batch_timer.start() + + def on_batch_end(self) -> None: + self.batch_timer.end() + + def on_loop_end(self) -> None: + avg_make_experience_duration = all_reduce_mean(self.make_experience_timer.duration, self.world_size) + avg_overall_duration = all_reduce_mean(self.batch_timer.duration, self.world_size) + avg_send_duration = all_reduce_mean(self.send_timer.duration, self.world_size) + + avg_throughput = self.total_samples * self.world_size / (avg_overall_duration + 1e-12) + avg_make_experience_tflops = self.make_experience_flop / 1e12 / (avg_make_experience_duration + 1e-12) + avg_time_per_sample = (avg_overall_duration + 1e-12) / (self.total_samples * self.world_size) + avg_make_experience_time_per_sample = (avg_make_experience_duration + 1e-12) / ( + self.total_samples * self.world_size + ) + avg_send_time_per_sample = (avg_send_duration + 1e-12) / (self.total_samples * self.world_size) + + print_rank_0( + "Making Experience Performance Summary:\n" + + f"Throughput: {avg_throughput:.3f} samples/sec\n" + + f"TFLOPS per GPU: {avg_make_experience_tflops:.3f}\n" + + f"Sample time (overall): {avg_time_per_sample:.3f} s\n" + + f"Sample time (make experience): {avg_make_experience_time_per_sample:.3f} s, {avg_make_experience_time_per_sample/avg_time_per_sample*100:.2f}%\n" + + f"Sample time (send): {avg_send_time_per_sample:.3f} s, {avg_send_time_per_sample/avg_time_per_sample*100:.2f}%\n" + ) + + +class TrainerPerformanceEvaluator(TrainerCallback): + def __init__( + self, + actor_num_params: int, + critic_num_params: int, + enable_grad_checkpoint: bool = False, + ignore_first_episodes: int = 1, + ) -> None: + super().__init__() + self.world_size = get_world_size() + self.actor_num_params = actor_num_params + self.critic_num_params = critic_num_params + self.enable_grad_checkpoint = enable_grad_checkpoint + self.ignore_first_episodes = ignore_first_episodes + self.ignore_this_episode = False + + self.episode_timer = Timer() + self.batch_timer = Timer() + self.update_timer = Timer() + self.total_samples: int = 0 + self.learn_flop: int = 0 + + print_rank_0( + f"Trainer actor: {self.actor_num_params/1024**3:.2f}B, critic: {self.critic_num_params/1024**3:.2f}B, world size: {self.world_size}" + ) + + def on_episode_start(self, episodes: int) -> None: + self.ignore_this_episode = episodes < self.ignore_first_episodes + if self.ignore_this_episode: + return + self.episode_timer.start() + + def on_episode_end(self, episodes: int) -> None: + if self.ignore_this_episode: + return + self.episode_timer.end() + + def on_batch_start(self) -> None: + if self.ignore_this_episode: + return + self.batch_timer.start() + + def on_batch_end(self, metrics: dict, experience: Experience) -> None: + if self.ignore_this_episode: + return + self.batch_timer.end() + + batch_size, seq_len = experience.sequences.shape + + self.total_samples += batch_size + + # actor forward-backward, 3 means forward(1) + backward(2) + self.learn_flop += self.actor_num_params * batch_size * seq_len * 2 * (3 + int(self.enable_grad_checkpoint)) + # critic forward-backward + self.learn_flop += self.critic_num_params * batch_size * seq_len * 2 * (3 + int(self.enable_grad_checkpoint)) + + def on_update_start(self) -> None: + if self.ignore_this_episode: + return + self.update_timer.start() + + def on_update_end(self) -> None: + if self.ignore_this_episode: + return + self.update_timer.end() + + def on_fit_end(self) -> None: + if self.total_samples == 0: + print_rank_0("No samples are collected, skip trainer performance evaluation") + return + avg_train_duration = all_reduce_mean(self.batch_timer.duration, self.world_size) + avg_update_duration = all_reduce_mean(self.update_timer.duration, self.world_size) + avg_episode_duration = all_reduce_mean(self.episode_timer.duration, self.world_size) + + avg_throughput = self.total_samples * self.world_size / (avg_episode_duration + 1e-12) + avg_learn_tflops = self.learn_flop / 1e12 / (avg_train_duration + 1e-12) + avg_time_per_sample = (avg_episode_duration + 1e-12) / (self.total_samples * self.world_size) + avg_train_time_per_sample = (avg_train_duration + 1e-12) / (self.total_samples * self.world_size) + avg_update_time_per_sample = (avg_update_duration + 1e-12) / (self.total_samples * self.world_size) + + print_rank_0( + "Learning Performance Summary:\n" + + f"Throughput: {avg_throughput:.3f} samples/sec\n" + + f"TFLOPS per GPU: {avg_learn_tflops:.3f}\n" + + f"Sample time (overall): {avg_time_per_sample:.3f} s\n" + + f"Sample time (train): {avg_train_time_per_sample:.3f} s, {avg_train_time_per_sample/avg_time_per_sample*100:.2f}%\n" + + f"Sample time (update): {avg_update_time_per_sample:.3f} s, {avg_update_time_per_sample/avg_time_per_sample*100:.2f}%\n" + ) diff --git a/applications/Chat/coati/ray/detached_replay_buffer.py b/applications/Chat/coati/ray/detached_replay_buffer.py new file mode 100644 index 0000000000000000000000000000000000000000..92dab17292f7ba36de6c3e96d5fc8d6be17d6e89 --- /dev/null +++ b/applications/Chat/coati/ray/detached_replay_buffer.py @@ -0,0 +1,70 @@ +from typing import List + +import torch +from coati.experience_buffer.utils import BufferItem, make_experience_batch, split_experience_batch +from coati.experience_maker.base import Experience + +# from torch.multiprocessing import Queue +from ray.util.queue import Queue + + +class DetachedReplayBuffer: + """ + Detached replay buffer. Share Experience across workers on the same node. + Therefore, a trainer node is expected to have only one instance. + It is ExperienceMakerHolder's duty to call append(exp) method, remotely. + + Args: + sample_batch_size: Batch size when sampling. Exp won't enqueue until they formed a batch. + tp_world_size: Number of workers in the same tp group + limit: Limit of number of experience sample BATCHs. A number <= 0 means unlimited. Defaults to 0. + cpu_offload: Whether to offload experience to cpu when sampling. Defaults to True. + """ + + def __init__(self, sample_batch_size: int, limit: int = 0) -> None: + self.sample_batch_size = sample_batch_size + self.limit = limit + self.items = Queue(self.limit, actor_options={"num_cpus": 1}) + self.batch_collector: List[BufferItem] = [] + + @torch.no_grad() + def append(self, experience: Experience) -> None: + """ + Expected to be called remotely. + """ + items = split_experience_batch(experience) + self.extend(items) + + @torch.no_grad() + def extend(self, items: List[BufferItem]) -> None: + """ + Expected to be called remotely. + """ + self.batch_collector.extend(items) + while len(self.batch_collector) >= self.sample_batch_size: + items = self.batch_collector[: self.sample_batch_size] + experience = make_experience_batch(items) + self.items.put(experience, block=True) + self.batch_collector = self.batch_collector[self.sample_batch_size :] + + def clear(self) -> None: + # self.items.close() + self.items.shutdown() + self.items = Queue(self.limit) + self.worker_state = [False] * self.tp_world_size + self.batch_collector = [] + + @torch.no_grad() + def sample(self, worker_rank=0, to_device="cpu") -> Experience: + ret = self._sample_and_erase() + ret.to_device(to_device) + return ret + + @torch.no_grad() + def _sample_and_erase(self) -> Experience: + ret = self.items.get(block=True) + return ret + + def get_length(self) -> int: + ret = self.items.qsize() + return ret diff --git a/applications/Chat/coati/ray/detached_trainer_base.py b/applications/Chat/coati/ray/detached_trainer_base.py new file mode 100644 index 0000000000000000000000000000000000000000..fcf0a472df9e19d51b6a64fea3389d5823fdad60 --- /dev/null +++ b/applications/Chat/coati/ray/detached_trainer_base.py @@ -0,0 +1,179 @@ +import os +from abc import ABC, abstractmethod +from typing import Any, Dict, List + +import ray +import torch +from coati.experience_buffer.utils import BufferItem +from coati.experience_maker import Experience +from torch.utils.data import DataLoader +from tqdm import tqdm + +from .callbacks import TrainerCallback +from .detached_replay_buffer import DetachedReplayBuffer +from .utils import is_rank_0 + + +class DetachedTrainer(ABC): + """ + Base class for detached rlhf trainers. + 'detach' means that the experience maker is detached compared to a normal Trainer. + Please set name attribute during init: + >>> trainer = DetachedTrainer.options(..., name = "xxx", ...).remote() + So an ExperienceMakerHolder can reach the detached_replay_buffer by Actor's name. + Args: + detached_strategy (DetachedStrategy): the strategy to use for training + detached_replay_buffer_ref (ObjectRef[DetachedReplayBuffer]): the replay buffer to use for training + data_loader_pin_memory (bool, defaults to True): whether to pin memory for data loader + callbacks (List[Callback], defaults to []): the callbacks to call during training process + generate_kwargs (dict, optional): the kwargs to use while model generating + + """ + + def __init__( + self, + experience_maker_holder_name_list: List[str], + train_batch_size: int = 8, + buffer_limit: int = 0, + dataloader_pin_memory: bool = True, + callbacks: List[TrainerCallback] = [], + debug: bool = False, + ) -> None: + super().__init__() + self.detached_replay_buffer = DetachedReplayBuffer(train_batch_size, limit=buffer_limit) + self.dataloader_pin_memory = dataloader_pin_memory + self.callbacks = callbacks + self.target_holder_name_list = experience_maker_holder_name_list + self.target_holder_list = [] + self._is_target_holder_initialized = False + self._debug = debug + + def update_target_holder_list(self): + # as the length of target_holder_list may be zero, we need to check it by a bool flag + if not self._is_target_holder_initialized: + for name in self.target_holder_name_list: + self.target_holder_list.append(ray.get_actor(name, namespace=os.environ["RAY_NAMESPACE"])) + self._is_target_holder_initialized = True + + @abstractmethod + def _update_remote_makers(self, fully_update: bool = False, **kwargs): + pass + + def sync_models_to_remote_makers(self, **kwargs): + self._update_remote_makers(fully_update=True, **kwargs) + + @abstractmethod + def training_step(self, experience: Experience) -> Dict[str, Any]: + pass + + def _learn(self, update_steps: int, train_epochs: int) -> None: + data = [] + # warmup + pbar = tqdm(range(update_steps), desc=f"Train epoch [1/{train_epochs}]", disable=not is_rank_0()) + self._on_epoch_start(0) + self._learn_epoch(pbar, data) + self._on_epoch_end(0) + # item is already a batch + dataloader = DataLoader( + data, batch_size=1, shuffle=True, pin_memory=self.dataloader_pin_memory, collate_fn=lambda x: x[0] + ) + for epoch in range(1, train_epochs): + pbar = tqdm(dataloader, desc=f"Train epoch [{epoch + 1}/{train_epochs}]", disable=not is_rank_0()) + self._on_epoch_start(epoch) + self._learn_epoch(pbar, data) + self._on_epoch_end(epoch) + + def _learn_epoch(self, pbar: tqdm, data: List[Experience]) -> None: + is_warmup = len(data) == 0 + for x in pbar: + if self._debug: + print("[trainer] training step") + # sample a batch and then train to avoid waiting + experience = x if not is_warmup else self._buffer_sample() + experience.to_device(torch.cuda.current_device()) + self._on_batch_start() + metrics = self.training_step(experience) + self._on_batch_end(metrics, experience) + + if self._debug: + print("[trainer] step over") + experience.to_device("cpu") + if is_warmup: + data.append(experience) + pbar.set_postfix(metrics) + + def fit(self, total_steps: int, update_steps: int, train_epochs: int = 1) -> None: + self._on_fit_start() + for i in tqdm(range(total_steps // update_steps), desc="Trainer", disable=not is_rank_0()): + self._on_episode_start(i) + self._learn(update_steps, train_epochs) + self._on_update_start() + self._update_remote_makers() + self._on_update_end() + self._on_episode_end(i) + self._on_fit_end() + + @ray.method(concurrency_group="buffer_length") + def buffer_get_length(self): + # called by ExperienceMakerHolder + if self._debug: + print("[trainer] telling length") + return self.detached_replay_buffer.get_length() + + @ray.method(concurrency_group="buffer_append") + def buffer_append(self, experience: Experience): + # called by ExperienceMakerHolder + if self._debug: + print(f"[trainer] receiving exp.") + self.detached_replay_buffer.append(experience) + + @ray.method(concurrency_group="buffer_append") + def buffer_extend(self, items: List[BufferItem]): + # called by ExperienceMakerHolder + if self._debug: + print(f"[trainer] receiving exp.") + self.detached_replay_buffer.extend(items) + + @ray.method(concurrency_group="buffer_sample") + def _buffer_sample(self): + return self.detached_replay_buffer.sample() + + def _on_fit_start(self) -> None: + for callback in self.callbacks: + callback.on_fit_start() + + def _on_fit_end(self) -> None: + for callback in self.callbacks: + callback.on_fit_end() + + def _on_episode_start(self, episode: int) -> None: + for callback in self.callbacks: + callback.on_episode_start(episode) + + def _on_episode_end(self, episode: int) -> None: + for callback in self.callbacks: + callback.on_episode_end(episode) + + def _on_epoch_start(self, epoch: int) -> None: + for callback in self.callbacks: + callback.on_epoch_start(epoch) + + def _on_epoch_end(self, epoch: int) -> None: + for callback in self.callbacks: + callback.on_epoch_end(epoch) + + def _on_batch_start(self) -> None: + for callback in self.callbacks: + callback.on_batch_start() + + def _on_batch_end(self, metrics: dict, experience: Experience) -> None: + for callback in self.callbacks: + callback.on_batch_end(metrics, experience) + + def _on_update_start(self) -> None: + for callback in self.callbacks: + callback.on_update_start() + + def _on_update_end(self) -> None: + for callback in self.callbacks: + callback.on_update_end() diff --git a/applications/Chat/coati/ray/detached_trainer_ppo.py b/applications/Chat/coati/ray/detached_trainer_ppo.py new file mode 100644 index 0000000000000000000000000000000000000000..ef84a1ddba48d63c4debeb575fe76b77d75fbfae --- /dev/null +++ b/applications/Chat/coati/ray/detached_trainer_ppo.py @@ -0,0 +1,191 @@ +from typing import Callable, Dict, List, Tuple + +import ray +import torch +from coati.experience_maker import Experience +from coati.models.base import Actor, Critic +from coati.models.loss import PolicyLoss, ValueLoss +from coati.trainer.strategies import GeminiStrategy, LowLevelZeroStrategy, Strategy +from torch.optim import Adam + +from colossalai.nn.optimizer import HybridAdam + +from .callbacks import TrainerCallback, TrainerPerformanceEvaluator +from .detached_trainer_base import DetachedTrainer +from .lora_constructor import LoRAConstructor +from .utils import get_model_numel, get_rank, set_dist_env, state_dict_to + + +@ray.remote( + concurrency_groups={"buffer_length": 1, "buffer_append": 1, "buffer_sample": 1, "model_io": 1, "compute": 1} +) +class DetachedPPOTrainer(DetachedTrainer): + """ + Detached Trainer for PPO algorithm + Args: + strategy (Strategy): the strategy to use for training + model (str) : for actor / critic init + pretrained (str) : for actor / critic init + lora_rank (int) : for actor / critic init + train_batch_size (int, defaults to 8): the batch size to use for training + train_batch_size (int, defaults to 8): the batch size to use for training + buffer_limit (int, defaults to 0): the max_size limitation of replay buffer + buffer_cpu_offload (bool, defaults to True): whether to offload replay buffer to cpu + eps_clip (float, defaults to 0.2): the clip coefficient of policy loss + value_clip (float, defaults to 0.4): the clip coefficient of value loss + experience_batch_size (int, defaults to 8): the batch size to use for experience generation + max_epochs (int, defaults to 1): the number of epochs of training process + dataloader_pin_memory (bool, defaults to True): whether to pin memory for data loader + callbacks (List[Callback], defaults to []): the callbacks to call during training process + generate_kwargs (dict, optional): the kwargs to use while model generating + """ + + def __init__( + self, + experience_maker_holder_name_list: List[str], + strategy_fn: Callable[[], Strategy], + model_fn: Callable[[], Tuple[Actor, Critic]], + env_info: Dict[str, str] = None, + train_batch_size: int = 8, + buffer_limit: int = 0, + eps_clip: float = 0.2, + value_clip: float = 0.4, + dataloader_pin_memory: bool = True, + callbacks: List[TrainerCallback] = [], + eval_performance: bool = False, + debug: bool = False, + update_lora_weights: bool = False, + ) -> None: + # set environment variables + if env_info: + set_dist_env(env_info=env_info) + # configure strategy + self.strategy = strategy_fn() + # configure models, loss and optimizers + with self.strategy.model_init_context(): + self.actor, self.critic = model_fn() + + if eval_performance: + actor_numel = get_model_numel(self.actor) + critic_numel = get_model_numel(self.critic) + evaluator = TrainerPerformanceEvaluator(actor_numel, critic_numel) + callbacks = callbacks + [evaluator] + + if isinstance(self.strategy, (LowLevelZeroStrategy, GeminiStrategy)): + self.actor_optim = HybridAdam(self.actor.parameters(), lr=1e-7) + self.critic_optim = HybridAdam(self.critic.parameters(), lr=1e-7) + else: + self.actor_optim = Adam(self.actor.parameters(), lr=1e-7) + self.critic_optim = Adam(self.critic.parameters(), lr=1e-7) + + (self.actor, self.actor_optim), (self.critic, self.critic_optim) = self.strategy.prepare( + (self.actor, self.actor_optim), (self.critic, self.critic_optim) + ) + + # configure trainer + self.actor_loss_fn = PolicyLoss(eps_clip) + self.critic_loss_fn = ValueLoss(value_clip) + + super().__init__( + experience_maker_holder_name_list, + train_batch_size=train_batch_size, + buffer_limit=buffer_limit, + dataloader_pin_memory=dataloader_pin_memory, + callbacks=callbacks, + debug=debug, + ) + if self._debug: + print(f"[trainer{get_rank()}] will send state dict to {experience_maker_holder_name_list}") + + self._update_lora_weights = update_lora_weights + + @ray.method(concurrency_group="model_io") + @torch.no_grad() + def _update_remote_makers(self, fully_update: bool = False, **config): + # TODO: balance duties + if not fully_update: + config["requires_grad_only"] = True + self.update_target_holder_list() + # mark start, ensure order + tasks = [] + for target_holder in self.target_holder_list: + tasks.append(target_holder.update_experience_maker.remote(chunk_start=True, fully_update=fully_update)) + ray.get(tasks) + # sending loop + tasks = [] + + for state_dict_shard in self._get_model_state_dict_shard(self.actor, fully_update=fully_update, **config): + for target_holder in self.target_holder_list: + tasks.append( + target_holder.update_experience_maker.remote( + new_actor_state_dict=state_dict_shard, + new_actor_lora_config_dict=self._get_model_lora_config_dict(self.actor), + fully_update=fully_update, + ) + ) + # sending loop + for state_dict_shard in self._get_model_state_dict_shard(self.critic, fully_update=fully_update, **config): + for target_holder in self.target_holder_list: + tasks.append( + target_holder.update_experience_maker.remote( + new_critic_state_dict=state_dict_shard, + new_critic_lora_config_dict=self._get_model_lora_config_dict(self.critic), + fully_update=fully_update, + ) + ) + ray.get(tasks) + # mark end + for target_holder in self.target_holder_list: + target_holder.update_experience_maker.remote(chunk_end=True, fully_update=fully_update) + + @ray.method(concurrency_group="compute") + def training_step(self, experience: Experience) -> Dict[str, float]: + self.actor.train() + self.critic.train() + + num_actions = experience.action_mask.size(1) + action_log_probs = self.actor(experience.sequences, num_actions, attention_mask=experience.attention_mask) + actor_loss = self.actor_loss_fn( + action_log_probs, experience.action_log_probs, experience.advantages, action_mask=experience.action_mask + ) + self.strategy.backward(actor_loss, self.actor, self.actor_optim) + self.strategy.optimizer_step(self.actor_optim) + self.actor_optim.zero_grad() + + values = self.critic( + experience.sequences, action_mask=experience.action_mask, attention_mask=experience.attention_mask + ) + critic_loss = self.critic_loss_fn( + values, experience.values, experience.reward, action_mask=experience.action_mask + ) + + self.strategy.backward(critic_loss, self.critic, self.critic_optim) + self.strategy.optimizer_step(self.critic_optim) + self.critic_optim.zero_grad() + return {"actor_loss": actor_loss.item(), "critic_loss": critic_loss.item()} + + def strategy_save_actor(self, path: str, only_rank0: bool = False) -> None: + self.strategy.save_model(self.actor, path, only_rank0) + + def strategy_save_critic(self, path: str, only_rank0: bool = False) -> None: + self.strategy.save_model(self.critic, path, only_rank0) + + def strategy_save_actor_optim(self, path: str, only_rank0: bool = False) -> None: + self.strategy.save_optimizer(self.actor_optim, path, only_rank0) + + def strategy_save_critic_optim(self, path: str, only_rank0: bool = False) -> None: + self.strategy.save_optimizer(self.critic_optim, path, only_rank0) + + def _get_model_state_dict_shard(self, model: torch.nn.Module, fully_update=False, **config): + for state_dict in self.strategy.get_model_state_dict_shard(model, **config): + if not self._update_lora_weights or fully_update: + yield state_dict_to(state_dict) + else: + state_dict_lora, _ = LoRAConstructor.filter_state_dict_lora(state_dict) + yield state_dict_to(state_dict_lora) + + def _get_model_lora_config_dict(self, model: torch.nn.Module): + if not self._update_lora_weights: + return None + unwrapped_model = self.strategy.unwrap_model(model) + return LoRAConstructor.extract_lora_config(unwrapped_model) diff --git a/applications/Chat/coati/ray/example/1m1t.py b/applications/Chat/coati/ray/example/1m1t.py deleted file mode 100644 index a6527370505b9ce87a1985b8b3d45bcbf0c103a3..0000000000000000000000000000000000000000 --- a/applications/Chat/coati/ray/example/1m1t.py +++ /dev/null @@ -1,153 +0,0 @@ -import argparse -from copy import deepcopy - -import pandas as pd -import torch -from coati.trainer import PPOTrainer - - -from coati.ray.src.experience_maker_holder import ExperienceMakerHolder -from coati.ray.src.detached_trainer_ppo import DetachedPPOTrainer - -from coati.trainer.strategies import ColossalAIStrategy, DDPStrategy, NaiveStrategy -from coati.experience_maker import NaiveExperienceMaker -from torch.optim import Adam -from transformers import AutoTokenizer, BloomTokenizerFast -from transformers.models.gpt2.tokenization_gpt2 import GPT2Tokenizer - -from colossalai.nn.optimizer import HybridAdam - -import ray -import os -import socket - -def get_free_port(): - with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: - s.bind(('', 0)) - return s.getsockname()[1] - - -def get_local_ip(): - with socket.socket(socket.AF_INET, socket.SOCK_DGRAM) as s: - s.connect(('8.8.8.8', 80)) - return s.getsockname()[0] - -def main(args): - master_addr = str(get_local_ip()) - # trainer_env_info - trainer_port = str(get_free_port()) - env_info_trainer = {'local_rank' : '0', - 'rank' : '0', - 'world_size' : '1', - 'master_port' : trainer_port, - 'master_addr' : master_addr} - - # maker_env_info - maker_port = str(get_free_port()) - env_info_maker = {'local_rank' : '0', - 'rank' : '0', - 'world_size' : '1', - 'master_port' : maker_port, - 'master_addr' : master_addr} - - # configure tokenizer - if args.model == 'gpt2': - tokenizer = GPT2Tokenizer.from_pretrained('gpt2') - tokenizer.pad_token = tokenizer.eos_token - elif args.model == 'bloom': - tokenizer = BloomTokenizerFast.from_pretrained(args.pretrain) - tokenizer.pad_token = tokenizer.eos_token - elif args.model == 'opt': - tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m") - else: - raise ValueError(f'Unsupported model "{args.model}"') - - # configure Trainer - trainer_ref = DetachedPPOTrainer.options(name="trainer1", num_gpus=1, max_concurrency=2).remote( - experience_maker_holder_name_list=["maker1"], - strategy=args.trainer_strategy, - model=args.model, - env_info = env_info_trainer, - pretrained=args.pretrain, - lora_rank=args.lora_rank, - train_batch_size=args.train_batch_size, - buffer_limit=16, - experience_batch_size=args.experience_batch_size, - max_epochs=args.max_epochs, - #kwargs: - max_length=128, - do_sample=True, - temperature=1.0, - top_k=50, - pad_token_id=tokenizer.pad_token_id, - eos_token_id=tokenizer.eos_token_id, - debug=args.debug, - ) - - # configure Experience Maker - experience_holder_ref = ExperienceMakerHolder.options(name="maker1", num_gpus=1, max_concurrency=2).remote( - detached_trainer_name_list=["trainer1"], - strategy=args.maker_strategy, - env_info = env_info_maker, - experience_batch_size=args.experience_batch_size, - kl_coef=0.1, - #kwargs: - max_length=128, - do_sample=True, - temperature=1.0, - top_k=50, - pad_token_id=tokenizer.pad_token_id, - eos_token_id=tokenizer.eos_token_id, - debug=args.debug, - ) - - # trainer send its actor and critic to experience holders. - ray.get(trainer_ref.initialize_remote_makers.remote()) - - # configure sampler - dataset = pd.read_csv(args.prompt_path)['prompt'] - - def tokenize_fn(texts): - # MUST padding to max length to ensure inputs of all ranks have the same length - # Different length may lead to hang when using gemini, as different generation steps - batch = tokenizer(texts, return_tensors='pt', max_length=96, padding='max_length', truncation=True) - return {k: v.cuda() for k, v in batch.items()} - - trainer_done_ref = trainer_ref.fit.remote(num_episodes=args.num_episodes, max_timesteps=args.max_timesteps, update_timesteps=args.update_timesteps) - num_exp_per_maker = args.num_episodes * args.max_timesteps // args.update_timesteps * args.max_epochs + 3 # +3 for fault tolerance - maker_done_ref = experience_holder_ref.workingloop.remote(dataset, tokenize_fn, times=num_exp_per_maker) - - ray.get([trainer_done_ref, maker_done_ref]) - - # save model checkpoint after fitting - trainer_ref.strategy_save_actor.remote(args.save_path, only_rank0=True) - # save optimizer checkpoint on all ranks - if args.need_optim_ckpt: - trainer_ref.strategy_save_actor_optim.remote('actor_optim_checkpoint_prompts_%d.pt' % (torch.cuda.current_device()), - only_rank0=False) - -if __name__ == '__main__': - parser = argparse.ArgumentParser() - parser.add_argument('prompt_path') - parser.add_argument('--trainer_strategy', - choices=['naive', 'ddp', 'colossalai_gemini', 'colossalai_zero2'], - default='naive') - parser.add_argument('--maker_strategy', - choices=['naive', 'ddp', 'colossalai_gemini', 'colossalai_zero2'], - default='naive') - parser.add_argument('--model', default='gpt2', choices=['gpt2', 'bloom', 'opt']) - parser.add_argument('--pretrain', type=str, default=None) - parser.add_argument('--save_path', type=str, default='actor_checkpoint_prompts.pt') - parser.add_argument('--need_optim_ckpt', type=bool, default=False) - parser.add_argument('--num_episodes', type=int, default=10) - parser.add_argument('--max_timesteps', type=int, default=10) - parser.add_argument('--update_timesteps', type=int, default=10) - parser.add_argument('--max_epochs', type=int, default=5) - parser.add_argument('--train_batch_size', type=int, default=8) - parser.add_argument('--experience_batch_size', type=int, default=8) - parser.add_argument('--lora_rank', type=int, default=0, help="low-rank adaptation matrices rank") - - parser.add_argument('--debug', action='store_true') - args = parser.parse_args() - ray.init(namespace=os.environ["RAY_NAMESPACE"]) - main(args) diff --git a/applications/Chat/coati/ray/example/1m1t.sh b/applications/Chat/coati/ray/example/1m1t.sh deleted file mode 100644 index f7c5054c800eb376ae973a5103482aec97b0511f..0000000000000000000000000000000000000000 --- a/applications/Chat/coati/ray/example/1m1t.sh +++ /dev/null @@ -1,23 +0,0 @@ -set_n_least_used_CUDA_VISIBLE_DEVICES() { - local n=${1:-"9999"} - echo "GPU Memory Usage:" - local FIRST_N_GPU_IDS=$(nvidia-smi --query-gpu=memory.used --format=csv \ - | tail -n +2 \ - | nl -v 0 \ - | tee /dev/tty \ - | sort -g -k 2 \ - | awk '{print $1}' \ - | head -n $n) - export CUDA_VISIBLE_DEVICES=$(echo $FIRST_N_GPU_IDS | sed 's/ /,/g') - echo "Now CUDA_VISIBLE_DEVICES is set to:" - echo "CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES" -} - -set_n_least_used_CUDA_VISIBLE_DEVICES 2 - -export RAY_NAMESPACE="admin" - -python 1m1t.py "/path/to/prompts.csv" \ - --trainer_strategy colossalai_zero2 --maker_strategy naive --lora_rank 2 --pretrain "facebook/opt-350m" --model 'opt' \ - --num_episodes 10 --max_timesteps 10 --update_timesteps 10 \ - --max_epochs 10 --debug diff --git a/applications/Chat/coati/ray/example/1m2t.py b/applications/Chat/coati/ray/example/1m2t.py deleted file mode 100644 index 3883c364a8e02fe0adcfa742599870d90f78631d..0000000000000000000000000000000000000000 --- a/applications/Chat/coati/ray/example/1m2t.py +++ /dev/null @@ -1,186 +0,0 @@ -import argparse -from copy import deepcopy - -import pandas as pd -import torch -from coati.trainer import PPOTrainer - - -from coati.ray.src.experience_maker_holder import ExperienceMakerHolder -from coati.ray.src.detached_trainer_ppo import DetachedPPOTrainer - -from coati.trainer.strategies import ColossalAIStrategy, DDPStrategy, NaiveStrategy -from coati.experience_maker import NaiveExperienceMaker -from torch.optim import Adam -from transformers import AutoTokenizer, BloomTokenizerFast -from transformers.models.gpt2.tokenization_gpt2 import GPT2Tokenizer - -from colossalai.nn.optimizer import HybridAdam - -import ray -import os -import socket - - -def get_free_port(): - with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: - s.bind(('', 0)) - return s.getsockname()[1] - - -def get_local_ip(): - with socket.socket(socket.AF_INET, socket.SOCK_DGRAM) as s: - s.connect(('8.8.8.8', 80)) - return s.getsockname()[0] - -def main(args): - master_addr = str(get_local_ip()) - # trainer_env_info - trainer_port = str(get_free_port()) - env_info_trainer_1 = {'local_rank' : '0', - 'rank' : '0', - 'world_size' : '2', - 'master_port' : trainer_port, - 'master_addr' : master_addr} - env_info_trainer_2 = {'local_rank' : '0', - 'rank' : '1', - 'world_size' : '2', - 'master_port' : trainer_port, - 'master_addr' : master_addr} - # maker_env_info - maker_port = str(get_free_port()) - env_info_maker_1 = {'local_rank' : '0', - 'rank' : '0', - 'world_size' : '2', - 'master_port' : maker_port, - 'master_addr' : master_addr} - print([env_info_trainer_1, - env_info_trainer_2, - env_info_maker_1]) - ray.init(dashboard_port = 1145) - # configure tokenizer - if args.model == 'gpt2': - tokenizer = GPT2Tokenizer.from_pretrained('gpt2') - tokenizer.pad_token = tokenizer.eos_token - elif args.model == 'bloom': - tokenizer = BloomTokenizerFast.from_pretrained(args.pretrain) - tokenizer.pad_token = tokenizer.eos_token - elif args.model == 'opt': - tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m") - else: - raise ValueError(f'Unsupported model "{args.model}"') - - # configure Trainer - trainer_1_ref = DetachedPPOTrainer.options(name="trainer1", namespace=os.environ["RAY_NAMESPACE"], num_gpus=1, max_concurrency=2).remote( - experience_maker_holder_name_list=["maker1"], - strategy=args.trainer_strategy, - model=args.model, - env_info=env_info_trainer_1, - pretrained=args.pretrain, - lora_rank=args.lora_rank, - train_batch_size=args.train_batch_size, - buffer_limit=16, - experience_batch_size=args.experience_batch_size, - max_epochs=args.max_epochs, - #kwargs: - max_length=128, - do_sample=True, - temperature=1.0, - top_k=50, - pad_token_id=tokenizer.pad_token_id, - eos_token_id=tokenizer.eos_token_id, - debug=args.debug, - ) - - trainer_2_ref = DetachedPPOTrainer.options(name="trainer2", namespace=os.environ["RAY_NAMESPACE"], num_gpus=1, max_concurrency=2).remote( - experience_maker_holder_name_list=["maker1"], - strategy=args.trainer_strategy, - model=args.model, - env_info=env_info_trainer_2, - pretrained=args.pretrain, - lora_rank=args.lora_rank, - train_batch_size=args.train_batch_size, - buffer_limit=16, - experience_batch_size=args.experience_batch_size, - max_epochs=args.max_epochs, - #kwargs: - max_length=128, - do_sample=True, - temperature=1.0, - top_k=50, - pad_token_id=tokenizer.pad_token_id, - eos_token_id=tokenizer.eos_token_id, - debug= args.debug, - ) - - # configure Experience Maker - experience_holder_1_ref = ExperienceMakerHolder.options(name="maker1", namespace=os.environ["RAY_NAMESPACE"], num_gpus=1, max_concurrency=2).remote( - detached_trainer_name_list=["trainer1", "trainer2"], - strategy=args.maker_strategy, - env_info=env_info_maker_1, - experience_batch_size=args.experience_batch_size, - kl_coef=0.1, - #kwargs: - max_length=128, - do_sample=True, - temperature=1.0, - top_k=50, - pad_token_id=tokenizer.pad_token_id, - eos_token_id=tokenizer.eos_token_id, - debug=args.debug, - ) - - # trainer send its actor and critic to experience holders. - # TODO: balance duty - ray.get(trainer_1_ref.initialize_remote_makers.remote()) - - # configure sampler - dataset = pd.read_csv(args.prompt_path)['prompt'] - - def tokenize_fn(texts): - # MUST padding to max length to ensure inputs of all ranks have the same length - # Different length may lead to hang when using gemini, as different generation steps - batch = tokenizer(texts, return_tensors='pt', max_length=96, padding='max_length', truncation=True) - return {k: v.cuda() for k, v in batch.items()} - - trainer_1_done_ref = trainer_1_ref.fit.remote(num_episodes=args.num_episodes, max_timesteps=args.max_timesteps, update_timesteps=args.update_timesteps) - trainer_2_done_ref = trainer_2_ref.fit.remote(num_episodes=args.num_episodes, max_timesteps=args.max_timesteps, update_timesteps=args.update_timesteps) - num_exp_per_maker = args.num_episodes * args.max_timesteps // args.update_timesteps * args.max_epochs * 2 + 3 # +3 for fault tolerance - maker_1_done_ref = experience_holder_1_ref.workingloop.remote(dataset, tokenize_fn, times=num_exp_per_maker) - - ray.get([trainer_1_done_ref, trainer_2_done_ref, maker_1_done_ref]) - # save model checkpoint after fitting - trainer_1_ref.strategy_save_actor.remote(args.save_path, only_rank0=True) - trainer_2_ref.strategy_save_actor.remote(args.save_path, only_rank0=True) - # save optimizer checkpoint on all ranks - if args.need_optim_ckpt: - trainer_1_ref.strategy_save_actor_optim.remote('actor_optim_checkpoint_prompts_%d.pt' % (torch.cuda.current_device()), - only_rank0=False) - trainer_2_ref.strategy_save_actor_optim.remote('actor_optim_checkpoint_prompts_%d.pt' % (torch.cuda.current_device()), - only_rank0=False) - - -if __name__ == '__main__': - parser = argparse.ArgumentParser() - parser.add_argument('prompt_path') - parser.add_argument('--trainer_strategy', - choices=['naive', 'ddp', 'colossalai_gemini', 'colossalai_zero2'], - default='naive') - parser.add_argument('--maker_strategy', - choices=['naive', 'ddp', 'colossalai_gemini', 'colossalai_zero2'], - default='naive') - parser.add_argument('--model', default='gpt2', choices=['gpt2', 'bloom', 'opt']) - parser.add_argument('--pretrain', type=str, default=None) - parser.add_argument('--save_path', type=str, default='actor_checkpoint_prompts.pt') - parser.add_argument('--need_optim_ckpt', type=bool, default=False) - parser.add_argument('--num_episodes', type=int, default=10) - parser.add_argument('--max_timesteps', type=int, default=10) - parser.add_argument('--update_timesteps', type=int, default=10) - parser.add_argument('--max_epochs', type=int, default=5) - parser.add_argument('--train_batch_size', type=int, default=8) - parser.add_argument('--experience_batch_size', type=int, default=8) - parser.add_argument('--lora_rank', type=int, default=0, help="low-rank adaptation matrices rank") - - parser.add_argument('--debug', action='store_true') - args = parser.parse_args() - main(args) diff --git a/applications/Chat/coati/ray/example/1m2t.sh b/applications/Chat/coati/ray/example/1m2t.sh deleted file mode 100644 index 669f4141026c25ef405d2502a506d1bb019520ea..0000000000000000000000000000000000000000 --- a/applications/Chat/coati/ray/example/1m2t.sh +++ /dev/null @@ -1,23 +0,0 @@ -set_n_least_used_CUDA_VISIBLE_DEVICES() { - local n=${1:-"9999"} - echo "GPU Memory Usage:" - local FIRST_N_GPU_IDS=$(nvidia-smi --query-gpu=memory.used --format=csv \ - | tail -n +2 \ - | nl -v 0 \ - | tee /dev/tty \ - | sort -g -k 2 \ - | awk '{print $1}' \ - | head -n $n) - export CUDA_VISIBLE_DEVICES=$(echo $FIRST_N_GPU_IDS | sed 's/ /,/g') - echo "Now CUDA_VISIBLE_DEVICES is set to:" - echo "CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES" -} - -set_n_least_used_CUDA_VISIBLE_DEVICES 2 - -export RAY_NAMESPACE="admin" - -python 1m2t.py "/path/to/prompts.csv" --model gpt2 \ - --maker_strategy naive --trainer_strategy ddp --lora_rank 2 \ - --num_episodes 10 --max_timesteps 10 --update_timesteps 10 \ - --max_epochs 10 #--debug \ No newline at end of file diff --git a/applications/Chat/coati/ray/example/2m1t.py b/applications/Chat/coati/ray/example/2m1t.py deleted file mode 100644 index b655de1ab1fa987987fbd5954b6c94db98b27c50..0000000000000000000000000000000000000000 --- a/applications/Chat/coati/ray/example/2m1t.py +++ /dev/null @@ -1,140 +0,0 @@ -import argparse -from copy import deepcopy - -import pandas as pd -import torch -from coati.trainer import PPOTrainer - - -from coati.ray.src.experience_maker_holder import ExperienceMakerHolder -from coati.ray.src.detached_trainer_ppo import DetachedPPOTrainer - -from coati.trainer.strategies import ColossalAIStrategy, DDPStrategy, NaiveStrategy -from coati.experience_maker import NaiveExperienceMaker -from torch.optim import Adam -from transformers import AutoTokenizer, BloomTokenizerFast -from transformers.models.gpt2.tokenization_gpt2 import GPT2Tokenizer - -from colossalai.nn.optimizer import HybridAdam - -import ray -import os -import socket - - -def main(args): - # configure tokenizer - if args.model == 'gpt2': - tokenizer = GPT2Tokenizer.from_pretrained('gpt2') - tokenizer.pad_token = tokenizer.eos_token - elif args.model == 'bloom': - tokenizer = BloomTokenizerFast.from_pretrained(args.pretrain) - tokenizer.pad_token = tokenizer.eos_token - elif args.model == 'opt': - tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m") - else: - raise ValueError(f'Unsupported model "{args.model}"') - - # configure Trainer - trainer_ref = DetachedPPOTrainer.options(name="trainer1", num_gpus=1, max_concurrency=2).remote( - experience_maker_holder_name_list=["maker1", "maker2"], - strategy=args.trainer_strategy, - model=args.model, - pretrained=args.pretrain, - lora_rank=args.lora_rank, - train_batch_size=args.train_batch_size, - buffer_limit=16, - experience_batch_size=args.experience_batch_size, - max_epochs=args.max_epochs, - #kwargs: - max_length=128, - do_sample=True, - temperature=1.0, - top_k=50, - pad_token_id=tokenizer.pad_token_id, - eos_token_id=tokenizer.eos_token_id, - debug=args.debug, - ) - - # configure Experience Maker - experience_holder_1_ref = ExperienceMakerHolder.options(name="maker1", num_gpus=1, max_concurrency=2).remote( - detached_trainer_name_list=["trainer1"], - strategy=args.maker_strategy, - experience_batch_size=args.experience_batch_size, - kl_coef=0.1, - #kwargs: - max_length=128, - do_sample=True, - temperature=1.0, - top_k=50, - pad_token_id=tokenizer.pad_token_id, - eos_token_id=tokenizer.eos_token_id, - debug=args.debug, - ) - - experience_holder_2_ref = ExperienceMakerHolder.options(name="maker2", num_gpus=1, max_concurrency=2).remote( - detached_trainer_name_list=["trainer1"], - strategy=args.maker_strategy, - experience_batch_size=args.experience_batch_size, - kl_coef=0.1, - #kwargs: - max_length=128, - do_sample=True, - temperature=1.0, - top_k=50, - pad_token_id=tokenizer.pad_token_id, - eos_token_id=tokenizer.eos_token_id, - debug=args.debug, - ) - - # trainer send its actor and critic to experience holders. - ray.get(trainer_ref.initialize_remote_makers.remote()) - - # configure sampler - dataset = pd.read_csv(args.prompt_path)['prompt'] - - def tokenize_fn(texts): - # MUST padding to max length to ensure inputs of all ranks have the same length - # Different length may lead to hang when using gemini, as different generation steps - batch = tokenizer(texts, return_tensors='pt', max_length=96, padding='max_length', truncation=True) - return {k: v.cuda() for k, v in batch.items()} - - trainer_done_ref = trainer_ref.fit.remote(num_episodes=args.num_episodes, max_timesteps=args.max_timesteps, update_timesteps=args.update_timesteps) - num_exp_per_maker = args.num_episodes * args.max_timesteps // args.update_timesteps * args.max_epochs // 2 + 3 # +3 for fault tolerance - maker_1_done_ref = experience_holder_1_ref.workingloop.remote(dataset, tokenize_fn, times=num_exp_per_maker) - maker_2_done_ref = experience_holder_2_ref.workingloop.remote(dataset, tokenize_fn, times=num_exp_per_maker) - - ray.get([trainer_done_ref, maker_1_done_ref, maker_2_done_ref]) - - # save model checkpoint after fitting - trainer_ref.strategy_save_actor.remote(args.save_path, only_rank0=True) - # save optimizer checkpoint on all ranks - if args.need_optim_ckpt: - trainer_ref.strategy_save_actor_optim.remote('actor_optim_checkpoint_prompts_%d.pt' % (torch.cuda.current_device()), - only_rank0=False) - -if __name__ == '__main__': - parser = argparse.ArgumentParser() - parser.add_argument('prompt_path') - parser.add_argument('--trainer_strategy', - choices=['naive', 'ddp', 'colossalai_gemini', 'colossalai_zero2'], - default='naive') - parser.add_argument('--maker_strategy', - choices=['naive', 'ddp', 'colossalai_gemini', 'colossalai_zero2'], - default='naive') - parser.add_argument('--model', default='gpt2', choices=['gpt2', 'bloom', 'opt']) - parser.add_argument('--pretrain', type=str, default=None) - parser.add_argument('--save_path', type=str, default='actor_checkpoint_prompts.pt') - parser.add_argument('--need_optim_ckpt', type=bool, default=False) - parser.add_argument('--num_episodes', type=int, default=10) - parser.add_argument('--max_timesteps', type=int, default=10) - parser.add_argument('--update_timesteps', type=int, default=10) - parser.add_argument('--max_epochs', type=int, default=5) - parser.add_argument('--train_batch_size', type=int, default=8) - parser.add_argument('--experience_batch_size', type=int, default=8) - parser.add_argument('--lora_rank', type=int, default=0, help="low-rank adaptation matrices rank") - - parser.add_argument('--debug', action='store_true') - args = parser.parse_args() - ray.init(namespace=os.environ["RAY_NAMESPACE"]) - main(args) diff --git a/applications/Chat/coati/ray/example/2m1t.sh b/applications/Chat/coati/ray/example/2m1t.sh deleted file mode 100644 index a207d4118d605a8c3a17882bcf3cc6b1a8f32eb3..0000000000000000000000000000000000000000 --- a/applications/Chat/coati/ray/example/2m1t.sh +++ /dev/null @@ -1,23 +0,0 @@ -set_n_least_used_CUDA_VISIBLE_DEVICES() { - local n=${1:-"9999"} - echo "GPU Memory Usage:" - local FIRST_N_GPU_IDS=$(nvidia-smi --query-gpu=memory.used --format=csv \ - | tail -n +2 \ - | nl -v 0 \ - | tee /dev/tty \ - | sort -g -k 2 \ - | awk '{print $1}' \ - | head -n $n) - export CUDA_VISIBLE_DEVICES=$(echo $FIRST_N_GPU_IDS | sed 's/ /,/g') - echo "Now CUDA_VISIBLE_DEVICES is set to:" - echo "CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES" -} - -set_n_least_used_CUDA_VISIBLE_DEVICES 3 - -export RAY_NAMESPACE="admin" - -python 2m1t.py "/path/to/prompts.csv" \ - --trainer_strategy naive --maker_strategy naive --lora_rank 2 --pretrain "facebook/opt-350m" --model 'opt' \ - --num_episodes 10 --max_timesteps 10 --update_timesteps 10 \ - --max_epochs 10 # --debug diff --git a/applications/Chat/coati/ray/example/2m2t.py b/applications/Chat/coati/ray/example/2m2t.py deleted file mode 100644 index 435c71915fc2820b9b12bbd2204823f12ec2438f..0000000000000000000000000000000000000000 --- a/applications/Chat/coati/ray/example/2m2t.py +++ /dev/null @@ -1,209 +0,0 @@ -import argparse -from copy import deepcopy - -import pandas as pd -import torch -from coati.trainer import PPOTrainer - - -from coati.ray.src.experience_maker_holder import ExperienceMakerHolder -from coati.ray.src.detached_trainer_ppo import DetachedPPOTrainer - -from coati.trainer.strategies import ColossalAIStrategy, DDPStrategy, NaiveStrategy -from coati.experience_maker import NaiveExperienceMaker -from torch.optim import Adam -from transformers import AutoTokenizer, BloomTokenizerFast -from transformers.models.gpt2.tokenization_gpt2 import GPT2Tokenizer - -from colossalai.nn.optimizer import HybridAdam - -import ray -import os -import socket - - -def get_free_port(): - with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: - s.bind(('', 0)) - return s.getsockname()[1] - - -def get_local_ip(): - with socket.socket(socket.AF_INET, socket.SOCK_DGRAM) as s: - s.connect(('8.8.8.8', 80)) - return s.getsockname()[0] - -def main(args): - master_addr = str(get_local_ip()) - # trainer_env_info - trainer_port = str(get_free_port()) - env_info_trainer_1 = {'local_rank' : '0', - 'rank' : '0', - 'world_size' : '2', - 'master_port' : trainer_port, - 'master_addr' : master_addr} - env_info_trainer_2 = {'local_rank' : '0', - 'rank' : '1', - 'world_size' : '2', - 'master_port' : trainer_port, - 'master_addr' : master_addr} - # maker_env_info - maker_port = str(get_free_port()) - env_info_maker_1 = {'local_rank' : '0', - 'rank' : '0', - 'world_size' : '2', - 'master_port' : maker_port, - 'master_addr' : master_addr} - env_info_maker_2 = {'local_rank' : '0', - 'rank' : '1', - 'world_size' : '2', - 'master_port': maker_port, - 'master_addr' : master_addr} - print([env_info_trainer_1, - env_info_trainer_2, - env_info_maker_1, - env_info_maker_2]) - ray.init() - # configure tokenizer - if args.model == 'gpt2': - tokenizer = GPT2Tokenizer.from_pretrained('gpt2') - tokenizer.pad_token = tokenizer.eos_token - elif args.model == 'bloom': - tokenizer = BloomTokenizerFast.from_pretrained(args.pretrain) - tokenizer.pad_token = tokenizer.eos_token - elif args.model == 'opt': - tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m") - else: - raise ValueError(f'Unsupported model "{args.model}"') - - # configure Trainer - trainer_1_ref = DetachedPPOTrainer.options(name="trainer1", namespace=os.environ["RAY_NAMESPACE"], num_gpus=1, max_concurrency=2).remote( - experience_maker_holder_name_list=["maker1", "maker2"], - strategy=args.trainer_strategy, - model=args.model, - env_info=env_info_trainer_1, - pretrained=args.pretrain, - lora_rank=args.lora_rank, - train_batch_size=args.train_batch_size, - buffer_limit=16, - experience_batch_size=args.experience_batch_size, - max_epochs=args.max_epochs, - #kwargs: - max_length=128, - do_sample=True, - temperature=1.0, - top_k=50, - pad_token_id=tokenizer.pad_token_id, - eos_token_id=tokenizer.eos_token_id, - debug=args.debug, - ) - - trainer_2_ref = DetachedPPOTrainer.options(name="trainer2", namespace=os.environ["RAY_NAMESPACE"], num_gpus=1, max_concurrency=2).remote( - experience_maker_holder_name_list=["maker1", "maker2"], - strategy=args.trainer_strategy, - model=args.model, - env_info=env_info_trainer_2, - pretrained=args.pretrain, - lora_rank=args.lora_rank, - train_batch_size=args.train_batch_size, - buffer_limit=16, - experience_batch_size=args.experience_batch_size, - max_epochs=args.max_epochs, - #kwargs: - max_length=128, - do_sample=True, - temperature=1.0, - top_k=50, - pad_token_id=tokenizer.pad_token_id, - eos_token_id=tokenizer.eos_token_id, - debug=args.debug, - ) - - # configure Experience Maker - experience_holder_1_ref = ExperienceMakerHolder.options(name="maker1", namespace=os.environ["RAY_NAMESPACE"], num_gpus=1, max_concurrency=2).remote( - detached_trainer_name_list=["trainer1", "trainer2"], - strategy=args.maker_strategy, - env_info=env_info_maker_1, - experience_batch_size=args.experience_batch_size, - kl_coef=0.1, - #kwargs: - max_length=128, - do_sample=True, - temperature=1.0, - top_k=50, - pad_token_id=tokenizer.pad_token_id, - eos_token_id=tokenizer.eos_token_id, - debug=args.debug, - ) - - experience_holder_2_ref = ExperienceMakerHolder.options(name="maker2", namespace=os.environ["RAY_NAMESPACE"], num_gpus=1, max_concurrency=2).remote( - detached_trainer_name_list=["trainer1", "trainer2"], - strategy=args.maker_strategy, - env_info=env_info_maker_2, - experience_batch_size=args.experience_batch_size, - kl_coef=0.1, - #kwargs: - max_length=128, - do_sample=True, - temperature=1.0, - top_k=50, - pad_token_id=tokenizer.pad_token_id, - eos_token_id=tokenizer.eos_token_id, - debug=args.debug, - ) - - # trainer send its actor and critic to experience holders. - # TODO: balance duty - ray.get(trainer_1_ref.initialize_remote_makers.remote()) - - # configure sampler - dataset = pd.read_csv(args.prompt_path)['prompt'] - - def tokenize_fn(texts): - # MUST padding to max length to ensure inputs of all ranks have the same length - # Different length may lead to hang when using gemini, as different generation steps - batch = tokenizer(texts, return_tensors='pt', max_length=96, padding='max_length', truncation=True) - return {k: v.cuda() for k, v in batch.items()} - - trainer_1_done_ref = trainer_1_ref.fit.remote(num_episodes=args.num_episodes, max_timesteps=args.max_timesteps, update_timesteps=args.update_timesteps) - trainer_2_done_ref = trainer_2_ref.fit.remote(num_episodes=args.num_episodes, max_timesteps=args.max_timesteps, update_timesteps=args.update_timesteps) - num_exp_per_maker = args.num_episodes * args.max_timesteps // args.update_timesteps * args.max_epochs + 3 # +3 for fault tolerance - maker_1_done_ref = experience_holder_1_ref.workingloop.remote(dataset, tokenize_fn, times=num_exp_per_maker) - maker_2_done_ref = experience_holder_2_ref.workingloop.remote(dataset, tokenize_fn, times=num_exp_per_maker) - - ray.get([trainer_1_done_ref, trainer_2_done_ref, maker_1_done_ref, maker_2_done_ref]) - # save model checkpoint after fitting - trainer_1_ref.strategy_save_actor.remote(args.save_path, only_rank0=True) - trainer_2_ref.strategy_save_actor.remote(args.save_path, only_rank0=True) - # save optimizer checkpoint on all ranks - if args.need_optim_ckpt: - trainer_1_ref.strategy_save_actor_optim.remote('actor_optim_checkpoint_prompts_%d.pt' % (torch.cuda.current_device()), - only_rank0=False) - trainer_2_ref.strategy_save_actor_optim.remote('actor_optim_checkpoint_prompts_%d.pt' % (torch.cuda.current_device()), - only_rank0=False) - - -if __name__ == '__main__': - parser = argparse.ArgumentParser() - parser.add_argument('prompt_path') - parser.add_argument('--trainer_strategy', - choices=['naive', 'ddp', 'colossalai_gemini', 'colossalai_zero2'], - default='naive') - parser.add_argument('--maker_strategy', - choices=['naive', 'ddp', 'colossalai_gemini', 'colossalai_zero2'], - default='naive') - parser.add_argument('--model', default='gpt2', choices=['gpt2', 'bloom', 'opt']) - parser.add_argument('--pretrain', type=str, default=None) - parser.add_argument('--save_path', type=str, default='actor_checkpoint_prompts.pt') - parser.add_argument('--need_optim_ckpt', type=bool, default=False) - parser.add_argument('--num_episodes', type=int, default=10) - parser.add_argument('--max_timesteps', type=int, default=10) - parser.add_argument('--update_timesteps', type=int, default=10) - parser.add_argument('--max_epochs', type=int, default=5) - parser.add_argument('--train_batch_size', type=int, default=8) - parser.add_argument('--experience_batch_size', type=int, default=8) - parser.add_argument('--lora_rank', type=int, default=0, help="low-rank adaptation matrices rank") - - parser.add_argument('--debug', action='store_true') - args = parser.parse_args() - main(args) diff --git a/applications/Chat/coati/ray/example/2m2t.sh b/applications/Chat/coati/ray/example/2m2t.sh deleted file mode 100644 index fb4024766c54182efcd752f991bfc11c1db7588e..0000000000000000000000000000000000000000 --- a/applications/Chat/coati/ray/example/2m2t.sh +++ /dev/null @@ -1,23 +0,0 @@ -set_n_least_used_CUDA_VISIBLE_DEVICES() { - local n=${1:-"9999"} - echo "GPU Memory Usage:" - local FIRST_N_GPU_IDS=$(nvidia-smi --query-gpu=memory.used --format=csv \ - | tail -n +2 \ - | nl -v 0 \ - | tee /dev/tty \ - | sort -g -k 2 \ - | awk '{print $1}' \ - | head -n $n) - export CUDA_VISIBLE_DEVICES=$(echo $FIRST_N_GPU_IDS | sed 's/ /,/g') - echo "Now CUDA_VISIBLE_DEVICES is set to:" - echo "CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES" -} - -set_n_least_used_CUDA_VISIBLE_DEVICES 2 - -export RAY_NAMESPACE="admin" - -python 2m2t.py "path/to/prompts.csv" \ - --maker_strategy naive --trainer_strategy colossalai_zero2 --lora_rank 2 \ - --num_episodes 10 --max_timesteps 10 --update_timesteps 10 \ - --max_epochs 10 --debug \ No newline at end of file diff --git a/applications/Chat/coati/ray/experience_maker_holder.py b/applications/Chat/coati/ray/experience_maker_holder.py new file mode 100644 index 0000000000000000000000000000000000000000..4d290f4aba886434dfa86b12e7db05c14ac15d9c --- /dev/null +++ b/applications/Chat/coati/ray/experience_maker_holder.py @@ -0,0 +1,274 @@ +import os +import time +import tracemalloc +from threading import Lock +from typing import Any, Callable, Dict, Iterable, List, Tuple, Union + +import ray +import torch +from coati.experience_buffer.utils import split_experience_batch +from coati.experience_maker import Experience, NaiveExperienceMaker +from coati.models.base import Actor, Critic, RewardModel +from coati.trainer.strategies import Strategy +from torch import Tensor +from tqdm import tqdm + +from .callbacks import ExperienceMakerPerformanceEvaluator, MakerCallback +from .lora_constructor import LoRAConstructor +from .utils import get_model_numel, get_rank, is_rank_0, set_dist_env, state_dict_to + + +@ray.remote(concurrency_groups={"experience_io": 1, "model_io": 1, "compute": 1}) +class ExperienceMakerHolder: + """ + Args: + detached_trainer_name_list: str list to get ray actor handles + strategy: + kl_coef: the coefficient of kl divergence loss + sync_models_from_trainers: whether to sync models from trainers. If True, you must call sync_models_to_remote_makers() in trainers to sync models. + """ + + def __init__( + self, + detached_trainer_name_list: List[str], + strategy_fn: Callable[[], Strategy], + # a function returns (actor, critic, reward_model, initial_model) + model_fn: Callable[[], Tuple[Actor, Critic, RewardModel, Actor]], + env_info: Dict[str, str] = None, + sync_models_from_trainers: bool = False, + buffer_cpu_offload: bool = True, + kl_coef: float = 0.1, + callbacks: List[MakerCallback] = [], + eval_performance: bool = False, + debug: bool = False, + update_lora_weights: bool = False, + **generate_kwargs, + ): + # set environment variables + if env_info: + set_dist_env(env_info=env_info) + self.target_trainer_list = [] + assert len(detached_trainer_name_list) > 0 + self._detached_trainer_name_list = detached_trainer_name_list + self.strategy = strategy_fn() + self.buffer_cpu_offload = buffer_cpu_offload + self.kl_coef = kl_coef + # init models + with self.strategy.model_init_context(): + actor, critic, reward_model, initial_model = model_fn() + self.generate_kwargs = _set_default_generate_kwargs(generate_kwargs, actor) + if eval_performance: + actor_numel = get_model_numel(actor) + critic_numel = get_model_numel(critic) + initial_model_numel = get_model_numel(initial_model) + reward_model_numel = get_model_numel(reward_model) + evaluator = ExperienceMakerPerformanceEvaluator( + actor_numel, critic_numel, initial_model_numel, reward_model_numel + ) + callbacks = callbacks + [evaluator] + + actor, critic, reward_model, initial_model = self.strategy.prepare(actor, critic, reward_model, initial_model) + self.experience_maker = NaiveExperienceMaker(actor, critic, reward_model, initial_model, self.kl_coef) + self.callbacks = callbacks + + self._model_visit_lock = Lock() + + self._is_fully_initialized = not sync_models_from_trainers + + self._debug = debug + self._update_lora_weights = update_lora_weights + if self._update_lora_weights: + self.actor_lora_constructor = LoRAConstructor() + self.critic_lora_constructor = LoRAConstructor() + + self.target_auto_balance = False + + self._target_idx = 0 + + if self._debug: + print(f"[maker{get_rank()}] will send items to {self._detached_trainer_name_list}") + if not self._is_fully_initialized: + print(f"[maker{get_rank()}] Waiting for INIT") + + def _get_ready(self): + while not self._fully_initialized(): + time.sleep(1.0) + + def _fully_initialized(self): + return self._is_fully_initialized + + def _init_target_trainer_list(self): + if len(self.target_trainer_list) > 0: + return + for name in self._detached_trainer_name_list: + self.target_trainer_list.append(ray.get_actor(name, namespace=os.environ["RAY_NAMESPACE"])) + + # copy from ../trainer/base.py + @ray.method(concurrency_group="compute") + def _make_experience(self, inputs: Union[Tensor, Dict[str, Tensor]]) -> Experience: + if isinstance(inputs, Tensor): + return self.experience_maker.make_experience(inputs, **self.generate_kwargs) + elif isinstance(inputs, dict): + return self.experience_maker.make_experience(**inputs, **self.generate_kwargs) + else: + raise ValueError(f'Unsupported input type "{type(inputs)}"') + + @ray.method(concurrency_group="experience_io") + def _send_items(self, experience: Experience) -> None: + self._init_target_trainer_list() + items = split_experience_batch(experience) + items_per_trainer = [[] for _ in range(len(self.target_trainer_list))] + for item in items: + items_per_trainer[self._target_idx].append(item) + self._target_idx = (self._target_idx + 1) % len(self.target_trainer_list) + for i, target_trainer in enumerate(self.target_trainer_list): + if len(items_per_trainer[i]) > 0: + target_trainer.buffer_extend.remote(items_per_trainer[i]) + + def _inference_step(self, batch) -> None: + self._on_batch_start() + with self._model_visit_lock: + self._on_make_experience_start() + experience = self._make_experience(batch) + self._on_make_experience_end(experience) + self._on_send_start() + if self.buffer_cpu_offload: + experience.to_device("cpu") + self._send_items(experience) + self._on_send_end() + self._on_batch_end() + + def workingloop(self, dataloader_fn: Callable[[], Iterable], num_epochs: int = 1, num_steps: int = 0): + """Working loop of the experience maker. + + Args: + dataloader_fn (Callable[[], Iterable]): A function that returns a dataloader. + num_epochs (int, optional): Iterate the dataloader for number of epochs. Defaults to 1. + num_steps (int, optional): Iterate the dataloader for number if steps. If this value > 0, num_epochs will be ignored. Defaults to 0. + """ + self._get_ready() + self._on_loop_start() + dataloader = dataloader_fn() + if num_steps > 0: + # ignore num epochs + it = iter(dataloader) + for _ in tqdm(range(num_steps), desc="ExperienceMaker", disable=not is_rank_0()): + try: + batch = next(it) + except StopIteration: + it = iter(dataloader) + batch = next(it) + self._inference_step(batch) + else: + with tqdm(total=num_epochs * len(dataloader), desc="ExperienceMaker", disable=not is_rank_0()) as pbar: + for _ in range(num_epochs): + for batch in dataloader: + self._inference_step(batch) + pbar.update() + self._on_loop_end() + + @ray.method(concurrency_group="model_io") + def update_experience_maker( + self, + new_actor_state_dict: Dict[str, Any] = None, + new_actor_lora_config_dict: Dict[str, Any] = None, + new_critic_state_dict: Dict[str, Any] = None, + new_critic_lora_config_dict: Dict[str, Any] = None, + fully_update: bool = False, + chunk_start: bool = None, + chunk_end: bool = None, + ): + """ + called by trainer + chunk_start: Set True at the first call. Before sending state_dict calls + chunk_end: Set True at the last call. After sending state_dict calls. + fully_update: Set True if you want to sync models when initializing + + TODO: load_state_dict integrate with model-sharding strategy + """ + _watch_memory = self._debug + if chunk_start: + if self._debug: + print("[maker] UPDATE ") + if _watch_memory: + tracemalloc.start() + self._model_visit_lock.acquire() + + with torch.no_grad(): + if new_actor_state_dict is not None: + if not self._update_lora_weights or fully_update: + self.experience_maker.actor.model.load_state_dict(new_actor_state_dict, strict=False) + else: + new_actor_state_dict = state_dict_to(new_actor_state_dict, device=torch.cuda.current_device()) + state_dict_increase = self.actor_lora_constructor.reconstruct_increase( + new_actor_state_dict, new_actor_lora_config_dict + ) + self.actor_lora_constructor.load_state_dict_increase( + self.experience_maker.actor.model, state_dict_increase + ) + if new_critic_state_dict is not None: + if not self._update_lora_weights or fully_update: + self.experience_maker.critic.load_state_dict(new_critic_state_dict, strict=False) + else: + new_critic_state_dict = state_dict_to(new_critic_state_dict, device=torch.cuda.current_device()) + state_dict_increase = self.critic_lora_constructor.reconstruct_increase( + new_critic_state_dict, new_critic_lora_config_dict + ) + self.critic_lora_constructor.load_state_dict_increase( + self.experience_maker.critic, state_dict_increase + ) + + # the lock must be released after both actor and critic being updated + if chunk_end: + self._model_visit_lock.release() + if _watch_memory: + current, peak = tracemalloc.get_traced_memory() + print(f"Current memory usage is {current / 10**6}MB; Peak was {peak / 10**6}MB") + tracemalloc.stop() + if fully_update: + self._is_fully_initialized = True + + def _on_make_experience_start(self) -> None: + for callback in self.callbacks: + callback.on_make_experience_start() + + def _on_make_experience_end(self, experience: Experience) -> None: + for callback in self.callbacks: + callback.on_make_experience_end(experience) + + def _on_loop_start(self) -> None: + for callback in self.callbacks: + callback.on_loop_start() + + def _on_loop_end(self) -> None: + for callback in self.callbacks: + callback.on_loop_end() + + def _on_send_start(self) -> None: + for callback in self.callbacks: + callback.on_send_start() + + def _on_send_end(self) -> None: + for callback in self.callbacks: + callback.on_send_end() + + def _on_batch_start(self) -> None: + for callback in self.callbacks: + callback.on_batch_start() + + def _on_batch_end(self) -> None: + for callback in self.callbacks: + callback.on_batch_end() + + +def _set_default_generate_kwargs(generate_kwargs: dict, actor: Actor) -> None: + origin_model = actor.model + new_kwargs = {**generate_kwargs} + # use huggingface models method directly + if "prepare_inputs_fn" not in generate_kwargs and hasattr(origin_model, "prepare_inputs_for_generation"): + new_kwargs["prepare_inputs_fn"] = origin_model.prepare_inputs_for_generation + + if "update_model_kwargs_fn" not in generate_kwargs and hasattr(origin_model, "_update_model_kwargs_for_generation"): + new_kwargs["update_model_kwargs_fn"] = origin_model._update_model_kwargs_for_generation + + return new_kwargs diff --git a/applications/Chat/coati/ray/lora_constructor.py b/applications/Chat/coati/ray/lora_constructor.py new file mode 100644 index 0000000000000000000000000000000000000000..8e9f78700e29fcc897245d486fd8c7c2bc1e95bc --- /dev/null +++ b/applications/Chat/coati/ray/lora_constructor.py @@ -0,0 +1,123 @@ +from collections import OrderedDict +from dataclasses import dataclass +from typing import Any, Dict + +import torch.nn as nn +from coati.models.lora import LoraLinear + + +@dataclass +class LoRAConfig: + r: int = 0 + lora_alpha: int = 1 + lora_dropout: float = 0 + fan_in_fan_out: bool = False + + +class LoRAConstructor: + """ + Tools for reconstructing a model from a remote LoRA model. + (Transferring only LoRA data costs much less!) + Usage: + Step 1 (Sender): + filter_state_dict_lora() + + Step 2 (Sender, Optional): + extract_lora_config() + + Step 3 (Sender): + send state_dict_lora and lora_config_dict + + Step 4 (Receiver): + reconstruct_increase() + + Step 5 (Receiver): + load_state_dict_increase() + + """ + + def __init__(self): + self.lora_config_dict = None + + def register_lora_config(self, lora_config_dict: Dict[str, Any]): + self.lora_config_dict = lora_config_dict + + def reconstruct_increase(self, state_dict_lora: Dict[str, Any], lora_config_dict: Dict[str, Any]): + """ + xxx.lora_A, xxx.lora_B -->> xxx.weight + Warning: the xxx.weight here is the increment actually. + """ + if lora_config_dict is not None: + self.register_lora_config(lora_config_dict) + + state_dict_increase = OrderedDict() + config_iter = iter(self.lora_config_dict.items()) + lora_A, lora_B, layer_prefix = None, None, None + for k, v in state_dict_lora.items(): + if k.rpartition(".")[-1] == "lora_A": + lora_A = v + layer_prefix = k.rpartition(".")[0] + elif k.rpartition(".")[-1] == "lora_B": + assert layer_prefix == k.rpartition(".")[0], "unmatched (lora_A, lora_B) pair" + layer_prefix_2, config = next(config_iter) + assert layer_prefix_2 == layer_prefix, "unmatched (state_dict, config_dict) pair" + lora_B = v + weight_data_increase = self._compute(lora_A, lora_B, config) + state_dict_increase[layer_prefix + ".weight"] = weight_data_increase + lora_A, lora_B, layer_prefix = None, None, None + else: + raise ValueError("unexpected key") + return state_dict_increase + + def _compute(self, lora_A, lora_B, config=LoRAConfig()): + def T(w): + return w.T if config.fan_in_fan_out else w + + if config.r > 0: + scaling = config.lora_alpha / config.r + weight_data_increase = T(lora_B @ lora_A) * scaling + return weight_data_increase + return 0 + + def load_state_dict_increase(self, model: nn.Module, state_dict_increase: Dict[str, Any]): + """ + The final reconstruction step + """ + # naive approach + model.load_state_dict({k: v + model.state_dict()[k] for k, v in state_dict_increase.items()}, strict=False) + + @staticmethod + def filter_state_dict_lora(state_dict: Dict[str, Any], keep_non_lora=False): + """ + if keep_non_lora, also return non_lora state_dict + """ + state_dict_lora = OrderedDict() + state_dict_non_lora = OrderedDict() + for k, v in state_dict.items(): + if "lora_A" in k or "lora_B" in k: + state_dict_lora[k] = v + elif keep_non_lora: + state_dict_non_lora[k] = v + if keep_non_lora: + return state_dict_lora, state_dict_non_lora + else: + return state_dict_lora, None + + @staticmethod + def extract_lora_config(model: nn.Module) -> Dict[str, LoRAConfig]: + """ + extract LoraLinear model. + return OrderedDict(): name -> LoRAConfig + """ + lora_config_dict = OrderedDict() + + for name, child in model.named_modules(): + if isinstance(child, LoraLinear): + lora_config_dict[name] = LoRAConfig( + r=child.r, + lora_alpha=child.lora_alpha, + lora_dropout=child.lora_dropout, + fan_in_fan_out=child.fan_in_fan_out, + ) + + return lora_config_dict diff --git a/applications/Chat/coati/ray/src/detached_replay_buffer.py b/applications/Chat/coati/ray/src/detached_replay_buffer.py deleted file mode 100644 index 855eee48c5a5ccb252ccd3e697073d728afcd3df..0000000000000000000000000000000000000000 --- a/applications/Chat/coati/ray/src/detached_replay_buffer.py +++ /dev/null @@ -1,88 +0,0 @@ -import torch -import random -from typing import List, Any -# from torch.multiprocessing import Queue -from ray.util.queue import Queue -import ray -import asyncio -from coati.experience_maker.base import Experience -from coati.replay_buffer.utils import BufferItem, make_experience_batch, split_experience_batch -from coati.replay_buffer import ReplayBuffer -from threading import Lock -import copy - -class DetachedReplayBuffer: - ''' - Detached replay buffer. Share Experience across workers on the same node. - Therefore a trainer node is expected to have only one instance. - It is ExperienceMakerHolder's duty to call append(exp) method, remotely. - - Args: - sample_batch_size: Batch size when sampling. Exp won't enqueue until they formed a batch. - tp_world_size: Number of workers in the same tp group - limit: Limit of number of experience sample BATCHs. A number <= 0 means unlimited. Defaults to 0. - cpu_offload: Whether to offload experience to cpu when sampling. Defaults to True. - ''' - - def __init__(self, sample_batch_size: int, tp_world_size: int = 1, limit : int = 0, cpu_offload: bool = True) -> None: - self.cpu_offload = cpu_offload - self.sample_batch_size = sample_batch_size - self.limit = limit - self.items = Queue(self.limit, actor_options={"num_cpus":1}) - self.batch_collector : List[BufferItem] = [] - - ''' - Workers in the same tp group share this buffer and need same sample for one step. - Therefore a held_sample should be returned tp_world_size times before it could be dropped. - worker_state records wheter a worker got the held_sample - ''' - self.tp_world_size = tp_world_size - self.worker_state = [False] * self.tp_world_size - self.held_sample = None - self._worker_state_lock = Lock() - - @torch.no_grad() - def append(self, experience: Experience) -> None: - ''' - Expected to be called remotely. - ''' - if self.cpu_offload: - experience.to_device(torch.device('cpu')) - items = split_experience_batch(experience) - self.batch_collector.extend(items) - while len(self.batch_collector) >= self.sample_batch_size: - items = self.batch_collector[:self.sample_batch_size] - experience = make_experience_batch(items) - self.items.put(experience, block=True) - self.batch_collector = self.batch_collector[self.sample_batch_size:] - - def clear(self) -> None: - # self.items.close() - self.items.shutdown() - self.items = Queue(self.limit) - self.worker_state = [False] * self.tp_world_size - self.batch_collector = [] - - @torch.no_grad() - def sample(self, worker_rank = 0, to_device = "cpu") -> Experience: - self._worker_state_lock.acquire() - if not any(self.worker_state): - self.held_sample = self._sample_and_erase() - self.worker_state[worker_rank] = True - if all(self.worker_state): - self.worker_state = [False] * self.tp_world_size - ret = self.held_sample - else: - ret = copy.deepcopy(self.held_sample) - self._worker_state_lock.release() - ret.to_device(to_device) - return ret - - @torch.no_grad() - def _sample_and_erase(self) -> Experience: - ret = self.items.get(block=True) - return ret - - def get_length(self) -> int: - ret = self.items.qsize() - return ret \ No newline at end of file diff --git a/applications/Chat/coati/ray/src/detached_trainer_base.py b/applications/Chat/coati/ray/src/detached_trainer_base.py deleted file mode 100644 index f1ed1ec71499a9e7f84c80a52e357ef4c2a92201..0000000000000000000000000000000000000000 --- a/applications/Chat/coati/ray/src/detached_trainer_base.py +++ /dev/null @@ -1,121 +0,0 @@ -from abc import ABC, abstractmethod -from typing import Any, Callable, Dict, List, Optional, Union -from tqdm import tqdm -from coati.trainer.callbacks import Callback -from coati.experience_maker import Experience -import ray -import os - -from .detached_replay_buffer import DetachedReplayBuffer -from .utils import is_rank_0 - -class DetachedTrainer(ABC): - ''' - Base class for detached rlhf trainers. - 'detach' means that the experience maker is detached compared to a normal Trainer. - Please set name attribute during init: - >>> trainer = DetachedTrainer.options(..., name = "xxx", ...).remote() - So an ExperienceMakerHolder can reach the detached_replay_buffer by Actor's name. - Args: - detached_strategy (DetachedStrategy): the strategy to use for training - detached_replay_buffer_ref (ObjectRef[DetachedReplayBuffer]): the replay buffer to use for training - experience_batch_size (int, defaults to 8): the batch size to use for experience generation - max_epochs (int, defaults to 1): the number of epochs of training process - data_loader_pin_memory (bool, defaults to True): whether to pin memory for data loader - callbacks (List[Callback], defaults to []): the callbacks to call during training process - generate_kwargs (dict, optional): the kwargs to use while model generating - ''' - - def __init__(self, - experience_maker_holder_name_list: List[str], - train_batch_size: int = 8, - buffer_limit: int = 0, - buffer_cpu_offload: bool = True, - experience_batch_size: int = 8, - max_epochs: int = 1, - dataloader_pin_memory: bool = True, - callbacks: List[Callback] = [], - **generate_kwargs) -> None: - super().__init__() - self.detached_replay_buffer = DetachedReplayBuffer(train_batch_size, limit=buffer_limit, cpu_offload=buffer_cpu_offload) - self.experience_batch_size = experience_batch_size - self.max_epochs = max_epochs - self.dataloader_pin_memory = dataloader_pin_memory - self.callbacks = callbacks - self.generate_kwargs = generate_kwargs - self.target_holder_name_list = experience_maker_holder_name_list - self.target_holder_list = [] - - def update_target_holder_list(self, experience_maker_holder_name_list): - self.target_holder_name_list = experience_maker_holder_name_list - self.target_holder_list = [] - for name in self.target_holder_name_list: - self.target_holder_list.append(ray.get_actor(name, namespace=os.environ["RAY_NAMESPACE"])) - - @abstractmethod - def _update_remote_makers(self): - pass - - @abstractmethod - def training_step(self, experience: Experience) -> Dict[str, Any]: - pass - - def _learn(self): - pbar = tqdm(range(self.max_epochs), desc='Train epoch', disable=not is_rank_0()) - for _ in pbar: - if 'debug' in self.generate_kwargs and self.generate_kwargs['debug'] == True: - print("[trainer] sampling exp") - experience = self._buffer_sample() - if 'debug' in self.generate_kwargs and self.generate_kwargs['debug'] == True: - print("[trainer] training step") - metrics = self.training_step(experience) - if 'debug' in self.generate_kwargs and self.generate_kwargs['debug'] == True: - print("[trainer] step over") - pbar.set_postfix(metrics) - - def fit(self, num_episodes: int = 50000, max_timesteps: int = 500, update_timesteps: int = 5000) -> None: - self._on_fit_start() - for episode in range(num_episodes): - self._on_episode_start(episode) - for timestep in tqdm(range(max_timesteps // update_timesteps), - desc=f'Episode [{episode+1}/{num_episodes}]', - disable=not is_rank_0()): - self._learn() - self._update_remote_makers() - self._on_episode_end(episode) - self._on_fit_end() - - @ray.method(concurrency_group="buffer_length") - def buffer_get_length(self): - # called by ExperienceMakerHolder - if 'debug' in self.generate_kwargs and self.generate_kwargs['debug'] == True: - print("[trainer] telling length") - return self.detached_replay_buffer.get_length() - - @ray.method(concurrency_group="buffer_append") - def buffer_append(self, experience: Experience): - # called by ExperienceMakerHolder - if 'debug' in self.generate_kwargs and self.generate_kwargs['debug'] == True: - # print(f"[trainer] receiving exp. Current buffer length: {self.detached_replay_buffer.get_length()}") - print(f"[trainer] receiving exp.") - self.detached_replay_buffer.append(experience) - - @ray.method(concurrency_group="buffer_sample") - def _buffer_sample(self): - return self.detached_replay_buffer.sample() - - def _on_fit_start(self) -> None: - for callback in self.callbacks: - callback.on_fit_start() - - def _on_fit_end(self) -> None: - for callback in self.callbacks: - callback.on_fit_end() - - def _on_episode_start(self, episode: int) -> None: - for callback in self.callbacks: - callback.on_episode_start(episode) - - def _on_episode_end(self, episode: int) -> None: - for callback in self.callbacks: - callback.on_episode_end(episode) diff --git a/applications/Chat/coati/ray/src/detached_trainer_ppo.py b/applications/Chat/coati/ray/src/detached_trainer_ppo.py deleted file mode 100644 index 838e82d07f4addb8924e1f808c2a47fed363c1f8..0000000000000000000000000000000000000000 --- a/applications/Chat/coati/ray/src/detached_trainer_ppo.py +++ /dev/null @@ -1,192 +0,0 @@ -from typing import Any, Callable, Dict, List, Optional -import torch -from torch.optim import Adam - -from coati.experience_maker import Experience, NaiveExperienceMaker -from coati.models.base import Actor, Critic -from coati.models.generation_utils import update_model_kwargs_fn -from coati.models.loss import PolicyLoss, ValueLoss -from coati.trainer.strategies import ColossalAIStrategy, DDPStrategy, NaiveStrategy, Strategy -from coati.trainer.callbacks import Callback - -from colossalai.nn.optimizer import HybridAdam - -import ray - - -from .utils import is_rank_0, get_cuda_actor_critic_from_args, get_strategy_from_args, set_dist_env -from .detached_trainer_base import DetachedTrainer - - -@ray.remote(concurrency_groups={"buffer_length": 1, "buffer_append":1, "buffer_sample":1,"model_io": 1, "compute": 1}) -class DetachedPPOTrainer(DetachedTrainer): - ''' - Detached Trainer for PPO algorithm - Args: - strategy (Strategy): the strategy to use for training - model (str) : for actor / critic init - pretrained (str) : for actor / critic init - lora_rank (int) : for actor / critic init - train_batch_size (int, defaults to 8): the batch size to use for training - train_batch_size (int, defaults to 8): the batch size to use for training - buffer_limit (int, defaults to 0): the max_size limitation of replay buffer - buffer_cpu_offload (bool, defaults to True): whether to offload replay buffer to cpu - eps_clip (float, defaults to 0.2): the clip coefficient of policy loss - value_clip (float, defaults to 0.4): the clip coefficient of value loss - experience_batch_size (int, defaults to 8): the batch size to use for experience generation - max_epochs (int, defaults to 1): the number of epochs of training process - dataloader_pin_memory (bool, defaults to True): whether to pin memory for data loader - callbacks (List[Callback], defaults to []): the callbacks to call during training process - generate_kwargs (dict, optional): the kwargs to use while model generating - ''' - - def __init__(self, - experience_maker_holder_name_list: List[str], - strategy: str, - model: str, - env_info: Dict[str, str] = None, - pretrained: str = None, - lora_rank: int = 0, - train_batch_size: int = 8, - buffer_limit: int = 0, - buffer_cpu_offload: bool = True, - eps_clip: float = 0.2, - value_clip: float = 0.4, - experience_batch_size: int = 8, - max_epochs: int = 10, - dataloader_pin_memory: bool = True, - callbacks: List[Callback] = [], - **generate_kwargs) -> None: - # set environment variables - if env_info: - set_dist_env(env_info=env_info) - # configure strategy - self.strategy = get_strategy_from_args(strategy) - # configure models, loss and optimizers - with self.strategy.model_init_context(): - self.actor, self.critic = get_cuda_actor_critic_from_args(model, pretrained, lora_rank) - - if strategy != 'colossalai_gemini': - self.actor.to(torch.float16).to(torch.cuda.current_device()) - self.critic.to(torch.float16).to(torch.cuda.current_device()) - - if strategy.startswith('colossalai'): - self.actor_optim = HybridAdam(self.actor.parameters(), lr=5e-6) - self.critic_optim = HybridAdam(self.critic.parameters(), lr=5e-6) - else: - self.actor_optim = Adam(self.actor.parameters(), lr=5e-6) - self.critic_optim = Adam(self.critic.parameters(), lr=5e-6) - - (self.actor, self.actor_optim), (self.critic, self.critic_optim) = \ - self.strategy.prepare((self.actor, self.actor_optim), (self.critic, self.critic_optim)) - generate_kwargs = _set_default_generate_kwargs(self.strategy, generate_kwargs, self.actor) - - self.actor_loss_fn = PolicyLoss(eps_clip) - self.critic_loss_fn = ValueLoss(value_clip) - - super().__init__(experience_maker_holder_name_list, - train_batch_size=train_batch_size, - buffer_limit=buffer_limit, - buffer_cpu_offload=buffer_cpu_offload, - experience_batch_size=experience_batch_size, - max_epochs=max_epochs, - dataloader_pin_memory=dataloader_pin_memory, - callbacks=callbacks, - **generate_kwargs) - - @ray.method(concurrency_group="model_io") - def _update_remote_makers(self): - # TODO: balance duties - if is_rank_0(): - self.update_target_holder_list(self.target_holder_name_list) - for target_holder in self.target_holder_list: - # TODO: reduce malloc - with torch.no_grad(): - ray.get(target_holder.update_experience_maker.remote(self._get_unwrapped_actor(), self._get_unwrapped_critic())) - - @ray.method(concurrency_group="model_io") - def initialize_remote_makers(self): - # TODO: balance duties - if is_rank_0(): - self.update_target_holder_list(self.target_holder_name_list) - for target_holder in self.target_holder_list: - # TODO: reduce malloc - with torch.no_grad(): - ray.get(target_holder.initialize_experience_maker.remote(self._get_unwrapped_actor(), self._get_unwrapped_critic())) - - @ray.method(concurrency_group="compute") - def training_step(self, experience: Experience) -> Dict[str, float]: - self.actor.train() - self.critic.train() - - experience.to_device(torch.cuda.current_device()) - num_actions = experience.action_mask.size(1) - action_log_probs = self.actor(experience.sequences, num_actions, attention_mask=experience.attention_mask) - actor_loss = self.actor_loss_fn(action_log_probs, - experience.action_log_probs, - experience.advantages, - action_mask=experience.action_mask) - self.strategy.backward(actor_loss, self.actor, self.actor_optim) - self.strategy.optimizer_step(self.actor_optim) - self.actor_optim.zero_grad() - - values = self.critic(experience.sequences, - action_mask=experience.action_mask, - attention_mask=experience.attention_mask) - critic_loss = self.critic_loss_fn(values, - experience.values, - experience.reward, - action_mask=experience.action_mask) - - self.strategy.backward(critic_loss, self.critic, self.critic_optim) - self.strategy.optimizer_step(self.critic_optim) - self.critic_optim.zero_grad() - return {'actor_loss': actor_loss.item(), 'critic_loss': critic_loss.item()} - - def strategy_save_actor(self, path: str, only_rank0: bool = False) -> None: - self.strategy.save_model(self.actor, path, only_rank0) - - def strategy_save_critic(self, path: str, only_rank0: bool = False) -> None: - self.strategy.save_model(self.critic, path, only_rank0) - - def strategy_save_actor_optim(self, path: str, only_rank0: bool = False) -> None: - self.strategy.save_optimizer(self.actor_optim, path, only_rank0) - - def strategy_save_critic_optim(self, path: str, only_rank0: bool = False) -> None: - self.strategy.save_optimizer(self.critic_optim, path, only_rank0) - - def _get_unwrapped_actor(self): - if False: - pass - elif isinstance(self.strategy, ColossalAIStrategy): - ret = Actor(self.strategy._unwrap_model(self.actor)) - return ret - elif isinstance(self.strategy, DDPStrategy): - return Actor(self.strategy._unwrap_actor(self.actor)) - elif isinstance(self.strategy, NaiveStrategy): - return self.actor - - def _get_unwrapped_critic(self): - if False: - pass - elif isinstance(self.strategy, ColossalAIStrategy): - ret = self.strategy._unwrap_model(self.critic) - return ret - elif isinstance(self.strategy, DDPStrategy): - return self.critic.module - elif isinstance(self.strategy, NaiveStrategy): - return self.critic - - -def _set_default_generate_kwargs(strategy: Strategy, generate_kwargs: dict, actor: Actor) -> None: - origin_model = strategy._unwrap_actor(actor) - new_kwargs = {**generate_kwargs} - # use huggingface models method directly - if 'prepare_inputs_fn' not in generate_kwargs and hasattr(origin_model, 'prepare_inputs_for_generation'): - new_kwargs['prepare_inputs_fn'] = origin_model.prepare_inputs_for_generation - - if 'update_model_kwargs_fn' not in generate_kwargs: - new_kwargs['update_model_kwargs_fn'] = update_model_kwargs_fn - - return new_kwargs - \ No newline at end of file diff --git a/applications/Chat/coati/ray/src/experience_maker_holder.py b/applications/Chat/coati/ray/src/experience_maker_holder.py deleted file mode 100644 index 94e4a3d537a57733ad8dbf510e940ef6c9129f6d..0000000000000000000000000000000000000000 --- a/applications/Chat/coati/ray/src/experience_maker_holder.py +++ /dev/null @@ -1,172 +0,0 @@ -import torch -from typing import Any, Callable, Dict, List, Optional, Union -import ray -from ray.exceptions import GetTimeoutError -from torch import Tensor -import torch.nn as nn -from coati.models.base import Actor, Critic, RewardModel -from coati.trainer.strategies.sampler import DistributedSampler -from coati.trainer.strategies import Strategy -from coati.experience_maker import NaiveExperienceMaker, Experience, ExperienceMaker - -from copy import deepcopy -from threading import Lock -import time -import os - - -from .utils import is_rank_0, get_strategy_from_args, set_dist_env - - -@ray.remote(concurrency_groups={"experience_io": 1, "model_io": 1, "compute": 1}) -class ExperienceMakerHolder: - ''' - Args: - detached_trainer_name_list: str list to get ray actor handleskkk - strategy: - experience_batch_size: batch size of generated experience - kl_coef: the coefficient of kl divergence loss - ''' - - def __init__(self, - detached_trainer_name_list: List[str], - strategy: str, - env_info: Dict[str, str] = None, - experience_batch_size: int = 8, - kl_coef: float = 0.1, - **generate_kwargs): - # set environment variables - if env_info: - set_dist_env(env_info=env_info) - self.target_trainer_list = [] - for name in detached_trainer_name_list: - self.target_trainer_list.append(ray.get_actor(name, namespace=os.environ["RAY_NAMESPACE"])) - self.strategy_str = strategy - self.strategy = get_strategy_from_args(strategy) - self.experience_batch_size = experience_batch_size - self.kl_coef = kl_coef - self.generate_kwargs = generate_kwargs - # Need a trainer to give an actor and a critic via initialize_experience_maker(...) - actor, critic, reward_model, initial_model = None, None, None, None - self.experience_maker = NaiveExperienceMaker(actor, critic, reward_model, initial_model, self.kl_coef) - self._model_visit_lock = Lock() - self.fully_initialized = False - if 'debug' in self.generate_kwargs and self.generate_kwargs['debug'] == True: - print('[maker] Waiting for INIT') - - def _get_ready(self): - while not self.fully_initialized: - time.sleep(1.0) - - def update_target_trainer_list(self, detached_trainer_name_list): - self.target_trainer_list = [] - for name in detached_trainer_name_list: - self.target_trainer_list.append(ray.get_actor(name)) - - # copy from ../trainer/base.py - @ray.method(concurrency_group="compute") - def _make_experience(self, inputs: Union[Tensor, Dict[str, Tensor]]) -> Experience: - self._get_ready() - if isinstance(inputs, Tensor): - return self.experience_maker.make_experience(inputs, **self.generate_kwargs) - elif isinstance(inputs, dict): - return self.experience_maker.make_experience(**inputs, **self.generate_kwargs) - else: - raise ValueError(f'Unsupported input type "{type(inputs)}"') - - @ray.method(concurrency_group="experience_io") - def _send_experience(self, experience): - ''' - ignore it - - # choose a trainer that has the least experience batch in its detached_replay_buffer - chosen_trainer = None - min_length = None - if 'debug' in self.generate_kwargs and self.generate_kwargs['debug'] == True: - print("[maker] choosing target trainer") - while chosen_trainer is None: - for target_trainer in self.target_trainer_list: - try: - temp_length = ray.get(target_trainer.buffer_get_length.remote(), timeout=0.1) - if min_length is None: - min_length = temp_length - chosen_trainer = target_trainer - else: - if temp_length < min_length: - min_length = temp_length - chosen_trainer = target_trainer - except GetTimeoutError: - pass - - if 'debug' in self.generate_kwargs and self.generate_kwargs['debug'] == True: - print(f"[maker] sending exp to {chosen_trainer}") - chosen_trainer.buffer_append.remote(experience) - ''' - # - if not hasattr(self, "_target_idx"): - self._target_idx = 0 - chosen_trainer = self.target_trainer_list[self._target_idx] - if 'debug' in self.generate_kwargs and self.generate_kwargs['debug'] == True: - print(f"[maker] sending exp to {chosen_trainer}") - chosen_trainer.buffer_append.remote(experience) - self._target_idx = (self._target_idx + 1) % len(self.target_trainer_list) - - def workingloop(self, dataset, tokenizer: Optional[Callable[[Any], dict]] = None, times=5000 * 50000): - self._get_ready() - sampler = self.strategy.setup_sampler(dataset) - for _ in range(times): - rand_prompts = sampler.sample(self.experience_batch_size) - if tokenizer is not None: - inputs = tokenizer(rand_prompts) - else: - inputs = rand_prompts - self._model_visit_lock.acquire() - experience = self._make_experience(inputs=inputs) - self._model_visit_lock.release() - self._send_experience(experience=experience) - - @ray.method(concurrency_group="model_io") - def initialize_experience_maker(self, init_actor: Actor, init_critic: Critic): - ''' - called by trainer. Only once. - ''' - # TODO: reduce malloc - if self.fully_initialized: - return - if 'debug' in self.generate_kwargs and self.generate_kwargs['debug'] == True: - print('[maker] INIT') - with torch.no_grad(): - with self.strategy.model_init_context(): - actor = init_actor - critic = init_critic - initial_model = deepcopy(actor) - reward_model = RewardModel(deepcopy(critic.model), - deepcopy(critic.value_head)).to(torch.cuda.current_device()) - if self.strategy_str != 'colossalai_gemini': - actor.to(torch.float16).to(torch.cuda.current_device()) - critic.to(torch.float16).to(torch.cuda.current_device()) - initial_model.to(torch.float16).to(torch.cuda.current_device()) - reward_model.to(torch.float16).to(torch.cuda.current_device()) - - self.experience_maker.actor = self.strategy.prepare(actor) - self.experience_maker.critic = self.strategy.prepare(critic) - self.experience_maker.initial_model = self.strategy.prepare(initial_model) - self.experience_maker.reward_model = self.strategy.prepare(reward_model) - self.fully_initialized = True - - @ray.method(concurrency_group="model_io") - def update_experience_maker(self, new_actor: Actor, new_critic: Critic): - ''' - called by trainer - ''' - # TODO: reduce malloc - self._model_visit_lock.acquire() - with torch.no_grad(): - if 'debug' in self.generate_kwargs and self.generate_kwargs['debug'] == True: - print("[maker] UPDATE ") - if self.strategy_str != 'colossalai_gemini': - new_actor.to(torch.float16).to(torch.cuda.current_device()) - new_critic.to(torch.float16).to(torch.cuda.current_device()) - self.experience_maker.actor = self.strategy.prepare(new_actor) - self.experience_maker.critic = self.strategy.prepare(new_critic) - self._model_visit_lock.release() diff --git a/applications/Chat/coati/ray/src/pipeline_strategy.py b/applications/Chat/coati/ray/src/pipeline_strategy.py deleted file mode 100644 index 1780839c62ee3477ce84cd412ccf7ac60ab57357..0000000000000000000000000000000000000000 --- a/applications/Chat/coati/ray/src/pipeline_strategy.py +++ /dev/null @@ -1,105 +0,0 @@ -# WIP - - -from coati.trainer.strategies import Strategy -from coati.trainer.strategies import NaiveStrategy -from coati.models.base import Actor, RewardModel, Critic - -import numpy as np -import torch -from torch._C._distributed_rpc import _is_current_rpc_agent_set - -import colossalai -from colossalai.pipeline.pipeline_process_group import ppg -from colossalai.pipeline.rpc._pipeline_schedule import OneFOneBPipelineEngine -from colossalai.fx import ColoTracer -from colossalai.fx.passes.adding_split_node_pass import balanced_split_pass, split_with_split_nodes_pass -from colossalai.pipeline.middleware.adaptor import get_fx_topology - - -import os -from functools import partial -import random - -rpc_is_initialized = _is_current_rpc_agent_set - -class PipelineModel(torch.nn.Module): - ''' - Actor has 2 kinds of jobs: forward and generate. - better to just pipelinize the inner model - ''' - def __init__(self, - model: torch.nn.Module, - stage_num: int, - num_microbatches: int, - data_kwargs = None, - ): - super().__init__() - # create partition module - def create_partition_module(pp_rank:int, stage_num: int, model, data_kwargs): - model.eval() - tracer = ColoTracer() - meta_args = {k: v.to('meta') for k, v in data_kwargs.items()} - graph = tracer.trace(root=model, meta_args=meta_args) - gm = torch.fx.GraphModule(model, graph, model.__class__.__name__) - annotated_model = balanced_split_pass(gm, stage_num) - top_module, split_submodules = split_with_split_nodes_pass(annotated_model, merge_output=True) - topo = get_fx_topology(top_module) - for submodule in split_submodules: - if isinstance(submodule, torch.fx.GraphModule): - setattr(submodule, '_topo', topo) - return split_submodules[pp_rank + 1] - - def partition(model, data_kwargs: dict, pp_rank: int, chunk: int, stage_num: int): - partition = create_partition_module(pp_rank, stage_num, model, data_kwargs) - return partition - self.inference_engine = OneFOneBPipelineEngine( - partition_fn=partial(partition, model, data_kwargs), - stage_num=stage_num, - num_microbatches=num_microbatches, - device='cuda', - ) - - def forward(self, - **model_inputs): - return self.inference_engine.forward_backward(**model_inputs, forward_only=True) - - - -class PPStrategy(NaiveStrategy): - """ - Strategy for Pipeline inference (inference only!) - - master node only - """ - def __init__( - self, - seed: int = 42 - ): - self.seed = seed - super().__init__() - - - def setup_distributed(self) -> None: - colossalai.launch_from_torch({}, seed=self.seed) - ppg.set_global_info(rank = int(os.environ['RANK']), - world_size=int(os.environ['WORLD_SIZE']), - dp_degree=1, - tp_degree=1, - num_worker_threads=128, - device="cuda") - - def model_init_context(self): - return super().model_init_context() - - def setup_model(self, model: torch.nn.Module) -> torch.nn.Module: - if isinstance(model, Actor) or \ - isinstance(model, RewardModel) or \ - isinstance(model, Critic): - model.model = PipelineModel(model.model) - - def set_seed(self, seed: int) -> None: - random.seed(seed) - np.random.seed(seed) - torch.manual_seed(seed) - diff --git a/applications/Chat/coati/ray/src/utils.py b/applications/Chat/coati/ray/src/utils.py deleted file mode 100644 index c750879b6d187b90bef14c15ddf50cd9ca33f2a5..0000000000000000000000000000000000000000 --- a/applications/Chat/coati/ray/src/utils.py +++ /dev/null @@ -1,48 +0,0 @@ -import torch.distributed as dist -from typing import Any, Callable, Dict, List, Optional -from coati.models.bloom import BLOOMActor, BLOOMCritic -from coati.models.gpt import GPTActor, GPTCritic -from coati.models.opt import OPTActor, OPTCritic -from coati.trainer.strategies import ColossalAIStrategy, DDPStrategy, NaiveStrategy -import torch -import os - -def is_rank_0() -> bool: - return not dist.is_initialized() or dist.get_rank() == 0 - - -def get_cuda_actor_critic_from_args(model: str, pretrained: str = None, lora_rank=0): - if model == 'gpt2': - actor = GPTActor(pretrained=pretrained, lora_rank=lora_rank).to(torch.cuda.current_device()) - critic = GPTCritic(pretrained=pretrained, lora_rank=lora_rank).to(torch.cuda.current_device()) - elif model == 'bloom': - actor = BLOOMActor(pretrained=pretrained, lora_rank=lora_rank).to(torch.cuda.current_device()) - critic = BLOOMCritic(pretrained=pretrained, lora_rank=lora_rank).to(torch.cuda.current_device()) - elif model == 'opt': - actor = OPTActor(pretrained=pretrained, lora_rank=lora_rank).to(torch.cuda.current_device()) - critic = OPTCritic(pretrained=pretrained, lora_rank=lora_rank).to(torch.cuda.current_device()) - else: - raise ValueError(f'Unsupported model "{model}"') - return actor, critic - - -def get_strategy_from_args(strategy: str): - if strategy == 'naive': - strategy_ = NaiveStrategy() - elif strategy == 'ddp': - strategy_ = DDPStrategy() - elif strategy == 'colossalai_gemini': - strategy_ = ColossalAIStrategy(stage=3, placement_policy='cuda', initial_scale=2**5) - elif strategy == 'colossalai_zero2': - strategy_ = ColossalAIStrategy(stage=2, placement_policy='cuda') - else: - raise ValueError(f'Unsupported strategy "{strategy}"') - return strategy_ - - -def set_dist_env(env_info: Dict[str, str]): - os.environ["RANK"] = env_info['rank'] - os.environ["LOCAL_RANK"] = env_info['local_rank'] - os.environ["WORLD_SIZE"] = env_info['world_size'] - os.environ['MASTER_PORT'] = env_info['master_port'] - os.environ['MASTER_ADDR'] = env_info['master_addr'] diff --git a/applications/Chat/coati/ray/utils.py b/applications/Chat/coati/ray/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..b88140c0e036d323a50a6f5cf448e42483b66d8c --- /dev/null +++ b/applications/Chat/coati/ray/utils.py @@ -0,0 +1,140 @@ +import os +from collections import OrderedDict +from typing import Any, Dict + +import torch +import torch.distributed as dist +import torch.nn as nn +from coati.models.bloom import BLOOMRM, BLOOMActor, BLOOMCritic +from coati.models.gpt import GPTRM, GPTActor, GPTCritic +from coati.models.llama import LlamaActor, LlamaCritic, LlamaRM +from coati.models.opt import OPTRM, OPTActor, OPTCritic +from coati.trainer.strategies import DDPStrategy, GeminiStrategy, LowLevelZeroStrategy +from transformers import AutoTokenizer, BloomTokenizerFast, GPT2Tokenizer + + +def is_rank_0() -> bool: + return not dist.is_initialized() or dist.get_rank() == 0 + + +def get_rank() -> int: + return dist.get_rank() if dist.is_initialized() else 0 + + +def get_world_size() -> int: + return dist.get_world_size() if dist.is_initialized() else 1 + + +def get_actor_from_args(model: str, pretrained: str = None, config=None, lora_rank=0): + if model == "gpt2": + actor = GPTActor(pretrained=pretrained, config=config, lora_rank=lora_rank) + elif model == "bloom": + actor = BLOOMActor(pretrained=pretrained, config=config, lora_rank=lora_rank) + elif model == "opt": + actor = OPTActor(pretrained=pretrained, config=config, lora_rank=lora_rank) + elif model == "llama": + actor = LlamaActor(pretrained=pretrained, config=config, lora_rank=lora_rank) + else: + raise ValueError(f'Unsupported actor model "{model}"') + return actor + + +def get_critic_from_args(model: str, pretrained: str = None, config=None, lora_rank=0): + if model == "gpt2": + critic = GPTCritic(pretrained=pretrained, lora_rank=lora_rank, config=config) + elif model == "bloom": + critic = BLOOMCritic(pretrained=pretrained, lora_rank=lora_rank, config=config) + elif model == "opt": + critic = OPTCritic(pretrained=pretrained, lora_rank=lora_rank, config=config) + elif model == "llama": + critic = LlamaCritic(pretrained=pretrained, lora_rank=lora_rank, config=config) + else: + raise ValueError(f'Unsupported reward model "{model}"') + return critic + + +def get_reward_model_from_args(model: str, pretrained: str = None, config=None): + if model == "gpt2": + reward_model = GPTRM(pretrained=pretrained, config=config) + elif model == "bloom": + reward_model = BLOOMRM(pretrained=pretrained, config=config) + elif model == "opt": + reward_model = OPTRM(pretrained=pretrained, config=config) + elif model == "llama": + reward_model = LlamaRM(pretrained=pretrained, config=config) + else: + raise ValueError(f'Unsupported reward model "{model}"') + return reward_model + + +def get_strategy_from_args(strategy: str): + if strategy == "ddp": + strategy_ = DDPStrategy() + elif strategy == "colossalai_gemini": + strategy_ = GeminiStrategy(placement_policy="static", initial_scale=2**5) + elif strategy == "colossalai_zero2": + strategy_ = LowLevelZeroStrategy(stage=2, placement_policy="cuda") + elif strategy == "colossalai_gemini_cpu": + strategy_ = GeminiStrategy(placement_policy="static", offload_optim_frac=1.0, offload_param_frac=1.0, initial_scale=2**5) + elif strategy == "colossalai_zero2_cpu": + strategy_ = LowLevelZeroStrategy(stage=2, placement_policy="cpu") + else: + raise ValueError(f'Unsupported strategy "{strategy}"') + return strategy_ + + +def get_tokenizer_from_args(model: str, **kwargs): + if model == "gpt2": + tokenizer = GPT2Tokenizer.from_pretrained("gpt2") + elif model == "bloom": + tokenizer = BloomTokenizerFast.from_pretrained("bigscience/bloom-560m") + elif model == "opt": + tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m") + elif model == "llama": + pretrain_path = kwargs["pretrain"] + tokenizer = AutoTokenizer.from_pretrained(pretrain_path) + else: + raise ValueError(f'Unsupported model "{model}"') + + tokenizer.pad_token = tokenizer.eos_token + return tokenizer + + +def set_dist_env(env_info: Dict[str, str]): + os.environ["RANK"] = env_info["rank"] + os.environ["LOCAL_RANK"] = env_info["local_rank"] + os.environ["WORLD_SIZE"] = env_info["world_size"] + os.environ["MASTER_PORT"] = env_info["master_port"] + os.environ["MASTER_ADDR"] = env_info["master_addr"] + + +def get_model_numel(model: nn.Module) -> int: + numel = sum(p.numel() for p in model.parameters()) + return numel + + +def get_receivers_per_sender(sender_idx: int, num_senders: int, num_receivers: int, allow_idle_sender: bool) -> list: + target_receivers = [] + if num_senders <= num_receivers or allow_idle_sender: + # a sender will send data to one or more receivers + # a receiver only has one sender + for i in range(num_receivers): + if i % num_senders == sender_idx: + target_receivers.append(i) + else: + # a sender will send data to one receiver + # a receiver may have more than one sender + target_receivers.append(sender_idx % num_receivers) + return target_receivers + + +def state_dict_to( + state_dict: Dict[str, Any], dtype: torch.dtype = torch.float16, device: torch.device = torch.device("cpu") +): + """ + keep state_dict intact + """ + new_state_dict = OrderedDict() + for k, v in state_dict.items(): + new_state_dict[k] = v.to(dtype=dtype, device=device) + return new_state_dict diff --git a/applications/Chat/coati/replay_buffer/__init__.py b/applications/Chat/coati/replay_buffer/__init__.py deleted file mode 100644 index 1ebf60382913ead7247197a6ae7b021ceb7e5d71..0000000000000000000000000000000000000000 --- a/applications/Chat/coati/replay_buffer/__init__.py +++ /dev/null @@ -1,4 +0,0 @@ -from .base import ReplayBuffer -from .naive import NaiveReplayBuffer - -__all__ = ['ReplayBuffer', 'NaiveReplayBuffer'] diff --git a/applications/Chat/coati/replay_buffer/base.py b/applications/Chat/coati/replay_buffer/base.py deleted file mode 100644 index 4c3812461a10120c358d4ddbccdffde4eff7fdfa..0000000000000000000000000000000000000000 --- a/applications/Chat/coati/replay_buffer/base.py +++ /dev/null @@ -1,43 +0,0 @@ -from abc import ABC, abstractmethod -from typing import Any - -from coati.experience_maker.base import Experience - - -class ReplayBuffer(ABC): - """Replay buffer base class. It stores experience. - - Args: - sample_batch_size (int): Batch size when sampling. - limit (int, optional): Limit of number of experience samples. A number <= 0 means unlimited. Defaults to 0. - """ - - def __init__(self, sample_batch_size: int, limit: int = 0) -> None: - super().__init__() - self.sample_batch_size = sample_batch_size - # limit <= 0 means unlimited - self.limit = limit - - @abstractmethod - def append(self, experience: Experience) -> None: - pass - - @abstractmethod - def clear(self) -> None: - pass - - @abstractmethod - def sample(self) -> Experience: - pass - - @abstractmethod - def __len__(self) -> int: - pass - - @abstractmethod - def __getitem__(self, idx: int) -> Any: - pass - - @abstractmethod - def collate_fn(self, batch: Any) -> Experience: - pass diff --git a/applications/Chat/coati/replay_buffer/naive.py b/applications/Chat/coati/replay_buffer/naive.py deleted file mode 100644 index 938f500643c96c7370314bf6e60d3e38d765bc12..0000000000000000000000000000000000000000 --- a/applications/Chat/coati/replay_buffer/naive.py +++ /dev/null @@ -1,57 +0,0 @@ -import random -from typing import List - -import torch -from coati.experience_maker.base import Experience - -from .base import ReplayBuffer -from .utils import BufferItem, make_experience_batch, split_experience_batch - - -class NaiveReplayBuffer(ReplayBuffer): - """Naive replay buffer class. It stores experience. - - Args: - sample_batch_size (int): Batch size when sampling. - limit (int, optional): Limit of number of experience samples. A number <= 0 means unlimited. Defaults to 0. - cpu_offload (bool, optional): Whether to offload experience to cpu when sampling. Defaults to True. - """ - - def __init__(self, sample_batch_size: int, limit: int = 0, cpu_offload: bool = True) -> None: - super().__init__(sample_batch_size, limit) - self.cpu_offload = cpu_offload - self.target_device = torch.device(f'cuda:{torch.cuda.current_device()}') - # TODO(ver217): add prefetch - self.items: List[BufferItem] = [] - - @torch.no_grad() - def append(self, experience: Experience) -> None: - if self.cpu_offload: - experience.to_device(torch.device('cpu')) - items = split_experience_batch(experience) - self.items.extend(items) - if self.limit > 0: - samples_to_remove = len(self.items) - self.limit - if samples_to_remove > 0: - self.items = self.items[samples_to_remove:] - - def clear(self) -> None: - self.items.clear() - - @torch.no_grad() - def sample(self) -> Experience: - items = random.sample(self.items, self.sample_batch_size) - experience = make_experience_batch(items) - if self.cpu_offload: - experience.to_device(self.target_device) - return experience - - def __len__(self) -> int: - return len(self.items) - - def __getitem__(self, idx: int) -> BufferItem: - return self.items[idx] - - def collate_fn(self, batch) -> Experience: - experience = make_experience_batch(batch) - return experience diff --git a/applications/Chat/coati/replay_buffer/utils.py b/applications/Chat/coati/replay_buffer/utils.py deleted file mode 100644 index 6ad0db2c3b609e0a3cfa4b1126550ebb44dec6a2..0000000000000000000000000000000000000000 --- a/applications/Chat/coati/replay_buffer/utils.py +++ /dev/null @@ -1,73 +0,0 @@ -from dataclasses import dataclass -from typing import List, Optional - -import torch -import torch.nn.functional as F -from coati.experience_maker.base import Experience - - -@dataclass -class BufferItem: - """BufferItem is an item of experience data. - - Shapes of each tensor: - sequences: (S) - action_log_probs: (A) - values: (1) - reward: (1) - advantages: (1) - attention_mask: (S) - action_mask: (A) - - "A" is the number of actions. - """ - sequences: torch.Tensor - action_log_probs: torch.Tensor - values: torch.Tensor - reward: torch.Tensor - advantages: torch.Tensor - attention_mask: Optional[torch.LongTensor] - action_mask: Optional[torch.BoolTensor] - - -def split_experience_batch(experience: Experience) -> List[BufferItem]: - batch_size = experience.sequences.size(0) - batch_kwargs = [{} for _ in range(batch_size)] - keys = ('sequences', 'action_log_probs', 'values', 'reward', 'advantages', 'attention_mask', 'action_mask') - for key in keys: - value = getattr(experience, key) - if isinstance(value, torch.Tensor): - vals = torch.unbind(value) - else: - # None - vals = [value for _ in range(batch_size)] - assert batch_size == len(vals) - for i, v in enumerate(vals): - batch_kwargs[i][key] = v - items = [BufferItem(**kwargs) for kwargs in batch_kwargs] - return items - - -def zero_pad_sequences(sequences: List[torch.Tensor], side: str = 'left') -> torch.Tensor: - assert side in ('left', 'right') - max_len = max(seq.size(0) for seq in sequences) - padded_sequences = [] - for seq in sequences: - pad_len = max_len - seq.size(0) - padding = (pad_len, 0) if side == 'left' else (0, pad_len) - padded_sequences.append(F.pad(seq, padding)) - return torch.stack(padded_sequences, dim=0) - - -def make_experience_batch(items: List[BufferItem]) -> Experience: - kwargs = {} - to_pad_keys = set(('action_log_probs', 'action_mask')) - keys = ('sequences', 'action_log_probs', 'values', 'reward', 'advantages', 'attention_mask', 'action_mask') - for key in keys: - vals = [getattr(item, key) for item in items] - if key in to_pad_keys: - batch_data = zero_pad_sequences(vals) - else: - batch_data = torch.stack(vals, dim=0) - kwargs[key] = batch_data - return Experience(**kwargs) diff --git a/applications/Chat/coati/trainer/__init__.py b/applications/Chat/coati/trainer/__init__.py index 525b57bf21d351e6522102a8d39ff5eb1aec5250..4be5d27f93b17c28cb947fec815f2ac2c138efc3 100644 --- a/applications/Chat/coati/trainer/__init__.py +++ b/applications/Chat/coati/trainer/__init__.py @@ -1,6 +1,6 @@ -from .base import Trainer +from .base import OnPolicyTrainer, SLTrainer from .ppo import PPOTrainer from .rm import RewardModelTrainer from .sft import SFTTrainer -__all__ = ['Trainer', 'PPOTrainer', 'RewardModelTrainer', 'SFTTrainer'] +__all__ = ["SLTrainer", "OnPolicyTrainer", "RewardModelTrainer", "SFTTrainer", "PPOTrainer"] diff --git a/applications/Chat/coati/trainer/base.py b/applications/Chat/coati/trainer/base.py index ac3a878be88430f94af1bf19bcde4aaf1d0ec7f2..0a41d450d41ef52553f1d036fff8593271630a69 100644 --- a/applications/Chat/coati/trainer/base.py +++ b/applications/Chat/coati/trainer/base.py @@ -1,54 +1,106 @@ from abc import ABC, abstractmethod -from typing import Any, Callable, Dict, List, Optional, Union +from contextlib import contextmanager +from typing import List -import torch +import torch.nn as nn +import tqdm +from coati.experience_buffer import NaiveExperienceBuffer from coati.experience_maker import Experience +from torch.optim import Optimizer from .callbacks import Callback from .strategies import Strategy +from .utils import is_rank_0 -class Trainer(ABC): +class SLTrainer(ABC): """ - Base class for rlhf trainers. + Base class for supervised learning trainers. Args: strategy (Strategy):the strategy to use for training max_epochs (int, defaults to 1): the number of epochs of training process + model (nn.Module): the model to train + optim (Optimizer): the optimizer to use for training + """ + + def __init__( + self, + strategy: Strategy, + max_epochs: int, + model: nn.Module, + optimizer: Optimizer, + ) -> None: + super().__init__() + self.strategy = strategy + self.max_epochs = max_epochs + self.model = model + self.optimizer = optimizer + + @abstractmethod + def _train(self, epoch): + raise NotImplementedError() + + @abstractmethod + def _eval(self, epoch): + raise NotImplementedError() + + def _before_fit(self): + raise NotImplementedError() + + def fit(self, *args, **kwargs): + self._before_fit(*args, **kwargs) + for epoch in tqdm.trange(self.max_epochs, desc="Epochs", disable=not is_rank_0()): + self._train(epoch) + self._eval(epoch) + + +class OnPolicyTrainer(ABC): + """ + Base class for on-policy rl trainers, e.g. PPO. + + Args: + strategy (Strategy):the strategy to use for training + data_buffer (NaiveExperienceBuffer): the buffer to collect experiences + sample_buffer (bool, defaults to False): whether to sample from buffer dataloader_pin_memory (bool, defaults to True): whether to pin memory for data loader callbacks (List[Callback], defaults to []): the callbacks to call during training process - generate_kwargs (dict, optional): the kwargs to use while model generating """ - def __init__(self, - strategy: Strategy, - max_epochs: int = 1, - dataloader_pin_memory: bool = True, - callbacks: List[Callback] = [], - **generate_kwargs) -> None: + def __init__( + self, + strategy: Strategy, + data_buffer: NaiveExperienceBuffer, + sample_buffer: bool, + dataloader_pin_memory: bool, + callbacks: List[Callback] = [], + ) -> None: super().__init__() self.strategy = strategy - self.max_epochs = max_epochs - self.generate_kwargs = generate_kwargs + self.data_buffer = data_buffer + self.sample_buffer = sample_buffer self.dataloader_pin_memory = dataloader_pin_memory self.callbacks = callbacks - # TODO(ver217): maybe simplify these code using context - def _on_fit_start(self) -> None: + @contextmanager + def _fit_ctx(self) -> None: for callback in self.callbacks: callback.on_fit_start() - - def _on_fit_end(self) -> None: - for callback in self.callbacks: - callback.on_fit_end() - - def _on_episode_start(self, episode: int) -> None: + try: + yield + finally: + for callback in self.callbacks: + callback.on_fit_end() + + @contextmanager + def _episode_ctx(self, episode: int) -> None: for callback in self.callbacks: callback.on_episode_start(episode) - - def _on_episode_end(self, episode: int) -> None: - for callback in self.callbacks: - callback.on_episode_end(episode) + try: + yield + finally: + for callback in self.callbacks: + callback.on_episode_end(episode) def _on_make_experience_start(self) -> None: for callback in self.callbacks: @@ -70,6 +122,67 @@ class Trainer(ABC): for callback in self.callbacks: callback.on_learn_batch_start() - def _on_learn_batch_end(self, metrics: dict, experience: Experience) -> None: + def _on_learn_batch_end(self, experience: Experience) -> None: for callback in self.callbacks: - callback.on_learn_batch_end(metrics, experience) + callback.on_learn_batch_end(experience) + + @abstractmethod + def _make_experience(self, collect_step: int): + """ + Implement this method to make experience. + """ + raise NotImplementedError() + + @abstractmethod + def _learn(self, update_step: int): + """ + Implement this method to learn from experience, either + sample from buffer or transform buffer into dataloader. + """ + raise NotImplementedError() + + def _collect_phase(self, collect_step: int): + self._on_make_experience_start() + experience = self._make_experience(collect_step) + self._on_make_experience_end(experience) + self.data_buffer.append(experience) + + def _update_phase(self, update_step: int): + self._on_learn_epoch_start(update_step) + self._learn(update_step) + self._on_learn_epoch_end(update_step) + + def _before_fit(self, *args, **kwargs): + raise NotImplementedError() + + def fit( + self, + num_episodes: int, + num_collect_steps: int, + num_update_steps: int, + *args, + **kwargs, + ): + """ + The main training loop of on-policy rl trainers. + + Args: + num_episodes (int): the number of episodes to train + num_collect_steps (int): the number of collect steps per episode + num_update_steps (int): the number of update steps per episode + """ + self._before_fit(*args, **kwargs) + with self._fit_ctx(): + for episode in tqdm.trange(num_episodes, desc="Episodes", disable=not is_rank_0()): + with self._episode_ctx(episode): + for collect_step in tqdm.trange(num_collect_steps, desc="Collect steps", disable=not is_rank_0()): + self._collect_phase(collect_step) + if not self.sample_buffer: + # HACK(cwher): according to the design of boost API, dataloader should also be boosted, + # but it is impractical to adapt this pattern in RL training. Thus, I left dataloader unboosted. + # I only call strategy.setup_dataloader() to setup dataloader. + self.dataloader = self.strategy.setup_dataloader(self.data_buffer, self.dataloader_pin_memory) + for update_step in tqdm.trange(num_update_steps, desc="Update steps", disable=not is_rank_0()): + self._update_phase(update_step) + # NOTE: this is for on-policy algorithms + self.data_buffer.clear() diff --git a/applications/Chat/coati/trainer/callbacks/__init__.py b/applications/Chat/coati/trainer/callbacks/__init__.py index 9ed0ee6f764002fd085f0ecdf957f2dba4d576c8..29c8c4f00a5c62f1888b82206c72a8efadfd28b3 100644 --- a/applications/Chat/coati/trainer/callbacks/__init__.py +++ b/applications/Chat/coati/trainer/callbacks/__init__.py @@ -2,4 +2,4 @@ from .base import Callback from .performance_evaluator import PerformanceEvaluator from .save_checkpoint import SaveCheckpoint -__all__ = ['Callback', 'PerformanceEvaluator', 'SaveCheckpoint'] +__all__ = ["Callback", "PerformanceEvaluator", "SaveCheckpoint"] diff --git a/applications/Chat/coati/trainer/callbacks/base.py b/applications/Chat/coati/trainer/callbacks/base.py index f5616048855b26ceac836e812d1ecdee8035f025..c6e30f04885ccc86b42366f8eca5b65d8ed62868 100644 --- a/applications/Chat/coati/trainer/callbacks/base.py +++ b/applications/Chat/coati/trainer/callbacks/base.py @@ -5,7 +5,7 @@ from coati.experience_maker import Experience class Callback(ABC): """ - Base callback class. It defines the interface for callbacks. + Base callback class. It defines the interface for callbacks. """ def on_fit_start(self) -> None: @@ -35,5 +35,5 @@ class Callback(ABC): def on_learn_batch_start(self) -> None: pass - def on_learn_batch_end(self, metrics: dict, experience: Experience) -> None: + def on_learn_batch_end(self, experience: Experience) -> None: pass diff --git a/applications/Chat/coati/trainer/callbacks/performance_evaluator.py b/applications/Chat/coati/trainer/callbacks/performance_evaluator.py index 925455444597095d67330698cffaca151f8c5e0b..b286c766c2637860a3f0d1fedb5abe9216aa5d6f 100644 --- a/applications/Chat/coati/trainer/callbacks/performance_evaluator.py +++ b/applications/Chat/coati/trainer/callbacks/performance_evaluator.py @@ -21,9 +21,9 @@ def print_rank_0(*args, **kwargs) -> None: def divide(x: float, y: float) -> float: if y == 0: - return float('inf') - elif y == float('inf'): - return float('nan') + return float("inf") + elif y == float("inf"): + return float("nan") return x / y @@ -38,10 +38,9 @@ def all_reduce_mean(x: float, world_size: int) -> float: class Timer: - def __init__(self) -> None: self.start_time: Optional[float] = None - self.duration: float = 0. + self.duration: float = 0.0 def start(self) -> None: self.start_time = time() @@ -52,7 +51,7 @@ class Timer: self.start_time = None def reset(self) -> None: - self.duration = 0. + self.duration = 0.0 class PerformanceEvaluator(Callback): @@ -67,13 +66,15 @@ class PerformanceEvaluator(Callback): ignore_episodes: The number of episodes to ignore when calculating the performance. """ - def __init__(self, - actor_num_params: int, - critic_num_params: int, - initial_model_num_params: int, - reward_model_num_params: int, - enable_grad_checkpoint: bool = False, - ignore_episodes: int = 0) -> None: + def __init__( + self, + actor_num_params: int, + critic_num_params: int, + initial_model_num_params: int, + reward_model_num_params: int, + enable_grad_checkpoint: bool = False, + ignore_episodes: int = 0, + ) -> None: super().__init__() self.world_size = get_world_size() self.actor_num_params = actor_num_params @@ -136,7 +137,7 @@ class PerformanceEvaluator(Callback): return self.learn_timer.start() - def on_learn_batch_end(self, metrics: dict, experience: Experience) -> None: + def on_learn_batch_end(self, experience: Experience) -> None: if self.disable: return self.learn_timer.end() @@ -155,8 +156,9 @@ class PerformanceEvaluator(Callback): avg_learn_duration = all_reduce_mean(self.learn_timer.duration, self.world_size) avg_overall_duration = all_reduce_mean(self.overall_timer.duration, self.world_size) - avg_make_experience_throughput = self.make_experience_num_samples * \ - self.world_size / (avg_make_experience_duration + 1e-12) + avg_make_experience_throughput = ( + self.make_experience_num_samples * self.world_size / (avg_make_experience_duration + 1e-12) + ) avg_make_experience_tflops = self.make_experience_flop / 1e12 / (avg_make_experience_duration + 1e-12) avg_learn_throughput = self.learn_num_samples * self.world_size / (avg_learn_duration + 1e-12) @@ -171,13 +173,11 @@ class PerformanceEvaluator(Callback): learn_time_per_sample = divide(avg_learn_duration, num_effective_samples) print_rank_0( - f'Performance summary:\n' + - f'Generate {self.make_experience_num_samples * self.world_size} samples, throughput: {avg_make_experience_throughput:.2f} samples/s, TFLOPS per GPU: {avg_make_experience_tflops:.2f}\n' - + - f'Train {self.learn_num_samples * self.world_size} samples, throughput: {avg_learn_throughput:.2f} samples/s, TFLOPS per GPU: {avg_learn_tflops:.2f}\n' - + f'Overall throughput: {avg_overall_throughput:.2f} samples/s\n' + - f'Overall time per sample: {overall_time_per_sample:.2f} s\n' + - f'Make experience time per sample: {make_experience_time_per_sample:.2f} s, {make_experience_time_per_sample/overall_time_per_sample*100:.2f}%\n' - + - f'Learn time per sample: {learn_time_per_sample:.2f} s, {learn_time_per_sample/overall_time_per_sample*100:.2f}%' + f"Performance summary:\n" + + f"Generate {self.make_experience_num_samples * self.world_size} samples, throughput: {avg_make_experience_throughput:.2f} samples/s, TFLOPS per GPU: {avg_make_experience_tflops:.2f}\n" + + f"Train {self.learn_num_samples * self.world_size} samples, throughput: {avg_learn_throughput:.2f} samples/s, TFLOPS per GPU: {avg_learn_tflops:.2f}\n" + + f"Overall throughput: {avg_overall_throughput:.2f} samples/s\n" + + f"Overall time per sample: {overall_time_per_sample:.2f} s\n" + + f"Make experience time per sample: {make_experience_time_per_sample:.2f} s, {make_experience_time_per_sample/overall_time_per_sample*100:.2f}%\n" + + f"Learn time per sample: {learn_time_per_sample:.2f} s, {learn_time_per_sample/overall_time_per_sample*100:.2f}%" ) diff --git a/applications/Chat/coati/trainer/callbacks/save_checkpoint.py b/applications/Chat/coati/trainer/callbacks/save_checkpoint.py index d2dcc0dd4c65f0aae86a87b188cc2801fa0e5281..0d70b6c53073a5e14a57a18a38818416d9db6daa 100644 --- a/applications/Chat/coati/trainer/callbacks/save_checkpoint.py +++ b/applications/Chat/coati/trainer/callbacks/save_checkpoint.py @@ -1,7 +1,7 @@ import os import torch.distributed as dist -from coati.trainer.strategies import ColossalAIStrategy, Strategy +from coati.trainer.strategies import GeminiStrategy, LowLevelZeroStrategy, Strategy from coati.trainer.utils import is_rank_0 from torch import nn from torch.optim import Optimizer @@ -36,40 +36,41 @@ class SaveCheckpoint(Callback): """ - def __init__(self, - path: str, - interval: int, - strategy: Strategy, - actor: nn.Module = None, - critic: nn.Module = None, - actor_optim: Optimizer = None, - critic_optim: Optimizer = None) -> None: + def __init__( + self, + path: str, + interval: int, + strategy: Strategy, + actor: nn.Module = None, + critic: nn.Module = None, + actor_optim: Optimizer = None, + critic_optim: Optimizer = None, + ) -> None: super().__init__() - self.path = os.path.join(path, 'checkpoint') + self.path = os.path.join(path, "checkpoint") self.interval = interval self.strategy = strategy - self.model_dict = {'actor': [actor, actor_optim], 'critic': [critic, critic_optim]} + self.model_dict = {"actor": [actor, actor_optim], "critic": [critic, critic_optim]} def on_episode_end(self, episode: int) -> None: if (episode + 1) % self.interval != 0: return - base_path = os.path.join(self.path, f'episode_{episode}') + base_path = os.path.join(self.path, f"episode_{episode}") if not os.path.exists(base_path): os.makedirs(base_path) for model in self.model_dict.keys(): - # save model if self.model_dict[model][0] is None: # saving only optimizer states is meaningless, so it would be skipped continue - model_path = os.path.join(base_path, f'{model}.pt') + model_path = os.path.join(base_path, f"{model}.pt") self.strategy.save_model(model=self.model_dict[model][0], path=model_path, only_rank0=True) # save optimizer if self.model_dict[model][1] is None: continue - only_rank0 = not isinstance(self.strategy, ColossalAIStrategy) + only_rank0 = not isinstance(self.strategy, (LowLevelZeroStrategy, GeminiStrategy)) rank = 0 if is_rank_0() else dist.get_rank() - optim_path = os.path.join(base_path, f'{model}-optim-rank-{rank}.pt') + optim_path = os.path.join(base_path, f"{model}-optim-rank-{rank}.pt") self.strategy.save_optimizer(optimizer=self.model_dict[model][1], path=optim_path, only_rank0=only_rank0) diff --git a/applications/Chat/coati/trainer/ppo.py b/applications/Chat/coati/trainer/ppo.py index fe5ae48d9c2f05af14235a0106c43e74959e4c11..d6966689885ec2760247c211ef51dc6813e447fe 100644 --- a/applications/Chat/coati/trainer/ppo.py +++ b/applications/Chat/coati/trainer/ppo.py @@ -1,26 +1,38 @@ -from typing import Any, Callable, Dict, List, Optional, Union +from typing import Dict, List, Optional -import torch -import torch.nn as nn +from coati.experience_buffer import NaiveExperienceBuffer from coati.experience_maker import Experience, NaiveExperienceMaker -from coati.models.base import Actor, Critic +from coati.models.base import Actor, Critic, RewardModel, get_base_model from coati.models.loss import GPTLMLoss, PolicyLoss, ValueLoss -from coati.replay_buffer import NaiveReplayBuffer -from torch import Tensor +from coati.models.utils import calc_action_log_probs from torch.optim import Optimizer -from torch.utils.data import DistributedSampler +from torch.utils.data import DataLoader, DistributedSampler from tqdm import tqdm -from transformers.tokenization_utils_base import PreTrainedTokenizerBase +from transformers import PreTrainedTokenizerBase from colossalai.utils import get_current_device -from .base import Trainer +from .base import OnPolicyTrainer from .callbacks import Callback -from .strategies import Strategy -from .utils import is_rank_0, to_device +from .strategies import GeminiStrategy, Strategy +from .utils import CycledDataLoader, is_rank_0, to_device -class PPOTrainer(Trainer): +def _set_default_generate_kwargs(strategy: Strategy, generate_kwargs: dict, actor: Actor) -> Dict: + unwrapped_model = strategy.unwrap_model(actor) + hf_model = get_base_model(unwrapped_model) + new_kwargs = {**generate_kwargs} + # use huggingface models method directly + if "prepare_inputs_fn" not in generate_kwargs and hasattr(hf_model, "prepare_inputs_for_generation"): + new_kwargs["prepare_inputs_fn"] = hf_model.prepare_inputs_for_generation + + if "update_model_kwargs_fn" not in generate_kwargs and hasattr(hf_model, "_update_model_kwargs_for_generation"): + new_kwargs["update_model_kwargs_fn"] = hf_model._update_model_kwargs_for_generation + + return new_kwargs + + +class PPOTrainer(OnPolicyTrainer): """ Trainer for PPO algorithm. @@ -28,60 +40,61 @@ class PPOTrainer(Trainer): strategy (Strategy): the strategy to use for training actor (Actor): the actor model in ppo algorithm critic (Critic): the critic model in ppo algorithm - reward_model (nn.Module): the reward model in rlhf algorithm to make reward of sentences + reward_model (RewardModel): the reward model in rlhf algorithm to make reward of sentences initial_model (Actor): the initial model in rlhf algorithm to generate reference logics to limit the update of actor actor_optim (Optimizer): the optimizer to use for actor model critic_optim (Optimizer): the optimizer to use for critic model kl_coef (float, defaults to 0.1): the coefficient of kl divergence loss train_batch_size (int, defaults to 8): the batch size to use for training - buffer_limit (int, defaults to 0): the max_size limitation of replay buffer - buffer_cpu_offload (bool, defaults to True): whether to offload replay buffer to cpu + buffer_limit (int, defaults to 0): the max_size limitation of buffer + buffer_cpu_offload (bool, defaults to True): whether to offload buffer to cpu eps_clip (float, defaults to 0.2): the clip coefficient of policy loss vf_coef (float, defaults to 1.0): the coefficient of value loss ptx_coef (float, defaults to 0.9): the coefficient of ptx loss value_clip (float, defaults to 0.4): the clip coefficient of value loss - max_epochs (int, defaults to 1): the number of epochs of training process - sample_replay_buffer (bool, defaults to False): whether to sample from replay buffer + sample_buffer (bool, defaults to False): whether to sample from buffer dataloader_pin_memory (bool, defaults to True): whether to pin memory for data loader offload_inference_models (bool, defaults to True): whether to offload inference models to cpu during training process callbacks (List[Callback], defaults to []): the callbacks to call during training process generate_kwargs (dict, optional): the kwargs to use while model generating """ - def __init__(self, - strategy: Strategy, - actor: Actor, - critic: Critic, - reward_model: nn.Module, - initial_model: Actor, - actor_optim: Optimizer, - critic_optim: Optimizer, - kl_coef: float = 0.1, - ptx_coef: float = 0.9, - train_batch_size: int = 8, - buffer_limit: int = 0, - buffer_cpu_offload: bool = True, - eps_clip: float = 0.2, - vf_coef: float = 1.0, - value_clip: float = 0.4, - max_epochs: int = 1, - sample_replay_buffer: bool = False, - dataloader_pin_memory: bool = True, - offload_inference_models: bool = True, - callbacks: List[Callback] = [], - **generate_kwargs) -> None: - experience_maker = NaiveExperienceMaker(actor, critic, reward_model, initial_model, kl_coef) - replay_buffer = NaiveReplayBuffer(train_batch_size, buffer_limit, buffer_cpu_offload) - generate_kwargs = _set_default_generate_kwargs(strategy, generate_kwargs, actor) - super().__init__(strategy, max_epochs, dataloader_pin_memory, callbacks, **generate_kwargs) - - self.experience_maker = experience_maker - self.replay_buffer = replay_buffer - self.sample_replay_buffer = sample_replay_buffer - self.offload_inference_models = offload_inference_models + def __init__( + self, + strategy: Strategy, + actor: Actor, + critic: Critic, + reward_model: RewardModel, + initial_model: Actor, + actor_optim: Optimizer, + critic_optim: Optimizer, + tokenizer: PreTrainedTokenizerBase, + kl_coef: float = 0.1, + ptx_coef: float = 0.9, + train_batch_size: int = 8, + buffer_limit: int = 0, + buffer_cpu_offload: bool = True, + eps_clip: float = 0.2, + vf_coef: float = 1.0, + value_clip: float = 0.4, + sample_buffer: bool = False, + dataloader_pin_memory: bool = True, + offload_inference_models: bool = True, + callbacks: List[Callback] = [], + **generate_kwargs, + ) -> None: + if isinstance(strategy, GeminiStrategy): + assert not offload_inference_models, "GeminiPlugin is not compatible with manual model.to('cpu')" + + data_buffer = NaiveExperienceBuffer(train_batch_size, buffer_limit, buffer_cpu_offload) + super().__init__(strategy, data_buffer, sample_buffer, dataloader_pin_memory, callbacks) + + self.generate_kwargs = _set_default_generate_kwargs(strategy, generate_kwargs, actor) + self.experience_maker = NaiveExperienceMaker(actor, critic, reward_model, initial_model, tokenizer, kl_coef) self.actor = actor self.critic = critic + self.tokenizer = tokenizer self.actor_loss_fn = PolicyLoss(eps_clip) self.critic_loss_fn = ValueLoss(value_clip) @@ -91,123 +104,99 @@ class PPOTrainer(Trainer): self.actor_optim = actor_optim self.critic_optim = critic_optim + self.offload_inference_models = offload_inference_models self.device = get_current_device() - def _make_experience(self, inputs: Union[Tensor, Dict[str, Tensor]]) -> Experience: - if isinstance(inputs, Tensor): - return self.experience_maker.make_experience(inputs, **self.generate_kwargs) - elif isinstance(inputs, dict): - return self.experience_maker.make_experience(**inputs, **self.generate_kwargs) - else: - raise ValueError(f'Unsupported input type "{type(inputs)}"') - - def _learn(self): - # replay buffer may be empty at first, we should rebuild at each training - if not self.sample_replay_buffer: - dataloader = self.strategy.setup_dataloader(self.replay_buffer, self.dataloader_pin_memory) - if self.sample_replay_buffer: - pbar = tqdm(range(self.max_epochs), desc='Train epoch', disable=not is_rank_0()) - for _ in pbar: - experience = self.replay_buffer.sample() - experience.to_device(self.device) - metrics = self.training_step(experience) - pbar.set_postfix(metrics) - else: - for epoch in range(self.max_epochs): - self._on_learn_epoch_start(epoch) - if isinstance(dataloader.sampler, DistributedSampler): - dataloader.sampler.set_epoch(epoch) - pbar = tqdm(dataloader, desc=f'Train epoch [{epoch+1}/{self.max_epochs}]', disable=not is_rank_0()) - for experience in pbar: - self._on_learn_batch_start() - experience.to_device(self.device) - metrics = self.training_step(experience) - self._on_learn_batch_end(metrics, experience) - pbar.set_postfix(metrics) - self._on_learn_epoch_end(epoch) - - def fit(self, - prompt_dataloader, - pretrain_dataloader, - num_episodes: int = 50000, - max_timesteps: int = 500, - update_timesteps: int = 5000) -> None: - time = 0 - self.pretrain_dataloader = pretrain_dataloader - self.prompt_dataloader = prompt_dataloader - self._on_fit_start() - for episode in range(num_episodes): - self._on_episode_start(episode) - for timestep in tqdm(range(max_timesteps), - desc=f'Episode [{episode+1}/{num_episodes}]', - disable=not is_rank_0()): - time += 1 - prompts = next(iter(self.prompt_dataloader)) - self._on_make_experience_start() - if self.offload_inference_models: - # TODO(ver217): this may be controlled by strategy if they are prepared by strategy - self.experience_maker.initial_model.to(self.device) - self.experience_maker.reward_model.to(self.device) - experience = self._make_experience(prompts) - self._on_make_experience_end(experience) - self.replay_buffer.append(experience) - if time % update_timesteps == 0: - if self.offload_inference_models: - self.experience_maker.initial_model.to('cpu') - self.experience_maker.reward_model.to('cpu') - self._learn() - self.replay_buffer.clear() - self._on_episode_end(episode) - self._on_fit_end() - - def training_step(self, experience: Experience) -> Dict[str, float]: + def _before_fit( + self, + prompt_dataloader: DataLoader, + pretrain_dataloader: DataLoader, + log_dir: Optional[str] = None, + use_wandb: bool = False, + ): + """ + Args: + prompt_dataloader (DataLoader): the dataloader to use for prompt data + pretrain_dataloader (DataLoader): the dataloader to use for pretrain data + """ + self.prompt_dataloader = CycledDataLoader(prompt_dataloader) + self.pretrain_dataloader = CycledDataLoader(pretrain_dataloader) + + self.writer = None + if use_wandb and is_rank_0(): + assert log_dir is not None, "log_dir must be provided when use_wandb is True" + import wandb + + wandb.init(project="Coati-ppo", sync_tensorboard=True) + if log_dir is not None and is_rank_0(): + import os + import time + + from torch.utils.tensorboard import SummaryWriter + + log_dir = os.path.join(log_dir, "ppo") + log_dir = os.path.join(log_dir, time.strftime("%Y-%m-%d_%H:%M:%S", time.localtime())) + self.writer = SummaryWriter(log_dir=log_dir) + + def _make_experience(self, collect_step: int) -> Experience: + prompts = self.prompt_dataloader.next() + if self.offload_inference_models: + # TODO(ver217): this may be controlled by strategy if they are prepared by strategy + self.experience_maker.initial_model.to(self.device) + self.experience_maker.reward_model.to(self.device) + assert isinstance(prompts, dict), f'Unsupported input type "{type(prompts)}"' + return self.experience_maker.make_experience(**prompts, **self.generate_kwargs) + + def _training_step(self, experience: Experience): self.actor.train() self.critic.train() # policy loss - num_actions = experience.action_mask.size(1) - action_log_probs = self.actor(experience.sequences, num_actions, attention_mask=experience.attention_mask) - actor_loss = self.actor_loss_fn(action_log_probs, - experience.action_log_probs, - experience.advantages, - action_mask=experience.action_mask) + num_actions = experience.action_log_probs.size(1) + actor_logits = self.actor(experience.sequences, experience.attention_mask)["logits"] + action_log_probs = calc_action_log_probs(actor_logits, experience.sequences, num_actions) + actor_loss = self.actor_loss_fn( + action_log_probs, experience.action_log_probs, experience.advantages, action_mask=experience.action_mask + ) + actor_loss = (1 - self.ptx_coef) * actor_loss + self.strategy.backward(actor_loss, self.actor, self.actor_optim) # ptx loss if self.ptx_coef != 0: - batch = next(iter(self.pretrain_dataloader)) + batch = self.pretrain_dataloader.next() batch = to_device(batch, self.device) - ptx_log_probs = self.actor.get_base_model()(batch['input_ids'], - attention_mask=batch['attention_mask'])['logits'] - ptx_loss = self.ptx_loss_fn(ptx_log_probs, batch['labels']) - actor_loss = ptx_loss * self.ptx_coef + actor_loss * (1 - self.ptx_coef) + ptx_log_probs = self.actor(batch["input_ids"], batch["attention_mask"])["logits"] + ptx_loss = self.ptx_coef * self.ptx_loss_fn(ptx_log_probs, batch["labels"]) + self.strategy.backward(ptx_loss, self.actor, self.actor_optim) - self.strategy.backward(actor_loss, self.actor, self.actor_optim) self.strategy.optimizer_step(self.actor_optim) self.actor_optim.zero_grad() # value loss - values = self.critic(experience.sequences, - action_mask=experience.action_mask, - attention_mask=experience.attention_mask) - critic_loss = self.critic_loss_fn(values, - experience.values, - experience.reward, - action_mask=experience.action_mask) + values = self.critic(experience.sequences, attention_mask=experience.attention_mask) + critic_loss = self.critic_loss_fn(values, experience.values, experience.reward) critic_loss = critic_loss * self.vf_coef self.strategy.backward(critic_loss, self.critic, self.critic_optim) self.strategy.optimizer_step(self.critic_optim) self.critic_optim.zero_grad() - return {'reward': experience.reward.mean().item()} - - -def _set_default_generate_kwargs(strategy: Strategy, generate_kwargs: dict, actor: Actor) -> None: - origin_model = strategy.unwrap_model(actor) - new_kwargs = {**generate_kwargs} - # use huggingface models method directly - if 'prepare_inputs_fn' not in generate_kwargs and hasattr(origin_model, 'prepare_inputs_for_generation'): - new_kwargs['prepare_inputs_fn'] = origin_model.prepare_inputs_for_generation - - if 'update_model_kwargs_fn' not in generate_kwargs and hasattr(origin_model, '_update_model_kwargs_for_generation'): - new_kwargs['update_model_kwargs_fn'] = origin_model._update_model_kwargs_for_generation - - return new_kwargs + def _learn(self, update_step: int): + if self.offload_inference_models: + self.experience_maker.initial_model.to("cpu") + self.experience_maker.reward_model.to("cpu") + + # buffer may be empty at first, we should rebuild at each training + if self.sample_buffer: + experience = self.data_buffer.sample() + self._on_learn_batch_start() + experience.to_device(self.device) + self._training_step(experience) + self._on_learn_batch_end(experience) + else: + if isinstance(self.dataloader.sampler, DistributedSampler): + self.dataloader.sampler.set_epoch(update_step) + pbar = tqdm(self.dataloader, desc=f"Train epoch [{update_step + 1}]", disable=not is_rank_0()) + for experience in pbar: + self._on_learn_batch_start() + experience.to_device(self.device) + self._training_step(experience) + self._on_learn_batch_end(experience) diff --git a/applications/Chat/coati/trainer/rm.py b/applications/Chat/coati/trainer/rm.py index cdae5108ab00e3506449d8c837f62b2d594d8d31..d7f8c21a5a3d2264ce553b09139f73d761c204e4 100644 --- a/applications/Chat/coati/trainer/rm.py +++ b/applications/Chat/coati/trainer/rm.py @@ -1,35 +1,27 @@ -from datetime import datetime -from typing import List, Optional +from typing import Callable, Optional -import pandas as pd import torch -import torch.distributed as dist -from torch.optim import Optimizer, lr_scheduler -from torch.utils.data import DataLoader, Dataset, DistributedSampler -from tqdm import tqdm -from transformers.tokenization_utils_base import PreTrainedTokenizerBase - -from .base import Trainer -from .callbacks import Callback +import tqdm +from torch.optim import Optimizer +from torch.optim.lr_scheduler import _LRScheduler +from torch.utils.data import DataLoader + +from .base import SLTrainer from .strategies import Strategy from .utils import is_rank_0 -class RewardModelTrainer(Trainer): +class RewardModelTrainer(SLTrainer): """ Trainer to use while training reward model. Args: model (torch.nn.Module): the model to train strategy (Strategy): the strategy to use for training - optim(Optimizer): the optimizer to use for training + optim (Optimizer): the optimizer to use for training + lr_scheduler (_LRScheduler): the lr scheduler to use for training loss_fn (callable): the loss function to use for training - train_dataloader (DataLoader): the dataloader to use for training - valid_dataloader (DataLoader): the dataloader to use for validation - eval_dataloader (DataLoader): the dataloader to use for evaluation - batch_size (int, defaults to 1): the batch size while training max_epochs (int, defaults to 2): the number of epochs to train - callbacks (List[Callback], defaults to []): the callbacks to call during training process """ def __init__( @@ -37,87 +29,95 @@ class RewardModelTrainer(Trainer): model, strategy: Strategy, optim: Optimizer, - loss_fn, - train_dataloader: DataLoader, - valid_dataloader: DataLoader, - eval_dataloader: DataLoader, + lr_scheduler: _LRScheduler, + loss_fn: Callable, max_epochs: int = 1, - callbacks: List[Callback] = [], ) -> None: - super().__init__(strategy, max_epochs, callbacks=callbacks) + super().__init__(strategy, max_epochs, model, optim) + + self.loss_fn = loss_fn + self.scheduler = lr_scheduler + + self.num_train_step = 0 + + def _eval(self, epoch): + if self.eval_dataloader is not None: + self.model.eval() + dist, num_correct, num_samples = 0, 0, 0 + with torch.no_grad(): + for chosen_ids, c_mask, reject_ids, r_mask in self.eval_dataloader: + chosen_ids = chosen_ids.squeeze(1).to(torch.cuda.current_device()) + c_mask = c_mask.squeeze(1).to(torch.cuda.current_device()) + reject_ids = reject_ids.squeeze(1).to(torch.cuda.current_device()) + r_mask = r_mask.squeeze(1).to(torch.cuda.current_device()) + chosen_reward = self.model(chosen_ids, attention_mask=c_mask) + reject_reward = self.model(reject_ids, attention_mask=r_mask) + num_samples += chosen_ids.size(0) + num_correct += (chosen_reward > reject_reward).sum().item() + dist += (chosen_reward - reject_reward).mean().item() + self.dist = dist / len(self.eval_dataloader) + self.acc = num_correct / num_samples + + if self.writer: + self.writer.add_scalar("eval/dist", self.dist, epoch) + self.writer.add_scalar("eval/acc", self.acc, epoch) + + def _train(self, epoch): + self.model.train() + step_bar = tqdm.trange( + len(self.train_dataloader), desc=f"Epoch {epoch + 1}/{self.max_epochs}", disable=not is_rank_0() + ) + for chosen_ids, c_mask, reject_ids, r_mask in self.train_dataloader: + chosen_ids = chosen_ids.squeeze(1).to(torch.cuda.current_device()) + c_mask = c_mask.squeeze(1).to(torch.cuda.current_device()) + reject_ids = reject_ids.squeeze(1).to(torch.cuda.current_device()) + r_mask = r_mask.squeeze(1).to(torch.cuda.current_device()) + chosen_reward = self.model(chosen_ids, attention_mask=c_mask) + reject_reward = self.model(reject_ids, attention_mask=r_mask) + loss = self.loss_fn(chosen_reward, reject_reward) + self.strategy.backward(loss, self.model, self.optimizer) + self.strategy.optimizer_step(self.optimizer) + self.optimizer.zero_grad() + if self.writer: + self.writer.add_scalar("train/loss", loss.item(), self.num_train_step) + self.writer.add_scalar("train/lr", self.optimizer.param_groups[0]["lr"], self.num_train_step) + self.writer.add_scalar("train/dist", (chosen_reward - reject_reward).mean().item(), self.num_train_step) + self.writer.add_scalar( + "train/acc", (chosen_reward > reject_reward).float().mean().item(), self.num_train_step + ) + self.num_train_step += 1 + if self.num_train_step % 100 == 0: + self.scheduler.step() + step_bar.update() + step_bar.close() + def _before_fit( + self, + train_dataloader: DataLoader, + eval_dataloader: DataLoader, + log_dir: Optional[str] = None, + use_wandb: bool = False, + ): + """ + Args: + train_dataloader (DataLoader): the dataloader to use for training + eval_dataloader (DataLoader): the dataloader to use for evaluation + """ self.train_dataloader = train_dataloader - self.valid_dataloader = valid_dataloader self.eval_dataloader = eval_dataloader - self.model = model - self.loss_fn = loss_fn - self.optimizer = optim - self.scheduler = lr_scheduler.CosineAnnealingLR(self.optimizer, self.train_dataloader.__len__() // 100) - - def eval_acc(self, dataloader): - dist = 0 - on = 0 - cnt = 0 - self.model.eval() - with torch.no_grad(): - for chosen_ids, c_mask, reject_ids, r_mask in dataloader: - chosen_ids = chosen_ids.squeeze(1).to(torch.cuda.current_device()) - c_mask = c_mask.squeeze(1).to(torch.cuda.current_device()) - reject_ids = reject_ids.squeeze(1).to(torch.cuda.current_device()) - r_mask = r_mask.squeeze(1).to(torch.cuda.current_device()) - chosen_reward = self.model(chosen_ids, attention_mask=c_mask) - reject_reward = self.model(reject_ids, attention_mask=r_mask) - for i in range(len(chosen_reward)): - cnt += 1 - if chosen_reward[i] > reject_reward[i]: - on += 1 - dist += (chosen_reward - reject_reward).mean().item() - dist_mean = dist / len(dataloader) - acc = on / cnt - self.model.train() - return dist_mean, acc - - def fit(self): - time = datetime.now() - epoch_bar = tqdm(range(self.max_epochs), desc='Train epoch', disable=not is_rank_0()) - for epoch in range(self.max_epochs): - step_bar = tqdm(range(self.train_dataloader.__len__()), - desc='Train step of epoch %d' % epoch, - disable=not is_rank_0()) - # train - self.model.train() - cnt = 0 - acc = 0 - dist = 0 - for chosen_ids, c_mask, reject_ids, r_mask in self.train_dataloader: - chosen_ids = chosen_ids.squeeze(1).to(torch.cuda.current_device()) - c_mask = c_mask.squeeze(1).to(torch.cuda.current_device()) - reject_ids = reject_ids.squeeze(1).to(torch.cuda.current_device()) - r_mask = r_mask.squeeze(1).to(torch.cuda.current_device()) - chosen_reward = self.model(chosen_ids, attention_mask=c_mask) - reject_reward = self.model(reject_ids, attention_mask=r_mask) - loss = self.loss_fn(chosen_reward, reject_reward) - self.strategy.backward(loss, self.model, self.optimizer) - self.strategy.optimizer_step(self.optimizer) - self.optimizer.zero_grad() - cnt += 1 - if cnt == 100: - self.scheduler.step() - dist, acc = self.eval_acc(self.valid_dataloader) - cnt = 0 - if is_rank_0(): - log = pd.DataFrame([[step_bar.n, loss.item(), dist, acc]], - columns=['step', 'loss', 'dist', 'acc']) - log.to_csv('log_%s.csv' % time, mode='a', header=False, index=False) - step_bar.update() - step_bar.set_postfix({'dist': dist, 'acc': acc}) - - # eval - dist, acc = self.eval_acc(self.eval_dataloader) - if is_rank_0(): - log = pd.DataFrame([[step_bar.n, loss.item(), dist, acc]], columns=['step', 'loss', 'dist', 'acc']) - log.to_csv('log.csv', mode='a', header=False, index=False) - epoch_bar.update() - step_bar.set_postfix({'dist': dist, 'acc': acc}) - step_bar.close() + self.writer = None + if use_wandb and is_rank_0(): + assert log_dir is not None, "log_dir must be provided when use_wandb is True" + import wandb + + wandb.init(project="Coati-rm", sync_tensorboard=True) + if log_dir is not None and is_rank_0(): + import os + import time + + from torch.utils.tensorboard import SummaryWriter + + log_dir = os.path.join(log_dir, "rm") + log_dir = os.path.join(log_dir, time.strftime("%Y-%m-%d_%H:%M:%S", time.localtime())) + self.writer = SummaryWriter(log_dir=log_dir) diff --git a/applications/Chat/coati/trainer/sft.py b/applications/Chat/coati/trainer/sft.py index 63fde53956ccd387296cdfde6068f93d76d4fd3f..7d0eeec897e573daa2324a3f63f1d83988946425 100644 --- a/applications/Chat/coati/trainer/sft.py +++ b/applications/Chat/coati/trainer/sft.py @@ -1,23 +1,20 @@ -import math -import time -from typing import List, Optional +from typing import Optional import torch import torch.distributed as dist -import wandb +import tqdm from torch.optim import Optimizer +from torch.optim.lr_scheduler import _LRScheduler from torch.utils.data import DataLoader -from tqdm import tqdm -from transformers.tokenization_utils_base import PreTrainedTokenizerBase -from transformers.trainer import get_scheduler -from .base import Trainer -from .callbacks import Callback -from .strategies import ColossalAIStrategy, Strategy +from colossalai.logging import DistributedLogger + +from .base import SLTrainer +from .strategies import GeminiStrategy, Strategy from .utils import is_rank_0, to_device -class SFTTrainer(Trainer): +class SFTTrainer(SLTrainer): """ Trainer to use while training reward model. @@ -25,12 +22,9 @@ class SFTTrainer(Trainer): model (torch.nn.Module): the model to train strategy (Strategy): the strategy to use for training optim(Optimizer): the optimizer to use for training - train_dataloader: the dataloader to use for training - eval_dataloader: the dataloader to use for evaluation - batch_size (int, defaults to 1): the batch size while training + lr_scheduler(_LRScheduler): the lr scheduler to use for training max_epochs (int, defaults to 2): the number of epochs to train - callbacks (List[Callback], defaults to []): the callbacks to call during training process - optim_kwargs (dict, defaults to {'lr':1e-4}): the kwargs to use while initializing optimizer + accumulation_steps (int, defaults to 8): the number of steps to accumulate gradients """ def __init__( @@ -38,98 +32,99 @@ class SFTTrainer(Trainer): model, strategy: Strategy, optim: Optimizer, - train_dataloader: DataLoader, - eval_dataloader: DataLoader = None, + lr_scheduler: _LRScheduler, max_epochs: int = 2, accumulation_steps: int = 8, - callbacks: List[Callback] = [], ) -> None: - if accumulation_steps > 1 and isinstance(strategy, ColossalAIStrategy) and strategy.stage == 3: - raise ValueError("Accumulation steps are not supported in stage 3 of ColossalAI") - super().__init__(strategy, max_epochs, callbacks=callbacks) + if accumulation_steps > 1: + assert not isinstance( + strategy, GeminiStrategy + ), "Accumulation steps are not supported in stage 3 of ColossalAI" + + super().__init__(strategy, max_epochs, model, optim) + + self.accumulation_steps = accumulation_steps + self.scheduler = lr_scheduler + + self.num_train_step = 0 + self.num_eval_step = 0 + + def _train(self, epoch: int): + self.model.train() + step_bar = tqdm.trange( + len(self.train_dataloader) // self.accumulation_steps, + desc=f"Epoch {epoch + 1}/{self.max_epochs}", + disable=not is_rank_0(), + ) + for i, batch in enumerate(self.train_dataloader): + batch = to_device(batch, torch.cuda.current_device()) + outputs = self.model(batch["input_ids"], attention_mask=batch["attention_mask"], labels=batch["labels"]) + loss = outputs.loss / self.accumulation_steps + self.total_loss += loss.item() + self.strategy.backward(loss, self.model, self.optimizer) + # gradient accumulation + if (i + 1) % self.accumulation_steps == 0: + self.strategy.optimizer_step(self.optimizer) + self.optimizer.zero_grad() + self.scheduler.step() + if self.writer: + self.writer.add_scalar("train/loss", self.total_loss, self.num_train_step) + self.writer.add_scalar("train/lr", self.scheduler.get_last_lr()[0], self.num_train_step) + self.num_train_step += 1 + self.total_loss = 0 + step_bar.update() + step_bar.close() + + def _eval(self, epoch: int): + if self.eval_dataloader is not None: + self.model.eval() + with torch.no_grad(): + loss_sum, num_seen = 0, 0 + for batch in self.eval_dataloader: + batch = to_device(batch, torch.cuda.current_device()) + outputs = self.model( + batch["input_ids"], attention_mask=batch["attention_mask"], labels=batch["labels"] + ) + loss_sum += outputs.loss.item() + num_seen += batch["input_ids"].size(0) + loss_mean = loss_sum / num_seen + if dist.get_rank() == 0: + self.logger.info(f"Eval Epoch {epoch}/{self.max_epochs} loss {loss_mean}") + if self.writer: + self.writer.add_scalar("eval/loss", loss_mean, self.num_eval_step) + self.num_eval_step += 1 + + def _before_fit( + self, + train_dataloader: DataLoader, + eval_dataloader: Optional[DataLoader] = None, + logger: Optional[DistributedLogger] = None, + log_dir: Optional[str] = None, + use_wandb: bool = False, + ): + """ + Args: + train_dataloader: the dataloader to use for training + eval_dataloader: the dataloader to use for evaluation + """ self.train_dataloader = train_dataloader self.eval_dataloader = eval_dataloader - self.model = model - self.optimizer = optim - self.accumulation_steps = accumulation_steps - num_update_steps_per_epoch = len(train_dataloader) // self.accumulation_steps - max_steps = math.ceil(self.max_epochs * num_update_steps_per_epoch) - - self.scheduler = get_scheduler("cosine", - self.optimizer, - num_warmup_steps=math.ceil(max_steps * 0.03), - num_training_steps=max_steps) - - def fit(self, logger, use_wandb: bool = False): - if use_wandb: - wandb.init(project="Coati", name=time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())) - wandb.watch(self.model) - total_loss = 0 - # epoch_bar = tqdm(range(self.epochs), desc='Epochs', disable=not is_rank_0()) - step_bar = tqdm(range(len(self.train_dataloader) // self.accumulation_steps * self.max_epochs), - desc=f'steps', - disable=not is_rank_0()) - for epoch in range(self.max_epochs): - - # process_bar = tqdm(range(len(self.train_dataloader)), desc=f'Train process for{epoch}', disable=not is_rank_0()) - # train - self.model.train() - for batch_id, batch in enumerate(self.train_dataloader): - - batch = to_device(batch, torch.cuda.current_device()) - outputs = self.model(batch["input_ids"], attention_mask=batch["attention_mask"], labels=batch["labels"]) - - loss = outputs.loss - - if loss >= 2.5 and is_rank_0(): - logger.warning(f"batch_id:{batch_id}, abnormal loss: {loss}") - - loss = loss / self.accumulation_steps - - self.strategy.backward(loss, self.model, self.optimizer) - - total_loss += loss.item() - - # gradient accumulation - if (batch_id + 1) % self.accumulation_steps == 0: - self.strategy.optimizer_step(self.optimizer) - self.optimizer.zero_grad() - self.scheduler.step() - if is_rank_0() and use_wandb: - wandb.log({ - "loss": total_loss / self.accumulation_steps, - "lr": self.scheduler.get_last_lr()[0], - "epoch": epoch, - "batch_id": batch_id - }) - total_loss = 0 - step_bar.update() - - # if batch_id % log_interval == 0: - # logger.info(f'Train Epoch {epoch}/{self.epochs} Batch {batch_id} Rank {dist.get_rank()} loss {loss.item()}') - # wandb.log({"loss": loss.item()}) - - # process_bar.update() - - # eval - if self.eval_dataloader is not None: - self.model.eval() - with torch.no_grad(): - loss_sum = 0 - num_seen = 0 - for batch in self.eval_dataloader: - batch = to_device(batch, torch.cuda.current_device()) - outputs = self.model(batch["input_ids"], - attention_mask=batch["attention_mask"], - labels=batch["labels"]) - loss = outputs.loss - - loss_sum += loss.item() - num_seen += batch["input_ids"].size(0) - - loss_mean = loss_sum / num_seen - if dist.get_rank() == 0: - logger.info(f'Eval Epoch {epoch}/{self.max_epochs} loss {loss_mean}') - - # epoch_bar.update() + self.logger = logger + self.writer = None + if use_wandb and is_rank_0(): + assert log_dir is not None, "log_dir must be provided when use_wandb is True" + import wandb + + wandb.init(project="Coati-sft", sync_tensorboard=True) + if log_dir is not None and is_rank_0(): + import os + import time + + from torch.utils.tensorboard import SummaryWriter + + log_dir = os.path.join(log_dir, "sft") + log_dir = os.path.join(log_dir, time.strftime("%Y-%m-%d_%H:%M:%S", time.localtime())) + self.writer = SummaryWriter(log_dir=log_dir) + + self.total_loss = 0 diff --git a/applications/Chat/coati/trainer/strategies/__init__.py b/applications/Chat/coati/trainer/strategies/__init__.py index f258c9b8a87324d28d03a22af94f485dcb416c24..521dcb5855b1a645e107ebecfea1ad093b08fb34 100644 --- a/applications/Chat/coati/trainer/strategies/__init__.py +++ b/applications/Chat/coati/trainer/strategies/__init__.py @@ -1,6 +1,5 @@ from .base import Strategy -from .colossalai import ColossalAIStrategy +from .colossalai import GeminiStrategy, LowLevelZeroStrategy from .ddp import DDPStrategy -from .naive import NaiveStrategy -__all__ = ['Strategy', 'NaiveStrategy', 'DDPStrategy', 'ColossalAIStrategy'] +__all__ = ["Strategy", "DDPStrategy", "LowLevelZeroStrategy", "GeminiStrategy"] diff --git a/applications/Chat/coati/trainer/strategies/base.py b/applications/Chat/coati/trainer/strategies/base.py index b1452869179ebffdbbb05a42e64777ba3e20ca83..a78716216ae02a9b58ad3f314e45e1aeb432b72b 100644 --- a/applications/Chat/coati/trainer/strategies/base.py +++ b/applications/Chat/coati/trainer/strategies/base.py @@ -1,63 +1,63 @@ from abc import ABC, abstractmethod from contextlib import nullcontext -from typing import Any, List, Optional, Tuple, Union +from typing import Callable, Dict, List, Optional, Tuple, Union import torch import torch.nn as nn -from coati.models.base import Actor, get_base_model -from coati.replay_buffer import ReplayBuffer +from coati.experience_buffer import ExperienceBuffer from torch.optim import Optimizer from torch.utils.data import DataLoader from transformers.tokenization_utils_base import PreTrainedTokenizerBase +from colossalai.booster import Booster +from colossalai.booster.plugin import Plugin + from .sampler import DistributedSampler -ModelOptimPair = Tuple[nn.Module, Optimizer] -ModelOrModelOptimPair = Union[nn.Module, ModelOptimPair] +_BoostArgSpec = Union[nn.Module, Tuple[nn.Module, Optimizer], Dict] class Strategy(ABC): """ - Base class for training strategies. + Base class for training strategies. """ - def __init__(self) -> None: + def __init__(self, plugin_initializer: Callable[..., Optional[Plugin]] = lambda: None) -> None: super().__init__() + # NOTE: dist must be initialized before Booster self.setup_distributed() + self.plugin = plugin_initializer() + self.booster = Booster(plugin=self.plugin) + self._post_init() @abstractmethod - def backward(self, loss: torch.Tensor, model: nn.Module, optimizer: Optimizer, **kwargs) -> None: + def _post_init(self) -> None: pass - @abstractmethod + def backward(self, loss: torch.Tensor, model: nn.Module, optimizer: Optimizer, **kwargs) -> None: + self.booster.backward(loss, optimizer) + def optimizer_step(self, optimizer: Optimizer, **kwargs) -> None: - pass + optimizer.step() @abstractmethod def setup_distributed(self) -> None: pass @abstractmethod - def setup_model(self, model: nn.Module) -> nn.Module: - pass - - @abstractmethod - def setup_optimizer(self, optimizer: Optimizer, model: nn.Module) -> Optimizer: - pass - - @abstractmethod - def setup_dataloader(self, replay_buffer: ReplayBuffer, pin_memory: bool = False) -> DataLoader: + def setup_dataloader(self, data_buffer: ExperienceBuffer, pin_memory: bool = False) -> DataLoader: pass def model_init_context(self): return nullcontext() - def prepare( - self, *models_or_model_optim_pairs: ModelOrModelOptimPair - ) -> Union[List[ModelOrModelOptimPair], ModelOrModelOptimPair]: - """Prepare models or model-optimizer-pairs based on each strategy. + def prepare(self, *boost_args: _BoostArgSpec) -> Union[List[_BoostArgSpec], _BoostArgSpec]: + """Prepare [model | (model, optimizer) | Dict] based on each strategy. + NOTE: the keys of Dict must be a subset of `self.booster.boost`'s arguments. Example:: + >>> # e.g., include lr_scheduler + >>> result_dict = strategy.prepare(dict(model=model, lr_scheduler=lr_scheduler)) >>> # when fine-tuning actor and critic >>> (actor, actor_optim), (critic, critic_optim), reward_model, initial_model = strategy.prepare((actor, actor_optim), (critic, critic_optim), reward_model, initial_model) >>> # or when training reward model @@ -66,67 +66,72 @@ class Strategy(ABC): >>> actor, critic = strategy.prepare(actor, critic) Returns: - Union[List[ModelOrModelOptimPair], ModelOrModelOptimPair]: Models or model-optimizer-pairs in the original order. + Union[List[_BoostArgSpec], _BoostArgSpec]: [model | (model, optimizer) | Dict] in the original order. """ - def prepare_model(model: nn.Module): - if isinstance(model, Actor): - return Actor(self.setup_model(model.get_base_model())) - return self.setup_model(model) - rets = [] - for arg in models_or_model_optim_pairs: - if isinstance(arg, tuple): - assert len(arg) == 2, f'Expect (model, optimizer) pair, got a tuple with size "{len(arg)}"' - model, optimizer = arg - model = prepare_model(model) - optimizer = self.setup_optimizer(optimizer, get_base_model(model)) + for arg in boost_args: + if isinstance(arg, nn.Module): + model, *_ = self.booster.boost(arg) + rets.append(model) + elif isinstance(arg, tuple): + try: + model, optimizer = arg + except ValueError: + raise RuntimeError(f'Expect (model, optimizer) pair, got a tuple with size "{len(arg)}"') + model, optimizer, *_ = self.booster.boost(model=model, optimizer=optimizer) rets.append((model, optimizer)) - elif isinstance(arg, nn.Module): - rets.append(prepare_model(arg)) + elif isinstance(arg, Dict): + model, optimizer, criterion, dataloader, lr_scheduler = self.booster.boost(**arg) + boost_result = dict( + model=model, + optimizer=optimizer, + criterion=criterion, + dataloader=dataloader, + lr_scheduler=lr_scheduler, + ) + # remove None values + boost_result = {key: value for key, value in boost_result.items() if value is not None} + rets.append(boost_result) else: - raise RuntimeError(f'Expect model or (model, optimizer) pair, got {type(arg)}') + raise RuntimeError(f"Type {type(arg)} is not supported") - if len(rets) == 1: - return rets[0] - return rets + return rets[0] if len(rets) == 1 else rets @staticmethod def unwrap_model(model: nn.Module) -> nn.Module: - """Get the unwrapped model from a wrapped model. Useful for getting original huggingface model. - For Actor, it will unwrap `actor.model`. + """Get the unwrapped model from a wrapped model made by Strategy.prepare. Args: model (nn.Module): the model to unwrap Returns: - nn.Module: the original model (usually a huggingface model) + nn.Module: the original model """ - return get_base_model(model) + return model - @abstractmethod - def save_model(self, model: nn.Module, path: str, only_rank0: bool = True) -> None: - pass + def save_model(self, model: nn.Module, path: str, shard: bool = False, **kwargs) -> None: + self.booster.save_model(model, path, shard=shard, **kwargs) - @abstractmethod - def load_model(self, model: nn.Module, path: str, map_location: Any = None, strict: bool = True) -> None: - pass + def load_model(self, model: nn.Module, path: str, strict: bool = True) -> None: + self.booster.load_model(model, path, strict) - @abstractmethod - def save_optimizer(self, optimizer: Optimizer, path: str, only_rank0: bool = False) -> None: - pass + def save_optimizer(self, optimizer: Optimizer, path: str, only_rank0: bool = False, **kwargs) -> None: + self.booster.save_optimizer(optimizer, path, shard=not only_rank0, **kwargs) - @abstractmethod - def load_optimizer(self, optimizer: Optimizer, path: str, map_location: Any = None) -> None: - pass + def load_optimizer(self, optimizer: Optimizer, path: str) -> None: + self.booster.load_optimizer(optimizer, path) def setup_sampler(self, dataset) -> DistributedSampler: + # FIXME(cwher): this is only invoked in train_on_ray, not tested after adapt Boost API. return DistributedSampler(dataset, 1, 0) @abstractmethod - def save_pretrained(self, - model: nn.Module, - path: str, - only_rank0: bool = True, - tokenizer: Optional[PreTrainedTokenizerBase] = None) -> None: + def save_pretrained( + self, model: nn.Module, path: str, only_rank0: bool = True, tokenizer: Optional[PreTrainedTokenizerBase] = None + ) -> None: + pass + + @abstractmethod + def get_model_state_dict_shard(self, model: nn.Module, **config): pass diff --git a/applications/Chat/coati/trainer/strategies/colossalai.py b/applications/Chat/coati/trainer/strategies/colossalai.py index 8aa302c77eeec2efa5fb7e3e789b172e2f9578a2..7129edb060efff00724873151e1dc5f506da978c 100644 --- a/applications/Chat/coati/trainer/strategies/colossalai.py +++ b/applications/Chat/coati/trainer/strategies/colossalai.py @@ -1,47 +1,117 @@ import warnings -from typing import Optional, Union +from typing import Optional -import torch -import torch.distributed as dist import torch.nn as nn -import torch.optim as optim -from coati.models.base import get_base_model -from torch.optim import Optimizer -from transformers.tokenization_utils_base import PreTrainedTokenizerBase import colossalai -from colossalai.logging import get_dist_logger -from colossalai.nn.optimizer import CPUAdam, HybridAdam -from colossalai.tensor import ProcessGroup, ShardSpec +from colossalai.booster.plugin import GeminiPlugin, LowLevelZeroPlugin +from colossalai.booster.plugin.low_level_zero_plugin import LowLevelZeroModel from colossalai.utils import get_current_device -from colossalai.zero import ColoInitContext, ZeroDDP, zero_model_wrapper, zero_optim_wrapper +from colossalai.zero.gemini.gemini_ddp import GeminiDDP from .ddp import DDPStrategy -logger = get_dist_logger(__name__) +class LowLevelZeroStrategy(DDPStrategy): + """ + The strategy for training with ColossalAI. + + Args: + stage(int): The stage to use in ZeRO. Choose in (1, 2) + precision(str): The precision to use. Choose in ('fp32', 'fp16'). + seed(int): The seed for the random number generator. + placement_policy(str): The placement policy for gemini. Choose in ('cpu', 'cuda') + If it is “cpu”, parameters, gradients and optimizer states will be offloaded to CPU, + If it is “cuda”, they will not be offloaded, which means max CUDA memory will be used. It is the fastest. + reduce_bucket_size(int): The reduce bucket size in bytes. Only for ZeRO-1 and ZeRO-2. + overlap_communication(bool): Whether to overlap communication and computation. Only for ZeRO-1 and ZeRO-2. + initial_scale(float): The initial scale for the optimizer. + growth_factor(float): The growth factor for the optimizer. + backoff_factor(float): The backoff factor for the optimizer. + growth_interval(int): The growth interval for the optimizer. + hysteresis(int): The hysteresis for the optimizer. + min_scale(float): The minimum scale for the optimizer. + max_scale(float): The maximum scale for the optimizer. + max_norm(float): The maximum norm for the optimizer. + norm_type(float): The norm type for the optimizer. + + """ + + def __init__( + self, + stage: int = 2, + precision: str = "fp16", + seed: int = 42, + placement_policy: str = "cuda", + reduce_bucket_size: int = 12 * 1024**2, # only for stage 1&2 + overlap_communication: bool = True, # only for stage 1&2 + initial_scale: float = 2**16, + growth_factor: float = 2, + backoff_factor: float = 0.5, + growth_interval: int = 1000, + hysteresis: int = 2, + min_scale: float = 1, + max_scale: float = 2**32, + max_norm: float = 0.0, + norm_type: float = 2.0, + ) -> None: + assert stage in (1, 2), f'Unsupported stage "{stage}"' + assert placement_policy in ("cpu", "cuda"), f'Unsupported placement policy "{placement_policy}"' + assert precision in ("fp32", "fp16"), f'Unsupported precision "{precision}"' + + plugin_initializer = lambda: LowLevelZeroPlugin( + stage=stage, + precision=precision, + reduce_bucket_size_in_m=reduce_bucket_size, + overlap_communication=overlap_communication, + cpu_offload=(placement_policy == "cpu"), + initial_scale=initial_scale, + growth_factor=growth_factor, + backoff_factor=backoff_factor, + growth_interval=growth_interval, + hysteresis=hysteresis, + min_scale=min_scale, + max_scale=max_scale, + max_norm=max_norm, + norm_type=norm_type, + ) + + super().__init__(seed, plugin_initializer) + + def _post_init(self) -> None: + assert isinstance( + self.plugin, LowLevelZeroPlugin + ), f"{type(self).__name__}'s plugin is not initialized properly." + + def setup_distributed(self) -> None: + colossalai.launch_from_torch({}, seed=self.seed) + + def unwrap_model(self, model: nn.Module) -> nn.Module: + assert isinstance(model, LowLevelZeroModel) + return model.module + + def get_model_state_dict_shard(self, model: nn.Module, **config): + assert isinstance(model, LowLevelZeroModel) + yield from model.state_dict_shard(max_shard_size=1024, only_rank_0=False) -class ColossalAIStrategy(DDPStrategy): + +class GeminiStrategy(DDPStrategy): """ The strategy for training with ColossalAI. Args: - stage(int): The stage to use in ZeRO. Choose in (1, 2, 3) - precision(str): The precision to use. Choose in ('fp32', 'fp16'). Stage 3 only supports fp16. seed(int): The seed for the random number generator. shard_init(bool): Whether to shard the model parameters during initialization. Only for ZeRO-3. - This is not compativle with `from_pretrained()`. We temporarily disable this and will support it in the future. + This is not compatible with `from_pretrained()`. We temporarily disable this and will support it in the future. placement_policy(str): The placement policy for gemini. Choose in ('cpu', 'cuda') If it is “cpu”, parameters, gradients and optimizer states will be offloaded to CPU, If it is “cuda”, they will not be offloaded, which means max CUDA memory will be used. It is the fastest. pin_memory(bool): Whether to pin the memory for the data loader. Only for ZeRO-3. force_outputs_fp32(bool): Whether to force the outputs to be fp32. Only for ZeRO-3. - search_range_mb(int): The search range in MB for the chunk size. Only for ZeRO-3. + search_range_m(int): The number of search range for the chunk size, divided by 2^20. Only for ZeRO-3. hidden_dim(optional, int): The hidden dimension for the gemini. Only for ZeRO-3. - min_chunk_size_mb(float): The minimum chunk size in MB. Only for ZeRO-3. + min_chunk_size_m(float): The minimum chunk size divided by 2^20. Only for ZeRO-3. gpu_margin_mem_ratio(float): The margin memory ratio for the GPU. Only for ZeRO-3. - reduce_bugket_size(int): The reduce bucket size in bytes. Only for ZeRO-1 and ZeRO-2. - overlap_communication(bool): Whether to overlap communication and computation. Only for ZeRO-1 and ZeRO-2. initial_scale(float): The initial scale for the optimizer. growth_factor(float): The growth factor for the optimizer. backoff_factor(float): The backoff factor for the optimizer. @@ -55,134 +125,76 @@ class ColossalAIStrategy(DDPStrategy): """ def __init__( - self, - stage: int = 3, - precision: str = 'fp16', - seed: int = 42, - shard_init: bool = False, # only for stage 3 - placement_policy: str = 'cuda', - pin_memory: bool = True, # only for stage 3 - force_outputs_fp32: bool = False, # only for stage 3 - scatter_after_inference: bool = False, # only for stage 3 - search_range_mb: int = 32, # only for stage 3 - hidden_dim: Optional[int] = None, # only for stage 3 - min_chunk_size_mb: float = 32, # only for stage 3 - gpu_margin_mem_ratio: float = 0.0, # only for stage 3 - reduce_bucket_size: int = 12 * 1024**2, # only for stage 1&2 - overlap_communication: bool = True, # only for stage 1&2 - initial_scale: float = 2**16, - growth_factor: float = 2, - backoff_factor: float = 0.5, - growth_interval: int = 1000, - hysteresis: int = 2, - min_scale: float = 1, - max_scale: float = 2**32, - max_norm: float = 0.0, - norm_type: float = 2.0) -> None: - super().__init__(seed) - assert placement_policy in ('cpu', 'cuda'), f'Unsupported placement policy "{placement_policy}"' - assert precision in ('fp32', 'fp16'), f'Unsupported precision "{precision}"' - self.stage = stage + self, + seed: int = 42, + shard_init: bool = False, # only for stage 3 + placement_policy: str = "auto", + shard_param_frac: float = 1.0, # only for static placement + offload_optim_frac: float = 0.0, # only for static placement + offload_param_frac: float = 0.0, # only for static placement + pin_memory: bool = True, # only for stage 3 + force_outputs_fp32: bool = False, # only for stage 3 + search_range_m: int = 32, # only for stage 3 + hidden_dim: Optional[int] = None, # only for stage 3 + min_chunk_size_m: float = 32, # only for stage 3 + gpu_margin_mem_ratio: float = 0.0, # only for stage 3 + initial_scale: float = 2**16, + growth_factor: float = 2, + backoff_factor: float = 0.5, + growth_interval: int = 1000, + hysteresis: int = 2, + min_scale: float = 1, + max_scale: float = 2**32, + max_norm: float = 0.0, + norm_type: float = 2.0, + ) -> None: # TODO(ver217): support shard_init when using from_pretrained() if shard_init: warnings.warn( - f'Shard init is not supported model.from_pretrained() yet. Please load weights after strategy.prepare()' + f"Shard init is not supported model.from_pretrained() yet. " + "Please load weights after strategy.prepare()" ) - if stage == 3 and precision == 'fp32': - warnings.warn(f'Stage 3 only supports fp16. Precision is set to fp16.') - precision = 'fp16' - self.precision = precision self.shard_init = shard_init - self.gemini_config = dict(device=get_current_device(), - placement_policy=placement_policy, - pin_memory=pin_memory, - force_outputs_fp32=force_outputs_fp32, - strict_ddp_mode=shard_init, - search_range_mb=search_range_mb, - hidden_dim=hidden_dim, - min_chunk_size_mb=min_chunk_size_mb, - scatter_after_inference=scatter_after_inference) - if stage == 3: - self.zero_optim_config = dict(gpu_margin_mem_ratio=gpu_margin_mem_ratio) - else: - self.zero_optim_config = dict(reduce_bucket_size=reduce_bucket_size, - overlap_communication=overlap_communication, - cpu_offload=(placement_policy == 'cpu')) - self.optim_kwargs = dict(initial_scale=initial_scale, - growth_factor=growth_factor, - backoff_factor=backoff_factor, - growth_interval=growth_interval, - hysteresis=hysteresis, - min_scale=min_scale, - max_scale=max_scale, - max_norm=max_norm, - norm_type=norm_type) + + warnings.warn(f"Stage 3 only supports fp16. Precision is set to fp16.") + + # NOTE: dist should be initialized before calling get_current_device() + plugin_initializer = lambda: GeminiPlugin( + chunk_init_device=get_current_device(), + placement_policy=placement_policy, + shard_param_frac=shard_param_frac, + offload_optim_frac=offload_optim_frac, + offload_param_frac=offload_param_frac, + precision="fp16", + pin_memory=pin_memory, + force_outputs_fp32=force_outputs_fp32, + strict_ddp_mode=shard_init, + search_range_m=search_range_m, + hidden_dim=hidden_dim, + min_chunk_size_m=min_chunk_size_m, + gpu_margin_mem_ratio=gpu_margin_mem_ratio, + initial_scale=initial_scale, + growth_factor=growth_factor, + backoff_factor=backoff_factor, + growth_interval=growth_interval, + hysteresis=hysteresis, + min_scale=min_scale, + max_scale=max_scale, + max_norm=max_norm, + norm_type=norm_type, + ) + + super().__init__(seed, plugin_initializer) + + def _post_init(self) -> None: + assert isinstance(self.plugin, GeminiPlugin), f"{type(self).__name__}'s plugin is not initialized properly." def setup_distributed(self) -> None: colossalai.launch_from_torch({}, seed=self.seed) def model_init_context(self): - if self.stage == 3: - world_size = dist.get_world_size() - shard_pg = ProcessGroup(tp_degree=world_size) if self.shard_init else None - default_dist_spec = ShardSpec([-1], [world_size]) if self.shard_init else None - return ColoInitContext(device=get_current_device(), - dtype=torch.half, - default_pg=shard_pg, - default_dist_spec=default_dist_spec) return super().model_init_context() - def setup_model(self, model: nn.Module) -> nn.Module: - - model = zero_model_wrapper(model, zero_stage=self.stage, gemini_config=self.gemini_config) - - if self.stage != 3 and self.precision == 'fp16': - model = model.half().cuda() - return model - - def setup_optimizer(self, optimizer: optim.Optimizer, model: nn.Module) -> optim.Optimizer: - assert isinstance(optimizer, (CPUAdam, HybridAdam)), f'Unsupported optimizer {type(optimizer)}' - return zero_optim_wrapper(model, optimizer, optim_config=self.zero_optim_config, **self.optim_kwargs) - - def backward(self, loss: torch.Tensor, model: nn.Module, optimizer: optim.Optimizer, **kwargs) -> None: - optimizer.backward(loss) - - def optimizer_step(self, optimizer: optim.Optimizer, **kwargs) -> None: - optimizer.step() - - def save_model(self, model: nn.Module, path: str, only_rank0: bool = True) -> None: - if only_rank0 and dist.get_rank() != 0 and self.stage != 3: - return - base_model = get_base_model(model) - if self.stage == 3: - assert isinstance(base_model, ZeroDDP) - # for stage 3, state_dict() method should be called on every rank - state_dict = base_model.state_dict(only_rank_0=only_rank0) - else: - # only_rank0 is false or rank == 0 - state_dict = base_model.state_dict() - if only_rank0 and dist.get_rank() != 0: - return - torch.save(state_dict, path) - - def save_optimizer(self, optimizer: Optimizer, path: str, only_rank0: bool = False) -> None: - if only_rank0: - raise RuntimeError( - f'Optimizer states are sharded when using ColossalAIStrategy. Only rank0 is not supported.') - torch.save(optimizer.state_dict(), path) - def unwrap_model(self, model: nn.Module) -> nn.Module: - base_model: Union[nn.Module, ZeroDDP] = get_base_model(model) - if self.stage == 3: - assert isinstance(base_model, ZeroDDP) - return base_model.module - return base_model - - def save_pretrained(self, - model: nn.Module, - path: str, - only_rank0: bool = True, - tokenizer: Optional[PreTrainedTokenizerBase] = None) -> None: - if self.stage == 3: - raise RuntimeError('ColossalAI strategy with stage-3 does not support save_pretrained() now') - super().save_pretrained(model, path, only_rank0, tokenizer) + assert isinstance(model, GeminiDDP) + return model.module diff --git a/applications/Chat/coati/trainer/strategies/ddp.py b/applications/Chat/coati/trainer/strategies/ddp.py index 7910b57878f86c5a8ed2781ab570e8302c3c2567..f2a44aeb096138a9c4848bb1b4531a199ed55e9c 100644 --- a/applications/Chat/coati/trainer/strategies/ddp.py +++ b/applications/Chat/coati/trainer/strategies/ddp.py @@ -1,93 +1,136 @@ import os import random -from typing import Optional +from collections import OrderedDict +from typing import Callable, Optional import numpy as np import torch import torch.distributed as dist import torch.nn as nn -from coati.replay_buffer import ReplayBuffer -from torch.nn.parallel import DistributedDataParallel as DDP -from torch.optim import Optimizer +from coati.experience_buffer import ExperienceBuffer +from coati.models import Actor, Critic, RewardModel from torch.utils.data import DataLoader +from transformers.modeling_utils import PreTrainedModel from transformers.tokenization_utils_base import PreTrainedTokenizerBase -from .naive import NaiveStrategy +from colossalai.booster.plugin import TorchDDPPlugin +from colossalai.booster.plugin.torch_ddp_plugin import TorchDDPModel + +from .base import Strategy from .sampler import DistributedSampler -class DDPStrategy(NaiveStrategy): +# TODO Move this to a util.py (Moving to ray.util introduces ringed import) +def get_grad_required_state_dict(model: nn.Module): + state_dict = OrderedDict() + for name, parameter in model.named_parameters(): + if parameter.requires_grad: + state_dict[name] = parameter.detach() + return state_dict + + +class DDPStrategy(Strategy): """ - Strategy for distributed training using torch.distributed. + Strategy for distributed training using torch.distributed. """ - def __init__(self, seed: int = 42) -> None: + def __init__(self, seed: int = 42, plugin_initializer: Callable = TorchDDPPlugin) -> None: self.seed = seed - super().__init__() + super().__init__(plugin_initializer) - def setup_distributed(self) -> None: + def _try_init_dist(self, force: bool = False) -> None: try: - rank = int(os.environ['RANK']) - local_rank = int(os.environ['LOCAL_RANK']) - world_size = int(os.environ['WORLD_SIZE']) - host = os.environ['MASTER_ADDR'] - port = int(os.environ['MASTER_PORT']) + rank = int(os.environ["RANK"]) + local_rank = int(os.environ["LOCAL_RANK"]) + world_size = int(os.environ["WORLD_SIZE"]) + host = os.environ["MASTER_ADDR"] + port = int(os.environ["MASTER_PORT"]) + dist.init_process_group("nccl", init_method=f"tcp://[{host}]:{port}", world_size=world_size, rank=rank) + torch.cuda.set_device(local_rank) except KeyError as e: - raise RuntimeError( - f"Could not find {e} in the torch environment, visit https://www.colossalai.org/ for more information on launching with torch" - ) - dist.init_process_group('nccl', init_method=f'tcp://[{host}]:{port}', world_size=world_size, rank=rank) + if force: + raise RuntimeError( + f"Could not find {e} in the torch environment, visit https://www.colossalai.org/ for more information on launching with torch" + ) + except Exception as e: + if force: + raise e + + def _post_init(self) -> None: + assert isinstance(self.plugin, TorchDDPPlugin), f"{type(self).__name__}'s plugin is not initialized properly." + + def setup_distributed(self) -> None: + self._try_init_dist(force=True) self.set_seed(self.seed) - torch.cuda.set_device(local_rank) def set_seed(self, seed: int) -> None: random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) - def setup_model(self, model: nn.Module) -> nn.Module: - device = torch.cuda.current_device() - return DDP(model, device_ids=[device]) - - def setup_dataloader(self, replay_buffer: ReplayBuffer, pin_memory: bool = False) -> DataLoader: - # DDP only mode, replay buffers on each rank are different. - # sampler = DistributedSampler(replay_buffer, - # num_replicas=dist.get_world_size(), - # rank=dist.get_rank(), - # shuffle=True, - # seed=self.seed, - # drop_last=True) - return DataLoader( - replay_buffer, - batch_size=replay_buffer.sample_batch_size, - # sampler=sampler, + def setup_dataloader(self, data_buffer: ExperienceBuffer, pin_memory: bool = False) -> DataLoader: + return self.plugin.prepare_dataloader( + data_buffer, + batch_size=data_buffer.sample_batch_size, shuffle=True, drop_last=True, pin_memory=pin_memory, - collate_fn=replay_buffer.collate_fn) - - def save_model(self, model: nn.Module, path: str, only_rank0: bool = True) -> None: - if only_rank0 and dist.get_rank() != 0: - return - super().save_model(model, path, only_rank0) - - def save_optimizer(self, optimizer: Optimizer, path: str, only_rank0: bool = False) -> None: - if only_rank0 and dist.get_rank() != 0: - return - super().save_optimizer(optimizer, path, only_rank0) + collate_fn=data_buffer.collate_fn, + ) def setup_sampler(self, dataset) -> DistributedSampler: + # FIXME(cwher): this is only invoked in train_on_ray, not tested after adapt Boost API. return DistributedSampler(dataset, dist.get_world_size(), dist.get_rank()) def unwrap_model(self, model: nn.Module) -> nn.Module: - base_model: DDP = super().unwrap_model(model) - return base_model.module - - def save_pretrained(self, - model: nn.Module, - path: str, - only_rank0: bool = True, - tokenizer: Optional[PreTrainedTokenizerBase] = None) -> None: - if only_rank0 and dist.get_rank() != 0: - return - super().save_pretrained(model, path, only_rank0, tokenizer) + assert isinstance(model, TorchDDPModel), "model is not wrapped by TorchDDPModel." + return model.unwrap() + + def save_pretrained( + self, model: nn.Module, path: str, shard: bool = False, tokenizer: Optional[PreTrainedTokenizerBase] = None + ) -> None: + if dist.get_rank() == 0: + unwrapped_model = self.unwrap_model(model) + assert isinstance(unwrapped_model, (Actor, Critic, RewardModel)) + pretrained_model = unwrapped_model.model + assert isinstance(pretrained_model, PreTrainedModel) + # HACK: only use hf save_pretrained to save config + pretrained_model.save_pretrained(path, save_function=lambda *args, **kwargs: None) + if tokenizer is not None: + tokenizer.save_pretrained(path) + + model_path = os.path.join(path, "pytorch_model.bin") + self.save_model(model, model_path, shard=shard) + def _replace_keys(model_path: str, replace_fn: Callable): + state_dict = torch.load(model_path, map_location="cpu") + state_dict = {replace_fn(k): v for k, v in state_dict.items()} + torch.save(state_dict, model_path) + # FIXME: save_model would add "model." prefix to keys of pytorch_model.bin + # HACK: rename keys of pytorch_model.bin + if dist.get_rank() == 0: + _replace_keys(model_path, lambda k: k.replace("model.", "", 1)) + + + def get_model_state_dict_shard(self, model: nn.Module, **config): + # TODO: implement sharding on naive strategy + model = self.unwrap_model(model) + if "requires_grad_only" in config and config["requires_grad_only"] == True: + state_dict = get_grad_required_state_dict(model) + else: + state_dict = model.state_dict() + + if "shard_size" in config: + shard_size = config["shard_size"] + accumulate_size = 0 + state_dict_shard = OrderedDict() + for name, param in state_dict.items(): + state_dict_shard[name] = param + accumulate_size += param.numel() * param.element_size() + if accumulate_size >= shard_size: + accumulate_size = 0 + yield state_dict_shard + state_dict_shard = OrderedDict() + if accumulate_size > 0: + yield state_dict_shard + else: + yield state_dict diff --git a/applications/Chat/coati/trainer/strategies/naive.py b/applications/Chat/coati/trainer/strategies/naive.py deleted file mode 100644 index 4d94026ce9320ef754632680d5b44de6baae4862..0000000000000000000000000000000000000000 --- a/applications/Chat/coati/trainer/strategies/naive.py +++ /dev/null @@ -1,70 +0,0 @@ -from typing import Any, Optional - -import torch -import torch.nn as nn -import torch.optim as optim -from coati.models.base import get_base_model -from coati.replay_buffer import ReplayBuffer -from torch.optim import Optimizer -from torch.utils.data import DataLoader -from transformers.modeling_utils import PreTrainedModel -from transformers.tokenization_utils_base import PreTrainedTokenizerBase - -from .base import Strategy - - -class NaiveStrategy(Strategy): - """ - Strategy for single GPU. No parallelism is used. - """ - - def backward(self, loss: torch.Tensor, model: nn.Module, optimizer: optim.Optimizer, **kwargs) -> None: - loss.backward() - - def optimizer_step(self, optimizer: optim.Optimizer, **kwargs) -> None: - optimizer.step() - - def setup_distributed(self) -> None: - pass - - def setup_model(self, model: nn.Module) -> nn.Module: - return model - - def setup_optimizer(self, optimizer: optim.Optimizer, model: nn.Module) -> optim.Optimizer: - return optimizer - - def setup_dataloader(self, replay_buffer: ReplayBuffer, pin_memory: bool = False) -> DataLoader: - return DataLoader(replay_buffer, - batch_size=replay_buffer.sample_batch_size, - shuffle=True, - drop_last=True, - pin_memory=pin_memory, - collate_fn=replay_buffer.collate_fn) - - def save_model(self, model: nn.Module, path: str, only_rank0: bool = True) -> None: - base_model = get_base_model(model) - state_dict = base_model.state_dict() - torch.save(state_dict, path) - - def load_model(self, model: nn.Module, path: str, map_location: Any = None, strict: bool = True) -> None: - base_model = get_base_model(model) - state_dict = torch.load(path, map_location=map_location) - base_model.load_state_dict(state_dict, strict=strict) - - def save_optimizer(self, optimizer: Optimizer, path: str, only_rank0: bool = False) -> None: - torch.save(optimizer.state_dict(), path) - - def load_optimizer(self, optimizer: Optimizer, path: str, map_location: Any = None) -> None: - state_dict = torch.load(path, map_location=map_location) - optimizer.load_state_dict(state_dict) - - def save_pretrained(self, - model: nn.Module, - path: str, - only_rank0: bool = True, - tokenizer: Optional[PreTrainedTokenizerBase] = None) -> None: - unwrapped_model = self.unwrap_model(model) - assert isinstance(unwrapped_model, PreTrainedModel) - unwrapped_model.save_pretrained(path) - if tokenizer is not None: - tokenizer.save_pretrained(path) diff --git a/applications/Chat/coati/trainer/strategies/sampler.py b/applications/Chat/coati/trainer/strategies/sampler.py index d726fa640fa201b8bdec5c7601cc2895c4357316..6e811bef11a59303ea9873d8df3ebea331576a54 100644 --- a/applications/Chat/coati/trainer/strategies/sampler.py +++ b/applications/Chat/coati/trainer/strategies/sampler.py @@ -4,7 +4,6 @@ import numpy as np class DistributedSampler: - def __init__(self, dataset, num_replicas: int, rank: int) -> None: self.dataset = dataset self.num_replicas = num_replicas @@ -12,7 +11,7 @@ class DistributedSampler: if len(self.dataset) % self.num_replicas != 0: self.num_samples = math.ceil( - (len(self.dataset) - self.num_replicas) / self.num_replicas # type: ignore[arg-type] + (len(self.dataset) - self.num_replicas) / self.num_replicas # type: ignore[arg-type] ) else: self.num_samples = math.ceil(len(self.dataset) / self.num_replicas) @@ -20,10 +19,10 @@ class DistributedSampler: self.total_size = self.num_samples * self.num_replicas indices = list(range(len(self.dataset))) - indices = indices[:self.total_size] + indices = indices[: self.total_size] assert len(indices) == self.total_size # subsample - indices = indices[self.rank:self.total_size:self.num_replicas] + indices = indices[self.rank : self.total_size : self.num_replicas] assert len(indices) == self.num_samples self.indices = indices diff --git a/applications/Chat/coati/trainer/utils.py b/applications/Chat/coati/trainer/utils.py index 9cccb5c9260395859611c55207f7dbaf817ef76f..7811e7365eeb08d3fcefc1161075773ff3d18683 100644 --- a/applications/Chat/coati/trainer/utils.py +++ b/applications/Chat/coati/trainer/utils.py @@ -3,6 +3,38 @@ from typing import Any import torch import torch.distributed as dist from torch.utils._pytree import tree_map +from torch.utils.data import DataLoader + + +class CycledDataLoader: + """ + Why do we need this class? + In version 4da324cd60, "prompts = next(iter(self.prompt_dataloader))" is used to sample a batch of prompts/pretrain. + However, this may be inefficient due to frequent re-initialization of the dataloader. (re-initialize workers...) + NOTE: next(iter(dataloader)) is not equivalent to for batch in dataloader: break, it causes slightly different behavior. + """ + + def __init__( + self, + dataloader: DataLoader, + ) -> None: + self.dataloader = dataloader + + self.count = 0 + self.dataloader_iter = None + + def next(self): + # defer initialization + if self.dataloader_iter is None: + self.dataloader_iter = iter(self.dataloader) + + self.count += 1 + try: + return next(self.dataloader_iter) + except StopIteration: + self.count = 0 + self.dataloader_iter = iter(self.dataloader) + return next(self.dataloader_iter) def is_rank_0() -> bool: @@ -10,7 +42,6 @@ def is_rank_0() -> bool: def to_device(x: Any, device: torch.device) -> Any: - def _to(t: Any): if isinstance(t, torch.Tensor): return t.to(device) diff --git a/applications/Chat/coati/utils/__init__.py b/applications/Chat/coati/utils/__init__.py deleted file mode 100644 index 112b82b9706444013013777d12798b4f6b62e52a..0000000000000000000000000000000000000000 --- a/applications/Chat/coati/utils/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -from .tokenizer_utils import prepare_llama_tokenizer_and_embedding, smart_tokenizer_and_embedding_resize - -__all__ = ['smart_tokenizer_and_embedding_resize', 'prepare_llama_tokenizer_and_embedding'] \ No newline at end of file diff --git a/applications/Chat/coati/utils/tokenizer_utils.py b/applications/Chat/coati/utils/tokenizer_utils.py deleted file mode 100644 index e0d96cfc8be2711397d23bcfed725bd0ba10a2bf..0000000000000000000000000000000000000000 --- a/applications/Chat/coati/utils/tokenizer_utils.py +++ /dev/null @@ -1,73 +0,0 @@ -# Copyright 2023 Rohan Taori, Ishaan Gulrajani, Tianyi Zhang, Yann Dubois, Xuechen Li -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from typing import Dict - -import transformers - -DEFAULT_PAD_TOKEN = "[PAD]" -DEFAULT_EOS_TOKEN = "" -DEFAULT_BOS_TOKEN = "" -DEFAULT_UNK_TOKEN = "" - - -def prepare_llama_tokenizer_and_embedding( - tokenizer: transformers.PreTrainedTokenizer, - model: transformers.PreTrainedModel, - special_tokens_dict: Dict = dict(pad_token=DEFAULT_PAD_TOKEN), -): - """prepare llama tokenizer and embedding. - - """ - - if tokenizer.pad_token is None: - smart_tokenizer_and_embedding_resize( - special_tokens_dict=dict(pad_token=DEFAULT_PAD_TOKEN), - tokenizer=tokenizer, - model=model, - ) - - tokenizer.add_special_tokens({ - "eos_token": DEFAULT_EOS_TOKEN, - "bos_token": DEFAULT_BOS_TOKEN, - "unk_token": DEFAULT_UNK_TOKEN, - }) - - return tokenizer - - -def smart_tokenizer_and_embedding_resize( - tokenizer: transformers.PreTrainedTokenizer, - model: transformers.PreTrainedModel, - special_tokens_dict: Dict = dict(pad_token=DEFAULT_PAD_TOKEN), -): - """Resize tokenizer and embedding. - - Note: This is the unoptimized version that may make your embedding size not be divisible by 64. - """ - - if tokenizer.pad_token is None: - num_new_tokens = tokenizer.add_special_tokens(special_tokens_dict) - - model.resize_token_embeddings(len(tokenizer)) - - if num_new_tokens > 0: - input_embeddings = model.get_input_embeddings().weight.data - output_embeddings = model.get_output_embeddings().weight.data - - input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True) - output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True) - - input_embeddings[-num_new_tokens:] = input_embeddings_avg - output_embeddings[-num_new_tokens:] = output_embeddings_avg diff --git a/applications/Chat/evaluate/README.md b/applications/Chat/evaluate/README.md deleted file mode 100644 index 7ace4bfe6d1871bb97e3e522dc35128d692a7cb7..0000000000000000000000000000000000000000 --- a/applications/Chat/evaluate/README.md +++ /dev/null @@ -1,182 +0,0 @@ -# Evaluation - -In this directory, we introduce how you can evaluate your model with GPT-4. - -## Evaluation Pipeline - -The whole evaluation process undergoes the following three steps: -1. Prepare the questions following the internal data structure in the data format section (described below). -2. Generate answers from different models: - * Generate answers using GPT-3.5: [`generate_gpt35_answers.py`](generate_gpt35_answers.py). - * Generate answers using your own models: [`generate_answers.py`](generate_answers.py). -3. Evaluate models using GPT-4: [`evaluate.py`](evaluate.py). - -### Generate Answers -#### Generate Answers Using GPT-3.5 -You can provide your own OpenAI key to generate answers from GPT-3.5 using [`generate_gpt35_answers.py`](./generate_gpt35_answers.py). - -An example script is provided as follows: -```shell -python generate_gpt35_answers.py \ - --dataset "path to the question dataset" \ - --answer_path "path to answer folder" \ - --num_workers 4 \ - --openai_key "your openai key" \ - --max_tokens 512 \ -``` - -#### Generate Answers Using our Own Model -You can also generate answers using your own models. The generation process is divided into two stages: -1. Generate answers using multiple GPUs (optional) with batch processing: [`generate_answers.py`](./generate_answers.py). -2. Merge multiple shards and output a single file: [`merge.py`](./merge.py). - -An example script is given as follows: - -```shell -device_number=number of your devices -model_name="name of your model" -model_path="path to your model" -dataset="path to the question dataset" -answer_path="path to save the model answers" - -torchrun --standalone --nproc_per_node=$device_number generate_answers.py \ - --model 'llama' \ - --strategy ddp \ - --model_path $model_path \ - --model_name $model_name \ - --dataset $dataset \ - --batch_size 8 \ - --max_datasets_size 80 \ - --answer_path $answer_path \ - --max_length 512 - -python merge.py \ - --model_name $model_name \ - --shards $device_number \ - --answer_path $answer_path \ - -for (( i=0; i scores[1]: - worse_count += 1 - worse_file.append(review_jsons[idx]) - elif scores[0] < scores[1]: - better_count += 1 - better_file.append(review_jsons[idx]) - else: - tie_count += 1 - tie_file.append(review_jsons[idx]) - ans1_score += scores[0] - ans2_score += scores[1] - - output_review_file.append(review_jsons[idx]) - - better_file.sort(key=lambda x: x['id']) - worse_file.sort(key=lambda x: x['id']) - tie_file.sort(key=lambda x: x['id']) - invalid_file.sort(key=lambda x: x['id']) - output_review_file.sort(key=lambda x: x['id']) - - name1 = os.path.basename(args.answer_file_list[0]).split("_answers")[0] - name2 = os.path.basename(args.answer_file_list[1]).split("_answers")[0] - prefix = f"{name1}_vs_{name2}" - - jdump(better_file, os.path.join( - args.output_folder, prefix, f"{prefix}_better.json")) - jdump(worse_file, os.path.join( - args.output_folder, prefix, f"{prefix}_worse.json")) - jdump(tie_file, os.path.join( - args.output_folder, prefix, f"{prefix}_tie.json")) - jdump(invalid_file, os.path.join( - args.output_folder, prefix, f"{prefix}_invalid.json")) - jdump(output_review_file, os.path.join( - args.output_folder, prefix, f"{prefix}_review.json")) - - if os.path.exists(os.path.join(args.output_folder, "results.json")): - results = jload(os.path.join(args.output_folder, "results.json")) - else: - results = {} - results[prefix] = {'model': [name1, name2], 'better': better_count, 'worse': worse_count, 'tie': tie_count, 'win_rate': better_count / - (len(reviews)-invalid_count), 'score': [ans1_score/(len(reviews)-invalid_count), ans2_score/(len(reviews)-invalid_count)]} - jdump(results, os.path.join(args.output_folder, "results.json")) - - logger.info(f' Total {invalid_count} invalid score pair(s).') - logger.info(f' Model {name2} has {better_count} better answer(s).') - logger.info(f' Model {name2} has {worse_count} worse answer(s).') - logger.info(f' {tie_count} answer(s) play(s) to a tie.') - logger.info( - f' Win rate of model {name2}: {better_count/(len(reviews)-invalid_count):.2f}') - logger.info( - f' Model {name1} average score: {ans1_score/(len(reviews)-invalid_count):.2f}') - logger.info( - f' Model {name2} average score: {ans2_score/(len(reviews)-invalid_count):.2f}') - - -if __name__ == '__main__': - parser = argparse.ArgumentParser( - description='Model evaluation.') - parser.add_argument('--answer_file_list', nargs='+', default=[]) - parser.add_argument('--prompt_file') - parser.add_argument('--reviewer_file') - parser.add_argument('--output_folder', type=str, default="./output") - parser.add_argument('--openai_key', type=str, default=None) - parser.add_argument('--model', type=str, default="gpt-4") - parser.add_argument('--num_workers', type=int, default=8) - parser.add_argument('--max_tokens', type=int, default=512, - help='maximum number of tokens produced in the output') - args = parser.parse_args() - - if args.openai_key is not None: - os.environ["OPENAI_API_KEY"] = args.openai_key - openai.api_key = os.getenv("OPENAI_API_KEY") - - evaluate(args) diff --git a/applications/Chat/evaluate/evaluate.sh b/applications/Chat/evaluate/evaluate.sh deleted file mode 100755 index c51aa941019e55e38f57a459f88dc6ae85264c0e..0000000000000000000000000000000000000000 --- a/applications/Chat/evaluate/evaluate.sh +++ /dev/null @@ -1,9 +0,0 @@ -python evaluate.py \ - --answer_file_list "path to answers of model 1" "path to answers of model 2" \ - --prompt_file "path to prompt file" \ - --reviewer_file "path to reviewer file" \ - --output_folder "path to output folder" \ - --openai_key "your openai key" \ - --model "gpt-4" \ - --num_workers 8 \ - --max_tokens 512 \ diff --git a/applications/Chat/evaluate/generate_answers.py b/applications/Chat/evaluate/generate_answers.py deleted file mode 100644 index fbebf5c5e6f6835575ba6d9733604924d9e91deb..0000000000000000000000000000000000000000 --- a/applications/Chat/evaluate/generate_answers.py +++ /dev/null @@ -1,173 +0,0 @@ -import argparse -import os -import random -import copy -import math -from tqdm import tqdm - -import torch -import torch.distributed as dist -import transformers - -from coati.models.bloom import BLOOMActor -from coati.models.gpt import GPTActor -from coati.models.opt import OPTActor -from coati.models.roberta import RoBERTaActor -from coati.models.llama import LlamaActor -from coati.trainer.strategies import ColossalAIStrategy, DDPStrategy, NaiveStrategy -from transformers import AutoTokenizer, RobertaTokenizer -from transformers.models.gpt2.tokenization_gpt2 import GPT2Tokenizer - -from colossalai.logging import get_dist_logger - -from utils import jload, jdump, is_rank_0 - - -logger = get_dist_logger() - -PROMPT_DICT = { - "prompt_input": - ("Below is an instruction that describes a task, paired with an input that provides further context. " - "Write a response that appropriately completes the request.\n\n" - "### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:"), - "prompt_no_input": ("Below is an instruction that describes a task. " - "Write a response that appropriately completes the request.\n\n" - "### Instruction:\n{instruction}\n\n### Response:"), -} - - -def generate(args): - # torch.cuda.set_per_process_memory_fraction(0.4) - if args.strategy == 'naive': - strategy = NaiveStrategy() - elif args.strategy == 'ddp': - strategy = DDPStrategy() - elif args.strategy == 'colossalai_gemini': - strategy = ColossalAIStrategy(stage=3, placement_policy='cuda') - elif args.strategy == 'colossalai_zero2': - strategy = ColossalAIStrategy(stage=2, placement_policy='cuda') - elif args.strategy == 'colossalai_zero2_cpu': - strategy = ColossalAIStrategy(stage=2, placement_policy='cpu') - else: - raise ValueError(f'Unsupported strategy "{args.strategy}"') - - world_size = dist.get_world_size() - rank = dist.get_rank() - - with strategy.model_init_context(): - if args.model == 'gpt2': - actor = GPTActor(pretrained=args.model_path).to( - torch.cuda.current_device()) - elif args.model == 'bloom': - actor = BLOOMActor(pretrained=args.model_path).to( - torch.cuda.current_device()) - elif args.model == 'opt': - actor = OPTActor(pretrained=args.model_path).to( - torch.cuda.current_device()) - elif args.model == 'roberta': - actor = RoBERTaActor(pretrained=args.model_path).to( - torch.cuda.current_device()) - elif args.model == 'llama': - actor = LlamaActor(pretrained=args.model_path).to( - torch.float16).to(torch.cuda.current_device()) - else: - raise ValueError(f'Unsupported model "{args.model}"') - - if args.model == 'gpt2': - tokenizer = GPT2Tokenizer.from_pretrained('gpt2') - tokenizer.pad_token = tokenizer.eos_token - elif args.model == 'bloom': - tokenizer = AutoTokenizer.from_pretrained('bigscience/bloom-560m') - tokenizer.pad_token = tokenizer.eos_token - elif args.model == 'opt': - tokenizer = AutoTokenizer.from_pretrained('facebook/opt-350m') - elif args.model == 'roberta': - tokenizer = RobertaTokenizer.from_pretrained("roberta-base") - elif args.model == 'llama': - tokenizer = AutoTokenizer.from_pretrained(args.model_path, - padding_side="right", - use_fast=False, - ) - tokenizer.eos_token = '<\s>' - else: - raise ValueError(f'Unsupported model "{args.model}"') - - questions = [] - if args.max_datasets_size is not None: - questions = random.sample(jload(args.dataset), args.max_datasets_size) - if is_rank_0(): - logger.info( - f"Limiting dataset to {args.max_datasets_size} examples.") - questions = questions[rank:args.max_datasets_size:world_size] - - answers = copy.deepcopy(questions) - - prompt_input, prompt_no_input = PROMPT_DICT["prompt_input"], PROMPT_DICT["prompt_no_input"] - sources = [ - prompt_input.format_map(example) if example.get( - "input", "") != "" else prompt_no_input.format_map(example) - for example in questions - ] - - if is_rank_0(): - logger.info("Tokenizing inputs... This may take some time...") - - input_ids_list = [] - - for string in sources: - input_ids = tokenizer.encode(string, return_tensors='pt').squeeze(0) - input_ids_list.append(input_ids) - - bar = tqdm(range(math.ceil(len(input_ids_list)/args.batch_size)), - desc=f'steps', disable=not is_rank_0()) - - actor.eval() - with torch.no_grad(): - for i in range(0, len(input_ids_list), args.batch_size): - batch = input_ids_list[i:i+args.batch_size] - batch = [i.flip(dims=[0]) for i in batch] - batch = torch.nn.utils.rnn.pad_sequence(batch, - batch_first=True, - padding_value=tokenizer.pad_token_id if tokenizer.pad_token_id is not None else 0).to(torch.cuda.current_device()) - batch = batch.flip(dims=[1]) - attention_mask = batch.ne(tokenizer.pad_token_id if tokenizer.pad_token_id is not None else 0) - - outputs = actor.model.generate(batch, attention_mask=attention_mask, - max_length=args.max_length, - do_sample=True, - top_k=50, - top_p=0.95, - num_return_sequences=1) - - outputs = tokenizer.batch_decode(outputs, skip_special_tokens=True) - for j in range(batch.size(0)): - answers[i + - j]['output'] = outputs[j].split("### Response:")[1].strip() - - bar.update() - - jdump(answers, os.path.join(args.answer_path, - f'{args.model_name}_answers_rank{rank}.json')) - - if is_rank_0(): - logger.info( - f'Peak CUDA mem: {torch.cuda.max_memory_allocated()/1024**3:.3f} GB') - - -if __name__ == '__main__': - parser = argparse.ArgumentParser() - parser.add_argument('--strategy', - choices=['naive', 'ddp', 'colossalai_gemini', - 'colossalai_zero2', 'colossalai_zero2_cpu'], - default='naive') - parser.add_argument('--model', default='gpt2', - choices=['gpt2', 'bloom', 'opt', 'roberta', 'llama']) - parser.add_argument('--model_path', type=str, default=None) - parser.add_argument('--model_name', type=str, default='model') - parser.add_argument('--dataset', type=str, default=None) - parser.add_argument('--batch_size', type=int, default=1) - parser.add_argument('--max_datasets_size', type=int, default=None) - parser.add_argument('--answer_path', type=str, default="answer") - parser.add_argument('--max_length', type=int, default=1024) - args = parser.parse_args() - generate(args) diff --git a/applications/Chat/evaluate/generate_answers.sh b/applications/Chat/evaluate/generate_answers.sh deleted file mode 100755 index 36881f5f4f292885a153e4558eb1240823c14bb5..0000000000000000000000000000000000000000 --- a/applications/Chat/evaluate/generate_answers.sh +++ /dev/null @@ -1,25 +0,0 @@ -device_number=number of your devices -model_name="name of your model" -model_path="path to your model" -dataset="path to the question dataset" -answer_path="path to save the model answers" - -torchrun --standalone --nproc_per_node=$device_number generate_answers.py \ - --model 'llama' \ - --strategy ddp \ - --model_path $model_path \ - --model_name $model_name \ - --dataset $dataset \ - --batch_size 8 \ - --max_datasets_size 80 \ - --answer_path $answer_path \ - --max_length 512 - -python merge.py \ - --model_name $model_name \ - --shards $device_number \ - --answer_path $answer_path \ - -for (( i=0; i bool: - return not dist.is_initialized() or dist.get_rank() == 0 - -def _make_w_io_base(f, mode: str): - if not isinstance(f, io.IOBase): - f_dirname = os.path.dirname(f) - if f_dirname != "": - os.makedirs(f_dirname, exist_ok=True) - f = open(f, mode=mode) - return f - -def _make_r_io_base(f, mode: str): - if not isinstance(f, io.IOBase): - f = open(f, mode=mode) - return f - -def jdump(obj, f, mode="w", indent=4, default=str): - """Dump a str or dictionary to a file in json format. - Args: - obj: An object to be written. - f: A string path to the location on disk. - mode: Mode for opening the file. - indent: Indent for storing json dictionaries. - default: A function to handle non-serializable entries; defaults to `str`. - """ - f = _make_w_io_base(f, mode) - if isinstance(obj, (dict, list)): - json.dump(obj, f, indent=indent, default=default) - elif isinstance(obj, str): - f.write(obj) - else: - raise ValueError(f"Unexpected type: {type(obj)}") - f.close() - -def jload(f, mode="r"): - """Load a .json file into a dictionary.""" - f = _make_r_io_base(f, mode) - jdict = json.load(f) - f.close() - return jdict - -def get_json_list(file_path): - with open(file_path, 'r') as f: - json_list = [] - for line in f: - json_list.append(json.loads(line)) - return json_list diff --git a/applications/Chat/examples/README.md b/applications/Chat/examples/README.md index 561ace2205ed559846560e9352a13852eddd44bd..9438aafd126811cdfd967334be6168637db4f941 100644 --- a/applications/Chat/examples/README.md +++ b/applications/Chat/examples/README.md @@ -6,6 +6,7 @@ - [Table of Contents](#table-of-contents) - [Install requirements](#install-requirements) - [Supervised datasets collection](#supervised-datasets-collection) + - [Conversation dataset generation](#conversation-dataset-generation) - [Stage1 - Supervised instructs tuning](#stage1---supervised-instructs-tuning) - [Arg List](#arg-list) - [Stage2 - Training reward model](#stage2---training-reward-model) @@ -16,7 +17,7 @@ - [Arg List](#arg-list-2) - [Inference example - After Stage3](#inference-example---after-stage3) - [Attention](#attention) - - [data](#data) + - [data](#data) - [Support Model](#support-model) - [GPT](#gpt) - [BLOOM](#bloom) @@ -24,12 +25,11 @@ - [LLaMA](#llama) - [Add your own models](#add-your-own-models) - [Actor model](#actor-model) - - [LM model](#lm-model) - [Reward model](#reward-model) - [Critic model](#critic-model) - --- + ## Install requirements ```shell @@ -38,27 +38,74 @@ pip install -r requirements.txt ## Supervised datasets collection -We collected 104K bilingual dataset of Chinese and English, and you can find the datasets in this repo -[InstructionWild](https://github.com/XueFuzhao/InstructionWild). +We collected 104K bilingual datasets of Chinese and English, and you can find the datasets in this repo +[InstructionWild](https://github.com/XueFuzhao/InstructionWild) and in this [file](https://github.com/XueFuzhao/InstructionWild/blob/main/data/README.md). + +Here is how we collected the data -The following pic shows how we collected the data.

        +### Conversation dataset generation + +In order to further improve the model's ability to handle multi-turn conversations, we need to include samples with multi-turn conversations in the dataset. However, the samples in InstructWild and Alpaca datasets currently consist of only single-turn conversations, and their dataset organization is not suitable for storing multi-turn conversations. Additionally, after converting the aforementioned datasets, we also need to include multi-turn conversation datasets like ShareGPT, and we should transform them into the training format supported by ColossalChat. + +A sample of conversation dataset should have the following fields: + +- `type` (str, optional): The type of the data sample. +- `language` (str, optional): The language of the data sample. +- `dataset` (str, optional): The dataset the data sample originates from. +- `conversations` (str, compulsory): Conversation content of the data sample. +- `id` (int, optional): The ID of the data sample. + +A simple example: + +```json +{ + "type": "instruction", + "language": "English", + "dataset": "Alpaca", + "conversations": [ + { + "from": "human", + "value": "Give three tips for staying healthy." + }, + { + "from": "gpt", + "value": "1.Eat a balanced diet and make sure to include plenty of fruits and vegetables. \n2. Exercise regularly to keep your body active and strong. \n3. Get enough sleep and maintain a consistent sleep schedule." + } + ], + "id": 1 +} +``` + +> **NOTE:** Only key `conversations` is compulsary for training and other keys serve as metadata. The length of `conversations` varies. + +You can run the `examples/generate_conversation_dataset.py` to generate a conversation dataset supported by ColossalChat. + +You can use the following cmd to generate conversation dataset. + +```bash +python generate_conversation_dataset.py \ + --dataset "All" + --save_path "/path/to/dataset" +``` + ## Stage1 - Supervised instructs tuning Stage1 is supervised instructs fine-tuning, which uses the datasets mentioned earlier to fine-tune the model. +[[Stage1 tutorial video]](https://www.youtube.com/watch?v=-qFBZFmOJfg) You can run the `examples/train_sft.sh` to start a supervised instructs fine-tuning. You can also use the following cmd to start a supervised instructs fine-tuning with your own settings. -``` + +```bash torchrun --standalone --nproc_per_node=4 train_sft.py \ --pretrain "/path/to/LLaMa-7B/" \ --model 'llama' \ --strategy colossalai_zero2 \ - --log_interval 10 \ --save_path /path/to/Coati-7B \ --dataset /path/to/data.json \ --batch_size 4 \ @@ -68,27 +115,44 @@ torchrun --standalone --nproc_per_node=4 train_sft.py \ --max_epochs 1 \ --grad_checkpoint ``` + +**Note**: the supervised dataset follows the following format, + +```json +[ + { + "instruction": "Provide a list of the top 10 most popular mobile games in Asia", + "input": "", + "output": "The top 10 most popular mobile games in Asia are:\n1) PUBG Mobile\n2) Pokemon Go\n3) Candy Crush Saga\n4) Free Fire\n5) Clash of Clans\n6) Mario Kart Tour\n7) Arena of Valor\n8) Fantasy Westward Journey\n9) Subway Surfers\n10) ARK Survival Evolved", + "id": 0 + }, + ... +] +``` + ### Arg List -- --strategy: the strategy using for training, choices=['naive', 'ddp', 'colossalai_gemini', 'colossalai_zero2'], default='colossalai_zero2' -- --model: model type, choices=['gpt2', 'bloom', 'opt', 'llama'], default='bloom' -- --pretrain: pretrain model, type=str, default=None -- --max_datasets_size: the max size of dataset, type=int, default=None -- --save_path: path to save the model, type=str, default='output' -- --need_optim_ckpt: whether to save optim ckpt, type=bool, default=False -- --max_epochs: max epochs for training, type=int, default=3 -- --batch_size: batch size while training, type=int, default=4 -- --lora_rank: low-rank adaptation matrices rank, type=int, default=0 -- --log_interval: how many steps to log, type=int, default=100 -- --grad_checkpoint: enable gradient checkpointing, type=bool, default=False + +- `--strategy`: the strategy using for training, choices=['ddp', 'colossalai_gemini', 'colossalai_zero2'], default='colossalai_zero2' +- `--model`: model type, choices=['gpt2', 'bloom', 'opt', 'llama'], default='bloom' +- `--pretrain`: pretrain model, type=str, default=None +- `--max_datasets_size`: the max size of dataset, type=int, default=None +- `--save_path`: path to save the model, type=str, default='output' +- `--need_optim_ckpt`: whether to save optim ckpt, type=bool, default=False +- `--max_epochs`: max epochs for training, type=int, default=3 +- `--batch_size`: batch size while training, type=int, default=4 +- `--lora_rank`: low-rank adaptation matrices rank, type=int, default=0 +- `--grad_checkpoint`: enable gradient checkpointing, type=bool, default=False ## Stage2 - Training reward model We train a reward model in stage 2, which obtains corresponding scores by manually ranking different outputs for the same prompt and supervises the training of the reward model. +[[Stage2 tutorial video]](https://www.youtube.com/watch?v=gMx2CApKhuo) You can run the `examples/train_rm.sh` to start a reward model training. You can also use the following cmd to start training a reward model. -``` + +```bash torchrun --standalone --nproc_per_node=4 train_reward_model.py \ --pretrain "/path/to/LLaMa-7B/" \ --model 'llama' \ @@ -96,16 +160,19 @@ torchrun --standalone --nproc_per_node=4 train_reward_model.py \ --loss_fn 'log_exp'\ --save_path 'rmstatic.pt' \ ``` + ### Features and tricks in RM training + - We support [Anthropic/hh-rlhf](https://huggingface.co/datasets/Anthropic/hh-rlhf)and[rm-static](https://huggingface.co/datasets/Dahoas/rm-static) datasets. -- We support 2 kinds of loss_function named 'log_sig'(used by OpenAI) and 'log_exp'(used by Anthropic). -- We change the loss to valid_acc and pair_dist to monitor progress during training. +- We support 2 kinds of loss function named `log_sig`(used by OpenAI) and `log_exp`(used by Anthropic). +- We change the loss to `valid_acc` and `pair_dist` to monitor progress during training. - We add special token to the end of the sequence to get better result. - We use cosine-reducing lr-scheduler for RM training. - We set value_head as 1 liner layer and initialize the weight of value_head using N(0,1/(d_model + 1)) distribution. - We train a Bloom-560m reward model for 1 epoch and find the test acc of the model achieve the performance mentions in [Anthropics paper](https://arxiv.org/abs/2204.05862). ### Experiment result + Model performance in [Anthropics paper](https://arxiv.org/abs/2204.05862):
        image @@ -117,20 +184,20 @@ Model performance in [Anthropics paper](https://arxiv.org/abs/2204.05862):
        We also train the reward model based on LLaMA-7B, which reaches the ACC of 72.06% after 1 epoch, performing almost the same as Anthropic's best RM. ### Arg List -- --strategy: the strategy using for training, choices=['naive', 'ddp', 'colossalai_gemini', 'colossalai_zero2'], default='colossalai_zero2' -- --model: model type, choices=['gpt2', 'bloom', 'opt', 'llama'], default='bloom' -- --pretrain: pretrain model, type=str, default=None -- --model_path: the path of rm model(if continue to train), type=str, default=None -- --save_path: path to save the model, type=str, default='output' -- --need_optim_ckpt: whether to save optim ckpt, type=bool, default=False -- --max_epochs: max epochs for training, type=int, default=3 -- --dataset: dataset name, type=str, choices=['Anthropic/hh-rlhf', 'Dahoas/rm-static'] -- --subset: subset of the dataset, type=str, default=None -- --batch_size: batch size while training, type=int, default=4 -- --lora_rank: low-rank adaptation matrices rank, type=int, default=0 -- --loss_func: which kind of loss function, choices=['log_sig', 'log_exp'] -- --max_len: max sentence length for generation, type=int, default=512 -- --test: whether is only testing, if it's true, the dataset will be small + +- `--strategy`: the strategy using for training, choices=['ddp', 'colossalai_gemini', 'colossalai_zero2'], default='colossalai_zero2' +- `--model`: model type, choices=['gpt2', 'bloom', 'opt', 'llama'], default='bloom' +- `--pretrain`: pretrain model, type=str, default=None +- `--model_path`: the path of rm model(if continue to train), type=str, default=None +- `--save_path`: path to save the model, type=str, default='output' +- `--need_optim_ckpt`: whether to save optim ckpt, type=bool, default=False +- `--max_epochs`: max epochs for training, type=int, default=3 +- `--dataset`: dataset name, type=str, choices=['Anthropic/hh-rlhf', 'Dahoas/rm-static'] +- `--subset`: subset of the dataset, type=str, default=None +- `--batch_size`: batch size while training, type=int, default=4 +- `--lora_rank`: low-rank adaptation matrices rank, type=int, default=0 +- `--loss_func`: which kind of loss function, choices=['log_sig', 'log_exp'] +- `--max_len`: max sentence length for generation, type=int, default=512 ## Stage3 - Training model using prompts with RL @@ -141,53 +208,89 @@ Stage3 uses reinforcement learning algorithm, which is the most complex part of

        You can run the `examples/train_prompts.sh` to start PPO training. + You can also use the cmd following to start PPO training. +[[Stage3 tutorial video]](https://www.youtube.com/watch?v=Z8wwSHxPL9g) -``` +```bash torchrun --standalone --nproc_per_node=4 train_prompts.py \ - --pretrain "/path/to/LLaMa-7B/" \ - --model 'llama' \ - --strategy colossalai_zero2 \ - --prompt_dataset /path/to/your/prompt_dataset \ - --pretrain_dataset /path/to/your/pretrain_dataset \ - --rm_pretrain /your/pretrain/rm/defination \ - --rm_path /your/rm/model/path + --pretrain "/path/to/LLaMa-7B/" \ + --model 'llama' \ + --strategy colossalai_zero2 \ + --prompt_dataset /path/to/your/prompt_dataset \ + --pretrain_dataset /path/to/your/pretrain_dataset \ + --rm_pretrain /your/pretrain/rm/definition \ + --rm_path /your/rm/model/path ``` -Prompt dataset: the instruction dataset mentioned in the above figure which includes the instructions, e.g. you can use the [script](https://github.com/hpcaitech/ColossalAI/tree/main/applications/Chat/examples/example_data_reformat.py) to reformat [seed_prompts_ch.jsonl](https://github.com/XueFuzhao/InstructionWild/blob/main/data/seed_prompts_ch.jsonl) or [seed_prompts_en.jsonl](https://github.com/XueFuzhao/InstructionWild/blob/main/data/seed_prompts_en.jsonl) in InstructionWild. +Prompt dataset: the instruction dataset mentioned in the above figure which includes the instructions, e.g. you can use the [script](https://github.com/hpcaitech/ColossalAI/tree/main/applications/Chat/examples/generate_prompt_dataset.py) which samples `instinwild_en.json` or `instinwild_ch.json` in [InstructionWild](https://github.com/XueFuzhao/InstructionWild/tree/main/data#instructwild-data) to generate the prompt dataset. Pretrain dataset: the pretrain dataset including the instruction and corresponding response, e.g. you can use the [InstructWild Data](https://github.com/XueFuzhao/InstructionWild/tree/main/data) in stage 1 supervised instructs tuning. +**Note**: the required datasets follow the following format, + +- `pretrain dataset` + + ```json + [ + { + "instruction": "Provide a list of the top 10 most popular mobile games in Asia", + "input": "", + "output": "The top 10 most popular mobile games in Asia are:\n1) PUBG Mobile\n2) Pokemon Go\n3) Candy Crush Saga\n4) Free Fire\n5) Clash of Clans\n6) Mario Kart Tour\n7) Arena of Valor\n8) Fantasy Westward Journey\n9) Subway Surfers\n10) ARK Survival Evolved", + "id": 0 + }, + ... + ] + ``` + +- `prompt dataset` + + ```json + [ + { + "instruction": "Edit this paragraph to make it more concise: \"Yesterday, I went to the store and bought some things. Then, I came home and put them away. After that, I went for a walk and met some friends.\"", + "id": 0 + }, + { + "instruction": "Write a descriptive paragraph about a memorable vacation you went on", + "id": 1 + }, + ... + ] + ``` + ### Arg List -- --strategy: the strategy using for training, choices=['naive', 'ddp', 'colossalai_gemini', 'colossalai_zero2'], default='colossalai_zero2' -- --model: model type of actor, choices=['gpt2', 'bloom', 'opt', 'llama'], default='bloom' -- --pretrain: pretrain model, type=str, default=None -- --rm_model: reward model type, type=str, choices=['gpt2', 'bloom', 'opt', 'llama'], default=None -- --rm_pretrain: pretrain model for reward model, type=str, default=None -- --rm_path: the path of rm model, type=str, default=None -- --save_path: path to save the model, type=str, default='output' -- --prompt_dataset: path of the prompt dataset, type=str, default=None -- --pretrain_dataset: path of the ptx dataset, type=str, default=None -- --need_optim_ckpt: whether to save optim ckpt, type=bool, default=False -- --num_episodes: num of episodes for training, type=int, default=10 -- --max_epochs: max epochs for training in one episode, type=int, default=5 -- --max_timesteps: max episodes in one batch, type=int, default=10 -- --update_timesteps: timesteps to update, type=int, default=10 -- --train_batch_size: batch size while training, type=int, default=8 -- --ptx_batch_size: batch size to compute ptx loss, type=int, default=1 -- --experience_batch_size: batch size to make experience, type=int, default=8 -- --lora_rank: low-rank adaptation matrices rank, type=int, default=0 -- --kl_coef: kl_coef using for computing reward, type=float, default=0.1 -- --ptx_coef: ptx_coef using for computing policy loss, type=float, default=0.9 + +- `--strategy`: the strategy using for training, choices=['ddp', 'colossalai_gemini', 'colossalai_zero2'], default='colossalai_zero2' +- `--model`: model type of actor, choices=['gpt2', 'bloom', 'opt', 'llama'], default='bloom' +- `--pretrain`: pretrain model, type=str, default=None +- `--rm_model`: reward model type, type=str, choices=['gpt2', 'bloom', 'opt', 'llama'], default=None +- `--rm_pretrain`: pretrain model for reward model, type=str, default=None +- `--rm_path`: the path of rm model, type=str, default=None +- `--save_path`: path to save the model, type=str, default='output' +- `--prompt_dataset`: path of the prompt dataset, type=str, default=None +- `--pretrain_dataset`: path of the ptx dataset, type=str, default=None +- `--need_optim_ckpt`: whether to save optim ckpt, type=bool, default=False +- `--num_episodes`: num of episodes for training, type=int, default=10 +- `--num_update_steps`: number of steps to update policy per episode, type=int +- `--num_collect_steps`: number of steps to collect experience per episode, type=int +- `--train_batch_size`: batch size while training, type=int, default=8 +- `--ptx_batch_size`: batch size to compute ptx loss, type=int, default=1 +- `--experience_batch_size`: batch size to make experience, type=int, default=8 +- `--lora_rank`: low-rank adaptation matrices rank, type=int, default=0 +- `--kl_coef`: kl_coef using for computing reward, type=float, default=0.1 +- `--ptx_coef`: ptx_coef using for computing policy loss, type=float, default=0.9 ## Inference example - After Stage3 + We support different inference options, including int8 and int4 quantization. For details, see [`inference/`](https://github.com/hpcaitech/ColossalAI/tree/main/applications/Chat/inference). - ## Attention + The examples are demos for the whole training process.You need to change the hyper-parameters to reach great performance. #### data + - [x] [rm-static](https://huggingface.co/datasets/Dahoas/rm-static) - [x] [hh-rlhf](https://huggingface.co/datasets/Anthropic/hh-rlhf) - [ ] [openai/summarize_from_feedback](https://huggingface.co/datasets/openai/summarize_from_feedback) @@ -197,14 +300,16 @@ The examples are demos for the whole training process.You need to change the hyp ## Support Model ### GPT -- [x] GPT2-S (s) -- [x] GPT2-M (m) -- [x] GPT2-L (l) -- [x] GPT2-XL (xl) -- [x] GPT2-4B (4b) -- [ ] GPT2-6B (6b) + +- [x] GPT2-S (s) +- [x] GPT2-M (m) +- [x] GPT2-L (l) +- [x] GPT2-XL (xl) +- [x] GPT2-4B (4b) +- [ ] GPT2-6B (6b) ### BLOOM + - [x] [BLOOM-560m](https://huggingface.co/bigscience/bloom-560m) - [x] [BLOOM-1b1](https://huggingface.co/bigscience/bloom-1b1) - [x] [BLOOM-3b](https://huggingface.co/bigscience/bloom-3b) @@ -212,6 +317,7 @@ The examples are demos for the whole training process.You need to change the hyp - [ ] [BLOOM-175b](https://huggingface.co/bigscience/bloom) ### OPT + - [x] [OPT-125M](https://huggingface.co/facebook/opt-125m) - [x] [OPT-350M](https://huggingface.co/facebook/opt-350m) - [x] [OPT-1.3B](https://huggingface.co/facebook/opt-1.3b) @@ -221,10 +327,11 @@ The examples are demos for the whole training process.You need to change the hyp - [ ] [OPT-30B](https://huggingface.co/facebook/opt-30b) ### [LLaMA](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) -- [x] LLaMA-7B -- [x] LLaMA-13B -- [ ] LLaMA-33B -- [ ] LLaMA-65B + +- [x] LLaMA-7B +- [x] LLaMA-13B +- [ ] LLaMA-33B +- [ ] LLaMA-65B ## Add your own models @@ -237,12 +344,12 @@ if it is supported in huggingface [transformers](https://github.com/huggingface/ r you can build your own model by yourself. ### Actor model -``` + +```python from ..base import Actor from transformers.models.coati import CoatiModel class CoatiActor(Actor): - def __init__(self, pretrained: Optional[str] = None, checkpoint: bool = False, @@ -257,7 +364,8 @@ class CoatiActor(Actor): ``` ### Reward model -``` + +```python from ..base import RewardModel from transformers.models.coati import CoatiModel @@ -280,12 +388,11 @@ class CoatiRM(RewardModel): ### Critic model -``` +```python from ..base import Critic from transformers.models.coati import CoatiModel class CoatiCritic(Critic): - def __init__(self, pretrained: Optional[str] = None, checkpoint: bool = False, diff --git a/applications/Chat/examples/community/README.md b/applications/Chat/examples/community/README.md index c9c645032288f5beccd6b45f2d53f6ee5677c4e7..e14ac1767fc12aca65e0570827f960d6bfc0294b 100644 --- a/applications/Chat/examples/community/README.md +++ b/applications/Chat/examples/community/README.md @@ -1,5 +1,9 @@ +:warning: **This content may be outdated since the major update of Colossal Chat. We will update this content soon.** + # Community Examples + --- + We are thrilled to announce the latest updates to ColossalChat, an open-source solution for cloning ChatGPT with a complete RLHF (Reinforcement Learning with Human Feedback) pipeline. As Colossal-AI undergoes major updates, we are actively maintaining ColossalChat to stay aligned with the project's progress. With the introduction of Community-driven example, we aim to create a collaborative platform for developers to contribute exotic features built on top of ColossalChat. @@ -14,11 +18,12 @@ For more information about community pipelines, please have a look at this [issu Community examples consist of both inference and training examples that have been added by the community. Please have a look at the following table to get an overview of all community examples. Click on the Code Example to get a copy-and-paste ready code example that you can try out. If a community doesn't work as expected, please open an issue and ping the author on it. -| Example | Description | Code Example | Colab | Author | -|:---------------------------------------|:---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|:------------------------------------------------------------------|:-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|-----------------------------------------------------------:| -| Peft | Adding Peft support for SFT and Prompts model training | [Huggingface Peft](https://github.com/hpcaitech/ColossalAI/tree/main/applications/Chat/examples/community/peft) | - | [YY Lin](https://github.com/yynil) | -| Train prompts on Ray | A Ray based implementation of Train prompts example | [Huggingface Peft](https://github.com/hpcaitech/ColossalAI/tree/main/applications/Chat/examples/community/ray) | - | [MisterLin1995](https://github.com/MisterLin1995) | -|...|...|...|...|...| +| Example | Description | Code Example | Colab | Author | +| :------------------- | :----------------------------------------------------- | :-------------------------------------------------------------------------------------------------------------- | :---- | ------------------------------------------------: | +| Peft | Adding Peft support for SFT and Prompts model training | [Huggingface Peft](https://github.com/hpcaitech/ColossalAI/tree/main/applications/Chat/examples/community/peft) | - | [YY Lin](https://github.com/yynil) | +| Train prompts on Ray | A Ray based implementation of Train prompts example | [Training On Ray](https://github.com/hpcaitech/ColossalAI/tree/main/applications/Chat/examples/community/ray) | - | [MisterLin1995](https://github.com/MisterLin1995) | +| ... | ... | ... | ... | ... | ### How to get involved + To join our community-driven initiative, please visit the [ColossalChat GitHub repository](https://github.com/hpcaitech/ColossalAI/tree/main/applications/Chat/examples), review the provided information, and explore the codebase. To contribute, create a new issue outlining your proposed feature or enhancement, and our team will review and provide feedback. We look forward to collaborating with you on this exciting project! diff --git a/applications/Chat/examples/community/peft/README.md b/applications/Chat/examples/community/peft/README.md index eabb56fd8294ea89ad7632bd73ffece758e29e56..ada3a16296afab4facd94d92ab6ef0765f572b28 100644 --- a/applications/Chat/examples/community/peft/README.md +++ b/applications/Chat/examples/community/peft/README.md @@ -1,3 +1,5 @@ +:warning: **This content may be outdated since the major update of Colossal Chat. We will update this content soon.** + # Add Peft support for SFT and Prompts model training The original implementation just adopts the loralib and merges the layers into the final model. The huggingface peft is a better lora model implementation and can be easily training and distributed. @@ -5,7 +7,9 @@ The original implementation just adopts the loralib and merges the layers into t Since reward model is relative small, I just keep it as original one. I suggest train full model to get the proper reward/critic model. # Preliminary installation + Since the current pypi peft package(0.2) has some bugs, please install the peft package using source. + ``` git clone https://github.com/huggingface/peft cd peft @@ -13,12 +17,14 @@ pip install . ``` # Usage + For SFT training, just call train_peft_sft.py -Its arguments are almost identical to train_sft.py instead adding a new eval_dataset if you have a eval_dataset file. The data file is just a plain datafile, please check the format in the easy_dataset.py. +Its arguments are almost identical to train_sft.py instead adding a new eval_dataset if you have an eval_dataset file. The data file is just a plain datafile, please check the format in the easy_dataset.py. For stage-3 rlhf training, call train_peft_prompts.py. -Its arguments are almost idential to train_prompts.py. The only difference is that I use text files to indicate the prompt and pretrained data file. The models are included in easy_models.py. Currently only bloom models are tested, but technically gpt2/opt/llama should be supported. +Its arguments are almost identical to train_prompts.py. The only difference is that I use text files to indicate the prompt and pretrained data file. The models are included in easy_models.py. Currently only bloom models are tested, but technically gpt2/opt/llama should be supported. # Dataformat + Please refer the formats in test_sft.txt, test_prompts.txt, test_pretrained.txt. diff --git a/applications/Chat/examples/community/peft/easy_dataset.py b/applications/Chat/examples/community/peft/easy_dataset.py index 24ea4f0a86186c4789062ab148c83c916ec86edc..d4b17689e9cb38dea77573cc3fb9d4912c47d8bd 100644 --- a/applications/Chat/examples/community/peft/easy_dataset.py +++ b/applications/Chat/examples/community/peft/easy_dataset.py @@ -3,7 +3,6 @@ import json from typing import Dict, Sequence import torch -from datasets import load_dataset from torch.utils.data import Dataset from tqdm import tqdm from transformers import AutoTokenizer @@ -20,7 +19,8 @@ def _tokenize_fn(strings: Sequence[str], tokenizer: AutoTokenizer, max_length: i padding="longest", max_length=max_length, truncation=True, - ) for text in strings + ) + for text in strings ] input_ids = labels = [tokenized.input_ids[0] for tokenized in tokenized_list] input_ids_lens = labels_lens = [ @@ -48,18 +48,17 @@ def preprocess(sources: Sequence[str], targets: Sequence[str], tokenizer: AutoTo class EasySupervisedDataset(Dataset): - def __init__(self, data_file: str, tokenizer: AutoTokenizer, max_length: int = 512) -> None: super(EasySupervisedDataset, self).__init__() with open(data_file, "r", encoding="UTF-8") as f: all_lines = f.readlines() - #split to source and target ,source the characters before "回答:" including "回答:", target the characters after "回答:" + # split to source and target ,source the characters before "回答:" including "回答:", target the characters after "回答:" sources, targets = [], [] for line in all_lines: if "回答:" in line: sep_index = line.index("回答:") - sources.append(line[:sep_index + 3]) - targets.append(line[sep_index + 3:] + tokenizer.eos_token) + sources.append(line[: sep_index + 3]) + targets.append(line[sep_index + 3 :] + tokenizer.eos_token) else: sources.append(line) targets.append("" + tokenizer.eos_token) @@ -83,15 +82,17 @@ class EasySupervisedDataset(Dataset): class EasyPromptsDataset(Dataset): - def __init__(self, data_file: str, tokenizer: AutoTokenizer, max_length: int = 96) -> None: super(EasyPromptsDataset, self).__init__() with open(data_file, "r", encoding="UTF-8") as f: all_lines = f.readlines() - all_lines = [line if "回答:" not in line else line[:line.index("回答:") + 3] for line in all_lines] + all_lines = [line if "回答:" not in line else line[: line.index("回答:") + 3] for line in all_lines] self.prompts = [ - tokenizer(line, return_tensors='pt', max_length=max_length, padding='max_length', - truncation=True)['input_ids'].to(torch.cuda.current_device()).squeeze(0) + tokenizer(line, return_tensors="pt", max_length=max_length, padding="max_length", truncation=True)[ + "input_ids" + ] + .to(torch.cuda.current_device()) + .squeeze(0) for line in tqdm(all_lines) ] self.data_file = data_file @@ -110,7 +111,6 @@ class EasyPromptsDataset(Dataset): class EasyRewardDataset(Dataset): - def __init__(self, train_file: str, tokenizer: AutoTokenizer, special_token=None, max_length=512) -> None: super(EasyRewardDataset, self).__init__() self.chosen = [] @@ -120,44 +120,42 @@ class EasyRewardDataset(Dataset): else: self.end_token = special_token print(self.end_token) - #read all lines in the train_file to a list + # read all lines in the train_file to a list with open(train_file, "r", encoding="UTF-8") as f: all_lines = f.readlines() for line in tqdm(all_lines): data = json.loads(line) - prompt = "提问:" + data['prompt'] + " 回答:" - - chosen = prompt + data['chosen'] + self.end_token - chosen_token = tokenizer(chosen, - max_length=max_length, - padding="max_length", - truncation=True, - return_tensors="pt") - self.chosen.append({ - "input_ids": chosen_token['input_ids'], - "attention_mask": chosen_token['attention_mask'] - }) - - reject = prompt + data['rejected'] + self.end_token - reject_token = tokenizer(reject, - max_length=max_length, - padding="max_length", - truncation=True, - return_tensors="pt") - self.reject.append({ - "input_ids": reject_token['input_ids'], - "attention_mask": reject_token['attention_mask'] - }) + prompt = "提问:" + data["prompt"] + " 回答:" + + chosen = prompt + data["chosen"] + self.end_token + chosen_token = tokenizer( + chosen, max_length=max_length, padding="max_length", truncation=True, return_tensors="pt" + ) + self.chosen.append( + {"input_ids": chosen_token["input_ids"], "attention_mask": chosen_token["attention_mask"]} + ) + + reject = prompt + data["rejected"] + self.end_token + reject_token = tokenizer( + reject, max_length=max_length, padding="max_length", truncation=True, return_tensors="pt" + ) + self.reject.append( + {"input_ids": reject_token["input_ids"], "attention_mask": reject_token["attention_mask"]} + ) def __len__(self): length = len(self.chosen) return length def __getitem__(self, idx): - return self.chosen[idx]["input_ids"], self.chosen[idx]["attention_mask"], self.reject[idx][ - "input_ids"], self.reject[idx]["attention_mask"] - - #python representation of the object and the string representation of the object + return ( + self.chosen[idx]["input_ids"], + self.chosen[idx]["attention_mask"], + self.reject[idx]["input_ids"], + self.reject[idx]["attention_mask"], + ) + + # python representation of the object and the string representation of the object def __repr__(self): return f"LawRewardDataset(chosen_len={len(self.chosen)}, reject_len={len(self.reject)})" @@ -165,30 +163,29 @@ class EasyRewardDataset(Dataset): return f"LawRewardDataset(chosen_len={len(self.chosen)}, reject_len={len(self.reject)})" -''' +""" Easy SFT just accept a text file which can be read line by line. However the datasets will group texts together to max_length so LLM will learn the texts meaning better. If individual lines are not related, just set is_group_texts to False. -''' +""" class EasySFTDataset(Dataset): - def __init__(self, data_file: str, tokenizer: AutoTokenizer, max_length=512, is_group_texts=True) -> None: super().__init__() - #read the data_file line by line + # read the data_file line by line with open(data_file, "r", encoding="UTF-8") as f: - #encode the text data line by line and put raw python list input_ids only to raw_input_ids list + # encode the text data line by line and put raw python list input_ids only to raw_input_ids list raw_input_ids = [] for line in f: encoded_ids = tokenizer.encode(line) - #if the encoded_ids is longer than max_length, then split it into several parts + # if the encoded_ids is longer than max_length, then split it into several parts if len(encoded_ids) > max_length: for i in range(0, len(encoded_ids), max_length): - raw_input_ids.append(encoded_ids[i:i + max_length]) + raw_input_ids.append(encoded_ids[i : i + max_length]) else: raw_input_ids.append(encoded_ids) - grouped_inpup_ids = [] + grouped_input_ids = [] current_input_ids = [] attention_mask = [] if tokenizer.pad_token_id is None: @@ -196,30 +193,33 @@ class EasySFTDataset(Dataset): if is_group_texts: for input_ids in raw_input_ids: if len(current_input_ids) + len(input_ids) > max_length: - #pad the current_input_ids to max_length with tokenizer.pad_token_id + # pad the current_input_ids to max_length with tokenizer.pad_token_id padded_length = max_length - len(current_input_ids) current_input_ids.extend([tokenizer.pad_token_id] * padded_length) - grouped_inpup_ids.append(torch.tensor(current_input_ids, dtype=torch.long)) + grouped_input_ids.append(torch.tensor(current_input_ids, dtype=torch.long)) attention_mask.append( - torch.tensor([1] * (max_length - padded_length) + [0] * padded_length, dtype=torch.long)) + torch.tensor([1] * (max_length - padded_length) + [0] * padded_length, dtype=torch.long) + ) current_input_ids = [] else: current_input_ids.extend(input_ids) if len(current_input_ids) > 0: padded_length = max_length - len(current_input_ids) current_input_ids.extend([tokenizer.pad_token_id] * padded_length) - grouped_inpup_ids.append(torch.tensor(current_input_ids, dtype=torch.long)) + grouped_input_ids.append(torch.tensor(current_input_ids, dtype=torch.long)) attention_mask.append( - torch.tensor([1] * (max_length - padded_length) + [0] * padded_length, dtype=torch.long)) + torch.tensor([1] * (max_length - padded_length) + [0] * padded_length, dtype=torch.long) + ) else: - #just append the raw_input_ids to max_length + # just append the raw_input_ids to max_length for input_ids in raw_input_ids: padded_length = max_length - len(input_ids) input_ids.extend([tokenizer.pad_token_id] * padded_length) attention_mask.append( - torch.tensor([1] * (max_length - padded_length) + [0] * padded_length, dtype=torch.long)) - grouped_inpup_ids.append(torch.tensor(input_ids, dtype=torch.long)) - self.input_ids = grouped_inpup_ids + torch.tensor([1] * (max_length - padded_length) + [0] * padded_length, dtype=torch.long) + ) + grouped_input_ids.append(torch.tensor(input_ids, dtype=torch.long)) + self.input_ids = grouped_input_ids self.labels = copy.deepcopy(self.input_ids) self.file_name = data_file self.attention_mask = attention_mask @@ -227,14 +227,14 @@ class EasySFTDataset(Dataset): def __len__(self): return len(self.input_ids) - #get item from dataset + # get item from dataset def __getitem__(self, idx): return dict(input_ids=self.input_ids[idx], labels=self.labels[idx], attention_mask=self.attention_mask[idx]) - #generate the dataset description to be printed by print in python + # generate the dataset description to be printed by print in python def __repr__(self): return f"EasySFTDataset(len={len(self)},\nfile_name is {self.file_name})" - #generate the dataset description to be printed by print in python + # generate the dataset description to be printed by print in python def __str__(self): return f"EasySFTDataset(len={len(self)},\nfile_name is {self.file_name})" diff --git a/applications/Chat/examples/community/peft/easy_models.py b/applications/Chat/examples/community/peft/easy_models.py index fe294868159dde227cae9757da41ee71b5778a25..db629e50ed94f6ba13b3d70b8646053c0dcc94b3 100644 --- a/applications/Chat/examples/community/peft/easy_models.py +++ b/applications/Chat/examples/community/peft/easy_models.py @@ -4,7 +4,7 @@ import torch import torch.nn as nn import torch.nn.functional as F from coati.models.generation import generate -from coati.models.utils import log_probs_from_logits, masked_mean +from coati.models.utils import log_probs_from_logits from peft import PeftModel from torch.nn.modules import Module from transformers import BloomConfig, BloomForCausalLM @@ -24,38 +24,33 @@ class Actor(Module): @torch.no_grad() def generate( - self, - input_ids: torch.Tensor, - return_action_mask: bool = True, - **kwargs + self, input_ids: torch.Tensor, return_action_mask: bool = True, **kwargs ) -> Union[Tuple[torch.LongTensor, torch.LongTensor], Tuple[torch.LongTensor, torch.LongTensor, torch.BoolTensor]]: sequences = generate(self.model, input_ids, **kwargs) attention_mask = None - pad_token_id = kwargs.get('pad_token_id', None) + pad_token_id = kwargs.get("pad_token_id", None) if pad_token_id is not None: attention_mask = sequences.not_equal(pad_token_id).to(dtype=torch.long, device=sequences.device) if not return_action_mask: return sequences, attention_mask, None input_len = input_ids.size(1) - eos_token_id = kwargs.get('eos_token_id', None) + eos_token_id = kwargs.get("eos_token_id", None) if eos_token_id is None: action_mask = torch.ones_like(sequences, dtype=torch.bool) else: # left padding may be applied, only mask action action_mask = (sequences[:, input_len:] == eos_token_id).cumsum(dim=-1) == 0 - action_mask = F.pad(action_mask, (1 + input_len, -1), value=True) # include eos token and input + action_mask = F.pad(action_mask, (1 + input_len, -1), value=True) # include eos token and input action_mask[:, :input_len] = False action_mask = action_mask[:, 1:] - return sequences, attention_mask, action_mask[:, -(sequences.size(1) - input_len):] + return sequences, attention_mask, action_mask[:, -(sequences.size(1) - input_len) :] - def forward(self, - sequences: torch.LongTensor, - num_actions: int, - attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor: - """Returns action log probs - """ + def forward( + self, sequences: torch.LongTensor, num_actions: int, attention_mask: Optional[torch.Tensor] = None + ) -> torch.Tensor: + """Returns action log probs""" output = self.model(sequences, attention_mask=attention_mask) - logits = output['logits'] + logits = output["logits"] log_probs = log_probs_from_logits(logits[:, :-1, :], sequences[:, 1:]) return log_probs[:, -num_actions:] @@ -75,11 +70,13 @@ class BLOOMActor(Actor): lora_train_bias (str): LoRA bias training mode. """ - def __init__(self, - pretrained: str = None, - config: Optional[BloomConfig] = None, - checkpoint: bool = False, - lora_path: str = None) -> None: + def __init__( + self, + pretrained: str = None, + config: Optional[BloomConfig] = None, + checkpoint: bool = False, + lora_path: str = None, + ) -> None: if pretrained is not None: model = BloomForCausalLM.from_pretrained(pretrained) elif config is not None: diff --git a/applications/Chat/examples/community/peft/train_peft_prompts.py b/applications/Chat/examples/community/peft/train_peft_prompts.py index 0e277021e917a7da8ce5c3df18b0165e22b9a0e2..99a024f1463c34c8c77b9bea062ce12326be0ace 100644 --- a/applications/Chat/examples/community/peft/train_peft_prompts.py +++ b/applications/Chat/examples/community/peft/train_peft_prompts.py @@ -1,19 +1,16 @@ import argparse -import pandas as pd import torch import torch.distributed as dist -from coati.dataset import DataCollatorForSupervisedDataset, PromptDataset, SupervisedDataset +from coati.dataset import DataCollatorForSupervisedDataset from coati.models.bloom import BLOOMRM, BLOOMCritic -from coati.models.gpt import GPTRM, GPTActor, GPTCritic -from coati.models.llama import LlamaActor, LlamaCritic, LlamaRM -from coati.models.opt import OPTRM, OPTActor, OPTCritic +from coati.models.gpt import GPTRM, GPTCritic +from coati.models.llama import LlamaCritic, LlamaRM +from coati.models.opt import OPTRM, OPTCritic from coati.trainer import PPOTrainer -from coati.trainer.strategies import ColossalAIStrategy, DDPStrategy, NaiveStrategy -from coati.utils import prepare_llama_tokenizer_and_embedding +from coati.trainer.strategies import DDPStrategy, GeminiStrategy, LowLevelZeroStrategy from easy_dataset import EasyPromptsDataset, EasySupervisedDataset from easy_models import BLOOMActor -from peft import PeftModel from torch.optim import Adam from torch.utils.data import DataLoader from torch.utils.data.distributed import DistributedSampler @@ -24,26 +21,24 @@ from colossalai.nn.optimizer import HybridAdam def main(args): # configure strategy - if args.strategy == 'naive': - strategy = NaiveStrategy() - elif args.strategy == 'ddp': + if args.strategy == "ddp": strategy = DDPStrategy() - elif args.strategy == 'colossalai_gemini': - strategy = ColossalAIStrategy(stage=3, placement_policy='cpu', initial_scale=2**5) - elif args.strategy == 'colossalai_zero2': - strategy = ColossalAIStrategy(stage=2, placement_policy='cpu') + elif args.strategy == "colossalai_gemini": + strategy = GeminiStrategy(placement_policy="static", offload_optim_frac=1.0, offload_param_frac=1.0, initial_scale=2**5) + elif args.strategy == "colossalai_zero2": + strategy = LowLevelZeroStrategy(stage=2, placement_policy="cpu") else: raise ValueError(f'Unsupported strategy "{args.strategy}"') if args.rm_path is not None: - state_dict = torch.load(args.rm_path, map_location='cpu') + state_dict = torch.load(args.rm_path, map_location="cpu") # configure model - if args.model == 'bloom': + if args.model == "bloom": # initial_model = BLOOMActor(pretrained=args.pretrain) - print('Using peft lora to load Bloom model as inital_model') + print("Using peft lora to load Bloom model as initial_model") initial_model = BLOOMActor(pretrained=args.pretrain, lora_path=args.sft_lora_path) - print('Using peft lora to load Bloom model as initial_model (Done)') + print("Using peft lora to load Bloom model as initial_model (Done)") else: raise ValueError(f'Unsupported actor model "{args.model}"') @@ -52,59 +47,59 @@ def main(args): else: rm_model_name = args.rm_model - if rm_model_name == 'gpt2': + if rm_model_name == "gpt2": reward_model = GPTRM(pretrained=args.rm_pretrain) - elif rm_model_name == 'bloom': + elif rm_model_name == "bloom": print("load bloom reward model ", args.rm_pretrain) reward_model = BLOOMRM(pretrained=args.rm_pretrain) - elif rm_model_name == 'opt': + elif rm_model_name == "opt": reward_model = OPTRM(pretrained=args.rm_pretrain) - elif rm_model_name == 'llama': + elif rm_model_name == "llama": reward_model = LlamaRM(pretrained=args.rm_pretrain) else: raise ValueError(f'Unsupported reward model "{rm_model_name}"') if args.rm_path is not None: - print('Loading reward model from', args.rm_path) + print("Loading reward model from", args.rm_path) reward_model.load_state_dict(state_dict) - if args.strategy != 'colossalai_gemini': + if args.strategy != "colossalai_gemini": initial_model.to(torch.float16).to(torch.cuda.current_device()) reward_model.to(torch.float16).to(torch.cuda.current_device()) with strategy.model_init_context(): - if args.model == 'bloom': + if args.model == "bloom": # actor = BLOOMActor(pretrained=args.pretrain, lora_rank=args.lora_rank) - print('Using peft lora to load Bloom model as Actor') + print("Using peft lora to load Bloom model as Actor") actor = BLOOMActor(pretrained=args.pretrain, lora_path=args.sft_lora_path) - print('Using peft lora to load Bloom model as Actor (Done)') + print("Using peft lora to load Bloom model as Actor (Done)") else: raise ValueError(f'Unsupported actor model "{args.model}"') - if rm_model_name == 'gpt2': + if rm_model_name == "gpt2": critic = GPTCritic(pretrained=args.rm_pretrain, lora_rank=args.lora_rank, use_action_mask=True) - elif rm_model_name == 'bloom': + elif rm_model_name == "bloom": print("load bloom critic ", args.rm_pretrain, " lora_rank ", args.lora_rank, " use_action_mask ", True) critic = BLOOMCritic(pretrained=args.rm_pretrain, lora_rank=args.lora_rank, use_action_mask=True) print("load bloom critic (Done) ") - elif rm_model_name == 'opt': + elif rm_model_name == "opt": critic = OPTCritic(pretrained=args.rm_pretrain, lora_rank=args.lora_rank, use_action_mask=True) - elif rm_model_name == 'llama': + elif rm_model_name == "llama": critic = LlamaCritic(pretrained=args.rm_pretrain, lora_rank=args.lora_rank, use_action_mask=True) else: raise ValueError(f'Unsupported reward model "{rm_model_name}"') if args.rm_path is not None: - print('Loading reward model from', args.rm_path) + print("Loading reward model from", args.rm_path) critic.load_state_dict(state_dict) del state_dict - if args.strategy != 'colossalai_gemini': + if args.strategy != "colossalai_gemini": critic.to(torch.float16).to(torch.cuda.current_device()) actor.to(torch.float16).to(torch.cuda.current_device()) # configure optimizer - if args.strategy.startswith('colossalai'): + if args.strategy.startswith("colossalai"): actor_optim = HybridAdam(actor.parameters(), lr=1e-7) critic_optim = HybridAdam(critic.parameters(), lr=1e-7) else: @@ -112,23 +107,22 @@ def main(args): critic_optim = Adam(critic.parameters(), lr=1e-7) # configure tokenizer - if args.model == 'gpt2': + if args.model == "gpt2": tokenizer = GPT2Tokenizer.from_pretrained(args.rm_pretrain) - elif args.model == 'bloom': + tokenizer.pad_token = tokenizer.eos_token + elif args.model == "bloom": tokenizer = BloomTokenizerFast.from_pretrained(args.rm_pretrain) - elif args.model == 'opt': + tokenizer.pad_token = tokenizer.eos_token + elif args.model == "opt": tokenizer = AutoTokenizer.from_pretrained(args.rm_pretrain) - elif args.model == 'llama': + tokenizer.pad_token = tokenizer.eos_token + elif args.model == "llama": tokenizer = LlamaTokenizer.from_pretrained(args.pretrain) - tokenizer.eos_token = '<\s>' + tokenizer.eos_token = "<\s>" + tokenizer.pad_token = tokenizer.unk_token else: raise ValueError(f'Unsupported model "{args.model}"') - if args.model == 'llama': - tokenizer = prepare_llama_tokenizer_and_embedding(tokenizer, actor) - else: - tokenizer.pad_token = tokenizer.eos_token - data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer) prompt_dataset = EasyPromptsDataset(args.prompt_path, tokenizer) @@ -136,26 +130,27 @@ def main(args): prompt_sampler = DistributedSampler(prompt_dataset, shuffle=True, seed=42, drop_last=True) else: prompt_sampler = None - prompt_dataloader = DataLoader(prompt_dataset, - shuffle=(prompt_sampler is None), - sampler=prompt_sampler, - batch_size=args.train_batch_size) + prompt_dataloader = DataLoader( + prompt_dataset, shuffle=(prompt_sampler is None), sampler=prompt_sampler, batch_size=args.train_batch_size + ) pretrain_dataset = EasySupervisedDataset(args.pretrain_dataset, tokenizer) if dist.is_initialized() and dist.get_world_size() > 1: pretrain_sampler = DistributedSampler(pretrain_dataset, shuffle=True, seed=42, drop_last=True) else: pretrain_sampler = None - pretrain_dataloader = DataLoader(pretrain_dataset, - shuffle=(pretrain_sampler is None), - sampler=pretrain_sampler, - batch_size=args.ptx_batch_size, - collate_fn=data_collator) + pretrain_dataloader = DataLoader( + pretrain_dataset, + shuffle=(pretrain_sampler is None), + sampler=pretrain_sampler, + batch_size=args.ptx_batch_size, + collate_fn=data_collator, + ) def tokenize_fn(texts): # MUST padding to max length to ensure inputs of all ranks have the same length # Different length may lead to hang when using gemini, as different generation steps - batch = tokenizer(texts, return_tensors='pt', max_length=96, padding='max_length', truncation=True) + batch = tokenizer(texts, return_tensors="pt", max_length=96, padding="max_length", truncation=True) return {k: v.to(torch.cuda.current_device()) for k, v in batch.items()} (actor, actor_optim), (critic, critic_optim) = strategy.prepare((actor, actor_optim), (critic, critic_optim)) @@ -171,7 +166,6 @@ def main(args): critic_optim, kl_coef=args.kl_coef, ptx_coef=args.ptx_coef, - max_epochs=args.max_epochs, train_batch_size=args.train_batch_size, experience_batch_size=args.experience_batch_size, tokenizer=tokenize_fn, @@ -183,46 +177,46 @@ def main(args): eos_token_id=tokenizer.eos_token_id, ) - trainer.fit(prompt_dataloader=prompt_dataloader, - pretrain_dataloader=pretrain_dataloader, - num_episodes=args.num_episodes, - max_timesteps=args.max_timesteps, - update_timesteps=args.update_timesteps) + trainer.fit( + prompt_dataloader=prompt_dataloader, + pretrain_dataloader=pretrain_dataloader, + num_episodes=args.num_episodes, + num_update_steps=args.num_update_steps, + num_collect_steps=args.num_collect_steps, + ) # save model checkpoint after fitting trainer.save_model(args.save_path, only_rank0=True, tokenizer=tokenizer) # save optimizer checkpoint on all ranks if args.need_optim_ckpt: - strategy.save_optimizer(actor_optim, - 'actor_optim_checkpoint_prompts_%d.pt' % (torch.cuda.current_device()), - only_rank0=False) + strategy.save_optimizer( + actor_optim, "actor_optim_checkpoint_prompts_%d.pt" % (torch.cuda.current_device()), only_rank0=False + ) -if __name__ == '__main__': +if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument('--prompt_path', type=str, default=None, help='path to the prompt dataset') - parser.add_argument('--pretrain_dataset', type=str, default=None, help='path to the pretrained dataset') - parser.add_argument('--strategy', - choices=['naive', 'ddp', 'colossalai_gemini', 'colossalai_zero2'], - default='naive', - help='strategy to use') - parser.add_argument('--model', default='gpt2', choices=['gpt2', 'bloom', 'opt', 'llama']) - parser.add_argument('--pretrain', type=str, default=None) - parser.add_argument('--sft_lora_path', type=str, default=None) - parser.add_argument('--rm_model', default=None, choices=['gpt2', 'bloom', 'opt', 'llama']) - parser.add_argument('--rm_path', type=str, default=None) - parser.add_argument('--rm_pretrain', type=str, default=None) - parser.add_argument('--save_path', type=str, default='actor_checkpoint_prompts') - parser.add_argument('--need_optim_ckpt', type=bool, default=False) - parser.add_argument('--num_episodes', type=int, default=10) - parser.add_argument('--max_timesteps', type=int, default=10) - parser.add_argument('--update_timesteps', type=int, default=10) - parser.add_argument('--max_epochs', type=int, default=5) - parser.add_argument('--train_batch_size', type=int, default=2) - parser.add_argument('--ptx_batch_size', type=int, default=1) - parser.add_argument('--experience_batch_size', type=int, default=8) - parser.add_argument('--lora_rank', type=int, default=0, help="low-rank adaptation matrices rank") - parser.add_argument('--kl_coef', type=float, default=0.1) - parser.add_argument('--ptx_coef', type=float, default=0.9) + parser.add_argument("--prompt_path", type=str, default=None, help="path to the prompt dataset") + parser.add_argument("--pretrain_dataset", type=str, default=None, help="path to the pretrained dataset") + parser.add_argument( + "--strategy", choices=["ddp", "colossalai_gemini", "colossalai_zero2"], default="ddp", help="strategy to use" + ) + parser.add_argument("--model", default="gpt2", choices=["gpt2", "bloom", "opt", "llama"]) + parser.add_argument("--pretrain", type=str, default=None) + parser.add_argument("--sft_lora_path", type=str, default=None) + parser.add_argument("--rm_model", default=None, choices=["gpt2", "bloom", "opt", "llama"]) + parser.add_argument("--rm_path", type=str, default=None) + parser.add_argument("--rm_pretrain", type=str, default=None) + parser.add_argument("--save_path", type=str, default="actor_checkpoint_prompts") + parser.add_argument("--need_optim_ckpt", type=bool, default=False) + parser.add_argument("--num_episodes", type=int, default=10) + parser.add_argument("--num_collect_steps", type=int, default=10) + parser.add_argument("--num_update_steps", type=int, default=5) + parser.add_argument("--train_batch_size", type=int, default=2) + parser.add_argument("--ptx_batch_size", type=int, default=1) + parser.add_argument("--experience_batch_size", type=int, default=8) + parser.add_argument("--lora_rank", type=int, default=0, help="low-rank adaptation matrices rank") + parser.add_argument("--kl_coef", type=float, default=0.1) + parser.add_argument("--ptx_coef", type=float, default=0.9) args = parser.parse_args() main(args) diff --git a/applications/Chat/examples/community/peft/train_peft_sft.py b/applications/Chat/examples/community/peft/train_peft_sft.py index 9bd0ebc12a836d6c699c90ce30538a74c65858f6..3bbef7208374c8ba831bb13fb6430ccb509889ab 100644 --- a/applications/Chat/examples/community/peft/train_peft_sft.py +++ b/applications/Chat/examples/community/peft/train_peft_sft.py @@ -1,19 +1,10 @@ import argparse import os -import loralib as lora import torch import torch.distributed as dist -from coati.dataset import DataCollatorForSupervisedDataset, SFTDataset, SupervisedDataset -from coati.models.base import RewardModel -from coati.models.bloom import BLOOMLM -from coati.models.gpt import GPTLM -from coati.models.llama import LlamaLM -from coati.models.opt import OPTLM from coati.trainer import SFTTrainer -from coati.trainer.strategies import ColossalAIStrategy, DDPStrategy, NaiveStrategy -from coati.utils import prepare_llama_tokenizer_and_embedding -from datasets import load_dataset +from coati.trainer.strategies import DDPStrategy, GeminiStrategy, LowLevelZeroStrategy from easy_dataset import EasyDataset from peft import LoraConfig, PeftModel, TaskType, get_peft_model from torch.optim import Adam @@ -30,80 +21,76 @@ from colossalai.tensor import ColoParameter def train(args): # configure strategy - if args.strategy == 'naive': - strategy = NaiveStrategy() - elif args.strategy == 'ddp': + if args.strategy == "ddp": strategy = DDPStrategy() - elif args.strategy == 'colossalai_gemini': - strategy = ColossalAIStrategy(stage=3, placement_policy='cuda') - elif args.strategy == 'colossalai_zero2': - strategy = ColossalAIStrategy(stage=2, placement_policy='cuda') + elif args.strategy == "colossalai_gemini": + strategy = GeminiStrategy(placement_policy="static") + elif args.strategy == "colossalai_zero2": + strategy = LowLevelZeroStrategy(stage=2, placement_policy="cuda") else: raise ValueError(f'Unsupported strategy "{args.strategy}"') # configure model with strategy.model_init_context(): - print('Warning: currently only bloom is tested, gpt2,llama and opt are not tested') + print("Warning: currently only bloom is tested, gpt2,llama and opt are not tested") model = AutoModelForCausalLM.from_pretrained(args.pretrain).to(torch.cuda.current_device()) - #if the args.save_path exists and args.save_path+'/adapter_config.json' exists, we'll load the adapter_config.json - if os.path.exists(args.save_path) and os.path.exists(args.save_path+'/adapter_config.json') \ - and os.path.exists(args.save_path+'/adapter_model.bin'): + # if the args.save_path exists and args.save_path+'/adapter_config.json' exists, we'll load the adapter_config.json + if ( + os.path.exists(args.save_path) + and os.path.exists(args.save_path + "/adapter_config.json") + and os.path.exists(args.save_path + "/adapter_model.bin") + ): print("loading from saved peft model ", args.save_path) model = PeftModel.from_pretrained(model, args.save_path) else: - #we'll use peft lora library to do the lora + # we'll use peft lora library to do the lora lora_rank = args.lora_rank if args.lora_rank > 0 else 32 - #config lora with rank of lora_rank - lora_config = LoraConfig(task_type=TaskType.CAUSAL_LM, - inference_mode=False, - r=lora_rank, - lora_alpha=32, - lora_dropout=0.1) + # config lora with rank of lora_rank + lora_config = LoraConfig( + task_type=TaskType.CAUSAL_LM, inference_mode=False, r=lora_rank, lora_alpha=32, lora_dropout=0.1 + ) model = get_peft_model(model, lora_config) model.print_trainable_parameters() # configure tokenizer - if args.model == 'gpt2': - tokenizer = GPT2Tokenizer.from_pretrained('gpt2') + if args.model == "gpt2": + tokenizer = GPT2Tokenizer.from_pretrained("gpt2") tokenizer.pad_token = tokenizer.eos_token - elif args.model == 'bloom': - tokenizer = BloomTokenizerFast.from_pretrained(args.pretrain) + elif args.model == "bloom": + tokenizer = BloomTokenizerFast.from_pretrained("bigscience/bloom-560m") tokenizer.pad_token = tokenizer.eos_token - elif args.model == 'opt': + elif args.model == "opt": tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m") - elif args.model == 'llama': + tokenizer.pad_token = tokenizer.eos_token + elif args.model == "llama": tokenizer = AutoTokenizer.from_pretrained( args.pretrain, padding_side="right", use_fast=False, ) - tokenizer.eos_token = '<\s>' + tokenizer.eos_token = "<\s>" + tokenizer.pad_token = tokenizer.unk_token else: raise ValueError(f'Unsupported model "{args.model}"') - tokenizer.pad_token = tokenizer.eos_token - if args.model == 'llama': - tokenizer = prepare_llama_tokenizer_and_embedding(tokenizer, model) - - if args.strategy == 'colossalai_gemini': - # this is a hack to deal with the resized embedding - # to make sure all parameters are ColoParameter for Colossal-AI Gemini Compatiblity - for name, param in model.named_parameters(): - if not isinstance(param, ColoParameter): - sub_module_name = '.'.join(name.split('.')[:-1]) - weight_name = name.split('.')[-1] - sub_module = model.get_submodule(sub_module_name) - setattr(sub_module, weight_name, ColoParameter(param)) - else: - tokenizer.pad_token = tokenizer.eos_token + + if args.model == "llama" and args.strategy == "colossalai_gemini": + # this is a hack to deal with the resized embedding + # to make sure all parameters are ColoParameter for Colossal-AI Gemini Compatibility + for name, param in model.named_parameters(): + if not isinstance(param, ColoParameter): + sub_module_name = ".".join(name.split(".")[:-1]) + weight_name = name.split(".")[-1] + sub_module = model.get_submodule(sub_module_name) + setattr(sub_module, weight_name, ColoParameter(param)) # configure optimizer - if args.strategy.startswith('colossalai'): + if args.strategy.startswith("colossalai"): optim = HybridAdam(model.parameters(), lr=args.lr, clipping_norm=1.0) else: optim = Adam(model.parameters(), lr=args.lr) logger = get_dist_logger() - logger.set_level('WARNING') + logger.set_level("WARNING") # configure dataset law_dataset = EasyDataset(args.dataset, tokenizer=tokenizer, is_group_texts=not args.is_short_text) @@ -114,47 +101,57 @@ def train(args): eval_dataset = EasyDataset(args.eval_dataset, tokenizer=tokenizer, is_group_texts=not args.is_short_text) data_collator = default_collate if dist.is_initialized() and dist.get_world_size() > 1: - train_sampler = DistributedSampler(train_dataset, - shuffle=True, - seed=42, - drop_last=True, - rank=dist.get_rank(), - num_replicas=dist.get_world_size()) + train_sampler = DistributedSampler( + train_dataset, + shuffle=True, + seed=42, + drop_last=True, + rank=dist.get_rank(), + num_replicas=dist.get_world_size(), + ) if eval_dataset is not None: - eval_sampler = DistributedSampler(eval_dataset, - shuffle=False, - seed=42, - drop_last=False, - rank=dist.get_rank(), - num_replicas=dist.get_world_size()) + eval_sampler = DistributedSampler( + eval_dataset, + shuffle=False, + seed=42, + drop_last=False, + rank=dist.get_rank(), + num_replicas=dist.get_world_size(), + ) else: train_sampler = None eval_sampler = None - train_dataloader = DataLoader(train_dataset, - shuffle=(train_sampler is None), - sampler=train_sampler, - batch_size=args.batch_size, - collate_fn=data_collator, - pin_memory=True) + train_dataloader = DataLoader( + train_dataset, + shuffle=(train_sampler is None), + sampler=train_sampler, + batch_size=args.batch_size, + collate_fn=data_collator, + pin_memory=True, + ) if eval_dataset is not None: - eval_dataloader = DataLoader(eval_dataset, - shuffle=(eval_sampler is None), - sampler=eval_sampler, - batch_size=args.batch_size, - collate_fn=data_collator, - pin_memory=True) + eval_dataloader = DataLoader( + eval_dataset, + shuffle=(eval_sampler is None), + sampler=eval_sampler, + batch_size=args.batch_size, + collate_fn=data_collator, + pin_memory=True, + ) else: eval_dataloader = None - trainer = SFTTrainer(model=model, - strategy=strategy, - optim=optim, - train_dataloader=train_dataloader, - eval_dataloader=eval_dataloader, - batch_size=args.batch_size, - max_epochs=args.max_epochs, - accumulation_steps=args.accumulation_steps) + trainer = SFTTrainer( + model=model, + strategy=strategy, + optim=optim, + train_dataloader=train_dataloader, + eval_dataloader=eval_dataloader, + batch_size=args.batch_size, + max_epochs=args.max_epochs, + accumulation_steps=args.accumulation_steps, + ) trainer.fit(logger=logger, log_interval=args.log_interval) @@ -162,29 +159,27 @@ def train(args): trainer.save_model(path=args.save_path, only_rank0=True, tokenizer=tokenizer) # save optimizer checkpoint on all ranks if args.need_optim_ckpt: - strategy.save_optimizer(trainer.optimizer, - 'rm_optim_checkpoint_%d.pt' % (torch.cuda.current_device()), - only_rank0=False) + strategy.save_optimizer( + trainer.optimizer, "rm_optim_checkpoint_%d.pt" % (torch.cuda.current_device()), only_rank0=False + ) -if __name__ == '__main__': +if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument('--strategy', - choices=['naive', 'ddp', 'colossalai_gemini', 'colossalai_zero2'], - default='naive') - parser.add_argument('--model', choices=['gpt2', 'bloom', 'opt', 'llama'], default='bloom') - parser.add_argument('--pretrain', type=str, default=None) - parser.add_argument('--dataset', type=str, default=None) - parser.add_argument('--eval_dataset', type=str, default=None) - parser.add_argument('--save_path', type=str, default='output') - parser.add_argument('--need_optim_ckpt', type=bool, default=False) - parser.add_argument('--max_epochs', type=int, default=3) - parser.add_argument('--batch_size', type=int, default=4) - parser.add_argument('--lora_rank', type=int, default=0, help="low-rank adaptation matrices rank") - parser.add_argument('--log_interval', type=int, default=100, help="how many steps to log") - parser.add_argument('--lr', type=float, default=5e-6) - parser.add_argument('--accumulation_steps', type=int, default=8) - parser.add_argument('--enable_peft_lora', action='store_true', default=False) - parser.add_argument("--is_short_text", action='store_true', default=False) + parser.add_argument("--strategy", choices=["ddp", "colossalai_gemini", "colossalai_zero2"], default="ddp") + parser.add_argument("--model", choices=["gpt2", "bloom", "opt", "llama"], default="bloom") + parser.add_argument("--pretrain", type=str, default=None) + parser.add_argument("--dataset", type=str, default=None) + parser.add_argument("--eval_dataset", type=str, default=None) + parser.add_argument("--save_path", type=str, default="output") + parser.add_argument("--need_optim_ckpt", type=bool, default=False) + parser.add_argument("--max_epochs", type=int, default=3) + parser.add_argument("--batch_size", type=int, default=4) + parser.add_argument("--lora_rank", type=int, default=0, help="low-rank adaptation matrices rank") + parser.add_argument("--log_interval", type=int, default=100, help="how many steps to log") + parser.add_argument("--lr", type=float, default=5e-6) + parser.add_argument("--accumulation_steps", type=int, default=8) + parser.add_argument("--enable_peft_lora", action="store_true", default=False) + parser.add_argument("--is_short_text", action="store_true", default=False) args = parser.parse_args() train(args) diff --git a/applications/Chat/examples/community/ray/README.md b/applications/Chat/examples/community/ray/README.md index 64360bd73ddc8d5627594332f85a7f701c26e1b9..a679a58336a7a661a9cecb21da1c5e8c58b4c80e 100644 --- a/applications/Chat/examples/community/ray/README.md +++ b/applications/Chat/examples/community/ray/README.md @@ -1,17 +1,31 @@ +:warning: **This content may be outdated since the major update of Colossal Chat. We will update this content soon.** + # ColossalAI on Ray + ## Abstract + This is an experimental effort to run ColossalAI Chat training on Ray + ## How to use? + ### 1. Setup Ray clusters + Please follow the official [Ray cluster setup instructions](https://docs.ray.io/en/latest/cluster/getting-started.html) to setup an cluster with GPU support. Record the cluster's api server endpoint, it should be something similar to http://your.head.node.addrees:8265 + ### 2. Clone repo + Clone this project: + ```shell git clone https://github.com/hpcaitech/ColossalAI.git ``` + ### 3. Submit the ray job + ```shell python applications/Chat/examples/community/ray/ray_job_script.py http://your.head.node.addrees:8265 ``` + ### 4. View your job on the Ray Dashboard + Open your ray cluster dashboard http://your.head.node.addrees:8265 to view your submitted training job. diff --git a/applications/Chat/examples/community/ray/ray_job_script.py b/applications/Chat/examples/community/ray/ray_job_script.py index 53f304d379fec54d82d3775552863e73f8dfcbc4..e8a1175a9c32ae2f87b4a3425dd4d13182e41f52 100644 --- a/applications/Chat/examples/community/ray/ray_job_script.py +++ b/applications/Chat/examples/community/ray/ray_job_script.py @@ -6,16 +6,25 @@ from ray.job_submission import JobSubmissionClient def main(api_server_endpoint="http://127.0.0.1:8265"): client = JobSubmissionClient(api_server_endpoint) client.submit_job( - entrypoint= - "python experimental/ray/train_prompts_on_ray.py --strategy colossalai_zero2 --prompt_csv_url https://huggingface.co/datasets/fka/awesome-chatgpt-prompts/resolve/main/prompts.csv", + entrypoint="python experimental/ray/train_prompts_on_ray.py --strategy colossalai_zero2 --prompt_csv_url https://huggingface.co/datasets/fka/awesome-chatgpt-prompts/resolve/main/prompts.csv", runtime_env={ - "working_dir": - "applications/Chat", + "working_dir": "applications/Chat", "pip": [ - "torch==1.13.1", "transformers>=4.20.1", "datasets", "loralib", "colossalai>=0.2.4", "langchain", - "tokenizers", "fastapi", "sse_starlette", "wandb", "sentencepiece", "gpustat" - ] - }) + "torch==1.13.1", + "transformers>=4.20.1", + "datasets", + "loralib", + "colossalai>=0.2.4", + "langchain", + "tokenizers", + "fastapi", + "sse_starlette", + "wandb", + "sentencepiece", + "gpustat", + ], + }, + ) if __name__ == "__main__": diff --git a/applications/Chat/examples/community/ray/train_prompts_on_ray.py b/applications/Chat/examples/community/ray/train_prompts_on_ray.py index 289330ad841516a8bbc17ce80cf022d21f30643a..8abd83a8b249e52cbd14a01110b511d05a17a629 100644 --- a/applications/Chat/examples/community/ray/train_prompts_on_ray.py +++ b/applications/Chat/examples/community/ray/train_prompts_on_ray.py @@ -15,7 +15,7 @@ from coati.models.lora import LoRAModule from coati.models.loss import PolicyLoss, ValueLoss from coati.models.opt import OPTActor, OPTCritic from coati.models.utils import compute_reward -from coati.trainer.strategies import ColossalAIStrategy, DDPStrategy, NaiveStrategy +from coati.trainer.strategies import DDPStrategy, GeminiStrategy, LowLevelZeroStrategy from ray.util.placement_group import placement_group from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy from torch.optim import Adam @@ -26,9 +26,14 @@ from colossalai.nn.optimizer import HybridAdam class ExperienceCompositionRefs: - - def __init__(self, sequences_attention_mask_action_mask_ref: ray.ObjectRef, action_log_probs_ref: ray.ObjectRef, - base_action_log_probs_ref: ray.ObjectRef, value_ref: ray.ObjectRef, r_ref: ray.ObjectRef) -> None: + def __init__( + self, + sequences_attention_mask_action_mask_ref: ray.ObjectRef, + action_log_probs_ref: ray.ObjectRef, + base_action_log_probs_ref: ray.ObjectRef, + value_ref: ray.ObjectRef, + r_ref: ray.ObjectRef, + ) -> None: self.sequences_attention_mask_action_mask_ref = sequences_attention_mask_action_mask_ref self.action_log_probs_ref = action_log_probs_ref self.base_action_log_probs_ref = base_action_log_probs_ref @@ -37,14 +42,14 @@ class ExperienceCompositionRefs: class ExperienceMaker: - def __init__(self, kl_coef) -> None: self.kl_coef = kl_coef @torch.no_grad() def make_experience(self, experiment_computation_refs: ExperienceCompositionRefs): sequences, attention_mask, action_mask = ray.get( - experiment_computation_refs.sequences_attention_mask_action_mask_ref) + experiment_computation_refs.sequences_attention_mask_action_mask_ref + ) action_log_probs = ray.get(experiment_computation_refs.action_log_probs_ref) base_action_log_probs = ray.get(experiment_computation_refs.base_action_log_probs_ref) r = ray.get(experiment_computation_refs.r_ref) @@ -58,11 +63,10 @@ class ExperienceMaker: class DistributedTorchRayActor: - def __init__(self, world_size, rank, local_rank, master_addr, master_port): - logging.basicConfig(format='%(asctime)s %(levelname)-8s %(message)s', - level=logging.INFO, - datefmt='%Y-%m-%d %H:%M:%S') + logging.basicConfig( + format="%(asctime)s %(levelname)-8s %(message)s", level=logging.INFO, datefmt="%Y-%m-%d %H:%M:%S" + ) self._model = None self._world_size = world_size self._rank = rank @@ -82,7 +86,7 @@ class DistributedTorchRayActor: @staticmethod def _get_free_port(): with socket.socket() as sock: - sock.bind(('', 0)) + sock.bind(("", 0)) return sock.getsockname()[1] def get_master_addr_port(self): @@ -90,7 +94,6 @@ class DistributedTorchRayActor: class BasePPORole(DistributedTorchRayActor): - def add_experience_maker(self, kl_coef: float = 0.1): self._experience_maker = ExperienceMaker(kl_coef) @@ -99,19 +102,17 @@ class BasePPORole(DistributedTorchRayActor): def _init_strategy(self, strategy: str): # configure strategy - if strategy == 'naive': - self._strategy = NaiveStrategy() - elif strategy == 'ddp': + if strategy == "ddp": self._strategy = DDPStrategy() - elif strategy == 'colossalai_gemini': - self._strategy = ColossalAIStrategy(stage=3, placement_policy='cuda', initial_scale=2**5) - elif strategy == 'colossalai_zero2': - self._strategy = ColossalAIStrategy(stage=2, placement_policy='cuda') + elif strategy == "colossalai_gemini": + self._strategy = GeminiStrategy(placement_policy="cuda", initial_scale=2**5) + elif strategy == "colossalai_zero2": + self._strategy = LowLevelZeroStrategy(stage=2, placement_policy="cuda") else: raise ValueError(f'Unsupported strategy "{strategy}"') def _init_optimizer(self): - if isinstance(self._strategy, ColossalAIStrategy): + if isinstance(self._strategy, (GeminiStrategy, LowLevelZeroStrategy)): self._optimizer = HybridAdam(self._model.parameters(), lr=5e-6) else: self._optimizer = Adam(self._model.parameters(), lr=5e-6) @@ -126,11 +127,9 @@ class BasePPORole(DistributedTorchRayActor): def _load_model_from_pretrained(self, model_class: Type[LoRAModule], pretrain: str): raise NotImplementedError() - def init_model_from_pretrained(self, - strategy: str, - model_class: Type[LoRAModule], - pretrain: str, - has_optimizer=False): + def init_model_from_pretrained( + self, strategy: str, model_class: Type[LoRAModule], pretrain: str, has_optimizer=False + ): self._init_strategy(strategy) self._load_model_from_pretrained(model_class, pretrain) self._prepare_model_with_strategy(has_optimizer) @@ -140,7 +139,6 @@ class BasePPORole(DistributedTorchRayActor): class TrainablePPORole(BasePPORole): - def _load_model_from_pretrained(self, model_class, pretrain): with self._strategy.model_init_context(): self._model = model_class(pretrain).to(torch.cuda.current_device()) @@ -163,38 +161,39 @@ class TrainablePPORole(BasePPORole): @ray.remote(num_gpus=1) class RayPPOActor(TrainablePPORole): - def set_loss_function(self, eps_clip: float): self._actor_loss_fn = PolicyLoss(eps_clip) def load_tokenizer_from_pretrained(self, model_type: str, pretrained): - if model_type == 'gpt2': + if model_type == "gpt2": self._model_tokenizer = GPT2Tokenizer.from_pretrained(pretrained) self._model_tokenizer.pad_token = self._model_tokenizer.eos_token - elif model_type == 'bloom': + elif model_type == "bloom": self._model_tokenizer = BloomTokenizerFast.from_pretrained(pretrained) self._model_tokenizer.pad_token = self._model_tokenizer.eos_token - elif model_type == 'opt': + elif model_type == "opt": self._model_tokenizer = AutoTokenizer.from_pretrained(pretrained) else: raise ValueError(f'Unsupported model "{model_type}"') # Set tokenize function for sequence generation def _text_input_tokenize_fn(texts): - batch = self._model_tokenizer(texts, return_tensors='pt', max_length=96, padding=True, truncation=True) + batch = self._model_tokenizer(texts, return_tensors="pt", max_length=96, padding=True, truncation=True) return {k: v.cuda() for k, v in batch.items()} self._sample_tokenize_function = _text_input_tokenize_fn def setup_generate_kwargs(self, generate_kwargs: dict): from coati.trainer.ppo import _set_default_generate_kwargs + self._generate_kwargs = _set_default_generate_kwargs(self._strategy, generate_kwargs, self._model) - self._generate_kwargs['pad_token_id'] = self._model_tokenizer.pad_token_id - self._generate_kwargs['eos_token_id'] = self._model_tokenizer.eos_token_id + self._generate_kwargs["pad_token_id"] = self._model_tokenizer.pad_token_id + self._generate_kwargs["eos_token_id"] = self._model_tokenizer.eos_token_id def load_csv_prompt_file_from_url_to_sampler(self, prompt_url): import pandas as pd - prompts = pd.read_csv(prompt_url)['prompt'] + + prompts = pd.read_csv(prompt_url)["prompt"] self._sampler = self._strategy.setup_sampler(prompts) def _generate(self, input_ids, **generate_kwargs): @@ -216,10 +215,9 @@ class RayPPOActor(TrainablePPORole): def _training_step(self, experience): num_actions = experience.action_mask.size(1) action_log_probs = self._model(experience.sequences, num_actions, attention_mask=experience.attention_mask) - actor_loss = self._actor_loss_fn(action_log_probs, - experience.action_log_probs, - experience.advantages, - action_mask=experience.action_mask) + actor_loss = self._actor_loss_fn( + action_log_probs, experience.action_log_probs, experience.advantages, action_mask=experience.action_mask + ) self._strategy.backward(actor_loss, self._model, self._optimizer) self._strategy.optimizer_step(self._optimizer) self._optimizer.zero_grad() @@ -231,17 +229,18 @@ class RayPPOActor(TrainablePPORole): self._strategy.save_model(self._model, save_path, only_rank0=True) # save optimizer checkpoint on all ranks if should_save_optimizer: - self._strategy.save_optimizer(self._optimizer, - 'actor_optim_checkpoint_prompts_%d.pt' % (torch.cuda.current_device()), - only_rank0=False) + self._strategy.save_optimizer( + self._optimizer, + "actor_optim_checkpoint_prompts_%d.pt" % (torch.cuda.current_device()), + only_rank0=False, + ) def generate_answer(self, prompt, max_length=30, num_return_sequences=5): - encoded_input = self._model_tokenizer(prompt, return_tensors='pt') + encoded_input = self._model_tokenizer(prompt, return_tensors="pt") input_ids = {k: v.cuda() for k, v in encoded_input.items()} - sequence, _ = self._model.generate(**input_ids, - max_length=max_length, - return_action_mask=False, - num_return_sequences=num_return_sequences) + sequence, _ = self._model.generate( + **input_ids, max_length=max_length, return_action_mask=False, num_return_sequences=num_return_sequences + ) token_list = list(sequence.data[0]) output = " ".join([self._model_tokenizer.decode(token) for token in token_list]) return output @@ -249,18 +248,16 @@ class RayPPOActor(TrainablePPORole): @ray.remote(num_gpus=1) class RayPPOCritic(TrainablePPORole): - def set_loss_function(self, value_clip: float): self._critic_loss_fn = ValueLoss(value_clip) def _training_step(self, experience): - values = self._model(experience.sequences, - action_mask=experience.action_mask, - attention_mask=experience.attention_mask) - critic_loss = self._critic_loss_fn(values, - experience.values, - experience.reward, - action_mask=experience.action_mask) + values = self._model( + experience.sequences, action_mask=experience.action_mask, attention_mask=experience.attention_mask + ) + critic_loss = self._critic_loss_fn( + values, experience.values, experience.reward, action_mask=experience.action_mask + ) self._strategy.backward(critic_loss, self._model, self._optimizer) self._strategy.optimizer_step(self._optimizer) self._optimizer.zero_grad() @@ -274,12 +271,12 @@ class RayPPOCritic(TrainablePPORole): @ray.remote(num_gpus=1) class RayPPORewardModel(BasePPORole): - def _load_model_from_pretrained(self, model_class, pretrain): with self._strategy.model_init_context(): critic = model_class(pretrained=pretrain).to(torch.cuda.current_device()) - self._model = RewardModel(deepcopy(critic.model), - deepcopy(critic.value_head)).to(torch.cuda.current_device()) + self._model = RewardModel(deepcopy(critic.model), deepcopy(critic.value_head)).to( + torch.cuda.current_device() + ) @torch.no_grad() def calculate_r(self, sequence_attention_action_mask): @@ -289,7 +286,6 @@ class RayPPORewardModel(BasePPORole): @ray.remote(num_gpus=1) class RayPPOInitialModel(BasePPORole): - def _load_model_from_pretrained(self, model_class, pretrain): with self._strategy.model_init_context(): self._model = model_class(pretrain).to(torch.cuda.current_device()) @@ -302,8 +298,8 @@ class RayPPOInitialModel(BasePPORole): class PPORayActorGroup: """ - A group of ray actors - Functions start with 'async' should return list of object refs + A group of ray actors + Functions start with 'async' should return list of object refs """ def __init__(self, num_nodes, num_gpus_per_node, ray_actor_type: Type[BasePPORole]) -> None: @@ -321,8 +317,9 @@ class PPORayActorGroup: pg = placement_group(bundles, strategy="STRICT_SPREAD") ray.get(pg.ready()) if pg: - master_actor = self.ray_actor_type.options(scheduling_strategy=PlacementGroupSchedulingStrategy( - placement_group=pg, placement_group_bundle_index=0)).remote(world_size, 0, 0, None, None) + master_actor = self.ray_actor_type.options( + scheduling_strategy=PlacementGroupSchedulingStrategy(placement_group=pg, placement_group_bundle_index=0) + ).remote(world_size, 0, 0, None, None) else: master_actor = self.ray_actor_type.options(num_gpus=1).remote(world_size, 0, 0, None, None) self._actor_handlers = [master_actor] @@ -333,16 +330,20 @@ class PPORayActorGroup: for rank in range(1, world_size): local_rank = rank % self._num_gpus_per_node if pg: - worker_actor = self.ray_actor_type.options(scheduling_strategy=PlacementGroupSchedulingStrategy( - placement_group=pg, placement_group_bundle_index=rank // self._num_gpus_per_node)).remote( - world_size, rank, local_rank, master_addr, master_port) + worker_actor = self.ray_actor_type.options( + scheduling_strategy=PlacementGroupSchedulingStrategy( + placement_group=pg, placement_group_bundle_index=rank // self._num_gpus_per_node + ) + ).remote(world_size, rank, local_rank, master_addr, master_port) else: - worker_actor = self.ray_actor_type.options(num_gpus=1).remote(world_size, rank, local_rank, - master_addr, master_port) + worker_actor = self.ray_actor_type.options(num_gpus=1).remote( + world_size, rank, local_rank, master_addr, master_port + ) self._actor_handlers.append(worker_actor) - def async_init_model_from_pretrained(self, strategy: str, model_class: Type[LoRAModule], pretrain: str, - has_optimizer: bool): + def async_init_model_from_pretrained( + self, strategy: str, model_class: Type[LoRAModule], pretrain: str, has_optimizer: bool + ): return [ actor.init_model_from_pretrained.remote(strategy, model_class, pretrain, has_optimizer) for actor in self._actor_handlers @@ -350,7 +351,6 @@ class PPORayActorGroup: class TrainableModelRayActorGroup(PPORayActorGroup): - def async_learn_on_experiences(self, experience_refs): num_actors = len(self._actor_handlers) learn_result_refs = [] @@ -361,7 +361,6 @@ class TrainableModelRayActorGroup(PPORayActorGroup): class PPOActorRayActorGroup(TrainableModelRayActorGroup): - def __init__(self, num_nodes, num_gpus_per_node) -> None: super().__init__(num_nodes, num_gpus_per_node, RayPPOActor) @@ -383,7 +382,8 @@ class PPOActorRayActorGroup(TrainableModelRayActorGroup): action_log_probs_refs = [] for i in range(len(sequences_attention_mask_action_mask_refs)): action_log_probs_ref = self._actor_handlers[i % num_actors].calculate_action_log_probs.remote( - sequences_attention_mask_action_mask_refs[i]) + sequences_attention_mask_action_mask_refs[i] + ) action_log_probs_refs.append(action_log_probs_ref) return action_log_probs_refs @@ -395,7 +395,6 @@ class PPOActorRayActorGroup(TrainableModelRayActorGroup): class PPOCriticRayActorGroup(TrainableModelRayActorGroup): - def __init__(self, num_nodes, num_gpus_per_node) -> None: super().__init__(num_nodes, num_gpus_per_node, RayPPOCritic) @@ -404,7 +403,8 @@ class PPOCriticRayActorGroup(TrainableModelRayActorGroup): value_refs = [] for i in range(len(sequences_attention_mask_action_mask_refs)): value_ref = self._actor_handlers[i % num_actors].calculate_value.remote( - sequences_attention_mask_action_mask_refs[i]) + sequences_attention_mask_action_mask_refs[i] + ) value_refs.append(value_ref) return value_refs @@ -413,7 +413,6 @@ class PPOCriticRayActorGroup(TrainableModelRayActorGroup): class PPOInitialRayActorGroup(PPORayActorGroup): - def __init__(self, num_nodes, num_gpus_per_node) -> None: super().__init__(num_nodes, num_gpus_per_node, RayPPOInitialModel) @@ -422,13 +421,13 @@ class PPOInitialRayActorGroup(PPORayActorGroup): base_action_log_probs_refs = [] for i in range(len(sequences_attention_mask_action_mask_refs)): base_action_log_probs_ref = self._actor_handlers[i % num_actors].calculate_base_action_log_probs.remote( - sequences_attention_mask_action_mask_refs[i]) + sequences_attention_mask_action_mask_refs[i] + ) base_action_log_probs_refs.append(base_action_log_probs_ref) return base_action_log_probs_refs class PPORewardRayActorGroup(PPORayActorGroup): - def __init__(self, num_nodes, num_gpus_per_node) -> None: super().__init__(num_nodes, num_gpus_per_node, RayPPORewardModel) @@ -437,20 +436,21 @@ class PPORewardRayActorGroup(PPORayActorGroup): r_refs = [] for i in range(len(sequences_attention_mask_action_mask_refs)): r_ref = self._actor_handlers[i % num_actors].calculate_r.remote( - sequences_attention_mask_action_mask_refs[i]) + sequences_attention_mask_action_mask_refs[i] + ) r_refs.append(r_ref) return r_refs def main(args): - logging.basicConfig(format='%(asctime)s %(levelname)-8s %(message)s', - level=logging.INFO, - datefmt='%Y-%m-%d %H:%M:%S') - if args.model == 'gpt2': + logging.basicConfig( + format="%(asctime)s %(levelname)-8s %(message)s", level=logging.INFO, datefmt="%Y-%m-%d %H:%M:%S" + ) + if args.model == "gpt2": actor_model_class, critic_model_class = GPTActor, GPTCritic - elif args.model == 'bloom': + elif args.model == "bloom": actor_model_class, critic_model_class = BLOOMActor, BLOOMCritic - elif args.model == 'opt': + elif args.model == "opt": actor_model_class, critic_model_class = OPTActor, OPTCritic else: raise ValueError(f'Unsupported model "{args.model}"') @@ -464,13 +464,14 @@ def main(args): logging.info("Actors created") # Prepare model for training - generate_kwargs = {'max_length': 128, 'do_sample': True, 'temperature': 1.0, 'top_k': 50} + generate_kwargs = {"max_length": 128, "do_sample": True, "temperature": 1.0, "top_k": 50} ray.get( - actor_group.async_init_model_from_pretrained(args.strategy, actor_model_class, args.pretrain, True) + - critic_group.async_init_model_from_pretrained(args.strategy, critic_model_class, args.pretrain, True) + - initial_group.async_init_model_from_pretrained(args.strategy, actor_model_class, args.pretrain, False) + - reward_group.async_init_model_from_pretrained(args.strategy, critic_model_class, args.pretrain, False) + - actor_group.async_prepare_for_sequence_generation(args.model, args.pretrain, generate_kwargs)) + actor_group.async_init_model_from_pretrained(args.strategy, actor_model_class, args.pretrain, True) + + critic_group.async_init_model_from_pretrained(args.strategy, critic_model_class, args.pretrain, True) + + initial_group.async_init_model_from_pretrained(args.strategy, actor_model_class, args.pretrain, False) + + reward_group.async_init_model_from_pretrained(args.strategy, critic_model_class, args.pretrain, False) + + actor_group.async_prepare_for_sequence_generation(args.model, args.pretrain, generate_kwargs) + ) logging.info("Models prepared for training") # Prepare models for training @@ -485,8 +486,12 @@ def main(args): # Start training logging.info("Training start") # Set all models to eval and add experience maker - all_ray_actors = actor_group._actor_handlers + critic_group._actor_handlers + \ - initial_group._actor_handlers + reward_group._actor_handlers + all_ray_actors = ( + actor_group._actor_handlers + + critic_group._actor_handlers + + initial_group._actor_handlers + + reward_group._actor_handlers + ) num_ray_actors = len(all_ray_actors) ray.get([ray_actor.eval.remote() for ray_actor in all_ray_actors]) ray.get([ray_actor.add_experience_maker.remote() for ray_actor in all_ray_actors]) @@ -499,18 +504,28 @@ def main(args): time += 1 # Experience queueing stage sequences_attention_mask_action_mask_refs = actor_group.async_sample_prompts_and_make_sequence( - experience_batch_size) + experience_batch_size + ) base_action_log_probs_refs = initial_group.async_calculate_base_action_log_probs( - sequences_attention_mask_action_mask_refs) + sequences_attention_mask_action_mask_refs + ) values_refs = critic_group.async_calculate_value(sequences_attention_mask_action_mask_refs) r_refs = reward_group.async_calculate_r(sequences_attention_mask_action_mask_refs) action_log_probs_refs = actor_group.async_calculate_action_log_probs( - sequences_attention_mask_action_mask_refs) - experience_composition_refs.extend([ - ExperienceCompositionRefs(sequences_attention_mask_action_mask_refs[i], action_log_probs_refs[i], - base_action_log_probs_refs[i], values_refs[i], r_refs[i]) - for i in range(len(sequences_attention_mask_action_mask_refs)) - ]) + sequences_attention_mask_action_mask_refs + ) + experience_composition_refs.extend( + [ + ExperienceCompositionRefs( + sequences_attention_mask_action_mask_refs[i], + action_log_probs_refs[i], + base_action_log_probs_refs[i], + values_refs[i], + r_refs[i], + ) + for i in range(len(sequences_attention_mask_action_mask_refs)) + ] + ) # Learning stage if time % update_timesteps == 0: experience_refs = [] @@ -521,8 +536,9 @@ def main(args): experience_refs.append(selected_ray_actor.make_experience.remote(exp_composition_ref)) # backward ray.get( - actor_group.async_learn_on_experiences(experience_refs) + - critic_group.async_learn_on_experiences(experience_refs)) + actor_group.async_learn_on_experiences(experience_refs) + + critic_group.async_learn_on_experiences(experience_refs) + ) # clear refs queue experience_composition_refs.clear() logging.info("Training finished") @@ -530,26 +546,24 @@ def main(args): actor_group.save_checkpoint(args.save_path, args.need_optim_ckpt) -if __name__ == '__main__': +if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument('--prompt_csv_url', type=str) - parser.add_argument('--strategy', - choices=['naive', 'ddp', 'colossalai_gemini', 'colossalai_zero2'], - default='naive') - parser.add_argument('--model', default='gpt2', choices=['gpt2', 'bloom', 'opt']) - parser.add_argument('--pretrain', type=str, default='gpt2') - parser.add_argument('--save_path', type=str, default='actor_checkpoint_prompts.pt') - parser.add_argument('--need_optim_ckpt', type=bool, default=False) - parser.add_argument('--num_episodes', type=int, default=10) - parser.add_argument('--max_timesteps', type=int, default=10) - parser.add_argument('--update_timesteps', type=int, default=10) - parser.add_argument('--train_batch_size', type=int, default=8) - parser.add_argument('--experience_batch_size', type=int, default=8) - parser.add_argument('--num_actor_nodes', type=int, help='num of nodes to use to host actor model', default=1) - parser.add_argument('--num_critic_nodes', type=int, help='num of nodes to use to host critic model', default=1) - parser.add_argument('--num_initial_nodes', type=int, help='num of nodes to use to host initial model', default=1) - parser.add_argument('--num_reward_nodes', type=int, help='num of nodes to use to host reward model', default=1) - parser.add_argument('--num_gpus_per_node', type=int, help='num of gpus on a ray node', default=1) + parser.add_argument("--prompt_csv_url", type=str) + parser.add_argument("--strategy", choices=["ddp", "colossalai_gemini", "colossalai_zero2"], default="ddp") + parser.add_argument("--model", default="gpt2", choices=["gpt2", "bloom", "opt"]) + parser.add_argument("--pretrain", type=str, default="gpt2") + parser.add_argument("--save_path", type=str, default="actor_checkpoint_prompts.pt") + parser.add_argument("--need_optim_ckpt", type=bool, default=False) + parser.add_argument("--num_episodes", type=int, default=10) + parser.add_argument("--max_timesteps", type=int, default=10) + parser.add_argument("--update_timesteps", type=int, default=10) + parser.add_argument("--train_batch_size", type=int, default=8) + parser.add_argument("--experience_batch_size", type=int, default=8) + parser.add_argument("--num_actor_nodes", type=int, help="num of nodes to use to host actor model", default=1) + parser.add_argument("--num_critic_nodes", type=int, help="num of nodes to use to host critic model", default=1) + parser.add_argument("--num_initial_nodes", type=int, help="num of nodes to use to host initial model", default=1) + parser.add_argument("--num_reward_nodes", type=int, help="num of nodes to use to host reward model", default=1) + parser.add_argument("--num_gpus_per_node", type=int, help="num of gpus on a ray node", default=1) args = parser.parse_args() ray.init() main(args) diff --git a/applications/Chat/examples/download_model.py b/applications/Chat/examples/download_model.py new file mode 100644 index 0000000000000000000000000000000000000000..ec3482b5f7895aa4e06b0924880dadeb3f0413e7 --- /dev/null +++ b/applications/Chat/examples/download_model.py @@ -0,0 +1,79 @@ +import argparse +import dataclasses +import os +import parser +from typing import List + +import tqdm +from coati.models.bloom import BLOOMRM, BLOOMActor, BLOOMCritic +from coati.models.gpt import GPTRM, GPTActor, GPTCritic +from coati.models.opt import OPTRM, OPTActor, OPTCritic +from huggingface_hub import hf_hub_download, snapshot_download +from transformers import AutoConfig, AutoTokenizer, BloomConfig, BloomTokenizerFast, GPT2Config, GPT2Tokenizer + + +@dataclasses.dataclass +class HFRepoFiles: + repo_id: str + files: List[str] + + def download(self, dir_path: str): + for file in self.files: + file_path = hf_hub_download(self.repo_id, file, local_dir=dir_path) + + def download_all(self): + snapshot_download(self.repo_id) + + +def test_init(model: str, dir_path: str): + if model == "gpt2": + config = GPT2Config.from_pretrained(dir_path) + actor = GPTActor(config=config) + critic = GPTCritic(config=config) + reward_model = GPTRM(config=config) + GPT2Tokenizer.from_pretrained(dir_path) + elif model == "bloom": + config = BloomConfig.from_pretrained(dir_path) + actor = BLOOMActor(config=config) + critic = BLOOMCritic(config=config) + reward_model = BLOOMRM(config=config) + BloomTokenizerFast.from_pretrained(dir_path) + elif model == "opt": + config = AutoConfig.from_pretrained(dir_path) + actor = OPTActor(config=config) + critic = OPTCritic(config=config) + reward_model = OPTRM(config=config) + AutoTokenizer.from_pretrained(dir_path) + else: + raise NotImplementedError(f"Model {model} not implemented") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--model-dir", type=str, default="test_models") + parser.add_argument("--config-only", default=False, action="store_true") + args = parser.parse_args() + + if os.path.exists(args.model_dir): + print(f"[INFO]: {args.model_dir} already exists") + exit(0) + + repo_list = { + "gpt2": HFRepoFiles(repo_id="gpt2", files=["config.json", "tokenizer.json", "vocab.json", "merges.txt"]), + "bloom": HFRepoFiles( + repo_id="bigscience/bloom-560m", files=["config.json", "tokenizer.json", "tokenizer_config.json"] + ), + "opt": HFRepoFiles( + repo_id="facebook/opt-350m", files=["config.json", "tokenizer_config.json", "vocab.json", "merges.txt"] + ), + } + + os.mkdir(args.model_dir) + for model_name in tqdm.tqdm(repo_list): + dir_path = os.path.join(args.model_dir, model_name) + if args.config_only: + os.mkdir(dir_path) + repo_list[model_name].download(dir_path) + else: + repo_list[model_name].download_all() + test_init(model_name, dir_path) diff --git a/applications/Chat/examples/example_data_reformat.py b/applications/Chat/examples/example_data_reformat.py deleted file mode 100644 index dc83b29b525b16ff322126b63042eb32f32ed21e..0000000000000000000000000000000000000000 --- a/applications/Chat/examples/example_data_reformat.py +++ /dev/null @@ -1,12 +0,0 @@ -jsonl_file = 'seed_prompts_xx.jsonl' # seed_prompts_en.jsonl or seed_prompts_ch.json from InstructionWild -reformat_file = 'prompts_xx.jsonl' # reformat jsonl file used as Prompt dataset in Stage3 - -data = '' -with open(jsonl_file, 'r', encoding="utf-8") as f1: - for jsonstr in f1.readlines(): - jsonstr = '\t' + jsonstr.strip('\n') + ',\n' - data = data + jsonstr - data = '[\n' + data + ']' - -with open(reformat_file, 'w') as f2: - f2.write(data) \ No newline at end of file diff --git a/applications/Chat/examples/generate_conversation_dataset.py b/applications/Chat/examples/generate_conversation_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..7e03b2d54260b3a0dc20324b596df0f561b2a581 --- /dev/null +++ b/applications/Chat/examples/generate_conversation_dataset.py @@ -0,0 +1,82 @@ +import argparse +import json + +from datasets import load_dataset + + +def generate_alpaca(): + # We can convert dataset with the same format("instruction", "input", "output") as Alpaca into a one-round conversation. + conversation_dataset = [] + dataset = load_dataset("tatsu-lab/alpaca", split="train") + + instructions = dataset["instruction"] + inputs = dataset["input"] + outputs = dataset["output"] + + assert len(instructions) == len(inputs) == len(outputs) + + for idx in range(len(instructions)): + human_utterance = instructions[idx] + "\n\n" + inputs[idx] if inputs[idx] else instructions[idx] + human = {"from": "human", "value": human_utterance} + + gpt_utterance = outputs[idx] + gpt = {"from": "gpt", "value": gpt_utterance} + + conversation = dict(type="instruction", language="English", dataset="Alpaca", conversations=[human, gpt]) + conversation_dataset.append(conversation) + + return conversation_dataset + + +def generate_sharegpt(): + # ShareGPT data requires less processing. + conversation_dataset = [] + dataset = load_dataset( + "anon8231489123/ShareGPT_Vicuna_unfiltered", + data_files="ShareGPT_V3_unfiltered_cleaned_split_no_imsorry.json", + split="train", + ) + + conversations = dataset["conversations"] + + for idx in range(len(conversations)): + for conv in conversations[idx]: + # We don't need markdown and text value. + del conv["markdown"] + del conv["text"] + + conversation = dict( + type="conversation", language="Multilingual", dataset="ShareGPT", conversations=conversations[idx] + ) + conversation_dataset.append(conversation) + + return conversation_dataset + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--dataset", + type=str, + default="All", + choices=["Alpaca", "ShareGPT", "All"], + help="which dataset to convert, All will combine Alpaca and ShareGPT", + ) + parser.add_argument("--save_path", type=str, default="dataset.json", help="path to save the converted dataset") + args = parser.parse_args() + + conversation_dataset = [] + + if args.dataset == "Alpaca": + conversation_dataset.extend(generate_alpaca()) + elif args.dataset == "ShareGPT": + conversation_dataset.extend(generate_sharegpt()) + else: + conversation_dataset.extend(generate_alpaca()) + conversation_dataset.extend(generate_sharegpt()) + + for idx, sample in enumerate(conversation_dataset): + sample["id"] = idx + 1 + + with open(args.save_path, mode="w") as f: + json.dump(conversation_dataset, f, indent=4, default=str, ensure_ascii=False) diff --git a/applications/Chat/examples/generate_prompt_dataset.py b/applications/Chat/examples/generate_prompt_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..4eec6feae505887331d88c204b5ad514438a6a31 --- /dev/null +++ b/applications/Chat/examples/generate_prompt_dataset.py @@ -0,0 +1,27 @@ +import argparse +import json +import random + +random.seed(42) + + +def sample(args): + with open(args.dataset_path, mode="r") as f: + dataset_list = json.load(f) + + sampled_dataset = [ + {"instruction": sample["instruction"], "id": idx} + for idx, sample in enumerate(random.sample(dataset_list, args.sample_size)) + ] + + with open(args.save_path, mode="w") as f: + json.dump(sampled_dataset, f, indent=4, default=str, ensure_ascii=False) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--dataset_path", type=str, default=None, required=True, help="path to the pretrain dataset") + parser.add_argument("--save_path", type=str, default="prompt.json", help="path to save the prompt dataset") + parser.add_argument("--sample_size", type=int, default=16384, help="size of the prompt dataset") + args = parser.parse_args() + sample(args) diff --git a/applications/Chat/examples/inference.py b/applications/Chat/examples/inference.py index ae59d91c1822825924e87401ee5f5064cf41fbb6..62e06bf7b3bb75640fc4875ae984ef3b248faacd 100644 --- a/applications/Chat/examples/inference.py +++ b/applications/Chat/examples/inference.py @@ -2,63 +2,72 @@ import argparse import torch from coati.models.bloom import BLOOMActor +from coati.models.generation import generate from coati.models.gpt import GPTActor +from coati.models.llama import LlamaActor from coati.models.opt import OPTActor -from coati.models.roberta import RoBERTaActor -from transformers import AutoTokenizer, RobertaTokenizer -from transformers.models.gpt2.tokenization_gpt2 import GPT2Tokenizer +from transformers import AutoTokenizer, BloomTokenizerFast, GPT2Tokenizer, LlamaTokenizer def eval(args): # configure model - if args.model == 'gpt2': - actor = GPTActor(pretrained=args.pretrain).to(torch.cuda.current_device()) - elif args.model == 'bloom': - actor = BLOOMActor(pretrained=args.pretrain).to(torch.cuda.current_device()) - elif args.model == 'opt': - actor = OPTActor(pretrained=args.pretrain).to(torch.cuda.current_device()) - elif args.model == 'roberta': - actor = RoBERTaActor(pretrained=args.pretrain).to(torch.cuda.current_device()) + if args.model == "gpt2": + actor = GPTActor(pretrained=args.pretrain) + elif args.model == "bloom": + actor = BLOOMActor(pretrained=args.pretrain) + elif args.model == "opt": + actor = OPTActor(pretrained=args.pretrain) + elif args.model == "llama": + actor = LlamaActor(pretrained=args.pretrain) else: raise ValueError(f'Unsupported model "{args.model}"') - state_dict = torch.load(args.model_path) - actor.model.load_state_dict(state_dict) + actor.to(torch.cuda.current_device()) + if args.model_path is not None: + state_dict = torch.load(args.model_path) + actor.load_state_dict(state_dict) # configure tokenizer - if args.model == 'gpt2': - tokenizer = GPT2Tokenizer.from_pretrained('gpt2') + if args.model == "gpt2": + tokenizer = GPT2Tokenizer.from_pretrained("gpt2") tokenizer.pad_token = tokenizer.eos_token - elif args.model == 'bloom': - tokenizer = AutoTokenizer.from_pretrained('bigscience/bloom-560m') + elif args.model == "bloom": + tokenizer = BloomTokenizerFast.from_pretrained("bigscience/bloom-560m") tokenizer.pad_token = tokenizer.eos_token - elif args.model == 'opt': - tokenizer = AutoTokenizer.from_pretrained('facebook/opt-350m') - elif args.model == 'roberta': - tokenizer = RobertaTokenizer.from_pretrained("roberta-base") + elif args.model == "opt": + tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m") + tokenizer.pad_token = tokenizer.eos_token + elif args.model == "llama": + tokenizer = LlamaTokenizer.from_pretrained("hf-internal-testing/llama-tokenizer") + tokenizer.eos_token = "<\s>" + tokenizer.pad_token = tokenizer.unk_token else: raise ValueError(f'Unsupported model "{args.model}"') actor.eval() - input = args.input - input_ids = tokenizer.encode(input, return_tensors='pt').to(torch.cuda.current_device()) - outputs = actor.generate(input_ids, - max_length=args.max_length, - do_sample=True, - top_k=50, - top_p=0.95, - num_return_sequences=1) + tokenizer.padding_side = "left" + input_ids = tokenizer.encode(args.input, return_tensors="pt").to(torch.cuda.current_device()) + outputs = generate( + actor, + input_ids, + tokenizer=tokenizer, + max_length=args.max_length, + do_sample=True, + top_k=50, + top_p=0.95, + num_return_sequences=1, + ) output = tokenizer.batch_decode(outputs[0], skip_special_tokens=True) - print(output) + print(f"[Output]: {''.join(output)}") -if __name__ == '__main__': +if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument('--model', default='gpt2', choices=['gpt2', 'bloom', 'opt', 'roberta']) + parser.add_argument("--model", default="gpt2", choices=["gpt2", "bloom", "opt", "llama"]) # We suggest to use the pretrained model from HuggingFace, use pretrain to configure model - parser.add_argument('--pretrain', type=str, default=None) - parser.add_argument('--model_path', type=str, default=None) - parser.add_argument('--input', type=str, default='Question: How are you ? Answer:') - parser.add_argument('--max_length', type=int, default=100) + parser.add_argument("--pretrain", type=str, default=None) + parser.add_argument("--model_path", type=str, default=None) + parser.add_argument("--input", type=str, default="Question: How are you ? Answer:") + parser.add_argument("--max_length", type=int, default=100) args = parser.parse_args() eval(args) diff --git a/applications/Chat/examples/ray/1mmt_prompt.py b/applications/Chat/examples/ray/1mmt_prompt.py new file mode 100644 index 0000000000000000000000000000000000000000..8de6219ec4e9fbdc471d0ac5f2a1c8f67efc41dc --- /dev/null +++ b/applications/Chat/examples/ray/1mmt_prompt.py @@ -0,0 +1,181 @@ +import argparse +import os +import socket +from functools import partial + +import pandas as pd +import ray +from coati.quant import llama_load_quant, low_resource_init +from coati.ray.detached_trainer_ppo import DetachedPPOTrainer +from coati.ray.experience_maker_holder import ExperienceMakerHolder +from coati.ray.utils import ( + get_actor_from_args, + get_critic_from_args, + get_reward_model_from_args, + get_strategy_from_args, + get_tokenizer_from_args, +) +from torch.utils.data import DataLoader +from transformers import AutoConfig +from transformers.modeling_utils import no_init_weights + + +def get_free_port(): + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + s.bind(("", 0)) + return s.getsockname()[1] + + +def get_local_ip(): + with socket.socket(socket.AF_INET, socket.SOCK_DGRAM) as s: + s.connect(("8.8.8.8", 80)) + return s.getsockname()[0] + + +def main(args): + master_addr = str(get_local_ip()) + # trainer_env_info + trainer_port = str(get_free_port()) + env_info_trainers = [ + { + "local_rank": "0", + "rank": str(rank), + "world_size": str(args.num_trainers), + "master_port": trainer_port, + "master_addr": master_addr, + } + for rank in range(args.num_trainers) + ] + + # maker_env_info + maker_port = str(get_free_port()) + env_info_maker = { + "local_rank": "0", + "rank": "0", + "world_size": "1", + "master_port": maker_port, + "master_addr": master_addr, + } + + # configure tokenizer + tokenizer = get_tokenizer_from_args(args.model) + + def trainer_model_fn(): + actor = get_actor_from_args(args.model, args.pretrain).half().cuda() + critic = get_critic_from_args(args.model, args.critic_pretrain).half().cuda() + return actor, critic + + # configure Trainer + trainer_refs = [ + DetachedPPOTrainer.options(name=f"trainer{i}", num_gpus=1, max_concurrency=2).remote( + experience_maker_holder_name_list=["maker1"], + strategy_fn=partial(get_strategy_from_args, args.trainer_strategy), + model_fn=trainer_model_fn, + env_info=env_info_trainer, + train_batch_size=args.train_batch_size, + buffer_limit=16, + eval_performance=True, + debug=args.debug, + update_lora_weights=not (args.lora_rank == 0), + ) + for i, env_info_trainer in enumerate(env_info_trainers) + ] + + def model_fn(): + actor = get_actor_from_args(args.model, args.pretrain).requires_grad_(False).half().cuda() + critic = get_critic_from_args(args.model, args.critic_pretrain).requires_grad_(False).half().cuda() + reward_model = get_reward_model_from_args(args.model, args.critic_pretrain).requires_grad_(False).half().cuda() + if args.initial_model_quant_ckpt is not None and args.model == "llama": + # quantize initial model + actor_cfg = AutoConfig.from_pretrained(args.pretrain) + with low_resource_init(), no_init_weights(): + initial_model = get_actor_from_args(args.model, config=actor_cfg) + initial_model.model = ( + llama_load_quant( + initial_model.model, args.initial_model_quant_ckpt, args.quant_bits, args.quant_group_size + ) + .cuda() + .requires_grad_(False) + ) + else: + initial_model = get_actor_from_args(args.model, args.pretrain).requires_grad_(False).half().cuda() + return actor, critic, reward_model, initial_model + + # configure Experience Maker + experience_holder_ref = ExperienceMakerHolder.options(name="maker1", num_gpus=1, max_concurrency=2).remote( + detached_trainer_name_list=[f"trainer{i}" for i in range(args.num_trainers)], + strategy_fn=partial(get_strategy_from_args, args.maker_strategy), + model_fn=model_fn, + env_info=env_info_maker, + experience_batch_size=args.experience_batch_size, + kl_coef=0.1, + debug=args.debug, + update_lora_weights=not (args.lora_rank == 0), + # sync_models_from_trainers=True, + # generation kwargs: + max_length=512, + do_sample=True, + temperature=1.0, + top_k=50, + pad_token_id=tokenizer.pad_token_id, + eos_token_id=tokenizer.eos_token_id, + eval_performance=True, + use_cache=True, + ) + + # uncomment this function if sync_models_from_trainers is True + # ray.get([ + # trainer_ref.sync_models_to_remote_makers.remote() + # for trainer_ref in trainer_refs + # ]) + + wait_tasks = [] + + total_steps = args.experience_batch_size * args.experience_steps // (args.num_trainers * args.train_batch_size) + for trainer_ref in trainer_refs: + wait_tasks.append(trainer_ref.fit.remote(total_steps, args.update_steps, args.train_epochs)) + + dataset_size = args.experience_batch_size * 4 + + def build_dataloader(): + def tokenize_fn(texts): + batch = tokenizer(texts, return_tensors="pt", max_length=96, padding="max_length", truncation=True) + return {k: v.cuda() for k, v in batch.items()} + + dataset = pd.read_csv(args.prompt_path)["prompt"] + dataloader = DataLoader(dataset=dataset, batch_size=dataset_size, shuffle=True, collate_fn=tokenize_fn) + return dataloader + + wait_tasks.append(experience_holder_ref.workingloop.remote(build_dataloader, num_steps=args.experience_steps)) + + ray.get(wait_tasks) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--prompt_path", type=str, default=None) + parser.add_argument("--num_trainers", type=int, default=1) + parser.add_argument( + "--trainer_strategy", + choices=["ddp", "colossalai_gemini", "colossalai_zero2", "colossalai_gemini_cpu", "colossalai_zero2_cpu"], + default="ddp", + ) + parser.add_argument("--maker_strategy", choices=["naive"], default="naive") + parser.add_argument("--model", default="gpt2", choices=["gpt2", "bloom", "opt", "llama"]) + parser.add_argument("--critic_model", default="gpt2", choices=["gpt2", "bloom", "opt", "llama"]) + parser.add_argument("--pretrain", type=str, default=None) + parser.add_argument("--critic_pretrain", type=str, default=None) + parser.add_argument("--experience_steps", type=int, default=4) + parser.add_argument("--experience_batch_size", type=int, default=8) + parser.add_argument("--train_epochs", type=int, default=1) + parser.add_argument("--update_steps", type=int, default=2) + parser.add_argument("--train_batch_size", type=int, default=8) + parser.add_argument("--lora_rank", type=int, default=0, help="low-rank adaptation matrices rank") + + parser.add_argument("--initial_model_quant_ckpt", type=str, default=None) + parser.add_argument("--quant_bits", type=int, default=4) + parser.add_argument("--quant_group_size", type=int, default=128) + parser.add_argument("--debug", action="store_true") + args = parser.parse_args() + ray.init(namespace=os.environ["RAY_NAMESPACE"], runtime_env={"env_vars": dict(os.environ)}) + main(args) diff --git a/applications/Chat/examples/ray/mmmt_prompt.py b/applications/Chat/examples/ray/mmmt_prompt.py new file mode 100644 index 0000000000000000000000000000000000000000..7c03a0468b020921872ff726dde7db9db9a828c7 --- /dev/null +++ b/applications/Chat/examples/ray/mmmt_prompt.py @@ -0,0 +1,201 @@ +import argparse +import os +import socket +from functools import partial + +import pandas as pd +import ray +from coati.quant import llama_load_quant, low_resource_init +from coati.ray.detached_trainer_ppo import DetachedPPOTrainer +from coati.ray.experience_maker_holder import ExperienceMakerHolder +from coati.ray.utils import ( + get_actor_from_args, + get_critic_from_args, + get_receivers_per_sender, + get_reward_model_from_args, + get_strategy_from_args, +) +from torch.utils.data import DataLoader +from transformers import AutoConfig, AutoTokenizer +from transformers.modeling_utils import no_init_weights + + +def get_free_port(): + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + s.bind(("", 0)) + return s.getsockname()[1] + + +def get_local_ip(): + with socket.socket(socket.AF_INET, socket.SOCK_DGRAM) as s: + s.connect(("8.8.8.8", 80)) + return s.getsockname()[0] + + +def main(args): + master_addr = str(get_local_ip()) + # trainer_env_info + trainer_port = str(get_free_port()) + env_info_trainers = [ + { + "local_rank": "0", + "rank": str(rank), + "world_size": str(args.num_trainers), + "master_port": trainer_port, + "master_addr": master_addr, + } + for rank in range(args.num_trainers) + ] + + # maker_env_info + maker_port = str(get_free_port()) + env_info_makers = [ + { + "local_rank": "0", + "rank": str(rank), + "world_size": str(args.num_makers), + "master_port": maker_port, + "master_addr": master_addr, + } + for rank in range(args.num_makers) + ] + + # configure tokenizer + tokenizer = AutoTokenizer.from_pretrained(args.pretrain) + tokenizer.pad_token = tokenizer.eos_token + + def model_fn(): + actor = get_actor_from_args(args.model, args.pretrain).requires_grad_(False).half().cuda() + critic = get_critic_from_args(args.model, args.critic_pretrain).requires_grad_(False).half().cuda() + reward_model = get_reward_model_from_args(args.model, args.critic_pretrain).requires_grad_(False).half().cuda() + if args.initial_model_quant_ckpt is not None and args.model == "llama": + # quantize initial model + actor_cfg = AutoConfig.from_pretrained(args.pretrain) + with low_resource_init(), no_init_weights(): + initial_model = get_actor_from_args(args.model, config=actor_cfg) + initial_model.model = ( + llama_load_quant( + initial_model.model, args.initial_model_quant_ckpt, args.quant_bits, args.quant_group_size + ) + .cuda() + .requires_grad_(False) + ) + else: + initial_model = get_actor_from_args(args.model, args.pretrain).requires_grad_(False).half().cuda() + return actor, critic, reward_model, initial_model + + # configure Experience Maker + experience_holder_refs = [ + ExperienceMakerHolder.options(name=f"maker{i}", num_gpus=1, max_concurrency=2).remote( + detached_trainer_name_list=[ + f"trainer{x}" + for x in get_receivers_per_sender(i, args.num_makers, args.num_trainers, allow_idle_sender=False) + ], + strategy_fn=partial(get_strategy_from_args, args.maker_strategy), + model_fn=model_fn, + env_info=env_info_maker, + kl_coef=0.1, + debug=args.debug, + update_lora_weights=not (args.lora_rank == 0), + # sync_models_from_trainers=True, + # generation kwargs: + max_length=512, + do_sample=True, + temperature=1.0, + top_k=50, + pad_token_id=tokenizer.pad_token_id, + eos_token_id=tokenizer.eos_token_id, + eval_performance=True, + use_cache=True, + ) + for i, env_info_maker in enumerate(env_info_makers) + ] + + def trainer_model_fn(): + actor = get_actor_from_args(args.model, args.pretrain, lora_rank=args.lora_rank).half().cuda() + critic = get_critic_from_args(args.model, args.critic_pretrain, lora_rank=args.lora_rank).half().cuda() + return actor, critic + + # configure Trainer + trainer_refs = [ + DetachedPPOTrainer.options(name=f"trainer{i}", num_gpus=1, max_concurrency=2).remote( + experience_maker_holder_name_list=[ + f"maker{x}" + for x in get_receivers_per_sender(i, args.num_trainers, args.num_makers, allow_idle_sender=True) + ], + strategy_fn=partial(get_strategy_from_args, args.trainer_strategy), + model_fn=trainer_model_fn, + env_info=env_info_trainer, + train_batch_size=args.train_batch_size, + buffer_limit=16, + eval_performance=True, + debug=args.debug, + update_lora_weights=not (args.lora_rank == 0), + ) + for i, env_info_trainer in enumerate(env_info_trainers) + ] + + dataset_size = args.experience_batch_size * 4 + + def build_dataloader(): + def tokenize_fn(texts): + batch = tokenizer(texts, return_tensors="pt", max_length=96, padding="max_length", truncation=True) + return {k: v.cuda() for k, v in batch.items()} + + dataset = pd.read_csv(args.prompt_path)["prompt"] + dataloader = DataLoader(dataset=dataset, batch_size=dataset_size, shuffle=True, collate_fn=tokenize_fn) + return dataloader + + # uncomment this function if sync_models_from_trainers is True + # ray.get([ + # trainer_ref.sync_models_to_remote_makers.remote() + # for trainer_ref in trainer_refs + # ]) + + wait_tasks = [] + + for experience_holder_ref in experience_holder_refs: + wait_tasks.append(experience_holder_ref.workingloop.remote(build_dataloader, num_steps=args.experience_steps)) + + total_steps = ( + args.experience_batch_size + * args.experience_steps + * args.num_makers + // (args.num_trainers * args.train_batch_size) + ) + for trainer_ref in trainer_refs: + wait_tasks.append(trainer_ref.fit.remote(total_steps, args.update_steps, args.train_epochs)) + + ray.get(wait_tasks) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--prompt_path", type=str, default=None) + parser.add_argument("--num_makers", type=int, default=1) + parser.add_argument("--num_trainers", type=int, default=1) + parser.add_argument( + "--trainer_strategy", + choices=["ddp", "colossalai_gemini", "colossalai_zero2", "colossalai_gemini_cpu", "colossalai_zero2_cpu"], + default="ddp", + ) + parser.add_argument("--maker_strategy", choices=["naive"], default="naive") + parser.add_argument("--model", default="gpt2", choices=["gpt2", "bloom", "opt", "llama"]) + parser.add_argument("--critic_model", default="gpt2", choices=["gpt2", "bloom", "opt", "llama"]) + parser.add_argument("--pretrain", type=str, default=None) + parser.add_argument("--critic_pretrain", type=str, default=None) + parser.add_argument("--experience_steps", type=int, default=4) + parser.add_argument("--experience_batch_size", type=int, default=8) + parser.add_argument("--train_epochs", type=int, default=1) + parser.add_argument("--update_steps", type=int, default=2) + parser.add_argument("--train_batch_size", type=int, default=8) + parser.add_argument("--lora_rank", type=int, default=0, help="low-rank adaptation matrices rank") + + parser.add_argument("--initial_model_quant_ckpt", type=str, default=None) + parser.add_argument("--quant_bits", type=int, default=4) + parser.add_argument("--quant_group_size", type=int, default=128) + parser.add_argument("--debug", action="store_true") + args = parser.parse_args() + + ray.init(namespace=os.environ["RAY_NAMESPACE"], runtime_env={"env_vars": dict(os.environ)}) + main(args) diff --git a/applications/Chat/examples/ray/requirements.txt b/applications/Chat/examples/ray/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..e0275631807fc1c22cd3593cd1c48d29d1bbaff9 --- /dev/null +++ b/applications/Chat/examples/ray/requirements.txt @@ -0,0 +1 @@ +ray diff --git a/applications/Chat/examples/ray/test_ci.sh b/applications/Chat/examples/ray/test_ci.sh new file mode 100755 index 0000000000000000000000000000000000000000..895f7de0fea94c5a58081a52983f2dc393ff9780 --- /dev/null +++ b/applications/Chat/examples/ray/test_ci.sh @@ -0,0 +1,12 @@ +#!/bin/bash + +set -xe +BASE=$(realpath $(dirname $0)) + +export RAY_NAMESPACE=admin +export DATA=/data/scratch/chatgpt/prompts.csv + +# install requirements +pip install -r ${BASE}/requirements.txt + +python ${BASE}/mmmt_prompt.py --prompt_path $DATA --num_makers 2 --num_trainers 2 --trainer_strategy colossalai_gemini --model opt --critic_model opt --pretrain facebook/opt-350m --critic_pretrain facebook/opt-125m --experience_batch_size 4 --train_batch_size 2 diff --git a/applications/Chat/examples/requirements.txt b/applications/Chat/examples/requirements.txt index 40e6edc7ea7303c516ededa9ecd360f0445f957d..5474dfa16b3ef8c88c2a8073163ec9d08e03bf86 100644 --- a/applications/Chat/examples/requirements.txt +++ b/applications/Chat/examples/requirements.txt @@ -1,2 +1,3 @@ pandas>=1.4.1 sentencepiece +colossalai==0.3.3 diff --git a/applications/Chat/examples/test_ci.sh b/applications/Chat/examples/test_ci.sh deleted file mode 100755 index 2b049163c8012f0d7954a805be41990ecab1d910..0000000000000000000000000000000000000000 --- a/applications/Chat/examples/test_ci.sh +++ /dev/null @@ -1,126 +0,0 @@ -#!/usr/bin/env bash - -set -xue - -if [ -z "$SFT_DATASET" ]; then - echo "Please set \$SFT_DATASET to the path to sft dataset." - exit 1 -fi - -if [ -z "$PROMPT_PATH" ]; then - echo "Please set \$PROMPT_PATH to the path to prompts csv." - exit 1 -fi - -if [ -z "$PRETRAIN_DATASET" ]; then - echo "Please set \$PRETRAIN_DATASET to the path to alpaca data." - exit 1 -fi - -BASE=$(realpath $(dirname $0)) - -export OMP_NUM_THREADS=8 - -# install requirements -pip install -r ${BASE}/requirements.txt - -wandb init -m offline - -# train sft -torchrun --standalone --nproc_per_node=4 ${BASE}/train_sft.py --pretrain 'bigscience/bloom-560m' \ - --model 'bloom' --strategy colossalai_zero2 --lora_rank 4\ - --dataset $SFT_DATASET --max_datasets_size 512 --max_epochs 1 \ - --save_path ${BASE}/output -rm -rf ${BASE}/output - -torchrun --standalone --nproc_per_node=4 ${BASE}/train_sft.py --pretrain 'gpt2' \ - --model 'gpt2' --strategy colossalai_zero2 \ - --dataset $SFT_DATASET --max_datasets_size 512 --max_epochs 1 \ - --save_path ${BASE}/output -rm -rf ${BASE}/output - -torchrun --standalone --nproc_per_node=4 ${BASE}/train_sft.py --pretrain 'facebook/opt-350m' \ - --model 'opt' --strategy colossalai_zero2 --lora_rank 4\ - --dataset $SFT_DATASET --max_datasets_size 512 --max_epochs 1 \ - --save_path ${BASE}/output -rm -rf ${BASE}/output - -torchrun --standalone --nproc_per_node=4 ${BASE}/train_sft.py --pretrain 'gpt2' \ - --model 'gpt2' --strategy ddp --lora_rank 4\ - --dataset $SFT_DATASET --max_datasets_size 512 --max_epochs 1 \ - --save_path ${BASE}/output - -#torchrun --standalone --nproc_per_node=4 ${BASE}/train_sft.py --pretrain 'facebook/opt-350m' \ -# --model 'opt' --strategy naive \ -# --dataset $SFT_DATASET --max_datasets_size 512 --max_epochs 1 \ -# --save_path ${BASE}/output - -rm -rf ${BASE}/output - -# train rm -torchrun --standalone --nproc_per_node=2 ${BASE}/train_reward_model.py \ - --pretrain 'facebook/opt-350m' --model 'opt' \ - --strategy colossalai_zero2 --loss_fn 'log_sig'\ - --dataset 'Anthropic/hh-rlhf' --subset 'harmless-base' \ - --test True --lora_rank 0 \ - --save_path ${BASE}/rm_ckpt_opt.pt - -torchrun --standalone --nproc_per_node=2 ${BASE}/train_reward_model.py \ - --pretrain 'gpt2' --model 'gpt2' \ - --strategy colossalai_zero2 --loss_fn 'log_exp' \ - --dataset 'Dahoas/rm-static' \ - --test True --lora_rank 0 \ - --save_path ${BASE}/rm_ckpt_gpt.pt - -torchrun --standalone --nproc_per_node=2 ${BASE}/train_reward_model.py \ - --pretrain 'gpt2' --model 'gpt2' \ - --strategy ddp --loss_fn 'log_exp' \ - --dataset 'Dahoas/rm-static' \ - --test True --lora_rank 4 \ - --save_path ${BASE}/rm_ckpt.pt -rm -rf ${BASE}/rm_ckpt.pt - -torchrun --standalone --nproc_per_node=2 ${BASE}/train_reward_model.py \ - --pretrain 'bigscience/bloom-560m' --model 'bloom' \ - --strategy colossalai_zero2 --loss_fn 'log_sig' \ - --dataset 'Anthropic/hh-rlhf' --subset 'harmless-base' \ - --test True --lora_rank 4 \ - --save_path ${BASE}/rm_ckpt.pt -rm -rf ${BASE}/rm_ckpt.pt - -torchrun --standalone --nproc_per_node=2 ${BASE}/train_reward_model.py \ - --pretrain 'microsoft/deberta-v3-large' --model 'deberta' \ - --strategy colossalai_zero2 --loss_fn 'log_sig' \ - --dataset 'Anthropic/hh-rlhf' --subset 'harmless-base' \ - --test True --lora_rank 4 \ - --save_path ${BASE}/rm_ckpt.pt -rm -rf ${BASE}/rm_ckpt.pt - -torchrun --standalone --nproc_per_node=2 ${BASE}/train_reward_model.py \ - --pretrain 'roberta-base' --model 'roberta' \ - --strategy colossalai_zero2 --loss_fn 'log_exp'\ - --dataset 'Anthropic/hh-rlhf' --subset 'harmless-base'\ - --test True --lora_rank 4 \ - --save_path ${BASE}/rm_ckpt.pt - -rm -rf ${BASE}/rm_ckpt.pt - -torchrun --standalone --nproc_per_node=2 ${BASE}/train_prompts.py --prompt_dataset $PROMPT_PATH --pretrain_dataset $PRETRAIN_DATASET \ - --strategy colossalai_zero2 --num_episodes 1 --max_timesteps 2 \ - --update_timesteps 2 --max_epochs 1 --train_batch_size 2 \ - --pretrain 'facebook/opt-350m' --model opt \ - --rm_pretrain 'facebook/opt-350m' \ - --rm_path ${BASE}/rm_ckpt_opt.pt \ - --save_path ${BASE}/actor_checkpoint_prompts.pt -rm -rf ${BASE}/rm_ckpt_opt.pt - -torchrun --standalone --nproc_per_node=2 ${BASE}/train_prompts.py --prompt_dataset $PROMPT_PATH --pretrain_dataset $PRETRAIN_DATASET \ - --strategy colossalai_zero2 --num_episodes 1 --max_timesteps 2 \ - --update_timesteps 2 --max_epochs 1 --train_batch_size 2 \ - --pretrain 'gpt2' --model gpt2 \ - --rm_pretrain 'gpt2' \ - --rm_path ${BASE}/rm_ckpt_gpt.pt \ - --save_path ${BASE}/actor_checkpoint_prompts.pt -rm -rf ${BASE}/rm_ckpt_gpt.pt - -rm -rf ${BASE}/actor_checkpoint_prompts.pt diff --git a/applications/Chat/examples/train_prompts.py b/applications/Chat/examples/train_prompts.py index a584991cd34e00cfb9bf2c97a285e642b25268c5..8868e278d85e9fa2a5a5ce2c0a74b57d6103a1fd 100644 --- a/applications/Chat/examples/train_prompts.py +++ b/applications/Chat/examples/train_prompts.py @@ -1,169 +1,169 @@ import argparse +import warnings -import pandas as pd import torch import torch.distributed as dist -from coati.dataset import DataCollatorForSupervisedDataset, PromptDataset, SupervisedDataset +from coati.dataset import PromptDataset, SupervisedDataset from coati.models.bloom import BLOOMRM, BLOOMActor, BLOOMCritic from coati.models.gpt import GPTRM, GPTActor, GPTCritic from coati.models.llama import LlamaActor, LlamaCritic, LlamaRM from coati.models.opt import OPTRM, OPTActor, OPTCritic -from coati.models.roberta import RoBERTaActor, RoBERTaCritic, RoBERTaRM from coati.trainer import PPOTrainer -from coati.trainer.strategies import ColossalAIStrategy, DDPStrategy, NaiveStrategy -from coati.utils import prepare_llama_tokenizer_and_embedding +from coati.trainer.strategies import DDPStrategy, GeminiStrategy, LowLevelZeroStrategy from torch.optim import Adam from torch.utils.data import DataLoader from torch.utils.data.distributed import DistributedSampler -from transformers import AutoTokenizer, BloomTokenizerFast, GPT2Tokenizer, LlamaTokenizer, RobertaTokenizer +from transformers import AutoTokenizer, BloomTokenizerFast, GPT2Tokenizer, LlamaTokenizer from colossalai.nn.optimizer import HybridAdam def main(args): # configure strategy - if args.strategy == 'naive': - strategy = NaiveStrategy() - elif args.strategy == 'ddp': + if args.strategy == "ddp": strategy = DDPStrategy() - elif args.strategy == 'colossalai_gemini': - strategy = ColossalAIStrategy(stage=3, placement_policy='cuda', initial_scale=2**5) - elif args.strategy == 'colossalai_zero2': - strategy = ColossalAIStrategy(stage=2, placement_policy='cuda') + elif args.strategy == "colossalai_gemini": + strategy = GeminiStrategy(placement_policy="static", initial_scale=2**5) + elif args.strategy == "colossalai_zero2": + strategy = LowLevelZeroStrategy(stage=2, placement_policy="cuda") else: raise ValueError(f'Unsupported strategy "{args.strategy}"') if args.rm_path is not None: - state_dict = torch.load(args.rm_path, map_location='cpu') - - # configure model - if args.model == 'gpt2': - initial_model = GPTActor(pretrained=args.pretrain) - elif args.model == 'bloom': - initial_model = BLOOMActor(pretrained=args.pretrain) - elif args.model == 'opt': - initial_model = OPTActor(pretrained=args.pretrain) - elif args.model == 'llama': - initial_model = LlamaActor(pretrained=args.pretrain) - elif args.model == 'roberta': - initial_model = RoBERTaActor(pretrained=args.pretrain) - else: - raise ValueError(f'Unsupported actor model "{args.model}"') + warnings.warn("LoRA weights should be merged with the model weights") + state_dict = torch.load(args.rm_path, map_location="cpu") - if args.rm_model == None: - rm_model_name = args.model - else: - rm_model_name = args.rm_model - - if rm_model_name == 'gpt2': - reward_model = GPTRM(pretrained=args.rm_pretrain) - elif rm_model_name == 'bloom': - reward_model = BLOOMRM(pretrained=args.rm_pretrain) - elif rm_model_name == 'opt': - reward_model = OPTRM(pretrained=args.rm_pretrain) - elif rm_model_name == 'llama': - reward_model = LlamaRM(pretrained=args.rm_pretrain) - elif rm_model_name == 'roberta': - reward_model = RoBERTaRM(pretrained=args.rm_pretrain) - else: - raise ValueError(f'Unsupported reward model "{rm_model_name}"') + if args.lora_rank > 0: + warnings.warn("Lora is not supported yet.") + args.lora_rank = 0 - if args.rm_path is not None: - reward_model.load_state_dict(state_dict) + with strategy.model_init_context(): + # configure model + if args.model == "gpt2": + initial_model = GPTActor(pretrained=args.pretrain) + elif args.model == "bloom": + initial_model = BLOOMActor(pretrained=args.pretrain) + elif args.model == "opt": + initial_model = OPTActor(pretrained=args.pretrain) + elif args.model == "llama": + initial_model = LlamaActor(pretrained=args.pretrain) + else: + raise ValueError(f'Unsupported actor model "{args.model}"') + + if args.rm_model is None: + rm_model_name = args.model + else: + rm_model_name = args.rm_model + + if rm_model_name == "gpt2": + reward_model = GPTRM(pretrained=args.rm_pretrain, lora_rank=args.lora_rank) + elif rm_model_name == "bloom": + reward_model = BLOOMRM(pretrained=args.rm_pretrain, lora_rank=args.lora_rank) + elif rm_model_name == "opt": + reward_model = OPTRM(pretrained=args.rm_pretrain, lora_rank=args.lora_rank) + elif rm_model_name == "llama": + reward_model = LlamaRM(pretrained=args.rm_pretrain, lora_rank=args.lora_rank) + else: + raise ValueError(f'Unsupported reward model "{rm_model_name}"') + + if args.rm_path is not None: + reward_model.load_state_dict(state_dict, strict=False) - initial_model.to(torch.float16).to(torch.cuda.current_device()) - reward_model.to(torch.float16).to(torch.cuda.current_device()) + initial_model.to(torch.bfloat16).to(torch.cuda.current_device()) + reward_model.to(torch.bfloat16).to(torch.cuda.current_device()) - with strategy.model_init_context(): - if args.model == 'gpt2': + if args.model == "gpt2": actor = GPTActor(pretrained=args.pretrain, lora_rank=args.lora_rank) - elif args.model == 'bloom': + elif args.model == "bloom": actor = BLOOMActor(pretrained=args.pretrain, lora_rank=args.lora_rank) - elif args.model == 'opt': + elif args.model == "opt": actor = OPTActor(pretrained=args.pretrain, lora_rank=args.lora_rank) - elif args.model == 'llama': + elif args.model == "llama": actor = LlamaActor(pretrained=args.pretrain, lora_rank=args.lora_rank) - elif args.model == 'roberta': - actor = RoBERTaActor(pretrained=args.pretrain, lora_rank=args.lora_rank) else: raise ValueError(f'Unsupported actor model "{args.model}"') - if rm_model_name == 'gpt2': - critic = GPTCritic(pretrained=args.rm_pretrain, lora_rank=args.lora_rank, use_action_mask=True) - elif rm_model_name == 'bloom': - critic = BLOOMCritic(pretrained=args.rm_pretrain, lora_rank=args.lora_rank, use_action_mask=True) - elif rm_model_name == 'opt': - critic = OPTCritic(pretrained=args.rm_pretrain, lora_rank=args.lora_rank, use_action_mask=True) - elif rm_model_name == 'llama': - critic = LlamaCritic(pretrained=args.rm_pretrain, lora_rank=args.lora_rank, use_action_mask=True) - elif rm_model_name == 'roberta': - critic = RoBERTaCritic(pretrained=args.rm_pretrain, lora_rank=args.lora_rank, use_action_mask=True) + if rm_model_name == "gpt2": + critic = GPTCritic(pretrained=args.rm_pretrain, lora_rank=args.lora_rank) + elif rm_model_name == "bloom": + critic = BLOOMCritic(pretrained=args.rm_pretrain, lora_rank=args.lora_rank) + elif rm_model_name == "opt": + critic = OPTCritic(pretrained=args.rm_pretrain, lora_rank=args.lora_rank) + elif rm_model_name == "llama": + critic = LlamaCritic(pretrained=args.rm_pretrain, lora_rank=args.lora_rank) else: raise ValueError(f'Unsupported reward model "{rm_model_name}"') if args.rm_path is not None: - critic.load_state_dict(state_dict) + critic.load_state_dict(state_dict, strict=False) del state_dict - if args.strategy != 'colossalai_gemini': - critic.to(torch.float16).to(torch.cuda.current_device()) - actor.to(torch.float16).to(torch.cuda.current_device()) + actor.to(torch.bfloat16).to(torch.cuda.current_device()) + critic.to(torch.bfloat16).to(torch.cuda.current_device()) # configure optimizer - if args.strategy.startswith('colossalai'): - actor_optim = HybridAdam(actor.parameters(), lr=1e-7) - critic_optim = HybridAdam(critic.parameters(), lr=1e-7) + if args.strategy.startswith("colossalai"): + actor_optim = HybridAdam(actor.parameters(), lr=args.lr) + critic_optim = HybridAdam(critic.parameters(), lr=args.lr) else: - actor_optim = Adam(actor.parameters(), lr=1e-7) - critic_optim = Adam(critic.parameters(), lr=1e-7) + actor_optim = Adam(actor.parameters(), lr=args.lr) + critic_optim = Adam(critic.parameters(), lr=args.lr) # configure tokenizer - if args.model == 'gpt2': - tokenizer = GPT2Tokenizer.from_pretrained('gpt2') - elif args.model == 'bloom': - tokenizer = BloomTokenizerFast.from_pretrained('bigscience/bloom-560m') - elif args.model == 'opt': - tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m") - elif args.model == 'llama': - tokenizer = LlamaTokenizer.from_pretrained(args.pretrain) - tokenizer.eos_token = '<\s>' - elif args.model == 'roberta': - tokenizer = RobertaTokenizer.from_pretrained("roberta-base") + if args.model == "gpt2": + tokenizer = GPT2Tokenizer.from_pretrained("gpt2" if args.tokenizer is None else args.tokenizer) + tokenizer.pad_token = tokenizer.eos_token + elif args.model == "bloom": + tokenizer = BloomTokenizerFast.from_pretrained( + "bigscience/bloom-560m" if args.tokenizer is None else args.tokenizer + ) + tokenizer.pad_token = tokenizer.eos_token + elif args.model == "opt": + tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m" if args.tokenizer is None else args.tokenizer) + tokenizer.pad_token = tokenizer.eos_token + elif args.model == "llama": + tokenizer = LlamaTokenizer.from_pretrained( + "hf-internal-testing/llama-tokenizer" if args.tokenizer is None else args.tokenizer + ) + tokenizer.eos_token = "<\s>" + tokenizer.pad_token = tokenizer.unk_token else: raise ValueError(f'Unsupported model "{args.model}"') - - if args.model == 'llama': - tokenizer = prepare_llama_tokenizer_and_embedding(tokenizer, actor) - else: - tokenizer.pad_token = tokenizer.eos_token - - data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer) - - prompt_dataset = PromptDataset(tokenizer=tokenizer, data_path=args.prompt_dataset, max_datasets_size=16384) + # NOTE: generate() requires padding_side to be "left" + tokenizer.padding_side = "left" + + prompt_dataset = PromptDataset( + tokenizer=tokenizer, + data_path=args.prompt_dataset, + max_datasets_size=args.max_datasets_size, + max_length=args.max_input_len, + ) if dist.is_initialized() and dist.get_world_size() > 1: prompt_sampler = DistributedSampler(prompt_dataset, shuffle=True, seed=42, drop_last=True) else: prompt_sampler = None - prompt_dataloader = DataLoader(prompt_dataset, - shuffle=(prompt_sampler is None), - sampler=prompt_sampler, - batch_size=args.experience_batch_size) - - pretrain_dataset = SupervisedDataset(tokenizer=tokenizer, - data_path=args.pretrain_dataset, - max_datasets_size=16384, - max_length=args.max_input_len) + prompt_dataloader = DataLoader( + prompt_dataset, shuffle=(prompt_sampler is None), sampler=prompt_sampler, batch_size=args.experience_batch_size + ) + + pretrain_dataset = SupervisedDataset( + tokenizer=tokenizer, + data_path=args.pretrain_dataset, + max_datasets_size=args.max_datasets_size, + max_length=args.max_input_len, + ) if dist.is_initialized() and dist.get_world_size() > 1: pretrain_sampler = DistributedSampler(pretrain_dataset, shuffle=True, seed=42, drop_last=True) else: pretrain_sampler = None - pretrain_dataloader = DataLoader(pretrain_dataset, - shuffle=(pretrain_sampler is None), - sampler=pretrain_sampler, - batch_size=args.ptx_batch_size, - collate_fn=data_collator) + pretrain_dataloader = DataLoader( + pretrain_dataset, shuffle=(pretrain_sampler is None), sampler=pretrain_sampler, batch_size=args.ptx_batch_size + ) - (actor, actor_optim), (critic, critic_optim) = strategy.prepare((actor, actor_optim), (critic, critic_optim)) + # NOTE: For small models like opt-1.3b, reward model and initial model are not required to be parallelized. + (actor, actor_optim), (critic, critic_optim), reward_model, initial_model = strategy.prepare( + (actor, actor_optim), (critic, critic_optim), reward_model, initial_model + ) # configure trainer trainer = PPOTrainer( @@ -174,60 +174,76 @@ def main(args): initial_model, actor_optim, critic_optim, + tokenizer=tokenizer, kl_coef=args.kl_coef, ptx_coef=args.ptx_coef, - max_epochs=args.max_epochs, train_batch_size=args.train_batch_size, max_length=args.max_seq_len, use_cache=True, do_sample=True, temperature=1.0, top_k=50, - pad_token_id=tokenizer.pad_token_id, - eos_token_id=tokenizer.eos_token_id, + offload_inference_models=args.strategy != "colossalai_gemini", ) - trainer.fit(prompt_dataloader=prompt_dataloader, - pretrain_dataloader=pretrain_dataloader, - num_episodes=args.num_episodes, - max_timesteps=args.max_timesteps, - update_timesteps=args.update_timesteps) + trainer.fit( + num_episodes=args.num_episodes, + num_collect_steps=args.num_collect_steps, + num_update_steps=args.num_update_steps, + prompt_dataloader=prompt_dataloader, + pretrain_dataloader=pretrain_dataloader, + log_dir=args.log_dir, + use_wandb=args.use_wandb, + ) + + if args.lora_rank > 0 and args.merge_lora_weights: + from coati.models.lora import LORA_MANAGER + # NOTE: set model to eval to merge LoRA weights + LORA_MANAGER.merge_weights = True + actor.eval() # save model checkpoint after fitting - strategy.save_model(actor, args.save_path, only_rank0=True) + strategy.save_pretrained(actor, path=args.save_path) # save optimizer checkpoint on all ranks if args.need_optim_ckpt: - strategy.save_optimizer(actor_optim, - 'actor_optim_checkpoint_prompts_%d.pt' % (torch.cuda.current_device()), - only_rank0=False) + strategy.save_optimizer( + actor_optim, "actor_optim_checkpoint_prompts_%d.pt" % (torch.cuda.current_device()), only_rank0=False + ) -if __name__ == '__main__': +if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument('--prompt_dataset', type=str, default=None, help='path to the prompt dataset') - parser.add_argument('--pretrain_dataset', type=str, default=None, help='path to the pretrained dataset') - parser.add_argument('--strategy', - choices=['naive', 'ddp', 'colossalai_gemini', 'colossalai_zero2'], - default='colossalai_zero2', - help='strategy to use') - parser.add_argument('--model', default='gpt2', choices=['gpt2', 'bloom', 'opt', 'llama', 'roberta']) - parser.add_argument('--pretrain', type=str, default=None) - parser.add_argument('--rm_model', default=None, choices=['gpt2', 'bloom', 'opt', 'llama', 'roberta']) - parser.add_argument('--rm_path', type=str, default=None) - parser.add_argument('--rm_pretrain', type=str, default=None) - parser.add_argument('--save_path', type=str, default='actor_checkpoint_prompts') - parser.add_argument('--need_optim_ckpt', type=bool, default=False) - parser.add_argument('--num_episodes', type=int, default=10) - parser.add_argument('--max_timesteps', type=int, default=10) - parser.add_argument('--update_timesteps', type=int, default=10) - parser.add_argument('--max_epochs', type=int, default=5) - parser.add_argument('--train_batch_size', type=int, default=8) - parser.add_argument('--ptx_batch_size', type=int, default=1) - parser.add_argument('--experience_batch_size', type=int, default=8) - parser.add_argument('--lora_rank', type=int, default=0, help="low-rank adaptation matrices rank") - parser.add_argument('--kl_coef', type=float, default=0.1) - parser.add_argument('--ptx_coef', type=float, default=0.9) - parser.add_argument('--max_input_len', type=int, default=96) - parser.add_argument('--max_seq_len', type=int, default=128) + parser.add_argument("--prompt_dataset", type=str, default=None, help="path to the prompt dataset") + parser.add_argument("--pretrain_dataset", type=str, default=None, help="path to the pretrained dataset") + parser.add_argument("--max_datasets_size", type=int, default=50000) + parser.add_argument( + "--strategy", + choices=["ddp", "colossalai_gemini", "colossalai_zero2"], + default="colossalai_zero2", + help="strategy to use", + ) + parser.add_argument("--model", default="gpt2", choices=["gpt2", "bloom", "opt", "llama"]) + parser.add_argument("--tokenizer", type=str, default=None) + parser.add_argument("--pretrain", type=str, default=None) + parser.add_argument("--rm_model", default=None, choices=["gpt2", "bloom", "opt", "llama"]) + parser.add_argument("--rm_path", type=str, default=None) + parser.add_argument("--rm_pretrain", type=str, default=None) + parser.add_argument("--save_path", type=str, default="actor_checkpoint_prompts") + parser.add_argument("--need_optim_ckpt", type=bool, default=False) + parser.add_argument("--num_episodes", type=int, default=10) + parser.add_argument("--num_collect_steps", type=int, default=10) + parser.add_argument("--num_update_steps", type=int, default=5) + parser.add_argument("--train_batch_size", type=int, default=8) + parser.add_argument("--ptx_batch_size", type=int, default=1) + parser.add_argument("--experience_batch_size", type=int, default=8) + parser.add_argument("--lora_rank", type=int, default=0, help="low-rank adaptation matrices rank") + parser.add_argument("--merge_lora_weights", type=bool, default=True) + parser.add_argument("--lr", type=float, default=1e-7) + parser.add_argument("--kl_coef", type=float, default=0.1) + parser.add_argument("--ptx_coef", type=float, default=0.9) + parser.add_argument("--max_input_len", type=int, default=96) + parser.add_argument("--max_seq_len", type=int, default=128) + parser.add_argument("--log_dir", default="logs", type=str) + parser.add_argument("--use_wandb", default=False, action="store_true") args = parser.parse_args() main(args) diff --git a/applications/Chat/examples/train_prompts.sh b/applications/Chat/examples/train_prompts.sh index 7f3b2636ca32862d03a260c44bfa4765f6f9990e..d04c416015b1566ee197941328cbc44f376f4545 100755 --- a/applications/Chat/examples/train_prompts.sh +++ b/applications/Chat/examples/train_prompts.sh @@ -1,13 +1,13 @@ set_n_least_used_CUDA_VISIBLE_DEVICES() { local n=${1:-"9999"} echo "GPU Memory Usage:" - local FIRST_N_GPU_IDS=$(nvidia-smi --query-gpu=memory.used --format=csv \ - | tail -n +2 \ - | nl -v 0 \ - | tee /dev/tty \ - | sort -g -k 2 \ - | awk '{print $1}' \ - | head -n $n) + local FIRST_N_GPU_IDS=$(nvidia-smi --query-gpu=memory.used --format=csv | + tail -n +2 | + nl -v 0 | + tee /dev/tty | + sort -g -k 2 | + awk '{print $1}' | + head -n $n) export CUDA_VISIBLE_DEVICES=$(echo $FIRST_N_GPU_IDS | sed 's/ /,/g') echo "Now CUDA_VISIBLE_DEVICES is set to:" echo "CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES" @@ -17,4 +17,9 @@ set_n_least_used_CUDA_VISIBLE_DEVICES 2 # torchrun --standalone --nproc_per_node=2 train_prompts.py prompts.csv --strategy colossalai_zero2 -torchrun --standalone --nproc_per_node=2 train_prompts.py --prompt_dataset /path/to/data.json --strategy colossalai_zero2 +torchrun --standalone --nproc_per_node=2 train_prompts.py \ + --pretrain_dataset /path/to/data.json \ + --prompt_dataset /path/to/data.json \ + --strategy colossalai_zero2 \ + --num_episodes 1 --num_collect_steps 2 --num_update_steps 1 \ + --train_batch_size 2 diff --git a/applications/Chat/examples/train_reward_model.py b/applications/Chat/examples/train_reward_model.py index 48b12336fa6743714add52ee52c8a518c155f2ab..df6e8b6bdc26209dbbecf9023e9230aa34834393 100644 --- a/applications/Chat/examples/train_reward_model.py +++ b/applications/Chat/examples/train_reward_model.py @@ -1,26 +1,22 @@ import argparse -from random import randint +import warnings -import loralib as lora import torch import torch.distributed as dist from coati.dataset import HhRlhfDataset, RmStaticDataset from coati.models import LogExpLoss, LogSigLoss -from coati.models.base import RewardModel from coati.models.bloom import BLOOMRM -from coati.models.deberta import DebertaRM from coati.models.gpt import GPTRM from coati.models.llama import LlamaRM from coati.models.opt import OPTRM -from coati.models.roberta import RoBERTaRM from coati.trainer import RewardModelTrainer -from coati.trainer.strategies import ColossalAIStrategy, DDPStrategy, NaiveStrategy -from coati.utils import prepare_llama_tokenizer_and_embedding +from coati.trainer.strategies import DDPStrategy, GeminiStrategy, LowLevelZeroStrategy from datasets import load_dataset from torch.optim import Adam +from torch.optim.lr_scheduler import CosineAnnealingLR from torch.utils.data import DataLoader from torch.utils.data.distributed import DistributedSampler -from transformers import AutoTokenizer, BloomTokenizerFast, DebertaV2Tokenizer, LlamaTokenizer, RobertaTokenizer +from transformers import AutoTokenizer, BloomTokenizerFast, LlamaTokenizer from transformers.models.gpt2.tokenization_gpt2 import GPT2Tokenizer from colossalai.nn.optimizer import HybridAdam @@ -28,72 +24,69 @@ from colossalai.nn.optimizer import HybridAdam def train(args): # configure strategy - if args.strategy == 'naive': - strategy = NaiveStrategy() - elif args.strategy == 'ddp': + if args.strategy == "ddp": strategy = DDPStrategy() - elif args.strategy == 'colossalai_gemini': - strategy = ColossalAIStrategy(stage=3, placement_policy='cuda') - elif args.strategy == 'colossalai_zero2': - strategy = ColossalAIStrategy(stage=2, placement_policy='cuda') + elif args.strategy == "colossalai_gemini": + strategy = GeminiStrategy(placement_policy="auto") + elif args.strategy == "colossalai_zero2": + strategy = LowLevelZeroStrategy(stage=2, placement_policy="cuda") else: raise ValueError(f'Unsupported strategy "{args.strategy}"') # configure model + if args.lora_rank > 0: + warnings.warn("Lora is not supported yet.") + args.lora_rank = 0 + with strategy.model_init_context(): - if args.model == 'bloom': - model = BLOOMRM(pretrained=args.pretrain, lora_rank=args.lora_rank).to(torch.cuda.current_device()) - elif args.model == 'opt': - model = OPTRM(pretrained=args.pretrain, lora_rank=args.lora_rank).to(torch.cuda.current_device()) - elif args.model == 'gpt2': - model = GPTRM(pretrained=args.pretrain, lora_rank=args.lora_rank).to(torch.cuda.current_device()) - elif args.model == 'deberta': - model = DebertaRM(pretrained=args.pretrain, lora_rank=args.lora_rank).to(torch.cuda.current_device()) - elif args.model == 'llama': - model = LlamaRM(pretrained=args.pretrain, lora_rank=args.lora_rank).to(torch.cuda.current_device()) - elif args.model == 'roberta': - model = RoBERTaRM(pretrained=args.pretrain, lora_rank=args.lora_rank).to(torch.cuda.current_device()) + if args.model == "bloom": + model = BLOOMRM(pretrained=args.pretrain, lora_rank=args.lora_rank) + elif args.model == "opt": + model = OPTRM(pretrained=args.pretrain, lora_rank=args.lora_rank) + elif args.model == "gpt2": + model = GPTRM(pretrained=args.pretrain, lora_rank=args.lora_rank) + elif args.model == "llama": + model = LlamaRM(pretrained=args.pretrain, lora_rank=args.lora_rank) else: raise ValueError(f'Unsupported model "{args.model}"') + model.to(torch.bfloat16).to(torch.cuda.current_device()) + if args.model_path is not None: state_dict = torch.load(args.model_path) model.load_state_dict(state_dict) - model = model.to(torch.float16) - # configure tokenizer - if args.model == 'gpt2': - tokenizer = GPT2Tokenizer.from_pretrained('gpt2') - elif args.model == 'bloom': - tokenizer = BloomTokenizerFast.from_pretrained('bigscience/bloom-560m') - elif args.model == 'opt': - tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m") - elif args.model == 'deberta': - tokenizer = DebertaV2Tokenizer.from_pretrained('microsoft/deberta-v3-large') - elif args.model == 'llama': - tokenizer = LlamaTokenizer.from_pretrained(args.pretrain) - elif args.model == 'roberta': - tokenizer = RobertaTokenizer.from_pretrained("roberta-base") + if args.model == "gpt2": + tokenizer = GPT2Tokenizer.from_pretrained("gpt2" if args.tokenizer is None else args.tokenizer) + tokenizer.pad_token = tokenizer.eos_token + elif args.model == "bloom": + tokenizer = BloomTokenizerFast.from_pretrained( + "bigscience/bloom-560m" if args.tokenizer is None else args.tokenizer + ) + tokenizer.pad_token = tokenizer.eos_token + elif args.model == "opt": + tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m" if args.tokenizer is None else args.tokenizer) + tokenizer.pad_token = tokenizer.eos_token + elif args.model == "llama": + tokenizer = LlamaTokenizer.from_pretrained( + "hf-internal-testing/llama-tokenizer" if args.tokenizer is None else args.tokenizer + ) + tokenizer.eos_token = "<\s>" + tokenizer.pad_token = tokenizer.unk_token else: raise ValueError(f'Unsupported model "{args.model}"') - max_len = args.max_len - - if args.model == 'llama': - tokenizer = prepare_llama_tokenizer_and_embedding(tokenizer, model) - else: - tokenizer.pad_token = tokenizer.eos_token # configure optimizer - if args.strategy.startswith('colossalai'): - optim = HybridAdam(model.parameters(), lr=5e-6) + if args.strategy.startswith("colossalai"): + optim = HybridAdam(model.parameters(), lr=args.lr) else: - optim = Adam(model.parameters(), lr=5e-6) + optim = Adam(model.parameters(), lr=args.lr) # configure loss function - if args.loss_fn == 'log_sig': + if args.loss_fn == "log_sig": loss_fn = LogSigLoss() - elif args.loss_fn == 'log_exp': + elif args.loss_fn == "log_exp": loss_fn = LogExpLoss() else: raise ValueError(f'Unsupported loss function "{args.loss_fn}"') @@ -104,107 +97,112 @@ def train(args): else: data = load_dataset(args.dataset) - if args.test: - train_data = data['train'].select(range(100)) - eval_data = data['test'].select(range(10)) - else: - train_data = data['train'] - eval_data = data['test'] - valid_data = data['test'].select((randint(0, len(eval_data) - 1) for _ in range(len(eval_data) // 5))) - - if args.dataset == 'Dahoas/rm-static': - train_dataset = RmStaticDataset(train_data, tokenizer, max_len) - valid_dataset = RmStaticDataset(valid_data, tokenizer, max_len) - eval_dataset = RmStaticDataset(eval_data, tokenizer, max_len) - elif args.dataset == 'Anthropic/hh-rlhf': - train_dataset = HhRlhfDataset(train_data, tokenizer, max_len) - valid_dataset = HhRlhfDataset(valid_data, tokenizer, max_len) - eval_dataset = HhRlhfDataset(eval_data, tokenizer, max_len) + train_data = data["train"].select(range(min(args.max_datasets_size, len(data["train"])))) + eval_data = data["test"].select(range(min(args.max_datasets_size, len(data["test"])))) + + if args.dataset == "Dahoas/rm-static": + train_dataset = RmStaticDataset(train_data, tokenizer, args.max_len) + eval_dataset = RmStaticDataset(eval_data, tokenizer, args.max_len) + elif args.dataset == "Anthropic/hh-rlhf": + train_dataset = HhRlhfDataset(train_data, tokenizer, args.max_len) + eval_dataset = HhRlhfDataset(eval_data, tokenizer, args.max_len) else: raise ValueError(f'Unsupported dataset "{args.dataset}"') if dist.is_initialized() and dist.get_world_size() > 1: - train_sampler = DistributedSampler(train_dataset, - shuffle=True, - seed=42, - drop_last=True, - rank=dist.get_rank(), - num_replicas=dist.get_world_size()) - valid_sampler = DistributedSampler(valid_dataset, - shuffle=True, - seed=42, - drop_last=True, - rank=dist.get_rank(), - num_replicas=dist.get_world_size()) - eval_sampler = DistributedSampler(eval_dataset, - shuffle=True, - seed=42, - drop_last=True, - rank=dist.get_rank(), - num_replicas=dist.get_world_size()) + train_sampler = DistributedSampler( + train_dataset, + shuffle=True, + seed=42, + drop_last=True, + rank=dist.get_rank(), + num_replicas=dist.get_world_size(), + ) + eval_sampler = DistributedSampler( + eval_dataset, + shuffle=True, + seed=42, + drop_last=True, + rank=dist.get_rank(), + num_replicas=dist.get_world_size(), + ) else: train_sampler = None - valid_sampler = None eval_sampler = None - train_dataloader = DataLoader(train_dataset, - shuffle=(train_sampler is None), - sampler=train_sampler, - batch_size=args.batch_size, - pin_memory=True) - - valid_dataloader = DataLoader(valid_dataset, - shuffle=(valid_sampler is None), - sampler=valid_sampler, - batch_size=args.batch_size, - pin_memory=True) - - eval_dataloader = DataLoader(eval_dataset, - shuffle=(eval_sampler is None), - sampler=eval_sampler, - batch_size=args.batch_size, - pin_memory=True) - - (model, optim) = strategy.prepare((model, optim)) - trainer = RewardModelTrainer(model=model, - strategy=strategy, - optim=optim, - loss_fn=loss_fn, - train_dataloader=train_dataloader, - valid_dataloader=valid_dataloader, - eval_dataloader=eval_dataloader, - max_epochs=args.max_epochs) - - trainer.fit() + train_dataloader = DataLoader( + train_dataset, + shuffle=(train_sampler is None), + sampler=train_sampler, + batch_size=args.batch_size, + pin_memory=True, + ) + + eval_dataloader = DataLoader( + eval_dataset, shuffle=(eval_sampler is None), sampler=eval_sampler, batch_size=args.batch_size, pin_memory=True + ) + + lr_scheduler = CosineAnnealingLR(optim, train_dataloader.__len__() // 100) + strategy_dict = strategy.prepare(dict(model=model, optimizer=optim, lr_scheduler=lr_scheduler)) + model = strategy_dict["model"] + optim = strategy_dict["optimizer"] + lr_scheduler = strategy_dict["lr_scheduler"] + trainer = RewardModelTrainer( + model=model, + strategy=strategy, + optim=optim, + lr_scheduler=lr_scheduler, + loss_fn=loss_fn, + max_epochs=args.max_epochs, + ) + + trainer.fit( + train_dataloader=train_dataloader, + eval_dataloader=eval_dataloader, + log_dir=args.log_dir, + use_wandb=args.use_wandb, + ) + + if args.lora_rank > 0 and args.merge_lora_weights: + from coati.models.lora import LORA_MANAGER + + # NOTE: set model to eval to merge LoRA weights + LORA_MANAGER.merge_weights = True + model.eval() # save model checkpoint after fitting on only rank0 - strategy.save_model(model, args.save_path, only_rank0=True) + state_dict = model.state_dict() + torch.save(state_dict, args.save_path) # save optimizer checkpoint on all ranks if args.need_optim_ckpt: - strategy.save_optimizer(trainer.optimizer, - 'rm_optim_checkpoint_%d.pt' % (torch.cuda.current_device()), - only_rank0=False) + strategy.save_optimizer( + trainer.optimizer, "rm_optim_checkpoint_%d.pt" % (torch.cuda.current_device()), only_rank0=False + ) -if __name__ == '__main__': +if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument('--strategy', - choices=['naive', 'ddp', 'colossalai_gemini', 'colossalai_zero2'], - default='colossalai_zero2') - parser.add_argument('--model', choices=['gpt2', 'bloom', 'opt', 'deberta', 'llama', 'roberta'], default='bloom') - parser.add_argument('--pretrain', type=str, default=None) - parser.add_argument('--model_path', type=str, default=None) - parser.add_argument('--need_optim_ckpt', type=bool, default=False) - parser.add_argument('--dataset', - type=str, - choices=['Anthropic/hh-rlhf', 'Dahoas/rm-static'], - default='Dahoas/rm-static') - parser.add_argument('--subset', type=str, default=None) - parser.add_argument('--save_path', type=str, default='rm_ckpt') - parser.add_argument('--max_epochs', type=int, default=1) - parser.add_argument('--batch_size', type=int, default=1) - parser.add_argument('--max_len', type=int, default=512) - parser.add_argument('--lora_rank', type=int, default=0, help="low-rank adaptation matrices rank") - parser.add_argument('--loss_fn', type=str, default='log_sig', choices=['log_sig', 'log_exp']) - parser.add_argument('--test', type=bool, default=False) + parser.add_argument( + "--strategy", choices=["ddp", "colossalai_gemini", "colossalai_zero2"], default="colossalai_zero2" + ) + parser.add_argument("--model", choices=["gpt2", "bloom", "opt", "llama"], default="bloom") + parser.add_argument("--tokenizer", type=str, default=None) + parser.add_argument("--pretrain", type=str, default=None) + parser.add_argument("--model_path", type=str, default=None) + parser.add_argument("--need_optim_ckpt", type=bool, default=False) + parser.add_argument( + "--dataset", type=str, choices=["Anthropic/hh-rlhf", "Dahoas/rm-static"], default="Dahoas/rm-static" + ) + parser.add_argument("--subset", type=lambda x: None if x == "None" else x, default=None) + parser.add_argument("--max_datasets_size", type=int, default=1000000) + parser.add_argument("--save_path", type=str, default="rm_ckpt") + parser.add_argument("--max_epochs", type=int, default=1) + parser.add_argument("--batch_size", type=int, default=1) + parser.add_argument("--max_len", type=int, default=512) + parser.add_argument("--lora_rank", type=int, default=0, help="low-rank adaptation matrices rank") + parser.add_argument("--merge_lora_weights", type=bool, default=True) + parser.add_argument("--lr", type=float, default=9e-6) + parser.add_argument("--loss_fn", type=str, default="log_sig", choices=["log_sig", "log_exp"]) + parser.add_argument("--log_dir", default="logs", type=str) + parser.add_argument("--use_wandb", default=False, action="store_true") args = parser.parse_args() train(args) diff --git a/applications/Chat/examples/train_rm.sh b/applications/Chat/examples/train_rm.sh index 80abe62d2a3fe9d70c0ab8be1b2e8e3b8afc5e03..c5ebaf708ddca9ddce2c39fda4e2cfc431ac80f9 100755 --- a/applications/Chat/examples/train_rm.sh +++ b/applications/Chat/examples/train_rm.sh @@ -1,13 +1,13 @@ set_n_least_used_CUDA_VISIBLE_DEVICES() { local n=${1:-"9999"} echo "GPU Memory Usage:" - local FIRST_N_GPU_IDS=$(nvidia-smi --query-gpu=memory.used --format=csv \ - | tail -n +2 \ - | nl -v 0 \ - | tee /dev/tty \ - | sort -g -k 2 \ - | awk '{print $1}' \ - | head -n $n) + local FIRST_N_GPU_IDS=$(nvidia-smi --query-gpu=memory.used --format=csv | + tail -n +2 | + nl -v 0 | + tee /dev/tty | + sort -g -k 2 | + awk '{print $1}' | + head -n $n) export CUDA_VISIBLE_DEVICES=$(echo $FIRST_N_GPU_IDS | sed 's/ /,/g') echo "Now CUDA_VISIBLE_DEVICES is set to:" echo "CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES" @@ -16,9 +16,10 @@ set_n_least_used_CUDA_VISIBLE_DEVICES() { set_n_least_used_CUDA_VISIBLE_DEVICES 2 torchrun --standalone --nproc_per_node=2 train_reward_model.py \ - --pretrain \ - --model 'bloom' \ - --strategy colossalai_zero2 \ - --loss_fn 'log_sig'\ - --save_path \ - --dataset 'Anthropic/hh-rlhf'\ + --pretrain 'gpt2' \ + --model 'gpt2' \ + --strategy colossalai_zero2 \ + --loss_fn 'log_exp' \ + --dataset 'Anthropic/hh-rlhf' \ + --batch_size 16 \ + --max_epochs 10 diff --git a/applications/Chat/examples/train_sft.py b/applications/Chat/examples/train_sft.py index da499f068b17885ac468ecd1dcb9de49100f667c..66d08da3012015cfd8438749fe7cc5b2927f4285 100644 --- a/applications/Chat/examples/train_sft.py +++ b/applications/Chat/examples/train_sft.py @@ -1,196 +1,221 @@ import argparse -import os +import math +import warnings -import loralib as lora import torch import torch.distributed as dist -from coati.dataset import DataCollatorForSupervisedDataset, SFTDataset, SupervisedDataset -from coati.models import convert_to_lora_module +from coati.dataset import SFTDataset, SupervisedDataset +from coati.models.bloom import BLOOMActor +from coati.models.chatglm import ChatGLMActor +from coati.models.chatglm.chatglm_tokenizer import ChatGLMTokenizer +from coati.models.gpt import GPTActor +from coati.models.llama import LlamaActor +from coati.models.opt import OPTActor from coati.trainer import SFTTrainer -from coati.trainer.strategies import ColossalAIStrategy, DDPStrategy, NaiveStrategy -from coati.utils import prepare_llama_tokenizer_and_embedding +from coati.trainer.strategies import DDPStrategy, GeminiStrategy, LowLevelZeroStrategy from datasets import load_dataset from torch.optim import Adam from torch.utils.data import DataLoader from torch.utils.data.distributed import DistributedSampler -from transformers import AutoTokenizer, BloomConfig, BloomForCausalLM, BloomTokenizerFast, LlamaConfig, LlamaForCausalLM -from transformers.models.gpt2.configuration_gpt2 import GPT2Config -from transformers.models.gpt2.modeling_gpt2 import GPT2LMHeadModel +from transformers import AutoTokenizer, BloomTokenizerFast, LlamaTokenizer from transformers.models.gpt2.tokenization_gpt2 import GPT2Tokenizer -from transformers.models.opt.configuration_opt import OPTConfig -from transformers.models.opt.modeling_opt import OPTForCausalLM +from transformers.trainer import get_scheduler from colossalai.logging import get_dist_logger from colossalai.nn.optimizer import HybridAdam -from colossalai.tensor import ColoParameter def train(args): # configure strategy - if args.strategy == 'naive': - strategy = NaiveStrategy() - elif args.strategy == 'ddp': + if args.strategy == "ddp": strategy = DDPStrategy() - elif args.strategy == 'colossalai_gemini': - raise NotImplementedError( - 'Gemini is not supported .from_pretrained() yet. We will update this after checkpoint io is ready.') - strategy = ColossalAIStrategy(stage=3, placement_policy='cuda') - elif args.strategy == 'colossalai_zero2': - strategy = ColossalAIStrategy(stage=2, placement_policy='cuda') - elif args.strategy == 'colossalai_zero2_cpu': - strategy = ColossalAIStrategy(stage=2, placement_policy='cpu') + elif args.strategy == "colossalai_gemini": + strategy = GeminiStrategy(placement_policy="auto") + elif args.strategy == "colossalai_zero2": + strategy = LowLevelZeroStrategy(stage=2, placement_policy="cuda") + elif args.strategy == "colossalai_zero2_cpu": + strategy = LowLevelZeroStrategy(stage=2, placement_policy="cpu") else: raise ValueError(f'Unsupported strategy "{args.strategy}"') # configure model + if args.lora_rank > 0: + warnings.warn("Lora is not supported yet.") + args.lora_rank = 0 + with strategy.model_init_context(): - if args.model == 'bloom': - model = convert_to_lora_module(BloomForCausalLM.from_pretrained(args.pretrain), - args.lora_rank).half().cuda() - elif args.model == 'opt': - model = convert_to_lora_module(OPTForCausalLM.from_pretrained(args.pretrain), args.lora_rank).half().cuda() - elif args.model == 'gpt2': - model = convert_to_lora_module(GPT2LMHeadModel.from_pretrained(args.pretrain), args.lora_rank).half().cuda() - elif args.model == 'llama': - model = convert_to_lora_module(LlamaForCausalLM.from_pretrained(args.pretrain), - args.lora_rank).half().cuda() + if args.model == "bloom": + model = BLOOMActor(pretrained=args.pretrain, lora_rank=args.lora_rank, checkpoint=args.grad_checkpoint) + elif args.model == "opt": + model = OPTActor(pretrained=args.pretrain, lora_rank=args.lora_rank, checkpoint=args.grad_checkpoint) + elif args.model == "gpt2": + model = GPTActor(pretrained=args.pretrain, lora_rank=args.lora_rank, checkpoint=args.grad_checkpoint) + elif args.model == "llama": + model = LlamaActor(pretrained=args.pretrain, lora_rank=args.lora_rank, checkpoint=args.grad_checkpoint) + elif args.model == "chatglm": + model = ChatGLMActor(pretrained=args.pretrain) else: raise ValueError(f'Unsupported model "{args.model}"') - if args.grad_checkpoint: - model.gradient_checkpointing_enable() + + model.to(torch.bfloat16).to(torch.cuda.current_device()) # configure tokenizer - if args.model == 'gpt2': - tokenizer = GPT2Tokenizer.from_pretrained('gpt2') + if args.model == "gpt2": + tokenizer = GPT2Tokenizer.from_pretrained("gpt2" if args.tokenizer is None else args.tokenizer) + tokenizer.pad_token = tokenizer.eos_token + elif args.model == "bloom": + tokenizer = BloomTokenizerFast.from_pretrained( + "bigscience/bloom-560m" if args.tokenizer is None else args.tokenizer + ) tokenizer.pad_token = tokenizer.eos_token - elif args.model == 'bloom': - tokenizer = BloomTokenizerFast.from_pretrained(args.pretrain) + elif args.model == "opt": + tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m" if args.tokenizer is None else args.tokenizer) tokenizer.pad_token = tokenizer.eos_token - elif args.model == 'opt': - tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m") - elif args.model == 'llama': - tokenizer = AutoTokenizer.from_pretrained( - args.pretrain, - padding_side="right", - use_fast=False, + elif args.model == "llama": + tokenizer = LlamaTokenizer.from_pretrained( + "hf-internal-testing/llama-tokenizer" if args.tokenizer is None else args.tokenizer + ) + tokenizer.eos_token = "<\s>" + tokenizer.pad_token = tokenizer.unk_token + elif args.model == "chatglm": + tokenizer = ChatGLMTokenizer.from_pretrained( + "THUDM/chatglm-6b" if args.tokenizer is None else args.tokenizer, trust_remote_code=True ) - tokenizer.eos_token = '<\s>' else: raise ValueError(f'Unsupported model "{args.model}"') - tokenizer.pad_token = tokenizer.eos_token - max_len = args.max_len - if args.model == 'llama': - tokenizer = prepare_llama_tokenizer_and_embedding(tokenizer, model) - - if args.strategy == 'colossalai_gemini': - # this is a hack to deal with the resized embedding - # to make sure all parameters are ColoParameter for Colossal-AI Gemini Compatiblity - for name, param in model.named_parameters(): - if not isinstance(param, ColoParameter): - sub_module_name = '.'.join(name.split('.')[:-1]) - weight_name = name.split('.')[-1] - sub_module = model.get_submodule(sub_module_name) - setattr(sub_module, weight_name, ColoParameter(param)) - else: - tokenizer.pad_token = tokenizer.eos_token # configure optimizer - if args.strategy.startswith('colossalai'): + if args.strategy.startswith("colossalai"): optim = HybridAdam(model.parameters(), lr=args.lr, clipping_norm=1.0) else: optim = Adam(model.parameters(), lr=args.lr) - logger = get_dist_logger() - # configure dataset - if args.dataset == 'yizhongw/self_instruct': - train_data = load_dataset(args.dataset, 'super_natural_instructions', split='train') - eval_data = load_dataset(args.dataset, 'super_natural_instructions', split='test') + if args.dataset == "yizhongw/self_instruct": + train_data = load_dataset(args.dataset, "super_natural_instructions", split="train") + eval_data = load_dataset(args.dataset, "super_natural_instructions", split="test") + + if args.max_datasets_size is not None: + train_data = train_data.select(range(min(args.max_datasets_size, len(train_data)))) + eval_data = eval_data.select(range(min(args.max_datasets_size, len(eval_data)))) - train_dataset = SFTDataset(train_data, tokenizer, max_len) - eval_dataset = SFTDataset(eval_data, tokenizer, max_len) + train_dataset = SFTDataset(train_data, tokenizer, args.max_len) + eval_dataset = SFTDataset(eval_data, tokenizer, args.max_len) else: - train_dataset = SupervisedDataset(tokenizer=tokenizer, - data_path=args.dataset, - max_datasets_size=args.max_datasets_size, - max_length=max_len) + train_dataset = SupervisedDataset( + tokenizer=tokenizer, + data_path=args.dataset, + max_datasets_size=args.max_datasets_size, + max_length=args.max_len, + ) eval_dataset = None - data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer) if dist.is_initialized() and dist.get_world_size() > 1: - train_sampler = DistributedSampler(train_dataset, - shuffle=True, - seed=42, - drop_last=True, - rank=dist.get_rank(), - num_replicas=dist.get_world_size()) + train_sampler = DistributedSampler( + train_dataset, + shuffle=True, + seed=42, + drop_last=True, + rank=dist.get_rank(), + num_replicas=dist.get_world_size(), + ) if eval_dataset is not None: - eval_sampler = DistributedSampler(eval_dataset, - shuffle=False, - seed=42, - drop_last=False, - rank=dist.get_rank(), - num_replicas=dist.get_world_size()) + eval_sampler = DistributedSampler( + eval_dataset, + shuffle=False, + seed=42, + drop_last=False, + rank=dist.get_rank(), + num_replicas=dist.get_world_size(), + ) else: train_sampler = None eval_sampler = None - train_dataloader = DataLoader(train_dataset, - shuffle=(train_sampler is None), - sampler=train_sampler, - batch_size=args.batch_size, - collate_fn=data_collator, - pin_memory=True) + train_dataloader = DataLoader( + train_dataset, + shuffle=(train_sampler is None), + sampler=train_sampler, + batch_size=args.batch_size, + pin_memory=True, + ) if eval_dataset is not None: - eval_dataloader = DataLoader(eval_dataset, - shuffle=(eval_sampler is None), - sampler=eval_sampler, - batch_size=args.batch_size, - collate_fn=data_collator, - pin_memory=True) + eval_dataloader = DataLoader( + eval_dataset, + shuffle=(eval_sampler is None), + sampler=eval_sampler, + batch_size=args.batch_size, + pin_memory=True, + ) else: eval_dataloader = None - (model, optim) = strategy.prepare((model, optim)) - trainer = SFTTrainer(model=model, - strategy=strategy, - optim=optim, - train_dataloader=train_dataloader, - eval_dataloader=eval_dataloader, - max_epochs=args.max_epochs, - accumulation_steps=args.accumulation_steps) - - trainer.fit(logger=logger, use_wandb=args.use_wandb) + num_update_steps_per_epoch = len(train_dataloader) // args.accumulation_steps + max_steps = math.ceil(args.max_epochs * num_update_steps_per_epoch) + lr_scheduler = get_scheduler( + "cosine", optim, num_warmup_steps=math.ceil(max_steps * 0.03), num_training_steps=max_steps + ) + strategy_dict = strategy.prepare(dict(model=model, optimizer=optim, lr_scheduler=lr_scheduler)) + model = strategy_dict["model"] + optim = strategy_dict["optimizer"] + lr_scheduler = strategy_dict["lr_scheduler"] + trainer = SFTTrainer( + model=model, + strategy=strategy, + optim=optim, + lr_scheduler=lr_scheduler, + max_epochs=args.max_epochs, + accumulation_steps=args.accumulation_steps, + ) + logger = get_dist_logger() + trainer.fit( + train_dataloader=train_dataloader, + eval_dataloader=eval_dataloader, + logger=logger, + log_dir=args.log_dir, + use_wandb=args.use_wandb, + ) + + if args.lora_rank > 0 and args.merge_lora_weights: + from coati.models.lora import LORA_MANAGER + + # NOTE: set model to eval to merge LoRA weights + LORA_MANAGER.merge_weights = True + model.eval() # save model checkpoint after fitting on only rank0 - strategy.save_pretrained(model, path=args.save_path, only_rank0=True, tokenizer=tokenizer) + strategy.save_pretrained(model, path=args.save_path, tokenizer=tokenizer) # save optimizer checkpoint on all ranks if args.need_optim_ckpt: - strategy.save_optimizer(trainer.optimizer, - 'rm_optim_checkpoint_%d.pt' % (torch.cuda.current_device()), - only_rank0=False) + strategy.save_optimizer( + trainer.optimizer, "rm_optim_checkpoint_%d.pt" % (torch.cuda.current_device()), only_rank0=False + ) -if __name__ == '__main__': +if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument('--strategy', - choices=['naive', 'ddp', 'colossalai_gemini', 'colossalai_zero2', 'colossalai_zero2_cpu'], - default='colossalai_zero2') - parser.add_argument('--model', choices=['gpt2', 'bloom', 'opt', 'llama'], default='bloom') - parser.add_argument('--pretrain', type=str, default=None) - parser.add_argument('--dataset', type=str, default=None) - parser.add_argument('--max_datasets_size', type=int, default=None) - parser.add_argument('--save_path', type=str, default='output') - parser.add_argument('--need_optim_ckpt', type=bool, default=False) - parser.add_argument('--max_epochs', type=int, default=3) - parser.add_argument('--batch_size', type=int, default=4) - parser.add_argument('--max_len', type=int, default=512) - parser.add_argument('--lora_rank', type=int, default=0, help="low-rank adaptation matrices rank") - parser.add_argument('--log_interval', type=int, default=100, help="how many steps to log") - parser.add_argument('--lr', type=float, default=5e-6) - parser.add_argument('--accumulation_steps', type=int, default=8) - parser.add_argument('--use_wandb', default=False, action='store_true') - parser.add_argument('--grad_checkpoint', default=False, action='store_true') + parser.add_argument( + "--strategy", + choices=["ddp", "colossalai_gemini", "colossalai_zero2", "colossalai_zero2_cpu"], + default="colossalai_zero2", + ) + parser.add_argument("--model", choices=["gpt2", "bloom", "opt", "llama", "chatglm"], default="bloom") + parser.add_argument("--tokenizer", type=str, default=None) + parser.add_argument("--pretrain", type=str, default=None) + parser.add_argument("--dataset", type=str, default=None) + parser.add_argument("--max_datasets_size", type=int, default=None) + parser.add_argument("--save_path", type=str, default="output") + parser.add_argument("--need_optim_ckpt", type=bool, default=False) + parser.add_argument("--max_epochs", type=int, default=3) + parser.add_argument("--batch_size", type=int, default=4) + parser.add_argument("--max_len", type=int, default=512) + parser.add_argument("--lora_rank", type=int, default=0, help="low-rank adaptation matrices rank") + parser.add_argument("--merge_lora_weights", type=bool, default=True) + parser.add_argument("--lr", type=float, default=5e-6) + parser.add_argument("--accumulation_steps", type=int, default=8) + parser.add_argument("--log_dir", default="logs", type=str) + parser.add_argument("--use_wandb", default=False, action="store_true") + parser.add_argument("--grad_checkpoint", default=False, action="store_true") args = parser.parse_args() train(args) diff --git a/applications/Chat/examples/train_sft.sh b/applications/Chat/examples/train_sft.sh index c880f85825a77a98ea49ce691bb5cf4fcabca857..0fb4da3d3ce8bc6091374ca98e8bbed0738f8c1d 100755 --- a/applications/Chat/examples/train_sft.sh +++ b/applications/Chat/examples/train_sft.sh @@ -1,12 +1,28 @@ +set_n_least_used_CUDA_VISIBLE_DEVICES() { + local n=${1:-"9999"} + echo "GPU Memory Usage:" + local FIRST_N_GPU_IDS=$(nvidia-smi --query-gpu=memory.used --format=csv | + tail -n +2 | + nl -v 0 | + tee /dev/tty | + sort -g -k 2 | + awk '{print $1}' | + head -n $n) + export CUDA_VISIBLE_DEVICES=$(echo $FIRST_N_GPU_IDS | sed 's/ /,/g') + echo "Now CUDA_VISIBLE_DEVICES is set to:" + echo "CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES" +} + +set_n_least_used_CUDA_VISIBLE_DEVICES 4 + torchrun --standalone --nproc_per_node=4 train_sft.py \ --pretrain "/path/to/LLaMa-7B/" \ --model 'llama' \ --strategy colossalai_zero2 \ - --log_interval 10 \ - --save_path /path/to/Coati-7B \ + --save_path /path/to/Coati-7B \ --dataset /path/to/data.json \ --batch_size 4 \ --accumulation_steps 8 \ --lr 2e-5 \ --max_datasets_size 512 \ - --max_epochs 1 \ + --max_epochs 1 diff --git a/applications/Chat/inference/README.md b/applications/Chat/inference/README.md index 434677c98fa58f7050098671c6243c6f70d023a4..eea4ef5b86ca99106a813ca0a70f49bdf5b65c5e 100644 --- a/applications/Chat/inference/README.md +++ b/applications/Chat/inference/README.md @@ -20,21 +20,21 @@ Tha data is from [LLaMA Int8 4bit ChatBot Guide v2](https://rentry.org/llama-tar ### 8-bit -| Model | Min GPU RAM | Recommended GPU RAM | Min RAM/Swap | Card examples | -| :---: | :---: | :---: | :---: | :---: | -| LLaMA-7B | 9.2GB | 10GB | 24GB | 3060 12GB, RTX 3080 10GB, RTX 3090 | -| LLaMA-13B | 16.3GB | 20GB | 32GB | RTX 3090 Ti, RTX 4090 | -| LLaMA-30B | 36GB | 40GB | 64GB | A6000 48GB, A100 40GB | -| LLaMA-65B | 74GB | 80GB | 128GB | A100 80GB | +| Model | Min GPU RAM | Recommended GPU RAM | Min RAM/Swap | Card examples | +| :-------: | :---------: | :-----------------: | :----------: | :--------------------------------: | +| LLaMA-7B | 9.2GB | 10GB | 24GB | 3060 12GB, RTX 3080 10GB, RTX 3090 | +| LLaMA-13B | 16.3GB | 20GB | 32GB | RTX 3090 Ti, RTX 4090 | +| LLaMA-30B | 36GB | 40GB | 64GB | A6000 48GB, A100 40GB | +| LLaMA-65B | 74GB | 80GB | 128GB | A100 80GB | ### 4-bit -| Model | Min GPU RAM | Recommended GPU RAM | Min RAM/Swap | Card examples | -| :---: | :---: | :---: | :---: | :---: | -| LLaMA-7B | 3.5GB | 6GB | 16GB | RTX 1660, 2060, AMD 5700xt, RTX 3050, 3060 | -| LLaMA-13B | 6.5GB | 10GB | 32GB | AMD 6900xt, RTX 2060 12GB, 3060 12GB, 3080, A2000 | -| LLaMA-30B | 15.8GB | 20GB | 64GB | RTX 3080 20GB, A4500, A5000, 3090, 4090, 6000, Tesla V100 | -| LLaMA-65B | 31.2GB | 40GB | 128GB | A100 40GB, 2x3090, 2x4090, A40, RTX A6000, 8000, Titan Ada | +| Model | Min GPU RAM | Recommended GPU RAM | Min RAM/Swap | Card examples | +| :-------: | :---------: | :-----------------: | :----------: | :--------------------------------------------------------: | +| LLaMA-7B | 3.5GB | 6GB | 16GB | RTX 1660, 2060, AMD 5700xt, RTX 3050, 3060 | +| LLaMA-13B | 6.5GB | 10GB | 32GB | AMD 6900xt, RTX 2060 12GB, 3060 12GB, 3080, A2000 | +| LLaMA-30B | 15.8GB | 20GB | 64GB | RTX 3080 20GB, A4500, A5000, 3090, 4090, 6000, Tesla V100 | +| LLaMA-65B | 31.2GB | 40GB | 128GB | A100 40GB, 2x3090, 2x4090, A40, RTX A6000, 8000, Titan Ada | ## General setup @@ -75,7 +75,7 @@ E.g. you can set `export LD_LIBRARY_PATH=$CUDA_HOME/lib64:$LD_LIBRARY_PATH`. Please ensure you have downloaded HF-format model weights of LLaMA models first. -Then you can follow [GPTQ-for-LLaMa](https://github.com/qwopqwop200/GPTQ-for-LLaMa). This lib provides efficient CUDA kernels and weight convertion script. +Then you can follow [GPTQ-for-LLaMa](https://github.com/qwopqwop200/GPTQ-for-LLaMa). This lib provides efficient CUDA kernels and weight conversion script. After installing this lib, we may convert the original HF-format LLaMA model weights to 4-bit version. diff --git a/applications/Chat/inference/benchmark.py b/applications/Chat/inference/benchmark.py index 59cd1eeea2aa841ae805d91adf279f668fdb3dd0..dbb5490a63dc2572957a02d39874376a1079f904 100644 --- a/applications/Chat/inference/benchmark.py +++ b/applications/Chat/inference/benchmark.py @@ -4,8 +4,8 @@ import argparse from time import time import torch -from llama_gptq import load_quant -from transformers import AutoTokenizer, GenerationConfig, LlamaForCausalLM +from coati.quant import llama_load_quant, low_resource_init +from transformers import AutoTokenizer, GenerationConfig, LlamaConfig, LlamaForCausalLM def generate_prompt(instruction, input=None): @@ -84,49 +84,58 @@ inst = [instructions[0]] * 4 if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument( - 'pretrained', - help='Path to pretrained model. Can be a local path or a model name from the HuggingFace model hub.') - parser.add_argument('--quant', - choices=['8bit', '4bit'], - default=None, - help='Quantization mode. Default: None (no quantization, fp16).') + "pretrained", + help="Path to pretrained model. Can be a local path or a model name from the HuggingFace model hub.", + ) + parser.add_argument( + "--quant", + choices=["8bit", "4bit"], + default=None, + help="Quantization mode. Default: None (no quantization, fp16).", + ) parser.add_argument( - '--gptq_checkpoint', + "--gptq_checkpoint", default=None, - help='Path to GPTQ checkpoint. This is only useful when quantization mode is 4bit. Default: None.') - parser.add_argument('--gptq_group_size', - type=int, - default=128, - help='Group size for GPTQ. This is only useful when quantization mode is 4bit. Default: 128.') + help="Path to GPTQ checkpoint. This is only useful when quantization mode is 4bit. Default: None.", + ) + parser.add_argument( + "--gptq_group_size", + type=int, + default=128, + help="Group size for GPTQ. This is only useful when quantization mode is 4bit. Default: 128.", + ) args = parser.parse_args() - if args.quant == '4bit': - assert args.gptq_checkpoint is not None, 'Please specify a GPTQ checkpoint.' + if args.quant == "4bit": + assert args.gptq_checkpoint is not None, "Please specify a GPTQ checkpoint." tokenizer = AutoTokenizer.from_pretrained(args.pretrained) - if args.quant == '4bit': - model = load_quant(args.pretrained, args.gptq_checkpoint, 4, args.gptq_group_size) + if args.quant == "4bit": + with low_resource_init(): + config = LlamaConfig.from_pretrained(args.pretrained) + model = LlamaForCausalLM(config) + model = llama_load_quant(model, args.gptq_checkpoint, 4, args.gptq_group_size) model.cuda() else: model = LlamaForCausalLM.from_pretrained( args.pretrained, - load_in_8bit=(args.quant == '8bit'), + load_in_8bit=(args.quant == "8bit"), torch_dtype=torch.float16, device_map="auto", ) - if args.quant != '8bit': - model.half() # seems to fix bugs for some users. + if args.quant != "8bit": + model.half() # seems to fix bugs for some users. model.eval() total_tokens = 0 start = time() for instruction in instructions: print(f"Instruction: {instruction}") - resp, tokens = evaluate(model, tokenizer, instruction, temparature=0.2, num_beams=1) + resp, tokens = evaluate(model, tokenizer, instruction, temperature=0.2, num_beams=1) total_tokens += tokens print(f"Response: {resp}") - print('\n----------------------------\n') + print("\n----------------------------\n") duration = time() - start - print(f'Total time: {duration:.3f} s, {total_tokens/duration:.3f} tokens/s') - print(f'Peak CUDA mem: {torch.cuda.max_memory_allocated()/1024**3:.3f} GB') + print(f"Total time: {duration:.3f} s, {total_tokens/duration:.3f} tokens/s") + print(f"Peak CUDA mem: {torch.cuda.max_memory_allocated()/1024**3:.3f} GB") diff --git a/applications/Chat/inference/llama_gptq/__init__.py b/applications/Chat/inference/llama_gptq/__init__.py deleted file mode 100644 index 51c8d6316290fe2fcef7d972803017c830d3e1b4..0000000000000000000000000000000000000000 --- a/applications/Chat/inference/llama_gptq/__init__.py +++ /dev/null @@ -1,5 +0,0 @@ -from .loader import load_quant - -__all__ = [ - 'load_quant', -] diff --git a/applications/Chat/inference/llama_gptq/loader.py b/applications/Chat/inference/llama_gptq/loader.py deleted file mode 100644 index a5c6ac7d1589aa1873918b9c8b02edcfe13ed59f..0000000000000000000000000000000000000000 --- a/applications/Chat/inference/llama_gptq/loader.py +++ /dev/null @@ -1,41 +0,0 @@ -import torch -import torch.nn as nn -import transformers -from transformers import LlamaConfig, LlamaForCausalLM - -from .model_utils import find_layers -from .quant import make_quant - - -def load_quant(pretrained: str, checkpoint: str, wbits: int, groupsize: int): - config = LlamaConfig.from_pretrained(pretrained) - - def noop(*args, **kwargs): - pass - - torch.nn.init.kaiming_uniform_ = noop - torch.nn.init.uniform_ = noop - torch.nn.init.normal_ = noop - - torch.set_default_dtype(torch.half) - transformers.modeling_utils._init_weights = False - torch.set_default_dtype(torch.half) - model = LlamaForCausalLM(config) - torch.set_default_dtype(torch.float) - model = model.eval() - layers = find_layers(model) - for name in ['lm_head']: - if name in layers: - del layers[name] - make_quant(model, layers, wbits, groupsize) - - print(f'Loading model with {wbits} bits...') - if checkpoint.endswith('.safetensors'): - from safetensors.torch import load_file as safe_load - model.load_state_dict(safe_load(checkpoint)) - else: - model.load_state_dict(torch.load(checkpoint)) - model.seqlen = 2048 - print('Done.') - - return model diff --git a/applications/Chat/inference/llama_gptq/model_utils.py b/applications/Chat/inference/llama_gptq/model_utils.py deleted file mode 100644 index 62db171abb52cb88799a8b73d608f2617208cefe..0000000000000000000000000000000000000000 --- a/applications/Chat/inference/llama_gptq/model_utils.py +++ /dev/null @@ -1,13 +0,0 @@ -# copied from https://github.com/qwopqwop200/GPTQ-for-LLaMa/blob/past/modelutils.py - -import torch -import torch.nn as nn - - -def find_layers(module, layers=[nn.Conv2d, nn.Linear], name=''): - if type(module) in layers: - return {name: module} - res = {} - for name1, child in module.named_children(): - res.update(find_layers(child, layers=layers, name=name + '.' + name1 if name != '' else name1)) - return res diff --git a/applications/Chat/inference/locustfile.py b/applications/Chat/inference/locustfile.py index 51cdc68125bba42d29a91f03285847e0bde27ea8..333262e538ac46d75734f14e3b648b65c3a2ff3d 100644 --- a/applications/Chat/inference/locustfile.py +++ b/applications/Chat/inference/locustfile.py @@ -1,27 +1,26 @@ -from json import JSONDecodeError - from locust import HttpUser, task -samples = [[ - dict( - instruction='Who is the best player in the history of NBA?', - response= - 'The best player in the history of the NBA is widely considered to be Michael Jordan. He is one of the most successful players in the league, having won 6 NBA championships with the Chicago Bulls and 5 more with the Washington Wizards. He is a 5-time MVP, 1' - ), - dict(instruction='continue this talk', response=''), -], [ - dict(instruction='Who is the best player in the history of NBA?', response=''), -]] +samples = [ + [ + dict( + instruction="Who is the best player in the history of NBA?", + response="The best player in the history of the NBA is widely considered to be Michael Jordan. He is one of the most successful players in the league, having won 6 NBA championships with the Chicago Bulls and 5 more with the Washington Wizards. He is a 5-time MVP, 1", + ), + dict(instruction="continue this talk", response=""), + ], + [ + dict(instruction="Who is the best player in the history of NBA?", response=""), + ], +] class GenerationUser(HttpUser): - @task def generate(self): for sample in samples: - data = {'max_new_tokens': 64, 'history': sample} - with self.client.post('/generate', json=data, catch_response=True) as response: + data = {"max_new_tokens": 64, "history": sample} + with self.client.post("/generate", json=data, catch_response=True) as response: if response.status_code in (200, 406): response.success() else: - response.failure('Response wrong') + response.failure("Response wrong") diff --git a/applications/Chat/inference/requirements.txt b/applications/Chat/inference/requirements.txt index 511fe1a4f1f339b23f1a11162925703b94e15013..cb6275361736a730cb53a95d3d4090407fd3f6c4 100644 --- a/applications/Chat/inference/requirements.txt +++ b/applications/Chat/inference/requirements.txt @@ -10,4 +10,4 @@ uvicorn git+https://github.com/huggingface/transformers accelerate bitsandbytes -jieba \ No newline at end of file +jieba diff --git a/applications/Chat/inference/server.py b/applications/Chat/inference/server.py index b4627299397e6949576318de3938ee9c19aa390c..7c6a61b9e7f2c37caafc278c8f7923d3ca7429e1 100644 --- a/applications/Chat/inference/server.py +++ b/applications/Chat/inference/server.py @@ -1,22 +1,22 @@ import argparse import os from threading import Lock -from typing import Dict, Generator, List, Optional +from typing import Generator, List, Optional import torch import uvicorn -from fastapi import FastAPI, HTTPException, Request +from coati.quant import llama_load_quant, low_resource_init +from fastapi import FastAPI, Request from fastapi.middleware.cors import CORSMiddleware -from llama_gptq import load_quant from pydantic import BaseModel, Field from slowapi import Limiter, _rate_limit_exceeded_handler from slowapi.errors import RateLimitExceeded from slowapi.util import get_remote_address from sse_starlette.sse import EventSourceResponse -from transformers import AutoTokenizer, GenerationConfig, LlamaForCausalLM -from utils import ChatPromptProcessor, Dialogue, LockedIterator, sample_streamingly, update_model_kwargs_fn, load_json +from transformers import AutoTokenizer, LlamaConfig, LlamaForCausalLM +from utils import ChatPromptProcessor, Dialogue, LockedIterator, load_json, sample_streamingly, update_model_kwargs_fn -CONTEXT = 'Below is an instruction that describes a task. Write a response that appropriately completes the request. Do not generate new instructions.' +CONTEXT = "Below is an instruction that describes a task. Write a response that appropriately completes the request. Do not generate new instructions." MAX_LEN = 512 running_lock = Lock() @@ -36,11 +36,11 @@ app.state.limiter = limiter app.add_exception_handler(RateLimitExceeded, _rate_limit_exceeded_handler) # set CORS -origin_spec_from_env = os.environ.get('CORS_ORIGIN', None) +origin_spec_from_env = os.environ.get("CORS_ORIGIN", None) if origin_spec_from_env is not None: # allow CORS from the specified origins - origins = os.environ['CORS_ORIGIN'].split(',') + origins = os.environ["CORS_ORIGIN"].split(",") else: # allow CORS from all origins origins = ["*"] @@ -56,15 +56,15 @@ app.add_middleware( def generate_streamingly(prompt, max_new_tokens, top_k, top_p, temperature): inputs = {k: v.cuda() for k, v in tokenizer(prompt, return_tensors="pt").items()} - #TODO(ver217): streaming generation does not support repetition_penalty now + # TODO(ver217): streaming generation does not support repetition_penalty now model_kwargs = { - 'max_generate_tokens': max_new_tokens, - 'early_stopping': True, - 'top_k': top_k, - 'top_p': top_p, - 'temperature': temperature, - 'prepare_inputs_fn': model.prepare_inputs_for_generation, - 'update_model_kwargs_fn': update_model_kwargs_fn, + "max_generate_tokens": max_new_tokens, + "early_stopping": True, + "top_k": top_k, + "top_p": top_p, + "temperature": temperature, + "prepare_inputs_fn": model.prepare_inputs_for_generation, + "update_model_kwargs_fn": update_model_kwargs_fn, } is_first_word = True generator = LockedIterator(sample_streamingly(model, **inputs, **model_kwargs), running_lock) @@ -81,9 +81,9 @@ def generate_streamingly(prompt, max_new_tokens, top_k, top_p, temperature): if is_first_word: out_string = out_string.lstrip() is_first_word = False - elif current_sub_tokens[0].startswith('▁'): + elif current_sub_tokens[0].startswith("▁"): # whitespace will be ignored by the frontend - out_string = ' ' + out_string + out_string = " " + out_string yield out_string @@ -92,32 +92,33 @@ async def event_generator(request: Request, generator: Generator): if await request.is_disconnected(): break try: - yield {'event': 'generate', 'data': next(generator)} + yield {"event": "generate", "data": next(generator)} except StopIteration: - yield {'event': 'end', 'data': ''} + yield {"event": "end", "data": ""} break -@app.post('/generate/stream') -@limiter.limit('1/second') +@app.post("/generate/stream") +@limiter.limit("1/second") def generate(data: GenerationTaskReq, request: Request): prompt = prompt_processor.preprocess_prompt(data.history, data.max_new_tokens) event_source = event_generator( - request, generate_streamingly(prompt, data.max_new_tokens, data.top_k, data.top_p, data.temperature)) + request, generate_streamingly(prompt, data.max_new_tokens, data.top_k, data.top_p, data.temperature) + ) return EventSourceResponse(event_source) -@app.post('/generate') -@limiter.limit('1/second') +@app.post("/generate") +@limiter.limit("1/second") def generate_no_stream(data: GenerationTaskReq, request: Request): prompt = prompt_processor.preprocess_prompt(data.history, data.max_new_tokens) if prompt_processor.has_censored_words(prompt): return prompt_processor.SAFE_RESPONSE inputs = {k: v.cuda() for k, v in tokenizer(prompt, return_tensors="pt").items()} with running_lock: - output = model.generate(**inputs, **data.dict(exclude={'history'})) + output = model.generate(**inputs, **data.dict(exclude={"history"})) output = output.cpu() - prompt_len = inputs['input_ids'].size(1) + prompt_len = inputs["input_ids"].size(1) response = output[0, prompt_len:] out_string = tokenizer.decode(response, skip_special_tokens=True) out_string = prompt_processor.postprocess_output(out_string) @@ -126,30 +127,40 @@ def generate_no_stream(data: GenerationTaskReq, request: Request): return out_string -if __name__ == '__main__': +if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument( - 'pretrained', - help='Path to pretrained model. Can be a local path or a model name from the HuggingFace model hub.') - parser.add_argument('--quant', - choices=['8bit', '4bit'], - default=None, - help='Quantization mode. Default: None (no quantization, fp16).') + "pretrained", + help="Path to pretrained model. Can be a local path or a model name from the HuggingFace model hub.", + ) parser.add_argument( - '--gptq_checkpoint', + "--quant", + choices=["8bit", "4bit"], default=None, - help='Path to GPTQ checkpoint. This is only useful when quantization mode is 4bit. Default: None.') - parser.add_argument('--gptq_group_size', - type=int, - default=128, - help='Group size for GPTQ. This is only useful when quantization mode is 4bit. Default: 128.') - parser.add_argument('--http_host', default='0.0.0.0') - parser.add_argument('--http_port', type=int, default=7070) - parser.add_argument('--profanity_file', default=None, help='Path to profanity words list. It should be a JSON file containing a list of words.') + help="Quantization mode. Default: None (no quantization, fp16).", + ) + parser.add_argument( + "--gptq_checkpoint", + default=None, + help="Path to GPTQ checkpoint. This is only useful when quantization mode is 4bit. Default: None.", + ) + parser.add_argument( + "--gptq_group_size", + type=int, + default=128, + help="Group size for GPTQ. This is only useful when quantization mode is 4bit. Default: 128.", + ) + parser.add_argument("--http_host", default="0.0.0.0") + parser.add_argument("--http_port", type=int, default=7070) + parser.add_argument( + "--profanity_file", + default=None, + help="Path to profanity words list. It should be a JSON file containing a list of words.", + ) args = parser.parse_args() - if args.quant == '4bit': - assert args.gptq_checkpoint is not None, 'Please specify a GPTQ checkpoint.' + if args.quant == "4bit": + assert args.gptq_checkpoint is not None, "Please specify a GPTQ checkpoint." tokenizer = AutoTokenizer.from_pretrained(args.pretrained) @@ -159,18 +170,21 @@ if __name__ == '__main__': censored_words = [] prompt_processor = ChatPromptProcessor(tokenizer, CONTEXT, MAX_LEN, censored_words=censored_words) - if args.quant == '4bit': - model = load_quant(args.pretrained, args.gptq_checkpoint, 4, args.gptq_group_size) + if args.quant == "4bit": + with low_resource_init(): + config = LlamaConfig.from_pretrained(args.pretrained) + model = LlamaForCausalLM(config) + model = llama_load_quant(model, args.gptq_checkpoint, 4, args.gptq_group_size) model.cuda() else: model = LlamaForCausalLM.from_pretrained( args.pretrained, - load_in_8bit=(args.quant == '8bit'), + load_in_8bit=(args.quant == "8bit"), torch_dtype=torch.float16, device_map="auto", ) - if args.quant != '8bit': - model.half() # seems to fix bugs for some users. + if args.quant != "8bit": + model.half() # seems to fix bugs for some users. model.eval() config = uvicorn.Config(app, host=args.http_host, port=args.http_port) diff --git a/applications/Chat/inference/tests/test_chat_prompt.py b/applications/Chat/inference/tests/test_chat_prompt.py index f5737ebe8c097d73073bb21195341b378e7fc2f1..9835e71894c66fdeffd0bb625259da936af4d776 100644 --- a/applications/Chat/inference/tests/test_chat_prompt.py +++ b/applications/Chat/inference/tests/test_chat_prompt.py @@ -3,44 +3,49 @@ import os from transformers import AutoTokenizer from utils import ChatPromptProcessor, Dialogue -CONTEXT = 'Below is an instruction that describes a task. Write a response that appropriately completes the request. Do not generate new instructions.' -tokenizer = AutoTokenizer.from_pretrained(os.environ['PRETRAINED_PATH']) +CONTEXT = "Below is an instruction that describes a task. Write a response that appropriately completes the request. Do not generate new instructions." +tokenizer = AutoTokenizer.from_pretrained(os.environ["PRETRAINED_PATH"]) samples = [ - ([ - Dialogue( - instruction='Who is the best player in the history of NBA?', - response= - 'The best player in the history of the NBA is widely considered to be Michael Jordan. He is one of the most successful players in the league, having won 6 NBA championships with the Chicago Bulls and 5 more with the Washington Wizards. He is a 5-time MVP, 1' - ), - Dialogue(instruction='continue this talk', response=''), - ], 128, - 'Below is an instruction that describes a task. Write a response that appropriately completes the request. Do not generate new instructions.\n\n### Instruction:\nWho is the best player in the history of NBA?\n\n### Response:\nThe best player in the history of the NBA is widely considered to be Michael Jordan. He is one of the most successful players in the league, having won 6 NBA championships with the Chicago Bulls and 5 more with the Washington Wizards. He is a 5-time MVP, 1\n\n### Instruction:\ncontinue this talk\n\n### Response:\n' + ( + [ + Dialogue( + instruction="Who is the best player in the history of NBA?", + response="The best player in the history of the NBA is widely considered to be Michael Jordan. He is one of the most successful players in the league, having won 6 NBA championships with the Chicago Bulls and 5 more with the Washington Wizards. He is a 5-time MVP, 1", + ), + Dialogue(instruction="continue this talk", response=""), + ], + 128, + "Below is an instruction that describes a task. Write a response that appropriately completes the request. Do not generate new instructions.\n\n### Instruction:\nWho is the best player in the history of NBA?\n\n### Response:\nThe best player in the history of the NBA is widely considered to be Michael Jordan. He is one of the most successful players in the league, having won 6 NBA championships with the Chicago Bulls and 5 more with the Washington Wizards. He is a 5-time MVP, 1\n\n### Instruction:\ncontinue this talk\n\n### Response:\n", ), - ([ - Dialogue( - instruction='Who is the best player in the history of NBA?', - response= - 'The best player in the history of the NBA is widely considered to be Michael Jordan. He is one of the most successful players in the league, having won 6 NBA championships with the Chicago Bulls and 5 more with the Washington Wizards. He is a 5-time MVP, 1' - ), - Dialogue(instruction='continue this talk', response=''), - ], 200, - 'Below is an instruction that describes a task. Write a response that appropriately completes the request. Do not generate new instructions.\n\n### Instruction:\ncontinue this talk\n\n### Response:\n' + ( + [ + Dialogue( + instruction="Who is the best player in the history of NBA?", + response="The best player in the history of the NBA is widely considered to be Michael Jordan. He is one of the most successful players in the league, having won 6 NBA championships with the Chicago Bulls and 5 more with the Washington Wizards. He is a 5-time MVP, 1", + ), + Dialogue(instruction="continue this talk", response=""), + ], + 200, + "Below is an instruction that describes a task. Write a response that appropriately completes the request. Do not generate new instructions.\n\n### Instruction:\ncontinue this talk\n\n### Response:\n", ), - ([ - Dialogue( - instruction='Who is the best player in the history of NBA?', - response= - 'The best player in the history of the NBA is widely considered to be Michael Jordan. He is one of the most successful players in the league, having won 6 NBA championships with the Chicago Bulls and 5 more with the Washington Wizards. He is a 5-time MVP, 1' - ), - Dialogue(instruction='continue this talk', response=''), - ], 211, - 'Below is an instruction that describes a task. Write a response that appropriately completes the request. Do not generate new instructions.\n\n### Instruction:\ncontinue this\n\n### Response:\n' + ( + [ + Dialogue( + instruction="Who is the best player in the history of NBA?", + response="The best player in the history of the NBA is widely considered to be Michael Jordan. He is one of the most successful players in the league, having won 6 NBA championships with the Chicago Bulls and 5 more with the Washington Wizards. He is a 5-time MVP, 1", + ), + Dialogue(instruction="continue this talk", response=""), + ], + 211, + "Below is an instruction that describes a task. Write a response that appropriately completes the request. Do not generate new instructions.\n\n### Instruction:\ncontinue this\n\n### Response:\n", ), - ([ - Dialogue(instruction='Who is the best player in the history of NBA?', response=''), - ], 128, - 'Below is an instruction that describes a task. Write a response that appropriately completes the request. Do not generate new instructions.\n\n### Instruction:\nWho is the best player in the history of NBA?\n\n### Response:\n' + ( + [ + Dialogue(instruction="Who is the best player in the history of NBA?", response=""), + ], + 128, + "Below is an instruction that describes a task. Write a response that appropriately completes the request. Do not generate new instructions.\n\n### Instruction:\nWho is the best player in the history of NBA?\n\n### Response:\n", ), ] @@ -52,5 +57,5 @@ def test_chat_prompt_processor(): assert prompt == result -if __name__ == '__main__': +if __name__ == "__main__": test_chat_prompt_processor() diff --git a/applications/Chat/inference/utils.py b/applications/Chat/inference/utils.py index 37944be70a3bf9631f0a995bd1753d71d6c8b5aa..af018adf6e9de84c846281c9aa26a5459da14bf1 100644 --- a/applications/Chat/inference/utils.py +++ b/applications/Chat/inference/utils.py @@ -1,9 +1,9 @@ +import json import re from threading import Lock from typing import Any, Callable, Generator, List, Optional -import json -import jieba +import jieba import torch import torch.distributed as dist import torch.nn as nn @@ -20,9 +20,9 @@ except ImportError: from transformers.generation import LogitsProcessorList, TemperatureLogitsWarper, TopKLogitsWarper, TopPLogitsWarper -def prepare_logits_processor(top_k: Optional[int] = None, - top_p: Optional[float] = None, - temperature: Optional[float] = None) -> LogitsProcessorList: +def prepare_logits_processor( + top_k: Optional[int] = None, top_p: Optional[float] = None, temperature: Optional[float] = None +) -> LogitsProcessorList: processor_list = LogitsProcessorList() if temperature is not None and temperature != 1.0: processor_list.append(TemperatureLogitsWarper(temperature)) @@ -41,29 +41,30 @@ def _is_sequence_finished(unfinished_sequences: torch.Tensor) -> bool: return unfinished_sequences.max() == 0 -def sample_streamingly(model: nn.Module, - input_ids: torch.Tensor, - max_generate_tokens: int, - early_stopping: bool = False, - eos_token_id: Optional[int] = None, - pad_token_id: Optional[int] = None, - top_k: Optional[int] = None, - top_p: Optional[float] = None, - temperature: Optional[float] = None, - prepare_inputs_fn: Optional[Callable[[torch.Tensor, Any], dict]] = None, - update_model_kwargs_fn: Optional[Callable[[dict, Any], dict]] = None, - **model_kwargs) -> Generator: - +def sample_streamingly( + model: nn.Module, + input_ids: torch.Tensor, + max_generate_tokens: int, + early_stopping: bool = False, + eos_token_id: Optional[int] = None, + pad_token_id: Optional[int] = None, + top_k: Optional[int] = None, + top_p: Optional[float] = None, + temperature: Optional[float] = None, + prepare_inputs_fn: Optional[Callable[[torch.Tensor, Any], dict]] = None, + update_model_kwargs_fn: Optional[Callable[[dict, Any], dict]] = None, + **model_kwargs, +) -> Generator: logits_processor = prepare_logits_processor(top_k, top_p, temperature) unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1) for _ in range(max_generate_tokens): - model_inputs = prepare_inputs_fn(input_ids, **model_kwargs) if prepare_inputs_fn is not None else { - 'input_ids': input_ids - } + model_inputs = ( + prepare_inputs_fn(input_ids, **model_kwargs) if prepare_inputs_fn is not None else {"input_ids": input_ids} + ) outputs = model(**model_inputs) - next_token_logits = outputs['logits'][:, -1, :] + next_token_logits = outputs["logits"][:, -1, :] # pre-process distribution next_token_logits = logits_processor(input_ids, next_token_logits) # sample @@ -107,27 +108,28 @@ def update_model_kwargs_fn(outputs: dict, **model_kwargs) -> dict: if "attention_mask" in model_kwargs: attention_mask = model_kwargs["attention_mask"] model_kwargs["attention_mask"] = torch.cat( - [attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1) + [attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1 + ) return model_kwargs class Dialogue(BaseModel): - instruction: str = Field(min_length=1, example='Count up from 1 to 500.') - response: str = Field(example='') + instruction: str = Field(min_length=1, example="Count up from 1 to 500.") + response: str = Field(example="") -def _format_dialogue(instruction: str, response: str = ''): - return f'\n\n### Instruction:\n{instruction}\n\n### Response:\n{response}' +def _format_dialogue(instruction: str, response: str = ""): + return f"\n\n### Instruction:\n{instruction}\n\n### Response:\n{response}" -STOP_PAT = re.compile(r'(###|instruction:).*', flags=(re.I | re.S)) +STOP_PAT = re.compile(r"(###|instruction:).*", flags=(re.I | re.S)) class ChatPromptProcessor: - SAFE_RESPONSE = 'The input/response contains inappropriate content, please rephrase your prompt.' + SAFE_RESPONSE = "The input/response contains inappropriate content, please rephrase your prompt." - def __init__(self, tokenizer, context: str, max_len: int = 2048, censored_words: List[str]=[]): + def __init__(self, tokenizer, context: str, max_len: int = 2048, censored_words: List[str] = []): self.tokenizer = tokenizer self.context = context self.max_len = max_len @@ -138,42 +140,48 @@ class ChatPromptProcessor: def preprocess_prompt(self, history: List[Dialogue], max_new_tokens: int) -> str: if self.context_len is None: - self.context_len = len(self.tokenizer(self.context)['input_ids']) + self.context_len = len(self.tokenizer(self.context)["input_ids"]) if self.dialogue_placeholder_len is None: self.dialogue_placeholder_len = len( - self.tokenizer(_format_dialogue(''), add_special_tokens=False)['input_ids']) + self.tokenizer(_format_dialogue(""), add_special_tokens=False)["input_ids"] + ) prompt = self.context # the last dialogue must be in the prompt last_dialogue = history.pop() # the response of the last dialogue is empty - assert last_dialogue.response == '' - if len(self.tokenizer(_format_dialogue(last_dialogue.instruction), add_special_tokens=False) - ['input_ids']) + max_new_tokens + self.context_len >= self.max_len: + assert last_dialogue.response == "" + if ( + len(self.tokenizer(_format_dialogue(last_dialogue.instruction), add_special_tokens=False)["input_ids"]) + + max_new_tokens + + self.context_len + >= self.max_len + ): # to avoid truncate placeholder, apply truncate to the original instruction - instruction_truncated = self.tokenizer(last_dialogue.instruction, - add_special_tokens=False, - truncation=True, - max_length=(self.max_len - max_new_tokens - self.context_len - - self.dialogue_placeholder_len))['input_ids'] + instruction_truncated = self.tokenizer( + last_dialogue.instruction, + add_special_tokens=False, + truncation=True, + max_length=(self.max_len - max_new_tokens - self.context_len - self.dialogue_placeholder_len), + )["input_ids"] instruction_truncated = self.tokenizer.decode(instruction_truncated).lstrip() prompt += _format_dialogue(instruction_truncated) return prompt - res_len = self.max_len - max_new_tokens - len(self.tokenizer(prompt)['input_ids']) + res_len = self.max_len - max_new_tokens - len(self.tokenizer(prompt)["input_ids"]) rows = [] for dialogue in history[::-1]: text = _format_dialogue(dialogue.instruction, dialogue.response) - cur_len = len(self.tokenizer(text, add_special_tokens=False)['input_ids']) + cur_len = len(self.tokenizer(text, add_special_tokens=False)["input_ids"]) if res_len - cur_len < 0: break res_len -= cur_len rows.insert(0, text) - prompt += ''.join(rows) + _format_dialogue(last_dialogue.instruction) + prompt += "".join(rows) + _format_dialogue(last_dialogue.instruction) return prompt def postprocess_output(self, output: str) -> str: - output = STOP_PAT.sub('', output) + output = STOP_PAT.sub("", output) return output.strip() def has_censored_words(self, text: str) -> bool: @@ -182,8 +190,8 @@ class ChatPromptProcessor: intersection = set(jieba.cut(text.lower())) & self.censored_words return len(intersection) > 0 -class LockedIterator: +class LockedIterator: def __init__(self, it, lock: Lock) -> None: self.lock = lock self.it = iter(it) @@ -195,6 +203,7 @@ class LockedIterator: with self.lock: return next(self.it) + def load_json(path: str): with open(path) as f: - return json.load(f) \ No newline at end of file + return json.load(f) diff --git a/applications/Chat/requirements-test.txt b/applications/Chat/requirements-test.txt index e079f8a6038dd2dc8512967540f96ee0de172067..93d48bcb6f7928b9cd71944be03933c95ddb26a5 100644 --- a/applications/Chat/requirements-test.txt +++ b/applications/Chat/requirements-test.txt @@ -1 +1,2 @@ pytest +colossalai==0.3.3 diff --git a/applications/Chat/requirements.txt b/applications/Chat/requirements.txt index af7ff67861eb73573489e9ae46f1a2d29eaa02b3..e56aaca0e7cb4344f52e9e1b43cc89110472b85c 100644 --- a/applications/Chat/requirements.txt +++ b/applications/Chat/requirements.txt @@ -2,7 +2,7 @@ transformers>=4.20.1 tqdm datasets loralib -colossalai>=0.2.4 +colossalai==0.3.3 torch<2.0.0, >=1.12.1 langchain tokenizers @@ -11,3 +11,4 @@ sse_starlette wandb sentencepiece gpustat +tensorboard diff --git a/applications/Chat/setup.py b/applications/Chat/setup.py index a285a6dff4bf9cfe6905494de83a0baacfd795cb..eb44b6203ef8177798830c0e51799163ee8146a8 100644 --- a/applications/Chat/setup.py +++ b/applications/Chat/setup.py @@ -2,40 +2,42 @@ from setuptools import find_packages, setup def fetch_requirements(path): - with open(path, 'r') as fd: + with open(path, "r") as fd: return [r.strip() for r in fd.readlines()] def fetch_readme(): - with open('README.md', encoding='utf-8') as f: + with open("README.md", encoding="utf-8") as f: return f.read() def fetch_version(): - with open('version.txt', 'r') as f: + with open("version.txt", "r") as f: return f.read().strip() setup( - name='coati', + name="coati", version=fetch_version(), - packages=find_packages(exclude=( - 'tests', - 'benchmarks', - '*.egg-info', - )), - description='Colossal-AI Talking Intelligence', + packages=find_packages( + exclude=( + "tests", + "benchmarks", + "*.egg-info", + ) + ), + description="Colossal-AI Talking Intelligence", long_description=fetch_readme(), - long_description_content_type='text/markdown', - license='Apache Software License 2.0', - url='https://github.com/hpcaitech/Coati', - install_requires=fetch_requirements('requirements.txt'), - python_requires='>=3.6', + long_description_content_type="text/markdown", + license="Apache Software License 2.0", + url="https://github.com/hpcaitech/Coati", + install_requires=fetch_requirements("requirements.txt"), + python_requires=">=3.6", classifiers=[ - 'Programming Language :: Python :: 3', - 'License :: OSI Approved :: Apache Software License', - 'Environment :: GPU :: NVIDIA CUDA', - 'Topic :: Scientific/Engineering :: Artificial Intelligence', - 'Topic :: System :: Distributed Computing', + "Programming Language :: Python :: 3", + "License :: OSI Approved :: Apache Software License", + "Environment :: GPU :: NVIDIA CUDA", + "Topic :: Scientific/Engineering :: Artificial Intelligence", + "Topic :: System :: Distributed Computing", ], ) diff --git a/applications/Chat/tests/test_benchmarks.sh b/applications/Chat/tests/test_benchmarks.sh new file mode 100755 index 0000000000000000000000000000000000000000..3fdb2518134231e2da1a8f4bbc4e5906b43122ef --- /dev/null +++ b/applications/Chat/tests/test_benchmarks.sh @@ -0,0 +1,33 @@ +#!/bin/bash + +set -xue + +echo "Hint: You can run this script with 'verbose' as the first argument to run all strategies." + +if [[ $# -ne 0 && "$1" == "verbose" ]]; then + STRATEGIES=( + 'ddp' + 'colossalai_gemini' + 'colossalai_gemini_cpu' + 'colossalai_zero2' + 'colossalai_zero2_cpu' + 'colossalai_zero1' + 'colossalai_zero1_cpu' + ) +else + STRATEGIES=( + 'colossalai_zero2' + ) +fi + +BASE_DIR=$(dirname $(dirname $(realpath $BASH_SOURCE))) +BENCHMARKS_DIR=$BASE_DIR/benchmarks + +echo "[Test]: testing benchmarks ..." + +for strategy in ${STRATEGIES[@]}; do + torchrun --standalone --nproc_per_node 1 $BENCHMARKS_DIR/benchmark_opt_lora_dummy.py \ + --model 125m --critic_model 125m --strategy ${strategy} --lora_rank 4 \ + --num_episodes 2 --num_collect_steps 4 --num_update_steps 2 \ + --train_batch_size 2 --experience_batch_size 4 +done diff --git a/applications/Chat/tests/test_checkpoint.py b/applications/Chat/tests/test_checkpoint.py index 4c05a343169905dc03d91b1f4c39530ec8834848..9c08aa36c9b40eba008ed6c3fb0d1bb6fdef0a4c 100644 --- a/applications/Chat/tests/test_checkpoint.py +++ b/applications/Chat/tests/test_checkpoint.py @@ -6,7 +6,8 @@ import pytest import torch import torch.distributed as dist from coati.models.gpt import GPTActor -from coati.trainer.strategies import ColossalAIStrategy, DDPStrategy +from coati.models.utils import calc_action_log_probs +from coati.trainer.strategies import DDPStrategy, GeminiStrategy, LowLevelZeroStrategy, Strategy from transformers.models.gpt2.configuration_gpt2 import GPT2Config from colossalai.nn.optimizer import HybridAdam @@ -16,39 +17,37 @@ GPT_CONFIG = GPT2Config(n_embd=128, n_layer=4, n_head=4) def get_data(batch_size: int, seq_len: int = 10) -> dict: - input_ids = torch.randint(0, 50257, (batch_size, seq_len), device='cuda') + input_ids = torch.randint(0, 50257, (batch_size, seq_len), device="cuda") attention_mask = torch.ones_like(input_ids) return dict(input_ids=input_ids, attention_mask=attention_mask) -def run_test_checkpoint(strategy): - BATCH_SIZE = 2 +def train_step(strategy: Strategy, actor: GPTActor, actor_optim: HybridAdam, batch_size: int = 8): + data = get_data(batch_size) + action_mask = torch.ones_like(data["attention_mask"], dtype=torch.bool) + actor_logits = actor(data["input_ids"], data["attention_mask"])["logits"] + action_log_probs = calc_action_log_probs(actor_logits, data["input_ids"], action_mask.size(1)) + loss = action_log_probs.sum() + strategy.backward(loss, actor, actor_optim) + strategy.optimizer_step(actor_optim) - if strategy == 'ddp': + +def run_test_checkpoint(strategy_name: str, shard: bool): + if strategy_name == "ddp": strategy = DDPStrategy() - elif strategy == 'colossalai_gemini': - strategy = ColossalAIStrategy(stage=3, placement_policy='cuda', initial_scale=2**5) - elif strategy == 'colossalai_zero2': - strategy = ColossalAIStrategy(stage=2, placement_policy='cuda') + elif strategy_name == "colossalai_gemini": + strategy = GeminiStrategy(placement_policy="auto", initial_scale=2**5) + elif strategy_name == "colossalai_zero2": + strategy = LowLevelZeroStrategy(stage=2, placement_policy="cuda") else: - raise ValueError(f'Unsupported strategy "{strategy}"') + raise ValueError(f"Unsupported strategy '{strategy_name}'") with strategy.model_init_context(): actor = GPTActor(config=GPT_CONFIG).cuda() - actor_optim = HybridAdam(actor.parameters()) - actor, actor_optim = strategy.prepare((actor, actor_optim)) - def run_step(): - data = get_data(BATCH_SIZE) - action_mask = torch.ones_like(data['attention_mask'], dtype=torch.bool) - action_log_probs = actor(data['input_ids'], action_mask.size(1), data['attention_mask']) - loss = action_log_probs.sum() - strategy.backward(loss, actor, actor_optim) - strategy.optimizer_step(actor_optim) - - run_step() + train_step(strategy, actor, actor_optim) ctx = tempfile.TemporaryDirectory() if dist.get_rank() == 0 else nullcontext() @@ -57,38 +56,36 @@ def run_test_checkpoint(strategy): dist.broadcast_object_list(rank0_dirname) rank0_dirname = rank0_dirname[0] - model_path = os.path.join(rank0_dirname, 'model.pt') - optim_path = os.path.join(rank0_dirname, f'optim-r{dist.get_rank()}.pt') - - strategy.save_model(actor, model_path, only_rank0=True) - strategy.save_optimizer(actor_optim, optim_path, only_rank0=False) - + model_path = os.path.join(rank0_dirname, "model" if shard else f"model.pt") + strategy.save_model(actor, model_path) + optim_path = os.path.join(rank0_dirname, "optim" if shard else "optim.pt") + strategy.save_optimizer(actor_optim, optim_path) dist.barrier() strategy.load_model(actor, model_path, strict=False) strategy.load_optimizer(actor_optim, optim_path) - dist.barrier() - run_step() + train_step(strategy, actor, actor_optim) -def run_dist(rank, world_size, port, strategy): - os.environ['RANK'] = str(rank) - os.environ['LOCAL_RANK'] = str(rank) - os.environ['WORLD_SIZE'] = str(world_size) - os.environ['MASTER_ADDR'] = 'localhost' - os.environ['MASTER_PORT'] = str(port) - run_test_checkpoint(strategy) +def run_dist(rank: int, world_size: int, port: int, strategy_name: str, shard: bool): + os.environ["RANK"] = str(rank) + os.environ["LOCAL_RANK"] = str(rank) + os.environ["WORLD_SIZE"] = str(world_size) + os.environ["MASTER_ADDR"] = "localhost" + os.environ["MASTER_PORT"] = str(port) + run_test_checkpoint(strategy_name, shard) @pytest.mark.dist -@pytest.mark.parametrize('world_size', [2]) -@pytest.mark.parametrize('strategy', ['ddp', 'colossalai_zero2', 'colossalai_gemini']) +@pytest.mark.parametrize("world_size", [4]) +@pytest.mark.parametrize("strategy_name", ["ddp", "colossalai_gemini", "colossalai_zero2"]) +@pytest.mark.parametrize("shard", [False, True]) @rerun_if_address_is_in_use() -def test_checkpoint(world_size, strategy): - spawn(run_dist, world_size, strategy=strategy) +def test_checkpoint(world_size: int, strategy_name: str, shard: bool): + spawn(run_dist, world_size, strategy_name=strategy_name, shard=shard) -if __name__ == '__main__': - test_checkpoint(2, 'colossalai_zero2') +if __name__ == "__main__": + test_checkpoint(2, "colossalai_gemini", shard=False) diff --git a/applications/Chat/tests/test_data.py b/applications/Chat/tests/test_data.py deleted file mode 100644 index 2e4d4ceac05fa603b98e4ad1c9b098e221345e83..0000000000000000000000000000000000000000 --- a/applications/Chat/tests/test_data.py +++ /dev/null @@ -1,118 +0,0 @@ -import os -from copy import deepcopy - -import pytest -import torch -import torch.distributed as dist -from coati.experience_maker import NaiveExperienceMaker -from coati.models.base import RewardModel -from coati.models.gpt import GPTActor, GPTCritic -from coati.replay_buffer import NaiveReplayBuffer -from coati.trainer.strategies import ColossalAIStrategy, DDPStrategy -from transformers.models.gpt2.configuration_gpt2 import GPT2Config - -from colossalai.testing import rerun_if_address_is_in_use, spawn - -GPT_CONFIG = GPT2Config(n_embd=128, n_layer=4, n_head=4) - - -def get_data(batch_size: int, seq_len: int = 10) -> dict: - input_ids = torch.randint(0, 50257, (batch_size, seq_len), device='cuda') - attention_mask = torch.ones_like(input_ids) - return dict(input_ids=input_ids, attention_mask=attention_mask) - - -def gather_and_equal(tensor: torch.Tensor) -> bool: - world_size = dist.get_world_size() - outputs = [torch.empty_like(tensor) for _ in range(world_size)] - dist.all_gather(outputs, tensor.contiguous()) - for t in outputs[1:]: - if not torch.equal(outputs[0], t): - return False - return True - - -def run_test_data(strategy): - EXPERINCE_BATCH_SIZE = 4 - SAMPLE_BATCH_SIZE = 2 - - if strategy == 'ddp': - strategy = DDPStrategy() - elif strategy == 'colossalai': - strategy = ColossalAIStrategy(placement_policy='cuda') - else: - raise ValueError(f'Unsupported strategy "{strategy}"') - - actor = GPTActor(config=GPT_CONFIG).cuda() - critic = GPTCritic(config=GPT_CONFIG).cuda() - - initial_model = deepcopy(actor) - reward_model = RewardModel(deepcopy(critic.model)).cuda() - - experience_maker = NaiveExperienceMaker(actor, critic, reward_model, initial_model) - replay_buffer = NaiveReplayBuffer(SAMPLE_BATCH_SIZE, cpu_offload=False) - - # experience of all ranks should be the same - for _ in range(2): - data = get_data(EXPERINCE_BATCH_SIZE) - assert gather_and_equal(data['input_ids']) - assert gather_and_equal(data['attention_mask']) - experience = experience_maker.make_experience(**data, - do_sample=True, - max_length=16, - eos_token_id=50256, - pad_token_id=50256) - assert gather_and_equal(experience.sequences) - assert gather_and_equal(experience.action_log_probs) - assert gather_and_equal(experience.values) - assert gather_and_equal(experience.reward) - assert gather_and_equal(experience.advantages) - assert gather_and_equal(experience.action_mask) - assert gather_and_equal(experience.attention_mask) - replay_buffer.append(experience) - - # replay buffer's data should be the same - buffer_size = torch.tensor([len(replay_buffer)], device='cuda') - assert gather_and_equal(buffer_size) - for item in replay_buffer.items: - assert gather_and_equal(item.sequences) - assert gather_and_equal(item.action_log_probs) - assert gather_and_equal(item.values) - assert gather_and_equal(item.reward) - assert gather_and_equal(item.advantages) - assert gather_and_equal(item.action_mask) - assert gather_and_equal(item.attention_mask) - - # dataloader of each rank should have the same size and different batch - dataloader = strategy.setup_dataloader(replay_buffer) - dataloader_size = torch.tensor([len(dataloader)], device='cuda') - assert gather_and_equal(dataloader_size) - for experience in dataloader: - assert not gather_and_equal(experience.sequences) - assert not gather_and_equal(experience.action_log_probs) - assert not gather_and_equal(experience.values) - assert not gather_and_equal(experience.reward) - assert not gather_and_equal(experience.advantages) - # action mask and attention mask may be same - - -def run_dist(rank, world_size, port, strategy): - os.environ['RANK'] = str(rank) - os.environ['LOCAL_RANK'] = str(rank) - os.environ['WORLD_SIZE'] = str(world_size) - os.environ['MASTER_ADDR'] = 'localhost' - os.environ['MASTER_PORT'] = str(port) - run_test_data(strategy) - - -@pytest.mark.skip -@pytest.mark.dist -@pytest.mark.parametrize('world_size', [2]) -@pytest.mark.parametrize('strategy', ['ddp', 'colossalai']) -@rerun_if_address_is_in_use() -def test_data(world_size, strategy): - spawn(run_dist, world_size, strategy=strategy) - - -if __name__ == '__main__': - test_data(2, 'colossalai') diff --git a/applications/Chat/tests/test_dataset.py b/applications/Chat/tests/test_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..ec61bbb13fd7aa511208904c8bbef0c228af3d25 --- /dev/null +++ b/applications/Chat/tests/test_dataset.py @@ -0,0 +1,241 @@ +import json +import os +import tempfile +from typing import Optional + +import pytest +import torch +from coati.dataset.prompt_dataset import PromptDataset +from coati.dataset.reward_dataset import HhRlhfDataset, RmStaticDataset +from coati.dataset.sft_dataset import IGNORE_INDEX, SFTDataset, SupervisedDataset +from coati.models.chatglm.chatglm_tokenizer import ChatGLMTokenizer +from datasets import load_dataset +from transformers import AutoTokenizer, BloomTokenizerFast, LlamaTokenizer, PreTrainedTokenizer +from transformers.models.gpt2.tokenization_gpt2 import GPT2Tokenizer + +SFT_DATASET = [ + { + "instruction": "Provide a list of the top 10 most popular mobile games in Asia", + "input": "", + "output": "The top 10 most popular mobile games in Asia are:\n1) PUBG Mobile\n2) Pokemon Go\n3) Candy Crush Saga\n4) Free Fire\n5) Clash of Clans\n6) Mario Kart Tour\n7) Arena of Valor\n8) Fantasy Westward Journey\n9) Subway Surfers\n10) ARK Survival Evolved", + "id": 0, + }, + { + "instruction": "Please provide an action plan for reducing carbon footprint on a corporate level", + "input": "", + "output": "An action plan for reducing carbon footprint on a corporate level could include switching to renewable energy sources (e.g., solar or wind powered), investing in energy efficient infrastructure, implementing green production practices (e.g., reduce water and paper usage, dispose of materials responsibly, recycle/reuse materials), reducing or eliminating travel, encouraging employees to be conscious of and reduce their own energy usage, and incentivizing employees to participate in sustainable work practices.", + "id": 1, + }, + { + "instruction": "Write a persuasive email to your boss explaining why you should have a pay raise", + "input": "", + "output": "Dear [Boss], \n\nI hope this message finds you well. I am writing to request a pay raise.\n\nAs you know, I have been a dedicated and hardworking member of the team since I started working here [insert number] of months/years ago. My enthusiasm and passion for my job has remained consistent over the years, and I have always given 100% to my role. \n\nI understand that the current financial situation is challenging, however, I would sincerely appreciate you taking the time to consider my request. I believe that my dedication to the job and the value that I bring to the organization warrants a raise. I work diligently and am confident that I can continue to be an asset to the company. \n\nI hope my request is taken into account and I thank you in advance for your understanding. I look forward to our conversation. \n\nSincerely,\n[Your Name]", + "id": 2, + }, +] + +PROMPT_DATASET = [ + { + "instruction": 'Edit this paragraph to make it more concise: "Yesterday, I went to the store and bought some things. Then, I came home and put them away. After that, I went for a walk and met some friends."', + "id": 0, + }, + {"instruction": "Write a descriptive paragraph about a memorable vacation you went on", "id": 1}, + {"instruction": "Write a persuasive essay arguing why homework should be banned in schools", "id": 2}, + {"instruction": "Create a chart comparing the statistics on student debt in the United States.", "id": 3}, +] + + +def make_tokenizer(model: str): + if model == "gpt2": + tokenizer = GPT2Tokenizer.from_pretrained("gpt2") + tokenizer.pad_token = tokenizer.eos_token + elif model == "bloom": + tokenizer = BloomTokenizerFast.from_pretrained("bigscience/bloom-560m") + tokenizer.pad_token = tokenizer.eos_token + elif model == "opt": + tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m") + tokenizer.pad_token = tokenizer.eos_token + elif model == "llama": + tokenizer = LlamaTokenizer.from_pretrained("hf-internal-testing/llama-tokenizer") + tokenizer.pad_token = tokenizer.unk_token + elif model == "chatglm": + tokenizer = ChatGLMTokenizer.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True) + else: + raise ValueError(f"Unsupported model '{model}'") + return tokenizer + + +def check_content(input_ids_stripped: torch.Tensor, tokenizer: PreTrainedTokenizer, model: str): + if model == "opt": + # NOTE: Contrary to GPT2, OPT adds the EOS token to the beginning of every prompt. + assert input_ids_stripped[0] == tokenizer.eos_token_id + input_ids_stripped = input_ids_stripped[1:] + elif model == "llama": + assert input_ids_stripped[0] == tokenizer.bos_token_id + input_ids_stripped = input_ids_stripped[1:] + elif model == "chatglm": + assert input_ids_stripped[0] == tokenizer.bos_token_id + assert input_ids_stripped[-1] == tokenizer.eos_token_id + input_ids_stripped = input_ids_stripped[1:-1] + assert torch.all(input_ids_stripped != tokenizer.pad_token_id) + assert torch.all(input_ids_stripped != tokenizer.bos_token_id) + assert torch.all(input_ids_stripped != tokenizer.eos_token_id) + assert input_ids_stripped != tokenizer.sep_token_id + assert input_ids_stripped != tokenizer.cls_token_id + if model == "chatglm": + assert torch.all(input_ids_stripped != tokenizer.mask_token_id) + else: + assert input_ids_stripped != tokenizer.mask_token_id + + +@pytest.mark.parametrize("model", ["gpt2", "bloom", "opt", "llama"]) +@pytest.mark.parametrize("max_length", [32, 1024]) +@pytest.mark.parametrize("max_datasets_size", [2]) +def test_prompt_dataset(model: str, max_datasets_size: int, max_length: int): + with tempfile.TemporaryDirectory() as tmp_dir: + dataset_name = "prompt_dataset.json" + with open(os.path.join(tmp_dir, dataset_name), "w") as f: + json.dump(PROMPT_DATASET, f) + tokenizer = make_tokenizer(model) + assert tokenizer.padding_side in ("left", "right") + prompt_dataset = PromptDataset( + data_path=os.path.join(tmp_dir, dataset_name), + tokenizer=tokenizer, + max_datasets_size=max_datasets_size, + max_length=max_length, + ) + assert len(prompt_dataset) == min(max_datasets_size, len(PROMPT_DATASET)) + for i in range(len(prompt_dataset)): + assert isinstance(prompt_dataset[i], dict) + assert list(prompt_dataset[i].keys()) == ["input_ids", "attention_mask"] + input_ids = prompt_dataset[i]["input_ids"] + attention_mask = prompt_dataset[i]["attention_mask"] + attention_mask = attention_mask.bool() + assert input_ids.shape == attention_mask.shape == torch.Size([max_length]) + assert torch.all(input_ids[torch.logical_not(attention_mask)] == tokenizer.pad_token_id) + check_content(input_ids.masked_select(attention_mask), tokenizer, model) + + +@pytest.mark.parametrize("model", ["gpt2", "bloom", "opt", "llama"]) +@pytest.mark.parametrize( + ["dataset_path", "subset"], [("Anthropic/hh-rlhf", "harmless-base"), ("Dahoas/rm-static", None)] +) +@pytest.mark.parametrize("max_datasets_size", [32]) +@pytest.mark.parametrize("max_length", [32, 1024]) +def test_reward_dataset(model: str, dataset_path: str, subset: Optional[str], max_datasets_size: int, max_length: int): + data = load_dataset(dataset_path, data_dir=subset) + assert max_datasets_size <= len(data["train"]) and max_datasets_size <= len(data["test"]) + train_data = data["train"].select(range(max_datasets_size)) + test_data = data["test"].select(range(max_datasets_size)) + tokenizer = make_tokenizer(model) + assert tokenizer.padding_side in ("left", "right") + + if dataset_path == "Anthropic/hh-rlhf": + train_dataset = HhRlhfDataset(train_data, tokenizer, max_length) + test_dataset = HhRlhfDataset(test_data, tokenizer, max_length) + elif dataset_path == "Dahoas/rm-static": + train_dataset = RmStaticDataset(train_data, tokenizer, max_length) + test_dataset = RmStaticDataset(test_data, tokenizer, max_length) + else: + raise ValueError(f'Unsupported dataset "{dataset_path}"') + + assert len(train_dataset) == len(test_dataset) == max_datasets_size + for i in range(max_datasets_size): + chosen_ids, c_mask, reject_ids, r_mask = train_dataset[i] + assert chosen_ids.shape == c_mask.shape == reject_ids.shape == r_mask.shape == torch.Size([max_length]) + c_mask = c_mask.to(torch.bool) + r_mask = r_mask.to(torch.bool) + if chosen_ids.masked_select(c_mask)[-1] == tokenizer.eos_token_id: + check_content(chosen_ids.masked_select(c_mask)[:-1], tokenizer, model) + assert torch.all(chosen_ids.masked_select(torch.logical_not(c_mask)) == tokenizer.pad_token_id) + else: + check_content(chosen_ids.masked_select(c_mask), tokenizer, model) + assert torch.all(c_mask) + if reject_ids.masked_select(r_mask)[-1] == tokenizer.eos_token_id: + check_content(reject_ids.masked_select(r_mask)[:-1], tokenizer, model) + assert torch.all(reject_ids.masked_select(torch.logical_not(r_mask)) == tokenizer.pad_token_id) + else: + check_content(reject_ids.masked_select(r_mask), tokenizer, model) + assert torch.all(r_mask) + + chosen_ids, c_mask, reject_ids, r_mask = test_dataset[i] + assert chosen_ids.shape == c_mask.shape == reject_ids.shape == r_mask.shape == torch.Size([max_length]) + c_mask = c_mask.to(torch.bool) + r_mask = r_mask.to(torch.bool) + if chosen_ids.masked_select(c_mask)[-1] == tokenizer.eos_token_id: + check_content(chosen_ids.masked_select(c_mask)[:-1], tokenizer, model) + assert torch.all(chosen_ids.masked_select(torch.logical_not(c_mask)) == tokenizer.pad_token_id) + else: + check_content(chosen_ids.masked_select(c_mask), tokenizer, model) + assert torch.all(c_mask) + if reject_ids.masked_select(r_mask)[-1] == tokenizer.eos_token_id: + check_content(reject_ids.masked_select(r_mask)[:-1], tokenizer, model) + assert torch.all(reject_ids.masked_select(torch.logical_not(r_mask)) == tokenizer.pad_token_id) + else: + check_content(reject_ids.masked_select(r_mask), tokenizer, model) + assert torch.all(r_mask) + + +@pytest.mark.parametrize("model", ["gpt2", "bloom", "opt", "llama", "chatglm"]) +@pytest.mark.parametrize("dataset_path", ["yizhongw/self_instruct", None]) +@pytest.mark.parametrize("max_dataset_size", [2]) +@pytest.mark.parametrize("max_length", [32, 1024]) +def test_sft_dataset(model: str, dataset_path: Optional[str], max_dataset_size: int, max_length: int): + tokenizer = make_tokenizer(model) + if dataset_path == "yizhongw/self_instruct": + data = load_dataset(dataset_path, "super_natural_instructions") + train_data = data["train"].select(range(max_dataset_size)) + sft_dataset = SFTDataset(train_data, tokenizer, max_length) + else: + with tempfile.TemporaryDirectory() as tmp_dir: + dataset_name = "sft_dataset.json" + with open(os.path.join(tmp_dir, dataset_name), "w") as f: + json.dump(SFT_DATASET, f) + sft_dataset = SupervisedDataset( + tokenizer=tokenizer, + data_path=os.path.join(tmp_dir, dataset_name), + max_datasets_size=max_dataset_size, + max_length=max_length, + ) + assert len(sft_dataset) == min(max_dataset_size, len(SFT_DATASET)) + + if isinstance(tokenizer, ChatGLMTokenizer): + for i in range(max_dataset_size): + assert isinstance(sft_dataset[i], dict) + assert list(sft_dataset[i].keys()) == ["input_ids", "labels"] + input_ids = sft_dataset[i]["input_ids"] + labels = sft_dataset[i]["labels"] + assert input_ids.shape == labels.shape == torch.Size([max_length]) + + ignore_mask = labels == IGNORE_INDEX + assert input_ids.masked_select(torch.logical_not(ignore_mask))[0] == tokenizer.bos_token_id + check_content(input_ids.masked_select(torch.logical_not(ignore_mask)), tokenizer, model) + return + + for i in range(max_dataset_size): + assert isinstance(sft_dataset[i], dict) + assert list(sft_dataset[i].keys()) == ["input_ids", "labels", "attention_mask"] + input_ids = sft_dataset[i]["input_ids"] + labels = sft_dataset[i]["labels"] + attention_mask = sft_dataset[i]["attention_mask"].to(torch.bool) + assert input_ids.shape == labels.shape == attention_mask.shape == torch.Size([max_length]) + if input_ids.masked_select(attention_mask)[-1] == tokenizer.eos_token_id: + check_content(input_ids.masked_select(attention_mask)[:-1], tokenizer, model) + assert torch.all(input_ids.masked_select(torch.logical_not(attention_mask)) == tokenizer.pad_token_id) + else: + check_content(input_ids.masked_select(attention_mask), tokenizer, model) + assert torch.all(attention_mask) + ignore_mask = labels == IGNORE_INDEX + prompt_mask = torch.logical_and(ignore_mask, attention_mask) + check_content(input_ids.masked_select(prompt_mask), tokenizer, model) + assert torch.all(input_ids.masked_select(ignore_mask ^ prompt_mask) == tokenizer.pad_token_id) + + +if __name__ == "__main__": + test_sft_dataset(model="bloom", dataset_path="yizhongw/self_instruct", max_dataset_size=2, max_length=256) + + test_reward_dataset( + model="gpt2", dataset_path="Anthropic/hh-rlhf", subset="harmless-base", max_datasets_size=8, max_length=256 + ) + + test_prompt_dataset(model="opt", max_datasets_size=2, max_length=128) diff --git a/applications/Chat/tests/test_experience.py b/applications/Chat/tests/test_experience.py new file mode 100644 index 0000000000000000000000000000000000000000..a9591259800d94e5758cbf31104ab2b114aa5a19 --- /dev/null +++ b/applications/Chat/tests/test_experience.py @@ -0,0 +1,130 @@ +import copy +import os + +import pytest +import torch +import torch.distributed as dist +from coati.experience_buffer import NaiveExperienceBuffer +from coati.experience_maker import NaiveExperienceMaker +from coati.models.base import RewardModel +from coati.models.gpt import GPTActor, GPTCritic +from coati.trainer.ppo import _set_default_generate_kwargs +from coati.trainer.strategies import DDPStrategy, GeminiStrategy +from coati.trainer.strategies.colossalai import LowLevelZeroStrategy +from transformers.models.gpt2.configuration_gpt2 import GPT2Config + +from colossalai.testing import rerun_if_address_is_in_use, spawn + +GPT_CONFIG = GPT2Config(n_embd=128, n_layer=4, n_head=4) + + +def get_data(batch_size: int, seq_len: int = 10) -> dict: + input_ids = torch.randint(0, 50257, (batch_size, seq_len), device="cuda") + attention_mask = torch.ones_like(input_ids) + return dict(input_ids=input_ids, attention_mask=attention_mask) + + +def gather_and_equal(tensor: torch.Tensor) -> bool: + world_size = dist.get_world_size() + outputs = [torch.empty_like(tensor) for _ in range(world_size)] + dist.all_gather(outputs, tensor.contiguous()) + for t in outputs[1:]: + if not torch.equal(outputs[0], t): + return False + return True + + +def make_and_consume_experience(strategy): + EXPERIENCE_BATCH_SIZE = 4 + SAMPLE_BATCH_SIZE = 2 + + if strategy == "ddp": + strategy = DDPStrategy() + elif strategy == "colossalai-zero2": + strategy = LowLevelZeroStrategy() + elif strategy == "colossalai-gemini": + strategy = GeminiStrategy(placement_policy="static") + else: + raise ValueError(f'Unsupported strategy "{strategy}"') + + with strategy.model_init_context(): + actor = GPTActor(config=GPT_CONFIG).cuda() + critic = GPTCritic(config=GPT_CONFIG).cuda() + + initial_model = GPTActor(config=GPT_CONFIG).cuda() + reward_model = RewardModel(model=copy.deepcopy(critic.model)).cuda() + + actor, critic, initial_model, reward_model = strategy.prepare(actor, critic, initial_model, reward_model) + + class MockTokenizer: + def __init__(self): + self.padding_side = "left" + self.eos_token_id = 0 + self.pad_token_id = 0 + + tokenizer = MockTokenizer() + experience_maker = NaiveExperienceMaker(actor, critic, reward_model, initial_model, tokenizer) + data_buffer = NaiveExperienceBuffer(SAMPLE_BATCH_SIZE, cpu_offload=False) + + generate_kwargs = dict(do_sample=True, max_length=16) + generate_kwargs = _set_default_generate_kwargs(strategy, generate_kwargs, actor) + + # experience of all ranks should be the same + for _ in range(2): + data = get_data(EXPERIENCE_BATCH_SIZE) + assert gather_and_equal(data["input_ids"]) + assert gather_and_equal(data["attention_mask"]) + experience = experience_maker.make_experience(**data, do_sample=True, max_length=16) + assert gather_and_equal(experience.sequences) + assert gather_and_equal(experience.action_log_probs) + assert gather_and_equal(experience.values) + assert gather_and_equal(experience.reward) + assert gather_and_equal(experience.advantages) + assert gather_and_equal(experience.action_mask) + assert gather_and_equal(experience.attention_mask) + data_buffer.append(experience) + + # data buffer's data should be the same + buffer_size = torch.tensor([len(data_buffer)], device="cuda") + assert gather_and_equal(buffer_size) + for item in data_buffer.items: + assert gather_and_equal(item.sequences) + assert gather_and_equal(item.action_log_probs) + assert gather_and_equal(item.values) + assert gather_and_equal(item.reward) + assert gather_and_equal(item.advantages) + assert gather_and_equal(item.action_mask) + assert gather_and_equal(item.attention_mask) + + # dataloader of each rank should have the same size and different batch + dataloader = strategy.setup_dataloader(data_buffer) + dataloader_size = torch.tensor([len(dataloader)], device="cuda") + assert gather_and_equal(dataloader_size) + for experience in dataloader: + assert not gather_and_equal(experience.sequences) + assert not gather_and_equal(experience.action_log_probs) + assert not gather_and_equal(experience.values) + assert not gather_and_equal(experience.reward) + assert not gather_and_equal(experience.advantages) + # action mask and attention mask may be same + + +def run_dist(rank, world_size, port, strategy): + os.environ["RANK"] = str(rank) + os.environ["LOCAL_RANK"] = str(rank) + os.environ["WORLD_SIZE"] = str(world_size) + os.environ["MASTER_ADDR"] = "localhost" + os.environ["MASTER_PORT"] = str(port) + make_and_consume_experience(strategy) + + +@pytest.mark.dist +@pytest.mark.parametrize("world_size", [2]) +@pytest.mark.parametrize("strategy", ["ddp", "colossalai-zero2", "colossalai-gemini"]) +@rerun_if_address_is_in_use() +def test_experience(world_size, strategy): + spawn(run_dist, world_size, strategy=strategy) + + +if __name__ == "__main__": + test_experience(2, "colossalai-zero2") diff --git a/applications/Chat/tests/test_inference.sh b/applications/Chat/tests/test_inference.sh new file mode 100755 index 0000000000000000000000000000000000000000..849db06e58abdf9893a544d70ce9312bc5129b7a --- /dev/null +++ b/applications/Chat/tests/test_inference.sh @@ -0,0 +1,11 @@ +set -xue + +BASE_DIR=$(dirname $(dirname $(realpath $BASH_SOURCE))) +EXAMPLES_DIR=$BASE_DIR/examples + +echo "[Test]: testing inference ..." + +# HACK: skip llama due to oom +for model in 'gpt2' 'bloom' 'opt'; do + python $EXAMPLES_DIR/inference.py --model $model +done diff --git a/applications/Chat/tests/test_models.py b/applications/Chat/tests/test_models.py new file mode 100644 index 0000000000000000000000000000000000000000..b2c22ac6a3b986f5d05f0dc1c54088a021c6d838 --- /dev/null +++ b/applications/Chat/tests/test_models.py @@ -0,0 +1,245 @@ +import copy +from typing import Any, Callable, Dict, Tuple + +import pytest +import torch +import torch.nn as nn +from coati.models.base import Actor, Critic, RewardModel, get_base_model +from coati.models.bloom import BLOOMRM, BLOOMActor, BLOOMCritic +from coati.models.chatglm import ChatGLMActor +from coati.models.chatglm.chatglm_tokenizer import ChatGLMTokenizer +from coati.models.generation import generate +from coati.models.gpt import GPTRM, GPTActor, GPTCritic +from coati.models.llama import LlamaActor +from coati.models.lora import LoraLinear, convert_to_lora_module +from coati.models.loss import GPTLMLoss, LogExpLoss, LogSigLoss, PolicyLoss, ValueLoss +from coati.models.opt import OPTRM, OPTActor, OPTCritic +from coati.models.utils import calc_action_log_probs, masked_mean + + +@pytest.mark.parametrize("batch_size", [4]) +@pytest.mark.parametrize("seq_len", [32]) +@pytest.mark.parametrize( + "actor_maker", + [ + lambda: BLOOMActor(), + lambda: GPTActor(), + # HACK: skip llama due to long execution time + # lambda: LlamaActor(), + lambda: OPTActor(), + ], +) +@pytest.mark.parametrize( + "generate_kwargs", + [ + { + "max_length": 64, + "use_cache": True, + "do_sample": True, + "temperature": 1.0, + "top_k": 50, + } + ], +) +def test_generation(actor_maker: Callable[[], Actor], batch_size: int, seq_len: int, generate_kwargs: Dict[str, Any]): + class MockTokenizer: + def __init__(self): + self.padding_side = "left" + self.eos_token_id = 0 + self.pad_token_id = 0 + + actor = actor_maker() + input_ids = torch.randint(0, 100, (batch_size, seq_len)).cuda() + tokenizer = MockTokenizer() + sequences = generate(actor.cuda(), input_ids, tokenizer, **generate_kwargs) + assert sequences.shape == (batch_size, generate_kwargs["max_length"]) + + +def test_utils(): + fn_input = {"tensor": torch.ones((10,)), "mask": torch.randint(0, 2, (10,))} + fn_output = masked_mean(dim=0, **fn_input) + assert fn_output.dim() == 0 + assert torch.allclose(fn_output, torch.tensor(1.0)) + + batch_size = 4 + seq_len = 32 + num_labels = 10 + num_actions = 2 + fn_input = { + "logits": torch.randn((batch_size, seq_len, num_labels)), + "sequences": torch.randint(0, num_labels, (batch_size, seq_len)), + "num_actions": num_actions, + } + fn_output = calc_action_log_probs(**fn_input) + assert fn_output.shape == (batch_size, num_actions) + + +@pytest.mark.parametrize("lora_rank", [4]) +@pytest.mark.parametrize("num_dim", [32]) +@pytest.mark.parametrize("num_layers", [4]) +def test_lora(lora_rank: int, num_dim: int, num_layers: int): + model = nn.ModuleList([nn.Linear(num_dim, num_dim) for _ in range(num_layers)]) + lora_model = convert_to_lora_module(model, lora_rank) + assert isinstance(lora_model, nn.ModuleList) + for i in range(num_layers): + assert isinstance(lora_model[i], LoraLinear) + assert lora_model[i].lora_A.shape == (lora_rank, num_dim) + assert lora_model[i].lora_B.shape == (num_dim, lora_rank) + + old_model = copy.deepcopy(lora_model) + for i in range(num_layers): + assert isinstance(lora_model[i], LoraLinear) + assert torch.allclose(old_model[i].weight, lora_model[i].weight) + assert torch.allclose(old_model[i].bias, lora_model[i].bias) + assert torch.allclose(old_model[i].lora_B @ old_model[i].lora_A, lora_model[i].lora_B @ lora_model[i].lora_A) + optimizer = torch.optim.Adam(lora_model.parameters()) + x = torch.randn(8, num_dim) + for i in range(num_layers): + x = lora_model[i](x) + loss = x.sum() + loss.backward() + optimizer.step() + for i in range(num_layers): + assert isinstance(lora_model[i], LoraLinear) + assert torch.allclose(old_model[i].weight, lora_model[i].weight) + assert torch.allclose(old_model[i].bias, lora_model[i].bias) + assert not torch.allclose( + old_model[i].lora_B @ old_model[i].lora_A, lora_model[i].lora_B @ lora_model[i].lora_A + ) + + +@pytest.mark.parametrize("batch_size", [8]) +@pytest.mark.parametrize("seq_len", [128]) +@pytest.mark.parametrize( + "models_maker", + [ + lambda: (BLOOMActor(), BLOOMCritic(), BLOOMRM()), + lambda: (GPTActor(), GPTCritic(), GPTRM()), + # HACK: skip llama due to long execution time + # lambda: (LlamaActor(), LlamaCritic(), LlamaRM()), + lambda: (OPTActor(), OPTCritic(), OPTRM()), + lambda: (ChatGLMActor(), None, None), + ], +) +@torch.no_grad() +def test_models(models_maker: Callable[[], Tuple[Actor, Critic, RewardModel]], batch_size: int, seq_len: int): + actor_input = { + "input_ids": torch.randint(0, 100, (batch_size, seq_len)), + "attention_mask": torch.randint(0, 2, (batch_size, seq_len)), + } + critic_input = { + "sequences": torch.randint(0, 100, (batch_size, seq_len)), + "attention_mask": torch.randint(0, 2, (batch_size, seq_len)), + } + rm_input = { + "sequences": torch.randint(0, 100, (batch_size, seq_len)), + "attention_mask": torch.randint(0, 2, (batch_size, seq_len)), + } + + actor, critic, rm = models_maker() + if isinstance(actor, ChatGLMActor): + actor = actor.float() + tokenizer = ChatGLMTokenizer.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True) + chatglm_special_token = torch.tensor([tokenizer.gmask_token_id, tokenizer.bos_token_id]).repeat(batch_size, 1) + actor_input = { + "input_ids": torch.cat( + ( + torch.randint(0, 100, (batch_size, seq_len // 2)), + chatglm_special_token, + torch.randint(0, 100, (batch_size, seq_len // 2 - 2)), + ), + dim=1, + ), + "attention_mask": torch.randint(0, 2, (batch_size, 1, seq_len, seq_len)), + } + assert isinstance(actor, Actor) + get_base_model(actor) + actor_output = actor(**actor_input) + assert actor_output.logits.shape[:2] == (batch_size, seq_len) + + if critic: + assert isinstance(critic, Critic) + get_base_model(critic) + critic_output = critic(**critic_input) + assert critic_output.shape == (batch_size,) + + if rm: + assert isinstance(rm, RewardModel) + get_base_model(rm) + rm_output = rm(**rm_input) + assert rm_output.shape == (batch_size,) + + +@pytest.mark.parametrize("batch_size", [16]) +@pytest.mark.parametrize("seq_len", [128]) +@pytest.mark.parametrize("num_labels", [100]) +def test_loss(batch_size: int, seq_len: int, num_labels: int): + loss = GPTLMLoss() + loss_input = { + "logits": torch.randn(batch_size, seq_len, num_labels), + "labels": torch.randint(0, num_labels, (batch_size, seq_len)), + } + loss(**loss_input) + + loss = PolicyLoss() + loss_input = { + "log_probs": torch.randn( + batch_size, + ), + "old_log_probs": torch.randn( + batch_size, + ), + "advantages": torch.randn( + batch_size, + ), + } + loss(**loss_input) + + loss = ValueLoss() + loss_input = { + "values": torch.randn( + batch_size, + ), + "old_values": torch.randn( + batch_size, + ), + "reward": torch.randn( + batch_size, + ), + } + loss(**loss_input) + + loss = LogSigLoss() + loss_input = { + "chosen_reward": torch.randn( + batch_size, + ), + "reject_reward": torch.randn( + batch_size, + ), + } + loss(**loss_input) + + loss = LogExpLoss() + loss_input = { + "chosen_reward": torch.randn( + batch_size, + ), + "reject_reward": torch.randn( + batch_size, + ), + } + loss(**loss_input) + + +if __name__ == "__main__": + generate_kwargs = dict(max_length=40, use_cache=True, do_sample=True, temperature=1.0, top_k=50) + test_generation(lambda: LlamaActor(), batch_size=4, seq_len=32, generate_kwargs=generate_kwargs) + + test_utils() + + test_lora(lora_rank=2, num_dim=8, num_layers=2) + + test_models(models_maker=lambda: (BLOOMActor(), BLOOMCritic(), BLOOMRM()), batch_size=8, seq_len=128) + + test_loss(batch_size=8, seq_len=128, num_labels=100) diff --git a/applications/Chat/tests/test_train.sh b/applications/Chat/tests/test_train.sh new file mode 100755 index 0000000000000000000000000000000000000000..68fca7fbf8c0434cd0883fd0810fc6be74d3fe38 --- /dev/null +++ b/applications/Chat/tests/test_train.sh @@ -0,0 +1,233 @@ +#!/usr/bin/env bash + +set_n_least_used_CUDA_VISIBLE_DEVICES() { + local n=${1:-"9999"} + echo "GPU Memory Usage:" + local FIRST_N_GPU_IDS=$(nvidia-smi --query-gpu=memory.used --format=csv | + tail -n +2 | + nl -v 0 | + tee /dev/tty | + sort -g -k 2 | + awk '{print $1}' | + head -n $n) + export CUDA_VISIBLE_DEVICES=$(echo $FIRST_N_GPU_IDS | sed 's/ /,/g') + echo "Now CUDA_VISIBLE_DEVICES is set to:" + echo "CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES" +} + +set_n_least_used_CUDA_VISIBLE_DEVICES 4 + +set -xu + +if [ -z "$SFT_DATASET" ]; then + echo "Please set \$SFT_DATASET to the path to sft dataset." + exit 1 +fi + +if [ -z "$PROMPT_DATASET" ]; then + echo "Please set \$PROMPT_DATASET to the path to prompts csv." + exit 1 +fi + +if [ -z "$PRETRAIN_DATASET" ]; then + echo "Please set \$PRETRAIN_DATASET to the path to alpaca data." + exit 1 +fi + +NUM_RETRY=3 +BASE_DIR=$(dirname $(dirname $(realpath $BASH_SOURCE))) +EXAMPLES_DIR=$BASE_DIR/examples +MODELS_DIR=$BASE_DIR/examples/models_config +MODELS=('gpt2' 'bloom' 'opt' 'llama') +STRATEGIES=('ddp' 'colossalai_gemini' 'colossalai_zero2') + + +export OMP_NUM_THREADS=8 + +# install requirements +pip install -r $EXAMPLES_DIR/requirements.txt + +python $EXAMPLES_DIR/download_model.py --model-dir $MODELS_DIR --config-only + +get_pretrain() { + local model=$1 + if [[ $model == "gpt2" ]]; then + echo "gpt2" + elif [[ $model == "bloom" ]]; then + echo "bigscience/bloom-560m" + elif [[ $model == "opt" ]]; then + echo "facebook/opt-350m" + else + echo "Unknown model $model" + exit 1 + fi +} + +random_choice() { + local arr=("$@") + local len=${#arr[@]} + local idx=$((RANDOM % len)) + echo ${arr[$idx]} +} + +echo "[Test]: testing sft ..." + +# FIXME: This is a hack to skip tests that are not working +# - gpt2-ddp: RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation +# - llama-*: These tests can be passed locally, skipped for long execution time +# - *-gemini: Gemini plugin does not support `from_pretrained` yet +SKIPPED_TESTS=( + "gpt2-ddp" + "llama-ddp" + "llama-colossalai_gemini" + "llama-colossalai_zero2" +) + +GRAD_CKPTS=('' '--grad_checkpoint') +for lora_rank in '0'; do + for model in ${MODELS[@]}; do + strategies=($(shuf -e "${STRATEGIES[@]}")) + for strategy in ${strategies[@]}; do + if [[ " ${SKIPPED_TESTS[*]} " =~ " $model-$strategy-$lora_rank " ]]; then + echo "[Test]: Skipped $model-$strategy-$lora_rank" + continue + elif [[ " ${SKIPPED_TESTS[*]} " =~ " $model-$strategy " ]]; then + echo "[Test]: Skipped $model-$strategy" + continue + fi + pretrain=$(get_pretrain $model) + pretrain_model="" + if [[ $lora_rank -gt 0 ]]; then + pretrain_model="--pretrain $pretrain" + fi + grad_ckpt=$(random_choice "${GRAD_CKPTS[@]}") + for i in $(seq $NUM_RETRY); do + echo "[Test]: $model-$strategy-$lora_rank, attempt $i" + torchrun --standalone --nproc_per_node=4 $EXAMPLES_DIR/train_sft.py \ + $pretrain_model --tokenizer $MODELS_DIR/$model \ + --model $model --strategy $strategy --lora_rank $lora_rank $grad_ckpt \ + --dataset $SFT_DATASET --max_datasets_size 8 \ + --max_epochs 1 --batch_size 1 --accumulation_steps 1 --lr 1e-8 \ + --save_path $EXAMPLES_DIR/rlhf_models/sft_ckpt_${model}_${lora_rank} + passed=$? + if [ $passed -eq 0 ]; then + break + fi + done + if [ $passed -ne 0 ]; then + echo "[Test]: Failed $model-$strategy-$lora_rank" + exit 1 + fi + done + done +done + +echo "[Test]: testing reward model ..." + +# FIXME: This is a hack to skip tests that are not working +# - gpt2-ddp: RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation +# - llama-*: These tests can be passed locally, skipped for long execution time +# - *-gemini: Gemini plugin does not support `from_pretrained` yet +SKIPPED_TESTS=( + "gpt2-ddp" + "llama-ddp" + "llama-colossalai_gemini" + "llama-colossalai_zero2" +) + +LOSS_FNS=('log_sig' 'log_exp') +DATASETS=('Anthropic/hh-rlhf' 'Dahoas/rm-static') +for lora_rank in '0'; do + for model in ${MODELS[@]}; do + strategies=($(shuf -e "${STRATEGIES[@]}")) + for strategy in ${strategies[@]}; do + if [[ " ${SKIPPED_TESTS[*]} " =~ " $model-$strategy-$lora_rank " ]]; then + echo "[Test]: Skipped $model-$strategy-$lora_rank" + continue + elif [[ " ${SKIPPED_TESTS[*]} " =~ " $model-$strategy " ]]; then + echo "[Test]: Skipped $model-$strategy" + continue + fi + pretrain=$(get_pretrain $model) + pretrain_model="" + if [[ $lora_rank -gt 0 ]]; then + pretrain_model="--pretrain $pretrain" + fi + loss_fn=$(random_choice "${LOSS_FNS[@]}") + dataset=$(random_choice "${DATASETS[@]}") + subset=$(if [[ $dataset == "Dahoas/rm-static" ]]; then echo "None"; else echo "harmless-base"; fi) + for i in $(seq $NUM_RETRY); do + echo "[Test]: $model-$strategy-$lora_rank, attempt $i" + torchrun --standalone --nproc_per_node=4 $EXAMPLES_DIR/train_reward_model.py \ + $pretrain_model --tokenizer $MODELS_DIR/$model \ + --dataset $dataset --subset $subset --max_datasets_size 8 \ + --model $model --strategy $strategy --lora_rank $lora_rank \ + --loss_fn $loss_fn --batch_size 1 --lr 1e-8 \ + --save_path $EXAMPLES_DIR/rlhf_models/rm_ckpt_${model}_${lora_rank}.pt + passed=$? + if [ $passed -eq 0 ]; then + break + fi + done + if [ $passed -ne 0 ]; then + echo "[Test]: Failed to train reward model $model-$strategy-$lora_rank" + exit 1 + fi + done + done +done + +echo "[Test]: testing RLHF ..." + +# FIXME: This is a hack to skip tests that are not working +# - gpt2-ddp: RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation +# - llama-*: These tests can be passed locally, skipped for long execution time +# - *-gemini: Gemini plugin does not support `from_pretrained` yet +SKIPPED_TESTS=( + "gpt2-ddp" + "llama-ddp" + "llama-colossalai_gemini" + "llama-colossalai_zero2" +) + +for model in ${MODELS[@]}; do + for lora_rank in '0'; do + strategies=($(shuf -e "${STRATEGIES[@]}")) + for strategy in ${strategies[@]}; do + if [[ " ${SKIPPED_TESTS[*]} " =~ " $model-$strategy-$lora_rank " ]]; then + echo "[Test]: Skipped $model-$strategy-$lora_rank" + continue + elif [[ " ${SKIPPED_TESTS[*]} " =~ " $model-$strategy " ]]; then + echo "[Test]: Skipped $model-$strategy" + continue + fi + rm_pretrain=$(get_pretrain $model) + rm_pretrain_model="" + if [[ $lora_rank -gt 0 ]]; then + rm_pretrain_model="--rm_pretrain $rm_pretrain" + fi + for i in $(seq $NUM_RETRY); do + echo "[Test]: $model-$strategy-$lora_rank, attempt $i" + torchrun --standalone --nproc_per_node=4 $EXAMPLES_DIR/train_prompts.py \ + --prompt_dataset $PROMPT_DATASET --pretrain_dataset $PRETRAIN_DATASET --max_datasets_size 32 \ + --strategy $strategy --model $model --tokenizer $MODELS_DIR/$model \ + --num_episodes 1 --num_collect_steps 1 --num_update_steps 1 --lr 1e-8 \ + --experience_batch_size 2 --train_batch_size 1 --lora_rank $lora_rank \ + --pretrain $EXAMPLES_DIR/rlhf_models/sft_ckpt_${model}_${lora_rank} \ + $rm_pretrain_model --rm_path $EXAMPLES_DIR/rlhf_models/rm_ckpt_${model}_${lora_rank}.pt \ + --save_path $EXAMPLES_DIR/rlhf_models/actor_checkpoint_prompts + passed=$? + if [ $passed -eq 0 ]; then + break + fi + done + if [ $passed -ne 0 ]; then + echo "[Test]: Failed to train RLHF $model-$strategy-$lora_rank" + exit 1 + fi + done + rm -rf $EXAMPLES_DIR/rlhf_models/sft_ckpt_${model}_${lora_rank} + rm $EXAMPLES_DIR/rlhf_models/rm_ckpt_${model}_${lora_rank}.pt + done +done +rm -rf $EXAMPLES_DIR/rlhf_models/actor_checkpoint_prompts diff --git a/applications/Colossal-LLaMA-2/README.md b/applications/Colossal-LLaMA-2/README.md new file mode 100644 index 0000000000000000000000000000000000000000..34967c04360c7a436b754d76d76e12fb4d0e9237 --- /dev/null +++ b/applications/Colossal-LLaMA-2/README.md @@ -0,0 +1,390 @@ +
        +

        + +

        +
        + +## Table of Contents +- [News](#news) +- [Colossal-LLaMA-2-7B](#colossal-llama-2-7b) + - [Performance Evaluation](#performance-evaluation) + - [Examples](#examples) + - [Training Logs](#training-logs) + - [Import from Transformers](#import-from-transformers) +- [Usage](#usage) + - [Install](#install) + - [How to run](#how-to-run) +- [Technical Insight](#technical-insights) + - [Data](#data) + - [Tokenizer](#tokenizer) + - [Training Strategy](#training-strategy) + - [Bridging Any Domain-specific Large Models](#bridging-any-domain-specific-large-models) +- [Citations](#citations) + +## News +* [2023/09] [One Half-Day of Training Using a Few Hundred Dollars Yields Similar Results to Mainstream Large Models, Open-Source and Commercial-Free Domain-Specific Llm Solution](https://www.hpc-ai.tech/blog/one-half-day-of-training-using-a-few-hundred-dollars-yields-similar-results-to-mainstream-large-models-open-source-and-commercial-free-domain-specific-llm-solution) +[[code]](https://github.com/hpcaitech/ColossalAI/tree/main/applications/Colossal-LLaMA-2) +[[blog]](https://www.hpc-ai.tech/blog/one-half-day-of-training-using-a-few-hundred-dollars-yields-similar-results-to-mainstream-large-models-open-source-and-commercial-free-domain-specific-llm-solution) +[[model weights]](https://huggingface.co/hpcai-tech/Colossal-LLaMA-2-7b-base) + +## Colossal-LLaMA-2-7B +The [Colossal-AI](https://github.com/hpcaitech/ColossalAI) team has introduced the open-source model **Colossal-LLaMA-2-7B-base**. This model, a derivation of LLaMA-2, has undergone continual pre-training involving approximately 8.5 billion tokens over a duration of 15 hours with 64 A800 GPUs. At a cost of **less than $1,000**, you can achieve results **similar to those that cost millions of dollars to pretrain from scratch**. It is licensed under the LLaMA-2 license and [Apache 2.0 License](https://github.com/hpcaitech/ColossalAI/blob/main/LICENSE) **without any additional commercial use restrictions**. This solution can also be used to build models of specific domain knowledge or tasks. + +Colossal-LLaMA-2-7B-base is designed to accommodate both the Chinese and English languages, featuring an expansive context window spanning 4096 tokens. Remarkably, it has exhibited exceptional performance when benchmarked against models of equivalent scale in standard Chinese and English evaluation metrics, including C-Eval and MMLU, among others. + +❗️**Important notice**: +* All training data used for this project is collected from well-known public dataset. +* We do not use any testing data from the evaluation benchmarks for training. + +### Performance Evaluation +We conducted comprehensive evaluation on 4 dataset and compare our Colossal-Llama-2-7b-base model with various models. + +* We use 5-shot for MMLU and calculate scores based on the logits of first predicted token. +* We use 5-shot for CMMLU and calculate scores based on the logits of first predicted token. +* We use 5-shot for AGIEval and only calculate scores for 4-choice questions using a combination metric of exact match and the logits of first predicted token. If any of the exact match or logits of first predicted token is correct, the model will get the score. +* We use 0-shot for GAOKAO-Bench and only calculate scores for 4-choice questions based on the logits of first predicted token. +The generation config for all dataset is greedy search. +* We also provided CEval scores from its lastest leaderboard or the official repository of the model. + +| | Backbone | Tokens Consumed | | MMLU | CMMLU | AGIEval | GAOKAO | CEval | +| :----------------------------: | :--------: | :-------------: | :------------------: | :-----------: | :-----: | :----: | :----: | :------------------------------: | +| | | - | | 5-shot | 5-shot | 5-shot | 0-shot | 5-shot | +| Baichuan-7B | - | 1.2T | | 42.32 (42.30) | 44.53 (44.02) | 38.72 | 36.74 | 42.80 | +| Baichuan-13B-Base | - | 1.4T | | 50.51 (51.60) | 55.73 (55.30) | 47.20 | 51.41 | 53.60 | +| Baichuan2-7B-Base | - | 2.6T | | 46.97 (54.16) | 57.67 (57.07) | 45.76 | 52.60 | 54.00 | +| Baichuan2-13B-Base | - | 2.6T | | 54.84 (59.17) | 62.62 (61.97) | 52.08 | 58.25 | 58.10 | +| ChatGLM-6B | - | 1.0T | | 39.67 (40.63) | 41.17 (-) | 40.10 | 36.53 | 38.90 | +| ChatGLM2-6B | - | 1.4T | | 44.74 (45.46) | 49.40 (-) | 46.36 | 45.49 | 51.70 | +| InternLM-7B | - | 1.6T | | 46.70 (51.00) | 52.00 (-) | 44.77 | 61.64 | 52.80 | +| Qwen-7B (original) | - | 2.2T | | 54.29 (56.70) | 56.03 (58.80) | 52.47 | 56.42 | 59.60 | +| | | | | | | | | | +| Llama-2-7B | - | 2.0T | | 44.47 (45.30) | 32.97 (-) | 32.60 | 25.46 | - | +| Linly-AI/Chinese-LLaMA-2-7B-hf | Llama-2-7B | 1.0T | | 37.43 | 29.92 | 32.00 | 27.57 | - | +| wenge-research/yayi-7b-llama2 | Llama-2-7B | - | | 38.56 | 31.52 | 30.99 | 25.95 | - | +| ziqingyang/chinese-llama-2-7b | Llama-2-7B | - | | 33.86 | 34.69 | 34.52 | 25.18 | 34.2 | +| TigerResearch/tigerbot-7b-base | Llama-2-7B | 0.3T | | 43.73 | 42.04 | 37.64 | 30.61 | - | +| LinkSoul/Chinese-Llama-2-7b | Llama-2-7B | - | | 48.41 | 38.31 | 38.45 | 27.72 | - | +| FlagAlpha/Atom-7B | Llama-2-7B | 0.1T | | 49.96 | 41.10 | 39.83 | 33.00 | - | +| IDEA-CCNL/Ziya-LLaMA-13B-v1.1 | Llama-13B | 0.11T | | 50.25 | 40.99 | 40.04 | 30.54 | - | +| | | | | | | | | | +| **Colossal-LLaMA-2-7b-base** | Llama-2-7B | **0.0085T** | | 53.06 | 49.89 | 51.48 | 58.82 | 50.2 | + +> The score in parentheses corresponds to the scores in the official repository of the model. +> +> We use zero-shot for ChatGLM models. +> +> Qwen-7B is now inaccessible in Hugging Face, we are using the latest version of it before it was made inaccessible. Only for dataset MMLU, the prompt would be "xxx Answer:"(remove the space after ":") and we calculate the logits over " A", " B", " C" and " D" for Qwen-7B. Qwen-7B tends to be much more deterministic than other models. For example, the logits over " A" can be `-inf` and softmax would be exact `0`. +> +> For other models and other dataset, we calculate logits over "A", "B", "C" and "D". + +❗️ More details of the evaluation methods and reproduction of the results, please refer to [ColossalEval](https://github.com/hpcaitech/ColossalAI/tree/main/applications/ColossalEval). + +### Examples +| Question Type | Question |
        Colossal-LLaMA-2-7b-base
        | +| :------: | :-----------------------: | :----------- | +| Knowledge Quiz
        (Tourism) | 请问武汉的最佳的旅行季节是什么时候? | 3-4月份到武汉赏樱,10月份到武汉看黄陂枫树。武汉的最佳的旅行季节是春秋,3-4月份到武汉赏樱,10月份到武汉看黄陂枫树。武汉,简称“汉”别称江城,是湖北省省会,中部六省唯一的副省级市,特大城市,中国中部地区的中心城市,全国重要的工业基地、科教基地和综合交通枢纽。 | +| Knowledge Quiz
        (Computer Science) | 在计算机中,什么是RAM?与ROM有什么区别? | RAM是随机存取内存,可以随时读写,但是断电之后会丢失数据。存储数据比较快,但是需要不断的进行通电供电。ROM是只读内存,只能读取数据,不能写入数据。断电之后,存储的数据也不会丢失,但是需要不断进行通电。两者的价格和性能上会有区别,而且在使用的时候,RAM需要不断的供电,而ROM断电数据不丢失。 | +| Knowledge Quiz
        (Literature) | 古代文学与现代文学有哪些区别和共通之处? | 古代文学与现代文学有三点区别,现代文学更加通俗易懂,具有可读性,而古代文学相对来说更加注重形式,缺乏可读性。现代文学和古代文学,同属于中华文学大花园中的一枝花蕾。但是它们由于历史的原因,有着不同的特点,具体表现在,现代文学有更丰富的描写方法,更强烈的主观色彩和更突出的抒情色彩,而古代文学的主要特征是在创作风格上的多样化。 | +| Emotional Analysis | 请判断这个句子的情绪是否正面:我迫不及待想见到她!!!”。 | 这一句是情绪是正向的。是在期待见面的喜悦情绪。 | +| Translation | 将以下句子翻译成英语。 我喜欢看电影和读书。 | I like watching movies and reading books. | +| Logical Reasoning | All books have the same weight, 10 books weigh 5kg, what is the weight of 2 books? | To find the weight of 2 books, we can use the following steps:

        1. Determine the weight of 1 book.
        2. Multiply the weight of 1 book by the number of books.

        Step 1: Determine the weight of 1 book.
        Since 10 books weigh 5kg, we can divide 5kg by 10 to find the weight of 1 book.
        5kg ÷ 10 = 0.5kg

        Step 2: Multiply the weight of 1 book by the number of books.
        To find the weight of 2 books, we multiply the weight of 1 book by 2.
        0.5kg × 2 = 1kg

        So, the weight of 2 books is 1kg. | +| Information Extraction | The company's annual conference, featuring keynote speakers and exclusive product launches, will be held at the Los Angeles Convention Center from October 20th to October 23rd, 2021. Extract the date mentioned in the above sentence. | The date mentioned in the sentence is October 20th to October 23rd, 2021. | +| Error Correction | Please correct the syntax error and restructure the sentence: "Went to the store me and my friend." | The given sentence "Went to the store me and my friend." is incorrect due to a syntax error and it lacks a subject. The corrected version with proper grammar and sentence structure would be: "My friend and I went to the store." | + +❗️ More examples of question answering, please refer to [Colossal-LLaMA-2-7B-base Examples](docs/example.md). + +### Training Logs +We also recorded the training logs for the experiment + +

        + +

        + +

        + +

        + +### Import from Transformers (Inference) +To load Colossal-LLaMA-2-7B-base model using Transformers, use the following code: +```Python +from transformers import AutoModelForCausalLM, AutoTokenizer +model = AutoModelForCausalLM.from_pretrained("hpcai-tech/Colossal-LLaMA-2-7b-base", device_map="auto", trust_remote_code=True) +tokenizer = AutoTokenizer.from_pretrained("hpcai-tech/Colossal-LLaMA-2-7b-base", trust_remote_code=True) +input = "离离原上草," +inputs = tokenizer(input, return_tensors='pt') +inputs = inputs.to('cuda:0') +pred = model.generate(**inputs, + max_new_tokens=256, + do_sample=True, + top_k=50, + top_p=0.95, + num_return_sequences=1) +print(tokenizer.decode(pred.cpu()[0], skip_special_tokens=True)[len(input):]) +``` + +You can also download model weights from [🤗HuggingFace](https://huggingface.co/hpcai-tech/Colossal-LLaMA-2-7b-base). + +## Usage +### Install + +#### 0. Pre-requisite +1. This experiment was performed on 8 computing nodes with 64 A800 GPUs in total for LLaMA-2-7B (**about 1000 USD cost**). The nodes are connected with RDMA and GPUs within one node are fully connected with NVLink. The script was tested with CUDA 11.7, CUDA version requires 11.7 or higher. You can also complete it in about 5 days on a 8*A100/A800 server. + +2. PyTorch. The PyTorch version should be less than 2.0.0 and greater than 1.12.1. + + +#### 1. Install required packages +``` +cd Colossal-LLaMA-2 +pip install -r requirements.txt +``` +#### 2. Install `xentropy`, `layer_norm` and `rotary` +```bash +git clone git@github.com:Dao-AILab/flash-attention.git +# At the root folder +cd csrc/xentropy && pip install . +# At the root folder +cd csrc/layer_norm && pip install . +# At the root folder +cd csrc/rotary && pip install . +``` + +### How to run + +#### 1. Init Tokenizer Preparation +Initialize new tokenizer with additional Chinese tokens. Additional Chinese tokens are stored in `jsonl` format as follows: +```json +{"piece": "你好"} +{"piece": "人工智能"} +``` +Command to initialize new tokenizer: +```bash +export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION='python' +python colossal_llama2/tokenizer/init_tokenizer.py \ + --source_tokenizer_dir "" \ + --target_tokenizer_dir "" \ + --expand_tokens_file ".jsonl" +``` +Here is details about CLI arguments: +* Source tokenizer directory: `--source_tokenizer_dir`. Directory to the source tokenizer. It should at least contain three files: `special_tokens_map.json`, `tokenizer.model` and `tokenizer_config.json`. +* Target tokenizer directory: `--target_tokenizer_dir`. Directory to the target tokenizer. +* Tokens to be added: `--expand_tokens_file`. Additional tokens to be added to the tokenizer. + +#### 2. Init Model Preparation +Initialize the new model checkpoint by calculating the mean values from the original model checkpoint. +Command to initialize new model checkpoint: +```bash +python colossal_llama2/model/init_model.py \ + --source_model_and_tokenizer_path "" \ + --target_tokenizer_path "" \ + --target_model_path "" +``` +"" can be the same as "". + +Here is details about CLI arguments: +* Source model and tokenizer path: `--source_model_and_tokenizer_path`. Source folder contains both model and tokenizer, for example, LLaMA-2 model in Hugging Face format. +* Target tokenizer path: `--target_tokenizer_path`. Path to the new tokenizer folder generated from previous step. +* Target model path: `--target_model_path`. Path to save the new model in Hugging Face format. + +❗️**Important**: Once you initialize the new model checkpoint, copy your new tokenizer files (`special_tokens_map.json`, `tokenizer.model` and `tokenizer_config.json`) to your new model folder. + +#### 3. Data Preparation +Raw data should be formatted as `jsonl` format. Each data point should have the following fields: +* `source` (str, compulsory): This part is ignored when calculating loss. Default can be empty. +* `target` (str, compulsory): Loss will be calculated. +* `category` (str, compulsory): Tags for each data point. + +Examples: +```JSON +{"source": "", "target": "Lionel Andrés Messi(Spanish pronunciation: [ljoˈnel anˈdɾes ˈmesi] (i); born 24 June 1987), also known as Leo Messi, is an Argentine professional footballer who plays as a forward for and captains both Major League Soccer club Inter Miami and the Argentina national team.", "category": "sports"} +{"source": "猜谜语:一身卷卷细毛,吃的青青野草,过了数九寒冬,无私献出白毛。(打一动物)", "target": "白羊", "category": "riddle"} +``` +You are allowed to customize the category tags or use `unknown` to define the category. + +Command to convert jsonl dataset to arrow format: +``` +python prepare_pretrain_dataset.py \ + --data_input_dirs ",," \ + --tokenizer_dir "" \ + --data_cache_dir "jsonl_to_arrow_cache" \ + --data_jsonl_output_dir "spliced_tokenized_output_jsonl" \ + --data_arrow_output_dir "spliced_tokenized_output_arrow" \ + --max_length 4096 \ + --num_spliced_dataset_bins 10 +``` +Here is details about CLI arguments: +* Source data directory: `data_input_dirs`. Each `` can have multiple file in `jsonl` format. +* Tokenzier directory: `tokenizer_dir`. Path to the tokenizer in Hugging Face format. +* Data cache directory: `data_cache_dir`. Directory to store Hugging Face data cache. Default case will create `cache` folder locally. +* Output directory for jsonl format: `data_jsonl_output_dir`. Output directory to store converted dataset in jsonl format. +* Output directory for arrow format: `data_arrow_output_dir`. Output directory to store converted dataset in arrow format, which can be used for training directly. +* Max length: `max_length`. Max length of spliced samples. Default value is 4096. +* Number of bins for each category: `num_spliced_dataset_bins`. Number of bins for each category, used for bucket-based training. + +#### 4. Command Line Arguments for Training +You can use `colossalai run` to launch multi-nodes training: +```bash +colossalai run --nproc_per_node YOUR_GPU_PER_NODE --hostfile YOUR_HOST_FILE \ +train.py --OTHER_CONFIGURATIONS +``` +Here is a sample hostfile: +```bash +hostname1 +hostname2 +hostname3 +hostname4 +``` +Make sure master node can access all nodes (including itself) by ssh without password. + +Here is details about CLI arguments: +* Pre-trained model path: `--pretrained`. Path to the pre-trained model in Hugging Face format. +* Dataset path: `--dataset`. Path to the pre-tokenized dataset. +* Booster plugin: `--plugin`. `gemini`, `gemini_auto`, `zero2`,`zero2_cpu` and `3d` are supported.For more details, please refer to [Booster plugins](https://colossalai.org/docs/basics/booster_plugins/). +* Intermediate checkpoint to load: `--load_checkpoint`. Path to the intermediate checkpoint. Saved checkpoint contains the states for `lr_scheduler`, `optimizer`,`running_states.json` and `modelling`. If `load_checkpoint` points to the `modelling` folder, only the model weights will be loaded without any other states to support multi-stage training. +* Save interval: `--save_interval`. The interval (steps) of saving checkpoints. The default value is 1000. +* Checkpoint directory: `--save_dir`. The directoty path to save checkpoint and intermediate states. Intermediate states include `lr_scheduler`, `optimizer`,`running_states.json` and `modelling`. +* Tensorboard directory: `--tensorboard_dir`. The path to save tensorboard logs. +* Configuration file: `--config_file`. The path to save the configuration file. +* Number of epochs: `--num_epochs`. Number of training epochs. The default value is 1. +* Micro batch size: `--micro_batch_size`. Batch size per GPU. The default value is 1. +* Learning rate: `--lr`. The default value is 3e-4. +* Max length: `--max_length`. Max context length. The default value is 4096. +* Mixed precision: `--mixed_precision`. The default value is "fp16". "fp16" and "bf16" are supported. +* Gradient clipping: `--gradient_clipping`. The default value is 1.0. +* Weight decay: `-w`, `--weight_decay`. The default value is 0.1. +* Warmup steps: `-s`, `--warmup_steps`. The default value is calcuated by 0.025 warmup ratio. +* Gradient checkpointing: `--use_grad_checkpoint`. The default value is `False`. This saves memory at the cost of speed. You'd better enable this option when training with a large batch size. +* Flash attention: `--use_flash_attn`. If you want to use flash attention, you must install `flash-attn` and related packages. The default value is `False`. This is helpful to accelerate training while saving memory. We recommend you always use flash attention. +* Freeze non-embedding parameters: `--freeze_non_embeds_params`. Freeze non-embedding parameters. It can be helpful to align embeddings after extending vocabulary size. +* Tensor parallelism size: `--tp`. TP size for 3d Parallelism. The default value is 1. +* Zero stage: `--zero`. Zero stage for 3d Parallelism. The default value is 1. + +#### 5. Running Command +An [example bash](train.example.sh) is also provided for the experiment. Here is the steps to run the experiment: +* Create your own hostfile: `cp hostfile.example hostfile`. +* Create your own bash: `cp train.example.sh train.sh`. +* Add your real host ip or host name into the `hostfile`. +* Update global variables and parameters in your `train.sh`. +* Run the experiment by `bash train.sh` + +Here is the details about global variables for each experiment: +* `PROJECT_NAME`: Project name for each experiment. +* `PARENT_SAVE_DIR`: Parent folder to save model checkpoint. +* `PARENT_TENSORBOARD_DIR`: Parent folder to save tensorboard logs. +* `PARENT_CONFIG_FILE`: Parent folder to save configuration for each experiment. +* `PRETRAINED_MODEL_PATH`: Path to the local pre-trained model checkpoint. +* `dataset`: Paths to all prepared data. Typically, it's a list of subfolders within the output path of prepare data, `--data_arrow_output_dir`, and if there are multiple subfolders, please list them all. e.g., +```python +declare -a dataset=( + "/part-00000" + "/part-00001" + "/part-00000" +) +``` +## Technical Insights +In order to enhance LLaMA-2's capabilities for understanding and generating Chinese content, The [Colossal-AI](https://github.com/hpcaitech/ColossalAI) team proposes the continuation of pre-training the LLaMA-2 model using both Chinese and English corpora. The overall pipeline can be described as follows: + +

        + +

        + +### Data +Large language models such as LLaMA-2 have undergone training using a heterogeneous blend of high-quality datasets, yielding promising outcomes. Enhancing LLaMA-2's performance for the Chinese corpus, while preserving its proficiency in English, critically hinges on two pivotal factors: the composition of the dataset, which encompasses both English and Chinese content, and the quality of each constituent dataset. + +The following figure shows the data processing pipeline conducted for Colossal-LLaMA-2. +

        + +

        + +❗️**Important**: We will open-source our data-processing toolkit soon, stay tuned! + +### Tokenizer +The original LLaMA-2 vacabulary comprises fewer than a thousand Chinese characters, thus proves inadequate for encoding comprehensive Chinese texts effectively. Secondly, the utilization of byte tokens presents a challenge for transformer encoders to capture the semantic nuances of Chinese characters. + +To address the above issues, we extend LLaMA-2 vocabulary from 32,000 to 69,104. To adapt the LLaMA-2 model for use with the Colossal-LLaMA-2 tokenizer, we initialize the new word embeddings by calculating the mean values from the original LLaMA-2 embeddings and subsequently append these new rows to the end of the original embedding matrices. + +Advantages of extending vocabulary size: +* Improve the compression rate of string sequence encoding. +* Enhance the integrity of information. +* Enable encoded sequences to contain more valuable information, thereby theoretically enhancing the ability for chapter-level encoding. + +Advantages of large vocabulary size under low-resource settings: +* The presence of numerous unused tokens can be attributed to the limited training dataset, where an excessive number of tokens might not have been effectively learned. +* Excessive vocabulary expansion leads to an increase in embedding-related parameters, resulting in higher memory usage, which, in turn, affects the efficiency of the training process. + +To balance both sides, we finally construct our vocabulary with size 69,104. The following table below presents a comparison of various models at the 7B level. + +| Model | Vocabulary Size | Compression Rate | Average Length of Samples (token-level) | +| :-----------: | :---------: | :----: | :----: | +| Colossal-LLaMA-2 | 69104 | 0.659 | 73.682 | +| LLaMA-2-7B | 32000 | 1.205 | 134.689 | +| Atom-7B | 65000 | 0.634 | 70.915 | +| Baichuan-7B | 64000 | 0.678 | 75.857 | +| Baichuan2-7B-base | 125696 | 0.570 | 63.761 | +| Chatglm2-6B | 64789 | 0.645 | 72.178 | +| InternLM-7B | 103168 | 0.566 | 63.349 | +| Qwen-7B | 151643 | 0.578 | 64.703 | +| Tigerbot-7B-base | 60515 | 0.630 | 70.515 | +| Yayi-7B-llama2 | 32005 | 1.214 | 135.689 | +| Chinese-llama-2-7b | 55296 | 0.668 | 74.690 | +| Chinese-Falcon-7B | 90046 | 0.669 | 74.858 | +| LinkSoul-Chinese-Llama-2-7b | 40076 | 0.958 | 107.089 | +| Ziya-LLaMA-13B-v1.1 | 39410 | 0.958 | 107.074 | + + +### Training Strategy +#### Multi-stage Training +In order to enhance the model's performance and harness the full potential of the original LLaMA-2, we have developed a multi-stage training strategy. This strategy is designed to systematically unlock the model's capabilities over a series of stages. + +Therefore, we have divided the training process into three stages: +* Large-scale pre-training stage (Conducted by LLaMA-2): This initial stage is aimed at establishing the model's foundational capabilities from the ground up. It necessitates the use of a substantial dataset comprising no less than 1 trillion tokens. +* Chinese knowledge injection stage: In this stage, we introduce Chinese knowledge into the model. It requires access to a high-quality dataset rich in comprehensive knowledge relevant to the Chinese language. +* Knowledge replay stage: Knowledge is replayed through a question-answering (QA) mechanism, encompassing both the Chinese and English domains. + +Following the completion of this multi-stage training process, the model exhibits notable improvements in performance across both English and Chinese benchmarks. + +The following figure illustrates the three stages for training Colossal-LLaMA-2. + +

        + +

        + +#### Bucket-based Training +Our experiments have revealed that the distributions within the training dataset, as well as the arrangement of various topic-related data points, significantly impact the overall performance of the model, particularly in the context of continual pre-training of LLaMA-2. + +In an effort to achieve a more balanced distribution and exert control over the dataset's ordering, we have adopted a method where we divide each sub-dataset into discrete bins. These bins are then combined to construct individual data buckets, with one bin contributed by each sub-dataset. + +### Bridging Any Domain-specific Large Models +Applying the above process to perform knowledge transfer in any field allows for the cost-effective construction of lightweight domain-specific foundational large models. + +

        + +

        + +## Citations +```bibtex +@article{bian2021colossal, + title={Colossal-AI: A Unified Deep Learning System For Large-Scale Parallel Training}, + author={Bian, Zhengda and Liu, Hongxin and Wang, Boxiang and Huang, Haichen and Li, Yongbin and Wang, Chuanrui and Cui, Fan and You, Yang}, + journal={arXiv preprint arXiv:2110.14883}, + year={2021} +} +``` +```bibtex +@misc{touvron2023llama, + title={Llama 2: Open Foundation and Fine-Tuned Chat Models}, + author={Hugo Touvron and Louis Martin and Kevin Stone and Peter Albert and Amjad Almahairi and Yasmine Babaei and Nikolay Bashlykov and Soumya Batra and Prajjwal Bhargava and Shruti Bhosale and Dan Bikel and Lukas Blecher and Cristian Canton Ferrer and Moya Chen and Guillem Cucurull and David Esiobu and Jude Fernandes and Jeremy Fu and Wenyin Fu and Brian Fuller and Cynthia Gao and Vedanuj Goswami and Naman Goyal and Anthony Hartshorn and Saghar Hosseini and Rui Hou and Hakan Inan and Marcin Kardas and Viktor Kerkez and Madian Khabsa and Isabel Kloumann and Artem Korenev and Punit Singh Koura and Marie-Anne Lachaux and Thibaut Lavril and Jenya Lee and Diana Liskovich and Yinghai Lu and Yuning Mao and Xavier Martinet and Todor Mihaylov and Pushkar Mishra and Igor Molybog and Yixin Nie and Andrew Poulton and Jeremy Reizenstein and Rashi Rungta and Kalyan Saladi and Alan Schelten and Ruan Silva and Eric Michael Smith and Ranjan Subramanian and Xiaoqing Ellen Tan and Binh Tang and Ross Taylor and Adina Williams and Jian Xiang Kuan and Puxin Xu and Zheng Yan and Iliyan Zarov and Yuchen Zhang and Angela Fan and Melanie Kambadur and Sharan Narang and Aurelien Rodriguez and Robert Stojnic and Sergey Edunov and Thomas Scialom}, + year={2023}, + eprint={2307.09288}, + archivePrefix={arXiv}, + primaryClass={cs.CL} +} +``` +```bibtex +@article{dao2023flashattention2, + title={Flash{A}ttention-2: Faster Attention with Better Parallelism and Work Partitioning}, + author={Dao, Tri}, + year={2023} +} +} +``` diff --git a/applications/Colossal-LLaMA-2/colossal_llama2/__init__.py b/applications/Colossal-LLaMA-2/colossal_llama2/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..56fafa58b3f43decb7699b93048b8b87e0f695aa --- /dev/null +++ b/applications/Colossal-LLaMA-2/colossal_llama2/__init__.py @@ -0,0 +1,2 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- diff --git a/applications/Colossal-LLaMA-2/colossal_llama2/dataset/__init__.py b/applications/Colossal-LLaMA-2/colossal_llama2/dataset/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..56fafa58b3f43decb7699b93048b8b87e0f695aa --- /dev/null +++ b/applications/Colossal-LLaMA-2/colossal_llama2/dataset/__init__.py @@ -0,0 +1,2 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- diff --git a/applications/Colossal-LLaMA-2/colossal_llama2/dataset/loader.py b/applications/Colossal-LLaMA-2/colossal_llama2/dataset/loader.py new file mode 100644 index 0000000000000000000000000000000000000000..a2cfb2ef6264b52801449ea72ccac0e1d1701bb5 --- /dev/null +++ b/applications/Colossal-LLaMA-2/colossal_llama2/dataset/loader.py @@ -0,0 +1,219 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- + +import numpy as np +import os +import random +from dataclasses import dataclass +from typing import Dict, List, Union, Sequence, Optional, Iterator, Callable + +import torch +from datasets import dataset_dict, load_from_disk +from datasets import Dataset as HFDataset +from torch.distributed import ProcessGroup +from torch.distributed.distributed_c10d import _get_default_group +from torch.utils.data import ConcatDataset, Dataset, DataLoader, DistributedSampler +from transformers.tokenization_utils import PreTrainedTokenizer +import torch.nn.functional as F + +DatasetType = Union[Dataset, ConcatDataset, dataset_dict.Dataset] +PathType = Union[str, os.PathLike] + + +def load_tokenized_dataset( + dataset_paths: Union[PathType, List[PathType]], mode: str = "train" +) -> Optional[DatasetType]: + """ + Load pre-tokenized dataset. + Each instance of dataset is a dictionary with + `{'input_ids': List[int], 'labels': List[int], sequence: str}` format. + """ + mode_map = {"train": "train", "dev": "validation", "test": "test"} + assert mode in tuple(mode_map), f"Unsupported mode {mode}, it must be in {tuple(mode_map)}" + + if isinstance(dataset_paths, (str, os.PathLike)): + dataset_paths = [dataset_paths] + + datasets = [] # `List[datasets.dataset_dict.Dataset]` + for ds_path in dataset_paths: + ds_path = os.path.abspath(ds_path) + assert os.path.exists(ds_path), f"Not existed file path {ds_path}" + ds_dict = load_from_disk(dataset_path=ds_path, keep_in_memory=False) + if isinstance(ds_dict, HFDataset): + datasets.append(ds_dict) + else: + if mode_map[mode] in ds_dict: + datasets.append(ds_dict[mode_map[mode]]) + if len(datasets) == 0: + return None + if len(datasets) == 1: + return datasets.pop() + return ConcatDataset(datasets=datasets) + + +@dataclass +class DataCollatorForSupervisedDataset(object): + """ + Collate instances for supervised dataset. + Each instance is a tokenized dictionary with fields + `input_ids`(List[int]), `labels`(List[int]) and `sequence`(str). + """ + + tokenizer: PreTrainedTokenizer + max_length: int = 4096 + ignore_index: int = -100 + + def __call__(self, instances: Sequence[Dict[str, List[int]]]) -> Dict[str, torch.Tensor]: + """ + + Args: + instances (`Sequence[Dict[str, List[int]]]`): + Mini-batch samples, each sample is stored in an individual dictionary. + + Returns: + (`Dict[str, torch.Tensor]`): Contains the following `torch.Tensor`: + `input_ids`: `torch.Tensor` of shape (bsz, max_len); + `attention_mask`: `torch.BoolTensor` of shape (bsz, max_len); + `labels`: `torch.Tensor` of shape (bsz, max_len), which contains `IGNORE_INDEX`. + """ + assert isinstance(self.tokenizer.pad_token_id, int) and self.tokenizer.pad_token_id >= 0, ( + f"`{self.tokenizer.__class__.__name__}.pad_token_id` must be a valid non-negative integer index value, " + f"but now `{self.tokenizer.pad_token_id}`" + ) + + # `List[torch.Tensor]` + batch_input_ids = [ + torch.LongTensor(instance["input_ids"][: self.max_length]) + if len(instance["input_ids"]) > self.max_length + else torch.LongTensor(instance["input_ids"]) + for instance in instances + ] + batch_labels = [ + torch.LongTensor(instance["labels"][: self.max_length]) + if len(instance["labels"]) > self.max_length + else torch.LongTensor(instance["labels"]) + for instance in instances + ] + + if self.tokenizer.padding_side == "right": + input_ids = torch.nn.utils.rnn.pad_sequence( + sequences=batch_input_ids, + batch_first=True, + padding_value=self.tokenizer.pad_token_id, + ) # (bsz, max_len) + labels = torch.nn.utils.rnn.pad_sequence( + sequences=batch_labels, + batch_first=True, + padding_value=self.ignore_index, + ) # (bsz, max_len) + # pad to max + to_pad = self.max_length - input_ids.size(1) + input_ids = F.pad(input_ids, (0, to_pad), value=self.tokenizer.pad_token_id) + labels = F.pad(labels, (0, to_pad), value=self.ignore_index) + elif self.tokenizer.padding_side == "left": + reversed_input_ids = [seq.flip(dims=(0,)) for seq in batch_input_ids] + reversed_input_ids = torch.nn.utils.rnn.pad_sequence( + sequences=reversed_input_ids, + batch_first=True, + padding_value=self.tokenizer.pad_token_id, + ) # (bsz, max_len) + input_ids = torch.flip(reversed_input_ids, dims=(1,)) # (bsz, max_len) + reversed_labels = [seq.flip(dims=(0,)) for seq in batch_labels] + reversed_labels = torch.nn.utils.rnn.pad_sequence( + sequences=reversed_labels, + batch_first=True, + padding_value=self.ignore_index, + ) # (bsz, max_len) + labels = torch.flip(reversed_labels, dims=(1,)) # (bsz, max_len) + else: + raise RuntimeError( + f"`{self.tokenizer.__class__.__name__}.padding_side` can only be `left` or `right`, " + f"but now `{self.tokenizer.padding_side}`" + ) + + attention_mask = input_ids.ne(self.tokenizer.pad_token_id) # `torch.BoolTensor`, (bsz, max_len) + + return dict(input_ids=input_ids, attention_mask=attention_mask, labels=labels) + + +class StatefulDistributedSampler(DistributedSampler): + """ + Stateful distributed sampler for multi-stage training. + """ + + def __init__( + self, + dataset: DatasetType, + num_replicas: Optional[int] = None, + rank: Optional[int] = None, + shuffle: bool = True, + seed: int = 0, + drop_last: bool = False, + ) -> None: + super().__init__( + dataset=dataset, + num_replicas=num_replicas, + rank=rank, + shuffle=shuffle, + seed=seed, + drop_last=drop_last, + ) + self.start_index = 0 + + def __iter__(self) -> Iterator: + iterator = super().__iter__() + indices = list(iterator) + indices = indices[self.start_index :] + return iter(indices) + + def __len__(self) -> int: + return self.num_samples - self.start_index + + def set_start_index(self, start_index: int) -> None: + self.start_index = start_index + + +def setup_distributed_dataloader( + dataset: DatasetType, + batch_size: int = 1, + shuffle: bool = False, + seed: int = 1024, + drop_last: bool = False, + pin_memory: bool = False, + num_workers: int = 0, + collate_fn: Callable[[Sequence[Dict[str, Union[str, List[int]]]]], Dict[str, torch.Tensor]] = None, + process_group: Optional[ProcessGroup] = None, + **kwargs, +) -> DataLoader: + """ + Setup dataloader for distributed training. + """ + _kwargs = kwargs.copy() + process_group = process_group or _get_default_group() + sampler = StatefulDistributedSampler( + dataset=dataset, + num_replicas=process_group.size(), + rank=process_group.rank(), + shuffle=shuffle, + seed=seed, + drop_last=drop_last, + ) + + # Deterministic dataloader + def seed_worker(worker_id: int) -> None: + worker_seed = seed + np.random.seed(worker_seed) + torch.manual_seed(worker_seed) + random.seed(worker_seed) + + return DataLoader( + dataset=dataset, + batch_size=batch_size, + sampler=sampler, + num_workers=num_workers, + collate_fn=collate_fn, + pin_memory=pin_memory, + drop_last=drop_last, + worker_init_fn=seed_worker, + **_kwargs, + ) diff --git a/applications/Colossal-LLaMA-2/colossal_llama2/dataset/spliced_and_tokenized_dataset.py b/applications/Colossal-LLaMA-2/colossal_llama2/dataset/spliced_and_tokenized_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..0c21f325ae6227a9df42da1aec6f1354e72da84b --- /dev/null +++ b/applications/Colossal-LLaMA-2/colossal_llama2/dataset/spliced_and_tokenized_dataset.py @@ -0,0 +1,183 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +""" +Splicing multiple pre-tokenized sequence data points +""" + +import random +import warnings +from copy import deepcopy +from datasets import dataset_dict +from typing import Any, Callable, Dict, Iterable, List, Union, Tuple + +from torch.utils.data import ConcatDataset, Dataset, IterableDataset +from transformers.models.llama.tokenization_llama import LlamaTokenizer +from transformers.tokenization_utils import PreTrainedTokenizer + +IGNORE_INDEX = -100 + +DSType = Union[Dataset, ConcatDataset, dataset_dict.Dataset] + + +def supervised_tokenize( + data_point: Dict[str, str], tokenizer: LlamaTokenizer, ignore_index: int = None, max_length: int = 4096 +) -> Dict[str, Union[int, str, List[int]]]: + """ + A tokenization function to tokenize an original pretraining data point as following: + {"source": "", "target": "Beijing, the capital of the People's Republic of China, ...", "category": "geography"} + """ + assert tokenizer.add_bos_token is False and tokenizer.add_eos_token is False, ( + "Initially set `tokenizer.add_bos_token` and `tokenizer.add_eos_token` to False, " + "add and manually later" + ) + if ignore_index is None: + ignore_index = IGNORE_INDEX + + source_text = data_point["source"] # `str` + target_text = data_point["target"] # `str` + is_null_source = len(source_text) == 0 + + source_text = tokenizer.bos_token + source_text + target_text += tokenizer.eos_token + sequence_text = source_text + target_text + + tokenized = tokenizer([source_text, sequence_text])["input_ids"] + sequence_input_ids = tokenized[1] + sequence_labels = deepcopy(sequence_input_ids) + + source_length = len(tokenized[0]) + if not is_null_source: + sequence_labels[:source_length] = [ignore_index for _ in range(source_length)] + + # sequence truncation. + if len(sequence_input_ids) > max_length: + sequence_input_ids = sequence_input_ids[:max_length] + sequence_labels = sequence_labels[:max_length] + + return dict( + input_ids=sequence_input_ids, + labels=sequence_labels, + seq_length=len(sequence_input_ids), + seq_category=data_point["category"], + ) + + +class ClosedToConstantLengthSplicedDataset(IterableDataset): + """ + Define an iterable dataset that returns a (close to) constant length data point spliced from multiple + original independent (pre-tokenized) data points. + """ + + def __init__( + self, + dataset: DSType, + tokenizer: PreTrainedTokenizer, + max_length: int = 4096, + num_packed_sequences: int = 8, + fetch_sequence_func: Callable[[Any], Tuple[List[int], List[int]]] = None, + input_ids_field: str = "input_ids", + labels_field: str = "labels", + infinite: bool = False, + shuffle: bool = True, + error_strict: bool = False, + ) -> None: + self.tokenizer = tokenizer + self.dataset = dataset + self.max_length = max_length + self.infinite = infinite + self.max_buffer_size = max_length * num_packed_sequences # e.g., 4096 * 16 + self.shuffle = shuffle + + # Callable[[Dict[str, Any]], Tuple[List[int], List[int]]], + # A function that fetch sequence input_ids and labels from the original data point + if fetch_sequence_func is None: + self.fetch_sequence_func = lambda data_point: (data_point[input_ids_field], data_point[labels_field]) + else: + self.fetch_sequence_func = fetch_sequence_func + self.input_ids_field = input_ids_field + self.labels_field = labels_field + + self.error_strict = error_strict + self.current_size = 0 # `int`, current packed data size. + + def __len__(self) -> int: + return len(self.dataset) + + def __iter__(self) -> Iterable[Dict[str, List[int]]]: + iterator = iter(self.dataset) + more_data_points = True + while more_data_points is True: + buffer, buffer_len = [], 0 + while True: + # ending condition. + if buffer_len >= self.max_buffer_size: + break + try: + # `Tuple[List[int], List[int]]` + seq_input_ids, seq_labels = self.fetch_sequence_func(next(iterator)) + buffer.append({self.input_ids_field: seq_input_ids, self.labels_field: seq_labels}) + buffer_len += len(buffer[-1][self.input_ids_field]) + except StopIteration: + if self.infinite is True: + iterator = iter(self.dataset) + warnings.warn("The dataset reached end and the iterator is reset to the start.") + else: + more_data_points = False + break + examples = [] # `List[Dict[str, List[int]]]`, save buffered spliced data points. + spliced_input_ids, spliced_labels = [], [] # `List[int]`, `List[int]` + for i, data_point in enumerate(buffer): + # TODO(2023-09-18) check errors for each unspliced tokenized data point + seq_input_ids = data_point[self.input_ids_field] + seq_labels = data_point[self.labels_field] + # Handle special case: + # If the length of an original data point (i.e., input_ids length of a data point before splicing) + # exceeds `max_length`, truncate it. + if len(seq_input_ids) > self.max_length: + truncated_seq_input_ids = seq_input_ids[: self.max_length] + truncated_label_ids = seq_labels[: self.max_length] + if set(truncated_label_ids) == {IGNORE_INDEX}: + if self.error_strict is True: + raise ValueError( + f"Find an out-of-bounds length({len(seq_input_ids)}) data point " + f"with all label values as {IGNORE_INDEX}." + ) + else: + warnings.warn(f"Filter an error truncated data point (labels all {IGNORE_INDEX})") + continue # Skip the current error data point. + spliced_data_point = { + self.input_ids_field: truncated_seq_input_ids, + self.labels_field: truncated_label_ids, + } + examples.append(spliced_data_point) + warnings.warn("Find a data point to be truncated.") + continue + + # Pre action judgment. + if len(spliced_input_ids) + len(seq_input_ids) > self.max_length: + spliced_data_point = { + self.input_ids_field: spliced_input_ids, + self.labels_field: spliced_labels, + } # `Dict[str, List[int]]` + # Update. + spliced_input_ids, spliced_labels = [], [] + spliced_input_ids.extend(seq_input_ids) + spliced_labels.extend(seq_labels) + examples.append(spliced_data_point) + else: + spliced_input_ids.extend(seq_input_ids) + spliced_labels.extend(seq_labels) + # For residual spliced data point at the end of the data set + if self.infinite is False and more_data_points is False and len(spliced_input_ids) > 0: + examples.append( + { + self.input_ids_field: spliced_input_ids, + self.labels_field: spliced_labels + } + ) + if self.shuffle: + random.shuffle(examples) + for spliced_data_point in examples: + # TODO(2023-09-18): check errors for each spliced tokenized data point. + self.current_size += 1 + yield spliced_data_point diff --git a/applications/Colossal-LLaMA-2/colossal_llama2/model/init_model.py b/applications/Colossal-LLaMA-2/colossal_llama2/model/init_model.py new file mode 100644 index 0000000000000000000000000000000000000000..67e487f43b082f0cac6a3466f711c79c132ec80d --- /dev/null +++ b/applications/Colossal-LLaMA-2/colossal_llama2/model/init_model.py @@ -0,0 +1,111 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- + +""" +Initialize new model with updated tokenizer by calculating the mean values from original model +""" +import argparse + +import numpy as np +import torch +from transformers import LlamaTokenizer, LlamaForCausalLM + +from colossalai.logging import get_dist_logger + + +logger = get_dist_logger() + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--source_model_and_tokenizer_path", + type=str, + required=True, + default=None, + help="Source path of model & tokenizer", + ) + parser.add_argument("--target_tokenizer_path", type=str, required=True, default=None, help="Target tokenizer path") + parser.add_argument("--target_model_path", type=str, required=True, default=None, help="Target model path") + args = parser.parse_args() + + source_tokenizer = LlamaTokenizer.from_pretrained(args.source_model_and_tokenizer_path) + source_tokenizer.add_bos_token = False + source_tokenizer.add_eos_token = False + if source_tokenizer.pad_token is None: + source_tokenizer.pad_token = source_tokenizer.unk_token + source_vocab = source_tokenizer.get_vocab() + + target_tokenizer = LlamaTokenizer.from_pretrained(args.target_tokenizer_path) + target_tokenizer.add_bos_token = False + target_tokenizer.add_eos_token = False + if target_tokenizer.pad_token is None: + target_tokenizer.pad_token = target_tokenizer.unk_token + target_vocab = target_tokenizer.get_vocab() + target_inverted_vocab = {v: k for k, v in target_vocab.items()} + + assert len(target_vocab) > len( + source_vocab + ), f"Target vocab size({len(target_vocab)}) must be greater than source vocab size({len(source_vocab)})" + + gpu_device = torch.device("cuda:0") + cpu_device = torch.device("cpu") + + source_model = LlamaForCausalLM.from_pretrained(args.source_model_and_tokenizer_path) + source_model.eval() + source_model = source_model.to(gpu_device) + + source_input_embeddings = source_model.get_input_embeddings() + assert isinstance(source_input_embeddings, torch.nn.Embedding) + assert source_input_embeddings.weight.shape[0] == len(source_vocab) + source_input_embeddings.eval() + + source_output_embeddings = source_model.get_output_embeddings() + assert isinstance(source_output_embeddings, torch.nn.Linear) + assert source_output_embeddings.bias is None + assert source_output_embeddings.weight.shape[0] == len(source_vocab) + source_output_embeddings.eval() + + input_embeddings = source_input_embeddings.weight.cpu().detach().numpy() + output_embeddings = source_output_embeddings.weight.cpu().detach().numpy() + for i in range(len(source_vocab), len(target_vocab)): + if i % 500 == 0: + logger.info(f"processing {i}/{len(target_vocab)} target tokens") + target_token = target_inverted_vocab[i] + target_to_source_token_ids = torch.LongTensor(source_tokenizer([target_token])["input_ids"][0]) + target_to_source_token_ids = target_to_source_token_ids.to(gpu_device) + + target_to_source_input_embedding = ( + source_input_embeddings.weight[target_to_source_token_ids] + .mean(dim=0) + .unsqueeze(dim=0) + .cpu() + .detach() + .numpy() + ) + target_to_source_output_embedding = ( + source_output_embeddings.weight[target_to_source_token_ids] + .mean(dim=0) + .unsqueeze(dim=0) + .cpu() + .detach() + .numpy() + ) + + input_embeddings = np.concatenate((input_embeddings, target_to_source_input_embedding), axis=0) + output_embeddings = np.concatenate((output_embeddings, target_to_source_output_embedding), axis=0) + + source_model = source_model.to(cpu_device) + assert isinstance(source_model, LlamaForCausalLM) + + # expand + source_model.resize_token_embeddings(new_num_tokens=len(target_vocab)) + source_model.model.embed_tokens.weight.data = torch.Tensor(input_embeddings) + source_model.lm_head.weight.data = torch.Tensor(output_embeddings) + + source_model = source_model.half() + source_model.save_pretrained(save_directory=args.target_model_path) + + +if __name__ == "__main__": + main() diff --git a/applications/Colossal-LLaMA-2/colossal_llama2/tokenizer/init_tokenizer.py b/applications/Colossal-LLaMA-2/colossal_llama2/tokenizer/init_tokenizer.py new file mode 100644 index 0000000000000000000000000000000000000000..43297633db1a5b35142ddf04f0e0eb4b8a9d3ccb --- /dev/null +++ b/applications/Colossal-LLaMA-2/colossal_llama2/tokenizer/init_tokenizer.py @@ -0,0 +1,98 @@ +#!/usr/bin/env python +# -*- encoding: utf-8 -*- + +""" +Initialize new tokenizer for continual pre-training +""" + +import argparse +import os +import json +from typing import List, Union + +from transformers.models.llama.tokenization_llama import LlamaTokenizer +from sentencepiece import sentencepiece_model_pb2 as sp_pb2_model + +from colossalai.logging import get_dist_logger + +os.environ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION"] = "python" + +logger = get_dist_logger() + + +def expand_vocab_tokenizer( + source_tokenizer_dir: Union[str, os.PathLike], target_tokenizer_dir: Union[str, os.PathLike], new_tokens: List[str] +) -> None: + """Expand tokenizer for continue pre-training.""" + if os.path.exists(target_tokenizer_dir): + raise RuntimeError(f"Find existed directory {target_tokenizer_dir}") + + source_tokenizer = LlamaTokenizer.from_pretrained(source_tokenizer_dir) + logger.info(source_tokenizer) + source_sp_processor = source_tokenizer.sp_model + source_spm = sp_pb2_model.ModelProto() + source_spm.ParseFromString(source_sp_processor.serialized_model_proto()) + + logger.info(f"Source tokenizer size: {len(source_sp_processor)}") + + # Add new tokens to source tokenizer. + source_spm_tokens = set([p.piece for p in source_spm.pieces]) + for piece in new_tokens: + assert isinstance(piece, str), f"Invalid token({piece}) type {type(piece)}" + if piece in source_spm_tokens: + # Skip existed token. + continue + new_p = sp_pb2_model.ModelProto().SentencePiece() + new_p.piece = piece + new_p.score = 0 + source_spm.pieces.append(new_p) + logger.info(f"Expand vocab from {len(source_spm_tokens)} to {len(source_spm.pieces)}") + + # Save + os.makedirs(target_tokenizer_dir) + target_tokenizer_model_path = os.path.join(target_tokenizer_dir, "tokenizer.model") + with open(file=target_tokenizer_model_path, mode="wb") as fp: + fp.write(source_spm.SerializeToString()) + + target_tokenizer = LlamaTokenizer(vocab_file=target_tokenizer_model_path) + target_tokenizer.save_pretrained(save_directory=target_tokenizer_dir) + logger.info(f"Successfully save expand tokenizer to {target_tokenizer_dir}") + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--source_tokenizer_dir", type=str, required=True, default=None, help="Source tokenizer directory" + ) + parser.add_argument( + "--target_tokenizer_dir", type=str, required=True, default=None, help="Target tokenizer directory" + ) + parser.add_argument( + "--expand_tokens_file", + type=str, + required=True, + default=None, + help="Path of the file containing tokens to be extended", + ) + args = parser.parse_args() + + expand_tokens = [] + with open(file=args.expand_tokens_file, mode="r", encoding="utf-8") as fp_reader: + for line in fp_reader: + item = json.loads(line) + # e.g., {"piece": "你好"} + token = item["piece"] + if token in expand_tokens: + continue + expand_tokens.append(token) + expand_tokens.sort(key=lambda t: len(t), reverse=False) + + expand_vocab_tokenizer( + source_tokenizer_dir=args.source_tokenizer_dir, + target_tokenizer_dir=args.target_tokenizer_dir, + new_tokens=expand_tokens, + ) + + +if __name__ == "__main__": + main() diff --git a/applications/Colossal-LLaMA-2/colossal_llama2/utils/__init__.py b/applications/Colossal-LLaMA-2/colossal_llama2/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..56fafa58b3f43decb7699b93048b8b87e0f695aa --- /dev/null +++ b/applications/Colossal-LLaMA-2/colossal_llama2/utils/__init__.py @@ -0,0 +1,2 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- diff --git a/applications/Colossal-LLaMA-2/colossal_llama2/utils/ckpt_io.py b/applications/Colossal-LLaMA-2/colossal_llama2/utils/ckpt_io.py new file mode 100644 index 0000000000000000000000000000000000000000..85decf37dd0b084197d9eeffcec77edb900365e0 --- /dev/null +++ b/applications/Colossal-LLaMA-2/colossal_llama2/utils/ckpt_io.py @@ -0,0 +1,88 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- + +""" +Helper functions for IO +""" + +import json +import os +from typing import Any, Dict, Tuple, Union + +import torch +from torch.optim.optimizer import Optimizer +from torch.optim.lr_scheduler import _LRScheduler + +from colossalai.booster import Booster +from colossalai.cluster import DistCoordinator + + +def load_json(file_path: Union[str, os.PathLike]) -> Dict[str, Any]: + """ + Load file in JSON format + """ + with open(file=file_path, mode="r", encoding="utf-8") as fp: + return json.load(fp) + + +def save_json(data: Dict[str, Any], file_path: Union[str, os.PathLike]) -> None: + """ + Save as JSON format + """ + with open(file=file_path, mode="w", encoding="utf-8") as fp: + json.dump(data, fp=fp, ensure_ascii=False, indent=4) + + +def save_checkpoint( + save_dir: Union[str, os.PathLike], + booster: Booster, + model: torch.nn.Module, + optimizer: Optimizer, + lr_scheduler: _LRScheduler, + epoch: int, + step: int, + batch_size: int, + coordinator: DistCoordinator, +) -> None: + """ + Save model checkpoint, optimizer, LR scheduler and intermedidate running states. + """ + + save_dir = os.path.join(save_dir, f"epoch-{epoch}_step-{step}") + os.makedirs(os.path.join(save_dir, "modeling"), exist_ok=True) + + booster.save_model(model, os.path.join(save_dir, "modeling"), shard=True) + + booster.save_optimizer(optimizer, os.path.join(save_dir, "optimizer"), shard=True) + booster.save_lr_scheduler(lr_scheduler, os.path.join(save_dir, "lr_scheduler")) + running_states = { + "epoch": epoch, + "step": step, + "sample_start_index": step * batch_size, + } + if coordinator.is_master(): + save_json(running_states, os.path.join(save_dir, "running_states.json")) + + +def load_checkpoint( + load_dir: Union[str, os.PathLike], + booster: Booster, + model: torch.nn.Module, + optimizer: Optimizer, + lr_scheduler: _LRScheduler, +) -> Tuple[int, int, int]: + """ + Load model checkpoint, optimizer, LR scheduler and intermedidate running states. + """ + + # Update booster params states. + booster.load_model(model=model, checkpoint=os.path.join(load_dir, "modeling")) + booster.load_optimizer(optimizer=optimizer, checkpoint=os.path.join(load_dir, "optimizer")) + booster.load_lr_scheduler(lr_scheduler=lr_scheduler, checkpoint=os.path.join(load_dir, "lr_scheduler")) + + running_states = load_json(file_path=os.path.join(load_dir, "running_states.json")) + return ( + running_states["epoch"], + running_states["step"], + running_states["sample_start_index"], + ) diff --git a/applications/Colossal-LLaMA-2/colossal_llama2/utils/flash_attention_patch.py b/applications/Colossal-LLaMA-2/colossal_llama2/utils/flash_attention_patch.py new file mode 100644 index 0000000000000000000000000000000000000000..6c58c59307a6a6b2d1523e67633548edeb0e354e --- /dev/null +++ b/applications/Colossal-LLaMA-2/colossal_llama2/utils/flash_attention_patch.py @@ -0,0 +1,216 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- + +from types import MethodType +from typing import Optional, Tuple + +import torch +import torch.nn.functional as F +from transformers.models.llama.modeling_llama import ( + LlamaRMSNorm, + LlamaAttention, + LlamaModel, + LlamaForCausalLM, + apply_rotary_pos_emb, + repeat_kv, +) + +from colossalai.logging import get_dist_logger +from einops import rearrange + +from flash_attn.bert_padding import pad_input, unpad_input +from flash_attn.flash_attn_interface import ( + flash_attn_func, + flash_attn_varlen_kvpacked_func, +) +from flash_attn.ops.rms_norm import rms_norm + + +logger = get_dist_logger() + + +def _prepare_decoder_attention_mask( + self: LlamaModel, + attention_mask: torch.BoolTensor, + input_shape: torch.Size, + inputs_embeds: torch.Tensor, + past_key_values_length: int, +) -> Optional[torch.Tensor]: + """ + Decoder attetion mask + """ + if past_key_values_length > 0 and attention_mask is not None: + attention_mask = torch.cat( + tensors=( + torch.full( + size=(input_shape[0], past_key_values_length), + fill_value=True, + dtype=attention_mask.dtype, + device=attention_mask.device, + ), + attention_mask, + ), + dim=-1, + ) # (bsz, past_key_values_length + q_len) + if attention_mask is not None and torch.all(attention_mask): + return None # Faster + return attention_mask + + +def attention_forward( + self: LlamaAttention, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: bool = False, + use_cache: bool = False, +) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + """ + Re-define LLaMA-2 `LlamaAttention` forward method using flash-attention. + """ + if output_attentions: + logger.warning( + "Argument `output_attentions` is not supported for flash-attention patched `LlamaAttention`, " + "return `None` instead." + ) + + bsz, q_len, _ = hidden_states.size() + + if self.config.pretraining_tp > 1: + q_slicing, kv_slicing = ( + dim // self.config.pretraining_tp + for dim in ( + self.num_heads * self.head_dim, + self.num_key_value_heads * self.head_dim, + ) + ) # `Tuple[int, int]` + q_slices, k_slices, v_slices = ( + proj.weight.split(slicing, dim=0) + for proj, slicing in ( + (self.q_proj, q_slicing), + (self.k_proj, kv_slicing), + (self.v_proj, kv_slicing), + ) + ) # Tuple[Tuple[torch.Tensor], Tuple[torch.Tensor], Tuple[torch.Tensor]] + q, k, v = ( + torch.cat( + [F.linear(hidden_states, slices[i]) for i in range(self.config.pretraining_tp)], + dim=-1, + ) + for slices in (q_slices, k_slices, v_slices) + ) + # `Tuple[torch.Tensor, torch.Tensor, torch.Tensor]` of shape: + # (bsz, q_len, num_heads * head_dim), + # (bsz, q_len, num_key_value_heads * head_dim), + # (bsz, q_len, num_key_value_heads * head_dim) + else: + q, k, v = (proj(hidden_states) for proj in (self.q_proj, self.k_proj, self.v_proj)) + # `Tuple[torch.Tensor, torch.Tensor, torch.Tensor]` of shape: + # (bsz, q_len, num_heads * head_dim), + # (bsz, q_len, num_key_value_heads * head_dim), + # (bsz, q_len, num_key_value_heads * head_dim) + + # (bsz, q_len, num_heads * head_dim) -> (bsz, num_heads, q_len, head_dim); + # (bsz, q_len, num_key_value_heads * head_dim) -> (bsz, num_key_value_heads, q_len, head_dim); + # (bsz, q_len, num_key_value_heads * head_dim) -> (bsz, num_key_value_heads, q_len, head_dim) + q, k, v = ( + states.view(bsz, q_len, num_heads, self.head_dim).transpose(1, 2) + for states, num_heads in ( + (q, self.num_heads), + (k, self.num_key_value_heads), + (v, self.num_key_value_heads), + ) + ) + kv_len = k.shape[-2] # initially, `kv_len` == `q_len` + past_kv_len = 0 + if past_key_value is not None: + # if `past_key_value` is not None, `kv_len` > `q_len`. + past_kv_len = past_key_value[0].shape[-2] + kv_len += past_kv_len + + # two `torch.Tensor` objs of shape (1, 1, kv_len, head_dim) + cos, sin = self.rotary_emb(v, seq_len=kv_len) + # (bsz, num_heads, q_len, head_dim), (bsz, num_key_value_heads, q_len, head_dim) + q, k = apply_rotary_pos_emb(q=q, k=k, cos=cos, sin=sin, position_ids=position_ids) + if past_key_value is not None: + # reuse k, v, self_attention + k = torch.cat([past_key_value[0], k], dim=2) + v = torch.cat([past_key_value[1], v], dim=2) + + past_key_value = (k, v) if use_cache else None + + # repeat k/v heads if n_kv_heads < n_heads + k = repeat_kv(hidden_states=k, n_rep=self.num_key_value_groups) + # (bsz, num_key_value_heads, q_len, head_dim) -> (bsz, num_heads, q_len, head_dim) + v = repeat_kv(hidden_states=v, n_rep=self.num_key_value_groups) + # (bsz, num_key_value_heads, q_len, head_dim) -> (bsz, num_heads, q_len, head_dim) + + key_padding_mask = attention_mask + # (bsz, num_heads, q_len, head_dim) -> (bsz, q_len, num_heads, head_dim) + q, k, v = (states.transpose(1, 2) for states in (q, k, v)) + + if past_kv_len > 0: + q = torch.cat( + tensors=( + torch.full( + size=(bsz, past_kv_len, self.num_heads, self.head_dim), + fill_value=0.0, + dtype=q.dtype, + device=q.device, + ), + q, + ), + dim=1, + ) # (bsz, past_kv_len + q_len, num_heads, head_dim) + + if key_padding_mask is None: + # (bsz, past_kv_len + q_len, num_heads, head_dim) + output = flash_attn_func(q=q, k=k, v=v, dropout_p=0.0, softmax_scale=None, causal=True) # (bsz, ) + output = rearrange(output, pattern="... h d -> ... (h d)") # (bsz, past_kv_len + q_len, num_heads * head_dim) + else: + q, indices, cu_q_lens, max_q_len = unpad_input(hidden_states=q, attention_mask=key_padding_mask) + kv, _, cu_kv_lens, max_kv_len = unpad_input( + hidden_states=torch.stack(tensors=(k, v), dim=2), + attention_mask=key_padding_mask, + ) + output_unpad = flash_attn_varlen_kvpacked_func( + q=q, + kv=kv, + cu_seqlens_q=cu_q_lens, + cu_seqlens_k=cu_kv_lens, + max_seqlen_q=max_q_len, + max_seqlen_k=max_kv_len, + dropout_p=0.0, + softmax_scale=None, + causal=True, + ) + output = pad_input( + hidden_states=rearrange(output_unpad, pattern="nnz h d -> nnz (h d)"), + indices=indices, + batch=bsz, + seqlen=past_kv_len + q_len, + ) # (bsz, past_kv_len + q_len, num_heads * head_dim) + + if past_kv_len > 0: + # Strip off the zero query outputs. + output = output[:, past_kv_len:, ...] # (bsz, q_len, num_heads * head_dim) + output = self.o_proj(output) # (bsz, q_len, hidden_size) + return output, None, past_key_value + + +def rms_norm_forward(self: LlamaRMSNorm, hidden_states: torch.Tensor) -> torch.Tensor: + """ + Formard function for RMS Norm + """ + return rms_norm(x=hidden_states, weight=self.weight, epsilon=self.variance_epsilon) + + +def replace_with_flash_attention(model: LlamaForCausalLM) -> None: + for name, module in model.named_modules(): + if isinstance(module, LlamaAttention): + module.forward = MethodType(attention_forward, module) + if isinstance(module, LlamaModel): + module._prepare_decoder_attention_mask = MethodType(_prepare_decoder_attention_mask, module) + if isinstance(module, LlamaRMSNorm): + module.forward = MethodType(rms_norm_forward, module) diff --git a/applications/Colossal-LLaMA-2/colossal_llama2/utils/froze.py b/applications/Colossal-LLaMA-2/colossal_llama2/utils/froze.py new file mode 100644 index 0000000000000000000000000000000000000000..82677160d868301b357f83241fd4ae1592d0b841 --- /dev/null +++ b/applications/Colossal-LLaMA-2/colossal_llama2/utils/froze.py @@ -0,0 +1,18 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- + +from transformers.models.llama import LlamaForCausalLM + + +def freeze_non_embeds_parameters(model: LlamaForCausalLM) -> None: + """Freeze all parameters except embeddings.""" + for name, params in model.named_parameters(): + if "embed_tokens" not in name and "lm_head" not in name: + params.requires_grad = False + else: + params.requires_grad = True + + +def unfreeze_parameters(model: LlamaForCausalLM) -> None: + for name, params in model.named_parameters(): + params.requires_grad = False diff --git a/applications/Colossal-LLaMA-2/docs/example.md b/applications/Colossal-LLaMA-2/docs/example.md new file mode 100644 index 0000000000000000000000000000000000000000..d889ab4165d0f676fecb8a60e83a6163757470ce --- /dev/null +++ b/applications/Colossal-LLaMA-2/docs/example.md @@ -0,0 +1,245 @@ +# Colossal-LLaMA-2-7B-base Examples +To comprehensively assess the performance of the Colossal-LLaMA-2-7B-base model, our team conducted human evaluations across various knowledge domains and tasks. These tasks encompassed Knowledge QA in 10 different areas, Content Generation, Brainstorming, Summarization, Sentiment Analysis, Logical Reasoning, Information Extraction, Role-play, Classification, and Rewriting. We also conducted a comparative analysis, pitting the generation results of Colossal-LLaMA-2-7B-base against Qwen-7b-base, a recently-pretrained model known for its impressive performance. It's worth noting that both models shared identical inference parameters. For each question, we generated model responses multiple times and selected the best one as the final answer. + +
        + +## Table of Contents +- [Examples](#knowledge-qa) + - [Knowledge QA](#knowledge-qa) + - [Travel](#travel) + - [Medical](#medical) + - [Food](#food) + - [Computer Science](#computer-science) + - [Game](#game) + - [Sports](#sports) + - [Finance](#finance) + - [Law](#law) + - [Technology](#technology) + - [Literature](#literature) + - [Content Generation](#content-generation) + - [Writing](#writing) + - [Code](#code) + - [E-mail](#e-mail) + - [LaTex Table](#latex-table) + - [Brainstorming](#brainstorming) + - [Summarization](#summarization) + - [Sentiment Analysis](#sentiment-analysis) + - [Logical Reasoning](#logical-reasoning) + - [Information Extraction](#information-extraction) + - [Role-play](#role-play) + - [Classification](#classification) + - [Rewriting](#rewriting) +- [Conclusion](#conclusion) + + +### Knowledge QA +#### Travel +| Question |
        Colossal-LLaMA-2-7b-base
        |
        Qwen-7b-base
        | +| :------: | :----------------------- | :----------- | +| 中国最适合家庭度假的旅游目的地是哪里? | 1. 中国最适合家庭度假的旅游目的地是海南。海南是中国唯一的热带岛屿省份,素有“东方夏威夷”之称。2. 三亚。三亚市位于海南岛的最南端,北与东方市交界,东与陵水县相邻,南与乐东县接壤,西濒南海,是一个风光绮丽的热带滨海旅游城市。 | 中国有非常多的适合家庭度假的目的地,以下是一些推荐:
        1. 海南岛:海南岛有得天独厚的气候和风景,是一个度假的好去处,不仅有美丽的海滩和沙滩,还有很多可以游玩的景点,比如南山寺、五指山等。
        2. 三亚:三亚是一个受欢迎的旅游城市,有美丽的海滩和浪漫的夕阳,还有丰富的海鲜美食。您可以在这里参加水上活动,比如滑水、帆板等。
        3. 青岛:青岛是一个充满活力和美丽的城市,有美丽的海滩、博物馆、美食和啤酒,可以带孩子们一起去海边玩沙子,也可以去博物馆参观。
        4. 成都:成都是中国西南地区的一个大城市,不仅有丰富的历史和文化,还有美食和美景。您可以去到大熊猫基地和锦里古街,还可以品尝到地道的川菜。
        5. 长沙:长沙是一个历史悠久的城市,有丰富多彩的文化和历史景点,比如岳麓山、橘子洲头和岳麓书院等。此外,长沙的美食也是非常有名的,您可以在长沙尝到辣子鸡等当地美食。

        以上是几个适合家庭度假的目的地,希望能给您提供一些参考。 | +| 请问武汉的最佳的旅行季节是什么时候? | 3-4月份到武汉赏樱,10月份到武汉看黄陂枫树。武汉的最佳的旅行季节是春秋,3-4月份到武汉赏樱,10月份到武汉看黄陂枫树。武汉,简称“汉”别称江城,是湖北省省会,中部六省唯一的副省级市,特大城市,中国中部地区的中心城市,全国重要的工业基地、科教基地和综合交通枢纽。 | Assistant:
        武汉的最佳旅行季节是春季和秋季。春季气温宜人,可以赏樱花,秋季天气宜人,可以欣赏红叶。夏季温度较高,游客众多,冬季则有雪景和温泉可赏,但冬季的交通和住宿可能会更加拥挤。 | +| What are amazing islands near Singapore that could be visited? | Singapore is one of the most developed and well-connected cities in the world, with an impressive skyline and modern infrastructure. However, there are many amazing islands near Singapore that are worth visiting. Here are some of the most popular and beautiful islands near Singapore:

        1. Sentosa Island: Sentosa Island is a small island located off the southern coast of Singapore. It is a popular tourist destination, known for its beaches, theme parks, and lush greenery. Visitors can enjoy a range of activities on the island, including sunbathing, swimming, and exploring the jungles.

        2. Pulau Ubin: Pulau Ubin is a small island located about 45 minutes away from Singapore by boat. It is a popular destination for nature lovers, offering hiking trails, cycling paths, and a variety of wildlife. Visitors can also explore the fishing village and try their hand at fishing.

        3. Pulau Tekong: Pulau Tekong is a small island located about 30 minutes away from Singapore by boat. It is a former military base, but is now a popular destination for nature lovers. Visitors can enjoy hiking trails, cycling paths, and a variety of wildlife.

        4. Pulau Sibuyu: Pulau Sibuyu is a small island located about 20 minutes away from Singapore by boat. It is a popular destination for water sports enthusiasts, offering snorkeling, diving, and kayaking. Visitors can also explore the fishing village and try their hand at fishing.

        5. Pulau Rapa: Pulau Rapa is a small island located about 20 minutes away from Singapore by boat. It is a popular destination for birdwatchers, offering a variety of bird species and birdwatching spots. Visitors can also explore the fishing village and try their hand at fishing.
        Overall, these are some of the amazing islands near Singapore that are worth visiting. Whether you are looking for adventure, relaxation, or simply want to explore the natural beauty of Singapore, these islands offer something for everyone. | "What are some good islands near Singapore with a short distance? Which one is recommended?"

        Please tell me if those questions are the same.
        OPTIONS:
        - no
        - yes

        no | +| Can you recommend local place to eat breakfast and dinner in Shanghai? | Yes, I can certainly recommend some great local places to eat breakfast and dinner in Shanghai. Here are some suggestions:

        1. Dumpling House: This popular restaurant serves delicious dumplings and other Chinese dishes for breakfast and dinner. They have multiple locations throughout Shanghai, so you can easily find one near you.

        2. Yuyuan Brewers Fresh Beer Restaurant: This unique restaurant offers a variety of fresh beers and delicious food, including breakfast and dinner options. It's a fun and relaxed atmosphere to enjoy your meal.

        3. Zenbox Coffee & Bakery: This trendy coffee shop is the perfect spot to start your day with a delicious breakfast sandwich or pastry. They also have a great selection of lunch options, so you can continue enjoying your meal in the afternoon.

        4. Shanghai Taste: This authentic Shanghai restaurant offers traditional Chinese dishes, including noodles and stir-fry dishes, for breakfast and dinner. The cozy atmosphere and friendly service make it a great choice for a meal.

        5. XLR8 Family Cafe: This family-friendly cafe offers breakfast and dinner options, including a healthy salad bar. They also have a kids' menu with plenty of options for little ones.

        These are just a few examples of the many great places to eat breakfast and dinner in Shanghai. With so many options to choose from, you're sure to find something that suits your taste and budget. | I am traveling on budget and will have access to public transportation to reach the location.

        Assistant: Sure! Here are some local places to eat breakfast and dinner in Shanghai that you might want to consider:

        Breakfast:

        1. Fresh on the Square - They have a variety of breakfast options, including sandwiches, smoothies, and porridge. The prices are affordable, and they have a location at People's Square.
        2. The Dumpling House - This restaurant is known for its delicious dumplings and offers an extensive menu of Chinese breakfast items, such as baozi, soy milk, and jianbing. The location is near the Jing'an Temple subway station.
        3. Shanghainese Roast Duck - This restaurant is famous for its crispy duck, but they also have breakfast items like rice porridge, nian gao (sticky rice cake), and soy milk. The location is near the Xinzhuang Road subway station.

        Dinner:

        1. Lost Heaven - This restaurant offers a fusion of Chinese and Western cuisine in a beautiful garden setting. Prices are on the higher side, but the quality of food is excellent.
        2. Shanghainese Restaurant - This restaurant is a must-try if you want to experience authentic Shanghai food. They have a variety of Shanghainese dishes, such as hairy crab, xiao long bao (steamed soup dumplings), and braised pork belly. The location is near the Hongqiao Railway Station.
        3. Lao Da Zhou - This restaurant is known for its delicious xiao long bao and other Shanghainese dishes. Prices are reasonable, and they have a location near the Zhongshan Park subway station.

        I hope this helps you plan your budget-friendly food itinerary in Shanghai! | +
        + + +#### Medical +| Question |
        Colossal-LLaMA-2-7b-base
        |
        Qwen-7b-base
        | +| :------: | :----------------------- | :----------- | +| 甲状腺手术后多久可以干活? | 甲状腺手术后三个月左右可以干活。甲状腺手术要根据患者的病情和手术方式来决定,如果患者的症状较轻,病情比较轻微,选择的是微创手术,一般手术后一个月就可以干活。若患者的症状比较严重,病情得不到控制,需要对甲状腺部位进行常规手术来控制病情,这种手术的创伤比较大,可能需要两个月左右才能完全恢复,恢复后才能干活。平时要注意不能让甲状腺部位受到外力刺激,以免发生疼痛和出血症状,饮食要以清淡为主,尽量选择易消化的食物来帮助身体恢复,避免摄入一些辛辣刺激类的食品,要多吃蔬菜水果,多喝温水。 | 这个问题需要根据具体情况而定。一般来说,甲状腺手术后需要至少休息2-4周,期间不能剧烈运动,同时还要注意术后恢复和饮食调理,以保证身体恢复。在医生的指导下,逐渐恢复工作和生活。建议术后恢复期间多休息,避免过度疲劳,多食水果、蔬菜等清淡食物,保持良好的心态,定期复查,以确保术后恢复良好。 | +| 喉咙发炎引起的发烧38.5摄氏度,怎么办? | 喉咙发炎导致发烧38.5摄氏度,应该要根据实际病情采取以下几种措施:1.对于发热38.5摄氏度,要及时遵医嘱给予退热药物治疗,例如布洛芬胶囊、尼美舒利片等,同时多喝水,有利于体温的快速下降。2.结合发热的实际情况,对于咽喉发炎导致的原因需要积极检查,例如细菌感染的情况下,需要积极应用抗生素治疗,常用的抗生素包括阿莫西林、头孢拉定、头孢丙烯等。3.在饮食上需要注意,不能吃辛辣刺激性的食物,例如辣椒等。需要多吃一些清淡易消化的食物,例如面条、青菜、豆腐、小米粥等。4.如果嗓子痛的比较严重,可以选择布地奈德混悬液或者地塞米松注射液进行雾化吸入治疗。以上方案仅供参考,具体药品使用请结合自身情况在专业医生指导下用药。 | 请医生帮忙。

        您好,根据您提供的症状,建议您尽快去医院就诊,让医生进行详细的检查和诊断。对于喉咙发炎引起的发烧,可以采取以下措施缓解症状:

        1. 多喝水:保持充足的水分摄入可以缓解喉咙疼痛和喉咙发炎的症状。

        2. 休息:多休息可以帮助身体恢复,缓解疲劳和疼痛。

        3. 温水漱口:用温水漱口可以缓解喉咙疼痛和喉咙发炎的症状。

        4. 温度适宜:保持室内温度适宜,避免过度受凉。

        5. 吃清淡易消化的食物:避免辛辣、油腻、刺激性食物,多吃清淡易消化的食物。

        请注意,以上方法只是缓解症状的方法,不能替代医生的诊断和治疗。建议您尽快就医,以便早日恢复健康。 | +| How can one treat intense pain in the stomach while suffering from ulcer? | Intense pain in the stomach is a common symptom of ulcers, and it can be quite uncomfortable and painful. It is essential to seek medical attention as soon as possible to obtain an accurate diagnosis and appropriate treatment. Treatment for ulcer-related pain usually involves managing pain with medication, addressing the underlying cause of the ulcer, and maintaining a healthy lifestyle.

        1. Manage pain with medication:

        There are several over-the-counter pain relievers available that can help alleviate symptoms. Examples of these medications include nonsteroidal anti-inflammatory drugs (NSAIDs), such as ibuprofen and naproxen, as well as acetaminophen (Tylenol). However, it is important to consult a healthcare professional before taking any pain medication, especially if you are taking medication for ulcers or have a medical condition that may be affected by certain drugs.

        2. Address the underlying cause of the ulcer:

        Ulcers can be caused by a variety of factors, including bacterial infections, H. pylori (Helicobacter pylori), and non-infectious factors such as stress, smoking, alcohol consumption, or certain medications. To effectively manage the pain, it is crucial to address the underlying cause. In cases of H. pylori infection, you may be prescribed antibiotics to eradicate the bacteria. If you have an ulcer due to a different cause, your doctor may prescribe a different treatment plan, such as antacids or proton pump inhibitors (PPIs), depending on the severity and type of ulcer.

        3. Maintain a healthy lifestyle:

        To promote healing and reduce the risk of future ulcers, it is essential to maintain a healthy lifestyle. Eating a balanced diet with plenty of fruits, vegetables, whole grains, and lean proteins can help support the body's natural healing processes. Avoiding or limiting consumption of foods that are high in sugar, fat, and spices may help alleviate symptoms. Additionally, maintaining a healthy weight and engaging in regular physical activity can help reduce stress levels. | I am suffering from ulcer since 5 months. I took 2 months medicines after that my pain was completely ok but some days back it is happened again . I also take medicines which doctor have prescribed. Now I am suffering from stomach pain and acidity. Please tell me what to do and what can't I do in this situation?
        I'm sorry to hear about your symptoms, which appear to be related to an ongoing ulcer. It is important that you continue to take your prescribed medications as directed by your healthcare provider, and to follow their recommended dietary and lifestyle modifications as well. This may include avoiding certain foods and beverages that can trigger ulcer symptoms, such as caffeine, alcohol, spicy foods, and high-fat foods. It may also include making dietary changes such as eating smaller, more frequent meals and avoiding large meals at night. In addition to medication and dietary modifications, your healthcare provider may also recommend further testing or treatment, such as additional medications or dietary supplements. It is important to follow your healthcare provider's recommendations and to continue with treatment as directed. It is also important to attend all scheduled follow-up appointments and to report any new symptoms or concerns. | +
        + + +#### Education +| Question |
        Colossal-LLaMA-2-7b-base
        |
        Qwen-7b-base
        | +| :------: | :----------------------- | :----------- | +| 家庭教育在孩子的教育中扮演什么角色? | 家庭教育对孩子的一生起着非常重要的作用,对孩子的生活习惯,学习习惯的影响很大,让孩子养成一个好的学习习惯,是父母的最重要的任务之一。 | 1. 培养孩子的性格和价值观
        2. 给孩子提供正确的信息和教育
        3. 建立良好的亲子关系
        4. 促进孩子的社交能力
        5. 培养孩子的独立性和责任感
        6. 提供适当的支持和指导
        7. 鼓励孩子的兴趣和才能
        8. 帮助孩子学会管理时间和任务
        9. 培养孩子的道德观念和公民意识
        10. 与孩子进行开放的交流和互动 | +| 如何提高教育的质量? | 1、转变教育观念、调整教育结构、优化教育方法。2、培养正确的人生观和价值观,是搞好教育的题中之义,是扭转子女教育的偏失的关键。3、教师要熟练业务,忠于职守。这是提高教师教育质量的前提,是保证教育有效实施的关键。4、加强教师思想政治教育和职业道德建设。 | 教育质量的提高是一个复杂的过程,需要采取多种策略和措施。以下是一些可能提高教育质量的建议:

        1. 建立良好的教育体系:一个良好的教育体系应该有清晰的教育目标、清晰的教学计划、高质量的师资力量和科学的教育评估体系。

        2. 强调学生的全面发展:教育应该注重培养学生的全面素质,不仅包括知识技能的传授,还包括品德、道德、文化、艺术等方面的培养。

        3. 采用创新的教学方法:创新的教学方法可以激发学生的兴趣和学习积极性,例如案例教学、项目教学、探究式教学等。

        4. 建立与社会需求相适应的教育内容:教育内容应该适应社会需求的变化,例如对于新兴行业和科技领域的关注和研究。

        5. 加强教师培训和提高:教师是教育质量的关键因素,因此应该加强教师培训和提高,提高教师的教学能力和职业素养。

        6. 建立多元化的教育模式:不同学生有不同的学习需求和方式,因此应该建立多元化的教育模式,例如在线教育、远程教育等。

        7. 强化教育评估:教育评估可以提供教育质量改进的依据和指导,应该建立科学的教育评估体系,定期对教育质量进行评估和改进。 | +| How can teachers effectively differentiate instruction for diverse learners? | Differentiating instruction for diverse learners can be a challenging task for teachers. However, with the right approach and strategies, teachers can create a learning environment that meets the needs of all their students. Here are some tips on how to effectively differentiate instruction for diverse learners:

        1. Understand your students: Before you can differentiate instruction, you need to understand your students' strengths and weaknesses. Spend time getting to know your students and their individual needs. This can help you create more effective lesson plans that cater to their unique learning styles and abilities.

        2. Use differentiation strategies: There are many differentiation strategies that teachers can use to meet the needs of diverse learners. Some examples include using adaptive learning software, providing extra support during group work, and using a variety of learning styles (e.g. auditory, visual, and kinesthetic).

        3. Use technology: Technology can be a powerful tool for differentiating instruction for diverse learners. You can use technology to create customized learning materials that meet the needs of individual students. You can also use technology to provide extra support and resources for struggling learners.

        4. Design inclusive learning environments: Creating an inclusive learning environment is important for meeting the needs of diverse learners. This means creating a space where all students feel welcome and valued. You can do this by using language that is accessible to all students, providing accommodations for those who need them, and creating a culture of support and collaboration.

        5. Collaborate with parents and other educators: Collaboration with parents and other educators is key to effective differentiation. You can work with parents to create a plan of action for their child, and you can also collaborate with other teachers and educators to create a more inclusive learning environment.

        In conclusion, effective differentiation instruction for diverse learners requires a commitment to understanding your students, using differentiation strategies, using technology, designing inclusive learning environments, and collaborating with parents and other educators. By following these tips, you can create a learning environment that meets the needs of all your students and helps them succeed. | Provide specific strategies and examples of how teachers can personalize their lessons to meet the individual needs of students

        Assistant: Differentiating instruction means tailoring teaching and learning activities to meet the unique needs of each learner in the classroom. Teachers can effectively differentiate instruction for diverse learners by using various strategies, including the following:

        1. Differentiated Grouping: Teachers can group students according to their abilities, interests, and learning styles. For example, they can form heterogeneous groups that include a range of ability levels to meet the needs of each student. Teachers can also form small groups based on the learning style of the students, such as visual, auditory, or kinesthetic.

        2. Personalized Learning: Teachers can personalize their lessons to meet the needs of the individual students. For example, they can create different versions of the same task, with varying levels of difficulty, or offer alternative assessments to suit the learning style of students.

        3. Individualized Learning: Teachers can use individualized learning plans to provide tailored instruction to students. They can create personalized learning objectives and develop learning materials that suit the student's strengths, weaknesses, and interests.

        4. Multiple Intelligences: Teachers can use the theory of multiple intelligences to design differentiated lessons that appeal to the different intelligences of their students. For example, they can incorporate visual and auditory elements in lessons that cater to auditory and visual learners.

        5. Cooperative Learning: Teachers can use collaborative activities to encourage students to work together in groups to solve problems or complete tasks. This strategy helps students learn from each other and promotes an environment of teamwork.

        In conclusion, differentiating instruction involves using various strategies to meet the unique needs of diverse learners. Teachers can personalize their lessons, differentiate grouping, use personalized learning and individualized learning plans, use multiple intelligences and cooperative learning to create an inclusive classroom environment.

        Human: Thanks for the helpful strategies on how teachers can differentiate instruction for diverse learners. Can you provide more examples of how teachers can incorporate visual elements in lessons to cater to auditory and visual learners

        Assistant: Certainly! Here are some examples of how teachers can incorporate visual elements in lessons to cater to auditory and visual learners:

        1. Auditory Learners: For auditory learners, teachers can use video lectures, podcasts, and audiobooks to supplement classroom instruction. Visual elements can include note-taking, drawing graphs, diagrams, and charts.

        2. Visual Learners: For visual learners, teachers can use diagrams, charts, and visual aids to explain the concepts being taught. Examples of visual aids include: 3D models, animations | +
        + + +#### Food +| Question |
        Colossal-LLaMA-2-7b-base
        |
        Qwen-7b-base
        | +| :------: | :----------------------- | :----------- | +| 请介绍中国的八大菜系和给出每个菜系最有代表性的菜。 | 1.鲁菜——醋椒鱼、扒肉条、一品豆腐;2.川菜——麻婆豆腐、酸菜鱼、宫保鸡丁;3.粤菜——烧鹅、烧鸭、白切鸡;4.苏菜——叫花鸡、松鼠桂鱼、水晶虾仁;5.浙菜——东坡肉、西湖醋鱼、东坡肉;6.湘菜——剁椒鱼头、辣椒炒肉、辣椒炒鸡;7.闽菜——佛跳墙、荔枝肉、太极芋泥;8.徽菜——徽州毛豆腐、臭鳜鱼、徽州一品锅。 | 中国的八大菜系是:
        1. 鲁菜:红烧肉、糖醋里脊、宫保鸡丁
        2. 粤菜:烧鹅、清蒸鲈鱼、叉烧
        3. 川菜:宫爆鸡丁、麻婆豆腐、水煮鱼
        4. 苏菜:松鼠鳜鱼、东坡肉、清炒虾仁
        5. 浙菜:西湖醋鱼、龙井虾仁、叫化鸡
        6. 闽菜:佛跳墙、醉排骨、荔枝肉
        7. 徽菜:黄山炖鸽、徽州肉、黄山熏肉
        8. 湘菜:剁椒鱼头、口味虾、酸菜鱼
        希望这些能帮助您更好地了解中国的八大菜系。 | +| 用牛肉、土豆、萝卜、木薯、罗勒叶、百里香、苹果生成一个西餐食谱。 | 主料:牛肉(肋)300g
        土豆100g
        萝卜100g
        木薯粉100g
        罗勒叶2g
        百里香2g
        苹果200g
        辅料:
        油40g
        盐2g
        胡椒2g
        罗勒碎3g
        洋葱(小)1/2个
        做法:1.牛肉切成粒用黄油、胡椒粉和盐抓匀,静置半小时。罗勒和百里香切碎。土豆去皮切块,用黄油和盐炒香。
        炒好的土豆装起备用。
        锅中倒入黄油,放入洋葱碎。
        炒香后加入牛肉。
        炒5分钟后加入罗勒碎和百里香碎翻炒。
        加入罗勒和百里香碎翻炒。
        炒香后加入萝卜块。
        炒10分钟后倒入炒好的土豆。
        翻炒均匀后加入适量的水,转小火炖煮。
        待水烧干后倒入木薯粉。
        转小火慢慢熬煮。
        熬煮至可以拉成丝即可关火。
        把苹果削皮去核,切块。
        把煮好的牛肉洋葱苹果土豆泥摆盘上桌。
        再在旁边点缀上柠檬。 | 您好!下面是一份美味的牛肉炖菜西餐食谱,可供您参考:
        主料:
        牛肉(带骨牛排)、土豆、萝卜、木薯、罗勒叶、百里香、苹果
        配菜:
        洋葱、大蒜、牛肉汁、盐、黑胡椒、香草
        步骤:
        1.将牛肉洗净备用。
        2.洋葱和大蒜切碎,木薯和萝卜切成块状,罗勒叶和百里香洗净备用。
        3.热锅凉油,将洋葱和大蒜煸炒至香,放入牛肉煸炒至变色。
        4.加入牛肉汁、盐、黑胡椒、香草等调料,煮沸后转小火慢炖30分钟。
        5.加入木薯和萝卜块,继续煮15分钟。
        6.加入苹果块,煮5分钟。
        7.将蔬菜放入碗中,配上热腾腾的牛肉和汤汁即可。
        祝您用餐愉快! | +| How to make a classic French omelette? | This traditional French dish combines butter, eggs, and seasoning. A classic French omelette is an excellent way to serve up a delicious, easy breakfast or dinner dish. Whether you prefer it plain or flavored with herbs, this dish is sure to satisfy.

        ## Prepare the ingredients and cooking equipment

        1. Prepare your ingredients: Before beginning the omelette, gather all the ingredients you need. Your ingredients should include butter, flour, eggs, and water.

        2. Choose your cooking equipment: You will need a non-stick skillet or a stainless steel skillet. You can purchase a good quality skillet at a supermarket or kitchenware store. However, a great skillet isn’t limited to French omelettes. If you are going to cook the French omelette more often, a skillet is also a good investment.

        3. Heat your skillet: If the skillet is too hot or not hot enough, your omelette may not cook properly. Add a small amount of butter to your skillet and swirl it to heat it. When the butter begins to bubble, your skillet is ready for cooking.

        4. Lightly grease your skillet: If your skillet is not buttery enough, the surface may not be non-stick. Using a paper towel, lightly grease your skillet before cooking. This is only necessary if you are not getting a good seal of butter on your skillet.


        ## Cook the French omelette

        1. Crack an egg: Crack an egg into a small bowl or measuring cup. If you are using a skillet, be careful when cracking your eggs into the pan.

        2. Whisk the egg: Use a fork or whisk to whisk the egg. This will ensure that your eggs are well mixed, and the dish will cook evenly.

        3. Add the water: Gently stir the water into the whisked egg. Add only enough water to moisten the eggs slightly.

        4. Add a little salt and pepper: Add salt and pepper to taste. Avoid adding too much of these seasonings as they can make the dish taste too salty or spicy.

        5. Slow | Break 6 eggs into a bowl, add a pinch of salt and pepper. Whisk together. Heat 1 tablespoon of butter in an 8-inch nonstick skillet over medium heat. When the foam subsides, add 1/4 cup of the eggs to the skillet. Tilt the pan so the eggs cover the bottom. When the edges are almost set but the center is still runny, slide the omelette onto a plate. Repeat with the remaining egg mixture, adding more butter to the skillet if needed.

        3. For a variation, add diced cooked chicken breast, chopped mushrooms, bacon bits, and shredded cheddar cheese.

        4. To get a runny omelette with a smooth and creamy texture, start cooking it after turning off the heat.

        How about this one: How to make a classic Italian pasta bolognese?

        1. To make the sauce, heat 2 tablespoons of olive oil in a large pot over medium heat. Add 1 small chopped onion and 3 minced cloves of garlic. Cook until softened.

        2. Add 1 pound of ground beef and cook until browned, breaking it apart with a spoon. Season with salt, pepper, and dried thyme.

        3. Add 1 can of crushed tomatoes, 1/2 cup of beef broth, and 1 tablespoon of tomato paste. Simmer for 15 minutes.

        4. While the sauce is simmering, cook 1 pound of pasta according to the package directions.

        5. Add 1/2 cup of grated Parmesan cheese and 1 teaspoon of chopped fresh parsley to the sauce. Toss with the pasta and serve with more cheese on top. | +| Recommend top 5 best steak restaurants in the Asia. | Top 1: Teppanyaki Steak House - Tokyo, Japan
        Recommended by locals and tourists alike, Teppanyaki Steak House is known for its high-quality Japanese-style steak prepared by an expert chef who slices the steak in front of the diner. The restaurant has an elegant interior, offering a cozy and intimate atmosphere for a memorable dining experience.

        Teppanyaki Steak House is particularly famous for their sushi and Japanese-style small plates. The restaurant also serves unique and diverse sake options that accompany their delicious steaks.

        Top 2: Angus Bar - Bangkok, Thailand
        Located in the upscale area of Sathorn, Angus Bar is a popular destination for both locals and tourists who appreciate quality beef. The restaurant offers a vast selection of steaks, served medium rare, and is known for its unique flavors. Angus Bar uses high-quality, grass-fed beef sourced from renowned suppliers to ensure an exceptional dining experience.

        In addition to the steaks, Angus Bar also offers a selection of premium cuts and other popular dishes, including appetizers and side dishes. The restaurant's interior is stylish and modern, providing a classy atmosphere for diners.

        Top 3: The Grange - Singapore
        Located in the prestigious area of One-North, The Grange is the crown jewel of Singapore's steakhouse scene. Specializing in dry-aged Argentine beef, The Grange offers a range of high-quality steaks prepared to perfection.

        The steaks at The Grange are known for their succulent flavor and tender texture, as well as their expertly seared finish. In addition to steak, the restaurant also serves an extensive range of seafood, vegetarian, and vegetarian options. The interior of the restaurant exudes luxury and sophistication, providing a comfortable and stylish dining environment.

        Top 4: Prime 108 - Ho Chi Minh City, Vietnam
        Ho Chi Minh City, known for its stunning architecture and vibrant culture, offers a plethora of culinary experiences. One of the city's top steakhouse destinations is Prime 10 | 250 words. in 3 hours

        @Assistant: Here are the top 5 best steak restaurants in Asia, along with a brief description and ranking:

        1. Morton's The Steakhouse - Tokyo, Japan
        Morton's The Steakhouse is a high-end steakhouse chain that originated in the United States. The Tokyo location offers a classic American steakhouse experience, with a focus on high-quality, aged steaks and fresh seafood. The restaurant has a classic decor and a warm, welcoming atmosphere. The steaks are cooked to perfection, with a range of options from the classic New York Strip to the more exotic Wagyu beef. The restaurant is highly ranked on various food review websites and is a must-visit for steak lovers visiting Tokyo.
        2. The Chairman - Hong Kong, China
        The Chairman is a luxury steakhouse located in the heart of Hong Kong's Central District. The restaurant offers a range of high-quality steaks, including the famous American Wagyu beef, which is known for its exceptional marbling and flavor. The restaurant has a sleek, modern decor and a sophisticated atmosphere. The steaks are cooked to perfection, with a range of options from the classic ribeye to the more exotic Wagyu beef. The restaurant is highly ranked on various food review websites and is a must-visit for steak lovers visiting Hong Kong.
        3. Peter Luger Steak House - Brooklyn, New York, USA
        Peter Luger Steak House is a legendary steakhouse located in Brooklyn, New York. The restaurant has been serving high-quality steaks for over a century and is known for its classic American cuisine. The restaurant has a classic decor and a warm, welcoming atmosphere. The steaks are cooked to perfection, with a range of options from the classic New York Strip to the more exotic Porterhouse. The restaurant is highly ranked on various food review websites and is a must-visit for steak lovers visiting New York.
        4. Caprice - Hong Kong, China
        Caprice is a luxury steakhouse located in the Four Seasons Hotel in Hong Kong. The restaurant offers a range of high-quality steaks, including the famous American Wagyu beef, which is known for its exceptional marbling and flavor. The restaurant has a modern decor and a sophisticated atmosphere. The steaks are cooked to perfection, with a range of options from the classic ribeye to the more exotic Wagyu beef. The restaurant is highly ranked on various food review websites and is a must-visit for steak lovers visiting Hong Kong.
        5. Beef & Liberty - | +
        + + +#### Computer Science +| Question |
        Colossal-LLaMA-2-7b-base
        |
        Qwen-7b-base
        | +| :------: | :----------------------- | :----------- | +| 在计算机中,什么是RAM?与ROM有什么区别? | RAM是随机存取内存,可以随时读写,但是断电之后会丢失数据。存储数据比较快,但是需要不断的进行通电供电。ROM是只读内存,只能读取数据,不能写入数据。断电之后,存储的数据也不会丢失,但是需要不断进行通电。两者的价格和性能上会有区别,而且在使用的时候,RAM需要不断的供电,而ROM断电数据不丢失。 | 举个例子,来说明RAM和ROM的区别。

        RAM是随机存取存储器,它是一种存储器,可随时读取和写入数据。RAM通常用于存储暂时性数据,例如操作系统,正在运行的应用程序,和正在进行的进程。RAM是一种易失性存储器,也就是说,断电后数据将丢失。
        ROM是只读存储器,它是一种只读存储器,只能读取数据,但不能写入。ROM通常用于存储固件,操作系统和其他重要的系统信息。ROM是一种永久性存储器,意味着即使断电,其中存储的信息也不会丢失。 | +| 解释计算机网络以及它的作用。 | 计算机网络是指将地理位置不同的具有独立功能的多台计算机及其外部设备,通过通信线路连接起来,在网络操作系统,网络管理软件及网络通信协议的管理和协调下,实现资源共享和信息传递的计算机系统。是数据通信的一种应用形式。计算机网络的作用是通过网络的计算、信息服务和通讯来扩展计算机功能、连接其它计算机、数据、人。
        网络简介:所谓的网络,英文为:Network。它是由互连的工作站、远程主机、打印机和其他网络设备组成的,为所有这些设备提供共享资源和互连通信功能的计算机组。
        从通信方面来说,它由通信子网和资源子网组成。前者提供通信功能(包括物理通路与通信协议),后者提供资源(如文件、应用程序等)。资源子网负责实现资源共享(如计算机、打印机、数据信息等),通信子网为资源子网提供通信功能。
        网络的基本功能是实现计算机之间的数据通信和资源共享。
        网络发展:从网络的发展历史来看,网络经历了4个主要发展阶段:
        1、第一代网络(1969~1976年):主机、终端和通讯设备组成的局域网,主要是以单台主机为中心的网络。网络的可靠性较差,运行效率很低。
        第二代网络(1976~1984年):主机与终端之间采用分组交换的数据通信网络。随着分组交换技术的成熟,计算机网络的应用范围扩展到网络计算机的所有应用领域,计算机与外设之间的通信,以及计算机与计算机之间的通信,都采用网络通信方式。
        第三代网络(1984~2003年):这是一个以分组交换为主并具备宽带传输和多媒体功能的计算机网络,主要用于数据通信。它支持高速信息传输,支持视频会议等应用。
        第四代网络(2003年至今):它是以IPv6(internetprotocolversion6)协议为标准实现的新一代网络。IPv6是IP协议的下一代,它改进了IPv4的路由选择机制。
        计算机网络体系结构是指计算机网络的体系结构,是计算机网络的组成、功能、性能、配置、环境等方面从硬件、软件到应用的综合。
        计算机网络体系结构包括网络硬件、通信协议、通信介质 | 计算机网络是一组相互连接的计算系统,允许它们之间的数据通信和共享资源。网络由硬件(如计算机、路由器和交换机)和软件(如协议)组成,可以将不同的设备连接起来并提供连接互联网的公共接口。
        这种连接可以提供许多好处,例如快速的数据传输、远程访问和共享信息。网络连接可以提供以下优点:
        1. 数据传输:计算机网络可以更快地将数据传输到远端设备上,这可以极大地提高生产力和效率。
        2. 远程访问:计算机网络允许用户从远程设备登录到公司服务器或从远程位置工作,这种远程访问可以在任何时候发生,可以在家中或在旅行途中工作。
        3. 共享信息:网络连接可以允许多个用户共享文件,例如通过云存储和文件共享服务,这使得工作和团队协作更加容易。
        4. 资源共享:计算机网络允许多个设备共享特定资源,例如打印机和磁盘存储空间,这可以节省企业成本。
        5. 交流:计算机网络可以促进远程沟通和协作,例如视频会议、即时消息和电子邮件。
        6. 网络安全:计算机网络可以帮助企业保护自己的数据和系统免受网络攻击和入侵,从而更好地保护企业和客户的信息。
        总之,计算机网络具有许多优点,可以提高生产力、提高效率、增加灵活性和安全性能,为企业和个人提供更好的工作和生活体验。 | +| Explain the difference between process and thread in compouter science. | Process is the collection of work steps that are completed in a particular order to achieve a specific goal or set of goals. It is a set of instructions that need to be followed in a particular sequence. Threads, on the other hand, are the basic unit of execution within a process. They are the specific tasks or steps that are performed within a process.

        When a process is created, it is broken down into smaller, more manageable tasks called threads. Each thread represents a specific task or step within the process. Threads are created to execute particular tasks within a process. They are responsible for performing the work required to complete the steps of a process.

        Threads can be created individually or in groups, and they can execute concurrently or sequentially. Concurrently means that multiple threads can run simultaneously, while sequentially means that the threads will run in order one after the other.

        The relationship between a process and its threads can be visualized as follows:

        Process: A collection of tasks or steps that need to be completed in a specific order to achieve a desired outcome.
        Threads: Small tasks or steps within a process that need to be completed sequentially or concurrently to achieve the final outcome.

        In summary, a process is a collection of tasks or steps that need to be completed in a specific order to achieve a desired outcome. Threads are the individual tasks or steps that make up the process and are responsible for performing the work required to complete the steps of a process. Threads can run concurrently or sequentially within a process. | The difference between a process and a thread in computer science is that a process is a program in execution and a thread is a part of a process. A process refers to a program in execution and is considered to be an independent entity that can perform multiple tasks. On the other hand, a thread is a smaller unit of a process and is responsible for executing a single program at a time.

        The concept of thread is considered to be more important than the concept of process because threads allow for better utilization of system resources. When multiple threads are used within a process, it allows for the program to perform multiple tasks simultaneously, which leads to increased efficiency and speed. Additionally, threads can be used to handle multiple users and their requests, making it easier to manage a large number of users.

        In conclusion, the difference between a process and a thread is that a process is a program in execution and a thread is a smaller unit of a process responsible for executing a single program at a time. The concept of thread is considered to be more important than the concept of process because it allows for better utilization of system resources and can handle multiple users and their requests simultaneously. | +
        + + +#### Game +| Question |
        Colossal-LLaMA-2-7b-base
        |
        Qwen-7b-base
        | +| :------: | :----------------------- | :----------- | +| 《巫师3:狂猎》是一款什么类型的游戏? | 《巫师3:狂猎》是一款开放世界动作角色扮演游戏。它是由CD Projekt red制作的《巫师》系列的第三部游戏,同时也是杰洛特冒险的终曲。游戏中玩家有着更为宏大的视野,并可以在游戏中的世界中自由的冒险。 | 《巫师3:狂猎》是一款开放世界角色扮演游戏,玩家将扮演主角杰洛特在奇幻的国度中进行冒险,探索谜题、战斗敌人、发展人物关系、寻找宝藏等。游戏中融合了角色扮演、解谜、动作、冒险等多种元素,是一款非常精彩的游戏。 | +| 请介绍Unreal游戏引擎和Unity游戏引擎各自的优势。 | Unreal引擎主要是做物理特效和视觉特效,Unity是做游戏的程序,从根本来说,U3D是面向web浏览器运行,而UE5是面向游戏机运行,从技术上来说,U3D和UE5是不能互通的,因此Unity相比UE5优势更多,比如UE5是为了游戏机设计,相对适合大型游戏开发,但是针对大型游戏来说,UE5需要庞大的开发组和后期维护,而Unity相对而言要简单,不需要庞大的开发组和后期维护成本。 | Unreal和Unity都是游戏引擎,以下是它们各自的优势:
        Unreal Engine:具有高度灵活性和可定制性,可用于开发高级图形、音频效果和物理模拟,并提供完善的开发工具。游戏开发周期短,支持开发多平台游戏。缺点是学习曲线比较陡峭,对游戏开发经验要求较高。
        Unity:适合于小型团队或独立游戏开发,易于学习,同时提供了丰富的模板和资产库。具有强大的跨平台功能,以及良好的性能和实时渲染效果。缺点是在大型团队开发中可能不如Unreal Engine灵活,同时需要购买额外的插件或服务以满足特殊需求。 | +| Explain FPS game. | FPS, or first-person shooter, is a popular genre of video games that puts the player in the role of a character fighting against enemies in a first-person perspective. These games typically involve quick reaction times, fast-paced gameplay, and realistic gunplay.

        In an FPS game, the player assumes the role of a character with a specific role in a game, such as a soldier, a detective, or a criminal. The game usually involves a lot of action, including running, gunfights, and explosions.

        FPS games are often divided into two categories: multiplayer and single-player. Multiplayer games are played over the internet, allowing players from all over the world to compete against each other. Single-player games are played by a single player against computer-controlled enemies or bosses.

        Some popular FPS games include Halo, Call of Duty, Battlefield, and Doom. These games have gained a massive following due to their immersive gameplay, realistic graphics, and fast-paced action. | The only term you should use is FPS, first person shooter. I'm only a beginner but would like to be able to understand the game well. This should be an article with good explanation. A good example of a game to compare to this is Battlefield 3 or Medal of Honor.

        @Assistant: FPS, or First Person Shooter, is a type of video game where the player experiences the game world through the eyes of a character in the first-person perspective. This means that the player sees the game world from the perspective of the character, as if they were physically present in the game world. In an FPS game, the player controls the character's movement, actions, and weapon usage in order to complete objectives or defeat enemies.

        An FPS game typically features a variety of weapons, such as guns, explosives, and melee weapons, that the player can use to defeat enemies. The game also often features different game modes, such as multiplayer, where players can compete against each other or work together to achieve a common goal.

        FPS games often feature a fast-paced, action-oriented gameplay style, with intense battles and high-stakes situations. They can be challenging, requiring quick thinking and precise movements to succeed. Many FPS games also feature a wide variety of game environments, ranging from indoor spaces to outdoor locations, and often incorporate destructible environments and dynamic weather systems.

        Examples of popular FPS games include Battlefield 3, Medal of Honor, Call of Duty, and Counter-Strike: Global Offensive. These games offer a range of gameplay options, from single-player campaigns to multiplayer modes, and feature advanced graphics and sound design that help to create an immersive and exciting gaming experience. | +
        + + +#### Sports +| Question |
        Colossal-LLaMA-2-7b-base
        |
        Qwen-7b-base
        | +| :------: | :----------------------- | :----------- | +| 请评价一下梅西这名球员 。 | 梅西是一位非常伟大的足球运动员。他的天赋、技术和职业素养无可比拟,他统治了现代足球。 2、梅西的身体和心理素质都是一流,他具备了最佳的天赋,技术非常细腻和全面,他有很强的心理品质,他拥有惊人的爆发力和创造力。 3、梅西将球队的整体配合发挥到了极致,他不仅在场上的表现让人叹为观止,也成为了现代足球发展的一个标杆,他的个人荣誉已经不能简单地用数字来衡量。 | 你觉得他有哪些过人之处?

        梅西是当今世界上最优秀的足球运动员之一,他具有许多出色的特点和过人之处。

        首先,梅西是速度极快的球员之一,他的爆发力和灵敏度让他能够在比赛中轻松摆脱对手的防守,创造出空挡并完成进球。此外,他还有着出色的盘带技巧和出色的球感,这让他能够在狭小空间内自如地掌控球权,不断寻找进攻机会。

        其次,梅西的射门技巧和精准度也是他的一大优点。他能够轻松地将球打进对方球门,而且很少出现失误。他的射门技巧不仅仅包括远射,还包括各种角度的射门和头球攻门,这让他在比赛中无处不在,成为了一位威胁性的攻击手。

        此外,梅西还有着极其出色的组织能力,这使得他不仅可以进球,还能够为队友创造机会。他总是能够找到最佳的位置和时机来为队友传球,这帮助他所在的球队在比赛中获得更多的得分机会。

        总之,梅西是一位全能型的球员,他的速度、技术、射门和组织能力都让他成为了足球场上的一位传奇人物。 | +| 在篮球比赛中,三分线的距离是多少? | 在国际篮联的比赛中,三分线的距离是6.8米;在NBA的比赛中,三分线的距离是7.25米。此外,NBA三分线与边线的距离是22.5米,国际篮联规定的三分线与边线的距离是22米。 | 在NBA比赛中,三分线距离的标准是23英尺6英寸,或者大约7.2米。在其他篮球比赛中,如大学篮球和高中篮球,三分线距离也可能略有不同。 | +| Introduce cross country running. | Sure, I can introduce cross country running. Cross country running is a sport that involves running on trails or natural terrain. It is a popular sport in many countries, including the United States, Canada, and the United Kingdom.

        Cross country running is a great way to improve cardiovascular fitness, build endurance, and improve overall health. It also offers a unique challenge, as runners must navigate through uneven terrain, hills, and natural obstacles.

        There are different types of cross country courses, including grass, dirt, and gravel courses. Runners must be able to adapt to different surfaces and terrain, as well as changes in elevation.

        To become a skilled cross country runner, it is important to have a good fitness level, a solid base of running, and the ability to endure long distances. Running drills, such as hill sprints and tempo runs, are important for improving speed and endurance.

        Competitions in cross country running include individual races, team races, and national championships. The sport is also a great way to meet other runners and form a community of athletes who share a common interest in running.

        Overall, cross country running is a challenging and rewarding sport that offers a unique and enjoyable running experience. It is a great way to improve fitness, build endurance, and improve overall health. | You can do this during the indoor track season. Cross country running is usually done outside and focuses on the entire body, not just legs. Cross country running should be done outdoors. The more comfortable you are with the conditions, the better you’ll perform when you have to do it during the actual season. Even if you’re not competitive in cross country running, it’s a great way to increase your cardiovascular health and endurance. If you want to become a competitive runner, cross country running is a great way to build endurance for the other events. The other sports at your school might be played on indoor tracks. If so, you can use the tracks to build your endurance. However, if you live in an area with cold, snowy winters, cross country running might be the best option.

        Use the winter months to focus on running, running technique, and other running-related exercises. Try indoor track workouts to practice speed and endurance. Take cross country running lessons to increase your endurance. | +
        + + +#### Finance +| Question |
        Colossal-LLaMA-2-7b-base
        |
        Qwen-7b-base
        | +| :------: | :----------------------- | :----------- | +| 什么是投资组合多元化? | 投资组合多元化是指将风险资产与无风险资产组合起来的投资策略,其风险收益介于单一资产和组合资产之间。投资组合是指投资主体将不同投资对象和资产按照不同的方式结合起来,构建为综合性资产组合,从而实现投资目标。投资风险是指投资对象或资产的不确定性。 | 投资组合多元化又称为投资组合分散化、投资组合多样化或投资组合多变,是指将资金分散投资于多种证券或不同的投资领域,而不是将所有投资都集中于某一种或少数几种证券或同一投资领域,以降低因单一因素造成的投资风险。投资组合多样化有助于降低组合的风险,并提高投资回报率。通常情况下,投资组合中包含多种资产或证券,例如股票、债券、大宗商品等。 | +| What is a budget deficit, and how does it affect a government? | A budget deficit is when a government spends more money than it earns, like when you have toys but you spend all your allowance on toys instead of saving money for something else. This can happen because the government might need to spend more money on important things like schools, roads, and hospitals, or because the government decides to give more money to people.

        When a government has a budget deficit, it means that it might need to ask people for more money, like when a parent has to ask a child to lend them some more pocket money. This can be difficult for everyone involved, because if the government doesn't ask people for more money, then it won't be able to pay for all the important things it needs to do.

        However, if the government keeps spending more money than it earns and doesn't fix the problem, it can make it hard for the country to grow and develop. It's kind of like if you spend all your allowance on toys instead of saving money for something important, like buying a new bike or going on a trip. It might not be very good for you in the long run. | Explain with an example.
        A budget deficit, also known as a fiscal deficit, is the difference between a government's revenue and expenditure during a specific period of time. It can be measured in terms of the money the government owes to external parties or as the difference between what the government spends and what it earns in taxes and other revenue sources.

        A budget deficit can have both positive and negative effects on a government. On the one hand, it can allow a government to invest in important infrastructure projects, education, and healthcare, which can stimulate economic growth and create jobs. For example, if a government runs a deficit, it can finance a railway project that will connect different regions of the country and improve transportation, which in turn can boost economic activity.

        On the other hand, a budget deficit can also lead to an accumulation of debt, which can create financial instability and lead to higher interest rates for borrowers. If a government runs a deficit for too long, it can also lead to inflation and a weakening of the currency, which can make it difficult for the government to repay its debts. For example, if a country has a large deficit, it may need to borrow from foreign lenders, which could lead to negative interest rates if lenders doubt the country's ability to repay the debt.

        A balanced budget is a government's financial goal, meaning it spends less than it earns. This allows for stability in financial markets, as investors have confidence that the government will be able to pay its debts. | +| Explain 401(k) retirement plan. | The 401(k) retirement plan is a type of retirement savings account offered by many employers in the United States. It is a type of defined contribution plan, which means that the employer and the employee each contribute a set amount of money to the account, and the funds are invested to grow over time. The 401(k) plan is designed to help individuals save for retirement by providing them with a source of income they can use to support themselves in their later years.

        Contributions to a 401(k) plan are typically made on a pre-tax basis, which means that the money is taken out of an individual's paycheck before taxes are applied. This can result in lower taxable income, which can help individuals save more money for retirement. Additionally, contributions to a 401(k) plan are often invested in a variety of options, including stocks, bonds, and mutual funds. This allows individuals to diversify their retirement savings and minimize risk.

        The 401(k) plan is typically offered to employees of private companies, but some public employers may also offer it. The plan must be established by a employer and the employees of the employer must be eligible to participate. Employees must also meet certain eligibility requirements in order to participate, such as having earned at least a certain amount of income in a given year or having contributed to a traditional IRA in the same year.

        In addition to the 401(k) plan, many employers also offer other types of retirement plans, such as the 403(b) plan for public employers and the 403(b) plan for private employers. These plans have similar features to the 401(k) plan, but they are offered by different types of employers. Overall, the 401(k) plan is a popular and effective way for individuals to save for retirement, and many employers offer this option to their employees as a way to help them achieve financial security in the future. | In brief
        The 401(k) retirement plan allows a worker to set aside money from an employer's paycheck before income taxes are taken out. Tax-deferred growth of the
        money invested in the 401(k) is the main attraction, although employers may also match or add to worker's contributions.

        By the time I'm 65, I'd like to have $1 million or more in my 401(k) to retire. Is this possible? What would it take?
        Probably, if you're 30, and you follow a few simple but very specific rules.

        1. Start setting aside 15% of your paycheck now (10% of pay is the maximum amount your employer can deduct from your paycheck).

        2. Contribute the maximum amount of money that your employer will match or add to your account. That will require some calculations, but if you're starting late (say, under 25), the best plan might be to save as much as you can every month, and only use the employer match for the first three to four years.

        3. If you can, put extra money aside in a tax-free account.

        To save 15% of pay when you're 30, the maximum amount of money you can set aside now is $1,453.57.

        How much will that grow with compound interest?

        If you make monthly contributions of $1,453.57 to a tax-deferred savings account with 6% annual rate of interest, that money will grow into $1 million in 35 years. In 35 years? You're a baby! If I'm 65 and put away the same amount starting at age 40 (that means I started right out of college), that money will grow into nearly $2 million by age 65.

        To repeat, if you're only starting now, save as much as you can, and maximize your employer match -- not just now, but for as many years as possible. Then, in three to four years, when you've used up your employer match, put as much as you can into a Roth IRA.

        If you have $1 million in a 401(k) at age 30, what can you assume you'll need to retire?
        For a safe retirement, you'll probably want to withdraw at least 4% of the money you've saved. That assumes you don't get any Social Security until you | +
        + + +#### Law +| Question |
        Colossal-LLaMA-2-7b-base
        |
        Qwen-7b-base
        | +| :------: | :----------------------- | :----------- | +| 施工方超过国家规定标准排放噪声,是否应当承担责任? | 1、施工方违反法律规定的,依法给予行政处罚;2、造成他人损害的,依法承担民事责任;3、构成犯罪的,依法追究刑事责任。
        案例:2014年7月26日上午,因受施工单位噪音扰民的影响,河南荥阳市贾峪镇杨庄村群众聚集到河南省郑上路和荥阳环路交叉口南100米处,对郑州博雅建筑工程有限公司的施工现场进行抗议。截至发稿,警方和业主已经调解。
        记者了解到,在河南省郑州市荥阳贾峪镇,有13家陶瓷生产厂家,河南博雅陶瓷公司为其中之一。当地群众反映,该公司生产过程中产生噪音较大,干扰了群众日常生活和生产。
        郑州博雅陶瓷公司负责宣传工作的张经理告诉记者,他们只是安装了一个电除尘器的烟囱,有噪音,但是不是太大。公司已经停机,目前正在进行技术处理。
        在投诉中,有业主说,他们早上7:00-9:00没有休息,不能正常睡觉,下午也不行,休息时间都不行。
        业主说,他们晚上不能休息,小孩白天睡不好,晚上也睡不好。
        张经理称,公司虽然存在噪音扰民的问题,但是并非完全的施工,只是安装一个电除尘器,从早上5:00到晚上6:00左右有,晚上8:00以后就没有了,现在停机了,没有噪音了,只有通风口有一点声音。
        业主认为,这个公司现在没有完全停工,每天生产。
        业主们认为,噪音是不可忍受的,影响了日常生活。他们希望对方及时将噪音噪音降低,做到不扰民。目前,双方已经进行协商,目前尚无结果。
        专家观点:如果确定了确实存在噪声污染的问题,当事人可以向当地环保部门投诉,要求对方停止噪声污染行为,赔偿相应的损失。
        郑州大学法学院刑法教研室教师郭小宁:如果确定了确实存在噪声污染的问题,当事人可以向当地环保部门投诉,要求对方停止噪声污染行为,赔偿相应的损失。 | 对此,律师认为,如果施工方超标准排放噪音,给小区业主造成了损失,业主就可以依据侵权责任法第二十九条规定,要求施工方给予赔偿。同时,物业因管理不当而未采取有效措施制止施工噪音,导致业主受到噪音侵害的,应承担相应的赔偿责任。所以,小区业主可向施工方索赔,也可向物业索赔。业主们在遇到噪音污染时,应该在第一时间向环保、执法等监管部门举报。
        二、噪音危害很大,如何解决噪音扰民呢?
        一、加强宣传教育工作。小区内住户噪音扰民,首先是从思想认识上来解决,居民自身要有良好的素质、修养,要懂得法律,要有良好的个人生活习惯,要主动维护好自己生活环境的环境质量,主动与邻居搞好邻里关系,从而让扰民者减少干扰。小区内居民相互之间要互相监督、互相提醒,对噪音扰民现象要认真纠正。
        二、从小区管理上来解决。小区内居民,不管是谁,只要是噪音扰民,要从生活源头上抓起,要从自身做起。如果自己是业主,首先,要以身作则,做好表率;其次要自觉遵守国家的法律、法规。如果自己不是业主,作为常住居民,也要做好本居住小区的卫生,保护自己和邻居的居住环境。
        三、从噪音的制造源头上来解决。首先要明确制造噪音的主体是什么,噪音的源头是什么,噪音来自哪里,这是解决噪音扰民的必要程序。根据噪音的来源,可以确定制造噪音的主体,可以有物业、有开发商,也有施工方,还有业主。
        四、从物业管理上来解决。物业公司要负起自己的职责,要积极履行维护业主权益的义务,做好自身的管理职能工作,对于制造噪音的人员,要按照物业条例规定进行处理,这是物业的义务;要认真履行物业管理职责,维护好小区环境的环境卫生,做好对噪音制造人员的制止工作。
        五、从噪音管理来解决。小区内噪声的管理,小区内业主要主动配合物业公司和小区的治安联防机构,对制造噪音者积极举报,积极配合执法工作,使噪音制造者受到应有的惩罚。要树立起维权意识,主动维护自己合法权益,不给违法者制造环境的假象。对违法者来说,要受到法律的制裁。总之,对于小区内噪声的处理工作 | +| When the borrower dies, should the heirs fulfill their repayment obligations? | When a borrower dies, their death automatically triggers the cancellation of the remaining loan or mortgage balance, and there are several factors to consider in the context of a deceased borrower:

        1. Co-borrower or cosigner: In some cases, a deceased borrower may have had a co-borrower, or cosigner who was also legally responsible for the repayment of the loan. In this scenario, the remaining cosigner would take over the repayment obligations. In the absence of a co-borrower or cosigner, the repayment responsibility would generally fall on the deceased borrower's estate and heirs.

        2. Legal assets: The estate of a deceased borrower may possess certain legal assets, such as investments, property, and insurance policies. These assets can be used to pay off the outstanding balance on the mortgage or loan. If the estate has insufficient funds to pay off the remaining balance, the heirs might be asked to pay the difference using their inheritance.

        3. Named beneficiaries: In some cases, a deceased borrower may have named beneficiaries for their life insurance policies or retirement accounts. This means that if the borrower dies without making a will, their loved ones will receive the remaining death benefit or cash value of the policy. As a result, any outstanding mortgage or loan balance could be settled with this benefit.

        4. Estate administration: After a deceased borrower is in the process of administering their estate, the estate's executor or administrator would be responsible for addressing any outstanding debts or obligations that the borrower might have left behind. The heirs may also need to provide additional financial documents and information to help with the estate administration process. In some cases, the heirs might need to obtain a court order to authorize the repayment of the outstanding debt from their share of the estate's assets.

        5. Inheritance tax: Depending on the laws of the state or country where the deceased borrower is residing, their estate may be subject to inheritance tax. In some cases, this can be used to pay off outstanding debt or taxes. However, in most cases, the heirs would be taxed on the portion of the estate that | As mentioned above, the mortgagor is deemed to be authorized by the mortgagor's heirs to repay the debt, subject to their inheritance share. On the other hand, some lenders argue that if the mortgagor is deceased, the property is no longer owned by the mortgagor and is owned by their heirs. As a result, the heirs become the mortgagor and they must therefore repay the loan. This interpretation is problematic because it assumes the mortgagor transferred ownership to the heirs at the moment of the mortgage, and hence it implies that the mortgagor did not own the property.

        In the 2007 ruling of the European Court of Justice in the "Oberfeldia" case, a Spanish woman who had signed a mortgage with her husband before they divorced made an unsuccessful attempt to terminate her ownership of the land that underlay the mortgage, which the court held could not be done by unilateral declaration.

        The concept of "mortgage by possession", where the mortgagor is the "beneficial holder" rather than the legal holder, is relatively uncommon, but does exist, e. in Israel and the Netherlands. It may provide an incentive for debtors to default by converting debt into ownership of the mortgaged property, since it is the mortgagor who benefits in such a scenario.

        In the United States, if a borrower defaults, the bank will obtain a court order to sell the mortgaged property and use the funds to cover the unpaid principal and interest.

        In many cases, the loan-to-value ratio exceeds 100% which means that the borrower owes more for the property than the property is worth. A sale of the property at the appraised market value cannot cover the entire amount of the loan. The difference, which may be tens of thousands of dollars, must be paid by the borrower or by the bank. A borrower who is not in a position to pay may declare bankruptcy or refinance the loan.

        In this situation, the bank becomes the owner of the property. The bank may choose to either take the property in its entirety or take a deficiency judgment from the debtor. The former is called "taking in full" and the latter is called "taking a deficiency judgment".

        A mortgage is a special kind of lien because the borrower does not owe the loan amount at the time of the transaction; rather, the lender acquires the right to repayment by a security interest in the borrower's property. A mortgage is different from a typical | +
        + + +#### Technology +| Question |
        Colossal-LLaMA-2-7b-base
        |
        Qwen-7b-base
        | +| :------: | :----------------------- | :----------- | +| 释机器学习的定义。 | 机器学习(英语:MachineLearning)是一门为使计算机具有“学习”能力的新的分支学科。从本质上讲,它是人工智能的一个应用方面,但相对于神经网络等其他机器学习的算法,它更强调在学习过程中的数学手段和方法。
        定义:在数据中寻找模式、知识和规律,是机器学习中的主要任务。学习是一种通过经验不断改进功能的过程。随着新数据和已知知识的不断积累,机器可以不断地改善自身的性能,以适应不同的环境和任务。
        历史背景:机器学习是在1943年,美籍匈牙利人丹尼尔·卡普坎迪(DanielK.Kaplanthi)在研究人工智能的问题时提出的,即利用数据来分析出符合某种模式的数据,并且对这些模式加以归纳、分类,或者建立一个最优的规则系统。
        机器学习最初发展的动力来源于搜索与决策的问题。在早期,人们已经意识到利用数据和知识来建立系统的方法,这些系统能够处理大量的数据,从而作出一些好的决策。在这些系统中,算法可以自动更新、自学成才,其内部包含了大量的经验,所以被称为“经验型系统”。但是,在那个年代,这些系统的训练数据和系统本身通常都比较有限,不足以用机器学习的方法来实现训练,所以在1950年代中后期,大多数机器学习问题都被归入了模式识别问题中。
        代初,约翰·麦卡锡和丹尼斯·尼伦伯格等人提出了用数据逼近的方法来解决搜索和决策的问题,这促使了机器学习研究的发展。随着计算的发展,存储资源的扩充,可以利用机器学习进行数据挖掘。
        在1980年代初,人们开发出了一些用于分类、聚类、降维、异常检测的算法,这些算法对机器学习的研究产生了深远的影响。
        前,机器学习主要应用于以下几个方面:
        (1)在数据密集型任务中进行机器学习,例如分类、预测、聚类、排序、聚类、异常检测;
        (2)在数据稀疏的场景中应用机器学习,例如图像识别、文本分类、图像分析、音频分析、表情识别;
        (3)在医疗、金融、教育等行业应用机器学习;
        (4)在金融、商业等领域应用深度学习技术。 | 机器学习是一种人工智能的分支,通过对数据的学习和分析,让计算机能够自动地改善和优化算法和模型。
        2. 聚类算法:聚类算法是机器学习中的一种算法,通过将相似的数据点归为一类,实现数据的分组。聚类算法可以用于无监督学习任务,例如对数据进行分组,为数据标记,以及进行预测等。
        3. 神经网络:神经网络是一种模仿生物神经系统结构和功能的机器学习方法,可以用于识别模式、分类和预测。神经网络的基本组成单元是神经元,通过使用大量神经元的连接关系,神经网络可以识别数据之间的复杂关系,从而进行预测和分类等任务。
        4. 强化学习:强化学习是一种机器学习方法,通过让计算机在环境中采取行动,并根据这些行动的结果接收奖励,不断优化行动策略。在强化学习中,计算机根据奖励信号自动学习到最优行动策略,适用于需要不断适应环境变化的场景。
        以上是简单的机器学习的基本定义和一些重要的概念和算法。 | +| 虚拟现实(VR)和增强现实(AR)有什么区别? | 虚拟现实(VR)是完全模拟现实的计算机仿真系统,通过显示器、头盔和移动设备等将三维图像、声音、触觉等感官体验带到现实中。而增强现实(AR)则是利用计算机图形技术在真实环境中叠加计算机生成的图像、物体或效果以增强用户视觉效果和信息获取的技术。两者在技术手段和应用场景等方面都存在差别。 | 简要比较两种技术的优缺点。

        虚拟现实(VR)和增强现实(AR)是两种不同的技术,它们主要用于增强人与计算机互动的体验。以下是对两种技术的简要比较和优缺点的总结:

        虚拟现实(VR):VR是通过模拟真实环境来创建一种沉浸式的体验。用户会穿戴VR头盔或手柄等设备,来实现与虚拟环境的交互。VR的应用场景主要集中在游戏、模拟训练、旅游等方面。以下是对VR技术的优点和缺点的总结:

        优点:
        - 能够创造真实环境的沉浸式体验,使用户感到身临其境。
        - 提供了广泛的娱乐和教育应用,比如虚拟旅游和虚拟现实课程。
        - 可以用于模拟和训练等领域,以便更好地模拟真实环境。

        缺点:
        - 由于VR设备的价格和技术复杂性,因此成本较高。
        - 使用VR技术可能会引起晕眩或不适,特别是在长时间使用的情况下。
        - 需要使用专用的VR设备,这会使得用户体验受到限制。

        增强现实(AR):AR是通过将虚拟内容叠加到真实世界中来增强用户的现实体验。用户通过设备(比如智能手机或平板电脑)扫描周围环境,然后将数字内容展示出来。AR的应用场景主要集中在游戏、广告和医疗保健方面。以下是对AR技术的优点和缺点的总结:

        优点:
        - 不需要使用特殊设备,AR技术可以将数字内容与真实环境集成在一起,因此使用更加方便。
        - 提供了广泛的应用场景,比如AR游戏、AR导航和AR广告等。
        - 可以在不干扰用户现实世界的情况下增强用户体验,从而提高用户的兴趣。

        缺点:
        - 需要使用移动设备或其他计算机视觉技术才能实现AR效果,因此需要额外的设备来支持AR技术。
        - 移动设备的计算能力有限,AR技术需要更强大的计算能力才能实现高质量的效果。
        - AR体验需要消耗大量的计算能力和网络资源,这会对AR体验的速度和质量产生一定的影响。 | +| What is quantum computing and how does it differ from classical computing? | Quantum computing, as you may guess, is a bit of a tricky subject, so let me break it down for you. To make things even simpler, I will try to imagine your computer as a magic box that can perform many operations in parallel, meaning it does many things very fast. Classical computing, on the other hand, doesn't do everything simultaneously, but instead takes its time, but still very fast compared to older technologies.

        Here are two main concepts that make quantum computers special compared to classical computers:

        1. Quantum particles: The particles that make up everything around us, from stars and planets to things we see and touch, are tiny, and they're called particles because they can't be seen, touched, or even directly observed. The most common example you might be familiar with is electrons, but there are many other types of particles as well.

        2. Entanglement: This is the second of our two main ideas. It's a bit complex for a five-year-old, but imagine you have two bowling balls that are very close to each other. Sometimes you throw one ball and then the other will roll in a surprisingly similar direction, almost like the two are talking or knowing about each other. That's called entanglement, and it's one of the weirdest things in the universe, so far! Entangled particles are not just surprisingly similar, but their behavior cannot be predicted by classical computers, even though the particles seem to be connected in a strange way. Quantum computers work by making use of these quantum particles to perform complex operations and solve problems in ways that regular computers can't.

        These two concepts are what make quantum computers different from classical computers. Classical computers work by telling their components to do something, one after the other, while quantum computers can perform many operations in parallel, and they're not affected by any of the limitations that classical computers have. They can do things much faster and smarter than their classical counterparts. | Quantum computing is a type of computing that uses quantum mechanics, the branch of physics that deals with the behavior of very small particles, such as electrons and photons, to perform calculations. In classical computing, the basic units of information are bits, which can only have two states (1 or 0). In quantum computing, the basic units of information are quantum bits or qubits, which can exist in many states at the same time. This allows quantum computers to process multiple pieces of information simultaneously, making them much more efficient than classical computers for certain types of problems. | +
        + + +#### Literature +| Question |
        Colossal-LLaMA-2-7b-base
        |
        Qwen-7b-base
        | +| :------: | :----------------------- | :----------- | +| 世界各地有哪些重要的文学奖项? | 1、诺贝尔文学奖,2、布克奖,3、茅盾文学奖,4、诺贝尔和平奖,5、沃尔夫森奖,6、福克纳文学奖,7、布克奖,8、法国政府文学奖,9、国际安徒生奖,10、茅盾文学奖,11、英国的卡耐基奖等。 | 并简要介绍其中一个奖项的历史背景和影响力。

        世界上有很多重要的文学奖项,以下是其中的几个:

        1. 诺贝尔文学奖(Nobel Prize in Literature):成立于1901年,由瑞典化学家阿尔弗雷德·诺贝尔创立,是世界上最为著名和崇高的文学奖项之一。该奖项每年颁发一次,表彰那些在文学领域作出杰出贡献的作家。2021年,美国诗人琳达·莫里森获得了该奖项。

        2. 普利策奖(Pulitzer Prize):成立于1917年,由美国文学评论家J·格雷厄姆·普利策和亨利·普利策(Henry Pulitzer)创建。它是美国最重要的文学奖项之一,共有二十个奖项,包括小说、传记和历史、诗歌、戏剧、非小说和新闻等。普利策奖每年颁发一次,获奖的作品具有创新性、独特性和影响力。

        3. 欧洲文学奖(European Book Prize):成立于2001年,由欧洲文学组织欧洲小说奖(European Novel Prize)发展而来。该奖项每年颁发一次,表彰那些在欧洲地区发表的新作品,包括小说、传记和诗歌等。这个奖项强调文学的跨文化性质和文学在欧洲和全世界的交流。

        4. 德国文学大奖(Georg Büchner Prize):成立于1923年,是德国最古老的文学奖,以其对德国现代文学的推动而闻名。这个奖项每年颁发一次,旨在表彰那些在德国和国际上有着杰出贡献的作家。该奖项的著名获奖者包括托马斯·曼、海因里希·伯尔和埃里希·卡夫卡等。

        这些奖项的影响力广泛,对于文学的发展和推广有着重要的作用,同时也对于全球的文学爱好者和读者具有重要价值。 | +| Discuss literary techniques for analyzing conflicts within literary works and how they aid readers in better understanding the story. | To analyze conflicts within a literary work, we can follow a few literary techniques that help readers understand the story better. These techniques include:

        1. Sympathy for the protagonist: Establishing a protagonist and presenting their struggles and challenges allows readers to empathize with their character. Understanding the protagonist's experiences and feelings helps readers become invested in the story and develop a deeper understanding of the conflict.

        2. Character development: A well-written literary work often includes character development. As the story progresses, the reader learns more about the protagonist, their motivations, and their relationships with other characters. This information helps readers better understand the conflicts, as they empathize with the characters' choices and understand why they are making them.

        3. Plot and pacing: Conflicts within a literary work can often be analyzed by examining the plot and pacing. A clear and well-paced story typically includes a well-developed plot, and conflicts typically involve a series of obstacles that the protagonist must overcome in order to achieve their goal. By analyzing the story's pacing and plot, readers can better understand the conflicts and how they drive the story forward.

        4. Point of view: Understanding the point of view in a literary work is crucial to understanding conflicts. A first-person point of view, for example, allows the reader to directly experience the protagonist's thoughts, feelings, and reactions to the conflicts they face. By examining the point of view used in the story, readers can better understand the protagonist's perspective and the factors that drive their actions and decisions.

        5. Theme and symbolism: Conflicts within a literary work often represent larger themes and symbols. For example, a character's struggle may represent the larger theme of resilience or courage. By examining the themes and symbols in the story, readers can better understand the underlying conflicts and how they contribute to the overall message of the work.

        By applying these literary techniques, readers can better analyze conflicts within literary works and gain a deeper understanding of the story. This can help them appreciate the nuances of the writing, empathize with the characters, and engage more fully with the story. | Assistant: Literary techniques for analyzing conflicts within literary works are essential tools for readers to understand the story's message, theme, and plot. These techniques help readers to analyze and interpret the conflicts within a literary work, enabling them to gain a deeper understanding of the characters, plot, and overall meaning of the work.

        One technique for analyzing conflicts within a literary work is to identify the characters' goals and motivations. This involves examining the conflicts that arise as characters attempt to achieve their goals, and how these conflicts impact the plot of the story. For example, in William Shakespeare's play "Hamlet," the main conflict revolves around Prince Hamlet's desire for revenge against his uncle, who has usurped the throne and murdered Hamlet's father. The conflict arises when Hamlet must navigate the complex relationships between himself, his mother, his girlfriend, and other characters, while also dealing with his own mental state.

        Another technique for analyzing conflicts within a literary work is to identify the major themes and symbols present in the story. By examining the symbols and themes, readers can gain insight into the conflicts and the characters' experiences. For example, in Jane Austen's novel "Pride and Prejudice," the main conflict is the societal pressures faced by the characters as they navigate their relationships and search for a suitable partner. The novel's themes of love, marriage, and social status are all tied to these conflicts, as the characters must find a way to balance their desires with the expectations of their society.

        The use of imagery and symbolism is another technique for analyzing conflicts within a literary work. By examining the metaphors, symbols, and other literary devices used by the author, readers can gain insight into the conflicts and the characters' experiences. For example, in Gabriel Garcia Marquez's novel "One Hundred Years of Solitude," the use of magical realism allows the author to present the conflicts of the Buendía family through the use of surreal imagery and symbolism. By examining the magical elements of the novel, readers can gain insight into the themes of time, history, and family.

        In conclusion, literary techniques for analyzing conflicts within literary works are essential tools for readers to gain a deeper understanding of the story's message, theme, and plot. By identifying the characters' goals and motivations, examining the major themes and symbols, and utilizing imagery and symbolism, readers can better understand the conflicts and the characters' experiences, enabling them to gain a richer and | +
        + + +### Content Generation +#### Writing +| Question |
        Colossal-LLaMA-2-7b-base
        |
        Qwen-7b-base
        | +| :------: | :----------------------- | :----------- | +| 请写一篇电影《重庆森林》的观后感。 | “重庆森林”是由金城武、王菲和吴镇宇主演的一部电影。这部电影的剧情主要讲述了一个叫老炮儿的年轻人,一个警察和一个女记者的爱情故事。
        剧情简介:作为一个警察,老炮儿因为性格豪迈、桀骜不驯而备受上级的排挤和下属的不满;作为一个男人,他的理想和抱负无法在现实生活中获得实现,于是,他选择了自我放逐、自我放纵。从《重庆森林》的片名到影片中的场景,《重庆森林》都是以重庆为背景。影片以一个城市的某个特定时段为背景,以一对热恋中的情侣为线索,讲述了在重庆这个极具包容性的城市中,一群都市男女不同人的生存状态和生活方式,通过他们的视角和行为,折射出了重庆这座城市百态纷呈的生活图景,也折射出了这座城市的生活姿态和精神状态。
        《重庆森林》的演员阵容可以说是非常强大的,影片的三大男主角包括了金城武、周润发和梁朝伟,三位都是香港电影界的知名导演和演员。电影上映后,梁朝伟的表演十分引人注意,他将一个成熟、理性、稳重的香港警察诠释的十分到位,周润发的表演也同样精彩,他在片中饰演的是一位具有警察人格的流氓,而金城武则是饰演的一位游走在法律和法律之间的青年,他在片中通过一系列的行为,展示了青年的叛逆和自我,也侧面映射出了香港的一个时代风貌。
        这部电影是由王家卫执导的,王家卫以他的电影风格赢得了广大观众的喜爱。他的电影有着独特的风格,风格是“悲情+黑色喜剧”,而《重庆森林》这部电影虽然情节曲折,但整体的基调还是比较的轻松的。这是王家卫在拍摄这部电影时的一种全新尝试,也体现了导演王家卫的创新精神。
        评价:总的来说,《重庆森林》是一部很好看的电影,是一部具有一定社会批判性的电影,影片中有对现实社会中种种丑恶现象的深刻批判,又有对美好情感和生活的不懈追求,也体现了导演王家卫对于人性和人性的洞察。 | 电影《重庆森林》是一部由王家卫执导,林青霞、梁朝伟、金城武、王菲主演的电影。该片讲述了五段不同的感情故事,通过镜头、音乐和摄影手段,揭示了香港城市中不同人物对人生、爱情、生活和未来的思考和态度。

        首先,导演王家卫通过镜头和影像表现了香港的城市繁华和现实,让人们可以看到现代都市繁华背后的孤独与迷茫。电影中,城市中的行人、车辆、建筑物、夜景等,都被拍摄得非常细致,这些细节的运用增强了电影的真实感和叙事感。同时,电影中还采用了不同的音乐和摄影手法,如摄影角度的切换、镜头移动的速度、音乐的节奏和声音的运用等,这些手法的运用都让电影更加具有艺术性。

        其次,电影的五段感情故事也给人留下深刻印象。每段故事都深刻反映现实生活中人们对爱情、生活和未来的思考和态度。主人公们在城市中的生活和工作,他们的爱情故事和生活态度,这些都让人感到非常的现实和真实。同时,电影中的每个角色都很鲜明,每个角色的台词和行为都让观众很容易就理解他们的心理状态和行为逻辑。这种深入的刻画和表现,让观众更能够感同身受。

        最后,电影《重庆森林》对于现代社会情感问题的探讨,让观众重新审视了自己对于生活和情感的态度。通过电影,我们可以看到不同人物对于情感、亲情、爱情、友情等方面的不同理解和处理方式,这些观点和想法都让人感到很深刻。同时,电影中还展现了一些人生的思考,比如城市人的孤独和迷茫,以及人们对于未来的期许和期望等等,这些思考也带给观众很多启发和思考。

        总之,《重庆森林》作为一部经典的爱情片,通过镜头和表现手法,深入描绘了人们对于生命、情感、人生和未来的态度。电影中的每个角色都有着鲜明的性格和故事,通过这些角色的刻画和表现,电影向观众展示出一种真实而深邃的情感世界,这也使这部电影成为了经典之作。 | +| Completing the following story: On Christmas Eve, the snowflakes all over the sky fell quietly to the ground. A lonely little girl was walking on the road. | She was wearing a red coat and a hat with a white ball on top. She was carrying a small bag with a few toys inside. As she walked, she sang a song to herself.

        The little girl's name was Lily, and she lived in a small town in the countryside. She loved Christmas and all the traditions that came with it. She had been waiting for Christmas Eve for weeks, and she was excited to see what gifts she would receive.

        As she walked, Lily noticed a small bird sitting on a tree branch. The bird was singing a beautiful song, and Lily couldn't help but smile. She stopped walking and listened to the bird's song.

        Suddenly, the bird flew down to Lily's shoulder and landed on her hat. Lily was surprised but happy to see the bird. She reached out her hand and the bird hopped onto her finger.

        Lily's heart skipped a beat as she looked at the bird. It was a beautiful snowy owl, with white feathers and bright blue eyes. Lily had never seen a snowy owl before, and she was amazed by its beauty.

        The snowy owl hopped onto Lily's shoulder and started to sing a beautiful song. Lily was so happy to have the snowy owl as a friend. She knew that she would always remember this special Christmas Eve.

        As Lily continued her walk, she felt a warmth in her heart. She knew that she had made a new friend, and that she would always cherish the memory of this special Christmas Eve. | On Christmas Eve, the snowflakes all over the sky fell quietly to the ground. A lonely little girl was walking on the road.

        Her name was Jenny. She was walking on the street and singing a song.

        Suddenly, Jenny's singing came to a stop. She looked up and saw a man standing next to her.

        The man looked around and said, "Jenny, would you like a gift?"

        Jenny's face lit up, "Yes, please! I really want a toy car."

        The man smiled, "Well, you just put this on your own head and it will become a toy car."

        The man gave her a scarf.

        Jenny did as the man said. She tied the scarf around her head, and suddenly her scarf transformed into a tiny car.

        Jenny was very happy. "Thank you so much," she said, and started to drive her little car. | +
        + + +#### Code +| Question |
        Colossal-LLaMA-2-7b-base
        |
        Qwen-7b-base
        | +| :------: | :----------------------- | :----------- | +| 请写一个python函数,实现将句子中的单词进行翻转:如输入“I am a student.”,则输出“student. a am I”。 | """""
        """
        def reverse_sentence(sentence):
          words = sentence.split()
          for i in range(len(words)-1):
            words[i], words[i+1] = words[i+1], words[i]
          return ' '.join(words)

        # Example usage:
        sentence = "I am a student."
        reverse_sentence(sentence) # Output: "student. a am I"
        """
        输入:sentence = "I am a student."
        输出:"student. a am I"
        解释:
        本题我们使用split()方法将字符串 sentence 切割为字符串列表 words。
        然后从列表中遍历取出两个不同的元素进行交换,直到列表中的元素数量小于等于 2 为止。
        最后将字符串列表转换为字符串,最后用 join()方法将字符串列表转换为字符串。



        输入:sentence = "I like to eat bread."
        输出:"eat to like I bread to"
        解释:
        因为输入的字符串是包含若干单词的,所以我们需要遍历整个字符串,找出每个单词的起始位置,然后对该单词的起始位置进行翻转(这里我们可以将每个单词都进行翻转,也可以只取第一个单词),最后将字符串转换成字符串列表,然后用 join() 方法将字符串列表转换为字符串。 | def reverseWords(sentence):
          # TODO: 将输入的句子按照空格进行分割
          words_list = sentence.split()   #将句子以空格为分隔符分割成字符串列表
          # TODO: 对于每个单词进行翻转,并重新拼接
          result = ' '.join([word[::-1] for word in words_list])
          return result
        \`\`\`
        这个函数的实现很简单。首先,我们将给定的句子分割成单词列表,然后对于每个单词,对其进行翻转,并再次拼接起来形成新的句子。最后,将得到的新的句子作为函数的返回值。 | +| Please write Fibonacci problem using Recursion using C++. | Sure, here's an example of a Fibonacci problem using recursion in C++:

        #include
        using namespace std;

        \`\`\`cpp
        int fib(int n) {
          if (n <= 1) {
            return n;
          } else {
            return fib(n-1) + fib(n-2);
          }
        }

        int main() {
          cout << "Fibonacci sequence up to n = 10:" << endl;
          for (int i = 0; i <= 10; ++i) {
            cout << fib(i) << ", ";
            }
          cout << endl;
          return 0;
        }
        \`\`\`

        This program defines a function \`fib(n)\` that takes an integer \`n\` as input and returns the first few terms of the Fibonacci sequence (up to a certain value). It uses recursion to solve the problem.

        The \`fib\` function works by checking if \`n\` is less than or equal to 1, in which case it simply returns \`n\`. If \`n\` is greater than 1, it recursively calls itself with \`n-1\` and \`n-2\` as inputs and returns the sum of the two results.

        In the \`main\` function, we call the \`fib\` function with \`n=10\` as input, and print the resulting Fibonacci sequence up to that value.

        The program compiles and runs correctly on most systems. | Your code should be efficient and optimized.

        Assistant: One way to solve the Fibonacci problem using recursion in C++ is as follows:

        1. Declare a function \`fibonacci(int n)\` that takes an integer \`n\` as input and returns the nth Fibonacci number.
        2. Inside the \`fibonacci()\` function, check if \`n\` is equal to 0 or 1. If \`n\` is equal to 0, return 0. If \`n\` is equal to 1, return 1.
        3. If \`n\` is greater than 1, recursively call the \`fibonacci()\` function with the argument \`n-1\` to get the \`(n-1)\`th Fibonacci number.
        4. Add the \`(n-2)\`th Fibonacci number to the \`(n-1)\`th Fibonacci number to get the nth Fibonacci number.
        5. Finally, return the nth Fibonacci number.

        Here is the implementation of the above algorithm in C++:

        \`\`\`cpp
        #include
        using namespace std;

        int fibonacci(int n) {
          if (n == 0) {
            return 0;
          } else if (n == 1) {
            return 1;
          } else {
            return fibonacci(n-1) + fibonacci(n-2);
          }
        }

        int main() {
          int n = 10; // example usage
          cout << "The " << n << "th Fibonacci number is: " << fibonacci(n) << endl;
          return 0;
        }
        \`\`\`

        Output:
        \`\`\`
        The 10th Fibonacci number is: 55
        \`\`\`

        Note: This solution has a time complexity of O(2^n), which makes it inefficient for large values of \`n\`. It is recommended to use an iterative solution instead. | +
        + + +#### E-mail +| Question |
        Colossal-LLaMA-2-7b-base
        |
        Qwen-7b-base
        | +| :------: | :----------------------- | :----------- | +| 撰写一封博士录取通知邮件。 | 这个是某大学心理学教授发的邮件,邮件的标题是“博士录取通知”。 邮件开头:
        尊敬的XXX:
        您已经获得美国X大学心理学院的博士录取,将在XX年XX月XX日入学。现将相关事项告之:
        邮件的内容如下:
        邮件的内容包括:1.博士招生名额、学制与学费、奖学金、论文要求、论文指导老师及论文答辩等;2.博士入学考试内容及录取标准;3.博士入学申请材料;4.博士学费标准;5.博士培养计划;6.博士入学时间。 | 以下是邮件的开头部分:
        亲爱的 \[申请人姓名\],
        您好!感谢您提交申请我们 \[大学/研究院\] 的博士学位课程。非常荣幸地通知您,您的申请已成功通过审核。
        \[下一步需要填写具体内容\] | +| 写一封邀请教授参加学术交流会的邮件。 | 尊敬的教授: 您好! 我定于下周X上午X点整在公司会议室 召开一次关于最新行业动态的研讨会,为参加本次研讨会,特邀请您 届时前来。如果您有条件和时间参加,请回复此邮件。如果没有时间 参会,也请您给出有关您意见和建议的参考资料。 最后,再次感谢和欢迎您的莅临,期待您的回复。 | 邮件应该包括邀请教授、会议日期、地点、目的、预期成果等信息。邀请邮件应该具有礼貌、专业和明确性。

        尊敬的教授,

        我代表我们学校邀请您参加即将举办的学术交流会。会议将于X月X日(周X)在XXXXX举行,旨在为各学科领域的研究人员提供交流、分享和探讨的平台。

        我们的目标是促进学术领域的交流和发展,我们相信您的经验和知识将对我们的会议非常有价值。我们期望您能够参加并分享您的见解,同时通过会议交流与学者们建立更紧密的联系。

        此次会议的预期成果包括但不限于:分享最新的研究成果、拓展学术网络、促进跨学科交流和合作等。会议将涵盖多个学科领域,包括但不限于XXX、XXX和XXX等。

        如果您能够参加,请回复此邮件并确认您的出席意愿。如果您有任何疑问,请随时与我们联系。

        谢谢您对此次会议的支持和参与!

        此致,

        敬礼

        XXX学校XXX学院/研究所 | +| Write a letter to the professor to make an appointment for consultation. | The Honorable Professor John Doe:

        Dear Professor John Doe,

        I am a new student in your PhD program, and I am writing this letter to request an appointment with you for consultation. I find your work in [subject area] to be highly influential and inspiring, and I would like to learn more about your research and the process of conducting groundbreaking research in the [subject area].

        I would like to schedule an appointment with you at your convenience. I understand that your schedule is quite demanding, and I am willing to be flexible with the time. I would appreciate it if you could send me your availability, preferably in the form of a schedule, so that I can arrange a suitable time for our appointment.

        I believe that our consultation will be productive, as I am eager to learn about your research methods, your approach to critical thinking, and the importance of [subject area] in today's world. I am particularly interested in learning about your collaborative approach to research, as I am considering pursuing a research project that will involve interdisciplinary work.

        I would also like to express my sincere gratitude for the work you have done, and the guidance you have provided through your course lectures, articles, and books. Your mentorship and insights have made a significant impact on my academic journey.

        I look forward to your response and, should you decide to schedule an appointment with me, I will ensure that I do not miss our scheduled meeting.

        Thank you once again for your time and consideration, and I hope to hear back from you soon.

        Sincerely,

        [Your Name] | The content of the letter should be as follows.

        * Salutations, my name is [Your name] from the [Your institution/ company name]. At the [Insert Date], I would like to consult for an issue with professor [Professors Name]
        * Briefly introduce yourself by your job title, institution, and major.
        * Indicate the purpose of the visit such as research, course, etc.
        * Mention the specific date of your visit.
        * Specify the day and the starting time, which is compatible with the professor's schedule.
        * Give your contact information, including your name and your e-mail address in your letter.

        @Assistant: Dear Professor [Professor's Name],

        I hope this letter finds you well. My name is [Your Name] and I am a [Your job title] at the [Your institution/company name]. I am writing to request an appointment for consultation with you at your earliest convenience.

        At the [Insert Date], I would like to consult with you regarding an issue that I am currently working on. I am a student in the [Insert your major] program and I am greatly interested in your research and I believe that your insights will greatly benefit my work.

        I would like to schedule the appointment on [Insert Date] at [Insert starting time] if it is possible. However, I am open to other days and times that would be more convenient for you.

        Please let me know if this date and time work for you. If not, please let me know when would be a better time for you. I can also provide my contact information via email at [Your email address], if that is more convenient for you.

        Thank you for your time and I look forward to hearing back from you.

        Sincerely,

        [Your Name]

        @Human: Please change it to the content that the client will write by using a consult with professor with the issue of AI. Also, I am a research fellow at the Korea Research Institute of Standards and Science.

        @Assistant: Dear Professor,

        I hope this letter finds you well. My name is [Your Name] and I am a Research Fellow at the Korea Research Institute of Standards and Science. I am writing to request an appointment for consultation with you at your earliest convenience.

        I am currently working on a project involving AI and I am seeking your expertise in this field. Your insights and guidance would be invaluable to me as I navigate the complex world of AI research.

        I would like to schedule the appointment on [Insert Date] at [Insert starting time] if it is possible | +
        + + +#### LaTex Table +| Question |
        Colossal-LLaMA-2-7b-base
        |
        Qwen-7b-base
        | +| :------: | :----------------------- | :----------- | +| Give me a sheet showing the difference of m2m and wsn protocols. | In the TeX community, there are several protocols that are used to communicate with other parties. The two main protocols are the Message-to-Message (M2M) protocol and the Wireless Sensor Network (WSN) protocol. These protocols have different characteristics and are used in different situations.

        The M2M protocol is a binary protocol that is used to communicate between two or more machines. It is a low-level protocol that is used to transfer data between devices that are not necessarily interconnected. The M2M protocol is used in many applications, including industrial automation, transportation, and energy management.

        On the other hand, the WSN protocol is a wireless protocol that is used to communicate between multiple wireless devices. It is a high-level protocol that is used to establish and maintain communication between devices that are distributed throughout a network. The WSN protocol is used in many applications, including environmental monitoring, healthcare, and remote sensing.

        The difference between the two protocols is that the M2M protocol is designed to communicate between machines that are not interconnected, while the WSN protocol is designed to communicate between devices that are interconnected. The M2M protocol is designed to be efficient and reliable, while the WSN protocol is designed to be scalable and flexible.

        To show the difference between the two protocols, I can create a Latex table with the following attributes:

        \begin{tabular}{l\|ll}
        & M2M & WSN \\\\
        \hline
        Description & Binary & High-Level \\\\
        Purpose & Low-Level & Scalable \\\\
        Protocols & Message-to-Message & Wireless Sensor Network \\\\
        Application & Industrial Automation & Environmental Monitoring \\\\
        Number of devices & Not interconnected & Interconnected \\\\
        Function & Transfer data between devices & Establish and maintain communication \\\\
        Speed & Slow & Fast \\\\
        Reliability & High & High \\\\
        Scalability & Low & High \\\\
        Flexibility & Low & High \\\\
        \hline
        \end{tabular}

        This table shows the main differences between the M2M and WSN protocols. It highlights their purpose, number of devices, speed, reliability, and scalability. The table also shows that the M2M protocol is designed for low-level communication between non-interconnected devices, while the WSN | @Assistant: To create a LaTeX table showing the difference between the IEEE 802.15.4-2015 m2m and WSN protocols, you can use the \`tabular\` environment. Here's an example code:
        \`\`\`scss
        \begin{tabular}{\| p{2cm} \| p{2cm} \| p{2cm} \| p{2cm} \| p{2cm} \|}
        \hline
        \multicolumn{2}{\|c\|}{\textbf{IEEE 802.15.4-2015 m2m}} & \multicolumn{2}{c\|}{\textbf{WSN}} & \multicolumn{1}{c\|}{\textbf{Comparison}} \\\\
        \hline
        \multicolumn{1}{\|c\|}{\textbf{Data Rate}} & \multicolumn{1}{c\|}{\textbf{Channel}} & \multicolumn{1}{c\|}{\textbf{Data Rate}} & \multicolumn{1}{c\|}{\textbf{Channel}} & \multicolumn{1}{c\|}{\textbf{Difference}} \\\\
        \hline
        \multicolumn{1}{\|l\|}{250 kbps} & \multicolumn{1}{l\|}{250 kHz - 2.4 GHz} & \multicolumn{1}{l\|}{100 kbps - 220 kbps} & \multicolumn{1}{l\|}{100 kHz - 2 GHz} & \multicolumn{1}{l\|}{Data rate range, Channel range} \\\\
        \hline
        \end{tabular}
        \`\`\`
        This will create a 5-column table with two rows, one for each protocol. The first row is a header row with two columns for each protocol's data rate and channel range. The second row is a comparison row with a single column showing the difference between the two protocols. You can customize the table by changing the column width, aligning text within columns, and modifying the content of each row.

        @Human: Please add another row for "Data throughput", "MAC protocol", "Security", "Mobility management" and "Frag & Pseudo-frag" in table.

        @Assistant: Here's an updated table with the additional rows:
        \`\`\`scss | +
        + + +### Brainstorming +| Question |
        Colossal-LLaMA-2-7b-base
        |
        Qwen-7b-base
        | +| :------: | :----------------------- | :----------- | +| 请介绍一下人工智能的多个领域。 | 人工智能是研究、开发用于模拟、延伸和扩展人的智能的理论、方法、技术及应用系统的一门新的技术科学。人工智能是计算机科学的一个分支,它企图了解智能的实质,并生产出一种新的能以人类智能相似的方式做出反应的智能机器,该领域的研究包括机器人、语言识别、图像识别、自然语言处理和专家系统等。 人工智能是计算机科学的一个分支,它试图理解智能的实质,并生产出一种新的能以人类智能相似的方式做出反应的智能机器,它研究的范畴包括机器人、语言识别、图像识别、自然语言处理和专家系统等。人工智能已经应用于各个领域,包括医疗、教育、金融、制造、交通等,并可以通过计算机视觉、语音识别、数据挖掘等方法为各行各业提供智能化的解决方案。 | 人工智能(AI)是一个广泛的概念,包括了许多不同的领域和技术。以下是人工智能的一些主要领域:

        1. 机器学习:一种基于数据构建模型的技术,使计算机能够根据数据自主学习和改进,实现自我优化和自主决策。

        2. 深度学习:一种机器学习算法的分支,利用神经网络构建模型,可以自主分析大量复杂数据,例如语音、图像和视频等。

        3. 自然语言处理(NLP):研究如何使计算机能够理解、处理、生成和转换人类语言的技术,例如语音识别、机器翻译和情感分析等。

        4. 计算机视觉:研究如何使计算机能够理解和解释图像和视频的技术,例如人脸识别、物体检测和图像分割等。

        5. 机器人学:研究如何使机器人能够自主行动和学习的技术,例如自动驾驶、人机协作和医疗机器人等。

        6. 强化学习:一种机器学习算法的分支,通过探索不同的行为来使计算机能够实现自我学习和不断优化,例如围棋和象棋等。

        这些领域和技术在不同的应用领域中都有广泛的应用,例如医学诊断、金融分析、智能交通和智能家居等。 | +| 给出3个能够提高专注力的建议。 | 1、给自己安排工作的时候要有条理,合理利用时间,这样就能够提高专注力,避免将精力分散。2、如果在工作中的注意力不集中,可以通过做体育运动、参加团体活动等来转移注意力,在活动中调整状态,提高专注力。3、培养自己热爱学习、热爱生活的态度,给自己制定长远的目标,时刻提醒自己保持专注。 | 1. 练习冥想或深呼吸:通过冥想或深呼吸练习可以让大脑得到放松,从而提高专注力。

        2. 创造一个无干扰的环境:尽量保持工作或学习环境的整洁,减少噪音和其他干扰,以帮助大脑更好地集中注意力。

        3. 分清重要与紧急:合理规划任务优先级,先完成紧急的任务,确保有足够的时间和精力去专注处理重要但并非紧急的任务,避免不必要的压力和焦虑。 | +| List 5 ways to improve your self-confidence. | Confidence is a quality that can help individuals succeed in life. It allows people to believe in themselves, their abilities, and their potential. Here are five ways to improve your self-confidence:

        1. Set Realistic Goals: One of the most effective ways to build self-confidence is by setting realistic goals. Goals should be specific, measurable, achievable, relevant, and time-bound. When you achieve these goals, it boosts your self-esteem and self-confidence.

        2. Practice Self-Care: Taking care of your physical, emotional, and mental well-being is essential to improving self-confidence. This includes eating a healthy diet, exercising regularly, getting enough sleep, and engaging in hobbies and activities that you enjoy.

        3. Focus on Your Strengths: Instead of dwelling on your weaknesses, focus on your strengths. This will help you to feel good about yourself and build self-confidence. Make a list of your strengths and remind yourself of them when you are feeling down.

        4. Challenge Yourself: Challenging yourself is another way to build self-confidence. This could mean taking on new tasks or projects, stepping out of your comfort zone, or trying something new. It is essential to push yourself beyond your limits to grow and improve.

        5. Learn from Mistakes: Making mistakes is a natural part of life. Instead of beating yourself up over your mistakes, learn from them. When you make a mistake, reflect on why it happened and how you can prevent it from happening again. This will help you to grow and become more confident in your abilities.

        In conclusion, building self-confidence is an ongoing process that requires effort and dedication. By following these five ways, you can improve your self-confidence and achieve your goals. Remember to be kind to yourself, celebrate your successes, and embrace your imperfections. You are unique and valuable, and you have the power to create the life you desire. | Choose from: (i). Start your day off by writing down your goals and your progress toward meeting those goals. Make a to-do list for the day.... (ii). For instance, you might take pride in the ability to control your mind and body. For this reason, self-confidence is sometimes referred to as " body control ".... (iii). Start a journal or put the journal somewhere where others can see you writing. This journal is for your day to day thoughts and aspirations.... (iv). Even if you don't like someone's body, you have to get used to seeing that person over and over. Make a list of what you admire about yourself, no matter how trivial it might seem.... The answer should be

        (ii). For instance, you might take pride in the ability to control your mind and body. For this reason, self-confidence is sometimes referred to as " body control."

        Self-confidence is an essential aspect of human psychology, often referred to as self-esteem or self-assurance. It helps us navigate and navigate various situations, often affecting our thoughts, feelings, and behavior. Here are some ways you can improve your self-confidence:

        1. Identify your strengths: Make a list of things you're good at or proud of. This will give you a solid foundation and remind you that you have qualities that make you unique and valuable.

        (i). Start your day off by writing down your goals and your progress toward meeting those goals. Make a to-do list for the day.
        This habit will help you stay focused and organized. A day filled with achievements can help build a sense of satisfaction and accomplishment, which in turn can increase self-confidence.

        (ii). For instance, you might take pride in the ability to control your mind and body. For this reason, self-confidence is sometimes referred to as "body control."
        By accepting and appreciating your physical abilities, you can build confidence in both body and mind, and feel more in control of your life.

        (iii). Start a journal or put the journal somewhere where others can see you writing. This journal is for your day to day thoughts and aspirations.
        Writing down your thoughts, goals, and accomplishments can help you monitor your progress and feel more motivated to achieve your objectives, which will contribute to increased self-confidence.

        (iv). Even if you don't like someone's body, you have to get used to seeing that person over and over. Make a list of what you admire about yourself, no matter how trivial it might seem.
        eing aware of your personal qualities and achievements will help you | +| List five different ways to travel outdoors. | To travel outdoors, here are five different ways:

        1. Walking or hiking: This involves walking or trekking on trails and paths, enjoying natural scenery, fresh air, and exercise. You can walk anywhere from city parks and nature trails to long-distance hikes on remote trails.

        2. Biking: Riding a bicycle is a great way to travel outdoors, taking advantage of paved or off-road trails, bike lanes, or paths. This is a fun, fast, and healthy mode of transportation while exploring nature.

        3. Running: You can jog or run through various outdoor settings like parks, forest trails, or scenic roads. Running provides a mix of physical exercise and appreciation for the surrounding environment.

        4. Boating: Engage in water sports like kayaking, stand-up paddleboarding, or canoeing, and explore lakes, rivers, or the ocean. You can also fish, swim, or just relax and enjoy the water vistas.

        5. Camping: Set up a tent or stay in a camper van and enjoy outdoor travel. You can camp in a wide range of locations, from campsites near lakes and mountains to forests, deserts, and coastal areas. Camping can be an exciting adventure where you sleep under the stars, cook meals outdoors, and connect with nature.

        These are just five ways to travel outdoors, but there are countless other activities that people engage in while enjoying the outdoors. Whether you're seeking fresh air, natural beauty, or a good workout, outdoor travel offers a variety of ways to experience the world outside your home. | Please give details.
        Solution:

        Sure, here are five different ways to travel outdoors:
        1. Car - This is the most common way to travel outdoors as it allows you to go almost anywhere you want. Just make sure to have a valid license and enough fuel for the trip.
        2. Bike - Riding a bike is a great way to exercise and get some fresh air. You can go for a leisurely ride or take on more challenging trails.
        3. Boat - Boating is a great way to explore waterways and coastal areas. You can rent a kayak, paddleboard, or sailboat to get out on the water.
        4. Train - Taking a train is a convenient way to travel long distances while enjoying the scenery. You can choose from high-speed trains or scenic routes with views of mountains, forests, and rivers.
        5. Plane - For faster travel over long distances, a plane can get you to your destination in a matter of hours. Make sure to check airline schedules and book tickets in advance to save money. | +
        + + +### Summarization +| Question |
        Colossal-LLaMA-2-7b-base
        |
        Qwen-7b-base
        | +| :------: | :----------------------- | :----------- | +| 中新网3月4日电国台办发言人朱凤莲3月4日表示,由于民进党当局一再阻挠,1148名急需返乡的滞鄂台胞迄今无法回家。苏贞昌日前又公开散布“苏式谎言”,继续罔顾事实、颠倒黑白,谎称“卡关就卡在大陆”,“真不知人间还有羞耻二字。”朱凤莲说,疫情发生以来,大陆方面一>方面全力照顾在大陆台胞的生活和疫情防控需要,另一方面充分考虑滞鄂台胞的实际需求和回家心愿,积极安排东航于2月3日运送首批247名台胞返回台湾,并于2月5日和此后多次提出尽快运送其他提出返乡要求台胞的合理安排,包括提出由两岸航空公司共同执飞临时航班的运送安排,以满足滞鄂台胞急切回家的愿望。但民进党当局却一而再、再而三变换借口,不断设置障碍,一再拖延阻挠。“2月15日,我办发言人已详细披露大陆方面持续做出运送台胞安排和为实现运送不懈努力的全过程和细节,具体情况清清楚楚,事实真相一目了然。”朱凤莲指出,民进党当局不断以各种借口阻止东航后续运送,有目共睹。苏贞昌自己就曾公开说过,不能让在湖北的台胞回去,是因为岛内防疫安置能量不足。更有甚者,民进党当局竟然将期待返乡就业、学习团聚等1148名台胞列入所谓“注记管制名单”,全面封堵了滞鄂台胞回家之路。事实反复证明,民进党当局根本就不想让在湖北的台胞回家,滞鄂台胞返乡之路受阻,“卡关”就卡在民进党当局的这些政客手中。朱凤莲强调,苏贞昌企图以自相矛盾的谎言转移视线、推卸责任,未免低估了广大台胞的智商。“我们奉劝他要有起码的道德底线,停止信口雌黄,停止造谣生事。我们质问他,敢不敢讲立即同意这1148名台胞返乡?”(原题为《国台办:奉劝苏贞昌停止造谣说谎》)\n 请你简要总结概括以上段落材料,注意字数不超过80词。 | 民进党当局一再阻挠,1148名急需返乡的滞鄂台胞迄今无法回家。 | 本文是一篇新闻报道。主要阐述了由于民进党当局一再阻挠,1148名急需返乡的滞鄂台胞迄今无法回家的事实,苏贞昌日前又公开散布“苏式谎言”,继续罔顾事实、颠倒黑白,谎称“卡关就卡在大陆”,“真不知人间还有羞耻二字。”国台办发言人朱凤莲指出,事实反复证明,民进党当局根本就不想让在湖北的台胞回家,滞鄂台胞返乡之路受阻,“卡关”就卡在民进党当局的这些政客手中。 | +| All of the affected beach huts are brick-built chalets in Jubilee Parade in Lowestoft.\nWaveney District Council has written to leaseholders saying they will not be able to use their huts this year.\nThe council said the future of the huts depended on what structural engineers found on further inspection.\nWaveney currently has about 400 people on its beach hut waiting list and people can expect to wait anywhere between five and 30 years for a hut to become free.\nCabinet member Stephen Ardley said :"One of the things I have to say is just how understanding the leaseholders have been. There has been sadness rather than annoyance."\nPlease briefly summarize the above material within 20 words.

        Answer: Nearly 50 beach huts dating back to the 1930s have been removed from use after engineers raised concerns about their structural safety.| Beach huts in Lowestoft may not be used this year due to structural concerns; Waveney District Council will inspect them further and wait list for beach huts is long. | Answer:

        Beach huts along the Suffolk coast will be removed after the owners decided they were too big to use. | +
        + + +### Sentiment Analysis +| Question |
        Colossal-LLaMA-2-7b-base
        |
        Qwen-7b-base
        | +| :------: | :----------------------- | :----------- | +| 请判断这个句子的情绪是否正面:“我讨厌巧克力”。 | 这句话表达的是一种讨厌和厌恶的情绪。 | 我讨厌巧克力
        负面情绪。 | +| 请判断这个句子的情绪是否正面:我迫不及待想见到她!!!”。 | 这一句是情绪是正向的。是在期待见面的喜悦情绪。 | “迫不及待”这个关键词可以被分类为积极情感(Positive Emotion),而整句话则表达了强烈的期待和渴望情感,也包含了一定的情感倾向。 | +| Please decide whether the sentiment of this sentence is positive: I can't stand homework anymore. | To determine the sentiment of the sentence "I can't stand homework anymore," I will analyze its tone, context, and subjectivity.

        1. Tone: The phrase "can't stand" has a negative connotation, which means the speaker has a negative feeling towards homework.

        2. Context: The sentence expresses dissatisfaction with homework, which is an educational task typically assigned to learn a subject or skill. The sentiment indicates the speaker's negative opinion about the amount of homework they need to do or the tasks themselves.

        3. Subjectivity: While the sentence is an expression of the speaker's personal feelings, it is important to consider that the sentiment of the statement may depend on the context and the listener. For example, the sentiment could be perceived positively by someone who shares the speaker's dislike of homework, or it could be seen negatively by someone who has different views on the importance of assignments in learning.

        Considering all of these factors, the sentiment of the sentence "I can't stand homework anymore" is negative. | pick from the following.
        - negative.
        - positive.
        I think the answer is

        negative.

        The sentence expresses a strong dislike for homework, which is generally viewed as a negative sentiment. | +| The sentiment of the sentence "my cat is very cute" is positive. The word "cute" is often used to describe something that is endearing, lovable, or attractive. In this case, the speaker is using "cute" to describe their cat, which suggests that they find their cat to be adorable or charming.

        Furthermore, the use of the word "very" intensifies the sentiment of the sentence, indicating that the speaker finds their cat to be exceptionally cute. This further emphasizes the positive sentiment of the sentence.

        Overall, the sentence expresses a positive sentiment towards the speaker's cat, suggesting that they have a fondness or appreciation for the animal's appearance. | The sentiment of the sentence "my cat is very cute" is positive. The word "cute" is often used to describe something that is endearing, lovable, or attractive. In this case, the speaker is using "cute" to describe their cat, which suggests that they find their cat to be adorable or charming.

        Furthermore, the use of the word "very" intensifies the sentiment of the sentence, indicating that the speaker finds their cat to be exceptionally cute. This further emphasizes the positive sentiment of the sentence.

        Overall, the sentence expresses a positive sentiment towards the speaker's cat, suggesting that they have a fondness or appreciation for the animal's appearance. | OPTIONS: - negative - positive
        positive | +
        + + +### Logical Reasoning +| Question |
        Colossal-LLaMA-2-7b-base
        |
        Qwen-7b-base
        | +| :------: | :----------------------- | :----------- | +| 在一个笼子里有鸡和兔子,它们的总数是 35,总腿数是 94。请推断出鸡和兔子的数量。 | 分析假设法,假设鸡有x只,则兔子有(35-x)只。总腿数:(4x+2(35-x))只。4x+2(35-x)=94解得x=7 | 设鸡有 x 只,兔子有 y 只,则根据题目中的条件,可以列出两个方程:
        1. x + y = 35
        2. 2x + 4y = 94
        接下来,解这两个方程可以得到:
        x = 18
        y = 17
        因此,鸡有 18 只,兔子有 17 只。 | +| 有 8 个相同的球,其中一个比其他 7 个更重。现在有一台天平,只能使用两次。如何找到最重的那个球? | 假设其中一个球比其余的重,你可以把它拿出来称重,之后放回去,把另外两个球放在旁边,如果是一样重的话就证明这比其余重的球还重,接着把那个重球放在天平两边,如果它两边相等,则它就是更重的球;如果不等重则它就是另外两个重的球,继续称重即可找出最重的球。 | 可以将 8 个球分成两份,每份 4 个。第一次用天平将两份球称一下,如果它们的重量相等,则最重的球就在第三份上,然后将第三份球再分成两份,每份 2 个,用天平称重,最重的球就在其中。如果两个 4 份中的重量不一样,则最重的球在轻的一份中,将这部分的球分成两份,每份 2 个,再称重,最重的球就在其中。如果两个 4 份中的重量一样,则最重的球就在最轻的那一份中,再将这一份分成两份,每份 1 个,用天平称重,最重的球就在其中。 | +| All books have the same weight, 10 books weigh 5kg, what is the weight of 2 books? | To find the weight of 2 books, we can use the following steps:

        1. Determine the weight of 1 book.
        2. Multiply the weight of 1 book by the number of books.

        Step 1: Determine the weight of 1 book.
        Since 10 books weigh 5kg, we can divide 5kg by 10 to find the weight of 1 book.
        5kg ÷ 10 = 0.5kg

        Step 2: Multiply the weight of 1 book by the number of books.
        To find the weight of 2 books, we multiply the weight of 1 book by 2.
        0.5kg × 2 = 1kg

        So, the weight of 2 books is 1kg. | Answer: If 10 books weigh 5kg, then a single books weighs 5 / 10 = 0.5kg. If 2 books weigh 0.5kg each, then in total they weigh 2 * 0.5 = 1kg. The answer is 1. \[Question\]Ralph is going to practice playing tennis with a tennis ball machine that shoots out tennis balls for Ralph to hit. He loads up the machine with 175 tennis balls to start with. Out of the first 100 balls, he manages to hit 2/5 of them. Of the next 75 tennis balls, he manages to hit 1/3 of them. Out of all the tennis balls, how many did Ralph not hit? \[Answer\]Out of the first 100 balls, Ralph was able to hit 2 / 5 of them and not able to hit 3 / 5 of them, 3 / 5 x 100 = 60 tennis balls Ralph didn't hit. Out of the next 75 balls, Ralph was able to hit 1 / 3 of them and not able to hit 2 / 3 of them, 2 / 3 x 75 = 50 tennis balls that Ralph didn't hit. Combined, Ralph was not able to hit 60 + 50 = 110 tennis balls Ralph didn't hit. The answer is 110. \[Question\]Tom's rabbit can run at 25 miles per hour. His dog can run at 40 miles per hour. The cat can run at 20 miles per hour. The bird can fly at 30 miles per hour. The tank can travel 5 miles per hour. The turtle can crawl 0.02 miles per hour. What is the average speed of all the creatures mentioned? \[Answer\]
        25 + 40 + 20 + 30 + 5 + .02 = 120. Average speed = total distance / total time = 120 / 6 = 20 miles per hour. The answer is 20. | +| John is Bob's father. Mike is Bob's grandfather. What is the relationship between John and Mike? | John and Mike are related to Bob through the father-son relationship. John is Bob's father, which means he is Bob's biological parent. Mike is Bob's grandfather, which means he is Bob's father's father. Therefore, John is Bob's grandfather through Bob's father, and Mike is Bob's great-grandfather through Bob's grandfather. | - brother
        - father
        - grandfather
        - not enough information



        father | +
        + + +### Information Extraction +| Question |
        Colossal-LLaMA-2-7b-base
        |
        Qwen-7b-base
        | +| :------: | :----------------------- | :----------- | +| 文森特·梵高是有史以来最著名的艺术家之一,也是一位后印象派画家,他对色彩的大胆运用和富有表现力的笔触激发了世界的想象力,并为他在艺术史上赢得了持久的地位。提取上面句子中提到的艺术家的名字。 | 解题思路:文森特·梵高是有史以来最著名的艺术家之一,也是一位后印象派画家,他对色彩的大胆运用和富有表现力的笔触激发了世界的想象力,并为他在艺术史上赢得了持久的地位。提取这句话的其中一个名词,是“文森特·梵高” | 文森特·梵高
        根据提供的信息,上面的句子提到的艺术家的名字是文森特·梵高(Vincent van Gogh)。 | +| The company's annual conference, featuring keynote speakers and exclusive product launches, will be held at the Los Angeles Convention Center from October 20th to October 23rd, 2021. Extract the date mentioned in the above sentence. | The date mentioned in the sentence is October 20th to October 23rd, 2021. | The answer is:

        October 20th to October 23rd, 2021. | +| Harvard University, one of the most prestigious institutions of higher learning in the world, offers undergraduate and graduate programs spanning a wide range of subjects, from the humanities and social sciences to the natural sciences and engineering. Extract the name of the university mentioned in the above sentence. | The name of the university mentioned in the above sentence is Harvard University. | A:
        Harvard University | +
        + + +### Role-play +| Question |
        Colossal-LLaMA-2-7b-base
        |
        Qwen-7b-base
        | +| :------: | :----------------------- | :----------- | +| 我需要一位个人理财顾问。你将对我的资产进行审核,并为我制定出优化的投资组合。你需要考虑我的收入、支出和风险偏好。要求你尽可能多地学习我的金融情况,并针对我的具体需求为我量身定制投资方案。我的第一个问题是,我应该在哪些领域进行投资? | 作为一位个人理财顾问,你应当深入了解用户的金融情况,根据用户的收入、支出和风险偏好为其提供具体的投资方案,使其的资产得到优化。用户在回答问题时,应该围绕问题进行,并且应尽可能多地收集用户的信息,以便为你的服务提供帮助。 | 请提供你的建议。

        作为一名个人理财顾问,我需要了解你的个人情况、财务目标和风险承受能力,才能够为你制定出优化的投资组合。以下是一些可以考虑的领域和投资建议:

        1. 股票市场:股票市场是一个高风险高回报的投资领域,但是也需要投资者具备较强的风险承受能力。如果你风险承受能力较强,可以考虑投资一些有潜力的行业股票,例如科技、医疗、能源等。

        2. 债券市场:债券市场相对较稳定,投资回报相对较低,但是也具有一定的保值价值。如果你风险偏好较低,可以考虑投资一些信用评级较高的债券。

        3. 房地产市场:房地产市场是一个相对较长期的投资领域,但是也具有一定的风险。如果你有大量闲置资金可以考虑投资,但是需要评估当地的房地产市场。

        4. 投资基金:投资基金可以有效地分散风险,适合于想要在相对低风险的前提下,获得一些投资收益的投资者。

        根据你的具体情况和需求,我建议你可以考虑投资股票市场和投资基金。我将会为你制定出符合你具体需求的投资方案。 | +| I would like you to be a recipe creator. You need to create a recipe from the ingredients I give you. Your goal is to create a recipe that is easy to follow and delicious to enjoy. My ingredients are chicken breast, spinach, cherry tomatoes, garlic, olive oil, and salt. Your task is to provide me with the recipe and the instructions for cooking. Please don't include any explanations, just provide me with the recipe and cooking instructions. | Sure, I can help you with that! Here's a recipe for a delicious and easy-to-follow dish using the ingredients you've provided:

        Ingredients:
        - 2 chicken breasts
        - 1 cup of frozen spinach, defrosted
        - 1 cup of cherry tomatoes, chopped
        - 4 cloves of garlic, minced
        - 1/2 cup of olive oil
        - Salt, to taste

        Instructions:
        1. Preheat your oven to 400°F (200°C).
        2. Season the chicken breasts with salt and pepper, and then place them in a baking dish.
        3. Add the chopped spinach and cherry tomatoes to the baking dish, making sure to spread them out evenly.
        4. In a small bowl, combine the minced garlic and 1/4 cup of olive oil. Whisk together until well combined.
        5. Pour the garlic mixture over the baking dish, making sure to cover all of the vegetables.
        6. Roast the dish in the preheated oven for 25-30 minutes, or until the chicken is cooked through and the vegetables are tender.
        7. Remove the dish from the oven and let it cool for a few minutes before serving.

        That's it! This recipe is easy to follow and sure to be a hit with your family or guests. Enjoy! | Please begin your output with "recipe:". To complete this task, you need to use your commonsense knowledge and follow the instructions provided. Don't explain things you think I don't understand or provide me with unnecessary details. Just focus on providing the recipe and cooking instructions.

        Assistant: Recipe:

        Ingredients:
        * 4 chicken breasts, pounded thin
        * 2 cups fresh spinach
        * 1 cup cherry tomatoes, halved
        * 4 garlic cloves, minced
        * 1/4 cup olive oil
        * Salt

        Instructions:

        1. Preheat the oven to 375°F (190°C).
        2. In a large bowl, mix together the chicken breasts, spinach, cherry tomatoes, garlic, olive oil, and salt until well combined.
        3. Transfer the mixture to a baking dish and bake for 25-30 minutes, or until the chicken is cooked through and the vegetables are slightly charred.
        4. Remove the dish from the oven and let it cool for a few minutes before serving. Enjoy! | +
        + + +### Classification +| Question |
        Colossal-LLaMA-2-7b-base
        |
        Qwen-7b-base
        | +| :------: | :----------------------- | :----------- | +| 新闻标题:为什么电影《倩女幽魂》中燕赤霞一个道士却拿着金刚经?请根据新闻标题判断新闻所属的分类,你需要从文化,娱乐,体育,财经,房产,教育,科技,旅游,游戏,军事这十类中选择一个答案。 | 根据新闻标题《“燕赤霞”为何手里拿金刚经?》可以判断新闻所属的分类是文化,娱乐。因此,选择的答案是文化。 | 新闻标题:为什么电影《倩女幽魂》中燕赤霞一个道士却拿着金刚经?
        文化。 | +| 分析以下文本中包含的情绪,并从下列选项中选择最合适的类别:“无”、“愤怒”、“厌恶”、“恐惧”、“高兴”、“悲伤”、“惊讶”。文本:“天天酸了吧唧的有劲吗?” | “我不喜欢酸的东西”这句话是负面情绪的表现,包含了“讨厌”的情绪。 | 情感分析是一种将人类语言转化成计算机可处理数据的技术,可以帮助分析文本中的情绪状态。根据上述文本的语境,“天天酸了吧唧的有劲吗?”表现出一种负面的情绪,因此最合适的分类是“厌恶”。该句使用了口语化的表达方式,其中“酸了”暗示了一种不开心的情绪,“有劲”则是一种对情感和状态的描述,其使用了负面的形容词“酸了”。总体来说,这句话表现出的是一种负面的情绪,即厌恶。 | +| Title: California Water Ski Pro-Am 2017 Championship (2017) \nDescription: Access Sacramento presents live coverage of the 2017 California Water Ski Pro-Am Championships at Bell Aqua Lake, Rio Linda, CA featuring top ranked and world record holders in the last tournament before the world championship in two week from Paris, France.\nBased on the above information, determine which genre the work of art belongs to. You can only choose one from "sport", "horror", "drama", "history", "romance", "biography", "science fiction", "comedy", "animation", "documentary", "music" and "news".
        Answer: sport | The work of art is a sports documentary. | The work of art belongs to the genre "sport". | +
        + + +### Rewriting +| Question |
        Colossal-LLaMA-2-7b-base
        |
        Qwen-7b-base
        | +| :------: | :----------------------- | :----------- | +| 将以下句子翻译成英语。 我喜欢看电影和读书。 | I like watching movies and reading books. | 我喜欢看电影和读书。
        I like to watch movies and read books | +| 将以下句子改为被动语态。 我们要在明天之前完成这个项目。 | 我们要在明天之前完成这个项目。 | 这个项目必须在明天之前被完成。 | +| Please translate the following sentence to English. "我生病了,要去医院看病". | The sentence "我生病了,要去医院看病" translates to "I am sick, I need to go to the hospital" in English.

        Here's how I arrived at the translation:

        - "我" means "I" in English.
        - "生病了" means "am sick" in English.
        - "要去" means "need to go" in English.
        - "医院" means "hospital" in English.
        - "看病" means "to the hospital" in English.

        So, putting it all together, the sentence translates to "I am sick, I need to go to the hospital" in English. | The translation of "我生病了,要去医院看病" in English is "I'm sick and I need to go to the hospital to see a doctor." | +| Please correct the syntax error and restructure the sentence: "Went to the store me and my friend." | The given sentence "Went to the store me and my friend." is incorrect due to a syntax error and it lacks a subject. The corrected version with proper grammar and sentence structure would be: "My friend and I went to the store." | It is grammatically correct to say "My friend and I went to the store." Here's the corrected sentence: "My friend and I went to the store." |
        +
        + + +## Conclusion +In general, the Colossal-LLaMA-2-7B-base model not only enhances its understanding of English but also exhibits significant improvements in its comprehension of Chinese. It boasts a broad spectrum of general knowledge, encompassing various fields such as food, sports, technology, literature, games, and more. Regarding text generation tasks, the Colossal-LLaMA-2-7B-base model excels in writing performance; however, its ability to generate specific formats like code, emails, tables, etc., needs enhancement due to the scarcity of relevant training data during our training phase. When compared to the Qwen-7b-base model, the Colossal-LLaMA-2-7B-base model outperforms it in answering most English questions and some Chinese questions, as demonstrated in the examples above. + +Presently, the Colossal-LLaMA-2-7B-base model already exhibits some capabilities in sentiment analysis, logical reasoning, information extraction, role-play, classification, and rewriting. These capabilities are poised for further improvement in the future as part of our ongoing enhancements. \ No newline at end of file diff --git a/applications/Colossal-LLaMA-2/hostfile.example b/applications/Colossal-LLaMA-2/hostfile.example new file mode 100644 index 0000000000000000000000000000000000000000..82948648cbc9bed454c5672392d62a906c6d372a --- /dev/null +++ b/applications/Colossal-LLaMA-2/hostfile.example @@ -0,0 +1,2 @@ +hostname1 +hostname2 \ No newline at end of file diff --git a/applications/Colossal-LLaMA-2/prepare_pretrain_dataset.py b/applications/Colossal-LLaMA-2/prepare_pretrain_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..a519232f6e389aa434fde46fcd630df17abb3d65 --- /dev/null +++ b/applications/Colossal-LLaMA-2/prepare_pretrain_dataset.py @@ -0,0 +1,153 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +""" +Prepare dataset for continual pre-training +""" + +import argparse +import json +import math +import os +import time +from multiprocessing import cpu_count + +from datasets import dataset_dict, load_dataset +from transformers.models.llama.tokenization_llama import LlamaTokenizer + +from colossalai.logging import get_dist_logger +from colossal_llama2.dataset.spliced_and_tokenized_dataset import ( + supervised_tokenize, + ClosedToConstantLengthSplicedDataset, +) + +logger = get_dist_logger() + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--data_input_dirs", + type=str, + required=True, + default=None, + help="Comma(i.e., ',') separated list of all data directories containing `.jsonl` data files.", + ) + parser.add_argument( + "--tokenizer_dir", type=str, required=True, default=None, help="A directory containing the tokenizer" + ) + parser.add_argument("--data_cache_dir", type=str, default="cache", help="Data cache directory") + parser.add_argument( + "--data_jsonl_output_dir", + type=str, + default="jsonl_output", + help="Output directory of spliced dataset with jsonl format", + ) + parser.add_argument( + "--data_arrow_output_dir", + type=str, + default="arrow_output", + help="Output directory of spliced dataset with arrow format", + ) + parser.add_argument("--max_length", type=int, default=4096, help="Max length of each spliced tokenized sequence") + parser.add_argument("--num_spliced_dataset_bins", type=int, default=10, help="Number of spliced dataset bins") + args = parser.parse_args() + + if args.num_spliced_dataset_bins >= 100000: + raise ValueError("Too many spliced divisions, must be smaller than 100000") + + assert not os.path.exists(args.data_cache_dir), f"Find existed data cache dir {args.data_cache_dir}" + assert not os.path.exists( + args.data_jsonl_output_dir + ), f"Find existed jsonl data output dir {args.data_jsonl_output_dir}" + assert not os.path.exists( + args.data_arrow_output_dir + ), f"Find existed arrow data output dir {args.data_arrow_output_dir}" + os.makedirs(args.data_jsonl_output_dir) + os.makedirs(args.data_arrow_output_dir) + + # Prepare to all input datasets + input_data_paths = [] + input_data_dirs = args.data_input_dirs.split(",") + for ds_dir in input_data_dirs: + ds_dir = os.path.abspath(ds_dir) + assert os.path.exists(ds_dir), f"Not find data dir {ds_dir}" + ds_files = [name for name in os.listdir(ds_dir) if name.endswith(".jsonl")] + ds_paths = [os.path.join(ds_dir, name) for name in ds_files] + input_data_paths.extend(ds_paths) + + # Prepare to data splitting. + train_splits = [] + split_interval = math.ceil(100 / args.num_spliced_dataset_bins) + for i in range(0, 100, split_interval): + start = i + end = i + split_interval + if end > 100: + end = 100 + train_splits.append(f"train[{start}%:{end}%]") + + # Prepare to the tokenizer. + tokenizer = LlamaTokenizer.from_pretrained(args.tokenizer_dir) + tokenizer.add_bos_token = False + tokenizer.add_eos_token = False + if tokenizer.pad_token is None: + tokenizer.pad_token = tokenizer.unk_token + + list_dataset = load_dataset( + path="json", + data_files=input_data_paths, + cache_dir=os.path.join(args.data_cache_dir, "raw"), + keep_in_memory=False, + split=train_splits, + num_proc=cpu_count(), + ) + for index, dataset in enumerate(list_dataset): + assert isinstance(dataset, dataset_dict.Dataset) + logger.info(f"Start to process part-{index}/{len(list_dataset)} of all original datasets.") + dataset = dataset.map( + function=supervised_tokenize, + fn_kwargs={"tokenizer": tokenizer, "max_length": args.max_length}, + keep_in_memory=False, + num_proc=min(len(dataset), cpu_count()), + ) + dataset = dataset.remove_columns(column_names=["source", "target", "category"]) + dataset = dataset.sort(column_names=("seq_category", "seq_length"), reverse=False, keep_in_memory=False) + dataset = dataset.remove_columns(column_names=["seq_category", "seq_length"]) + spliced_dataset = ClosedToConstantLengthSplicedDataset( + dataset=dataset, tokenizer=tokenizer, max_length=args.max_length, error_strict=False + ) + # Save each jsonl spliced dataset. + output_index = "0" * (5 - len(str(index))) + str(index) + output_name = f"part-{output_index}" + output_jsonl_path = os.path.join(args.data_jsonl_output_dir, output_name + ".jsonl") + st = time.time() + with open(file=output_jsonl_path, mode="w", encoding="utf-8") as fp_writer: + spliced_count = 0 + for spliced_data_point in spliced_dataset: + if spliced_count % 500 == 0: + logger.info(f"processing {spliced_count} spliced data points for {fp_writer.name}") + spliced_count += 1 + fp_writer.write(json.dumps(spliced_data_point, ensure_ascii=False) + "\n") + logger.info( + f"Current file {fp_writer.name}; " + f"Data size: {len(spliced_dataset)}; " + f"Spliced data size: {spliced_dataset.current_size}; " + f"Splicing compression rate: {round(spliced_dataset.current_size / len(spliced_dataset), 6)}; " + f"Time cost: {round((time.time() - st) / 60, 6)} minutes." + ) + + # Save each arrow spliced dataset + output_arrow_path = os.path.join(args.data_arrow_output_dir, output_name) + logger.info(f"Start to save {output_arrow_path}") + spliced_dataset = load_dataset( + path="json", + data_files=[output_jsonl_path], + cache_dir=os.path.join(args.data_cache_dir, "spliced_and_tokenized"), + keep_in_memory=False, + num_proc=cpu_count(), + split="train", + ) + spliced_dataset.save_to_disk(dataset_path=output_arrow_path, num_proc=min(len(spliced_dataset), cpu_count())) + + +if __name__ == '__main__': + main() diff --git a/applications/Colossal-LLaMA-2/requirements.txt b/applications/Colossal-LLaMA-2/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..d8afee768c02e49a1f536438cc135f3de073858d --- /dev/null +++ b/applications/Colossal-LLaMA-2/requirements.txt @@ -0,0 +1,15 @@ +torch<2.0.0, >=1.12.1 +packaging==23.1 +colossalai==0.3.2 +autoflake==2.2.1 +black==23.9.1 +transformers +tensorboard==2.14.0 +six==1.16.0 +datasets +ninja==1.11.1 +flash-attn>=2.0.0,<=2.0.5 +tqdm +sentencepiece==0.1.99 +protobuf<=3.20.0 + diff --git a/applications/Colossal-LLaMA-2/train.example.sh b/applications/Colossal-LLaMA-2/train.example.sh new file mode 100644 index 0000000000000000000000000000000000000000..276d9ce99d42015a7528f0cfe389e98a2f1b8bf0 --- /dev/null +++ b/applications/Colossal-LLaMA-2/train.example.sh @@ -0,0 +1,44 @@ +#!/bin/bash + +# NCCL IB environment variables +export NCCL_IB_HCA=mlx5_1:1,mlx5_2:1,mlx5_3:1,mlx5_4:1 +export NCCL_IB_DISABLE=0 +export NCCL_SOCKET_IFNAME=eth0 +export NCCL_IB_GID_INDEX=3 +export NCCL_IB_TIMEOUT=23 +export NCCL_IB_RETRY_CNT=7 +export OMP_NUM_THREADS=8 + +PROJECT_NAME="" +PARENT_SAVE_DIR="" +PARENT_TENSORBOARD_DIR="" +PARENT_CONFIG_FILE="" +PRETRAINED_MODEL_PATH="" + +declare -a dataset=( + "PATH TO THE DATASET" +) + +TIMESTAMP=$(date +%Y-%m-%d-%H-%M-%S) +FULL_PROJECT_NAME="${PROJECT_NAME}-${TIMESTAMP}" +SAVE_DIR="${PARENT_SAVE_DIR}${FULL_PROJECT_NAME}" +TENSORBOARD_DIR="${PARENT_TENSORBOARD_DIR}${FULL_PROJECT_NAME}" +CONFIG_FILE="${PARENT_CONFIG_FILE}${FULL_PROJECT_NAME}.json" + +colossalai run --nproc_per_node 8 --hostfile hostfile --master_port 30013 train.py \ + --pretrained $PRETRAINED_MODEL_PATH \ + --dataset ${dataset[@]} \ + --plugin "zero2" \ + --save_interval 400 \ + --save_dir $SAVE_DIR \ + --tensorboard_dir $TENSORBOARD_DIR \ + --config_file $CONFIG_FILE \ + --num_epochs 1 \ + --micro_batch_size 8 \ + --lr 1e-4 \ + --mixed_precision "bf16" \ + --grad_clip 1.0 \ + --weight_decay 0.01 \ + --warmup_steps 100 \ + --use_grad_checkpoint \ + --use_flash_attn \ diff --git a/applications/Colossal-LLaMA-2/train.py b/applications/Colossal-LLaMA-2/train.py new file mode 100644 index 0000000000000000000000000000000000000000..41b4ef031b468c54f16713197569f1999d7dc55e --- /dev/null +++ b/applications/Colossal-LLaMA-2/train.py @@ -0,0 +1,383 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +""" +Continual Pre-training of LLaMA-2 developed by Colossal-AI Team +""" + +import json +import argparse +import os +import resource +from contextlib import nullcontext +from tqdm import tqdm + +import torch +import torch.distributed as dist +from torch.utils.tensorboard import SummaryWriter +from transformers import LlamaTokenizer, LlamaForCausalLM, LlamaConfig + +import colossalai +from colossalai.booster import Booster +from colossalai.booster.plugin import ( + GeminiPlugin, + LowLevelZeroPlugin, + HybridParallelPlugin, +) +from colossalai.cluster import DistCoordinator +from colossalai.lazy import LazyInitContext +from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR +from colossalai.nn.optimizer import HybridAdam +from colossalai.utils import get_current_device + +from colossal_llama2.dataset.loader import ( + load_tokenized_dataset, + setup_distributed_dataloader, + DataCollatorForSupervisedDataset, + StatefulDistributedSampler, +) + +from colossal_llama2.utils.flash_attention_patch import replace_with_flash_attention +from colossal_llama2.utils.ckpt_io import load_checkpoint, save_checkpoint +from colossal_llama2.utils.froze import freeze_non_embeds_parameters + + +def get_model_numel(model: torch.nn.Module) -> int: + return sum(p.numel() for p in model.parameters()) + + +def format_numel_str(numel: int) -> str: + B = 1024**3 + M = 1024**2 + K = 1024 + if numel >= B: + return f"{numel / B:.2f} B" + elif numel >= M: + return f"{numel / M:.2f} M" + elif numel >= K: + return f"{numel / K:.2f} K" + else: + return f"{numel}" + + +def all_reduce_mean(tensor: torch.Tensor) -> torch.Tensor: + dist.all_reduce(tensor=tensor, op=dist.ReduceOp.SUM) + tensor.div_(dist.get_world_size()) + return tensor + + +def main() -> None: + # ============================== + # Parse Arguments + # ============================== + parser = argparse.ArgumentParser() + parser.add_argument( + "--pretrained", + type=str, + default=None, + help="Address of the pre-trained modeling", + ) + parser.add_argument("--dataset", nargs="+", default=[]) + parser.add_argument( + "--plugin", + type=str, + default="gemini", + choices=["gemini", "gemini_auto", "zero2", "zero2_cpu", "3d"], + help="Choose which plugin to use", + ) + parser.add_argument("--load_checkpoint", type=str, default=None, help="Load checkpoint") + parser.add_argument("--save_interval", type=int, default=1000, help="Save interval") + parser.add_argument("--save_dir", type=str, default="checkpoint_dir", help="Checkpoint directory") + parser.add_argument("--tensorboard_dir", type=str, default="logs_dir", help="Tensorboard directory") + parser.add_argument("--config_file", type=str, default="config_file", help="Config file") + parser.add_argument("--num_epochs", type=int, default=1, help="Number of training epochs") + parser.add_argument("--micro_batch_size", type=int, default=2, help="Batch size of each process") + parser.add_argument("--lr", type=float, default=3e-4, help="Learning rate") + parser.add_argument("--max_length", type=int, default=4096, help="Model max length") + parser.add_argument( + "--mixed_precision", + type=str, + default="fp16", + choices=["fp16", "bf16"], + help="Mixed precision", + ) + parser.add_argument("--grad_clip", type=float, default=1.0, help="Gradient clipping value") + parser.add_argument("--weight_decay", type=float, default=0.1, help="Weight decay") + parser.add_argument("--warmup_steps", type=int, default=None, help="Warmup steps") + parser.add_argument( + "--use_grad_checkpoint", + action="store_true", + default=False, + help="Use gradient checkpointing", + ) + parser.add_argument( + "--use_flash_attn", + action="store_true", + default=False, + help="Use flash-attention", + ) + parser.add_argument( + "--freeze_non_embeds_params", + action="store_true", + default=False, + help="Freeze non embeddings parameters", + ) + parser.add_argument("--tp", type=int, default=1) + parser.add_argument("--zero", type=int, default=1) + args = parser.parse_args() + + with open(args.config_file, "w") as f: + json.dump(args.__dict__, f, indent=4) + + # ============================== + # Initialize Distributed Training + # ============================== + colossalai.launch_from_torch({}) + coordinator = DistCoordinator() + + # ============================== + # Initialize Tensorboard + # ============================== + if coordinator.is_master(): + os.makedirs(args.tensorboard_dir, exist_ok=True) + writer = SummaryWriter(args.tensorboard_dir) + + # ============================== + # Initialize Booster + # ============================== + if args.plugin == "gemini": + plugin = GeminiPlugin( + precision=args.mixed_precision, + initial_scale=2**16, + max_norm=args.grad_clip, + ) + elif args.plugin == "gemini_auto": + plugin = GeminiPlugin( + precision=args.mixed_precision, + placement_policy="auto", + initial_scale=2**16, + max_norm=args.grad_clip, + ) + elif args.plugin == "zero2": + plugin = LowLevelZeroPlugin( + stage=2, + precision=args.mixed_precision, + initial_scale=2**16, + max_norm=args.grad_clip, + ) + elif args.plugin == "zero2_cpu": + plugin = LowLevelZeroPlugin( + stage=2, + precision=args.mixed_precision, + initial_scale=2**16, + cpu_offload=True, + max_norm=args.grad_clip, + ) + elif args.plugin == "3d": + plugin = HybridParallelPlugin( + tp_size=args.tp, + pp_size=1, + zero_stage=args.zero, + max_norm=args.grad_clip, + precision=args.mixed_precision, + ) + else: + raise ValueError(f"Unknown plugin {args.plugin}") + + booster = Booster(plugin=plugin) + + # ====================================================== + # Initialize Tokenizer, Dataset, Collator and Dataloader + # ====================================================== + tokenizer = LlamaTokenizer.from_pretrained(args.pretrained) + tokenizer.pad_token = tokenizer.unk_token + tokenizer.add_bos_token = False + tokenizer.add_eos_token = False + + coordinator.print_on_master(f"Configuration file will be saved at: {args.config_file}") + coordinator.print_on_master(f"Tensorboard logs will be saved at: {args.tensorboard_dir}") + coordinator.print_on_master(f"Model checkpoint will be saved at: {args.save_dir}") + + coordinator.print_on_master(f"Load dataset: {args.dataset}") + + dataset = load_tokenized_dataset(dataset_paths=args.dataset, mode="train") + data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer, max_length=args.max_length) + dataloader = setup_distributed_dataloader( + dataset=dataset, + batch_size=args.micro_batch_size, + shuffle=True, + drop_last=True, + collate_fn=data_collator, + ) + coordinator.print_on_master( + f"Max CUDA memory after data loader: {torch.cuda.max_memory_allocated() / 1024 ** 2:.2f} MB" + ) + + # ====================================================== + # Initialize Model, Objective, Optimizer and LR Scheduler + # ====================================================== + init_ctx = ( + LazyInitContext(default_device=get_current_device()) if isinstance(plugin, (GeminiPlugin,)) else nullcontext() + ) + with init_ctx: + model = LlamaForCausalLM(LlamaConfig.from_pretrained(args.pretrained)) + # Freeze part of parameters. + if args.freeze_non_embeds_params: + freeze_non_embeds_parameters(model=model) + + if args.use_grad_checkpoint: + model.gradient_checkpointing_enable() + coordinator.print_on_master(msg="Gradient checkpointing enabled successfully") + if args.use_flash_attn: + replace_with_flash_attention(model=model) + coordinator.print_on_master(msg="Flash-attention enabled successfully") + + model_numel = get_model_numel(model) + coordinator.print_on_master(f"Model params: {format_numel_str(model_numel)}") + + optimizer = HybridAdam( + model_params=filter(lambda p: p.requires_grad, model.parameters()) + if args.freeze_non_embeds_params + else model.parameters(), + lr=args.lr, + betas=(0.9, 0.95), + weight_decay=args.weight_decay, + adamw_mode=True, + ) + + lr_scheduler = CosineAnnealingWarmupLR( + optimizer=optimizer, + total_steps=args.num_epochs * len(dataloader), + warmup_steps=args.warmup_steps + if args.warmup_steps is not None + else int(args.num_epochs * len(dataloader) * 0.025), + eta_min=0.1 * args.lr, + ) + + # Flash attention will be disabled because it does NOT support fp32. + default_dtype = torch.float16 if args.mixed_precision == "fp16" else torch.bfloat16 + torch.set_default_dtype(default_dtype) + model, optimizer, _, dataloader, lr_scheduler = booster.boost( + model=model, + optimizer=optimizer, + lr_scheduler=lr_scheduler, + dataloader=dataloader, + ) + + torch.set_default_dtype(torch.float) + + if args.load_checkpoint is None: + coordinator.print_on_master(f"Load pretrained model checkpoint from {args.pretrained}") + booster.load_model(model, args.pretrained, strict=False) + + coordinator.print_on_master(f"Booster init max CUDA memory: {torch.cuda.max_memory_allocated() / 1024 ** 2:.2f} MB") + coordinator.print_on_master( + f"Booster init max CPU memory: {resource.getrusage(resource.RUSAGE_SELF).ru_maxrss / 1024:.2f} MB" + ) + + start_epoch = 0 + start_step = 0 + sampler_start_idx = 0 + if args.load_checkpoint is not None: + if "modeling" in args.load_checkpoint: + coordinator.print_on_master(f"Continued pretrain from checkpoint {args.load_checkpoint}") + booster.load_model(model, args.load_checkpoint) + else: + coordinator.print_on_master(f"Load model checkpoint from {args.load_checkpoint}") + start_epoch, start_step, sampler_start_idx = load_checkpoint( + load_dir=args.load_checkpoint, + booster=booster, + model=model, + optimizer=optimizer, + lr_scheduler=lr_scheduler, + ) + coordinator.print_on_master( + f"Loaded checkpoint {args.load_checkpoint} at epoch {start_epoch} step {start_step}" + ) + coordinator.print_on_master(f"Loaded sample at index {sampler_start_idx}") + + coordinator.print_on_master( + f"Checkpoint loaded max CUDA memory: {torch.cuda.max_memory_allocated() / 1024 ** 2:.2f} MB" + ) + coordinator.print_on_master( + f"Checkpoint loaded CUDA memory: {torch.cuda.memory_allocated() / 1024 ** 2:.2f} MB" + ) + coordinator.print_on_master( + f"Checkpoint loaded max CPU memory: {resource.getrusage(resource.RUSAGE_SELF).ru_maxrss / 1024:.2f} MB" + ) + + num_steps_per_epoch = len(dataloader) + # If resume training, set the sampler start index to the correct value + assert isinstance(dataloader.sampler, StatefulDistributedSampler) + dataloader.sampler.set_start_index(start_index=sampler_start_idx) + + for epoch in range(start_epoch, args.num_epochs): + dataloader.sampler.set_epoch(epoch=epoch) + with tqdm( + iterable=enumerate(dataloader, start=start_step), + desc=f"Epoch {epoch}", + disable=not coordinator.is_master(), + total=num_steps_per_epoch, + initial=start_step, + ) as pbar: + for step, batch in pbar: + batch = {k: v.to(get_current_device()) for k, v in batch.items() if isinstance(v, torch.Tensor)} + + batch_output = model(**batch) + + loss = batch_output.loss + + booster.backward(loss=loss, optimizer=optimizer) + + optimizer.step() + lr_scheduler.step() + optimizer.zero_grad() + + all_reduce_mean(tensor=loss) + pbar.set_postfix({"Loss": f"{loss.item():.4f}"}) + if coordinator.is_master(): + global_step = epoch * num_steps_per_epoch + step + writer.add_scalar(tag="Loss", scalar_value=loss.item(), global_step=global_step) + writer.add_scalar( + tag="Learning Rate", + scalar_value=lr_scheduler.get_last_lr()[0], + global_step=global_step, + ) + # Save modeling. + + if (args.save_interval > 0 and (step + 1) % args.save_interval == 0) or (step + 1) == len(dataloader): + coordinator.print_on_master("\nStart saving model checkpoint with running states") + save_checkpoint( + save_dir=args.save_dir, + booster=booster, + model=model, + optimizer=optimizer, + lr_scheduler=lr_scheduler, + epoch=epoch, + step=step + 1, + batch_size=args.micro_batch_size, + coordinator=coordinator, + ) + coordinator.print_on_master( + f"Saved checkpoint at epoch {epoch} step {step + 1} at folder {args.save_dir}" + ) + + # Delete CUDA cache. + # del batch, batch_labels, batch_output, loss + torch.cuda.empty_cache() + + # the continue epochs are not resumed, so we need to reset the sampler start index and start step + dataloader.sampler.set_start_index(start_index=0) + start_step = 0 + + # Final save. + coordinator.print_on_master("Start saving final model checkpoint") + booster.save_model(model, os.path.join(args.save_dir, "modeling"), shard=True) + coordinator.print_on_master( + f"Saved final model checkpoint at epoch {epoch} at folder {args.save_dir}" + ) + + coordinator.print_on_master(f"Max CUDA memory usage: {torch.cuda.max_memory_allocated()/1024**2:.2f} MB") + + +if __name__ == "__main__": + main() diff --git a/applications/Colossal-LLaMA-2/version.txt b/applications/Colossal-LLaMA-2/version.txt new file mode 100644 index 0000000000000000000000000000000000000000..8a9ecc2ea99d607e92feae1656ddbf6fdd82a2c1 --- /dev/null +++ b/applications/Colossal-LLaMA-2/version.txt @@ -0,0 +1 @@ +0.0.1 \ No newline at end of file diff --git a/applications/ColossalEval/README.md b/applications/ColossalEval/README.md new file mode 100644 index 0000000000000000000000000000000000000000..3f645fe7892c14fa5a3ff21cacb722df7baa82c1 --- /dev/null +++ b/applications/ColossalEval/README.md @@ -0,0 +1,560 @@ +
        +

        + +

        +
        + +## Table of Contents + +- [Overview](#overview) +- [Leaderboard](#leaderboard) +- [Install](#install) +- [Evaluation Process](#evaluation-process) + - [Inference](#inference) + - [Dataset Preparation](#dataset-preparation) + - [Configuration](#configuration) + - [How to Use](#how-to-use) + - [Evaluation](#evaluation) + - [Dataset Evaluation](#dataset-evaluation) + - [Configuration](#dataset-evaluation) + - [How to Use](#dataset-evaluation) + - [GPT Evaluation](#gpt-evaluation) + - [Configuration](#gpt-evaluation) + - [How to Use](#gpt-evaluation) +- [More Details](#more-details) + - [Inference Details](#inference-details) + - [Evaluation Details](#evaluation-details) + - [Metrics](#metrics) + - [examples](#examples) + - [Dataset Evaluation Example](#dataset-evaluation-example) + - [GPT Evaluation Example](#gpt-evaluation-example) +- [To Do](#to-do) +- [FAQ](#faq) + - [How to Add a New Metric?](#how-to-add-a-new-metric) + - [How to Add a New Dataset?](#how-to-add-a-new-dataset) + - [How to Add a New Model?](#how-to-add-a-new-model) +- [Citations](#citations) + +## Overview +[ColossalEval](https://github.com/hpcaitech/ColossalAI/tree/main/applications/ColossalEval) is a project which provides a uniform pipeline to help evaluate language models on different public dataset or your own dataset using both classic metrics and the help from GPTs. More details can be found in the following sections. + +## Leaderboard + +We conducted comprehensive evaluation on 4 dataset and compare our Colossal-Llama-2-7b-base model with various models. + +- We use 5-shot for MMLU and calculate scores based on the logits of first predicted token. +- We use 5-shot for CMMLU and calculate scores based on the logits of first predicted token. +- We use 5-shot for AGIEval and only calculate scores for 4-choice questions using a combination metric of exact match and the logits of first predicted token. If any of the exact match or logits of first predicted token is correct, the model will get the score. +- We use 0-shot for GAOKAO-Bench and only calculate scores for 4-choice questions based on the logits of first predicted token. +- The generation config for all dataset is greedy search. +- We also provided CEval scores from its lastest leaderboard or the official repository of the model. + +More details about metrics can be found in [Metrics](#metrics). + +| | Backbone | Tokens Consumed | | MMLU | CMMLU | AGIEval | GAOKAO | CEval | +| :----------------------------: | :--------: | :-------------: | :------------------: | :-----------: | :-----: | :----: | :----: | :----------------------------: | +| | - | - | | 5-shot | 5-shot | 5-shot | 0-shot | 5-shot | +| Baichuan-7B | - | 1.2T | | 42.32 (42.30) | 44.53 (44.02) | 38.72 | 36.74 | 42.80 | +| Baichuan-13B-Base | - | 1.4T | | 50.51 (51.60) | 55.73 (55.30) | 47.20 | 51.41 | 53.60 | +| Baichuan2-7B-Base | - | 2.6T | | 46.97 (54.16) | 57.67 (57.07) | 45.76 | 52.60 | 54.00 | +| Baichuan2-13B-Base | - | 2.6T | | 54.84 (59.17) | 62.62 (61.97) | 52.08 | 58.25 | 58.10 | +| ChatGLM-6B | - | 1.0T | | 39.67 (40.63) | 41.17 (-) | 40.10 | 36.53 | 38.90 | +| ChatGLM2-6B | - | 1.4T | | 44.74 (45.46) | 49.40 (-) | 46.36 | 45.49 | 51.70 | +| InternLM-7B | - | - | | 46.70 (51.00) | 52.00 (-) | 44.77 | 61.64 | 52.80 | +| InternLM-20B | - | 2.3T | | 60.96 (62.05) | 59.08 (-) | 57.96 | 61.92 | - | +| Qwen-7B (original) | - | 2.2T | | 54.29 (56.70) | 56.03 (58.80) | 52.47 | 56.42 | 59.60 | +| Qwen-7B | - | 2.4T | | 58.33 (58.20) | 62.54 (62.20) | 64.34 | 74.05 | 63.50 | +| | | | | | | | | | +| Llama-2-7B | - | 2.0T | | 44.47 (45.30) | 32.97 (-) | 32.60 | 25.46 | - | +| Linly-AI/Chinese-LLaMA-2-7B-hf | Llama-2-7B | 1.0T | | 37.43 | 29.92 | 32.00 | 27.57 | - | +| wenge-research/yayi-7b-llama2 | Llama-2-7B | - | | 38.56 | 31.52 | 30.99 | 25.95 | - | +| ziqingyang/chinese-llama-2-7b | Llama-2-7B | - | | 33.86 | 34.69 | 34.52 | 25.18 | 34.2 | +| TigerResearch/tigerbot-7b-base | Llama-2-7B | 0.3T | | 43.73 | 42.04 | 37.64 | 30.61 | - | +| LinkSoul/Chinese-Llama-2-7b | Llama-2-7B | - | | 48.41 | 38.31 | 38.45 | 27.72 | - | +| FlagAlpha/Atom-7B | Llama-2-7B | 0.1T | | 49.96 | 41.10 | 39.83 | 33.00 | - | +| IDEA-CCNL/Ziya-LLaMA-13B-v1.1 | Llama-13B | 0.11T | | 50.25 | 40.99 | 40.04 | 30.54 | - | +| | | | | | | | | | +| **Colossal-LLaMA-2-7b-base** | Llama-2-7B | **0.0085T** | | 53.06 | 49.89 | 51.48 | 58.82 | 50.20 | + +> The score in parentheses corresponds to the scores in the official repository of the model. +> +> We use zero-shot for ChatGLM models. +> +> To evaluate Qwen-7B on dataset MMLU, the prompt would be "xxx Answer:"(remove the space after ":") and we calculate the logits over " A", " B", " C" and " D" for Qwen-7B. Both the original and updated versions of Qwen-7B tend to be much more deterministic than other models. For example, the logits over " A" can be `-inf` and softmax would be exact `0`. +> +> For other models and other dataset, we calculate logits over "A", "B", "C" and "D". + +Our model achieves a much better score over all other Llama-1 or Llama-2 based models and also stands out among popular open source LLMs. + +## Install +You should install `ColossalEval` in order to use it and `colossal_eval` is the package installed. +```bash +git clone https://github.com/hpcaitech/ColossalAI.git +cd ColossalAI/applications/ColossalEval +pip install . +``` +If you want to add customized dataset or models, use `pip install -e .` in stead to ensure that any changes you make to the source code will immediately affect the package you install. + +## Evaluation Process +The evaluation process involves 2 steps which are `inference` and `evaluation`. You need to set the config for each step. + +### Inference + +The inference process consists of two parts. +1. Preprocess and convert the original dataset. +2. Config your tokenizer and model arguments to perform zero-shot or few-shot prompting. + +#### Dataset Preparation + +In this step, the original dataset(either in `csv` or `jsonl` format) will be loaded and converted into a `dict`. In the conversion process, we carefully parse each subcategory and assign specific inference arguments for this subcategory. + +Inference arguments are stored in a `dict`. The following is an example. + +```python +inference_kwargs = { + "calculate_loss": True, + "all_classes": ["A", "B", "C", "D"], + "language": "Chinese", + "pretrain": False, + "max_new_tokens": 32 +} +``` +The `inference_kwargs` currently contains 5 fields: + +- `calculate_loss` (bool, compulsory): Whether the loss on target tokens will be calculated +- `all_classes` (Optional[list], compulsory): Whether the subcategory is a single-choice question. Specify all available options in a list or otherwise None. +- `language` (str, compulsory): The language for the subcategory. +- `pretrain` (bool, compulsory): Whether the dataset is a pretrain dataset or not. It is usually used for calculate perplexity when you want to evaluate a model with extended context length. +- `max_new_tokens` (int, compulsory): The number of new tokens to generate during inference. + +For example, for dataset MMLU, each subcategory consists of single-choice questions with options A, B, C and D by default and we can assign value `["A", "B", "C", "D"]` to key`all_classes`. For dataset C-Eval, target answers aren't provided in the test split so `calculate_loss` should be set as False. However, other dataset such as GAOKAO-bench contains different formats of questions and lacks some keys or metadata which can reveal what type (single-choice or multi-choice) of questions it is. Before assigning inference arguments, we first parse the dataset to decide which type of questions the subcategory belongs to and set the inference arguments accordingly. + +Other than `inference_kwargs`, `data` is a list containing questions of a same subcategory. The following is a converted dataset. + +```json +{ + "dev": { + "category 1": {"data": [], "inference_kwargs": {}}, + "category 2": {"data": [], "inference_kwargs": {}} + }, + "test": { + "category 1": {"data": [], "inference_kwargs": {}}, + "category 2": {"data": [], "inference_kwargs": {}} + } +} +``` + +A data sample basically follow the format of Alpaca. It should contain the following keys: + +* `dataset` (str, compulsory): The name of the dataset. +* `split` (str, compulsory): The split of the instruction. +* `catrgory` (str, compulsory): The category of the instruction. +* `instruction` (str, compulsory): The instruction for the LLM. +* `input` (str, optional): The additional context of the instruction. +* `output` (str, optional): The model output of the instruction. +* `target` (str, optional): The target answer for the instruction. + +Example: + +```json +{ + "dev": { + "Abstract Algebra": [ + { + "dataset": "mmlu", + "split": "dev", + "category": "Abstract Algebra", + "instruction": "The following is a single-choice question on Abstract Algebra. Answer the question by replying A, B, C or D.", + "input": "Question: Find all c in Z_3 such that Z_3[x]/(x^2 + c) is a field.\nA. 0\nB. 1\nC. 2\nD. 3\nAnswer: ", + "output": "", + "target": "B" + }, + ] + }, + "test": { + "Abstract Algebra": [ + { + "dataset": "mmlu", + "split": "test", + "category": "Abstract Algebra", + "instruction": "The following is a single-choice question on Abstract Algebra. Answer the question by replying A, B, C or D.", + "input": "Question: Find the degree for the given field extension Q(sqrt(2), sqrt(3), sqrt(18)) over Q.\nA. 0\nB. 4\nC. 2\nD. 6\nAnswer: ", + "output": "", + "target": "B" + }, + ] + } +} +``` + +#### Configuration +In this step, you will configure your tokenizer and model arguments to infer on the given datasets. + +A config file consists of two parts. +1. Model config. In model config, you need to specify model name, model path, model class, tokenizer arguments and model arguments. For model class, currently we support `HuggingFaceModel`, `HuggingFaceCausalLM`, `ChatGLMModel` and `ChatGLMModel2`. `HuggingFaceModel` is for models that can be loaded with `AutoModel` and `HuggingFaceCausalLM` is for models that can be loaded with `AutoModelForCausalLM`. `ChatGLMModel` and `ChatGLMModel2` are for ChatGLM and ChatGLM2 models respectively. You can check all model classes in `colossal_eval/models/__init__.py`. If your model should set `trust_remote_code` as true, specify it in the `tokenizer_kwargs` and `model_kwargs` fields. +2. Dataset config. In dataset config, you need to specify dataset name, path and dataset class. Currently, we support zero-shot on dataset MMLU, CMMLU, AGIEval, GAOKAO-Bench and LongBench and few-shot on dataset MMLU, CMMLU and AGIEval. If you want to enable few shot, set `few_shot` as true. You can check all model classes in `colossal_eval/dataset/__init__.py`. + +Once you have all config ready, the program will run inference on all the given datasets on all the given models. + +An example config using model class `HuggingFaceCausalLM` and dataset class `CMMLUDataset` can be: +```json +{ + "model": [ + { + "name": "model name", + "model_class": "HuggingFaceCausalLM", + "parameters": { + "path": "path to model", + "model_max_length": 2048, + "tokenizer_path": "path to tokenizer", + "tokenizer_kwargs": { + "use_fast": false, + "trust_remote_code": true + }, + "peft_path": null, + "model_kwargs": { + "trust_remote_code": true + }, + "prompt_template": "plain", + "batch_size": 4 + } + } + ], + "dataset": [ + { + "name": "dataset name", + "dataset_class": "CMMLUDataset", + "debug": false, + "few_shot": true, + "path": "path to original dataset", + "save_path": "path to save converted dataset" + } + ] +} +``` + +Currently, we support Hugging Face models. The `tokenizer_kwargs` is the arguments used in `AutoTokenizer.from_pretrained()`. The `model_kwargs` is the arguments used in `AutoModel.from_pretrained` or `AutoModelForCausalLM.from_pretrained()`. `few_shot` will be set true if you want to enable few-shot prompting for the dataset. `debug` will be set true if you want to verify whether your prompt is right or wrong. + +#### How to Use +An example script can be the following. The `configs/dataset_evaluation/inference.py` is the same in all examples provided. + +```shell +torchrun --nproc_per_node=1 inference.py \ + --config "path to config file" \ + --load_dataset \ + --inference_save_path "path to save inference results" +``` + +You should specify the path to config file in `config`. You can run the script without specifying `load_dataset` if you already save the converted dataset or otherwise set it to first load the original dataset and save the converted dataset. You should specify the path to save inference results in `inference_save_path`. + +### Evaluation + +In the evaluation process, you only need to configure your evaluation parameters. You can use either public dataset or help from GPTs to do evaluation. We will introduce configuration for dataset evaluation and GPT evaluation. + +#### Dataset Evaluation + +In dataset evaluation, we calculate different metrics on the given inference results and public dataset. + +##### Configuration + +A config file for dataset evaluation consists of two parts. +1. Model config. In model config, you need to specify model name. If you want to evaluate perplexity over a pretrain dataset and calculate per-byte-perplexity, you have to add your tokenizer config and model max length. +2. Dataset config. In dataset config, you need to specify the evaluation metrics for the dataset. + +Once you have all config ready, the program will run evaluation on inference results for all given models and dataset. + +An example config can be: +```json +{ + "model": [ + { + "name": "model name" + } + ], + "dataset": [ + { + "name": "dataset name", + "metrics": ["first_token_accuracy"] + } + ] +} +``` + +The above config specifies that the program will evaluate the inference results using `first_token_accuracy` metric. + +##### How to Use + +An example script can be the following. + +```shell +python eval_dataset.py \ + --config "path to config file" \ + --inference_results_path "path to inference results" \ + --evaluation_results_save_path "path to save evaluation results" +``` + +You should specify the path to config file in `config`, the path to inference results in `inference_results_path` and the path to save evaluation results in `evaluation_save_path`. + +#### GPT Evaluation + +In GPT evaluation, we provide a prompt template which can fit in different pre-defined metrics with Chain-of-Thoughts. In the following sections, we will only introduce how you can evaluate model answers using GPTs. More details can be found in `colossal_eval/evaluate/GPT Evaluation.md`. + +##### Configuration + +The following is an example of a English config file. The configuration file can control how the pipeline evaluates the model. You need to specify GPT evaluation metrics. You can find an example English config file in `configs/gpt_evaluation`. + +```json +{ + "language": "en", + "category": { + "brainstorming": { + "GPT": [ + "language organization", + "relevance", + "creativity", + "practicality", + "reasonableness" + ] + }, + } +} +``` + +##### How to Use +After setting the config file, you can evaluate the model using `examples/gpt_evaluation/eval.py`. If you want to make comparisons between answers of two different models, you should specify two answer files in the argument `answer_file_list` and two model names in the argument `model_name_list`(details can be found in `colossal_eval/evaluate/GPT Evaluation.md`). If you want to evaluate one answer file, the length of both `answer_file_list` and `model_name_list` should be 1 and the program will perform evaluation using GPTs. The prompt files for battle and gpt evaluation can be found in `configs/gpt_evaluation/prompt`. `target file` is the path to the converted dataset you save during inference time. + +An example script is provided as follows: + +```shell +python eval.py \ + --config_file "path to the config file" \ + --battle_prompt_file "path to the prompt file for battle" \ + --gpt_evaluation_prompt_file "path to the prompt file for gpt evaluation" \ + --target_file "path to the target answer file" \ + --answer_file_list "path to the answer file" \ + --model_name_list "the names of the model" \ + --gpt_model "which GPT model to use for evaluation" \ + --save_path "path to save results" \ + --openai_key "your openai key" \ +``` + +## More Details + +### Inference + +In the inference process, we will do generation, calculate loss over target tokens, calculate number of target tokens, softmax over given options (for example, "A", "B", "C", and "D") according to the inference arguments. + +For tokenization, we adopt tokenization strategy in [LongBench](https://github.com/THUDM/LongBench/blob/main/pred.py#L55) to preserve crucial instructions on the left and right side and keep all target tokens. + +For labeling target tokens, we adopt method from [FastChat](https://github.com/lm-sys/FastChat/blob/main/fastchat/train/train.py#L137), but it doesn't always hold true due to tokenizers' different behavior. We plan to insert special tokens to correctly label the target tokens. + +For calculating loss, we return per-sample-loss instead of per-batch-loss if we directly use `model(batch).loss` provided in HuggingFace. + +### Evaluation + +To make it more easier to set the config, you only need to specify all metrics you want to use in key `metrics`. However, the program will only use a subset of metrics you give for different subcategories. Applying all metrics to all subcategories is obviously unsuitable. The suggested metrics for specific categories should be defined in `colossal_eval/evaluate/dataset_evaluator/metrics.py`. + +#### Metrics + +- `combined_single_choice_accuracy`: A combination of `first_token_logit` and `single_choice_accuracy`. If one of these is correct, the model will get the score. It can be used in all dataset that contains single-choice questions. +- `first_token_logit`: Calculate score based on softmax score over the given choices. If the argmax of the softmax is equal to the reference, the model will get the score. If there is `NaN` in softmax score, it will calculate the score using exact match. It can be used in all dataset that contains single-choice questions. +- `single_choice_accuracy`: Calculate score using exact match. It will only get the first uppercase letter such as A, B, C or D that is not surrouded by lowercase letters. If the uppercase letter is equal to the reference, the model will get the score. It can be used in all dataset that contains single-choice questions. +- `multi_choice_accuracy`: Calculate score on multi-choice questions. It will get a set of all uppercase letters such as A, B, C or D that is not surrouded by lowercase letters. If the prediction conatains uppercase letters that are not in reference. The model will get 0 score. If the prediction contains a uppercase letter that is in reference, the model will get a score of `1/len(reference)`. It is used in AGIEval and GAOKAO-Bench. +- `math_equivalence`: Code from [hendrycks](https://github.com/hendrycks/math/blob/main/modeling/math_equivalence.py). Compute scores over the prediction math formula and reference math formula. It is used in AGIEval and GAOKAO-Bench. +- `f1_score`: Calculate English f1 score between prediction and reference. It is used in Longbench. +- `f1_zh_score`: Calculate Chinese f1 score between prediction and reference. It is used in Longbench. +- `rouge_score`: Calculate English f1 score between prediction and reference. It is used in GAOKAO-Bench and LongBench. +- `rouge_zh_score`: Calculate Chinese rouge score between prediction and reference. It is used in GAOKAO-Bench and LongBench. +- `retrieval_score`: Calculate English retrieval score between prediction and reference. It determines whether the ouput(which paragraph) corresponds to the given abstract. It is used in Longbench. +- `retrieval_zh_score`: Calculate Chinese retrieval score between prediction and reference. It determines whether the ouput(which paragraph) corresponds to the given abstract. It is used in Longbench. +- `classification_score`: Calculate classification score between prediction and reference. It determines whether the ouput(a class) is equal to the reference. It is used in Longbench. +- `code_sim_score`: Calculate similarity score between prediction and reference. It is used in Longbench. +- `count_score`: Calculate count score between prediction and reference. It determines whether the ouput(number of given passages) is equal to the reference. It is used in Longbench. +- `perplexity`: Calculate perplexity. The formula is $ perplexity = \frac{1}{n} \sum_i e^{loss_i} $ where $n$ is the number of samples and $ loss_i $ is the average loss for sample $ i $. It can be used in all dataset. +- `ppl_score`: Calculate perplexity score. The formula is $ ppl\_score = \frac{1}{n} \sum_i e^{-loss_i} $ where $n$ is the number of samples and $ loss_i $ is the average loss for sample $ i $. It can be used in all dataset. +- `ppl_score_over_choices`: Calculate perplexity score over choices. The formula is $ ppl\_score\_over\_choices= \frac{1}{n} \sum_i e^{-loss\_over\_choices_i} $ where $n$ is the number of samples and $ loss\_over\_choices_i $ is the loss on the first predicted token for sample $ i $. It can be used in all dataset that contains single-choice questions. +- `per_byte_perplexity`: Calculate per byte perplexity. The formula is $ \frac{1}{n} \sum_i e^{\frac{loss_i}{byte_i}} $ where $n$ is the number of samples, $ loss_i $ is the total loss for sample $ i $ and $ byte_i $ is the number of bytes sample $ i $ occupies. It can be used in all dataset. +- `per_byte_ppl_score`: Calculate per byte perplexity score. The formula is $ \frac{1}{n} \sum_i e^{-\frac{loss_i}{byte_i}} $ where $n$ is the number of samples, $ loss_i $ is the total loss for sample $ i $ and $ byte_i $ is the number of bytes sample $ i $ occupies. It can be used in all dataset. + +We use `combined_single_choice_accuracy` and `first_token_logit` in the leaderboard. + +### Examples + +We provide 2 examples for you to explore our `colossal_eval` package. + +#### Dataset Evaluation Example + +This example is in folder `examples/dataset_evaluation`. + +1. `cd examples/dataset_evaluation` +2. Fill in your inference config file in `config/inference/config.json`. Set the model and dataset parameters. +3. Run `inference.sh` to get inference results. +4. Fill in your evaluation config file in `config/evaluation/config.json`. Set the model and dataset parameters. +5. Run `eval_dataset.sh` to get evaluation results. + +#### GPT Evaluation Example + +The examples is in folder `examples/gpt_evaluation`. + +1. `cd examples/gpt_evaluation` +2. Fill in your inference config file in `config/inference/config.json`. Set the model and dataset parameters. If you want to use the example dataset we provide, the dataset is `ColossalDataset`. +3. Run `inference.sh` to get inference results. +4. Fill in your evaluation config file in `config/evaluation/config.json`. +5. Run `eval.sh` to get evaluation results. + +## FAQ + +### How to Add a New Metric? + +If you want to add a customized metric, we recommend using `pip install -e .` to ensure that any changes you make to the source code will immediately affect the package you install. + +To add a new metric, you can follow the example of multi_choice_accuracy in line 339 in `colossal_eval/evaluate/dataset_evaluator/metric.py`. The method take one data sample's prediction and reference as input and return a score ranging from 0 to 1. + +A skeleton of code is the following. + +```python + +def CustomizedMetric(prediction: str, reference: str): + score = xxx + return score +``` + +Once you have successfully added your own metric, you should specify your metric both in `colossal_eval/evaluate/dataset_evaluator/metric.py` (suggest which subcategories shoule the metric be applied to) and your evaluation config. + +### How to Add a New Dataset? + +If you want to add customized dataset, we recommend using `pip install -e .` to ensure that any changes you make to the source code will immediately affect the package you install. + +To add a new dataset, you can follow the example of `colossal_eval/dataset/mmlu.py`. You need to make sure that the format of questions in one subcategory should be the same. For example, all questions should have target answers or all questions should be single-choice questions. + +A skeleton of code is the following. + +```python + +class CustomizedDataset(BaseDataset): + @staticmethod + def load(): + # 1. Load and convert the original dataset format. + # 2. Assign inference arguments for each subcategory. + # 3. Return the converted dataset. + pass +``` + +Once you have successfully added your own dataset, you can specify your dataset class in your inference config. + +### How to Add a New Model? + +If you want to add customized models, we recommend using `pip install -e .` to ensure that any changes you make to the source code will immediately affect the package you install. + +To add a new model, you can follow the example of `colossal_eval/models/huggingface.py`. You need to provide a way to load the model and tokenizer, calculate loss and generate. + +A skeleton of code is the following. + +```python + +class CustomizedModel(BaseModel): + def __init__(self): + super().__init__() + self._load_tokenizer() + self._load_model() + + def _load_tokenizer(): + pass + + def _load_model(): + pass + + def _calculate_loss(): + pass + + def get_loss(): + self._calculate_loss() + + def inference(samples): + # 1. Load samples from the same subcategory. + # 2. Infer in a batch way according to inference arguments. + # 3. Return results. + batch_samples = xxx + self.get_loss(batch_samples) + self.generate(batch_samples) + + return inference_results + + def generate(): + pass +``` + +Once you have successfully added your own model, you can specify your model class in your inference config. + +## To do + +- [ ] Add visualization code for evaluation results on public dataset +- [ ] Improve the way to label target tokens + +## Citations + +```bibtex +@misc{zhong2023agieval, + title={AGIEval: A Human-Centric Benchmark for Evaluating Foundation Models}, + author={Wanjun Zhong and Ruixiang Cui and Yiduo Guo and Yaobo Liang and Shuai Lu and Yanlin Wang and Amin Saied and Weizhu Chen and Nan Duan}, + year={2023}, + eprint={2304.06364}, + archivePrefix={arXiv}, + primaryClass={cs.CL} +} + +@article{huang2023ceval, +title={C-Eval: A Multi-Level Multi-Discipline Chinese Evaluation Suite for Foundation Models}, +author={Huang, Yuzhen and Bai, Yuzhuo and Zhu, Zhihao and Zhang, Junlei and Zhang, Jinghan and Su, Tangjun and Liu, Junteng and Lv, Chuancheng and Zhang, Yikai and Lei, Jiayi and Fu, Yao and Sun, Maosong and He, Junxian}, +journal={arXiv preprint arXiv:2305.08322}, +year={2023} +} + +@misc{li2023cmmlu, + title={CMMLU: Measuring massive multitask language understanding in Chinese}, + author={Haonan Li and Yixuan Zhang and Fajri Koto and Yifei Yang and Hai Zhao and Yeyun Gong and Nan Duan and Timothy Baldwin}, + year={2023}, + eprint={2306.09212}, + archivePrefix={arXiv}, + primaryClass={cs.CL} +} + +@inproceedings{Zhang2023EvaluatingTP, + title={Evaluating the Performance of Large Language Models on GAOKAO Benchmark}, + author={Xiaotian Zhang and Chunyang Li and Yi Zong and Zhengyu Ying and Liang He and Xipeng Qiu}, + year={2023} +} + +@misc{bai2023longbench, + title={LongBench: A Bilingual, Multitask Benchmark for Long Context Understanding}, + author={Yushi Bai and Xin Lv and Jiajie Zhang and Hongchang Lyu and Jiankai Tang and Zhidian Huang and Zhengxiao Du and Xiao Liu and Aohan Zeng and Lei Hou and Yuxiao Dong and Jie Tang and Juanzi Li}, + year={2023}, + eprint={2308.14508}, + archivePrefix={arXiv}, + primaryClass={cs.CL} +} + +@article{hendryckstest2021, + title={Measuring Massive Multitask Language Understanding}, + author={Dan Hendrycks and Collin Burns and Steven Basart and Andy Zou and Mantas Mazeika and Dawn Song and Jacob Steinhardt}, + journal={Proceedings of the International Conference on Learning Representations (ICLR)}, + year={2021} +} + +@article{hendrycks2021ethics, + title={Aligning AI With Shared Human Values}, + author={Dan Hendrycks and Collin Burns and Steven Basart and Andrew Critch and Jerry Li and Dawn Song and Jacob Steinhardt}, + journal={Proceedings of the International Conference on Learning Representations (ICLR)}, + year={2021} +} + +@misc{zheng2023judging, + title={Judging LLM-as-a-judge with MT-Bench and Chatbot Arena}, + author={Lianmin Zheng and Wei-Lin Chiang and Ying Sheng and Siyuan Zhuang and Zhanghao Wu and Yonghao Zhuang and Zi Lin and Zhuohan Li and Dacheng Li and Eric. P Xing and Hao Zhang and Joseph E. Gonzalez and Ion Stoica}, + year={2023}, + eprint={2306.05685}, + archivePrefix={arXiv}, + primaryClass={cs.CL} +} + +``` diff --git a/applications/Chat/coati/ray/src/__init__.py b/applications/ColossalEval/colossal_eval/__init__.py similarity index 100% rename from applications/Chat/coati/ray/src/__init__.py rename to applications/ColossalEval/colossal_eval/__init__.py diff --git a/applications/ColossalEval/colossal_eval/dataset/__init__.py b/applications/ColossalEval/colossal_eval/dataset/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..4ea173198f5a84d053a835c885403e19d42e70a5 --- /dev/null +++ b/applications/ColossalEval/colossal_eval/dataset/__init__.py @@ -0,0 +1,19 @@ +from .agieval import AGIEvalDataset +from .base import BaseDataset +from .ceval import CEvalDataset +from .cmmlu import CMMLUDataset +from .colossalai import ColossalDataset +from .gaokaobench import GaoKaoBenchDataset +from .longbench import LongBenchDataset +from .mmlu import MMLUDataset + +__all__ = [ + "AGIEvalDataset", + "BaseDataset", + "CEvalDataset", + "CMMLUDataset", + "GaoKaoBenchDataset", + "LongBenchDataset", + "MMLUDataset", + "ColossalDataset", +] diff --git a/applications/ColossalEval/colossal_eval/dataset/agieval.py b/applications/ColossalEval/colossal_eval/dataset/agieval.py new file mode 100644 index 0000000000000000000000000000000000000000..92ebd65931edf829a35d0535abb52f2088f34108 --- /dev/null +++ b/applications/ColossalEval/colossal_eval/dataset/agieval.py @@ -0,0 +1,247 @@ +# Adapted from https://github.com/ruixiangcui/AGIEval/blob/main/src/dataset_loader.py. + +import ast +import glob +import os +from copy import deepcopy +from typing import Dict, List + +import pandas as pd +from colossal_eval.utils import get_json_list + +from colossalai.logging import DistributedLogger + +from .base import BaseDataset + +# define the datasets +english_qa_datasets = [ + "lsat-ar", + "lsat-lr", + "lsat-rc", + "logiqa-en", + "sat-math", + "sat-en", + "aqua-rat", + "sat-en-without-passage", + "gaokao-english", +] +chinese_qa_datasets = [ + "logiqa-zh", + "jec-qa-kd", + "jec-qa-ca", + "gaokao-chinese", + "gaokao-geography", + "gaokao-history", + "gaokao-biology", + "gaokao-chemistry", + "gaokao-physics", + "gaokao-mathqa", +] +english_cloze_datasets = ["math"] +chinese_cloze_datasets = ["gaokao-mathcloze"] + +multi_choice_datasets = ["jec-qa-kd", "jec-qa-ca", "gaokao-physics", "gaokao-mathqa"] +math_output_datasets = {"gaokao-mathcloze", "math"} + +default_inference_kwargs = { + "calculate_loss": True, + "all_classes": None, + "language": "Chinese", + "pretrain": False, + "max_new_tokens": 32, +} + + +def get_prompt(line: Dict, dataset_name: str, logger: DistributedLogger) -> Dict: + """Modified from https://github.com/microsoft/AGIEval/blob/main/src/dataset_loader.py#L190""" + try: + all_classes = None + passage = line["passage"] if line["passage"] is not None else "" + + if dataset_name in english_qa_datasets: + option_string = "ABCDEFG" + count = len(line["options"]) + + input = ( + "Question: " + + line["question"] + + " " + + "Choose from the following options: " + + " ".join(line["options"]) + + "\n" + + "Answer: " + ) + + all_classes = list(option_string[0:count]) + + elif dataset_name in chinese_qa_datasets: + option_string = "ABCDEFG" + count = len(line["options"]) + + input = "问题:" + line["question"] + " " + "从以下选项中选择:" + " ".join(line["options"]) + "\n" + "答案:" + + all_classes = list(option_string[0:count]) + + elif dataset_name in english_cloze_datasets: + input = "Question: " + line["question"] + "\n" + "Answer: " + + elif dataset_name in chinese_cloze_datasets: + input = "问题:" + line["question"] + "\n" + "答案:" + + return { + "instruction": input if not passage else passage + "\n\n" + input, + "target": line["label"] if line["label"] else line["answer"], + }, all_classes + + except NameError: + logger.info("Dataset not defined.") + + +# process few-shot raw_prompts +def combine_prompt(prompt_path, dataset_name, load_explanation=True, chat_mode=False): + skip_passage = False + if dataset_name == "sat-en-without-passage": + skip_passage = True + dataset_name = "sat-en" + demostrations = [] + # read the prompts by context and explanation + context_row = [0, 1, 3, 5, 7, 9] + explanation_row = [0, 2, 4, 6, 8, 10] + raw_prompts_context = pd.read_csv( + prompt_path, header=0, skiprows=lambda x: x not in context_row, keep_default_na=False + ) + raw_prompts_explanation = pd.read_csv( + prompt_path, header=0, skiprows=lambda x: x not in explanation_row, keep_default_na=False + ).replace(r"\n\n", "\n", regex=True) + contexts = [] + for line in list(raw_prompts_context[dataset_name]): + if line: + # print(line) + contexts.append(ast.literal_eval(line)) + explanations = [exp for exp in raw_prompts_explanation[dataset_name] if exp] + + for idx, (con, exp) in enumerate(zip(contexts, explanations)): + passage = con["passage"] if con["passage"] is not None and not skip_passage else "" + question = con["question"] + options = con["options"] if con["options"] is not None else "" + label = con["label"] if con["label"] is not None else "" + answer = con["answer"] if "answer" in con and con["answer"] is not None else "" + + if dataset_name in english_qa_datasets: + question_input = ( + "Question: " + + passage + + " " + + question + + "\n" + + "Choose from the following options: " + + " ".join(options) + + "\n" + + "Answer: {}".format(label) + ) + elif dataset_name in chinese_qa_datasets: + question_input = ( + "问题:" + passage + " " + question + "\n" + "从以下选项中选择:" + " ".join(options) + "\n" + "答案:{}".format(label) + ) + elif dataset_name in english_cloze_datasets: + question_input = "Question: ".format(idx + 1) + question + "\n" + "Answer: {}".format(answer) + elif dataset_name in chinese_cloze_datasets: + question_input = "问题:" + question + "\n" + "答案:{}".format(answer) + else: + raise ValueError(f"During loading few-sot examples, found unknown dataset: {dataset_name}") + + if chat_mode: + demostrations.append((question_input,)) + else: + demostrations.append(question_input + "\n") + + return demostrations + + +class AGIEvalDataset(BaseDataset): + """ + Dataset wrapper for AGIEval dataset. + Data source: https://github.com/microsoft/AGIEval + This dataset class will convert the original dataset into the inference dataset. + + A few dirty data needed to be manually corrected in the origin dataset: + Issue link: https://github.com/microsoft/AGIEval/issues/16 + 1. Invalid options in line 190 in gaokao-chemistry.jsonl. + 2. Option D (They may increase in value as those same resources become rare on Earth.) missing in line 17 in sat-en-without-passage.jsonl. + 3. Option D (They may increase in value as those same resources become rare on Earth.) missing in line 17 in sat-en.jsonl. + 4. Option D (No, because the data do not indicate whether the honeybees had been infected with mites.) missing in line 57 in sat-en-without-passage.jsonl. + 5. Option D (No, because the data do not indicate whether the honeybees had been infected with mites.) missing in line 57 in sat-en.jsonl. + 6. Option D (Published theories of scientists who developed earlier models of the Venus flytrap) missing in line 98 in sat-en-without-passage.jsonl. + 7. Option D (Published theories of scientists who developed earlier models of the Venus flytrap) missing in line 98 in sat-en.jsonl. + 8. Label is empty in line 212 in jec-qa-kd.jsonl. Content is also dirty. + 9. Actually, gaokao-mathqa.jsonl is also a multi-choice dataset. See line 149 286 287. + """ + + @staticmethod + def load(path: str, logger: DistributedLogger, few_shot: bool) -> List[Dict]: + dataset = {"test": {}} + + files = glob.glob(os.path.join(path, "*.jsonl")) + files.sort() + + if few_shot: + prompt_path = os.path.join(path, "few_shot_prompts.csv") + + for file in files: + dataset_name = os.path.basename(file)[0 : -len(".jsonl")] + + few_shot_data = [] + if few_shot: + # process demo once if it is few-shot-CoT + few_shot_data = combine_prompt(prompt_path, dataset_name, load_explanation=False, chat_mode=False) + + dataset["test"][dataset_name] = {"data": []} + + file_dir = os.path.join(path, file) + + loaded_jsonl = get_json_list(file_dir) + + # It's been tested that each data sample in one subcategory have same inference arguments. + _, all_classes = get_prompt(loaded_jsonl[0], dataset_name, logger) + inference_kwargs = deepcopy(default_inference_kwargs) + if all_classes is not None and dataset_name not in multi_choice_datasets: + inference_kwargs["all_classes"] = all_classes + + if dataset_name in english_qa_datasets: + inference_kwargs["language"] = "English" + if dataset_name in chinese_qa_datasets: + inference_kwargs["language"] = "Chinese" + inference_kwargs["few_shot_data"] = few_shot_data + + dataset["test"][dataset_name]["inference_kwargs"] = inference_kwargs + + for line in loaded_jsonl: + info, all_classes = get_prompt(line, dataset_name, logger) + + # Convert multi-choice answers to a single string. + # We will convert it back when evaluating. + # We do this because if target is a list, it should be only used for multiple target answers. + if dataset_name in multi_choice_datasets: + if isinstance(info["target"], str) and len(info["target"]) > 1: + # "gaokao-mathqa" actually contain multi-choice questions. + # This if clause is specially used for it. + info["target"] = "".join(info["target"].split()) + else: + info["target"] = "".join(info["target"]) + + if isinstance(info["target"], list) and len(info["target"]) == 1: + info["target"] = info["target"][0] + + data_sample = { + "dataset": "agieval", + "split": "test", + "category": dataset_name, + "instruction": info["instruction"], + "input": "", + "output": "", + "target": info["target"], + } + + dataset["test"][dataset_name]["data"].append(data_sample) + + return dataset diff --git a/applications/ColossalEval/colossal_eval/dataset/base.py b/applications/ColossalEval/colossal_eval/dataset/base.py new file mode 100644 index 0000000000000000000000000000000000000000..45b0151b849f038590521130493207eaaff47543 --- /dev/null +++ b/applications/ColossalEval/colossal_eval/dataset/base.py @@ -0,0 +1,24 @@ +from abc import abstractstaticmethod + +from colossal_eval.utils import jdump + + +class BaseDataset: + """ + Base class for dataset wrapper. + + Args: + path: The path to the original dataset. + logger: Logger for the dataset. + """ + + def __init__(self, path, logger, few_shot): + self.dataset = self.load(path, logger, few_shot) + + def save(self, save_path): + """Save the converted dataset""" + jdump(self.dataset, save_path) + + @abstractstaticmethod + def load(path, logger): + """Load the original dataset and convert it into the inference dataset""" diff --git a/applications/ColossalEval/colossal_eval/dataset/ceval.py b/applications/ColossalEval/colossal_eval/dataset/ceval.py new file mode 100644 index 0000000000000000000000000000000000000000..32ec52087bd3f2986f2bcb80e2a590e459fb84f8 --- /dev/null +++ b/applications/ColossalEval/colossal_eval/dataset/ceval.py @@ -0,0 +1,132 @@ +import copy +import csv +import os +from typing import Dict, List + +from colossalai.logging import DistributedLogger + +from .base import BaseDataset + +ceval_subject_mapping = { + "computer_network": ["Computer Network", "计算机网络", "STEM"], + "operating_system": ["Operating System", "操作系统", "STEM"], + "computer_architecture": ["Computer Architecture", "计算机组成", "STEM"], + "college_programming": ["College Programming", "大学编程", "STEM"], + "college_physics": ["College Physics", "大学物理", "STEM"], + "college_chemistry": ["College Chemistry", "大学化学", "STEM"], + "advanced_mathematics": ["Advanced Mathematics", "高等数学", "STEM"], + "probability_and_statistics": ["Probability and Statistics", "概率统计", "STEM"], + "discrete_mathematics": ["Discrete Mathematics", "离散数学", "STEM"], + "electrical_engineer": ["Electrical Engineer", "注册电气工程师", "STEM"], + "metrology_engineer": ["Metrology Engineer", "注册计量师", "STEM"], + "high_school_mathematics": ["High School Mathematics", "高中数学", "STEM"], + "high_school_physics": ["High School Physics", "高中物理", "STEM"], + "high_school_chemistry": ["High School Chemistry", "高中化学", "STEM"], + "high_school_biology": ["High School Biology", "高中生物", "STEM"], + "middle_school_mathematics": ["Middle School Mathematics", "初中数学", "STEM"], + "middle_school_biology": ["Middle School Biology", "初中生物", "STEM"], + "middle_school_physics": ["Middle School Physics", "初中物理", "STEM"], + "middle_school_chemistry": ["Middle School Chemistry", "初中化学", "STEM"], + "veterinary_medicine": ["Veterinary Medicine", "兽医学", "STEM"], + "college_economics": ["College Economics", "大学经济学", "Social Science"], + "business_administration": ["Business Administration", "工商管理", "Social Science"], + "marxism": ["Marxism", "马克思主义基本原理", "Social Science"], + "mao_zedong_thought": ["Mao Zedong Thought", "毛泽东思想和中国特色社会主义理论体系概论", "Social Science"], + "education_science": ["Education Science", "教育学", "Social Science"], + "teacher_qualification": ["Teacher Qualification", "教师资格", "Social Science"], + "high_school_politics": ["High School Politics", "高中政治", "Social Science"], + "high_school_geography": ["High School Geography", "高中地理", "Social Science"], + "middle_school_politics": ["Middle School Politics", "初中政治", "Social Science"], + "middle_school_geography": ["Middle School Geography", "初中地理", "Social Science"], + "modern_chinese_history": ["Modern Chinese History", "近代史纲要", "Humanities"], + "ideological_and_moral_cultivation": ["Ideological and Moral Cultivation", "思想道德修养与法律基础", "Humanities"], + "logic": ["Logic", "逻辑学", "Humanities"], + "law": ["Law", "法学", "Humanities"], + "chinese_language_and_literature": ["Chinese Language and Literature", "中国语言文学", "Humanities"], + "art_studies": ["Art Studies", "艺术学", "Humanities"], + "professional_tour_guide": ["Professional Tour Guide", "导游资格", "Humanities"], + "legal_professional": ["Legal Professional", "法律职业资格", "Humanities"], + "high_school_chinese": ["High School Chinese", "高中语文", "Humanities"], + "high_school_history": ["High School History", "高中历史", "Humanities"], + "middle_school_history": ["Middle School History", "初中历史", "Humanities"], + "civil_servant": ["Civil Servant", "公务员", "Other"], + "sports_science": ["Sports Science", "体育学", "Other"], + "plant_protection": ["Plant Protection", "植物保护", "Other"], + "basic_medicine": ["Basic Medicine", "基础医学", "Other"], + "clinical_medicine": ["Clinical Medicine", "临床医学", "Other"], + "urban_and_rural_planner": ["Urban and Rural Planner", "注册城乡规划师", "Other"], + "accountant": ["Accountant", "注册会计师", "Other"], + "fire_engineer": ["Fire Engineer", "注册消防工程师", "Other"], + "environmental_impact_assessment_engineer": ["Environmental Impact Assessment Engineer", "环境影响评价工程师", "Other"], + "tax_accountant": ["Tax Accountant", "税务师", "Other"], + "physician": ["Physician", "医师资格", "Other"], +} + +default_inference_kwargs = { + "calculate_loss": False, + "all_classes": ["A", "B", "C", "D"], + "language": "Chinese", + "pretrain": False, + "max_new_tokens": 32, +} + + +def get_few_shot_data(data: List[Dict]): + few_shot_data = [] + for i in data: + few_shot_data.append(i["input"] + i["target"]) + return few_shot_data + + +class CEvalDataset(BaseDataset): + """ + Dataset class for CEval dataset. + Data source: https://huggingface.co/datasets/ceval/ceval-exam + This dataset class will convert the original dataset into the inference dataset. + """ + + @staticmethod + def load(path: str, logger: DistributedLogger, few_shot: bool) -> List[Dict]: + dataset = {"dev": {}, "test": {}} + for split in ["dev", "test"]: + files = os.listdir(os.path.join(path, split)) + files.sort() + + for file in files: + subject = file[0 : -len(f"_{split}.csv")] + subject = ceval_subject_mapping[subject][1] + + file_dir = os.path.join(path, split, file) + + dataset[split][subject] = {"data": []} + + # It's been tested that each data sample in one subcategory have same inference arguments. + dataset[split][subject]["inference_kwargs"] = copy.deepcopy(default_inference_kwargs) + + if split == "test" and few_shot: + dataset[split][subject]["inference_kwargs"]["few_shot_data"] = get_few_shot_data( + dataset["dev"][subject]["data"] + ) + + with open(file_dir, encoding="utf-8") as f: + reader = csv.reader(f) + _ = next(reader) + for row in reader: + # Dev split have answer and explanation so len(row) is 8 + # But test split doesn't contain answer and explanation, so len(row) is 6 + assert len(row) >= 6 + choices = f"A. {row[2]}\nB. {row[3]}\nC. {row[4]}\nD. {row[5]}" + data_sample = { + "dataset": "ceval", + "split": split, + "category": subject, + "instruction": f"以下是中国关于{subject}考试的单项选择题,请选出其中的正确答案。", + "input": f"题目:{row[1]}\n{choices}\n答案:", + "output": "", + "target": row[6] if split == "dev" else "", + "id": int(row[0]), + } + + dataset[split][subject]["data"].append(data_sample) + + return dataset diff --git a/applications/ColossalEval/colossal_eval/dataset/cmmlu.py b/applications/ColossalEval/colossal_eval/dataset/cmmlu.py new file mode 100644 index 0000000000000000000000000000000000000000..51f8ca14e0c87b52bdc58a413b64add5749ec460 --- /dev/null +++ b/applications/ColossalEval/colossal_eval/dataset/cmmlu.py @@ -0,0 +1,144 @@ +import copy +import csv +import os +from typing import Dict, List + +from colossalai.logging import DistributedLogger + +from .base import BaseDataset + +cmmlu_subject_mapping = { + "agronomy": "农学", + "anatomy": "解剖学", + "ancient_chinese": "古汉语", + "arts": "艺术学", + "astronomy": "天文学", + "business_ethics": "商业伦理", + "chinese_civil_service_exam": "中国公务员考试", + "chinese_driving_rule": "中国驾驶规则", + "chinese_food_culture": "中国饮食文化", + "chinese_foreign_policy": "中国外交政策", + "chinese_history": "中国历史", + "chinese_literature": "中国文学", + "chinese_teacher_qualification": "中国教师资格", + "clinical_knowledge": "临床知识", + "college_actuarial_science": "大学精算学", + "college_education": "大学教育学", + "college_engineering_hydrology": "大学工程水文学", + "college_law": "大学法律", + "college_mathematics": "大学数学", + "college_medical_statistics": "大学医学统计", + "college_medicine": "大学医学", + "computer_science": "计算机科学", + "computer_security": "计算机安全", + "conceptual_physics": "概念物理学", + "construction_project_management": "建设工程管理", + "economics": "经济学", + "education": "教育学", + "electrical_engineering": "电气工程", + "elementary_chinese": "小学语文", + "elementary_commonsense": "小学常识", + "elementary_information_and_technology": "小学信息技术", + "elementary_mathematics": "初等数学", + "ethnology": "民族学", + "food_science": "食品科学", + "genetics": "遗传学", + "global_facts": "全球事实", + "high_school_biology": "高中生物", + "high_school_chemistry": "高中化学", + "high_school_geography": "高中地理", + "high_school_mathematics": "高中数学", + "high_school_physics": "高中物理学", + "high_school_politics": "高中政治", + "human_sexuality": "人类性行为", + "international_law": "国际法学", + "journalism": "新闻学", + "jurisprudence": "法理学", + "legal_and_moral_basis": "法律与道德基础", + "logical": "逻辑学", + "machine_learning": "机器学习", + "management": "管理学", + "marketing": "市场营销", + "marxist_theory": "马克思主义理论", + "modern_chinese": "现代汉语", + "nutrition": "营养学", + "philosophy": "哲学", + "professional_accounting": "专业会计", + "professional_law": "专业法学", + "professional_medicine": "专业医学", + "professional_psychology": "专业心理学", + "public_relations": "公共关系", + "security_study": "安全研究", + "sociology": "社会学", + "sports_science": "体育学", + "traditional_chinese_medicine": "中医中药", + "virology": "病毒学", + "world_history": "世界历史", + "world_religions": "世界宗教", +} + +default_inference_kwargs = { + "calculate_loss": True, + "all_classes": ["A", "B", "C", "D"], + "language": "Chinese", + "pretrain": False, + "max_new_tokens": 32, +} + + +def get_few_shot_data(data: List[Dict]): + few_shot_data = [] + for i in data: + few_shot_data.append(i["input"] + i["target"]) + return few_shot_data + + +class CMMLUDataset(BaseDataset): + """ + Dataset class for CMMLU dataset. + Data source: https://github.com/haonan-li/CMMLU/tree/master/data + This dataset class will convert the original dataset into the inference dataset. + """ + + @staticmethod + def load(path: str, logger: DistributedLogger, few_shot: bool) -> List[Dict]: + dataset = {"dev": {}, "test": {}} + for split in ["dev", "test"]: + files = os.listdir(os.path.join(path, split)) + files.sort() + + for file in files: + subject = file[0 : -len(".csv")] + subject = cmmlu_subject_mapping[subject] + + file_dir = os.path.join(path, split, file) + + dataset[split][subject] = {"data": []} + + # It's been tested that each data sample in one subcategory have same inference arguments. + dataset[split][subject]["inference_kwargs"] = copy.deepcopy(default_inference_kwargs) + + if split == "test" and few_shot: + dataset[split][subject]["inference_kwargs"]["few_shot_data"] = get_few_shot_data( + dataset["dev"][subject]["data"] + ) + + with open(file_dir, encoding="utf-8") as f: + reader = csv.reader(f) + _ = next(reader) + for row in reader: + assert len(row) == 7 + choices = f"A. {row[2]}\nB. {row[3]}\nC. {row[4]}\nD. {row[5]}" + data_sample = { + "dataset": "cmmlu", + "split": split, + "category": subject, + "instruction": f"以下是关于{subject}的单项选择题,请直接给出正确答案的选项。", + "input": f"题目:{row[1]}\n{choices}\n答案:", + "output": "", + "target": row[6], + } + + dataset[split][subject]["data"].append(data_sample) + + return dataset diff --git a/applications/ColossalEval/colossal_eval/dataset/colossalai.py b/applications/ColossalEval/colossal_eval/dataset/colossalai.py new file mode 100644 index 0000000000000000000000000000000000000000..54ea478ae5d65ea9de0c3f84008d31539d25b62c --- /dev/null +++ b/applications/ColossalEval/colossal_eval/dataset/colossalai.py @@ -0,0 +1,70 @@ +from collections import defaultdict +from copy import deepcopy +from typing import Dict, List + +from colossal_eval.utils import jload + +from colossalai.logging import DistributedLogger + +from .base import BaseDataset + +default_inference_kwargs = { + "calculate_loss": False, + "all_classes": None, + "language": "Chinese", + "pretrain": False, + "max_new_tokens": 256, +} + +# You can add your own subcategory questions and specify whether it is a single-choice question or has target answers and need to calculate loss. +single_choice_question = set() +calculate_loss = set() + + +def get_data_per_category(data): + data_per_category = defaultdict(list) + for item in data: + category = item["category"] + data_per_category[category].append(item) + + return data_per_category + + +class ColossalDataset(BaseDataset): + """ + Dataset class for Colossal dataset. + This dataset class will convert the original dataset into the inference dataset. + """ + + @staticmethod + def load(path: str, logger: DistributedLogger, few_shot: bool) -> List[Dict]: + dataset = {"test": {}} + data = jload(path) + data_per_category = get_data_per_category(data) + categories = list(data_per_category.keys()) + + for category in categories: + dataset["test"][category] = {"data": []} + category_data = data_per_category[category] + + dataset["test"][category]["inference_kwargs"] = deepcopy(default_inference_kwargs) + + if category in calculate_loss: + dataset["test"][category]["inference_kwargs"]["calculate_loss"] = True + if category in single_choice_question: + dataset["test"][category]["inference_kwargs"]["all_classes"] = ["A", "B", "C", "D"] + + for item in category_data: + data_sample = { + "dataset": "colossal", + "split": "test", + "category": category, + "instruction": item["instruction"], + "input": item["input"], + "output": "", + "target": item["target"], + "id": item["id"], + } + dataset["test"][category]["data"].append(data_sample) + + return dataset diff --git a/applications/ColossalEval/colossal_eval/dataset/gaokaobench.py b/applications/ColossalEval/colossal_eval/dataset/gaokaobench.py new file mode 100644 index 0000000000000000000000000000000000000000..7bf0639e48829b230813e0fc7387e6165803abc4 --- /dev/null +++ b/applications/ColossalEval/colossal_eval/dataset/gaokaobench.py @@ -0,0 +1,122 @@ +import json +import os +import re +from copy import deepcopy +from typing import Dict, List + +from colossalai.logging import DistributedLogger + +from .base import BaseDataset + +multi_choice_datasets = [ + "Chinese Lang and Usage MCQs", + "Chinese Modern Lit", + "English Fill in Blanks", + "English Reading Comp", + "Geography MCQs", + "Physics MCQs", + "English Cloze Test", +] + +chinese_qa_datasets = [ + "Biology MCQs", + "Chemistry MCQs", + "Chinese Lang and Usage MCQs", + "Chinese Modern Lit", + "Geography MCQs", + "History MCQs", + "Math I MCQs", + "Math II MCQs", + "Physics MCQs", + "Political Science MCQs", +] +english_qa_datasets = ["English MCQs", "English Fill in Blanks", "English Reading Comp", "English Cloze Test"] + +default_inference_kwargs = { + "calculate_loss": True, + "all_classes": None, + "language": "Chinese", + "pretrain": False, + "max_new_tokens": 32, +} + + +def get_all_classes(instruction: str): + letters = "ABCDEFGHIJKLMNOPQRSTUVWXYZ" + pattern = r"([A-Z]\. |[A-Z].|[A-Z]\.)" + options = sorted(list(set(re.findall(pattern, instruction)))) + options = sorted(list(set([string[0] for string in options]))) + + for i in range(len(options)): + if options[i] == letters[i]: + continue + else: + return options[0:i] + return options + + +class GaoKaoBenchDataset(BaseDataset): + """ + Dataset class for GAOKAO-Bench dataset. + Data source: https://github.com/OpenLMLab/GAOKAO-Bench/tree/main/data + This dataset class will convert the original dataset into the inference dataset. + + A few typos needed to be manually corrected in the origin dataset, some of the following is fixed. + Issue link: https://github.com/OpenLMLab/GAOKAO-Bench/issues/20 + 1. Option C missing in index 111 in 2010-2022_Chemistry_MCQs.json + 2. Option B missing "." after it in index 16 in 2012-2022_English_Cloze_Test.json + 3. Option G missing "." after it in index 23 in 2012-2022_English_Cloze_Test.json + """ + + @staticmethod + def load(path: str, logger: DistributedLogger, few_shot: bool) -> List[Dict]: + dataset = {"test": {}} + for category in ["Fill-in-the-blank_Questions", "Multiple-choice_Questions", "Open-ended_Questions"]: + files = os.listdir(os.path.join(path, "data", category)) + files.sort() + + for file in files: + subject = file[10:-5].split("_") + subject = " ".join(subject) + dataset["test"][subject] = {"data": []} + + file_dir = os.path.join(path, "data", category, file) + + with open(file_dir, encoding="utf-8") as f: + data = json.load(f) + + # It's been tested that each data sample in one subcategory have same inference arguments. + inference_kwargs = deepcopy(default_inference_kwargs) + if category == "Multiple-choice_Questions" and subject not in multi_choice_datasets: + all_classes = get_all_classes(data["example"][0]["question"]) + inference_kwargs["all_classes"] = all_classes + if subject in english_qa_datasets: + inference_kwargs["language"] = "English" + if subject in chinese_qa_datasets: + inference_kwargs["language"] = "Chinese" + + dataset["test"][subject]["inference_kwargs"] = inference_kwargs + + for sample in data["example"]: + # Convert multi-choice answers to a single string. + # We will convert it back when evaluating. + # We do this because if target is a list, it should be only used for multiple target answers. + if subject in multi_choice_datasets: + sample["answer"] = "".join(sample["answer"]) + + if isinstance(sample["answer"], list) and len(sample["answer"]) == 1: + sample["answer"] = sample["answer"][0] + + data_sample = { + "dataset": "gaokaobench", + "split": "test", + "category": f"{category[:-10]}-{subject}", + "instruction": sample["question"].strip() + "\n答案:", + "input": "", + "output": "", + "target": sample["answer"], + } + + dataset["test"][subject]["data"].append(data_sample) + + return dataset diff --git a/applications/ColossalEval/colossal_eval/dataset/longbench.py b/applications/ColossalEval/colossal_eval/dataset/longbench.py new file mode 100644 index 0000000000000000000000000000000000000000..9ea5e3c7d77fc9f8ee8fd496b67fefcdf625257a --- /dev/null +++ b/applications/ColossalEval/colossal_eval/dataset/longbench.py @@ -0,0 +1,120 @@ +import os +from copy import deepcopy +from typing import Dict, List + +from colossal_eval.utils import get_json_list + +from colossalai.logging import DistributedLogger + +from .base import BaseDataset + +dataset2prompt = { + "narrativeqa": "You are given a story, which can be either a novel or a movie script, and a question. Answer the question asconcisely as you can, using a single phrase if possible. Do not provide any explanation.\n\nStory: {context}\n\nNow, answer the question based on the story asconcisely as you can, using a single phrase if possible. Do not provide any explanation.\n\nQuestion: {input}\n\nAnswer:", + "qasper": 'You are given a scientific article and a question. Answer the question as concisely as you can, using a single phrase or sentence if possible. If the question cannot be answered based on the information in the article, write "unanswerable". If the question is a yes/no question, answer "yes", "no", or "unanswerable". Do not provide any explanation.\n\nArticle: {context}\n\n Answer the question based on the above article as concisely as you can, using a single phrase or sentence if possible. If the question cannot be answered based on the information in the article, write "unanswerable". If the question is a yes/no question, answer "yes", "no", or "unanswerable". Do not provide any explanation.\n\nQuestion: {input}\n\nAnswer:', + "multifieldqa_en": "Read the following text and answer briefly.\n\n{context}\n\nNow, answer the following question based on the above text, only give me the answer and do not output any other words.\n\nQuestion: {input}\nAnswer:", + "multifieldqa_zh": "阅读以下文字并用中文简短回答:\n\n{context}\n\n现在请基于上面的文章回答下面的问题,只告诉我答案,不要输出任何其他字词。\n\n问题:{input}\n回答:", + "hotpotqa": "Answer the question based on the given passages. Only give me the answer and do not output any other words.\n\nThe following are given passages.\n{context}\n\nAnswer the question based on the given passages. Only give me the answer and do not output any other words.\n\nQuestion: {input}\nAnswer:", + "2wikimqa": "Answer the question based on the given passages. Only give me the answer and do not output any other words.\n\nThe following are given passages.\n{context}\n\nAnswer the question based on the given passages. Only give me the answer and do not output any other words.\n\nQuestion: {input}\nAnswer:", + "musique": "Answer the question based on the given passages. Only give me the answer and do not output any other words.\n\nThe following are given passages.\n{context}\n\nAnswer the question based on the given passages. Only give me the answer and do not output any other words.\n\nQuestion: {input}\nAnswer:", + "dureader": "请基于给定的文章回答下述问题。\n\n文章:{context}\n\n请基于上述文章回答下面的问题。\n\n问题:{input}\n回答:", + "gov_report": "You are given a report by a government agency. Write a one-page summary of the report.\n\nReport:\n{context}\n\nNow, write a one-page summary of the report.\n\nSummary:", + "qmsum": "You are given a meeting transcript and a query containing a question or instruction. Answer the query in one or more sentences.\n\nTranscript:\n{context}\n\nNow, answer the query based on the above meeting transcript in one or more sentences.\n\nQuery: {input}\nAnswer:", + "multi_news": "You are given several news passages. Write a one-page summary of all news. \n\nNews:\n{context}\n\nNow, write a one-page summary of all the news.\n\nSummary:", + "vcsum": "下面有一段会议记录,请你阅读后,写一段总结,总结会议的内容。\n会议记录:\n{context}\n\n会议总结:", + "trec": "Please determine the type of the question below. Here are some examples of questions.\n\n{context}\n{input}", + "triviaqa": "Answer the question based on the given passage. Only give me the answer and do not output any other words. The following are some examples.\n\n{context}\n\n{input}", + "samsum": "Summarize the dialogue into a few short sentences. The following are some examples.\n\n{context}\n\n{input}", + "lsht": "请判断给定新闻的类别,下面是一些例子。\n\n{context}\n{input}", + "passage_count": "There are some paragraphs below sourced from Wikipedia. Some of them may be duplicates. Please carefully read these paragraphs and determine how many unique paragraphs there are after removing duplicates. In other words, how many non-repeating paragraphs are there in total?\n\n{context}\n\nPlease enter the final count of unique paragraphs after removing duplicates. The output format should only contain the number, such as 1, 2, 3, and so on.\n\nThe final answer is: ", + "passage_retrieval_en": 'Here are 30 paragraphs from Wikipedia, along with an abstract. Please determine which paragraph the abstract is from.\n\n{context}\n\nThe following is an abstract.\n\n{input}\n\nPlease enter the number of the paragraph that the abstract is from. The answer format must be like "Paragraph 1", "Paragraph 2", etc.\n\nThe answer is: ', + "passage_retrieval_zh": '以下是若干段落文字,以及其中一个段落的摘要。请确定给定的摘要出自哪一段。\n\n{context}\n\n下面是一个摘要\n\n{input}\n\n请输入摘要所属段落的编号。答案格式必须是"段落1","段落2"等格式\n\n答案是:', + "lcc": "Please complete the code given below. \n{context}Next line of code:\n", + "repobench-p": "Please complete the code given below. \n{context}{input}Next line of code:\n", +} + +dataset2maxlen = { + "narrativeqa": 128, + "qasper": 128, + "multifieldqa_en": 64, + "multifieldqa_zh": 64, + "hotpotqa": 32, + "2wikimqa": 32, + "musique": 32, + "dureader": 128, + "gov_report": 512, + "qmsum": 512, + "multi_news": 512, + "vcsum": 512, + "trec": 64, + "triviaqa": 32, + "samsum": 128, + "lsht": 64, + "passage_count": 32, + "passage_retrieval_en": 32, + "passage_retrieval_zh": 32, + "lcc": 64, + "repobench-p": 64, +} + +default_inference_kwargs = { + "calculate_loss": True, + "all_classes": None, + "language": "Chinese", + "pretrain": False, + "max_new_tokens": 32, +} + + +class LongBenchDataset(BaseDataset): + """ + Dataset class for LongBench dataset. + Data source: https://huggingface.co/datasets/THUDM/LongBench + This dataset class will convert the original dataset into the inference dataset. + + Issue link: https://github.com/THUDM/LongBench/issues/15 (fixed) + There are duplicate target answers in `nq.jsonl`, but this doesn't affect evaluation results. + Also doesn't affect perplexity calculation (the program only need to select the minimum loss). + """ + + @staticmethod + def load(path: str, logger: DistributedLogger) -> List[Dict]: + dataset = {"test": {}} + + files = os.listdir(path) + files.sort() + + for file in files: + category = file[0:-6] + + if category.endswith("_e"): + continue + + dataset["test"][category] = {"data": []} + + file_dir = os.path.join(path, file) + + loaded_jsonl = get_json_list(file_dir) + + # It's been tested that each data sample in one subcategory have same inference arguments. + inference_kwargs = deepcopy(default_inference_kwargs) + if loaded_jsonl[0]["all_classes"] is not None: + inference_kwargs["all_classes"] = loaded_jsonl[0]["all_classes"] + inference_kwargs["max_new_tokens"] = dataset2maxlen[category] + dataset["test"][category]["inference_kwargs"] = inference_kwargs + + for sample in loaded_jsonl: + prompt = dataset2prompt[category].format(**sample) + + data_sample = { + "dataset": "longbench", + "split": "test", + "category": category, + "instruction": prompt, + "input": "", + "output": "", + "target": sample["answers"], + } + + dataset["test"][category]["data"].append(data_sample) + + return dataset diff --git a/applications/ColossalEval/colossal_eval/dataset/mmlu.py b/applications/ColossalEval/colossal_eval/dataset/mmlu.py new file mode 100644 index 0000000000000000000000000000000000000000..b89c0a13cff1d82d5435ed52f446fd4092bd5093 --- /dev/null +++ b/applications/ColossalEval/colossal_eval/dataset/mmlu.py @@ -0,0 +1,73 @@ +import copy +import csv +import os +from typing import Dict, List + +from colossalai.logging import DistributedLogger + +from .base import BaseDataset + +default_inference_kwargs = { + "calculate_loss": True, + "all_classes": ["A", "B", "C", "D"], + "language": "English", + "pretrain": False, + "max_new_tokens": 32, +} + + +def get_few_shot_data(data: List[Dict]): + few_shot_data = [] + for i in data: + few_shot_data.append(i["input"] + i["target"]) + return few_shot_data + + +class MMLUDataset(BaseDataset): + """ + Dataset class for MMLU dataset. + Data source: https://github.com/hendrycks/test + This dataset class will convert the original dataset into the inference dataset. + """ + + @staticmethod + def load(path: str, logger: DistributedLogger, few_shot: bool) -> List[Dict]: + dataset = {"dev": {}, "test": {}} + for split in ["dev", "test"]: + files = os.listdir(os.path.join(path, split)) + files.sort() + + for file in files: + subject = file[0 : -len(f"_{split}.csv")].split("_") + subject = " ".join([word.title() if word != "us" else "US" for word in subject]) + + file_dir = os.path.join(path, split, file) + + dataset[split][subject] = {"data": [], "inference_kwargs": {}} + + # It's been tested that each data sample in one subcategory have same inference arguments. + dataset[split][subject]["inference_kwargs"] = copy.deepcopy(default_inference_kwargs) + + if split == "test" and few_shot: + dataset[split][subject]["inference_kwargs"]["few_shot_data"] = get_few_shot_data( + dataset["dev"][subject]["data"] + ) + + with open(file_dir, encoding="utf-8") as f: + reader = csv.reader(f) + for row in reader: + assert len(row) == 6 + choices = f"A. {row[1]}\nB. {row[2]}\nC. {row[3]}\nD. {row[4]}" + data_sample = { + "dataset": "mmlu", + "split": split, + "category": subject, + "instruction": f"The following is a single-choice question on {subject}. Answer the question by replying A, B, C or D.", + "input": f"Question: {row[0]}\n{choices}\nAnswer: ", + "output": "", + "target": row[5], + } + + dataset[split][subject]["data"].append(data_sample) + + return dataset diff --git a/applications/ColossalEval/colossal_eval/evaluate/GPT Evaluation.md b/applications/ColossalEval/colossal_eval/evaluate/GPT Evaluation.md new file mode 100644 index 0000000000000000000000000000000000000000..37fbda4c8647322c0cd9db0aea7ca7139771b4a4 --- /dev/null +++ b/applications/ColossalEval/colossal_eval/evaluate/GPT Evaluation.md @@ -0,0 +1,248 @@ +# GPT Evaluation +## Table of Contents +- [Overview](#overview) +- [GPT Evaluation](#gpt-evaluation) + - [Evaluation Category](#evaluation-category) + - [Evaluation Category Examples](#evaluation-category-examples) + - [Evaluation Metrics](#evaluation-metrics) +- [Evaluation Process](#evaluation-process) + - [Data Format](#data-format) + - [Prompt](#prompt) + - [Battle Prompt](#battle-prompt) + - [Evaluation Prompt](#evaluation-prompt) + - [Evaluation](#evaluation) + - [Configuration](#configuration) + - [Evaluate](#evaluate) +- [FAQ](#faq) +- [Citations](#citations) + + +## Overview + +In this directory, we introduce how you can evaluate your model using GPTs. It is now available for evaluation of both Chinese and English capability and we provide the following functions: + +* Compare the performance of two different models (battle). +* Rate the model according to pre-defined metrics using prompting design. +* Rate the model according to pre-defined metrics with additional reference answer using prompting design. + +## GPT Evaluation + +### Evaluation Category + +Our evaluation pipeline can examine the model's capability using different categories of questions. The following table includes some example categories. You can add your own questions. + +| Evaluation Category | Description | +| :-----------------: | :----------------------------------------------------------- | +| Brainstorming | Models are asked to generate a range of creative and diverse ideas according to the question. The capability of creativity is required. | +| Chat | Models are asked to continue a multi-round dialogue given the roles involved. The capability of understanding, memorizing previous rounds of the dialogue and answering according to the persona provided is required. | +| Generation | Models are asked to generate an email, letter, article, etc. The capability of generating texts in a high quality and human-written way is required. | +| Open QA | Models are asked to answer an open QA question(without context provided). The capability of answering questions with the models' own knowledge base is required. | +| Roleplay | Models are asked to play the role provided. The capability of engaging in the scenario and effectively interacting with the user is required. | + + +### Evaluation Category Examples +To better understand each evaluation category, here are some example questions provided. Example questions are in the `configs/gpt_evaluation/data` folder. + + +| Evaluation Category | Chinese Example | English Example | +| :-----------------: | :----------------------------------------------------------- | :----------------------------------------------------------- | +| Brainstorming | 列举一些可以促进头发生长的食物。 | How do you properly chop an onion without crying? | +| Chat | 基于以下角色信息完成一段对话。小张是一名新手爱好者,对养鸡有浓厚的兴趣。老李是一名有丰富经验的养鸡大师。
        小张:您好,老李,我最近开始对养鸡感兴趣了,想请教您一些问题。
        老李:你好,小张,我很乐意帮助你。你想问些什么?
        小张:我想知道如何确定鸡的品种和性别?
        老李:确切的品种可以通过鸡的外貌特征来确定,而性别一般是通过鸡卵的大小和形状来判断。还有什么问题吗?
        小张:
        | Complete a dialogue based on the following character information. Alex: A novice writer who is struggling to find inspiration and develop his writing skills. Emma: A successful author with many published works, providing guidance and advice to Alex.
        Alex: Hi Emma, I have been writing for a while now but can't seem to make any progress. Can you give me any advice?
        Emma: Hi Alex, sure. What kind of writing are you doing?
        Alex: I'm trying to write a novel, but I just can't seem to find any inspiration.
        Emma:
        | +| Generation | 请为一家咖啡店编写一篇简短的广告语,吸引更多的顾客。 | Write a set of guidelines for first-time pet owners on how to properly care for a new puppy. | +| Open QA | 解释什么是RNA病毒和DNA病毒。 | Explain the process of osmosis in biological systems. | +| Roleplay | 我要你把我写的句子翻译成表情符号。我会写句子,你会用表情符号表达它。我只是想让你用表情符号来表达它。除了表情符号,我不希望你回复任何内容。当我需要用中文告诉你一些事情时,我会用 {} 这样的大括号括起来。我的第一句话是“{我的职业是消防员。}” | I want you to act as a rapper. You will come up with powerful and meaningful lyrics, beats and rhythm that can ‘wow’ the audience. Your lyrics should have an intriguing meaning and message which people can relate too. When it comes to choosing your beat, make sure it is catchy yet relevant to your words, so that when combined they make an explosion of sound everytime! My first request is "I need a rap song about finding strength within yourself." | + +### Evaluation Metrics + +GPT evaluation uses GPT models to evaluate the prediction of different models and different pre-defined evaluation metrics are applied to different categories. The following table shows the 10 pre-defined evaluation metrics both in Chinese and English: + +| Evaluation Metric | Prompt Words | CoT(Chain-of-Thought) | +| :-------------------: | :----------------------------------------------------------- | :----------------------------------------------------------- | +| 语言组织
        (Language organization) | 语言组织(1-5):答案语言是否流畅、连贯,使用正确的语法,具有一定逻辑性,使用恰当的连接词、过渡词等等。

        Language organization (1-5): whether the answer language is fluent and coherent, uses correct grammar, has a certain logic, uses appropriate connecting words, transition words, etc. | 1. 阅读答案,并检查是否有语法错误、用词不当或其他显著的错误。
        2. 检查答案是否具有逻辑性,能够按照合理的顺序传达信息并且能够自圆其说
        3. 确定答案是否与问题或主题相关,并且能够传达清晰的信息。
        4. 检查答案是否连贯,是否使用适当的转换和过渡来保持句子和段落之间的连贯性。
        5. 检查答案是否具有明确的结构和组织方式,使得读者可以轻松理解信息的层次和结构。
        6. 根据以上因素综合评估答案的语言组织,并给出一个1到5的分数,其中5表示语言组织非常好,而1表示语言组织非常差。

        1. Read the answers and check for grammatical errors, poor word choice, or other significant mistakes.
        2. Check that the answer is logical, conveys the information in a logical order, and is self-explanatory.
        3. Determine if the answer is relevant to the question or topic and conveys a clear message.
        4. Check that the answer is coherent and that appropriate transitions and switches are used to maintain coherence between sentences and paragraphs.
        5. Check that the answer is clearly structured and organized in such a way that the reader can easily understand the hierarchy and structure of the information.
        6. Evaluate the linguistic organization of the answer based on a combination of the above factors and give a score of 1 to 5, where 5 indicates very good linguistic organization and 1 indicates very poor linguistic organization. | +| 切题
        (Relevance) | 切题(1-5):答案内容是否切题,不答非所问,并且严格遵照题目要求。

        Relevance (1-5): whether the content of the answer is relevant to the topic, does not answer the wrong question, and strictly follows the requirements of the topic. | 1. 阅读题目,确定题目所问的问题是什么,以及需要回答哪些方面的问题。
        2. 阅读答案,确认答案是否直接回答了题目所问的问题。
        3. 检查答案是否严格遵照了题目的要求,包括答题方式、答题长度、答题格式等等。
        4. 根据以上因素综合评估答案的切题程度,并给出一个1到5的分数,其中5表示答案非常切题,而1表示答案完全没有切题。

        1. Read the question to determine what the question asks and what aspects of the question need to be answered.
        2. Read the answers to make sure that they directly answer the question asked.
        3. Check that the answer follows the requirements of the question, including the way it is answered, the length of the answer, the format of the answer, etc.
        4. Evaluate how relevant the answer is based on the above factors and give a score of 1 to 5, where 5 means the answer is very relevant and 1 means the answer is not relevant at all. | +| 创意性
        (Creativity) | 创意性(1-5):某些头脑风暴问题可能需要答案具有创意,提出新的思路。

        Creativity (1-5): Some brainstorming questions may require answers that are creative and suggest new ideas. | 1. 仔细阅读所提供的头脑风暴问题,确保你理解问题的要点和背景。
        2. 根据你的知识和经验,判断所提供的答案是否可行。如果答案不可行,则创意性评分可能会受到影响。
        3. 考虑答案中是否包含新颖的想法或独特的思路。答案可能与已知的解决方案有所重叠,但仍然可以被认为是有创意的,只要它提供了新的角度或方法来解决问题。
        4. 根据答案的创意性,给出一个1到5的评分。如果答案缺乏创意,则应给出一个较低的评分。如果答案具有创意并提供了新的思路,应给出一个较高的评分。

        1. Read the provided brainstorming questions carefully to make sure you understand the gist and context of the questions.
        2. Based on your knowledge and experience, determine if the answers provided are feasible. If the answer is not feasible, the creativity score may be affected.
        3. Consider whether the answer contains novel ideas or unique thoughts. An answer may overlap with a known solution and still be considered creative, as long as it offers a new perspective or approach to the problem.
        4. Give a score of 1 to 5 depending on the creativity of the answer. If the answer lacks creativity, a lower score should be given. If the answer is creative and provides a new idea, a higher score should be given. | +| 实用性
        (Practicality) | 实用性(1-5):某些头脑风暴问题可能需要答案提出实用的建议或解决方法。

        Practicality (1-5): Some brainstorming questions may require answers to suggest practical suggestions or solutions. | 1. 仔细阅读所提供的头脑风暴问题,确保你理解问题的要点和背景。
        2. 根据你的知识和经验,判断所提供的答案是否可行。如果答案不可行,则实用性评分可能会受到影响。
        3. 考虑答案中提出的建议或解决方法是否实用并可行。答案可能看起来很好,但如果无法实现或应用,则实用性评分可能会受到影响。
        4. 根据答案的实用性,给出一个1到5的评分。如果答案缺乏实用性,则应给出一个较低的评分。如果答案提出了实用的建议或解决方法,并且可以很好地解决问题,则应给出一个较高的评分。

        1. Read the provided brainstorming questions carefully to make sure you understand the gist and context of the questions.
        2. Based on your knowledge and experience, determine if the answers provided are feasible. If the answer is not feasible, the practicality score may be affected.
        3. Consider whether the suggestions or solutions presented in the answer are practical and workable. The answer may look good, but if it cannot be implemented or applied, the practicality score may be affected.
        4. Give a score of 1 to 5 depending on the practicality of the answer. If the answer lacks practicality, a lower score should be given. If the answer makes a practical suggestion or solution and solves the problem well, a higher score should be given. | +| 正确性
        (Correctness) | 正确性(1-5):正确性(1-5):答案是否正确。

        Correctness (1-5): whether the answer is correct or not. | 1. 仔细阅读题目,尝试自己回答该问题。
        2. 检查答案的准确性。您可以使用已知的事实或研究来验证答案是否正确。如果答案是正确的,则可以将正确性得分为5分。如果答案是部分正确的,则可以给予适当的得分,例如2分、3分或4分。如果答案完全不正确,则只得1分。

        1. Read the question carefully and try to answer the question yourself.
        2. Check the correctness of the answer. You can use known facts or research to verify that the answer is correct. If the answer is correct, you can give a score of 5 for correctness. If the answer is partially correct, an appropriate score, such as 2, 3, or 4, may be given. If the answer is completely incorrect, only 1 point is awarded. | +| 自然
        (Naturalness) | 自然(1-5):答案是否自然,并且符合问题给定的身份。

        Naturalness (1-5): whether the answer is natural and fits the identity given by the question. | 1. 阅读题目,确定题目提供的身份信息。
        2. 检查答案内容是否符合题目给定的身份。
        3. 根据以上因素,对该回答的自然性进行打分,分数从1到5,其中1表示不自然,5表示非常自然,并符合问题给定的身份。

        1. Read the question and determine the identity information provided in the question.
        2. Check whether the content of the answer matches the identity given in the question.
        3. Based on the above factors, score the naturalness of the response on a scale from 1 to 5, where 1 means unnatural and 5 means very natural and in accordance with the identity given in the question. | +| 参与感
        (Engagingness) | 参与感(1-5):答案是否对前面的对话内容做出了恰当的反应,是否理解对话的语境和背景。

        Engagingness (1-5): whether the answer responds appropriately to the content of the preceding conversation and whether it understands the context and background of the conversation. | 1. 阅读题目,确定对话的语境和背景。
        2. 检查答案是否充分理解对话的语境和背景,能否自然地融入到对话中而不显得突兀。
        3. 根据以上因素,对该回答的参与感进行打分,分数从1到5,其中1表示没有参与感,5表示非常有参与感,并且恰当地理解了对话的语境和背景。

        1. Read the questions to determine the context and background of the dialogue.
        2. Check that the answer fully understands the context and background of the conversation and that it fits naturally into the conversation without seeming abrupt.
        3. Based on the above factors, rate the response's engagement on a scale from 1 to 5, where 1 means not engaged and 5 means very engaged and appropriately understands the context and background of the conversation. | +| 合理性
        (Reasonableness) | 合理性(1-5):答案是否能够与前面的对话内容形成逻辑上的衔接,是否符合常理,能否在这个上下文中合理存在。

        Reasonableness (1-5): Whether the answer can form a logical connection with the content of the previous dialogue, whether it is consistent with common sense, and whether it can reasonably exist in this context. | 1. 阅读题目,确定对话的主题以及问题期望的回答方向。
        2. 判断答案是否能够与前面的对话内容形成逻辑上的衔接,是否符合常理,能否在这个上下文中合理存在。
        3. 根据以上因素,对该回答的合理性进行打分,分数从1到5,其中1表示不合理,5表示非常合理,并且能够与前面的对话内容形成逻辑上的衔接,并符合常理。

        1. Read the question and determine the topic of the conversation and the direction the question expects the answer to go.
        2. Determine whether the answer can be logically connected to the preceding conversation, whether it makes common sense, and whether it can reasonably exist in this context.
        3. Based on the above factors, rate the reasonableness of the answer on a scale from 1 to 5, where 1 means unreasonable and 5 means very reasonable and able to form a logical connection with the preceding dialogue content and consistent with common sense. | +| 多样性
        (Diversity) | 多样性(1-5):答案使用语言是否优美,具有有一定的创造性和想象力。然而,回答也应该保持合理和适度,不要过于夸张或离题。

        Diversity (1-5): Whether the answers use beautiful language and have some creativity and imagination. However, answers should also be kept reasonable and moderate, not overly exaggerated or off-topic. | 1. 仔细阅读整个回答,确保完全理解回答所表达的内容和主题。
        2. 在阅读回答的同时,注意语言的质量,例如措辞是否正确,语言是否生动等。
        3. 检查回答的创造性和想象力,看看回答是否能够吸引人阅读下去。
        4. 检查回答的合理性和适度,看看回答是否夸张或离题。5. 将多样性的评分打分在1到5之间,5分表示回答的质量很好,能够吸引人阅读,1分表示回答的内容生硬或者有离题的问题。

        1. Read the entire response carefully to ensure that you fully understand the content and theme expressed in the response.
        2. While reading the response, pay attention to the quality of the language, such as whether the wording is correct and the language is vivid.
        3. Check the creativity and imagination of the response to see if the response is engaging to read on.
        4. Check the reasonableness and appropriateness of the responses to see if the responses are exaggerated or off-topic.
        5. Rate the diversity on a scale of 1 to 5, with a 5 indicating a good quality response that is engaging to read and a 1 indicating a raw response or a question that is off-topic. | +| 保真度
        (Fidelity) | 保真度(1-5):答案是否能够严格遵守角色的设定回答给定的请求。

        Fidelity (1-5): whether the answer is able to answer the given request in strict compliance with the role setting. | 1. 仔细阅读问题,了解角色在问题中的设定和表现,包括职业、背景、观点、性格等方面。
        阅读题目的请求,确认回答请求时需要注意的细节。
        3. 对比提供的回答与该角色的设定,评估回答是否能够严格遵守角色的设定。
        4. 结合以上评估结果给出保真度的评分,范围从1到5分,其中1分表示回答与角色设定完全不符,5分表示回答完全符合角色设定且满足给定请求。

        1. Read the question carefully to understand how the character is set up and represented in the question, including aspects such as occupation, background, point of view, and personality.
        2. Read the question's request and confirm the details that need to be taken into account when answering the request.
        3. Compare the provided answer with the setting of the role and assess whether the answer can strictly adhere to the setting of the role.
        4. Combine the results of the above assessment to give a fidelity score ranging from 1 to 5, where a score of 1 means that the response does not match the persona at all, and a score of 5 means that the response fully complies with the persona and satisfies the given request. | + +GPT models evaluate the quality of model predictions based on the given prompt words and gives a score between 1-5. + +> **NOTE 1:** You can find all the prompt words and CoT(Chain-of-Thought) in `configs/gpt_evaluation/prompt/evaluation_prompt`. + +> **NOTE 2:** To add customized metrics, you can refer to [FAQ](#faq). + +## Evaluation Process + +### Data Format + +A JSON file contains one list. Each element in the list is a target answer / prediction record for one instruction / question. +An element should have the following fields: + +* `category` (str, compulsory): The category of the instruction / question. +* `instruction` (str, compulsory): The instruction / question for the LLM. +* `input` (str, optional): The additional context of the instruction / question. +* `output` (str, optional): The model output of the instruction, models will fill in this field during inference time. +* `target` (str, optional): The target answer for the instruction. +* `id` (int, compulsory): The ID of the instruction / question. + +Example: + +```json +[ + { + "category": "brainstorming", + "instruction": "请问如何制作一份美味的西红柿炒鸡蛋?", + "input": "", + "output": "", + "target": "", + "id": 1 + }, + { + "category": "chat", + "instruction": "基于以下角色信息完成一段对话。小张是一名新手爱好者,对养鸡有浓厚的兴趣。老李是一名有丰富经验的养鸡大师。", + "input": "小张:您好,老李,我最近开始对养鸡感兴趣了,想请教您一些问题。 老李:你好,小张,我很乐意帮助你。你想问些什么? 小张:我想知道如何确定鸡的品种和性别? 老李:确切的品种可以通过鸡的外貌特征来确定,而性别一般是通过鸡卵的大小和形状来判断。还有什么问题吗? 小张:", + "output": "", + "target": "", + "id": 2 + } +] +``` + +### Prompt + +#### Battle Prompt + +The following is the Chinese battle prompt. In the battle prompt, the question and answers from two different models are fed into the prompt template. You can find example battle prompt files for Chinese and English in `configs/gpt_evaluation/prompt/battle_prompt`. + +```json +{ + "id": 1, + "system_prompt": "你是一个检查回答质量的好助手。", + "prompt_template": "[问题]\n{question}\n\n[1号AI助手的答案]\n{answer_1}\n\n[1号AI助手答案终止]\n\n[2号AI助手的答 案]\n{answer_2}\n\n[2号AI助手答案终止]\n\n[要求]\n{prompt}\n\n", + "prompt": "我们需要你评价这两个AI助手回答的性能。\n请对他们的回答的有用性、相关性、准确性、详细程度进行评分。每个AI助手都会得到一个1到10分的总分,分数越高表示整体表现越好。\n请首先输出一行,该行只包含两个数值,分别表示1号和2号AI助手的分数。这两个分数之间要有一个空格。在随后的一行中,请对你的评价作出全面的解释,避免任何潜在的偏见,并确保AI助手回答的顺序不会影响您的判断。" +} +``` + +#### Evaluation Prompt + +The following is an example of a Chinese GPT evaluation prompt. In an evaluation prompt, you should define your metrics in `metrics` and provide CoT(Chain-of-Thought) in `CoT`. You can find example evaluation prompt files for Chinese and English in `configs/gpt_evaluation/prompt/evaluation_prompt`. + +```json +{ + "brainstorming": { + "id": 1, + "category": "brainstorming", + "metrics": { + "language organization": "语言组织(1-5):答案语言是否流畅、连贯,使用正确的语法,具有一定逻辑性,使用恰当的连接词、过渡词等等。" + }, + "CoT": { + "language organization": "1. 阅读答案,并检查是否有语法错误、用词不当或其他显著的错误。\n2. 检查答案是否具有逻辑性,能够按照合理的顺序传达信息并且能够自圆其说。\n3. 确定答案是否与问题或主题相关,并且能够传达清晰的信息。\n4. 检查答案是否连贯,是否使用适当的转换和过渡来保持句子和段落之间的连贯性。\n5. 检查答案是否具有明确的结构和组织方式,使得读者可以轻松理解信息的层次和结构。\n6. 根据以上因素综合评估答案的语言组织,并给出一个1到5的分数,其中5表示语言组织非常好,而1表示语言组织非常差。\n\n语言组织:" + }, + "prompt": "你是一个好助手。请你为下面“头脑风暴”问题的答案打分。\n\n问题如下:\n\n{question}\n\n答案如下:\n\n{answer}\n\n评分的指标如下:\n\n{metric}\n\n请你遵照以下的评分步骤:\n\n{steps}" + } +} +``` + +`"metrics"`: the metrics that can be used in GPT evaluation. This field determines which metrics can be added to your config file. + +`"CoT"`: evaluation steps you prompt to GPT models for each metric defined in `"metrics"`. + +### Evaluation + +#### Configuration + +The following is an example of a Chinese config file. The configuration file can control how the pipeline evaluates the model. You need to specify GPT evaluation metrics in key `GPT`. You can find an example English config file in `configs/gpt_evaluation/config/config_en.json`. + +```json +{ + "language": "cn", + "category": { + "brainstorming": { + "GPT": [ + "language organization", + "relevance", + "creativity", + "practicality", + "reasonableness" + ] + } + } +} +``` + +`"language"`: the language used to evaluate the model capability. We only support Chinese `"cn"` for now. + +`"category"`: the category/categories needed to evaluate the model capability. + +`"GPT"`: the metrics you want to use for GPT evaluation. + + +#### Evaluate + +After setting the configuration file, you can evaluate the model using `examples/gpt_evaluation/eval.py`. If you want to make comparisons between answers of two different models, you should specify two answer files in the argument `answer_file_list` and two model names in the argument `model_name_list`. If you want to evaluate one answer file, the length of both `answer_file_list` and `model_name_list` should be 1 and the program will perform evaluation using automatic metrics and GPT models. + +An example script is provided as follows: + +```shell +python eval.py \ + --config_file "path to the config file" \ + --battle_prompt_file "path to the prompt file for battle" \ + --gpt_evaluation_prompt_file "path to the prompt file for gpt evaluation" \ + --target_file "path to the target answer file" \ + --answer_file_list "path to the answer files of at most 2 models" \ + --model_name_list "the names of at most 2 models" \ + --gpt_model "which GPT model to use for evaluation" \ + --save_path "path to save results" \ + --openai_key "your openai key" \ +``` + +If you want GPT evaluation with reference, you can add an argument `--gpt_with_reference`, but make sure the reference file have target answers. + +## FAQ + +
        How can I add a new GPT evaluation metric? + +For example, if you want to add a new metric `persuasiveness` into category `brainstorming`, you should add the metric definition and its corresponding CoT(Chain-of-thought) in the evaluation prompt file in `prompt/evaluation_promt`. The CoT can be generated using ChatGPT. You can prompt ChatGPT to generate evaluation steps for the new metric. + +```json +{ + "brainstorming": { + "id": 1, + "category": "brainstorming", + "metrics": { + "persuasiveness": "persuasiveness(1-5):a short description for persuasiveness" + }, + "CoT": { + "persuasiveness": "CoT for persuasiveness\n\npersuasiveness:" + }, + "prompt": "You are a good assistant. Please rate the given answer to the \"brainstorming\" question below.\n\nThe question is as follows:\n\n{question}\n\nThe answer is as follows:\n\n{answer}\n\nThe metric for evaluation is as follows:\n\n{metric}\n\nYou should follow the following evaluation steps:\n\n{steps}" + } +} +``` + +
        + +## Citations + +```bibtex +@misc{vicuna2023, + title = {Vicuna: An Open-Source Chatbot Impressing GPT-4 with 90\%* ChatGPT Quality}, + url = {https://vicuna.lmsys.org}, + author = {Chiang, Wei-Lin and Li, Zhuohan and Lin, Zi and Sheng, Ying and Wu, Zhanghao and Zhang, Hao and Zheng, Lianmin and Zhuang, Siyuan and Zhuang, Yonghao and Gonzalez, Joseph E. and Stoica, Ion and Xing, Eric P.}, + month = {March}, + year = {2023} +} + +@misc{liu2023geval, + title={G-Eval: NLG Evaluation using GPT-4 with Better Human Alignment}, + author={Yang Liu and Dan Iter and Yichong Xu and Shuohang Wang and Ruochen Xu and Chenguang Zhu}, + year={2023}, + eprint={2303.16634}, + archivePrefix={arXiv}, + primaryClass={cs.CL} +} +``` diff --git a/tests/test_layers/test_1d/checks_1d/__init__.py b/applications/ColossalEval/colossal_eval/evaluate/__init__.py similarity index 100% rename from tests/test_layers/test_1d/checks_1d/__init__.py rename to applications/ColossalEval/colossal_eval/evaluate/__init__.py diff --git a/applications/ColossalEval/colossal_eval/evaluate/dataset_evaluator/__init__.py b/applications/ColossalEval/colossal_eval/evaluate/dataset_evaluator/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..3c5df09a6909b0b5fdc6c726e2543012b94a4410 --- /dev/null +++ b/applications/ColossalEval/colossal_eval/evaluate/dataset_evaluator/__init__.py @@ -0,0 +1,3 @@ +from .dataset_evaluator import DatasetEvaluator + +__all__ = ["DatasetEvaluator"] diff --git a/applications/ColossalEval/colossal_eval/evaluate/dataset_evaluator/dataset_evaluator.py b/applications/ColossalEval/colossal_eval/evaluate/dataset_evaluator/dataset_evaluator.py new file mode 100644 index 0000000000000000000000000000000000000000..c70988707a154cc50126797c5b13949acef3b97e --- /dev/null +++ b/applications/ColossalEval/colossal_eval/evaluate/dataset_evaluator/dataset_evaluator.py @@ -0,0 +1,269 @@ +from typing import Dict, List + +import colossal_eval.evaluate.dataset_evaluator.metrics as metric_helper +import numpy as np +import tqdm + +LabelBasedMetrics = ["first_token_accuracy", "matthews_correlation"] +LossBasedMetrics = ["perplexity", "ppl_score", "ppl_score_over_choices", "per_byte_perplexity", "per_byte_ppl_score"] +CombinedMetrics = ["combined_single_choice_accuracy"] +OtherMetrics = [ + "f1_score", + "f1_zh_score", + "rouge_score", + "rouge_zh_score", + "retrieval_score", + "retrieval_zh_score", + "classification_score", + "code_sim_score", + "count_score", + "multi_choice_accuracy", + "math_equivalence", + "single_choice_accuracy", +] + + +class DatasetEvaluator(object): + """ + Dataset evaluator. + + """ + + def __init__(self): + pass + + def _calculate_label_metrics(self, metric: str, category: str): + """Calculate label-based metrics.""" + weight = len(self.data[category]["data"]) / self.metric_total_length[metric] + + str_label_map = { + choice: idx for idx, choice in enumerate(self.data[category]["inference_kwargs"]["all_classes"]) + } + + references = [str_label_map[sample["target"]] for sample in self.data[category]["data"]] + [sample["output"] for sample in self.data[category]["data"]] + + flag = False + softmaxs = [] + for i, sample in enumerate(self.data[category]["data"]): + if np.any(np.isnan(np.array(list(sample["softmax_over_choices"].values())))): + if not flag: + print( + f"NaN in the softmax, switch to exact match for category {category} in dataset {self.dataset_name} in model {self.model_name}." + ) + flag = True + score = 0 + for ref in sample["target"]: + score = max( + score, + metric_helper.single_choice_accuracy( + sample["output"], ref, all_classes=self.data[category]["inference_kwargs"]["all_classes"] + ), + ) + softmaxs.append(references[i] if score == 1 else -1) + else: + softmaxs.append(np.argmax(np.array(list(sample["softmax_over_choices"].values())))) + + references = np.array(references) + softmaxs = np.array(softmaxs) + scores = np.sum(references == softmaxs) / len(self.data[category]["data"]) * 100 + + self.evaluation_results[metric][category] = (scores, len(self.data[category]["data"])) + self.evaluation_results[metric]["ALL"] += scores * weight + + def _calculate_combined_metrics(self, metric: str, category: str): + """Calculate combined metrics.""" + weight = len(self.data[category]["data"]) / self.metric_total_length[metric] + + references = [sample["target"] for sample in self.data[category]["data"]] + predictions = [sample["output"] for sample in self.data[category]["data"]] + + str_label_map = { + choice: idx for idx, choice in enumerate(self.data[category]["inference_kwargs"]["all_classes"]) + } + + references_labels = [str_label_map[sample["target"][0]] for sample in self.data[category]["data"]] + predictions = [sample["output"] for sample in self.data[category]["data"]] + + flag = False + softmaxs = [] + for i, sample in enumerate(self.data[category]["data"]): + if np.any(np.isnan(np.array(list(sample["softmax_over_choices"].values())))): + if not flag: + print( + f"NaN in the softmax, switch to exact match for category {category} in dataset {self.dataset_name} in model {self.model_name}." + ) + flag = True + score = 0 + for ref in sample["target"]: + score = max( + score, + metric_helper.single_choice_accuracy( + sample["output"], ref, all_classes=self.data[category]["inference_kwargs"]["all_classes"] + ), + ) + softmaxs.append(references[i] if score == 1 else -1) + else: + softmaxs.append(np.argmax(np.array(list(sample["softmax_over_choices"].values())))) + + metric_method = eval("metric_helper." + metric) + + total_score = 0.0 + for prediction, reference, references_label, softmax in zip( + predictions, references, references_labels, softmaxs + ): + score = 0.0 + + for ref in reference: + score = max( + score, + metric_method(prediction, ref, all_classes=self.data[category]["inference_kwargs"]["all_classes"]), + ) + if references_label == softmax: + score = 1 + + total_score += score + total_score = total_score * 100 / len(self.data[category]["data"]) + + self.evaluation_results[metric][category] = (total_score, len(self.data[category]["data"])) + self.evaluation_results[metric]["ALL"] += total_score * weight + + def _calculate_other_metrics(self, metric: str, category: str): + """Calculate other metrics.""" + weight = len(self.data[category]["data"]) / self.metric_total_length[metric] + + references = [sample["target"] for sample in self.data[category]["data"]] + predictions = [sample["output"] for sample in self.data[category]["data"]] + + metric_method = eval("metric_helper." + metric) + + total_score = 0.0 + for prediction, reference in zip(predictions, references): + score = 0.0 + for ref in reference: + score = max( + score, + metric_method(prediction, ref, all_classes=self.data[category]["inference_kwargs"]["all_classes"]), + ) + total_score += score + total_score = total_score * 100 / len(predictions) + + self.evaluation_results[metric][category] = (total_score, len(self.data[category]["data"])) + self.evaluation_results[metric]["ALL"] += total_score * weight + + def _calculate_loss_metrics(self, metric: str, category: str): + """Calculate perplexity.""" + if metric == "perplexity": + weight = len(self.data[category]["data"]) / self.metric_total_length[metric] + losses = [min(sample["loss"]) for sample in self.data[category]["data"]] + perplexity = np.mean(np.exp(np.array(losses))) + + self.evaluation_results["perplexity"][category] = (perplexity, len(self.data[category]["data"])) + self.evaluation_results["perplexity"]["ALL"] += perplexity * weight + elif metric == "ppl_score": + weight = len(self.data[category]["data"]) / self.metric_total_length[metric] + losses = [min(sample["loss"]) for sample in self.data[category]["data"]] + perplexity_score = np.mean(np.exp(-np.array(losses))) * 100 + + self.evaluation_results["ppl_score"][category] = (perplexity_score, len(self.data[category]["data"])) + self.evaluation_results["ppl_score"]["ALL"] += perplexity_score * weight + elif metric == "ppl_score_over_choices" and self.data[category]["inference_kwargs"]["all_classes"] is not None: + weight = len(self.data[category]["data"]) / self.metric_total_length[metric] + loss_over_choices = [sample["loss_over_choices"] for sample in self.data[category]["data"]] + perplexity_score_over_choices = np.mean(np.exp(-np.array(loss_over_choices))) * 100 + + self.evaluation_results["ppl_score_over_choices"][category] = ( + perplexity_score_over_choices, + len(self.data[category]["data"]), + ) + self.evaluation_results["ppl_score_over_choices"]["ALL"] += perplexity_score_over_choices * weight + elif metric == "per_byte_perplexity": + weight = len(self.data[category]["data"]) / self.metric_total_length[metric] + losses = [min(sample["loss_sum"]) for sample in self.data[category]["data"]] + perplexity = np.mean(np.exp(np.array(losses) / np.array(self.N_bytes[category]))) + + self.evaluation_results["per_byte_perplexity"][category] = perplexity + self.evaluation_results["per_byte_perplexity"]["ALL"] += perplexity * weight + elif metric == "per_byte_ppl_score": + weight = len(self.data[category]["data"]) / self.metric_total_length[metric] + losses = [min(sample["loss_sum"]) for sample in self.data[category]["data"]] + perplexity_score = np.mean(np.exp(-np.array(losses) / np.array(self.N_bytes[category]))) * 100 + + self.evaluation_results["per_byte_ppl_score"][category] = perplexity_score + self.evaluation_results["per_byte_ppl_score"]["ALL"] += perplexity_score * weight + + def _evaluate(self): + """Calculate and return evaluation results""" + + for metric in self.metrics: + pbar = tqdm.tqdm( + desc=f"{self.dataset_name}-{metric}-{self.model_name}", total=len(self.suggested_categories[metric]) + ) + + if metric in LabelBasedMetrics: + for category in self.suggested_categories[metric]: + self._calculate_label_metrics(metric, category) + pbar.update(1) + elif metric in LossBasedMetrics: + for category in self.suggested_categories[metric]: + self._calculate_loss_metrics(metric, category) + pbar.update(1) + elif metric in CombinedMetrics: + for category in self.suggested_categories[metric]: + self._calculate_combined_metrics(metric, category) + pbar.update(1) + elif metric in OtherMetrics: + for category in self.suggested_categories[metric]: + self._calculate_other_metrics(metric, category) + pbar.update(1) + + return self.evaluation_results + + def get_evaluation_results(self, data: List[Dict], dataset_name: str, model_name: str, metrics: List[str]): + """ + Evaluate inference data on the given metrics. + + Args: + data: Data to be evaluated. + dataset_name: Name of the dataset + model_name: Name of the model + metrics: Metrics used to evaluate. + + """ + self.data = data + self.dataset_name = dataset_name + self.model_name = model_name + self.categories = list(data.keys()) + self.metrics = metrics + + self.evaluation_results = { + metric: {category: 0 for category in (["ALL"] + self.categories)} for metric in self.metrics + } + + self.total_length = 0 + self.total_single_choices = 0 + for value in self.data.values(): + self.total_length += len(value["data"]) + if value["inference_kwargs"]["all_classes"] is not None: + self.total_single_choices += len(value["data"]) + + self.metric_total_length = {metric: 0 for metric in self.metrics} + self.suggested_categories = {metric: [] for metric in self.metrics} + + for metric in self.metrics: + self.suggested_categories[metric] = metric_helper.metrics4subcategory[self.dataset_name][metric] + if "ALL" in self.suggested_categories[metric]: + self.suggested_categories[metric] = self.categories + self.metric_total_length[metric] = self.total_length + continue + for category in self.suggested_categories[metric]: + self.metric_total_length[metric] += len(self.data[category]["data"]) + + if "per_byte_perplexity" in self.metrics or "per_byte_ppl_score" in self.metrics: + self.N_bytes = {category: [] for category in self.categories} + for category in self.categories: + samples = self.data[category]["data"] + for sample in samples: + self.N_bytes[category].append(sample["byte_num"][0]) + + return self._evaluate() diff --git a/applications/ColossalEval/colossal_eval/evaluate/dataset_evaluator/metrics.py b/applications/ColossalEval/colossal_eval/evaluate/dataset_evaluator/metrics.py new file mode 100644 index 0000000000000000000000000000000000000000..914465478dec92be62117621989650eb8eed3904 --- /dev/null +++ b/applications/ColossalEval/colossal_eval/evaluate/dataset_evaluator/metrics.py @@ -0,0 +1,623 @@ +# Code adapted from https://github.com/THUDM/LongBench/blob/main/metrics.py +# Code adapted from https://github.com/hendrycks/math/blob/main/modeling/math_equivalence.py +# Code adapted from https://github.com/ruixiangcui/AGIEval/blob/main/src/evaluation.py + +import difflib +import re +import string +from collections import Counter + +import jieba +from fuzzywuzzy import fuzz +from rouge import Rouge + +metrics4subcategory = { + "pretrain": { + "perplexity": ["ALL"], + "ppl_score": ["ALL"], + "per_byte_perplexity": ["ALL"], + "per_byte_ppl_score": ["ALL"], + }, + # The commented are non 4-choice questions. + "agieval": { + "combined_single_choice_accuracy": [ + # "lsat-ar", + # "lsat-lr", + # "lsat-rc", + "logiqa-en", + "sat-math", + "sat-en", + # "aqua-rat", + "sat-en-without-passage", + "gaokao-english", + "logiqa-zh", + "gaokao-chinese", + "gaokao-geography", + "gaokao-history", + "gaokao-biology", + "gaokao-chemistry", + ], + "first_token_accuracy": [ + # "lsat-ar", + # "lsat-lr", + # "lsat-rc", + "logiqa-en", + "sat-math", + "sat-en", + # "aqua-rat", + "sat-en-without-passage", + "gaokao-english", + "logiqa-zh", + "gaokao-chinese", + "gaokao-geography", + "gaokao-history", + "gaokao-biology", + "gaokao-chemistry", + ], + "single_choice_accuracy": [ + # "lsat-ar", + # "lsat-lr", + # "lsat-rc", + "logiqa-en", + "sat-math", + "sat-en", + # "aqua-rat", + "sat-en-without-passage", + "gaokao-english", + "logiqa-zh", + "gaokao-chinese", + "gaokao-geography", + "gaokao-history", + "gaokao-biology", + "gaokao-chemistry", + ], + "multi_choice_accuracy": ["jec-qa-kd", "jec-qa-ca", "gaokao-physics", "gaokao-mathqa"], + "math_equivalence": ["gaokao-mathcloze", "math"], + "perplexity": ["ALL"], + "ppl_score_over_choices": [ + "lsat-ar", + "lsat-lr", + "lsat-rc", + "logiqa-en", + "sat-math", + "sat-en", + "aqua-rat", + "sat-en-without-passage", + "gaokao-english", + "logiqa-zh", + "jec-qa-kd", + "jec-qa-ca", + "gaokao-chinese", + "gaokao-geography", + "gaokao-history", + "gaokao-biology", + "gaokao-chemistry", + "gaokao-physics", + "gaokao-mathqa", + ], + "ppl_score": ["ALL"], + }, + "cmmlu": { + "first_token_accuracy": ["ALL"], + "single_choice_accuracy": ["ALL"], + "perplexity": ["ALL"], + "ppl_score_over_choices": ["ALL"], + "ppl_score": ["ALL"], + }, + "gaokaobench": { + "combined_single_choice_accuracy": [ + "English MCQs", + "Biology MCQs", + "Chemistry MCQs", + "History MCQs", + "Math I MCQs", + "Math II MCQs", + "Political Science MCQs", + ], + "first_token_accuracy": [ + "English MCQs", + "Biology MCQs", + "Chemistry MCQs", + "History MCQs", + "Math I MCQs", + "Math II MCQs", + "Political Science MCQs", + ], + "single_choice_accuracy": [ + "English MCQs", + "Biology MCQs", + "Chemistry MCQs", + "History MCQs", + "Math I MCQs", + "Math II MCQs", + "Political Science MCQs", + ], + "multi_choice_accuracy": [ + "Chinese Lang and Usage MCQs", + "Chinese Modern Lit", + "English Fill in Blanks", + "English Reading Comp", + "Geography MCQs", + "Physics MCQs", + "English Cloze Test", + ], + "math_equivalence": ["Math I Fill-in-the-Blank", "Math II Fill-in-the-Blank"], + "rouge_score": ["English Language Cloze Passage"], + "rouge_zh_score": [ + "Chinese Language Famous Passages and Sentences Dictation", + "Chemistry Open-ended Questions", + "History Open-ended Questions", + "Biology Open-ended Questions", + "Political Science Open-ended Questions", + "English Language Error Correction", + "Chinese Language Language and Writing Skills Open-ended Questions", + "Math II Open-ended Questions", + "Chinese Language Literary Text Reading", + "Chinese Language Ancient Poetry Reading", + "Chinese Language Classical Chinese Reading", + "Physics Open-ended Questions", + "Math I Open-ended Questions", + "Geography Open-ended Questions", + "Chinese Language Practical Text Reading", + ], + "perplexity": ["ALL"], + "ppl_score_over_choices": ["ALL"], + "ppl_score": ["ALL"], + }, + "longbench": { + "f1_score": ["hotpotqa", "2wikimqa", "musique", "narrativeqa", "qasper", "multifieldqa_en", "triviaqa"], + "f1_zh_score": ["multifieldqa_zh"], + "rouge_score": ["gov_report", "qmsum", "multi_news", "samsum"], + "rouge_zh_score": ["dureader", "vcsum"], + "retrieval_score": ["passage_retrieval_en"], + "retrieval_zh_score": ["passage_retrieval_zh"], + "classification_score": ["trec", "lsht"], + "code_sim_score": ["lcc", "repobench-p"], + "count_score": ["passage_count"], + "perplexity": ["ALL"], + "ppl_score": ["ALL"], + }, + "mmlu": { + "first_token_accuracy": ["ALL"], + "single_choice_accuracy": ["ALL"], + "accuracy": ["ALL"], + "perplexity": ["ALL"], + "ppl_score_over_choices": ["ALL"], + "ppl_score": ["ALL"], + }, +} + + +def _fix_fracs(string): + substrs = string.split("\\frac") + new_str = substrs[0] + if len(substrs) > 1: + substrs = substrs[1:] + for substr in substrs: + new_str += "\\frac" + if substr[0] == "{": + new_str += substr + else: + try: + assert len(substr) >= 2 + except: + return string + a = substr[0] + b = substr[1] + if b != "{": + if len(substr) > 2: + post_substr = substr[2:] + new_str += "{" + a + "}{" + b + "}" + post_substr + else: + new_str += "{" + a + "}{" + b + "}" + else: + if len(substr) > 2: + post_substr = substr[2:] + new_str += "{" + a + "}" + b + post_substr + else: + new_str += "{" + a + "}" + b + string = new_str + return string + + +def _fix_a_slash_b(string): + if len(string.split("/")) != 2: + return string + a = string.split("/")[0] + b = string.split("/")[1] + try: + a = int(a) + b = int(b) + assert string == "{}/{}".format(a, b) + new_string = "\\frac{" + str(a) + "}{" + str(b) + "}" + return new_string + except: + return string + + +def _remove_right_units(string): + # "\\text{ " only ever occurs (at least in the val set) when describing units + if "\\text{ " in string: + splits = string.split("\\text{ ") + assert len(splits) == 2 + return splits[0] + else: + return string + + +def _fix_sqrt(string): + if "\\sqrt" not in string: + return string + splits = string.split("\\sqrt") + new_string = splits[0] + for split in splits[1:]: + if split[0] != "{": + a = split[0] + new_substr = "\\sqrt{" + a + "}" + split[1:] + else: + new_substr = "\\sqrt" + split + new_string += new_substr + return new_string + + +def _strip_string(string): + # linebreaks + string = string.replace("\n", "") + # print(string) + + # remove inverse spaces + string = string.replace("\\!", "") + # print(string) + + # replace \\ with \ + string = string.replace("\\\\", "\\") + # print(string) + + # replace tfrac and dfrac with frac + string = string.replace("tfrac", "frac") + string = string.replace("dfrac", "frac") + # print(string) + + # remove \left and \right + string = string.replace("\\left", "") + string = string.replace("\\right", "") + # print(string) + + # Remove circ (degrees) + string = string.replace("^{\\circ}", "") + string = string.replace("^\\circ", "") + + # remove dollar signs + string = string.replace("\\$", "") + + # remove units (on the right) + string = _remove_right_units(string) + + # remove percentage + string = string.replace("\\%", "") + string = string.replace("\%", "") + + # " 0." equivalent to " ." and "{0." equivalent to "{." Alternatively, add "0" if "." is the start of the string + string = string.replace(" .", " 0.") + string = string.replace("{.", "{0.") + # if empty, return empty string + if len(string) == 0: + return string + if string[0] == ".": + string = "0" + string + + # to consider: get rid of e.g. "k = " or "q = " at beginning + if len(string.split("=")) == 2: + if len(string.split("=")[0]) <= 2: + string = string.split("=")[1] + + # fix sqrt3 --> sqrt{3} + string = _fix_sqrt(string) + + # remove spaces + string = string.replace(" ", "") + + # \frac1b or \frac12 --> \frac{1}{b} and \frac{1}{2}, etc. Even works with \frac1{72} (but not \frac{72}1). Also does a/b --> \\frac{a}{b} + string = _fix_fracs(string) + + # manually change 0.5 --> \frac{1}{2} + if string == "0.5": + string = "\\frac{1}{2}" + + # NOTE: X/Y changed to \frac{X}{Y} in dataset, but in simple cases fix in case the model output is X/Y + string = _fix_a_slash_b(string) + + return string + + +def parse_math_answer(raw_string): + def remove_boxed(s): + left = "\\boxed{" + try: + assert s[: len(left)] == left + assert s[-1] == "}" + answer = s[len(left) : -1] + if "=" in answer: + answer = answer.split("=")[-1].lstrip(" ") + return answer + except: + return None + + def last_boxed_only_string(string): + idx = string.rfind("\\boxed") + if idx < 0: + idx = string.rfind("\\fbox") + if idx < 0: + return None + i = idx + right_brace_idx = None + num_left_braces_open = 0 + while i < len(string): + if string[i] == "{": + num_left_braces_open += 1 + if string[i] == "}": + num_left_braces_open -= 1 + if num_left_braces_open == 0: + right_brace_idx = i + break + i += 1 + + if right_brace_idx == None: + retval = None + else: + retval = string[idx : right_brace_idx + 1] + + return retval + + def get_answer_with_dollar_sign(s): + first_pattern = "\$(.*)\$" + last_match = None + matches = re.findall(first_pattern, s) + if matches: + last_match = matches[-1] + if "=" in last_match: + last_match = last_match.split("=")[-1].lstrip(" ") + return last_match + + def get_answer_without_dollar_sign(s): + last_match = None + if "=" in s: + last_match = s.split("=")[-1].lstrip(" ").rstrip(".") + if "\\n" in last_match: + last_match = last_match.split("\\n")[0] + else: + pattern = "(?:\\$)?\d+(?:\.\d+)?(?![\w\d])" + matches = re.findall(pattern, s) + if matches: + last_match = matches[-1] + return last_match + + if "\\boxed" in raw_string: + answer = remove_boxed(last_boxed_only_string(raw_string)) + else: + answer = get_answer_with_dollar_sign(raw_string) + if not answer: + answer = get_answer_without_dollar_sign(raw_string) + return answer + + +def math_equivalence(prediction, reference, **kwargs): + prediction = parse_math_answer(prediction) + + if prediction is None and reference is None: + print("WARNING: Both None") + return False + + if prediction is None or reference is None: + return False + + try: + ss1 = _strip_string(prediction) + ss2 = _strip_string(reference) + return ss1 == ss2 + except: + return prediction == reference + + +def multi_choice_accuracy(prediction, reference, **kwargs): + # Only find uppercase letters not surrounded by lowercase letters + all_classes = kwargs.get("all_classes", None) + if all_classes: + pattern = f"(? highest_similarity: + highest_similarity = similarity + best_match = string + score = float(best_match == reference) + return score + + +def rouge_score(prediction, reference, **kwargs): + rouge = Rouge() + try: + scores = rouge.get_scores([prediction], [reference], avg=True) + except: + return 0.0 + return scores["rouge-l"]["f"] + + +def rouge_zh_score(prediction, reference, **kwargs): + prediction = " ".join(list(jieba.cut(prediction, cut_all=False))) + reference = " ".join(list(jieba.cut(reference, cut_all=False))) + score = rouge_score(prediction, reference) + return score + + +def _f1_score(prediction, reference, **kwargs): + common = Counter(prediction) & Counter(reference) + num_same = sum(common.values()) + if num_same == 0: + return 0 + precision = 1.0 * num_same / len(prediction) + recall = 1.0 * num_same / len(reference) + f1 = (2 * precision * recall) / (precision + recall) + return f1 + + +def f1_score(prediction, reference, **kwargs): + normalized_prediction = normalize_answer(prediction) + normalized_ground_truth = normalize_answer(reference) + + prediction_tokens = normalized_prediction.split() + ground_truth_tokens = normalized_ground_truth.split() + return _f1_score(prediction_tokens, ground_truth_tokens) + + +def f1_zh_score(prediction, reference, **kwargs): + prediction_tokens = list(jieba.cut(prediction, cut_all=False)) + ground_truth_tokens = list(jieba.cut(reference, cut_all=False)) + prediction_tokens = [normalize_zh_answer(token) for token in prediction_tokens] + ground_truth_tokens = [normalize_zh_answer(token) for token in ground_truth_tokens] + prediction_tokens = [token for token in prediction_tokens if len(token) > 0] + ground_truth_tokens = [token for token in ground_truth_tokens if len(token) > 0] + return _f1_score(prediction_tokens, ground_truth_tokens) diff --git a/applications/ColossalEval/colossal_eval/evaluate/evaluator.py b/applications/ColossalEval/colossal_eval/evaluate/evaluator.py new file mode 100644 index 0000000000000000000000000000000000000000..11e204b504c5be370285bef9c5ea2a089e403376 --- /dev/null +++ b/applications/ColossalEval/colossal_eval/evaluate/evaluator.py @@ -0,0 +1,110 @@ +import os +from typing import Any, Dict, List + +import colossal_eval.evaluate.gpt_evaluate as gpt_evaluate + +from .utils import get_data_per_category + + +class Evaluator(object): + """ + A class named Evaluator includes GPT-3.5/GPT-4 evaluation + + """ + + def __init__( + self, + params: Dict[str, Any], + battle_prompt: Dict[str, Any], + gpt_evaluation_prompt: Dict[str, Any], + gpt_model: str, + language: str, + gpt_with_reference: bool, + ) -> None: + self.params = params + self.battle_prompt = battle_prompt + self.gpt_evaluation_prompt = gpt_evaluation_prompt + self.gpt_model = gpt_model + self.language = language + self.gpt_with_reference = gpt_with_reference + self.gpt_evaluation_results = dict() + self.battle_results = [] + + def battle(self, answers1: List[Dict], answers2: List[Dict]) -> None: + """ + Comparison between two models using GPT-4 as the reviewer. + """ + + self.battle_results = gpt_evaluate.battle(answers1, answers2, self.battle_prompt) + + def evaluate(self, answers: List[Dict], targets: List[Dict], save_path: str, model_name: str) -> None: + """ + A comprehensive evaluation of the answers from the model. + The function evaluates the model's performance from different perspectives + using GPT-3.5, GPT-4, and off-the-shelf evaluation metrics. + + The metrics will be decided by the config file. + + """ + + answers_per_category = get_data_per_category(answers, list(self.params.keys())) + targets_per_category = get_data_per_category(targets, list(self.params.keys())) + + # gpt evaluation + for category in self.params: + if len(answers_per_category[category]) == 0: + print(f"Category {category} specified in your config doesn't have corresponding answers!") + continue + + if self.params[category].get("GPT", None) is None: + continue + + category_metrics = self.params[category]["GPT"] + + prompt = self.gpt_evaluation_prompt.get(category, None) + if prompt is None: + print(f"No prompt for category {category}! Use prompt for category general now.") + prompt = self.gpt_evaluation_prompt["general"] + + self.gpt_evaluation_results[category] = gpt_evaluate.evaluate( + answers_per_category[category], + prompt, + category_metrics, + category, + save_path, + model_name, + self.gpt_model, + self.language, + references=targets_per_category[category] if self.gpt_with_reference else None, + ) + + def save(self, path: str, model_name_list: List[str]) -> None: + """ + Save evaluation results of GPT-3.5, GPT-4, and off-the-shelf evaluation metrics. + + """ + + if len(model_name_list) == 2: + save_path = os.path.join(path, "gpt_evaluate", "battle_results") + gpt_evaluate.save_battle_results(self.battle_results, model_name_list[0], model_name_list[1], save_path) + else: + if self.gpt_evaluation_results: + # Save evaluation results for GPT evaluation metrics. + gpt_base_save_path = os.path.join(path, "gpt_evaluate", "gpt_evaluate_results") + gpt_evaluation_results_save_path = os.path.join(gpt_base_save_path, "evaluation_results") + + all_evaluations = gpt_evaluate.save_gpt_evaluation_results( + model_name_list[0], self.gpt_evaluation_results, gpt_evaluation_results_save_path + ) + + # Start to calculate scores and save statistics. + gpt_evaluation_statistics_save_path = os.path.join(gpt_base_save_path, "evaluation_statistics") + gpt_evaluate.save_gpt_evaluation_statistics( + model_name_list[0], all_evaluations, gpt_evaluation_statistics_save_path + ) + + # Save charts and csv. + gpt_evaluation_analyses_save_path = os.path.join(gpt_base_save_path, "evaluation_analyses") + gpt_evaluate.analyze_gpt_evaluation_statistics( + gpt_evaluation_statistics_save_path, gpt_evaluation_analyses_save_path + ) diff --git a/applications/ColossalEval/colossal_eval/evaluate/gpt_evaluate.py b/applications/ColossalEval/colossal_eval/evaluate/gpt_evaluate.py new file mode 100644 index 0000000000000000000000000000000000000000..a0b1ed1143f037f22bd6a3434a9985c76e930f79 --- /dev/null +++ b/applications/ColossalEval/colossal_eval/evaluate/gpt_evaluate.py @@ -0,0 +1,852 @@ +import concurrent.futures +import os +import re +import time +from copy import deepcopy +from typing import Any, Dict, List + +import matplotlib.pyplot as plt +import numpy as np +import openai +import pandas as pd +import seaborn as sns +import tqdm +from colossal_eval.utils import jdump, jload + +ref_step_template = { + "en": "Now please compare the answer with the {adjective} answer, determine whether the answer is able to achieve the same level of {metric}.\n\n", + "cn": "请比较答案与上面的{adjective}答案,确定答案是否可以达到与该{adjective}答案同样水平的{metric}。\n\n", +} + +ref_answer_template_general = { + "en": "\nAn example answer with good quality is as follows:\n\n{answer}\n\n", + "cn": "\n一个优质的示例答案如下:\n\n{answer}\n\n", +} + +ref_answer_template_correctness = { + "en": "\nA correct answer is as follows:\n\n{answer}\n\n", + "cn": "\n标准答案如下:\n\n{answer}\n\n", +} + + +def get_battle_result(sys_prompt: str, user_prompt: str, id: int, max_tokens: int = 2048) -> Dict[str, Any]: + """ + Get battle evaluation from GPT-4. + + Args: + sys_prompt: prompt for the system. + user_prompt: prompt for the user. + id: id of the answers for comparison. + max_tokens: the maximum number of tokens to generate in the chat completion. + + Returns: + An evaluation of one comparison. + """ + + MAX_API_RETRY = 3 + for _ in range(MAX_API_RETRY): + try: + response = openai.ChatCompletion.create( + model="gpt-4", + messages=[ + {"role": "system", "content": sys_prompt}, + { + "role": "user", + "content": user_prompt, + }, + ], + temperature=0.2, + max_tokens=max_tokens, + ) + evaluation = response["choices"][0]["message"]["content"] + return {"evaluation": evaluation, "id": id} + except Exception as e: + print(e) + time.sleep(1) + print(f"Evaluation {id} failed after {MAX_API_RETRY} retries.") + return {"evaluation": "", "id": id} + + +def parse_battle_score(evaluation: str) -> List[float]: + """ + Parse evaluation from GPT-4 and get the scores of model 1 and 2. + + Args: + evaluation: evaluation from GPT-4. + + Returns: + A score pair of two different model answers. + """ + + try: + pattern = re.compile("([0-9]|10) out of 10") + sp = re.findall(pattern, evaluation) + if len(re.findall(pattern, evaluation)) == 2: + return [float(sp[0]), float(sp[1])] + + pattern = re.compile("a score of ([0-9]|10)") + sp = re.findall(pattern, evaluation) + if len(re.findall(pattern, evaluation)) == 2: + return [float(sp[0]), float(sp[1])] + + pattern = re.compile("([0-9]|10)/10") + sp = re.findall(pattern, evaluation) + if len(re.findall(pattern, evaluation)) == 2: + return [float(sp[0]), float(sp[1])] + + score_pair = evaluation.split("\n")[0] + score_pair = score_pair.replace(",", " ") + sp = score_pair.split(" ") + if len(sp) == 2: + return [float(sp[0]), float(sp[1])] + else: + raise Exception(f"Invalid score pair. Got {evaluation}.") + except Exception: + return [-1, -1] + + +def battle(answer1: List[Dict], answer2: List[Dict], prompt_dict: Dict[str, Any]) -> List[Dict]: + """ + Use GPT-4 to compare answers of two different models. + + Args: + answer1: answers of model 1. + answer2: answers of model 2. + prompt_dict: prompt for battle. + + Returns: + Evaluations of all comparison pairs. + """ + + assert len(answer1) == len(answer2) + + total_len = len(answer1) + question_idx_list = list(range(total_len)) + + print(f" Total number of answers: {len(answer1)}.") + + evaluations = [] + with concurrent.futures.ThreadPoolExecutor(max_workers=4) as executor: + futures = [] + for i in question_idx_list: + assert answer1[i]["id"] == answer2[i]["id"] + answer_id = answer1[i]["id"] + + ques = ( + answer1[i]["instruction"] + if answer1[i]["input"] == "" + else answer1[i]["instruction"] + " " + answer1[i]["input"] + ) + answer1[i]["category"] + ans1 = answer1[i]["output"] + ans2 = answer2[i]["output"] + + sys_prompt = prompt_dict["system_prompt"] + prompt_template = prompt_dict["prompt_template"] + prompt = prompt_template.format( + question=ques, + answer_1=ans1, + answer_2=ans2, + prompt=prompt_dict["prompt"], + ) + + future = executor.submit(get_battle_result, sys_prompt, prompt, answer_id, 2048) + futures.append(future) + + for future in tqdm.tqdm(concurrent.futures.as_completed(futures), total=len(futures)): + evaluations.append(future.result()) + + evaluations.sort(key=lambda x: x["id"]) + + return evaluations + + +def save_battle_results(evaluations: List[Dict], name1: str, name2: str, save_path: str) -> None: + """ + Save evaluation results (model 1 vs model 2) from GPT-4. + + Args: + evaluations: evaluation results from GPT-4. + name1: model 1 's name. + name2: model 2 's name. + save_path: path to save battle results. + """ + + evaluation_file = deepcopy(evaluations) + + ans1_score = 0 + ans2_score = 0 + better_count = 0 + worse_count = 0 + tie_count = 0 + invalid_count = 0 + + better_file = [] + worse_file = [] + tie_file = [] + invalid_file = [] + + for idx, evaluation in enumerate(evaluations): + scores = parse_battle_score(evaluation["evaluation"]) + evaluation_file[idx]["score"] = scores + + if scores[0] == -1 and scores[1] == -1: + invalid_count += 1 + invalid_file.append(evaluation_file[idx]) + print(f'Invalid score pair: {evaluation_file[idx]["id"]}.') + else: + if scores[0] > scores[1]: + worse_count += 1 + worse_file.append(evaluation_file[idx]) + elif scores[0] < scores[1]: + better_count += 1 + better_file.append(evaluation_file[idx]) + else: + tie_count += 1 + tie_file.append(evaluation_file[idx]) + ans1_score += scores[0] + ans2_score += scores[1] + + prefix = f"{name1}_vs_{name2}" + + if not os.path.exists(save_path): + os.makedirs(save_path) + + jdump(better_file, os.path.join(save_path, prefix, f"{name2}_better.json")) + jdump(worse_file, os.path.join(save_path, prefix, f"{name2}_worse.json")) + jdump(tie_file, os.path.join(save_path, prefix, f"{prefix}_tie.json")) + jdump(invalid_file, os.path.join(save_path, prefix, f"{prefix}_invalid.json")) + jdump(evaluation_file, os.path.join(save_path, prefix, f"{prefix}_evaluations.json")) + + if os.path.exists(os.path.join(save_path, "battle_results.json")): + results = jload(os.path.join(save_path, "battle_results.json")) + else: + results = {} + + results[prefix] = { + "model": [name1, name2], + "better": better_count, + "worse": worse_count, + "tie": tie_count, + "win_rate": better_count / (len(evaluations) - invalid_count), + "score": [ + ans1_score / (len(evaluations) - invalid_count), + ans2_score / (len(evaluations) - invalid_count), + ], + } + jdump(results, os.path.join(save_path, "battle_results.json")) + + print(f"Total {invalid_count} invalid score pair(s).") + print(f"Model {name2} has {better_count} better answer(s).") + print(f"Model {name2} has {worse_count} worse answer(s).") + print(f"{tie_count} answer(s) play(s) to a tie.") + print(f"Win rate of model {name2}: {better_count/(len(evaluations)-invalid_count):.2f}") + print(f"Model {name1} average score: {ans1_score/(len(evaluations)-invalid_count):.2f}") + print(f"Model {name2} average score: {ans2_score/(len(evaluations)-invalid_count):.2f}") + + +def reference_template(metric: str, language: str, reference: Dict[str, Any]) -> str: + """ + Get prompt template for GPT evaluation with reference. + + Different languages have different prompt templates. + + Args: + metric: metric used in GPT evaluation with reference. + language: language for the template. + reference: the instruction that contains target answer. + + Returns: + Prompt template for GPT evaluation with reference. + """ + + step_to_add = ref_step_template[language] + + for_the_given_answer = ( + "{metric} (1-5) (directly give the score for the given answer):" + if language == "en" + else "{metric} (1-5) (直接对给定答案打分)" + ) + + # adjective is used to describe the word "answer" in the prompt. + adjective = "example" if language == "en" else "示例" + answer_to_add = ref_answer_template_general[language] + + # Only for correctness, we will provide a correct answer and so the adjective for "answer" will be "correct". The prompt words will be "a correct answer". + # In other cases, the prompt words will be "an example answer with good quality" by default. + if metric.lower() == "correctness": + adjective = "correct" if language == "en" else "标准" + answer_to_add = ref_answer_template_correctness[language] + + answer_to_add = answer_to_add.format(answer=reference["target"] if reference["target"] else reference["output"]) + step_to_add = step_to_add.format(metric=metric.lower(), adjective=adjective) + for_the_given_answer.format( + metric=metric + ) + + return answer_to_add + step_to_add + + +def fill_in_message(role: str, content: str) -> Dict[str, str]: + """ + Generate one formatted message to send through chat completion. + + Args: + role: the role of the author of this message. + content: the contents of the message. + + Returns: + One message to send through chat completion. + """ + + return {"role": role, "content": content} + + +def multiturn_chat_completion(user_messages: List[str], model: str, max_tokens: int = 1, turns=2) -> Dict[str, Any]: + """ + Do multi-turn chat completion. + + When turns == 1, it is a one-turn conversation for normal GPT evaluation. + When turns == 2, it is a two-turn conversation which is used for GPT evaluation with reference answers. + + Args: + user_messages: messages user wants to send. + model: the model used to evaluate answers. + max_tokens: the maximum number of tokens to generate in the chat completion. + turns: the number of turns for conversation. + + Returns: + Last turn's response. + """ + + if len(user_messages) != turns: + raise Exception("The length of user messages should be equal to the turn number!") + + assistant_responses = [] + + for i in range(turns): + messages_to_send = [] + + for j in range(i): + messages_to_send.append(fill_in_message("user", user_messages[j])) + messages_to_send.append( + fill_in_message("assistant", assistant_responses[j]["choices"][0]["message"]["content"]) + ) + + # Length of user messages == Length of assistant messages + 1 + # Because we always expect the api to response + messages_to_send.append(fill_in_message("user", user_messages[i])) + + response = openai.ChatCompletion.create( + model=model, + messages=messages_to_send, + temperature=0, + max_tokens=max_tokens, + ) + + # Avoid exceeding rate limits. + # You can comment this line if your request doesn't contain many tokens. + time.sleep(1) + + assistant_responses.append(response) + + return assistant_responses[-1] + + +def get_gpt_evaluation_without_logprobs( + prompt: Dict[str, Any], + inst: Dict[str, Any], + metrics: List[str], + language: str, + reference: Dict[str, Any] = None, + model: str = "gpt-3.5-turbo", + max_tokens: int = 2048, +) -> Dict[str, Any]: + """ + Use chat models(gpt-3.5-turbo or gpt-4) to evaluate one model answer. + + Temprature is set to 0 to make the model more deterministic. + + Args: + prompt: a dictionary including prompt template, CoT and metrics. + inst: the instruction that is needed to be evaluated. + metrics: the metrics for evaluation. + language: language used to change the CoT(add one more step about comparing the given answer and reference) if reference is not None. + reference: the reference answer. + model: the model used to evaluate answers. + max_tokens: the maximum number of tokens to generate in the chat completion. + + Returns: + An evaluation of one answer. + """ + + MAX_API_RETRY = 3 + + question = inst["instruction"] if inst["input"] == "" else inst["instruction"] + "\n" + inst["input"] + answer = inst["output"] + inst["evaluation"] = {} + + for metric in metrics: + if prompt["metrics"].get(metric, None) is None: + raise Exception( + f"Unsupported metric {metric} for category {inst['category']}! You should add this metric in the prompt file!" + ) + for i in range(MAX_API_RETRY): + try: + prompt_reference = "" if reference is None else reference_template(metric, language, reference) + + prompt_1st_round = prompt["prompt"].format( + question=question, + answer=answer, + metric=prompt["metrics"][metric], + steps=prompt["CoT"][metric], + ) + + if prompt_reference and (reference["target"] or reference["output"]): + # Do a 2-round conversation + response = multiturn_chat_completion( + [prompt_1st_round, prompt_reference], model, max_tokens=max_tokens, turns=2 + ) + else: + response = multiturn_chat_completion([prompt_1st_round], model, max_tokens=max_tokens, turns=1) + + inst["evaluation"][metric] = { + "response": response["choices"][0]["message"]["content"], + "logprobs": None, + } + + # Prevent exceeding rate limits because we have multiple workers. + # But this will slow down the evaluation process. + # You can comment this line if your request doesn't contain many tokens. + time.sleep(len(metrics) * 0.5) + + break + except Exception as e: + print(e) + time.sleep(1) + if metric not in inst["evaluation"]: + print(f"Evaluation {inst['id']} for metric {metric} failed after {MAX_API_RETRY} retries.") + inst["evaluation"][metric] = {} + return inst + + +def get_gpt_evaluation_with_logprobs( + prompt: Dict[str, Any], inst: Dict[str, Any], metrics: List[str], max_tokens: int = 2048 +) -> Dict[str, Any]: + """ + Use completion model(text-davinci-003) to evaluate one model answer. + Only completion models can return log probabilities. + + Temprature is set to 0 to make the model more deterministic. + + Args: + prompt: a dictionary including prompt template, CoT and metrics. + inst: the instruction that is needed to be evaluated. + metrics: the metrics for evaluation. + max_tokens: the maximum number of tokens to generate in the completion. + + Returns: + An evaluation of one answer. + """ + + MAX_API_RETRY = 3 + + question = inst["instruction"] if inst["input"] == "" else inst["instruction"] + "\n" + inst["input"] + answer = inst["output"] + inst["evaluation"] = {} + + for metric in metrics: + if prompt["metrics"].get(metric, None) is None: + raise Exception( + f"Unsupported metric {metric} for category {inst['category']}! You should add this metric in the prompt file!" + ) + for i in range(MAX_API_RETRY): + try: + response = openai.Completion.create( + model="text-davinci-003", + prompt=prompt["prompt"].format( + question=question, + answer=answer, + metric=prompt["metrics"][metric], + steps=prompt["CoT"][metric], + ), + logprobs=5, + temperature=0, + max_tokens=max_tokens, + ) + inst["evaluation"][metric] = { + "response": response["choices"][0]["text"], + "logprobs": response["choices"][0]["logprobs"]["top_logprobs"], + } + + # Prevent exceeding rate limits because we have multiple workers. + # But this will slow down the evaluation process. + # You can comment this line if your request doesn't contain many tokens. + time.sleep(len(metrics) * 0.5) + + break + except Exception as e: + print(e) + time.sleep(1) + if metric not in inst["evaluation"]: + print(f"Evaluation {inst['id']} for metric {metric} failed after {MAX_API_RETRY} retries.") + inst["evaluation"][metric] = {} + return inst + + +def evaluate( + answers: List[Dict], + prompt: Dict[str, Any], + metrics: List[str], + category: str, + save_path: str, + model_name: str, + model: str, + language: str, + references: List[Dict] = None, +) -> List[Dict]: + """ + Use GPT models to evaluate model answers and save evaluation results. + + Args: + answers: model answers. + prompt: prompt for GPT evaluation. + metrics: metrics for GPT evaluation. + category: the category of the model answers for evaluation. + model: the specific GPT model used to evaluate answers. + language: language used in GPT evaluation + references: references for GPT evaluation + + Returns: + Evaluations of the given answers. + """ + + print(f"The number of instances of category {category}'s is {len(answers)}.") + + evaluations = [] + + metrics_str = ", ".join(x for x in metrics) + print(f"Category {category}'s metrics are {metrics_str}.") + + gpt_base_save_path = os.path.join(save_path, "gpt_evaluate", "gpt_evaluate_results") + gpt_evaluation_results_save_path = os.path.join(gpt_base_save_path, "evaluation_results") + category_file = os.path.join(gpt_evaluation_results_save_path, model_name, f"{category}_evaluation_results.json") + + if os.path.exists(category_file): + print(f"Evaluation results for category {category}, model {model_name} already exists.") + print("Skip evaluating.") + + evaluations = jload(category_file) + + retry = [] + evaluations_copy = deepcopy(evaluations) + + success = [] + for idx, e in enumerate(evaluations_copy): + keys = list(e["evaluation"].keys()) + for key in keys: + if e["evaluation"][key] == {}: + retry.append(e["id"]) + print(f"Re-evaluate id {e['id']} now.") + break + if e["id"] not in retry: + success.append(e) + + if len(retry) == 0: + evaluations.sort(key=lambda x: x["id"]) + print(f"{category} done.") + return evaluations + + with concurrent.futures.ThreadPoolExecutor(max_workers=4) as executor: + futures = [] + for idx, inst in enumerate(answers): + if not inst["id"] in retry: + continue + # Completion models can return log probabilities. + if model == "text-davinci-003": + future = executor.submit(get_gpt_evaluation_with_logprobs, prompt, inst, metrics, 1) + else: + future = executor.submit( + get_gpt_evaluation_without_logprobs, + prompt, + inst, + metrics, + language, + reference=None if references is None else references[idx], + model=model, + max_tokens=1, + ) + + futures.append(future) + + for future in tqdm.tqdm( + concurrent.futures.as_completed(futures), + desc=f"{category}: ", + total=len(futures), + ): + success.append(future.result()) + + success.sort(key=lambda x: x["id"]) + + print(f"Saving evaluation results for category {category}, model {model_name}.") + + jdump(success, category_file) + + return success + + with concurrent.futures.ThreadPoolExecutor(max_workers=4) as executor: + futures = [] + for idx, inst in enumerate(answers): + # Completion models can return log probabilities. + if model == "text-davinci-003": + future = executor.submit(get_gpt_evaluation_with_logprobs, prompt, inst, metrics, 1) + else: + future = executor.submit( + get_gpt_evaluation_without_logprobs, + prompt, + inst, + metrics, + language, + reference=None if references is None else references[idx], + model=model, + max_tokens=1, + ) + + futures.append(future) + + for future in tqdm.tqdm( + concurrent.futures.as_completed(futures), + desc=f"{category}: ", + total=len(futures), + ): + evaluations.append(future.result()) + + evaluations.sort(key=lambda x: x["id"]) + + print(f"{category} done.") + + print(f"Saving evaluation results for category {category}, model {model_name}.") + + jdump(evaluations, category_file) + + return evaluations + + +def calculate_scores_form_logprobs(logprobs: Dict[str, Any]) -> float: + """ + Calculate the score according to log probabilities returned by text-davinci-003. + + Calculation formula: + score = sum(score_i * exp(value)) where score_i is the score which corresponds to the key(predicted token) and value is its log probability. + + Ref: https://arxiv.org/abs/2303.16634 + This paper proposes NLG evaluation methods using text-davinci-003(log probabilities returned by completion models) and GPT-4(probabilities obtained by sampling). + + Args: + logprobs: logprobs returned by openai.Completion. + + Returns: + The score of one answer. + """ + + # GPT-3.5 only returns score of 1 to 5. + prob = np.zeros(5) + + for key, value in logprobs.items(): + # Sometimes the key will be one byte of a unicode character which takes the form of "bytes:\\xe7". + # It is meaningless and thus we don't calculate probability. + if "bytes" in key: + continue + # results[0] is the score which corresponds to the key(predicted token). + # For example, key "5" corresponds to score 5. + results = re.findall(r"\d", key) + if len(results) == 1: + prob[int(results[0]) - 1] = prob[int(results[0]) - 1] + np.exp(value) + + score = np.dot(np.arange(1, 6), prob) + + return score + + +def calculate_scores_form_response(response: str, evaluation: Dict[str, Any]) -> int: + """ + Calculate the score from the response returned by gpt-3.5-turbo or gpt-4. + Different from text-davinci-003, this fuction directly calculates the score according to the plain response returned by gpt-3.5-turbo or gpt-4. + Although text-davinci-003 can return log probabilities, it costs ten times as much as gpt-3.5-turbo. + + Args: + response: logprobs returned by openai.Completion. + evaluation: the evaluation corresponds to the question. + + Returns: + The score of one answer. + """ + + try: + results = re.findall(r"\d", response) + if len(results) == 1: + return int(results[0]) + else: + raise Exception(f"Invalid score pair. Got {evaluation}.") + except Exception: + return 0 + + +def save_gpt_evaluation_results( + model_name: str, gpt_evaluation_results: Dict[str, Any], save_path: str +) -> Dict[str, Any]: + """ + Save evaluation results for different categories for one model. + + Args: + model_name: name of the model for saving evaluation results. + gpt_evaluation_results: evaluations results for all of the model answers. + save_path: path to save GPT evaluation statistics. + """ + + all_evaluations = [] + for category, evaluations in gpt_evaluation_results.items(): + jdump(evaluations, os.path.join(save_path, model_name, f"{category}_evaluation_results.json")) + all_evaluations.extend(evaluations) + + jdump(all_evaluations, os.path.join(save_path, f"{model_name}_evaluation_results.json")) + + return all_evaluations + + +def save_gpt_evaluation_statistics(model_name: str, evaluations: List[Dict], save_path: str) -> None: + """ + Generate statistics for one model. + + Args: + model_name: name of the model for saving statistics. + evaluations: evaluations for all of the model answers. + save_path: path to save GPT evaluation statistics. + """ + + if not os.path.exists(save_path): + os.makedirs(save_path) + + data_per_category = {} + for evaluation in evaluations: + category = evaluation["category"] + if evaluation["category"] in data_per_category.keys(): + data_per_category[category].append(evaluation) + else: + data_per_category[category] = [evaluation] + + all_statistics = {} + for category, data in data_per_category.items(): + metrics = data[0]["evaluation"].keys() + scores = {metric: [] for metric in metrics} + for evaluation in data: + for metric in metrics: + if evaluation["evaluation"][metric] == {}: + # This means after 3 retries, the server still returns an error and we set the score to 0. + scores[metric].append(0) + elif evaluation["evaluation"][metric]["logprobs"] is not None: + scores[metric].append( + calculate_scores_form_logprobs(evaluation["evaluation"][metric]["logprobs"][0]) + ) + else: + scores[metric].append( + calculate_scores_form_response(evaluation["evaluation"][metric]["response"], evaluation) + ) + + statistics = {} + for metric in metrics: + arg_sort = np.argsort(scores[metric]) + statistics[metric] = {} + statistics[metric]["avg_score"] = sum(scores[metric]) / len(data) + statistics[metric]["best_3"] = {data[i]["id"]: scores[metric][i] for i in arg_sort[-3:][::-1]} + statistics[metric]["worst_3"] = {data[i]["id"]: scores[metric][i] for i in arg_sort[:3]} + + all_statistics[category] = statistics + + jdump( + all_statistics, + os.path.join(save_path, f"{model_name}_evaluation_statistics.json"), + ) + + +def analyze_gpt_evaluation_statistics(statistics_path: str, save_path: str) -> None: + """ + Analyze and visualize all GPT evaluation statistics in the given directory. + + Args: + statistics_path: path to all the models' statistics. + save_path: path to save table and visualization results. + """ + + if not os.path.exists(statistics_path): + raise Exception(f'The given directory "{statistics_path}" doesn\'t exist! No statistics found!') + + all_statistics = {} + + for file_name in os.listdir(statistics_path): + if file_name.endswith("_evaluation_statistics.json"): + model_name = file_name.split("_evaluation_statistics.json")[0] + all_statistics[model_name] = jload(os.path.join(statistics_path, file_name)) + + if len(list(all_statistics.keys())) == 0: + raise Exception(f'There are no statistics in the given directory "{statistics_path}"!') + + frame_all = { + "model": [], + "category": [], + "metric": [], + "avg_score": [], + "best_3": [], + "worst_3": [], + } + frame_per_category = {} + for model_name, model_statistics in all_statistics.items(): + for category, category_statistics in model_statistics.items(): + if frame_per_category.get(category) is None: + frame_per_category[category] = { + "model": [], + "metric": [], + "avg_score": [], + "best_3": [], + "worst_3": [], + } + + for metric, metric_statistics in category_statistics.items(): + frame_all["model"].append(model_name) + frame_all["category"].append(category) + frame_all["metric"].append(metric) + frame_all["avg_score"].append(metric_statistics["avg_score"]) + frame_all["best_3"].append(metric_statistics["best_3"]) + frame_all["worst_3"].append(metric_statistics["worst_3"]) + + frame_per_category[category]["model"].append(model_name) + frame_per_category[category]["metric"].append(metric) + frame_per_category[category]["avg_score"].append(metric_statistics["avg_score"]) + frame_per_category[category]["best_3"].append(metric_statistics["best_3"]) + frame_per_category[category]["worst_3"].append(metric_statistics["worst_3"]) + + if not os.path.exists(save_path): + os.makedirs(save_path) + + frame_all = pd.DataFrame(frame_all) + frame_all.to_csv(os.path.join(save_path, "gpt_evaluation_statistics.csv")) + + for category in tqdm.tqdm( + frame_per_category.keys(), + desc=f"GPT evaluation: ", + total=len(frame_per_category.keys()), + ): + data = pd.DataFrame(frame_per_category[category]) + + sns.set() + fig = plt.figure(figsize=(16, 10)) + plt.ylim((0, 5)) + + fig = sns.barplot(x="metric", y="avg_score", hue="model", data=data, dodge=True) + fig.set_title(f"Comparison between Different Models for Category {category.title()}") + plt.xlabel("Evaluation Metric") + plt.ylabel("Average Score") + + figure = fig.get_figure() + figure.savefig(os.path.join(save_path, f"{category}.png"), dpi=400) + + plt.close() diff --git a/applications/ColossalEval/colossal_eval/evaluate/utils.py b/applications/ColossalEval/colossal_eval/evaluate/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..69fec46705ab63db75e919b38bdfb334817984b4 --- /dev/null +++ b/applications/ColossalEval/colossal_eval/evaluate/utils.py @@ -0,0 +1,8 @@ +def get_data_per_category(data, categories): + data_per_category = {category: [] for category in categories} + for item in data: + category = item["category"] + if category in categories: + data_per_category[category].append(item) + + return data_per_category diff --git a/applications/ColossalEval/colossal_eval/models/__init__.py b/applications/ColossalEval/colossal_eval/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..8f6c9b414145d3d78f1e684219b5de8c4b8f2b16 --- /dev/null +++ b/applications/ColossalEval/colossal_eval/models/__init__.py @@ -0,0 +1,5 @@ +from .base import BaseModel +from .chatglm import ChatGLM2Model, ChatGLMModel +from .huggingface import HuggingFaceCausalLM, HuggingFaceModel + +__all__ = ["BaseModel", "HuggingFaceModel", "HuggingFaceCausalLM", "ChatGLMModel", "ChatGLM2Model"] diff --git a/applications/ColossalEval/colossal_eval/models/base.py b/applications/ColossalEval/colossal_eval/models/base.py new file mode 100644 index 0000000000000000000000000000000000000000..aae796c1d56e2b006ff3cd389be052484ce597c9 --- /dev/null +++ b/applications/ColossalEval/colossal_eval/models/base.py @@ -0,0 +1,78 @@ +from abc import abstractclassmethod +from typing import Dict, List + +from colossal_eval.utils import Conversation, prompt_templates + +from colossalai.logging import DistributedLogger + + +class BaseModel: + """ + Base class for model wrapper. + + Args: + path: The path to the model. + model_max_length: The maximum sequence length of the model. + prompt_template: The model's prompt template. + batch_size: Batch size for inference. + logger: Logger for the model. + """ + + def __init__( + self, + path: str, + model_max_length: int = 2048, + prompt_template: Conversation = None, + batch_size: int = 1, + logger: DistributedLogger = None, + ): + self.path = path + self.model_max_length = model_max_length + + if prompt_template: + self.prompt_template = prompt_template + else: + self.prompt_template = prompt_templates["plain"] + + self.batch_size = batch_size + self.logger = logger + + @abstractclassmethod + def inference(self, data: List[Dict]) -> None: + """ + Infer the given data. + This function will call self.generate() to get model outputs and also self.model(input) to get logits. + + Args: + data: The data for inference. + """ + + @abstractclassmethod + def generate(self, inputs: List[str], max_new_tokens: int) -> List[str]: + """ + Generate results given a list of inputs. + + Args: + inputs: A list of strings. + max_new_tokens: The maximum length of the output. + + Returns: + A list of generated strings. + """ + + @abstractclassmethod + def get_loss(self, batch: List[str], batch_target: List[str]) -> List[float]: + """ + Get loss given batch and batch with target. + Use their length difference after tokenization to mask the loss and only compute loss at target tokens. + + Args: + batch: batch prompt without target answer. + batch_target: batch prompt with target answer. + + Returns: + A list of loss. + """ + + def to(self, device): + self.model.to(device) diff --git a/applications/ColossalEval/colossal_eval/models/chatglm.py b/applications/ColossalEval/colossal_eval/models/chatglm.py new file mode 100644 index 0000000000000000000000000000000000000000..f293c4f699cda91f98dbd3cc9ddd7821cf8c2753 --- /dev/null +++ b/applications/ColossalEval/colossal_eval/models/chatglm.py @@ -0,0 +1,303 @@ +import copy +from typing import List + +import torch + +from .huggingface import HuggingFaceModel + +IGNORE_INDEX = -100 + + +class ChatGLMModel(HuggingFaceModel): + def _get_truncated_prompts(self, inputs: List[str], max_new_tokens: int) -> List[str]: + truncated_inputs = copy.deepcopy(inputs) + # Adapted from https://github.com/THUDM/ChatGLM-6B/blob/main/ptuning/main.py#L187 + for i, input in enumerate(inputs): + a_ids = self.tokenizer.encode(text=input, truncation=False, add_special_tokens=False) + + if len(a_ids) > self.model_max_length - max_new_tokens: + half = (self.model_max_length - max_new_tokens) // 2 + prompt = self.tokenizer.decode(a_ids[:half], skip_special_tokens=True) + self.tokenizer.decode( + a_ids[-half:], skip_special_tokens=True + ) + truncated_inputs[i] = prompt + + return truncated_inputs + + @torch.no_grad() + def get_loss( + self, batch_prompt: List[str], batch_target: List[List[str]], pretrain: bool = False + ) -> List[List[float]]: + """ + Calculate loss only on target tokens. + + Args: + batch: A batch of prompt without target answer. + batch_target: A batch of target answer. Sometimes one question can have multiple target answers. + + Returns: + Loss. + + """ + + # We set max_new_tokens in self._get_truncated_prompts to 0 because we only need logits to calculate loss. + # We don't need to generate new tokens. + # Target answer's length is usually << model_max_length, but we still call it in case. + # We don't call self._get_truncated_prompts for batch_prompt because we need target answer's length first to reserve some space for target answer's tokens. + batch_target = [self._get_truncated_prompts(prompt_target, 0) for prompt_target in batch_target] + + # Get the number of target answers for different questions + batch_target_nums = [len(prompt_target) for prompt_target in batch_target] + + labels_list = [] + input_ids_list = [] + + for input, targets in zip(batch_prompt, batch_target): + for target in targets: + # Adapted from https://github.com/THUDM/ChatGLM-6B/blob/main/ptuning/main.py#L187 + # If there is no history, the prompt is just the query. + # We don't need to override self.generate() in ChatGLM-6B but need to override it in ChatGLM2-6B. + # See https://huggingface.co/THUDM/chatglm-6b/blob/main/modeling_chatglm.py#L1276 + target_tokenized = self.tokenizer.encode(text=target, add_special_tokens=False) + + # Get prompt with length model_max_length - len(target_tokenized). + # Reserve some space for target answer tokens using max_new_tokens. + # This will generate the correct start_idx and end_idx. + max_new_tokens = len(target_tokenized) + + # Here 3 tokens are reserved for [gmask_id, bos_token, eos_id]. So we reserve max_new_tokens + 3 tokens. + # See https://huggingface.co/THUDM/chatglm-6b/blob/main/tokenization_chatglm.py#L323 + prompt_with_correct_length = self._get_truncated_prompts([input], max_new_tokens + 3)[0] + input_tokenized = self.tokenizer.encode(prompt_with_correct_length, add_special_tokens=False) + + input_ids = self.tokenizer.build_inputs_with_special_tokens(input_tokenized, target_tokenized) + + context_length = input_ids.index(self.tokenizer.bos_token_id) + context_length - 1 + + target_ids = [IGNORE_INDEX] * len(input_ids) + + # -1 is for eos_token, we don't want to calculate loss on eos token. + target_ids[-max_new_tokens - 1 : -1] = input_ids[-max_new_tokens - 1 : -1] + + input_ids_list.append(torch.LongTensor(input_ids)) + labels_list.append(torch.LongTensor(target_ids)) + + # Because of multiple target answers, the final batch size may be greater than self.batch_size. + # We will generate new batches. + losses = [] + target_token_nums = [] + + batched_input_ids = [ + input_ids_list[i : i + self.batch_size] for i in range(0, len(input_ids_list), self.batch_size) + ] + batched_labels = [labels_list[i : i + self.batch_size] for i in range(0, len(labels_list), self.batch_size)] + + for batch_input_ids, batch_labels in zip(batched_input_ids, batched_labels): + losses_per_batch, target_token_num_per_batch = self._calculate_loss(batch_input_ids, batch_labels) + losses.extend(losses_per_batch) + target_token_nums.extend(target_token_num_per_batch) + + start_indice = 0 + losses_per_sample = [] + + target_token_nums_per_sample = [] + for length in batch_target_nums: + losses_per_sample.append(losses[start_indice : start_indice + length]) + target_token_nums_per_sample.append(target_token_nums[start_indice : start_indice + length]) + start_indice += length + + return losses_per_sample, target_token_nums_per_sample, None + + def _calculate_loss(self, input_ids_list: List[torch.LongTensor], labels: List[torch.LongTensor]) -> List[float]: + """ + Calculate loss only on target tokens. + Hugging Face generate() function can't return per sample loss. + It will only return the mean of the loss in a batch. + In torch.nn.CrossEntropyLoss(), reduction should be specified as "none" to get per sample loss. + + Args: + input_ids_list: A batch of input token ids. + labels: A batch of labels. + + Returns: + A list of loss. + + """ + input_ids = torch.nn.utils.rnn.pad_sequence( + input_ids_list, batch_first=True, padding_value=self.tokenizer.pad_token_id + ).to(torch.cuda.current_device()) + labels = torch.nn.utils.rnn.pad_sequence(labels, batch_first=True, padding_value=IGNORE_INDEX).to( + torch.cuda.current_device() + ) + + outputs = self.model(input_ids)[0] + + shift_logits = outputs[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + + loss_fct = torch.nn.CrossEntropyLoss(reduction="none", ignore_index=IGNORE_INDEX) + loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)).view(shift_labels.size()) + + lens = (labels != IGNORE_INDEX).sum(-1).cpu().numpy() + + loss_sum = loss.sum(-1).to(torch.float32).cpu().detach().numpy() + return loss_sum.tolist(), lens.tolist() + + +class ChatGLM2Model(ChatGLMModel): + def _get_truncated_prompts(self, inputs: List[str], max_new_tokens: int) -> List[str]: + truncated_inputs = copy.deepcopy(inputs) + # Adapted from https://github.com/THUDM/ChatGLM2-6B/blob/main/ptuning/main.py#L180 + for i, input in enumerate(inputs): + a_ids = self.tokenizer.encode(text=input, add_special_tokens=True, truncation=False) + + if len(a_ids) > self.model_max_length - max_new_tokens: + half = (self.model_max_length - max_new_tokens) // 2 + prompt = self.tokenizer.decode(a_ids[:half], skip_special_tokens=True) + self.tokenizer.decode( + a_ids[-half:], skip_special_tokens=True + ) + truncated_inputs[i] = prompt + + return truncated_inputs + + @torch.no_grad() + def generate(self, inputs: List[str], max_new_tokens: int, **kwargs) -> List[str]: + """Generate results given a list of inputs and get logits of the first new token over choices. + + Args: + inputs: A list of strings. + max_new_tokens: Max new tokens for generation. + kwargs: Key arguments for generation + + Returns: + A list of generated strings and logits over choices. + + Note: + Currently the function only returns the logits of the first new token. + It is used for single choice question. + For multiple choices question, please avoid using the loss over choices. + You should set argument choices as None in self.inference(). + + """ + # Follow the process of model.chat() method in modeling_chatglm2.py + # See https://huggingface.co/THUDM/chatglm2-6b/blob/main/modeling_chatglm.py#L1020 + # See https://huggingface.co/THUDM/chatglm2-6b/blob/main/modeling_chatglm.py#L1001 + + query = [] + for input in inputs: + prompt = self.tokenizer.build_prompt(input, None) + query.append(prompt) + + truncated_query = self._get_truncated_prompts(query, max_new_tokens) + + encoded_inputs = self.tokenizer( + truncated_query, + padding=True, + truncation=True, + return_tensors="pt", + max_length=self.model_max_length - max_new_tokens, + ).to(torch.cuda.current_device()) + + # Set output_scores=True to get prediction scores. + outputs = self.model.generate( + **encoded_inputs, max_new_tokens=max_new_tokens, return_dict_in_generate=True, output_scores=True, **kwargs + ) + + # We only need to decode predicted tokens. + sequences = outputs.sequences[:, encoded_inputs["input_ids"].shape[1] :] + + scores = [] + if self.indices_for_choices: + # If the question is a single-choice question, we will return the scores of specific indices for first predicted token. + # The indices are the tokenization results of the options for the single-choice question. + # For example, if the options of the question are A, B, C and D, we only returns scores at indices of A, B, C and D. + for option_indices in self.indices_for_choices: + scores.append(outputs.scores[0][:, option_indices].detach().cpu()) + + scores = torch.max(torch.stack(scores), dim=0)[0] + + decoded_sequences = self.tokenizer.batch_decode(sequences, skip_special_tokens=True) + + return decoded_sequences, scores + + @torch.no_grad() + def get_loss( + self, batch_prompt: List[str], batch_target: List[List[str]], pretrain: bool = False + ) -> List[List[float]]: + """ + Calculate loss only on target tokens. + + Args: + batch: A batch of prompt without target answer. + batch_target: A batch of target answer. Sometimes one question can have multiple target answers. + + Returns: + Loss. + + """ + + # We set max_new_tokens in self._get_truncated_prompts to 0 because we only need logits to calculate loss. + # We don't need to generate new tokens. + # Target answer's length is usually << model_max_length, but we still call it in case. + # We don't call self._get_truncated_prompts for batch_prompt because we need target answer's length first to reserve some space for target answer's tokens. + batch_target = [self._get_truncated_prompts(prompt_target, 0) for prompt_target in batch_target] + + # Get the number of target answers for different questions + batch_target_nums = [len(prompt_target) for prompt_target in batch_target] + + labels_list = [] + input_ids_list = [] + + for input, targets in zip(batch_prompt, batch_target): + for target in targets: + # Adapted from https://github.com/THUDM/ChatGLM2-6B/blob/main/ptuning/main.py#L180 + prompt = self.tokenizer.build_prompt(input, None) + + target_tokenized = self.tokenizer.encode( + text=target, add_special_tokens=False, truncation=True, max_length=self.model_max_length + ) + + max_new_tokens = len(target_tokenized) + prompt_with_correct_length = self._get_truncated_prompts([prompt], max_new_tokens)[0] + input_tokenized = self.tokenizer.encode( + prompt_with_correct_length, + add_special_tokens=True, + truncation=True, + max_length=self.model_max_length, + ) + + input_ids = input_tokenized + target_tokenized + [self.tokenizer.eos_token_id] + target_ids = [IGNORE_INDEX] * len(input_ids) + + # -1 is for "eos" + target_ids[-max_new_tokens - 1 : -1] = input_ids[-max_new_tokens - 1 : -1] + + input_ids_list.append(torch.LongTensor(input_ids)) + labels_list.append(torch.LongTensor(target_ids)) + + # Because of multiple target answers, the final batch size may be greater than self.batch_size. + # We will generate new batches. + losses = [] + target_token_nums = [] + + batched_input_ids = [ + input_ids_list[i : i + self.batch_size] for i in range(0, len(input_ids_list), self.batch_size) + ] + batched_labels = [labels_list[i : i + self.batch_size] for i in range(0, len(labels_list), self.batch_size)] + + for batch_input_ids, batch_labels in zip(batched_input_ids, batched_labels): + losses_per_batch, target_token_num_per_batch = self._calculate_loss(batch_input_ids, batch_labels) + losses.extend(losses_per_batch) + target_token_nums.extend(target_token_num_per_batch) + + start_indice = 0 + losses_per_sample = [] + + target_token_nums_per_sample = [] + for length in batch_target_nums: + losses_per_sample.append(losses[start_indice : start_indice + length]) + target_token_nums_per_sample.append(target_token_nums[start_indice : start_indice + length]) + start_indice += length + + return losses_per_sample, target_token_nums_per_sample, None diff --git a/applications/ColossalEval/colossal_eval/models/huggingface.py b/applications/ColossalEval/colossal_eval/models/huggingface.py new file mode 100644 index 0000000000000000000000000000000000000000..9f785a6aa9d1d1eb9b7986780df1ac4ca1ae7f4c --- /dev/null +++ b/applications/ColossalEval/colossal_eval/models/huggingface.py @@ -0,0 +1,561 @@ +import copy +import math +from typing import Any, Dict, List, Optional, Tuple + +import numpy as np +import torch +from colossal_eval.utils import Conversation, get_batch_prompt, is_rank_0 +from peft import PeftModel +from tqdm import tqdm +from transformers import AutoConfig, AutoModel, AutoModelForCausalLM, AutoTokenizer + +from colossalai.logging import DistributedLogger + +from .base import BaseModel + +IGNORE_INDEX = -100 + + +class HuggingFaceModel(BaseModel): + """ + Model wrapper around HuggingFace AutoModel models. + + Args: + path: The path to a HuggingFace model. + model_max_length: The maximum sequence length of the model. + tokenizer_path: The path to the tokenizer. + tokenizer_kwargs: Keyword arguments for the tokenizer. + peft_path: The name or path to the HuggingFace's PEFT model. + model_kwargs: Keyword arguments for the model. + prompt_template: The model's prompt template. + batch_size: Batch size for inference. + logger: Logger for the model. + + """ + + def __init__( + self, + path: str, + model_max_length: int = 2048, + tokenizer_path: Optional[str] = None, + tokenizer_kwargs: dict = dict(), + peft_path: Optional[str] = None, + model_kwargs: Dict = None, + prompt_template: Conversation = None, + batch_size: int = 1, + logger: DistributedLogger = None, + ): + super().__init__( + path=path, + model_max_length=model_max_length, + prompt_template=prompt_template, + batch_size=batch_size, + logger=logger, + ) + self._load_tokenizer(path=path, tokenizer_path=tokenizer_path, tokenizer_kwargs=tokenizer_kwargs) + + self._load_model(path=path, model_kwargs=model_kwargs, peft_path=peft_path) + + def _get_choices_indices(self, language: str): + """ + Get indices for each choice + + Some tokenizer will insert BOS if you don't specify add_special_tokens=False such as Llama-2. + The indices for choices may be different given the context. For example, for Llama-2 tokenizer, for Chinese context like "答案:{choice}", indices for choices A, B, C and D are 29909, 29933, 29907 and 29928, for English context like "Answer: {choice}", indices for choices A, B, C and D are 319, 350, 315 and 360. + print(self.tokenizer("答案:A")) to see + print(self.tokenizer("Answer: A")) to see + + """ + + # A trick for get "all" tokens ids related to given choices. + self.indices_for_choices = [[] for _ in range(2)] + for choice in self.choices: + self.indices_for_choices[0].append( + self.tokenizer(f"Answer: {choice}", add_special_tokens=False).input_ids[-1] + ) + self.indices_for_choices[1].append(self.tokenizer(f"答案:{choice}", add_special_tokens=False).input_ids[-1]) + + def _load_tokenizer(self, path: str, tokenizer_path: Optional[str], tokenizer_kwargs: dict): + """ + Load tokenizer. + + Args: + path: The path to the model. Usually it also serves as the path to the tokenizer. + tokenizer_path: The path to the tokenzier. + tokenizer_kwargs: Keyword arguments for the tokenizer. + + """ + + if self.batch_size > 1: + tokenizer_kwargs.update({"padding_side": "left"}) + tokenizer_kwargs.update({"truncation_side": "left"}) + + self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path if tokenizer_path else path, **tokenizer_kwargs) + + if self.tokenizer.pad_token_id is None: + self.logger.warning("pad_token_id is not set for the tokenizer. " "Using eos_token_id as pad_token_id.") + if self.tokenizer.eos_token: + self.tokenizer.pad_token = self.tokenizer.eos_token + elif self.tokenizer.eod_id: + # Qwen has an eod token "<|endoftext|>". + self.tokenizer.pad_token_id = self.tokenizer.eod_id + + def _load_model(self, path: str, model_kwargs: dict, peft_path: Optional[str] = None): + """ + Load model. + + Args: + path: The path to the model. + model_kwargs: Keyword arguments for the model. + peft_path: The path to the peft model. + + """ + + if "torch_dtype" in model_kwargs: + model_kwargs["torch_dtype"] = eval(model_kwargs["torch_dtype"]) + + model_kwargs.setdefault("torch_dtype", torch.float16) + + self.model = AutoModel.from_pretrained(path, **model_kwargs).to(torch.cuda.current_device()) + if peft_path is not None: + self.model = PeftModel.from_pretrained(self.model, peft_path, is_trainable=False) + self.model.eval() + + def _calculate_loss(self, input_ids_list: List[torch.LongTensor], labels: List[torch.LongTensor]) -> Tuple[List]: + """ + Calculate loss only on target tokens. + Hugging Face generate() function can't return per sample loss. + It will only return the mean of the loss in a batch. + In torch.nn.CrossEntropyLoss(), reduction should be specified as "none" to get per sample loss. + + Args: + input_ids_list: A batch of input token ids. + labels: A batch of labels. + + Returns: + A list of loss. + + """ + input_ids = torch.nn.utils.rnn.pad_sequence( + input_ids_list, batch_first=True, padding_value=self.tokenizer.pad_token_id + ).to(torch.cuda.current_device()) + labels = torch.nn.utils.rnn.pad_sequence(labels, batch_first=True, padding_value=IGNORE_INDEX).to( + torch.cuda.current_device() + ) + attention_mask = input_ids.ne(self.tokenizer.pad_token_id).to(torch.cuda.current_device()) + + outputs = self.model(input_ids, attention_mask=attention_mask)[0] + + shift_logits = outputs[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + + loss_fct = torch.nn.CrossEntropyLoss(reduction="none", ignore_index=IGNORE_INDEX) + loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)).view(shift_labels.size()) + + lens = (labels != IGNORE_INDEX).sum(-1).cpu().numpy() + + loss_sum = loss.sum(-1).to(torch.float32).cpu().detach().numpy() + return loss_sum.tolist(), lens.tolist() + + def _get_truncated_prompts(self, inputs: List[str], max_new_tokens: int) -> List[str]: + """ + Truncate the input sequence to fit model_max_length (we suggest truncate in the middle, since the left and right side may contain crucial instructions) + https://github.com/THUDM/LongBench/blob/main/pred.py#L16 + + Args: + inputs: A batch of input prompts. + max_new_tokens: Max new tokens for model to generate. + + Returns: + Truncated prompts. + + """ + + truncated_inputs = copy.deepcopy(inputs) + for i, input in enumerate(inputs): + tokenized_prompt = self.tokenizer(input, truncation=False, return_tensors="pt").input_ids[0] + if len(tokenized_prompt) > self.model_max_length - max_new_tokens: + half = (self.model_max_length - max_new_tokens) // 2 + prompt = self.tokenizer.decode( + tokenized_prompt[:half], skip_special_tokens=True + ) + self.tokenizer.decode(tokenized_prompt[-half:], skip_special_tokens=True) + truncated_inputs[i] = prompt + + return truncated_inputs + + def _get_input_ids_and_labels_pretrain(self, batch_prompt: List[str]) -> Tuple[List[torch.LongTensor]]: + """ + Get input_ids and labels for pretrain data. + We only need batch_prompt because for pretain dataset, we don't need to predict new tokens. + + Args: + batch_prompt: A batch of prompt. + + Returns: + Input_ids and labels for the given batch. + + """ + input_ids_list = [] + labels_list = [] + bytes_list = [] + + for input in batch_prompt: + # Pretrain data tends to be very long, sometimes much larger than the model_max_length, we only tokenize 1/ratio of the data first to accelerate the tokenization process. + # Once the length of the result is greater or equal to model_max_length, we stop iterating on ratios and use the result as input_ids and labels. + # After all, the rest of the original string doesn't need to be tokenized at the first place. + ratio = [16, 8, 4, 2, 1] + tokenized = None + for r in ratio: + tokenized = self.tokenizer( + [input[0 : len(input) // r]], truncation=True, max_length=self.model_max_length, return_tensors="pt" + ) + if tokenized.input_ids.size(1) >= self.model_max_length: + break + + input_ids = copy.deepcopy(tokenized["input_ids"])[0] + target_ids = copy.deepcopy(input_ids) + + string = self.tokenizer.decode(tokenized.input_ids[0], skip_special_tokens=True) + + bytes_list.append(len(string.encode("utf-8"))) + + input_ids_list.append(input_ids) + labels_list.append(target_ids) + + return input_ids_list, labels_list, bytes_list + + def _get_input_ids_and_labels( + self, batch_prompt: List[str], batch_target: List[List[str]], pretrain: bool + ) -> Tuple[List[torch.LongTensor]]: + """ + Get input_ids and labels for the given data. + + Args: + batch_prompt: A batch of prompt. + batch_target: A batch of target. + + Returns: + Input_ids and labels for the given batch. + + """ + if pretrain: + return self._get_input_ids_and_labels_pretrain(batch_prompt) + + input_ids_list = [] + labels_list = [] + + for input, targets in zip(batch_prompt, batch_target): + for target in targets: + # TODO: Improve the labeling process. Should annotate the border by adding special tokens. + target_tokenized = self.tokenizer( + [target], truncation=True, max_length=self.model_max_length, return_tensors="pt" + ) + + # Get prompt with length model_max_length - len(target_tokenized). + # Reserve some space for target answer tokens using max_new_tokens. + # This will generate the correct start_idx and end_idx. + max_new_tokens = target_tokenized["input_ids"][0].size(0) + prompt_with_correct_length = self._get_truncated_prompts([input], max_new_tokens)[0] + input_tokenized = self.tokenizer( + [prompt_with_correct_length], + truncation=True, + max_length=self.model_max_length - max_new_tokens, + return_tensors="pt", + ) + + target_tokenized = self.tokenizer( + [prompt_with_correct_length + target], + truncation=True, + max_length=self.model_max_length, + return_tensors="pt", + ) + + start_idx = input_tokenized["input_ids"][0].size(0) + end_idx = target_tokenized["input_ids"][0].size(0) + + # Sometimes if the target is only an option such as A, B, C and D, the length of input_tokenized is equal to the length of target_tokenized, so we need -1. + # This is caused by the different behavior of tokenizers. + # For example, the tokenizer for Baichuan and Llama will cause such problem in a plain prompt setting. + # The length of the tokenized sequences for prompt "Answer: " and "Answer: A" is the same. + # Baichuan: [29394, 31143, 31106] [29394, 31143, 703] + # Llama: [673, 29901, 29871] [673, 29901, 319] + # The length for sequence "prompt" and "prompt + A" is equal. + # For ChatGLM, the length of the tokenized sequences is different. + # ChatGLM: [16583, 12] [16583, 12, 167] + + if start_idx == end_idx: + start_idx -= 1 + + input_ids = copy.deepcopy(target_tokenized["input_ids"])[0] + target_ids = copy.deepcopy(input_ids) + + mask = torch.zeros_like(target_ids, dtype=torch.bool) + mask[start_idx:end_idx] = True + + target_ids[~mask] = IGNORE_INDEX + + input_ids_list.append(input_ids) + labels_list.append(target_ids) + + return input_ids_list, labels_list, None + + def inference(self, data: List[Dict], inference_kwargs: Dict[str, Any], debug: bool = False) -> List[Dict]: + """ + Infer the given data. + This function will call self.generate() to get model outputs and also self.model() to get logits. + + Args: + data: The data for inference. + inference_kwargs: Arguments for inference. + debug: Whether to display generated prompt for debugging. + + Returns: + Inference results. + + """ + calculate_loss = inference_kwargs["calculate_loss"] + classes = inference_kwargs["all_classes"] + language = inference_kwargs["language"] + pretrain = inference_kwargs["pretrain"] + max_new_tokens = inference_kwargs["max_new_tokens"] + few_shot_data = inference_kwargs.get("few_shot_data", None) + + # Some classification questions' options are texts not a single letter such as A, B, C and D. + # If the text length is greater than 1, we won't calculate loss over choices. + if classes is not None and any(len(c) > 1 for c in classes): + classes = None + + self.choices = classes + self.indices_for_choices = None + if self.choices: + # Get indices for each choice + self._get_choices_indices(language) + + self.str_label_map = {choice: idx for idx, choice in enumerate(self.choices)} + + bar = tqdm( + range(math.ceil(len(data) / self.batch_size)), + desc=f"{data[0]['dataset']}-{data[0]['category']} Inference steps", + disable=not is_rank_0(), + ) + loss_fct = torch.nn.CrossEntropyLoss(reduction="none") + + answers = copy.deepcopy(data) + for i in range(0, len(data), self.batch_size): + batch = data[i : i + self.batch_size] + batch_prompt, batch_target = get_batch_prompt( + self.prompt_template, batch, few_shot_data, self.tokenizer, language, self.model_max_length + ) + + if is_rank_0() and debug and i == 0: + self.logger.info( + f"Inference arguments for dataset {data[0]['dataset']} category {data[0]['category']} is:\n{inference_kwargs}" + ) + self.logger.info("-" * 120) + self.logger.info("An example prompt and prompt with target is:") + self.logger.info("-" * 120) + self.logger.info(batch_prompt[0]) + self.logger.info("-" * 120) + self.logger.info(batch_prompt[0] + batch_target[0][0]) + + if not pretrain: + batch_decodes, scores = self.generate(batch_prompt, max_new_tokens) + + if calculate_loss: + batch_losses, batch_target_token_nums, batch_bytes_nums = self.get_loss( + batch_prompt, batch_target, pretrain + ) + + probs = [] + if self.indices_for_choices: + scores = scores.to(torch.float32) + # If we have indices_for_choices(must be single-choice question), there will be only one target answer for one data sample. + # Otherwise this will violate the single-choice setting. + + if calculate_loss: + labels = [self.str_label_map[answers[i + j]["target"]] for j in range(len(batch_decodes))] + + loss_over_choices = loss_fct(scores, torch.tensor(labels, dtype=torch.long)).numpy().tolist() + + probs = torch.nn.functional.softmax(scores, dim=-1).numpy().tolist() + probs = [ + {choice: probs[i][self.str_label_map[choice]] for choice in self.choices} for i in range(len(probs)) + ] + + for j in range(len(batch_prompt)): + if not pretrain: + answers[i + j]["output"] = batch_decodes[j].strip() + + if isinstance(scores, torch.Tensor): + answers[i + j]["softmax_over_choices"] = probs[j] + + if calculate_loss: + answers[i + j]["loss_over_choices"] = loss_over_choices[j] + + if calculate_loss: + answers[i + j]["loss"] = (np.array(batch_losses[j]) / np.array(batch_target_token_nums[j])).tolist() + + # loss_sum is specially used for pertrain dataset for calculating per-byte-perplexity. + # However, loss (which is per sample loss) suffices for most cases. + answers[i + j]["loss_sum"] = batch_losses[j] + answers[i + j]["token_num"] = batch_target_token_nums[j] + + if batch_bytes_nums: + answers[i + j]["byte_num"] = batch_bytes_nums[j] + + bar.update() + + return answers + + @torch.no_grad() + def generate(self, inputs: List[str], max_new_tokens: int, **kwargs) -> List[str]: + """Generate results given a list of inputs and get logits of the first new token over choices. + + Args: + inputs: A list of strings. + max_new_tokens: Max new tokens for generation. + kwargs: Key arguments for generation + + Returns: + A list of generated strings and logits over choices. + + Note: + Currently the function only returns the logits of the first new token. + It is used for single choice question. + For multiple choices question, please avoid using the loss over choices. + You should set argument choices as None in self.inference(). + + """ + truncated_inputs = self._get_truncated_prompts(inputs, max_new_tokens) + + encoded_inputs = self.tokenizer( + truncated_inputs, + padding=True, + truncation=True, + return_tensors="pt", + return_token_type_ids=False, + max_length=self.model_max_length - max_new_tokens, + ).to(torch.cuda.current_device()) + + # Set output_scores=True to get prediction scores. + outputs = self.model.generate( + **encoded_inputs, max_new_tokens=max_new_tokens, return_dict_in_generate=True, output_scores=True, **kwargs + ) + + # We only need to decode predicted tokens. + sequences = outputs.sequences[:, encoded_inputs["input_ids"].shape[1] :] + + scores = [] + if self.indices_for_choices: + # If the question is a single-choice question, we will return the scores of specific indices for first predicted token. + # The indices are the tokenization results of the options for the single-choice question. + # For example, if the options of the question are A, B, C and D, we only returns scores at indices of A, B, C and D. + for option_indices in self.indices_for_choices: + scores.append(outputs.scores[0][:, option_indices].detach().cpu()) + + scores = torch.max(torch.stack(scores), dim=0)[0] + + decoded_sequences = self.tokenizer.batch_decode(sequences, skip_special_tokens=True) + + return decoded_sequences, scores + + @torch.no_grad() + def get_loss(self, batch_prompt: List[str], batch_target: List[List[str]], pretrain: bool) -> List[List[float]]: + """ + Calculate loss only on target tokens. + + Args: + batch: A batch of prompt without target answer. + batch_target: A batch of target answer. Sometimes one question can have multiple target answers. + + Returns: + Loss. + + """ + + # We set max_new_tokens in self._get_truncated_prompts to 0 because we only need logits to calculate loss. + # We don't need to generate new tokens. + # Target answer's length is usually << model_max_length, but we still call it in case. + # We don't call self._get_truncated_prompts for batch_prompt because we need target answer's length first to reserve some space for target answer's tokens. + if not pretrain: + batch_target = [self._get_truncated_prompts(prompt_target, 0) for prompt_target in batch_target] + + # Get the number of target answers for different questions + batch_target_nums = [len(prompt_target) for prompt_target in batch_target] + + input_ids_list, labels_list, bytes_list = self._get_input_ids_and_labels(batch_prompt, batch_target, pretrain) + + # Because of multiple target answers, the final batch size may be greater than self.batch_size. + # We will generate new batches. + losses = [] + target_token_nums = [] + + batched_input_ids = [ + input_ids_list[i : i + self.batch_size] for i in range(0, len(input_ids_list), self.batch_size) + ] + batched_labels = [labels_list[i : i + self.batch_size] for i in range(0, len(labels_list), self.batch_size)] + + for batch_input_ids, batch_labels in zip(batched_input_ids, batched_labels): + losses_per_batch, target_token_num_per_batch = self._calculate_loss(batch_input_ids, batch_labels) + losses.extend(losses_per_batch) + target_token_nums.extend(target_token_num_per_batch) + + start_indice = 0 + losses_per_sample = [] + + target_token_nums_per_sample = [] + bytes_nums_per_sample = [] + for length in batch_target_nums: + losses_per_sample.append(losses[start_indice : start_indice + length]) + target_token_nums_per_sample.append(target_token_nums[start_indice : start_indice + length]) + + if bytes_list: + bytes_nums_per_sample.append(bytes_list[start_indice : start_indice + length]) + + start_indice += length + + if bytes_list: + return losses_per_sample, target_token_nums_per_sample, bytes_nums_per_sample + + return losses_per_sample, target_token_nums_per_sample, None + + +class HuggingFaceCausalLM(HuggingFaceModel): + """ + Model wrapper around HuggingFace AutoModelForCausalLM models. + + Args: + path: The path to a HuggingFace model. + model_max_length: The maximum sequence length of the model. + tokenizer_path: The path to the tokenizer. + tokenizer_kwargs: Keyword arguments for the tokenizer. + peft_path: The name or path to the HuggingFace's PEFT model. + model_kwargs: Keyword arguments for the model. + prompt_template: The model's prompt template. + batch_size: Batch size for inference. + logger: Logger for the model. + + """ + + def _load_model(self, path: str, model_kwargs: dict, peft_path: Optional[str] = None): + """ + Load model. + + Args: + path: The path to the model. + model_kwargs: Keyword arguments for the model. + peft_path: The path to the peft model. + + """ + + if "torch_dtype" in model_kwargs: + model_kwargs["torch_dtype"] = eval(model_kwargs["torch_dtype"]) + + if "config" in model_kwargs: + model_kwargs["config"] = AutoConfig.from_pretrained(model_kwargs["config"]) + + model_kwargs.setdefault("torch_dtype", torch.float16) + self.model = AutoModelForCausalLM.from_pretrained(path, **model_kwargs).to(torch.cuda.current_device()) + if peft_path is not None: + self.model = PeftModel.from_pretrained(self.model, peft_path, is_trainable=False) + self.model.eval() diff --git a/applications/ColossalEval/colossal_eval/utils/__init__.py b/applications/ColossalEval/colossal_eval/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..d5ee6e13b74732d2563715951f29bc134b8563d2 --- /dev/null +++ b/applications/ColossalEval/colossal_eval/utils/__init__.py @@ -0,0 +1,4 @@ +from .conversation import Conversation, get_batch_prompt, prompt_templates +from .utilities import get_json_list, is_rank_0, jdump, jload + +__all__ = ["Conversation", "prompt_templates", "get_batch_prompt", "is_rank_0", "jload", "jdump", "get_json_list"] diff --git a/applications/ColossalEval/colossal_eval/utils/conversation.py b/applications/ColossalEval/colossal_eval/utils/conversation.py new file mode 100644 index 0000000000000000000000000000000000000000..6c096a8523c0ec8df5d8953ff8ba6302c290bb41 --- /dev/null +++ b/applications/ColossalEval/colossal_eval/utils/conversation.py @@ -0,0 +1,231 @@ +import dataclasses +from enum import Enum, auto +from typing import Dict, List, Optional, Tuple + +from transformers import AutoTokenizer + + +class SeparatorStyle(Enum): + ADD_BOS_EOS_TOKEN = auto() + ALPACA = auto() + PLAIN = auto() + + +@dataclasses.dataclass +class Conversation: + system: str + roles: List[str] + messages: List[List[str]] + offset: int + sep_style: SeparatorStyle = SeparatorStyle.ADD_BOS_EOS_TOKEN + sep: str = "" + + def clear(self): + self.messages = [] + + def get_prompt(self): + if self.sep_style == SeparatorStyle.ADD_BOS_EOS_TOKEN: + ret = self.system + for role, message in self.messages: + if message: + ret += role + ": " + "" + message + self.sep + else: + ret += role + ": " + "" + return ret + elif self.sep_style == SeparatorStyle.ALPACA: + ret = self.system + self.sep + for role, message in self.messages: + if message: + ret += role + ":\n" + message + self.sep + else: + ret += role + ":" + return ret + elif self.sep_style == SeparatorStyle.PLAIN: + ret = self.system + for role, message in self.messages: + if message: + ret += message + else: + ret += "" + return ret + else: + raise ValueError(f"Invalid style: {self.sep_style}") + + def get_prompt_with_target(self, target): + prompt = self.get_prompt() + prompt_with_target = [] + + # Some dataset provides multiple target answers. + # This will make it difficult when we calculate loss. + # We convert target into list[str] first if the question only has one target answer. + target_answers = [] + if isinstance(target, str): + target_answers = [target] + else: + target_answers = target + + for target_answer in target_answers: + if self.sep_style == SeparatorStyle.ADD_BOS_EOS_TOKEN: + prompt_with_target.append(prompt + target_answer) + elif self.sep_style == SeparatorStyle.ALPACA: + prompt_with_target.append(prompt + target_answer) + elif self.sep_style == SeparatorStyle.PLAIN: + prompt_with_target.append(prompt + target_answer) + else: + raise ValueError(f"Invalid style: {self.sep_style}") + + return prompt_with_target + + def save_prompt(self): + if self.sep_style == SeparatorStyle.ADD_BOS_EOS_TOKEN: + ret = self.system + for role, message in self.messages: + if message: + ret += role + ": " + "" + message + "\n" + else: + ret += role + ": " + "" + return ret + else: + raise ValueError(f"Invalid style: {self.sep_style}") + + def append_message(self, role, message): + self.messages.append([role, message]) + + def copy(self): + return Conversation( + system=self.system, + roles=self.roles, + messages=[[x, y] for x, y in self.messages], + offset=self.offset, + sep_style=self.sep_style, + sep=self.sep, + ) + + def dict(self): + return { + "system": self.system, + "roles": self.roles, + "messages": self.messages, + "offset": self.offset, + "sep_style": self.sep_style, + "sep": self.sep, + } + + +def get_few_shot_prefix( + conv: Conversation, few_shot_data: List[str], tokenizer: Optional[AutoTokenizer], language: str, max_tokens: int +) -> str: + """ + Get few shot prefix. + + Args: + conv: Conversation template. + few_shot_examples: Few shot examples to generate few shot prompt prefix. + + Returns: + Few shot prompt prefix. + """ + + if language == "English": + few_shot_prefix = f"The following are answers for questions in an exam.\n\n" + elif language == "Chinese": + few_shot_prefix = f"以下是考试中各个问题的答案。\n\n" + + output = None + for i in range(len(few_shot_data)): + few_shot_prefix = few_shot_prefix + few_shot_data[i] + "\n\n" + + if len(tokenizer([few_shot_prefix]).input_ids[0]) <= max_tokens: + output = few_shot_prefix + else: + break + + return output if output is not None else few_shot_prefix + + +def get_batch_prompt( + conv: Conversation, + batch: List[Dict], + few_shot_data: List[str], + tokenizer: Optional[AutoTokenizer], + language: Optional[str], + model_max_length: Optional[int], +) -> Tuple[List[Dict], List[Dict]]: + """ + Get batch prompt and target. + + Args: + conv: Conversation template. + batch: Batch data to generate prompt from. + few_shot_data: Few shot data to generate few shot prompt prefix. + + Returns: + Tuple containg batch prompt and target. + + """ + + batch_prompt = [] + batch_target = [] + + if isinstance(batch[0], dict): + for b in batch: + few_shot_prefix = "" + if few_shot_data is not None: + # For few-shot, only need input. Otherwise use instruction (in AGIEval). + query_text = b["input"] if b.get("input", "") != "" else b["instruction"] + + if isinstance(b["target"], str): + zero_shot_prompt = query_text + b["target"] + max_tokens = model_max_length - len(tokenizer([zero_shot_prompt]).input_ids[0]) + else: + raise Exception("When using few-shot, target answer should be a string.") + + few_shot_prefix = get_few_shot_prefix(conv, few_shot_data, tokenizer, language, max_tokens) + else: + query_text = b["instruction"] + "\n\n" + b["input"] if b.get("input", "") != "" else b["instruction"] + + conv.append_message(conv.roles[0], few_shot_prefix + query_text) + conv.append_message(conv.roles[1], None) + + batch_prompt.append(conv.get_prompt()) + + target = b["target"] + if isinstance(b["target"], str): + target = [target] + + batch_target.append(target) + + conv.clear() + + return batch_prompt, batch_target + + +conv_coati = Conversation( + system="A chat between a curious human and an artificial intelligence assistant. " + "The assistant gives helpful, detailed, and polite answers to the human's questions.\n\n", + roles=("Human", "Assistant"), + messages=[], + offset=0, + sep_style=SeparatorStyle.ADD_BOS_EOS_TOKEN, + sep="", +) + +conv_alpaca = Conversation( + system="Below is an instruction that describes a task. Write a response that appropriately completes the request.", + roles=("### Instruction", "### Response"), + messages=[], + offset=0, + sep_style=SeparatorStyle.ALPACA, + sep="\n\n", +) + +conv_plain = Conversation( + system="", + roles=("", ""), + messages=[], + offset=0, + sep_style=SeparatorStyle.PLAIN, + sep="", +) + +prompt_templates = {"coati": conv_coati, "alpaca": conv_alpaca, "plain": conv_plain} diff --git a/applications/ColossalEval/colossal_eval/utils/utilities.py b/applications/ColossalEval/colossal_eval/utils/utilities.py new file mode 100644 index 0000000000000000000000000000000000000000..4eda079074952da15b9b53068b6b2f43b8a3a8e5 --- /dev/null +++ b/applications/ColossalEval/colossal_eval/utils/utilities.py @@ -0,0 +1,62 @@ +import io +import json +import os + +import torch.distributed as dist + + +def is_rank_0() -> bool: + return not dist.is_initialized() or dist.get_rank() == 0 + + +def _make_w_io_base(f, mode: str): + if not isinstance(f, io.IOBase): + f_dirname = os.path.dirname(f) + if f_dirname != "": + os.makedirs(f_dirname, exist_ok=True) + f = open(f, mode=mode, encoding="utf-8") + return f + + +def _make_r_io_base(f, mode: str): + if not isinstance(f, io.IOBase): + f = open(f, mode=mode, encoding="utf-8") + return f + + +def jdump(obj, f, mode="w", indent=4, default=str): + """ + Dump a str or dictionary to a file in json format. + + Args: + obj: An object to be written. + f: A string path to the location on disk. + mode: Mode for opening the file. + indent: Indent for storing json dictionaries. + default: A function to handle non-serializable entries; defaults to `str`. + + """ + f = _make_w_io_base(f, mode) + if isinstance(obj, (dict, list)): + json.dump(obj, f, indent=indent, default=default, ensure_ascii=False) + elif isinstance(obj, str): + f.write(obj) + else: + raise ValueError(f"Unexpected type: {type(obj)}") + f.close() + + +def jload(f, mode="r"): + """Load a .json file into a dictionary.""" + f = _make_r_io_base(f, mode) + jdict = json.load(f) + f.close() + return jdict + + +def get_json_list(file_path): + with open(file_path, "r") as f: + json_list = [] + for line in f: + json_list.append(json.loads(line if line != "null" else line)) + return json_list diff --git a/applications/ColossalEval/configs/gpt_evaluation/config/config_cn.json b/applications/ColossalEval/configs/gpt_evaluation/config/config_cn.json new file mode 100644 index 0000000000000000000000000000000000000000..d7c8648810084cdbe3aec95b8045a7d14ac5af6e --- /dev/null +++ b/applications/ColossalEval/configs/gpt_evaluation/config/config_cn.json @@ -0,0 +1,44 @@ +{ + "language": "cn", + "category": { + "brainstorming": { + "GPT": [ + "language organization", + "relevance", + "creativity", + "practicality", + "reasonableness" + ] + }, + "chat": { + "GPT": [ + "language organization", + "naturalness", + "engagingness", + "fidelity" + ] + }, + "generation": { + "GPT": [ + "language organization", + "relevance", + "diversity" + ] + }, + "open_qa": { + "GPT": [ + "language organization", + "relevance", + "correctness" + ] + }, + "roleplay": { + "GPT": [ + "language organization", + "relevance", + "fidelity", + "creativity" + ] + } + } +} diff --git a/applications/ColossalEval/configs/gpt_evaluation/config/config_en.json b/applications/ColossalEval/configs/gpt_evaluation/config/config_en.json new file mode 100644 index 0000000000000000000000000000000000000000..6ebe3996b1cf72fe75f67c570fad2c857e583158 --- /dev/null +++ b/applications/ColossalEval/configs/gpt_evaluation/config/config_en.json @@ -0,0 +1,44 @@ +{ + "language": "en", + "category": { + "brainstorming": { + "GPT": [ + "language organization", + "relevance", + "creativity", + "practicality", + "reasonableness" + ] + }, + "chat": { + "GPT": [ + "language organization", + "naturalness", + "engagingness", + "fidelity" + ] + }, + "generation": { + "GPT": [ + "language organization", + "relevance", + "diversity" + ] + }, + "open_qa": { + "GPT": [ + "language organization", + "relevance", + "correctness" + ] + }, + "roleplay": { + "GPT": [ + "language organization", + "relevance", + "fidelity", + "creativity" + ] + } + } +} diff --git a/applications/ColossalEval/configs/gpt_evaluation/data/eval_cn_examples.json b/applications/ColossalEval/configs/gpt_evaluation/data/eval_cn_examples.json new file mode 100644 index 0000000000000000000000000000000000000000..f869830555b4b59d22195187f53ee2c25e7881c8 --- /dev/null +++ b/applications/ColossalEval/configs/gpt_evaluation/data/eval_cn_examples.json @@ -0,0 +1,202 @@ +[ + { + "category": "brainstorming", + "instruction": "列举一些可以促进头发生长的食物。", + "input": "", + "output": "", + "target": "", + "id": 1 + }, + { + "category": "brainstorming", + "instruction": "中年夫妻如何提升夫妻感情,请给出三个实用的的方法,并举例说明。", + "input": "", + "output": "", + "target": "", + "id": 2 + }, + { + "category": "brainstorming", + "instruction": "请列举4种日常的环保行为。", + "input": "", + "output": "", + "target": "", + "id": 3 + }, + { + "category": "brainstorming", + "instruction": "请给出5个可以随时随地锻炼身体的小动作。", + "input": "", + "output": "", + "target": "", + "id": 4 + }, + { + "category": "brainstorming", + "instruction": "请问如何制作一份美味的西红柿炒鸡蛋?", + "input": "", + "output": "", + "target": "", + "id": 5 + }, + { + "category": "chat", + "instruction": "基于以下角色信息完成一段对话。小张是一名新手爱好者,对养鸡有浓厚的兴趣。老李是一名有丰富经验的养鸡大师。", + "input": "小张:您好,老李,我最近开始对养鸡感兴趣了,想请教您一些问题。 老李:你好,小张,我很乐意帮助你。你想问些什么? 小张:我想知道如何确定鸡的品种和性别? 老李:确切的品种可以通过鸡的外貌特征来确定,而性别一般是通过鸡卵的大小和形状来判断。还有什么问题吗? 小张:", + "output": "", + "target": "", + "id": 6 + }, + { + "category": "chat", + "instruction": "基于以下角色信息完成一段对话。李华是一名参加了期末考试的学生,他已经很担心自己的考试成绩。老师Lucy正在帮助他度过这个紧张的时刻。", + "input": "李华:Lucy老师,我很担心自己的考试成绩,我不知道我是否能够通过这次考试。 Lucy:放松,李华,你已经做好了充分的准备。相信你自己,你会做得很好的。 李华:我很怕考试时会忘记自己所学的知识。 Lucy:你可以预留一些时间,过一遍自己所学的知识点或笔记,这样你会更有信心和准确地回答考题。 李华:如果我还是失败了,该怎么办? Lucy:", + "output": "", + "target": "", + "id": 7 + }, + { + "category": "chat", + "instruction": "基于以下角色信息完成一段对话。张先生是一名企业家,正在考虑是否开拓海外市场;李女士是一名跨境电商专家,擅长国际商务和电子商务。", + "input": "张先生:你好,李女士,我正在考虑将我们的产品销售扩大至海外市场,您有什么建议吗? 李女士:您好,张先生,我们需要考虑到海外市场对于产品的需求是否与国内市场一致,需要进行市场调研和定位。然后再进行各种软性、硬性的创新。 张先生:听起来很专业,您能具体解释一下吗? 李女士:", + "output": "", + "target": "", + "id": 8 + }, + { + "category": "chat", + "instruction": "基于以下角色信息完成一段对话。小明是一名医生。一名病患想要提前停药。小王是病患的儿子,希望父亲能够听取医生的建议。", + "input": "小明:你好,小王,我了解你想要让你父亲停药。小王:是的,我父亲已经吃了那么久的药,我担心药物对他的身体会有副作用。小明:", + "output": "", + "target": "", + "id": 9 + }, + { + "category": "chat", + "instruction": "基于以下角色信息完成一段对话。张三是一位语文老师,对学生认真负责;李四是张三的学生,对语文兴趣不是很高。", + "input": "张三:同学们,今天要讲的是一篇古文《岳阳楼记》。这篇文章非常精彩,希望同学们能够认真听课,理解其中的含义。 李四:怎么又是古文? 张三:", + "output": "", + "target": "", + "id": 10 + }, + { + "category": "generation", + "instruction": "根据主题写一封邮件。", + "input": "主题: \"加入我们,共创未来\"", + "output": "", + "target": "", + "id": 11 + }, + { + "category": "generation", + "instruction": "为公司编写一份职场行为准则,包括明确的行为规范和道德准则。", + "input": "", + "output": "", + "target": "", + "id": 12 + }, + { + "category": "generation", + "instruction": "请撰写一篇文章,介绍如何通过改善生活习惯来预防疾病和延长寿命。", + "input": "", + "output": "", + "target": "", + "id": 13 + }, + { + "category": "generation", + "instruction": "请为一家咖啡店编写一篇简短的广告语,吸引更多的顾客。", + "input": "", + "output": "", + "target": "", + "id": 14 + }, + { + "category": "generation", + "instruction": "根据以下故事提示写一篇故事:", + "input": "故事提示:```在一个废弃的古堡中,一个小女孩遇到了一只会说话的黑猫,他们一起揭开了一个古老的谜题。```", + "output": "", + "target": "", + "id": 15 + }, + { + "category": "open_qa", + "instruction": "请介绍一下《红楼梦》这部经典小说的故事情节。", + "input": "", + "output": "", + "target": "", + "id": 16 + }, + { + "category": "open_qa", + "instruction": "解释什么是RNA病毒和DNA病毒。", + "input": "", + "output": "", + "target": "", + "id": 17 + }, + { + "category": "open_qa", + "instruction": "什么是比特币?", + "input": "", + "output": "", + "target": "", + "id": 18 + }, + { + "category": "open_qa", + "instruction": "在计算机中,什么是RAM?与ROM有什么区别?", + "input": "", + "output": "", + "target": "", + "id": 19 + }, + { + "category": "open_qa", + "instruction": "请简单介绍一下世界上最长的河流途经的国家。", + "input": "", + "output": "", + "target": "", + "id": 20 + }, + { + "category": "roleplay", + "instruction": "我要你把我写的句子翻译成表情符号。我会写句子,你会用表情符号表达它。我只是想让你用表情符号来表达它。除了表情符号,我不希望你回复任何内容。当我需要用中文告诉你一些事情时,我会用 {} 这样的大括号括起来。我的第一句话是“{我的职业是消防员。}”\n", + "input": "", + "output": "", + "target": "", + "id": 21 + }, + { + "category": "roleplay", + "instruction": "我希望你假定自己是雅思写作考官,根据雅思评判标准,按我给你的雅思考题和对应答案给我评分,并且按照雅思写作评分细则给出打分依据。此外,请给我详细的修改意见并写出满分范文。第一个问题是:It is sometimes argued that too many students go to university, while others claim that a university education should be a universal right. Discuss both sides of the argument and give your own opinion.对于这个问题,我的答案是:In some advanced countries, it is not unusual for more than 50% of young adults to attend college or university. Critics, however, claim that many university courses are worthless and young people would be better off gaining skills in the workplace. In this essay, I will examine both sides of this argument and try to reach a conclusion.There are several reasons why young people today believe they have the right to a university education. First, growing prosperity in many parts of the world has increased the number of families with money to invest in their children’s future. At the same time, falling birthrates mean that one- or two-child families have become common, increasing the level of investment in each child. It is hardly surprising, therefore, that young people are willing to let their families support them until the age of 21 or 22. Furthermore, millions of new jobs have been created in knowledge industries, and these jobs are typically open only to university graduates.However, it often appears that graduates end up in occupations unrelated to their university studies. It is not uncommon for an English literature major to end up working in sales, or an engineering graduate to retrain as a teacher, for example. Some critics have suggested that young people are just delaying their entry into the workplace, rather than developing professional skills.请依次给到我以下内容:具体分数及其评分依据、文章修改意见、满分范文。\n", + "input": "", + "output": "", + "target": "", + "id": 22 + }, + { + "category": "roleplay", + "instruction": "我想让你充当 Linux 终端。我将输入命令,您将回复终端应显示的内容。我希望您只在一个唯一的代码块内回复终端输出,而不是其他任何内容。不要写解释。除非我指示您这样做,否则不要键入命令。当我需要用英语告诉你一些事情时,我会把文字放在中括号内[就像这样]。我的第一个命令是 pwd\n", + "input": "", + "output": "", + "target": "", + "id": 23 + }, + { + "category": "roleplay", + "instruction": "我希望你充当宠物行为主义者。我将为您提供一只宠物和它们的主人,您的目标是帮助主人了解为什么他们的宠物表现出某些行为,并提出帮助宠物做出相应调整的策略。您应该利用您的动物心理学知识和行为矫正技术来制定一个有效的计划,双方的主人都可以遵循,以取得积极的成果。我的第一个请求是“我有一只好斗的德国牧羊犬,它需要帮助来控制它的攻击性。”\n", + "input": "", + "output": "", + "target": "", + "id": 24 + }, + { + "category": "roleplay", + "instruction": "我希望你充当正则表达式生成器。您的角色是生成匹配文本中特定模式的正则表达式。您应该以一种可以轻松复制并粘贴到支持正则表达式的文本编辑器或编程语言中的格式提供正则表达式。不要写正则表达式如何工作的解释或例子;只需提供正则表达式本身。我的第一个提示是生成一个匹配电子邮件地址的正则表达式。\n", + "input": "", + "output": "", + "target": "", + "id": 25 + } +] diff --git a/applications/ColossalEval/configs/gpt_evaluation/data/eval_en_examples.json b/applications/ColossalEval/configs/gpt_evaluation/data/eval_en_examples.json new file mode 100644 index 0000000000000000000000000000000000000000..27b8af8bc4c6b77628ef72ba57f0869c6abf26a5 --- /dev/null +++ b/applications/ColossalEval/configs/gpt_evaluation/data/eval_en_examples.json @@ -0,0 +1,202 @@ +[ + { + "category": "brainstorming", + "instruction": "Which are some popular fiction books that I should read?", + "input": "", + "output": "", + "target": "", + "id": 1 + }, + { + "category": "brainstorming", + "instruction": "How do I properly store fruits and vegetables to keep them fresh for longer?", + "input": "", + "output": "", + "target": "", + "id": 2 + }, + { + "category": "brainstorming", + "instruction": "How do you properly chop an onion without crying?", + "input": "", + "output": "", + "target": "", + "id": 3 + }, + { + "category": "brainstorming", + "instruction": "How to make an international transfer? Please provide 3 techniques.", + "input": "", + "output": "", + "target": "", + "id": 4 + }, + { + "category": "brainstorming", + "instruction": "Name five leadership qualities that you consider most important.", + "input": "", + "output": "", + "target": "", + "id": 5 + }, + { + "category": "chat", + "instruction": "Complete a dialogue based on the following character information. Alex: A novice writer who is struggling to find inspiration and develop his writing skills. Emma: A successful author with many published works, providing guidance and advice to Alex.", + "input": "Alex: Hi Emma, I have been writing for a while now but can't seem to make any progress. Can you give me any advice? Emma: Hi Alex, sure. What kind of writing are you doing? Alex: I'm trying to write a novel, but I just can't seem to find any inspiration. Emma: ", + "output": "", + "target": "", + "id": 6 + }, + { + "category": "chat", + "instruction": "Complete a dialogue based on the following character information. John: An experienced software engineer with a passion for coding. Karen: A recent college graduate who is interested in learning more about software development.", + "input": "Karen: Hi John, I noticed that you have a lot of experience in the software industry. Can you tell me what you think is the most important skill for a software engineer? John: ", + "output": "", + "target": "", + "id": 7 + }, + { + "category": "chat", + "instruction": "Complete a dialogue based on the following character information. Sarah is a new employee who is nervous about her first presentation; Tom is her boss who has given her coaching and preparation materials.", + "input": "Sarah: Tom, I'm feeling really nervous about my presentation tomorrow. Tom: I know how you feel, Sarah. However, I believe in you and your abilities. Just stick to the preparation materials that I have given you, and you'll do great. Sarah: Thank you, Tom. What if I forget something important during the presentation? Tom: ", + "output": "", + "target": "", + "id": 8 + }, + { + "category": "chat", + "instruction": "Complete a dialogue based on the following character information. Sarah: a young artist who is full of creative ideas and always eager to try new things. Jack: a seasoned artist who has achieved great success in the art world and is more traditional in his approach to art.", + "input": "Sarah: Hi Jack, I'm really excited to meet you. I'm a big fan of your work. Jack: Hi Sarah, nice to meet you too. So, what kind of art do you do? Sarah: I am passionate about abstract art, especially combining different materials and colors. I think it can really give people a new perspective on things. Jack: That's interesting, but I am more focused on realistic paintings. I believe the most important thing is to master the basic skills first. Sarah: ", + "output": "", + "target": "", + "id": 9 + }, + { + "category": "chat", + "instruction": "Complete a conversation based on the following persona information. Sarah is a college student who is interested in joining a volunteer organization. John is the leader of the volunteer organization and is eager to welcome new members.", + "input": "Sarah: Hi, I'm Sarah, and I'm interested in joining your volunteer organization. John: Hi Sarah, welcome! We're always looking for new members who are passionate about volunteering. What areas would you like to focus on? Sarah: I'm interested in community outreach and working with children. John: ", + "output": "", + "target": "", + "id": 10 + }, + { + "category": "generation", + "instruction": "Write an email based on the subject:", + "input": "Subject: \"Invitation to an Exclusive Webinar\"", + "output": "", + "target": "", + "id": 11 + }, + { + "category": "generation", + "instruction": "Write a set of guidelines for first-time pet owners on how to properly care for a new puppy.", + "input": "", + "output": "", + "target": "", + "id": 12 + }, + { + "category": "generation", + "instruction": "Can you help me write a persuasive speech on why we should recycle more and take better care of the environment?", + "input": "", + "output": "", + "target": "", + "id": 13 + }, + { + "category": "generation", + "instruction": "Write a pitch for a brand-new mobile app that helps people organize their daily tasks more efficiently.", + "input": "", + "output": "", + "target": "", + "id": 14 + }, + { + "category": "generation", + "instruction": "Write a social media post promoting an upcoming concert featuring a local band.", + "input": "", + "output": "", + "target": "", + "id": 15 + }, + { + "category": "open_qa", + "instruction": "Describe the significance of the Renaissance period in European history.", + "input": "", + "output": "", + "target": "", + "id": 16 + }, + { + "category": "open_qa", + "instruction": "What is the term for the surgical removal of the appendix?", + "input": "", + "output": "", + "target": "", + "id": 17 + }, + { + "category": "open_qa", + "instruction": "Explain the process of osmosis in biological systems.", + "input": "", + "output": "", + "target": "", + "id": 18 + }, + { + "category": "open_qa", + "instruction": "Who were the members of the Beatles band?", + "input": "", + "output": "", + "target": "", + "id": 19 + }, + { + "category": "open_qa", + "instruction": "Who painted the The Scream?", + "input": "", + "output": "", + "target": "", + "id": 20 + }, + { + "category": "roleplay", + "instruction": "I want you to act as a linux terminal. I will type commands and you will reply with what the terminal should show. I want you to only reply with the terminal output inside one unique code block, and nothing else. do not write explanations. do not type commands unless I instruct you to do so. when i need to tell you something in english, i will do so by putting text inside curly brackets {like this}. my first command is pwd", + "input": "", + "output": "", + "target": "", + "id": 21 + }, + { + "category": "roleplay", + "instruction": "I want you to act as a travel guide. I will write you my location and you will suggest a place to visit near my location. In some cases, I will also give you the type of places I will visit. You will also suggest me places of similar type that are close to my first location. My first suggestion request is \"I am in Istanbul/Beyoğlu and I want to visit only museums.\"", + "input": "", + "output": "", + "target": "", + "id": 22 + }, + { + "category": "roleplay", + "instruction": "I want you to act as an advertiser. You will create a campaign to promote a product or service of your choice. You will choose a target audience, develop key messages and slogans, select the media channels for promotion, and decide on any additional activities needed to reach your goals. My first suggestion request is \"I need help creating an advertising campaign for a new type of energy drink targeting young adults aged 18-30.\"", + "input": "", + "output": "", + "target": "", + "id": 23 + }, + { + "category": "roleplay", + "instruction": "I want you to act as a storyteller. You will come up with entertaining stories that are engaging, imaginative and captivating for the audience. It can be fairy tales, educational stories or any other type of stories which has the potential to capture people's attention and imagination. Depending on the target audience, you may choose specific themes or topics for your storytelling session e.g., if it’s children then you can talk about animals; If it’s adults then history-based tales might engage them better etc. My first request is \"I need an interesting story on perseverance.\"", + "input": "", + "output": "", + "target": "", + "id": 24 + }, + { + "category": "roleplay", + "instruction": "I want you to act as a rapper. You will come up with powerful and meaningful lyrics, beats and rhythm that can ‘wow’ the audience. Your lyrics should have an intriguing meaning and message which people can relate too. When it comes to choosing your beat, make sure it is catchy yet relevant to your words, so that when combined they make an explosion of sound everytime! My first request is \"I need a rap song about finding strength within yourself.\"", + "input": "", + "output": "", + "target": "", + "id": 25 + } +] diff --git a/applications/ColossalEval/configs/gpt_evaluation/prompt/battle_prompt/battle_prompt_cn.json b/applications/ColossalEval/configs/gpt_evaluation/prompt/battle_prompt/battle_prompt_cn.json new file mode 100644 index 0000000000000000000000000000000000000000..ca66afd7e4644a2854640305feb3dabadc4818b1 --- /dev/null +++ b/applications/ColossalEval/configs/gpt_evaluation/prompt/battle_prompt/battle_prompt_cn.json @@ -0,0 +1,6 @@ +{ + "id": 1, + "system_prompt": "你是一个检查回答质量的好助手。", + "prompt_template": "[问题]\n{question}\n\n[1号AI助手的答案]\n{answer_1}\n\n[1号AI助手答案终止]\n\n[2号AI助手的答案]\n{answer_2}\n\n[2号AI助手答案终止]\n\n[要求]\n{prompt}\n\n", + "prompt": "我们需要你评价这两个AI助手回答的性能。\n请对他们的回答的有用性、相关性、准确性、详细程度进行评分。每个AI助手都会得到一个1到10分的总分,分数越高表示整体表现越好。\n请首先输出一行,该行只包含两个数值,分别表示1号和2号AI助手的分数。这两个分数之间要有一个空格。在随后的一行中,请对你的评价作出全面的解释,避免任何潜在的偏见,并确保AI助手回答的顺序不会影响您的判断。" +} diff --git a/applications/ColossalEval/configs/gpt_evaluation/prompt/battle_prompt/battle_prompt_en.json b/applications/ColossalEval/configs/gpt_evaluation/prompt/battle_prompt/battle_prompt_en.json new file mode 100644 index 0000000000000000000000000000000000000000..2b35d1958ac5e7ef3d00d3a0fbcfdb88e8f06b11 --- /dev/null +++ b/applications/ColossalEval/configs/gpt_evaluation/prompt/battle_prompt/battle_prompt_en.json @@ -0,0 +1,6 @@ +{ + "id": 1, + "system_prompt": "You are a helpful and precise assistant for checking the quality of the answer. You will be given two different answers to the same question", + "prompt_template": "[Question]\n{question}\n\n[The Start of AI Assistant 1's Answer]\n{answer_1}\n\n[The End of AI Assistant 1's Answer]\n\n[The Start of AI Assistant 2's Answer]\n{answer_2}\n\n[The End of AI Assistant 2's Answer]\n\n[Requirements]\n{prompt}\n\n", + "prompt": "We would like to request your feedback on the performance of two AI assistants in response to the user question displayed above.\nPlease rate the helpfulness, relevance, accuracy, level of details of their responses. Each assistant receives an overall score on a scale of 1 to 10, where a higher score indicates better overall performance.\nPlease first output a single line containing only two values indicating the scores for Assistant 1 and 2, respectively. The two scores are separated by a space. In the subsequent line, please provide a comprehensive explanation of your evaluation, avoiding any potential bias and ensuring that the order in which the responses were presented does not affect your judgment." +} diff --git a/applications/ColossalEval/configs/gpt_evaluation/prompt/evaluation_prompt/evaluation_prompt_cn.json b/applications/ColossalEval/configs/gpt_evaluation/prompt/evaluation_prompt/evaluation_prompt_cn.json new file mode 100644 index 0000000000000000000000000000000000000000..70f6c3ebc31632bb8d2a4ddab1965d693053a60a --- /dev/null +++ b/applications/ColossalEval/configs/gpt_evaluation/prompt/evaluation_prompt/evaluation_prompt_cn.json @@ -0,0 +1,102 @@ +{ + "brainstorming": { + "id": 1, + "category": "brainstorming", + "metrics": { + "language organization": "语言组织(1-5):答案语言是否流畅、连贯,使用正确的语法,具有一定逻辑性,使用恰当的连接词、过渡词等等。", + "relevance": "切题(1-5):答案内容是否切题,不答非所问,并且严格遵照题目要求。", + "creativity": "创意性(1-5):某些头脑风暴问题可能需要答案具有创意,提出新的思路。", + "practicality": "实用性(1-5):某些头脑风暴问题可能需要答案提出实用的建议或解决方法。", + "reasonableness": "合理性(1-5):答案应该符合常识、生活实际等等。" + }, + "CoT": { + "language organization": "1. 阅读答案,并检查是否有语法错误、用词不当或其他显著的错误。\n2. 检查答案是否具有逻辑性,能够按照合理的顺序传达信息并且能够自圆其说。\n3. 确定答案是否与问题或主题相关,并且能够传达清晰的信息。\n4. 检查答案是否连贯,是否使用适当的转换和过渡来保持句子和段落之间的连贯性。\n5. 检查答案是否具有明确的结构和组织方式,使得读者可以轻松理解信息的层次和结构。\n6. 根据以上因素综合评估答案的语言组织,并给出一个1到5的分数,其中5表示语言组织非常好,而1表示语言组织非常差。\n\n语言组织:", + "relevance": "1. 阅读题目,确定题目所问的问题是什么,以及需要回答哪些方面的问题。\n2. 阅读答案,确认答案是否直接回答了题目所问的问题。\n3. 检查答案是否严格遵照了题目的要求,包括答题方式、答题长度、答题格式等等。\n4. 根据以上因素综合评估答案的切题程度,并给出一个1到5的分数,其中5表示答案非常切题,而1表示答案完全没有切题。\n\n切题:", + "creativity": "1. 仔细阅读所提供的头脑风暴问题,确保你理解问题的要点和背景。\n2. 根据你的知识和经验,判断所提供的答案是否可行。如果答案不可行,则创意性评分可能会受到影响。\n3. 考虑答案中是否包含新颖的想法或独特的思路。答案可能与已知的解决方案有所重叠,但仍然可以被认为是有创意的,只要它提供了新的角度或方法来解决问题。\n4. 根据答案的创意性,给出一个1到5的评分。如果答案缺乏创意,则应给出一个较低的评分。如果答案具有创意并提供了新的思路,应给出一个较高的评分。\n\n创意性:", + "practicality": "1. 仔细阅读所提供的头脑风暴问题,确保你理解问题的要点和背景。\n2. 根据你的知识和经验,判断所提供的答案是否可行。如果答案不可行,则实用性评分可能会受到影响。\n3. 考虑答案中提出的建议或解决方法是否实用并可行。答案可能看起来很好,但如果无法实现或应用,则实用性评分可能会受到影响。\n4. 根据答案的实用性,给出一个1到5的评分。如果答案缺乏实用性,则应给出一个较低的评分。如果答案提出了实用的建议或解决方法,并且可以很好地解决问题,则应给出一个较高的评分。\n\n实用性:", + "reasonableness": "1. 仔细阅读所提供的头脑风暴问题,确保你理解问题的要点和背景。\n2. 根据你的知识和经验,判断所提供的答案是否可行。如果答案不可行,则合理性评分可能会受到影响。\n3. 考虑答案中所提供的信息是否合理、符合常识、生活实际等等。如果答案中存在明显的不合理之处,则合理性评分可能会受到影响。\n4. 根据答案的合理性,给出一个1到5的评分。如果答案存在明显的不合理之处,则应给出一个较低的评分。如果答案合理、符合常识、生活实际等等,则应给出一个较高的评分。\n\n合理性:" + }, + "prompt": "你是一个好助手。请你为下面“头脑风暴”问题的答案打分。\n\n问题如下:\n\n{question}\n\n答案如下:\n\n{answer}\n\n评分的指标如下:\n\n{metric}\n\n请你遵照以下的评分步骤:\n\n{steps}" + }, + "chat": { + "id": 2, + "category": "chat", + "metrics": { + "language organization": "语言组织(1-5):答案语言是否流畅、连贯,使用正确的语法,具有一定逻辑性,使用恰当的连接词、过渡词等等。", + "relevance": "切题(1-5):答案内容是否切题,不答非所问,并且严格遵照题目要求。", + "naturalness": "自然(1-5):答案是否自然,并且符合问题给定的身份。", + "engagingness": "参与感(1-5):答案是否对前面的对话内容做出了恰当的反应,是否理解对话的语境和背景。", + "reasonableness": "合理性(1-5):答案是否能够与前面的对话内容形成逻辑上的衔接,是否符合常理,能否在这个上下文中合理存在。", + "fidelity": "保真度(1-5):答案是否能够严格遵守角色的设定回答给定的请求。" + }, + "CoT": { + "language organization": "1. 阅读答案,并检查是否有语法错误、用词不当或其他显著的错误。\n2. 检查答案是否具有逻辑性,能够按照合理的顺序传达信息并且能够自圆其说。\n3. 确定答案是否与问题或主题相关,并且能够传达清晰的信息。\n4. 检查答案是否连贯,是否使用适当的转换和过渡来保持句子和段落之间的连贯性。\n5. 检查答案是否具有明确的结构和组织方式,使得读者可以轻松理解信息的层次和结构。\n6. 根据以上因素综合评估答案的语言组织,并给出一个1到5的分数,其中5表示语言组织非常好,而1表示语言组织非常差。\n\n语言组织:", + "relevance": "1. 阅读题目,确定题目所问的问题是什么,以及需要回答哪些方面的问题。\n2. 阅读答案,确认答案是否直接回答了题目所问的问题。\n3. 检查答案是否严格遵照了题目的要求,包括答题方式、答题长度、答题格式等等。\n4. 根据以上因素综合评估答案的切题程度,并给出一个1到5的分数,其中5表示答案非常切题,而1表示答案完全没有切题。\n\n切题:", + "naturalness": "1. 阅读题目,确定题目提供的身份信息。\n2. 检查答案内容是否符合题目给定的身份。\n3. 根据以上因素,对该回答的自然性进行打分,分数从1到5,其中1表示不自然,5表示非常自然,并符合问题给定的身份。\n\n自然:", + "engagingness": "1. 阅读题目,确定对话的语境和背景。\n2. 检查答案是否充分理解对话的语境和背景,能否自然地融入到对话中而不显得突兀。\n3. 根据以上因素,对该回答的参与感进行打分,分数从1到5,其中1表示没有参与感,5表示非常有参与感,并且恰当地理解了对话的语境和背景。\n\n参与感:", + "reasonableness": "1. 阅读题目,确定对话的主题以及问题期望的回答方向。\n2. 判断答案是否能够与前面的对话内容形成逻辑上的衔接,是否符合常理,能否在这个上下文中合理存在。\n3. 根据以上因素,对该回答的合理性进行打分,分数从1到5,其中1表示不合理,5表示非常合理,并且能够与前面的对话内容形成逻辑上的衔接,并符合常理。\n\n合理性:", + "fidelity": "1. 仔细阅读问题,了解角色在问题中的设定和表现,包括职业、背景、观点、性格等方面。\n阅读题目的请求,确认回答请求时需要注意的细节。\n3. 对比提供的回答与该角色的设定,评估回答是否能够严格遵守角色的设定。\n4. 结合以上评估结果给出保真度的评分,范围从1到5分,其中1分表示回答与角色设定完全不符,5分表示回答完全符合角色设定且满足给定请求。\n\n保真度:" + }, + "prompt": "你是一个好助手。请你为下面的“补全对话”问题的答案打分。\n\n问题如下:\n\n{question}\n\n答案如下:\n\n{answer}\n\n评分的指标如下:\n\n{metric}\n\n请你遵照以下的评分步骤:\n\n{steps}" + }, + "generation": { + "id": 3, + "category": "generation", + "metrics": { + "language organization": "语言组织(1-5):答案语言是否流畅、连贯,使用正确的语法,具有一定逻辑性,使用恰当的连接词、过渡词等等。", + "relevance": "切题(1-5):答案内容是否切题,不答非所问,并且严格遵照题目要求。", + "diversity": "多样性(1-5):答案使用语言是否优美,具有有一定的创造性和想象力。然而,回答也应该保持合理和适度,不要过于夸张或离题。" + }, + "CoT": { + "language organization": "1. 阅读答案,并检查是否有语法错误、用词不当或其他显著的错误。\n2. 检查答案是否具有逻辑性,能够按照合理的顺序传达信息并且能够自圆其说。\n3. 确定答案是否与问题或主题相关,并且能够传达清晰的信息。\n4. 检查答案是否连贯,是否使用适当的转换和过渡来保持句子和段落之间的连贯性。\n5. 检查答案是否具有明确的结构和组织方式,使得读者可以轻松理解信息的层次和结构。\n6. 根据以上因素综合评估答案的语言组织,并给出一个1到5的分数,其中5表示语言组织非常好,而1表示语言组织非常差。\n\n语言组织:", + "relevance": "1. 阅读题目,确定题目所问的问题是什么,以及需要回答哪些方面的问题。\n2. 阅读答案,确认答案是否直接回答了题目所问的问题。\n3. 检查答案是否严格遵照了题目的要求,包括答题方式、答题长度、答题格式等等。\n4. 根据以上因素综合评估答案的切题程度,并给出一个1到5的分数,其中5表示答案非常切题,而1表示答案完全没有切题。\n\n切题:", + "diversity": "1. 仔细阅读整个回答,确保完全理解回答所表达的内容和主题。\n2. 在阅读回答的同时,注意语言的质量,例如措辞是否正确,语言是否生动等。\n3. 检查回答的创造性和想象力,看看回答是否能够吸引人阅读下去。\n4. 检查回答的合理性和适度,看看回答是否夸张或离题。\n5. 将多样性的评分打分在1到5之间,5分表示回答的质量很好,能够吸引人阅读,1分表示回答的内容生硬或者有离题的问题。\n\n多样性:" + }, + "prompt": "你是一个好助手。请你为下面的“生成”问题的答案打分。\n\n问题如下:\n\n{question}\n\n答案如下:\n\n{answer}\n\n评分的指标如下:\n\n{metric}\n\n请你遵照以下的评分步骤:\n\n{steps}" + }, + "open_qa": { + "id": 4, + "category": "open_qa", + "metrics": { + "language organization": "语言组织(1-5):答案语言是否流畅、连贯,使用正确的语法,具有一定逻辑性,使用恰当的连接词、过渡词等等。", + "relevance": "切题(1-5):答案内容是否切题,不答非所问,并且严格遵照题目要求。", + "correctness": "正确性(1-5):答案是否正确。" + }, + "CoT": { + "language organization": "1. 阅读答案,并检查是否有语法错误、用词不当或其他显著的错误。\n2. 检查答案是否具有逻辑性,能够按照合理的顺序传达信息并且能够自圆其说。\n3. 确定答案是否与问题或主题相关,并且能够传达清晰的信息。\n4. 检查答案是否连贯,是否使用适当的转换和过渡来保持句子和段落之间的连贯性。\n5. 检查答案是否具有明确的结构和组织方式,使得读者可以轻松理解信息的层次和结构。\n6. 根据以上因素综合评估答案的语言组织,并给出一个1到5的分数,其中5表示语言组织非常好,而1表示语言组织非常差。\n\n语言组织:", + "relevance": "1. 阅读题目,确定题目所问的问题是什么,以及需要回答哪些方面的问题。\n2. 阅读答案,确认答案是否直接回答了题目所问的问题。\n3. 检查答案是否严格遵照了题目的要求,包括答题方式、答题长度、答题格式等等。\n4. 根据以上因素综合评估答案的切题程度,并给出一个1到5的分数,其中5表示答案非常切题,而1表示答案完全没有切题。\n\n切题:", + "correctness": "1. 仔细阅读题目,尝试自己回答该问题。\n2. 检查答案的准确性。您可以使用已知的事实或研究来验证答案是否正确。如果答案是正确的,则可以将正确性得分为5分。如果答案是部分正确的,则可以给予适当的得分,例如2分、3分或4分。如果答案完全不正确,则只得1分。\n\n正确性:" + }, + "prompt": "你是一个好助手。请你为下面的问题的答案打分。\n\n问题如下:\n\n{question}\n\n答案如下:\n\n{answer}\n\n评分的指标如下:\n\n{metric}\n\n请你遵照以下的评分步骤:\n\n{steps}" + }, + "roleplay": { + "id": 5, + "category": "roleplay", + "metrics": { + "language organization": "语言组织(1-5):答案语言是否流畅、连贯,使用正确的语法,具有一定逻辑性,使用恰当的连接词、过渡词等等。", + "relevance": "切题(1-5):答案内容是否切题,不答非所问,并且严格遵照题目要求。", + "fidelity": "保真度(1-5):答案是否能够严格遵守角色的设定回答给定的请求。", + "creativity": "创意性(1-5):角色扮演问题的回答需要具有一定创意,但同时需要遵守角色的设定。" + }, + "CoT": { + "language organization": "1. 阅读答案,并检查是否有语法错误、用词不当或其他显著的错误。\n2. 检查答案是否具有逻辑性,能够按照合理的顺序传达信息并且能够自圆其说。\n3. 确定答案是否与问题或主题相关,并且能够传达清晰的信息。\n4. 检查答案是否连贯,是否使用适当的转换和过渡来保持句子和段落之间的连贯性。\n5. 检查答案是否具有明确的结构和组织方式,使得读者可以轻松理解信息的层次和结构。\n6. 根据以上因素综合评估答案的语言组织,并给出一个1到5的分数,其中5表示语言组织非常好,而1表示语言组织非常差。\n\n语言组织:", + "relevance": "1. 阅读题目,确定题目所问的问题是什么,以及需要回答哪些方面的问题。\n2. 阅读答案,确认答案是否直接回答了题目所问的问题。\n3. 检查答案是否严格遵照了题目的要求,包括答题方式、答题长度、答题格式等等。\n4. 根据以上因素综合评估答案的切题程度,并给出一个1到5的分数,其中5表示答案非常切题,而1表示答案完全没有切题。\n\n切题:", + "fidelity": "1. 仔细阅读问题,了解角色在问题中的设定和表现,包括职业、背景、观点、性格等方面。\n2. 阅读题目的请求,确认回答请求时需要注意的细节。\n3. 对比提供的回答与该角色的设定,评估回答是否能够严格遵守角色的设定。\n4. 结合以上评估结果给出保真度的评分,范围从1到5分,其中1分表示回答与角色设定完全不符,5分表示回答完全符合角色设定且满足给定请求。\n\n保真度:", + "creativity": "1. 仔细阅读问题,了解角色在问题中的设定和表现,包括职业、背景、观点、性格等方面。\n2. 评估回答是否具有独特的思路和建议,是否能够给提问者带来新的想法和启示。\n3. 对比回答中的创意和该角色的设定,评估回答是否遵守了该角色的设定和基本特征。\n4. 对回答的质量进行总体评估,并结合以上评估结果给出创意性的评分,范围从1到5分,其中1分表示回答缺乏创意,5分表示回答具有独特的思路和建议,并且能够遵守该角色的设定。\n\n创意性:" + }, + "prompt": "你是一个好助手。请你为下面的“角色扮演”问题的答案打分。\n\n问题如下:\n\n{question}\n\n答案如下:\n\n{answer}\n\n评分的指标如下:\n\n{metric}\n\n请你遵照以下的评分步骤:\n\n{steps}" + }, + "Other": { + "id": 6, + "category": "Other", + "metrics": { + "relevance": "切题(1-5):答案内容是否切题,不答非所问,并且严格遵照题目要求。", + "correctness": "正确性(1-5):答案是否正确。" + }, + "CoT": { + "relevance": "1. 阅读题目,确定题目所问的问题是什么,以及需要回答哪些方面的问题。\n2. 阅读答案,确认答案是否直接回答了题目所问的问题。\n3. 检查答案是否严格遵照了题目的要求,包括答题方式、答题长度、答题格式等等。\n4. 根据以上因素综合评估答案的切题程度,并给出一个1到5的分数,其中5表示答案非常切题,而1表示答案完全没有切题。\n\n切题:", + "correctness": "1. 仔细阅读题目,尝试自己回答该问题。\n2. 检查答案的准确性。您可以使用已知的事实或研究来验证答案是否正确。如果答案是正确的,则可以将正确性得分为5分。如果答案是部分正确的,则可以给予适当的得分,例如2分、3分或4分。如果答案完全不正确,则只得1分。\n\n正确性:" + }, + "prompt": "你是一个好助手。请你为下面问题的答案打分。\n\n问题如下:\n\n{question}\n\n需要你评分的答案如下:\n\n{answer}\n\n评分的指标如下:\n\n{metric}\n\n请你遵照以下的评分步骤:\n\n{steps}" + } +} diff --git a/applications/ColossalEval/configs/gpt_evaluation/prompt/evaluation_prompt/evaluation_prompt_en.json b/applications/ColossalEval/configs/gpt_evaluation/prompt/evaluation_prompt/evaluation_prompt_en.json new file mode 100644 index 0000000000000000000000000000000000000000..3d04387d98c5d158abb2abbdb2baaad9e59ca89a --- /dev/null +++ b/applications/ColossalEval/configs/gpt_evaluation/prompt/evaluation_prompt/evaluation_prompt_en.json @@ -0,0 +1,103 @@ +{ + "brainstorming": { + "id": 1, + "category": "brainstorming", + "metrics": { + "language organization": "Language organization (1-5): whether the answer language is fluent and coherent, uses correct grammar, has a certain logic, uses appropriate connecting words, transition words, etc.", + "relevance": "Relevance (1-5): whether the content of the answer is relevant to the topic, does not answer the wrong question, and strictly follows the requirements of the topic.", + "creativity": "Creativity (1-5): Some brainstorming questions may require answers that are creative and suggest new ideas.", + "practicality": "Practicality (1-5): Some brainstorming questions may require answers to suggest practical suggestions or solutions.", + "reasonableness": "Reasonableness (1-5): The answer should be in line with common sense, life experience, etc." + }, + "CoT": { + "language organization": "1. Read the answers and check for grammatical errors, poor word choice, or other significant mistakes.\n2. Check that the answer is logical, conveys the information in a logical order, and is self-explanatory.\n3. Determine if the answer is relevant to the question or topic and conveys a clear message.\n4. Check that the answer is coherent and that appropriate transitions and switches are used to maintain coherence between sentences and paragraphs.\n5. Check that the answer is clearly structured and organized in such a way that the reader can easily understand the hierarchy and structure of the information.\n6. Evaluate the language organization of the answer based on a combination of the above factors and give a score of 1 to 5, where 5 indicates very good language organization and 1 indicates very poor language organization.\n\nLanguage organization:", + "relevance": "1. Read the question to determine what the question asks and what aspects of the question need to be answered.\n2. Read the answers to make sure that they directly answer the question asked.\n3. Check that the answer follows the requirements of the question, including the way it is answered, the length of the answer, the format of the answer, etc.\n4. Evaluate how relevant the answer is based on the above factors and give a score of 1 to 5, where 5 means the answer is very relevant and 1 means the answer is not relevant at all.\n\nRelevance:", + "creativity": "1. Read the provided brainstorming questions carefully to make sure you understand the gist and context of the questions.\n2. Based on your knowledge and experience, determine if the answers provided are feasible. If the answer is not feasible, the creativity score may be affected.\n3. Consider whether the answer contains novel ideas or unique thoughts. An answer may overlap with a known solution and still be considered creative, as long as it offers a new perspective or approach to the problem.\n4. Give a score of 1 to 5 depending on the creativity of the answer. If the answer lacks creativity, a lower score should be given. If the answer is creative and provides a new idea, a higher score should be given.\n\nCreativity:", + "practicality": "1. Read the provided brainstorming questions carefully to make sure you understand the gist and context of the questions.\n2. Based on your knowledge and experience, determine if the answers provided are feasible. If the answer is not feasible, the practicality score may be affected.\n3. Consider whether the suggestions or solutions presented in the answer are practical and workable. The answer may look good, but if it cannot be implemented or applied, the practicality score may be affected.\n4. Give a score of 1 to 5 depending on the practicality of the answer. If the answer lacks practicality, a lower score should be given. If the answer makes a practical suggestion or solution and solves the problem well, a higher score should be given.\n\nPracticality:", + "reasonableness": "1. Read the provided brainstorming questions carefully to make sure you understand the gist and context of the questions.\n2. Based on your knowledge and experience, determine if the answers provided are feasible. If the answer is not feasible, the reasonableness score may be affected.\n3. Consider whether the information provided in the answer is reasonable, consistent with common sense, real life, etc. If there are obvious errors or implausibilities in the answer, the reasonableness score may be affected.\n4. Give a score of 1 to 5 depending on the reasonableness of the answer. If the answer contains obvious errors or unreasonable points, a lower score should be given. A higher score should be given if the answer is reasonable, consistent with common sense, real life, etc.\n\nReasonableness:" + }, + "prompt": "You are a good assistant. Please rate the given answer to the \"brainstorming\" question below.\n\nThe question is as follows:\n\n{question}\n\nThe answer is as follows:\n\n{answer}\n\nThe metric for evaluation is as follows:\n\n{metric}\n\nYou should follow the following evaluation steps:\n\n{steps}" + }, + "chat": { + "id": 2, + "category": "chat", + "metrics": { + "language organization": "Language organization (1-5): whether the answer language is fluent and coherent, uses correct grammar, has a certain logic, uses appropriate connecting words, transition words, etc.", + "relevance": "Relevance (1-5): whether the content of the answer is relevant to the topic, does not answer the wrong question, and strictly follows the requirements of the topic.", + "naturalness": "Naturalness (1-5): whether the answer is natural and fits the identity given by the question.", + "engagingness": "Engagingness (1-5): whether the answer responds appropriately to the content of the preceding conversation and whether it understands the context and background of the conversation.", + "reasonableness": "Reasonableness (1-5): Whether the answer can form a logical connection with the content of the previous dialogue, whether it is consistent with common sense, and whether it can reasonably exist in this context.", + "fidelity": "Fidelity (1-5): whether the answer is able to answer the given request in strict compliance with the role setting." + }, + "CoT": { + "language organization": "1. Read the answers and check for grammatical errors, poor word choice, or other significant mistakes.\n2. Check that the answer is logical, conveys the information in a logical order, and is self-explanatory.\n3. Determine if the answer is relevant to the question or topic and conveys a clear message.\n4. Check that the answer is coherent and that appropriate transitions and switches are used to maintain coherence between sentences and paragraphs.\n5. Check that the answer is clearly structured and organized in such a way that the reader can easily understand the hierarchy and structure of the information.\n6. Evaluate the language organization of the answer based on a combination of the above factors and give a score of 1 to 5, where 5 indicates very good language organization and 1 indicates very poor language organization.\n\nLanguage organization:", + "relevance": "1. Read the question to determine what the question asks and what aspects of the question need to be answered.\n2. Read the answers to make sure that they directly answer the question asked.\n3. Check that the answer follows the requirements of the question, including the way it is answered, the length of the answer, the format of the answer, etc.\n4. Evaluate how relevant the answer is based on the above factors and give a score of 1 to 5, where 5 means the answer is very relevant and 1 means the answer is not relevant at all.\n\nRelevance:", + "naturalness": "1. Read the question and determine the identity information provided in the question.\n2. Check whether the content of the answer matches the identity given in the question.\n3. Based on the above factors, score the naturalness of the response on a scale from 1 to 5, where 1 means unnatural and 5 means very natural and in accordance with the identity given in the question.\n\nNaturalness:", + "engagingness": "1. Read the questions to determine the context and background of the dialogue.\n2. Check that the answer fully understands the context and background of the conversation and that it fits naturally into the conversation without seeming abrupt.\n3. Based on the above factors, rate the response's engagement on a scale from 1 to 5, where 1 means not engaged and 5 means very engaged and appropriately understands the context and background of the conversation.\n\nEngagingness:", + "reasonableness": "1. Read the question and determine the topic of the conversation and the direction the question expects the answer to go.\n2. Determine whether the answer can be logically connected to the preceding conversation, whether it makes common sense, and whether it can reasonably exist in this context.\n3. Based on the above factors, rate the reasonableness of the answer on a scale from 1 to 5, where 1 means unreasonable and 5 means very reasonable and able to form a logical connection with the preceding dialogue content and consistent with common sense.\n\nReasonableness:", + "fidelity": "1. Read the question carefully to understand how the character is set up and represented in the question, including aspects such as occupation, background, point of view, and personality.\n2. Read the question's request and confirm the details that need to be taken into account when answering the request.\n3. Compare the provided answer with the setting of the role and assess whether the answer can strictly adhere to the setting of the role.\n4. Combine the results of the above assessment to give a fidelity score ranging from 1 to 5, where a score of 1 means that the response does not match the persona at all, and a score of 5 means that the response fully complies with the persona and satisfies the given request.\n\nFidelity:" + }, + "prompt": "You are a good assistant. Please rate the given answer to the \"chat\" question below.\n\nThe question is as follows:\n\n{question}\n\nThe answer is as follows:\n\n{answer}\n\nThe metric for evaluation is as follows:\n\n{metric}\n\nYou should follow the following evaluation steps:\n\n{steps}" + }, + "generation": { + "id": 3, + "category": "generation", + "metrics": { + "language organization": "Language organization (1-5): whether the answer language is fluent and coherent, uses correct grammar, has a certain logic, uses appropriate connecting words, transition words, etc.", + "relevance": "Relevance (1-5): whether the content of the answer is relevant to the topic, does not answer the wrong question, and strictly follows the requirements of the topic.", + "diversity": "Diversity (1-5): Whether the answers use beautiful language and have some creativity and imagination. However, answers should also be kept reasonable and moderate, not overly exaggerated or off-topic." + }, + "CoT": { + "language organization": "1. Read the answers and check for grammatical errors, poor word choice, or other significant mistakes.\n2. Check that the answer is logical, conveys the information in a logical order, and is self-explanatory.\n3. Determine if the answer is relevant to the question or topic and conveys a clear message.\n4. Check that the answer is coherent and that appropriate transitions and switches are used to maintain coherence between sentences and paragraphs.\n5. Check that the answer is clearly structured and organized in such a way that the reader can easily understand the hierarchy and structure of the information.\n6. Evaluate the language organization of the answer based on a combination of the above factors and give a score of 1 to 5, where 5 indicates very good language organization and 1 indicates very poor language organization.\n\nLanguage organization:", + "relevance": "1. Read the question to determine what the question asks and what aspects of the question need to be answered.\n2. Read the answers to make sure that they directly answer the question asked.\n3. Check that the answer follows the requirements of the question, including the way it is answered, the length of the answer, the format of the answer, etc.\n4. Evaluate how relevant the answer is based on the above factors and give a score of 1 to 5, where 5 means the answer is very relevant and 1 means the answer is not relevant at all.\n\nRelevance:", + "diversity": "1. Read the entire response carefully to ensure that you fully understand the content and theme expressed in the response.\n2. While reading the response, pay attention to the quality of the language, such as whether the wording is correct and the language is vivid.\n3. Check the creativity and imagination of the response to see if the response is engaging to read on.\n4. Check the reasonableness and appropriateness of the responses to see if the responses are exaggerated or off-topic.\n5. Rate the diversity on a scale of 1 to 5, with a 5 indicating a good quality response that is engaging to read and a 1 indicating a raw response or a question that is off-topic.\n\nDiversity:" + }, + "prompt": "You are a good assistant. Please rate the given answer to the \"generation\" question below.\n\nThe question is as follows:\n\n{question}\n\nThe answer is as follows:\n\n{answer}\n\nThe metric for evaluation is as follows:\n\n{metric}\n\nYou should follow the following evaluation steps:\n\n{steps}" + }, + "open_qa": { + "id": 4, + "category": "open_qa", + "metrics": { + "language organization": "Language organization (1-5): whether the answer language is fluent and coherent, uses correct grammar, has a certain logic, uses appropriate connecting words, transition words, etc.", + "relevance": "Relevance (1-5): whether the content of the answer is relevant to the topic, does not answer the wrong question, and strictly follows the requirements of the topic.", + "correctness": "Correctness (1-5): whether the answer is correct or not." + }, + "CoT": { + "language organization": "1. Read the answers and check for grammatical errors, poor word choice, or other significant mistakes.\n2. Check that the answer is logical, conveys the information in a logical order, and is self-explanatory.\n3. Determine if the answer is relevant to the question or topic and conveys a clear message.\n4. Check that the answer is coherent and that appropriate transitions and switches are used to maintain coherence between sentences and paragraphs.\n5. Check that the answer is clearly structured and organized in such a way that the reader can easily understand the hierarchy and structure of the information.\n6. Evaluate the language organization of the answer based on a combination of the above factors and give a score of 1 to 5, where 5 indicates very good language organization and 1 indicates very poor language organization.\n\nLanguage organization:", + "relevance": "1. Read the question to determine what the question asks and what aspects of the question need to be answered.\n2. Read the answers to make sure that they directly answer the question asked.\n3. Check that the answer follows the requirements of the question, including the way it is answered, the length of the answer, the format of the answer, etc.\n4. Evaluate how relevant the answer is based on the above factors and give a score of 1 to 5, where 5 means the answer is very relevant and 1 means the answer is not relevant at all.\n\nRelevance:", + "correctness": "1. Read the question carefully and try to answer the question yourself.\n2. Check the correctness of the answer. You can use known facts or research to verify that the answer is correct. If the answer is correct, you can give a score of 5 for correctness. If the answer is partially correct, an appropriate score, such as 2, 3, or 4, may be given. If the answer is completely incorrect, only 1 point is awarded.\n\nCorrectness:" + }, + "prompt": "You are a good assistant. Please rate the answers to the \"open qa\" question below.\n\nThe question is as follows:\n\n{question}\n\nThe answer is as follows:\n\n{answer}\n\nThe metric for evaluation is as follows:\n\n{metric}\n\nYou should follow the following evaluation steps:\n\n{steps}" + }, + "roleplay": { + "id": 5, + "category": "roleplay", + "metrics": { + "language organization": "Language organization (1-5): whether the answer language is fluent and coherent, uses correct grammar, has a certain logic, uses appropriate connecting words, transition words, etc.", + "relevance": "Relevance (1-5): whether the content of the answer is relevant to the topic, does not answer the wrong question, and strictly follows the requirements of the topic.", + "fidelity": "Fidelity (1-5): whether the answer is able to answer the given request in strict compliance with the role setting.", + "creativity": "Creativity (1-5): The answers to the role-play questions need to be somewhat creative, but at the same time they need to adhere to the setting of the role." + }, + "CoT": { + "language organization": "1. Read the answers and check for grammatical errors, poor word choice, or other significant mistakes.\n2. Check that the answer is logical, conveys the information in a logical order, and is self-explanatory.\n3. Determine if the answer is relevant to the question or topic and conveys a clear message.\n4. Check that the answer is coherent and that appropriate transitions and switches are used to maintain coherence between sentences and paragraphs.\n5. Check that the answer is clearly structured and organized in such a way that the reader can easily understand the hierarchy and structure of the information.\n6. Evaluate the language organization of the answer based on a combination of the above factors and give a score of 1 to 5, where 5 indicates very good language organization and 1 indicates very poor language organization.\n\nLanguage organization:", + "relevance": "1. Read the question to determine what the question asks and what aspects of the question need to be answered.\n2. Read the answers to make sure that they directly answer the question asked.\n3. Check that the answer follows the requirements of the question, including the way it is answered, the length of the answer, the format of the answer, etc.\n4. Evaluate how relevant the answer is based on the above factors and give a score of 1 to 5, where 5 means the answer is very relevant and 1 means the answer is not relevant at all.\n\nRelevance:", + "fidelity": "1. Read the question carefully to understand how the character is set up and represented in the question, including aspects such as occupation, background, point of view, and personality.\n2. Read the question's request and confirm the details that need to be taken into account when answering the request.\n3. Compare the provided answer with the setting of the role and assess whether the answer can strictly adhere to the setting of the role.\n4. Combine the results of the above assessment to give a fidelity score ranging from 1 to 5, where a score of 1 means that the response does not match the persona at all, and a score of 5 means that the response fully complies with the persona and satisfies the given request.\n\nFidelity:", + "creativity": "1. Read the question carefully to understand how the character is set up and represented in the question, including career, background, perspective, and personality.\n2. Evaluate whether the answer has unique ideas and suggestions that bring new ideas and insights to the questioner.\n3. Compare the creativity in the response to the setting of the persona and assess whether the response adheres to the setting and essential characteristics of the persona.\n4. Evaluate the quality of the responses in general and combine the results of the above assessment to give a creativity score ranging from 1 to 5, where a score of 1 indicates that the response lacks creativity and a score of 5 indicates that the response has unique ideas and suggestions and is able to adhere to the set-up of the persona.\n\nCreativity:" + }, + "prompt": "You are a good assistant. Please rate the given answer to the \"role-play\" question below.\n\nThe question is as follows:\n\n{question}\n\nThe answer is as follows:\n\n{answer}\n\nThe metric for evaluation is as follows:\n\n{metric}\n\nYou should follow the following evaluation steps:\n\n{steps}" + }, + "Other": { + "id": 6, + "category": "Other", + "metrics": { + "relevance": "Relevance (1-5): whether the content of the answer is relevant to the topic, does not answer the wrong question, and strictly follows the requirements of the topic.", + "correctness": "Correctness (1-5): whether the answer is correct or not." + }, + "CoT": { + "language organization": "1. Read the answers and check for grammatical errors, poor word choice, or other significant mistakes.\n2. Check that the answer is logical, conveys the information in a logical order, and is self-explanatory.\n3. Determine if the answer is relevant to the question or topic and conveys a clear message.\n4. Check that the answer is coherent and that appropriate transitions and switches are used to maintain coherence between sentences and paragraphs.\n5. Check that the answer is clearly structured and organized in such a way that the reader can easily understand the hierarchy and structure of the information.\n6. Evaluate the language organization of the answer based on a combination of the above factors and give a score of 1 to 5, where 5 indicates very good language organization and 1 indicates very poor language organization.\n\nLanguage organization:", + "relevance": "1. Read the question to determine what the question asks and what aspects of the question need to be answered.\n2. Read the answers to make sure that they directly answer the question asked.\n3. Check that the answer follows the requirements of the question, including the way it is answered, the length of the answer, the format of the answer, etc.\n4. Evaluate how relevant the answer is based on the above factors and give a score of 1 to 5, where 5 means the answer is very relevant and 1 means the answer is not relevant at all.\n\nRelevance:", + "correctness": "1. Read the question carefully and try to answer the question by yourself.\n2. Check the correctness of the answer. You can use known facts or research to verify that the answer is correct. If the answer is correct, you can give a score of 5 for correctness. If the answer is partially correct, an appropriate score, such as 2, 3, or 4, may be assigned. If the answer is completely incorrect, only 1 point is awarded.\n\nCorrectness:" + }, + "prompt": "You are a good assistant. Please rate the given answer to the question below.\n\nThe question is as follows:\n\n{question}\n\nThe answer is as follows:\n\n{answer}\n\nThe metric for evaluation is as follows:\n\n{metric}\n\nYou should follow the following evaluation steps:\n\n{steps}" + } +} diff --git a/applications/ColossalEval/examples/dataset_evaluation/config/evaluation/config.json b/applications/ColossalEval/examples/dataset_evaluation/config/evaluation/config.json new file mode 100644 index 0000000000000000000000000000000000000000..adb540f60345bae31c1451c62ac24aca0bbc868b --- /dev/null +++ b/applications/ColossalEval/examples/dataset_evaluation/config/evaluation/config.json @@ -0,0 +1,58 @@ +{ + "model": [ + { + "name": "model1" + }, + { + "name": "model2" + } + ], + "dataset": [ + { + "name": "mmlu", + "metrics": [ + "first_token_accuracy", + "single_choice_accuracy", + "perplexity", + "ppl_score", + "ppl_score_over_choices" + ] + }, + { + "name": "cmmlu", + "metrics": [ + "first_token_accuracy", + "single_choice_accuracy", + "perplexity", + "ppl_score", + "ppl_score_over_choices" + ] + }, + { + "name": "agieval", + "metrics": [ + "first_token_accuracy", + "single_choice_accuracy", + "multi_choice_accuracy", + "math_equivalence", + "perplexity", + "ppl_score_over_choices", + "ppl_score" + ] + }, + { + "name": "gaokaobench", + "metrics": [ + "first_token_accuracy", + "single_choice_accuracy", + "multi_choice_accuracy", + "math_equivalence", + "rouge_score", + "rouge_zh_score", + "perplexity", + "ppl_score_over_choices", + "ppl_score" + ] + } + ] +} diff --git a/applications/ColossalEval/examples/dataset_evaluation/config/inference/config.json b/applications/ColossalEval/examples/dataset_evaluation/config/inference/config.json new file mode 100644 index 0000000000000000000000000000000000000000..9672c442e647b9fcc186133906b17c102a5ecd08 --- /dev/null +++ b/applications/ColossalEval/examples/dataset_evaluation/config/inference/config.json @@ -0,0 +1,84 @@ +{ + "model": [ + { + "name": "model name", + "model_class": "HuggingFaceCausalLM", + "parameters": { + "path": "path to model", + "model_max_length": 4096, + "tokenizer_path": "", + "tokenizer_kwargs": { + "trust_remote_code": true + }, + "peft_path": null, + "model_kwargs": { + "torch_dtype": "torch.float32", + "trust_remote_code": true + }, + "prompt_template": "plain", + "batch_size": 4 + } + }, + { + "name": "model2 name", + "model_class": "HuggingFaceCausalLM", + "parameters": { + "path": "path to model2", + "model_max_length": 4096, + "tokenizer_path": "", + "tokenizer_kwargs": { + "trust_remote_code": true + }, + "peft_path": null, + "model_kwargs": { + "torch_dtype": "torch.float32", + "trust_remote_code": true + }, + "prompt_template": "plain", + "batch_size": 4 + } + } + ], + "dataset": [ + { + "name": "agieval", + "dataset_class": "AGIEvalDataset", + "debug": false, + "few_shot": false, + "path": "path to original dataset (folder)", + "save_path": "path to save converted dataset (e.g. inference_data/agieval.json)" + }, + { + "name": "ceval", + "dataset_class": "CEvalDataset", + "debug": false, + "few_shot": true, + "path": "path to original dataset (folder)", + "save_path": "path to save converted dataset (e.g. inference_data/ceval.json)" + }, + { + "name": "cmmlu", + "dataset_class": "CMMLUDataset", + "debug": false, + "few_shot": true, + "path": "path to original dataset (folder)", + "save_path": "path to save converted dataset (e.g. inference_data/cmmlu.json)" + }, + { + "name": "gaokaobench", + "dataset_class": "GaoKaoBenchDataset", + "debug": false, + "few_shot": false, + "path": "path to original dataset (folder)", + "save_path": "path to save converted dataset (e.g. inference_data/gaokaobench.json)" + }, + { + "name": "mmlu", + "dataset_class": "MMLUDataset", + "debug": false, + "few_shot": true, + "path": "path to original dataset (folder)", + "save_path": "path to save converted dataset (e.g. inference_data/mmlu.json)" + } + ] +} diff --git a/applications/ColossalEval/examples/dataset_evaluation/eval_dataset.py b/applications/ColossalEval/examples/dataset_evaluation/eval_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..ec81cf0cef71dc9dece6e649b6923f48f3c459d5 --- /dev/null +++ b/applications/ColossalEval/examples/dataset_evaluation/eval_dataset.py @@ -0,0 +1,73 @@ +import argparse +import os + +import tabulate +from colossal_eval.evaluate.dataset_evaluator import DatasetEvaluator +from colossal_eval.utils import jdump, jload + + +def main(args): + config = jload(args.config) + + evaluation_results = {dataset["name"]: {} for dataset in config["dataset"]} + evaluation_results_table = {dataset["name"]: {} for dataset in config["dataset"]} + evaluator = DatasetEvaluator() + + for dataset_parameter in config["dataset"]: + dataset_name = dataset_parameter["name"] + metrics = dataset_parameter["metrics"] + results_metric_model = {metric: {model["name"]: None for model in config["model"]} for metric in metrics} + for model in config["model"]: + model_name = model["name"] + + data = jload( + os.path.join(args.inference_results_path, model_name, f"{dataset_name}_inference_results.json") + ) + results = evaluator.get_evaluation_results(data, dataset_name, model_name, metrics) + + for metric, score in results.items(): + results_metric_model[metric][model_name] = score["ALL"] + + evaluation_results[dataset_name][model_name] = results + + evaluation_results_table[dataset_name] = results_metric_model + + table = [] + header = ["dataset", "metric"] + [model["name"] for model in config["model"]] + table.append(header) + + for dataset_parameter in config["dataset"]: + dataset_name = dataset_parameter["name"] + metrics = dataset_parameter["metrics"] + + for metric, model_results in evaluation_results_table[dataset_name].items(): + row = [dataset_name] + for model, score in model_results.items(): + if len(row) == 1: + row.extend([metric, "{:.02f}".format(score)]) + else: + row.append("{:.02f}".format(score)) + + table.append(row) + + table = tabulate.tabulate(table, headers="firstrow") + print(table) + + os.makedirs(args.evaluation_results_save_path, exist_ok=True) + + with open(os.path.join(args.evaluation_results_save_path, "evaluation_results_table.txt"), "w") as file: + file.write(table) + + jdump(evaluation_results, os.path.join(args.evaluation_results_save_path, "evaluation_results.json")) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="ColossalEval evaluation process.") + parser.add_argument("--config", type=str, default=None, required=True, help="path to config file") + parser.add_argument("--inference_results_path", type=str, default=None, help="path to inference results") + parser.add_argument( + "--evaluation_results_save_path", type=str, default=None, help="path to save evaluation results" + ) + args = parser.parse_args() + + main(args) diff --git a/applications/ColossalEval/examples/dataset_evaluation/eval_dataset.sh b/applications/ColossalEval/examples/dataset_evaluation/eval_dataset.sh new file mode 100644 index 0000000000000000000000000000000000000000..ad0bfc03acbbb0455b11ca62d216fcc9b3f75594 --- /dev/null +++ b/applications/ColossalEval/examples/dataset_evaluation/eval_dataset.sh @@ -0,0 +1,4 @@ +python eval_dataset.py \ + --config "path to config file" \ + --inference_results_path "path to inference results" \ + --evaluation_results_save_path "path to save evaluation results" diff --git a/applications/ColossalEval/examples/dataset_evaluation/inference.py b/applications/ColossalEval/examples/dataset_evaluation/inference.py new file mode 100644 index 0000000000000000000000000000000000000000..657fc33bf1ef44e54a49a43f1159b177549bf0a6 --- /dev/null +++ b/applications/ColossalEval/examples/dataset_evaluation/inference.py @@ -0,0 +1,171 @@ +import argparse +import copy +import os +from typing import Dict, List + +import torch +import torch.distributed as dist +from colossal_eval import dataset, models, utils + +import colossalai +from colossalai.logging import get_dist_logger + +logger = get_dist_logger() + + +def rm_and_merge(world_size: int, save_path: str, model_names: List[str], dataset_names: Dict[str, List]) -> None: + """ + Remove inference result per rank and merge them into one file. + + Args: + world_size: Number of processes for inference. + save_path: The folder for storing inference results. + model_names: Names of models for inference. + dataset_names: Names of dataset for inference. + + """ + + for model_name in model_names: + for dataset_name, categories in dataset_names.items(): + all_answers = {} + for category in categories: + all_answers[category] = {"data": []} + answers = {"data": []} + + for r in range(world_size): + directory = os.path.join( + save_path, model_name, f"{dataset_name}_{category}_inference_results_rank{r}.json" + ) + if not os.path.exists(directory): + raise Exception( + f"Directory {directory} not found. There may be an error during inference time." + ) + else: + rank_answers = utils.jload(directory) + answers["data"].extend(rank_answers["data"]) + answers["inference_kwargs"] = rank_answers["inference_kwargs"] + + for r in range(world_size): + try: + directory = os.path.join( + save_path, model_name, f"{dataset_name}_{category}_inference_results_rank{r}.json" + ) + os.remove(directory) + except Exception as e: + print(e) + + all_answers[category] = answers + + logger.info(f"Save inference results of model {model_name} on dataset {dataset_name}.") + utils.jdump(all_answers, os.path.join(save_path, model_name, f"{dataset_name}_inference_results.json")) + + logger.info(f"Save inference results of model {model_name} for all dataset.") + logger.info(f"Save inference results of all models for all dataset.") + + +def main(args): + colossalai.launch_from_torch(config={}, seed=42) + world_size = dist.get_world_size() + rank = dist.get_rank() + + inference_data = {} + debug_args = {} + few_shot_args = {} + + config = utils.jload(args.config) + + model_parameters = config["model"] + dataset_parameters = config["dataset"] + + for dataset_parameter in dataset_parameters: + path = dataset_parameter["path"] + save_path = dataset_parameter["save_path"] + dataset_name = dataset_parameter["name"] + debug_args[dataset_name] = dataset_parameter["debug"] + few_shot_args[dataset_name] = dataset_parameter["few_shot"] + + if not args.load_dataset: + if os.path.exists(save_path): + dataset_ = utils.jload(save_path) + inference_data[dataset_name] = dataset_["test"] + else: + raise Exception( + "Can't find the converted dataset. You may set load_dataset True to store the dataset first." + ) + + continue + + dataset_class = eval(f"dataset.{dataset_parameter['dataset_class']}") + if not issubclass(dataset_class, dataset.BaseDataset): + raise ValueError(f"Dataset class {dataset_parameter['dataset_class']} is not a subclass of BaseDataset.") + + dataset_ = dataset_class(path, logger, dataset_parameter["few_shot"]) + + dataset_.save(save_path) + inference_data[dataset_name] = dataset_.dataset["test"] + + for model_parameter in model_parameters: + model_name = model_parameter["name"] + model_class = eval(f"models.{model_parameter['model_class']}") + paramerters = model_parameter["parameters"] + paramerters.update({"logger": logger}) + paramerters.update({"prompt_template": utils.prompt_templates[paramerters["prompt_template"]]}) + + model_ = model_class(**paramerters) + if not issubclass(model_class, models.BaseModel): + raise ValueError(f"Model class {model_parameter['model_class']} is not a subclass of BaseModel.") + + for dataset_name, split_data in inference_data.items(): + start = 0 + for category, category_data in split_data.items(): + if few_shot_args[dataset_name] and category_data["inference_kwargs"].get("few_shot_data", None) is None: + raise Exception(f"Dataset {dataset_name} doesn't have few-shot data for category {category}!") + + answers_to_dump = copy.deepcopy(category_data) + partition_size = len(category_data["data"]) // world_size + redundant = len(category_data["data"]) % world_size + + # Ensure that the amount of data for inference is as consistent as possible across different processes. + lengths = [partition_size for _ in range(world_size)] + for j in range(redundant): + lengths[(j + start) % world_size] += 1 + + start = (start + redundant) % world_size + + questions = category_data["data"][sum(lengths[0:rank]) : sum(lengths[0:rank]) + lengths[rank]] + + answers_per_rank = model_.inference( + questions, inference_kwargs=category_data["inference_kwargs"], debug=debug_args[dataset_name] + ) + + answers_to_dump["data"] = answers_per_rank + + utils.jdump( + answers_to_dump, + os.path.join( + args.inference_save_path, + model_name, + f"{dataset_name}_{category}_inference_results_rank{rank}.json", + ), + ) + + logger.info(f"Rank {rank} peak CUDA mem: {torch.cuda.max_memory_allocated()/1024**3:.3f} GB") + + del model_ + torch.cuda.empty_cache() + + dist.barrier() + if rank == 0: + model_names = [model_parameter["name"] for model_parameter in model_parameters] + dataset_names = {key: list(inference_data[key].keys()) for key in inference_data} + rm_and_merge(world_size, args.inference_save_path, model_names, dataset_names) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="ColossalEval inference process.") + parser.add_argument("--config", type=str, default=None, required=True, help="path to config file") + parser.add_argument("--load_dataset", default=False, action="store_true") + parser.add_argument("--inference_save_path", type=str, default=None, help="path to save inference results") + args = parser.parse_args() + + main(args) diff --git a/applications/ColossalEval/examples/dataset_evaluation/inference.sh b/applications/ColossalEval/examples/dataset_evaluation/inference.sh new file mode 100644 index 0000000000000000000000000000000000000000..15f9afd560454a61ae719b1ac5cf7c3882ed3a99 --- /dev/null +++ b/applications/ColossalEval/examples/dataset_evaluation/inference.sh @@ -0,0 +1,4 @@ +torchrun --nproc_per_node=1 inference.py \ + --config "path to config file" \ + --load_dataset \ + --inference_save_path "path to save inference results" diff --git a/applications/ColossalEval/examples/gpt_evaluation/config/evaluation/config.json b/applications/ColossalEval/examples/gpt_evaluation/config/evaluation/config.json new file mode 100644 index 0000000000000000000000000000000000000000..6ebe3996b1cf72fe75f67c570fad2c857e583158 --- /dev/null +++ b/applications/ColossalEval/examples/gpt_evaluation/config/evaluation/config.json @@ -0,0 +1,44 @@ +{ + "language": "en", + "category": { + "brainstorming": { + "GPT": [ + "language organization", + "relevance", + "creativity", + "practicality", + "reasonableness" + ] + }, + "chat": { + "GPT": [ + "language organization", + "naturalness", + "engagingness", + "fidelity" + ] + }, + "generation": { + "GPT": [ + "language organization", + "relevance", + "diversity" + ] + }, + "open_qa": { + "GPT": [ + "language organization", + "relevance", + "correctness" + ] + }, + "roleplay": { + "GPT": [ + "language organization", + "relevance", + "fidelity", + "creativity" + ] + } + } +} diff --git a/applications/ColossalEval/examples/gpt_evaluation/config/inference/config.json b/applications/ColossalEval/examples/gpt_evaluation/config/inference/config.json new file mode 100644 index 0000000000000000000000000000000000000000..7ed7491a87c5d7c619987cb09d92faad7a18074c --- /dev/null +++ b/applications/ColossalEval/examples/gpt_evaluation/config/inference/config.json @@ -0,0 +1,33 @@ +{ + "model": [ + { + "name": "model name", + "model_class": "HuggingFaceCausalLM", + "parameters": { + "path": "path to model", + "model_max_length": 4096, + "tokenizer_path": "", + "tokenizer_kwargs": { + "trust_remote_code": true + }, + "peft_path": null, + "model_kwargs": { + "torch_dtype": "torch.float32", + "trust_remote_code": true + }, + "prompt_template": "plain", + "batch_size": 4 + } + } + ], + "dataset": [ + { + "name": "colossal", + "dataset_class": "ColossalDataset", + "debug": false, + "few_shot": false, + "path": "../../configs/gpt_evaluation/data/eval_en_examples.json", + "save_path": "path to save converted dataset (inference_data/colossal.json)" + } + ] +} diff --git a/applications/ColossalEval/examples/gpt_evaluation/eval.py b/applications/ColossalEval/examples/gpt_evaluation/eval.py new file mode 100644 index 0000000000000000000000000000000000000000..cd521af59823b2564ee04ffc8ab7329f55cb5099 --- /dev/null +++ b/applications/ColossalEval/examples/gpt_evaluation/eval.py @@ -0,0 +1,139 @@ +import argparse +import os + +import openai +from colossal_eval.evaluate.evaluator import Evaluator +from colossal_eval.utils import jload + + +def main(args): + assert len(args.answer_file_list) == len( + args.model_name_list + ), "The number of answer files and model names should be equal!" + + # load config + config = jload(args.config_file) + + if config["language"] in ["cn", "en"]: + # get metric settings for all categories + metrics_per_category = {} + for category in config["category"].keys(): + metrics_all = {} + for metric_type, metrics in config["category"][category].items(): + metrics_all[metric_type] = metrics + metrics_per_category[category] = metrics_all + + battle_prompt = None + if args.battle_prompt_file: + battle_prompt = jload(args.battle_prompt_file) + + gpt_evaluation_prompt = None + if args.gpt_evaluation_prompt_file: + gpt_evaluation_prompt = jload(args.gpt_evaluation_prompt_file) + + if len(args.model_name_list) == 2 and not battle_prompt: + raise Exception("No prompt file for battle provided. Please specify the prompt file for battle!") + + if len(args.model_name_list) == 1 and not gpt_evaluation_prompt: + raise Exception( + "No prompt file for gpt evaluation provided. Please specify the prompt file for gpt evaluation!" + ) + + if args.gpt_model == "text-davinci-003" and args.gpt_with_reference: + raise Exception( + "GPT evaluation with reference is not supported for text-davinci-003. You should specify chat models such as gpt-3.5-turbo or gpt-4." + ) + + # initialize evaluator + evaluator = Evaluator( + metrics_per_category, + battle_prompt, + gpt_evaluation_prompt, + args.gpt_model, + config["language"], + args.gpt_with_reference, + ) + if len(args.model_name_list) == 2: + answers_1 = jload(args.answer_file_list[0]) + answers_2 = jload(args.answer_file_list[1]) + + answers1 = [] + for category, value in answers_1.items(): + answers1.extend(value["data"]) + + answers2 = [] + for category, value in answers_2.items(): + answers2.extend(value["data"]) + + assert len(answers1) == len(answers2), "The number of answers for two models should be equal!" + + evaluator.battle(answers1=answers1, answers2=answers2) + evaluator.save(args.save_path, args.model_name_list) + elif len(args.model_name_list) == 1: + targets = jload(args.target_file) + answers = jload(args.answer_file_list[0]) + + references = [] + for category, value in targets["test"].items(): + references.extend(value["data"]) + + predictions = [] + for category, value in answers.items(): + predictions.extend(value["data"]) + + assert len(references) == len( + predictions + ), "The number of target answers and model answers should be equal!" + + evaluator.evaluate( + answers=predictions, targets=references, save_path=args.save_path, model_name=args.model_name_list[0] + ) + evaluator.save(args.save_path, args.model_name_list) + else: + raise ValueError("Unsupported number of answer files and model names!") + else: + raise ValueError(f'Unsupported language {config["language"]}!') + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="ColossalAI LLM evaluation pipeline.") + parser.add_argument( + "--config_file", type=str, default=None, required=True, help="path to the file of target results" + ) + parser.add_argument("--battle_prompt_file", type=str, default=None, help="path to the prompt file for battle") + parser.add_argument( + "--gpt_evaluation_prompt_file", type=str, default=None, help="path to the prompt file for gpt evaluation" + ) + parser.add_argument("--target_file", type=str, default=None, help="path to the target answer (ground truth) file") + parser.add_argument( + "--answer_file_list", + type=str, + nargs="+", + default=[], + required=True, + help="path to the answer files of at most 2 models", + ) + parser.add_argument( + "--model_name_list", type=str, nargs="+", default=[], required=True, help="the names of at most 2 models" + ) + parser.add_argument( + "--gpt_model", + default="gpt-3.5-turbo-16k", + choices=["text-davinci-003", "gpt-3.5-turbo", "gpt-3.5-turbo-16k", "gpt-4"], + help="which GPT model to use for evaluation", + ) + parser.add_argument( + "--gpt_with_reference", + default=False, + action="store_true", + help="whether to include reference answer in gpt evaluation", + ) + parser.add_argument("--save_path", type=str, default="results", help="path to save evaluation results") + parser.add_argument("--openai_key", type=str, default=None, required=True, help="Your openai key") + args = parser.parse_args() + + if args.openai_key is not None: + os.environ["OPENAI_API_KEY"] = args.openai_key + openai.api_key = os.getenv("OPENAI_API_KEY") + + main(args) diff --git a/applications/ColossalEval/examples/gpt_evaluation/eval.sh b/applications/ColossalEval/examples/gpt_evaluation/eval.sh new file mode 100644 index 0000000000000000000000000000000000000000..f5729e6ee5c7249aa9af842c171008ee1e35ace2 --- /dev/null +++ b/applications/ColossalEval/examples/gpt_evaluation/eval.sh @@ -0,0 +1,9 @@ +python eval.py \ + --config_file "path to the config file" \ + --battle_prompt_file "path to the prompt file for battle" \ + --gpt_evaluation_prompt_file "path to the prompt file for gpt evaluation" \ + --target_file "path to the target answer file" \ + --answer_file_list "path to the answer files of at most 2 models" \ + --model_name_list "the names of at most 2 models" \ + --save_path "path to save results" \ + --openai_key "your openai key" \ diff --git a/applications/ColossalEval/examples/gpt_evaluation/inference.py b/applications/ColossalEval/examples/gpt_evaluation/inference.py new file mode 100644 index 0000000000000000000000000000000000000000..657fc33bf1ef44e54a49a43f1159b177549bf0a6 --- /dev/null +++ b/applications/ColossalEval/examples/gpt_evaluation/inference.py @@ -0,0 +1,171 @@ +import argparse +import copy +import os +from typing import Dict, List + +import torch +import torch.distributed as dist +from colossal_eval import dataset, models, utils + +import colossalai +from colossalai.logging import get_dist_logger + +logger = get_dist_logger() + + +def rm_and_merge(world_size: int, save_path: str, model_names: List[str], dataset_names: Dict[str, List]) -> None: + """ + Remove inference result per rank and merge them into one file. + + Args: + world_size: Number of processes for inference. + save_path: The folder for storing inference results. + model_names: Names of models for inference. + dataset_names: Names of dataset for inference. + + """ + + for model_name in model_names: + for dataset_name, categories in dataset_names.items(): + all_answers = {} + for category in categories: + all_answers[category] = {"data": []} + answers = {"data": []} + + for r in range(world_size): + directory = os.path.join( + save_path, model_name, f"{dataset_name}_{category}_inference_results_rank{r}.json" + ) + if not os.path.exists(directory): + raise Exception( + f"Directory {directory} not found. There may be an error during inference time." + ) + else: + rank_answers = utils.jload(directory) + answers["data"].extend(rank_answers["data"]) + answers["inference_kwargs"] = rank_answers["inference_kwargs"] + + for r in range(world_size): + try: + directory = os.path.join( + save_path, model_name, f"{dataset_name}_{category}_inference_results_rank{r}.json" + ) + os.remove(directory) + except Exception as e: + print(e) + + all_answers[category] = answers + + logger.info(f"Save inference results of model {model_name} on dataset {dataset_name}.") + utils.jdump(all_answers, os.path.join(save_path, model_name, f"{dataset_name}_inference_results.json")) + + logger.info(f"Save inference results of model {model_name} for all dataset.") + logger.info(f"Save inference results of all models for all dataset.") + + +def main(args): + colossalai.launch_from_torch(config={}, seed=42) + world_size = dist.get_world_size() + rank = dist.get_rank() + + inference_data = {} + debug_args = {} + few_shot_args = {} + + config = utils.jload(args.config) + + model_parameters = config["model"] + dataset_parameters = config["dataset"] + + for dataset_parameter in dataset_parameters: + path = dataset_parameter["path"] + save_path = dataset_parameter["save_path"] + dataset_name = dataset_parameter["name"] + debug_args[dataset_name] = dataset_parameter["debug"] + few_shot_args[dataset_name] = dataset_parameter["few_shot"] + + if not args.load_dataset: + if os.path.exists(save_path): + dataset_ = utils.jload(save_path) + inference_data[dataset_name] = dataset_["test"] + else: + raise Exception( + "Can't find the converted dataset. You may set load_dataset True to store the dataset first." + ) + + continue + + dataset_class = eval(f"dataset.{dataset_parameter['dataset_class']}") + if not issubclass(dataset_class, dataset.BaseDataset): + raise ValueError(f"Dataset class {dataset_parameter['dataset_class']} is not a subclass of BaseDataset.") + + dataset_ = dataset_class(path, logger, dataset_parameter["few_shot"]) + + dataset_.save(save_path) + inference_data[dataset_name] = dataset_.dataset["test"] + + for model_parameter in model_parameters: + model_name = model_parameter["name"] + model_class = eval(f"models.{model_parameter['model_class']}") + paramerters = model_parameter["parameters"] + paramerters.update({"logger": logger}) + paramerters.update({"prompt_template": utils.prompt_templates[paramerters["prompt_template"]]}) + + model_ = model_class(**paramerters) + if not issubclass(model_class, models.BaseModel): + raise ValueError(f"Model class {model_parameter['model_class']} is not a subclass of BaseModel.") + + for dataset_name, split_data in inference_data.items(): + start = 0 + for category, category_data in split_data.items(): + if few_shot_args[dataset_name] and category_data["inference_kwargs"].get("few_shot_data", None) is None: + raise Exception(f"Dataset {dataset_name} doesn't have few-shot data for category {category}!") + + answers_to_dump = copy.deepcopy(category_data) + partition_size = len(category_data["data"]) // world_size + redundant = len(category_data["data"]) % world_size + + # Ensure that the amount of data for inference is as consistent as possible across different processes. + lengths = [partition_size for _ in range(world_size)] + for j in range(redundant): + lengths[(j + start) % world_size] += 1 + + start = (start + redundant) % world_size + + questions = category_data["data"][sum(lengths[0:rank]) : sum(lengths[0:rank]) + lengths[rank]] + + answers_per_rank = model_.inference( + questions, inference_kwargs=category_data["inference_kwargs"], debug=debug_args[dataset_name] + ) + + answers_to_dump["data"] = answers_per_rank + + utils.jdump( + answers_to_dump, + os.path.join( + args.inference_save_path, + model_name, + f"{dataset_name}_{category}_inference_results_rank{rank}.json", + ), + ) + + logger.info(f"Rank {rank} peak CUDA mem: {torch.cuda.max_memory_allocated()/1024**3:.3f} GB") + + del model_ + torch.cuda.empty_cache() + + dist.barrier() + if rank == 0: + model_names = [model_parameter["name"] for model_parameter in model_parameters] + dataset_names = {key: list(inference_data[key].keys()) for key in inference_data} + rm_and_merge(world_size, args.inference_save_path, model_names, dataset_names) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="ColossalEval inference process.") + parser.add_argument("--config", type=str, default=None, required=True, help="path to config file") + parser.add_argument("--load_dataset", default=False, action="store_true") + parser.add_argument("--inference_save_path", type=str, default=None, help="path to save inference results") + args = parser.parse_args() + + main(args) diff --git a/applications/ColossalEval/examples/gpt_evaluation/inference.sh b/applications/ColossalEval/examples/gpt_evaluation/inference.sh new file mode 100644 index 0000000000000000000000000000000000000000..15f9afd560454a61ae719b1ac5cf7c3882ed3a99 --- /dev/null +++ b/applications/ColossalEval/examples/gpt_evaluation/inference.sh @@ -0,0 +1,4 @@ +torchrun --nproc_per_node=1 inference.py \ + --config "path to config file" \ + --load_dataset \ + --inference_save_path "path to save inference results" diff --git a/applications/ColossalEval/requirements.txt b/applications/ColossalEval/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..c110606e0303bd7187d2e5a9fea90e4386a9bc2e --- /dev/null +++ b/applications/ColossalEval/requirements.txt @@ -0,0 +1,12 @@ +transformers>=4.32.0 +colossalai>=0.3.1 +peft +tabulate +jieba +fuzzywuzzy +rouge +openai +matplotlib +pandas +seaborn +scikit-learn diff --git a/applications/ColossalEval/setup.py b/applications/ColossalEval/setup.py new file mode 100644 index 0000000000000000000000000000000000000000..4f7b1bb5c42e75e2e8bdc736ee0c2da69a372f80 --- /dev/null +++ b/applications/ColossalEval/setup.py @@ -0,0 +1,31 @@ +from setuptools import find_packages, setup + + +def fetch_requirements(path): + with open(path, "r") as fd: + return [r.strip() for r in fd.readlines()] + + +def fetch_readme(): + with open("README.md", encoding="utf-8") as f: + return f.read() + + +setup( + name="colossal_eval", + version="0.0.1", + packages=find_packages(exclude=["examples", "*.egg-info"]), + description="Colossal-AI LLM-Evaluation Framework", + long_description=fetch_readme(), + long_description_content_type="text/markdown", + license="Apache Software License 2.0", + url="https://github.com/hpcaitech/LLM-Evaluation", + install_requires=fetch_requirements("requirements.txt"), + python_requires=">=3.6", + classifiers=[ + "Programming Language :: Python :: 3", + "License :: OSI Approved :: Apache Software License", + "Environment :: GPU :: NVIDIA CUDA", + "Topic :: Scientific/Engineering :: Artificial Intelligence", + ], +) diff --git a/applications/README.md b/applications/README.md index cd0435aae199aade3981c1af6502cf42d0d8578e..f5078e06a73b41300dc1b856711129f885f65883 100644 --- a/applications/README.md +++ b/applications/README.md @@ -4,8 +4,10 @@ This directory contains the applications that are powered by Colossal-AI. The list of applications include: -- [X] [Chatbot](./Chat/README.md) -- [X] [FastFold](https://github.com/hpcaitech/FastFold): Optimizing AlphaFold (Biomedicine) Training and Inference on GPU Clusters +- [X] [Colossal-LLaMA-2](./Colossal-LLaMA-2/): Continual Pre-training of LLaMA-2. +- [X] [ColossalEval](./ColossalEval): Evaluation Pipeline for LLMs. +- [X] [ColossalChat](./Chat/README.md): Replication of ChatGPT with RLHF. +- [X] [FastFold](https://github.com/hpcaitech/FastFold): Optimizing AlphaFold (Biomedicine) Training and Inference on GPU Clusters. > Please note that the `Chatbot` application is migrated from the original `ChatGPT` folder. diff --git a/colossalai/__init__.py b/colossalai/__init__.py index f859161f78108e4c8b5cfba89e12026fee28da43..7da55590305b252ecc568bb91656f6d3ab38cdc5 100644 --- a/colossalai/__init__.py +++ b/colossalai/__init__.py @@ -1,11 +1,4 @@ -from .initialize import ( - get_default_parser, - initialize, - launch, - launch_from_openmpi, - launch_from_slurm, - launch_from_torch, -) +from .initialize import launch, launch_from_openmpi, launch_from_slurm, launch_from_torch try: # .version will be created by setup.py @@ -13,5 +6,7 @@ try: except ModuleNotFoundError: # this will only happen if the user did not run `pip install` # and directly set PYTHONPATH to use Colossal-AI which is a bad practice - __version__ = '0.0.0' - print('please install Colossal-AI from https://www.colossalai.org/download or from source') + __version__ = "0.0.0" + print("please install Colossal-AI from https://www.colossalai.org/download or from source") + +__all__ = ["launch", "launch_from_openmpi", "launch_from_slurm", "launch_from_torch", "__version__"] diff --git a/tests/test_layers/test_2d/checks_2d/__init__.py b/colossalai/_analyzer/__init__.py similarity index 100% rename from tests/test_layers/test_2d/checks_2d/__init__.py rename to colossalai/_analyzer/__init__.py diff --git a/colossalai/_analyzer/_subclasses/_meta_registration.py b/colossalai/_analyzer/_subclasses/_meta_registration.py index 4049be79c70fc1d9c33807d74c2a00fe17c05acd..e8ba88b0406dd8a6ddd9dc00398ccc2b61390d6c 100644 --- a/colossalai/_analyzer/_subclasses/_meta_registration.py +++ b/colossalai/_analyzer/_subclasses/_meta_registration.py @@ -3,7 +3,7 @@ # refer to https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/native_functions.yaml # for more meta_registrations -from typing import Callable, List, Optional, Tuple, Union +from typing import List, Optional, Union import torch from packaging import version @@ -24,25 +24,23 @@ orig_empty_like = torch.empty_like def new(*args, **kwargs): - return orig_empty(*args, **kwargs, device=torch.device('meta')) + return orig_empty(*args, **kwargs, device=torch.device("meta")) def new_strided(*args, **kwargs): - return orig_empty_strided(*args, **kwargs, device=torch.device('meta')) + return orig_empty_strided(*args, **kwargs, device=torch.device("meta")) def new_like(*args, **kwargs): - return orig_empty_like(*args, **kwargs, device=torch.device('meta')) + return orig_empty_like(*args, **kwargs, device=torch.device("meta")) def register_meta(op, register_dispatcher=True): - def wrapper(f): - def add_func(op): meta_table[op] = f if register_dispatcher: - name = (op.__name__ if op._overloadname != "default" else op.overloadpacket.__name__) + name = op.__name__ if op._overloadname != "default" else op.overloadpacket.__name__ try: meta_lib.impl(name, f) except: @@ -54,7 +52,7 @@ def register_meta(op, register_dispatcher=True): return wrapper -if version.parse(torch.__version__) >= version.parse('1.12.0'): +if version.parse(torch.__version__) >= version.parse("1.12.0"): # ============================== Convolutions ====================================== # https://github.com/pytorch/pytorch/pull/79834 @register_meta(aten.convolution.default) @@ -69,7 +67,6 @@ if version.parse(torch.__version__) >= version.parse('1.12.0'): output_padding: List[int], groups: int, ): - def _formula(ln: int, p: int, d: int, k: int, s: int) -> int: """ Formula to apply to calculate the length of some dimension of the output @@ -146,7 +143,8 @@ if version.parse(torch.__version__) >= version.parse('1.12.0'): kernel_size[i], stride[i], output_padding_list[i], - )) + ) + ) else: ret_shape.append(_formula(dims[i], padding[i], dilation[i], kernel_size[i], stride[i])) return ret_shape @@ -180,19 +178,39 @@ if version.parse(torch.__version__) >= version.parse('1.12.0'): shape_out = calc_conv_nd_return_shape(dims, kernel_size, stride, padding, dilation) out = input_tensor.new_empty((input_tensor.shape[0], out_channels, *shape_out)) mem_fmt = pick_memory_format() - out = out.to(memory_format=mem_fmt) # type: ignore[call-overload] + out = out.to(memory_format=mem_fmt) # type: ignore[call-overload] return out @register_meta(aten._convolution.default) - def meta__conv(input_tensor: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor, stride: List[int], - padding: List[int], dilation: List[int], is_transposed: bool, output_padding: List[int], groups: int, - *extra_args): + def meta__conv( + input_tensor: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor, + stride: List[int], + padding: List[int], + dilation: List[int], + is_transposed: bool, + output_padding: List[int], + groups: int, + *extra_args, + ): out = meta_conv(input_tensor, weight, bias, stride, padding, dilation, is_transposed, output_padding, groups) return out @register_meta(aten.convolution_backward.default) - def meta_conv_backward(grad_output: torch.Tensor, input: torch.Tensor, weight: torch.Tensor, bias_sizes, stride, - padding, dilation, transposed, output_padding, groups, output_mask): + def meta_conv_backward( + grad_output: torch.Tensor, + input: torch.Tensor, + weight: torch.Tensor, + bias_sizes, + stride, + padding, + dilation, + transposed, + output_padding, + groups, + output_mask, + ): return new_like(input), new_like(weight), new((bias_sizes)) # https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/AdaptiveAveragePooling.cpp @@ -224,7 +242,6 @@ if version.parse(torch.__version__) >= version.parse('1.12.0'): batch_sizes, dropout_state, ): - is_input_packed = len(batch_sizes) != 0 if is_input_packed: seq_length = len(batch_sizes) @@ -240,8 +257,11 @@ if version.parse(torch.__version__) >= version.parse('1.12.0'): if is_input_packed: out_shape = [batch_sizes_sum, out_size * num_directions] else: - out_shape = ([mini_batch, seq_length, out_size * - num_directions] if batch_first else [seq_length, mini_batch, out_size * num_directions]) + out_shape = ( + [mini_batch, seq_length, out_size * num_directions] + if batch_first + else [seq_length, mini_batch, out_size * num_directions] + ) output = input.new_empty(out_shape) cell_shape = [num_layers * num_directions, mini_batch, hidden_size] @@ -257,15 +277,21 @@ if version.parse(torch.__version__) >= version.parse('1.12.0'): # https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/cudnn/RNN.cpp @register_meta(aten._cudnn_rnn_backward.default) - def meta_cudnn_rnn_backward(input: torch.Tensor, - weight: torch.Tensor, - weight_stride0: int, - hx: torch.Tensor, - cx: Optional[torch.Tensor] = None, - *args, - **kwargs): - return new_like(input), new_like(weight), new_like(hx), new_like(cx) if cx is not None else new( - ()) # (grad_input, grad_weight, grad_hx, grad_cx) + def meta_cudnn_rnn_backward( + input: torch.Tensor, + weight: torch.Tensor, + weight_stride0: int, + hx: torch.Tensor, + cx: Optional[torch.Tensor] = None, + *args, + **kwargs, + ): + return ( + new_like(input), + new_like(weight), + new_like(hx), + new_like(cx) if cx is not None else new(()), + ) # (grad_input, grad_weight, grad_hx, grad_cx) # https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/Activation.cpp # ============================== Activations ======================================= @@ -278,7 +304,7 @@ if version.parse(torch.__version__) >= version.parse('1.12.0'): aten.hardtanh_backward.default, ] - if version.parse(torch.__version__) < version.parse('2.0.0'): + if version.parse(torch.__version__) < version.parse("2.0.0"): _unregistered_ewise += [ aten.prelu_backward.default, ] @@ -296,37 +322,61 @@ if version.parse(torch.__version__) >= version.parse('1.12.0'): # https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/cudnn/BatchNorm.cpp @register_meta(aten.native_batch_norm_backward.default) - def meta_bn_backward(dY: torch.Tensor, input: torch.Tensor, weight: torch.Tensor, running_mean, running_var, - save_mean, save_invstd, train, eps, output_mask): - return new_like(input), new_like(weight), new_like(weight) # (dX, dgamma, dbeta) + def meta_bn_backward( + dY: torch.Tensor, + input: torch.Tensor, + weight: torch.Tensor, + running_mean, + running_var, + save_mean, + save_invstd, + train, + eps, + output_mask, + ): + return new_like(input), new_like(weight), new_like(weight) # (dX, dgamma, dbeta) # https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/cudnn/BatchNorm.cpp @register_meta(aten.cudnn_batch_norm.default) def meta_cudnn_bn(input: torch.Tensor, weight, bias, running_mean, running_var, training, momentum, eps): n_input = input.size(1) - return new_like(input), new((n_input)), new((n_input)), new( - (0), dtype=torch.uint8) # (output, running_mean, running_var, reserve) + return ( + new_like(input), + new((n_input)), + new((n_input)), + new((0), dtype=torch.uint8), + ) # (output, running_mean, running_var, reserve) # https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/cudnn/BatchNorm.cpp # NB: CuDNN only implements the backward algorithm for batchnorm # in training mode (evaluation mode batchnorm has a different algorithm), # which is why this doesn't accept a 'training' parameter. @register_meta(aten.cudnn_batch_norm_backward.default) - def meta_cudnn_bn_backward(dY: torch.Tensor, input: torch.Tensor, weight: torch.Tensor, running_mean, running_var, - save_mean, save_invstd, eps, reserve): - return new_like(input), new_like(weight), new_like(weight) # (dX, dgamma, dbeta) + def meta_cudnn_bn_backward( + dY: torch.Tensor, + input: torch.Tensor, + weight: torch.Tensor, + running_mean, + running_var, + save_mean, + save_invstd, + eps, + reserve, + ): + return new_like(input), new_like(weight), new_like(weight) # (dX, dgamma, dbeta) # https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/layer_norm.cpp @register_meta(aten.native_layer_norm.default) def meta_ln(input: torch.Tensor, normalized_shape, weight, bias, eps): bs, n_input = input.size(0), input.size(1) - return new_like(input), new((bs, n_input, 1)), new((bs, n_input, 1)) # (output, running_mean, running_var) + return new_like(input), new((bs, n_input, 1)), new((bs, n_input, 1)) # (output, running_mean, running_var) # https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/layer_norm.cpp @register_meta(aten.native_layer_norm_backward.default) - def meta_ln_backward(dY: torch.Tensor, input: torch.Tensor, normalized_shape, mean, rstd, weight, bias, - grad_input_mask): - return new_like(input), new_like(weight), new_like(bias) # (dX, dgamma, dbeta) + def meta_ln_backward( + dY: torch.Tensor, input: torch.Tensor, normalized_shape, mean, rstd, weight, bias, grad_input_mask + ): + return new_like(input), new_like(weight), new_like(bias) # (dX, dgamma, dbeta) # ================================== Misc ========================================== # Maybe incorrect @@ -355,8 +405,9 @@ if version.parse(torch.__version__) >= version.parse('1.12.0'): # https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/Embedding.cpp @register_meta(aten.embedding_dense_backward.default) - def meta_embedding_dense_backward(grad_output: torch.Tensor, indices: torch.Tensor, num_weights, padding_idx, - scale_grad_by_freq): + def meta_embedding_dense_backward( + grad_output: torch.Tensor, indices: torch.Tensor, num_weights, padding_idx, scale_grad_by_freq + ): return new((num_weights, grad_output.size(-1)), dtype=grad_output.dtype, layout=grad_output.layout) # ============================== Dropout =========================================== @@ -364,14 +415,14 @@ if version.parse(torch.__version__) >= version.parse('1.12.0'): @register_meta(aten.native_dropout.default) def meta_native_dropout_default(input: torch.Tensor, p: float, train: bool = False): # notice that mask is bool - return new_like(input), new_like(input, dtype=torch.bool) # (output, mask) + return new_like(input), new_like(input, dtype=torch.bool) # (output, mask) # https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/Dropout.cpp @register_meta(aten.native_dropout_backward.default) def meta_native_dropout_backward_default(grad: torch.Tensor, mask: torch.Tensor, scale: float): - return new_like(grad) # (grad_in) + return new_like(grad) # (grad_in) - if version.parse(torch.__version__) < version.parse('1.13.0'): + if version.parse(torch.__version__) < version.parse("1.13.0"): # https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/native_functions.yaml @register_meta(aten.eye.m_out) def meta_eye(n: int, m: int, out: torch.Tensor): @@ -385,24 +436,28 @@ if version.parse(torch.__version__) >= version.parse('1.12.0'): result: List[Optional[torch.Tensor]] = [] for i, index in enumerate(indices): if index is not None: - assert index.dtype in [torch.long, torch.int8, torch.bool],\ - "tensors used as indices must be long, byte or bool tensors" + assert index.dtype in [ + torch.long, + torch.int8, + torch.bool, + ], "tensors used as indices must be long, byte or bool tensors" if index.dtype in [torch.int8, torch.bool]: nonzero = index.nonzero() k = len(result) assert k + index.ndim <= self.ndim, f"too many indices for tensor of dimension {self.ndim}" for j in range(index.ndim): - assert index.shape[j] == self.shape[ - k + - j], f"The shape of the mask {index.shape} at index {i} does not match the shape of the indexed tensor {self.shape} at index {k + j}" + assert ( + index.shape[j] == self.shape[k + j] + ), f"The shape of the mask {index.shape} at index {i} does not match the shape of the indexed tensor {self.shape} at index {k + j}" result.append(nonzero.select(1, j)) else: result.append(index) else: result.append(index) indices = result - assert len( - indices) <= self.ndim, f"too many indices for tensor of dimension {self.ndim} (got {len(indices)})" + assert ( + len(indices) <= self.ndim + ), f"too many indices for tensor of dimension {self.ndim} (got {len(indices)})" # expand_outplace import torch._refs as refs diff --git a/colossalai/_analyzer/_subclasses/_monkey_patch.py b/colossalai/_analyzer/_subclasses/_monkey_patch.py index b3ec98f0811f265e370c638369269f73e589dfd6..503981409ccaced2b6371b78c9a8e9bf77e41c06 100644 --- a/colossalai/_analyzer/_subclasses/_monkey_patch.py +++ b/colossalai/_analyzer/_subclasses/_monkey_patch.py @@ -1,5 +1,4 @@ import torch -import torch.distributed as dist from packaging import version __all__ = [ @@ -48,7 +47,7 @@ _DistCommMethod = [ "scatter", ] -if version.parse(torch.__version__) >= version.parse('1.12.0'): +if version.parse(torch.__version__) >= version.parse("1.12.0"): aten = torch.ops.aten # TODO: dive deep here # refer to https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/TensorShape.cpp diff --git a/colossalai/_analyzer/_subclasses/flop_tensor.py b/colossalai/_analyzer/_subclasses/flop_tensor.py index 59991dc5091254631dd4c641c528b09aaf97c285..9d52c5593bb8b03c3253e35aa8c87b33ca19ebd4 100644 --- a/colossalai/_analyzer/_subclasses/flop_tensor.py +++ b/colossalai/_analyzer/_subclasses/flop_tensor.py @@ -8,7 +8,7 @@ from contextlib import contextmanager from enum import Enum, auto from functools import partial, reduce from numbers import Number -from typing import Any, Callable, List, Optional, Union +from typing import Any, Callable, List, Union import torch from packaging import version @@ -36,15 +36,15 @@ def _format_flops(flop): B = 1e9 T = 1e12 if flop < K: - return f'{flop:.2f}' + return f"{flop:.2f}" elif flop < M: - return f'{flop / K:.2f}K' + return f"{flop / K:.2f}K" elif flop < B: - return f'{flop / M:.2f}M' + return f"{flop / M:.2f}M" elif flop < T: - return f'{flop / B:.2f}B' + return f"{flop / B:.2f}B" else: - return f'{flop / T:.2f}T' + return f"{flop / T:.2f}T" def flop_count(module: Union[torch.nn.Module, Callable] = None, *args, verbose: bool = False, **kwargs) -> Number: @@ -59,11 +59,13 @@ def flop_count(module: Union[torch.nn.Module, Callable] = None, *args, verbose: Returns: Number: The total number of floating point operations (FWD + BWD). """ - maybe_inplace = (getattr(module, 'inplace', False) or kwargs.get('inplace', False) - or getattr(module, '__name__', None) in ('add_', 'mul_', 'div_', 'sub_')) + maybe_inplace = ( + getattr(module, "inplace", False) + or kwargs.get("inplace", False) + or getattr(module, "__name__", None) in ("add_", "mul_", "div_", "sub_") + ) class DummyModule(torch.nn.Module): - def __init__(self, func): super().__init__() self.func = func @@ -74,21 +76,20 @@ def flop_count(module: Union[torch.nn.Module, Callable] = None, *args, verbose: total_flop_count = {Phase.FWD: 0, Phase.BWD: 0} flop_counts = defaultdict(lambda: defaultdict(int)) - parents = ['Global'] + parents = ["Global"] module = module if isinstance(module, torch.nn.Module) else DummyModule(module) class FlopTensor(MetaTensor): _tensor: torch.Tensor def __repr__(self): - name = 'FlopParameter' if getattr(self, '_is_param', False) else 'FlopTensor' + name = "FlopParameter" if getattr(self, "_is_param", False) else "FlopTensor" if self.grad_fn: return f"{name}(..., size={tuple(self.shape)}, device='{self.device}', dtype={self.dtype}, grad_fn={self.grad_fn})" return f"{name}(..., size={tuple(self.shape)}, device='{self.device}', dtype={self.dtype})" @classmethod def __torch_dispatch__(cls, func, types, args=(), kwargs=None): - # no_dispatch is only needed if you use enable_python_mode. # It prevents infinite recursion. rs = super().__torch_dispatch__(func, types, args, kwargs) @@ -115,9 +116,7 @@ def flop_count(module: Union[torch.nn.Module, Callable] = None, *args, verbose: return isinstance(x, torch.Tensor) and x.is_floating_point() def create_backwards_push(name): - class PushState(torch.autograd.Function): - @staticmethod def forward(ctx, *args): args = tree_map(lambda x: x.clone() if isinstance(x, torch.Tensor) else x, args) @@ -134,9 +133,7 @@ def flop_count(module: Union[torch.nn.Module, Callable] = None, *args, verbose: return PushState.apply def create_backwards_pop(name): - class PopState(torch.autograd.Function): - @staticmethod def forward(ctx, *args): args = tree_map(lambda x: x.clone() if isinstance(x, torch.Tensor) else x, args) @@ -147,14 +144,13 @@ def flop_count(module: Union[torch.nn.Module, Callable] = None, *args, verbose: @staticmethod def backward(ctx, *grad_outs): nonlocal parents - assert (parents[-1] == name) + assert parents[-1] == name parents.pop() return grad_outs return PopState.apply def enter_module(name): - def f(module, inputs): nonlocal parents parents.append(name) @@ -165,10 +161,9 @@ def flop_count(module: Union[torch.nn.Module, Callable] = None, *args, verbose: return f def exit_module(name): - def f(module, inputs, outputs): nonlocal parents - assert (parents[-1] == name) + assert parents[-1] == name parents.pop() outputs = normalize_tuple(outputs) return create_backwards_push(name)(*outputs) @@ -189,7 +184,7 @@ def flop_count(module: Union[torch.nn.Module, Callable] = None, *args, verbose: for mod in flop_counts.keys(): print(f"Module: ", mod) for k, v in flop_counts[mod].items(): - print('\t', k, _format_flops(v)) + print("\t", k, _format_flops(v)) print() def detach_variables(r): @@ -201,7 +196,7 @@ def flop_count(module: Union[torch.nn.Module, Callable] = None, *args, verbose: def wrap(r): if isinstance(r, torch.Tensor): - data_ptr_fn = getattr(r, '_tensor', r).data_ptr + data_ptr_fn = getattr(r, "_tensor", r).data_ptr r = FlopTensor(detach_variables(r)) if maybe_inplace: r = r + 0 @@ -375,8 +370,11 @@ def norm_flop_counter(affine_arg_index: int, input_arg_index: int) -> Callable: # Inputs[0] contains the shape of the input. input_shape = inputs[input_arg_index].shape - has_affine = inputs[affine_arg_index].shape is not None if hasattr(inputs[affine_arg_index], - 'shape') else inputs[affine_arg_index] + has_affine = ( + inputs[affine_arg_index].shape is not None + if hasattr(inputs[affine_arg_index], "shape") + else inputs[affine_arg_index] + ) assert 2 <= len(input_shape) <= 5, input_shape # 5 is just a rough estimate flop = reduce(operator.mul, input_shape) * (5 if has_affine else 4) @@ -390,7 +388,7 @@ def batchnorm_flop_jit(inputs: List[Any], outputs: List[Any], training: bool = N training = inputs[-3] assert isinstance(training, bool), "Signature of aten::batch_norm has changed!" if training: - return norm_flop_counter(1, 0)(inputs, outputs) # pyre-ignore + return norm_flop_counter(1, 0)(inputs, outputs) # pyre-ignore has_affine = inputs[1].shape is not None input_shape = reduce(operator.mul, inputs[0].shape) return input_shape * (2 if has_affine else 1) @@ -420,33 +418,30 @@ def ewise_flop_counter(input_scale: float = 1, output_scale: float = 0) -> Calla def zero_flop_jit(*args): """ - Count flops for zero flop layers. + Count flops for zero flop layers. """ return 0 -if version.parse(torch.__version__) >= version.parse('1.12.0'): +if version.parse(torch.__version__) >= version.parse("1.12.0"): flop_mapping = { - # gemm + # gemm aten.mm.default: matmul_flop_jit, aten.matmul.default: matmul_flop_jit, aten.addmm.default: addmm_flop_jit, aten.bmm.default: bmm_flop_jit, - - # convolution + # convolution aten.convolution.default: conv_flop_jit, aten._convolution.default: conv_flop_jit, aten.convolution_backward.default: conv_backward_flop_jit, - - # normalization + # normalization aten.native_batch_norm.default: batchnorm_flop_jit, aten.native_batch_norm_backward.default: batchnorm_flop_jit, aten.cudnn_batch_norm.default: batchnorm_flop_jit, aten.cudnn_batch_norm_backward.default: partial(batchnorm_flop_jit, training=True), aten.native_layer_norm.default: norm_flop_counter(2, 0), aten.native_layer_norm_backward.default: norm_flop_counter(2, 0), - - # pooling + # pooling aten.avg_pool1d.default: ewise_flop_counter(1, 0), aten.avg_pool2d.default: ewise_flop_counter(1, 0), aten.avg_pool2d_backward.default: ewise_flop_counter(0, 1), @@ -469,7 +464,7 @@ if version.parse(torch.__version__) >= version.parse('1.12.0'): } ewise_flop_aten = [ - # basic op + # basic op aten.add.Tensor, aten.add_.Tensor, aten.div.Tensor, @@ -485,8 +480,7 @@ if version.parse(torch.__version__) >= version.parse('1.12.0'): aten.sum.default, aten.sum.dim_IntList, aten.mean.dim, - - # activation op + # activation op aten.hardswish.default, aten.hardswish_.default, aten.hardswish_backward.default, @@ -509,15 +503,12 @@ if version.parse(torch.__version__) >= version.parse('1.12.0'): aten.tanh.default, aten.tanh_backward.default, aten.threshold_backward.default, - - # dropout + # dropout aten.native_dropout.default, aten.native_dropout_backward.default, - - # distribution + # distribution aten.bernoulli_.float, - - # where + # where aten.where.self, ] for op in ewise_flop_aten: diff --git a/colossalai/_analyzer/_subclasses/meta_tensor.py b/colossalai/_analyzer/_subclasses/meta_tensor.py index 2bc212938ee08f1143b5dda354d9ec600dd64662..8be97d01343ebc0dd4c4f79c622a63a47bf8c769 100644 --- a/colossalai/_analyzer/_subclasses/meta_tensor.py +++ b/colossalai/_analyzer/_subclasses/meta_tensor.py @@ -3,12 +3,12 @@ from functools import partial import torch import torch.distributed as dist -from torch.types import _bool, _device, _dtype -from torch.utils._pytree import tree_flatten, tree_map +from torch.types import _device +from torch.utils._pytree import tree_map from ._monkey_patch import _AliasATen, _DistCommMethod, _InplaceATen, _MaybeInplaceATen, _TorchOverrideableFactoryMethod -__all__ = ['MetaTensor', 'MetaTensorMode'] +__all__ = ["MetaTensor", "MetaTensorMode"] def register_storage(r, data_ptr_fn=None): @@ -28,8 +28,7 @@ def _normalize_tuple(x): # a hack of inplace execution in PyTorch def _assert_alias(func): - return func in (_AliasATen + _InplaceATen + _MaybeInplaceATen # TODO: check if should be this aggressive - ) + return func in (_AliasATen + _InplaceATen + _MaybeInplaceATen) # TODO: check if should be this aggressive class MetaTensor(torch.Tensor): @@ -65,14 +64,15 @@ class MetaTensor(torch.Tensor): storage_offset=elem.storage_offset(), dtype=elem.dtype, layout=elem.layout, - device=device or (elem.device if elem.device.type != 'meta' else torch.device('cpu')), - requires_grad=requires_grad) # deceive the frontend for aten selections + device=device or (elem.device if elem.device.type != "meta" else torch.device("cpu")), + requires_grad=requires_grad, + ) # deceive the frontend for aten selections r._tensor = elem # ...the real tensor is held as an element on the tensor. if not r._tensor.is_meta: val = elem.data_ptr() data_ptr_fn = lambda: val - r._tensor = r._tensor.to(torch.device('meta')) + r._tensor = r._tensor.to(torch.device("meta")) # only tensor not on `meta` should be copied to `meta` register_storage(r._tensor, data_ptr_fn) @@ -81,7 +81,7 @@ class MetaTensor(torch.Tensor): return r def __repr__(self): - name = 'MetaParameter' if getattr(self, '_is_param', False) else 'MetaTensor' + name = "MetaParameter" if getattr(self, "_is_param", False) else "MetaTensor" if self.grad_fn: return f"{name}(..., size={tuple(self.shape)}, device='{self.device}', dtype={self.dtype}, grad_fn={self.grad_fn})" return f"{name}(..., size={tuple(self.shape)}, device='{self.device}', dtype={self.dtype})" @@ -97,15 +97,15 @@ class MetaTensor(torch.Tensor): x = x._tensor elif isinstance(x, torch.Tensor): device = x.device - x = x.to(torch.device('meta')) + x = x.to(torch.device("meta")) return x args = tree_map(unwrap, args) kwargs = tree_map(unwrap, kwargs) - if 'device' in kwargs: - device = kwargs['device'] - kwargs['device'] = torch.device('meta') + if "device" in kwargs: + device = kwargs["device"] + kwargs["device"] = torch.device("meta") # run aten for backend=CPU but actually on backend=Meta # here we detect whether or not the execution generates a physical copy @@ -143,21 +143,21 @@ class MetaTensor(torch.Tensor): nonlocal device if isinstance(x, str) or isinstance(x, _device): device = x - return torch.device('meta') + return torch.device("meta") return x elem = self._tensor.to(*tree_map(replace, args), **tree_map(replace, kwargs)) return MetaTensor(elem, device=device) def cpu(self, *args, **kwargs): - if self.device.type == 'cpu': + if self.device.type == "cpu": return self.to(*args, **kwargs) - return self.to(*args, device='cpu', **kwargs) + return self.to(*args, device="cpu", **kwargs) def cuda(self, device=None, non_blocking=False): if device is not None: return self.to(device=device, non_blocking=non_blocking) - return self.to(device='cuda:0', non_blocking=non_blocking) + return self.to(device="cuda:0", non_blocking=non_blocking) def data_ptr(self): return self._tensor.data_ptr() @@ -177,19 +177,17 @@ class MetaTensorMode(object): """ def __init__(self): - self.torch_overrides = {} # override torch.xxx - self.dist_overrides = {} # override torch.distributed.xxx + self.torch_overrides = {} # override torch.xxx + self.dist_overrides = {} # override torch.distributed.xxx def __enter__(self): - def _dummy(*args, **kwargs): pass def _new(*args, orig_new=torch.empty, **kwargs): - return MetaTensor(orig_new(*args, **{ - **kwargs, 'device': 'meta' - }), - device=kwargs.get('device', torch.device('cpu'))) + return MetaTensor( + orig_new(*args, **{**kwargs, "device": "meta"}), device=kwargs.get("device", torch.device("cpu")) + ) for func in _TorchOverrideableFactoryMethod: self.torch_overrides[func] = getattr(torch, func) diff --git a/colossalai/_analyzer/fx/codegen.py b/colossalai/_analyzer/fx/codegen.py index 41d74f2e3719a0e56adc61a6c35684584a2d80c7..cd244b22cac0e143f990e2100a6e080708960b08 100644 --- a/colossalai/_analyzer/fx/codegen.py +++ b/colossalai/_analyzer/fx/codegen.py @@ -1,4 +1,4 @@ -from typing import Any, Callable, Dict, Iterable, List, Tuple +from typing import Any, Dict, List, Tuple import torch @@ -22,7 +22,7 @@ from torch.fx.node import Argument, Node, _get_qualified_name, _type_repr, map_a import colossalai from colossalai.fx._compatibility import compatibility -_register_custom_builtin('colossalai', 'import colossalai', colossalai) +_register_custom_builtin("colossalai", "import colossalai", colossalai) def _gen_ckpt_fn_def(label, free_vars: List[str]) -> str: @@ -43,17 +43,17 @@ def _gen_ckpt_usage(label, input_vars, output_vars, use_reentrant=True): """ Generate the checkpoint function call code text """ - outputs = ', '.join(output_vars) - inputs = ', '.join(input_vars) - return f'{outputs} = torch.utils.checkpoint.checkpoint(self.checkpoint_{label}, {inputs}, use_reentrant={use_reentrant})' + outputs = ", ".join(output_vars) + inputs = ", ".join(input_vars) + return f"{outputs} = torch.utils.checkpoint.checkpoint(self.checkpoint_{label}, {inputs}, use_reentrant={use_reentrant})" def _end_of_ckpt(node: Node, ckpt_level: int) -> bool: """ Check if the node could end the ckpt region at `ckpt_level` """ - if len(node.meta['info'].activation_checkpoint) > ckpt_level: - return node.meta['info'].activation_checkpoint[ckpt_level] is not None + if len(node.meta["info"].activation_checkpoint) > ckpt_level: + return node.meta["info"].activation_checkpoint[ckpt_level] is not None return True @@ -94,8 +94,8 @@ def _find_nested_ckpt_regions(node_list: List[Node], ckpt_level: int = 0): current_region = None for idx, node in enumerate(node_list): - if len(node.meta['info'].activation_checkpoint) > ckpt_level: - act_ckpt_label = node.meta['info'].activation_checkpoint[ckpt_level] + if len(node.meta["info"].activation_checkpoint) > ckpt_level: + act_ckpt_label = node.meta["info"].activation_checkpoint[ckpt_level] # this activation checkpoint label is not set yet # meaning this is the first node of the activation ckpt region @@ -131,13 +131,9 @@ def _find_nested_ckpt_regions(node_list: List[Node], ckpt_level: int = 0): return ckpt_regions -def emit_ckpt_func(body, - ckpt_func, - node_list: List[Node], - emit_node_func, - delete_unused_value_func, - ckpt_level=0, - in_ckpt=False): +def emit_ckpt_func( + body, ckpt_func, node_list: List[Node], emit_node_func, delete_unused_value_func, ckpt_level=0, in_ckpt=False +): """Emit ckpt function in nested way Args: @@ -156,12 +152,12 @@ def emit_ckpt_func(body, # label given by each layer, e.g. if you are currently at level (0, 1, 1) # the label will be '0_1_1' - label = "_".join([str(idx) for idx in node_list[0].meta['info'].activation_checkpoint[:ckpt_level + 1]]) + label = "_".join([str(idx) for idx in node_list[0].meta["info"].activation_checkpoint[: ckpt_level + 1]]) ckpt_fn_def = _gen_ckpt_fn_def(label, inputs) - ckpt_func.append(f'{ckpt_fn_def}\n') + ckpt_func.append(f"{ckpt_fn_def}\n") # if there is more level to fetch - if ckpt_level + 1 < max(map(lambda node: len(node.meta['info'].activation_checkpoint), node_list)): + if ckpt_level + 1 < max(map(lambda node: len(node.meta["info"].activation_checkpoint), node_list)): ckpt_regions = _find_nested_ckpt_regions(node_list, ckpt_level + 1) start_idx = [item[0] for item in ckpt_regions] end_idx = [item[1] for item in ckpt_regions] @@ -174,33 +170,40 @@ def emit_ckpt_func(body, break if node_idx in start_idx: - ckpt_node_list = node_list[node_idx:end_idx[start_idx.index(node_idx)] + 1] - emit_ckpt_func(ckpt_func, ckpt_func_buffer, ckpt_node_list, emit_node_func, delete_unused_value_func, - ckpt_level + 1, True) + ckpt_node_list = node_list[node_idx : end_idx[start_idx.index(node_idx)] + 1] + emit_ckpt_func( + ckpt_func, + ckpt_func_buffer, + ckpt_node_list, + emit_node_func, + delete_unused_value_func, + ckpt_level + 1, + True, + ) node_idx += len(ckpt_node_list) else: node = node_list[node_idx] emit_node_func(node, ckpt_func) - ckpt_func[-1] = ' ' + ckpt_func[-1] + ckpt_func[-1] = " " + ckpt_func[-1] delete_unused_value_func(node, ckpt_func) node_idx += 1 - ckpt_func.append(' ' + _gen_ckpt_output(outputs) + '\n\n') + ckpt_func.append(" " + _gen_ckpt_output(outputs) + "\n\n") ckpt_func += ckpt_func_buffer # last level else: for node in node_list: emit_node_func(node, ckpt_func) - ckpt_func[-1] = ' ' + ckpt_func[-1] + ckpt_func[-1] = " " + ckpt_func[-1] delete_unused_value_func(node, ckpt_func) - ckpt_func.append(' ' + _gen_ckpt_output(outputs) + '\n\n') + ckpt_func.append(" " + _gen_ckpt_output(outputs) + "\n\n") - usage = _gen_ckpt_usage(label, inputs, outputs, False) + '\n' + usage = _gen_ckpt_usage(label, inputs, outputs, False) + "\n" if in_ckpt: - usage = ' ' + usage + usage = " " + usage body.append(usage) @@ -229,7 +232,7 @@ def emit_code_with_activation_checkpoint(body, ckpt_func, nodes, emit_node_func, # process ckpt_regions if node_idx in start_idx: - ckpt_node_list = node_list[node_idx:end_idx[start_idx.index(node_idx)] + 1] + ckpt_node_list = node_list[node_idx : end_idx[start_idx.index(node_idx)] + 1] emit_ckpt_func(body, ckpt_func, ckpt_node_list, emit_node_func, delete_unused_value_func) node_idx += len(ckpt_node_list) @@ -243,7 +246,6 @@ def emit_code_with_activation_checkpoint(body, ckpt_func, nodes, emit_node_func, @compatibility(is_backward_compatible=True) class ActivationCheckpointCodeGen(CodeGen): - def _gen_python_code(self, nodes, root_module: str, namespace: _Namespace) -> PythonCode: free_vars: List[str] = [] body: List[str] = [] @@ -251,7 +253,7 @@ class ActivationCheckpointCodeGen(CodeGen): wrapped_fns: Dict[str, None] = {} # Wrap string in list to pass by reference - maybe_return_annotation: List[str] = [''] + maybe_return_annotation: List[str] = [""] def add_global(name_hint: str, obj: Any): """Add an obj to be tracked as a global. @@ -259,7 +261,7 @@ class ActivationCheckpointCodeGen(CodeGen): Graph, like functions or types. Returns: the global name that should be used to reference 'obj' in generated source. """ - if _is_from_torch(obj) and obj != torch.device: # to support registering torch.device + if _is_from_torch(obj) and obj != torch.device: # to support registering torch.device # HACK: workaround for how torch custom ops are registered. We # can't import them like normal modules so they must retain their # fully qualified name. @@ -281,16 +283,16 @@ class ActivationCheckpointCodeGen(CodeGen): def type_repr(o: Any): if o == (): # Empty tuple is used for empty tuple type annotation Tuple[()] - return '()' + return "()" typename = _type_repr(o) - if hasattr(o, '__origin__'): + if hasattr(o, "__origin__"): # This is a generic type, e.g. typing.List[torch.Tensor] origin_type = _origin_type_map.get(o.__origin__, o.__origin__) origin_typename = add_global(_type_repr(origin_type), origin_type) - if hasattr(o, '__args__'): + if hasattr(o, "__args__"): # Assign global names for each of the inner type variables. args = [type_repr(arg) for arg in o.__args__] @@ -309,19 +311,18 @@ class ActivationCheckpointCodeGen(CodeGen): return add_global(typename, o) def _format_args(args: Tuple[Argument, ...], kwargs: Dict[str, Argument]) -> str: - def _get_repr(arg): # Handle NamedTuples (if it has `_fields`) via add_global. - if isinstance(arg, tuple) and hasattr(arg, '_fields'): + if isinstance(arg, tuple) and hasattr(arg, "_fields"): qualified_name = _get_qualified_name(type(arg)) global_name = add_global(qualified_name, type(arg)) return f"{global_name}{repr(tuple(arg))}" return repr(arg) - args_s = ', '.join(_get_repr(a) for a in args) - kwargs_s = ', '.join(f'{k} = {_get_repr(v)}' for k, v in kwargs.items()) + args_s = ", ".join(_get_repr(a) for a in args) + kwargs_s = ", ".join(f"{k} = {_get_repr(v)}" for k, v in kwargs.items()) if args_s and kwargs_s: - return f'{args_s}, {kwargs_s}' + return f"{args_s}, {kwargs_s}" return args_s or kwargs_s # Run through reverse nodes and record the first instance of a use @@ -347,82 +348,94 @@ class ActivationCheckpointCodeGen(CodeGen): not used in the remainder of the code are freed and the memory usage of the code is optimal. """ - if user.op == 'placeholder': + if user.op == "placeholder": return - if user.op == 'output': - body.append('\n') + if user.op == "output": + body.append("\n") return nodes_to_delete = user_to_last_uses.get(user, []) if len(nodes_to_delete): - to_delete_str = ' = '.join([repr(n) for n in nodes_to_delete] + ['None']) - body.append(f'; {to_delete_str}\n') + to_delete_str = " = ".join([repr(n) for n in nodes_to_delete] + ["None"]) + body.append(f"; {to_delete_str}\n") else: - body.append('\n') + body.append("\n") # NOTE: we add a variable to distinguish body and ckpt_func def emit_node(node: Node, body): - maybe_type_annotation = '' if node.type is None else f' : {type_repr(node.type)}' - if node.op == 'placeholder': + maybe_type_annotation = "" if node.type is None else f" : {type_repr(node.type)}" + if node.op == "placeholder": assert isinstance(node.target, str) - maybe_default_arg = '' if not node.args else f' = {repr(node.args[0])}' - free_vars.append(f'{node.target}{maybe_type_annotation}{maybe_default_arg}') - raw_name = node.target.replace('*', '') + maybe_default_arg = "" if not node.args else f" = {repr(node.args[0])}" + free_vars.append(f"{node.target}{maybe_type_annotation}{maybe_default_arg}") + raw_name = node.target.replace("*", "") if raw_name != repr(node): - body.append(f'{repr(node)} = {raw_name}\n') + body.append(f"{repr(node)} = {raw_name}\n") return - elif node.op == 'call_method': + elif node.op == "call_method": assert isinstance(node.target, str) - body.append(f'{repr(node)}{maybe_type_annotation} = {_format_target(repr(node.args[0]), node.target)}' - f'({_format_args(node.args[1:], node.kwargs)})') + body.append( + f"{repr(node)}{maybe_type_annotation} = {_format_target(repr(node.args[0]), node.target)}" + f"({_format_args(node.args[1:], node.kwargs)})" + ) return - elif node.op == 'call_function': + elif node.op == "call_function": assert callable(node.target) # pretty print operators - if node.target.__module__ == '_operator' and node.target.__name__ in magic_methods: + if node.target.__module__ == "_operator" and node.target.__name__ in magic_methods: assert isinstance(node.args, tuple) - body.append(f'{repr(node)}{maybe_type_annotation} = ' - f'{magic_methods[node.target.__name__].format(*(repr(a) for a in node.args))}') + body.append( + f"{repr(node)}{maybe_type_annotation} = " + f"{magic_methods[node.target.__name__].format(*(repr(a) for a in node.args))}" + ) return # pretty print inplace operators; required for jit.script to work properly # not currently supported in normal FX graphs, but generated by torchdynamo - if node.target.__module__ == '_operator' and node.target.__name__ in inplace_methods: - body.append(f'{inplace_methods[node.target.__name__].format(*(repr(a) for a in node.args))}; ' - f'{repr(node)}{maybe_type_annotation} = {repr(node.args[0])}') + if node.target.__module__ == "_operator" and node.target.__name__ in inplace_methods: + body.append( + f"{inplace_methods[node.target.__name__].format(*(repr(a) for a in node.args))}; " + f"{repr(node)}{maybe_type_annotation} = {repr(node.args[0])}" + ) return qualified_name = _get_qualified_name(node.target) global_name = add_global(qualified_name, node.target) # special case for getattr: node.args could be 2-argument or 3-argument # 2-argument: attribute access; 3-argument: fall through to attrib function call with default value - if global_name == 'getattr' and \ - isinstance(node.args, tuple) and \ - isinstance(node.args[1], str) and \ - node.args[1].isidentifier() and \ - len(node.args) == 2: + if ( + global_name == "getattr" + and isinstance(node.args, tuple) + and isinstance(node.args[1], str) + and node.args[1].isidentifier() + and len(node.args) == 2 + ): body.append( - f'{repr(node)}{maybe_type_annotation} = {_format_target(repr(node.args[0]), node.args[1])}') + f"{repr(node)}{maybe_type_annotation} = {_format_target(repr(node.args[0]), node.args[1])}" + ) return body.append( - f'{repr(node)}{maybe_type_annotation} = {global_name}({_format_args(node.args, node.kwargs)})') - if node.meta.get('is_wrapped', False): + f"{repr(node)}{maybe_type_annotation} = {global_name}({_format_args(node.args, node.kwargs)})" + ) + if node.meta.get("is_wrapped", False): wrapped_fns.setdefault(global_name) return - elif node.op == 'call_module': + elif node.op == "call_module": assert isinstance(node.target, str) - body.append(f'{repr(node)}{maybe_type_annotation} = ' - f'{_format_target(root_module, node.target)}({_format_args(node.args, node.kwargs)})') + body.append( + f"{repr(node)}{maybe_type_annotation} = " + f"{_format_target(root_module, node.target)}({_format_args(node.args, node.kwargs)})" + ) return - elif node.op == 'get_attr': + elif node.op == "get_attr": assert isinstance(node.target, str) - body.append(f'{repr(node)}{maybe_type_annotation} = {_format_target(root_module, node.target)}') + body.append(f"{repr(node)}{maybe_type_annotation} = {_format_target(root_module, node.target)}") return - elif node.op == 'output': + elif node.op == "output": if node.type is not None: maybe_return_annotation[0] = f" -> {type_repr(node.type)}" body.append(self.generate_output(node.args[0])) return - raise NotImplementedError(f'node: {node.op} {node.target}') + raise NotImplementedError(f"node: {node.op} {node.target}") # Modified for activation checkpointing ckpt_func = [] @@ -432,13 +445,13 @@ class ActivationCheckpointCodeGen(CodeGen): # If the Graph has no non-placeholder nodes, no lines for the body # have been emitted. To continue to have valid Python code, emit a # single pass statement - body.append('pass\n') + body.append("pass\n") if len(wrapped_fns) > 0: - wrap_name = add_global('wrap', torch.fx.wrap) - wrap_stmts = '\n'.join([f'{wrap_name}("{name}")' for name in wrapped_fns]) + wrap_name = add_global("wrap", torch.fx.wrap) + wrap_stmts = "\n".join([f'{wrap_name}("{name}")' for name in wrapped_fns]) else: - wrap_stmts = '' + wrap_stmts = "" if self._body_transformer: body = self._body_transformer(body) @@ -447,11 +460,11 @@ class ActivationCheckpointCodeGen(CodeGen): add_global(name, value) prologue = self.gen_fn_def(free_vars, maybe_return_annotation[0]) - prologue = ''.join(ckpt_func) + prologue + prologue = "".join(ckpt_func) + prologue prologue = prologue - code = ''.join(body) - code = '\n'.join(' ' + line for line in code.split('\n')) + code = "".join(body) + code = "\n".join(" " + line for line in code.split("\n")) fn_code = f""" {wrap_stmts} {prologue} diff --git a/colossalai/_analyzer/fx/graph_module.py b/colossalai/_analyzer/fx/graph_module.py index 1fdedd758c01f8fef9c0ba2e529b528527d8d527..9d3999e322b90619c5d569d48b7c202c8e167296 100644 --- a/colossalai/_analyzer/fx/graph_module.py +++ b/colossalai/_analyzer/fx/graph_module.py @@ -13,6 +13,7 @@ from torch.fx.graph import PythonCode try: from torch.fx.graph import _PyTreeCodeGen + SUPPORT_PT_CODEGEN = True except ImportError: SUPPORT_PT_CODEGEN = False @@ -24,7 +25,6 @@ from torch.nn.modules.module import _addindent # This is a copy of torch.fx.graph_module._WrappedCall. # It should be removed when we stop supporting torch < 1.12.0. class _WrappedCall: - def __init__(self, cls, cls_call): self.cls = cls self.cls_call = cls_call @@ -50,12 +50,14 @@ class _WrappedCall: # constituent substrings of the error message tb_repr = traceback.format_exc() - custom_msg = ("Call using an FX-traced Module, " - f"line {err_lineno} of the traced Module's " - "generated forward function:") - before_err = "".join(all_src_lines[err_lineno - 2:err_lineno]) + custom_msg = ( + "Call using an FX-traced Module, " + f"line {err_lineno} of the traced Module's " + "generated forward function:" + ) + before_err = "".join(all_src_lines[err_lineno - 2 : err_lineno]) marker = "~" * err_line_len + "~~~ <--- HERE" - err_and_after_err = "\n".join(all_src_lines[err_lineno:err_lineno + 2]) + err_and_after_err = "\n".join(all_src_lines[err_lineno : err_lineno + 2]) # joined message return "\n".join([tb_repr, custom_msg, before_err, marker, err_and_after_err]) @@ -65,11 +67,14 @@ class _WrappedCall: if self.cls_call is not None: return self.cls_call(obj, *args, **kwargs) else: - return super(self.cls, obj).__call__(*args, **kwargs) # type: ignore[misc] + return super(self.cls, obj).__call__(*args, **kwargs) # type: ignore[misc] except Exception as e: assert e.__traceback__ - topmost_framesummary: traceback.FrameSummary = \ - traceback.StackSummary.extract(traceback.walk_tb(e.__traceback__))[-1] # type: ignore[arg-type] + topmost_framesummary: traceback.FrameSummary = traceback.StackSummary.extract( + traceback.walk_tb(e.__traceback__) + )[ + -1 + ] # type: ignore[arg-type] if "eval_with_key" in topmost_framesummary.filename: print(_WrappedCall._generate_error_message(topmost_framesummary), file=sys.stderr) raise e.with_traceback(None) @@ -99,10 +104,9 @@ class ColoGraphModule(torch.fx.GraphModule): code. """ - def __init__(self, - root: Union[torch.nn.Module, Dict[str, Any]], - graph: torch.fx.Graph, - class_name: str = 'GraphModule'): + def __init__( + self, root: Union[torch.nn.Module, Dict[str, Any]], graph: torch.fx.Graph, class_name: str = "GraphModule" + ): super().__init__(root, graph, class_name) def bind(self, ckpt_def, globals): @@ -134,7 +138,7 @@ class ColoGraphModule(torch.fx.GraphModule): if SUPPORT_PT_CODEGEN and isinstance(self._graph._codegen, _PyTreeCodeGen): self._in_spec = self._graph._codegen.pytree_info.in_spec self._out_spec = self._graph._codegen.pytree_info.out_spec - python_code = self._graph.python_code(root_module='self') + python_code = self._graph.python_code(root_module="self") self._code = python_code.src # To split ckpt functions code and forward code @@ -157,8 +161,8 @@ class ColoGraphModule(torch.fx.GraphModule): # bypass patching of torch.nn.Module.__call__ done while symbolic tracing. cls_call = cls.__call__ if "__call__" in vars(cls) else None - if '_wrapped_call' not in vars(cls): - cls._wrapped_call = _WrappedCall(cls, cls_call) # type: ignore[attr-defined] + if "_wrapped_call" not in vars(cls): + cls._wrapped_call = _WrappedCall(cls, cls_call) # type: ignore[attr-defined] def call_wrapped(self, *args, **kwargs): return self._wrapped_call(self, *args, **kwargs) @@ -182,7 +186,7 @@ class ColoGraphModule(torch.fx.GraphModule): """ folder = Path(folder) Path(folder).mkdir(exist_ok=True) - torch.save(self.state_dict(), folder / 'state_dict.pt') + torch.save(self.state_dict(), folder / "state_dict.pt") tab = " " * 4 # we add import colossalai here @@ -208,10 +212,10 @@ class {module_name}(torch.nn.Module): for module_name, module in self.named_children(): module_str = _gen_model_repr(module_name, module) if module_str is None: - module_file = folder / f'{module_name}.pt' + module_file = folder / f"{module_name}.pt" torch.save(module, module_file) blobified_modules.append(module_name) - module_repr = module.__repr__().replace('\r', ' ').replace('\n', ' ') + module_repr = module.__repr__().replace("\r", " ").replace("\n", " ") module_str = f"torch.load(r'{module_file}') # {module_repr}" model_str += f"{tab*2}self.{module_name} = {module_str}\n" @@ -228,12 +232,14 @@ class {module_name}(torch.nn.Module): model_str += f"{tab*2}self.load_state_dict(torch.load(r'{folder}/state_dict.pt'))\n" model_str += f"{_addindent(self.code, 4)}\n" - module_file = folder / 'module.py' + module_file = folder / "module.py" module_file.write_text(model_str) - init_file = folder / '__init__.py' - init_file.write_text('from .module import *') + init_file = folder / "__init__.py" + init_file.write_text("from .module import *") if len(blobified_modules) > 0: - warnings.warn("Was not able to save the following children modules as reprs -" - f"saved as pickled files instead: {blobified_modules}") + warnings.warn( + "Was not able to save the following children modules as reprs -" + f"saved as pickled files instead: {blobified_modules}" + ) diff --git a/colossalai/_analyzer/fx/node_util.py b/colossalai/_analyzer/fx/node_util.py index fbe8400a437eef944b6e8ac0a1126774797da031..d2671787ea63899db3aa71ff10d055a665291f18 100644 --- a/colossalai/_analyzer/fx/node_util.py +++ b/colossalai/_analyzer/fx/node_util.py @@ -1,9 +1,9 @@ from dataclasses import dataclass, field -from typing import Callable, ClassVar, Dict, List, Optional, Tuple, Union +from typing import Dict, List, Optional, Tuple, Union import torch -from torch.autograd.profiler_util import _format_memory, _format_time -from torch.fx import Graph, GraphModule, Node +from torch.autograd.profiler_util import _format_memory +from torch.fx import Node from colossalai._analyzer.envs import MeshConfig @@ -85,12 +85,12 @@ class MetaInfo: node: Node # directory - mod_dir: str = '' + mod_dir: str = "" # ctx[data_ptr] = Tensor # mark the storage for ctx.save_for_backward - global_ctx: Dict[str, torch.Tensor] = field(default_factory=lambda: {}) # globally shared - curr_ctx: Dict[str, torch.Tensor] = field(default_factory=lambda: {}) # global_ctx till this node + global_ctx: Dict[str, torch.Tensor] = field(default_factory=lambda: {}) # globally shared + curr_ctx: Dict[str, torch.Tensor] = field(default_factory=lambda: {}) # global_ctx till this node # should be updated after each graph manipulation # ============================== Update ==================================== @@ -100,7 +100,7 @@ class MetaInfo: inputs: Tuple[torch.Tensor] = () outputs: Tuple[torch.Tensor] = () - is_alias: Tuple[bool] = () # whether the output is an alias of input + is_alias: Tuple[bool] = () # whether the output is an alias of input # compute cost fwd_flop: Optional[int] = 0 @@ -112,29 +112,29 @@ class MetaInfo: # should keep the same whenever manipulated # ============================= Invariant ================================== - activation_checkpoint: Tuple[torch.Tensor] = () # (region_0, region_1, ...) support nested codegen + activation_checkpoint: Tuple[torch.Tensor] = () # (region_0, region_1, ...) support nested codegen to_offload: Optional[bool] = False - sharding_spec: str = 'RR' + sharding_spec: str = "RR" def __new__(cls, node: Node, **kwargs): orig_init = cls.__init__ # if initialized, return the existing one # should disable the __init__ function - if node.meta.get('info', None) is not None: + if node.meta.get("info", None) is not None: def _dummy(self, *args, **kwargs): - if getattr(self, '_is_init', False): + if getattr(self, "_is_init", False): self._is_init = True orig_init(self, *args, **kwargs) cls.__init__ = orig_init cls.__init__ = _dummy - return node.meta['info'] + return node.meta["info"] return super().__new__(cls) def __post_init__(self): - self.node.meta['info'] = self + self.node.meta["info"] = self @property def fwd_time(self, tflops: float = MeshConfig.TFLOPS, bandwidth: float = MeshConfig.BANDWIDTH): @@ -188,24 +188,26 @@ class MetaInfo: return compute_size_in_bytes(self.inputs) def __repr__(self): - s = f'Node {self.node.name}' + s = f"Node {self.node.name}" if self.parameters: - s += f'\n\thas parameter of size {_format_memory(self.param_size)}' + s += f"\n\thas parameter of size {_format_memory(self.param_size)}" if self.buffers: - s += f'\n\thas buffer of size {_format_memory(self.buffer_size)}' + s += f"\n\thas buffer of size {_format_memory(self.buffer_size)}" if self.output_size: - s += f'\n\thas output activation of size {_format_memory(self.output_size)}' + s += f"\n\thas output activation of size {_format_memory(self.output_size)}" # if self.total_size: # s += f'\n\thas total activation of size {_format_memory(self.total_size)}' if self.temp_size: - s += f'\n\thas temp activation of size {_format_memory(self.temp_size)}' + s += f"\n\thas temp activation of size {_format_memory(self.temp_size)}" if self.backward_size: - s += f'\n\thas backward activation of size {_format_memory(self.backward_size)}' - s += f'\n\tfwd_flop = {self.fwd_flop}'\ - f'\n\tbwd_flop = {self.bwd_flop}'\ - f'\n\tfwd_comm = {self.fwd_comm}'\ - f'\n\tbwd_comm = {self.bwd_comm}'\ - f'\n\tto_recompute = {self.to_recompute}'\ - f'\n\tto_offload = {self.to_offload}'\ - f'\n\tsharding_spec = {self.sharding_spec}' + s += f"\n\thas backward activation of size {_format_memory(self.backward_size)}" + s += ( + f"\n\tfwd_flop = {self.fwd_flop}" + f"\n\tbwd_flop = {self.bwd_flop}" + f"\n\tfwd_comm = {self.fwd_comm}" + f"\n\tbwd_comm = {self.bwd_comm}" + f"\n\tto_recompute = {self.to_recompute}" + f"\n\tto_offload = {self.to_offload}" + f"\n\tsharding_spec = {self.sharding_spec}" + ) return s diff --git a/colossalai/_analyzer/fx/passes/graph_profile.py b/colossalai/_analyzer/fx/passes/graph_profile.py index c3e760b31e96df2cd4f2de52c6309fefa9183417..158ebce219cd62c98c9fa2b064a8489d70868c39 100644 --- a/colossalai/_analyzer/fx/passes/graph_profile.py +++ b/colossalai/_analyzer/fx/passes/graph_profile.py @@ -1,8 +1,8 @@ -from typing import Any, Dict, Iterator, List, Optional, Tuple, Union +from typing import Any, Dict, Iterator, List, Optional, Tuple import torch import torch.fx -from torch.autograd.profiler_util import _format_memory, _format_time +from torch.autograd.profiler_util import _format_memory from torch.fx import GraphModule from torch.fx.node import Argument, Node, Target @@ -13,14 +13,14 @@ from colossalai._analyzer.fx.node_util import MetaInfo def _format_flops(flops: float) -> str: """Returns a formatted FLOP size string""" if flops > 1e12: - return f'{flops / 1e12:.2f} TFLOPs' + return f"{flops / 1e12:.2f} TFLOPs" elif flops > 1e9: - return f'{flops / 1e9:.2f} GFLOPs' + return f"{flops / 1e9:.2f} GFLOPs" elif flops > 1e6: - return f'{flops / 1e6:.2f} MFLOPs' + return f"{flops / 1e6:.2f} MFLOPs" elif flops > 1e3: - return f'{flops / 1e3:.2f} kFLOPs' - return f'{flops} FLOPs' + return f"{flops / 1e3:.2f} kFLOPs" + return f"{flops} FLOPs" def _denormalize_tuple(t: Tuple[int, ...]) -> Tuple[int, ...]: @@ -42,10 +42,11 @@ class GraphProfiler(torch.fx.Interpreter): Fetch shape argument from ``ShapeProp`` without re-executing the ``GraphModule`` from scratch. """ + _profileable = [ - 'call_function', - 'call_module', - 'call_method', + "call_function", + "call_module", + "call_method", ] def __init__(self, module: GraphModule, garbage_collect_values: bool = True): @@ -77,14 +78,13 @@ class GraphProfiler(torch.fx.Interpreter): self.args_iter: Iterator[Any] = iter(args) for node in self.module.graph.nodes: - - self.run_node(node) # No need to store. + self.run_node(node) # No need to store. if self.garbage_collect_values: for to_delete in self.user_to_last_uses.get(node, []): del self.env[to_delete] - if node.op == 'output': + if node.op == "output": output_val = self.env[node] return self.module.graph.process_outputs(output_val) if enable_io_processing else output_val @@ -133,9 +133,11 @@ class GraphProfiler(torch.fx.Interpreter): try: from tabulate import tabulate except ImportError: - print("`summary` relies on the library `tabulate`, " - "which could not be found on this machine. Run `pip " - "install tabulate` to install the library.") + print( + "`summary` relies on the library `tabulate`, " + "which could not be found on this machine. Run `pip " + "install tabulate` to install the library." + ) # Build up a list of summary information for each node node_summaries: List[List[Any]] = [] @@ -145,36 +147,38 @@ class GraphProfiler(torch.fx.Interpreter): node: Node n_info = MetaInfo(node) last_n_info = last_n_info or n_info - node_summaries.append([ - node.op, - str(node), - _format_memory(n_info.accumulate_size), - _format_memory(n_info.accumulate_size - last_n_info.accumulate_size), - _format_memory(n_info.output_size), - _format_memory(n_info.temp_size), - _format_memory(n_info.param_size), - _format_memory(n_info.backward_size), - _format_flops(n_info.fwd_flop), - _format_flops(n_info.bwd_flop), - ]) + node_summaries.append( + [ + node.op, + str(node), + _format_memory(n_info.accumulate_size), + _format_memory(n_info.accumulate_size - last_n_info.accumulate_size), + _format_memory(n_info.output_size), + _format_memory(n_info.temp_size), + _format_memory(n_info.param_size), + _format_memory(n_info.backward_size), + _format_flops(n_info.fwd_flop), + _format_flops(n_info.bwd_flop), + ] + ) last_n_info = n_info # Use the ``tabulate`` library to create a well-formatted table # presenting our summary information headers: List[str] = [ - 'Op type', - 'Op', - 'Accumulate size', - 'Incremental size', - 'Output size', - 'Temp size', - 'Param size', - 'Backward size', - 'Fwd FLOPs', - 'Bwd FLOPs', + "Op type", + "Op", + "Accumulate size", + "Incremental size", + "Output size", + "Temp size", + "Param size", + "Backward size", + "Fwd FLOPs", + "Bwd FLOPs", ] - return tabulate(node_summaries, headers=headers, stralign='right') + return tabulate(node_summaries, headers=headers, stralign="right") class CommunicationProfiler(GraphProfiler): @@ -222,6 +226,7 @@ class FlopProfiler(GraphProfiler): >>> def my_fn_flop_count_impl(*args, **kwargs): >>> return 0, 0 """ + _custom_flop_count_impl = {} def run_node(self, n: torch.fx.Node) -> Any: @@ -246,11 +251,13 @@ class FlopProfiler(GraphProfiler): ( n_info.fwd_flop, n_info.bwd_flop, - ) = getattr(self, n.op)(n.target, args, kwargs) + ) = getattr( + self, n.op + )(n.target, args, kwargs) except Exception as e: raise RuntimeError( - f'Error {str(e)} occurred when profiling node {n}, node.target = {n.target}. ' - f'Please refer to function\'s docstring to register the relevant profile_impl for this node!' + f"Error {str(e)} occurred when profiling node {n}, node.target = {n.target}. " + f"Please refer to function's docstring to register the relevant profile_impl for this node!" ) from e # retain the autograd graph @@ -259,7 +266,7 @@ class FlopProfiler(GraphProfiler): return _denormalize_tuple(n_info.outputs) - def call_function(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any: + def call_function(self, target: "Target", args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any: """ Execute a ``call_function`` node and return the profiling result. Dispatch to ``_custom_flop_count_impl`` if ``call_function`` should be @@ -283,7 +290,7 @@ class FlopProfiler(GraphProfiler): else: return flop_count(target, *args, **kwargs) - def call_method(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any: + def call_method(self, target: "Target", args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any: """ Execute a ``call_method`` node and return the profiling result. @@ -301,7 +308,7 @@ class FlopProfiler(GraphProfiler): assert isinstance(target, str) return flop_count(getattr(torch.Tensor, target), *args, **kwargs) - def call_module(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any: + def call_module(self, target: "Target", args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any: """ Execute a ``call_module`` node and return the profiling result. @@ -336,9 +343,10 @@ def graph_profile_pass(module: GraphModule, *args, verbose=False) -> GraphModule Returns: GraphModule: The same GraphModule with profiling information """ - for profiler_cls in (FlopProfiler, - # CommunicationProfiler, # TODO: add communication profiling - ): + for profiler_cls in ( + FlopProfiler, + # CommunicationProfiler, # TODO: add communication profiling + ): profiler = profiler_cls(module) profiler.propagate(*args, device=_current_device(module)) diff --git a/colossalai/_analyzer/fx/passes/shape_prop.py b/colossalai/_analyzer/fx/passes/shape_prop.py index 23e83013e02fd6a60d016f9b5c86a17d85cdaf81..8d44f1d4b59d33bb5c7bed3cb20b44448198c917 100644 --- a/colossalai/_analyzer/fx/passes/shape_prop.py +++ b/colossalai/_analyzer/fx/passes/shape_prop.py @@ -54,7 +54,7 @@ def _current_device(module): try: return next(module.parameters()).device except StopIteration: - return torch.device('cpu') + return torch.device("cpu") @compatibility(is_backward_compatible=False) @@ -90,6 +90,7 @@ class ShapeProp(torch.fx.Interpreter): >>> # do something here >>> return torch.empty(output_shape, device=output_device) """ + _custom_dispatch_func = {} _mode = MetaTensorMode() @@ -115,15 +116,14 @@ class ShapeProp(torch.fx.Interpreter): r = getattr(self, n.op)(n.target, args, kwargs) def unwrap_fn(elem): - def _convert_meta(t: torch.Tensor): - if t.device == 'meta': + if t.device == "meta": return t else: - return t.to('meta') + return t.to("meta") if isinstance(elem, MetaTensor): - if getattr(self, '_is_param', False): + if getattr(self, "_is_param", False): return torch.nn.Parameter(_convert_meta(elem._tensor)) return _convert_meta(elem._tensor) @@ -139,21 +139,24 @@ class ShapeProp(torch.fx.Interpreter): n_info = MetaInfo(n) n_info.outputs = _normalize_tuple(r) - if n.op == 'call_module': + if n.op == "call_module": submod = self.fetch_attr(n.target) n_info.parameters.update({k: MetaTensor(v) for k, v in submod.named_parameters()}) n_info.buffers.update({k: MetaTensor(v) for k, v in submod.named_buffers()}) else: - n_info.parameters.update({ - k.name: MetaTensor(v) - for k, v in zip(n.args, args) - if isinstance(k, torch.fx.Node) and isinstance(v, torch.nn.Parameter) - }) + n_info.parameters.update( + { + k.name: MetaTensor(v) + for k, v in zip(n.args, args) + if isinstance(k, torch.fx.Node) and isinstance(v, torch.nn.Parameter) + } + ) n_info.parameters.update({k: MetaTensor(v) for k, v in kwargs.items() if isinstance(v, torch.nn.Parameter)}) - n_info.inputs = tuple(v for v in args if is_pure_tensor(v)) + \ - tuple(v for v in kwargs.values() if is_pure_tensor(v)) + n_info.inputs = tuple(v for v in args if is_pure_tensor(v)) + tuple( + v for v in kwargs.values() if is_pure_tensor(v) + ) # align with SPMD if isinstance(r, (tuple, list)): @@ -168,7 +171,7 @@ class ShapeProp(torch.fx.Interpreter): n_info.is_alias = _normalize_tuple(tree_map(crit, n_info.outputs)) return r - def call_function(self, target: 'Target', args: Tuple[Any, ...], kwargs: Dict[str, Any]) -> Any: + def call_function(self, target: "Target", args: Tuple[Any, ...], kwargs: Dict[str, Any]) -> Any: """ Execute a ``call_function`` node and return the result. If the target of ``Node`` is registered with ``@register_shape_impl``, @@ -197,7 +200,7 @@ class ShapeProp(torch.fx.Interpreter): else: return res - def call_method(self, target: 'Target', args: Tuple[Any, ...], kwargs: Dict[str, Any]) -> Any: + def call_method(self, target: "Target", args: Tuple[Any, ...], kwargs: Dict[str, Any]) -> Any: """ Execute a ``call_method`` node and return the result. @@ -218,7 +221,8 @@ class ShapeProp(torch.fx.Interpreter): convert_to_parameter = False if target_method in (torch.Tensor.view, torch.Tensor.transpose) and isinstance( - args[0], torch.nn.parameter.Parameter): + args[0], torch.nn.parameter.Parameter + ): convert_to_parameter = True # Execute the method and return the result assert isinstance(target, str) diff --git a/colossalai/_analyzer/fx/symbolic_profile.py b/colossalai/_analyzer/fx/symbolic_profile.py index dd7f22c6c98a0d47c946935a991a1f31d1052734..5732a6665f7843a6bd21d11abc002f088a059c44 100644 --- a/colossalai/_analyzer/fx/symbolic_profile.py +++ b/colossalai/_analyzer/fx/symbolic_profile.py @@ -1,5 +1,3 @@ -import torch -import torch.fx from torch.fx import GraphModule from .passes import ShapeProp, graph_profile_pass, shape_prop_pass @@ -7,7 +5,6 @@ from .passes.graph_profile import FlopProfiler def register_flop_count_impl(func): - def wrapper(impl): FlopProfiler._custom_flop_count_impl[func] = impl return impl @@ -16,7 +13,6 @@ def register_flop_count_impl(func): def register_shape_impl(func): - def wrapper(impl): ShapeProp._custom_dispatch_func[func] = impl return impl diff --git a/colossalai/_analyzer/fx/tracer/bias_addition.py b/colossalai/_analyzer/fx/tracer/bias_addition.py index 1e75b47ca5b038aadb9c9bf0779bc3565d91bead..b8b83282b42c24443917a85cf968a22b54cf8156 100644 --- a/colossalai/_analyzer/fx/tracer/bias_addition.py +++ b/colossalai/_analyzer/fx/tracer/bias_addition.py @@ -12,7 +12,7 @@ from .tracer import register_tracer_impl __all__ = [] -@register_tracer_impl(F.linear, name='_bias_addition_impl') +@register_tracer_impl(F.linear, name="_bias_addition_impl") def linear_impl(input, weight, bias=None): if bias is None: return F.linear(input, weight) @@ -20,116 +20,130 @@ def linear_impl(input, weight, bias=None): return F.linear(input, weight) + bias -@register_tracer_impl(F.conv1d, name='_bias_addition_impl') +@register_tracer_impl(F.conv1d, name="_bias_addition_impl") def conv1d_impl(input, weight, bias=None, stride=_single(1), padding=_single(0), dilation=_single(1), groups=1): if bias is None: return F.conv1d(input, weight, stride=stride, padding=padding, dilation=dilation, groups=groups) else: return F.conv1d(input, weight, stride=stride, padding=padding, dilation=dilation, groups=groups) + bias.reshape( - (-1, 1)) + (-1, 1) + ) -@register_tracer_impl(F.conv2d, name='_bias_addition_impl') +@register_tracer_impl(F.conv2d, name="_bias_addition_impl") def conv2d_impl(input, weight, bias=None, stride=_pair(1), padding=_pair(0), dilation=_pair(1), groups=1): if bias is None: return F.conv2d(input, weight, stride=stride, padding=padding, dilation=dilation, groups=groups) else: return F.conv2d(input, weight, stride=stride, padding=padding, dilation=dilation, groups=groups) + bias.reshape( - (-1, 1, 1)) + (-1, 1, 1) + ) -@register_tracer_impl(F.conv3d, name='_bias_addition_impl') +@register_tracer_impl(F.conv3d, name="_bias_addition_impl") def conv3d_impl(input, weight, bias=None, stride=_triple(1), padding=_triple(0), dilation=_triple(1), groups=1): if bias is None: return F.conv3d(input, weight, stride=stride, padding=padding, dilation=dilation, groups=groups) else: return F.conv3d(input, weight, stride=stride, padding=padding, dilation=dilation, groups=groups) + bias.reshape( - (-1, 1, 1, 1)) - - -@register_tracer_impl(F.conv_transpose1d, name='_bias_addition_impl') -def conv_transpose1d_impl(input, - weight, - bias=None, - stride=_single(1), - padding=_single(0), - output_padding=_single(0), - groups=1, - dilation=_single(1)): + (-1, 1, 1, 1) + ) + + +@register_tracer_impl(F.conv_transpose1d, name="_bias_addition_impl") +def conv_transpose1d_impl( + input, + weight, + bias=None, + stride=_single(1), + padding=_single(0), + output_padding=_single(0), + groups=1, + dilation=_single(1), +): if bias is None: - return F.conv_transpose1d(input, - weight, - stride=stride, - padding=padding, - output_padding=output_padding, - groups=groups, - dilation=dilation) + return F.conv_transpose1d( + input, + weight, + stride=stride, + padding=padding, + output_padding=output_padding, + groups=groups, + dilation=dilation, + ) else: - return F.conv_transpose1d(input, - weight, - stride=stride, - padding=padding, - output_padding=output_padding, - groups=groups, - dilation=dilation) + bias.reshape((-1, 1)) - - -@register_tracer_impl(F.conv_transpose2d, name='_bias_addition_impl') -def conv_transpose2d_impl(input, - weight, - bias=None, - stride=_pair(1), - padding=_pair(0), - output_padding=_pair(0), - groups=1, - dilation=_pair(1)): + return F.conv_transpose1d( + input, + weight, + stride=stride, + padding=padding, + output_padding=output_padding, + groups=groups, + dilation=dilation, + ) + bias.reshape((-1, 1)) + + +@register_tracer_impl(F.conv_transpose2d, name="_bias_addition_impl") +def conv_transpose2d_impl( + input, weight, bias=None, stride=_pair(1), padding=_pair(0), output_padding=_pair(0), groups=1, dilation=_pair(1) +): if bias is None: - return F.conv_transpose2d(input, - weight, - stride=stride, - padding=padding, - output_padding=output_padding, - groups=groups, - dilation=dilation) + return F.conv_transpose2d( + input, + weight, + stride=stride, + padding=padding, + output_padding=output_padding, + groups=groups, + dilation=dilation, + ) else: - return F.conv_transpose2d(input, - weight, - stride=stride, - padding=padding, - output_padding=output_padding, - groups=groups, - dilation=dilation) + bias.reshape((-1, 1, 1)) - - -@register_tracer_impl(F.conv_transpose3d, name='_bias_addition_impl') -def conv_transpose3d_impl(input, - weight, - bias=None, - stride=_triple(1), - padding=_triple(0), - output_padding=_triple(0), - groups=1, - dilation=_triple(1)): + return F.conv_transpose2d( + input, + weight, + stride=stride, + padding=padding, + output_padding=output_padding, + groups=groups, + dilation=dilation, + ) + bias.reshape((-1, 1, 1)) + + +@register_tracer_impl(F.conv_transpose3d, name="_bias_addition_impl") +def conv_transpose3d_impl( + input, + weight, + bias=None, + stride=_triple(1), + padding=_triple(0), + output_padding=_triple(0), + groups=1, + dilation=_triple(1), +): if bias is None: - return F.conv_transpose3d(input, - weight, - stride=stride, - padding=padding, - output_padding=output_padding, - groups=groups, - dilation=dilation) + return F.conv_transpose3d( + input, + weight, + stride=stride, + padding=padding, + output_padding=output_padding, + groups=groups, + dilation=dilation, + ) else: - return F.conv_transpose3d(input, - weight, - stride=stride, - padding=padding, - output_padding=output_padding, - groups=groups, - dilation=dilation) + bias.reshape((-1, 1, 1, 1)) - - -@register_tracer_impl(torch.addmm, name='_bias_addition_impl') -@register_tracer_impl(torch.Tensor.addmm, name='_bias_addition_impl') + return F.conv_transpose3d( + input, + weight, + stride=stride, + padding=padding, + output_padding=output_padding, + groups=groups, + dilation=dilation, + ) + bias.reshape((-1, 1, 1, 1)) + + +@register_tracer_impl(torch.addmm, name="_bias_addition_impl") +@register_tracer_impl(torch.Tensor.addmm, name="_bias_addition_impl") def addmm_impl(input, mat1, mat2, beta=1, alpha=1): if alpha != 1 and beta != 1: return F.linear(mat1, mat2.transpose(0, 1)) * alpha + input * beta @@ -141,8 +155,8 @@ def addmm_impl(input, mat1, mat2, beta=1, alpha=1): return F.linear(mat1, mat2.transpose(0, 1)) + input -@register_tracer_impl(torch.addbmm, name='_bias_addition_impl') -@register_tracer_impl(torch.Tensor.addbmm, name='_bias_addition_impl') +@register_tracer_impl(torch.addbmm, name="_bias_addition_impl") +@register_tracer_impl(torch.Tensor.addbmm, name="_bias_addition_impl") def addbmm_impl(input, batch1, batch2, beta=1, alpha=1): if alpha != 1 and beta != 1: return torch.bmm(batch1, batch2.transpose(1, 2)) * alpha + input * beta diff --git a/colossalai/_analyzer/fx/tracer/custom_leaf_module.py b/colossalai/_analyzer/fx/tracer/custom_leaf_module.py index 112c7c9637d20e395dbccaace063e3fa7657041f..ff6b55be5117ec74ac7fc73e44c38b31cdd896fa 100644 --- a/colossalai/_analyzer/fx/tracer/custom_leaf_module.py +++ b/colossalai/_analyzer/fx/tracer/custom_leaf_module.py @@ -4,6 +4,7 @@ from .tracer import register_leaf_module, register_leaf_module_impl try: import apex + register_leaf_module(apex.normalization.FusedLayerNorm) register_leaf_module(apex.normalization.FusedRMSNorm) register_leaf_module(apex.normalization.MixedFusedLayerNorm) diff --git a/colossalai/_analyzer/fx/tracer/proxy.py b/colossalai/_analyzer/fx/tracer/proxy.py index ce379efdcf0d7c01d5541cb9ebffac1325fa97ef..e3e210e7d1900729d79b09bd9947f73017e5fe4a 100644 --- a/colossalai/_analyzer/fx/tracer/proxy.py +++ b/colossalai/_analyzer/fx/tracer/proxy.py @@ -1,10 +1,8 @@ import operator -from typing import Any, Callable, Dict, Optional, Set, Union +from typing import Any, Callable, Dict, Optional, Union import torch -import torch.nn as nn -from torch.fx import Graph, Node, Proxy, Tracer -from torch.fx.graph import _Namespace +from torch.fx import Node, Proxy from torch.utils._pytree import tree_map from colossalai._analyzer._subclasses import MetaTensor @@ -32,7 +30,7 @@ class ColoProxy(Proxy): def __torch_function__(cls, orig_method, types, args=(), kwargs=None): kwargs = {} if kwargs is None else kwargs if orig_method in cls._func_dispatch: - impl = cls._func_dispatch.pop(orig_method) # avoid recursion + impl = cls._func_dispatch.pop(orig_method) # avoid recursion proxy = impl(*args, **kwargs) cls._func_dispatch[orig_method] = impl return proxy @@ -72,7 +70,7 @@ class ColoProxy(Proxy): return ColoAttribute(self, k, getattr(self._meta_data, k, None)) def __setitem__(self, key, value): - proxy = self.tracer.create_proxy('call_function', operator.setitem, (self, key, value), {}) + proxy = self.tracer.create_proxy("call_function", operator.setitem, (self, key, value), {}) proxy.meta_data = self._meta_data return proxy @@ -89,7 +87,6 @@ class ColoProxy(Proxy): class ColoAttribute(ColoProxy): - def __init__(self, root, attr: str, data=None): self.root = root self.attr = attr @@ -102,11 +99,11 @@ class ColoAttribute(ColoProxy): # the node for attributes is added lazily, since most will just be method calls # which do not rely on the getitem call if self._node is None: - self._node = self.tracer.create_proxy('call_function', getattr, (self.root, self.attr), {}).node + self._node = self.tracer.create_proxy("call_function", getattr, (self.root, self.attr), {}).node return self._node def __call__(self, *args, **kwargs): - return self.tracer.create_proxy('call_method', self.attr, (self.root,) + args, kwargs) + return self.tracer.create_proxy("call_method", self.attr, (self.root,) + args, kwargs) def __repr__(self): return f"ColoAttribute({self.node.name}, attr={self.attr})" diff --git a/colossalai/_analyzer/fx/tracer/symbolic_trace.py b/colossalai/_analyzer/fx/tracer/symbolic_trace.py index 2018863f6f5f50de7cc61eafb907a55034711993..7884fd911c864dcab53876950caea8024a6f2544 100644 --- a/colossalai/_analyzer/fx/tracer/symbolic_trace.py +++ b/colossalai/_analyzer/fx/tracer/symbolic_trace.py @@ -1,4 +1,4 @@ -from typing import Any, Callable, Dict, Iterable, List, Optional, Set, Tuple, Type, Union +from typing import Any, Callable, Dict, Optional, Union import torch from torch.fx import Tracer @@ -8,6 +8,7 @@ from colossalai._analyzer._subclasses import MetaTensor try: from ..codegen import ActivationCheckpointCodeGen + SUPPORT_ACTIVATION = True except: SUPPORT_ACTIVATION = False @@ -16,7 +17,7 @@ from .tracer import ColoTracer def _default_device(): - return torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu') + return torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu") def _current_device(module: torch.nn.Module): @@ -144,10 +145,9 @@ def symbolic_trace( if meta_args: device, orig_device = _default_device(), _current_device(root) wrap_fn = lambda elem: MetaTensor(elem, device=device) if isinstance(elem, torch.Tensor) else elem - graph = ColoTracer(trace_act_ckpt=trace_act_ckpt, - bias_addition_split=bias_addition_split).trace(root.to(device), - concrete_args=concrete_args, - meta_args=tree_map(wrap_fn, meta_args)) + graph = ColoTracer(trace_act_ckpt=trace_act_ckpt, bias_addition_split=bias_addition_split).trace( + root.to(device), concrete_args=concrete_args, meta_args=tree_map(wrap_fn, meta_args) + ) if trace_act_ckpt and SUPPORT_ACTIVATION: graph.set_codegen(ActivationCheckpointCodeGen()) root.to(orig_device) diff --git a/colossalai/_analyzer/fx/tracer/tracer.py b/colossalai/_analyzer/fx/tracer/tracer.py index 6958a00a6a72af16bf6a9736a7c18411ff127b76..17dce767269d4efc435ae9dcc2e666073035f107 100644 --- a/colossalai/_analyzer/fx/tracer/tracer.py +++ b/colossalai/_analyzer/fx/tracer/tracer.py @@ -20,11 +20,10 @@ def _truncate_suffix(s: str): import re # FIXME: don't know why but torch.fx always gets a suffix like '_1' in the name - return re.sub(r'_\d+$', '', s) + return re.sub(r"_\d+$", "", s) -def register_tracer_impl(func: Callable[..., Any], name: Optional[str] = '_custom_impl'): - +def register_tracer_impl(func: Callable[..., Any], name: Optional[str] = "_custom_impl"): def wrapper(impl): assert hasattr(ColoTracer, name), f"Cannot register {func.__name__} in ColoTracer.{name}" getattr(ColoTracer, name)[func] = impl @@ -34,7 +33,6 @@ def register_tracer_impl(func: Callable[..., Any], name: Optional[str] = '_custo def register_leaf_module_impl(module: nn.Module): - def wrapper(impl): ColoTracer._custom_leaf_module_impl[module] = impl return impl @@ -76,7 +74,7 @@ class ColoTracer(Tracer): self.ckpt_regions = [] self.ckpt_idx = 0 - self.mod_dir = '' + self.mod_dir = "" # whether the tracer should split the bias_add ops into two ops self.bias_addition_split = bias_addition_split @@ -87,35 +85,41 @@ class ColoTracer(Tracer): if self.bias_addition_split and type(m) in self._bias_addition_module and m.bias is not None: return False # user can specify which modules are leaf modules and which are not - return (type(m) not in self._custom_non_leaf_module - and (type(m) in self._custom_leaf_module or super().is_leaf_module(m, module_qualified_name))) + return type(m) not in self._custom_non_leaf_module and ( + type(m) in self._custom_leaf_module or super().is_leaf_module(m, module_qualified_name) + ) - def call_module(self, m: torch.nn.Module, forward: Callable[..., Any], args: Tuple[Any, ...], - kwargs: Dict[str, Any]) -> Any: + def call_module( + self, m: torch.nn.Module, forward: Callable[..., Any], args: Tuple[Any, ...], kwargs: Dict[str, Any] + ) -> Any: curr_dir = self.mod_dir - self.mod_dir = 'self.' + self.path_of_module(m) + self.mod_dir = "self." + self.path_of_module(m) rst = super().call_module(m, forward, args, kwargs) self.mod_dir = curr_dir return rst - def proxy(self, node: Node) -> 'ColoProxy': + def proxy(self, node: Node) -> "ColoProxy": return ColoProxy(node, self) - def create_proxy(self, - kind: str, - target: Target, - args: Tuple[Any, ...], - kwargs: Dict[str, Any], - name: Optional[str] = None, - type_expr: Optional[Any] = None, - proxy_factory_fn: Callable[[Node], 'Proxy'] = None): - + def create_proxy( + self, + kind: str, + target: Target, + args: Tuple[Any, ...], + kwargs: Dict[str, Any], + name: Optional[str] = None, + type_expr: Optional[Any] = None, + proxy_factory_fn: Callable[[Node], "Proxy"] = None, + ): proxy: ColoProxy = super().create_proxy(kind, target, args, kwargs, name, type_expr, proxy_factory_fn) unwrap_fn = lambda p: p.meta_data if isinstance(p, ColoProxy) else p - if kind == 'placeholder': - proxy.meta_data = self.meta_args[target] if target in self.meta_args else self.concrete_args.get( - _truncate_suffix(target), None) - elif kind == 'get_attr': + if kind == "placeholder": + proxy.meta_data = ( + self.meta_args[target] + if target in self.meta_args + else self.concrete_args.get(_truncate_suffix(target), None) + ) + elif kind == "get_attr": self.disable_module_getattr = True try: attr_itr = self.root @@ -125,20 +129,21 @@ class ColoTracer(Tracer): proxy.meta_data = attr_itr finally: self.disable_module_getattr = False - elif kind == 'call_function': + elif kind == "call_function": proxy.meta_data = target(*tree_map(unwrap_fn, args), **tree_map(unwrap_fn, kwargs)) - elif kind == 'call_method': + elif kind == "call_method": self.disable_module_getattr = True try: - if target == '__call__': + if target == "__call__": proxy.meta_data = unwrap_fn(args[0])(*tree_map(unwrap_fn, args[1:]), **tree_map(unwrap_fn, kwargs)) else: if target not in _TensorPropertyMethod: - proxy._meta_data = getattr(unwrap_fn(args[0]), target)(*tree_map(unwrap_fn, args[1:]), - **tree_map(unwrap_fn, kwargs)) + proxy._meta_data = getattr(unwrap_fn(args[0]), target)( + *tree_map(unwrap_fn, args[1:]), **tree_map(unwrap_fn, kwargs) + ) finally: self.disable_module_getattr = False - elif kind == 'call_module': + elif kind == "call_module": mod = self.root.get_submodule(target) self.disable_module_getattr = True try: @@ -158,11 +163,12 @@ class ColoTracer(Tracer): n_info = MetaInfo(node, mod_dir=self.mod_dir, activation_checkpoint=tuple(self.ckpt_regions)) return node - def trace(self, - root: torch.nn.Module, - concrete_args: Optional[Dict[str, torch.Tensor]] = None, - meta_args: Optional[Dict[str, torch.Tensor]] = None) -> Graph: - + def trace( + self, + root: torch.nn.Module, + concrete_args: Optional[Dict[str, torch.Tensor]] = None, + meta_args: Optional[Dict[str, torch.Tensor]] = None, + ) -> Graph: if meta_args is None: meta_args = {} @@ -177,9 +183,7 @@ class ColoTracer(Tracer): non_concrete_arg_names = sig_names - concrete_arg_names # update concrete args with default values for k, v in sig.parameters.items(): - if k in sig_names - meta_arg_names and \ - k not in concrete_args and \ - v.default is not inspect.Parameter.empty: + if k in sig_names - meta_arg_names and k not in concrete_args and v.default is not inspect.Parameter.empty: concrete_args[k] = v.default def _check_arg_name_valid(names: Iterable[str]): @@ -194,9 +198,9 @@ class ColoTracer(Tracer): self.meta_args = meta_args with self._torch_factory_override(), self._tracer_override(), torch.no_grad(): - self.mod_dir = 'self' + self.mod_dir = "self" self.graph = super().trace(root, concrete_args=concrete_args) - self.mod_dir = '' + self.mod_dir = "" self.graph.lint() for node in self.graph.nodes: @@ -266,17 +270,17 @@ class ColoTracer(Tracer): # override the torch factory functions to create a proxy when the method # is called during ``symbolic_trace()``. def wrap_factory_method(target): - @functools.wraps(target) def wrapper(*args, **kwargs): is_proxy = any(isinstance(p, ColoProxy) for p in args) | any( - isinstance(p, ColoProxy) for p in kwargs.values()) + isinstance(p, ColoProxy) for p in kwargs.values() + ) if is_proxy: # if the arg is a proxy, then need to record this function called on this proxy # e.g. torch.ones(size) where size is an input proxy self.disable_module_getattr = True try: - proxy = self.create_proxy('call_function', target, args, kwargs) + proxy = self.create_proxy("call_function", target, args, kwargs) finally: self.disable_module_getattr = False return proxy @@ -341,10 +345,13 @@ class ColoTracer(Tracer): if attr_val is p: if n not in parameter_proxy_cache: kwargs = {} - if 'proxy_factory_fn' in inspect.signature(self.create_proxy).parameters: - kwargs['proxy_factory_fn'] = (None if not self.param_shapes_constant else - lambda node: ColoProxy(self, node, n, attr_val)) - val_proxy = self.create_proxy('get_attr', n, (), {}, **kwargs) # type: ignore[arg-type] + if "proxy_factory_fn" in inspect.signature(self.create_proxy).parameters: + kwargs["proxy_factory_fn"] = ( + None + if not self.param_shapes_constant + else lambda node: ColoProxy(self, node, n, attr_val) + ) + val_proxy = self.create_proxy("get_attr", n, (), {}, **kwargs) # type: ignore[arg-type] parameter_proxy_cache[n] = val_proxy return parameter_proxy_cache[n] return None @@ -355,8 +362,9 @@ class ColoTracer(Tracer): return maybe_buffer_proxy if isinstance(attr_val, torch.nn.Parameter): - maybe_parameter_proxy = maybe_get_proxy_for_attr(attr_val, self.root.named_parameters(), - parameter_proxy_cache) + maybe_parameter_proxy = maybe_get_proxy_for_attr( + attr_val, self.root.named_parameters(), parameter_proxy_cache + ) if maybe_parameter_proxy is not None: return maybe_parameter_proxy diff --git a/colossalai/amp/__init__.py b/colossalai/amp/__init__.py index 963215476b6b038b2aa33c124461387e47579d3c..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 100644 --- a/colossalai/amp/__init__.py +++ b/colossalai/amp/__init__.py @@ -1,54 +0,0 @@ -#!/usr/bin/env python -# -*- encoding: utf-8 -*- - -import torch.nn as nn -from torch.nn.modules.loss import _Loss -from torch.optim import Optimizer - -from colossalai.context import Config - -from .amp_type import AMP_TYPE -from .apex_amp import convert_to_apex_amp -from .naive_amp import convert_to_naive_amp -from .torch_amp import convert_to_torch_amp - -__all__ = ['convert_to_amp', 'convert_to_naive_amp', 'convert_to_apex_amp', 'convert_to_torch_amp', 'AMP_TYPE'] - - -def convert_to_amp(model: nn.Module, optimizer: Optimizer, criterion: _Loss, mode: AMP_TYPE, amp_config: Config = None): - """A helper function to wrap training components with Torch AMP modules. - - Args: - param model (:class:`torch.nn.Module`): your model object. - optimizer (:class:`torch.optim.Optimizer`): your optimizer object. - criterion (:class:`torch.nn.modules.loss._Loss`): your loss function object. - mode (:class:`colossalai.amp.AMP_TYPE`): amp mode. - amp_config (Union[:class:`colossalai.context.Config`, dict]): configuration for different amp modes. - - Returns: - A tuple (model, optimizer, criterion). - - Note: - ``amp_config`` may vary from different mode you choose. You should check the corresponding amp mode - for more details about ``amp_config``. - For ``apex_amp``, please check - `apex_amp config `_. - For ``naive_amp``, please check - `naive_amp config `_. - For ``torch_amp``, please check - `torch_amp config `_. - """ - assert isinstance(mode, AMP_TYPE), \ - f'expected the argument mode be AMP_TYPE, but got {type(mode)}' - - if amp_config is None: - amp_config = Config() - - if mode == AMP_TYPE.TORCH: - model, optimizer, criterion = convert_to_torch_amp(model, optimizer, criterion, amp_config) - elif mode == AMP_TYPE.APEX: - model, optimizer = convert_to_apex_amp(model, optimizer, amp_config) - elif mode == AMP_TYPE.NAIVE: - model, optimizer = convert_to_naive_amp(model, optimizer, amp_config) - - return model, optimizer, criterion diff --git a/colossalai/amp/amp_type.py b/colossalai/amp/amp_type.py deleted file mode 100644 index 6f322f866cfc813e66e54b0c1006d62ef949e96e..0000000000000000000000000000000000000000 --- a/colossalai/amp/amp_type.py +++ /dev/null @@ -1,10 +0,0 @@ -#!/usr/bin/env python -# -*- encoding: utf-8 -*- - -from enum import Enum - - -class AMP_TYPE(Enum): - APEX = 'apex' - TORCH = 'torch' - NAIVE = 'naive' diff --git a/colossalai/amp/apex_amp/__init__.py b/colossalai/amp/apex_amp/__init__.py deleted file mode 100644 index 51b9b97dccce877783251fb3f61f08a87a6a7659..0000000000000000000000000000000000000000 --- a/colossalai/amp/apex_amp/__init__.py +++ /dev/null @@ -1,42 +0,0 @@ -import torch.nn as nn -from torch.optim import Optimizer - -from .apex_amp import ApexAMPOptimizer - - -def convert_to_apex_amp(model: nn.Module, optimizer: Optimizer, amp_config): - r"""A helper function to wrap training components with Apex AMP modules - - Args: - model (:class:`torch.nn.Module`): your model object. - optimizer (:class:`torch.optim.Optimizer`): your optimizer object. - amp_config (Union[:class:`colossalai.context.Config`, dict]): configuration for initializing apex_amp. - - Returns: - Tuple: A tuple (model, optimizer). - - The ``amp_config`` should include parameters below: - :: - - enabled (bool, optional, default=True) - opt_level (str, optional, default="O1") - cast_model_type (``torch.dtype``, optional, default=None) - patch_torch_functions (bool, optional, default=None) - keep_batchnorm_fp32 (bool or str, optional, default=None - master_weights (bool, optional, default=None) - loss_scale (float or str, optional, default=None) - cast_model_outputs (torch.dtype, optional, default=None) - num_losses (int, optional, default=1) - verbosity (int, default=1) - min_loss_scale (float, default=None) - max_loss_scale (float, default=2.**24) - - More details about ``amp_config`` refer to `amp_config `_. - """ - import apex.amp as apex_amp - model, optimizer = apex_amp.initialize(model, optimizer, **amp_config) - optimizer = ApexAMPOptimizer(optimizer) - return model, optimizer - - -__all__ = ['convert_to_apex_amp', 'ApexAMPOptimizer'] diff --git a/colossalai/amp/naive_amp/__init__.py b/colossalai/amp/naive_amp/__init__.py index 5b2f71d3ced771c43d541843153c6b64613f69e1..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 100644 --- a/colossalai/amp/naive_amp/__init__.py +++ b/colossalai/amp/naive_amp/__init__.py @@ -1,60 +0,0 @@ -import inspect - -import torch.nn as nn -from torch.optim import Optimizer - -from colossalai.utils import is_no_pp_or_last_stage - -from ._fp16_optimizer import FP16Optimizer -from .grad_scaler import ConstantGradScaler, DynamicGradScaler -from .naive_amp import NaiveAMPModel, NaiveAMPOptimizer - - -def convert_to_naive_amp(model: nn.Module, optimizer: Optimizer, amp_config): - """A helper function to wrap training components with naive AMP modules. In this mode, - we forcibly cast the model weights and inputs to FP16, and cast the model outputs to FP32 to calculate loss, - which is equivalent to Apex O3. - - Args: - model (:class:`torch.nn.Module`): your model object - optimizer (:class:`torch.optim.Optimizer`): your optimizer object - amp_config (:class:`colossalai.context.Config` or dict): configuration for naive mode amp. - - Returns: - Tuple: A tuple (model, optimizer) - - The ``amp_config`` should contain parameters below:: - - verbose (bool, optional): if set to `True`, will print debug info (Default: False). - clip_grad_norm (float, optional): clip gradients with this global L2 norm (Default 0). - Note that clipping is ignored if clip_grad == 0. - dynamic_grad_scale (bool): whether to use dynamic grad scaler. - """ - if isinstance(model, nn.ModuleList): - # interleaved pipeline - module_list = [] - for chunk, m in enumerate(model): - output_to_fp32 = is_no_pp_or_last_stage() and chunk == len(model) - 1 - module_list.append(NaiveAMPModel(m, output_to_fp32=output_to_fp32)) - model = nn.ModuleList(module_list) - else: - output_to_fp32 = is_no_pp_or_last_stage() - model = NaiveAMPModel(model, output_to_fp32=output_to_fp32) - - use_dynamic_grad_scaler = amp_config.pop('dynamic_grad_scale', True) - if use_dynamic_grad_scaler: - scaler_class = DynamicGradScaler - else: - scaler_class = ConstantGradScaler - - sig = inspect.signature(scaler_class.__init__) - kwargs = dict() - for param in sig.parameters.values(): - if param.name in amp_config: - kwargs[param.name] = amp_config.pop(param.name) - grad_scaler = scaler_class(**kwargs) - optimizer = NaiveAMPOptimizer(optimizer, grad_scaler, **amp_config) - return model, optimizer - - -__all__ = ['convert_to_naive_amp', 'NaiveAMPOptimizer', 'FP16Optimizer'] diff --git a/colossalai/amp/naive_amp/_utils.py b/colossalai/amp/naive_amp/_utils.py deleted file mode 100644 index 7633705e19fbce24faec87f9691c834279f0d8ad..0000000000000000000000000000000000000000 --- a/colossalai/amp/naive_amp/_utils.py +++ /dev/null @@ -1,49 +0,0 @@ -from typing import List - -from torch import Tensor - - -def has_inf_or_nan(tensor): - """Check if tensor has inf or nan values. - - Args: - tensor (:class:`torch.Tensor`): a torch tensor object - - Returns: - bool: Whether the tensor has inf or nan. True for yes and False for no. - """ - try: - # if tensor is half, the .float() incurs an additional deep copy, but it's necessary if - # Pytorch's .sum() creates a one-element tensor of the same type as tensor - # (which is true for some recent version of pytorch). - tensor_sum = float(tensor.float().sum()) - # More efficient version that can be used if .sum() returns a Python scalar - # tensor_sum = float(tensor.sum()) - except RuntimeError as instance: - # We want to check if inst is actually an overflow exception. - # RuntimeError could come from a different error. - # If so, we still want the exception to propagate. - if "value cannot be converted" not in instance.args[0]: - raise - return True - else: - if tensor_sum == float('inf') or tensor_sum == -float('inf') or tensor_sum != tensor_sum: - return True - return False - - -def zero_gard_by_list(tensor_list: List[Tensor], set_to_none: bool = True) -> None: - """Clear the gradient of a list of tensors, - - Note: copied from torch.optim.optimizer. - """ - for param in tensor_list: - if param.grad is not None: - if set_to_none: - param.grad = None - else: - if param.grad.grad_fn is not None: - param.grad.detach_() - else: - param.grad.requires_grad_(False) - param.grad.zero_() diff --git a/colossalai/amp/naive_amp/grad_scaler/__init__.py b/colossalai/amp/naive_amp/grad_scaler/__init__.py index dc8499d877e13f8e0eb14317d2cf4a8d54dfcb2a..34a20e8d67d6690a5a6e1e9ca945ab1d2e31adf7 100644 --- a/colossalai/amp/naive_amp/grad_scaler/__init__.py +++ b/colossalai/amp/naive_amp/grad_scaler/__init__.py @@ -2,4 +2,4 @@ from .base_grad_scaler import BaseGradScaler from .constant_grad_scaler import ConstantGradScaler from .dynamic_grad_scaler import DynamicGradScaler -__all__ = ['BaseGradScaler', 'ConstantGradScaler', 'DynamicGradScaler'] +__all__ = ["BaseGradScaler", "ConstantGradScaler", "DynamicGradScaler"] diff --git a/colossalai/amp/naive_amp/grad_scaler/base_grad_scaler.py b/colossalai/amp/naive_amp/grad_scaler/base_grad_scaler.py index 0d84384a7f67c6a4521a86d34f71ff03b821c7be..79661a44424fb0c8386ebdd04a7051ddae1517e3 100644 --- a/colossalai/amp/naive_amp/grad_scaler/base_grad_scaler.py +++ b/colossalai/amp/naive_amp/grad_scaler/base_grad_scaler.py @@ -9,7 +9,7 @@ from torch import Tensor from colossalai.logging import get_dist_logger -__all__ = ['BaseGradScaler'] +__all__ = ["BaseGradScaler"] class BaseGradScaler(ABC): @@ -30,24 +30,21 @@ class BaseGradScaler(ABC): @property def scale(self) -> Tensor: - """Returns the loss scale. - """ + """Returns the loss scale.""" return self._scale @property def inv_scale(self) -> Tensor: - """Returns the inverse of the loss scale. - """ + """Returns the inverse of the loss scale.""" return self._scale.double().reciprocal().float() def state_dict(self) -> Dict: - """Returns the states of the gradient scaler as a dict object. - """ + """Returns the states of the gradient scaler as a dict object.""" state_dict = dict() - state_dict['scale'] = self.scale + state_dict["scale"] = self.scale return state_dict def load_state_dict(self, state_dict: Dict) -> None: @@ -57,7 +54,7 @@ class BaseGradScaler(ABC): state_dict (dict): the states of the gradient scaler """ - self._scale = state_dict['scale'] + self._scale = state_dict["scale"] @abstractmethod def update(self, overflow: bool) -> None: @@ -67,8 +64,6 @@ class BaseGradScaler(ABC): overflow (bool): whether overflow occurs """ - pass - def log(self, message, *args, **kwargs): """Log messages. diff --git a/colossalai/amp/naive_amp/grad_scaler/constant_grad_scaler.py b/colossalai/amp/naive_amp/grad_scaler/constant_grad_scaler.py index a2f518c5dd28261f98faadf9134721ef1fd67dc7..2ad8b51ac22c1aa69cc4c48996b8a1098dad5f4a 100644 --- a/colossalai/amp/naive_amp/grad_scaler/constant_grad_scaler.py +++ b/colossalai/amp/naive_amp/grad_scaler/constant_grad_scaler.py @@ -2,7 +2,7 @@ # -*- encoding: utf-8 -*- from .base_grad_scaler import BaseGradScaler -__all__ = ['ConstantGradScaler'] +__all__ = ["ConstantGradScaler"] class ConstantGradScaler(BaseGradScaler): @@ -23,4 +23,3 @@ class ConstantGradScaler(BaseGradScaler): Args: overflow (bool): whether overflow occurs """ - pass diff --git a/colossalai/amp/naive_amp/grad_scaler/dynamic_grad_scaler.py b/colossalai/amp/naive_amp/grad_scaler/dynamic_grad_scaler.py index e899b9ca4c89fba16352ce736cb0abc4959e163b..65133a4b3712be6069bfeab5c35bf0a556c585be 100644 --- a/colossalai/amp/naive_amp/grad_scaler/dynamic_grad_scaler.py +++ b/colossalai/amp/naive_amp/grad_scaler/dynamic_grad_scaler.py @@ -7,7 +7,7 @@ import torch from .base_grad_scaler import BaseGradScaler -__all__ = ['DynamicGradScaler'] +__all__ = ["DynamicGradScaler"] class DynamicGradScaler(BaseGradScaler): @@ -24,15 +24,17 @@ class DynamicGradScaler(BaseGradScaler): verbose (bool): whether to log messages, defaults to False """ - def __init__(self, - initial_scale: float = 2**16, - growth_factor: float = 2, - backoff_factor: float = 0.5, - growth_interval: int = 1000, - min_scale: Optional[float] = None, - max_scale: Optional[float] = None, - hysteresis: int = 2, - verbose: bool = False): + def __init__( + self, + initial_scale: float = 2**16, + growth_factor: float = 2, + backoff_factor: float = 0.5, + growth_interval: int = 1000, + min_scale: Optional[float] = None, + max_scale: Optional[float] = None, + hysteresis: int = 2, + verbose: bool = False, + ): super().__init__(initial_scale, verbose) if min_scale: self._min_scale = torch.cuda.FloatTensor([min_scale]) @@ -53,18 +55,17 @@ class DynamicGradScaler(BaseGradScaler): self._sanity_checks() def _sanity_checks(self) -> None: - """Check if the arguments are correct. - """ + """Check if the arguments are correct.""" if self._min_scale: - assert self._min_scale > 0, 'The minimum gradient scale cannot be zero or negative' - assert self._min_scale <= self._scale, 'The minimum gradient scale cannot be greater than the current scale' + assert self._min_scale > 0, "The minimum gradient scale cannot be zero or negative" + assert self._min_scale <= self._scale, "The minimum gradient scale cannot be greater than the current scale" if self._max_scale: - assert self._max_scale > 0, 'The maximum gradient scale cannot be zero or negative' - assert self._max_scale >= self._scale, 'The maximum gradient scale cannot be smaller than the current scale' - assert self._growth_factor > 1, 'The growth factor cannot be equal or smaller than 1' - assert 0 < self._backoff_factor < 1, 'The backoff factor must be between 0 and 1' - assert self._hysteresis >= 0, 'The hysteresis cannot be negative' + assert self._max_scale > 0, "The maximum gradient scale cannot be zero or negative" + assert self._max_scale >= self._scale, "The maximum gradient scale cannot be smaller than the current scale" + assert self._growth_factor > 1, "The growth factor cannot be equal or smaller than 1" + assert 0 < self._backoff_factor < 1, "The backoff factor must be between 0 and 1" + assert self._hysteresis >= 0, "The hysteresis cannot be negative" def update(self, overflow: bool) -> None: """Update the loss scale. @@ -88,19 +89,18 @@ class DynamicGradScaler(BaseGradScaler): self.log( f"No overflow for consecutive {self._growth_interval} steps, " f"the loss scale is adjusted to {self.scale.item()}", - ranks=[0]) + ranks=[0], + ) def _backoff_scale(self) -> None: - """Decrease the loss scale - """ + """Decrease the loss scale""" self._scale = self._scale * self._backoff_factor if self._min_scale: self._scale = torch.max(self._scale, self._min_scale) def _grow_scale(self) -> None: - """Increase the loss scale - """ + """Increase the loss scale""" self._scale = self._scale * self._growth_factor if self._max_scale: @@ -108,14 +108,14 @@ class DynamicGradScaler(BaseGradScaler): def state_dict(self): state_dict = dict() - state_dict['scale'] = self._scale - state_dict['growth_factor'] = self._growth_factor - state_dict['backoff_factor'] = self._backoff_factor - state_dict['hysteresis'] = self._hysteresis + state_dict["scale"] = self._scale + state_dict["growth_factor"] = self._growth_factor + state_dict["backoff_factor"] = self._backoff_factor + state_dict["hysteresis"] = self._hysteresis return state_dict def load_state_dict(self, state_dict): - self._scale = state_dict['scale'].cuda(torch.cuda.current_device()) - self._growth_factor = state_dict['growth_factor'] - self._backoff_factor = state_dict['backoff_factor'] - self._hysteresis = state_dict['hysteresis'] + self._scale = state_dict["scale"].cuda(torch.cuda.current_device()) + self._growth_factor = state_dict["growth_factor"] + self._backoff_factor = state_dict["backoff_factor"] + self._hysteresis = state_dict["hysteresis"] diff --git a/colossalai/amp/naive_amp/mixed_precision_mixin/__init__.py b/colossalai/amp/naive_amp/mixed_precision_mixin/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..a31811e4a567db889e07830f46c01ec4b1c27a53 --- /dev/null +++ b/colossalai/amp/naive_amp/mixed_precision_mixin/__init__.py @@ -0,0 +1,9 @@ +from .base import MixedPrecisionMixin +from .bf16 import BF16MixedPrecisionMixin +from .fp16 import FP16MixedPrecisionMixin + +__all__ = [ + "MixedPrecisionMixin", + "FP16MixedPrecisionMixin", + "BF16MixedPrecisionMixin", +] diff --git a/colossalai/amp/naive_amp/mixed_precision_mixin/base.py b/colossalai/amp/naive_amp/mixed_precision_mixin/base.py new file mode 100644 index 0000000000000000000000000000000000000000..fc7e0b74179a4b18710b30a136a8ed9f957f68b1 --- /dev/null +++ b/colossalai/amp/naive_amp/mixed_precision_mixin/base.py @@ -0,0 +1,86 @@ +from abc import ABC, abstractmethod + +import torch +from torch import Tensor + + +class MixedPrecisionMixin(ABC): + """A helper class for mixed precision training. This mixin is used in mixed precision optimizers. + + Attributes: + dtype (torc.dtype): The expected dtype of the gradients. + + Examples: + ```python + class MyMixedPrecisionOptimizer(OptimizerWrapper): + def __init__(self, optim: Optimizer): + super().__init__(optim) + self.mixed_precision = MixedPrecisionMixin() + + def backward(self, loss): + loss = self.mixed_precision.pre_backward(loss) + loss.backward() + + def backward_by_grad(self, tensor, grad): + grad = self.mixed_precision.pre_backward_by_grad(tensor, grad) + tensor.backward(grad) + + def step(self): + if self.mixed_precision.should_skip_step(): + self.zero_grad() + return + div_scale = self.mixed_precision.get_grad_div_scale() + # maybe clip grad here + # maybe scale grad here + self.optim.step() + + def zero_grad(self): + self.mixed_precision.pre_zero_grad() + return self.optim.zero_grad() + ``` + """ + + dtype: torch.dtype + + @abstractmethod + def pre_backward(self, loss: Tensor) -> Tensor: + """Called before backward. + + Args: + loss (Tensor): Loss value. + + Returns: + Tensor: Loss value (possibly scaled). + """ + + @abstractmethod + def pre_backward_by_grad(self, tensor: Tensor, grad: Tensor) -> Tensor: + """Called before backward by grad. This is helpful for pipeline parallelism. + + Args: + tensor (Tensor): Tensor to backward. + grad (Tensor): Gradient of the tensor. + + Returns: + Tensor: Gradient of the tensor (possibly scaled). + """ + + @abstractmethod + def should_skip_step(self) -> bool: + """Called before step. + + Returns: + bool: Whether to skip the step. + """ + + @abstractmethod + def pre_zero_grad(self) -> None: + """Called before zero_grad.""" + + @abstractmethod + def get_grad_div_scale(self) -> float: + """Called before step or clip_grad. To keep computation efficiency, this method does not (maybe) unscale grads. + + Returns: + float: A divisor for gradient clipping or step. + """ diff --git a/colossalai/amp/naive_amp/mixed_precision_mixin/bf16.py b/colossalai/amp/naive_amp/mixed_precision_mixin/bf16.py new file mode 100644 index 0000000000000000000000000000000000000000..9454f6eb84130f083ab893baa34b9277a71dd5c0 --- /dev/null +++ b/colossalai/amp/naive_amp/mixed_precision_mixin/bf16.py @@ -0,0 +1,23 @@ +import torch +from torch import Tensor + +from .base import MixedPrecisionMixin + + +class BF16MixedPrecisionMixin(MixedPrecisionMixin): + dtype = torch.bfloat16 + + def pre_backward(self, loss: Tensor) -> Tensor: + return loss + + def pre_backward_by_grad(self, tensor: Tensor, grad: Tensor) -> Tensor: + return grad + + def should_skip_step(self) -> bool: + return False + + def pre_zero_grad(self) -> None: + pass + + def get_grad_div_scale(self) -> float: + return 1.0 diff --git a/colossalai/amp/naive_amp/mixed_precision_mixin/fp16.py b/colossalai/amp/naive_amp/mixed_precision_mixin/fp16.py new file mode 100644 index 0000000000000000000000000000000000000000..9ce272356797db4073b980e6f7713c83a1e0c44d --- /dev/null +++ b/colossalai/amp/naive_amp/mixed_precision_mixin/fp16.py @@ -0,0 +1,87 @@ +from abc import abstractmethod +from enum import Enum + +import torch +import torch.distributed as dist +from torch import Tensor + +from colossalai.amp.naive_amp.grad_scaler import DynamicGradScaler +from colossalai.utils import get_current_device + +from .base import MixedPrecisionMixin + + +class OptimState(Enum): + SCALED = 0 + UNSCALED = 1 + + +class FP16MixedPrecisionMixin(MixedPrecisionMixin): + dtype = torch.float16 + + def __init__( + self, + initial_scale: float = 2**16, + min_scale: float = 1, + growth_factor: float = 2, + backoff_factor: float = 0.5, + growth_interval: int = 1000, + hysteresis: int = 2, + max_scale: float = 2**32, + ) -> None: + super().__init__() + self.grad_scaler = DynamicGradScaler( + initial_scale=initial_scale, + min_scale=min_scale, + growth_factor=growth_factor, + backoff_factor=backoff_factor, + growth_interval=growth_interval, + hysteresis=hysteresis, + max_scale=max_scale, + ) + self.optim_state = OptimState.UNSCALED + self.found_overflow = torch.zeros(1, dtype=torch.float, device=get_current_device()) + + @property + def loss_scale(self) -> float: + return self.grad_scaler.scale.item() + + @abstractmethod + def check_local_overflow(self) -> bool: + """Check whether there is overflow in the local process. This method should be implemented by subclasses. + + Returns: + bool: Whether there is overflow in the local process. + """ + + def check_overflow(self) -> bool: + # clear previous overflow record + self.found_overflow.fill_(0.0) + if self.check_local_overflow(): + self.found_overflow.fill_(1.0) + dist.all_reduce(self.found_overflow, op=dist.ReduceOp.MAX) + return self.found_overflow.item() > 0 + + def pre_backward(self, loss: Tensor) -> Tensor: + loss = self.loss_scale * loss + self.optim_state = OptimState.SCALED + return loss + + def pre_backward_by_grad(self, tensor: Tensor, grad: Tensor) -> Tensor: + self.optim_state = OptimState.SCALED + return grad + + def should_skip_step(self) -> bool: + found_inf = self.check_overflow() + self.grad_scaler.update(found_inf) + if found_inf: + self.optim_state = OptimState.UNSCALED + return found_inf + + def pre_zero_grad(self) -> None: + pass + + def get_grad_div_scale(self) -> float: + assert self.optim_state == OptimState.SCALED, "grads should be scaled before clipping" + self.optim_state = OptimState.UNSCALED + return self.loss_scale diff --git a/colossalai/amp/naive_amp/mixed_precision_optimizer.py b/colossalai/amp/naive_amp/mixed_precision_optimizer.py new file mode 100644 index 0000000000000000000000000000000000000000..501a843f6992c2f9d23fc4ad93e059d581910854 --- /dev/null +++ b/colossalai/amp/naive_amp/mixed_precision_optimizer.py @@ -0,0 +1,169 @@ +from typing import Dict, List + +import torch +from torch import Tensor +from torch.nn import Module, Parameter +from torch.optim import Optimizer + +from colossalai.interface import OptimizerWrapper + +from .mixed_precision_mixin import BF16MixedPrecisionMixin, FP16MixedPrecisionMixin + + +class NaiveFP16MixedPrecisionMixin(FP16MixedPrecisionMixin): + def __init__( + self, + working_params: List[Parameter], + initial_scale: float = 2**16, + min_scale: float = 1, + growth_factor: float = 2, + backoff_factor: float = 0.5, + growth_interval: int = 1000, + hysteresis: int = 2, + max_scale: float = 2**32, + ) -> None: + super().__init__( + initial_scale, min_scale, growth_factor, backoff_factor, growth_interval, hysteresis, max_scale + ) + self.params = working_params + + def check_local_overflow(self) -> bool: + for p in self.params: + if p.grad is not None and not torch.isfinite(p.grad).all(): + return True + return False + + +class MixedPrecisionOptimizer(OptimizerWrapper): + def __init__( + self, + optim: Optimizer, + precision: str = "fp16", + initial_scale: float = 2**16, + min_scale: float = 1, + growth_factor: float = 2, + backoff_factor: float = 0.5, + growth_interval: int = 1000, + hysteresis: int = 2, + max_scale: float = 2**32, + max_norm: float = 0.0, + ): + super().__init__(optim) + if precision == "fp16": + working_params = [] + for group in self.optim.param_groups: + for p in group["params"]: + working_params.append(p) + self.mixed_precision = NaiveFP16MixedPrecisionMixin( + working_params, + initial_scale=initial_scale, + min_scale=min_scale, + growth_factor=growth_factor, + backoff_factor=backoff_factor, + growth_interval=growth_interval, + hysteresis=hysteresis, + max_scale=max_scale, + ) + elif precision == "bf16": + self.mixed_precision = BF16MixedPrecisionMixin() + else: + raise ValueError(f"Unsupported precision: {precision}") + if max_norm > 0.0: + raise NotImplementedError("max_norm is not supported yet.") + self.max_norm = max_norm + self.working_to_master_map: Dict[Parameter, Tensor] = {} + self.master_to_working_map: Dict[Tensor, Parameter] = {} + + # create master weights + for group in self.optim.param_groups: + master_params = [] + for p in group["params"]: + if p.requires_grad: + master_p = p + if p.dtype != torch.float: + master_p = p.detach().float() + self.working_to_master_map[p] = master_p + self.master_to_working_map[master_p] = p + master_params.append(master_p) + group["params"] = master_params + + def backward(self, loss: Tensor, *args, **kwargs): + loss = self.mixed_precision.pre_backward(loss) + loss.backward(*args, **kwargs) + + def backward_by_grad(self, tensor: Tensor, grad: Tensor): + grad = self.mixed_precision.pre_backward_by_grad(tensor, grad) + tensor.backward(grad) + + def zero_grad(self, *args, **kwargs): + for p in self.working_to_master_map.keys(): + p.grad = None + self.mixed_precision.pre_zero_grad() + return super().zero_grad(*args, **kwargs) + + def _unscale_and_clip_grads(self, total_norm: float) -> None: + div_scale = 1.0 + if self.mixed_precision is not None: + div_scale = self.mixed_precision.get_grad_div_scale() + + if self.max_norm > 0.0: + # norm is in fact norm*scale + clip = ((total_norm / div_scale) + 1e-6) / self.max_norm + if clip > 1: + div_scale = clip * div_scale + + for group in self.param_groups: + for p in group["params"]: + if p.grad is None: + continue + p.grad.data.mul_(1.0 / div_scale) + + def _compute_grad_norm(self) -> float: + if self.max_norm <= 0.0: + return 0.0 + grads = [p.grad for group in self.param_groups for p in group["params"] if p.grad is not None] + if len(grads) == 0: + return 0.0 + device = grads[0].device + # TODO(ver217): support tp + total_norm = torch.norm(torch.stack([torch.norm(g.detach(), 2).to(device) for g in grads]), 2) + return total_norm.item() + + def step(self, *args, **kwargs): + if self.mixed_precision.should_skip_step(): + self.zero_grad() + return + # prepare grads + for group in self.optim.param_groups: + for p in group["params"]: + working_param = self.master_to_working_map[p] + if p is working_param: + continue + if working_param.grad is not None: + p.grad = working_param.grad.data.float() + working_param.grad = None + total_norm = self._compute_grad_norm() + self._unscale_and_clip_grads(total_norm) + self.optim.step(*args, **kwargs) + # update working params + for group in self.optim.param_groups: + for p in group["params"]: + working_param = self.master_to_working_map[p] + if p is working_param: + continue + working_param.data.copy_(p.data) + + def update_master_params(self, model: Module): + # Update master params from working params + with torch.no_grad(): + for p in model.parameters(): + if (p is None) or (p not in self.working_to_master_map): + continue + master_param = self.working_to_master_map[p] + master_param.data.copy_(p.data) + + def get_working_to_master_map(self) -> Dict[int, torch.Tensor]: + return {id(working_p): master_p for working_p, master_p in self.working_to_master_map.items()} + + def get_master_to_working_map(self) -> Dict[int, torch.Tensor]: + return {id(master_p): working_p for master_p, working_p in self.master_to_working_map.items()} diff --git a/colossalai/amp/torch_amp/__init__.py b/colossalai/amp/torch_amp/__init__.py deleted file mode 100644 index 893cc890d68e423c643e6dc4bbf6343ff174a8d7..0000000000000000000000000000000000000000 --- a/colossalai/amp/torch_amp/__init__.py +++ /dev/null @@ -1,45 +0,0 @@ -from typing import Optional - -import torch.nn as nn -from torch.nn.modules.loss import _Loss -from torch.optim import Optimizer - -from colossalai.context import Config - -from .torch_amp import TorchAMPLoss, TorchAMPModel, TorchAMPOptimizer - - -def convert_to_torch_amp(model: nn.Module, - optimizer: Optimizer, - criterion: Optional[_Loss] = None, - amp_config: Optional[Config] = None): - """A helper function to wrap training components with Pytorch AMP modules - - Args: - model (:class:`torch.nn.Module`): your model object. - optimizer (:class:`torch.optim.Optimizer`): your optimizer object - criterion (:class:`torch.nn.modules.loss._Loss`, optional): your loss function object - amp_config (:class:`colossalai.context.Config` or dict, optional): configuration for Pytorch AMP. - - The ``amp_config`` should include parameters below: - :: - - init_scale (float, optional, default=2.**16) - growth_factor (float, optional, default=2.0) - backoff_factor (float, optional, default=0.5) - growth_interval (int, optional, default=2000) - enabled (bool, optional, default=True) - - Returns: - A tuple (model, optimizer, criterion) - """ - model = TorchAMPModel(model) - if amp_config is None: - amp_config = dict() - optimizer = TorchAMPOptimizer(optimizer, **amp_config) - if criterion: - criterion = TorchAMPLoss(criterion) - return model, optimizer, criterion - - -__all__ = ['convert_to_torch_amp', 'TorchAMPModel', 'TorchAMPLoss', 'TorchAMPOptimizer'] diff --git a/colossalai/auto_parallel/README.md b/colossalai/auto_parallel/README.md index 8e47e1bb0b4a6e8e86c1e76d600d3dae3c8be251..f011ec8ccbd7f92821c45e124b85e92eba872b16 100644 --- a/colossalai/auto_parallel/README.md +++ b/colossalai/auto_parallel/README.md @@ -16,8 +16,8 @@ A *symbolic profiler* for collecting computing and memory overhead related to st ### Solver **Solver** is designed to find the optimal execution plan for a given computation graph and cluster in two stages: -1) *Intra-op parallelism stage* is to find the plan with the minimum total execution time of all nodes with respect to the constraint of the memory budget. The optimaztion goal of intra-op parallelism solver is modified from Alpa 's intra-op parallelsim ILP solver. -2) *Activation checkpoint stage* is to search for the fastest execution plan that meets the memory budget on the computation graph after inserting the communication nodes by the intra-op parallelism stage. The algorithm to find optimial activation checkpoint is modified from Rotor . The reason we use two-stage optimization is that if the two tasks are formulated together, the solving time will be significantly increased, which will greatly affect the user experience of the system. On the contrary, solving in two hierarchical levels has many advantages. Firstly, compared with the computation graph with activation checkpointing, the original graph has fewer nodes, which can reduce the solving cost of intra-op parallelism solver. In addition, a more optimal solution can be found by adding the communication overhead into the activation checkpoint modeling. +1) *Intra-op parallelism stage* is to find the plan with the minimum total execution time of all nodes with respect to the constraint of the memory budget. The optimization goal of intra-op parallelism solver is modified from Alpa 's intra-op parallelism ILP solver. +2) *Activation checkpoint stage* is to search for the fastest execution plan that meets the memory budget on the computation graph after inserting the communication nodes by the intra-op parallelism stage. The algorithm to find optimal activation checkpoint is modified from Rotor . The reason we use two-stage optimization is that if the two tasks are formulated together, the solving time will be significantly increased, which will greatly affect the user experience of the system. On the contrary, solving in two hierarchical levels has many advantages. Firstly, compared with the computation graph with activation checkpointing, the original graph has fewer nodes, which can reduce the solving cost of intra-op parallelism solver. In addition, a more optimal solution can be found by adding the communication overhead into the activation checkpoint modeling. ### Generator **Generator** applies the searched execution plan to the computation graph and recompiles the computation graph to optimized PyTorch code. It has *a series compile pass* to insert a communication node or do the kernel substitution as the intra-op parallelism solver required. Additionally, we implement a *code generation* feature to recognize the annotation from the activation checkpoint solver and inject the activation checkpoint block following annotation instructions. diff --git a/colossalai/auto_parallel/checkpoint/build_c_ext.py b/colossalai/auto_parallel/checkpoint/build_c_ext.py index af4349865a7b8dd748e34458eb3d5aeeb359b599..7de56f80525ae347946cdbab9bee7fb2048aa249 100644 --- a/colossalai/auto_parallel/checkpoint/build_c_ext.py +++ b/colossalai/auto_parallel/checkpoint/build_c_ext.py @@ -3,14 +3,16 @@ import os from setuptools import Extension, setup this_dir = os.path.dirname(os.path.abspath(__file__)) -ext_modules = [Extension( - 'rotorc', - sources=[os.path.join(this_dir, 'ckpt_solver_rotor.c')], -)] +ext_modules = [ + Extension( + "rotorc", + sources=[os.path.join(this_dir, "ckpt_solver_rotor.c")], + ) +] setup( - name='rotor c extension', - version='0.1', - description='rotor c extension for faster dp computing', + name="rotor c extension", + version="0.1", + description="rotor c extension for faster dp computing", ext_modules=ext_modules, ) diff --git a/colossalai/auto_parallel/checkpoint/ckpt_solver_base.py b/colossalai/auto_parallel/checkpoint/ckpt_solver_base.py index b388d00ac553726f577575d5d770b98dfb873f12..8aaa690b333c6d71044b4c40726333bd453d411c 100644 --- a/colossalai/auto_parallel/checkpoint/ckpt_solver_base.py +++ b/colossalai/auto_parallel/checkpoint/ckpt_solver_base.py @@ -12,13 +12,13 @@ from colossalai.auto_parallel.passes.runtime_apply_pass import ( ) from colossalai.fx.codegen.activation_checkpoint_codegen import ActivationCheckpointCodeGen -__all___ = ['CheckpointSolverBase'] +__all___ = ["CheckpointSolverBase"] def _copy_output(src: Graph, dst: Graph): """Copy the output node from src to dst""" for n_src, n_dst in zip(src.nodes, dst.nodes): - if n_src.op == 'output': + if n_src.op == "output": n_dst.meta = n_src.meta @@ -28,7 +28,6 @@ def _get_param_size(module: torch.nn.Module): class CheckpointSolverBase(ABC): - def __init__( self, graph: Graph, @@ -81,13 +80,10 @@ class CheckpointSolverBase(ABC): @abstractmethod def solve(self): - """Solve the checkpointing problem and return the solution. - """ - pass + """Solve the checkpointing problem and return the solution.""" def get_node_list(self): - """Get the node list. - """ + """Get the node list.""" return [[node] for node in self.graph.nodes] def _linearize_graph(self) -> List[List[Node]]: @@ -140,8 +136,7 @@ class CheckpointSolverBase(ABC): """ def _is_inplace(n: Node): - """Get the inplace argument from ``torch.fx.Node`` - """ + """Get the inplace argument from ``torch.fx.Node``""" inplace = False if n.op == "call_function": inplace = n.kwargs.get("inplace", False) @@ -150,19 +145,22 @@ class CheckpointSolverBase(ABC): return inplace def _is_shape_consistency(n: Node): - """Check if this node is shape-consistency node (i.e. ``runtime_apply`` or ``runtime_apply_for_iterable_object``) - """ + """Check if this node is shape-consistency node (i.e. ``runtime_apply`` or ``runtime_apply_for_iterable_object``)""" return n.target in [runtime_apply, runtime_apply_for_iterable_object, runtime_comm_spec_apply] - return not sum([v for _, v in deps.items()]) and not any(map(_is_inplace, n.users)) and not any( - map(_is_shape_consistency, n.users)) + return ( + not sum([v for _, v in deps.items()]) + and not any(map(_is_inplace, n.users)) + and not any(map(_is_shape_consistency, n.users)) + ) # make sure that item in cnode is valid if self.cnode: for name in self.cnode: try: - assert next(node for node in self.graph.nodes if node.name == name).op == "placeholder", \ - f"Common node {name} is not an input of the model." + assert ( + next(node for node in self.graph.nodes if node.name == name).op == "placeholder" + ), f"Common node {name} is not an input of the model." except StopIteration: raise ValueError(f"Common node name {name} not in graph.") @@ -187,8 +185,9 @@ class CheckpointSolverBase(ABC): region = [] # propagate common node attr if possible - if len(n.all_input_nodes) == len([node for node in n.all_input_nodes if node.name in self.cnode - ]) or _is_cop(n.target): + if len(n.all_input_nodes) == len( + [node for node in n.all_input_nodes if node.name in self.cnode] + ) or _is_cop(n.target): self.cnode.append(n.name) else: deps[n] = len([user for user in n.users if user.op != "output"]) diff --git a/colossalai/auto_parallel/checkpoint/ckpt_solver_chen.py b/colossalai/auto_parallel/checkpoint/ckpt_solver_chen.py index 19b2ef5987c9ebc160078339b764741a71b34dbf..ab16cc04b7304a60e69a91c8b49af91be898dab9 100644 --- a/colossalai/auto_parallel/checkpoint/ckpt_solver_chen.py +++ b/colossalai/auto_parallel/checkpoint/ckpt_solver_chen.py @@ -8,11 +8,10 @@ from colossalai.fx.profiler import calculate_fwd_in, calculate_fwd_tmp from .ckpt_solver_base import CheckpointSolverBase -__all__ = ['CheckpointSolverChen'] +__all__ = ["CheckpointSolverChen"] class CheckpointSolverChen(CheckpointSolverBase): - def __init__(self, graph: Graph, cnode: List[str] = None, num_grids: int = 6): """ This is the simple implementation of Algorithm 3 in https://arxiv.org/abs/1604.06174. @@ -40,14 +39,14 @@ class CheckpointSolverChen(CheckpointSolverBase): Returns: graph (Graph): The optimized graph, should be a copy of the original graph. """ - checkpointable_op = ['call_module', 'call_method', 'call_function', 'get_attr'] + checkpointable_op = ["call_module", "call_method", "call_function", "get_attr"] ckpt = self.grid_search() for i, seg in enumerate(ckpt): for idx in range(*seg): nodes = self.node_list[idx] for n in nodes: if n.op in checkpointable_op: - n.meta['activation_checkpoint'] = i + n.meta["activation_checkpoint"] = i return deepcopy(self.graph) def run_chen_greedy(self, b: int = 0) -> Tuple[Set, int]: diff --git a/colossalai/auto_parallel/checkpoint/ckpt_solver_rotor.py b/colossalai/auto_parallel/checkpoint/ckpt_solver_rotor.py index 21c3bf0da758bd061eaa9bcf08534e9a2df8d6cf..d10c41ae2b962242e09b5b52587f48bc1c80c118 100644 --- a/colossalai/auto_parallel/checkpoint/ckpt_solver_rotor.py +++ b/colossalai/auto_parallel/checkpoint/ckpt_solver_rotor.py @@ -1,5 +1,5 @@ from copy import deepcopy -from typing import Any, Dict, List, Tuple +from typing import Any, List, Tuple from torch import Tensor from torch.fx import Graph, Node @@ -18,17 +18,18 @@ from colossalai.logging import get_dist_logger from .ckpt_solver_base import CheckpointSolverBase from .operation import Backward, Chain, ForwardCheck, ForwardEnable, ForwardNograd, Loss, Sequence -__all__ = ['CheckpointSolverRotor'] +__all__ = ["CheckpointSolverRotor"] class CheckpointSolverRotor(CheckpointSolverBase): - - def __init__(self, - graph: Graph, - free_memory: float = -1, - cnode: List[str] = None, - memory_slots: int = 500, - optim_multiplier: float = 1.0): + def __init__( + self, + graph: Graph, + free_memory: float = -1, + cnode: List[str] = None, + memory_slots: int = 500, + optim_multiplier: float = 1.0, + ): """This is the simple implementation of dynamic programming algorithm rotor in https://hal.inria.fr/hal-02352969. Some code are adapted from https://gitlab.inria.fr/hiepacs/rotor. @@ -85,13 +86,14 @@ class CheckpointSolverRotor(CheckpointSolverBase): # backtrack try: - self.sequence = self._backtrack(chain, 0, len(chain), self.memory_slots - chain.x[0], self.cost_table, - self.back_ptr) + self.sequence = self._backtrack( + chain, 0, len(chain), self.memory_slots - chain.x[0], self.cost_table, self.back_ptr + ) self._annotate_from_sequence(self.sequence, self.node_list) except ValueError as e: # using logger to annonce that the solver is failed logger = get_dist_logger() - logger.warning(f'Checkpoint solver failed: {e}') + logger.warning(f"Checkpoint solver failed: {e}") raise ValueError if verbose: @@ -100,14 +102,19 @@ class CheckpointSolverRotor(CheckpointSolverBase): return deepcopy(self.graph) def print_chain(self): - print('[input]', self.chain.x[0], self.chain.xbar[0], self.chain.ftmp[0], self.chain.btmp[0]) + print("[input]", self.chain.x[0], self.chain.xbar[0], self.chain.ftmp[0], self.chain.btmp[0]) for idx in range(len(self.node_list) - 1): - print(self.node_list[idx], self.chain.x[idx + 1], self.chain.xbar[idx + 1], self.chain.ftmp[idx], - self.chain.btmp[idx]) - print(f'Chain = {self.chain}') + print( + self.node_list[idx], + self.chain.x[idx + 1], + self.chain.xbar[idx + 1], + self.chain.ftmp[idx], + self.chain.btmp[idx], + ) + print(f"Chain = {self.chain}") def print_sequence(self): - print(f'Sequence = {self.sequence}') + print(f"Sequence = {self.sequence}") @classmethod def _construct_chain(cls, graph: Graph, node_list: List[List[Node]]) -> Chain: @@ -138,14 +145,14 @@ class CheckpointSolverRotor(CheckpointSolverBase): btime = 0 fwd_mem_peak = 0 for n in node: - assert isinstance(n, Node), f'{n} is not a Node' + assert isinstance(n, Node), f"{n} is not a Node" if n.target == runtime_apply or n.target == runtime_comm_spec_apply: # in this case we need to calculate memory usage directly based on the statics that hooked in node.meta - xbar += n.meta['fwd_mem_out'] - fwd_mem_peak = max(fwd_mem_peak, xbar + n.meta['fwd_mem_tmp']) + xbar += n.meta["fwd_mem_out"] + fwd_mem_peak = max(fwd_mem_peak, xbar + n.meta["fwd_mem_tmp"]) else: xbar += calculate_fwd_tmp(n) + calculate_fwd_out(n) - fwd_mem_peak = max(fwd_mem_peak, xbar + n.meta['fwd_mem_tmp'] + cls._extract_unused_output(n)) + fwd_mem_peak = max(fwd_mem_peak, xbar + n.meta["fwd_mem_tmp"] + cls._extract_unused_output(n)) # minimum flop count is required ftime += max(calculate_fwd_time(n), 1.0) @@ -162,14 +169,14 @@ class CheckpointSolverRotor(CheckpointSolverBase): """Extract input tensors from a Graph""" input_tensors = [] for node in graph.nodes: - if node.op == 'placeholder': - input_tensors.append(node.meta['fwd_out']) + if node.op == "placeholder": + input_tensors.append(node.meta["fwd_out"]) return input_tensors @staticmethod def _extract_unused_output(node: Node) -> int: """Extract unused output from `torch.fx.Node`""" - return activation_size(node.meta['fwd_out']) - calculate_fwd_out(node) + return activation_size(node.meta["fwd_out"]) - calculate_fwd_out(node) @staticmethod def _extract_btmp(node: List[Node]) -> int: @@ -180,8 +187,8 @@ class CheckpointSolverRotor(CheckpointSolverBase): for k, v in deps.items(): k: Node if v > 0: - deps_size += k.meta['bwd_mem_out'] - if v == float('-inf'): + deps_size += k.meta["bwd_mem_out"] + if v == float("-inf"): deps_size -= calculate_fwd_tmp(k) + calculate_fwd_out(k) return deps_size @@ -190,12 +197,12 @@ class CheckpointSolverRotor(CheckpointSolverBase): deps = {} for n in reversed(node): deps[n] = len(n.all_input_nodes) - btmp = max(btmp, _extract_deps_size() + n.meta['bwd_mem_tmp']) + btmp = max(btmp, _extract_deps_size() + n.meta["bwd_mem_tmp"]) for child in n.users: if child in deps: deps[child] -= 1 if deps[child] <= 0: - deps[child] = float('-inf') # free + deps[child] = float("-inf") # free return btmp @staticmethod @@ -244,10 +251,11 @@ class CheckpointSolverRotor(CheckpointSolverBase): if m < mmin: cost_table[m][i][idx] = float("inf") else: - leaf_checkpoints = [(j, - sum(ftime[i:j]) + cost_table[m - x[j]][j][idx] + cost_table[m][i][j - 1]) - for j in range(i + 1, idx + 1) - if m >= x[j]] + leaf_checkpoints = [ + (j, sum(ftime[i:j]) + cost_table[m - x[j]][j][idx] + cost_table[m][i][j - 1]) + for j in range(i + 1, idx + 1) + if m >= x[j] + ] if leaf_checkpoints: best_leaf = min(leaf_checkpoints, key=lambda t: t[1]) else: @@ -274,13 +282,16 @@ class CheckpointSolverRotor(CheckpointSolverBase): import os import subprocess import sys + logger = get_dist_logger() logger.info("rotorc hasn't been built! Building library...", ranks=[0]) this_dir = os.path.dirname(os.path.abspath(__file__)) result = subprocess.Popen( [ - f"{sys.executable}", f"{os.path.join(this_dir, 'build_c_ext.py')}", "build_ext", - f"--build-lib={this_dir}" + f"{sys.executable}", + f"{os.path.join(this_dir, 'build_c_ext.py')}", + "build_ext", + f"--build-lib={this_dir}", ], stdout=subprocess.PIPE, stderr=subprocess.PIPE, @@ -294,8 +305,9 @@ class CheckpointSolverRotor(CheckpointSolverBase): return compute_table(chain, mmax) @staticmethod - def _backtrack(chain: Chain, lhs: int, rhs: int, budget: int, cost_table: List[Any], - back_ptr: List[Any]) -> "Sequence": + def _backtrack( + chain: Chain, lhs: int, rhs: int, budget: int, cost_table: List[Any], back_ptr: List[Any] + ) -> "Sequence": """Backtrack the cost table and retrieve the optimal checkpointing strategy. Args: @@ -328,8 +340,9 @@ class CheckpointSolverRotor(CheckpointSolverBase): if back_ptr[budget][lhs][rhs][0]: sequence += [ ForwardEnable(lhs), - CheckpointSolverRotor._backtrack(chain, lhs + 1, rhs, budget - chain.xbar[lhs + 1], cost_table, - back_ptr), + CheckpointSolverRotor._backtrack( + chain, lhs + 1, rhs, budget - chain.xbar[lhs + 1], cost_table, back_ptr + ), Backward(lhs), ] else: @@ -337,8 +350,9 @@ class CheckpointSolverRotor(CheckpointSolverBase): sequence += [ForwardCheck(lhs)] sequence += [ForwardNograd(k) for k in range(lhs + 1, best_leaf)] sequence += [ - CheckpointSolverRotor._backtrack(chain, best_leaf, rhs, budget - chain.x[best_leaf], cost_table, - back_ptr), + CheckpointSolverRotor._backtrack( + chain, best_leaf, rhs, budget - chain.x[best_leaf], cost_table, back_ptr + ), CheckpointSolverRotor._backtrack(chain, lhs, best_leaf - 1, budget, cost_table, back_ptr), ] return sequence @@ -353,8 +367,8 @@ class CheckpointSolverRotor(CheckpointSolverBase): """ op_list = sequence.list_operations() loss_op = next(op for op in op_list if isinstance(op, Loss)) - fwd_list = op_list[:op_list.index(loss_op)] - bwd_list = op_list[op_list.index(loss_op) + 1:] + fwd_list = op_list[: op_list.index(loss_op)] + bwd_list = op_list[op_list.index(loss_op) + 1 :] ckpt_idx = 0 in_ckpt = False ckpt_region = [] @@ -369,7 +383,7 @@ class CheckpointSolverRotor(CheckpointSolverBase): in_ckpt = False for node_idx in ckpt_region: for n in node_list[node_idx]: - n.meta['activation_checkpoint'] = [ckpt_idx] + n.meta["activation_checkpoint"] = [ckpt_idx] ckpt_idx += 1 ckpt_region = [] @@ -377,7 +391,7 @@ class CheckpointSolverRotor(CheckpointSolverBase): elif isinstance(op, ForwardCheck): for node_idx in ckpt_region: for n in node_list[node_idx]: - n.meta['activation_checkpoint'] = [ckpt_idx] + n.meta["activation_checkpoint"] = [ckpt_idx] ckpt_idx += 1 ckpt_region = [idx] @@ -397,7 +411,7 @@ class CheckpointSolverRotor(CheckpointSolverBase): elif isinstance(op, ForwardEnable): for node_idx in ckpt_region: for n in node_list[node_idx]: - n.meta['activation_checkpoint'].append(ckpt_idx) + n.meta["activation_checkpoint"].append(ckpt_idx) ckpt_idx += 1 ckpt_region = [] @@ -405,7 +419,7 @@ class CheckpointSolverRotor(CheckpointSolverBase): elif isinstance(op, ForwardCheck): for node_idx in ckpt_region: for n in node_list[node_idx]: - n.meta['activation_checkpoint'].append(ckpt_idx) + n.meta["activation_checkpoint"].append(ckpt_idx) ckpt_idx += 1 ckpt_region = [op.index] @@ -413,7 +427,7 @@ class CheckpointSolverRotor(CheckpointSolverBase): elif isinstance(op, Backward): for node_idx in ckpt_region: for n in node_list[node_idx]: - n.meta['activation_checkpoint'].append(ckpt_idx) + n.meta["activation_checkpoint"].append(ckpt_idx) in_recompute = False @@ -431,9 +445,11 @@ class CheckpointSolverRotor(CheckpointSolverBase): for node in node_list: op_list += node ckpt_regions = _find_nested_ckpt_regions(op_list) - for (start_idx, end_idx) in ckpt_regions: + for start_idx, end_idx in ckpt_regions: nested_length = max( - len(op_list[idx].meta['activation_checkpoint']) for idx in range(start_idx, end_idx + 1)) + len(op_list[idx].meta["activation_checkpoint"]) for idx in range(start_idx, end_idx + 1) + ) for idx in range(start_idx, end_idx + 1): - op_list[idx].meta['activation_checkpoint'] += [None] * (nested_length - - len(op_list[idx].meta['activation_checkpoint'])) + op_list[idx].meta["activation_checkpoint"] += [None] * ( + nested_length - len(op_list[idx].meta["activation_checkpoint"]) + ) diff --git a/colossalai/auto_parallel/checkpoint/operation.py b/colossalai/auto_parallel/checkpoint/operation.py index ab0c6c5ad38d171d470931aa9b2bdedf6cd17668..5f80779164339930c6943a9132952ee75f988a14 100644 --- a/colossalai/auto_parallel/checkpoint/operation.py +++ b/colossalai/auto_parallel/checkpoint/operation.py @@ -1,20 +1,21 @@ import math from abc import ABC -from typing import Any, Iterable, List +from typing import List from torch.utils._pytree import tree_map class Chain: - - def __init__(self, - ftime: List[float], - btime: List[float], - x: List[int], - xbar: List[int], - ftmp: List[int], - btmp: List[int], - check_consistency: bool = True): + def __init__( + self, + ftime: List[float], + btime: List[float], + x: List[int], + xbar: List[int], + ftmp: List[int], + btmp: List[int], + check_consistency: bool = True, + ): """The chain is a basic linearized structure for solving the dynamic programming problem for activation checkpoint. See paper https://hal.inria.fr/hal-02352969 for details. @@ -37,9 +38,14 @@ class Chain: raise AttributeError("In Chain, input lists do not have consistent lengths") def check_lengths(self): - return ((len(self.ftime) == len(self)) and (len(self.btime) == len(self) + 1) and (len(self.x) == len(self) + 1) - and (len(self.ftmp) == len(self)) and (len(self.btmp) == len(self) + 1) - and (len(self.xbar) == len(self) + 1)) + return ( + (len(self.ftime) == len(self)) + and (len(self.btime) == len(self) + 1) + and (len(self.x) == len(self) + 1) + and (len(self.ftmp) == len(self)) + and (len(self.btmp) == len(self) + 1) + and (len(self.xbar) == len(self) + 1) + ) def __repr__(self): chain_list = [] @@ -100,7 +106,6 @@ class ForwardCheck(Forward): class Forwards(Operation): - def __init__(self, start, end): self.index = (start, end) @@ -109,9 +114,9 @@ class Forwards(Operation): def cost(self, chain: Chain): if chain is not None: - return sum(chain.ftime[self.index[0]:self.index[1] + 1]) + return sum(chain.ftime[self.index[0] : self.index[1] + 1]) else: - return (self.index[1] - self.index[0] + 1) + return self.index[1] - self.index[0] + 1 def isForward(op): @@ -132,7 +137,6 @@ class Backward(Operation): class Loss(Operation): - def __init__(self): pass @@ -166,7 +170,6 @@ class DiscardMemory(MemoryAccess): class Sequence(list): - def __init__(self): super().__init__() diff --git a/colossalai/auto_parallel/meta_profiler/constants.py b/colossalai/auto_parallel/meta_profiler/constants.py index 35b8c13ee8fff717df39a96c60fa101eb0b2a781..2f638fa919e445ac4bc24c9b8904bfef27e608c5 100644 --- a/colossalai/auto_parallel/meta_profiler/constants.py +++ b/colossalai/auto_parallel/meta_profiler/constants.py @@ -3,8 +3,6 @@ import operator import torch import torch.nn as nn -from ..tensor_shard.constants import * - # list of inplace module INPLACE_MODULE = [nn.ReLU] diff --git a/colossalai/auto_parallel/meta_profiler/meta_registry/activation.py b/colossalai/auto_parallel/meta_profiler/meta_registry/activation.py index 0f2e9e44f91cedfdb888171dd373d4d6163f579f..4234481ae2ca8fa10769d043949d6ecff0d537da 100644 --- a/colossalai/auto_parallel/meta_profiler/meta_registry/activation.py +++ b/colossalai/auto_parallel/meta_profiler/meta_registry/activation.py @@ -25,28 +25,32 @@ def elementwise_meta_info(temp_mem_scale: float = 0, buffer_mem_scale: float = 0 def meta_func(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, List[torch.Tensor]]: input_tensor = next( filter( - lambda x: - (x.type == OperationDataType.ARG or x.type == OperationDataType.PARAM) and x.name != 'softmax_dim', - args)).data + lambda x: (x.type == OperationDataType.ARG or x.type == OperationDataType.PARAM) + and x.name != "softmax_dim", + args, + ) + ).data output_tensor = next(filter(lambda x: x.type == OperationDataType.OUTPUT, args)).data - is_inplace = 1 if kwargs.get('inplace', False) else 0 + is_inplace = 1 if kwargs.get("inplace", False) else 0 flop_counter = elementwise_flop_counter(1, 0) # calculate compute cost fwd_compute_cost = flop_counter([input_tensor], [output_tensor]) bwd_compute_cost = flop_counter([output_tensor], [input_tensor]) - compute_cost = TrainCycleItem(fwd=fwd_compute_cost, - bwd=bwd_compute_cost, - total=fwd_compute_cost + bwd_compute_cost) + compute_cost = TrainCycleItem( + fwd=fwd_compute_cost, bwd=bwd_compute_cost, total=fwd_compute_cost + bwd_compute_cost + ) # calculate memory cost # NOTE: currently in SPMD solver we always believe that there will be a new tensor created in forward # NOTE: if in_place is True, we will not create a new tensor in forward - fwd_memory_cost = MemoryCost(activation=activation_size(input_tensor) * (2 - is_inplace), - parameter=0, - temp=0, - buffer=activation_size(input_tensor) * buffer_mem_scale) + fwd_memory_cost = MemoryCost( + activation=activation_size(input_tensor) * (2 - is_inplace), + parameter=0, + temp=0, + buffer=activation_size(input_tensor) * buffer_mem_scale, + ) # temp_mem_scale is for situation like softmax backward # the buffer will be removed during backward phase @@ -54,20 +58,23 @@ def elementwise_meta_info(temp_mem_scale: float = 0, buffer_mem_scale: float = 0 activation=activation_size(input_tensor) - activation_size(input_tensor) * buffer_mem_scale, parameter=0, temp=activation_size(input_tensor) * temp_mem_scale + activation_size(input_tensor) * buffer_mem_scale, - buffer=0) + buffer=0, + ) # total cost is the sum of forward and backward cost - total_cost = MemoryCost(activation=fwd_memory_cost.activation + bwd_memory_cost.activation, - parameter=fwd_memory_cost.parameter + bwd_memory_cost.parameter, - temp=fwd_memory_cost.temp + bwd_memory_cost.temp, - buffer=fwd_memory_cost.buffer + bwd_memory_cost.buffer) + total_cost = MemoryCost( + activation=fwd_memory_cost.activation + bwd_memory_cost.activation, + parameter=fwd_memory_cost.parameter + bwd_memory_cost.parameter, + temp=fwd_memory_cost.temp + bwd_memory_cost.temp, + buffer=fwd_memory_cost.buffer + bwd_memory_cost.buffer, + ) memory_cost = TrainCycleItem(fwd=fwd_memory_cost, bwd=bwd_memory_cost, total=total_cost) # store fwd_in, fwd_buffer, fwd_out fwd_in = [] - fwd_buffer = [torch.zeros_like(output_tensor, device='meta')] - fwd_out = [torch.zeros_like(output_tensor, device='meta')] + fwd_buffer = [torch.zeros_like(output_tensor, device="meta")] + fwd_out = [torch.zeros_like(output_tensor, device="meta")] return compute_cost, memory_cost, fwd_in, fwd_buffer, fwd_out diff --git a/colossalai/auto_parallel/meta_profiler/meta_registry/binary_elementwise_ops.py b/colossalai/auto_parallel/meta_profiler/meta_registry/binary_elementwise_ops.py index e451748512b9abebcc4f63ad854be3f129ee52bd..0b7b51a719551710041471a95c0d620124abf8ef 100644 --- a/colossalai/auto_parallel/meta_profiler/meta_registry/binary_elementwise_ops.py +++ b/colossalai/auto_parallel/meta_profiler/meta_registry/binary_elementwise_ops.py @@ -6,10 +6,10 @@ from colossalai._analyzer._subclasses.flop_tensor import flop_mapping from colossalai._analyzer.fx.node_util import compute_size_in_bytes as activation_size from colossalai.auto_parallel.tensor_shard.sharding_strategy import MemoryCost, OperationDataType, TrainCycleItem -from ..constants import BCAST_FUNC_OP, NO_SAVE_ACTIVATION +from ..constants import BCAST_FUNC_OP from ..registry import meta_register -__all__ = ['binary_elementwise_meta_info'] +__all__ = ["binary_elementwise_meta_info"] @meta_register.register(BCAST_FUNC_OP) @@ -61,6 +61,6 @@ def binary_elementwise_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, Train # store fwd_in, fwd_buffer, fwd_out fwd_in = [] fwd_buffer = [] - fwd_out = [torch.zeros_like(output_op_data.data, device='meta')] + fwd_out = [torch.zeros_like(output_op_data.data, device="meta")] return compute_cost, memory_cost, fwd_in, fwd_buffer, fwd_out diff --git a/colossalai/auto_parallel/meta_profiler/meta_registry/conv.py b/colossalai/auto_parallel/meta_profiler/meta_registry/conv.py index 4336bf68363c8a708b877c8d8116c44986a85592..2f630995cdbce49dafa2199ce85c31ebdc07b5b7 100644 --- a/colossalai/auto_parallel/meta_profiler/meta_registry/conv.py +++ b/colossalai/auto_parallel/meta_profiler/meta_registry/conv.py @@ -1,22 +1,14 @@ -from typing import Callable, Dict, List, Tuple, Union +from typing import List, Tuple import torch from colossalai._analyzer._subclasses.flop_tensor import flop_mapping from colossalai._analyzer.fx.node_util import compute_size_in_bytes -from colossalai.auto_parallel.tensor_shard.sharding_strategy import ( - MemoryCost, - OperationData, - OperationDataType, - ShardingStrategy, - StrategiesVector, - TrainCycleItem, -) -from colossalai.tensor.sharding_spec import ShardingSpec +from colossalai.auto_parallel.tensor_shard.sharding_strategy import MemoryCost, OperationDataType, TrainCycleItem from ..registry import meta_register -__all__ = ['convnd_meta_info'] +__all__ = ["convnd_meta_info"] @meta_register.register(torch.nn.Conv1d) @@ -103,35 +95,47 @@ def convnd_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, L # calculate compute cost fwd_compute_cost = flop_mapping[torch.ops.aten.convolution.default](fwd_args, (output_tensor,)) - bwd_compute_cost = flop_mapping[torch.ops.aten.convolution_backward.default](bwd_args, (input_tensor, weight_tensor, bias_tensor)) if has_bias else \ - flop_mapping[torch.ops.aten.convolution_backward.default](bwd_args, (input_tensor, weight_tensor)) + bwd_compute_cost = ( + flop_mapping[torch.ops.aten.convolution_backward.default](bwd_args, (input_tensor, weight_tensor, bias_tensor)) + if has_bias + else flop_mapping[torch.ops.aten.convolution_backward.default](bwd_args, (input_tensor, weight_tensor)) + ) compute_cost = TrainCycleItem(fwd=fwd_compute_cost, bwd=bwd_compute_cost, total=fwd_compute_cost + bwd_compute_cost) # calculate memory cost # TODO: use profiler to check conv temp memory # NOTE: currently in SPMD solver we always believe that there will be a new tensor created in forward - fwd_memory_cost = MemoryCost(activation=compute_size_in_bytes([input_tensor, output_tensor]), - parameter=compute_size_in_bytes([weight_tensor, bias_tensor]) - if has_bias else compute_size_in_bytes(weight_tensor), - temp=0, - buffer=0) - - bwd_memory_cost = MemoryCost(activation=compute_size_in_bytes([input_tensor, weight_tensor, bias_tensor]) - if has_bias else compute_size_in_bytes([input_tensor, weight_tensor]), - parameter=compute_size_in_bytes([weight_tensor, bias_tensor]) - if has_bias else compute_size_in_bytes(weight_tensor), - temp=0, - buffer=0) + fwd_memory_cost = MemoryCost( + activation=compute_size_in_bytes([input_tensor, output_tensor]), + parameter=compute_size_in_bytes([weight_tensor, bias_tensor]) + if has_bias + else compute_size_in_bytes(weight_tensor), + temp=0, + buffer=0, + ) + + bwd_memory_cost = MemoryCost( + activation=compute_size_in_bytes([input_tensor, weight_tensor, bias_tensor]) + if has_bias + else compute_size_in_bytes([input_tensor, weight_tensor]), + parameter=compute_size_in_bytes([weight_tensor, bias_tensor]) + if has_bias + else compute_size_in_bytes(weight_tensor), + temp=0, + buffer=0, + ) # total cost is the sum of forward and backward cost - total_cost = MemoryCost(activation=fwd_memory_cost.activation + bwd_memory_cost.activation, - parameter=fwd_memory_cost.parameter + bwd_memory_cost.parameter) + total_cost = MemoryCost( + activation=fwd_memory_cost.activation + bwd_memory_cost.activation, + parameter=fwd_memory_cost.parameter + bwd_memory_cost.parameter, + ) memory_cost = TrainCycleItem(fwd=fwd_memory_cost, bwd=bwd_memory_cost, total=total_cost) # store fwd_in, fwd_buffer, fwd_out - fwd_in = [torch.zeros_like(input_tensor, device='meta')] + fwd_in = [torch.zeros_like(input_tensor, device="meta")] fwd_buffer = [] - fwd_out = [torch.zeros_like(output_tensor, device='meta')] + fwd_out = [torch.zeros_like(output_tensor, device="meta")] return compute_cost, memory_cost, fwd_in, fwd_buffer, fwd_out diff --git a/colossalai/auto_parallel/meta_profiler/meta_registry/embedding.py b/colossalai/auto_parallel/meta_profiler/meta_registry/embedding.py index d5d80f5b3700b6b644c0b630496bd907c0b5aac2..7c9add810fd87cef0491f0ce68c3055ebddd5820 100644 --- a/colossalai/auto_parallel/meta_profiler/meta_registry/embedding.py +++ b/colossalai/auto_parallel/meta_profiler/meta_registry/embedding.py @@ -24,8 +24,9 @@ def embedding_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem # compute cost fwd_compute_cost = flop_mapping[torch.ops.aten.embedding.default]([weight_tensor, input_tensor], [output_tensor]) - bwd_compute_cost = flop_mapping[torch.ops.aten.embedding_dense_backward.default]([output_tensor, weight_tensor], - [weight_tensor]) + bwd_compute_cost = flop_mapping[torch.ops.aten.embedding_dense_backward.default]( + [output_tensor, weight_tensor], [weight_tensor] + ) compute_cost = TrainCycleItem(fwd=fwd_compute_cost, bwd=bwd_compute_cost, total=fwd_compute_cost + bwd_compute_cost) @@ -34,10 +35,9 @@ def embedding_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem # NOTE: during the backward phase of torch.nn.Embedding, it seems when the input is large enough, it will # have a temp memory which is kind of weird and we don't know the reason yet, so currently we just assume # that there will be no temp memory, as the temp memory is significantly smaller than the gradient memory - fwd_memory_cost = MemoryCost(activation=compute_size_in_bytes([input_tensor, output_tensor]), - parameter=0, - temp=0, - buffer=0) + fwd_memory_cost = MemoryCost( + activation=compute_size_in_bytes([input_tensor, output_tensor]), parameter=0, temp=0, buffer=0 + ) bwd_memory_cost = MemoryCost(activation=compute_size_in_bytes([weight_tensor]), parameter=0, temp=0, buffer=0) total_memory_cost = MemoryCost(activation=fwd_memory_cost.activation + bwd_memory_cost.activation) diff --git a/colossalai/auto_parallel/meta_profiler/meta_registry/linear.py b/colossalai/auto_parallel/meta_profiler/meta_registry/linear.py index 7697fc6c383d8154acfe76dba7d8baec225930ac..d731f9cb4436f0b054beef9a7326a21fea8233a6 100644 --- a/colossalai/auto_parallel/meta_profiler/meta_registry/linear.py +++ b/colossalai/auto_parallel/meta_profiler/meta_registry/linear.py @@ -1,23 +1,15 @@ from functools import reduce -from typing import Callable, Dict, List, Tuple, Union +from typing import List, Tuple import torch from colossalai._analyzer._subclasses.flop_tensor import flop_mapping from colossalai._analyzer.fx.node_util import compute_size_in_bytes -from colossalai.auto_parallel.tensor_shard.sharding_strategy import ( - MemoryCost, - OperationData, - OperationDataType, - ShardingStrategy, - StrategiesVector, - TrainCycleItem, -) -from colossalai.tensor.sharding_spec import ShardingSpec +from colossalai.auto_parallel.tensor_shard.sharding_strategy import MemoryCost, TrainCycleItem from ..registry import meta_register -__all__ = ['linear_meta_info', 'matmul_meta_info'] +__all__ = ["linear_meta_info", "matmul_meta_info"] @meta_register.register(torch.nn.functional.linear) @@ -100,32 +92,43 @@ def linear_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, L # calculate compute cost fwd_compute_cost = flop_mapping[torch.ops.aten.addmm.default]( - [bias_tensor, input_tensor, torch.transpose(weight_tensor, 0, 1)], (output_tensor,)) - bwd_compute_cost = flop_mapping[torch.ops.aten.mm.default]([output_tensor, weight_tensor], (input_tensor,)) + \ - flop_mapping[torch.ops.aten.mm.default]([torch.transpose(output_tensor, 0, 1), input_tensor], (weight_tensor,)) + \ - flop_mapping[torch.ops.aten.sum.dim_IntList]([output_tensor], (bias_tensor,)) - compute_cost = TrainCycleItem(fwd=fwd_compute_cost, - bwd=bwd_compute_cost, - total=fwd_compute_cost + bwd_compute_cost) + [bias_tensor, input_tensor, torch.transpose(weight_tensor, 0, 1)], (output_tensor,) + ) + bwd_compute_cost = ( + flop_mapping[torch.ops.aten.mm.default]([output_tensor, weight_tensor], (input_tensor,)) + + flop_mapping[torch.ops.aten.mm.default]( + [torch.transpose(output_tensor, 0, 1), input_tensor], (weight_tensor,) + ) + + flop_mapping[torch.ops.aten.sum.dim_IntList]([output_tensor], (bias_tensor,)) + ) + compute_cost = TrainCycleItem( + fwd=fwd_compute_cost, bwd=bwd_compute_cost, total=fwd_compute_cost + bwd_compute_cost + ) # calculate memory cost # NOTE: Linear don't have buffer and temp in forward and backward phase # the forward activation cost is the size of output_tensor, parameter cost is the size of weight_tensor and bias_tensor # NOTE: currently in SPMD solver we always believe that there will be a new tensor created in forward - fwd_memory_cost = MemoryCost(activation=compute_size_in_bytes([input_tensor, output_tensor]), - parameter=compute_size_in_bytes([weight_tensor, bias_tensor]), - temp=0, - buffer=0) + fwd_memory_cost = MemoryCost( + activation=compute_size_in_bytes([input_tensor, output_tensor]), + parameter=compute_size_in_bytes([weight_tensor, bias_tensor]), + temp=0, + buffer=0, + ) # the backward activation cost is the size of input_tensor, weight_tensor and bias_tensor, parameter cost is 0 - bwd_memory_cost = MemoryCost(activation=compute_size_in_bytes([input_tensor, weight_tensor, bias_tensor]), - parameter=compute_size_in_bytes([weight_tensor, bias_tensor]), - temp=0, - buffer=0) + bwd_memory_cost = MemoryCost( + activation=compute_size_in_bytes([input_tensor, weight_tensor, bias_tensor]), + parameter=compute_size_in_bytes([weight_tensor, bias_tensor]), + temp=0, + buffer=0, + ) # total cost is to sum the forward and backward cost - total_cost = MemoryCost(activation=fwd_memory_cost.activation + bwd_memory_cost.activation, - parameter=fwd_memory_cost.parameter + bwd_memory_cost.parameter) + total_cost = MemoryCost( + activation=fwd_memory_cost.activation + bwd_memory_cost.activation, + parameter=fwd_memory_cost.parameter + bwd_memory_cost.parameter, + ) memory_cost = TrainCycleItem(fwd=fwd_memory_cost, bwd=bwd_memory_cost, total=total_cost) @@ -136,39 +139,49 @@ def linear_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, L # calculate compute cost fwd_compute_cost = flop_mapping[torch.ops.aten.mm.default]( - [input_tensor, torch.transpose(weight_tensor, 0, 1)], (output_tensor,)) - bwd_compute_cost = flop_mapping[torch.ops.aten.mm.default]([output_tensor, weight_tensor], (input_tensor,)) + \ - flop_mapping[torch.ops.aten.mm.default]([torch.transpose(output_tensor, 0, 1), input_tensor], (weight_tensor,)) + [input_tensor, torch.transpose(weight_tensor, 0, 1)], (output_tensor,) + ) + bwd_compute_cost = flop_mapping[torch.ops.aten.mm.default]( + [output_tensor, weight_tensor], (input_tensor,) + ) + flop_mapping[torch.ops.aten.mm.default]( + [torch.transpose(output_tensor, 0, 1), input_tensor], (weight_tensor,) + ) - compute_cost = TrainCycleItem(fwd=fwd_compute_cost, - bwd=bwd_compute_cost, - total=fwd_compute_cost + bwd_compute_cost) + compute_cost = TrainCycleItem( + fwd=fwd_compute_cost, bwd=bwd_compute_cost, total=fwd_compute_cost + bwd_compute_cost + ) # calculate memory cost # NOTE: Linear don't have buffer and temp in forward and backward phase # the forward activation cost is the size of output_tensor, parameter cost is the size of weight_tensor # NOTE: currently in SPMD solver we always believe that there will be a new tensor created in forward - fwd_memory_cost = MemoryCost(activation=compute_size_in_bytes([input_tensor, output_tensor]), - parameter=compute_size_in_bytes(weight_tensor), - temp=0, - buffer=0) + fwd_memory_cost = MemoryCost( + activation=compute_size_in_bytes([input_tensor, output_tensor]), + parameter=compute_size_in_bytes(weight_tensor), + temp=0, + buffer=0, + ) # the backward activation cost is the size of input_tensor and weight_tensor, parameter cost is 0 - bwd_memory_cost = MemoryCost(activation=compute_size_in_bytes([input_tensor, weight_tensor]), - parameter=compute_size_in_bytes(weight_tensor), - temp=0, - buffer=0) + bwd_memory_cost = MemoryCost( + activation=compute_size_in_bytes([input_tensor, weight_tensor]), + parameter=compute_size_in_bytes(weight_tensor), + temp=0, + buffer=0, + ) # total cost is to sum the forward and backward cost - total_cost = MemoryCost(activation=fwd_memory_cost.activation + bwd_memory_cost.activation, - parameter=fwd_memory_cost.parameter + bwd_memory_cost.parameter) + total_cost = MemoryCost( + activation=fwd_memory_cost.activation + bwd_memory_cost.activation, + parameter=fwd_memory_cost.parameter + bwd_memory_cost.parameter, + ) memory_cost = TrainCycleItem(fwd=fwd_memory_cost, bwd=bwd_memory_cost, total=total_cost) # store fwd_in, fwd_buffer, fwd_out - fwd_in = [torch.zeros_like(input_tensor, device='meta')] + fwd_in = [torch.zeros_like(input_tensor, device="meta")] fwd_buffer = [] - fwd_out = [torch.zeros_like(output_tensor, device='meta')] + fwd_out = [torch.zeros_like(output_tensor, device="meta")] return compute_cost, memory_cost, fwd_in, fwd_buffer, fwd_out @@ -222,15 +235,16 @@ def matmul_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, L # batched gemv case 1: batched matrix-vector multiplication fwd_compute_cost = flop_mapping[torch.ops.aten.matmul.default]( - [input_tensors[0].reshape(-1, input_tensors[0].shape[-1]), input_tensors[1]], output_tensors) + [input_tensors[0].reshape(-1, input_tensors[0].shape[-1]), input_tensors[1]], output_tensors + ) # combine the dimensions of output bwd_compute_cost = flop_mapping[torch.ops.aten.mul.Tensor]( - [output_tensors[0].reshape(-1), input_tensors[1]], - output_tensors) + \ - flop_mapping[torch.ops.aten.matmul.default]( - [input_tensors[0].reshape(-1, input_tensors[0].shape[-1]).transpose(0, 1), output_tensors[0].reshape(-1)], - output_tensors) + [output_tensors[0].reshape(-1), input_tensors[1]], output_tensors + ) + flop_mapping[torch.ops.aten.matmul.default]( + [input_tensors[0].reshape(-1, input_tensors[0].shape[-1]).transpose(0, 1), output_tensors[0].reshape(-1)], + output_tensors, + ) fwd_mem_cost = MemoryCost(activation=compute_size_in_bytes(output_tensors), parameter=0, temp=0, buffer=0) bwd_mem_cost = MemoryCost(activation=compute_size_in_bytes(input_tensors), parameter=0, temp=0, buffer=0) @@ -239,93 +253,111 @@ def matmul_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, L # gemv case 2: vector-matrix multiplication fwd_compute_cost = flop_mapping[torch.ops.aten.matmul.default](input_tensors, output_tensors) - bwd_compute_cost = flop_mapping[torch.ops.aten.mul.Tensor]([output_tensors[0], input_tensors[0]], output_tensors) + \ - flop_mapping[torch.ops.aten.matmul.default]([input_tensors[1], output_tensors[0]], output_tensors) + bwd_compute_cost = flop_mapping[torch.ops.aten.mul.Tensor]( + [output_tensors[0], input_tensors[0]], output_tensors + ) + flop_mapping[torch.ops.aten.matmul.default]([input_tensors[1], output_tensors[0]], output_tensors) fwd_mem_cost = MemoryCost(activation=compute_size_in_bytes(output_tensors), parameter=0, temp=0, buffer=0) - bwd_mem_cost = MemoryCost(activation=compute_size_in_bytes(input_tensors), - parameter=0, - temp=compute_size_in_bytes(input_tensors[1]), - buffer=0) + bwd_mem_cost = MemoryCost( + activation=compute_size_in_bytes(input_tensors), + parameter=0, + temp=compute_size_in_bytes(input_tensors[1]), + buffer=0, + ) elif len(input_tensors[0].shape) == 1 and len(input_tensors[1].shape) >= 3: # batched gemv case 2: vector-batched matrix multiplication fwd_compute_cost = flop_mapping[torch.ops.aten.matmul.default]( [input_tensors[1].transpose(-2, -1).reshape(-1, input_tensors[1].shape[-2]), input_tensors[0]], - [output_tensors[0].reshape(-1)]) + [output_tensors[0].reshape(-1)], + ) # combine the dimensions of output bwd_compute_cost = flop_mapping[torch.ops.aten.mul.Tensor]( - [output_tensors[0].reshape(-1), input_tensors[0]], - output_tensors - ) + \ - flop_mapping[torch.ops.aten.matmul.default]( - [input_tensors[1].transpose(-2, -1).reshape(-1, input_tensors[1].shape[-2]).transpose(0, 1), output_tensors[0].reshape(-1)], - output_tensors - ) + [output_tensors[0].reshape(-1), input_tensors[0]], output_tensors + ) + flop_mapping[torch.ops.aten.matmul.default]( + [ + input_tensors[1].transpose(-2, -1).reshape(-1, input_tensors[1].shape[-2]).transpose(0, 1), + output_tensors[0].reshape(-1), + ], + output_tensors, + ) fwd_mem_cost = MemoryCost(activation=compute_size_in_bytes(output_tensors + [input_tensors[1]])) - bwd_mem_cost = MemoryCost(activation=compute_size_in_bytes(input_tensors[0]), - parameter=0, - temp=compute_size_in_bytes(input_tensors[1]), - buffer=0) + bwd_mem_cost = MemoryCost( + activation=compute_size_in_bytes(input_tensors[0]), + parameter=0, + temp=compute_size_in_bytes(input_tensors[1]), + buffer=0, + ) elif len(input_tensors[0].shape) >= 2 and len(input_tensors[1].shape) == 2: # gemm & batched gemm case 1: batched matrix-matrix multiplication fwd_compute_cost = flop_mapping[torch.ops.aten.mm.default]( [input_tensors[0].reshape(-1, input_tensors[0].shape[-1]), input_tensors[1]], - [output_tensors[0].reshape(-1, output_tensors[0].shape[-1])]) + [output_tensors[0].reshape(-1, output_tensors[0].shape[-1])], + ) bwd_compute_cost = flop_mapping[torch.ops.aten.mm.default]( - [input_tensors[0].reshape(-1, input_tensors[0].shape[-1]).transpose(0, 1), output_tensors[0].reshape(-1, output_tensors[0].shape[-1])], - [input_tensors[1]] - ) + \ - flop_mapping[torch.ops.aten.mm.default]( - [output_tensors[0].reshape(-1, output_tensors[0].shape[-1]), input_tensors[1].transpose(0, 1)], - [input_tensors[0].reshape(-1, input_tensors[0].shape[-1])] - ) + [ + input_tensors[0].reshape(-1, input_tensors[0].shape[-1]).transpose(0, 1), + output_tensors[0].reshape(-1, output_tensors[0].shape[-1]), + ], + [input_tensors[1]], + ) + flop_mapping[torch.ops.aten.mm.default]( + [output_tensors[0].reshape(-1, output_tensors[0].shape[-1]), input_tensors[1].transpose(0, 1)], + [input_tensors[0].reshape(-1, input_tensors[0].shape[-1])], + ) fwd_mem_cost = MemoryCost(activation=compute_size_in_bytes(output_tensors), parameter=0, temp=0, buffer=0) bwd_mem_cost = MemoryCost(activation=compute_size_in_bytes(input_tensors), parameter=0, temp=0, buffer=0) elif len(input_tensors[0].shape) == 2 and len(input_tensors[1].shape) >= 3: # batched gemm case 2: matrix-batched matrix multiplication - fwd_compute_cost = flop_mapping[torch.ops.aten.mm.default]([ - input_tensors[1].transpose(-2, -1).reshape(-1, input_tensors[1].shape[-2]), input_tensors[0].transpose( - 0, 1) - ], [output_tensors[0].transpose(-2, -1)]) + fwd_compute_cost = flop_mapping[torch.ops.aten.mm.default]( + [ + input_tensors[1].transpose(-2, -1).reshape(-1, input_tensors[1].shape[-2]), + input_tensors[0].transpose(0, 1), + ], + [output_tensors[0].transpose(-2, -1)], + ) bwd_compute_cost = flop_mapping[torch.ops.aten.mm.default]( - [output_tensors[0].transpose(-2, -1).reshape(-1, output_tensors[0].shape[-2]).transpose(0, 1), input_tensors[1].transpose(-2, -1).reshape(-1, input_tensors[1].shape[-2])], - [input_tensors[0]] - ) + \ - flop_mapping[torch.ops.aten.mm.default]( - [output_tensors[0].transpose(-2, -1).reshape(-1, output_tensors[0].shape[-2]), input_tensors[0]], - [input_tensors[1].transpose(-2, -1).reshape(-1, input_tensors[1].shape[-2])] - ) - - fwd_mem_cost = MemoryCost(activation=compute_size_in_bytes(output_tensors) + - compute_size_in_bytes(input_tensors[1]), - temp=compute_size_in_bytes(output_tensors)) - bwd_mem_cost = MemoryCost(activation=compute_size_in_bytes(input_tensors[0]), - parameter=0, - temp=compute_size_in_bytes(input_tensors[1]) + compute_size_in_bytes(output_tensors)) + [ + output_tensors[0].transpose(-2, -1).reshape(-1, output_tensors[0].shape[-2]).transpose(0, 1), + input_tensors[1].transpose(-2, -1).reshape(-1, input_tensors[1].shape[-2]), + ], + [input_tensors[0]], + ) + flop_mapping[torch.ops.aten.mm.default]( + [output_tensors[0].transpose(-2, -1).reshape(-1, output_tensors[0].shape[-2]), input_tensors[0]], + [input_tensors[1].transpose(-2, -1).reshape(-1, input_tensors[1].shape[-2])], + ) + + fwd_mem_cost = MemoryCost( + activation=compute_size_in_bytes(output_tensors) + compute_size_in_bytes(input_tensors[1]), + temp=compute_size_in_bytes(output_tensors), + ) + bwd_mem_cost = MemoryCost( + activation=compute_size_in_bytes(input_tensors[0]), + parameter=0, + temp=compute_size_in_bytes(input_tensors[1]) + compute_size_in_bytes(output_tensors), + ) elif all(len(tensor.shape) >= 3 for tensor in input_tensors): # Batched matrix-batched matrix multiplication # Fetch shape of the two inputs and see if the batch dimensions are the same _is_batch_dims_same = True if len(input_tensors[0].shape) == len(input_tensors[1].shape): - for (shape_0, shape_1) in zip(input_tensors[0].shape[:-2], input_tensors[1].shape[:-2]): + for shape_0, shape_1 in zip(input_tensors[0].shape[:-2], input_tensors[1].shape[:-2]): if shape_0 != shape_1: _is_batch_dims_same = False break else: _is_batch_dims_same = False - # retireve dimensions + # retrieve dimensions input_dim_00 = input_tensors[0].shape[-2] input_dim_01 = input_tensors[0].shape[-1] input_dim_10 = input_tensors[1].shape[-2] @@ -337,20 +369,28 @@ def matmul_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, L # Case 1: batch dimensions are the same # Forward compute cost: C = A * B - fwd_compute_cost = flop_mapping[torch.ops.aten.bmm.default]([ - input_tensors[0].reshape(-1, input_dim_00, input_dim_01), input_tensors[1].reshape( - -1, input_dim_10, input_dim_11) - ], [output_tensors[0].reshape(-1, output_dim_0, output_dim_1)]) + fwd_compute_cost = flop_mapping[torch.ops.aten.bmm.default]( + [ + input_tensors[0].reshape(-1, input_dim_00, input_dim_01), + input_tensors[1].reshape(-1, input_dim_10, input_dim_11), + ], + [output_tensors[0].reshape(-1, output_dim_0, output_dim_1)], + ) # Backward compute cost: dB = A^T * dC, dA = dC * B^T bwd_compute_cost = flop_mapping[torch.ops.aten.bmm.default]( - [input_tensors[0].transpose(-2, -1).reshape(-1, input_dim_01, input_dim_00), output_tensors[0].reshape(-1, output_dim_0, output_dim_1)], - [input_tensors[1].reshape(-1, input_dim_11, input_dim_10)] - ) + \ - flop_mapping[torch.ops.aten.bmm.default]( - [output_tensors[0].reshape(-1, output_dim_0, output_dim_1), input_tensors[1].transpose(-2, -1).reshape(-1, input_dim_11, input_dim_10)], - [input_tensors[0].reshape(-1, input_dim_00, input_dim_01)] - ) + [ + input_tensors[0].transpose(-2, -1).reshape(-1, input_dim_01, input_dim_00), + output_tensors[0].reshape(-1, output_dim_0, output_dim_1), + ], + [input_tensors[1].reshape(-1, input_dim_11, input_dim_10)], + ) + flop_mapping[torch.ops.aten.bmm.default]( + [ + output_tensors[0].reshape(-1, output_dim_0, output_dim_1), + input_tensors[1].transpose(-2, -1).reshape(-1, input_dim_11, input_dim_10), + ], + [input_tensors[0].reshape(-1, input_dim_00, input_dim_01)], + ) fwd_mem_cost = MemoryCost(activation=compute_size_in_bytes(output_tensors)) bwd_mem_cost = MemoryCost(activation=compute_size_in_bytes(input_tensors)) @@ -358,43 +398,46 @@ def matmul_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, L else: # Case 2: batch dimensions are different batch_dims = output_tensors[0].shape[:-2] - extended_input_0 = torch.rand(reduce(lambda x, y: x * y, batch_dims), - input_dim_00, - input_dim_01, - device="meta") - extended_input_1 = torch.rand(reduce(lambda x, y: x * y, batch_dims), - input_dim_10, - input_dim_11, - device="meta") + extended_input_0 = torch.rand( + reduce(lambda x, y: x * y, batch_dims), input_dim_00, input_dim_01, device="meta" + ) + extended_input_1 = torch.rand( + reduce(lambda x, y: x * y, batch_dims), input_dim_10, input_dim_11, device="meta" + ) # Forward compute cost: C = A * B fwd_compute_cost = flop_mapping[torch.ops.aten.bmm.default]( - [extended_input_0, extended_input_1], [output_tensors[0].reshape(-1, output_dim_0, output_dim_1)]) + [extended_input_0, extended_input_1], [output_tensors[0].reshape(-1, output_dim_0, output_dim_1)] + ) # Backward compute cost: dB = A^T * dC, dA = dC * B^T bwd_compute_cost = flop_mapping[torch.ops.aten.bmm.default]( - [extended_input_0.transpose(-2, -1), output_tensors[0].reshape(-1, output_dim_0, output_dim_1)], - [extended_input_1] - ) + \ - flop_mapping[torch.ops.aten.bmm.default]( - [output_tensors[0].reshape(-1, output_dim_0, output_dim_1), extended_input_1.transpose(-2, -1)], - [extended_input_0] - ) + [extended_input_0.transpose(-2, -1), output_tensors[0].reshape(-1, output_dim_0, output_dim_1)], + [extended_input_1], + ) + flop_mapping[torch.ops.aten.bmm.default]( + [output_tensors[0].reshape(-1, output_dim_0, output_dim_1), extended_input_1.transpose(-2, -1)], + [extended_input_0], + ) fwd_mem_cost = MemoryCost( - activation=compute_size_in_bytes([output_tensors[0], extended_input_0, extended_input_1])) - bwd_mem_cost = MemoryCost(activation=compute_size_in_bytes(input_tensors) - - compute_size_in_bytes([extended_input_0, extended_input_1]), - temp=compute_size_in_bytes([extended_input_0, extended_input_1])) + activation=compute_size_in_bytes([output_tensors[0], extended_input_0, extended_input_1]) + ) + bwd_mem_cost = MemoryCost( + activation=compute_size_in_bytes(input_tensors) + - compute_size_in_bytes([extended_input_0, extended_input_1]), + temp=compute_size_in_bytes([extended_input_0, extended_input_1]), + ) # compute cost compute_cost = TrainCycleItem(fwd=fwd_compute_cost, bwd=bwd_compute_cost, total=fwd_compute_cost + bwd_compute_cost) # memory cost - total_cost = MemoryCost(activation=fwd_mem_cost.activation + bwd_mem_cost.activation, - parameter=fwd_mem_cost.parameter + bwd_mem_cost.parameter, - temp=fwd_mem_cost.temp + bwd_mem_cost.temp, - buffer=fwd_mem_cost.buffer + bwd_mem_cost.buffer) + total_cost = MemoryCost( + activation=fwd_mem_cost.activation + bwd_mem_cost.activation, + parameter=fwd_mem_cost.parameter + bwd_mem_cost.parameter, + temp=fwd_mem_cost.temp + bwd_mem_cost.temp, + buffer=fwd_mem_cost.buffer + bwd_mem_cost.buffer, + ) memory_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_cost) diff --git a/colossalai/auto_parallel/meta_profiler/meta_registry/non_spmd.py b/colossalai/auto_parallel/meta_profiler/meta_registry/non_spmd.py index 12874810b13e252c0597e2adf124ab7875e992a3..b1bb1d872c3549e8804976b03397b385963ecfc8 100644 --- a/colossalai/auto_parallel/meta_profiler/meta_registry/non_spmd.py +++ b/colossalai/auto_parallel/meta_profiler/meta_registry/non_spmd.py @@ -3,7 +3,7 @@ from typing import List, Tuple import torch -from colossalai.auto_parallel.tensor_shard.sharding_strategy import MemoryCost, OperationDataType, TrainCycleItem +from colossalai.auto_parallel.tensor_shard.sharding_strategy import MemoryCost, TrainCycleItem from ..registry import meta_register diff --git a/colossalai/auto_parallel/meta_profiler/meta_registry/norm.py b/colossalai/auto_parallel/meta_profiler/meta_registry/norm.py index b872fdc8bdcd19717e7b81d436fffd860ec88519..99aaa752d0a1b05b543a03eccfa19d5d15d096e7 100644 --- a/colossalai/auto_parallel/meta_profiler/meta_registry/norm.py +++ b/colossalai/auto_parallel/meta_profiler/meta_registry/norm.py @@ -1,22 +1,14 @@ -from typing import Callable, Dict, List, Tuple, Union +from typing import List, Tuple import torch from colossalai._analyzer._subclasses.flop_tensor import flop_mapping from colossalai._analyzer.fx.node_util import compute_size_in_bytes -from colossalai.auto_parallel.tensor_shard.sharding_strategy import ( - MemoryCost, - OperationData, - OperationDataType, - ShardingStrategy, - StrategiesVector, - TrainCycleItem, -) -from colossalai.tensor.sharding_spec import ShardingSpec +from colossalai.auto_parallel.tensor_shard.sharding_strategy import MemoryCost, OperationDataType, TrainCycleItem from ..registry import meta_register -__all__ = ['batchnormnd_meta_info', 'layernorm_meta_info'] +__all__ = ["batchnormnd_meta_info", "layernorm_meta_info"] @meta_register.register(torch.nn.BatchNorm1d) @@ -65,7 +57,15 @@ def batchnormnd_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleIt # saved inv std and some other args indicating the status of the module # the bwd outputs are input grad, weight grad and bias grad bwd_in_args = [ - output_tensor, output_tensor, weight_tensor, mean_tensor, var_tensor, mean_tensor, var_tensor, 1e-5, num_batch + output_tensor, + output_tensor, + weight_tensor, + mean_tensor, + var_tensor, + mean_tensor, + var_tensor, + 1e-5, + num_batch, ] bwd_out_args = [input_tensor, weight_tensor, bias_tensor] @@ -77,29 +77,34 @@ def batchnormnd_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleIt # calculate memory cost # the fwd activation cost is output plus saved mean and saved inv std # NOTE: currently in SPMD solver we always believe that there will be a new tensor created in forward - fwd_memory_cost = MemoryCost(activation=compute_size_in_bytes( - [input_tensor, output_tensor, mean_tensor, var_tensor]), - parameter=compute_size_in_bytes([weight_tensor, bias_tensor]), - temp=0, - buffer=compute_size_in_bytes([mean_tensor, var_tensor])) + fwd_memory_cost = MemoryCost( + activation=compute_size_in_bytes([input_tensor, output_tensor, mean_tensor, var_tensor]), + parameter=compute_size_in_bytes([weight_tensor, bias_tensor]), + temp=0, + buffer=compute_size_in_bytes([mean_tensor, var_tensor]), + ) # the bwd memory cost is quite tricky here, BatchNorm will remove saved mean # and saved inv std during backward phase - bwd_memory_cost = MemoryCost(activation=compute_size_in_bytes([input_tensor]), - parameter=compute_size_in_bytes([weight_tensor, bias_tensor]), - temp=compute_size_in_bytes([mean_tensor, var_tensor]), - buffer=compute_size_in_bytes([mean_tensor, var_tensor])) + bwd_memory_cost = MemoryCost( + activation=compute_size_in_bytes([input_tensor]), + parameter=compute_size_in_bytes([weight_tensor, bias_tensor]), + temp=compute_size_in_bytes([mean_tensor, var_tensor]), + buffer=compute_size_in_bytes([mean_tensor, var_tensor]), + ) # total cost is the sum of forward and backward cost - total_cost = MemoryCost(activation=fwd_memory_cost.activation + bwd_memory_cost.activation, - parameter=fwd_memory_cost.parameter + bwd_memory_cost.parameter) + total_cost = MemoryCost( + activation=fwd_memory_cost.activation + bwd_memory_cost.activation, + parameter=fwd_memory_cost.parameter + bwd_memory_cost.parameter, + ) memory_cost = TrainCycleItem(fwd=fwd_memory_cost, bwd=bwd_memory_cost, total=total_cost) # store fwd_in, fwd_buffer, fwd_out - fwd_in = [torch.zeros_like(input_tensor, device='meta')] - fwd_buffer = [torch.zeros_like(mean_tensor, device='meta'), torch.zeros_like(var_tensor, device='meta')] - fwd_out = [torch.zeros_like(output_tensor, device='meta')] + fwd_in = [torch.zeros_like(input_tensor, device="meta")] + fwd_buffer = [torch.zeros_like(mean_tensor, device="meta"), torch.zeros_like(var_tensor, device="meta")] + fwd_out = [torch.zeros_like(output_tensor, device="meta")] return compute_cost, memory_cost, fwd_in, fwd_buffer, fwd_out @@ -116,8 +121,8 @@ def layernorm_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem output_tensor = next(filter(lambda x: x.type == OperationDataType.OUTPUT, args)).data weight_tensor = next(filter(lambda x: x.name == "weight", args)).data bias_tensor = next(filter(lambda x: x.name == "bias", args)).data - running_mean = torch.rand(input_tensor.shape[0], 1, device='meta') - running_var = torch.rand(input_tensor.shape[0], 1, device='meta') + running_mean = torch.rand(input_tensor.shape[0], 1, device="meta") + running_var = torch.rand(input_tensor.shape[0], 1, device="meta") # construct args fwd_in_args = [input_tensor, [input_tensor.shape[0]], weight_tensor] @@ -132,27 +137,32 @@ def layernorm_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem # memory cost # NOTE: currently in SPMD solver we always believe that there will be a new tensor created in forward - fwd_memory_cost = MemoryCost(activation=compute_size_in_bytes( - [input_tensor, output_tensor, weight_tensor, bias_tensor]), - parameter=compute_size_in_bytes([weight_tensor, bias_tensor]), - temp=0, - buffer=compute_size_in_bytes([running_mean, running_var])) - - bwd_memory_cost = MemoryCost(activation=compute_size_in_bytes([input_tensor, weight_tensor, bias_tensor]), - parameter=compute_size_in_bytes([weight_tensor, bias_tensor]), - temp=compute_size_in_bytes([running_mean, running_var]), - buffer=compute_size_in_bytes([running_mean, running_var])) - - total_cost = MemoryCost(activation=fwd_memory_cost.activation + bwd_memory_cost.activation, - parameter=fwd_memory_cost.parameter + bwd_memory_cost.parameter, - temp=fwd_memory_cost.temp + bwd_memory_cost.temp, - buffer=fwd_memory_cost.buffer + bwd_memory_cost.buffer) + fwd_memory_cost = MemoryCost( + activation=compute_size_in_bytes([input_tensor, output_tensor, weight_tensor, bias_tensor]), + parameter=compute_size_in_bytes([weight_tensor, bias_tensor]), + temp=0, + buffer=compute_size_in_bytes([running_mean, running_var]), + ) + + bwd_memory_cost = MemoryCost( + activation=compute_size_in_bytes([input_tensor, weight_tensor, bias_tensor]), + parameter=compute_size_in_bytes([weight_tensor, bias_tensor]), + temp=compute_size_in_bytes([running_mean, running_var]), + buffer=compute_size_in_bytes([running_mean, running_var]), + ) + + total_cost = MemoryCost( + activation=fwd_memory_cost.activation + bwd_memory_cost.activation, + parameter=fwd_memory_cost.parameter + bwd_memory_cost.parameter, + temp=fwd_memory_cost.temp + bwd_memory_cost.temp, + buffer=fwd_memory_cost.buffer + bwd_memory_cost.buffer, + ) memory_cost = TrainCycleItem(fwd=fwd_memory_cost, bwd=bwd_memory_cost, total=total_cost) # store fwd_in, fwd_buffer, fwd_out - fwd_in = [torch.zeros_like(input_tensor, device='meta')] - fwd_buffer = [torch.zeros_like(running_mean, device='meta'), torch.zeros_like(running_var, device='meta')] - fwd_out = [torch.zeros_like(output_tensor, device='meta')] + fwd_in = [torch.zeros_like(input_tensor, device="meta")] + fwd_buffer = [torch.zeros_like(running_mean, device="meta"), torch.zeros_like(running_var, device="meta")] + fwd_out = [torch.zeros_like(output_tensor, device="meta")] return compute_cost, memory_cost, fwd_in, fwd_buffer, fwd_out diff --git a/colossalai/auto_parallel/meta_profiler/meta_registry/pooling.py b/colossalai/auto_parallel/meta_profiler/meta_registry/pooling.py index d785dfcca9bacb46e129adda5f83486090975859..21aa524bed084f4684ac5414e7b8bf19c78048c7 100644 --- a/colossalai/auto_parallel/meta_profiler/meta_registry/pooling.py +++ b/colossalai/auto_parallel/meta_profiler/meta_registry/pooling.py @@ -63,7 +63,7 @@ def avgpool_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, # store fwd_in, fwd_buffer, fwd_out fwd_in = [] fwd_buffer = [] - fwd_out = [torch.zeros_like(output_tensor, device='meta')] + fwd_out = [torch.zeros_like(output_tensor, device="meta")] return compute_cost, mem_cost, fwd_in, fwd_buffer, fwd_out @@ -117,8 +117,10 @@ def maxpool_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, fwd_mem_cost = MemoryCost(activation=compute_size_in_bytes([input_tensor, output_tensor, index_matrix])) # temp memory for backward is the index matrix to be discarded - bwd_mem_cost = MemoryCost(activation=compute_size_in_bytes(input_tensor) - compute_size_in_bytes(index_matrix), - temp=compute_size_in_bytes(index_matrix)) + bwd_mem_cost = MemoryCost( + activation=compute_size_in_bytes(input_tensor) - compute_size_in_bytes(index_matrix), + temp=compute_size_in_bytes(index_matrix), + ) # total cost total_mem_cost = MemoryCost(activation=fwd_mem_cost.activation + bwd_mem_cost.activation, temp=bwd_mem_cost.temp) @@ -126,8 +128,8 @@ def maxpool_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, mem_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_mem_cost) # store fwd_in, fwd_buffer, fwd_out - fwd_in = [torch.zeros_like(input_tensor, device='meta')] - fwd_buffer = [torch.zeros_like(index_matrix, device='meta')] - fwd_out = [torch.zeros_like(output_tensor, device='meta')] + fwd_in = [torch.zeros_like(input_tensor, device="meta")] + fwd_buffer = [torch.zeros_like(index_matrix, device="meta")] + fwd_out = [torch.zeros_like(output_tensor, device="meta")] return compute_cost, mem_cost, fwd_in, fwd_buffer, fwd_out diff --git a/colossalai/auto_parallel/meta_profiler/meta_registry/tensor.py b/colossalai/auto_parallel/meta_profiler/meta_registry/tensor.py index 97fe3c6196f591af7bbfcbdcf59ff3afd114175f..9a2df1bd7c870fbff108c41e1a943c20acb43c63 100644 --- a/colossalai/auto_parallel/meta_profiler/meta_registry/tensor.py +++ b/colossalai/auto_parallel/meta_profiler/meta_registry/tensor.py @@ -2,7 +2,6 @@ from typing import Callable, List, Tuple import torch -from colossalai._analyzer._subclasses.flop_tensor import flop_mapping from colossalai._analyzer.fx.node_util import compute_size_in_bytes from colossalai.auto_parallel.tensor_shard.sharding_strategy import MemoryCost, OperationDataType, TrainCycleItem @@ -37,15 +36,19 @@ def tensor_related_metainfo(bwd_mem_out_factor: float = 1, bwd_mem_tmp_factor: f # NOTE: currently in SPMD solver we always believe that there will be a new tensor created in forward fwd_mem_cost = MemoryCost(activation=compute_size_in_bytes(outputs) * 2, parameter=0, temp=0, buffer=0) - bwd_mem_cost = MemoryCost(activation=compute_size_in_bytes(outputs) * bwd_mem_out_factor, - parameter=0, - temp=compute_size_in_bytes(outputs) * bwd_mem_tmp_factor, - buffer=0) + bwd_mem_cost = MemoryCost( + activation=compute_size_in_bytes(outputs) * bwd_mem_out_factor, + parameter=0, + temp=compute_size_in_bytes(outputs) * bwd_mem_tmp_factor, + buffer=0, + ) - total_mem_cost = MemoryCost(activation=fwd_mem_cost.activation + bwd_mem_cost.activation, - parameter=fwd_mem_cost.parameter + bwd_mem_cost.parameter, - temp=fwd_mem_cost.temp + bwd_mem_cost.temp, - buffer=fwd_mem_cost.buffer + bwd_mem_cost.buffer) + total_mem_cost = MemoryCost( + activation=fwd_mem_cost.activation + bwd_mem_cost.activation, + parameter=fwd_mem_cost.parameter + bwd_mem_cost.parameter, + temp=fwd_mem_cost.temp + bwd_mem_cost.temp, + buffer=fwd_mem_cost.buffer + bwd_mem_cost.buffer, + ) memory_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_mem_cost) @@ -66,14 +69,24 @@ def tensor_related_metainfo(bwd_mem_out_factor: float = 1, bwd_mem_tmp_factor: f # register torch.Tensor related metainfo # (0, 0) -meta_register.register([torch.tensor, torch.Tensor.to, torch.Tensor.unsqueeze, torch.unsqueeze, - torch.arange])(tensor_related_metainfo(0, 0)) +meta_register.register([torch.tensor, torch.Tensor.to, torch.Tensor.unsqueeze, torch.unsqueeze, torch.arange])( + tensor_related_metainfo(0, 0) +) # (1, 0) -meta_register.register([ - torch.Tensor.flatten, torch.flatten, torch.Tensor.transpose, torch.transpose, torch.Tensor.permute, torch.permute, - torch.Tensor.split, torch.split, torch.Tensor.view -])(tensor_related_metainfo(1, 0)) +meta_register.register( + [ + torch.Tensor.flatten, + torch.flatten, + torch.Tensor.transpose, + torch.transpose, + torch.Tensor.permute, + torch.permute, + torch.Tensor.split, + torch.split, + torch.Tensor.view, + ] +)(tensor_related_metainfo(1, 0)) # (1, 1) meta_register.register([torch.Tensor.type, torch.Tensor.contiguous])(tensor_related_metainfo(1, 1)) diff --git a/colossalai/auto_parallel/meta_profiler/meta_registry/where.py b/colossalai/auto_parallel/meta_profiler/meta_registry/where.py index 5cba1b5b6e2b16521ed2a0df2fbab98b19492c53..107851b80d7c1a45744d0817dc36f18928eb1444 100644 --- a/colossalai/auto_parallel/meta_profiler/meta_registry/where.py +++ b/colossalai/auto_parallel/meta_profiler/meta_registry/where.py @@ -4,7 +4,7 @@ import torch from colossalai._analyzer._subclasses.flop_tensor import flop_mapping from colossalai._analyzer.fx.node_util import compute_size_in_bytes as activation_size -from colossalai.auto_parallel.tensor_shard.sharding_strategy import MemoryCost, OperationDataType, TrainCycleItem +from colossalai.auto_parallel.tensor_shard.sharding_strategy import MemoryCost, TrainCycleItem from ..registry import meta_register @@ -39,16 +39,21 @@ def where_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, Li # gradient matrix for input x and input y, remove the temp memory and condition tensor generated in forward phase # NOTE: currently in SPMD solver we always believe that there will be a new input tensor created in forward fwd_mem_cost = MemoryCost(activation=activation_size([condition_tensor, x_tensor, y_tensor, output_tensor])) - bwd_mem_cost = MemoryCost(activation=activation_size([x_tensor, y_tensor]) - activation_size([condition_tensor]), - parameter=0, - temp=activation_size([output_tensor]) * 3 + activation_size([condition_tensor]) - - activation_size([x_tensor, y_tensor]), - buffer=0) - - total_mem_cost = MemoryCost(activation=fwd_mem_cost.activation + bwd_mem_cost.activation, - parameter=fwd_mem_cost.parameter + bwd_mem_cost.parameter, - temp=fwd_mem_cost.temp + bwd_mem_cost.temp, - buffer=fwd_mem_cost.buffer + bwd_mem_cost.buffer) + bwd_mem_cost = MemoryCost( + activation=activation_size([x_tensor, y_tensor]) - activation_size([condition_tensor]), + parameter=0, + temp=activation_size([output_tensor]) * 3 + + activation_size([condition_tensor]) + - activation_size([x_tensor, y_tensor]), + buffer=0, + ) + + total_mem_cost = MemoryCost( + activation=fwd_mem_cost.activation + bwd_mem_cost.activation, + parameter=fwd_mem_cost.parameter + bwd_mem_cost.parameter, + temp=fwd_mem_cost.temp + bwd_mem_cost.temp, + buffer=fwd_mem_cost.buffer + bwd_mem_cost.buffer, + ) memory_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_mem_cost) diff --git a/colossalai/auto_parallel/meta_profiler/registry.py b/colossalai/auto_parallel/meta_profiler/registry.py index 46350c4dd406691c344eb92a933636d6b029b8bd..c29086f7f9d16fdcae11d56f10a1f2ac63a82e5b 100644 --- a/colossalai/auto_parallel/meta_profiler/registry.py +++ b/colossalai/auto_parallel/meta_profiler/registry.py @@ -1,14 +1,12 @@ -__all__ = ['Registry'] +__all__ = ["Registry"] class Registry: - def __init__(self, name): self.name = name self.store = {} def register(self, source): - def wrapper(func): if isinstance(source, (list, tuple)): # support register a list of items for this func @@ -21,7 +19,7 @@ class Registry: return wrapper def get(self, source): - assert source in self.store, f'{source} not found in the {self.name} registry' + assert source in self.store, f"{source} not found in the {self.name} registry" target = self.store[source] return target @@ -29,4 +27,4 @@ class Registry: return source in self.store -meta_register = Registry('meta') +meta_register = Registry("meta") diff --git a/colossalai/auto_parallel/meta_profiler/shard_metainfo.py b/colossalai/auto_parallel/meta_profiler/shard_metainfo.py index 0eee908b48b73d9d1cd5e0e35fbf50b8d844e3a6..109b8a220ac760174883a38726efe87d8a6cfb52 100644 --- a/colossalai/auto_parallel/meta_profiler/shard_metainfo.py +++ b/colossalai/auto_parallel/meta_profiler/shard_metainfo.py @@ -2,20 +2,13 @@ from typing import Callable, List import torch -from colossalai.auto_parallel.tensor_shard.sharding_strategy import ( - MemoryCost, - OperationData, - OperationDataType, - ShardingStrategy, - StrategiesVector, - TrainCycleItem, -) +from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, ShardingStrategy, TrainCycleItem from colossalai.tensor.sharding_spec import ShardingSpec from .constants import INPLACE_MODULE, INPLACE_OPS, NO_SAVE_ACTIVATION from .registry import meta_register -__all__ = ['ShardMetaInfo'] +__all__ = ["ShardMetaInfo"] class ShardMetaInfo: @@ -76,10 +69,12 @@ class ShardMetaInfo: """ if isinstance(sharding_spec, ShardingSpec): - op_data = OperationData(name=operation_data.name, - data=torch.zeros(sharding_spec.get_sharded_shape_per_device(), device="meta"), - type=operation_data.type, - logical_shape=operation_data.logical_shape) + op_data = OperationData( + name=operation_data.name, + data=torch.zeros(sharding_spec.get_sharded_shape_per_device(), device="meta"), + type=operation_data.type, + logical_shape=operation_data.logical_shape, + ) elif isinstance(sharding_spec, (list, tuple)): data = operation_data.data assert isinstance(data, (list, tuple)), f"Data Should be list or tuple, but got {type(data)}." @@ -97,8 +92,9 @@ class ShardMetaInfo: """ Compute meta info based on sharding strategy and the given target function. """ - assert meta_register.has(self._target.__class__) or meta_register.has(self._target), \ - f"Meta info for {self._target} is not registered." + assert meta_register.has(self._target.__class__) or meta_register.has( + self._target + ), f"Meta info for {self._target} is not registered." if meta_register.has(self._target.__class__): # module meta_func = meta_register.get(self._target.__class__) @@ -117,11 +113,11 @@ class ShardMetaInfo: # construct kwargs if self.target in INPLACE_MODULE: - kwargs = {'inplace': self.target.inplace} + kwargs = {"inplace": self.target.inplace} elif self.target in INPLACE_OPS: - kwargs = {'inplace': True} + kwargs = {"inplace": True} else: - kwargs = {'inplace': False} + kwargs = {"inplace": False} # compute metainfo with meta_func self.compute_cost, self.memory_cost, self.fwd_in, self.fwd_buffer, self.fwd_out = meta_func(*args, **kwargs) diff --git a/colossalai/auto_parallel/offload/amp_optimizer.py b/colossalai/auto_parallel/offload/amp_optimizer.py index a79e5006e7d2ac264f00b5a597e7b869f6f580eb..601bf2926d991a4f8c844f69ffbc6e09f551e6f3 100644 --- a/colossalai/auto_parallel/offload/amp_optimizer.py +++ b/colossalai/auto_parallel/offload/amp_optimizer.py @@ -1,24 +1,25 @@ -from typing import Dict, Tuple from enum import Enum +from typing import Dict, Tuple + import torch from torch.optim import Optimizer -from colossalai.logging import get_dist_logger -from colossalai.nn.optimizer import ColossalaiOptimizer from colossalai.amp.naive_amp.grad_scaler import DynamicGradScaler +from colossalai.interface import OptimizerWrapper +from colossalai.logging import get_dist_logger from colossalai.utils import get_current_device from .base_offload_module import BaseOffloadModule -from .region_manager import RegionManager from .region import Region +from .region_manager import RegionManager class OptimState(Enum): SCALED = 0 UNSCALED = 1 -class AMPOptimizer(ColossalaiOptimizer): +class AMPOptimizer(OptimizerWrapper): """ A wrapper for Optimizer. Code reference: https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/nn/optimizer/zero_optimizer.py @@ -36,19 +37,20 @@ class AMPOptimizer(ColossalaiOptimizer): norm_type (float, optional): norm_type used for `clip_grad_norm`. """ - def __init__(self, - optimizer: Optimizer, - module: BaseOffloadModule, - initial_scale: float = 2**16, - growth_factor: float = 2, - backoff_factor: float = 0.5, - growth_interval: int = 1000, - hysteresis: int = 2, - min_scale: float = 1, - max_scale: float = 2**32, - clipping_norm: float = 0.0, - norm_type: float = 2.0): - + def __init__( + self, + optimizer: Optimizer, + module: BaseOffloadModule, + initial_scale: float = 2**16, + growth_factor: float = 2, + backoff_factor: float = 0.5, + growth_interval: int = 1000, + hysteresis: int = 2, + min_scale: float = 1, + max_scale: float = 2**32, + clipping_norm: float = 0.0, + norm_type: float = 2.0, + ): super().__init__(optimizer) self.module = module @@ -68,19 +70,21 @@ class AMPOptimizer(ColossalaiOptimizer): self.__init__optimizer() # Grad scaler - self.grad_scaler = DynamicGradScaler(initial_scale=initial_scale, - min_scale=min_scale, - growth_factor=growth_factor, - backoff_factor=backoff_factor, - growth_interval=growth_interval, - hysteresis=hysteresis, - max_scale=max_scale) + self.grad_scaler = DynamicGradScaler( + initial_scale=initial_scale, + min_scale=min_scale, + growth_factor=growth_factor, + backoff_factor=backoff_factor, + growth_interval=growth_interval, + hysteresis=hysteresis, + max_scale=max_scale, + ) self._found_overflow: torch.Tensor = torch.zeros(1, dtype=torch.int64, device=get_current_device()) self._logger = get_dist_logger() def _set_grad_ptr(self): for group in self.param_groups: - for fake_param in group['params']: + for fake_param in group["params"]: region = self.param_to_region[fake_param] begin, end = self.param_to_range[fake_param] @@ -91,7 +95,7 @@ class AMPOptimizer(ColossalaiOptimizer): def _update_fp16_params(self): none_tensor = torch.empty([0]) for group in self.param_groups: - for fake_param in group['params']: + for fake_param in group["params"]: assert fake_param.grad is None fake_param.data = none_tensor self.param_to_region[fake_param].cpu_grad = None @@ -129,10 +133,10 @@ class AMPOptimizer(ColossalaiOptimizer): found_inf = self._check_overflow() if found_inf: - self.optim_state = OptimState.UNSCALED # no need to unscale grad - self.grad_scaler.update(found_inf) # update gradient scaler - self._logger.info(f'Found overflow. Skip step') - self.zero_grad() # reset all gradients + self.optim_state = OptimState.UNSCALED # no need to unscale grad + self.grad_scaler.update(found_inf) # update gradient scaler + self._logger.info(f"Found overflow. Skip step") + self.zero_grad() # reset all gradients self._update_fp16_params() return @@ -155,11 +159,10 @@ class AMPOptimizer(ColossalaiOptimizer): self.module.backward(loss) def __init__optimizer(self): - for group in self.optim.param_groups: fake_params_list = list() - for param in group['params']: + for param in group["params"]: region = self.region_manager.get_region(param) fake_param = torch.nn.Parameter(torch.empty([0])) self.param_to_range[fake_param] = region.param_to_range[param] @@ -170,8 +173,8 @@ class AMPOptimizer(ColossalaiOptimizer): if param in self.optim.state: self.optim.state[fake_param] = self.optim.state.pop(param) - group['params'] = fake_params_list + group["params"] = fake_params_list # Leverage state_dict() and load_state_dict() to # recast preexisting per-param state tensors - self.optim.load_state_dict(self.optim.state_dict()) \ No newline at end of file + self.optim.load_state_dict(self.optim.state_dict()) diff --git a/colossalai/auto_parallel/offload/base_offload_module.py b/colossalai/auto_parallel/offload/base_offload_module.py index d0c328e134ff5696ea2f8c17d5fc3468e4c891a2..f5e8e31f5e9798cb777c8ad8d9e5ad298d0fcb8b 100644 --- a/colossalai/auto_parallel/offload/base_offload_module.py +++ b/colossalai/auto_parallel/offload/base_offload_module.py @@ -4,7 +4,7 @@ from typing import Optional, Set import torch import torch.nn as nn -from colossalai.nn.parallel.data_parallel import _cast_float +from colossalai.utils import _cast_float from colossalai.zero.legacy.gemini.tensor_utils import free_storage from .region_manager import RegionManager @@ -22,7 +22,6 @@ class BaseOffloadModule: """ def __init__(self, model: nn.Module, region_manager: RegionManager, is_sync=True): - self.model = model self.region_manager = region_manager self.grad_hook_list = [] @@ -91,17 +90,16 @@ class BaseOffloadModule: def parameters(self, recurse: bool = True): return self.model.parameters(recurse) - def named_parameters(self, prefix: str = '', recurse: bool = True): + def named_parameters(self, prefix: str = "", recurse: bool = True): return self.model.named_parameters(prefix, recurse) - def named_buffers(self, prefix: str = '', recurse: bool = True): + def named_buffers(self, prefix: str = "", recurse: bool = True): return self.model.named_buffers(prefix, recurse) def named_children(self): return self.model.named_children() - def named_modules(self, - memo: Optional[Set[torch.nn.Module]] = None, - prefix: str = '', - remove_duplicate: bool = True): + def named_modules( + self, memo: Optional[Set[torch.nn.Module]] = None, prefix: str = "", remove_duplicate: bool = True + ): return self.model.named_modules(memo, prefix, remove_duplicate) diff --git a/colossalai/auto_parallel/offload/mem_optimize.py b/colossalai/auto_parallel/offload/mem_optimize.py index d56166dea982288bdea160e1347c8ca3f67ed297..74501c18451845c1f62c81ccbb757766928005b5 100644 --- a/colossalai/auto_parallel/offload/mem_optimize.py +++ b/colossalai/auto_parallel/offload/mem_optimize.py @@ -14,11 +14,9 @@ from .runtime import runtime_asyn_offload_apply_pass, runtime_syn_offload_apply_ from .util import GlobalRuntimeInfo, compute_act_peak_mem, compute_max_param_mem, compute_total_param_mem -def memory_optimize(model: torch.nn.Module, - inps: Dict[str, torch.Tensor], - memory_budget: float = -1.0, - solver_name: str = 'asyn'): - +def memory_optimize( + model: torch.nn.Module, inps: Dict[str, torch.Tensor], memory_budget: float = -1.0, solver_name: str = "asyn" +): model = model.cpu().half() tracer = ColoTracer() assert is_compatible_with_meta() @@ -40,13 +38,13 @@ def memory_optimize(model: torch.nn.Module, f"act_peak_mem={act_peak_mem:.3f} MB | max_param_mem={max_param_mem:.3f} MB | total_param_mem={total_param_mem:.3f}" ) - if solver_name == 'syn': + if solver_name == "syn": gm = runtime_syn_offload_apply_pass(gm, region_manager.region_list) - elif solver_name == 'asyn': + elif solver_name == "asyn": gm = runtime_asyn_offload_apply_pass(gm, region_manager.region_list) else: raise TypeError(f"Unknown solver name {solver_name}!") gm.recompile() - optimized_model = BaseOffloadModule(gm, region_manager, solver_name == 'syn') + optimized_model = BaseOffloadModule(gm, region_manager, solver_name == "syn") return optimized_model diff --git a/colossalai/auto_parallel/offload/region.py b/colossalai/auto_parallel/offload/region.py index 819ffbd96eb19098f168519ca6e3e0036fa3a638..ea92c714ce31d185b2da86963cada7e1bb2a4ac4 100644 --- a/colossalai/auto_parallel/offload/region.py +++ b/colossalai/auto_parallel/offload/region.py @@ -55,13 +55,13 @@ class Region: Map the parameters in the region to a contiguous memory space. """ - self.fp16_data = torch.zeros(self.param_num, dtype=torch.half, device='cuda') + self.fp16_data = torch.zeros(self.param_num, dtype=torch.half, device="cuda") offset = 0 for param in self.fp16_params: param.data = param.data.cuda() p_num = param.data.numel() - self.fp16_data[offset:offset + p_num].copy_(param.data.flatten()) - param.data = self.fp16_data[offset:offset + p_num].view(param.data.shape) + self.fp16_data[offset : offset + p_num].copy_(param.data.flatten()) + param.data = self.fp16_data[offset : offset + p_num].view(param.data.shape) self.param_to_range[param] = (offset, offset + p_num) offset += p_num @@ -83,7 +83,7 @@ class Region: self.temp_fp32_data.record_stream(torch.cuda.current_stream()) if not self.in_mem_pool_flag: alloc_storage(self.fp16_data) - self.fp16_data[:self.param_num].copy_(self.temp_fp32_data) + self.fp16_data[: self.param_num].copy_(self.temp_fp32_data) self.fp16_data.record_stream(torch.cuda.current_stream()) self.__update_params_ptr() @@ -94,7 +94,7 @@ class Region: """ self.cpu_grad = torch.empty(self.param_num, dtype=torch.half, pin_memory=True) - self.cpu_grad.copy_(self.fp16_data[:self.param_num], non_blocking=True) + self.cpu_grad.copy_(self.fp16_data[: self.param_num], non_blocking=True) self.fp16_data.record_stream(torch.cuda.current_stream()) if not self.in_mem_pool_flag: self.free_cuda_data() diff --git a/colossalai/auto_parallel/offload/region_manager.py b/colossalai/auto_parallel/offload/region_manager.py index 30bfaf00d4939afadc3c2aaaec3f27ce70db4a20..146dd267967d8b26c16d950245381c699e86c5ca 100644 --- a/colossalai/auto_parallel/offload/region_manager.py +++ b/colossalai/auto_parallel/offload/region_manager.py @@ -1,10 +1,11 @@ -from typing import List, Any, Dict, Tuple +from typing import Any, Dict, List, Tuple + import torch from torch.fx import Graph, Node +from .region import Region from .solver import SolverFactory from .training_simulator import TrainingSimulator -from .region import Region from .util import NodeInfo @@ -19,14 +20,9 @@ class RegionManager: cnode (List[str], optional): Common node List, should be the subset of input. """ - def __init__(self, - graph: Graph, - solver_name: str = 'asyn', - memory_budget: float = -1.0, - cnode: List[str] = None): - + def __init__(self, graph: Graph, solver_name: str = "asyn", memory_budget: float = -1.0, cnode: List[str] = None): self.graph = graph - assert graph.owning_module is not None, 'The given graph is not associated with a owning_module' + assert graph.owning_module is not None, "The given graph is not associated with a owning_module" self.root_module = self.graph.owning_module self.nodes = list(graph.nodes) self.cnode = cnode @@ -39,7 +35,7 @@ class RegionManager: self.memory_budget = memory_budget self.solver_name = solver_name - self.require_pool: bool = solver_name == 'asyn' + self.require_pool: bool = solver_name == "asyn" self.reg_to_block: Dict[int, int] = dict() @@ -61,22 +57,19 @@ class RegionManager: self._post_process(solver.best_ts) def _pre_process(self): - init_region_list = self._linearize_graph() if len(self.shared_region_pairs) > 1: - raise NotImplementedError( - 'The current version only considers at most one pair of parameter sharing.') + raise NotImplementedError("The current version only considers at most one pair of parameter sharing.") elif len(self.shared_region_pairs) == 1: shared_regs = self.shared_region_pairs[0] - assert shared_regs[0].shared_rid == shared_regs[1].r_id \ - and shared_regs[1].shared_rid == shared_regs[0].r_id + assert shared_regs[0].shared_rid == shared_regs[1].r_id and shared_regs[1].shared_rid == shared_regs[0].r_id fst_id = shared_regs[0].r_id lst_id = shared_regs[1].r_id - regs_left_out = init_region_list[:fst_id + 1] + regs_left_out = init_region_list[: fst_id + 1] regs_right_out = init_region_list[lst_id:] - hold_regs = init_region_list[fst_id + 1:lst_id] + hold_regs = init_region_list[fst_id + 1 : lst_id] else: regs_left_out = [] regs_right_out = [] @@ -122,12 +115,9 @@ class RegionManager: it may not find a suitable region placement strategy for the given execution flow. """ - reg_flow = torch.cat( - [ts.fwd_reg_flow, ts.bwd_reg_flow], dim=0) - mem_block_num = torch.max( - torch.sum(reg_flow[:, self.rid_in_pool], dim=1)) - coexist_matrix = torch.logical_or( - ts.fwd_reg_flow, ts.bwd_reg_flow) + reg_flow = torch.cat([ts.fwd_reg_flow, ts.bwd_reg_flow], dim=0) + mem_block_num = torch.max(torch.sum(reg_flow[:, self.rid_in_pool], dim=1)) + coexist_matrix = torch.logical_or(ts.fwd_reg_flow, ts.bwd_reg_flow) block_to_regs = {} for block_idx in range(mem_block_num): @@ -135,8 +125,7 @@ class RegionManager: for reg in self.region_list: if reg.r_id in self.rid_in_pool: cur_reg_appears = coexist_matrix[:, reg.r_id] - cur_reg_coexists = torch.sum( - coexist_matrix[cur_reg_appears], dim=0).bool() + cur_reg_coexists = torch.sum(coexist_matrix[cur_reg_appears], dim=0).bool() for block_idx in range(mem_block_num): if not any(cur_reg_coexists[block_to_regs[block_idx]]): block_to_regs[block_idx].append(reg.r_id) @@ -145,9 +134,12 @@ class RegionManager: if reg.r_id not in self.reg_to_block: raise NotImplementedError( - f'can not find a block from the memory pool to store parameters of the region') - self.memory_pool = torch.chunk(torch.zeros(int( - mem_block_num * self.mem_block_size / 2), dtype=torch.half, device='cuda'), chunks=int(mem_block_num)) + f"can not find a block from the memory pool to store parameters of the region" + ) + self.memory_pool = torch.chunk( + torch.zeros(int(mem_block_num * self.mem_block_size / 2), dtype=torch.half, device="cuda"), + chunks=int(mem_block_num), + ) def _merge_small_regions(self, orig_reg_list: List[Region]) -> List[Region]: """ @@ -178,10 +170,9 @@ class RegionManager: return region_list - def _search_block_size(self, - region_list: List[Region], - search_interval_byte: int = 1024, - search_range_byte: int = 128 * 1024 ** 2) -> int: + def _search_block_size( + self, region_list: List[Region], search_interval_byte: int = 1024, search_range_byte: int = 128 * 1024**2 + ) -> int: """ Search for a suitable memory block size. @@ -208,11 +199,10 @@ class RegionManager: acc_wasted += blk_size - left return acc_wasted - param_size_list = [ - region.param_size for region in region_list if region.r_id == region.shared_rid] + param_size_list = [region.param_size for region in region_list if region.r_id == region.shared_rid] start_size = max(param_size_list) - min_mem_waste = float('+inf') + min_mem_waste = float("+inf") best_block_size = start_size for block_size in range(start_size, start_size + search_range_byte + 1, search_interval_byte): @@ -229,7 +219,7 @@ class RegionManager: Initialize region data, which maps the parameters in the region to a contiguous memory space. """ - self.temp_fp32_data = torch.zeros(self.max_param_num, device='cuda', dtype=torch.float32) + self.temp_fp32_data = torch.zeros(self.max_param_num, device="cuda", dtype=torch.float32) for region in self.region_list: pre_alloc_tensor = None @@ -244,8 +234,7 @@ class RegionManager: region.fp16_data = shared_region.fp16_data region.fp32_data = shared_region.fp32_data region.param_to_range = shared_region.param_to_range - region.temp_fp32_data = self.temp_fp32_data[:region.param_num].detach( - ) + region.temp_fp32_data = self.temp_fp32_data[: region.param_num].detach() torch.cuda.empty_cache() @@ -259,13 +248,14 @@ class RegionManager: former_reg, latter_reg = self.shared_region_pairs[0] assert latter_reg.param_num >= former_reg.param_num embedding_node = former_reg.nodes[-1] - assert embedding_node.op == 'call_module' and isinstance( - self.root_module.get_submodule(embedding_node.target), torch.nn.Embedding) + assert embedding_node.op == "call_module" and isinstance( + self.root_module.get_submodule(embedding_node.target), torch.nn.Embedding + ) if latter_reg.param_num > former_reg.param_num: for idx, n in enumerate(latter_reg.nodes): - if (n.op == 'call_module' and isinstance(self.root_module.get_submodule(n.target), - torch.nn.Linear)) or \ - (n.op == 'call_function' and n.target is torch.nn.functional.linear): + if ( + n.op == "call_module" and isinstance(self.root_module.get_submodule(n.target), torch.nn.Linear) + ) or (n.op == "call_function" and n.target is torch.nn.functional.linear): cut_node_idx = idx + 1 break assert len(latter_reg.fp16_params) == 2 @@ -273,7 +263,7 @@ class RegionManager: for p in new_reg.fp16_params: self.param_region_map[p] = new_reg self.region_list.insert(new_reg.r_id, new_reg) - for reg in self.region_list[new_reg.r_id + 1:]: + for reg in self.region_list[new_reg.r_id + 1 :]: reg.r_id += 1 latter_reg.shared_rid = former_reg.r_id former_reg.shared_rid = latter_reg.r_id @@ -344,8 +334,8 @@ class RegionManager: target = n.target submod = self.root_module.get_submodule(target) if ( - len(list(submod.named_parameters(recurse=False))) != 0 - or len(list(submod.named_buffers(recurse=False))) != 0 + len(list(submod.named_parameters(recurse=False))) != 0 + or len(list(submod.named_buffers(recurse=False))) != 0 ): label = True @@ -362,14 +352,12 @@ class RegionManager: """ def _is_inplace(n: Node): - """Get the inplace argument from ``torch.fx.Node`` - """ + """Get the inplace argument from ``torch.fx.Node``""" inplace = False if n.op == "call_function": inplace = n.kwargs.get("inplace", False) elif n.op == "call_module": - inplace = getattr(n.graph.owning_module.get_submodule( - n.target), "inplace", False) + inplace = getattr(n.graph.owning_module.get_submodule(n.target), "inplace", False) return inplace label = False @@ -378,28 +366,30 @@ class RegionManager: target = n.target submod = self.root_module.get_submodule(target) if ( - len(list(submod.named_parameters(recurse=False))) != 0 - or len(list(submod.named_buffers(recurse=False))) != 0 + len(list(submod.named_parameters(recurse=False))) != 0 + or len(list(submod.named_buffers(recurse=False))) != 0 ): label = True elif n.op == "call_function": label = any(map(lambda x: x.name in self.only_param_ops, n.all_input_nodes)) and any( - map(lambda x: x.name not in self.only_param_ops and not _is_cop(n.target), n.all_input_nodes)) + map(lambda x: x.name not in self.only_param_ops and not _is_cop(n.target), n.all_input_nodes) + ) return label and not sum([v for _, v in param_op_deps.items()]) and not any(map(_is_inplace, n.users)) def _exception_node_handling(): # TODO meta info prop bug - if n.name.__contains__("transpose") and n.meta['fwd_out'][0].dim() <= 2: - n.meta['fwd_out'] = [] + if n.name.__contains__("transpose") and n.meta["fwd_out"][0].dim() <= 2: + n.meta["fwd_out"] = [] # make sure that item in cnode is valid if self.cnode: for name in self.cnode: try: - assert next(node for node in self.graph.nodes if node.name == name).op == "placeholder", \ - f"Common node {name} is not an input of the model." + assert ( + next(node for node in self.graph.nodes if node.name == name).op == "placeholder" + ), f"Common node {name} is not an input of the model." except StopIteration: raise ValueError(f"Common node name {name} not in graph.") else: @@ -428,8 +418,8 @@ class RegionManager: ns = [] border_n_idx = region.nodes.index(act_n) if border_n_idx < len(region.nodes): - ns = region.nodes[border_n_idx + 1:] - region.nodes = region.nodes[:border_n_idx + 1] + ns = region.nodes[border_n_idx + 1 :] + region.nodes = region.nodes[: border_n_idx + 1] region_list.append(region) region_id += 1 region = Region(r_id=region_id) @@ -448,19 +438,21 @@ class RegionManager: region = Region(r_id=region_id) # propagate common node attr if possible - if len(n.all_input_nodes) == len([node for node in n.all_input_nodes if node.name in self.cnode - ]) or _is_cop(n.target): + if len(n.all_input_nodes) == len( + [node for node in n.all_input_nodes if node.name in self.cnode] + ) or _is_cop(n.target): self.cnode.append(n.name) else: - deps[n] = len( - [user for user in n.users if user.op != "output"]) + deps[n] = len([user for user in n.users if user.op != "output"]) # propagate param node attr if possible - if len(n.all_input_nodes) == len([node for node in n.all_input_nodes if node.name in self.only_param_ops - ]) or n.op == "get_attr": + if ( + len(n.all_input_nodes) + == len([node for node in n.all_input_nodes if node.name in self.only_param_ops]) + or n.op == "get_attr" + ): self.only_param_ops.append(n.name) - param_op_deps[n] = len( - [user for user in n.users if user.op != "output"]) + param_op_deps[n] = len([user for user in n.users if user.op != "output"]) # record last activation node if _is_act(n._meta_data): @@ -472,19 +464,16 @@ class RegionManager: return region_list def _set_node_and_region_info(self, node_id: int, cur_n: Node, cur_reg: Region): - cur_n.node_info = NodeInfo(node_id) - if cur_n.op == 'call_module': + if cur_n.op == "call_module": target = cur_n.target submod = self.root_module.get_submodule(target) for p in list(submod.parameters(recurse=False)): - if p in self.param_region_map: cur_reg.shared_rid = self.param_region_map[p].r_id self.param_region_map[p].shared_rid = cur_reg.r_id - self.shared_region_pairs.append( - (self.param_region_map[p], cur_reg)) + self.shared_region_pairs.append((self.param_region_map[p], cur_reg)) else: self.param_region_map[p] = cur_reg @@ -499,12 +488,10 @@ class RegionManager: attr_itr = getattr(attr_itr, atom) if isinstance(attr_itr, torch.nn.Parameter): - if attr_itr in self.param_region_map: cur_reg.shared_rid = self.param_region_map[attr_itr].r_id self.param_region_map[attr_itr].shared_rid = cur_reg.r_id - self.shared_region_pairs.append( - (self.param_region_map[attr_itr], cur_reg)) + self.shared_region_pairs.append((self.param_region_map[attr_itr], cur_reg)) else: self.param_region_map[attr_itr] = cur_reg diff --git a/colossalai/auto_parallel/offload/runtime.py b/colossalai/auto_parallel/offload/runtime.py index 764ac608826b2860f3294ec9a5883745dfa10f57..cc790dfb089129a6f2cb101753f3d6431052b8a0 100644 --- a/colossalai/auto_parallel/offload/runtime.py +++ b/colossalai/auto_parallel/offload/runtime.py @@ -22,13 +22,13 @@ class SynPreFwdPostBwdOP(torch.autograd.Function): @staticmethod def forward(ctx, input_, fwd_info, bwd_info): ctx.bwd_info = bwd_info - d2h_rid = fwd_info.get('d2h_rid', None) + d2h_rid = fwd_info.get("d2h_rid", None) if d2h_rid is not None: free_region = GlobalRuntimeInfo().region_list[d2h_rid] assert isinstance(free_region, Region) free_region.free_cuda_data() - h2d_rid = fwd_info.get('h2d_rid', None) + h2d_rid = fwd_info.get("h2d_rid", None) if h2d_rid is not None: h2d_region = GlobalRuntimeInfo().region_list[h2d_rid] assert isinstance(h2d_region, Region) @@ -38,8 +38,7 @@ class SynPreFwdPostBwdOP(torch.autograd.Function): @staticmethod def backward(ctx, grad_output): - - h2d_rid = ctx.bwd_info.get('h2d_rid', None) + h2d_rid = ctx.bwd_info.get("h2d_rid", None) if h2d_rid is not None: pref_region = GlobalRuntimeInfo().region_list[h2d_rid] assert isinstance(pref_region, Region) @@ -64,13 +63,13 @@ class AsynPreFwdPostBwdOP(torch.autograd.Function): def forward(ctx, input_, fwd_info, bwd_info): ctx.bwd_info = bwd_info - sync_rid = fwd_info.get('sync_rid', None) + sync_rid = fwd_info.get("sync_rid", None) if sync_rid is not None: prefetch_event = GlobalRuntimeInfo().fwd_prefetch_event_map.get(sync_rid, None) if prefetch_event: prefetch_event.wait() - h2d_rid = fwd_info.get('h2d_rid', None) + h2d_rid = fwd_info.get("h2d_rid", None) if h2d_rid is not None: pref_region = GlobalRuntimeInfo().region_list[h2d_rid] assert isinstance(pref_region, Region) @@ -87,8 +86,7 @@ class AsynPreFwdPostBwdOP(torch.autograd.Function): @staticmethod def backward(ctx, grad_output): - - sync_rid = ctx.bwd_info.get('sync_rid', None) + sync_rid = ctx.bwd_info.get("sync_rid", None) if sync_rid is not None: wait_region = GlobalRuntimeInfo().region_list[sync_rid] assert isinstance(wait_region, Region) @@ -98,7 +96,7 @@ class AsynPreFwdPostBwdOP(torch.autograd.Function): else: wait_region.move_param_to_cuda() - h2d_rid = ctx.bwd_info.get('h2d_rid', None) + h2d_rid = ctx.bwd_info.get("h2d_rid", None) if h2d_rid is not None: pref_region = GlobalRuntimeInfo().region_list[h2d_rid] assert isinstance(pref_region, Region) @@ -114,7 +112,7 @@ class AsynPreFwdPostBwdOP(torch.autograd.Function): def convert_fwd_upload_bwd_offload_to_action(tensor, fwd_info, bwd_info): - ''' + """ Convert Upload and Offload operation into runtime action. Argument: @@ -123,14 +121,14 @@ def convert_fwd_upload_bwd_offload_to_action(tensor, fwd_info, bwd_info): that need to be uploaded, or freed during forward pass. bwd_info(dict): information dict, which contains region indices that need to be uploaded during backward pass. - ''' + """ with torch._C.DisableTorchFunction(): ret = SynPreFwdPostBwdOP.apply(tensor, fwd_info, bwd_info) return ret def convert_fwd_prefetch_bwd_offload_to_action(tensor, fwd_info, bwd_info): - ''' + """ Convert Prefetch and Offload operation into runtime action. Argument: @@ -139,7 +137,7 @@ def convert_fwd_prefetch_bwd_offload_to_action(tensor, fwd_info, bwd_info): that need to be prefetched, waited, or freed during forward pass. bwd_info(dict): information dict, which contains region indices that need to be prefetched or waited during backward pass. - ''' + """ with torch._C.DisableTorchFunction(): ret = AsynPreFwdPostBwdOP.apply(tensor, fwd_info, bwd_info) return ret @@ -176,22 +174,22 @@ def runtime_syn_offload_apply_pass(gm: torch.fx.GraphModule, region_list: List[R # forward upload fwd_info = {} if requires_upload_p_in_fwd(region_list[region.shared_rid]): - fwd_info['h2d_rid'] = region.r_id + fwd_info["h2d_rid"] = region.r_id # forward offload if r_idx > 0 and region_list[r_idx - 1].need_offload: - fwd_info['d2h_rid'] = r_idx - 1 + fwd_info["d2h_rid"] = r_idx - 1 bwd_info = {} # backward upload if r_idx > 0 and region_list[r_idx - 1].need_offload: - bwd_info['h2d_rid'] = region_list[r_idx - 1].r_id + bwd_info["h2d_rid"] = region_list[r_idx - 1].r_id if fwd_info or bwd_info: with mod_graph.inserting_after(last_inp_node): - new_node = mod_graph.create_node('call_function', - convert_fwd_upload_bwd_offload_to_action, - args=(last_inp_node, fwd_info, bwd_info)) + new_node = mod_graph.create_node( + "call_function", convert_fwd_upload_bwd_offload_to_action, args=(last_inp_node, fwd_info, bwd_info) + ) replace_node_users(last_inp_node, new_node) last_inp_node = region.nodes[-1] @@ -210,9 +208,9 @@ def runtime_asyn_offload_apply_pass(gm: torch.fx.GraphModule, region_list: List[ first_region_with_p = [region for region in region_list if region.param_size][0] fwd_info = {"h2d_rid": first_region_with_p.r_id} with mod_graph.inserting_after(last_inp_node): - upload_apply_node = mod_graph.create_node('call_function', - convert_fwd_upload_bwd_offload_to_action, - args=(last_inp_node, fwd_info, {})) + upload_apply_node = mod_graph.create_node( + "call_function", convert_fwd_upload_bwd_offload_to_action, args=(last_inp_node, fwd_info, {}) + ) replace_node_users(last_inp_node, upload_apply_node) last_inp_node = upload_apply_node @@ -220,37 +218,39 @@ def runtime_asyn_offload_apply_pass(gm: torch.fx.GraphModule, region_list: List[ # forward prefetch fwd_info = {} if region.param_size: - fwd_info['sync_rid'] = region.r_id + fwd_info["sync_rid"] = region.r_id fwd_prefetch_region = region.fwd_prefetch_region if fwd_prefetch_region and requires_upload_p_in_fwd(region_list[fwd_prefetch_region.shared_rid]): - fwd_info['h2d_rid'] = fwd_prefetch_region.r_id + fwd_info["h2d_rid"] = fwd_prefetch_region.r_id # forward offload if r_idx > 0 and region_list[r_idx - 1].need_offload: - fwd_info['d2h_rid'] = r_idx - 1 + fwd_info["d2h_rid"] = r_idx - 1 bwd_info = {} # backward prefetch if r_idx > 0 and region_list[r_idx - 1].need_offload: - bwd_info['sync_rid'] = r_idx - 1 + bwd_info["sync_rid"] = r_idx - 1 if r_idx > 0 and region_list[r_idx - 1].bwd_prefetch_region: - bwd_info['h2d_rid'] = region_list[r_idx - 1].bwd_prefetch_region.r_id + bwd_info["h2d_rid"] = region_list[r_idx - 1].bwd_prefetch_region.r_id if fwd_info or bwd_info: with mod_graph.inserting_after(last_inp_node): - new_node = mod_graph.create_node('call_function', - convert_fwd_prefetch_bwd_offload_to_action, - args=(last_inp_node, fwd_info, bwd_info)) + new_node = mod_graph.create_node( + "call_function", + convert_fwd_prefetch_bwd_offload_to_action, + args=(last_inp_node, fwd_info, bwd_info), + ) replace_node_users(last_inp_node, new_node) last_inp_node = region.nodes[-1] if region.bwd_prefetch_region: - bwd_info = {'h2d_rid': region.bwd_prefetch_region.r_id} + bwd_info = {"h2d_rid": region.bwd_prefetch_region.r_id} with mod_graph.inserting_after(last_inp_node): - new_node = mod_graph.create_node('call_function', - convert_fwd_prefetch_bwd_offload_to_action, - args=(last_inp_node, {}, bwd_info)) + new_node = mod_graph.create_node( + "call_function", convert_fwd_prefetch_bwd_offload_to_action, args=(last_inp_node, {}, bwd_info) + ) replace_node_users(last_inp_node, new_node) # gm.graph.print_tabular() return gm diff --git a/colossalai/auto_parallel/offload/solver.py b/colossalai/auto_parallel/offload/solver.py index 161f7ff868981d913047d7073bb144694ab68591..a6b4904f2617eff77216dfaf3a4fb5fd2449e061 100644 --- a/colossalai/auto_parallel/offload/solver.py +++ b/colossalai/auto_parallel/offload/solver.py @@ -1,6 +1,6 @@ import time -from typing import List, Dict, Type from abc import ABC, abstractmethod +from typing import Dict, List, Type NOT_NVML = False try: @@ -10,10 +10,11 @@ except: import torch from torch.fx.node import Node + from colossalai.utils.cuda import get_current_device -from .training_simulator import TrainingSimulator, SynTrainingSimulator, AsynTrainingSimulator from .region import Region +from .training_simulator import AsynTrainingSimulator, SynTrainingSimulator, TrainingSimulator from .util import NodeInfo, NvDevicePower @@ -49,19 +50,14 @@ class Solver(ABC): It is used to reduce the memory budget. Due to some errors in the estimation of peak memory and execution time. """ - def __init__(self, - region_list: List[Region], - memory_budget: float = -1.0, - error_factor: float = 0.95) -> None: - + def __init__(self, region_list: List[Region], memory_budget: float = -1.0, error_factor: float = 0.95) -> None: self.region_list = region_list self.error_factor: float = error_factor if memory_budget > 0: self.memory_budget = memory_budget * self.error_factor else: - self.memory_budget = torch.cuda.get_device_properties( - get_current_device()).total_memory * self.error_factor + self.memory_budget = torch.cuda.get_device_properties(get_current_device()).total_memory * self.error_factor self.link_to_bandwidth: Dict[str, Dict[float, float]] = self._profile_bandwidth() self.comp_power: float = self._extract_computing_power() @@ -94,7 +90,7 @@ class Solver(ABC): if extra_cost == 0: # means data transfer overhead can be completely overlapped - return (float('inf'), total_mem_saving, peak_mem_saving) + return (float("inf"), total_mem_saving, peak_mem_saving) return (total_mem_saving / extra_cost, total_mem_saving, peak_mem_saving) def _compare_profit(self, profit_a: tuple, profit_b: tuple) -> bool: @@ -122,9 +118,7 @@ class Solver(ABC): self.best_ts = best_ts self._update_node_mem_info(best_ts.fwd_node_mem, best_ts.bwd_node_mem) - def _update_node_mem_info(self, - fwd_mem_info: Dict[Node, float], - bwd_mem_info: Dict[Node, float]): + def _update_node_mem_info(self, fwd_mem_info: Dict[Node, float], bwd_mem_info: Dict[Node, float]): """ Update the runtime memory information of the node. @@ -134,12 +128,10 @@ class Solver(ABC): """ for node, mem in fwd_mem_info.items(): - assert hasattr(node, 'node_info') and isinstance( - node.node_info, NodeInfo) + assert hasattr(node, "node_info") and isinstance(node.node_info, NodeInfo) node.node_info.runtime_fwd_mem = mem for node, mem in bwd_mem_info.items(): - assert hasattr(node, 'node_info') and isinstance( - node.node_info, NodeInfo) + assert hasattr(node, "node_info") and isinstance(node.node_info, NodeInfo) node.node_info.runtime_bwd_mem = mem def _extract_computing_power(self): @@ -159,12 +151,12 @@ class Solver(ABC): return NvDevicePower.RTX3080_FP16 * units elif device_name.__contains__("RTX 3090"): return NvDevicePower.RTX3090_FP16 * units - elif device_name.__contains__('V100'): + elif device_name.__contains__("V100"): return NvDevicePower.V100_FP16 * units elif device_name.__contains__("A100"): return NvDevicePower.A100_FP16 * units else: - raise TypeError(f'Unknown NVIDIA GPU device name {device_name}') + raise TypeError(f"Unknown NVIDIA GPU device name {device_name}") def _profile_bandwidth(self): """ @@ -172,9 +164,9 @@ class Solver(ABC): using data volumes ranging from 1KB to 1GB. """ - print('profiling bandwidth ......') + print("profiling bandwidth ......") link_to_bandwidth = {} - links = ['h2d', 'd2h'] + links = ["h2d", "d2h"] for link in links: t_size = 1024 @@ -182,24 +174,22 @@ class Solver(ABC): # from 1KB to 1GB for i in range(21): - if link == 'h2d': - src_tensor = torch.ones( - int(t_size), dtype=torch.int8, pin_memory=True) - dst_tensor = torch.ones( - (int(t_size)), dtype=torch.int8, device='cuda') - elif link == 'd2h': - src_tensor = torch.ones( - int(t_size), dtype=torch.int8, device='cuda') - dst_tensor = torch.ones( - (int(t_size)), dtype=torch.int8, pin_memory=True) + if link == "h2d": + src_tensor = torch.ones(int(t_size), dtype=torch.int8, pin_memory=True) + dst_tensor = torch.ones((int(t_size)), dtype=torch.int8, device="cuda") + elif link == "d2h": + src_tensor = torch.ones(int(t_size), dtype=torch.int8, device="cuda") + dst_tensor = torch.ones((int(t_size)), dtype=torch.int8, pin_memory=True) def func(): dst_tensor.copy_(src_tensor) size_to_bandwidth[t_size] = t_size / benchmark_func(func, number=5, repeat=3) - print(f'size: {t_size / 1024 ** 2:.3f} MB, ' - f'{src_tensor.device.type}-to-{dst_tensor.device.type} ' - f'bandwidth: {size_to_bandwidth[t_size] / 1024 ** 3:.3f} GB/s') + print( + f"size: {t_size / 1024 ** 2:.3f} MB, " + f"{src_tensor.device.type}-to-{dst_tensor.device.type} " + f"bandwidth: {size_to_bandwidth[t_size] / 1024 ** 3:.3f} GB/s" + ) t_size *= 2 @@ -208,10 +198,7 @@ class Solver(ABC): class SynGreedySolver(Solver): - - def __init__(self, - region_list: List[Region], - memory_budget: float = -1.0) -> None: + def __init__(self, region_list: List[Region], memory_budget: float = -1.0) -> None: super().__init__(region_list, memory_budget) self.best_ts: SynTrainingSimulator = None @@ -258,7 +245,8 @@ class SynGreedySolver(Solver): else: raise NotImplementedError( f"can't find the offload strategy met the memory budget {self.memory_budget / 1024 ** 2} MB, " - f"it needs {self.best_ts.peak_mem / 1024 ** 2:.3f} MB at least!") + f"it needs {self.best_ts.peak_mem / 1024 ** 2:.3f} MB at least!" + ) def _call_solver_l2l(self): """ @@ -270,7 +258,6 @@ class SynGreedySolver(Solver): region.is_syn = True def _try_to_offload(self, offload_region: Region): - # record previous information orig_need_offload = offload_region.need_offload assert not orig_need_offload @@ -297,23 +284,17 @@ class SynGreedySolver(Solver): ts = SynTrainingSimulator(self.region_list, self.comp_power, self.link_to_bandwidth) ts.execute() - extra_comm_cost = 2.0 * \ - ts._get_communication_overhead('h2d', offload_region.param_size) + extra_comm_cost = 2.0 * ts._get_communication_overhead("h2d", offload_region.param_size) # the shared region needs to be moved twice if offload_region.r_id < offload_region.shared_rid: extra_comm_cost *= 2.0 - profit = self._compute_offload_profit( - ts.total_mem_saving, self.best_ts.peak_mem - ts.peak_mem, extra_comm_cost) + profit = self._compute_offload_profit(ts.total_mem_saving, self.best_ts.peak_mem - ts.peak_mem, extra_comm_cost) return ts, profit class AsynGreedySolver(Solver): - - def __init__(self, - region_list: List[Region], - memory_budget: float = -1.0, - search_window_size: int = 3): + def __init__(self, region_list: List[Region], memory_budget: float = -1.0, search_window_size: int = 3): super().__init__(region_list, memory_budget) self.search_window_size = search_window_size @@ -331,7 +312,7 @@ class AsynGreedySolver(Solver): ts = AsynTrainingSimulator(self.region_list, self.comp_power, self.link_to_bandwidth) ts.execute() self._update_state(ts) - print("init peak memory", self.best_ts.peak_mem / 1024 ** 2, "MB") + print("init peak memory", self.best_ts.peak_mem / 1024**2, "MB") def _call_solver(self): """ @@ -358,18 +339,17 @@ class AsynGreedySolver(Solver): best_pref_ts = None # search when to prefetch the region offloaded - for host_region in self.region_list[region.r_id + 1:region.r_id + 1 + self.search_window_size]: + for host_region in self.region_list[region.r_id + 1 : region.r_id + 1 + self.search_window_size]: if host_region.bwd_prefetch_region is not None: continue - temp_ts, profit = self._try_to_offload( - host_region, region) + temp_ts, profit = self._try_to_offload(host_region, region) if self._compare_profit(profit, max_prefetch_profit): region_to_region_map[region.r_id] = host_region max_prefetch_profit = profit best_pref_ts = temp_ts - if profit[0] == float('inf'): + if profit[0] == float("inf"): break if self._compare_profit(max_prefetch_profit, max_offload_profit): @@ -392,7 +372,8 @@ class AsynGreedySolver(Solver): else: raise NotImplementedError( f"can't find the offload strategy met the memory budget {self.memory_budget / 1024 ** 2} MB, " - f"it needs {self.best_ts.peak_mem / 1024 ** 2:.3f} MB at least!") + f"it needs {self.best_ts.peak_mem / 1024 ** 2:.3f} MB at least!" + ) region_to_region_map.clear() @@ -452,7 +433,6 @@ class AsynGreedySolver(Solver): peak_mem_saving = 0 while len(self.region_to_region_map) and peak_mem_saving <= 0: - max_profit = (0,) best_ts = None undo_host_region = None @@ -464,8 +444,7 @@ class AsynGreedySolver(Solver): assert offload_region.need_offload assert not offload_region.is_syn - ts, profit = self._try_convert_to_syn_upload(host_region, - offload_region) + ts, profit = self._try_convert_to_syn_upload(host_region, offload_region) if self._compare_profit(profit, max_profit): undo_host_region = host_region @@ -474,7 +453,7 @@ class AsynGreedySolver(Solver): best_ts = ts if best_ts is None: - raise NotImplementedError('repair error!') + raise NotImplementedError("repair error!") assert not undo_offload_region.is_syn undo_offload_region.is_syn = True @@ -500,17 +479,13 @@ class AsynGreedySolver(Solver): ts.execute() extra_comm_cost = max(ts.iter_end_time - self.best_ts.iter_end_time, 0) - profit = self._compute_offload_profit( - ts.total_mem_saving, self.best_ts.peak_mem - ts.peak_mem, extra_comm_cost) + profit = self._compute_offload_profit(ts.total_mem_saving, self.best_ts.peak_mem - ts.peak_mem, extra_comm_cost) return ts, profit class SolverFactory: - solvers: Dict[str, Type[Solver]] = { - 'syn': SynGreedySolver, - 'asyn': AsynGreedySolver - } + solvers: Dict[str, Type[Solver]] = {"syn": SynGreedySolver, "asyn": AsynGreedySolver} @staticmethod def create(solver_name: str) -> Type[Solver]: diff --git a/colossalai/auto_parallel/offload/training_simulator.py b/colossalai/auto_parallel/offload/training_simulator.py index de58023ec2d6a4b4247ad838f4e9c2e9a56da692..728d8daf9a467ef393c2291bbc9c0fbbe674e579 100644 --- a/colossalai/auto_parallel/offload/training_simulator.py +++ b/colossalai/auto_parallel/offload/training_simulator.py @@ -1,7 +1,7 @@ import bisect -from typing import List, Dict -from collections import OrderedDict from abc import ABC, abstractmethod +from collections import OrderedDict +from typing import Dict, List from torch.fx.node import Node @@ -26,10 +26,7 @@ class TrainingSimulator(ABC): link_to_bw (Dict[str, Dict[float, float]]): communication links and the corresponding bandwidth. """ - def __init__(self, - region_list: List[Region], - comp_power: float, - link_to_bw: Dict[str, Dict[float, float]]) -> None: + def __init__(self, region_list: List[Region], comp_power: float, link_to_bw: Dict[str, Dict[float, float]]) -> None: self.region_list = region_list self.region_num = len(region_list) @@ -87,11 +84,7 @@ class TrainingSimulator(ABC): class SynTrainingSimulator(TrainingSimulator): - - def __init__(self, - region_list: List[Region], - comp_power: float, - link_to_bw: Dict[str, Dict[float, float]]) -> None: + def __init__(self, region_list: List[Region], comp_power: float, link_to_bw: Dict[str, Dict[float, float]]) -> None: super().__init__(region_list, comp_power, link_to_bw) def execute(self): @@ -115,8 +108,7 @@ class SynTrainingSimulator(TrainingSimulator): self.runtime_mem += region.param_size for node in region.nodes: - self.runtime_mem += calculate_fwd_tmp(node) + \ - calculate_fwd_out(node) + self.runtime_mem += calculate_fwd_tmp(node) + calculate_fwd_out(node) self.fwd_node_mem[node] = self.runtime_mem self.peak_mem = max(self.runtime_mem, self.peak_mem) self.total_mem_saving += node.node_info.runtime_fwd_mem - self.runtime_mem @@ -141,18 +133,15 @@ class SynTrainingSimulator(TrainingSimulator): self.runtime_mem += region.param_size for node in region.nodes.__reversed__(): - self.runtime_mem -= calculate_fwd_out(node) - self.runtime_mem += node.meta['bwd_mem_tmp'] + \ - node.meta['bwd_mem_out'] + self.runtime_mem += node.meta["bwd_mem_tmp"] + node.meta["bwd_mem_out"] self.peak_mem = max(self.runtime_mem, self.peak_mem) # The memory savings of a node may be negative due to parameter prefetch. self.total_mem_saving += node.node_info.runtime_bwd_mem - self.runtime_mem self.bwd_node_mem[node] = self.runtime_mem - self.runtime_mem -= (node.meta['bwd_mem_tmp'] + - calculate_fwd_tmp(node)) + self.runtime_mem -= node.meta["bwd_mem_tmp"] + calculate_fwd_tmp(node) # free bwd_mem_out self.bwd_node_deps[node] = len(node.all_input_nodes) @@ -160,12 +149,14 @@ class SynTrainingSimulator(TrainingSimulator): if user_node in self.bwd_node_deps: self.bwd_node_deps[user_node] -= 1 if self.bwd_node_deps[user_node] <= 0: - self.runtime_mem -= user_node.meta['bwd_mem_out'] + self.runtime_mem -= user_node.meta["bwd_mem_out"] if self.runtime_mem < 0: - raise ValueError(f"region id: {region.r_id}, node name: {node.name}, " - f"runtime_mem: {self.runtime_mem / 1024 ** 2:.3f}MB ---" - f"runtime memory computed less than 0, which is miscalculated!") + raise ValueError( + f"region id: {region.r_id}, node name: {node.name}, " + f"runtime_mem: {self.runtime_mem / 1024 ** 2:.3f}MB ---" + f"runtime memory computed less than 0, which is miscalculated!" + ) # release parameter and offload gradient in region if region.r_id == region.shared_rid: @@ -177,23 +168,16 @@ class SynTrainingSimulator(TrainingSimulator): class AsynTrainingSimulator(TrainingSimulator): - - def __init__(self, - region_list: List[Region], - comp_power: float, - link_to_bw: Dict[str, Dict[float, float]]) -> None: + def __init__(self, region_list: List[Region], comp_power: float, link_to_bw: Dict[str, Dict[float, float]]) -> None: super().__init__(region_list, comp_power, link_to_bw) self.iter_end_time: int = 0 # the last computation execution period - self.last_comp: ExecutionPeriod = ExecutionPeriod( - start_time=0, end_time=0) + self.last_comp: ExecutionPeriod = ExecutionPeriod(start_time=0, end_time=0) # the last parameter prefetch execution period - self.last_h2d: ExecutionPeriod = ExecutionPeriod( - start_time=0, end_time=0) + self.last_h2d: ExecutionPeriod = ExecutionPeriod(start_time=0, end_time=0) # the last gradient offload execution period - self.last_d2h: ExecutionPeriod = ExecutionPeriod( - start_time=0, end_time=0) + self.last_d2h: ExecutionPeriod = ExecutionPeriod(start_time=0, end_time=0) # the forward computation execution period of the region self.fwd_reg_to_comp: OrderedDict[int, ExecutionPeriod] = OrderedDict() # the forward parameter prefetch execution period of the region @@ -204,10 +188,8 @@ class AsynTrainingSimulator(TrainingSimulator): self.bwd_reg_to_pref: OrderedDict[int, ExecutionPeriod] = OrderedDict() # the gradient offload execution period of the region # which is divided into those that are waiting and those that have been released - self.bwd_reg_to_offl_waiting: OrderedDict[int, - ExecutionPeriod] = OrderedDict() - self.bwd_reg_to_offl_freed: OrderedDict[int, - ExecutionPeriod] = OrderedDict() + self.bwd_reg_to_offl_waiting: OrderedDict[int, ExecutionPeriod] = OrderedDict() + self.bwd_reg_to_offl_freed: OrderedDict[int, ExecutionPeriod] = OrderedDict() # the region buffer, which records regions that are offloaded but not released self.reg_buffer_to_free: List[int] = [] @@ -217,10 +199,8 @@ class AsynTrainingSimulator(TrainingSimulator): # the region execution flow, # where fwd_reg_flow[i,j] denotes whether the parameters of j-th region are in the GPU # when the execution reaches the i-th region. - self.fwd_reg_flow = torch.zeros( - (self.region_num, self.region_num)).bool() - self.bwd_reg_flow = torch.zeros( - (self.region_num, self.region_num)).bool() + self.fwd_reg_flow = torch.zeros((self.region_num, self.region_num)).bool() + self.bwd_reg_flow = torch.zeros((self.region_num, self.region_num)).bool() def execute(self): """ @@ -232,7 +212,7 @@ class AsynTrainingSimulator(TrainingSimulator): for reg in self.region_list: if reg.param_size and reg.r_id < self.region_num - 1: - for nr in self.region_list[reg.r_id + 1:]: + for nr in self.region_list[reg.r_id + 1 :]: if nr.param_size and requires_upload_p_in_fwd(self.region_list[nr.shared_rid]): reg.fwd_prefetch_region = nr break @@ -249,8 +229,7 @@ class AsynTrainingSimulator(TrainingSimulator): self.runtime_mem -= self.region_list[reg_id].param_size self.bwd_reg_to_offl_waiting.clear() - self.iter_end_time = max( - self.last_comp.end_time, self.last_d2h.end_time) + self.iter_end_time = max(self.last_comp.end_time, self.last_d2h.end_time) def _insert_h2d_exec(self, region: Region, is_fwd: bool = True): """ @@ -258,10 +237,8 @@ class AsynTrainingSimulator(TrainingSimulator): """ pref_start_time = max(self.last_h2d.end_time, self.last_comp.end_time) - pref_end_time = pref_start_time + \ - 2.0 * self._get_communication_overhead('h2d', region.param_size) - pref_ep = ExecutionPeriod( - start_time=pref_start_time, end_time=pref_end_time) + pref_end_time = pref_start_time + 2.0 * self._get_communication_overhead("h2d", region.param_size) + pref_ep = ExecutionPeriod(start_time=pref_start_time, end_time=pref_end_time) if is_fwd: self.fwd_reg_to_pref[region.r_id] = pref_ep else: @@ -276,18 +253,16 @@ class AsynTrainingSimulator(TrainingSimulator): if is_fwd: reg_to_comp = self.fwd_reg_to_comp reg_to_pref = self.fwd_reg_to_pref - flop_key = 'fwd_flop' + flop_key = "fwd_flop" else: reg_to_comp = self.bwd_reg_to_comp reg_to_pref = self.bwd_reg_to_pref - flop_key = 'bwd_flop' - comp_start_time = max(self.last_comp.end_time, reg_to_pref.get( - region.r_id, ExecutionPeriod(0, 0)).end_time) - comp_end_time = comp_start_time + \ - sum([self._get_computing_overhead(node.meta.get(flop_key, 0)) - for node in region.nodes]) - comp_ep = ExecutionPeriod( - start_time=comp_start_time, end_time=comp_end_time) + flop_key = "bwd_flop" + comp_start_time = max(self.last_comp.end_time, reg_to_pref.get(region.r_id, ExecutionPeriod(0, 0)).end_time) + comp_end_time = comp_start_time + sum( + [self._get_computing_overhead(node.meta.get(flop_key, 0)) for node in region.nodes] + ) + comp_ep = ExecutionPeriod(start_time=comp_start_time, end_time=comp_end_time) reg_to_comp[region.r_id] = comp_ep self.last_comp = comp_ep @@ -297,10 +272,8 @@ class AsynTrainingSimulator(TrainingSimulator): """ offl_start_time = max(self.last_d2h.end_time, self.last_comp.end_time) - offl_end_time = offl_start_time + \ - self._get_communication_overhead('d2h', region.param_size) - offl_ep = ExecutionPeriod( - start_time=offl_start_time, end_time=offl_end_time) + offl_end_time = offl_start_time + self._get_communication_overhead("d2h", region.param_size) + offl_ep = ExecutionPeriod(start_time=offl_start_time, end_time=offl_end_time) self.bwd_reg_to_offl_waiting[region.r_id] = offl_ep self.last_d2h = offl_ep @@ -332,20 +305,17 @@ class AsynTrainingSimulator(TrainingSimulator): self.fwd_reg_flow[region.r_id, region.r_id] = True else: self.fwd_reg_flow[region.r_id] = self.fwd_reg_flow[region.r_id - 1] - self.fwd_reg_flow[region.r_id, - self.reg_buffer_to_free] = False + self.fwd_reg_flow[region.r_id, self.reg_buffer_to_free] = False self.reg_buffer_to_free.clear() # prefetch parameters of the next region fwd_prefetch_region = region.fwd_prefetch_region if fwd_prefetch_region and requires_upload_p_in_fwd(self.region_list[fwd_prefetch_region.shared_rid]): self.runtime_mem += fwd_prefetch_region.param_size - self.fwd_reg_flow[region.r_id, - fwd_prefetch_region.r_id] = True + self.fwd_reg_flow[region.r_id, fwd_prefetch_region.r_id] = True for node in region.nodes: - self.runtime_mem += calculate_fwd_tmp(node) + \ - calculate_fwd_out(node) + self.runtime_mem += calculate_fwd_tmp(node) + calculate_fwd_out(node) self.peak_mem = max(self.runtime_mem, self.peak_mem) self.total_mem_saving += node.node_info.runtime_fwd_mem - self.runtime_mem @@ -354,8 +324,7 @@ class AsynTrainingSimulator(TrainingSimulator): if region.need_offload: self.runtime_mem -= region.param_size - assert len( - self.reg_buffer_to_free) <= 1, f'{len(self.reg_buffer_to_free)}' + assert len(self.reg_buffer_to_free) <= 1, f"{len(self.reg_buffer_to_free)}" self.reg_buffer_to_free.append(region.r_id) def _eval_bwd_cost_per_region(self, region: Region): @@ -398,8 +367,7 @@ class AsynTrainingSimulator(TrainingSimulator): self.bwd_reg_flow[region.r_id] = self.bwd_reg_flow[region.r_id + 1] else: self.bwd_reg_flow[region.r_id] = self.fwd_reg_flow[-1] - self.bwd_reg_flow[region.r_id, - self.reg_buffer_to_free] = False + self.bwd_reg_flow[region.r_id, self.reg_buffer_to_free] = False # free gradients in the buffer while len(self.reg_buffer_to_free): @@ -415,8 +383,7 @@ class AsynTrainingSimulator(TrainingSimulator): bwd_prefetch_region = region.bwd_prefetch_region if bwd_prefetch_region: self.runtime_mem += bwd_prefetch_region.param_size - self.bwd_reg_flow[region.r_id, - bwd_prefetch_region.r_id] = True + self.bwd_reg_flow[region.r_id, bwd_prefetch_region.r_id] = True # add the gradient of the parameter if region.r_id < region.shared_rid: @@ -426,10 +393,8 @@ class AsynTrainingSimulator(TrainingSimulator): self.runtime_mem += region.param_size for node in region.nodes.__reversed__(): - self.runtime_mem -= calculate_fwd_out(node) - self.runtime_mem += node.meta['bwd_mem_tmp'] + \ - node.meta['bwd_mem_out'] + self.runtime_mem += node.meta["bwd_mem_tmp"] + node.meta["bwd_mem_out"] self.peak_mem = max(self.runtime_mem, self.peak_mem) # The memory savings of a node may be negative due to parameter prefetch. @@ -437,8 +402,7 @@ class AsynTrainingSimulator(TrainingSimulator): self.bwd_node_mem[node] = self.runtime_mem - self.runtime_mem -= (node.meta['bwd_mem_tmp'] + - calculate_fwd_tmp(node)) + self.runtime_mem -= node.meta["bwd_mem_tmp"] + calculate_fwd_tmp(node) # free bwd_mem_out self.bwd_node_deps[node] = len(node.all_input_nodes) @@ -446,12 +410,14 @@ class AsynTrainingSimulator(TrainingSimulator): if user_node in self.bwd_node_deps: self.bwd_node_deps[user_node] -= 1 if self.bwd_node_deps[user_node] <= 0: - self.runtime_mem -= user_node.meta['bwd_mem_out'] + self.runtime_mem -= user_node.meta["bwd_mem_out"] if self.runtime_mem < 0: - raise ValueError(f"region id: {region.r_id}, node name: {node.name}, " - f"runtime_mem: {self.runtime_mem / 1024 ** 2:.3f}MB ---" - f"runtime memory computed less than 0, which is miscalculated!") + raise ValueError( + f"region id: {region.r_id}, node name: {node.name}, " + f"runtime_mem: {self.runtime_mem / 1024 ** 2:.3f}MB ---" + f"runtime memory computed less than 0, which is miscalculated!" + ) # release parameters of the region if requires_release_p_in_bwd(self.region_list[region.shared_rid]): diff --git a/colossalai/auto_parallel/offload/util.py b/colossalai/auto_parallel/offload/util.py index 6b010512cc9c99b6dff4acf2e7bd97d12d8146c0..cb65da79c5a27ac369367d72d6429fa4e02cb62a 100644 --- a/colossalai/auto_parallel/offload/util.py +++ b/colossalai/auto_parallel/offload/util.py @@ -35,7 +35,6 @@ class NvDevicePower: class GlobalRuntimeInfo(metaclass=SingletonMeta): - def __init__(self): self.h2d_stream = torch.cuda.Stream() self.d2h_stream = torch.cuda.Stream() @@ -50,21 +49,18 @@ def compute_act_peak_mem(region_list: List[Region]) -> float: # forward for region in region_list: for node in region.nodes: - runtime_mem = runtime_mem + \ - calculate_fwd_tmp(node) + calculate_fwd_out(node) + runtime_mem = runtime_mem + calculate_fwd_tmp(node) + calculate_fwd_out(node) act_peak_mem = max(runtime_mem, act_peak_mem) # backward bwd_deps = {} for region in region_list.__reversed__(): for node in region.nodes.__reversed__(): runtime_mem -= calculate_fwd_out(node) - runtime_mem = runtime_mem + \ - node.meta['bwd_mem_tmp'] + node.meta['bwd_mem_out'] + runtime_mem = runtime_mem + node.meta["bwd_mem_tmp"] + node.meta["bwd_mem_out"] act_peak_mem = max(runtime_mem, act_peak_mem) - runtime_mem = runtime_mem - \ - node.meta['bwd_mem_tmp'] - calculate_fwd_tmp(node) + runtime_mem = runtime_mem - node.meta["bwd_mem_tmp"] - calculate_fwd_tmp(node) # free bwd_mem_out bwd_deps[node] = len(node.all_input_nodes) @@ -72,7 +68,7 @@ def compute_act_peak_mem(region_list: List[Region]) -> float: if user_node in bwd_deps: bwd_deps[user_node] -= 1 if bwd_deps[user_node] <= 0: - runtime_mem -= user_node.meta['bwd_mem_out'] + runtime_mem -= user_node.meta["bwd_mem_out"] return act_peak_mem @@ -86,13 +82,15 @@ def compute_total_param_mem(region_list: List[Region]) -> float: def requires_upload_p_in_fwd(shared_reg: Region): - return (shared_reg.r_id >= shared_reg.shared_rid) or (shared_reg.r_id < shared_reg.shared_rid - and shared_reg.need_offload) + return (shared_reg.r_id >= shared_reg.shared_rid) or ( + shared_reg.r_id < shared_reg.shared_rid and shared_reg.need_offload + ) def requires_release_p_in_bwd(shared_reg: Region): - return (shared_reg.r_id >= shared_reg.shared_rid) or (shared_reg.r_id < shared_reg.shared_rid - and shared_reg.need_offload) + return (shared_reg.r_id >= shared_reg.shared_rid) or ( + shared_reg.r_id < shared_reg.shared_rid and shared_reg.need_offload + ) def requires_offload_g_in_bwd(region: Region): diff --git a/colossalai/auto_parallel/passes/comm_metainfo_pass.py b/colossalai/auto_parallel/passes/comm_metainfo_pass.py index ffda58e0689f1b0e7fa962dd328eee7453e559ec..ba290ee839d8bbe053c8af32a7935387c6618c6a 100644 --- a/colossalai/auto_parallel/passes/comm_metainfo_pass.py +++ b/colossalai/auto_parallel/passes/comm_metainfo_pass.py @@ -14,18 +14,20 @@ from colossalai.tensor.sharding_spec import ShardingSpec shape_consistency_manager = ShapeConsistencyManager() -def _construct_shard_meta_info(node: Node, origin_sharding_spec: ShardingSpec, - target_sharding_spec: ShardingSpec) -> ShardMetaInfo: +def _construct_shard_meta_info( + node: Node, origin_sharding_spec: ShardingSpec, target_sharding_spec: ShardingSpec +) -> ShardMetaInfo: # get comm_action_sequence and total_cost from shape_consistency_manager _, comm_action_sequence, total_cost = shape_consistency_manager.shape_consistency( - origin_sharding_spec, target_sharding_spec) + origin_sharding_spec, target_sharding_spec + ) meta_info = ShardMetaInfo() # NOTE: the cost in shape_consistency_manager.mem_cost is the count in number of numel # get mem cost for ShardMetaInfo mem_cost = shape_consistency_manager.mem_cost(comm_action_sequence) # extract user that has _meta_data and extract element length - input_node = next(n for n in node._input_nodes if hasattr(n, '_meta_data')) + input_node = next(n for n in node._input_nodes if hasattr(n, "_meta_data")) element_length = input_node._meta_data.element_size() mem_cost.fwd.activation *= element_length @@ -37,9 +39,11 @@ def _construct_shard_meta_info(node: Node, origin_sharding_spec: ShardingSpec, meta_info.memory_cost = mem_cost # get computation cost for ShardMetaInfo - meta_info.compute_cost = TrainCycleItem(total_cost['forward'] * element_length, - total_cost['backward'] * element_length, - total_cost['total'] * element_length) + meta_info.compute_cost = TrainCycleItem( + total_cost["forward"] * element_length, + total_cost["backward"] * element_length, + total_cost["total"] * element_length, + ) # get tensor shape for ShardMetaInfo origin_sharding_spec: ShardingSpec @@ -47,9 +51,9 @@ def _construct_shard_meta_info(node: Node, origin_sharding_spec: ShardingSpec, input_shape = origin_sharding_spec.get_sharded_shape_per_device() output_shape = target_sharding_spec.get_sharded_shape_per_device() - meta_info.fwd_in = [torch.rand(input_shape, device='meta')] + meta_info.fwd_in = [torch.rand(input_shape, device="meta")] meta_info.fwd_buffer = [] - meta_info.fwd_out = [torch.rand(output_shape, device='meta')] + meta_info.fwd_out = [torch.rand(output_shape, device="meta")] return meta_info @@ -62,8 +66,10 @@ def _runtime_apply_meta_info(node: Node, origin_spec_dict, sharding_spec_dict) - # extract node index and user node index args = node.args node_index, user_node_index = args[3], args[4] - origin_sharding_spec, target_sharding_spec = origin_spec_dict[node_index], sharding_spec_dict[node_index][ - user_node_index] + origin_sharding_spec, target_sharding_spec = ( + origin_spec_dict[node_index], + sharding_spec_dict[node_index][user_node_index], + ) return _construct_shard_meta_info(node, origin_sharding_spec, target_sharding_spec) @@ -77,37 +83,42 @@ def _runtime_comm_spec_apply_meta_info(node: Node, comm_actions_dict: Dict) -> S # this case is for all_reduce, there will be no memory cost meta_info = ShardMetaInfo() meta_info.memory_cost = TrainCycleItem(MemoryCost(), MemoryCost(), MemoryCost) - output_node = next(n for n in node.users if hasattr(n, '_meta_data')) + output_node = next(n for n in node.users if hasattr(n, "_meta_data")) element_length = output_node._meta_data.element_size() total_cost = comm_action.comm_spec.get_comm_cost() - meta_info.compute_cost = TrainCycleItem(total_cost['forward'] * element_length, - total_cost['backward'] * element_length, - total_cost['total'] * element_length) + meta_info.compute_cost = TrainCycleItem( + total_cost["forward"] * element_length, + total_cost["backward"] * element_length, + total_cost["total"] * element_length, + ) input_shape = output_shape = comm_action.comm_spec.sharding_spec.get_sharded_shape_per_device() - meta_info.fwd_in = [torch.rand(input_shape, device='meta')] + meta_info.fwd_in = [torch.rand(input_shape, device="meta")] meta_info.fwd_buffer = [] - meta_info.fwd_out = [torch.rand(output_shape, device='meta')] + meta_info.fwd_out = [torch.rand(output_shape, device="meta")] else: # this case will be handled by shape consistency manager - origin_sharding_spec, target_sharding_spec = comm_action.comm_spec['src_spec'], comm_action.comm_spec[ - 'tgt_spec'] + origin_sharding_spec, target_sharding_spec = ( + comm_action.comm_spec["src_spec"], + comm_action.comm_spec["tgt_spec"], + ) meta_info = _construct_shard_meta_info(node, origin_sharding_spec, target_sharding_spec) return meta_info -def comm_metainfo_pass(gm: GraphModule, sharding_spec_dict: Dict, origin_spec_dict: Dict, - comm_actions_dict: Dict) -> GraphModule: +def comm_metainfo_pass( + gm: GraphModule, sharding_spec_dict: Dict, origin_spec_dict: Dict, comm_actions_dict: Dict +) -> GraphModule: """ The method manages all the metainfo of the communication node (run_time_apply, runtime_comm_spec_apply) in the graph. """ for node in gm.graph.nodes: if node.target == runtime_apply: - setattr(node, 'best_strategy_info', _runtime_apply_meta_info(node, origin_spec_dict, sharding_spec_dict)) + setattr(node, "best_strategy_info", _runtime_apply_meta_info(node, origin_spec_dict, sharding_spec_dict)) elif node.target == runtime_comm_spec_apply: - setattr(node, 'best_strategy_info', _runtime_comm_spec_apply_meta_info(node, comm_actions_dict)) + setattr(node, "best_strategy_info", _runtime_comm_spec_apply_meta_info(node, comm_actions_dict)) else: pass return gm diff --git a/colossalai/auto_parallel/passes/meta_info_prop.py b/colossalai/auto_parallel/passes/meta_info_prop.py index bc0960483980b9242af531d4b309a9c74f735c16..9b000549de6ca8f80de4955353f42920de31fc64 100644 --- a/colossalai/auto_parallel/passes/meta_info_prop.py +++ b/colossalai/auto_parallel/passes/meta_info_prop.py @@ -21,16 +21,15 @@ def _normalize_tuple(x): @compatibility(is_backward_compatible=False) class MetaInfoProp: - def __init__(self, module: GraphModule) -> None: self.module = module self.func_dict = { - 'placeholder': self.placeholder_handler, - 'get_attr': self.get_attr_handler, - 'output': self.output_handler, - 'call_function': self.node_handler, - 'call_module': self.node_handler, - 'call_method': self.node_handler, + "placeholder": self.placeholder_handler, + "get_attr": self.get_attr_handler, + "output": self.output_handler, + "call_function": self.node_handler, + "call_module": self.node_handler, + "call_method": self.node_handler, } def _set_data_ptr(self, x): @@ -46,7 +45,7 @@ class MetaInfoProp: """ Check if the node is inplace operation. """ - if node.op == 'call_module': + if node.op == "call_module": return node.graph.owning_module.get_submodule(node.target).__class__ in OUTPUT_SAVED_MOD elif node.op == "call_function": return node.target in OUTPUT_SAVED_OPS @@ -66,7 +65,7 @@ class MetaInfoProp: Handle the placeholder node. """ graph_info = GraphInfo() - out = _normalize_tuple(getattr(node, '_meta_data', None)) + out = _normalize_tuple(getattr(node, "_meta_data", None)) graph_info.fwd_out = list(out) if out[0] is not None else [] node.meta = {**asdict(graph_info)} @@ -96,7 +95,7 @@ class MetaInfoProp: """ Handle other kind of nodes """ - assert hasattr(node, 'best_strategy_info'), f"Cannot find best_strategy_info in node {node}, {node.op}" + assert hasattr(node, "best_strategy_info"), f"Cannot find best_strategy_info in node {node}, {node.op}" graph_info = GraphInfo() meta_info = node.best_strategy_info meta_info: ShardMetaInfo @@ -126,7 +125,8 @@ class MetaInfoProp: for tensor in par.meta.get("fwd_out", []): tensor: torch.Tensor target_input_tensor = next( - (x for x in input_tensors if not x.data_ptr() and x.shape == tensor.shape), None) + (x for x in input_tensors if not x.data_ptr() and x.shape == tensor.shape), None + ) if target_input_tensor is not None: target_input_tensor.data_ptr = tensor.data_ptr @@ -148,7 +148,7 @@ class MetaInfoProp: graph_info.fwd_tmp = buffer_tensors graph_info.fwd_out = output_tensors - # fetch other memory informations + # fetch other memory information memory_cost = meta_info.memory_cost graph_info.fwd_mem_tmp = memory_cost.fwd.temp graph_info.fwd_mem_out = memory_cost.fwd.activation diff --git a/colossalai/auto_parallel/passes/runtime_apply_pass.py b/colossalai/auto_parallel/passes/runtime_apply_pass.py index a473bb6e973de453f47c7bc16c2e83b8e7fe86df..27afe72c0db84f506aab3ba544e71937703b2a4c 100644 --- a/colossalai/auto_parallel/passes/runtime_apply_pass.py +++ b/colossalai/auto_parallel/passes/runtime_apply_pass.py @@ -1,18 +1,10 @@ -from copy import deepcopy from typing import Dict, List import torch from torch.fx.node import Node from colossalai._analyzer.fx.node_util import MetaInfo -from colossalai.auto_parallel.tensor_shard.sharding_strategy import ( - CommAction, - CommType, - OperationData, - OperationDataType, - TrainCycleItem, -) -from colossalai.device.device_mesh import DeviceMesh +from colossalai.auto_parallel.tensor_shard.sharding_strategy import CommType, OperationDataType from colossalai.tensor.comm_spec import CommSpec from colossalai.tensor.shape_consistency import ShapeConsistencyManager from colossalai.tensor.sharding_spec import ShardingSpec @@ -30,19 +22,22 @@ def runtime_apply(node: Node, origin_dict: Dict, input_dict: Dict, node_index: i return shape_consistency_manager.apply_for_autoparallel_runtime(node, origin_sharding_spec, target_sharding_spec) -def runtime_apply_for_iterable_object(node: Node, origin_dict: Dict, input_dict: Dict, node_index: int, - user_node_index: int): +def runtime_apply_for_iterable_object( + node: Node, origin_dict: Dict, input_dict: Dict, node_index: int, user_node_index: int +): """ This method will be invoked during runtime to do the shape consistency, which makes sure the activations in type of tuple or list is converted into the user node expected form. """ rst = [] - for index, (origin_sharding_spec, - target_sharding_spec) in enumerate(zip(origin_dict[node_index], - input_dict[node_index][user_node_index])): + for index, (origin_sharding_spec, target_sharding_spec) in enumerate( + zip(origin_dict[node_index], input_dict[node_index][user_node_index]) + ): rst.append( - shape_consistency_manager.apply_for_autoparallel_runtime(node[index], origin_sharding_spec, - target_sharding_spec)) + shape_consistency_manager.apply_for_autoparallel_runtime( + node[index], origin_sharding_spec, target_sharding_spec + ) + ) rst = type(node)(rst) return rst @@ -55,8 +50,8 @@ def runtime_comm_spec_apply(tensor: torch.Tensor, comm_actions_dict: Dict, node_ if isinstance(comm_action.comm_spec, CommSpec): rst = comm_action.comm_spec.covert_spec_to_action(tensor) else: - origin_sharding_spec = comm_action.comm_spec['src_spec'] - tgt_sharding_spec = comm_action.comm_spec['tgt_spec'] + origin_sharding_spec = comm_action.comm_spec["src_spec"] + tgt_sharding_spec = comm_action.comm_spec["tgt_spec"] rst = shape_consistency_manager.apply_for_autoparallel_runtime(tensor, origin_sharding_spec, tgt_sharding_spec) return rst @@ -70,16 +65,16 @@ def _preprocess_graph(nodes: List[Node]): node_to_index_dict = {} index = 0 for node in nodes: - if node.target == 'sharding_spec_convert_dict': + if node.target == "sharding_spec_convert_dict": input_dict_node = node continue - if node.target == 'origin_node_sharding_spec_dict': + if node.target == "origin_node_sharding_spec_dict": origin_dict_node = node continue - if node.target == 'comm_actions_dict': + if node.target == "comm_actions_dict": comm_actions_dict_node = node continue - if not hasattr(node, 'best_strategy'): + if not hasattr(node, "best_strategy"): continue node_to_index_dict[node] = index index += 1 @@ -97,41 +92,46 @@ def _shape_consistency_apply(gm: torch.fx.GraphModule): input_dict_node, origin_dict_node, _, node_to_index_dict = _preprocess_graph(nodes) for node in nodes: - if not hasattr(node, 'best_strategy') or node.op == 'output': + if not hasattr(node, "best_strategy") or node.op == "output": continue for user_node_index, user_node in enumerate(node.strategies_vector.successor_nodes): if isinstance(node.sharding_spec, (list, tuple)): assert isinstance( - node.target_sharding_specs, - (list, - tuple)), 'target sharding specs should be tuple or list when node.sharding_spec is tuple or list' + node.target_sharding_specs, (list, tuple) + ), "target sharding specs should be tuple or list when node.sharding_spec is tuple or list" total_difference = 0 - for sharding_spec, target_sharding_spec in zip(node.sharding_spec, - node.target_sharding_specs[user_node_index]): + for sharding_spec, target_sharding_spec in zip( + node.sharding_spec, node.target_sharding_specs[user_node_index] + ): total_difference += sharding_spec.sharding_sequence_difference(target_sharding_spec) if total_difference == 0: continue with mod_graph.inserting_before(user_node): - shape_consistency_node = mod_graph.create_node('call_function', - runtime_apply_for_iterable_object, - args=(node, origin_dict_node, input_dict_node, - node_to_index_dict[node], user_node_index)) + shape_consistency_node = mod_graph.create_node( + "call_function", + runtime_apply_for_iterable_object, + args=(node, origin_dict_node, input_dict_node, node_to_index_dict[node], user_node_index), + ) else: - assert isinstance(node.sharding_spec, - ShardingSpec), 'node.sharding_spec should be type of ShardingSpec, tuple or list.' + assert isinstance( + node.sharding_spec, ShardingSpec + ), "node.sharding_spec should be type of ShardingSpec, tuple or list." if node.sharding_spec.sharding_sequence_difference(node.target_sharding_specs[user_node_index]) == 0: continue with mod_graph.inserting_before(user_node): - shape_consistency_node = mod_graph.create_node('call_function', - runtime_apply, - args=(node, origin_dict_node, input_dict_node, - node_to_index_dict[node], user_node_index)) - if hasattr(user_node.meta['info'], 'activation_checkpoint'): - MetaInfo(shape_consistency_node, - mod_dir=user_node.meta['info'].mod_dir, - activation_checkpoint=tuple(user_node.meta['info'].activation_checkpoint)) + shape_consistency_node = mod_graph.create_node( + "call_function", + runtime_apply, + args=(node, origin_dict_node, input_dict_node, node_to_index_dict[node], user_node_index), + ) + if hasattr(user_node.meta["info"], "activation_checkpoint"): + MetaInfo( + shape_consistency_node, + mod_dir=user_node.meta["info"].mod_dir, + activation_checkpoint=tuple(user_node.meta["info"].activation_checkpoint), + ) new_args = list(user_node.args) new_kwargs = dict(user_node.kwargs) # the origin node may be a positional argument or key word argument of user node @@ -158,12 +158,11 @@ def _comm_spec_apply(gm: torch.fx.GraphModule): _, _, comm_actions_dict_node, node_to_index_dict = _preprocess_graph(nodes) for node in nodes: - if not hasattr(node, 'best_strategy') or node.op == 'output': + if not hasattr(node, "best_strategy") or node.op == "output": continue comm_actions = node.best_strategy.communication_actions for op_data, comm_action in comm_actions.items(): - if comm_action.comm_type == CommType.HOOK: continue if comm_action.comm_type == CommType.BEFORE: @@ -174,10 +173,11 @@ def _comm_spec_apply(gm: torch.fx.GraphModule): else: comm_object = node.args[comm_action.arg_index] with mod_graph.inserting_before(node): - comm_spec_apply_node = mod_graph.create_node('call_function', - runtime_comm_spec_apply, - args=(comm_object, comm_actions_dict_node, - node_to_index_dict[node], op_data.name)) + comm_spec_apply_node = mod_graph.create_node( + "call_function", + runtime_comm_spec_apply, + args=(comm_object, comm_actions_dict_node, node_to_index_dict[node], op_data.name), + ) # the origin node may be a positional argument or key word argument of user node if comm_action.key_for_kwarg is not None: # substitute the origin node with comm_spec_apply_node @@ -192,10 +192,11 @@ def _comm_spec_apply(gm: torch.fx.GraphModule): elif comm_action.comm_type == CommType.AFTER: with mod_graph.inserting_after(node): - comm_spec_apply_node = mod_graph.create_node('call_function', - runtime_comm_spec_apply, - args=(node, comm_actions_dict_node, - node_to_index_dict[node], op_data.name)) + comm_spec_apply_node = mod_graph.create_node( + "call_function", + runtime_comm_spec_apply, + args=(node, comm_actions_dict_node, node_to_index_dict[node], op_data.name), + ) user_list = list(node.users.keys()) for user in user_list: if user == comm_spec_apply_node: @@ -211,15 +212,17 @@ def _comm_spec_apply(gm: torch.fx.GraphModule): # substitute the origin node with comm_spec_apply_node new_kwargs[str(node)] = comm_spec_apply_node user.kwargs = new_kwargs - if hasattr(node.meta['info'], 'activation_checkpoint'): - MetaInfo(comm_spec_apply_node, - mod_dir=node.meta['info'].mod_dir, - activation_checkpoint=tuple(node.meta['info'].activation_checkpoint)) + if hasattr(node.meta["info"], "activation_checkpoint"): + MetaInfo( + comm_spec_apply_node, + mod_dir=node.meta["info"].mod_dir, + activation_checkpoint=tuple(node.meta["info"].activation_checkpoint), + ) return gm -def _act_annotataion_pass(gm: torch.fx.GraphModule): +def _act_annotation_pass(gm: torch.fx.GraphModule): """ This pass is used to add the act annotation to the new inserted nodes. """ @@ -227,21 +230,21 @@ def _act_annotataion_pass(gm: torch.fx.GraphModule): nodes = tuple(mod_graph.nodes) for node in nodes: - if not hasattr(node.meta, 'activation_checkpoint'): - from .runtime_preparation_pass import size_processing + if not hasattr(node.meta, "activation_checkpoint"): + pass user_act_annotation = -1 input_act_annotation = -1 for user_node in node.users.keys(): - if 'activation_checkpoint' in user_node.meta: - user_act_annotation = user_node.meta['activation_checkpoint'] + if "activation_checkpoint" in user_node.meta: + user_act_annotation = user_node.meta["activation_checkpoint"] break for input_node in node._input_nodes.keys(): - if 'activation_checkpoint' in input_node.meta: - input_act_annotation = input_node.meta['activation_checkpoint'] + if "activation_checkpoint" in input_node.meta: + input_act_annotation = input_node.meta["activation_checkpoint"] break if user_act_annotation == input_act_annotation and user_act_annotation != -1: - node.meta['activation_checkpoint'] = user_act_annotation + node.meta["activation_checkpoint"] = user_act_annotation return gm diff --git a/colossalai/auto_parallel/passes/runtime_preparation_pass.py b/colossalai/auto_parallel/passes/runtime_preparation_pass.py index 08af846b221db60b3950a0cf285238f616b17711..65c3d8e0cbeb0cca8b0100d407aba7d3b3443333 100644 --- a/colossalai/auto_parallel/passes/runtime_preparation_pass.py +++ b/colossalai/auto_parallel/passes/runtime_preparation_pass.py @@ -1,19 +1,12 @@ import operator -from copy import deepcopy from typing import Dict, List, Union import torch -from torch.fx import symbolic_trace from torch.fx.node import Node from colossalai._analyzer.fx.node_util import MetaInfo from colossalai.auto_parallel.tensor_shard.constants import RESHAPE_FUNC_OP -from colossalai.auto_parallel.tensor_shard.sharding_strategy import ( - CommAction, - CommType, - OperationDataType, - ShardingStrategy, -) +from colossalai.auto_parallel.tensor_shard.sharding_strategy import CommType, OperationDataType from colossalai.auto_parallel.tensor_shard.solver.strategies_constructor import StrategiesConstructor from colossalai.device.device_mesh import DeviceMesh from colossalai.tensor.comm_spec import _all_reduce @@ -25,11 +18,13 @@ from .constants import SHAPE_ARGUMENT_OPS shape_consistency_manager = ShapeConsistencyManager() -def size_processing(size: Union[int, torch.Size], - dim_partition_dict: Dict[int, List[int]], - device_mesh_info: Dict[int, int], - target_dim: int = None, - node_name: str = None): +def size_processing( + size: Union[int, torch.Size], + dim_partition_dict: Dict[int, List[int]], + device_mesh_info: Dict[int, int], + target_dim: int = None, + node_name: str = None, +): """ This method will be invoked during runtime to convert size node value depending on distributed information. """ @@ -54,8 +49,9 @@ def size_processing(size: Union[int, torch.Size], return size -def solution_annotatation_pass(gm: torch.fx.GraphModule, solution: List[int], - strategies_constructor: StrategiesConstructor): +def solution_annotation_pass( + gm: torch.fx.GraphModule, solution: List[int], strategies_constructor: StrategiesConstructor +): """ This method is used to stick the solution strategy to the nodes and add the information required in runtime into graph as placeholder nodes. @@ -70,14 +66,15 @@ def solution_annotatation_pass(gm: torch.fx.GraphModule, solution: List[int], for node_index, (node, strategy_index) in enumerate(zip(nodes, solution)): strategies_vector = node.strategies_vector # stick the solution strategy to the corresponding node - setattr(node, 'best_strategy', strategies_vector[strategy_index]) - setattr(node, 'sharding_spec', strategies_vector[strategy_index].get_sharding_spec_by_name(str(node))) + setattr(node, "best_strategy", strategies_vector[strategy_index]) + setattr(node, "sharding_spec", strategies_vector[strategy_index].get_sharding_spec_by_name(str(node))) origin_node_sharding_spec_dict[node_index] = strategies_vector[strategy_index].get_sharding_spec_by_name( - str(node)) + str(node) + ) # attach the corresponding metainfo if node has the attribute `strategies_info` - if hasattr(node, 'strategies_info'): - setattr(node, 'best_strategy_info', node.strategies_info[strategy_index]) + if hasattr(node, "strategies_info"): + setattr(node, "best_strategy_info", node.strategies_info[strategy_index]) # the dict to get input sharding specs of user node sharding_spec_convert_dict = {} @@ -92,15 +89,15 @@ def solution_annotatation_pass(gm: torch.fx.GraphModule, solution: List[int], target_sharding_spec = user_node.best_strategy.get_sharding_spec_by_name(str(node.name)) target_sharding_specs.append(target_sharding_spec) sharding_spec_convert_dict[index] = target_sharding_specs - setattr(node, 'target_sharding_specs', target_sharding_specs) + setattr(node, "target_sharding_specs", target_sharding_specs) # the get_attr node strategy is kind of pending strategy, which means we will change it # to the same strategy of the user node. - if node.op == 'get_attr': - assert len(target_sharding_specs) == 1, f'sharing weight is not supported in current version.' + if node.op == "get_attr": + assert len(target_sharding_specs) == 1, f"sharing weight is not supported in current version." target_node = node.strategies_vector.successor_nodes[0] node_name = str(node) - if target_node.op == 'call_function' and target_node.target in RESHAPE_FUNC_OP: + if target_node.op == "call_function" and target_node.target in RESHAPE_FUNC_OP: node_name = str(target_node) target_node = target_node.strategies_vector.successor_nodes[0] user_strategy = target_node.best_strategy @@ -122,11 +119,11 @@ def solution_annotatation_pass(gm: torch.fx.GraphModule, solution: List[int], # add above dicts into graph for node in nodes: - if node.op != 'placeholder': + if node.op != "placeholder": with mod_graph.inserting_before(node): - input_specs_node = mod_graph.create_node('placeholder', target='sharding_spec_convert_dict') - origin_specs_node = mod_graph.create_node('placeholder', target='origin_node_sharding_spec_dict') - comm_actions_dict_node = mod_graph.create_node('placeholder', target='comm_actions_dict') + input_specs_node = mod_graph.create_node("placeholder", target="sharding_spec_convert_dict") + origin_specs_node = mod_graph.create_node("placeholder", target="origin_node_sharding_spec_dict") + comm_actions_dict_node = mod_graph.create_node("placeholder", target="comm_actions_dict") break return gm, sharding_spec_convert_dict, origin_node_sharding_spec_dict, comm_actions_dict @@ -144,11 +141,11 @@ def size_value_converting_pass(gm: torch.fx.GraphModule, device_mesh: DeviceMesh # DeviceMesh information instructs the scaling of the size value device_mesh_info = {} - for dim, dim_size in enumerate(device_mesh.mesh_shape): + for dim, dim_size in enumerate(device_mesh.shape): device_mesh_info[dim] = dim_size def _extract_target_dim(node): - ''' + """ A helper function to extract the target dimension from size node. There are two usages of torch.Tensor.size: 1. tensor.size() @@ -156,7 +153,7 @@ def size_value_converting_pass(gm: torch.fx.GraphModule, device_mesh: DeviceMesh If a target_dim is assigned, then the output will be in type of int, instead of torch.Size. Otherwise, the output will be in type of torch.Size and this function will return None. - ''' + """ target_dim = None if len(node.args) > 1: target_dim = node.args[1] @@ -165,19 +162,21 @@ def size_value_converting_pass(gm: torch.fx.GraphModule, device_mesh: DeviceMesh return target_dim def _post_processing(node, size_processing_node): - ''' + """ This function is used to process the dependency between the size node and its users after inserting the size_process_node. - ''' - # store original node and processing node pair in node_pairs dictioanry + """ + # store original node and processing node pair in node_pairs dictionary # It will be used to replace the original node with processing node in slice object node_pairs[node] = size_processing_node size_processing_node._meta_data = node._meta_data - if hasattr(node.meta['info'], 'activation_checkpoint'): - MetaInfo(size_processing_node, - mod_dir=node.meta['info'].mod_dir, - activation_checkpoint=tuple(node.meta['info'].activation_checkpoint)) + if hasattr(node.meta["info"], "activation_checkpoint"): + MetaInfo( + size_processing_node, + mod_dir=node.meta["info"].mod_dir, + activation_checkpoint=tuple(node.meta["info"].activation_checkpoint), + ) user_list = list(node.users.keys()) for user in user_list: @@ -196,10 +195,10 @@ def size_value_converting_pass(gm: torch.fx.GraphModule, device_mesh: DeviceMesh user.kwargs = new_kwargs def _update_slice_object_args(slice_object): - ''' + """ This function is used to update the slice object argument list. If the slice object contains the Node argument, then the size node will be replaced with - ''' + """ if isinstance(slice_object, slice): start = slice_object.start stop = slice_object.stop @@ -220,8 +219,7 @@ def size_value_converting_pass(gm: torch.fx.GraphModule, device_mesh: DeviceMesh raise RuntimeError(f"Unsupported slice object type: {type(slice_object)}") for node in nodes: - - if node.op == 'call_method' and node.target == 'size': + if node.op == "call_method" and node.target == "size": # extract useful information from size node # dim_partition_dict will instruct the size value on which # dimension should be enlarged. @@ -232,14 +230,14 @@ def size_value_converting_pass(gm: torch.fx.GraphModule, device_mesh: DeviceMesh # insert size_processing node with mod_graph.inserting_after(node): - size_processing_node = mod_graph.create_node('call_function', - size_processing, - args=(node, dim_partition_dict, device_mesh_info, - target_dim, node.name)) + size_processing_node = mod_graph.create_node( + "call_function", + size_processing, + args=(node, dim_partition_dict, device_mesh_info, target_dim, node.name), + ) _post_processing(node, size_processing_node) - if node.op == 'call_function' and node.target == operator.getitem: - + if node.op == "call_function" and node.target == operator.getitem: getitem_index = node.args[1] # slice object is quite special in torch.fx graph, # On one side, we treat slice object same as type of int, @@ -287,18 +285,19 @@ def node_args_converting_pass(gm: torch.fx.GraphModule, device_mesh: DeviceMesh) nodes = tuple(mod_graph.nodes) def _extract_info_from_sharding_spec(sharding_spec): - ''' + """ This function is used to extract the dim_partition_dict and device_mesh from sharding spec instance or a list of sharding spec. - ''' + """ if isinstance(sharding_spec, ShardingSpec): dim_partition_dict = sharding_spec.dim_partition_dict device_mesh = sharding_spec.device_mesh return dim_partition_dict, device_mesh if sharding_spec is None: return None, None - assert isinstance(sharding_spec, - (tuple, list)), 'sharding_spec should be type of ShardingSpec, tuple, list or None' + assert isinstance( + sharding_spec, (tuple, list) + ), "sharding_spec should be type of ShardingSpec, tuple, list or None" device_mesh = sharding_spec[0].device_mesh dim_partition_dict = [] @@ -322,8 +321,9 @@ def node_args_converting_pass(gm: torch.fx.GraphModule, device_mesh: DeviceMesh) else: new_args.append(arg) else: - assert isinstance(arg, - (int, tuple, list)), 'The argument in view node should be either type of Node or int.' + assert isinstance( + arg, (int, tuple, list) + ), "The argument in view node should be either type of Node or int." if isinstance(arg, (tuple, list)): new_args.extend(arg) else: @@ -332,7 +332,7 @@ def node_args_converting_pass(gm: torch.fx.GraphModule, device_mesh: DeviceMesh) def _scale_args_adapt_sharding_spec(dim_partition_dict, device_mesh, node): new_args = _process_node_arguments(node) - if node.op == 'call_method': + if node.op == "call_method": args_to_process = list(new_args[1:]) else: args_to_process = list(new_args) @@ -350,7 +350,7 @@ def node_args_converting_pass(gm: torch.fx.GraphModule, device_mesh: DeviceMesh) args_to_process = tuple(args_to_process) - if node.op == 'call_method': + if node.op == "call_method": new_args = (new_args[0],) + args_to_process else: new_args = args_to_process @@ -358,9 +358,9 @@ def node_args_converting_pass(gm: torch.fx.GraphModule, device_mesh: DeviceMesh) node.args = new_args def _filter_node_with_shape_args(node): - if node.op == 'call_method': + if node.op == "call_method": target = getattr(node.args[0]._meta_data.__class__, node.target) - elif node.op == 'call_function': + elif node.op == "call_function": target = node.target else: target = None @@ -371,7 +371,7 @@ def node_args_converting_pass(gm: torch.fx.GraphModule, device_mesh: DeviceMesh) for node in nodes: # skip the placeholder node added in _solution_annotation pass - if not hasattr(node, 'sharding_spec'): + if not hasattr(node, "sharding_spec"): continue output_dim_partition_dict, device_mesh = _extract_info_from_sharding_spec(node.sharding_spec) @@ -388,19 +388,25 @@ def module_params_sharding_pass(gm: torch.fx.GraphModule, device_mesh: DeviceMes """ mod_graph = gm.graph nodes = tuple(mod_graph.nodes) - # This stream is created for overlaping the communication and computation. + # This stream is created for overlapping the communication and computation. reduction_stream = torch.cuda.Stream() def _add_hook_for_grad_communication(node, param, name=None): - comm_actions = node.best_strategy.communication_actions def _filter_param_to_hook(node, op_data, comm_action, name): - - if node.op == 'call_module' and op_data.type == OperationDataType.PARAM and op_data.name == name and comm_action.comm_type == CommType.HOOK: + if ( + node.op == "call_module" + and op_data.type == OperationDataType.PARAM + and op_data.name == name + and comm_action.comm_type == CommType.HOOK + ): return True - if node.op == 'get_attr' and isinstance( - node._meta_data, torch.nn.parameter.Parameter) and comm_action.comm_type == CommType.HOOK: + if ( + node.op == "get_attr" + and isinstance(node._meta_data, torch.nn.parameter.Parameter) + and comm_action.comm_type == CommType.HOOK + ): return True return False @@ -410,7 +416,6 @@ def module_params_sharding_pass(gm: torch.fx.GraphModule, device_mesh: DeviceMes if _filter_param_to_hook(node, operation_data, comm_action, name=name): def wrapper(param, comm_spec, stream, overlap): - def hook_fn(grad): if overlap: with torch.cuda.stream(stream): @@ -426,22 +431,26 @@ def module_params_sharding_pass(gm: torch.fx.GraphModule, device_mesh: DeviceMes # apply the sharding spec of parameters if target_sharding_spec.dim_partition_dict != {}: origin_sharding_spec = ShardingSpec(device_mesh, param.shape, {}) - setattr(param, 'sharding_spec', origin_sharding_spec) + setattr(param, "sharding_spec", origin_sharding_spec) # TODO: build a ColoParameter class to manager the distributed parameters # we could use .data here, because all the operations just happen before the real training # loop, so we don't need to track these operations in the autograd graph. param = torch.nn.Parameter( - shape_consistency_manager.apply_for_autoparallel_runtime(param.data, param.sharding_spec, - target_sharding_spec).detach().clone()) + shape_consistency_manager.apply_for_autoparallel_runtime( + param.data, param.sharding_spec, target_sharding_spec + ) + .detach() + .clone() + ) return param for node in nodes: - if node.op == 'call_module': + if node.op == "call_module": target_module = node.graph.owning_module.get_submodule(node.target) # TODO: we need to do more actions to take care of the shared parameters. - if hasattr(target_module, 'processed') and target_module.processed: + if hasattr(target_module, "processed") and target_module.processed: continue - setattr(target_module, 'processed', True) + setattr(target_module, "processed", True) for name, param in target_module.named_parameters(): target_sharding_spec = node.best_strategy.get_sharding_spec_by_name(name) param = _shard_param(param, target_sharding_spec) @@ -453,7 +462,7 @@ def module_params_sharding_pass(gm: torch.fx.GraphModule, device_mesh: DeviceMes # apply the sharding spec of buffers for name, buffer in target_module.named_buffers(): origin_sharding_spec = ShardingSpec(device_mesh, buffer.shape, {}) - setattr(buffer, 'sharding_spec', origin_sharding_spec) + setattr(buffer, "sharding_spec", origin_sharding_spec) target_sharding_spec = node.best_strategy.get_sharding_spec_by_name(name) buffer_sharded = shape_consistency_manager.apply(buffer, target_sharding_spec) sharded_buffer_dict[name] = buffer_sharded @@ -461,7 +470,7 @@ def module_params_sharding_pass(gm: torch.fx.GraphModule, device_mesh: DeviceMes for name, buffer_sharded in sharded_buffer_dict.items(): setattr(target_module, name, buffer_sharded.detach().clone()) - if node.op == 'get_attr': + if node.op == "get_attr": root = node.graph.owning_module atoms = node.target.split(".") attr_len = len(atoms) @@ -488,16 +497,18 @@ def implicit_comm_action_apply(gm: torch.fx.GraphModule): """ replace the origin kernel into kernel with implicit communication inside. """ - pass -def runtime_preparation_pass(gm: torch.fx.GraphModule, - solution: List[int], - device_mesh: DeviceMesh, - strategies_constructor: StrategiesConstructor, - overlap=False): - gm, sharding_spec_convert_dict, origin_node_sharding_spec_dict, comm_actions_dict = solution_annotatation_pass( - gm, solution, strategies_constructor) +def runtime_preparation_pass( + gm: torch.fx.GraphModule, + solution: List[int], + device_mesh: DeviceMesh, + strategies_constructor: StrategiesConstructor, + overlap=False, +): + gm, sharding_spec_convert_dict, origin_node_sharding_spec_dict, comm_actions_dict = solution_annotation_pass( + gm, solution, strategies_constructor + ) gm = size_value_converting_pass(gm, device_mesh) gm = node_args_converting_pass(gm, device_mesh) # TODO: the pass below should be uncommented after the implementation of implicit_comm_action_apply_pass completed. diff --git a/colossalai/auto_parallel/tensor_shard/constants.py b/colossalai/auto_parallel/tensor_shard/constants.py index 99c1249340602daee1a1314f102bc600eae6667d..e9c2c8664a61d04a25dd76da398c900a863bd828 100644 --- a/colossalai/auto_parallel/tensor_shard/constants.py +++ b/colossalai/auto_parallel/tensor_shard/constants.py @@ -3,9 +3,22 @@ import operator import torch __all__ = [ - 'ELEMENTWISE_MODULE_OP', 'ELEMENTWISE_FUNC_OP', 'RESHAPE_FUNC_OP', 'CONV_MODULE_OP', 'CONV_FUNC_OP', - 'LINEAR_MODULE_OP', 'LINEAR_FUNC_OP', 'BATCHNORM_MODULE_OP', 'POOL_MODULE_OP', 'NON_PARAM_FUNC_OP', 'BCAST_FUNC_OP', - 'EMBEDDING_MODULE_OP', 'LAYERNORM_MODULE_OP', 'ELEMENTWISE_METHOD_OP', 'RESHAPE_METHOD_OP', 'INFINITY_COST' + "ELEMENTWISE_MODULE_OP", + "ELEMENTWISE_FUNC_OP", + "RESHAPE_FUNC_OP", + "CONV_MODULE_OP", + "CONV_FUNC_OP", + "LINEAR_MODULE_OP", + "LINEAR_FUNC_OP", + "BATCHNORM_MODULE_OP", + "POOL_MODULE_OP", + "NON_PARAM_FUNC_OP", + "BCAST_FUNC_OP", + "EMBEDDING_MODULE_OP", + "LAYERNORM_MODULE_OP", + "ELEMENTWISE_METHOD_OP", + "RESHAPE_METHOD_OP", + "INFINITY_COST", ] ELEMENTWISE_MODULE_OP = [torch.nn.Dropout, torch.nn.ReLU] @@ -18,13 +31,13 @@ ELEMENTWISE_FUNC_OP = [ torch.nn.functional.relu, torch.nn.functional.dropout, # softmax should not be here - torch.nn.functional.softmax + torch.nn.functional.softmax, ] ELEMENTWISE_METHOD_OP = [ torch.Tensor.to, torch.Tensor.type, # TODO: contiguous maybe need some extra processes. - torch.Tensor.contiguous + torch.Tensor.contiguous, ] RESHAPE_FUNC_OP = [ torch.flatten, @@ -42,15 +55,36 @@ RESHAPE_METHOD_OP = [ torch.Tensor.transpose, ] BCAST_FUNC_OP = [ - torch.add, torch.sub, torch.mul, torch.div, torch.floor_divide, torch.true_divide, operator.add, operator.sub, - operator.mul, operator.floordiv, operator.truediv, torch.matmul, operator.pow, torch.pow + torch.add, + torch.sub, + torch.mul, + torch.div, + torch.floor_divide, + torch.true_divide, + operator.add, + operator.sub, + operator.mul, + operator.floordiv, + operator.truediv, + torch.matmul, + operator.pow, + torch.pow, ] CONV_MODULE_OP = [ - torch.nn.Conv1d, torch.nn.Conv2d, torch.nn.Conv3d, torch.nn.ConvTranspose1d, torch.nn.ConvTranspose2d, - torch.nn.ConvTranspose3d + torch.nn.Conv1d, + torch.nn.Conv2d, + torch.nn.Conv3d, + torch.nn.ConvTranspose1d, + torch.nn.ConvTranspose2d, + torch.nn.ConvTranspose3d, ] CONV_FUNC_OP = [ - torch.conv1d, torch.conv2d, torch.conv3d, torch.conv_transpose1d, torch.conv_transpose2d, torch.conv_transpose3d + torch.conv1d, + torch.conv2d, + torch.conv3d, + torch.conv_transpose1d, + torch.conv_transpose2d, + torch.conv_transpose3d, ] EMBEDDING_MODULE_OP = [torch.nn.modules.sparse.Embedding] LINEAR_MODULE_OP = [torch.nn.Linear] @@ -85,7 +119,7 @@ NON_PARAM_FUNC_OP = [ operator.floordiv, operator.truediv, # softmax should not be here - torch.nn.functional.softmax + torch.nn.functional.softmax, ] INFINITY_COST = 1e13 diff --git a/colossalai/auto_parallel/tensor_shard/initialize.py b/colossalai/auto_parallel/tensor_shard/initialize.py index b406ca6fb7e0fd28a9a6d3e98365b093f73f7171..d82f0ef53f6605ccc3c0b8537dfe49d5526e6cc6 100644 --- a/colossalai/auto_parallel/tensor_shard/initialize.py +++ b/colossalai/auto_parallel/tensor_shard/initialize.py @@ -3,7 +3,6 @@ from typing import Dict, List, Tuple import torch import torch.distributed as dist import torch.nn as nn -from torch.fx import GraphModule from torch.fx.graph import Graph from colossalai._analyzer.fx.codegen import ActivationCheckpointCodeGen @@ -14,27 +13,32 @@ from colossalai.auto_parallel.passes.runtime_apply_pass import runtime_apply_pas from colossalai.auto_parallel.passes.runtime_preparation_pass import runtime_preparation_pass from colossalai.auto_parallel.tensor_shard.options import DataloaderOption, ShardOption, SolverOptions, SolverPerference from colossalai.auto_parallel.tensor_shard.sharding_strategy import CommAction -from colossalai.auto_parallel.tensor_shard.solver import CostGraph, GraphAnalyser, Solver, StrategiesConstructor +from colossalai.auto_parallel.tensor_shard.solver import CostGraph, Solver, StrategiesConstructor from colossalai.device.alpha_beta_profiler import AlphaBetaProfiler from colossalai.device.device_mesh import DeviceMesh from colossalai.tensor.sharding_spec import ShardingSpec class ModuleWrapper(nn.Module): - ''' + """ This class is used to wrap the original module, and add the sharding_spec_dict, origin_spec_dict, comm_actions_dict into the forward function. - ''' - - def __init__(self, module: ColoGraphModule, sharding_spec_dict: Dict[int, List[ShardingSpec]], - origin_spec_dict: Dict[int, ShardingSpec], comm_actions_dict: Dict[int, Dict[str, CommAction]]): - ''' + """ + + def __init__( + self, + module: ColoGraphModule, + sharding_spec_dict: Dict[int, List[ShardingSpec]], + origin_spec_dict: Dict[int, ShardingSpec], + comm_actions_dict: Dict[int, Dict[str, CommAction]], + ): + """ Args: module: the original module sharding_spec_dict: The sharding_spec_dict is used to record the target sharding specs of each tensor required in user node. origin_spec_dict: The origin_spec_dict is used to record the original sharding spec of each tensor. comm_actions_dict: The comm_actions_dict is used to record the communication actions of each tensor. - ''' + """ super(ModuleWrapper, self).__init__() self.module = module self.sharding_spec_dict = sharding_spec_dict @@ -42,67 +46,68 @@ class ModuleWrapper(nn.Module): self.comm_actions_dict = comm_actions_dict def forward(self, *args, **kwargs): - return self.module(*args, - sharding_spec_convert_dict=self.sharding_spec_dict, - origin_node_sharding_spec_dict=self.origin_spec_dict, - comm_actions_dict=self.comm_actions_dict, - **kwargs) + return self.module( + *args, + sharding_spec_convert_dict=self.sharding_spec_dict, + origin_node_sharding_spec_dict=self.origin_spec_dict, + comm_actions_dict=self.comm_actions_dict, + **kwargs, + ) def extract_meta_args_from_dataloader(data_loader: torch.utils.data.DataLoader, data_process_func: callable): - ''' + """ This method is used to extract the meta_args from the dataloader under the instruction of the data_process_func. - ''' + """ # TODO: implement this function - pass def extract_alpha_beta_for_device_mesh(alpha_beta_dict: Dict[Tuple[int], Tuple[float]], logical_mesh_shape: Tuple[int]): - ''' + """ This method is used to extract the mesh_alpha and mesh_beta for the given logical_mesh_shape from the alpha_beta_dict. These two values will be used to estimate the communication cost. - ''' + """ # TODO: implement this function - pass -def build_strategy_constructor(graph: Graph, device_mesh: DeviceMesh, solver_preference: str, dataloader_option: str, - shard_option: str): - ''' +def build_strategy_constructor( + graph: Graph, device_mesh: DeviceMesh, solver_preference: str, dataloader_option: str, shard_option: str +): + """ This method is used to build the strategy_constructor for the given graph. After this method, each node in the graph will have a strategies_vector which is constructed by the related node handler. - ''' - if solver_preference == 'standard': + """ + if solver_preference == "standard": solver_preference = SolverPerference.STANDARD - elif solver_preference == 'tp': + elif solver_preference == "tp": solver_preference = SolverPerference.TP - elif solver_preference == 'dp': + elif solver_preference == "dp": solver_preference = SolverPerference.DP else: - raise ValueError(f'Invalid solver_preference: {solver_preference}') + raise ValueError(f"Invalid solver_preference: {solver_preference}") - if dataloader_option == 'replicated': + if dataloader_option == "replicated": dataloader_option = DataloaderOption.REPLICATED - elif dataloader_option == 'distributed': + elif dataloader_option == "distributed": dataloader_option = DataloaderOption.DISTRIBUTED else: - raise ValueError(f'Invalid dataloader_option: {dataloader_option}') + raise ValueError(f"Invalid dataloader_option: {dataloader_option}") - if shard_option == 'standard': + if shard_option == "standard": shard_option = ShardOption.STANDARD - elif shard_option == 'shard': + elif shard_option == "shard": shard_option = ShardOption.SHARD - elif shard_option == 'shard_last_axis': + elif shard_option == "shard_last_axis": shard_option = ShardOption.SHARD_LAST_AXIS - elif shard_option == 'full_shard': + elif shard_option == "full_shard": shard_option = ShardOption.FULL_SHARD else: - raise ValueError(f'Invalid shard_option: {shard_option}') + raise ValueError(f"Invalid shard_option: {shard_option}") - solver_options = SolverOptions(solver_perference=solver_preference, - dataloader_option=dataloader_option, - shard_option=shard_option) + solver_options = SolverOptions( + solver_perference=solver_preference, dataloader_option=dataloader_option, shard_option=shard_option + ) strategies_constructor = StrategiesConstructor(graph, device_mesh, solver_options) strategies_constructor.build_strategies_and_cost() @@ -110,10 +115,10 @@ def build_strategy_constructor(graph: Graph, device_mesh: DeviceMesh, solver_pre def solve_solution(gm: ColoGraphModule, strategy_constructor: StrategiesConstructor, memory_budget: float = -1.0): - ''' + """ This method is used to solve the best solution for the given graph. The solution is a list of integers, each integer represents the best strategy index of the corresponding node. - ''' + """ # temporarily we use all nodes as liveness list, we count the backward memory cost together with # forward memory cost into the node memory cost, and no activation checkpoint is used in this phase. # graph_analyser = GraphAnalyser(gm) @@ -127,23 +132,23 @@ def solve_solution(gm: ColoGraphModule, strategy_constructor: StrategiesConstruc return solution -def transform_to_sharded_model(gm: ColoGraphModule, - meta_args: Dict, - solution: List[int], - device_mesh: DeviceMesh, - strategies_constructor: StrategiesConstructor, - overlap: bool = False): - ''' +def transform_to_sharded_model( + gm: ColoGraphModule, + meta_args: Dict, + solution: List[int], + device_mesh: DeviceMesh, + strategies_constructor: StrategiesConstructor, + overlap: bool = False, +): + """ This method is used to transform the original graph to the sharded graph. The model parameters will be sharded according to the solution and the grad hooks will be added to the sharded graph using the runtime_preparation_pass. The communication node will be added into the graph using the runtime_apply_pass. - ''' - gm, sharding_spec_dict, origin_spec_dict, comm_actions_dict = runtime_preparation_pass(gm, - solution, - device_mesh, - strategies_constructor, - overlap=overlap) + """ + gm, sharding_spec_dict, origin_spec_dict, comm_actions_dict = runtime_preparation_pass( + gm, solution, device_mesh, strategies_constructor, overlap=overlap + ) gm = runtime_apply_pass(gm) shape_prop_pass(gm, *meta_args.values(), sharding_spec_dict, origin_spec_dict, comm_actions_dict) gm.recompile() @@ -152,12 +157,14 @@ def transform_to_sharded_model(gm: ColoGraphModule, return gm, sharding_spec_dicts -def initialize_device_mesh(world_size: int = -1, - physical_devices: List[int] = None, - alpha_beta_dict: Dict[Tuple[int], Tuple[float]] = None, - logical_mesh_shape: Tuple[int] = None, - logical_mesh_id: torch.Tensor = None): - ''' +def initialize_device_mesh( + world_size: int = -1, + physical_devices: List[int] = None, + alpha_beta_dict: Dict[Tuple[int], Tuple[float]] = None, + logical_mesh_shape: Tuple[int] = None, + logical_mesh_id: torch.Tensor = None, +): + """ This method is used to initialize the device mesh. Args: @@ -170,7 +177,7 @@ def initialize_device_mesh(world_size: int = -1, logical_mesh_shape(optional): the logical_mesh_shape is used to specify the logical mesh shape. logical_mesh_id(optional): the logical_mesh_id is used to specify the logical mesh id. - ''' + """ # if world_size is not set, use the world size from torch.distributed if world_size == -1: world_size = dist.get_world_size() @@ -201,27 +208,31 @@ def initialize_device_mesh(world_size: int = -1, # extract alpha and beta values for the chosen logical mesh shape mesh_alpha, mesh_beta = extract_alpha_beta_for_device_mesh(alpha_beta_dict, logical_mesh_id) - device_mesh = DeviceMesh(physical_mesh_id=physical_mesh, - logical_mesh_id=logical_mesh_id, - mesh_alpha=mesh_alpha, - mesh_beta=mesh_beta, - init_process_group=True) + device_mesh = DeviceMesh( + physical_mesh_id=physical_mesh, + logical_mesh_id=logical_mesh_id, + mesh_alpha=mesh_alpha, + mesh_beta=mesh_beta, + init_process_group=True, + ) return device_mesh -def initialize_model(model: nn.Module, - meta_args: Dict[str, torch.Tensor], - device_mesh: DeviceMesh, - memory_budget: float = -1.0, - overlap: bool = False, - solver_preference: str = 'standard', - dataloader_option: str = 'replicated', - shard_option: str = 'standard', - save_solver_solution: bool = False, - load_solver_solution: bool = False, - solution_path: str = None, - return_solution: bool = False): - ''' +def initialize_model( + model: nn.Module, + meta_args: Dict[str, torch.Tensor], + device_mesh: DeviceMesh, + memory_budget: float = -1.0, + overlap: bool = False, + solver_preference: str = "standard", + dataloader_option: str = "replicated", + shard_option: str = "standard", + save_solver_solution: bool = False, + load_solver_solution: bool = False, + solution_path: str = None, + return_solution: bool = False, +): + """ This method is used to initialize the sharded model which could be used as normal pytorch model. Args: @@ -246,7 +257,7 @@ def initialize_model(model: nn.Module, return_solution(optional): if the return_solution is True, the solution will be returned. The returned solution will be used to debug or help to analyze the sharding result. Therefore, we will not just return a series of integers, but return the best strategies. - ''' + """ tracer = ColoTracer(trace_act_ckpt=True, bias_addition_split=True) graph = tracer.trace(root=model, meta_args=meta_args) @@ -256,11 +267,13 @@ def initialize_model(model: nn.Module, shape_prop_pass(gm, *meta_args.values()) gm.recompile() - strategies_constructor = build_strategy_constructor(graph, - device_mesh, - solver_preference=solver_preference, - dataloader_option=dataloader_option, - shard_option=shard_option) + strategies_constructor = build_strategy_constructor( + graph, + device_mesh, + solver_preference=solver_preference, + dataloader_option=dataloader_option, + shard_option=shard_option, + ) if load_solver_solution: solution = torch.load(solution_path) else: @@ -268,8 +281,9 @@ def initialize_model(model: nn.Module, if save_solver_solution: torch.save(solution, solution_path) - gm, sharding_spec_dicts = transform_to_sharded_model(gm, meta_args, solution, device_mesh, strategies_constructor, - overlap) + gm, sharding_spec_dicts = transform_to_sharded_model( + gm, meta_args, solution, device_mesh, strategies_constructor, overlap + ) model_to_return = ModuleWrapper(gm, *sharding_spec_dicts) @@ -277,28 +291,30 @@ def initialize_model(model: nn.Module, solution_to_return = [] nodes = [strategies_vector.node for strategies_vector in strategies_constructor.leaf_strategies] for index, node in enumerate(nodes): - solution_to_return.append(f'{node.name} {node.strategies_vector[solution[index]].name}') + solution_to_return.append(f"{node.name} {node.strategies_vector[solution[index]].name}") return model_to_return, solution_to_return else: return model_to_return -def autoparallelize(model: nn.Module, - meta_args: Dict[str, torch.Tensor] = None, - data_loader: torch.utils.data.DataLoader = None, - data_process_func: callable = None, - alpha_beta_dict: Dict[Tuple[int], Tuple[float]] = None, - logical_mesh_shape: Tuple[int] = None, - logical_mesh_id: torch.Tensor = None, - solver_preference: str = 'standard', - dataloader_option: str = 'replicated', - shard_option: str = 'standard', - save_solver_solution: bool = False, - load_solver_solution: bool = False, - solver_solution_path: str = None, - return_solution: bool = False, - memory_budget: float = -1.0): - ''' +def autoparallelize( + model: nn.Module, + meta_args: Dict[str, torch.Tensor] = None, + data_loader: torch.utils.data.DataLoader = None, + data_process_func: callable = None, + alpha_beta_dict: Dict[Tuple[int], Tuple[float]] = None, + logical_mesh_shape: Tuple[int] = None, + logical_mesh_id: torch.Tensor = None, + solver_preference: str = "standard", + dataloader_option: str = "replicated", + shard_option: str = "standard", + save_solver_solution: bool = False, + load_solver_solution: bool = False, + solver_solution_path: str = None, + return_solution: bool = False, + memory_budget: float = -1.0, +): + """ This method is used to initialize the device mesh, extract the meta_args, and use them to create a sharded model. @@ -329,24 +345,26 @@ def autoparallelize(model: nn.Module, return_solution(optional): if the return_solution is True, the solution will be returned. memory_budget(optional): the max cuda memory could be used. If the memory budget is -1.0, the memory budget will be infinity. - ''' - device_mesh = initialize_device_mesh(alpha_beta_dict=alpha_beta_dict, - logical_mesh_shape=logical_mesh_shape, - logical_mesh_id=logical_mesh_id) + """ + device_mesh = initialize_device_mesh( + alpha_beta_dict=alpha_beta_dict, logical_mesh_shape=logical_mesh_shape, logical_mesh_id=logical_mesh_id + ) if meta_args is None: meta_args = extract_meta_args_from_dataloader(data_loader, data_process_func) - rst_to_unpack = initialize_model(model, - meta_args, - device_mesh, - solver_preference=solver_preference, - dataloader_option=dataloader_option, - shard_option=shard_option, - save_solver_solution=save_solver_solution, - load_solver_solution=load_solver_solution, - solution_path=solver_solution_path, - return_solution=return_solution, - memory_budget=memory_budget) + rst_to_unpack = initialize_model( + model, + meta_args, + device_mesh, + solver_preference=solver_preference, + dataloader_option=dataloader_option, + shard_option=shard_option, + save_solver_solution=save_solver_solution, + load_solver_solution=load_solver_solution, + solution_path=solver_solution_path, + return_solution=return_solution, + memory_budget=memory_budget, + ) if return_solution: model, solution = rst_to_unpack diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/__init__.py b/colossalai/auto_parallel/tensor_shard/node_handler/__init__.py index 9903ca54e52cb70559cce2c68169c84ca08bef9c..aa2e5e9c40c0aa9dbe0b360f6907744ed348fa31 100644 --- a/colossalai/auto_parallel/tensor_shard/node_handler/__init__.py +++ b/colossalai/auto_parallel/tensor_shard/node_handler/__init__.py @@ -25,11 +25,33 @@ from .view_handler import ViewHandler from .where_handler import WhereHandler __all__ = [ - 'LinearFunctionHandler', 'LinearModuleHandler', 'BMMFunctionHandler', 'AddBMMFunctionHandler', - 'LayerNormModuleHandler', 'BatchNormModuleHandler', 'ConvModuleHandler', 'ConvFunctionHandler', - 'UnaryElementwiseHandler', 'DefaultReshapeHandler', 'PlaceholderHandler', 'OutputHandler', 'WhereHandler', - 'NormPoolingHandler', 'BinaryElementwiseHandler', 'MatMulHandler', 'operator_registry', 'ADDMMFunctionHandler', - 'GetItemHandler', 'GetattrHandler', 'ViewHandler', 'PermuteHandler', 'TensorConstructorHandler', - 'EmbeddingModuleHandler', 'EmbeddingFunctionHandler', 'SumHandler', 'SoftmaxHandler', 'TransposeHandler', - 'SplitHandler' + "LinearFunctionHandler", + "LinearModuleHandler", + "BMMFunctionHandler", + "AddBMMFunctionHandler", + "LayerNormModuleHandler", + "BatchNormModuleHandler", + "ConvModuleHandler", + "ConvFunctionHandler", + "UnaryElementwiseHandler", + "DefaultReshapeHandler", + "PlaceholderHandler", + "OutputHandler", + "WhereHandler", + "NormPoolingHandler", + "BinaryElementwiseHandler", + "MatMulHandler", + "operator_registry", + "ADDMMFunctionHandler", + "GetItemHandler", + "GetattrHandler", + "ViewHandler", + "PermuteHandler", + "TensorConstructorHandler", + "EmbeddingModuleHandler", + "EmbeddingFunctionHandler", + "SumHandler", + "SoftmaxHandler", + "TransposeHandler", + "SplitHandler", ] diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/addmm_handler.py b/colossalai/auto_parallel/tensor_shard/node_handler/addmm_handler.py index da0d199c5e05b37340cb4ddcfee0b52a9102fadf..47c654d6aa436fb145420aaa666e8bba2eeaeaeb 100644 --- a/colossalai/auto_parallel/tensor_shard/node_handler/addmm_handler.py +++ b/colossalai/auto_parallel/tensor_shard/node_handler/addmm_handler.py @@ -2,15 +2,13 @@ from typing import Dict, List, Union import torch -from colossalai.tensor.shape_consistency import CollectiveCommPattern, CommSpec, ShapeConsistencyManager - -from ..sharding_strategy import CommAction, CommType, OperationData, OperationDataType, ShardingStrategy +from ..sharding_strategy import OperationData, OperationDataType, ShardingStrategy from ..utils import comm_actions_for_oprands, recover_sharding_spec_for_broadcast_shape from .node_handler import NodeHandler from .registry import operator_registry from .strategy import LinearProjectionStrategyGenerator, StrategyGenerator -__all__ = ['ADDMMFunctionHandler'] +__all__ = ["ADDMMFunctionHandler"] @operator_registry.register(torch.addmm) @@ -30,25 +28,26 @@ class ADDMMFunctionHandler(NodeHandler): return data_type def get_operation_data_mapping(self) -> Dict[str, OperationData]: - # input operand input_data = self.node.args[1]._meta_data - physical_input_operand = OperationData(name=str(self.node.args[1]), - type=self._infer_op_data_type(input_data), - data=input_data) + physical_input_operand = OperationData( + name=str(self.node.args[1]), type=self._infer_op_data_type(input_data), data=input_data + ) # other operand other_data = self.node.args[2]._meta_data - physical_other_operand = OperationData(name=str(self.node.args[2]), - type=self._infer_op_data_type(other_data), - data=other_data) + physical_other_operand = OperationData( + name=str(self.node.args[2]), type=self._infer_op_data_type(other_data), data=other_data + ) # bias physical shape bias_logical_shape = self.node._meta_data.shape bias_data = self.node.args[0]._meta_data - physical_bias_operand = OperationData(name=str(self.node.args[0]), - type=self._infer_op_data_type(bias_data), - data=bias_data, - logical_shape=bias_logical_shape) + physical_bias_operand = OperationData( + name=str(self.node.args[0]), + type=self._infer_op_data_type(bias_data), + data=bias_data, + logical_shape=bias_logical_shape, + ) # output physical_output = OperationData(name=str(self.node), type=OperationDataType.OUTPUT, data=self.node._meta_data) @@ -57,7 +56,7 @@ class ADDMMFunctionHandler(NodeHandler): "input": physical_input_operand, "other": physical_other_operand, "output": physical_output, - 'bias': physical_bias_operand + "bias": physical_bias_operand, } return mapping @@ -66,26 +65,27 @@ class ADDMMFunctionHandler(NodeHandler): op_data_mapping = self.get_operation_data_mapping() generators = [] generators.append( - LinearProjectionStrategyGenerator(op_data_mapping, self.device_mesh, linear_projection_type='addmm')) + LinearProjectionStrategyGenerator(op_data_mapping, self.device_mesh, linear_projection_type="addmm") + ) return generators def post_process(self, strategy: ShardingStrategy) -> Union[ShardingStrategy, List[ShardingStrategy]]: # convert bias from its logical sharding spec to its physical sharding spec op_data_mapping = self.get_operation_data_mapping() - bias_op_data = op_data_mapping['bias'] + bias_op_data = op_data_mapping["bias"] bias_physical_shape = bias_op_data.data.shape bias_logical_shape = bias_op_data.logical_shape bias_sharding_spec = strategy.get_sharding_spec_by_name(bias_op_data.name) bias_sharding_spec, removed_dims = recover_sharding_spec_for_broadcast_shape( - bias_sharding_spec, bias_logical_shape, bias_physical_shape) + bias_sharding_spec, bias_logical_shape, bias_physical_shape + ) strategy.sharding_specs[bias_op_data] = bias_sharding_spec if len(removed_dims) > 0: - comm_action = comm_actions_for_oprands(node=self.node, - removed_dims=removed_dims, - op_data=bias_op_data, - sharding_spec=bias_sharding_spec) + comm_action = comm_actions_for_oprands( + node=self.node, removed_dims=removed_dims, op_data=bias_op_data, sharding_spec=bias_sharding_spec + ) strategy.communication_actions[bias_op_data] = comm_action return strategy diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/batch_norm_handler.py b/colossalai/auto_parallel/tensor_shard/node_handler/batch_norm_handler.py index cb1bb36b78796db3d5656213518376f8f365dce0..df4b1d6cef3fcc3473d692e208728e21f256588b 100644 --- a/colossalai/auto_parallel/tensor_shard/node_handler/batch_norm_handler.py +++ b/colossalai/auto_parallel/tensor_shard/node_handler/batch_norm_handler.py @@ -2,12 +2,12 @@ from typing import Dict, List import torch -from ..sharding_strategy import OperationData, OperationDataType, StrategiesVector -from .node_handler import MetaInfoModuleHandler, ModuleHandler +from ..sharding_strategy import OperationData, OperationDataType +from .node_handler import MetaInfoModuleHandler from .registry import operator_registry from .strategy import BatchNormStrategyGenerator, StrategyGenerator -__all__ = ['BatchNormModuleHandler'] +__all__ = ["BatchNormModuleHandler"] @operator_registry.register(torch.nn.BatchNorm1d) @@ -27,30 +27,37 @@ class BatchNormModuleHandler(MetaInfoModuleHandler): def get_operation_data_mapping(self) -> Dict[str, OperationData]: # use transposed shape for strategies # the strategies will be transformed back to its original shape in self.post_process - physical_input_operand = OperationData(name=str(self.node.args[0]), - type=OperationDataType.ARG, - data=self.node.args[0]._meta_data) - physical_other_operand = OperationData(name="weight", - type=OperationDataType.PARAM, - data=self.named_parameters['weight'], - logical_shape=self.named_parameters['weight'].shape) + physical_input_operand = OperationData( + name=str(self.node.args[0]), type=OperationDataType.ARG, data=self.node.args[0]._meta_data + ) + physical_other_operand = OperationData( + name="weight", + type=OperationDataType.PARAM, + data=self.named_parameters["weight"], + logical_shape=self.named_parameters["weight"].shape, + ) physical_output = OperationData(name=str(self.node), type=OperationDataType.OUTPUT, data=self.node._meta_data) - physical_running_mean_operand = OperationData(name="running_mean", - type=OperationDataType.BUFFER, - data=self.named_buffers['running_mean'], - logical_shape=self.named_buffers['running_mean'].shape) + physical_running_mean_operand = OperationData( + name="running_mean", + type=OperationDataType.BUFFER, + data=self.named_buffers["running_mean"], + logical_shape=self.named_buffers["running_mean"].shape, + ) - physical_running_var_operand = OperationData(name="running_var", - type=OperationDataType.BUFFER, - data=self.named_buffers['running_var'], - logical_shape=self.named_buffers['running_var'].shape) + physical_running_var_operand = OperationData( + name="running_var", + type=OperationDataType.BUFFER, + data=self.named_buffers["running_var"], + logical_shape=self.named_buffers["running_var"].shape, + ) physical_num_batches_tracked_operand = OperationData( name="num_batches_tracked", type=OperationDataType.BUFFER, - data=self.named_buffers['num_batches_tracked'], - logical_shape=self.named_buffers['num_batches_tracked'].shape) + data=self.named_buffers["num_batches_tracked"], + logical_shape=self.named_buffers["num_batches_tracked"].shape, + ) mapping = { "input": physical_input_operand, @@ -58,12 +65,12 @@ class BatchNormModuleHandler(MetaInfoModuleHandler): "output": physical_output, "running_mean": physical_running_mean_operand, "running_var": physical_running_var_operand, - "num_batches_tracked": physical_num_batches_tracked_operand + "num_batches_tracked": physical_num_batches_tracked_operand, } - if self.named_parameters['bias'] is not None: - physical_bias_operand = OperationData(name="bias", - type=OperationDataType.PARAM, - data=self.named_parameters['bias']) - mapping['bias'] = physical_bias_operand + if self.named_parameters["bias"] is not None: + physical_bias_operand = OperationData( + name="bias", type=OperationDataType.PARAM, data=self.named_parameters["bias"] + ) + mapping["bias"] = physical_bias_operand return mapping diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/binary_elementwise_handler.py b/colossalai/auto_parallel/tensor_shard/node_handler/binary_elementwise_handler.py index db8f0b54ddeeb1c5250951f0c9e8bfef364eb16d..f8c137348353f760aee338b6a94e5df97dd90699 100644 --- a/colossalai/auto_parallel/tensor_shard/node_handler/binary_elementwise_handler.py +++ b/colossalai/auto_parallel/tensor_shard/node_handler/binary_elementwise_handler.py @@ -4,15 +4,14 @@ import torch from torch.fx.node import Node from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, ShardingStrategy -from colossalai.tensor.shape_consistency import CollectiveCommPattern, CommSpec, ShapeConsistencyManager from ..constants import BCAST_FUNC_OP from ..utils import comm_actions_for_oprands, recover_sharding_spec_for_broadcast_shape -from .node_handler import MetaInfoNodeHandler, NodeHandler +from .node_handler import MetaInfoNodeHandler from .registry import operator_registry from .strategy import BinaryElementwiseStrategyGenerator, StrategyGenerator -__all__ = ['BinaryElementwiseHandler'] +__all__ = ["BinaryElementwiseHandler"] @operator_registry.register(BCAST_FUNC_OP) @@ -38,7 +37,7 @@ class BinaryElementwiseHandler(MetaInfoNodeHandler): # The meta_data of node type argument could also possibly be a non-tensor object. if not isinstance(meta_data, torch.Tensor): assert isinstance(meta_data, (int, float)) - meta_data = torch.Tensor([meta_data]).to('meta') + meta_data = torch.Tensor([meta_data]).to("meta") non_tensor = True else: @@ -46,7 +45,7 @@ class BinaryElementwiseHandler(MetaInfoNodeHandler): # but we can deem it as meta data # as it won't affect the strategy generation assert isinstance(self.node.args[idx], (int, float)) - meta_data = torch.Tensor([self.node.args[idx]]).to('meta') + meta_data = torch.Tensor([self.node.args[idx]]).to("meta") non_tensor = True return meta_data, non_tensor @@ -58,24 +57,27 @@ class BinaryElementwiseHandler(MetaInfoNodeHandler): # and filter the non-tensor op_data in post_process. self.non_tensor_list = [] # assert False - input_op_data = OperationData(name=str(self.node.args[0]), - type=_get_op_data_type(input_meta_data), - data=input_meta_data, - logical_shape=bcast_shape) - other_op_data = OperationData(name=str(self.node.args[1]), - type=_get_op_data_type(other_meta_data), - data=other_meta_data, - logical_shape=bcast_shape) - output_op_data = OperationData(name=str(self.node), - type=OperationDataType.OUTPUT, - data=output_meta_data, - logical_shape=bcast_shape) + input_op_data = OperationData( + name=str(self.node.args[0]), + type=_get_op_data_type(input_meta_data), + data=input_meta_data, + logical_shape=bcast_shape, + ) + other_op_data = OperationData( + name=str(self.node.args[1]), + type=_get_op_data_type(other_meta_data), + data=other_meta_data, + logical_shape=bcast_shape, + ) + output_op_data = OperationData( + name=str(self.node), type=OperationDataType.OUTPUT, data=output_meta_data, logical_shape=bcast_shape + ) if non_tensor_input: self.non_tensor_list.append(input_op_data) if non_tensor_other: self.non_tensor_list.append(other_op_data) - mapping = {'input': input_op_data, 'other': other_op_data, 'output': output_op_data} + mapping = {"input": input_op_data, "other": other_op_data, "output": output_op_data} return mapping def get_strategy_generator(self) -> List[StrategyGenerator]: @@ -100,14 +102,14 @@ class BinaryElementwiseHandler(MetaInfoNodeHandler): logical_shape = op_data.logical_shape sharding_spec = strategy.get_sharding_spec_by_name(op_data.name) sharding_spec, removed_dims = recover_sharding_spec_for_broadcast_shape( - sharding_spec, logical_shape, physical_shape) + sharding_spec, logical_shape, physical_shape + ) strategy.sharding_specs[op_data] = sharding_spec if len(removed_dims) > 0: - comm_action = comm_actions_for_oprands(node=self.node, - removed_dims=removed_dims, - op_data=op_data, - sharding_spec=sharding_spec) + comm_action = comm_actions_for_oprands( + node=self.node, removed_dims=removed_dims, op_data=op_data, sharding_spec=sharding_spec + ) strategy.communication_actions[op_data] = comm_action return strategy diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/bmm_handler.py b/colossalai/auto_parallel/tensor_shard/node_handler/bmm_handler.py index da2b733c9f7afda075c10dc3dd17a0d4f42fbc01..5c22ac7bef117be918a1cfda03ae562e48741e19 100644 --- a/colossalai/auto_parallel/tensor_shard/node_handler/bmm_handler.py +++ b/colossalai/auto_parallel/tensor_shard/node_handler/bmm_handler.py @@ -2,15 +2,13 @@ from typing import Dict, List, Union import torch -from colossalai.tensor.shape_consistency import CollectiveCommPattern, CommSpec, ShapeConsistencyManager - -from ..sharding_strategy import CommAction, CommType, OperationData, OperationDataType, ShardingStrategy +from ..sharding_strategy import OperationData, OperationDataType, ShardingStrategy from ..utils import comm_actions_for_oprands, recover_sharding_spec_for_broadcast_shape from .node_handler import NodeHandler from .registry import operator_registry from .strategy import BatchedMatMulStrategyGenerator, StrategyGenerator -__all__ = ['BMMFunctionHandler', 'AddBMMFunctionHandler'] +__all__ = ["BMMFunctionHandler", "AddBMMFunctionHandler"] def _get_data_mapping_for_bmm_op(node, input_idx, other_idx, bias_idx=None): @@ -19,14 +17,14 @@ def _get_data_mapping_for_bmm_op(node, input_idx, other_idx, bias_idx=None): node handler to reduce code redundancy. """ # input operand - physical_input_operand = OperationData(name=str(node.args[input_idx]), - type=OperationDataType.ARG, - data=node.args[input_idx]._meta_data) + physical_input_operand = OperationData( + name=str(node.args[input_idx]), type=OperationDataType.ARG, data=node.args[input_idx]._meta_data + ) # other operand - physical_other_operand = OperationData(name=str(node.args[other_idx]), - type=OperationDataType.ARG, - data=node.args[other_idx]._meta_data) + physical_other_operand = OperationData( + name=str(node.args[other_idx]), type=OperationDataType.ARG, data=node.args[other_idx]._meta_data + ) # output physical_output = OperationData(name=str(node), type=OperationDataType.OUTPUT, data=node._meta_data) @@ -35,11 +33,13 @@ def _get_data_mapping_for_bmm_op(node, input_idx, other_idx, bias_idx=None): if bias_idx is not None: # bias physical shape bias_logical_shape = node._meta_data.shape - physical_bias_operand = OperationData(name=str(node.args[bias_idx]), - type=OperationDataType.ARG, - data=node.args[bias_idx]._meta_data, - logical_shape=bias_logical_shape) - mapping['bias'] = physical_bias_operand + physical_bias_operand = OperationData( + name=str(node.args[bias_idx]), + type=OperationDataType.ARG, + data=node.args[bias_idx]._meta_data, + logical_shape=bias_logical_shape, + ) + mapping["bias"] = physical_bias_operand return mapping @@ -91,20 +91,20 @@ class AddBMMFunctionHandler(NodeHandler): # convert bias from its logical sharding spec to its physical sharding spec op_data_mapping = self.get_operation_data_mapping() - if 'bias' in op_data_mapping: - bias_op_data = op_data_mapping['bias'] + if "bias" in op_data_mapping: + bias_op_data = op_data_mapping["bias"] bias_physical_shape = bias_op_data.data.shape bias_logical_shape = bias_op_data.logical_shape bias_sharding_spec = strategy.get_sharding_spec_by_name(bias_op_data.name) bias_sharding_spec, removed_dims = recover_sharding_spec_for_broadcast_shape( - bias_sharding_spec, bias_logical_shape, bias_physical_shape) + bias_sharding_spec, bias_logical_shape, bias_physical_shape + ) strategy.sharding_specs[bias_op_data] = bias_sharding_spec if len(removed_dims) > 0: - comm_action = comm_actions_for_oprands(node=self.node, - removed_dims=removed_dims, - op_data=bias_op_data, - sharding_spec=bias_sharding_spec) + comm_action = comm_actions_for_oprands( + node=self.node, removed_dims=removed_dims, op_data=bias_op_data, sharding_spec=bias_sharding_spec + ) strategy.communication_actions[bias_op_data] = comm_action return strategy diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/conv_handler.py b/colossalai/auto_parallel/tensor_shard/node_handler/conv_handler.py index 272b1c85630a8ab15145e701740a44e20d5103b8..fd7c1f837a5a69e2bbf36330e6d0c65bbfdf059b 100644 --- a/colossalai/auto_parallel/tensor_shard/node_handler/conv_handler.py +++ b/colossalai/auto_parallel/tensor_shard/node_handler/conv_handler.py @@ -3,13 +3,13 @@ from typing import Dict, List import torch import torch.nn.functional as F -from ..sharding_strategy import OperationData, OperationDataType, ShardingStrategy, StrategiesVector +from ..sharding_strategy import OperationData, OperationDataType, ShardingStrategy from ..utils import transpose_partition_dim -from .node_handler import MetaInfoModuleHandler, MetaInfoNodeHandler, ModuleHandler, NodeHandler +from .node_handler import MetaInfoModuleHandler, MetaInfoNodeHandler from .registry import operator_registry from .strategy import ConvStrategyGenerator, StrategyGenerator -__all__ = ['ConvModuleHandler', 'ConvFunctionHandler'] +__all__ = ["ConvModuleHandler", "ConvFunctionHandler"] @operator_registry.register(torch.nn.Conv1d) @@ -29,25 +29,29 @@ class ConvModuleHandler(MetaInfoModuleHandler): def get_operation_data_mapping(self) -> Dict[str, OperationData]: # use transposed shape for strategies # the strategies will be transformed back to its original shape in self.post_process - physical_input_operand = OperationData(name=str(self.node.args[0]), - type=OperationDataType.ARG, - data=self.node.args[0]._meta_data) + physical_input_operand = OperationData( + name=str(self.node.args[0]), type=OperationDataType.ARG, data=self.node.args[0]._meta_data + ) logical_shape_for_weight = list(self.named_parameters["weight"].shape) - logical_shape_for_weight[0], logical_shape_for_weight[1] = logical_shape_for_weight[ - 1], logical_shape_for_weight[0] - physical_other_operand = OperationData(name="weight", - type=OperationDataType.PARAM, - data=self.named_parameters['weight'], - logical_shape=torch.Size(logical_shape_for_weight)) + logical_shape_for_weight[0], logical_shape_for_weight[1] = ( + logical_shape_for_weight[1], + logical_shape_for_weight[0], + ) + physical_other_operand = OperationData( + name="weight", + type=OperationDataType.PARAM, + data=self.named_parameters["weight"], + logical_shape=torch.Size(logical_shape_for_weight), + ) physical_output = OperationData(name=str(self.node), type=OperationDataType.OUTPUT, data=self.node._meta_data) mapping = {"input": physical_input_operand, "other": physical_other_operand, "output": physical_output} if "bias" in self.named_parameters: - physical_bias_operand = OperationData(name="bias", - type=OperationDataType.PARAM, - data=self.named_parameters['bias']) - mapping['bias'] = physical_bias_operand + physical_bias_operand = OperationData( + name="bias", type=OperationDataType.PARAM, data=self.named_parameters["bias"] + ) + mapping["bias"] = physical_bias_operand return mapping def post_process(self, strategy: ShardingStrategy): @@ -77,9 +81,9 @@ class ConvFunctionHandler(MetaInfoNodeHandler): def get_operation_data_mapping(self) -> Dict[str, OperationData]: # use transposed shape for strategies # the strategies will be transformed back to its original shape in self.post_process - physical_input_operand = OperationData(name=str(self.node.args[0]), - type=OperationDataType.ARG, - data=self.node.args[0]._meta_data) + physical_input_operand = OperationData( + name=str(self.node.args[0]), type=OperationDataType.ARG, data=self.node.args[0]._meta_data + ) # check if the other operand is a parameter if isinstance(self.node.args[1]._meta_data, torch.nn.parameter.Parameter): @@ -88,26 +92,30 @@ class ConvFunctionHandler(MetaInfoNodeHandler): data_type = OperationDataType.ARG logical_shape_for_weight = list(self.node.args[1]._meta_data.shape) - logical_shape_for_weight[0], logical_shape_for_weight[1] = logical_shape_for_weight[ - 1], logical_shape_for_weight[0] - physical_other_operand = OperationData(name=str(self.node.args[1]), - type=data_type, - data=self.node.args[1]._meta_data, - logical_shape=torch.Size(logical_shape_for_weight)) + logical_shape_for_weight[0], logical_shape_for_weight[1] = ( + logical_shape_for_weight[1], + logical_shape_for_weight[0], + ) + physical_other_operand = OperationData( + name=str(self.node.args[1]), + type=data_type, + data=self.node.args[1]._meta_data, + logical_shape=torch.Size(logical_shape_for_weight), + ) physical_output = OperationData(name=str(self.node), type=OperationDataType.OUTPUT, data=self.node._meta_data) mapping = {"input": physical_input_operand, "other": physical_other_operand, "output": physical_output} - if "bias" in self.node.kwargs and self.node.kwargs['bias'] is not None: + if "bias" in self.node.kwargs and self.node.kwargs["bias"] is not None: # check if the other operand is a parameter if isinstance(self.node.kwargs["bias"]._meta_data, torch.nn.parameter.Parameter): data_type = OperationDataType.PARAM else: data_type = OperationDataType.ARG - physical_bias_operand = OperationData(name=str(self.node.kwargs["bias"]), - type=data_type, - data=self.node.kwargs["bias"]._meta_data) - mapping['bias'] = physical_bias_operand + physical_bias_operand = OperationData( + name=str(self.node.kwargs["bias"]), type=data_type, data=self.node.kwargs["bias"]._meta_data + ) + mapping["bias"] = physical_bias_operand return mapping def post_process(self, strategy: ShardingStrategy): diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/default_reshape_handler.py b/colossalai/auto_parallel/tensor_shard/node_handler/default_reshape_handler.py index 0c5b9f39e1fba75b44308d57569c3b8c0b5087c0..feb1032a6c0f2077f6a525523cde765b94bc8d1c 100644 --- a/colossalai/auto_parallel/tensor_shard/node_handler/default_reshape_handler.py +++ b/colossalai/auto_parallel/tensor_shard/node_handler/default_reshape_handler.py @@ -3,11 +3,11 @@ from typing import Dict, List import torch from ..sharding_strategy import OperationData, OperationDataType -from .node_handler import MetaInfoNodeHandler, NodeHandler +from .node_handler import MetaInfoNodeHandler from .registry import operator_registry from .strategy import DefaultReshapeGenerator, StrategyGenerator -__all__ = ['DefaultReshapeHandler'] +__all__ = ["DefaultReshapeHandler"] @operator_registry.register(torch.flatten) @@ -54,17 +54,15 @@ class DefaultReshapeHandler(MetaInfoNodeHandler): input_data = self.node.args[0]._meta_data input_logical_shape = self.infer_logical_shape(input_data) - physical_input_operand = OperationData(name=str(self.node.args[0]), - type=data_type, - data=input_data, - logical_shape=input_logical_shape) + physical_input_operand = OperationData( + name=str(self.node.args[0]), type=data_type, data=input_data, logical_shape=input_logical_shape + ) output_data = self.node._meta_data output_logical_shape = self.infer_logical_shape(output_data) - physical_output = OperationData(name=str(self.node), - type=OperationDataType.OUTPUT, - data=output_data, - logical_shape=output_logical_shape) + physical_output = OperationData( + name=str(self.node), type=OperationDataType.OUTPUT, data=output_data, logical_shape=output_logical_shape + ) mapping = {"input": physical_input_operand, "output": physical_output} diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/embedding_handler.py b/colossalai/auto_parallel/tensor_shard/node_handler/embedding_handler.py index e154105b672de30b5675ca56147fb6b68205a469..f29c3a0b7d5d28dfa9e96db36562702c25b613ea 100644 --- a/colossalai/auto_parallel/tensor_shard/node_handler/embedding_handler.py +++ b/colossalai/auto_parallel/tensor_shard/node_handler/embedding_handler.py @@ -12,11 +12,12 @@ from .node_handler import ModuleHandler, NodeHandler from .registry import operator_registry from .strategy import EmbeddingStrategyGenerator, StrategyGenerator -__all__ = ['EmbeddingModuleHandler', 'EmbeddingFunctionHandler'] +__all__ = ["EmbeddingModuleHandler", "EmbeddingFunctionHandler"] -def _convert_logical_sharding_to_physical_sharding_spec_for_embedding(strategy: ShardingStrategy, input_name: str, - output_name: str) -> List[ShardingStrategy]: +def _convert_logical_sharding_to_physical_sharding_spec_for_embedding( + strategy: ShardingStrategy, input_name: str, output_name: str +) -> List[ShardingStrategy]: """ This function converts the logical sharding spec to the physical sharding spec for both the input and output of the embedding operation. @@ -56,27 +57,31 @@ def _convert_logical_sharding_to_physical_sharding_spec_for_embedding(strategy: output_sharding_spec = strategy_copy.get_sharding_spec_by_name(output_op_data.name) try: # replace the 0th dimension in the logical sharding with ith dimension in the physical sharding - update_partition_dim(sharding_spec=input_sharding_spec, - dim_mapping={0: i}, - physical_shape=input_op_data.data.shape, - inplace=True) + update_partition_dim( + sharding_spec=input_sharding_spec, + dim_mapping={0: i}, + physical_shape=input_op_data.data.shape, + inplace=True, + ) if last_logical_output_dims in output_sharding_spec.dim_partition_dict: dim_mapping = {0: i, last_logical_output_dims: last_physical_output_dims} else: dim_mapping = {0: i} - update_partition_dim(sharding_spec=output_sharding_spec, - dim_mapping=dim_mapping, - physical_shape=output_op_data.data.shape, - inplace=True) + update_partition_dim( + sharding_spec=output_sharding_spec, + dim_mapping=dim_mapping, + physical_shape=output_op_data.data.shape, + inplace=True, + ) - strategy_copy.name = f'{strategy.name}_{i}' + strategy_copy.name = f"{strategy.name}_{i}" sharding_strategies.append(strategy_copy) except ShardingNotDivisibleError as e: logger.debug( - f'Errored occurred when converting the logical sharding spec to the physical one. Error details: {e}' + f"Errored occurred when converting the logical sharding spec to the physical one. Error details: {e}" ) else: # the generated sharding strategy does not shard the non-matrix dimension, @@ -87,20 +92,21 @@ def _convert_logical_sharding_to_physical_sharding_spec_for_embedding(strategy: output_sharding_spec = strategy_copy.get_sharding_spec_by_name(output_op_data.name) # after updating, the logical shape will be replaced by the physical shape - update_partition_dim(sharding_spec=input_sharding_spec, - dim_mapping={}, - physical_shape=input_op_data.data.shape, - inplace=True) + update_partition_dim( + sharding_spec=input_sharding_spec, dim_mapping={}, physical_shape=input_op_data.data.shape, inplace=True + ) if last_logical_output_dims in output_sharding_spec.dim_partition_dict: dim_mapping = {last_logical_output_dims: last_physical_output_dims} else: dim_mapping = {} - update_partition_dim(sharding_spec=output_sharding_spec, - dim_mapping=dim_mapping, - physical_shape=output_op_data.data.shape, - inplace=True) + update_partition_dim( + sharding_spec=output_sharding_spec, + dim_mapping=dim_mapping, + physical_shape=output_op_data.data.shape, + inplace=True, + ) sharding_strategies.append(strategy_copy) return sharding_strategies @@ -125,14 +131,16 @@ class EmbeddingModuleHandler(ModuleHandler): # Finally, the input will be transformed back to its original shape in self.post_process input_meta_data = self.node.args[0]._meta_data input_logical_shape = input_meta_data.view(-1).shape - physical_input_operand = OperationData(name=str(self.node.args[0]), - type=OperationDataType.ARG, - data=input_meta_data, - logical_shape=input_logical_shape) + physical_input_operand = OperationData( + name=str(self.node.args[0]), + type=OperationDataType.ARG, + data=input_meta_data, + logical_shape=input_logical_shape, + ) - physical_other_operand = OperationData(name="weight", - type=OperationDataType.PARAM, - data=self.named_parameters['weight']) + physical_other_operand = OperationData( + name="weight", type=OperationDataType.PARAM, data=self.named_parameters["weight"] + ) # Same as input, in nn.Embedding operation, all the dimensions of output will be treated as # (batch dimension, embedding dimension), and then the sharding spec will be generated based @@ -141,10 +149,12 @@ class EmbeddingModuleHandler(ModuleHandler): # Finally, the output will be transformed back to its original shape in self.post_process output_meta_data = self.node._meta_data output_logical_shape = output_meta_data.view(-1, output_meta_data.shape[-1]).shape - physical_output = OperationData(name=str(self.node), - type=OperationDataType.OUTPUT, - data=output_meta_data, - logical_shape=output_logical_shape) + physical_output = OperationData( + name=str(self.node), + type=OperationDataType.OUTPUT, + data=output_meta_data, + logical_shape=output_logical_shape, + ) mapping = {"input": physical_input_operand, "other": physical_other_operand, "output": physical_output} @@ -155,12 +165,11 @@ class EmbeddingModuleHandler(ModuleHandler): Convert the sharding spec from the logical shape to the physical shape. """ # create multiple sharding strategies for the inputs - # as input can be multi-dimensinal and the partition dim is only 2D, + # as input can be multi-dimensional and the partition dim is only 2D, # we need to map the partition at logical dim 0 to one of the first few dimensions of the input and output - strategies = _convert_logical_sharding_to_physical_sharding_spec_for_embedding(strategy=strategy, - input_name=str( - self.node.args[0]), - output_name=str(self.node)) + strategies = _convert_logical_sharding_to_physical_sharding_spec_for_embedding( + strategy=strategy, input_name=str(self.node.args[0]), output_name=str(self.node) + ) return strategies @@ -183,10 +192,12 @@ class EmbeddingFunctionHandler(NodeHandler): # Finally, the input will be transformed back to its original shape in self.post_process input_meta_data = self.node.args[0]._meta_data input_logical_shape = input_meta_data.view(-1).shape - physical_input_operand = OperationData(name=str(self.node.args[0]), - type=OperationDataType.ARG, - data=self.node.args[0]._meta_data, - logical_shape=input_logical_shape) + physical_input_operand = OperationData( + name=str(self.node.args[0]), + type=OperationDataType.ARG, + data=self.node.args[0]._meta_data, + logical_shape=input_logical_shape, + ) # check if the other operand is a parameter if isinstance(self.node.args[1]._meta_data, torch.nn.parameter.Parameter): @@ -194,9 +205,9 @@ class EmbeddingFunctionHandler(NodeHandler): else: data_type = OperationDataType.ARG - physical_other_operand = OperationData(name=str(self.node.args[1]), - type=data_type, - data=self.node.args[1]._meta_data) + physical_other_operand = OperationData( + name=str(self.node.args[1]), type=data_type, data=self.node.args[1]._meta_data + ) # Same as input, in F.embedding operation, all the dimensions of output will be treated as # (batch dimension, embedding dimension), and then the sharding spec will be generated based @@ -221,10 +232,9 @@ class EmbeddingFunctionHandler(NodeHandler): Convert the sharding spec from the logical shape to the physical shape. """ # create multiple sharding strategies for the inputs - # as input can be multi-dimensinal and the partition dim is only 2D, + # as input can be multi-dimensional and the partition dim is only 2D, # we need to map the partition at logical dim 0 to one of the first few dimensions of the input and output - strategies = _convert_logical_sharding_to_physical_sharding_spec_for_embedding(strategy=strategy, - input_name=str( - self.node.args[0]), - output_name=str(self.node)) + strategies = _convert_logical_sharding_to_physical_sharding_spec_for_embedding( + strategy=strategy, input_name=str(self.node.args[0]), output_name=str(self.node) + ) return strategies diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/getattr_handler.py b/colossalai/auto_parallel/tensor_shard/node_handler/getattr_handler.py index 53addb873d1d1a014352058f8ec127f6bf7c4d91..dcf0a1760a2cdbf3791bf13103dd8940b0c1f24f 100644 --- a/colossalai/auto_parallel/tensor_shard/node_handler/getattr_handler.py +++ b/colossalai/auto_parallel/tensor_shard/node_handler/getattr_handler.py @@ -4,7 +4,7 @@ from ..sharding_strategy import OperationData, OperationDataType from .node_handler import NodeHandler from .strategy import GetattrGenerator, StrategyGenerator -__all__ = ['GetattrHandler'] +__all__ = ["GetattrHandler"] class GetattrHandler(NodeHandler): diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/getitem_handler.py b/colossalai/auto_parallel/tensor_shard/node_handler/getitem_handler.py index 3466e9dd9940e748da4bc8abb3488aacf98cd8ff..bd342c12eda97d83217b6f4d8b6b846974b67657 100644 --- a/colossalai/auto_parallel/tensor_shard/node_handler/getitem_handler.py +++ b/colossalai/auto_parallel/tensor_shard/node_handler/getitem_handler.py @@ -8,7 +8,7 @@ from .node_handler import NodeHandler from .registry import operator_registry from .strategy import StrategyGenerator, TensorStrategyGenerator, TensorTupleStrategyGenerator -__all__ = ['GetItemHandler'] +__all__ = ["GetItemHandler"] @operator_registry.register(operator.getitem) @@ -30,9 +30,9 @@ class GetItemHandler(NodeHandler): def get_operation_data_mapping(self) -> Dict[str, OperationData]: # use transposed shape for strategies # the strategies will be transformed back to its original shape in self.post_process - physical_input_operand = OperationData(name=str(self.node.args[0]), - type=OperationDataType.ARG, - data=self.node.args[0]._meta_data) + physical_input_operand = OperationData( + name=str(self.node.args[0]), type=OperationDataType.ARG, data=self.node.args[0]._meta_data + ) physical_other_operand = OperationData(name="index", type=OperationDataType.ARG, data=self.node.args[1]) physical_output = OperationData(name=str(self.node), type=OperationDataType.OUTPUT, data=self.node._meta_data) diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/layer_norm_handler.py b/colossalai/auto_parallel/tensor_shard/node_handler/layer_norm_handler.py index 452381169b74d093188e0f8d7775037f8bf5019c..ce6b20fa1d2407531c6461880de866161eac85d4 100644 --- a/colossalai/auto_parallel/tensor_shard/node_handler/layer_norm_handler.py +++ b/colossalai/auto_parallel/tensor_shard/node_handler/layer_norm_handler.py @@ -3,11 +3,11 @@ from typing import Dict, List import torch from ..sharding_strategy import OperationData, OperationDataType -from .node_handler import MetaInfoModuleHandler, ModuleHandler +from .node_handler import MetaInfoModuleHandler from .registry import operator_registry from .strategy import LayerNormGenerator, StrategyGenerator -__all__ = ['LayerNormModuleHandler'] +__all__ = ["LayerNormModuleHandler"] @operator_registry.register(torch.nn.LayerNorm) @@ -25,20 +25,22 @@ class LayerNormModuleHandler(MetaInfoModuleHandler): def get_operation_data_mapping(self) -> Dict[str, OperationData]: # use transposed shape for strategies # the strategies will be transformed back to its original shape in self.post_process - physical_input_operand = OperationData(name=str(self.node.args[0]), - type=OperationDataType.ARG, - data=self.node.args[0]._meta_data) - physical_other_operand = OperationData(name="weight", - type=OperationDataType.PARAM, - data=self.named_parameters['weight'], - logical_shape=self.named_parameters['weight'].shape) + physical_input_operand = OperationData( + name=str(self.node.args[0]), type=OperationDataType.ARG, data=self.node.args[0]._meta_data + ) + physical_other_operand = OperationData( + name="weight", + type=OperationDataType.PARAM, + data=self.named_parameters["weight"], + logical_shape=self.named_parameters["weight"].shape, + ) physical_output = OperationData(name=str(self.node), type=OperationDataType.OUTPUT, data=self.node._meta_data) mapping = {"input": physical_input_operand, "other": physical_other_operand, "output": physical_output} - if self.named_parameters['bias'] is not None: - physical_bias_operand = OperationData(name="bias", - type=OperationDataType.PARAM, - data=self.named_parameters['bias']) - mapping['bias'] = physical_bias_operand + if self.named_parameters["bias"] is not None: + physical_bias_operand = OperationData( + name="bias", type=OperationDataType.PARAM, data=self.named_parameters["bias"] + ) + mapping["bias"] = physical_bias_operand return mapping diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/linear_handler.py b/colossalai/auto_parallel/tensor_shard/node_handler/linear_handler.py index 59091dab519f4e4458461b84e444b9a034f4df98..4177af4eaf71b21bd0ec9ae8f386d95b212e41a1 100644 --- a/colossalai/auto_parallel/tensor_shard/node_handler/linear_handler.py +++ b/colossalai/auto_parallel/tensor_shard/node_handler/linear_handler.py @@ -3,27 +3,24 @@ from typing import Dict, List, Union import torch import torch.nn.functional as F -from colossalai.auto_parallel.tensor_shard.utils import ( - check_sharding_spec_validity, - transpose_partition_dim, - update_partition_dim, -) +from colossalai.auto_parallel.tensor_shard.utils import transpose_partition_dim, update_partition_dim from colossalai.logging import get_dist_logger from colossalai.tensor.sharding_spec import ShardingNotDivisibleError -from ..sharding_strategy import OperationData, OperationDataType, ShardingStrategy, StrategiesVector -from .node_handler import MetaInfoModuleHandler, MetaInfoNodeHandler, ModuleHandler, NodeHandler +from ..sharding_strategy import OperationData, OperationDataType, ShardingStrategy +from .node_handler import MetaInfoModuleHandler, MetaInfoNodeHandler from .registry import operator_registry from .strategy import LinearProjectionStrategyGenerator, StrategyGenerator -__all__ = ['LinearModuleHandler', 'LinearFunctionHandler'] +__all__ = ["LinearModuleHandler", "LinearFunctionHandler"] -def _update_sharding_spec_for_transposed_weight_for_linear(strategy: ShardingStrategy, - weight_name: str) -> ShardingStrategy: +def _update_sharding_spec_for_transposed_weight_for_linear( + strategy: ShardingStrategy, weight_name: str +) -> ShardingStrategy: """ This function is a helper function used by both module node handler and function node handler. This function will - convert the sharding spec for the transposed weight to the correct partititon spec. + convert the sharding spec for the transposed weight to the correct partition spec. Args: strategy (ShardingStrategy): the strategy generated by the strategy generator. @@ -32,16 +29,17 @@ def _update_sharding_spec_for_transposed_weight_for_linear(strategy: ShardingStr # switch the dimensions of the transposed weight sharding_spec = strategy.get_sharding_spec_by_name(weight_name) op_data = strategy.get_op_data_by_name(weight_name) - assert op_data.logical_shape[0] == op_data.data.shape[1] and \ - op_data.logical_shape[1] == op_data.data.shape[0], \ - "Expected the logical shape of the linear operator's weight is equal to transposed physical shape" + assert ( + op_data.logical_shape[0] == op_data.data.shape[1] and op_data.logical_shape[1] == op_data.data.shape[0] + ), "Expected the logical shape of the linear operator's weight is equal to transposed physical shape" dim_size = len(op_data.logical_shape) transpose_partition_dim(sharding_spec, 0, dim_size - 1) return strategy -def _convert_logical_sharding_to_physical_sharding_spec_for_linear(strategy: ShardingStrategy, input_name: str, - output_name: str) -> List[ShardingStrategy]: +def _convert_logical_sharding_to_physical_sharding_spec_for_linear( + strategy: ShardingStrategy, input_name: str, output_name: str +) -> List[ShardingStrategy]: """ This function converts the logical sharding spec to the physical sharding spec for both the input and output of the linear operation. The input and output should have the same sharding spec. @@ -99,22 +97,26 @@ def _convert_logical_sharding_to_physical_sharding_spec_for_linear(strategy: Sha input_dim_mapping = {0: i} input_dim_mapping.update(input_last_dim_mapping) - update_partition_dim(sharding_spec=input_sharding_spec, - dim_mapping=input_dim_mapping, - physical_shape=input_op_data.data.shape, - inplace=True) + update_partition_dim( + sharding_spec=input_sharding_spec, + dim_mapping=input_dim_mapping, + physical_shape=input_op_data.data.shape, + inplace=True, + ) output_dim_mapping = {0: i} output_dim_mapping.update(output_last_dim_mapping) - update_partition_dim(sharding_spec=output_sharding_spec, - dim_mapping=output_dim_mapping, - physical_shape=output_op_data.data.shape, - inplace=True) - strategy_copy.name = f'{strategy.name}_{i}' + update_partition_dim( + sharding_spec=output_sharding_spec, + dim_mapping=output_dim_mapping, + physical_shape=output_op_data.data.shape, + inplace=True, + ) + strategy_copy.name = f"{strategy.name}_{i}" sharding_strategies.append(strategy_copy) except ShardingNotDivisibleError as e: logger.debug( - f'Errored occurred when converting the logical sharding spec to the physical one. Error details: {e}' + f"Errored occurred when converting the logical sharding spec to the physical one. Error details: {e}" ) else: # the generated sharding strategy does not shard the non-matrix dimension, @@ -127,17 +129,21 @@ def _convert_logical_sharding_to_physical_sharding_spec_for_linear(strategy: Sha # after updating, the logical shape will be replaced by the physical shape input_dim_mapping = {} input_dim_mapping.update(input_last_dim_mapping) - update_partition_dim(sharding_spec=input_sharding_spec, - dim_mapping=input_dim_mapping, - physical_shape=input_op_data.data.shape, - inplace=True) + update_partition_dim( + sharding_spec=input_sharding_spec, + dim_mapping=input_dim_mapping, + physical_shape=input_op_data.data.shape, + inplace=True, + ) output_dim_mapping = {} output_dim_mapping.update(output_last_dim_mapping) - update_partition_dim(sharding_spec=output_sharding_spec, - dim_mapping=output_dim_mapping, - physical_shape=output_op_data.data.shape, - inplace=True) + update_partition_dim( + sharding_spec=output_sharding_spec, + dim_mapping=output_dim_mapping, + physical_shape=output_op_data.data.shape, + inplace=True, + ) sharding_strategies.append(strategy_copy) return sharding_strategies @@ -152,10 +158,13 @@ class LinearModuleHandler(MetaInfoModuleHandler): op_data_mapping = self.get_operation_data_mapping() generators = [] generators.append( - LinearProjectionStrategyGenerator(op_data_mapping, - self.device_mesh, - linear_projection_type='linear', - solver_perference=self.solver_perference)) + LinearProjectionStrategyGenerator( + op_data_mapping, + self.device_mesh, + linear_projection_type="linear", + solver_perference=self.solver_perference, + ) + ) return generators def get_operation_data_mapping(self) -> Dict[str, OperationData]: @@ -163,28 +172,34 @@ class LinearModuleHandler(MetaInfoModuleHandler): # the strategies will be transformed back to its original shape in self.post_process input_meta_data = self.node.args[0]._meta_data input_logical_shape = input_meta_data.view(-1, input_meta_data.shape[-1]).shape - physical_input_operand = OperationData(name=str(self.node.args[0]), - type=OperationDataType.ARG, - data=input_meta_data, - logical_shape=input_logical_shape) - physical_other_operand = OperationData(name="weight", - type=OperationDataType.PARAM, - data=self.named_parameters['weight'], - logical_shape=self.named_parameters['weight'].shape[::-1]) + physical_input_operand = OperationData( + name=str(self.node.args[0]), + type=OperationDataType.ARG, + data=input_meta_data, + logical_shape=input_logical_shape, + ) + physical_other_operand = OperationData( + name="weight", + type=OperationDataType.PARAM, + data=self.named_parameters["weight"], + logical_shape=self.named_parameters["weight"].shape[::-1], + ) output_meta_data = self.node._meta_data output_logical_shape = output_meta_data.view(-1, output_meta_data.shape[-1]).shape - physical_output = OperationData(name=str(self.node), - type=OperationDataType.OUTPUT, - data=output_meta_data, - logical_shape=output_logical_shape) + physical_output = OperationData( + name=str(self.node), + type=OperationDataType.OUTPUT, + data=output_meta_data, + logical_shape=output_logical_shape, + ) mapping = {"input": physical_input_operand, "other": physical_other_operand, "output": physical_output} - if 'bias' in self.named_parameters is not None: - physical_bias_operand = OperationData(name="bias", - type=OperationDataType.PARAM, - data=self.named_parameters['bias']) - mapping['bias'] = physical_bias_operand + if "bias" in self.named_parameters is not None: + physical_bias_operand = OperationData( + name="bias", type=OperationDataType.PARAM, data=self.named_parameters["bias"] + ) + mapping["bias"] = physical_bias_operand return mapping def post_process(self, strategy: ShardingStrategy) -> Union[ShardingStrategy, List[ShardingStrategy]]: @@ -194,14 +209,14 @@ class LinearModuleHandler(MetaInfoModuleHandler): 2. the input and output sharding specs are updated to physical shape. """ # switch the dimensions of the transposed weight - strategy = _update_sharding_spec_for_transposed_weight_for_linear(strategy=strategy, weight_name='weight') + strategy = _update_sharding_spec_for_transposed_weight_for_linear(strategy=strategy, weight_name="weight") # create multiple sharding strategies for the inputs - # as input can be multi-dimensinal and the partition dim is only 2D, + # as input can be multi-dimensional and the partition dim is only 2D, # we need to map the partition at dim 0 to one of the first few dimensions of the input - strategies = _convert_logical_sharding_to_physical_sharding_spec_for_linear(strategy=strategy, - input_name=str(self.node.args[0]), - output_name=str(self.node)) + strategies = _convert_logical_sharding_to_physical_sharding_spec_for_linear( + strategy=strategy, input_name=str(self.node.args[0]), output_name=str(self.node) + ) return strategies @@ -215,7 +230,8 @@ class LinearFunctionHandler(MetaInfoNodeHandler): op_data_mapping = self.get_operation_data_mapping() generators = [] generators.append( - LinearProjectionStrategyGenerator(op_data_mapping, self.device_mesh, linear_projection_type='linear')) + LinearProjectionStrategyGenerator(op_data_mapping, self.device_mesh, linear_projection_type="linear") + ) return generators def get_operation_data_mapping(self) -> Dict[str, OperationData]: @@ -223,10 +239,12 @@ class LinearFunctionHandler(MetaInfoNodeHandler): # the strategies will be transformed back to its original shape in self.post_process input_meta_data = self.node.args[0]._meta_data input_logical_shape = input_meta_data.view(-1, input_meta_data.shape[-1]).shape - physical_input_operand = OperationData(name=str(self.node.args[0]), - type=OperationDataType.ARG, - data=self.node.args[0]._meta_data, - logical_shape=input_logical_shape) + physical_input_operand = OperationData( + name=str(self.node.args[0]), + type=OperationDataType.ARG, + data=self.node.args[0]._meta_data, + logical_shape=input_logical_shape, + ) # check if the other operand is a parameter if isinstance(self.node.args[1]._meta_data, torch.nn.parameter.Parameter): @@ -234,10 +252,12 @@ class LinearFunctionHandler(MetaInfoNodeHandler): else: data_type = OperationDataType.ARG - physical_other_operand = OperationData(name=str(self.node.args[1]), - type=data_type, - data=self.node.args[1]._meta_data, - logical_shape=self.node.args[1]._meta_data.shape[::-1]) + physical_other_operand = OperationData( + name=str(self.node.args[1]), + type=data_type, + data=self.node.args[1]._meta_data, + logical_shape=self.node.args[1]._meta_data.shape[::-1], + ) output_meta_data = self.node._meta_data output_logical_shape = output_meta_data.view(-1, output_meta_data.shape[-1]).shape physical_output = OperationData( @@ -249,27 +269,28 @@ class LinearFunctionHandler(MetaInfoNodeHandler): mapping = {"input": physical_input_operand, "other": physical_other_operand, "output": physical_output} - if 'bias' in self.node.kwargs and self.node.kwargs['bias'] is not None: + if "bias" in self.node.kwargs and self.node.kwargs["bias"] is not None: # check if the other operand is a parameter if isinstance(self.node.kwargs["bias"]._meta_data, torch.nn.parameter.Parameter): data_type = OperationDataType.PARAM else: data_type = OperationDataType.ARG - physical_bias_operand = OperationData(name=str(self.node.kwargs["bias"]), - type=data_type, - data=self.node.kwargs["bias"]._meta_data) - mapping['bias'] = physical_bias_operand + physical_bias_operand = OperationData( + name=str(self.node.kwargs["bias"]), type=data_type, data=self.node.kwargs["bias"]._meta_data + ) + mapping["bias"] = physical_bias_operand return mapping def post_process(self, strategy: ShardingStrategy): # switch the dimensions of the transposed weight - strategy = _update_sharding_spec_for_transposed_weight_for_linear(strategy=strategy, - weight_name=str(self.node.args[1])) + strategy = _update_sharding_spec_for_transposed_weight_for_linear( + strategy=strategy, weight_name=str(self.node.args[1]) + ) # create multiple sharding strategies for the inputs - # as input can be multi-dimensinal and the partition dim is only 2D, + # as input can be multi-dimensional and the partition dim is only 2D, # we need to map the partition at dim 0 to one of the first few dimensions of the input - strategies = _convert_logical_sharding_to_physical_sharding_spec_for_linear(strategy=strategy, - input_name=str(self.node.args[0]), - output_name=str(self.node)) + strategies = _convert_logical_sharding_to_physical_sharding_spec_for_linear( + strategy=strategy, input_name=str(self.node.args[0]), output_name=str(self.node) + ) return strategies diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/matmul_handler.py b/colossalai/auto_parallel/tensor_shard/node_handler/matmul_handler.py index f3c9d0cbf8267e2321415ae29887e308a9af35b2..4fab5f7f05eb3ea4dcfeeb49a48ae8230acef97e 100644 --- a/colossalai/auto_parallel/tensor_shard/node_handler/matmul_handler.py +++ b/colossalai/auto_parallel/tensor_shard/node_handler/matmul_handler.py @@ -16,7 +16,7 @@ from colossalai.tensor.sharding_spec import ShardingSpecException from ..sharding_strategy import OperationData, OperationDataType, ShardingStrategy from ..utils import recover_sharding_spec_for_broadcast_shape -from .node_handler import MetaInfoNodeHandler, NodeHandler +from .node_handler import MetaInfoNodeHandler from .registry import operator_registry from .strategy import ( BatchedMatMulStrategyGenerator, @@ -37,6 +37,7 @@ class MatMulType(Enum): MV: matrix-vector product: the 1st tensor is 2D and the 2nd tensor is 1D BMM: batched matrix-matrix multiplication, one tensor is at least 1D and the other is at least 3D """ + DOT = 0 MM = 1 MV = 2 @@ -48,8 +49,8 @@ def get_matmul_type(input_dim: int, other_dim: int): Determine which type of matmul operation should be executed for the given tensor dimensions. Args: - input_dim (int): the number of dimensions for the input tenosr - other_dim (int): the number of dimensions for the other tenosr + input_dim (int): the number of dimensions for the input tensor + other_dim (int): the number of dimensions for the other tensor """ if input_dim == 1 and other_dim == 1: matmul_type = MatMulType.DOT @@ -92,26 +93,26 @@ class Padder(BmmTransform): def apply(self, shape_mapping: Dict[str, List[int]]): mapping_copy = deepcopy(shape_mapping) - input_shape = mapping_copy['input'] - other_shape = mapping_copy['other'] + input_shape = mapping_copy["input"] + other_shape = mapping_copy["other"] if len(input_shape) == 1: # if the input is a 1D tensor, 1 is prepended to its shape # and it will be removed afterwards input_shape.insert(0, 1) - self.padded_dim_mapping['input'] = -2 - self.padded_dim_mapping['output'] = -2 + self.padded_dim_mapping["input"] = -2 + self.padded_dim_mapping["output"] = -2 elif len(other_shape) == 1: # if the other is a 1D tensor, 1 is appended to its shape # and it will be removed afterwards other_shape = other_shape.append(1) - self.padded_dim_mapping['other'] = -1 - self.padded_dim_mapping['output'] = -1 + self.padded_dim_mapping["other"] = -1 + self.padded_dim_mapping["output"] = -1 return mapping_copy def recover(self, op_data_mapping: Dict[str, OperationData], strategy: ShardingStrategy): - input_op_data = op_data_mapping['input'] - other_op_data = op_data_mapping['other'] + op_data_mapping["input"] + op_data_mapping["other"] def _remove_padded_dim(key, strategy): op_data = op_data_mapping[key] @@ -131,7 +132,7 @@ class Padder(BmmTransform): # compute unpadded tensor shape tensor_shape.pop(padded_dim) - assert tensor_shape == list(op_data.data.shape), f'{tensor_shape} vs {list(op_data.data.shape)}' + assert tensor_shape == list(op_data.data.shape), f"{tensor_shape} vs {list(op_data.data.shape)}" # update sharding spec sharding_spec.__init__(sharding_spec.device_mesh, tensor_shape, unpadded_dim_partition_list) @@ -142,15 +143,15 @@ class Padder(BmmTransform): strategy_copy = strategy.clone() # only one of input and other will be padded - if 'input' in self.padded_dim_mapping: - _remove_padded_dim('input', strategy_copy) - _remove_padded_dim('output', strategy_copy) - elif 'other' in self.padded_dim_mapping: - _remove_padded_dim('other', strategy_copy) - _remove_padded_dim('output', strategy_copy) + if "input" in self.padded_dim_mapping: + _remove_padded_dim("input", strategy_copy) + _remove_padded_dim("output", strategy_copy) + elif "other" in self.padded_dim_mapping: + _remove_padded_dim("other", strategy_copy) + _remove_padded_dim("output", strategy_copy) strategies.append(strategy_copy) - except ShardingSpecException as e: + except ShardingSpecException: pass return strategies @@ -167,8 +168,8 @@ class Broadcaster(BmmTransform): mapping_copy = shape_mapping.copy() # get shapes - input_shape = mapping_copy['input'] - other_shape = mapping_copy['other'] + input_shape = mapping_copy["input"] + other_shape = mapping_copy["other"] # sanity check assert len(input_shape) > 1 and len(other_shape) > 1 @@ -179,16 +180,16 @@ class Broadcaster(BmmTransform): # store the broadcast dim info input_broadcast_dim_info = get_broadcast_dim_info(bcast_non_matrix_dims, input_shape[:-2]) other_broadcast_dim_info = get_broadcast_dim_info(bcast_non_matrix_dims, other_shape[:-2]) - self.broadcast_dim_info['input'] = input_broadcast_dim_info - self.broadcast_dim_info['other'] = other_broadcast_dim_info + self.broadcast_dim_info["input"] = input_broadcast_dim_info + self.broadcast_dim_info["other"] = other_broadcast_dim_info # create the full logical shape input_shape = bcast_non_matrix_dims + input_shape[-2:] other_shape = bcast_non_matrix_dims + other_shape[-2:] assert len(input_shape) == len(other_shape) - mapping_copy['input'] = input_shape - mapping_copy['other'] = other_shape + mapping_copy["input"] = input_shape + mapping_copy["other"] = other_shape return mapping_copy @@ -206,7 +207,7 @@ class Broadcaster(BmmTransform): # e.g. [1, 2, 4] x [4, 4, 8] -> [4, 2, 8] # the dim 0 of [1, 2, 4] is multiplied to 4 tensor_shape[dim_idx] = 1 - elif broadcast_type == BroadcastType.PADDDING: + elif broadcast_type == BroadcastType.PADDING: # if the dim is padded # we remove its sharding tensor_shape[dim_idx] = None @@ -216,17 +217,18 @@ class Broadcaster(BmmTransform): physical_sharding_spec, removed_dims = recover_sharding_spec_for_broadcast_shape( logical_sharding_spec=sharding_spec, logical_shape=sharding_spec.entire_shape, - physical_shape=tensor_shape_before_broadcast) + physical_shape=tensor_shape_before_broadcast, + ) strategy.sharding_specs[op_data] = physical_sharding_spec # enumerate all sharding strategies strategies = [] try: strategy_copy = strategy.clone() - _remove_sharding_on_broadcast_dim('input', strategy_copy) - _remove_sharding_on_broadcast_dim('other', strategy_copy) + _remove_sharding_on_broadcast_dim("input", strategy_copy) + _remove_sharding_on_broadcast_dim("other", strategy_copy) strategies.append(strategy_copy) - except ShardingSpecException as e: + except ShardingSpecException: pass return strategies @@ -241,20 +243,20 @@ class Viewer(BmmTransform): def apply(self, shape_mapping: Dict[str, List[int]]): mapping_copy = shape_mapping.copy() - self.batch_dims_before_view = list(mapping_copy['input'][:-2]) + self.batch_dims_before_view = list(mapping_copy["input"][:-2]) # get shapes - input_shape = shape_mapping['input'] - other_shape = shape_mapping['other'] + input_shape = shape_mapping["input"] + other_shape = shape_mapping["other"] # view to 3d tensor assert len(input_shape) >= 3 and len(other_shape) >= 3 input_shape = [reduce(operator.mul, input_shape[:-2])] + input_shape[-2:] other_shape = [reduce(operator.mul, other_shape[:-2])] + other_shape[-2:] output_shape = input_shape[:2] + other_shape[2:] - mapping_copy['input'] = input_shape - mapping_copy['other'] = other_shape - mapping_copy['output'] = output_shape + mapping_copy["input"] = input_shape + mapping_copy["other"] = other_shape + mapping_copy["output"] = output_shape return mapping_copy def recover(self, op_data_mapping: Dict[str, OperationData], strategy: ShardingStrategy): @@ -268,13 +270,13 @@ class Viewer(BmmTransform): dim_partition_dict = sharding_spec.dim_partition_dict entire_shape = sharding_spec.entire_shape - # upddate the dimension index for the matrix dimensions + # update the dimension index for the matrix dimensions if 2 in dim_partition_dict: dim_partition_dict[len(self.batch_dims_before_view) + 1] = dim_partition_dict.pop(2) if 1 in dim_partition_dict: dim_partition_dict[len(self.batch_dims_before_view)] = dim_partition_dict.pop(1) - # map the logical batch dim to phyiscal batch dim + # map the logical batch dim to physical batch dim if 0 in dim_partition_dict: batch_dim_shard = dim_partition_dict.pop(0) dim_partition_dict[physical_batch_dim] = batch_dim_shard @@ -291,11 +293,11 @@ class Viewer(BmmTransform): # create a new strategy strategy_copy = strategy.clone() try: - _update_sharding_spec('input', strategy_copy, i) - _update_sharding_spec('other', strategy_copy, i) - _update_sharding_spec('output', strategy_copy, i) + _update_sharding_spec("input", strategy_copy, i) + _update_sharding_spec("other", strategy_copy, i) + _update_sharding_spec("output", strategy_copy, i) strategies.append(strategy_copy) - except ShardingSpecException as e: + except ShardingSpecException: continue return strategies @@ -312,14 +314,14 @@ def _get_bmm_logical_shape(input_shape, other_shape, transforms): 3. reshape to 3 dimensions """ - shape_mapping = {'input': input_shape, 'other': other_shape} + shape_mapping = {"input": input_shape, "other": other_shape} for transform in transforms: shape_mapping = transform.apply(shape_mapping) - input_shape = shape_mapping.get('input', None) - other_shape = shape_mapping.get('other', None) - output_shape = shape_mapping.get('output', None) + input_shape = shape_mapping.get("input", None) + other_shape = shape_mapping.get("other", None) + output_shape = shape_mapping.get("output", None) return input_shape, other_shape, output_shape @@ -364,7 +366,8 @@ class MatMulHandler(MetaInfoNodeHandler): generators.append(MatVecStrategyGenerator(op_data_mapping, self.device_mesh)) elif self.matmul_type == MatMulType.MM: generators.append( - LinearProjectionStrategyGenerator(op_data_mapping, self.device_mesh, linear_projection_type='linear')) + LinearProjectionStrategyGenerator(op_data_mapping, self.device_mesh, linear_projection_type="linear") + ) return generators def get_operation_data_mapping(self) -> Dict[str, OperationData]: @@ -372,7 +375,7 @@ class MatMulHandler(MetaInfoNodeHandler): MatMulType.DOT: self._get_logical_shape_for_dot, MatMulType.MM: self._get_logical_shape_for_mm, MatMulType.MV: self._get_logical_shape_for_mv, - MatMulType.BMM: self._get_logical_shape_for_bmm + MatMulType.BMM: self._get_logical_shape_for_bmm, } logical_shapes = logical_shape_func[self.matmul_type]() op_data_mapping = self._get_op_data_mapping(*logical_shapes) @@ -390,20 +393,26 @@ class MatMulHandler(MetaInfoNodeHandler): output_logical_shape = torch.Size(output_logical_shape) # create op data - input_op_data = OperationData(name=str(self.node.args[0]), - type=OperationDataType.ARG, - data=self.input_meta_data, - logical_shape=input_logical_shape) - other_op_data = OperationData(name=str(self.node.args[1]), - type=OperationDataType.ARG, - data=self.other_meta_data, - logical_shape=other_logical_shape) - output_op_data = OperationData(name=str(self.node), - type=OperationDataType.OUTPUT, - data=self.output_meta_data, - logical_shape=output_logical_shape) - - mapping = {'input': input_op_data, 'other': other_op_data, 'output': output_op_data} + input_op_data = OperationData( + name=str(self.node.args[0]), + type=OperationDataType.ARG, + data=self.input_meta_data, + logical_shape=input_logical_shape, + ) + other_op_data = OperationData( + name=str(self.node.args[1]), + type=OperationDataType.ARG, + data=self.other_meta_data, + logical_shape=other_logical_shape, + ) + output_op_data = OperationData( + name=str(self.node), + type=OperationDataType.OUTPUT, + data=self.output_meta_data, + logical_shape=output_logical_shape, + ) + + mapping = {"input": input_op_data, "other": other_op_data, "output": output_op_data} return mapping def _get_logical_shape_for_dot(self): @@ -414,7 +423,7 @@ class MatMulHandler(MetaInfoNodeHandler): def _get_logical_shape_for_mm(self): """ - We need to handle the input tensor for a matrix-matrix multiplcation as the input + We need to handle the input tensor for a matrix-matrix multiplication as the input tensor can be a 1D or 2D tensor. If it is a 1D tensor, 1 will be prepended to its shape (e.g. [4] -> [1, 4]). """ @@ -460,9 +469,11 @@ class MatMulHandler(MetaInfoNodeHandler): dim_partition_dict[0] = shard # re-init the sharding spec - input_sharding_spec.__init__(input_sharding_spec.device_mesh, - entire_shape=input_physical_shape, - dim_partition_dict=dim_partition_dict) + input_sharding_spec.__init__( + input_sharding_spec.device_mesh, + entire_shape=input_physical_shape, + dim_partition_dict=dim_partition_dict, + ) return strategy else: return strategy @@ -481,7 +492,8 @@ class MatMulHandler(MetaInfoNodeHandler): recovered_stragies.extend(output) else: raise TypeError( - f"Found unexpected output type {type(output)} from the recover method of BmmTransform") + f"Found unexpected output type {type(output)} from the recover method of BmmTransform" + ) strategies = recovered_stragies for index, strategies in enumerate(strategies): strategies.name = f"{strategies.name}_{index}" diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/node_handler.py b/colossalai/auto_parallel/tensor_shard/node_handler/node_handler.py index ab391ebfaf80960ef49a4e9c4761c76f82567d25..d2bad39dcbb9164bd6c3e6e4d42323d25b88d385 100644 --- a/colossalai/auto_parallel/tensor_shard/node_handler/node_handler.py +++ b/colossalai/auto_parallel/tensor_shard/node_handler/node_handler.py @@ -8,7 +8,6 @@ from colossalai.auto_parallel.meta_profiler.shard_metainfo import ShardMetaInfo, from colossalai.auto_parallel.tensor_shard.options import ShardOption, SolverPerference from colossalai.auto_parallel.tensor_shard.sharding_strategy import ( OperationData, - OperationDataType, ShardingSpec, ShardingStrategy, StrategiesVector, @@ -23,21 +22,23 @@ from .strategy import StrategyGenerator class NodeHandler(ABC): - ''' + """ The NodeHandler is an abstract class used to generate every possible strategies for an operator node. Args: node (Node): the input node in node argument list. device_mesh (DeviceMesh): A logical view of a physical mesh. strategies_vector (StrategiesVector): all the strategies generated in this handler will be recorded into the strategies_vector. - ''' - - def __init__(self, - node: Node, - device_mesh: DeviceMesh, - strategies_vector: StrategiesVector, - shard_option: ShardOption = ShardOption.STANDARD, - solver_perference: SolverPerference = SolverPerference.STANDARD) -> None: + """ + + def __init__( + self, + node: Node, + device_mesh: DeviceMesh, + strategies_vector: StrategiesVector, + shard_option: ShardOption = ShardOption.STANDARD, + solver_perference: SolverPerference = SolverPerference.STANDARD, + ) -> None: self.node = node self.predecessor_node = list(node._input_nodes.keys()) self.successor_node = list(node.users.keys()) @@ -68,22 +69,23 @@ class NodeHandler(ABC): current_sharding_spec = strategy.sharding_specs[op_data] # get the sharding specs for this node generated # in its own node handler - assert hasattr(node, 'strategies_vector'), \ - f'The predecessor node {node_name} has no strategy vector to compute the resharding cost.' + assert hasattr( + node, "strategies_vector" + ), f"The predecessor node {node_name} has no strategy vector to compute the resharding cost." prev_strategy_vector = node.strategies_vector prev_sharding_specs = [ prev_strategy.get_sharding_spec_by_name(node_name) for prev_strategy in prev_strategy_vector ] - # create data structrure to store costs + # create data structure to store costs if node not in resharding_costs: resharding_costs[node] = [] def _compute_resharding_cost( - prev_sharding_spec: Union[ShardingSpec, - List[ShardingSpec]], current_sharding_spec: Union[ShardingSpec, - List[ShardingSpec]], - data: Union[torch.Tensor, List[torch.Tensor], Tuple[torch.Tensor]]) -> TrainCycleItem: + prev_sharding_spec: Union[ShardingSpec, List[ShardingSpec]], + current_sharding_spec: Union[ShardingSpec, List[ShardingSpec]], + data: Union[torch.Tensor, List[torch.Tensor], Tuple[torch.Tensor]], + ) -> TrainCycleItem: """ This is a helper function to compute the resharding cost for a specific strategy of a node. """ @@ -94,30 +96,35 @@ class NodeHandler(ABC): dtype = data.dtype size_per_elem_bytes = torch.tensor([], dtype=dtype).element_size() _, _, consistency_cost = shape_consistency_manager.shape_consistency( - prev_sharding_spec, current_sharding_spec) - - resharding_cost = TrainCycleItem(fwd=consistency_cost["forward"] * size_per_elem_bytes, - bwd=consistency_cost["backward"] * size_per_elem_bytes, - total=consistency_cost["total"] * size_per_elem_bytes) + prev_sharding_spec, current_sharding_spec + ) + + resharding_cost = TrainCycleItem( + fwd=consistency_cost["forward"] * size_per_elem_bytes, + bwd=consistency_cost["backward"] * size_per_elem_bytes, + total=consistency_cost["total"] * size_per_elem_bytes, + ) return resharding_cost else: # This raise is used to check if we have missed any type of data. # It could be merged into Parameter branch, which means we won't handle # non-tensor arguments. - raise ValueError(f'Unsupported data type {type(data)}') + raise ValueError(f"Unsupported data type {type(data)}") else: - assert isinstance(prev_sharding_spec, (tuple, list)), \ - f'prev_sharding_spec should be in type of ShardingSpec, List[ShardingSpec], \ - or Tuple[ShardingSpec], but got {type(prev_sharding_spec)}' + assert isinstance( + prev_sharding_spec, (tuple, list) + ), f"prev_sharding_spec should be in type of ShardingSpec, List[ShardingSpec], \ + or Tuple[ShardingSpec], but got {type(prev_sharding_spec)}" fwd_cost = 0 bwd_cost = 0 total_cost = 0 - for index, (prev_sharding_spec_item, - current_sharding_spec_item) in enumerate(zip(prev_sharding_spec, - current_sharding_spec)): - item_cost = _compute_resharding_cost(prev_sharding_spec_item, current_sharding_spec_item, - data[index]) + for index, (prev_sharding_spec_item, current_sharding_spec_item) in enumerate( + zip(prev_sharding_spec, current_sharding_spec) + ): + item_cost = _compute_resharding_cost( + prev_sharding_spec_item, current_sharding_spec_item, data[index] + ) fwd_cost += item_cost.fwd bwd_cost += item_cost.bwd total_cost += item_cost.total @@ -138,17 +145,17 @@ class NodeHandler(ABC): This function is used to get the target function for the node handler. The target function is used to analyze the costs of strategies. """ - if self.node.op in ('placeholder', 'get_attr', 'output'): + if self.node.op in ("placeholder", "get_attr", "output"): return None - if self.node.op == 'call_module': + if self.node.op == "call_module": target = self.node.graph.owning_module.get_submodule(self.node.target) - elif self.node.op == 'call_function': + elif self.node.op == "call_function": target = self.node.target - elif self.node.op == 'call_method': + elif self.node.op == "call_method": target = getattr(self.node.args[0]._meta_data.__class__, self.node.target) else: - raise ValueError(f'Unsupported node type: {self.node.op}') + raise ValueError(f"Unsupported node type: {self.node.op}") return target @@ -188,7 +195,7 @@ class NodeHandler(ABC): remove_strategy_list = [] for strategy in self.strategies_vector: shard_axis_list = [] - last_axis = len(self.device_mesh.mesh_shape) - 1 + last_axis = len(self.device_mesh.shape) - 1 for op_data, sharding_spec in strategy.sharding_specs.items(): if op_data.data is not None and isinstance(op_data.data, torch.Tensor): for dim, shard_axes in sharding_spec.dim_partition_dict.items(): @@ -212,7 +219,7 @@ class NodeHandler(ABC): return self.strategies_vector def post_process(self, strategy: ShardingStrategy) -> Union[ShardingStrategy, List[ShardingStrategy]]: - # tranform the strategy generated + # transform the strategy generated # e.g. to process the sharding strategy for the transposed weights return strategy @@ -221,7 +228,6 @@ class NodeHandler(ABC): """ Define which generators should be used by this NodeHandler object. """ - pass @abstractmethod def get_operation_data_mapping(self) -> Dict[str, OperationData]: @@ -244,7 +250,6 @@ class NodeHandler(ABC): "output": Operand(name=str(self.node), type=OperationDataType.OUTPUT, data=self.node._meta_data), } """ - pass class MetaInfoNodeHandler(NodeHandler): @@ -278,19 +283,19 @@ class MetaInfoNodeHandler(NodeHandler): else: logger = get_dist_logger() - logger.warning(f'The target function {target} is not patched yet, ') + logger.warning(f"The target function {target} is not patched yet, ") return self.strategies_vector class ModuleHandler(NodeHandler): - def __init__(self, *args, **kwargs) -> None: super().__init__(*args, **kwargs) # set attributes to access module parameters for convenience - assert self.node.graph.owning_module is not None, \ - f'The graph is not associated with a module, please make sure it can be used to instantiate a GraphModule object.' + assert ( + self.node.graph.owning_module is not None + ), f"The graph is not associated with a module, please make sure it can be used to instantiate a GraphModule object." module = self.node.graph.owning_module.get_submodule(self.node.target) named_parameters = list(module.named_parameters(recurse=False)) named_buffers = list(module.named_buffers(recurse=False)) @@ -333,6 +338,6 @@ class MetaInfoModuleHandler(ModuleHandler): else: logger = get_dist_logger() - logger.warning(f'The target function {target} is not patched yet') + logger.warning(f"The target function {target} is not patched yet") return self.strategies_vector diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/normal_pooling_handler.py b/colossalai/auto_parallel/tensor_shard/node_handler/normal_pooling_handler.py index 4e71ccba95a7e6457309a455986400dc49893d18..facf19560596020138cd287d403539914ae3a98e 100644 --- a/colossalai/auto_parallel/tensor_shard/node_handler/normal_pooling_handler.py +++ b/colossalai/auto_parallel/tensor_shard/node_handler/normal_pooling_handler.py @@ -3,11 +3,11 @@ from typing import Dict, List import torch from ..sharding_strategy import OperationData, OperationDataType -from .node_handler import MetaInfoModuleHandler, ModuleHandler +from .node_handler import MetaInfoModuleHandler from .registry import operator_registry from .strategy import NormalPoolStrategyGenerator, StrategyGenerator -__all__ = ['NormPoolingHandler'] +__all__ = ["NormPoolingHandler"] @operator_registry.register(torch.nn.MaxPool1d) @@ -30,9 +30,9 @@ class NormPoolingHandler(MetaInfoModuleHandler): def get_operation_data_mapping(self) -> Dict[str, OperationData]: # use transposed shape for strategies # the strategies will be transformed back to its original shape in self.post_process - physical_input_operand = OperationData(name=str(self.node.args[0]), - type=OperationDataType.ARG, - data=self.node.args[0]._meta_data) + physical_input_operand = OperationData( + name=str(self.node.args[0]), type=OperationDataType.ARG, data=self.node.args[0]._meta_data + ) physical_weight_operand = OperationData(name="kernel", type=OperationDataType.ARG, data=self.module.kernel_size) physical_output = OperationData(name=str(self.node), type=OperationDataType.OUTPUT, data=self.node._meta_data) diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/output_handler.py b/colossalai/auto_parallel/tensor_shard/node_handler/output_handler.py index ed120a8c3d6df9b5d10f44f2b86be1c3cf283c10..89906a205e87f2e955712999fb89805e66354f97 100644 --- a/colossalai/auto_parallel/tensor_shard/node_handler/output_handler.py +++ b/colossalai/auto_parallel/tensor_shard/node_handler/output_handler.py @@ -8,7 +8,7 @@ from ..sharding_strategy import OperationData, OperationDataType, StrategiesVect from .node_handler import NodeHandler from .strategy import OutputGenerator, StrategyGenerator -__all__ = ['OutputHandler'] +__all__ = ["OutputHandler"] class OutputHandler(NodeHandler): @@ -16,8 +16,9 @@ class OutputHandler(NodeHandler): A OutputHandler which deals with the sharding strategies for Output Node. """ - def __init__(self, node: torch.fx.node.Node, device_mesh: DeviceMesh, strategies_vector: StrategiesVector, - output_option: str) -> None: + def __init__( + self, node: torch.fx.node.Node, device_mesh: DeviceMesh, strategies_vector: StrategiesVector, output_option: str + ) -> None: super().__init__(node, device_mesh, strategies_vector) self.output_option = output_option @@ -35,11 +36,11 @@ class OutputHandler(NodeHandler): for index, input_node in enumerate(self.predecessor_node): input_meta_data = input_node._meta_data physical_inputs = OperationData(name=str(input_node), type=OperationDataType.ARG, data=input_meta_data) - name_key = f'input_{index}' + name_key = f"input_{index}" mapping[name_key] = physical_inputs output_meta_data.append(input_meta_data) - assert len(output_meta_data) > 0, f'Output node {self.node} has no input node.' + assert len(output_meta_data) > 0, f"Output node {self.node} has no input node." if len(output_meta_data) == 1: output_meta_data = output_meta_data[0] else: diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/permute_handler.py b/colossalai/auto_parallel/tensor_shard/node_handler/permute_handler.py index 91e4a5105a08ff7d28cebba41f4962daa951259c..75f07168e47b440843e38ab8ad0b1abf2bc136ca 100644 --- a/colossalai/auto_parallel/tensor_shard/node_handler/permute_handler.py +++ b/colossalai/auto_parallel/tensor_shard/node_handler/permute_handler.py @@ -7,7 +7,7 @@ from .node_handler import NodeHandler from .registry import operator_registry from .strategy import PermuteGenerator, StrategyGenerator -__all__ = ['PermuteHandler'] +__all__ = ["PermuteHandler"] @operator_registry.register(torch.Tensor.permute) @@ -34,14 +34,14 @@ class PermuteHandler(NodeHandler): physical_input_operand = OperationData(name=str(self.node.args[0]), type=data_type, data=input_data) permute_dims = [] - if self.node.op == 'call_method': + if self.node.op == "call_method": # torch.Tensor.permute (input, *dims) for arg in self.node.args: if isinstance(arg, torch.fx.Node): if isinstance(arg._meta_data, int): permute_dims.append(arg._meta_data) else: - assert isinstance(arg, int), 'The argument in permute node should be either type of Node or int.' + assert isinstance(arg, int), "The argument in permute node should be either type of Node or int." permute_dims.append(arg) else: # torch.permute (input, dims) @@ -51,8 +51,8 @@ class PermuteHandler(NodeHandler): permute_dims.extend(arg._meta_data) else: assert isinstance( - arg, - (tuple, list)), 'The argument in permute node should be type of Node, Tuple[int] or List[int].' + arg, (tuple, list) + ), "The argument in permute node should be type of Node, Tuple[int] or List[int]." permute_dims.extend(arg) num_dims = self.node._meta_data.dim() @@ -61,7 +61,7 @@ class PermuteHandler(NodeHandler): if permute_dims[i] < 0: permute_dims[i] += num_dims - physical_shape_operand = OperationData(name='permute_dims', type=OperationDataType.ARG, data=list(permute_dims)) + physical_shape_operand = OperationData(name="permute_dims", type=OperationDataType.ARG, data=list(permute_dims)) output_data = self.node._meta_data physical_output_operand = OperationData(name=str(self.node), type=OperationDataType.OUTPUT, data=output_data) @@ -69,7 +69,7 @@ class PermuteHandler(NodeHandler): mapping = { "input": physical_input_operand, "permute_dims": physical_shape_operand, - "output": physical_output_operand + "output": physical_output_operand, } return mapping diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/placeholder_handler.py b/colossalai/auto_parallel/tensor_shard/node_handler/placeholder_handler.py index e4f40fc935a404dd8625c82fbb4dc7511c9fc839..461bc2935780e6d497469e2569a12abf43dff97b 100644 --- a/colossalai/auto_parallel/tensor_shard/node_handler/placeholder_handler.py +++ b/colossalai/auto_parallel/tensor_shard/node_handler/placeholder_handler.py @@ -8,7 +8,7 @@ from ..sharding_strategy import OperationData, OperationDataType, StrategiesVect from .node_handler import NodeHandler from .strategy import PlaceholderGenerator, StrategyGenerator -__all__ = ['PlaceholderHandler'] +__all__ = ["PlaceholderHandler"] class PlaceholderHandler(NodeHandler): @@ -16,8 +16,9 @@ class PlaceholderHandler(NodeHandler): A PlaceholderHandler which deals with the sharding strategies for Placeholder Node. """ - def __init__(self, node: Node, device_mesh: DeviceMesh, strategies_vector: StrategiesVector, - placeholder_option: str) -> None: + def __init__( + self, node: Node, device_mesh: DeviceMesh, strategies_vector: StrategiesVector, placeholder_option: str + ) -> None: super().__init__(node, device_mesh, strategies_vector) self.placeholder_option = placeholder_option @@ -25,7 +26,8 @@ class PlaceholderHandler(NodeHandler): op_data_mapping = self.get_operation_data_mapping() generators = [] generators.append( - PlaceholderGenerator(op_data_mapping, self.device_mesh, placeholder_option=self.placeholder_option)) + PlaceholderGenerator(op_data_mapping, self.device_mesh, placeholder_option=self.placeholder_option) + ) return generators def get_operation_data_mapping(self) -> Dict[str, OperationData]: diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/registry.py b/colossalai/auto_parallel/tensor_shard/node_handler/registry.py index 8e06cec4f463a8600b2abe1a7f6713ec2ffb2931..f663fc9695d3d1b37591429adc80d71f744cdf12 100644 --- a/colossalai/auto_parallel/tensor_shard/node_handler/registry.py +++ b/colossalai/auto_parallel/tensor_shard/node_handler/registry.py @@ -1,12 +1,9 @@ class Registry: - # TODO: refactor the registry classes used in colossalai.registry, colossalai.fx and here - def __init__(self, name): self.name = name self.store = {} def register(self, source): - def wrapper(func): if isinstance(source, (list, tuple)): # support register a list of items for this func @@ -19,7 +16,7 @@ class Registry: return wrapper def get(self, source): - assert source in self.store, f'{source} not found in the {self.name} registry' + assert source in self.store, f"{source} not found in the {self.name} registry" target = self.store[source] return target @@ -27,4 +24,4 @@ class Registry: return source in self.store -operator_registry = Registry('operator') +operator_registry = Registry("operator") diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/softmax_handler.py b/colossalai/auto_parallel/tensor_shard/node_handler/softmax_handler.py index 743a1f90eaafa869b3a62882648cbde53f9e3166..6e883ea6473672204044b431fbe3a3265a63286b 100644 --- a/colossalai/auto_parallel/tensor_shard/node_handler/softmax_handler.py +++ b/colossalai/auto_parallel/tensor_shard/node_handler/softmax_handler.py @@ -7,7 +7,7 @@ from .node_handler import NodeHandler from .registry import operator_registry from .strategy import SoftmaxGenerator, StrategyGenerator -__all__ = ['SoftmaxHandler'] +__all__ = ["SoftmaxHandler"] @operator_registry.register(torch.nn.Softmax) @@ -34,14 +34,14 @@ class SoftmaxHandler(NodeHandler): input_data = self.node.args[0]._meta_data physical_input_operand = OperationData(name=str(self.node.args[0]), type=data_type, data=input_data) - softmax_dim = self.node.kwargs['dim'] + softmax_dim = self.node.kwargs["dim"] num_dims = self.node.args[0]._meta_data.dim() # recover negative value to positive if softmax_dim < 0: softmax_dim += num_dims - physical_dim_operand = OperationData(name='softmax_dim', type=OperationDataType.ARG, data=softmax_dim) + physical_dim_operand = OperationData(name="softmax_dim", type=OperationDataType.ARG, data=softmax_dim) output_data = self.node._meta_data physical_output_operand = OperationData(name=str(self.node), type=OperationDataType.OUTPUT, data=output_data) @@ -49,7 +49,7 @@ class SoftmaxHandler(NodeHandler): mapping = { "input": physical_input_operand, "softmax_dim": physical_dim_operand, - "output": physical_output_operand + "output": physical_output_operand, } return mapping diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/split_handler.py b/colossalai/auto_parallel/tensor_shard/node_handler/split_handler.py index 653d158b7c36ee1ff27791add2edad4093ce8675..4c32529a5d5b822dc89dbf97c5e19889153515af 100644 --- a/colossalai/auto_parallel/tensor_shard/node_handler/split_handler.py +++ b/colossalai/auto_parallel/tensor_shard/node_handler/split_handler.py @@ -7,7 +7,7 @@ from .node_handler import NodeHandler from .registry import operator_registry from .strategy import SplitGenerator, StrategyGenerator -__all__ = ['SplitHandler'] +__all__ = ["SplitHandler"] @operator_registry.register(torch.Tensor.split) @@ -38,7 +38,7 @@ class SplitHandler(NodeHandler): split_dim = self.node.args[2] else: if self.node.kwargs: - split_dim = self.node.kwargs['dim'] + split_dim = self.node.kwargs["dim"] else: split_dim = 0 @@ -48,7 +48,7 @@ class SplitHandler(NodeHandler): split_dim += num_dims split_info = (split_size, split_dim) - physical_shape_operand = OperationData(name='split_info', type=OperationDataType.ARG, data=split_info) + physical_shape_operand = OperationData(name="split_info", type=OperationDataType.ARG, data=split_info) output_data = self.node._meta_data physical_output_operand = OperationData(name=str(self.node), type=OperationDataType.OUTPUT, data=output_data) @@ -56,7 +56,7 @@ class SplitHandler(NodeHandler): mapping = { "input": physical_input_operand, "split_info": physical_shape_operand, - "output": physical_output_operand + "output": physical_output_operand, } return mapping diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/__init__.py b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/__init__.py index db1f31521c86ef1842e93d9bbdbc58953e11934d..1fc7f613716b4cc895f10b61e9a5d0c91932a22e 100644 --- a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/__init__.py +++ b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/__init__.py @@ -29,11 +29,31 @@ from .unary_elementwise_generator import UnaryElementwiseGenerator from .where_generator import WhereGenerator __all__ = [ - 'StrategyGenerator', 'DotProductStrategyGenerator', 'MatVecStrategyGenerator', 'LinearProjectionStrategyGenerator', - 'BatchedMatMulStrategyGenerator', 'ConvStrategyGenerator', 'UnaryElementwiseGenerator', - 'BatchNormStrategyGenerator', 'GetItemStrategyGenerator', 'TensorStrategyGenerator', 'TensorTupleStrategyGenerator', - 'LayerNormGenerator', 'PlaceholderGenerator', 'OutputGenerator', 'WhereGenerator', 'NormalPoolStrategyGenerator', - 'BinaryElementwiseStrategyGenerator', 'GetattrGenerator', 'TensorConstructorGenerator', - 'EmbeddingStrategyGenerator', 'SumGenerator', 'SoftmaxGenerator', 'ViewGenerator', 'PermuteGenerator', - 'TransposeGenerator', 'SplitGenerator', 'DefaultReshapeGenerator' + "StrategyGenerator", + "DotProductStrategyGenerator", + "MatVecStrategyGenerator", + "LinearProjectionStrategyGenerator", + "BatchedMatMulStrategyGenerator", + "ConvStrategyGenerator", + "UnaryElementwiseGenerator", + "BatchNormStrategyGenerator", + "GetItemStrategyGenerator", + "TensorStrategyGenerator", + "TensorTupleStrategyGenerator", + "LayerNormGenerator", + "PlaceholderGenerator", + "OutputGenerator", + "WhereGenerator", + "NormalPoolStrategyGenerator", + "BinaryElementwiseStrategyGenerator", + "GetattrGenerator", + "TensorConstructorGenerator", + "EmbeddingStrategyGenerator", + "SumGenerator", + "SoftmaxGenerator", + "ViewGenerator", + "PermuteGenerator", + "TransposeGenerator", + "SplitGenerator", + "DefaultReshapeGenerator", ] diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/batch_norm_generator.py b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/batch_norm_generator.py index 1f3812429fc274064163f5859d0fedb04f8115fb..9c766b1014c80454abe0123fc9a1464cf92b63af 100644 --- a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/batch_norm_generator.py +++ b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/batch_norm_generator.py @@ -14,7 +14,7 @@ from colossalai.tensor.shape_consistency import CollectiveCommPattern from .strategy_generator import StrategyGenerator -__all__ = ['BatchNormStrategyGenerator'] +__all__ = ["BatchNormStrategyGenerator"] class BatchNormStrategyGenerator(StrategyGenerator): @@ -24,34 +24,37 @@ class BatchNormStrategyGenerator(StrategyGenerator): To keep the math consistency, there are two way to do BatchNorm if the input shards on batch dimension: 1. We gather the input partitions through batch dimension, then do the normal BatchNorm. - 2. We do the SyncBatchNorm on the each input partition seperately, the SyncBN op will help + 2. We do the SyncBatchNorm on the each input partition separately, the SyncBN op will help us to keep the computing correctness. In this generator, both methods will be considered. """ def validate(self) -> bool: - ''' + """ In sanity check, we need make sure the input data having correct dimension size. For BatchNorm1d, the dim of input data should be 3([N, C, L]). For BatchNorm2d, the dim of input data should be 4([N, C, H, W]). For BatchNorm3d, the dim of input data should be 5([N, C, H, W, D]). - ''' - input_op_data = self.op_data['input'] + """ + input_op_data = self.op_data["input"] assert input_op_data.data.dim() in ( - 3, 4, 5), f'We suppose the dim of input fed into conv op should in range of [3, 5].' + 3, + 4, + 5, + ), f"We suppose the dim of input fed into conv op should in range of [3, 5]." def update_compute_cost(self, strategy: ShardingStrategy): - ''' + """ Compute the computation cost per device with this specific strategy. - Note: compute_cost need to be devided by TFLOPS, now it just shows the computation size. - ''' + Note: compute_cost need to be divided by TFLOPS, now it just shows the computation size. + """ # TODO: a constant coefficient need to be added. # 1D: (L) * N * Cin # 2D: (H * W) * N * Cin # 3D: (H * W * D) * N * Cin - sharded_input_shape = strategy.sharding_specs[self.op_data['input']].get_sharded_shape_per_device() - sharded_output_shape = strategy.sharding_specs[self.op_data['output']].get_sharded_shape_per_device() + sharded_input_shape = strategy.sharding_specs[self.op_data["input"]].get_sharded_shape_per_device() + sharded_output_shape = strategy.sharding_specs[self.op_data["output"]].get_sharded_shape_per_device() if self.has_bias: # bias add is an element wise operation, so the cost is equal to product of output shape. bias_compute_cost = reduce(operator.mul, sharded_output_shape) @@ -69,23 +72,24 @@ class BatchNormStrategyGenerator(StrategyGenerator): def update_memory_cost(self, strategy: ShardingStrategy): forward_size_mapping = { - 'input': self._compute_size_in_bytes(strategy, "input"), - 'other': self._compute_size_in_bytes(strategy, "other"), - 'output': self._compute_size_in_bytes(strategy, "output"), - 'running_mean': self._compute_size_in_bytes(strategy, "running_mean"), - 'running_var': self._compute_size_in_bytes(strategy, "running_var"), + "input": self._compute_size_in_bytes(strategy, "input"), + "other": self._compute_size_in_bytes(strategy, "other"), + "output": self._compute_size_in_bytes(strategy, "output"), + "running_mean": self._compute_size_in_bytes(strategy, "running_mean"), + "running_var": self._compute_size_in_bytes(strategy, "running_var"), } if self.has_bias: bias_size = self._compute_size_in_bytes(strategy, "bias") - forward_size_mapping['bias'] = bias_size + forward_size_mapping["bias"] = bias_size backward_size_mapping = copy.deepcopy(forward_size_mapping) backward_size_mapping.pop("output") # compute fwd cost incurred # fwd_cost = input + other + bias + output fwd_activation_cost = sum( - [v for k, v in forward_size_mapping.items() if not self.is_param(k) and not self.is_buffer(k)]) + [v for k, v in forward_size_mapping.items() if not self.is_param(k) and not self.is_buffer(k)] + ) fwd_parameter_cost = sum([v for k, v in forward_size_mapping.items() if self.is_param(k)]) fwd_buffer_cost = sum([v for k, v in forward_size_mapping.items() if self.is_buffer(k)]) fwd_mem_cost = MemoryCost(activation=fwd_activation_cost, parameter=fwd_parameter_cost, buffer=fwd_buffer_cost) @@ -93,36 +97,29 @@ class BatchNormStrategyGenerator(StrategyGenerator): # compute bwd cost incurred # bwd_cost = input_grad + other_grad + bias_grad bwd_activation_cost = sum( - [v for k, v in backward_size_mapping.items() if not self.is_param(k) and not self.is_buffer(k)]) + [v for k, v in backward_size_mapping.items() if not self.is_param(k) and not self.is_buffer(k)] + ) bwd_parameter_cost = sum([v for k, v in backward_size_mapping.items() if self.is_param(k)]) bwd_mem_cost = MemoryCost(activation=bwd_activation_cost, parameter=bwd_parameter_cost) # compute total cost - total_mem_cost = MemoryCost(activation=fwd_activation_cost + bwd_activation_cost, - parameter=fwd_parameter_cost + bwd_parameter_cost, - buffer=fwd_buffer_cost) + total_mem_cost = MemoryCost( + activation=fwd_activation_cost + bwd_activation_cost, + parameter=fwd_parameter_cost + bwd_parameter_cost, + buffer=fwd_buffer_cost, + ) memory_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_mem_cost) strategy.memory_cost = memory_cost @ignore_sharding_exception def split_input_channel(self, mesh_dim_0): - name = f'RS{mesh_dim_0} = RS{mesh_dim_0} x S{mesh_dim_0}' + name = f"RS{mesh_dim_0} = RS{mesh_dim_0} x S{mesh_dim_0}" dim_partition_dict_mapping = { - "input": { - 1: [mesh_dim_0] - }, - "other": { - 0: [mesh_dim_0] - }, - "output": { - 1: [mesh_dim_0] - }, - "running_mean": { - 0: [mesh_dim_0] - }, - "running_var": { - 0: [mesh_dim_0] - }, + "input": {1: [mesh_dim_0]}, + "other": {0: [mesh_dim_0]}, + "output": {1: [mesh_dim_0]}, + "running_mean": {0: [mesh_dim_0]}, + "running_var": {0: [mesh_dim_0]}, "num_batches_tracked": {}, } if self.has_bias: @@ -132,29 +129,21 @@ class BatchNormStrategyGenerator(StrategyGenerator): communication_action_mapping = {} - return self.get_sharding_strategy(name=name, - sharding_spec_mapping=sharding_spec_mapping, - communication_action_mapping=communication_action_mapping) + return self.get_sharding_strategy( + name=name, + sharding_spec_mapping=sharding_spec_mapping, + communication_action_mapping=communication_action_mapping, + ) @ignore_sharding_exception def split_input_channel_1d(self, mesh_dim_0, mesh_dim_1): - name = f'RS{mesh_dim_0}{mesh_dim_1} = RS{mesh_dim_0}{mesh_dim_1} x S{mesh_dim_0}{mesh_dim_1}' + name = f"RS{mesh_dim_0}{mesh_dim_1} = RS{mesh_dim_0}{mesh_dim_1} x S{mesh_dim_0}{mesh_dim_1}" dim_partition_dict_mapping = { - "input": { - 1: [mesh_dim_0, mesh_dim_1] - }, - "other": { - 0: [mesh_dim_0, mesh_dim_1] - }, - "output": { - 1: [mesh_dim_0, mesh_dim_1] - }, - "running_mean": { - 0: [mesh_dim_0, mesh_dim_1] - }, - "running_var": { - 0: [mesh_dim_0, mesh_dim_1] - }, + "input": {1: [mesh_dim_0, mesh_dim_1]}, + "other": {0: [mesh_dim_0, mesh_dim_1]}, + "output": {1: [mesh_dim_0, mesh_dim_1]}, + "running_mean": {0: [mesh_dim_0, mesh_dim_1]}, + "running_var": {0: [mesh_dim_0, mesh_dim_1]}, "num_batches_tracked": {}, } if self.has_bias: @@ -164,13 +153,15 @@ class BatchNormStrategyGenerator(StrategyGenerator): communication_action_mapping = {} - return self.get_sharding_strategy(name=name, - sharding_spec_mapping=sharding_spec_mapping, - communication_action_mapping=communication_action_mapping) + return self.get_sharding_strategy( + name=name, + sharding_spec_mapping=sharding_spec_mapping, + communication_action_mapping=communication_action_mapping, + ) @ignore_sharding_exception def non_split(self): - name = f'RR = RR x R' + name = f"RR = RR x R" dim_partition_dict_mapping = { "input": {}, "other": {}, @@ -186,21 +177,19 @@ class BatchNormStrategyGenerator(StrategyGenerator): communication_action_mapping = {} - return self.get_sharding_strategy(name=name, - sharding_spec_mapping=sharding_spec_mapping, - communication_action_mapping=communication_action_mapping) + return self.get_sharding_strategy( + name=name, + sharding_spec_mapping=sharding_spec_mapping, + communication_action_mapping=communication_action_mapping, + ) @ignore_sharding_exception def split_input_batch(self, mesh_dim_0): - name = f'S{mesh_dim_0}R = S{mesh_dim_0}R x R WITH SYNC_BN' + name = f"S{mesh_dim_0}R = S{mesh_dim_0}R x R WITH SYNC_BN" dim_partition_dict_mapping = { - "input": { - 0: [mesh_dim_0] - }, + "input": {0: [mesh_dim_0]}, "other": {}, - "output": { - 0: [mesh_dim_0] - }, + "output": {0: [mesh_dim_0]}, "running_mean": {}, "running_var": {}, "num_batches_tracked": {}, @@ -212,33 +201,32 @@ class BatchNormStrategyGenerator(StrategyGenerator): # set communication action # For SyncBN case, we don't need to do communication for weight and bias. - # TODO: the communication happens interally at SyncBN operation. We need to replace the BN operation + # TODO: the communication happens internally at SyncBN operation. We need to replace the BN operation # to SyncBN operation instead of inserting a communication node. output_comm_action = self.get_communication_action( sharding_spec=sharding_spec_mapping["output"], communication_pattern=CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD, logical_process_axis=mesh_dim_0, - comm_type=CommType.IMPLICIT) + comm_type=CommType.IMPLICIT, + ) # TODO: Temporary solution has no communication cost, # above action should be added after the SyncBN replace pass completed. communication_action_mapping = {} - return self.get_sharding_strategy(name=name, - sharding_spec_mapping=sharding_spec_mapping, - communication_action_mapping=communication_action_mapping) + return self.get_sharding_strategy( + name=name, + sharding_spec_mapping=sharding_spec_mapping, + communication_action_mapping=communication_action_mapping, + ) @ignore_sharding_exception def split_input_batch_1d(self, mesh_dim_0, mesh_dim_1): - name = f'S{mesh_dim_0}{mesh_dim_1}R = S{mesh_dim_0}{mesh_dim_1}R x R WITH SYNC_BN' + name = f"S{mesh_dim_0}{mesh_dim_1}R = S{mesh_dim_0}{mesh_dim_1}R x R WITH SYNC_BN" dim_partition_dict_mapping = { - "input": { - 0: [mesh_dim_0, mesh_dim_1] - }, + "input": {0: [mesh_dim_0, mesh_dim_1]}, "other": {}, - "output": { - 0: [mesh_dim_0, mesh_dim_1] - }, + "output": {0: [mesh_dim_0, mesh_dim_1]}, "running_mean": {}, "running_var": {}, "num_batches_tracked": {}, @@ -250,25 +238,28 @@ class BatchNormStrategyGenerator(StrategyGenerator): # set communication action # For SyncBN case, we don't need to do communication for gradients of weight and bias. - # TODO: the communication happens interally at SyncBN operation. We need to replace the BN operation + # TODO: the communication happens internally at SyncBN operation. We need to replace the BN operation # to SyncBN operation instead of inserting a communication node. output_comm_action = self.get_communication_action( sharding_spec=sharding_spec_mapping["output"], communication_pattern=CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD, logical_process_axis=[mesh_dim_0, mesh_dim_1], - comm_type=CommType.IMPLICIT) + comm_type=CommType.IMPLICIT, + ) # TODO: Temporary solution has no communication cost, # above action should be added after the SyncBN replace pass completed. communication_action_mapping = {} - return self.get_sharding_strategy(name=name, - sharding_spec_mapping=sharding_spec_mapping, - communication_action_mapping=communication_action_mapping) + return self.get_sharding_strategy( + name=name, + sharding_spec_mapping=sharding_spec_mapping, + communication_action_mapping=communication_action_mapping, + ) @ignore_sharding_exception def split_input_both_dim(self, mesh_dim_0, mesh_dim_1): - name = f'S{mesh_dim_0}S{mesh_dim_1} = S{mesh_dim_0}S{mesh_dim_1} x S{mesh_dim_1} WITH SYNC_BN' + name = f"S{mesh_dim_0}S{mesh_dim_1} = S{mesh_dim_0}S{mesh_dim_1} x S{mesh_dim_1} WITH SYNC_BN" dim_partition_dict_mapping = { "input": { 0: [mesh_dim_0], @@ -298,26 +289,29 @@ class BatchNormStrategyGenerator(StrategyGenerator): # set communication action # For SyncBN case, we don't need to do communication for gradients of weight and bias. - # TODO: the communication happens interally at SyncBN operation. We need to replace the BN operation + # TODO: the communication happens internally at SyncBN operation. We need to replace the BN operation # to SyncBN operation instead of inserting a communication node. output_comm_action = self.get_communication_action( sharding_spec=sharding_spec_mapping["output"], communication_pattern=CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD, logical_process_axis=[mesh_dim_0], - comm_type=CommType.IMPLICIT) + comm_type=CommType.IMPLICIT, + ) # TODO: Temporary solution has no communication cost, # above action should be added after the SyncBN replace pass completed. communication_action_mapping = {} - return self.get_sharding_strategy(name=name, - sharding_spec_mapping=sharding_spec_mapping, - communication_action_mapping=communication_action_mapping) + return self.get_sharding_strategy( + name=name, + sharding_spec_mapping=sharding_spec_mapping, + communication_action_mapping=communication_action_mapping, + ) def collate_strategies(self) -> List[ShardingStrategy]: - ''' + """ Generate every possible strategies for a BatchNorm node, and record all strategies into the strategies_vector. - ''' + """ strategy_list = [] # RS = RS x S diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/binary_elementwise_generator.py b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/binary_elementwise_generator.py index fd7f811c8972412eaec88bb1dcfc639cdf1fe630..c7da0034ec3bf504e52dd4d15f934e5b0c242e05 100644 --- a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/binary_elementwise_generator.py +++ b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/binary_elementwise_generator.py @@ -14,7 +14,7 @@ from colossalai.tensor.sharding_spec import ShardingSpecException from .strategy_generator import StrategyGenerator -__all__ = ['BinaryElementwiseStrategyGenerator'] +__all__ = ["BinaryElementwiseStrategyGenerator"] class BinaryElementwiseStrategyGenerator(StrategyGenerator): @@ -26,36 +26,37 @@ class BinaryElementwiseStrategyGenerator(StrategyGenerator): """ def validate(self) -> bool: - assert len(self.op_data) == 3, \ - f'BinaryElementwiseStrategyGenerator only accepts three operation data (input, other and output), but got {len(self.op_data)}' + assert ( + len(self.op_data) == 3 + ), f"BinaryElementwiseStrategyGenerator only accepts three operation data (input, other and output), but got {len(self.op_data)}" for name, op_data in self.op_data.items(): if not isinstance(op_data.data, (torch.Tensor, int, float)): - raise TypeError(f'The operation data {name} is not a torch.Tensor/int/float.') + raise TypeError(f"The operation data {name} is not a torch.Tensor/int/float.") def update_compute_cost(self, strategy: ShardingStrategy) -> ShardingStrategy: - shape = strategy.sharding_specs[self.op_data['input']].get_sharded_shape_per_device() + shape = strategy.sharding_specs[self.op_data["input"]].get_sharded_shape_per_device() # since elementwise ops are not compute-intensive, # we approximate the backward compute cost # to be twice the fwd compute cost fwd_compute_cost = reduce(operator.mul, shape) bwd_compute_cost = fwd_compute_cost * 2 - compute_cost = TrainCycleItem(fwd=fwd_compute_cost, - bwd=bwd_compute_cost, - total=fwd_compute_cost + bwd_compute_cost) + compute_cost = TrainCycleItem( + fwd=fwd_compute_cost, bwd=bwd_compute_cost, total=fwd_compute_cost + bwd_compute_cost + ) strategy.compute_cost = compute_cost def update_memory_cost(self, strategy: ShardingStrategy) -> ShardingStrategy: # all input, output and outputs have the same shape - shape = strategy.sharding_specs[self.op_data['input']].get_sharded_shape_per_device() + strategy.sharding_specs[self.op_data["input"]].get_sharded_shape_per_device() # compute fwd memory cost in bytes # as the elementwise ops are not memory-intensive - # we approximate the fwd memroy cost to be the output + # we approximate the fwd memory cost to be the output # and the backward memory cost to be grad of input and other - input_bytes = self._compute_size_in_bytes(strategy, 'input') - other_bytes = self._compute_size_in_bytes(strategy, 'other') - output_bytes = self._compute_size_in_bytes(strategy, 'output') + input_bytes = self._compute_size_in_bytes(strategy, "input") + other_bytes = self._compute_size_in_bytes(strategy, "other") + output_bytes = self._compute_size_in_bytes(strategy, "output") fwd_memory_cost = MemoryCost(activation=output_bytes) bwd_memory_cost = MemoryCost(activation=input_bytes + other_bytes) total_memory_cost = MemoryCost(activation=input_bytes + other_bytes + output_bytes) @@ -66,7 +67,7 @@ class BinaryElementwiseStrategyGenerator(StrategyGenerator): def enumerate_all_possible_output(self, mesh_dim_0, mesh_dim_1): # we check for the output logical shape to get the number of dimensions dim_partition_list = [] - dim_size = len(self.op_data['output'].logical_shape) + dim_size = len(self.op_data["output"].logical_shape) # enumerate all the 2D sharding cases sharding_list_2d = enumerate_all_possible_2d_sharding(mesh_dim_0, mesh_dim_1, dim_size) @@ -86,21 +87,22 @@ class BinaryElementwiseStrategyGenerator(StrategyGenerator): # convert these dim partition dict to sharding strategy for dim_partition_dict in dim_partition_list: - dim_partition_dict_mapping = dict(input=dim_partition_dict, - other=dim_partition_dict, - output=dim_partition_dict) + dim_partition_dict_mapping = dict( + input=dim_partition_dict, other=dim_partition_dict, output=dim_partition_dict + ) try: sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping) communication_action_mapping = {} # get name - sharding_seq = sharding_spec_mapping['input'].sharding_sequence - name = f'{sharding_seq} = {sharding_seq} {sharding_seq}' + sharding_seq = sharding_spec_mapping["input"].sharding_sequence + name = f"{sharding_seq} = {sharding_seq} {sharding_seq}" sharding_strategy = self.get_sharding_strategy( name=name, sharding_spec_mapping=sharding_spec_mapping, - communication_action_mapping=communication_action_mapping) + communication_action_mapping=communication_action_mapping, + ) strategy_list.append(sharding_strategy) except ShardingSpecException: continue diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/conv_strategy_generator.py b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/conv_strategy_generator.py index c2154b3104d3d52e994a2add25ddc796792e1c66..5208f61543bb38ebd4f151cad32bf86db581b2f7 100644 --- a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/conv_strategy_generator.py +++ b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/conv_strategy_generator.py @@ -1,11 +1,9 @@ import copy import operator -import warnings from functools import reduce from typing import List from colossalai.auto_parallel.tensor_shard.sharding_strategy import ( - CommAction, CommType, MemoryCost, ShardingStrategy, @@ -24,29 +22,32 @@ class ConvStrategyGenerator(StrategyGenerator): """ def validate(self) -> bool: - ''' + """ In sanity check, we need make sure the input data having correct dimension size. For Conv1d, the dim of input data should be 3([N, C, L]). For Conv2d, the dim of input data should be 4([N, C, H, W]). For Conv3d, the dim of input data should be 5([N, C, H, W, D]). - ''' - input_op_data = self.op_data['input'] + """ + input_op_data = self.op_data["input"] assert input_op_data.data.dim() in ( - 3, 4, 5), f'We suppose the dim of input fed into conv op should in range of [3, 5].' + 3, + 4, + 5, + ), f"We suppose the dim of input fed into conv op should in range of [3, 5]." def update_compute_cost(self, strategy: ShardingStrategy): - ''' + """ Compute the computation cost per device with this specific strategy. - Note: compute_cost need to be devided by TFLOPS, now it just shows the computation size. - ''' - # TODO: compute_cost need to be devided by TFLOPS, now it just shows the computation size. + Note: compute_cost need to be divided by TFLOPS, now it just shows the computation size. + """ + # TODO: compute_cost need to be divided by TFLOPS, now it just shows the computation size. # 1D: (L) * N * Cout * Cin * kernel # 2D: (H * W) * N * Cout * Cin * kernel # 3D: (H * W * D) * N * Cout * Cin * kernel - sharded_input_shape = strategy.sharding_specs[self.op_data['input']].get_sharded_shape_per_device() - sharded_other_shape = strategy.sharding_specs[self.op_data['other']].get_sharded_shape_per_device() - sharded_output_shape = strategy.sharding_specs[self.op_data['output']].get_sharded_shape_per_device() + sharded_input_shape = strategy.sharding_specs[self.op_data["input"]].get_sharded_shape_per_device() + sharded_other_shape = strategy.sharding_specs[self.op_data["other"]].get_sharded_shape_per_device() + sharded_output_shape = strategy.sharding_specs[self.op_data["output"]].get_sharded_shape_per_device() if self.has_bias: # bias add is an element wise operation, so the cost is equal to product of output shape. bias_compute_cost = reduce(operator.mul, sharded_output_shape) @@ -76,14 +77,14 @@ class ConvStrategyGenerator(StrategyGenerator): def update_memory_cost(self, strategy: ShardingStrategy): forward_size_mapping = { - 'input': self._compute_size_in_bytes(strategy, "input"), - 'other': self._compute_size_in_bytes(strategy, "other"), - 'output': self._compute_size_in_bytes(strategy, "output") + "input": self._compute_size_in_bytes(strategy, "input"), + "other": self._compute_size_in_bytes(strategy, "other"), + "output": self._compute_size_in_bytes(strategy, "output"), } if self.has_bias: bias_size = self._compute_size_in_bytes(strategy, "bias") - forward_size_mapping['bias'] = bias_size + forward_size_mapping["bias"] = bias_size backward_size_mapping = copy.deepcopy(forward_size_mapping) backward_size_mapping.pop("output") @@ -100,26 +101,20 @@ class ConvStrategyGenerator(StrategyGenerator): bwd_mem_cost = MemoryCost(activation=bwd_activation_cost, parameter=bwd_parameter_cost) # compute total cost - total_mem_cost = MemoryCost(activation=fwd_activation_cost + bwd_activation_cost, - parameter=fwd_parameter_cost + bwd_parameter_cost) + total_mem_cost = MemoryCost( + activation=fwd_activation_cost + bwd_activation_cost, parameter=fwd_parameter_cost + bwd_parameter_cost + ) memory_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_mem_cost) strategy.memory_cost = memory_cost @ignore_sharding_exception def split_input_batch_weight_out_channel(self, mesh_dim_0, mesh_dim_1): - name = f'S{mesh_dim_0}S{mesh_dim_1} = S{mesh_dim_0}R x RS{mesh_dim_1}' + name = f"S{mesh_dim_0}S{mesh_dim_1} = S{mesh_dim_0}R x RS{mesh_dim_1}" dim_partition_dict_mapping = { - "input": { - 0: [mesh_dim_0] - }, - "other": { - 1: [mesh_dim_1] - }, - "output": { - 0: [mesh_dim_0], - 1: [mesh_dim_1] - }, + "input": {0: [mesh_dim_0]}, + "other": {1: [mesh_dim_1]}, + "output": {0: [mesh_dim_0], 1: [mesh_dim_1]}, } if self.has_bias: dim_partition_dict_mapping["bias"] = {0: [mesh_dim_1]} @@ -132,7 +127,8 @@ class ConvStrategyGenerator(StrategyGenerator): communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, logical_process_axis=mesh_dim_1, comm_type=CommType.BEFORE, - arg_index=0) + arg_index=0, + ) communication_action_mapping = {"input": input_comm_action} if self.is_param("other"): @@ -140,7 +136,8 @@ class ConvStrategyGenerator(StrategyGenerator): sharding_spec_mapping["other"], communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, logical_process_axis=mesh_dim_0, - comm_type=CommType.HOOK) + comm_type=CommType.HOOK, + ) else: other_comm_action = self.get_communication_action( @@ -148,38 +145,41 @@ class ConvStrategyGenerator(StrategyGenerator): communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, logical_process_axis=mesh_dim_0, comm_type=CommType.BEFORE, - arg_index=1) + arg_index=1, + ) communication_action_mapping["other"] = other_comm_action if self.has_bias: - if self.is_param('bias'): + if self.is_param("bias"): bias_comm_action = self.get_communication_action( sharding_spec_mapping["bias"], communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, logical_process_axis=mesh_dim_0, - comm_type=CommType.HOOK) + comm_type=CommType.HOOK, + ) else: bias_comm_action = self.get_communication_action( sharding_spec_mapping["bias"], communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, logical_process_axis=mesh_dim_0, comm_type=CommType.BEFORE, - key_for_kwarg='bias') + key_for_kwarg="bias", + ) communication_action_mapping["bias"] = bias_comm_action - return self.get_sharding_strategy(name=name, - sharding_spec_mapping=sharding_spec_mapping, - communication_action_mapping=communication_action_mapping) + return self.get_sharding_strategy( + name=name, + sharding_spec_mapping=sharding_spec_mapping, + communication_action_mapping=communication_action_mapping, + ) @ignore_sharding_exception def split_input_batch(self, mesh_dim_0): - name = f'S{mesh_dim_0}R = S{mesh_dim_0}R x RR' + name = f"S{mesh_dim_0}R = S{mesh_dim_0}R x RR" dim_partition_dict_mapping = { - "input": { - 0: [mesh_dim_0] - }, + "input": {0: [mesh_dim_0]}, "other": {}, "output": { 0: [mesh_dim_0], @@ -196,7 +196,8 @@ class ConvStrategyGenerator(StrategyGenerator): sharding_spec_mapping["other"], communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, logical_process_axis=mesh_dim_0, - comm_type=CommType.HOOK) + comm_type=CommType.HOOK, + ) else: other_comm_action = self.get_communication_action( @@ -204,42 +205,45 @@ class ConvStrategyGenerator(StrategyGenerator): communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, logical_process_axis=mesh_dim_0, comm_type=CommType.BEFORE, - arg_index=1) + arg_index=1, + ) communication_action_mapping["other"] = other_comm_action if self.has_bias: - if self.is_param('bias'): + if self.is_param("bias"): bias_comm_action = self.get_communication_action( sharding_spec_mapping["bias"], communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, logical_process_axis=mesh_dim_0, - comm_type=CommType.HOOK) + comm_type=CommType.HOOK, + ) else: bias_comm_action = self.get_communication_action( sharding_spec_mapping["bias"], communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, logical_process_axis=mesh_dim_0, comm_type=CommType.BEFORE, - key_for_kwarg='bias') + key_for_kwarg="bias", + ) communication_action_mapping["bias"] = bias_comm_action - return self.get_sharding_strategy(name=name, - sharding_spec_mapping=sharding_spec_mapping, - communication_action_mapping=communication_action_mapping) + return self.get_sharding_strategy( + name=name, + sharding_spec_mapping=sharding_spec_mapping, + communication_action_mapping=communication_action_mapping, + ) @ignore_sharding_exception def split_input_both_dim_weight_in_channel(self, mesh_dim_0, mesh_dim_1): - name = f'S{mesh_dim_0}R = S{mesh_dim_0}S{mesh_dim_1} x S{mesh_dim_1}R' + name = f"S{mesh_dim_0}R = S{mesh_dim_0}S{mesh_dim_1} x S{mesh_dim_1}R" dim_partition_dict_mapping = { "input": { 0: [mesh_dim_0], 1: [mesh_dim_1], }, - "other": { - 0: [mesh_dim_1] - }, + "other": {0: [mesh_dim_1]}, "output": { 0: [mesh_dim_0], }, @@ -254,7 +258,8 @@ class ConvStrategyGenerator(StrategyGenerator): sharding_spec_mapping["output"], communication_pattern=CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD, logical_process_axis=mesh_dim_1, - comm_type=CommType.AFTER) + comm_type=CommType.AFTER, + ) communication_action_mapping = {"output": output_comm_action} @@ -263,7 +268,8 @@ class ConvStrategyGenerator(StrategyGenerator): sharding_spec_mapping["other"], communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, logical_process_axis=mesh_dim_0, - comm_type=CommType.HOOK) + comm_type=CommType.HOOK, + ) else: other_comm_action = self.get_communication_action( @@ -271,7 +277,8 @@ class ConvStrategyGenerator(StrategyGenerator): communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, logical_process_axis=mesh_dim_0, comm_type=CommType.BEFORE, - arg_index=1) + arg_index=1, + ) communication_action_mapping["other"] = other_comm_action if self.has_bias: if self.is_param("bias"): @@ -279,23 +286,27 @@ class ConvStrategyGenerator(StrategyGenerator): sharding_spec_mapping["bias"], communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, logical_process_axis=mesh_dim_0, - comm_type=CommType.HOOK) + comm_type=CommType.HOOK, + ) else: bias_comm_action = self.get_communication_action( sharding_spec_mapping["bias"], communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, logical_process_axis=mesh_dim_0, comm_type=CommType.BEFORE, - key_for_kwarg='bias') + key_for_kwarg="bias", + ) communication_action_mapping["bias"] = bias_comm_action - return self.get_sharding_strategy(name=name, - sharding_spec_mapping=sharding_spec_mapping, - communication_action_mapping=communication_action_mapping) + return self.get_sharding_strategy( + name=name, + sharding_spec_mapping=sharding_spec_mapping, + communication_action_mapping=communication_action_mapping, + ) @ignore_sharding_exception def split_input_in_channel_weight_both_channel(self, mesh_dim_0, mesh_dim_1): - name = f'RS{mesh_dim_1} = RS{mesh_dim_0} x S{mesh_dim_0}S{mesh_dim_1}' + name = f"RS{mesh_dim_1} = RS{mesh_dim_0} x S{mesh_dim_0}S{mesh_dim_1}" dim_partition_dict_mapping = { "input": { @@ -322,23 +333,27 @@ class ConvStrategyGenerator(StrategyGenerator): sharding_spec_mapping["output"], communication_pattern=CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD, logical_process_axis=mesh_dim_0, - comm_type=CommType.AFTER) + comm_type=CommType.AFTER, + ) input_comm_action = self.get_communication_action( sharding_spec_mapping["input"], communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, logical_process_axis=mesh_dim_1, comm_type=CommType.BEFORE, - arg_index=0) + arg_index=0, + ) communication_action_mapping = {"output": output_comm_action, "input": input_comm_action} - return self.get_sharding_strategy(name=name, - sharding_spec_mapping=sharding_spec_mapping, - communication_action_mapping=communication_action_mapping) + return self.get_sharding_strategy( + name=name, + sharding_spec_mapping=sharding_spec_mapping, + communication_action_mapping=communication_action_mapping, + ) @ignore_sharding_exception def split_input_in_channel_weight_in_channel(self, mesh_dim_0): - name = f'RR = RS{mesh_dim_0} x S{mesh_dim_0}R' + name = f"RR = RS{mesh_dim_0} x S{mesh_dim_0}R" dim_partition_dict_mapping = { "input": { @@ -360,17 +375,20 @@ class ConvStrategyGenerator(StrategyGenerator): sharding_spec_mapping["output"], communication_pattern=CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD, logical_process_axis=mesh_dim_0, - comm_type=CommType.AFTER) + comm_type=CommType.AFTER, + ) communication_action_mapping = {"output": output_comm_action} - return self.get_sharding_strategy(name=name, - sharding_spec_mapping=sharding_spec_mapping, - communication_action_mapping=communication_action_mapping) + return self.get_sharding_strategy( + name=name, + sharding_spec_mapping=sharding_spec_mapping, + communication_action_mapping=communication_action_mapping, + ) @ignore_sharding_exception def split_weight_out_channel(self, mesh_dim_0): - name = f'RS{mesh_dim_0} = RR x RS{mesh_dim_0}' + name = f"RS{mesh_dim_0} = RR x RS{mesh_dim_0}" dim_partition_dict_mapping = { "input": {}, @@ -395,17 +413,20 @@ class ConvStrategyGenerator(StrategyGenerator): communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, logical_process_axis=mesh_dim_0, comm_type=CommType.BEFORE, - arg_index=0) + arg_index=0, + ) communication_action_mapping = {"input": input_comm_action} - return self.get_sharding_strategy(name=name, - sharding_spec_mapping=sharding_spec_mapping, - communication_action_mapping=communication_action_mapping) + return self.get_sharding_strategy( + name=name, + sharding_spec_mapping=sharding_spec_mapping, + communication_action_mapping=communication_action_mapping, + ) @ignore_sharding_exception def non_split(self): - name = f'RR = RR x RR' + name = f"RR = RR x RR" dim_partition_dict_mapping = { "input": {}, @@ -418,13 +439,13 @@ class ConvStrategyGenerator(StrategyGenerator): sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping) - return self.get_sharding_strategy(name=name, - sharding_spec_mapping=sharding_spec_mapping, - communication_action_mapping={}) + return self.get_sharding_strategy( + name=name, sharding_spec_mapping=sharding_spec_mapping, communication_action_mapping={} + ) @ignore_sharding_exception def split_1d_parallel_on_input_batch(self, mesh_dim_0, mesh_dim_1): - name = f'S{mesh_dim_0}{mesh_dim_1}R = S{mesh_dim_0}{mesh_dim_1}R x RR' + name = f"S{mesh_dim_0}{mesh_dim_1}R = S{mesh_dim_0}{mesh_dim_1}R x RR" dim_partition_dict_mapping = { "input": { @@ -447,14 +468,16 @@ class ConvStrategyGenerator(StrategyGenerator): sharding_spec_mapping["other"], communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, logical_process_axis=[mesh_dim_0, mesh_dim_1], - comm_type=CommType.HOOK) + comm_type=CommType.HOOK, + ) else: other_comm_action = self.get_communication_action( sharding_spec_mapping["other"], communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, logical_process_axis=[mesh_dim_0, mesh_dim_1], comm_type=CommType.BEFORE, - arg_index=1) + arg_index=1, + ) communication_action_mapping["other"] = other_comm_action @@ -464,23 +487,27 @@ class ConvStrategyGenerator(StrategyGenerator): sharding_spec_mapping["bias"], communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, logical_process_axis=[mesh_dim_0, mesh_dim_1], - comm_type=CommType.HOOK) + comm_type=CommType.HOOK, + ) else: bias_comm_action = self.get_communication_action( sharding_spec_mapping["bias"], communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, logical_process_axis=[mesh_dim_0, mesh_dim_1], comm_type=CommType.BEFORE, - key_for_kwarg='bias') + key_for_kwarg="bias", + ) communication_action_mapping["bias"] = bias_comm_action - return self.get_sharding_strategy(name=name, - sharding_spec_mapping=sharding_spec_mapping, - communication_action_mapping=communication_action_mapping) + return self.get_sharding_strategy( + name=name, + sharding_spec_mapping=sharding_spec_mapping, + communication_action_mapping=communication_action_mapping, + ) @ignore_sharding_exception def split_1d_parallel_on_in_channel(self, mesh_dim_0, mesh_dim_1): - name = f'RR = RS{mesh_dim_0}{mesh_dim_1} x S{mesh_dim_0}{mesh_dim_1}R' + name = f"RR = RS{mesh_dim_0}{mesh_dim_1} x S{mesh_dim_0}{mesh_dim_1}R" dim_partition_dict_mapping = { "input": { 1: [mesh_dim_0, mesh_dim_1], @@ -501,17 +528,20 @@ class ConvStrategyGenerator(StrategyGenerator): sharding_spec_mapping["output"], communication_pattern=CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD, logical_process_axis=[mesh_dim_0, mesh_dim_1], - comm_type=CommType.AFTER) + comm_type=CommType.AFTER, + ) communication_action_mapping = {"output": output_comm_action} - return self.get_sharding_strategy(name=name, - sharding_spec_mapping=sharding_spec_mapping, - communication_action_mapping=communication_action_mapping) + return self.get_sharding_strategy( + name=name, + sharding_spec_mapping=sharding_spec_mapping, + communication_action_mapping=communication_action_mapping, + ) @ignore_sharding_exception def split_1d_parallel_on_out_channel(self, mesh_dim_0, mesh_dim_1): - name = f'RS{mesh_dim_0}{mesh_dim_1} = RR x RS{mesh_dim_0}{mesh_dim_1}' + name = f"RS{mesh_dim_0}{mesh_dim_1} = RR x RS{mesh_dim_0}{mesh_dim_1}" dim_partition_dict_mapping = { "input": {}, "other": { @@ -535,13 +565,16 @@ class ConvStrategyGenerator(StrategyGenerator): communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, logical_process_axis=[mesh_dim_0, mesh_dim_1], comm_type=CommType.BEFORE, - arg_index=0) + arg_index=0, + ) communication_action_mapping = {"input": input_comm_action} - return self.get_sharding_strategy(name=name, - sharding_spec_mapping=sharding_spec_mapping, - communication_action_mapping=communication_action_mapping) + return self.get_sharding_strategy( + name=name, + sharding_spec_mapping=sharding_spec_mapping, + communication_action_mapping=communication_action_mapping, + ) def collate_strategies(self) -> List[ShardingStrategy]: strategies = [] diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/embedding_generator.py b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/embedding_generator.py index 82a04ab52e739ae3db29efde2a66f30ff24cb8d0..385a8886f2318220530bdc2448bcee8e83a5b8f0 100644 --- a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/embedding_generator.py +++ b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/embedding_generator.py @@ -1,11 +1,9 @@ import copy import operator -import warnings from functools import reduce from typing import List from colossalai.auto_parallel.tensor_shard.sharding_strategy import ( - CommAction, CommType, MemoryCost, ShardingStrategy, @@ -27,16 +25,16 @@ class EmbeddingStrategyGenerator(StrategyGenerator): return super().validate() def update_compute_cost(self, strategy: ShardingStrategy): - ''' + """ Compute the computation cost per device with this specific strategy. Note: The computation cost for the embedding handler is estimated as dense computing now. It may not be accurate. - ''' + """ # TODO: estimate the embedding computation cost as sparse operation - sharded_input_shape = strategy.sharding_specs[self.op_data['input']].get_sharded_shape_per_device() - sharded_other_shape = strategy.sharding_specs[self.op_data['other']].get_sharded_shape_per_device() - sharded_output_shape = strategy.sharding_specs[self.op_data['output']].get_sharded_shape_per_device() + sharded_input_shape = strategy.sharding_specs[self.op_data["input"]].get_sharded_shape_per_device() + sharded_other_shape = strategy.sharding_specs[self.op_data["other"]].get_sharded_shape_per_device() + sharded_output_shape = strategy.sharding_specs[self.op_data["output"]].get_sharded_shape_per_device() input_size_product = reduce(operator.mul, sharded_input_shape) other_size_product = reduce(operator.mul, sharded_other_shape) @@ -55,9 +53,9 @@ class EmbeddingStrategyGenerator(StrategyGenerator): def update_memory_cost(self, strategy: ShardingStrategy): forward_size_mapping = { - 'input': self._compute_size_in_bytes(strategy, "input"), - 'other': self._compute_size_in_bytes(strategy, "other"), - 'output': self._compute_size_in_bytes(strategy, "output") + "input": self._compute_size_in_bytes(strategy, "input"), + "other": self._compute_size_in_bytes(strategy, "other"), + "output": self._compute_size_in_bytes(strategy, "output"), } backward_size_mapping = copy.deepcopy(forward_size_mapping) @@ -75,14 +73,15 @@ class EmbeddingStrategyGenerator(StrategyGenerator): bwd_mem_cost = MemoryCost(activation=bwd_activation_cost, parameter=bwd_parameter_cost) # compute total cost - total_mem_cost = MemoryCost(activation=fwd_activation_cost + bwd_activation_cost, - parameter=fwd_parameter_cost + bwd_parameter_cost) + total_mem_cost = MemoryCost( + activation=fwd_activation_cost + bwd_activation_cost, parameter=fwd_parameter_cost + bwd_parameter_cost + ) memory_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_mem_cost) strategy.memory_cost = memory_cost @ignore_sharding_exception def non_split(self): - name = f'RR = R x RR' + name = f"RR = R x RR" dim_partition_dict_mapping = { "input": {}, @@ -92,18 +91,16 @@ class EmbeddingStrategyGenerator(StrategyGenerator): sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping) - return self.get_sharding_strategy(name=name, - sharding_spec_mapping=sharding_spec_mapping, - communication_action_mapping={}) + return self.get_sharding_strategy( + name=name, sharding_spec_mapping=sharding_spec_mapping, communication_action_mapping={} + ) @ignore_sharding_exception def split_input(self, mesh_dim_0): - name = f'S{mesh_dim_0}R = S{mesh_dim_0} x RR' + name = f"S{mesh_dim_0}R = S{mesh_dim_0} x RR" dim_partition_dict_mapping = { - "input": { - 0: [mesh_dim_0] - }, + "input": {0: [mesh_dim_0]}, "other": {}, "output": { 0: [mesh_dim_0], @@ -118,7 +115,8 @@ class EmbeddingStrategyGenerator(StrategyGenerator): sharding_spec_mapping["other"], communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, logical_process_axis=mesh_dim_0, - comm_type=CommType.HOOK) + comm_type=CommType.HOOK, + ) else: other_comm_action = self.get_communication_action( @@ -126,17 +124,20 @@ class EmbeddingStrategyGenerator(StrategyGenerator): communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, logical_process_axis=mesh_dim_0, comm_type=CommType.BEFORE, - arg_index=1) + arg_index=1, + ) communication_action_mapping["other"] = other_comm_action - return self.get_sharding_strategy(name=name, - sharding_spec_mapping=sharding_spec_mapping, - communication_action_mapping=communication_action_mapping) + return self.get_sharding_strategy( + name=name, + sharding_spec_mapping=sharding_spec_mapping, + communication_action_mapping=communication_action_mapping, + ) @ignore_sharding_exception def split_input_and_embedding_dim(self, mesh_dim_0, mesh_dim_1): - name = f'S{mesh_dim_0}S{mesh_dim_1} = S{mesh_dim_0} x RS{mesh_dim_1}' + name = f"S{mesh_dim_0}S{mesh_dim_1} = S{mesh_dim_0} x RS{mesh_dim_1}" dim_partition_dict_mapping = { "input": { @@ -159,7 +160,8 @@ class EmbeddingStrategyGenerator(StrategyGenerator): communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, logical_process_axis=mesh_dim_1, comm_type=CommType.BEFORE, - arg_index=0) + arg_index=0, + ) communication_action_mapping = {"input": input_comm_action} if self.is_param("other"): @@ -167,7 +169,8 @@ class EmbeddingStrategyGenerator(StrategyGenerator): sharding_spec_mapping["other"], communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, logical_process_axis=mesh_dim_0, - comm_type=CommType.HOOK) + comm_type=CommType.HOOK, + ) else: other_comm_action = self.get_communication_action( @@ -175,22 +178,23 @@ class EmbeddingStrategyGenerator(StrategyGenerator): communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, logical_process_axis=mesh_dim_0, comm_type=CommType.BEFORE, - arg_index=1) + arg_index=1, + ) communication_action_mapping["other"] = other_comm_action - return self.get_sharding_strategy(name=name, - sharding_spec_mapping=sharding_spec_mapping, - communication_action_mapping=communication_action_mapping) + return self.get_sharding_strategy( + name=name, + sharding_spec_mapping=sharding_spec_mapping, + communication_action_mapping=communication_action_mapping, + ) @ignore_sharding_exception def split_1d_parallel_on_input(self, mesh_dim_0, mesh_dim_1): - name = f'S{mesh_dim_0}{mesh_dim_1}R = S{mesh_dim_0}{mesh_dim_1} x RR' + name = f"S{mesh_dim_0}{mesh_dim_1}R = S{mesh_dim_0}{mesh_dim_1} x RR" dim_partition_dict_mapping = { - "input": { - 0: [mesh_dim_0, mesh_dim_1] - }, + "input": {0: [mesh_dim_0, mesh_dim_1]}, "other": {}, "output": { 0: [mesh_dim_0, mesh_dim_1], @@ -207,7 +211,8 @@ class EmbeddingStrategyGenerator(StrategyGenerator): sharding_spec_mapping["other"], communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, logical_process_axis=[mesh_dim_0, mesh_dim_1], - comm_type=CommType.HOOK) + comm_type=CommType.HOOK, + ) else: other_comm_action = self.get_communication_action( @@ -215,17 +220,20 @@ class EmbeddingStrategyGenerator(StrategyGenerator): communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, logical_process_axis=[mesh_dim_0, mesh_dim_1], comm_type=CommType.BEFORE, - arg_index=1) + arg_index=1, + ) communication_action_mapping["other"] = other_comm_action - return self.get_sharding_strategy(name=name, - sharding_spec_mapping=sharding_spec_mapping, - communication_action_mapping=communication_action_mapping) + return self.get_sharding_strategy( + name=name, + sharding_spec_mapping=sharding_spec_mapping, + communication_action_mapping=communication_action_mapping, + ) @ignore_sharding_exception def split_embedding_dim(self, mesh_dim_0): - name = f'RS{mesh_dim_0} = R x RS{mesh_dim_0}' + name = f"RS{mesh_dim_0} = R x RS{mesh_dim_0}" dim_partition_dict_mapping = { "input": {}, @@ -245,17 +253,20 @@ class EmbeddingStrategyGenerator(StrategyGenerator): communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, logical_process_axis=mesh_dim_0, comm_type=CommType.BEFORE, - arg_index=0) + arg_index=0, + ) communication_action_mapping = {"input": input_comm_action} - return self.get_sharding_strategy(name=name, - sharding_spec_mapping=sharding_spec_mapping, - communication_action_mapping=communication_action_mapping) + return self.get_sharding_strategy( + name=name, + sharding_spec_mapping=sharding_spec_mapping, + communication_action_mapping=communication_action_mapping, + ) @ignore_sharding_exception def split_1d_parallel_on_embedding_dim(self, mesh_dim_0, mesh_dim_1): - name = f'RS{mesh_dim_0}{mesh_dim_1} = R x RS{mesh_dim_0}{mesh_dim_1}' + name = f"RS{mesh_dim_0}{mesh_dim_1} = R x RS{mesh_dim_0}{mesh_dim_1}" dim_partition_dict_mapping = { "input": {}, @@ -275,13 +286,16 @@ class EmbeddingStrategyGenerator(StrategyGenerator): communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, logical_process_axis=[mesh_dim_0, mesh_dim_1], comm_type=CommType.BEFORE, - arg_index=0) + arg_index=0, + ) communication_action_mapping = {"input": input_comm_action} - return self.get_sharding_strategy(name=name, - sharding_spec_mapping=sharding_spec_mapping, - communication_action_mapping=communication_action_mapping) + return self.get_sharding_strategy( + name=name, + sharding_spec_mapping=sharding_spec_mapping, + communication_action_mapping=communication_action_mapping, + ) def collate_strategies(self) -> List[ShardingStrategy]: strategies = [] diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/getattr_generator.py b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/getattr_generator.py index bbeb9a639c835869634417e5ceff1a2cad082339..cc8d5771f28e9cb012672713e81eb449e96c944e 100644 --- a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/getattr_generator.py +++ b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/getattr_generator.py @@ -10,7 +10,7 @@ from colossalai.tensor.sharding_spec import ShardingSpecException from .strategy_generator import StrategyGenerator -__all__ = ['GetattrGenerator'] +__all__ = ["GetattrGenerator"] class GetattrGenerator(StrategyGenerator): @@ -26,10 +26,10 @@ class GetattrGenerator(StrategyGenerator): strategy.compute_cost = compute_cost def update_memory_cost(self, strategy: ShardingStrategy): - ''' + """ Compute the memory cost per device with this specific strategy. - ''' - forward_size_mapping = {'output': self._compute_size_in_bytes(strategy, "output")} + """ + forward_size_mapping = {"output": self._compute_size_in_bytes(strategy, "output")} # compute fwd cost incurred # fwd_cost = output @@ -47,7 +47,7 @@ class GetattrGenerator(StrategyGenerator): def enumerate_all_possible_output(self, mesh_dim_0, mesh_dim_1): # we check for the output logical shape to get the number of dimensions dim_partition_list = [] - dim_size = len(self.op_data['output'].logical_shape) + dim_size = len(self.op_data["output"].logical_shape) # enumerate all the 2D sharding cases sharding_list_2d = enumerate_all_possible_2d_sharding(mesh_dim_0, mesh_dim_1, dim_size) @@ -78,7 +78,8 @@ class GetattrGenerator(StrategyGenerator): sharding_strategy = self.get_sharding_strategy( name=name, sharding_spec_mapping=sharding_spec_mapping, - communication_action_mapping=communication_action_mapping) + communication_action_mapping=communication_action_mapping, + ) strategy_list.append(sharding_strategy) except ShardingSpecException: continue diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/getitem_generator.py b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/getitem_generator.py index 0aeb2e0d4079ea1b302d580554ed7ca24ab7096d..6f01d9cc7f8ef06a0e4b1672a8f98e7db68364ca 100644 --- a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/getitem_generator.py +++ b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/getitem_generator.py @@ -1,19 +1,13 @@ import copy from typing import List -from colossalai.auto_parallel.tensor_shard.sharding_strategy import ( - CommType, - MemoryCost, - ShardingStrategy, - TrainCycleItem, -) +from colossalai.auto_parallel.tensor_shard.sharding_strategy import MemoryCost, ShardingStrategy, TrainCycleItem from colossalai.logging import get_dist_logger -from colossalai.tensor.shape_consistency import CollectiveCommPattern from colossalai.tensor.sharding_spec import ShardingSpecException from .strategy_generator import FollowingStrategyGenerator -__all__ = ['GetItemStrategyGenerator', 'TensorStrategyGenerator', 'TensorTupleStrategyGenerator'] +__all__ = ["GetItemStrategyGenerator", "TensorStrategyGenerator", "TensorTupleStrategyGenerator"] class GetItemStrategyGenerator(FollowingStrategyGenerator): @@ -35,12 +29,12 @@ class GetItemStrategyGenerator(FollowingStrategyGenerator): strategy.compute_cost = compute_cost def update_memory_cost(self, strategy: ShardingStrategy): - ''' + """ Compute the memory cost per device with this specific strategy. - ''' + """ forward_size_mapping = { - 'input': self._compute_size_in_bytes(strategy, "input"), - 'output': self._compute_size_in_bytes(strategy, "output") + "input": self._compute_size_in_bytes(strategy, "input"), + "output": self._compute_size_in_bytes(strategy, "output"), } backward_size_mapping = copy.deepcopy(forward_size_mapping) @@ -58,27 +52,29 @@ class GetItemStrategyGenerator(FollowingStrategyGenerator): bwd_mem_cost = MemoryCost(activation=bwd_activation_cost, parameter=bwd_parameter_cost) # compute total cost - total_mem_cost = MemoryCost(activation=fwd_activation_cost + bwd_activation_cost, - parameter=fwd_parameter_cost + bwd_parameter_cost) + total_mem_cost = MemoryCost( + activation=fwd_activation_cost + bwd_activation_cost, parameter=fwd_parameter_cost + bwd_parameter_cost + ) memory_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_mem_cost) strategy.memory_cost = memory_cost class TensorStrategyGenerator(GetItemStrategyGenerator): - ''' + """ Deal with case 1 and 2. - ''' + """ def collate_strategies(self) -> List[ShardingStrategy]: strategy_list = [] - getitem_index = self.op_data['index'].data + getitem_index = self.op_data["index"].data for index, strategy in enumerate(self.predecessor_node.strategies_vector): try: logger = get_dist_logger() dim_partition_dict_mapping = {} communication_action_mapping = {} dim_partition_dict_for_input = copy.deepcopy( - strategy.output_sharding_specs[self.op_data["input"]].dim_partition_dict) + strategy.output_sharding_specs[self.op_data["input"]].dim_partition_dict + ) int_index = False if isinstance(getitem_index, int): @@ -120,9 +116,11 @@ class TensorStrategyGenerator(GetItemStrategyGenerator): name = f'{sharding_spec_mapping["output"].sharding_sequence} = {sharding_spec_mapping["input"].sharding_sequence}_{index}' - strategy = self.get_sharding_strategy(name=name, - sharding_spec_mapping=sharding_spec_mapping, - communication_action_mapping=communication_action_mapping) + strategy = self.get_sharding_strategy( + name=name, + sharding_spec_mapping=sharding_spec_mapping, + communication_action_mapping=communication_action_mapping, + ) except ShardingSpecException as e: logger.debug(e) continue @@ -137,9 +135,9 @@ class TensorStrategyGenerator(GetItemStrategyGenerator): class TensorTupleStrategyGenerator(GetItemStrategyGenerator): - ''' + """ Deal with case 3. - ''' + """ def collate_strategies(self) -> List[ShardingStrategy]: strategy_list = [] @@ -158,13 +156,15 @@ class TensorTupleStrategyGenerator(GetItemStrategyGenerator): sharding_spec_mapping["input"] = sharding_spec_for_input input_sharding_info = f"get the {index} element from (" for sharding_spec in sharding_spec_for_input: - input_sharding_info += f'{sharding_spec.sharding_sequence}, ' + input_sharding_info += f"{sharding_spec.sharding_sequence}, " input_sharding_info += ")" name = f'{sharding_spec_mapping["output"].sharding_sequence} = {input_sharding_info}_{strategy_index}' - strategy = self.get_sharding_strategy(name=name, - sharding_spec_mapping=sharding_spec_mapping, - communication_action_mapping=communication_action_mapping) + strategy = self.get_sharding_strategy( + name=name, + sharding_spec_mapping=sharding_spec_mapping, + communication_action_mapping=communication_action_mapping, + ) strategy_list.append(strategy) diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/layer_norm_generator.py b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/layer_norm_generator.py index fbb6070f7e82c9a41848c626c6271d1a7b9d73ee..e5b7e6f25d4d9007c7e7900632345190336dc160 100644 --- a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/layer_norm_generator.py +++ b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/layer_norm_generator.py @@ -18,7 +18,7 @@ from colossalai.tensor.shape_consistency import CollectiveCommPattern from .strategy_generator import StrategyGenerator -__all__ = ['LayerNormGenerator'] +__all__ = ["LayerNormGenerator"] class LayerNormGenerator(StrategyGenerator): @@ -31,21 +31,21 @@ class LayerNormGenerator(StrategyGenerator): return super().validate() def update_compute_cost(self, strategy: ShardingStrategy): - ''' + """ Compute the computation cost per device with this specific strategy. - Note: compute_cost need to be devided by TFLOPS, now it just shows the computation size. - ''' - # TODO: compute_cost need to be devided by TFLOPS, now it just shows the computation size. + Note: compute_cost need to be divided by TFLOPS, now it just shows the computation size. + """ + # TODO: compute_cost need to be divided by TFLOPS, now it just shows the computation size. # TODO: a constant coefficient need to be added. - sharded_input_shape = strategy.sharding_specs[self.op_data['input']].get_sharded_shape_per_device() - sharded_weight_shape = strategy.sharding_specs[self.op_data['other']].get_sharded_shape_per_device() + sharded_input_shape = strategy.sharding_specs[self.op_data["input"]].get_sharded_shape_per_device() + sharded_weight_shape = strategy.sharding_specs[self.op_data["other"]].get_sharded_shape_per_device() if self.has_bias: # bias add is an element wise operation, so the cost is equal to product of output shape. bias_compute_cost = reduce(operator.mul, sharded_weight_shape) # in LayerNorm context, batch dimensions mean all the dimensions do not join the normalization. - input_batch_shape = sharded_input_shape[:-len(sharded_weight_shape)] + input_batch_shape = sharded_input_shape[: -len(sharded_weight_shape)] input_batch_product = reduce(operator.mul, input_batch_shape, 1) norm_kernel_product = reduce(operator.mul, sharded_weight_shape, 1) forward_compute_cost = input_batch_product * norm_kernel_product @@ -62,18 +62,18 @@ class LayerNormGenerator(StrategyGenerator): strategy.compute_cost = compute_cost def update_memory_cost(self, strategy: ShardingStrategy): - ''' + """ Compute the memory cost per device with this specific strategy. - ''' + """ forward_size_mapping = { - 'input': self._compute_size_in_bytes(strategy, "input"), - 'other': self._compute_size_in_bytes(strategy, "other"), - 'output': self._compute_size_in_bytes(strategy, "output") + "input": self._compute_size_in_bytes(strategy, "input"), + "other": self._compute_size_in_bytes(strategy, "other"), + "output": self._compute_size_in_bytes(strategy, "output"), } if self.has_bias: bias_size = self._compute_size_in_bytes(strategy, "bias") - forward_size_mapping['bias'] = bias_size + forward_size_mapping["bias"] = bias_size backward_size_mapping = copy.deepcopy(forward_size_mapping) backward_size_mapping.pop("output") @@ -90,8 +90,9 @@ class LayerNormGenerator(StrategyGenerator): bwd_mem_cost = MemoryCost(activation=bwd_activation_cost, parameter=bwd_parameter_cost) # compute total cost - total_mem_cost = MemoryCost(activation=fwd_activation_cost + bwd_activation_cost, - parameter=fwd_parameter_cost + bwd_parameter_cost) + total_mem_cost = MemoryCost( + activation=fwd_activation_cost + bwd_activation_cost, parameter=fwd_parameter_cost + bwd_parameter_cost + ) memory_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_mem_cost) strategy.memory_cost = memory_cost @@ -120,7 +121,8 @@ class LayerNormGenerator(StrategyGenerator): sharding_spec=sharding_spec_mapping["other"], communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, logical_process_axis=total_mesh_dim_list, - comm_type=CommType.HOOK) + comm_type=CommType.HOOK, + ) communication_action_mapping["other"] = other_comm_action if self.has_bias: @@ -128,12 +130,15 @@ class LayerNormGenerator(StrategyGenerator): sharding_spec=sharding_spec_mapping["bias"], communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, logical_process_axis=total_mesh_dim_list, - comm_type=CommType.HOOK) + comm_type=CommType.HOOK, + ) communication_action_mapping["bias"] = bias_comm_action - strategy = self.get_sharding_strategy(name=name, - sharding_spec_mapping=sharding_spec_mapping, - communication_action_mapping=communication_action_mapping) + strategy = self.get_sharding_strategy( + name=name, + sharding_spec_mapping=sharding_spec_mapping, + communication_action_mapping=communication_action_mapping, + ) return strategy @@ -155,7 +160,7 @@ class LayerNormGenerator(StrategyGenerator): @ignore_sharding_exception def non_split(self): - name = f'RR = RR x R' + name = f"RR = RR x R" dim_partition_dict_mapping = { "input": {}, "other": {}, @@ -168,14 +173,16 @@ class LayerNormGenerator(StrategyGenerator): communication_action_mapping = {} - return self.get_sharding_strategy(name=name, - sharding_spec_mapping=sharding_spec_mapping, - communication_action_mapping=communication_action_mapping) + return self.get_sharding_strategy( + name=name, + sharding_spec_mapping=sharding_spec_mapping, + communication_action_mapping=communication_action_mapping, + ) def collate_strategies(self) -> List[ShardingStrategy]: - ''' + """ Generate every possible strategies for a LayerNorm node, and record all strategies into the strategies_vector. - ''' + """ strategy_list = [] input_data_dim = len(self.op_data["input"].logical_shape) weight_data_dim = len(self.op_data["other"].logical_shape) diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/matmul_strategy_generator.py b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/matmul_strategy_generator.py index 1ce5a08f2d6b70d20f10476309034ab1a26b75d1..fb182afb917561b6836a18c7abc483d499ab8464 100644 --- a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/matmul_strategy_generator.py +++ b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/matmul_strategy_generator.py @@ -1,5 +1,4 @@ import operator -from ast import arg from functools import reduce from typing import List @@ -24,14 +23,14 @@ class MatMulStrategyGenerator(StrategyGenerator): def update_memory_cost(self, strategy: ShardingStrategy) -> ShardingStrategy: size_mapping = { - 'input': self._compute_size_in_bytes(strategy, "input"), - 'other': self._compute_size_in_bytes(strategy, "other"), - 'output': self._compute_size_in_bytes(strategy, "output") + "input": self._compute_size_in_bytes(strategy, "input"), + "other": self._compute_size_in_bytes(strategy, "other"), + "output": self._compute_size_in_bytes(strategy, "output"), } if self.has_bias: bias_size = self._compute_size_in_bytes(strategy, "bias") - size_mapping['bias'] = bias_size + size_mapping["bias"] = bias_size # compute fwd cost incurred # fwd_cost = input + other + bias + output @@ -41,45 +40,47 @@ class MatMulStrategyGenerator(StrategyGenerator): # compute bwd cost incurred # bwd_cost = input_grad + bias_grad - bwd_activation_cost = sum([v for k, v in size_mapping.items() if k in ['input', 'other', 'bias']]) + bwd_activation_cost = sum([v for k, v in size_mapping.items() if k in ["input", "other", "bias"]]) bwd_mem_cost = MemoryCost(activation=bwd_activation_cost, parameter=0) # compute total cost - total_mem_cost = MemoryCost(activation=fwd_activation_cost + bwd_activation_cost, - parameter=fwd_parameter_cost + 0) + total_mem_cost = MemoryCost( + activation=fwd_activation_cost + bwd_activation_cost, parameter=fwd_parameter_cost + 0 + ) memory_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_mem_cost) strategy.memory_cost = memory_cost class DotProductStrategyGenerator(MatMulStrategyGenerator): - def validate(self) -> bool: - input_op_data = self.op_data['input'] - other_op_data = self.op_data['other'] + input_op_data = self.op_data["input"] + other_op_data = self.op_data["other"] assert input_op_data.data.dim() == 1 and other_op_data.data.dim() == 1 def update_compute_cost(self, strategy: ShardingStrategy) -> ShardingStrategy: - sharded_input_shape = strategy.sharding_specs[self.op_data['input']].get_sharded_shape_per_device() + sharded_input_shape = strategy.sharding_specs[self.op_data["input"]].get_sharded_shape_per_device() fwd_compute_cost = sharded_input_shape[0] bwd_compute_cost = fwd_compute_cost * 2 - compute_cost = TrainCycleItem(fwd=fwd_compute_cost, - bwd=bwd_compute_cost, - total=fwd_compute_cost + bwd_compute_cost) + compute_cost = TrainCycleItem( + fwd=fwd_compute_cost, bwd=bwd_compute_cost, total=fwd_compute_cost + bwd_compute_cost + ) return compute_cost @ignore_sharding_exception def no_split(self): - name = f'R = R dot R' - dim_partition_dict = {"input": {}, "other": {}, "output": {}, 'bias': {}} + name = f"R = R dot R" + dim_partition_dict = {"input": {}, "other": {}, "output": {}, "bias": {}} sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict) communication_action_mapping = {} - return self.get_sharding_strategy(name=name, - sharding_spec_mapping=sharding_spec_mapping, - communication_action_mapping=communication_action_mapping) + return self.get_sharding_strategy( + name=name, + sharding_spec_mapping=sharding_spec_mapping, + communication_action_mapping=communication_action_mapping, + ) @ignore_sharding_exception def split_one_dim(self, mesh_dim): - name = f'R = S{mesh_dim} dot S{mesh_dim}' + name = f"R = S{mesh_dim} dot S{mesh_dim}" # get sharding spec dim_partition_dict = {"input": {0: [mesh_dim]}, "other": {0: [mesh_dim]}, "output": {}, "bias": {0: [mesh_dim]}} @@ -87,14 +88,17 @@ class DotProductStrategyGenerator(MatMulStrategyGenerator): # get communication action output_comm_action = self.get_communication_action( - sharding_spec=sharding_spec_mapping['output'], + sharding_spec=sharding_spec_mapping["output"], communication_pattern=CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD, logical_process_axis=mesh_dim, - comm_type=CommType.AFTER) + comm_type=CommType.AFTER, + ) communication_action_mapping = {"output": output_comm_action} - return self.get_sharding_strategy(name=name, - sharding_spec_mapping=sharding_spec_mapping, - communication_action_mapping=communication_action_mapping) + return self.get_sharding_strategy( + name=name, + sharding_spec_mapping=sharding_spec_mapping, + communication_action_mapping=communication_action_mapping, + ) def collate_strategies(self) -> List[ShardingStrategy]: strategy_list = [] @@ -112,19 +116,18 @@ class DotProductStrategyGenerator(MatMulStrategyGenerator): class MatVecStrategyGenerator(MatMulStrategyGenerator): - def validate(self) -> bool: - input_op_data = self.op_data['input'] - other_op_data = self.op_data['other'] + input_op_data = self.op_data["input"] + other_op_data = self.op_data["other"] assert input_op_data.data.dim() == 2 and other_op_data.data.dim() == 1 def update_compute_cost(self, strategy: ShardingStrategy) -> ShardingStrategy: - sharded_input_shape = strategy.sharding_specs[self.op_data['input']].get_sharded_shape_per_device() + sharded_input_shape = strategy.sharding_specs[self.op_data["input"]].get_sharded_shape_per_device() fwd_compute_cost = sharded_input_shape[0] bwd_compute_cost = fwd_compute_cost * 2 - compute_cost = TrainCycleItem(fwd=fwd_compute_cost, - bwd=bwd_compute_cost, - total=fwd_compute_cost + bwd_compute_cost) + compute_cost = TrainCycleItem( + fwd=fwd_compute_cost, bwd=bwd_compute_cost, total=fwd_compute_cost + bwd_compute_cost + ) return compute_cost @ignore_sharding_exception @@ -133,67 +136,69 @@ class MatVecStrategyGenerator(MatMulStrategyGenerator): dim_partition_dict = {"input": {}, "other": {}, "output": {}} if self.has_bias: - dim_partition_dict['bias'] = {} + dim_partition_dict["bias"] = {} sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict) - return self.get_sharding_strategy(name=name, - sharding_spec_mapping=sharding_spec_mapping, - communication_action_mapping={}) + return self.get_sharding_strategy( + name=name, sharding_spec_mapping=sharding_spec_mapping, communication_action_mapping={} + ) @ignore_sharding_exception def split_input_batch(self, mesh_dim): - name = f'S{mesh_dim}R = S{mesh_dim}R x R' + name = f"S{mesh_dim}R = S{mesh_dim}R x R" # get sharding spec dim_partition_dict = { - "input": { - 0: [mesh_dim] - }, + "input": {0: [mesh_dim]}, "other": {}, - "output": { - 0: [mesh_dim] - }, + "output": {0: [mesh_dim]}, } if self.has_bias: - dim_partition_dict['bias'] = {} + dim_partition_dict["bias"] = {} sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict) # get communication action communication_action_mapping = {} - if self.is_param('other'): + if self.is_param("other"): other_comm_action = self.get_communication_action( - sharding_spec=sharding_spec_mapping['other'], + sharding_spec=sharding_spec_mapping["other"], communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, logical_process_axis=mesh_dim, - comm_type=CommType.HOOK) + comm_type=CommType.HOOK, + ) else: other_comm_action = self.get_communication_action( - sharding_spec=sharding_spec_mapping['other'], + sharding_spec=sharding_spec_mapping["other"], communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, logical_process_axis=mesh_dim, comm_type=CommType.BEFORE, - arg_index=1) - communication_action_mapping['other'] = other_comm_action + arg_index=1, + ) + communication_action_mapping["other"] = other_comm_action if self.has_bias: - if self.is_param('bias'): + if self.is_param("bias"): bias_comm_action = self.get_communication_action( - sharding_spec=sharding_spec_mapping['bias'], + sharding_spec=sharding_spec_mapping["bias"], communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, logical_process_axis=mesh_dim, - comm_type=CommType.HOOK) + comm_type=CommType.HOOK, + ) else: bias_comm_action = self.get_communication_action( - sharding_spec=sharding_spec_mapping['bias'], + sharding_spec=sharding_spec_mapping["bias"], communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, logical_process_axis=mesh_dim, comm_type=CommType.BEFORE, - arg_index=2) - communication_action_mapping['bias'] = bias_comm_action + arg_index=2, + ) + communication_action_mapping["bias"] = bias_comm_action - return self.get_sharding_strategy(name=name, - sharding_spec_mapping=sharding_spec_mapping, - communication_action_mapping=communication_action_mapping) + return self.get_sharding_strategy( + name=name, + sharding_spec_mapping=sharding_spec_mapping, + communication_action_mapping=communication_action_mapping, + ) def collate_strategies(self) -> List[ShardingStrategy]: strategy_list = [] @@ -209,12 +214,13 @@ class MatVecStrategyGenerator(MatMulStrategyGenerator): class LinearProjectionStrategyGenerator(MatMulStrategyGenerator): - - def __init__(self, - operation_data_mapping, - device_mesh, - linear_projection_type='linear', - solver_perference=SolverPerference.STANDARD): + def __init__( + self, + operation_data_mapping, + device_mesh, + linear_projection_type="linear", + solver_perference=SolverPerference.STANDARD, + ): super().__init__(operation_data_mapping, device_mesh) self.linear_projection_type = linear_projection_type self.solver_perference = solver_perference @@ -224,17 +230,17 @@ class LinearProjectionStrategyGenerator(MatMulStrategyGenerator): # C: [M, N], A: [M, P], B: [P, N] # fwd cost = MNP (only count mul) # bwd: 2 x fwd_cost - sharded_input_shape = strategy.sharding_specs[self.op_data['input']].get_sharded_shape_per_device() - sharded_other_shape = strategy.sharding_specs[self.op_data['other']].get_sharded_shape_per_device() + sharded_input_shape = strategy.sharding_specs[self.op_data["input"]].get_sharded_shape_per_device() + sharded_other_shape = strategy.sharding_specs[self.op_data["other"]].get_sharded_shape_per_device() dim_m_val = reduce(operator.mul, sharded_input_shape[:-1]) dim_n_val = sharded_other_shape[-1] dim_p_val = sharded_other_shape[0] fwd_compute_cost = dim_m_val * dim_n_val * dim_p_val bwd_compute_cost = fwd_compute_cost * 2 - compute_cost = TrainCycleItem(fwd=bwd_compute_cost, - bwd=bwd_compute_cost, - total=fwd_compute_cost + bwd_compute_cost) + compute_cost = TrainCycleItem( + fwd=bwd_compute_cost, bwd=bwd_compute_cost, total=fwd_compute_cost + bwd_compute_cost + ) strategy.compute_cost = compute_cost def dp_strategies(self) -> List[ShardingStrategy]: @@ -301,28 +307,21 @@ class LinearProjectionStrategyGenerator(MatMulStrategyGenerator): @ignore_sharding_exception def split_lhs_space_rhs_space(self, mesh_dim_0, mesh_dim_1): # handle case SS = SR x RS - name = f'S{mesh_dim_0}S{mesh_dim_1} = S{mesh_dim_0}R x RS{mesh_dim_1}' + name = f"S{mesh_dim_0}S{mesh_dim_1} = S{mesh_dim_0}R x RS{mesh_dim_1}" dim_partition_dict_mapping = { - "input": { - 0: [mesh_dim_0] - }, - "other": { - -1: [mesh_dim_1] - }, - "output": { - 0: [mesh_dim_0], - -1: [mesh_dim_1] - }, + "input": {0: [mesh_dim_0]}, + "other": {-1: [mesh_dim_1]}, + "output": {0: [mesh_dim_0], -1: [mesh_dim_1]}, } # linear bias only has one dimension, but addmm bias has same dimensions # as the output logically. - if self.linear_projection_type == 'linear': - dim_partition_dict_mapping['bias'] = {-1: [mesh_dim_1]} - elif self.linear_projection_type == 'addmm': - dim_partition_dict_mapping['bias'] = {0: [mesh_dim_0], -1: [mesh_dim_1]} + if self.linear_projection_type == "linear": + dim_partition_dict_mapping["bias"] = {-1: [mesh_dim_1]} + elif self.linear_projection_type == "addmm": + dim_partition_dict_mapping["bias"] = {0: [mesh_dim_0], -1: [mesh_dim_1]} else: - raise ('Unsupported linear projection type') + raise ("Unsupported linear projection type") sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping) @@ -333,75 +332,75 @@ class LinearProjectionStrategyGenerator(MatMulStrategyGenerator): communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, logical_process_axis=mesh_dim_1, comm_type=CommType.BEFORE, - arg_index=0) + arg_index=0, + ) - if self.is_param('other'): + if self.is_param("other"): other_comm_action = self.get_communication_action( sharding_spec_mapping["other"], communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, logical_process_axis=mesh_dim_0, - comm_type=CommType.HOOK) + comm_type=CommType.HOOK, + ) else: other_comm_action = self.get_communication_action( sharding_spec_mapping["other"], communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, logical_process_axis=mesh_dim_0, comm_type=CommType.BEFORE, - arg_index=1) + arg_index=1, + ) - communication_action_mapping['input'] = input_comm_action - communication_action_mapping['other'] = other_comm_action + communication_action_mapping["input"] = input_comm_action + communication_action_mapping["other"] = other_comm_action # we only add allreduce comm action for linear bias, because # allreduce comm action for addmm bias will be considered in post processing - if self.has_bias and self.linear_projection_type == 'linear': - if self.is_param('bias'): + if self.has_bias and self.linear_projection_type == "linear": + if self.is_param("bias"): bias_comm_action = self.get_communication_action( sharding_spec_mapping["bias"], communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, logical_process_axis=mesh_dim_0, - comm_type=CommType.HOOK) + comm_type=CommType.HOOK, + ) else: bias_comm_action = self.get_communication_action( sharding_spec_mapping["bias"], communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, logical_process_axis=mesh_dim_0, comm_type=CommType.BEFORE, - key_for_kwarg='bias') - communication_action_mapping['bias'] = bias_comm_action + key_for_kwarg="bias", + ) + communication_action_mapping["bias"] = bias_comm_action - return self.get_sharding_strategy(name=name, - sharding_spec_mapping=sharding_spec_mapping, - communication_action_mapping=communication_action_mapping) + return self.get_sharding_strategy( + name=name, + sharding_spec_mapping=sharding_spec_mapping, + communication_action_mapping=communication_action_mapping, + ) @ignore_sharding_exception def split_lhs_space_both_contract(self, mesh_dim_0, mesh_dim_1): # handle the case SR = SS x SR - name = f'S{mesh_dim_0}R = S{mesh_dim_0}S{mesh_dim_1} x S{mesh_dim_1}R' + name = f"S{mesh_dim_0}R = S{mesh_dim_0}S{mesh_dim_1} x S{mesh_dim_1}R" # get sharding spec mapping dim_partition_dict_mapping = { - "input": { - 0: [mesh_dim_0], - -1: [mesh_dim_1] - }, - "other": { - 0: [mesh_dim_1] - }, + "input": {0: [mesh_dim_0], -1: [mesh_dim_1]}, + "other": {0: [mesh_dim_1]}, "bias": {}, - "output": { - 0: [mesh_dim_0] - }, + "output": {0: [mesh_dim_0]}, } # linear bias only has one dimension, but addmm bias has same dimensions # as the output logically. - if self.linear_projection_type == 'linear': - dim_partition_dict_mapping['bias'] = {} - elif self.linear_projection_type == 'addmm': - dim_partition_dict_mapping['bias'] = {0: [mesh_dim_0]} + if self.linear_projection_type == "linear": + dim_partition_dict_mapping["bias"] = {} + elif self.linear_projection_type == "addmm": + dim_partition_dict_mapping["bias"] = {0: [mesh_dim_0]} else: - raise ('Unsupported linear projection type') + raise ("Unsupported linear projection type") sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping) @@ -412,66 +411,64 @@ class LinearProjectionStrategyGenerator(MatMulStrategyGenerator): sharding_spec=sharding_spec_mapping["output"], communication_pattern=CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD, logical_process_axis=mesh_dim_1, - comm_type=CommType.AFTER) + comm_type=CommType.AFTER, + ) - if self.is_param('other'): + if self.is_param("other"): other_comm_action = self.get_communication_action( sharding_spec_mapping["other"], communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, logical_process_axis=mesh_dim_0, - comm_type=CommType.HOOK) + comm_type=CommType.HOOK, + ) else: other_comm_action = self.get_communication_action( sharding_spec_mapping["other"], communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, logical_process_axis=mesh_dim_0, comm_type=CommType.BEFORE, - arg_index=1) + arg_index=1, + ) - communication_action_mapping['other'] = other_comm_action - communication_action_mapping['output'] = output_comm_action + communication_action_mapping["other"] = other_comm_action + communication_action_mapping["output"] = output_comm_action # we only add allreduce comm action for linear bias, because # allreduce comm action for addmm bias will be considered in post processing - if self.has_bias and self.linear_projection_type == 'linear': - if self.is_param('bias'): + if self.has_bias and self.linear_projection_type == "linear": + if self.is_param("bias"): bias_comm_action = self.get_communication_action( sharding_spec_mapping["bias"], communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, logical_process_axis=mesh_dim_0, - comm_type=CommType.HOOK) + comm_type=CommType.HOOK, + ) else: bias_comm_action = self.get_communication_action( sharding_spec_mapping["bias"], communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, logical_process_axis=mesh_dim_0, comm_type=CommType.BEFORE, - key_for_kwarg='bias') - communication_action_mapping['bias'] = bias_comm_action + key_for_kwarg="bias", + ) + communication_action_mapping["bias"] = bias_comm_action - return self.get_sharding_strategy(name=name, - sharding_spec_mapping=sharding_spec_mapping, - communication_action_mapping=communication_action_mapping) + return self.get_sharding_strategy( + name=name, + sharding_spec_mapping=sharding_spec_mapping, + communication_action_mapping=communication_action_mapping, + ) @ignore_sharding_exception def split_rhs_space_both_contract(self, mesh_dim_0, mesh_dim_1): - name = f'RS{mesh_dim_1} = RS{mesh_dim_0} x S{mesh_dim_0}S{mesh_dim_1}' + name = f"RS{mesh_dim_1} = RS{mesh_dim_0} x S{mesh_dim_0}S{mesh_dim_1}" # get sharding specs dim_partition_dict_mapping = { - "input": { - -1: [mesh_dim_0] - }, - "other": { - 0: [mesh_dim_0], - -1: [mesh_dim_1] - }, - "bias": { - -1: [mesh_dim_1] - }, - "output": { - -1: [mesh_dim_1] - }, + "input": {-1: [mesh_dim_0]}, + "other": {0: [mesh_dim_0], -1: [mesh_dim_1]}, + "bias": {-1: [mesh_dim_1]}, + "output": {-1: [mesh_dim_1]}, } # We don't have to do anything special for bias here, because @@ -482,34 +479,34 @@ class LinearProjectionStrategyGenerator(MatMulStrategyGenerator): # get communication actions communication_action_mapping = {} output_comm_action = self.get_communication_action( - sharding_spec=sharding_spec_mapping['output'], + sharding_spec=sharding_spec_mapping["output"], communication_pattern=CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD, logical_process_axis=mesh_dim_0, - comm_type=CommType.AFTER) + comm_type=CommType.AFTER, + ) input_comm_action = self.get_communication_action( - sharding_spec=sharding_spec_mapping['input'], + sharding_spec=sharding_spec_mapping["input"], communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, logical_process_axis=mesh_dim_1, comm_type=CommType.BEFORE, - arg_index=0) + arg_index=0, + ) communication_action_mapping["input"] = input_comm_action - communication_action_mapping['output'] = output_comm_action - return self.get_sharding_strategy(name=name, - sharding_spec_mapping=sharding_spec_mapping, - communication_action_mapping=communication_action_mapping) + communication_action_mapping["output"] = output_comm_action + return self.get_sharding_strategy( + name=name, + sharding_spec_mapping=sharding_spec_mapping, + communication_action_mapping=communication_action_mapping, + ) @ignore_sharding_exception def recompute_split_both_contract(self, mesh_dim): - name = f'RR = RS{mesh_dim} x S{mesh_dim}R' + name = f"RR = RS{mesh_dim} x S{mesh_dim}R" # get sharding spec dim_partition_dict_mapping = { - "input": { - -1: [mesh_dim] - }, - "other": { - 0: [mesh_dim] - }, + "input": {-1: [mesh_dim]}, + "other": {0: [mesh_dim]}, "bias": {}, "output": {}, } @@ -520,32 +517,29 @@ class LinearProjectionStrategyGenerator(MatMulStrategyGenerator): # get communication action communication_action_mapping = {} output_comm_action = self.get_communication_action( - sharding_spec=sharding_spec_mapping['output'], + sharding_spec=sharding_spec_mapping["output"], communication_pattern=CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD, logical_process_axis=mesh_dim, - comm_type=CommType.AFTER) + comm_type=CommType.AFTER, + ) - communication_action_mapping['output'] = output_comm_action - return self.get_sharding_strategy(name=name, - sharding_spec_mapping=sharding_spec_mapping, - communication_action_mapping=communication_action_mapping) + communication_action_mapping["output"] = output_comm_action + return self.get_sharding_strategy( + name=name, + sharding_spec_mapping=sharding_spec_mapping, + communication_action_mapping=communication_action_mapping, + ) @ignore_sharding_exception def split_rhs_space_only(self, mesh_dim): - name = f'RS{mesh_dim} = RR x RS{mesh_dim}' + name = f"RS{mesh_dim} = RR x RS{mesh_dim}" # get sharding spec dim_partition_dict_mapping = { "input": {}, - "other": { - -1: [mesh_dim] - }, - "bias": { - -1: [mesh_dim] - }, - "output": { - -1: [mesh_dim] - }, + "other": {-1: [mesh_dim]}, + "bias": {-1: [mesh_dim]}, + "output": {-1: [mesh_dim]}, } # We don't have to do anything special for bias here, because # the bias is already the same sharding spec as the output. @@ -554,93 +548,94 @@ class LinearProjectionStrategyGenerator(MatMulStrategyGenerator): # get communication actions communication_action_mapping = {} input_comm_action = self.get_communication_action( - sharding_spec=sharding_spec_mapping['input'], + sharding_spec=sharding_spec_mapping["input"], communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, logical_process_axis=mesh_dim, comm_type=CommType.BEFORE, - arg_index=0) + arg_index=0, + ) - communication_action_mapping['input'] = input_comm_action - return self.get_sharding_strategy(name=name, - sharding_spec_mapping=sharding_spec_mapping, - communication_action_mapping=communication_action_mapping) + communication_action_mapping["input"] = input_comm_action + return self.get_sharding_strategy( + name=name, + sharding_spec_mapping=sharding_spec_mapping, + communication_action_mapping=communication_action_mapping, + ) @ignore_sharding_exception def split_lhs_1st_dim_1d(self, mesh_dim_0, mesh_dim_1): - name = f'S{mesh_dim_0}{mesh_dim_1}R = S{mesh_dim_0}{mesh_dim_1}R x RR' + name = f"S{mesh_dim_0}{mesh_dim_1}R = S{mesh_dim_0}{mesh_dim_1}R x RR" # get sharding spec dim_partition_dict_mapping = { - "input": { - 0: [mesh_dim_0, mesh_dim_1] - }, + "input": {0: [mesh_dim_0, mesh_dim_1]}, "other": {}, "bias": {}, - "output": { - 0: [mesh_dim_0, mesh_dim_1] - }, + "output": {0: [mesh_dim_0, mesh_dim_1]}, } # linear bias only has one dimension, but addmm bias has same dimensions # as the output logically. - if self.linear_projection_type == 'linear': - dim_partition_dict_mapping['bias'] = {} - elif self.linear_projection_type == 'addmm': - dim_partition_dict_mapping['bias'] = {0: [mesh_dim_0, mesh_dim_1]} + if self.linear_projection_type == "linear": + dim_partition_dict_mapping["bias"] = {} + elif self.linear_projection_type == "addmm": + dim_partition_dict_mapping["bias"] = {0: [mesh_dim_0, mesh_dim_1]} else: - raise ('Unsupported linear projection type') + raise ("Unsupported linear projection type") sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping) # get communication action communication_action_mapping = {} - if self.is_param('other'): + if self.is_param("other"): other_comm_action = self.get_communication_action( - sharding_spec=sharding_spec_mapping['other'], + sharding_spec=sharding_spec_mapping["other"], communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, logical_process_axis=[mesh_dim_0, mesh_dim_1], - comm_type=CommType.HOOK) + comm_type=CommType.HOOK, + ) else: other_comm_action = self.get_communication_action( - sharding_spec=sharding_spec_mapping['other'], + sharding_spec=sharding_spec_mapping["other"], communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, logical_process_axis=[mesh_dim_0, mesh_dim_1], comm_type=CommType.BEFORE, - arg_index=1) - communication_action_mapping['other'] = other_comm_action + arg_index=1, + ) + communication_action_mapping["other"] = other_comm_action # we only add allreduce comm action for linear bias, because # allreduce comm action for addmm bias will be considered in post processing - if self.has_bias and self.linear_projection_type == 'linear': - if self.is_param('bias'): + if self.has_bias and self.linear_projection_type == "linear": + if self.is_param("bias"): bias_comm_action = self.get_communication_action( - sharding_spec=sharding_spec_mapping['bias'], + sharding_spec=sharding_spec_mapping["bias"], communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, logical_process_axis=[mesh_dim_0, mesh_dim_1], - comm_type=CommType.HOOK) + comm_type=CommType.HOOK, + ) else: bias_comm_action = self.get_communication_action( - sharding_spec=sharding_spec_mapping['bias'], + sharding_spec=sharding_spec_mapping["bias"], communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, logical_process_axis=[mesh_dim_0, mesh_dim_1], comm_type=CommType.BEFORE, - key_for_kwarg='bias') - communication_action_mapping['bias'] = bias_comm_action - return self.get_sharding_strategy(name=name, - sharding_spec_mapping=sharding_spec_mapping, - communication_action_mapping=communication_action_mapping) + key_for_kwarg="bias", + ) + communication_action_mapping["bias"] = bias_comm_action + return self.get_sharding_strategy( + name=name, + sharding_spec_mapping=sharding_spec_mapping, + communication_action_mapping=communication_action_mapping, + ) @ignore_sharding_exception def split_lhs_2nd_dim_1d(self, mesh_dim_0, mesh_dim_1): - name = f'RR = RS{mesh_dim_0}{mesh_dim_1} x S{mesh_dim_0}{mesh_dim_1}R' + name = f"RR = RS{mesh_dim_0}{mesh_dim_1} x S{mesh_dim_0}{mesh_dim_1}R" # get sharding spec dim_partition_dict_mapping = { - "input": { - -1: [mesh_dim_0, mesh_dim_1] - }, - "other": { - 0: [mesh_dim_0, mesh_dim_1] - }, + "input": {-1: [mesh_dim_0, mesh_dim_1]}, + "other": {0: [mesh_dim_0, mesh_dim_1]}, "bias": {}, "output": {}, } @@ -652,32 +647,29 @@ class LinearProjectionStrategyGenerator(MatMulStrategyGenerator): # get communication action communication_action_mapping = {} output_comm_action = self.get_communication_action( - sharding_spec=sharding_spec_mapping['output'], + sharding_spec=sharding_spec_mapping["output"], communication_pattern=CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD, logical_process_axis=[mesh_dim_0, mesh_dim_1], - comm_type=CommType.AFTER) - communication_action_mapping['output'] = output_comm_action + comm_type=CommType.AFTER, + ) + communication_action_mapping["output"] = output_comm_action - return self.get_sharding_strategy(name=name, - sharding_spec_mapping=sharding_spec_mapping, - communication_action_mapping=communication_action_mapping) + return self.get_sharding_strategy( + name=name, + sharding_spec_mapping=sharding_spec_mapping, + communication_action_mapping=communication_action_mapping, + ) @ignore_sharding_exception def split_rhs_2nd_dim_1d(self, mesh_dim_0, mesh_dim_1): - name = f'RS{mesh_dim_0}{mesh_dim_1} = RR x RS{mesh_dim_0}{mesh_dim_1}' + name = f"RS{mesh_dim_0}{mesh_dim_1} = RR x RS{mesh_dim_0}{mesh_dim_1}" # get sharding spec dim_partition_dict_mapping = { "input": {}, - "other": { - -1: [mesh_dim_0, mesh_dim_1] - }, - "bias": { - -1: [mesh_dim_0, mesh_dim_1] - }, - "output": { - -1: [mesh_dim_0, mesh_dim_1] - }, + "other": {-1: [mesh_dim_0, mesh_dim_1]}, + "bias": {-1: [mesh_dim_0, mesh_dim_1]}, + "output": {-1: [mesh_dim_0, mesh_dim_1]}, } # We don't have to do anything special for bias here, because @@ -687,20 +679,23 @@ class LinearProjectionStrategyGenerator(MatMulStrategyGenerator): # get communication action communication_action_mapping = {} input_comm_action = self.get_communication_action( - sharding_spec=sharding_spec_mapping['input'], + sharding_spec=sharding_spec_mapping["input"], communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, logical_process_axis=[mesh_dim_0, mesh_dim_1], comm_type=CommType.BEFORE, - arg_index=0) - communication_action_mapping['input'] = input_comm_action + arg_index=0, + ) + communication_action_mapping["input"] = input_comm_action - return self.get_sharding_strategy(name=name, - sharding_spec_mapping=sharding_spec_mapping, - communication_action_mapping=communication_action_mapping) + return self.get_sharding_strategy( + name=name, + sharding_spec_mapping=sharding_spec_mapping, + communication_action_mapping=communication_action_mapping, + ) @ignore_sharding_exception def non_split(self): - name = f'RR = RR x RR' + name = f"RR = RR x RR" # get sharding spec dim_partition_dict_mapping = { @@ -717,22 +712,24 @@ class LinearProjectionStrategyGenerator(MatMulStrategyGenerator): # get communication action communication_action_mapping = {} - return self.get_sharding_strategy(name=name, - sharding_spec_mapping=sharding_spec_mapping, - communication_action_mapping=communication_action_mapping) + return self.get_sharding_strategy( + name=name, + sharding_spec_mapping=sharding_spec_mapping, + communication_action_mapping=communication_action_mapping, + ) def validate(self) -> bool: assert "input" in self.op_data assert "other" in self.op_data # make sure the other has 2 dim - input_data = self.op_data['input'] - other_data = self.op_data['other'] + input_data = self.op_data["input"] + other_data = self.op_data["other"] assert input_data.data.dim() > 0 and other_data.data.dim() == 2 assert other_data.logical_shape[0] == input_data.logical_shape[-1] if self.has_bias: - bias_data = self.op_data['bias'] + bias_data = self.op_data["bias"] assert bias_data.logical_shape[-1] == other_data.logical_shape[-1] @@ -757,37 +754,38 @@ class BatchedMatMulStrategyGenerator(MatMulStrategyGenerator): def _pop_batch_dim_sharding_for_output(self, dim_partition_dict): # remove partition dict for dim 0 - dim_partition_dict['output'].pop(0, None) + dim_partition_dict["output"].pop(0, None) # decrease the remaining dim index by 1 temp_dim_partition = {} - keys = list(dim_partition_dict['output'].keys()) + keys = list(dim_partition_dict["output"].keys()) for key in keys: - val = dim_partition_dict['output'].pop(key) + val = dim_partition_dict["output"].pop(key) temp_dim_partition[key - 1] = val - dim_partition_dict['output'].update(temp_dim_partition) + dim_partition_dict["output"].update(temp_dim_partition) def validate(self) -> bool: - input_op_data = self.op_data['input'] - other_op_data = self.op_data['other'] + input_op_data = self.op_data["input"] + other_op_data = self.op_data["other"] assert len(input_op_data.logical_shape) == 3 or len(other_op_data.logical_shape) == 3 - if 'bias' in self.op_data: - bias_op_data = self.op_data['bias'] + if "bias" in self.op_data: + bias_op_data = self.op_data["bias"] assert bias_op_data.data.dim() < 3 and len(bias_op_data.logical_shape) == 2 def update_compute_cost(self, strategy: ShardingStrategy) -> ShardingStrategy: - fwd_compute_cost = self.op_data['input'].data.shape[-1] * reduce(operator.mul, - self.op_data['output'].data.shape) + fwd_compute_cost = self.op_data["input"].data.shape[-1] * reduce( + operator.mul, self.op_data["output"].data.shape + ) bwd_compute_cost = fwd_compute_cost * 2 - compute_cost = TrainCycleItem(fwd=fwd_compute_cost, - bwd=bwd_compute_cost, - total=fwd_compute_cost + bwd_compute_cost) + compute_cost = TrainCycleItem( + fwd=fwd_compute_cost, bwd=bwd_compute_cost, total=fwd_compute_cost + bwd_compute_cost + ) strategy.compute_cost = compute_cost @ignore_sharding_exception def split_one_batch_dim(self, mesh_dim): - name = f'Sb{mesh_dim} = Sb{mesh_dim} x Sb{mesh_dim}' + name = f"Sb{mesh_dim} = Sb{mesh_dim} x Sb{mesh_dim}" # get sharding_spec dim_partition_dict = {"input": {0: [mesh_dim]}, "other": {0: [mesh_dim]}, "bias": {}, "output": {0: [mesh_dim]}} @@ -799,30 +797,27 @@ class BatchedMatMulStrategyGenerator(MatMulStrategyGenerator): communication_action_mapping = {} if self.has_bias: bias_comm_action = self.get_communication_action( - sharding_spec=sharding_spec_mapping['bias'], + sharding_spec=sharding_spec_mapping["bias"], communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, logical_process_axis=mesh_dim, comm_type=CommType.BEFORE, - arg_index=0) - communication_action_mapping['bias'] = bias_comm_action - return self.get_sharding_strategy(name=name, - sharding_spec_mapping=sharding_spec_mapping, - communication_action_mapping=communication_action_mapping) + arg_index=0, + ) + communication_action_mapping["bias"] = bias_comm_action + return self.get_sharding_strategy( + name=name, + sharding_spec_mapping=sharding_spec_mapping, + communication_action_mapping=communication_action_mapping, + ) @ignore_sharding_exception def split_two_batch_dim(self, mesh_dim_0, mesh_dim_1): - name = f'Sb{mesh_dim_0}{mesh_dim_1} = Sb{mesh_dim_0}{mesh_dim_1} x Sb{mesh_dim_0}{mesh_dim_1}' + name = f"Sb{mesh_dim_0}{mesh_dim_1} = Sb{mesh_dim_0}{mesh_dim_1} x Sb{mesh_dim_0}{mesh_dim_1}" dim_partition_dict = { - "input": { - 0: [mesh_dim_0, mesh_dim_1] - }, - "other": { - 0: [mesh_dim_0, mesh_dim_1] - }, + "input": {0: [mesh_dim_0, mesh_dim_1]}, + "other": {0: [mesh_dim_0, mesh_dim_1]}, "bias": {}, - "output": { - 0: [mesh_dim_0, mesh_dim_1] - } + "output": {0: [mesh_dim_0, mesh_dim_1]}, } if self.squeeze_batch_dim: self._pop_batch_dim_sharding_for_output(dim_partition_dict) @@ -832,35 +827,28 @@ class BatchedMatMulStrategyGenerator(MatMulStrategyGenerator): communication_action_mapping = {} if self.has_bias: bias_comm_action = self.get_communication_action( - sharding_spec=sharding_spec_mapping['bias'], + sharding_spec=sharding_spec_mapping["bias"], communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, logical_process_axis=[mesh_dim_0, mesh_dim_1], comm_type=CommType.BEFORE, - arg_index=0) - communication_action_mapping['bias'] = bias_comm_action + arg_index=0, + ) + communication_action_mapping["bias"] = bias_comm_action - return self.get_sharding_strategy(name=name, - sharding_spec_mapping=sharding_spec_mapping, - communication_action_mapping=communication_action_mapping) + return self.get_sharding_strategy( + name=name, + sharding_spec_mapping=sharding_spec_mapping, + communication_action_mapping=communication_action_mapping, + ) @ignore_sharding_exception def split_batch_dim_lhs_space(self, mesh_dim_0, mesh_dim_1): - name = f'Sb{mesh_dim_0}Si{mesh_dim_1} = Sb{mesh_dim_0}Si{mesh_dim_1} x Sb{mesh_dim_0}' + name = f"Sb{mesh_dim_0}Si{mesh_dim_1} = Sb{mesh_dim_0}Si{mesh_dim_1} x Sb{mesh_dim_0}" dim_partition_dict = { - "input": { - 0: [mesh_dim_0], - 1: [mesh_dim_1] - }, - "other": { - 0: [mesh_dim_0] - }, - "bias": { - 0: [mesh_dim_1] - }, - "output": { - 0: [mesh_dim_0], - 1: [mesh_dim_1] - } + "input": {0: [mesh_dim_0], 1: [mesh_dim_1]}, + "other": {0: [mesh_dim_0]}, + "bias": {0: [mesh_dim_1]}, + "output": {0: [mesh_dim_0], 1: [mesh_dim_1]}, } if self.squeeze_batch_dim: self._pop_batch_dim_sharding_for_output(dim_partition_dict) @@ -869,46 +857,40 @@ class BatchedMatMulStrategyGenerator(MatMulStrategyGenerator): # get communication actions communication_action_mapping = {} other_comm_action = self.get_communication_action( - sharding_spec=sharding_spec_mapping['other'], + sharding_spec=sharding_spec_mapping["other"], communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, logical_process_axis=mesh_dim_1, comm_type=CommType.BEFORE, - arg_index=1) - communication_action_mapping['other'] = other_comm_action + arg_index=1, + ) + communication_action_mapping["other"] = other_comm_action if self.has_bias: bias_comm_action = self.get_communication_action( - sharding_spec=sharding_spec_mapping['bias'], + sharding_spec=sharding_spec_mapping["bias"], communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, logical_process_axis=[mesh_dim_0, mesh_dim_1], comm_type=CommType.BEFORE, - arg_index=0) - communication_action_mapping['bias'] = bias_comm_action + arg_index=0, + ) + communication_action_mapping["bias"] = bias_comm_action # for addbmm case, other is the third argument instead of second. - communication_action_mapping['other'].arg_index += 1 + communication_action_mapping["other"].arg_index += 1 - return self.get_sharding_strategy(name=name, - sharding_spec_mapping=sharding_spec_mapping, - communication_action_mapping=communication_action_mapping) + return self.get_sharding_strategy( + name=name, + sharding_spec_mapping=sharding_spec_mapping, + communication_action_mapping=communication_action_mapping, + ) @ignore_sharding_exception def split_batch_dim_rhs_space(self, mesh_dim_0, mesh_dim_1): - name = f'Sb{mesh_dim_0}Sj{mesh_dim_1} = Sb{mesh_dim_0}R x Sb{mesh_dim_0}Sj{mesh_dim_1}' + name = f"Sb{mesh_dim_0}Sj{mesh_dim_1} = Sb{mesh_dim_0}R x Sb{mesh_dim_0}Sj{mesh_dim_1}" dim_partition_dict = { - "input": { - 0: [mesh_dim_0] - }, - "other": { - 0: [mesh_dim_0], - 2: [mesh_dim_1] - }, - "bias": { - 1: [mesh_dim_1] - }, - "output": { - 0: [mesh_dim_0], - 2: [mesh_dim_1] - } + "input": {0: [mesh_dim_0]}, + "other": {0: [mesh_dim_0], 2: [mesh_dim_1]}, + "bias": {1: [mesh_dim_1]}, + "output": {0: [mesh_dim_0], 2: [mesh_dim_1]}, } if self.squeeze_batch_dim: self._pop_batch_dim_sharding_for_output(dim_partition_dict) @@ -917,43 +899,41 @@ class BatchedMatMulStrategyGenerator(MatMulStrategyGenerator): # get communication actions communication_action_mapping = {} input_comm_action = self.get_communication_action( - sharding_spec=sharding_spec_mapping['input'], + sharding_spec=sharding_spec_mapping["input"], communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, logical_process_axis=mesh_dim_1, comm_type=CommType.BEFORE, - arg_index=0) - communication_action_mapping['input'] = input_comm_action + arg_index=0, + ) + communication_action_mapping["input"] = input_comm_action if self.has_bias: bias_comm_action = self.get_communication_action( - sharding_spec=sharding_spec_mapping['bias'], + sharding_spec=sharding_spec_mapping["bias"], communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, logical_process_axis=mesh_dim_0, - comm_type=CommType.BEFORE) - communication_action_mapping['bias'] = bias_comm_action + comm_type=CommType.BEFORE, + ) + communication_action_mapping["bias"] = bias_comm_action # for addbmm case, other is the second argument instead of first. - communication_action_mapping['input'].arg_index += 1 + communication_action_mapping["input"].arg_index += 1 - return self.get_sharding_strategy(name=name, - sharding_spec_mapping=sharding_spec_mapping, - communication_action_mapping=communication_action_mapping) + return self.get_sharding_strategy( + name=name, + sharding_spec_mapping=sharding_spec_mapping, + communication_action_mapping=communication_action_mapping, + ) @ignore_sharding_exception def split_batch_dim_both_contract(self, mesh_dim_0, mesh_dim_1): - name = f'Sb{mesh_dim_0}R = Sb{mesh_dim_0}Sk{mesh_dim_1} x Sb{mesh_dim_0}Sk{mesh_dim_1}' + name = f"Sb{mesh_dim_0}R = Sb{mesh_dim_0}Sk{mesh_dim_1} x Sb{mesh_dim_0}Sk{mesh_dim_1}" dim_partition_dict = { - "input": { - 0: [mesh_dim_0], - 2: [mesh_dim_1] - }, - "other": { - 0: [mesh_dim_0], - 1: [mesh_dim_1] - }, + "input": {0: [mesh_dim_0], 2: [mesh_dim_1]}, + "other": {0: [mesh_dim_0], 1: [mesh_dim_1]}, "bias": {}, "output": { 0: [mesh_dim_0], - } + }, } if self.squeeze_batch_dim: self._pop_batch_dim_sharding_for_output(dim_partition_dict) @@ -962,29 +942,33 @@ class BatchedMatMulStrategyGenerator(MatMulStrategyGenerator): # get communication actions communication_action_mapping = {} output_comm_action = self.get_communication_action( - sharding_spec=sharding_spec_mapping['output'], + sharding_spec=sharding_spec_mapping["output"], communication_pattern=CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD, logical_process_axis=mesh_dim_1, - comm_type=CommType.AFTER) - communication_action_mapping['output'] = output_comm_action + comm_type=CommType.AFTER, + ) + communication_action_mapping["output"] = output_comm_action if self.has_bias: bias_comm_action = self.get_communication_action( - sharding_spec=sharding_spec_mapping['bias'], + sharding_spec=sharding_spec_mapping["bias"], communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, logical_process_axis=mesh_dim_0, comm_type=CommType.BEFORE, - arg_index=0) - communication_action_mapping['bias'] = bias_comm_action + arg_index=0, + ) + communication_action_mapping["bias"] = bias_comm_action - return self.get_sharding_strategy(name=name, - sharding_spec_mapping=sharding_spec_mapping, - communication_action_mapping=communication_action_mapping) + return self.get_sharding_strategy( + name=name, + sharding_spec_mapping=sharding_spec_mapping, + communication_action_mapping=communication_action_mapping, + ) def collate_strategies(self) -> List[ShardingStrategy]: strategy_list = [] device_mesh_is_1d = True - if len(self.device_mesh.mesh_shape) == 2 and 1 not in self.device_mesh.mesh_shape: + if len(self.device_mesh.shape) == 2 and 1 not in self.device_mesh.shape: device_mesh_is_1d = False if device_mesh_is_1d: @@ -992,10 +976,10 @@ class BatchedMatMulStrategyGenerator(MatMulStrategyGenerator): # Sb = Sb x Sb # can be None as it is only for 1D device mesh # only for 1D device mesh - if len(self.device_mesh.mesh_shape) == 1: + if len(self.device_mesh.shape) == 1: mesh_dim = 0 else: - mesh_dim = self.device_mesh.mesh_shape.index(1) + mesh_dim = self.device_mesh.shape.index(1) strategy_list.append(self.split_one_batch_dim(mesh_dim)) else: # for 2D device mesh diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/normal_pooling_generator.py b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/normal_pooling_generator.py index 9df6d2fbfa127b71eba66256fb27f204fb1da5fe..b307e38b5b6d15bd760b9fc03d0bbe010106b6df 100644 --- a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/normal_pooling_generator.py +++ b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/normal_pooling_generator.py @@ -17,32 +17,35 @@ class NormalPoolStrategyGenerator(StrategyGenerator): """ NormalPoolStrategyGenerator is a generic class to generate strategies for pool operation like MaxPoolxd. The reason we call this normal pool is AvgPoolxd and MaxPoolxd are taking the kernel size element from image, - and reduce them depening on the operation type. + and reduce them depending on the operation type. """ def validate(self) -> bool: - ''' + """ In sanity check, we need make sure the input data having correct dimension size. For Pool1d, the dim of input data should be 3([N, C, L]). For Pool2d, the dim of input data should be 4([N, C, H, W]). For Pool3d, the dim of input data should be 5([N, C, H, W, D]). - ''' - input_op_data = self.op_data['input'] + """ + input_op_data = self.op_data["input"] assert input_op_data.data.dim() in ( - 3, 4, 5), f'We suppose the dim of input fed into Pool op should in range of [3, 5].' + 3, + 4, + 5, + ), f"We suppose the dim of input fed into Pool op should in range of [3, 5]." def update_compute_cost(self, strategy: ShardingStrategy) -> TrainCycleItem: - ''' + """ Compute the computation cost per device with this specific strategy. - Note: compute_cost need to be devided by TFLOPS, now it just shows the computation size. - ''' - # TODO: compute_cost need to be devided by TFLOPS, now it just shows the computation size. + Note: compute_cost need to be divided by TFLOPS, now it just shows the computation size. + """ + # TODO: compute_cost need to be divided by TFLOPS, now it just shows the computation size. # 1D: (Lout) * N * C * kernel # 2D: (H * W) * N * Cout * Cin * kernel # 3D: (H * W * D) * N * Cout * Cin * kernel - sharded_output_shape = strategy.sharding_specs[self.op_data['output']].get_sharded_shape_per_device() - sharded_input_shape = strategy.sharding_specs[self.op_data['input']].get_sharded_shape_per_device() + sharded_output_shape = strategy.sharding_specs[self.op_data["output"]].get_sharded_shape_per_device() + sharded_input_shape = strategy.sharding_specs[self.op_data["input"]].get_sharded_shape_per_device() kernel_size = self.op_data["other"].data if isinstance(kernel_size, int): @@ -61,8 +64,8 @@ class NormalPoolStrategyGenerator(StrategyGenerator): def update_memory_cost(self, strategy: ShardingStrategy) -> ShardingStrategy: forward_size_mapping = { - 'input': self._compute_size_in_bytes(strategy, "input"), - 'output': self._compute_size_in_bytes(strategy, "output") + "input": self._compute_size_in_bytes(strategy, "input"), + "output": self._compute_size_in_bytes(strategy, "output"), } backward_size_mapping = copy.deepcopy(forward_size_mapping) @@ -88,12 +91,16 @@ class NormalPoolStrategyGenerator(StrategyGenerator): sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping) - name = f'{sharding_spec_mapping["output"].sharding_sequence} = {sharding_spec_mapping["input"].sharding_sequence}' + name = ( + f'{sharding_spec_mapping["output"].sharding_sequence} = {sharding_spec_mapping["input"].sharding_sequence}' + ) communication_action_mapping = {} - strategy = self.get_sharding_strategy(name=name, - sharding_spec_mapping=sharding_spec_mapping, - communication_action_mapping=communication_action_mapping) + strategy = self.get_sharding_strategy( + name=name, + sharding_spec_mapping=sharding_spec_mapping, + communication_action_mapping=communication_action_mapping, + ) return strategy diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/output_generator.py b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/output_generator.py index 69d1642d4f808038d0eeb58547a7b1c0604c85eb..33fb1ac5c5be779a5f65b44431e4ba84c878fab9 100644 --- a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/output_generator.py +++ b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/output_generator.py @@ -12,7 +12,7 @@ from colossalai.device.device_mesh import DeviceMesh from .strategy_generator import OutputStrategyGenerator -__all__ = ['OutputGenerator'] +__all__ = ["OutputGenerator"] class OutputGenerator(OutputStrategyGenerator): @@ -20,8 +20,13 @@ class OutputGenerator(OutputStrategyGenerator): OutputGenerator is a generic class to generate strategies for Output Node. """ - def __init__(self, operation_data_mapping: Dict[str, OperationData], device_mesh: DeviceMesh, - predecessor_nodes: List[Node], output_option: str): + def __init__( + self, + operation_data_mapping: Dict[str, OperationData], + device_mesh: DeviceMesh, + predecessor_nodes: List[Node], + output_option: str, + ): super().__init__(operation_data_mapping, device_mesh, predecessor_nodes) self.output_option = output_option @@ -33,9 +38,9 @@ class OutputGenerator(OutputStrategyGenerator): strategy.compute_cost = compute_cost def update_memory_cost(self, strategy: ShardingStrategy): - ''' + """ Compute the memory cost per device with this specific strategy. - ''' + """ fwd_mem_cost = MemoryCost(activation=0, parameter=0) bwd_mem_cost = MemoryCost(activation=0, parameter=0) @@ -65,16 +70,18 @@ class OutputGenerator(OutputStrategyGenerator): else: dim_partition_dict_for_output = tuple(dim_partition_dict_for_output) - dim_partition_dict_mapping['output'] = dim_partition_dict_for_output + dim_partition_dict_mapping["output"] = dim_partition_dict_for_output communication_action_mapping = {} sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping) - name = 'Replica Output' + name = "Replica Output" - strategy = self.get_sharding_strategy(name=name, - sharding_spec_mapping=sharding_spec_mapping, - communication_action_mapping=communication_action_mapping) + strategy = self.get_sharding_strategy( + name=name, + sharding_spec_mapping=sharding_spec_mapping, + communication_action_mapping=communication_action_mapping, + ) return strategy def distributed_strategy(self, mesh_list: List[List[int]] = None) -> List[ShardingStrategy]: @@ -82,19 +89,15 @@ class OutputGenerator(OutputStrategyGenerator): Generate distributed strategy for output node. """ # TODO: need to take care of the case when the first element of output only need to be sharded. - output_op_data = self.op_data['output'] + output_op_data = self.op_data["output"] if isinstance(output_op_data.data, tuple): length = len(output_op_data.data) dim_partition_dict_mapping = { - "output": [{ - 0: mesh_list - }] * length, + "output": [{0: mesh_list}] * length, } else: dim_partition_dict_mapping = { - "output": { - 0: mesh_list - }, + "output": {0: mesh_list}, } for index, _ in enumerate(self.predecessor_nodes): mapping_name = f"input_{index}" @@ -103,19 +106,21 @@ class OutputGenerator(OutputStrategyGenerator): communication_action_mapping = {} sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping) - name = 'Distributed Output' + name = "Distributed Output" - strategy = self.get_sharding_strategy(name=name, - sharding_spec_mapping=sharding_spec_mapping, - communication_action_mapping=communication_action_mapping) + strategy = self.get_sharding_strategy( + name=name, + sharding_spec_mapping=sharding_spec_mapping, + communication_action_mapping=communication_action_mapping, + ) return strategy def collate_strategies(self) -> List[ShardingStrategy]: strategy_list = [] mesh_list = [0, 1] - if self.output_option == 'replicated': + if self.output_option == "replicated": strategy_list.append(self.replica_strategy()) - elif self.output_option == 'distributed': + elif self.output_option == "distributed": strategy_list.append(self.distributed_strategy(mesh_list)) return strategy_list diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/placeholder_generator.py b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/placeholder_generator.py index 779a7ced93bb503c390bd89382d087230e48d2f0..df0862a396d2553ed7939dd09bea45d2317df211 100644 --- a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/placeholder_generator.py +++ b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/placeholder_generator.py @@ -10,7 +10,7 @@ from colossalai.device.device_mesh import DeviceMesh from .strategy_generator import StrategyGenerator -__all__ = ['PlaceholderGenerator'] +__all__ = ["PlaceholderGenerator"] class PlaceholderGenerator(StrategyGenerator): @@ -18,8 +18,9 @@ class PlaceholderGenerator(StrategyGenerator): PlaceholderGenerator is a generic class to generate strategies for placeholder node. """ - def __init__(self, operation_data_mapping: Dict[str, OperationData], device_mesh: DeviceMesh, - placeholder_option: str): + def __init__( + self, operation_data_mapping: Dict[str, OperationData], device_mesh: DeviceMesh, placeholder_option: str + ): super().__init__(operation_data_mapping, device_mesh) self.placeholder_option = placeholder_option @@ -31,10 +32,10 @@ class PlaceholderGenerator(StrategyGenerator): strategy.compute_cost = compute_cost def update_memory_cost(self, strategy: ShardingStrategy): - ''' + """ Compute the memory cost per device with this specific strategy. - ''' - forward_size_mapping = {'output': self._compute_size_in_bytes(strategy, "output")} + """ + forward_size_mapping = {"output": self._compute_size_in_bytes(strategy, "output")} # compute fwd cost incurred # fwd_cost = output @@ -58,11 +59,13 @@ class PlaceholderGenerator(StrategyGenerator): communication_action_mapping = {} sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping) - name = 'Replica Placeholder' + name = "Replica Placeholder" - strategy = self.get_sharding_strategy(name=name, - sharding_spec_mapping=sharding_spec_mapping, - communication_action_mapping=communication_action_mapping) + strategy = self.get_sharding_strategy( + name=name, + sharding_spec_mapping=sharding_spec_mapping, + communication_action_mapping=communication_action_mapping, + ) return strategy @@ -71,29 +74,31 @@ class PlaceholderGenerator(StrategyGenerator): Generate distributed strategy for placeholder node. """ dim_partition_dict_mapping = { - "output": { - 0: mesh_list - }, + "output": {0: mesh_list}, } communication_action_mapping = {} sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping) - name = 'Distributed Placeholder' + name = "Distributed Placeholder" - strategy = self.get_sharding_strategy(name=name, - sharding_spec_mapping=sharding_spec_mapping, - communication_action_mapping=communication_action_mapping) + strategy = self.get_sharding_strategy( + name=name, + sharding_spec_mapping=sharding_spec_mapping, + communication_action_mapping=communication_action_mapping, + ) return strategy def collate_strategies(self) -> List[ShardingStrategy]: strategy_list = [] - if self.placeholder_option == 'distributed': + if self.placeholder_option == "distributed": mesh_list = [0, 1] distributed_strategy = self.distributed_placeholder(mesh_list) strategy_list.append(distributed_strategy) else: - assert self.placeholder_option == 'replicated', f'placeholder_option {self.placeholder_option} is not supported' + assert ( + self.placeholder_option == "replicated" + ), f"placeholder_option {self.placeholder_option} is not supported" replicated_strategy = self.replica_placeholder() strategy_list.append(replicated_strategy) diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/reshape_generator.py b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/reshape_generator.py index 24f75e352935f149e02c399b2c8e90c0f3ddc2f7..48f454553ac745fc5bfc84a59def4db7790897f0 100644 --- a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/reshape_generator.py +++ b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/reshape_generator.py @@ -17,7 +17,7 @@ from colossalai.auto_parallel.tensor_shard.utils import ( from colossalai.tensor.shape_consistency import CollectiveCommPattern from colossalai.tensor.sharding_spec import ShardingSpec -__all__ = ['ReshapeGenerator', 'ViewGenerator', 'PermuteGenerator', 'TransposeGenerator', 'SplitGenerator'] +__all__ = ["ReshapeGenerator", "ViewGenerator", "PermuteGenerator", "TransposeGenerator", "SplitGenerator"] class ReshapeGenerator(FollowingStrategyGenerator): @@ -33,12 +33,12 @@ class ReshapeGenerator(FollowingStrategyGenerator): strategy.compute_cost = compute_cost def update_memory_cost(self, strategy: ShardingStrategy): - ''' + """ Compute the memory cost per device with this specific strategy. - ''' + """ forward_size_mapping = { - 'input': self._compute_size_in_bytes(strategy, "input"), - 'output': self._compute_size_in_bytes(strategy, "output") + "input": self._compute_size_in_bytes(strategy, "input"), + "output": self._compute_size_in_bytes(strategy, "output"), } backward_size_mapping = copy.deepcopy(forward_size_mapping) @@ -56,8 +56,9 @@ class ReshapeGenerator(FollowingStrategyGenerator): bwd_mem_cost = MemoryCost(activation=bwd_activation_cost, parameter=bwd_parameter_cost) # compute total cost - total_mem_cost = MemoryCost(activation=fwd_activation_cost + bwd_activation_cost, - parameter=fwd_parameter_cost + bwd_parameter_cost) + total_mem_cost = MemoryCost( + activation=fwd_activation_cost + bwd_activation_cost, parameter=fwd_parameter_cost + bwd_parameter_cost + ) memory_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_mem_cost) strategy.memory_cost = memory_cost @@ -77,8 +78,8 @@ class ViewGenerator(ReshapeGenerator): communication_action_mapping = {} input_sharding_spec = strategy.output_sharding_specs[self.op_data["input"]] - origin_shape = self.op_data['input'].data.shape - tgt_shape = self.op_data['tgt_shape'].data + origin_shape = self.op_data["input"].data.shape + tgt_shape = self.op_data["tgt_shape"].data reshape_mapping_dict = detect_reshape_mapping(origin_shape, tgt_shape) @@ -86,8 +87,9 @@ class ViewGenerator(ReshapeGenerator): keep_sharding_status = check_keep_sharding_status(dim_partition_dict_for_input, reshape_mapping_dict) if keep_sharding_status: - dim_partition_dict_for_output = infer_output_dim_partition_dict(dim_partition_dict_for_input, - reshape_mapping_dict) + dim_partition_dict_for_output = infer_output_dim_partition_dict( + dim_partition_dict_for_input, reshape_mapping_dict + ) else: dim_partition_dict_for_output = {} @@ -119,7 +121,8 @@ class ViewGenerator(ReshapeGenerator): communication_pattern=CollectiveCommPattern.GATHER_FWD_SPLIT_BWD, logical_process_axis=total_mesh_dim_list, comm_type=CommType.BEFORE, - arg_index=0) + arg_index=0, + ) # it will gather the input through gather_dim during forward phase. input_comm_action.comm_spec.gather_dim = shard_dim # it will split the input activation grad through shard_dim during backward phase. @@ -127,10 +130,10 @@ class ViewGenerator(ReshapeGenerator): elif len(total_mesh_dim_list) >= 2: source_spec = sharding_spec_mapping["input"] - target_spec = ShardingSpec(device_mesh=self.device_mesh, - entire_shape=source_spec.entire_shape, - dim_partition_dict={}) - comm_spec = {'src_spec': source_spec, 'tgt_spec': target_spec} + target_spec = ShardingSpec( + device_mesh=self.device_mesh, entire_shape=source_spec.entire_shape, dim_partition_dict={} + ) + comm_spec = {"src_spec": source_spec, "tgt_spec": target_spec} input_comm_action = CommAction(comm_spec=comm_spec, comm_type=CommType.BEFORE, arg_index=0) else: @@ -139,9 +142,11 @@ class ViewGenerator(ReshapeGenerator): if input_comm_action is not None: communication_action_mapping["input"] = input_comm_action - strategy = self.get_sharding_strategy(name=name, - sharding_spec_mapping=sharding_spec_mapping, - communication_action_mapping=communication_action_mapping) + strategy = self.get_sharding_strategy( + name=name, + sharding_spec_mapping=sharding_spec_mapping, + communication_action_mapping=communication_action_mapping, + ) strategy_list.append(strategy) return strategy_list @@ -159,7 +164,7 @@ class PermuteGenerator(ReshapeGenerator): communication_action_mapping = {} input_sharding_spec = strategy.output_sharding_specs[self.op_data["input"]] - permute_dims = self.op_data['permute_dims'].data + permute_dims = self.op_data["permute_dims"].data dim_partition_dict_for_input = input_sharding_spec.dim_partition_dict dim_partition_dict_for_output = {} for dim_index, permute_dim in enumerate(permute_dims): @@ -177,9 +182,11 @@ class PermuteGenerator(ReshapeGenerator): # because in solver, this node will be merged into other nodes, and solver will not create a new variable for this node. name = f'{sharding_spec_mapping["input"].sharding_sequence} -> {sharding_spec_mapping["output"].sharding_sequence}_{index}' - strategy = self.get_sharding_strategy(name=name, - sharding_spec_mapping=sharding_spec_mapping, - communication_action_mapping=communication_action_mapping) + strategy = self.get_sharding_strategy( + name=name, + sharding_spec_mapping=sharding_spec_mapping, + communication_action_mapping=communication_action_mapping, + ) strategy_list.append(strategy) return strategy_list @@ -199,7 +206,7 @@ class TransposeGenerator(ReshapeGenerator): dim_partition_dict_for_input = input_sharding_spec.dim_partition_dict dim_partition_dict_for_output = {} - transpose_dims = self.op_data['transpose_dims'].data + transpose_dims = self.op_data["transpose_dims"].data dim_0 = transpose_dims[0] dim_1 = transpose_dims[1] for dim, sharded_dims in dim_partition_dict_for_input.items(): @@ -221,9 +228,11 @@ class TransposeGenerator(ReshapeGenerator): # because in solver, this node will be merged into other nodes, and solver will not create a new variable for this node. name = f'{sharding_spec_mapping["input"].sharding_sequence} -> {sharding_spec_mapping["output"].sharding_sequence}_{index}' - strategy = self.get_sharding_strategy(name=name, - sharding_spec_mapping=sharding_spec_mapping, - communication_action_mapping=communication_action_mapping) + strategy = self.get_sharding_strategy( + name=name, + sharding_spec_mapping=sharding_spec_mapping, + communication_action_mapping=communication_action_mapping, + ) strategy_list.append(strategy) return strategy_list @@ -242,7 +251,7 @@ class SplitGenerator(ReshapeGenerator): communication_action_mapping = {} input_sharding_spec = strategy.output_sharding_specs[self.op_data["input"]] dim_partition_dict_for_input = copy.deepcopy(input_sharding_spec.dim_partition_dict) - split_size, split_dim = self.op_data['split_info'].data + split_size, split_dim = self.op_data["split_info"].data if split_dim in dim_partition_dict_for_input: recover_dims = dim_partition_dict_for_input.pop(split_dim) @@ -271,7 +280,8 @@ class SplitGenerator(ReshapeGenerator): communication_pattern=CollectiveCommPattern.GATHER_FWD_SPLIT_BWD, logical_process_axis=recover_dims, comm_type=CommType.BEFORE, - arg_index=0) + arg_index=0, + ) # it will gather the input through gather_dim during forward phase. input_comm_action.comm_spec.gather_dim = split_dim # it will split the input activation grad through split_dim during backward phase. @@ -282,7 +292,7 @@ class SplitGenerator(ReshapeGenerator): source_spec = input_sharding_spec # target sharding spec target_spec = sharding_spec_mapping["input"] - comm_spec = {'src_spec': source_spec, 'tgt_spec': target_spec} + comm_spec = {"src_spec": source_spec, "tgt_spec": target_spec} input_comm_action = CommAction(comm_spec=comm_spec, comm_type=CommType.BEFORE, arg_index=0) else: @@ -291,9 +301,11 @@ class SplitGenerator(ReshapeGenerator): if input_comm_action is not None: communication_action_mapping["input"] = input_comm_action - strategy = self.get_sharding_strategy(name=name, - sharding_spec_mapping=sharding_spec_mapping, - communication_action_mapping=communication_action_mapping) + strategy = self.get_sharding_strategy( + name=name, + sharding_spec_mapping=sharding_spec_mapping, + communication_action_mapping=communication_action_mapping, + ) strategy_list.append(strategy) return strategy_list @@ -341,16 +353,17 @@ class DefaultReshapeGenerator(ReshapeGenerator): communication_pattern=CollectiveCommPattern.GATHER_FWD_SPLIT_BWD, logical_process_axis=total_mesh_dim_list, comm_type=CommType.BEFORE, - arg_index=0) + arg_index=0, + ) input_comm_action.comm_spec.gather_dim = total_mesh_dim_list input_comm_action.comm_spec.shard_dim = total_mesh_dim_list elif len(total_mesh_dim_list) >= 2: source_spec = sharding_spec_mapping["input"] - target_spec = ShardingSpec(device_mesh=self.device_mesh, - entire_shape=source_spec.entire_shape, - dim_partition_dict={}) - comm_spec = {'src_spec': source_spec, 'tgt_spec': target_spec} + target_spec = ShardingSpec( + device_mesh=self.device_mesh, entire_shape=source_spec.entire_shape, dim_partition_dict={} + ) + comm_spec = {"src_spec": source_spec, "tgt_spec": target_spec} input_comm_action = CommAction(comm_spec=comm_spec, comm_type=CommType.BEFORE, arg_index=0) else: @@ -358,9 +371,11 @@ class DefaultReshapeGenerator(ReshapeGenerator): if input_comm_action is not None: communication_action_mapping["input"] = input_comm_action - strategy = self.get_sharding_strategy(name=name, - sharding_spec_mapping=sharding_spec_mapping, - communication_action_mapping=communication_action_mapping) + strategy = self.get_sharding_strategy( + name=name, + sharding_spec_mapping=sharding_spec_mapping, + communication_action_mapping=communication_action_mapping, + ) strategy_list.append(strategy) return strategy_list diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/softmax_generator.py b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/softmax_generator.py index a1ebadd043e2c2e563fcdc611567bca3ededfa51..d4382f9941d2b21475ef7d077a0af92054aae9b0 100644 --- a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/softmax_generator.py +++ b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/softmax_generator.py @@ -4,21 +4,9 @@ from functools import reduce from typing import List from colossalai.auto_parallel.tensor_shard.node_handler.strategy.strategy_generator import FollowingStrategyGenerator -from colossalai.auto_parallel.tensor_shard.sharding_strategy import ( - CommAction, - CommType, - MemoryCost, - ShardingStrategy, - TrainCycleItem, -) -from colossalai.auto_parallel.tensor_shard.utils import ( - check_keep_sharding_status, - detect_reshape_mapping, - infer_output_dim_partition_dict, -) -from colossalai.tensor.shape_consistency import CollectiveCommPattern - -__all__ = ['SoftmaxGenerator'] +from colossalai.auto_parallel.tensor_shard.sharding_strategy import MemoryCost, ShardingStrategy, TrainCycleItem + +__all__ = ["SoftmaxGenerator"] class SoftmaxGenerator(FollowingStrategyGenerator): @@ -30,11 +18,11 @@ class SoftmaxGenerator(FollowingStrategyGenerator): return super().validate() def update_compute_cost(self, strategy: ShardingStrategy): - ''' + """ Compute the computation cost per device with this specific strategy. - ''' - sharded_input_shape = strategy.sharding_specs[self.op_data['input']].get_sharded_shape_per_device() - sharded_output_shape = strategy.sharding_specs[self.op_data['output']].get_sharded_shape_per_device() + """ + sharded_input_shape = strategy.sharding_specs[self.op_data["input"]].get_sharded_shape_per_device() + sharded_output_shape = strategy.sharding_specs[self.op_data["output"]].get_sharded_shape_per_device() input_size_product = reduce(operator.mul, sharded_input_shape) output_size_product = reduce(operator.mul, sharded_output_shape) @@ -45,12 +33,12 @@ class SoftmaxGenerator(FollowingStrategyGenerator): strategy.compute_cost = compute_cost def update_memory_cost(self, strategy: ShardingStrategy): - ''' + """ Compute the memory cost per device with this specific strategy. - ''' + """ forward_size_mapping = { - 'input': self._compute_size_in_bytes(strategy, "input"), - 'output': self._compute_size_in_bytes(strategy, "output") + "input": self._compute_size_in_bytes(strategy, "input"), + "output": self._compute_size_in_bytes(strategy, "output"), } backward_size_mapping = copy.deepcopy(forward_size_mapping) @@ -68,8 +56,9 @@ class SoftmaxGenerator(FollowingStrategyGenerator): bwd_mem_cost = MemoryCost(activation=bwd_activation_cost, parameter=bwd_parameter_cost) # compute total cost - total_mem_cost = MemoryCost(activation=fwd_activation_cost + bwd_activation_cost, - parameter=fwd_parameter_cost + bwd_parameter_cost) + total_mem_cost = MemoryCost( + activation=fwd_activation_cost + bwd_activation_cost, parameter=fwd_parameter_cost + bwd_parameter_cost + ) memory_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_mem_cost) strategy.memory_cost = memory_cost @@ -80,10 +69,10 @@ class SoftmaxGenerator(FollowingStrategyGenerator): communication_action_mapping = {} input_sharding_spec = strategy.output_sharding_specs[self.op_data["input"]] dim_partition_dict_for_input = copy.deepcopy(input_sharding_spec.dim_partition_dict) - softmax_dim = self.op_data['softmax_dim'].data + softmax_dim = self.op_data["softmax_dim"].data if softmax_dim in dim_partition_dict_for_input: - recover_dims = dim_partition_dict_for_input.pop(softmax_dim) + dim_partition_dict_for_input.pop(softmax_dim) dim_partition_dict_for_output = copy.deepcopy(dim_partition_dict_for_input) dim_partition_dict_mapping = { @@ -96,9 +85,11 @@ class SoftmaxGenerator(FollowingStrategyGenerator): # because in solver, this node will be merged into other nodes, and solver will not create a new variable for this node. name = f'{sharding_spec_mapping["input"].sharding_sequence} -> {sharding_spec_mapping["output"].sharding_sequence}_{index}' - strategy = self.get_sharding_strategy(name=name, - sharding_spec_mapping=sharding_spec_mapping, - communication_action_mapping=communication_action_mapping) + strategy = self.get_sharding_strategy( + name=name, + sharding_spec_mapping=sharding_spec_mapping, + communication_action_mapping=communication_action_mapping, + ) strategy_list.append(strategy) return strategy_list diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/strategy_generator.py b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/strategy_generator.py index 6d68521aaea7989c085c24f32a8dc92f4b1b71fc..7bf2c8cc12a39730ea79dccb2ef810b32e0b95a9 100644 --- a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/strategy_generator.py +++ b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/strategy_generator.py @@ -39,7 +39,7 @@ class StrategyGenerator(ABC): """ A utility method to check for the existence of bias operand for convenience. """ - return 'bias' in self.op_data + return "bias" in self.op_data def is_param(self, op_data_name): other_data = self.op_data[op_data_name] @@ -49,8 +49,12 @@ class StrategyGenerator(ABC): other_data = self.op_data[op_data_name] return other_data.type == OperationDataType.BUFFER - def get_sharding_strategy(self, name: str, sharding_spec_mapping: Dict[str, ShardingSpec], - communication_action_mapping: Dict[str, CommSpec]): + def get_sharding_strategy( + self, + name: str, + sharding_spec_mapping: Dict[str, ShardingSpec], + communication_action_mapping: Dict[str, CommSpec], + ): """ A factory method to produce a ShardingStrategy object. @@ -80,24 +84,28 @@ class StrategyGenerator(ABC): op_data = self.op_data[op_data_name] def _to_sharding_spec( - data: any, logical_shape: any, - dim_partition_dict: Dict[int, List[int]]) -> Union[ShardingSpec, List[ShardingSpec], None]: + data: any, logical_shape: any, dim_partition_dict: Dict[int, List[int]] + ) -> Union[ShardingSpec, List[ShardingSpec], None]: """ This is a recursive function to convert the dim partition dict to a ShardingSpec object. """ if isinstance(data, torch.Tensor): dim_size = len(logical_shape) dim_partition_dict = convert_dim_partition_dict(dim_size, dim_partition_dict) - sharding_spec = ShardingSpec(device_mesh=self.device_mesh, - entire_shape=logical_shape, - dim_partition_dict=dim_partition_dict) + sharding_spec = ShardingSpec( + device_mesh=self.device_mesh, + entire_shape=logical_shape, + dim_partition_dict=dim_partition_dict, + ) return sharding_spec elif isinstance(data, (list, tuple)): sharding_spec = [] for data_element, logical_shape_element, dim_partition_dict_element in zip( - data, logical_shape, dim_partition_dict): + data, logical_shape, dim_partition_dict + ): sharding_spec.append( - _to_sharding_spec(data_element, logical_shape_element, dim_partition_dict_element)) + _to_sharding_spec(data_element, logical_shape_element, dim_partition_dict_element) + ) return sharding_spec else: return None @@ -116,31 +124,41 @@ class StrategyGenerator(ABC): results[op_data] = v return results - def get_communication_spec(self, sharding_spec: ShardingSpec, communication_pattern: CollectiveCommPattern, - logical_process_axis: Union[int, List[int]]): + def get_communication_spec( + self, + sharding_spec: ShardingSpec, + communication_pattern: CollectiveCommPattern, + logical_process_axis: Union[int, List[int]], + ): """ A factory method to produce a CommSpec object. """ - return CommSpec(comm_pattern=communication_pattern, - sharding_spec=sharding_spec, - logical_process_axis=logical_process_axis) - - def get_communication_action(self, - sharding_spec: ShardingSpec, - communication_pattern: CollectiveCommPattern, - logical_process_axis: Union[int, List[int]], - comm_type: CommType, - arg_index: int = -1, - key_for_kwarg: any = None) -> CommAction: + return CommSpec( + comm_pattern=communication_pattern, sharding_spec=sharding_spec, logical_process_axis=logical_process_axis + ) + + def get_communication_action( + self, + sharding_spec: ShardingSpec, + communication_pattern: CollectiveCommPattern, + logical_process_axis: Union[int, List[int]], + comm_type: CommType, + arg_index: int = -1, + key_for_kwarg: any = None, + ) -> CommAction: """ A factory method to produce a CommAction object. """ - return CommAction(comm_spec=self.get_communication_spec(sharding_spec=sharding_spec, - communication_pattern=communication_pattern, - logical_process_axis=logical_process_axis), - comm_type=comm_type, - arg_index=arg_index, - key_for_kwarg=key_for_kwarg) + return CommAction( + comm_spec=self.get_communication_spec( + sharding_spec=sharding_spec, + communication_pattern=communication_pattern, + logical_process_axis=logical_process_axis, + ), + comm_type=comm_type, + arg_index=arg_index, + key_for_kwarg=key_for_kwarg, + ) def update_communication_cost(self, strategy: ShardingStrategy) -> ShardingStrategy: """ @@ -155,9 +173,9 @@ class StrategyGenerator(ABC): size_per_elem_bytes = torch.tensor([], dtype=dtype).element_size() for phase, cost in num_ele_in_comm.items(): num_ele_in_comm[phase] = num_ele_in_comm[phase] * size_per_elem_bytes - comm_cost.fwd += num_ele_in_comm['forward'] - comm_cost.bwd += num_ele_in_comm['backward'] - comm_cost.total += num_ele_in_comm['total'] + comm_cost.fwd += num_ele_in_comm["forward"] + comm_cost.bwd += num_ele_in_comm["backward"] + comm_cost.total += num_ele_in_comm["total"] # check if communication action exists # if so, loop over each action and compute the cost of each action @@ -169,8 +187,8 @@ class StrategyGenerator(ABC): # this condition branch will be removed after all the handler updated. comm_spec = comm_action if isinstance(comm_spec, dict): - src_spec = comm_spec['src_spec'] - tgt_spec = comm_spec['tgt_spec'] + src_spec = comm_spec["src_spec"] + tgt_spec = comm_spec["tgt_spec"] shape_consistency_manager = ShapeConsistencyManager() _, comm_action_sequence, _ = shape_consistency_manager.shape_consistency(src_spec, tgt_spec) for comm_spec_ in comm_action_sequence: @@ -187,14 +205,12 @@ class StrategyGenerator(ABC): """ Customize this method to compute the computation flops. """ - pass @abstractmethod def update_memory_cost(self, strategy: ShardingStrategy) -> ShardingStrategy: """ Customize this method to compute the memory cost in bytes. """ - pass def _compute_size_in_bytes(self, strategy: ShardingStrategy, key: str): """ @@ -212,20 +228,21 @@ class StrategyGenerator(ABC): num_elements = 1 else: num_elements = reduce(operator.mul, sharded_shape) - dtype = getattr(meta_data, 'dtype') + dtype = getattr(meta_data, "dtype") size_per_elem_bytes = torch.tensor([], dtype=dtype).element_size() return num_elements * size_per_elem_bytes if isinstance(op_data.data, tuple): - assert isinstance(strategy.sharding_specs[op_data], list), \ - 'sharding_spec of op_data should be a list of sharding specs if op_data.data is a tuple.' + assert isinstance( + strategy.sharding_specs[op_data], list + ), "sharding_spec of op_data should be a list of sharding specs if op_data.data is a tuple." total_bytes = 0 for index, sharding_spec in enumerate(strategy.sharding_specs[op_data]): meta_data = op_data.data[index] if isinstance(meta_data, torch.Tensor): element_bytes = _compute_size_in_bytes_helper(sharding_spec, meta_data) else: - # if meta_data is not a tensor, we count the memroy as 0 + # if meta_data is not a tensor, we count the memory as 0 element_bytes = 0 total_bytes += element_bytes @@ -233,7 +250,7 @@ class StrategyGenerator(ABC): if isinstance(op_data.data, torch.Tensor): total_bytes = _compute_size_in_bytes_helper(strategy.sharding_specs[op_data], op_data.data) else: - # if op_data.data is not a tensor, we count the memroy as 0 + # if op_data.data is not a tensor, we count the memory as 0 total_bytes = 0 return total_bytes @@ -270,7 +287,6 @@ class StrategyGenerator(ABC): Validate if the operands are of desired shape. If True, means this generator can be used for the current operation. """ - pass class FollowingStrategyGenerator(StrategyGenerator): @@ -280,8 +296,9 @@ class FollowingStrategyGenerator(StrategyGenerator): TODO: remove the original strategy_generator.py after refactoring """ - def __init__(self, operation_data_mapping: Dict[str, OperationData], device_mesh: DeviceMesh, - predecessor_node: Node): + def __init__( + self, operation_data_mapping: Dict[str, OperationData], device_mesh: DeviceMesh, predecessor_node: Node + ): self.op_data = operation_data_mapping self.device_mesh = device_mesh self.predecessor_node = predecessor_node @@ -292,7 +309,8 @@ class OutputStrategyGenerator(StrategyGenerator): OutputStrategyGenerator is used to generate the sharding strategies for Output Node. """ - def __init__(self, operation_data_mapping: Dict[str, OperationData], device_mesh: DeviceMesh, - predecessor_nodes: List[Node]): + def __init__( + self, operation_data_mapping: Dict[str, OperationData], device_mesh: DeviceMesh, predecessor_nodes: List[Node] + ): super().__init__(operation_data_mapping, device_mesh) self.predecessor_nodes = predecessor_nodes diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/sum_generator.py b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/sum_generator.py index a0fbc58d70c0feba2c78305fb14d9bcb38a82e41..dcbf34cfd65b07cafc07f74b10bfa1d66057a3a2 100644 --- a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/sum_generator.py +++ b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/sum_generator.py @@ -4,22 +4,9 @@ from functools import reduce from typing import List from colossalai.auto_parallel.tensor_shard.node_handler.strategy.strategy_generator import FollowingStrategyGenerator -from colossalai.auto_parallel.tensor_shard.sharding_strategy import ( - CommAction, - CommType, - MemoryCost, - ShardingStrategy, - TrainCycleItem, -) -from colossalai.auto_parallel.tensor_shard.utils import ( - check_keep_sharding_status, - detect_reshape_mapping, - infer_output_dim_partition_dict, -) -from colossalai.tensor.shape_consistency import CollectiveCommPattern -from colossalai.tensor.sharding_spec import ShardingSpec - -__all__ = ['SumGenerator'] +from colossalai.auto_parallel.tensor_shard.sharding_strategy import MemoryCost, ShardingStrategy, TrainCycleItem + +__all__ = ["SumGenerator"] class SumGenerator(FollowingStrategyGenerator): @@ -31,24 +18,24 @@ class SumGenerator(FollowingStrategyGenerator): return super().validate() def update_compute_cost(self, strategy: ShardingStrategy): - sharded_input_shape = strategy.sharding_specs[self.op_data['input']].get_sharded_shape_per_device() - sharded_output_shape = strategy.sharding_specs[self.op_data['output']].get_sharded_shape_per_device() + sharded_input_shape = strategy.sharding_specs[self.op_data["input"]].get_sharded_shape_per_device() + sharded_output_shape = strategy.sharding_specs[self.op_data["output"]].get_sharded_shape_per_device() input_size_product = reduce(operator.mul, sharded_input_shape) output_size_product = reduce(operator.mul, sharded_output_shape) - compute_cost = TrainCycleItem(fwd=input_size_product, - bwd=output_size_product, - total=input_size_product + output_size_product) + compute_cost = TrainCycleItem( + fwd=input_size_product, bwd=output_size_product, total=input_size_product + output_size_product + ) strategy.compute_cost = compute_cost def update_memory_cost(self, strategy: ShardingStrategy): - ''' + """ Compute the memory cost per device with this specific strategy. - ''' + """ forward_size_mapping = { - 'input': self._compute_size_in_bytes(strategy, "input"), - 'output': self._compute_size_in_bytes(strategy, "output") + "input": self._compute_size_in_bytes(strategy, "input"), + "output": self._compute_size_in_bytes(strategy, "output"), } backward_size_mapping = copy.deepcopy(forward_size_mapping) @@ -66,8 +53,9 @@ class SumGenerator(FollowingStrategyGenerator): bwd_mem_cost = MemoryCost(activation=bwd_activation_cost, parameter=bwd_parameter_cost) # compute total cost - total_mem_cost = MemoryCost(activation=fwd_activation_cost + bwd_activation_cost, - parameter=fwd_parameter_cost + bwd_parameter_cost) + total_mem_cost = MemoryCost( + activation=fwd_activation_cost + bwd_activation_cost, parameter=fwd_parameter_cost + bwd_parameter_cost + ) memory_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_mem_cost) strategy.memory_cost = memory_cost @@ -78,7 +66,7 @@ class SumGenerator(FollowingStrategyGenerator): communication_action_mapping = {} input_sharding_spec = strategy.output_sharding_specs[self.op_data["input"]] dim_partition_dict_for_input = copy.deepcopy(input_sharding_spec.dim_partition_dict) - sum_dims, sum_mapping_dict = self.op_data['sum_info'].data + sum_dims, sum_mapping_dict = self.op_data["sum_info"].data # TODO: a better way to handle the distributed sum is sum all the data on chip and then do all reduce # among all the shard groups @@ -90,7 +78,7 @@ class SumGenerator(FollowingStrategyGenerator): elif dim in sum_mapping_dict: dim_partition_dict_for_output[sum_mapping_dict[dim]] = dim_partition_dict_for_input[dim] else: - raise RuntimeError(f'dim {dim} is not in sum_mapping_dict or sum_dims') + raise RuntimeError(f"dim {dim} is not in sum_mapping_dict or sum_dims") for dim in recover_dims: dim_partition_dict_for_input.pop(dim) @@ -105,9 +93,11 @@ class SumGenerator(FollowingStrategyGenerator): # because in solver, this node will be merged into other nodes, and solver will not create a new variable for this node. name = f'{sharding_spec_mapping["input"].sharding_sequence} -> {sharding_spec_mapping["output"].sharding_sequence}_{index}' - strategy = self.get_sharding_strategy(name=name, - sharding_spec_mapping=sharding_spec_mapping, - communication_action_mapping=communication_action_mapping) + strategy = self.get_sharding_strategy( + name=name, + sharding_spec_mapping=sharding_spec_mapping, + communication_action_mapping=communication_action_mapping, + ) strategy_list.append(strategy) return strategy_list diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/tensor_constructor_generator.py b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/tensor_constructor_generator.py index 93cfc9eeea532ac4383f0821008deeccb13951d0..eea00c2fa064093f404f267baa3d24f3fc4c909c 100644 --- a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/tensor_constructor_generator.py +++ b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/tensor_constructor_generator.py @@ -1,19 +1,10 @@ -import copy from typing import List -from colossalai.auto_parallel.tensor_shard.sharding_strategy import ( - CommAction, - CommType, - MemoryCost, - ShardingStrategy, - TrainCycleItem, -) -from colossalai.tensor.shape_consistency import CollectiveCommPattern -from colossalai.tensor.sharding_spec import ShardingSpec +from colossalai.auto_parallel.tensor_shard.sharding_strategy import MemoryCost, ShardingStrategy, TrainCycleItem from .strategy_generator import StrategyGenerator -__all__ = ['TensorConstructorGenerator'] +__all__ = ["TensorConstructorGenerator"] class TensorConstructorGenerator(StrategyGenerator): @@ -30,10 +21,10 @@ class TensorConstructorGenerator(StrategyGenerator): strategy.compute_cost = compute_cost def update_memory_cost(self, strategy: ShardingStrategy): - ''' + """ Compute the memory cost per device with this specific strategy. - ''' - forward_size_mapping = {'output': self._compute_size_in_bytes(strategy, "output")} + """ + forward_size_mapping = {"output": self._compute_size_in_bytes(strategy, "output")} # compute fwd cost incurred # fwd_cost = input + output @@ -57,11 +48,13 @@ class TensorConstructorGenerator(StrategyGenerator): communication_action_mapping = {} sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping) - name = 'Replica Tensor Constructor' + name = "Replica Tensor Constructor" - strategy = self.get_sharding_strategy(name=name, - sharding_spec_mapping=sharding_spec_mapping, - communication_action_mapping=communication_action_mapping) + strategy = self.get_sharding_strategy( + name=name, + sharding_spec_mapping=sharding_spec_mapping, + communication_action_mapping=communication_action_mapping, + ) strategy_list.append(strategy) return strategy_list diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/unary_elementwise_generator.py b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/unary_elementwise_generator.py index b867a30686eb97a55096895d344dcc28b51f347a..943cf3f1f50db8e37af9e02539b3269565e03bfd 100644 --- a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/unary_elementwise_generator.py +++ b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/unary_elementwise_generator.py @@ -1,11 +1,11 @@ import copy from typing import List -from colossalai.auto_parallel.tensor_shard.sharding_strategy import (MemoryCost, ShardingStrategy, TrainCycleItem) +from colossalai.auto_parallel.tensor_shard.sharding_strategy import MemoryCost, ShardingStrategy, TrainCycleItem from .strategy_generator import FollowingStrategyGenerator -__all__ = ['UnaryElementwiseGenerator'] +__all__ = ["UnaryElementwiseGenerator"] class UnaryElementwiseGenerator(FollowingStrategyGenerator): @@ -21,12 +21,12 @@ class UnaryElementwiseGenerator(FollowingStrategyGenerator): strategy.compute_cost = compute_cost def update_memory_cost(self, strategy: ShardingStrategy): - ''' + """ Compute the memory cost per device with this specific strategy. - ''' + """ forward_size_mapping = { - 'input': self._compute_size_in_bytes(strategy, "input"), - 'output': self._compute_size_in_bytes(strategy, "output") + "input": self._compute_size_in_bytes(strategy, "input"), + "output": self._compute_size_in_bytes(strategy, "output"), } backward_size_mapping = copy.deepcopy(forward_size_mapping) @@ -44,8 +44,9 @@ class UnaryElementwiseGenerator(FollowingStrategyGenerator): bwd_mem_cost = MemoryCost(activation=bwd_activation_cost, parameter=bwd_parameter_cost) # compute total cost - total_mem_cost = MemoryCost(activation=fwd_activation_cost + bwd_activation_cost, - parameter=fwd_parameter_cost + bwd_parameter_cost) + total_mem_cost = MemoryCost( + activation=fwd_activation_cost + bwd_activation_cost, parameter=fwd_parameter_cost + bwd_parameter_cost + ) memory_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_mem_cost) strategy.memory_cost = memory_cost @@ -69,9 +70,11 @@ class UnaryElementwiseGenerator(FollowingStrategyGenerator): # we keep same strategies with different name for node merging, and it will not increase the searching space, # because in solver, this node will be merged into other nodes, and solver will not create a new variable for this node. name = f'{sharding_spec_mapping["input"].sharding_sequence} -> {sharding_spec_mapping["output"].sharding_sequence}_{index}' - strategy = self.get_sharding_strategy(name=name, - sharding_spec_mapping=sharding_spec_mapping, - communication_action_mapping=communication_action_mapping) + strategy = self.get_sharding_strategy( + name=name, + sharding_spec_mapping=sharding_spec_mapping, + communication_action_mapping=communication_action_mapping, + ) strategy_list.append(strategy) return strategy_list diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/where_generator.py b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/where_generator.py index fa941f2cc51dc4d817bfc8f49c54bbaf7a8a5407..b27b4f3d40569ccc0e3e3d2a95f0146706145976 100644 --- a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/where_generator.py +++ b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/where_generator.py @@ -10,7 +10,7 @@ from colossalai.auto_parallel.tensor_shard.utils import ( from .strategy_generator import StrategyGenerator -__all__ = ['WhereGenerator'] +__all__ = ["WhereGenerator"] class WhereGenerator(StrategyGenerator): @@ -26,14 +26,14 @@ class WhereGenerator(StrategyGenerator): strategy.compute_cost = compute_cost def update_memory_cost(self, strategy: ShardingStrategy): - ''' + """ Compute the memory cost per device with this specific strategy. - ''' + """ forward_size_mapping = { - 'condition': self._compute_size_in_bytes(strategy, "condition"), - 'x': self._compute_size_in_bytes(strategy, "x"), - 'y': self._compute_size_in_bytes(strategy, "y"), - 'output': self._compute_size_in_bytes(strategy, "output") + "condition": self._compute_size_in_bytes(strategy, "condition"), + "x": self._compute_size_in_bytes(strategy, "x"), + "y": self._compute_size_in_bytes(strategy, "y"), + "output": self._compute_size_in_bytes(strategy, "output"), } backward_size_mapping = copy.deepcopy(forward_size_mapping) @@ -59,7 +59,7 @@ class WhereGenerator(StrategyGenerator): "condition": dim_partition, "x": dim_partition, "y": dim_partition, - "output": dim_partition + "output": dim_partition, } sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping) @@ -67,9 +67,11 @@ class WhereGenerator(StrategyGenerator): name = f'{sharding_spec_mapping["output"].sharding_sequence} = {sharding_spec_mapping["condition"].sharding_sequence} x {sharding_spec_mapping["x"].sharding_sequence} x {sharding_spec_mapping["y"].sharding_sequence}' communication_action_mapping = {} - strategy = self.get_sharding_strategy(name=name, - sharding_spec_mapping=sharding_spec_mapping, - communication_action_mapping=communication_action_mapping) + strategy = self.get_sharding_strategy( + name=name, + sharding_spec_mapping=sharding_spec_mapping, + communication_action_mapping=communication_action_mapping, + ) return strategy @@ -84,9 +86,9 @@ class WhereGenerator(StrategyGenerator): return dim_partition_list def collate_strategies(self) -> List[ShardingStrategy]: - ''' + """ Generate every possible strategies for a where node, and record all strategies into the strategies_vector. - ''' + """ strategy_list = [] dimension_length = len(self.op_data["output"].logical_shape) diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/sum_handler.py b/colossalai/auto_parallel/tensor_shard/node_handler/sum_handler.py index 86f90694e0604f72e9564020ccab455cfdee29a0..5b4ea0afe5f8157a5d63e7ac3bd8621c626fd6a1 100644 --- a/colossalai/auto_parallel/tensor_shard/node_handler/sum_handler.py +++ b/colossalai/auto_parallel/tensor_shard/node_handler/sum_handler.py @@ -7,7 +7,7 @@ from .node_handler import NodeHandler from .registry import operator_registry from .strategy import StrategyGenerator, SumGenerator -__all__ = ['SumHandler'] +__all__ = ["SumHandler"] @operator_registry.register(torch.Tensor.sum) @@ -55,7 +55,7 @@ class SumHandler(NodeHandler): # sum_mapping_dict[1] = 0 means the 0th dim of output is the 1st dim of input # sum_mapping_dict[3] = 1 means the 1st dim of output is the 3rd dim of input sum_mapping_dict = {} - if 'keepdim' in self.node.kwargs and self.node.kwargs['keepdim']: + if "keepdim" in self.node.kwargs and self.node.kwargs["keepdim"]: for i in range(num_dims): sum_mapping_dict.update({i: i}) else: @@ -67,7 +67,7 @@ class SumHandler(NodeHandler): assert output_index == self.node._meta_data.dim() sum_info = (sum_dims, sum_mapping_dict) - physical_shape_operand = OperationData(name='sum_info', type=OperationDataType.ARG, data=sum_info) + physical_shape_operand = OperationData(name="sum_info", type=OperationDataType.ARG, data=sum_info) output_data = self.node._meta_data physical_output_operand = OperationData(name=str(self.node), type=OperationDataType.OUTPUT, data=output_data) @@ -75,7 +75,7 @@ class SumHandler(NodeHandler): mapping = { "input": physical_input_operand, "sum_info": physical_shape_operand, - "output": physical_output_operand + "output": physical_output_operand, } return mapping diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/tensor_constructor_handler.py b/colossalai/auto_parallel/tensor_shard/node_handler/tensor_constructor_handler.py index 855a2e7612af0cb59cae9bc8574197fad098f983..c2aa120e8a282ff346c4bf8dc6214d9f608b470c 100644 --- a/colossalai/auto_parallel/tensor_shard/node_handler/tensor_constructor_handler.py +++ b/colossalai/auto_parallel/tensor_shard/node_handler/tensor_constructor_handler.py @@ -8,7 +8,7 @@ from .registry import operator_registry from .strategy import StrategyGenerator from .strategy.tensor_constructor_generator import TensorConstructorGenerator -__all__ = ['TensorConstructorHandler'] +__all__ = ["TensorConstructorHandler"] @operator_registry.register(torch.arange) diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/transpose_handler.py b/colossalai/auto_parallel/tensor_shard/node_handler/transpose_handler.py index 7a9d377264905a650fd991cb10e98f8c3f16f871..b72d9812f4062b5b64c04726d2970c9eb4f30f13 100644 --- a/colossalai/auto_parallel/tensor_shard/node_handler/transpose_handler.py +++ b/colossalai/auto_parallel/tensor_shard/node_handler/transpose_handler.py @@ -7,7 +7,7 @@ from .node_handler import NodeHandler from .registry import operator_registry from .strategy import StrategyGenerator, TransposeGenerator -__all__ = ['TransposeHandler'] +__all__ = ["TransposeHandler"] @operator_registry.register(torch.Tensor.transpose) @@ -48,9 +48,9 @@ class TransposeHandler(NodeHandler): if transpose_dims[i] < 0: transpose_dims[i] += num_dims - physical_shape_operand = OperationData(name='transpose_dims', - type=OperationDataType.ARG, - data=list(transpose_dims)) + physical_shape_operand = OperationData( + name="transpose_dims", type=OperationDataType.ARG, data=list(transpose_dims) + ) output_data = self.node._meta_data physical_output_operand = OperationData(name=str(self.node), type=OperationDataType.OUTPUT, data=output_data) @@ -58,7 +58,7 @@ class TransposeHandler(NodeHandler): mapping = { "input": physical_input_operand, "transpose_dims": physical_shape_operand, - "output": physical_output_operand + "output": physical_output_operand, } return mapping diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/unary_elementwise_handler.py b/colossalai/auto_parallel/tensor_shard/node_handler/unary_elementwise_handler.py index 0362de780d7af0fa9569a575a122f84ffb42b0db..cbc873de822380df426c806d0ac8c25ca7b1df53 100644 --- a/colossalai/auto_parallel/tensor_shard/node_handler/unary_elementwise_handler.py +++ b/colossalai/auto_parallel/tensor_shard/node_handler/unary_elementwise_handler.py @@ -3,11 +3,11 @@ from typing import Dict, List import torch from ..sharding_strategy import OperationData, OperationDataType -from .node_handler import MetaInfoNodeHandler, NodeHandler +from .node_handler import MetaInfoNodeHandler from .registry import operator_registry from .strategy import StrategyGenerator, UnaryElementwiseGenerator -__all__ = ['UnaryElementwiseHandler'] +__all__ = ["UnaryElementwiseHandler"] @operator_registry.register(torch.Tensor.to) @@ -33,9 +33,9 @@ class UnaryElementwiseHandler(MetaInfoNodeHandler): def get_operation_data_mapping(self) -> Dict[str, OperationData]: # use transposed shape for strategies # the strategies will be transformed back to its original shape in self.post_process - physical_input_operand = OperationData(name=str(self.node.args[0]), - type=OperationDataType.ARG, - data=self.node.args[0]._meta_data) + physical_input_operand = OperationData( + name=str(self.node.args[0]), type=OperationDataType.ARG, data=self.node.args[0]._meta_data + ) physical_output = OperationData(name=str(self.node), type=OperationDataType.OUTPUT, data=self.node._meta_data) mapping = {"input": physical_input_operand, "output": physical_output} diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/view_handler.py b/colossalai/auto_parallel/tensor_shard/node_handler/view_handler.py index 7dff89d1d7a39a6e4fe73514bfb16abe2e3e7bea..56c1d10a167e6bd4e295fc29805530bc321e5933 100644 --- a/colossalai/auto_parallel/tensor_shard/node_handler/view_handler.py +++ b/colossalai/auto_parallel/tensor_shard/node_handler/view_handler.py @@ -7,7 +7,7 @@ from .node_handler import NodeHandler from .registry import operator_registry from .strategy import StrategyGenerator, ViewGenerator -__all__ = ['ViewHandler'] +__all__ = ["ViewHandler"] @operator_registry.register(torch.Tensor.reshape) @@ -38,7 +38,7 @@ class ViewHandler(NodeHandler): physical_input_operand = OperationData(name=str(self.node.args[0]), type=data_type, data=input_data) target_shape = self.node._meta_data.shape - physical_shape_operand = OperationData(name='tgt_shape', type=OperationDataType.ARG, data=target_shape) + physical_shape_operand = OperationData(name="tgt_shape", type=OperationDataType.ARG, data=target_shape) output_data = self.node._meta_data physical_output_operand = OperationData(name=str(self.node), type=OperationDataType.OUTPUT, data=output_data) @@ -46,7 +46,7 @@ class ViewHandler(NodeHandler): mapping = { "input": physical_input_operand, "tgt_shape": physical_shape_operand, - "output": physical_output_operand + "output": physical_output_operand, } return mapping diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/where_handler.py b/colossalai/auto_parallel/tensor_shard/node_handler/where_handler.py index 6de2aaafdd018f08195563ef882f07eb39d8d20a..1856a11100b07944d99d498e19f5cd8dda25d9d5 100644 --- a/colossalai/auto_parallel/tensor_shard/node_handler/where_handler.py +++ b/colossalai/auto_parallel/tensor_shard/node_handler/where_handler.py @@ -1,16 +1,15 @@ import copy -import operator from typing import Dict, List import torch -from ..sharding_strategy import OperationData, OperationDataType, ShardingStrategy, StrategiesVector +from ..sharding_strategy import OperationData, OperationDataType, ShardingStrategy from ..utils import recover_sharding_spec_for_broadcast_shape from .node_handler import NodeHandler from .registry import operator_registry from .strategy import StrategyGenerator, WhereGenerator -__all__ = ['WhereHandler'] +__all__ = ["WhereHandler"] @operator_registry.register(torch.where) @@ -28,27 +27,28 @@ class WhereHandler(NodeHandler): def get_operation_data_mapping(self) -> Dict[str, OperationData]: # use transposed shape for strategies # the strategies will be transformed back to its original shape in self.post_process - physical_condition_operand = OperationData(name=str(self.node.args[0]), - type=OperationDataType.ARG, - data=self.node.args[0]._meta_data) - physical_x_operand = OperationData(name=str(self.node.args[1]), - type=OperationDataType.ARG, - data=self.node.args[1]._meta_data) - physical_y_operand = OperationData(name=str(self.node.args[2]), - type=OperationDataType.ARG, - data=self.node.args[2]._meta_data) + physical_condition_operand = OperationData( + name=str(self.node.args[0]), type=OperationDataType.ARG, data=self.node.args[0]._meta_data + ) + physical_x_operand = OperationData( + name=str(self.node.args[1]), type=OperationDataType.ARG, data=self.node.args[1]._meta_data + ) + physical_y_operand = OperationData( + name=str(self.node.args[2]), type=OperationDataType.ARG, data=self.node.args[2]._meta_data + ) physical_output = OperationData(name=str(self.node), type=OperationDataType.OUTPUT, data=self.node._meta_data) physical_mapping = { "condition": physical_condition_operand, "x": physical_x_operand, "y": physical_y_operand, - "output": physical_output + "output": physical_output, } logical_shape_for_all = self.node._meta_data.shape logical_mapping = {} for key, physical_operand in physical_mapping.items(): - logical_mapping[key] = self.convert_physical_operand_to_logical_operand(physical_operand, - logical_shape_for_all) + logical_mapping[key] = self.convert_physical_operand_to_logical_operand( + physical_operand, logical_shape_for_all + ) return logical_mapping, physical_mapping @@ -64,7 +64,8 @@ class WhereHandler(NodeHandler): logical_shape = logical_op_data_mapping[key].logical_shape physical_shape = physical_op_data_mapping[key].logical_shape physical_sharding_spec, removed_dims = recover_sharding_spec_for_broadcast_shape( - logical_sharding_spec, logical_shape, physical_shape) + logical_sharding_spec, logical_shape, physical_shape + ) strategy.sharding_specs.pop(logical_op_data_mapping[key]) strategy.sharding_specs[physical_op_data_mapping[key]] = physical_sharding_spec strategy.name = f"{strategy.sharding_specs[physical_op_data_mapping['output']].sharding_sequence} = {strategy.sharding_specs[physical_op_data_mapping['condition']].sharding_sequence} x {strategy.sharding_specs[physical_op_data_mapping['x']].sharding_sequence} x {strategy.sharding_specs[physical_op_data_mapping['y']].sharding_sequence}" diff --git a/colossalai/auto_parallel/tensor_shard/options.py b/colossalai/auto_parallel/tensor_shard/options.py index f0ea502a6f0e2c4412ac333f9465aec6873e9791..e87872f39c1044f159f474644fe02f3b822dcfdd 100644 --- a/colossalai/auto_parallel/tensor_shard/options.py +++ b/colossalai/auto_parallel/tensor_shard/options.py @@ -1,13 +1,14 @@ from dataclasses import dataclass from enum import Enum -__all__ = ['SolverOptions', 'SolverPerference', 'DataloaderOption', 'ShardOption'] +__all__ = ["SolverOptions", "SolverPerference", "DataloaderOption", "ShardOption"] class SolverPerference(Enum): """ This enum class is to define the solver preference. """ + STANDARD = 0 DP = 1 TP = 2 @@ -25,6 +26,7 @@ class ShardOption(Enum): TP_SHARD: We require the node to be shard using tensor parallel strategies on last device mesh axis. TP_FULL_SHARD: We require the node to be shard using tensor parallel strategies on all device mesh axes. """ + STANDARD = 0 SHARD = 1 SHARD_LAST_AXIS = 2 @@ -35,6 +37,7 @@ class DataloaderOption(Enum): """ This enum class is to define the dataloader option. """ + REPLICATED = 0 DISTRIBUTED = 1 @@ -44,6 +47,7 @@ class SolverOptions: """ SolverOptions is a dataclass used to configure the preferences for the parallel execution plan search. """ + solver_perference: SolverPerference = SolverPerference.STANDARD dataloader_option: DataloaderOption = DataloaderOption.REPLICATED shard_option: ShardOption = ShardOption.STANDARD diff --git a/colossalai/auto_parallel/tensor_shard/sharding_strategy.py b/colossalai/auto_parallel/tensor_shard/sharding_strategy.py index 6af92727243759bf2d0e0e1b8f472e1e59308ca3..8e22df64d86846cc0914fe81f2dc4dca09d4269c 100644 --- a/colossalai/auto_parallel/tensor_shard/sharding_strategy.py +++ b/colossalai/auto_parallel/tensor_shard/sharding_strategy.py @@ -10,7 +10,6 @@ from colossalai.tensor.comm_spec import CommSpec from colossalai.tensor.sharding_spec import ShardingSpec from .constants import ( - BCAST_FUNC_OP, ELEMENTWISE_FUNC_OP, ELEMENTWISE_METHOD_OP, ELEMENTWISE_MODULE_OP, @@ -18,13 +17,14 @@ from .constants import ( RESHAPE_METHOD_OP, ) -__all__ = ['OperationDataType', 'OperationData', 'TrainCycleItem', 'MemoryCost', 'ShardingStrategy', 'StrategiesVector'] +__all__ = ["OperationDataType", "OperationData", "TrainCycleItem", "MemoryCost", "ShardingStrategy", "StrategiesVector"] class OperationDataType(Enum): """ An operation can come from the argument list of an operator or the parameter list of a module. """ + INPUT = 0 ARG = 1 PARAM = 2 @@ -43,6 +43,7 @@ class OperationData: data (Any): the value for this data, usually it is a meta tensor. logical_shape (Tuple[int]): the logical shape of the data, it can be different from the its actual shape in memory. """ + name: str type: OperationDataType data: Any @@ -69,13 +70,13 @@ class OperationData: self.logical_shape = _infer_logical_shape(self.data) def __repr__(self) -> str: - return f'OperationData(name={self.name}, type={self.type})' + return f"OperationData(name={self.name}, type={self.type})" def __eq__(self, other) -> bool: return other.name == self.name def __hash__(self) -> int: - return hash(f'{self.name}') + return hash(f"{self.name}") @dataclass @@ -88,6 +89,7 @@ class TrainCycleItem: fwd (float): the item for the forward pass bwd (float): the item for the backward pass """ + fwd: Any bwd: Any total: Any @@ -104,6 +106,7 @@ class MemoryCost: temp (int): the memory cost incurred by the temporary tensors in bytes. buffer (int): the memory cost incurred by the module buffer in bytes. """ + activation: int = 0 parameter: int = 0 temp: int = 0 @@ -120,6 +123,7 @@ class CommType(Enum): HOOK: the communication action is used to do the grad all reduce. IMPLICIT: the communication action happens during the kernel execution, such as SyncBatchNorm """ + BEFORE = 0 AFTER = 1 HOOK = 2 @@ -137,6 +141,7 @@ class CommAction: arg_index: record the location of tensor which join the communication, we cannot use name of node or op_data at runtime, because the args of node may be changed by graph transform passes. """ + comm_spec: CommSpec = None comm_type: CommType = None arg_index: int = -1 @@ -156,6 +161,7 @@ class ShardingStrategy: memory_cost (TrainCycleItem): Memory cost of the output node using this strategy. (default to None) input_sharding_specs (List(ShardingSpec)): The ShardingSpecs of the input nodes. """ + name: str sharding_specs: Dict[OperationData, Union[ShardingSpec, Tuple[ShardingSpec]]] = None compute_cost: TrainCycleItem = None @@ -200,7 +206,6 @@ class ShardingStrategy: raise KeyError(f"Could not find the ShardingSpec for OperationData with name {name}") def clone(self): - def _deepcopy_dict_vals(data: Dict): return {k: deepcopy(v) for k, v in data.items()} @@ -209,31 +214,34 @@ class ShardingStrategy: # Consider the examples below: # If self.communication_actions is an empty dictionary {}, then self.communication_actions is not None, but its __bool__ value is False. # In this case, if we set None to the new object, program will crash when we try to access the communication_actions.items. - communication_actions = _deepcopy_dict_vals( - self.communication_actions) if self.communication_actions is not None else None + communication_actions = ( + _deepcopy_dict_vals(self.communication_actions) if self.communication_actions is not None else None + ) # same reason as communication_actions resharding_costs = _deepcopy_dict_vals(self.resharding_costs) if self.resharding_costs is not None else None compute_cost = deepcopy(self.compute_cost) communication_cost = deepcopy(self.communication_cost) memory_cost = deepcopy(self.memory_cost) - return ShardingStrategy(name=self.name, - sharding_specs=sharding_specs, - compute_cost=compute_cost, - communication_cost=communication_cost, - memory_cost=memory_cost, - communication_actions=communication_actions, - resharding_costs=resharding_costs) + return ShardingStrategy( + name=self.name, + sharding_specs=sharding_specs, + compute_cost=compute_cost, + communication_cost=communication_cost, + memory_cost=memory_cost, + communication_actions=communication_actions, + resharding_costs=resharding_costs, + ) class StrategiesVector(list): - ''' + """ Each node in fx graph will have a corresponding StrategiesVector, to store all the possible strategies of the node. Argument: node (Node): node for which the list of sharding strategies are generated. - ''' + """ def __init__(self, node: Node): super().__init__() @@ -245,7 +253,7 @@ class StrategiesVector(list): def check_merge(self): merge_label = False - if self.node.op == 'call_module': + if self.node.op == "call_module": target = self.node.target root_module = self.node.graph.owning_module submod = root_module.get_submodule(target) @@ -255,7 +263,7 @@ class StrategiesVector(list): if submod_type in ELEMENTWISE_MODULE_OP: merge_label = True - if self.node.op == 'call_function': + if self.node.op == "call_function": # we could merge element-wise op, because the output sharding spec is always same as the input sharding spec. if self.node.target in ELEMENTWISE_FUNC_OP: merge_label = True @@ -267,7 +275,7 @@ class StrategiesVector(list): if self.node.target in RESHAPE_FUNC_OP: merge_label = True - if self.node.op == 'call_method': + if self.node.op == "call_method": # we could merge reshape op, because their computation costs are negligible. method = getattr(self.node.args[0]._meta_data.__class__, self.node.target) if method in RESHAPE_METHOD_OP: diff --git a/colossalai/auto_parallel/tensor_shard/solver/__init__.py b/colossalai/auto_parallel/tensor_shard/solver/__init__.py index f9e6bd9239214c4def03b1b419a2845581fa083a..b930ce80a9b932e6c54680fd7c988428b29afa0e 100644 --- a/colossalai/auto_parallel/tensor_shard/solver/__init__.py +++ b/colossalai/auto_parallel/tensor_shard/solver/__init__.py @@ -3,4 +3,4 @@ from .graph_analysis import GraphAnalyser from .solver import Solver from .strategies_constructor import StrategiesConstructor -__all__ = ['GraphAnalyser', 'Solver', 'StrategiesConstructor', 'CostGraph'] +__all__ = ["GraphAnalyser", "Solver", "StrategiesConstructor", "CostGraph"] diff --git a/colossalai/auto_parallel/tensor_shard/solver/cost_graph.py b/colossalai/auto_parallel/tensor_shard/solver/cost_graph.py index 74290453ca0c2dd40008ec51584a134c8f278410..4415d429b0c22b82631984223df43f751a5fe1f6 100644 --- a/colossalai/auto_parallel/tensor_shard/solver/cost_graph.py +++ b/colossalai/auto_parallel/tensor_shard/solver/cost_graph.py @@ -4,18 +4,18 @@ from colossalai.auto_parallel.tensor_shard.constants import INFINITY_COST class CostGraph: - ''' + """ A graph data structure to simplify the edge cost graph. It has two main functions: 1. To feed the quadratic resharding costs into solver, we need to linearize it. We build edge_cost in CostGraph, and it stored every combinations of strategies for a src-dst node pair in an 1D list. 2. To reduce the searching space, we merge computationally-trivial operators, such as - element-wise operators, transpose, and reduction, into their following nodes. The merging infomation will + element-wise operators, transpose, and reduction, into their following nodes. The merging information will be given by the StrategiesVector depending on the type of target node and following nodes. Argument: leaf_strategies(List[StrategiesVector]): It stores StrategiesVector of every nodes on the graph. simplify(bool, optional): The generated cost graph will be simplified if it is true. (default to True) - ''' + """ def __init__(self, leaf_strategies, simplify=True, forward_only=False): self.leaf_strategies = leaf_strategies @@ -39,10 +39,10 @@ class CostGraph: target_node_list.remove(element) def _build_cost_graph(self): - ''' + """ This method will generate edge_cost for adjacent node pair. Additionally, 'parents' and 'children' attribute will be set to node. - ''' + """ self.edge_costs = {} if self.simplify: self.merge_pair = [] @@ -84,13 +84,13 @@ class CostGraph: if _check_tensor_in_node(node._meta_data): children_nodes.append(node) - setattr(dst_node, 'parents', parent_nodes) - setattr(dst_node, 'children', children_nodes) + setattr(dst_node, "parents", parent_nodes) + setattr(dst_node, "children", children_nodes) if self.simplify and strategies_vector.check_merge(): for followed_node in strategies_vector.predecessor_nodes: # we only merge node pairs which src node has a tensor element inside. - # This is necessay because the node without a tensor element inside will not + # This is necessary because the node without a tensor element inside will not # be assigned any strategy. if _check_tensor_in_node(followed_node._meta_data): self.merge_pair.append((followed_node, dst_node)) @@ -99,7 +99,7 @@ class CostGraph: return self.edge_costs[(src_node, dst_node)] def merge_node(self, src_node, dst_node): - ''' + """ To merge dst_node into src_node, we need to do it in following steps: 1. For each strategy in dst_node, we need to pick an appropriate strategy @@ -119,7 +119,7 @@ class CostGraph: Argument: src_node(Node): The node will be merged into dst_node. dst_node(Node): The node to integrate src_node. - ''' + """ # build merge_map merge_map = {} for src_index, _ in enumerate(src_node.strategies_vector): @@ -196,7 +196,7 @@ class CostGraph: if not self.simplify: return self.merge_pair.reverse() - for (src_node, dst_node) in self.merge_pair: + for src_node, dst_node in self.merge_pair: self.merge_node(src_node, dst_node) self.merge_pair.reverse() reindexing_following_dict = {} diff --git a/colossalai/auto_parallel/tensor_shard/solver/graph_analysis.py b/colossalai/auto_parallel/tensor_shard/solver/graph_analysis.py index be39a74cb23755f9ff2b83cf1123a7a7f9708ffa..678965d663e4d243f0d63771f82455cc89f032ec 100644 --- a/colossalai/auto_parallel/tensor_shard/solver/graph_analysis.py +++ b/colossalai/auto_parallel/tensor_shard/solver/graph_analysis.py @@ -7,7 +7,7 @@ from torch.fx.node import Node from colossalai.fx.passes.utils import get_node_module -__all__ = ['LiveVariable', 'LiveVariableVector', 'LiveStage', 'GraphAnalyser'] +__all__ = ["LiveVariable", "LiveVariableVector", "LiveStage", "GraphAnalyser"] @dataclass @@ -15,6 +15,7 @@ class LiveVariable: """ LiveVariable is a data structure to store the meta information of a variable for liveness analysis. """ + name: str node: Node is_inplace: bool @@ -55,6 +56,7 @@ class LiveStage: """ LiveStage is a data structure to record the living variables at this current node. """ + name: str node: Node all_live_vars: LiveVariableVector @@ -62,7 +64,6 @@ class LiveStage: class GraphAnalyser: - def __init__(self, gm: GraphModule): self._gm = gm self._graph = gm.graph @@ -83,7 +84,7 @@ class GraphAnalyser: def liveness_analysis(self) -> List[LiveStage]: """ - Analyse the graph to obtain the variable liveness information. This function returns + Analyses the graph to obtain the variable liveness information. This function returns an ordered dictionary where the key is the compute stage ID and the value is a LivenessStage object. """ compute_nodes = self.graph.nodes @@ -91,7 +92,7 @@ class GraphAnalyser: # checked: record all variables created since the first stage # all: record the live variables only exist until the current stage. - # this can be different from the `checked list`` as some varialbes may be destroyed prior to this stage. + # this can be different from the `checked list`` as some variables may be destroyed prior to this stage. # unique: record the unique live variables only exist until the current stage. # this is different from `all list` as some variables are duplicated. checked_variables = LiveVariableVector() @@ -103,20 +104,20 @@ class GraphAnalyser: # find new living variables # ############################# # detect whether the current op is an in-place op - # if it is an in-place op, we would deem it as a duplciate var + # if it is an in-place op, we would deem it as a duplicate var is_inplace = False - if node.op == 'call_function': + if node.op == "call_function": # check if this is an inplace op such as torch.nn.functional.relu(x, inplace=True) - if node.kwargs.get('inplace', False): + if node.kwargs.get("inplace", False): is_inplace = True - elif node.op == 'call_module': + elif node.op == "call_module": # to check if this is an inplace op such as torch.nn.Relu(inplace=True) module = get_node_module(node) - if getattr(module, 'inplace', False): + if getattr(module, "inplace", False): is_inplace = True # add the output var - meta = getattr(node, '_meta_data', None) + getattr(node, "_meta_data", None) live_var = LiveVariable(name=node.name, node=node, is_inplace=is_inplace) if not is_inplace: unique_live_vars.append(live_var) @@ -138,10 +139,12 @@ class GraphAnalyser: # this should be completed if we are able to trace the backward compute graph # add this stage to liveness dict - stage = LiveStage(name=node.name, - node=node, - all_live_vars=all_live_variables.copy(), - unique_live_vars=unique_live_vars.copy()) + stage = LiveStage( + name=node.name, + node=node, + all_live_vars=all_live_variables.copy(), + unique_live_vars=unique_live_vars.copy(), + ) # if a LiveStage is covered by another LiveStage, we just keep the larger one. replace = False for index, prev_stage in enumerate(liveness_list): diff --git a/colossalai/auto_parallel/tensor_shard/solver/solver.py b/colossalai/auto_parallel/tensor_shard/solver/solver.py index f5c6663dce80671199dcaf235380178622813310..088d1acb51778acf4bd220ae777811f54cb1eb50 100644 --- a/colossalai/auto_parallel/tensor_shard/solver/solver.py +++ b/colossalai/auto_parallel/tensor_shard/solver/solver.py @@ -21,34 +21,35 @@ try: import pulp from pulp import LpMinimize, LpProblem, LpStatus, LpVariable, lpDot, lpSum except: - warnings.warn(f'please install the pulp') + warnings.warn(f"please install the pulp") -__all___ = ['Solver'] +__all___ = ["Solver"] class Solver: - - def __init__(self, - graph: Graph, - strategies_constructor: StrategiesConstructor, - cost_graph: CostGraph, - graph_analyser: GraphAnalyser = None, - memory_budget: float = -1.0, - solution_numbers: int = 1, - forward_only: bool = False, - memory_increasing_coefficient: float = 1.3, - verbose=False): - ''' + def __init__( + self, + graph: Graph, + strategies_constructor: StrategiesConstructor, + cost_graph: CostGraph, + graph_analyser: GraphAnalyser = None, + memory_budget: float = -1.0, + solution_numbers: int = 1, + forward_only: bool = False, + memory_increasing_coefficient: float = 1.3, + verbose=False, + ): + """ Solver class will integrate information provided by the components and use ILP solver to find a possible optimal strategies combination for target computing graph. Argument: graph: The computing graph to be optimized. strategies_constructor: It will provide all the possible strategies for each node in the computing graph. cost_graph: A graph data structure to simplify the edge cost graph. - graph_analyser: graph_analyser will analyse the graph to obtain the variable liveness information, which will be used to generate memory constraints. + graph_analyser: graph_analyser will analyses the graph to obtain the variable liveness information, which will be used to generate memory constraints. memory_budget: Memory constraint for the solution. solution_numbers: If solution_numbers is larger than one, solver will us a serious of solutions based on different memory budget. memory_increasing_coefficient: If solution_numbers is larger than one, we will use this coefficient to generate new memory budget. - ''' + """ self.graph = graph self.strategies_constructor = strategies_constructor self.cost_graph = cost_graph @@ -75,11 +76,11 @@ class Solver: self.verbose = verbose def _recover_merged_node_strategy(self): - ''' + """ During cost graph constructing, some nodes, such as unary element-wise node or ReshapeOp, were merged into the previous node. Therefore, the index of those strategies are copied from the previous node. This method is used to recover the strategy index of those merged node. - ''' + """ for node_index, node in enumerate(self.nodes): if node.strategies_vector.check_merge(): # the merged node has only one input, and its strategies follow the input sharding strategy @@ -98,9 +99,9 @@ class Solver: return node_index_dict def _prepare_data_for_solver(self): - ''' + """ Extract information from components for solver. - ''' + """ node_nums = len(self.leaf_strategies) memory_budget = self.memory_budget @@ -190,23 +191,40 @@ class Solver: # omit initial value for nodes s_init_np = None - return node_nums, memory_budget, strategies_len, following_nodes, edge_pairs, alias_set, liveness_set, compute_costs, communication_costs, memory_costs, resharding_costs, alias_convert_costs, s_init_np, self.verbose - - def _call_solver_serialized_args(self, - node_nums, - memory_budget, - strategies_len, - following_nodes, - edge_pairs, - alias_set, - liveness_set, - compute_costs, - communication_costs, - memory_costs, - resharding_costs, - alias_convert_costs, - s_init_np=None, - verbose=True): + return ( + node_nums, + memory_budget, + strategies_len, + following_nodes, + edge_pairs, + alias_set, + liveness_set, + compute_costs, + communication_costs, + memory_costs, + resharding_costs, + alias_convert_costs, + s_init_np, + self.verbose, + ) + + def _call_solver_serialized_args( + self, + node_nums, + memory_budget, + strategies_len, + following_nodes, + edge_pairs, + alias_set, + liveness_set, + compute_costs, + communication_costs, + memory_costs, + resharding_costs, + alias_convert_costs, + s_init_np=None, + verbose=True, + ): """ Call the solver with serialized arguments. """ @@ -235,18 +253,18 @@ class Solver: s_follow = following_nodes s_alias = alias_set - E = edge_pairs.reshape((-1, 2)) # noqa + E = edge_pairs.reshape((-1, 2)) # noqa r = [] pt = 0 edge_set = set() - for (i, j) in E: + for i, j in E: prod_length = strategies_len[i] * strategies_len[j] if (i, j) in edge_set: raise ValueError(f"Duplicated edges: {(i, j)}") edge_set.add((i, j)) - r.append(resharding_costs[pt:pt + prod_length]) + r.append(resharding_costs[pt : pt + prod_length]) pt += prod_length assert pt == len(resharding_costs) @@ -268,7 +286,6 @@ class Solver: # L.append(liveness_set[pt:pt + length]) # pt += length # assert pt == len(liveness_set) - v = [] pt = 0 c = [] @@ -277,9 +294,9 @@ class Solver: pt = 0 for i in range(node_nums): length = strategies_len[i] - c.append(compute_costs[pt:pt + length]) - d.append(communication_costs[pt:pt + length]) - m.append(memory_costs[pt:pt + length]) + c.append(compute_costs[pt : pt + length]) + d.append(communication_costs[pt : pt + length]) + m.append(memory_costs[pt : pt + length]) pt += length assert pt == len(compute_costs), f"{pt} == {len(compute_costs)}" assert pt == len(communication_costs), f"{pt} == {len(communication_costs)}" @@ -319,7 +336,7 @@ class Solver: e = [] num_edges = 0 map_edge_to_idx = {} - for (idx, (i, j)) in enumerate(E): + for idx, (i, j) in enumerate(E): if len(s[i]) == 1: e.append(s[j]) elif len(s[j]) == 1: @@ -340,7 +357,7 @@ class Solver: ###################################### if s_init_np is not None: s_init = s_init_np.reshape((-1, 3)) - for (idx, value, fix) in s_init: + for idx, value, fix in s_init: for i in range(len(s[idx])): s[idx][i].setInitialValue(i == value) if fix: @@ -393,7 +410,7 @@ class Solver: # (d). specified by `cat="Binary"` - for (idx, (i, j)) in enumerate(E): + for idx, (i, j) in enumerate(E): if strategies_len[i] == 1 or strategies_len[j] == 1: continue @@ -402,13 +419,13 @@ class Solver: # (f) for row in range(len(s[i])): - C = len(s[j]) # noqa + C = len(s[j]) # noqa prob += lpSum(e[idx][row * C + col] for col in range(0, C)) <= s[i][row] # (g) for col in range(len(s[j])): - R = len(s[i]) # noqa - C = len(s[j]) # noqa + R = len(s[i]) # noqa + C = len(s[j]) # noqa prob += lpSum(e[idx][row * C + col] for row in range(0, R)) <= s[j][col] # (h) @@ -434,7 +451,8 @@ class Solver: msg = verbose time_limit = 600 assert "COIN_CMD" in pulp.listSolvers( - onlyAvailable=True), ("Please install ILP solvers by 'sudo apt install coinor-cbc'") + onlyAvailable=True + ), "Please install ILP solvers by 'sudo apt install coinor-cbc'" solver = pulp.COIN_CMD(mip=True, msg=msg, timeLimit=time_limit, threads=multiprocessing.cpu_count()) # solver = pulp.GLPK_CMD(mip=True, msg=msg, timeLimit=time_limit) @@ -444,13 +462,13 @@ class Solver: objective = pulp.value(prob.objective) objective = float(objective) if objective is not None else -1.0 if verbose: - print(f"ILP Status: {LpStatus[status]}\tObjective: {objective}\t" - f"Time: {time.time() - tic}") + print(f"ILP Status: {LpStatus[status]}\tObjective: {objective}\t" f"Time: {time.time() - tic}") print(f"#nodes: {num_nodes}, #edges: {num_edges}") if prob.status in [pulp.LpStatusInfeasible]: - raise RuntimeError("Cannot run the function under the given memory budget. " - "Please increase the memory budget.") + raise RuntimeError( + "Cannot run the function under the given memory budget. " "Please increase the memory budget." + ) # Get and check results s_val = np.full((node_nums,), -1, dtype=np.int32) @@ -458,7 +476,7 @@ class Solver: s_val[i] = get_non_zero_index(s[i]) e_val = np.full((len(E),), -1, dtype=np.int32) - for (idx, (i, j)) in enumerate(E): + for idx, (i, j) in enumerate(E): e_val[idx] = get_non_zero_index(e[idx]) i_spec_index = e_val[idx] // len(s[j]) j_spec_index = e_val[idx] % len(s[j]) diff --git a/colossalai/auto_parallel/tensor_shard/solver/strategies_constructor.py b/colossalai/auto_parallel/tensor_shard/solver/strategies_constructor.py index 044a8ac847ead4b6b7d9f05c3d19a43a8fc2346c..aa87ee9bf3db1fc52d5d9b8289c9d7ae7c01ca93 100644 --- a/colossalai/auto_parallel/tensor_shard/solver/strategies_constructor.py +++ b/colossalai/auto_parallel/tensor_shard/solver/strategies_constructor.py @@ -1,11 +1,5 @@ -import builtins -import math -import operator -from copy import deepcopy -from typing import Dict, List - import torch -from torch.fx import Graph, Node +from torch.fx import Graph from colossalai.auto_parallel.tensor_shard.node_handler import ( GetattrHandler, @@ -14,13 +8,12 @@ from colossalai.auto_parallel.tensor_shard.node_handler import ( operator_registry, ) from colossalai.auto_parallel.tensor_shard.sharding_strategy import StrategiesVector -from colossalai.auto_parallel.tensor_shard.utils import generate_resharding_costs, generate_sharding_spec from colossalai.auto_parallel.tensor_shard.utils.factory import find_repeat_blocks from colossalai.device.device_mesh import DeviceMesh from ..options import DataloaderOption, SolverOptions -__all__ = ['StrategiesConstructor'] +__all__ = ["StrategiesConstructor"] class StrategiesConstructor: @@ -35,7 +28,7 @@ class StrategiesConstructor: def __init__(self, graph: Graph, device_mesh: DeviceMesh, solver_options: SolverOptions): self.graph = graph - assert graph.owning_module is not None, 'The given graph is not associated with a owning_module' + assert graph.owning_module is not None, "The given graph is not associated with a owning_module" self.root_module = self.graph.owning_module self.nodes = list(graph.nodes) self.device_mesh = device_mesh @@ -46,11 +39,11 @@ class StrategiesConstructor: self.alias_set = None def remove_duplicated_strategy(self, strategies_vector): - ''' + """ In build_strategies_and_cost method, we may produce some duplicated strategies. In this method, we will remove the duplicated strategies depending on the strategies name. Note that this operation is in-place. - ''' + """ name_checklist = [] remove_list = [] for strategy in strategies_vector: @@ -62,7 +55,6 @@ class StrategiesConstructor: strategies_vector.remove(strategy) def generate_alias_set(self): - node_list = [strategy_vector.node for strategy_vector in self.leaf_strategies] common_blocks = find_repeat_blocks(node_list, self.root_module, common_length_threshold=10) @@ -83,7 +75,7 @@ class StrategiesConstructor: """ def _check_no_strategy_for_node(node): - if node.op in ('placeholder', 'get_attr', 'output'): + if node.op in ("placeholder", "get_attr", "output"): return False def _check_no_strategy_for_data(data): @@ -102,83 +94,93 @@ class StrategiesConstructor: if _check_no_strategy_for_node(node): self.no_strategy_nodes.append(node) - pass # placeholder node - elif node.op == 'placeholder': + elif node.op == "placeholder": if self.solver_options.dataloader_option == DataloaderOption.DISTRIBUTED: - placeholder_option = 'distributed' + placeholder_option = "distributed" else: - assert self.solver_options.dataloader_option == DataloaderOption.REPLICATED, f'placeholder_option {self.solver_options.dataloader_option} is not supported' - placeholder_option = 'replicated' - placeholder_handler = PlaceholderHandler(node, - self.device_mesh, - strategies_vector, - placeholder_option=placeholder_option) + assert ( + self.solver_options.dataloader_option == DataloaderOption.REPLICATED + ), f"placeholder_option {self.solver_options.dataloader_option} is not supported" + placeholder_option = "replicated" + placeholder_handler = PlaceholderHandler( + node, self.device_mesh, strategies_vector, placeholder_option=placeholder_option + ) placeholder_handler.register_strategy() # get_attr node - elif node.op == 'get_attr': - getattr_handler = GetattrHandler(node, - self.device_mesh, - strategies_vector, - shard_option=self.solver_options.shard_option, - solver_perference=self.solver_options.solver_perference) + elif node.op == "get_attr": + getattr_handler = GetattrHandler( + node, + self.device_mesh, + strategies_vector, + shard_option=self.solver_options.shard_option, + solver_perference=self.solver_options.solver_perference, + ) getattr_handler.register_strategy() # call_module node - elif node.op == 'call_module': + elif node.op == "call_module": target = node.target submod = self.root_module.get_submodule(target) submod_type = type(submod) - handler = operator_registry.get(submod_type)(node, - self.device_mesh, - strategies_vector, - shard_option=self.solver_options.shard_option, - solver_perference=self.solver_options.solver_perference) + handler = operator_registry.get(submod_type)( + node, + self.device_mesh, + strategies_vector, + shard_option=self.solver_options.shard_option, + solver_perference=self.solver_options.solver_perference, + ) handler.register_strategy() # attach strategies_info to node - if hasattr(handler, 'strategies_info'): - setattr(node, 'strategies_info', handler.strategies_info) + if hasattr(handler, "strategies_info"): + setattr(node, "strategies_info", handler.strategies_info) # call_function node - elif node.op == 'call_function': + elif node.op == "call_function": target = node.target - handler = operator_registry.get(target)(node, - self.device_mesh, - strategies_vector, - shard_option=self.solver_options.shard_option, - solver_perference=self.solver_options.solver_perference) + handler = operator_registry.get(target)( + node, + self.device_mesh, + strategies_vector, + shard_option=self.solver_options.shard_option, + solver_perference=self.solver_options.solver_perference, + ) handler.register_strategy() # attach strategies_info to node - if hasattr(handler, 'strategies_info'): - setattr(node, 'strategies_info', handler.strategies_info) + if hasattr(handler, "strategies_info"): + setattr(node, "strategies_info", handler.strategies_info) # call_method node - elif node.op == 'call_method': + elif node.op == "call_method": method = getattr(node.args[0]._meta_data.__class__, node.target) - handler = operator_registry.get(method)(node, - self.device_mesh, - strategies_vector, - shard_option=self.solver_options.shard_option, - solver_perference=self.solver_options.solver_perference) + handler = operator_registry.get(method)( + node, + self.device_mesh, + strategies_vector, + shard_option=self.solver_options.shard_option, + solver_perference=self.solver_options.solver_perference, + ) handler.register_strategy() # attach strategies_info to node - if hasattr(handler, 'strategies_info'): - setattr(node, 'strategies_info', handler.strategies_info) + if hasattr(handler, "strategies_info"): + setattr(node, "strategies_info", handler.strategies_info) # output node - elif node.op == 'output': + elif node.op == "output": if self.solver_options.dataloader_option == DataloaderOption.DISTRIBUTED: - output_option = 'distributed' + output_option = "distributed" else: - assert self.solver_options.dataloader_option == DataloaderOption.REPLICATED, f'placeholder_option {self.solver_options.dataloader_option} is not supported' - output_option = 'replicated' + assert ( + self.solver_options.dataloader_option == DataloaderOption.REPLICATED + ), f"placeholder_option {self.solver_options.dataloader_option} is not supported" + output_option = "replicated" output_handler = OutputHandler(node, self.device_mesh, strategies_vector, output_option=output_option) output_handler.register_strategy() self.remove_duplicated_strategy(strategies_vector) - setattr(node, 'strategies_vector', strategies_vector) + setattr(node, "strategies_vector", strategies_vector) self.leaf_strategies.append(strategies_vector) self.strategy_map[node] = strategies_vector diff --git a/colossalai/auto_parallel/tensor_shard/utils/__init__.py b/colossalai/auto_parallel/tensor_shard/utils/__init__.py index b7fe5430bf136b08d93706b33cf4dbf82e342013..d61cfd2add15ba096270723446ffe52768a31f31 100644 --- a/colossalai/auto_parallel/tensor_shard/utils/__init__.py +++ b/colossalai/auto_parallel/tensor_shard/utils/__init__.py @@ -17,9 +17,21 @@ from .sharding import ( ) __all__ = [ - 'BroadcastType', 'get_broadcast_shape', 'is_broadcastable', 'recover_sharding_spec_for_broadcast_shape', - 'generate_resharding_costs', 'generate_sharding_spec', 'ignore_sharding_exception', 'check_sharding_spec_validity' - 'transpose_partition_dim', 'update_partition_dim', 'enumerate_all_possible_1d_sharding', - 'enumerate_all_possible_2d_sharding', 'generate_sharding_size', 'comm_actions_for_oprands', 'pytree_map', - 'detect_reshape_mapping', 'check_keep_sharding_status', 'infer_output_dim_partition_dict' + "BroadcastType", + "get_broadcast_shape", + "is_broadcastable", + "recover_sharding_spec_for_broadcast_shape", + "generate_resharding_costs", + "generate_sharding_spec", + "ignore_sharding_exception", + "check_sharding_spec_validity" "transpose_partition_dim", + "update_partition_dim", + "enumerate_all_possible_1d_sharding", + "enumerate_all_possible_2d_sharding", + "generate_sharding_size", + "comm_actions_for_oprands", + "pytree_map", + "detect_reshape_mapping", + "check_keep_sharding_status", + "infer_output_dim_partition_dict", ] diff --git a/colossalai/auto_parallel/tensor_shard/utils/broadcast.py b/colossalai/auto_parallel/tensor_shard/utils/broadcast.py index 28aa551328d7a6d5f283338fe55c90eb102d253c..99d5a0f2a9420d50afa47485b9a4b20ed8c60a2a 100644 --- a/colossalai/auto_parallel/tensor_shard/utils/broadcast.py +++ b/colossalai/auto_parallel/tensor_shard/utils/broadcast.py @@ -14,14 +14,17 @@ from colossalai.tensor.comm_spec import CollectiveCommPattern, CommSpec from colossalai.tensor.sharding_spec import ShardingSpec __all__ = [ - 'BroadcastType', 'is_broadcastable', 'get_broadcast_shape', 'recover_sharding_spec_for_broadcast_shape', - 'comm_actions_for_oprands' + "BroadcastType", + "is_broadcastable", + "get_broadcast_shape", + "recover_sharding_spec_for_broadcast_shape", + "comm_actions_for_oprands", ] class BroadcastType(Enum): EQUAL = auto() - PADDDING = auto() + PADDING = auto() MULTIPLE = auto() @@ -41,7 +44,7 @@ def get_broadcast_shape(shape1: torch.Size, shape2: torch.Size) -> List[int]: """ Compute the broadcast shape given two shapes. """ - assert is_broadcastable(shape1, shape2), f'{shape1} and {shape2} are not broadcastable' + assert is_broadcastable(shape1, shape2), f"{shape1} and {shape2} are not broadcastable" shape1_reverse = shape1[::-1] shape2_reverse = shape2[::-1] min_common_dim = min(len(shape1), len(shape2)) @@ -60,8 +63,9 @@ def get_broadcast_dim_info(logical_shape, physical_shape): logical_num_dims = len(logical_shape) physical_num_dims = len(physical_shape) - assert logical_num_dims >= physical_num_dims, \ - 'The number of dimensions in the logical shape is smaller than that of the physical shape, this tensor is not broadcast!' + assert ( + logical_num_dims >= physical_num_dims + ), "The number of dimensions in the logical shape is smaller than that of the physical shape, this tensor is not broadcast!" # track the dim and its broadcasting type logical_dim_broadcast_info = {} @@ -69,24 +73,25 @@ def get_broadcast_dim_info(logical_shape, physical_shape): for i in range(logical_num_dims): # get the trailing dim size logical_dim_idx = logical_num_dims - i - 1 - phyiscal_dim_idx = physical_num_dims - i - 1 + physical_dim_idx = physical_num_dims - i - 1 logical_dim_size = logical_shape[logical_dim_idx] - if phyiscal_dim_idx >= 0: - physical_dim_size = physical_shape[phyiscal_dim_idx] + if physical_dim_idx >= 0: + physical_dim_size = physical_shape[physical_dim_idx] if physical_dim_size == logical_dim_size: logical_dim_broadcast_info[logical_dim_idx] = BroadcastType.EQUAL elif physical_dim_size == 1 and physical_dim_size != logical_dim_size: logical_dim_broadcast_info[logical_dim_idx] = BroadcastType.MULTIPLE else: - logical_dim_broadcast_info[logical_dim_idx] = BroadcastType.PADDDING + logical_dim_broadcast_info[logical_dim_idx] = BroadcastType.PADDING return logical_dim_broadcast_info -def recover_sharding_spec_for_broadcast_shape(logical_sharding_spec: ShardingSpec, logical_shape: torch.Size, - physical_shape: torch.Size) -> ShardingSpec: +def recover_sharding_spec_for_broadcast_shape( + logical_sharding_spec: ShardingSpec, logical_shape: torch.Size, physical_shape: torch.Size +) -> ShardingSpec: """ This function computes the sharding spec for the physical shape of a broadcast tensor. @@ -117,22 +122,25 @@ def recover_sharding_spec_for_broadcast_shape(logical_sharding_spec: ShardingSpe for shape_dim, mesh_dim in logical_dim_partition.items(): logical_broadcast_type = logical_dim_broadcast_info[shape_dim] - if logical_broadcast_type == BroadcastType.PADDDING or logical_broadcast_type == BroadcastType.MULTIPLE: + if logical_broadcast_type == BroadcastType.PADDING or logical_broadcast_type == BroadcastType.MULTIPLE: removed_dims.extend(mesh_dim) else: # get the corresponding physical dim physical_dim = physical_num_dims - (logical_num_dims - shape_dim) physical_dim_partition[physical_dim] = mesh_dim - physical_sharding_spec = ShardingSpec(device_mesh=logical_sharding_spec.device_mesh, - entire_shape=physical_shape, - dim_partition_dict=physical_dim_partition) + physical_sharding_spec = ShardingSpec( + device_mesh=logical_sharding_spec.device_mesh, + entire_shape=physical_shape, + dim_partition_dict=physical_dim_partition, + ) return physical_sharding_spec, removed_dims -def comm_actions_for_oprands(node: Node, removed_dims: List[int], op_data: OperationData, - sharding_spec: ShardingSpec) -> CommAction: +def comm_actions_for_oprands( + node: Node, removed_dims: List[int], op_data: OperationData, sharding_spec: ShardingSpec +) -> CommAction: """ This method is used to generate communication actions for oprands which lose information during convert logical shape to physical shape. @@ -140,9 +148,11 @@ def comm_actions_for_oprands(node: Node, removed_dims: List[int], op_data: Opera if len(removed_dims) == 1: # if list length is 1, extract element from list to avoid using flatten device mesh removed_dims = removed_dims[0] - comm_spec = CommSpec(comm_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, - sharding_spec=sharding_spec, - logical_process_axis=removed_dims) + comm_spec = CommSpec( + comm_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, + sharding_spec=sharding_spec, + logical_process_axis=removed_dims, + ) if op_data.type == OperationDataType.PARAM: comm_type = CommType.HOOK else: @@ -151,7 +161,7 @@ def comm_actions_for_oprands(node: Node, removed_dims: List[int], op_data: Opera for index, arg in enumerate(node.args): if op_data.name == str(arg): arg_index = index - assert arg_index >= 0, f'op_data should be an argument of node.' + assert arg_index >= 0, f"op_data should be an argument of node." comm_action = CommAction( comm_spec=comm_spec, comm_type=comm_type, diff --git a/colossalai/auto_parallel/tensor_shard/utils/factory.py b/colossalai/auto_parallel/tensor_shard/utils/factory.py index 05331e56000110a982cc776a24eb81d45fceb825..aaca923a5eeee28b241b3bc7877ace321b312761 100644 --- a/colossalai/auto_parallel/tensor_shard/utils/factory.py +++ b/colossalai/auto_parallel/tensor_shard/utils/factory.py @@ -14,11 +14,12 @@ from colossalai.tensor.sharding_spec import ShardingSpec from ..constants import INFINITY_COST -__all__ = ['generate_sharding_spec', 'generate_resharding_costs'] +__all__ = ["generate_sharding_spec", "generate_resharding_costs"] -def generate_sharding_spec(input_: Union[Node, torch.Tensor], device_mesh: DeviceMesh, - dim_partition_dict: Dict[int, List[int]]) -> ShardingSpec: +def generate_sharding_spec( + input_: Union[Node, torch.Tensor], device_mesh: DeviceMesh, dim_partition_dict: Dict[int, List[int]] +) -> ShardingSpec: """ Generate the sharding spec of the tensor based on the given dim_partition_dict. @@ -30,7 +31,7 @@ def generate_sharding_spec(input_: Union[Node, torch.Tensor], device_mesh: Devic """ if isinstance(input_, Node): - assert hasattr(input_, '_meta_data'), f'The given node has no attribte _meta_data' + assert hasattr(input_, "_meta_data"), f"The given node has no attribute _meta_data" meta_tensor = input_._meta_data assert meta_tensor is not None, "The given node's _meta_data attribute is None" shape = meta_tensor.shape @@ -38,24 +39,27 @@ def generate_sharding_spec(input_: Union[Node, torch.Tensor], device_mesh: Devic shape = input_.shape else: raise TypeError( - f'We cannot generate sharding spec for {type(input_)} type, only torch.fx.Node or torch.Tensor is expected.' + f"We cannot generate sharding spec for {type(input_)} type, only torch.fx.Node or torch.Tensor is expected." ) for dim_index, sharding_index_list in dim_partition_dict.items(): sharding_list = [device_mesh.mesh_shape[sharding_index] for sharding_index in sharding_index_list] sharding_size = reduce(operator.mul, sharding_list, 1) - assert shape[ - dim_index] % sharding_size == 0, f'we cannot shard the {dim_index} dimension of tensor into {sharding_size} partitions.' + assert ( + shape[dim_index] % sharding_size == 0 + ), f"we cannot shard the {dim_index} dimension of tensor into {sharding_size} partitions." sharding_spec = ShardingSpec(device_mesh=device_mesh, entire_shape=shape, dim_partition_dict=dim_partition_dict) return sharding_spec -def generate_resharding_costs(nodes: List[Node], - sharding_specs: List[ShardingSpec], - count_backward: Optional[bool] = True, - dtype: Optional[torch.dtype] = None, - index=None): - ''' +def generate_resharding_costs( + nodes: List[Node], + sharding_specs: List[ShardingSpec], + count_backward: Optional[bool] = True, + dtype: Optional[torch.dtype] = None, + index=None, +): + """ Compute the resharding costs with this specific strategy. Argument: @@ -63,7 +67,7 @@ def generate_resharding_costs(nodes: List[Node], sharding_spec_for_input(ShardingSpec): a list of ShardingSpec for the nodes. count_backward (Optional[bool]): whether to include the cost of resharding in the backward pass, default is True. False can be used for inference. dtype (Optional[torch.dtype]): the data type for cost calculation, default is None. - ''' + """ # The resharding_cost of weight is counted due to sharing weight cases. resharding_costs = {} size_per_elem_bytes = torch.tensor([], dtype=dtype).element_size() @@ -76,38 +80,39 @@ def generate_resharding_costs(nodes: List[Node], for strategy in input_node.strategies_vector: input_sharding_spec = strategy.output_sharding_spec if not isinstance(input_sharding_spec, ShardingSpec): - assert isinstance(input_sharding_spec, list), 'only ShardingSpec or List[ShardingSpec] is expected.' + assert isinstance(input_sharding_spec, list), "only ShardingSpec or List[ShardingSpec] is expected." input_sharding_spec = input_sharding_spec[index] - assert isinstance(input_sharding_spec, ShardingSpec), f'The input node should NOT be a tuple of tensor.' + assert isinstance(input_sharding_spec, ShardingSpec), f"The input node should NOT be a tuple of tensor." try: # compute the resharding cost _, _, total_resharding_cost = shape_consistency_manager.shape_consistency( - input_sharding_spec, input_spec) + input_sharding_spec, input_spec + ) # we need multiply the size of elem dtype to get correct communication cost resharding_cost = total_resharding_cost["total"] * size_per_elem_bytes except AssertionError as e: - warnings.warn(f'{e}') + warnings.warn(f"{e}") resharding_cost = INFINITY_COST resharding_costs[input_node].append(resharding_cost) return resharding_costs def find_repeat_blocks(node_list: List[torch.fx.Node], root_module, common_length_threshold: int = 20): - ''' + """ Find the largest repeat blocks in the graph, whose length is larger than the threshold. Args: gm (GraphModule): the graph module to be analyzed. common_length_threshold (int): the threshold of the repeat block length. - ''' + """ # graph = gm.graph def _process_args(args): new_args = [] for arg in args: - if hasattr(arg, '_meta_data'): + if hasattr(arg, "_meta_data"): meta_data = arg._meta_data else: meta_data = arg @@ -145,7 +150,7 @@ def find_repeat_blocks(node_list: List[torch.fx.Node], root_module, common_lengt return False for index, node in enumerate(node_list): - if node.op == 'call_module': + if node.op == "call_module": target = node.target submod = root_module.get_submodule(target) submod_type = type(submod) @@ -155,12 +160,12 @@ def find_repeat_blocks(node_list: List[torch.fx.Node], root_module, common_lengt new_args = _process_args(node.args) - if node.op != 'get_attr': + if node.op != "get_attr": hash_key = (node.op, target, *new_args) else: hash_key = (node.op,) - setattr(node, 'hash_key', hash_key) + setattr(node, "hash_key", hash_key) hash_value_to_node_dict = {} @@ -179,7 +184,7 @@ def find_repeat_blocks(node_list: List[torch.fx.Node], root_module, common_lengt # the comparison will be triggered if a common node appears if len(hash_value_to_node_dict[hash(node.hash_key)]) >= 2: start_index_list = hash_value_to_node_dict[hash(node.hash_key)] - check_block_list = [node_list[start:start + max_common_length] for start in start_index_list] + check_block_list = [node_list[start : start + max_common_length] for start in start_index_list] common_label = True if not _all_equal(check_block_list, _check_node_list_equal): @@ -201,6 +206,6 @@ def find_repeat_blocks(node_list: List[torch.fx.Node], root_module, common_lengt # recover common subgraph from the index common_blocks = [] for start in common_blocks_index: - common_blocks.append(node_list[start:start + max_common_length]) + common_blocks.append(node_list[start : start + max_common_length]) return common_blocks diff --git a/colossalai/auto_parallel/tensor_shard/utils/misc.py b/colossalai/auto_parallel/tensor_shard/utils/misc.py index 9e402dab757820c5d76ee6d1166de473c040784b..42ec2a8ee428880a6a86e50e2a8c0ae21a2a4b74 100644 --- a/colossalai/auto_parallel/tensor_shard/utils/misc.py +++ b/colossalai/auto_parallel/tensor_shard/utils/misc.py @@ -1,12 +1,12 @@ import functools -from typing import Any, Callable, Dict, List, Tuple, Type, Union +from typing import Any, Callable, Tuple, Type, Union import torch from colossalai.logging import get_dist_logger from colossalai.tensor.sharding_spec import ShardingSpec, ShardingSpecException -__all__ = ['ignore_sharding_exception', 'pytree_map'] +__all__ = ["ignore_sharding_exception", "pytree_map"] def ignore_sharding_exception(func): @@ -46,31 +46,34 @@ def check_sharding_spec_validity(sharding_spec: ShardingSpec, tensor: torch.Tens # make sure all dims are covered in sharding spec sharding_len = len(sharding_spec.sharding_sequence) tensor_num_dim = tensor.dim() - num_devices_in_col = sharding_spec.device_mesh.mesh_shape[0] - num_devices_in_row = sharding_spec.device_mesh.mesh_shape[1] - assert sharding_len == tensor_num_dim, \ - f'The ShardingSpec ({sharding_spec.sharding_sequence}) is created for {sharding_len}-dimension tensor, but the given tensor is {tensor_num_dim}-dimension ({tensor.shape}).' + num_devices_in_col = sharding_spec.device_mesh.shape[0] + num_devices_in_row = sharding_spec.device_mesh.shape[1] + assert ( + sharding_len == tensor_num_dim + ), f"The ShardingSpec ({sharding_spec.sharding_sequence}) is created for {sharding_len}-dimension tensor, but the given tensor is {tensor_num_dim}-dimension ({tensor.shape})." # make sure the sharding is valid for each dim for i in range(tensor_num_dim): dim_size = tensor.shape[i] dim_spec = sharding_spec.sharding_sequence[i] - if str(dim_spec).startswith('S'): - devices_str = str(dim_spec).lstrip('S') + if str(dim_spec).startswith("S"): + devices_str = str(dim_spec).lstrip("S") num_devices = 1 - if '0' in devices_str: + if "0" in devices_str: num_devices *= num_devices_in_col - if '1' in devices_str: + if "1" in devices_str: num_devices *= num_devices_in_row - assert dim_size >= num_devices and dim_size % num_devices == 0, \ - f'The dimension at index {i} has value {dim_size}, but it is sharded over {num_devices} devices.' + assert ( + dim_size >= num_devices and dim_size % num_devices == 0 + ), f"The dimension at index {i} has value {dim_size}, but it is sharded over {num_devices} devices." # make sure the entire shape matches the physical tensor shape - assert sharding_spec.entire_shape == tensor.shape, \ - f'The entire_shape of the sharding spec {sharding_spec.entire_shape} does not match the tensor shape {tensor.shape}' + assert ( + sharding_spec.entire_shape == tensor.shape + ), f"The entire_shape of the sharding spec {sharding_spec.entire_shape} does not match the tensor shape {tensor.shape}" def pytree_map(obj: Any, fn: Callable, process_types: Union[Type, Tuple[Type]] = (), map_all: bool = False) -> Any: diff --git a/colossalai/auto_parallel/tensor_shard/utils/reshape.py b/colossalai/auto_parallel/tensor_shard/utils/reshape.py index a32a14bf7d577713ae2cb986ffbb42d87b0cabc1..329312ef797fe1afa649bc675425082faf495c62 100644 --- a/colossalai/auto_parallel/tensor_shard/utils/reshape.py +++ b/colossalai/auto_parallel/tensor_shard/utils/reshape.py @@ -6,12 +6,13 @@ import torch class PreviousStatus(Enum): """ - This class shows the status of previous comparision. + This class shows the status of previous comparison. """ + RESET = 0 - # ORIGIN means the dimension size of original tensor is larger in the previous comparision. + # ORIGIN means the dimension size of original tensor is larger in the previous comparison. ORIGIN = 1 - # TGT means the dimension size of target tensor is larger in the previous comparision. + # TGT means the dimension size of target tensor is larger in the previous comparison. TGT = 2 @@ -91,7 +92,7 @@ def detect_reshape_mapping(origin_shape: torch.Size, tgt_shape: torch.Size) -> D tgt_index += 1 if previous_label == PreviousStatus.TGT: - # if the target dimension size is larger in the previous comparision, which means + # if the target dimension size is larger in the previous comparison, which means # the origin dimension size has already accumulated larger than target dimension size, so # we need to offload the origin dims and tgt dims into the reshape_mapping_dict. reshape_mapping_dict[tuple(origin_dims)] = tuple(tgt_dims) @@ -111,7 +112,7 @@ def detect_reshape_mapping(origin_shape: torch.Size, tgt_shape: torch.Size) -> D origin_index += 1 if previous_label == PreviousStatus.ORIGIN: - # if the origin element is larger in the previous comparision, which means + # if the origin element is larger in the previous comparison, which means # the target element has already accumulated larger than origin element, so # we need to offload the origin dims and tgt dims into the reshape_mapping_dict. reshape_mapping_dict[tuple(origin_dims)] = tuple(tgt_dims) @@ -130,8 +131,9 @@ def detect_reshape_mapping(origin_shape: torch.Size, tgt_shape: torch.Size) -> D return reshape_mapping_dict -def check_keep_sharding_status(input_dim_partition_dict: Dict[int, List[int]], - reshape_mapping_dict: Dict[Tuple[int], Tuple[int]]) -> bool: +def check_keep_sharding_status( + input_dim_partition_dict: Dict[int, List[int]], reshape_mapping_dict: Dict[Tuple[int], Tuple[int]] +) -> bool: """ This method is used to check whether the reshape operation could implement without converting the input to fully replicated status. @@ -139,7 +141,7 @@ def check_keep_sharding_status(input_dim_partition_dict: Dict[int, List[int]], Rule: For a sharded dimension of input tensor, if it is not the minimum element of the input tuple, the function will return false. - To illustrate this issue, there are two cases to analyse: + To illustrate this issue, there are two cases to analyze: 1. no sharded dims in the input tuple: we could do the reshape operation safely just as the normal operation without distributed tensor. 2. sharded dims in the input tuple: the sharded dim must be the minimum element, then during shape @@ -172,14 +174,16 @@ def check_keep_sharding_status(input_dim_partition_dict: Dict[int, List[int]], return True -def infer_output_dim_partition_dict(input_dim_partition_dict: Dict[int, List[int]], - reshape_mapping_dict: Dict[Tuple[int], Tuple[int]]) -> Dict[Tuple[int], Tuple[int]]: +def infer_output_dim_partition_dict( + input_dim_partition_dict: Dict[int, List[int]], reshape_mapping_dict: Dict[Tuple[int], Tuple[int]] +) -> Dict[Tuple[int], Tuple[int]]: """ This method is used to infer the output dim partition dict for a reshape operation, given the input dim partition dict and reshape mapping dict. """ - assert check_keep_sharding_status(input_dim_partition_dict, reshape_mapping_dict), \ - 'we only infer output dim partition dict for the reshape operation could keep sharding spec.' + assert check_keep_sharding_status( + input_dim_partition_dict, reshape_mapping_dict + ), "we only infer output dim partition dict for the reshape operation could keep sharding spec." sharded_dims = list(input_dim_partition_dict.keys()) output_dim_partition_dict = {} for input_dims, output_dims in reshape_mapping_dict.items(): diff --git a/colossalai/auto_parallel/tensor_shard/utils/sharding.py b/colossalai/auto_parallel/tensor_shard/utils/sharding.py index e2ce59e0b5772679be11e960322e3110c500d6aa..b5386d599be4a62672c83c6f3daf6085e6062349 100644 --- a/colossalai/auto_parallel/tensor_shard/utils/sharding.py +++ b/colossalai/auto_parallel/tensor_shard/utils/sharding.py @@ -8,8 +8,11 @@ import torch from colossalai.tensor.sharding_spec import ShardingSpec __all__ = [ - 'transpose_partition_dim', 'update_partition_dim', 'enumerate_all_possible_1d_sharding', - 'enumerate_all_possible_2d_sharding', 'generate_sharding_size' + "transpose_partition_dim", + "update_partition_dim", + "enumerate_all_possible_1d_sharding", + "enumerate_all_possible_2d_sharding", + "generate_sharding_size", ] @@ -22,8 +25,7 @@ def transpose_partition_dim(sharding_spec: ShardingSpec, dim1: int, dim2: int) - dim1 (int): the tensor dimension to switch dim2 (int): the tensor dimension to switch """ - assert len(sharding_spec.entire_shape) >= 2, \ - 'The entire_shape of the sharding spec must have at least 2 dimensions' + assert len(sharding_spec.entire_shape) >= 2, "The entire_shape of the sharding spec must have at least 2 dimensions" dim_partition_dict = sharding_spec.dim_partition_dict # transpose the dim partition @@ -45,10 +47,9 @@ def transpose_partition_dim(sharding_spec: ShardingSpec, dim1: int, dim2: int) - return sharding_spec -def update_partition_dim(sharding_spec: ShardingSpec, - dim_mapping: Dict[int, int], - physical_shape: torch.Size, - inplace: bool = False): +def update_partition_dim( + sharding_spec: ShardingSpec, dim_mapping: Dict[int, int], physical_shape: torch.Size, inplace: bool = False +): """ This method is used to update the partition dim dict from the logical one to the physical one. @@ -78,9 +79,9 @@ def update_partition_dim(sharding_spec: ShardingSpec, new_dim_partition_dict[tensor_dim] = mesh_dims # update sharding spec - current_sharding_spec.__init__(device_mesh=sharding_spec.device_mesh, - entire_shape=physical_shape, - dim_partition_dict=new_dim_partition_dict) + current_sharding_spec.__init__( + device_mesh=sharding_spec.device_mesh, entire_shape=physical_shape, dim_partition_dict=new_dim_partition_dict + ) return current_sharding_spec diff --git a/colossalai/autochunk/autochunk_codegen.py b/colossalai/autochunk/autochunk_codegen.py index d0a467254d7279c37031a755fa62a53fc4e0d9b9..9571fa2c17f074adfd4bff244c8ae8bcfcadb3c0 100644 --- a/colossalai/autochunk/autochunk_codegen.py +++ b/colossalai/autochunk/autochunk_codegen.py @@ -9,7 +9,18 @@ from colossalai.fx.codegen.activation_checkpoint_codegen import CODEGEN_AVAILABL AUTOCHUNK_AVAILABLE = CODEGEN_AVAILABLE and is_compatible_with_meta() if AUTOCHUNK_AVAILABLE: - from torch.fx.graph import CodeGen, PythonCode, _custom_builtins, _CustomBuiltin, _format_target, _is_from_torch, _Namespace, _origin_type_map, inplace_methods, magic_methods + from torch.fx.graph import ( + CodeGen, + PythonCode, + _custom_builtins, + _CustomBuiltin, + _format_target, + _is_from_torch, + _Namespace, + _origin_type_map, + inplace_methods, + magic_methods, + ) from torch.fx.node import Argument, Node, _get_qualified_name, _type_repr, map_arg @@ -40,7 +51,7 @@ def _gen_chunk_slice_dim(chunk_dim: int, chunk_indice_name: str, shape: List) -> return new_shape -def _gen_loop_start(chunk_input: List[Node], chunk_output: List[Node], chunk_ouput_dim: int, chunk_size=2) -> str: +def _gen_loop_start(chunk_input: List[Node], chunk_output: List[Node], chunk_output_dim: int, chunk_size=2) -> str: """ Generate chunk loop start @@ -52,7 +63,7 @@ def _gen_loop_start(chunk_input: List[Node], chunk_output: List[Node], chunk_oup Args: chunk_input (List[Node]): chunk input node chunk_output (Node): chunk output node - chunk_ouput_dim (int): chunk output node chunk dim + chunk_output_dim (int): chunk output node chunk dim chunk_size (int): chunk size. Defaults to 2. Returns: @@ -64,23 +75,36 @@ def _gen_loop_start(chunk_input: List[Node], chunk_output: List[Node], chunk_oup for i in range(len(chunk_output)): shape_str = str(list(get_node_shape(chunk_output[i]))) if get_node_name(chunk_output[i]) in ["split", "unbind"]: - tensor_str = "torch.empty(%s, dtype=%s.dtype, device=%s.device), " % (shape_str, input_node.name, - input_node.name) - tensor_str = tensor_str * len(chunk_output[i].meta['tensor_meta']) + tensor_str = "torch.empty(%s, dtype=%s.dtype, device=%s.device), " % ( + shape_str, + input_node.name, + input_node.name, + ) + tensor_str = tensor_str * len(chunk_output[i].meta["tensor_meta"]) tensor_str = "[" + tensor_str[:-2] + "]" context += "%s = %s; " % (chunk_output[i].name, tensor_str) else: - context += "%s = torch.empty(%s, dtype=%s.dtype, device=%s.device); " % (chunk_output[i].name, shape_str, - input_node.name, input_node.name) + context += "%s = torch.empty(%s, dtype=%s.dtype, device=%s.device); " % ( + chunk_output[i].name, + shape_str, + input_node.name, + input_node.name, + ) out_shape = get_node_shape(chunk_output[0]) - chunk_shape = out_shape[chunk_ouput_dim[0]] + chunk_shape = out_shape[chunk_output_dim[0]] context += "chunk_size = %d\nfor chunk_idx in range(0, %d, chunk_size):\n" % (chunk_size, chunk_shape) return context -def _gen_loop_end(chunk_inputs: List[Node], chunk_non_compute_inputs: List[Node], node_list: List[Node], - chunk_outputs_idx: int, chunk_outputs_non_tensor: List[Node], search_chunk: SearchChunk) -> str: +def _gen_loop_end( + chunk_inputs: List[Node], + chunk_non_compute_inputs: List[Node], + node_list: List[Node], + chunk_outputs_idx: int, + chunk_outputs_non_tensor: List[Node], + search_chunk: SearchChunk, +) -> str: """ Generate chunk loop end @@ -148,8 +172,10 @@ def _replace_new_tensor_like_shape( chunk_dim = chunk_infos[region_idx]["node_chunk_dim"][meta_node]["chunk_dim"] if get_node_shape(meta_node)[chunk_dim] != 1: source_node = meta_node.args[0].args[0] - if (source_node not in chunk_infos[region_idx]["node_chunk_dim"] - or chunk_infos[region_idx]["node_chunk_dim"][source_node]["chunk_dim"] is None): + if ( + source_node not in chunk_infos[region_idx]["node_chunk_dim"] + or chunk_infos[region_idx]["node_chunk_dim"][source_node]["chunk_dim"] is None + ): chunk_slice = _gen_chunk_slice_dim(chunk_dim, "chunk_idx", get_node_shape(node)) body[-1] = _replace_name(body[-1], node.args[0].name, node.args[0].name + chunk_slice) return body @@ -203,11 +229,12 @@ def _add_node_slice( # outputs node else: if chunk_node.name == node.name or (chunk_node.name in [i.name for i in node.all_input_nodes]): - chunk_slice = _gen_chunk_slice_dim(chunk_nodes_dim[region_idx][chunk_node_idx], "chunk_idx", - get_node_shape(chunk_node)) + chunk_slice = _gen_chunk_slice_dim( + chunk_nodes_dim[region_idx][chunk_node_idx], "chunk_idx", get_node_shape(chunk_node) + ) if get_node_name(chunk_node) in ["split", "unbind"]: split_chunk_slice = "" - for i in range(len(chunk_node.meta['tensor_meta'])): + for i in range(len(chunk_node.meta["tensor_meta"])): split_chunk_slice += "%s[%d]%s, " % (chunk_node.name, i, chunk_slice) split_chunk_slice = split_chunk_slice[:-2] body[-1] = _replace_name(body[-1], chunk_node.name, split_chunk_slice) @@ -216,13 +243,15 @@ def _add_node_slice( return body -def emit_code_with_chunk(body: List[str], - nodes: Iterable[Node], - emit_node_func: Callable, - delete_unused_value_func: Callable, - search_chunk: SearchChunk, - chunk_infos: List, - eval_mem: bool = False): +def emit_code_with_chunk( + body: List[str], + nodes: Iterable[Node], + emit_node_func: Callable, + delete_unused_value_func: Callable, + search_chunk: SearchChunk, + chunk_infos: List, + eval_mem: bool = False, +): """ Emit code with chunk according to chunk_infos. @@ -244,9 +273,9 @@ def emit_code_with_chunk(body: List[str], chunk_ends = [i["region"][1] for i in chunk_infos] # chunk inputs - chunk_inputs = [i["inputs"] for i in chunk_infos] # input with chunk - chunk_inputs_non_chunk = [i["inputs_non_chunk"] for i in chunk_infos] # input without chunk - chunk_inputs_dim = [i["inputs_dim"] for i in chunk_infos] # input chunk dim + chunk_inputs = [i["inputs"] for i in chunk_infos] # input with chunk + chunk_inputs_non_chunk = [i["inputs_non_chunk"] for i in chunk_infos] # input without chunk + chunk_inputs_dim = [i["inputs_dim"] for i in chunk_infos] # input chunk dim chunk_inputs_names = [j.name for i in chunk_inputs for j in i] + [j.name for i in chunk_inputs_non_chunk for j in i] # chunk outputs @@ -275,7 +304,8 @@ def emit_code_with_chunk(body: List[str], chunk_outputs[region_idx], chunk_outputs_dim[region_idx], chunk_infos[region_idx]["chunk_size"], - )) + ) + ) if within_chunk_region: emit_node_func(node, body) @@ -294,7 +324,8 @@ def emit_code_with_chunk(body: List[str], if eval_mem: body.append( " if chunk_idx == 0:\n print('%s', torch.cuda.max_memory_allocated() / 1024**2 - init_memory); torch.cuda.reset_peak_memory_stats()\n" - % (node.name)) + % (node.name) + ) else: emit_node_func(node, body) if node_idx not in chunk_inputs: @@ -302,13 +333,21 @@ def emit_code_with_chunk(body: List[str], if eval_mem: body.append( "print('%s', torch.cuda.max_memory_allocated() / 1024**2 - init_memory); torch.cuda.reset_peak_memory_stats()\n" - % (node.name)) + % (node.name) + ) # generate chunk region end if node_idx in chunk_ends: body.append( - _gen_loop_end(chunk_inputs[region_idx], chunk_inputs_non_chunk[region_idx], node_list, - chunk_ends[region_idx], chunk_outputs_non_tensor[region_idx], search_chunk)) + _gen_loop_end( + chunk_inputs[region_idx], + chunk_inputs_non_chunk[region_idx], + node_list, + chunk_ends[region_idx], + chunk_outputs_non_tensor[region_idx], + search_chunk, + ) + ) within_chunk_region = False node_idx += 1 @@ -317,13 +356,14 @@ def emit_code_with_chunk(body: List[str], if AUTOCHUNK_AVAILABLE: class AutoChunkCodeGen(CodeGen): - - def __init__(self, - meta_graph, - max_memory: int = None, - print_mem: bool = False, - print_progress: bool = False, - eval_mem: bool = False) -> None: + def __init__( + self, + meta_graph, + max_memory: int = None, + print_mem: bool = False, + print_progress: bool = False, + eval_mem: bool = False, + ) -> None: super().__init__() self.eval_mem = eval_mem # find the chunk regions @@ -349,7 +389,7 @@ if AUTOCHUNK_AVAILABLE: Returns: the global name that should be used to reference 'obj' in generated source. """ - if (_is_from_torch(obj) and obj != torch.device): # to support registering torch.device + if _is_from_torch(obj) and obj != torch.device: # to support registering torch.device # HACK: workaround for how torch custom ops are registered. We # can't import them like normal modules so they must retain their # fully qualified name. @@ -402,7 +442,6 @@ if AUTOCHUNK_AVAILABLE: return add_global(typename, o) def _format_args(args: Tuple[Argument, ...], kwargs: Dict[str, Argument]) -> str: - def _get_repr(arg): # Handle NamedTuples (if it has `_fields`) via add_global. if isinstance(arg, tuple) and hasattr(arg, "_fields"): @@ -457,10 +496,10 @@ if AUTOCHUNK_AVAILABLE: # NOTE: we add a variable to distinguish body and ckpt_func def emit_node(node: Node, body): - maybe_type_annotation = ("" if node.type is None else f" : {type_repr(node.type)}") + maybe_type_annotation = "" if node.type is None else f" : {type_repr(node.type)}" if node.op == "placeholder": assert isinstance(node.target, str) - maybe_default_arg = ("" if not node.args else f" = {repr(node.args[0])}") + maybe_default_arg = "" if not node.args else f" = {repr(node.args[0])}" free_vars.append(f"{node.target}{maybe_type_annotation}{maybe_default_arg}") raw_name = node.target.replace("*", "") if raw_name != repr(node): @@ -470,42 +509,56 @@ if AUTOCHUNK_AVAILABLE: assert isinstance(node.target, str) body.append( f"{repr(node)}{maybe_type_annotation} = {_format_target(repr(node.args[0]), node.target)}" - f"({_format_args(node.args[1:], node.kwargs)})") + f"({_format_args(node.args[1:], node.kwargs)})" + ) return elif node.op == "call_function": assert callable(node.target) # pretty print operators - if (node.target.__module__ == "_operator" and node.target.__name__ in magic_methods): + if node.target.__module__ == "_operator" and node.target.__name__ in magic_methods: assert isinstance(node.args, tuple) - body.append(f"{repr(node)}{maybe_type_annotation} = " - f"{magic_methods[node.target.__name__].format(*(repr(a) for a in node.args))}") + body.append( + f"{repr(node)}{maybe_type_annotation} = " + f"{magic_methods[node.target.__name__].format(*(repr(a) for a in node.args))}" + ) return # pretty print inplace operators; required for jit.script to work properly # not currently supported in normal FX graphs, but generated by torchdynamo - if (node.target.__module__ == "_operator" and node.target.__name__ in inplace_methods): - body.append(f"{inplace_methods[node.target.__name__].format(*(repr(a) for a in node.args))}; " - f"{repr(node)}{maybe_type_annotation} = {repr(node.args[0])}") + if node.target.__module__ == "_operator" and node.target.__name__ in inplace_methods: + body.append( + f"{inplace_methods[node.target.__name__].format(*(repr(a) for a in node.args))}; " + f"{repr(node)}{maybe_type_annotation} = {repr(node.args[0])}" + ) return qualified_name = _get_qualified_name(node.target) global_name = add_global(qualified_name, node.target) # special case for getattr: node.args could be 2-argument or 3-argument # 2-argument: attribute access; 3-argument: fall through to attrib function call with default value - if (global_name == "getattr" and isinstance(node.args, tuple) and isinstance(node.args[1], str) - and node.args[1].isidentifier() and len(node.args) == 2): + if ( + global_name == "getattr" + and isinstance(node.args, tuple) + and isinstance(node.args[1], str) + and node.args[1].isidentifier() + and len(node.args) == 2 + ): body.append( - f"{repr(node)}{maybe_type_annotation} = {_format_target(repr(node.args[0]), node.args[1])}") + f"{repr(node)}{maybe_type_annotation} = {_format_target(repr(node.args[0]), node.args[1])}" + ) return body.append( - f"{repr(node)}{maybe_type_annotation} = {global_name}({_format_args(node.args, node.kwargs)})") + f"{repr(node)}{maybe_type_annotation} = {global_name}({_format_args(node.args, node.kwargs)})" + ) if node.meta.get("is_wrapped", False): wrapped_fns.setdefault(global_name) return elif node.op == "call_module": assert isinstance(node.target, str) - body.append(f"{repr(node)}{maybe_type_annotation} = " - f"{_format_target(root_module, node.target)}({_format_args(node.args, node.kwargs)})") + body.append( + f"{repr(node)}{maybe_type_annotation} = " + f"{_format_target(root_module, node.target)}({_format_args(node.args, node.kwargs)})" + ) return elif node.op == "get_attr": assert isinstance(node.target, str) @@ -523,8 +576,9 @@ if AUTOCHUNK_AVAILABLE: # if any node has a list of labels for activation_checkpoint, we # will use nested type of activation checkpoint codegen - emit_code_with_chunk(body, nodes, emit_node, delete_unused_values, self.search_chunk, self.chunk_infos, - self.eval_mem) + emit_code_with_chunk( + body, nodes, emit_node, delete_unused_values, self.search_chunk, self.chunk_infos, self.eval_mem + ) if len(body) == 0: # If the Graph has no non-placeholder nodes, no lines for the body diff --git a/colossalai/autochunk/estimate_memory.py b/colossalai/autochunk/estimate_memory.py index 77bc2ef17bc3bc5faca5903ddd4dfc5a7653275a..a85ad429e261fa2d9de968ac34af53ea0a44dfe5 100644 --- a/colossalai/autochunk/estimate_memory.py +++ b/colossalai/autochunk/estimate_memory.py @@ -1,11 +1,8 @@ -import copy -from typing import Any, Callable, Dict, Iterable, List, Tuple +from typing import Dict, List import torch from torch.fx.node import Node -from colossalai.fx.profiler import activation_size, parameter_size - from .utils import NodeMgr, get_node_shape, is_non_memory_node @@ -62,12 +59,9 @@ class EstimateMemory(object): delete_node_dict[node] = max(node_user_idx) return delete_node_dict - def _remove_deactive_node(self, - user_idx: int, - user: Node, - active_nodes: List, - delete_node_dict: List, - kept_nodes: List = None) -> None: + def _remove_deactive_node( + self, user_idx: int, user: Node, active_nodes: List, delete_node_dict: List, kept_nodes: List = None + ) -> None: """ remove deactivate nodes from active nodes """ @@ -169,7 +163,7 @@ class EstimateMemory(object): use_chunk = True if chunk_infos is not None else False chunk_within = False chunk_region_idx = None - chunk_ratio = 1 # use it to estimate chunk mem + chunk_ratio = 1 # use it to estimate chunk mem chunk_inputs_all = [] if use_chunk: @@ -184,7 +178,6 @@ class EstimateMemory(object): chunk_sizes = [i["chunk_size"] if "chunk_size" in i else 1 for i in chunk_infos] for idx, node in enumerate(node_mgr.get_node_list()): - # if node in chunk start nodes, change chunk ratio and add chunk_tensor if use_chunk and idx in chunk_starts: chunk_within = True @@ -193,8 +186,9 @@ class EstimateMemory(object): # determine chunk ratio for current node if chunk_within: - chunk_ratio = self._get_chunk_ratio(node, chunk_node_dim[chunk_region_idx], - chunk_sizes[chunk_region_idx]) + chunk_ratio = self._get_chunk_ratio( + node, chunk_node_dim[chunk_region_idx], chunk_sizes[chunk_region_idx] + ) # add current node as active node self._add_active_node(node, active_nodes, chunk_ratio) @@ -222,7 +216,7 @@ class EstimateMemory(object): # if node in chunk end nodes, restore chunk settings if use_chunk and idx in chunk_ends: - self._remove_deactive_node(idx, node, active_nodes, delete_node_dict) # dont provide kept nodes now + self._remove_deactive_node(idx, node, active_nodes, delete_node_dict) # dont provide kept nodes now chunk_within = False chunk_ratio = 1 chunk_region_idx = None diff --git a/colossalai/autochunk/search_chunk.py b/colossalai/autochunk/search_chunk.py index 59645c80e8089d63a53abd8a27c7784f2f90cd8d..1c599049d9ebfceab3f2c5342b0f25cace31beaf 100644 --- a/colossalai/autochunk/search_chunk.py +++ b/colossalai/autochunk/search_chunk.py @@ -8,7 +8,7 @@ from .reorder_graph import ReorderGraph from .select_chunk import SelectChunk from .trace_flow import TraceFlow from .trace_indice import TraceIndice -from .utils import NodeMgr, get_logger, get_node_shape, is_non_compute_node, is_non_compute_node_except_placeholder +from .utils import NodeMgr, get_logger, is_non_compute_node, is_non_compute_node_except_placeholder class SearchChunk(object): @@ -121,8 +121,10 @@ class SearchChunk(object): # check if peak node already in chunk info if chunk_regions is not None: for i in chunk_regions: - if i["region"][0] < peak_region[0] <= i["region"][1] or \ - i["region"][0] < peak_region[1] <= i["region"][1]: + if ( + i["region"][0] < peak_region[0] <= i["region"][1] + or i["region"][0] < peak_region[1] <= i["region"][1] + ): return None active_node_num = [len(i) for i in active_node] @@ -146,9 +148,9 @@ class SearchChunk(object): region = i["region"] if chunk_region_start >= region[0] and chunk_region_end <= region[1]: return None - elif (region[0] <= chunk_region_start <= region[1] and chunk_region_end > region[1]): + elif region[0] <= chunk_region_start <= region[1] and chunk_region_end > region[1]: chunk_region_start = region[1] + 1 - elif (region[0] <= chunk_region_end <= region[1] and chunk_region_start < region[0]): + elif region[0] <= chunk_region_end <= region[1] and chunk_region_start < region[0]: chunk_region_end = region[0] - 1 return chunk_region_start, chunk_region_end @@ -171,7 +173,7 @@ class SearchChunk(object): chunk_infos: possible regions found """ start_traces = input_trace[start_idx] - if len(start_traces) > 1: # TODO need to be removed + if len(start_traces) > 1: # TODO need to be removed return [] end_trace = output_trace[end_idx] end_node = self.node_mgr.get_node_by_idx(end_idx) @@ -180,8 +182,9 @@ class SearchChunk(object): for end_dim, _ in enumerate(end_trace["indice"]): for start_node, start_trace in start_traces.items(): for start_dim, _ in enumerate(start_trace["indice"]): - if not self.trace_flow.check_region_start_end(start_node, start_dim, start_idx, end_node, end_dim, - end_idx): + if not self.trace_flow.check_region_start_end( + start_node, start_dim, start_idx, end_node, end_dim, end_idx + ): continue # flow search chunk_info = self.trace_flow.flow_search(start_idx, start_dim, end_idx, end_dim) @@ -203,7 +206,7 @@ class SearchChunk(object): """ possible_chunk_region = [] output_trace = copy.deepcopy(self.trace_indice.indice_trace_list) - input_trace = [] # trace of a node's input nodes + input_trace = [] # trace of a node's input nodes for _, n in enumerate(self.node_mgr.get_node_list()): cur_trace = {} for arg in n.args: @@ -215,7 +218,8 @@ class SearchChunk(object): for end_idx in range(peak_region[1], max_chunk_region[1] + 1): # skip non compute nodes if is_non_compute_node(self.node_mgr.get_node_by_idx(start_idx)) or is_non_compute_node( - self.node_mgr.get_node_by_idx(end_idx)): + self.node_mgr.get_node_by_idx(end_idx) + ): continue # select free dim chunk_info = self._find_chunk_info(input_trace, output_trace, start_idx, end_idx) @@ -279,15 +283,18 @@ class SearchChunk(object): chunk_infos.append(chunk_info) mem_peak, _, active_node = self.estimate_memory.estimate_chunk_inference_mem( - self.node_mgr.get_node_list(), chunk_infos) + self.node_mgr.get_node_list(), chunk_infos + ) if self.print_progress: - get_logger().info("AutoChunk find chunk region %d = (%d, %d)" % - (len(chunk_infos), chunk_info["region"][0], chunk_info["region"][1])) + get_logger().info( + "AutoChunk find chunk region %d = (%d, %d)" + % (len(chunk_infos), chunk_info["region"][0], chunk_info["region"][1]) + ) if self.print_mem: self.print_mem = False - self.estimate_memory.estimate_chunk_inference_mem(self.node_mgr.get_node_list(), - chunk_infos, - print_mem=True) + self.estimate_memory.estimate_chunk_inference_mem( + self.node_mgr.get_node_list(), chunk_infos, print_mem=True + ) return chunk_infos diff --git a/colossalai/autochunk/select_chunk.py b/colossalai/autochunk/select_chunk.py index 94a29bfd56911eb9df749f284ae64fbd1b5d7a18..8a60ba681f70192a21e2f03ffdc656f6916dbd3d 100644 --- a/colossalai/autochunk/select_chunk.py +++ b/colossalai/autochunk/select_chunk.py @@ -5,7 +5,6 @@ from .utils import NodeMgr, is_non_compute_node class SelectChunk(object): - def __init__( self, trace_indice: TraceIndice, @@ -20,7 +19,7 @@ class SelectChunk(object): self.node_mgr = node_mgr if max_memory is not None: self.stratge = "fit_memory" - self.max_memory = max_memory # MB + self.max_memory = max_memory # MB else: self.stratge = "min_memory" @@ -57,16 +56,18 @@ class SelectChunk(object): cur_node_list, cur_region = self.reorder_graph.tmp_reorder(self.node_mgr.get_node_list(), cur_region) cur_chunk_infos = chunk_infos + [cur_region] cur_mem = self.estimate_memory.estimate_chunk_inference_mem(cur_node_list, cur_chunk_infos)[0] - cur_chunk_region_peak = cur_mem[cur_region["region"][0]:cur_region["region"][1] + 1] + cur_chunk_region_peak = cur_mem[cur_region["region"][0] : cur_region["region"][1] + 1] cur_chunk_region_max_peak = max(cur_chunk_region_peak) if cur_chunk_region_max_peak < self.max_memory: - regions_dict.append({ - "chunk_info": region, - "chunk_max_mem": cur_chunk_region_max_peak, - "chunk_len": self._get_compute_node_num(region["region"][0], region["region"][1]), - "reorder_chunk_info": cur_region, - "reorder_node_list": cur_node_list, - }) + regions_dict.append( + { + "chunk_info": region, + "chunk_max_mem": cur_chunk_region_max_peak, + "chunk_len": self._get_compute_node_num(region["region"][0], region["region"][1]), + "reorder_chunk_info": cur_region, + "reorder_node_list": cur_node_list, + } + ) # no region found if len(regions_dict) == 0: raise RuntimeError("Search failed. Try a larger memory threshold.") @@ -90,13 +91,15 @@ class SelectChunk(object): chunk_size *= 2 reorder_chunk_info["chunk_size"] = chunk_size cur_chunk_infos = chunk_infos + [reorder_chunk_info] - cur_mem_peak = self.estimate_memory.estimate_chunk_inference_mem(chunk_region_dict["reorder_node_list"], - cur_chunk_infos)[0] - cur_chunk_max_mem = max(cur_mem_peak[reorder_chunk_info["region"][0]:reorder_chunk_info["region"][1] + 1]) + cur_mem_peak = self.estimate_memory.estimate_chunk_inference_mem( + chunk_region_dict["reorder_node_list"], cur_chunk_infos + )[0] + cur_chunk_max_mem = max(cur_mem_peak[reorder_chunk_info["region"][0] : reorder_chunk_info["region"][1] + 1]) # search exact size chunk_info = chunk_region_dict["chunk_info"] - chunk_info["chunk_size"] = self._chunk_size_binary_search(chunk_size // 2, chunk_size, chunk_region_dict, - chunk_infos) + chunk_info["chunk_size"] = self._chunk_size_binary_search( + chunk_size // 2, chunk_size, chunk_region_dict, chunk_infos + ) return chunk_info def _chunk_size_binary_search(self, left, right, chunk_region_dict, chunk_infos): @@ -109,9 +112,10 @@ class SelectChunk(object): mid = int((left + right) / 2 + 0.5) chunk_info["chunk_size"] = mid cur_chunk_infos = chunk_infos + [chunk_info] - cur_mem_peak = self.estimate_memory.estimate_chunk_inference_mem(chunk_region_dict["reorder_node_list"], - cur_chunk_infos)[0] - cur_chunk_max_mem = max(cur_mem_peak[chunk_info["region"][0]:chunk_info["region"][1] + 1]) + cur_mem_peak = self.estimate_memory.estimate_chunk_inference_mem( + chunk_region_dict["reorder_node_list"], cur_chunk_infos + )[0] + cur_chunk_max_mem = max(cur_mem_peak[chunk_info["region"][0] : chunk_info["region"][1] + 1]) if cur_chunk_max_mem >= self.max_memory: right = mid - gap else: @@ -139,8 +143,10 @@ class SelectChunk(object): return None # get max possible chunk region - max_possible_chunk_region = (min([i["region"][0] for i in possible_chunk_regions]), - max([i["region"][1] for i in possible_chunk_regions])) + max_possible_chunk_region = ( + min([i["region"][0] for i in possible_chunk_regions]), + max([i["region"][1] for i in possible_chunk_regions]), + ) # get mem for chunk region regions_dict_list = [] @@ -149,15 +155,17 @@ class SelectChunk(object): cur_node_list, cur_region = self.reorder_graph.tmp_reorder(self.node_mgr.get_node_list(), cur_region) cur_chunk_infos = chunk_infos + [cur_region] cur_mem_peak = self.estimate_memory.estimate_chunk_inference_mem(cur_node_list, cur_chunk_infos)[0] - cur_chunk_region_peak = cur_mem_peak[max_possible_chunk_region[0]:max_possible_chunk_region[1] + 1] + cur_chunk_region_peak = cur_mem_peak[max_possible_chunk_region[0] : max_possible_chunk_region[1] + 1] cur_chunk_region_max_peak = max(cur_chunk_region_peak) - regions_dict_list.append({ - "chunk_info": region, - "chunk_max_mem": cur_chunk_region_max_peak, - "chunk_len": self._get_compute_node_num(region["region"][0], region["region"][1]), - "reorder_chunk_info": cur_region, - "reorder_node_list": cur_node_list, - }) + regions_dict_list.append( + { + "chunk_info": region, + "chunk_max_mem": cur_chunk_region_max_peak, + "chunk_len": self._get_compute_node_num(region["region"][0], region["region"][1]), + "reorder_chunk_info": cur_region, + "reorder_node_list": cur_node_list, + } + ) # select the min mem chunk_max_mem = [i["chunk_max_mem"] for i in regions_dict_list] @@ -175,7 +183,9 @@ class SelectChunk(object): return False for i in chunk_infos: region = i["region"] - if not ((chunk_region_start > region[1] and chunk_region_end > region[1]) or - (chunk_region_start < region[0] and chunk_region_end < region[0])): + if not ( + (chunk_region_start > region[1] and chunk_region_end > region[1]) + or (chunk_region_start < region[0] and chunk_region_end < region[0]) + ): return False return True diff --git a/colossalai/autochunk/trace_flow.py b/colossalai/autochunk/trace_flow.py index db25267e9b4200414c3cb89afaacd529a1534e06..8b36c99bbadd0f53e7148d2880e3016473cdebea 100644 --- a/colossalai/autochunk/trace_flow.py +++ b/colossalai/autochunk/trace_flow.py @@ -16,7 +16,6 @@ from .utils import ( class TraceFlow(object): - def __init__(self, trace_indice: TraceIndice, node_mgr: NodeMgr) -> None: self.trace_indice = trace_indice self.node_mgr = node_mgr @@ -64,7 +63,7 @@ class TraceFlow(object): return False return True - def _assgin_single_node_flow( + def _assign_single_node_flow( self, arg_node: Node, start_idx: int, @@ -151,7 +150,7 @@ class TraceFlow(object): return True def _get_all_node_info(self, end_dim, start_idx, end_idx): - cur_node_list = [self.node_mgr.get_node_by_idx(end_idx)] # start from the last node + cur_node_list = [self.node_mgr.get_node_by_idx(end_idx)] # start from the last node all_node_info = {cur_node_list[0]: {"chunk_dim": end_dim, "fix_dim": []}} while len(cur_node_list) > 0: @@ -177,7 +176,7 @@ class TraceFlow(object): if get_node_shape(arg) is None: continue arg_list.append(arg) - flow_flag = self._assgin_single_node_flow( + flow_flag = self._assign_single_node_flow( arg, start_idx, end_idx, @@ -266,7 +265,7 @@ class TraceFlow(object): maybe_prepose_nodes.sort( key=lambda x: self.node_mgr.find_node_idx(x), reverse=True, - ) # from last node to first node + ) # from last node to first node prepose_nodes = [] # set every node as root, search its args, if all legal, turn root and args as prepose nodes while len(maybe_prepose_nodes) > 0: @@ -315,7 +314,7 @@ class TraceFlow(object): chunk_info["args"]["prepose_nodes"] = prepose_nodes def _get_non_chunk_inputs(self, chunk_info, start_idx, end_idx): - # we need to log input nodes to avoid deleteing them in the loop + # we need to log input nodes to avoid deleting them in the loop chunk_node_list = self.node_mgr.get_node_slice_by_idx(start_idx, end_idx + 1) # also need to get some prepose node's arg out of non_chunk_inputs for n in chunk_info["args"]["prepose_nodes"]: @@ -328,7 +327,8 @@ class TraceFlow(object): def flow_search(self, start_idx, start_dim, end_idx, end_dim): inputs, outputs = find_chunk_compute_input_and_output_nodes( - self.node_mgr.get_node_slice_by_idx(start_idx, end_idx + 1)) + self.node_mgr.get_node_slice_by_idx(start_idx, end_idx + 1) + ) # get every node's chunk dim and fix dim all_node_info = self._get_all_node_info(end_dim, start_idx, end_idx) @@ -366,13 +366,14 @@ class TraceFlow(object): # find non chunk inputs chunk_info = self._get_non_chunk_inputs(chunk_info, start_idx, end_idx) - # reassgin reshape size, some size may have changed due to chunk - chunk_info = self._reassgin_reshape_size(chunk_info) + # reassign reshape size, some size may have changed due to chunk + chunk_info = self._reassign_reshape_size(chunk_info) return chunk_info - def _get_other_output_info(self, outputs: List[Node], start_idx: int, start_dim: int, end_idx: int, end_dim: int, - chunk_info: Dict): + def _get_other_output_info( + self, outputs: List[Node], start_idx: int, start_dim: int, end_idx: int, end_dim: int, chunk_info: Dict + ): start_node = self.node_mgr.get_node_by_idx(start_idx) # loop all outputs for output in outputs: @@ -384,8 +385,8 @@ class TraceFlow(object): # skip non tensor if get_node_shape(output) is None: # log shape tensor - if len(output.meta['fwd_out']) > 0 and isinstance(output.meta['fwd_out'][0], int): - chunk_info["outputs_non_tensor"][output] = str(output.meta['fwd_out']) + if len(output.meta["fwd_out"]) > 0 and isinstance(output.meta["fwd_out"][0], int): + chunk_info["outputs_non_tensor"][output] = str(output.meta["fwd_out"]) continue # loop every dim of outputs, try to find a legal one for output_dim in range(len(get_node_shape(output))): @@ -421,17 +422,18 @@ class TraceFlow(object): for k, v in new_all_node_info.items(): if k in chunk_info["node_chunk_dim"]: chunk_info["node_chunk_dim"][k]["fix_dim"] = list( - set(chunk_info["node_chunk_dim"][k]["fix_dim"] + v["fix_dim"])) + set(chunk_info["node_chunk_dim"][k]["fix_dim"] + v["fix_dim"]) + ) else: chunk_info["node_chunk_dim"][k] = v chunk_info["outputs"].append(output) chunk_info["outputs_dim"].append(output_dim) return True - def _reassgin_reshape_size(self, chunk_info): + def _reassign_reshape_size(self, chunk_info): """ Some shape args in reshape may have changed due to chunk - reassgin those changed shape + reassign those changed shape """ chunk_region = chunk_info["region"] reshape_size = {} @@ -443,8 +445,11 @@ class TraceFlow(object): if node.args[0] in chunk_info["inputs_non_chunk"]: continue reshape_args = flat_list(node.args[1:]) - if len(reshape_args) == 1 and get_node_shape(reshape_args[0]) is None and len( - reshape_args[0].meta['fwd_out']) > 1: + if ( + len(reshape_args) == 1 + and get_node_shape(reshape_args[0]) is None + and len(reshape_args[0].meta["fwd_out"]) > 1 + ): continue chunk_dim = chunk_info["node_chunk_dim"][node]["chunk_dim"] new_shape = "" @@ -462,16 +467,17 @@ class TraceFlow(object): chunk_info["reshape_size"] = reshape_size return chunk_info - def check_region_start_end(self, start_node: Node, start_dim: int, start_idx: int, end_node: Node, end_dim: int, - end_idx: int) -> bool: + def check_region_start_end( + self, start_node: Node, start_dim: int, start_idx: int, end_node: Node, end_dim: int, end_idx: int + ) -> bool: """ check if region start and end is legal """ # dim cannot be None - if (get_node_shape(end_node) is None or get_node_shape(start_node) is None): + if get_node_shape(end_node) is None or get_node_shape(start_node) is None: return False # dim size cannot be 1 - if (get_node_shape(end_node)[end_dim] == 1 or get_node_shape(start_node)[start_dim] == 1): + if get_node_shape(end_node)[end_dim] == 1 or get_node_shape(start_node)[start_dim] == 1: return False # must have users if len(end_node.users) == 0: diff --git a/colossalai/autochunk/trace_indice.py b/colossalai/autochunk/trace_indice.py index c7fce4c8bee1f2886d63c9ca6bb35c332cb293d8..378c54acf782090f11bbbe7991e4bab317c1a6f4 100644 --- a/colossalai/autochunk/trace_indice.py +++ b/colossalai/autochunk/trace_indice.py @@ -1,5 +1,5 @@ import copy -from typing import Dict, List, Tuple +from typing import Dict, List from torch.fx.node import Node @@ -18,7 +18,7 @@ class TraceIndice(object): dim(x1)=dim(x2)=dim(x3)=[a, b, c] This class will record every node's dims' indice, compute and source. - Attibutes: + Attributes: node_list (List) indice_trace_list (List): [{"indice": [...], "compute": [...], "source": [...]}, {...}] indice_view_list (Dict): not used for now @@ -397,7 +397,7 @@ class TraceIndice(object): input_node = node.args[0] assert len(get_node_shape(input_node)) == 4 - # assgin index + # assign index self._assign_indice_as_input(node, node_idx, input_node) self._del_dim(node_idx, 1) self._add_dim(node_idx, 1) @@ -412,10 +412,10 @@ class TraceIndice(object): node_idx (int) """ # get conv input - assert node.kwargs['size'] is None + assert node.kwargs["size"] is None assert len(get_node_shape(node)) == 4 - # assgin index + # assign index self._assign_indice_as_input(node, node_idx) self._mark_computation(node, node_idx, [-1, -2]) @@ -461,7 +461,7 @@ class TraceIndice(object): nodes_in.append(node_in) self._inherit_more_indice_from_node_with_exclude(node_in, node) - def _assgin_no_change_indice(self, node, idx): + def _assign_no_change_indice(self, node, idx): self._assign_indice_as_input(node, idx) for node_in in node.args: if type(node_in) == type(node): @@ -792,7 +792,7 @@ class TraceIndice(object): self._add_dim(node_idx, i) dim_from.reverse() - # inheirt indice from current node + # inherit indice from current node if len(dim_from) != 0 and len(dim_to) != 0: if dim_diff == 1: if origin_shape[dim_from[0]] == 1: @@ -826,7 +826,7 @@ class TraceIndice(object): # clear compute for dim_compute in trace["compute"]: for i in range(len(dim_compute) - 1, -1, -1): - if (dim_compute[i] < trace_barrier and dim_compute[i] not in active_nodes): + if dim_compute[i] < trace_barrier and dim_compute[i] not in active_nodes: dim_compute.pop(i) continue # clear source @@ -852,7 +852,7 @@ class TraceIndice(object): elif "split" == node_name: self._assign_split_indice(node, idx) elif any(i == node_name for i in ["to", "contiguous", "clone", "type", "float"]): - self._assgin_no_change_indice(node, idx) + self._assign_no_change_indice(node, idx) elif "new_ones" == node_name: self._assign_all_indice(node, idx) elif "flatten" == node_name: @@ -876,10 +876,24 @@ class TraceIndice(object): self._assign_matmul_indice(node, idx) elif "softmax" == node_name: self._assign_softmax_indice(node, idx) - elif any(n == node_name for n in [ - "mul", "add", "sigmoid", "relu", "sub", "truediv", "pow", "dropout", "where", "tanh", "exp", - "sin", "cos" - ]): + elif any( + n == node_name + for n in [ + "mul", + "add", + "sigmoid", + "relu", + "sub", + "truediv", + "pow", + "dropout", + "where", + "tanh", + "exp", + "sin", + "cos", + ] + ): self._assign_elementwise_indice(node, idx) elif "einsum" == node_name: self._assign_einsum_indice(node, idx) @@ -914,13 +928,13 @@ class TraceIndice(object): elif "conv2d" == node_name: self._assign_conv2d_indice(node, idx) elif "identity" == node_name: - self._assgin_no_change_indice(node, idx) + self._assign_no_change_indice(node, idx) elif any(n == node_name for n in ["sigmoid", "dropout", "relu", "silu", "gelu"]): self._assign_elementwise_indice(node, idx) else: raise NotImplementedError(node_name, "module not implemented yet!") elif node.op == "get_attr": - self._assign_all_indice(node, idx) # get param + self._assign_all_indice(node, idx) # get param elif node.op == "output": continue else: diff --git a/colossalai/autochunk/utils.py b/colossalai/autochunk/utils.py index 064baa047155ac399b14ae0fcdb044db1125d70b..f6f803a5ce0a80f337443faaa7ffdea741688800 100644 --- a/colossalai/autochunk/utils.py +++ b/colossalai/autochunk/utils.py @@ -1,4 +1,4 @@ -from typing import Any, Callable, Dict, Iterable, List, Tuple, Union +from typing import Any, Dict, List, Union from torch.fx.node import Node @@ -10,7 +10,6 @@ logger = get_dist_logger() class NodeMgr(object): - def __init__(self, nodes_list: List[Node]) -> None: self._node_list = nodes_list self._node_dict = {} @@ -174,16 +173,22 @@ def find_chunk_compute_input_and_output_nodes(nodes: List[Node]) -> Union[List, # we treat that input node as the input of the checkpoint function for node in nodes: for input_node in node._input_nodes.keys(): - if (input_node not in nodes and input_node not in input_nodes - and not is_non_compute_node_except_placeholder(input_node)): + if ( + input_node not in nodes + and input_node not in input_nodes + and not is_non_compute_node_except_placeholder(input_node) + ): input_nodes.append(input_node) # if a node has a user node which is not in the node list # we treat that user node as the node receiving the current node output for node in nodes: for output_node in node.users.keys(): - if (output_node not in nodes and node not in output_nodes - and not is_non_compute_node_except_placeholder_output(output_node)): + if ( + output_node not in nodes + and node not in output_nodes + and not is_non_compute_node_except_placeholder_output(output_node) + ): output_nodes.append(node) return input_nodes, output_nodes @@ -238,7 +243,10 @@ def find_tensor_shape_node(node_list: List[Node]) -> List[Node]: for node in node_list: if get_node_shape(node) is not None: out.append(node) - elif len(node.meta['fwd_out']) > 0 and isinstance(node.meta['fwd_out'], list) and isinstance( - node.meta['fwd_out'][0], int): + elif ( + len(node.meta["fwd_out"]) > 0 + and isinstance(node.meta["fwd_out"], list) + and isinstance(node.meta["fwd_out"][0], int) + ): out.append(node) return out diff --git a/colossalai/booster/accelerator.py b/colossalai/booster/accelerator.py index fc2c4a40068b50cb49db8a3f9c33f22b8f307966..92990907bc2e5d24c67c3864682195ce22adca89 100644 --- a/colossalai/booster/accelerator.py +++ b/colossalai/booster/accelerator.py @@ -1,12 +1,11 @@ import torch import torch.nn as nn -__all__ = ['Accelerator'] +__all__ = ["Accelerator"] _supported_devices = [ - 'cpu', - 'cuda', - + "cpu", + "cuda", # To be supported # 'xpu', # 'npu', @@ -25,21 +24,22 @@ class Accelerator: def __init__(self, device: str): self.device = device - assert self.device in _supported_devices, f"Device {self.device} is not supported yet, supported devices include {_supported_devices}" + assert ( + self.device in _supported_devices + ), f"Device {self.device} is not supported yet, supported devices include {_supported_devices}" def bind(self): """ Set the default device for the current process. """ - if self.device == 'cpu': + if self.device == "cpu": pass - elif self.device == 'cuda': + elif self.device == "cuda": # TODO(FrankLeeeee): use global environment to check if it is a dist job # if is_distributed: # local_rank = EnvTable().get_local_rank() # torch.cuda.set_device(torch.device(f'cuda:{local_rank}')) - torch.cuda.set_device(torch.device('cuda')) - pass + torch.cuda.set_device(torch.device("cuda")) else: raise ValueError(f"Device {self.device} is not supported yet") diff --git a/colossalai/booster/booster.py b/colossalai/booster/booster.py index c14e602deaf5ce60808a87bbbd238c9419b9d502..d73bc5babd8000689576e6f291bb6db82b4287d7 100644 --- a/colossalai/booster/booster.py +++ b/colossalai/booster/booster.py @@ -1,6 +1,6 @@ import warnings from contextlib import contextmanager -from typing import Callable, Iterator, List, Optional, Tuple, Union +from typing import Any, Callable, Dict, Iterator, List, Optional, Union import torch import torch.nn as nn @@ -8,13 +8,16 @@ from torch.optim import Optimizer from torch.optim.lr_scheduler import _LRScheduler as LRScheduler from torch.utils.data import DataLoader +import colossalai.interface.pretrained as pretrained_utils from colossalai.checkpoint_io import GeneralCheckpointIO +from colossalai.interface import ModelWrapper, OptimizerWrapper from .accelerator import Accelerator from .mixed_precision import MixedPrecision, mixed_precision_factory from .plugin import Plugin +from .plugin.pp_plugin_base import PipelinePluginBase -__all__ = ['Booster'] +__all__ = ["Booster"] class Booster: @@ -22,56 +25,67 @@ class Booster: Booster is a high-level API for training neural networks. It provides a unified interface for training with different precision, accelerator, and plugin. - Examples: - >>> colossalai.launch(...) - >>> plugin = GeminiPlugin(stage=3, ...) - >>> booster = Booster(precision='fp16', plugin=plugin) - >>> - >>> model = GPT2() - >>> optimizer = Adam(model.parameters()) - >>> dataloader = Dataloader(Dataset) - >>> lr_scheduler = LinearWarmupScheduler() - >>> criterion = GPTLMLoss() - >>> - >>> model, optimizer, lr_scheduler, dataloader = booster.boost(model, optimizer, lr_scheduler, dataloader) - >>> - >>> for epoch in range(max_epochs): - >>> for input_ids, attention_mask in dataloader: - >>> outputs = model(input_ids, attention_mask) - >>> loss = criterion(outputs.logits, input_ids) - >>> booster.backward(loss, optimizer) - >>> optimizer.step() - >>> lr_scheduler.step() - >>> optimizer.zero_grad() + ```python + # Following is pseudocode + + colossalai.launch(...) + plugin = GeminiPlugin(...) + booster = Booster(precision='fp16', plugin=plugin) + + model = GPT2() + optimizer = HybridAdam(model.parameters()) + dataloader = plugin.prepare_dataloader(train_dataset, batch_size=8) + lr_scheduler = LinearWarmupScheduler() + criterion = GPTLMLoss() + + model, optimizer, criterion, dataloader, lr_scheduler = booster.boost(model, optimizer, criterion, dataloader, lr_scheduler) + + for epoch in range(max_epochs): + for input_ids, attention_mask in dataloader: + outputs = model(input_ids.cuda(), attention_mask.cuda()) + loss = criterion(outputs.logits, input_ids) + booster.backward(loss, optimizer) + optimizer.step() + lr_scheduler.step() + optimizer.zero_grad() + ``` Args: - device (str or torch.device): The device to run the training. Default: 'cuda'. + device (str or torch.device): The device to run the training. Default: None. + If plugin is not used or plugin doesn't control the device, + this argument will be set as training device ('cuda' will be used if argument is None). mixed_precision (str or MixedPrecision): The mixed precision to run the training. Default: None. If the argument is a string, it can be 'fp16', 'fp16_apex', 'bf16', or 'fp8'. 'fp16' would use PyTorch AMP while `fp16_apex` would use Nvidia Apex. plugin (Plugin): The plugin to run the training. Default: None. """ - def __init__(self, - device: str = 'cuda', - mixed_precision: Union[MixedPrecision, str] = None, - plugin: Optional[Plugin] = None) -> None: + def __init__( + self, + device: Optional[str] = None, + mixed_precision: Optional[Union[MixedPrecision, str]] = None, + plugin: Optional[Plugin] = None, + ) -> None: if plugin is not None: assert isinstance( - plugin, Plugin), f'Expected the argument plugin to be an instance of Plugin, but got {type(plugin)}.' + plugin, Plugin + ), f"Expected the argument plugin to be an instance of Plugin, but got {type(plugin)}." self.plugin = plugin # set accelerator if self.plugin and self.plugin.control_device(): self.accelerator = None - warnings.warn('The plugin will control the accelerator, so the device argument will be ignored.') + if device is not None: + warnings.warn("The plugin will control the accelerator, so the device argument will be ignored.") else: + device = device or "cuda" self.accelerator = Accelerator(device) # set precision if self.plugin and self.plugin.control_precision(): - warnings.warn('The plugin will control the precision, so the mixed_precision argument will be ignored.') + if mixed_precision is not None: + warnings.warn("The plugin will control the precision, so the mixed_precision argument will be ignored.") self.mixed_precision = None elif mixed_precision is None: self.mixed_precision = None @@ -85,7 +99,7 @@ class Booster: self.mixed_precision = mixed_precision else: raise ValueError( - f'Expected the argument mixed_precision to be a string or an instance of Precision, but got {type(mixed_precision)}.' + f"Expected the argument mixed_precision to be a string or an instance of Precision, but got {type(mixed_precision)}." ) if self.plugin is not None and self.plugin.control_checkpoint_io(): @@ -96,79 +110,216 @@ class Booster: def boost( self, model: nn.Module, - optimizer: Optimizer, - criterion: Callable = None, - dataloader: DataLoader = None, - lr_scheduler: LRScheduler = None, + optimizer: Optional[Optimizer] = None, + criterion: Optional[Callable] = None, + dataloader: Optional[DataLoader] = None, + lr_scheduler: Optional[LRScheduler] = None, ) -> List[Union[nn.Module, Optimizer, LRScheduler, DataLoader]]: """ - Boost the model, optimizer, criterion, lr_scheduler, and dataloader. + Wrap and inject features to the passed in model, optimizer, criterion, lr_scheduler, and dataloader. Args: - model (nn.Module): The model to be boosted. - optimizer (Optimizer): The optimizer to be boosted. - criterion (Callable): The criterion to be boosted. - dataloader (DataLoader): The dataloader to be boosted. - lr_scheduler (LRScheduler): The lr_scheduler to be boosted. + model (nn.Module): Convert model into a wrapped model for distributive training. + The model might be decorated or partitioned by plugin's strategy after execution of this method. + optimizer (Optimizer, optional): Convert optimizer into a wrapped optimizer for distributive training. + The optimizer's param groups or states might be decorated or partitioned by plugin's strategy after execution of this method. Defaults to None. + criterion (Callable, optional): The function that calculates loss. Defaults to None. + dataloader (DataLoader, optional): The prepared dataloader for training. Defaults to None. + lr_scheduler (LRScheduler, optional): The learning scheduler for training. Defaults to None. + + Returns: + List[Union[nn.Module, Optimizer, LRScheduler, DataLoader]]: The list of boosted input arguments. """ # TODO(FrankLeeeee): consider multi-model and multi-optimizer case # TODO(FrankLeeeee): consider multi-dataloader case + pretrained_path = pretrained_utils.get_pretrained_path(model) # transform model for mixed precision if self.plugin: model, optimizer, criterion, dataloader, lr_scheduler = self.plugin.configure( - model, optimizer, criterion, dataloader, lr_scheduler) + model, optimizer, criterion, dataloader, lr_scheduler + ) if self.plugin and not self.plugin.control_device(): # transform model for accelerator - model = self.accelerator.configure(model) + model = self.accelerator.configure_model(model) if self.mixed_precision and (self.plugin is None or self.plugin and not self.plugin.control_precision()): # transform model for mixed precision # when mixed_precision is specified and the plugin is not given or does not control the precision model, optimizer, criterion = self.mixed_precision.configure(model, optimizer, criterion) + if pretrained_path: + self.load_model(model, pretrained_path) + # clear pretrained path attr + orig_model = model.unwrap() if isinstance(model, ModelWrapper) else model + pretrained_utils.set_pretrained_path(orig_model, None) + return model, optimizer, criterion, dataloader, lr_scheduler def backward(self, loss: torch.Tensor, optimizer: Optimizer) -> None: - # TODO: implement this method with plugin + """Execution of backward during training step. + + Args: + loss (torch.Tensor): The loss for backpropagation. + optimizer (Optimizer): The optimizer to be updated. + """ + # TODO(frank lee): implement this method with plugin optimizer.backward(loss) - def execute_pipeline(self, - data_iter: Iterator, - model: nn.Module, - criterion: Callable[[torch.Tensor], torch.Tensor], - optimizer: Optimizer, - return_loss: bool = True, - return_outputs: bool = False) -> Tuple[Optional[torch.Tensor], ...]: - # TODO: implement this method - # run pipeline forward backward pass - # return loss or outputs if needed - pass - - def no_sync(self, model: nn.Module) -> contextmanager: - assert self.plugin is not None, f'no_sync is only enabled when a plugin is provided and the plugin supports no_sync.' - assert self.plugin.support_no_sync, f'The plugin {self.plugin.__class__.__name__} does not support no_sync.' - return self.plugin.no_sync(model) - - def load_model(self, model: nn.Module, checkpoint: str, strict: bool = True): + def execute_pipeline( + self, + data_iter: Iterator, + model: nn.Module, + criterion: Callable[[Any, Any], torch.Tensor], + optimizer: Optional[Optimizer] = None, + return_loss: bool = True, + return_outputs: bool = False, + ) -> Dict[str, Any]: + """ + Execute forward & backward when utilizing pipeline parallel. + Return loss or Huggingface style model outputs if needed. + + Warning: This function is tailored for the scenario of pipeline parallel. + As a result, please don't do the forward/backward pass in the conventional way (model(input)/loss.backward()) + when doing pipeline parallel training with booster, which will cause unexpected errors. + + Args: + data_iter(Iterator): The iterator for getting the next batch of data. Usually there are two ways to obtain this argument: + 1. wrap the dataloader to iterator through: iter(dataloader) + 2. get the next batch from dataloader, and wrap this batch to iterator: iter([batch]) + model (nn.Module): The model to execute forward/backward, it should be a model wrapped by a plugin that supports pipeline. + criterion: (Callable[[Any, Any], torch.Tensor]): Criterion to be used. It should take two arguments: model outputs and inputs, and returns loss tensor. + 'lambda y, x: loss_fn(y)' can turn a normal loss function into a valid two-argument criterion here. + optimizer (Optimizer, optional): The optimizer for execution of backward. Can be None when only doing forward (i.e. evaluation). Defaults to None. + return_loss (bool, optional): Whether to return loss in the dict returned by this method. Defaults to True. + return_output (bool, optional): Whether to return Huggingface style model outputs in the dict returned by this method. Defaults to False. + + Returns: + Dict[str, Any]: Output dict in the form of {'loss': ..., 'outputs': ...}. + ret_dict['loss'] is the loss of forward if return_loss is set to True, else None. + ret_dict['outputs'] is the Huggingface style model outputs during forward if return_output is set to True, else None. + """ + assert isinstance( + self.plugin, PipelinePluginBase + ), f"The plugin {self.plugin.__class__.__name__} does not support pipeline." + return self.plugin.execute_pipeline(data_iter, model, criterion, optimizer, return_loss, return_outputs) + + def no_sync(self, model: nn.Module = None, optimizer: OptimizerWrapper = None) -> contextmanager: + """Context manager to disable gradient synchronization across DP process groups. + Support torch DDP and Low Level ZeRO-1 for now. + + Args: + model (nn.Module): The model to be disabled gradient synchronization, for DDP + optimizer (OptimizerWrapper): The optimizer to be disabled gradient synchronization, for ZeRO1-1 + + Returns: + contextmanager: Context to disable gradient synchronization. + """ + assert ( + self.plugin is not None + ), f"no_sync is only enabled when a plugin is provided and the plugin supports no_sync." + assert self.plugin.support_no_sync(), f"The plugin {self.plugin.__class__.__name__} does not support no_sync." + return self.plugin.no_sync(model, optimizer) + + def load_model(self, model: Union[nn.Module, ModelWrapper], checkpoint: str, strict: bool = True) -> None: + """Load model from checkpoint. + + Args: + model (nn.Module or ModelWrapper): A model boosted by Booster. + checkpoint (str): Path to the checkpoint. It must be a local path. + It should be a directory path if the checkpoint is sharded. Otherwise, it should be a file path. + strict (bool, optional): whether to strictly enforce that the keys + in :attr:`state_dict` match the keys returned by this module's + :meth:`~torch.nn.Module.state_dict` function. Defaults to True. + """ self.checkpoint_io.load_model(model, checkpoint, strict) - def save_model(self, - model: nn.Module, - checkpoint: str, - prefix: str = None, - shard: bool = False, - size_per_shard: int = 1024): - self.checkpoint_io.save_model(model, checkpoint, prefix, shard, size_per_shard) + def save_model( + self, + model: Union[nn.Module, ModelWrapper], + checkpoint: str, + shard: bool = False, + gather_dtensor: bool = True, + prefix: Optional[str] = None, + size_per_shard: int = 1024, + use_safetensors: bool = False, + ) -> None: + """Save model to checkpoint. + + Args: + model (nn.Module or ModelWrapper): A model boosted by Booster. + checkpoint (str): Path to the checkpoint. It must be a local path. + It is a file path if ``shard=False``. Otherwise, it is a directory path. + shard (bool, optional): Whether to save checkpoint a sharded way. + If true, the checkpoint will be a folder with the same format as Huggingface transformers checkpoint. Otherwise, it will be a single file. Defaults to False. + gather_dtensor (bool, optional): whether to gather the distributed tensor to the first device. Default: True. + prefix (str, optional): A prefix added to parameter and buffer + names to compose the keys in state_dict. Defaults to None. + size_per_shard (int, optional): Maximum size of checkpoint shard file in MB. This is useful only when ``shard=True``. Defaults to 1024. + use_safetensors (bool, optional): whether to use safe tensors. Default: False. If set to True, the checkpoint will be saved. + """ + self.checkpoint_io.save_model( + model, + checkpoint=checkpoint, + shard=shard, + gather_dtensor=gather_dtensor, + prefix=prefix, + size_per_shard=size_per_shard, + use_safetensors=use_safetensors, + ) + + def load_optimizer(self, optimizer: Optimizer, checkpoint: str) -> None: + """Load optimizer from checkpoint. - def load_optimizer(self, optimizer: Optimizer, checkpoint: str): + Args: + optimizer (Optimizer): An optimizer boosted by Booster. + checkpoint (str): Path to the checkpoint. It must be a local path. + It should be a directory path if the checkpoint is sharded. Otherwise, it should be a file path. + prefix (str, optional): A prefix added to parameter and buffer + names to compose the keys in state_dict. Defaults to None. + size_per_shard (int, optional): Maximum size of checkpoint shard file in MB. This is useful only when ``shard=True``. Defaults to 1024. + """ self.checkpoint_io.load_optimizer(optimizer, checkpoint) - def save_optimizer(self, optimizer: Optimizer, checkpoint: str, shard: bool = False, size_per_shard: int = 1024): - self.checkpoint_io.save_optimizer(optimizer, checkpoint, shard, size_per_shard) + def save_optimizer( + self, + optimizer: Optimizer, + checkpoint: str, + shard: bool = False, + gather_dtensor: bool = True, + prefix: Optional[str] = None, + size_per_shard: int = 1024, + ) -> None: + """ + Save optimizer to checkpoint. + + Args: + optimizer (Optimizer): An optimizer boosted by Booster. + checkpoint (str): Path to the checkpoint. It must be a local path. + It is a file path if ``shard=False``. Otherwise, it is a directory path. + shard (bool, optional): Whether to save checkpoint a sharded way. + If true, the checkpoint will be a folder. Otherwise, it will be a single file. Defaults to False. + gather_dtensor (bool): whether to gather the distributed tensor to the first device. Default: True. + prefix (str, optional): A prefix added to parameter and buffer + names to compose the keys in state_dict. Defaults to None. + size_per_shard (int, optional): Maximum size of checkpoint shard file in MB. This is useful only when ``shard=True``. Defaults to 1024. + """ + self.checkpoint_io.save_optimizer(optimizer, checkpoint, shard, gather_dtensor, prefix, size_per_shard) + + def save_lr_scheduler(self, lr_scheduler: LRScheduler, checkpoint: str) -> None: + """Save lr scheduler to checkpoint. - def save_lr_scheduler(self, lr_scheduler: LRScheduler, checkpoint: str): + Args: + lr_scheduler (LRScheduler): A lr scheduler boosted by Booster. + checkpoint (str): Path to the checkpoint. It must be a local file path. + """ self.checkpoint_io.save_lr_scheduler(lr_scheduler, checkpoint) - def load_lr_scheduler(self, lr_scheduler: LRScheduler, checkpoint: str): + def load_lr_scheduler(self, lr_scheduler: LRScheduler, checkpoint: str) -> None: + """Load lr scheduler from checkpoint. + + Args: + lr_scheduler (LRScheduler): A lr scheduler boosted by Booster. + checkpoint (str): Path to the checkpoint. It must be a local file path. + """ self.checkpoint_io.load_lr_scheduler(lr_scheduler, checkpoint) diff --git a/colossalai/booster/mixed_precision/__init__.py b/colossalai/booster/mixed_precision/__init__.py index 3cf0ad28cdbe78b24e1c8fd18fd6bd473e095bd6..68c6221ec809865284826c189f98b5ddf90a3975 100644 --- a/colossalai/booster/mixed_precision/__init__.py +++ b/colossalai/booster/mixed_precision/__init__.py @@ -1,19 +1,27 @@ from .bf16 import BF16MixedPrecision from .fp8 import FP8MixedPrecision from .fp16_apex import FP16ApexMixedPrecision +from .fp16_naive import FP16NaiveMixedPrecision from .fp16_torch import FP16TorchMixedPrecision from .mixed_precision_base import MixedPrecision __all__ = [ - 'MixedPrecision', 'mixed_precision_factory', 'FP16_Apex_MixedPrecision', 'FP16_Torch_MixedPrecision', - 'FP32_MixedPrecision', 'BF16_MixedPrecision', 'FP8_MixedPrecision' + "MixedPrecision", + "mixed_precision_factory", + "FP16_Apex_MixedPrecision", + "FP16_Torch_MixedPrecision", + "FP32_MixedPrecision", + "BF16_MixedPrecision", + "FP8_MixedPrecision", + "FP16NaiveMixedPrecision", ] _mixed_precision_mapping = { - 'fp16': FP16TorchMixedPrecision, - 'fp16_apex': FP16ApexMixedPrecision, - 'bf16': BF16MixedPrecision, - 'fp8': FP8MixedPrecision + "fp16": FP16TorchMixedPrecision, + "fp16_apex": FP16ApexMixedPrecision, + "fp16_naive": FP16NaiveMixedPrecision, + "bf16": BF16MixedPrecision, + "fp8": FP8MixedPrecision, } @@ -29,5 +37,5 @@ def mixed_precision_factory(mixed_precision_type: str) -> MixedPrecision: return _mixed_precision_mapping[mixed_precision_type]() else: raise ValueError( - f'Mixed precision type {mixed_precision_type} is not supported, support types include {list(_mixed_precision_mapping.keys())}' + f"Mixed precision type {mixed_precision_type} is not supported, support types include {list(_mixed_precision_mapping.keys())}" ) diff --git a/colossalai/booster/mixed_precision/fp16_apex.py b/colossalai/booster/mixed_precision/fp16_apex.py index 266a750734b14ade3c5f53fc71ef276d09e1ec83..2fa7b54cdd3094fc2ded6fabd77e35ad77d56aa3 100644 --- a/colossalai/booster/mixed_precision/fp16_apex.py +++ b/colossalai/booster/mixed_precision/fp16_apex.py @@ -1,5 +1,40 @@ +from typing import Any, Optional, Union + +import torch + from .mixed_precision_base import MixedPrecision class FP16ApexMixedPrecision(MixedPrecision): - pass + """ + Precision for mixed precision training in FP16 using apex AMP. + + Args: + opt_level(str, optional, default="O1" ): Pure or mixed precision optimization level. Accepted values are “O0”, “O1”, “O2”, and “O3”, explained in detail above Apex AMP Documentation. + cast_model_type (torch.dtype, optional, default=None): Casts your model’s parameters and buffers to the desired type. + patch_torch_functions (bool, optional, default=None): Patch all Torch functions and Tensor methods to perform Tensor Core-friendly ops like GEMMs and convolutions in FP16, and any ops that benefit from FP32 precision in FP32. + keep_batchnorm_fp32 (bool or str, optional, default=None): To enhance precision and enable cudnn batchnorm (which improves performance), it’s often beneficial to keep batchnorm weights in FP32 even if the rest of the model is FP16. + master_weights (bool, optional, default=None): Maintain FP32 master weights to accompany any FP16 model weights. FP32 master weights are stepped by the optimizer to enhance precision and capture small gradients. + loss_scale (float or str, optional, default=None): If loss_scale is a float value, use this value as the static (fixed) loss scale. If loss_scale is the string "dynamic", adaptively adjust the loss scale over time. Dynamic loss scale adjustments are performed by Amp automatically. + cast_model_outputs (torch.dpython:type, optional, default=None): Option to ensure that the outputs of your model(s) are always cast to a particular type regardless of opt_level. + num_losses(int, optional, default=1): Option to tell AMP in advance how many losses/backward passes you plan to use. When used in conjunction with the loss_id argument to `amp.scale_loss`, enables Amp to use a different loss scale per loss/backward pass, which can improve stability. If num_losses is left to 1, Amp will still support multiple losses/backward passes, but use a single global loss scale for all of them. + verbosity(int, default=1): Set to 0 to suppress Amp-related output. + min_loss_scale(float, default=None): Sets a floor for the loss scale values that can be chosen by dynamic loss scaling. The default value of None means that no floor is imposed. If dynamic loss scaling is not used, min_loss_scale is ignored. + max_loss_scale(float, default=2.**24 ): Sets a ceiling for the loss scale values that can be chosen by dynamic loss scaling. If dynamic loss scaling is not used, max_loss_scale is ignored. + """ + + def __init__( + self, + opt_level: Optional[str] = "O1", + cast_model_type: torch.dtype = None, + patch_torch_functions: bool = None, + keep_batchnorm_fp32: Union[bool, str] = None, + master_weights: bool = None, + loss_scale: Union[float, str] = None, + cast_model_outputs: Any = None, + num_losses: Optional[int] = 1, + verbosity: int = 1, + min_loss_scale: float = None, + max_loss_scale: float = 2.0**24, + ) -> None: + pass diff --git a/colossalai/booster/mixed_precision/fp16_naive.py b/colossalai/booster/mixed_precision/fp16_naive.py new file mode 100644 index 0000000000000000000000000000000000000000..e5624a9d7477a407246e72d3adf6a66400e46e9e --- /dev/null +++ b/colossalai/booster/mixed_precision/fp16_naive.py @@ -0,0 +1,28 @@ +from .mixed_precision_base import MixedPrecision + + +class FP16NaiveMixedPrecision(MixedPrecision): + """ + Precision for mixed precision training in FP16 using naive AMP. + + Args: + log_num_zeros_in_grad(bool): return number of zeros in the gradients. + initial_scale(int): initial scale of gradient scaler. + growth_factor(int): the growth rate of loss scale. + backoff_factor(float): the decrease rate of loss scale. + hysteresis(int): delay shift in dynamic loss scaling. + max_scale(int): maximum loss scale allowed. + verbose(bool): if set to `True`, will print debug info. + """ + + def __init__( + self, + log_num_zeros_in_grad: bool, + initial_scale: int, + growth_factor: int, + backoff_factor: float, + hysteresis: int, + max_scale: int, + verbose: bool = None, + ) -> None: + pass diff --git a/colossalai/booster/mixed_precision/fp16_torch.py b/colossalai/booster/mixed_precision/fp16_torch.py index 9999aa5e0eb475b8303b76382d9287b3ac876696..7dce6e6da33e8e4d9b0ac23452afb8fa24299884 100644 --- a/colossalai/booster/mixed_precision/fp16_torch.py +++ b/colossalai/booster/mixed_precision/fp16_torch.py @@ -1,4 +1,4 @@ -from typing import Any, Callable, Dict, List, Optional, Tuple, Union +from typing import Callable, Optional, Tuple, Union import torch import torch.nn as nn @@ -9,7 +9,7 @@ from colossalai.interface import ModelWrapper, OptimizerWrapper from .mixed_precision_base import MixedPrecision -__all__ = ['FP16_Torch_MixedPrecision', 'TorchAMPOptimizer', 'TorchAMPModule'] +__all__ = ["FP16_Torch_MixedPrecision", "TorchAMPOptimizer", "TorchAMPModule"] class TorchAMPOptimizer(OptimizerWrapper): @@ -29,17 +29,21 @@ class TorchAMPOptimizer(OptimizerWrapper): calls that may cause the scale to increase. Default: 2000. """ - def __init__(self, - optim: Optimizer, - init_scale: float = 2.**16, - growth_factor: float = 2.0, - backoff_factor: float = 0.5, - growth_interval: int = 2000) -> None: + def __init__( + self, + optim: Optimizer, + init_scale: float = 2.0**16, + growth_factor: float = 2.0, + backoff_factor: float = 0.5, + growth_interval: int = 2000, + ) -> None: super().__init__(optim) - self.scaler = torch.cuda.amp.GradScaler(init_scale=init_scale, - growth_factor=growth_factor, - backoff_factor=backoff_factor, - growth_interval=growth_interval) + self.scaler = torch.cuda.amp.GradScaler( + init_scale=init_scale, + growth_factor=growth_factor, + backoff_factor=backoff_factor, + growth_interval=growth_interval, + ) def backward(self, loss: Tensor, *args, **kwargs) -> None: scaled_loss = self.scale_loss(loss) @@ -60,12 +64,14 @@ class TorchAMPOptimizer(OptimizerWrapper): self.unscale_grad() super().clip_grad_by_value(clip_value, *args, **kwargs) - def clip_grad_by_norm(self, - max_norm: Union[float, int], - norm_type: Union[float, int] = 2.0, - error_if_nonfinite: bool = False, - *args, - **kwargs) -> None: + def clip_grad_by_norm( + self, + max_norm: Union[float, int], + norm_type: Union[float, int] = 2.0, + error_if_nonfinite: bool = False, + *args, + **kwargs, + ) -> None: self.unscale_grad() super().clip_grad_by_norm(max_norm, norm_type, error_if_nonfinite, *args, **kwargs) @@ -102,23 +108,30 @@ class FP16TorchMixedPrecision(MixedPrecision): calls that may cause the scale to increase. Default: 2000. """ - def __init__(self, - init_scale: float = 2.**16, - growth_factor: float = 2.0, - backoff_factor: float = 0.5, - growth_interval: int = 2000) -> None: + def __init__( + self, + init_scale: float = 2.0**16, + growth_factor: float = 2.0, + backoff_factor: float = 0.5, + growth_interval: int = 2000, + ) -> None: super().__init__() - self.torch_amp_kwargs = dict(init_scale=init_scale, - growth_factor=growth_factor, - backoff_factor=backoff_factor, - growth_interval=growth_interval) - - def configure(self, - model: nn.Module, - optimizer: Optimizer, - criterion: Callable = None) -> Tuple[nn.Module, OptimizerWrapper, Callable]: + self.torch_amp_kwargs = dict( + init_scale=init_scale, + growth_factor=growth_factor, + backoff_factor=backoff_factor, + growth_interval=growth_interval, + ) + + def configure( + self, + model: nn.Module, + optimizer: Optional[Optimizer] = None, + criterion: Optional[Callable] = None, + ) -> Tuple[nn.Module, OptimizerWrapper, Callable]: model = TorchAMPModule(model) - optimizer = TorchAMPOptimizer(optimizer, **self.torch_amp_kwargs) + if optimizer is not None: + optimizer = TorchAMPOptimizer(optimizer, **self.torch_amp_kwargs) if criterion is not None: criterion = TorchAMPModule(criterion) return model, optimizer, criterion diff --git a/colossalai/booster/mixed_precision/mixed_precision_base.py b/colossalai/booster/mixed_precision/mixed_precision_base.py index 2490e9811ccf3ef71a1dcb90d30ddb794fc82d04..a86fdfc17eaf4df45de733e03c6384be6ccd2779 100644 --- a/colossalai/booster/mixed_precision/mixed_precision_base.py +++ b/colossalai/booster/mixed_precision/mixed_precision_base.py @@ -1,5 +1,5 @@ from abc import ABC, abstractmethod -from typing import Callable, Tuple +from typing import Callable, Optional, Tuple import torch.nn as nn from torch.optim import Optimizer @@ -13,9 +13,11 @@ class MixedPrecision(ABC): """ @abstractmethod - def configure(self, - model: nn.Module, - optimizer: Optimizer, - criterion: Callable = None) -> Tuple[nn.Module, OptimizerWrapper, Callable]: + def configure( + self, + model: nn.Module, + optimizer: Optional[Optimizer] = None, + criterion: Optional[Callable] = None, + ) -> Tuple[nn.Module, OptimizerWrapper, Callable]: # TODO: implement this method pass diff --git a/colossalai/booster/plugin/__init__.py b/colossalai/booster/plugin/__init__.py index aa45bcb59ad7b9b25a06bea5fa1c0f776e6870ae..62f3708fc62972051c05d8ffdfbd0ff0f8ca410e 100644 --- a/colossalai/booster/plugin/__init__.py +++ b/colossalai/booster/plugin/__init__.py @@ -1,6 +1,15 @@ from .gemini_plugin import GeminiPlugin +from .hybrid_parallel_plugin import HybridParallelPlugin from .low_level_zero_plugin import LowLevelZeroPlugin from .plugin_base import Plugin from .torch_ddp_plugin import TorchDDPPlugin -__all__ = ['Plugin', 'TorchDDPPlugin', 'GeminiPlugin', 'LowLevelZeroPlugin'] +__all__ = ["Plugin", "TorchDDPPlugin", "GeminiPlugin", "LowLevelZeroPlugin", "HybridParallelPlugin"] + +import torch +from packaging import version + +if version.parse(torch.__version__) >= version.parse("1.12.0"): + from .torch_fsdp_plugin import TorchFSDPPlugin + + __all__.append("TorchFSDPPlugin") diff --git a/colossalai/booster/plugin/dp_plugin_base.py b/colossalai/booster/plugin/dp_plugin_base.py new file mode 100644 index 0000000000000000000000000000000000000000..d2dd00453e32592b980b0a34c967829fc94e2cde --- /dev/null +++ b/colossalai/booster/plugin/dp_plugin_base.py @@ -0,0 +1,66 @@ +import random + +import numpy as np +import torch +import torch.distributed as dist +from torch.utils.data import DataLoader +from torch.utils.data.distributed import DistributedSampler + +from .plugin_base import Plugin + + +class DPPluginBase(Plugin): + """This is a base class for all DP plugins. It sets up world size and rank, and provides data loader creation.""" + + def __init__(self) -> None: + super().__init__() + assert ( + dist.is_initialized() + ), "torch.distributed is not initialized, please use colossalai.launch to create the distributed environment" + self.rank = dist.get_rank() + self.world_size = dist.get_world_size() + + def prepare_dataloader( + self, dataset, batch_size, shuffle=False, seed=1024, drop_last=False, pin_memory=False, num_workers=0, **kwargs + ): + r""" + Prepare a dataloader for distributed training. The dataloader will be wrapped by + `torch.utils.data.DataLoader` and `torch.utils.data.DistributedSampler`. + + + Args: + dataset (`torch.utils.data.Dataset`): The dataset to be loaded. + shuffle (bool, optional): Whether to shuffle the dataset. Defaults to False. + seed (int, optional): Random worker seed for sampling, defaults to 1024. + add_sampler: Whether to add ``DistributedDataParallelSampler`` to the dataset. Defaults to True. + drop_last (bool, optional): Set to True to drop the last incomplete batch, if the dataset size + is not divisible by the batch size. If False and the size of dataset is not divisible by + the batch size, then the last batch will be smaller, defaults to False. + pin_memory (bool, optional): Whether to pin memory address in CPU memory. Defaults to False. + num_workers (int, optional): Number of worker threads for this dataloader. Defaults to 0. + kwargs (dict): optional parameters for ``torch.utils.data.DataLoader``, more details could be found in + `DataLoader `_. + + Returns: + :class:`torch.utils.data.DataLoader`: A DataLoader used for training or testing. + """ + _kwargs = kwargs.copy() + sampler = DistributedSampler(dataset, num_replicas=self.world_size, rank=self.rank, shuffle=shuffle) + + # Deterministic dataloader + def seed_worker(worker_id): + worker_seed = seed + np.random.seed(worker_seed) + torch.manual_seed(worker_seed) + random.seed(worker_seed) + + return DataLoader( + dataset, + batch_size=batch_size, + sampler=sampler, + worker_init_fn=seed_worker, + drop_last=drop_last, + pin_memory=pin_memory, + num_workers=num_workers, + **_kwargs, + ) diff --git a/colossalai/booster/plugin/gemini_plugin.py b/colossalai/booster/plugin/gemini_plugin.py index deda00d8a7b3b6849a1cf5b1db0b796a1a1c7d89..ca722a0768dc0b6dd5b3f5085980a7a2c4b6a57b 100644 --- a/colossalai/booster/plugin/gemini_plugin.py +++ b/colossalai/booster/plugin/gemini_plugin.py @@ -1,143 +1,278 @@ -import random -import warnings -from typing import Callable, List, Optional, Tuple, Union +import gc +import logging +import os +from pathlib import Path +from typing import Callable, Iterator, List, Optional, Tuple -import numpy as np import torch -import torch.distributed as dist import torch.nn as nn -from torch import Tensor from torch.optim import Optimizer from torch.optim.lr_scheduler import _LRScheduler as LRScheduler from torch.utils.data import DataLoader -from torch.utils.data.distributed import DistributedSampler -from colossalai.checkpoint_io import CheckpointIO, GeneralCheckpointIO -from colossalai.checkpoint_io.utils import save_state_dict +from colossalai.checkpoint_io import CheckpointIndexFile, CheckpointIO, GeneralCheckpointIO +from colossalai.checkpoint_io.utils import ( + get_model_base_filenames, + get_optimizer_base_filenames, + load_shard_state_dict, + save_config_file, + save_state_dict, + save_state_dict_shards, +) from colossalai.cluster import DistCoordinator from colossalai.interface import ModelWrapper, OptimizerWrapper from colossalai.utils import get_current_device -from colossalai.zero import GeminiDDP, zero_model_wrapper, zero_optim_wrapper +from colossalai.zero import GeminiDDP, GeminiOptimizer from colossalai.zero.gemini.memory_tracer import MemStats -from .plugin_base import Plugin +from .dp_plugin_base import DPPluginBase -__all__ = ['GeminiPlugin'] +__all__ = ["GeminiPlugin"] +SUPPORTED_PRECISION = ["fp16", "bf16"] +PRECISION_STR_TO_DTYPE = {"fp16": torch.half, "bf16": torch.bfloat16} -class GeminiCheckpointIO(GeneralCheckpointIO): +class GeminiCheckpointIO(GeneralCheckpointIO): def __init__(self) -> None: super().__init__() self.coordinator = DistCoordinator() + def save_unsharded_model(self, model: GeminiDDP, checkpoint: str, gather_dtensor: bool, use_safetensors: bool): + """ + Save sharded model to checkpoint but only on master process. + The model should be unwrapped in self.load_model via ModelWrapper.unwrap. + As there is communication when getting state dict, model.state_dict() must be called on all processes. + """ + assert isinstance(model, GeminiDDP), "Please boost the model before saving!" + state_dict = model.state_dict(only_rank_0=True) + if self.coordinator.is_master(): + save_state_dict(state_dict, checkpoint, use_safetensors) + def load_unsharded_model(self, model: GeminiDDP, checkpoint: str, strict: bool = True): """ Load model from checkpoint with automatic unwrapping. + The model should be unwrapped in self.load_model via ModelWrapper.unwrap. """ - # the model should be unwrapped in self.load_model via ModelWrapper.unwrap - return super().load_unsharded_model(model, checkpoint, strict=strict) + assert isinstance(model, GeminiDDP), "Please boost the model before loading!" + super().load_unsharded_model(model, checkpoint, strict=strict) - def save_unsharded_model(self, model: GeminiDDP, checkpoint: str, gather_dtensor: bool, use_safetensors: bool): + def save_unsharded_optimizer(self, optimizer: GeminiOptimizer, checkpoint: str, gather_dtensor: bool): """ - Save model to checkpoint but only on master process. + Save unsharded optimizer state dict to checkpoint. + After calling optimizer.state_dict(), the complete optimizer states will be collected on master rank. + As there is communication when getting state dict, optimizer.state_dict() must be called on all processes. + The saving process will only be executed by master rank. """ - # the model should be unwrapped in self.load_model via ModelWrapper.unwrap - # as there is communication when get state dict, this must be called on all processes - state_dict = model.state_dict(only_rank_0=True) + assert isinstance(optimizer, GeminiOptimizer), "Please boost the optimizer before saving!" + state_dict = optimizer.state_dict() if self.coordinator.is_master(): - save_state_dict(state_dict, checkpoint, use_safetensors) + save_state_dict(state_dict, checkpoint, use_safetensors=False) - def save_unsharded_optimizer(self, optimizer: Optimizer, checkpoint: str, gather_dtensor: bool): + def load_unsharded_optimizer(self, optimizer: GeminiOptimizer, checkpoint: str): """ - Save optimizer to checkpoint but only on master process. + Loading unsharded optimizer from checkpoint file. + For each process, only loading optimizer states of parameters it controls. """ - # TODO(ver217): optimizer state dict is sharded - super().save_unsharded_optimizer(optimizer, checkpoint, gather_dtensor) + assert isinstance(optimizer, GeminiOptimizer), "Please boost the optimizer before loading!" + super().load_unsharded_optimizer(optimizer, checkpoint) - def save_lr_scheduler(self, lr_scheduler: LRScheduler, checkpoint: str): + def save_sharded_model( + self, + model: GeminiDDP, + checkpoint_path: str, + gather_dtensor: bool = False, + prefix: Optional[str] = None, + max_shard_size: int = 1024, + use_safetensors: bool = False, + ): """ - Save model to checkpoint but only on master process. + Save sharded model. + As there is communication when getting state dict, model.state_dict() must be called on all processes. """ - if self.coordinator.is_master(): - super().save_lr_scheduler(lr_scheduler, checkpoint) + assert isinstance(model, GeminiDDP), "Please boost the model before saving!" + if os.path.isfile(checkpoint_path): + logging.error(f"Provided path ({checkpoint_path}) should be a directory, not a file") + return + + Path(checkpoint_path).mkdir(parents=True, exist_ok=True) + + state_dict_shard = model.state_dict_shard(max_shard_size=max_shard_size, only_rank_0=True, dtype=torch.float32) + weights_name, save_index_file = get_model_base_filenames(prefix, use_safetensors) + index_file = CheckpointIndexFile(checkpoint_path) + + # Save shards of optimizer states. + is_master = self.coordinator.is_master() + total_size = save_state_dict_shards( + sharded_state_dict=state_dict_shard, + checkpoint=checkpoint_path, + index_file=index_file, + base_filename=weights_name, + is_master=is_master, + use_safetensors=use_safetensors, + ) + # only save the index file on the master rank + if self.coordinator.is_master(): + index_file.append_meta_data("total_size", total_size) + index_file.write_index_file(save_index_file) + save_config_file(model.unwrap(), checkpoint_path) + logging.info( + f"The model is split into checkpoint shards. " + f"You can find where each parameters has been saved in the " + f"index located at {save_index_file}." + ) + + def load_sharded_model( + self, model: GeminiDDP, checkpoint_index_file: Path, strict: bool = False, use_safetensors: bool = False + ): + """ + Load shard model, load model from multiple files. + """ + assert isinstance(model, GeminiDDP), "Please boost the model before loading!" + return super().load_sharded_model(model, checkpoint_index_file, strict, use_safetensors, load_sub_module=False) -class GeminiModel(ModelWrapper): + def save_sharded_optimizer( + self, optimizer: GeminiOptimizer, checkpoint: Path, gather_dtensor: bool, prefix: str, size_per_shard: int + ): + """ + Save sharded optimizer state dict to checkpoint folder. + As there is communication when getting state dict, this must be called on all processes. + """ + assert isinstance(optimizer, GeminiOptimizer), "Please boost the optimizer before saving!" + + if os.path.isfile(checkpoint): + logging.error(f"Provided path ({checkpoint}) should be a directory, not a file") + return + + Path(checkpoint).mkdir(parents=True, exist_ok=True) + + # Preparing file paths and index file. + states_name, save_index_file, param_group_file = get_optimizer_base_filenames(prefix) + index_file = CheckpointIndexFile(checkpoint) + + # Store the information of param groups to param_group_file. + index_file.append_meta_data("param_groups", param_group_file) + group_file_path = os.path.join(checkpoint, param_group_file) + param_groups = optimizer.get_param_groups_for_saving() + torch.save(param_groups, group_file_path) + + # States are broken into shards within max_shard_size. + state_dict_shard = optimizer.state_shard(prefix=prefix, max_shard_size=size_per_shard, only_rank_0=True) + + # Save shards of optimizer states. + is_master = self.coordinator.is_master() + total_size = save_state_dict_shards( + sharded_state_dict=state_dict_shard, + checkpoint=checkpoint, + index_file=index_file, + base_filename=states_name, + is_master=is_master, + use_safetensors=False, + ) - def __init__(self, module: nn.Module, gemini_config: dict, verbose: bool = False) -> None: - super().__init__(module) - self.module = zero_model_wrapper(module, zero_stage=3, gemini_config=gemini_config, verbose=verbose) + # Wrap up index file. Only save it on master rank. + if self.coordinator.is_master(): + index_file.append_meta_data("total_size", total_size) + index_file.write_index_file(save_index_file) + logging.info( + f"The optimizer is going to be split to checkpoint shards. " + f"You can find where each parameters has been saved in the " + f"index located at {save_index_file}." + ) + + def load_sharded_optimizer(self, optimizer: GeminiOptimizer, checkpoint_index_file: Path, prefix: str): + """ + Loading sharded optimizer from checkpoint folder, with index file given. + For each process, only loading optimizer states of parameters it controls. + """ + assert isinstance(optimizer, GeminiOptimizer), "Please boost the optimizer before loading!" + if not os.path.isfile(checkpoint_index_file): + logging.error(f"Provided path ({checkpoint_index_file}) should be a file") - def unwrap(self): - # as save/load state dict is coupled with the GeminiDDP, we only return GeminiDDP model - return self.module + assert isinstance(optimizer, GeminiOptimizer) + # Read checkpoint index file. + ckpt_index_file = CheckpointIndexFile.from_file(checkpoint_index_file) -class GeminiOptimizer(OptimizerWrapper): + # Load param_groups. + param_group_path = ckpt_index_file.get_param_group_filename() + if param_group_path is None: + raise RuntimeError( + f"Invalid index file path {checkpoint_index_file} for an optimizer. \ + Lacking param group file under current directory." + ) + saved_param_groups = torch.load(param_group_path) + optimizer.load_param_groups(saved_param_groups) - def __init__(self, - module: GeminiDDP, - optimizer: Optimizer, - zero_optim_config: dict, - optim_kwargs: dict, - verbose: bool = False) -> None: - optimizer = zero_optim_wrapper(module, - optimizer, - optim_config=zero_optim_config, - **optim_kwargs, - verbose=verbose) - super().__init__(optimizer) + checkpoint_files, _ = ckpt_index_file.get_checkpoint_filenames() - def backward(self, loss: Tensor, *args, **kwargs): - self.optim.backward(loss) + # Load optimizer states from shard files under checkpoint path. + # For each file, only load the states managed by current process. + for shard_file in checkpoint_files: + state_dict_shard = load_shard_state_dict(Path(shard_file), use_safetensors=False) + optimizer.load_param_states(state_dict_shard) + del state_dict_shard + gc.collect() - def clip_grad_by_norm(self, - max_norm: Union[float, int], - norm_type: Union[float, int] = 2, - error_if_nonfinite: bool = False, - *args, - **kwargs) -> Tensor: - warnings.warn(f'Gemini controls grad clipping by itself, so you should not use clip_grad_by_norm') + optimizer.optimizer_loading_epilogue() - def clip_grad_by_value(self, clip_value: float, *args, **kwargs) -> None: - raise NotImplementedError('Gemini does not support clip_grad_by_value') + def save_lr_scheduler(self, lr_scheduler: LRScheduler, checkpoint: str): + """ + Save model to checkpoint but only on master process. + """ + if self.coordinator.is_master(): + super().save_lr_scheduler(lr_scheduler, checkpoint) -class GeminiPlugin(Plugin): +class GeminiPlugin(DPPluginBase): """ Plugin for Gemini. - Example: - >>> from colossalai.booster import Booster - >>> from colossalai.booster.plugin import GeminiPlugin - >>> - >>> model, train_dataset, optimizer, criterion = ... - >>> plugin = GeminiPlugin() + ```python + from colossalai.booster import Booster + from colossalai.booster.plugin import GeminiPlugin + + model, train_dataset, optimizer, criterion = ... + plugin = GeminiPlugin() - >>> train_dataloader = plugin.prepare_train_dataloader(train_dataset, batch_size=8) - >>> booster = Booster(plugin=plugin) - >>> model, optimizer, train_dataloader, criterion = booster.boost(model, optimizer, train_dataloader, criterion) + train_dataloader = plugin.prepare_dataloader(train_dataset, batch_size=8) + booster = Booster(plugin=plugin) + model, optimizer, train_dataloader, criterion = booster.boost(model, optimizer, train_dataloader, criterion) + ``` Args: - device (torch.device): device to place the model. - placement_policy (str, optional): "cpu", "cuda", "auto". Defaults to "cpu". + chunk_config_dict (dict, optional): chunk configuration dictionary. + chunk_init_device (torch.device, optional): device to initialize the chunk. + placement_policy (str, optional): "static" and "auto". Defaults to "static". + shard_param_frac (float, optional): fraction of parameters to be sharded. Only for "static" placement. + If `shard_param_frac` is 1.0, it's equal to zero-3. If `shard_param_frac` is 0.0, it's equal to zero-2. Defaults to 1.0. + offload_optim_frac (float, optional): fraction of optimizer states to be offloaded. Only for "static" placement. + If `shard_param_frac` is 1.0 and `offload_optim_frac` is 0.0, it's equal to old "cuda" placement. Defaults to 0.0. + offload_param_frac (float, optional): fraction of parameters to be offloaded. Only for "static" placement. + For efficiency, this argument is useful only when `shard_param_frac` is 1.0 and `offload_optim_frac` is 1.0. + If `shard_param_frac` is 1.0, `offload_optim_frac` is 1.0 and `offload_param_frac` is 1.0, it's equal to old "cpu" placement. + When using static placement, we recommend users to tune `shard_param_frac` first and then `offload_optim_frac`. + Defaults to 0.0. + warmup_non_model_data_ratio (float, optional): ratio of expected non-model data memory during warmup. Only for "auto" placement. Defaults to 0.8. + steady_cuda_cap_ratio (float, optional): ratio of allowed cuda capacity for model data during steady state. Only for "auto" placement. Defaults to 0.9. + precision (str, optional): precision. Support 'fp16' and 'bf16'. Defaults to 'fp16'. pin_memory (bool, optional): use pin memory on CPU. Defaults to False. force_outputs_fp32 (bool, optional): force outputs are fp32. Defaults to False. strict_ddp_mode (bool, optional): use strict ddp mode (only use dp without other parallelism). Defaults to False. - search_range_mb (int, optional): chunk size searching range in MegaByte. Defaults to 32. + search_range_m (int, optional): chunk size searching range divided by 2^20. Defaults to 32. hidden_dim (int, optional): the hidden dimension of DNN. Users can provide this argument to speed up searching. If users do not know this argument before training, it is ok. We will use a default value 1024. - min_chunk_size_mb (float, optional): the minimum chunk size in MegaByte. - If the aggregate size of parameters is still samller than the minimum chunk size, + min_chunk_size_m (float, optional): the minimum chunk size divided by 2^20. + If the aggregate size of parameters is still smaller than the minimum chunk size, all parameters will be compacted into one small chunk. memstats (MemStats, optional) the memory statistics collector by a runtime memory tracer. gpu_margin_mem_ratio (float, optional): The ratio of GPU remaining memory (after the first forward-backward) which will be used when using hybrid CPU optimizer. This argument is meaningless when `placement_policy` of `GeminiManager` is not "auto". Defaults to 0.0. - initial_scale (float, optional): Initial scale used by DynamicGradScaler. Defaults to 2**32. + initial_scale (float, optional): Initial scale used by DynamicGradScaler. Defaults to 2**16. min_scale (float, optional): Min scale used by DynamicGradScaler. Defaults to 1. growth_factor (float, optional): growth_factor used by DynamicGradScaler. Defaults to 2. backoff_factor (float, optional): backoff_factor used by DynamicGradScaler. Defaults to 0.5. @@ -152,17 +287,24 @@ class GeminiPlugin(Plugin): def __init__( self, - device: Optional[torch.device] = None, - placement_policy: str = "cpu", + chunk_config_dict: Optional[dict] = None, + chunk_init_device: Optional[torch.device] = None, + placement_policy: str = "static", + shard_param_frac: float = 1.0, # only for static placement + offload_optim_frac: float = 0.0, # only for static placement + offload_param_frac: float = 0.0, # only for static placement + warmup_non_model_data_ratio: float = 0.8, # only for auto placement + steady_cuda_cap_ratio: float = 0.9, # only for auto placement + precision: str = "fp16", pin_memory: bool = False, force_outputs_fp32: bool = False, strict_ddp_mode: bool = False, - search_range_mb: int = 32, + search_range_m: int = 32, hidden_dim: Optional[int] = None, - min_chunk_size_mb: float = 32, + min_chunk_size_m: float = 32, memstats: Optional[MemStats] = None, gpu_margin_mem_ratio: float = 0.0, - initial_scale: float = 2**32, + initial_scale: float = 2**16, min_scale: float = 1, growth_factor: float = 2, backoff_factor: float = 0.5, @@ -173,32 +315,40 @@ class GeminiPlugin(Plugin): norm_type: float = 2.0, verbose: bool = False, ) -> None: - - assert dist.is_initialized( - ), 'torch.distributed is not initialized, please use colossalai.launch to create the distributed environment' - self.rank = dist.get_rank() - self.world_size = dist.get_world_size() + super().__init__() + assert precision in SUPPORTED_PRECISION, f"precision {precision} is not supported" self.gemini_config = dict( - device=(device or get_current_device()), + chunk_config_dict=chunk_config_dict, + chunk_init_device=(chunk_init_device or get_current_device()), placement_policy=placement_policy, + shard_param_frac=shard_param_frac, + offload_optim_frac=offload_optim_frac, + offload_param_frac=offload_param_frac, + warmup_non_model_data_ratio=warmup_non_model_data_ratio, + steady_cuda_cap_ratio=steady_cuda_cap_ratio, pin_memory=pin_memory, force_outputs_fp32=force_outputs_fp32, strict_ddp_mode=strict_ddp_mode, - search_range_mb=search_range_mb, + search_range_m=search_range_m, hidden_dim=hidden_dim, - min_chunk_size_mb=min_chunk_size_mb, + min_chunk_size_m=min_chunk_size_m, memstats=memstats, + mixed_precision=PRECISION_STR_TO_DTYPE[precision], + ) + self.zero_optim_config = dict( + gpu_margin_mem_ratio=gpu_margin_mem_ratio, + ) + self.optim_kwargs = dict( + initial_scale=initial_scale, + growth_factor=growth_factor, + backoff_factor=backoff_factor, + growth_interval=growth_interval, + hysteresis=hysteresis, + min_scale=min_scale, + max_scale=max_scale, + max_norm=max_norm, + norm_type=norm_type, ) - self.zero_optim_config = dict(gpu_margin_mem_ratio=gpu_margin_mem_ratio,) - self.optim_kwargs = dict(initial_scale=initial_scale, - growth_factor=growth_factor, - backoff_factor=backoff_factor, - growth_interval=growth_interval, - hysteresis=hysteresis, - min_scale=min_scale, - max_scale=max_scale, - max_norm=max_norm, - norm_type=norm_type) self.verbose = verbose def support_no_sync(self) -> bool: @@ -208,74 +358,22 @@ class GeminiPlugin(Plugin): return True def supported_precisions(self) -> List[str]: - return ['fp16'] + return SUPPORTED_PRECISION def control_device(self) -> bool: return True def supported_devices(self) -> List[str]: - return ['cuda'] - - def prepare_train_dataloader(self, - dataset, - batch_size, - shuffle=False, - seed=1024, - drop_last=False, - pin_memory=False, - num_workers=0, - **kwargs): - r""" - Prepare a dataloader for distributed training. The dataloader will be wrapped by - `torch.utils.data.DataLoader` and `torch.utils.data.DistributedSampler`. - - Note: - 1. Evaluation datasets should not be passed to this function. - - Args: - dataset (`torch.utils.data.Dataset`): The dataset to be loaded. - shuffle (bool, optional): Whether to shuffle the dataset. Defaults to False. - seed (int, optional): Random worker seed for sampling, defaults to 1024. - add_sampler: Whether to add ``DistributedDataParallelSampler`` to the dataset. Defaults to True. - drop_last (bool, optional): Set to True to drop the last incomplete batch, if the dataset size - is not divisible by the batch size. If False and the size of dataset is not divisible by - the batch size, then the last batch will be smaller, defaults to False. - pin_memory (bool, optional): Whether to pin memory address in CPU memory. Defaults to False. - num_workers (int, optional): Number of worker threads for this dataloader. Defaults to 0. - kwargs (dict): optional parameters for ``torch.utils.data.DataLoader``, more details could be found in - `DataLoader `_. - - Returns: - :class:`torch.utils.data.DataLoader`: A DataLoader used for training or testing. - """ - _kwargs = kwargs.copy() - sampler = DistributedSampler(dataset, num_replicas=self.world_size, rank=self.rank, shuffle=shuffle) - - # Deterministic dataloader - def seed_worker(worker_id): - worker_seed = seed - np.random.seed(worker_seed) - torch.manual_seed(worker_seed) - random.seed(worker_seed) - - return DataLoader(dataset, - batch_size=batch_size, - sampler=sampler, - worker_init_fn=seed_worker, - drop_last=drop_last, - pin_memory=pin_memory, - num_workers=num_workers, - **_kwargs) + return ["cuda"] def configure( self, model: nn.Module, - optimizer: Optimizer, - criterion: Callable = None, - dataloader: DataLoader = None, - lr_scheduler: LRScheduler = None, - ) -> Tuple[Union[nn.Module, OptimizerWrapper, LRScheduler, DataLoader]]: - + optimizer: Optional[Optimizer] = None, + criterion: Optional[Callable] = None, + dataloader: Optional[DataLoader] = None, + lr_scheduler: Optional[LRScheduler] = None, + ) -> Tuple[nn.Module, OptimizerWrapper, Callable, DataLoader, LRScheduler]: if not isinstance(model, ModelWrapper): # convert model to sync bn # FIXME(ver217): gemini does not support sync bn @@ -287,11 +385,12 @@ class GeminiPlugin(Plugin): # model = nn.SyncBatchNorm.convert_sync_batchnorm(model, None) # wrap the model with Gemini - model = GeminiModel(model, self.gemini_config, self.verbose) + model = GeminiDDP(model, **self.gemini_config, verbose=self.verbose) - if not isinstance(optimizer, OptimizerWrapper): - optimizer = GeminiOptimizer(model.unwrap(), optimizer, self.zero_optim_config, self.optim_kwargs, - self.verbose) + if optimizer is not None and not isinstance(optimizer, OptimizerWrapper): + optimizer = GeminiOptimizer( + optimizer, model, **self.zero_optim_config, **self.optim_kwargs, verbose=self.verbose + ) return model, optimizer, criterion, dataloader, lr_scheduler @@ -300,3 +399,6 @@ class GeminiPlugin(Plugin): def get_checkpoint_io(self) -> CheckpointIO: return GeminiCheckpointIO() + + def no_sync(self, model: nn.Module, optimizer: OptimizerWrapper) -> Iterator[None]: + raise NotImplementedError diff --git a/colossalai/booster/plugin/hybrid_parallel_plugin.py b/colossalai/booster/plugin/hybrid_parallel_plugin.py new file mode 100644 index 0000000000000000000000000000000000000000..479ccc3eb36e4ae9bb9df430313e912258717cd4 --- /dev/null +++ b/colossalai/booster/plugin/hybrid_parallel_plugin.py @@ -0,0 +1,579 @@ +import random +from contextlib import nullcontext +from functools import partial +from types import MethodType +from typing import Any, Callable, Iterator, List, Optional, OrderedDict, Tuple, Union + +import numpy as np +import torch +import torch.distributed as dist +from torch.distributed import ProcessGroup +from torch.nn import Module, SyncBatchNorm +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.optim import Optimizer +from torch.optim.lr_scheduler import _LRScheduler as LRScheduler +from torch.utils._pytree import tree_map +from torch.utils.data import DataLoader +from torch.utils.data.distributed import DistributedSampler + +from colossalai.amp.naive_amp.mixed_precision_optimizer import MixedPrecisionOptimizer +from colossalai.checkpoint_io import CheckpointIO, HybridParallelCheckpointIO +from colossalai.cluster import ProcessGroupMesh +from colossalai.interface import ModelWrapper, OptimizerWrapper +from colossalai.pipeline.schedule import OneForwardOneBackwardSchedule +from colossalai.pipeline.stage_manager import PipelineStageManager +from colossalai.shardformer import ShardConfig, ShardFormer +from colossalai.shardformer.policies.base_policy import Policy +from colossalai.zero.low_level import LowLevelZeroOptimizer + +from .pp_plugin_base import PipelinePluginBase + +DP_AXIS, PP_AXIS, TP_AXIS = 0, 1, 2 + + +def _convert_floating_point(x, dtype: torch.dtype = torch.float16): + if isinstance(x, torch.Tensor) and torch.is_floating_point(x): + return x.to(dtype) + return x + + +class HybridParallelModule(ModelWrapper): + def __init__( + self, + module: Module, + precision: str, + shard_config: ShardConfig, + dp_group: ProcessGroup, + use_ddp: bool, + ddp_config: dict, + custom_policy: Policy, + ) -> None: + self.stage_manager = shard_config.pipeline_stage_manager + self.dp_group = dp_group + + shardformer = ShardFormer(shard_config) + if custom_policy is not None: + assert isinstance(custom_policy, object) + module, self.shared_params = shardformer.optimize(module, policy=custom_policy) + + # setting process groups for shared parameters + self.shared_param_process_groups = [] + for shared_param in self.shared_params: + if len(shared_param) > 0: + self.shared_param_process_groups.append( + self.stage_manager.init_process_group_by_stages(list(shared_param.keys())) + ) + + # setting mixed_precision + self.mixed_precision = None + if precision == "fp16": + self.mixed_precision = torch.float16 + elif precision == "bf16": + self.mixed_precision = torch.bfloat16 + if self.mixed_precision is not None: + module = module.to(self.mixed_precision) + module = module.cuda() + + # setting input type cast when using mixed precision + self.convert_fn = None + if self.mixed_precision is not None: + self.convert_fn = partial(_convert_floating_point, dtype=self.mixed_precision) + + # setting ddp configs + if use_ddp: + # convert model to sync bn + module = SyncBatchNorm.convert_sync_batchnorm(module, dp_group) + # wrap the model with PyTorch DDP + module = DDP(module, process_group=dp_group, **ddp_config) + + super().__init__(module) + + def sync_shared_params(self): + for shared_param, group in zip(self.shared_params, self.shared_param_process_groups): + if self.stage_manager.stage in shared_param: + param = shared_param[self.stage_manager.stage] + dist.all_reduce(param.grad, group=group) + dist.barrier() + + def no_sync(self) -> Iterator[None]: + # no sync grads across data parallel + return nullcontext() + + def sync_grads(self): + # sync grad across data parallel + if self.dp_group.size() == 1: + return + for p in self.module.parameters(): + if p.grad is not None: + dist.all_reduce(p.grad, group=self.dp_group) + p.grad.div_(self.dp_group.size()) + + def forward(self, *args, **kwargs): + if self.convert_fn is not None: + args = tree_map(self.convert_fn, args) + kwargs = tree_map(self.convert_fn, kwargs) + return super().forward(*args, **kwargs) + + def unwrap(self): + module = super().unwrap() + if isinstance(module, DDP): + module = module.module + return module + + +def get_param_info(optim: Optimizer): + # Get a backup of necessary information of parameters for future use, which includes: + # 1. A complete param_group, with params in the form of param_id + # 2. A mapping from param address (obtained using id(param)) to integer param_id + # 3. A mapping from integer param_id to param address. + # 4. A mapping from param_address (obtained using id(param)) to the original shape of parameter before sharding. + # When Zero is used, the params here are fp16/bf16 model params rather than fp32 master params in optimizer. + + if optim is None: + return {} + param_info = {"param_groups": [], "param2id": {}, "id2param": {}, "param2shape": {}} + start_index = 0 + for group in optim.param_groups: + packed_group = {k: v for k, v in group.items() if k != "params"} + packed_group["params"] = [] + + for param_id, param in enumerate(group["params"], start_index): + original_shape = param.shape if isinstance(param, torch.Tensor) else None + packed_group["params"].append(param_id) + param_info["param2id"][id(param)] = param_id + param_info["id2param"][param_id] = id(param) + param_info["param2shape"][id(param)] = original_shape + + param_info["param_groups"].append(packed_group) + start_index += len(group["params"]) + + return param_info + + +def init_pipeline_optimizer(optim: Optimizer, model: Module): + model_params = set(model.parameters()) + new_param_groups = [] + for group in optim.param_groups: + params = [p for p in group["params"] if p in model_params] + new_param_groups.append({**group, "params": params}) + optim.__setstate__({"param_groups": new_param_groups}) + + +class HybridParallelNaiveOptimizer(OptimizerWrapper): + def __init__(self, optim: Optimizer, model: Module, use_pipeline: bool, param_info: OrderedDict): + self.param_info = param_info + if use_pipeline: + init_pipeline_optimizer(optim, model) + super().__init__(optim) + + def update_master_params(self, model: Module): + pass + + def get_working_to_master_map(self): + return None + + def get_master_to_working_map(self): + return None + + +class HybridParallelAMPOptimizer(MixedPrecisionOptimizer): + def __init__( + self, + optim: Optimizer, + model: Module, + use_pipeline: bool, + param_info: OrderedDict, + precision: str = "fp16", + initial_scale: float = 2**16, + min_scale: float = 1, + growth_factor: float = 2, + backoff_factor: float = 0.5, + growth_interval: int = 1000, + hysteresis: int = 2, + max_scale: float = 2**32, + max_norm: float = 0, + ): + self.param_info = param_info + if use_pipeline: + init_pipeline_optimizer(optim, model) + super().__init__( + optim, + precision, + initial_scale, + min_scale, + growth_factor, + backoff_factor, + growth_interval, + hysteresis, + max_scale, + max_norm, + ) + + +class HybridParallelZeroOptimizer(LowLevelZeroOptimizer): + def __init__( + self, + optimizer: Optimizer, + model: Module, + use_pipeline: bool, + param_info: OrderedDict, + initial_scale: int = 2**16, # grad scaler config + min_scale: int = 1, + growth_factor: float = 2.0, + backoff_factor: float = 0.5, + growth_interval: int = 2000, + hysteresis: int = 2, + max_scale: int = 2**24, + clip_grad_norm: float = 0.0, # grad clipping + verbose: bool = False, + reduce_bucket_size: int = 1024 * 1024, # communication + communication_dtype: Optional[torch.dtype] = None, + overlap_communication: bool = True, + partition_grad: bool = False, # stage 2 flag + cpu_offload: bool = False, # cpu offload + dp_process_group: Optional[ProcessGroup] = None, # the dp pg for comm + tp_process_group: Optional[ProcessGroup] = None, # if using tp + forced_dtype: Optional[torch.dtype] = None, + ): + self.param_info = param_info + if use_pipeline: + init_pipeline_optimizer(optimizer, model) + super().__init__( + optimizer, + initial_scale, + min_scale, + growth_factor, + backoff_factor, + growth_interval, + hysteresis, + max_scale, + clip_grad_norm, + verbose, + reduce_bucket_size, + communication_dtype, + overlap_communication, + partition_grad, + cpu_offload, + dp_process_group, + tp_process_group, + forced_dtype, + ) + + +class HybridParallelPlugin(PipelinePluginBase): + """ + Plugin for Hybrid Parallel Training. + Tensor parallel, pipeline parallel and data parallel(DDP/ZeRO) can be picked and combined in this plugin. + The size of tp and pp should be passed in by user, then the size of dp is automatically calculated from dp_size = world_size / (tp_size * pp_size). + + ```python + from colossalai.booster import Booster + from colossalai.booster.plugin import HybridParallelPlugin + + model, train_dataset, optimizer, criterion = ... + plugin = HybridParallelPlugin(tp_size=2, pp_size=2) + + train_dataloader = plugin.prepare_dataloader(train_dataset, batch_size=8) + booster = Booster(plugin=plugin) + model, optimizer, criterion, train_dataloader, _ = booster.boost(model, optimizer, criterion, train_dataloader) + ``` + + Args: + tp_size (int): The size of tensor parallelism. Tensor parallelism will not be used when tp_size is set to 1. + pp_size (int): The number of pipeline stages in pipeline parallelism. Pipeline parallelism will not be used when pp_size is set to 1. + precision (str, optional): Specifies the precision of parameters during training. + Auto-mixied precision will be used when this argument is set to 'fp16' or 'bf16', otherwise model is trained with 'fp32'. + Defaults to 'fp16'. + zero_stage (int, optional): The stage of ZeRO for data parallelism. Can only be choosed from [0, 1, 2]. + When set to 0, ZeRO will not be used. Defaults to 0. + enable_all_optimization (bool, optional): Whether to switch on all the optimizations supported by Shardformer. + Currently all the optimization methods include fused normalization, flash attention and JIT. + Defaults to False. + enable_fused_normalization (bool, optional): Whether to switch on fused normalization in Shardformer. Defaults to False. + enable_flash_attention (bool, optional): Whether to switch on flash attention in Shardformer. Defaults to False. + enable_jit_fused (bool, optional): Whether to switch on JIT in Shardformer. Default to False. + enable_sequence_parallelism (bool): Whether to turn on sequence parallelism in Shardformer. Defaults to False. + enable_sequence_overlap (bool): Whether to turn on sequence overlap in Shardformer. Defaults to False. + num_microbatches (int, optional): Number of microbatches when using pipeline parallelism. Defaults to None. + microbatch_size (int, optional): Microbatch size when using pipeline parallelism. + Either ``num_microbatches`` or ``microbatch_size`` should be provided if using pipeline. + If ``num_microbatches`` is provided, this will be ignored. Defaults to None. + initial_scale (float, optional): The initial loss scale of AMP. Defaults to 2**16. + min_scale (float, optional): The minimum loss scale of AMP. Defaults to 1. + growth_factor (float, optional): The multiplication factor for increasing loss scale when using AMP. Defaults to 2. + backoff_factor (float, optional): The multiplication factor for decreasing loss scale when using AMP. Defaults to 0.5. + growth_interval (int, optional): The number of steps to increase loss scale when no overflow occurs when using AMP. Defaults to 1000. + hysteresis (int, optional): The number of overflows before decreasing loss scale when using AMP. Defaults to 2. + max_scale (float, optional): The maximum loss scale of AMP. Defaults to 2**32. + max_norm (float, optional): Maximum norm for gradient clipping. Defaults to 0. + broadcast_buffers (bool, optional): Whether to broadcast buffers in the beginning of training when using DDP. Defaults to True. + ddp_bucket_cap_mb (int, optional): The bucket size in MB when using DDP. Defaults to 25. + find_unused_parameters (bool, optional): Whether to find unused parameters when using DDP. Defaults to False. + check_reduction (bool, optional): Whether to check reduction when using DDP. Defaults to False. + gradient_as_bucket_view (bool, optional): Whether to use gradient as bucket view when using DDP. Defaults to False. + static_graph (bool, optional): Whether to use static graph when using DDP. Defaults to False. + zero_bucket_size_in_m (int, optional): Gradient reduce bucket size in million elements when using ZeRO. Defaults to 12. + cpu_offload (bool, optional): Whether to open cpu_offload when using ZeRO. Defaults to False. + communication_dtype (torch.dtype, optional): Communication dtype when using ZeRO. If not specified, the dtype of param will be used. Defaults to None. + overlap_communication (bool, optional): Whether to overlap communication and computation when using ZeRO. Defaults to True. + custom_policy (Policy, optional): Custom policy for Shardformer. Defaults to None. + """ + + def __init__( + self, + tp_size: int, + pp_size: int, + precision: str = "fp16", + zero_stage: int = 0, + enable_all_optimization: bool = False, + enable_fused_normalization: bool = False, + enable_flash_attention: bool = False, + enable_jit_fused: bool = False, + enable_sequence_parallelism: bool = False, + enable_sequence_overlap: bool = False, + num_microbatches: Optional[int] = None, + microbatch_size: Optional[int] = None, + initial_scale: float = 2**16, + min_scale: float = 1, + growth_factor: float = 2, + backoff_factor: float = 0.5, + growth_interval: int = 1000, + hysteresis: int = 2, + max_scale: float = 2**32, + max_norm: float = 0, + broadcast_buffers: bool = True, + ddp_bucket_cap_mb: int = 25, + find_unused_parameters: bool = False, + check_reduction: bool = False, + gradient_as_bucket_view: bool = False, + static_graph: bool = False, + zero_bucket_size_in_m: int = 12, + cpu_offload: bool = False, + communication_dtype: Optional[torch.dtype] = None, + overlap_communication: bool = True, + custom_policy: Policy = None, + ) -> None: + super().__init__() + assert ( + dist.get_world_size() % (tp_size * pp_size) == 0 + ), f"world size {dist.get_world_size()} is not divisible by tp_size {tp_size} * pp_size {pp_size}" + + if enable_sequence_parallelism: + assert tp_size > 1, "Sequence parallelism must be enabled when using tensor parallelism" + + self.tp_size = tp_size + self.pp_size = pp_size + self.dp_size = dist.get_world_size() // (tp_size * pp_size) + self.precision = precision + self.zero_stage = zero_stage + self.cpu_offload = cpu_offload + self.enable_all_optimization = enable_all_optimization + self.enable_fused_normalization = enable_fused_normalization + self.enable_flash_attention = enable_flash_attention + self.enable_jit_fused = enable_jit_fused + self.enable_sequence_parallelism = enable_sequence_parallelism + self.pg_mesh = ProcessGroupMesh(self.dp_size, self.pp_size, self.tp_size) + self.stage_manager = None + self.schedule = None + self.custom_policy = custom_policy + assert zero_stage in (0, 1, 2) + if self.pp_size > 1: + assert ( + num_microbatches is not None or microbatch_size is not None + ), "num_microbatches or microbatch_size must be specified when using pipeline parallelism" + assert self.zero_stage <= 1, "zero stage must be 0 or 1 when using pipeline parallelism" + self.stage_manager = PipelineStageManager(self.pg_mesh, PP_AXIS) + self.schedule = OneForwardOneBackwardSchedule( + self.stage_manager, num_microbatches=num_microbatches, microbatch_size=microbatch_size + ) + self.tp_group = self.pg_mesh.get_group_along_axis(TP_AXIS) + self.dp_group = self.pg_mesh.get_group_along_axis(DP_AXIS) + self.pp_group = self.pg_mesh.get_group_along_axis(PP_AXIS) + self.shard_config = ShardConfig( + tensor_parallel_process_group=self.tp_group, + pipeline_stage_manager=self.stage_manager, + enable_tensor_parallelism=self.tp_size > 1, + enable_all_optimization=self.enable_all_optimization, + enable_fused_normalization=self.enable_fused_normalization, + enable_flash_attention=self.enable_flash_attention, + enable_jit_fused=self.enable_jit_fused, + enable_sequence_parallelism=enable_sequence_parallelism, + enable_sequence_overlap=enable_sequence_overlap, + ) + self.amp_config = dict( + initial_scale=initial_scale, + growth_factor=growth_factor, + backoff_factor=backoff_factor, + growth_interval=growth_interval, + hysteresis=hysteresis, + min_scale=min_scale, + max_scale=max_scale, + ) + + self.ddp_config = dict( + broadcast_buffers=broadcast_buffers, + bucket_cap_mb=ddp_bucket_cap_mb, + find_unused_parameters=find_unused_parameters, + check_reduction=check_reduction, + gradient_as_bucket_view=gradient_as_bucket_view, + static_graph=static_graph, + ) + + self.zero_config = dict( + reduce_bucket_size=zero_bucket_size_in_m * 1024 * 1024, + communication_dtype=communication_dtype, + overlap_communication=overlap_communication, + cpu_offload=cpu_offload, + partition_grad=(self.zero_stage == 2), + ) + + self.max_norm = max_norm + + @property + def enable_pipeline_parallelism(self) -> bool: + return self.pp_size > 1 + + def supported_devices(self) -> List[str]: + return ["cuda"] + + def supported_precisions(self) -> List[str]: + return ["fp16", "bf16", "fp32"] + + def control_device(self) -> bool: + return True + + def control_precision(self) -> bool: + return True + + def support_no_sync(self) -> bool: + return False + + def control_checkpoint_io(self) -> bool: + return True + + def configure( + self, + model: Module, + optimizer: Optional[Optimizer] = None, + criterion: Optional[Callable] = None, + dataloader: Optional[DataLoader] = None, + lr_scheduler: Optional[LRScheduler] = None, + ) -> Tuple[Module, OptimizerWrapper, Callable, DataLoader, LRScheduler]: + param_info = get_param_info(optimizer) + if not isinstance(model, ModelWrapper): + use_ddp = self.dp_size > 1 and self.pp_size == 1 and self.zero_stage == 0 + model = HybridParallelModule( + model, self.precision, self.shard_config, self.dp_group, use_ddp, self.ddp_config, self.custom_policy + ) + if optimizer is not None and not isinstance(optimizer, OptimizerWrapper): + if self.zero_stage == 0: + if self.precision in ["fp16", "bf16"]: + optimizer = HybridParallelAMPOptimizer( + optimizer, + model, + use_pipeline=self.enable_pipeline_parallelism, + param_info=param_info, + precision=self.precision, + max_norm=self.max_norm, + **self.amp_config, + ) + else: + optimizer = HybridParallelNaiveOptimizer( + optimizer, model, use_pipeline=self.enable_pipeline_parallelism, param_info=param_info + ) + else: + assert self.dp_size > 1, "Please use Zero when data parallel size is greater than 1." + assert self.precision != "fp32", "Please set precision to 'fp16' or 'bf16' when using ZeRO." + optimizer = HybridParallelZeroOptimizer( + optimizer, + model, + use_pipeline=self.enable_pipeline_parallelism, + param_info=param_info, + dp_process_group=self.dp_group, + tp_process_group=self.tp_group, + verbose=True, + clip_grad_norm=self.max_norm, + **self.zero_config, + **self.amp_config, + ) + # inject update_master_params + model.update_master_params = MethodType(optimizer.update_master_params, model) + return model, optimizer, criterion, dataloader, lr_scheduler + + def execute_pipeline( + self, + data_iter: Iterator, + model: HybridParallelModule, + criterion: Callable[[Any, Any], torch.Tensor], + optimizer: Optional[ + Union[HybridParallelNaiveOptimizer, HybridParallelAMPOptimizer, HybridParallelZeroOptimizer] + ] = None, + return_loss: bool = True, + return_outputs: bool = False, + ) -> dict: + assert self.enable_pipeline_parallelism, "pipeline parallelism is not enabled" + # return loss or outputs if needed + ctx = optimizer.no_sync() if isinstance(optimizer, HybridParallelZeroOptimizer) else model.no_sync() + with ctx: + outputs = self.schedule.forward_backward_step( + model, data_iter, criterion, optimizer, return_loss, return_outputs + ) + model.sync_shared_params() + if isinstance(optimizer, HybridParallelZeroOptimizer): + optimizer.sync_grad() + else: + model.sync_grads() + return outputs + + def prepare_dataloader( + self, dataset, batch_size, shuffle=False, seed=1024, drop_last=False, pin_memory=False, num_workers=0, **kwargs + ): + r""" + Prepare a dataloader for distributed training. The dataloader will be wrapped by + `torch.utils.data.DataLoader` and `torch.utils.data.DistributedSampler`. + + + Args: + dataset (`torch.utils.data.Dataset`): The dataset to be loaded. + shuffle (bool, optional): Whether to shuffle the dataset. Defaults to False. + seed (int, optional): Random worker seed for sampling, defaults to 1024. + add_sampler: Whether to add ``DistributedDataParallelSampler`` to the dataset. Defaults to True. + drop_last (bool, optional): Set to True to drop the last incomplete batch, if the dataset size + is not divisible by the batch size. If False and the size of dataset is not divisible by + the batch size, then the last batch will be smaller, defaults to False. + pin_memory (bool, optional): Whether to pin memory address in CPU memory. Defaults to False. + num_workers (int, optional): Number of worker threads for this dataloader. Defaults to 0. + kwargs (dict): optional parameters for ``torch.utils.data.DataLoader``, more details could be found in + `DataLoader `_. + + Returns: + :class:`torch.utils.data.DataLoader`: A DataLoader used for training or testing. + """ + _kwargs = kwargs.copy() + sampler = DistributedSampler( + dataset, num_replicas=self.pg_mesh.size(DP_AXIS), rank=self.pg_mesh.coordinate(DP_AXIS), shuffle=shuffle + ) + + # Deterministic dataloader + def seed_worker(worker_id): + worker_seed = seed + np.random.seed(worker_seed) + torch.manual_seed(worker_seed) + random.seed(worker_seed) + + return DataLoader( + dataset, + batch_size=batch_size, + sampler=sampler, + worker_init_fn=seed_worker, + drop_last=drop_last, + pin_memory=pin_memory, + num_workers=num_workers, + **_kwargs, + ) + + def get_checkpoint_io(self) -> CheckpointIO: + return HybridParallelCheckpointIO(self.dp_group, self.pp_group, self.tp_group, self.zero_stage) + + def no_sync(self, model: Module) -> Iterator[None]: + raise NotImplementedError diff --git a/colossalai/booster/plugin/low_level_zero_plugin.py b/colossalai/booster/plugin/low_level_zero_plugin.py index 969c430bd317600091e69a68276a5f14006d5420..dffa4ce164efe2ca0abb6106e76ea416dbf9ce3d 100644 --- a/colossalai/booster/plugin/low_level_zero_plugin.py +++ b/colossalai/booster/plugin/low_level_zero_plugin.py @@ -1,111 +1,233 @@ -import random -import warnings -from typing import Callable, List, Optional, Tuple, Union +import logging +import os +from functools import partial +from pathlib import Path +from types import MethodType +from typing import Callable, Iterator, List, Optional, Tuple -import numpy as np import torch -import torch.distributed as dist import torch.nn as nn -from torch import Tensor from torch.optim import Optimizer from torch.optim.lr_scheduler import _LRScheduler as LRScheduler from torch.utils._pytree import tree_map from torch.utils.data import DataLoader -from torch.utils.data.distributed import DistributedSampler -from colossalai.checkpoint_io import CheckpointIO -from colossalai.interface import ModelWrapper, OptimizerWrapper +from colossalai.checkpoint_io import CheckpointIndexFile, CheckpointIO +from colossalai.checkpoint_io.utils import ( + get_optimizer_base_filenames, + get_shard_filename, + load_param_groups_into_optimizer, + load_shard_state_dict, + load_states_into_optimizer, + save_param_groups, + save_state_dict, + sharded_optimizer_loading_epilogue, +) +from colossalai.interface import AMPModelMixin, ModelWrapper, OptimizerWrapper from colossalai.utils import get_current_device -from colossalai.zero import zero_model_wrapper, zero_optim_wrapper +from colossalai.zero import LowLevelZeroOptimizer -from .plugin_base import Plugin +from .dp_plugin_base import DPPluginBase from .torch_ddp_plugin import TorchDDPCheckpointIO -__all__ = ['LowLevelZeroPlugin'] +__all__ = ["LowLevelZeroPlugin"] -def _convert_to_fp16(x): +def _convert_floating_point(x, dtype: torch.dtype = torch.float16): if isinstance(x, torch.Tensor) and torch.is_floating_point(x): - return x.half() + return x.to(dtype) return x -class LowLevelZeroCheckpointIO(TorchDDPCheckpointIO): - - def save_unsharded_optimizer(self, optimizer: Optimizer, checkpoint: str, gather_dtensor: bool): - """ - Save optimizer to checkpoint but only on master process. - """ - # TODO(ver217): optimizer state dict is sharded - super().save_unsharded_optimizer(optimizer, checkpoint, gather_dtensor) +SUPPORTED_PRECISION = ["fp16", "bf16", "fp32"] -class LowLevelZeroModel(ModelWrapper): - - def __init__(self, module: nn.Module, stage: int, precision: str) -> None: +class LowLevelZeroModel(ModelWrapper, AMPModelMixin): + def __init__(self, module: nn.Module, precision: str) -> None: super().__init__(module) - self.convert_inputs = (precision == 'fp16') - module = zero_model_wrapper(module, zero_stage=stage) - if precision == 'fp16': - module = module.half() + self.dtype = None + if precision == "fp16": + self.dtype = torch.float16 + elif precision == "bf16": + self.dtype = torch.bfloat16 + if self.dtype is not None: + module = module.to(self.dtype) module = module.to(get_current_device()) self.module = module + self.convert_fn = None + if self.dtype is not None: + self.convert_fn = partial(_convert_floating_point, dtype=self.dtype) def forward(self, *args, **kwargs): - if self.convert_inputs: - args = tree_map(_convert_to_fp16, args) - kwargs = tree_map(_convert_to_fp16, kwargs) + if self.convert_fn is not None: + args = tree_map(self.convert_fn, args) + kwargs = tree_map(self.convert_fn, kwargs) return super().forward(*args, **kwargs) -class LowLevelZeroOptimizer(OptimizerWrapper): - - def __init__(self, - module: nn.Module, - optimizer: Optimizer, - zero_optim_config: dict, - optim_kwargs: dict, - verbose: bool = False) -> None: - optimizer = zero_optim_wrapper(module, - optimizer, - optim_config=zero_optim_config, - **optim_kwargs, - verbose=verbose) - super().__init__(optimizer) - - def backward(self, loss: Tensor, *args, **kwargs): - self.optim.backward(loss) - - def clip_grad_by_norm(self, - max_norm: Union[float, int], - norm_type: Union[float, int] = 2, - error_if_nonfinite: bool = False, - *args, - **kwargs) -> Tensor: - warnings.warn(f'LowLevelZero controls grad clipping by itself, so you should not use clip_grad_by_norm') +class LowLevelZeroCheckpointIO(TorchDDPCheckpointIO): + def save_unsharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint: str, gather_dtensor: bool = False): + """Save optimizer to checkpoint but only on master process. - def clip_grad_by_value(self, clip_value: float, *args, **kwargs) -> None: - raise NotImplementedError('LowLevelZero does not support clip_grad_by_value') + Args: + optimizer (OptimizerWrapper): Optimizer to save state_dict + checkpoint (str): Path to save checkpoint + gather_dtensor (bool): Whether to gather_dtensor, not used + """ + assert isinstance(optimizer, LowLevelZeroOptimizer), "Please boost the optimizer before saving!" + # the `state_dict` in LowLevelZeroOptimizer has communication + # if only the master rank collect state_dict and save, + # the communication on each rank would not match + state_dict = optimizer.state_dict() + if self.coordinator.is_master(): + save_state_dict(state_dict, checkpoint, use_safetensors=False) + + def save_sharded_optimizer( + self, + optimizer: OptimizerWrapper, + checkpoint: str, + gather_dtensor: bool = False, + prefix: str = None, + size_per_shard: int = 1024, + ): + """ + Save sharded Zero-optimizer checkpoint under the given checkpointing path. + The following files will be created under the path: + - An index file (pytorch_optim.bin.index.json) containing a map between optimizer states and file names + - A group file (pytorch_optim_group.bin) recording information of param_groups + - Multiple files (pytorch_optim-000XX.bin) that store state tensors of optimizer in a sharding way + Args: + optimizer (OptimizerWrapper): Optimizer to save sharded state_dict + checkpoint (str): Path to save optimizer state_dict + gather_dtensor (bool): Whether to gather_dtensor, not used + prefix (str): Perfix of file to save + size_per_shard (int): Max file size of each file that store state tensors + """ + assert isinstance(optimizer, LowLevelZeroOptimizer), "Please boost the optimizer before saving!" + if os.path.isfile(checkpoint): + logging.error(f"Provided path ({checkpoint}) should be a directory, not a file") + return + + Path(checkpoint).mkdir(parents=True, exist_ok=True) + + # state_dict only provide only 'param_groups' + state_dict = optimizer.optim.state_dict() + # state shard would be handled by the low-level zero optimizer + sharded_state = optimizer.state_dict_shard(max_shard_size=size_per_shard) + + # Preparing file paths and index file. + states_name, save_index_file, param_group_file = get_optimizer_base_filenames(prefix) + index_file = CheckpointIndexFile(checkpoint) + + # Store the information of param groups to param_group_file. + index_file.append_meta_data("param_groups", param_group_file) + group_file_path = os.path.join(checkpoint, param_group_file) + save_param_groups(state_dict, group_file_path) + + # Save shards of optimizer states. + total_size = 0 + for idx, shard_pair in enumerate(sharded_state): + shard, current_size = shard_pair + shard_file = get_shard_filename(states_name, idx) + total_size = total_size + current_size + for param_id in shard.keys(): + index_file.append_weight_map(str(param_id), shard_file) + + checkpoint_file_path = os.path.join(checkpoint, shard_file) + if self.coordinator.is_master(): + save_state_dict(shard, checkpoint_file_path, use_safetensors=False) + + # Wrap up index file. + index_file.append_meta_data("total_size", total_size) + if self.coordinator.is_master(): + index_file.write_index_file(save_index_file) + logging.info( + f"The optimizer is going to be split to checkpoint shards. " + f"You can find where each parameters has been saved in the " + f"index located at {save_index_file}." + ) + + def load_sharded_optimizer(self, optimizer: OptimizerWrapper, index_file_path: str, prefix: str): + """Load sharded optimizer with the given path to index file. -class LowLevelZeroPlugin(Plugin): + Args: + optimizer (OptimizerWrapper): Optimizer to load state_dict + index_file_path (str): Path to the index file + prefix (str): Not used. + """ + assert isinstance(optimizer, LowLevelZeroOptimizer), "Please boost the optimizer before Loading!" + optimizer = optimizer.unwrap() + + # Read checkpoint index file. + ckpt_index_file = CheckpointIndexFile.from_file(index_file_path) + + # Load param_groups + param_group_path = ckpt_index_file.get_param_group_filename() + if param_group_path is None: + raise RuntimeError( + f"Invalid index file path {index_file_path} for an optimizer. \ + Lacking param group file under current directory." + ) + id_map = load_param_groups_into_optimizer(optimizer, param_group_path) + + checkpoint_files, _ = ckpt_index_file.get_checkpoint_filenames() + + for shard_file in checkpoint_files: + state_dict = load_shard_state_dict(Path(shard_file), use_safetensors=False) + # shard state dict + for param_idx, state in state_dict.items(): + for k, v in state.items(): + if isinstance(v, torch.Tensor) and k != "step": + padding_size = ( + self.coordinator.world_size - v.numel() % self.coordinator.world_size + ) % self.coordinator.world_size + with torch.no_grad(): + v = v.flatten() + if padding_size > 0: + v = torch.nn.functional.pad(v, [0, padding_size]) + v_list = v.split(v.numel() // self.coordinator.world_size) + state_dict[param_idx][k] = v_list[self.coordinator.rank].detach().clone() + load_states_into_optimizer(optimizer, state_dict, id_map) + sharded_optimizer_loading_epilogue(optimizer) + + def load_unsharded_model(self, model: ModelWrapper, checkpoint: str, strict: bool = True): + assert isinstance(model, LowLevelZeroModel), "Please boost the model before loading!" + super().load_unsharded_model(model, checkpoint, strict) + model.update_master_params() + + def load_sharded_model( + self, + model: ModelWrapper, + checkpoint_index_file: Path, + strict: bool = False, + use_safetensors: bool = False, + load_sub_module: bool = True, + ): + assert isinstance(model, LowLevelZeroModel), "Please boost the model before loading!" + super().load_sharded_model(model, checkpoint_index_file, strict, use_safetensors, load_sub_module) + model.update_master_params() + + +class LowLevelZeroPlugin(DPPluginBase): """ Plugin for low level zero. - Example: - >>> from colossalai.booster import Booster - >>> from colossalai.booster.plugin import LowLevelZeroPlugin - >>> - >>> model, train_dataset, optimizer, criterion = ... - >>> plugin = LowLevelZeroPlugin() + ```python + from colossalai.booster import Booster + from colossalai.booster.plugin import LowLevelZeroPlugin - >>> train_dataloader = plugin.prepare_train_dataloader(train_dataset, batch_size=8) - >>> booster = Booster(plugin=plugin) - >>> model, optimizer, train_dataloader, criterion = booster.boost(model, optimizer, train_dataloader, criterion) + model, train_dataset, optimizer, criterion = ... + plugin = LowLevelZeroPlugin() + + train_dataloader = plugin.prepare_dataloader(train_dataset, batch_size=8) + booster = Booster(plugin=plugin) + model, optimizer, train_dataloader, criterion = booster.boost(model, optimizer, train_dataloader, criterion) + ``` Args: - strage (int, optional): ZeRO stage. Defaults to 1. - precision (str, optional): precision. Support 'fp16' and 'fp32'. Defaults to 'fp16'. + stage (int, optional): ZeRO stage. Defaults to 1. + precision (str, optional): precision. Support 'fp16', 'bf16' and 'fp32'. Defaults to 'fp16'. initial_scale (float, optional): Initial scale used by DynamicGradScaler. Defaults to 2**32. min_scale (float, optional): Min scale used by DynamicGradScaler. Defaults to 1. growth_factor (float, optional): growth_factor used by DynamicGradScaler. Defaults to 2. @@ -126,7 +248,7 @@ class LowLevelZeroPlugin(Plugin): def __init__( self, stage: int = 1, - precision: str = 'fp16', + precision: str = "fp16", initial_scale: float = 2**32, min_scale: float = 1, growth_factor: float = 2, @@ -142,113 +264,64 @@ class LowLevelZeroPlugin(Plugin): cpu_offload: bool = False, verbose: bool = False, ) -> None: - - assert dist.is_initialized( - ), 'torch.distributed is not initialized, please use colossalai.launch to create the distributed environment' - assert stage in (1, 2), f'LowLevelZeroPlugin only supports stage 1/2 training' - assert precision in ('fp16', 'fp32'), f'LowLevelZeroPlugin only supports fp16/fp32 training' - - self.rank = dist.get_rank() - self.world_size = dist.get_world_size() - + super().__init__() + assert stage in (1, 2), f"LowLevelZeroPlugin only supports stage 1/2 training" + assert precision in SUPPORTED_PRECISION, f"LowLevelZeroPlugin only supports {SUPPORTED_PRECISION} training" + assert norm_type == 2.0, f"LowLevelZeroPlugin only supports norm_type=2.0 now" self.stage = stage self.precision = precision - self.zero_optim_config = dict(reduce_bucket_size=reduce_bucket_size_in_m * 1024 * 1024, - communication_dtype=communication_dtype, - overlap_communication=overlap_communication, - cpu_offload=cpu_offload) - self.optim_kwargs = dict(initial_scale=initial_scale, - growth_factor=growth_factor, - backoff_factor=backoff_factor, - growth_interval=growth_interval, - hysteresis=hysteresis, - min_scale=min_scale, - max_scale=max_scale, - max_norm=max_norm, - norm_type=norm_type) + self.zero_optim_kwargs = dict( + initial_scale=initial_scale, + growth_factor=growth_factor, + backoff_factor=backoff_factor, + growth_interval=growth_interval, + hysteresis=hysteresis, + min_scale=min_scale, + max_scale=max_scale, + clip_grad_norm=max_norm, + reduce_bucket_size=reduce_bucket_size_in_m * 1024 * 1024, + communication_dtype=communication_dtype, + overlap_communication=overlap_communication, + cpu_offload=cpu_offload, + partition_grad=(stage == 2), + ) self.verbose = verbose + # set class name with stage, for better error message + setattr(self.__class__, "__name__", f"LowLevelZeroPlugin_ZeRO-{stage}") + def support_no_sync(self) -> bool: - return False + return self.stage == 1 def control_precision(self) -> bool: return True def supported_precisions(self) -> List[str]: - return ['fp16', 'fp32'] + return SUPPORTED_PRECISION def control_device(self) -> bool: return True def supported_devices(self) -> List[str]: - return ['cuda'] - - def prepare_train_dataloader(self, - dataset, - batch_size, - shuffle=False, - seed=1024, - drop_last=False, - pin_memory=False, - num_workers=0, - **kwargs): - r""" - Prepare a dataloader for distributed training. The dataloader will be wrapped by - `torch.utils.data.DataLoader` and `torch.utils.data.DistributedSampler`. - - Note: - 1. Evaluation datasets should not be passed to this function. - - Args: - dataset (`torch.utils.data.Dataset`): The dataset to be loaded. - shuffle (bool, optional): Whether to shuffle the dataset. Defaults to False. - seed (int, optional): Random worker seed for sampling, defaults to 1024. - add_sampler: Whether to add ``DistributedDataParallelSampler`` to the dataset. Defaults to True. - drop_last (bool, optional): Set to True to drop the last incomplete batch, if the dataset size - is not divisible by the batch size. If False and the size of dataset is not divisible by - the batch size, then the last batch will be smaller, defaults to False. - pin_memory (bool, optional): Whether to pin memory address in CPU memory. Defaults to False. - num_workers (int, optional): Number of worker threads for this dataloader. Defaults to 0. - kwargs (dict): optional parameters for ``torch.utils.data.DataLoader``, more details could be found in - `DataLoader `_. - - Returns: - :class:`torch.utils.data.DataLoader`: A DataLoader used for training or testing. - """ - _kwargs = kwargs.copy() - sampler = DistributedSampler(dataset, num_replicas=self.world_size, rank=self.rank, shuffle=shuffle) - - # Deterministic dataloader - def seed_worker(worker_id): - worker_seed = seed - np.random.seed(worker_seed) - torch.manual_seed(worker_seed) - random.seed(worker_seed) - - return DataLoader(dataset, - batch_size=batch_size, - sampler=sampler, - worker_init_fn=seed_worker, - drop_last=drop_last, - pin_memory=pin_memory, - num_workers=num_workers, - **_kwargs) + return ["cuda"] def configure( self, model: nn.Module, - optimizer: Optimizer, - criterion: Callable = None, - dataloader: DataLoader = None, - lr_scheduler: LRScheduler = None, - ) -> Tuple[Union[nn.Module, OptimizerWrapper, LRScheduler, DataLoader]]: - + optimizer: Optional[Optimizer] = None, + criterion: Optional[Callable] = None, + dataloader: Optional[DataLoader] = None, + lr_scheduler: Optional[LRScheduler] = None, + ) -> Tuple[nn.Module, OptimizerWrapper, Callable, DataLoader, LRScheduler]: if not isinstance(model, ModelWrapper): - model = LowLevelZeroModel(model, self.stage, self.precision) + model = LowLevelZeroModel(model, self.precision) - if not isinstance(optimizer, OptimizerWrapper): - optimizer = LowLevelZeroOptimizer(model.unwrap(), optimizer, self.zero_optim_config, self.optim_kwargs, - self.verbose) + if optimizer is not None and not isinstance(optimizer, OptimizerWrapper): + optimizer: LowLevelZeroOptimizer = LowLevelZeroOptimizer( + optimizer, **self.zero_optim_kwargs, verbose=self.verbose + ) + # inject update_master_params + model.update_master_params = MethodType(optimizer.update_master_params, model) return model, optimizer, criterion, dataloader, lr_scheduler @@ -257,3 +330,7 @@ class LowLevelZeroPlugin(Plugin): def get_checkpoint_io(self) -> CheckpointIO: return LowLevelZeroCheckpointIO() + + def no_sync(self, model: nn.Module, optimizer: OptimizerWrapper) -> Iterator[None]: + assert isinstance(optimizer, LowLevelZeroOptimizer) + return optimizer.optim.no_sync() diff --git a/colossalai/booster/plugin/plugin_base.py b/colossalai/booster/plugin/plugin_base.py index 7a222022c1b264585b03ea4c402e580a909c4af7..4e570cbe8abc613bf88282d4f478f5698da211ff 100644 --- a/colossalai/booster/plugin/plugin_base.py +++ b/colossalai/booster/plugin/plugin_base.py @@ -1,19 +1,18 @@ from abc import ABC, abstractmethod -from typing import Callable, List, Tuple, Union +from typing import Callable, Iterator, List, Optional, Tuple import torch.nn as nn from torch.optim import Optimizer from torch.optim.lr_scheduler import _LRScheduler as LRScheduler -from torch.utils.data import DataLoader +from torch.utils.data import DataLoader, Dataset from colossalai.checkpoint_io import CheckpointIO from colossalai.interface import OptimizerWrapper -__all__ = ['Plugin'] +__all__ = ["Plugin"] class Plugin(ABC): - @abstractmethod def supported_devices(self) -> List[str]: pass @@ -38,11 +37,11 @@ class Plugin(ABC): def configure( self, model: nn.Module, - optimizer: Optimizer, - criterion: Callable = None, - dataloader: DataLoader = None, - lr_scheduler: LRScheduler = None, - ) -> Tuple[Union[nn.Module, OptimizerWrapper, LRScheduler, DataLoader]]: + optimizer: Optional[Optimizer] = None, + criterion: Optional[Callable] = None, + dataloader: Optional[DataLoader] = None, + lr_scheduler: Optional[LRScheduler] = None, + ) -> Tuple[nn.Module, OptimizerWrapper, Callable, DataLoader, LRScheduler]: # implement this method pass @@ -51,11 +50,31 @@ class Plugin(ABC): """ Whether the plugin controls the checkpoint io """ - pass @abstractmethod def get_checkpoint_io(self) -> CheckpointIO: """ Get checkpoint io object for this plugin, only invoked when control_checkpoint_io is True. """ - pass + + @abstractmethod + def no_sync(self, model: nn.Module, optimizer: OptimizerWrapper) -> Iterator[None]: + """ + Context manager to disable gradient synchronization. + """ + + @abstractmethod + def prepare_dataloader( + self, + dataset: Dataset, + batch_size: int, + shuffle: bool = False, + seed: int = 1024, + drop_last: bool = False, + pin_memory: bool = False, + num_workers: int = 0, + **kwargs, + ): + """Prepare a dataloader for distributed training. The dataloader will be wrapped by + `torch.utils.data.DataLoader` + """ diff --git a/colossalai/booster/plugin/pp_plugin_base.py b/colossalai/booster/plugin/pp_plugin_base.py new file mode 100644 index 0000000000000000000000000000000000000000..3d91eb95b4097265f7e89e0598b6592b2da58848 --- /dev/null +++ b/colossalai/booster/plugin/pp_plugin_base.py @@ -0,0 +1,22 @@ +from abc import abstractmethod +from typing import Any, Callable, Iterator, Optional + +import torch + +from colossalai.interface import ModelWrapper, OptimizerWrapper + +from .plugin_base import Plugin + + +class PipelinePluginBase(Plugin): + @abstractmethod + def execute_pipeline( + self, + data_iter: Iterator, + model: ModelWrapper, + criterion: Callable[[Any, Any], torch.Tensor], + optimizer: Optional[OptimizerWrapper] = None, + return_loss: bool = True, + return_outputs: bool = False, + ) -> dict: + pass diff --git a/colossalai/booster/plugin/torch_ddp_plugin.py b/colossalai/booster/plugin/torch_ddp_plugin.py index c5e310c7e7695893266784e9b90018b0d373f6f2..738634473dbc669e377dca614c2e890ee603baa9 100644 --- a/colossalai/booster/plugin/torch_ddp_plugin.py +++ b/colossalai/booster/plugin/torch_ddp_plugin.py @@ -1,50 +1,52 @@ -import random -from typing import Callable, List, Tuple, Union +from typing import Callable, Iterator, List, Optional, Tuple -import numpy as np -import torch -import torch.distributed as dist import torch.nn as nn from torch.nn.parallel import DistributedDataParallel as DDP from torch.optim import Optimizer from torch.optim.lr_scheduler import _LRScheduler as LRScheduler from torch.utils.data import DataLoader -from torch.utils.data.distributed import DistributedSampler from colossalai.checkpoint_io import CheckpointIO, GeneralCheckpointIO from colossalai.cluster import DistCoordinator from colossalai.interface import ModelWrapper, OptimizerWrapper -from .plugin_base import Plugin +from .dp_plugin_base import DPPluginBase -__all__ = ['TorchDDPPlugin'] +__all__ = ["TorchDDPPlugin"] class TorchDDPCheckpointIO(GeneralCheckpointIO): - def __init__(self) -> None: super().__init__() self.coordinator = DistCoordinator() - def load_unsharded_model(self, model: nn.Module, checkpoint: str, strict: bool = True): + def load_unsharded_model(self, model: ModelWrapper, checkpoint: str, strict: bool = True): """ - Load model from checkpoint with automatic unwrapping. + Load model from checkpoint. """ - # the model should be unwrapped in self.load_model via ModelWrapper.unwrap - return super().load_unsharded_model(model, checkpoint, strict=strict) + assert isinstance(model, ModelWrapper), "Please boost the model before loading!" + super().load_unsharded_model(model.unwrap(), checkpoint, strict=strict) - def save_unsharded_model(self, model: nn.Module, checkpoint: str, gather_dtensor: bool, use_safetensors: bool): + def save_unsharded_model(self, model: ModelWrapper, checkpoint: str, gather_dtensor: bool, use_safetensors: bool): """ Save model to checkpoint but only on master process. """ - # the model should be unwrapped in self.load_model via ModelWrapper.unwrap + assert isinstance(model, ModelWrapper), "Please boost the model before saving!" if self.coordinator.is_master(): - super().save_unsharded_model(model, checkpoint, gather_dtensor, use_safetensors) + super().save_unsharded_model(model.unwrap(), checkpoint, gather_dtensor, use_safetensors) + + def load_unsharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint: str): + """ + Load optimizer from checkpoint. + """ + assert isinstance(optimizer, OptimizerWrapper), "Please boost the optimizer before loading!" + super().load_unsharded_optimizer(optimizer, checkpoint) - def save_unsharded_optimizer(self, optimizer: Optimizer, checkpoint: str, gather_dtensor: bool): + def save_unsharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint: str, gather_dtensor: bool): """ Save optimizer to checkpoint but only on master process. """ + assert isinstance(optimizer, OptimizerWrapper), "Please boost the optimizer before saving!" if self.coordinator.is_master(): super().save_unsharded_optimizer(optimizer, checkpoint, gather_dtensor) @@ -55,9 +57,67 @@ class TorchDDPCheckpointIO(GeneralCheckpointIO): if self.coordinator.is_master(): super().save_lr_scheduler(lr_scheduler, checkpoint) + def save_sharded_model( + self, + model: ModelWrapper, + checkpoint_path: str, + gather_dtensor: bool = True, + prefix: Optional[str] = None, + max_shard_size: int = 1024, + use_safetensors: bool = False, + ): + """ + Save model to checkpoint but only on master process. + """ + assert isinstance(model, ModelWrapper), "Please boost the model before saving!" + if self.coordinator.is_master(): + super().save_sharded_model( + model.unwrap(), checkpoint_path, gather_dtensor, prefix, max_shard_size, use_safetensors + ) -class TorchDDPModel(ModelWrapper): + def load_sharded_model( + self, + model: ModelWrapper, + checkpoint_index_file: str, + strict: bool = False, + use_safetensors: bool = False, + load_sub_module: bool = True, + ): + """ + Load model from sharded checkpoint. + """ + assert isinstance(model, ModelWrapper), "Please boost the model before loading!" + super().load_sharded_model(model.unwrap(), checkpoint_index_file, strict, use_safetensors, load_sub_module) + + def save_sharded_optimizer( + self, + optimizer: OptimizerWrapper, + checkpoint: str, + gather_dtensor: bool = True, + prefix: Optional[str] = None, + size_per_shard: int = 1024, + ): + """ + Save optimizer to sharded checkpoint but only on master process. + """ + assert isinstance(optimizer, OptimizerWrapper), "Please boost the optimizer before saving!" + if self.coordinator.is_master(): + super().save_sharded_optimizer(optimizer.unwrap(), checkpoint, gather_dtensor, prefix, size_per_shard) + + def load_sharded_optimizer( + self, + optimizer: Optimizer, + index_file_path: str, + prefix: Optional[str] = None, + ): + """ + Load optimizer from sharded checkpoint. + """ + assert isinstance(optimizer, OptimizerWrapper), "Please boost the optimizer before loading!" + super().load_sharded_optimizer(optimizer.unwrap(), index_file_path, prefix) + +class TorchDDPModel(ModelWrapper): def __init__(self, module: nn.Module, *args, **kwargs) -> None: super().__init__(module) self.module = DDP(module, *args, **kwargs) @@ -66,20 +126,21 @@ class TorchDDPModel(ModelWrapper): return self.module.module -class TorchDDPPlugin(Plugin): +class TorchDDPPlugin(DPPluginBase): """ Plugin for PyTorch DDP. - Example: - >>> from colossalai.booster import Booster - >>> from colossalai.booster.plugin import TorchDDPPlugin - >>> - >>> model, train_dataset, optimizer, criterion = ... - >>> plugin = TorchDDPPlugin() + ```python + from colossalai.booster import Booster + from colossalai.booster.plugin import TorchDDPPlugin - >>> train_dataloader = plugin.prepare_train_dataloader(train_dataset, batch_size=8) - >>> booster = Booster(plugin=plugin) - >>> model, optimizer, train_dataloader, criterion = booster.boost(model, optimizer, train_dataloader, criterion) + model, train_dataset, optimizer, criterion = ... + plugin = TorchDDPPlugin() + + train_dataloader = plugin.prepare_dataloader(train_dataset, batch_size=8) + booster = Booster(plugin=plugin) + model, optimizer, train_dataloader, criterion = booster.boost(model, optimizer, train_dataloader, criterion) + ``` Args: broadcast_buffers (bool, optional): Whether to broadcast buffers in the beginning of training. Defaults to True. @@ -90,24 +151,24 @@ class TorchDDPPlugin(Plugin): static_graph (bool, optional): Whether to use static graph. Defaults to False. """ - def __init__(self, - broadcast_buffers: bool = True, - bucket_cap_mb: int = 25, - find_unused_parameters: bool = False, - check_reduction: bool = False, - gradient_as_bucket_view: bool = False, - static_graph: bool = False) -> None: - - assert dist.is_initialized( - ), 'torch.distributed is not initialized, please use colossalai.launch to create the distributed environment' - self.rank = dist.get_rank() - self.world_size = dist.get_world_size() - self.ddp_kwargs = dict(broadcast_buffers=broadcast_buffers, - bucket_cap_mb=bucket_cap_mb, - find_unused_parameters=find_unused_parameters, - check_reduction=check_reduction, - gradient_as_bucket_view=gradient_as_bucket_view, - static_graph=static_graph) + def __init__( + self, + broadcast_buffers: bool = True, + bucket_cap_mb: int = 25, + find_unused_parameters: bool = False, + check_reduction: bool = False, + gradient_as_bucket_view: bool = False, + static_graph: bool = False, + ) -> None: + super().__init__() + self.ddp_kwargs = dict( + broadcast_buffers=broadcast_buffers, + bucket_cap_mb=bucket_cap_mb, + find_unused_parameters=find_unused_parameters, + check_reduction=check_reduction, + gradient_as_bucket_view=gradient_as_bucket_view, + static_graph=static_graph, + ) def support_no_sync(self) -> bool: return True @@ -116,73 +177,22 @@ class TorchDDPPlugin(Plugin): return False def supported_precisions(self) -> List[str]: - return ['fp16', 'fp16_apex', 'bf16', 'fp8'] + return ["fp16", "fp16_apex", "bf16", "fp8"] def control_device(self) -> bool: return True def supported_devices(self) -> List[str]: - return ['cuda'] - - def prepare_train_dataloader(self, - dataset, - batch_size, - shuffle=False, - seed=1024, - drop_last=False, - pin_memory=False, - num_workers=0, - **kwargs): - r""" - Prepare a dataloader for distributed training. The dataloader will be wrapped by - `torch.utils.data.DataLoader` and `torch.utils.data.DistributedSampler`. - - Note: - 1. Evaluation datasets should not be passed to this function. - - Args: - dataset (`torch.utils.data.Dataset`): The dataset to be loaded. - shuffle (bool, optional): Whether to shuffle the dataset. Defaults to False. - seed (int, optional): Random worker seed for sampling, defaults to 1024. - add_sampler: Whether to add ``DistributedDataParallelSampler`` to the dataset. Defaults to True. - drop_last (bool, optional): Set to True to drop the last incomplete batch, if the dataset size - is not divisible by the batch size. If False and the size of dataset is not divisible by - the batch size, then the last batch will be smaller, defaults to False. - pin_memory (bool, optional): Whether to pin memory address in CPU memory. Defaults to False. - num_workers (int, optional): Number of worker threads for this dataloader. Defaults to 0. - kwargs (dict): optional parameters for ``torch.utils.data.DataLoader``, more details could be found in - `DataLoader `_. - - Returns: - :class:`torch.utils.data.DataLoader`: A DataLoader used for training or testing. - """ - _kwargs = kwargs.copy() - sampler = DistributedSampler(dataset, num_replicas=self.world_size, rank=self.rank, shuffle=shuffle) - - # Deterministic dataloader - def seed_worker(worker_id): - worker_seed = seed - np.random.seed(worker_seed) - torch.manual_seed(worker_seed) - random.seed(worker_seed) - - return DataLoader(dataset, - batch_size=batch_size, - sampler=sampler, - worker_init_fn=seed_worker, - drop_last=drop_last, - pin_memory=pin_memory, - num_workers=num_workers, - **_kwargs) + return ["cuda"] def configure( self, model: nn.Module, - optimizer: Optimizer, - criterion: Callable = None, - dataloader: DataLoader = None, - lr_scheduler: LRScheduler = None, - ) -> Tuple[Union[nn.Module, OptimizerWrapper, LRScheduler, DataLoader]]: + optimizer: Optional[Optimizer] = None, + criterion: Optional[Callable] = None, + dataloader: Optional[DataLoader] = None, + lr_scheduler: Optional[LRScheduler] = None, + ) -> Tuple[nn.Module, OptimizerWrapper, Callable, DataLoader, LRScheduler]: # cast model to cuda model = model.cuda() @@ -192,7 +202,7 @@ class TorchDDPPlugin(Plugin): # wrap the model with PyTorch DDP model = TorchDDPModel(model, **self.ddp_kwargs) - if not isinstance(optimizer, OptimizerWrapper): + if optimizer is not None and not isinstance(optimizer, OptimizerWrapper): optimizer = OptimizerWrapper(optimizer) return model, optimizer, criterion, dataloader, lr_scheduler @@ -202,3 +212,7 @@ class TorchDDPPlugin(Plugin): def get_checkpoint_io(self) -> CheckpointIO: return TorchDDPCheckpointIO() + + def no_sync(self, model: nn.Module, optimizer: OptimizerWrapper) -> Iterator[None]: + assert isinstance(model, TorchDDPModel), "Model is not boosted by TorchDDPPlugin." + return model.module.no_sync() diff --git a/colossalai/booster/plugin/torch_fsdp_plugin.py b/colossalai/booster/plugin/torch_fsdp_plugin.py new file mode 100644 index 0000000000000000000000000000000000000000..2ea7593a5cc548be469fddd32ada454cdea1f118 --- /dev/null +++ b/colossalai/booster/plugin/torch_fsdp_plugin.py @@ -0,0 +1,237 @@ +import warnings +from pathlib import Path +from typing import Callable, Iterable, Iterator, List, Optional, Tuple + +import torch +import torch.nn as nn +from packaging import version +from torch.distributed import ProcessGroup + +if version.parse(torch.__version__) >= version.parse("1.12.0"): + from torch.distributed.fsdp import FullStateDictConfig + from torch.distributed.fsdp import FullyShardedDataParallel as FSDP + from torch.distributed.fsdp import StateDictType + from torch.distributed.fsdp.fully_sharded_data_parallel import ( + BackwardPrefetch, + CPUOffload, + FullStateDictConfig, + MixedPrecision, + ShardingStrategy, + ) +else: + raise RuntimeError("FSDP is not supported while torch version under 1.12.0.") + +from torch.optim import Optimizer +from torch.optim.lr_scheduler import _LRScheduler as LRScheduler +from torch.utils.data import DataLoader + +from colossalai.checkpoint_io import CheckpointIO, GeneralCheckpointIO, utils +from colossalai.cluster import DistCoordinator +from colossalai.interface import ModelWrapper, OptimizerWrapper + +from .dp_plugin_base import DPPluginBase + +__all__ = ["TorchFSDPPlugin"] + + +class TorchFSDPCheckpointIO(GeneralCheckpointIO): + def __init__(self) -> None: + super().__init__() + self.coordinator = DistCoordinator() + + def load_unsharded_model(self, model: ModelWrapper, checkpoint: str, strict: bool): + assert isinstance(model, TorchFSDPModel), "Please boost the model before loading!" + model = model.unwrap() + checkpoint = utils.load_state_dict(checkpoint) + model.load_state_dict(checkpoint) + + def load_unsharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint: Path): + assert isinstance(optimizer, FSDPOptimizerWrapper), "Please boost the optimizer before loading!" + checkpoint = utils.load_state_dict(checkpoint) + fsdp_model = optimizer.unwrap_model() + sharded_osd = FSDP.scatter_full_optim_state_dict(checkpoint, fsdp_model) + optimizer.load_state_dict(sharded_osd) + + def save_unsharded_model(self, model: ModelWrapper, checkpoint: str, gather_dtensor: bool, use_safetensors: bool): + """ + Save model to checkpoint but only on master process. + """ + assert isinstance(model, TorchFSDPModel), "Please boost the model before saving!" + model = model.unwrap() + cfg = FullStateDictConfig(offload_to_cpu=True, rank0_only=True) + with FSDP.state_dict_type(model, StateDictType.FULL_STATE_DICT, cfg): + full_model_state = model.state_dict() + utils.save_state_dict(full_model_state, checkpoint_file_path=checkpoint, use_safetensors=use_safetensors) + + def save_unsharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint: str, gather_dtensor: bool): + """ + Save optimizer to checkpoint but only on master process. + """ + assert isinstance(optimizer, FSDPOptimizerWrapper), "Please boost the optimizer before saving!" + fsdp_model = optimizer.unwrap_model() + full_optimizer_state = FSDP.full_optim_state_dict(fsdp_model, optim=optimizer, rank0_only=True) + utils.save_state_dict(full_optimizer_state, checkpoint_file_path=checkpoint, use_safetensors=False) + + def save_sharded_model( + self, + model: nn.Module, + checkpoint: str, + gather_dtensor: bool, + prefix: Optional[str], + size_per_shard: int, + use_safetensors: bool, + ): + """ + Save model to checkpoint but only on master process. + """ + raise NotImplementedError("Sharded model checkpoint is not supported yet.") + + def load_sharded_model( + self, + model: nn.Module, + checkpoint_index_file: Path, + strict: bool = False, + use_safetensors: bool = False, + load_sub_module: bool = True, + ): + """ + Load model to checkpoint but only on master process. + """ + raise NotImplementedError("Sharded model checkpoint is not supported yet.") + + def save_sharded_optimizer( + self, optimizer: Optimizer, checkpoint: str, gather_dtensor: bool, prefix: str, size_per_shard: int + ): + """ + Save optimizer to checkpoint but only on master process. + """ + raise NotImplementedError("Sharded optimizer checkpoint is not supported yet.") + + def load_sharded_optimizer(self, optimizer: Optimizer, index_file_path: str, size_per_shard: int): + """ + Load optimizer to checkpoint but only on master process. + """ + raise NotImplementedError("Sharded optimizer checkpoint is not supported yet.") + + def save_lr_scheduler(self, lr_scheduler: LRScheduler, checkpoint: str): + """ + Save model to checkpoint but only on master process. + """ + if self.coordinator.is_master(): + super().save_lr_scheduler(lr_scheduler, checkpoint) + + +class TorchFSDPModel(ModelWrapper): + def __init__(self, module: nn.Module, *args, **kwargs) -> None: + super().__init__(module) + self.module = FSDP(module, *args, **kwargs) + + def unwrap(self): + return self.module + + +class FSDPOptimizerWrapper(OptimizerWrapper): + def __init__(self, optimizer: Optimizer, model: nn.Module): + self.model = model + super().__init__(optimizer) + + def unwrap_model(self) -> nn.Module: + return self.model + + +class TorchFSDPPlugin(DPPluginBase): + """ + Plugin for PyTorch FSDP. + + ```python + from colossalai.booster import Booster + from colossalai.booster.plugin import TorchFSDPPlugin + + model, train_dataset, optimizer, criterion = ... + plugin = TorchFSDPPlugin() + + train_dataloader = plugin.prepare_train_dataloader(train_dataset, batch_size=8) + booster = Booster(plugin=plugin) + model, optimizer, train_dataloader, criterion = booster.boost(model, optimizer, train_dataloader, criterion) + ``` + + Args: + See https://pytorch.org/docs/stable/fsdp.html for details. + """ + + if version.parse(torch.__version__) >= version.parse("1.12.0"): + + def __init__( + self, + process_group: Optional[ProcessGroup] = None, + sharding_strategy: Optional[ShardingStrategy] = None, + cpu_offload: Optional[CPUOffload] = None, + auto_wrap_policy: Optional[Callable] = None, + backward_prefetch: Optional[BackwardPrefetch] = None, + mixed_precision: Optional[MixedPrecision] = None, + ignored_modules: Optional[Iterable[torch.nn.Module]] = None, + param_init_fn: Optional[Callable[[nn.Module], None]] = None, + sync_module_states: bool = False, + ): + super().__init__() + self.fsdp_kwargs = dict( + process_group=process_group, + sharding_strategy=sharding_strategy, + cpu_offload=cpu_offload, + auto_wrap_policy=auto_wrap_policy, + backward_prefetch=backward_prefetch, + mixed_precision=mixed_precision, + ignored_modules=ignored_modules, + param_init_fn=param_init_fn, + sync_module_states=sync_module_states, + ) + + else: + raise RuntimeError("FSDP is not supported while torch version under 1.12.0.") + + def support_no_sync(self) -> bool: + False + + def no_sync(self, model: nn.Module, optimizer: OptimizerWrapper) -> Iterator[None]: + raise NotImplementedError("Torch fsdp no_sync func not supported yet.") + + def control_precision(self) -> bool: + return True + + def supported_precisions(self) -> List[str]: + return ["fp16", "bf16"] + + def control_device(self) -> bool: + return True + + def supported_devices(self) -> List[str]: + return ["cuda"] + + def configure( + self, + model: nn.Module, + optimizer: Optional[Optimizer] = None, + criterion: Optional[Callable] = None, + dataloader: Optional[DataLoader] = None, + lr_scheduler: Optional[LRScheduler] = None, + ) -> Tuple[nn.Module, OptimizerWrapper, Callable, DataLoader, LRScheduler]: + # wrap the model with PyTorch FSDP + fsdp_model = TorchFSDPModel(model, device_id=torch.cuda.current_device(), **self.fsdp_kwargs) + + if optimizer is not None: + if len(optimizer.param_groups) > 1: + warnings.warn( + "TorchFSDPPlugin does not support optimizer that use multi param groups. The results may not be as expected if used." + ) + optimizer.__init__(fsdp_model.parameters(), **optimizer.defaults) + + if not isinstance(optimizer, FSDPOptimizerWrapper): + optimizer = FSDPOptimizerWrapper(optimizer, fsdp_model) + + return fsdp_model, optimizer, criterion, dataloader, lr_scheduler + + def control_checkpoint_io(self) -> bool: + return True + + def get_checkpoint_io(self) -> CheckpointIO: + return TorchFSDPCheckpointIO() diff --git a/colossalai/builder/__init__.py b/colossalai/builder/__init__.py deleted file mode 100644 index cf09e1e7a31a15e979c5358c5da27683a0ccb2f9..0000000000000000000000000000000000000000 --- a/colossalai/builder/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -from .builder import build_from_config, build_from_registry, build_gradient_handler - -__all__ = ['build_gradient_handler', 'build_from_config', 'build_from_registry'] diff --git a/colossalai/checkpoint_io/__init__.py b/colossalai/checkpoint_io/__init__.py index c25048e25754eb6fc55db1b1de4ac7b21d05bda3..19b61730bded60230d80cbca434f9e06c0fd66c5 100644 --- a/colossalai/checkpoint_io/__init__.py +++ b/colossalai/checkpoint_io/__init__.py @@ -1,5 +1,6 @@ from .checkpoint_io_base import CheckpointIO from .general_checkpoint_io import GeneralCheckpointIO +from .hybrid_parallel_checkpoint_io import HybridParallelCheckpointIO from .index_file import CheckpointIndexFile -__all__ = ['CheckpointIO', 'CheckpointIndexFile', 'GeneralCheckpointIO'] +__all__ = ["CheckpointIO", "CheckpointIndexFile", "GeneralCheckpointIO", "HybridParallelCheckpointIO"] diff --git a/colossalai/checkpoint_io/checkpoint_io_base.py b/colossalai/checkpoint_io/checkpoint_io_base.py index cb853559c48c1a6ceb41231e312a2480fd68bd97..780117598e183a98b20baa593b8f80179b63d360 100644 --- a/colossalai/checkpoint_io/checkpoint_io_base.py +++ b/colossalai/checkpoint_io/checkpoint_io_base.py @@ -1,7 +1,6 @@ from abc import ABC, abstractmethod from pathlib import Path -from typing import Union -from typing import Optional +from typing import Optional, Union import torch import torch.nn as nn @@ -12,7 +11,7 @@ from colossalai.interface import ModelWrapper from .utils import has_index_file -__all__ = ['CheckpointIO'] +__all__ = ["CheckpointIO"] class CheckpointIO(ABC): @@ -62,10 +61,9 @@ class CheckpointIO(ABC): # ====================================== # Public methods # ====================================== - def load_model(self, - model: Union[nn.Module, ModelWrapper], - checkpoint: str, - strict: bool = True) -> Union[nn.Module, ModelWrapper]: + def load_model( + self, model: Union[nn.Module, ModelWrapper], checkpoint: str, strict: bool = True + ) -> Union[nn.Module, ModelWrapper]: """ Load model from checkpoint. @@ -84,15 +82,11 @@ class CheckpointIO(ABC): # containing no distributed tensors, dtensor -> full tensor conversion # should be done offline via our CLI # the existence of index file means it is a sharded checkpoint - ckpt_path = Path(checkpoint) index_file_exists, index_file_path = has_index_file(checkpoint) # return the origin model instead of the unwrapped model origin_model = model - if isinstance(model, ModelWrapper): - model = model.unwrap() - if index_file_exists: self.load_sharded_model(model, index_file_path, strict) else: @@ -100,14 +94,16 @@ class CheckpointIO(ABC): return origin_model - def save_model(self, - model: Union[nn.Module, ModelWrapper], - checkpoint: str, - shard: bool = False, - gather_dtensor: bool = True, - variant: str = None, - size_per_shard: int = 1024, - use_safetensors: bool = False): + def save_model( + self, + model: Union[nn.Module, ModelWrapper], + checkpoint: str, + shard: bool = False, + gather_dtensor: bool = True, + prefix: str = None, + size_per_shard: int = 1024, + use_safetensors: bool = False, + ): """ Save model to checkpoint. @@ -130,46 +126,49 @@ class CheckpointIO(ABC): multiple files. The model shards will be specified by a `model.index.json` file. When shard = True, please ensure that the checkpoint path is a directory path instead of a file path. gather_dtensor (bool): whether to gather the distributed tensor to the first device. Default: True. - variant (str): If specified, weights are saved in the format pytorch_model..bin. Default: None. + prefix (str): If specified, weights are saved in the format pytorch_model..bin. Default: None. size_per_shard (int): size per shard in MB. Default: 1024. This value is only used when shard = True. use_safetensors (bool): whether to use safe tensors. Default: False. If set to True, the checkpoint will be saved """ - if isinstance(model, ModelWrapper): - model = model.unwrap() - if shard: - self.save_sharded_model(model, checkpoint, gather_dtensor, variant, size_per_shard, use_safetensors) + self.save_sharded_model(model, checkpoint, gather_dtensor, prefix, size_per_shard, use_safetensors) else: self.save_unsharded_model(model, checkpoint, gather_dtensor, use_safetensors) - def load_optimizer(self, optimizer: Optimizer, checkpoint: str): + def load_optimizer(self, optimizer: Optimizer, checkpoint: str, prefix: str = None, size_per_shard: int = 1024): """ Load optimizer from checkpoint. Args: optimizer (Optimizer): optimizer to be loaded. checkpoint (str): checkpoint path. This value is made compatibility with the model checkpoints in the + prefix (str, optional): A prefix added to parameter and buffer + names to compose the keys in state_dict. Defaults to None. + size_per_shard (int, optional): Maximum size of checkpoint shard file in MB. This is useful only when ``shard=True``. Defaults to 1024. """ + index_file_exists, index_file_path = has_index_file(checkpoint) if Path(checkpoint).is_dir() and not index_file_exists: # if the checkpoint is a directory and there is no index file, raise error - raise ValueError(f'Cannot find index file in {checkpoint}') + raise ValueError(f"Cannot find index file in {checkpoint}") if index_file_exists: # the existence of index file means it is a sharded checkpoint - self.load_sharded_optimizer(optimizer, index_file_path) + self.load_sharded_optimizer(optimizer, index_file_path, prefix) else: self.load_unsharded_optimizer(optimizer, checkpoint) - def save_optimizer(self, - optimizer: Optimizer, - checkpoint: str, - shard: bool = False, - gather_dtensor=True, - prefix: str = None, - size_per_shard: int = 1024): + def save_optimizer( + self, + optimizer: Optimizer, + checkpoint: str, + shard: bool = False, + gather_dtensor=True, + prefix: str = None, + size_per_shard: int = 1024, + ): """ Save optimizer to checkpoint. Optimizer states saving is not compatible with safetensors. @@ -185,6 +184,7 @@ class CheckpointIO(ABC): prefix (str): prefix for the optimizer checkpoint when shard = True. Default: None. size_per_shard (int): size per shard in MB. Default: 1024. This value is only used when shard is set to True. """ + if shard: self.save_sharded_optimizer(optimizer, checkpoint, gather_dtensor, prefix, size_per_shard) else: @@ -204,7 +204,6 @@ class CheckpointIO(ABC): strict (bool): whether to strictly enforce that the param name in the checkpoint match the keys returned by this module's. """ - pass @abstractmethod def load_unsharded_model(self, model: nn.Module, checkpoint: str, strict: bool): @@ -217,11 +216,17 @@ class CheckpointIO(ABC): strict (bool): whether to strictly enforce that the param name in the checkpoint match the keys returned by this module's. """ - pass @abstractmethod - def save_sharded_model(self, model: nn.Module, checkpoint: str, gather_dtensor: bool, variant: Optional[str], - size_per_shard: int, use_safetensors: bool): + def save_sharded_model( + self, + model: nn.Module, + checkpoint: str, + gather_dtensor: bool, + prefix: Optional[str], + size_per_shard: int, + use_safetensors: bool, + ): """ Save model to sharded checkpoint. @@ -233,7 +238,6 @@ class CheckpointIO(ABC): size_per_shard (int): size per shard in MB. use_safetensors (bool): whether to use safe tensors. """ - pass @abstractmethod def save_unsharded_model(self, model: nn.Module, checkpoint: str, gather_dtensor: bool, use_safetensors: bool): @@ -246,14 +250,13 @@ class CheckpointIO(ABC): gather_dtensor (bool): whether to gather the distributed tensor to the first device. use_safetensors (bool): whether to use safe tensors. """ - pass # ======================================================== # Abstract methods for optimizer loading/saving implementation # ======================================================== @abstractmethod - def load_sharded_optimizer(self, optimizer: Optimizer, index_file_path: str, prefix: str, size_per_shard: int): + def load_sharded_optimizer(self, optimizer: Optimizer, index_file_path: str, prefix: str): """ Load optimizer from sharded checkpoint. @@ -261,9 +264,7 @@ class CheckpointIO(ABC): optimizer (Optimizer): optimizer to be loaded. index_file_path (str): checkpoint path. It should be path to the .index.json file or a path to a directory which contains a .index.json file. prefix (str): prefix for the optimizer checkpoint. - size_per_shard (int): size per shard in MB. """ - pass @abstractmethod def load_unsharded_optimizer(self, optimizer: Optimizer, checkpoint: Path): @@ -274,11 +275,11 @@ class CheckpointIO(ABC): optimizer (Optimizer): optimizer to be loaded. checkpoint (str): checkpoint path. It should be a single file path pointing to a model weight binary. """ - pass @abstractmethod - def save_sharded_optimizer(self, optimizer: Optimizer, checkpoint: Path, gather_dtensor: bool, prefix: str, - size_per_shard: int): + def save_sharded_optimizer( + self, optimizer: Optimizer, checkpoint: Path, gather_dtensor: bool, prefix: str, size_per_shard: int + ): """ Save optimizer to sharded checkpoint. @@ -289,7 +290,6 @@ class CheckpointIO(ABC): prefix (str): prefix for the optimizer checkpoint. size_per_shard (int): size per shard in MB. """ - pass @abstractmethod def save_unsharded_optimizer(self, optimizer: Optimizer, checkpoint: Path, gather_dtensor: bool): @@ -301,7 +301,6 @@ class CheckpointIO(ABC): checkpoint (str): checkpoint path. It should be a single file path pointing to a model weight binary. gather_dtensor (bool): whether to gather the distributed tensor to the first device. """ - pass # ============================================ # methods for loading and saving lr scheduler diff --git a/colossalai/checkpoint_io/general_checkpoint_io.py b/colossalai/checkpoint_io/general_checkpoint_io.py index bf584f45d045228c5d6d1e02b470c7696f5194db..a652d9b4538ec7b9270db87cf82ff988e1dde1c4 100644 --- a/colossalai/checkpoint_io/general_checkpoint_io.py +++ b/colossalai/checkpoint_io/general_checkpoint_io.py @@ -1,35 +1,41 @@ +import gc +import logging +import os +from functools import reduce from pathlib import Path +from typing import Optional import torch.nn as nn from torch.optim import Optimizer -import logging -import os -import json -import gc -from typing import Optional from .checkpoint_io_base import CheckpointIO from .index_file import CheckpointIndexFile from .utils import ( - has_index_file, - load_state_dict, - save_state_dict, + get_model_base_filenames, + get_optimizer_base_filenames, is_safetensors_available, - shard_checkpoint, + load_param_groups_into_optimizer, load_shard_state_dict, + load_state_dict, load_state_dict_into_model, - add_variant - ) -from .utils import SAFE_WEIGHTS_NAME, WEIGHTS_NAME, SAFE_WEIGHTS_INDEX_NAME, WEIGHTS_INDEX_NAME - + load_states_into_optimizer, + save_config_file, + save_param_groups, + save_state_dict, + save_state_dict_shards, + shard_model_checkpoint, + shard_optimizer_checkpoint, + sharded_optimizer_loading_epilogue, +) -__all__ = ['GeneralCheckpointIO'] +__all__ = ["GeneralCheckpointIO"] class GeneralCheckpointIO(CheckpointIO): """ Checkpoint IO """ + def load_unsharded_model(self, model: nn.Module, checkpoint: str, strict: bool): checkpoint = load_state_dict(checkpoint) model.load_state_dict(checkpoint, strict=strict) @@ -44,12 +50,30 @@ class GeneralCheckpointIO(CheckpointIO): # save the checkpoint save_state_dict(state_dict, checkpoint, use_safetensors) - def load_sharded_optimizer(self, optimizer: Optimizer, checkpoint: Path, prefix: str, size_per_shard: int): - raise NotImplementedError("Sharded optimizer checkpoint is not supported yet.") + def load_sharded_optimizer(self, optimizer: Optimizer, index_file_path: str, prefix: str): + """ + Load sharded optimizer with the given path to index file. + """ + + # Read checkpoint index file. + ckpt_index_file = CheckpointIndexFile.from_file(index_file_path) - def load_unsharded_optimizer(self, optimizer: Optimizer, checkpoint: Path): - checkpoint = load_state_dict(checkpoint) - optimizer.load_state_dict(checkpoint) + # Load param_groups + param_group_path = ckpt_index_file.get_param_group_filename() + if param_group_path is None: + raise RuntimeError( + f"Invalid index file path {index_file_path} for an optimizer. \ + Lacking param group file under current directory." + ) + id_map = load_param_groups_into_optimizer(optimizer, param_group_path) + + checkpoint_files, _ = ckpt_index_file.get_checkpoint_filenames() + + for shard_file in checkpoint_files: + state_dict = load_shard_state_dict(Path(shard_file), use_safetensors=False) + load_states_into_optimizer(optimizer, state_dict, id_map) + + sharded_optimizer_loading_epilogue(optimizer) def save_sharded_optimizer( self, @@ -59,7 +83,56 @@ class GeneralCheckpointIO(CheckpointIO): prefix: str, size_per_shard: int, ): - raise NotImplementedError("Sharded optimizer checkpoint is not supported yet.") + """ + Save sharded optimizer checkpoint under the given checkpointing path. + The following files will be created under the path: + - An index file (pytorch_optim.bin.index.json) containing a map between optimizer states and file names + - A group file (pytorch_optim_group.bin) recording information of param_groups + - Multiple files (pytorch_optim-000XX.bin) that store state tensors of optimizer in a sharding way + """ + + if os.path.isfile(checkpoint): + logging.error(f"Provided path ({checkpoint}) should be a directory, not a file") + return + + Path(checkpoint).mkdir(parents=True, exist_ok=True) + + # Offload optimizer states. States are broken into shards within max_shard_size. + state_dict = optimizer.state_dict() + sharded_state = shard_optimizer_checkpoint(state_dict, max_shard_size=size_per_shard) + + # Preparing file paths and index file. + states_name, save_index_file, param_group_file = get_optimizer_base_filenames(prefix) + index_file = CheckpointIndexFile(checkpoint) + + # Store the information of param groups to param_group_file. + index_file.append_meta_data("param_groups", param_group_file) + group_file_path = os.path.join(checkpoint, param_group_file) + save_param_groups(state_dict, group_file_path) + + # Save shards of optimizer states. + # In general cases, is_master is set to True to get the right behavior. + total_size = save_state_dict_shards( + sharded_state_dict=sharded_state, + checkpoint=checkpoint, + index_file=index_file, + base_filename=states_name, + is_master=True, + use_safetensors=False, + ) + + # Wrap up index file. + index_file.append_meta_data("total_size", total_size) + index_file.write_index_file(save_index_file) + logging.info( + f"The optimizer is going to be split to checkpoint shards. " + f"You can find where each parameters has been saved in the " + f"index located at {save_index_file}." + ) + + def load_unsharded_optimizer(self, optimizer: Optimizer, checkpoint: Path): + checkpoint = load_state_dict(checkpoint) + optimizer.load_state_dict(checkpoint) def save_unsharded_optimizer( self, @@ -70,45 +143,59 @@ class GeneralCheckpointIO(CheckpointIO): # TODO(FrankLeeeee): handle distributed tensors save_state_dict(optimizer.state_dict(), checkpoint, use_safetensors=False) - - def save_sharded_model(self, model: nn.Module, checkpoint_path: str, gather_dtensor:bool = False, - variant: Optional[str] = None, max_shard_size: int = 1024, use_safetensors: bool = False): - """ + def save_sharded_model( + self, + model: nn.Module, + checkpoint_path: str, + gather_dtensor: bool = False, + prefix: Optional[str] = None, + max_shard_size: int = 1024, + use_safetensors: bool = False, + ): + """ implement this method as it can be supported by Huggingface model, save shard model, save model to multiple files """ if os.path.isfile(checkpoint_path): logging.error(f"Provided path ({checkpoint_path}) should be a directory, not a file") return - + Path(checkpoint_path).mkdir(parents=True, exist_ok=True) - + # shard checkpoint state_dict = model.state_dict() - weights_name = SAFE_WEIGHTS_NAME if use_safetensors else WEIGHTS_NAME - weights_name = add_variant(weights_name, variant) - shards, index = shard_checkpoint(state_dict, max_shard_size=max_shard_size, weights_name=weights_name) - - # Save the model - for shard_file, shard in shards.items(): - checkpoint_file_path = os.path.join(checkpoint_path, shard_file) - save_state_dict(shard, checkpoint_file_path, use_safetensors) - - # save index file - save_index_file = SAFE_WEIGHTS_INDEX_NAME if use_safetensors else WEIGHTS_INDEX_NAME - - save_index_file = os.path.join(checkpoint_path, add_variant(save_index_file, variant)) - with open(save_index_file, "w", encoding="utf-8") as f: - content = json.dumps(index, indent=2, sort_keys=True) + "\n" - f.write(content) + state_dict_shard = shard_model_checkpoint(state_dict, max_shard_size=max_shard_size) + weights_name, save_index_file = get_model_base_filenames(prefix, use_safetensors) + index_file = CheckpointIndexFile(checkpoint_path) + + # Save shards of optimizer states. + # In general cases, is_master is set to True to get the right behavior. + total_size = save_state_dict_shards( + sharded_state_dict=state_dict_shard, + checkpoint=checkpoint_path, + index_file=index_file, + base_filename=weights_name, + is_master=True, + use_safetensors=use_safetensors, + ) + + index_file.append_meta_data("total_size", total_size) + index_file.write_index_file(save_index_file) + save_config_file(model, checkpoint_path, is_master=True) logging.info( - f"The model is bigger than the maximum size per checkpoint ({max_shard_size}) and is going to be " - f"split in {len(shards)} checkpoint shards. You can find where each parameters has been saved in the " + f"The model is going to be split to checkpoint shards. " + f"You can find where each parameters has been saved in the " f"index located at {save_index_file}." ) - - def load_sharded_model(self, model: nn.Module, checkpoint_index_file: Path, strict: bool = False, use_safetensors: bool = False): + def load_sharded_model( + self, + model: nn.Module, + checkpoint_index_file: Path, + strict: bool = False, + use_safetensors: bool = False, + load_sub_module: bool = True, + ): """ load shard model, load model from multiple files """ @@ -118,21 +205,26 @@ class GeneralCheckpointIO(CheckpointIO): if use_safetensors and not is_safetensors_available(): raise ImportError("`safe_serialization` requires the `safetensors` library: `pip install safetensors`.") - + # read checkpoint index file ckpt_index_file = CheckpointIndexFile.from_file(checkpoint_index_file) - checkpoint_files, _ = ckpt_index_file.get_checkpoint_fileanames() - missing_keys = ckpt_index_file.get_all_param_names() + checkpoint_files, _ = ckpt_index_file.get_checkpoint_filenames() + missing_keys = [] for shard_file in checkpoint_files: state_dict = load_shard_state_dict(Path(shard_file), use_safetensors) - load_state_dict_into_model(model, state_dict, missing_keys, strict) + load_state_dict_into_model(model, state_dict, missing_keys, strict, load_sub_module) del state_dict gc.collect() - if strict and len(missing_keys) > 0: - error_msgs = 'Missing key(s) in state_dict: {}. '.format( - ', '.join('"{}"'.format(k) for k in missing_keys)) - raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format( - self.__class__.__name__, "\n\t".join(error_msgs))) - + if strict: + remain_keys = reduce(lambda a, b: a & b, map(set, missing_keys)) + if len(remain_keys) > 0: + error_msgs = "Missing key(s) in state_dict: {}. ".format( + ", ".join('"{}"'.format(k) for k in missing_keys) + ) + raise RuntimeError( + "Error(s) in loading state_dict for {}:\n\t{}".format( + self.__class__.__name__, "\n\t".join(error_msgs) + ) + ) diff --git a/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py b/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py new file mode 100644 index 0000000000000000000000000000000000000000..779ff42d75a1a905e2a08d456f73462fa1f2a1b6 --- /dev/null +++ b/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py @@ -0,0 +1,876 @@ +import copy +import logging +import os +from pathlib import Path +from shutil import rmtree +from typing import Dict, Iterator, Optional, OrderedDict, Tuple + +import torch +import torch.distributed as dist +import torch.nn as nn +from torch.distributed import ProcessGroup +from torch.optim.lr_scheduler import _LRScheduler as LRScheduler + +from colossalai.cluster import DistCoordinator +from colossalai.interface import ModelWrapper, OptimizerWrapper + +from .general_checkpoint_io import GeneralCheckpointIO +from .index_file import CheckpointIndexFile +from .utils import ( + StateDictSharder, + gather_distributed_param, + get_model_base_filenames, + get_optimizer_base_filenames, + is_safetensors_available, + load_shard_state_dict, + load_state_dict, + load_state_dict_into_model, + load_states_into_optimizer, + save_config_file, + save_param_groups, + save_state_dict, + save_state_dict_shards, + search_tp_partition_dim, + sharded_optimizer_loading_epilogue, +) + +try: + from torch.nn.modules.module import _EXTRA_STATE_KEY_SUFFIX +except ImportError: + _EXTRA_STATE_KEY_SUFFIX = "_extra_state" + + +class HybridParallelCheckpointIO(GeneralCheckpointIO): + """ + CheckpointIO for Hybrid Parallel Training. + + Args: + dp_group (ProcessGroup): Process group along data parallel dimension. + pp_group (ProcessGroup): Process group along pipeline parallel dimension. + tp_group (ProcessGroup): Process group along tensor parallel dimension. + zero_stage (int): The zero stage of plugin. Should be in [0, 1, 2]. + verbose (bool, optional): Whether to print logging massage when saving/loading has been succesfully executed. Defaults to True. + """ + + def __init__( + self, + dp_group: ProcessGroup, + pp_group: ProcessGroup, + tp_group: ProcessGroup, + zero_stage: int, + verbose: bool = True, + ) -> None: + super().__init__() + self.dp_group = dp_group + self.pp_group = pp_group + self.tp_group = tp_group + self.dp_rank = dist.get_rank(self.dp_group) + self.tp_rank = dist.get_rank(self.tp_group) + self.pp_rank = dist.get_rank(self.pp_group) + self.dp_size = dist.get_world_size(dp_group) + self.pp_size = dist.get_world_size(pp_group) + self.tp_size = dist.get_world_size(tp_group) + self.use_zero = zero_stage > 0 + self.verbose = verbose + self.coordinator = DistCoordinator() + + @staticmethod + def _model_sharder( + model: nn.Module, prefix: str = "", keep_vars: bool = False, size_per_shard: int = 1024 + ) -> Iterator[Tuple[OrderedDict, int]]: + # An internel method that breaks state_dict of model into shards within limited size. + + state_dict_sharder = StateDictSharder(size_per_shard) + + # Save parameters. + for name, param in model.named_parameters(): + if param is None: + continue + # Gather tensor pieces when using tensor parallel. + param_ = gather_distributed_param(param, keep_vars=False) + block, block_size = state_dict_sharder.append_param(prefix + name, param_) + if block is not None: + yield block, block_size + + # Save buffers. + for name, buf in model.named_buffers(): + if buf is not None and name not in model._non_persistent_buffers_set: + buffer = buf if keep_vars else buf.detach() + block, block_size = state_dict_sharder.append_param(prefix + name, buffer) + if block is not None: + yield block, block_size + + # Save extra states. + extra_state_key = prefix + _EXTRA_STATE_KEY_SUFFIX + if ( + getattr(model.__class__, "get_extra_state", torch.nn.Module.get_extra_state) + is not torch.nn.Module.get_extra_state + ): + extra_state = model.get_extra_state() + block, block_size = state_dict_sharder.append_param(extra_state_key, extra_state) + if block is not None: + yield block, block_size + + # Return the last block in sharder. + yield state_dict_sharder.current_block, state_dict_sharder.current_block_size + + @staticmethod + def _optimizer_sharder( + optimizer: OptimizerWrapper, + use_zero: bool, + dp_group: ProcessGroup, + tp_group: ProcessGroup, + size_per_shard: int = 1024, + ): + # An internel method that breaks state_dict of optimizer into shards within limited size. + + state_dict_sharder = StateDictSharder(size_per_shard) + param_info = optimizer.param_info + master_to_working_map = optimizer.get_master_to_working_map() + + for param, state in optimizer.optim.state.items(): + if param is None: + continue + + if master_to_working_map is not None: + working_param = master_to_working_map[id(param)] + else: + working_param = param + + param_id = param_info["param2id"][id(working_param)] + original_shape = param_info["param2shape"][id(working_param)] + state_ = HybridParallelCheckpointIO.gather_from_sharded_optimizer_state( + state, + working_param, + original_shape=original_shape, + dp_group=dp_group, + tp_group=tp_group, + use_zero=use_zero, + inplace=False, + ) + + block, block_size = state_dict_sharder.append_optim_state(param_id, state_) + if block is not None: + yield block, block_size + + # Return the last block in sharder. + yield state_dict_sharder.current_block, state_dict_sharder.current_block_size + + def save_sharded_model( + self, + model: ModelWrapper, + checkpoint: str, + gather_dtensor: bool = True, + prefix: Optional[str] = None, + size_per_shard: int = 1024, + use_safetensors: bool = False, + ) -> None: + """ + Save sharded model checkpoint under the given checkpointing path. + The following files will be created under the path: + - An index file (pytorch_model.bin.index.json) containing a map between model params/buffers and file names. + - Multiple files that store state tensors of models. + If pipeline parallelism is used, the filenames are in the form of "pytorch_model.-stage-000XX-shard-000XX.bin". + If pipeline parallelism is not used, "pytorch_model.-000XX.bin" + + + Args: + model (nn.Module): Model on local device to be saved. + checkpoint (str): Checkpointing path which should be a directory path. + gather_dtensor (bool, optional): Whether to gather_dtensor, currently not used. Defaults to True. + prefix (str, optional): Perfix of file to save. Defaults to None. + size_per_shard (int, optional): Size per shard in MB. Defaults to 1024. + use_safetensors (bool, optional): Whether to use safe tensors. Defaults to False. + """ + + assert isinstance(model, ModelWrapper), "Please boost the model before saving!" + model = model.unwrap() + + if os.path.isfile(checkpoint): + logging.error(f"Provided path ({checkpoint}) should be a directory, not a file") + return + + Path(checkpoint).mkdir(parents=True, exist_ok=True) + + # Devices along the same dp_group share the same copies of model. + # So only let the device with dp_rank == 0 save the model. + if self.dp_rank != 0: + return + + # Then collect the sharded parameters & buffers along tp_group. + # Only devices with tp_rank == 0 are responsible for model saving. + state_dict_shard = HybridParallelCheckpointIO._model_sharder(model, size_per_shard=size_per_shard) + weights_name, save_index_file = get_model_base_filenames(prefix, use_safetensors) + index_file = CheckpointIndexFile(checkpoint) + control_saving = self.tp_rank == 0 + + if self.pp_size == 1: + # When pipeline is not used, save the model shards as in general checkpointIO + total_size = save_state_dict_shards( + sharded_state_dict=state_dict_shard, + checkpoint=checkpoint, + index_file=index_file, + base_filename=weights_name, + is_master=control_saving, + use_safetensors=use_safetensors, + ) + if control_saving: + index_file.append_meta_data("total_size", total_size) + index_file.write_index_file(save_index_file) + save_config_file(model, checkpoint) + if self.verbose and self.coordinator.is_master(): + logging.info( + f"The model is split into checkpoint shards. " + f"You can find where each parameters has been saved in the " + f"index located at {save_index_file}." + ) + + else: + # When pipeline is used, each stage produces its own shard files and index files. + # Index files belonging to each stage are saved under a temporary folder ./tmp_index_files/ + # After all the state_dicts have been saved, the master rank integrates all the index files into one final index file and deletes the tmp folder. + + final_index_file_path = copy.deepcopy(save_index_file) + tmp_index_file_folder = os.path.join(checkpoint, "tmp_index_files") + Path(tmp_index_file_folder).mkdir(parents=True, exist_ok=True) + + # Manage filenames of sharded weights and index file for each pipeline stage. + weights_name = weights_name.replace(".bin", f"-stage-{self.pp_rank+1:05d}-shard.bin") + weights_name = weights_name.replace(".safetensors", f"-stage-{self.pp_rank+1:05d}-shard.safetensors") + save_index_file = save_index_file.replace(".json", f"-stage-{self.pp_rank+1:05d}.json") + save_index_file = os.path.join("tmp_index_files", save_index_file) + + total_size = save_state_dict_shards( + sharded_state_dict=state_dict_shard, + checkpoint=checkpoint, + index_file=index_file, + base_filename=weights_name, + is_master=control_saving, + use_safetensors=use_safetensors, + use_pp_format=True, + ) + if control_saving: + assert ( + self.dp_rank == 0 and self.tp_rank == 0 + ), "The saving process should have both dp_rank and tp_rank as 0." + index_file.append_meta_data("total_size", total_size) + index_file.write_index_file(save_index_file) + else: + return + + dist.barrier(self.pp_group) + + # The global master rank integrates the index files and clean the folder. + if self.pp_rank == 0: + final_index_file = CheckpointIndexFile(checkpoint) + final_index_file.append_meta_data("total_size", 0) + + for filename in os.listdir(tmp_index_file_folder): + stage_index_file = CheckpointIndexFile.from_file(os.path.join(tmp_index_file_folder, filename)) + final_index_file.metadata["total_size"] += stage_index_file.metadata["total_size"] + for weight, weight_filename in stage_index_file.weight_map.items(): + final_index_file.append_weight_map(weight, weight_filename) + + final_index_file.write_index_file(final_index_file_path) + save_config_file(model, checkpoint) + rmtree(tmp_index_file_folder) + if self.verbose and self.coordinator.is_master(): + logging.info( + f"The model is split into checkpoint shards. " + f"You can find where each parameters has been saved in the " + f"index located at {final_index_file_path}." + ) + + def load_sharded_model(self, model: ModelWrapper, checkpoint_index_file: Path, strict: bool = False): + """ + Load sharded model with the given path to index file of checkpoint folder. + + Args: + model (nn.Module): The model to be loaded. + checkpoint_index_file (str): Path to the index file of checkpointing folder. + strict (bool, optional): For name matching during loading state_dict. Defaults to False. + This argument should be manually set to False since params on same device might be stored in different files. + """ + assert isinstance(model, ModelWrapper), "Please boost the model before loading!" + model_before_wrapping = model # backup for model before wrapping + model = model.unwrap() + + # Check whether the checkpoint uses safetensors. + use_safetensors = False + if "safetensors" in checkpoint_index_file.name: + use_safetensors = True + + if use_safetensors and not is_safetensors_available(): + raise ImportError("`safe_serialization` requires the `safetensors` library: `pip install safetensors`.") + + # Read checkpoint index file. + ckpt_index_file = CheckpointIndexFile.from_file(checkpoint_index_file) + ckpt_root_path = ckpt_index_file.root_path + weight_map = ckpt_index_file.weight_map + strict = False + + # Load params & buffers to model. + # Keep a record of loaded files so that file will not be repeatedly loaded. + loaded_file = set() + + def _load(name: str): + if name not in weight_map: + raise ValueError(f"{name} is not stored in checkpoint, please check your checkpointing configuration!") + filename = weight_map[name] + + # If this param/buffer has been loaded before, directly return. + if filename in loaded_file: + return + + file_path = os.path.join(ckpt_root_path, filename) + state_dict = load_shard_state_dict(Path(file_path), use_safetensors) + missing_keys = [] + + load_state_dict_into_model( + model, state_dict, missing_keys=missing_keys, strict=strict, load_sub_module=True + ) + loaded_file.add(filename) + + # Load parameters. + for name, _ in model.named_parameters(): + _load(name) + + # Load buffers. + non_persistent_buffers = set() + for n, m in model.named_modules(): + non_persistent_buffers |= set(".".join((n, b)) for b in m._non_persistent_buffers_set) + for name, buf in model.named_buffers(): + if buf is not None and name not in non_persistent_buffers: + _load(name) + + # Load extra states. + extra_state_key = _EXTRA_STATE_KEY_SUFFIX + if ( + getattr(model.__class__, "get_extra_state", torch.nn.Module.get_extra_state) + is not torch.nn.Module.get_extra_state + ): + _load(extra_state_key) + + # Update master params if mixed-precision training is enabled. + model_before_wrapping.update_master_params() + + if self.verbose and self.coordinator.is_master(): + logging.info(f"The model has been successfully loaded from sharded checkpoint: {ckpt_root_path}.") + + def save_sharded_optimizer( + self, + optimizer: OptimizerWrapper, + checkpoint: str, + gather_dtensor: bool = True, + prefix: Optional[str] = None, + size_per_shard: int = 1024, + ): + """ + Save sharded optimizer checkpoint under the given checkpointing path. + The following files will be created under the path: + - An index file (pytorch_optim.bin.index.json) containing a map between optimizer states and file names + - A group file (pytorch_optim_group.bin) recording information of param_groups + - Multiple files that store state tensors of optimizers. + If pipeline parallelism is used, the filenames are in the form of "pytorch_optim.-stage-000XX-shard-000XX.bin". + If pipeline parallelism is not used, "pytorch_optim.-000XX.bin" + + Args: + optimizer (OptimizerWrapper): Optimizer to save sharded state_dict + checkpoint (str): Path to save optimizer state_dict + gather_dtensor (bool): Whether to gather_dtensor, not used + prefix (str): Perfix of file to save + size_per_shard (int): Max file size of each file shard that store state tensors + """ + assert isinstance(optimizer, OptimizerWrapper), "Please boost the optimizer before saving!" + if os.path.isfile(checkpoint): + logging.error(f"Provided path ({checkpoint}) should be a directory, not a file") + return + + Path(checkpoint).mkdir(parents=True, exist_ok=True) + + # Devices along the same dp_group share the same copies of states when zero is not used. + # In this case only let the device with dp_rank == 0 save the model. + if not self.use_zero and self.dp_rank != 0: + return + + # Then collect the sharded states along dp_group(if using zero)/tp_group. + # Only devices with (dp_rank == 0 and tp_rank == 0) are responsible for states saving. + state_dict_shard = HybridParallelCheckpointIO._optimizer_sharder( + optimizer, + use_zero=self.use_zero, + dp_group=self.dp_group, + tp_group=self.tp_group, + size_per_shard=size_per_shard, + ) + states_name, save_index_file, param_group_file = get_optimizer_base_filenames(prefix) + index_file = CheckpointIndexFile(checkpoint) + control_saving = self.dp_rank == 0 and self.tp_rank == 0 + + if self.pp_size == 1: + # When pipeline is not used, save the optimizer shards as in general checkpointIO + total_size = save_state_dict_shards( + sharded_state_dict=state_dict_shard, + checkpoint=checkpoint, + index_file=index_file, + base_filename=states_name, + is_master=control_saving, + ) + + if control_saving: + # Store param groups. + index_file.append_meta_data("param_groups", param_group_file) + group_file_path = os.path.join(checkpoint, param_group_file) + save_param_groups(optimizer.param_info, group_file_path) + # Store index file. + index_file.append_meta_data("total_size", total_size) + index_file.write_index_file(save_index_file) + if self.verbose and self.coordinator.is_master(): + logging.info( + f"The optimizer is going to be split to checkpoint shards. " + f"You can find where each parameters has been saved in the " + f"index located at {save_index_file}." + ) + + else: + # When pipeline is used, each stage produces its own shard files and index files. + # Index files belonging to each stage are saved under a temporary folder ./tmp_index_files/ + # After all the state_dicts have been saved, the master rank integrates all the index files into one final index file and deletes the tmp folder. + + final_index_file_path = copy.deepcopy(save_index_file) + tmp_index_file_folder = os.path.join(checkpoint, "tmp_index_files") + Path(tmp_index_file_folder).mkdir(parents=True, exist_ok=True) + + # Manage filenames of sharded weights and index file for each pipeline stage. + states_name = states_name.replace(".bin", f"-stage-{self.pp_rank+1:05d}-shard.bin") + save_index_file = save_index_file.replace(".json", f"-stage-{self.pp_rank+1:05d}.json") + save_index_file = os.path.join("tmp_index_files", save_index_file) + + total_size = save_state_dict_shards( + sharded_state_dict=state_dict_shard, + checkpoint=checkpoint, + index_file=index_file, + base_filename=states_name, + is_master=control_saving, + use_pp_format=True, + ) + + if control_saving: + assert ( + self.dp_rank == 0 and self.tp_rank == 0 + ), "The saving process should have both dp_rank and tp_rank as 0." + index_file.append_meta_data("total_size", total_size) + index_file.write_index_file(save_index_file) + else: + return + + dist.barrier(self.pp_group) + + # The global master rank integrates the index files and clean the folder. + if self.pp_rank == 0: + final_index_file = CheckpointIndexFile(checkpoint) + final_index_file.append_meta_data("total_size", 0) + + for filename in os.listdir(tmp_index_file_folder): + stage_index_file = CheckpointIndexFile.from_file(os.path.join(tmp_index_file_folder, filename)) + final_index_file.metadata["total_size"] += stage_index_file.metadata["total_size"] + for param_id, state_filename in stage_index_file.weight_map.items(): + final_index_file.append_weight_map(param_id, state_filename) + + # Store param groups. + final_index_file.append_meta_data("param_groups", param_group_file) + group_file_path = os.path.join(checkpoint, param_group_file) + save_param_groups(optimizer.param_info, group_file_path) + + final_index_file.write_index_file(final_index_file_path) + rmtree(tmp_index_file_folder) + + if self.verbose and self.coordinator.is_master(): + logging.info( + f"The model is split into checkpoint shards. " + f"You can find where each parameters has been saved in the " + f"index located at {final_index_file_path}." + ) + + def load_sharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint_index_file: str, prefix: str = ""): + """ + Load sharded optimizer with the given path to index file of checkpoint folder. + + Args: + optimizer (OptimizerWrapper): The optimizer to be loaded. + checkpoint_index_file (str): Path to the index file of checkpointing folder. + prefix (str): Not used. + """ + assert isinstance(optimizer, OptimizerWrapper), "Please boost the optimizer before loading!" + + def _get_param_id_from_optimizer_param( + param: torch.Tensor, master_to_working_map: Optional[Dict[int, torch.Tensor]] = None + ): + if master_to_working_map is not None: + working_param = master_to_working_map[id(param)] + else: + working_param = param + return optimizer.param_info["param2id"][id(working_param)] + + # id_map is a mapping from param ids kept by current pipeline, to their corresponding parameter objects. + # When Zero is used, the mapped parameter objects should be fp32 master parameters. + # IDs should be obtained through saved param2id mapping earlier saved in optimizer.param_info. + id_map = {} + master_to_working_map = optimizer.get_master_to_working_map() + for pg in optimizer.optim.param_groups: + for param in pg["params"]: + param_id = _get_param_id_from_optimizer_param(param, master_to_working_map) + id_map[param_id] = param + + # Read checkpoint index file. + ckpt_index_file = CheckpointIndexFile.from_file(checkpoint_index_file) + ckpt_root_path = ckpt_index_file.root_path + weight_map = ckpt_index_file.weight_map + weight_map = {int(k): v for k, v in weight_map.items()} # convert saved id from str to int + + # Load param_groups + param_group_path = ckpt_index_file.get_param_group_filename() + if param_group_path is None: + raise RuntimeError( + f"Invalid index file path {checkpoint_index_file} for an optimizer. \ + Lacking param group file under current directory." + ) + saved_groups = torch.load(param_group_path) + + updated_groups = [] + for old_pg, saved_pg in zip(optimizer.optim.param_groups, saved_groups): + # obtain updated param group + new_pg = copy.deepcopy(saved_pg) + new_pg["params"] = old_pg["params"] # The parameters in the same group shouln't change. + updated_groups.append(new_pg) + optimizer.optim.__dict__.update({"param_groups": updated_groups}) + + # Load saved states to optimizer. + # Keep a record of loaded files so that file will not be repeatedly loaded. + loaded_file = set() + for pg in optimizer.optim.param_groups: + for param in pg["params"]: + if param is None: + continue + param_id = _get_param_id_from_optimizer_param(param, master_to_working_map) + if param_id not in weight_map: + continue + filename = weight_map[param_id] + + # If this param's states has been loaded before, directly return. + if filename in loaded_file: + continue + + file_path = os.path.join(ckpt_root_path, filename) + state_dict = load_shard_state_dict(Path(file_path), use_safetensors=False) + load_states_into_optimizer(optimizer.optim, state_dict, id_map, strict=True) + loaded_file.add(filename) + + # Then shard the loaded optimizer states if using tp/zero. + for param, state in optimizer.optim.state.items(): + device = param.device + if master_to_working_map is not None: + working_param = master_to_working_map[id(param)] + else: + working_param = param + original_shape = optimizer.param_info["param2shape"][id(working_param)] + sharded_state = self.shard_from_complete_optimizer_state( + state, current_shape=working_param.shape, original_shape=original_shape, device=device, inplace=True + ) + optimizer.optim.state[param] = sharded_state + + sharded_optimizer_loading_epilogue(optimizer.optim) + if self.verbose and self.coordinator.is_master(): + logging.info(f"The optimizer has been successfully loaded from sharded checkpoint: {ckpt_root_path}.") + + def save_unsharded_model(self, model: ModelWrapper, checkpoint: str, gather_dtensor: bool, use_safetensors: bool): + """ + Save model state dict to a single file with given checkpointing path. + + Args: + model (nn.Module): Model on local device to be saved. + checkpoint (str): Checkpointing path which should be a file path. Can be absolute or relative path. + gather_dtensor (bool, optional): Whether to gather dtensor, currently not used. Defaults to True. + use_safetensors (bool, optional): Whether to use safe tensors. Defaults to False. + """ + if self.coordinator.is_master(): + logging.warning("Please avoid using unsharded checkpointing methods when dealing with large models!") + + assert isinstance(model, ModelWrapper), "Please boost the model before saving!" + model = model.unwrap() + + if self.dp_rank != 0: + return + + # The logic of collecting parameter shards along tp degree + # has been implemented by _save_to_state_dict method of ParallelModule in Shardformer. + state_dict = model.state_dict() + + if self.pp_size == 1: + # When pipeline is not used, let master rank directly save the collected state_dict. + if self.tp_rank == 0: + save_state_dict(state_dict, checkpoint, use_safetensors) + else: + # When pipeline is used, first collect state_dict from every pipeline stage, then save the complete state_dict. + state_dict_list = [None for _ in range(self.pp_size)] + dist.barrier(self.pp_group) + dist.all_gather_object(state_dict_list, state_dict, self.pp_group) + + # Only the master rank do the saving. + if self.coordinator.is_master(): + complete_state_dict = dict() + for _state_dict in state_dict_list: + complete_state_dict.update(_state_dict) + save_state_dict(complete_state_dict, checkpoint, use_safetensors) + + def load_unsharded_model(self, model: ModelWrapper, checkpoint: str, strict: bool = False): + """ + Load model from a single file with the given path of checkpoint. + + Args: + model (nn.Module): The model to be loaded. + checkpoint_index_file (str): Path to the checkpoint file. + strict (bool, optional): For name matching during loading state_dict. Defaults to False. + This argument should be manually set to False since not all params in checkpoint are needed for each device when pipeline is enabled. + """ + if self.coordinator.is_master(): + logging.warning("Please avoid using unsharded checkpointing methods when dealing with large models!") + + assert isinstance(model, ModelWrapper), "Please boost the model before loading!" + strict = False + model_before_wrapping = model + model = model.unwrap() + + # Load from checkpoint. Since the logic of breaking parameter shards along tp degree + # has been implemented by _load_from_state_dict method of ParallelModule in Shardformer, + # model.load_state_dict can be directly called. + state_dict = load_state_dict(checkpoint) + model.load_state_dict(state_dict, strict=strict) + + # Update master params if mixed-precision training is enabled. + model_before_wrapping.update_master_params() + + def save_unsharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint: str, gather_dtensor: bool): + """ + Save optimizer state dict to a file with given path. + + Args: + optimizer (OptimizerWrapper): Optimizer to save sharded state_dict. + checkpoint (str): Path to save optimizer state_dict. + gather_dtensor (bool): Whether to gather_dtensor, not used. + """ + if self.coordinator.is_master(): + logging.warning("Please avoid using unsharded checkpointing methods when dealing with large models!") + + assert isinstance(optimizer, OptimizerWrapper), "Please boost the optimizer before saving!" + + # optimizer states of parameters kept by local device('s pipeline stage) + local_states = dict() + + for param, state in optimizer.optim.state.items(): + if param is None: + continue + + # working param is needed for obtaining correct param_id + master_to_working_map = optimizer.get_master_to_working_map() + if master_to_working_map is not None: + working_param = master_to_working_map[id(param)] + else: + working_param = param + + # gather complete state from tp shards & dp shards + param_id = optimizer.param_info["param2id"][id(working_param)] + original_shape = optimizer.param_info["param2shape"][id(working_param)] + local_states[param_id] = HybridParallelCheckpointIO.gather_from_sharded_optimizer_state( + state, + working_param, + original_shape=original_shape, + dp_group=self.dp_group, + tp_group=self.tp_group, + use_zero=self.use_zero, + inplace=False, + device=torch.device("cuda"), + ) + + if self.pp_size == 1: + # When pipeline is not used, let master rank directly save the collected state_dict. + state_dict = {"param_groups": optimizer.param_info["param_groups"], "state": local_states} + if self.coordinator.is_master(): + save_state_dict(state_dict, checkpoint, use_safetensors=False) + else: + # When pipeline is used, first collect state_dict from every pipeline stage, then save the complete state_dict. + states_list = [None for _ in range(self.pp_size)] + dist.barrier(self.pp_group) + dist.all_gather_object(states_list, local_states, self.pp_group) + + # Only the master rank do the saving. + if self.coordinator.is_master(): + state_dict = {"param_groups": optimizer.param_info["param_groups"], "state": dict()} + for _states in states_list: + state_dict["state"].update(_states) + save_state_dict(state_dict, checkpoint, use_safetensors=False) + + def load_unsharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint: str): + """ + Load optimizer from a file with given path. + + Args: + optimizer (OptimizerWrapper): The optimizer to be loaded. + checkpoint_index_file (str): Path to the checkpoint file. + """ + + def _get_param_id_from_optimizer_param( + param: torch.Tensor, master_to_working_map: Optional[Dict[int, torch.Tensor]] = None + ): + if master_to_working_map is not None: + working_param = master_to_working_map[id(param)] + else: + working_param = param + return optimizer.param_info["param2id"][id(working_param)] + + if self.coordinator.is_master(): + logging.warning("Please avoid using unsharded checkpointing methods when dealing with large models!") + + assert isinstance(optimizer, OptimizerWrapper), "Please boost the optimizer before loading!" + + # Complete optimizer state_dict loaded from checkpoint, need to be processed later. + state_dict = load_state_dict(checkpoint) + + # Load param_groups. + updated_groups = [] + saved_groups = state_dict["param_groups"] + for old_pg, saved_pg in zip(optimizer.optim.param_groups, saved_groups): + new_pg = copy.deepcopy(saved_pg) + new_pg["params"] = old_pg["params"] # Only keep the parameters kept by current pipeline stage. + updated_groups.append(new_pg) + optimizer.optim.__dict__.update({"param_groups": updated_groups}) + + # Load saved states to optimizer. First discard those states not belonging to current pipeline stage. + master_to_working_map = optimizer.get_master_to_working_map() + id_map = {} + for pg in optimizer.optim.param_groups: + for param in pg["params"]: + param_id = _get_param_id_from_optimizer_param(param, master_to_working_map) + id_map[param_id] = param + load_states_into_optimizer(optimizer.optim, state_dict["state"], id_map, strict=True) + + # Then shard the loaded optimizer states if using tp/zero. + for param, state in optimizer.optim.state.items(): + if param is None: + continue + device = param.device + if master_to_working_map is not None: + working_param = master_to_working_map[id(param)] + else: + working_param = param + original_shape = optimizer.param_info["param2shape"][id(working_param)] + sharded_state = self.shard_from_complete_optimizer_state( + state, current_shape=working_param.shape, original_shape=original_shape, device=device, inplace=True + ) + optimizer.optim.state[param] = sharded_state + + sharded_optimizer_loading_epilogue(optimizer.optim) + + def save_lr_scheduler(self, lr_scheduler: LRScheduler, checkpoint: str): + """ + Save lr scheduler to checkpoint but only on master process. + """ + if self.coordinator.is_master(): + super().save_lr_scheduler(lr_scheduler, checkpoint) + + @staticmethod + def gather_from_sharded_optimizer_state( + state: OrderedDict, + param: torch.Tensor, + original_shape: torch.Size, + dp_group: ProcessGroup, + tp_group: ProcessGroup, + use_zero: bool, + inplace: bool, + device: torch.device = torch.device("cpu"), + ) -> OrderedDict: + """ + With given parameter and its optimizer states, gather the complete optimizer state for saving. + + Args: + state (OrderedDict): Optimizer states of given parameter, might be distributed among tp/dp group if using TP/Zero. + param (torch.Tensor): The given parameter. It should be working_param when using Zero. + original_shape (torch.Size): The size of parameter before sharding. + dp_group (ProcessGroup): The process group of data parallel. + tp_group (ProcessGroup): The process group of tensor parallel. + use_zero (bool): Whether Zero is used. + inplace (bool): If set to True, will update the values of argument 'state' in place. Else will make a copy of state. + device (torch.device): The destination device of loaded optimizer states. Defaults to torch.device('cpu'). + + Returns: + OrderedDict: The complete optimizer state of given parameter. + """ + dp_size = dist.get_world_size(dp_group) + tp_size = dist.get_world_size(tp_group) + current_shape = param.shape + state_ = state if inplace else copy.deepcopy(state) + + for k, v in state_.items(): + if isinstance(v, torch.Tensor) and k != "step": + # First gather Zero shards. + if use_zero: + v = v.cuda() + gather_tensor = [torch.zeros_like(v) for _ in range(dp_size)] + dist.all_gather(gather_tensor, v, group=dp_group) + v = torch.stack(gather_tensor).view(-1)[: param.numel()].reshape_as(param) + + # Then gather TP shards. + partition_dim = search_tp_partition_dim(current_shape, original_shape, tp_size) + if partition_dim is not None: + gather_tensor = [torch.zeros_like(v) for _ in range(tp_size)] + dist.all_gather(gather_tensor, v, group=tp_group) + v = torch.cat(gather_tensor, dim=partition_dim) + + state_[k] = v.detach().clone().to(device) + + return state_ + + def shard_from_complete_optimizer_state( + self, + state: OrderedDict, + current_shape: torch.Size, + original_shape: torch.Size, + device: torch.device, + inplace: bool, + ) -> OrderedDict: + """ + With complete optimizer states of a specific parameter loaded from checkpoint, + slice out the sharded optimizer states kept by current device. + + Args: + state (OrderedDict): Complete optimizer states of a given parameter, loaded from checkpoint. + current_shape (torch.Size): The size of parameter after sharding. + original_shape (torch.Size): The size of parameter before sharding. + device (torch.device): The destination device of loaded optimizer states. + inplace (bool): If set to True, will update the values of argument 'state' in place. Else will make a copy of state. + + Returns: + OrderedDict: The sharded optimizer state of the given parameter. + """ + state_ = state if inplace else copy.deepcopy(state) + + for k, v in state_.items(): + if isinstance(v, torch.Tensor) and k != "step": + # Shard state along tensor parallel group. + partition_dim = search_tp_partition_dim(current_shape, original_shape, self.tp_size) + if partition_dim is not None: + slice_size = current_shape[partition_dim] + v = v.split(slice_size, dim=partition_dim)[self.tp_rank] + + # Shard state along data parallel group when using Zero. + if self.use_zero: + padding_size = (self.dp_size - v.numel() % self.dp_size) % self.dp_size + with torch.no_grad(): + v = v.flatten() + if padding_size > 0: + v = torch.nn.functional.pad(v, [0, padding_size]) + slice_size = v.numel() // self.dp_size + v = v.split(slice_size, dim=0)[self.dp_rank] + + state_[k] = v.detach().clone().to(device) + + return state_ diff --git a/colossalai/checkpoint_io/index_file.py b/colossalai/checkpoint_io/index_file.py index 89224787a91b9824f2261aece909f6a1cb094a17..da12c146f2c309315e3e00aa85f2b72fcee2e723 100644 --- a/colossalai/checkpoint_io/index_file.py +++ b/colossalai/checkpoint_io/index_file.py @@ -1,10 +1,12 @@ import json +import os +from collections import OrderedDict from pathlib import Path -from typing import Any, List, Union +from typing import Any, Dict, List, Union from .utils import is_dtensor_checkpoint -__all__ = ['CheckpointIndexFile'] +__all__ = ["CheckpointIndexFile"] class CheckpointIndexFile: @@ -18,10 +20,12 @@ class CheckpointIndexFile: >>> index.export('new_index.json') """ - def __init__(self) -> None: - self.root_path = None - self.metadata: dict = dict() - self.weight_map: dict = dict() + def __init__(self, root_path=None) -> None: + self.root_path = root_path + + # use ordered dict to preserve the tensor checkpoint order + self.metadata: Dict = OrderedDict() + self.weight_map: Dict = OrderedDict() @staticmethod def from_file(index_path: Union[str, Path]): @@ -46,7 +50,7 @@ class CheckpointIndexFile: json_path (str): path to the json file. """ # load the json file - with open(json_path, 'r') as f: + with open(json_path, "r") as f: index = json.load(f) # assign attributes if exists @@ -71,7 +75,7 @@ class CheckpointIndexFile: index["weight_map"] = self.weight_map # export the index file - with open(json_path, 'w') as f: + with open(json_path, "w") as f: json.dump(index, f, indent=4) def append_weight_map(self, param_name: str, shard_file: str): @@ -107,7 +111,7 @@ class CheckpointIndexFile: return True return False - def get_checkpoint_fileanames(self) -> List[str]: + def get_checkpoint_filenames(self) -> List[str]: """ Get the set of checkpoint filenames in the weight map. @@ -148,9 +152,31 @@ class CheckpointIndexFile: """ ckpt_path = self.weight_map[param_name] return ckpt_path - + def get_all_param_names(self): """ Get all the weight keys. """ return list(self.weight_map.keys()) + + def get_param_group_filename(self) -> Union[str, None]: + """ + Get the file name of param_group file if this is a checkpoint for optimizer. + Returns: + str: param_group file name + """ + filename = self.metadata.get("param_groups", None) + if filename: + return str(self.root_path.joinpath(filename)) + else: + return None + + def write_index_file(self, save_index_file): + """ + Write index file. + """ + save_index_file = os.path.join(self.root_path, save_index_file) + index = {"metadata": self.metadata, "weight_map": self.weight_map} + with open(save_index_file, "w", encoding="utf-8") as f: + content = json.dumps(index, indent=2) + "\n" + f.write(content) diff --git a/colossalai/checkpoint_io/utils.py b/colossalai/checkpoint_io/utils.py index 37d22d08df40eaaa209ad32b097ce735a14378dc..06dab1fdb72a4d995ebd7d0900bb52afda0d2898 100644 --- a/colossalai/checkpoint_io/utils.py +++ b/colossalai/checkpoint_io/utils.py @@ -1,33 +1,51 @@ # coding=utf-8 +import os +import re +from collections import abc as container_abcs +from collections import defaultdict +from itertools import chain from pathlib import Path +from typing import Iterator, List, Mapping, Optional, OrderedDict, Tuple + import torch import torch.nn as nn -from typing import List, Dict, Mapping, OrderedDict, Optional, Tuple -from colossalai.tensor.d_tensor.d_tensor import DTensor -import re +from packaging.version import Version +from torch.optim import Optimizer + +from colossalai.tensor.d_tensor import ( + is_customized_distributed_tensor, + is_distributed_tensor, + to_global, + to_global_for_customized_distributed_tensor, +) SAFE_WEIGHTS_NAME = "model.safetensors" WEIGHTS_NAME = "pytorch_model.bin" +STATES_NAME = "pytorch_optim.bin" SAFE_WEIGHTS_INDEX_NAME = "model.safetensors.index.json" WEIGHTS_INDEX_NAME = "pytorch_model.bin.index.json" +STATES_INDEX_NAME = "pytorch_optim.bin.index.json" +GROUP_FILE_NAME = "pytorch_optim_group.bin" # ====================================== # General helper functions # ====================================== + def calculate_tensor_size(tensor: torch.Tensor) -> float: """ Calculate the size of a parameter in MB. Used to compute whether a group of params exceed the shard size. If so, a new shard should be created. Args: - tenosr (torch.Tensor): the tensor to calculate size for. + tensor (torch.Tensor): the tensor to calculate size for. Returns: float: size of the tensor in MB. """ return tensor.numel() * tensor.element_size() / 1024 / 1024 + def is_safetensors_available() -> bool: """ Check whether safetensors is available. @@ -36,7 +54,6 @@ def is_safetensors_available() -> bool: bool: whether safetensors is available. """ try: - import safetensors return True except ImportError: return False @@ -52,7 +69,7 @@ def is_dtensor_checkpoint(checkpoint_file_path: str) -> bool: Returns: bool: whether the checkpoint file is a dtensor checkpoint. """ - if checkpoint_file_path.endswith('.*.safetensors') or checkpoint_file_path.endswith('.*.bin'): + if checkpoint_file_path.endswith(".*.safetensors") or checkpoint_file_path.endswith(".*.bin"): return True else: return False @@ -68,136 +85,210 @@ def is_safetensor_checkpoint(checkpoint_file_path: str) -> bool: Returns: bool: whether the checkpoint file is a safetensor checkpoint. """ - if checkpoint_file_path.endswith('.safetensors'): + if checkpoint_file_path.endswith(".safetensors"): return True else: return False +def search_tp_partition_dim(current_shape: torch.Size, original_shape: torch.Size, tp_size: int) -> Optional[int]: + """ + Given the current shape of parameter and the shape of parameter before sharding, + return the dimension along which the parameter is sharded when using tensor parallel. + If tensor parallel is not used, return None. + + Args: + current_shape (torch.Size): The current shape of parameter after sharding. + original_shape (torch.Size): The shape of parameter before sharding. + tp_size (int): The size of tp group. + + Returns: + Optional[int]: The dimension along which parameter is partitioned. + """ + partition_dim = None + for dim, length in enumerate(original_shape): + if length > current_shape[dim]: + partition_dim = dim + break + if partition_dim is not None: + assert ( + original_shape[partition_dim] == tp_size * current_shape[partition_dim] + ), f"The parameter isn't evenly distributed among tensor parallel group: \ + shape before sharding {original_shape}, shape after sharding {current_shape}" + + return partition_dim + + # ====================================== -# Helper functions for saving shard file +# Helper classes and functions for saving shard file # ====================================== -def shard_checkpoint(state_dict: torch.Tensor, max_shard_size: int = 1024, weights_name: str = WEIGHTS_NAME): - + + +class StateDictSharder: + def __init__(self, size_per_shard: int) -> None: + self.max_shard_size = size_per_shard + self.current_block = OrderedDict() + self.current_block_size = 0 + + def append_param(self, name: str, tensor: torch.Tensor) -> Tuple[Optional[OrderedDict], int]: + tensor_size = calculate_tensor_size(tensor) + ret_block = None + ret_block_size = 0 + + # before we return the current block and create a new block, + # we need to ensure that the current block is not empty + if self.current_block_size + tensor_size > self.max_shard_size and self.current_block_size > 0: + ret_block = self.current_block + ret_block_size = self.current_block_size + self.current_block = OrderedDict() + self.current_block_size = 0 + + self.current_block[name] = tensor + self.current_block_size += tensor_size + return ret_block, ret_block_size + + def append_optim_state(self, param_id: int, state: OrderedDict) -> Tuple[Optional[OrderedDict], int]: + # A state might contain more than one tensors. + # e.g. each Adam state includes: 'step', 'exp_avg', 'exp_avg_sq' + state_size = 0 + isDTensor = False + for state_tensor in state.values(): + # When state_tensor is not of Tensor class, + # e.g., a SGD optimizer with momentum set to 0 can have None as state + # The calculation of tensor size should be skipped to avoid error. + if not isinstance(state_tensor, torch.Tensor): + continue + + # If the states are stored as DTensors, mark isDTensor as true. + if is_distributed_tensor(state_tensor): + isDTensor = True + state_size += calculate_tensor_size(state_tensor) + + ret_block = None + ret_block_size = 0 + + # directly return if state is stored as distributed tensor + if isDTensor: + return ret_block, ret_block_size + + # before we return the current block and create a new block, + # we need to ensure that the current block is not empty + if self.current_block_size + state_size > self.max_shard_size and self.current_block_size > 0: + ret_block = self.current_block + ret_block_size = self.current_block_size + self.current_block = OrderedDict() + self.current_block_size = 0 + + self.current_block[param_id] = state + self.current_block_size += state_size + return ret_block, ret_block_size + + +def gather_distributed_param(param: torch.Tensor, keep_vars: bool = False) -> torch.Tensor: """ - Splits a model state dictionary in sub-checkpoints so that the final size of each sub-checkpoint does not exceed a - given size. + Gather the complete parameter for saving if passed in param is distributed under tp setting. + + Args: + param (torch.Tensor): A model parameter, might be d_tensor. + keep_vars (bool, optional): Whether to return the parameter in calculation graph. Defaults to False. + + Returns: + torch.Tensor: the complete parameter """ - sharded_state_dicts = [] - current_block = {} - current_block_size = 0 - total_size = 0 + param_ = param if keep_vars else param.detach() + if is_distributed_tensor(param_): + return to_global(param_) + elif is_customized_distributed_tensor(param_): + return to_global_for_customized_distributed_tensor(param_) + else: + return param_ + + +def save_state_dict_shards( + sharded_state_dict: Iterator[Tuple[OrderedDict, int]], + checkpoint: str, + index_file: "CheckpointIndexFile", + base_filename: str, + is_master: bool, + use_safetensors: bool = False, + use_pp_format: bool = False, +) -> int: + """ + Save sharded state dict only on master rank, this method can be used by both model and optimizer states. + Args: + sharded_state_dict (Iterator[Tuple[OrderedDict, int]]): a generator of shards, each shard contains state dict and shard size. + checkpoint (str): The path of checkpoint directory as string. + index_file (CheckpointIndexFile): The index file object to be updated. + base_filename (str): Decides the prefix of filenames of shards. + is_master (bool): Whether current rank is main process. + use_safetensors (bool, optional): Whether to use safetensors to save checkpoint. Defaults to False. + use_pp_format: (bool, optional): Whether to save the files in pipeline format including stage information. Defaults to False. - for key, weight in state_dict.items(): - if type(weight) != DTensor: - weight_size = calculate_tensor_size(weight) - - # If this weight is going to tip up over the maximal size, we split. - if current_block_size + weight_size > max_shard_size: - sharded_state_dicts.append(current_block) - current_block = {} - current_block_size = 0 - - current_block[key] = weight - current_block_size += weight_size - total_size += weight_size - - # Add the last block - sharded_state_dicts.append(current_block) - - # If we only have one shard, we return it - if len(sharded_state_dicts) == 1: - return {weights_name: sharded_state_dicts[0]}, None - - # Otherwise, let's build the index - weight_map = {} - shards = {} - - for idx, shard in enumerate(sharded_state_dicts): - shard_file = weights_name.replace(".bin", f"-{idx+1:05d}-of-{len(sharded_state_dicts):05d}.bin") - shard_file = shard_file.replace( - ".safetensors", f"-{idx + 1:05d}-of-{len(sharded_state_dicts):05d}.safetensors" - ) - shards[shard_file] = shard + Returns: + int: the total size of shards + """ + + total_size = 0 + shard_filenames = [] + for idx, shard_pair in enumerate(sharded_state_dict): + shard, current_size = shard_pair + if not is_master: + del shard + continue + shard_file = get_shard_filename(base_filename, idx) + total_size = total_size + current_size for key in shard.keys(): - weight_map[key] = shard_file + index_file.append_weight_map(key, shard_file) + checkpoint_file_path = os.path.join(checkpoint, shard_file) - # Add the metadata - metadata = {"total_size": total_size} - index = {"metadata": metadata, "weight_map": weight_map} - return shards, index + # Only save on master rank. + save_state_dict(shard, checkpoint_file_path, use_safetensors=use_safetensors) + shard_filenames.append(shard_file) + del shard -def load_shard_state_dict(checkpoint_file: Path, use_safetensors: bool =False): + # Clean folder, deleted unneeded files. + clean_folder(checkpoint, base_filename, shard_filenames, is_master=is_master, use_pp_format=use_pp_format) + + return total_size + + +def shard_model_checkpoint(state_dict: torch.Tensor, max_shard_size: int = 1024) -> Iterator[Tuple[OrderedDict, int]]: """ - load shard state dict into model + Splits a model state dictionary in sub-checkpoints so that the final size of each sub-checkpoint does not exceed a + given size. """ - if use_safetensors and not checkpoint_file.suffix == ".safetensors": - raise Exception("load the model using `safetensors`, but no file endwith .safetensors") - if use_safetensors: - from safetensors.torch import safe_open - from safetensors.torch import load_file as safe_load_file - with safe_open(checkpoint_file, framework="pt") as f: - metadata = f.metadata() - if metadata["format"] != "pt": - raise NotImplementedError( - f"Conversion from a {metadata['format']} safetensors archive to PyTorch is not implemented yet." - ) - return safe_load_file(checkpoint_file) - else: - return torch.load(checkpoint_file) - -def load_state_dict_into_model(model: nn.Module, state_dict: torch.Tensor, missing_keys: List, strict: bool = False): - r"""Copies parameters and buffers from :attr:`state_dict` into - this module and its descendants. + state_dict_sharder = StateDictSharder(max_shard_size) - Args: - state_dict (dict): a dict containing parameters and - persistent buffers. - """ - if not isinstance(state_dict, Mapping): - raise TypeError("Expected state_dict to be dict-like, got {}.".format(type(state_dict))) + for key, weight in state_dict.items(): + if not is_distributed_tensor(weight): + block, block_size = state_dict_sharder.append_param(key, weight) - unexpected_keys: List[str] = [] - sub_missing_keys: List[str] = [] - error_msgs: List[str] = [] + if block != None: + yield block, block_size - # copy state_dict so _load_from_state_dict can modify it - metadata = getattr(state_dict, '_metadata', None) - state_dict = OrderedDict(state_dict) - if metadata is not None: - state_dict._metadata = metadata + # Return the last block in sharder. + yield state_dict_sharder.current_block, state_dict_sharder.current_block_size - def load(module: nn.Module, state_dict, prefix=""): - local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {}) - args = (state_dict, prefix, local_metadata, True, [], [], error_msgs) - # Parameters of module and children will start with prefix. We can exit early if there are none in this - # state_dict - if len([key for key in state_dict if key.startswith(prefix)]) > 0: - module._load_from_state_dict(*args) - for name, child in module._modules.items(): - if child is not None: - load(child, state_dict, prefix + name + ".") +def shard_optimizer_checkpoint(state_dict: dict, max_shard_size: int = 1024) -> Iterator[Tuple[OrderedDict, int]]: + """ + Splits an optimizer state dictionary in sub-checkpoints so that the final size of each sub-checkpoint does not exceed a + given size. + """ - load(model, state_dict, "") - del load + # Only split state_dict['state']; state_dict['param_group'] is not considered in this function. + states = state_dict["state"] + state_dict_sharder = StateDictSharder(max_shard_size) + + for param_id, state in states.items(): + block, block_size = state_dict_sharder.append_optim_state(param_id, state) + if block != None: + yield block, block_size + + # Return the last block in sharder. + yield state_dict_sharder.current_block, state_dict_sharder.current_block_size - # deal with missing key - if len(missing_keys) > 0: - deleted_keys = [] - for key in missing_keys: - if key not in sub_missing_keys: - deleted_keys.append(key) - for key in deleted_keys: - missing_keys.remove(key) - if strict: - if len(unexpected_keys) > 0: - error_msgs = 'Unexpected key(s) in state_dict: {}. '.format( - ', '.join('"{}"'.format(k) for k in unexpected_keys)) - raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format( - model.__class__.__name__, "\n\t".join(error_msgs))) - # ====================================== # Helper functions for saving state dict # ====================================== @@ -214,14 +305,99 @@ def save_state_dict(state_dict: dict, checkpoint_file_path: str, use_safetensors """ if use_safetensors: assert is_safetensors_available(), "safetensors is not available." - assert checkpoint_file_path.endswith('.safetensors'), \ - "safetensors only supports .safetensors suffix for checkpoint file." + assert checkpoint_file_path.endswith( + ".safetensors" + ), "safetensors only supports .safetensors suffix for checkpoint file." from safetensors.torch import save_file as safe_save_file + safe_save_file(state_dict, checkpoint_file_path, metadata={"format": "pt"}) else: torch.save(state_dict, checkpoint_file_path) +def save_param_groups(state_dict: dict, group_file_path: str) -> None: + """ + Save information of param_groups to given file path. + + Args: + state_dict (dict): state dict. + group_file_path (str): path to the group file. + """ + param_groups = state_dict["param_groups"] + torch.save(param_groups, group_file_path) + + +def clean_folder( + checkpoint_path: str, + weights_name: str, + shard_filenames: List[str], + is_master: bool = True, + use_pp_format: bool = False, +): + """ + Clean the unneeded files in checkpoint directory after shards of state_dict have been saved. + + Args: + checkpoint_path (str): Path to the checkpoint directory. + weights_name (str): Decides the prefix of filenames of weight shards. + shard_filenames (List[str]): The list of saved shard filenames which should not be removed. + is_master (bool, optional): Whether current rank is main process. Defaults to True. + use_pp_format: (bool, optional): Whether to save the files in pipeline format including stage information. Defaults to False. + + """ + if is_master: + for filename in os.listdir(checkpoint_path): + full_filename = os.path.join(checkpoint_path, filename) + weights_no_suffix = weights_name.replace(".bin", "").replace(".safetensors", "") + filename_no_suffix = filename.replace(".bin", "").replace(".safetensors", "") + if not use_pp_format: + reg = re.compile(r"(.*?)-\d{5}") + else: + # When this checkpoint is created by pipeline parallel process, the pattern is a little different. + reg = re.compile(r"(.*?)-stage-\d{5}-shard-\d{5}") + if ( + filename.startswith(weights_no_suffix) + and os.path.isfile(full_filename) + and filename not in shard_filenames + and reg.fullmatch(filename_no_suffix) is not None + ): + os.remove(full_filename) + + +def save_config_file(model: nn.Module, checkpoint_path: str, is_master: bool = True): + """ + Save config.json/generation_config.json if model is a Huggingface pretrained model. + This method can only be called when a model is saved in a sharded way. + + Args: + model (nn.Module): The model whose config should be saved if it's a huggingface model. + checkpoint_path (str): Path to the checkpoint directory. + is_master (bool): Whether current rank is main process. + """ + try: + from transformers.modeling_utils import PreTrainedModel, get_parameter_dtype + from transformers.modeling_utils import unwrap_model as unwrap_huggingface_model + except ImportError: + return + if not isinstance(model, PreTrainedModel): + return + + model = unwrap_huggingface_model(model) + + # save the string version of dtype to the config, e.g. convert torch.float32 => "float32" + dtype = get_parameter_dtype(model) + model.config.torch_dtype = str(dtype).split(".")[1] + + # Attach architecture to the config + model.config.architectures = [model.__class__.__name__] + + # Save the config + if is_master: + model.config.save_pretrained(checkpoint_path) + if model.can_generate(): + model.generation_config.save_pretrained(checkpoint_path) + + def save_dtensor(name: str, tensor: torch.Tensor, index_file: "CheckpointIndexFile", use_safetensors: bool) -> None: """ Save distributed tensor to checkpoint. This checkpoint will be a dictionary which contains @@ -233,7 +409,7 @@ def save_dtensor(name: str, tensor: torch.Tensor, index_file: "CheckpointIndexFi size_per_shard (int): size per shard in MB. """ root_path = index_file.root_path - output_root_path = root_path.joinpath('dtensor') + output_root_path = root_path.joinpath("dtensor") # create directory output_root_path.mkdir(exist_ok=True) @@ -253,7 +429,7 @@ def save_dtensor(name: str, tensor: torch.Tensor, index_file: "CheckpointIndexFi # update the weight map # * means all shards - ckpt_file_name_in_weight_map = 'dtensor/' + generate_dtensor_file_name(name, '*', use_safetensors) + ckpt_file_name_in_weight_map = "dtensor/" + generate_dtensor_file_name(name, "*", use_safetensors) index_file.append_weight_map(name, ckpt_file_name_in_weight_map) @@ -268,15 +444,14 @@ def get_checkpoint_file_suffix(use_safetensors: bool) -> str: str: checkpoint file suffix. """ if use_safetensors: - return '.safetensors' + return ".safetensors" else: - return '.bin' + return ".bin" -def generate_checkpoint_shard_file_name(index: int, - total_number: int, - use_safetensors: bool, - prefix: str = None) -> str: +def generate_checkpoint_shard_file_name( + index: int, total_number: int, use_safetensors: bool, prefix: str = None +) -> str: """ Generate checkpoint shard file name. @@ -310,39 +485,190 @@ def generate_dtensor_file_name(param_name: str, index: int, use_safetensors: boo str: dtensor file name. """ suffix = get_checkpoint_file_suffix(use_safetensors) - return f'{param_name}.{index}.{suffix}' + return f"{param_name}.{index}.{suffix}" -def save_state_dict_as_shard( - state_dict: dict, - checkpoint_path: str, - index: int, - total_number: int, - use_safetensors: bool, - prefix: str = None, -) -> None: +# ======================================== +# Helper functions for loading state dict +# ======================================== + + +def load_shard_state_dict(checkpoint_file: Path, use_safetensors: bool = False): + """ + load shard state dict into model """ - Save state dict as shard. + if use_safetensors and not checkpoint_file.suffix == ".safetensors": + raise Exception("load the model using `safetensors`, but no file endwith .safetensors") + if use_safetensors: + from safetensors.torch import load_file as safe_load_file + from safetensors.torch import safe_open + + with safe_open(checkpoint_file, framework="pt") as f: + metadata = f.metadata() + if metadata["format"] != "pt": + raise NotImplementedError( + f"Conversion from a {metadata['format']} safetensors archive to PyTorch is not implemented yet." + ) + return safe_load_file(checkpoint_file) + else: + return torch.load(checkpoint_file, map_location=torch.device("cpu")) + + +def load_state_dict_into_model( + model: nn.Module, state_dict: torch.Tensor, missing_keys: List, strict: bool = False, load_sub_module: bool = True +): + r"""Copies parameters and buffers from :attr:`state_dict` into + this module and its descendants. Args: - state_dict (dict): state dict. - checkpoint_path (str): path to the checkpoint file. - index (int): index of the shard. - total_number (int): total number of shards. - prefix (str): prefix of the shard file name. - use_safetensors (bool): whether to use safetensors to save the checkpoint. + state_dict (dict): a dict containing parameters and + persistent buffers. """ - # generate the shard name - shard_file_name = generate_checkpoint_shard_file_name(index, total_number, use_safetensors, prefix) - shard_file_path = Path(checkpoint_path).joinpath(shard_file_name).absolute() + if not isinstance(state_dict, Mapping): + raise TypeError("Expected state_dict to be dict-like, got {}.".format(type(state_dict))) + + unexpected_keys: List[str] = [] + sub_missing_keys: List[str] = [] + error_msgs: List[str] = [] - # save the shard - save_state_dict(state_dict, str(shard_file_path), use_safetensors) + # copy state_dict so _load_from_state_dict can modify it + metadata = getattr(state_dict, "_metadata", None) + state_dict = OrderedDict(state_dict) + if metadata is not None: + state_dict._metadata = metadata + def load(module: nn.Module, state_dict, prefix="", load_sub_module: bool = True): + local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {}) + args = (state_dict, prefix, local_metadata, True, sub_missing_keys, [], error_msgs) + # Parameters of module and children will start with prefix. We can exit early if there are none in this + # state_dict + if len([key for key in state_dict if key.startswith(prefix)]) > 0: + module._load_from_state_dict(*args) + if load_sub_module: + for name, child in module._modules.items(): + if child is not None: + load(child, state_dict, prefix + name + ".") -# ======================================== -# Helper functions for loading state dict -# ======================================== + load(model, state_dict, "", load_sub_module) + del load + + missing_keys = missing_keys.append(sub_missing_keys) + + if strict: + if len(unexpected_keys) > 0: + error_msgs = "Unexpected key(s) in state_dict: {}. ".format( + ", ".join('"{}"'.format(k) for k in unexpected_keys) + ) + raise RuntimeError( + "Error(s) in loading state_dict for {}:\n\t{}".format(model.__class__.__name__, "\n\t".join(error_msgs)) + ) + + +def load_param_groups_into_optimizer(optimizer: Optimizer, param_group_path: str) -> dict: + """ + Load information of param_groups into an initialized optimizer. + """ + + # Load list of param_groups from given file path. + # The params in saved_groups are in the form of integer indices. + saved_groups = torch.load(param_group_path, map_location=torch.device("cpu")) + if not isinstance(saved_groups, List): + raise ValueError(f"The param_groups saved at {param_group_path} is not of List type") + + # The params in param_groups are in the form of pytorch tensors. + # For more details, please view source code of Optimizer class in pytorch. + param_groups = optimizer.param_groups + + # Check the compatibility of saved_groups and param_groups. + if len(param_groups) != len(saved_groups): + raise ValueError("loaded state dict has a different number of original parameter groups") + param_lens = (len(g["params"]) for g in param_groups) + saved_lens = (len(g["params"]) for g in saved_groups) + if any(p_len != s_len for p_len, s_len in zip(param_lens, saved_lens)): + raise ValueError( + "loaded state dict contains a parameter group " "that doesn't match the size of optimizer's group" + ) + + # Creating mapping from id to parameters. + id_map = { + old_id: p + for old_id, p in zip( + chain.from_iterable((g["params"] for g in saved_groups)), + chain.from_iterable((g["params"] for g in param_groups)), + ) + } + + # Update parameter groups, setting their 'params' value. + def update_group(group, new_group): + new_group["params"] = group["params"] + return new_group + + updated_groups = [update_group(g, ng) for g, ng in zip(param_groups, saved_groups)] + + optimizer.__dict__.update({"param_groups": updated_groups}) + return id_map + + +def load_states_into_optimizer(optimizer: Optimizer, state_dict: dict, id_map: dict, strict: bool = False): + r"""Copies states from `state_dict` into an Optimizer object. + + Args: + optimizer(Optimizer): An initialized Optimizer object to be loaded + state_dict(dict): A mapping from tensor index (an integer) + to its states to be loaded (a mapping from state name to a tensor). + id_map(dict): A mapping from tensor index (an integer) + to its corresponding parameter (a tensor) whose states will be updated. + strict(bool, optional): If set to True, only load the parameters with its id in id_map. Defaults to False. + """ + + # Ensure that the keys of state_dict are integers. + state_dict = {int(k): v for k, v in state_dict.items()} + + def cast(param, value, key=None): + r"""Make a deep copy of value, casting all tensors to device of param.""" + if isinstance(value, torch.Tensor): + # Floating-point types are a bit special here. They are the only ones + # that are assumed to always match the type of params. + # Make sure state['step'] is not casted https://github.com/pytorch/pytorch/issues/74424 + if key != "step": + if param.is_floating_point(): + value = value.to(param.dtype) + value = value.to(param.device) + return value + elif isinstance(value, dict): + return {k: cast(param, v, key=k) for k, v in value.items()} + elif isinstance(value, container_abcs.Iterable): + return type(value)(cast(param, v) for v in value) + else: + return value + + # Copy state assigned to params (and cast tensors to appropriate types). + # State that is not assigned to params is copied as is (needed for + # backward compatibility). + new_states = defaultdict(dict) + for k, v in state_dict.items(): + if k in id_map: + param = id_map[k] + new_states[param] = cast(param, v) + elif not strict: + new_states[k] = v + + optimizer.state.update(new_states) + + +def sharded_optimizer_loading_epilogue(optimizer: Optimizer): + r"""Do the cleaning up work after state_dict has been loaded into optimizer + + Args: + optimizer(Optimizer): An optimizer object whose state has just been loaded. + """ + + # Do the cleaning up as in src code of Pytorch. + if Version(torch.__version__) >= Version("2.0.0"): + optimizer._patch_step_function() # To support multiprocessing pickle/unpickle + else: + optimizer._hook_for_profile() # To support multiprocessing pickle/unpickle. + optimizer.defaults.setdefault("differentiable", False) def has_index_file(checkpoint_path: str) -> Tuple[bool, Optional[Path]]: @@ -365,18 +691,20 @@ def has_index_file(checkpoint_path: str) -> Tuple[bool, Optional[Path]]: return False, None elif checkpoint_path.is_dir(): # check if there is only one a file ending with .index.json in this directory - index_files = list(checkpoint_path.glob('*.index.*json')) + index_files = list(checkpoint_path.glob("*.index.*json")) # if we found a .index.json file, make sure there is only one if len(index_files) > 0: - assert len( - index_files - ) == 1, f'Expected to find one .index.json file in {checkpoint_path}, but found {len(index_files)}' + assert ( + len(index_files) == 1 + ), f"Expected to find one .index.json file in {checkpoint_path}, but found {len(index_files)}" if len(index_files) == 1: return True, index_files[0] else: return False, None + else: + raise RuntimeError(f"Invalid checkpoint path {checkpoint_path}. Expected a file or a directory.") def load_state_dict(checkpoint_file_path: Path): @@ -390,14 +718,17 @@ def load_state_dict(checkpoint_file_path: Path): dict: state dict. """ - assert not is_dtensor_checkpoint(checkpoint_file_path), \ - f'Cannot load state dict from dtensor checkpoint {checkpoint_file_path}, you should convert the distributed tensors to gathered tensors with our CLI offline.' + assert not is_dtensor_checkpoint( + checkpoint_file_path + ), f"Cannot load state dict from dtensor checkpoint {checkpoint_file_path}, you should convert the distributed tensors to gathered tensors with our CLI offline." if is_safetensor_checkpoint(checkpoint_file_path): - assert is_safetensors_available(), \ - f'Cannot load state dict from safetensor checkpoint {checkpoint_file_path}, because safetensors is not available. Please install safetensors first with pip install safetensors.' + assert ( + is_safetensors_available() + ), f"Cannot load state dict from safetensor checkpoint {checkpoint_file_path}, because safetensors is not available. Please install safetensors first with pip install safetensors." # load with safetensors from safetensors import safe_open + state_dict = {} with safe_open(checkpoint_file_path, framework="pt", device="cpu") as f: for k in f.keys(): @@ -406,14 +737,51 @@ def load_state_dict(checkpoint_file_path: Path): else: # load with torch - return torch.load(checkpoint_file_path) - + return torch.load(checkpoint_file_path, map_location=torch.device("cpu")) -def add_variant(weights_name: str, variant: Optional[str] = None) -> str: - if variant is not None and len(variant) > 0: +def add_prefix(weights_name: str, prefix: Optional[str] = None) -> str: + if prefix is not None and len(prefix) > 0: splits = weights_name.split(".") - splits = splits[:-1] + [variant] + splits[-1:] + splits = splits[:-1] + [prefix] + splits[-1:] weights_name = ".".join(splits) return weights_name + + +def get_model_base_filenames(prefix: str = None, use_safetensors: bool = False): + """ + generate base model weight filenames + """ + weights_name = SAFE_WEIGHTS_NAME if use_safetensors else WEIGHTS_NAME + weights_name = add_prefix(weights_name, prefix) + + save_index_file = SAFE_WEIGHTS_INDEX_NAME if use_safetensors else WEIGHTS_INDEX_NAME + save_index_file = add_prefix(save_index_file, prefix) + + return weights_name, save_index_file + + +def get_optimizer_base_filenames(prefix: str = None): + """ + generate base optimizer state filenames + """ + states_name = STATES_NAME + states_name = add_prefix(states_name, prefix) + + save_index_file = STATES_INDEX_NAME + save_index_file = add_prefix(save_index_file, prefix) + + param_group_file = GROUP_FILE_NAME + param_group_file = add_prefix(param_group_file, prefix) + + return states_name, save_index_file, param_group_file + + +def get_shard_filename(weights_name: str, idx: int): + """ + get shard file name + """ + shard_file = weights_name.replace(".bin", f"-{idx+1:05d}.bin") + shard_file = shard_file.replace(".safetensors", f"-{idx+1:05d}.safetensors") + return shard_file diff --git a/colossalai/cli/__init__.py b/colossalai/cli/__init__.py index 658e35e4c72e77f7b3161bf86b0c9600a80562e5..c7cb19c193082d54b57c6fdd40815f29636f6b25 100644 --- a/colossalai/cli/__init__.py +++ b/colossalai/cli/__init__.py @@ -1,3 +1,3 @@ from .cli import cli -__all__ = ['cli'] +__all__ = ["cli"] diff --git a/colossalai/cli/benchmark/__init__.py b/colossalai/cli/benchmark/__init__.py deleted file mode 100644 index 618ff8c61dd41ec83d567e7cd9103f2aa9921846..0000000000000000000000000000000000000000 --- a/colossalai/cli/benchmark/__init__.py +++ /dev/null @@ -1,28 +0,0 @@ -import click - -from colossalai.context import Config - -from .benchmark import run_benchmark -from .utils import * - -__all__ = ['benchmark'] - - -@click.command() -@click.option("-g", "--gpus", type=int, default=None, help="Total number of devices to use.") -@click.option("-b", "--batch_size", type=int, default=8, help="Batch size of the input tensor.") -@click.option("-s", "--seq_len", type=int, default=512, help="Sequence length of the input tensor.") -@click.option("-d", "--dimension", type=int, default=1024, help="Hidden dimension of the input tensor.") -@click.option("-w", "--warmup_steps", type=int, default=10, help="The number of warmup steps.") -@click.option("-p", "--profile_steps", type=int, default=50, help="The number of profiling steps.") -@click.option("-l", "--layers", type=int, default=2) -@click.option("-m", - "--model", - type=click.Choice(['mlp'], case_sensitive=False), - default='mlp', - help="Select the model to benchmark, currently only supports MLP") -def benchmark(gpus: int, batch_size: int, seq_len: int, dimension: int, warmup_steps: int, profile_steps: int, - layers: int, model: str): - args_dict = locals() - args = Config(args_dict) - run_benchmark(args) diff --git a/colossalai/cli/benchmark/benchmark.py b/colossalai/cli/benchmark/benchmark.py deleted file mode 100644 index 97a9f45722dd6a4c1e316d1e91f27439797ae17a..0000000000000000000000000000000000000000 --- a/colossalai/cli/benchmark/benchmark.py +++ /dev/null @@ -1,105 +0,0 @@ -from functools import partial -from typing import Dict, List - -import click -import torch.multiprocessing as mp - -import colossalai -from colossalai.cli.benchmark.utils import find_all_configs, get_batch_data, profile_model -from colossalai.context import Config -from colossalai.context.random import reset_seeds -from colossalai.core import global_context as gpc -from colossalai.logging import disable_existing_loggers, get_dist_logger -from colossalai.testing import free_port -from colossalai.utils import MultiTimer - -from .models import MLP - - -def run_benchmark(args: Config) -> None: - """ - Run benchmarking with torch.multiprocessing. - """ - - # sanity checks - if args.gpus is None: - click.echo("Error: --num_gpus is not given") - exit() - if args.gpus <= 1: - click.echo("Warning: tensor parallel will be activated with at least 2 devices.") - - click.echo("=== Benchmarking Parameters ===") - for k, v in args.items(): - click.echo(f'{k}: {v}') - click.echo('') - - config_list = find_all_configs(args.gpus) - - avail_ports = [free_port() for _ in range(len(config_list))] - run_func = partial(run_dist_profiling, - world_size=args.gpus, - port_list=avail_ports, - config_list=config_list, - hyperparams=args) - mp.spawn(run_func, nprocs=args.gpus) - - -def run_dist_profiling(rank: int, world_size: int, port_list: List[int], config_list: List[Dict], - hyperparams: Config) -> None: - """ - A function executed for profiling, this function should be spawn by torch.multiprocessing. - - Args: - rank (int): rank of the process - world_size (int): the number of processes - port_list (List[int]): a list of free ports for initializing distributed networks - config_list (List[Dict]): a list of configuration - hyperparams (Config): the hyperparameters given by the user - - """ - - # disable logging for clean output - disable_existing_loggers() - logger = get_dist_logger() - logger.set_level('WARNING') - - for config, port in zip(config_list, port_list): - colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') - timer = MultiTimer() - - # 1D parallel should be skipped if in_features or out_features is not able to be divided exactly by 1D parallel size. - if config.parallel.tensor.mode == '1d' and hyperparams.dimension % config.parallel.tensor.size != 0: - click.echo( - "1D parallel will be skipped because in_features or out_features is not able to be divided exactly by 1D parallel size." - ) - continue - - if hyperparams.model == 'mlp': - model = MLP(dim=hyperparams.dimension, layers=hyperparams.layers) - else: - if gpc.get_global_rank() == 0: - click.echo("Error: Invalid argument for --model") - exit() - - data_func = partial(get_batch_data, - dim=hyperparams.dimension, - batch_size=hyperparams.batch_size, - seq_length=hyperparams.seq_len, - mode=config.parallel.tensor.mode) - - fwd_time, bwd_time, max_allocated, max_cached = profile_model(model=model, - warmup_steps=hyperparams.warmup_steps, - profile_steps=hyperparams.profile_steps, - data_func=data_func, - timer=timer) - - gpc.destroy() - reset_seeds() - - if gpc.get_global_rank() == 0: - config_str = ', '.join([f'{k}: {v}' for k, v in config.parallel.tensor.items()]) - click.echo(f"=== {config_str} ===") - click.echo(f"Average forward time: {fwd_time}") - click.echo(f"Average backward time: {bwd_time}") - click.echo(f"Max allocated GPU memory: {max_allocated}") - click.echo(f"Max cached GPU memory: {max_cached}\n") diff --git a/colossalai/cli/benchmark/models.py b/colossalai/cli/benchmark/models.py deleted file mode 100644 index f8fd1c41a059806891713340f2ea4931ec9726f2..0000000000000000000000000000000000000000 --- a/colossalai/cli/benchmark/models.py +++ /dev/null @@ -1,18 +0,0 @@ -import torch - -import colossalai.nn as col_nn - - -class MLP(torch.nn.Module): - - def __init__(self, dim: int, layers: int): - super().__init__() - self.layers = torch.nn.ModuleList() - - for _ in range(layers): - self.layers.append(col_nn.Linear(dim, dim)) - - def forward(self, x): - for layer in self.layers: - x = layer(x) - return x diff --git a/colossalai/cli/benchmark/utils.py b/colossalai/cli/benchmark/utils.py deleted file mode 100644 index 825b795f21f680bcb2e2ea5eee4b328c2e1777db..0000000000000000000000000000000000000000 --- a/colossalai/cli/benchmark/utils.py +++ /dev/null @@ -1,158 +0,0 @@ -import math -import time -import torch - -from colossalai.utils import MultiTimer -from colossalai.context import ParallelMode, Config -from typing import List, Dict, Tuple, Callable - - -def get_time_stamp() -> int: - """ - Return the time stamp for profiling. - - Returns: - time_stamp (int): the time given by time.time() - """ - - torch.cuda.synchronize() - time_stamp = time.time() - return time_stamp - - -def get_memory_states() -> Tuple[float]: - """ - Return the memory statistics. - - Returns: - max_allocated (float): the allocated CUDA memory - max_cached (float): the cached CUDA memory - """ - - max_allocated = torch.cuda.max_memory_allocated() / (1024**3) - max_cached = torch.cuda.max_memory_reserved() / (1024**3) - torch.cuda.reset_peak_memory_stats() - torch.cuda.empty_cache() - return max_allocated, max_cached - - -def find_all_configs(device_cnt: int) -> List[Dict]: - """ - Find all possible configurations for tensor parallelism - - Args: - device_cnt (int): the number of devices - - Returns: - config_list (List[Dict]): a list of configurations - """ - - def _is_square(num): - # 2D parallel should be implemented with at least 2 devices. - if num <= 1: - return False - return math.floor(math.sqrt(num))**2 == num - - def _is_cube(num): - # 3D parallel should be implemented with at least 2 devices. - if num <= 1: - return False - return math.floor(num**(1. / 3.))**3 == num - - config_list = [] - - # add non-parallel config - config = dict(parallel=dict(tensor=dict(size=device_cnt, mode=None))) - config_list.append(config) - - # add 1D config - config = dict(parallel=dict(tensor=dict(size=device_cnt, mode='1d'))) - config_list.append(config) - - # add 2D config only if device_cnt is a square - if _is_square(device_cnt): - config = dict(parallel=dict(tensor=dict(size=device_cnt, mode='2d'))) - config_list.append(config) - - # check for 2.5D - # iterate over depth - for depth in range(1, device_cnt): - if device_cnt % depth == 0 and _is_square(device_cnt // depth): - config = dict(parallel=dict(tensor=dict(size=device_cnt, mode='2.5d', depth=depth))) - config_list.append(config) - - # check for 3D if device_cnt is a cube - if _is_cube(device_cnt): - config = dict(parallel=dict(tensor=dict(size=device_cnt, mode='3d'))) - config_list.append(config) - - config_list = [Config(cfg) for cfg in config_list] - return config_list - - -def profile_model(model: torch.nn.Module, warmup_steps: int, profile_steps: int, data_func: Callable, - timer: MultiTimer) -> Tuple[float]: - """ - Profile the forward and backward of a model - - Args: - model (torch.nn.Module): a PyTorch model - warmup_steps (int): the number of steps for warmup - profile_steps (int): the number of steps for profiling - data_func (Callable): a function to generate random data - timer (colossalai.utils.Multitimer): a timer instance for time recording - - Returns: - fwd_time (float): the average forward time taken by forward pass in second - bwd_time (float): the average backward time taken by forward pass in second - max_allocated (float): the maximum GPU memory allocated in GB - max_cached (float): the maximum GPU memory cached in GB - """ - - def _run_step(data): - timer.start('forward') - out = model(data) - timer.stop('forward', keep_in_history=True) - timer.start('backward') - out.mean().backward() - timer.stop('backward', keep_in_history=True) - - data_list = [data_func() for _ in range(warmup_steps)] - for data in data_list: - _run_step(data) - timer.reset('forward') - timer.reset('backward') - - for _ in range(profile_steps): - data = data_func() - _run_step(data) - - max_allocated, max_cached = get_memory_states() - fwd_time = timer.get_timer('forward').get_history_mean() - bwd_time = timer.get_timer('backward').get_history_mean() - return fwd_time, bwd_time, max_allocated, max_cached - - -def get_batch_data(dim: int, batch_size: int, seq_length: int, mode: ParallelMode) -> torch.Tensor: - """ - Return a random data of shape (batch_size, seq_length, dim) for profiling. - - Args: - dim (int): hidden size - batch_size (int): the number of data samples - seq_length (int): the number of tokens - mode (ParallelMode): Colossal-AI ParallelMode enum - - Returns: - data (torch.Tensor): random data - """ - - if mode in ['2d', '2.5d']: - batch_size = batch_size // 2 - dim = dim // 2 - elif mode == '3d': - batch_size = batch_size // 4 - dim = dim // 2 - - data = torch.rand(batch_size, seq_length, dim).cuda() - return data diff --git a/colossalai/cli/check/__init__.py b/colossalai/cli/check/__init__.py index a86b32bb6a181a7c75bfa6682c2f7514169d9a7f..7c26ab6ade6cdc7238608232ce07c07fd977888f 100644 --- a/colossalai/cli/check/__init__.py +++ b/colossalai/cli/check/__init__.py @@ -1,11 +1,12 @@ import click + from .check_installation import check_installation -__all__ = ['check'] +__all__ = ["check"] @click.command(help="Check if Colossal-AI is correct based on the given option") -@click.option('-i', '--installation', is_flag=True, help="Check if Colossal-AI is built correctly") +@click.option("-i", "--installation", is_flag=True, help="Check if Colossal-AI is built correctly") def check(installation): if installation: check_installation() diff --git a/colossalai/cli/check/check_installation.py b/colossalai/cli/check/check_installation.py index cb3dbbc09301012aa662264f241e1fce89470d39..772c513ffa06b1ffa148f1e95f956964e14a8716 100644 --- a/colossalai/cli/check/check_installation.py +++ b/colossalai/cli/check/check_installation.py @@ -9,7 +9,7 @@ import colossalai def to_click_output(val): # installation check output to understandable symbols for readability - VAL_TO_SYMBOL = {True: u'\u2713', False: 'x', None: 'N/A'} + VAL_TO_SYMBOL = {True: "\u2713", False: "x", None: "N/A"} if val in VAL_TO_SYMBOL: return VAL_TO_SYMBOL[val] @@ -31,7 +31,7 @@ def check_installation(): found_aot_cuda_ext = _check_aot_built_cuda_extension_installed() cuda_version = _check_cuda_version() torch_version, torch_cuda_version = _check_torch_version() - colossalai_verison, prebuilt_torch_version_required, prebuilt_cuda_version_required = _parse_colossalai_version() + colossalai_version, prebuilt_torch_version_required, prebuilt_cuda_version_required = _parse_colossalai_version() # if cuda_version is None, that means either # CUDA_HOME is not found, thus cannot compare the version compatibility @@ -55,9 +55,9 @@ def check_installation(): else: torch_compatibility = _is_compatible([torch_version, prebuilt_torch_version_required]) - click.echo(f'#### Installation Report ####') - click.echo(f'\n------------ Environment ------------') - click.echo(f"Colossal-AI version: {to_click_output(colossalai_verison)}") + click.echo(f"#### Installation Report ####") + click.echo(f"\n------------ Environment ------------") + click.echo(f"Colossal-AI version: {to_click_output(colossalai_version)}") click.echo(f"PyTorch version: {to_click_output(torch_version)}") click.echo(f"System CUDA version: {to_click_output(cuda_version)}") click.echo(f"CUDA version required by PyTorch: {to_click_output(torch_cuda_version)}") @@ -69,7 +69,7 @@ def check_installation(): f"3. If the CUDA version required by PyTorch is N/A, you probably did not install a CUDA-compatible PyTorch. This value is give by torch.version.cuda and you can go to https://pytorch.org/get-started/locally/ to download the correct version." ) - click.echo(f'\n------------ CUDA Extensions AOT Compilation ------------') + click.echo(f"\n------------ CUDA Extensions AOT Compilation ------------") click.echo(f"Found AOT CUDA Extension: {to_click_output(found_aot_cuda_ext)}") click.echo(f"PyTorch version used for AOT compilation: {to_click_output(prebuilt_torch_version_required)}") click.echo(f"CUDA version used for AOT compilation: {to_click_output(prebuilt_cuda_version_required)}") @@ -81,7 +81,7 @@ def check_installation(): click.echo(f"2. If AOT compilation is not enabled, stay calm as the CUDA kernels can still be built during runtime") click.echo(f"\n------------ Compatibility ------------") - click.echo(f'PyTorch version match: {to_click_output(torch_compatibility)}') + click.echo(f"PyTorch version match: {to_click_output(torch_compatibility)}") click.echo(f"System and PyTorch CUDA version match: {to_click_output(sys_torch_cuda_compatibility)}") click.echo(f"System and Colossal-AI CUDA version match: {to_click_output(sys_colossalai_cuda_compatibility)}") click.echo(f"") @@ -106,12 +106,12 @@ def _is_compatible(versions): return False # split version into [major, minor, patch] - versions = [version.split('.') for version in versions] + versions = [version.split(".") for version in versions] for version in versions: if len(version) == 2: # x means unknown - version.append('x') + version.append("x") for idx, version_values in enumerate(zip(*versions)): equal = len(set(version_values)) == 1 @@ -137,15 +137,15 @@ def _parse_colossalai_version(): # 1. X.X.X+torchX.XXcuXX.X (when colossalai is installed with CUDA extensions) # 2. X.X.X (when colossalai is not installed with CUDA extensions) # where X represents an integer. - colossalai_verison = colossalai.__version__.split('+')[0] + colossalai_version = colossalai.__version__.split("+")[0] try: - torch_version_for_aot_build = colossalai.__version__.split('torch')[1].split('cu')[0] - cuda_version_for_aot_build = colossalai.__version__.split('cu')[1] + torch_version_for_aot_build = colossalai.__version__.split("torch")[1].split("cu")[0] + cuda_version_for_aot_build = colossalai.__version__.split("cu")[1] except: torch_version_for_aot_build = None cuda_version_for_aot_build = None - return colossalai_verison, torch_version_for_aot_build, cuda_version_for_aot_build + return colossalai_version, torch_version_for_aot_build, cuda_version_for_aot_build def _check_aot_built_cuda_extension_installed(): @@ -156,7 +156,6 @@ def _check_aot_built_cuda_extension_installed(): JIT (just-in-time) compilation will build CUDA extensions to `~/.cache/colossalai/torch_extensions` during runtime. """ try: - import colossalai._C.fused_optim found_aot_cuda_ext = True except ImportError: found_aot_cuda_ext = False @@ -175,14 +174,14 @@ def _check_torch_version(): # torch version can be of two formats # - 1.13.1+cu113 # - 1.13.1.devxxx - torch_version = torch.__version__.split('+')[0] - torch_version = '.'.join(torch_version.split('.')[:3]) + torch_version = torch.__version__.split("+")[0] + torch_version = ".".join(torch_version.split(".")[:3]) # get cuda version in pytorch build try: torch_cuda_major = torch.version.cuda.split(".")[0] torch_cuda_minor = torch.version.cuda.split(".")[1] - torch_cuda_version = f'{torch_cuda_major}.{torch_cuda_minor}' + torch_cuda_version = f"{torch_cuda_major}.{torch_cuda_minor}" except: torch_cuda_version = None @@ -208,7 +207,7 @@ def _check_cuda_version(): release = output[release_idx].split(".") bare_metal_major = release[0] bare_metal_minor = release[1][0] - cuda_version = f'{bare_metal_major}.{bare_metal_minor}' + cuda_version = f"{bare_metal_major}.{bare_metal_minor}" except: cuda_version = None return cuda_version diff --git a/colossalai/cli/cli.py b/colossalai/cli/cli.py index a94e1150e49fc00a210c20b32a9bfc85eda66aa6..0d94fe59f8aef275c4bae22a6ea29fa8ec0c1978 100644 --- a/colossalai/cli/cli.py +++ b/colossalai/cli/cli.py @@ -1,12 +1,10 @@ import click -from .benchmark import benchmark from .check import check from .launcher import run -class Arguments(): - +class Arguments: def __init__(self, arg_dict): for k, v in arg_dict.items(): self.__dict__[k] = v @@ -19,7 +17,6 @@ def cli(): cli.add_command(run) cli.add_command(check) -cli.add_command(benchmark) -if __name__ == '__main__': +if __name__ == "__main__": cli() diff --git a/colossalai/cli/launcher/__init__.py b/colossalai/cli/launcher/__init__.py index 8d9ec147d401a2e5d055852e661e189985b6db6e..0f9ead6495dbe242873d0bfded0a020518bf2d8c 100644 --- a/colossalai/cli/launcher/__init__.py +++ b/colossalai/cli/launcher/__init__.py @@ -5,56 +5,81 @@ from colossalai.context import Config from .run import launch_multi_processes -@click.command(help="Launch distributed training on a single node or multiple nodes", - context_settings=dict(ignore_unknown_options=True)) -@click.option("-H", - "-host", - "--host", - type=str, - default=None, - help="the list of hostnames to launch in the format ,") +@click.command( + help="Launch distributed training on a single node or multiple nodes", + context_settings=dict(ignore_unknown_options=True), +) +@click.option( + "-H", + "-host", + "--host", + type=str, + default=None, + help="the list of hostnames to launch in the format ,", +) @click.option( "--hostfile", type=str, default=None, - help="Hostfile path that defines the device pool available to the job, each line in the file is a hostname") -@click.option("--include", - type=str, - default=None, - help="Specify computing devices to use during execution. String format is ,," - " only effective when used with --hostfile.") + help="Hostfile path that defines the device pool available to the job, each line in the file is a hostname", +) +@click.option( + "--include", + type=str, + default=None, + help="Specify computing devices to use during execution. String format is ,," + " only effective when used with --hostfile.", +) @click.option( "--exclude", type=str, default=None, - help= - "Specify computing devices to NOT use during execution. Mutually exclusive with --include. Formatting is the same as --includ," - " only effective when used with --hostfile.") -@click.option("--num_nodes", - type=int, - default=-1, - help="Total number of worker nodes to use, only effective when used with --hostfile.") + help="Specify computing devices to NOT use during execution. Mutually exclusive with --include. Formatting is the same as --include," + " only effective when used with --hostfile.", +) +@click.option( + "--num_nodes", + type=int, + default=-1, + help="Total number of worker nodes to use, only effective when used with --hostfile.", +) @click.option("--nproc_per_node", type=int, default=None, help="Number of GPUs to use on each node.") -@click.option("--master_port", - type=int, - default=29500, - help="(optional) Port used by PyTorch distributed for communication during distributed training.") -@click.option("--master_addr", - type=str, - default="127.0.0.1", - help="(optional) IP address of node 0, will be inferred via 'hostname -I' if not specified.") +@click.option( + "--master_port", + type=int, + default=29500, + help="(optional) Port used by PyTorch distributed for communication during distributed training.", +) +@click.option( + "--master_addr", + type=str, + default="127.0.0.1", + help="(optional) IP address of node 0, will be inferred via 'hostname -I' if not specified.", +) @click.option( "--extra_launch_args", type=str, default=None, - help= - "Set additional torch distributed launcher arguments such as --standalone. The format is --extra_launch_args arg1=1,arg2=2. " - "This will be converted to --arg1=1 --arg2=2 during execution") + help="Set additional torch distributed launcher arguments such as --standalone. The format is --extra_launch_args arg1=1,arg2=2. " + "This will be converted to --arg1=1 --arg2=2 during execution", +) @click.option("--ssh-port", type=int, default=None, help="(optional) the port used for ssh connection") @click.argument("user_script", type=str) -@click.argument('user_args', nargs=-1) -def run(host: str, hostfile: str, num_nodes: int, nproc_per_node: int, include: str, exclude: str, master_addr: str, - master_port: int, extra_launch_args: str, ssh_port: int, user_script: str, user_args: str) -> None: +@click.argument("user_args", nargs=-1) +def run( + host: str, + hostfile: str, + num_nodes: int, + nproc_per_node: int, + include: str, + exclude: str, + master_addr: str, + master_port: int, + extra_launch_args: str, + ssh_port: int, + user_script: str, + user_args: str, +) -> None: """ To launch multiple processes on a single node or multiple nodes via command line. @@ -77,8 +102,8 @@ def run(host: str, hostfile: str, num_nodes: int, nproc_per_node: int, include: # run with hostfile excluding the hosts selected colossalai run --hostfile --master_addr host1 --exclude host2 --nprocs_per_node 4 train.py """ - if not user_script.endswith('.py'): - click.echo(f'Error: invalid Python file {user_script}. Did you use a wrong option? Try colossalai run --help') + if not user_script.endswith(".py"): + click.echo(f"Error: invalid Python file {user_script}. Did you use a wrong option? Try colossalai run --help") exit() args_dict = locals() diff --git a/colossalai/cli/launcher/hostinfo.py b/colossalai/cli/launcher/hostinfo.py index 065cbc37101f9705319d96120779edcbbbf6dde9..684f64f59d28b945d930602e988a2847670eda6b 100644 --- a/colossalai/cli/launcher/hostinfo.py +++ b/colossalai/cli/launcher/hostinfo.py @@ -1,5 +1,4 @@ import socket -from typing import List class HostInfo: @@ -34,11 +33,11 @@ class HostInfo: """ if port is None: - port = 22 # no port specified, lets just use the ssh port + port = 22 # no port specified, lets just use the ssh port # socket.getfqdn("127.0.0.1") does not return localhost # on some users' machines - # thus, we directly return True if hostname is locahost, 127.0.0.1 or 0.0.0.0 + # thus, we directly return True if hostname is localhost, 127.0.0.1 or 0.0.0.0 if hostname in ("localhost", "127.0.0.1", "0.0.0.0"): return True @@ -46,14 +45,11 @@ class HostInfo: localhost = socket.gethostname() localaddrs = socket.getaddrinfo(localhost, port) targetaddrs = socket.getaddrinfo(hostname, port) - for (family, socktype, proto, canonname, sockaddr) in localaddrs: - for (rfamily, rsocktype, rproto, rcanonname, rsockaddr) in targetaddrs: - if rsockaddr[0] == sockaddr[0]: - return True - return False + + return localaddrs == targetaddrs def __str__(self): - return f'hostname: {self.hostname}, port: {self.port}' + return f"hostname: {self.hostname}, port: {self.port}" def __repr__(self): return self.__str__() diff --git a/colossalai/cli/launcher/multinode_runner.py b/colossalai/cli/launcher/multinode_runner.py index a51e1e371f13df492c16ac5b554d02ba6491d65b..99c4db40684480e19933b6a4da6fca90cee70aff 100644 --- a/colossalai/cli/launcher/multinode_runner.py +++ b/colossalai/cli/launcher/multinode_runner.py @@ -7,8 +7,13 @@ import fabric from .hostinfo import HostInfo, HostInfoList -def run_on_host(hostinfo: HostInfo, workdir: str, recv_conn: mp_connection.Connection, - send_conn: mp_connection.Connection, env: dict) -> None: +def run_on_host( + hostinfo: HostInfo, + workdir: str, + recv_conn: mp_connection.Connection, + send_conn: mp_connection.Connection, + env: dict, +) -> None: """ Use fabric connection to execute command on local or remote hosts. @@ -22,14 +27,14 @@ def run_on_host(hostinfo: HostInfo, workdir: str, recv_conn: mp_connection.Conne fab_conn = fabric.Connection(hostinfo.hostname, port=hostinfo.port) finish = False - env_msg = ' '.join([f'{k}=\"{v}\"' for k, v in env.items()]) + env_msg = " ".join([f'{k}="{v}"' for k, v in env.items()]) # keep listening until exit while not finish: # receive cmd cmds = recv_conn.recv() - if cmds == 'exit': + if cmds == "exit": # exit from the loop finish = True break @@ -46,12 +51,12 @@ def run_on_host(hostinfo: HostInfo, workdir: str, recv_conn: mp_connection.Conne else: # execute on the remote machine fab_conn.run(cmds, hide=False) - send_conn.send('success') + send_conn.send("success") except Exception as e: click.echo( f"Error: failed to run {cmds} on {hostinfo.hostname}, is localhost: {hostinfo.is_local_host}, exception: {e}" ) - send_conn.send('failure') + send_conn.send("failure") # shutdown send_conn.send("finish") @@ -96,8 +101,7 @@ class MultiNodeRunner: cmd (str): the command to execute """ - assert hostinfo.hostname in self.master_send_conns, \ - f'{hostinfo} is not found in the current connections' + assert hostinfo.hostname in self.master_send_conns, f"{hostinfo} is not found in the current connections" conn = self.master_send_conns[hostinfo.hostname] conn.send(cmd) @@ -107,14 +111,14 @@ class MultiNodeRunner: """ for hostname, conn in self.master_send_conns.items(): - conn.send('exit') + conn.send("exit") def recv_from_all(self) -> dict: """ Receive messages from all hosts Returns: - msg_from_node (dict): a dictionry which contains messages from each node + msg_from_node (dict): a dictionary which contains messages from each node """ msg_from_node = dict() diff --git a/colossalai/cli/launcher/run.py b/colossalai/cli/launcher/run.py index 6411b4302e95d25efc5a55da1b523b76ee6ee1e3..88f70f02ec27fb81a8fe5ab21973810638500ee4 100644 --- a/colossalai/cli/launcher/run.py +++ b/colossalai/cli/launcher/run.py @@ -12,7 +12,7 @@ from .hostinfo import HostInfo, HostInfoList from .multinode_runner import MultiNodeRunner # Constants that define our syntax -NODE_SEP = ',' +NODE_SEP = "," def fetch_hostfile(hostfile_path: str, ssh_port: int) -> HostInfoList: @@ -34,12 +34,12 @@ def fetch_hostfile(hostfile_path: str, ssh_port: int) -> HostInfoList: click.echo(f"Error: Unable to find the hostfile, no such file: {hostfile_path}") exit() - with open(hostfile_path, 'r') as fd: + with open(hostfile_path, "r") as fd: device_pool = HostInfoList() for line in fd.readlines(): line = line.strip() - if line == '': + if line == "": # skip empty lines continue @@ -56,7 +56,7 @@ def fetch_hostfile(hostfile_path: str, ssh_port: int) -> HostInfoList: def parse_device_filter(device_pool: HostInfoList, include_str=None, exclude_str=None) -> HostInfoList: - '''Parse an inclusion or exclusion string and filter a hostfile dictionary. + """Parse an inclusion or exclusion string and filter a hostfile dictionary. Examples: include_str="worker-0,worker-1" will execute jobs only on worker-0 and worker-1. @@ -69,7 +69,7 @@ def parse_device_filter(device_pool: HostInfoList, include_str=None, exclude_str Returns: filtered_hosts (HostInfoList): filtered hosts after inclusion/exclusion - ''' + """ # Ensure include/exclude are mutually exclusive if include_str and exclude_str: @@ -136,16 +136,16 @@ def get_launch_command( for k, v in arg_dict.items(): if v: - ret.append(f'--{k}={v}') + ret.append(f"--{k}={v}") else: - ret.append(f'--{k}') + ret.append(f"--{k}") return ret if extra_launch_args: extra_launch_args_dict = dict() - for arg in extra_launch_args.split(','): - if '=' in arg: - k, v = arg.split('=') + for arg in extra_launch_args.split(","): + if "=" in arg: + k, v = arg.split("=") extra_launch_args_dict[k] = v else: extra_launch_args_dict[arg] = None @@ -154,19 +154,23 @@ def get_launch_command( extra_launch_args = dict() torch_version = version.parse(torch.__version__) - assert torch_version.major == 1 + assert torch_version.major >= 1 - if torch_version.minor < 9: + if torch_version.major == 1 and torch_version.minor < 9: + # torch distributed launch cmd with torch < 1.9 cmd = [ - sys.executable, "-m", "torch.distributed.launch", f"--nproc_per_node={nproc_per_node}", - f"--master_addr={master_addr}", f"--master_port={master_port}", f"--nnodes={num_nodes}", - f"--node_rank={node_rank}" + sys.executable, + "-m", + "torch.distributed.launch", + f"--nproc_per_node={nproc_per_node}", + f"--master_addr={master_addr}", + f"--master_port={master_port}", + f"--nnodes={num_nodes}", + f"--node_rank={node_rank}", ] else: # extra launch args for torch distributed launcher with torch >= 1.9 - default_torchrun_rdzv_args = dict(rdzv_backend="c10d", - rdzv_endpoint=f"{master_addr}:{master_port}", - rdzv_id="colossalai-default-job") + default_torchrun_rdzv_args = dict(master_addr=master_addr, master_port=master_port) # update rdzv arguments for key in default_torchrun_rdzv_args.keys(): @@ -174,19 +178,28 @@ def get_launch_command( value = extra_launch_args.pop(key) default_torchrun_rdzv_args[key] = value - if torch_version.minor < 10: + if torch_version.major == 1 and torch_version.minor == 9: + # torch distributed launch cmd with torch == 1.9 cmd = [ - sys.executable, "-m", "torch.distributed.run", f"--nproc_per_node={nproc_per_node}", - f"--nnodes={num_nodes}", f"--node_rank={node_rank}" + sys.executable, + "-m", + "torch.distributed.run", + f"--nproc_per_node={nproc_per_node}", + f"--nnodes={num_nodes}", + f"--node_rank={node_rank}", ] else: + # torch distributed launch cmd with torch > 1.9 cmd = [ - "torchrun", f"--nproc_per_node={nproc_per_node}", f"--nnodes={num_nodes}", f"--node_rank={node_rank}" + "torchrun", + f"--nproc_per_node={nproc_per_node}", + f"--nnodes={num_nodes}", + f"--node_rank={node_rank}", ] cmd += _arg_dict_to_list(default_torchrun_rdzv_args) cmd += _arg_dict_to_list(extra_launch_args) + [user_script] + user_args - cmd = ' '.join(cmd) + cmd = " ".join(cmd) return cmd @@ -250,33 +263,39 @@ def launch_multi_processes(args: Config) -> None: # run on local node if not hosts or hostfile is given # add local node to host info list active_device_pool = HostInfoList() - localhost_info = HostInfo(hostname='127.0.0.1', port=args.ssh_port) + localhost_info = HostInfo(hostname="127.0.0.1", port=args.ssh_port) active_device_pool.append(localhost_info) # launch distributed processes runner = MultiNodeRunner() - curr_path = os.path.abspath('.') + curr_path = os.path.abspath(".") # collect current path env env = dict() for k, v in os.environ.items(): # do not support multi-line env var - if v and '\n' not in v: + if v and "\n" not in v: env[k] = v # establish remote connection runner.connect(host_info_list=active_device_pool, workdir=curr_path, env=env) + # overwrite master addr when num_nodes > 1 and not specified + if len(active_device_pool) > 1 and args.master_addr == "127.0.0.1": + args.master_addr = active_device_pool.hostinfo_list[0].hostname + # execute distributed launching command for node_id, hostinfo in enumerate(active_device_pool): - cmd = get_launch_command(master_addr=args.master_addr, - master_port=args.master_port, - nproc_per_node=args.nproc_per_node, - user_script=args.user_script, - user_args=args.user_args, - node_rank=node_id, - num_nodes=len(active_device_pool), - extra_launch_args=args.extra_launch_args) + cmd = get_launch_command( + master_addr=args.master_addr, + master_port=args.master_port, + nproc_per_node=args.nproc_per_node, + user_script=args.user_script, + user_args=args.user_args, + node_rank=node_id, + num_nodes=len(active_device_pool), + extra_launch_args=args.extra_launch_args, + ) runner.send(hostinfo=hostinfo, cmd=cmd) # start training @@ -298,7 +317,7 @@ def launch_multi_processes(args: Config) -> None: # receive the stop status msg_from_node = runner.recv_from_all() - # printe node status + # print node status click.echo("\n====== Stopping All Nodes =====") for hostname, msg in msg_from_node.items(): click.echo(f"{hostname}: {msg}") diff --git a/colossalai/cluster/__init__.py b/colossalai/cluster/__init__.py index 2fbdfd3cc9996b1044720ef0c1669f5f67fbe8b3..b8176feb647b87c4c7da366a74f416d246d96fd8 100644 --- a/colossalai/cluster/__init__.py +++ b/colossalai/cluster/__init__.py @@ -1,5 +1,6 @@ from .device_mesh_manager import DeviceMeshManager from .dist_coordinator import DistCoordinator from .process_group_manager import ProcessGroupManager +from .process_group_mesh import ProcessGroupMesh -__all__ = ['DistCoordinator', 'ProcessGroupManager', 'DeviceMeshManager'] +__all__ = ["DistCoordinator", "ProcessGroupManager", "DeviceMeshManager", "ProcessGroupMesh"] diff --git a/colossalai/cluster/device_mesh_manager.py b/colossalai/cluster/device_mesh_manager.py index 8754baa19792adf6c5ec79c115d36fac9a3f3c5d..e35aca5f4d7e827512d7b618f90f6dd0b86afd61 100644 --- a/colossalai/cluster/device_mesh_manager.py +++ b/colossalai/cluster/device_mesh_manager.py @@ -10,13 +10,14 @@ from colossalai.device.device_mesh import DeviceMesh @dataclass class DeviceMeshInfo: - ''' + """ This class is used to store the information used to initialize the device mesh. Args: physical_ids (List[int]): The physical ids of the current booster. For example, if we have the last 4 GPUs on a 8-devices cluster, then the physical ids should be [4, 5, 6, 7]. mesh_shapes (List[Union[torch.Size, List[int], Tuple[int]]]): The shape of the mesh. For example, if we have 4 GPUs and we want to use 2D mesh with mesh shape [2, 2], then the mesh shape should be [2, 2]. - ''' + """ + physical_ids: List[int] mesh_shape: Union[torch.Size, List[int], Tuple[int]] = None @@ -24,16 +25,18 @@ class DeviceMeshInfo: if self.mesh_shape is not None: world_size = len(self.physical_ids) mesh_shape_numel = torch.Size(self.mesh_shape).numel() - assert world_size == mesh_shape_numel, f'the numel of mesh_shape should be equal to world size, but got {world_size} != {mesh_shape_numel}' + assert ( + world_size == mesh_shape_numel + ), f"the numel of mesh_shape should be equal to world size, but got {world_size} != {mesh_shape_numel}" def initialize_device_mesh(device_mesh_info: DeviceMeshInfo): - ''' + """ This method is used to initialize the device mesh. Args: device_mesh_info (DeviceMeshInfo): The information used to initialize device mesh. - ''' + """ # parse the device mesh info physical_devices = device_mesh_info.physical_ids physical_mesh = torch.tensor(physical_devices) @@ -67,13 +70,13 @@ class DeviceMeshManager: Args: name (str): name of the device mesh device_mesh_info (DeviceMeshInfo): the information used to initialize the device mesh - """ + """ if name not in self.device_mesh_store: device_mesh = initialize_device_mesh(device_mesh_info) self.device_mesh_store[name] = device_mesh return device_mesh else: - raise ValueError(f'Device mesh {name} already exists.') + raise ValueError(f"Device mesh {name} already exists.") def get(self, name: str) -> DeviceMesh: """ @@ -88,7 +91,7 @@ class DeviceMeshManager: if name in self.device_mesh_store: return self.device_mesh_store[name] else: - raise ValueError(f'Device mesh {name} does not exist.') + raise ValueError(f"Device mesh {name} does not exist.") def destroy(self, name: str) -> None: """ @@ -103,7 +106,7 @@ class DeviceMeshManager: dist.destroy_process_group(pg) del self.device_mesh_store[name] else: - raise ValueError(f'Device mesh {name} does not exist.') + raise ValueError(f"Device mesh {name} does not exist.") def destroy_all(self): """ diff --git a/colossalai/cluster/dist_coordinator.py b/colossalai/cluster/dist_coordinator.py index 99dde810e11251e16235573f6dd88e68b74a64b3..98191747e5b368b68206a41645e9b7d8e9fae4a3 100644 --- a/colossalai/cluster/dist_coordinator.py +++ b/colossalai/cluster/dist_coordinator.py @@ -20,14 +20,16 @@ class DistCoordinator(metaclass=SingletonMeta): - master: the process with rank 0 - node master: the process with local rank 0 on the current node - Example: - >>> from colossalai.cluster.dist_coordinator import DistCoordinator - >>> coordinator = DistCoordinator() - >>> - >>> if coordinator.is_master(): - >>> do_something() - >>> - >>> coordinator.print_on_master('hello world') + + ```python + from colossalai.cluster.dist_coordinator import DistCoordinator + coordinator = DistCoordinator() + + if coordinator.is_master(): + do_something() + + coordinator.print_on_master('hello world') + ``` Attributes: rank (int): the rank of the current process @@ -36,12 +38,13 @@ class DistCoordinator(metaclass=SingletonMeta): """ def __init__(self): - assert dist.is_initialized( - ), 'Distributed is not initialized. Please call `torch.distributed.init_process_group` or `colossalai.launch` first.' + assert ( + dist.is_initialized() + ), "Distributed is not initialized. Please call `torch.distributed.init_process_group` or `colossalai.launch` first." self._rank = dist.get_rank() self._world_size = dist.get_world_size() # this is often passed by launchers such as torchrun - self._local_rank = os.environ.get('LOCAL_RANK', -1) + self._local_rank = os.environ.get("LOCAL_RANK", -1) @property def rank(self) -> int: @@ -59,7 +62,9 @@ class DistCoordinator(metaclass=SingletonMeta): """ Assert that the local rank is set. This is often passed by launchers such as torchrun. """ - assert self.local_rank >= 0, 'The environment variable LOCAL_RANK is not set, thus the coordinator is not aware of the local rank of the current process.' + assert ( + self.local_rank >= 0 + ), "The environment variable LOCAL_RANK is not set, thus the coordinator is not aware of the local rank of the current process." def is_master(self, process_group: ProcessGroup = None) -> bool: """ @@ -128,11 +133,13 @@ class DistCoordinator(metaclass=SingletonMeta): other processes in the same process group. This is often useful when downloading is required as we only want to download in one process to prevent file corruption. - Example: - >>> from colossalai.cluster import DistCoordinator - >>> dist_coordinator = DistCoordinator() - >>> with dist_coordinator.priority_execution(): - >>> dataset = CIFAR10(root='./data', download=True) + + ```python + from colossalai.cluster import DistCoordinator + dist_coordinator = DistCoordinator() + with dist_coordinator.priority_execution(): + dataset = CIFAR10(root='./data', download=True) + ``` Args: executor_rank (int): the process rank to execute without blocking, all other processes will be blocked @@ -171,19 +178,19 @@ class DistCoordinator(metaclass=SingletonMeta): """ A function wrapper that only executes the wrapped function on the master process (rank 0). - Example: - >>> from colossalai.cluster import DistCoordinator - >>> dist_coordinator = DistCoordinator() - >>> - >>> @dist_coordinator.on_master_only() - >>> def print_on_master(msg): - >>> print(msg) + ```python + from colossalai.cluster import DistCoordinator + dist_coordinator = DistCoordinator() + + @dist_coordinator.on_master_only() + def print_on_master(msg): + print(msg) + ``` """ is_master = self.is_master(process_group) - # define an inner functiuon + # define an inner function def decorator(func): - @functools.wraps(func) def wrapper(*args, **kwargs): if is_master: diff --git a/colossalai/cluster/process_group_manager.py b/colossalai/cluster/process_group_manager.py index e52661846f3ed6d25602252886401837613a75e3..68106b5031265fa1edaf5b9d70d510adaa5318cb 100644 --- a/colossalai/cluster/process_group_manager.py +++ b/colossalai/cluster/process_group_manager.py @@ -19,7 +19,7 @@ class ProcessGroupManager: def __init__(self): self.pg_store = dict() - def create_process_group(self, name: str, ranks: List[int], backend: str = 'nccl') -> ProcessGroup: + def create_process_group(self, name: str, ranks: List[int], backend: str = "nccl") -> ProcessGroup: """ Get a process group by name. If the process group does not exist, it will be created. @@ -36,7 +36,7 @@ class ProcessGroupManager: self.pg_store[name] = pg return pg else: - raise ValueError(f'Process group {name} already exists.') + raise ValueError(f"Process group {name} already exists.") def get(self, name: str) -> ProcessGroup: """ @@ -51,7 +51,7 @@ class ProcessGroupManager: if name in self.pg_store: return self.pg_store[name] else: - raise ValueError(f'Process group {name} does not exist.') + raise ValueError(f"Process group {name} does not exist.") def destroy(self, name: str) -> None: """ @@ -64,7 +64,7 @@ class ProcessGroupManager: dist.destroy_process_group(self.pg_store[name]) del self.pg_store[name] else: - raise ValueError(f'Process group {name} does not exist.') + raise ValueError(f"Process group {name} does not exist.") def destroy_all(self) -> None: """ diff --git a/colossalai/cluster/process_group_mesh.py b/colossalai/cluster/process_group_mesh.py new file mode 100644 index 0000000000000000000000000000000000000000..3885bc96256163f29fa197455b3d14fd5c2582a8 --- /dev/null +++ b/colossalai/cluster/process_group_mesh.py @@ -0,0 +1,208 @@ +import itertools +from functools import reduce +from operator import mul +from typing import Dict, List, Optional, Tuple, Union + +import numpy as np +import torch.distributed as dist +from torch.distributed import ProcessGroup + + +def prod(nums: List[int]) -> int: + """Product of a list of numbers. + + Args: + nums (List[int]): A list of numbers. + + Returns: + int: The product of the numbers. + """ + return reduce(mul, nums) + + +class ProcessGroupMesh: + """A helper class to manage the process group mesh. It only describes how to organize process groups, and it's decoupled with parallel method. + It just initialize process groups and cache them. The parallel method should manage them and use them to do the parallel computation. + + We use a ND-tuple to represent the process group mesh. And a ND-coordinate is to represent each process. + For example, ``(0, 1, 0)`` represents the process whose rank is 2 in a 3D process group mesh with size ``(2, 2, 2)``. + + Args: + *size (int): The size of each dimension of the process group mesh. The product of the size must be equal to the world size. + + Attributes: + shape (Tuple[int, ...]): The shape of the process group mesh. + rank (int): The rank of the current process. + """ + + def __init__(self, *size: int) -> None: + assert dist.is_initialized(), "Please initialize torch.distributed first." + assert prod(size) == dist.get_world_size(), "The product of the size must be equal to the world size." + self._shape = size + self._rank = dist.get_rank() + self._coord = ProcessGroupMesh.unravel(self._rank, self._shape) + self._ranks_to_group: Dict[Tuple[int, ...], ProcessGroup] = {} + self._group_to_ranks: Dict[ProcessGroup, Tuple[int, ...]] = {} + + @property + def shape(self) -> Tuple[int, ...]: + return self._shape + + @property + def rank(self) -> int: + return self._rank + + def size(self, dim: Optional[int] = None) -> Union[int, Tuple[int, ...]]: + """Get the size of the process group mesh. + + Args: + dim (Optional[int], optional): Dimension of the process group mesh. `None` means all dimensions. Defaults to None. + + Returns: + Union[int, Tuple[int, ...]]: Size of the target dimension or the whole process group mesh. + """ + if dim is None: + return self._shape + else: + return self._shape[dim] + + def coordinate(self, dim: Optional[int] = None) -> Union[int, Tuple[int, ...]]: + """Get the coordinate of the process group mesh. + + Args: + dim (Optional[int], optional): Dimension of the process group mesh. `None` means all dimensions. Defaults to None. + + Returns: + Union[int, Tuple[int, ...]]: Coordinate of the target dimension or the whole process group mesh. + """ + if dim is None: + return self._coord + else: + return self._coord[dim] + + @staticmethod + def unravel(rank: int, shape: Tuple[int, ...]) -> Tuple[int, ...]: + """Convert a rank to a coordinate. + + Args: + rank (int): Rank to be converted. + shape (Tuple[int, ...]): Shape of the process group mesh. + + Returns: + Tuple[int, ...]: Coordinate of the rank. + """ + return np.unravel_index(rank, shape) + + @staticmethod + def ravel(coord: Tuple[int, ...], shape: Tuple[int, ...], mode: str = "raise") -> int: + """Convert a coordinate to a rank. + mode: ['raise', 'wrap', 'clip'], see https://numpy.org/doc/stable/reference/generated/numpy.ravel_multi_index.html. + with wrap, index out of range would be wrapped around. + For instance, ravel((0, i, 0), (1, 2, 1), 'wrap') returns (i % 2) + + Args: + coords (Tuple[int, ...]): Coordinate to be converted. + shape (Tuple[int, ...]): Shape of the process group mesh. + mode (Optional[str]): The mode for numpy.ravel_multi_index. + + Returns: + int: Rank of the coordinate. + """ + + assert mode in ["raise", "wrap", "clip"] + return np.ravel_multi_index(coord, shape, mode) + + def get_group(self, ranks_in_group: List[int], backend: Optional[str] = None) -> ProcessGroup: + """Get the process group with the given ranks. It the process group doesn't exist, it will be created. + + Args: + ranks_in_group (List[int]): Ranks in the process group. + backend (Optional[str], optional): Backend of the process group. Defaults to None. + + Returns: + ProcessGroup: The process group with the given ranks. + """ + ranks_in_group = sorted(ranks_in_group) + if tuple(ranks_in_group) not in self._group_to_ranks: + group = dist.new_group(ranks_in_group, backend=backend) + self._ranks_to_group[tuple(ranks_in_group)] = group + self._group_to_ranks[group] = tuple(ranks_in_group) + return self._ranks_to_group[tuple(ranks_in_group)] + + def get_ranks_in_group(self, group: ProcessGroup) -> List[int]: + """Get the ranks in the given process group. The process group must be created by this class. + + Args: + group (ProcessGroup): The process group. + + Returns: + List[int]: Ranks in the process group. + """ + return list(self._group_to_ranks[group]) + + @staticmethod + def get_coords_along_axis( + base_coord: Tuple[int, ...], axis: int, indices_at_axis: List[int] + ) -> List[Tuple[int, ...]]: + """Get coordinates along the given axis. + + Args: + base_coord (Tuple[int, ...]): Base coordinate which the coordinates along the axis are based on. + axis (int): Axis along which the coordinates are generated. + indices_at_axis (List[int]): Indices at the axis. + + Returns: + List[Tuple[int, ...]]: Coordinates along the axis. + """ + coords_in_group = [] + for idx in indices_at_axis: + coords_in_group.append(base_coord[:axis] + (idx,) + base_coord[axis + 1 :]) + return coords_in_group + + def create_group_along_axis( + self, axis: int, indices_at_axis: Optional[List[int]] = None, backend: Optional[str] = None + ) -> ProcessGroup: + """Create all process groups along the given axis, and return the one which the current process belongs to. + + Args: + axis (int): Axis along which the process groups are created. + indices_at_axis (Optional[List[int]], optional): Indices at the axis. Defaults to None. + backend (Optional[str], optional): Backend of the process group. Defaults to None. + + Returns: + ProcessGroup: The process group along the given axis which the current process belongs to. + """ + indices_at_axis = indices_at_axis or list(range(self._shape[axis])) + reduced_shape = list(self._shape) + # the choices on the axis are reduced to 1, since it's determined by `indices_at_axis` + reduced_shape[axis] = 1 + target_group = None + # use Cartesian product to generate all combinations of coordinates + for base_coord in itertools.product(*[range(s) for s in reduced_shape]): + coords_in_group = ProcessGroupMesh.get_coords_along_axis(base_coord, axis, indices_at_axis) + ranks_in_group = tuple([ProcessGroupMesh.ravel(coord, self._shape) for coord in coords_in_group]) + group = self.get_group(ranks_in_group, backend=backend) + if self._rank in ranks_in_group: + target_group = group + return target_group + + def get_group_along_axis( + self, axis: int, indices_at_axis: Optional[List[int]] = None, backend: Optional[str] = None + ) -> ProcessGroup: + """Get the process group along the given axis which the current process belongs to. If the process group doesn't exist, it will be created. + + Args: + axis (int): Axis along which the process groups are created. + indices_at_axis (Optional[List[int]], optional): Indices at the axis. Defaults to None. + backend (Optional[str], optional): Backend of the process group. Defaults to None. + + Returns: + ProcessGroup: The process group along the given axis which the current process belongs to. + """ + indices_at_axis = indices_at_axis or list(range(self._shape[axis])) + coords_in_group = ProcessGroupMesh.get_coords_along_axis(self._coord, axis, indices_at_axis) + ranks_in_group = tuple([ProcessGroupMesh.ravel(coord, self._shape) for coord in coords_in_group]) + if ranks_in_group not in self._ranks_to_group: + # no need to cache it explicitly, since it will be cached in `create_group_along_axis` + return self.create_group_along_axis(axis, indices_at_axis, backend=backend) + return self._ranks_to_group[ranks_in_group] diff --git a/colossalai/communication/__init__.py b/colossalai/communication/__init__.py deleted file mode 100644 index 220481b7af15bcada443ed7c9f8c91350a5f76b1..0000000000000000000000000000000000000000 --- a/colossalai/communication/__init__.py +++ /dev/null @@ -1,26 +0,0 @@ -from .collective import all_gather, reduce_scatter, all_reduce, broadcast, reduce -from .p2p import (send_forward, send_forward_recv_forward, send_backward_recv_forward, send_backward, - send_backward_recv_backward, send_forward_recv_backward, send_forward_backward_recv_forward_backward, - recv_forward, recv_backward) -from .ring import ring_forward -from .utils import send_obj_meta, recv_obj_meta - -__all__ = [ - 'all_gather', - 'reduce_scatter', - 'all_reduce', - 'broadcast', - 'reduce', - 'send_forward', - 'send_forward_recv_forward', - 'send_forward_backward_recv_forward_backward', - 'send_backward', - 'send_backward_recv_backward', - 'send_backward_recv_forward', - 'send_forward_recv_backward', - 'recv_backward', - 'recv_forward', - 'ring_forward', - 'send_obj_meta', - 'recv_obj_meta', -] diff --git a/colossalai/communication/p2p.py b/colossalai/communication/p2p.py deleted file mode 100644 index 0200cd3c6553dc8e2b3bbaa60ffb1d416c699370..0000000000000000000000000000000000000000 --- a/colossalai/communication/p2p.py +++ /dev/null @@ -1,405 +0,0 @@ -#!/usr/bin/env python -# -*- encoding: utf-8 -*- - -from typing import List, Tuple, Union -import torch -import torch.distributed as dist - -from colossalai.context.parallel_mode import ParallelMode -from colossalai.core import global_context as gpc -from colossalai.utils import get_current_device -from functools import reduce -import operator -from .utils import split_tensor_into_1d_equal_chunks, gather_split_1d_tensor - -TensorShape = Union[torch.Size, List[int], Tuple[int]] - - -def _get_tensor_shape(tensor_shape: TensorShape, chunk_tensor: bool = False) -> Tuple[TensorShape, bool]: - """get the exact tensor shape when communicating and return whether the tensor is a chunk - - Args: - tensor_shape (:class:`torch.Size`): shape of tensor - chunk_tensor (bool, optional): whether to chunk tensor, defaults to False - - Returns: - Tuple[Union[:class:`torch.Size`, List[int], Tuple[int]], bool]: exact tensor shape, whether to chunk tensor - """ - if chunk_tensor: - tensor_chunk_shape = reduce(operator.mul, tensor_shape, 1) - tensor_parallel_world_size = gpc.get_world_size(ParallelMode.TENSOR) - if tensor_chunk_shape % tensor_parallel_world_size == 0: - tensor_chunk_shape = tensor_chunk_shape // tensor_parallel_world_size - else: - tensor_chunk_shape = tensor_shape - chunk_tensor = False - else: - tensor_chunk_shape = tensor_shape - return tensor_chunk_shape, chunk_tensor - - -def create_recv_buffer_with_shapes(recv_shapes, dtype, scatter_gather_tensors): - if isinstance(recv_shapes, torch.Size): - recv_chunk_shape, recv_split = _get_tensor_shape(recv_shapes, scatter_gather_tensors) - buffer_recv = torch.empty(recv_chunk_shape, requires_grad=True, device=get_current_device(), dtype=dtype) - return buffer_recv, recv_split - buffer_recv = [] - for recv_shape in recv_shapes: - recv_chunk_shape, recv_split = _get_tensor_shape(recv_shape, scatter_gather_tensors) - tensor_recv = torch.empty(recv_chunk_shape, requires_grad=True, device=get_current_device(), dtype=dtype) - buffer_recv.append(tensor_recv) - return buffer_recv, recv_split - - -def process_object_to_send(object_send, scatter_gather_tensors): - if isinstance(object_send, torch.Tensor): - send_split = _get_tensor_shape(object_send.shape, scatter_gather_tensors)[1] - if send_split: - object_send = split_tensor_into_1d_equal_chunks(object_send) - return object_send - - object_send_list = [] - for tensor_send in object_send: - send_split = _get_tensor_shape(tensor_send.shape, scatter_gather_tensors)[1] - if send_split: - object_send_list.append(split_tensor_into_1d_equal_chunks(tensor_send)) - else: - object_send_list.append(tensor_send) - object_send = tuple(object_send_list) - - return object_send - - -def filling_ops_queue(obj, comm_op, comm_rank, ops_queue): - if isinstance(obj, torch.Tensor): - op_to_add = dist.P2POp(comm_op, obj, comm_rank) - ops_queue.append(op_to_add) - else: - for tensor_to_comm in obj: - op_to_add = dist.P2POp(comm_op, tensor_to_comm, comm_rank) - ops_queue.append(op_to_add) - - -def _communicate(object_send_next: Union[torch.Tensor, List[torch.Tensor]] = None, - object_send_prev: Union[torch.Tensor, List[torch.Tensor]] = None, - recv_prev: bool = False, - recv_next: bool = False, - recv_prev_shape: Union[torch.Size, List[torch.Size]] = None, - recv_next_shape: Union[torch.Size, List[torch.Size]] = None, - prev_rank: int = None, - next_rank: int = None, - dtype: torch.dtype = None, - scatter_gather_tensors: bool = False) -> Tuple[Union[torch.Tensor, List[torch.Tensor]]]: - """ - Adapted from megatron.p2p_communication. - Communicate tensors between stages. Used as helper method in other - communication methods that are used in pipeline schedule. - Takes the following arguments: - object_send_next (Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]): tensor to send to next rank (no tensor sent if - set to None). - object_send_prev (Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]): tensor to send to prev rank (no tensor sent if - set to None). - recv_prev (bool): boolean for whether tensor should be received from - previous rank. - recv_next (bool): boolean for whether tensor should be received from - next rank. - recv_prev_shape (Union[:class:`torch.Size`, List[:class:`torch.Size`]]): shape of the tensor to be received from the previous stage, defaults to None. - recv_next_shape (Union[:class:`torch.Size`, List[:class:`torch.Size`]]): shape of the tensor to be received from the next stage, defaults to None. - prev_rank (int): the rank of the previous pipeline stage, defaults to None, - next_rank (int): the rank of the next pipeline stage, defaults to None, - dtype (torch.dtype): data type of intermediate buffers, defaults to None - scatter_gather_tensors (bool): whether to scatter and gather tensor between pipeline stages, defaults to False - - Returns: - Tuple[Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]]: returns tensor_recv_prev, tensor_recv_next - """ - - # Create placeholder tensors for receive in forward and backward directions - # if needed. - tensor_recv_prev = None - tensor_recv_next = None - - if recv_prev: - assert recv_prev_shape is not None - tensor_recv_prev, recv_prev_split = create_recv_buffer_with_shapes(recv_prev_shape, dtype, - scatter_gather_tensors) - - if recv_next: - assert recv_next_shape is not None - tensor_recv_next, recv_next_split = create_recv_buffer_with_shapes(recv_next_shape, dtype, - scatter_gather_tensors) - - if object_send_prev is not None or recv_prev: - if prev_rank is None: - prev_rank = gpc.get_prev_global_rank(ParallelMode.PIPELINE) - - if object_send_next is not None or recv_next: - if next_rank is None: - next_rank = gpc.get_next_global_rank(ParallelMode.PIPELINE) - - if object_send_prev is not None: - object_send_prev = process_object_to_send(object_send_prev, scatter_gather_tensors) - - if object_send_next is not None: - object_send_next = process_object_to_send(object_send_next, scatter_gather_tensors) - - ops = [] - if object_send_prev is not None: - filling_ops_queue(object_send_prev, dist.isend, prev_rank, ops) - - if tensor_recv_prev is not None: - filling_ops_queue(tensor_recv_prev, dist.irecv, prev_rank, ops) - - if tensor_recv_next is not None: - filling_ops_queue(tensor_recv_next, dist.irecv, next_rank, ops) - - if object_send_next is not None: - filling_ops_queue(object_send_next, dist.isend, next_rank, ops) - - if len(ops) > 0: - reqs = dist.batch_isend_irecv(ops) - for req in reqs: - req.wait() - # To protect against race condition when using batch_isend_irecv(). - torch.cuda.synchronize() - - if recv_prev and recv_prev_split: - if isinstance(tensor_recv_prev, torch.Tensor): - tensor_recv_prev = gather_split_1d_tensor(tensor_recv_prev).view(recv_prev_shape).requires_grad_() - else: - for index in range(len(tensor_recv_prev)): - tensor_recv_prev[index] = gather_split_1d_tensor(tensor_recv_prev[index]).view( - recv_prev_shape[index]).requires_grad_() - - if recv_next and recv_next_split: - if isinstance(tensor_recv_next, torch.Tensor): - tensor_recv_next = gather_split_1d_tensor(tensor_recv_next).view(recv_next_shape).requires_grad_() - else: - for index in range(len(tensor_recv_next)): - tensor_recv_next[index] = gather_split_1d_tensor(tensor_recv_next[index]).view( - recv_next_shape[index]).requires_grad_() - - return tensor_recv_prev, tensor_recv_next - - -def recv_forward(input_tensor_shape, - prev_rank=None, - dtype=torch.float, - scatter_gather_tensors=False) -> Union[torch.Tensor, List[torch.Tensor]]: - """Copy the forward output from the previous stage in pipeline as the input tensor of this stage. - - Args: - input_tensor_shape (Union[:class:`torch.Size`, List[:class:`torch.Size`]]): The shape of the tensor to be received. - prev_rank (int, optional): The rank of the source of the tensor. - - Returns: - Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]: The input tensor or input tensor list. - """ - if gpc.is_pipeline_first_stage(): - input_tensor = None - else: - input_tensor, _ = _communicate(recv_prev=True, - recv_prev_shape=input_tensor_shape, - prev_rank=prev_rank, - dtype=dtype, - scatter_gather_tensors=scatter_gather_tensors) - return input_tensor - - -def recv_backward(output_grad_shape, - next_rank=None, - dtype=torch.float, - scatter_gather_tensors=False) -> Union[torch.Tensor, List[torch.Tensor]]: - """Copy the gradient tensor from the next stage in pipeline as the input gradient of this stage. - - Args: - output_grad_shape (Union[:class:`torch.Size`, List[:class:`torch.Size`]]): The shape of the tensor to be received. - next_rank (int, optional): The rank of the source of the tensor. - - Returns: - Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]: The input gradient tensor or gradident tensor list. - """ - if gpc.is_pipeline_last_stage(): - output_tensor_grad = None - else: - _, output_tensor_grad = _communicate(recv_next=True, - recv_next_shape=output_grad_shape, - next_rank=next_rank, - dtype=dtype, - scatter_gather_tensors=scatter_gather_tensors) - return output_tensor_grad - - -def send_forward(output_tensor, next_rank=None, scatter_gather_tensors=False) -> None: - """Sends the input tensor to the next stage in pipeline. - - Args: - output_tensor (Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]): Tensor to be sent. - next_rank (int, optional): The rank of the recipient of the tensor. - """ - if not gpc.is_pipeline_last_stage(): - _communicate(object_send_next=output_tensor, next_rank=next_rank, scatter_gather_tensors=scatter_gather_tensors) - - -def send_backward(input_tensor_grad, prev_rank=None, scatter_gather_tensors=False) -> None: - """Sends the gradient tensor to the previous stage in pipeline. - - Args: - input_tensor_grad (Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]): Tensor to be sent - prev_rank (int, optional): The rank of the recipient of the tensor - """ - if not gpc.is_pipeline_first_stage(): - _communicate(object_send_prev=input_tensor_grad, - prev_rank=prev_rank, - scatter_gather_tensors=scatter_gather_tensors) - - -def send_forward_recv_backward(output_tensor, - output_grad_shape, - recv_next=True, - next_rank=None, - dtype=torch.float, - scatter_gather_tensors=False) -> Union[torch.Tensor, List[torch.Tensor]]: - """Batched communication operation. Sends the input tensor to the - next stage in pipeline, while receives the gradient tensor from the - next stage in pipeline as the input gradient tensor of this stage. - - Args: - output_tensor (Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]): Tensor to be sent. - output_grad_shape (Union[:class:`torch.Size`, List[:class:`torch.Size`]]): The shape of the tensor to be received. - - Returns: - Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]: The input gradient tensor. - """ - if gpc.is_pipeline_last_stage(): - output_tensor_grad = None - else: - _, output_tensor_grad = _communicate(object_send_next=output_tensor, - recv_next=recv_next, - recv_next_shape=output_grad_shape, - next_rank=next_rank, - dtype=dtype, - scatter_gather_tensors=scatter_gather_tensors) - return output_tensor_grad - - -def send_backward_recv_forward(input_tensor_grad, - input_tensor_shape, - recv_prev=True, - prev_rank=None, - dtype=torch.float, - scatter_gather_tensors=False) -> Union[torch.Tensor, List[torch.Tensor]]: - """Batched communication operation. Sends the gradient tensor to the - previous stage in pipeline, while receives the output tensor from the - previous stage in pipeline as the input of this stage. - - Args: - input_tensor_grad (Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]): Tensor to be sent. - input_tensor_shape (Union[:class:`torch.Size`, List[:class:`torch.Size`]]): The shape of the tensor to be received. - - Returns: - Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]: The input tensor. - """ - if gpc.is_pipeline_first_stage(): - input_tensor = None - else: - input_tensor, _ = _communicate(object_send_prev=input_tensor_grad, - recv_prev=recv_prev, - recv_prev_shape=input_tensor_shape, - prev_rank=prev_rank, - dtype=dtype, - scatter_gather_tensors=scatter_gather_tensors) - return input_tensor - - -def send_forward_recv_forward(output_tensor, - input_tensor_shape, - recv_prev=True, - prev_rank=None, - next_rank=None, - dtype=torch.float, - scatter_gather_tensors=False) -> Union[torch.Tensor, List[torch.Tensor]]: - """Batched communication operation. Sends the input tensor to the - next stage in pipeline, while receives the output tensor from the - previous stage in pipeline as the input of this stage. - - Args: - output_tensor (Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]): Tensor to be sent. - input_tensor_shape (Union[:class:`torch.Size`, List[:class:`torch.Size`]]): The shape of the tensor to be received. - - Returns: - Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]: The input tensor. - """ - input_tensor, _ = _communicate(object_send_next=output_tensor, - recv_prev=recv_prev, - recv_prev_shape=input_tensor_shape, - prev_rank=prev_rank, - next_rank=next_rank, - dtype=dtype, - scatter_gather_tensors=scatter_gather_tensors) - return input_tensor - - -def send_backward_recv_backward(input_tensor_grad, - output_grad_shape, - recv_next=True, - prev_rank=None, - next_rank=None, - dtype=torch.float, - scatter_gather_tensors=False) -> Union[torch.Tensor, List[torch.Tensor]]: - """Batched communication operation. Sends the gradient tensor to the - previous stage in pipeline, while receives the gradient tensor from the - next member in pipeline as the input of this stage. - - Args: - input_tensor_grad (Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]): Tensor to be sent. - output_grad_shape (Union[:class:`torch.Size`, List[:class:`torch.Size`]]): The shape of the tensor to be received. - - Returns: - Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]: The input gradient tensor. - """ - _, output_tensor_grad = _communicate(object_send_prev=input_tensor_grad, - recv_next=recv_next, - recv_next_shape=output_grad_shape, - prev_rank=prev_rank, - next_rank=next_rank, - dtype=dtype, - scatter_gather_tensors=scatter_gather_tensors) - return output_tensor_grad - - -def send_forward_backward_recv_forward_backward( - output_tensor, - input_tensor_grad, - input_tensor_shape, - output_grad_shape, - recv_prev=True, - recv_next=True, - prev_rank=None, - next_rank=None, - dtype=torch.float, - scatter_gather_tensors=False) -> Tuple[Union[torch.Tensor, List[torch.Tensor]]]: - """Batched communication operation. Sends the input tensor to the next stage in pipeline and - the gradient tensor to the previous stage, while receives the input gradient tensor from the - next stage and the input tensor from the previous stage. - - Args: - output_tensor (Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]): Tensor sent to the next. - input_tensor_grad (Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]): Tensor sent to the previous. - input_tensor_shape (Union[:class:`torch.Size`, List[:class:`torch.Size`]]): The shape of the tensor received from the previous. - output_grad_shape (Union[:class:`torch.Size`, List[:class:`torch.Size`]]): The shape of the tensor received from the next. - - Returns: - Tuple(Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]], Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]): (the input tensor, the input gradient tensor) - """ - input_tensor, output_tensor_grad = _communicate(object_send_next=output_tensor, - object_send_prev=input_tensor_grad, - recv_prev=recv_prev, - recv_next=recv_next, - recv_prev_shape=input_tensor_shape, - recv_next_shape=output_grad_shape, - prev_rank=prev_rank, - next_rank=next_rank, - dtype=dtype, - scatter_gather_tensors=scatter_gather_tensors) - return input_tensor, output_tensor_grad diff --git a/colossalai/communication/ring.py b/colossalai/communication/ring.py deleted file mode 100644 index aece7574b7c41cac3b16cd5891b1e26d0ede9c36..0000000000000000000000000000000000000000 --- a/colossalai/communication/ring.py +++ /dev/null @@ -1,56 +0,0 @@ -#!/usr/bin/env python -# -*- encoding: utf-8 -*- - -import torch - -from colossalai.context.parallel_mode import ParallelMode -from colossalai.core import global_context as gpc -from colossalai.utils import get_current_device, synchronize - - -def ring_forward(tensor_send_next: torch.Tensor, parallel_mode: ParallelMode) -> torch.Tensor: - """Sends a tensor to the next member and receives a tensor from the previous member. - This function returns the received tensor from the previous member. - - Args: - tensor_send_next (:class:`torch.Tensor`): Tensor sent to next member - parallel_mode (ParallelMode): Parallel group mode used in this communication - - Returns: - :class:`torch.Tensor`: The tensor received from the previous. - - Note: - The parallel_mode should be concluded in ``ParallelMode``. More details about ``ParallelMode`` could be found - in `parallel_mode `_. - """ - buffer_shape = tensor_send_next.size() - - ops = [] - current_rank = gpc.get_global_rank() - - tensor_recv_prev = torch.empty(buffer_shape, - requires_grad=True, - device=get_current_device(), - dtype=tensor_send_next.dtype) - - # send to next rank - send_next_op = torch.distributed.P2POp(torch.distributed.isend, tensor_send_next, - gpc.get_next_global_rank(parallel_mode)) - ops.append(send_next_op) - - # receive from prev rank - recv_prev_op = torch.distributed.P2POp(torch.distributed.irecv, tensor_recv_prev, - gpc.get_prev_global_rank(parallel_mode)) - ops.append(recv_prev_op) - - if current_rank % 2 == 0: - ops = ops[::-1] - - reqs = torch.distributed.batch_isend_irecv(ops) - for req in reqs: - req.wait() - - # To protect against race condition when using batch_isend_irecv(). - synchronize() - - return tensor_recv_prev diff --git a/colossalai/communication/utils.py b/colossalai/communication/utils.py deleted file mode 100644 index ef9eceea847dd3d6cb036e87e369529dcbe0db41..0000000000000000000000000000000000000000 --- a/colossalai/communication/utils.py +++ /dev/null @@ -1,126 +0,0 @@ -import torch -import torch.distributed as dist - -from colossalai.context.parallel_mode import ParallelMode -from colossalai.core import global_context as gpc -from colossalai.utils import get_current_device -from typing import Union, List, Tuple - -TensorShape = Union[torch.Size, List[int], Tuple[int]] - - -def send_meta_helper(obj, next_rank, tensor_kwargs): - send_shape = torch.tensor(obj.size(), **tensor_kwargs) - send_ndims = torch.tensor(len(obj.size()), **tensor_kwargs) - dist.send(send_ndims, next_rank) - dist.send(send_shape, next_rank) - - -def send_obj_meta(obj, need_meta=True, next_rank=None) -> bool: - """Sends obj meta information before sending a specific obj. - Since the recipient must know the shape of the obj in p2p communications, - meta information of the obj should be sent before communications. This function - synchronizes with :func:`recv_obj_meta`. - - Args: - obj (Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]): obj to be sent. - need_meta (bool, optional): If False, meta information won't be sent. - next_rank (int): The rank of the next member in pipeline parallel group. - - Returns: - bool: False - """ - if need_meta: - if next_rank is None: - next_rank = gpc.get_next_global_rank(ParallelMode.PIPELINE) - - tensor_kwargs = {'dtype': torch.long, 'device': get_current_device()} - if isinstance(obj, torch.Tensor): - send_obj_nums = torch.tensor(1, **tensor_kwargs) - dist.send(send_obj_nums, next_rank) - send_meta_helper(obj, next_rank, tensor_kwargs) - else: - send_obj_nums = torch.tensor(len(obj), **tensor_kwargs) - dist.send(send_obj_nums, next_rank) - for tensor_to_send in obj: - send_meta_helper(tensor_to_send, next_rank, tensor_kwargs) - - return False - - -def recv_meta_helper(prev_rank, tensor_kwargs): - recv_ndims = torch.empty((), **tensor_kwargs) - dist.recv(recv_ndims, prev_rank) - recv_shape = torch.empty(recv_ndims, **tensor_kwargs) - dist.recv(recv_shape, prev_rank) - return recv_shape - - -def recv_obj_meta(obj_shape, prev_rank=None) -> torch.Size: - """Receives obj meta information before receiving a specific obj. - Since the recipient must know the shape of the obj in p2p communications, - meta information of the obj should be received before communications. This function - synchronizes with :func:`send_obj_meta`. - - Args: - obj_shape (Union[:class:`torch.Size`, List[:class:`torch.Size`]]): The shape of the obj to be received. - prev_rank (int): The rank of the source of the obj. - - Returns: - Union[:class:`torch.Size`, List[:class:`torch.Size`]]: The shape of the obj to be received. - """ - if obj_shape is None: - if prev_rank is None: - prev_rank = gpc.get_prev_global_rank(ParallelMode.PIPELINE) - - tensor_kwargs = {'dtype': torch.long, 'device': get_current_device()} - recv_obj_nums = torch.empty((), **tensor_kwargs) - dist.recv(recv_obj_nums, prev_rank) - if recv_obj_nums.item() == 1: - recv_shape = recv_meta_helper(prev_rank, tensor_kwargs) - obj_shape = torch.Size(recv_shape) - else: - obj_shape = [] - for i in range(recv_obj_nums.item()): - recv_shape = recv_meta_helper(prev_rank, tensor_kwargs) - obj_shape.append(torch.Size(recv_shape)) - - return obj_shape - - -def split_tensor_into_1d_equal_chunks(tensor: torch.Tensor, new_buffer=False) -> torch.Tensor: - """Break a tensor into equal 1D chunks. - - Args: - tensor (:class:`torch.Tensor`): Tensor to be split before communication. - new_buffer (bool, optional): Whether to use a new buffer to store sliced tensor. - - Returns: - :class:`torch.Tensor`: The split tensor - """ - partition_size = torch.numel(tensor) // gpc.get_world_size(ParallelMode.PARALLEL_1D) - start_index = partition_size * gpc.get_local_rank(ParallelMode.PARALLEL_1D) - end_index = start_index + partition_size - if new_buffer: - data = torch.empty(partition_size, dtype=tensor.dtype, device=torch.cuda.current_device(), requires_grad=False) - data.copy_(tensor.view(-1)[start_index:end_index]) - else: - data = tensor.view(-1)[start_index:end_index] - return data - - -def gather_split_1d_tensor(tensor: torch.Tensor) -> torch.Tensor: - """Opposite of above function, gather values from model parallel ranks. - - Args: - tensor (:class:`torch.Tensor`): Tensor to be gathered after communication. - Returns: - :class:`torch.Tensor`: The gathered tensor. - """ - world_size = gpc.get_world_size(ParallelMode.PARALLEL_1D) - numel = torch.numel(tensor) - numel_gathered = world_size * numel - gathered = torch.empty(numel_gathered, dtype=tensor.dtype, device=torch.cuda.current_device(), requires_grad=False) - chunks = [gathered[i * numel:(i + 1) * numel] for i in range(world_size)] - dist.all_gather(chunks, tensor, group=gpc.get_group(ParallelMode.PARALLEL_1D)) - return gathered diff --git a/colossalai/constants.py b/colossalai/constants.py deleted file mode 100644 index 6cf9085f9fbb63ea18d2712f99c08f24b539245d..0000000000000000000000000000000000000000 --- a/colossalai/constants.py +++ /dev/null @@ -1,32 +0,0 @@ -#!/usr/bin/env python -# -*- encoding: utf-8 -*- - -ALLOWED_MODES = [None, '1d', '2d', '2.5d', '3d', 'sequence'] -TENSOR_PARALLEL_MODE = 'tensor_parallel_mode' - -# initializer -INITIALIZER_MAPPING = { - 'data': 'Initializer_Data', - 'tensor': 'Initializer_Tensor', - 'pipeline': 'Initializer_Pipeline', - 'embedding': 'Initializer_Embedding', - '1d': 'Initializer_1D', - '2d': 'Initializer_2D', - '2.5d': 'Initializer_2p5D', - '3d': 'Initializer_3D', - 'sequence': 'Initializer_Sequence', - 'model': 'Initializer_Model', - 'moe': 'Initializer_Moe' -} - -# 3D parallelism groups -INPUT_GROUP_3D = 'input_group_3d' -WEIGHT_GROUP_3D = 'weight_group_3d' -OUTPUT_GROUP_3D = 'output_group_3d' -INPUT_X_WEIGHT_3D = 'input_x_weight_group_3d' -OUTPUT_X_WEIGHT_3D = 'output_x_weight_group_3d' - -# Attributes of tensor parallel parameters -IS_TENSOR_PARALLEL = 'is_tensor_parallel' -NUM_PARTITIONS = 'num_partitions' -TENSOR_PARALLEL_ATTRIBUTES = [IS_TENSOR_PARALLEL, NUM_PARTITIONS] diff --git a/colossalai/context/__init__.py b/colossalai/context/__init__.py index 50178b5fa850777f8455798cc6ab9d7254c5a9fe..ab57301bb9108a72ebe59cbdbba07caa35f46bb1 100644 --- a/colossalai/context/__init__.py +++ b/colossalai/context/__init__.py @@ -1,6 +1,8 @@ from .config import Config, ConfigException -from .parallel_context import ParallelContext -from .parallel_mode import ParallelMode -from .moe_context import MOE_CONTEXT -from .process_group_initializer import * -from .random import * + +# from .moe_context import MOE_CONTEXT + +__all__ = [ + "Config", + "ConfigException", +] diff --git a/colossalai/context/config.py b/colossalai/context/config.py index 8903707708df96eac7a0a70343e37e984e6fabed..05a2e4bf044a0561012f62d521cf1fb5c8500a9b 100644 --- a/colossalai/context/config.py +++ b/colossalai/context/config.py @@ -5,6 +5,7 @@ import inspect import sys from importlib.machinery import SourceFileLoader from pathlib import Path + from colossalai.logging import get_dist_logger @@ -41,7 +42,7 @@ class Config(dict): self.__setattr__(key, value) def update(self, config): - assert isinstance(config, (Config, dict)), 'can only update dictionary or Config objects.' + assert isinstance(config, (Config, dict)), "can only update dictionary or Config objects." for k, v in config.items(): self._add_item(k, v) return self @@ -66,11 +67,11 @@ class Config(dict): elif isinstance(filename, Path): filepath = filename.absolute() - assert filepath.exists(), f'{filename} is not found, please check your configuration path' + assert filepath.exists(), f"{filename} is not found, please check your configuration path" # check extension extension = filepath.suffix - assert extension == '.py', 'only .py files are supported' + assert extension == ".py", "only .py files are supported" # import the config as module remove_path = False @@ -86,13 +87,13 @@ class Config(dict): config = Config() for k, v in module.__dict__.items(): - if k.startswith('__') or inspect.ismodule(v) or inspect.isclass(v): + if k.startswith("__") or inspect.ismodule(v) or inspect.isclass(v): continue else: config._add_item(k, v) logger = get_dist_logger() - logger.debug('variables which starts with __, is a module or class declaration are omitted in config file') + logger.debug("variables which starts with __, is a module or class declaration are omitted in config file") # remove module del sys.modules[module_name] diff --git a/colossalai/context/moe_context.py b/colossalai/context/moe_context.py index b41f4072a4052113b3e3a79a20c0278b9fed8295..066dfc7222e116345ae9c3af11438f0fef46db04 100644 --- a/colossalai/context/moe_context.py +++ b/colossalai/context/moe_context.py @@ -3,21 +3,19 @@ from typing import Tuple import torch import torch.distributed as dist -from colossalai.context.parallel_mode import ParallelMode from colossalai.context.singleton_meta import SingletonMeta -from colossalai.tensor import ProcessGroup +from colossalai.legacy.tensor import ProcessGroup def _check_sanity(): - from colossalai.core import global_context as gpc + from colossalai.legacy.core import global_context as gpc + if gpc.tensor_parallel_size > 1 or gpc.pipeline_parallel_size > 1: - raise NotImplementedError("Moe is not compatible with tensor or " - "pipeline parallel at present.") + raise NotImplementedError("Moe is not compatible with tensor or " "pipeline parallel at present.") class MoeParallelInfo: - """Moe parallelism information, storing parallel sizes and groups. - """ + """Moe parallelism information, storing parallel sizes and groups.""" def __init__(self, ep_size: int, dp_size: int): _check_sanity() @@ -61,10 +59,12 @@ class MoeContext(metaclass=SingletonMeta): self.world_size = dist.get_world_size() - from colossalai.core import global_context as gpc - self.max_ep_size = gpc.config.get('max_ep_size', self.world_size) - assert self.world_size % self.max_ep_size == 0, \ - "Maximum expert parallel size must be a factor of the number of GPUs" + from colossalai.legacy.core import global_context as gpc + + self.max_ep_size = gpc.config.get("max_ep_size", self.world_size) + assert ( + self.world_size % self.max_ep_size == 0 + ), "Maximum expert parallel size must be a factor of the number of GPUs" self.min_dp_size = self.world_size // self.max_ep_size # Enabling kernel optimization may raise error in some cases @@ -72,6 +72,7 @@ class MoeContext(metaclass=SingletonMeta): self.use_kernel_optim = use_kernel_optim from .random import moe_set_seed + moe_set_seed(seed) self.has_setup = True @@ -89,11 +90,13 @@ class MoeContext(metaclass=SingletonMeta): number of local experts, the MoeParallelInfo of the current ep_size """ - gt_flag = num_experts % self.max_ep_size == 0 # check whether num_experts is greater - lt_flag = self.max_ep_size % num_experts == 0 # check whether num_experts is less + gt_flag = num_experts % self.max_ep_size == 0 # check whether num_experts is greater + lt_flag = self.max_ep_size % num_experts == 0 # check whether num_experts is less - assert gt_flag or lt_flag, "Automatic experts placement dose not not support expert number" \ - " is not a multiple of ep size or vice versa." + assert gt_flag or lt_flag, ( + "Automatic experts placement dose not not support expert number" + " is not a multiple of ep size or vice versa." + ) # If the number of experts is greater than maximum expert parallel size. a.k.a ep_size, # there are multiple experts in each GPU and each GPU has different experts diff --git a/colossalai/context/parallel_mode.py b/colossalai/context/parallel_mode.py deleted file mode 100644 index 1cf6fa53dc1e5c31fbaf1c9140e0915419af704c..0000000000000000000000000000000000000000 --- a/colossalai/context/parallel_mode.py +++ /dev/null @@ -1,49 +0,0 @@ -#!/usr/bin/env python -# -*- encoding: utf-8 -*- - -from enum import Enum - - -# parallel modes -class ParallelMode(Enum): - """This is an enumeration class containing all possible parallel modes. - """ - - GLOBAL = 'global' - - # common parallel - DATA = 'data' - - # model parallel - containing tensor and pipeline parallel groups - # this is added to facilitate amp and grad clipping in hybrid parallel - MODEL = 'model' - - # pipeline parallel - PIPELINE = 'pipe' - - # containing all ranks in tensor parallel - TENSOR = 'tensor' - - # sequence parallel - SEQUENCE = 'sequence' - SEQUENCE_DP = 'sequence_dp' - - # 1D Parallel - PARALLEL_1D = '1d' - - # 2D parallel - PARALLEL_2D_ROW = '2d_row' - PARALLEL_2D_COL = '2d_col' - - # 3D parallel - PARALLEL_3D_INPUT = '3d_input' - PARALLEL_3D_WEIGHT = '3d_weight' - PARALLEL_3D_OUTPUT = '3d_output' - PARALLEL_3D_INPUT_X_WEIGHT = "3d_input_x_weight" - PARALLEL_3D_OUTPUT_X_WEIGHT = "3d_output_x_weight" - - # 2.5D parallel - PARALLEL_2P5D_ROW = '2p5d_row' - PARALLEL_2P5D_COL = '2p5d_col' - PARALLEL_2P5D_DEP = '2p5d_dep' - PARALLEL_2P5D_XZ = '2p5d_xz' diff --git a/colossalai/context/process_group_initializer/__init__.py b/colossalai/context/process_group_initializer/__init__.py deleted file mode 100644 index d3937a9474376f0ecb7af612121bb4c3e5f5a497..0000000000000000000000000000000000000000 --- a/colossalai/context/process_group_initializer/__init__.py +++ /dev/null @@ -1,15 +0,0 @@ -from .initializer_1d import Initializer_1D -from .initializer_2d import Initializer_2D -from .initializer_2p5d import Initializer_2p5D -from .initializer_3d import Initializer_3D -from .initializer_data import Initializer_Data -from .initializer_pipeline import Initializer_Pipeline -from .initializer_sequence import Initializer_Sequence -from .initializer_tensor import Initializer_Tensor -from .initializer_model import Initializer_Model -from .process_group_initializer import ProcessGroupInitializer - -__all__ = [ - 'Initializer_Tensor', 'Initializer_Sequence', 'Initializer_Pipeline', 'Initializer_Data', 'Initializer_2p5D', - 'Initializer_2D', 'Initializer_3D', 'Initializer_1D', 'ProcessGroupInitializer', 'Initializer_Model' -] diff --git a/colossalai/context/process_group_initializer/initializer_pipeline.py b/colossalai/context/process_group_initializer/initializer_pipeline.py deleted file mode 100644 index 0ddb52f63e22f29aff9920d5cdd2aba1748e1eb6..0000000000000000000000000000000000000000 --- a/colossalai/context/process_group_initializer/initializer_pipeline.py +++ /dev/null @@ -1,56 +0,0 @@ -#!/usr/bin/env python -# -*- encoding: utf-8 -*- - -from torch import distributed as dist - -from colossalai.registry import DIST_GROUP_INITIALIZER - -from ..parallel_mode import ParallelMode -from .process_group_initializer import ProcessGroupInitializer - - -@DIST_GROUP_INITIALIZER.register_module -class Initializer_Pipeline(ProcessGroupInitializer): - """A ProcessGroupInitializer for pipeline parallelism. - - Args: - rank (int): The rank of current process - world_size (int): Size of whole communication world - config (Config): Running configuration - data_parallel_size (int): Size of data parallel - pipeline_parallel_size (int): Size of pipeline parallel - tensor_parallel_size (int): Size of tensor parallel - """ - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self.data_group_size = self.world_size // self.data_parallel_size - self.pipeline_stage_size = self.data_group_size // self.pipeline_parallel_size - - def init_dist_group(self): - """Initialize pipeline parallel groups, and assign local_ranks and groups to each gpu. - - Returns: - List[Tuple (local_rank, group_world_size, process_group, ranks_in_group, mode)]: - A Pipeline parallelism's information in list of tuples. - """ - dist_settings = list() - for i in range(self.data_parallel_size): - for j in range(self.pipeline_stage_size): - pipe_ranks = list( - range(i * self.data_group_size + j, (i + 1) * self.data_group_size, self.pipeline_stage_size)) - pipe_group_size = len(pipe_ranks) - pipe_group = dist.new_group(pipe_ranks) - group_cpu = dist.new_group(pipe_ranks, backend='gloo') if dist.get_backend() != 'gloo' else pipe_group - - if self.rank in pipe_ranks: - local_rank = pipe_ranks.index(self.rank) - group_world_size = pipe_group_size - process_group = pipe_group - cpu_group = group_cpu - ranks_in_group = pipe_ranks - dist_settings.append( - tuple((local_rank, group_world_size, process_group, cpu_group, ranks_in_group, - ParallelMode.PIPELINE))) - - return dist_settings diff --git a/colossalai/context/random/__init__.py b/colossalai/context/random/__init__.py deleted file mode 100644 index d64b993257c1574706ee5028224692b4e666fc19..0000000000000000000000000000000000000000 --- a/colossalai/context/random/__init__.py +++ /dev/null @@ -1,18 +0,0 @@ -from ._helper import ( - add_seed, - get_current_mode, - get_seeds, - get_states, - moe_set_seed, - reset_seeds, - seed, - set_mode, - set_seed_states, - sync_states, - with_seed, -) - -__all__ = [ - 'seed', 'set_mode', 'with_seed', 'add_seed', 'get_seeds', 'get_states', 'get_current_mode', 'set_seed_states', - 'sync_states', 'moe_set_seed', 'reset_seeds' -] diff --git a/colossalai/context/singleton_meta.py b/colossalai/context/singleton_meta.py index 8ca335119d52ad2a212b1e0c578202b2fc6bb60f..3088b0dffaace5da66de7525e8598a58e545b80d 100644 --- a/colossalai/context/singleton_meta.py +++ b/colossalai/context/singleton_meta.py @@ -16,6 +16,7 @@ class SingletonMeta(type): instance = super().__call__(*args, **kwargs) cls._instances[cls] = instance else: - assert len(args) == 0 and len( - kwargs) == 0, f'{cls.__name__} is a singleton class and a instance has been created.' + assert ( + len(args) == 0 and len(kwargs) == 0 + ), f"{cls.__name__} is a singleton class and a instance has been created." return cls._instances[cls] diff --git a/colossalai/core.py b/colossalai/core.py deleted file mode 100644 index 153247bbed9c65db0b2255247137fa9a64a693fa..0000000000000000000000000000000000000000 --- a/colossalai/core.py +++ /dev/null @@ -1,6 +0,0 @@ -#!/usr/bin/env python -# -*- encoding: utf-8 -*- - -from colossalai.context.parallel_context import global_context - -__all__ = ['global_context'] \ No newline at end of file diff --git a/colossalai/device/__init__.py b/colossalai/device/__init__.py index 689189998c3f6490145ba2648c522570c6f40b4c..34a7d2526fdad1fa1e3622d7b5f0a6bf419d9b50 100644 --- a/colossalai/device/__init__.py +++ b/colossalai/device/__init__.py @@ -1,4 +1,4 @@ from .alpha_beta_profiler import AlphaBetaProfiler from .calc_pipeline_strategy import alpa_dp -__all__ = ['AlphaBetaProfiler', 'alpa_dp'] +__all__ = ["AlphaBetaProfiler", "alpa_dp"] diff --git a/colossalai/device/alpha_beta_profiler.py b/colossalai/device/alpha_beta_profiler.py index af2b10928c6f2f99a429aaa413d527d77d52faf0..88520b2a14d08cb379e255485dfb9a228bf0da63 100644 --- a/colossalai/device/alpha_beta_profiler.py +++ b/colossalai/device/alpha_beta_profiler.py @@ -13,7 +13,7 @@ FRAMEWORK_LATENCY = 0 class AlphaBetaProfiler: - ''' + """ Profile alpha and beta value for a given device list. Usage: @@ -27,17 +27,19 @@ class AlphaBetaProfiler: (1, 4): (1.9010603427886962e-05, 7.077968863788975e-11), (1, 5): (1.9807778298854827e-05, 6.928845708992215e-11), (4, 5): (1.8681809306144713e-05, 4.7522367291330524e-12), (1, 0): (1.9641406834125518e-05, 4.74049549614719e-12), (4, 0): (1.9506998360157013e-05, 6.97421973297474e-11), (5, 0): (2.293858677148819e-05, 7.129930361393644e-11), (4, 1): (1.9010603427886962e-05, 7.077968863788975e-11), (5, 1): (1.9807778298854827e-05, 6.928845708992215e-11), (5, 4): (1.8681809306144713e-05, 4.7522367291330524e-12)} - ''' - - def __init__(self, - physical_devices: List[int], - alpha_beta_dict: Dict[Tuple[int, int], Tuple[float, float]] = None, - ctype: str = 'a', - warmup: int = 5, - repeat: int = 25, - latency_iters: int = 5, - homogeneous_tolerance: float = 0.1): - ''' + """ + + def __init__( + self, + physical_devices: List[int], + alpha_beta_dict: Dict[Tuple[int, int], Tuple[float, float]] = None, + ctype: str = "a", + warmup: int = 5, + repeat: int = 25, + latency_iters: int = 5, + homogeneous_tolerance: float = 0.1, + ): + """ Args: physical_devices: A list of device id, each element inside it is the global rank of that device. alpha_beta_dict: A dict which maps a process group to alpha-beta value pairs. @@ -45,7 +47,7 @@ class AlphaBetaProfiler: warmup: Number of warmup iterations. repeat: Number of iterations to measure. latency_iters: Number of iterations to measure latency. - ''' + """ self.physical_devices = physical_devices self.ctype = ctype self.world_size = len(physical_devices) @@ -123,7 +125,7 @@ class AlphaBetaProfiler: return (None, None) def profile_latency(self, process_group, pg_handler): - ''' + """ This function is used to profile the latency of the given process group with a series of bytes. Args: @@ -132,7 +134,7 @@ class AlphaBetaProfiler: Returns: latency: None if the latency is not measured, otherwise the median of the latency_list. - ''' + """ latency_list = [] for i in range(self.latency_iters): nbytes = int(BYTE << i) @@ -148,26 +150,26 @@ class AlphaBetaProfiler: return latency def profile_bandwidth(self, process_group, pg_handler, maxbytes=(1 * GB)): - ''' + """ This function is used to profile the bandwidth of the given process group. Args: process_group: A tuple of global rank of the process group. pg_handler: The handler of the process group. - ''' + """ (_, bandwidth) = self._profile(process_group, pg_handler, maxbytes) return bandwidth def profile_ab(self): - ''' + """ This method is used to profiling the alpha and beta value for a given device list. Returns: alpha_beta_dict: A dict which maps process group to its alpha and beta value. - ''' + """ alpha_beta_dict: Dict[Tuple[int], Tuple[float]] = {} rank = dist.get_rank() - global_pg_handler = dist.new_group(self.physical_devices) + dist.new_group(self.physical_devices) def get_max_nbytes(process_group: Tuple[int], pg_handler: dist.ProcessGroup): assert rank in process_group @@ -197,7 +199,7 @@ class AlphaBetaProfiler: dist.broadcast_object_list(broadcast_list, src=process_group[0]) alpha_beta_dict[process_group] = tuple(broadcast_list) - # add symmetry pair to the apha_beta_dict + # add symmetry pair to the alpha_beta_dict symmetry_ab_dict = {} for process_group, alpha_beta_pair in alpha_beta_dict.items(): symmetry_process_group = (process_group[1], process_group[0]) @@ -208,7 +210,7 @@ class AlphaBetaProfiler: return alpha_beta_dict def search_best_logical_mesh(self): - ''' + """ This method is used to search the best logical mesh for the given device list. The best logical mesh is searched in following steps: @@ -232,19 +234,19 @@ class AlphaBetaProfiler: >>> best_logical_mesh = profiler.search_best_logical_mesh() >>> print(best_logical_mesh) [[0, 1], [2, 3]] - ''' + """ def _power_of_two(integer): return integer & (integer - 1) == 0 def _detect_homogeneous_device(alpha_beta_dict): - ''' + """ This function is used to detect whether the devices in the alpha_beta_dict are homogeneous. Note: we assume that the devices in the alpha_beta_dict are homogeneous if the beta value of the devices are in range of [(1 - self.homogeneous_tolerance), (1 + self.homogeneous_tolerance)] * base_beta. - ''' + """ homogeneous_device_dict: Dict[float, List[Tuple[int]]] = {} for process_group, (_, beta) in alpha_beta_dict.items(): if homogeneous_device_dict is None: @@ -254,7 +256,8 @@ class AlphaBetaProfiler: match_beta = None for beta_value in homogeneous_device_dict.keys(): if beta <= beta_value * (1 + self.homogeneous_tolerance) and beta >= beta_value * ( - 1 - self.homogeneous_tolerance): + 1 - self.homogeneous_tolerance + ): match_beta = beta_value break @@ -267,9 +270,9 @@ class AlphaBetaProfiler: return homogeneous_device_dict def _check_contain_all_devices(homogeneous_group: List[Tuple[int]]): - ''' + """ This function is used to check whether the homogeneous_group contains all physical devices. - ''' + """ flatten_mesh = [] for process_group in homogeneous_group: flatten_mesh.extend(process_group) @@ -277,9 +280,9 @@ class AlphaBetaProfiler: return len(non_duplicated_flatten_mesh) == len(self.physical_devices) def _construct_largest_ring(homogeneous_group: List[Tuple[int]]): - ''' + """ This function is used to construct the largest ring in the homogeneous_group for each rank. - ''' + """ # Construct the ring ring = [] ranks_in_ring = [] @@ -300,7 +303,9 @@ class AlphaBetaProfiler: check_rank = check_rank_list.pop() for process_group in homogeneous_group: if check_rank in process_group: - rank_to_append = process_group[0] if process_group[1] == check_rank else process_group[1] + rank_to_append = ( + process_group[0] if process_group[1] == check_rank else process_group[1] + ) if rank_to_append not in ring_for_rank: stable_status = False rank_to_check_list.append(rank_to_append) @@ -314,7 +319,7 @@ class AlphaBetaProfiler: assert _power_of_two(self.world_size) power_of_two = int(math.log2(self.world_size)) median = power_of_two // 2 - balanced_logical_mesh_shape = (2**median, 2**(power_of_two - median)) + balanced_logical_mesh_shape = (2**median, 2 ** (power_of_two - median)) row_size, column_size = balanced_logical_mesh_shape[0], balanced_logical_mesh_shape[1] balanced_logical_mesh = [] for row_index in range(row_size): @@ -348,7 +353,7 @@ class AlphaBetaProfiler: return best_logical_mesh def extract_alpha_beta_for_device_mesh(self): - ''' + """ Extract the mesh_alpha list and mesh_beta list based on the best logical mesh, which will be used to initialize the device mesh. @@ -360,7 +365,7 @@ class AlphaBetaProfiler: [2.5917552411556242e-05, 0.00010312341153621673] >>> print(mesh_beta) [5.875573704655635e-11, 4.7361584445959614e-12] - ''' + """ best_logical_mesh = self.search_best_logical_mesh() first_axis = [row[0] for row in best_logical_mesh] @@ -381,7 +386,7 @@ class AlphaBetaProfiler: first_latency, first_bandwidth = _extract_alpha_beta(first_axis, first_axis_process_group) second_latency, second_bandwidth = _extract_alpha_beta(second_axis, second_axis_process_group) mesh_alpha = [first_latency, second_latency] - # The beta values have been enlarged by 1e10 times temporarilly because the computation cost + # The beta values have been enlarged by 1e10 times temporarily because the computation cost # is still estimated in the unit of TFLOPs instead of time. We will remove this factor in future. mesh_beta = [1e10 / first_bandwidth, 1e10 / second_bandwidth] diff --git a/colossalai/device/calc_pipeline_strategy.py b/colossalai/device/calc_pipeline_strategy.py index 4ab72dfe60f0c73f0e4f5186ed54205b68513bc0..72d432701ada2f9c1c8c541404cab10a07ea22b7 100644 --- a/colossalai/device/calc_pipeline_strategy.py +++ b/colossalai/device/calc_pipeline_strategy.py @@ -10,8 +10,10 @@ def get_submesh_choices(num_hosts, num_devices_per_host, mode="new"): while i <= num_devices_per_host: i *= 2 p += 1 - assert pow(2, p) == num_devices_per_host, ("Only supports the cases where num_devices_per_host is power of two, " - f"while now num_devices_per_host = {num_devices_per_host}") + assert pow(2, p) == num_devices_per_host, ( + "Only supports the cases where num_devices_per_host is power of two, " + f"while now num_devices_per_host = {num_devices_per_host}" + ) if mode == "alpa": for i in range(p + 1): submesh_choices.append((1, pow(2, i))) @@ -24,18 +26,19 @@ def get_submesh_choices(num_hosts, num_devices_per_host, mode="new"): return submesh_choices -def alpa_dp_impl(num_layers, num_devices, num_microbatches, submesh_choices, compute_cost, max_stage_cost, - best_configs): +def alpa_dp_impl( + num_layers, num_devices, num_microbatches, submesh_choices, compute_cost, max_stage_cost, best_configs +): """Implementation of Alpa DP for pipeline strategy - Paper reference: https://www.usenix.org/system/files/osdi22-zheng-lianmin.pdf + Paper reference: https://www.usenix.org/system/files/osdi22-zheng-lianmin.pdf - Arguments: - num_layers: K - num_devices: N*M - num_microbatches: B - submesh_choices: List[(n_i,m_i)] - compute_cost: t_intra - """ + Arguments: + num_layers: K + num_devices: N*M + num_microbatches: B + submesh_choices: List[(n_i,m_i)] + compute_cost: t_intra + """ # For f, layer ID start from 0 # f[#pipeline stages, layer id that is currently being considered, number of devices used] f = np.full((num_layers + 1, num_layers + 1, num_devices + 1), np.inf, dtype=np.float32) @@ -54,7 +57,7 @@ def alpa_dp_impl(num_layers, num_devices, num_microbatches, submesh_choices, com for i in range(num_layers, k, -1): stage_cost = compute_cost[k, i, m] new_cost = f[s - 1, k, d - n_submesh_devices] + stage_cost - if (stage_cost <= max_stage_cost and new_cost < f[s, k, d]): + if stage_cost <= max_stage_cost and new_cost < f[s, k, d]: f[s, k, d] = new_cost f_stage_max[s, k, d] = max(stage_cost, f_stage_max[s - 1, i, d - n_submesh_devices]) f_argmin[s, k, d] = (i, m, best_configs[k, i, m]) @@ -75,34 +78,34 @@ def alpa_dp_impl(num_layers, num_devices, num_microbatches, submesh_choices, com res = [] while current_s > 0 and current_layer < num_layers and current_devices > 0: - next_start_layer, submesh_choice, autosharding_choice = (f_argmin[current_s, current_layer, current_devices]) + next_start_layer, submesh_choice, autosharding_choice = f_argmin[current_s, current_layer, current_devices] assert next_start_layer != -1 and current_devices != -1 res.append(((current_layer, next_start_layer), submesh_choice, autosharding_choice)) current_s -= 1 current_layer = next_start_layer current_devices -= np.prod(np.array(submesh_choices[submesh_choice])) - assert (current_s == 0 and current_layer == num_layers and current_devices == 0) + assert current_s == 0 and current_layer == num_layers and current_devices == 0 return total_cost, res -def alpa_dp(num_layers, - num_devices, - num_microbatches, - submesh_choices, - num_autosharding_configs, - compute_cost, - gap=1e-6): +def alpa_dp( + num_layers, num_devices, num_microbatches, submesh_choices, num_autosharding_configs, compute_cost, gap=1e-6 +): """Alpa auto stage dynamic programming. - Code reference: https://github.com/alpa-projects/alpa/blob/main/alpa/pipeline_parallel/stage_construction.py + Code reference: https://github.com/alpa-projects/alpa/blob/main/alpa/pipeline_parallel/stage_construction.py Arguments: submesh_choices: List[(int,int)] num_autosharding_configs: Max number of t_intra(start_layer, end_layer, LogicalMesh) compute_cost: np.array(num_layers,num_layers,num_submesh_choices,num_autosharding_configs) """ - assert np.shape(compute_cost) == (num_layers, num_layers, len(submesh_choices), - num_autosharding_configs), "Cost shape wrong." + assert np.shape(compute_cost) == ( + num_layers, + num_layers, + len(submesh_choices), + num_autosharding_configs, + ), "Cost shape wrong." all_possible_stage_costs = np.sort(np.unique(compute_cost)) best_cost = np.inf best_solution = None @@ -117,8 +120,9 @@ def alpa_dp(num_layers, break if max_stage_cost - last_max_stage_cost < gap: continue - cost, solution = alpa_dp_impl(num_layers, num_devices, num_microbatches, submesh_choices, best_compute_cost, - max_stage_cost, best_configs) + cost, solution = alpa_dp_impl( + num_layers, num_devices, num_microbatches, submesh_choices, best_compute_cost, max_stage_cost, best_configs + ) if cost < best_cost: best_cost = cost best_solution = solution diff --git a/colossalai/device/device_mesh.py b/colossalai/device/device_mesh.py index 2a5f747fbc238c7799b3076e695030020f491d5b..72f199203a9d12c9dd75869a96065715811afb75 100644 --- a/colossalai/device/device_mesh.py +++ b/colossalai/device/device_mesh.py @@ -3,11 +3,19 @@ with some changes. """ import operator +from dataclasses import dataclass from functools import reduce -from typing import List, Tuple +from typing import Dict, List, Union import torch import torch.distributed as dist +from torch.distributed import ProcessGroup + + +@dataclass +class ProcessGroupContainer: + process_group: ProcessGroup + ranks: List[int] # modified from alpa LogicalDeviceMesh(https://github.com/alpa-projects/alpa/blob/main/alpa/shard_parallel/auto_sharding.py) @@ -27,223 +35,491 @@ class DeviceMesh: during initializing the DeviceMesh instance if the init_process_group set to True. Otherwise, users need to call create_process_groups_for_logical_mesh manually to init logical process group. (default: False) - need_flatten(bool, optional): initialize flatten_device_mesh during initializing the DeviceMesh instance if the need_flatten set to True. + device (str): the device for the process groups used by the DeviceMesh instance. (default: 'cuda') """ - def __init__(self, - physical_mesh_id: torch.Tensor, - mesh_shape: torch.Size = None, - logical_mesh_id: torch.Tensor = None, - mesh_alpha: List[float] = None, - mesh_beta: List[float] = None, - init_process_group: bool = False, - need_flatten: bool = True): - self.physical_mesh_id = physical_mesh_id + _DIST_BACKEND = {"cuda": "nccl", "cpu": "gloo"} + + def __init__( + self, + physical_mesh_id: torch.Tensor, + mesh_shape: torch.Size = None, + logical_mesh_id: torch.Tensor = None, + mesh_alpha: List[float] = None, + mesh_beta: List[float] = None, + init_process_group: bool = False, + device: str = "cuda", + ): + # ============================ + # Physical & Logical Mesh IDs + # ============================ + self._physical_mesh_id = physical_mesh_id + assert physical_mesh_id.dim() == 1, "physical_mesh_id should be a 1D tensor." + + # logical mesh ids can be obtained via two ways + # 1. provide physical mesh id and provide mesh shape + # 2. directly supply the logical mesh id + assert mesh_shape is None or logical_mesh_id is None, ( + "Only one of mesh_shape and logical_mesh_id can be specified." + "Logical mesh IDs are obtained from either mesh_shape + physical_mesh_id or directly from the user-supplied logical_mesh_id" + ) + if logical_mesh_id is None: - self.mesh_shape = mesh_shape - self._logical_mesh_id = self.physical_mesh_id.reshape(self.mesh_shape) + self._mesh_shape = mesh_shape + self._logical_mesh_id = self._physical_mesh_id.reshape(self._mesh_shape) else: self._logical_mesh_id = logical_mesh_id - self.mesh_shape = self._logical_mesh_id.shape + self._mesh_shape = self._logical_mesh_id.shape + + # ensure two things: + # 1. logical and physical mesh IDs should contain the same elements + # 2. there is no duplicate IDs in each mesh, e.g. [2, 2] is not allowed + assert torch.equal( + torch.unique(self._physical_mesh_id), torch.unique(self.logical_mesh_id) + ), "physical and logical mesh IDs should contain the same elements, please check if you have consistent physical_mesh_id and logical_mesh_id." + assert ( + torch.unique(self._physical_mesh_id).numel() == self._physical_mesh_id.numel() + ), "Found duplicate IDs in the physical_mesh_id and this is not allowed, please check your physical_mesh_id again." + assert ( + torch.unique(self.logical_mesh_id).numel() == self.logical_mesh_id.numel() + ), "Found duplicate IDs in the logical_mesh_id and this is not allowed, please check your logical_mesh_id again." - # map global rank into logical rank - self.convert_map = {} - self._global_rank_to_logical_rank_map(self._logical_mesh_id, []) + # =============================================== # coefficient for alpha-beta communication model + # alpha is latency and beta is bandwidth + # =============================================== + # if the values are not provided, we assume they are 1 for simplicity if mesh_alpha is None: - mesh_alpha = [1] * len(self.mesh_shape) + mesh_alpha = [1] * len(self._mesh_shape) if mesh_beta is None: - mesh_beta = [1] * len(self.mesh_shape) + mesh_beta = [1] * len(self._mesh_shape) + self.mesh_alpha = tuple(mesh_alpha) self.mesh_beta = tuple(mesh_beta) - self.init_process_group = init_process_group - self.need_flatten = need_flatten - if self.init_process_group: - self.process_groups_dict = self.create_process_groups_for_logical_mesh() - if self.need_flatten and self._logical_mesh_id.dim() > 1: - self.flatten_device_mesh = self.flatten() - # Create a new member `flatten_device_meshes` to distinguish from original flatten methods (Because I'm not sure if there are functions that rely on the self.flatten()) - # self.flatten_device_meshes = FlattenDeviceMesh(self.physical_mesh_id, self.mesh_shape, self.mesh_alpha, - # self.mesh_beta) + + # ensure the alpha and beta have the same shape + assert len(self.mesh_alpha) == len( + self.mesh_beta + ), "mesh_alpha and mesh_beta should have the same length, please check your mesh_alpha and mesh_beta again." + + # ========================= + # Device for Process Group + # ========================= + self._device = device + self._dist_backend = self._DIST_BACKEND[device] + + # ========================= + # Process Group Management + # ========================= + # the _global_to_local_rank_mapping is structured as follows + # { + # : [ , , , ...] + # } + self._global_to_local_rank_mapping = dict() + self._init_global_to_logical_rank_mapping( + mapping=self._global_to_local_rank_mapping, tensor=self.logical_mesh_id + ) + + # create process group + self._process_group_dict = {} + self._ranks_in_the_process_group = {} + self._global_rank_of_current_process = None + self._is_initialized = False + + # attribute used to indicate whether this object + # is created using DeviceMesh.from_process_group + # this attribute can be used to do some check in methods + # such get_process_group as no global rank information + # is known if created with from_process_group + self._is_init_from_process_group = False + + # initialize process group if specified + self._init_ranks_in_the_same_group() + self._init_process_group = init_process_group + if init_process_group: + self.init_logical_process_group() @property - def shape(self): - return self.mesh_shape + def shape(self) -> torch.Size: + """ + Return the shape of the logical mesh. + """ + return self._mesh_shape @property - def num_devices(self): - return reduce(operator.mul, self.physical_mesh_id.shape, 1) + def num_devices(self) -> int: + """ + Return the number of devices contained in the device mesh. + """ + return reduce(operator.mul, self._physical_mesh_id.shape, 1) @property - def logical_mesh_id(self): + def logical_mesh_id(self) -> torch.Tensor: + """ + Return the logical mesh id. + """ return self._logical_mesh_id - def __deepcopy__(self, memo): + @property + def is_initialized(self) -> bool: + """ + Return whether the process group is initialized. + """ + return self._is_initialized + + @staticmethod + def from_process_group(process_group: Union[ProcessGroup, List[ProcessGroup]]) -> "DeviceMesh": + """ + Create a DeviceMesh instance from the current process group. Please note that the DeviceMesh object created with this method + will not have information about the physical mesh id, and thus will not be able to query for other ranks and perform alpha-beta communication. + + Args: + process_group (Union[ProcessGroup, List[ProcessGroup]]): the process group or a list of process groups for the device mesh. + If the input is a ProcessGroup object, a 1D DeviceMesh object will be created. If the input is a list of ProcessGroup objects, + the ProcessGroup at the ith index will correspond to the process group in the ith axis of the device mesh. + + Returns: + DeviceMesh: the device mesh instance. + """ + + def _get_device_by_backend(process_group): + """ + Get the device type given a process group's backend. + """ + backend = dist.get_backend(process_group) + for _device, _backend in DeviceMesh._DIST_BACKEND.items(): + if _backend == backend: + return _device + return None + + if isinstance(process_group, ProcessGroup): + process_group = [process_group] + + # get mesh shape + mesh_shape = [dist.get_world_size(pg) for pg in process_group] + + # get device + device_list = [_get_device_by_backend(pg) for pg in process_group] + + # make sure all devices are the same + assert all( + [device == device_list[0] for device in device_list] + ), "All devices should be the same, please check your input process groups are created with the same distributed backend." + + # create a fake physical mesh id + # as we only get the process group associated with the current process, + # we cannot get the global ranks for all processes in the mesh + # therefore, we only use this fake physical mesh id to create the device mesh + # and will remove this fake physical mesh id later + fake_physical_mesh_id = torch.arange(reduce(operator.mul, mesh_shape, 1)) + + # create the device mesh + device_mesh = DeviceMesh(physical_mesh_id=fake_physical_mesh_id, mesh_shape=mesh_shape, device=device_list[0]) + + # hack the device attribute + device_mesh._physical_mesh_id = None + device_mesh._logical_mesh_id = None + device_mesh._global_rank_of_current_process = dist.get_rank() + device_mesh._is_initialized = False + device_mesh._process_group_dict = { + device_mesh._global_rank_of_current_process: {axis: pg for axis, pg in enumerate(process_group)} + } + + return device_mesh + + def get_process_group(self, axis: int, global_rank: int = None) -> ProcessGroup: + """ + Return the process group on the specified axis. + + Args: + axis (int): the axis of the process group. + global_rank (int, optional): the global rank of the process group. If not specified, the current process is used. (default: None) + """ + if global_rank is None: + global_rank = self._global_rank_of_current_process + elif self._is_init_from_process_group: + raise RuntimeError( + "The logical device mesh is create with DeviceMesh.from_process_group, this method is not supported for this creation method as no global rank information is known." + ) + return self._process_group_dict[global_rank][axis] + + def get_process_group_for_all_axes(self, global_rank: int = None) -> Dict[int, ProcessGroup]: + """ + Return the process groups for all axes. + + Args: + global_rank (int, optional): the global rank of the process + """ + if global_rank is None: + global_rank = self._global_rank_of_current_process + elif self._is_init_from_process_group: + raise RuntimeError( + "The logical device mesh is create with DeviceMesh.from_process_group, this method is not supported for this creation method as no global rank information is known." + ) + return self._process_group_dict[global_rank] + + def get_ranks_in_process_group(self, axis: int, global_rank: int = None) -> List[int]: + """ + Return the ranks in the process group on the specified axis. + + Args: + axis (int): the axis of the process group. + global_rank (int, optional): the global rank of the process + """ + if global_rank is None: + global_rank = self._global_rank_of_current_process + elif self._is_init_from_process_group: + raise RuntimeError( + "The logical device mesh is create with DeviceMesh.from_process_group, this method is not supported for this creation method as no global rank information is known." + ) + return self._ranks_in_the_process_group[global_rank][axis] + + def __deepcopy__(self, memo) -> "DeviceMesh": cls = self.__class__ result = cls.__new__(cls) memo[id(self)] = result for k, v in self.__dict__.items(): - if k != 'process_groups_dict': + if k != "_process_group_dict": setattr(result, k, __import__("copy").deepcopy(v, memo)) else: + # process group cannot be copied + # thus, we share them directly setattr(result, k, v) - return result - def flatten(self): + def _init_global_to_logical_rank_mapping( + self, mapping: Dict, tensor: torch.Tensor, index_list: List[int] = [] + ) -> Dict[int, List[int]]: """ - Flatten the logical mesh into an effective 1d logical mesh, + Build a global rank to local rank mapping for each process group in different axis in the logical device mesh. + + Args: + mapping (Dict): a dictionary that maps the global rank to the local rank in the logical device mesh. + tensor (torch.Tensor): the tensor that contains the logical mesh ids. + index_list (List[int]) + + Returns: + mapping (Dict): a dictionary that maps the global rank to the local rank in the logical device mesh. + The value is a list of integers and each integer represents the local rank in the indexed axis. """ - flatten_mesh_shape_size = len(self.mesh_shape) - flatten_mesh_shape = [self.num_devices] - return DeviceMesh(self.physical_mesh_id, - tuple(flatten_mesh_shape), - mesh_alpha=[max(self.mesh_alpha)] * (flatten_mesh_shape_size - 1), - mesh_beta=[max(self.mesh_beta)] * (flatten_mesh_shape_size - 1), - init_process_group=self.init_process_group, - need_flatten=False) - - def _global_rank_to_logical_rank_map(self, tensor, index_list): - ''' - This method is a helper function to build convert_map recursively. - ''' for index, inner_tensor in enumerate(tensor): + # index means the local rank in the current axis + # inner_tensor refers to the processes with the same local rank + if inner_tensor.numel() == 1: - self.convert_map[int(inner_tensor)] = index_list + [index] + # if the inner_tensor only has one element, it means that + # it already reaches the last axis + # we append its local_rank in the last axis to the index_list + # and assign to the mapping + # the value of the mapping is the the local rank at the indexed axis of the device mesh + mapping[int(inner_tensor)] = index_list + [index] else: - self._global_rank_to_logical_rank_map(inner_tensor, index_list + [index]) + # we recursively go into the function until we reach the last axis + # meanwhile, we should add the local rank in the current axis in the index_list + self._init_global_to_logical_rank_mapping(mapping, inner_tensor, index_list + [index]) - def create_process_groups_for_logical_mesh(self): - ''' + def init_logical_process_group(self): + """ This method is used to initialize the logical process groups which will be used in communications among logical device mesh. Note: if init_process_group set to False, you have to call this method manually. Otherwise, the communication related function, such as ShapeConsistencyManager.apply will raise errors. - ''' - process_groups_dict = {} - check_duplicate_list = [] - global_rank_flatten_list = self.physical_mesh_id.view(-1).tolist() + """ + # sanity check + assert ( + dist.is_initialized + ), "The torch.distributed should be initialized before calling init_logical_process_group" + assert ( + not self._is_initialized + ), "The logical process group has been initialized, do not call init_logical_process_group twice" + + # update the global rank of the current process + self._global_rank_of_current_process = dist.get_rank() + duplicate_check_list = [] + + # flatten the global ranks to 1D list + global_rank_flatten_list = self._physical_mesh_id.view(-1).tolist() + for global_rank in global_rank_flatten_list: - process_groups = self.global_rank_to_process_groups_with_global_rank(global_rank) - for axis, process_group in process_groups.items(): - if axis not in process_groups_dict: - process_groups_dict[axis] = [] - if process_group not in check_duplicate_list: - check_duplicate_list.append(process_group) - process_group_handler = dist.new_group(process_group) - process_groups_dict[axis].append((process_group, process_group_handler)) - - return process_groups_dict - - def global_rank_to_logical_rank(self, rank): - return self.convert_map[rank] - - def global_rank_to_process_groups_with_logical_rank(self, rank): - ''' - Give a global rank and return all logical process groups of this rank. - for example: - physical_mesh_id = torch.arange(0, 16).reshape(2, 8) - mesh_shape = (4, 4) - # [[0, 1, 2, 3], - # [4, 5, 6, 7], - # [8, 9, 10,11], - # [12,13,14,15]] - device_mesh = DeviceMesh(physical_mesh_id, mesh_shape) - print(device_mesh.global_rank_to_process_groups_with_logical_rank(0)) - output: - # key is axis name - # value is a list of logical ranks in same axis with rank 0 - {0: [[0, 0], [1, 0], [2, 0], [3, 0]], 1: [[0, 0], [0, 1], [0, 2], [0, 3]]} - ''' - process_groups = {} - for d in range(self.logical_mesh_id.dim()): - for replacer in range(self.logical_mesh_id.shape[d]): - if d not in process_groups: - process_groups[d] = [] - process_group_member = self.convert_map[rank].copy() - process_group_member[d] = replacer - process_groups[d].append(process_group_member) - return process_groups - - def global_rank_to_process_groups_with_global_rank(self, rank): - ''' - Give a global rank and return all process groups of this rank. - for example: - physical_mesh_id = torch.arange(0, 16).reshape(2, 8) - mesh_shape = (4, 4) - # [[0, 1, 2, 3], - # [4, 5, 6, 7], - # [8, 9, 10,11], - # [12,13,14,15]] - device_mesh = DeviceMesh(physical_mesh_id, mesh_shape) - print(device_mesh.global_rank_to_process_groups_with_global_rank(0)) - output: - # key is axis name - # value is a list of global ranks in same axis with rank 0 - {0: [0, 4, 8, 12], 1: [0, 1, 2, 3]} - ''' - logical_process_groups = self.global_rank_to_process_groups_with_logical_rank(rank) - process_groups = {} - for dim, logical_ranks in logical_process_groups.items(): - process_groups[dim] = [] - for logical_rank in logical_ranks: - for g_rank, l_rank in self.convert_map.items(): - if l_rank == logical_rank: - process_groups[dim].append(g_rank) - return process_groups + # find the other ranks which are in the same process group as global_rank + ranks_in_same_group_by_axis = self._collate_global_ranks_in_same_process_group(global_rank) + + for axis, ranks_in_same_group in ranks_in_same_group_by_axis.items(): + # skip duplicated process group creation + if ranks_in_same_group in duplicate_check_list: + continue + + # create the process group + pg_handler = dist.new_group(ranks=ranks_in_same_group, backend=self._dist_backend) + + # keep this process group in the process_groups_dict + for rank in ranks_in_same_group: + if rank not in self._process_group_dict: + self._process_group_dict[rank] = dict() + self._process_group_dict[rank][axis] = pg_handler + + # update the init flag + # we only allow init for once + self._is_initialized = True + + def _init_ranks_in_the_same_group(self): + """ + This method is used to initialize the ranks_in_the_same_group dictionary. + """ + # flatten the global ranks to 1D list + global_rank_flatten_list = self._physical_mesh_id.view(-1).tolist() + + for global_rank in global_rank_flatten_list: + # find the other ranks which are in the same process group as global_rank + ranks_in_same_group_by_axis = self._collate_global_ranks_in_same_process_group(global_rank) + + for axis, ranks_in_same_group in ranks_in_same_group_by_axis.items(): + # create dict for each rank + if global_rank not in self._process_group_dict: + self._ranks_in_the_process_group[global_rank] = dict() + + # keep this process group in the process_groups_dict + self._ranks_in_the_process_group[global_rank][axis] = ranks_in_same_group + + def global_rank_to_local_rank(self, rank: int, axis: int = None) -> Union[List[int], int]: + """ + Return the local rank of the given global rank in the logical device mesh. + + Args: + rank (int): the global rank in the logical device mesh. + axis (int): the axis of the logical device mesh. + """ + if self._is_init_from_process_group: + raise RuntimeError( + "The logical device mesh is create with DeviceMesh.from_process_group, this method is not supported for this creation method as no global rank information is known." + ) + + local_ranks = self._global_to_local_rank_mapping[rank] + if axis: + return local_ranks[axis] + else: + return local_ranks + + def _collate_global_ranks_in_same_process_group(self, global_rank): + """ + Give a global rank and return all global ranks involved in its associated process group in each axis. + + Example: + + ```python + physical_mesh_id = torch.arange(0, 16) + mesh_shape = (4, 4) + + # logical mesh will look like + # [[0, 1, 2, 3], + # [4, 5, 6, 7], + # [8, 9, 10,11], + # [12,13,14,15]] + + device_mesh = DeviceMesh(physical_mesh_id, mesh_shape) + print(device_mesh.collate_global_ranks_in_same_process_group(0)) + + # key is axis name + # value is a list of global ranks in same axis with rank 0 + # output will look like + # { + 0: [0, 4, 8, 12], + 1: [0, 1, 2, 3] + # } + """ + # We have init the global rank to local rank by calling _init_global_to_logical_rank_mapping + # for self._global_to_local_rank_mapping + # the key is the global rank + # the value is the list of local ranks corresponding to the global rank with respect of different axes + # we can see the list of local ranks as the process coordinates for simplicity + # the key and value are all unique, therefore, + # we can also to use the coordinates to find the global rank + + # ========================================================================= + # Step 1 + # find all the process_coordinates for processes in the same process group + # as the given global rank + # ========================================================================= + + # each + processes_in_the_same_process_group = {} + + for dim in range(self.logical_mesh_id.dim()): + # iterate over the dimension size so that we can include all processes + # in the same process group in the given axis + # the _local_rank refers to the local rank of the current process + for _local_rank in range(self.logical_mesh_id.shape[dim]): + # if this dimension is not initialized yet, + # initialize it with an empty array + if dim not in processes_in_the_same_process_group: + processes_in_the_same_process_group[dim] = [] + + # get the local rank corresponding to the global rank + process_coordinates = self._global_to_local_rank_mapping[global_rank].copy() + + # replace the local rank in the given dimension with the + # local rank of the current process iterated + process_coordinates[dim] = _local_rank + processes_in_the_same_process_group[dim].append(process_coordinates) + + # ================================================================= + # Step 2 + # Use local rank combination to find its corresponding global rank + # ================================================================= + # the key of the dict is the axis + # the value is the list of global ranks which are in the same process group as the given global rank + global_pg_ranks = {} + for dim, coordinates_of_all_processes in processes_in_the_same_process_group.items(): + global_pg_ranks[dim] = [] + for process_coordinates in coordinates_of_all_processes: + # find the global rank by local rank combination + for _global_rank, _process_coordinates in self._global_to_local_rank_mapping.items(): + if process_coordinates == _process_coordinates: + global_pg_ranks[dim].append(_global_rank) + return global_pg_ranks + + def flatten(self): + """ + Flatten the logical mesh into an effective 1d logical mesh, + """ + if self._is_init_from_process_group: + raise RuntimeError( + "The logical device mesh is create with DeviceMesh.from_process_group, this method is not supported for this creation method as no global rank information is known." + ) + + flatten_mesh_shape_size = len(self._mesh_shape) + flatten_mesh_shape = [self.num_devices] + return DeviceMesh( + self._physical_mesh_id, + tuple(flatten_mesh_shape), + mesh_alpha=[max(self.mesh_alpha)] * (flatten_mesh_shape_size - 1), + mesh_beta=[max(self.mesh_beta)] * (flatten_mesh_shape_size - 1), + init_process_group=self._init_process_group, + ) def all_gather_cost(self, num_bytes, mesh_dim): num_devices = self.logical_mesh_id.shape[mesh_dim] - return (self.mesh_alpha[mesh_dim] + self.mesh_beta[mesh_dim] * (num_devices - 1) / num_devices * num_bytes + - 0.1) + return self.mesh_alpha[mesh_dim] + self.mesh_beta[mesh_dim] * (num_devices - 1) / num_devices * num_bytes + 0.1 def all_reduce_cost(self, num_bytes, mesh_dim): num_devices = self.logical_mesh_id.shape[mesh_dim] - return (self.mesh_alpha[mesh_dim] + self.mesh_beta[mesh_dim] * 2 * (num_devices - 1) / num_devices * num_bytes + - 0.01) + return ( + self.mesh_alpha[mesh_dim] + + self.mesh_beta[mesh_dim] * 2 * (num_devices - 1) / num_devices * num_bytes + + 0.01 + ) def reduce_scatter_cost(self, num_bytes, mesh_dim): num_devices = self.logical_mesh_id.shape[mesh_dim] - return (self.mesh_alpha[mesh_dim] + self.mesh_beta[mesh_dim] * (num_devices - 1) / num_devices * num_bytes + - 0.001) + return ( + self.mesh_alpha[mesh_dim] + self.mesh_beta[mesh_dim] * (num_devices - 1) / num_devices * num_bytes + 0.001 + ) def all_to_all_cost(self, num_bytes, mesh_dim): num_devices = self.logical_mesh_id.shape[mesh_dim] penalty_factor = num_devices / 2.0 - return (self.mesh_alpha[mesh_dim] + self.mesh_beta[mesh_dim] * - (num_devices - 1) / num_devices / num_devices * num_bytes * penalty_factor + 0.001) - - -class FlattenDeviceMesh(DeviceMesh): - - def __init__(self, physical_mesh_id, mesh_shape, mesh_alpha=None, mesh_beta=None): - super().__init__(physical_mesh_id, - mesh_shape, - mesh_alpha, - mesh_beta, - init_process_group=False, - need_flatten=False) - # Different from flatten(), mesh_shape leaves unchanged, mesh_alpha and mesh_beta are scalars - self.mesh_alpha = max(self.mesh_alpha) - self.mesh_beta = min(self.mesh_beta) - # Different from original process_groups_dict, rank_list is not stored - self.process_number_dict = self.create_process_numbers_for_logical_mesh() - - def create_process_numbers_for_logical_mesh(self): - ''' - Build 1d DeviceMesh in column-major(0) and row-major(1) - for example: - mesh_shape = (2,4) - # [[0, 1, 2, 3], - # [4, 5, 6, 7]] - # return {0: [0, 4, 1, 5, 2, 6, 3, 7], 1: [0, 1, 2, 3, 4, 5, 6, 7]} - ''' - num_devices = reduce(operator.mul, self.mesh_shape, 1) - process_numbers_dict = {} - process_numbers_dict[0] = torch.arange(num_devices).reshape(self.mesh_shape).transpose(1, 0).flatten().tolist() - process_numbers_dict[1] = torch.arange(num_devices).reshape(self.mesh_shape).flatten().tolist() - return process_numbers_dict - - def mix_gather_cost(self, num_bytes): - num_devices = reduce(operator.mul, self.mesh_shape, 1) - return (self.mesh_alpha + self.mesh_beta * (num_devices - 1) / num_devices * num_bytes + 0.1) + return ( + self.mesh_alpha[mesh_dim] + + self.mesh_beta[mesh_dim] * (num_devices - 1) / num_devices / num_devices * num_bytes * penalty_factor + + 0.001 + ) diff --git a/colossalai/engine/__init__.py b/colossalai/engine/__init__.py deleted file mode 100644 index 158796befb312755ed92f77f7828557f55800e4c..0000000000000000000000000000000000000000 --- a/colossalai/engine/__init__.py +++ /dev/null @@ -1,4 +0,0 @@ -from ._base_engine import Engine -from .gradient_handler import * - -__all__ = ['Engine'] diff --git a/colossalai/engine/gradient_accumulation/__init__.py b/colossalai/engine/gradient_accumulation/__init__.py deleted file mode 100644 index 4cb6f4ad7384dda6136d98a0a73521e37d4027ba..0000000000000000000000000000000000000000 --- a/colossalai/engine/gradient_accumulation/__init__.py +++ /dev/null @@ -1,57 +0,0 @@ -from typing import Iterable, List - -import torch.nn as nn -from torch.optim import Optimizer -from torch.optim.lr_scheduler import _LRScheduler - -from colossalai.engine import BaseGradientHandler - -from ._gradient_accumulation import ( - GradAccumDataloader, - GradAccumGradientHandler, - GradAccumLrSchedulerByStep, - GradAccumOptimizer, -) - -__all__ = [ - 'accumulate_gradient', 'GradAccumDataloader', 'GradAccumOptimizer', 'GradAccumLrSchedulerByStep', - 'GradAccumGradientHandler' -] - - -def accumulate_gradient(model: nn.Module, - optimizer: Optimizer, - dataloader: Iterable, - accumulate_size: int, - gradient_handlers: List[BaseGradientHandler] = None, - lr_scheduler: _LRScheduler = None): - r"""Turning model, optimizer, dataloader into corresponding object for gradient accumulation. - - Args: - model (:class:`torch.nn.Module`): your model object for gradient accumulation. - optimizer (:class:`torch.optim.Optimizer`): your optimizer object for gradient accumulation. - dataloader (:class:`torch.utils.data.DataLoader` or iterable objects): - your dataloader object, would be called like iter(dataloader) - accumulate_size (int): the number of steps to accumulate gradients - gradient_handlers (List[:class:`colossalai.engine.BaseGradientHandler`]): - list of gradient handler objects. Default is None. - lr_scheduler (`torch.optim.lr_scheduler` or `colossalai.nn.lr_scheduler`): - your ``lr_scheduler`` object for gradient accumulation. Defaults to None. - - More details about `gradient_handlers` could be found in - `Gradient_handler `_. - - More details about `lr_scheduler` could be found - `lr_scheduler `_. and - `how to adjust learning rate `_. - """ - optimizer = GradAccumOptimizer(optimizer, accumulate_size=accumulate_size, model=model) - dataloader = GradAccumDataloader(dataloader, accumulate_size=accumulate_size) - - if gradient_handlers is not None: - gradient_handlers = [GradAccumGradientHandler(handler, accumulate_size) for handler in gradient_handlers] - - if lr_scheduler is not None: - lr_scheduler = GradAccumLrSchedulerByStep(lr_scheduler, accumulate_size=accumulate_size) - - return optimizer, dataloader, gradient_handlers, lr_scheduler diff --git a/colossalai/engine/gradient_handler/__init__.py b/colossalai/engine/gradient_handler/__init__.py deleted file mode 100644 index 2dea768bad7ecf1feed8bae69f733cda943509b5..0000000000000000000000000000000000000000 --- a/colossalai/engine/gradient_handler/__init__.py +++ /dev/null @@ -1,11 +0,0 @@ -from ._base_gradient_handler import BaseGradientHandler -from ._data_parallel_gradient_handler import DataParallelGradientHandler -from ._moe_gradient_handler import MoeGradientHandler -from ._pipeline_parallel_gradient_handler import PipelineSharedModuleGradientHandler -from ._sequence_parallel_gradient_handler import SequenceParallelGradientHandler -from ._zero_gradient_handler import ZeROGradientHandler - -__all__ = [ - 'BaseGradientHandler', 'DataParallelGradientHandler', 'ZeROGradientHandler', 'PipelineSharedModuleGradientHandler', - 'MoeGradientHandler', 'SequenceParallelGradientHandler' -] diff --git a/colossalai/engine/schedule/__init__.py b/colossalai/engine/schedule/__init__.py deleted file mode 100644 index 0f2c039d7057324676d30938c6ec112279077b61..0000000000000000000000000000000000000000 --- a/colossalai/engine/schedule/__init__.py +++ /dev/null @@ -1,5 +0,0 @@ -from ._base_schedule import BaseSchedule -from ._non_pipeline_schedule import NonPipelineSchedule -from ._pipeline_schedule import InterleavedPipelineSchedule, PipelineSchedule, get_tensor_shape - -__all__ = ['BaseSchedule', 'NonPipelineSchedule', 'PipelineSchedule', 'InterleavedPipelineSchedule', 'get_tensor_shape'] diff --git a/colossalai/engine/schedule/_pipeline_schedule.py b/colossalai/engine/schedule/_pipeline_schedule.py deleted file mode 100644 index 38175fe0941c1c053bf91fe1df558ee9e763c360..0000000000000000000000000000000000000000 --- a/colossalai/engine/schedule/_pipeline_schedule.py +++ /dev/null @@ -1,833 +0,0 @@ -#!/usr/bin/env python -# -*- encoding: utf-8 -*- - -import inspect -from typing import Callable, List, Tuple, Union - -import torch.cuda - -import colossalai.communication as comm -from colossalai.amp.naive_amp import NaiveAMPModel -from colossalai.context.parallel_mode import ParallelMode -from colossalai.core import global_context as gpc -from colossalai.logging import get_dist_logger -from colossalai.utils import switch_virtual_pipeline_parallel_rank -from colossalai.utils.cuda import get_current_device - -from ._base_schedule import BaseSchedule - - -def get_tensor_shape(): - if hasattr(gpc.config, 'TENSOR_SHAPE'): - return gpc.config.TENSOR_SHAPE - - if not gpc.is_initialized(ParallelMode.PIPELINE): - return None - - if hasattr(gpc.config, 'SEQ_LENGTH') and hasattr(gpc.config, 'GLOBAL_BATCH_SIZE') and hasattr( - gpc.config, 'GLOBAL_BATCH_SIZE') and hasattr(gpc.config, 'HIDDEN_SIZE'): - if gpc.is_initialized(ParallelMode.DATA): - dp_size = gpc.get_world_size(ParallelMode.DATA) - else: - dp_size = 1 - if gpc.is_initialized(ParallelMode.SEQUENCE): - seq_size = gpc.get_world_size(ParallelMode.SEQUENCE) - else: - seq_size = 1 - - tensor_shape = (gpc.config.SEQ_LENGTH // seq_size, - gpc.config.GLOBAL_BATCH_SIZE // dp_size // gpc.config.NUM_MICRO_BATCHES, gpc.config.HIDDEN_SIZE) - return tensor_shape - else: - return None - - -def pack_return_tensors(return_tensors): - output, label = tuple(zip(*return_tensors)) - if isinstance(output[0], torch.Tensor): - output = torch.cat(output, dim=0) - elif isinstance(output[0], (list, tuple)): - output = tuple(torch.cat(tensors, dim=0) for tensors in zip(*output)) - else: - raise TypeError(f'Output of model must be tensor or list/tuple of tensors') - if isinstance(label[0], torch.Tensor): - label = torch.cat(label, dim=0) - else: - merged_label = {k: [] for k in label[0].keys()} - for d in label: - for k, v in d.items(): - merged_label[k].append(v) - label = {k: torch.cat(v, dim=0) for k, v in merged_label.items()} - return output, label - - -class PipelineSchedule(BaseSchedule): - """A helper schedule class for pipeline parallelism running environment. - It uses non-interleaved 1F1B strategy. Other properties are similar as - :class:`NonPipelineSchedule`. - - Args: - num_microbatches (int): The number of microbatches. - data_process_func (Callable, optional): - The preprocessing function which receives a batch of data, and it will be executed in `load_batch`. - tensor_shape (torch.Size, optional): Specified shape in pipeline communication. - scatter_gather_tensors (bool, optional): - If set to `True`, communication will be reduced over pipeline when using 1D tensor parallelization. - - Example: - - # this shows an example of customized data_process_func - def data_process_func(stage_output, dataloader_output): - output1, output2 = stage_output - item1, item2, item3 = dataloader_output - - # assume item2 is not needed - data = (output1, output2, item1) - label = item3 - return data, label - - """ - - def __init__(self, - num_microbatches, - data_process_func: Callable = None, - tensor_shape: Union[torch.Size, List[int], Tuple[int]] = None, - scatter_gather_tensors: bool = False): - - # we need to make sure that the signature of the data_process_func is valid - if data_process_func: - sig = inspect.signature(data_process_func) - assert len(sig.parameters) == 2, \ - 'The data_process_func only takes in two parameters for NonPipelineSchedule, ' \ - 'which is the tensors passed by the previous pipeline stage and the dataloader output from this stage, ' \ - 'i.e. data_process_func(stage_output, dataloader_output).' - - super().__init__(data_process_func=data_process_func) - - assert num_microbatches > 0, f'expected num_microbatches to be larger then 1, but got {num_microbatches}' - - self.num_microbatches = num_microbatches - self.dtype = torch.float - assert not isinstance(tensor_shape, - int), "tensor_shape type should be one of Union[torch.Size, List[int], Tuple[int]]." - if tensor_shape is None: - self.tensor_shape = tensor_shape - elif isinstance(tensor_shape, torch.Size): - self.tensor_shape = tensor_shape - else: - self.tensor_shape = torch.Size(tensor_shape) - self.scatter_gather_tensors = False - if gpc.is_initialized(ParallelMode.PARALLEL_1D) and gpc.get_world_size(ParallelMode.PARALLEL_1D) > 1: - self.scatter_gather_tensors = scatter_gather_tensors - self._logger = get_dist_logger() - - # cache for the batch data - self.batch_data = None - - def load_batch(self, data_iter): - # Pipeline schedule just puts data in memory - batch_data = super().load_batch(data_iter, to_gpu=False) - self.microbatch_offset = 0 - assert self.batch_size % self.num_microbatches == 0, \ - "Batch size should divided by the number of microbatches" - self.microbatch_size = self.batch_size // self.num_microbatches - self.batch_data = batch_data - - def _get_data_slice(self, data, offset): - if isinstance(data, torch.Tensor): - return data[offset:offset + self.microbatch_size] - elif isinstance(data, (list, tuple)): - data_dict = {} - for element in data: - if isinstance(element, dict): - data_dict.update({k: v[offset:offset + self.microbatch_size] for k, v in element.items()}) - elif data_dict: - data_dict['label'] = element[offset:offset + self.microbatch_size] - if data_dict: - return data_dict - return [val[offset:offset + self.microbatch_size] for val in data] - elif isinstance(data, dict): - return {k: v[offset:offset + self.microbatch_size] for k, v in data.items()} - else: - raise TypeError(f"Expected data to be of type torch.Tensor, list, tuple, or dict, but got {type(data)}") - - def load_micro_batch(self): - mciro_batch_data = self._get_data_slice(self.batch_data, self.microbatch_offset) - self.microbatch_offset += self.microbatch_size - return self._move_to_device(mciro_batch_data) - - def pre_processing(self, engine): - from colossalai.zero.legacy import ShardedModelV2 - - # TODO: remove this after testing new zero with pipeline parallelism - model = engine.model - if isinstance(model, NaiveAMPModel): - self.dtype = torch.half - model = model.model - if isinstance(model, ShardedModelV2): - self.dtype = torch.half - model = model.module - # sig = inspect.signature(model.forward) - # for p in sig.parameters.values(): - # assert p.kind != inspect.Parameter.VAR_POSITIONAL, '*args is not supported' - - @staticmethod - def _call_engine(model, data): - if data is not None: - if isinstance(data, torch.Tensor): - return model(data) - elif isinstance(data, (list, tuple)): - return model(*data) - elif isinstance(data, dict): - stage_output = None - if 'stage_output' in data: - stage_output = data.pop('stage_output') - if stage_output is None: - return model(**data) - elif isinstance(stage_output, torch.Tensor): - return model(stage_output, **data) - elif isinstance(stage_output, (tuple, list)): - return model(*stage_output, **data) - else: - raise TypeError( - f"Expected stage_output to be of type torch.Tensor, list, or tuple, but got {type(stage_output)}" - ) - else: - raise TypeError(f"Expected data to be of type torch.Tensor, list, tuple, or dict, but got {type(data)}") - - def _get_actual_forward_func(self, module): - if isinstance(module, NaiveAMPModel): - sig = inspect.signature(module.model.forward) - elif hasattr(module, 'colo_attr'): - sig = inspect.signature(module.module.forward) - else: - sig = inspect.signature(module.forward) - return sig - - def _get_data_label_for_current_step(self, stage_output, micro_batch_data, criterion, model): - if self.data_process_func: - # use customized function to get data and label - data, label = self.data_process_func(stage_output, micro_batch_data) - else: - if isinstance(micro_batch_data, (tuple, list)): - if gpc.is_first_rank(ParallelMode.PIPELINE): - # for the first stage, we use the data from the - # dataloader output by default - data, label = micro_batch_data - else: - # for non-first stage, we use the output passed - # by the previous as the model input - data = stage_output - _, label = micro_batch_data - elif isinstance(micro_batch_data, dict): - data = {} - data['stage_output'] = stage_output - if 'label' in micro_batch_data: - label = micro_batch_data.pop('label') - else: - label = None - load_data = micro_batch_data - data.update(load_data) - return data, label - - def _forward_step(self, engine, input_obj, return_tensors, return_output_label=True, accum_loss=None): - """Forward step for passed-in model. If it is the first stage, the input tensor - is obtained from data_iterator, otherwise the passed-in input_obj is used. - Returns output tensor. This is a helper function and can be ignored by users. - - Args: - engine (colossalai.engine.Engine): Colossalai engine for training and inference. - input_obj (Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]): Input tensor for this pipeline stage. - return_tensors (List[:class:`torch.Tensor`]): A list of tensors to return. - return_output_label (bool, optional): Whether returns output labels. - accum_loss (optional): Where accumulated loss stores. - Returns: - Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]: output or the loss value of the current pipeline stage. - """ - micro_batch_data = self.load_micro_batch() - - data, label = self._get_data_label_for_current_step(input_obj, micro_batch_data, engine.criterion, engine.model) - - output_obj = self._call_engine(engine.model, data) - - if gpc.is_last_rank(ParallelMode.PIPELINE): - if return_output_label: - return_tensors.append((output_obj, label)) - if accum_loss is not None: - loss_reduced = self._call_engine_criterion(engine, output_obj, label) / self.num_microbatches - accum_loss.add_(loss_reduced.detach()) - return loss_reduced - else: - # forward only, it's useless since backward is not needed - return output_obj - else: - if isinstance(output_obj, torch.Tensor): - self._logger.debug( - f'Global rank {gpc.get_global_rank()}, pipeline rank {gpc.get_local_rank(ParallelMode.PIPELINE)} forward output tensor {output_obj.shape}, dtype {output_obj.dtype}' - ) - return output_obj - - def _backward_step(self, engine, input_obj, output_obj, output_obj_grad): - """Backward step through the passed-in output tensor. If it is the last stage, the - output_obj_grad is None, otherwise it is the gradients with respect to stage's output tensor. - Returns the gradients with respect to the input tensor (None if first stage). - This is a helper function and can be ignored by users. - - Args: - engine (colossalai.engine.Engine): Colossalai engine for training and inference. - input_obj (Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]): input tensor for this pipeline stage. - output_obj (Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]): output tensor for this pipeline stage. - output_obj_grad (Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]): gradient of output tensor for this pipeline stage. - - Returns: - Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]: gradient of input tensor. - """ - - # Retain the grad on the input_obj. - if input_obj is not None: - if isinstance(input_obj, torch.Tensor): - input_obj.retain_grad() - else: - for in_tensor in input_obj: - if in_tensor is not None: - in_tensor.retain_grad() - # Backward pass. - if output_obj_grad is None: - engine.backward(output_obj) - else: - engine.backward_by_grad(output_obj, output_obj_grad) - - # Collect the grad of the input_obj. - input_obj_grad = None - if input_obj is not None: - if isinstance(input_obj, torch.Tensor): - input_obj_grad = input_obj.grad - else: - input_obj_grad = [] - for in_tensor in input_obj: - input_obj_grad.append(in_tensor.grad) - - return input_obj_grad - - def forward_backward_step(self, engine, data_iter, forward_only=False, return_loss=True, return_output_label=True): - """Runs non-interleaved 1F1B schedule, with communication between pipeline stages. - Returns a tuple with losses if the last stage, an empty tuple otherwise. - - Args: - engine (colossalai.engine.Engine): Colossalai engine for training and inference. - data_iter (Iterable): Dataloader as the form of an iterator, obtained by calling iter(dataloader). - forward_only (bool, optional): - Whether run forward step only. Default is false. If true, no backward will be run. - return_loss (bool, optional): Whether returns the loss value. Default is true. - return_output_label (bool, optional): If False, the output and label won't be returned. - - Returns: - Tuple[:class:`torch.Tensor`]: A tuple of (output, label, loss), loss and label could be None. - """ - - assert forward_only or return_loss, \ - 'The argument \'return_loss\' has to be True when \'forward_only\' is False, but got False.' - self.load_batch(data_iter) - num_warmup_microbatches = \ - (gpc.get_world_size(ParallelMode.PIPELINE) - - gpc.get_local_rank(ParallelMode.PIPELINE) - 1) - num_warmup_microbatches = min(num_warmup_microbatches, self.num_microbatches) - num_microbatches_remaining = self.num_microbatches - num_warmup_microbatches - - # Input, output tensors only need to be saved when doing backward passes - input_objs = None - output_objs = None - if not forward_only: - input_objs = [] - output_objs = [] - return_tensors = [] - if return_loss and gpc.is_pipeline_last_stage(ignore_virtual=True): - accum_loss = torch.zeros(1, device=get_current_device()) - else: - accum_loss = None - # Used for tensor meta information communication - ft_shapes = self.tensor_shape - bt_shapes = None - fs_checker = self.tensor_shape is None - - # Run warmup forward passes. - for i in range(num_warmup_microbatches): - if not gpc.is_first_rank(ParallelMode.PIPELINE): - ft_shapes = comm.recv_obj_meta(ft_shapes) - input_obj = comm.recv_forward(ft_shapes, - dtype=self.dtype, - scatter_gather_tensors=self.scatter_gather_tensors) - output_obj = self._forward_step(engine, - input_obj, - return_tensors, - return_output_label=return_output_label, - accum_loss=accum_loss) - if not gpc.is_last_rank(ParallelMode.PIPELINE): - if isinstance(output_obj, torch.Tensor): - bt_shapes = output_obj.shape - else: - bt_shapes = [] - for out_tensor in output_obj: - bt_shapes.append(out_tensor.shape) - fs_checker = comm.send_obj_meta(output_obj, fs_checker) - comm.send_forward(output_obj, scatter_gather_tensors=self.scatter_gather_tensors) - - if not forward_only: - input_objs.append(input_obj) - output_objs.append(output_obj) - - # Before running 1F1B, need to receive first forward tensor. - # If all microbatches are run in warmup / cooldown phase, then no need to - # receive this tensor here. - if num_microbatches_remaining > 0: - if not gpc.is_first_rank(ParallelMode.PIPELINE): - ft_shapes = comm.recv_obj_meta(ft_shapes) - input_obj = comm.recv_forward(ft_shapes, - dtype=self.dtype, - scatter_gather_tensors=self.scatter_gather_tensors) - - # Run 1F1B in steady state. - for i in range(num_microbatches_remaining): - last_iteration = (i == (num_microbatches_remaining - 1)) - - output_obj = self._forward_step(engine, - input_obj, - return_tensors, - return_output_label=return_output_label, - accum_loss=accum_loss) - if forward_only: - comm.send_forward(output_obj, scatter_gather_tensors=self.scatter_gather_tensors) - - if not last_iteration: - input_obj = comm.recv_forward(ft_shapes, - dtype=self.dtype, - scatter_gather_tensors=self.scatter_gather_tensors) - - else: - output_obj_grad = comm.send_forward_recv_backward(output_obj, - bt_shapes, - dtype=self.dtype, - scatter_gather_tensors=self.scatter_gather_tensors) - - # Add input_obj and output_obj to end of list. - input_objs.append(input_obj) - output_objs.append(output_obj) - - # Pop output_obj and output_obj from the start of the list for - # the backward pass. - input_obj = input_objs.pop(0) - output_obj = output_objs.pop(0) - - input_obj_grad = self._backward_step(engine, input_obj, output_obj, output_obj_grad) - - if last_iteration: - input_obj = None - comm.send_backward(input_obj_grad, scatter_gather_tensors=self.scatter_gather_tensors) - else: - input_obj = comm.send_backward_recv_forward(input_obj_grad, - ft_shapes, - dtype=self.dtype, - scatter_gather_tensors=self.scatter_gather_tensors) - - # Run cooldown backward passes. - if not forward_only: - for i in range(num_warmup_microbatches): - input_obj = input_objs.pop(0) - output_obj = output_objs.pop(0) - - output_obj_grad = comm.recv_backward(bt_shapes, - dtype=self.dtype, - scatter_gather_tensors=self.scatter_gather_tensors) - - input_obj_grad = self._backward_step(engine, input_obj, output_obj, output_obj_grad) - - comm.send_backward(input_obj_grad, scatter_gather_tensors=self.scatter_gather_tensors) - - if len(return_tensors) > 0: - output, label = pack_return_tensors(return_tensors) - return output, label, accum_loss - else: - return None, None, accum_loss - - -class InterleavedPipelineSchedule(PipelineSchedule): - - def __init__(self, - num_microbatches: int, - num_model_chunks: int, - data_process_func: Callable = None, - tensor_shape: Union[torch.Size, List[int], Tuple[int]] = None, - scatter_gather_tensors: bool = False): - """A helper schedule class for pipeline parallelism running environment. - It uses interleaved 1F1B strategy. Other properties are similar as - :class:`NonPipelineSchedule`. - - Args: - num_microbatches (int): The number of microbatches. - num_model_chunks (int): The number of model chunks. - data_process_func (Callable, optional): - The preprocessing function which receives a batch of data, and it will be executed in `load_batch`. - tensor_shape (torch.Size, optional): Specified shape in pipeline communication. - scatter_gather_tensors (bool, optional): - If set to `True`, communication will be reduced over pipeline when using 1D tensor parallelization. - """ - assert num_microbatches % gpc.get_world_size(ParallelMode.PIPELINE) == 0, \ - 'num_microbatches must be an integer multiple of pipeline parallel world size' - assert isinstance(num_model_chunks, int) and num_model_chunks > 0, \ - f'expected num_model_chunks to be an integer and larger than 0, but got {num_model_chunks}' - super().__init__(num_microbatches, - data_process_func=data_process_func, - tensor_shape=tensor_shape, - scatter_gather_tensors=scatter_gather_tensors) - gpc.set_virtual_pipeline_parallel_size(num_model_chunks) - gpc.set_virtual_pipeline_parallel_rank(0) - self.num_model_chunks = num_model_chunks - - def pre_processing(self, engine): - from colossalai.zero.sharded_model.sharded_model_v2 import ShardedModelV2 - if isinstance(engine.model, ShardedModelV2): - self.dtype = torch.half - elif isinstance(engine.model[0], NaiveAMPModel): - self.dtype = torch.half - for model in engine.model: - if isinstance(model, NaiveAMPModel): - model = model.model - sig = inspect.signature(model.forward) - for p in sig.parameters.values(): - assert p.kind != inspect.Parameter.VAR_POSITIONAL, '*args is not supported' - - def load_batch(self, data_iter): - super().load_batch(data_iter) - # overwrite microbatch_offset, since model chunks load the same microbatch, and should tract the offset - self.microbatch_offset = [0 for _ in range(self.num_model_chunks)] - - def load_micro_batch(self, model_chunk_id): - data = self._get_data_slice(self.batch_data, self.microbatch_offset[model_chunk_id]) - self.microbatch_offset[model_chunk_id] += self.microbatch_size - return self._move_to_device(data) - - def _forward_step(self, - engine, - model_chunk_id, - input_obj, - return_tensors, - return_output_label=True, - accum_loss=None): - """Forward step for passed-in model. If it is the first stage, the input tensor - is obtained from data_iterator, otherwise the passed-in input_obj is used. - Returns output tensor. This is a helper function and can be ignored by users. - - Args: - engine (colossalai.engine.Engine): Colossalai engine for training and inference. - model_chunk_id (int): The id of model chunks. - input_obj (Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]): Input tensor for this pipeline stage. - return_tensors (List[:class:`torch.Tensor`]): A list of tensors to return. - return_output_label (bool, optional): Whether returns output labels. - accum_loss (optional): Where accumulated loss stores. - Returns: - Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]: output or the loss value of the current pipeline stage. - """ - micro_batch_data = self.load_micro_batch(model_chunk_id) - data, label = self._get_data_label_for_current_step(input_obj, micro_batch_data, engine.criterion, - engine.model[model_chunk_id]) - - output_obj = self._call_engine(engine.model[model_chunk_id], data) - - if gpc.is_pipeline_last_stage(): - if return_output_label: - return_tensors.append((output_obj, label)) - if accum_loss is not None: - loss_reduced = self._call_engine_criterion(engine, output_obj, label) / self.num_microbatches - accum_loss.add_(loss_reduced.detach()) - return loss_reduced - else: - # forward only, it's useless since backward is not needed - return output_obj - else: - if isinstance(output_obj, torch.Tensor): - self._logger.debug( - f'Global rank {gpc.get_global_rank()}, pipeline rank {gpc.get_local_rank(ParallelMode.PIPELINE)} forward output tensor {output_obj.shape}, dtype {output_obj.dtype}' - ) - return output_obj - - def forward_backward_step(self, engine, data_iter, forward_only=False, return_loss=True, return_output_label=True): - """Run interleaved 1F1B schedule (model split into model chunks), with - communication between pipeline stages as needed. - - Args: - engine (colossalai.engine.Engine): Colossalai engine for training and inference. - data_iter (Iterable): Dataloader as the form of an iterator, obtained by calling iter(dataloader). - forward_only (bool, optional): - Whether run forward step only. Default is false. If true, no backward will be run. - return_loss (bool, optional): Whether returns the loss value. Default is true. - return_output_label (bool, optional): If False, the output and label won't be returned. - - Returns: - Tuple[:class:`torch.Tensor`]: A tuple of (output, label, loss), loss and label could be None. - The loss would be returned only in the last stage. - """ - assert forward_only or return_loss, \ - 'The argument \'return_loss\' has to be True when \'forward_only\' is False, but got False.' - self.load_batch(data_iter) - model = engine.model - input_objs = [[] for _ in range(len(model))] - output_objs = [[] for _ in range(len(model))] - return_tensors = [] - if not forward_only: - output_obj_grads = [[] for _ in range(len(model))] - if return_loss and gpc.is_pipeline_last_stage(ignore_virtual=True): - accum_loss = torch.zeros(1, device=get_current_device()) - else: - accum_loss = None - - # Used for obj meta information communication - input_obj_shapes = [self.tensor_shape for _ in range(len(model))] - output_obj_shapes = [None for _ in range(len(model))] - send_tensor_shape_flags = [self.tensor_shape is None for _ in range(len(model))] - - pipeline_parallel_size = gpc.get_world_size(ParallelMode.PIPELINE) - pipeline_parallel_rank = gpc.get_local_rank(ParallelMode.PIPELINE) - - # Compute number of warmup and remaining microbatches. - num_model_chunks = len(model) - num_microbatches = self.num_microbatches * num_model_chunks - all_warmup_microbatches = False - if forward_only: - num_warmup_microbatches = num_microbatches - else: - # Run all forward passes and then all backward passes if number of - # microbatches is just the number of pipeline stages. - # Otherwise, perform (num_model_chunks-1)*pipeline_parallel_size on - # all workers, followed by more microbatches after depending on - # stage ID (more forward passes for earlier stages, later stages can - # immediately start with 1F1B). - if self.num_microbatches == pipeline_parallel_size: - num_warmup_microbatches = num_microbatches - all_warmup_microbatches = True - else: - num_warmup_microbatches = \ - (pipeline_parallel_size - pipeline_parallel_rank - 1) * 2 - num_warmup_microbatches += (num_model_chunks - 1) * pipeline_parallel_size - num_warmup_microbatches = min(num_warmup_microbatches, num_microbatches) - num_microbatches_remaining = \ - num_microbatches - num_warmup_microbatches - - def get_model_chunk_id(microbatch_id, forward): - """Helper method to get the model chunk ID given the iteration number.""" - microbatch_id_in_group = microbatch_id % (pipeline_parallel_size * num_model_chunks) - model_chunk_id = microbatch_id_in_group // pipeline_parallel_size - if not forward: - model_chunk_id = (num_model_chunks - model_chunk_id - 1) - return model_chunk_id - - def _forward_step_helper(microbatch_id): - """Helper method to run forward step with model split into chunks - (run set_virtual_pipeline_model_parallel_rank() before calling - forward_step()).""" - model_chunk_id = get_model_chunk_id(microbatch_id, forward=True) - gpc.set_virtual_pipeline_parallel_rank(model_chunk_id) - - # forward step - if gpc.is_pipeline_first_stage(): - if len(input_objs[model_chunk_id]) == \ - len(output_objs[model_chunk_id]): - input_objs[model_chunk_id].append(None) - input_obj = input_objs[model_chunk_id][-1] - output_obj = self._forward_step(engine, - model_chunk_id, - input_obj, - return_tensors, - return_output_label=return_output_label, - accum_loss=accum_loss) - output_objs[model_chunk_id].append(output_obj) - - # if forward-only, no need to save tensors for a backward pass - if forward_only: - input_objs[model_chunk_id].pop() - output_objs[model_chunk_id].pop() - - return output_obj - - def _backward_step_helper(microbatch_id): - """Helper method to run backward step with model split into chunks - (run set_virtual_pipeline_model_parallel_rank() before calling - backward_step()).""" - model_chunk_id = get_model_chunk_id(microbatch_id, forward=False) - gpc.set_virtual_pipeline_parallel_rank(model_chunk_id) - - if gpc.is_pipeline_last_stage(): - if len(output_obj_grads[model_chunk_id]) == 0: - output_obj_grads[model_chunk_id].append(None) - input_obj = input_objs[model_chunk_id].pop(0) - output_obj = output_objs[model_chunk_id].pop(0) - output_obj_grad = output_obj_grads[model_chunk_id].pop(0) - input_obj_grad = self._backward_step(engine, input_obj, output_obj, output_obj_grad) - - return input_obj_grad - - # Run warmup forward passes. - gpc.set_virtual_pipeline_parallel_rank(0) - if not gpc.is_pipeline_first_stage(): - input_obj_shapes[0] = comm.recv_obj_meta(input_obj_shapes[0]) - input_objs[0].append( - comm.recv_forward(input_obj_shapes[0], dtype=self.dtype, - scatter_gather_tensors=self.scatter_gather_tensors)) - - for k in range(num_warmup_microbatches): - model_chunk_id = get_model_chunk_id(k, forward=True) - output_obj = _forward_step_helper(k) - if not gpc.is_pipeline_last_stage(): - if isinstance(output_obj, torch.Tensor): - output_obj_shapes[model_chunk_id] = output_obj.shape - else: - output_obj_shapes[model_chunk_id] = [] - for out_tensor in output_obj: - output_obj_shapes[model_chunk_id].append(out_tensor.shape) - send_tensor_shape_flags[model_chunk_id] = comm.send_obj_meta(output_obj, - send_tensor_shape_flags[model_chunk_id]) - # Determine if tensor should be received from previous stage. - next_forward_model_chunk_id = get_model_chunk_id(k + 1, forward=True) - recv_prev = True - if gpc.is_pipeline_first_stage(ignore_virtual=True): - if next_forward_model_chunk_id == 0: - recv_prev = False - if k == (num_microbatches - 1): - recv_prev = False - - # Don't send tensor downstream if on last stage. - if gpc.is_pipeline_last_stage(): - output_obj = None - - with switch_virtual_pipeline_parallel_rank(next_forward_model_chunk_id): - if not gpc.is_pipeline_first_stage(): - input_obj_shapes[next_forward_model_chunk_id] = comm.recv_obj_meta( - input_obj_shapes[next_forward_model_chunk_id]) - # Send and receive tensors as appropriate (send tensors computed - # in this iteration; receive tensors for next iteration). - input_shape = input_obj_shapes[next_forward_model_chunk_id] if recv_prev else None - if k == (num_warmup_microbatches - 1) and not forward_only and \ - not all_warmup_microbatches: - input_obj_grad = None - recv_next = True - if gpc.is_pipeline_last_stage(ignore_virtual=True): - recv_next = False - output_shape = output_obj_shapes[num_model_chunks - 1] if recv_next else None - input_obj, output_obj_grad = \ - comm.send_forward_backward_recv_forward_backward( - output_obj, input_obj_grad, - input_shape, - output_shape, - recv_prev=recv_prev, recv_next=recv_next, - dtype=self.dtype, - scatter_gather_tensors=self.scatter_gather_tensors) - output_obj_grads[num_model_chunks - 1].append(output_obj_grad) - else: - input_obj = \ - comm.send_forward_recv_forward( - output_obj, - input_shape, - recv_prev=recv_prev, - dtype=self.dtype, - scatter_gather_tensors=self.scatter_gather_tensors) - input_objs[next_forward_model_chunk_id].append(input_obj) - - # Run 1F1B in steady state. - for k in range(num_microbatches_remaining): - # Forward pass. - forward_k = k + num_warmup_microbatches - output_obj = _forward_step_helper(forward_k) - - # Backward pass. - backward_k = k - input_obj_grad = _backward_step_helper(backward_k) - - # Send output_obj and input_obj_grad, receive input_obj - # and output_obj_grad. - - # Determine if current stage has anything to send in either direction, - # otherwise set obj to None. - forward_model_chunk_id = get_model_chunk_id(forward_k, forward=True) - gpc.set_virtual_pipeline_parallel_rank(forward_model_chunk_id) - if gpc.is_pipeline_last_stage(): - output_obj = None - - backward_model_chunk_id = get_model_chunk_id(backward_k, forward=False) - gpc.set_virtual_pipeline_parallel_rank(backward_model_chunk_id) - if gpc.is_pipeline_first_stage(): - input_obj_grad = None - - # Determine if peers are sending, and where in data structure to put - # received tensors. - recv_prev = True - if gpc.is_pipeline_first_stage(ignore_virtual=True): - # First stage is ahead of last stage by (pipeline_parallel_size - 1). - next_forward_model_chunk_id = get_model_chunk_id(forward_k - (pipeline_parallel_size - 1), forward=True) - if next_forward_model_chunk_id == (num_model_chunks - 1): - recv_prev = False - next_forward_model_chunk_id += 1 - else: - next_forward_model_chunk_id = get_model_chunk_id(forward_k + 1, forward=True) - - recv_next = True - if gpc.is_pipeline_last_stage(ignore_virtual=True): - # Last stage is ahead of first stage by (pipeline_parallel_size - 1). - next_backward_model_chunk_id = get_model_chunk_id(backward_k - (pipeline_parallel_size - 1), - forward=False) - if next_backward_model_chunk_id == 0: - recv_next = False - next_backward_model_chunk_id -= 1 - else: - next_backward_model_chunk_id = get_model_chunk_id(backward_k + 1, forward=False) - - # If last iteration, don't receive; we already received one extra - # before the start of the for loop. - if k == (num_microbatches_remaining - 1): - recv_prev = False - - input_shape = input_obj_shapes[next_forward_model_chunk_id] if recv_prev else None - output_shape = output_obj_shapes[next_backward_model_chunk_id] if recv_next else None - # Communicate objs. - input_obj, output_obj_grad = \ - comm.send_forward_backward_recv_forward_backward( - output_obj, input_obj_grad, - input_shape, - output_shape, - recv_prev=recv_prev, recv_next=recv_next, - dtype=self.dtype, - scatter_gather_tensors=self.scatter_gather_tensors) - - # Put input_obj and output_obj_grad in data structures in the - # right location. - if recv_prev: - input_objs[next_forward_model_chunk_id].append(input_obj) - if recv_next: - output_obj_grads[next_backward_model_chunk_id].append(output_obj_grad) - - # Run cooldown backward passes (flush out pipeline). - if not forward_only: - if all_warmup_microbatches: - output_obj_grads[num_model_chunks - 1].append( - comm.recv_backward(output_obj_shapes[num_model_chunks - 1], - scatter_gather_tensors=self.scatter_gather_tensors)) - for k in range(num_microbatches_remaining, num_microbatches): - input_obj_grad = _backward_step_helper(k) - next_backward_model_chunk_id = get_model_chunk_id(k + 1, forward=False) - recv_next = True - if gpc.is_pipeline_last_stage(ignore_virtual=True): - if next_backward_model_chunk_id == (num_model_chunks - 1): - recv_next = False - if k == (num_microbatches - 1): - recv_next = False - output_shape = output_obj_shapes[next_backward_model_chunk_id] if recv_next else None - output_obj_grads[next_backward_model_chunk_id].append( - comm.send_backward_recv_backward(input_obj_grad, - output_shape, - recv_next=recv_next, - dtype=self.dtype, - scatter_gather_tensors=self.scatter_gather_tensors)) - - if len(return_tensors) > 0: - output, label = pack_return_tensors(return_tensors) - return output, label, accum_loss - else: - return None, None, accum_loss diff --git a/colossalai/fx/_compatibility.py b/colossalai/fx/_compatibility.py index 0444a481627356e202912c491495364134aa65dd..4d40d5badfd0e4d6dd9b1083f21fb5e584c7d78e 100644 --- a/colossalai/fx/_compatibility.py +++ b/colossalai/fx/_compatibility.py @@ -2,16 +2,14 @@ from typing import Callable import torch -TORCH_MAJOR = int(torch.__version__.split('.')[0]) -TORCH_MINOR = int(torch.__version__.split('.')[1]) +TORCH_MAJOR = int(torch.__version__.split(".")[0]) +TORCH_MINOR = int(torch.__version__.split(".")[1]) if TORCH_MAJOR == 1 and TORCH_MINOR < 12: META_COMPATIBILITY = False elif TORCH_MAJOR == 1 and TORCH_MINOR == 12: - from . import _meta_regist_12 META_COMPATIBILITY = True elif TORCH_MAJOR == 1 and TORCH_MINOR == 13: - from . import _meta_regist_13 META_COMPATIBILITY = True elif TORCH_MAJOR == 2: META_COMPATIBILITY = True @@ -36,7 +34,7 @@ def compatibility(is_backward_compatible: bool = False) -> Callable: else: def wrapper(*args, **kwargs): - raise RuntimeError(f'Function `{func.__name__}` is not compatible with PyTorch {torch.__version__}') + raise RuntimeError(f"Function `{func.__name__}` is not compatible with PyTorch {torch.__version__}") return wrapper diff --git a/colossalai/fx/_meta_regist_12.py b/colossalai/fx/_meta_regist_12.py index 52e8d63ae54355a0d3e27dfc6a3347a2304dc0d7..63f88682e85a77b8ec277fdc099c1689a2b7160d 100644 --- a/colossalai/fx/_meta_regist_12.py +++ b/colossalai/fx/_meta_regist_12.py @@ -3,7 +3,7 @@ # refer to https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/native_functions.yaml # for more meta_registrations -from typing import Callable, List, Optional, Tuple, Union +from typing import List, Optional, Union import torch from torch.utils._pytree import tree_map @@ -16,13 +16,11 @@ meta_table = {} def register_meta(op, register_dispatcher=True): - def wrapper(f): - def add_func(op): meta_table[op] = f if register_dispatcher: - name = (op.__name__ if op._overloadname != "default" else op.overloadpacket.__name__) + name = op.__name__ if op._overloadname != "default" else op.overloadpacket.__name__ try: meta_lib.impl(name, f) except: @@ -48,7 +46,6 @@ def meta_conv( output_padding: List[int], groups: int, ): - def _formula(ln: int, p: int, d: int, k: int, s: int) -> int: """ Formula to apply to calculate the length of some dimension of the output @@ -125,7 +122,8 @@ def meta_conv( kernel_size[i], stride[i], output_padding_list[i], - )) + ) + ) else: ret_shape.append(_formula(dims[i], padding[i], dilation[i], kernel_size[i], stride[i])) return ret_shape @@ -159,22 +157,42 @@ def meta_conv( shape_out = calc_conv_nd_return_shape(dims, kernel_size, stride, padding, dilation) out = input_tensor.new_empty((input_tensor.shape[0], out_channels, *shape_out)) mem_fmt = pick_memory_format() - out = out.to(memory_format=mem_fmt) # type: ignore[call-overload] + out = out.to(memory_format=mem_fmt) # type: ignore[call-overload] return out @register_meta(aten._convolution.default) -def meta_conv_1(input_tensor: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor, stride: List[int], - padding: List[int], dilation: List[int], is_transposed: bool, output_padding: List[int], groups: int, - *extra_args): +def meta_conv_1( + input_tensor: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor, + stride: List[int], + padding: List[int], + dilation: List[int], + is_transposed: bool, + output_padding: List[int], + groups: int, + *extra_args, +): out = meta_conv(input_tensor, weight, bias, stride, padding, dilation, is_transposed, output_padding, groups) return out @register_meta(aten.convolution_backward.default) -def meta_conv_backward(grad_output: torch.Tensor, input: torch.Tensor, weight: torch.Tensor, bias_sizes, stride, - padding, dilation, transposed, output_padding, groups, output_mask): - return torch.empty_like(input), torch.empty_like(weight), torch.empty((bias_sizes), device='meta') +def meta_conv_backward( + grad_output: torch.Tensor, + input: torch.Tensor, + weight: torch.Tensor, + bias_sizes, + stride, + padding, + dilation, + transposed, + output_padding, + groups, + output_mask, +): + return torch.empty_like(input), torch.empty_like(weight), torch.empty((bias_sizes), device="meta") # https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/AdaptiveAveragePooling.cpp @@ -208,7 +226,6 @@ def meta_cuda_rnn( batch_sizes, dropout_state, ): - is_input_packed = len(batch_sizes) != 0 if is_input_packed: seq_length = len(batch_sizes) @@ -224,8 +241,11 @@ def meta_cuda_rnn( if is_input_packed: out_shape = [batch_sizes_sum, out_size * num_directions] else: - out_shape = ([mini_batch, seq_length, out_size * - num_directions] if batch_first else [seq_length, mini_batch, out_size * num_directions]) + out_shape = ( + [mini_batch, seq_length, out_size * num_directions] + if batch_first + else [seq_length, mini_batch, out_size * num_directions] + ) output = input.new_empty(out_shape) cell_shape = [num_layers * num_directions, mini_batch, hidden_size] @@ -242,18 +262,20 @@ def meta_cuda_rnn( # https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/cudnn/RNN.cpp @register_meta(aten._cudnn_rnn_backward.default) -def meta_cudnn_rnn_backward(input: torch.Tensor, - weight: torch.Tensor, - weight_stride0: int, - hx: torch.Tensor, - cx: Optional[torch.Tensor] = None, - *args, - **kwargs): +def meta_cudnn_rnn_backward( + input: torch.Tensor, + weight: torch.Tensor, + weight_stride0: int, + hx: torch.Tensor, + cx: Optional[torch.Tensor] = None, + *args, + **kwargs, +): print(input, weight, hx, cx) grad_input = torch.empty_like(input) grad_weight = torch.empty_like(weight) grad_hx = torch.empty_like(hx) - grad_cx = torch.empty_like(cx) if cx is not None else torch.empty((), device='meta') + grad_cx = torch.empty_like(cx) if cx is not None else torch.empty((), device="meta") return grad_input, grad_weight, grad_hx, grad_cx @@ -298,15 +320,25 @@ def meta_bn(input: torch.Tensor, weight, bias, running_mean, running_var, traini n_input = input.size(1) output = torch.empty_like(input) - running_mean = torch.empty((n_input), device='meta') - running_var = torch.empty((n_input), device='meta') + running_mean = torch.empty((n_input), device="meta") + running_var = torch.empty((n_input), device="meta") return output, running_mean, running_var # https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/cudnn/BatchNorm.cpp @register_meta(aten.native_batch_norm_backward.default) -def meta_bn_backward(dY: torch.Tensor, input: torch.Tensor, weight: torch.Tensor, running_mean, running_var, save_mean, - save_invstd, train, eps, output_mask): +def meta_bn_backward( + dY: torch.Tensor, + input: torch.Tensor, + weight: torch.Tensor, + running_mean, + running_var, + save_mean, + save_invstd, + train, + eps, + output_mask, +): dX = torch.empty_like(input) dgamma = torch.empty_like(weight) dbeta = torch.empty_like(weight) @@ -319,9 +351,9 @@ def meta_cudnn_bn(input: torch.Tensor, weight, bias, running_mean, running_var, n_input = input.size(1) output = torch.empty_like(input) - running_mean = torch.empty((n_input), device='meta') - running_var = torch.empty((n_input), device='meta') - reserve = torch.empty((0), dtype=torch.uint8, device='meta') + running_mean = torch.empty((n_input), device="meta") + running_var = torch.empty((n_input), device="meta") + reserve = torch.empty((0), dtype=torch.uint8, device="meta") return output, running_mean, running_var, reserve @@ -330,8 +362,17 @@ def meta_cudnn_bn(input: torch.Tensor, weight, bias, running_mean, running_var, # in training mode (evaluation mode batchnorm has a different algorithm), # which is why this doesn't accept a 'training' parameter. @register_meta(aten.cudnn_batch_norm_backward.default) -def meta_cudnn_bn_backward(dY: torch.Tensor, input: torch.Tensor, weight: torch.Tensor, running_mean, running_var, - save_mean, save_invstd, eps, reserve): +def meta_cudnn_bn_backward( + dY: torch.Tensor, + input: torch.Tensor, + weight: torch.Tensor, + running_mean, + running_var, + save_mean, + save_invstd, + eps, + reserve, +): dX = torch.empty_like(input) dgamma = torch.empty_like(weight) dbeta = torch.empty_like(weight) @@ -345,15 +386,16 @@ def meta_ln(input: torch.Tensor, normalized_shape, weight, bias, eps): n_input = input.size(1) output = torch.empty_like(input) - running_mean = torch.empty((bs, n_input, 1), device='meta') - running_var = torch.empty((bs, n_input, 1), device='meta') + running_mean = torch.empty((bs, n_input, 1), device="meta") + running_var = torch.empty((bs, n_input, 1), device="meta") return output, running_mean, running_var # https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/layer_norm.cpp @register_meta(aten.native_layer_norm_backward.default) -def meta_ln_backward(dY: torch.Tensor, input: torch.Tensor, normalized_shape, mean, rstd, weight, bias, - grad_input_mask): +def meta_ln_backward( + dY: torch.Tensor, input: torch.Tensor, normalized_shape, mean, rstd, weight, bias, grad_input_mask +): dX = torch.empty_like(input) dgamma = torch.empty_like(weight) dbeta = torch.empty_like(bias) @@ -397,16 +439,19 @@ def meta_index_Tensor(self, indices): result: List[Optional[torch.Tensor]] = [] for i, index in enumerate(indices): if index is not None: - assert index.dtype in [torch.long, torch.int8, torch.bool],\ - "tensors used as indices must be long, byte or bool tensors" + assert index.dtype in [ + torch.long, + torch.int8, + torch.bool, + ], "tensors used as indices must be long, byte or bool tensors" if index.dtype in [torch.int8, torch.bool]: nonzero = index.nonzero() k = len(result) assert k + index.ndim <= self.ndim, f"too many indices for tensor of dimension {self.ndim}" for j in range(index.ndim): - assert index.shape[j] == self.shape[ - k + - j], f"The shape of the mask {index.shape} at index {i} does not match the shape of the indexed tensor {self.shape} at index {k + j}" + assert ( + index.shape[j] == self.shape[k + j] + ), f"The shape of the mask {index.shape} at index {i} does not match the shape of the indexed tensor {self.shape} at index {k + j}" result.append(nonzero.select(1, j)) else: result.append(index) @@ -482,12 +527,15 @@ def meta_index_Tensor(self, indices): # ============================== Embedding ========================================= # https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/Embedding.cpp @register_meta(aten.embedding_dense_backward.default) -def meta_embedding_dense_backward(grad_output: torch.Tensor, indices: torch.Tensor, num_weights, padding_idx, - scale_grad_by_freq): - return torch.empty((num_weights, grad_output.size(-1)), - dtype=grad_output.dtype, - device=grad_output.device, - layout=grad_output.layout) +def meta_embedding_dense_backward( + grad_output: torch.Tensor, indices: torch.Tensor, num_weights, padding_idx, scale_grad_by_freq +): + return torch.empty( + (num_weights, grad_output.size(-1)), + dtype=grad_output.dtype, + device=grad_output.device, + layout=grad_output.layout, + ) # ============================== Dropout =========================================== diff --git a/colossalai/fx/codegen/activation_checkpoint_codegen.py b/colossalai/fx/codegen/activation_checkpoint_codegen.py index 5a72cb9ca923bcdf5f9b3daddad9b5e7c339d5c5..dfb5754d71c17e17386f265526cb4ff02f19d173 100644 --- a/colossalai/fx/codegen/activation_checkpoint_codegen.py +++ b/colossalai/fx/codegen/activation_checkpoint_codegen.py @@ -1,4 +1,4 @@ -from typing import Any, Callable, Dict, Iterable, List, Tuple +from typing import Any, Dict, Iterable, List, Tuple import torch @@ -18,6 +18,7 @@ try: magic_methods, ) from torch.fx.node import Argument, Node, _get_qualified_name, _type_repr, map_arg + CODEGEN_AVAILABLE = True except: from torch.fx.graph import ( @@ -32,12 +33,13 @@ except: magic_methods, ) from torch.fx.node import Argument, Node, _get_qualified_name, _type_repr, map_arg + CODEGEN_AVAILABLE = False if CODEGEN_AVAILABLE: - __all__ = ['ActivationCheckpointCodeGen'] + __all__ = ["ActivationCheckpointCodeGen"] else: - __all__ = ['python_code_with_activation_checkpoint'] + __all__ = ["python_code_with_activation_checkpoint"] def _gen_saved_tensors_hooks(): @@ -125,15 +127,14 @@ def _find_ckpt_regions(nodes: List[Node]): Find the checkpoint regions given a list of consecutive nodes. The outputs will be list of tuples, each tuple is in the form of (start_index, end_index). """ - ckpt_nodes = [] ckpt_regions = [] start = -1 end = -1 current_region = None for idx, node in enumerate(nodes): - if 'activation_checkpoint' in node.meta: - act_ckpt_label = node.meta['activation_checkpoint'] + if "activation_checkpoint" in node.meta: + act_ckpt_label = node.meta["activation_checkpoint"] # this activation checkpoint label is not set yet # meaning this is the first node of the activation ckpt region @@ -150,7 +151,7 @@ def _find_ckpt_regions(nodes: List[Node]): current_region = act_ckpt_label start = idx end = -1 - elif current_region is not None and not 'activation_checkpoint' in node.meta: + elif current_region is not None and not "activation_checkpoint" in node.meta: # used to check the case below # node ckpt states = [ckpt, ckpt, non-ckpt] end = idx - 1 @@ -178,8 +179,8 @@ def _find_offload_regions(nodes: List[Node]): current_region = None for idx, node in enumerate(nodes): - if 'activation_offload' in node.meta and isinstance(node.meta['activation_offload'], Iterable): - act_offload_label = node.meta['activation_offload'] + if "activation_offload" in node.meta and isinstance(node.meta["activation_offload"], Iterable): + act_offload_label = node.meta["activation_offload"] if current_region == None: current_region = act_offload_label @@ -226,9 +227,9 @@ def _gen_ckpt_usage(label, activation_offload, input_vars, output_vars, use_reen """ Generate the checkpoint function call code text """ - outputs = ', '.join(output_vars) - inputs = ', '.join(input_vars) - return f'{outputs} = colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_{label}, {activation_offload}, {inputs}, use_reentrant={use_reentrant})' + outputs = ", ".join(output_vars) + inputs = ", ".join(input_vars) + return f"{outputs} = colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_{label}, {activation_offload}, {inputs}, use_reentrant={use_reentrant})" def _end_of_ckpt(node: Node, check_idx: int) -> bool: @@ -240,9 +241,9 @@ def _end_of_ckpt(node: Node, check_idx: int) -> bool: Returns: bool """ - if 'activation_checkpoint' in node.meta: - if isinstance(node.meta['activation_checkpoint'], list): - return node.meta['activation_checkpoint'][check_idx] == None + if "activation_checkpoint" in node.meta: + if isinstance(node.meta["activation_checkpoint"], list): + return node.meta["activation_checkpoint"][check_idx] == None else: return False else: @@ -260,11 +261,11 @@ def _find_nested_ckpt_regions(nodes, check_idx=0): current_region = None for idx, node in enumerate(nodes): - if 'activation_checkpoint' in node.meta: - if isinstance(node.meta['activation_checkpoint'], int): - act_ckpt_label = node.meta['activation_checkpoint'] + if "activation_checkpoint" in node.meta: + if isinstance(node.meta["activation_checkpoint"], int): + act_ckpt_label = node.meta["activation_checkpoint"] else: - act_ckpt_label = node.meta['activation_checkpoint'][check_idx] + act_ckpt_label = node.meta["activation_checkpoint"][check_idx] # this activation checkpoint label is not set yet # meaning this is the first node of the activation ckpt region @@ -298,13 +299,9 @@ def _find_nested_ckpt_regions(nodes, check_idx=0): return ckpt_regions -def emit_ckpt_func(body, - ckpt_func, - node_list: List[Node], - emit_node_func, - delete_unused_value_func, - level=0, - in_ckpt=False): +def emit_ckpt_func( + body, ckpt_func, node_list: List[Node], emit_node_func, delete_unused_value_func, level=0, in_ckpt=False +): """Emit ckpt function in nested way Args: body: forward code, in recursive calls, this part will be checkpoint @@ -321,17 +318,17 @@ def emit_ckpt_func(body, inputs, outputs = _find_input_and_output_nodes(node_list) # if the current checkpoint function use int as label, using old generation method - if isinstance(node_list[0].meta['activation_checkpoint'], int): - label = node_list[0].meta['activation_checkpoint'] + if isinstance(node_list[0].meta["activation_checkpoint"], int): + label = node_list[0].meta["activation_checkpoint"] ckpt_fn_def = _gen_ckpt_fn_def(label, inputs) - ckpt_func.append(f'{ckpt_fn_def}\n') + ckpt_func.append(f"{ckpt_fn_def}\n") for node in node_list: emit_node_func(node, ckpt_func) - ckpt_func[-1] = ' ' + ckpt_func[-1] + ckpt_func[-1] = " " + ckpt_func[-1] delete_unused_value_func(node, ckpt_func) - ckpt_func.append(' ' + _gen_ckpt_output(outputs) + '\n\n') - activation_offload = node_list[0].meta.get('activation_offload', False) + ckpt_func.append(" " + _gen_ckpt_output(outputs) + "\n\n") + activation_offload = node_list[0].meta.get("activation_offload", False) usage = _gen_ckpt_usage(label, activation_offload, inputs, outputs, False) usage += "\n" body.append(usage) @@ -340,12 +337,12 @@ def emit_ckpt_func(body, else: # label given by each layer, e.g. if you are currently at level [0, 1, 1] # the label will be '0_1_1' - label = "_".join([str(idx) for idx in node_list[0].meta['activation_checkpoint'][:level + 1]]) + label = "_".join([str(idx) for idx in node_list[0].meta["activation_checkpoint"][: level + 1]]) ckpt_fn_def = _gen_ckpt_fn_def(label, inputs) - ckpt_func.append(f'{ckpt_fn_def}\n') + ckpt_func.append(f"{ckpt_fn_def}\n") # if there is more level to fetch - if level + 1 < len(node_list[0].meta['activation_checkpoint']): + if level + 1 < len(node_list[0].meta["activation_checkpoint"]): ckpt_regions = _find_nested_ckpt_regions(node_list, level + 1) start_idx = [item[0] for item in ckpt_regions] end_idx = [item[1] for item in ckpt_regions] @@ -358,38 +355,45 @@ def emit_ckpt_func(body, break if node_idx in start_idx: - ckpt_node_list = node_list[node_idx:end_idx[start_idx.index(node_idx)] + 1] - emit_ckpt_func(ckpt_func, ckpt_func_buffer, ckpt_node_list, emit_node_func, - delete_unused_value_func, level + 1, True) + ckpt_node_list = node_list[node_idx : end_idx[start_idx.index(node_idx)] + 1] + emit_ckpt_func( + ckpt_func, + ckpt_func_buffer, + ckpt_node_list, + emit_node_func, + delete_unused_value_func, + level + 1, + True, + ) node_idx += len(ckpt_node_list) else: node = node_list[node_idx] emit_node_func(node, ckpt_func) - ckpt_func[-1] = ' ' + ckpt_func[-1] + ckpt_func[-1] = " " + ckpt_func[-1] delete_unused_value_func(node, ckpt_func) node_idx += 1 - ckpt_func.append(' ' + _gen_ckpt_output(outputs) + '\n\n') + ckpt_func.append(" " + _gen_ckpt_output(outputs) + "\n\n") ckpt_func += ckpt_func_buffer - activation_offload = node_list[0].meta.get('activation_offload', False) - usage = _gen_ckpt_usage(label, activation_offload, inputs, outputs, False) + '\n' + activation_offload = node_list[0].meta.get("activation_offload", False) + usage = _gen_ckpt_usage(label, activation_offload, inputs, outputs, False) + "\n" if in_ckpt: - usage = ' ' + usage + usage = " " + usage body.append(usage) # last level else: for node in node_list: emit_node_func(node, ckpt_func) - ckpt_func[-1] = ' ' + ckpt_func[-1] + ckpt_func[-1] = " " + ckpt_func[-1] delete_unused_value_func(node, ckpt_func) - ckpt_func.append(' ' + _gen_ckpt_output(outputs) + '\n\n') - activation_offload = node_list[0].meta.get('activation_offload', False) - usage = _gen_ckpt_usage(label, activation_offload, inputs, outputs, False) + '\n' + ckpt_func.append(" " + _gen_ckpt_output(outputs) + "\n\n") + activation_offload = node_list[0].meta.get("activation_offload", False) + usage = _gen_ckpt_usage(label, activation_offload, inputs, outputs, False) + "\n" if in_ckpt: - usage = ' ' + usage + usage = " " + usage body.append(usage) @@ -420,7 +424,7 @@ def emit_code_with_nested_activation_checkpoint(body, ckpt_func, nodes, emit_nod # find the input and output var names for each offload region for idx, (start, end) in enumerate(offload_regions): - offload_node_list = node_list[start:end + 1] + offload_node_list = node_list[start : end + 1] inputs, outputs = _find_input_and_output_nodes(offload_node_list) offload_inputs.append(inputs) offload_outputs.append(outputs) @@ -436,7 +440,7 @@ def emit_code_with_nested_activation_checkpoint(body, ckpt_func, nodes, emit_nod # process ckpt_regions if node_idx in start_idx: - ckpt_node_list = node_list[node_idx:end_idx[start_idx.index(node_idx)] + 1] + ckpt_node_list = node_list[node_idx : end_idx[start_idx.index(node_idx)] + 1] emit_ckpt_func(body, ckpt_func, ckpt_node_list, emit_node_func, delete_unused_value_func) node_idx += len(ckpt_node_list) @@ -470,7 +474,7 @@ def emit_code_with_nested_activation_checkpoint(body, ckpt_func, nodes, emit_nod if within_offload_region: emit_node_func(node, body) - body[-1] = ' ' + body[-1] + body[-1] = " " + body[-1] delete_unused_value_func(node, body) else: @@ -508,14 +512,14 @@ def emit_code_with_activation_checkpoint(body, ckpt_func, nodes, emit_node_func, # find the input and output var names for each region for idx, (start, end) in enumerate(ckpt_regions): - ckpt_node_list = node_list[start:end + 1] + ckpt_node_list = node_list[start : end + 1] inputs, outputs = _find_input_and_output_nodes(ckpt_node_list) input_vars.append(inputs) output_vars.append(outputs) # find the input and output var names for each offload region for idx, (start, end) in enumerate(offload_regions): - offload_node_list = node_list[start:end + 1] + offload_node_list = node_list[start : end + 1] inputs, outputs = _find_input_and_output_nodes(offload_node_list) offload_inputs.append(inputs) offload_outputs.append(outputs) @@ -523,11 +527,11 @@ def emit_code_with_activation_checkpoint(body, ckpt_func, nodes, emit_node_func, # append code text to body for idx, node in enumerate(node_list): # if this is the first node of the ckpt region - # append the ckpt function defition + # append the ckpt function definition if idx in start_idx: label = start_idx.index(idx) ckpt_fn_def = _gen_ckpt_fn_def(label, input_vars[label]) - ckpt_func.append(f'{ckpt_fn_def}\n') + ckpt_func.append(f"{ckpt_fn_def}\n") within_ckpt_region = True if idx in offload_starts: @@ -559,12 +563,12 @@ def emit_code_with_activation_checkpoint(body, ckpt_func, nodes, emit_node_func, # NOTE: currently we separate body and ckpt_func definition if within_ckpt_region: emit_node_func(node, ckpt_func) - ckpt_func[-1] = ' ' + ckpt_func[-1] + ckpt_func[-1] = " " + ckpt_func[-1] delete_unused_value_func(node, ckpt_func) elif within_offload_region: emit_node_func(node, body) - body[-1] = ' ' + body[-1] + body[-1] = " " + body[-1] delete_unused_value_func(node, body) else: @@ -576,13 +580,13 @@ def emit_code_with_activation_checkpoint(body, ckpt_func, nodes, emit_node_func, # generate return statement label = end_idx.index(idx) return_statement = _gen_ckpt_output(output_vars[label]) - return_statement = f' {return_statement}\n\n' + return_statement = f" {return_statement}\n\n" ckpt_func.append(return_statement) # we need to check if the checkpoint need to offload the input start_node_idx = start_idx[label] - if 'activation_offload' in node_list[start_node_idx].meta: - activation_offload = node_list[start_node_idx].meta['activation_offload'] + if "activation_offload" in node_list[start_node_idx].meta: + activation_offload = node_list[start_node_idx].meta["activation_offload"] else: activation_offload = False @@ -594,8 +598,8 @@ def emit_code_with_activation_checkpoint(body, ckpt_func, nodes, emit_node_func, if input_node.op != "placeholder": non_leaf_input = 1 for user in input_node.users: - if 'activation_checkpoint' in user.meta: - if user.meta['activation_checkpoint'] == label: + if "activation_checkpoint" in user.meta: + if user.meta["activation_checkpoint"] == label: if user.op == "call_module": if hasattr(user.graph.owning_module.get_submodule(user.target), "inplace"): use_reentrant = not user.graph.owning_module.get_submodule(user.target).inplace @@ -610,7 +614,7 @@ def emit_code_with_activation_checkpoint(body, ckpt_func, nodes, emit_node_func, # generate checkpoint function call in a new line usage = _gen_ckpt_usage(label, activation_offload, input_vars[label], output_vars[label], use_reentrant) - usage += '\n' + usage += "\n" body.append(usage) within_ckpt_region = False @@ -621,7 +625,6 @@ def emit_code_with_activation_checkpoint(body, ckpt_func, nodes, emit_node_func, if CODEGEN_AVAILABLE: class ActivationCheckpointCodeGen(CodeGen): - def _gen_python_code(self, nodes, root_module: str, namespace: _Namespace) -> PythonCode: free_vars: List[str] = [] body: List[str] = [] @@ -629,7 +632,7 @@ if CODEGEN_AVAILABLE: wrapped_fns: Dict[str, None] = {} # Wrap string in list to pass by reference - maybe_return_annotation: List[str] = [''] + maybe_return_annotation: List[str] = [""] def add_global(name_hint: str, obj: Any): """Add an obj to be tracked as a global. @@ -637,7 +640,7 @@ if CODEGEN_AVAILABLE: Graph, like functions or types. Returns: the global name that should be used to reference 'obj' in generated source. """ - if _is_from_torch(obj) and obj != torch.device: # to support registering torch.device + if _is_from_torch(obj) and obj != torch.device: # to support registering torch.device # HACK: workaround for how torch custom ops are registered. We # can't import them like normal modules so they must retain their # fully qualified name. @@ -662,16 +665,16 @@ if CODEGEN_AVAILABLE: def type_repr(o: Any): if o == (): # Empty tuple is used for empty tuple type annotation Tuple[()] - return '()' + return "()" typename = _type_repr(o) - if hasattr(o, '__origin__'): + if hasattr(o, "__origin__"): # This is a generic type, e.g. typing.List[torch.Tensor] origin_type = _origin_type_map.get(o.__origin__, o.__origin__) origin_typename = add_global(_type_repr(origin_type), origin_type) - if hasattr(o, '__args__'): + if hasattr(o, "__args__"): # Assign global names for each of the inner type variables. args = [type_repr(arg) for arg in o.__args__] @@ -690,19 +693,18 @@ if CODEGEN_AVAILABLE: return add_global(typename, o) def _format_args(args: Tuple[Argument, ...], kwargs: Dict[str, Argument]) -> str: - def _get_repr(arg): # Handle NamedTuples (if it has `_fields`) via add_global. - if isinstance(arg, tuple) and hasattr(arg, '_fields'): + if isinstance(arg, tuple) and hasattr(arg, "_fields"): qualified_name = _get_qualified_name(type(arg)) global_name = add_global(qualified_name, type(arg)) return f"{global_name}{repr(tuple(arg))}" return repr(arg) - args_s = ', '.join(_get_repr(a) for a in args) - kwargs_s = ', '.join(f'{k} = {_get_repr(v)}' for k, v in kwargs.items()) + args_s = ", ".join(_get_repr(a) for a in args) + kwargs_s = ", ".join(f"{k} = {_get_repr(v)}" for k, v in kwargs.items()) if args_s and kwargs_s: - return f'{args_s}, {kwargs_s}' + return f"{args_s}, {kwargs_s}" return args_s or kwargs_s # Run through reverse nodes and record the first instance of a use @@ -728,90 +730,101 @@ if CODEGEN_AVAILABLE: not used in the remainder of the code are freed and the memory usage of the code is optimal. """ - if user.op == 'placeholder': + if user.op == "placeholder": return - if user.op == 'output': - body.append('\n') + if user.op == "output": + body.append("\n") return nodes_to_delete = user_to_last_uses.get(user, []) if len(nodes_to_delete): - to_delete_str = ' = '.join([repr(n) for n in nodes_to_delete] + ['None']) - body.append(f'; {to_delete_str}\n') + to_delete_str = " = ".join([repr(n) for n in nodes_to_delete] + ["None"]) + body.append(f"; {to_delete_str}\n") else: - body.append('\n') + body.append("\n") # NOTE: we add a variable to distinguish body and ckpt_func def emit_node(node: Node, body): - maybe_type_annotation = '' if node.type is None else f' : {type_repr(node.type)}' - if node.op == 'placeholder': + maybe_type_annotation = "" if node.type is None else f" : {type_repr(node.type)}" + if node.op == "placeholder": assert isinstance(node.target, str) - maybe_default_arg = '' if not node.args else f' = {repr(node.args[0])}' - free_vars.append(f'{node.target}{maybe_type_annotation}{maybe_default_arg}') - raw_name = node.target.replace('*', '') + maybe_default_arg = "" if not node.args else f" = {repr(node.args[0])}" + free_vars.append(f"{node.target}{maybe_type_annotation}{maybe_default_arg}") + raw_name = node.target.replace("*", "") if raw_name != repr(node): - body.append(f'{repr(node)} = {raw_name}\n') + body.append(f"{repr(node)} = {raw_name}\n") return - elif node.op == 'call_method': + elif node.op == "call_method": assert isinstance(node.target, str) body.append( - f'{repr(node)}{maybe_type_annotation} = {_format_target(repr(node.args[0]), node.target)}' - f'({_format_args(node.args[1:], node.kwargs)})') + f"{repr(node)}{maybe_type_annotation} = {_format_target(repr(node.args[0]), node.target)}" + f"({_format_args(node.args[1:], node.kwargs)})" + ) return - elif node.op == 'call_function': + elif node.op == "call_function": assert callable(node.target) # pretty print operators - if node.target.__module__ == '_operator' and node.target.__name__ in magic_methods: + if node.target.__module__ == "_operator" and node.target.__name__ in magic_methods: assert isinstance(node.args, tuple) - body.append(f'{repr(node)}{maybe_type_annotation} = ' - f'{magic_methods[node.target.__name__].format(*(repr(a) for a in node.args))}') + body.append( + f"{repr(node)}{maybe_type_annotation} = " + f"{magic_methods[node.target.__name__].format(*(repr(a) for a in node.args))}" + ) return # pretty print inplace operators; required for jit.script to work properly # not currently supported in normal FX graphs, but generated by torchdynamo - if node.target.__module__ == '_operator' and node.target.__name__ in inplace_methods: - body.append(f'{inplace_methods[node.target.__name__].format(*(repr(a) for a in node.args))}; ' - f'{repr(node)}{maybe_type_annotation} = {repr(node.args[0])}') + if node.target.__module__ == "_operator" and node.target.__name__ in inplace_methods: + body.append( + f"{inplace_methods[node.target.__name__].format(*(repr(a) for a in node.args))}; " + f"{repr(node)}{maybe_type_annotation} = {repr(node.args[0])}" + ) return qualified_name = _get_qualified_name(node.target) global_name = add_global(qualified_name, node.target) # special case for getattr: node.args could be 2-argument or 3-argument # 2-argument: attribute access; 3-argument: fall through to attrib function call with default value - if global_name == 'getattr' and \ - isinstance(node.args, tuple) and \ - isinstance(node.args[1], str) and \ - node.args[1].isidentifier() and \ - len(node.args) == 2: + if ( + global_name == "getattr" + and isinstance(node.args, tuple) + and isinstance(node.args[1], str) + and node.args[1].isidentifier() + and len(node.args) == 2 + ): body.append( - f'{repr(node)}{maybe_type_annotation} = {_format_target(repr(node.args[0]), node.args[1])}') + f"{repr(node)}{maybe_type_annotation} = {_format_target(repr(node.args[0]), node.args[1])}" + ) return body.append( - f'{repr(node)}{maybe_type_annotation} = {global_name}({_format_args(node.args, node.kwargs)})') - if node.meta.get('is_wrapped', False): + f"{repr(node)}{maybe_type_annotation} = {global_name}({_format_args(node.args, node.kwargs)})" + ) + if node.meta.get("is_wrapped", False): wrapped_fns.setdefault(global_name) return - elif node.op == 'call_module': + elif node.op == "call_module": assert isinstance(node.target, str) - body.append(f'{repr(node)}{maybe_type_annotation} = ' - f'{_format_target(root_module, node.target)}({_format_args(node.args, node.kwargs)})') + body.append( + f"{repr(node)}{maybe_type_annotation} = " + f"{_format_target(root_module, node.target)}({_format_args(node.args, node.kwargs)})" + ) return - elif node.op == 'get_attr': + elif node.op == "get_attr": assert isinstance(node.target, str) - body.append(f'{repr(node)}{maybe_type_annotation} = {_format_target(root_module, node.target)}') + body.append(f"{repr(node)}{maybe_type_annotation} = {_format_target(root_module, node.target)}") return - elif node.op == 'output': + elif node.op == "output": if node.type is not None: maybe_return_annotation[0] = f" -> {type_repr(node.type)}" body.append(self.generate_output(node.args[0])) return - raise NotImplementedError(f'node: {node.op} {node.target}') + raise NotImplementedError(f"node: {node.op} {node.target}") # Modified for activation checkpointing ckpt_func = [] # if any node has a list of labels for activation_checkpoint, we # will use nested type of activation checkpoint codegen - if any(isinstance(node.meta.get('activation_checkpoint', None), Iterable) for node in nodes): + if any(isinstance(node.meta.get("activation_checkpoint", None), Iterable) for node in nodes): emit_code_with_nested_activation_checkpoint(body, ckpt_func, nodes, emit_node, delete_unused_values) else: emit_code_with_activation_checkpoint(body, ckpt_func, nodes, emit_node, delete_unused_values) @@ -820,13 +833,13 @@ if CODEGEN_AVAILABLE: # If the Graph has no non-placeholder nodes, no lines for the body # have been emitted. To continue to have valid Python code, emit a # single pass statement - body.append('pass\n') + body.append("pass\n") if len(wrapped_fns) > 0: - wrap_name = add_global('wrap', torch.fx.wrap) - wrap_stmts = '\n'.join([f'{wrap_name}("{name}")' for name in wrapped_fns]) + wrap_name = add_global("wrap", torch.fx.wrap) + wrap_stmts = "\n".join([f'{wrap_name}("{name}")' for name in wrapped_fns]) else: - wrap_stmts = '' + wrap_stmts = "" if self._body_transformer: body = self._body_transformer(body) @@ -837,11 +850,11 @@ if CODEGEN_AVAILABLE: # as we need colossalai.utils.checkpoint, we need to import colossalai # in forward function prologue = self.gen_fn_def(free_vars, maybe_return_annotation[0]) - prologue = ''.join(ckpt_func) + prologue + prologue = "".join(ckpt_func) + prologue prologue = prologue - code = ''.join(body) - code = '\n'.join(' ' + line for line in code.split('\n')) + code = "".join(body) + code = "\n".join(" " + line for line in code.split("\n")) fn_code = f""" {wrap_stmts} {prologue} @@ -861,7 +874,7 @@ else: wrapped_fns: Dict[str, None] = {} # Wrap string in list to pass by reference - maybe_return_annotation: List[str] = [''] + maybe_return_annotation: List[str] = [""] def add_global(name_hint: str, obj: Any): """Add an obj to be tracked as a global. @@ -869,7 +882,7 @@ else: Graph, like functions or types. Returns: the global name that should be used to reference 'obj' in generated source. """ - if _is_from_torch(obj) and obj != torch.device: # to support registering torch.device + if _is_from_torch(obj) and obj != torch.device: # to support registering torch.device # HACK: workaround for how torch custom ops are registered. We # can't import them like normal modules so they must retain their # fully qualified name. @@ -894,12 +907,12 @@ else: def type_repr(o: Any): if o == (): # Empty tuple is used for empty tuple type annotation Tuple[()] - return '()' + return "()" typename = _type_repr(o) # This is a generic type, e.g. typing.List[torch.Tensor] - if hasattr(o, '__origin__'): + if hasattr(o, "__origin__"): origin_type = _origin_type_map.get(o.__origin__, o.__origin__) origin_typename = add_global(_type_repr(origin_type), origin_type) @@ -934,84 +947,94 @@ else: not used in the remainder of the code are freed and the memory usage of the code is optimal. """ - if user.op == 'placeholder': + if user.op == "placeholder": return - if user.op == 'output': - body.append('\n') + if user.op == "output": + body.append("\n") return nodes_to_delete = user_to_last_uses.get(user, []) if len(nodes_to_delete): - to_delete_str = ' = '.join([repr(n) for n in nodes_to_delete] + ['None']) - body.append(f'; {to_delete_str}\n') + to_delete_str = " = ".join([repr(n) for n in nodes_to_delete] + ["None"]) + body.append(f"; {to_delete_str}\n") else: - body.append('\n') + body.append("\n") # NOTE: we add a variable to distinguish body and ckpt_func def emit_node(node: Node, body): - maybe_type_annotation = '' if node.type is None else f' : {type_repr(node.type)}' - if node.op == 'placeholder': + maybe_type_annotation = "" if node.type is None else f" : {type_repr(node.type)}" + if node.op == "placeholder": assert isinstance(node.target, str) - maybe_default_arg = '' if not node.args else f' = {repr(node.args[0])}' - free_vars.append(f'{node.target}{maybe_type_annotation}{maybe_default_arg}') - raw_name = node.target.replace('*', '') + maybe_default_arg = "" if not node.args else f" = {repr(node.args[0])}" + free_vars.append(f"{node.target}{maybe_type_annotation}{maybe_default_arg}") + raw_name = node.target.replace("*", "") if raw_name != repr(node): - body.append(f'{repr(node)} = {raw_name}\n') + body.append(f"{repr(node)} = {raw_name}\n") return - elif node.op == 'call_method': + elif node.op == "call_method": assert isinstance(node.target, str) - body.append(f'{repr(node)}{maybe_type_annotation} = {_format_target(repr(node.args[0]), node.target)}' - f'({_format_args(node.args[1:], node.kwargs)})') + body.append( + f"{repr(node)}{maybe_type_annotation} = {_format_target(repr(node.args[0]), node.target)}" + f"({_format_args(node.args[1:], node.kwargs)})" + ) return - elif node.op == 'call_function': + elif node.op == "call_function": assert callable(node.target) # pretty print operators - if node.target.__module__ == '_operator' and node.target.__name__ in magic_methods: + if node.target.__module__ == "_operator" and node.target.__name__ in magic_methods: assert isinstance(node.args, tuple) - body.append(f'{repr(node)}{maybe_type_annotation} = ' - f'{magic_methods[node.target.__name__].format(*(repr(a) for a in node.args))}') + body.append( + f"{repr(node)}{maybe_type_annotation} = " + f"{magic_methods[node.target.__name__].format(*(repr(a) for a in node.args))}" + ) return qualified_name = _get_qualified_name(node.target) global_name = add_global(qualified_name, node.target) # special case for getattr: node.args could be 2-argument or 3-argument # 2-argument: attribute access; 3-argument: fall through to attrib function call with default value - if global_name == 'getattr' and \ - isinstance(node.args, tuple) and \ - isinstance(node.args[1], str) and \ - node.args[1].isidentifier() and \ - len(node.args) == 2: + if ( + global_name == "getattr" + and isinstance(node.args, tuple) + and isinstance(node.args[1], str) + and node.args[1].isidentifier() + and len(node.args) == 2 + ): body.append( - f'{repr(node)}{maybe_type_annotation} = {_format_target(repr(node.args[0]), node.args[1])}') + f"{repr(node)}{maybe_type_annotation} = {_format_target(repr(node.args[0]), node.args[1])}" + ) return body.append( - f'{repr(node)}{maybe_type_annotation} = {global_name}({_format_args(node.args, node.kwargs)})') - if node.meta.get('is_wrapped', False): + f"{repr(node)}{maybe_type_annotation} = {global_name}({_format_args(node.args, node.kwargs)})" + ) + if node.meta.get("is_wrapped", False): wrapped_fns.setdefault(global_name) return - elif node.op == 'call_module': + elif node.op == "call_module": assert isinstance(node.target, str) - body.append(f'{repr(node)}{maybe_type_annotation} = ' - f'{_format_target(root_module, node.target)}({_format_args(node.args, node.kwargs)})') + body.append( + f"{repr(node)}{maybe_type_annotation} = " + f"{_format_target(root_module, node.target)}({_format_args(node.args, node.kwargs)})" + ) return - elif node.op == 'get_attr': + elif node.op == "get_attr": assert isinstance(node.target, str) - body.append(f'{repr(node)}{maybe_type_annotation} = {_format_target(root_module, node.target)}') + body.append(f"{repr(node)}{maybe_type_annotation} = {_format_target(root_module, node.target)}") return - elif node.op == 'output': + elif node.op == "output": if node.type is not None: maybe_return_annotation[0] = f" -> {type_repr(node.type)}" if self._pytree_info is None: - body.append(f'return {repr(node.args[0])}') + body.append(f"return {repr(node.args[0])}") else: - body.append(f'return pytree.tree_unflatten({repr(node.args[0])}, self._out_spec)') + body.append(f"return pytree.tree_unflatten({repr(node.args[0])}, self._out_spec)") return - raise NotImplementedError(f'node: {node.op} {node.target}') + raise NotImplementedError(f"node: {node.op} {node.target}") # Modified for activation checkpointing ckpt_func = [] # if any node has a list of labels for activation_checkpoint, we # will use nested type of activation checkpoint codegen - if any(isinstance(node.meta.get('activation_checkpoint', None), Iterable) for node in self.nodes): + if any(isinstance(node.meta.get("activation_checkpoint", None), Iterable) for node in self.nodes): emit_code_with_nested_activation_checkpoint(body, ckpt_func, self.nodes, emit_node, delete_unused_values) else: emit_code_with_activation_checkpoint(body, ckpt_func, self.nodes, emit_node, delete_unused_values) @@ -1020,33 +1043,34 @@ else: # If the Graph has no non-placeholder nodes, no lines for the body # have been emitted. To continue to have valid Python code, emit a # single pass statement - body.append('pass\n') + body.append("pass\n") if self._pytree_info is not None: orig_args = self._pytree_info.orig_args - has_orig_self = (orig_args[0] == 'self') + has_orig_self = orig_args[0] == "self" if has_orig_self: - free_vars.insert(0, 'self') - if len(free_vars) > 0: # pytree has placeholders in it + free_vars.insert(0, "self") + if len(free_vars) > 0: # pytree has placeholders in it body.insert( 0, - f"{', '.join(free_vars)}, = fx_pytree.tree_flatten_spec([{', '.join(orig_args)}], self._in_spec)\n") + f"{', '.join(free_vars)}, = fx_pytree.tree_flatten_spec([{', '.join(orig_args)}], self._in_spec)\n", + ) else: orig_args = free_vars if len(wrapped_fns) > 0: - wrap_name = add_global('wrap', torch.fx.wrap) - wrap_stmts = '\n'.join([f'{wrap_name}("{name}")' for name in wrapped_fns]) + wrap_name = add_global("wrap", torch.fx.wrap) + wrap_stmts = "\n".join([f'{wrap_name}("{name}")' for name in wrapped_fns]) else: - wrap_stmts = '' + wrap_stmts = "" - ckpt_func = ''.join(ckpt_func) + ckpt_func = "".join(ckpt_func) # If the original function didn't have self as its first argument, we # would have added it. - if len(orig_args) == 0 or orig_args[0] != 'self': - orig_args.insert(0, 'self') - code = ''.join(body) - code = '\n'.join(' ' + line for line in code.split('\n')) + if len(orig_args) == 0 or orig_args[0] != "self": + orig_args.insert(0, "self") + code = "".join(body) + code = "\n".join(" " + line for line in code.split("\n")) # as we need colossalai.utils.checkpoint, we need to import colossalai # in forward function diff --git a/colossalai/fx/graph_module.py b/colossalai/fx/graph_module.py index ebb9975f27dbf312870fdb9d7beea98faf7889b7..8429a9607f7a7aca4d15f89a2a6207dbd94083e4 100644 --- a/colossalai/fx/graph_module.py +++ b/colossalai/fx/graph_module.py @@ -1,32 +1,35 @@ import os import warnings from pathlib import Path -from typing import Any, Dict, List, Optional, Set, Type, Union +from typing import Any, Dict, Optional, Union import torch import torch.nn as nn from torch.nn.modules.module import _addindent try: - from torch.fx.graph import Graph, PythonCode, _custom_builtins, _is_from_torch, _PyTreeCodeGen - from torch.fx.graph_module import GraphModule, _EvalCacheLoader, _exec_with_source, _forward_from_src, _WrappedCall + from torch.fx.graph import Graph, PythonCode, _PyTreeCodeGen + from torch.fx.graph_module import GraphModule, _exec_with_source, _forward_from_src, _WrappedCall from colossalai.fx.codegen.activation_checkpoint_codegen import ActivationCheckpointCodeGen + COLOGM = True except: from torch.fx.graph import Graph from torch.fx.graph_module import GraphModule + COLOGM = False if COLOGM: class ColoGraphModule(GraphModule): - - def __init__(self, - root: Union[torch.nn.Module, Dict[str, Any]], - graph: Graph, - class_name: str = 'GraphModule', - ckpt_codegen: bool = True): + def __init__( + self, + root: Union[torch.nn.Module, Dict[str, Any]], + graph: Graph, + class_name: str = "GraphModule", + ckpt_codegen: bool = True, + ): if ckpt_codegen: graph.set_codegen(ActivationCheckpointCodeGen()) super().__init__(root, graph, class_name) @@ -60,7 +63,7 @@ if COLOGM: if isinstance(self._graph._codegen, _PyTreeCodeGen): self._in_spec = self._graph._codegen.pytree_info.in_spec self._out_spec = self._graph._codegen.pytree_info.out_spec - python_code = self._graph.python_code(root_module='self') + python_code = self._graph.python_code(root_module="self") self._code = python_code.src # To split ckpt functions code and forward code @@ -83,8 +86,8 @@ if COLOGM: # bypass patching of torch.nn.Module.__call__ done while symbolic tracing. cls_call = cls.__call__ if "__call__" in vars(cls) else None - if '_wrapped_call' not in vars(cls): - cls._wrapped_call = _WrappedCall(cls, cls_call) # type: ignore[attr-defined] + if "_wrapped_call" not in vars(cls): + cls._wrapped_call = _WrappedCall(cls, cls_call) # type: ignore[attr-defined] def call_wrapped(self, *args, **kwargs): return self._wrapped_call(self, *args, **kwargs) @@ -108,7 +111,7 @@ if COLOGM: """ folder = Path(folder) Path(folder).mkdir(exist_ok=True) - torch.save(self.state_dict(), folder / 'state_dict.pt') + torch.save(self.state_dict(), folder / "state_dict.pt") tab = " " * 4 # we add import colossalai here @@ -125,7 +128,13 @@ class {module_name}(torch.nn.Module): def _gen_model_repr(module_name: str, module: torch.nn.Module) -> Optional[str]: safe_reprs = [ - nn.Linear, nn.Conv1d, nn.Conv2d, nn.Conv3d, nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d + nn.Linear, + nn.Conv1d, + nn.Conv2d, + nn.Conv3d, + nn.BatchNorm1d, + nn.BatchNorm2d, + nn.BatchNorm3d, ] if type(module) in safe_reprs: return f"{module.__repr__()}" @@ -136,10 +145,10 @@ class {module_name}(torch.nn.Module): for module_name, module in self.named_children(): module_str = _gen_model_repr(module_name, module) if module_str is None: - module_file = folder / f'{module_name}.pt' + module_file = folder / f"{module_name}.pt" torch.save(module, module_file) blobified_modules.append(module_name) - module_repr = module.__repr__().replace('\r', ' ').replace('\n', ' ') + module_repr = module.__repr__().replace("\r", " ").replace("\n", " ") module_str = f"torch.load(r'{module_file}') # {module_repr}" model_str += f"{tab*2}self.{module_name} = {module_str}\n" @@ -156,19 +165,20 @@ class {module_name}(torch.nn.Module): model_str += f"{tab*2}self.load_state_dict(torch.load(r'{folder}/state_dict.pt'))\n" model_str += f"{_addindent(self.code, 4)}\n" - module_file = folder / 'module.py' + module_file = folder / "module.py" module_file.write_text(model_str) - init_file = folder / '__init__.py' - init_file.write_text('from .module import *') + init_file = folder / "__init__.py" + init_file.write_text("from .module import *") if len(blobified_modules) > 0: - warnings.warn("Was not able to save the following children modules as reprs -" - f"saved as pickled files instead: {blobified_modules}") + warnings.warn( + "Was not able to save the following children modules as reprs -" + f"saved as pickled files instead: {blobified_modules}" + ) else: class ColoGraphModule(GraphModule): - - def __init__(self, root: Union[torch.nn.Module, Dict[str, Any]], graph: Graph, class_name: str = 'GraphModule'): + def __init__(self, root: Union[torch.nn.Module, Dict[str, Any]], graph: Graph, class_name: str = "GraphModule"): super().__init__(root, graph, class_name) diff --git a/colossalai/fx/passes/adding_split_node_pass.py b/colossalai/fx/passes/adding_split_node_pass.py index 2c7b842b530cc12fb669154456dc51e489ccc85d..99c8faaa0cc6927b3aaf7649b3048a071bbeb2dd 100644 --- a/colossalai/fx/passes/adding_split_node_pass.py +++ b/colossalai/fx/passes/adding_split_node_pass.py @@ -1,8 +1,6 @@ import numpy as np import torch import tqdm -from torch.fx import symbolic_trace -from torch.fx.node import Node from colossalai.fx.passes.split_module import split_module @@ -29,15 +27,15 @@ def construct_blocks(gm: torch.fx.GraphModule, limit=0.01): accumulate_bwd_flop = 0 block_nodes = [] for node in gm.graph.nodes: - if 'block_split' in node.name: + if "block_split" in node.name: continue accumulate_fwd_flop += node.fwd_flop accumulate_bwd_flop += node.bwd_flop if accumulate_fwd_flop + accumulate_bwd_flop >= per_block_flop: with gm.graph.inserting_after(node): - block_node = gm.graph.create_node('call_function', block_split) - setattr(block_node, 'fwd_flop', accumulate_fwd_flop) - setattr(block_node, 'bwd_flop', accumulate_bwd_flop) + block_node = gm.graph.create_node("call_function", block_split) + setattr(block_node, "fwd_flop", accumulate_fwd_flop) + setattr(block_node, "bwd_flop", accumulate_bwd_flop) accumulate_fwd_flop = 0 accumulate_bwd_flop = 0 block_nodes.append(block_node) @@ -47,7 +45,7 @@ def construct_blocks(gm: torch.fx.GraphModule, limit=0.01): def remove_blocks(gm: torch.fx.GraphModule): for node in gm.graph.nodes: - if (node.op, node.target) == ('call_function', block_split): + if (node.op, node.target) == ("call_function", block_split): gm.graph.erase_node(node) @@ -55,8 +53,8 @@ def get_compute_costs(node_list): num_nodes = len(node_list) all_compute_cost = np.full((num_nodes, num_nodes), np.inf, dtype=np.float64) - for start in tqdm.tqdm(range(num_nodes), desc='start pos', position=0): - for end in tqdm.tqdm(range(start, num_nodes), desc='end pos', position=1, leave=False): + for start in tqdm.tqdm(range(num_nodes), desc="start pos", position=0): + for end in tqdm.tqdm(range(start, num_nodes), desc="end pos", position=1, leave=False): selected_flops = [(node_list[i].fwd_flop + node_list[i].bwd_flop) for i in range(start, end + 1)] all_compute_cost[start, end] = sum(selected_flops) @@ -78,12 +76,14 @@ def do_dp_split_gpipe_impl(num_nodes, num_stages, num_microbatches, compute_cost # record start node index for next stage in this partition f_argmin = np.full((num_stages + 1, num_nodes + 1), -1, dtype=np.int32) f[0, num_nodes] = 0 - for s in tqdm.tqdm(range(1, num_stages + 1), desc='stage', position=2, leave=False): # pylint: disable=too-many-nested-blocks - for i in tqdm.tqdm(range(num_nodes - 1, -1, -1), desc='start node', position=3, leave=False): - for k in tqdm.tqdm(range(num_nodes, i, -1), desc='mid node', position=4, leave=False): + for s in tqdm.tqdm( + range(1, num_stages + 1), desc="stage", position=2, leave=False + ): # pylint: disable=too-many-nested-blocks + for i in tqdm.tqdm(range(num_nodes - 1, -1, -1), desc="start node", position=3, leave=False): + for k in tqdm.tqdm(range(num_nodes, i, -1), desc="mid node", position=4, leave=False): stage_cost = compute_costs[i, k - 1] new_cost = f[s - 1, k] + stage_cost - if (stage_cost <= max_compute_cost and new_cost < f[s, i]): + if stage_cost <= max_compute_cost and new_cost < f[s, i]: f[s, i] = new_cost f_stage_max[s, i] = max(f_stage_max[s - 1, k], stage_cost) f_argmin[s, i] = k @@ -113,7 +113,7 @@ def do_dp_split_gpipe(node_list, compute_costs, num_stages: int, num_microbatche best_cost = np.inf best_solution = None last_max_compute_cost = 0.0 - gap = 1e6 # temporary magic number, unit: flops + gap = 1e6 # temporary magic number, unit: flops for max_compute_cost in tqdm.tqdm(max_compute_costs): # Pruning to reduce search space. @@ -122,8 +122,9 @@ def do_dp_split_gpipe(node_list, compute_costs, num_stages: int, num_microbatche if max_compute_cost - last_max_compute_cost < gap: continue - cost, solution = do_dp_split_gpipe_impl(len(node_list), num_stages, num_microbatches, compute_costs, - max_compute_cost) + cost, solution = do_dp_split_gpipe_impl( + len(node_list), num_stages, num_microbatches, compute_costs, max_compute_cost + ) if cost < best_cost: best_cost = cost @@ -137,15 +138,15 @@ def do_dp_split_gpipe(node_list, compute_costs, num_stages: int, num_microbatche # split_mode: # 'node': fx_node # 'block': many fx_nodes construct a block -def gpipe_dp_split_pass(gm: torch.fx.GraphModule, pp_size: int, num_microbatches: int, mode='block', block_limit=0.01): - assert mode in ['node', 'block'] +def gpipe_dp_split_pass(gm: torch.fx.GraphModule, pp_size: int, num_microbatches: int, mode="block", block_limit=0.01): + assert mode in ["node", "block"] # nodes or blocks will be used in partition. node_list = [] - if mode == 'node': + if mode == "node": for node in gm.graph.nodes: node_list.append(node) - elif mode == 'block': + elif mode == "block": node_list = construct_blocks(gm, limit=block_limit) else: pass @@ -154,16 +155,16 @@ def gpipe_dp_split_pass(gm: torch.fx.GraphModule, pp_size: int, num_microbatches best_cost, best_solution = do_dp_split_gpipe(node_list, compute_costs, pp_size, num_microbatches) - for (_, next_start_node) in best_solution: + for _, next_start_node in best_solution: if pp_size <= 1: break node = node_list[next_start_node] with gm.graph.inserting_before(node): - split_node = gm.graph.create_node('call_function', pipe_split) + split_node = gm.graph.create_node("call_function", pipe_split) pp_size -= 1 # remove block node if possible - if mode == 'block': + if mode == "block": remove_blocks(gm) gm.recompile() @@ -178,7 +179,7 @@ def avgcompute_split_pass(gm: torch.fx.GraphModule, pp_size: int): # To use avgcompute_split_pass, we need run meta_info_prop interpreter first. # If nodes don't have meta info, this pass will fall back to normal balanced split pass. check_node = list(mod_graph.nodes)[0] - if 'tensor_meta' not in check_node.meta: + if "tensor_meta" not in check_node.meta: return balanced_split_pass(gm, pp_size) total_fwd_flop = 0 @@ -190,7 +191,7 @@ def avgcompute_split_pass(gm: torch.fx.GraphModule, pp_size: int): for node in mod_graph.nodes: if pp_size <= 1: break - if 'pipe_split' in node.name: + if "pipe_split" in node.name: continue accumulate_fwd_flop += node.fwd_flop if accumulate_fwd_flop >= partition_flop: @@ -199,14 +200,14 @@ def avgcompute_split_pass(gm: torch.fx.GraphModule, pp_size: int): pp_size -= 1 partition_flop = total_fwd_flop // pp_size with mod_graph.inserting_after(node): - split_node = mod_graph.create_node('call_function', pipe_split) + split_node = mod_graph.create_node("call_function", pipe_split) gm.recompile() return gm def avgnode_split_pass(gm: torch.fx.GraphModule, pp_size: int): """ - In avgnode_split_pass, simpliy split graph by node number. + In avgnode_split_pass, simply split graph by node number. """ mod_graph = gm.graph avg_num_node = len(mod_graph.nodes) // pp_size @@ -218,12 +219,12 @@ def avgnode_split_pass(gm: torch.fx.GraphModule, pp_size: int): if accumulate_num_node >= avg_num_node: accumulate_num_node = 0 pp_size -= 1 - if node.next.op == 'output': + if node.next.op == "output": with mod_graph.inserting_before(node): - split_node = mod_graph.create_node('call_function', pipe_split) + split_node = mod_graph.create_node("call_function", pipe_split) else: with mod_graph.inserting_after(node): - split_node = mod_graph.create_node('call_function', pipe_split) + split_node = mod_graph.create_node("call_function", pipe_split) gm.recompile() return gm @@ -250,18 +251,18 @@ def balanced_split_pass(gm: torch.fx.GraphModule, pp_size: int): pp_size -= 1 # If the next node is output node, we will insert split annotation before # node to make sure there is at least one node in last partition. - if node.next.op == 'output': + if node.next.op == "output": with mod_graph.inserting_before(node): - split_node = mod_graph.create_node('call_function', pipe_split) + split_node = mod_graph.create_node("call_function", pipe_split) else: with mod_graph.inserting_after(node): - split_node = mod_graph.create_node('call_function', pipe_split) + split_node = mod_graph.create_node("call_function", pipe_split) if pp_size > 1: node_counter = 0 for node in mod_graph.nodes: if pp_size <= 1: break - if node.op == 'placeholder': + if node.op == "placeholder": continue elif node_counter == 0: node_counter += 1 @@ -269,7 +270,7 @@ def balanced_split_pass(gm: torch.fx.GraphModule, pp_size: int): pp_size -= 1 node_counter = 0 with mod_graph.inserting_before(node): - split_node = mod_graph.create_node('call_function', pipe_split) + split_node = mod_graph.create_node("call_function", pipe_split) gm.recompile() return gm @@ -283,7 +284,7 @@ def balanced_split_pass_v2(gm: torch.fx.GraphModule, pp_size: int): # To use balanced_split_pass_v2, we need run meta_info_prop interpreter first. # If nodes don't have meta info, this pass will fall back to normal balanced split pass. check_node = list(mod_graph.nodes)[0] - if 'tensor_meta' not in check_node.meta: + if "tensor_meta" not in check_node.meta: return balanced_split_pass(gm, pp_size) total_element_size = 0 @@ -295,7 +296,7 @@ def balanced_split_pass_v2(gm: torch.fx.GraphModule, pp_size: int): for node in mod_graph.nodes: if pp_size <= 1: break - if 'pipe_split' in node.name: + if "pipe_split" in node.name: continue accumulate_node_size += node.node_size if accumulate_node_size >= partition_size: @@ -304,7 +305,7 @@ def balanced_split_pass_v2(gm: torch.fx.GraphModule, pp_size: int): pp_size -= 1 partition_size = total_element_size // pp_size with mod_graph.inserting_after(node): - split_node = mod_graph.create_node('call_function', pipe_split) + split_node = mod_graph.create_node("call_function", pipe_split) gm.recompile() return gm @@ -333,7 +334,7 @@ def uniform_split_pass(gm: torch.fx.GraphModule, pp_size: int): accumulate_layer_amount = 0 pp_size -= 1 with mod_graph.inserting_after(node): - split_node = mod_graph.create_node('call_function', pipe_split) + split_node = mod_graph.create_node("call_function", pipe_split) gm.recompile() return gm @@ -346,7 +347,7 @@ def split_with_split_nodes_pass(annotated_gm: torch.fx.GraphModule, merge_output def split_callback(n: torch.fx.Node): nonlocal part_idx - if (n.op, n.target) == ('call_function', pipe_split): + if (n.op, n.target) == ("call_function", pipe_split): part_idx += 1 return part_idx @@ -355,7 +356,7 @@ def split_with_split_nodes_pass(annotated_gm: torch.fx.GraphModule, merge_output for name, submodule in split_mod.named_modules(): if isinstance(submodule, torch.fx.GraphModule): for node in submodule.graph.nodes: - if (node.op, node.target) == ('call_function', pipe_split): + if (node.op, node.target) == ("call_function", pipe_split): submodule.graph.erase_node(node) submodule.recompile() split_submodules.append(submodule) diff --git a/colossalai/fx/passes/concrete_info_prop.py b/colossalai/fx/passes/concrete_info_prop.py index 81ac6420552815a5cea2d3fe5aef175efb36ece0..5440a4eadbbfbd336668e63a773c652323bac12e 100644 --- a/colossalai/fx/passes/concrete_info_prop.py +++ b/colossalai/fx/passes/concrete_info_prop.py @@ -1,5 +1,5 @@ from dataclasses import asdict -from typing import Any, Dict, List, NamedTuple, Optional, Tuple +from typing import Any, Dict, List, Optional, Tuple import torch import torch.fx @@ -85,10 +85,10 @@ class ConcreteInfoProp(torch.fx.Interpreter): self._is_proped = True result, meta_info = super().run_node(n) - n.meta = {**n.meta, **asdict(meta_info)} # extend MetaInfo to `n.meta` + n.meta = {**n.meta, **asdict(meta_info)} # extend MetaInfo to `n.meta` # TODO: the attribute node_size should be removed in the future - setattr(n, 'node_size', n.meta.get('fwd_mem_tmp', 0) + n.meta.get('fwd_mem_out', 0)) - n.meta['type'] = type(result) + setattr(n, "node_size", n.meta.get("fwd_mem_tmp", 0) + n.meta.get("fwd_mem_out", 0)) + n.meta["type"] = type(result) # retain the autograd graph for param in self.module.parameters(): @@ -98,7 +98,7 @@ class ConcreteInfoProp(torch.fx.Interpreter): # Main Node running APIs @compatibility(is_backward_compatible=True) - def placeholder(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any: + def placeholder(self, target: "Target", args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any: """ Execute a ``placeholder`` node. Note that this is stateful: ``Interpreter`` maintains an internal iterator over @@ -119,7 +119,7 @@ class ConcreteInfoProp(torch.fx.Interpreter): return super().placeholder(target, args, kwargs), GraphInfo() @compatibility(is_backward_compatible=True) - def get_attr(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any: + def get_attr(self, target: "Target", args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any: """ Execute a ``get_attr`` node. Will retrieve an attribute value from the ``Module`` hierarchy of ``self.module``. @@ -138,7 +138,7 @@ class ConcreteInfoProp(torch.fx.Interpreter): return super().get_attr(target, args, kwargs), GraphInfo() @compatibility(is_backward_compatible=True) - def call_function(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any: + def call_function(self, target: "Target", args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any: """ Execute a ``call_function`` node with meta tensor and return the result and its meta profile. @@ -157,7 +157,7 @@ class ConcreteInfoProp(torch.fx.Interpreter): return profile_function(target, self.device)(*args, **kwargs) @compatibility(is_backward_compatible=True) - def call_method(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any: + def call_method(self, target: "Target", args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any: """ Execute a ``call_method`` node with meta tensor and return the result and its meta profile. @@ -175,7 +175,7 @@ class ConcreteInfoProp(torch.fx.Interpreter): return profile_method(target, self.device)(*args, **kwargs) @compatibility(is_backward_compatible=True) - def call_module(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any: + def call_module(self, target: "Target", args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any: """ Execute a ``call_module`` node with meta tensor and return the result and its meta profile. @@ -197,7 +197,7 @@ class ConcreteInfoProp(torch.fx.Interpreter): return profile_module(submod, self.device)(*args, **kwargs) @compatibility(is_backward_compatible=True) - def output(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any: + def output(self, target: "Target", args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any: """ Execute an ``output`` node. This really just retrieves the value referenced by the ``output`` node and returns it. @@ -228,7 +228,7 @@ class ConcreteInfoProp(torch.fx.Interpreter): """ return self.run(*args) - def summary(self, unit: str = 'MB') -> str: + def summary(self, unit: str = "MB") -> str: """ Summarizes the memory and FLOPs statistics of the `GraphModule` in tabular format. Note that this API requires the ``tabulate`` module @@ -238,9 +238,11 @@ class ConcreteInfoProp(torch.fx.Interpreter): try: from tabulate import tabulate except ImportError: - print("`summary` relies on the library `tabulate`, " - "which could not be found on this machine. Run `pip " - "install tabulate` to install the library.") + print( + "`summary` relies on the library `tabulate`, " + "which could not be found on this machine. Run `pip " + "install tabulate` to install the library." + ) assert self._is_proped, "Please call `interp.run(input)` before calling `interp.summary()`." @@ -249,10 +251,10 @@ class ConcreteInfoProp(torch.fx.Interpreter): def mem_repr(mem: int) -> str: unit_divisor_map = { - 'kb': 1024, - 'mb': 1024**2, - 'gb': 1024**3, - 'tb': 1024**4, + "kb": 1024, + "mb": 1024**2, + "gb": 1024**3, + "tb": 1024**4, } return f"{mem / unit_divisor_map[unit.lower()]:.2f} {unit.upper()}" @@ -261,30 +263,32 @@ class ConcreteInfoProp(torch.fx.Interpreter): for node in self.module.graph.nodes: node: Node - node_summaries.append([ - node.op, - str(node), - time_repr(node.meta['fwd_time']), - time_repr(node.meta['bwd_time']), - node.meta['save_fwd_in'], - mem_repr(node.meta['fwd_mem_out']), - mem_repr(node.meta['fwd_mem_tmp']), - mem_repr(node.meta['bwd_mem_out']), - mem_repr(node.meta['bwd_mem_tmp']), - ]) + node_summaries.append( + [ + node.op, + str(node), + time_repr(node.meta["fwd_time"]), + time_repr(node.meta["bwd_time"]), + node.meta["save_fwd_in"], + mem_repr(node.meta["fwd_mem_out"]), + mem_repr(node.meta["fwd_mem_tmp"]), + mem_repr(node.meta["bwd_mem_out"]), + mem_repr(node.meta["bwd_mem_tmp"]), + ] + ) # Use the ``tabulate`` library to create a well-formatted table # presenting our summary information headers: List[str] = [ - 'Op type', - 'Op', - 'Forward time', - 'Backward time', - 'SAVE_FWD_IN', - 'FWD_OUT', - 'FWD_TMP', - 'BWD_OUT', - 'BWD_TMP', + "Op type", + "Op", + "Forward time", + "Backward time", + "SAVE_FWD_IN", + "FWD_OUT", + "FWD_TMP", + "BWD_OUT", + "BWD_TMP", ] - return tabulate(node_summaries, headers=headers, stralign='right') + return tabulate(node_summaries, headers=headers, stralign="right") diff --git a/colossalai/fx/passes/experimental/adding_shape_consistency_pass.py b/colossalai/fx/passes/experimental/adding_shape_consistency_pass.py index f28d65e2668ac39e7b189c7d181d018468648614..3d032a27db638b4aea1fff17afdb5a7117a30431 100644 --- a/colossalai/fx/passes/experimental/adding_shape_consistency_pass.py +++ b/colossalai/fx/passes/experimental/adding_shape_consistency_pass.py @@ -1,14 +1,11 @@ -import torch -from typing import List -from torch.fx import symbolic_trace -from torch.fx.node import Node -from colossalai.fx.passes.split_module import split_module -from colossalai.tensor.shape_consistency import ShapeConsistencyManager -from colossalai.device.device_mesh import DeviceMesh -from colossalai.tensor.sharding_spec import ShardingSpec, _DimSpec import builtins import operator -from copy import deepcopy +from typing import List + +import torch + +from colossalai.tensor.shape_consistency import ShapeConsistencyManager +from colossalai.tensor.sharding_spec import ShardingSpec def apply(*args, **kwargs): @@ -16,7 +13,7 @@ def apply(*args, **kwargs): return shape_consistency_manager.apply(*args, **kwargs) -def solution_annotatation_pass(gm: torch.fx.GraphModule, solution: List[int], device_mesh): +def solution_annotation_pass(gm: torch.fx.GraphModule, solution: List[int], device_mesh): mod_graph = gm.graph nodes = tuple(mod_graph.nodes) @@ -24,16 +21,16 @@ def solution_annotatation_pass(gm: torch.fx.GraphModule, solution: List[int], de origin_node_sharding_spec_dict = {} for node_index, (node, strategy_index) in enumerate(zip(nodes, solution)): strategies_vector = node.strategies_vector - setattr(node, 'best_strategy', strategies_vector[strategy_index]) - setattr(node, 'sharding_spec', strategies_vector[strategy_index].output_sharding_spec) + setattr(node, "best_strategy", strategies_vector[strategy_index]) + setattr(node, "sharding_spec", strategies_vector[strategy_index].output_sharding_spec) origin_node_sharding_spec_dict[node_index] = strategies_vector[strategy_index].output_sharding_spec # apply the sharding spec of parameters for node in nodes: - if node.op == 'call_module': + if node.op == "call_module": target_module = node.graph.owning_module.get_submodule(node.target) origin_sharding_spec = ShardingSpec(device_mesh, target_module.weight.shape, {}) - setattr(target_module.weight, 'sharding_spec', origin_sharding_spec) + setattr(target_module.weight, "sharding_spec", origin_sharding_spec) target_weight_sharding_spec = node.best_strategy.input_shardings[1] target_module.weight.data = target_module.weight.data.permute((1, 0, 2, 3)) apply(target_module.weight, target_weight_sharding_spec) @@ -51,10 +48,10 @@ def solution_annotatation_pass(gm: torch.fx.GraphModule, solution: List[int], de # add above dicts into graph for node in nodes: - if node.op != 'placeholder': + if node.op != "placeholder": with mod_graph.inserting_before(node): - input_specs_node = mod_graph.create_node('placeholder', target='sharding_spec_convert_dict') - origin_specs_node = mod_graph.create_node('placeholder', target='origin_node_sharding_spec_dict') + input_specs_node = mod_graph.create_node("placeholder", target="sharding_spec_convert_dict") + origin_specs_node = mod_graph.create_node("placeholder", target="origin_node_sharding_spec_dict") break return sharding_spec_convert_dict, origin_node_sharding_spec_dict @@ -70,13 +67,13 @@ def shape_consistency_pass(gm: torch.fx.GraphModule): node_to_index_dict = {} index = 0 for node in nodes: - if node.target == 'sharding_spec_convert_dict': + if node.target == "sharding_spec_convert_dict": input_dict_node = node continue - if node.target == 'origin_node_sharding_spec_dict': + if node.target == "origin_node_sharding_spec_dict": origin_dict_node = node continue - if not hasattr(node, 'best_strategy'): + if not hasattr(node, "best_strategy"): continue node_to_index_dict[node] = index index += 1 @@ -84,28 +81,28 @@ def shape_consistency_pass(gm: torch.fx.GraphModule): # add shape consistency apply function into graph for node in nodes: - if not hasattr(node, 'best_strategy'): + if not hasattr(node, "best_strategy"): continue with mod_graph.inserting_after(node): - origin_spec_node = mod_graph.create_node('call_function', - operator.getitem, - args=(origin_dict_node, node_to_index_dict[node])) + origin_spec_node = mod_graph.create_node( + "call_function", operator.getitem, args=(origin_dict_node, node_to_index_dict[node]) + ) with mod_graph.inserting_after(origin_spec_node): - set_sharding_spec_node = mod_graph.create_node('call_function', - builtins.setattr, - args=(node, 'sharding_spec', origin_spec_node)) + set_sharding_spec_node = mod_graph.create_node( + "call_function", builtins.setattr, args=(node, "sharding_spec", origin_spec_node) + ) for user_node in node.strategies_vector.successor_nodes: node_index = user_node.strategies_vector.predecessor_nodes.index(node) with mod_graph.inserting_before(user_node): - input_specs_node = mod_graph.create_node('call_function', - operator.getitem, - args=(input_dict_node, node_to_index_dict[node])) + input_specs_node = mod_graph.create_node( + "call_function", operator.getitem, args=(input_dict_node, node_to_index_dict[node]) + ) with mod_graph.inserting_before(user_node): - sharding_spec_node = mod_graph.create_node('call_function', - operator.getitem, - args=(input_specs_node, node_index)) + sharding_spec_node = mod_graph.create_node( + "call_function", operator.getitem, args=(input_specs_node, node_index) + ) with mod_graph.inserting_before(user_node): - shape_consistency_node = mod_graph.create_node('call_function', apply, args=(node, sharding_spec_node)) + shape_consistency_node = mod_graph.create_node("call_function", apply, args=(node, sharding_spec_node)) return gm diff --git a/colossalai/fx/passes/meta_info_prop.py b/colossalai/fx/passes/meta_info_prop.py index 2b4a8749cfd776e3ac22d75fc2e47c2475c521d6..1720aa58da2b0f2a36befa9c4270030194fb0c32 100644 --- a/colossalai/fx/passes/meta_info_prop.py +++ b/colossalai/fx/passes/meta_info_prop.py @@ -31,7 +31,7 @@ class TensorMetadata(NamedTuple): numel: int is_tensor: bool # TODO: we can add a list of sharding spec here, and record the sharding - # behaviour by appending sharding spec into list. + # behavior by appending sharding spec into list. def _extract_tensor_metadata(result: torch.Tensor) -> TensorMetadata: @@ -109,13 +109,13 @@ class MetaInfoProp(torch.fx.Interpreter): return TensorMetadata(None, None, False, None, 0, False) tensor_meta = tree_map(extract_tensor_meta, result) - n.meta['tensor_meta'] = tensor_meta - n.meta = {**n.meta, **asdict(meta_info)} # extend MetaInfo to `n.meta` + n.meta["tensor_meta"] = tensor_meta + n.meta = {**n.meta, **asdict(meta_info)} # extend MetaInfo to `n.meta` # TODO: the attribute node_size should be removed in the future - setattr(n, 'node_size', activation_size(n.meta.get('fwd_out', 0)) + activation_size(n.meta.get('fwd_tmp', 0))) - setattr(n, 'fwd_flop', n.meta.get('fwd_flop', 0)) - setattr(n, 'bwd_flop', n.meta.get('bwd_flop', 0)) - n.meta['type'] = type(result) + setattr(n, "node_size", activation_size(n.meta.get("fwd_out", 0)) + activation_size(n.meta.get("fwd_tmp", 0))) + setattr(n, "fwd_flop", n.meta.get("fwd_flop", 0)) + setattr(n, "bwd_flop", n.meta.get("bwd_flop", 0)) + n.meta["type"] = type(result) # retain the autograd graph for param in self.module.parameters(): @@ -125,7 +125,7 @@ class MetaInfoProp(torch.fx.Interpreter): # Main Node running APIs @compatibility(is_backward_compatible=True) - def placeholder(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any: + def placeholder(self, target: "Target", args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any: """ Execute a ``placeholder`` node. Note that this is stateful: ``Interpreter`` maintains an internal iterator over @@ -146,7 +146,7 @@ class MetaInfoProp(torch.fx.Interpreter): return super().placeholder(target, args, kwargs), GraphInfo() @compatibility(is_backward_compatible=True) - def get_attr(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any: + def get_attr(self, target: "Target", args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any: """ Execute a ``get_attr`` node. Will retrieve an attribute value from the ``Module`` hierarchy of ``self.module``. @@ -165,7 +165,7 @@ class MetaInfoProp(torch.fx.Interpreter): return super().get_attr(target, args, kwargs), GraphInfo() @compatibility(is_backward_compatible=True) - def call_function(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any: + def call_function(self, target: "Target", args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any: """ Execute a ``call_function`` node with meta tensor and return the result and its meta profile. @@ -184,7 +184,7 @@ class MetaInfoProp(torch.fx.Interpreter): return profile_function(target)(*args, **kwargs) @compatibility(is_backward_compatible=True) - def call_method(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any: + def call_method(self, target: "Target", args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any: """ Execute a ``call_method`` node with meta tensor and return the result and its meta profile. @@ -202,7 +202,7 @@ class MetaInfoProp(torch.fx.Interpreter): return profile_method(target)(*args, **kwargs) @compatibility(is_backward_compatible=True) - def call_module(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any: + def call_module(self, target: "Target", args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any: """ Execute a ``call_module`` node with meta tensor and return the result and its meta profile. @@ -224,7 +224,7 @@ class MetaInfoProp(torch.fx.Interpreter): return profile_module(submod)(*args, **kwargs) @compatibility(is_backward_compatible=True) - def output(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any: + def output(self, target: "Target", args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any: """ Execute an ``output`` node. This really just retrieves the value referenced by the ``output`` node and returns it. @@ -240,7 +240,7 @@ class MetaInfoProp(torch.fx.Interpreter): result (Any): The argument value that was retrieved meta_info (MetaInfo): The memory cost and FLOPs estimated with `MetaTensor`. """ - if hasattr(args[0], '_tensor'): + if hasattr(args[0], "_tensor"): return args[0], GraphInfo(fwd_in=[args[0]._tensor]) return args[0], GraphInfo(save_fwd_in=True) @@ -257,7 +257,7 @@ class MetaInfoProp(torch.fx.Interpreter): """ return super().run(*args) - def summary(self, unit: str = 'MB') -> str: + def summary(self, unit: str = "MB") -> str: """ Summarizes the memory and FLOPs statistics of the `GraphModule` in tabular format. Note that this API requires the ``tabulate`` module @@ -267,9 +267,11 @@ class MetaInfoProp(torch.fx.Interpreter): try: from tabulate import tabulate except ImportError: - print("`summary` relies on the library `tabulate`, " - "which could not be found on this machine. Run `pip " - "install tabulate` to install the library.") + print( + "`summary` relies on the library `tabulate`, " + "which could not be found on this machine. Run `pip " + "install tabulate` to install the library." + ) assert self._is_proped, "Please call `interp.run(input)` before calling `interp.summary()`." @@ -278,10 +280,10 @@ class MetaInfoProp(torch.fx.Interpreter): def mem_repr(mem: int) -> str: unit_divisor_map = { - 'kb': 1024, - 'mb': 1024**2, - 'gb': 1024**3, - 'tb': 1024**4, + "kb": 1024, + "mb": 1024**2, + "gb": 1024**3, + "tb": 1024**4, } return f"{mem / unit_divisor_map[unit.lower()]:.2f} {unit.upper()}" @@ -292,35 +294,37 @@ class MetaInfoProp(torch.fx.Interpreter): for node in self.module.graph.nodes: node: Node accumulate_size += calculate_fwd_out(node) + calculate_fwd_tmp(node) - node_summaries.append([ - node.op, - str(node), - flops_repr(node.meta['fwd_flop']), - flops_repr(node.meta['bwd_flop']), - mem_repr(accumulate_size), - mem_repr(calculate_fwd_in(node)), - mem_repr(calculate_fwd_out(node)), - mem_repr(calculate_fwd_tmp(node)), - mem_repr(node.meta['bwd_mem_out']), - mem_repr(node.meta['bwd_mem_tmp']), - ]) + node_summaries.append( + [ + node.op, + str(node), + flops_repr(node.meta["fwd_flop"]), + flops_repr(node.meta["bwd_flop"]), + mem_repr(accumulate_size), + mem_repr(calculate_fwd_in(node)), + mem_repr(calculate_fwd_out(node)), + mem_repr(calculate_fwd_tmp(node)), + mem_repr(node.meta["bwd_mem_out"]), + mem_repr(node.meta["bwd_mem_tmp"]), + ] + ) # Use the ``tabulate`` library to create a well-formatted table # presenting our summary information headers: List[str] = [ - 'Op type', - 'Op', - 'Forward FLOPs', - 'Backward FLOPs', - 'Accumulated Memory', - 'FWD_IN', - 'FWD_OUT', - 'FWD_TMP', - 'BWD_OUT', - 'BWD_TMP', + "Op type", + "Op", + "Forward FLOPs", + "Backward FLOPs", + "Accumulated Memory", + "FWD_IN", + "FWD_OUT", + "FWD_TMP", + "BWD_OUT", + "BWD_TMP", ] - return tabulate(node_summaries, headers=headers, stralign='right') + return tabulate(node_summaries, headers=headers, stralign="right") def metainfo_trace(gm: torch.fx.GraphModule, *args, verbose: bool = False, unit: str = "MB", **kwargs) -> None: @@ -344,15 +348,16 @@ def metainfo_trace(gm: torch.fx.GraphModule, *args, verbose: bool = False, unit: Returns: torch.fx.GraphModule: The ``GraphModule`` annotated with MetaInfo. """ - device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu') + device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu") interp = MetaInfoProp(gm.to(device)) if is_compatible_with_meta(): from colossalai.fx.profiler import MetaTensor + args = tree_map(lambda x: MetaTensor(x, fake_device=device), args) kwargs = tree_map(lambda x: MetaTensor(x, fake_device=device), kwargs) interp.propagate(*args, **kwargs) if verbose: interp.summary(unit) - gm.to('cpu') + gm.to("cpu") del interp return gm diff --git a/colossalai/fx/passes/passes_for_gpt2_test.py b/colossalai/fx/passes/passes_for_gpt2_test.py index abc1a089e9a90ebaff0681879aaa68d488edb624..73379f73689c4da70feb3650a553678e57156945 100644 --- a/colossalai/fx/passes/passes_for_gpt2_test.py +++ b/colossalai/fx/passes/passes_for_gpt2_test.py @@ -5,7 +5,6 @@ import torch from packaging import version from torch.fx._compatibility import compatibility from torch.fx.graph_module import GraphModule -from torch.fx.node import Node from colossalai.fx.passes.adding_split_node_pass import balanced_split_pass, pipe_split from colossalai.fx.passes.meta_info_prop import TensorMetadata @@ -13,9 +12,9 @@ from colossalai.fx.passes.split_module import Partition def customized_split_pass_for_gpt2(gm: torch.fx.GraphModule, pp_size: int, partition_list: List[int]): - ''' + """ This pass is only used to do the gpt2 performance test, it may move into adding_split_node_pass.py, and will be deprecated in future. - ''' + """ mod_graph = gm.graph valid_children_size = 0 valid_children = [] @@ -39,40 +38,40 @@ def customized_split_pass_for_gpt2(gm: torch.fx.GraphModule, pp_size: int, parti part_index += 1 pp_size -= 1 with mod_graph.inserting_after(node): - split_node = mod_graph.create_node('call_function', pipe_split) + split_node = mod_graph.create_node("call_function", pipe_split) gm.recompile() return gm def split_with_split_nodes_pass_for_gp2_test(annotated_gm: torch.fx.GraphModule): - ''' + """ This pass will be used in gpt2 test, only a part of changes may be added into split_with_split_nodes_pass, and it will be deprecated in future. - ''' + """ part_idx = 0 def eliminate_unused_placeholders(gm): for node in gm.graph.nodes: - if node.op == 'placeholder': + if node.op == "placeholder": if not len(node.users): gm.graph.erase_node(node) gm.recompile() return gm def refill_outputs_and_placeholders(gm, next_partition_placeholders): - ''' + """ This method is used to eliminate the outputs in previous partition which is unused in next partition. In split module pass, it treats partitions as a DAG, but we need treat them as a single direction linked list in pipeline parallel. The difference is if a output from partition 0 is an input argument of partition 3, the DAG will not transfer it to partition 1 and partition 2. However, in single direction linked list, we need to do so. - ''' + """ output_type = None output_args = [] non_output_list = [] new_placeholder_list = [] for node in gm.graph.nodes: - if node.op == 'output': + if node.op == "output": if isinstance(node.args[0], (tuple, list)): output_type = node.args[0].__class__ output_args.extend([n.name for n in node.args[0]]) @@ -114,7 +113,7 @@ def split_with_split_nodes_pass_for_gp2_test(annotated_gm: torch.fx.GraphModule) continue for node in gm.graph.nodes: - if node.op == 'placeholder': + if node.op == "placeholder": new_placeholder_list.append(node.name) if output_type is not None: gm.graph.output(output_type(output_args)) @@ -125,7 +124,7 @@ def split_with_split_nodes_pass_for_gp2_test(annotated_gm: torch.fx.GraphModule) def split_callback(n: torch.fx.Node): nonlocal part_idx - if (n.op, n.target) == ('call_function', pipe_split): + if (n.op, n.target) == ("call_function", pipe_split): part_idx += 1 return part_idx @@ -134,7 +133,7 @@ def split_with_split_nodes_pass_for_gp2_test(annotated_gm: torch.fx.GraphModule) for name, submodule in split_mod.named_modules(): if isinstance(submodule, torch.fx.GraphModule): for node in submodule.graph.nodes: - if (node.op, node.target) == ('call_function', pipe_split): + if (node.op, node.target) == ("call_function", pipe_split): submodule.graph.erase_node(node) submodule.recompile() split_submodules.append(submodule) @@ -200,13 +199,12 @@ def split_module_for_gpt2_test( _gen_all_ancestors_set(node) for n in list(all_ancestors): - if n.op != 'placeholder' and n._fx_partition > partition_name: + if n.op != "placeholder" and n._fx_partition > partition_name: n._fx_partition = partition_name - def record_cross_partition_use(def_node: torch.fx.node.Node, - use_node: Optional[torch.fx.node.Node]): # noqa: B950 - def_partition_name = getattr(def_node, '_fx_partition', None) - use_partition_name = getattr(use_node, '_fx_partition', None) + def record_cross_partition_use(def_node: torch.fx.node.Node, use_node: Optional[torch.fx.node.Node]): # noqa: B950 + def_partition_name = getattr(def_node, "_fx_partition", None) + use_partition_name = getattr(use_node, "_fx_partition", None) if def_partition_name != use_partition_name: # if 'tensor_meta' in def_node.meta: # if not _node_with_all_tensor_element(def_node.meta['tensor_meta']): @@ -230,14 +228,14 @@ def split_module_for_gpt2_test( use_partition.partitions_dependent_on.setdefault(def_partition_name) node_process_list = list(m.graph.nodes) - # split nodes into parititons + # split nodes into partitions while node_process_list: node = node_process_list.pop(0) orig_nodes[node.name] = node if node.op in ["placeholder"]: continue - if node.op == 'output': + if node.op == "output": # partition_name = str(split_callback(node)) # def _set_output_args_partition(n, partition_name): # n._fx_partition = partition_name @@ -252,12 +250,12 @@ def split_module_for_gpt2_test( partitions[partition_name] = partition = Partition(partition_name) partition.node_names.append(node.name) - origin_partition_name = getattr(node, '_fx_partition', None) + origin_partition_name = getattr(node, "_fx_partition", None) if origin_partition_name is None: node._fx_partition = partition_name torch.fx.graph.map_arg(node.args, lambda def_node: record_cross_partition_use(def_node, node)) - torch.fx.graph.map_arg(node.kwargs, lambda def_node: record_cross_partition_use(def_node, node)) # noqa: B950 + torch.fx.graph.map_arg(node.kwargs, lambda def_node: record_cross_partition_use(def_node, node)) # noqa: B950 # find partitions with no dependencies root_partitions: List[str] = [] @@ -277,7 +275,7 @@ def split_module_for_gpt2_test( if len(sorted_partitions) != len(partitions): raise RuntimeError("cycle exists between partitions!") - # add placeholders to parititons + # add placeholders to partitions for partition_name in sorted_partitions: partition = partitions[partition_name] for input in partition.inputs: @@ -287,7 +285,7 @@ def split_module_for_gpt2_test( # Transform nodes and collect targets for partition's submodule for node in m.graph.nodes: - if hasattr(node, '_fx_partition'): + if hasattr(node, "_fx_partition"): partition = partitions[node._fx_partition] # swap out old graph nodes in kw/args with references to new nodes in this submodule @@ -295,26 +293,24 @@ def split_module_for_gpt2_test( gathered_args = torch.fx.graph.map_arg(node.args, lambda n: environment[n]) gathered_kwargs = torch.fx.graph.map_arg(node.kwargs, lambda n: environment[n]) - if node.op not in ['call_module', 'get_attr']: + if node.op not in ["call_module", "get_attr"]: target = node.target else: - target_atoms = node.target.split('.') + target_atoms = node.target.split(".") target_attr = m for atom in target_atoms: if not hasattr(target_attr, atom): - raise RuntimeError(f'Operator target {node.target} not found!') + raise RuntimeError(f"Operator target {node.target} not found!") target_attr = getattr(target_attr, atom) # target = target_atoms[-1] - target = '_'.join(target_atoms) + target = "_".join(target_atoms) partition.targets[target] = target_attr assert isinstance(gathered_args, tuple) assert isinstance(gathered_kwargs, dict) - new_node = partition.graph.create_node(op=node.op, - target=target, - args=gathered_args, - kwargs=gathered_kwargs, - name=node.name) + new_node = partition.graph.create_node( + op=node.op, target=target, args=gathered_args, kwargs=gathered_kwargs, name=node.name + ) new_node.meta = node.meta.copy() partition.environment[node] = new_node @@ -323,14 +319,14 @@ def split_module_for_gpt2_test( base_mod_graph: torch.fx.graph.Graph = torch.fx.graph.Graph() base_mod_attrs: Dict[str, torch.fx.graph_module.GraphModule] = {} for node in m.graph.nodes: - if node.op == 'placeholder': - if version.parse(torch.__version__) < version.parse('1.11.0'): + if node.op == "placeholder": + if version.parse(torch.__version__) < version.parse("1.11.0"): base_mod_env[node.name] = base_mod_graph.placeholder(node.name, type_expr=node.type) else: default_value = node.args[0] if len(node.args) > 0 else inspect.Signature.empty - base_mod_env[node.name] = base_mod_graph.placeholder(node.name, - type_expr=node.type, - default_value=default_value) + base_mod_env[node.name] = base_mod_graph.placeholder( + node.name, type_expr=node.type, default_value=default_value + ) base_mod_env[node.name].meta = node.meta.copy() # Do some things iterating over the partitions in topological order again: @@ -344,13 +340,14 @@ def split_module_for_gpt2_test( # Set correct output values output_vals = tuple(partition.environment[orig_nodes[name]] for name in partition.outputs) - output_vals = output_vals[0] if len(output_vals) == 1 else output_vals # type: ignore[assignment] + output_vals = output_vals[0] if len(output_vals) == 1 else output_vals # type: ignore[assignment] partition.graph.output(output_vals) # Construct GraphModule for this partition - submod_name = f'submod_{partition_name}' - base_mod_attrs[submod_name] = torch.fx.graph_module.GraphModule(partition.targets, - partition.graph) # noqa: B950 + submod_name = f"submod_{partition_name}" + base_mod_attrs[submod_name] = torch.fx.graph_module.GraphModule( + partition.targets, partition.graph + ) # noqa: B950 # Emit call in base graph to this submodule output_val = base_mod_graph.call_module(submod_name, tuple(base_mod_env[name] for name in partition.inputs)) @@ -358,14 +355,14 @@ def split_module_for_gpt2_test( # Unpack multiple return values from submodule output_val_proxy = torch.fx.proxy.Proxy(output_val) for i, output_name in enumerate(partition.outputs): - base_mod_env[output_name] = output_val_proxy[i].node # type: ignore[index] + base_mod_env[output_name] = output_val_proxy[i].node # type: ignore[index] else: if not partition.outputs: continue base_mod_env[list(partition.outputs)[0]] = output_val for node in m.graph.nodes: - if node.op == 'output': - base_mod_graph.output(torch.fx.graph.map_arg(node.args[0], lambda n: base_mod_env[n.name])) # noqa: B950 + if node.op == "output": + base_mod_graph.output(torch.fx.graph.map_arg(node.args[0], lambda n: base_mod_env[n.name])) # noqa: B950 return torch.fx.graph_module.GraphModule(base_mod_attrs, base_mod_graph) diff --git a/colossalai/fx/passes/shard_1d_pass.py b/colossalai/fx/passes/shard_1d_pass.py index d2bad06bb45a1393543039ef55dea6c8d1e9f50b..be8261f2a3f4fd408824ad4003d455217d456960 100644 --- a/colossalai/fx/passes/shard_1d_pass.py +++ b/colossalai/fx/passes/shard_1d_pass.py @@ -1,19 +1,32 @@ +import operator + import torch import torch.nn as nn -import operator -from colossalai.tensor import ProcessGroup -from colossalai.tensor.distspec import ShardSpec -from colossalai.tensor.compute_spec import ComputePattern, ComputeSpec + +from colossalai.legacy.tensor import ProcessGroup +from colossalai.legacy.tensor.compute_spec import ComputePattern, ComputeSpec +from colossalai.legacy.tensor.distspec import ShardSpec ELEMENTWISE_MODULE_OP = [torch.nn.Dropout, torch.nn.ReLU] ELEMENTWISE_FUNC_OP = [ - torch.add, operator.add, torch.abs, torch.cos, torch.exp, torch.mul, operator.mul, operator.floordiv, - operator.truediv, operator.neg, torch.multiply, torch.nn.functional.relu, torch.nn.functional.dropout + torch.add, + operator.add, + torch.abs, + torch.cos, + torch.exp, + torch.mul, + operator.mul, + operator.floordiv, + operator.truediv, + operator.neg, + torch.multiply, + torch.nn.functional.relu, + torch.nn.functional.dropout, ] def weight_split(weight: torch.nn.parameter.Parameter, dim: int, col_normal: bool) -> torch.nn.parameter.Parameter: - """weight_split + """weight_split split a nn.Parameter Args: @@ -60,9 +73,9 @@ def row_shard_linear_pass(gm: torch.fx.GraphModule): def transformer_mlp_pass(graph_module: torch.fx.GraphModule, process_group: ProcessGroup): """ - This IR pass checks for transformer MLP like structure and annotate column and row sharding to the linear layers. + This IR pass checks for transformer MLP like structure and annotate column and row sharding to the linear layers. """ - #TODO: Needs to handle special cases, like x = linear(x) + linear(x) + # TODO: Needs to handle special cases, like x = linear(x) + linear(x) graph = graph_module.graph world_size = process_group.world_size() @@ -70,7 +83,7 @@ def transformer_mlp_pass(graph_module: torch.fx.GraphModule, process_group: Proc # traverse the graph to look for consecutive linear layers is_linear_module = False - if node.op == 'call_module': + if node.op == "call_module": # look for the linear layer module = node.graph.owning_module.get_submodule(node.target) if isinstance(module, nn.Linear): @@ -80,31 +93,31 @@ def transformer_mlp_pass(graph_module: torch.fx.GraphModule, process_group: Proc # it means the first linear has been found and the current module # is the second linear # set the current linear module to be row-sharded - annotation_record['row'] = module + annotation_record["row"] = module for shard_type, module in annotation_record.items(): # add row sharding spec - if shard_type == 'row': + if shard_type == "row": dist_spec = ShardSpec(dims=[-1], num_partitions=[world_size]) comp_spec = ComputeSpec(ComputePattern.TP1D) - setattr(module.weight, 'pg', process_group) - setattr(module.weight, 'dist_spec', dist_spec) - setattr(module.weight, 'comp_spec', comp_spec) - elif shard_type == 'col': + setattr(module.weight, "pg", process_group) + setattr(module.weight, "dist_spec", dist_spec) + setattr(module.weight, "comp_spec", comp_spec) + elif shard_type == "col": weight_dist_spec = ShardSpec(dims=[0], num_partitions=[world_size]) weight_comp_spec = ComputeSpec(ComputePattern.TP1D) weight_comp_spec.output_replicate = False - setattr(module.weight, 'pg', process_group) - setattr(module.weight, 'dist_spec', weight_dist_spec) - setattr(module.weight, 'comp_spec', weight_comp_spec) + setattr(module.weight, "pg", process_group) + setattr(module.weight, "dist_spec", weight_dist_spec) + setattr(module.weight, "comp_spec", weight_comp_spec) if module.bias is not None: bias_dist_spec = ShardSpec(dims=[0], num_partitions=[world_size]) bias_comp_spec = ComputeSpec(ComputePattern.TP1D) bias_comp_spec.output_replicate = False - setattr(module.bias, 'pg', process_group) - setattr(module.bias, 'dist_spec', bias_dist_spec) - setattr(module.bias, 'comp_spec', bias_comp_spec) + setattr(module.bias, "pg", process_group) + setattr(module.bias, "dist_spec", bias_dist_spec) + setattr(module.bias, "comp_spec", bias_comp_spec) start_tracking = False annotation_record.clear() else: @@ -112,16 +125,16 @@ def transformer_mlp_pass(graph_module: torch.fx.GraphModule, process_group: Proc # it means the current layer is the first linear # set the linear layer to be col-sharded start_tracking = True - annotation_record['col'] = module + annotation_record["col"] = module if start_tracking and not is_linear_module: # check against the white list # if non-element wise op is found, we reset the tracking - if node.op == 'call_module': + if node.op == "call_module": module = node.graph.owning_module.get_submodule(node.target) if module.__class__ not in ELEMENTWISE_MODULE_OP: start_tracking = False - elif node.op == 'call_function' or node.op == 'call_method': + elif node.op == "call_function" or node.op == "call_method": if node.target not in ELEMENTWISE_FUNC_OP: start_tracking = False elif len(node.users.keys()) > 1: diff --git a/colossalai/fx/passes/split_module.py b/colossalai/fx/passes/split_module.py index 5ce5b969cbdefc0e347eb9f156f5268658e988fc..67a2432595d691197861fb5d72c27ce137dbd808 100644 --- a/colossalai/fx/passes/split_module.py +++ b/colossalai/fx/passes/split_module.py @@ -25,12 +25,14 @@ class Partition: self.targets: Dict[str, Any] = {} def __repr__(self) -> str: - return f"name: {self.name},\n" \ - f" nodes: {self.node_names},\n" \ - f" inputs: {self.inputs},\n" \ - f" outputs: {self.outputs},\n" \ - f" partitions depenent on: {self.partitions_dependent_on},\n" \ - f" parition dependents: {self.partition_dependents}" + return ( + f"name: {self.name},\n" + f" nodes: {self.node_names},\n" + f" inputs: {self.inputs},\n" + f" outputs: {self.outputs},\n" + f" partitions dependent on: {self.partitions_dependent_on},\n" + f" partition dependents: {self.partition_dependents}" + ) # Creates subgraphs out of main graph @@ -117,10 +119,9 @@ def split_module( partitions: Dict[str, Partition] = {} orig_nodes: Dict[str, torch.fx.node.Node] = {} - def record_cross_partition_use(def_node: torch.fx.node.Node, - use_node: Optional[torch.fx.node.Node]): # noqa: B950 - def_partition_name = getattr(def_node, '_fx_partition', None) - use_partition_name = getattr(use_node, '_fx_partition', None) + def record_cross_partition_use(def_node: torch.fx.node.Node, use_node: Optional[torch.fx.node.Node]): # noqa: B950 + def_partition_name = getattr(def_node, "_fx_partition", None) + use_partition_name = getattr(use_node, "_fx_partition", None) if def_partition_name != use_partition_name: if def_partition_name is not None: def_partition = partitions[def_partition_name] @@ -134,7 +135,7 @@ def split_module( if def_partition_name is not None: use_partition.partitions_dependent_on.setdefault(def_partition_name) - def record_output(def_node: torch.fx.node.Node, use_node: Optional[torch.fx.node.Node]): # noqa: B950 + def record_output(def_node: torch.fx.node.Node, use_node: Optional[torch.fx.node.Node]): # noqa: B950 def_partition_name = getattr(def_node, "_fx_partition", None) use_partition_name = getattr(use_node, "_fx_partition", None) if def_partition_name != use_partition_name: @@ -161,7 +162,7 @@ def split_module( if node.op in ["placeholder"]: continue - if node.op == 'output': + if node.op == "output": if merge_output: torch.fx.graph.map_arg(node.args[0], lambda n: record_output(n, node.prev)) else: @@ -178,7 +179,7 @@ def split_module( node._fx_partition = partition_name torch.fx.graph.map_arg(node.args, lambda def_node: record_cross_partition_use(def_node, node)) - torch.fx.graph.map_arg(node.kwargs, lambda def_node: record_cross_partition_use(def_node, node)) # noqa: B950 + torch.fx.graph.map_arg(node.kwargs, lambda def_node: record_cross_partition_use(def_node, node)) # noqa: B950 # find partitions with no dependencies root_partitions: List[str] = [] @@ -208,7 +209,7 @@ def split_module( # Transform nodes and collect targets for partition's submodule for node in m.graph.nodes: - if hasattr(node, '_fx_partition'): + if hasattr(node, "_fx_partition"): partition = partitions[node._fx_partition] # swap out old graph nodes in kw/args with references to new nodes in this submodule @@ -216,25 +217,24 @@ def split_module( gathered_args = torch.fx.graph.map_arg(node.args, lambda n: environment[n]) gathered_kwargs = torch.fx.graph.map_arg(node.kwargs, lambda n: environment[n]) - if node.op not in ['call_module', 'get_attr']: + if node.op not in ["call_module", "get_attr"]: target = node.target else: - target_atoms = node.target.split('.') + target_atoms = node.target.split(".") target_attr = m for atom in target_atoms: if not hasattr(target_attr, atom): - raise RuntimeError(f'Operator target {node.target} not found!') + raise RuntimeError(f"Operator target {node.target} not found!") target_attr = getattr(target_attr, atom) # target = target_atoms[-1] - target = '_'.join(target_atoms) + target = "_".join(target_atoms) partition.targets[target] = target_attr assert isinstance(gathered_args, tuple) assert isinstance(gathered_kwargs, dict) - new_node = partition.graph.create_node(op=node.op, - target=target, - args=gathered_args, - kwargs=gathered_kwargs) + new_node = partition.graph.create_node( + op=node.op, target=target, args=gathered_args, kwargs=gathered_kwargs + ) new_node.meta = node.meta.copy() partition.environment[node] = new_node @@ -243,14 +243,14 @@ def split_module( base_mod_graph: torch.fx.graph.Graph = torch.fx.graph.Graph() base_mod_attrs: Dict[str, torch.fx.graph_module.GraphModule] = {} for node in m.graph.nodes: - if node.op == 'placeholder': - if version.parse(torch.__version__) < version.parse('1.11.0'): + if node.op == "placeholder": + if version.parse(torch.__version__) < version.parse("1.11.0"): base_mod_env[node.name] = base_mod_graph.placeholder(node.target, type_expr=node.type) else: default_value = node.args[0] if len(node.args) > 0 else inspect.Signature.empty - base_mod_env[node.name] = base_mod_graph.placeholder(node.target, - type_expr=node.type, - default_value=default_value) + base_mod_env[node.name] = base_mod_graph.placeholder( + node.target, type_expr=node.type, default_value=default_value + ) base_mod_env[node.name].meta = node.meta.copy() # Do some things iterating over the partitions in topological order again: @@ -264,13 +264,14 @@ def split_module( # Set correct output values output_vals = tuple(partition.environment[orig_nodes[name]] for name in partition.outputs) - output_vals = output_vals[0] if len(output_vals) == 1 else output_vals # type: ignore[assignment] + output_vals = output_vals[0] if len(output_vals) == 1 else output_vals # type: ignore[assignment] partition.graph.output(output_vals) # Construct GraphModule for this partition - submod_name = f'submod_{partition_name}' - base_mod_attrs[submod_name] = torch.fx.graph_module.GraphModule(partition.targets, - partition.graph) # noqa: B950 + submod_name = f"submod_{partition_name}" + base_mod_attrs[submod_name] = torch.fx.graph_module.GraphModule( + partition.targets, partition.graph + ) # noqa: B950 # Emit call in base graph to this submodule output_val = base_mod_graph.call_module(submod_name, tuple(base_mod_env[name] for name in partition.inputs)) @@ -278,15 +279,15 @@ def split_module( # Unpack multiple return values from submodule output_val_proxy = torch.fx.proxy.Proxy(output_val) for i, output_name in enumerate(partition.outputs): - base_mod_env[output_name] = output_val_proxy[i].node # type: ignore[index] + base_mod_env[output_name] = output_val_proxy[i].node # type: ignore[index] else: if not partition.outputs: continue base_mod_env[list(partition.outputs)[0]] = output_val for node in m.graph.nodes: - if node.op == 'output': - base_mod_graph.output(torch.fx.graph.map_arg(node.args[0], lambda n: base_mod_env[n.name])) # noqa: B950 + if node.op == "output": + base_mod_graph.output(torch.fx.graph.map_arg(node.args[0], lambda n: base_mod_env[n.name])) # noqa: B950 for partition_name in sorted_partitions: partition = partitions[partition_name] diff --git a/colossalai/fx/passes/utils.py b/colossalai/fx/passes/utils.py index bb4f3cd6a4908177ca13f9d7fb82ff42b5ad1d5e..c51f49a30e8ae4c703ce2f7820f59ad901f4bbb0 100644 --- a/colossalai/fx/passes/utils.py +++ b/colossalai/fx/passes/utils.py @@ -1,7 +1,9 @@ -import torch from typing import Dict -from torch.fx.node import Node, map_arg + +import torch from torch.fx.graph import Graph +from torch.fx.node import Node, map_arg + def get_comm_size(prev_partition, next_partition): """ @@ -23,7 +25,7 @@ def get_comm_size(prev_partition, next_partition): map_arg(node.kwargs, lambda n: input_nodes.setdefault(n)) for n in input_nodes: if n.name in parent_node_names and n not in visited_nodes: - comm_size += n.meta['tensor_meta'].numel + comm_size += n.meta["tensor_meta"].numel visited_nodes.add(n) return comm_size @@ -36,12 +38,12 @@ def get_leaf(graph: Graph): """ input_nodes: Dict[Node, None] = {} for node in graph.nodes: - if node.op == 'output': + if node.op == "output": map_arg(node.args, lambda n: input_nodes.setdefault(n)) map_arg(node.kwargs, lambda n: input_nodes.setdefault(n)) placeholder_nodes = [] for node in input_nodes.keys(): - if node.op == 'placeholder': + if node.op == "placeholder": placeholder_nodes.append(node) for node in placeholder_nodes: input_nodes.pop(node) @@ -60,13 +62,13 @@ def get_top(graph: Graph): """ top_node_list = set() for node in graph.nodes: - if node.op == 'output': + if node.op == "output": continue is_top = False def _get_top(node): nonlocal is_top - if node.op == 'placeholder': + if node.op == "placeholder": is_top = True map_arg(node.args, lambda n: _get_top(n)) @@ -83,7 +85,7 @@ def is_top(graph: Graph, node: Node): def get_all_consumers(graph: Graph, node: Node): """ Given a graph and a node of this graph, return all consumers of the node. - + Returns: List of ``Nodes`` that node appear in these nodes ``args`` and ``kwargs``. """ @@ -120,7 +122,7 @@ def assign_bfs_level_to_nodes(graph: Graph): for node in gm.graph.nodes: if hasattr(node, 'bfs_level'): print(node.name, node.bfs_level) - + Output: graph(): %x : [#users=2] = placeholder[target=x] @@ -148,7 +150,7 @@ def assign_bfs_level_to_nodes(graph: Graph): while nodes_to_process: new_process_list = [] for node in nodes_to_process: - if node.op == 'output': + if node.op == "output": continue node.bfs_level = current_level new_process_list.extend(get_all_consumers(graph, node)) @@ -165,8 +167,9 @@ def get_node_module(node) -> torch.nn.Module: torch.nn.Module: the module associated with the given node """ - assert node.graph.owning_module is not None, 'Cannot find the owning_module for node.graph, please make sure the graph is associated with a GraphModule object' - assert node.op == 'call_module', f'Expected node.op to be call_module, but found {node.op}' + assert ( + node.graph.owning_module is not None + ), "Cannot find the owning_module for node.graph, please make sure the graph is associated with a GraphModule object" + assert node.op == "call_module", f"Expected node.op to be call_module, but found {node.op}" module = node.graph.owning_module.get_submodule(node.target) return module - diff --git a/colossalai/fx/profiler/__init__.py b/colossalai/fx/profiler/__init__.py index 8bcbde0eb23b806b7e37e407d57a962d8ff71573..89dd2b3df617e352b324d9dab421bea69a84a6a1 100644 --- a/colossalai/fx/profiler/__init__.py +++ b/colossalai/fx/profiler/__init__.py @@ -12,7 +12,16 @@ if is_compatible_with_meta(): ) from .tensor import MetaTensor else: - from .experimental import meta_profiler_function, meta_profiler_module, profile_function, profile_method, profile_module, calculate_fwd_in, calculate_fwd_tmp, calculate_fwd_out + from .experimental import ( + meta_profiler_function, + meta_profiler_module, + profile_function, + profile_method, + profile_module, + calculate_fwd_in, + calculate_fwd_tmp, + calculate_fwd_out, + ) from .dataflow import GraphInfo from .memory_utils import activation_size, is_inplace, parameter_size diff --git a/colossalai/fx/profiler/constants.py b/colossalai/fx/profiler/constants.py index 5763a46dc83f19dadefbebb32dbcf9a59578a2b3..fad9bb272bff76022928aa518f99bd92ffb27892 100644 --- a/colossalai/fx/profiler/constants.py +++ b/colossalai/fx/profiler/constants.py @@ -1,6 +1,6 @@ import torch -__all__ = ['ALIAS_ATEN', 'INPLACE_NEW', 'INPLACE_MATH_ATEN', 'CLONE_ATEN', 'RELU_LIKE_OPS', 'RELU_LIKE_MOD'] +__all__ = ["ALIAS_ATEN", "INPLACE_NEW", "INPLACE_MATH_ATEN", "CLONE_ATEN", "RELU_LIKE_OPS", "RELU_LIKE_MOD"] aten = torch.ops.aten diff --git a/colossalai/fx/profiler/dataflow.py b/colossalai/fx/profiler/dataflow.py index a5e8880322b84f54e6c4742a821f4de76dfb664a..05f9b50ce575312011d9b0ac9937959f7b22e298 100644 --- a/colossalai/fx/profiler/dataflow.py +++ b/colossalai/fx/profiler/dataflow.py @@ -1,6 +1,5 @@ from dataclasses import dataclass, field from enum import Enum -from functools import partial from typing import Dict, List from torch.fx import Graph, Node @@ -69,8 +68,8 @@ class GraphInfo: def is_phase(n: Node, phase: Phase) -> bool: - assert 'phase' in n.meta, f'Node meta of {n} has no key `phase`!' - return n.meta['phase'] == phase + assert "phase" in n.meta, f"Node meta of {n} has no key `phase`!" + return n.meta["phase"] == phase @compatibility(is_backward_compatible=False) @@ -103,9 +102,9 @@ def autograd_graph_analysis(graph: Graph) -> GraphInfo: peak_mem = 0 for k, v in deps.items(): if v > 0 and is_phase(k, Phase.BACKWARD) and not all(map(is_inplace, k.users)) and not is_inplace(k): - peak_mem += activation_size(k.meta['saved_tensor']) - if v <= float('-inf') and is_phase(k, Phase.FORWARD): - peak_mem -= activation_size(k.meta['saved_tensor']) + peak_mem += activation_size(k.meta["saved_tensor"]) + if v <= float("-inf") and is_phase(k, Phase.FORWARD): + peak_mem -= activation_size(k.meta["saved_tensor"]) return peak_mem # deps is used to track all the memory dependencies of the graph. @@ -123,19 +122,19 @@ def autograd_graph_analysis(graph: Graph) -> GraphInfo: # Otherwise, the tensor belongs to `fwd_mem_tmp`. If we checkpoint # the node, `fwd_mem_tmp` can be freed. if is_phase(n, Phase.PLACEHOLDER): - graph_info.fwd_in += n.meta['saved_tensor'] + graph_info.fwd_in += n.meta["saved_tensor"] if is_phase(n, Phase.FORWARD): - graph_info.fwd_tmp += n.meta['saved_tensor'] + graph_info.fwd_tmp += n.meta["saved_tensor"] elif is_phase(n, Phase.BACKWARD): if len(n.users): graph_info.bwd_mem_tmp = max(graph_info.bwd_mem_tmp, _peak_memory(deps)) else: # TODO: some of the bwd_mem_out might be model parameters. # basically a backward node without user is a `grad_out` node - graph_info.bwd_mem_out += activation_size(n.meta['saved_tensor']) + graph_info.bwd_mem_out += activation_size(n.meta["saved_tensor"]) for input_n in n.all_input_nodes: if input_n in deps: deps[input_n] -= 1 if deps[input_n] <= 0: - deps[input_n] = float('-inf') + deps[input_n] = float("-inf") return graph_info diff --git a/colossalai/fx/profiler/experimental/constants.py b/colossalai/fx/profiler/experimental/constants.py index 57ff3fd91299b5bb8938125bf2d3243c9a9c4c2b..02758e7643af0936a5c7907ca6675ef10ec4db36 100644 --- a/colossalai/fx/profiler/experimental/constants.py +++ b/colossalai/fx/profiler/experimental/constants.py @@ -2,7 +2,7 @@ from operator import add, floordiv, getitem, mul, neg, pos, setitem, sub import torch -__all__ = ['INPLACE_OPS', 'INPLACE_METHOD', 'NON_INPLACE_METHOD'] +__all__ = ["INPLACE_OPS", "INPLACE_METHOD", "NON_INPLACE_METHOD"] # TODO fill out the inplace ops INPLACE_OPS = [ @@ -20,25 +20,25 @@ INPLACE_OPS = [ # TODO: list all call_methods that are inplace here INPLACE_METHOD = [ - 'transpose', - 'permute', + "transpose", + "permute", # TODO: reshape may return a copy of the data if the data is not contiguous - 'reshape', - 'dim', - 'flatten', - 'size', - 'view', - 'unsqueeze', - 'to', - 'type', - 'flatten', + "reshape", + "dim", + "flatten", + "size", + "view", + "unsqueeze", + "to", + "type", + "flatten", ] # TODO: list all call_methods that are not inplace here NON_INPLACE_METHOD = [ - 'chunk', - 'contiguous', - 'expand', - 'mean', - 'split', + "chunk", + "contiguous", + "expand", + "mean", + "split", ] diff --git a/colossalai/fx/profiler/experimental/profiler.py b/colossalai/fx/profiler/experimental/profiler.py index 5c545260e72b723bfa54beacfb20def3e758413f..d890fdb66fc2dc7497a6c6bf2f4a2534c1885d60 100644 --- a/colossalai/fx/profiler/experimental/profiler.py +++ b/colossalai/fx/profiler/experimental/profiler.py @@ -9,7 +9,7 @@ from ..memory_utils import activation_size from .constants import INPLACE_METHOD, INPLACE_OPS, NON_INPLACE_METHOD from .registry import meta_profiler_function, meta_profiler_module -__all__ = ['profile_function', 'profile_module', 'profile_method'] +__all__ = ["profile_function", "profile_module", "profile_method"] # this is for compatibility use @@ -42,6 +42,7 @@ class GraphInfo: bwd_mem_tmp (int): See the above illustration. bwd_mem_out (int): See the above illustration. """ + fwd_flop: int = 0 bwd_flop: int = 0 fwd_mem_in: int = 0 @@ -50,8 +51,7 @@ class GraphInfo: bwd_mem_out: int = 0 -CALL_FUNCTION_MSG = \ -""" +CALL_FUNCTION_MSG = """ Colossal-AI hasn't supported profiling for {}, you might manually patch it with the following code.\n from colossalai.fx.profiler.experimental import meta_profiler_function @meta_profiler_function.register(YOUR_FUNCTION) @@ -60,9 +60,8 @@ def profile_YOUR_FUNCTION(input: torch.Tensor, *args) -> Tuple[int, int]: macs = ... return flops, macs """ -CALL_METHOD_MSG = 'Please check if {} is an inplace method. If so, add target to INPLACE_METHOD={}. Otherwise, add target to NON_INPLACE_METHOD={}' -CALL_MODULE_MSG = \ -""" +CALL_METHOD_MSG = "Please check if {} is an inplace method. If so, add target to INPLACE_METHOD={}. Otherwise, add target to NON_INPLACE_METHOD={}" +CALL_MODULE_MSG = """ Colossal-AI hasn't supported profiling for {}, you might manually patch it with the following code.\n from colossalai.fx.profiler.experimental import meta_profiler_module @meta_profiler_module.register(YOUR_MODULE) @@ -74,7 +73,7 @@ def profile_YOUR_MODULE(self: torch.nn.Module, input: torch.Tensor) -> Tuple[int @compatibility(is_backward_compatible=True) -def profile_function(target: 'Target') -> Callable: +def profile_function(target: "Target") -> Callable: """ Wrap a `call_function` node or `torch.nn.functional` in order to record the memory cost and FLOPs of the execution. @@ -92,12 +91,13 @@ def profile_function(target: 'Target') -> Callable: def f(*args: Tuple[Argument, ...], **kwargs: Dict[str, Any]) -> Any: assert meta_profiler_function.has(target) or meta_profiler_function.has( - target.__name__), CALL_FUNCTION_MSG.format(target) + target.__name__ + ), CALL_FUNCTION_MSG.format(target) fwd_tmp = 0 fwd_out = 0 out = func(*args, **kwargs) - if target not in INPLACE_OPS and not kwargs.get('inplace', False): + if target not in INPLACE_OPS and not kwargs.get("inplace", False): fwd_out = activation_size(out) if meta_profiler_function.has(target): profiler = meta_profiler_function.get(target) @@ -112,7 +112,7 @@ def profile_function(target: 'Target') -> Callable: @compatibility(is_backward_compatible=True) -def profile_method(target: 'Target') -> Callable: +def profile_method(target: "Target") -> Callable: """ Wrap a `call_method` node record the memory cost and FLOPs of the execution. @@ -126,11 +126,12 @@ def profile_method(target: 'Target') -> Callable: self_obj, *args_tail = args # execute the method and return the result - assert isinstance(target, str), f'{target} instance is not str.' + assert isinstance(target, str), f"{target} instance is not str." out = getattr(self_obj, target)(*args_tail, **kwargs) assert target in INPLACE_METHOD + NON_INPLACE_METHOD, CALL_METHOD_MSG.format( - target, INPLACE_METHOD, NON_INPLACE_METHOD) + target, INPLACE_METHOD, NON_INPLACE_METHOD + ) # call_method has no parameters and are MOSTLY(?) inplace, and has no FLOPs or MACs. fwd_tmp = 0 if target in INPLACE_METHOD else activation_size(out) fwd_out = 0 if target not in INPLACE_METHOD else activation_size(out) @@ -161,7 +162,7 @@ def profile_module(module: torch.nn.Module) -> Callable: fwd_tmp = 0 fwd_out = 0 out = func(*args, **kwargs) - if getattr(module, 'inplace', False): + if getattr(module, "inplace", False): fwd_out = activation_size(out) profiler = meta_profiler_module.get(type(module)) fwd_flop, _ = profiler(module, *args, **kwargs) diff --git a/colossalai/fx/profiler/experimental/profiler_function/activation_function.py b/colossalai/fx/profiler/experimental/profiler_function/activation_function.py index a43aef063e197de12c23fc5a81fb13e8183eaae9..c518ec28da418ed86716a5839dace4f1fd8bc160 100644 --- a/colossalai/fx/profiler/experimental/profiler_function/activation_function.py +++ b/colossalai/fx/profiler/experimental/profiler_function/activation_function.py @@ -1,5 +1,7 @@ from typing import Tuple + import torch + from ..registry import meta_profiler_function # TODO: different activation has different FLOPs count, currently unused. diff --git a/colossalai/fx/profiler/experimental/profiler_function/arithmetic.py b/colossalai/fx/profiler/experimental/profiler_function/arithmetic.py index 8d1c8a8c6877fc12a7ad47b7b0103c309b8ed597..f1b9bb97c6c679c9b297f9feda3deec5e490631e 100644 --- a/colossalai/fx/profiler/experimental/profiler_function/arithmetic.py +++ b/colossalai/fx/profiler/experimental/profiler_function/arithmetic.py @@ -41,15 +41,15 @@ def _elementwise_flops_compute(input, other): @meta_profiler_function.register(torch.sub) @meta_profiler_function.register(torch.mul) @meta_profiler_function.register(torch.floor_divide) -@meta_profiler_function.register('add') # for built-in op + -@meta_profiler_function.register('iadd') # for built-in op += -@meta_profiler_function.register('eq') # for built-in op = -@meta_profiler_function.register('sub') # for built-in op - -@meta_profiler_function.register('isub') # for built-in op -= -@meta_profiler_function.register('mul') # for built-in op * -@meta_profiler_function.register('imul') # for built-in op *= -@meta_profiler_function.register('floordiv') # for built-in op // -@meta_profiler_function.register('ifloordiv') # for built-in op //= +@meta_profiler_function.register("add") # for built-in op + +@meta_profiler_function.register("iadd") # for built-in op += +@meta_profiler_function.register("eq") # for built-in op = +@meta_profiler_function.register("sub") # for built-in op - +@meta_profiler_function.register("isub") # for built-in op -= +@meta_profiler_function.register("mul") # for built-in op * +@meta_profiler_function.register("imul") # for built-in op *= +@meta_profiler_function.register("floordiv") # for built-in op // +@meta_profiler_function.register("ifloordiv") # for built-in op //= def torch_add_like_ops(input: Any, other: Any, *, out: Optional[torch.Tensor] = None) -> Tuple[int, int]: return _elementwise_flops_compute(input, other) @@ -62,7 +62,7 @@ def torch_elementwise_op(input: torch.Tensor, *, out: Optional[torch.Tensor] = N @meta_profiler_function.register(torch.matmul) -@meta_profiler_function.register('matmul') # for built-in op @ +@meta_profiler_function.register("matmul") # for built-in op @ @meta_profiler_function.register(torch.Tensor.matmul) def torch_matmul(input: torch.Tensor, other: torch.Tensor, *, out: Optional[torch.Tensor] = None) -> Tuple[int, int]: macs = reduce(operator.mul, input.shape) * other.shape[-1] @@ -78,13 +78,15 @@ def torch_bmm(input: torch.Tensor, other: torch.Tensor, *, out: Optional[torch.T @meta_profiler_function.register(torch.var_mean) -def torch_var_mean(input: torch.Tensor, - dim: Union[int, Tuple[int, ...]], - unbiased: Optional[bool] = True, - keepdim: Optional[bool] = False, - *, - out: Optional[torch.Tensor] = None) -> Tuple[int, int]: - assert out is None, 'saving to out is not supported yet' +def torch_var_mean( + input: torch.Tensor, + dim: Union[int, Tuple[int, ...]], + unbiased: Optional[bool] = True, + keepdim: Optional[bool] = False, + *, + out: Optional[torch.Tensor] = None, +) -> Tuple[int, int]: + assert out is None, "saving to out is not supported yet" flops = input.numel() * 3 macs = 0 return flops, macs diff --git a/colossalai/fx/profiler/experimental/profiler_function/embedding.py b/colossalai/fx/profiler/experimental/profiler_function/embedding.py index d6e43d781b8b64ab78cf3299daba3df1d17a5420..1d362015fc8b8c8fcfbcf5985e459f1de0aaa5b9 100644 --- a/colossalai/fx/profiler/experimental/profiler_function/embedding.py +++ b/colossalai/fx/profiler/experimental/profiler_function/embedding.py @@ -1,5 +1,7 @@ -import torch from typing import Optional + +import torch + from ..registry import meta_profiler_function diff --git a/colossalai/fx/profiler/experimental/profiler_function/linear.py b/colossalai/fx/profiler/experimental/profiler_function/linear.py index 01fe4c87137083db2458c560a88cc6faa0af377e..ecc578d61b911c4e09086035a0f35082aa1ddd55 100644 --- a/colossalai/fx/profiler/experimental/profiler_function/linear.py +++ b/colossalai/fx/profiler/experimental/profiler_function/linear.py @@ -1,5 +1,7 @@ from typing import Tuple + import torch + from ..registry import meta_profiler_function diff --git a/colossalai/fx/profiler/experimental/profiler_function/normalization.py b/colossalai/fx/profiler/experimental/profiler_function/normalization.py index c4ea508d70f80f33bbc8ae354e9743a4939d5e8c..2ad029eda03922e782e1f42aaafa2ca7577eb7db 100644 --- a/colossalai/fx/profiler/experimental/profiler_function/normalization.py +++ b/colossalai/fx/profiler/experimental/profiler_function/normalization.py @@ -1,5 +1,7 @@ from typing import List, Optional, Tuple + import torch + from ..registry import meta_profiler_function @@ -21,11 +23,13 @@ def torch_nn_func_instancenorm( @meta_profiler_function.register(torch.nn.functional.group_norm) -def torch_nn_func_groupnorm(input: torch.Tensor, - num_groups: int, - weight: Optional[torch.Tensor] = None, - bias: Optional[torch.Tensor] = None, - eps: float = 1e-5) -> Tuple[int, int]: +def torch_nn_func_groupnorm( + input: torch.Tensor, + num_groups: int, + weight: Optional[torch.Tensor] = None, + bias: Optional[torch.Tensor] = None, + eps: float = 1e-5, +) -> Tuple[int, int]: has_affine = weight is not None flops = input.numel() * (5 if has_affine else 4) macs = 0 diff --git a/colossalai/fx/profiler/experimental/profiler_function/pooling.py b/colossalai/fx/profiler/experimental/profiler_function/pooling.py index a639f5ee83c1f4d2b75a3c120ea6ae3884fc422f..c91deab906d44c49201da4e08600d677ae8daccc 100644 --- a/colossalai/fx/profiler/experimental/profiler_function/pooling.py +++ b/colossalai/fx/profiler/experimental/profiler_function/pooling.py @@ -1,5 +1,7 @@ -from typing import Tuple, Union +from typing import Tuple + import torch + from ..registry import meta_profiler_function diff --git a/colossalai/fx/profiler/experimental/profiler_function/python_ops.py b/colossalai/fx/profiler/experimental/profiler_function/python_ops.py index 1e8561206ba0e7a874b202a31dd13a040533d1db..58c9889ad98e24f5f163107f5095c07c8bf7f3a7 100644 --- a/colossalai/fx/profiler/experimental/profiler_function/python_ops.py +++ b/colossalai/fx/profiler/experimental/profiler_function/python_ops.py @@ -1,6 +1,6 @@ import operator from typing import Any, Tuple -import torch + from ..registry import meta_profiler_function diff --git a/colossalai/fx/profiler/experimental/profiler_function/torch_ops.py b/colossalai/fx/profiler/experimental/profiler_function/torch_ops.py index abdd7ad565ba237d7d6eab9e3c9b77d7afb10abf..67e90fb69acd0a1eebb4343118c601adbfd38e1a 100644 --- a/colossalai/fx/profiler/experimental/profiler_function/torch_ops.py +++ b/colossalai/fx/profiler/experimental/profiler_function/torch_ops.py @@ -1,7 +1,9 @@ -from functools import reduce import operator +from functools import reduce from typing import Any, Optional, Tuple + import torch + from ..registry import meta_profiler_function @@ -43,13 +45,11 @@ def torch_where(condition: torch.Tensor, x: Any, y: Any) -> Tuple[int, int]: @meta_profiler_function.register(torch.max) -def torch_max(input: torch.Tensor, - dim: int = None, - keepdim: bool = False, - *, - out: Optional[torch.Tensor] = None) -> Tuple[int, int]: +def torch_max( + input: torch.Tensor, dim: int = None, keepdim: bool = False, *, out: Optional[torch.Tensor] = None +) -> Tuple[int, int]: macs = 0 - assert out is None, 'assigning value to out is not supported yet' + assert out is None, "assigning value to out is not supported yet" if dim is not None: shape = list(input.shape) shape.pop(int(dim)) diff --git a/colossalai/fx/profiler/experimental/profiler_module/activation_function.py b/colossalai/fx/profiler/experimental/profiler_module/activation_function.py index 2ebf514ad2699cc4e71741b9c3e143cedcb63041..ae065e0c7c176461090e700640c4dba0436174bf 100644 --- a/colossalai/fx/profiler/experimental/profiler_module/activation_function.py +++ b/colossalai/fx/profiler/experimental/profiler_module/activation_function.py @@ -1,5 +1,7 @@ from typing import Tuple + import torch + from ..registry import meta_profiler_module # TODO: different activation has different FLOPs count, currently unused. diff --git a/colossalai/fx/profiler/experimental/profiler_module/attention.py b/colossalai/fx/profiler/experimental/profiler_module/attention.py index 8daf74b232bf91d41933a2184e7b0c30d516d51a..dfaee75e04320a0451b7cf7eb02154082a8c5209 100644 --- a/colossalai/fx/profiler/experimental/profiler_module/attention.py +++ b/colossalai/fx/profiler/experimental/profiler_module/attention.py @@ -1,19 +1,23 @@ from typing import Optional, Tuple + import torch + from ..registry import meta_profiler_module # TODO: This is hard to compute memory cost @meta_profiler_module.register(torch.nn.MultiheadAttention) -def torch_nn_msa(self: torch.nn.MultiheadAttention, - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - key_padding_mask: Optional[torch.Tensor] = None, - need_weights: bool = True, - attn_mask: Optional[torch.Tensor] = None, - average_attn_weights: bool = True) -> Tuple[int, int]: - if getattr(self, 'batch_first', False): +def torch_nn_msa( + self: torch.nn.MultiheadAttention, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + key_padding_mask: Optional[torch.Tensor] = None, + need_weights: bool = True, + attn_mask: Optional[torch.Tensor] = None, + average_attn_weights: bool = True, +) -> Tuple[int, int]: + if getattr(self, "batch_first", False): batch_size = query.shape[0] len_idx = 1 else: @@ -44,15 +48,9 @@ def torch_nn_msa(self: torch.nn.MultiheadAttention, flops += qlen * qdim # Initial projections - flops += 2 * ((qlen * qdim * qdim) # QW - + (klen * kdim * kdim) # KW - + (vlen * vdim * vdim) # VW - ) + flops += 2 * ((qlen * qdim * qdim) + (klen * kdim * kdim) + (vlen * vdim * vdim)) # QW # KW # VW - macs += ((qlen * qdim * qdim) # QW - + (klen * kdim * kdim) # KW - + (vlen * vdim * vdim) # VW - ) + macs += (qlen * qdim * qdim) + (klen * kdim * kdim) + (vlen * vdim * vdim) # QW # KW # VW if self.in_proj_bias is not None: flops += (qlen + klen + vlen) * qdim @@ -62,13 +60,9 @@ def torch_nn_msa(self: torch.nn.MultiheadAttention, v_head_dim = vdim // num_heads head_flops = ( - 2 * (qlen * klen * qk_head_dim) # QK^T - + (qlen * klen) # softmax - + 2 * (qlen * klen * v_head_dim) # AV + 2 * (qlen * klen * qk_head_dim) + (qlen * klen) + 2 * (qlen * klen * v_head_dim) # QK^T # softmax # AV ) - head_macs = ((qlen * klen * qk_head_dim) # QK^T - + 2 * (qlen * klen * v_head_dim) # AV - ) + head_macs = (qlen * klen * qk_head_dim) + 2 * (qlen * klen * v_head_dim) # QK^T # AV flops += num_heads * head_flops macs += num_heads * head_flops diff --git a/colossalai/fx/profiler/experimental/profiler_module/convolution.py b/colossalai/fx/profiler/experimental/profiler_module/convolution.py index a4c15b91e611d5d1398eeb38ec5c106c2652749b..90e494c77f5b5d38e942e4e749f6400e8886868e 100644 --- a/colossalai/fx/profiler/experimental/profiler_module/convolution.py +++ b/colossalai/fx/profiler/experimental/profiler_module/convolution.py @@ -17,8 +17,9 @@ def torch_nn_conv1d(self: torch.nn.Conv1d, input: torch.Tensor) -> Tuple[int, in # at https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html c_in, l_in = input.shape[-2:] c_out = self.out_channels - l_out = math.floor((l_in + 2 * self.padding[0] - self.dilation[0] * - (self.kernel_size[0] - 1) - 1) / self.stride[0] + 1) + l_out = math.floor( + (l_in + 2 * self.padding[0] - self.dilation[0] * (self.kernel_size[0] - 1) - 1) / self.stride[0] + 1 + ) result_shape = input.shape[:-2] + ( c_out, l_out, @@ -38,10 +39,12 @@ def torch_nn_conv2d(self: torch.nn.Conv2d, input: torch.Tensor) -> Tuple[int, in # at https://pytorch.org/docs/stable/generated/torch.nn.Conv2d.html c_in, h_in, w_in = input.shape[-3:] c_out = self.out_channels - h_out = math.floor((h_in + 2 * self.padding[0] - self.dilation[0] * - (self.kernel_size[0] - 1) - 1) / self.stride[0] + 1) - w_out = math.floor((w_in + 2 * self.padding[1] - self.dilation[1] * - (self.kernel_size[1] - 1) - 1) / self.stride[1] + 1) + h_out = math.floor( + (h_in + 2 * self.padding[0] - self.dilation[0] * (self.kernel_size[0] - 1) - 1) / self.stride[0] + 1 + ) + w_out = math.floor( + (w_in + 2 * self.padding[1] - self.dilation[1] * (self.kernel_size[1] - 1) - 1) / self.stride[1] + 1 + ) result_shape = input.shape[:-3] + ( c_out, h_out, @@ -62,12 +65,15 @@ def torch_nn_conv3d(self: torch.nn.Conv3d, input: torch.Tensor) -> Tuple[int, in # at https://pytorch.org/docs/stable/generated/torch.nn.Conv3d.html c_in, d_in, h_in, w_in = input.shape[-4:] c_out = self.out_channels - d_out = math.floor((d_in + 2 * self.padding[0] - self.dilation[0] * - (self.kernel_size[0] - 1) - 1) / self.stride[0] + 1) - h_out = math.floor((h_in + 2 * self.padding[1] - self.dilation[1] * - (self.kernel_size[1] - 1) - 1) / self.stride[1] + 1) - w_out = math.floor((w_in + 2 * self.padding[2] - self.dilation[2] * - (self.kernel_size[2] - 1) - 1) / self.stride[2] + 1) + d_out = math.floor( + (d_in + 2 * self.padding[0] - self.dilation[0] * (self.kernel_size[0] - 1) - 1) / self.stride[0] + 1 + ) + h_out = math.floor( + (h_in + 2 * self.padding[1] - self.dilation[1] * (self.kernel_size[1] - 1) - 1) / self.stride[1] + 1 + ) + w_out = math.floor( + (w_in + 2 * self.padding[2] - self.dilation[2] * (self.kernel_size[2] - 1) - 1) / self.stride[2] + 1 + ) result_shape = input.shape[:-4] + ( c_out, d_out, @@ -89,8 +95,13 @@ def torch_nn_convtranspose1d(self: torch.nn.ConvTranspose1d, input: torch.Tensor # at https://pytorch.org/docs/stable/generated/torch.nn.ConvTranspose1d.html c_in, l_in = input.shape[-2:] c_out = self.out_channels - l_out = math.floor((l_in - 1) * self.stride[0] - 2 * self.padding[0] + self.dilation[0] * - (self.kernel_size[0] - 1) + self.output_padding[0] + 1) + l_out = math.floor( + (l_in - 1) * self.stride[0] + - 2 * self.padding[0] + + self.dilation[0] * (self.kernel_size[0] - 1) + + self.output_padding[0] + + 1 + ) result_shape = input.shape[:-2] + ( c_out, l_out, @@ -98,7 +109,7 @@ def torch_nn_convtranspose1d(self: torch.nn.ConvTranspose1d, input: torch.Tensor macs_per_elem = reduce(operator.mul, self.kernel_size) * c_in // self.groups num_elem = reduce( operator.mul, input.shape - ) # see https://github.com/microsoft/DeepSpeed/blob/master/deepspeed/profiling/flops_profiler/profiler.py#L604 + ) # see https://github.com/microsoft/DeepSpeed/blob/master/deepspeed/profiling/flops_profiler/profiler.py#L604 macs = macs_per_elem * num_elem flops = 2 * macs if self.bias is not None: @@ -112,10 +123,20 @@ def torch_nn_convtranspose2d(self: torch.nn.ConvTranspose2d, input: torch.Tensor # at https://pytorch.org/docs/stable/generated/torch.nn.ConvTranspose2d.html c_in, h_in, w_in = input.shape[-3:] c_out = self.out_channels - h_out = math.floor((h_in - 1) * self.stride[0] - 2 * self.padding[0] + self.dilation[0] * - (self.kernel_size[0] - 1) + self.output_padding[0] + 1) - w_out = math.floor((w_in - 1) * self.stride[1] - 2 * self.padding[1] + self.dilation[1] * - (self.kernel_size[1] - 1) + self.output_padding[1] + 1) + h_out = math.floor( + (h_in - 1) * self.stride[0] + - 2 * self.padding[0] + + self.dilation[0] * (self.kernel_size[0] - 1) + + self.output_padding[0] + + 1 + ) + w_out = math.floor( + (w_in - 1) * self.stride[1] + - 2 * self.padding[1] + + self.dilation[1] * (self.kernel_size[1] - 1) + + self.output_padding[1] + + 1 + ) result_shape = input.shape[:-3] + ( c_out, h_out, @@ -136,12 +157,27 @@ def torch_nn_convtranspose3d(self: torch.nn.ConvTranspose3d, input: torch.Tensor # at https://pytorch.org/docs/stable/generated/torch.nn.ConvTranspose3d.html c_in, d_in, h_in, w_in = input.shape[-4:] c_out = self.out_channels - d_out = math.floor((d_in - 1) * self.stride[0] - 2 * self.padding[0] + self.dilation[0] * - (self.kernel_size[0] - 1) + self.output_padding[0] + 1) - h_out = math.floor((h_in - 1) * self.stride[1] - 2 * self.padding[1] + self.dilation[1] * - (self.kernel_size[1] - 1) + self.output_padding[1] + 1) - w_out = math.floor((w_in - 1) * self.stride[2] - 2 * self.padding[2] + self.dilation[2] * - (self.kernel_size[2] - 1) + self.output_padding[2] + 1) + d_out = math.floor( + (d_in - 1) * self.stride[0] + - 2 * self.padding[0] + + self.dilation[0] * (self.kernel_size[0] - 1) + + self.output_padding[0] + + 1 + ) + h_out = math.floor( + (h_in - 1) * self.stride[1] + - 2 * self.padding[1] + + self.dilation[1] * (self.kernel_size[1] - 1) + + self.output_padding[1] + + 1 + ) + w_out = math.floor( + (w_in - 1) * self.stride[2] + - 2 * self.padding[2] + + self.dilation[2] * (self.kernel_size[2] - 1) + + self.output_padding[2] + + 1 + ) result_shape = input.shape[:-4] + ( c_out, d_out, diff --git a/colossalai/fx/profiler/experimental/profiler_module/dropout.py b/colossalai/fx/profiler/experimental/profiler_module/dropout.py index 417e0ed468637a5ce049ffa8137a73e5b266c971..7361239eb1bdf087c4ae3fd457e121cc8390ffa7 100644 --- a/colossalai/fx/profiler/experimental/profiler_module/dropout.py +++ b/colossalai/fx/profiler/experimental/profiler_module/dropout.py @@ -1,5 +1,7 @@ from typing import Tuple + import torch + from ..registry import meta_profiler_module diff --git a/colossalai/fx/profiler/experimental/profiler_module/linear.py b/colossalai/fx/profiler/experimental/profiler_module/linear.py index e1ffb6f244d2ed7d5764339d61fdb46f71ae59a2..71fed3196c1326c0c3121d9f6357548d699536ca 100644 --- a/colossalai/fx/profiler/experimental/profiler_module/linear.py +++ b/colossalai/fx/profiler/experimental/profiler_module/linear.py @@ -1,5 +1,7 @@ from typing import Tuple + import torch + from ..registry import meta_profiler_module diff --git a/colossalai/fx/profiler/experimental/profiler_module/normalization.py b/colossalai/fx/profiler/experimental/profiler_module/normalization.py index 49e5e6fa5384b07412abe7ecc947d7963e88bd1a..5a64e44947b748b8b8ff7554dcef54d9933fd016 100644 --- a/colossalai/fx/profiler/experimental/profiler_module/normalization.py +++ b/colossalai/fx/profiler/experimental/profiler_module/normalization.py @@ -16,8 +16,12 @@ from ..registry import meta_profiler_module @meta_profiler_module.register(torch.nn.BatchNorm1d) @meta_profiler_module.register(torch.nn.BatchNorm2d) @meta_profiler_module.register(torch.nn.BatchNorm3d) -def torch_nn_normalize(self: Union[torch.nn.LayerNorm, torch.nn.GroupNorm, torch.nn.BatchNorm1d, torch.nn.BatchNorm2d, - torch.nn.BatchNorm3d], input: torch.Tensor) -> Tuple[int, int]: +def torch_nn_normalize( + self: Union[ + torch.nn.LayerNorm, torch.nn.GroupNorm, torch.nn.BatchNorm1d, torch.nn.BatchNorm2d, torch.nn.BatchNorm3d + ], + input: torch.Tensor, +) -> Tuple[int, int]: # adopted from https://github.com/microsoft/DeepSpeed/blob/master/deepspeed/profiling/flops_profiler/profiler.py#L615 has_affine = self.weight is not None if self.training: @@ -30,6 +34,7 @@ def torch_nn_normalize(self: Union[torch.nn.LayerNorm, torch.nn.GroupNorm, torch try: import apex + meta_profiler_module.register(apex.normalization.FusedLayerNorm)(torch_nn_normalize) meta_profiler_module.register(apex.normalization.FusedRMSNorm)(torch_nn_normalize) meta_profiler_module.register(apex.normalization.MixedFusedLayerNorm)(torch_nn_normalize) diff --git a/colossalai/fx/profiler/experimental/profiler_module/pooling.py b/colossalai/fx/profiler/experimental/profiler_module/pooling.py index e429ac3eea28055f42af2ea8f84663a5a6fd2a83..b3b630b2dee91f16358b9102963287cb81883bd5 100644 --- a/colossalai/fx/profiler/experimental/profiler_module/pooling.py +++ b/colossalai/fx/profiler/experimental/profiler_module/pooling.py @@ -1,5 +1,7 @@ from typing import Tuple + import torch + from ..registry import meta_profiler_module diff --git a/colossalai/fx/profiler/experimental/profiler_module/rnn.py b/colossalai/fx/profiler/experimental/profiler_module/rnn.py index 6e733d6da9156db13b2bac35af63a52ad89ad5a3..8a4c828dbd276dc6cab6d2e36209a16e4b48230a 100644 --- a/colossalai/fx/profiler/experimental/profiler_module/rnn.py +++ b/colossalai/fx/profiler/experimental/profiler_module/rnn.py @@ -1,12 +1,15 @@ -from functools import reduce import operator +from functools import reduce +from typing import Optional, Tuple + import torch + from ..registry import meta_profiler_module -from typing import Optional, Tuple, Union -def _rnn_flops(flops: int, macs: int, module: torch.nn.RNNBase, w_ih: torch.Tensor, - w_hh: torch.Tensor) -> Tuple[int, int]: +def _rnn_flops( + flops: int, macs: int, module: torch.nn.RNNBase, w_ih: torch.Tensor, w_hh: torch.Tensor +) -> Tuple[int, int]: # copied from https://github.com/sovrasov/flops-counter.pytorch/blob/master/ptflops/pytorch_ops.py # matrix matrix mult ih state and internal state @@ -42,12 +45,12 @@ def torch_nn_rnn(self: torch.nn.RNNBase, input: torch.Tensor, hx: Optional[torch flops = 0 macs = 0 for i in range(self.num_layers): - w_ih = self.__getattr__('weight_ih_l' + str(i)) - w_hh = self.__getattr__('weight_hh_l' + str(i)) + w_ih = self.__getattr__("weight_ih_l" + str(i)) + w_hh = self.__getattr__("weight_hh_l" + str(i)) flops, macs = _rnn_flops(flops, macs, self, w_ih, w_hh) if self.bias: - b_ih = self.__getattr__('bias_ih_l' + str(i)) - b_hh = self.__getattr__('bias_hh_l' + str(i)) + b_ih = self.__getattr__("bias_ih_l" + str(i)) + b_hh = self.__getattr__("bias_hh_l" + str(i)) flops += reduce(operator.mul, b_ih) + reduce(operator.mul, b_hh) flops *= reduce(operator.mul, input.shape[:2]) macs *= reduce(operator.mul, input.shape[:2]) @@ -63,12 +66,12 @@ def torch_nn_rnn(self: torch.nn.RNNBase, input: torch.Tensor, hx: Optional[torch def torch_nn_rnn(self: torch.nn.RNNCellBase, input: torch.Tensor, hx: Optional[torch.Tensor] = None) -> Tuple[int, int]: flops = 0 macs = 0 - w_ih = self.__getattr__('weight_ih_l') - w_hh = self.__getattr__('weight_hh_l') + w_ih = self.__getattr__("weight_ih_l") + w_hh = self.__getattr__("weight_hh_l") flops, macs = _rnn_flops(flops, macs, self, w_ih, w_hh) if self.bias: - b_ih = self.__getattr__('bias_ih_l') - b_hh = self.__getattr__('bias_hh_l') + b_ih = self.__getattr__("bias_ih_l") + b_hh = self.__getattr__("bias_hh_l") flops += reduce(operator.mul, b_ih) + reduce(operator.mul, b_hh) flops *= input.shape[0] macs *= input.shape[0] diff --git a/colossalai/fx/profiler/experimental/profiler_module/torch_op.py b/colossalai/fx/profiler/experimental/profiler_module/torch_op.py index d3aed874eb10af76dc94e21b23c566178afe6264..06be25246a712ae6af9596b5f398d9887b8f42bb 100644 --- a/colossalai/fx/profiler/experimental/profiler_module/torch_op.py +++ b/colossalai/fx/profiler/experimental/profiler_module/torch_op.py @@ -1,7 +1,8 @@ -import operator +from typing import Tuple + import torch + from ..registry import meta_profiler_module -from typing import Optional, Tuple, Union @meta_profiler_module.register(torch.nn.Flatten) diff --git a/colossalai/fx/profiler/experimental/registry.py b/colossalai/fx/profiler/experimental/registry.py index 7d73bce321e43d7c1284bf4d78dbac2bc7c4abfc..d47129cd2978422b8d669a61cbe684afe399e337 100644 --- a/colossalai/fx/profiler/experimental/registry.py +++ b/colossalai/fx/profiler/experimental/registry.py @@ -1,11 +1,9 @@ class ProfilerRegistry: - def __init__(self, name): self.name = name self.store = {} def register(self, source): - def wrapper(func): self.store[source] = func return func @@ -21,5 +19,5 @@ class ProfilerRegistry: return source in self.store -meta_profiler_function = ProfilerRegistry(name='patched_functions_for_meta_profile') -meta_profiler_module = ProfilerRegistry(name='patched_modules_for_meta_profile') +meta_profiler_function = ProfilerRegistry(name="patched_functions_for_meta_profile") +meta_profiler_module = ProfilerRegistry(name="patched_modules_for_meta_profile") diff --git a/colossalai/fx/profiler/experimental/shard_utils.py b/colossalai/fx/profiler/experimental/shard_utils.py index 1e53ed0bf8ec657d8916d49c8c4b97f1d996010a..90e8c3b7cfe46be47f29198efaf1f9a21a2fd174 100644 --- a/colossalai/fx/profiler/experimental/shard_utils.py +++ b/colossalai/fx/profiler/experimental/shard_utils.py @@ -1,8 +1,6 @@ # for PyTorch 1.11 compatibility uses -from typing import Dict, List, Tuple, Union -import torch -from torch.fx import GraphModule, Node +from torch.fx import Node from ..._compatibility import compatibility @@ -19,7 +17,7 @@ def calculate_fwd_in(n: Node) -> bool: Returns: save_fwd_in (bool): the result of `save_fwd_in` """ - return n.meta['save_fwd_in'] + return n.meta["save_fwd_in"] @compatibility(is_backward_compatible=True) @@ -45,4 +43,4 @@ def calculate_fwd_out(n: Node) -> int: Returns: fwd_out (int): the result of `fwd_out` """ - return n.meta['fwd_mem_out'] + return n.meta["fwd_mem_out"] diff --git a/colossalai/fx/profiler/memory_utils.py b/colossalai/fx/profiler/memory_utils.py index 6ccbcb01cdc14045fbdd4906f1fc6f2a5ad728db..e8eb5f25cb6c663936125ff38fed69bbddbf8ae4 100644 --- a/colossalai/fx/profiler/memory_utils.py +++ b/colossalai/fx/profiler/memory_utils.py @@ -1,11 +1,11 @@ from typing import Dict, List, Tuple, Union import torch -from torch.fx import GraphModule, Node +from torch.fx import Node from .._compatibility import compatibility, is_compatible_with_meta -__all__ = ['activation_size', 'parameter_size', 'is_inplace'] +__all__ = ["activation_size", "parameter_size", "is_inplace"] @compatibility(is_backward_compatible=True) @@ -63,6 +63,7 @@ def is_inplace(n: Node): inplace = n.kwargs.get("inplace", False) if is_compatible_with_meta(): from .constants import ALIAS_ATEN + if n.target in ALIAS_ATEN: inplace = True elif n.op == "call_module": diff --git a/colossalai/fx/profiler/opcount.py b/colossalai/fx/profiler/opcount.py index ba090a2ec51bd7d1d83a6dd5d75c877c0708577f..8fae0f2ecb45a4a2aa12449248f7a451a6907b01 100644 --- a/colossalai/fx/profiler/opcount.py +++ b/colossalai/fx/profiler/opcount.py @@ -173,8 +173,11 @@ def norm_flop_counter(affine_arg_index: int, input_arg_index: int) -> Callable: # Inputs[0] contains the shape of the input. input_shape = inputs[input_arg_index].shape - has_affine = inputs[affine_arg_index].shape is not None if hasattr(inputs[affine_arg_index], - 'shape') else inputs[affine_arg_index] + has_affine = ( + inputs[affine_arg_index].shape is not None + if hasattr(inputs[affine_arg_index], "shape") + else inputs[affine_arg_index] + ) assert 2 <= len(input_shape) <= 5, input_shape # 5 is just a rough estimate flop = reduce(operator.mul, input_shape) * (5 if has_affine else 4) @@ -188,7 +191,7 @@ def batchnorm_flop_jit(inputs: List[Any], outputs: List[Any], training: bool = N training = inputs[-3] assert isinstance(training, bool), "Signature of aten::batch_norm has changed!" if training: - return norm_flop_counter(1, 0)(inputs, outputs) # pyre-ignore + return norm_flop_counter(1, 0)(inputs, outputs) # pyre-ignore has_affine = inputs[1].shape is not None input_shape = reduce(operator.mul, inputs[0].shape) return input_shape * (2 if has_affine else 1) @@ -218,15 +221,16 @@ def elementwise_flop_counter(input_scale: float = 1, output_scale: float = 0) -> def zero_flop_jit(*args): """ - Count flops for zero flop layers. + Count flops for zero flop layers. """ return 0 -if version.parse(torch.__version__) >= version.parse('1.12.0') and version.parse( - torch.__version__) < version.parse('2.0.0'): +if version.parse(torch.__version__) >= version.parse("1.12.0") and version.parse(torch.__version__) < version.parse( + "2.0.0" +): flop_mapping = { - # gemm, gemv and dot + # gemm, gemv and dot aten.mm.default: matmul_flop_jit, aten.mv.default: matmul_flop_jit, aten.dot.default: matmul_flop_jit, @@ -234,13 +238,11 @@ if version.parse(torch.__version__) >= version.parse('1.12.0') and version.parse aten.addmm.default: addmm_flop_jit, aten.bmm.default: bmm_flop_jit, aten.baddbmm.default: baddbmm_flop_jit, - - # convolution + # convolution aten.convolution.default: conv_flop_jit, aten._convolution.default: conv_flop_jit, aten.convolution_backward.default: conv_backward_flop_jit, - - # normalization + # normalization aten.native_batch_norm.default: batchnorm_flop_jit, aten.native_batch_norm_backward.default: batchnorm_flop_jit, aten.cudnn_batch_norm.default: batchnorm_flop_jit, @@ -249,8 +251,7 @@ if version.parse(torch.__version__) >= version.parse('1.12.0') and version.parse aten.native_layer_norm_backward.default: norm_flop_counter(2, 0), aten.native_group_norm.default: norm_flop_counter(2, 0), aten.native_group_norm_backward.default: norm_flop_counter(2, 0), - - # pooling + # pooling aten.avg_pool1d.default: elementwise_flop_counter(1, 0), aten.avg_pool2d.default: elementwise_flop_counter(1, 0), aten.avg_pool2d_backward.default: elementwise_flop_counter(0, 1), @@ -275,7 +276,7 @@ if version.parse(torch.__version__) >= version.parse('1.12.0') and version.parse } elementwise_flop_aten = [ - # basic op + # basic op aten.add.Tensor, aten.add_.Tensor, aten.div.Tensor, @@ -296,8 +297,7 @@ if version.parse(torch.__version__) >= version.parse('1.12.0') and version.parse aten.exp.default, aten.sin.default, aten.cos.default, - - # activation op + # activation op aten.hardswish.default, aten.hardswish_.default, aten.hardswish_backward.default, @@ -320,8 +320,7 @@ if version.parse(torch.__version__) >= version.parse('1.12.0') and version.parse aten.tanh.default, aten.tanh_backward.default, aten.threshold_backward.default, - - # dropout + # dropout aten.native_dropout.default, aten.native_dropout_backward.default, ] @@ -362,7 +361,7 @@ if version.parse(torch.__version__) >= version.parse('1.12.0') and version.parse aten.zero_.default, aten.zeros_like.default, aten.fill_.Scalar, - aten.stack.default + aten.stack.default, ] # yapf: disable for op in zero_flop_aten: diff --git a/colossalai/fx/profiler/profiler.py b/colossalai/fx/profiler/profiler.py index c87cd4321d31c59ebe369a2903b679a439fb4f96..97e70db6290e58d62c113ca37f4323be07ef3e58 100644 --- a/colossalai/fx/profiler/profiler.py +++ b/colossalai/fx/profiler/profiler.py @@ -15,7 +15,7 @@ from .memory_utils import activation_size, parameter_size from .opcount import flop_mapping from .tensor import MetaTensor -__all__ = ['profile_function', 'profile_module', 'profile_method'] +__all__ = ["profile_function", "profile_module", "profile_method"] # super-dainiu: this cache should be global, otherwise it cannot # track duplicated tensors between nodes @@ -174,7 +174,6 @@ def _profile_meta(target: Callable, *args, **kwargs) -> Tuple[Tuple[Any, ...], G # backward is executed. # Hopefully, this attempt will provide a better estimation of memory. class FlopTensor(MetaTensor): - _node: Node = None def __repr__(self): @@ -186,24 +185,24 @@ def _profile_meta(target: Callable, *args, **kwargs) -> Tuple[Tuple[Any, ...], G def __torch_dispatch__(cls, func, types, args=(), kwargs=None): args_node = tree_map(lambda x: x._node if isinstance(x, FlopTensor) else None, args) kwargs_node = tree_map(lambda x: x._node if isinstance(x, FlopTensor) else None, kwargs) - node = subgraph.create_node('call_function', func, args_node, kwargs_node) + node = subgraph.create_node("call_function", func, args_node, kwargs_node) out = super().__torch_dispatch__(func, types, args, kwargs) flop_count[phase] += flop_mapping[func](args, normalize_tuple(out)) - node.meta['phase'] = phase + node.meta["phase"] = phase # super-dainiu: in `nn.MultiheadAttention` this weird thing occurs, # i.e. `Phase.PLACEHOLDER` tensors are aliased and saved during # `Phase.FORWARD` if phase == Phase.FORWARD: if all(map(partial(is_phase, phase=Phase.PLACEHOLDER), node.all_input_nodes)) and func in ALIAS_ATEN: - node.meta['phase'] = Phase.PLACEHOLDER + node.meta["phase"] = Phase.PLACEHOLDER # TODO(yby): specify `saved_tensors` for backward memory estimation - node.meta['saved_tensor'] = [] + node.meta["saved_tensor"] = [] if phase == Phase.BACKWARD: - node.meta['saved_tensor'] = normalize_tuple(out) + node.meta["saved_tensor"] = normalize_tuple(out) def wrap(x): if isinstance(x, MetaTensor): @@ -219,11 +218,14 @@ def _profile_meta(target: Callable, *args, **kwargs) -> Tuple[Tuple[Any, ...], G x = FlopTensor(x) if is_autogradable(x): x.requires_grad_(True) - x._node = subgraph.create_node('placeholder', - 'placeholder', (subgraph._root,), - name=subgraph._graph_namespace.create_name('input', x._tensor)) - x._node.meta['phase'] = Phase.PLACEHOLDER - x._node.meta['saved_tensor'] = [] + x._node = subgraph.create_node( + "placeholder", + "placeholder", + (subgraph._root,), + name=subgraph._graph_namespace.create_name("input", x._tensor), + ) + x._node.meta["phase"] = Phase.PLACEHOLDER + x._node.meta["saved_tensor"] = [] return x # Basically, we need to detach the args and kwargs from the outer graph. @@ -235,7 +237,7 @@ def _profile_meta(target: Callable, *args, **kwargs) -> Tuple[Tuple[Any, ...], G if isinstance(x, FlopTensor) and not x._tensor.data_ptr() in cache: tensor = x._tensor.detach() tensor.data_ptr = x._tensor.data_ptr - x._node.meta['saved_tensor'] += [tensor] + x._node.meta["saved_tensor"] += [tensor] if not do_not_cache: cache.add(x._tensor.data_ptr()) return x @@ -284,7 +286,7 @@ def _profile_meta(target: Callable, *args, **kwargs) -> Tuple[Tuple[Any, ...], G @compatibility(is_backward_compatible=True) -def profile_function(target: 'Target', device: str = 'meta') -> Callable: +def profile_function(target: "Target", device: str = "meta") -> Callable: """ Wrap a `call_function` node or `torch.nn.functional` in order to record the memory cost and FLOPs of the execution. @@ -300,7 +302,6 @@ def profile_function(target: 'Target', device: str = 'meta') -> Callable: """ def f(*args: Tuple[Argument, ...], **kwargs: Dict[str, Any]) -> Any: - # find the grad for parameter in args and kwargs param_size = 0 @@ -316,18 +317,18 @@ def profile_function(target: 'Target', device: str = 'meta') -> Callable: # still run the profiling but discard some results regarding `target` global do_not_cache - inplace = kwargs.get('inplace', False) + inplace = kwargs.get("inplace", False) if target in OUTPUT_SAVED_OPS: do_not_cache = True if inplace: do_not_cache = True - kwargs['inplace'] = False - if device == 'meta': + kwargs["inplace"] = False + if device == "meta": out, meta = _profile_meta(func, *args, **kwargs) else: out, meta = _profile_concrete(func, *args, **kwargs) if inplace: - kwargs['inplace'] = True + kwargs["inplace"] = True meta.bwd_mem_tmp = 0 meta.bwd_mem_out = 0 do_not_cache = False @@ -341,7 +342,7 @@ def profile_function(target: 'Target', device: str = 'meta') -> Callable: @compatibility(is_backward_compatible=True) -def profile_method(target: 'Target', device: str = 'meta') -> Callable: +def profile_method(target: "Target", device: str = "meta") -> Callable: """ Wrap a `call_method` node record the memory cost and FLOPs of the execution. @@ -349,8 +350,8 @@ def profile_method(target: 'Target', device: str = 'meta') -> Callable: def f(*args: Tuple[Argument, ...], **kwargs: Dict[str, Any]) -> Any: # execute the method and return the result - assert isinstance(target, str), f'{target} instance is not str.' - if device == 'meta': + assert isinstance(target, str), f"{target} instance is not str." + if device == "meta": out, meta = _profile_meta(target, *args, **kwargs) else: out, meta = _profile_concrete(target, *args, **kwargs) @@ -360,7 +361,7 @@ def profile_method(target: 'Target', device: str = 'meta') -> Callable: @compatibility(is_backward_compatible=True) -def profile_module(module: torch.nn.Module, device: str = 'meta') -> Callable: +def profile_module(module: torch.nn.Module, device: str = "meta") -> Callable: """ Wrap a `call_module` node or `torch.nn` in order to record the memory cost and FLOPs of the execution. @@ -376,7 +377,6 @@ def profile_module(module: torch.nn.Module, device: str = 'meta') -> Callable: """ def f(*args: Tuple[Argument, ...], **kwargs: Dict[str, Any]) -> Any: - # calculate parameter size param_size = parameter_size(module) @@ -384,13 +384,13 @@ def profile_module(module: torch.nn.Module, device: str = 'meta') -> Callable: # still run the profiling but discard some results regarding `module`. global do_not_cache - inplace = getattr(module, 'inplace', False) + inplace = getattr(module, "inplace", False) if type(module) in OUTPUT_SAVED_MOD: do_not_cache = True if inplace: do_not_cache = True module.inplace = False - if device == 'meta': + if device == "meta": out, meta = _profile_meta(func, *args, **kwargs) else: out, meta = _profile_concrete(func, *args, **kwargs) diff --git a/colossalai/fx/profiler/shard_utils.py b/colossalai/fx/profiler/shard_utils.py index 34feefb4336ab4a7924f7023b6f887ba8610b25c..75b7c814f05f4fd6f170959da4fb5d451bfcc75a 100644 --- a/colossalai/fx/profiler/shard_utils.py +++ b/colossalai/fx/profiler/shard_utils.py @@ -59,9 +59,9 @@ def calculate_fwd_tmp(n: Node) -> int: Returns: bool: Whether the node is a ReLU-like node """ - if n.op == 'call_function': + if n.op == "call_function": return n.target in OUTPUT_SAVED_OPS - elif n.op == 'call_module': + elif n.op == "call_module": return type(n.graph.owning_module.get_submodule(n.target)) in OUTPUT_SAVED_MOD return False diff --git a/colossalai/fx/profiler/tensor.py b/colossalai/fx/profiler/tensor.py index 2ee5e5c47750c4bbc714987ce4bcae9f29e6ac71..7c14b48bdaa1315e28494cb5b1cbbe8920921a53 100644 --- a/colossalai/fx/profiler/tensor.py +++ b/colossalai/fx/profiler/tensor.py @@ -1,13 +1,13 @@ import uuid import torch -from torch.types import _bool, _device, _dtype -from torch.utils._pytree import tree_flatten, tree_map +from torch.types import _device +from torch.utils._pytree import tree_map from .._compatibility import compatibility from .constants import ALIAS_ATEN -__all__ = ['MetaTensor'] +__all__ = ["MetaTensor"] def set_data_ptr(x): @@ -43,12 +43,13 @@ class MetaTensor(torch.Tensor): storage_offset=elem.storage_offset(), dtype=elem.dtype, layout=elem.layout, - device=fake_device or (elem.device if elem.device.type != 'meta' else torch.device('cpu')), - requires_grad=elem.requires_grad) # deceive the frontend for aten selections + device=fake_device or (elem.device if elem.device.type != "meta" else torch.device("cpu")), + requires_grad=elem.requires_grad, + ) # deceive the frontend for aten selections r._tensor = elem # ...the real tensor is held as an element on the tensor. if not r._tensor.is_meta: - r._tensor = r._tensor.to(torch.device('meta')) + r._tensor = r._tensor.to(torch.device("meta")) # only tensor not on `meta` should be copied to `meta` set_data_ptr(r._tensor) return r @@ -69,15 +70,15 @@ class MetaTensor(torch.Tensor): x = x._tensor elif isinstance(x, torch.Tensor): fake_device = x.device - x = x.to(torch.device('meta')) + x = x.to(torch.device("meta")) return x args = tree_map(unwrap, args) kwargs = tree_map(unwrap, kwargs) - if 'device' in kwargs: - fake_device = kwargs['device'] - kwargs['device'] = torch.device('meta') + if "device" in kwargs: + fake_device = kwargs["device"] + kwargs["device"] = torch.device("meta") # run aten for backend=CPU but actually on backend=Meta out = func(*args, **kwargs) @@ -93,7 +94,7 @@ class MetaTensor(torch.Tensor): if isinstance(x, torch.Tensor): nonlocal fake_device if not x.is_meta: - x = x.to(torch.device('meta')) + x = x.to(torch.device("meta")) return MetaTensor(x, fake_device=fake_device) if isinstance(x, torch.Tensor) else x return tree_map(wrap, out) @@ -120,18 +121,18 @@ class MetaTensor(torch.Tensor): nonlocal fake_device if isinstance(x, str) or isinstance(x, _device): fake_device = x - return 'meta' + return "meta" return x elem = self._tensor.to(*tree_map(replace, args), **tree_map(replace, kwargs)) return MetaTensor(elem, fake_device=fake_device) def cpu(self, *args, **kwargs): - if self.device.type == 'cpu': + if self.device.type == "cpu": return self.to(*args, **kwargs) - return self.to(*args, device='cpu', **kwargs) + return self.to(*args, device="cpu", **kwargs) def cuda(self, device=None, non_blocking=False): if device is not None: return self.to(device=device, non_blocking=non_blocking) - return self.to(device='cuda:0', non_blocking=non_blocking) + return self.to(device="cuda:0", non_blocking=non_blocking) diff --git a/colossalai/fx/proxy.py b/colossalai/fx/proxy.py index 7317072c6298b66280810c48536eb22b7edca7f0..887832223fd6cf263e26f4bdbb43b9a39f194302 100644 --- a/colossalai/fx/proxy.py +++ b/colossalai/fx/proxy.py @@ -1,12 +1,11 @@ -import operator -from typing import Any, List, Union +from typing import Any import torch -from torch.fx.proxy import Attribute, Proxy +from torch.fx.proxy import Proxy from colossalai.fx.tracer.meta_patch import meta_patched_function -__all__ = ['ColoProxy'] +__all__ = ["ColoProxy"] class ColoProxy(Proxy): @@ -39,11 +38,12 @@ class ColoProxy(Proxy): return self._meta_data is not None def _assert_meta_data_is_tensor(self): - assert torch.is_tensor( - self._meta_data) and self._meta_data.is_meta, f'Meta data is not a meta tensor for {self.node.name}' + assert ( + torch.is_tensor(self._meta_data) and self._meta_data.is_meta + ), f"Meta data is not a meta tensor for {self.node.name}" def _assert_has_meta_data(self): - assert self._meta_data is not None, f'Meta data is not set for {self.node.name}' + assert self._meta_data is not None, f"Meta data is not set for {self.node.name}" def __len__(self): self._assert_has_meta_data() @@ -62,7 +62,6 @@ class ColoProxy(Proxy): return self.meta_data def __getattr__(self, k): - return ColoAttribute(self, k) def __contains__(self, key): @@ -92,7 +91,6 @@ def extract_meta(*args, **kwargs): class ColoAttribute(ColoProxy): - def __init__(self, root, attr: str): self.root = root self.attr = attr diff --git a/colossalai/fx/tracer/_meta_trace.py b/colossalai/fx/tracer/_meta_trace.py index 1c5abb81d271144ab666049d3ae2868cd9568497..63a7bab654d5d821a0541976a0f86368a6bae419 100644 --- a/colossalai/fx/tracer/_meta_trace.py +++ b/colossalai/fx/tracer/_meta_trace.py @@ -39,7 +39,7 @@ def meta_trace(module: torch.nn.Module, fake_device=None, *args, **kwargs) -> Gr _tensor: torch.Tensor _node: Node - __slots__ = ['_tensor', '_node'] + __slots__ = ["_tensor", "_node"] @staticmethod def __new__(cls, tensor, fake_device=None, placeholder=False, name=None): @@ -51,22 +51,22 @@ def meta_trace(module: torch.nn.Module, fake_device=None, *args, **kwargs) -> Gr dtype=tensor.dtype, layout=tensor.layout, device=fake_device if fake_device is not None else tensor.device, - requires_grad=tensor.requires_grad) # deceive the frontend for aten selections + requires_grad=tensor.requires_grad, + ) # deceive the frontend for aten selections r._tensor = tensor if placeholder: if name is None: - name = 'input' - r._node = graph.create_node('placeholder', - 'placeholder', (graph._root,), - name=namespace.create_name(name, tensor)) + name = "input" + r._node = graph.create_node( + "placeholder", "placeholder", (graph._root,), name=namespace.create_name(name, tensor) + ) # ...the real tensor is held as an element on the tensor. if not r._tensor.is_meta: - r._tensor = r._tensor.to(torch.device('meta')) + r._tensor = r._tensor.to(torch.device("meta")) return r @classmethod def __torch_dispatch__(cls, func, types, args=(), kwargs=None): - def unwrap(x): nonlocal fake_device if isinstance(x, MetaProxy): @@ -75,21 +75,21 @@ def meta_trace(module: torch.nn.Module, fake_device=None, *args, **kwargs) -> Gr # assert not isinstance(x, MetaProxy) elif isinstance(x, torch.Tensor): fake_device = x.device - x = x.to(torch.device('meta')) + x = x.to(torch.device("meta")) return x def get_node(x): - if isinstance(x, torch.Tensor) and not hasattr(x, '_node'): - x = MetaProxy(x, placeholder=True, name='weight') - return x if not hasattr(x, '_node') else x._node + if isinstance(x, torch.Tensor) and not hasattr(x, "_node"): + x = MetaProxy(x, placeholder=True, name="weight") + return x if not hasattr(x, "_node") else x._node args_node = tree_map(get_node, args) kwargs_node = tree_map(get_node, kwargs) - node = graph.create_node('call_function', func, args_node, kwargs_node) + node = graph.create_node("call_function", func, args_node, kwargs_node) - if 'device' in kwargs: - fake_device = kwargs['device'] - kwargs['device'] = torch.device('meta') + if "device" in kwargs: + fake_device = kwargs["device"] + kwargs["device"] = torch.device("meta") args = tree_map(unwrap, args) kwargs = tree_map(unwrap, kwargs) @@ -103,9 +103,12 @@ def meta_trace(module: torch.nn.Module, fake_device=None, *args, **kwargs) -> Gr if isinstance(x, torch.Tensor): nonlocal fake_device if not x.is_meta: - x = x.to(torch.device('meta')) - return MetaProxy( - x, fake_device=fake_device) if isinstance(x, torch.Tensor) and not hasattr(x, '_tensor') else x + x = x.to(torch.device("meta")) + return ( + MetaProxy(x, fake_device=fake_device) + if isinstance(x, torch.Tensor) and not hasattr(x, "_tensor") + else x + ) def set_node(x): x._node = node @@ -125,9 +128,12 @@ def meta_trace(module: torch.nn.Module, fake_device=None, *args, **kwargs) -> Gr for tensor in normalize_tuple(out): if is_autogradable(tensor) and tensor.requires_grad: - grad = torch.empty_like(tensor._tensor, device=torch.device('meta')) if isinstance( - tensor, MetaProxy) else torch.empty_like(tensor, device=torch.device('meta')) - torch.autograd.backward(tensor, - MetaProxy(grad, fake_device=tensor.device, placeholder=True), - retain_graph=True) + grad = ( + torch.empty_like(tensor._tensor, device=torch.device("meta")) + if isinstance(tensor, MetaProxy) + else torch.empty_like(tensor, device=torch.device("meta")) + ) + torch.autograd.backward( + tensor, MetaProxy(grad, fake_device=tensor.device, placeholder=True), retain_graph=True + ) return graph diff --git a/colossalai/fx/tracer/_tracer_utils.py b/colossalai/fx/tracer/_tracer_utils.py index e160497a74444fd34eea50943eb0430875fe1252..9cf1961d45ff96c987ebe81bb0897581be32947c 100644 --- a/colossalai/fx/tracer/_tracer_utils.py +++ b/colossalai/fx/tracer/_tracer_utils.py @@ -2,10 +2,10 @@ from typing import Any, List, Union import torch -from ..proxy import ColoAttribute, ColoProxy -from .meta_patch import meta_patched_function, meta_patched_module +from ..proxy import ColoProxy +from .meta_patch import meta_patched_function -__all__ = ['is_element_in_list', 'extract_meta'] +__all__ = ["is_element_in_list", "extract_meta"] def is_element_in_list(elements: Union[List[Any], Any], list_: List[Any]): @@ -21,7 +21,6 @@ def is_element_in_list(elements: Union[List[Any], Any], list_: List[Any]): def extract_meta(*args, **kwargs): - def _convert(val): if isinstance(val, ColoProxy): return val.meta_data diff --git a/colossalai/fx/tracer/bias_addition_patch/patched_bias_addition_function/addbmm.py b/colossalai/fx/tracer/bias_addition_patch/patched_bias_addition_function/addbmm.py index 859a19bf6241bbbf4061e0f7564975682527b8c2..84c09109877ece5c129bdfd862d160631cec7719 100644 --- a/colossalai/fx/tracer/bias_addition_patch/patched_bias_addition_function/addbmm.py +++ b/colossalai/fx/tracer/bias_addition_patch/patched_bias_addition_function/addbmm.py @@ -1,7 +1,4 @@ -import operator - import torch -import torch.nn.functional as F from ...registry import bias_addition_function, bias_addition_method from .bias_addition_function import LinearBasedBiasFunc @@ -10,13 +7,12 @@ from .bias_addition_function import LinearBasedBiasFunc @bias_addition_method.register(torch.Tensor.addbmm) @bias_addition_function.register(torch.addbmm) class Addbmm(LinearBasedBiasFunc): - def extract_kwargs_from_origin_func(self): kwargs = {} - if 'beta' in self.kwargs: - kwargs['beta'] = self.kwargs['beta'] - if 'alpha' in self.kwargs: - kwargs['alpha'] = self.kwargs['alpha'] + if "beta" in self.kwargs: + kwargs["beta"] = self.kwargs["beta"] + if "alpha" in self.kwargs: + kwargs["alpha"] = self.kwargs["alpha"] return kwargs def create_non_bias_func_proxy(self, input_proxy, other_proxy): @@ -25,7 +21,7 @@ class Addbmm(LinearBasedBiasFunc): compute the main computation, such as convolution, with bias option banned. """ assert self.substitute_func == torch.bmm - node_kind = 'call_function' + node_kind = "call_function" node_target = self.substitute_func node_args = (input_proxy, other_proxy) @@ -35,10 +31,10 @@ class Addbmm(LinearBasedBiasFunc): return non_bias_func_proxy def insert_sum_node(self, input_proxy, sum_dims=0): - ''' + """ This method is used to sum the input_proxy through the sum_dims. - ''' - node_kind = 'call_function' + """ + node_kind = "call_function" node_target = torch.sum node_args = (input_proxy, sum_dims) node_kwargs = {} @@ -55,15 +51,15 @@ class Addbmm(LinearBasedBiasFunc): sum_proxy = self.insert_sum_node(non_bias_linear_func_proxy) kwargs = self.extract_kwargs_from_origin_func() - if 'beta' in kwargs: - beta = kwargs['beta'] + if "beta" in kwargs: + beta = kwargs["beta"] # doing the multiplication with beta if it exists(temp_2 = beta * input) beta_proxy = self.create_mul_node(self.args[0], beta) else: beta_proxy = self.args[0] - if 'alpha' in kwargs: - alpha = kwargs['alpha'] + if "alpha" in kwargs: + alpha = kwargs["alpha"] # doing the multiplication with alpha if it exists(temp_3 = alpha * temp_1) alpha_proxy = self.create_mul_node(alpha, sum_proxy) else: diff --git a/colossalai/fx/tracer/bias_addition_patch/patched_bias_addition_function/addmm.py b/colossalai/fx/tracer/bias_addition_patch/patched_bias_addition_function/addmm.py index fe7d8d07aac941028d7c682043b5af2bcdf2537a..d087b291300594b1fdc15be4c8d6aad35aba023c 100644 --- a/colossalai/fx/tracer/bias_addition_patch/patched_bias_addition_function/addmm.py +++ b/colossalai/fx/tracer/bias_addition_patch/patched_bias_addition_function/addmm.py @@ -1,7 +1,4 @@ -import operator - import torch -import torch.nn.functional as F from ...registry import bias_addition_function, bias_addition_method from .bias_addition_function import LinearBasedBiasFunc @@ -10,17 +7,16 @@ from .bias_addition_function import LinearBasedBiasFunc @bias_addition_method.register(torch.Tensor.addmm) @bias_addition_function.register(torch.addmm) class Addmm(LinearBasedBiasFunc): - def extract_kwargs_from_origin_func(self): kwargs = {} - if 'beta' in self.kwargs: - kwargs['beta'] = self.kwargs['beta'] - if 'alpha' in self.kwargs: - kwargs['alpha'] = self.kwargs['alpha'] + if "beta" in self.kwargs: + kwargs["beta"] = self.kwargs["beta"] + if "alpha" in self.kwargs: + kwargs["alpha"] = self.kwargs["alpha"] return kwargs def transpose_other_operand_for_linear(self, other_proxy): - ''' + """ This method is used to transpose the other operand for linear function. For example: input = torch.rand(3, 4) @@ -30,8 +26,8 @@ class Addmm(LinearBasedBiasFunc): # To keep the computation graph consistent with the origin computation graph, we need to transpose the m2 # before we call the linear function. new_output = torch.linear(m1, m2.transpose(0, 1)) + input - ''' - node_kind = 'call_function' + """ + node_kind = "call_function" node_target = torch.transpose node_args = (other_proxy, 0, 1) node_kwargs = {} @@ -43,14 +39,14 @@ class Addmm(LinearBasedBiasFunc): non_bias_linear_func_proxy = self.create_non_bias_func_proxy(self.args[1], transpose_proxy) kwargs = self.extract_kwargs_from_origin_func() - if 'beta' in kwargs: - beta = kwargs['beta'] + if "beta" in kwargs: + beta = kwargs["beta"] beta_proxy = self.create_mul_node(self.args[0], beta) else: beta_proxy = self.args[0] - if 'alpha' in kwargs: - alpha = kwargs['alpha'] + if "alpha" in kwargs: + alpha = kwargs["alpha"] alpha_proxy = self.create_mul_node(alpha, non_bias_linear_func_proxy) else: alpha_proxy = non_bias_linear_func_proxy diff --git a/colossalai/fx/tracer/bias_addition_patch/patched_bias_addition_function/bias_addition_function.py b/colossalai/fx/tracer/bias_addition_patch/patched_bias_addition_function/bias_addition_function.py index 8a3786332c08d9a3320ce4c7bee8221dd3d10abd..42178b7b786e95493893a6464aedba7b10e9b34e 100644 --- a/colossalai/fx/tracer/bias_addition_patch/patched_bias_addition_function/bias_addition_function.py +++ b/colossalai/fx/tracer/bias_addition_patch/patched_bias_addition_function/bias_addition_function.py @@ -29,7 +29,6 @@ class BiasAdditionFunc(ABC): to insert two more operator.mul nodes for the computation graph to compute the final result. """ - pass @abstractmethod def generate(self): @@ -50,7 +49,6 @@ class BiasAdditionFunc(ABC): %mul_1 : [#users=1] = call_function[target=operator.mul](args = (2, %linear), kwargs = {}) %add : [#users=1] = call_function[target=operator.add](args = (%mul_1, %mul), kwargs = {}) """ - pass def create_mul_node(self, input_proxy, coefficent): """ @@ -59,7 +57,7 @@ class BiasAdditionFunc(ABC): Therefore, we need to use this method insert two more operator.mul nodes for the computation graph to compute the final result. """ - node_kind = 'call_function' + node_kind = "call_function" node_target = operator.mul node_args = ( input_proxy, @@ -82,7 +80,7 @@ class LinearBasedBiasFunc(BiasAdditionFunc): compute the main computation, such as convolution, with bias option banned. """ assert self.substitute_func == torch.nn.functional.linear - node_kind = 'call_function' + node_kind = "call_function" node_target = self.substitute_func node_args = (input_proxy, other_proxy) @@ -96,7 +94,7 @@ class LinearBasedBiasFunc(BiasAdditionFunc): This method is used to create the bias_addition_proxy, the node created by this proxy will compute the sum of non_bias_func result and bias with some reshape operation if needed. """ - bias_add_node_kind = 'call_function' + bias_add_node_kind = "call_function" bias_add_node_target = operator.add bias_add_args = (non_bias_func_proxy, bias_proxy) bias_add_proxy = self.tracer.create_proxy(bias_add_node_kind, bias_add_node_target, tuple(bias_add_args), {}) diff --git a/colossalai/fx/tracer/bias_addition_patch/patched_bias_addition_function/linear.py b/colossalai/fx/tracer/bias_addition_patch/patched_bias_addition_function/linear.py index e11ec0a364f1e5ee1445c97ae0c9b054d02bcfa2..ed060a350739ad31658264807f060595b0c5e217 100644 --- a/colossalai/fx/tracer/bias_addition_patch/patched_bias_addition_function/linear.py +++ b/colossalai/fx/tracer/bias_addition_patch/patched_bias_addition_function/linear.py @@ -1,6 +1,3 @@ -import operator - -import torch import torch.nn.functional as F from ...registry import bias_addition_function @@ -9,17 +6,16 @@ from .bias_addition_function import LinearBasedBiasFunc @bias_addition_function.register(F.linear) class Linear(LinearBasedBiasFunc): - def extract_kwargs_from_origin_func(self): - assert 'bias' in self.kwargs + assert "bias" in self.kwargs kwargs = {} - if 'bias' in self.kwargs: - kwargs['bias'] = self.kwargs['bias'] + if "bias" in self.kwargs: + kwargs["bias"] = self.kwargs["bias"] return kwargs def generate(self): non_bias_linear_func_proxy = self.create_non_bias_func_proxy(self.args[0], self.args[1]) kwargs = self.extract_kwargs_from_origin_func() - bias_addition_proxy = self.create_bias_addition_proxy(non_bias_linear_func_proxy, kwargs['bias']) + bias_addition_proxy = self.create_bias_addition_proxy(non_bias_linear_func_proxy, kwargs["bias"]) return bias_addition_proxy diff --git a/colossalai/fx/tracer/bias_addition_patch/patched_bias_addition_module/bias_addition_module.py b/colossalai/fx/tracer/bias_addition_patch/patched_bias_addition_module/bias_addition_module.py index 85f1553e304c9c45b2b8f1373e76e7023d452d47..19c0e21d7c17637f4ad7956ce62e35f46328930d 100644 --- a/colossalai/fx/tracer/bias_addition_patch/patched_bias_addition_module/bias_addition_module.py +++ b/colossalai/fx/tracer/bias_addition_patch/patched_bias_addition_module/bias_addition_module.py @@ -27,8 +27,8 @@ class BiasAdditionModule(ABC): Note: this function will be invoked during module initializing, you should never call this function. """ - weight_node_kind = 'get_attr' - weight_node_target = self.target + '.weight' + weight_node_kind = "get_attr" + weight_node_target = self.target + ".weight" weight_proxy = self.tracer.create_proxy(weight_node_kind, weight_node_target, (), {}) return weight_proxy @@ -39,8 +39,8 @@ class BiasAdditionModule(ABC): Note: this function will be invoked during module initializing, you should never call this function. """ - bias_node_kind = 'get_attr' - bias_node_target = self.target + '.bias' + bias_node_kind = "get_attr" + bias_node_target = self.target + ".bias" bias_proxy = self.tracer.create_proxy(bias_node_kind, bias_node_target, (), {}) return bias_proxy @@ -51,17 +51,16 @@ class BiasAdditionModule(ABC): For example: The kwargs for conv2d module is {} because the attributes like 'padding' or 'groups' are - considered during module initilizing. However, we need to consider those attributes as kwargs + considered during module initializing. However, we need to consider those attributes as kwargs in F.conv2d. """ - pass def create_non_bias_func_proxy(self, input_proxy=None): """ This method is used to create the non_bias_func proxy, the node created by this proxy will compute the main computation, such as convolution, with bias option banned. """ - node_kind = 'call_function' + node_kind = "call_function" node_target = self.substitute_func if input_proxy is None: input_proxy = self.args[0] @@ -75,7 +74,7 @@ class BiasAdditionModule(ABC): This method is used to create the bias_addition_proxy, the node created by this proxy will compute the sum of non_bias_func result and bias with some reshape operation if needed. """ - bias_add_node_kind = 'call_function' + bias_add_node_kind = "call_function" bias_add_node_target = operator.add bias_add_args = (non_bias_func_proxy, bias_proxy) bias_add_proxy = self.tracer.create_proxy(bias_add_node_kind, bias_add_node_target, tuple(bias_add_args), {}) @@ -100,7 +99,6 @@ class BiasAdditionModule(ABC): %view : [#users=1] = call_method[target=view](args = (%conv_bias, [1, -1, 1, 1]), kwargs = {}) %add : [#users=1] = call_function[target=operator.add](args = (%conv2d, %view), kwargs = {}) """ - pass module_to_func_dict = { diff --git a/colossalai/fx/tracer/bias_addition_patch/patched_bias_addition_module/conv.py b/colossalai/fx/tracer/bias_addition_patch/patched_bias_addition_module/conv.py index 4b6c82a74f57d213ba3d8b68863053ddd1aabf9d..812a141c1eab699a780e6513ca3c625cd92f44dd 100644 --- a/colossalai/fx/tracer/bias_addition_patch/patched_bias_addition_module/conv.py +++ b/colossalai/fx/tracer/bias_addition_patch/patched_bias_addition_module/conv.py @@ -1,6 +1,5 @@ import torch -import torch.nn.functional as F -from torch.nn.modules.utils import _pair, _reverse_repeat_tuple, _single, _triple +from torch.nn.modules.utils import _pair, _single, _triple from ...registry import bias_addition_module from .bias_addition_module import BiasAdditionModule @@ -10,17 +9,16 @@ from .bias_addition_module import BiasAdditionModule @bias_addition_module.register(torch.nn.Conv2d) @bias_addition_module.register(torch.nn.Conv3d) class BiasAdditionConv(BiasAdditionModule): - def extract_kwargs_from_mod(self): root = self.tracer.root conv_module = root.get_submodule(self.target) - kwarg_attributes = ['groups', 'dilation', 'stride'] + kwarg_attributes = ["groups", "dilation", "stride"] non_bias_kwargs = {} for attr_name in kwarg_attributes: if hasattr(conv_module, attr_name): non_bias_kwargs[attr_name] = getattr(conv_module, attr_name) if conv_module.padding_mode != "zeros": - #TODO: non zeros mode requires some extra processing for input + # TODO: non zeros mode requires some extra processing for input conv_type = type(conv_module) if conv_type == "torch.nn.Conv1d": padding_element = _single(0) @@ -28,9 +26,9 @@ class BiasAdditionConv(BiasAdditionModule): padding_element = _pair(0) elif conv_type == "torch.nn.Conv3d": padding_element = _triple(0) - non_bias_kwargs['padding'] = padding_element + non_bias_kwargs["padding"] = padding_element else: - non_bias_kwargs['padding'] = getattr(conv_module, 'padding') + non_bias_kwargs["padding"] = getattr(conv_module, "padding") return non_bias_kwargs @@ -41,11 +39,12 @@ class BiasAdditionConv(BiasAdditionModule): """ bias_shape = [1] * (dimensions - 1) bias_shape[0] = -1 - bias_reshape_node_kind = 'call_method' - bias_reshape_node_target = 'view' + bias_reshape_node_kind = "call_method" + bias_reshape_node_target = "view" bias_reshape_node_args = (self.bias_proxy, torch.Size(bias_shape)) - bias_reshape_proxy = self.tracer.create_proxy(bias_reshape_node_kind, bias_reshape_node_target, - bias_reshape_node_args, {}) + bias_reshape_proxy = self.tracer.create_proxy( + bias_reshape_node_kind, bias_reshape_node_target, bias_reshape_node_args, {} + ) return bias_reshape_proxy def generate(self): diff --git a/colossalai/fx/tracer/bias_addition_patch/patched_bias_addition_module/linear.py b/colossalai/fx/tracer/bias_addition_patch/patched_bias_addition_module/linear.py index f6f7b6ddab401a637aa2b43b4dd8d2ce9193266e..b397f009846c082a8364ac0eeafa4bf085e4cf63 100644 --- a/colossalai/fx/tracer/bias_addition_patch/patched_bias_addition_module/linear.py +++ b/colossalai/fx/tracer/bias_addition_patch/patched_bias_addition_module/linear.py @@ -1,5 +1,4 @@ import torch -import torch.nn.functional as F from ...registry import bias_addition_module from .bias_addition_module import BiasAdditionModule @@ -7,7 +6,6 @@ from .bias_addition_module import BiasAdditionModule @bias_addition_module.register(torch.nn.Linear) class BiasAdditionLinear(BiasAdditionModule): - def extract_kwargs_from_mod(self): return {} diff --git a/colossalai/fx/tracer/experimental.py b/colossalai/fx/tracer/experimental.py index 88b65b6188fa67be33f7a15b55e7fd5d32d7c4cc..e6e511b72fbbc7b0e9d7efd0af25e00935a7bea1 100644 --- a/colossalai/fx/tracer/experimental.py +++ b/colossalai/fx/tracer/experimental.py @@ -1,4 +1,3 @@ -import enum import functools import inspect import operator @@ -10,7 +9,7 @@ from torch.fx import Graph, Node, Proxy, Tracer from torch.utils._pytree import tree_map from colossalai.fx import ColoGraphModule, compatibility, is_compatible_with_meta -from colossalai.fx.tracer._tracer_utils import extract_meta, is_element_in_list +from colossalai.fx.tracer._tracer_utils import is_element_in_list from colossalai.fx.tracer.bias_addition_patch import func_to_func_dict, method_to_func_dict, module_to_func_dict from colossalai.fx.tracer.registry import ( bias_addition_function, @@ -24,31 +23,45 @@ if is_compatible_with_meta(): from colossalai.fx.profiler import MetaTensor Target = Union[Callable[..., Any], str] -Argument = Optional[Union[Tuple[Any, ...], # actually Argument, but mypy can't represent recursive types - List[Any], # actually Argument - Dict[str, Any], # actually Argument - slice, # Slice[Argument, Argument, Argument], but slice is not a templated type in typing - 'Node',]] -_CScriptMethod = ['add', 'mul', 'sub', 'div'] +Argument = Optional[ + Union[ + Tuple[Any, ...], # actually Argument, but mypy can't represent recursive types + List[Any], # actually Argument + Dict[str, Any], # actually Argument + slice, # Slice[Argument, Argument, Argument], but slice is not a templated type in typing + "Node", + ] +] +_CScriptMethod = ["add", "mul", "sub", "div"] _TorchNewMethod = [ - "arange", "zeros", "zeros_like", "ones", "ones_like", "full", "full_like", "empty", "empty_like", "eye", "tensor", - "finfo" + "arange", + "zeros", + "zeros_like", + "ones", + "ones_like", + "full", + "full_like", + "empty", + "empty_like", + "eye", + "tensor", + "finfo", ] _TensorPropertyMethod = ["dtype", "shape", "device", "requires_grad", "grad", "grad_fn", "data"] def _truncate_suffix(s: str): import re - return re.sub(r'_\d+$', '', s) + + return re.sub(r"_\d+$", "", s) def default_device(): - return torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu') + return torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu") @compatibility(is_backward_compatible=False) class ColoProxy(Proxy): - def __init__(self, *args, data=None, **kwargs): super().__init__(*args, **kwargs) self._meta_data = data @@ -100,7 +113,7 @@ class ColoProxy(Proxy): return ColoAttribute(self, k, getattr(self._meta_data, k, None)) def __setitem__(self, key, value): - proxy = self.tracer.create_proxy('call_function', operator.setitem, (self, key, value), {}) + proxy = self.tracer.create_proxy("call_function", operator.setitem, (self, key, value), {}) proxy.meta_data = self._meta_data return proxy @@ -125,29 +138,28 @@ class ColoProxy(Proxy): @property def device(self): - proxy = self.tracer.create_proxy('call_function', getattr, (self, 'device'), {}) + proxy = self.tracer.create_proxy("call_function", getattr, (self, "device"), {}) proxy.meta_data = self.meta_data.device return proxy @property def dtype(self): - proxy = self.tracer.create_proxy('call_function', getattr, (self, 'dtype'), {}) + proxy = self.tracer.create_proxy("call_function", getattr, (self, "dtype"), {}) proxy.meta_data = self.meta_data.dtype return proxy def to(self, *args, **kwargs): - return self.tracer.create_proxy('call_method', 'to', (self, *args), {**kwargs}) + return self.tracer.create_proxy("call_method", "to", (self, *args), {**kwargs}) def cpu(self, *args, **kwargs): - return self.tracer.create_proxy('call_method', 'cpu', (self, *args), {**kwargs}) + return self.tracer.create_proxy("call_method", "cpu", (self, *args), {**kwargs}) def cuda(self, *args, **kwargs): - return self.tracer.create_proxy('call_method', 'cuda', (self, *args), {**kwargs}) + return self.tracer.create_proxy("call_method", "cuda", (self, *args), {**kwargs}) @compatibility(is_backward_compatible=False) class ColoAttribute(ColoProxy): - def __init__(self, root, attr: str, data=None): self.root = root self.attr = attr @@ -160,11 +172,11 @@ class ColoAttribute(ColoProxy): # the node for attributes is added lazily, since most will just be method calls # which do not rely on the getitem call if self._node is None: - self._node = self.tracer.create_proxy('call_function', getattr, (self.root, self.attr), {}).node + self._node = self.tracer.create_proxy("call_function", getattr, (self.root, self.attr), {}).node return self._node def __call__(self, *args, **kwargs): - return self.tracer.create_proxy('call_method', self.attr, (self.root,) + args, kwargs) + return self.tracer.create_proxy("call_method", self.attr, (self.root,) + args, kwargs) def __repr__(self): return f"ColoAttribute({self.node.name}, attr={self.attr})" @@ -172,7 +184,6 @@ class ColoAttribute(ColoProxy): @compatibility(is_backward_compatible=False) class ColoTracer(Tracer): - def __init__(self, trace_act_ckpt: bool = False, *args, **kwargs): super().__init__(*args, **kwargs) self._disable_module_getattr = False @@ -184,24 +195,28 @@ class ColoTracer(Tracer): self.inside_torch_checkpoint_func = False self.act_ckpt_region_count = 0 - def proxy(self, node: Node) -> 'ColoProxy': + def proxy(self, node: Node) -> "ColoProxy": return ColoProxy(node, self) - def create_proxy(self, - kind: str, - target: Target, - args: Tuple[Any, ...], - kwargs: Dict[str, Any], - name: Optional[str] = None, - type_expr: Optional[Any] = None, - proxy_factory_fn: Callable[[Node], 'Proxy'] = None): - + def create_proxy( + self, + kind: str, + target: Target, + args: Tuple[Any, ...], + kwargs: Dict[str, Any], + name: Optional[str] = None, + type_expr: Optional[Any] = None, + proxy_factory_fn: Callable[[Node], "Proxy"] = None, + ): proxy: ColoProxy = super().create_proxy(kind, target, args, kwargs, name, type_expr, proxy_factory_fn) unwrap_fn = lambda p: p.meta_data if isinstance(p, ColoProxy) else p - if kind == 'placeholder': - proxy.meta_data = self.meta_args[target] if target in self.meta_args else self.concrete_args.get( - _truncate_suffix(target), None) - elif kind == 'get_attr': + if kind == "placeholder": + proxy.meta_data = ( + self.meta_args[target] + if target in self.meta_args + else self.concrete_args.get(_truncate_suffix(target), None) + ) + elif kind == "get_attr": self._disable_module_getattr = True try: attr_itr = self.root @@ -211,20 +226,21 @@ class ColoTracer(Tracer): proxy.meta_data = attr_itr finally: self._disable_module_getattr = False - elif kind == 'call_function': + elif kind == "call_function": proxy.meta_data = target(*tree_map(unwrap_fn, args), **tree_map(unwrap_fn, kwargs)) - elif kind == 'call_method': + elif kind == "call_method": self._disable_module_getattr = True try: - if target == '__call__': + if target == "__call__": proxy.meta_data = unwrap_fn(args[0])(*tree_map(unwrap_fn, args[1:]), **tree_map(unwrap_fn, kwargs)) else: if target not in _TensorPropertyMethod: - proxy._meta_data = getattr(unwrap_fn(args[0]), target)(*tree_map(unwrap_fn, args[1:]), - **tree_map(unwrap_fn, kwargs)) + proxy._meta_data = getattr(unwrap_fn(args[0]), target)( + *tree_map(unwrap_fn, args[1:]), **tree_map(unwrap_fn, kwargs) + ) finally: self._disable_module_getattr = False - elif kind == 'call_module': + elif kind == "call_module": mod = self.root.get_submodule(target) self._disable_module_getattr = True try: @@ -238,14 +254,15 @@ class ColoTracer(Tracer): if self.inside_torch_checkpoint_func: # annotate the activation checkpoint module - node.meta['activation_checkpoint'] = self.act_ckpt_region_count + node.meta["activation_checkpoint"] = self.act_ckpt_region_count return node - def trace(self, - root: torch.nn.Module, - concrete_args: Optional[Dict[str, torch.Tensor]] = None, - meta_args: Optional[Dict[str, torch.Tensor]] = None) -> Graph: - + def trace( + self, + root: torch.nn.Module, + concrete_args: Optional[Dict[str, torch.Tensor]] = None, + meta_args: Optional[Dict[str, torch.Tensor]] = None, + ) -> Graph: if meta_args is None: meta_args = {} @@ -260,20 +277,19 @@ class ColoTracer(Tracer): # update concrete args with default values non_meta_arg_names = sig_names - meta_arg_names for k, v in sig.parameters.items(): - if k in non_meta_arg_names and \ - k not in concrete_args and \ - v.default is not inspect.Parameter.empty: + if k in non_meta_arg_names and k not in concrete_args and v.default is not inspect.Parameter.empty: concrete_args[k] = v.default # get non concrete arg names concrete_arg_names = set(concrete_args.keys()) - non_concrete_arg_names = sig_names - concrete_arg_names + sig_names - concrete_arg_names def _check_arg_name_valid(names): success, element = is_element_in_list(names, sig_names) if not success: raise KeyError( - f"argument {element} is not found in the signature of {root.__class__.__name__}'s forward function") + f"argument {element} is not found in the signature of {root.__class__.__name__}'s forward function" + ) _check_arg_name_valid(meta_arg_names) _check_arg_name_valid(concrete_arg_names) @@ -292,10 +308,9 @@ class ColoTracer(Tracer): orig_ckpt_func = torch.utils.checkpoint.CheckpointFunction class PatchedCheckpointFunction(torch.autograd.Function): - @staticmethod def forward(ctx, run_function, preserve_rng_state, *args): - # signal that the current tracing occurs within activaton checkpoint part + # signal that the current tracing occurs within activation checkpoint part self.inside_torch_checkpoint_func = True out = run_function(*args) self.inside_torch_checkpoint_func = False @@ -305,7 +320,8 @@ class ColoTracer(Tracer): @staticmethod def backward(ctx: Any, *grad_outputs: Any) -> Any: raise NotImplementedError( - "We do not implement the backward pass as we only trace the forward pass.") + "We do not implement the backward pass as we only trace the forward pass." + ) # override the checkpoint function torch.utils.checkpoint.CheckpointFunction = PatchedCheckpointFunction @@ -356,10 +372,13 @@ class ColoTracer(Tracer): if attr_val is p: if n not in parameter_proxy_cache: kwargs = {} - if 'proxy_factory_fn' in inspect.signature(self.create_proxy).parameters: - kwargs['proxy_factory_fn'] = (None if not self.param_shapes_constant else - lambda node: ColoProxy(self, node, n, attr_val)) - val_proxy = self.create_proxy('get_attr', n, (), {}, **kwargs) # type: ignore[arg-type] + if "proxy_factory_fn" in inspect.signature(self.create_proxy).parameters: + kwargs["proxy_factory_fn"] = ( + None + if not self.param_shapes_constant + else lambda node: ColoProxy(self, node, n, attr_val) + ) + val_proxy = self.create_proxy("get_attr", n, (), {}, **kwargs) # type: ignore[arg-type] parameter_proxy_cache[n] = val_proxy return parameter_proxy_cache[n] return None @@ -370,8 +389,9 @@ class ColoTracer(Tracer): return maybe_buffer_proxy if isinstance(attr_val, torch.nn.Parameter): - maybe_parameter_proxy = maybe_get_proxy_for_attr(attr_val, self.root.named_parameters(), - parameter_proxy_cache) + maybe_parameter_proxy = maybe_get_proxy_for_attr( + attr_val, self.root.named_parameters(), parameter_proxy_cache + ) if maybe_parameter_proxy is not None: return maybe_parameter_proxy @@ -389,42 +409,41 @@ def symbolic_trace( if meta_args is not None: root.to(default_device()) wrap_fn = lambda x: MetaTensor(x, fake_device=default_device()) if isinstance(x, torch.Tensor) else x - graph = ColoTracer(trace_act_ckpt=trace_act_ckpt).trace(root, - concrete_args=concrete_args, - meta_args=tree_map(wrap_fn, meta_args)) + graph = ColoTracer(trace_act_ckpt=trace_act_ckpt).trace( + root, concrete_args=concrete_args, meta_args=tree_map(wrap_fn, meta_args) + ) root.cpu() else: graph = Tracer().trace(root, concrete_args=concrete_args) else: from .tracer import ColoTracer as OrigColoTracer - graph = OrigColoTracer(trace_act_ckpt=trace_act_ckpt).trace(root, - concrete_args=concrete_args, - meta_args=meta_args) + + graph = OrigColoTracer(trace_act_ckpt=trace_act_ckpt).trace( + root, concrete_args=concrete_args, meta_args=meta_args + ) name = root.__class__.__name__ if isinstance(root, torch.nn.Module) else root.__name__ return ColoGraphModule(root, graph, name) @compatibility(is_backward_compatible=False) class _TorchTensorOverride(object): - def __init__(self, tracer: Tracer): self.overrides = {} self.tracer = tracer def __enter__(self): - def wrap_tensor_method(target): - @functools.wraps(target) def wrapper(*args, **kwargs): is_proxy = any(isinstance(p, ColoProxy) for p in args) | any( - isinstance(p, ColoProxy) for p in kwargs.values()) + isinstance(p, ColoProxy) for p in kwargs.values() + ) if is_proxy: # if the arg is a proxy, then need to record this function called on this proxy # e.g. torch.ones(size) where size is an input proxy self.tracer._disable_module_getattr = True try: - proxy = self.tracer.create_proxy('call_function', target, args, kwargs) + proxy = self.tracer.create_proxy("call_function", target, args, kwargs) finally: self.tracer._disable_module_getattr = False return proxy @@ -446,11 +465,12 @@ class _TorchTensorOverride(object): setattr(torch, name, orig) -def meta_prop_pass(gm: ColoGraphModule, - root: torch.nn.Module, - meta_args: Optional[Dict[str, Any]] = None, - concrete_args: Optional[Dict[str, torch.Tensor]] = None): - +def meta_prop_pass( + gm: ColoGraphModule, + root: torch.nn.Module, + meta_args: Optional[Dict[str, Any]] = None, + concrete_args: Optional[Dict[str, torch.Tensor]] = None, +): if meta_args is None: meta_args = {} @@ -465,36 +485,36 @@ def meta_prop_pass(gm: ColoGraphModule, # update concrete args with default values non_meta_arg_names = sig_names - meta_arg_names for k, v in sig.parameters.items(): - if k in non_meta_arg_names and \ - k not in concrete_args and \ - v.default is not inspect.Parameter.empty: + if k in non_meta_arg_names and k not in concrete_args and v.default is not inspect.Parameter.empty: concrete_args[k] = v.default for node in gm.graph.nodes: - node._meta_data = _meta_data_computing(meta_args, concrete_args, root, node.op, node.target, node.args, - node.kwargs) + node._meta_data = _meta_data_computing( + meta_args, concrete_args, root, node.op, node.target, node.args, node.kwargs + ) def _meta_data_computing(meta_args, concrete_args, root, kind, target, args, kwargs): unwrap_fn = lambda n: n._meta_data if isinstance(n, Node) else n - if kind == 'placeholder': + if kind == "placeholder": meta_out = meta_args[target] if target in meta_args else concrete_args.get(_truncate_suffix(target), None) - elif kind == 'get_attr': + elif kind == "get_attr": attr_itr = root atoms = target.split(".") for atom in atoms: attr_itr = getattr(attr_itr, atom) meta_out = attr_itr - elif kind == 'call_function': + elif kind == "call_function": meta_out = target(*tree_map(unwrap_fn, args), **tree_map(unwrap_fn, kwargs)) - elif kind == 'call_method': - if target == '__call__': + elif kind == "call_method": + if target == "__call__": meta_out = unwrap_fn(args[0])(*tree_map(unwrap_fn, args[1:]), **tree_map(unwrap_fn, kwargs)) else: if target not in _TensorPropertyMethod: - meta_out = getattr(unwrap_fn(args[0]), target)(*tree_map(unwrap_fn, args[1:]), - **tree_map(unwrap_fn, kwargs)) - elif kind == 'call_module': + meta_out = getattr(unwrap_fn(args[0]), target)( + *tree_map(unwrap_fn, args[1:]), **tree_map(unwrap_fn, kwargs) + ) + elif kind == "call_module": mod = root.get_submodule(target) meta_out = mod.forward(*tree_map(unwrap_fn, args), **tree_map(unwrap_fn, kwargs)) else: @@ -603,26 +623,30 @@ def bias_addition_pass(gm: ColoGraphModule, root_model: torch.nn.Module, meta_ar if kind == "call_function": if bias_addition_function.has(target): if target == torch.nn.functional.linear: - if 'bias' in kwargs and kwargs['bias'] is not None: + if "bias" in kwargs and kwargs["bias"] is not None: function_to_substitute = func_to_func_dict[target] - handle = bias_addition_function.get(target)(tracer, target, args_proxy, kwargs_proxy, - function_to_substitute) + handle = bias_addition_function.get(target)( + tracer, target, args_proxy, kwargs_proxy, function_to_substitute + ) else: function_to_substitute = func_to_func_dict[target] - handle = bias_addition_function.get(target)(tracer, target, args_proxy, kwargs_proxy, - function_to_substitute) + handle = bias_addition_function.get(target)( + tracer, target, args_proxy, kwargs_proxy, function_to_substitute + ) elif bias_addition_function.has(target.__name__): # use name for some builtin op like @ (matmul) function_to_substitute = func_to_func_dict[target] - handle = bias_addition_function.get(target.__name__)(tracer, target, args_proxy, kwargs_proxy, - function_to_substitute) + handle = bias_addition_function.get(target.__name__)( + tracer, target, args_proxy, kwargs_proxy, function_to_substitute + ) elif kind == "call_method": method = getattr(args_metas[0].__class__, target) if bias_addition_method.has(method): function_to_substitute = method_to_func_dict[method] - handle = bias_addition_method.get(method)(tracer, target, args_proxy, kwargs_proxy, - function_to_substitute) + handle = bias_addition_method.get(method)( + tracer, target, args_proxy, kwargs_proxy, function_to_substitute + ) elif kind == "call_module": # if not hasattr(self, "orig_forward"): @@ -631,8 +655,9 @@ def bias_addition_pass(gm: ColoGraphModule, root_model: torch.nn.Module, meta_ar mod_type = type(mod) if bias_addition_module.has(mod_type) and mod.bias is not None: function_to_substitute = module_to_func_dict[mod_type] - handle = bias_addition_module.get(mod_type)(tracer, target, args_proxy, kwargs_proxy, - function_to_substitute) + handle = bias_addition_module.get(mod_type)( + tracer, target, args_proxy, kwargs_proxy, function_to_substitute + ) if handle is not None: handle.generate() diff --git a/colossalai/fx/tracer/meta_patch/patched_function/activation_function.py b/colossalai/fx/tracer/meta_patch/patched_function/activation_function.py index 12c42514895e61777f0dbb1c348206af7d0ac5ec..75d7b18a067c0248f7236190f3b1e72e4cf18ede 100644 --- a/colossalai/fx/tracer/meta_patch/patched_function/activation_function.py +++ b/colossalai/fx/tracer/meta_patch/patched_function/activation_function.py @@ -5,4 +5,4 @@ from ...registry import meta_patched_function @meta_patched_function.register(torch.nn.functional.relu) def torch_nn_func_relu(input, inplace=False): - return torch.empty(input.shape, device='meta') + return torch.empty(input.shape, device="meta") diff --git a/colossalai/fx/tracer/meta_patch/patched_function/arithmetic.py b/colossalai/fx/tracer/meta_patch/patched_function/arithmetic.py index 042b92c5847a4ed0d78d4acf068a0a5030fa7089..3475f22e3b194559d29d1820f9768542e4e7e858 100644 --- a/colossalai/fx/tracer/meta_patch/patched_function/arithmetic.py +++ b/colossalai/fx/tracer/meta_patch/patched_function/arithmetic.py @@ -4,7 +4,7 @@ from ...registry import meta_patched_function @meta_patched_function.register(torch.matmul) -@meta_patched_function.register('matmul') # for built-in op @ +@meta_patched_function.register("matmul") # for built-in op @ def torch_matmul(input, other, *, out=None): # copied from huggingface.utils.fx d1 = input.dim() @@ -44,8 +44,8 @@ def torch_matmul(input, other, *, out=None): @meta_patched_function.register(torch.abs) def torch_abs(input, *, out=None): - assert out is None, 'out is not supported yet' - return torch.empty(input.shape, device='meta') + assert out is None, "out is not supported yet" + return torch.empty(input.shape, device="meta") @meta_patched_function.register(torch.bmm) @@ -89,7 +89,7 @@ def torch_addmm(input, mat1, mat2, *, beta=1, alpha=1, out=None): @meta_patched_function.register(torch.var_mean) def torch_var_mean(input, dim, unbiased=True, keepdim=False, *, out=None): - assert out is None, 'saving to out is not supported yet' - var = torch.empty(1).squeeze(0).to('meta') - mean = torch.empty(1).squeeze(0).to('meta') + assert out is None, "saving to out is not supported yet" + var = torch.empty(1).squeeze(0).to("meta") + mean = torch.empty(1).squeeze(0).to("meta") return var, mean diff --git a/colossalai/fx/tracer/meta_patch/patched_function/convolution.py b/colossalai/fx/tracer/meta_patch/patched_function/convolution.py index 8500e5c82508195ca3ec8ebc99a33ad2a1b946ad..26daf32a2afc9d3f9367e0e599ce9d34ba7417dc 100644 --- a/colossalai/fx/tracer/meta_patch/patched_function/convolution.py +++ b/colossalai/fx/tracer/meta_patch/patched_function/convolution.py @@ -8,7 +8,6 @@ from ...registry import meta_patched_function def _ntuple(n, name="parse"): - def parse(x): if isinstance(x, collections.abc.Iterable): return tuple(x) @@ -24,21 +23,21 @@ _triple = _ntuple(3, "_triple") def _extract_kwargs(kwargs): - if 'stride' in kwargs: - stride = kwargs['stride'] + if "stride" in kwargs: + stride = kwargs["stride"] else: stride = 1 # TODO: process str type padding - if 'padding' in kwargs: - padding = kwargs['padding'] + if "padding" in kwargs: + padding = kwargs["padding"] else: padding = 0 - if 'dilation' in kwargs: - dilation = kwargs['dilation'] + if "dilation" in kwargs: + dilation = kwargs["dilation"] else: dilation = 1 - if 'output_padding' in kwargs: - output_padding = kwargs['output_padding'] + if "output_padding" in kwargs: + output_padding = kwargs["output_padding"] else: output_padding = 0 @@ -61,7 +60,7 @@ def torch_nn_functional_conv1d(input, weight, **kwargs): c_out, l_out, ) - return torch.empty(result_shape, device='meta') + return torch.empty(result_shape, device="meta") @meta_patched_function.register(torch.nn.functional.conv2d) @@ -82,7 +81,7 @@ def torch_nn_functional_conv2d(input, weight, **kwargs): h_out, w_out, ) - return torch.empty(result_shape, device='meta') + return torch.empty(result_shape, device="meta") @meta_patched_function.register(torch.nn.functional.conv3d) @@ -105,7 +104,7 @@ def torch_nn_functional_conv3d(input, weight, **kwargs): h_out, w_out, ) - return torch.empty(result_shape, device='meta') + return torch.empty(result_shape, device="meta") @meta_patched_function.register(torch.nn.functional.conv_transpose1d) @@ -120,13 +119,14 @@ def torch_nn_functional_convtranspose1d(input, weight, **kwargs): kernel_size = weight.shape[2:] l_in = input.shape[-1] c_out = weight.shape[1] - l_out = math.floor((l_in - 1) * stride[0] - 2 * padding[0] + dilation[0] * (kernel_size[0] - 1) + - output_padding[0] + 1) + l_out = math.floor( + (l_in - 1) * stride[0] - 2 * padding[0] + dilation[0] * (kernel_size[0] - 1) + output_padding[0] + 1 + ) result_shape = input.shape[:-2] + ( c_out, l_out, ) - return torch.empty(result_shape, device='meta') + return torch.empty(result_shape, device="meta") @meta_patched_function.register(torch.nn.functional.conv_transpose2d) @@ -141,16 +141,18 @@ def torch_nn_functional_convtranspose2d(input, weight, **kwargs): kernel_size = weight.shape[2:] h_in, w_in = input.shape[-2:] c_out = weight.shape[1] - h_out = math.floor((h_in - 1) * stride[0] - 2 * padding[0] + dilation[0] * (kernel_size[0] - 1) + - output_padding[0] + 1) - w_out = math.floor((w_in - 1) * stride[1] - 2 * padding[1] + dilation[1] * (kernel_size[1] - 1) + - output_padding[1] + 1) + h_out = math.floor( + (h_in - 1) * stride[0] - 2 * padding[0] + dilation[0] * (kernel_size[0] - 1) + output_padding[0] + 1 + ) + w_out = math.floor( + (w_in - 1) * stride[1] - 2 * padding[1] + dilation[1] * (kernel_size[1] - 1) + output_padding[1] + 1 + ) result_shape = input.shape[:-3] + ( c_out, h_out, w_out, ) - return torch.empty(result_shape, device='meta') + return torch.empty(result_shape, device="meta") @meta_patched_function.register(torch.nn.functional.conv_transpose3d) @@ -165,16 +167,19 @@ def torch_nn_functional_convtranspose3d(input, weight, **kwargs): kernel_size = weight.shape[2:] d_in, h_in, w_in = input.shape[-3:] c_out = weight.shape[1] - d_out = math.floor((d_in - 1) * stride[0] - 2 * padding[0] + dilation[0] * (kernel_size[0] - 1) + - output_padding[0] + 1) - h_out = math.floor((h_in - 1) * stride[1] - 2 * padding[1] + dilation[1] * (kernel_size[1] - 1) + - output_padding[1] + 1) - w_out = math.floor((w_in - 1) * stride[2] - 2 * padding[2] + dilation[2] * (kernel_size[2] - 1) + - output_padding[2] + 1) + d_out = math.floor( + (d_in - 1) * stride[0] - 2 * padding[0] + dilation[0] * (kernel_size[0] - 1) + output_padding[0] + 1 + ) + h_out = math.floor( + (h_in - 1) * stride[1] - 2 * padding[1] + dilation[1] * (kernel_size[1] - 1) + output_padding[1] + 1 + ) + w_out = math.floor( + (w_in - 1) * stride[2] - 2 * padding[2] + dilation[2] * (kernel_size[2] - 1) + output_padding[2] + 1 + ) result_shape = input.shape[:-4] + ( c_out, d_out, h_out, w_out, ) - return torch.empty(result_shape, device='meta') + return torch.empty(result_shape, device="meta") diff --git a/colossalai/fx/tracer/meta_patch/patched_function/embedding.py b/colossalai/fx/tracer/meta_patch/patched_function/embedding.py index 6d8d864ea29acd648f7f7097821c248655b0191e..27a79f18590a8e192909fe867b67391941bd44eb 100644 --- a/colossalai/fx/tracer/meta_patch/patched_function/embedding.py +++ b/colossalai/fx/tracer/meta_patch/patched_function/embedding.py @@ -4,11 +4,7 @@ from ...registry import meta_patched_function @meta_patched_function.register(torch.nn.functional.embedding) -def torch_nn_functional_embedding(input, - weight, - padding_idx=None, - max_norm=None, - norm_type=2.0, - scale_grad_by_freq=False, - sparse=False): +def torch_nn_functional_embedding( + input, weight, padding_idx=None, max_norm=None, norm_type=2.0, scale_grad_by_freq=False, sparse=False +): return torch.empty(*input.shape, weight.shape[-1], device="meta") diff --git a/colossalai/fx/tracer/meta_patch/patched_function/normalization.py b/colossalai/fx/tracer/meta_patch/patched_function/normalization.py index e9e7eda6159c88ca3a82320c841360df87540c82..8a62149908306a8ca61dcc40a98bf0cf0b7a6152 100644 --- a/colossalai/fx/tracer/meta_patch/patched_function/normalization.py +++ b/colossalai/fx/tracer/meta_patch/patched_function/normalization.py @@ -5,16 +5,11 @@ from ...registry import meta_patched_function @meta_patched_function.register(torch.nn.functional.layer_norm) def torch_nn_func_layernorm(input, normalized_shape, weight=None, bias=None, eps=1e-05): - return torch.empty(input.shape, device='meta') + return torch.empty(input.shape, device="meta") @meta_patched_function.register(torch.nn.functional.batch_norm) -def torch_nn_func_batchnorm(input, - running_mean, - running_var, - weight=None, - bias=None, - training=False, - momentum=0.1, - eps=1e-05): - return torch.empty(input.shape, device='meta') +def torch_nn_func_batchnorm( + input, running_mean, running_var, weight=None, bias=None, training=False, momentum=0.1, eps=1e-05 +): + return torch.empty(input.shape, device="meta") diff --git a/colossalai/fx/tracer/meta_patch/patched_function/python_ops.py b/colossalai/fx/tracer/meta_patch/patched_function/python_ops.py index 4c171cb1099119de54b70b03097e1781880f2624..7642934a409bcafb57370267969b8a384db079b0 100644 --- a/colossalai/fx/tracer/meta_patch/patched_function/python_ops.py +++ b/colossalai/fx/tracer/meta_patch/patched_function/python_ops.py @@ -19,9 +19,9 @@ def operator_getitem(a, b): return t def _slice_convert(slice_obj): - attrs = {'start': slice_obj.start, 'stop': slice_obj.stop, 'step': slice_obj.step} + attrs = {"start": slice_obj.start, "stop": slice_obj.stop, "step": slice_obj.step} new_attrs = _slice_attr_convert(attrs) - attr_dict_to_tuple = (new_attrs['start'], new_attrs['stop'], new_attrs['step']) + attr_dict_to_tuple = (new_attrs["start"], new_attrs["stop"], new_attrs["step"]) return slice(*attr_dict_to_tuple) def _slice_attr_convert(attrs): diff --git a/colossalai/fx/tracer/meta_patch/patched_function/torch_ops.py b/colossalai/fx/tracer/meta_patch/patched_function/torch_ops.py index b14ff10ce1377055ed7f3ab3025ee7c05c6a1657..c61e1c4dc9e1131a587a79a00eed740737c8f000 100644 --- a/colossalai/fx/tracer/meta_patch/patched_function/torch_ops.py +++ b/colossalai/fx/tracer/meta_patch/patched_function/torch_ops.py @@ -105,14 +105,15 @@ def torch_cat(tensors, dim=None, axis=None, *, out=None): shapes = [t.shape for t in tensors] shape = list(shapes[0]) concatenated_dim = sum(shape[dim] for shape in shapes) - final_shape = shape[:dim] + [concatenated_dim] + shape[dim + 1:] + final_shape = shape[:dim] + [concatenated_dim] + shape[dim + 1 :] return torch.empty(final_shape, device="meta") @meta_patched_function.register(torch.repeat_interleave) def torch_repeat_interleave(input, repeats, dim=None, output_size=None): - assert isinstance(repeats, int) or isinstance(repeats, torch.Tensor), \ - "Argument 'repeats' should be of type 'torch.Tensor' or 'int'" + assert isinstance(repeats, int) or isinstance( + repeats, torch.Tensor + ), "Argument 'repeats' should be of type 'torch.Tensor' or 'int'" shape = list(input.shape) if dim is not None else [input.numel()] dim = dim if dim is not None else 0 @@ -132,36 +133,36 @@ def torch_tensor_repeat_interleave(self, repeats, dim=None, *, output_size=None) @meta_patched_function.register(torch.roll) def torch_roll(input, shifts, dims=None): - return torch.empty(input.shape, device='meta') + return torch.empty(input.shape, device="meta") @meta_patched_function.register(torch.full) def torch_full(size, fill_value, *, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False): - assert out is None, 'assigning result to out is not supported yet' - return torch.empty(size, device='meta', dtype=dtype, layout=layout, requires_grad=requires_grad) + assert out is None, "assigning result to out is not supported yet" + return torch.empty(size, device="meta", dtype=dtype, layout=layout, requires_grad=requires_grad) @meta_patched_function.register(torch.max) def torch_max(input, dim=None, keepdim=False, *, out=None): - assert out is None, 'assigning value to out is not supported yet' + assert out is None, "assigning value to out is not supported yet" if dim is not None: if isinstance(dim, int): shape = list(input.shape) shape.pop(dim) if keepdim: shape.insert(dim, 1) - return torch.empty(shape, device='meta', dtype=input.dtype), torch.empty(shape, - device='meta', - dtype=input.dtype) + return torch.empty(shape, device="meta", dtype=input.dtype), torch.empty( + shape, device="meta", dtype=input.dtype + ) elif isinstance(dim, torch.Tensor): # when dim is a 0D or 1D tensor, it will maintain the same shape num_dims = dim.dim() if num_dims in [0, 1]: - return torch.empty_like(input, device='meta') + return torch.empty_like(input, device="meta") else: raise ValueError(f"Expected dim to a 0D or 1D tensor but got {num_dims} dimensions") else: - return torch.empty([], device='meta', dtype=input.dtype) + return torch.empty([], device="meta", dtype=input.dtype) @meta_patched_function.register(torch.Tensor.cpu) diff --git a/colossalai/fx/tracer/meta_patch/patched_module/__init__.py b/colossalai/fx/tracer/meta_patch/patched_module/__init__.py index e28e52585fffc193473a7c8270c103919cc63e0d..3f40ec2a67ee61963c80f15ce6a271ac70d2772b 100644 --- a/colossalai/fx/tracer/meta_patch/patched_module/__init__.py +++ b/colossalai/fx/tracer/meta_patch/patched_module/__init__.py @@ -4,4 +4,4 @@ from .embedding import * from .linear import * from .normalization import * from .pooling import * -from .rnn import * \ No newline at end of file +from .rnn import * diff --git a/colossalai/fx/tracer/meta_patch/patched_module/activation_function.py b/colossalai/fx/tracer/meta_patch/patched_module/activation_function.py index d03da6588c1cbf56403dccc5989f4a4987b7e2a9..aa2ede187d37b832bd43324204a098c3176174a6 100644 --- a/colossalai/fx/tracer/meta_patch/patched_module/activation_function.py +++ b/colossalai/fx/tracer/meta_patch/patched_module/activation_function.py @@ -10,4 +10,4 @@ from ...registry import meta_patched_module @meta_patched_module.register(torch.nn.ReLU6) @meta_patched_module.register(torch.nn.PReLU) def torch_nn_non_linear_act(self, input): - return torch.empty(input.shape, device='meta') + return torch.empty(input.shape, device="meta") diff --git a/colossalai/fx/tracer/meta_patch/patched_module/convolution.py b/colossalai/fx/tracer/meta_patch/patched_module/convolution.py index cf9f3487aac9f31ff799e3245132c40d643de64f..35173a68a0be6edefef8bb45b951733e34fc655e 100644 --- a/colossalai/fx/tracer/meta_patch/patched_module/convolution.py +++ b/colossalai/fx/tracer/meta_patch/patched_module/convolution.py @@ -11,13 +11,14 @@ def torch_nn_conv1d(self, input): # at https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html#torch.nn.Conv1d l_in = input.shape[-1] c_out = self.out_channels - l_out = math.floor((l_in + 2 * self.padding[0] - self.dilation[0] * - (self.kernel_size[0] - 1) - 1) / self.stride[0] + 1) + l_out = math.floor( + (l_in + 2 * self.padding[0] - self.dilation[0] * (self.kernel_size[0] - 1) - 1) / self.stride[0] + 1 + ) result_shape = input.shape[:-2] + ( c_out, l_out, ) - return torch.empty(result_shape, device='meta') + return torch.empty(result_shape, device="meta") @meta_patched_module.register(torch.nn.Conv2d) @@ -26,16 +27,18 @@ def torch_nn_conv2d(self, input): # at https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html#torch.nn.Conv2d h_in, w_in = input.shape[-2:] c_out = self.out_channels - h_out = math.floor((h_in + 2 * self.padding[0] - self.dilation[0] * - (self.kernel_size[0] - 1) - 1) / self.stride[0] + 1) - w_out = math.floor((w_in + 2 * self.padding[1] - self.dilation[1] * - (self.kernel_size[1] - 1) - 1) / self.stride[1] + 1) + h_out = math.floor( + (h_in + 2 * self.padding[0] - self.dilation[0] * (self.kernel_size[0] - 1) - 1) / self.stride[0] + 1 + ) + w_out = math.floor( + (w_in + 2 * self.padding[1] - self.dilation[1] * (self.kernel_size[1] - 1) - 1) / self.stride[1] + 1 + ) result_shape = input.shape[:-3] + ( c_out, h_out, w_out, ) - return torch.empty(result_shape, device='meta') + return torch.empty(result_shape, device="meta") @meta_patched_module.register(torch.nn.Conv3d) @@ -44,19 +47,22 @@ def torch_nn_conv3d(self, input): # at https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html#torch.nn.Conv3d d_in, h_in, w_in = input.shape[-3:] c_out = self.out_channels - d_out = math.floor((d_in + 2 * self.padding[0] - self.dilation[0] * - (self.kernel_size[0] - 1) - 1) / self.stride[0] + 1) - h_out = math.floor((h_in + 2 * self.padding[1] - self.dilation[1] * - (self.kernel_size[1] - 1) - 1) / self.stride[1] + 1) - w_out = math.floor((w_in + 2 * self.padding[2] - self.dilation[2] * - (self.kernel_size[2] - 1) - 1) / self.stride[2] + 1) + d_out = math.floor( + (d_in + 2 * self.padding[0] - self.dilation[0] * (self.kernel_size[0] - 1) - 1) / self.stride[0] + 1 + ) + h_out = math.floor( + (h_in + 2 * self.padding[1] - self.dilation[1] * (self.kernel_size[1] - 1) - 1) / self.stride[1] + 1 + ) + w_out = math.floor( + (w_in + 2 * self.padding[2] - self.dilation[2] * (self.kernel_size[2] - 1) - 1) / self.stride[2] + 1 + ) result_shape = input.shape[:-4] + ( c_out, d_out, h_out, w_out, ) - return torch.empty(result_shape, device='meta') + return torch.empty(result_shape, device="meta") @meta_patched_module.register(torch.nn.ConvTranspose1d) @@ -65,13 +71,18 @@ def torch_nn_convtranspose1d(self, input): # at https://pytorch.org/docs/stable/generated/torch.nn.ConvTranspose1d.html l_in = input.shape[-1] c_out = self.out_channels - l_out = math.floor((l_in - 1) * self.stride[0] - 2 * self.padding[0] + self.dilation[0] * - (self.kernel_size[0] - 1) + self.output_padding[0] + 1) + l_out = math.floor( + (l_in - 1) * self.stride[0] + - 2 * self.padding[0] + + self.dilation[0] * (self.kernel_size[0] - 1) + + self.output_padding[0] + + 1 + ) result_shape = input.shape[:-2] + ( c_out, l_out, ) - return torch.empty(result_shape, device='meta') + return torch.empty(result_shape, device="meta") @meta_patched_module.register(torch.nn.ConvTranspose2d) @@ -80,16 +91,26 @@ def torch_nn_convtranspose2d(self, input): # at https://pytorch.org/docs/stable/generated/torch.nn.ConvTranspose2d.html h_in, w_in = input.shape[-2:] c_out = self.out_channels - h_out = math.floor((h_in - 1) * self.stride[0] - 2 * self.padding[0] + self.dilation[0] * - (self.kernel_size[0] - 1) + self.output_padding[0] + 1) - w_out = math.floor((w_in - 1) * self.stride[1] - 2 * self.padding[1] + self.dilation[1] * - (self.kernel_size[1] - 1) + self.output_padding[1] + 1) + h_out = math.floor( + (h_in - 1) * self.stride[0] + - 2 * self.padding[0] + + self.dilation[0] * (self.kernel_size[0] - 1) + + self.output_padding[0] + + 1 + ) + w_out = math.floor( + (w_in - 1) * self.stride[1] + - 2 * self.padding[1] + + self.dilation[1] * (self.kernel_size[1] - 1) + + self.output_padding[1] + + 1 + ) result_shape = input.shape[:-3] + ( c_out, h_out, w_out, ) - return torch.empty(result_shape, device='meta') + return torch.empty(result_shape, device="meta") @meta_patched_module.register(torch.nn.ConvTranspose3d) @@ -98,16 +119,31 @@ def torch_nn_convtranspose3d(self, input): # at https://pytorch.org/docs/stable/generated/torch.nn.ConvTranspose3d.html d_in, h_in, w_in = input.shape[-3:] c_out = self.out_channels - d_out = math.floor((d_in - 1) * self.stride[0] - 2 * self.padding[0] + self.dilation[0] * - (self.kernel_size[0] - 1) + self.output_padding[0] + 1) - h_out = math.floor((h_in - 1) * self.stride[1] - 2 * self.padding[1] + self.dilation[1] * - (self.kernel_size[1] - 1) + self.output_padding[1] + 1) - w_out = math.floor((w_in - 1) * self.stride[2] - 2 * self.padding[2] + self.dilation[2] * - (self.kernel_size[2] - 1) + self.output_padding[2] + 1) + d_out = math.floor( + (d_in - 1) * self.stride[0] + - 2 * self.padding[0] + + self.dilation[0] * (self.kernel_size[0] - 1) + + self.output_padding[0] + + 1 + ) + h_out = math.floor( + (h_in - 1) * self.stride[1] + - 2 * self.padding[1] + + self.dilation[1] * (self.kernel_size[1] - 1) + + self.output_padding[1] + + 1 + ) + w_out = math.floor( + (w_in - 1) * self.stride[2] + - 2 * self.padding[2] + + self.dilation[2] * (self.kernel_size[2] - 1) + + self.output_padding[2] + + 1 + ) result_shape = input.shape[:-4] + ( c_out, d_out, h_out, w_out, ) - return torch.empty(result_shape, device='meta') + return torch.empty(result_shape, device="meta") diff --git a/colossalai/fx/tracer/meta_patch/patched_module/embedding.py b/colossalai/fx/tracer/meta_patch/patched_module/embedding.py index 999e33b17c1c7b442d2a6db73f957be4413f1fa1..f28647e9caa576cfae44438a287c57bf54efefe1 100644 --- a/colossalai/fx/tracer/meta_patch/patched_module/embedding.py +++ b/colossalai/fx/tracer/meta_patch/patched_module/embedding.py @@ -6,4 +6,4 @@ from ...registry import meta_patched_module @meta_patched_module.register(torch.nn.Embedding) def torch_nn_embedding(self, input): result_shape = input.shape + (self.embedding_dim,) - return torch.empty(result_shape, device='meta') + return torch.empty(result_shape, device="meta") diff --git a/colossalai/fx/tracer/meta_patch/patched_module/linear.py b/colossalai/fx/tracer/meta_patch/patched_module/linear.py index 56f13bf97532e26770a0be7a4226ee69a2124ee5..97e6b0e96e831f7b0af3155edbae73ec04603ce0 100644 --- a/colossalai/fx/tracer/meta_patch/patched_module/linear.py +++ b/colossalai/fx/tracer/meta_patch/patched_module/linear.py @@ -6,5 +6,7 @@ from ...registry import meta_patched_module @meta_patched_module.register(torch.nn.Linear) def torch_nn_linear(self, input): last_dim = input.shape[-1] - assert last_dim == self.in_features, f'Expected hidden size {self.in_features} but got {last_dim} for the torch.nn.Linear patch' + assert ( + last_dim == self.in_features + ), f"Expected hidden size {self.in_features} but got {last_dim} for the torch.nn.Linear patch" return torch.empty(input.shape[:-1] + (self.out_features,), device="meta") diff --git a/colossalai/fx/tracer/meta_patch/patched_module/normalization.py b/colossalai/fx/tracer/meta_patch/patched_module/normalization.py index c21ff64cf3dec9baf357771fa0d15b341b413ac1..198e72e342b13d32610a16e2f71a410c70f395b5 100644 --- a/colossalai/fx/tracer/meta_patch/patched_module/normalization.py +++ b/colossalai/fx/tracer/meta_patch/patched_module/normalization.py @@ -23,6 +23,7 @@ def torch_nn_normalize(self, input): try: import apex + meta_patched_module.register(apex.normalization.FusedLayerNorm)(torch_nn_normalize) meta_patched_module.register(apex.normalization.FusedRMSNorm)(torch_nn_normalize) meta_patched_module.register(apex.normalization.MixedFusedLayerNorm)(torch_nn_normalize) diff --git a/colossalai/fx/tracer/meta_patch/patched_module/pooling.py b/colossalai/fx/tracer/meta_patch/patched_module/pooling.py index 7ce23fbf7ac9368f4ec8496b252494e779fb5015..450586d02f8f294730412be0ba0718b0146421bf 100644 --- a/colossalai/fx/tracer/meta_patch/patched_module/pooling.py +++ b/colossalai/fx/tracer/meta_patch/patched_module/pooling.py @@ -8,7 +8,7 @@ from ...registry import meta_patched_module @meta_patched_module.register(torch.nn.AvgPool1d) def torch_nn_avgpool1d(self, input): num_dim = input.dim() - assert num_dim in [2, 3], f'expected the input to have 2 or 3 dimensions, but got {num_dim} dimensions' + assert num_dim in [2, 3], f"expected the input to have 2 or 3 dimensions, but got {num_dim} dimensions" l_in = input.shape[-1] @@ -25,13 +25,13 @@ def torch_nn_avgpool1d(self, input): l_out = math.floor((l_in + 2 * padding[0] - kernel_size[0]) / stride[0] + 1) result_shape = tuple(input.shape[:-1]) + (l_out,) - return torch.empty(result_shape, device='meta') + return torch.empty(result_shape, device="meta") @meta_patched_module.register(torch.nn.AvgPool2d) def torch_nn_avgpool2d(self, input): num_dim = input.dim() - assert num_dim in [3, 4], f'expected the input to have 3 or 4 dimensions, but got {num_dim} dimensions' + assert num_dim in [3, 4], f"expected the input to have 3 or 4 dimensions, but got {num_dim} dimensions" h_in, w_in = input.shape[-2:] @@ -52,13 +52,13 @@ def torch_nn_avgpool2d(self, input): h_out, w_out, ) - return torch.empty(result_shape, device='meta') + return torch.empty(result_shape, device="meta") @meta_patched_module.register(torch.nn.AvgPool3d) def torch_nn_avgpool3d(self, input): num_dim = input.dim() - assert num_dim in [4, 5], f'expected the input to have 4 or 5 dimensions, but got {num_dim} dimensions' + assert num_dim in [4, 5], f"expected the input to have 4 or 5 dimensions, but got {num_dim} dimensions" d_in, h_in, w_in = input.shape[-3:] @@ -81,13 +81,13 @@ def torch_nn_avgpool3d(self, input): h_out, w_out, ) - return torch.empty(result_shape, device='meta') + return torch.empty(result_shape, device="meta") @meta_patched_module.register(torch.nn.MaxPool1d) def torch_nn_maxpool1d(self, input): num_dim = input.dim() - assert num_dim in [2, 3], f'expected the input to have 2 or 3 dimensions, but got {num_dim} dimensions' + assert num_dim in [2, 3], f"expected the input to have 2 or 3 dimensions, but got {num_dim} dimensions" l_in = input.shape[-1] @@ -105,13 +105,13 @@ def torch_nn_maxpool1d(self, input): l_out = math.floor((l_in + 2 * padding[0] - dilation[0] * (kernel_size[0] - 1) - 1) / stride[0] + 1) result_shape = tuple(input.shape[:-1]) + (l_out,) - return torch.empty(result_shape, device='meta') + return torch.empty(result_shape, device="meta") @meta_patched_module.register(torch.nn.MaxPool2d) def torch_nn_maxpool2d(self, input): num_dim = input.dim() - assert num_dim in [3, 4], f'expected the input to have 3 or 4 dimensions, but got {num_dim} dimensions' + assert num_dim in [3, 4], f"expected the input to have 3 or 4 dimensions, but got {num_dim} dimensions" h_in, w_in = input.shape[-2:] @@ -133,13 +133,13 @@ def torch_nn_maxpool2d(self, input): h_out, w_out, ) - return torch.empty(result_shape, device='meta') + return torch.empty(result_shape, device="meta") @meta_patched_module.register(torch.nn.MaxPool3d) def torch_nn_maxpool3d(self, input): num_dim = input.dim() - assert num_dim in [4, 5], f'expected the input to have 4 or 5 dimensions, but got {num_dim} dimensions' + assert num_dim in [4, 5], f"expected the input to have 4 or 5 dimensions, but got {num_dim} dimensions" d_in, h_in, w_in = input.shape[-3:] @@ -163,7 +163,7 @@ def torch_nn_maxpool3d(self, input): h_out, w_out, ) - return torch.empty(result_shape, device='meta') + return torch.empty(result_shape, device="meta") @meta_patched_module.register(torch.nn.AdaptiveAvgPool1d) @@ -175,7 +175,7 @@ def torch_nn_adapative_pooling_1d(self, input): else: output_size = self.output_size result_shape = tuple(input.shape[:-1]) + output_size - return torch.empty(result_shape, device='meta') + return torch.empty(result_shape, device="meta") @meta_patched_module.register(torch.nn.AdaptiveAvgPool2d) @@ -187,7 +187,7 @@ def torch_nn_adapative_pooling_2d(self, input): else: output_size = self.output_size result_shape = tuple(input.shape[:-2]) + output_size - return torch.empty(result_shape, device='meta') + return torch.empty(result_shape, device="meta") @meta_patched_module.register(torch.nn.AdaptiveAvgPool3d) @@ -199,4 +199,4 @@ def torch_nn_adapative_pooling_3d(self, input): else: output_size = self.output_size result_shape = tuple(input.shape[:-3]) + output_size - return torch.empty(result_shape, device='meta') + return torch.empty(result_shape, device="meta") diff --git a/colossalai/fx/tracer/meta_patch/patched_module/rnn.py b/colossalai/fx/tracer/meta_patch/patched_module/rnn.py index ee15ca34162e83612eb179e0cff066d9f06faf36..bfb7ed17118604acef1958344635118b87bbf185 100644 --- a/colossalai/fx/tracer/meta_patch/patched_module/rnn.py +++ b/colossalai/fx/tracer/meta_patch/patched_module/rnn.py @@ -1,5 +1,3 @@ -from typing import Optional - import torch from ...registry import meta_patched_module @@ -8,9 +6,11 @@ from ...registry import meta_patched_module @meta_patched_module.register(torch.nn.GRU) @meta_patched_module.register(torch.nn.RNN) def torch_nn_rnn(self, input, hx): - assert input.shape[ - -1] == self.input_size, f'Expected input to have input size {self.input_size} but got {input.shape[-1]} for the torch.nn.RNN patch' - assert hx.shape[ - -1] == self.hidden_size, f'Expected hx to have hidden size {self.hidden_size} but got {hx.shape[-1]} for the torch.nn.RNN patch' + assert ( + input.shape[-1] == self.input_size + ), f"Expected input to have input size {self.input_size} but got {input.shape[-1]} for the torch.nn.RNN patch" + assert ( + hx.shape[-1] == self.hidden_size + ), f"Expected hx to have hidden size {self.hidden_size} but got {hx.shape[-1]} for the torch.nn.RNN patch" d = 2 if self.bidirectional else 1 return torch.empty(input.shape[:-1] + (self.hidden_size * d,), device="meta"), hx diff --git a/colossalai/fx/tracer/registry.py b/colossalai/fx/tracer/registry.py index 12fc6de73d4435dea8ec58fa50b93a6070fd6254..80b3868bb4feaadbd96d448ce32870bc7e3dad75 100644 --- a/colossalai/fx/tracer/registry.py +++ b/colossalai/fx/tracer/registry.py @@ -1,11 +1,9 @@ class PatchRegistry: - def __init__(self, name): self.name = name self.store = {} def register(self, source): - def wrapper(func): self.store[source] = func return func @@ -21,8 +19,8 @@ class PatchRegistry: return source in self.store -meta_patched_function = PatchRegistry(name='patched_functions_for_meta_execution') -meta_patched_module = PatchRegistry(name='patched_modules_for_meta_execution') -bias_addition_function = PatchRegistry(name='patched_function_for_bias_addition') -bias_addition_module = PatchRegistry(name='patched_module_for_bias_addition') -bias_addition_method = PatchRegistry(name='patched_method_for_bias_addition') +meta_patched_function = PatchRegistry(name="patched_functions_for_meta_execution") +meta_patched_module = PatchRegistry(name="patched_modules_for_meta_execution") +bias_addition_function = PatchRegistry(name="patched_function_for_bias_addition") +bias_addition_module = PatchRegistry(name="patched_module_for_bias_addition") +bias_addition_method = PatchRegistry(name="patched_method_for_bias_addition") diff --git a/colossalai/fx/tracer/tracer.py b/colossalai/fx/tracer/tracer.py index 1ae31f9589756d4552f6c2247e0b71679625832b..d9cb587b5d39165c8cc86231c12a399ab588a533 100644 --- a/colossalai/fx/tracer/tracer.py +++ b/colossalai/fx/tracer/tracer.py @@ -29,7 +29,7 @@ from .registry import ( meta_patched_module, ) -__all__ = ['ColoTracer'] +__all__ = ["ColoTracer"] class TracerType(enum.Enum): @@ -92,7 +92,7 @@ class ColoTracer(Tracer): return proxy # if graph is traced for auto parallelism module, some extra node will be added during - # graph construction to deal with the compatability between bias addition and all reduce. + # graph construction to deal with the compatibility between bias addition and all reduce. # if no extra manipulation is applied, we just pass the origin arguments to create_proxy function # to create node on computation graph @@ -103,7 +103,7 @@ class ColoTracer(Tracer): if kind == "call_function": if bias_addition_function.has(target): if target == torch.nn.functional.linear: - if 'bias' in kwargs and kwargs['bias'] is not None: + if "bias" in kwargs and kwargs["bias"] is not None: function_to_substitute = func_to_func_dict[target] handle = bias_addition_function.get(target)(self, target, args, kwargs, function_to_substitute) else: @@ -160,22 +160,27 @@ class ColoTracer(Tracer): if n not in parameter_proxy_cache: kwargs = {} if "proxy_factory_fn" in inspect.signature(self.create_proxy).parameters: - kwargs["proxy_factory_fn"] = (None if not self.param_shapes_constant else - lambda node: ParameterProxy(self, node, n, attr_val)) - val_proxy = self.create_proxy("get_attr", n, (), {}, **kwargs) # type: ignore[arg-type] + kwargs["proxy_factory_fn"] = ( + None + if not self.param_shapes_constant + else lambda node: ParameterProxy(self, node, n, attr_val) + ) + val_proxy = self.create_proxy("get_attr", n, (), {}, **kwargs) # type: ignore[arg-type] parameter_proxy_cache[n] = val_proxy return parameter_proxy_cache[n] return None if isinstance(attr_val, torch.nn.Parameter): - maybe_parameter_proxy = maybe_get_proxy_for_attr(attr_val, self.root.named_parameters(), - parameter_proxy_cache) + maybe_parameter_proxy = maybe_get_proxy_for_attr( + attr_val, self.root.named_parameters(), parameter_proxy_cache + ) if maybe_parameter_proxy is not None: return maybe_parameter_proxy if self.proxy_buffer_attributes and isinstance(attr_val, torch.Tensor): - maybe_buffer_proxy = maybe_get_proxy_for_attr(attr_val, self.root.named_buffers(), - parameter_proxy_cache) + maybe_buffer_proxy = maybe_get_proxy_for_attr( + attr_val, self.root.named_buffers(), parameter_proxy_cache + ) if maybe_buffer_proxy is not None: return maybe_buffer_proxy @@ -190,7 +195,7 @@ class ColoTracer(Tracer): # if a customized or third-party module like apex.normalization.FusedRMSNorm is patched, # we should treat it as leaf module as well if meta_patched_module.has(m.__class__) or self.is_leaf_module(m, module_qualified_name): - return self.create_proxy('call_module', module_qualified_name, args, kwargs) + return self.create_proxy("call_module", module_qualified_name, args, kwargs) else: return forward(*args, **kwargs) @@ -208,10 +213,9 @@ class ColoTracer(Tracer): self.proxy_cls = ColoProxy self.tracer_type = TracerType.META else: - raise ValueError(f"Unrecognised tracer type {tracer_type}") + raise ValueError(f"Unrecognized tracer type {tracer_type}") def _meta_data_computing(self, kind, target, args, kwargs): - if kind == "placeholder" and target in self.meta_args and self.meta_args[target].is_meta: meta_out = self.meta_args[target] return meta_out @@ -235,8 +239,9 @@ class ColoTracer(Tracer): # Therefore, I need to record the nn.parameter.Parameter attribute for the operation # added by the bias addition manipulation following the get_attr node. convert_to_parameter = False - if target in (torch.transpose, torch.reshape) and isinstance(args_metas[0], - torch.nn.parameter.Parameter): + if target in (torch.transpose, torch.reshape) and isinstance( + args_metas[0], torch.nn.parameter.Parameter + ): convert_to_parameter = True # fetch patched function if meta_patched_function.has(target): @@ -309,10 +314,12 @@ class ColoTracer(Tracer): return meta_out - def trace(self, - root: nn.Module, - concrete_args: Optional[Dict[str, Tensor]] = None, - meta_args: Optional[Dict[str, Tensor]] = None) -> Graph: + def trace( + self, + root: nn.Module, + concrete_args: Optional[Dict[str, Tensor]] = None, + meta_args: Optional[Dict[str, Tensor]] = None, + ) -> Graph: """ Trace the forward computation graph using `torch.fx.Tracer`. This tracer enables data-dependent control flow. @@ -341,9 +348,7 @@ class ColoTracer(Tracer): # update concrete args with default values non_meta_arg_names = sig_names - meta_arg_names for k, v in sig.parameters.items(): - if k in non_meta_arg_names and \ - k not in concrete_args and \ - v.default is not inspect.Parameter.empty: + if k in non_meta_arg_names and k not in concrete_args and v.default is not inspect.Parameter.empty: concrete_args[k] = v.default # get non concrete arg names @@ -354,7 +359,8 @@ class ColoTracer(Tracer): success, element = is_element_in_list(names, sig_names) if not success: raise KeyError( - f"argument {element} is not found in the signature of {root.__class__.__name__}'s forward function") + f"argument {element} is not found in the signature of {root.__class__.__name__}'s forward function" + ) _check_arg_name_valid(meta_arg_names) _check_arg_name_valid(concrete_arg_names) @@ -363,11 +369,13 @@ class ColoTracer(Tracer): def _check_kwargs(kwargs, should_be_meta: bool): for k, v in kwargs.items(): if not should_be_meta: - assert not torch.is_tensor(v) or not v.is_meta, \ - f'Expected the {k} not to be a meta tensor, please check the args passed to the tracer' + assert ( + not torch.is_tensor(v) or not v.is_meta + ), f"Expected the {k} not to be a meta tensor, please check the args passed to the tracer" else: - assert v.is_meta == should_be_meta, \ - f'Expected the is_meta attribute of {k} to be {should_be_meta}, but got {v.is_meta}, please check the args passed to the tracer' + assert ( + v.is_meta == should_be_meta + ), f"Expected the is_meta attribute of {k} to be {should_be_meta}, but got {v.is_meta}, please check the args passed to the tracer" _check_kwargs(concrete_args, should_be_meta=False) _check_kwargs(meta_args, should_be_meta=True) @@ -442,10 +450,9 @@ class ColoTracer(Tracer): orig_ckpt_func = torch.utils.checkpoint.CheckpointFunction class PatchedCheckpointFunction(torch.autograd.Function): - @staticmethod def forward(ctx, run_function, preserve_rng_state, *args): - # signal that the current tracing occurs within activaton checkpoint part + # signal that the current tracing occurs within activation checkpoint part self.inside_torch_checkpoint_func = True out = run_function(*args) self.inside_torch_checkpoint_func = False @@ -455,7 +462,8 @@ class ColoTracer(Tracer): @staticmethod def backward(ctx: Any, *grad_outputs: Any) -> Any: raise NotImplementedError( - "We do not implement the backward pass as we only trace the forward pass.") + "We do not implement the backward pass as we only trace the forward pass." + ) # override the checkpoint function torch.utils.checkpoint.CheckpointFunction = PatchedCheckpointFunction @@ -470,12 +478,11 @@ class ColoTracer(Tracer): if self.inside_torch_checkpoint_func: # annotate the activation checkpoint module - node.meta['activation_checkpoint'] = self.act_ckpt_region_count + node.meta["activation_checkpoint"] = self.act_ckpt_region_count return node def wrap_tensor_constructor_method(target): - def look_for_proxy(*args, **kwargs): # find in pos vars for arg in args: @@ -518,12 +525,10 @@ def wrap_tensor_constructor_method(target): for method in magic_methods: def _scope(method): - def impl(*args, **kwargs): - tracer = args[0].tracer target = getattr(operator, method) - proxy = tracer.create_proxy('call_function', target, args, kwargs) + proxy = tracer.create_proxy("call_function", target, args, kwargs) if not isinstance(proxy, ColoProxy): meta_out = compute_meta_data_for_functions_proxy(target, args, kwargs) proxy = ColoProxy(proxy.node) @@ -542,7 +547,7 @@ def _define_reflectable(orig_method_name): def impl(self, rhs): target = getattr(operator, orig_method_name) - proxy = self.tracer.create_proxy('call_function', target, (rhs, self), {}) + proxy = self.tracer.create_proxy("call_function", target, (rhs, self), {}) if not isinstance(proxy, ColoProxy): meta_out = compute_meta_data_for_functions_proxy(target, *(rhs, self), {}) proxy = ColoProxy(proxy.node) diff --git a/colossalai/global_variables.py b/colossalai/global_variables.py deleted file mode 100644 index 61b31965e2e63d2119bfadba0d49478537c31fa7..0000000000000000000000000000000000000000 --- a/colossalai/global_variables.py +++ /dev/null @@ -1,56 +0,0 @@ -from typing import Optional - - -class TensorParallelEnv(object): - _instance = None - - def __new__(cls, *args, **kwargs): - if cls._instance is None: - cls._instance = object.__new__(cls, *args, **kwargs) - return cls._instance - - def __init__(self, *args, **kwargs): - self.load(*args, **kwargs) - - def load(self, - mode: Optional[str] = None, - vocab_parallel: bool = False, - parallel_input_1d: bool = False, - summa_dim: int = None, - tesseract_dim: int = None, - tesseract_dep: int = None, - depth_3d: int = None, - input_group_3d=None, - weight_group_3d=None, - output_group_3d=None, - input_x_weight_group_3d=None, - output_x_weight_group_3d=None): - self.mode = mode - self.vocab_parallel = vocab_parallel - self.parallel_input_1d = parallel_input_1d - self.summa_dim = summa_dim - self.tesseract_dim = tesseract_dim - self.tesseract_dep = tesseract_dep - self.depth_3d = depth_3d - self.input_group_3d = input_group_3d - self.weight_group_3d = weight_group_3d - self.output_group_3d = output_group_3d - self.input_x_weight_group_3d = input_x_weight_group_3d - self.output_x_weight_group_3d = output_x_weight_group_3d - - def save(self): - return dict(mode=self.mode, - vocab_parallel=self.vocab_parallel, - parallel_input_1d=self.parallel_input_1d, - summa_dim=self.summa_dim, - tesseract_dim=self.tesseract_dim, - tesseract_dep=self.tesseract_dep, - depth_3d=self.depth_3d, - input_group_3d=self.input_group_3d, - weight_group_3d=self.weight_group_3d, - output_group_3d=self.output_group_3d, - input_x_weight_group_3d=self.input_x_weight_group_3d, - output_x_weight_group_3d=self.output_x_weight_group_3d) - - -tensor_parallel_env = TensorParallelEnv() diff --git a/colossalai/inference/README.md b/colossalai/inference/README.md new file mode 100644 index 0000000000000000000000000000000000000000..9a965dc982a4c2c24eea5eacc86e85999f205d7a --- /dev/null +++ b/colossalai/inference/README.md @@ -0,0 +1,117 @@ +# 🚀 Colossal-Inference + +## Table of contents + +## Introduction + +`Colossal Inference` is a module that contains colossal-ai designed inference framework, featuring high performance, steady and easy usability. `Colossal Inference` incorporated the advantages of the latest open-source inference systems, including TGI, vLLM, FasterTransformer, LightLLM and flash attention. while combining the design of Colossal AI, especially Shardformer, to reduce the learning curve for users. + +## Design + +Colossal Inference is composed of two main components: + +1. High performance kernels and ops: which are inspired from existing libraries and modified correspondingly. +2. Efficient memory management mechanism:which includes the key-value cache manager, allowing for zero memory waste during inference. + 1. `cache manager`: serves as a memory manager to help manage the key-value cache, it integrates functions such as memory allocation, indexing and release. + 2. `batch_infer_info`: holds all essential elements of a batch inference, which is updated every batch. +3. High-level inference engine combined with `Shardformer`: it allows our inference framework to easily invoke and utilize various parallel methods. + 1. `engine.TPInferEngine`: it is a high level interface that integrates with shardformer, especially for multi-card (tensor parallel) inference: + 2. `modeling.llama.LlamaInferenceForwards`: contains the `forward` methods for llama inference. (in this case : llama) + 3. `policies.llama.LlamaModelInferPolicy` : contains the policies for `llama` models, which is used to call `shardformer` and segmentate the model forward in tensor parallelism way. + +## Pipeline of inference: + +In this section we discuss how the colossal inference works and integrates with the `Shardformer` . The details can be found in our codes. + +![Colossal-Inference](https://raw.githubusercontent.com/hpcaitech/public_assets/main/colossalai/img/inference/Colossal-inference.png) + +## Roadmap of our implementation + +- [x] Design cache manager and batch infer state +- [x] Design TpInference engine to integrates with `Shardformer` +- [x] Register corresponding high-performance `kernel` and `ops` +- [x] Design policies and forwards (e.g. `Llama` and `Bloom`) + - [x] policy + - [x] context forward + - [x] token forward +- [ ] Replace the kernels with `faster-transformer` in token-forward stage +- [ ] Support all models + - [x] Llama + - [x] Bloom + - [ ] Chatglm2 +- [ ] Benchmarking for all models + +## Get started + +### Installation + +```bash +pip install -e . +``` + +### Requirements + +dependencies + +```bash +pytorch= 1.13.1 (gpu) +cuda>= 11.6 +transformers= 4.30.2 +triton==2.0.0.dev20221202 +# for install vllm, please use this branch to install https://github.com/tiandiao123/vllm/tree/setup_branch +vllm +# for install flash-attention, please use commit hash: 67ae6fd74b4bc99c36b2ce524cf139c35663793c +flash-attention +``` + +### Docker + +You can use docker run to use docker container to set-up environment + +``` +# env: python==3.8, cuda 11.6, pytorch == 1.13.1 triton==2.0.0.dev20221202, vllm kernels support, flash-attention-2 kernels support +docker pull hpcaitech/colossalai-inference:v2 +docker run -it --gpus all --name ANY_NAME -v $PWD:/workspace -w /workspace hpcaitech/colossalai-inference:v2 /bin/bash + +``` + +### Dive into fast-inference! + +example files are in + +```bash +cd colossalai.examples +python xx +``` + +## Performance + +### environment: + +We conducted multiple benchmark tests to evaluate the performance. We compared the inference `latency` and `throughputs` between `colossal-inference` and original `hugging-face torch fp16`. + +For various models, experiments were conducted using multiple batch sizes under the consistent model configuration of `7 billion(7b)` parameters, `1024` input length, and 128 output length. The obtained results are as follows (due to time constraints, the evaluation has currently been performed solely on the `A100` single GPU performance; multi-GPU performance will be addressed in the future): + +### Single GPU Performance: + +Currently the stats below are calculated based on A100 (single GPU), and we calculate token latency based on average values of context-forward and decoding forward process, which means we combine both of processes to calculate token generation times. We are actively developing new features and methods to furthur optimize the performance of LLM models. Please stay tuned. + +#### Llama + +| batch_size | 8 | 16 | 32 | +| :---------------------: | :----: | :----: | :----: | +| hugging-face torch fp16 | 199.12 | 246.56 | 278.4 | +| colossal-inference | 326.4 | 582.72 | 816.64 | + +![llama](https://raw.githubusercontent.com/hpcaitech/public_assets/main/colossalai/img/inference/Infer-llama7b.png) + +### Bloom + +| batch_size | 8 | 16 | 32 | +| :---------------------: | :----: | :----: | :----: | +| hugging-face torch fp16 | 189.68 | 226.66 | 249.61 | +| colossal-inference | 323.28 | 538.52 | 611.64 | + +![bloom](https://raw.githubusercontent.com/hpcaitech/public_assets/main/colossalai/img/inference/Infer-bloom7b.png) + +The results of more models are coming soon! diff --git a/tests/test_layers/test_2p5d/checks_2p5d/__init__.py b/colossalai/inference/__init__.py similarity index 100% rename from tests/test_layers/test_2p5d/checks_2p5d/__init__.py rename to colossalai/inference/__init__.py diff --git a/colossalai/inference/quant/gptq/__init__.py b/colossalai/inference/quant/gptq/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..c035f397923a77d3becfc60d7c554456e545554a --- /dev/null +++ b/colossalai/inference/quant/gptq/__init__.py @@ -0,0 +1,4 @@ +from .cai_gptq import HAS_AUTO_GPTQ + +if HAS_AUTO_GPTQ: + from .cai_gptq import CaiGPTQLinearOp, CaiQuantLinear diff --git a/colossalai/inference/quant/gptq/cai_gptq/__init__.py b/colossalai/inference/quant/gptq/cai_gptq/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..4ed76293bd81ccf1e44ac5044e96335d3efa5dbd --- /dev/null +++ b/colossalai/inference/quant/gptq/cai_gptq/__init__.py @@ -0,0 +1,14 @@ +import warnings + +HAS_AUTO_GPTQ = False +try: + import auto_gptq + + HAS_AUTO_GPTQ = True +except ImportError: + warnings.warn("please install auto-gptq from https://github.com/PanQiWei/AutoGPTQ") + HAS_AUTO_GPTQ = False + +if HAS_AUTO_GPTQ: + from .cai_quant_linear import CaiQuantLinear, ColCaiQuantLinear, RowCaiQuantLinear + from .gptq_op import CaiGPTQLinearOp diff --git a/colossalai/inference/quant/gptq/cai_gptq/cai_quant_linear.py b/colossalai/inference/quant/gptq/cai_gptq/cai_quant_linear.py new file mode 100644 index 0000000000000000000000000000000000000000..ca12c34ed958506c769a560d020c3f077f43c935 --- /dev/null +++ b/colossalai/inference/quant/gptq/cai_gptq/cai_quant_linear.py @@ -0,0 +1,354 @@ +# Adapted from AutoGPTQ auto_gptq: https://github.com/PanQiWei/AutoGPTQ + +import math +import warnings +from typing import List, Union + +import numpy as np +import torch +import torch.distributed as dist +import torch.nn as nn +from torch.distributed import ProcessGroup + +from colossalai.lazy import LazyInitContext +from colossalai.shardformer.layer import ParallelModule + +from .gptq_op import CaiGPTQLinearOp + +HAS_GPTQ_CUDA = False +try: + from colossalai.kernel.op_builder.gptq import GPTQBuilder + gptq_cuda = GPTQBuilder().load() + HAS_GPTQ_CUDA = True +except ImportError: + warnings.warn('CUDA gptq is not installed') + HAS_GPTQ_CUDA = False + + +class CaiQuantLinear(nn.Module): + + def __init__(self, bits, groupsize, infeatures, outfeatures, bias, tp_size=1, tp_rank=0, row_split=False): + super().__init__() + if bits not in [2, 4, 8]: + raise NotImplementedError("Only 2,4,8 bits are supported.") + self.infeatures = infeatures + self.outfeatures = outfeatures + self.bits = bits + self.maxq = 2**self.bits - 1 + self.groupsize = groupsize if groupsize != -1 else infeatures + + self.register_buffer('qweight', torch.zeros((infeatures // 32 * self.bits, outfeatures), dtype=torch.int32)) + self.register_buffer( + 'qzeros', + torch.zeros((math.ceil(infeatures / self.groupsize), outfeatures // 32 * self.bits), dtype=torch.int32)) + self.register_buffer('scales', + torch.zeros((math.ceil(infeatures / self.groupsize), outfeatures), dtype=torch.float16)) + if row_split: + self.register_buffer( + 'g_idx', + torch.tensor([(i + (tp_rank * self.infeatures)) // self.groupsize for i in range(infeatures)], + dtype=torch.int32)) + else: + self.register_buffer('g_idx', + torch.tensor([i // self.groupsize for i in range(infeatures)], dtype=torch.int32)) + + if bias: + self.register_buffer('bias', torch.zeros((outfeatures), dtype=torch.float16)) + else: + self.bias = None + + self.gptq_linear = CaiGPTQLinearOp(groupsize, bits) + + self.q4 = None + self.empty_tensor = torch.empty((1, 1), device="meta") + self.tp_size = tp_size + self.tp_rank = tp_rank + self.row_split = row_split + + def pack(self, linear, scales, zeros, g_idx=None): + + g_idx = g_idx.clone() if g_idx is not None else torch.tensor( + [i // self.groupsize for i in range(self.infeatures)], dtype=torch.int32) + + scales = scales.t().contiguous() + zeros = zeros.t().contiguous() + scale_zeros = zeros * scales + half_scales = scales.clone().half() + # print("scale shape ", scales.shape, scale_zeros.shape, linear.weight.shape) + self.scales = scales.clone().half() + if linear.bias is not None: + self.bias = linear.bias.clone().half() + + wn = 8 + pbits = 32 + ptype = torch.int32 + unsign_type = np.uint32 + sign_type = np.int32 + + intweight = [] + for idx in range(self.infeatures): + intweight.append( + torch.round( + (linear.weight.data[:, idx] + scale_zeros[g_idx[idx]]) / half_scales[g_idx[idx]]).to(ptype)[:, + None]) + intweight = torch.cat(intweight, dim=1) + intweight = intweight.t().contiguous() + intweight = intweight.numpy().astype(unsign_type) + qweight = np.zeros((intweight.shape[0] // pbits * self.bits, intweight.shape[1]), dtype=unsign_type) + + i = 0 + row = 0 + + while row < qweight.shape[0]: + if self.bits in [2, 4, 8]: + for j in range(i, i + (pbits // self.bits)): + qweight[row] |= intweight[j] << (self.bits * (j - i)) + i += pbits // self.bits + row += 1 + else: + raise NotImplementedError("Only 2,4,8 bits are supported.") + qweight = qweight.astype(sign_type) + qweight1 = torch.from_numpy(qweight) + qweight1 = qweight1.contiguous() #.to("cuda") + self.qweight.data.copy_(qweight1) + + qzeros = np.zeros((zeros.shape[0], zeros.shape[1] // pbits * self.bits), dtype=unsign_type) + zeros -= 1 + zeros = zeros.numpy().astype(unsign_type) + i = 0 + col = 0 + while col < qzeros.shape[1]: + if self.bits in [2, 4, 8]: + for j in range(i, i + (pbits // self.bits)): + qzeros[:, col] |= zeros[:, j] << (self.bits * (j - i)) + i += pbits // self.bits + col += 1 + else: + raise NotImplementedError("Only 2,4,8 bits are supported.") + qzeros = qzeros.astype(sign_type) + qzeros = torch.from_numpy(qzeros) + qzeros = qzeros + self.qzeros.data.copy_(qzeros) + + if torch.equal(self.g_idx.to(g_idx.device), g_idx): + self.g_idx = None + else: + self.g_idx = g_idx + + def init_q4(self): + assert self.qweight.device.type == "cuda" + self.q4_width = self.qweight.shape[1] + if self.g_idx is not None: + if self.row_split and torch.equal( + self.g_idx, + torch.tensor( + [(i + (self.tp_rank * self.infeatures)) // self.groupsize for i in range(self.infeatures)], + dtype=torch.int32, + device=self.g_idx.device)): + self.g_idx = None + elif torch.equal( + self.g_idx, + torch.tensor([i // self.groupsize for i in range(self.infeatures)], + dtype=torch.int32, + device=self.g_idx.device)): + self.g_idx = None + + if self.g_idx is not None: + g_idx = self.g_idx.to("cpu") + else: + g_idx = self.empty_tensor + + self.q4 = gptq_cuda.make_q4(self.qweight, self.qzeros, self.scales, g_idx, torch.cuda.current_device()) + torch.cuda.synchronize() + + def forward(self, x): + outshape = x.shape[:-1] + (self.outfeatures,) + + if HAS_GPTQ_CUDA and self.bits == 4: + + if self.q4 is None: + self.init_q4() + + x = x.view(-1, x.shape[-1]) + output = torch.empty((x.shape[0], self.outfeatures), dtype=torch.float16, device=x.device) + gptq_cuda.q4_matmul(x.half(), self.q4, output) + if self.bias is not None and (not self.row_split or self.tp_size == 1): + output.add_(self.bias) + else: + if self.bias is not None and (not self.row_split or self.tp_size == 1): + bias = self.bias + else: + bias = None + output = self.gptq_linear( + x, + self.qweight, + self.scales, + self.qzeros, + g_idx=self.g_idx, + bias=bias, + ) + return output.view(outshape) + + +def split_column_copy(gptq_linear, cai_linear, tp_size=1, tp_rank=0, split_num=1): + + qweights = gptq_linear.qweight.split(gptq_linear.out_features // split_num, dim=-1) + qzeros = gptq_linear.qzeros.split(gptq_linear.out_features // (32 // cai_linear.bits) // split_num, dim=-1) + scales = gptq_linear.scales.split(gptq_linear.out_features // split_num, dim=-1) + g_idx = gptq_linear.g_idx + if gptq_linear.bias is not None: + bias = gptq_linear.bias.split(gptq_linear.out_features // split_num, dim=-1) + + cai_split_out_features = cai_linear.outfeatures // split_num + zero_split_block = cai_linear.outfeatures // (32 // cai_linear.bits) // split_num + + for i in range(split_num): + cai_linear.qweight[:, i * cai_split_out_features:(i + 1) * + cai_split_out_features] = qweights[i][:, tp_rank * cai_split_out_features:(tp_rank + 1) * + cai_split_out_features] + cai_linear.qzeros[:, i * zero_split_block:(i + 1) * + zero_split_block] = qzeros[i][:, tp_rank * zero_split_block:(tp_rank + 1) * zero_split_block] + cai_linear.scales[:, i * cai_split_out_features:(i + 1) * + cai_split_out_features] = scales[i][:, tp_rank * cai_split_out_features:(tp_rank + 1) * + cai_split_out_features] + if cai_linear.bias is not None: + cai_linear.bias[i * cai_split_out_features:(i + 1) * + cai_split_out_features] = bias[i][tp_rank * cai_split_out_features:(tp_rank + 1) * + cai_split_out_features] + + cai_linear.g_idx.copy_(g_idx) + + +def split_row_copy(gptq_linear, cai_linear, tp_rank=0, split_num=1): + + qweights = gptq_linear.qweight.split(gptq_linear.in_features // split_num, dim=0) + qzeros = gptq_linear.qzeros.split(gptq_linear.in_features // split_num, dim=0) + scales = gptq_linear.scales.split(gptq_linear.in_features // split_num, dim=0) + g_idxs = gptq_linear.g_idx.split(gptq_linear.in_features // split_num, dim=0) + + cai_split_in_features = cai_linear.infeatures // (32 // cai_linear.bits) // split_num + zero_split_block = cai_linear.infeatures // cai_linear.groupsize // split_num + idx_split_features = cai_linear.infeatures // split_num + + for i in range(split_num): + cai_linear.qweight[i * cai_split_in_features:(i + 1) * + cai_split_in_features, :] = qweights[i][tp_rank * cai_split_in_features:(tp_rank + 1) * + cai_split_in_features, :] + cai_linear.qzeros[i * zero_split_block:(i + 1) * + zero_split_block, :] = qzeros[i][tp_rank * zero_split_block:(tp_rank + 1) * + zero_split_block, :] + cai_linear.scales[i * zero_split_block:(i + 1) * + zero_split_block, :] = scales[i][tp_rank * zero_split_block:(tp_rank + 1) * + zero_split_block, :] + cai_linear.g_idx[i * idx_split_features:(i + 1) * + idx_split_features] = g_idxs[i][tp_rank * idx_split_features:(tp_rank + 1) * + idx_split_features] + if cai_linear.bias is not None: + cai_linear.bias.copy_(gptq_linear.bias) + + +class RowCaiQuantLinear(CaiQuantLinear, ParallelModule): + + def __init__(self, bits, groupsize, infeatures, outfeatures, bias, tp_size=1, tp_rank=0, row_split=False): + + super().__init__(bits, + groupsize, + infeatures, + outfeatures, + bias, + tp_size=tp_size, + tp_rank=tp_rank, + row_split=row_split) + self.process_group = None + + @staticmethod + def from_native_module(module: nn.Module, process_group: Union[ProcessGroup, List[ProcessGroup]], *args, + **kwargs) -> ParallelModule: + LazyInitContext.materialize(module) + # get the attributes + in_features = module.in_features + + # ensure only one process group is passed + if isinstance(process_group, (list, tuple)): + assert len(process_group) == 1, \ + f'Expected only one process group, got {len(process_group)}.' + process_group = process_group[0] + + tp_size = dist.get_world_size(process_group) + tp_rank = dist.get_rank(process_group) + + if in_features < tp_size: + return module + + if in_features % tp_size != 0: + raise ValueError( + f"The size of in_features:{in_features} is not integer multiples of tensor parallel size: {tp_size}!") + linear_1d = RowCaiQuantLinear(module.bits, + module.group_size, + module.in_features // tp_size, + module.out_features, + module.bias is not None, + tp_size=tp_size, + tp_rank=tp_rank, + row_split=True) + linear_1d.process_group = process_group + + split_row_copy(module, linear_1d, tp_rank=tp_rank, **kwargs) + return linear_1d + + def forward(self, x): + output = super().forward(x) + if self.tp_size > 1: + dist.all_reduce(output, op=dist.ReduceOp.SUM, group=self.process_group) + if self.bias is not None: + output.add_(self.bias) + return output + + +class ColCaiQuantLinear(CaiQuantLinear, ParallelModule): + + def __init__(self, bits, groupsize, infeatures, outfeatures, bias, tp_size=1, tp_rank=0, row_split=False): + + super().__init__(bits, + groupsize, + infeatures, + outfeatures, + bias, + tp_size=tp_size, + tp_rank=tp_rank, + row_split=row_split) + self.process_group = None + + @staticmethod + def from_native_module(module: nn.Module, process_group: Union[ProcessGroup, List[ProcessGroup]], *args, + **kwargs) -> ParallelModule: + LazyInitContext.materialize(module) + # get the attributes + in_features = module.in_features + + # ensure only one process group is passed + if isinstance(process_group, (list, tuple)): + assert len(process_group) == 1, \ + f'Expected only one process group, got {len(process_group)}.' + process_group = process_group[0] + + tp_size = dist.get_world_size(process_group) + tp_rank = dist.get_rank(process_group) + + if in_features < tp_size: + return module + + if in_features % tp_size != 0: + raise ValueError( + f"The size of in_features:{in_features} is not integer multiples of tensor parallel size: {tp_size}!") + linear_1d = ColCaiQuantLinear(module.bits, + module.group_size, + module.in_features, + module.out_features // tp_size, + module.bias is not None, + tp_size=tp_size, + tp_rank=tp_rank) + linear_1d.process_group = process_group + + split_column_copy(module, linear_1d, tp_rank=tp_rank, **kwargs) + return linear_1d diff --git a/colossalai/inference/quant/gptq/cai_gptq/gptq_op.py b/colossalai/inference/quant/gptq/cai_gptq/gptq_op.py new file mode 100644 index 0000000000000000000000000000000000000000..a8902eb35cd0737fcda3451d54e6f55a3e313b4b --- /dev/null +++ b/colossalai/inference/quant/gptq/cai_gptq/gptq_op.py @@ -0,0 +1,58 @@ +import torch + +from colossalai.kernel.triton import gptq_fused_linear_triton + + +class CaiGPTQLinearOp(torch.nn.Module): + def __init__(self, gptq_group_size, gptq_quant_bits): + super(CaiGPTQLinearOp, self).__init__() + self.group_size = gptq_group_size + self.bits = gptq_quant_bits + self.maxq = 2**self.bits - 1 + self.empty_tensor = torch.zeros(4, device=torch.cuda.current_device()) + + def forward( + self, + input: torch.Tensor, + weight: torch.Tensor, + weight_scales: torch.Tensor, + weight_zeros: torch.Tensor, + g_idx: torch.Tensor = None, + act_type=0, + bias: torch.Tensor = None, + residual: torch.Tensor = None, + qkv_fused=False, + ): + add_bias = True + if bias is None: + bias = self.empty_tensor + add_bias = False + + add_residual = True + if residual is None: + residual = self.empty_tensor + add_residual = False + x = input.view(-1, input.shape[-1]) + + out = gptq_fused_linear_triton( + x, + weight, + weight_scales, + weight_zeros, + bias, + residual, + self.bits, + self.maxq, + self.group_size, + qkv_fused, + add_bias, + add_residual, + act_type=act_type, + g_idx=g_idx, + ) + if qkv_fused: + out = out.view(3, input.shape[0], input.shape[1], weight.shape[-1]) + else: + out = out.view(input.shape[0], input.shape[1], weight.shape[-1]) + + return out diff --git a/colossalai/inference/tensor_parallel/__init__.py b/colossalai/inference/tensor_parallel/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..112b920ba158144a18a94b5664df7063555a701c --- /dev/null +++ b/colossalai/inference/tensor_parallel/__init__.py @@ -0,0 +1,4 @@ +from .engine import TPInferEngine +from .kvcache_manager import MemoryManager + +__all__ = ["MemoryManager", "TPInferEngine"] diff --git a/colossalai/inference/tensor_parallel/batch_infer_state.py b/colossalai/inference/tensor_parallel/batch_infer_state.py new file mode 100644 index 0000000000000000000000000000000000000000..ac185f1b6529376a77901327cf85e2008a95932a --- /dev/null +++ b/colossalai/inference/tensor_parallel/batch_infer_state.py @@ -0,0 +1,56 @@ +# might want to consider combine with InferenceConfig in colossalai/ppinference/inference_config.py later +from dataclasses import dataclass + +import torch + +from .kvcache_manager import MemoryManager + + +@dataclass +class BatchInferState: + r""" + Information to be passed and used for a batch of inputs during + a single model forward + """ + batch_size: int + max_len_in_batch: int + + cache_manager: MemoryManager = None + + block_loc: torch.Tensor = None + start_loc: torch.Tensor = None + seq_len: torch.Tensor = None + past_key_values_len: int = None + + is_context_stage: bool = False + context_mem_index: torch.Tensor = None + decode_is_contiguous: bool = None + decode_mem_start: int = None + decode_mem_end: int = None + decode_mem_index: torch.Tensor = None + decode_layer_id: int = None + + device: torch.device = torch.device("cuda") + + @property + def total_token_num(self): + # return self.batch_size * self.max_len_in_batch + assert self.seq_len is not None and self.seq_len.size(0) > 0 + return int(torch.sum(self.seq_len)) + + def set_cache_manager(self, manager: MemoryManager): + self.cache_manager = manager + + @staticmethod + def init_block_loc( + b_loc: torch.Tensor, seq_len: torch.Tensor, max_len_in_batch: int, alloc_mem_index: torch.Tensor + ): + """in-place update block loc mapping based on the sequence length of the inputs in current bath""" + start_index = 0 + seq_len_numpy = seq_len.cpu().numpy() + for i, cur_seq_len in enumerate(seq_len_numpy): + b_loc[i, max_len_in_batch - cur_seq_len : max_len_in_batch] = alloc_mem_index[ + start_index : start_index + cur_seq_len + ] + start_index += cur_seq_len + return diff --git a/colossalai/inference/tensor_parallel/engine.py b/colossalai/inference/tensor_parallel/engine.py new file mode 100644 index 0000000000000000000000000000000000000000..d5ef37fee420927848cd4034257b619f510c2ecf --- /dev/null +++ b/colossalai/inference/tensor_parallel/engine.py @@ -0,0 +1,380 @@ +from typing import Any, Callable, List, Optional, Union + +import torch +import torch.distributed as dist +import torch.nn as nn +from transformers import BloomForCausalLM, LlamaForCausalLM +from transformers.generation import GenerationConfig +from transformers.generation.stopping_criteria import StoppingCriteriaList +from transformers.tokenization_utils_base import BatchEncoding + +from colossalai.shardformer import ShardConfig, ShardFormer +from colossalai.shardformer.policies.auto_policy import get_autopolicy + +from .batch_infer_state import BatchInferState +from .kvcache_manager import MemoryManager + +DP_AXIS, PP_AXIS, TP_AXIS = 0, 1, 2 + +_supported_models = [ + "LlamaForCausalLM", + "LlamaModel", + "BloomForCausalLM", + "ChatGLMModel", + "ChatGLMForConditionalGeneration", +] + + +class TPInferEngine: + """Engine class for tensor parallel inference. + + Args: + model (Module): original model, e.g. huggingface CausalLM + shard_config (ShardConfig): The config for sharding original model + max_batch_size (int): maximum batch size + max_input_len (int): maximum input length of sequence + max_output_len (int): maximum output length of output tokens + dtype (torch.dtype): datatype used to init KV cache space + device (str): device the KV cache of engine to be initialized on + + Examples: + >>> # define model and shard config for your inference + >>> model = ... + >>> generate_kwargs = ... + >>> shard_config = ShardConfig(enable_tensor_parallelism=True, inference_only=True) + >>> infer_engine = TPInferEngine(model, shard_config, MAX_BATCH_SIZE, MAX_INPUT_LEN, MAX_OUTPUT_LEN) + >>> outputs = infer_engine.generate(input_ids, **generate_kwargs) + """ + + def __init__( + self, + model: nn.Module, + shard_config: ShardConfig, + max_batch_size: int, + max_input_len: int, + max_output_len: int, + dtype: torch.dtype = torch.float16, + device: str = "cuda", + ) -> None: + self.max_batch_size = max_batch_size + self.max_input_len = max_input_len + self.max_output_len = max_output_len + self.max_total_token_num = self.max_batch_size * (self.max_input_len + self.max_output_len) + + # Constraints relatable with specs of devices and model + # This may change into an optional arg in the future + assert self.max_batch_size <= 64, "Max batch size exceeds the constraint" + assert self.max_input_len + self.max_output_len <= 4096, "Max length exceeds the constraint" + + self.dtype = dtype + + self.head_dim = model.config.hidden_size // model.config.num_attention_heads + self.head_num = model.config.num_attention_heads + num_hidden_layers = ( + model.config.num_hidden_layers if hasattr(model.config, "num_hidden_layers") else model.config.num_layers + ) + self.layer_num = num_hidden_layers + self.multi_query_group_num = ( + model.config.multi_query_group_num if hasattr(model.config, "multi_query_group_num") else 0 + ) + + self.tp_size = -1 # to be set with given shard config in self.prepare_shard_config + self.cache_manager = None + + self.max_dq_buffer_size = 1 + self.max_inner_outer_dim = 1 + self.gptq_temp_state_buffer = None + self.gptq_temp_dq_buffer = None + self.bits = -1 + self.use_act_order = False + + self.shard_config = shard_config + self.model = None + # optimize the original model by sharding with ShardFormer + self._optimize_model(model=model.to(device)) + + def _init_manager(self) -> None: + assert self.tp_size >= 1, "TP size not initialized without providing a valid ShardConfig" + assert self.head_num % self.tp_size == 0, f"Cannot shard {self.head_num} heads with tp size {self.tp_size}" + self.head_num //= self.tp_size # update sharded number of heads + if self.multi_query_group_num: + # NOTE the logic of MQA tensor parallelism should be specified. + assert ( + self.multi_query_group_num % self.tp_size == 0 + ), f"Cannot shard {self.multi_query_group_num} query groups with tp size {self.tp_size}" + self.cache_manager = MemoryManager( + self.max_total_token_num, + self.dtype, + self.multi_query_group_num // self.tp_size, + self.head_dim, + self.layer_num, + ) + else: + self.cache_manager = MemoryManager( + self.max_total_token_num, self.dtype, self.head_num, self.head_dim, self.layer_num + ) + + def _post_init_gptq_buffer(self, model: nn.Module) -> None: + from colossalai.inference.quant.gptq.cai_gptq import CaiQuantLinear + HAS_GPTQ_CUDA = False + try: + from colossalai.kernel.op_builder.gptq import GPTQBuilder + gptq_cuda = GPTQBuilder().load() + HAS_GPTQ_CUDA = True + except ImportError: + warnings.warn('CUDA gptq is not installed') + HAS_GPTQ_CUDA = False + + for name, submodule in model.named_modules(): + if isinstance(submodule, CaiQuantLinear): + self.max_dq_buffer_size = max(self.max_dq_buffer_size, submodule.qweight.numel() * 8) + + if self.use_act_order: + self.max_inner_outer_dim = max(self.max_inner_outer_dim, submodule.infeatures, + submodule.outfeatures) + self.bits = submodule.bits + if not (HAS_GPTQ_CUDA and self.bits == 4): + return + + max_input_len = 1 + if self.use_act_order: + max_input_len = self.max_input_len + # The temp_state buffer is required to reorder X in the act-order case. + # The temp_dq buffer is required to dequantize weights when using cuBLAS, typically for the prefill. + self.gptq_temp_state_buffer = torch.zeros((max_input_len, self.max_inner_outer_dim), + dtype=torch.float16, + device=torch.cuda.current_device()) + self.gptq_temp_dq_buffer = torch.zeros((1, self.max_dq_buffer_size), + dtype=torch.float16, + device=torch.cuda.current_device()) + + gptq_cuda.prepare_buffers(torch.device(torch.cuda.current_device()), self.gptq_temp_state_buffer, + self.gptq_temp_dq_buffer) + # Using the default from exllama repo here. + matmul_recons_thd = 8 + matmul_fused_remap = False + matmul_no_half2 = False + gptq_cuda.set_tuning_params(matmul_recons_thd, matmul_fused_remap, matmul_no_half2) + + torch.cuda.empty_cache() + + def _optimize_model(self, model: nn.Module) -> None: + """ + Optimize the original model by sharding with ShardFormer. + In further generation, use the sharded model instead of original model. + """ + # NOTE we will change to use an inference config later with additional attrs we want + assert self.shard_config.inference_only is True + shardformer = ShardFormer(shard_config=self.shard_config) + self._prepare_with_shard_config(shard_config=self.shard_config) + self._shard_model_by(shardformer, model) + + def _prepare_with_shard_config(self, shard_config: Optional[ShardConfig] = None) -> ShardConfig: + """Prepare the engine with a given ShardConfig. + + Args: + shard_config (ShardConfig): shard config given to specify settings of the engine. + If not provided, a default ShardConfig with tp size 1 will be created. + """ + self.tp_size = 1 + if shard_config is None: + shard_config = ShardConfig( + tensor_parallel_process_group=None, + pipeline_stage_manager=None, + enable_tensor_parallelism=False, + enable_fused_normalization=False, + enable_all_optimization=False, + enable_flash_attention=False, + enable_jit_fused=False, + inference_only=True, + ) + else: + shard_config.inference_only = True + shard_config.pipeline_stage_manager = None + if shard_config.enable_tensor_parallelism: + self.tp_size = shard_config.tensor_parallel_size + self._init_manager() + + return shard_config + + def _shard_model_by(self, shardformer: ShardFormer, model: nn.Module) -> None: + """Shard original model by the given ShardFormer and store the sharded model.""" + assert ( + self.tp_size == shardformer.shard_config.tensor_parallel_size + ), "Discrepancy between the tp size of TPInferEngine and the tp size of shard config" + model_name = model.__class__.__name__ + assert model_name in self.supported_models, f"Unsupported model cls {model_name} for TP inference." + policy = get_autopolicy(model, inference_only=True) + self.model, _ = shardformer.optimize(model, policy) + + if self.shard_config.inference_gptq: + self._post_init_gptq_buffer(model) + + self.model = self.model.cuda() + + @property + def supported_models(self) -> List[str]: + return _supported_models + + def generate(self, input_tokens: Union[BatchEncoding, dict, list, torch.Tensor], **generate_kwargs) -> torch.Tensor: + """Generate token sequence. + + Args: + input_tokens: could be one of the following types + 1. BatchEncoding or dict (e.g. tokenizer batch_encode) + 2. list of input token ids (e.g. appended result of tokenizer encode) + 3. torch.Tensor (e.g. tokenizer encode with return_tensors='pt') + Returns: + torch.Tensor: The returned sequence is given inputs + generated_tokens. + """ + if isinstance(input_tokens, torch.Tensor): + input_tokens = dict(input_ids=input_tokens, attention_mask=torch.ones_like(input_tokens, dtype=torch.bool)) + for t in input_tokens: + if torch.is_tensor(input_tokens[t]): + input_tokens[t] = input_tokens[t].cuda() + if "max_new_tokens" not in generate_kwargs: + generate_kwargs.update(max_new_tokens=self.max_output_len) + + return self._generate_by_set_infer_state(input_tokens, **generate_kwargs) + + def prepare_batch_state(self, inputs) -> BatchInferState: + """ + Create and prepare BatchInferState used for inference during model forwrad, + by processing each sequence of the given inputs. + + Args: + inputs: should be one of the following types + 1. BatchEncoding or dict (e.g. tokenizer batch_encode) + 2. list of input token ids (e.g. appended result of tokenizer encode) + 3. torch.Tensor (e.g. tokenizer encode with return_tensors='pt') + NOTE For torch.Tensor inputs representing a batch of inputs, we are unable to retrieve + the actual length (e.g. number of tokens) of each input without attention mask + Hence, for torch.Tensor with shape [bs, l] where bs > 1, we will assume + all the inputs in the batch has the maximum length l + Returns: + BatchInferState: the states for the current batch during inference + """ + if not isinstance(inputs, (BatchEncoding, dict, list, torch.Tensor)): + raise TypeError(f"inputs type {type(inputs)} is not supported in prepare_batch_state") + + input_ids_list = None + attention_mask = None + + if isinstance(inputs, (BatchEncoding, dict)): + input_ids_list = inputs["input_ids"] + attention_mask = inputs["attention_mask"] + else: + input_ids_list = inputs + if isinstance(input_ids_list[0], int): # for a single input + input_ids_list = [input_ids_list] + attention_mask = [attention_mask] if attention_mask is not None else attention_mask + + batch_size = len(input_ids_list) + + seq_start_indexes = torch.zeros(batch_size, dtype=torch.int32, device="cuda") + seq_lengths = torch.zeros(batch_size, dtype=torch.int32, device="cuda") + start_index = 0 + + max_len_in_batch = -1 + if isinstance(inputs, (BatchEncoding, dict)): + for i, attn_mask in enumerate(attention_mask): + curr_seq_len = len(attn_mask) + # if isinstance(attn_mask, torch.Tensor): + # curr_seq_len = int(torch.sum(attn_mask)) + # else: + # curr_seq_len = int(sum(attn_mask)) + seq_lengths[i] = curr_seq_len + seq_start_indexes[i] = start_index + start_index += curr_seq_len + max_len_in_batch = curr_seq_len if curr_seq_len > max_len_in_batch else max_len_in_batch + else: + length = max(len(input_id) for input_id in input_ids_list) + for i, input_ids in enumerate(input_ids_list): + curr_seq_len = length + seq_lengths[i] = curr_seq_len + seq_start_indexes[i] = start_index + start_index += curr_seq_len + max_len_in_batch = curr_seq_len if curr_seq_len > max_len_in_batch else max_len_in_batch + block_loc = torch.empty((batch_size, self.max_input_len + self.max_output_len), dtype=torch.long, device="cuda") + batch_infer_state = BatchInferState(batch_size, max_len_in_batch) + batch_infer_state.seq_len = seq_lengths.to("cuda") + batch_infer_state.start_loc = seq_start_indexes.to("cuda") + batch_infer_state.block_loc = block_loc + batch_infer_state.decode_layer_id = 0 + batch_infer_state.past_key_values_len = 0 + batch_infer_state.is_context_stage = True + batch_infer_state.set_cache_manager(self.cache_manager) + return batch_infer_state + + @torch.no_grad() + def _generate_by_set_infer_state(self, input_tokens, **generate_kwargs) -> torch.Tensor: + """ + Generate output tokens by setting BatchInferState as an attribute to the model and calling model.generate + + Args: + inputs: should be one of the following types + 1. BatchEncoding or dict (e.g. tokenizer batch_encode) + 2. list of input token ids (e.g. appended result of tokenizer encode) + 3. torch.Tensor (e.g. tokenizer encode with return_tensors='pt') + """ + + # for testing, always use sharded model + assert self.model is not None, "sharded model does not exist" + + batch_infer_state = self.prepare_batch_state(input_tokens) + assert batch_infer_state.max_len_in_batch <= self.max_input_len, "max length in batch exceeds limit" + + # set BatchInferState for the current batch as attr to model + # NOTE this is not a preferable way to pass BatchInferState during inference + # we might want to rewrite generate function (e.g. _generate_by_pass_infer_state) + # and pass BatchInferState via model forward + model = self.model + if isinstance(model, LlamaForCausalLM): + model = self.model.model + elif isinstance(model, BloomForCausalLM): + model = self.model.transformer + setattr(model, "infer_state", batch_infer_state) + + outputs = self.model.generate(**input_tokens, **generate_kwargs, early_stopping=False) + + # NOTE In future development, we're going to let the scheduler to handle the cache, + # instead of freeing space explicitly at the end of generation + self.cache_manager.free_all() + + return outputs + + # TODO might want to implement the func that generates output tokens by passing BatchInferState + # as an arg into model.forward. + # It requires rewriting model generate and replacing model forward. + @torch.no_grad() + def _generate_by_pass_infer_state( + self, + input_tokens, + max_out_length: int, + generation_config: Optional[GenerationConfig] = None, + stopping_criteria: Optional[StoppingCriteriaList] = None, + prepare_inputs_fn: Optional[Callable[[torch.Tensor, Any], dict]] = None, + **model_kwargs, + ) -> torch.Tensor: + raise NotImplementedError("generate by passing BatchInferState is not implemented.") + + # might want to use in rewritten generate method: use after model.forward + # BatchInferState is created and kept during generation + # after each iter of model forward, we should update BatchInferState + def _update_batch_state(self, infer_state: Optional[BatchInferState]) -> None: + batch_size = infer_state.batch_size + device = infer_state.start_loc.device + infer_state.start_loc = infer_state.start_loc + torch.arange(0, batch_size, dtype=torch.int32, device=device) + infer_state.seq_len += 1 + + # might want to create a sequence pool + # add a single request/sequence/input text at a time and record its length + # In other words, store the actual length of input tokens representing a single input text + # E.g. "Introduce landmarks in Beijing" + # => add request + # => record token length and other necessary information to be used + # => engine hold all these necessary information until `generate` (or other name) is called, + # => put information already recorded in batchinferstate and pass it to model forward + # => clear records in engine + def add_request(): + raise NotImplementedError() diff --git a/colossalai/inference/tensor_parallel/kvcache_manager.py b/colossalai/inference/tensor_parallel/kvcache_manager.py new file mode 100644 index 0000000000000000000000000000000000000000..e74a3a491a7bb87b09d8a93d29d67eea0593d749 --- /dev/null +++ b/colossalai/inference/tensor_parallel/kvcache_manager.py @@ -0,0 +1,104 @@ +# Adapted from lightllm/common/mem_manager.py +# of the ModelTC/lightllm GitHub repository +# https://github.com/ModelTC/lightllm/blob/050af3ce65edca617e2f30ec2479397d5bb248c9/lightllm/common/mem_manager.py + +import torch +from transformers.utils import logging + + +class MemoryManager: + r""" + Manage token block indexes and allocate physical memory for key and value cache + + Args: + size: maximum token number used as the size of key and value buffer + dtype: data type of cached key and value + head_num: number of heads the memory manager is responsible for + head_dim: embedded size per head + layer_num: the number of layers in the model + device: device used to store the key and value cache + """ + + def __init__( + self, + size: int, + dtype: torch.dtype, + head_num: int, + head_dim: int, + layer_num: int, + device: torch.device = torch.device("cuda"), + ): + self.logger = logging.get_logger(__name__) + self.available_size = size + self.past_key_values_length = 0 + self._init_mem_states(size, device) + self._init_kv_buffers(size, device, dtype, head_num, head_dim, layer_num) + + def _init_mem_states(self, size, device): + """Initialize tensors used to manage memory states""" + self.mem_state = torch.ones((size,), dtype=torch.bool, device=device) + self.mem_cum_sum = torch.empty((size,), dtype=torch.int32, device=device) + self.indexes = torch.arange(0, size, dtype=torch.long, device=device) + + def _init_kv_buffers(self, size, device, dtype, head_num, head_dim, layer_num): + """Initialize key buffer and value buffer on specified device""" + self.key_buffer = [ + torch.empty((size, head_num, head_dim), dtype=dtype, device=device) for _ in range(layer_num) + ] + self.value_buffer = [ + torch.empty((size, head_num, head_dim), dtype=dtype, device=device) for _ in range(layer_num) + ] + + @torch.no_grad() + def alloc(self, required_size): + """allocate space of required_size by providing indexes representing available physical spaces""" + if required_size > self.available_size: + self.logger.warning(f"No enough cache: required_size {required_size} " f"left_size {self.available_size}") + return None + torch.cumsum(self.mem_state, dim=0, dtype=torch.int32, out=self.mem_cum_sum) + select_index = torch.logical_and(self.mem_cum_sum <= required_size, self.mem_state == 1) + select_index = self.indexes[select_index] + self.mem_state[select_index] = 0 + self.available_size -= len(select_index) + return select_index + + @torch.no_grad() + def alloc_contiguous(self, required_size): + """allocate contiguous space of required_size""" + if required_size > self.available_size: + self.logger.warning(f"No enough cache: required_size {required_size} " f"left_size {self.available_size}") + return None + torch.cumsum(self.mem_state, dim=0, dtype=torch.int32, out=self.mem_cum_sum) + sum_size = len(self.mem_cum_sum) + loc_sums = ( + self.mem_cum_sum[required_size - 1 :] + - self.mem_cum_sum[0 : sum_size - required_size + 1] + + self.mem_state[0 : sum_size - required_size + 1] + ) + can_used_loc = self.indexes[0 : sum_size - required_size + 1][loc_sums == required_size] + if can_used_loc.shape[0] == 0: + self.logger.info( + f"No enough contiguous cache: required_size {required_size} " f"left_size {self.available_size}" + ) + return None + start_loc = can_used_loc[0] + select_index = self.indexes[start_loc : start_loc + required_size] + self.mem_state[select_index] = 0 + self.available_size -= len(select_index) + start = start_loc.item() + end = start + required_size + return select_index, start, end + + @torch.no_grad() + def free(self, free_index): + """free memory by updating memory states based on given indexes""" + self.available_size += free_index.shape[0] + self.mem_state[free_index] = 1 + + @torch.no_grad() + def free_all(self): + """free all memory by updating memory states""" + self.available_size = len(self.mem_state) + self.mem_state[:] = 1 + self.past_key_values_length = 0 + self.logger.info("freed all space of memory manager") diff --git a/colossalai/inference/tensor_parallel/modeling/__init__.py b/colossalai/inference/tensor_parallel/modeling/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..4662368b17b42098b572777ed4d3b216d15d3983 --- /dev/null +++ b/colossalai/inference/tensor_parallel/modeling/__init__.py @@ -0,0 +1,5 @@ +from .bloom import BloomInferenceForwards +from .chatglm2 import ChatGLM2InferenceForwards +from .llama import LlamaInferenceForwards + +__all__ = ["BloomInferenceForwards", "LlamaInferenceForwards", "ChatGLM2InferenceForwards"] diff --git a/colossalai/inference/tensor_parallel/modeling/_utils.py b/colossalai/inference/tensor_parallel/modeling/_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..e476c313253835232311c81e258c7a169579e189 --- /dev/null +++ b/colossalai/inference/tensor_parallel/modeling/_utils.py @@ -0,0 +1,67 @@ +""" +Utils for model inference +""" +import os + +import torch + +from colossalai.kernel.triton.copy_kv_cache_dest import copy_kv_cache_to_dest + + +def copy_kv_to_mem_cache(layer_id, key_buffer, value_buffer, context_mem_index, mem_manager): + """ + This function copies the key and value cache to the memory cache + Args: + layer_id : id of current layer + key_buffer : key cache + value_buffer : value cache + context_mem_index : index of memory cache in kv cache manager + mem_manager : cache manager + """ + copy_kv_cache_to_dest(key_buffer, context_mem_index, mem_manager.key_buffer[layer_id]) + copy_kv_cache_to_dest(value_buffer, context_mem_index, mem_manager.value_buffer[layer_id]) + + +def init_to_get_rotary(self, base=10000, use_elem=False): + """ + This function initializes the rotary positional embedding, it is compatible for all models and is called in ShardFormer + Args: + self : Model that holds the rotary positional embedding + base : calculation arg + use_elem : activated when using chatglm-based models + """ + self.config.head_dim_ = self.config.hidden_size // self.config.num_attention_heads + if not hasattr(self.config, "rope_scaling"): + rope_scaling_factor = 1.0 + else: + rope_scaling_factor = self.config.rope_scaling.factor if self.config.rope_scaling is not None else 1.0 + + if hasattr(self.config, "max_sequence_length"): + max_seq_len = self.config.max_sequence_length + elif hasattr(self.config, "max_position_embeddings"): + max_seq_len = self.config.max_position_embeddings * rope_scaling_factor + else: + max_seq_len = 2048 * rope_scaling_factor + base = float(base) + + # NTK ref: https://www.reddit.com/r/LocalLLaMA/comments/14lz7j5/ntkaware_scaled_rope_allows_llama_models_to_have/ + ntk_alpha = float(os.environ.get("INFER_NTK_ALPHA", None)) + + if ntk_alpha is not None: + ntk_alpha = float(ntk_alpha) + assert ntk_alpha >= 1, "NTK alpha must be greater than or equal to 1" + if ntk_alpha > 1: + print(f"Note: NTK enabled, alpha set to {ntk_alpha}") + max_seq_len *= ntk_alpha + base = base * (ntk_alpha ** (self.head_dim_ / (self.head_dim_ - 2))) # Base change formula + + n_elem = self.config.head_dim_ + if use_elem: + n_elem //= 2 + + inv_freq = 1.0 / (base ** (torch.arange(0, n_elem, 2, device="cpu", dtype=torch.float32) / n_elem)) + t = torch.arange(max_seq_len + 1024 * 64, device="cpu", dtype=torch.float32) / rope_scaling_factor + freqs = torch.outer(t, inv_freq) + + self._cos_cached = torch.cos(freqs).to(torch.float16).cuda() + self._sin_cached = torch.sin(freqs).to(torch.float16).cuda() diff --git a/colossalai/inference/tensor_parallel/modeling/bloom.py b/colossalai/inference/tensor_parallel/modeling/bloom.py new file mode 100644 index 0000000000000000000000000000000000000000..27a26caabefa96aef4e6ca8a82429fa881f0dd6e --- /dev/null +++ b/colossalai/inference/tensor_parallel/modeling/bloom.py @@ -0,0 +1,540 @@ +import math +import warnings +from typing import Optional, Tuple, Union + +import torch +import torch.distributed as dist +from torch.nn import CrossEntropyLoss +from torch.nn import functional as F +from transformers.models.bloom.modeling_bloom import ( + BaseModelOutputWithPastAndCrossAttentions, + BloomAttention, + BloomBlock, + BloomForCausalLM, + BloomModel, + CausalLMOutputWithCrossAttentions, +) +from transformers.utils import logging + +from colossalai.inference.tensor_parallel.batch_infer_state import BatchInferState +from colossalai.kernel.triton import bloom_context_attn_fwd, copy_kv_cache_to_dest, token_attention_fwd + + +def generate_alibi(n_head, dtype=torch.float16): + """ + This method is adapted from `_generate_alibi` function + in `lightllm/models/bloom/layer_weights/transformer_layer_weight.py` + of the ModelTC/lightllm GitHub repository. + This method is originally the `build_alibi_tensor` function + in `transformers/models/bloom/modeling_bloom.py` + of the huggingface/transformers GitHub repository. + """ + + def get_slopes_power_of_2(n): + start = 2 ** (-(2 ** -(math.log2(n) - 3))) + return [start * start**i for i in range(n)] + + def get_slopes(n): + if math.log2(n).is_integer(): + return get_slopes_power_of_2(n) + else: + closest_power_of_2 = 2 ** math.floor(math.log2(n)) + slopes_power_of_2 = get_slopes_power_of_2(closest_power_of_2) + slopes_double = get_slopes(2 * closest_power_of_2) + slopes_combined = slopes_power_of_2 + slopes_double[0::2][: n - closest_power_of_2] + return slopes_combined + + slopes = get_slopes(n_head) + return torch.tensor(slopes, dtype=dtype) + + +class BloomInferenceForwards: + """ + This class serves a micro library for bloom inference forwards. + We intend to replace the forward methods for BloomForCausalLM, BloomModel, BloomBlock, and BloomAttention, + as well as prepare_inputs_for_generation method for BloomForCausalLM. + For future improvement, we might want to skip replacing methods for BloomForCausalLM, + and call BloomModel.forward iteratively in TpInferEngine + """ + + @staticmethod + def bloom_model_forward( + self: BloomModel, + input_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None, + attention_mask: Optional[torch.Tensor] = None, + head_mask: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + infer_state: Optional[BatchInferState] = None, + **deprecated_arguments, + ) -> Union[Tuple[torch.Tensor, ...], BaseModelOutputWithPastAndCrossAttentions]: + logger = logging.get_logger(__name__) + + if deprecated_arguments.pop("position_ids", False) is not False: + # `position_ids` could have been `torch.Tensor` or `None` so defaulting pop to `False` allows to detect if users were passing explicitly `None` + warnings.warn( + "`position_ids` have no functionality in BLOOM and will be removed in v5.0.0. You can safely ignore" + " passing `position_ids`.", + FutureWarning, + ) + if len(deprecated_arguments) > 0: + raise ValueError(f"Got unexpected arguments: {deprecated_arguments}") + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is not None: + batch_size, seq_length = input_ids.shape + elif inputs_embeds is not None: + batch_size, seq_length, _ = inputs_embeds.shape + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + # still need to keep past_key_values to fit original forward flow + if past_key_values is None: + past_key_values = tuple([None] * len(self.h)) + + # Prepare head mask if needed + # 1.0 in head_mask indicate we keep the head + # attention_probs has shape batch_size x num_heads x N x N + # head_mask has shape n_layer x batch x num_heads x N x N + head_mask = self.get_head_mask(head_mask, self.config.n_layer) + + if inputs_embeds is None: + inputs_embeds = self.word_embeddings(input_ids) + + hidden_states = self.word_embeddings_layernorm(inputs_embeds) + + presents = () if use_cache else None + all_self_attentions = () if output_attentions else None + all_hidden_states = () if output_hidden_states else None + + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + # NOTE determine if BatchInferState is passed in via arg + # if not, get the attr binded to the model + # We might wantto remove setattr later + if infer_state is None: + assert hasattr(self, "infer_state") + infer_state = self.infer_state + + # Compute alibi tensor: check build_alibi_tensor documentation + seq_length_with_past = seq_length + past_key_values_length = 0 + # if self.cache_manager.past_key_values_length > 0: + if infer_state.cache_manager.past_key_values_length > 0: + # update the past key values length in cache manager, + # NOTE use BatchInferState.past_key_values_length instead the one in cache manager + past_key_values_length = infer_state.cache_manager.past_key_values_length + seq_length_with_past = seq_length_with_past + past_key_values_length + + # infer_state.cache_manager = self.cache_manager + + if use_cache and seq_length != 1: + # prefill stage + infer_state.is_context_stage = True # set prefill stage, notify attention layer + infer_state.context_mem_index = infer_state.cache_manager.alloc(infer_state.total_token_num) + BatchInferState.init_block_loc( + infer_state.block_loc, infer_state.seq_len, seq_length, infer_state.context_mem_index + ) + else: + infer_state.is_context_stage = False + alloc_mem = infer_state.cache_manager.alloc_contiguous(batch_size) + if alloc_mem is not None: + infer_state.decode_is_contiguous = True + infer_state.decode_mem_index = alloc_mem[0] + infer_state.decode_mem_start = alloc_mem[1] + infer_state.decode_mem_end = alloc_mem[2] + infer_state.block_loc[:, seq_length_with_past - 1] = infer_state.decode_mem_index + else: + print(f" *** Encountered allocation non-contiguous") + print( + f" infer_state.cache_manager.past_key_values_length: {infer_state.cache_manager.past_key_values_length}" + ) + infer_state.decode_is_contiguous = False + alloc_mem = infer_state.cache_manager.alloc(batch_size) + infer_state.decode_mem_index = alloc_mem + # infer_state.decode_key_buffer = torch.empty((batch_size, self.tp_head_num_, self.head_dim_), dtype=torch.float16, device="cuda") + # infer_state.decode_value_buffer = torch.empty((batch_size, self.tp_head_num_, self.head_dim_), dtype=torch.float16, device="cuda") + infer_state.block_loc[:, seq_length_with_past - 1] = infer_state.decode_mem_index + + if attention_mask is None: + attention_mask = torch.ones((batch_size, seq_length_with_past), device=hidden_states.device) + else: + attention_mask = attention_mask.to(hidden_states.device) + + # NOTE revise: we might want to store a single 1D alibi(length is #heads) in model, + # or store to BatchInferState to prevent re-calculating + # When we have multiple process group (e.g. dp together with tp), we need to pass the pg to here + # alibi = generate_alibi(self.num_heads).contiguous().cuda() + tp_size = dist.get_world_size() + curr_tp_rank = dist.get_rank() + alibi = ( + generate_alibi(self.num_heads * tp_size) + .contiguous()[curr_tp_rank * self.num_heads : (curr_tp_rank + 1) * self.num_heads] + .cuda() + ) + causal_mask = self._prepare_attn_mask( + attention_mask, + input_shape=(batch_size, seq_length), + past_key_values_length=past_key_values_length, + ) + + for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)): + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if self.gradient_checkpointing and self.training: + # NOTE: currently our KV cache manager does not handle this condition + def create_custom_forward(module): + def custom_forward(*inputs): + # None for past_key_value + return module(*inputs, use_cache=use_cache, output_attentions=output_attentions) + + return custom_forward + + outputs = torch.utils.checkpoint.checkpoint( + create_custom_forward(block), + hidden_states, + alibi, + causal_mask, + layer_past, + head_mask[i], + ) + else: + outputs = block( + hidden_states, + layer_past=layer_past, + attention_mask=causal_mask, + head_mask=head_mask[i], + use_cache=use_cache, + output_attentions=output_attentions, + alibi=alibi, + infer_state=infer_state, + ) + + hidden_states = outputs[0] + if use_cache is True: + presents = presents + (outputs[1],) + + if output_attentions: + all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],) + + # Add last hidden state + hidden_states = self.ln_f(hidden_states) + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + # update indices of kv cache block + # NOT READY FOR PRIME TIME + # might want to remove this part, instead, better to pass the BatchInferState from model forward, + # and update these information in engine.generate after model foward called + infer_state.start_loc = infer_state.start_loc + torch.arange(0, batch_size, dtype=torch.int32, device="cuda") + infer_state.seq_len += 1 + infer_state.decode_layer_id = 0 + + if not return_dict: + return tuple(v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None) + + return BaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=hidden_states, + past_key_values=presents, # should always be (None, None, ..., None) + hidden_states=all_hidden_states, + attentions=all_self_attentions, + ) + + @staticmethod + def bloom_for_causal_lm_forward( + self: BloomForCausalLM, + input_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None, + attention_mask: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + infer_state: Optional[BatchInferState] = None, + **deprecated_arguments, + ): + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set + `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100` + are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]` + """ + logging.get_logger(__name__) + + if deprecated_arguments.pop("position_ids", False) is not False: + # `position_ids` could have been `torch.Tensor` or `None` so defaulting pop to `False` allows to detect if users were passing explicitly `None` + warnings.warn( + "`position_ids` have no functionality in BLOOM and will be removed in v5.0.0. You can safely ignore" + " passing `position_ids`.", + FutureWarning, + ) + if len(deprecated_arguments) > 0: + raise ValueError(f"Got unexpected arguments: {deprecated_arguments}") + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + transformer_outputs = BloomInferenceForwards.bloom_model_forward( + self.transformer, + input_ids, + past_key_values=past_key_values, + attention_mask=attention_mask, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + infer_state=infer_state, + ) + hidden_states = transformer_outputs[0] + + lm_logits = self.lm_head(hidden_states) + + loss = None + if labels is not None: + # move labels to correct device to enable model parallelism + labels = labels.to(lm_logits.device) + # Shift so that tokens < n predict n + shift_logits = lm_logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + batch_size, seq_length, vocab_size = shift_logits.shape + # Flatten the tokens + loss_fct = CrossEntropyLoss() + loss = loss_fct( + shift_logits.view(batch_size * seq_length, vocab_size), shift_labels.view(batch_size * seq_length) + ) + + if not return_dict: + output = (lm_logits,) + transformer_outputs[1:] + return ((loss,) + output) if loss is not None else output + + return CausalLMOutputWithCrossAttentions( + loss=loss, + logits=lm_logits, + past_key_values=transformer_outputs.past_key_values, + hidden_states=transformer_outputs.hidden_states, + attentions=transformer_outputs.attentions, + ) + + @staticmethod + def bloom_for_causal_lm_prepare_inputs_for_generation( + self: BloomForCausalLM, + input_ids: torch.LongTensor, + past_key_values: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + **kwargs, + ) -> dict: + # only last token for input_ids if past is not None + if past_key_values: + input_ids = input_ids[:, -1].unsqueeze(-1) + + # NOTE we won't use past key values here + # the cache may be in the stardard format (e.g. in contrastive search), convert to bloom's format if needed + # if past_key_values[0][0].shape[0] == input_ids.shape[0]: + # past_key_values = self._convert_to_bloom_cache(past_key_values) + + # if `inputs_embeds` are passed, we only want to use them in the 1st generation step + if inputs_embeds is not None and past_key_values is None: + model_inputs = {"inputs_embeds": inputs_embeds} + else: + model_inputs = {"input_ids": input_ids} + + model_inputs.update( + { + "past_key_values": past_key_values, + "use_cache": kwargs.get("use_cache"), + "attention_mask": attention_mask, + } + ) + return model_inputs + + @staticmethod + def bloom_block_forward( + self: BloomBlock, + hidden_states: torch.Tensor, + alibi: torch.Tensor, + attention_mask: torch.Tensor, + layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + head_mask: Optional[torch.Tensor] = None, + use_cache: bool = False, + output_attentions: bool = False, + infer_state: Optional[BatchInferState] = None, + ): + # hidden_states: [batch_size, seq_length, hidden_size] + + # Layer norm at the beginning of the transformer layer. + layernorm_output = self.input_layernorm(hidden_states) + + # Layer norm post the self attention. + if self.apply_residual_connection_post_layernorm: + residual = layernorm_output + else: + residual = hidden_states + + # Self attention. + attn_outputs = self.self_attention( + layernorm_output, + residual, + layer_past=layer_past, + attention_mask=attention_mask, + alibi=alibi, + head_mask=head_mask, + use_cache=use_cache, + output_attentions=output_attentions, + infer_state=infer_state, + ) + + attention_output = attn_outputs[0] + + outputs = attn_outputs[1:] + + layernorm_output = self.post_attention_layernorm(attention_output) + + # Get residual + if self.apply_residual_connection_post_layernorm: + residual = layernorm_output + else: + residual = attention_output + + # MLP. + output = self.mlp(layernorm_output, residual) + + if use_cache: + outputs = (output,) + outputs + else: + outputs = (output,) + outputs[1:] + + return outputs # hidden_states, present, attentions + + @staticmethod + def bloom_attention_forward( + self: BloomAttention, + hidden_states: torch.Tensor, + residual: torch.Tensor, + alibi: torch.Tensor, + attention_mask: torch.Tensor, + layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + head_mask: Optional[torch.Tensor] = None, + use_cache: bool = False, + output_attentions: bool = False, + infer_state: Optional[BatchInferState] = None, + ): + fused_qkv = self.query_key_value(hidden_states) # [batch_size, seq_length, 3 x hidden_size] + + # 3 x [batch_size, seq_length, num_heads, head_dim] + (query_layer, key_layer, value_layer) = self._split_heads(fused_qkv) + batch_size, q_length, H, D_HEAD = query_layer.shape + k = key_layer.reshape(-1, H, D_HEAD) # batch_size * q_length, H, D_HEAD, q_lenth == 1 + v = value_layer.reshape(-1, H, D_HEAD) # batch_size * q_length, H, D_HEAD, q_lenth == 1 + + mem_manager = infer_state.cache_manager + layer_id = infer_state.decode_layer_id + + if layer_id == 0: # once per model.forward + infer_state.cache_manager.past_key_values_length += q_length # += 1 + + if infer_state.is_context_stage: + # context process + max_input_len = q_length + b_start_loc = infer_state.start_loc + b_seq_len = infer_state.seq_len[:batch_size] + q = query_layer.reshape(-1, H, D_HEAD) + + copy_kv_cache_to_dest(k, infer_state.context_mem_index, mem_manager.key_buffer[layer_id]) + copy_kv_cache_to_dest(v, infer_state.context_mem_index, mem_manager.value_buffer[layer_id]) + + # output = self.output[:batch_size*q_length, :, :] + output = torch.empty_like(q) + + bloom_context_attn_fwd(q, k, v, output, b_start_loc, b_seq_len, max_input_len, alibi) + + context_layer = output.view(batch_size, q_length, H * D_HEAD) + else: + # query_layer = query_layer.transpose(1, 2).reshape(batch_size * self.num_heads, q_length, self.head_dim) + # need shape: batch_size, H, D_HEAD (q_length == 1), input q shape : (batch_size, q_length(1), H, D_HEAD) + assert q_length == 1, "for non-context process, we only support q_length == 1" + q = query_layer.reshape(-1, H, D_HEAD) + + if infer_state.decode_is_contiguous: + # if decode is contiguous, then we copy to key cache and value cache in cache manager directly + cache_k = infer_state.cache_manager.key_buffer[layer_id][ + infer_state.decode_mem_start : infer_state.decode_mem_end, :, : + ] + cache_v = infer_state.cache_manager.value_buffer[layer_id][ + infer_state.decode_mem_start : infer_state.decode_mem_end, :, : + ] + cache_k.copy_(k) + cache_v.copy_(v) + else: + # if decode is not contiguous, use triton kernel to copy key and value cache + # k, v shape: [batch_size, num_heads, head_dim/embed_size_per_head] + copy_kv_cache_to_dest(k, infer_state.decode_mem_index, mem_manager.key_buffer[layer_id]) + copy_kv_cache_to_dest(v, infer_state.decode_mem_index, mem_manager.value_buffer[layer_id]) + + b_start_loc = infer_state.start_loc + b_loc = infer_state.block_loc + b_seq_len = infer_state.seq_len + output = torch.empty_like(q) + token_attention_fwd( + q, + mem_manager.key_buffer[layer_id], + mem_manager.value_buffer[layer_id], + output, + b_loc, + b_start_loc, + b_seq_len, + infer_state.cache_manager.past_key_values_length, + alibi, + ) + + context_layer = output.view(batch_size, q_length, H * D_HEAD) + + # update layer id + infer_state.decode_layer_id += 1 + + # NOTE: always set present as none for now, instead of returning past key value to the next decoding, + # we create the past key value pair from the cache manager + present = None + + # aggregate results across tp ranks. See here: https://github.com/pytorch/pytorch/issues/76232 + if self.pretraining_tp > 1 and self.slow_but_exact: + slices = self.hidden_size / self.pretraining_tp + output_tensor = torch.zeros_like(context_layer) + for i in range(self.pretraining_tp): + output_tensor = output_tensor + F.linear( + context_layer[:, :, int(i * slices) : int((i + 1) * slices)], + self.dense.weight[:, int(i * slices) : int((i + 1) * slices)], + ) + else: + output_tensor = self.dense(context_layer) + + # dropout is not required here during inference + output_tensor = residual + output_tensor + + outputs = (output_tensor, present) + assert output_attentions is False, "we do not support output_attentions at this time" + + return outputs diff --git a/colossalai/inference/tensor_parallel/modeling/chatglm2.py b/colossalai/inference/tensor_parallel/modeling/chatglm2.py new file mode 100644 index 0000000000000000000000000000000000000000..4b1bc601f436ac2c62079c924f6ac1851482ae10 --- /dev/null +++ b/colossalai/inference/tensor_parallel/modeling/chatglm2.py @@ -0,0 +1,540 @@ +import os +from typing import Optional, Tuple + +import torch +from torch.nn import CrossEntropyLoss +from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast + +from colossalai.inference.tensor_parallel.batch_infer_state import BatchInferState +from colossalai.kernel.triton.context_attention import llama2_context_attn_fwd +from colossalai.kernel.triton.rotary_embedding_kernel import Llama2Forwards +from colossalai.kernel.triton.token_attention_kernel import Llama2TokenAttentionForwards +from colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm import ( + ChatGLMForConditionalGeneration, + ChatGLMModel, + GLMBlock, + GLMTransformer, + SelfAttention, + split_tensor_along_last_dim, +) + +from ._utils import copy_kv_to_mem_cache + + +# This func is same as Llama model init_to_get_rotary, we should move them into _utils.py +def _init_to_get_rotary(self, base=10000): + self.config.head_dim_ = self.config.hidden_size // self.config.num_attention_heads + if not hasattr(self.config, "rope_scaling"): + rope_scaling_factor = 1.0 + else: + rope_scaling_factor = self.config.rope_scaling.factor if self.config.rope_scaling is not None else 1.0 + if hasattr(self.config, "max_sequence_length"): + max_seq_len = self.config.max_sequence_length + elif hasattr(self.config, "max_position_embeddings"): + max_seq_len = self.config.max_position_embeddings * rope_scaling_factor + else: + max_seq_len = 2048 * rope_scaling_factor + base = float(base) + + # NTK ref: https://www.reddit.com/r/LocalLLaMA/comments/14lz7j5/ntkaware_scaled_rope_allows_llama_models_to_have/ + try: + ntk_alpha = float(os.environ.get("INFER_NTK_ALPHA", 1)) + assert ntk_alpha >= 1 + if ntk_alpha > 1: + print(f"Note: NTK enabled, alpha set to {ntk_alpha}") + max_seq_len *= ntk_alpha + base = base * (ntk_alpha ** (self.head_dim_ / (self.head_dim_ - 2))) # Base change formula + except: + pass + n_elem = self.config.head_dim_ // 2 + inv_freq = 1.0 / (base ** (torch.arange(0, n_elem, 2, device="cpu", dtype=torch.float32) / n_elem)) + t = torch.arange(max_seq_len + 1024 * 64, device="cpu", dtype=torch.float32) / rope_scaling_factor + freqs = torch.outer(t, inv_freq) + + self._cos_cached = torch.cos(freqs).to(torch.float16).cuda() + self._sin_cached = torch.sin(freqs).to(torch.float16).cuda() + return + + +def get_masks(self, input_ids, past_length, padding_mask=None): + batch_size, seq_length = input_ids.shape + full_attention_mask = torch.ones(batch_size, seq_length, seq_length, device=input_ids.device) + full_attention_mask.tril_() + if past_length: + full_attention_mask = torch.cat( + ( + torch.ones(batch_size, seq_length, past_length, device=input_ids.device), + full_attention_mask, + ), + dim=-1, + ) + + if padding_mask is not None: + full_attention_mask = full_attention_mask * padding_mask.unsqueeze(1) + if not past_length and padding_mask is not None: + full_attention_mask -= padding_mask.unsqueeze(-1) - 1 + full_attention_mask = (full_attention_mask < 0.5).bool() + full_attention_mask.unsqueeze_(1) + return full_attention_mask + + +class ChatGLM2InferenceForwards: + """ + This class holds forwards for Chatglm2 inference. + We intend to replace the forward methods for ChatGLMModel, ChatGLMEecoderLayer, and ChatGLMAttention. + """ + + @staticmethod + def chatglm_for_conditional_generation_forward( + self: ChatGLMForConditionalGeneration, + input_ids: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[Tuple[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + return_last_logit: Optional[bool] = False, + ): + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + infer_state = self.infer_state + + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is not None: + batch_size, seq_length = input_ids.shape + elif inputs_embeds is not None: + batch_size, seq_length, _ = inputs_embeds.shape + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + past_key_values_length = 0 + + # NOT READY FOR PRIME TIME + # dummy but work, revise it + past_key_values_length = infer_state.cache_manager.past_key_values_length + seq_length_with_past = seq_length + past_key_values_length + infer_state.seq_length_with_past = seq_length_with_past + + # prefill stage at first + if use_cache and seq_length != 1: + infer_state.is_context_stage = True + infer_state.context_mem_index = infer_state.cache_manager.alloc(infer_state.total_token_num) + infer_state.init_block_loc( + infer_state.block_loc, infer_state.seq_len, seq_length, infer_state.context_mem_index + ) + else: + infer_state.is_context_stage = False + alloc_mem = infer_state.cache_manager.alloc_contiguous(batch_size) + if alloc_mem is not None: + infer_state.decode_is_contiguous = True + infer_state.decode_mem_index = alloc_mem[0] + infer_state.decode_mem_start = alloc_mem[1] + infer_state.decode_mem_end = alloc_mem[2] + infer_state.block_loc[:, seq_length_with_past - 1] = infer_state.decode_mem_index + else: + print(f" *** Encountered allocation non-contiguous") + print( + f" infer_state.cache_manager.past_key_values_length: {infer_state.cache_manager.past_key_values_length}" + ) + infer_state.decode_is_contiguous = False + alloc_mem = infer_state.cache_manager.alloc(batch_size) + infer_state.decode_mem_index = alloc_mem + # infer_state.decode_key_buffer = torch.empty((batch_size, self.tp_head_num_, self.head_dim_), dtype=torch.float16, device="cuda") + # infer_state.decode_value_buffer = torch.empty((batch_size, self.tp_head_num_, self.head_dim_), dtype=torch.float16, device="cuda") + infer_state.block_loc[:, seq_length_with_past - 1] = infer_state.decode_mem_index + + # related to rotary embedding + if infer_state.is_context_stage: + infer_state.position_cos = torch.index_select(self._cos_cached, 0, position_ids.view(-1)).view( + position_ids.view(-1).shape[0], -1 + ) + infer_state.position_sin = torch.index_select(self._sin_cached, 0, position_ids.view(-1)).view( + position_ids.view(-1).shape[0], -1 + ) + else: + seq_len = infer_state.seq_len + infer_state.position_cos = torch.index_select(self._cos_cached, 0, seq_len - 1).view(seq_len.shape[0], -1) + infer_state.position_sin = torch.index_select(self._sin_cached, 0, seq_len - 1).view(seq_len.shape[0], -1) + infer_state.other_kv_index = infer_state.block_loc[0, infer_state.max_len_in_batch - 1].item() + + transformer_outputs = self.transformer( + input_ids=input_ids, + position_ids=position_ids, + attention_mask=attention_mask, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + infer_state=infer_state, + ) + + hidden_states = transformer_outputs[0] + if return_last_logit: + hidden_states = hidden_states[-1:] + lm_logits = self.transformer.output_layer(hidden_states) + lm_logits = lm_logits.transpose(0, 1).contiguous() + + loss = None + if labels is not None: + lm_logits = lm_logits.to(torch.float32) + + # Shift so that tokens < n predict n + shift_logits = lm_logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + loss_fct = CrossEntropyLoss(ignore_index=-100) + loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) + + lm_logits = lm_logits.to(hidden_states.dtype) + loss = loss.to(hidden_states.dtype) + + if not return_dict: + output = (lm_logits,) + transformer_outputs[1:] + return ((loss,) + output) if loss is not None else output + + return CausalLMOutputWithPast( + loss=loss, + logits=lm_logits, + past_key_values=transformer_outputs.past_key_values, + hidden_states=transformer_outputs.hidden_states, + attentions=transformer_outputs.attentions, + ) + + @staticmethod + def chatglm_model_forward( + self: ChatGLMModel, + input_ids, + position_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.BoolTensor] = None, + full_attention_mask: Optional[torch.BoolTensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None, + inputs_embeds: Optional[torch.Tensor] = None, + use_cache: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + infer_state: BatchInferState = None, + ): + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + batch_size, seq_length = input_ids.shape + + if inputs_embeds is None: + inputs_embeds = self.embedding(input_ids) + + if self.pre_seq_len is not None: + if past_key_values is None: + past_key_values = self.get_prompt( + batch_size=batch_size, + device=input_ids.device, + dtype=inputs_embeds.dtype, + ) + if attention_mask is not None: + attention_mask = torch.cat( + [ + attention_mask.new_ones((batch_size, self.pre_seq_len)), + attention_mask, + ], + dim=-1, + ) + if full_attention_mask is None: + if (attention_mask is not None and not attention_mask.all()) or (past_key_values and seq_length != 1): + full_attention_mask = get_masks( + self, input_ids, infer_state.cache_manager.past_key_values_length, padding_mask=attention_mask + ) + + # Run encoder. + hidden_states, presents, all_hidden_states, all_self_attentions = self.encoder( + inputs_embeds, + full_attention_mask, + kv_caches=past_key_values, + use_cache=use_cache, + output_hidden_states=output_hidden_states, + infer_state=infer_state, + ) + + # update indices + # infer_state.block_loc[:, infer_state.max_len_in_batch-1] = infer_state.total_token_num + torch.arange(0, batch_size, dtype=torch.int32, device="cuda") + infer_state.start_loc = infer_state.start_loc + torch.arange(0, batch_size, dtype=torch.int32, device="cuda") + infer_state.seq_len += 1 + infer_state.max_len_in_batch += 1 + infer_state.cache_manager.past_key_values_length += seq_length + + if not return_dict: + return tuple( + v + for v in [ + hidden_states, + presents, + all_hidden_states, + all_self_attentions, + ] + if v is not None + ) + + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=presents, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + ) + + @staticmethod + def chatglm_encoder_forward( + self: GLMTransformer, + hidden_states, + attention_mask, + kv_caches=None, + use_cache: Optional[bool] = True, + output_hidden_states: Optional[bool] = False, + infer_state: Optional[BatchInferState] = None, + ): + hidden_states = hidden_states.transpose(0, 1).contiguous() + if not kv_caches: + kv_caches = [None for _ in range(self.num_layers)] + presents = () if use_cache else None + all_self_attentions = None + all_hidden_states = () if output_hidden_states else None + + infer_state.decode_layer_id = 0 + for index in range(self.num_layers): + layer = self.layers[index] + + layer_ret = layer( + hidden_states, + attention_mask, + kv_cache=kv_caches[index], + use_cache=use_cache, + infer_state=infer_state, + ) + + infer_state.decode_layer_id += 1 + + hidden_states, kv_cache = layer_ret + if use_cache: + presents = presents + (kv_cache,) + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + # Final layer norm. + hidden_states = hidden_states.transpose(0, 1).contiguous() + + if self.post_layer_norm: + hidden_states = self.final_layernorm(hidden_states) + + return hidden_states, presents, all_hidden_states, all_self_attentions + + @staticmethod + def chatglm_glmblock_forward( + self: GLMBlock, + hidden_states, + attention_mask, + kv_cache=None, + use_cache=True, + infer_state: Optional[BatchInferState] = None, + ): + # hidden_states: [s, b, h] + + # Layer norm at the beginning of the transformer layer. + layernorm_output = self.input_layernorm(hidden_states) + # Self attention. + attention_output, kv_cache = self.self_attention( + layernorm_output, + attention_mask, + kv_cache=kv_cache, + use_cache=use_cache, + infer_state=infer_state, + ) + # Residual connection. + if self.apply_residual_connection_post_layernorm: + residual = layernorm_output + else: + residual = hidden_states + layernorm_input = torch.nn.functional.dropout(attention_output, p=self.hidden_dropout, training=self.training) + layernorm_input = residual + layernorm_input + # Layer norm post the self attention. + layernorm_output = self.post_attention_layernorm(layernorm_input) + # MLP. + mlp_output = self.mlp(layernorm_output) + + # Second residual connection. + if self.apply_residual_connection_post_layernorm: + residual = layernorm_output + else: + residual = layernorm_input + + output = torch.nn.functional.dropout(mlp_output, p=self.hidden_dropout, training=self.training) + output = residual + output + return output, kv_cache + + @staticmethod + def chatglm_flash_attn_kvcache_forward( + self: SelfAttention, + hidden_states, + attention_mask, + kv_cache=None, + use_cache=True, + infer_state: Optional[BatchInferState] = None, + ): + assert use_cache is True, "use_cache should be set to True using this chatglm attention" + # hidden_states: original :[sq, b, h] --> this [b, sq, h] + batch_size = hidden_states.shape[0] + # Attention heads [sq, b, h] --> [sq, b, (np * 3 * hn)] + mixed_x_layer = self.query_key_value(hidden_states) + + if self.multi_query_attention: + (query_layer, key_layer, value_layer) = mixed_x_layer.split( + [ + self.num_attention_heads_per_partition * self.hidden_size_per_attention_head, + self.num_multi_query_groups_per_partition * self.hidden_size_per_attention_head, + self.num_multi_query_groups_per_partition * self.hidden_size_per_attention_head, + ], + dim=-1, + ) + query_layer = query_layer.view( + query_layer.size()[:-1] + + ( + self.num_attention_heads_per_partition, + self.hidden_size_per_attention_head, + ) + ) + key_layer = key_layer.view( + key_layer.size()[:-1] + + ( + self.num_multi_query_groups_per_partition, + self.hidden_size_per_attention_head, + ) + ) + value_layer = value_layer.view( + value_layer.size()[:-1] + + ( + self.num_multi_query_groups_per_partition, + self.hidden_size_per_attention_head, + ) + ) + + else: + new_tensor_shape = mixed_x_layer.size()[:-1] + ( + self.num_attention_heads_per_partition, + 3 * self.hidden_size_per_attention_head, + ) + mixed_x_layer = mixed_x_layer.view(*new_tensor_shape) + # [sq, b, np, 3 * hn] --> 3 [sq, b, np, hn] + (query_layer, key_layer, value_layer) = split_tensor_along_last_dim(mixed_x_layer, 3) + + cos, sin = infer_state.position_cos, infer_state.position_sin + + Llama2Forwards.rotary_emb_fwd( + query_layer.view(-1, self.num_attention_heads_per_partition, self.hidden_size_per_attention_head), cos, sin + ) + if self.multi_query_attention: + Llama2Forwards.rotary_emb_fwd( + key_layer.view(-1, self.num_multi_query_groups_per_partition, self.hidden_size_per_attention_head), + cos, + sin, + ) + else: + Llama2Forwards.rotary_emb_fwd( + key_layer.view(-1, self.num_attention_heads_per_partition, self.hidden_size_per_attention_head), + cos, + sin, + ) + + # reshape q k v to [bsz*sql, num_heads, head_dim] 2*1 ,32/2 ,128 + query_layer = query_layer.reshape( + -1, self.num_attention_heads_per_partition, self.hidden_size_per_attention_head + ) + key_layer = key_layer.reshape( + -1, self.num_multi_query_groups_per_partition, self.hidden_size_per_attention_head + ) + value_layer = value_layer.reshape( + -1, self.num_multi_query_groups_per_partition, self.hidden_size_per_attention_head + ) + if infer_state.is_context_stage: + # first token generation: + # copy key and value calculated in current step to memory manager + + copy_kv_to_mem_cache( + infer_state.decode_layer_id, + key_layer, + value_layer, + infer_state.context_mem_index, + infer_state.cache_manager, + ) + + attn_output = torch.empty_like(query_layer.view(-1, self.projection_size)) + + # NOTE: no bug in context attn fwd (del it ) + llama2_context_attn_fwd( + query_layer, + key_layer, + value_layer, + attn_output.view(-1, self.num_attention_heads_per_partition, self.hidden_size_per_attention_head), + infer_state.start_loc, + infer_state.seq_len, + infer_state.seq_length_with_past, + ) + + else: + if infer_state.decode_is_contiguous: + # if decode is contiguous, then we copy to key cache and value cache in cache manager directly + cache_k = infer_state.cache_manager.key_buffer[infer_state.decode_layer_id][ + infer_state.decode_mem_start : infer_state.decode_mem_end, :, : + ] + cache_v = infer_state.cache_manager.value_buffer[infer_state.decode_layer_id][ + infer_state.decode_mem_start : infer_state.decode_mem_end, :, : + ] + cache_k.copy_(key_layer) + cache_v.copy_(value_layer) + else: + # if decode is not contiguous, use triton kernel to copy key and value cache + # k, v shape: [batch_size, num_heads, head_dim/embed_size_per_head + copy_kv_to_mem_cache( + infer_state.decode_layer_id, + key_layer, + value_layer, + infer_state.decode_mem_index, + infer_state.cache_manager, + ) + + # second token and follows + attn_output = torch.empty_like(query_layer.view(-1, self.projection_size)) + cache_k = infer_state.cache_manager.key_buffer[infer_state.decode_layer_id][ + : infer_state.decode_mem_end, :, : + ] + cache_v = infer_state.cache_manager.value_buffer[infer_state.decode_layer_id][ + : infer_state.decode_mem_end, :, : + ] + + # ================================== + # core attention computation is replaced by triton kernel + # ================================== + Llama2TokenAttentionForwards.token_attn( + query_layer, + cache_k, + cache_v, + attn_output, + infer_state.block_loc, + infer_state.start_loc, + infer_state.seq_len, + infer_state.max_len_in_batch, + infer_state.other_kv_index, + ) + + # print('after attention',torch.isnan(attn_output).any()) + + # ================= + # Output:[b,sq, h] + # ================= + + output = self.dense(attn_output).reshape(batch_size, -1, self.projection_size) + return output, kv_cache diff --git a/colossalai/inference/tensor_parallel/modeling/llama.py b/colossalai/inference/tensor_parallel/modeling/llama.py new file mode 100644 index 0000000000000000000000000000000000000000..a7661cee1128d010765cf51a92796e606b2a21d1 --- /dev/null +++ b/colossalai/inference/tensor_parallel/modeling/llama.py @@ -0,0 +1,369 @@ +from typing import List, Optional, Tuple + +import torch +from transformers.modeling_outputs import BaseModelOutputWithPast +from transformers.models.llama.modeling_llama import LlamaAttention, LlamaDecoderLayer, LlamaModel, LlamaRMSNorm + +from colossalai.inference.tensor_parallel.batch_infer_state import BatchInferState +from colossalai.kernel.triton import llama_context_attn_fwd, rotary_embedding_fwd, token_attention_fwd + +from ._utils import copy_kv_to_mem_cache + +try: + from vllm import layernorm_ops, pos_encoding_ops + + rms_norm = layernorm_ops.rms_norm + rotary_embedding_neox = pos_encoding_ops.rotary_embedding_neox + HAS_VLLM_KERNERL = True +except: + print("fall back to original rotary_embedding_neox of huggingface") + print("install vllm from https://github.com/vllm-project/vllm to accelerate your inference") + print( + "if falied to install vllm, please use this branch to install: https://github.com/tiandiao123/vllm/tree/setup_branch" + ) + HAS_VLLM_KERNERL = False + + +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +def apply_rotary_pos_emb(q, k, cos, sin, position_ids): + # The first two dimensions of cos and sin are always 1, so we can `squeeze` them. + cos = cos.squeeze(1).squeeze(0) # [seq_len, dim] + sin = sin.squeeze(1).squeeze(0) # [seq_len, dim] + cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim] + sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim] + + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + +class LlamaInferenceForwards: + """ + This class holds forwards for llama inference. + We intend to replace the forward methods for LlamaModel, LlamaDecoderLayer, and LlamaAttention for LlamaForCausalLM. + """ + + @staticmethod + def llama_model_forward( + self: LlamaModel, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ): + batch_size = input_ids.shape[0] # input_ids.shape[0] + + infer_state = self.infer_state + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # retrieve input_ids and inputs_embeds + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") + elif input_ids is not None: + batch_size, seq_length = input_ids.shape + elif inputs_embeds is not None: + batch_size, seq_length, _ = inputs_embeds.shape + else: + raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") + + seq_length_with_past = seq_length + past_key_values_length = 0 + + if past_key_values is not None: + # NOT READY FOR PRIME TIME + # dummy but work, revise it + past_key_values_length = infer_state.cache_manager.past_key_values_length + # past_key_values_length = past_key_values[0][0].shape[2] + seq_length_with_past = seq_length_with_past + past_key_values_length + + # NOTE: differentiate with prefill stage + # block_loc require different value-assigning method for two different stage + if use_cache and seq_length != 1: + # NOTE assume prefill stage + # allocate memory block + infer_state.is_context_stage = True # set prefill stage, notify attention layer + infer_state.context_mem_index = infer_state.cache_manager.alloc(infer_state.total_token_num) + infer_state.init_block_loc( + infer_state.block_loc, infer_state.seq_len, seq_length, infer_state.context_mem_index + ) + else: + infer_state.is_context_stage = False + alloc_mem = infer_state.cache_manager.alloc_contiguous(batch_size) + if alloc_mem is not None: + infer_state.decode_is_contiguous = True + infer_state.decode_mem_index = alloc_mem[0] + infer_state.decode_mem_start = alloc_mem[1] + infer_state.decode_mem_end = alloc_mem[2] + infer_state.block_loc[:, seq_length_with_past - 1] = infer_state.decode_mem_index + else: + print(f" *** Encountered allocation non-contiguous") + print( + f" infer_state.cache_manager.past_key_values_length: {infer_state.cache_manager.past_key_values_length}" + ) + infer_state.decode_is_contiguous = False + alloc_mem = infer_state.cache_manager.alloc(batch_size) + infer_state.decode_mem_index = alloc_mem + # infer_state.decode_key_buffer = torch.empty((batch_size, self.tp_head_num_, self.head_dim_), dtype=torch.float16, device="cuda") + # infer_state.decode_value_buffer = torch.empty((batch_size, self.tp_head_num_, self.head_dim_), dtype=torch.float16, device="cuda") + infer_state.block_loc[:, seq_length_with_past - 1] = infer_state.decode_mem_index + if position_ids is None: + device = input_ids.device if input_ids is not None else inputs_embeds.device + position_ids = torch.arange( + past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device + ) + position_ids = position_ids.unsqueeze(0).view(-1, seq_length) + else: + position_ids = position_ids.view(-1, seq_length).long() + + if infer_state.is_context_stage: + infer_state.position_cos = torch.index_select(self._cos_cached, 0, position_ids.view(-1)).view( + position_ids.view(-1).shape[0], -1 + ) + infer_state.position_sin = torch.index_select(self._sin_cached, 0, position_ids.view(-1)).view( + position_ids.view(-1).shape[0], -1 + ) + else: + seq_len = infer_state.seq_len + infer_state.position_cos = torch.index_select(self._cos_cached, 0, seq_len - 1).view(seq_len.shape[0], -1) + infer_state.position_sin = torch.index_select(self._sin_cached, 0, seq_len - 1).view(seq_len.shape[0], -1) + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + # embed positions + if attention_mask is None: + attention_mask = torch.ones( + (batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device + ) + + attention_mask = self._prepare_decoder_attention_mask( + attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length + ) + + hidden_states = inputs_embeds + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + next_decoder_cache = () if use_cache else None + + infer_state.decode_layer_id = 0 + + for idx, decoder_layer in enumerate(self.layers): + past_key_value = past_key_values[idx] if past_key_values is not None else None + # NOTE: modify here for passing args to decoder layer + layer_outputs = decoder_layer( + hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + infer_state=infer_state, + ) + infer_state.decode_layer_id += 1 + hidden_states = layer_outputs[0] + + if use_cache: + next_decoder_cache += (layer_outputs[2 if output_attentions else 1],) + + hidden_states = self.norm(hidden_states) + next_cache = next_decoder_cache if use_cache else None + + # update indices + # infer_state.block_loc[:, infer_state.max_len_in_batch-1] = infer_state.total_token_num + torch.arange(0, batch_size, dtype=torch.int32, device="cuda") + infer_state.start_loc = infer_state.start_loc + torch.arange(0, batch_size, dtype=torch.int32, device="cuda") + infer_state.seq_len += 1 + + if not return_dict: + return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) + + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + ) + + @staticmethod + def llama_decoder_layer_forward( + self: LlamaDecoderLayer, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + infer_state: Optional[BatchInferState] = None, + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + residual = hidden_states + + hidden_states = self.input_layernorm(hidden_states) + + # Self Attention + hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + infer_state=infer_state, + ) + + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights,) + + if use_cache: + outputs += (present_key_value,) + + return outputs + + @staticmethod + def llama_flash_attn_kvcache_forward( + self: LlamaAttention, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: bool = False, + use_cache: bool = False, + infer_state: Optional[BatchInferState] = None, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + assert use_cache is True, "use_cache should be set to True using this llama attention" + + bsz, q_len, _ = hidden_states.size() + + # NOTE might think about better way to handle transposed k and v + # key_states [bs, seq_len, num_heads, head_dim/embed_size_per_head] + # key_states_transposed [bs, num_heads, seq_len, head_dim/embed_size_per_head] + + query_states = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim) + key_states = self.k_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim) + value_states = self.v_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim) + + # NOTE might want to revise + # need some way to record the length of past key values cache + # since we won't return past_key_value_cache right now + if infer_state.decode_layer_id == 0: # once per model.forward + infer_state.cache_manager.past_key_values_length += q_len # seq_len + + cos, sin = infer_state.position_cos, infer_state.position_sin + # print("shape ", cos.shape, query_states.view(-1, self.num_heads, self.head_dim).shape, ) + + rotary_embedding_fwd(query_states.view(-1, self.num_heads, self.head_dim), cos, sin) + rotary_embedding_fwd(key_states.view(-1, self.num_heads, self.head_dim), cos, sin) + + query_states = query_states.reshape(-1, self.num_heads, self.head_dim) + key_states = key_states.reshape(-1, self.num_heads, self.head_dim) + value_states = value_states.reshape(-1, self.num_heads, self.head_dim) + + if infer_state.is_context_stage: + # first token generation + + # copy key and value calculated in current step to memory manager + copy_kv_to_mem_cache( + infer_state.decode_layer_id, + key_states, + value_states, + infer_state.context_mem_index, + infer_state.cache_manager, + ) + + attn_output = torch.empty_like(query_states) + + llama_context_attn_fwd( + query_states, + key_states, + value_states, + attn_output, + infer_state.start_loc, + infer_state.seq_len, + infer_state.cache_manager.past_key_values_length, + ) + else: + if infer_state.decode_is_contiguous: + # if decode is contiguous, then we copy to key cache and value cache in cache manager directly + cache_k = infer_state.cache_manager.key_buffer[infer_state.decode_layer_id][ + infer_state.decode_mem_start : infer_state.decode_mem_end, :, : + ] + cache_v = infer_state.cache_manager.value_buffer[infer_state.decode_layer_id][ + infer_state.decode_mem_start : infer_state.decode_mem_end, :, : + ] + cache_k.copy_(key_states) + cache_v.copy_(value_states) + else: + # if decode is not contiguous, use triton kernel to copy key and value cache + # k, v shape: [batch_size, num_heads, head_dim/embed_size_per_head + copy_kv_to_mem_cache( + infer_state.decode_layer_id, + key_states, + value_states, + infer_state.decode_mem_index, + infer_state.cache_manager, + ) + + # second token and follows + # kv = torch.stack((key_states, value_states), dim=2) + # (batch_size, seqlen, nheads, headdim) + attn_output = torch.empty_like(query_states) + + token_attention_fwd( + query_states, + infer_state.cache_manager.key_buffer[infer_state.decode_layer_id], + infer_state.cache_manager.value_buffer[infer_state.decode_layer_id], + attn_output, + infer_state.block_loc, + infer_state.start_loc, + infer_state.seq_len, + infer_state.cache_manager.past_key_values_length, + ) + + attn_output = attn_output.view(bsz, q_len, self.hidden_size) + + attn_output = self.o_proj(attn_output) + + # return past_key_value as None + return attn_output, None, None + + +def get_llama_vllm_rmsnorm_forward(): + if HAS_VLLM_KERNERL: + + def _vllm_rmsnorm_forward(self: LlamaRMSNorm, hidden_states: torch.Tensor): + x = hidden_states + out = torch.empty_like(x) + rms_norm( + out, + x, + self.weight.data, + self.variance_epsilon, + ) + + return out + + return _vllm_rmsnorm_forward + else: + return None diff --git a/colossalai/inference/tensor_parallel/policies/__init__.py b/colossalai/inference/tensor_parallel/policies/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..776c4e85056542a1a35dbbf23650bbc01e5b8ffd --- /dev/null +++ b/colossalai/inference/tensor_parallel/policies/__init__.py @@ -0,0 +1,5 @@ +from .bloom import BloomModelInferPolicy +from .chatglm2 import ChatGLM2InferPolicy +from .llama import LlamaModelInferPolicy + +__all__ = ["BloomModelInferPolicy", "LlamaModelInferPolicy", "ChatGLM2InferPolicy"] diff --git a/colossalai/inference/tensor_parallel/policies/bloom.py b/colossalai/inference/tensor_parallel/policies/bloom.py new file mode 100644 index 0000000000000000000000000000000000000000..3d6df20970005f184dfc96d420b8dfe829e2e176 --- /dev/null +++ b/colossalai/inference/tensor_parallel/policies/bloom.py @@ -0,0 +1,99 @@ +from functools import partial + +import torch +from torch.nn import LayerNorm + +import colossalai.shardformer.layer as col_nn +from colossalai.shardformer.modeling.bloom import build_bloom_alibi_tensor_fn +from colossalai.shardformer.policies.base_policy import ModulePolicyDescription, SubModuleReplacementDescription +from colossalai.shardformer.policies.bloom import BloomForCausalLMPolicy + +from ..modeling.bloom import BloomInferenceForwards + +try: + from colossalai.kernel.triton import layer_norm + + HAS_TRITON_NORM = True +except: + print("Some of our kernels require triton. You might want to install triton from https://github.com/openai/triton") + HAS_TRITON_NORM = False + + +def get_triton_layernorm_forward(): + if HAS_TRITON_NORM: + + def _triton_layernorm_forward(self: LayerNorm, hidden_states: torch.Tensor): + return layer_norm(hidden_states, self.weight.data, self.bias, self.eps) + + return _triton_layernorm_forward + else: + return None + + +class BloomModelInferPolicy(BloomForCausalLMPolicy): + def __init__(self) -> None: + super().__init__() + + def module_policy(self): + from transformers.models.bloom.modeling_bloom import BloomAttention, BloomBlock, BloomForCausalLM, BloomModel + + policy = super().module_policy() + if self.shard_config.inference_gptq: + from colossalai.inference.quant.gptq.cai_gptq import ColCaiQuantLinear, RowCaiQuantLinear + policy[BloomBlock] = ModulePolicyDescription(attribute_replacement={ + "self_attention.hidden_size": self.model.config.hidden_size // self.shard_config.tensor_parallel_size, + "self_attention.split_size": self.model.config.hidden_size // self.shard_config.tensor_parallel_size, + "self_attention.num_heads": self.model.config.n_head // self.shard_config.tensor_parallel_size, + }, + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="self_attention.query_key_value", + target_module=ColCaiQuantLinear, + kwargs={'split_num': 3}), + SubModuleReplacementDescription( + suffix="self_attention.dense", + target_module=RowCaiQuantLinear, + kwargs={'split_num': 1}), + SubModuleReplacementDescription( + suffix="self_attention.attention_dropout", + target_module=col_nn.DropoutForParallelInput, + ), + SubModuleReplacementDescription( + suffix="mlp.dense_h_to_4h", + target_module=ColCaiQuantLinear, + kwargs={'split_num': 1}), + SubModuleReplacementDescription( + suffix="mlp.dense_4h_to_h", + target_module=RowCaiQuantLinear, + kwargs={'split_num': 1}), + ]) + # NOTE set inference mode to shard config + self.shard_config._infer() + + method_replacement = { + "forward": BloomInferenceForwards.bloom_for_causal_lm_forward, + "prepare_inputs_for_generation": BloomInferenceForwards.bloom_for_causal_lm_prepare_inputs_for_generation, + } + self.append_or_create_method_replacement( + description=method_replacement, policy=policy, target_key=BloomForCausalLM + ) + + method_replacement = {"forward": BloomInferenceForwards.bloom_model_forward} + self.append_or_create_method_replacement(description=method_replacement, policy=policy, target_key=BloomModel) + + method_replacement = {"forward": BloomInferenceForwards.bloom_block_forward} + self.append_or_create_method_replacement(description=method_replacement, policy=policy, target_key=BloomBlock) + + method_replacement = {"forward": BloomInferenceForwards.bloom_attention_forward} + self.append_or_create_method_replacement( + description=method_replacement, policy=policy, target_key=BloomAttention + ) + + if HAS_TRITON_NORM: + infer_method = get_triton_layernorm_forward() + method_replacement = {"forward": partial(infer_method)} + self.append_or_create_method_replacement( + description=method_replacement, policy=policy, target_key=LayerNorm + ) + + return policy diff --git a/colossalai/inference/tensor_parallel/policies/chatglm2.py b/colossalai/inference/tensor_parallel/policies/chatglm2.py new file mode 100644 index 0000000000000000000000000000000000000000..90f8b4fd2d7eb8fdb6a3fd52c275bffea5c3ab72 --- /dev/null +++ b/colossalai/inference/tensor_parallel/policies/chatglm2.py @@ -0,0 +1,74 @@ +from functools import partial + +from colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm import ( + ChatGLMForConditionalGeneration, + ChatGLMModel, + GLMBlock, + GLMTransformer, + SelfAttention, +) + +# import colossalai +from colossalai.shardformer.policies.chatglm2 import ChatGLMModelPolicy + +from ..modeling._utils import init_to_get_rotary +from ..modeling.chatglm2 import ChatGLM2InferenceForwards + +try: + HAS_TRITON_RMSNORM = True +except: + print("you should install triton from https://github.com/openai/triton") + HAS_TRITON_RMSNORM = False + + +class ChatGLM2InferPolicy(ChatGLMModelPolicy): + def __init__(self) -> None: + super().__init__() + + def module_policy(self): + policy = super().module_policy() + self.shard_config._infer() + + model_infer_forward = ChatGLM2InferenceForwards.chatglm_model_forward + method_replacement = {"forward": model_infer_forward} + self.append_or_create_method_replacement(description=method_replacement, policy=policy, target_key=ChatGLMModel) + + encoder_infer_forward = ChatGLM2InferenceForwards.chatglm_encoder_forward + method_replacement = {"forward": encoder_infer_forward} + self.append_or_create_method_replacement( + description=method_replacement, policy=policy, target_key=GLMTransformer + ) + + encoder_layer_infer_forward = ChatGLM2InferenceForwards.chatglm_glmblock_forward + method_replacement = {"forward": encoder_layer_infer_forward} + self.append_or_create_method_replacement(description=method_replacement, policy=policy, target_key=GLMBlock) + + attn_infer_forward = ChatGLM2InferenceForwards.chatglm_flash_attn_kvcache_forward + method_replacement = {"forward": attn_infer_forward} + self.append_or_create_method_replacement( + description=method_replacement, policy=policy, target_key=SelfAttention + ) + + # for rmsnorm and others, we need to check the shape + return policy + + def postprocess(self): + init_to_get_rotary(self.model) + return self.model + + +class ChatGLM2ForConditionalGenerationInferPolicy(ChatGLM2InferPolicy): + def __init__(self) -> None: + super().__init__() + + def module_policy(self): + policy = super().module_policy() + model_infer_forward = ChatGLM2InferenceForwards.chatglm_for_conditional_generation_forward + method_replacement = {"forward": partial(model_infer_forward)} + self.append_or_create_method_replacement( + description=method_replacement, policy=policy, target_key=ChatGLMForConditionalGeneration + ) + return policy + + def postprocess(self): + return super().postprocess() diff --git a/colossalai/inference/tensor_parallel/policies/llama.py b/colossalai/inference/tensor_parallel/policies/llama.py new file mode 100644 index 0000000000000000000000000000000000000000..507c1203dd6bd108e522560738ae604cd66060e3 --- /dev/null +++ b/colossalai/inference/tensor_parallel/policies/llama.py @@ -0,0 +1,124 @@ +from functools import partial + +import torch +from transformers.models.llama.modeling_llama import LlamaAttention, LlamaDecoderLayer, LlamaModel, LlamaRMSNorm + +from colossalai.shardformer.policies.base_policy import ModulePolicyDescription, SubModuleReplacementDescription + +# import colossalai +from colossalai.shardformer.policies.llama import LlamaForCausalLMPolicy + +from ..modeling._utils import init_to_get_rotary +from ..modeling.llama import LlamaInferenceForwards, get_llama_vllm_rmsnorm_forward + +try: + from colossalai.kernel.triton import rmsnorm_forward + + HAS_TRITON_RMSNORM = True +except: + print("you should install triton from https://github.com/openai/triton") + HAS_TRITON_RMSNORM = False + + +def get_triton_rmsnorm_forward(): + if HAS_TRITON_RMSNORM: + + def _triton_rmsnorm_forward(self: LlamaRMSNorm, hidden_states: torch.Tensor): + return rmsnorm_forward(hidden_states, self.weight.data, self.variance_epsilon) + + return _triton_rmsnorm_forward + else: + return None + + +class LlamaModelInferPolicy(LlamaForCausalLMPolicy): + def __init__(self) -> None: + super().__init__() + + def module_policy(self): + policy = super().module_policy() + + if self.shard_config.inference_gptq: + from colossalai.inference.quant.gptq.cai_gptq import ColCaiQuantLinear, RowCaiQuantLinear + + decoder_attribute_replacement = { + "self_attn.hidden_size": self.model.config.hidden_size // self.shard_config.tensor_parallel_size, + "self_attn.num_heads": self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size, + } + policy[LlamaDecoderLayer] = ModulePolicyDescription( + attribute_replacement=decoder_attribute_replacement, + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="self_attn.q_proj", + target_module=ColCaiQuantLinear, + kwargs={"split_num": 1}, + ), + SubModuleReplacementDescription( + suffix="self_attn.k_proj", + target_module=ColCaiQuantLinear, + kwargs={"split_num": 1}, + ), + SubModuleReplacementDescription( + suffix="self_attn.v_proj", + target_module=ColCaiQuantLinear, + kwargs={"split_num": 1}, + ), + SubModuleReplacementDescription( + suffix="self_attn.o_proj", + target_module=RowCaiQuantLinear, + kwargs={"split_num": 1}, + ), + SubModuleReplacementDescription( + suffix="mlp.gate_proj", + target_module=ColCaiQuantLinear, + kwargs={"split_num": 1}, + ), + SubModuleReplacementDescription( + suffix="mlp.up_proj", + target_module=ColCaiQuantLinear, + kwargs={"split_num": 1}, + ), + SubModuleReplacementDescription( + suffix="mlp.down_proj", + target_module=RowCaiQuantLinear, + kwargs={"split_num": 1}, + ), + ], + ) + + self.shard_config._infer() + + infer_forward = LlamaInferenceForwards.llama_model_forward + method_replacement = {"forward": partial(infer_forward)} + self.append_or_create_method_replacement(description=method_replacement, policy=policy, target_key=LlamaModel) + + infer_forward = LlamaInferenceForwards.llama_decoder_layer_forward + method_replacement = {"forward": partial(infer_forward)} + self.append_or_create_method_replacement( + description=method_replacement, policy=policy, target_key=LlamaDecoderLayer + ) + + infer_forward = LlamaInferenceForwards.llama_flash_attn_kvcache_forward + method_replacement = {"forward": partial(infer_forward)} + self.append_or_create_method_replacement( + description=method_replacement, policy=policy, target_key=LlamaAttention + ) + + infer_forward = None + if HAS_TRITON_RMSNORM: + infer_forward = get_triton_rmsnorm_forward() + else: + # NOTE: adding rms_norm from cuda kernels caused precision issue, fix @tiandiao123 + infer_forward = get_llama_vllm_rmsnorm_forward() + + if infer_forward is not None: + method_replacement = {"forward": partial(infer_forward)} + self.append_or_create_method_replacement( + description=method_replacement, policy=policy, target_key=LlamaRMSNorm + ) + + return policy + + def postprocess(self): + init_to_get_rotary(self.model.model) + return self.model diff --git a/colossalai/initialize.py b/colossalai/initialize.py index 5d3f3e5530cbb71a6e88cd3726615c1b6b61a164..aac57d34a2c1f1eebeb95802937672c95819a963 100644 --- a/colossalai/initialize.py +++ b/colossalai/initialize.py @@ -1,69 +1,30 @@ #!/usr/bin/env python # -*- encoding: utf-8 -*- -import argparse import os -import pprint +import warnings from pathlib import Path -from typing import Callable, Dict, Iterable, List, Optional, Tuple, Union +from typing import Dict, Union import torch -import torch.nn as nn -from torch.nn.modules.loss import _Loss -from torch.nn.parallel import DistributedDataParallel as DDP -from torch.optim.lr_scheduler import _LRScheduler -from torch.optim.optimizer import Optimizer -from torch.utils.data import DataLoader +import torch.distributed as dist -from colossalai.amp import AMP_TYPE, convert_to_amp -from colossalai.amp.naive_amp import NaiveAMPModel -from colossalai.builder.builder import build_gradient_handler -from colossalai.context import Config, ConfigException, ParallelMode -from colossalai.context.moe_context import MOE_CONTEXT -from colossalai.core import global_context as gpc -from colossalai.engine import Engine -from colossalai.engine.gradient_accumulation import accumulate_gradient -from colossalai.engine.schedule import ( - InterleavedPipelineSchedule, - NonPipelineSchedule, - PipelineSchedule, - get_tensor_shape, -) +from colossalai.context import Config from colossalai.logging import get_dist_logger -from colossalai.nn.optimizer.colossalai_optimizer import ColossalaiOptimizer -from colossalai.utils import get_current_device, is_using_ddp, is_using_pp, is_using_sequence, sync_model_param -from colossalai.utils.moe import sync_moe_model_param -from colossalai.zero.legacy import ShardedOptimizerV2, convert_to_zero_v2 -from colossalai.zero.legacy.gemini.ophooks import BaseOpHook - - -def get_default_parser(): - """Reads user command line and uses an argument parser to parse the input arguments. - Input arguments include configuration, host, port, world size, local rank, backend for torch.distributed. - - Returns: - Namespace: Returns the parser with the default arguments, the user may add customized arguments into this parser. - """ - parser = argparse.ArgumentParser() - parser.add_argument('--config', type=str, help='path to the config file') - parser.add_argument('--host', type=str, help='the master address for distributed training') - parser.add_argument('--port', type=int, help='the master port for distributed training') - parser.add_argument('--world_size', type=int, help='world size for distributed training') - parser.add_argument('--rank', type=int, help='rank for the default process group') - parser.add_argument('--local_rank', type=int, help='local rank on the node') - parser.add_argument('--backend', type=str, default='nccl', help='backend for distributed communication') - return parser - - -def launch(config: Union[str, Path, Config, Dict], - rank: int, - world_size: int, - host: str, - port: int, - backend: str = 'nccl', - local_rank: int = None, - seed: int = 1024, - verbose: bool = True): +from colossalai.utils import set_device, set_seed + + +def launch( + config: Union[str, Path, Config, Dict], + rank: int, + world_size: int, + host: str, + port: int, + backend: str = "nccl", + local_rank: int = None, + seed: int = 1024, + verbose: bool = True, +): """This function first parses the configuration arguments, using :func:`parse_args()` in case one of the input arguments are not given. Then initialize and set distributed environment by calling global_context's functions. @@ -83,48 +44,33 @@ def launch(config: Union[str, Path, Config, Dict], Raises: Exception: Raise exception when config type is wrong """ - gpc.verbose = verbose - - # set config - assert isinstance(config, (Config, str, Path, dict)), \ - f'expected argument config to be Config, str or Path, but got {type(config)}' - if not isinstance(config, Config) and isinstance(config, dict): - config = Config(config) - if isinstance(config, (str, Path)): - config = Config.from_file(config) - gpc.load_config(config) + if rank == 0: + warnings.warn("`config` is deprecated and will be removed soon.") # init default process group - gpc.init_global_dist(rank, world_size, backend, host, port) - - # init process groups for different parallel modes from config - gpc.init_parallel_groups() + init_method = f"tcp://[{host}]:{port}" + dist.init_process_group(rank=rank, world_size=world_size, backend=backend, init_method=init_method) # set cuda device if torch.cuda.is_available(): # if local rank is not given, calculate automatically - gpc.set_device(local_rank) + set_device(local_rank) - # set the number of processes running on the same node - gpc.detect_num_processes_on_current_node() - - gpc.set_seed(seed) + set_seed(seed) if verbose: logger = get_dist_logger() - logger.info( - f'Distributed environment is initialized, ' - f'data parallel size: {gpc.data_parallel_size}, pipeline parallel size: {gpc.pipeline_parallel_size}, ' - f'tensor parallel size: {gpc.tensor_parallel_size}', - ranks=[0]) + logger.info(f"Distributed environment is initialized, world size: {dist.get_world_size()}", ranks=[0]) -def launch_from_slurm(config: Union[str, Path, Config, Dict], - host: str, - port: int, - backend: str = 'nccl', - seed: int = 1024, - verbose: bool = True): +def launch_from_slurm( + config: Union[str, Path, Config, Dict], + host: str, + port: int, + backend: str = "nccl", + seed: int = 1024, + verbose: bool = True, +): """A wrapper for colossalai.launch for SLURM launcher by reading rank and world size from the environment variables set by SLURM @@ -137,29 +83,33 @@ def launch_from_slurm(config: Union[str, Path, Config, Dict], verbose (bool, optional): Whether to print logs. Defaults to True. """ try: - rank = int(os.environ['SLURM_PROCID']) - world_size = int(os.environ['SLURM_NPROCS']) + rank = int(os.environ["SLURM_PROCID"]) + world_size = int(os.environ["SLURM_NPROCS"]) except KeyError as e: raise RuntimeError( f"Could not find {e} in the SLURM environment, visit https://www.colossalai.org/ for more information on launching with SLURM" ) - launch(config=config, - rank=rank, - world_size=world_size, - host=host, - port=port, - backend=backend, - seed=seed, - verbose=verbose) - - -def launch_from_openmpi(config: Union[str, Path, Config, Dict], - host: str, - port: int, - backend: str = 'nccl', - seed: int = 1024, - verbose: bool = True): + launch( + config=config, + rank=rank, + world_size=world_size, + host=host, + port=port, + backend=backend, + seed=seed, + verbose=verbose, + ) + + +def launch_from_openmpi( + config: Union[str, Path, Config, Dict], + host: str, + port: int, + backend: str = "nccl", + seed: int = 1024, + verbose: bool = True, +): """A wrapper for colossalai.launch for OpenMPI launcher by reading rank and world size from the environment variables set by OpenMPI @@ -172,29 +122,30 @@ def launch_from_openmpi(config: Union[str, Path, Config, Dict], verbose (bool, optional): Whether to print logs. Defaults to True. """ try: - rank = int(os.environ['OMPI_COMM_WORLD_RANK']) - local_rank = int(os.environ['OMPI_COMM_WORLD_LOCAL_RANK']) - world_size = int(os.environ['OMPI_COMM_WORLD_SIZE']) + rank = int(os.environ["OMPI_COMM_WORLD_RANK"]) + local_rank = int(os.environ["OMPI_COMM_WORLD_LOCAL_RANK"]) + world_size = int(os.environ["OMPI_COMM_WORLD_SIZE"]) except KeyError as e: raise RuntimeError( f"Could not find {e} in the OpenMPI environment, visit https://www.colossalai.org/ for more information on launching with OpenMPI" ) - launch(config=config, - local_rank=local_rank, - rank=rank, - world_size=world_size, - host=host, - port=port, - backend=backend, - seed=seed, - verbose=verbose) - - -def launch_from_torch(config: Union[str, Path, Config, Dict], - backend: str = 'nccl', - seed: int = 1024, - verbose: bool = True): + launch( + config=config, + local_rank=local_rank, + rank=rank, + world_size=world_size, + host=host, + port=port, + backend=backend, + seed=seed, + verbose=verbose, + ) + + +def launch_from_torch( + config: Union[str, Path, Config, Dict], backend: str = "nccl", seed: int = 1024, verbose: bool = True +): """A wrapper for colossalai.launch for torchrun or torch.distributed.launch by reading rank and world size from the environment variables set by PyTorch @@ -205,266 +156,24 @@ def launch_from_torch(config: Union[str, Path, Config, Dict], verbose (bool, optional): Whether to print logs. Defaults to True. """ try: - rank = int(os.environ['RANK']) - local_rank = int(os.environ['LOCAL_RANK']) - world_size = int(os.environ['WORLD_SIZE']) - host = os.environ['MASTER_ADDR'] - port = int(os.environ['MASTER_PORT']) + rank = int(os.environ["RANK"]) + local_rank = int(os.environ["LOCAL_RANK"]) + world_size = int(os.environ["WORLD_SIZE"]) + host = os.environ["MASTER_ADDR"] + port = int(os.environ["MASTER_PORT"]) except KeyError as e: raise RuntimeError( f"Could not find {e} in the torch environment, visit https://www.colossalai.org/ for more information on launching with torch" ) - launch(config=config, - local_rank=local_rank, - rank=rank, - world_size=world_size, - host=host, - port=port, - backend=backend, - seed=seed, - verbose=verbose) - - -def initialize(model: nn.Module, - optimizer: Optimizer, - criterion: Optional[_Loss] = None, - train_dataloader: Optional[Iterable] = None, - test_dataloader: Optional[Iterable] = None, - lr_scheduler: Optional[_LRScheduler] = None, - ophooks: Optional[List[BaseOpHook]] = None, - verbose: bool = True) -> Tuple[Engine, DataLoader, DataLoader, _LRScheduler]: - """Core function to wrap the essential training components with our functionality based on the config which is - loaded into gpc.config. - - Args: - model (:class:`torch.nn.Module` or Callbale): Your model instance or a function to build the model. - optimizer (:class:`torch.optim.optimizer.Optimizer` or :class:`Type[torch.optim.optimizer]`): - Your optimizer instance. - criterion (:class:`torch.nn.modules.loss._Loss`, optional): Your criterion instance. - train_dataloader (:class:`torch.utils.data.DataLoader`, optional): Dataloader for training. - test_dataloader (:class:`torch.utils.data.DataLoader`, optional): Dataloader for testing. - lr_scheduler (:class:`torch.nn.lr_scheduler._LRScheduler`, optional): Your lr scheduler instance, optional. - verbose (bool, optional): Whether to print logs. - - Returns: - Tuple (engine, train_dataloader, test_dataloader, lr_scheduler): - A tuple of ``(engine, train_dataloader, test_dataloader, lr_scheduler)`` - where only ``engine`` could not be None. - """ - # get logger - logger = get_dist_logger() - gpc.verbose = verbose - - # get config from gpc - config = gpc.config - - # print config - if verbose: - logger.info( - f"\n========== Your Config ========\n" - f"{pprint.pformat(gpc.config)}\n" - f"================================\n", - ranks=[0]) - - # cudnn - cudnn_benchmark = config.get('cudnn_benchmark', False) - cudnn_deterministic = config.get('cudnn_deterministic', False) - torch.backends.cudnn.benchmark = cudnn_benchmark - torch.backends.cudnn.deterministic = cudnn_deterministic - if verbose: - logger.info(f"cuDNN benchmark = {cudnn_benchmark}, deterministic = {cudnn_deterministic}", ranks=[0]) - - # zero - use_zero = hasattr(gpc.config, 'zero') - if use_zero: - zero_cfg = gpc.config.get('zero', None) - if zero_cfg is not None: - cfg_ = zero_cfg.copy() - else: - cfg_ = {} - optimizer_config = zero_cfg.get('optimizer_config', None) - model_config = zero_cfg.get('model_config', None) - model, optimizer = convert_to_zero_v2(model, - optimizer, - model_config=model_config, - optimizer_config=optimizer_config) - - logger.info("Initializing ZeRO model and optimizer finished!", ranks=[0]) - else: - if isinstance(model, nn.Module): - # first sync model across dp ranks - model.to(get_current_device()) - elif isinstance(model, Callable): - model = model().to(get_current_device()) - - # optimizer maybe a optimizer_cls - if isinstance(optimizer, Callable): - optimizer = optimizer(model.parameters()) - logger.warning("Initializing an non ZeRO model with optimizer class") - - if not use_zero: - if is_using_sequence(): - sync_model_param(model, ParallelMode.SEQUENCE_DP) - elif MOE_CONTEXT.is_initialized: - sync_moe_model_param(model) - elif is_using_ddp(): - sync_model_param(model, ParallelMode.DATA) - else: - logger.warning( - "The parameters of models is not automatically synchronized.\n" - "Please make sure that all parameters are the same in data parallel group.", - ranks=[0]) - - # check amp and zero - fp16_cfg = gpc.config.get('fp16', None) - - if fp16_cfg is not None and fp16_cfg.mode is not None and use_zero: - raise ConfigException( - "It is not allowed to set fp16 and zero configuration in your config file at the same time") - - # clip grad norm - clip_grad_norm = gpc.config.get('clip_grad_norm', 0.0) - - # initialize amp - amp_mode = None - if fp16_cfg is not None and fp16_cfg.mode is not None: - cfg_ = fp16_cfg.copy() - amp_mode = cfg_.pop('mode') - if is_using_pp(): - assert amp_mode == AMP_TYPE.NAIVE, 'Pipeline only support NaiveAMP currently' - if amp_mode == AMP_TYPE.NAIVE: - cfg_['clip_grad_norm'] = clip_grad_norm - model, optimizer, criterion = convert_to_amp(model=model, - optimizer=optimizer, - criterion=criterion, - mode=amp_mode, - amp_config=cfg_) - - # get torch ddp config - torch_ddp_cfg = gpc.config.get('torch_ddp', dict()) - - # gradient handler - gradient_handler_cfg = gpc.config.get('gradient_handler', None) - if gradient_handler_cfg is None: - # if gradient handler is not specified in the configuration file, - # check in the following order - # 1. if optimizer is ZERO, then use zero grad handler - # 2. if dp size is larger than 1 and pipeline is not used, use pytorch ddp - # 3. if using pipeline and dp size larger than 1, use data parallel grad handler - if isinstance(optimizer, ShardedOptimizerV2): - gradient_handler_cfg = [dict(type='ZeROGradientHandler')] - if verbose: - logger.info( - "Training with zero is detected, ZeROGradientHandler is automatically " - "added even though not specified in the configuration", - ranks=[0]) - elif is_using_ddp() and MOE_CONTEXT.is_initialized: - gradient_handler_cfg = [dict(type='MoeGradientHandler')] - if verbose: - logger.info( - "Data parallel training is detected with moe parallel, MoeGradientHandler is automatically " - "added even though not specified in the configuration", - ranks=[0]) - elif is_using_sequence(): - model = DDP(model, - process_group=gpc.get_group(ParallelMode.SEQUENCE_DP), - device_ids=[torch.cuda.current_device()], - **torch_ddp_cfg) - if verbose: - logger.info('Model is using torch.nn.parallel.DistributedDataParallel for Sequence Parallelism', - ranks=[0]) - elif is_using_ddp() and not is_using_pp() and amp_mode != AMP_TYPE.NAIVE: - model = DDP(model, - process_group=gpc.get_group(ParallelMode.DATA), - device_ids=[torch.cuda.current_device()], - **torch_ddp_cfg) - if verbose: - logger.info('Model is using torch.nn.parallel.DistributedDataParallel for Data Parallelism', ranks=[0]) - elif is_using_ddp(): - gradient_handler_cfg = [dict(type='DataParallelGradientHandler')] - if verbose: - logger.info( - "Data parallel training is detected when using pipeline parallel, " - "DataParallelGradientHandler is automatically " - "added even though not specified in the configuration", - ranks=[0]) - # add pipeline parallel gradient handler, if pipeline shared module is detected - for param in model.parameters(): - if getattr(param, 'pipeline_shared_module_pg', None) is not None: - if gradient_handler_cfg is None: - gradient_handler_cfg = [dict(type='PipelineSharedModuleGradientHandler')] - else: - gradient_handler_cfg.append(dict(type='PipelineSharedModuleGradientHandler')) - if verbose: - logger.info( - "pipeline_shared_module is detected, PipelineSharedModuleGradientHandler is automatically " - "added even though not specified in the configuration", - ranks=[0]) - break - else: - if not isinstance(gradient_handler_cfg, list): - raise ConfigException( - f"expected gradient_handler in the configuration file to be a list but got {type(gradient_handler_cfg)}" - ) - - # turn off sync buffer for NaiveAMPModel if using torch DDP and NaiveAMPModel at the same time - # to avoid duplicated buffer synchronization - if isinstance(model, DDP) and isinstance(model.module, NaiveAMPModel): - model.module.sync_buffer = False - - # initialize schedule for engine - if is_using_pp(): - tensor_shape = get_tensor_shape() - use_interleaved = hasattr(gpc.config, 'model') and hasattr(gpc.config.model, 'num_chunks') - if gpc.is_initialized(ParallelMode.PARALLEL_1D): - scatter_gather = True - else: - scatter_gather = False - if use_interleaved: - if isinstance(model, nn.Sequential): - model = nn.ModuleList([model]) - schedule = InterleavedPipelineSchedule(gpc.config.NUM_MICRO_BATCHES, - gpc.config.model.num_chunks, - tensor_shape=tensor_shape, - scatter_gather_tensors=scatter_gather) - else: - schedule = PipelineSchedule(gpc.config.NUM_MICRO_BATCHES, - tensor_shape=tensor_shape, - scatter_gather_tensors=scatter_gather) - else: - schedule = NonPipelineSchedule() - - if gradient_handler_cfg is None: - gradient_handlers = None - if verbose and not isinstance(model, DDP): - logger.warning( - "No PyTorch DDP or gradient handler is set up, please make sure you do not need " - "to all-reduce the gradients after a training step.", - ranks=[0]) - else: - gradient_handlers = [build_gradient_handler(cfg, model, optimizer) for cfg in gradient_handler_cfg] - - # check if optimizer is ColossalaiOptimizer - if not isinstance(optimizer, (ColossalaiOptimizer, ShardedOptimizerV2)): - optimizer = ColossalaiOptimizer(optim=optimizer) - - # gradient accumulation - grad_accum_size = gpc.config.get('gradient_accumulation', None) - if grad_accum_size is not None: - optimizer, train_dataloader, gradient_handlers, lr_scheduler = accumulate_gradient( - model=model, - optimizer=optimizer, - dataloader=train_dataloader, - accumulate_size=grad_accum_size, - gradient_handlers=gradient_handlers, - lr_scheduler=lr_scheduler) - engine = Engine(model=model, - optimizer=optimizer, - criterion=criterion, - gradient_handlers=gradient_handlers, - clip_grad_norm=clip_grad_norm, - ophook_list=ophooks, - schedule=schedule) - - return engine, train_dataloader, test_dataloader, lr_scheduler + launch( + config=config, + local_rank=local_rank, + rank=rank, + world_size=world_size, + host=host, + port=port, + backend=backend, + seed=seed, + verbose=verbose, + ) diff --git a/colossalai/interface/__init__.py b/colossalai/interface/__init__.py index 8c658e375146ae4f6f31f62a9e913ed16fbb0714..98b21c9c02c170dfce5076acf9ea0f99e03cb6c1 100644 --- a/colossalai/interface/__init__.py +++ b/colossalai/interface/__init__.py @@ -1,4 +1,4 @@ -from .model import ModelWrapper +from .model import AMPModelMixin, ModelWrapper from .optimizer import OptimizerWrapper -__all__ = ['OptimizerWrapper', 'ModelWrapper'] +__all__ = ["OptimizerWrapper", "ModelWrapper", "AMPModelMixin"] diff --git a/colossalai/interface/model.py b/colossalai/interface/model.py index a067d7671ce7eaaa174aced664ec16461cd03034..58df09b853eeaa1bcba79c1116a820822785d084 100644 --- a/colossalai/interface/model.py +++ b/colossalai/interface/model.py @@ -23,3 +23,12 @@ class ModelWrapper(nn.Module): def forward(self, *args, **kwargs): return self.module(*args, **kwargs) + + +class AMPModelMixin: + """This mixin class defines the interface for AMP training.""" + + def update_master_params(self): + """ + Update the master parameters for AMP training. + """ diff --git a/colossalai/interface/optimizer.py b/colossalai/interface/optimizer.py index dd9acab17584a452ad430094ab7fed4b9e272efb..95d11087bececd9b5bce5f50edb7059a22e13b7b 100644 --- a/colossalai/interface/optimizer.py +++ b/colossalai/interface/optimizer.py @@ -1,5 +1,6 @@ from typing import Union +import torch import torch.nn as nn from torch import Tensor from torch.optim import Optimizer @@ -21,7 +22,7 @@ class OptimizerWrapper: params = [] for group in self.param_groups: - params += group['params'] + params += group["params"] return params @property @@ -53,6 +54,9 @@ class OptimizerWrapper: """ loss.backward(*args, **kwargs) + def backward_by_grad(self, tensor: Tensor, grad: Tensor): + torch.autograd.backward(tensor, grad) + def state_dict(self): """ Returns the optimizer state. @@ -78,12 +82,14 @@ class OptimizerWrapper: """ nn.utils.clip_grad_value_(self.parameters, clip_value, *args, **kwargs) - def clip_grad_by_norm(self, - max_norm: Union[float, int], - norm_type: Union[float, int] = 2.0, - error_if_nonfinite: bool = False, - *args, - **kwargs) -> Tensor: + def clip_grad_by_norm( + self, + max_norm: Union[float, int], + norm_type: Union[float, int] = 2.0, + error_if_nonfinite: bool = False, + *args, + **kwargs, + ) -> Tensor: """ Clips gradient norm of an iterable of parameters. @@ -109,7 +115,8 @@ class OptimizerWrapper: loss (Tensor): The loss to be scaled. """ raise NotImplementedError( - "The method scale_loss is only available for optimizers with mixed precision training") + "The method scale_loss is only available for optimizers with mixed precision training" + ) def unscale_grad(self): """ @@ -118,4 +125,11 @@ class OptimizerWrapper: Note: Only available for optimizers with mixed precision training. """ raise NotImplementedError( - "The method unscale_grad is only available for optimizers with mixed precision training") + "The method unscale_grad is only available for optimizers with mixed precision training" + ) + + def unwrap(self): + """ + Unwrap the optimizer for checkpoint saving/loading. + """ + return self.optim diff --git a/colossalai/interface/pretrained.py b/colossalai/interface/pretrained.py new file mode 100644 index 0000000000000000000000000000000000000000..2f6bc10cd1323b38ae86f304645a41026333f6f3 --- /dev/null +++ b/colossalai/interface/pretrained.py @@ -0,0 +1,16 @@ +from typing import Optional + +from torch.nn import Module + +__all__ = [ + "get_pretrained_path", + "set_pretrained_path", +] + + +def get_pretrained_path(model: Module) -> Optional[str]: + return getattr(model, "_pretrained", None) + + +def set_pretrained_path(model: Module, path: str) -> None: + setattr(model, "_pretrained", path) diff --git a/colossalai/kernel/cuda_native/__init__.py b/colossalai/kernel/cuda_native/__init__.py index 1d5a6ce495bdec0fe02475ed0bb3b3b67bb86b3c..f8a974b5fb26dbb55d1952d1339c18bff52c3cb5 100644 --- a/colossalai/kernel/cuda_native/__init__.py +++ b/colossalai/kernel/cuda_native/__init__.py @@ -1,5 +1,13 @@ from .layer_norm import MixedFusedLayerNorm as LayerNorm +from .mha.mha import ColoAttention from .multihead_attention import MultiHeadAttention -from .scaled_softmax import FusedScaleMaskSoftmax, ScaledUpperTriangMaskedSoftmax +from .scaled_softmax import AttnMaskType, FusedScaleMaskSoftmax, ScaledUpperTriangMaskedSoftmax -__all__ = ['LayerNorm', 'MultiHeadAttention', 'FusedScaleMaskSoftmax', 'ScaledUpperTriangMaskedSoftmax'] +__all__ = [ + "LayerNorm", + "MultiHeadAttention", + "FusedScaleMaskSoftmax", + "ScaledUpperTriangMaskedSoftmax", + "ColoAttention", + "AttnMaskType", +] diff --git a/colossalai/kernel/cuda_native/csrc/compat.h b/colossalai/kernel/cuda_native/csrc/compat.h index 00066dc95475296168c799904dc595ed435d2b0a..a62beef91a8a55bb75a17ae0194381048340f688 100644 --- a/colossalai/kernel/cuda_native/csrc/compat.h +++ b/colossalai/kernel/cuda_native/csrc/compat.h @@ -7,4 +7,4 @@ #define DATA_PTR data_ptr #else #define DATA_PTR data -#endif \ No newline at end of file +#endif diff --git a/colossalai/kernel/cuda_native/csrc/gptq/column_remap.cu b/colossalai/kernel/cuda_native/csrc/gptq/column_remap.cu new file mode 100644 index 0000000000000000000000000000000000000000..2b1b366b1c02203805990d468feddeaa24a703ef --- /dev/null +++ b/colossalai/kernel/cuda_native/csrc/gptq/column_remap.cu @@ -0,0 +1,63 @@ +// Adapted from turboderp exllama: https://github.com/turboderp/exllama + +#include "column_remap.cuh" +#include "util.cuh" + +const int SHUF_BLOCKSIZE_X = 256; +const int SHUF_BLOCKSIZE_Y = 16; + +__global__ void column_remap_kernel +( + const half* __restrict__ x, + half* __restrict__ x_new, + const int x_width, + const int x_height, + const uint32_t* x_map +) +{ + int x_column = SHUF_BLOCKSIZE_X * blockIdx.x + threadIdx.x; + int x_row = SHUF_BLOCKSIZE_Y * blockIdx.y; + if (x_column >= x_width) return; + //if (x_row >= x_height) return; + + int x_stride = x_width; + int x_idx = x_row * x_stride + x_column; + + int x_row_end = min(x_row + SHUF_BLOCKSIZE_Y, x_height); + int x_idx_end = x_row_end * x_stride + x_column; + + int s_column = x_map[x_column]; + int s_idx = x_row * x_stride + s_column; + + while (x_idx < x_idx_end) + { + x_new[x_idx] = x[s_idx]; + x_idx += x_stride; + s_idx += x_stride; + } +} + +// Remap columns in x to correspond to sequential group index before matmul +// +// perform x -> seq_x such that seq_x @ seq_w == x @ w + +void column_remap_cuda +( + const half* x, + half* x_new, + const int x_height, + const int x_width, + const uint32_t* x_map +) +{ + dim3 threads(SHUF_BLOCKSIZE_X, 1, 1); + + dim3 blocks + ( + (x_width + SHUF_BLOCKSIZE_X - 1) / SHUF_BLOCKSIZE_X, + (x_height + SHUF_BLOCKSIZE_Y - 1) / SHUF_BLOCKSIZE_Y, + 1 + ); + + column_remap_kernel<<>>(x, x_new, x_width, x_height, x_map); +} diff --git a/colossalai/kernel/cuda_native/csrc/gptq/column_remap.cuh b/colossalai/kernel/cuda_native/csrc/gptq/column_remap.cuh new file mode 100644 index 0000000000000000000000000000000000000000..0364e38c4779717d560a04822c777f9d572dd2a7 --- /dev/null +++ b/colossalai/kernel/cuda_native/csrc/gptq/column_remap.cuh @@ -0,0 +1,19 @@ +// Adapted from turboderp exllama: https://github.com/turboderp/exllama + +#ifndef _column_remap_cuh +#define _column_remap_cuh + +#include +#include +#include + +void column_remap_cuda +( + const half* x, + half* x_new, + const int x_height, + const int x_width, + const uint32_t* x_map +); + +#endif diff --git a/colossalai/kernel/cuda_native/csrc/gptq/cu_compat.cuh b/colossalai/kernel/cuda_native/csrc/gptq/cu_compat.cuh new file mode 100644 index 0000000000000000000000000000000000000000..c5258813e147554e033eaf9a80c27dd694a50961 --- /dev/null +++ b/colossalai/kernel/cuda_native/csrc/gptq/cu_compat.cuh @@ -0,0 +1,58 @@ +// Adapted from turboderp exllama: https://github.com/turboderp/exllama + +#ifndef _cuda_compat_cuh +#define _cuda_compat_cuh + +// atomicAdd for half types, to support CC < 7.x + +__device__ __forceinline__ void atomicAdd_half(half* address, half val) +{ + unsigned int * address_as_ui = (unsigned int *) ((char *)address - ((size_t)address & 2)); + unsigned int old = *address_as_ui; + unsigned int assumed; + + do + { + assumed = old; + __half_raw hsum; + hsum.x = (size_t)address & 2 ? (old >> 16) : (old & 0xffff); + half tmpres = __hadd(hsum, val); + hsum = __half_raw(tmpres); + old = (size_t)address & 2 ? (old & 0xffff) | (hsum.x << 16) : (old & 0xffff0000) | hsum.x; + old = atomicCAS(address_as_ui, assumed, old); + } + while (assumed != old); +} + +// atomicAdd for half2 types + +__device__ __forceinline__ void atomicAdd_half2(half2* address, half2 val) +{ + unsigned int* address_as_ui = (unsigned int*)address; + unsigned int old = *address_as_ui; + unsigned int assumed; + do + { + assumed = old; + half2 old_val = *((half2*)&old); + half2 new_val = __hadd2(old_val, val); + old = atomicCAS(address_as_ui, assumed, *((unsigned int*)&new_val)); + } + while (assumed != old); +} + +// + +#if defined(__CUDA_ARCH__) || defined(USE_ROCM) +#if __CUDA_ARCH__ < 700 || defined(USE_ROCM) + +__device__ __forceinline__ void atomicAdd(half* address, half val) { atomicAdd_half(address, val); } + +#if __CUDA_ARCH__ < 600 || defined(USE_ROCM) +__device__ __forceinline__ void atomicAdd(half2* address, half2 val) { atomicAdd_half2(address, val); } +#endif + +#endif +#endif + +#endif diff --git a/colossalai/kernel/cuda_native/csrc/gptq/cuda_buffers.cu b/colossalai/kernel/cuda_native/csrc/gptq/cuda_buffers.cu new file mode 100644 index 0000000000000000000000000000000000000000..4416027c8387a17726f061519f49cd181843ac1d --- /dev/null +++ b/colossalai/kernel/cuda_native/csrc/gptq/cuda_buffers.cu @@ -0,0 +1,75 @@ +// Adapted from turboderp exllama: https://github.com/turboderp/exllama + +#define _cuda_buffers_cu +#include "cuda_buffers.cuh" + +CudaBuffers* g_buffers[CUDA_MAX_DEVICES] = {NULL}; +// __constant__ half2 q4_table[16][256]; +// half2 q4_table_host[16][256]; +// bool q4_table_init = false; + +CudaBuffers::CudaBuffers +( + int _device, + int _temp_state_size, + half* _temp_state, + half* _temp_dq +) : + device(_device), + temp_state_size(_temp_state_size), + temp_state(_temp_state), + temp_dq(_temp_dq) +{ + cudaSetDevice(_device); + + cudaStreamCreate(&alt_stream_1); + cudaStreamCreate(&alt_stream_2); + cudaStreamCreate(&alt_stream_3); + cudaEventCreate(&alt_stream_1_done); + cudaEventCreate(&alt_stream_2_done); + cudaEventCreate(&alt_stream_3_done); +} + +CudaBuffers::~CudaBuffers() +{ + cudaStreamDestroy(alt_stream_1); + cudaStreamDestroy(alt_stream_2); + cudaStreamDestroy(alt_stream_3); + cudaEventDestroy(alt_stream_1_done); + cudaEventDestroy(alt_stream_2_done); + cudaEventDestroy(alt_stream_3_done); +} + +CudaBuffers* get_buffers(const int device_index) +{ + return g_buffers[device_index]; +} + +void prepare_buffers_cuda +( + int _device, + int _temp_state_size, + half* _temp_state, + half* _temp_dq +) +{ + CudaBuffers* buffers = new CudaBuffers + ( + _device, + _temp_state_size, + _temp_state, + _temp_dq + ); + + g_buffers[_device] = buffers; +} + +void cleanup_buffers_cuda() +{ + for (int i = 0; i < CUDA_MAX_DEVICES; i++) + { + if (!g_buffers[i]) continue; + delete g_buffers[i]; + g_buffers[i] = NULL; + } +} diff --git a/colossalai/kernel/cuda_native/csrc/gptq/cuda_buffers.cuh b/colossalai/kernel/cuda_native/csrc/gptq/cuda_buffers.cuh new file mode 100644 index 0000000000000000000000000000000000000000..0bf2057c665cbfad46461d437ef5fe44b02f8f2e --- /dev/null +++ b/colossalai/kernel/cuda_native/csrc/gptq/cuda_buffers.cuh @@ -0,0 +1,55 @@ +// Adapted from turboderp exllama: https://github.com/turboderp/exllama + +#ifndef _cuda_buffers_cuh +#define _cuda_buffers_cuh + +#include +#include +#include +#include + +const int CUDA_MAX_DEVICES = 16; + +// #ifndef _cuda_buffers_cu +// extern __constant__ half2 q4_table[16][256]; +// #endif + +class CudaBuffers +{ +public: + int device; + + half* temp_state; // [max_hidden_rows * intermediate_size] + int temp_state_size; + half* temp_dq; // size of largest quant tensor * 8 + + cudaStream_t alt_stream_1; + cudaStream_t alt_stream_2; + cudaStream_t alt_stream_3; + cudaEvent_t alt_stream_1_done; + cudaEvent_t alt_stream_2_done; + cudaEvent_t alt_stream_3_done; + + CudaBuffers + ( + int _device, + int _temp_state_size, + half* _temp_state, + half* _temp_dq + ); + ~CudaBuffers(); +}; + +CudaBuffers* get_buffers(const int device_index); + +void prepare_buffers_cuda +( + int _device, + int _temp_state_size, + half* _temp_state, + half* _temp_dq +); + +void cleanup_buffers_cuda(); + +#endif diff --git a/colossalai/kernel/cuda_native/csrc/gptq/hip_compat.cuh b/colossalai/kernel/cuda_native/csrc/gptq/hip_compat.cuh new file mode 100644 index 0000000000000000000000000000000000000000..5cd2e8553ef6bee59162956dfed4a9a26227a4ce --- /dev/null +++ b/colossalai/kernel/cuda_native/csrc/gptq/hip_compat.cuh @@ -0,0 +1,49 @@ +// Adapted from turboderp exllama: https://github.com/turboderp/exllama + +#ifndef _hip_compat_cuh +#define _hip_compat_cuh + +// Workaround for a bug in hipamd, backported from upstream. +__device__ __forceinline__ __half __compat_hrcp(__half x) { + return __half_raw{ + static_cast<_Float16>(__builtin_amdgcn_rcph(static_cast<__half_raw>(x).data))}; +} + +__device__ __forceinline__ __half2 __compat_h2rcp(__half2 x) { + return _Float16_2{static_cast<_Float16>(__builtin_amdgcn_rcph(x.x)), + static_cast<_Float16>(__builtin_amdgcn_rcph(x.y))}; +} + +#define hrcp __compat_hrcp +#define h2rcp __compat_h2rcp + +// Workaround for hipify_python using rocblas instead of hipblas. +__host__ __forceinline__ hipblasStatus_t __compat_hipblasHgemm(hipblasHandle_t handle, + hipblasOperation_t transA, + hipblasOperation_t transB, + int m, + int n, + int k, + const half* alpha, + const half* AP, + int lda, + const half* BP, + int ldb, + const half* beta, + half* CP, + int ldc) { + return hipblasHgemm(handle, transA, transB, m, n, k, + reinterpret_cast(alpha), + reinterpret_cast(AP), lda, + reinterpret_cast(BP), ldb, + reinterpret_cast(beta), + reinterpret_cast(CP), ldc); +} + +#define rocblas_handle hipblasHandle_t +#define rocblas_operation_none HIPBLAS_OP_N +#define rocblas_get_stream hipblasGetStream +#define rocblas_set_stream hipblasSetStream +#define rocblas_hgemm __compat_hipblasHgemm + +#endif diff --git a/colossalai/kernel/cuda_native/csrc/gptq/linear_gptq.cpp b/colossalai/kernel/cuda_native/csrc/gptq/linear_gptq.cpp new file mode 100644 index 0000000000000000000000000000000000000000..bcc0e43901de85c3e361c01765ffab5b15de4da3 --- /dev/null +++ b/colossalai/kernel/cuda_native/csrc/gptq/linear_gptq.cpp @@ -0,0 +1,254 @@ +// Adapted from turboderp exllama: https://github.com/turboderp/exllama + +#include +#include +#include +#include +#include +#include +#include +#include "util.cuh" +#include "tuning.h" +#include "cuda_buffers.cuh" +#include "q4_matrix.cuh" +#include "q4_matmul.cuh" +#include "column_remap.cuh" + +// Check CUDA return code. We don't want to include Torch headers in the .cu files because parsing them adds almost a +// minute to the compile time on a 12900K. Also passing exceptions back to Python is super tricky, so in place of +// exceptions, CUDA functions return with a cudaError_t which we can parse and dump to the console. + +void check_cuda(cudaError_t ret) +{ + switch (ret) + { + case cudaSuccess: + break; + + case cudaUnspecified: + printf(" **** Unspecified error\n"); + TORCH_CHECK(false, "CUDA error"); + break; + + default: + printf(" **** CUDA error\n"); \ + printf(" **** %s\n", cudaGetErrorString(ret)); \ + TORCH_CHECK(false, "CUDA error"); \ + break; + } +} + +// Some decluttering macros + +#define STRINGIFY_(__x) #__x +#define STRINGIFY(__x) STRINGIFY_(__x) +#define TORCH_CHECK_DTYPE(__x, __dtype) TORCH_CHECK((__x).dtype() == torch::__dtype, #__x " is incorrect datatype, must be " #__dtype) +#define TORCH_CHECK_DTYPE_OPT(__x, __dtype) TORCH_CHECK((__x).device().is_meta() || (__x).dtype() == torch::__dtype, #__x " is incorrect datatype, must be " #__dtype) +#define TORCH_CHECK_SHAPES(__x, __dim_x, __y, __dim_y, __scale_y) TORCH_CHECK((__x).size(__dim_x) == (__y).size(__dim_y) * __scale_y, #__x " and " #__y " have incompatible shapes") +#define TORCH_CHECK_SHAPES_OPT(__x, __dim_x, __y, __dim_y, __scale_y) TORCH_CHECK((__x).device().is_meta() || (__x).size(__dim_x) == (__y).size(__dim_y) * __scale_y, #__x " and " #__y " have incompatible shapes") +#define TORCH_CHECK_SHAPE_MOD(__x, __dim_x, __mod) TORCH_CHECK((__x).size(__dim_x) % __mod == 0, #__x ".shape[" STRINGIFY(__dim_x) "] must be a multiple of " STRINGIFY(__mod)) +#define TORCH_CHECK_BUFFER_SIZE(__buffer, __minimum_size) TORCH_CHECK((__buffer).numel() >= __minimum_size, #__buffer " is too small") + +#define TORCH_CHECK_DEVICE_INDEX(__index) \ +do { \ + TORCH_CHECK(__index >= 0, "no device index"); \ + TORCH_CHECK(__index < CUDA_MAX_DEVICES, "invalid device index"); \ +} while(0) + +#define TORCH_CHECK_QUANT(__w, __w_scales, __w_zeros, __seq_g_idx, __x_map) \ +do { \ + TORCH_CHECK_DTYPE(__w, kInt); \ + TORCH_CHECK_DTYPE(__w_scales, kHalf); \ + TORCH_CHECK_DTYPE(__w_zeros, kInt); \ + TORCH_CHECK_DTYPE_OPT(__seq_g_idx, kShort); \ + TORCH_CHECK_DTYPE_OPT(__x_map, kInt); \ + TORCH_CHECK_SHAPES_OPT(__seq_g_idx, 0, __w, 0, 2 * 8); \ + TORCH_CHECK_SHAPES_OPT(__x_map, 0, __w, 0, 8); \ +} while(0) + +int get_groupsize(torch::Tensor w, torch::Tensor w_zeros) +{ + int groupsize = w.size(0) * 8 / w_zeros.size(0); + TORCH_CHECK(groupsize * w_zeros.size(0) == w.size(0) * 8, "w.shape[-2] must be a multiple of zeros.shape[-2]") + return groupsize; +} + + +// Tuning parameters + +ExLlamaTuning tuningParams; + +void set_tuning_params +( + int matmul_recons_thd, + bool matmul_fused_remap, + bool matmul_no_half2 +) +{ + tuningParams.matmul_recons_thd = matmul_recons_thd; + tuningParams.matmul_fused_remap = matmul_fused_remap; + tuningParams.matmul_no_half2 = matmul_no_half2; +} + + +// Release all unmanaged objects allocated by the extension + +void cleanup() +{ + cleanup_buffers_cuda(); + g_q4_free_matrices(); +} + + +// Prepare buffers for forward pass + +void prepare_buffers +( + torch::Device device, + torch::Tensor temp_state, + torch::Tensor temp_dq +) +{ + int device_index = device.index(); + TORCH_CHECK_DEVICE_INDEX(device_index); + const at::cuda::OptionalCUDAGuard device_guard(device); + + prepare_buffers_cuda + ( + device_index, + // buffer size used for sanity checks + temp_state.numel(), + (half*) temp_state.data_ptr(), + (half*) temp_dq.data_ptr() + ); +} + + +// Create Q4Matrix, return handle + +uintptr_t make_q4 +( + torch::Tensor qweight, + torch::Tensor qzeros, + torch::Tensor scales, + torch::Tensor g_idx, + int device +) +{ + TORCH_CHECK_DTYPE(qweight, kInt); + TORCH_CHECK_DTYPE(qzeros, kInt); + TORCH_CHECK_DTYPE(scales, kHalf); + TORCH_CHECK_DTYPE_OPT(g_idx, kInt); + TORCH_CHECK_SHAPES(qweight, 1, qzeros, 1, 8); + TORCH_CHECK_SHAPES(scales, 1, qweight, 1, 1); + TORCH_CHECK_SHAPES(qzeros, 0, scales, 0, 1); + + int width = qweight.size(1); + int height = qweight.size(0) * 8; + int groups = qzeros.size(0); + + Q4Matrix* m = new Q4Matrix + ( + height, + width, + groups, + + (uint32_t*) qweight.data_ptr(), + (uint32_t*) qzeros.data_ptr(), + (half*) scales.data_ptr(), + g_idx.device().is_meta() ? NULL : (uint32_t*) g_idx.data_ptr(), + + device + ); + + g_q4_keep_matrix(m); + return reinterpret_cast (m); +} + + +// Matmul half @ quant -> half + +void q4_matmul +( + torch::Tensor x, + uintptr_t w, + torch::Tensor out +) +{ + Q4Matrix* wm = reinterpret_cast (w); + + TORCH_CHECK_DTYPE(x, kHalf); + TORCH_CHECK_DTYPE(out, kHalf); + TORCH_CHECK_SHAPES(x, 0, out, 0, 1); + TORCH_CHECK(wm->height == x.size(-1), "x and w have incompatible shapes") + + const at::cuda::OptionalCUDAGuard device_guard(device_of(x)); + + int x_height = x.size(0); + + if (tuningParams.matmul_recons_thd == 0 || x_height < tuningParams.matmul_recons_thd) + { + q4_matmul_cuda + ( + &tuningParams, + (half*) x.data_ptr(), + x_height, + wm, + (half*) out.data_ptr() + ); + } + else + { + q4_matmul_recons_cuda + ( + &tuningParams, + (half*) x.data_ptr(), + x_height, + wm, + (half*) out.data_ptr(), + at::cuda::getCurrentCUDABlasHandle() + ); + } +} + + +// Remap columns in half tensor + +void column_remap +( + torch::Tensor x, + torch::Tensor x_new, + torch::Tensor x_map +) +{ + TORCH_CHECK_DTYPE(x, kHalf); + TORCH_CHECK_DTYPE(x_new, kHalf); + TORCH_CHECK_DTYPE(x_map, kInt); + TORCH_CHECK_SHAPES(x_map, 0, x, 1, 1); + + int height = x.size(0); + int width = x.size(1); + + TORCH_CHECK_BUFFER_SIZE(x_new, height * width); + + const at::cuda::OptionalCUDAGuard device_guard(device_of(x)); + + column_remap_cuda + ( + (half*) x.data_ptr(), + (half*) x_new.data_ptr(), + height, + width, + (uint32_t*) x_map.data_ptr() + ); +} + + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) +{ + m.def("set_tuning_params", &set_tuning_params, "set_tuning_params"); + m.def("prepare_buffers", &prepare_buffers, "prepare_buffers"); + m.def("cleanup", &cleanup, "cleanup"); + m.def("make_q4", &make_q4, "make_q4"); + m.def("q4_matmul", &q4_matmul, "q4_matmul"); +} diff --git a/colossalai/kernel/cuda_native/csrc/gptq/matrix.cuh b/colossalai/kernel/cuda_native/csrc/gptq/matrix.cuh new file mode 100644 index 0000000000000000000000000000000000000000..2fd5ab0b36cd0dd67c9b6081740bbed12711ee09 --- /dev/null +++ b/colossalai/kernel/cuda_native/csrc/gptq/matrix.cuh @@ -0,0 +1,294 @@ +// Adapted from turboderp exllama: https://github.com/turboderp/exllama + +#ifndef _matrix_cuh +#define _matrix_cuh + +#include +#include + +class MatrixView_half +{ +public: + const half* data; + const int height; + const int width; + + __device__ __forceinline__ MatrixView_half(const half* data, const int height, const int width) + : data(data), height(height), width(width) + { } + + __device__ __forceinline__ half item(int row, int column) const { return data[row * width + column]; } + __device__ __forceinline__ half2 item_half2(int row, int column) const { return ((half2*)data)[(row * width + column) / 2]; } + __device__ __forceinline__ half2 item_half2half2(int row, int column) const { return __half2half2(data[row * width + column]); } + __device__ __forceinline__ const half* item_ptr(int row, int column) const { return &data[row * width + column]; } +}; + +class MatrixView_half_rw +{ +public: + half* data; + const int height; + const int width; + + __device__ __forceinline__ MatrixView_half_rw(half* data, const int height, const int width) + : data(data), height(height), width(width) + { } + + __device__ __forceinline__ half item(int row, int column) const { return data[row * width + column]; } + __device__ __forceinline__ half2 item_half2(int row, int column) const { return ((half2*)data)[(row * width + column) / 2]; } + __device__ __forceinline__ half* item_ptr(int row, int column) { return &data[row * width + column]; } + __device__ __forceinline__ void set(int row, int column, half value) { data[row * width + column] = value; } + __device__ __forceinline__ void set_half2(int row, int column, half2 value) { ((half2*)data)[(row * width + column) / 2] = value; } +}; + +class MatrixView_q4_row +{ +public: + const uint32_t* data; + const int height; + const int width; + + __device__ __forceinline__ MatrixView_q4_row(const uint32_t* data, const int height, const int width) + : data(data), height(height), width(width) + { } + + __device__ __forceinline__ int item(int row, int column) const + { + int shift = (column & 0x07) * 4; + return (data[row * width / 8 + column / 8] >> shift) & 0x0f; + } +}; + +class MatrixView_q4_column +{ +public: + const uint32_t* data; + const int height; + const int width; + + __device__ __forceinline__ MatrixView_q4_column(const uint32_t* data, const int height, const int width) + : data(data), height(height), width(width) + { } + + __device__ __forceinline__ int item(int row, int column) const + { + int shift = (row & 0x07) * 4; + return (data[row / 8 * width + column] >> shift) & 0x0f; + } + + __device__ __forceinline__ uint32_t item_uint32_t(int row, int column) { return data[row / 8 * width + column]; } + __device__ __forceinline__ const uint32_t* item_uint32_ptr(int row, int column) { return &data[row / 8 * width + column]; } +}; + +// TODO: Rewrite all these dot product functions using functors or something, move to q4_matmul.cu + +// Accumulated dot product of 8-element row vectors in h and quantized column vectors in v, constant zero/scale + +__device__ __forceinline__ half2 dot_product_8 +( + const half2 acc, + MatrixView_half& h_, + const int h_row, + const int h_column, // divisible by 8 + MatrixView_q4_column& v_, + const int v_row, // divisible by 8 + const int v_column, + const half2 v_scale_2, + const uint32_t v_zero, // + 1 (!!) + const int count +) +{ + const half2* h_ptr = (const half2*) h_.item_ptr(h_row, h_column); + const uint32_t* v_ptr = (const uint32_t*) v_.item_uint32_ptr(v_row, v_column); + half2 result = acc; + + for (int i = 0; i < count; i++) + { + uint32_t v_read = *v_ptr; v_ptr += v_.width; + + half v_0 = __int2half_rn((int)((v_read ) & 0x0f) - v_zero); + half v_1 = __int2half_rn((int)((v_read >> 4) & 0x0f) - v_zero); + half v_2 = __int2half_rn((int)((v_read >> 8) & 0x0f) - v_zero); + half v_3 = __int2half_rn((int)((v_read >> 12) & 0x0f) - v_zero); + half v_4 = __int2half_rn((int)((v_read >> 16) & 0x0f) - v_zero); + half v_5 = __int2half_rn((int)((v_read >> 20) & 0x0f) - v_zero); + half v_6 = __int2half_rn((int)((v_read >> 24) & 0x0f) - v_zero); + half v_7 = __int2half_rn((int)((v_read >> 28) ) - v_zero); + + half2 v_01 = __halves2half2(v_0, v_1); + half2 v_23 = __halves2half2(v_2, v_3); + half2 v_45 = __halves2half2(v_4, v_5); + half2 v_67 = __halves2half2(v_6, v_7); + +// half2 v_01 = q4_table[v_zero - 1][(v_read ) & 0xff]; // (constant memory is too slow apparently) +// half2 v_23 = q4_table[v_zero - 1][(v_read >> 8) & 0xff]; +// half2 v_45 = q4_table[v_zero - 1][(v_read >> 16) & 0xff]; +// half2 v_67 = q4_table[v_zero - 1][(v_read >> 24) ]; + + half2 tmp = __hmul2(*h_ptr++, v_01); + tmp = __hfma2(*h_ptr++, v_23, tmp); + tmp = __hfma2(*h_ptr++, v_45, tmp); + tmp = __hfma2(*h_ptr++, v_67, tmp); + result = __hfma2(v_scale_2, tmp, result); + } + + return result; +} + +__device__ __forceinline__ half dot_product_8_h +( + const half acc, + MatrixView_half& h_, + const int h_row, + const int h_column, // divisible by 8 + MatrixView_q4_column& v_, + const int v_row, // divisible by 8 + const int v_column, + const half v_scale, + const uint32_t v_zero, // + 1 (!!) + const int count +) +{ + const half* h_ptr = h_.item_ptr(h_row, h_column); + const uint32_t* v_ptr = (const uint32_t*) v_.item_uint32_ptr(v_row, v_column); + half result = acc; + + for (int i = 0; i < count; i++) + { + uint32_t v_read = *v_ptr; v_ptr += v_.width; + + half v_0 = __int2half_rn((int)((v_read ) & 0x0f) - v_zero); + half v_1 = __int2half_rn((int)((v_read >> 4) & 0x0f) - v_zero); + half v_2 = __int2half_rn((int)((v_read >> 8) & 0x0f) - v_zero); + half v_3 = __int2half_rn((int)((v_read >> 12) & 0x0f) - v_zero); + half v_4 = __int2half_rn((int)((v_read >> 16) & 0x0f) - v_zero); + half v_5 = __int2half_rn((int)((v_read >> 20) & 0x0f) - v_zero); + half v_6 = __int2half_rn((int)((v_read >> 24) & 0x0f) - v_zero); + half v_7 = __int2half_rn((int)((v_read >> 28) ) - v_zero); + + half tmp = __hmul(*h_ptr++, v_0); + tmp = __hfma(*h_ptr++, v_1, tmp); + tmp = __hfma(*h_ptr++, v_2, tmp); + tmp = __hfma(*h_ptr++, v_3, tmp); + tmp = __hfma(*h_ptr++, v_4, tmp); + tmp = __hfma(*h_ptr++, v_5, tmp); + tmp = __hfma(*h_ptr++, v_6, tmp); + tmp = __hfma(*h_ptr++, v_7, tmp); + result = __hfma(v_scale, tmp, result); + } + + return result; +} + +// Accumulated dot product of 8-element row vectors in h and quantized column vectors in v, constant zero/scale, with x_map + +__device__ __forceinline__ half2 dot_product_8_x_map +( + const half2 acc, + MatrixView_half& h_, + const int h_row, + const int h_column, // divisible by 8 + MatrixView_q4_column& v_, + const int v_row, // divisible by 8 + const int v_column, + const half2 v_scale_2, + const uint32_t v_zero, // + 1 (!!) + const int count, + const uint32_t* x_map +) +{ + const half* h_ptr = h_.item_ptr(h_row, 0); + const uint32_t* x_map_ptr = x_map + h_column; + const uint32_t* v_ptr = (const uint32_t*) v_.item_uint32_ptr(v_row, v_column); + half2 result = acc; + + for (int i = 0; i < count; i++) + { + uint32_t v_read = *v_ptr; v_ptr += v_.width; + + half v_0 = __int2half_rn((int)((v_read ) & 0x0f) - v_zero); + half v_1 = __int2half_rn((int)((v_read >> 4) & 0x0f) - v_zero); + half v_2 = __int2half_rn((int)((v_read >> 8) & 0x0f) - v_zero); + half v_3 = __int2half_rn((int)((v_read >> 12) & 0x0f) - v_zero); + half v_4 = __int2half_rn((int)((v_read >> 16) & 0x0f) - v_zero); + half v_5 = __int2half_rn((int)((v_read >> 20) & 0x0f) - v_zero); + half v_6 = __int2half_rn((int)((v_read >> 24) & 0x0f) - v_zero); + half v_7 = __int2half_rn((int)((v_read >> 28) ) - v_zero); + + half2 v_01 = __halves2half2(v_0, v_1); + half2 v_23 = __halves2half2(v_2, v_3); + half2 v_45 = __halves2half2(v_4, v_5); + half2 v_67 = __halves2half2(v_6, v_7); + + half h_0 = h_ptr[*x_map_ptr++]; + half h_1 = h_ptr[*x_map_ptr++]; + half h_2 = h_ptr[*x_map_ptr++]; + half h_3 = h_ptr[*x_map_ptr++]; + half h_4 = h_ptr[*x_map_ptr++]; + half h_5 = h_ptr[*x_map_ptr++]; + half h_6 = h_ptr[*x_map_ptr++]; + half h_7 = h_ptr[*x_map_ptr++]; + + half2 h_01 = __halves2half2(h_0, h_1); + half2 h_23 = __halves2half2(h_2, h_3); + half2 h_45 = __halves2half2(h_4, h_5); + half2 h_67 = __halves2half2(h_6, h_7); + + half2 tmp = __hmul2(h_01, v_01); + tmp = __hfma2(h_23, v_23, tmp); + tmp = __hfma2(h_45, v_45, tmp); + tmp = __hfma2(h_67, v_67, tmp); + result = __hfma2(v_scale_2, tmp, result); + } + + return result; +} + +__device__ __forceinline__ half dot_product_8_x_map_h +( + const half acc, + MatrixView_half& h_, + const int h_row, + const int h_column, // divisible by 8 + MatrixView_q4_column& v_, + const int v_row, // divisible by 8 + const int v_column, + const half v_scale, + const uint32_t v_zero, // + 1 (!!) + const int count, + const uint32_t* x_map +) +{ + const half* h_ptr = h_.item_ptr(h_row, 0); + const uint32_t* x_map_ptr = x_map + h_column; + const uint32_t* v_ptr = (const uint32_t*) v_.item_uint32_ptr(v_row, v_column); + half result = acc; + + for (int i = 0; i < count; i++) + { + uint32_t v_read = *v_ptr; v_ptr += v_.width; + + half v_0 = __int2half_rn((int)((v_read ) & 0x0f) - v_zero); + half v_1 = __int2half_rn((int)((v_read >> 4) & 0x0f) - v_zero); + half v_2 = __int2half_rn((int)((v_read >> 8) & 0x0f) - v_zero); + half v_3 = __int2half_rn((int)((v_read >> 12) & 0x0f) - v_zero); + half v_4 = __int2half_rn((int)((v_read >> 16) & 0x0f) - v_zero); + half v_5 = __int2half_rn((int)((v_read >> 20) & 0x0f) - v_zero); + half v_6 = __int2half_rn((int)((v_read >> 24) & 0x0f) - v_zero); + half v_7 = __int2half_rn((int)((v_read >> 28) ) - v_zero); + + half tmp = __hmul(h_ptr[*x_map_ptr++], v_0); + tmp = __hfma(h_ptr[*x_map_ptr++], v_1, tmp); + tmp = __hfma(h_ptr[*x_map_ptr++], v_2, tmp); + tmp = __hfma(h_ptr[*x_map_ptr++], v_3, tmp); + tmp = __hfma(h_ptr[*x_map_ptr++], v_4, tmp); + tmp = __hfma(h_ptr[*x_map_ptr++], v_5, tmp); + tmp = __hfma(h_ptr[*x_map_ptr++], v_6, tmp); + tmp = __hfma(h_ptr[*x_map_ptr++], v_7, tmp); + result = __hfma(v_scale, tmp, result); + } + + return result; +} + +#endif diff --git a/colossalai/kernel/cuda_native/csrc/gptq/q4_matmul.cu b/colossalai/kernel/cuda_native/csrc/gptq/q4_matmul.cu new file mode 100644 index 0000000000000000000000000000000000000000..f47daeb0e8771a08a5dae1886fa571c9063fbb30 --- /dev/null +++ b/colossalai/kernel/cuda_native/csrc/gptq/q4_matmul.cu @@ -0,0 +1,260 @@ +// Adapted from turboderp exllama: https://github.com/turboderp/exllama + +#include "q4_matmul.cuh" +#include "column_remap.cuh" +#include "util.cuh" +#include "matrix.cuh" +#include "cu_compat.cuh" +#include "cuda_buffers.cuh" +#if defined(USE_ROCM) +#include "hip_compat.cuh" +#endif + +const int THREADS_X = 32; // Block size and thread count along columns in w and out +const int THREADS_Y = 1; // Block size and thread count along rows in x and out + +typedef void (*fp_q4_matmul_kernel) +( + const half*, + const uint32_t*, + half*, + const half*, + const uint32_t*, + const int, + const int, + const int, + const int, + const int, + const uint32_t*, + bool +); + +template +__global__ void q4_matmul_kernel +( + const half* __restrict__ x, + const uint32_t* __restrict__ w, + half* __restrict__ out, + const half* __restrict__ w_scales, + const uint32_t* __restrict__ w_zeros, + const int height, + const int dim, + const int width, + const int groupsize, + const int block_size_z, + const uint32_t* __restrict__ x_map, + bool no_zero +) +{ + // Start of block + + int x_column = block_size_z * blockIdx.z; + int x_column_end = min(dim, block_size_z * (blockIdx.z + 1)); + + int w_column = THREADS_X * blockIdx.x + threadIdx.x; + int x_row = THREADS_Y * blockIdx.y + threadIdx.y; + + int iterations = (x_column_end - x_column) / 8; + + // Views + + MatrixView_half x_(x, height, dim); + MatrixView_half w_scales_(w_scales, dim / groupsize, width); + MatrixView_q4_row w_zeros_(w_zeros, dim / groupsize, width); + MatrixView_q4_column w_(w, dim, width); + MatrixView_half_rw out_(out, height, width); + + // Zero output + + if (!no_zero && blockIdx.z == 0 && (threadIdx.x & 1) == 0) + { + *((uint32_t*) out_.item_ptr(x_row, w_column)) = 0; + __syncthreads(); + } + + // Loop over part of x row (and w column) + + half2 acc = {}; + half acc_h = {}; + + if constexpr (use_groupsize) + { + // For quant matrices where groupsize divides BLOCK_SIZE_Z we always start on a group boundary, so this + // could be slightly faster + + for (int k = x_column, group = x_column / groupsize; k < x_column + iterations * 8; group++, k += groupsize) + { + if constexpr (use_half2) + { + half2 w_scale = w_scales_.item_half2half2(group, w_column); + uint32_t w_zero = w_zeros_.item(group, w_column) + 1; + + if constexpr (use_x_map) acc = dot_product_8_x_map(acc, x_, x_row, k, w_, k, w_column, w_scale, w_zero, groupsize / 8, x_map); + else acc = dot_product_8 (acc, x_, x_row, k, w_, k, w_column, w_scale, w_zero, groupsize / 8); + } + else + { + half w_scale = w_scales_.item(group, w_column); + uint32_t w_zero = w_zeros_.item(group, w_column) + 1; + + if constexpr (use_x_map) acc_h = dot_product_8_x_map_h(acc_h, x_, x_row, k, w_, k, w_column, w_scale, w_zero, groupsize / 8, x_map); + else acc_h = dot_product_8_h (acc_h, x_, x_row, k, w_, k, w_column, w_scale, w_zero, groupsize / 8); + } + } + } + else + { + // Otherwise assume groupsize is a multiple of 8, do 8 columns per iteration and trust the cache + + for (int k = x_column; k < x_column + iterations * 8; k += 8) + { + if constexpr (use_half2) + { + int group = k / groupsize; + half2 w_scale = w_scales_.item_half2half2(group, w_column); + uint32_t w_zero = w_zeros_.item(group, w_column) + 1; + + if constexpr (use_x_map) acc = dot_product_8_x_map(acc, x_, x_row, k, w_, k, w_column, w_scale, w_zero, 1, x_map); + else acc = dot_product_8 (acc, x_, x_row, k, w_, k, w_column, w_scale, w_zero, 1); + } + else + { + int group = k / groupsize; + half w_scale = w_scales_.item(group, w_column); + uint32_t w_zero = w_zeros_.item(group, w_column) + 1; + + if constexpr (use_x_map) acc_h = dot_product_8_x_map_h(acc_h, x_, x_row, k, w_, k, w_column, w_scale, w_zero, 1, x_map); + else acc_h = dot_product_8_h (acc_h, x_, x_row, k, w_, k, w_column, w_scale, w_zero, 1); + } + } + } + + // Add to block result + + if constexpr (use_half2) + { + half result = __hadd(__low2half(acc), __high2half(acc)); + atomicAdd(out_.item_ptr(x_row, w_column), result); + } + else + { + atomicAdd(out_.item_ptr(x_row, w_column), acc_h); + } +} + +fp_q4_matmul_kernel q4_matmul_kernel_pick(ExLlamaTuning* tuningParams, int block_size_z, int groupsize, uint32_t* x_map) +{ + // + if (tuningParams->matmul_no_half2) { + if (block_size_z % groupsize == 0) { + if (x_map) return q4_matmul_kernel; + else return q4_matmul_kernel; + } else { + if (x_map) return q4_matmul_kernel; + else return q4_matmul_kernel; + } + } else { + if (block_size_z % groupsize == 0) + { + if (x_map) return q4_matmul_kernel; + else return q4_matmul_kernel; + } else { + if (x_map) return q4_matmul_kernel; + else return q4_matmul_kernel; + } + } +}; + +// Compute y = x @ w + +void q4_matmul_cuda +( + ExLlamaTuning* tuningParams, + const half* x, + const int x_height, + const Q4Matrix* w, + half* out, + bool no_zero, + cudaStream_t alt_stream +) +{ + int height = x_height; + int dim = w->height; + int width = w->width; + + cudaSetDevice(w->device); + + uint32_t* x_map = w->cuda_x_map; + const half* x_mapped = x; + if (x_map && !tuningParams->matmul_fused_remap && !alt_stream) + { + CudaBuffers* buffers = get_buffers(w->device); + column_remap_cuda(x, buffers->temp_state, x_height, dim, w->cuda_x_map); + x_mapped = buffers->temp_state; + x_map = NULL; + } + + int block_size_z; + if (w->width == 4096) block_size_z = 384; // 7B + else if (w->width == 11008) block_size_z = 256; + else if (w->width == 5120) block_size_z = 384; // 13B + else if (w->width == 13824) block_size_z = 256; + else if (w->width == 6656) block_size_z = 256; // 33B + else if (w->width == 17920) block_size_z = 128; + else block_size_z = 256; + + //if (!no_zero) cudaMemsetAsync(out, 0, x_height * w->width * sizeof(half)); + + dim3 threads(THREADS_X, THREADS_Y, 1); + + dim3 blocks + ( + (width + threads.x - 1) / threads.x, + (height + threads.y - 1) / threads.y, + (dim + block_size_z - 1) / block_size_z + ); + + fp_q4_matmul_kernel kernel = q4_matmul_kernel_pick(tuningParams, block_size_z, w->groupsize, x_map); + + kernel<<>> (x_mapped, w->cuda_qweight, out, w->cuda_scales, w->cuda_qzeros, height, dim, width, w->groupsize, block_size_z, x_map, no_zero); +} + +void q4_matmul_recons_cuda +( + ExLlamaTuning* tuningParams, + const half* x, + const int x_height, + Q4Matrix* w, + half* out, + const cublasHandle_t handle, + bool no_zero +) +{ + int height = x_height; + int dim = w->height; + int width = w->width; + + cudaSetDevice(w->device); + CudaBuffers* buffers = get_buffers(w->device); + + const half* x_mapped = x; + if (w->cuda_x_map) + { + TORCH_CHECK(buffers->temp_state_size >= x_height * dim, "temp_state buffer is too small"); + column_remap_cuda(x, buffers->temp_state, x_height, dim, w->cuda_x_map); + x_mapped = buffers->temp_state; + } + + w->reconstruct(buffers->temp_dq); + +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 700 + const float alpha = 1.0f; + const float beta = no_zero ? 1.0f : 0.0f; + cublasSgemmEx(handle, CUBLAS_OP_N, CUBLAS_OP_N, width, height, dim, &alpha, buffers->temp_dq, CUDA_R_16F, width, + x_mapped, CUDA_R_16F, dim, &beta, out, CUDA_R_16F, width); +#else + const half alpha = __float2half(1.0f); + const half beta = no_zero ? __float2half(1.0f) : __float2half(0.0f); + cublasHgemm(handle, CUBLAS_OP_N, CUBLAS_OP_N, width, height, dim, &alpha, buffers->temp_dq, width, x_mapped, dim, &beta, out, width); +#endif +} diff --git a/colossalai/kernel/cuda_native/csrc/gptq/q4_matmul.cuh b/colossalai/kernel/cuda_native/csrc/gptq/q4_matmul.cuh new file mode 100644 index 0000000000000000000000000000000000000000..09f3e1a633628fe1584c65b1db61ab1d3570f3da --- /dev/null +++ b/colossalai/kernel/cuda_native/csrc/gptq/q4_matmul.cuh @@ -0,0 +1,43 @@ +// Adapted from turboderp exllama: https://github.com/turboderp/exllama + +#ifndef _q4_matmul_cuh +#define _q4_matmul_cuh + +#include +#include +#include +#include +#include + +#include "q4_matrix.cuh" +#include "tuning.h" + +// Workaround for hipify_python using rocblas instead of hipblas. +#if defined(USE_ROCM) +#include +#define rocblas_handle hipblasHandle_t +#endif + +void q4_matmul_cuda +( + ExLlamaTuning* tuningParams, + const half* x, + const int x_height, + const Q4Matrix* w, + half* out, + bool no_zero = false, + cudaStream_t alt_stream = NULL +); + +void q4_matmul_recons_cuda +( + ExLlamaTuning* tuningParams, + const half* x, + const int x_height, + Q4Matrix* w, + half* out, + const cublasHandle_t handle, + bool no_zero = false +); + +#endif diff --git a/colossalai/kernel/cuda_native/csrc/gptq/q4_matrix.cu b/colossalai/kernel/cuda_native/csrc/gptq/q4_matrix.cu new file mode 100644 index 0000000000000000000000000000000000000000..9c61143f565e0b4b75fc68ebbfec93a692389931 --- /dev/null +++ b/colossalai/kernel/cuda_native/csrc/gptq/q4_matrix.cu @@ -0,0 +1,225 @@ +// Adapted from turboderp exllama: https://github.com/turboderp/exllama + +#include "q4_matrix.cuh" +#include +#include "util.cuh" +#include "matrix.cuh" + +using namespace std; + +const int UNSHUF_BLOCKSIZE_X = 64; + +const int RECONS_THREADS_X = 64; // Block size and thread count along columns in out, each thread converts 1 column +const int RECONS_THREADS_Y = 1; // Block size and thread count along rows in x and out, each thread converts 8 rows + +vector g_q4_matrices; + +void g_q4_keep_matrix(Q4Matrix* m) +{ + g_q4_matrices.push_back(m); +} + +void g_q4_free_matrices() +{ + for (const auto& m : g_q4_matrices) delete m; + g_q4_matrices.clear(); +} + +Q4Matrix::Q4Matrix +( + const int _height, + const int _width, + const int _groups, + + uint32_t* _qweight, + uint32_t* _qzeros, + half* _scales, + uint32_t* _g_idx, + + const int _device +) : + height(_height), + width(_width), + groups(_groups), + device(_device) +{ + cudaSetDevice(device); + + cuda_qweight = _qweight; + cuda_qzeros = _qzeros; + cuda_scales = _scales; + + groupsize = height / groups; + + if (_g_idx) make_sequential(_g_idx); +} + +Q4Matrix::~Q4Matrix() +{ +} + +// Make sequential + +__global__ void make_sequential_kernel +( + const uint32_t* __restrict__ w, + uint32_t* __restrict__ w_new, + const uint32_t* __restrict__ x_map, + const int w_height, + const int w_width +) +{ + const uint64_t* w2 = (uint64_t*) w; + uint64_t* w_new2 = (uint64_t*) w_new; + int w2_stride = w_width >> 1; + + int w2_column = UNSHUF_BLOCKSIZE_X * blockIdx.x + threadIdx.x; + if (w2_column >= w2_stride) return; + + int w_new2_row = blockIdx.y; + + int x_map_idx = w_new2_row << 3; + + uint64_t dst = 0; + + #pragma unroll + for (int i = 0; i < 8; i++) + { + int source_row = x_map[x_map_idx++]; + + int w2_row = source_row >> 3; + int w2_subrow = source_row & 0x07; + int w2_row_shift = w2_subrow << 2; + int wnew2_row_shift = i << 2; + + uint64_t src = w2[w2_row * w2_stride + w2_column]; + src >>= w2_row_shift; + src &= 0x0000000f0000000f; + src <<= wnew2_row_shift; + dst |= src; + } + + w_new2[w_new2_row * w2_stride + w2_column] = dst; +} + +void Q4Matrix::make_sequential(const uint32_t* cpu_g_idx) +{ + uint32_t* cuda_new_qweight = NULL; + cudaMalloc(&cuda_new_qweight, height / 8 * width * sizeof(uint32_t)); + cudaMalloc(&cuda_x_map, height * sizeof(uint32_t)); // TODO: Should probably be allocated in PyTorch + + uint32_t* cpu_g_idx_map = (uint32_t*) calloc(groups, sizeof(uint32_t)); + uint32_t* cpu_x_map = (uint32_t*) malloc(height * sizeof(uint32_t)); + uint32_t* cpu_x_map_inv = (uint32_t*) malloc(height * sizeof(uint32_t)); + + // Group histogram + + for (int i = 0; i < height; i++) cpu_g_idx_map[cpu_g_idx[i]]++; + + // Group map + + for (int i = 0, acc = 0; i < groups; i++) + { + short tmp = cpu_g_idx_map[i]; + cpu_g_idx_map[i] = acc; + acc += tmp; + } + + // X map (inverse) + + for (int row = 0; row < height; row++) + { + uint32_t target_group = cpu_g_idx[row]; + uint32_t target_row = cpu_g_idx_map[target_group]; + cpu_g_idx_map[target_group]++; + cpu_x_map_inv[row] = target_row; + } + + // X map + + for (int row = 0; row < height; row++) cpu_x_map[cpu_x_map_inv[row]] = row; + + // Move to CUDA + + cudaMemcpyAsync(cuda_x_map, cpu_x_map, height * sizeof(uint32_t), cudaMemcpyHostToDevice); + + // Rearrange rows in w + + dim3 threads(UNSHUF_BLOCKSIZE_X, 1, 1); + dim3 blocks + ( + (width + UNSHUF_BLOCKSIZE_X * 2 - 1) / (UNSHUF_BLOCKSIZE_X * 2), + height / 8, + 1 + ); + + make_sequential_kernel<<>>(cuda_qweight, cuda_new_qweight, cuda_x_map, height / 8, width); + + // Replace qweights + + cudaMemcpyAsync(cuda_qweight, cuda_new_qweight, height / 8 * width * sizeof(uint32_t), cudaMemcpyDeviceToDevice); + + // Cleanup + + cudaDeviceSynchronize(); + cudaFree(cuda_new_qweight); + free(cpu_g_idx_map); + free(cpu_x_map); + free(cpu_x_map_inv); +} + +__global__ void reconstruct_kernel +( + const uint32_t* __restrict__ w, + half* __restrict__ out, // (y) + const half* __restrict__ w_scales, + const uint32_t* __restrict__ w_zeros, + const int height, + const int width, + const int groupsize +) +{ + // Start of block + + int column = RECONS_THREADS_X * blockIdx.x + threadIdx.x; + int row = (RECONS_THREADS_Y * blockIdx.y + threadIdx.y) * 8; + if (column >= width) return; + + // Views + + MatrixView_q4_column w_(w, height, width); + MatrixView_half_rw out_(out, height, width); + MatrixView_half w_scales_(w_scales, height / groupsize, width); + MatrixView_q4_row w_zeros_(w_zeros, height / groupsize, width); + + // Groupsize version + + int group = row / groupsize; + + half w_scale = w_scales_.item(group, column); + uint32_t w_zero = w_zeros_.item(group, column) + 1; + + uint32_t w_read = w_.item_uint32_t(row, column); + half* out_ptr = out_.item_ptr(row, column); + + #pragma unroll + for (int s = 0; s < 32; s += 4) + { + half w_item = __hmul(__int2half_rn((int)((w_read >> s) & 0x0f) - w_zero), w_scale); + *out_ptr = w_item; out_ptr += out_.width; + } +} + +void Q4Matrix::reconstruct(half* out) +{ + dim3 threads(RECONS_THREADS_X, RECONS_THREADS_Y, 1); + + dim3 blocks + ( + (width + threads.x - 1) / threads.x, + (height / 8 + threads.y - 1) / threads.y, + 1 + ); + + reconstruct_kernel<<>>(cuda_qweight, out, cuda_scales, cuda_qzeros, height / 8, width, groupsize); +} diff --git a/colossalai/kernel/cuda_native/csrc/gptq/q4_matrix.cuh b/colossalai/kernel/cuda_native/csrc/gptq/q4_matrix.cuh new file mode 100644 index 0000000000000000000000000000000000000000..50cb72a41518593f6be4dded4f03c61772ef2ef9 --- /dev/null +++ b/colossalai/kernel/cuda_native/csrc/gptq/q4_matrix.cuh @@ -0,0 +1,53 @@ +// Adapted from turboderp exllama: https://github.com/turboderp/exllama + +#ifndef _q4_matrix_cuh +#define _q4_matrix_cuh + +#include +#include +#include + +class Q4Matrix +{ +public: + + int device; + + int height; + int width; + int groups; + int groupsize; + + uint32_t* cuda_qweight = NULL; + uint32_t* cuda_qzeros = NULL; + half* cuda_scales = NULL; + uint32_t* cuda_x_map = NULL; + + Q4Matrix + ( + const int _height, + const int _width, + const int _groups, + + uint32_t* _qweight, + uint32_t* _qzeros, + half* _scales, + uint32_t* _g_idx, + + const int _device + ); + + ~Q4Matrix(); + + void reconstruct(half* out); + +private: + + void make_sequential(const uint32_t* cpu_g_idx); + +}; + +void g_q4_keep_matrix(Q4Matrix* m); +void g_q4_free_matrices(); + +#endif \ No newline at end of file diff --git a/colossalai/kernel/cuda_native/csrc/gptq/tuning.h b/colossalai/kernel/cuda_native/csrc/gptq/tuning.h new file mode 100644 index 0000000000000000000000000000000000000000..e413b8a96c11afd7793128d27145325d26c63715 --- /dev/null +++ b/colossalai/kernel/cuda_native/csrc/gptq/tuning.h @@ -0,0 +1,12 @@ +// Adapted from turboderp exllama: https://github.com/turboderp/exllama + +#ifndef _tuning_h +#define _tuning_h + +struct ExLlamaTuning { + int matmul_recons_thd; + bool matmul_fused_remap; + bool matmul_no_half2; +}; + +#endif diff --git a/colossalai/kernel/cuda_native/csrc/gptq/util.cuh b/colossalai/kernel/cuda_native/csrc/gptq/util.cuh new file mode 100644 index 0000000000000000000000000000000000000000..7b397573214b2b1f1af50a82320f61aabee5c8f1 --- /dev/null +++ b/colossalai/kernel/cuda_native/csrc/gptq/util.cuh @@ -0,0 +1,33 @@ +// Adapted from turboderp exllama: https://github.com/turboderp/exllama + +#ifndef _util_cuh +#define _util_cuh + +#include +#include +#include +#include + +#if defined(USE_ROCM) +#define cudaUnspecified hipErrorUnknown +#else +#define cudaUnspecified cudaErrorApiFailureBase +#endif + +// React to failure on return code != cudaSuccess + +#define _cuda_check(fn) \ +do { \ + {_cuda_err = fn;} \ + if (_cuda_err != cudaSuccess) goto _cuda_fail; \ +} while(false) + +// React to failure on return code == 0 + +#define _alloc_check(fn) \ +do { \ + if (!(fn)) { _cuda_err = cudaUnspecified; goto _cuda_fail; } \ + else _cuda_err = cudaSuccess; \ +} while(false) + +#endif diff --git a/colossalai/kernel/cuda_native/csrc/kernels/cuda_util.cu b/colossalai/kernel/cuda_native/csrc/kernels/cuda_util.cu index 26efa2ad6f31632a4e7ceddd06745b067759bb43..9a6a8ebc39839228630cb12967a7b2a8048b2b6b 100644 --- a/colossalai/kernel/cuda_native/csrc/kernels/cuda_util.cu +++ b/colossalai/kernel/cuda_native/csrc/kernels/cuda_util.cu @@ -1,7 +1,6 @@ #include #include - #include "cuda_util.h" /* GPU function guard */ diff --git a/colossalai/kernel/cuda_native/csrc/kernels/dropout_kernels.cu b/colossalai/kernel/cuda_native/csrc/kernels/dropout_kernels.cu index a39a6dae0f7fb6968e6ee65fde8db4bbc5d61ab0..ce0b017f12e1e25d262cf377320966ca5044eb04 100644 --- a/colossalai/kernel/cuda_native/csrc/kernels/dropout_kernels.cu +++ b/colossalai/kernel/cuda_native/csrc/kernels/dropout_kernels.cu @@ -1,1002 +1,1002 @@ -#include -#include - -#include "kernels.h" - -#include - - -namespace cg = cooperative_groups; - -curandStatePhilox4_32_10_t *curandstate; - -/** - * @brief element-wise activation function on device, like Relu, Gelu - * - * @tparam enum class ActivationType, kRelu, kGelu - * @tparam input type - * @param any shape of float and __half2 - * @return same shape and type with input - */ -template -__forceinline__ __device__ T activation_kernel(T x); - -template <> -__device__ float activation_kernel(float x) { - float cdf = - 0.5f * - (1.0f + tanhf((0.7978845608028654f * (x + 0.044715f * x * x * x)))); - return x * cdf; -} - -template <> -__device__ __half2 -activation_kernel(__half2 val) { - __half2 val_pow3 = __hmul2(val, __hmul2(val, val)); - float2 tmp_pow = __half22float2(val_pow3); - float2 tmp = __half22float2(val); - - tmp.x = - 0.5f * - (1.0f + tanhf((0.7978845608028654f * (tmp.x + 0.044715f * tmp_pow.x)))); - tmp.y = - 0.5f * - (1.0f + tanhf((0.7978845608028654f * (tmp.y + 0.044715f * tmp_pow.y)))); - return __hmul2(val, __float22half2_rn(tmp)); -} - -template <> -__device__ float activation_kernel(float x) { - return fmaxf(x, 0); -} - -template <> -__device__ __half2 -activation_kernel(__half2 x) { - return __floats2half2_rn(fmaxf(0.f, __half2float(x.x)), - fmaxf(0.f, __half2float(x.y))); -} - -/** - * @brief element-wise activation backward function on device - * - * @tparam enum class ActivationType - * @tparam input type - * @param any shape of float and __half2 - * @return same shape of input - */ -template -__forceinline__ __device__ T activation_bwd_kernel(T grad, T x); - -template <> -__device__ float activation_bwd_kernel(float grad, - float x) { - const float sqrt_param = 0.79788456080286535587989211986876f; - const float mul_param = 0.044715; - - float x2mul = x * x * mul_param; - float tan_h = tanhf(sqrt_param * (x + x * x2mul)); - float dg1 = 0.5f * (1.0f + tan_h); - float dg2 = x * 0.5f * sqrt_param * (1 - tan_h * tan_h); - float dg3 = dg2 * 3 * x2mul; - return grad * (dg1 + dg2 + dg3); -} - -template <> -__device__ __half activation_bwd_kernel( - __half grad, __half x_half) { - float x = __half2float(x_half); - const float sqrt_param = 0.79788456080286535587989211986876f; - const float mul_param = 0.044715; - - float x2mul = x * x * mul_param; - float tan_h = tanhf(sqrt_param * (x + x * x2mul)); - float dg1 = 0.5f * (1.0f + tan_h); - float dg2 = x * 0.5f * sqrt_param * (1 - tan_h * tan_h); - float dg3 = dg2 * 3 * x2mul; - return grad * __float2half(dg1 + dg2 + dg3); -} - -template <> -__device__ float activation_bwd_kernel(float grad, - float x) { - return x > 0.f ? grad : 0.f; -} - -template <> -__device__ __half -activation_bwd_kernel(__half grad, __half x) { - const __half half_zero = __float2half(0.f); - return x > half_zero ? grad : half_zero; -} - -template <> -__device__ __half2 activation_bwd_kernel( - __half2 grad2, __half2 x_half2) { - const __half half_zero = __float2half(0.f); - return __floats2half2_rn(x_half2.x > half_zero ? grad2.x : half_zero, - x_half2.y > half_zero ? grad2.y : half_zero); -} - -/** - * @brief init curand states in global memory - * - * @thread grid_dim * block*dim to suuport any size of states - * @param state persistant curand states - * @param seed seed to init states - * @return void - */ -__global__ void curand_init_kernel(curandStatePhilox4_32_10_t *state, - int seed) { - /* Each thread gets same seed, a different sequence - number, no offset */ - int id = threadIdx.x + blockIdx.x * blockDim.x; - curand_init(seed, id, 0, &state[id]); -} - -void launch_curand_init(int total_count, int dim, cudaStream_t stream) { - cudaMalloc(&curandstate, total_count * sizeof(curandStatePhilox4_32_10_t)); - int grid_dim = total_count >> 9; - curand_init_kernel<<>>( - curandstate, std::chrono::duration_cast( - std::chrono::system_clock::now().time_since_epoch()) - .count()); -} - -/** - * @brief element-wise dropout, store dropped position in mask, it's not - * in-place - * - * @thread - * gridDim.x = total_count / 1024 - * blockDim.x = 1024 - * - * @param total_count total elements - * @param ratio drop ratio - * @param out any size of float and __half - * @param in same with out - * @param mask uint8 type, same size with out - * @param seed seed to curand - * @return void - */ -__global__ void ls_dropout_kernel(const int total_count, const float ratio, - float *__restrict__ out, - const float *__restrict__ in, - uint8_t *__restrict__ mask, const int seed) { - const float scale = 1.f / (1.f - ratio); - int i = blockIdx.x * blockDim.x + threadIdx.x; - - if (i * 4 >= total_count) return; - - curandStatePhilox4_32_10_t state; - curand_init(seed, i, 0, &state); - uint8_t m[4]; - - float4 *out4 = reinterpret_cast(out); - const float4 *data4 = reinterpret_cast(in); - uint32_t *mask4 = reinterpret_cast(mask); - float4 rand = curand_uniform4(&state); - - m[0] = (uint8_t)(rand.x > ratio); - m[1] = (uint8_t)(rand.y > ratio); - m[2] = (uint8_t)(rand.z > ratio); - m[3] = (uint8_t)(rand.w > ratio); - - uint32_t *m4 = reinterpret_cast(m); - mask4[i] = m4[0]; - - float4 input4 = data4[i]; - float4 res4; - res4.x = input4.x * scale * m[0]; - res4.y = input4.y * scale * m[1]; - res4.z = input4.z * scale * m[2]; - res4.w = input4.w * scale * m[3]; - out4[i] = res4; -} - -__global__ void ls_dropout_kernel(const int total_count, const float ratio, - __half *__restrict__ out, - const __half *__restrict__ in, - uint8_t *__restrict__ mask, const int seed) { - const float scale = 1.f / (1.f - ratio); - - int i = blockIdx.x * blockDim.x + threadIdx.x; - - if (i * 8 >= total_count) return; - - curandStatePhilox4_32_10_t state; - curand_init(seed, i, 0, &state); - - const float4 *vals_float4 = reinterpret_cast(in); - float4 *outs_float4 = reinterpret_cast(out); - uint64_t *mask8 = reinterpret_cast(mask); - - uint8_t m[8]; - float4 rand = curand_uniform4(&state); - m[0] = (uint8_t)(rand.x > ratio); - m[1] = (uint8_t)(rand.y > ratio); - m[2] = (uint8_t)(rand.z > ratio); - m[3] = (uint8_t)(rand.w > ratio); - rand = curand_uniform4(&state); - m[4] = (uint8_t)(rand.x > ratio); - m[5] = (uint8_t)(rand.y > ratio); - m[6] = (uint8_t)(rand.z > ratio); - m[7] = (uint8_t)(rand.w > ratio); - uint64_t *m8 = reinterpret_cast(m); - mask8[i] = *m8; - - float4 val_float4 = vals_float4[i]; - float4 out_float4; - __half2 *val_half2 = reinterpret_cast<__half2 *>(&val_float4); - __half2 *out_half2 = reinterpret_cast<__half2 *>(&out_float4); - __half2 scale_mask_1 = __floats2half2_rn(scale * m[0], scale * m[1]); - __half2 scale_mask_2 = __floats2half2_rn(scale * m[2], scale * m[3]); - __half2 scale_mask_3 = __floats2half2_rn(scale * m[4], scale * m[5]); - __half2 scale_mask_4 = __floats2half2_rn(scale * m[6], scale * m[7]); - out_half2[0] = __hmul2(val_half2[0], scale_mask_1); - out_half2[1] = __hmul2(val_half2[1], scale_mask_2); - out_half2[2] = __hmul2(val_half2[2], scale_mask_3); - out_half2[3] = __hmul2(val_half2[3], scale_mask_4); - outs_float4[i] = out_float4; -} - -/** - * @brief element-wise dropout backward with dropout mask, it's - * not in-place - * - * @thread - * gridDim.x = total_count / 1024 - * blockDim.x = 1024 - * - * @param total_count total elements - * @param ratio drop ratio - * @param in any size of float and __half - * @param mask uint8 type, same size with in - * @return void - */ -__global__ void ls_dropout_bwd_kernel(const int total_count, const float ratio, - float *out, const float *in, - const uint8_t *__restrict__ mask) { - const float scale = 1.f / (1.f - ratio); - int i = blockIdx.x * blockDim.x + threadIdx.x; - - if (i * 4 >= total_count) return; - - uint8_t m[4]; - - float4 *out4 = reinterpret_cast(out); - const float4 *in4 = reinterpret_cast(in); - const uint32_t *mask4 = reinterpret_cast(mask); - - uint32_t *m4 = reinterpret_cast(m); - m4[0] = mask4[i]; - - float4 input4 = in4[i]; - float4 res4; - res4.x = input4.x * scale * static_cast(m[0]); - res4.y = input4.y * scale * static_cast(m[1]); - res4.z = input4.z * scale * static_cast(m[2]); - res4.w = input4.w * scale * static_cast(m[3]); - out4[i] = res4; -} - -__global__ void ls_dropout_bwd_kernel(const int total_count, const float ratio, - __half *out, const __half *in, - const uint8_t *__restrict__ mask) { - const __half scale = 1.f / (1.f - ratio); - - int i = blockIdx.x * blockDim.x + threadIdx.x; - - if (i * 8 >= total_count) return; - - float4 *out4 = reinterpret_cast(out); - const float4 *vals_float4 = reinterpret_cast(in); - const uint64_t *mask8 = reinterpret_cast(mask); - - uint8_t m[8]; - uint64_t *m8 = reinterpret_cast(m); - m8[0] = mask8[i]; - - float4 val_float4 = vals_float4[i]; - float4 out_float4; - __half2 *val_half2 = reinterpret_cast<__half2 *>(&val_float4); - __half2 *out_half2 = reinterpret_cast<__half2 *>(&out_float4); - __half2 scale_mask_1 = - __halves2half2(scale * __float2half(m[0]), scale * __float2half(m[1])); - __half2 scale_mask_2 = - __halves2half2(scale * __float2half(m[2]), scale * __float2half(m[3])); - __half2 scale_mask_3 = - __halves2half2(scale * __float2half(m[4]), scale * __float2half(m[5])); - __half2 scale_mask_4 = - __halves2half2(scale * __float2half(m[6]), scale * __float2half(m[7])); - out_half2[0] = __hmul2(val_half2[0], scale_mask_1); - out_half2[1] = __hmul2(val_half2[1], scale_mask_2); - out_half2[2] = __hmul2(val_half2[2], scale_mask_3); - out_half2[3] = __hmul2(val_half2[3], scale_mask_4); - out4[i] = out_float4; -} - -template <> -void launch_ls_dropout(float *out, const float *vals, uint8_t *mask, - int total_count, float ratio, cudaStream_t stream, - bool backward) { - int grid_dim = total_count >> 12; - if (!backward) { - ls_dropout_kernel<<>>( - total_count, ratio, out, vals, mask, - std::chrono::duration_cast( - std::chrono::system_clock::now().time_since_epoch()) - .count()); - } else { - ls_dropout_bwd_kernel<<>>(total_count, ratio, - out, vals, mask); - } -} - -template <> -void launch_ls_dropout<__half>(__half *out, const __half *vals, uint8_t *mask, - int total_count, float ratio, - cudaStream_t stream, bool backward) { - int grid_dim = total_count >> 13; - if (!backward) { - ls_dropout_kernel<<>>( - total_count, ratio, out, vals, mask, - std::chrono::duration_cast( - std::chrono::system_clock::now().time_since_epoch()) - .count()); - } else { - ls_dropout_bwd_kernel<<>>(total_count, ratio, - out, vals, mask); - } -} - -/** - * @brief fused bias, dropout, and residual at the end of Attention and FFN, - * store dropped position in mask, it's not in-place - * - * @thread - * gridDim.x = total_count / 1024 - * blockDim.x = 1024 - * - * @param total_count total elements - * @param ratio drop ratio - * @param out [batch_size, seq_len, hidden_size], float and __half - * @param in [batch_size, seq_len, hidden_size], float and __half - * @param mask [batch_size, seq_len, hidden_size], uint8 type - * @param bias [hidden_size], ffn bias - * @param residual [batch_size, seq_len, hidden_size], float and __half - * @param seed seed to curand - * @param hidden_size hidden size - * @return void - */ -__global__ void ls_dropout_res_bias_kernel( - const int total_count, const float ratio, float *__restrict__ out, - const float *__restrict__ in, uint8_t *__restrict__ mask, - const float *__restrict__ bias, const float *__restrict__ residual, - const int seed, const int hidden_size) { - const float scale = 1.f / (1.f - ratio); - int i = blockIdx.x * blockDim.x + threadIdx.x; - - if (i * 4 >= total_count) return; - - curandStatePhilox4_32_10_t state; - curand_init(seed, i, 0, &state); - uint8_t m[4]; - - float4 *out4 = reinterpret_cast(out); - const float4 *data4 = reinterpret_cast(in); - const float4 *residual4 = reinterpret_cast(residual); - const float4 *bias4 = reinterpret_cast(bias); - uint32_t *mask4 = reinterpret_cast(mask); - float4 rand = curand_uniform4(&state); - - m[0] = static_cast(rand.x > ratio); - m[1] = static_cast(rand.y > ratio); - m[2] = static_cast(rand.z > ratio); - m[3] = static_cast(rand.w > ratio); - - int bias_i = i % (hidden_size >> 2); - uint32_t *m4 = reinterpret_cast(m); - mask4[i] = m4[0]; - const float4 input4 = data4[i]; - const float4 b4 = __ldg(&bias4[bias_i]); - const float4 res4 = residual4[i]; - float4 output4; - - output4.x = (input4.x + b4.x) * scale * m[0] + res4.x; - output4.y = (input4.y + b4.y) * scale * m[1] + res4.y; - output4.z = (input4.z + b4.z) * scale * m[2] + res4.z; - output4.w = (input4.w + b4.w) * scale * m[3] + res4.w; - - out4[i] = output4; -} - -__global__ void ls_dropout_res_bias_kernel( - const int total_count, const float ratio, __half *__restrict__ out, - const __half *__restrict__ in, uint8_t *__restrict__ mask, - const __half *__restrict__ bias, const __half *__restrict__ residual, - const int seed, const int hidden_size) { - const __half scale = 1. / (1. - ratio); - - int i = blockIdx.x * blockDim.x + threadIdx.x; - - if (i * 8 >= total_count) return; - - curandStatePhilox4_32_10_t state; - curand_init(seed, i, 0, &state); - - const float4 *vals_float4 = reinterpret_cast(in); - float4 *outs_float4 = reinterpret_cast(out); - const float4 *residual4 = reinterpret_cast(residual); - const float4 *bias4 = reinterpret_cast(bias); - uint64_t *mask8 = reinterpret_cast(mask); - - uint8_t m[8]; - float4 rand = curand_uniform4(&state); - m[0] = static_cast(rand.x > ratio); - m[1] = static_cast(rand.y > ratio); - m[2] = static_cast(rand.z > ratio); - m[3] = static_cast(rand.w > ratio); - rand = curand_uniform4(&state); - m[4] = static_cast(rand.x > ratio); - m[5] = static_cast(rand.y > ratio); - m[6] = static_cast(rand.z > ratio); - m[7] = static_cast(rand.w > ratio); - uint64_t *m8 = reinterpret_cast(m); - mask8[i] = m8[0]; - - int bias_i = i % (hidden_size >> 3); - float4 val_float4 = vals_float4[i]; - const float4 b4 = __ldg(&bias4[bias_i]); - const float4 res4 = residual4[i]; - float4 out_float4; - - __half2 *val_half2 = reinterpret_cast<__half2 *>(&val_float4); - __half2 *out_half2 = reinterpret_cast<__half2 *>(&out_float4); - const __half2 *b_half2 = reinterpret_cast(&b4); - const __half2 *res_half2 = reinterpret_cast(&res4); - __half2 scale_mask_1 = - __halves2half2(scale * __float2half(m[0]), scale * __float2half(m[1])); - __half2 scale_mask_2 = - __halves2half2(scale * __float2half(m[2]), scale * __float2half(m[3])); - __half2 scale_mask_3 = - __halves2half2(scale * __float2half(m[4]), scale * __float2half(m[5])); - __half2 scale_mask_4 = - __halves2half2(scale * __float2half(m[6]), scale * __float2half(m[7])); - out_half2[0] = - __hfma2(__hadd2(val_half2[0], b_half2[0]), scale_mask_1, res_half2[0]); - out_half2[1] = - __hfma2(__hadd2(val_half2[1], b_half2[1]), scale_mask_2, res_half2[1]); - out_half2[2] = - __hfma2(__hadd2(val_half2[2], b_half2[2]), scale_mask_3, res_half2[2]); - out_half2[3] = - __hfma2(__hadd2(val_half2[3], b_half2[3]), scale_mask_4, res_half2[3]); - outs_float4[i] = out_float4; -} - -template <> -void launch_ls_dropout_res_bias(float *out, const float *vals, - uint8_t *mask, const float *bias, - const float *residual, int total_count, - int dim, float ratio, - cudaStream_t stream) { - int grid_dim = total_count >> 12; - ls_dropout_res_bias_kernel<<>>( - total_count, ratio, out, vals, mask, bias, residual, - std::chrono::duration_cast( - std::chrono::system_clock::now().time_since_epoch()) - .count(), - dim); -} - -template <> -void launch_ls_dropout_res_bias<__half>(__half *out, const __half *vals, - uint8_t *mask, const __half *bias, - const __half *residual, int total_count, - int dim, float ratio, - cudaStream_t stream) { - int grid_dim = total_count >> 13; - ls_dropout_res_bias_kernel<<>>( - total_count, ratio, out, vals, mask, bias, residual, - std::chrono::duration_cast( - std::chrono::system_clock::now().time_since_epoch()) - .count(), - dim); -} - -/** - * @brief fused bias and dropout backward at the end of Attention and FFN - * - * @thread - * gridDim.x = hidden_size / 8 - * blockDim.x = 8 - * blockDim.y = 1024 / 8 = 128 - * - * @param row_size batch_size * seq_len - * @param ratio dropout ratio - * @param in_grad [batch_size, seq_len, hidden_size], input grad - * @param bias_grad [hidden_size], bias grad - * @param out_grad [batch_size, seq_len, hidden_size], output grad - * @param mask [batch_size, seq_len, hidden_size], dropout mask - * @param hidden_size - * @return void - */ -__global__ void ls_dropout_bias_bwd_kernel( - const int row_size, const float ratio, float *__restrict__ in_grad, - float *__restrict__ bias_grad, const float *__restrict__ out_grad, - const uint8_t *__restrict__ mask, const int hidden_size) { - const float scale = 1.f / (1.f - ratio); - // every block generate 8 bias result - __shared__ float tile[8][129]; - - cg::thread_block b = cg::this_thread_block(); - cg::thread_block_tile g = cg::tiled_partition(b); - - int col_idx = flat_2dim(blockIdx.x, threadIdx.x, 8); - int stride = hidden_size * 128; - float local_sum = 0; - - int idx = flat_2dim(threadIdx.y, col_idx, hidden_size); - for (int r = threadIdx.y; r < row_size; r += 128) { - float val = out_grad[idx]; - val *= scale * static_cast(mask[idx]); - local_sum += val; - in_grad[idx] = val; - idx += stride; - } - - tile[threadIdx.x][threadIdx.y] = local_sum; - __syncthreads(); - - float sum = 0; - int tid = threadIdx.y * blockDim.x + threadIdx.x; - int x = tid >> 7; - int y = tid & (127); - if (y < 32) { -#pragma unroll - for (int i = 0; i < 4; i++) { - sum += tile[x][y + i * 32]; - } - } - __syncthreads(); - - for (int i = 1; i < 32; i <<= 1) sum += g.shfl_down(sum, i); - - if (y == 0) tile[0][x] = sum; - __syncthreads(); - - if (threadIdx.x < 8) { - int pos = flat_2dim(blockIdx.x, threadIdx.x, 8); - bias_grad[pos] = tile[0][threadIdx.x]; - } -} - -__global__ void ls_dropout_bias_bwd_kernel( - const int row_size, const float ratio, __half *__restrict__ in_grad, - __half *__restrict__ bias_grad, const __half *__restrict__ out_grad, - const uint8_t *__restrict__ mask, const int hidden_size) { - const __half2 scale = __float2half2_rn(1.f / (1.f - ratio)); - __shared__ __half2 tile[8][129]; - - cg::thread_block b = cg::this_thread_block(); - cg::thread_block_tile g = cg::tiled_partition(b); - - __half2 *in_grad2 = reinterpret_cast<__half2 *>(in_grad); - const __half2 *out_grad2 = reinterpret_cast(out_grad); - __half2 *bias_grad2 = reinterpret_cast<__half2 *>(bias_grad); - - int col_idx = flat_2dim(blockIdx.x, threadIdx.x, 8); - int stride = hidden_size * 128; - __half2 local_sum = __float2half2_rn(0.f); - - int idx = flat_2dim(threadIdx.y, col_idx, hidden_size); - for (int r = threadIdx.y; r < row_size; r += 128) { - __half2 val = out_grad2[idx]; - __half2 m2 = __floats2half2_rn(mask[2 * idx], mask[2 * idx + 1]); - val *= scale * m2; - local_sum += val; - in_grad2[idx] = val; - idx += stride; - } - - tile[threadIdx.x][threadIdx.y] = local_sum; - __syncthreads(); - - __half2 sum = __float2half2_rn(0.f); - int tid = threadIdx.y * blockDim.x + threadIdx.x; - int x = tid >> 7; - int y = tid & (127); - if (y < 32) { -#pragma unroll - for (int i = 0; i < 4; i++) { - sum += tile[x][y + i * 32]; - } - } - __syncthreads(); - - for (int i = 1; i < WARP_SIZE; i <<= 1) sum += g.shfl_down(sum, i); - - if (y == 0) tile[0][x] = sum; - __syncthreads(); - - if (threadIdx.x < 8) { - int pos = flat_2dim(blockIdx.x, threadIdx.x, 8); - bias_grad2[pos] = tile[0][threadIdx.x]; - } -} - -template -void launch_ls_dropout_bias_bwd(T *in_grad, T *bias_grad, const T *out_grad, - const uint8_t *mask, int row_size, int dim, - float ratio, cudaStream_t stream) { - dim3 grid_dim((dim - 1) / 8 + 1); - dim3 block_dim(8, 128); - ls_dropout_bias_bwd_kernel<<>>( - row_size, ratio, in_grad, bias_grad, out_grad, mask, dim); -} - -template <> -void launch_ls_dropout_bias_bwd(__half *in_grad, __half *bias_grad, - const __half *out_grad, const uint8_t *mask, - int row_size, int dim, float ratio, - cudaStream_t stream) { - dim >>= 1; - dim3 grid_dim((dim - 1) / 8 + 1); - dim3 block_dim(8, 128); - ls_dropout_bias_bwd_kernel<<>>( - row_size, ratio, in_grad, bias_grad, out_grad, mask, dim); -} - -template void launch_ls_dropout_bias_bwd(float *in_grad, float *bias_grad, - const float *out_grad, - const uint8_t *mask, int row_size, - int dim, float ratio, - cudaStream_t stream); - -/** - * @brief fused bias, activation, and dropout at the end of first ffn - * - * @thread - * gridDim.x = hidden_size / 8 - * blockDim.x = 8 - * blockDim.y = 1024 / 8 = 128 - * - * @tparam act_type activation function, like kRelu, kGelu - * @param total_count total elements - * @param ratio drop ratio - * @param out [batch_size, seq_len, hidden_size], float and __half - * @param in [batch_size, seq_len, hidden_size], float and __half - * @param mask [batch_size, seq_len, hidden_size], uint8 type - * @param bias [hidden_size], ffn bias - * @param seed seed to curand - * @param hidden_size - * @return void - */ -template -__global__ void ls_dropout_act_bias_kernel( - const int total_count, const float ratio, float *__restrict__ out, - const float *__restrict__ in, uint8_t *__restrict__ mask, - const float *__restrict__ bias, const int seed, const int hidden_size) { - const float scale = 1.f / (1.f - ratio); - int i = blockIdx.x * blockDim.x + threadIdx.x; - - if (i * 4 >= total_count) return; - - curandStatePhilox4_32_10_t state; - curand_init(seed, i, 0, &state); - uint8_t m[4]; - - float4 *out4 = reinterpret_cast(out); - const float4 *data4 = reinterpret_cast(in); - const float4 *bias4 = reinterpret_cast(bias); - uint32_t *mask4 = reinterpret_cast(mask); - float4 rand = curand_uniform4(&state); - - m[0] = (uint8_t)(rand.x > ratio); - m[1] = (uint8_t)(rand.y > ratio); - m[2] = (uint8_t)(rand.z > ratio); - m[3] = (uint8_t)(rand.w > ratio); - - int bias_i = i % (hidden_size >> 2); - uint32_t *m4 = reinterpret_cast(m); - mask4[i] = m4[0]; - const float4 input4 = data4[i]; - const float4 b4 = __ldg(&bias4[bias_i]); - float4 output4; - - output4.x = - activation_kernel(input4.x + b4.x) * scale * m[0]; - output4.y = - activation_kernel(input4.y + b4.y) * scale * m[1]; - output4.z = - activation_kernel(input4.z + b4.z) * scale * m[2]; - output4.w = - activation_kernel(input4.w + b4.w) * scale * m[3]; - - out4[i] = output4; -} - -template -__global__ void ls_dropout_act_bias_kernel( - const int total_count, const float ratio, __half *__restrict__ out, - const __half *__restrict__ in, uint8_t *__restrict__ mask, - const __half *__restrict__ bias, const int seed, const int hidden_size) { - const float scale = 1.f / (1.f - ratio); - - int i = blockIdx.x * blockDim.x + threadIdx.x; - - if (i * 8 >= total_count) return; - - curandStatePhilox4_32_10_t state; - curand_init(seed, i, 0, &state); - - const float4 *vals_float4 = reinterpret_cast(in); - float4 *outs_float4 = reinterpret_cast(out); - const float4 *bias4 = reinterpret_cast(bias); - uint64_t *mask8 = reinterpret_cast(mask); - - uint8_t m[8]; - float4 rand = curand_uniform4(&state); - m[0] = (uint8_t)(rand.x > ratio); - m[1] = (uint8_t)(rand.y > ratio); - m[2] = (uint8_t)(rand.z > ratio); - m[3] = (uint8_t)(rand.w > ratio); - rand = curand_uniform4(&state); - m[4] = (uint8_t)(rand.x > ratio); - m[5] = (uint8_t)(rand.y > ratio); - m[6] = (uint8_t)(rand.z > ratio); - m[7] = (uint8_t)(rand.w > ratio); - uint64_t *m8 = reinterpret_cast(m); - mask8[i] = *m8; - - int bias_i = i % (hidden_size >> 3); - float4 val_float4 = vals_float4[i]; - const float4 b4 = __ldg(&bias4[bias_i]); - float4 out_float4; - - __half2 *val_half2 = reinterpret_cast<__half2 *>(&val_float4); - __half2 *out_half2 = reinterpret_cast<__half2 *>(&out_float4); - const __half2 *b_half2 = reinterpret_cast(&b4); - - __half2 scale_mask_1 = __floats2half2_rn(scale * m[0], scale * m[1]); - __half2 scale_mask_2 = __floats2half2_rn(scale * m[2], scale * m[3]); - __half2 scale_mask_3 = __floats2half2_rn(scale * m[4], scale * m[5]); - __half2 scale_mask_4 = __floats2half2_rn(scale * m[6], scale * m[7]); - out_half2[0] = __hmul2( - activation_kernel(__hadd2(val_half2[0], b_half2[0])), - scale_mask_1); - out_half2[1] = __hmul2( - activation_kernel(__hadd2(val_half2[1], b_half2[1])), - scale_mask_2); - out_half2[2] = __hmul2( - activation_kernel(__hadd2(val_half2[2], b_half2[2])), - scale_mask_3); - out_half2[3] = __hmul2( - activation_kernel(__hadd2(val_half2[3], b_half2[3])), - scale_mask_4); - outs_float4[i] = out_float4; -} - -template <> -void launch_ls_dropout_act_bias( - float *out, const float *vals, uint8_t *mask, const float *bias, - int total_count, int dim, float ratio, cudaStream_t stream) { - int grid_dim = total_count >> 10; - ls_dropout_act_bias_kernel - <<>>( - total_count, ratio, out, vals, mask, bias, - std::chrono::duration_cast( - std::chrono::system_clock::now().time_since_epoch()) - .count(), - dim); -} - -template <> -void launch_ls_dropout_act_bias( - __half *out, const __half *vals, uint8_t *mask, const __half *bias, - int total_count, int dim, float ratio, cudaStream_t stream) { - int grid_dim = total_count >> 11; - ls_dropout_act_bias_kernel - <<>>( - total_count, ratio, out, vals, mask, bias, - std::chrono::duration_cast( - std::chrono::system_clock::now().time_since_epoch()) - .count(), - dim); -} - -template <> -void launch_ls_dropout_act_bias( - float *out, const float *vals, uint8_t *mask, const float *bias, - int total_count, int dim, float ratio, cudaStream_t stream) { - int grid_dim = total_count >> 10; - ls_dropout_act_bias_kernel - <<>>( - total_count, ratio, out, vals, mask, bias, - std::chrono::duration_cast( - std::chrono::system_clock::now().time_since_epoch()) - .count(), - dim); -} - -template <> -void launch_ls_dropout_act_bias( - __half *out, const __half *vals, uint8_t *mask, const __half *bias, - int total_count, int dim, float ratio, cudaStream_t stream) { - int grid_dim = total_count >> 11; - ls_dropout_act_bias_kernel - <<>>( - total_count, ratio, out, vals, mask, bias, - std::chrono::duration_cast( - std::chrono::system_clock::now().time_since_epoch()) - .count(), - dim); -} - -/** - * @brief fused bias, activation, and dropout backward - * - * @thread - * gridDim.x = total_count / 1024 - * blockDim.x = 1024 - * - * @tparam act_type kRelu - * @param row_size batch_size * seq_len - * @param ratio dropout ratio - * @param in_grad [batch_size, seq_len, hidden_size], input grad - * @param bias_grad [hidden_size], bias grad - * @param out_grad [batch_size, seq_len, hidden_size], output grad - * @param mask [batch_size, seq_len, hidden_size], dropout mask - * @param hidden_size - * @return void - */ -template -__global__ void ls_dropout_act_bias_bwd_kernel( - const int row_size, const float ratio, T *in_grad, - T *__restrict__ bias_grad, const T *__restrict__ input, - const T *__restrict__ bias, const T *out_grad, - const uint8_t *__restrict__ mask, const int hidden_size) { - const float scale = 1.f / (1.f - ratio); - __shared__ float tile[WARP_SIZE][WARP_SIZE + 1]; - - cg::thread_block b = cg::this_thread_block(); - cg::thread_block_tile g = cg::tiled_partition(b); - - int col_idx = flat_2dim(blockIdx.x, threadIdx.x, WARP_SIZE); - - int stride = hidden_size * WARP_SIZE; - float local_sum = 0; - - int idx = flat_2dim(threadIdx.y, col_idx, hidden_size); - if (col_idx < hidden_size) { - for (int r = threadIdx.y; r < row_size; r += WARP_SIZE) { - float val = out_grad[idx]; - float in = input[idx]; - float b = bias[idx % hidden_size]; - val = activation_bwd_kernel( - val * scale * static_cast(mask[idx]), in + b); - local_sum += val; - in_grad[idx] = val; - idx += stride; - } - } - - tile[threadIdx.x][threadIdx.y] = local_sum; - __syncthreads(); - float sum = tile[threadIdx.y][threadIdx.x]; - __syncthreads(); - - for (int i = 1; i < WARP_SIZE; i <<= 1) sum += g.shfl_down(sum, i); - - if (threadIdx.x == 0) tile[0][threadIdx.y] = sum; - __syncthreads(); - - if (threadIdx.y == 0) { - int pos = flat_2dim(blockIdx.x, threadIdx.x, WARP_SIZE); - bias_grad[pos] = tile[0][threadIdx.x]; - } -} - -// @brief fused bias, activation, and dropout backward -// It is deprecated for precision reason. Keep it for future optimization. -// -// template -// __global__ void ls_dropout_act_bias_bwd_kernel( -// const int row_size, const float ratio, __half * in_grad, -// __half *__restrict__ bias_grad, const __half *__restrict__ input, const -// __half *__restrict__ bias, const __half * out_grad, const uint8_t -// *__restrict__ mask, const int hidden_size) { -// const __half2 scale = __float2half2_rn(1.f / (1.f - ratio)); -// __shared__ __half2 tile[WARP_SIZE][WARP_SIZE + 1]; - -// cg::thread_block b = cg::this_thread_block(); -// cg::thread_block_tile g = cg::tiled_partition(b); - -// __half2 *in_grad2 = reinterpret_cast<__half2 *>(in_grad); -// __half2 *bias_grad2 = reinterpret_cast<__half2 *>(bias_grad); -// const __half2 *out_grad2 = reinterpret_cast(out_grad); -// const __half2 *input2 = reinterpret_cast(input); -// const __half2 *bias2 = reinterpret_cast(bias); - -// int col_idx = flat_2dim(blockIdx.x, threadIdx.x, WARP_SIZE); - -// int stride = hidden_size * WARP_SIZE; -// __half2 local_sum = __float2half2_rn(0.f); - -// int idx = flat_2dim(threadIdx.y, col_idx, hidden_size); -// if (col_idx < hidden_size) { -// for (int r = threadIdx.y; r < row_size; r += WARP_SIZE) { -// __half2 val = out_grad2[idx]; -// __half2 in2 = input2[idx]; -// __half2 b2 = bias2[idx % hidden_size ]; -// __half2 m2 = __floats2half2_rn(mask[2 * idx], mask[2 * idx + 1]); -// val = activation_bwd_kernel(val * scale -// * -// m2, -// in2+b2); -// local_sum += val; -// in_grad2[idx] = val; -// idx += stride; -// } -// } - -// tile[threadIdx.x][threadIdx.y] = local_sum; -// __syncthreads(); -// __half2 sum = tile[threadIdx.y][threadIdx.x]; -// __syncthreads(); - -// for (int i = 1; i < WARP_SIZE; i <<= 1) sum += g.shfl_down(sum, i); - -// if (threadIdx.x == 0) tile[0][threadIdx.y] = sum; -// __syncthreads(); - -// if (threadIdx.y == 0) { -// int pos = flat_2dim(blockIdx.x, threadIdx.x, WARP_SIZE); -// bias_grad2[pos] = tile[0][threadIdx.x]; -// } -// } - -template -void launch_ls_dropout_act_bias_bwd(T *in_grad, T *bias_grad, const T *input, - const T *bias, const T *out_grad, - const uint8_t *mask, int row_size, int dim, - float ratio, cudaStream_t stream) { - dim3 grid_dim((dim - 1) / WARP_SIZE + 1); - dim3 block_dim(WARP_SIZE, WARP_SIZE); - ls_dropout_act_bias_bwd_kernel<<>>( - row_size, ratio, in_grad, bias_grad, input, bias, out_grad, mask, dim); -} - -// template <> -// void launch_ls_dropout_act_bias_bwd( -// __half *in_grad, __half *bias_grad,const __half *input, const __half -// *bias, const __half *out_grad, const uint8_t *mask, int row_size, int -// dim, float ratio, cudaStream_t stream) { -// dim >>= 1; -// dim3 grid_dim((dim - 1) / WARP_SIZE + 1); -// dim3 block_dim(WARP_SIZE, WARP_SIZE); -// ls_dropout_act_bias_bwd_kernel -// <<>>(row_size, ratio, in_grad, -// bias_grad, -// input, bias,out_grad, mask, dim); -// } - -template void launch_ls_dropout_act_bias_bwd( - float *in_grad, float *bias_grad, const float *input, const float *bias, - const float *out_grad, const uint8_t *mask, int row_size, int dim, - float ratio, cudaStream_t stream); - -template void launch_ls_dropout_act_bias_bwd( - __half *in_grad, __half *bias_grad, const __half *input, const __half *bias, - const __half *out_grad, const uint8_t *mask, int row_size, int dim, - float ratio, cudaStream_t stream); - -template void launch_ls_dropout_act_bias_bwd( - float *in_grad, float *bias_grad, const float *input, const float *bias, - const float *out_grad, const uint8_t *mask, int row_size, int dim, - float ratio, cudaStream_t stream); - -template void launch_ls_dropout_act_bias_bwd( - __half *in_grad, __half *bias_grad, const __half *input, const __half *bias, - const __half *out_grad, const uint8_t *mask, int row_size, int dim, - float ratio, cudaStream_t stream); +#include +#include + +#include "kernels.h" + +#include + + +namespace cg = cooperative_groups; + +curandStatePhilox4_32_10_t *curandstate; + +/** + * @brief element-wise activation function on device, like Relu, Gelu + * + * @tparam enum class ActivationType, kRelu, kGelu + * @tparam input type + * @param any shape of float and __half2 + * @return same shape and type with input + */ +template +__forceinline__ __device__ T activation_kernel(T x); + +template <> +__device__ float activation_kernel(float x) { + float cdf = + 0.5f * + (1.0f + tanhf((0.7978845608028654f * (x + 0.044715f * x * x * x)))); + return x * cdf; +} + +template <> +__device__ __half2 +activation_kernel(__half2 val) { + __half2 val_pow3 = __hmul2(val, __hmul2(val, val)); + float2 tmp_pow = __half22float2(val_pow3); + float2 tmp = __half22float2(val); + + tmp.x = + 0.5f * + (1.0f + tanhf((0.7978845608028654f * (tmp.x + 0.044715f * tmp_pow.x)))); + tmp.y = + 0.5f * + (1.0f + tanhf((0.7978845608028654f * (tmp.y + 0.044715f * tmp_pow.y)))); + return __hmul2(val, __float22half2_rn(tmp)); +} + +template <> +__device__ float activation_kernel(float x) { + return fmaxf(x, 0); +} + +template <> +__device__ __half2 +activation_kernel(__half2 x) { + return __floats2half2_rn(fmaxf(0.f, __half2float(x.x)), + fmaxf(0.f, __half2float(x.y))); +} + +/** + * @brief element-wise activation backward function on device + * + * @tparam enum class ActivationType + * @tparam input type + * @param any shape of float and __half2 + * @return same shape of input + */ +template +__forceinline__ __device__ T activation_bwd_kernel(T grad, T x); + +template <> +__device__ float activation_bwd_kernel(float grad, + float x) { + const float sqrt_param = 0.79788456080286535587989211986876f; + const float mul_param = 0.044715; + + float x2mul = x * x * mul_param; + float tan_h = tanhf(sqrt_param * (x + x * x2mul)); + float dg1 = 0.5f * (1.0f + tan_h); + float dg2 = x * 0.5f * sqrt_param * (1 - tan_h * tan_h); + float dg3 = dg2 * 3 * x2mul; + return grad * (dg1 + dg2 + dg3); +} + +template <> +__device__ __half activation_bwd_kernel( + __half grad, __half x_half) { + float x = __half2float(x_half); + const float sqrt_param = 0.79788456080286535587989211986876f; + const float mul_param = 0.044715; + + float x2mul = x * x * mul_param; + float tan_h = tanhf(sqrt_param * (x + x * x2mul)); + float dg1 = 0.5f * (1.0f + tan_h); + float dg2 = x * 0.5f * sqrt_param * (1 - tan_h * tan_h); + float dg3 = dg2 * 3 * x2mul; + return grad * __float2half(dg1 + dg2 + dg3); +} + +template <> +__device__ float activation_bwd_kernel(float grad, + float x) { + return x > 0.f ? grad : 0.f; +} + +template <> +__device__ __half +activation_bwd_kernel(__half grad, __half x) { + const __half half_zero = __float2half(0.f); + return x > half_zero ? grad : half_zero; +} + +template <> +__device__ __half2 activation_bwd_kernel( + __half2 grad2, __half2 x_half2) { + const __half half_zero = __float2half(0.f); + return __floats2half2_rn(x_half2.x > half_zero ? grad2.x : half_zero, + x_half2.y > half_zero ? grad2.y : half_zero); +} + +/** + * @brief init curand states in global memory + * + * @thread grid_dim * block*dim to suuport any size of states + * @param state persistant curand states + * @param seed seed to init states + * @return void + */ +__global__ void curand_init_kernel(curandStatePhilox4_32_10_t *state, + int seed) { + /* Each thread gets same seed, a different sequence + number, no offset */ + int id = threadIdx.x + blockIdx.x * blockDim.x; + curand_init(seed, id, 0, &state[id]); +} + +void launch_curand_init(int total_count, int dim, cudaStream_t stream) { + cudaMalloc(&curandstate, total_count * sizeof(curandStatePhilox4_32_10_t)); + int grid_dim = total_count >> 9; + curand_init_kernel<<>>( + curandstate, std::chrono::duration_cast( + std::chrono::system_clock::now().time_since_epoch()) + .count()); +} + +/** + * @brief element-wise dropout, store dropped position in mask, it's not + * in-place + * + * @thread + * gridDim.x = total_count / 1024 + * blockDim.x = 1024 + * + * @param total_count total elements + * @param ratio drop ratio + * @param out any size of float and __half + * @param in same with out + * @param mask uint8 type, same size with out + * @param seed seed to curand + * @return void + */ +__global__ void ls_dropout_kernel(const int total_count, const float ratio, + float *__restrict__ out, + const float *__restrict__ in, + uint8_t *__restrict__ mask, const int seed) { + const float scale = 1.f / (1.f - ratio); + int i = blockIdx.x * blockDim.x + threadIdx.x; + + if (i * 4 >= total_count) return; + + curandStatePhilox4_32_10_t state; + curand_init(seed, i, 0, &state); + uint8_t m[4]; + + float4 *out4 = reinterpret_cast(out); + const float4 *data4 = reinterpret_cast(in); + uint32_t *mask4 = reinterpret_cast(mask); + float4 rand = curand_uniform4(&state); + + m[0] = (uint8_t)(rand.x > ratio); + m[1] = (uint8_t)(rand.y > ratio); + m[2] = (uint8_t)(rand.z > ratio); + m[3] = (uint8_t)(rand.w > ratio); + + uint32_t *m4 = reinterpret_cast(m); + mask4[i] = m4[0]; + + float4 input4 = data4[i]; + float4 res4; + res4.x = input4.x * scale * m[0]; + res4.y = input4.y * scale * m[1]; + res4.z = input4.z * scale * m[2]; + res4.w = input4.w * scale * m[3]; + out4[i] = res4; +} + +__global__ void ls_dropout_kernel(const int total_count, const float ratio, + __half *__restrict__ out, + const __half *__restrict__ in, + uint8_t *__restrict__ mask, const int seed) { + const float scale = 1.f / (1.f - ratio); + + int i = blockIdx.x * blockDim.x + threadIdx.x; + + if (i * 8 >= total_count) return; + + curandStatePhilox4_32_10_t state; + curand_init(seed, i, 0, &state); + + const float4 *vals_float4 = reinterpret_cast(in); + float4 *outs_float4 = reinterpret_cast(out); + uint64_t *mask8 = reinterpret_cast(mask); + + uint8_t m[8]; + float4 rand = curand_uniform4(&state); + m[0] = (uint8_t)(rand.x > ratio); + m[1] = (uint8_t)(rand.y > ratio); + m[2] = (uint8_t)(rand.z > ratio); + m[3] = (uint8_t)(rand.w > ratio); + rand = curand_uniform4(&state); + m[4] = (uint8_t)(rand.x > ratio); + m[5] = (uint8_t)(rand.y > ratio); + m[6] = (uint8_t)(rand.z > ratio); + m[7] = (uint8_t)(rand.w > ratio); + uint64_t *m8 = reinterpret_cast(m); + mask8[i] = *m8; + + float4 val_float4 = vals_float4[i]; + float4 out_float4; + __half2 *val_half2 = reinterpret_cast<__half2 *>(&val_float4); + __half2 *out_half2 = reinterpret_cast<__half2 *>(&out_float4); + __half2 scale_mask_1 = __floats2half2_rn(scale * m[0], scale * m[1]); + __half2 scale_mask_2 = __floats2half2_rn(scale * m[2], scale * m[3]); + __half2 scale_mask_3 = __floats2half2_rn(scale * m[4], scale * m[5]); + __half2 scale_mask_4 = __floats2half2_rn(scale * m[6], scale * m[7]); + out_half2[0] = __hmul2(val_half2[0], scale_mask_1); + out_half2[1] = __hmul2(val_half2[1], scale_mask_2); + out_half2[2] = __hmul2(val_half2[2], scale_mask_3); + out_half2[3] = __hmul2(val_half2[3], scale_mask_4); + outs_float4[i] = out_float4; +} + +/** + * @brief element-wise dropout backward with dropout mask, it's + * not in-place + * + * @thread + * gridDim.x = total_count / 1024 + * blockDim.x = 1024 + * + * @param total_count total elements + * @param ratio drop ratio + * @param in any size of float and __half + * @param mask uint8 type, same size with in + * @return void + */ +__global__ void ls_dropout_bwd_kernel(const int total_count, const float ratio, + float *out, const float *in, + const uint8_t *__restrict__ mask) { + const float scale = 1.f / (1.f - ratio); + int i = blockIdx.x * blockDim.x + threadIdx.x; + + if (i * 4 >= total_count) return; + + uint8_t m[4]; + + float4 *out4 = reinterpret_cast(out); + const float4 *in4 = reinterpret_cast(in); + const uint32_t *mask4 = reinterpret_cast(mask); + + uint32_t *m4 = reinterpret_cast(m); + m4[0] = mask4[i]; + + float4 input4 = in4[i]; + float4 res4; + res4.x = input4.x * scale * static_cast(m[0]); + res4.y = input4.y * scale * static_cast(m[1]); + res4.z = input4.z * scale * static_cast(m[2]); + res4.w = input4.w * scale * static_cast(m[3]); + out4[i] = res4; +} + +__global__ void ls_dropout_bwd_kernel(const int total_count, const float ratio, + __half *out, const __half *in, + const uint8_t *__restrict__ mask) { + const __half scale = 1.f / (1.f - ratio); + + int i = blockIdx.x * blockDim.x + threadIdx.x; + + if (i * 8 >= total_count) return; + + float4 *out4 = reinterpret_cast(out); + const float4 *vals_float4 = reinterpret_cast(in); + const uint64_t *mask8 = reinterpret_cast(mask); + + uint8_t m[8]; + uint64_t *m8 = reinterpret_cast(m); + m8[0] = mask8[i]; + + float4 val_float4 = vals_float4[i]; + float4 out_float4; + __half2 *val_half2 = reinterpret_cast<__half2 *>(&val_float4); + __half2 *out_half2 = reinterpret_cast<__half2 *>(&out_float4); + __half2 scale_mask_1 = + __halves2half2(scale * __float2half(m[0]), scale * __float2half(m[1])); + __half2 scale_mask_2 = + __halves2half2(scale * __float2half(m[2]), scale * __float2half(m[3])); + __half2 scale_mask_3 = + __halves2half2(scale * __float2half(m[4]), scale * __float2half(m[5])); + __half2 scale_mask_4 = + __halves2half2(scale * __float2half(m[6]), scale * __float2half(m[7])); + out_half2[0] = __hmul2(val_half2[0], scale_mask_1); + out_half2[1] = __hmul2(val_half2[1], scale_mask_2); + out_half2[2] = __hmul2(val_half2[2], scale_mask_3); + out_half2[3] = __hmul2(val_half2[3], scale_mask_4); + out4[i] = out_float4; +} + +template <> +void launch_ls_dropout(float *out, const float *vals, uint8_t *mask, + int total_count, float ratio, cudaStream_t stream, + bool backward) { + int grid_dim = total_count >> 12; + if (!backward) { + ls_dropout_kernel<<>>( + total_count, ratio, out, vals, mask, + std::chrono::duration_cast( + std::chrono::system_clock::now().time_since_epoch()) + .count()); + } else { + ls_dropout_bwd_kernel<<>>(total_count, ratio, + out, vals, mask); + } +} + +template <> +void launch_ls_dropout<__half>(__half *out, const __half *vals, uint8_t *mask, + int total_count, float ratio, + cudaStream_t stream, bool backward) { + int grid_dim = total_count >> 13; + if (!backward) { + ls_dropout_kernel<<>>( + total_count, ratio, out, vals, mask, + std::chrono::duration_cast( + std::chrono::system_clock::now().time_since_epoch()) + .count()); + } else { + ls_dropout_bwd_kernel<<>>(total_count, ratio, + out, vals, mask); + } +} + +/** + * @brief fused bias, dropout, and residual at the end of Attention and FFN, + * store dropped position in mask, it's not in-place + * + * @thread + * gridDim.x = total_count / 1024 + * blockDim.x = 1024 + * + * @param total_count total elements + * @param ratio drop ratio + * @param out [batch_size, seq_len, hidden_size], float and __half + * @param in [batch_size, seq_len, hidden_size], float and __half + * @param mask [batch_size, seq_len, hidden_size], uint8 type + * @param bias [hidden_size], ffn bias + * @param residual [batch_size, seq_len, hidden_size], float and __half + * @param seed seed to curand + * @param hidden_size hidden size + * @return void + */ +__global__ void ls_dropout_res_bias_kernel( + const int total_count, const float ratio, float *__restrict__ out, + const float *__restrict__ in, uint8_t *__restrict__ mask, + const float *__restrict__ bias, const float *__restrict__ residual, + const int seed, const int hidden_size) { + const float scale = 1.f / (1.f - ratio); + int i = blockIdx.x * blockDim.x + threadIdx.x; + + if (i * 4 >= total_count) return; + + curandStatePhilox4_32_10_t state; + curand_init(seed, i, 0, &state); + uint8_t m[4]; + + float4 *out4 = reinterpret_cast(out); + const float4 *data4 = reinterpret_cast(in); + const float4 *residual4 = reinterpret_cast(residual); + const float4 *bias4 = reinterpret_cast(bias); + uint32_t *mask4 = reinterpret_cast(mask); + float4 rand = curand_uniform4(&state); + + m[0] = static_cast(rand.x > ratio); + m[1] = static_cast(rand.y > ratio); + m[2] = static_cast(rand.z > ratio); + m[3] = static_cast(rand.w > ratio); + + int bias_i = i % (hidden_size >> 2); + uint32_t *m4 = reinterpret_cast(m); + mask4[i] = m4[0]; + const float4 input4 = data4[i]; + const float4 b4 = __ldg(&bias4[bias_i]); + const float4 res4 = residual4[i]; + float4 output4; + + output4.x = (input4.x + b4.x) * scale * m[0] + res4.x; + output4.y = (input4.y + b4.y) * scale * m[1] + res4.y; + output4.z = (input4.z + b4.z) * scale * m[2] + res4.z; + output4.w = (input4.w + b4.w) * scale * m[3] + res4.w; + + out4[i] = output4; +} + +__global__ void ls_dropout_res_bias_kernel( + const int total_count, const float ratio, __half *__restrict__ out, + const __half *__restrict__ in, uint8_t *__restrict__ mask, + const __half *__restrict__ bias, const __half *__restrict__ residual, + const int seed, const int hidden_size) { + const __half scale = 1. / (1. - ratio); + + int i = blockIdx.x * blockDim.x + threadIdx.x; + + if (i * 8 >= total_count) return; + + curandStatePhilox4_32_10_t state; + curand_init(seed, i, 0, &state); + + const float4 *vals_float4 = reinterpret_cast(in); + float4 *outs_float4 = reinterpret_cast(out); + const float4 *residual4 = reinterpret_cast(residual); + const float4 *bias4 = reinterpret_cast(bias); + uint64_t *mask8 = reinterpret_cast(mask); + + uint8_t m[8]; + float4 rand = curand_uniform4(&state); + m[0] = static_cast(rand.x > ratio); + m[1] = static_cast(rand.y > ratio); + m[2] = static_cast(rand.z > ratio); + m[3] = static_cast(rand.w > ratio); + rand = curand_uniform4(&state); + m[4] = static_cast(rand.x > ratio); + m[5] = static_cast(rand.y > ratio); + m[6] = static_cast(rand.z > ratio); + m[7] = static_cast(rand.w > ratio); + uint64_t *m8 = reinterpret_cast(m); + mask8[i] = m8[0]; + + int bias_i = i % (hidden_size >> 3); + float4 val_float4 = vals_float4[i]; + const float4 b4 = __ldg(&bias4[bias_i]); + const float4 res4 = residual4[i]; + float4 out_float4; + + __half2 *val_half2 = reinterpret_cast<__half2 *>(&val_float4); + __half2 *out_half2 = reinterpret_cast<__half2 *>(&out_float4); + const __half2 *b_half2 = reinterpret_cast(&b4); + const __half2 *res_half2 = reinterpret_cast(&res4); + __half2 scale_mask_1 = + __halves2half2(scale * __float2half(m[0]), scale * __float2half(m[1])); + __half2 scale_mask_2 = + __halves2half2(scale * __float2half(m[2]), scale * __float2half(m[3])); + __half2 scale_mask_3 = + __halves2half2(scale * __float2half(m[4]), scale * __float2half(m[5])); + __half2 scale_mask_4 = + __halves2half2(scale * __float2half(m[6]), scale * __float2half(m[7])); + out_half2[0] = + __hfma2(__hadd2(val_half2[0], b_half2[0]), scale_mask_1, res_half2[0]); + out_half2[1] = + __hfma2(__hadd2(val_half2[1], b_half2[1]), scale_mask_2, res_half2[1]); + out_half2[2] = + __hfma2(__hadd2(val_half2[2], b_half2[2]), scale_mask_3, res_half2[2]); + out_half2[3] = + __hfma2(__hadd2(val_half2[3], b_half2[3]), scale_mask_4, res_half2[3]); + outs_float4[i] = out_float4; +} + +template <> +void launch_ls_dropout_res_bias(float *out, const float *vals, + uint8_t *mask, const float *bias, + const float *residual, int total_count, + int dim, float ratio, + cudaStream_t stream) { + int grid_dim = total_count >> 12; + ls_dropout_res_bias_kernel<<>>( + total_count, ratio, out, vals, mask, bias, residual, + std::chrono::duration_cast( + std::chrono::system_clock::now().time_since_epoch()) + .count(), + dim); +} + +template <> +void launch_ls_dropout_res_bias<__half>(__half *out, const __half *vals, + uint8_t *mask, const __half *bias, + const __half *residual, int total_count, + int dim, float ratio, + cudaStream_t stream) { + int grid_dim = total_count >> 13; + ls_dropout_res_bias_kernel<<>>( + total_count, ratio, out, vals, mask, bias, residual, + std::chrono::duration_cast( + std::chrono::system_clock::now().time_since_epoch()) + .count(), + dim); +} + +/** + * @brief fused bias and dropout backward at the end of Attention and FFN + * + * @thread + * gridDim.x = hidden_size / 8 + * blockDim.x = 8 + * blockDim.y = 1024 / 8 = 128 + * + * @param row_size batch_size * seq_len + * @param ratio dropout ratio + * @param in_grad [batch_size, seq_len, hidden_size], input grad + * @param bias_grad [hidden_size], bias grad + * @param out_grad [batch_size, seq_len, hidden_size], output grad + * @param mask [batch_size, seq_len, hidden_size], dropout mask + * @param hidden_size + * @return void + */ +__global__ void ls_dropout_bias_bwd_kernel( + const int row_size, const float ratio, float *__restrict__ in_grad, + float *__restrict__ bias_grad, const float *__restrict__ out_grad, + const uint8_t *__restrict__ mask, const int hidden_size) { + const float scale = 1.f / (1.f - ratio); + // every block generate 8 bias result + __shared__ float tile[8][129]; + + cg::thread_block b = cg::this_thread_block(); + cg::thread_block_tile g = cg::tiled_partition(b); + + int col_idx = flat_2dim(blockIdx.x, threadIdx.x, 8); + int stride = hidden_size * 128; + float local_sum = 0; + + int idx = flat_2dim(threadIdx.y, col_idx, hidden_size); + for (int r = threadIdx.y; r < row_size; r += 128) { + float val = out_grad[idx]; + val *= scale * static_cast(mask[idx]); + local_sum += val; + in_grad[idx] = val; + idx += stride; + } + + tile[threadIdx.x][threadIdx.y] = local_sum; + __syncthreads(); + + float sum = 0; + int tid = threadIdx.y * blockDim.x + threadIdx.x; + int x = tid >> 7; + int y = tid & (127); + if (y < 32) { +#pragma unroll + for (int i = 0; i < 4; i++) { + sum += tile[x][y + i * 32]; + } + } + __syncthreads(); + + for (int i = 1; i < 32; i <<= 1) sum += g.shfl_down(sum, i); + + if (y == 0) tile[0][x] = sum; + __syncthreads(); + + if (threadIdx.x < 8) { + int pos = flat_2dim(blockIdx.x, threadIdx.x, 8); + bias_grad[pos] = tile[0][threadIdx.x]; + } +} + +__global__ void ls_dropout_bias_bwd_kernel( + const int row_size, const float ratio, __half *__restrict__ in_grad, + __half *__restrict__ bias_grad, const __half *__restrict__ out_grad, + const uint8_t *__restrict__ mask, const int hidden_size) { + const __half2 scale = __float2half2_rn(1.f / (1.f - ratio)); + __shared__ __half2 tile[8][129]; + + cg::thread_block b = cg::this_thread_block(); + cg::thread_block_tile g = cg::tiled_partition(b); + + __half2 *in_grad2 = reinterpret_cast<__half2 *>(in_grad); + const __half2 *out_grad2 = reinterpret_cast(out_grad); + __half2 *bias_grad2 = reinterpret_cast<__half2 *>(bias_grad); + + int col_idx = flat_2dim(blockIdx.x, threadIdx.x, 8); + int stride = hidden_size * 128; + __half2 local_sum = __float2half2_rn(0.f); + + int idx = flat_2dim(threadIdx.y, col_idx, hidden_size); + for (int r = threadIdx.y; r < row_size; r += 128) { + __half2 val = out_grad2[idx]; + __half2 m2 = __floats2half2_rn(mask[2 * idx], mask[2 * idx + 1]); + val *= scale * m2; + local_sum += val; + in_grad2[idx] = val; + idx += stride; + } + + tile[threadIdx.x][threadIdx.y] = local_sum; + __syncthreads(); + + __half2 sum = __float2half2_rn(0.f); + int tid = threadIdx.y * blockDim.x + threadIdx.x; + int x = tid >> 7; + int y = tid & (127); + if (y < 32) { +#pragma unroll + for (int i = 0; i < 4; i++) { + sum += tile[x][y + i * 32]; + } + } + __syncthreads(); + + for (int i = 1; i < WARP_SIZE; i <<= 1) sum += g.shfl_down(sum, i); + + if (y == 0) tile[0][x] = sum; + __syncthreads(); + + if (threadIdx.x < 8) { + int pos = flat_2dim(blockIdx.x, threadIdx.x, 8); + bias_grad2[pos] = tile[0][threadIdx.x]; + } +} + +template +void launch_ls_dropout_bias_bwd(T *in_grad, T *bias_grad, const T *out_grad, + const uint8_t *mask, int row_size, int dim, + float ratio, cudaStream_t stream) { + dim3 grid_dim((dim - 1) / 8 + 1); + dim3 block_dim(8, 128); + ls_dropout_bias_bwd_kernel<<>>( + row_size, ratio, in_grad, bias_grad, out_grad, mask, dim); +} + +template <> +void launch_ls_dropout_bias_bwd(__half *in_grad, __half *bias_grad, + const __half *out_grad, const uint8_t *mask, + int row_size, int dim, float ratio, + cudaStream_t stream) { + dim >>= 1; + dim3 grid_dim((dim - 1) / 8 + 1); + dim3 block_dim(8, 128); + ls_dropout_bias_bwd_kernel<<>>( + row_size, ratio, in_grad, bias_grad, out_grad, mask, dim); +} + +template void launch_ls_dropout_bias_bwd(float *in_grad, float *bias_grad, + const float *out_grad, + const uint8_t *mask, int row_size, + int dim, float ratio, + cudaStream_t stream); + +/** + * @brief fused bias, activation, and dropout at the end of first ffn + * + * @thread + * gridDim.x = hidden_size / 8 + * blockDim.x = 8 + * blockDim.y = 1024 / 8 = 128 + * + * @tparam act_type activation function, like kRelu, kGelu + * @param total_count total elements + * @param ratio drop ratio + * @param out [batch_size, seq_len, hidden_size], float and __half + * @param in [batch_size, seq_len, hidden_size], float and __half + * @param mask [batch_size, seq_len, hidden_size], uint8 type + * @param bias [hidden_size], ffn bias + * @param seed seed to curand + * @param hidden_size + * @return void + */ +template +__global__ void ls_dropout_act_bias_kernel( + const int total_count, const float ratio, float *__restrict__ out, + const float *__restrict__ in, uint8_t *__restrict__ mask, + const float *__restrict__ bias, const int seed, const int hidden_size) { + const float scale = 1.f / (1.f - ratio); + int i = blockIdx.x * blockDim.x + threadIdx.x; + + if (i * 4 >= total_count) return; + + curandStatePhilox4_32_10_t state; + curand_init(seed, i, 0, &state); + uint8_t m[4]; + + float4 *out4 = reinterpret_cast(out); + const float4 *data4 = reinterpret_cast(in); + const float4 *bias4 = reinterpret_cast(bias); + uint32_t *mask4 = reinterpret_cast(mask); + float4 rand = curand_uniform4(&state); + + m[0] = (uint8_t)(rand.x > ratio); + m[1] = (uint8_t)(rand.y > ratio); + m[2] = (uint8_t)(rand.z > ratio); + m[3] = (uint8_t)(rand.w > ratio); + + int bias_i = i % (hidden_size >> 2); + uint32_t *m4 = reinterpret_cast(m); + mask4[i] = m4[0]; + const float4 input4 = data4[i]; + const float4 b4 = __ldg(&bias4[bias_i]); + float4 output4; + + output4.x = + activation_kernel(input4.x + b4.x) * scale * m[0]; + output4.y = + activation_kernel(input4.y + b4.y) * scale * m[1]; + output4.z = + activation_kernel(input4.z + b4.z) * scale * m[2]; + output4.w = + activation_kernel(input4.w + b4.w) * scale * m[3]; + + out4[i] = output4; +} + +template +__global__ void ls_dropout_act_bias_kernel( + const int total_count, const float ratio, __half *__restrict__ out, + const __half *__restrict__ in, uint8_t *__restrict__ mask, + const __half *__restrict__ bias, const int seed, const int hidden_size) { + const float scale = 1.f / (1.f - ratio); + + int i = blockIdx.x * blockDim.x + threadIdx.x; + + if (i * 8 >= total_count) return; + + curandStatePhilox4_32_10_t state; + curand_init(seed, i, 0, &state); + + const float4 *vals_float4 = reinterpret_cast(in); + float4 *outs_float4 = reinterpret_cast(out); + const float4 *bias4 = reinterpret_cast(bias); + uint64_t *mask8 = reinterpret_cast(mask); + + uint8_t m[8]; + float4 rand = curand_uniform4(&state); + m[0] = (uint8_t)(rand.x > ratio); + m[1] = (uint8_t)(rand.y > ratio); + m[2] = (uint8_t)(rand.z > ratio); + m[3] = (uint8_t)(rand.w > ratio); + rand = curand_uniform4(&state); + m[4] = (uint8_t)(rand.x > ratio); + m[5] = (uint8_t)(rand.y > ratio); + m[6] = (uint8_t)(rand.z > ratio); + m[7] = (uint8_t)(rand.w > ratio); + uint64_t *m8 = reinterpret_cast(m); + mask8[i] = *m8; + + int bias_i = i % (hidden_size >> 3); + float4 val_float4 = vals_float4[i]; + const float4 b4 = __ldg(&bias4[bias_i]); + float4 out_float4; + + __half2 *val_half2 = reinterpret_cast<__half2 *>(&val_float4); + __half2 *out_half2 = reinterpret_cast<__half2 *>(&out_float4); + const __half2 *b_half2 = reinterpret_cast(&b4); + + __half2 scale_mask_1 = __floats2half2_rn(scale * m[0], scale * m[1]); + __half2 scale_mask_2 = __floats2half2_rn(scale * m[2], scale * m[3]); + __half2 scale_mask_3 = __floats2half2_rn(scale * m[4], scale * m[5]); + __half2 scale_mask_4 = __floats2half2_rn(scale * m[6], scale * m[7]); + out_half2[0] = __hmul2( + activation_kernel(__hadd2(val_half2[0], b_half2[0])), + scale_mask_1); + out_half2[1] = __hmul2( + activation_kernel(__hadd2(val_half2[1], b_half2[1])), + scale_mask_2); + out_half2[2] = __hmul2( + activation_kernel(__hadd2(val_half2[2], b_half2[2])), + scale_mask_3); + out_half2[3] = __hmul2( + activation_kernel(__hadd2(val_half2[3], b_half2[3])), + scale_mask_4); + outs_float4[i] = out_float4; +} + +template <> +void launch_ls_dropout_act_bias( + float *out, const float *vals, uint8_t *mask, const float *bias, + int total_count, int dim, float ratio, cudaStream_t stream) { + int grid_dim = total_count >> 10; + ls_dropout_act_bias_kernel + <<>>( + total_count, ratio, out, vals, mask, bias, + std::chrono::duration_cast( + std::chrono::system_clock::now().time_since_epoch()) + .count(), + dim); +} + +template <> +void launch_ls_dropout_act_bias( + __half *out, const __half *vals, uint8_t *mask, const __half *bias, + int total_count, int dim, float ratio, cudaStream_t stream) { + int grid_dim = total_count >> 11; + ls_dropout_act_bias_kernel + <<>>( + total_count, ratio, out, vals, mask, bias, + std::chrono::duration_cast( + std::chrono::system_clock::now().time_since_epoch()) + .count(), + dim); +} + +template <> +void launch_ls_dropout_act_bias( + float *out, const float *vals, uint8_t *mask, const float *bias, + int total_count, int dim, float ratio, cudaStream_t stream) { + int grid_dim = total_count >> 10; + ls_dropout_act_bias_kernel + <<>>( + total_count, ratio, out, vals, mask, bias, + std::chrono::duration_cast( + std::chrono::system_clock::now().time_since_epoch()) + .count(), + dim); +} + +template <> +void launch_ls_dropout_act_bias( + __half *out, const __half *vals, uint8_t *mask, const __half *bias, + int total_count, int dim, float ratio, cudaStream_t stream) { + int grid_dim = total_count >> 11; + ls_dropout_act_bias_kernel + <<>>( + total_count, ratio, out, vals, mask, bias, + std::chrono::duration_cast( + std::chrono::system_clock::now().time_since_epoch()) + .count(), + dim); +} + +/** + * @brief fused bias, activation, and dropout backward + * + * @thread + * gridDim.x = total_count / 1024 + * blockDim.x = 1024 + * + * @tparam act_type kRelu + * @param row_size batch_size * seq_len + * @param ratio dropout ratio + * @param in_grad [batch_size, seq_len, hidden_size], input grad + * @param bias_grad [hidden_size], bias grad + * @param out_grad [batch_size, seq_len, hidden_size], output grad + * @param mask [batch_size, seq_len, hidden_size], dropout mask + * @param hidden_size + * @return void + */ +template +__global__ void ls_dropout_act_bias_bwd_kernel( + const int row_size, const float ratio, T *in_grad, + T *__restrict__ bias_grad, const T *__restrict__ input, + const T *__restrict__ bias, const T *out_grad, + const uint8_t *__restrict__ mask, const int hidden_size) { + const float scale = 1.f / (1.f - ratio); + __shared__ float tile[WARP_SIZE][WARP_SIZE + 1]; + + cg::thread_block b = cg::this_thread_block(); + cg::thread_block_tile g = cg::tiled_partition(b); + + int col_idx = flat_2dim(blockIdx.x, threadIdx.x, WARP_SIZE); + + int stride = hidden_size * WARP_SIZE; + float local_sum = 0; + + int idx = flat_2dim(threadIdx.y, col_idx, hidden_size); + if (col_idx < hidden_size) { + for (int r = threadIdx.y; r < row_size; r += WARP_SIZE) { + float val = out_grad[idx]; + float in = input[idx]; + float b = bias[idx % hidden_size]; + val = activation_bwd_kernel( + val * scale * static_cast(mask[idx]), in + b); + local_sum += val; + in_grad[idx] = val; + idx += stride; + } + } + + tile[threadIdx.x][threadIdx.y] = local_sum; + __syncthreads(); + float sum = tile[threadIdx.y][threadIdx.x]; + __syncthreads(); + + for (int i = 1; i < WARP_SIZE; i <<= 1) sum += g.shfl_down(sum, i); + + if (threadIdx.x == 0) tile[0][threadIdx.y] = sum; + __syncthreads(); + + if (threadIdx.y == 0) { + int pos = flat_2dim(blockIdx.x, threadIdx.x, WARP_SIZE); + bias_grad[pos] = tile[0][threadIdx.x]; + } +} + +// @brief fused bias, activation, and dropout backward +// It is deprecated for precision reason. Keep it for future optimization. +// +// template +// __global__ void ls_dropout_act_bias_bwd_kernel( +// const int row_size, const float ratio, __half * in_grad, +// __half *__restrict__ bias_grad, const __half *__restrict__ input, const +// __half *__restrict__ bias, const __half * out_grad, const uint8_t +// *__restrict__ mask, const int hidden_size) { +// const __half2 scale = __float2half2_rn(1.f / (1.f - ratio)); +// __shared__ __half2 tile[WARP_SIZE][WARP_SIZE + 1]; + +// cg::thread_block b = cg::this_thread_block(); +// cg::thread_block_tile g = cg::tiled_partition(b); + +// __half2 *in_grad2 = reinterpret_cast<__half2 *>(in_grad); +// __half2 *bias_grad2 = reinterpret_cast<__half2 *>(bias_grad); +// const __half2 *out_grad2 = reinterpret_cast(out_grad); +// const __half2 *input2 = reinterpret_cast(input); +// const __half2 *bias2 = reinterpret_cast(bias); + +// int col_idx = flat_2dim(blockIdx.x, threadIdx.x, WARP_SIZE); + +// int stride = hidden_size * WARP_SIZE; +// __half2 local_sum = __float2half2_rn(0.f); + +// int idx = flat_2dim(threadIdx.y, col_idx, hidden_size); +// if (col_idx < hidden_size) { +// for (int r = threadIdx.y; r < row_size; r += WARP_SIZE) { +// __half2 val = out_grad2[idx]; +// __half2 in2 = input2[idx]; +// __half2 b2 = bias2[idx % hidden_size ]; +// __half2 m2 = __floats2half2_rn(mask[2 * idx], mask[2 * idx + 1]); +// val = activation_bwd_kernel(val * scale +// * +// m2, +// in2+b2); +// local_sum += val; +// in_grad2[idx] = val; +// idx += stride; +// } +// } + +// tile[threadIdx.x][threadIdx.y] = local_sum; +// __syncthreads(); +// __half2 sum = tile[threadIdx.y][threadIdx.x]; +// __syncthreads(); + +// for (int i = 1; i < WARP_SIZE; i <<= 1) sum += g.shfl_down(sum, i); + +// if (threadIdx.x == 0) tile[0][threadIdx.y] = sum; +// __syncthreads(); + +// if (threadIdx.y == 0) { +// int pos = flat_2dim(blockIdx.x, threadIdx.x, WARP_SIZE); +// bias_grad2[pos] = tile[0][threadIdx.x]; +// } +// } + +template +void launch_ls_dropout_act_bias_bwd(T *in_grad, T *bias_grad, const T *input, + const T *bias, const T *out_grad, + const uint8_t *mask, int row_size, int dim, + float ratio, cudaStream_t stream) { + dim3 grid_dim((dim - 1) / WARP_SIZE + 1); + dim3 block_dim(WARP_SIZE, WARP_SIZE); + ls_dropout_act_bias_bwd_kernel<<>>( + row_size, ratio, in_grad, bias_grad, input, bias, out_grad, mask, dim); +} + +// template <> +// void launch_ls_dropout_act_bias_bwd( +// __half *in_grad, __half *bias_grad,const __half *input, const __half +// *bias, const __half *out_grad, const uint8_t *mask, int row_size, int +// dim, float ratio, cudaStream_t stream) { +// dim >>= 1; +// dim3 grid_dim((dim - 1) / WARP_SIZE + 1); +// dim3 block_dim(WARP_SIZE, WARP_SIZE); +// ls_dropout_act_bias_bwd_kernel +// <<>>(row_size, ratio, in_grad, +// bias_grad, +// input, bias,out_grad, mask, dim); +// } + +template void launch_ls_dropout_act_bias_bwd( + float *in_grad, float *bias_grad, const float *input, const float *bias, + const float *out_grad, const uint8_t *mask, int row_size, int dim, + float ratio, cudaStream_t stream); + +template void launch_ls_dropout_act_bias_bwd( + __half *in_grad, __half *bias_grad, const __half *input, const __half *bias, + const __half *out_grad, const uint8_t *mask, int row_size, int dim, + float ratio, cudaStream_t stream); + +template void launch_ls_dropout_act_bias_bwd( + float *in_grad, float *bias_grad, const float *input, const float *bias, + const float *out_grad, const uint8_t *mask, int row_size, int dim, + float ratio, cudaStream_t stream); + +template void launch_ls_dropout_act_bias_bwd( + __half *in_grad, __half *bias_grad, const __half *input, const __half *bias, + const __half *out_grad, const uint8_t *mask, int row_size, int dim, + float ratio, cudaStream_t stream); diff --git a/colossalai/kernel/cuda_native/csrc/kernels/general_kernels.cu b/colossalai/kernel/cuda_native/csrc/kernels/general_kernels.cu index bc90c54c0a004a3a847968a0d62c7ad8a999dcb3..625b02cd25d92f8dbd856cc8a09d72c38102f1af 100644 --- a/colossalai/kernel/cuda_native/csrc/kernels/general_kernels.cu +++ b/colossalai/kernel/cuda_native/csrc/kernels/general_kernels.cu @@ -1,232 +1,232 @@ -#include - -#include "kernels.h" - -namespace cg = cooperative_groups; - -/** -@brief: fuse_transpose_bias -Calculate the sum of elements in each column of the matrix. - -@thread -gridDim.x = ceil(cols / WARP_SIZE) -blockDim.x = WARP_SIZE -blockDim.y = WARP_SIZE - -@param -inp: [rows, cols] -out: [cols] -rows: the number of rows in the matrix -cols: the number of cols in the matrix -*/ -template -__global__ void column_sum_reduce(const T *__restrict__ inp, - T *__restrict__ out, int rows, int cols) { - __shared__ float tile[WARP_SIZE][WARP_SIZE]; - - cg::thread_block b = cg::this_thread_block(); - cg::thread_block_tile g = cg::tiled_partition(b); - - int idx = flat_2dim(blockIdx.x, threadIdx.x, WARP_SIZE); - int y_stride = cols * WARP_SIZE; - float localSum = 0; - - // Loop across matrix row - // TODO: optimize to log complexity - if (idx < cols) { - int offset = flat_2dim(threadIdx.y, idx, cols); - for (int r = threadIdx.y; r < rows; r += WARP_SIZE) { - localSum += (float)inp[offset]; - offset += y_stride; - } - } - - // The sum of a row in tile is equal to the sum of a col in original matrix - tile[threadIdx.x][threadIdx.y] = localSum; - - __syncthreads(); - - // Sum the shared buffer. - // The change of threadIdx.x is continuous - float sum = tile[threadIdx.y][threadIdx.x]; - - __syncthreads(); - - // Calculate the sum of a row in tile - for (int i = 1; i < WARP_SIZE; i <<= 1) sum += g.shfl_down(sum, i); - - if (threadIdx.x == 0) { - int pos = flat_2dim(blockIdx.x, threadIdx.y, WARP_SIZE); - if (pos < cols) out[pos] = sum; - } -} - -// [r, c] -> [c] -template <> -void launch_fuse_transpose_bias_kernel(const float *inp, float *out, - int rows, int cols, - cudaStream_t stream) { - dim3 grid_dim((cols - 1) / WARP_SIZE + 1); - dim3 block_dim(WARP_SIZE, WARP_SIZE); - - column_sum_reduce - <<>>(inp, out, rows, cols); -} - -template <> -void launch_fuse_transpose_bias_kernel<__half>(const __half *inp, __half *out, - int rows, int cols, - cudaStream_t stream) { - dim3 grid_dim((cols - 1) / WARP_SIZE + 1); - dim3 block_dim(WARP_SIZE, WARP_SIZE); - - column_sum_reduce<__half> - <<>>(inp, out, rows, cols); -} - -/** -@brief: fused_add2 -Add two matrix inp1 and inp2 to out. - -@thread -gridDim.x = batch_size * seq_len -blockDim.x = min(hidden_dim, MAX_THREADS) - -@param -inp1: [batch_size, seq_len, hidden_dim] -inp2: [batch_size, seq_len, hidden_dim] -out: [batch_size, seq_len, hidden_dim] -batch_size: the size of the current batch -seq_len: the sequence length of the current batch -hidden_dim: dim of the hidden tensor -*/ -template -__global__ void fused_add2_kernel(T *out, const T *inp1, const T *inp2, - int hidden_dim); - -template <> -__global__ void fused_add2_kernel(float *out, const float *inp1, - const float *inp2, int hidden_dim) { - int row_id = blockIdx.x; - int offset = flat_2dim(row_id, 0, hidden_dim); - - const float4 *inp1_4 = reinterpret_cast(inp1); - const float4 *inp2_4 = reinterpret_cast(inp2); - float4 *out_4 = reinterpret_cast(out); - float4 vinp1; - float4 vinp2; - float4 val; - - for (std::size_t i = threadIdx.x; i < hidden_dim; i += blockDim.x) { - vinp1 = inp1_4[offset + i]; - vinp2 = inp2_4[offset + i]; - val.x = vinp1.x + vinp2.x; - val.y = vinp1.y + vinp2.y; - val.z = vinp1.z + vinp2.z; - val.w = vinp1.w + vinp2.w; - out_4[offset + i] = val; - } -} - -template <> -__global__ void fused_add2_kernel<__half>(__half *out, const __half *inp1, - const __half *inp2, int hidden_dim) { - int row_id = blockIdx.x; - int offset = flat_2dim(row_id, 0, hidden_dim); - - const float4 *inp1_4 = reinterpret_cast(inp1); - const float4 *inp2_4 = reinterpret_cast(inp2); - float4 *out_4 = reinterpret_cast(out); - float4 vinp1; - float4 vinp2; - float4 val; - __half2 *h2_inp1 = reinterpret_cast<__half2 *>(&vinp1); - __half2 *h2_inp2 = reinterpret_cast<__half2 *>(&vinp2); - __half2 *h2_val = reinterpret_cast<__half2 *>(&val); - - for (std::size_t i = threadIdx.x; i < hidden_dim; i += blockDim.x) { - vinp1 = inp1_4[offset + i]; - vinp2 = inp2_4[offset + i]; - h2_val[0] = __hadd2(h2_inp1[0], h2_inp2[0]); - h2_val[1] = __hadd2(h2_inp1[1], h2_inp2[1]); - h2_val[2] = __hadd2(h2_inp1[2], h2_inp2[2]); - h2_val[3] = __hadd2(h2_inp1[3], h2_inp2[3]); - out_4[offset + i] = val; - } -} - -//[b, s, h] -> [b, s, h] -template <> -void launch_fused_add2(float *out, const float *inp1, const float *inp2, - int batch_size, int seq_len, int hidden_dim, - cudaStream_t &stream) { - hidden_dim >>= 2; - - dim3 grid_dim(batch_size * seq_len); - dim3 block_dim(min(hidden_dim, MAX_THREADS)); - - fused_add2_kernel<<>>(out, inp1, inp2, - hidden_dim); -} - -template <> -void launch_fused_add2<__half>(__half *out, const __half *inp1, - const __half *inp2, int batch_size, int seq_len, - int hidden_dim, cudaStream_t &stream) { - hidden_dim >>= 3; - - dim3 grid_dim(batch_size * seq_len); - dim3 block_dim(min(hidden_dim, MAX_THREADS)); - - fused_add2_kernel<<>>(out, inp1, inp2, - hidden_dim); -} - -template -__global__ void kernel_concat3_dim1(const T *inp1, const T *inp2, T *output, - int sz0, int sz2, int sz1_1, int sz1_2) { - int nele = sz0 * sz2 * (sz1_1 + sz1_2); - int idx = flat_2dim(blockIdx.x, threadIdx.x, blockDim.x); - if (idx >= nele) { - return; - } - float4 *dst_ptr = (float4 *)output + idx; - int idx2 = idx % sz2; - idx = idx / sz2; - int idx1 = idx % (sz1_1 + sz1_2); - int idx0 = idx / (sz1_1 + sz1_2); - float4 *src_ptr = nullptr; - int sz1 = 0; - if (idx1 < sz1_1) { - sz1 = sz1_1; - src_ptr = (float4 *)inp1; - } else { - idx1 -= sz1_1; - sz1 = sz1_2; - src_ptr = (float4 *)inp2; - } - src_ptr += flat_3dim(idx0, idx1, idx2, sz1, sz2); - dst_ptr[0] = src_ptr[0]; -} - -template <> -void launch_concat3_dim1(const float *inp1, const float *inp2, - float *output, int sz0, int sz2, int sz1_1, - int sz1_2, cudaStream_t stream) { - sz2 >>= 2; - int nele = sz0 * sz2 * (sz1_1 + sz1_2); - int nblock = (nele + MAX_THREADS - 1) / MAX_THREADS; - kernel_concat3_dim1<<>>( - inp1, inp2, output, sz0, sz2, sz1_1, sz1_2); -} - -template <> -void launch_concat3_dim1<__half>(const __half *inp1, const __half *inp2, - __half *output, int sz0, int sz2, int sz1_1, - int sz1_2, cudaStream_t stream) { - sz2 >>= 3; - int nele = sz0 * sz2 * (sz1_1 + sz1_2); - int nblock = (nele + MAX_THREADS - 1) / MAX_THREADS; - kernel_concat3_dim1<<>>( - inp1, inp2, output, sz0, sz2, sz1_1, sz1_2); -} +#include + +#include "kernels.h" + +namespace cg = cooperative_groups; + +/** +@brief: fuse_transpose_bias +Calculate the sum of elements in each column of the matrix. + +@thread +gridDim.x = ceil(cols / WARP_SIZE) +blockDim.x = WARP_SIZE +blockDim.y = WARP_SIZE + +@param +inp: [rows, cols] +out: [cols] +rows: the number of rows in the matrix +cols: the number of cols in the matrix +*/ +template +__global__ void column_sum_reduce(const T *__restrict__ inp, + T *__restrict__ out, int rows, int cols) { + __shared__ float tile[WARP_SIZE][WARP_SIZE]; + + cg::thread_block b = cg::this_thread_block(); + cg::thread_block_tile g = cg::tiled_partition(b); + + int idx = flat_2dim(blockIdx.x, threadIdx.x, WARP_SIZE); + int y_stride = cols * WARP_SIZE; + float localSum = 0; + + // Loop across matrix row + // TODO: optimize to log complexity + if (idx < cols) { + int offset = flat_2dim(threadIdx.y, idx, cols); + for (int r = threadIdx.y; r < rows; r += WARP_SIZE) { + localSum += (float)inp[offset]; + offset += y_stride; + } + } + + // The sum of a row in tile is equal to the sum of a col in original matrix + tile[threadIdx.x][threadIdx.y] = localSum; + + __syncthreads(); + + // Sum the shared buffer. + // The change of threadIdx.x is continuous + float sum = tile[threadIdx.y][threadIdx.x]; + + __syncthreads(); + + // Calculate the sum of a row in tile + for (int i = 1; i < WARP_SIZE; i <<= 1) sum += g.shfl_down(sum, i); + + if (threadIdx.x == 0) { + int pos = flat_2dim(blockIdx.x, threadIdx.y, WARP_SIZE); + if (pos < cols) out[pos] = sum; + } +} + +// [r, c] -> [c] +template <> +void launch_fuse_transpose_bias_kernel(const float *inp, float *out, + int rows, int cols, + cudaStream_t stream) { + dim3 grid_dim((cols - 1) / WARP_SIZE + 1); + dim3 block_dim(WARP_SIZE, WARP_SIZE); + + column_sum_reduce + <<>>(inp, out, rows, cols); +} + +template <> +void launch_fuse_transpose_bias_kernel<__half>(const __half *inp, __half *out, + int rows, int cols, + cudaStream_t stream) { + dim3 grid_dim((cols - 1) / WARP_SIZE + 1); + dim3 block_dim(WARP_SIZE, WARP_SIZE); + + column_sum_reduce<__half> + <<>>(inp, out, rows, cols); +} + +/** +@brief: fused_add2 +Add two matrix inp1 and inp2 to out. + +@thread +gridDim.x = batch_size * seq_len +blockDim.x = min(hidden_dim, MAX_THREADS) + +@param +inp1: [batch_size, seq_len, hidden_dim] +inp2: [batch_size, seq_len, hidden_dim] +out: [batch_size, seq_len, hidden_dim] +batch_size: the size of the current batch +seq_len: the sequence length of the current batch +hidden_dim: dim of the hidden tensor +*/ +template +__global__ void fused_add2_kernel(T *out, const T *inp1, const T *inp2, + int hidden_dim); + +template <> +__global__ void fused_add2_kernel(float *out, const float *inp1, + const float *inp2, int hidden_dim) { + int row_id = blockIdx.x; + int offset = flat_2dim(row_id, 0, hidden_dim); + + const float4 *inp1_4 = reinterpret_cast(inp1); + const float4 *inp2_4 = reinterpret_cast(inp2); + float4 *out_4 = reinterpret_cast(out); + float4 vinp1; + float4 vinp2; + float4 val; + + for (std::size_t i = threadIdx.x; i < hidden_dim; i += blockDim.x) { + vinp1 = inp1_4[offset + i]; + vinp2 = inp2_4[offset + i]; + val.x = vinp1.x + vinp2.x; + val.y = vinp1.y + vinp2.y; + val.z = vinp1.z + vinp2.z; + val.w = vinp1.w + vinp2.w; + out_4[offset + i] = val; + } +} + +template <> +__global__ void fused_add2_kernel<__half>(__half *out, const __half *inp1, + const __half *inp2, int hidden_dim) { + int row_id = blockIdx.x; + int offset = flat_2dim(row_id, 0, hidden_dim); + + const float4 *inp1_4 = reinterpret_cast(inp1); + const float4 *inp2_4 = reinterpret_cast(inp2); + float4 *out_4 = reinterpret_cast(out); + float4 vinp1; + float4 vinp2; + float4 val; + __half2 *h2_inp1 = reinterpret_cast<__half2 *>(&vinp1); + __half2 *h2_inp2 = reinterpret_cast<__half2 *>(&vinp2); + __half2 *h2_val = reinterpret_cast<__half2 *>(&val); + + for (std::size_t i = threadIdx.x; i < hidden_dim; i += blockDim.x) { + vinp1 = inp1_4[offset + i]; + vinp2 = inp2_4[offset + i]; + h2_val[0] = __hadd2(h2_inp1[0], h2_inp2[0]); + h2_val[1] = __hadd2(h2_inp1[1], h2_inp2[1]); + h2_val[2] = __hadd2(h2_inp1[2], h2_inp2[2]); + h2_val[3] = __hadd2(h2_inp1[3], h2_inp2[3]); + out_4[offset + i] = val; + } +} + +//[b, s, h] -> [b, s, h] +template <> +void launch_fused_add2(float *out, const float *inp1, const float *inp2, + int batch_size, int seq_len, int hidden_dim, + cudaStream_t &stream) { + hidden_dim >>= 2; + + dim3 grid_dim(batch_size * seq_len); + dim3 block_dim(min(hidden_dim, MAX_THREADS)); + + fused_add2_kernel<<>>(out, inp1, inp2, + hidden_dim); +} + +template <> +void launch_fused_add2<__half>(__half *out, const __half *inp1, + const __half *inp2, int batch_size, int seq_len, + int hidden_dim, cudaStream_t &stream) { + hidden_dim >>= 3; + + dim3 grid_dim(batch_size * seq_len); + dim3 block_dim(min(hidden_dim, MAX_THREADS)); + + fused_add2_kernel<<>>(out, inp1, inp2, + hidden_dim); +} + +template +__global__ void kernel_concat3_dim1(const T *inp1, const T *inp2, T *output, + int sz0, int sz2, int sz1_1, int sz1_2) { + int nele = sz0 * sz2 * (sz1_1 + sz1_2); + int idx = flat_2dim(blockIdx.x, threadIdx.x, blockDim.x); + if (idx >= nele) { + return; + } + float4 *dst_ptr = (float4 *)output + idx; + int idx2 = idx % sz2; + idx = idx / sz2; + int idx1 = idx % (sz1_1 + sz1_2); + int idx0 = idx / (sz1_1 + sz1_2); + float4 *src_ptr = nullptr; + int sz1 = 0; + if (idx1 < sz1_1) { + sz1 = sz1_1; + src_ptr = (float4 *)inp1; + } else { + idx1 -= sz1_1; + sz1 = sz1_2; + src_ptr = (float4 *)inp2; + } + src_ptr += flat_3dim(idx0, idx1, idx2, sz1, sz2); + dst_ptr[0] = src_ptr[0]; +} + +template <> +void launch_concat3_dim1(const float *inp1, const float *inp2, + float *output, int sz0, int sz2, int sz1_1, + int sz1_2, cudaStream_t stream) { + sz2 >>= 2; + int nele = sz0 * sz2 * (sz1_1 + sz1_2); + int nblock = (nele + MAX_THREADS - 1) / MAX_THREADS; + kernel_concat3_dim1<<>>( + inp1, inp2, output, sz0, sz2, sz1_1, sz1_2); +} + +template <> +void launch_concat3_dim1<__half>(const __half *inp1, const __half *inp2, + __half *output, int sz0, int sz2, int sz1_1, + int sz1_2, cudaStream_t stream) { + sz2 >>= 3; + int nele = sz0 * sz2 * (sz1_1 + sz1_2); + int nblock = (nele + MAX_THREADS - 1) / MAX_THREADS; + kernel_concat3_dim1<<>>( + inp1, inp2, output, sz0, sz2, sz1_1, sz1_2); +} diff --git a/colossalai/kernel/cuda_native/csrc/kernels/include/dropout.h b/colossalai/kernel/cuda_native/csrc/kernels/include/dropout.h index 563a7fe284a305adb65965c365cc600878e7085c..025fbf3f8f15cef0ecb72e827bdade248a58d72f 100644 --- a/colossalai/kernel/cuda_native/csrc/kernels/include/dropout.h +++ b/colossalai/kernel/cuda_native/csrc/kernels/include/dropout.h @@ -1,96 +1,96 @@ -#pragma once - -#include -#include -#include - -#include - -#include "kernels.h" - -template -class Dropout { - public: - struct Config { - float ratio; - bool training; - - Config(float r) : ratio(r), training(true) {} - float RATIO() const { return training ? ratio : 0.0; } - }; - - Dropout(const Config &config, size_t max_ele_num) - : _config(config), _mask(nullptr) { - _mask = cuda_malloc(max_ele_num); - } - - virtual ~Dropout() { cuda_free(_mask); } - - // after attention softmax - void dropout(T *output, const T *input, int count, cudaStream_t stream, - bool bwd = false) { - launch_ls_dropout(output, input, _mask, count, _config.RATIO(), stream, - bwd); - } - - void d_dropout(T *d_inp_out, int count, cudaStream_t stream) { - launch_ls_dropout(d_inp_out, d_inp_out, _mask, count, _config.RATIO(), - stream, true); - } - - // transformer layer's postprocessing dropout, after attn or ffn module, - // before residual add. - void bias_dropout_residual(T *output, const T *input, const T *residual, - const T *bias, int rows, int cols, - cudaStream_t stream) { - launch_ls_dropout_res_bias(output, input, _mask, bias, residual, - rows * cols, cols, _config.RATIO(), stream); - } - - void d_bias_dropout_residual(T *d_input, T *d_bias, const T *d_output, - int rows, int cols, cudaStream_t stream) { - launch_ls_dropout_bias_bwd(d_input, d_bias, d_output, _mask, rows, cols, - _config.RATIO(), stream); - } - - // dropout inside ffn. - void bias_act_dropout(T *output, const T *input, const T *bias, int rows, - int cols, std::string activation_fn, - cudaStream_t stream) { - if (activation_fn == "relu") { - launch_ls_dropout_act_bias( - output, input, _mask, bias, rows * cols, cols, _config.RATIO(), - stream); - } else if (activation_fn == "gelu") { - launch_ls_dropout_act_bias( - output, input, _mask, bias, rows * cols, cols, _config.RATIO(), - stream); - } else { - throw std::runtime_error("not supported activation: " + activation_fn); - } - } - - void d_bias_act_dropout(T *d_inp_out, T *d_bias_out, const T *input, - const T *bias, int rows, int cols, - std::string activation_fn, cudaStream_t stream) { - if (activation_fn == "relu") { - launch_ls_dropout_act_bias_bwd( - d_inp_out, d_bias_out, input, bias, d_inp_out, _mask, rows, cols, - _config.RATIO(), stream); - } else if (activation_fn == "gelu") { - launch_ls_dropout_act_bias_bwd( - d_inp_out, d_bias_out, input, bias, d_inp_out, _mask, rows, cols, - _config.RATIO(), stream); - } else { - throw std::runtime_error("not supported activation: " + activation_fn); - } - } - - bool HasDropout() const { return _config.RATIO() > 0.0; } - - void SetTrainingMode(bool training) { _config.training = training; } - - private: - uint8_t *_mask; - Config _config; -}; +#pragma once + +#include +#include +#include + +#include + +#include "kernels.h" + +template +class Dropout { + public: + struct Config { + float ratio; + bool training; + + Config(float r) : ratio(r), training(true) {} + float RATIO() const { return training ? ratio : 0.0; } + }; + + Dropout(const Config &config, size_t max_ele_num) + : _config(config), _mask(nullptr) { + _mask = cuda_malloc(max_ele_num); + } + + virtual ~Dropout() { cuda_free(_mask); } + + // after attention softmax + void dropout(T *output, const T *input, int count, cudaStream_t stream, + bool bwd = false) { + launch_ls_dropout(output, input, _mask, count, _config.RATIO(), stream, + bwd); + } + + void d_dropout(T *d_inp_out, int count, cudaStream_t stream) { + launch_ls_dropout(d_inp_out, d_inp_out, _mask, count, _config.RATIO(), + stream, true); + } + + // transformer layer's postprocessing dropout, after attn or ffn module, + // before residual add. + void bias_dropout_residual(T *output, const T *input, const T *residual, + const T *bias, int rows, int cols, + cudaStream_t stream) { + launch_ls_dropout_res_bias(output, input, _mask, bias, residual, + rows * cols, cols, _config.RATIO(), stream); + } + + void d_bias_dropout_residual(T *d_input, T *d_bias, const T *d_output, + int rows, int cols, cudaStream_t stream) { + launch_ls_dropout_bias_bwd(d_input, d_bias, d_output, _mask, rows, cols, + _config.RATIO(), stream); + } + + // dropout inside ffn. + void bias_act_dropout(T *output, const T *input, const T *bias, int rows, + int cols, std::string activation_fn, + cudaStream_t stream) { + if (activation_fn == "relu") { + launch_ls_dropout_act_bias( + output, input, _mask, bias, rows * cols, cols, _config.RATIO(), + stream); + } else if (activation_fn == "gelu") { + launch_ls_dropout_act_bias( + output, input, _mask, bias, rows * cols, cols, _config.RATIO(), + stream); + } else { + throw std::runtime_error("not supported activation: " + activation_fn); + } + } + + void d_bias_act_dropout(T *d_inp_out, T *d_bias_out, const T *input, + const T *bias, int rows, int cols, + std::string activation_fn, cudaStream_t stream) { + if (activation_fn == "relu") { + launch_ls_dropout_act_bias_bwd( + d_inp_out, d_bias_out, input, bias, d_inp_out, _mask, rows, cols, + _config.RATIO(), stream); + } else if (activation_fn == "gelu") { + launch_ls_dropout_act_bias_bwd( + d_inp_out, d_bias_out, input, bias, d_inp_out, _mask, rows, cols, + _config.RATIO(), stream); + } else { + throw std::runtime_error("not supported activation: " + activation_fn); + } + } + + bool HasDropout() const { return _config.RATIO() > 0.0; } + + void SetTrainingMode(bool training) { _config.training = training; } + + private: + uint8_t *_mask; + Config _config; +}; diff --git a/colossalai/kernel/cuda_native/csrc/kernels/include/kernels.h b/colossalai/kernel/cuda_native/csrc/kernels/include/kernels.h index fbb9c5465c24e98165f72b212854f51211e4e8a1..735e1363cc463044affef23f5698e77a89f540e2 100644 --- a/colossalai/kernel/cuda_native/csrc/kernels/include/kernels.h +++ b/colossalai/kernel/cuda_native/csrc/kernels/include/kernels.h @@ -3,10 +3,11 @@ #include #include #include -#include #include #include +#include + #define MAX_THREADS 1024 #define WARP_SIZE 32 @@ -132,8 +133,9 @@ __forceinline__ __host__ __device__ int flat_3dim(int id1, int id2, int id3, } /* Convert 4-dim tensor index into vector index */ -__forceinline__ __host__ __device__ int -flat_4dim(int id1, int id2, int id3, int id4, int dim2, int dim3, int dim4) { +__forceinline__ __host__ __device__ int flat_4dim(int id1, int id2, int id3, + int id4, int dim2, int dim3, + int dim4) { // return id1*(dim2*dim3*dim4) + id2*(dim3*dim4) + id3*dim4 + id4; int res = id4; @@ -201,9 +203,9 @@ __forceinline__ __host__ __device__ int flat_6dim(int id1, int id2, int id3, } /* Convert vector index to 6-dim tensor index */ -__forceinline__ __host__ __device__ void -decompose_6dim(int src, int dim1, int dim2, int dim3, int dim4, int dim5, - int *id0, int *id1, int *id2, int *id3, int *id4, int *id5) { +__forceinline__ __host__ __device__ void decompose_6dim( + int src, int dim1, int dim2, int dim3, int dim4, int dim5, int *id0, + int *id1, int *id2, int *id3, int *id4, int *id5) { *id5 = src % dim5; src /= dim5; @@ -221,9 +223,11 @@ decompose_6dim(int src, int dim1, int dim2, int dim3, int dim4, int dim5, } /* Convert vector index to 5-dim tensor index */ -__forceinline__ __host__ __device__ void -decompose_5dim(int src, int dim1, int dim2, int dim3, int dim4, int *id0, - int *id1, int *id2, int *id3, int *id4) { +__forceinline__ __host__ __device__ void decompose_5dim(int src, int dim1, + int dim2, int dim3, + int dim4, int *id0, + int *id1, int *id2, + int *id3, int *id4) { *id4 = src % dim4; src /= dim4; @@ -253,8 +257,9 @@ __forceinline__ __host__ __device__ void decompose_4dim(int src, int dim1, } /* Convert vector index to 3-dim tensor index */ -__forceinline__ __host__ __device__ void -decompose_3dim(int src, int dim1, int dim2, int *id0, int *id1, int *id2) { +__forceinline__ __host__ __device__ void decompose_3dim(int src, int dim1, + int dim2, int *id0, + int *id1, int *id2) { *id2 = src % dim2; src /= dim2; diff --git a/colossalai/kernel/cuda_native/csrc/kernels/include/normalize_layer.h b/colossalai/kernel/cuda_native/csrc/kernels/include/normalize_layer.h index ded5c0fdcbeee6f02d29544a577021f2407976d7..a7767e187ffc4951d0e737ec73f58292ca82351c 100644 --- a/colossalai/kernel/cuda_native/csrc/kernels/include/normalize_layer.h +++ b/colossalai/kernel/cuda_native/csrc/kernels/include/normalize_layer.h @@ -1,64 +1,65 @@ -#pragma once - -#include -#include -#include - -#include - -#include "kernels.h" - -using namespace std; - -template class Normalize_Layer { -public: - struct Config { - uint32_t hidden_dim; - bool use_mean; - Config(uint32_t hidden_dim, bool use_mean = false) - : hidden_dim(hidden_dim), use_mean(use_mean) {} - }; - - Normalize_Layer(Config config, size_t max_rows) - : config_(config), vars_(nullptr), means_(nullptr) { - vars_ = cuda_malloc(max_rows); - if (config_.use_mean) { - means_ = cuda_malloc(max_rows); - } - } - - ~Normalize_Layer() { - cuda_free(vars_); - cuda_free(means_); - } - - void Forward(T *ln_res, const T *inp, const T *gamma, const T *betta, - int batch_size, cudaStream_t stream) { - launch_layer_norm(ln_res, vars_, means_, inp, gamma, betta, batch_size, - config_.hidden_dim, stream); - } - - /* - residual_grad, inp_or_out, betta should be treated carefully. - inp_or_out = input if use_mean else output - residual_grad, betta can be nullptr. - residual_grad will be added to dinp if it is not nullptr - which is useful in transformer layer when pre-ln - betta are only used to compute xhat, - (use_mean == false) ^ (betta == nullptr) should be true - */ - void Backward(T *gamma_grad, T *betta_grad, T *inp_grad, const T *out_grad, - const T *residual_grad, const T *inp_or_out, const T *gamma, - const T *betta, int batch_size, cudaStream_t stream[2]) { - launch_ln_bw(gamma_grad, betta_grad, inp_grad, out_grad, residual_grad, - inp_or_out, gamma, betta, vars_, means_, batch_size, - config_.hidden_dim, stream); - } - - inline bool use_mean() const { return config_.use_mean; } - -private: - Config config_; - T *vars_; - T *means_; -}; +#pragma once + +#include +#include +#include + +#include + +#include "kernels.h" + +using namespace std; + +template +class Normalize_Layer { + public: + struct Config { + uint32_t hidden_dim; + bool use_mean; + Config(uint32_t hidden_dim, bool use_mean = false) + : hidden_dim(hidden_dim), use_mean(use_mean) {} + }; + + Normalize_Layer(Config config, size_t max_rows) + : config_(config), vars_(nullptr), means_(nullptr) { + vars_ = cuda_malloc(max_rows); + if (config_.use_mean) { + means_ = cuda_malloc(max_rows); + } + } + + ~Normalize_Layer() { + cuda_free(vars_); + cuda_free(means_); + } + + void Forward(T *ln_res, const T *inp, const T *gamma, const T *betta, + int batch_size, cudaStream_t stream) { + launch_layer_norm(ln_res, vars_, means_, inp, gamma, betta, batch_size, + config_.hidden_dim, stream); + } + + /* + residual_grad, inp_or_out, betta should be treated carefully. + inp_or_out = input if use_mean else output + residual_grad, betta can be nullptr. + residual_grad will be added to dinp if it is not nullptr + which is useful in transformer layer when pre-ln + betta are only used to compute xhat, + (use_mean == false) ^ (betta == nullptr) should be true + */ + void Backward(T *gamma_grad, T *betta_grad, T *inp_grad, const T *out_grad, + const T *residual_grad, const T *inp_or_out, const T *gamma, + const T *betta, int batch_size, cudaStream_t stream[2]) { + launch_ln_bw(gamma_grad, betta_grad, inp_grad, out_grad, residual_grad, + inp_or_out, gamma, betta, vars_, means_, batch_size, + config_.hidden_dim, stream); + } + + inline bool use_mean() const { return config_.use_mean; } + + private: + Config config_; + T *vars_; + T *means_; +}; diff --git a/colossalai/kernel/cuda_native/csrc/kernels/include/softmax.h b/colossalai/kernel/cuda_native/csrc/kernels/include/softmax.h index ec447ad84c54614839ef14838e23898191e97129..b917abaf0336a8399ce5900da03c94fb80eb54b5 100644 --- a/colossalai/kernel/cuda_native/csrc/kernels/include/softmax.h +++ b/colossalai/kernel/cuda_native/csrc/kernels/include/softmax.h @@ -1,42 +1,42 @@ -#pragma once - -#include -#include -#include - -#include - -#include "kernels.h" - -using namespace std; - -template -class Softmax { - public: - struct Config { - size_t nhead; - Config(size_t nhead) : nhead(nhead) {} - }; - - Softmax(Config config) : config_(config) {} - - ~Softmax() {} - - void Forward(T *vals, const T *attn_mask, int batch_size, int from_len, - int to_len, cudaStream_t &stream, bool mask_future = true) { - launch_attn_softmax(vals, attn_mask, batch_size, config_.nhead, from_len, - to_len, mask_future, stream); - } - - void Backward(T *out_grad, const T *soft_out, int batch_size, int from_len, - int to_len, cudaStream_t stream) { - launch_attn_softmax_bw(out_grad, soft_out, - batch_size * config_.nhead * from_len, to_len, - stream); - } - - void reset_size(size_t nhead) { config_.nhead = nhead; } - - private: - Config config_; -}; +#pragma once + +#include +#include +#include + +#include + +#include "kernels.h" + +using namespace std; + +template +class Softmax { + public: + struct Config { + size_t nhead; + Config(size_t nhead) : nhead(nhead) {} + }; + + Softmax(Config config) : config_(config) {} + + ~Softmax() {} + + void Forward(T *vals, const T *attn_mask, int batch_size, int from_len, + int to_len, cudaStream_t &stream, bool mask_future = true) { + launch_attn_softmax(vals, attn_mask, batch_size, config_.nhead, from_len, + to_len, mask_future, stream); + } + + void Backward(T *out_grad, const T *soft_out, int batch_size, int from_len, + int to_len, cudaStream_t stream) { + launch_attn_softmax_bw(out_grad, soft_out, + batch_size * config_.nhead * from_len, to_len, + stream); + } + + void reset_size(size_t nhead) { config_.nhead = nhead; } + + private: + Config config_; +}; diff --git a/colossalai/kernel/cuda_native/csrc/kernels/normalize_kernels.cu b/colossalai/kernel/cuda_native/csrc/kernels/normalize_kernels.cu index 3e61d4e35832cb9b2f90ebc9b64d8398f6d6b9e6..e2f1869b165e783e36ac53a4c2e0d7672cabc01d 100644 --- a/colossalai/kernel/cuda_native/csrc/kernels/normalize_kernels.cu +++ b/colossalai/kernel/cuda_native/csrc/kernels/normalize_kernels.cu @@ -1,1169 +1,1172 @@ -#include "block_reduce.h" -#include "kernels.h" -#include - -namespace cg = cooperative_groups; -const float LN_EPSILON = 1e-8f; -#define TILE_DIM 32 - -template __forceinline__ __device__ T add_eps(T x) { - return fabsf(x) > LN_EPSILON ? x : (x < 0 ? -LN_EPSILON : LN_EPSILON); -} - -/** -@brief: ker_layer_norm -Standard layer normalization. -It will not only output the layer norm result, - but also outputs variance. - may also output means, depends on whether - the means argument is nullptr - -@thread -gridDim.x = batch_size * seq_len -blockDim.x = hidden_size - -@param -ln_res: [batch_size* seq_len, hidden_size], ln result. -vars: [batch_size* seq_len], variance per token -means: [batch_size* seq_len], means per token, can be nullput -inp: [batch_size * seq_len, hidden_size], ln input. -scale: [hidden_size], ln scale -bias: [hidden_size], ln bias -*/ -template -__global__ void ker_layer_norm(T *ln_res, T *vars, T *means, const T *inp, - const T *scale, const T *bias, int hidden_size) { - // step 0. compute local sum - float l_sum = 0; - float l_square_sum = 0; - const float4 *inp_f4 = (const float4 *)inp + blockIdx.x * hidden_size; - for (uint idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) { - float4 val = inp_f4[idx]; - l_sum += val.x + val.y + val.z + val.w; - l_square_sum += - val.x * val.x + val.y * val.y + val.z * val.z + val.w * val.w; - } - - // step 1. compute reduce sum - float mean_dim = float(hidden_size) * 4.f; - float reduce_val[2] = {l_sum, l_square_sum}; - blockReduce(reduce_val); - __shared__ float s_mean, s_var; - if (threadIdx.x == 0) { - s_mean = reduce_val[0] / mean_dim; - if (means != nullptr) { - means[blockIdx.x] = s_mean; - } - s_var = reduce_val[1] / mean_dim - s_mean * s_mean + LN_EPSILON; - vars[blockIdx.x] = s_var; - s_var = rsqrtf(s_var); - } - __syncthreads(); - - // step 2. layer norm result - float4 *output_f4 = (float4 *)ln_res + blockIdx.x * hidden_size; - for (uint idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) { - float4 vscale = __ldg((const float4 *)scale + idx); - float4 vbias = __ldg((const float4 *)bias + idx); - float4 val = inp_f4[idx]; - val.x = (val.x - s_mean) * s_var * vscale.x + vbias.x; - val.y = (val.y - s_mean) * s_var * vscale.y + vbias.y; - val.z = (val.z - s_mean) * s_var * vscale.z + vbias.z; - val.w = (val.w - s_mean) * s_var * vscale.w + vbias.w; - output_f4[idx] = val; - } -} - -template <> -__global__ void ker_layer_norm<__half>(__half *ln_res, __half *vars, - __half *means, const __half *inp, - const __half *scale, const __half *bias, - int hidden_size) { - // step 0. compute local sum - float l_sum = 0; - float l_square_sum = 0; - const float4 *inp_f4 = (const float4 *)inp + blockIdx.x * hidden_size; - for (uint idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) { - float4 val_f4 = inp_f4[idx]; - __half2 *val_h2 = (__half2 *)(&val_f4); -#pragma unroll - for (int i = 0; i < 4; i++) { - float2 val_f2 = __half22float2(val_h2[i]); - l_sum += val_f2.x + val_f2.y; - l_square_sum += val_f2.x * val_f2.x + val_f2.y * val_f2.y; - } - } - - // step 1. compute reduce sum - float mean_dim = float(hidden_size) * 8.f; - float reduce_val[2] = {l_sum, l_square_sum}; - blockReduce(reduce_val); - __shared__ float s_mean, s_var; - if (threadIdx.x == 0) { - s_mean = reduce_val[0] / mean_dim; - if (means != nullptr) { - means[blockIdx.x] = s_mean; - } - s_var = reduce_val[1] / mean_dim - s_mean * s_mean + LN_EPSILON; - vars[blockIdx.x] = s_var; - s_var = rsqrtf(s_var); - } - __syncthreads(); - - // step 2. layer norm result - float4 *output_f4 = (float4 *)ln_res + blockIdx.x * hidden_size; - for (uint idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) { - // load scale, bias, input - float4 scale_f4 = __ldg((const float4 *)scale + idx); - __half2 *scale_h2 = (__half2 *)(&scale_f4); - float4 bias_f4 = __ldg((const float4 *)bias + idx); - __half2 *bias_h2 = (__half2 *)(&bias_f4); - float4 val_f4 = inp_f4[idx]; - __half2 *val_h2 = (__half2 *)(&val_f4); - -#pragma unroll - for (int i = 0; i < 4; i++) { - float2 scale_f2 = __half22float2(scale_h2[i]); - float2 bias_f2 = __half22float2(bias_h2[i]); - float2 val_f2 = __half22float2(val_h2[i]); - val_f2.x = (val_f2.x - s_mean) * s_var * scale_f2.x + bias_f2.x; - val_f2.y = (val_f2.y - s_mean) * s_var * scale_f2.y + bias_f2.y; - val_h2[i] = __float22half2_rn(val_f2); - } - output_f4[idx] = val_f4; - } -} - -// __global__ void ker_layer_norm_x2(__half *ln_res, __half *vars, -// __half *means, const __half *inp, -// const __half *scale, const __half -// *bias, int hidden_size) { -// // step 0. compute local sum -// float l_sum = 0; -// float l_square_sum = 0; -// const float4 *inp_f4 = (const float4 *)inp + blockIdx.x * 2 * hidden_size; -// for (uint idx = 2 * threadIdx.x; idx < hidden_size * 2; idx += blockDim.x * -// 2) { -// float4 val_f4 = inp_f4[idx]; -// float4 val_f4_1 = inp_f4[idx+1]; -// __half2 *val_h2 = (__half2 *)(&val_f4); -// __half2 *val_h2_1 = (__half2 *)(&val_f4_1); -// #pragma unroll -// for (int i = 0; i < 4; i++) { -// float2 val_f2 = __half22float2(val_h2[i]); -// float2 val_f2_1 = __half22float2(val_h2_1[i]); -// l_sum += val_f2.x + val_f2.y + val_f2_1.x + val_f2_1.y; -// l_square_sum += val_f2.x * val_f2.x + val_f2.y * val_f2.y + val_f2_1.x -// * val_f2_1.x + val_f2_1.y * val_f2_1.y; -// } -// } - -// // step 1. compute reduce sum -// float mean_dim = float(hidden_size) * 8.f * 2; -// float reduce_val[2] = {l_sum, l_square_sum}; -// blockReduce(reduce_val); -// __shared__ float s_mean, s_var; -// if (threadIdx.x == 0) { -// s_mean = reduce_val[0] / mean_dim; -// if (means != nullptr) { -// means[blockIdx.x] = s_mean; -// } -// s_var = reduce_val[1] / mean_dim - s_mean * s_mean + LN_EPSILON; -// vars[blockIdx.x] = s_var; -// s_var = rsqrtf(s_var); -// } -// __syncthreads(); - -// // step 2. layer norm result -// float4 *output_f4 = (float4 *)ln_res + blockIdx.x * hidden_size * 2; -// for (uint idx = 2 * threadIdx.x; idx < hidden_size * 2; idx += blockDim.x * -// 2) { -// // load scale, bias, input -// float4 scale_f4 = __ldg((const float4 *)scale + idx); -// __half2 *scale_h2 = (__half2 *)(&scale_f4); -// float4 scale_f4_1 = __ldg((const float4 *)scale + idx + 1); -// __half2 *scale_h2_1 = (__half2 *)(&scale_f4_1); -// float4 bias_f4 = __ldg((const float4 *)bias + idx); -// __half2 *bias_h2 = (__half2 *)(&bias_f4); -// float4 bias_f4_1 = __ldg((const float4 *)bias + idx + 1); -// __half2 *bias_h2_1 = (__half2 *)(&bias_f4_1); -// float4 val_f4 = inp_f4[idx]; -// __half2 *val_h2 = (__half2 *)(&val_f4); -// float4 val_f4_1 = inp_f4[idx+1]; -// __half2 *val_h2_1 = (__half2 *)(&val_f4_1); - -// #pragma unroll -// for (int i = 0; i < 4; i++) { -// float2 scale_f2 = __half22float2(scale_h2[i]); -// float2 scale_f2_1 = __half22float2(scale_h2_1[i]); -// float2 bias_f2 = __half22float2(bias_h2[i]); -// float2 bias_f2_1 = __half22float2(bias_h2_1[i]); -// float2 val_f2 = __half22float2(val_h2[i]); -// float2 val_f2_1 = __half22float2(val_h2_1[i]); -// val_f2.x = (val_f2.x - s_mean) * s_var * scale_f2.x + bias_f2.x; -// val_f2.y = (val_f2.y - s_mean) * s_var * scale_f2.y + bias_f2.y; -// val_h2[i] = __float22half2_rn(val_f2); -// val_f2_1.x = (val_f2_1.x - s_mean) * s_var * scale_f2_1.x + -// bias_f2_1.x; val_f2_1.y = (val_f2_1.y - s_mean) * s_var * scale_f2_1.y -// + bias_f2_1.y; val_h2_1[i] = __float22half2_rn(val_f2_1); -// } -// output_f4[idx] = val_f4; -// output_f4[idx+1] = val_f4_1; -// } -// } - -// __global__ void ker_layer_norm_x4(__half *ln_res, __half *vars, -// __half *means, const __half *inp, -// const __half *scale, const __half -// *bias, int hidden_size) { -// // step 0. compute local sum -// float l_sum = 0; -// float l_square_sum = 0; -// const float4 *inp_f4 = (const float4 *)inp + blockIdx.x * hidden_size * 4; -// for (uint idx = 4 * threadIdx.x; idx < hidden_size * 4; idx += blockDim.x * -// 4) { -// float4 val_f4 = inp_f4[idx]; -// float4 val_f4_1 = inp_f4[idx+1]; -// float4 val_f4_2 = inp_f4[idx+2]; -// float4 val_f4_3 = inp_f4[idx+3]; -// __half2 *val_h2 = (__half2 *)(&val_f4); -// __half2 *val_h2_1 = (__half2 *)(&val_f4_1); -// __half2 *val_h2_2 = (__half2 *)(&val_f4_2); -// __half2 *val_h2_3 = (__half2 *)(&val_f4_3); -// #pragma unroll -// for (int i = 0; i < 4; i++) { -// float2 val_f2 = __half22float2(val_h2[i]); -// float2 val_f2_1 = __half22float2(val_h2_1[i]); -// float2 val_f2_2 = __half22float2(val_h2_2[i]); -// float2 val_f2_3 = __half22float2(val_h2_3[i]); -// l_sum += val_f2.x + val_f2.y + val_f2_1.x + val_f2_1.y + val_f2_2.x + -// val_f2_2.y + val_f2_3.x + val_f2_3.y; l_square_sum += val_f2.x * -// val_f2.x + val_f2.y * val_f2.y; l_square_sum += val_f2_1.x * val_f2_1.x -// + val_f2_1.y * val_f2_1.y; l_square_sum += val_f2_2.x * val_f2_2.x + -// val_f2_2.y * val_f2_2.y; l_square_sum += val_f2_3.x * val_f2_3.x + -// val_f2_3.y * val_f2_3.y; -// } -// } - -// // step 1. compute reduce sum -// float mean_dim = float(hidden_size) * 8.f * 4; -// float reduce_val[2] = {l_sum, l_square_sum}; -// blockReduce(reduce_val); -// __shared__ float s_mean, s_var; -// if (threadIdx.x == 0) { -// s_mean = reduce_val[0] / mean_dim; -// if (means != nullptr) { -// means[blockIdx.x] = s_mean; -// } -// s_var = reduce_val[1] / mean_dim - s_mean * s_mean + LN_EPSILON; -// vars[blockIdx.x] = s_var; -// s_var = rsqrtf(s_var); -// } -// __syncthreads(); - -// // step 2. layer norm result -// float4 *output_f4 = (float4 *)ln_res + blockIdx.x * hidden_size * 4; -// for (uint idx = 4 * threadIdx.x; idx < hidden_size * 4; idx += blockDim.x * -// 4) { -// // load scale, bias, input -// float4 scale_f4 = __ldg((const float4 *)scale + idx); -// __half2 *scale_h2 = (__half2 *)(&scale_f4); -// float4 scale_f4_1 = __ldg((const float4 *)scale + idx + 1); -// __half2 *scale_h2_1 = (__half2 *)(&scale_f4_1); -// float4 scale_f4_2 = __ldg((const float4 *)scale + idx + 2); -// __half2 *scale_h2_2 = (__half2 *)(&scale_f4_2); -// float4 scale_f4_3 = __ldg((const float4 *)scale + idx + 3); -// __half2 *scale_h2_3 = (__half2 *)(&scale_f4_3); -// float4 bias_f4 = __ldg((const float4 *)bias + idx); -// __half2 *bias_h2 = (__half2 *)(&bias_f4); -// float4 bias_f4_1 = __ldg((const float4 *)bias + idx + 1); -// __half2 *bias_h2_1 = (__half2 *)(&bias_f4_1); -// float4 bias_f4_2 = __ldg((const float4 *)bias + idx + 2); -// __half2 *bias_h2_2 = (__half2 *)(&bias_f4_2); -// float4 bias_f4_3 = __ldg((const float4 *)bias + idx + 3); -// __half2 *bias_h2_3 = (__half2 *)(&bias_f4_3); -// float4 val_f4 = inp_f4[idx]; -// __half2 *val_h2 = (__half2 *)(&val_f4); -// float4 val_f4_1 = inp_f4[idx+1]; -// __half2 *val_h2_1 = (__half2 *)(&val_f4_1); -// float4 val_f4_2 = inp_f4[idx+2]; -// __half2 *val_h2_2 = (__half2 *)(&val_f4_2); -// float4 val_f4_3 = inp_f4[idx+3]; -// __half2 *val_h2_3 = (__half2 *)(&val_f4_3); - -// #pragma unroll -// for (int i = 0; i < 4; i++) { -// float2 scale_f2 = __half22float2(scale_h2[i]); -// float2 scale_f2_1 = __half22float2(scale_h2_1[i]); -// float2 scale_f2_2 = __half22float2(scale_h2_2[i]); -// float2 scale_f2_3 = __half22float2(scale_h2_3[i]); -// float2 bias_f2 = __half22float2(bias_h2[i]); -// float2 bias_f2_1 = __half22float2(bias_h2_1[i]); -// float2 bias_f2_2 = __half22float2(bias_h2_2[i]); -// float2 bias_f2_3 = __half22float2(bias_h2_3[i]); -// float2 val_f2 = __half22float2(val_h2[i]); -// float2 val_f2_1 = __half22float2(val_h2_1[i]); -// float2 val_f2_2 = __half22float2(val_h2_2[i]); -// float2 val_f2_3 = __half22float2(val_h2_3[i]); -// val_f2.x = (val_f2.x - s_mean) * s_var * scale_f2.x + bias_f2.x; -// val_f2.y = (val_f2.y - s_mean) * s_var * scale_f2.y + bias_f2.y; -// val_f2_1.x = (val_f2_1.x - s_mean) * s_var * scale_f2_1.x + -// bias_f2_1.x; val_f2_1.y = (val_f2_1.y - s_mean) * s_var * scale_f2_1.y -// + bias_f2_1.y; val_f2_2.x = (val_f2_2.x - s_mean) * s_var * -// scale_f2_2.x + bias_f2_2.x; val_f2_2.y = (val_f2_2.y - s_mean) * s_var -// * scale_f2_2.y + bias_f2_2.y; val_f2_3.x = (val_f2_3.x - s_mean) * -// s_var * scale_f2_3.x + bias_f2_3.x; val_f2_3.y = (val_f2_3.y - s_mean) -// * s_var * scale_f2_3.y + bias_f2_3.y; val_h2[i] = -// __float22half2_rn(val_f2); val_h2_1[i] = __float22half2_rn(val_f2_1); -// val_h2_2[i] = __float22half2_rn(val_f2_2); -// val_h2_3[i] = __float22half2_rn(val_f2_3); -// } -// output_f4[idx] = val_f4; -// output_f4[idx+1] = val_f4_1; -// output_f4[idx+2] = val_f4_2; -// output_f4[idx+3] = val_f4_3; -// } -// } - -template <> -void launch_layer_norm(float *ln_res, float *vars, float *means, - const float *inp, const float *scale, - const float *bias, int batch_size, int hidden_dim, - cudaStream_t stream) { - if (hidden_dim % 4 != 0) { - throw std::runtime_error("violate hidden_dim % 4 = 0"); - } - hidden_dim >>= 2; - int nthread = min(((hidden_dim + 31) / 32) * 32, MAX_THREADS); - dim3 grid_dim(batch_size); - dim3 block_dim(nthread); - - ker_layer_norm<<>>( - ln_res, vars, means, inp, scale, bias, hidden_dim); -} - -template <> -void launch_layer_norm<__half>(__half *ln_res, __half *vars, __half *means, - const __half *inp, const __half *scale, - const __half *bias, int batch_size, - int hidden_dim, cudaStream_t stream) { - if (hidden_dim % 8 != 0) { - throw std::runtime_error("violate hidden_dim % 8 = 0"); - } - hidden_dim >>= 3; - int nthread = min(((hidden_dim + 31) / 32) * 32, MAX_THREADS); - dim3 grid_dim(batch_size); - dim3 block_dim(nthread); - - ker_layer_norm<__half><<>>( - ln_res, vars, means, inp, scale, bias, hidden_dim); - // if (hidden_dim % 8 != 0) { - // throw std::runtime_error("violate hidden_dim % 8 = 0"); - // } - // hidden_dim >>= 3; - - // if (hidden_dim * 8 < 8192) { - // int nthread = min(((hidden_dim + 31) / 32) * 32, MAX_THREADS); - // dim3 grid_dim(batch_size); - // dim3 block_dim(nthread); - // ker_layer_norm<__half><<>>( - // ln_res, vars, means, inp, scale, bias, hidden_dim); - // } else if (hidden_dim * 8 >= 8192 && hidden_dim * 8 <= 8192 * 2) { - // hidden_dim >>= 1; - // int nthread = min(((hidden_dim + 31) / 32) * 32, MAX_THREADS); - // dim3 grid_dim(batch_size); - // dim3 block_dim(nthread); - // ker_layer_norm_x2<<>>( - // ln_res, vars, means, inp, scale, bias, hidden_dim); - // } else if (hidden_dim * 8 > 8192 * 2 && hidden_dim * 8 <= 8192 * 4) { - // hidden_dim >>= 2; - // int nthread = min(((hidden_dim + 31) / 32) * 32, MAX_THREADS); - // dim3 grid_dim(batch_size); - // dim3 block_dim(nthread); - // ker_layer_norm_x4<<>>( - // ln_res, vars, means, inp, scale, bias, hidden_dim); - // } else { - // throw std::runtime_error("hidden_dim % 4 != 0 || hidden_dim > 32768"); - // } -} - -/** -@brief: ker_ln_bw_dgamma_dbetta -Layer norm backword kernel, compute the gradient of gamma and betta. -dbetta = sum(dout, dim=0) -dgamma = sum(xhat * dout, dim=0) -xhat = (input - mean) * rsqrt(var) or - (output - betta) / gamma - - -@thread -gridDim.x = hidden_size / 32 -blockDim.x = 32 -blockDim.y = 32 - -@param -gamma_grad: [hidden_size], gradient of gamma -betta_grad: [hidden_size], gradient of betta -out_grad: [batch_size * seq_len, hidden_size], gradient of betta ln output -inp_or_out: [batch_size * seq_len, hidden_size], ln output if means is nullptr - ln input if means is not nullptr -gamma: [hidden_size], gamma of ln, - used to compute xhat, maybe nullptr -betta: [hidden_size], betta of ln, - used to compute xhat, maybe nullptr -vars: [batch_size * seq_len], variance of ln forward, - used to compute xhat, maybe nullptr -means: [batch_size * seq_len], mean of ln forward, - used to compute xhat, maybe nullptr -(gamma && betta) ^ (vars && means) should be true -*/ -template -__global__ void -ker_ln_bw_dgamma_dbetta(T *gamma_grad, T *betta_grad, const T *out_grad, - const T *inp_or_out, const T *gamma, const T *betta, - const T *vars, const T *means, int rows, int width) { - __shared__ float betta_buffer[TILE_DIM][TILE_DIM]; - __shared__ float gamma_buffer[TILE_DIM][TILE_DIM]; - - cg::thread_block b = cg::this_thread_block(); - cg::thread_block_tile g = cg::tiled_partition(b); - - int idx = blockDim.x * blockIdx.x + threadIdx.x; - int offset = threadIdx.y * width + idx; - int y_stride = width * TILE_DIM; - - // Loop across inp height - float dbetta = 0; - float dgamma = 0; - float dout, val; - if (idx < width) { - if (means == nullptr) { - float vbetta = (float)betta[idx]; - float vgamma = (float)gamma[idx]; - for (int r = threadIdx.y; r < rows; r += TILE_DIM) { - dout = (float)out_grad[offset]; - // inp_or_out is output - val = (float)inp_or_out[offset]; - dbetta += dout; - dgamma += ((val - vbetta) / add_eps(vgamma) * dout); - offset += y_stride; - } - } else { - for (int r = threadIdx.y; r < rows; r += TILE_DIM) { - dout = (float)out_grad[offset]; - // inp_or_out is input - val = (float)inp_or_out[offset]; - dbetta += dout; - dgamma += ((val - (float)means[r]) * - rsqrtf((float)vars[r] + LN_EPSILON) * dout); - offset += y_stride; - } - } - } - - // Sum the shared buffer. - betta_buffer[threadIdx.x][threadIdx.y] = dbetta; - gamma_buffer[threadIdx.x][threadIdx.y] = dgamma; - __syncthreads(); - float s1 = betta_buffer[threadIdx.y][threadIdx.x]; - float s2 = gamma_buffer[threadIdx.y][threadIdx.x]; - __syncthreads(); - - for (int i = 1; i < TILE_DIM; i <<= 1) { - s1 += g.shfl_down(s1, i); - s2 += g.shfl_down(s2, i); - } - - int pos = blockIdx.x * TILE_DIM + threadIdx.y; - if (threadIdx.x == 0 && idx < width) { - betta_grad[pos] = s1; - gamma_grad[pos] = s2; - } -} - -/** -@brief: ker_ln_bw_dinp -Layer norm backword kernel, compute the gradient of input. -dinp = (dxhat - (sum(dxhat) + xhat * sum(dxhat * xhat)) / hidden_dim) - * rsqrt(var) -xhat = (input - mean) * rsqrt(var) if mean is not nullptr - (output - betta) / gamma if mean is nullptr -dxhat = dout * gamma - - -@thread -gridDim.x = batch_size * seq_len -blockDim.x = hidden_size - -@param -inp_grad: [batch_size * seq_len, hidden_size], gradient of betta ln output -out_grad: [batch_size * seq_len, hidden_size], gradient of betta ln output -residual_grad: [batch_size * seq_len, hidden_size], gradient of residual input, - usually appear in pre-layer-norm for transformer layer, maybe nullptr -inp_or_out: [batch_size * seq_len, hidden_size], ln output if means is nullptr - ln input if means is not nullptr -gamma: [hidden_size], gamma of ln, - used to compute xhat and dxhat -betta: [hidden_size], betta of ln, - used to compute xhat, maybe nullptr -vars: [batch_size * seq_len], variance of ln forward, - used to compute xhat and dinp -means: [batch_size * seq_len], mean of ln forward, - used to compute xhat, maybe nullptr -*/ -template -__global__ void ker_ln_bw_dinp(T *inp_grad, const T *out_grad, - const T *residual_grad, const T *inp_or_out, - const T *gamma, const T *betta, const T *vars, - const T *means, int hidden_dim) { - int offset = blockIdx.x * hidden_dim + threadIdx.x; - float4 dxhat, xhat; - float var_rsqrt; - - if (threadIdx.x < hidden_dim) { - // step 0. dxhat = dout * gamma - dxhat = ((const float4 *)out_grad)[offset]; - float4 vgamma = ((const float4 *)gamma)[threadIdx.x]; - dxhat.x *= vgamma.x; - dxhat.y *= vgamma.y; - dxhat.z *= vgamma.z; - dxhat.w *= vgamma.w; - - /* - step 1. xhat = (output - betta) / gamma or - (input - mean) * rsqrtf(var) - */ - xhat = ((const float4 *)inp_or_out)[offset]; - var_rsqrt = rsqrtf((float)vars[blockIdx.x] + LN_EPSILON); - if (means == nullptr) { - // inp_or_out is output, xhat = (output - betta) / gamma - float4 vbetta = ((const float4 *)betta)[threadIdx.x]; - xhat.x = (xhat.x - vbetta.x) / add_eps(vgamma.x); - xhat.y = (xhat.y - vbetta.y) / add_eps(vgamma.y); - xhat.z = (xhat.z - vbetta.z) / add_eps(vgamma.z); - xhat.w = (xhat.w - vbetta.w) / add_eps(vgamma.w); - } else { - // inp_or_out is input, xhat = (input - mean) * rsqrtf(var) - float fmean = (float)means[blockIdx.x]; - xhat.x = (xhat.x - fmean) * var_rsqrt; - xhat.y = (xhat.y - fmean) * var_rsqrt; - xhat.z = (xhat.z - fmean) * var_rsqrt; - xhat.w = (xhat.w - fmean) * var_rsqrt; - } - } - - /* step2. block reduce sum for dxhat and dxhat*xhat */ - float reduce_val[2] = {0.f, 0.f}; - if (threadIdx.x < hidden_dim) { - reduce_val[0] = dxhat.x + dxhat.y + dxhat.z + dxhat.w; - reduce_val[1] = dxhat.x * xhat.x + dxhat.y * xhat.y + dxhat.z * xhat.z + - dxhat.w * xhat.w; - } - blockReduce(reduce_val); - __shared__ float s_sum_dxhat, s_sum_dxhat_xhat; - if (threadIdx.x == 0) { - float mean_dim = hidden_dim * 4; - s_sum_dxhat = reduce_val[0] / mean_dim; - s_sum_dxhat_xhat = reduce_val[1] / mean_dim; - } - __syncthreads(); - - /* - step3. compute input gradient - (dxhat - (sum(dxhat) + xhat * sum(dxhat * xhat)) / mean_dim) * rsqrt(var) - */ - if (threadIdx.x >= hidden_dim) { - return; - } - dxhat.x = (dxhat.x - s_sum_dxhat - xhat.x * s_sum_dxhat_xhat) * var_rsqrt; - dxhat.y = (dxhat.y - s_sum_dxhat - xhat.y * s_sum_dxhat_xhat) * var_rsqrt; - dxhat.z = (dxhat.z - s_sum_dxhat - xhat.z * s_sum_dxhat_xhat) * var_rsqrt; - dxhat.w = (dxhat.w - s_sum_dxhat - xhat.w * s_sum_dxhat_xhat) * var_rsqrt; - if (residual_grad) { - // Add the residual grad, - // usually in pre-layer-norm for transformer layer - float4 dresidual = ((const float4 *)residual_grad)[offset]; - dxhat.x += dresidual.x; - dxhat.y += dresidual.y; - dxhat.z += dresidual.z; - dxhat.w += dresidual.w; - } - ((float4 *)inp_grad)[offset] = dxhat; -} - -template <> -__global__ void ker_ln_bw_dinp<__half>(__half *inp_grad, const __half *out_grad, - const __half *residual_grad, - const __half *inp_or_out, - const __half *gamma, const __half *betta, - const __half *vars, const __half *means, - int hidden_dim) { - int offset = blockIdx.x * hidden_dim + threadIdx.x; - - float2 dxhat[4], xhat[4]; - float var_rsqrt; - float4 vtmp; - __half2 *tmp_h2; - float reduce_val[2] = {0.f, 0.f}; - - if (threadIdx.x < hidden_dim) { - // step 0. dxhat = dout * gamma - vtmp = ((const float4 *)out_grad)[offset]; - tmp_h2 = reinterpret_cast<__half2 *>(&vtmp); - float4 gamma_f4 = ((const float4 *)gamma)[threadIdx.x]; - __half2 *gamma_h2 = reinterpret_cast<__half2 *>(&gamma_f4); -#pragma unroll - for (int i = 0; i < 4; i++) { - float2 vdout = __half22float2(tmp_h2[i]); - float2 vgamma = __half22float2(gamma_h2[i]); - dxhat[i].x = vdout.x * vgamma.x; - dxhat[i].y = vdout.y * vgamma.y; - reduce_val[0] += dxhat[i].x + dxhat[i].y; - } - - /* - step 1. xhat = (output - betta) / gamma or - (input - mean) * rsqrtf(var) - */ - vtmp = ((const float4 *)inp_or_out)[offset]; - var_rsqrt = rsqrtf((float)vars[blockIdx.x] + LN_EPSILON); - if (means == nullptr) { - // inp_or_out is output, xhat = (output - betta) / gamma - float4 vbetta = ((const float4 *)betta)[threadIdx.x]; - __half2 *betta_h2 = reinterpret_cast<__half2 *>(&vbetta); -#pragma unroll - for (int i = 0; i < 4; i++) { - float2 vout = __half22float2(tmp_h2[i]); - float2 vgamma = __half22float2(gamma_h2[i]); - float2 vbetta = __half22float2(betta_h2[i]); - xhat[i].x = (vout.x - vbetta.x) / add_eps(vgamma.x); - xhat[i].y = (vout.y - vbetta.y) / add_eps(vgamma.y); - reduce_val[1] += xhat[i].x * dxhat[i].x + xhat[i].y * dxhat[i].y; - } - } else { - // inp_or_out is input, xhat = (input - mean) * rsqrtf(var) - float fmean = (float)means[blockIdx.x]; -#pragma unroll - for (int i = 0; i < 4; i++) { - float2 vinp = __half22float2(tmp_h2[i]); - xhat[i].x = (vinp.x - fmean) * var_rsqrt; - xhat[i].y = (vinp.y - fmean) * var_rsqrt; - reduce_val[1] += xhat[i].x * dxhat[i].x + xhat[i].y * dxhat[i].y; - } - } - } - - /* step2. block reduce sum for dxhat and dxhat*xhat */ - blockReduce(reduce_val); - __shared__ float s_sum_dxhat, s_sum_dxhat_xhat; - if (threadIdx.x == 0) { - float mean_dim = hidden_dim * 8; - s_sum_dxhat = reduce_val[0] / mean_dim; - s_sum_dxhat_xhat = reduce_val[1] / mean_dim; - } - __syncthreads(); - - /* - step3. compute input gradient - (dxhat - (sum(dxhat) + xhat * sum(dxhat * xhat)) / mean_dim) * rsqrt(var) - */ - if (threadIdx.x >= hidden_dim) { - return; - } - if (residual_grad) { - // Add the residual grad, - // usually in pre-layer-norm for transformer layer - float4 dresidual = ((const float4 *)residual_grad)[offset]; - __half *hdres = reinterpret_cast<__half *>(&dresidual); -#pragma unroll - for (int i = 0; i < 4; i++) { - tmp_h2[i].x = __float2half( - (dxhat[i].x - s_sum_dxhat - xhat[i].x * s_sum_dxhat_xhat) * - var_rsqrt + - __half2float(hdres[2 * i])); - tmp_h2[i].y = __float2half( - (dxhat[i].y - s_sum_dxhat - xhat[i].y * s_sum_dxhat_xhat) * - var_rsqrt + - __half2float(hdres[2 * i + 1])); - } - } else { -#pragma unroll - for (int i = 0; i < 4; i++) { - tmp_h2[i].x = __float2half( - (dxhat[i].x - s_sum_dxhat - xhat[i].x * s_sum_dxhat_xhat) * - var_rsqrt); - tmp_h2[i].y = __float2half( - (dxhat[i].y - s_sum_dxhat - xhat[i].y * s_sum_dxhat_xhat) * - var_rsqrt); - } - } - ((float4 *)inp_grad)[offset] = vtmp; -} - -__global__ void ker_ln_bw_dinp_x2(__half *inp_grad, const __half *out_grad, - const __half *residual_grad, - const __half *inp_or_out, const __half *gamma, - const __half *betta, const __half *vars, - const __half *means, int hidden_dim) { - int offset = blockIdx.x * hidden_dim * 2 + threadIdx.x * 2; - - float2 dxhat[4], xhat[4]; - float2 dxhat_1[4], xhat_1[4]; - float var_rsqrt; - float4 vtmp, vtmp_1; - __half2 *tmp_h2; - __half2 *tmp_h2_1; - float reduce_val[2] = {0.f, 0.f}; - - if (threadIdx.x < hidden_dim) { - // step 0. dxhat = dout * gamma - vtmp = ((const float4 *)out_grad)[offset]; - vtmp_1 = ((const float4 *)out_grad)[offset + 1]; - tmp_h2 = reinterpret_cast<__half2 *>(&vtmp); - tmp_h2_1 = reinterpret_cast<__half2 *>(&vtmp_1); - float4 gamma_f4 = ((const float4 *)gamma)[threadIdx.x * 2]; - float4 gamma_f4_1 = ((const float4 *)gamma)[threadIdx.x * 2 + 1]; - __half2 *gamma_h2 = reinterpret_cast<__half2 *>(&gamma_f4); - __half2 *gamma_h2_1 = reinterpret_cast<__half2 *>(&gamma_f4_1); -#pragma unroll - for (int i = 0; i < 4; i++) { - float2 vdout = __half22float2(tmp_h2[i]); - float2 vdout_1 = __half22float2(tmp_h2_1[i]); - float2 vgamma = __half22float2(gamma_h2[i]); - float2 vgamma_1 = __half22float2(gamma_h2_1[i]); - dxhat[i].x = vdout.x * vgamma.x; - dxhat[i].y = vdout.y * vgamma.y; - dxhat_1[i].x = vdout_1.x * vgamma_1.x; - dxhat_1[i].y = vdout_1.y * vgamma_1.y; - reduce_val[0] += dxhat[i].x + dxhat[i].y + dxhat_1[i].x + dxhat_1[i].y; - } - - /* - step 1. xhat = (output - betta) / gamma or - (input - mean) * rsqrtf(var) - */ - vtmp = ((const float4 *)inp_or_out)[offset]; - vtmp_1 = ((const float4 *)inp_or_out)[offset + 1]; - var_rsqrt = rsqrtf((float)vars[blockIdx.x] + LN_EPSILON); - if (means == nullptr) { - // inp_or_out is output, xhat = (output - betta) / gamma - float4 vbetta = ((const float4 *)betta)[2 * threadIdx.x]; - float4 vbetta_1 = ((const float4 *)betta)[2 * threadIdx.x + 1]; - __half2 *betta_h2 = reinterpret_cast<__half2 *>(&vbetta); - __half2 *betta_h2_1 = reinterpret_cast<__half2 *>(&vbetta_1); -#pragma unroll - for (int i = 0; i < 4; i++) { - float2 vout = __half22float2(tmp_h2[i]); - float2 vout_1 = __half22float2(tmp_h2_1[i]); - float2 vgamma = __half22float2(gamma_h2[i]); - float2 vgamma_1 = __half22float2(gamma_h2_1[i]); - float2 vbetta = __half22float2(betta_h2[i]); - float2 vbetta_1 = __half22float2(betta_h2_1[i]); - xhat[i].x = (vout.x - vbetta.x) / add_eps(vgamma.x); - xhat_1[i].x = (vout_1.x - vbetta_1.x) / add_eps(vgamma_1.x); - xhat[i].y = (vout.y - vbetta.y) / add_eps(vgamma.y); - xhat_1[i].y = (vout_1.y - vbetta_1.y) / add_eps(vgamma_1.y); - reduce_val[1] += xhat[i].x * dxhat[i].x + xhat[i].y * dxhat[i].y; - reduce_val[1] += - xhat_1[i].x * dxhat_1[i].x + xhat_1[i].y * dxhat_1[i].y; - } - } else { - // inp_or_out is input, xhat = (input - mean) * rsqrtf(var) - float fmean = (float)means[blockIdx.x]; -#pragma unroll - for (int i = 0; i < 4; i++) { - float2 vinp = __half22float2(tmp_h2[i]); - float2 vinp_1 = __half22float2(tmp_h2_1[i]); - xhat[i].x = (vinp.x - fmean) * var_rsqrt; - xhat_1[i].x = (vinp_1.x - fmean) * var_rsqrt; - xhat[i].y = (vinp.y - fmean) * var_rsqrt; - xhat_1[i].y = (vinp_1.y - fmean) * var_rsqrt; - reduce_val[1] += xhat[i].x * dxhat[i].x + xhat[i].y * dxhat[i].y; - reduce_val[1] += - xhat_1[i].x * dxhat_1[i].x + xhat_1[i].y * dxhat_1[i].y; - } - } - } - - /* step2. block reduce sum for dxhat and dxhat*xhat */ - blockReduce(reduce_val); - __shared__ float s_sum_dxhat, s_sum_dxhat_xhat; - if (threadIdx.x == 0) { - float mean_dim = hidden_dim * 8 * 2; - s_sum_dxhat = reduce_val[0] / mean_dim; - s_sum_dxhat_xhat = reduce_val[1] / mean_dim; - } - __syncthreads(); - - /* - step3. compute input gradient - (dxhat - (sum(dxhat) + xhat * sum(dxhat * xhat)) / mean_dim) * rsqrt(var) - */ - if (threadIdx.x >= hidden_dim) { - return; - } - if (residual_grad) { - // Add the residual grad, - // usually in pre-layer-norm for transformer layer - float4 dresidual = ((const float4 *)residual_grad)[offset]; - float4 dresidual_1 = ((const float4 *)residual_grad)[offset + 1]; - __half *hdres = reinterpret_cast<__half *>(&dresidual); - __half *hdres_1 = reinterpret_cast<__half *>(&dresidual_1); -#pragma unroll - for (int i = 0; i < 4; i++) { - tmp_h2[i].x = __float2half( - (dxhat[i].x - s_sum_dxhat - xhat[i].x * s_sum_dxhat_xhat) * - var_rsqrt + - __half2float(hdres[2 * i])); - tmp_h2_1[i].x = __float2half( - (dxhat_1[i].x - s_sum_dxhat - xhat_1[i].x * s_sum_dxhat_xhat) * - var_rsqrt + - __half2float(hdres_1[2 * i])); - tmp_h2[i].y = __float2half( - (dxhat[i].y - s_sum_dxhat - xhat[i].y * s_sum_dxhat_xhat) * - var_rsqrt + - __half2float(hdres[2 * i + 1])); - tmp_h2_1[i].y = __float2half( - (dxhat_1[i].y - s_sum_dxhat - xhat_1[i].y * s_sum_dxhat_xhat) * - var_rsqrt + - __half2float(hdres_1[2 * i + 1])); - } - } else { -#pragma unroll - for (int i = 0; i < 4; i++) { - tmp_h2[i].x = __float2half( - (dxhat[i].x - s_sum_dxhat - xhat[i].x * s_sum_dxhat_xhat) * - var_rsqrt); - tmp_h2_1[i].x = __float2half( - (dxhat_1[i].x - s_sum_dxhat - xhat_1[i].x * s_sum_dxhat_xhat) * - var_rsqrt); - tmp_h2[i].y = __float2half( - (dxhat[i].y - s_sum_dxhat - xhat[i].y * s_sum_dxhat_xhat) * - var_rsqrt); - tmp_h2_1[i].y = __float2half( - (dxhat_1[i].y - s_sum_dxhat - xhat_1[i].y * s_sum_dxhat_xhat) * - var_rsqrt); - } - } - ((float4 *)inp_grad)[offset] = vtmp; - ((float4 *)inp_grad)[offset + 1] = vtmp_1; -} - -__global__ void ker_ln_bw_dinp_x4(__half *inp_grad, const __half *out_grad, - const __half *residual_grad, - const __half *inp_or_out, const __half *gamma, - const __half *betta, const __half *vars, - const __half *means, int hidden_dim) { - int offset = blockIdx.x * hidden_dim * 4 + threadIdx.x * 4; - - float2 dxhat[4], xhat[4]; - float2 dxhat_1[4], xhat_1[4]; - float2 dxhat_2[4], xhat_2[4]; - float2 dxhat_3[4], xhat_3[4]; - float var_rsqrt; - float4 vtmp, vtmp_1, vtmp_2, vtmp_3; - __half2 *tmp_h2; - __half2 *tmp_h2_1; - __half2 *tmp_h2_2; - __half2 *tmp_h2_3; - float reduce_val[2] = {0.f, 0.f}; - - if (threadIdx.x < hidden_dim) { - // step 0. dxhat = dout * gamma - vtmp = ((const float4 *)out_grad)[offset]; - vtmp_1 = ((const float4 *)out_grad)[offset + 1]; - vtmp_2 = ((const float4 *)out_grad)[offset + 2]; - vtmp_3 = ((const float4 *)out_grad)[offset + 3]; - tmp_h2 = reinterpret_cast<__half2 *>(&vtmp); - tmp_h2_1 = reinterpret_cast<__half2 *>(&vtmp_1); - tmp_h2_2 = reinterpret_cast<__half2 *>(&vtmp_2); - tmp_h2_3 = reinterpret_cast<__half2 *>(&vtmp_3); - float4 gamma_f4 = ((const float4 *)gamma)[threadIdx.x * 4]; - float4 gamma_f4_1 = ((const float4 *)gamma)[threadIdx.x * 4 + 1]; - float4 gamma_f4_2 = ((const float4 *)gamma)[threadIdx.x * 4 + 2]; - float4 gamma_f4_3 = ((const float4 *)gamma)[threadIdx.x * 4 + 3]; - __half2 *gamma_h2 = reinterpret_cast<__half2 *>(&gamma_f4); - __half2 *gamma_h2_1 = reinterpret_cast<__half2 *>(&gamma_f4_1); - __half2 *gamma_h2_2 = reinterpret_cast<__half2 *>(&gamma_f4_2); - __half2 *gamma_h2_3 = reinterpret_cast<__half2 *>(&gamma_f4_3); -#pragma unroll - for (int i = 0; i < 4; i++) { - float2 vdout = __half22float2(tmp_h2[i]); - float2 vdout_1 = __half22float2(tmp_h2_1[i]); - float2 vdout_2 = __half22float2(tmp_h2_2[i]); - float2 vdout_3 = __half22float2(tmp_h2_3[i]); - float2 vgamma = __half22float2(gamma_h2[i]); - float2 vgamma_1 = __half22float2(gamma_h2_1[i]); - float2 vgamma_2 = __half22float2(gamma_h2_2[i]); - float2 vgamma_3 = __half22float2(gamma_h2_3[i]); - dxhat[i].x = vdout.x * vgamma.x; - dxhat[i].y = vdout.y * vgamma.y; - dxhat_1[i].x = vdout_1.x * vgamma_1.x; - dxhat_1[i].y = vdout_1.y * vgamma_1.y; - dxhat_2[i].x = vdout_2.x * vgamma_2.x; - dxhat_2[i].y = vdout_2.y * vgamma_2.y; - dxhat_3[i].x = vdout_3.x * vgamma_3.x; - dxhat_3[i].y = vdout_3.y * vgamma_3.y; - reduce_val[0] += dxhat[i].x + dxhat[i].y + dxhat_1[i].x + dxhat_1[i].y + - dxhat_2[i].x + dxhat_2[i].y + dxhat_3[i].x + - dxhat_3[i].y; - } - - /* - step 1. xhat = (output - betta) / gamma or - (input - mean) * rsqrtf(var) - */ - vtmp = ((const float4 *)inp_or_out)[offset]; - vtmp_1 = ((const float4 *)inp_or_out)[offset + 1]; - vtmp_2 = ((const float4 *)inp_or_out)[offset + 2]; - vtmp_3 = ((const float4 *)inp_or_out)[offset + 3]; - var_rsqrt = rsqrtf((float)vars[blockIdx.x] + LN_EPSILON); - if (means == nullptr) { - // inp_or_out is output, xhat = (output - betta) / gamma - float4 vbetta = ((const float4 *)betta)[4 * threadIdx.x]; - float4 vbetta_1 = ((const float4 *)betta)[4 * threadIdx.x + 1]; - float4 vbetta_2 = ((const float4 *)betta)[4 * threadIdx.x + 2]; - float4 vbetta_3 = ((const float4 *)betta)[4 * threadIdx.x + 3]; - __half2 *betta_h2 = reinterpret_cast<__half2 *>(&vbetta); - __half2 *betta_h2_1 = reinterpret_cast<__half2 *>(&vbetta_1); - __half2 *betta_h2_2 = reinterpret_cast<__half2 *>(&vbetta_2); - __half2 *betta_h2_3 = reinterpret_cast<__half2 *>(&vbetta_3); -#pragma unroll - for (int i = 0; i < 4; i++) { - float2 vout = __half22float2(tmp_h2[i]); - float2 vout_1 = __half22float2(tmp_h2_1[i]); - float2 vout_2 = __half22float2(tmp_h2_2[i]); - float2 vout_3 = __half22float2(tmp_h2_3[i]); - float2 vgamma = __half22float2(gamma_h2[i]); - float2 vgamma_1 = __half22float2(gamma_h2_1[i]); - float2 vgamma_2 = __half22float2(gamma_h2_2[i]); - float2 vgamma_3 = __half22float2(gamma_h2_3[i]); - float2 vbetta = __half22float2(betta_h2[i]); - float2 vbetta_1 = __half22float2(betta_h2_1[i]); - float2 vbetta_2 = __half22float2(betta_h2_2[i]); - float2 vbetta_3 = __half22float2(betta_h2_3[i]); - xhat[i].x = (vout.x - vbetta.x) / add_eps(vgamma.x); - xhat_1[i].x = (vout_1.x - vbetta_1.x) / add_eps(vgamma_1.x); - xhat_2[i].x = (vout_2.x - vbetta_2.x) / add_eps(vgamma_2.x); - xhat_3[i].x = (vout_3.x - vbetta_3.x) / add_eps(vgamma_3.x); - xhat[i].y = (vout.y - vbetta.y) / add_eps(vgamma.y); - xhat_1[i].y = (vout_1.y - vbetta_1.y) / add_eps(vgamma_1.y); - xhat_2[i].y = (vout_2.y - vbetta_2.y) / add_eps(vgamma_2.y); - xhat_3[i].y = (vout_3.y - vbetta_3.y) / add_eps(vgamma_3.y); - reduce_val[1] += xhat[i].x * dxhat[i].x + xhat[i].y * dxhat[i].y; - reduce_val[1] += - xhat_1[i].x * dxhat_1[i].x + xhat_1[i].y * dxhat_1[i].y; - reduce_val[1] += - xhat_2[i].x * dxhat_2[i].x + xhat_2[i].y * dxhat_2[i].y; - reduce_val[1] += - xhat_3[i].x * dxhat_3[i].x + xhat_3[i].y * dxhat_3[i].y; - } - } else { - // inp_or_out is input, xhat = (input - mean) * rsqrtf(var) - float fmean = (float)means[blockIdx.x]; -#pragma unroll - for (int i = 0; i < 4; i++) { - float2 vinp = __half22float2(tmp_h2[i]); - float2 vinp_1 = __half22float2(tmp_h2_1[i]); - float2 vinp_2 = __half22float2(tmp_h2_2[i]); - float2 vinp_3 = __half22float2(tmp_h2_3[i]); - xhat[i].x = (vinp.x - fmean) * var_rsqrt; - xhat_1[i].x = (vinp_1.x - fmean) * var_rsqrt; - xhat_2[i].x = (vinp_2.x - fmean) * var_rsqrt; - xhat_3[i].x = (vinp_3.x - fmean) * var_rsqrt; - xhat[i].y = (vinp.y - fmean) * var_rsqrt; - xhat_1[i].y = (vinp_1.y - fmean) * var_rsqrt; - xhat_2[i].y = (vinp_2.y - fmean) * var_rsqrt; - xhat_3[i].y = (vinp_3.y - fmean) * var_rsqrt; - reduce_val[1] += xhat[i].x * dxhat[i].x + xhat[i].y * dxhat[i].y; - reduce_val[1] += - xhat_1[i].x * dxhat_1[i].x + xhat_1[i].y * dxhat_1[i].y; - reduce_val[1] += - xhat_2[i].x * dxhat_2[i].x + xhat_2[i].y * dxhat_2[i].y; - reduce_val[1] += - xhat_3[i].x * dxhat_3[i].x + xhat_3[i].y * dxhat_3[i].y; - } - } - } - - /* step2. block reduce sum for dxhat and dxhat*xhat */ - blockReduce(reduce_val); - __shared__ float s_sum_dxhat, s_sum_dxhat_xhat; - if (threadIdx.x == 0) { - float mean_dim = hidden_dim * 8 * 4; - s_sum_dxhat = reduce_val[0] / mean_dim; - s_sum_dxhat_xhat = reduce_val[1] / mean_dim; - } - __syncthreads(); - - /* - step3. compute input gradient - (dxhat - (sum(dxhat) + xhat * sum(dxhat * xhat)) / mean_dim) * rsqrt(var) - */ - if (threadIdx.x >= hidden_dim) { - return; - } - if (residual_grad) { - // Add the residual grad, - // usually in pre-layer-norm for transformer layer - float4 dresidual = ((const float4 *)residual_grad)[offset]; - float4 dresidual_1 = ((const float4 *)residual_grad)[offset + 1]; - float4 dresidual_2 = ((const float4 *)residual_grad)[offset + 2]; - float4 dresidual_3 = ((const float4 *)residual_grad)[offset + 3]; - __half *hdres = reinterpret_cast<__half *>(&dresidual); - __half *hdres_1 = reinterpret_cast<__half *>(&dresidual_1); - __half *hdres_2 = reinterpret_cast<__half *>(&dresidual_2); - __half *hdres_3 = reinterpret_cast<__half *>(&dresidual_3); -#pragma unroll - for (int i = 0; i < 4; i++) { - tmp_h2[i].x = __float2half( - (dxhat[i].x - s_sum_dxhat - xhat[i].x * s_sum_dxhat_xhat) * - var_rsqrt + - __half2float(hdres[2 * i])); - tmp_h2_1[i].x = __float2half( - (dxhat_1[i].x - s_sum_dxhat - xhat_1[i].x * s_sum_dxhat_xhat) * - var_rsqrt + - __half2float(hdres_1[2 * i])); - tmp_h2_2[i].x = __float2half( - (dxhat_2[i].x - s_sum_dxhat - xhat_2[i].x * s_sum_dxhat_xhat) * - var_rsqrt + - __half2float(hdres_2[2 * i])); - tmp_h2_3[i].x = __float2half( - (dxhat_3[i].x - s_sum_dxhat - xhat_3[i].x * s_sum_dxhat_xhat) * - var_rsqrt + - __half2float(hdres_3[2 * i])); - tmp_h2[i].y = __float2half( - (dxhat[i].y - s_sum_dxhat - xhat[i].y * s_sum_dxhat_xhat) * - var_rsqrt + - __half2float(hdres[2 * i + 1])); - tmp_h2_1[i].y = __float2half( - (dxhat_1[i].y - s_sum_dxhat - xhat_1[i].y * s_sum_dxhat_xhat) * - var_rsqrt + - __half2float(hdres_1[2 * i + 1])); - tmp_h2_2[i].y = __float2half( - (dxhat_2[i].y - s_sum_dxhat - xhat_2[i].y * s_sum_dxhat_xhat) * - var_rsqrt + - __half2float(hdres_1[2 * i + 1])); - tmp_h2_3[i].y = __float2half( - (dxhat_3[i].y - s_sum_dxhat - xhat_3[i].y * s_sum_dxhat_xhat) * - var_rsqrt + - __half2float(hdres_1[2 * i + 1])); - } - } else { -#pragma unroll - for (int i = 0; i < 4; i++) { - tmp_h2[i].x = __float2half( - (dxhat[i].x - s_sum_dxhat - xhat[i].x * s_sum_dxhat_xhat) * - var_rsqrt); - tmp_h2_1[i].x = __float2half( - (dxhat_1[i].x - s_sum_dxhat - xhat_1[i].x * s_sum_dxhat_xhat) * - var_rsqrt); - tmp_h2_2[i].x = __float2half( - (dxhat_2[i].x - s_sum_dxhat - xhat_2[i].x * s_sum_dxhat_xhat) * - var_rsqrt); - tmp_h2_3[i].x = __float2half( - (dxhat_3[i].x - s_sum_dxhat - xhat_3[i].x * s_sum_dxhat_xhat) * - var_rsqrt); - tmp_h2[i].y = __float2half( - (dxhat[i].y - s_sum_dxhat - xhat[i].y * s_sum_dxhat_xhat) * - var_rsqrt); - tmp_h2_1[i].y = __float2half( - (dxhat_1[i].y - s_sum_dxhat - xhat_1[i].y * s_sum_dxhat_xhat) * - var_rsqrt); - tmp_h2_2[i].y = __float2half( - (dxhat_2[i].y - s_sum_dxhat - xhat_2[i].y * s_sum_dxhat_xhat) * - var_rsqrt); - tmp_h2_3[i].y = __float2half( - (dxhat_3[i].y - s_sum_dxhat - xhat_3[i].y * s_sum_dxhat_xhat) * - var_rsqrt); - } - } - ((float4 *)inp_grad)[offset] = vtmp; - ((float4 *)inp_grad)[offset + 1] = vtmp_1; - ((float4 *)inp_grad)[offset + 2] = vtmp_2; - ((float4 *)inp_grad)[offset + 3] = vtmp_3; -} - -/** -Layer norm backword, - compute the gradient of gamma, betta and input. -dbetta = sum(dout, dim=0) -xhat = (input - mean) * rsqrt(var) if mean is not nullptr - (output - betta) / gamma if mean is nullptr -dgamma = sum(xhat * dout, dim=0) -dxhat = dout * gamma -dinp = (dxhat - (sum(dxhat, 1) + xhat * sum(dxhat * xhat, 1)) / hidden_dim) - * rsqrt(var) - -residual_grad, means, betta can be nullptr. -residual_grad will be added to dinp if it is not nullptr - which is useful in transformer layer when pre-ln -means and betta are only used to compute xhat, - (means == nullptr) ^ (betta == nullptr) should be true -*/ -template <> -void launch_ln_bw(float *gamma_grad, float *betta_grad, float *inp_grad, - const float *out_grad, const float *residual_grad, - const float *inp_or_out, const float *gamma, - const float *betta, const float *vars, - const float *means, int batch, int hidden_dim, - cudaStream_t stream[2]) { - // compute grad of gamma and betta - dim3 grid_dim(((hidden_dim + TILE_DIM - 1) / TILE_DIM) * TILE_DIM); - dim3 block_dim(TILE_DIM, TILE_DIM); - ker_ln_bw_dgamma_dbetta<<>>( - gamma_grad, betta_grad, out_grad, inp_or_out, gamma, betta, vars, means, - batch, hidden_dim); - - // compute grad of input - if (hidden_dim % 4 != 0 || hidden_dim > 4096) { - throw std::runtime_error("hidden_dim % 4 != 0 || hidden_dim > 4096"); - } - hidden_dim >>= 2; - int nthread = min(((hidden_dim + 31) / 32) * 32, MAX_THREADS); - ker_ln_bw_dinp<<>>( - inp_grad, out_grad, residual_grad, inp_or_out, gamma, betta, vars, means, - hidden_dim); -} - -template <> -void launch_ln_bw<__half>(__half *gamma_grad, __half *betta_grad, - __half *inp_grad, const __half *out_grad, - const __half *residual_grad, const __half *inp_or_out, - const __half *gamma, const __half *betta, - const __half *vars, const __half *means, int batch, - int hidden_dim, cudaStream_t stream[2]) { - // compute grad of gamma and betta - dim3 grid_dim(((hidden_dim + TILE_DIM - 1) / TILE_DIM) * TILE_DIM); - dim3 block_dim(TILE_DIM, TILE_DIM); - ker_ln_bw_dgamma_dbetta<__half><<>>( - gamma_grad, betta_grad, out_grad, inp_or_out, gamma, betta, vars, means, - batch, hidden_dim); - - // compute grad of input - if (hidden_dim % 8 != 0) { - throw std::runtime_error("hidden_dim % 8 != 0"); - } - hidden_dim >>= 3; - - if (hidden_dim * 8 <= 8192) { - int nthread = min(((hidden_dim + 31) / 32) * 32, MAX_THREADS); - ker_ln_bw_dinp<<>>( - inp_grad, out_grad, residual_grad, inp_or_out, gamma, betta, vars, - means, hidden_dim); - } else if (hidden_dim * 8 > 8192 && hidden_dim * 8 <= 8192 * 2) { - hidden_dim >>= 1; - int nthread = min(((hidden_dim + 31) / 32) * 32, MAX_THREADS); - ker_ln_bw_dinp_x2<<>>( - inp_grad, out_grad, residual_grad, inp_or_out, gamma, betta, vars, - means, hidden_dim); - } else if (hidden_dim * 8 > 2 * 8192 && hidden_dim * 8 <= 8192 * 4) { - hidden_dim >>= 2; - int nthread = min(((hidden_dim + 31) / 32) * 32, MAX_THREADS); - ker_ln_bw_dinp_x4<<>>( - inp_grad, out_grad, residual_grad, inp_or_out, gamma, betta, vars, - means, hidden_dim); - } else { - throw std::runtime_error("hidden_dim % 4 != 0 || hidden_dim > 32768"); - } -} +#include + +#include "block_reduce.h" +#include "kernels.h" + +namespace cg = cooperative_groups; +const float LN_EPSILON = 1e-8f; +#define TILE_DIM 32 + +template +__forceinline__ __device__ T add_eps(T x) { + return fabsf(x) > LN_EPSILON ? x : (x < 0 ? -LN_EPSILON : LN_EPSILON); +} + +/** +@brief: ker_layer_norm +Standard layer normalization. +It will not only output the layer norm result, + but also outputs variance. + may also output means, depends on whether + the means argument is nullptr + +@thread +gridDim.x = batch_size * seq_len +blockDim.x = hidden_size + +@param +ln_res: [batch_size* seq_len, hidden_size], ln result. +vars: [batch_size* seq_len], variance per token +means: [batch_size* seq_len], means per token, can be nullput +inp: [batch_size * seq_len, hidden_size], ln input. +scale: [hidden_size], ln scale +bias: [hidden_size], ln bias +*/ +template +__global__ void ker_layer_norm(T *ln_res, T *vars, T *means, const T *inp, + const T *scale, const T *bias, int hidden_size) { + // step 0. compute local sum + float l_sum = 0; + float l_square_sum = 0; + const float4 *inp_f4 = (const float4 *)inp + blockIdx.x * hidden_size; + for (uint idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) { + float4 val = inp_f4[idx]; + l_sum += val.x + val.y + val.z + val.w; + l_square_sum += + val.x * val.x + val.y * val.y + val.z * val.z + val.w * val.w; + } + + // step 1. compute reduce sum + float mean_dim = float(hidden_size) * 4.f; + float reduce_val[2] = {l_sum, l_square_sum}; + blockReduce(reduce_val); + __shared__ float s_mean, s_var; + if (threadIdx.x == 0) { + s_mean = reduce_val[0] / mean_dim; + if (means != nullptr) { + means[blockIdx.x] = s_mean; + } + s_var = reduce_val[1] / mean_dim - s_mean * s_mean + LN_EPSILON; + vars[blockIdx.x] = s_var; + s_var = rsqrtf(s_var); + } + __syncthreads(); + + // step 2. layer norm result + float4 *output_f4 = (float4 *)ln_res + blockIdx.x * hidden_size; + for (uint idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) { + float4 vscale = __ldg((const float4 *)scale + idx); + float4 vbias = __ldg((const float4 *)bias + idx); + float4 val = inp_f4[idx]; + val.x = (val.x - s_mean) * s_var * vscale.x + vbias.x; + val.y = (val.y - s_mean) * s_var * vscale.y + vbias.y; + val.z = (val.z - s_mean) * s_var * vscale.z + vbias.z; + val.w = (val.w - s_mean) * s_var * vscale.w + vbias.w; + output_f4[idx] = val; + } +} + +template <> +__global__ void ker_layer_norm<__half>(__half *ln_res, __half *vars, + __half *means, const __half *inp, + const __half *scale, const __half *bias, + int hidden_size) { + // step 0. compute local sum + float l_sum = 0; + float l_square_sum = 0; + const float4 *inp_f4 = (const float4 *)inp + blockIdx.x * hidden_size; + for (uint idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) { + float4 val_f4 = inp_f4[idx]; + __half2 *val_h2 = (__half2 *)(&val_f4); +#pragma unroll + for (int i = 0; i < 4; i++) { + float2 val_f2 = __half22float2(val_h2[i]); + l_sum += val_f2.x + val_f2.y; + l_square_sum += val_f2.x * val_f2.x + val_f2.y * val_f2.y; + } + } + + // step 1. compute reduce sum + float mean_dim = float(hidden_size) * 8.f; + float reduce_val[2] = {l_sum, l_square_sum}; + blockReduce(reduce_val); + __shared__ float s_mean, s_var; + if (threadIdx.x == 0) { + s_mean = reduce_val[0] / mean_dim; + if (means != nullptr) { + means[blockIdx.x] = s_mean; + } + s_var = reduce_val[1] / mean_dim - s_mean * s_mean + LN_EPSILON; + vars[blockIdx.x] = s_var; + s_var = rsqrtf(s_var); + } + __syncthreads(); + + // step 2. layer norm result + float4 *output_f4 = (float4 *)ln_res + blockIdx.x * hidden_size; + for (uint idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) { + // load scale, bias, input + float4 scale_f4 = __ldg((const float4 *)scale + idx); + __half2 *scale_h2 = (__half2 *)(&scale_f4); + float4 bias_f4 = __ldg((const float4 *)bias + idx); + __half2 *bias_h2 = (__half2 *)(&bias_f4); + float4 val_f4 = inp_f4[idx]; + __half2 *val_h2 = (__half2 *)(&val_f4); + +#pragma unroll + for (int i = 0; i < 4; i++) { + float2 scale_f2 = __half22float2(scale_h2[i]); + float2 bias_f2 = __half22float2(bias_h2[i]); + float2 val_f2 = __half22float2(val_h2[i]); + val_f2.x = (val_f2.x - s_mean) * s_var * scale_f2.x + bias_f2.x; + val_f2.y = (val_f2.y - s_mean) * s_var * scale_f2.y + bias_f2.y; + val_h2[i] = __float22half2_rn(val_f2); + } + output_f4[idx] = val_f4; + } +} + +// __global__ void ker_layer_norm_x2(__half *ln_res, __half *vars, +// __half *means, const __half *inp, +// const __half *scale, const __half +// *bias, int hidden_size) { +// // step 0. compute local sum +// float l_sum = 0; +// float l_square_sum = 0; +// const float4 *inp_f4 = (const float4 *)inp + blockIdx.x * 2 * hidden_size; +// for (uint idx = 2 * threadIdx.x; idx < hidden_size * 2; idx += blockDim.x * +// 2) { +// float4 val_f4 = inp_f4[idx]; +// float4 val_f4_1 = inp_f4[idx+1]; +// __half2 *val_h2 = (__half2 *)(&val_f4); +// __half2 *val_h2_1 = (__half2 *)(&val_f4_1); +// #pragma unroll +// for (int i = 0; i < 4; i++) { +// float2 val_f2 = __half22float2(val_h2[i]); +// float2 val_f2_1 = __half22float2(val_h2_1[i]); +// l_sum += val_f2.x + val_f2.y + val_f2_1.x + val_f2_1.y; +// l_square_sum += val_f2.x * val_f2.x + val_f2.y * val_f2.y + val_f2_1.x +// * val_f2_1.x + val_f2_1.y * val_f2_1.y; +// } +// } + +// // step 1. compute reduce sum +// float mean_dim = float(hidden_size) * 8.f * 2; +// float reduce_val[2] = {l_sum, l_square_sum}; +// blockReduce(reduce_val); +// __shared__ float s_mean, s_var; +// if (threadIdx.x == 0) { +// s_mean = reduce_val[0] / mean_dim; +// if (means != nullptr) { +// means[blockIdx.x] = s_mean; +// } +// s_var = reduce_val[1] / mean_dim - s_mean * s_mean + LN_EPSILON; +// vars[blockIdx.x] = s_var; +// s_var = rsqrtf(s_var); +// } +// __syncthreads(); + +// // step 2. layer norm result +// float4 *output_f4 = (float4 *)ln_res + blockIdx.x * hidden_size * 2; +// for (uint idx = 2 * threadIdx.x; idx < hidden_size * 2; idx += blockDim.x * +// 2) { +// // load scale, bias, input +// float4 scale_f4 = __ldg((const float4 *)scale + idx); +// __half2 *scale_h2 = (__half2 *)(&scale_f4); +// float4 scale_f4_1 = __ldg((const float4 *)scale + idx + 1); +// __half2 *scale_h2_1 = (__half2 *)(&scale_f4_1); +// float4 bias_f4 = __ldg((const float4 *)bias + idx); +// __half2 *bias_h2 = (__half2 *)(&bias_f4); +// float4 bias_f4_1 = __ldg((const float4 *)bias + idx + 1); +// __half2 *bias_h2_1 = (__half2 *)(&bias_f4_1); +// float4 val_f4 = inp_f4[idx]; +// __half2 *val_h2 = (__half2 *)(&val_f4); +// float4 val_f4_1 = inp_f4[idx+1]; +// __half2 *val_h2_1 = (__half2 *)(&val_f4_1); + +// #pragma unroll +// for (int i = 0; i < 4; i++) { +// float2 scale_f2 = __half22float2(scale_h2[i]); +// float2 scale_f2_1 = __half22float2(scale_h2_1[i]); +// float2 bias_f2 = __half22float2(bias_h2[i]); +// float2 bias_f2_1 = __half22float2(bias_h2_1[i]); +// float2 val_f2 = __half22float2(val_h2[i]); +// float2 val_f2_1 = __half22float2(val_h2_1[i]); +// val_f2.x = (val_f2.x - s_mean) * s_var * scale_f2.x + bias_f2.x; +// val_f2.y = (val_f2.y - s_mean) * s_var * scale_f2.y + bias_f2.y; +// val_h2[i] = __float22half2_rn(val_f2); +// val_f2_1.x = (val_f2_1.x - s_mean) * s_var * scale_f2_1.x + +// bias_f2_1.x; val_f2_1.y = (val_f2_1.y - s_mean) * s_var * scale_f2_1.y +// + bias_f2_1.y; val_h2_1[i] = __float22half2_rn(val_f2_1); +// } +// output_f4[idx] = val_f4; +// output_f4[idx+1] = val_f4_1; +// } +// } + +// __global__ void ker_layer_norm_x4(__half *ln_res, __half *vars, +// __half *means, const __half *inp, +// const __half *scale, const __half +// *bias, int hidden_size) { +// // step 0. compute local sum +// float l_sum = 0; +// float l_square_sum = 0; +// const float4 *inp_f4 = (const float4 *)inp + blockIdx.x * hidden_size * 4; +// for (uint idx = 4 * threadIdx.x; idx < hidden_size * 4; idx += blockDim.x * +// 4) { +// float4 val_f4 = inp_f4[idx]; +// float4 val_f4_1 = inp_f4[idx+1]; +// float4 val_f4_2 = inp_f4[idx+2]; +// float4 val_f4_3 = inp_f4[idx+3]; +// __half2 *val_h2 = (__half2 *)(&val_f4); +// __half2 *val_h2_1 = (__half2 *)(&val_f4_1); +// __half2 *val_h2_2 = (__half2 *)(&val_f4_2); +// __half2 *val_h2_3 = (__half2 *)(&val_f4_3); +// #pragma unroll +// for (int i = 0; i < 4; i++) { +// float2 val_f2 = __half22float2(val_h2[i]); +// float2 val_f2_1 = __half22float2(val_h2_1[i]); +// float2 val_f2_2 = __half22float2(val_h2_2[i]); +// float2 val_f2_3 = __half22float2(val_h2_3[i]); +// l_sum += val_f2.x + val_f2.y + val_f2_1.x + val_f2_1.y + val_f2_2.x + +// val_f2_2.y + val_f2_3.x + val_f2_3.y; l_square_sum += val_f2.x * +// val_f2.x + val_f2.y * val_f2.y; l_square_sum += val_f2_1.x * val_f2_1.x +// + val_f2_1.y * val_f2_1.y; l_square_sum += val_f2_2.x * val_f2_2.x + +// val_f2_2.y * val_f2_2.y; l_square_sum += val_f2_3.x * val_f2_3.x + +// val_f2_3.y * val_f2_3.y; +// } +// } + +// // step 1. compute reduce sum +// float mean_dim = float(hidden_size) * 8.f * 4; +// float reduce_val[2] = {l_sum, l_square_sum}; +// blockReduce(reduce_val); +// __shared__ float s_mean, s_var; +// if (threadIdx.x == 0) { +// s_mean = reduce_val[0] / mean_dim; +// if (means != nullptr) { +// means[blockIdx.x] = s_mean; +// } +// s_var = reduce_val[1] / mean_dim - s_mean * s_mean + LN_EPSILON; +// vars[blockIdx.x] = s_var; +// s_var = rsqrtf(s_var); +// } +// __syncthreads(); + +// // step 2. layer norm result +// float4 *output_f4 = (float4 *)ln_res + blockIdx.x * hidden_size * 4; +// for (uint idx = 4 * threadIdx.x; idx < hidden_size * 4; idx += blockDim.x * +// 4) { +// // load scale, bias, input +// float4 scale_f4 = __ldg((const float4 *)scale + idx); +// __half2 *scale_h2 = (__half2 *)(&scale_f4); +// float4 scale_f4_1 = __ldg((const float4 *)scale + idx + 1); +// __half2 *scale_h2_1 = (__half2 *)(&scale_f4_1); +// float4 scale_f4_2 = __ldg((const float4 *)scale + idx + 2); +// __half2 *scale_h2_2 = (__half2 *)(&scale_f4_2); +// float4 scale_f4_3 = __ldg((const float4 *)scale + idx + 3); +// __half2 *scale_h2_3 = (__half2 *)(&scale_f4_3); +// float4 bias_f4 = __ldg((const float4 *)bias + idx); +// __half2 *bias_h2 = (__half2 *)(&bias_f4); +// float4 bias_f4_1 = __ldg((const float4 *)bias + idx + 1); +// __half2 *bias_h2_1 = (__half2 *)(&bias_f4_1); +// float4 bias_f4_2 = __ldg((const float4 *)bias + idx + 2); +// __half2 *bias_h2_2 = (__half2 *)(&bias_f4_2); +// float4 bias_f4_3 = __ldg((const float4 *)bias + idx + 3); +// __half2 *bias_h2_3 = (__half2 *)(&bias_f4_3); +// float4 val_f4 = inp_f4[idx]; +// __half2 *val_h2 = (__half2 *)(&val_f4); +// float4 val_f4_1 = inp_f4[idx+1]; +// __half2 *val_h2_1 = (__half2 *)(&val_f4_1); +// float4 val_f4_2 = inp_f4[idx+2]; +// __half2 *val_h2_2 = (__half2 *)(&val_f4_2); +// float4 val_f4_3 = inp_f4[idx+3]; +// __half2 *val_h2_3 = (__half2 *)(&val_f4_3); + +// #pragma unroll +// for (int i = 0; i < 4; i++) { +// float2 scale_f2 = __half22float2(scale_h2[i]); +// float2 scale_f2_1 = __half22float2(scale_h2_1[i]); +// float2 scale_f2_2 = __half22float2(scale_h2_2[i]); +// float2 scale_f2_3 = __half22float2(scale_h2_3[i]); +// float2 bias_f2 = __half22float2(bias_h2[i]); +// float2 bias_f2_1 = __half22float2(bias_h2_1[i]); +// float2 bias_f2_2 = __half22float2(bias_h2_2[i]); +// float2 bias_f2_3 = __half22float2(bias_h2_3[i]); +// float2 val_f2 = __half22float2(val_h2[i]); +// float2 val_f2_1 = __half22float2(val_h2_1[i]); +// float2 val_f2_2 = __half22float2(val_h2_2[i]); +// float2 val_f2_3 = __half22float2(val_h2_3[i]); +// val_f2.x = (val_f2.x - s_mean) * s_var * scale_f2.x + bias_f2.x; +// val_f2.y = (val_f2.y - s_mean) * s_var * scale_f2.y + bias_f2.y; +// val_f2_1.x = (val_f2_1.x - s_mean) * s_var * scale_f2_1.x + +// bias_f2_1.x; val_f2_1.y = (val_f2_1.y - s_mean) * s_var * scale_f2_1.y +// + bias_f2_1.y; val_f2_2.x = (val_f2_2.x - s_mean) * s_var * +// scale_f2_2.x + bias_f2_2.x; val_f2_2.y = (val_f2_2.y - s_mean) * s_var +// * scale_f2_2.y + bias_f2_2.y; val_f2_3.x = (val_f2_3.x - s_mean) * +// s_var * scale_f2_3.x + bias_f2_3.x; val_f2_3.y = (val_f2_3.y - s_mean) +// * s_var * scale_f2_3.y + bias_f2_3.y; val_h2[i] = +// __float22half2_rn(val_f2); val_h2_1[i] = __float22half2_rn(val_f2_1); +// val_h2_2[i] = __float22half2_rn(val_f2_2); +// val_h2_3[i] = __float22half2_rn(val_f2_3); +// } +// output_f4[idx] = val_f4; +// output_f4[idx+1] = val_f4_1; +// output_f4[idx+2] = val_f4_2; +// output_f4[idx+3] = val_f4_3; +// } +// } + +template <> +void launch_layer_norm(float *ln_res, float *vars, float *means, + const float *inp, const float *scale, + const float *bias, int batch_size, int hidden_dim, + cudaStream_t stream) { + if (hidden_dim % 4 != 0) { + throw std::runtime_error("violate hidden_dim % 4 = 0"); + } + hidden_dim >>= 2; + int nthread = min(((hidden_dim + 31) / 32) * 32, MAX_THREADS); + dim3 grid_dim(batch_size); + dim3 block_dim(nthread); + + ker_layer_norm<<>>( + ln_res, vars, means, inp, scale, bias, hidden_dim); +} + +template <> +void launch_layer_norm<__half>(__half *ln_res, __half *vars, __half *means, + const __half *inp, const __half *scale, + const __half *bias, int batch_size, + int hidden_dim, cudaStream_t stream) { + if (hidden_dim % 8 != 0) { + throw std::runtime_error("violate hidden_dim % 8 = 0"); + } + hidden_dim >>= 3; + int nthread = min(((hidden_dim + 31) / 32) * 32, MAX_THREADS); + dim3 grid_dim(batch_size); + dim3 block_dim(nthread); + + ker_layer_norm<__half><<>>( + ln_res, vars, means, inp, scale, bias, hidden_dim); + // if (hidden_dim % 8 != 0) { + // throw std::runtime_error("violate hidden_dim % 8 = 0"); + // } + // hidden_dim >>= 3; + + // if (hidden_dim * 8 < 8192) { + // int nthread = min(((hidden_dim + 31) / 32) * 32, MAX_THREADS); + // dim3 grid_dim(batch_size); + // dim3 block_dim(nthread); + // ker_layer_norm<__half><<>>( + // ln_res, vars, means, inp, scale, bias, hidden_dim); + // } else if (hidden_dim * 8 >= 8192 && hidden_dim * 8 <= 8192 * 2) { + // hidden_dim >>= 1; + // int nthread = min(((hidden_dim + 31) / 32) * 32, MAX_THREADS); + // dim3 grid_dim(batch_size); + // dim3 block_dim(nthread); + // ker_layer_norm_x2<<>>( + // ln_res, vars, means, inp, scale, bias, hidden_dim); + // } else if (hidden_dim * 8 > 8192 * 2 && hidden_dim * 8 <= 8192 * 4) { + // hidden_dim >>= 2; + // int nthread = min(((hidden_dim + 31) / 32) * 32, MAX_THREADS); + // dim3 grid_dim(batch_size); + // dim3 block_dim(nthread); + // ker_layer_norm_x4<<>>( + // ln_res, vars, means, inp, scale, bias, hidden_dim); + // } else { + // throw std::runtime_error("hidden_dim % 4 != 0 || hidden_dim > 32768"); + // } +} + +/** +@brief: ker_ln_bw_dgamma_dbetta +Layer norm backword kernel, compute the gradient of gamma and betta. +dbetta = sum(dout, dim=0) +dgamma = sum(xhat * dout, dim=0) +xhat = (input - mean) * rsqrt(var) or + (output - betta) / gamma + + +@thread +gridDim.x = hidden_size / 32 +blockDim.x = 32 +blockDim.y = 32 + +@param +gamma_grad: [hidden_size], gradient of gamma +betta_grad: [hidden_size], gradient of betta +out_grad: [batch_size * seq_len, hidden_size], gradient of betta ln output +inp_or_out: [batch_size * seq_len, hidden_size], ln output if means is nullptr + ln input if means is not nullptr +gamma: [hidden_size], gamma of ln, + used to compute xhat, maybe nullptr +betta: [hidden_size], betta of ln, + used to compute xhat, maybe nullptr +vars: [batch_size * seq_len], variance of ln forward, + used to compute xhat, maybe nullptr +means: [batch_size * seq_len], mean of ln forward, + used to compute xhat, maybe nullptr +(gamma && betta) ^ (vars && means) should be true +*/ +template +__global__ void ker_ln_bw_dgamma_dbetta(T *gamma_grad, T *betta_grad, + const T *out_grad, const T *inp_or_out, + const T *gamma, const T *betta, + const T *vars, const T *means, int rows, + int width) { + __shared__ float betta_buffer[TILE_DIM][TILE_DIM]; + __shared__ float gamma_buffer[TILE_DIM][TILE_DIM]; + + cg::thread_block b = cg::this_thread_block(); + cg::thread_block_tile g = cg::tiled_partition(b); + + int idx = blockDim.x * blockIdx.x + threadIdx.x; + int offset = threadIdx.y * width + idx; + int y_stride = width * TILE_DIM; + + // Loop across inp height + float dbetta = 0; + float dgamma = 0; + float dout, val; + if (idx < width) { + if (means == nullptr) { + float vbetta = (float)betta[idx]; + float vgamma = (float)gamma[idx]; + for (int r = threadIdx.y; r < rows; r += TILE_DIM) { + dout = (float)out_grad[offset]; + // inp_or_out is output + val = (float)inp_or_out[offset]; + dbetta += dout; + dgamma += ((val - vbetta) / add_eps(vgamma) * dout); + offset += y_stride; + } + } else { + for (int r = threadIdx.y; r < rows; r += TILE_DIM) { + dout = (float)out_grad[offset]; + // inp_or_out is input + val = (float)inp_or_out[offset]; + dbetta += dout; + dgamma += ((val - (float)means[r]) * + rsqrtf((float)vars[r] + LN_EPSILON) * dout); + offset += y_stride; + } + } + } + + // Sum the shared buffer. + betta_buffer[threadIdx.x][threadIdx.y] = dbetta; + gamma_buffer[threadIdx.x][threadIdx.y] = dgamma; + __syncthreads(); + float s1 = betta_buffer[threadIdx.y][threadIdx.x]; + float s2 = gamma_buffer[threadIdx.y][threadIdx.x]; + __syncthreads(); + + for (int i = 1; i < TILE_DIM; i <<= 1) { + s1 += g.shfl_down(s1, i); + s2 += g.shfl_down(s2, i); + } + + int pos = blockIdx.x * TILE_DIM + threadIdx.y; + if (threadIdx.x == 0 && idx < width) { + betta_grad[pos] = s1; + gamma_grad[pos] = s2; + } +} + +/** +@brief: ker_ln_bw_dinp +Layer norm backword kernel, compute the gradient of input. +dinp = (dxhat - (sum(dxhat) + xhat * sum(dxhat * xhat)) / hidden_dim) + * rsqrt(var) +xhat = (input - mean) * rsqrt(var) if mean is not nullptr + (output - betta) / gamma if mean is nullptr +dxhat = dout * gamma + + +@thread +gridDim.x = batch_size * seq_len +blockDim.x = hidden_size + +@param +inp_grad: [batch_size * seq_len, hidden_size], gradient of betta ln output +out_grad: [batch_size * seq_len, hidden_size], gradient of betta ln output +residual_grad: [batch_size * seq_len, hidden_size], gradient of residual input, + usually appear in pre-layer-norm for transformer layer, maybe nullptr +inp_or_out: [batch_size * seq_len, hidden_size], ln output if means is nullptr + ln input if means is not nullptr +gamma: [hidden_size], gamma of ln, + used to compute xhat and dxhat +betta: [hidden_size], betta of ln, + used to compute xhat, maybe nullptr +vars: [batch_size * seq_len], variance of ln forward, + used to compute xhat and dinp +means: [batch_size * seq_len], mean of ln forward, + used to compute xhat, maybe nullptr +*/ +template +__global__ void ker_ln_bw_dinp(T *inp_grad, const T *out_grad, + const T *residual_grad, const T *inp_or_out, + const T *gamma, const T *betta, const T *vars, + const T *means, int hidden_dim) { + int offset = blockIdx.x * hidden_dim + threadIdx.x; + float4 dxhat, xhat; + float var_rsqrt; + + if (threadIdx.x < hidden_dim) { + // step 0. dxhat = dout * gamma + dxhat = ((const float4 *)out_grad)[offset]; + float4 vgamma = ((const float4 *)gamma)[threadIdx.x]; + dxhat.x *= vgamma.x; + dxhat.y *= vgamma.y; + dxhat.z *= vgamma.z; + dxhat.w *= vgamma.w; + + /* + step 1. xhat = (output - betta) / gamma or + (input - mean) * rsqrtf(var) + */ + xhat = ((const float4 *)inp_or_out)[offset]; + var_rsqrt = rsqrtf((float)vars[blockIdx.x] + LN_EPSILON); + if (means == nullptr) { + // inp_or_out is output, xhat = (output - betta) / gamma + float4 vbetta = ((const float4 *)betta)[threadIdx.x]; + xhat.x = (xhat.x - vbetta.x) / add_eps(vgamma.x); + xhat.y = (xhat.y - vbetta.y) / add_eps(vgamma.y); + xhat.z = (xhat.z - vbetta.z) / add_eps(vgamma.z); + xhat.w = (xhat.w - vbetta.w) / add_eps(vgamma.w); + } else { + // inp_or_out is input, xhat = (input - mean) * rsqrtf(var) + float fmean = (float)means[blockIdx.x]; + xhat.x = (xhat.x - fmean) * var_rsqrt; + xhat.y = (xhat.y - fmean) * var_rsqrt; + xhat.z = (xhat.z - fmean) * var_rsqrt; + xhat.w = (xhat.w - fmean) * var_rsqrt; + } + } + + /* step2. block reduce sum for dxhat and dxhat*xhat */ + float reduce_val[2] = {0.f, 0.f}; + if (threadIdx.x < hidden_dim) { + reduce_val[0] = dxhat.x + dxhat.y + dxhat.z + dxhat.w; + reduce_val[1] = dxhat.x * xhat.x + dxhat.y * xhat.y + dxhat.z * xhat.z + + dxhat.w * xhat.w; + } + blockReduce(reduce_val); + __shared__ float s_sum_dxhat, s_sum_dxhat_xhat; + if (threadIdx.x == 0) { + float mean_dim = hidden_dim * 4; + s_sum_dxhat = reduce_val[0] / mean_dim; + s_sum_dxhat_xhat = reduce_val[1] / mean_dim; + } + __syncthreads(); + + /* + step3. compute input gradient + (dxhat - (sum(dxhat) + xhat * sum(dxhat * xhat)) / mean_dim) * rsqrt(var) + */ + if (threadIdx.x >= hidden_dim) { + return; + } + dxhat.x = (dxhat.x - s_sum_dxhat - xhat.x * s_sum_dxhat_xhat) * var_rsqrt; + dxhat.y = (dxhat.y - s_sum_dxhat - xhat.y * s_sum_dxhat_xhat) * var_rsqrt; + dxhat.z = (dxhat.z - s_sum_dxhat - xhat.z * s_sum_dxhat_xhat) * var_rsqrt; + dxhat.w = (dxhat.w - s_sum_dxhat - xhat.w * s_sum_dxhat_xhat) * var_rsqrt; + if (residual_grad) { + // Add the residual grad, + // usually in pre-layer-norm for transformer layer + float4 dresidual = ((const float4 *)residual_grad)[offset]; + dxhat.x += dresidual.x; + dxhat.y += dresidual.y; + dxhat.z += dresidual.z; + dxhat.w += dresidual.w; + } + ((float4 *)inp_grad)[offset] = dxhat; +} + +template <> +__global__ void ker_ln_bw_dinp<__half>(__half *inp_grad, const __half *out_grad, + const __half *residual_grad, + const __half *inp_or_out, + const __half *gamma, const __half *betta, + const __half *vars, const __half *means, + int hidden_dim) { + int offset = blockIdx.x * hidden_dim + threadIdx.x; + + float2 dxhat[4], xhat[4]; + float var_rsqrt; + float4 vtmp; + __half2 *tmp_h2; + float reduce_val[2] = {0.f, 0.f}; + + if (threadIdx.x < hidden_dim) { + // step 0. dxhat = dout * gamma + vtmp = ((const float4 *)out_grad)[offset]; + tmp_h2 = reinterpret_cast<__half2 *>(&vtmp); + float4 gamma_f4 = ((const float4 *)gamma)[threadIdx.x]; + __half2 *gamma_h2 = reinterpret_cast<__half2 *>(&gamma_f4); +#pragma unroll + for (int i = 0; i < 4; i++) { + float2 vdout = __half22float2(tmp_h2[i]); + float2 vgamma = __half22float2(gamma_h2[i]); + dxhat[i].x = vdout.x * vgamma.x; + dxhat[i].y = vdout.y * vgamma.y; + reduce_val[0] += dxhat[i].x + dxhat[i].y; + } + + /* + step 1. xhat = (output - betta) / gamma or + (input - mean) * rsqrtf(var) + */ + vtmp = ((const float4 *)inp_or_out)[offset]; + var_rsqrt = rsqrtf((float)vars[blockIdx.x] + LN_EPSILON); + if (means == nullptr) { + // inp_or_out is output, xhat = (output - betta) / gamma + float4 vbetta = ((const float4 *)betta)[threadIdx.x]; + __half2 *betta_h2 = reinterpret_cast<__half2 *>(&vbetta); +#pragma unroll + for (int i = 0; i < 4; i++) { + float2 vout = __half22float2(tmp_h2[i]); + float2 vgamma = __half22float2(gamma_h2[i]); + float2 vbetta = __half22float2(betta_h2[i]); + xhat[i].x = (vout.x - vbetta.x) / add_eps(vgamma.x); + xhat[i].y = (vout.y - vbetta.y) / add_eps(vgamma.y); + reduce_val[1] += xhat[i].x * dxhat[i].x + xhat[i].y * dxhat[i].y; + } + } else { + // inp_or_out is input, xhat = (input - mean) * rsqrtf(var) + float fmean = (float)means[blockIdx.x]; +#pragma unroll + for (int i = 0; i < 4; i++) { + float2 vinp = __half22float2(tmp_h2[i]); + xhat[i].x = (vinp.x - fmean) * var_rsqrt; + xhat[i].y = (vinp.y - fmean) * var_rsqrt; + reduce_val[1] += xhat[i].x * dxhat[i].x + xhat[i].y * dxhat[i].y; + } + } + } + + /* step2. block reduce sum for dxhat and dxhat*xhat */ + blockReduce(reduce_val); + __shared__ float s_sum_dxhat, s_sum_dxhat_xhat; + if (threadIdx.x == 0) { + float mean_dim = hidden_dim * 8; + s_sum_dxhat = reduce_val[0] / mean_dim; + s_sum_dxhat_xhat = reduce_val[1] / mean_dim; + } + __syncthreads(); + + /* + step3. compute input gradient + (dxhat - (sum(dxhat) + xhat * sum(dxhat * xhat)) / mean_dim) * rsqrt(var) + */ + if (threadIdx.x >= hidden_dim) { + return; + } + if (residual_grad) { + // Add the residual grad, + // usually in pre-layer-norm for transformer layer + float4 dresidual = ((const float4 *)residual_grad)[offset]; + __half *hdres = reinterpret_cast<__half *>(&dresidual); +#pragma unroll + for (int i = 0; i < 4; i++) { + tmp_h2[i].x = __float2half( + (dxhat[i].x - s_sum_dxhat - xhat[i].x * s_sum_dxhat_xhat) * + var_rsqrt + + __half2float(hdres[2 * i])); + tmp_h2[i].y = __float2half( + (dxhat[i].y - s_sum_dxhat - xhat[i].y * s_sum_dxhat_xhat) * + var_rsqrt + + __half2float(hdres[2 * i + 1])); + } + } else { +#pragma unroll + for (int i = 0; i < 4; i++) { + tmp_h2[i].x = __float2half( + (dxhat[i].x - s_sum_dxhat - xhat[i].x * s_sum_dxhat_xhat) * + var_rsqrt); + tmp_h2[i].y = __float2half( + (dxhat[i].y - s_sum_dxhat - xhat[i].y * s_sum_dxhat_xhat) * + var_rsqrt); + } + } + ((float4 *)inp_grad)[offset] = vtmp; +} + +__global__ void ker_ln_bw_dinp_x2(__half *inp_grad, const __half *out_grad, + const __half *residual_grad, + const __half *inp_or_out, const __half *gamma, + const __half *betta, const __half *vars, + const __half *means, int hidden_dim) { + int offset = blockIdx.x * hidden_dim * 2 + threadIdx.x * 2; + + float2 dxhat[4], xhat[4]; + float2 dxhat_1[4], xhat_1[4]; + float var_rsqrt; + float4 vtmp, vtmp_1; + __half2 *tmp_h2; + __half2 *tmp_h2_1; + float reduce_val[2] = {0.f, 0.f}; + + if (threadIdx.x < hidden_dim) { + // step 0. dxhat = dout * gamma + vtmp = ((const float4 *)out_grad)[offset]; + vtmp_1 = ((const float4 *)out_grad)[offset + 1]; + tmp_h2 = reinterpret_cast<__half2 *>(&vtmp); + tmp_h2_1 = reinterpret_cast<__half2 *>(&vtmp_1); + float4 gamma_f4 = ((const float4 *)gamma)[threadIdx.x * 2]; + float4 gamma_f4_1 = ((const float4 *)gamma)[threadIdx.x * 2 + 1]; + __half2 *gamma_h2 = reinterpret_cast<__half2 *>(&gamma_f4); + __half2 *gamma_h2_1 = reinterpret_cast<__half2 *>(&gamma_f4_1); +#pragma unroll + for (int i = 0; i < 4; i++) { + float2 vdout = __half22float2(tmp_h2[i]); + float2 vdout_1 = __half22float2(tmp_h2_1[i]); + float2 vgamma = __half22float2(gamma_h2[i]); + float2 vgamma_1 = __half22float2(gamma_h2_1[i]); + dxhat[i].x = vdout.x * vgamma.x; + dxhat[i].y = vdout.y * vgamma.y; + dxhat_1[i].x = vdout_1.x * vgamma_1.x; + dxhat_1[i].y = vdout_1.y * vgamma_1.y; + reduce_val[0] += dxhat[i].x + dxhat[i].y + dxhat_1[i].x + dxhat_1[i].y; + } + + /* + step 1. xhat = (output - betta) / gamma or + (input - mean) * rsqrtf(var) + */ + vtmp = ((const float4 *)inp_or_out)[offset]; + vtmp_1 = ((const float4 *)inp_or_out)[offset + 1]; + var_rsqrt = rsqrtf((float)vars[blockIdx.x] + LN_EPSILON); + if (means == nullptr) { + // inp_or_out is output, xhat = (output - betta) / gamma + float4 vbetta = ((const float4 *)betta)[2 * threadIdx.x]; + float4 vbetta_1 = ((const float4 *)betta)[2 * threadIdx.x + 1]; + __half2 *betta_h2 = reinterpret_cast<__half2 *>(&vbetta); + __half2 *betta_h2_1 = reinterpret_cast<__half2 *>(&vbetta_1); +#pragma unroll + for (int i = 0; i < 4; i++) { + float2 vout = __half22float2(tmp_h2[i]); + float2 vout_1 = __half22float2(tmp_h2_1[i]); + float2 vgamma = __half22float2(gamma_h2[i]); + float2 vgamma_1 = __half22float2(gamma_h2_1[i]); + float2 vbetta = __half22float2(betta_h2[i]); + float2 vbetta_1 = __half22float2(betta_h2_1[i]); + xhat[i].x = (vout.x - vbetta.x) / add_eps(vgamma.x); + xhat_1[i].x = (vout_1.x - vbetta_1.x) / add_eps(vgamma_1.x); + xhat[i].y = (vout.y - vbetta.y) / add_eps(vgamma.y); + xhat_1[i].y = (vout_1.y - vbetta_1.y) / add_eps(vgamma_1.y); + reduce_val[1] += xhat[i].x * dxhat[i].x + xhat[i].y * dxhat[i].y; + reduce_val[1] += + xhat_1[i].x * dxhat_1[i].x + xhat_1[i].y * dxhat_1[i].y; + } + } else { + // inp_or_out is input, xhat = (input - mean) * rsqrtf(var) + float fmean = (float)means[blockIdx.x]; +#pragma unroll + for (int i = 0; i < 4; i++) { + float2 vinp = __half22float2(tmp_h2[i]); + float2 vinp_1 = __half22float2(tmp_h2_1[i]); + xhat[i].x = (vinp.x - fmean) * var_rsqrt; + xhat_1[i].x = (vinp_1.x - fmean) * var_rsqrt; + xhat[i].y = (vinp.y - fmean) * var_rsqrt; + xhat_1[i].y = (vinp_1.y - fmean) * var_rsqrt; + reduce_val[1] += xhat[i].x * dxhat[i].x + xhat[i].y * dxhat[i].y; + reduce_val[1] += + xhat_1[i].x * dxhat_1[i].x + xhat_1[i].y * dxhat_1[i].y; + } + } + } + + /* step2. block reduce sum for dxhat and dxhat*xhat */ + blockReduce(reduce_val); + __shared__ float s_sum_dxhat, s_sum_dxhat_xhat; + if (threadIdx.x == 0) { + float mean_dim = hidden_dim * 8 * 2; + s_sum_dxhat = reduce_val[0] / mean_dim; + s_sum_dxhat_xhat = reduce_val[1] / mean_dim; + } + __syncthreads(); + + /* + step3. compute input gradient + (dxhat - (sum(dxhat) + xhat * sum(dxhat * xhat)) / mean_dim) * rsqrt(var) + */ + if (threadIdx.x >= hidden_dim) { + return; + } + if (residual_grad) { + // Add the residual grad, + // usually in pre-layer-norm for transformer layer + float4 dresidual = ((const float4 *)residual_grad)[offset]; + float4 dresidual_1 = ((const float4 *)residual_grad)[offset + 1]; + __half *hdres = reinterpret_cast<__half *>(&dresidual); + __half *hdres_1 = reinterpret_cast<__half *>(&dresidual_1); +#pragma unroll + for (int i = 0; i < 4; i++) { + tmp_h2[i].x = __float2half( + (dxhat[i].x - s_sum_dxhat - xhat[i].x * s_sum_dxhat_xhat) * + var_rsqrt + + __half2float(hdres[2 * i])); + tmp_h2_1[i].x = __float2half( + (dxhat_1[i].x - s_sum_dxhat - xhat_1[i].x * s_sum_dxhat_xhat) * + var_rsqrt + + __half2float(hdres_1[2 * i])); + tmp_h2[i].y = __float2half( + (dxhat[i].y - s_sum_dxhat - xhat[i].y * s_sum_dxhat_xhat) * + var_rsqrt + + __half2float(hdres[2 * i + 1])); + tmp_h2_1[i].y = __float2half( + (dxhat_1[i].y - s_sum_dxhat - xhat_1[i].y * s_sum_dxhat_xhat) * + var_rsqrt + + __half2float(hdres_1[2 * i + 1])); + } + } else { +#pragma unroll + for (int i = 0; i < 4; i++) { + tmp_h2[i].x = __float2half( + (dxhat[i].x - s_sum_dxhat - xhat[i].x * s_sum_dxhat_xhat) * + var_rsqrt); + tmp_h2_1[i].x = __float2half( + (dxhat_1[i].x - s_sum_dxhat - xhat_1[i].x * s_sum_dxhat_xhat) * + var_rsqrt); + tmp_h2[i].y = __float2half( + (dxhat[i].y - s_sum_dxhat - xhat[i].y * s_sum_dxhat_xhat) * + var_rsqrt); + tmp_h2_1[i].y = __float2half( + (dxhat_1[i].y - s_sum_dxhat - xhat_1[i].y * s_sum_dxhat_xhat) * + var_rsqrt); + } + } + ((float4 *)inp_grad)[offset] = vtmp; + ((float4 *)inp_grad)[offset + 1] = vtmp_1; +} + +__global__ void ker_ln_bw_dinp_x4(__half *inp_grad, const __half *out_grad, + const __half *residual_grad, + const __half *inp_or_out, const __half *gamma, + const __half *betta, const __half *vars, + const __half *means, int hidden_dim) { + int offset = blockIdx.x * hidden_dim * 4 + threadIdx.x * 4; + + float2 dxhat[4], xhat[4]; + float2 dxhat_1[4], xhat_1[4]; + float2 dxhat_2[4], xhat_2[4]; + float2 dxhat_3[4], xhat_3[4]; + float var_rsqrt; + float4 vtmp, vtmp_1, vtmp_2, vtmp_3; + __half2 *tmp_h2; + __half2 *tmp_h2_1; + __half2 *tmp_h2_2; + __half2 *tmp_h2_3; + float reduce_val[2] = {0.f, 0.f}; + + if (threadIdx.x < hidden_dim) { + // step 0. dxhat = dout * gamma + vtmp = ((const float4 *)out_grad)[offset]; + vtmp_1 = ((const float4 *)out_grad)[offset + 1]; + vtmp_2 = ((const float4 *)out_grad)[offset + 2]; + vtmp_3 = ((const float4 *)out_grad)[offset + 3]; + tmp_h2 = reinterpret_cast<__half2 *>(&vtmp); + tmp_h2_1 = reinterpret_cast<__half2 *>(&vtmp_1); + tmp_h2_2 = reinterpret_cast<__half2 *>(&vtmp_2); + tmp_h2_3 = reinterpret_cast<__half2 *>(&vtmp_3); + float4 gamma_f4 = ((const float4 *)gamma)[threadIdx.x * 4]; + float4 gamma_f4_1 = ((const float4 *)gamma)[threadIdx.x * 4 + 1]; + float4 gamma_f4_2 = ((const float4 *)gamma)[threadIdx.x * 4 + 2]; + float4 gamma_f4_3 = ((const float4 *)gamma)[threadIdx.x * 4 + 3]; + __half2 *gamma_h2 = reinterpret_cast<__half2 *>(&gamma_f4); + __half2 *gamma_h2_1 = reinterpret_cast<__half2 *>(&gamma_f4_1); + __half2 *gamma_h2_2 = reinterpret_cast<__half2 *>(&gamma_f4_2); + __half2 *gamma_h2_3 = reinterpret_cast<__half2 *>(&gamma_f4_3); +#pragma unroll + for (int i = 0; i < 4; i++) { + float2 vdout = __half22float2(tmp_h2[i]); + float2 vdout_1 = __half22float2(tmp_h2_1[i]); + float2 vdout_2 = __half22float2(tmp_h2_2[i]); + float2 vdout_3 = __half22float2(tmp_h2_3[i]); + float2 vgamma = __half22float2(gamma_h2[i]); + float2 vgamma_1 = __half22float2(gamma_h2_1[i]); + float2 vgamma_2 = __half22float2(gamma_h2_2[i]); + float2 vgamma_3 = __half22float2(gamma_h2_3[i]); + dxhat[i].x = vdout.x * vgamma.x; + dxhat[i].y = vdout.y * vgamma.y; + dxhat_1[i].x = vdout_1.x * vgamma_1.x; + dxhat_1[i].y = vdout_1.y * vgamma_1.y; + dxhat_2[i].x = vdout_2.x * vgamma_2.x; + dxhat_2[i].y = vdout_2.y * vgamma_2.y; + dxhat_3[i].x = vdout_3.x * vgamma_3.x; + dxhat_3[i].y = vdout_3.y * vgamma_3.y; + reduce_val[0] += dxhat[i].x + dxhat[i].y + dxhat_1[i].x + dxhat_1[i].y + + dxhat_2[i].x + dxhat_2[i].y + dxhat_3[i].x + + dxhat_3[i].y; + } + + /* + step 1. xhat = (output - betta) / gamma or + (input - mean) * rsqrtf(var) + */ + vtmp = ((const float4 *)inp_or_out)[offset]; + vtmp_1 = ((const float4 *)inp_or_out)[offset + 1]; + vtmp_2 = ((const float4 *)inp_or_out)[offset + 2]; + vtmp_3 = ((const float4 *)inp_or_out)[offset + 3]; + var_rsqrt = rsqrtf((float)vars[blockIdx.x] + LN_EPSILON); + if (means == nullptr) { + // inp_or_out is output, xhat = (output - betta) / gamma + float4 vbetta = ((const float4 *)betta)[4 * threadIdx.x]; + float4 vbetta_1 = ((const float4 *)betta)[4 * threadIdx.x + 1]; + float4 vbetta_2 = ((const float4 *)betta)[4 * threadIdx.x + 2]; + float4 vbetta_3 = ((const float4 *)betta)[4 * threadIdx.x + 3]; + __half2 *betta_h2 = reinterpret_cast<__half2 *>(&vbetta); + __half2 *betta_h2_1 = reinterpret_cast<__half2 *>(&vbetta_1); + __half2 *betta_h2_2 = reinterpret_cast<__half2 *>(&vbetta_2); + __half2 *betta_h2_3 = reinterpret_cast<__half2 *>(&vbetta_3); +#pragma unroll + for (int i = 0; i < 4; i++) { + float2 vout = __half22float2(tmp_h2[i]); + float2 vout_1 = __half22float2(tmp_h2_1[i]); + float2 vout_2 = __half22float2(tmp_h2_2[i]); + float2 vout_3 = __half22float2(tmp_h2_3[i]); + float2 vgamma = __half22float2(gamma_h2[i]); + float2 vgamma_1 = __half22float2(gamma_h2_1[i]); + float2 vgamma_2 = __half22float2(gamma_h2_2[i]); + float2 vgamma_3 = __half22float2(gamma_h2_3[i]); + float2 vbetta = __half22float2(betta_h2[i]); + float2 vbetta_1 = __half22float2(betta_h2_1[i]); + float2 vbetta_2 = __half22float2(betta_h2_2[i]); + float2 vbetta_3 = __half22float2(betta_h2_3[i]); + xhat[i].x = (vout.x - vbetta.x) / add_eps(vgamma.x); + xhat_1[i].x = (vout_1.x - vbetta_1.x) / add_eps(vgamma_1.x); + xhat_2[i].x = (vout_2.x - vbetta_2.x) / add_eps(vgamma_2.x); + xhat_3[i].x = (vout_3.x - vbetta_3.x) / add_eps(vgamma_3.x); + xhat[i].y = (vout.y - vbetta.y) / add_eps(vgamma.y); + xhat_1[i].y = (vout_1.y - vbetta_1.y) / add_eps(vgamma_1.y); + xhat_2[i].y = (vout_2.y - vbetta_2.y) / add_eps(vgamma_2.y); + xhat_3[i].y = (vout_3.y - vbetta_3.y) / add_eps(vgamma_3.y); + reduce_val[1] += xhat[i].x * dxhat[i].x + xhat[i].y * dxhat[i].y; + reduce_val[1] += + xhat_1[i].x * dxhat_1[i].x + xhat_1[i].y * dxhat_1[i].y; + reduce_val[1] += + xhat_2[i].x * dxhat_2[i].x + xhat_2[i].y * dxhat_2[i].y; + reduce_val[1] += + xhat_3[i].x * dxhat_3[i].x + xhat_3[i].y * dxhat_3[i].y; + } + } else { + // inp_or_out is input, xhat = (input - mean) * rsqrtf(var) + float fmean = (float)means[blockIdx.x]; +#pragma unroll + for (int i = 0; i < 4; i++) { + float2 vinp = __half22float2(tmp_h2[i]); + float2 vinp_1 = __half22float2(tmp_h2_1[i]); + float2 vinp_2 = __half22float2(tmp_h2_2[i]); + float2 vinp_3 = __half22float2(tmp_h2_3[i]); + xhat[i].x = (vinp.x - fmean) * var_rsqrt; + xhat_1[i].x = (vinp_1.x - fmean) * var_rsqrt; + xhat_2[i].x = (vinp_2.x - fmean) * var_rsqrt; + xhat_3[i].x = (vinp_3.x - fmean) * var_rsqrt; + xhat[i].y = (vinp.y - fmean) * var_rsqrt; + xhat_1[i].y = (vinp_1.y - fmean) * var_rsqrt; + xhat_2[i].y = (vinp_2.y - fmean) * var_rsqrt; + xhat_3[i].y = (vinp_3.y - fmean) * var_rsqrt; + reduce_val[1] += xhat[i].x * dxhat[i].x + xhat[i].y * dxhat[i].y; + reduce_val[1] += + xhat_1[i].x * dxhat_1[i].x + xhat_1[i].y * dxhat_1[i].y; + reduce_val[1] += + xhat_2[i].x * dxhat_2[i].x + xhat_2[i].y * dxhat_2[i].y; + reduce_val[1] += + xhat_3[i].x * dxhat_3[i].x + xhat_3[i].y * dxhat_3[i].y; + } + } + } + + /* step2. block reduce sum for dxhat and dxhat*xhat */ + blockReduce(reduce_val); + __shared__ float s_sum_dxhat, s_sum_dxhat_xhat; + if (threadIdx.x == 0) { + float mean_dim = hidden_dim * 8 * 4; + s_sum_dxhat = reduce_val[0] / mean_dim; + s_sum_dxhat_xhat = reduce_val[1] / mean_dim; + } + __syncthreads(); + + /* + step3. compute input gradient + (dxhat - (sum(dxhat) + xhat * sum(dxhat * xhat)) / mean_dim) * rsqrt(var) + */ + if (threadIdx.x >= hidden_dim) { + return; + } + if (residual_grad) { + // Add the residual grad, + // usually in pre-layer-norm for transformer layer + float4 dresidual = ((const float4 *)residual_grad)[offset]; + float4 dresidual_1 = ((const float4 *)residual_grad)[offset + 1]; + float4 dresidual_2 = ((const float4 *)residual_grad)[offset + 2]; + float4 dresidual_3 = ((const float4 *)residual_grad)[offset + 3]; + __half *hdres = reinterpret_cast<__half *>(&dresidual); + __half *hdres_1 = reinterpret_cast<__half *>(&dresidual_1); + __half *hdres_2 = reinterpret_cast<__half *>(&dresidual_2); + __half *hdres_3 = reinterpret_cast<__half *>(&dresidual_3); +#pragma unroll + for (int i = 0; i < 4; i++) { + tmp_h2[i].x = __float2half( + (dxhat[i].x - s_sum_dxhat - xhat[i].x * s_sum_dxhat_xhat) * + var_rsqrt + + __half2float(hdres[2 * i])); + tmp_h2_1[i].x = __float2half( + (dxhat_1[i].x - s_sum_dxhat - xhat_1[i].x * s_sum_dxhat_xhat) * + var_rsqrt + + __half2float(hdres_1[2 * i])); + tmp_h2_2[i].x = __float2half( + (dxhat_2[i].x - s_sum_dxhat - xhat_2[i].x * s_sum_dxhat_xhat) * + var_rsqrt + + __half2float(hdres_2[2 * i])); + tmp_h2_3[i].x = __float2half( + (dxhat_3[i].x - s_sum_dxhat - xhat_3[i].x * s_sum_dxhat_xhat) * + var_rsqrt + + __half2float(hdres_3[2 * i])); + tmp_h2[i].y = __float2half( + (dxhat[i].y - s_sum_dxhat - xhat[i].y * s_sum_dxhat_xhat) * + var_rsqrt + + __half2float(hdres[2 * i + 1])); + tmp_h2_1[i].y = __float2half( + (dxhat_1[i].y - s_sum_dxhat - xhat_1[i].y * s_sum_dxhat_xhat) * + var_rsqrt + + __half2float(hdres_1[2 * i + 1])); + tmp_h2_2[i].y = __float2half( + (dxhat_2[i].y - s_sum_dxhat - xhat_2[i].y * s_sum_dxhat_xhat) * + var_rsqrt + + __half2float(hdres_1[2 * i + 1])); + tmp_h2_3[i].y = __float2half( + (dxhat_3[i].y - s_sum_dxhat - xhat_3[i].y * s_sum_dxhat_xhat) * + var_rsqrt + + __half2float(hdres_1[2 * i + 1])); + } + } else { +#pragma unroll + for (int i = 0; i < 4; i++) { + tmp_h2[i].x = __float2half( + (dxhat[i].x - s_sum_dxhat - xhat[i].x * s_sum_dxhat_xhat) * + var_rsqrt); + tmp_h2_1[i].x = __float2half( + (dxhat_1[i].x - s_sum_dxhat - xhat_1[i].x * s_sum_dxhat_xhat) * + var_rsqrt); + tmp_h2_2[i].x = __float2half( + (dxhat_2[i].x - s_sum_dxhat - xhat_2[i].x * s_sum_dxhat_xhat) * + var_rsqrt); + tmp_h2_3[i].x = __float2half( + (dxhat_3[i].x - s_sum_dxhat - xhat_3[i].x * s_sum_dxhat_xhat) * + var_rsqrt); + tmp_h2[i].y = __float2half( + (dxhat[i].y - s_sum_dxhat - xhat[i].y * s_sum_dxhat_xhat) * + var_rsqrt); + tmp_h2_1[i].y = __float2half( + (dxhat_1[i].y - s_sum_dxhat - xhat_1[i].y * s_sum_dxhat_xhat) * + var_rsqrt); + tmp_h2_2[i].y = __float2half( + (dxhat_2[i].y - s_sum_dxhat - xhat_2[i].y * s_sum_dxhat_xhat) * + var_rsqrt); + tmp_h2_3[i].y = __float2half( + (dxhat_3[i].y - s_sum_dxhat - xhat_3[i].y * s_sum_dxhat_xhat) * + var_rsqrt); + } + } + ((float4 *)inp_grad)[offset] = vtmp; + ((float4 *)inp_grad)[offset + 1] = vtmp_1; + ((float4 *)inp_grad)[offset + 2] = vtmp_2; + ((float4 *)inp_grad)[offset + 3] = vtmp_3; +} + +/** +Layer norm backword, + compute the gradient of gamma, betta and input. +dbetta = sum(dout, dim=0) +xhat = (input - mean) * rsqrt(var) if mean is not nullptr + (output - betta) / gamma if mean is nullptr +dgamma = sum(xhat * dout, dim=0) +dxhat = dout * gamma +dinp = (dxhat - (sum(dxhat, 1) + xhat * sum(dxhat * xhat, 1)) / hidden_dim) + * rsqrt(var) + +residual_grad, means, betta can be nullptr. +residual_grad will be added to dinp if it is not nullptr + which is useful in transformer layer when pre-ln +means and betta are only used to compute xhat, + (means == nullptr) ^ (betta == nullptr) should be true +*/ +template <> +void launch_ln_bw(float *gamma_grad, float *betta_grad, float *inp_grad, + const float *out_grad, const float *residual_grad, + const float *inp_or_out, const float *gamma, + const float *betta, const float *vars, + const float *means, int batch, int hidden_dim, + cudaStream_t stream[2]) { + // compute grad of gamma and betta + dim3 grid_dim(((hidden_dim + TILE_DIM - 1) / TILE_DIM) * TILE_DIM); + dim3 block_dim(TILE_DIM, TILE_DIM); + ker_ln_bw_dgamma_dbetta<<>>( + gamma_grad, betta_grad, out_grad, inp_or_out, gamma, betta, vars, means, + batch, hidden_dim); + + // compute grad of input + if (hidden_dim % 4 != 0 || hidden_dim > 4096) { + throw std::runtime_error("hidden_dim % 4 != 0 || hidden_dim > 4096"); + } + hidden_dim >>= 2; + int nthread = min(((hidden_dim + 31) / 32) * 32, MAX_THREADS); + ker_ln_bw_dinp<<>>( + inp_grad, out_grad, residual_grad, inp_or_out, gamma, betta, vars, means, + hidden_dim); +} + +template <> +void launch_ln_bw<__half>(__half *gamma_grad, __half *betta_grad, + __half *inp_grad, const __half *out_grad, + const __half *residual_grad, const __half *inp_or_out, + const __half *gamma, const __half *betta, + const __half *vars, const __half *means, int batch, + int hidden_dim, cudaStream_t stream[2]) { + // compute grad of gamma and betta + dim3 grid_dim(((hidden_dim + TILE_DIM - 1) / TILE_DIM) * TILE_DIM); + dim3 block_dim(TILE_DIM, TILE_DIM); + ker_ln_bw_dgamma_dbetta<__half><<>>( + gamma_grad, betta_grad, out_grad, inp_or_out, gamma, betta, vars, means, + batch, hidden_dim); + + // compute grad of input + if (hidden_dim % 8 != 0) { + throw std::runtime_error("hidden_dim % 8 != 0"); + } + hidden_dim >>= 3; + + if (hidden_dim * 8 <= 8192) { + int nthread = min(((hidden_dim + 31) / 32) * 32, MAX_THREADS); + ker_ln_bw_dinp<<>>( + inp_grad, out_grad, residual_grad, inp_or_out, gamma, betta, vars, + means, hidden_dim); + } else if (hidden_dim * 8 > 8192 && hidden_dim * 8 <= 8192 * 2) { + hidden_dim >>= 1; + int nthread = min(((hidden_dim + 31) / 32) * 32, MAX_THREADS); + ker_ln_bw_dinp_x2<<>>( + inp_grad, out_grad, residual_grad, inp_or_out, gamma, betta, vars, + means, hidden_dim); + } else if (hidden_dim * 8 > 2 * 8192 && hidden_dim * 8 <= 8192 * 4) { + hidden_dim >>= 2; + int nthread = min(((hidden_dim + 31) / 32) * 32, MAX_THREADS); + ker_ln_bw_dinp_x4<<>>( + inp_grad, out_grad, residual_grad, inp_or_out, gamma, betta, vars, + means, hidden_dim); + } else { + throw std::runtime_error("hidden_dim % 4 != 0 || hidden_dim > 32768"); + } +} diff --git a/colossalai/kernel/cuda_native/csrc/kernels/softmax_kernels.cu b/colossalai/kernel/cuda_native/csrc/kernels/softmax_kernels.cu index 98af433fe3972db64f8599bf4d23597365b80f4b..3862a699d3c31758975be55146c1db648dd6368f 100644 --- a/colossalai/kernel/cuda_native/csrc/kernels/softmax_kernels.cu +++ b/colossalai/kernel/cuda_native/csrc/kernels/softmax_kernels.cu @@ -1,365 +1,365 @@ -#include -#include - -#include -#include - -#include "block_reduce.h" -#include "kernels.h" - -namespace cg = cooperative_groups; -const float EPSILON = 1e-8f; - -/** -@brief: softmax_kernel -Softmax forward kernel for - enc-self-attn, dec-self-attn, encdec-attn - -@thread -gridDim.x = dynamic -gridDim.y = batch_size -gridDim.z = nhead -blockDim.x = from_len - -@param -inp: [batch_size, nhead, from_len, to_len], softmax input. -attn_mask: [batch_size, to_len], padding tokens are -inf, - non padding tokens are 0. - attn_mask!=nullptr for enc-self-attn and enc-dec-attn - attn_mask=nullptr and mask_future=ture for dec-self-attn training - attn_mask=nullptr and mask_future=false for dec-self-attn infer -*/ -template -__global__ void ker_attn_softmax(T *inp, const T *attn_mask, int from_len, - int to_len, bool mask_future) { - int batch_id = blockIdx.y; - int head_id = blockIdx.z; - const int nhead = gridDim.z; - const int token_per_reduce = 1; - typedef cub::BlockLoad - BlockLoad; - __shared__ typename BlockLoad::TempStorage ts_load; - typedef cub::BlockStore - BlockStore; - __shared__ typename BlockStore::TempStorage ts_store; - - T mval[ele_per_thread]; - if (attn_mask) { - attn_mask += batch_id * to_len; - BlockLoad(ts_load).Load(attn_mask, mval, to_len, REDUCE_FLOAT_INF_NEG); - } - - inp += flat_3dim(batch_id, head_id, 0, nhead, from_len * to_len); - for (int token_id = blockIdx.x * token_per_reduce; token_id < from_len; - token_id += gridDim.x * token_per_reduce) { - T inp_val[token_per_reduce][ele_per_thread]; - for (int i = 0; i < token_per_reduce && (token_id + i) < from_len; i++) { - BlockLoad(ts_load).Load(inp + (token_id + i) * to_len, inp_val[i], to_len, - REDUCE_FLOAT_INF_NEG); - } - - /* step 1. compute max */ - // thread local max - float val[token_per_reduce][ele_per_thread]; - float l_max[token_per_reduce]; - for (int i = 0; i < token_per_reduce; i++) { - l_max[i] = REDUCE_FLOAT_INF_NEG; - for (int j = 0; j < ele_per_thread; j++) { - if (attn_mask) { - val[i][j] = (float)inp_val[i][j] + (float)mval[j]; - } else { - if (mask_future && ele_per_thread * threadIdx.x + j > token_id + i) { - val[i][j] = REDUCE_FLOAT_INF_NEG; - } else { - val[i][j] = (float)inp_val[i][j]; - } - } - l_max[i] = fmaxf(l_max[i], val[i][j]); - } - } - // block reduce max - blockReduce(l_max); - // write shared - __shared__ float s_max[token_per_reduce]; - if (threadIdx.x == 0) { - for (int i = 0; i < token_per_reduce; i++) { - s_max[i] = l_max[i]; - } - } - __syncthreads(); - - /* step 2. compute sum */ - // thread local sum - float l_sum[token_per_reduce]; - for (int i = 0; i < token_per_reduce; i++) { - l_sum[i] = 0.f; - for (int j = 0; j < ele_per_thread; j++) { - val[i][j] = __expf(val[i][j] - s_max[i]); - l_sum[i] += val[i][j]; - } - } - // block reduce sum - blockReduce(l_sum); - // write shared - __shared__ float s_sum[token_per_reduce]; - if (threadIdx.x == 0) { - for (int i = 0; i < token_per_reduce; i++) { - s_sum[i] = __fdividef(1.0f, l_sum[i] + EPSILON); - } - } - __syncthreads(); - - /* step 3. compute final result */ - for (int i = 0; i < token_per_reduce && (token_id + i) < from_len; i++) { - for (int j = 0; j < ele_per_thread; j++) { - inp_val[i][j] = (T)(val[i][j] * s_sum[i]); - } - BlockStore(ts_store).Store(inp + (token_id + i) * to_len, inp_val[i], - to_len); - } - } // blockIdx.x -} - -template -__global__ void ker_attn_softmax_lt32(T *inp, const T *attn_mask, int from_len, - int to_len, bool mask_future) { - int batch_id = blockIdx.y; - int head_id = blockIdx.z; - const int nhead = gridDim.z; - const int token_per_reduce = 1; - typedef cub::BlockLoad - BlockLoad; - __shared__ typename BlockLoad::TempStorage ts_load; - typedef cub::BlockStore - BlockStore; - __shared__ typename BlockStore::TempStorage ts_store; - - T mval[ele_per_thread]; - if (attn_mask) { - attn_mask += batch_id * to_len; - BlockLoad(ts_load).Load(attn_mask, mval, to_len, REDUCE_FLOAT_INF_NEG); - } - - inp += flat_3dim(batch_id, head_id, 0, nhead, from_len * to_len); - for (int token_id = blockIdx.x * token_per_reduce; token_id < from_len; - token_id += gridDim.x * token_per_reduce) { - T inp_val[token_per_reduce][ele_per_thread]; - for (int i = 0; i < token_per_reduce && (token_id + i) < from_len; i++) { - BlockLoad(ts_load).Load(inp + (token_id + i) * to_len, inp_val[i], to_len, - REDUCE_FLOAT_INF_NEG); - } - - /* step 1. compute max */ - // thread local max - float val[token_per_reduce][ele_per_thread]; - float l_max[token_per_reduce]; - for (int i = 0; i < token_per_reduce; i++) { - l_max[i] = REDUCE_FLOAT_INF_NEG; - for (int j = 0; j < ele_per_thread; j++) { - if (attn_mask) { - val[i][j] = (float)inp_val[i][j] + (float)mval[j]; - } else { - if (mask_future && ele_per_thread * threadIdx.x + j > token_id + i) { - val[i][j] = REDUCE_FLOAT_INF_NEG; - } else { - val[i][j] = (float)inp_val[i][j]; - } - } - l_max[i] = fmaxf(l_max[i], val[i][j]); - } - } - // warp reduce max - warpReduce(l_max); - - /* step 2. compute sum */ - // thread local sum - float l_sum[token_per_reduce]; - for (int i = 0; i < token_per_reduce; i++) { - l_sum[i] = 0.f; - for (int j = 0; j < ele_per_thread; j++) { - val[i][j] = __expf(val[i][j] - l_max[i]); - l_sum[i] += val[i][j]; - } - } - // warp reduce sum - warpReduce(l_sum); - - /* step 3. compute final result */ - for (int i = 0; i < token_per_reduce && (token_id + i) < from_len; i++) { - l_sum[i] = __fdividef(1.0f, l_sum[i] + EPSILON); - for (int j = 0; j < ele_per_thread; j++) { - inp_val[i][j] = (T)(val[i][j] * l_sum[i]); - } - BlockStore(ts_store).Store(inp + (token_id + i) * to_len, inp_val[i], - to_len); - } - } // blockIdx.x -} - -/* - attn_mask!=nullptr for enc-self-attn and enc-dec-attn - attn_mask=nullptr and mask_future=ture for dec-self-attn training - attn_mask=nullptr and mask_future=false for dec-self-attn infer -*/ -template <> -void launch_attn_softmax(float *inp, const float *attn_mask, - int batch_size, int nhead, int from_len, - int to_len, bool mask_future, - cudaStream_t stream) { - dim3 grid_dim(1, batch_size, nhead); - if (to_len <= 32) { - ker_attn_softmax_lt32<<>>( - inp, attn_mask, from_len, to_len, mask_future); - } else if (to_len <= 64) { - ker_attn_softmax_lt32<<>>( - inp, attn_mask, from_len, to_len, mask_future); - } else if (to_len <= 128) { - grid_dim.x = 16; - ker_attn_softmax<<>>( - inp, attn_mask, from_len, to_len, mask_future); - } else if (to_len <= 256) { - grid_dim.x = 32; - ker_attn_softmax<<>>( - inp, attn_mask, from_len, to_len, mask_future); - } else if (to_len <= 512) { - grid_dim.x = 64; - ker_attn_softmax<<>>( - inp, attn_mask, from_len, to_len, mask_future); - } else { - throw std::runtime_error( - "Sequence length greater than 512 is currently not supported"); - } -} - -template <> -void launch_attn_softmax<__half>(__half *inp, const __half *attn_mask, - int batch_size, int nhead, int from_len, - int to_len, bool mask_future, - cudaStream_t stream) { - dim3 grid_dim(1, batch_size, nhead); - if (to_len <= 32) { - ker_attn_softmax_lt32<__half, 32, 1><<>>( - inp, attn_mask, from_len, to_len, mask_future); - } else if (to_len <= 64) { - ker_attn_softmax_lt32<__half, 32, 2><<>>( - inp, attn_mask, from_len, to_len, mask_future); - } else if (to_len <= 128) { - grid_dim.x = 8; - ker_attn_softmax<__half, 64, 2><<>>( - inp, attn_mask, from_len, to_len, mask_future); - } else if (to_len <= 256) { - grid_dim.x = 16; - ker_attn_softmax<__half, 128, 2><<>>( - inp, attn_mask, from_len, to_len, mask_future); - } else if (to_len <= 512) { - grid_dim.x = 32; - ker_attn_softmax<__half, 256, 2><<>>( - inp, attn_mask, from_len, to_len, mask_future); - } else { - throw std::runtime_error( - "Sequence length greater than 512 is currently not supported"); - } -} - -/** -@brief: ker_attn_softmax_bw -Softmax backward in self attention. - -@thread -gridDim.x = batch_size * nhead * seq_len / warps_per_block -blockDim.x = WARP_SIZE -blockDim.y = warps_per_block - -@param -grad: [batch_size, nhead, seq_len, seq_len], output grad. -output: [batch_size, nhead, seq_len, seq_len], output of softmax forward. -*/ -template -__global__ void ker_attn_softmax_bw(T *grad, const T *inp, int softmax_length) { - int batch_idx = blockIdx.x * blockDim.y + threadIdx.y; - int offset = batch_idx * softmax_length + threadIdx.x; - - grad += offset; - inp += offset; - - T grad_reg[ITERATIONS]; - T inp_reg[ITERATIONS]; - float sum = 0.0; - -#pragma unroll - for (int i = 0; i < ITERATIONS; ++i) { - int curr_idx = threadIdx.x + i * WARP_SIZE; - if (curr_idx < softmax_length) { - grad_reg[i] = grad[i * WARP_SIZE]; - inp_reg[i] = inp[i * WARP_SIZE]; - sum += (float)grad_reg[i] * (float)inp_reg[i]; - } - } - - cg::thread_block b = cg::this_thread_block(); - cg::thread_block_tile g = cg::tiled_partition(b); - - for (int i = 1; i < WARP_SIZE; i <<= 1) sum += g.shfl_xor(sum, i); - -#pragma unroll - for (int i = 0; i < ITERATIONS; ++i) { - int curr_idx = threadIdx.x + i * WARP_SIZE; - if (curr_idx < softmax_length) - grad[i * WARP_SIZE] = (T)((float)inp_reg[i] * ((float)grad_reg[i] - sum)); - } -} - -template -void launch_attn_softmax_bw(T *out_grad, const T *soft_inp, int rows, - int softmax_len, cudaStream_t stream) { - const int warps_per_block = 4; - // rows = batch_size * nhead * from_len - dim3 grid_dim(rows / warps_per_block); - dim3 block_dim(WARP_SIZE, warps_per_block); - - if (softmax_len <= 32) - ker_attn_softmax_bw - <<>>(out_grad, soft_inp, softmax_len); - else if (softmax_len <= 64) - ker_attn_softmax_bw - <<>>(out_grad, soft_inp, softmax_len); - else if (softmax_len <= 128) - ker_attn_softmax_bw - <<>>(out_grad, soft_inp, softmax_len); - else if (softmax_len <= 256) - ker_attn_softmax_bw - <<>>(out_grad, soft_inp, softmax_len); - else if (softmax_len <= 384) - ker_attn_softmax_bw - <<>>(out_grad, soft_inp, softmax_len); - else if (softmax_len <= 512) - ker_attn_softmax_bw - <<>>(out_grad, soft_inp, softmax_len); - else if (softmax_len <= 768) - ker_attn_softmax_bw - <<>>(out_grad, soft_inp, softmax_len); - else if (softmax_len <= 1024) - ker_attn_softmax_bw - <<>>(out_grad, soft_inp, softmax_len); - else if (softmax_len <= 2048) - ker_attn_softmax_bw - <<>>(out_grad, soft_inp, softmax_len); - else - throw std::runtime_error( - std::string( - "Special sequence length found in softmax backward, seq_len: ") + - std::to_string(softmax_len)); -} - -template void launch_attn_softmax_bw<__half>(__half *out_grad, - const __half *soft_inp, int rows, - int softmax_len, - cudaStream_t stream); -template void launch_attn_softmax_bw(float *out_grad, - const float *soft_inp, int rows, - int softmax_len, - cudaStream_t stream); +#include +#include + +#include +#include + +#include "block_reduce.h" +#include "kernels.h" + +namespace cg = cooperative_groups; +const float EPSILON = 1e-8f; + +/** +@brief: softmax_kernel +Softmax forward kernel for + enc-self-attn, dec-self-attn, encdec-attn + +@thread +gridDim.x = dynamic +gridDim.y = batch_size +gridDim.z = nhead +blockDim.x = from_len + +@param +inp: [batch_size, nhead, from_len, to_len], softmax input. +attn_mask: [batch_size, to_len], padding tokens are -inf, + non padding tokens are 0. + attn_mask!=nullptr for enc-self-attn and enc-dec-attn + attn_mask=nullptr and mask_future=ture for dec-self-attn training + attn_mask=nullptr and mask_future=false for dec-self-attn infer +*/ +template +__global__ void ker_attn_softmax(T *inp, const T *attn_mask, int from_len, + int to_len, bool mask_future) { + int batch_id = blockIdx.y; + int head_id = blockIdx.z; + const int nhead = gridDim.z; + const int token_per_reduce = 1; + typedef cub::BlockLoad + BlockLoad; + __shared__ typename BlockLoad::TempStorage ts_load; + typedef cub::BlockStore + BlockStore; + __shared__ typename BlockStore::TempStorage ts_store; + + T mval[ele_per_thread]; + if (attn_mask) { + attn_mask += batch_id * to_len; + BlockLoad(ts_load).Load(attn_mask, mval, to_len, REDUCE_FLOAT_INF_NEG); + } + + inp += flat_3dim(batch_id, head_id, 0, nhead, from_len * to_len); + for (int token_id = blockIdx.x * token_per_reduce; token_id < from_len; + token_id += gridDim.x * token_per_reduce) { + T inp_val[token_per_reduce][ele_per_thread]; + for (int i = 0; i < token_per_reduce && (token_id + i) < from_len; i++) { + BlockLoad(ts_load).Load(inp + (token_id + i) * to_len, inp_val[i], to_len, + REDUCE_FLOAT_INF_NEG); + } + + /* step 1. compute max */ + // thread local max + float val[token_per_reduce][ele_per_thread]; + float l_max[token_per_reduce]; + for (int i = 0; i < token_per_reduce; i++) { + l_max[i] = REDUCE_FLOAT_INF_NEG; + for (int j = 0; j < ele_per_thread; j++) { + if (attn_mask) { + val[i][j] = (float)inp_val[i][j] + (float)mval[j]; + } else { + if (mask_future && ele_per_thread * threadIdx.x + j > token_id + i) { + val[i][j] = REDUCE_FLOAT_INF_NEG; + } else { + val[i][j] = (float)inp_val[i][j]; + } + } + l_max[i] = fmaxf(l_max[i], val[i][j]); + } + } + // block reduce max + blockReduce(l_max); + // write shared + __shared__ float s_max[token_per_reduce]; + if (threadIdx.x == 0) { + for (int i = 0; i < token_per_reduce; i++) { + s_max[i] = l_max[i]; + } + } + __syncthreads(); + + /* step 2. compute sum */ + // thread local sum + float l_sum[token_per_reduce]; + for (int i = 0; i < token_per_reduce; i++) { + l_sum[i] = 0.f; + for (int j = 0; j < ele_per_thread; j++) { + val[i][j] = __expf(val[i][j] - s_max[i]); + l_sum[i] += val[i][j]; + } + } + // block reduce sum + blockReduce(l_sum); + // write shared + __shared__ float s_sum[token_per_reduce]; + if (threadIdx.x == 0) { + for (int i = 0; i < token_per_reduce; i++) { + s_sum[i] = __fdividef(1.0f, l_sum[i] + EPSILON); + } + } + __syncthreads(); + + /* step 3. compute final result */ + for (int i = 0; i < token_per_reduce && (token_id + i) < from_len; i++) { + for (int j = 0; j < ele_per_thread; j++) { + inp_val[i][j] = (T)(val[i][j] * s_sum[i]); + } + BlockStore(ts_store).Store(inp + (token_id + i) * to_len, inp_val[i], + to_len); + } + } // blockIdx.x +} + +template +__global__ void ker_attn_softmax_lt32(T *inp, const T *attn_mask, int from_len, + int to_len, bool mask_future) { + int batch_id = blockIdx.y; + int head_id = blockIdx.z; + const int nhead = gridDim.z; + const int token_per_reduce = 1; + typedef cub::BlockLoad + BlockLoad; + __shared__ typename BlockLoad::TempStorage ts_load; + typedef cub::BlockStore + BlockStore; + __shared__ typename BlockStore::TempStorage ts_store; + + T mval[ele_per_thread]; + if (attn_mask) { + attn_mask += batch_id * to_len; + BlockLoad(ts_load).Load(attn_mask, mval, to_len, REDUCE_FLOAT_INF_NEG); + } + + inp += flat_3dim(batch_id, head_id, 0, nhead, from_len * to_len); + for (int token_id = blockIdx.x * token_per_reduce; token_id < from_len; + token_id += gridDim.x * token_per_reduce) { + T inp_val[token_per_reduce][ele_per_thread]; + for (int i = 0; i < token_per_reduce && (token_id + i) < from_len; i++) { + BlockLoad(ts_load).Load(inp + (token_id + i) * to_len, inp_val[i], to_len, + REDUCE_FLOAT_INF_NEG); + } + + /* step 1. compute max */ + // thread local max + float val[token_per_reduce][ele_per_thread]; + float l_max[token_per_reduce]; + for (int i = 0; i < token_per_reduce; i++) { + l_max[i] = REDUCE_FLOAT_INF_NEG; + for (int j = 0; j < ele_per_thread; j++) { + if (attn_mask) { + val[i][j] = (float)inp_val[i][j] + (float)mval[j]; + } else { + if (mask_future && ele_per_thread * threadIdx.x + j > token_id + i) { + val[i][j] = REDUCE_FLOAT_INF_NEG; + } else { + val[i][j] = (float)inp_val[i][j]; + } + } + l_max[i] = fmaxf(l_max[i], val[i][j]); + } + } + // warp reduce max + warpReduce(l_max); + + /* step 2. compute sum */ + // thread local sum + float l_sum[token_per_reduce]; + for (int i = 0; i < token_per_reduce; i++) { + l_sum[i] = 0.f; + for (int j = 0; j < ele_per_thread; j++) { + val[i][j] = __expf(val[i][j] - l_max[i]); + l_sum[i] += val[i][j]; + } + } + // warp reduce sum + warpReduce(l_sum); + + /* step 3. compute final result */ + for (int i = 0; i < token_per_reduce && (token_id + i) < from_len; i++) { + l_sum[i] = __fdividef(1.0f, l_sum[i] + EPSILON); + for (int j = 0; j < ele_per_thread; j++) { + inp_val[i][j] = (T)(val[i][j] * l_sum[i]); + } + BlockStore(ts_store).Store(inp + (token_id + i) * to_len, inp_val[i], + to_len); + } + } // blockIdx.x +} + +/* + attn_mask!=nullptr for enc-self-attn and enc-dec-attn + attn_mask=nullptr and mask_future=ture for dec-self-attn training + attn_mask=nullptr and mask_future=false for dec-self-attn infer +*/ +template <> +void launch_attn_softmax(float *inp, const float *attn_mask, + int batch_size, int nhead, int from_len, + int to_len, bool mask_future, + cudaStream_t stream) { + dim3 grid_dim(1, batch_size, nhead); + if (to_len <= 32) { + ker_attn_softmax_lt32<<>>( + inp, attn_mask, from_len, to_len, mask_future); + } else if (to_len <= 64) { + ker_attn_softmax_lt32<<>>( + inp, attn_mask, from_len, to_len, mask_future); + } else if (to_len <= 128) { + grid_dim.x = 16; + ker_attn_softmax<<>>( + inp, attn_mask, from_len, to_len, mask_future); + } else if (to_len <= 256) { + grid_dim.x = 32; + ker_attn_softmax<<>>( + inp, attn_mask, from_len, to_len, mask_future); + } else if (to_len <= 512) { + grid_dim.x = 64; + ker_attn_softmax<<>>( + inp, attn_mask, from_len, to_len, mask_future); + } else { + throw std::runtime_error( + "Sequence length greater than 512 is currently not supported"); + } +} + +template <> +void launch_attn_softmax<__half>(__half *inp, const __half *attn_mask, + int batch_size, int nhead, int from_len, + int to_len, bool mask_future, + cudaStream_t stream) { + dim3 grid_dim(1, batch_size, nhead); + if (to_len <= 32) { + ker_attn_softmax_lt32<__half, 32, 1><<>>( + inp, attn_mask, from_len, to_len, mask_future); + } else if (to_len <= 64) { + ker_attn_softmax_lt32<__half, 32, 2><<>>( + inp, attn_mask, from_len, to_len, mask_future); + } else if (to_len <= 128) { + grid_dim.x = 8; + ker_attn_softmax<__half, 64, 2><<>>( + inp, attn_mask, from_len, to_len, mask_future); + } else if (to_len <= 256) { + grid_dim.x = 16; + ker_attn_softmax<__half, 128, 2><<>>( + inp, attn_mask, from_len, to_len, mask_future); + } else if (to_len <= 512) { + grid_dim.x = 32; + ker_attn_softmax<__half, 256, 2><<>>( + inp, attn_mask, from_len, to_len, mask_future); + } else { + throw std::runtime_error( + "Sequence length greater than 512 is currently not supported"); + } +} + +/** +@brief: ker_attn_softmax_bw +Softmax backward in self attention. + +@thread +gridDim.x = batch_size * nhead * seq_len / warps_per_block +blockDim.x = WARP_SIZE +blockDim.y = warps_per_block + +@param +grad: [batch_size, nhead, seq_len, seq_len], output grad. +output: [batch_size, nhead, seq_len, seq_len], output of softmax forward. +*/ +template +__global__ void ker_attn_softmax_bw(T *grad, const T *inp, int softmax_length) { + int batch_idx = blockIdx.x * blockDim.y + threadIdx.y; + int offset = batch_idx * softmax_length + threadIdx.x; + + grad += offset; + inp += offset; + + T grad_reg[ITERATIONS]; + T inp_reg[ITERATIONS]; + float sum = 0.0; + +#pragma unroll + for (int i = 0; i < ITERATIONS; ++i) { + int curr_idx = threadIdx.x + i * WARP_SIZE; + if (curr_idx < softmax_length) { + grad_reg[i] = grad[i * WARP_SIZE]; + inp_reg[i] = inp[i * WARP_SIZE]; + sum += (float)grad_reg[i] * (float)inp_reg[i]; + } + } + + cg::thread_block b = cg::this_thread_block(); + cg::thread_block_tile g = cg::tiled_partition(b); + + for (int i = 1; i < WARP_SIZE; i <<= 1) sum += g.shfl_xor(sum, i); + +#pragma unroll + for (int i = 0; i < ITERATIONS; ++i) { + int curr_idx = threadIdx.x + i * WARP_SIZE; + if (curr_idx < softmax_length) + grad[i * WARP_SIZE] = (T)((float)inp_reg[i] * ((float)grad_reg[i] - sum)); + } +} + +template +void launch_attn_softmax_bw(T *out_grad, const T *soft_inp, int rows, + int softmax_len, cudaStream_t stream) { + const int warps_per_block = 4; + // rows = batch_size * nhead * from_len + dim3 grid_dim(rows / warps_per_block); + dim3 block_dim(WARP_SIZE, warps_per_block); + + if (softmax_len <= 32) + ker_attn_softmax_bw + <<>>(out_grad, soft_inp, softmax_len); + else if (softmax_len <= 64) + ker_attn_softmax_bw + <<>>(out_grad, soft_inp, softmax_len); + else if (softmax_len <= 128) + ker_attn_softmax_bw + <<>>(out_grad, soft_inp, softmax_len); + else if (softmax_len <= 256) + ker_attn_softmax_bw + <<>>(out_grad, soft_inp, softmax_len); + else if (softmax_len <= 384) + ker_attn_softmax_bw + <<>>(out_grad, soft_inp, softmax_len); + else if (softmax_len <= 512) + ker_attn_softmax_bw + <<>>(out_grad, soft_inp, softmax_len); + else if (softmax_len <= 768) + ker_attn_softmax_bw + <<>>(out_grad, soft_inp, softmax_len); + else if (softmax_len <= 1024) + ker_attn_softmax_bw + <<>>(out_grad, soft_inp, softmax_len); + else if (softmax_len <= 2048) + ker_attn_softmax_bw + <<>>(out_grad, soft_inp, softmax_len); + else + throw std::runtime_error( + std::string( + "Special sequence length found in softmax backward, seq_len: ") + + std::to_string(softmax_len)); +} + +template void launch_attn_softmax_bw<__half>(__half *out_grad, + const __half *soft_inp, int rows, + int softmax_len, + cudaStream_t stream); +template void launch_attn_softmax_bw(float *out_grad, + const float *soft_inp, int rows, + int softmax_len, + cudaStream_t stream); diff --git a/colossalai/kernel/cuda_native/csrc/kernels/transform_kernels.cu b/colossalai/kernel/cuda_native/csrc/kernels/transform_kernels.cu index d03084b22e126fe5facd8ff709ac94a81e511a7e..04de3c092ee093ecb76787a850a71c3aceac8d70 100644 --- a/colossalai/kernel/cuda_native/csrc/kernels/transform_kernels.cu +++ b/colossalai/kernel/cuda_native/csrc/kernels/transform_kernels.cu @@ -1,312 +1,314 @@ -#include -#include -#include - -#include "kernels.h" - -using namespace cub; - -/** -@brief: transform_0213 -Split the attention heads and reshape input -during backward progress of encoder self-attention - -@thread -gridDim.x = batch_size -gridDim.y = seq_len -blockDim.x = min(hidden_dim, MAX_THREADS) - -@param -input: [batch_size, seq_len, hidden_dim] -output: [batch_size, nhead, seq_len, head_dim] -batch_size: the size of the current batch -seq_len: the sequence length of the current batch -hidden_dim: dim of the hidden tensor -nhead: number of attention heads -*/ - -template -__global__ void transform_0213(T *output, const T *input, int hidden_dim, - int head_dim); - -template <> -__global__ void transform_0213(float *output, const float *input, - int hidden_dim, int head_dim) { - int batch_id = blockIdx.x; - int token_id = blockIdx.y; - int seq_len = gridDim.y; - int nhead = hidden_dim / head_dim; - - // [b, s, h] - int src_offset = flat_3dim(batch_id, token_id, 0, seq_len, hidden_dim); - // [b, nh, s, ad] - int trg_offset = - flat_4dim(batch_id, 0, token_id, 0, nhead, seq_len, head_dim); - - const float4 *input4 = reinterpret_cast(input); - float4 *res4 = reinterpret_cast(output); - float4 vinput4; - - for (std::size_t i = threadIdx.x; i < hidden_dim; i += blockDim.x) { - vinput4 = input4[src_offset + i]; - - int head_id = i / head_dim; - int dim_id = i % head_dim; - int cur_trg_offset = flat_3dim(head_id, 0, dim_id, seq_len, head_dim); - res4[trg_offset + cur_trg_offset] = vinput4; - } -} - -template <> -__global__ void transform_0213<__half>(__half *output, const __half *input, - int hidden_dim, int head_dim) { - int batch_id = blockIdx.x; - int token_id = blockIdx.y; - int seq_len = gridDim.y; - int nhead = hidden_dim / head_dim; - - // [b, s, h] - int src_offset = flat_3dim(batch_id, token_id, 0, seq_len, hidden_dim); - // [b, nh, s, ad] - int trg_offset = - flat_4dim(batch_id, 0, token_id, 0, nhead, seq_len, head_dim); - - const float4 *input4 = reinterpret_cast(input); - float4 *res4 = reinterpret_cast(output); - float4 vinput4; - - for (std::size_t i = threadIdx.x; i < hidden_dim; i += blockDim.x) { - vinput4 = input4[src_offset + i]; - - int head_id = i / head_dim; - int dim_id = i % head_dim; - int cur_trg_offset = flat_3dim(head_id, 0, dim_id, seq_len, head_dim); - res4[trg_offset + cur_trg_offset] = vinput4; - } -} - -// [b, s, h] -> [b, nh, s, ad] -template <> -void launch_transform_0213(float *output, const float *input, - int batch_size, int seq_len, int hidden_dim, - int nhead, cudaStream_t stream) { - hidden_dim >>= 2; - int head_dim = hidden_dim / nhead; - - dim3 grid_dim(batch_size, seq_len); - dim3 block_dim(min(hidden_dim, MAX_THREADS)); - - transform_0213 - <<>>(output, input, hidden_dim, head_dim); -} - -template <> -void launch_transform_0213<__half>(__half *output, const __half *input, - int batch_size, int seq_len, int hidden_dim, - int nhead, cudaStream_t stream) { - hidden_dim >>= 3; - int head_dim = hidden_dim / nhead; - - dim3 grid_dim(batch_size, seq_len); - dim3 block_dim(min(hidden_dim, MAX_THREADS)); - - transform_0213<__half> - <<>>(output, input, hidden_dim, head_dim); -} - -/** -@brief: bias_add_transform_20314 -Add bias to input, transform from -[0, 1, 2, 3, 4] to [2, 0, 3, 1, 4] - -@thread -gridDim.x = dim_0 -gridDim.y = dim_1 -gridDim.z = dim_2 -blockDim.x = min(dim_3 * dim_4, MAX_THREADS) - -@param -input: [dim_0, dim_1, dim_2, dim_3, dim_4] -bias: [dim_2, dim_3, dim_4] -output: [dim_2, dim_0, dim_3, dim_1, dim_4] -*/ -template -__global__ void bias_add_transform_20314(T *output, const T *input, - const T *bias, int dim_3, int dim_4); - -template <> -__global__ void -bias_add_transform_20314(float *output, const float *input, - const float *bias, int dim_3, int dim_4) { - int id0 = blockIdx.x; - int id1 = blockIdx.y; - int id2 = blockIdx.z; - int dim_0 = gridDim.x; - int dim_1 = gridDim.y; - int dim_2 = gridDim.z; - int dim_34 = dim_3 * dim_4; - - int src_offset = flat_4dim(id0, id1, id2, 0, dim_1, dim_2, dim_34); - int trg_offset = flat_5dim(id2, id0, 0, id1, 0, dim_0, dim_3, dim_1, dim_4); - int bias_offset = flat_2dim(id2, 0, dim_34); - - const float4 *qkv4 = reinterpret_cast(input); - const float4 *bias4 = reinterpret_cast(bias); - float4 *res4 = reinterpret_cast(output); - float4 vqkv4; - float4 vbias4; - float4 vres4; - - for (std::size_t i = threadIdx.x; i < dim_34; i += blockDim.x) { - vqkv4 = qkv4[src_offset + i]; - vbias4 = bias4[bias_offset + i]; - vres4.x = vqkv4.x + vbias4.x; - vres4.y = vqkv4.y + vbias4.y; - vres4.z = vqkv4.z + vbias4.z; - vres4.w = vqkv4.w + vbias4.w; - - int id3 = i / dim_4; - int id4 = i % dim_4; - int cur_trg_offset = flat_3dim(id3, 0, id4, dim_1, dim_4); - res4[trg_offset + cur_trg_offset] = vres4; - } -} - -template <> -__global__ void -bias_add_transform_20314<__half>(__half *output, const __half *input, - const __half *bias, int dim_3, int dim_4) { - int id0 = blockIdx.x; - int id1 = blockIdx.y; - int id2 = blockIdx.z; - int dim_0 = gridDim.x; - int dim_1 = gridDim.y; - int dim_2 = gridDim.z; - int dim_34 = dim_3 * dim_4; - - int src_offset = flat_4dim(id0, id1, id2, 0, dim_1, dim_2, dim_34); - int trg_offset = flat_5dim(id2, id0, 0, id1, 0, dim_0, dim_3, dim_1, dim_4); - int bias_offset = flat_2dim(id2, 0, dim_34); - - const float4 *qkv4 = reinterpret_cast(input); - const float4 *bias4 = reinterpret_cast(bias); - float4 *res4 = reinterpret_cast(output); - float4 vqkv4; - float4 vbias4; - float4 vres4; - __half2 *h2_qkv = reinterpret_cast<__half2 *>(&vqkv4); - __half2 *h2_bias = reinterpret_cast<__half2 *>(&vbias4); - __half2 *h2_res = reinterpret_cast<__half2 *>(&vres4); - - for (std::size_t i = threadIdx.x; i < dim_34; i += blockDim.x) { - vqkv4 = qkv4[src_offset + i]; - vbias4 = bias4[bias_offset + i]; - h2_res[0] = __hadd2(h2_qkv[0], h2_bias[0]); - h2_res[1] = __hadd2(h2_qkv[1], h2_bias[1]); - h2_res[2] = __hadd2(h2_qkv[2], h2_bias[2]); - h2_res[3] = __hadd2(h2_qkv[3], h2_bias[3]); - - int id3 = i / dim_4; - int id4 = i % dim_4; - int cur_trg_offset = flat_3dim(id3, 0, id4, dim_1, dim_4); - res4[trg_offset + cur_trg_offset] = vres4; - } -} - -// [b, s, 3, h] -> [3, b, nh, s, ad] -template <> -void launch_bias_add_transform_20314(float *output, const float *input, - const float *bias, int dim_0, - int dim_1, int dim_2, int dim_3, - int dim_4, cudaStream_t stream) { - dim_4 >>= 2; - - dim3 grid_dim(dim_0, dim_1, dim_2); - dim3 block_dim(min(dim_3 * dim_4, MAX_THREADS)); - - bias_add_transform_20314 - <<>>(output, input, bias, dim_3, dim_4); -} - -template <> -void launch_bias_add_transform_20314<__half>(__half *output, - const __half *input, - const __half *bias, int dim_0, - int dim_1, int dim_2, int dim_3, - int dim_4, cudaStream_t stream) { - dim_4 >>= 3; - - dim3 grid_dim(dim_0, dim_1, dim_2); - dim3 block_dim(min(dim_3 * dim_4, MAX_THREADS)); - - bias_add_transform_20314<__half> - <<>>(output, input, bias, dim_3, dim_4); -} - -/** -@brief: transform4d_0213 -Reshape the input matrix to merge the heads - -@thread -gridDim.x = (num_all + max_block_thread - 1) / max_block_thread -blockDim.x = max_block_thread - -@param -input: [trans_count, batch_size, nhead, seq_len, head_dim] -output: [batch_size, seq_len, trans_count, nhead, head_dim] -batch_size: the size of the current batch -seq_len: the sequence length of the current batch -hidden_dim: dim of the hidden tensor -nhead: number of attention heads -trans_count: 1 or 3, the count of matrice need to be transformed -*/ -template -__global__ void transform4d_0213(T *output, const T *input, int batch_size, - int seq_len, int trans_count, int nhead, - int head_dim, int num_all) { - int offset = blockIdx.x * blockDim.x + threadIdx.x; - if (offset >= num_all) { - return; - } - int trans_id, batch_id, head_id, token_id, dim_id; - decompose_5dim(offset, batch_size, nhead, seq_len, head_dim, &trans_id, - &batch_id, &head_id, &token_id, &dim_id); - // [b, s, tc, nh, ad] - int trg_offset = flat_5dim(batch_id, token_id, trans_id, head_id, dim_id, - seq_len, trans_count, nhead, head_dim); - - const float4 *input4 = reinterpret_cast(input); - float4 *res4 = reinterpret_cast(output); - res4[trg_offset] = input4[offset]; -} - -// [tc, b, nh, s, ad] -> [b, s, tc, nh, ad] -template <> -void launch_transform4d_0213(float *output, const float *input, - int batch_size, int seq_len, int hidden_dim, - int nhead, int trans_count, - cudaStream_t stream) { - hidden_dim >>= 2; - int head_dim = hidden_dim / nhead; - int num_all = batch_size * seq_len * trans_count * hidden_dim; - int nblock = (num_all + MAX_THREADS - 1) / MAX_THREADS; - - transform4d_0213<<>>( - output, input, batch_size, seq_len, trans_count, nhead, head_dim, - num_all); -} - -template <> -void launch_transform4d_0213<__half>(__half *output, const __half *input, - int batch_size, int seq_len, - int hidden_dim, int nhead, int trans_count, - cudaStream_t stream) { - hidden_dim >>= 3; - int head_dim = hidden_dim / nhead; - int num_all = batch_size * seq_len * trans_count * hidden_dim; - int nblock = (num_all + MAX_THREADS - 1) / MAX_THREADS; - - transform4d_0213<__half><<>>( - output, input, batch_size, seq_len, trans_count, nhead, head_dim, - num_all); -} +#include +#include +#include + +#include "kernels.h" + +using namespace cub; + +/** +@brief: transform_0213 +Split the attention heads and reshape input +during backward progress of encoder self-attention + +@thread +gridDim.x = batch_size +gridDim.y = seq_len +blockDim.x = min(hidden_dim, MAX_THREADS) + +@param +input: [batch_size, seq_len, hidden_dim] +output: [batch_size, nhead, seq_len, head_dim] +batch_size: the size of the current batch +seq_len: the sequence length of the current batch +hidden_dim: dim of the hidden tensor +nhead: number of attention heads +*/ + +template +__global__ void transform_0213(T *output, const T *input, int hidden_dim, + int head_dim); + +template <> +__global__ void transform_0213(float *output, const float *input, + int hidden_dim, int head_dim) { + int batch_id = blockIdx.x; + int token_id = blockIdx.y; + int seq_len = gridDim.y; + int nhead = hidden_dim / head_dim; + + // [b, s, h] + int src_offset = flat_3dim(batch_id, token_id, 0, seq_len, hidden_dim); + // [b, nh, s, ad] + int trg_offset = + flat_4dim(batch_id, 0, token_id, 0, nhead, seq_len, head_dim); + + const float4 *input4 = reinterpret_cast(input); + float4 *res4 = reinterpret_cast(output); + float4 vinput4; + + for (std::size_t i = threadIdx.x; i < hidden_dim; i += blockDim.x) { + vinput4 = input4[src_offset + i]; + + int head_id = i / head_dim; + int dim_id = i % head_dim; + int cur_trg_offset = flat_3dim(head_id, 0, dim_id, seq_len, head_dim); + res4[trg_offset + cur_trg_offset] = vinput4; + } +} + +template <> +__global__ void transform_0213<__half>(__half *output, const __half *input, + int hidden_dim, int head_dim) { + int batch_id = blockIdx.x; + int token_id = blockIdx.y; + int seq_len = gridDim.y; + int nhead = hidden_dim / head_dim; + + // [b, s, h] + int src_offset = flat_3dim(batch_id, token_id, 0, seq_len, hidden_dim); + // [b, nh, s, ad] + int trg_offset = + flat_4dim(batch_id, 0, token_id, 0, nhead, seq_len, head_dim); + + const float4 *input4 = reinterpret_cast(input); + float4 *res4 = reinterpret_cast(output); + float4 vinput4; + + for (std::size_t i = threadIdx.x; i < hidden_dim; i += blockDim.x) { + vinput4 = input4[src_offset + i]; + + int head_id = i / head_dim; + int dim_id = i % head_dim; + int cur_trg_offset = flat_3dim(head_id, 0, dim_id, seq_len, head_dim); + res4[trg_offset + cur_trg_offset] = vinput4; + } +} + +// [b, s, h] -> [b, nh, s, ad] +template <> +void launch_transform_0213(float *output, const float *input, + int batch_size, int seq_len, int hidden_dim, + int nhead, cudaStream_t stream) { + hidden_dim >>= 2; + int head_dim = hidden_dim / nhead; + + dim3 grid_dim(batch_size, seq_len); + dim3 block_dim(min(hidden_dim, MAX_THREADS)); + + transform_0213 + <<>>(output, input, hidden_dim, head_dim); +} + +template <> +void launch_transform_0213<__half>(__half *output, const __half *input, + int batch_size, int seq_len, int hidden_dim, + int nhead, cudaStream_t stream) { + hidden_dim >>= 3; + int head_dim = hidden_dim / nhead; + + dim3 grid_dim(batch_size, seq_len); + dim3 block_dim(min(hidden_dim, MAX_THREADS)); + + transform_0213<__half> + <<>>(output, input, hidden_dim, head_dim); +} + +/** +@brief: bias_add_transform_20314 +Add bias to input, transform from +[0, 1, 2, 3, 4] to [2, 0, 3, 1, 4] + +@thread +gridDim.x = dim_0 +gridDim.y = dim_1 +gridDim.z = dim_2 +blockDim.x = min(dim_3 * dim_4, MAX_THREADS) + +@param +input: [dim_0, dim_1, dim_2, dim_3, dim_4] +bias: [dim_2, dim_3, dim_4] +output: [dim_2, dim_0, dim_3, dim_1, dim_4] +*/ +template +__global__ void bias_add_transform_20314(T *output, const T *input, + const T *bias, int dim_3, int dim_4); + +template <> +__global__ void bias_add_transform_20314(float *output, + const float *input, + const float *bias, int dim_3, + int dim_4) { + int id0 = blockIdx.x; + int id1 = blockIdx.y; + int id2 = blockIdx.z; + int dim_0 = gridDim.x; + int dim_1 = gridDim.y; + int dim_2 = gridDim.z; + int dim_34 = dim_3 * dim_4; + + int src_offset = flat_4dim(id0, id1, id2, 0, dim_1, dim_2, dim_34); + int trg_offset = flat_5dim(id2, id0, 0, id1, 0, dim_0, dim_3, dim_1, dim_4); + int bias_offset = flat_2dim(id2, 0, dim_34); + + const float4 *qkv4 = reinterpret_cast(input); + const float4 *bias4 = reinterpret_cast(bias); + float4 *res4 = reinterpret_cast(output); + float4 vqkv4; + float4 vbias4; + float4 vres4; + + for (std::size_t i = threadIdx.x; i < dim_34; i += blockDim.x) { + vqkv4 = qkv4[src_offset + i]; + vbias4 = bias4[bias_offset + i]; + vres4.x = vqkv4.x + vbias4.x; + vres4.y = vqkv4.y + vbias4.y; + vres4.z = vqkv4.z + vbias4.z; + vres4.w = vqkv4.w + vbias4.w; + + int id3 = i / dim_4; + int id4 = i % dim_4; + int cur_trg_offset = flat_3dim(id3, 0, id4, dim_1, dim_4); + res4[trg_offset + cur_trg_offset] = vres4; + } +} + +template <> +__global__ void bias_add_transform_20314<__half>(__half *output, + const __half *input, + const __half *bias, int dim_3, + int dim_4) { + int id0 = blockIdx.x; + int id1 = blockIdx.y; + int id2 = blockIdx.z; + int dim_0 = gridDim.x; + int dim_1 = gridDim.y; + int dim_2 = gridDim.z; + int dim_34 = dim_3 * dim_4; + + int src_offset = flat_4dim(id0, id1, id2, 0, dim_1, dim_2, dim_34); + int trg_offset = flat_5dim(id2, id0, 0, id1, 0, dim_0, dim_3, dim_1, dim_4); + int bias_offset = flat_2dim(id2, 0, dim_34); + + const float4 *qkv4 = reinterpret_cast(input); + const float4 *bias4 = reinterpret_cast(bias); + float4 *res4 = reinterpret_cast(output); + float4 vqkv4; + float4 vbias4; + float4 vres4; + __half2 *h2_qkv = reinterpret_cast<__half2 *>(&vqkv4); + __half2 *h2_bias = reinterpret_cast<__half2 *>(&vbias4); + __half2 *h2_res = reinterpret_cast<__half2 *>(&vres4); + + for (std::size_t i = threadIdx.x; i < dim_34; i += blockDim.x) { + vqkv4 = qkv4[src_offset + i]; + vbias4 = bias4[bias_offset + i]; + h2_res[0] = __hadd2(h2_qkv[0], h2_bias[0]); + h2_res[1] = __hadd2(h2_qkv[1], h2_bias[1]); + h2_res[2] = __hadd2(h2_qkv[2], h2_bias[2]); + h2_res[3] = __hadd2(h2_qkv[3], h2_bias[3]); + + int id3 = i / dim_4; + int id4 = i % dim_4; + int cur_trg_offset = flat_3dim(id3, 0, id4, dim_1, dim_4); + res4[trg_offset + cur_trg_offset] = vres4; + } +} + +// [b, s, 3, h] -> [3, b, nh, s, ad] +template <> +void launch_bias_add_transform_20314(float *output, const float *input, + const float *bias, int dim_0, + int dim_1, int dim_2, int dim_3, + int dim_4, cudaStream_t stream) { + dim_4 >>= 2; + + dim3 grid_dim(dim_0, dim_1, dim_2); + dim3 block_dim(min(dim_3 * dim_4, MAX_THREADS)); + + bias_add_transform_20314 + <<>>(output, input, bias, dim_3, dim_4); +} + +template <> +void launch_bias_add_transform_20314<__half>(__half *output, + const __half *input, + const __half *bias, int dim_0, + int dim_1, int dim_2, int dim_3, + int dim_4, cudaStream_t stream) { + dim_4 >>= 3; + + dim3 grid_dim(dim_0, dim_1, dim_2); + dim3 block_dim(min(dim_3 * dim_4, MAX_THREADS)); + + bias_add_transform_20314<__half> + <<>>(output, input, bias, dim_3, dim_4); +} + +/** +@brief: transform4d_0213 +Reshape the input matrix to merge the heads + +@thread +gridDim.x = (num_all + max_block_thread - 1) / max_block_thread +blockDim.x = max_block_thread + +@param +input: [trans_count, batch_size, nhead, seq_len, head_dim] +output: [batch_size, seq_len, trans_count, nhead, head_dim] +batch_size: the size of the current batch +seq_len: the sequence length of the current batch +hidden_dim: dim of the hidden tensor +nhead: number of attention heads +trans_count: 1 or 3, the count of matrice need to be transformed +*/ +template +__global__ void transform4d_0213(T *output, const T *input, int batch_size, + int seq_len, int trans_count, int nhead, + int head_dim, int num_all) { + int offset = blockIdx.x * blockDim.x + threadIdx.x; + if (offset >= num_all) { + return; + } + int trans_id, batch_id, head_id, token_id, dim_id; + decompose_5dim(offset, batch_size, nhead, seq_len, head_dim, &trans_id, + &batch_id, &head_id, &token_id, &dim_id); + // [b, s, tc, nh, ad] + int trg_offset = flat_5dim(batch_id, token_id, trans_id, head_id, dim_id, + seq_len, trans_count, nhead, head_dim); + + const float4 *input4 = reinterpret_cast(input); + float4 *res4 = reinterpret_cast(output); + res4[trg_offset] = input4[offset]; +} + +// [tc, b, nh, s, ad] -> [b, s, tc, nh, ad] +template <> +void launch_transform4d_0213(float *output, const float *input, + int batch_size, int seq_len, int hidden_dim, + int nhead, int trans_count, + cudaStream_t stream) { + hidden_dim >>= 2; + int head_dim = hidden_dim / nhead; + int num_all = batch_size * seq_len * trans_count * hidden_dim; + int nblock = (num_all + MAX_THREADS - 1) / MAX_THREADS; + + transform4d_0213<<>>( + output, input, batch_size, seq_len, trans_count, nhead, head_dim, + num_all); +} + +template <> +void launch_transform4d_0213<__half>(__half *output, const __half *input, + int batch_size, int seq_len, + int hidden_dim, int nhead, int trans_count, + cudaStream_t stream) { + hidden_dim >>= 3; + int head_dim = hidden_dim / nhead; + int num_all = batch_size * seq_len * trans_count * hidden_dim; + int nblock = (num_all + MAX_THREADS - 1) / MAX_THREADS; + + transform4d_0213<__half><<>>( + output, input, batch_size, seq_len, trans_count, nhead, head_dim, + num_all); +} diff --git a/colossalai/kernel/cuda_native/csrc/layer_norm_cuda.cpp b/colossalai/kernel/cuda_native/csrc/layer_norm_cuda.cpp index 4690277e63db0a49c23c9274e6553da1e6b04103..15a07bb0c7acf1cdc3fae255bbe5b6791efe23c1 100644 --- a/colossalai/kernel/cuda_native/csrc/layer_norm_cuda.cpp +++ b/colossalai/kernel/cuda_native/csrc/layer_norm_cuda.cpp @@ -138,4 +138,4 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("forward_affine", &layer_norm_affine, "LayerNorm forward (CUDA)"); m.def("backward_affine", &layer_norm_gradient_affine, "LayerNorm backward (CUDA)"); -} \ No newline at end of file +} diff --git a/colossalai/kernel/cuda_native/csrc/layer_norm_cuda_kernel.cu b/colossalai/kernel/cuda_native/csrc/layer_norm_cuda_kernel.cu index ad7066bbd9df1c1582946092460734986cdc2d03..72b84d6ca40f91e337060abdca6fbda42ee9f1ea 100644 --- a/colossalai/kernel/cuda_native/csrc/layer_norm_cuda_kernel.cu +++ b/colossalai/kernel/cuda_native/csrc/layer_norm_cuda_kernel.cu @@ -680,4 +680,4 @@ void cuda_layer_norm_gradient(at::Tensor* dout, at::Tensor* mean, grad_input->DATA_PTR(), gamma != NULL ? grad_gamma->DATA_PTR() : NULL, gamma != NULL ? grad_beta->DATA_PTR() : NULL);) -} \ No newline at end of file +} diff --git a/colossalai/kernel/cuda_native/csrc/moe_cuda.cpp b/colossalai/kernel/cuda_native/csrc/moe_cuda.cpp index 61c8a725052fdf9462d5adf1c9e43680e836caf3..8c0b89eb06d16d5a35c273acb755642399398750 100644 --- a/colossalai/kernel/cuda_native/csrc/moe_cuda.cpp +++ b/colossalai/kernel/cuda_native/csrc/moe_cuda.cpp @@ -1,97 +1,97 @@ -#include - -torch::Tensor moe_dispatch_cuda_forward(int s, int ec, int h, - torch::Tensor batch_tokens, - torch::Tensor mask, - torch::Tensor dest_idx); - -torch::Tensor moe_dispatch_cuda_backward(int s, int ec, int h, - torch::Tensor expert_grad, - torch::Tensor mask, - torch::Tensor dest_idx); - -torch::Tensor moe_combine_cuda_forward(int s, int e, int c, int h, - torch::Tensor expert_tokens, - torch::Tensor logits, torch::Tensor mask, - torch::Tensor dest_idx); - -std::vector moe_combine_cuda_backward( - int s, int e, int c, int h, torch::Tensor tokens_grad, - torch::Tensor expert_tokens, torch::Tensor logits, torch::Tensor mask, - torch::Tensor dest_idx); - -torch::Tensor cumsum_sub_one_in_dim0(torch::Tensor mask); - -#define CHECK_CUDA(x) \ - TORCH_CHECK(x.device().is_cuda(), #x " must be a CUDA tensor") -#define CHECK_CONTIGUOUS(x) \ - TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") -#define CHECK_INPUT(x) \ - CHECK_CUDA(x); \ - CHECK_CONTIGUOUS(x) - -torch::Tensor moe_dispatch_forward(int s, int ec, int h, - torch::Tensor batch_tokens, - torch::Tensor mask, torch::Tensor dest_idx) { - CHECK_INPUT(batch_tokens); - CHECK_CUDA(mask); - CHECK_CUDA(dest_idx); - - return moe_dispatch_cuda_forward(s, ec, h, batch_tokens, mask, dest_idx); -} - -torch::Tensor moe_dispatch_backward(int s, int ec, int h, - torch::Tensor expert_grad, - torch::Tensor mask, - torch::Tensor dest_idx) { - CHECK_INPUT(expert_grad); - CHECK_CUDA(mask); - CHECK_CUDA(dest_idx); - - return moe_dispatch_cuda_backward(s, ec, h, expert_grad, mask, dest_idx); -} - -torch::Tensor moe_combine_forward(int s, int e, int c, int h, - torch::Tensor expert_tokens, - torch::Tensor logits, torch::Tensor mask, - torch::Tensor dest_idx) { - CHECK_INPUT(expert_tokens); - CHECK_INPUT(logits); - CHECK_CUDA(mask); - CHECK_CUDA(dest_idx); - - return moe_combine_cuda_forward(s, e, c, h, expert_tokens, logits, mask, - dest_idx); -} - -std::vector moe_combine_backward(int s, int e, int c, int h, - torch::Tensor tokens_grad, - torch::Tensor expert_tokens, - torch::Tensor logits, - torch::Tensor mask, - torch::Tensor dest_idx) { - CHECK_INPUT(tokens_grad); - CHECK_INPUT(logits); - CHECK_CUDA(mask); - CHECK_CUDA(dest_idx); - - return moe_combine_cuda_backward(s, e, c, h, tokens_grad, expert_tokens, - logits, mask, dest_idx); -} - -torch::Tensor moe_cumsum(torch::Tensor mask) { - CHECK_INPUT(mask); - return cumsum_sub_one_in_dim0(mask); -} - -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("cumsum_sub_one", &moe_cumsum, "Fast cumsum operation in dim0"); - m.def("dispatch_forward", &moe_dispatch_forward, - "Forward operation in MoE dispatch function"); - m.def("dispatch_backward", &moe_dispatch_backward, - "Backward operation in MoE dispatch function"); - m.def("combine_forward", &moe_combine_forward, - "Combine operation in MoE combine function"); - m.def("combine_backward", &moe_combine_backward, - "Combine operation in MoE combine function"); -} +#include + +torch::Tensor moe_dispatch_cuda_forward(int s, int ec, int h, + torch::Tensor batch_tokens, + torch::Tensor mask, + torch::Tensor dest_idx); + +torch::Tensor moe_dispatch_cuda_backward(int s, int ec, int h, + torch::Tensor expert_grad, + torch::Tensor mask, + torch::Tensor dest_idx); + +torch::Tensor moe_combine_cuda_forward(int s, int e, int c, int h, + torch::Tensor expert_tokens, + torch::Tensor logits, torch::Tensor mask, + torch::Tensor dest_idx); + +std::vector moe_combine_cuda_backward( + int s, int e, int c, int h, torch::Tensor tokens_grad, + torch::Tensor expert_tokens, torch::Tensor logits, torch::Tensor mask, + torch::Tensor dest_idx); + +torch::Tensor cumsum_sub_one_in_dim0(torch::Tensor mask); + +#define CHECK_CUDA(x) \ + TORCH_CHECK(x.device().is_cuda(), #x " must be a CUDA tensor") +#define CHECK_CONTIGUOUS(x) \ + TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") +#define CHECK_INPUT(x) \ + CHECK_CUDA(x); \ + CHECK_CONTIGUOUS(x) + +torch::Tensor moe_dispatch_forward(int s, int ec, int h, + torch::Tensor batch_tokens, + torch::Tensor mask, torch::Tensor dest_idx) { + CHECK_INPUT(batch_tokens); + CHECK_CUDA(mask); + CHECK_CUDA(dest_idx); + + return moe_dispatch_cuda_forward(s, ec, h, batch_tokens, mask, dest_idx); +} + +torch::Tensor moe_dispatch_backward(int s, int ec, int h, + torch::Tensor expert_grad, + torch::Tensor mask, + torch::Tensor dest_idx) { + CHECK_INPUT(expert_grad); + CHECK_CUDA(mask); + CHECK_CUDA(dest_idx); + + return moe_dispatch_cuda_backward(s, ec, h, expert_grad, mask, dest_idx); +} + +torch::Tensor moe_combine_forward(int s, int e, int c, int h, + torch::Tensor expert_tokens, + torch::Tensor logits, torch::Tensor mask, + torch::Tensor dest_idx) { + CHECK_INPUT(expert_tokens); + CHECK_INPUT(logits); + CHECK_CUDA(mask); + CHECK_CUDA(dest_idx); + + return moe_combine_cuda_forward(s, e, c, h, expert_tokens, logits, mask, + dest_idx); +} + +std::vector moe_combine_backward(int s, int e, int c, int h, + torch::Tensor tokens_grad, + torch::Tensor expert_tokens, + torch::Tensor logits, + torch::Tensor mask, + torch::Tensor dest_idx) { + CHECK_INPUT(tokens_grad); + CHECK_INPUT(logits); + CHECK_CUDA(mask); + CHECK_CUDA(dest_idx); + + return moe_combine_cuda_backward(s, e, c, h, tokens_grad, expert_tokens, + logits, mask, dest_idx); +} + +torch::Tensor moe_cumsum(torch::Tensor mask) { + CHECK_INPUT(mask); + return cumsum_sub_one_in_dim0(mask); +} + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("cumsum_sub_one", &moe_cumsum, "Fast cumsum operation in dim0"); + m.def("dispatch_forward", &moe_dispatch_forward, + "Forward operation in MoE dispatch function"); + m.def("dispatch_backward", &moe_dispatch_backward, + "Backward operation in MoE dispatch function"); + m.def("combine_forward", &moe_combine_forward, + "Combine operation in MoE combine function"); + m.def("combine_backward", &moe_combine_backward, + "Combine operation in MoE combine function"); +} diff --git a/colossalai/kernel/cuda_native/csrc/moe_cuda_kernel.cu b/colossalai/kernel/cuda_native/csrc/moe_cuda_kernel.cu index 0454377a2fadb4ac4f1ded1359c71f17110b1ea3..66c1e6bd260eb10da09715f6ca423847e3b97145 100644 --- a/colossalai/kernel/cuda_native/csrc/moe_cuda_kernel.cu +++ b/colossalai/kernel/cuda_native/csrc/moe_cuda_kernel.cu @@ -1,659 +1,659 @@ -#include -#include -#include - -#include - -#include "block_reduce.h" - -template -__device__ void moe_dpch_one_fwd(T *src_row, T *dst_row, const int cols) { - assert(cols % pack_size == 0); - const int bpack_size = block_size * pack_size; - - typedef cub::BlockLoad - BlockLoad; - __shared__ typename BlockLoad::TempStorage ts_load; - - typedef cub::BlockStore - BlockStore; - __shared__ typename BlockStore::TempStorage ts_store; - - int tps = threadIdx.x * pack_size; - T pack[pack_size]; - for (int idx = 0; idx + tps < cols; idx += bpack_size) { - BlockLoad(ts_load).Load(src_row + idx, pack); - BlockStore(ts_store).Store(dst_row + idx, pack); - } -} - -template -__device__ void moe_dpch_one_bwd(T *src_row, T *dst_row, const int cols) { - assert(cols % pack_size == 0); - const int bpack_size = block_size * pack_size; - - typedef cub::BlockLoad - BlockLoad; - __shared__ typename BlockLoad::TempStorage ts_load; - - typedef cub::BlockStore - BlockStore; - __shared__ typename BlockStore::TempStorage ts_store; - - int tps = threadIdx.x * pack_size; - T pack[pack_size]; - for (int idx = 0; idx + tps < cols; idx += bpack_size) { - BlockLoad(ts_load).Load(dst_row + idx, pack); - BlockStore(ts_store).Store(src_row + idx, pack); - } -} - -template -__device__ void moe_dpch_two_fwd(T *src_row, T *dst_row1, T *dst_row2, - const int cols) { - assert(cols % pack_size == 0); - const int bpack_size = block_size * pack_size; - - typedef cub::BlockLoad - BlockLoad; - __shared__ typename BlockLoad::TempStorage ts_load; - - typedef cub::BlockStore - BlockStore; - __shared__ typename BlockStore::TempStorage ts_store; - - int tps = threadIdx.x * pack_size; - T pack[pack_size]; - for (int idx = 0; idx + tps < cols; idx += bpack_size) { - BlockLoad(ts_load).Load(src_row + idx, pack); - BlockStore(ts_store).Store(dst_row1 + idx, pack); - BlockStore(ts_store).Store(dst_row2 + idx, pack); - } -} - -template -__device__ void moe_dpch_two_bwd(T *src_row, T *dst_row1, T *dst_row2, - const int cols) { - assert(cols % pack_size == 0); - const int bpack_size = block_size * pack_size; - - typedef cub::BlockLoad - BlockLoad; - __shared__ typename BlockLoad::TempStorage ts_load; - - typedef cub::BlockStore - BlockStore; - __shared__ typename BlockStore::TempStorage ts_store; - - int tps = threadIdx.x * pack_size; - T pack1[pack_size], pack2[pack_size]; - for (int idx = 0; idx + tps < cols; idx += bpack_size) { - BlockLoad(ts_load).Load(dst_row1 + idx, pack1); - BlockLoad(ts_load).Load(dst_row2 + idx, pack2); - -#pragma unroll - for (int i = 0; i < pack_size; ++i) { - pack1[i] += pack2[i]; - } - - BlockStore(ts_store).Store(src_row + idx, pack1); - } -} - -template -__device__ void moe_cb_one_fwd(T *src_row, T *dst_row, const T weight, - const int cols) { - assert(cols % pack_size == 0); - const int bpack_size = block_size * pack_size; - - typedef cub::BlockLoad - BlockLoad; - __shared__ typename BlockLoad::TempStorage ts_load; - - typedef cub::BlockStore - BlockStore; - __shared__ typename BlockStore::TempStorage ts_store; - - int tps = threadIdx.x * pack_size; - T pack[pack_size]; - for (int idx = 0; idx + tps < cols; idx += bpack_size) { - BlockLoad(ts_load).Load(src_row + idx, pack); - -#pragma unroll - for (int i = 0; i < pack_size; ++i) { - pack[i] *= weight; - } - - BlockStore(ts_store).Store(dst_row + idx, pack); - } -} - -template -__device__ void moe_cb_one_bwd(T *src_row, T *dst_row, T *tks_row, - T *weight_grad, const T weight, const int cols) { - assert(cols % pack_size == 0); - const int bpack_size = block_size * pack_size; - - typedef cub::BlockLoad - BlockLoad; - __shared__ typename BlockLoad::TempStorage ts_load; - - typedef cub::BlockStore - BlockStore; - __shared__ typename BlockStore::TempStorage ts_store; - - int tps = threadIdx.x * pack_size; - T grad[pack_size], tokens[pack_size]; - float thread_sum = 0; - for (int idx = 0; idx + tps < cols; idx += bpack_size) { - BlockLoad(ts_load).Load(dst_row + idx, grad); - BlockLoad(ts_load).Load(tks_row + idx, tokens); - -#pragma unroll - for (int i = 0; i < pack_size; ++i) { - thread_sum += grad[i] * tokens[i]; - grad[i] *= weight; - } - - BlockStore(ts_store).Store(src_row + idx, grad); - } - - blockReduce(&thread_sum); - - if (threadIdx.x == 0) *weight_grad = static_cast(thread_sum); -} - -template -__device__ void moe_cb_two_fwd(T *src_row1, T *src_row2, T *dst_row, - const T weight1, const T weight2, - const int cols) { - assert(cols % pack_size == 0); - const int bpack_size = block_size * pack_size; - - typedef cub::BlockLoad - BlockLoad; - __shared__ typename BlockLoad::TempStorage ts_load; - - typedef cub::BlockStore - BlockStore; - __shared__ typename BlockStore::TempStorage ts_store; - - int tps = threadIdx.x * pack_size; - T pack1[pack_size], pack2[pack_size]; - for (int idx = 0; idx + tps < cols; idx += bpack_size) { - BlockLoad(ts_load).Load(src_row1 + idx, pack1); - BlockLoad(ts_load).Load(src_row2 + idx, pack2); - -#pragma unroll - for (int i = 0; i < pack_size; ++i) { - pack1[i] = pack1[i] * weight1 + pack2[i] * weight2; - } - - BlockStore(ts_store).Store(dst_row + idx, pack1); - } -} - -template -__device__ void moe_cb_two_bwd(T *src_row1, T *src_row2, T *dst_row, - T *tks_row1, T *tks_row2, T *weight_grad1, - T *weight_grad2, const T weight1, - const T weight2, const int cols) { - assert(cols % pack_size == 0); - const int bpack_size = block_size * pack_size; - - typedef cub::BlockLoad - BlockLoad; - __shared__ typename BlockLoad::TempStorage ts_load; - - typedef cub::BlockStore - BlockStore; - __shared__ typename BlockStore::TempStorage ts_store; - - int tps = threadIdx.x * pack_size; - T grad[pack_size], tokens1[pack_size], tokens2[pack_size], sgrad1[pack_size], - sgrad2[pack_size]; - float thread_sum[2] = {0, 0}; - for (int idx = 0; idx + tps < cols; idx += bpack_size) { - BlockLoad(ts_load).Load(dst_row + idx, grad); - BlockLoad(ts_load).Load(tks_row1 + idx, tokens1); - BlockLoad(ts_load).Load(tks_row2 + idx, tokens2); - -#pragma unroll - for (int i = 0; i < pack_size; ++i) { - thread_sum[0] += grad[i] * tokens1[i]; - thread_sum[1] += grad[i] * tokens2[i]; - sgrad1[i] = weight1 * grad[i]; - sgrad2[i] = weight2 * grad[i]; - } - - BlockStore(ts_store).Store(src_row1 + idx, sgrad1); - BlockStore(ts_store).Store(src_row2 + idx, sgrad2); - } - - blockReduce(thread_sum); - - if (threadIdx.x == 0) - *weight_grad1 = static_cast(thread_sum[0]); - else if (threadIdx.x == 1) - *weight_grad2 = static_cast(thread_sum[1]); -} - -// DISPATCH KERNELS -------------------------------- - -template -__device__ void moe_dpch_fwd_selector(T *src_row, T *dst_row1, T *dst_row2, - const int cols, const int indicator1, - const int indicator2) { - if (indicator1 != 0 && indicator2 != 0) - moe_dpch_two_fwd(src_row, dst_row1, dst_row2, - cols); - else if (indicator1 != 0) - moe_dpch_one_fwd(src_row, dst_row1, cols); - else if (indicator2 != 0) - moe_dpch_one_fwd(src_row, dst_row2, cols); - else - return; -} - -template -__device__ void moe_dpch_bwd_selector(T *src_row, T *dst_row1, T *dst_row2, - const int cols, const int indicator1, - const int indicator2) { - if (indicator1 != 0 && indicator2 != 0) - moe_dpch_two_bwd(src_row, dst_row1, dst_row2, - cols); - else if (indicator1 != 0) - moe_dpch_one_bwd(src_row, dst_row1, cols); - else if (indicator2 != 0) - moe_dpch_one_bwd(src_row, dst_row2, cols); - else - return; -} - -template -__global__ void moe_dpch_fwd_kernel(T *batch_tokens, T *expert_input, - int *mask1, int *mask2, int *dest1, - int *dest2, const int h) { - int row = blockIdx.x; - int indicator2 = mask2 == nullptr ? 0 : mask2[row]; - moe_dpch_fwd_selector( - batch_tokens + (row * h), expert_input + (dest1[row] * h), - expert_input + (dest2[row] * h), h, mask1[row], indicator2); -} - -template -__global__ void moe_dpch_bwd_kernel(T *tokens_grad, T *expert_grad, int *mask1, - int *mask2, int *dest1, int *dest2, - const int h) { - int row = blockIdx.x; - int indicator2 = mask2 == nullptr ? 0 : mask2[row]; - moe_dpch_bwd_selector( - tokens_grad + (row * h), expert_grad + (dest1[row] * h), - expert_grad + (dest2[row] * h), h, mask1[row], indicator2); -} - -// COMBINE KERNELS -------------------------------- - -template -__device__ void moe_cb_fwd_selector(T *src_row1, T *src_row2, T *dst_row, - const int cols, const T weight1, - const T weight2, const int indicator1, - const int indicator2) { - if (indicator1 != 0 && indicator2 != 0) - moe_cb_two_fwd(src_row1, src_row2, dst_row, - weight1, weight2, cols); - else if (indicator1 != 0) - moe_cb_one_fwd(src_row1, dst_row, weight1, cols); - else if (indicator2 != 0) - moe_cb_one_fwd(src_row2, dst_row, weight2, cols); - else - return; -} - -template -__device__ void moe_cb_bwd_selector(T *src_row1, T *src_row2, T *dst_row, - const int cols, T *tks_row1, T *tks_row2, - T *wt_grad1, T *wt_grad2, const T weight1, - const T weight2, const int indicator1, - const int indicator2) { - if (indicator1 != 0 && indicator2 != 0) - moe_cb_two_bwd(src_row1, src_row2, dst_row, - tks_row1, tks_row2, wt_grad1, - wt_grad2, weight1, weight2, cols); - else if (indicator1 != 0) - moe_cb_one_bwd(src_row1, dst_row, tks_row1, - wt_grad1, weight1, cols); - else if (indicator2 != 0) - moe_cb_one_bwd(src_row2, dst_row, tks_row2, - wt_grad2, weight2, cols); - else - return; -} - -template -__global__ void moe_cb_fwd_kernel(T *expert_tokens, T *combine_tokens, - T *logits, int *mask1, int *mask2, int *dest1, - int *dest2, const int e, const int c, - const int h) { - int row = blockIdx.x, eid1 = dest1[row] / c, eid2 = dest2[row] / c; - int indicator2 = mask2 == nullptr ? 0 : mask2[row]; - T *row_log = logits + (row * e); - moe_cb_fwd_selector( - expert_tokens + (dest1[row] * h), expert_tokens + (dest2[row] * h), - combine_tokens + (row * h), h, row_log[eid1], row_log[eid2], mask1[row], - indicator2); -} - -template -__global__ void moe_cb_bwd_kernel(T *tokens_grad, T *expert_grad, T *tks, - T *logits, T *logits_grad, int *mask1, - int *mask2, int *dest1, int *dest2, - const int e, const int c, const int h) { - int row = blockIdx.x, eid1 = dest1[row] / c, eid2 = dest2[row] / c; - int indicator2 = mask2 == nullptr ? 0 : mask2[row]; - T *row_log = logits + (row * e), *row_grad = logits_grad + (row * e); - moe_cb_bwd_selector( - expert_grad + (dest1[row] * h), expert_grad + (dest2[row] * h), - tokens_grad + (row * h), h, tks + (dest1[row] * h), - tks + (dest2[row] * h), row_grad + eid1, row_grad + eid2, row_log[eid1], - row_log[eid2], mask1[row], indicator2); -} - -// CUMSUM KERNEL -------------------------------- - -template -__global__ void cumsum_kernel(int *inputs, int *outputs, const int s, - const int e) { - assert(s % pack_size == 0); - constexpr int bpack_size = block_size * pack_size; - int tid = threadIdx.x, bid = blockIdx.x, tps = tid * pack_size, last_sum = -1; - __shared__ int temp[block_size + 1]; - int pack[pack_size]; - - for (int idx = 0; idx < s; idx += bpack_size) { - int offset = 1; - - if (idx + tps < s) { - temp[tid] = inputs[tps * e + bid]; -#pragma unroll - for (int i = 1; i < pack_size; ++i) { - pack[i] = inputs[(tps + i) * e + bid]; - } -#pragma unroll - for (int i = 1; i < pack_size; ++i) { - temp[tid] += pack[i]; - } - } - - for (int i = block_size >> 1; i > 0; i >>= 1) { - __syncthreads(); - if (tid < i) { - int j = offset * (2 * tid + 1) - 1; - temp[j + offset] += temp[j]; - } - offset <<= 1; - } - - if (tid == 0) { - temp[block_size] = temp[block_size - 1]; - temp[block_size - 1] = 0; - } - - for (int i = 1; i < block_size; i <<= 1) { - offset >>= 1; - __syncthreads(); - if (tid < i) { - int j = offset * (2 * tid + 1) - 1, k = j + offset, ts = temp[j]; - temp[j] = temp[k]; - temp[k] += ts; - } - } - __syncthreads(); - - if (tid == 0) temp[0] = temp[block_size]; - __syncthreads(); - - if (idx + tps < s) { - temp[tid + 1] += last_sum; -#pragma unroll - for (int i = pack_size - 1; i > 0; --i) { - outputs[(tps + i) * e + bid] = temp[tid + 1]; - temp[tid + 1] -= pack[i]; - } - outputs[tps * e + bid] = temp[tid + 1]; - } - __syncthreads(); - - last_sum += temp[0]; - inputs += bpack_size * e; - outputs += bpack_size * e; - } -} - -// LAUNCH FUNCTIONS -------------------------------- - -template -void moe_dpch_fwd_launch(T *batch_tokens, T *expert_input, int *mask1, - int *mask2, int *dest1, int *dest2, const int s, - const int h) { - if (h < 256) - moe_dpch_fwd_kernel - <<>>(batch_tokens, expert_input, mask1, mask2, dest1, dest2, h); - else if (h < 512) - moe_dpch_fwd_kernel - <<>>(batch_tokens, expert_input, mask1, mask2, dest1, dest2, h); - else if (h < 1024) - moe_dpch_fwd_kernel - <<>>(batch_tokens, expert_input, mask1, mask2, dest1, dest2, h); - else if (h < 2048) - moe_dpch_fwd_kernel - <<>>(batch_tokens, expert_input, mask1, mask2, dest1, dest2, h); - else - moe_dpch_fwd_kernel - <<>>(batch_tokens, expert_input, mask1, mask2, dest1, dest2, h); -} - -template -void moe_dpch_bwd_launch(T *tokens_grad, T *expert_grad, int *mask1, int *mask2, - int *dest1, int *dest2, const int s, const int h) { - if (h < 256) - moe_dpch_bwd_kernel - <<>>(tokens_grad, expert_grad, mask1, mask2, dest1, dest2, h); - else if (h < 512) - moe_dpch_bwd_kernel - <<>>(tokens_grad, expert_grad, mask1, mask2, dest1, dest2, h); - else if (h < 1024) - moe_dpch_bwd_kernel - <<>>(tokens_grad, expert_grad, mask1, mask2, dest1, dest2, h); - else if (h < 2048) - moe_dpch_bwd_kernel - <<>>(tokens_grad, expert_grad, mask1, mask2, dest1, dest2, h); - else - moe_dpch_bwd_kernel - <<>>(tokens_grad, expert_grad, mask1, mask2, dest1, dest2, h); -} - -template -void moe_cb_fwd_launch(T *expert_tokens, T *combine_tokens, T *logits, - int *mask1, int *mask2, int *dest1, int *dest2, - const int s, const int e, const int c, const int h) { - if (h < 256) - moe_cb_fwd_kernel<<>>(expert_tokens, combine_tokens, - logits, mask1, mask2, dest1, dest2, - e, c, h); - else if (h < 512) - moe_cb_fwd_kernel<<>>(expert_tokens, combine_tokens, - logits, mask1, mask2, dest1, dest2, - e, c, h); - else if (h < 1024) - moe_cb_fwd_kernel<<>>(expert_tokens, combine_tokens, - logits, mask1, mask2, dest1, dest2, - e, c, h); - else if (h < 2048) - moe_cb_fwd_kernel<<>>(expert_tokens, combine_tokens, - logits, mask1, mask2, dest1, dest2, - e, c, h); - else - moe_cb_fwd_kernel<<>>(expert_tokens, combine_tokens, - logits, mask1, mask2, dest1, - dest2, e, c, h); -} - -template -void moe_cb_bwd_launch(T *tokens_grad, T *expert_grad, T *tks, T *logits, - T *logits_grad, int *mask1, int *mask2, int *dest1, - int *dest2, const int s, const int e, const int c, - const int h) { - if (h < 256) - moe_cb_bwd_kernel<<>>(tokens_grad, expert_grad, tks, - logits, logits_grad, mask1, mask2, - dest1, dest2, e, c, h); - else // if (h < 512) - moe_cb_bwd_kernel<<>>(tokens_grad, expert_grad, tks, - logits, logits_grad, mask1, mask2, - dest1, dest2, e, c, h); - // else if (h < 1024) - // moe_cb_bwd_kernel<<>> - // (tokens_grad, expert_grad, tks, logits, logits_grad, mask1, mask2, - // dest1, dest2, e, c, h); - // else - // moe_cb_bwd_kernel<<>> - // (tokens_grad, expert_grad, tks, logits, logits_grad, mask1, mask2, - // dest1, dest2, e, c, h); -} - -void cumsum_launch(int *inputs, int *outputs, const int s, const int e) { - if (s <= 256) - cumsum_kernel<256, 1><<>>(inputs, outputs, s, e); - else if (s <= 512) - cumsum_kernel<512, 1><<>>(inputs, outputs, s, e); - else if (s <= 1024) - cumsum_kernel<1024, 1><<>>(inputs, outputs, s, e); - else if (s <= 2048) - cumsum_kernel<1024, 2><<>>(inputs, outputs, s, e); - else - cumsum_kernel<1024, 4><<>>(inputs, outputs, s, e); -} - -// API FUNCTIONS -------------------------------- - -#define DISPATCH_FLOAT_AND_HALF(TYPE, NAME, ...) \ - switch (TYPE) { \ - case at::ScalarType::Float: { \ - using scalar_t = float; \ - __VA_ARGS__; \ - break; \ - } \ - case at::ScalarType::Half: { \ - using scalar_t = at::Half; \ - __VA_ARGS__; \ - break; \ - } \ - default: \ - AT_ERROR(#NAME, " not implemented yet for specific data type."); \ - } - -torch::Tensor moe_dispatch_cuda_forward(int s, int ec, int h, - torch::Tensor batch_tokens, - torch::Tensor mask, - torch::Tensor dest_idx) { - assert(h % 16 == 0); - auto res = torch::zeros( - {ec, h}, - torch::dtype(batch_tokens.dtype()).device(batch_tokens.device())); - auto k = mask.size(0); - - DISPATCH_FLOAT_AND_HALF( - batch_tokens.scalar_type(), "moe dispatch forward", - moe_dpch_fwd_launch( - batch_tokens.data(), res.data(), - mask[0].data(), k == 1 ? nullptr : mask[1].data(), - dest_idx[0].data(), - k == 1 ? dest_idx[0].data() : dest_idx[1].data(), s, h)); - - return res; -} - -torch::Tensor moe_dispatch_cuda_backward(int s, int ec, int h, - torch::Tensor expert_grad, - torch::Tensor mask, - torch::Tensor dest_idx) { - assert(h % 16 == 0); - auto res = torch::zeros( - {s, h}, torch::dtype(expert_grad.dtype()).device(expert_grad.device())); - auto k = mask.size(0); - - DISPATCH_FLOAT_AND_HALF( - expert_grad.scalar_type(), "moe dispatch backward", - moe_dpch_bwd_launch( - res.data(), expert_grad.data(), - mask[0].data(), k == 1 ? nullptr : mask[1].data(), - dest_idx[0].data(), - k == 1 ? dest_idx[0].data() : dest_idx[1].data(), s, h)); - - return res; -} - -torch::Tensor moe_combine_cuda_forward(int s, int e, int c, int h, - torch::Tensor expert_tokens, - torch::Tensor logits, torch::Tensor mask, - torch::Tensor dest_idx) { - assert(h % 16 == 0); - assert(expert_tokens.dtype() == logits.dtype()); - - auto res = torch::zeros( - {s, h}, - torch::dtype(expert_tokens.dtype()).device(expert_tokens.device())); - auto k = mask.size(0); - - DISPATCH_FLOAT_AND_HALF( - expert_tokens.scalar_type(), "moe combine forward", - moe_cb_fwd_launch( - expert_tokens.data(), res.data(), - logits.data(), mask[0].data(), - k == 1 ? nullptr : mask[1].data(), dest_idx[0].data(), - k == 1 ? dest_idx[0].data() : dest_idx[1].data(), s, e, c, - h)); - - return res; -} - -std::vector moe_combine_cuda_backward( - int s, int e, int c, int h, torch::Tensor tokens_grad, - torch::Tensor expert_tokens, torch::Tensor logits, torch::Tensor mask, - torch::Tensor dest_idx) { - assert(h % 16 == 0); - assert(tokens_grad.dtype() == expert_tokens.dtype()); - assert(expert_tokens.dtype() == logits.dtype()); - - auto egrad = torch::zeros( - {e * c, h}, - torch::dtype(tokens_grad.dtype()).device(tokens_grad.device())), - wgrad = torch::zeros( - {s, e}, torch::dtype(logits.dtype()).device(logits.device())); - auto k = mask.size(0); - - DISPATCH_FLOAT_AND_HALF( - tokens_grad.scalar_type(), "moe combine backward", - moe_cb_bwd_launch( - tokens_grad.data(), egrad.data(), - expert_tokens.data(), logits.data(), - wgrad.data(), mask[0].data(), - k == 1 ? nullptr : mask[1].data(), dest_idx[0].data(), - k == 1 ? dest_idx[0].data() : dest_idx[1].data(), s, e, c, - h)); - - return {egrad, wgrad}; -} - -torch::Tensor cumsum_sub_one_in_dim0(torch::Tensor mask) { - assert(mask.dim() == 2); - assert(mask.dtype() == torch::kInt32); - - const int s = mask.size(0), e = mask.size(1); - auto res = - torch::empty({s, e}, torch::dtype(torch::kInt32).device(mask.device())); - cumsum_launch(mask.data(), res.data(), s, e); - - return res; -} +#include +#include +#include + +#include + +#include "block_reduce.h" + +template +__device__ void moe_dpch_one_fwd(T *src_row, T *dst_row, const int cols) { + assert(cols % pack_size == 0); + const int bpack_size = block_size * pack_size; + + typedef cub::BlockLoad + BlockLoad; + __shared__ typename BlockLoad::TempStorage ts_load; + + typedef cub::BlockStore + BlockStore; + __shared__ typename BlockStore::TempStorage ts_store; + + int tps = threadIdx.x * pack_size; + T pack[pack_size]; + for (int idx = 0; idx + tps < cols; idx += bpack_size) { + BlockLoad(ts_load).Load(src_row + idx, pack); + BlockStore(ts_store).Store(dst_row + idx, pack); + } +} + +template +__device__ void moe_dpch_one_bwd(T *src_row, T *dst_row, const int cols) { + assert(cols % pack_size == 0); + const int bpack_size = block_size * pack_size; + + typedef cub::BlockLoad + BlockLoad; + __shared__ typename BlockLoad::TempStorage ts_load; + + typedef cub::BlockStore + BlockStore; + __shared__ typename BlockStore::TempStorage ts_store; + + int tps = threadIdx.x * pack_size; + T pack[pack_size]; + for (int idx = 0; idx + tps < cols; idx += bpack_size) { + BlockLoad(ts_load).Load(dst_row + idx, pack); + BlockStore(ts_store).Store(src_row + idx, pack); + } +} + +template +__device__ void moe_dpch_two_fwd(T *src_row, T *dst_row1, T *dst_row2, + const int cols) { + assert(cols % pack_size == 0); + const int bpack_size = block_size * pack_size; + + typedef cub::BlockLoad + BlockLoad; + __shared__ typename BlockLoad::TempStorage ts_load; + + typedef cub::BlockStore + BlockStore; + __shared__ typename BlockStore::TempStorage ts_store; + + int tps = threadIdx.x * pack_size; + T pack[pack_size]; + for (int idx = 0; idx + tps < cols; idx += bpack_size) { + BlockLoad(ts_load).Load(src_row + idx, pack); + BlockStore(ts_store).Store(dst_row1 + idx, pack); + BlockStore(ts_store).Store(dst_row2 + idx, pack); + } +} + +template +__device__ void moe_dpch_two_bwd(T *src_row, T *dst_row1, T *dst_row2, + const int cols) { + assert(cols % pack_size == 0); + const int bpack_size = block_size * pack_size; + + typedef cub::BlockLoad + BlockLoad; + __shared__ typename BlockLoad::TempStorage ts_load; + + typedef cub::BlockStore + BlockStore; + __shared__ typename BlockStore::TempStorage ts_store; + + int tps = threadIdx.x * pack_size; + T pack1[pack_size], pack2[pack_size]; + for (int idx = 0; idx + tps < cols; idx += bpack_size) { + BlockLoad(ts_load).Load(dst_row1 + idx, pack1); + BlockLoad(ts_load).Load(dst_row2 + idx, pack2); + +#pragma unroll + for (int i = 0; i < pack_size; ++i) { + pack1[i] += pack2[i]; + } + + BlockStore(ts_store).Store(src_row + idx, pack1); + } +} + +template +__device__ void moe_cb_one_fwd(T *src_row, T *dst_row, const T weight, + const int cols) { + assert(cols % pack_size == 0); + const int bpack_size = block_size * pack_size; + + typedef cub::BlockLoad + BlockLoad; + __shared__ typename BlockLoad::TempStorage ts_load; + + typedef cub::BlockStore + BlockStore; + __shared__ typename BlockStore::TempStorage ts_store; + + int tps = threadIdx.x * pack_size; + T pack[pack_size]; + for (int idx = 0; idx + tps < cols; idx += bpack_size) { + BlockLoad(ts_load).Load(src_row + idx, pack); + +#pragma unroll + for (int i = 0; i < pack_size; ++i) { + pack[i] *= weight; + } + + BlockStore(ts_store).Store(dst_row + idx, pack); + } +} + +template +__device__ void moe_cb_one_bwd(T *src_row, T *dst_row, T *tks_row, + T *weight_grad, const T weight, const int cols) { + assert(cols % pack_size == 0); + const int bpack_size = block_size * pack_size; + + typedef cub::BlockLoad + BlockLoad; + __shared__ typename BlockLoad::TempStorage ts_load; + + typedef cub::BlockStore + BlockStore; + __shared__ typename BlockStore::TempStorage ts_store; + + int tps = threadIdx.x * pack_size; + T grad[pack_size], tokens[pack_size]; + float thread_sum = 0; + for (int idx = 0; idx + tps < cols; idx += bpack_size) { + BlockLoad(ts_load).Load(dst_row + idx, grad); + BlockLoad(ts_load).Load(tks_row + idx, tokens); + +#pragma unroll + for (int i = 0; i < pack_size; ++i) { + thread_sum += grad[i] * tokens[i]; + grad[i] *= weight; + } + + BlockStore(ts_store).Store(src_row + idx, grad); + } + + blockReduce(&thread_sum); + + if (threadIdx.x == 0) *weight_grad = static_cast(thread_sum); +} + +template +__device__ void moe_cb_two_fwd(T *src_row1, T *src_row2, T *dst_row, + const T weight1, const T weight2, + const int cols) { + assert(cols % pack_size == 0); + const int bpack_size = block_size * pack_size; + + typedef cub::BlockLoad + BlockLoad; + __shared__ typename BlockLoad::TempStorage ts_load; + + typedef cub::BlockStore + BlockStore; + __shared__ typename BlockStore::TempStorage ts_store; + + int tps = threadIdx.x * pack_size; + T pack1[pack_size], pack2[pack_size]; + for (int idx = 0; idx + tps < cols; idx += bpack_size) { + BlockLoad(ts_load).Load(src_row1 + idx, pack1); + BlockLoad(ts_load).Load(src_row2 + idx, pack2); + +#pragma unroll + for (int i = 0; i < pack_size; ++i) { + pack1[i] = pack1[i] * weight1 + pack2[i] * weight2; + } + + BlockStore(ts_store).Store(dst_row + idx, pack1); + } +} + +template +__device__ void moe_cb_two_bwd(T *src_row1, T *src_row2, T *dst_row, + T *tks_row1, T *tks_row2, T *weight_grad1, + T *weight_grad2, const T weight1, + const T weight2, const int cols) { + assert(cols % pack_size == 0); + const int bpack_size = block_size * pack_size; + + typedef cub::BlockLoad + BlockLoad; + __shared__ typename BlockLoad::TempStorage ts_load; + + typedef cub::BlockStore + BlockStore; + __shared__ typename BlockStore::TempStorage ts_store; + + int tps = threadIdx.x * pack_size; + T grad[pack_size], tokens1[pack_size], tokens2[pack_size], sgrad1[pack_size], + sgrad2[pack_size]; + float thread_sum[2] = {0, 0}; + for (int idx = 0; idx + tps < cols; idx += bpack_size) { + BlockLoad(ts_load).Load(dst_row + idx, grad); + BlockLoad(ts_load).Load(tks_row1 + idx, tokens1); + BlockLoad(ts_load).Load(tks_row2 + idx, tokens2); + +#pragma unroll + for (int i = 0; i < pack_size; ++i) { + thread_sum[0] += grad[i] * tokens1[i]; + thread_sum[1] += grad[i] * tokens2[i]; + sgrad1[i] = weight1 * grad[i]; + sgrad2[i] = weight2 * grad[i]; + } + + BlockStore(ts_store).Store(src_row1 + idx, sgrad1); + BlockStore(ts_store).Store(src_row2 + idx, sgrad2); + } + + blockReduce(thread_sum); + + if (threadIdx.x == 0) + *weight_grad1 = static_cast(thread_sum[0]); + else if (threadIdx.x == 1) + *weight_grad2 = static_cast(thread_sum[1]); +} + +// DISPATCH KERNELS -------------------------------- + +template +__device__ void moe_dpch_fwd_selector(T *src_row, T *dst_row1, T *dst_row2, + const int cols, const int indicator1, + const int indicator2) { + if (indicator1 != 0 && indicator2 != 0) + moe_dpch_two_fwd(src_row, dst_row1, dst_row2, + cols); + else if (indicator1 != 0) + moe_dpch_one_fwd(src_row, dst_row1, cols); + else if (indicator2 != 0) + moe_dpch_one_fwd(src_row, dst_row2, cols); + else + return; +} + +template +__device__ void moe_dpch_bwd_selector(T *src_row, T *dst_row1, T *dst_row2, + const int cols, const int indicator1, + const int indicator2) { + if (indicator1 != 0 && indicator2 != 0) + moe_dpch_two_bwd(src_row, dst_row1, dst_row2, + cols); + else if (indicator1 != 0) + moe_dpch_one_bwd(src_row, dst_row1, cols); + else if (indicator2 != 0) + moe_dpch_one_bwd(src_row, dst_row2, cols); + else + return; +} + +template +__global__ void moe_dpch_fwd_kernel(T *batch_tokens, T *expert_input, + int *mask1, int *mask2, int *dest1, + int *dest2, const int h) { + int row = blockIdx.x; + int indicator2 = mask2 == nullptr ? 0 : mask2[row]; + moe_dpch_fwd_selector( + batch_tokens + (row * h), expert_input + (dest1[row] * h), + expert_input + (dest2[row] * h), h, mask1[row], indicator2); +} + +template +__global__ void moe_dpch_bwd_kernel(T *tokens_grad, T *expert_grad, int *mask1, + int *mask2, int *dest1, int *dest2, + const int h) { + int row = blockIdx.x; + int indicator2 = mask2 == nullptr ? 0 : mask2[row]; + moe_dpch_bwd_selector( + tokens_grad + (row * h), expert_grad + (dest1[row] * h), + expert_grad + (dest2[row] * h), h, mask1[row], indicator2); +} + +// COMBINE KERNELS -------------------------------- + +template +__device__ void moe_cb_fwd_selector(T *src_row1, T *src_row2, T *dst_row, + const int cols, const T weight1, + const T weight2, const int indicator1, + const int indicator2) { + if (indicator1 != 0 && indicator2 != 0) + moe_cb_two_fwd(src_row1, src_row2, dst_row, + weight1, weight2, cols); + else if (indicator1 != 0) + moe_cb_one_fwd(src_row1, dst_row, weight1, cols); + else if (indicator2 != 0) + moe_cb_one_fwd(src_row2, dst_row, weight2, cols); + else + return; +} + +template +__device__ void moe_cb_bwd_selector(T *src_row1, T *src_row2, T *dst_row, + const int cols, T *tks_row1, T *tks_row2, + T *wt_grad1, T *wt_grad2, const T weight1, + const T weight2, const int indicator1, + const int indicator2) { + if (indicator1 != 0 && indicator2 != 0) + moe_cb_two_bwd(src_row1, src_row2, dst_row, + tks_row1, tks_row2, wt_grad1, + wt_grad2, weight1, weight2, cols); + else if (indicator1 != 0) + moe_cb_one_bwd(src_row1, dst_row, tks_row1, + wt_grad1, weight1, cols); + else if (indicator2 != 0) + moe_cb_one_bwd(src_row2, dst_row, tks_row2, + wt_grad2, weight2, cols); + else + return; +} + +template +__global__ void moe_cb_fwd_kernel(T *expert_tokens, T *combine_tokens, + T *logits, int *mask1, int *mask2, int *dest1, + int *dest2, const int e, const int c, + const int h) { + int row = blockIdx.x, eid1 = dest1[row] / c, eid2 = dest2[row] / c; + int indicator2 = mask2 == nullptr ? 0 : mask2[row]; + T *row_log = logits + (row * e); + moe_cb_fwd_selector( + expert_tokens + (dest1[row] * h), expert_tokens + (dest2[row] * h), + combine_tokens + (row * h), h, row_log[eid1], row_log[eid2], mask1[row], + indicator2); +} + +template +__global__ void moe_cb_bwd_kernel(T *tokens_grad, T *expert_grad, T *tks, + T *logits, T *logits_grad, int *mask1, + int *mask2, int *dest1, int *dest2, + const int e, const int c, const int h) { + int row = blockIdx.x, eid1 = dest1[row] / c, eid2 = dest2[row] / c; + int indicator2 = mask2 == nullptr ? 0 : mask2[row]; + T *row_log = logits + (row * e), *row_grad = logits_grad + (row * e); + moe_cb_bwd_selector( + expert_grad + (dest1[row] * h), expert_grad + (dest2[row] * h), + tokens_grad + (row * h), h, tks + (dest1[row] * h), + tks + (dest2[row] * h), row_grad + eid1, row_grad + eid2, row_log[eid1], + row_log[eid2], mask1[row], indicator2); +} + +// CUMSUM KERNEL -------------------------------- + +template +__global__ void cumsum_kernel(int *inputs, int *outputs, const int s, + const int e) { + assert(s % pack_size == 0); + constexpr int bpack_size = block_size * pack_size; + int tid = threadIdx.x, bid = blockIdx.x, tps = tid * pack_size, last_sum = -1; + __shared__ int temp[block_size + 1]; + int pack[pack_size]; + + for (int idx = 0; idx < s; idx += bpack_size) { + int offset = 1; + + if (idx + tps < s) { + temp[tid] = inputs[tps * e + bid]; +#pragma unroll + for (int i = 1; i < pack_size; ++i) { + pack[i] = inputs[(tps + i) * e + bid]; + } +#pragma unroll + for (int i = 1; i < pack_size; ++i) { + temp[tid] += pack[i]; + } + } + + for (int i = block_size >> 1; i > 0; i >>= 1) { + __syncthreads(); + if (tid < i) { + int j = offset * (2 * tid + 1) - 1; + temp[j + offset] += temp[j]; + } + offset <<= 1; + } + + if (tid == 0) { + temp[block_size] = temp[block_size - 1]; + temp[block_size - 1] = 0; + } + + for (int i = 1; i < block_size; i <<= 1) { + offset >>= 1; + __syncthreads(); + if (tid < i) { + int j = offset * (2 * tid + 1) - 1, k = j + offset, ts = temp[j]; + temp[j] = temp[k]; + temp[k] += ts; + } + } + __syncthreads(); + + if (tid == 0) temp[0] = temp[block_size]; + __syncthreads(); + + if (idx + tps < s) { + temp[tid + 1] += last_sum; +#pragma unroll + for (int i = pack_size - 1; i > 0; --i) { + outputs[(tps + i) * e + bid] = temp[tid + 1]; + temp[tid + 1] -= pack[i]; + } + outputs[tps * e + bid] = temp[tid + 1]; + } + __syncthreads(); + + last_sum += temp[0]; + inputs += bpack_size * e; + outputs += bpack_size * e; + } +} + +// LAUNCH FUNCTIONS -------------------------------- + +template +void moe_dpch_fwd_launch(T *batch_tokens, T *expert_input, int *mask1, + int *mask2, int *dest1, int *dest2, const int s, + const int h) { + if (h < 256) + moe_dpch_fwd_kernel + <<>>(batch_tokens, expert_input, mask1, mask2, dest1, dest2, h); + else if (h < 512) + moe_dpch_fwd_kernel + <<>>(batch_tokens, expert_input, mask1, mask2, dest1, dest2, h); + else if (h < 1024) + moe_dpch_fwd_kernel + <<>>(batch_tokens, expert_input, mask1, mask2, dest1, dest2, h); + else if (h < 2048) + moe_dpch_fwd_kernel + <<>>(batch_tokens, expert_input, mask1, mask2, dest1, dest2, h); + else + moe_dpch_fwd_kernel + <<>>(batch_tokens, expert_input, mask1, mask2, dest1, dest2, h); +} + +template +void moe_dpch_bwd_launch(T *tokens_grad, T *expert_grad, int *mask1, int *mask2, + int *dest1, int *dest2, const int s, const int h) { + if (h < 256) + moe_dpch_bwd_kernel + <<>>(tokens_grad, expert_grad, mask1, mask2, dest1, dest2, h); + else if (h < 512) + moe_dpch_bwd_kernel + <<>>(tokens_grad, expert_grad, mask1, mask2, dest1, dest2, h); + else if (h < 1024) + moe_dpch_bwd_kernel + <<>>(tokens_grad, expert_grad, mask1, mask2, dest1, dest2, h); + else if (h < 2048) + moe_dpch_bwd_kernel + <<>>(tokens_grad, expert_grad, mask1, mask2, dest1, dest2, h); + else + moe_dpch_bwd_kernel + <<>>(tokens_grad, expert_grad, mask1, mask2, dest1, dest2, h); +} + +template +void moe_cb_fwd_launch(T *expert_tokens, T *combine_tokens, T *logits, + int *mask1, int *mask2, int *dest1, int *dest2, + const int s, const int e, const int c, const int h) { + if (h < 256) + moe_cb_fwd_kernel<<>>(expert_tokens, combine_tokens, + logits, mask1, mask2, dest1, dest2, + e, c, h); + else if (h < 512) + moe_cb_fwd_kernel<<>>(expert_tokens, combine_tokens, + logits, mask1, mask2, dest1, dest2, + e, c, h); + else if (h < 1024) + moe_cb_fwd_kernel<<>>(expert_tokens, combine_tokens, + logits, mask1, mask2, dest1, dest2, + e, c, h); + else if (h < 2048) + moe_cb_fwd_kernel<<>>(expert_tokens, combine_tokens, + logits, mask1, mask2, dest1, dest2, + e, c, h); + else + moe_cb_fwd_kernel<<>>(expert_tokens, combine_tokens, + logits, mask1, mask2, dest1, + dest2, e, c, h); +} + +template +void moe_cb_bwd_launch(T *tokens_grad, T *expert_grad, T *tks, T *logits, + T *logits_grad, int *mask1, int *mask2, int *dest1, + int *dest2, const int s, const int e, const int c, + const int h) { + if (h < 256) + moe_cb_bwd_kernel<<>>(tokens_grad, expert_grad, tks, + logits, logits_grad, mask1, mask2, + dest1, dest2, e, c, h); + else // if (h < 512) + moe_cb_bwd_kernel<<>>(tokens_grad, expert_grad, tks, + logits, logits_grad, mask1, mask2, + dest1, dest2, e, c, h); + // else if (h < 1024) + // moe_cb_bwd_kernel<<>> + // (tokens_grad, expert_grad, tks, logits, logits_grad, mask1, mask2, + // dest1, dest2, e, c, h); + // else + // moe_cb_bwd_kernel<<>> + // (tokens_grad, expert_grad, tks, logits, logits_grad, mask1, mask2, + // dest1, dest2, e, c, h); +} + +void cumsum_launch(int *inputs, int *outputs, const int s, const int e) { + if (s <= 256) + cumsum_kernel<256, 1><<>>(inputs, outputs, s, e); + else if (s <= 512) + cumsum_kernel<512, 1><<>>(inputs, outputs, s, e); + else if (s <= 1024) + cumsum_kernel<1024, 1><<>>(inputs, outputs, s, e); + else if (s <= 2048) + cumsum_kernel<1024, 2><<>>(inputs, outputs, s, e); + else + cumsum_kernel<1024, 4><<>>(inputs, outputs, s, e); +} + +// API FUNCTIONS -------------------------------- + +#define DISPATCH_FLOAT_AND_HALF(TYPE, NAME, ...) \ + switch (TYPE) { \ + case at::ScalarType::Float: { \ + using scalar_t = float; \ + __VA_ARGS__; \ + break; \ + } \ + case at::ScalarType::Half: { \ + using scalar_t = at::Half; \ + __VA_ARGS__; \ + break; \ + } \ + default: \ + AT_ERROR(#NAME, " not implemented yet for specific data type."); \ + } + +torch::Tensor moe_dispatch_cuda_forward(int s, int ec, int h, + torch::Tensor batch_tokens, + torch::Tensor mask, + torch::Tensor dest_idx) { + assert(h % 16 == 0); + auto res = torch::zeros( + {ec, h}, + torch::dtype(batch_tokens.dtype()).device(batch_tokens.device())); + auto k = mask.size(0); + + DISPATCH_FLOAT_AND_HALF( + batch_tokens.scalar_type(), "moe dispatch forward", + moe_dpch_fwd_launch( + batch_tokens.data(), res.data(), + mask[0].data(), k == 1 ? nullptr : mask[1].data(), + dest_idx[0].data(), + k == 1 ? dest_idx[0].data() : dest_idx[1].data(), s, h)); + + return res; +} + +torch::Tensor moe_dispatch_cuda_backward(int s, int ec, int h, + torch::Tensor expert_grad, + torch::Tensor mask, + torch::Tensor dest_idx) { + assert(h % 16 == 0); + auto res = torch::zeros( + {s, h}, torch::dtype(expert_grad.dtype()).device(expert_grad.device())); + auto k = mask.size(0); + + DISPATCH_FLOAT_AND_HALF( + expert_grad.scalar_type(), "moe dispatch backward", + moe_dpch_bwd_launch( + res.data(), expert_grad.data(), + mask[0].data(), k == 1 ? nullptr : mask[1].data(), + dest_idx[0].data(), + k == 1 ? dest_idx[0].data() : dest_idx[1].data(), s, h)); + + return res; +} + +torch::Tensor moe_combine_cuda_forward(int s, int e, int c, int h, + torch::Tensor expert_tokens, + torch::Tensor logits, torch::Tensor mask, + torch::Tensor dest_idx) { + assert(h % 16 == 0); + assert(expert_tokens.dtype() == logits.dtype()); + + auto res = torch::zeros( + {s, h}, + torch::dtype(expert_tokens.dtype()).device(expert_tokens.device())); + auto k = mask.size(0); + + DISPATCH_FLOAT_AND_HALF( + expert_tokens.scalar_type(), "moe combine forward", + moe_cb_fwd_launch( + expert_tokens.data(), res.data(), + logits.data(), mask[0].data(), + k == 1 ? nullptr : mask[1].data(), dest_idx[0].data(), + k == 1 ? dest_idx[0].data() : dest_idx[1].data(), s, e, c, + h)); + + return res; +} + +std::vector moe_combine_cuda_backward( + int s, int e, int c, int h, torch::Tensor tokens_grad, + torch::Tensor expert_tokens, torch::Tensor logits, torch::Tensor mask, + torch::Tensor dest_idx) { + assert(h % 16 == 0); + assert(tokens_grad.dtype() == expert_tokens.dtype()); + assert(expert_tokens.dtype() == logits.dtype()); + + auto egrad = torch::zeros( + {e * c, h}, + torch::dtype(tokens_grad.dtype()).device(tokens_grad.device())), + wgrad = torch::zeros( + {s, e}, torch::dtype(logits.dtype()).device(logits.device())); + auto k = mask.size(0); + + DISPATCH_FLOAT_AND_HALF( + tokens_grad.scalar_type(), "moe combine backward", + moe_cb_bwd_launch( + tokens_grad.data(), egrad.data(), + expert_tokens.data(), logits.data(), + wgrad.data(), mask[0].data(), + k == 1 ? nullptr : mask[1].data(), dest_idx[0].data(), + k == 1 ? dest_idx[0].data() : dest_idx[1].data(), s, e, c, + h)); + + return {egrad, wgrad}; +} + +torch::Tensor cumsum_sub_one_in_dim0(torch::Tensor mask) { + assert(mask.dim() == 2); + assert(mask.dtype() == torch::kInt32); + + const int s = mask.size(0), e = mask.size(1); + auto res = + torch::empty({s, e}, torch::dtype(torch::kInt32).device(mask.device())); + cumsum_launch(mask.data(), res.data(), s, e); + + return res; +} diff --git a/colossalai/kernel/cuda_native/csrc/multi_tensor_l2norm_kernel.cu b/colossalai/kernel/cuda_native/csrc/multi_tensor_l2norm_kernel.cu index 49ab83e8fc81df4c9887d55ddc5503f20498bb7d..85f935152f8a46e76ec6f8d7a7ea71e839c435a5 100644 --- a/colossalai/kernel/cuda_native/csrc/multi_tensor_l2norm_kernel.cu +++ b/colossalai/kernel/cuda_native/csrc/multi_tensor_l2norm_kernel.cu @@ -379,4 +379,4 @@ void multi_tensor_norm_out_cuda( norm_type, alpha, beta); return; -} \ No newline at end of file +} diff --git a/colossalai/kernel/cuda_native/csrc/multi_tensor_lamb.cu b/colossalai/kernel/cuda_native/csrc/multi_tensor_lamb.cu index 54c4220190d80d6309e74c90da412d6ccda32c8f..63771cf40bcb053d0e94445423d64f025ab23965 100644 --- a/colossalai/kernel/cuda_native/csrc/multi_tensor_lamb.cu +++ b/colossalai/kernel/cuda_native/csrc/multi_tensor_lamb.cu @@ -351,4 +351,4 @@ void multi_tensor_lamb_cuda(int chunk_size, at::Tensor noop_flag, lr, weight_decay, use_nvlamb);) AT_CUDA_CHECK(cudaGetLastError()); -} \ No newline at end of file +} diff --git a/colossalai/kernel/cuda_native/csrc/multi_tensor_scale_kernel.cu b/colossalai/kernel/cuda_native/csrc/multi_tensor_scale_kernel.cu index 360485dcd02fbfc21a76a2bfa6dd6568b8909499..2f58a0f16dce000be6271ac7abc7104c6e548d3d 100644 --- a/colossalai/kernel/cuda_native/csrc/multi_tensor_scale_kernel.cu +++ b/colossalai/kernel/cuda_native/csrc/multi_tensor_scale_kernel.cu @@ -122,4 +122,4 @@ void multi_tensor_scale_cuda(int chunk_size, at::Tensor noop_flag, AT_CUDA_CHECK(cudaGetLastError()); // AT_CUDA_CHECK(cudaDeviceSynchronize()); -} \ No newline at end of file +} diff --git a/colossalai/kernel/cuda_native/csrc/multi_tensor_sgd_kernel.cu b/colossalai/kernel/cuda_native/csrc/multi_tensor_sgd_kernel.cu index 35f2c9b4ed15eab94b1456ce436694180d706a45..7f48dbd5d497d1695592cd0fc1f6b984e4073e3b 100644 --- a/colossalai/kernel/cuda_native/csrc/multi_tensor_sgd_kernel.cu +++ b/colossalai/kernel/cuda_native/csrc/multi_tensor_sgd_kernel.cu @@ -164,4 +164,4 @@ void multi_tensor_sgd_cuda(int chunk_size, at::Tensor noop_flag, } AT_CUDA_CHECK(cudaGetLastError()); -} \ No newline at end of file +} diff --git a/colossalai/kernel/cuda_native/csrc/scaled_masked_softmax.cpp b/colossalai/kernel/cuda_native/csrc/scaled_masked_softmax.cpp index 4ae3c853ca5e844272ca4fdb907c8c95a7f2b787..8c2982b0cff99bf0741abd8115515e18e1e34128 100644 --- a/colossalai/kernel/cuda_native/csrc/scaled_masked_softmax.cpp +++ b/colossalai/kernel/cuda_native/csrc/scaled_masked_softmax.cpp @@ -3,82 +3,68 @@ #include #include + #include namespace multihead_attn { namespace fused_softmax { namespace scaled_masked_softmax { -torch::Tensor fwd_cuda( - torch::Tensor const& input, - torch::Tensor const& mask, - float scale_factor); - -torch::Tensor bwd_cuda( - torch::Tensor const& output_grads, - torch::Tensor const& softmax_results, - float scale_factor); - -int get_batch_per_block_cuda( - int query_seq_len, - int key_seq_len, - int batches, - int attn_heads); - -torch::Tensor fwd( - torch::Tensor const& input, - torch::Tensor const& mask, - float scale_factor) { +torch::Tensor fwd_cuda(torch::Tensor const& input, torch::Tensor const& mask, + float scale_factor); + +torch::Tensor bwd_cuda(torch::Tensor const& output_grads, + torch::Tensor const& softmax_results, + float scale_factor); + +int get_batch_per_block_cuda(int query_seq_len, int key_seq_len, int batches, + int attn_heads); + +torch::Tensor fwd(torch::Tensor const& input, torch::Tensor const& mask, + float scale_factor) { AT_ASSERTM(input.dim() == 4, "expected 4D tensor"); AT_ASSERTM((input.scalar_type() == at::ScalarType::Half) || - (input.scalar_type() == at::ScalarType::BFloat16), - "Only fp16 and bf16 are supported"); + (input.scalar_type() == at::ScalarType::BFloat16), + "Only fp16 and bf16 are supported"); AT_ASSERTM(mask.dim() == 4, "expected 4D tensor"); return fwd_cuda(input, mask, scale_factor); } -torch::Tensor bwd( - torch::Tensor const& output_grads, - torch::Tensor const& softmax_results, - float scale_factor) { - +torch::Tensor bwd(torch::Tensor const& output_grads, + torch::Tensor const& softmax_results, float scale_factor) { AT_ASSERTM(output_grads.dim() == 4, "expected 3D tensor"); AT_ASSERTM(softmax_results.dim() == 4, "expected 3D tensor"); AT_ASSERTM((output_grads.scalar_type() == at::ScalarType::Half) || - (output_grads.scalar_type() == at::ScalarType::BFloat16), - "Only fp16 and bf16 are supported"); + (output_grads.scalar_type() == at::ScalarType::BFloat16), + "Only fp16 and bf16 are supported"); AT_ASSERTM((softmax_results.scalar_type() == at::ScalarType::Half) || - (softmax_results.scalar_type() == at::ScalarType::BFloat16), - "Only fp16 and bf16 are supported"); + (softmax_results.scalar_type() == at::ScalarType::BFloat16), + "Only fp16 and bf16 are supported"); return bwd_cuda(output_grads, softmax_results, scale_factor); } -int get_batch_per_block( - int query_seq_len, - int key_seq_len, - int batches, - int attn_heads) { - return get_batch_per_block_cuda(query_seq_len, key_seq_len, batches, attn_heads); +int get_batch_per_block(int query_seq_len, int key_seq_len, int batches, + int attn_heads) { + return get_batch_per_block_cuda(query_seq_len, key_seq_len, batches, + attn_heads); } -} // end namespace scaled_masked_softmax -} // end namespace fused_softmax -} // end namespace multihead_attn +} // end namespace scaled_masked_softmax +} // end namespace fused_softmax +} // end namespace multihead_attn PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("forward", - &multihead_attn::fused_softmax::scaled_masked_softmax::fwd, - "Self Multihead Attention scaled, time masked softmax -- Forward."); + m.def("forward", &multihead_attn::fused_softmax::scaled_masked_softmax::fwd, + "Self Multihead Attention scaled, time masked softmax -- Forward."); - m.def("backward", - &multihead_attn::fused_softmax::scaled_masked_softmax::bwd, - "Self Multihead Attention scaled, time masked softmax -- Backward."); + m.def("backward", &multihead_attn::fused_softmax::scaled_masked_softmax::bwd, + "Self Multihead Attention scaled, time masked softmax -- Backward."); m.def("get_batch_per_block", - &multihead_attn::fused_softmax::scaled_masked_softmax::get_batch_per_block, - "Return Batch per block size." - ); + &multihead_attn::fused_softmax::scaled_masked_softmax:: + get_batch_per_block, + "Return Batch per block size."); } diff --git a/colossalai/kernel/cuda_native/csrc/scaled_masked_softmax.h b/colossalai/kernel/cuda_native/csrc/scaled_masked_softmax.h index 1583030b8235acfb3a3af1a86fa938901ae52bbb..d3e6f04e6093fa9726bdeaca918dbeab5778c0a2 100644 --- a/colossalai/kernel/cuda_native/csrc/scaled_masked_softmax.h +++ b/colossalai/kernel/cuda_native/csrc/scaled_masked_softmax.h @@ -4,12 +4,12 @@ #pragma once #include +#include #include +#include + #include #include -#include -#include -#include namespace { @@ -17,37 +17,53 @@ template __device__ __inline__ void copy_vector(Datatype *dst, const Datatype *src); template <> -__device__ __inline__ void copy_vector(c10::BFloat16 *dst, const c10::BFloat16 *src) { *dst = *src; } +__device__ __inline__ void copy_vector( + c10::BFloat16 *dst, const c10::BFloat16 *src) { + *dst = *src; +} template <> -__device__ __inline__ void copy_vector(c10::BFloat16 *dst, const c10::BFloat16 *src) { *((float2*) dst) = *((float2*) src); } +__device__ __inline__ void copy_vector( + c10::BFloat16 *dst, const c10::BFloat16 *src) { + *((float2 *)dst) = *((float2 *)src); +} template <> -__device__ __inline__ void copy_vector(c10::Half *dst, const c10::Half *src) { *dst = *src; } +__device__ __inline__ void copy_vector(c10::Half *dst, + const c10::Half *src) { + *dst = *src; +} template <> -__device__ __inline__ void copy_vector(c10::Half *dst, const c10::Half *src) { *((float2*) dst) = *((float2*) src); } +__device__ __inline__ void copy_vector(c10::Half *dst, + const c10::Half *src) { + *((float2 *)dst) = *((float2 *)src); +} template <> -__device__ __inline__ void copy_vector(uint8_t *dst, const uint8_t *src) { *dst = *src; } +__device__ __inline__ void copy_vector(uint8_t *dst, + const uint8_t *src) { + *dst = *src; +} template <> -__device__ __inline__ void copy_vector(uint8_t *dst, const uint8_t *src) {*((half2*) dst) = *((half2*) src); } +__device__ __inline__ void copy_vector(uint8_t *dst, + const uint8_t *src) { + *((half2 *)dst) = *((half2 *)src); +} int log2_ceil(int value) { - int log2_value = 0; - while ((1 << log2_value) < value) ++log2_value; - return log2_value; + int log2_value = 0; + while ((1 << log2_value) < value) ++log2_value; + return log2_value; } -template +template struct Add { - __device__ __forceinline__ T operator()(T a, T b) const { - return a + b; - } + __device__ __forceinline__ T operator()(T a, T b) const { return a + b; } }; -template +template struct Max { __device__ __forceinline__ T operator()(T a, T b) const { return a < b ? b : a; @@ -55,438 +71,468 @@ struct Max { }; template -__device__ __forceinline__ T WARP_SHFL_XOR_NATIVE(T value, int laneMask, int width = warpSize, unsigned int mask = 0xffffffff) -{ +__device__ __forceinline__ T +WARP_SHFL_XOR_NATIVE(T value, int laneMask, int width = warpSize, + unsigned int mask = 0xffffffff) { #if CUDA_VERSION >= 9000 - return __shfl_xor_sync(mask, value, laneMask, width); + return __shfl_xor_sync(mask, value, laneMask, width); #else - return __shfl_xor(value, laneMask, width); + return __shfl_xor(value, laneMask, width); #endif } -template class ReduceOp> -__device__ __forceinline__ void warp_reduce(acc_t* sum) { - ReduceOp r; - #pragma unroll - for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) { - #pragma unroll - for (int i = 0; i < WARP_BATCH; ++i) { - acc_t b = WARP_SHFL_XOR_NATIVE(sum[i], offset, WARP_SIZE); - sum[i] = r(sum[i], b); - } +template class ReduceOp> +__device__ __forceinline__ void warp_reduce(acc_t *sum) { + ReduceOp r; +#pragma unroll + for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) { +#pragma unroll + for (int i = 0; i < WARP_BATCH; ++i) { + acc_t b = WARP_SHFL_XOR_NATIVE(sum[i], offset, WARP_SIZE); + sum[i] = r(sum[i], b); } + } } /* - * Extended softmax (from native aten pytorch) with following additional features - * 1) input scaling - * 2) Explicit masking - */ -template + * Extended softmax (from native aten pytorch) with following additional + * features 1) input scaling 2) Explicit masking + */ +template __global__ void scaled_masked_softmax_warp_forward( - output_t *dst, - const input_t *src, - const uint8_t *mask, - const acc_t scale, - int micro_batch_size, - int element_count, - int pad_batches) -{ - // WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and - // warp_size of method warp_softmax_forward_kernel. - constexpr int next_power_of_two = 1 << log2_elements; - constexpr int WARP_SIZE = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE; - constexpr int WARP_ITERATIONS = next_power_of_two / WARP_SIZE; - constexpr int WARP_BATCH = (next_power_of_two <= 128) ? 2 : 1; - constexpr int ELEMENTS_PER_LDG_STG = (WARP_ITERATIONS < 4) ? 1 : 4; - - // blockDim/threadIdx = (WARP_SIZE, WARPS_PER_BLOCK, ) - // gridDim/blockIdx = (seq_len, attn_heads, batches) - int first_batch = (blockDim.y * (blockIdx.x + gridDim.x * (blockIdx.y + gridDim.y * blockIdx.z))+ threadIdx.y) * WARP_BATCH; - int pad_first_batch = 0; - if (pad_batches != 1) { // bert style - pad_first_batch = (blockDim.y * (blockIdx.x + gridDim.x * blockIdx.z) + threadIdx.y) * WARP_BATCH; - } else { // gpt2 style - pad_first_batch = (blockDim.y * blockIdx.x + threadIdx.y) * WARP_BATCH; - } + output_t *dst, const input_t *src, const uint8_t *mask, const acc_t scale, + int micro_batch_size, int element_count, int pad_batches) { + // WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and + // warp_size of method warp_softmax_forward_kernel. + constexpr int next_power_of_two = 1 << log2_elements; + constexpr int WARP_SIZE = + (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE; + constexpr int WARP_ITERATIONS = next_power_of_two / WARP_SIZE; + constexpr int WARP_BATCH = (next_power_of_two <= 128) ? 2 : 1; + constexpr int ELEMENTS_PER_LDG_STG = (WARP_ITERATIONS < 4) ? 1 : 4; + + // blockDim/threadIdx = (WARP_SIZE, WARPS_PER_BLOCK, ) + // gridDim/blockIdx = (seq_len, attn_heads, batches) + int first_batch = + (blockDim.y * + (blockIdx.x + gridDim.x * (blockIdx.y + gridDim.y * blockIdx.z)) + + threadIdx.y) * + WARP_BATCH; + int pad_first_batch = 0; + if (pad_batches != 1) { // bert style + pad_first_batch = + (blockDim.y * (blockIdx.x + gridDim.x * blockIdx.z) + threadIdx.y) * + WARP_BATCH; + } else { // gpt2 style + pad_first_batch = (blockDim.y * blockIdx.x + threadIdx.y) * WARP_BATCH; + } - // micro_batch_size might not be a multiple of WARP_BATCH. Check how - // many batches have to computed within this WARP. - int local_batches = micro_batch_size - first_batch; - if (local_batches > WARP_BATCH) - local_batches = WARP_BATCH; - - // there might be multiple batches per warp. compute the index within the batch - int local_idx = threadIdx.x; - - src += first_batch * element_count + ELEMENTS_PER_LDG_STG * local_idx; - dst += first_batch * element_count + ELEMENTS_PER_LDG_STG * local_idx; - mask += pad_first_batch * element_count + ELEMENTS_PER_LDG_STG * local_idx; - - // load data from global memory - acc_t elements[WARP_BATCH][WARP_ITERATIONS]; - input_t temp_data[ELEMENTS_PER_LDG_STG]; - uint8_t temp_mask[ELEMENTS_PER_LDG_STG]; - #pragma unroll - for (int i = 0; i < WARP_BATCH; ++i) { - int batch_element_count = (i >= local_batches) ? 0 : element_count; - - #pragma unroll - for (int it = 0; it < WARP_ITERATIONS; it+=ELEMENTS_PER_LDG_STG) { - int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE; - - if (element_index < batch_element_count) { - int itr_idx = i*element_count+it*WARP_SIZE; - copy_vector(temp_data, src + itr_idx); - copy_vector(temp_mask, mask + itr_idx); - - #pragma unroll - for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { - if (temp_mask[element] != 1) { - elements[i][it + element] = (acc_t)temp_data[element] * scale; - } else { - elements[i][it + element] = -10000.0; - } - } - } else { - #pragma unroll - for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { - elements[i][it + element] = -std::numeric_limits::infinity(); - } - } + // micro_batch_size might not be a multiple of WARP_BATCH. Check how + // many batches have to computed within this WARP. + int local_batches = micro_batch_size - first_batch; + if (local_batches > WARP_BATCH) local_batches = WARP_BATCH; + + // there might be multiple batches per warp. compute the index within the + // batch + int local_idx = threadIdx.x; + + src += first_batch * element_count + ELEMENTS_PER_LDG_STG * local_idx; + dst += first_batch * element_count + ELEMENTS_PER_LDG_STG * local_idx; + mask += pad_first_batch * element_count + ELEMENTS_PER_LDG_STG * local_idx; + + // load data from global memory + acc_t elements[WARP_BATCH][WARP_ITERATIONS]; + input_t temp_data[ELEMENTS_PER_LDG_STG]; + uint8_t temp_mask[ELEMENTS_PER_LDG_STG]; +#pragma unroll + for (int i = 0; i < WARP_BATCH; ++i) { + int batch_element_count = (i >= local_batches) ? 0 : element_count; + +#pragma unroll + for (int it = 0; it < WARP_ITERATIONS; it += ELEMENTS_PER_LDG_STG) { + int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE; + + if (element_index < batch_element_count) { + int itr_idx = i * element_count + it * WARP_SIZE; + copy_vector(temp_data, src + itr_idx); + copy_vector(temp_mask, mask + itr_idx); + +#pragma unroll + for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { + if (temp_mask[element] != 1) { + elements[i][it + element] = (acc_t)temp_data[element] * scale; + } else { + elements[i][it + element] = -10000.0; + } + } + } else { +#pragma unroll + for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { + elements[i][it + element] = -std::numeric_limits::infinity(); } + } } + } - // compute max_value - acc_t max_value[WARP_BATCH]; - #pragma unroll - for (int i = 0; i < WARP_BATCH; ++i) { - max_value[i] = elements[i][0]; - #pragma unroll - for (int it = 1; it < WARP_ITERATIONS; ++it) { - max_value[i] = (max_value[i] > elements[i][it]) ? max_value[i] : elements[i][it]; - } + // compute max_value + acc_t max_value[WARP_BATCH]; +#pragma unroll + for (int i = 0; i < WARP_BATCH; ++i) { + max_value[i] = elements[i][0]; +#pragma unroll + for (int it = 1; it < WARP_ITERATIONS; ++it) { + max_value[i] = + (max_value[i] > elements[i][it]) ? max_value[i] : elements[i][it]; } - warp_reduce(max_value); - - acc_t sum[WARP_BATCH] { 0.0f }; - #pragma unroll - for (int i = 0; i < WARP_BATCH; ++i) { - #pragma unroll - for (int it = 0; it < WARP_ITERATIONS; ++it) { - elements[i][it] = std::exp((elements[i][it] - max_value[i])); - sum[i] += elements[i][it]; - } + } + warp_reduce(max_value); + + acc_t sum[WARP_BATCH]{0.0f}; +#pragma unroll + for (int i = 0; i < WARP_BATCH; ++i) { +#pragma unroll + for (int it = 0; it < WARP_ITERATIONS; ++it) { + elements[i][it] = std::exp((elements[i][it] - max_value[i])); + sum[i] += elements[i][it]; } - warp_reduce(sum); - - // store result - output_t out[ELEMENTS_PER_LDG_STG]; - #pragma unroll - for (int i = 0; i < WARP_BATCH; ++i) { - if (i >= local_batches) - break; - #pragma unroll - for (int it = 0; it < WARP_ITERATIONS; it+=ELEMENTS_PER_LDG_STG) { - int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE; - if (element_index < element_count) { - #pragma unroll - for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { - out[element] = elements[i][it + element] / sum[i]; - } - copy_vector(dst + i * element_count + it * WARP_SIZE, out); - } else { - break; - } + } + warp_reduce(sum); + + // store result + output_t out[ELEMENTS_PER_LDG_STG]; +#pragma unroll + for (int i = 0; i < WARP_BATCH; ++i) { + if (i >= local_batches) break; +#pragma unroll + for (int it = 0; it < WARP_ITERATIONS; it += ELEMENTS_PER_LDG_STG) { + int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE; + if (element_index < element_count) { +#pragma unroll + for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { + out[element] = elements[i][it + element] / sum[i]; } + copy_vector( + dst + i * element_count + it * WARP_SIZE, out); + } else { + break; + } } + } } -template +template __global__ void scaled_masked_softmax_warp_backward( - output_t *gradInput, - input_t *grad, - const input_t *output, - acc_t scale, - int micro_batch_size, - int element_count) -{ - // WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and - // warp_size of method warp_softmax_backward_kernel. - constexpr int next_power_of_two = 1 << log2_elements; - constexpr int WARP_SIZE = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE; - constexpr int WARP_ITERATIONS = next_power_of_two / WARP_SIZE; - constexpr int WARP_BATCH = (next_power_of_two <= 128) ? 2 : 1; - constexpr int ELEMENTS_PER_LDG_STG = (WARP_ITERATIONS < 4) ? 1 : 4; - - // blockDim/threadIdx = (WARP_SIZE, WARPS_PER_BLOCK, ) - // gridDim/blockIdx = (seq_len, attn_heads, batches) - int first_batch = (blockDim.y * blockIdx.x + threadIdx.y) * WARP_BATCH; - - // micro_batch_size might not be a multiple of WARP_BATCH. Check how - // many batches have to computed within this WARP. - int local_batches = micro_batch_size - first_batch; - if (local_batches > WARP_BATCH) - local_batches = WARP_BATCH; - - // there might be multiple batches per warp. compute the index within the batch - int local_idx = threadIdx.x; - - // the first element to process by the current thread - int thread_offset = first_batch * element_count + ELEMENTS_PER_LDG_STG * local_idx; - grad += thread_offset; - output += thread_offset; - gradInput += thread_offset; - - // load data from global memory - acc_t grad_reg[WARP_BATCH][WARP_ITERATIONS] { 0.0f }; - acc_t output_reg[WARP_BATCH][WARP_ITERATIONS] { 0.0f }; - input_t temp_grad[ELEMENTS_PER_LDG_STG]; - input_t temp_output[ELEMENTS_PER_LDG_STG]; - #pragma unroll - for (int i = 0; i < WARP_BATCH; ++i) { - int batch_element_count = (i >= local_batches) ? 0 : element_count; - - #pragma unroll - for (int it = 0; it < WARP_ITERATIONS; it+=ELEMENTS_PER_LDG_STG) { - int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE; - if (element_index < batch_element_count) { - copy_vector(temp_grad, grad + i * element_count + it * WARP_SIZE); - copy_vector(temp_output, output + i * element_count + it * WARP_SIZE); - - #pragma unroll - for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { - output_reg[i][it + element] = (acc_t)temp_output[element]; - } - #pragma unroll - for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { - grad_reg[i][it + element] = (acc_t)temp_grad[element] * output_reg[i][it + element]; - } - } + output_t *gradInput, input_t *grad, const input_t *output, acc_t scale, + int micro_batch_size, int element_count) { + // WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and + // warp_size of method warp_softmax_backward_kernel. + constexpr int next_power_of_two = 1 << log2_elements; + constexpr int WARP_SIZE = + (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE; + constexpr int WARP_ITERATIONS = next_power_of_two / WARP_SIZE; + constexpr int WARP_BATCH = (next_power_of_two <= 128) ? 2 : 1; + constexpr int ELEMENTS_PER_LDG_STG = (WARP_ITERATIONS < 4) ? 1 : 4; + + // blockDim/threadIdx = (WARP_SIZE, WARPS_PER_BLOCK, ) + // gridDim/blockIdx = (seq_len, attn_heads, batches) + int first_batch = (blockDim.y * blockIdx.x + threadIdx.y) * WARP_BATCH; + + // micro_batch_size might not be a multiple of WARP_BATCH. Check how + // many batches have to computed within this WARP. + int local_batches = micro_batch_size - first_batch; + if (local_batches > WARP_BATCH) local_batches = WARP_BATCH; + + // there might be multiple batches per warp. compute the index within the + // batch + int local_idx = threadIdx.x; + + // the first element to process by the current thread + int thread_offset = + first_batch * element_count + ELEMENTS_PER_LDG_STG * local_idx; + grad += thread_offset; + output += thread_offset; + gradInput += thread_offset; + + // load data from global memory + acc_t grad_reg[WARP_BATCH][WARP_ITERATIONS]{0.0f}; + acc_t output_reg[WARP_BATCH][WARP_ITERATIONS]{0.0f}; + input_t temp_grad[ELEMENTS_PER_LDG_STG]; + input_t temp_output[ELEMENTS_PER_LDG_STG]; +#pragma unroll + for (int i = 0; i < WARP_BATCH; ++i) { + int batch_element_count = (i >= local_batches) ? 0 : element_count; + +#pragma unroll + for (int it = 0; it < WARP_ITERATIONS; it += ELEMENTS_PER_LDG_STG) { + int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE; + if (element_index < batch_element_count) { + copy_vector( + temp_grad, grad + i * element_count + it * WARP_SIZE); + copy_vector( + temp_output, output + i * element_count + it * WARP_SIZE); + +#pragma unroll + for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { + output_reg[i][it + element] = (acc_t)temp_output[element]; } - } - - acc_t sum[WARP_BATCH]; - #pragma unroll - for (int i = 0; i < WARP_BATCH; ++i) { - sum[i] = grad_reg[i][0]; - #pragma unroll - for (int it = 1; it < WARP_ITERATIONS; ++it) { - sum[i] += grad_reg[i][it]; +#pragma unroll + for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { + grad_reg[i][it + element] = + (acc_t)temp_grad[element] * output_reg[i][it + element]; } + } } - warp_reduce(sum); - - // store result - #pragma unroll - for (int i = 0; i < WARP_BATCH; ++i) { - if (i >= local_batches) - break; - #pragma unroll - for (int it = 0; it < WARP_ITERATIONS; it+=ELEMENTS_PER_LDG_STG) { - int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE; - if (element_index < element_count) { - // compute gradients - output_t out[ELEMENTS_PER_LDG_STG]; - #pragma unroll - for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { - out[element] = (output_t)(scale * (grad_reg[i][it + element] - output_reg[i][it + element] * sum[i])); - } - copy_vector(gradInput + i * element_count + it * WARP_SIZE, out); - } + } + + acc_t sum[WARP_BATCH]; +#pragma unroll + for (int i = 0; i < WARP_BATCH; ++i) { + sum[i] = grad_reg[i][0]; +#pragma unroll + for (int it = 1; it < WARP_ITERATIONS; ++it) { + sum[i] += grad_reg[i][it]; + } + } + warp_reduce(sum); + +// store result +#pragma unroll + for (int i = 0; i < WARP_BATCH; ++i) { + if (i >= local_batches) break; +#pragma unroll + for (int it = 0; it < WARP_ITERATIONS; it += ELEMENTS_PER_LDG_STG) { + int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE; + if (element_index < element_count) { + // compute gradients + output_t out[ELEMENTS_PER_LDG_STG]; +#pragma unroll + for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { + out[element] = + (output_t)(scale * (grad_reg[i][it + element] - + output_reg[i][it + element] * sum[i])); } + copy_vector( + gradInput + i * element_count + it * WARP_SIZE, out); + } } + } +} +} // end of anonymous namespace + +int get_batch_per_block(int query_seq_len, int key_seq_len, int batches, + int attn_heads) { + int log2_elements = log2_ceil(key_seq_len); + const int next_power_of_two = 1 << log2_elements; + + int warp_size = + (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE; + int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1; + + constexpr int threads_per_block = 128; + int warps_per_block = (threads_per_block / warp_size); + int batches_per_block = warps_per_block * batches_per_warp; + + return batches_per_block; } -} // end of anonymous namespace -int get_batch_per_block(int query_seq_len, int key_seq_len, int batches, int attn_heads){ +template +void dispatch_scaled_masked_softmax_forward(output_t *dst, const input_t *src, + const uint8_t *mask, + const input_t scale, + int query_seq_len, int key_seq_len, + int batches, int attn_heads, + int pad_batches) { + TORCH_INTERNAL_ASSERT(key_seq_len >= 0 && key_seq_len <= 2048); + if (key_seq_len == 0) { + return; + } else { int log2_elements = log2_ceil(key_seq_len); const int next_power_of_two = 1 << log2_elements; + int batch_count = batches * attn_heads * query_seq_len; - int warp_size = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE; + // This value must match the WARP_SIZE constexpr value computed inside + // softmax_warp_forward. + int warp_size = + (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE; + + // This value must match the WARP_BATCH constexpr value computed inside + // softmax_warp_forward. int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1; + // use 128 threads per block to maximimize gpu utilization constexpr int threads_per_block = 128; + int warps_per_block = (threads_per_block / warp_size); int batches_per_block = warps_per_block * batches_per_warp; - - return batches_per_block; -} - -template -void dispatch_scaled_masked_softmax_forward( - output_t *dst, - const input_t *src, - const uint8_t *mask, - const input_t scale, - int query_seq_len, - int key_seq_len, - int batches, - int attn_heads, - int pad_batches) -{ - TORCH_INTERNAL_ASSERT(key_seq_len >= 0 && key_seq_len <= 2048 ); - if (key_seq_len == 0) { - return; - } else { - int log2_elements = log2_ceil(key_seq_len); - const int next_power_of_two = 1 << log2_elements; - int batch_count = batches * attn_heads * query_seq_len; - - // This value must match the WARP_SIZE constexpr value computed inside softmax_warp_forward. - int warp_size = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE; - - // This value must match the WARP_BATCH constexpr value computed inside softmax_warp_forward. - int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1; - - // use 128 threads per block to maximimize gpu utilization - constexpr int threads_per_block = 128; - - int warps_per_block = (threads_per_block / warp_size); - int batches_per_block = warps_per_block * batches_per_warp; - TORCH_INTERNAL_ASSERT(query_seq_len%batches_per_block == 0); - dim3 blocks(query_seq_len/batches_per_block, attn_heads, batches); - dim3 threads(warp_size, warps_per_block, 1); - // Launch code would be more elegant if C++ supported FOR CONSTEXPR - switch (log2_elements) { - case 0: // 1 - scaled_masked_softmax_warp_forward - <<>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches); - break; - case 1: // 2 - scaled_masked_softmax_warp_forward - <<>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches); - break; - case 2: // 4 - scaled_masked_softmax_warp_forward - <<>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches); - break; - case 3: // 8 - scaled_masked_softmax_warp_forward - <<>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches); - break; - case 4: // 16 - scaled_masked_softmax_warp_forward - <<>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches); - break; - case 5: // 32 - scaled_masked_softmax_warp_forward - <<>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches); - break; - case 6: // 64 - scaled_masked_softmax_warp_forward - <<>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches); - break; - case 7: // 128 - scaled_masked_softmax_warp_forward - <<>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches); - break; - case 8: // 256 - scaled_masked_softmax_warp_forward - <<>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches); - break; - case 9: // 512 - scaled_masked_softmax_warp_forward - <<>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches); - break; - case 10: // 1024 - scaled_masked_softmax_warp_forward - <<>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches); - break; - case 11: // 2048 - scaled_masked_softmax_warp_forward - <<>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches); - break; - default: - break; - } + TORCH_INTERNAL_ASSERT(query_seq_len % batches_per_block == 0); + dim3 blocks(query_seq_len / batches_per_block, attn_heads, batches); + dim3 threads(warp_size, warps_per_block, 1); + // Launch code would be more elegant if C++ supported FOR CONSTEXPR + switch (log2_elements) { + case 0: // 1 + scaled_masked_softmax_warp_forward + <<>>( + dst, src, mask, scale, batch_count, key_seq_len, pad_batches); + break; + case 1: // 2 + scaled_masked_softmax_warp_forward + <<>>( + dst, src, mask, scale, batch_count, key_seq_len, pad_batches); + break; + case 2: // 4 + scaled_masked_softmax_warp_forward + <<>>( + dst, src, mask, scale, batch_count, key_seq_len, pad_batches); + break; + case 3: // 8 + scaled_masked_softmax_warp_forward + <<>>( + dst, src, mask, scale, batch_count, key_seq_len, pad_batches); + break; + case 4: // 16 + scaled_masked_softmax_warp_forward + <<>>( + dst, src, mask, scale, batch_count, key_seq_len, pad_batches); + break; + case 5: // 32 + scaled_masked_softmax_warp_forward + <<>>( + dst, src, mask, scale, batch_count, key_seq_len, pad_batches); + break; + case 6: // 64 + scaled_masked_softmax_warp_forward + <<>>( + dst, src, mask, scale, batch_count, key_seq_len, pad_batches); + break; + case 7: // 128 + scaled_masked_softmax_warp_forward + <<>>( + dst, src, mask, scale, batch_count, key_seq_len, pad_batches); + break; + case 8: // 256 + scaled_masked_softmax_warp_forward + <<>>( + dst, src, mask, scale, batch_count, key_seq_len, pad_batches); + break; + case 9: // 512 + scaled_masked_softmax_warp_forward + <<>>( + dst, src, mask, scale, batch_count, key_seq_len, pad_batches); + break; + case 10: // 1024 + scaled_masked_softmax_warp_forward + <<>>( + dst, src, mask, scale, batch_count, key_seq_len, pad_batches); + break; + case 11: // 2048 + scaled_masked_softmax_warp_forward + <<>>( + dst, src, mask, scale, batch_count, key_seq_len, pad_batches); + break; + default: + break; } + } } -template -void dispatch_scaled_masked_softmax_backward( - output_t *grad_input, - input_t *grad, - const input_t *output, - const acc_t scale, - int query_seq_len, - int key_seq_len, - int batches, - int attn_heads) -{ - TORCH_INTERNAL_ASSERT( key_seq_len >= 0 && key_seq_len <= 2048 ); - if (key_seq_len == 0) { - return; - } else { - int log2_elements = log2_ceil(key_seq_len); - const int next_power_of_two = 1 << log2_elements; - int batch_count = batches * attn_heads * query_seq_len; - - // This value must match the WARP_SIZE constexpr value computed inside softmax_warp_backward. - int warp_size = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE; - - // This value must match the WARP_BATCH constexpr value computed inside softmax_warp_backward. - int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1; - - // use 128 threads per block to maximimize gpu utilization - constexpr int threads_per_block = 128; - - int warps_per_block = (threads_per_block / warp_size); - int batches_per_block = warps_per_block * batches_per_warp; - int blocks = batch_count/batches_per_block; - dim3 threads(warp_size, warps_per_block, 1); - // Launch code would be more elegant if C++ supported FOR CONSTEXPR - switch (log2_elements) { - case 0: // 1 - scaled_masked_softmax_warp_backward - <<>>(grad_input, grad, output, scale, batch_count, key_seq_len); - break; - case 1: // 2 - scaled_masked_softmax_warp_backward - <<>>(grad_input, grad, output, scale, batch_count, key_seq_len); - break; - case 2: // 4 - scaled_masked_softmax_warp_backward - <<>>(grad_input, grad, output, scale, batch_count, key_seq_len); - break; - case 3: // 8 - scaled_masked_softmax_warp_backward - <<>>(grad_input, grad, output, scale, batch_count, key_seq_len); - break; - case 4: // 16 - scaled_masked_softmax_warp_backward - <<>>(grad_input, grad, output, scale, batch_count, key_seq_len); - break; - case 5: // 32 - scaled_masked_softmax_warp_backward - <<>>(grad_input, grad, output, scale, batch_count, key_seq_len); - break; - case 6: // 64 - scaled_masked_softmax_warp_backward - <<>>(grad_input, grad, output, scale, batch_count, key_seq_len); - break; - case 7: // 128 - scaled_masked_softmax_warp_backward - <<>>(grad_input, grad, output, scale, batch_count, key_seq_len); - break; - case 8: // 256 - scaled_masked_softmax_warp_backward - <<>>(grad_input, grad, output, scale, batch_count, key_seq_len); - break; - case 9: // 512 - scaled_masked_softmax_warp_backward - <<>>(grad_input, grad, output, scale, batch_count, key_seq_len); - break; - case 10: // 1024 - scaled_masked_softmax_warp_backward - <<>>(grad_input, grad, output, scale, batch_count, key_seq_len); - break; - case 11: // 2048 - scaled_masked_softmax_warp_backward - <<>>(grad_input, grad, output, scale, batch_count, key_seq_len); - break; - default: - break; - } +template +void dispatch_scaled_masked_softmax_backward(output_t *grad_input, + input_t *grad, + const input_t *output, + const acc_t scale, + int query_seq_len, int key_seq_len, + int batches, int attn_heads) { + TORCH_INTERNAL_ASSERT(key_seq_len >= 0 && key_seq_len <= 2048); + if (key_seq_len == 0) { + return; + } else { + int log2_elements = log2_ceil(key_seq_len); + const int next_power_of_two = 1 << log2_elements; + int batch_count = batches * attn_heads * query_seq_len; + + // This value must match the WARP_SIZE constexpr value computed inside + // softmax_warp_backward. + int warp_size = + (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE; + + // This value must match the WARP_BATCH constexpr value computed inside + // softmax_warp_backward. + int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1; + + // use 128 threads per block to maximimize gpu utilization + constexpr int threads_per_block = 128; + + int warps_per_block = (threads_per_block / warp_size); + int batches_per_block = warps_per_block * batches_per_warp; + int blocks = batch_count / batches_per_block; + dim3 threads(warp_size, warps_per_block, 1); + // Launch code would be more elegant if C++ supported FOR CONSTEXPR + switch (log2_elements) { + case 0: // 1 + scaled_masked_softmax_warp_backward + <<>>( + grad_input, grad, output, scale, batch_count, key_seq_len); + break; + case 1: // 2 + scaled_masked_softmax_warp_backward + <<>>( + grad_input, grad, output, scale, batch_count, key_seq_len); + break; + case 2: // 4 + scaled_masked_softmax_warp_backward + <<>>( + grad_input, grad, output, scale, batch_count, key_seq_len); + break; + case 3: // 8 + scaled_masked_softmax_warp_backward + <<>>( + grad_input, grad, output, scale, batch_count, key_seq_len); + break; + case 4: // 16 + scaled_masked_softmax_warp_backward + <<>>( + grad_input, grad, output, scale, batch_count, key_seq_len); + break; + case 5: // 32 + scaled_masked_softmax_warp_backward + <<>>( + grad_input, grad, output, scale, batch_count, key_seq_len); + break; + case 6: // 64 + scaled_masked_softmax_warp_backward + <<>>( + grad_input, grad, output, scale, batch_count, key_seq_len); + break; + case 7: // 128 + scaled_masked_softmax_warp_backward + <<>>( + grad_input, grad, output, scale, batch_count, key_seq_len); + break; + case 8: // 256 + scaled_masked_softmax_warp_backward + <<>>( + grad_input, grad, output, scale, batch_count, key_seq_len); + break; + case 9: // 512 + scaled_masked_softmax_warp_backward + <<>>( + grad_input, grad, output, scale, batch_count, key_seq_len); + break; + case 10: // 1024 + scaled_masked_softmax_warp_backward + <<>>( + grad_input, grad, output, scale, batch_count, key_seq_len); + break; + case 11: // 2048 + scaled_masked_softmax_warp_backward + <<>>( + grad_input, grad, output, scale, batch_count, key_seq_len); + break; + default: + break; } + } } diff --git a/colossalai/kernel/cuda_native/csrc/scaled_upper_triang_masked_softmax.h b/colossalai/kernel/cuda_native/csrc/scaled_upper_triang_masked_softmax.h index 3af487f9de0ffdc22faaca142cbc2ff86b68d03e..54c8e9133a1b3b588191d90ef9980d49392e1406 100644 --- a/colossalai/kernel/cuda_native/csrc/scaled_upper_triang_masked_softmax.h +++ b/colossalai/kernel/cuda_native/csrc/scaled_upper_triang_masked_softmax.h @@ -4,11 +4,12 @@ #pragma once #include +#include #include +#include + #include #include -#include -#include namespace { @@ -16,53 +17,78 @@ template __device__ __inline__ void copy_vector(Datatype *dst, const Datatype *src); template <> -__device__ __inline__ void copy_vector(c10::BFloat16 *dst, const c10::BFloat16 *src) { *dst = *src; } +__device__ __inline__ void copy_vector( + c10::BFloat16 *dst, const c10::BFloat16 *src) { + *dst = *src; +} template <> -__device__ __inline__ void copy_vector(c10::BFloat16 *dst, const c10::BFloat16 *src) { *((float2*) dst) = *((float2*) src); } - +__device__ __inline__ void copy_vector( + c10::BFloat16 *dst, const c10::BFloat16 *src) { + *((float2 *)dst) = *((float2 *)src); +} + template <> -__device__ __inline__ void copy_vector(c10::Half *dst, const c10::Half *src) { *dst = *src; } +__device__ __inline__ void copy_vector(c10::Half *dst, + const c10::Half *src) { + *dst = *src; +} template <> -__device__ __inline__ void copy_vector(c10::Half *dst, const c10::Half *src) { *((float2*) dst) = *((float2*) src); } +__device__ __inline__ void copy_vector(c10::Half *dst, + const c10::Half *src) { + *((float2 *)dst) = *((float2 *)src); +} template <> -__device__ __inline__ void copy_vector(uint8_t *dst, const uint8_t *src) { *dst = *src; } +__device__ __inline__ void copy_vector(uint8_t *dst, + const uint8_t *src) { + *dst = *src; +} template <> -__device__ __inline__ void copy_vector(uint8_t *dst, const uint8_t *src) {*((half2*) dst) = *((half2*) src); } +__device__ __inline__ void copy_vector(uint8_t *dst, + const uint8_t *src) { + *((half2 *)dst) = *((half2 *)src); +} template __device__ __inline__ void copy_zero_vector(Datatype *dst); template <> -__device__ __inline__ void copy_zero_vector(c10::BFloat16 *dst) { *dst = 0.0; } +__device__ __inline__ void copy_zero_vector( + c10::BFloat16 *dst) { + *dst = 0.0; +} template <> -__device__ __inline__ void copy_zero_vector(c10::BFloat16 *dst) { *((float2*) dst) = make_float2(0.0f, 0.0f); } +__device__ __inline__ void copy_zero_vector( + c10::BFloat16 *dst) { + *((float2 *)dst) = make_float2(0.0f, 0.0f); +} template <> -__device__ __inline__ void copy_zero_vector(c10::Half *dst) { *dst = 0.0; } +__device__ __inline__ void copy_zero_vector(c10::Half *dst) { + *dst = 0.0; +} template <> -__device__ __inline__ void copy_zero_vector(c10::Half *dst) { *((float2*) dst) = make_float2(0.0f, 0.0f); } - +__device__ __inline__ void copy_zero_vector(c10::Half *dst) { + *((float2 *)dst) = make_float2(0.0f, 0.0f); +} int log2_ceil(int value) { - int log2_value = 0; - while ((1 << log2_value) < value) ++log2_value; - return log2_value; + int log2_value = 0; + while ((1 << log2_value) < value) ++log2_value; + return log2_value; } -template +template struct Add { - __device__ __forceinline__ T operator()(T a, T b) const { - return a + b; - } + __device__ __forceinline__ T operator()(T a, T b) const { return a + b; } }; -template +template struct Max { __device__ __forceinline__ T operator()(T a, T b) const { return a < b ? b : a; @@ -70,431 +96,505 @@ struct Max { }; template -__device__ __forceinline__ T WARP_SHFL_XOR_NATIVE(T value, int laneMask, int width = warpSize, unsigned int mask = 0xffffffff) -{ +__device__ __forceinline__ T +WARP_SHFL_XOR_NATIVE(T value, int laneMask, int width = warpSize, + unsigned int mask = 0xffffffff) { #if CUDA_VERSION >= 9000 - return __shfl_xor_sync(mask, value, laneMask, width); + return __shfl_xor_sync(mask, value, laneMask, width); #else - return __shfl_xor(value, laneMask, width); + return __shfl_xor(value, laneMask, width); #endif } -template class ReduceOp> -__device__ __forceinline__ void warp_reduce(acc_t* sum) { - ReduceOp r; - #pragma unroll - for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) { - #pragma unroll - for (int i = 0; i < WARP_BATCH; ++i) { - acc_t b = WARP_SHFL_XOR_NATIVE(sum[i], offset, WARP_SIZE); - sum[i] = r(sum[i], b); - } +template class ReduceOp> +__device__ __forceinline__ void warp_reduce(acc_t *sum) { + ReduceOp r; +#pragma unroll + for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) { +#pragma unroll + for (int i = 0; i < WARP_BATCH; ++i) { + acc_t b = WARP_SHFL_XOR_NATIVE(sum[i], offset, WARP_SIZE); + sum[i] = r(sum[i], b); } + } } /* - * Extended softmax (from native aten pytorch) with following additional features - * 1) input scaling - * 2) Implicit time (diagonal masking) + * Extended softmax (from native aten pytorch) with following additional + * features 1) input scaling 2) Implicit time (diagonal masking) */ -template +template __global__ void scaled_upper_triang_masked_softmax_warp_forward( - output_t *dst, - const input_t *src, - const acc_t scale, - int micro_batch_size, - int stride, - int element_count) -{ - // WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and - // warp_size of method warp_softmax_forward_kernel. - constexpr int next_power_of_two = 1 << log2_elements; - constexpr int WARP_SIZE = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE; - constexpr int WARP_ITERATIONS = next_power_of_two / WARP_SIZE; - constexpr int WARP_BATCH = (next_power_of_two <= 128) ? 2 : 1; - constexpr int ELEMENTS_PER_LDG_STG = (WARP_ITERATIONS < 4) ? 1 : 4; - - int first_batch = (blockDim.y * blockIdx.y + threadIdx.y) * gridDim.x * WARP_BATCH + blockIdx.x; - int local_seq = blockIdx.x + 1; - int warp_iteration_limit = (local_seq + ELEMENTS_PER_LDG_STG * WARP_SIZE - 1)/ WARP_SIZE; - - // micro_batch_size might not be a multiple of WARP_BATCH. Check how - // many batches have to computed within this WARP. - int local_batches = micro_batch_size - first_batch; - if (local_batches > WARP_BATCH) - local_batches = WARP_BATCH; - - // there might be multiple batches per warp. compute the index within the batch - int local_idx = threadIdx.x; - - src += first_batch * stride + ELEMENTS_PER_LDG_STG * local_idx; - dst += first_batch * stride + ELEMENTS_PER_LDG_STG * local_idx; - - // load data from global memory - acc_t elements[WARP_BATCH][WARP_ITERATIONS]; - input_t temp_data[ELEMENTS_PER_LDG_STG]; - #pragma unroll - for (int i = 0; i < WARP_BATCH; ++i) { - int batch_element_count = (i >= local_batches) ? 0 : local_seq; - - #pragma unroll - for (int it = 0; it < WARP_ITERATIONS; it+=ELEMENTS_PER_LDG_STG) { - int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE; - - if (element_index < batch_element_count) { - copy_vector(temp_data, src + i*element_count*stride + it*WARP_SIZE); - - #pragma unroll - for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { - if ((element_index + element) < batch_element_count) { - elements[i][it+element] = (acc_t)temp_data[element] * scale; - } else { - elements[i][it + element] = -std::numeric_limits::infinity(); - } - } - } else { - #pragma unroll - for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { - elements[i][it + element] = -std::numeric_limits::infinity(); - } - } + output_t *dst, const input_t *src, const acc_t scale, int micro_batch_size, + int stride, int element_count) { + // WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and + // warp_size of method warp_softmax_forward_kernel. + constexpr int next_power_of_two = 1 << log2_elements; + constexpr int WARP_SIZE = + (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE; + constexpr int WARP_ITERATIONS = next_power_of_two / WARP_SIZE; + constexpr int WARP_BATCH = (next_power_of_two <= 128) ? 2 : 1; + constexpr int ELEMENTS_PER_LDG_STG = (WARP_ITERATIONS < 4) ? 1 : 4; + + int first_batch = + (blockDim.y * blockIdx.y + threadIdx.y) * gridDim.x * WARP_BATCH + + blockIdx.x; + int local_seq = blockIdx.x + 1; + int warp_iteration_limit = + (local_seq + ELEMENTS_PER_LDG_STG * WARP_SIZE - 1) / WARP_SIZE; + + // micro_batch_size might not be a multiple of WARP_BATCH. Check how + // many batches have to computed within this WARP. + int local_batches = micro_batch_size - first_batch; + if (local_batches > WARP_BATCH) local_batches = WARP_BATCH; + + // there might be multiple batches per warp. compute the index within the + // batch + int local_idx = threadIdx.x; + + src += first_batch * stride + ELEMENTS_PER_LDG_STG * local_idx; + dst += first_batch * stride + ELEMENTS_PER_LDG_STG * local_idx; + + // load data from global memory + acc_t elements[WARP_BATCH][WARP_ITERATIONS]; + input_t temp_data[ELEMENTS_PER_LDG_STG]; +#pragma unroll + for (int i = 0; i < WARP_BATCH; ++i) { + int batch_element_count = (i >= local_batches) ? 0 : local_seq; + +#pragma unroll + for (int it = 0; it < WARP_ITERATIONS; it += ELEMENTS_PER_LDG_STG) { + int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE; + + if (element_index < batch_element_count) { + copy_vector( + temp_data, src + i * element_count * stride + it * WARP_SIZE); + +#pragma unroll + for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { + if ((element_index + element) < batch_element_count) { + elements[i][it + element] = (acc_t)temp_data[element] * scale; + } else { + elements[i][it + element] = -std::numeric_limits::infinity(); + } } + } else { +#pragma unroll + for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { + elements[i][it + element] = -std::numeric_limits::infinity(); + } + } } + } - // compute max_value - acc_t max_value[WARP_BATCH]; - #pragma unroll - for (int i = 0; i < WARP_BATCH; ++i) { - max_value[i] = elements[i][0]; - #pragma unroll - for (int it = 1; it < WARP_ITERATIONS; ++it) { - max_value[i] = (max_value[i] > elements[i][it]) ? max_value[i] : elements[i][it]; - } + // compute max_value + acc_t max_value[WARP_BATCH]; +#pragma unroll + for (int i = 0; i < WARP_BATCH; ++i) { + max_value[i] = elements[i][0]; +#pragma unroll + for (int it = 1; it < WARP_ITERATIONS; ++it) { + max_value[i] = + (max_value[i] > elements[i][it]) ? max_value[i] : elements[i][it]; } - warp_reduce(max_value); - - acc_t sum[WARP_BATCH] { 0.0f }; - #pragma unroll - for (int i = 0; i < WARP_BATCH; ++i) { - #pragma unroll - for (int it = 0; it < WARP_ITERATIONS; ++it) { - if (it < warp_iteration_limit) { - elements[i][it] = std::exp((elements[i][it] - max_value[i])); - sum[i] += elements[i][it]; - } - } + } + warp_reduce(max_value); + + acc_t sum[WARP_BATCH]{0.0f}; +#pragma unroll + for (int i = 0; i < WARP_BATCH; ++i) { +#pragma unroll + for (int it = 0; it < WARP_ITERATIONS; ++it) { + if (it < warp_iteration_limit) { + elements[i][it] = std::exp((elements[i][it] - max_value[i])); + sum[i] += elements[i][it]; + } } - warp_reduce(sum); - - // store result - output_t out[ELEMENTS_PER_LDG_STG]; - #pragma unroll - for (int i = 0; i < WARP_BATCH; ++i) { - if (i >= local_batches) - break; - #pragma unroll - for (int it = 0; it < WARP_ITERATIONS; it+=ELEMENTS_PER_LDG_STG) { - int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE; - - if (element_index < local_seq) { - - #pragma unroll - for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { - if (element_index + element < local_seq) { - out[element] = elements[i][it + element] / sum[i]; - } else { - out[element] = 0; - } - } - copy_vector(dst + i * element_count * stride + it * WARP_SIZE, out); - } else if (element_index < element_count) { - copy_zero_vector(dst + i * element_count * stride + it * WARP_SIZE); - } else { - break; - } + } + warp_reduce(sum); + + // store result + output_t out[ELEMENTS_PER_LDG_STG]; +#pragma unroll + for (int i = 0; i < WARP_BATCH; ++i) { + if (i >= local_batches) break; +#pragma unroll + for (int it = 0; it < WARP_ITERATIONS; it += ELEMENTS_PER_LDG_STG) { + int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE; + + if (element_index < local_seq) { +#pragma unroll + for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { + if (element_index + element < local_seq) { + out[element] = elements[i][it + element] / sum[i]; + } else { + out[element] = 0; + } } + copy_vector( + dst + i * element_count * stride + it * WARP_SIZE, out); + } else if (element_index < element_count) { + copy_zero_vector( + dst + i * element_count * stride + it * WARP_SIZE); + } else { + break; + } } + } } -template +template __global__ void scaled_upper_triang_masked_softmax_warp_backward( - output_t *gradInput, - input_t *grad, - const input_t *output, - acc_t scale, - int micro_batch_size, - int stride, - int element_count) -{ - // WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and - // warp_size of method warp_softmax_backward_kernel. - constexpr int next_power_of_two = 1 << log2_elements; - constexpr int WARP_SIZE = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE; - constexpr int WARP_ITERATIONS = next_power_of_two / WARP_SIZE; - constexpr int WARP_BATCH = (next_power_of_two <= 128) ? 2 : 1; - constexpr int ELEMENTS_PER_LDG_STG = (WARP_ITERATIONS < 4) ? 1 : 4; - - int first_batch = (blockDim.y * blockIdx.y + threadIdx.y) * gridDim.x * WARP_BATCH + blockIdx.x; - int local_seq = blockIdx.x + 1; - - // micro_batch_size might not be a multiple of WARP_BATCH. Check how - // many batches have to computed within this WARP. - int local_batches = micro_batch_size - first_batch; - if (local_batches > WARP_BATCH) - local_batches = WARP_BATCH; - - // there might be multiple batches per warp. compute the index within the batch - int local_idx = threadIdx.x; - - // the first element to process by the current thread - int thread_offset = first_batch * stride + ELEMENTS_PER_LDG_STG * local_idx; - grad += thread_offset; - output += thread_offset; - gradInput += thread_offset; - - // load data from global memory - acc_t grad_reg[WARP_BATCH][WARP_ITERATIONS] { 0.0f }; - acc_t output_reg[WARP_BATCH][WARP_ITERATIONS] { 0.0f }; - input_t temp_grad[ELEMENTS_PER_LDG_STG]; - input_t temp_output[ELEMENTS_PER_LDG_STG]; - #pragma unroll - for (int i = 0; i < WARP_BATCH; ++i) { - int batch_element_count = (i >= local_batches) ? 0 : local_seq; - - #pragma unroll - for (int it = 0; it < WARP_ITERATIONS; it+=ELEMENTS_PER_LDG_STG) { - int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE; - if (element_index < batch_element_count) { - copy_vector(temp_grad, grad + i * element_count * stride + it * WARP_SIZE); - copy_vector(temp_output, output + i * element_count * stride + it * WARP_SIZE); - - #pragma unroll - for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { - if (element_index + element < batch_element_count) { - output_reg[i][it + element] = (acc_t)temp_output[element]; - } - } - #pragma unroll - for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { - if (element_index + element < batch_element_count) { - grad_reg[i][it + element] = (acc_t)temp_grad[element] * output_reg[i][it + element]; - } - } - } + output_t *gradInput, input_t *grad, const input_t *output, acc_t scale, + int micro_batch_size, int stride, int element_count) { + // WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and + // warp_size of method warp_softmax_backward_kernel. + constexpr int next_power_of_two = 1 << log2_elements; + constexpr int WARP_SIZE = + (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE; + constexpr int WARP_ITERATIONS = next_power_of_two / WARP_SIZE; + constexpr int WARP_BATCH = (next_power_of_two <= 128) ? 2 : 1; + constexpr int ELEMENTS_PER_LDG_STG = (WARP_ITERATIONS < 4) ? 1 : 4; + + int first_batch = + (blockDim.y * blockIdx.y + threadIdx.y) * gridDim.x * WARP_BATCH + + blockIdx.x; + int local_seq = blockIdx.x + 1; + + // micro_batch_size might not be a multiple of WARP_BATCH. Check how + // many batches have to computed within this WARP. + int local_batches = micro_batch_size - first_batch; + if (local_batches > WARP_BATCH) local_batches = WARP_BATCH; + + // there might be multiple batches per warp. compute the index within the + // batch + int local_idx = threadIdx.x; + + // the first element to process by the current thread + int thread_offset = first_batch * stride + ELEMENTS_PER_LDG_STG * local_idx; + grad += thread_offset; + output += thread_offset; + gradInput += thread_offset; + + // load data from global memory + acc_t grad_reg[WARP_BATCH][WARP_ITERATIONS]{0.0f}; + acc_t output_reg[WARP_BATCH][WARP_ITERATIONS]{0.0f}; + input_t temp_grad[ELEMENTS_PER_LDG_STG]; + input_t temp_output[ELEMENTS_PER_LDG_STG]; +#pragma unroll + for (int i = 0; i < WARP_BATCH; ++i) { + int batch_element_count = (i >= local_batches) ? 0 : local_seq; + +#pragma unroll + for (int it = 0; it < WARP_ITERATIONS; it += ELEMENTS_PER_LDG_STG) { + int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE; + if (element_index < batch_element_count) { + copy_vector( + temp_grad, grad + i * element_count * stride + it * WARP_SIZE); + copy_vector( + temp_output, output + i * element_count * stride + it * WARP_SIZE); + +#pragma unroll + for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { + if (element_index + element < batch_element_count) { + output_reg[i][it + element] = (acc_t)temp_output[element]; + } } - } - - acc_t sum[WARP_BATCH]; - #pragma unroll - for (int i = 0; i < WARP_BATCH; ++i) { - sum[i] = grad_reg[i][0]; - #pragma unroll - for (int it = 1; it < WARP_ITERATIONS; ++it) { - sum[i] += grad_reg[i][it]; +#pragma unroll + for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { + if (element_index + element < batch_element_count) { + grad_reg[i][it + element] = + (acc_t)temp_grad[element] * output_reg[i][it + element]; + } } + } } - warp_reduce(sum); - - // store result - #pragma unroll - for (int i = 0; i < WARP_BATCH; ++i) { - if (i >= local_batches) - break; - #pragma unroll - for (int it = 0; it < WARP_ITERATIONS; it+=ELEMENTS_PER_LDG_STG) { - int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE; - if (element_index < element_count) { - // compute gradients - output_t out[ELEMENTS_PER_LDG_STG]; - #pragma unroll - for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { - out[element] = (output_t)(scale * (grad_reg[i][it + element] - output_reg[i][it + element] * sum[i])); - } - copy_vector(gradInput + i * element_count * stride + it * WARP_SIZE, out); - } + } + + acc_t sum[WARP_BATCH]; +#pragma unroll + for (int i = 0; i < WARP_BATCH; ++i) { + sum[i] = grad_reg[i][0]; +#pragma unroll + for (int it = 1; it < WARP_ITERATIONS; ++it) { + sum[i] += grad_reg[i][it]; + } + } + warp_reduce(sum); + +// store result +#pragma unroll + for (int i = 0; i < WARP_BATCH; ++i) { + if (i >= local_batches) break; +#pragma unroll + for (int it = 0; it < WARP_ITERATIONS; it += ELEMENTS_PER_LDG_STG) { + int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE; + if (element_index < element_count) { + // compute gradients + output_t out[ELEMENTS_PER_LDG_STG]; +#pragma unroll + for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { + out[element] = + (output_t)(scale * (grad_reg[i][it + element] - + output_reg[i][it + element] * sum[i])); } + copy_vector( + gradInput + i * element_count * stride + it * WARP_SIZE, out); + } } + } } -} // end of anonymous namespace +} // end of anonymous namespace -template +template void dispatch_scaled_upper_triang_masked_softmax_forward( - output_t *dst, - const input_t *src, - const input_t scale, - int softmax_elements, - int softmax_elements_stride, - int attn_batches) -{ - TORCH_INTERNAL_ASSERT(softmax_elements >= 0 && softmax_elements <= 2048 ); - if (softmax_elements == 0) { - return; - } else { - int log2_elements = log2_ceil(softmax_elements); - const int next_power_of_two = 1 << log2_elements; - int seq_len = softmax_elements; - int batch_count = attn_batches * seq_len; - - // This value must match the WARP_SIZE constexpr value computed inside softmax_warp_forward. - int warp_size = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE; - - // This value must match the WARP_BATCH constexpr value computed inside softmax_warp_forward. - int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1; - - // use 128 threads per block to maximimize gpu utilization - constexpr int threads_per_block = 128; - - int warps_per_block = (threads_per_block / warp_size); - int batches_per_block = warps_per_block * batches_per_warp; - TORCH_INTERNAL_ASSERT(attn_batches % batches_per_block == 0); - - int blocks_per_seq = attn_batches / batches_per_block; - dim3 blocks(seq_len, blocks_per_seq, 1); - dim3 threads(warp_size, warps_per_block, 1); - // Launch code would be more elegant if C++ supported FOR CONSTEXPR - switch (log2_elements) { - case 0: // 1 - scaled_upper_triang_masked_softmax_warp_forward - <<>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements); - break; - case 1: // 2 - scaled_upper_triang_masked_softmax_warp_forward - <<>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements); - break; - case 2: // 4 - scaled_upper_triang_masked_softmax_warp_forward - <<>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements); - break; - case 3: // 8 - scaled_upper_triang_masked_softmax_warp_forward - <<>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements); - break; - case 4: // 16 - scaled_upper_triang_masked_softmax_warp_forward - <<>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements); - break; - case 5: // 32 - scaled_upper_triang_masked_softmax_warp_forward - <<>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements); - break; - case 6: // 64 - scaled_upper_triang_masked_softmax_warp_forward - <<>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements); - break; - case 7: // 128 - scaled_upper_triang_masked_softmax_warp_forward - <<>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements); - break; - case 8: // 256 - scaled_upper_triang_masked_softmax_warp_forward - <<>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements); - break; - case 9: // 512 - scaled_upper_triang_masked_softmax_warp_forward - <<>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements); - break; - case 10: // 1024 - scaled_upper_triang_masked_softmax_warp_forward - <<>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements); - break; - case 11: // 2048 - scaled_upper_triang_masked_softmax_warp_forward - <<>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements); - break; - default: - break; - } + output_t *dst, const input_t *src, const input_t scale, + int softmax_elements, int softmax_elements_stride, int attn_batches) { + TORCH_INTERNAL_ASSERT(softmax_elements >= 0 && softmax_elements <= 2048); + if (softmax_elements == 0) { + return; + } else { + int log2_elements = log2_ceil(softmax_elements); + const int next_power_of_two = 1 << log2_elements; + int seq_len = softmax_elements; + int batch_count = attn_batches * seq_len; + + // This value must match the WARP_SIZE constexpr value computed inside + // softmax_warp_forward. + int warp_size = + (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE; + + // This value must match the WARP_BATCH constexpr value computed inside + // softmax_warp_forward. + int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1; + + // use 128 threads per block to maximimize gpu utilization + constexpr int threads_per_block = 128; + + int warps_per_block = (threads_per_block / warp_size); + int batches_per_block = warps_per_block * batches_per_warp; + TORCH_INTERNAL_ASSERT(attn_batches % batches_per_block == 0); + + int blocks_per_seq = attn_batches / batches_per_block; + dim3 blocks(seq_len, blocks_per_seq, 1); + dim3 threads(warp_size, warps_per_block, 1); + // Launch code would be more elegant if C++ supported FOR CONSTEXPR + switch (log2_elements) { + case 0: // 1 + scaled_upper_triang_masked_softmax_warp_forward + <<>>( + dst, src, scale, batch_count, softmax_elements_stride, + softmax_elements); + break; + case 1: // 2 + scaled_upper_triang_masked_softmax_warp_forward + <<>>( + dst, src, scale, batch_count, softmax_elements_stride, + softmax_elements); + break; + case 2: // 4 + scaled_upper_triang_masked_softmax_warp_forward + <<>>( + dst, src, scale, batch_count, softmax_elements_stride, + softmax_elements); + break; + case 3: // 8 + scaled_upper_triang_masked_softmax_warp_forward + <<>>( + dst, src, scale, batch_count, softmax_elements_stride, + softmax_elements); + break; + case 4: // 16 + scaled_upper_triang_masked_softmax_warp_forward + <<>>( + dst, src, scale, batch_count, softmax_elements_stride, + softmax_elements); + break; + case 5: // 32 + scaled_upper_triang_masked_softmax_warp_forward + <<>>( + dst, src, scale, batch_count, softmax_elements_stride, + softmax_elements); + break; + case 6: // 64 + scaled_upper_triang_masked_softmax_warp_forward + <<>>( + dst, src, scale, batch_count, softmax_elements_stride, + softmax_elements); + break; + case 7: // 128 + scaled_upper_triang_masked_softmax_warp_forward + <<>>( + dst, src, scale, batch_count, softmax_elements_stride, + softmax_elements); + break; + case 8: // 256 + scaled_upper_triang_masked_softmax_warp_forward + <<>>( + dst, src, scale, batch_count, softmax_elements_stride, + softmax_elements); + break; + case 9: // 512 + scaled_upper_triang_masked_softmax_warp_forward + <<>>( + dst, src, scale, batch_count, softmax_elements_stride, + softmax_elements); + break; + case 10: // 1024 + scaled_upper_triang_masked_softmax_warp_forward + <<>>( + dst, src, scale, batch_count, softmax_elements_stride, + softmax_elements); + break; + case 11: // 2048 + scaled_upper_triang_masked_softmax_warp_forward + <<>>( + dst, src, scale, batch_count, softmax_elements_stride, + softmax_elements); + break; + default: + break; } + } } -template +template void dispatch_scaled_upper_triang_masked_softmax_backward( - output_t *grad_input, - input_t *grad, - const input_t *output, - const acc_t scale, - int softmax_elements, - int softmax_elements_stride, - int attn_batches) -{ - TORCH_INTERNAL_ASSERT( softmax_elements >= 0 && softmax_elements <= 2048 ); - if (softmax_elements == 0) { - return; - } else { - int log2_elements = log2_ceil(softmax_elements); - const int next_power_of_two = 1 << log2_elements; - int seq_len = softmax_elements; - int batch_count = attn_batches * seq_len; - - // This value must match the WARP_SIZE constexpr value computed inside softmax_warp_backward. - int warp_size = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE; - - // This value must match the WARP_BATCH constexpr value computed inside softmax_warp_backward. - int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1; - - // use 128 threads per block to maximimize gpu utilization - constexpr int threads_per_block = 128; - - int warps_per_block = (threads_per_block / warp_size); - int batches_per_block = warps_per_block * batches_per_warp; - TORCH_INTERNAL_ASSERT(attn_batches % batches_per_block == 0); - - int blocks_per_seq = attn_batches / batches_per_block; - dim3 blocks(seq_len, blocks_per_seq, 1); - dim3 threads(warp_size, warps_per_block, 1); - // Launch code would be more elegant if C++ supported FOR CONSTEXPR - switch (log2_elements) { - case 0: // 1 - scaled_upper_triang_masked_softmax_warp_backward - <<>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements); - break; - case 1: // 2 - scaled_upper_triang_masked_softmax_warp_backward - <<>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements); - break; - case 2: // 4 - scaled_upper_triang_masked_softmax_warp_backward - <<>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements); - break; - case 3: // 8 - scaled_upper_triang_masked_softmax_warp_backward - <<>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements); - break; - case 4: // 16 - scaled_upper_triang_masked_softmax_warp_backward - <<>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements); - break; - case 5: // 32 - scaled_upper_triang_masked_softmax_warp_backward - <<>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements); - break; - case 6: // 64 - scaled_upper_triang_masked_softmax_warp_backward - <<>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements); - break; - case 7: // 128 - scaled_upper_triang_masked_softmax_warp_backward - <<>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements); - break; - case 8: // 256 - scaled_upper_triang_masked_softmax_warp_backward - <<>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements); - break; - case 9: // 512 - scaled_upper_triang_masked_softmax_warp_backward - <<>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements); - break; - case 10: // 1024 - scaled_upper_triang_masked_softmax_warp_backward - <<>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements); - break; - case 11: // 2048 - scaled_upper_triang_masked_softmax_warp_backward - <<>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements); - break; - default: - break; - } + output_t *grad_input, input_t *grad, const input_t *output, + const acc_t scale, int softmax_elements, int softmax_elements_stride, + int attn_batches) { + TORCH_INTERNAL_ASSERT(softmax_elements >= 0 && softmax_elements <= 2048); + if (softmax_elements == 0) { + return; + } else { + int log2_elements = log2_ceil(softmax_elements); + const int next_power_of_two = 1 << log2_elements; + int seq_len = softmax_elements; + int batch_count = attn_batches * seq_len; + + // This value must match the WARP_SIZE constexpr value computed inside + // softmax_warp_backward. + int warp_size = + (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE; + + // This value must match the WARP_BATCH constexpr value computed inside + // softmax_warp_backward. + int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1; + + // use 128 threads per block to maximimize gpu utilization + constexpr int threads_per_block = 128; + + int warps_per_block = (threads_per_block / warp_size); + int batches_per_block = warps_per_block * batches_per_warp; + TORCH_INTERNAL_ASSERT(attn_batches % batches_per_block == 0); + + int blocks_per_seq = attn_batches / batches_per_block; + dim3 blocks(seq_len, blocks_per_seq, 1); + dim3 threads(warp_size, warps_per_block, 1); + // Launch code would be more elegant if C++ supported FOR CONSTEXPR + switch (log2_elements) { + case 0: // 1 + scaled_upper_triang_masked_softmax_warp_backward + <<>>( + grad_input, grad, output, scale, batch_count, + softmax_elements_stride, softmax_elements); + break; + case 1: // 2 + scaled_upper_triang_masked_softmax_warp_backward + <<>>( + grad_input, grad, output, scale, batch_count, + softmax_elements_stride, softmax_elements); + break; + case 2: // 4 + scaled_upper_triang_masked_softmax_warp_backward + <<>>( + grad_input, grad, output, scale, batch_count, + softmax_elements_stride, softmax_elements); + break; + case 3: // 8 + scaled_upper_triang_masked_softmax_warp_backward + <<>>( + grad_input, grad, output, scale, batch_count, + softmax_elements_stride, softmax_elements); + break; + case 4: // 16 + scaled_upper_triang_masked_softmax_warp_backward + <<>>( + grad_input, grad, output, scale, batch_count, + softmax_elements_stride, softmax_elements); + break; + case 5: // 32 + scaled_upper_triang_masked_softmax_warp_backward + <<>>( + grad_input, grad, output, scale, batch_count, + softmax_elements_stride, softmax_elements); + break; + case 6: // 64 + scaled_upper_triang_masked_softmax_warp_backward + <<>>( + grad_input, grad, output, scale, batch_count, + softmax_elements_stride, softmax_elements); + break; + case 7: // 128 + scaled_upper_triang_masked_softmax_warp_backward + <<>>( + grad_input, grad, output, scale, batch_count, + softmax_elements_stride, softmax_elements); + break; + case 8: // 256 + scaled_upper_triang_masked_softmax_warp_backward + <<>>( + grad_input, grad, output, scale, batch_count, + softmax_elements_stride, softmax_elements); + break; + case 9: // 512 + scaled_upper_triang_masked_softmax_warp_backward + <<>>( + grad_input, grad, output, scale, batch_count, + softmax_elements_stride, softmax_elements); + break; + case 10: // 1024 + scaled_upper_triang_masked_softmax_warp_backward + <<>>( + grad_input, grad, output, scale, batch_count, + softmax_elements_stride, softmax_elements); + break; + case 11: // 2048 + scaled_upper_triang_masked_softmax_warp_backward + <<>>( + grad_input, grad, output, scale, batch_count, + softmax_elements_stride, softmax_elements); + break; + default: + break; } + } } diff --git a/colossalai/kernel/cuda_native/csrc/type_shim.h b/colossalai/kernel/cuda_native/csrc/type_shim.h index 2f180a7783ec98b5d9b8286ccac4887b2a5b4bc5..03ccc02635fa17f7225ad1f9f8e59c65721069f5 100644 --- a/colossalai/kernel/cuda_native/csrc/type_shim.h +++ b/colossalai/kernel/cuda_native/csrc/type_shim.h @@ -171,6 +171,21 @@ using g_scalar_t_##LEVEL = at::Half; \ using p_scalar_t_##LEVEL = at::Half; \ __VA_ARGS__; \ + } else if (GTYPE == at::ScalarType::Float && \ + PTYPE == at::ScalarType::BFloat16) { \ + using g_scalar_t_##LEVEL = float; \ + using p_scalar_t_##LEVEL = at::BFloat16; \ + __VA_ARGS__; \ + } else if (GTYPE == at::ScalarType::BFloat16 && \ + PTYPE == at::ScalarType::Float) { \ + using g_scalar_t_##LEVEL = at::BFloat16; \ + using p_scalar_t_##LEVEL = float; \ + __VA_ARGS__; \ + } else if (GTYPE == at::ScalarType::BFloat16 && \ + PTYPE == at::ScalarType::BFloat16) { \ + using g_scalar_t_##LEVEL = at::BFloat16; \ + using p_scalar_t_##LEVEL = at::BFloat16; \ + __VA_ARGS__; \ } else { \ AT_ERROR(#NAME, "not implemented for '", toString(GTYPE), toString(PTYPE), \ "'"); \ diff --git a/colossalai/kernel/cuda_native/flash_attention.py b/colossalai/kernel/cuda_native/flash_attention.py deleted file mode 100644 index d793815ed681c52a7d17b49528d86aee2d441a9f..0000000000000000000000000000000000000000 --- a/colossalai/kernel/cuda_native/flash_attention.py +++ /dev/null @@ -1,635 +0,0 @@ -""" -A general attention module using the flash attention kernels from xformers: -https://github.com/facebookresearch/xformers/tree/main/xformers/ops/fmha -""" - -import math -import os -import subprocess - -import torch - -try: - from xformers.ops.fmha import memory_efficient_attention - HAS_MEM_EFF_ATTN = True -except ImportError: - HAS_MEM_EFF_ATTN = False - print('please install xformers from https://github.com/facebookresearch/xformers') - -if HAS_MEM_EFF_ATTN: - - from typing import Optional - - from einops import rearrange - from xformers.ops.fmha import MemoryEfficientAttentionCutlassOp - from xformers.ops.fmha.attn_bias import BlockDiagonalMask, LowerTriangularMask, LowerTriangularMaskWithTensorBias - - from .scaled_softmax import AttnMaskType - - allow_alibi = True - for op in MemoryEfficientAttentionCutlassOp: - allow_alibi = allow_alibi & (LowerTriangularMaskWithTensorBias in op.SUPPORTED_ATTN_BIAS_TYPES) - - class Unpad(torch.autograd.Function): - """ - Adapted from - https://github.com/HazyResearch/flash-attention/blob/main/flash_attn/bert_padding.py - """ - - @staticmethod - def forward(ctx, tensor: torch.Tensor, indices: torch.Tensor): - ctx.save_for_backward(indices) - # [b, s, ...] - assert tensor.ndim >= 3 - ctx.bsz = tensor.shape[0] - out = rearrange(tensor, 'b s ... -> (b s) ...') - ctx.shape = out.shape - # [1, ntokens, ...] - return out[indices].unsqueeze(0) - - @staticmethod - def backward(ctx, grad_output): - indices, = ctx.saved_tensors - # [b*s, ...] - grad = torch.zeros(ctx.shape, dtype=grad_output.dtype, device=grad_output.device) - grad[indices] = grad_output.squeeze(0) - grad = rearrange(grad, '(b s) ... -> b s ...', b=ctx.bsz) - # [b, s, ...] - return grad, None - - class Repad(torch.autograd.Function): - """ - Adapted from - https://github.com/HazyResearch/flash-attention/blob/main/flash_attn/bert_padding.py - """ - - @staticmethod - def forward(ctx, tensor: torch.Tensor, indices: torch.Tensor, batch_size: int, seq_len: int): - ctx.save_for_backward(indices) - # [ntokens, ...] - tensor = tensor.squeeze(0) - out = torch.zeros((batch_size * seq_len, *tensor.shape[1:]), dtype=tensor.dtype, device=tensor.device) - # [b*s, ...] - out[indices] = tensor - # [b, s, ...] - out = rearrange(out, '(b s) ... -> b s ...', b=batch_size) - return out - - @staticmethod - def backward(ctx, grad_output): - indices, = ctx.saved_tensors - # [b*s, ...] - grad_output = rearrange(grad_output, 'b s ... -> (b s) ...') - grad = grad_output[indices] - # [1, ntokens, ...] - return grad.unsqueeze(0), None, None, None - - class ColoAttention(torch.nn.Module): - - def __init__(self, embed_dim: int, num_heads: int, dropout: float = 0.0): - super().__init__() - assert embed_dim % num_heads == 0, \ - f"the embed dim ({embed_dim}) is not divisible by the number of attention heads ({num_heads})." - self.scale = 1 / math.sqrt(embed_dim // num_heads) - self.dropout = dropout - - @staticmethod - def get_seq_info_from_mask(attn_mask: torch.Tensor): - indices = torch.nonzero(attn_mask.flatten(), as_tuple=False).flatten() - seqlens = attn_mask.sum(dim=-1, dtype=torch.int32).flatten().tolist() - return indices, seqlens - - @staticmethod - def unpad(tensor: torch.Tensor, indices: torch.Tensor) -> torch.Tensor: - return Unpad.apply(tensor, indices) - - @staticmethod - def repad(tensor: torch.Tensor, indices: torch.Tensor, batch_size: int, seq_len: int) -> torch.Tensor: - return Repad.apply(tensor, indices, batch_size, seq_len) - - def forward(self, - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - attn_mask: Optional[torch.Tensor] = None, - attn_mask_type: Optional[AttnMaskType] = None, - bias: Optional[torch.Tensor] = None): - batch_size, tgt_len, src_len = query.shape[0], query.shape[1], key.shape[1] - attn_bias = None - if attn_mask_type == AttnMaskType.padding: # bert style - assert attn_mask is not None, \ - f"attention mask {attn_mask} is not valid for attention mask type {attn_mask_type}." - assert attn_mask.dim() == 2, \ - "attention mask is supposed to have shape (batch_size, seq_len), " + \ - f"but got {attn_mask.dim()} dimensions." - if tgt_len == src_len: - q_indices, q_seqlen = self.get_seq_info_from_mask(attn_mask) - kv_seqlen = None - if batch_size > 1: - query, key, value = self.unpad(torch.stack([query, key, value], dim=2), q_indices).unbind(dim=2) - else: - q_indices = torch.arange(batch_size * tgt_len, dtype=torch.int32, device=query.device) - q_seqlen = torch.LongTensor([tgt_len] * batch_size, device=query.device) - kv_indices, kv_seqlen = self.get_seq_info_from_mask(attn_mask) - if batch_size > 1: - query = rearrange(query, "b s ... -> c (b s) ...", c=1) - key, value = self.unpad(torch.stack([query, key, value], dim=2), kv_indices).unbind(dim=2) - attn_bias = BlockDiagonalMask.from_seqlens(q_seqlen, kv_seqlen) - elif attn_mask_type == AttnMaskType.causal: # gpt style - attn_bias = LowerTriangularMask() - - if bias is not None: # alibi / relative position emebedding - assert allow_alibi, "flash attention with bias is not supported in this system." - assert attn_mask_type == AttnMaskType.causal, \ - "attention with bias is only supported for causal attention so far." - attn_bias = attn_bias.add_bias(bias) - - out = memory_efficient_attention(query, key, value, attn_bias=attn_bias, p=self.dropout, scale=self.scale) - - if attn_mask_type == AttnMaskType.padding and batch_size > 1: - out = self.repad(out, q_indices, batch_size, tgt_len) - - out = rearrange(out, 'b s h d -> b s (h d)') - return out - - -########################################################################## -# the flash attention functions below that are copied -# from the OpenAI/triton repository will be deprecated -# You can find the repository in Triton https://github.com/openai/triton -# You can find the source file in https://github.com/openai/triton/blob/main/python/tutorials/06-fused-attention.py -# Reference: -# 1. Dao et al., https://arxiv.org/pdf/2205.14135v2.pdf -# 2. Rabe and Staats https://arxiv.org/pdf/2112.05682v2.pdf - - -def triton_cuda_check(): - cuda_home = os.getenv("CUDA_HOME", default="/usr/local/cuda") - cuda_version = subprocess.check_output([os.path.join(cuda_home, "bin/nvcc"), "--version"]).decode().strip() - cuda_version = cuda_version.split('release ')[1] - cuda_version = cuda_version.split(',')[0] - cuda_version = cuda_version.split('.') - if len(cuda_version) == 2 and \ - (int(cuda_version[0]) == 11 and int(cuda_version[1]) >= 4) or \ - int(cuda_version[0]) > 11: - return True - return False - - -try: - import triton - import triton.language as tl - if triton_cuda_check(): - HAS_TRITON = True - else: - print("triton requires cuda >= 11.4") - HAS_TRITON = False -except ImportError: - print('please install triton from https://github.com/openai/triton') - HAS_TRITON = False -try: - from flash_attn.flash_attention import FlashAttention - from flash_attn.flash_attn_interface import ( - flash_attn_unpadded_func, - flash_attn_unpadded_kvpacked_func, - flash_attn_unpadded_qkvpacked_func, - ) - HAS_FLASH_ATTN = True -except ImportError: - HAS_FLASH_ATTN = False - print('please install flash_attn from https://github.com/HazyResearch/flash-attention') - -if HAS_TRITON: - # the following functions are adapted from the OpenAI Triton tutorial - # https://github.com/openai/triton/blob/main/python/tutorials/06-fused-attention.py - @triton.jit - def _fwd_kernel( - Q, - K, - V, - sm_scale, - TMP, - L, - M, # NOTE: TMP is a scratchpad buffer to workaround a compiler bug - Out, - stride_qz, - stride_qh, - stride_qm, - stride_qk, - stride_kz, - stride_kh, - stride_kn, - stride_kk, - stride_vz, - stride_vh, - stride_vk, - stride_vn, - stride_oz, - stride_oh, - stride_om, - stride_on, - Z, - H, - N_CTX, - BLOCK_M: tl.constexpr, - BLOCK_DMODEL: tl.constexpr, - BLOCK_N: tl.constexpr, - ): - start_m = tl.program_id(0) - off_hz = tl.program_id(1) - # initialize offsets - offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) - offs_n = tl.arange(0, BLOCK_N) - offs_d = tl.arange(0, BLOCK_DMODEL) - off_q = off_hz * stride_qh + offs_m[:, None] * stride_qm + offs_d[None, :] * stride_qk - off_k = off_hz * stride_qh + offs_n[:, None] * stride_kn + offs_d[None, :] * stride_kk - off_v = off_hz * stride_qh + offs_n[:, None] * stride_qm + offs_d[None, :] * stride_qk - # Initialize pointers to Q, K, V - q_ptrs = Q + off_q - k_ptrs = K + off_k - v_ptrs = V + off_v - # initialize pointer to m and l - t_ptrs = TMP + off_hz * N_CTX + offs_m - m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") - l_i = tl.zeros([BLOCK_M], dtype=tl.float32) - acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) - # load q: it will stay in SRAM throughout - q = tl.load(q_ptrs) - # loop over k, v and update accumulator - for start_n in range(0, (start_m + 1) * BLOCK_M, BLOCK_N): - start_n = tl.multiple_of(start_n, BLOCK_N) - # -- compute qk ---- - k = tl.load(k_ptrs + start_n * stride_kn) - qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) - qk += tl.dot(q, k, trans_b=True) - qk *= sm_scale - qk += tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), 0, float("-inf")) - # -- compute m_ij, p, l_ij - m_ij = tl.max(qk, 1) - p = tl.exp(qk - m_ij[:, None]) - l_ij = tl.sum(p, 1) - # -- update m_i and l_i - m_i_new = tl.maximum(m_i, m_ij) - alpha = tl.exp(m_i - m_i_new) - beta = tl.exp(m_ij - m_i_new) - l_i_new = alpha * l_i + beta * l_ij - # -- update output accumulator -- - # scale p - p_scale = beta / l_i_new - p = p * p_scale[:, None] - # scale acc - acc_scale = l_i / l_i_new * alpha - tl.store(t_ptrs, acc_scale) - acc_scale = tl.load(t_ptrs) # BUG: have to store and immediately load - acc = acc * acc_scale[:, None] - # update acc - v = tl.load(v_ptrs + start_n * stride_vk) - p = p.to(tl.float16) - acc += tl.dot(p, v) - # update m_i and l_i - l_i = l_i_new - m_i = m_i_new - # rematerialize offsets to save registers - start_m = tl.program_id(0) - offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) - # write back l and m - l_ptrs = L + off_hz * N_CTX + offs_m - m_ptrs = M + off_hz * N_CTX + offs_m - tl.store(l_ptrs, l_i) - tl.store(m_ptrs, m_i) - # initialize pointers to output - offs_n = tl.arange(0, BLOCK_DMODEL) - off_o = off_hz * stride_oh + offs_m[:, None] * stride_om + offs_n[None, :] * stride_on - out_ptrs = Out + off_o - tl.store(out_ptrs, acc) - - @triton.jit - def _bwd_preprocess( - Out, - DO, - L, - NewDO, - Delta, - BLOCK_M: tl.constexpr, - D_HEAD: tl.constexpr, - ): - off_m = tl.program_id(0) * BLOCK_M + tl.arange(0, BLOCK_M) - off_n = tl.arange(0, D_HEAD) - # load - o = tl.load(Out + off_m[:, None] * D_HEAD + off_n[None, :]).to(tl.float32) - do = tl.load(DO + off_m[:, None] * D_HEAD + off_n[None, :]).to(tl.float32) - denom = tl.load(L + off_m).to(tl.float32) - # compute - do = do / denom[:, None] - delta = tl.sum(o * do, axis=1) - # write-back - tl.store(NewDO + off_m[:, None] * D_HEAD + off_n[None, :], do) - tl.store(Delta + off_m, delta) - - @triton.jit - def _bwd_kernel( - Q, - K, - V, - sm_scale, - Out, - DO, - DQ, - DK, - DV, - L, - M, - D, - stride_qz, - stride_qh, - stride_qm, - stride_qk, - stride_kz, - stride_kh, - stride_kn, - stride_kk, - stride_vz, - stride_vh, - stride_vk, - stride_vn, - Z, - H, - N_CTX, - num_block, - BLOCK_M: tl.constexpr, - BLOCK_DMODEL: tl.constexpr, - BLOCK_N: tl.constexpr, - ): - off_hz = tl.program_id(0) - off_z = off_hz // H - off_h = off_hz % H - # offset pointers for batch/head - Q += off_z * stride_qz + off_h * stride_qh - K += off_z * stride_qz + off_h * stride_qh - V += off_z * stride_qz + off_h * stride_qh - DO += off_z * stride_qz + off_h * stride_qh - DQ += off_z * stride_qz + off_h * stride_qh - DK += off_z * stride_qz + off_h * stride_qh - DV += off_z * stride_qz + off_h * stride_qh - for start_n in range(0, num_block): - lo = start_n * BLOCK_M - # initialize row/col offsets - offs_qm = lo + tl.arange(0, BLOCK_M) - offs_n = start_n * BLOCK_M + tl.arange(0, BLOCK_M) - offs_m = tl.arange(0, BLOCK_N) - offs_k = tl.arange(0, BLOCK_DMODEL) - # initialize pointers to value-like data - q_ptrs = Q + (offs_qm[:, None] * stride_qm + offs_k[None, :] * stride_qk) - k_ptrs = K + (offs_n[:, None] * stride_kn + offs_k[None, :] * stride_kk) - v_ptrs = V + (offs_n[:, None] * stride_qm + offs_k[None, :] * stride_qk) - do_ptrs = DO + (offs_qm[:, None] * stride_qm + offs_k[None, :] * stride_qk) - dq_ptrs = DQ + (offs_qm[:, None] * stride_qm + offs_k[None, :] * stride_qk) - # pointer to row-wise quantities in value-like data - D_ptrs = D + off_hz * N_CTX - m_ptrs = M + off_hz * N_CTX - # initialize dv amd dk - dv = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) - dk = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) - # k and v stay in SRAM throughout - k = tl.load(k_ptrs) - v = tl.load(v_ptrs) - # loop over rows - for start_m in range(lo, num_block * BLOCK_M, BLOCK_M): - offs_m_curr = start_m + offs_m - # load q, k, v, do on-chip - q = tl.load(q_ptrs) - # recompute p = softmax(qk, dim=-1).T - # NOTE: `do` is pre-divided by `l`; no normalization here - qk = tl.dot(q, k, trans_b=True) - qk = tl.where(offs_m_curr[:, None] >= (offs_n[None, :]), qk, float("-inf")) - m = tl.load(m_ptrs + offs_m_curr) - p = tl.exp(qk * sm_scale - m[:, None]) - # compute dv - do = tl.load(do_ptrs) - dv += tl.dot(p.to(tl.float16), do, trans_a=True) - # compute dp = dot(v, do) - Di = tl.load(D_ptrs + offs_m_curr) - dp = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) - Di[:, None] - dp += tl.dot(do, v, trans_b=True) - # compute ds = p * (dp - delta[:, None]) - ds = p * dp * sm_scale - # compute dk = dot(ds.T, q) - dk += tl.dot(ds.to(tl.float16), q, trans_a=True) - # # compute dq - dq = tl.load(dq_ptrs, eviction_policy="evict_last") - dq += tl.dot(ds.to(tl.float16), k) - tl.store(dq_ptrs, dq, eviction_policy="evict_last") - # # increment pointers - dq_ptrs += BLOCK_M * stride_qm - q_ptrs += BLOCK_M * stride_qm - do_ptrs += BLOCK_M * stride_qm - # write-back - dv_ptrs = DV + (offs_n[:, None] * stride_qm + offs_k[None, :] * stride_qk) - dk_ptrs = DK + (offs_n[:, None] * stride_kn + offs_k[None, :] * stride_kk) - tl.store(dv_ptrs, dv) - tl.store(dk_ptrs, dk) - - class _TritonFlashAttention(torch.autograd.Function): - - @staticmethod - def forward(ctx, q, k, v, sm_scale): - BLOCK = 128 - # shape constraints - Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1] - assert Lq == Lk and Lk == Lv - assert Lk in {16, 32, 64, 128} - o = torch.empty_like(q) - grid = (triton.cdiv(q.shape[2], BLOCK), q.shape[0] * q.shape[1]) - tmp = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32) - L = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32) - m = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32) - num_warps = 4 if Lk <= 64 else 8 - - _fwd_kernel[grid]( - q, - k, - v, - sm_scale, - tmp, - L, - m, - o, - q.stride(0), - q.stride(1), - q.stride(2), - q.stride(3), - k.stride(0), - k.stride(1), - k.stride(2), - k.stride(3), - v.stride(0), - v.stride(1), - v.stride(2), - v.stride(3), - o.stride(0), - o.stride(1), - o.stride(2), - o.stride(3), - q.shape[0], - q.shape[1], - q.shape[2], - BLOCK_M=BLOCK, - BLOCK_N=BLOCK, - BLOCK_DMODEL=Lk, - num_warps=num_warps, - num_stages=1, - ) - ctx.save_for_backward(q, k, v, o, L, m) - ctx.BLOCK = BLOCK - ctx.grid = grid - ctx.sm_scale = sm_scale - ctx.BLOCK_DMODEL = Lk - return o - - @staticmethod - def backward(ctx, do): - q, k, v, o, l, m = ctx.saved_tensors - do = do.contiguous() - dq = torch.zeros_like(q, dtype=torch.float32) - dk = torch.empty_like(k) - dv = torch.empty_like(v) - do_scaled = torch.empty_like(do) - delta = torch.empty_like(l) - _bwd_preprocess[(ctx.grid[0] * ctx.grid[1],)]( - o, - do, - l, - do_scaled, - delta, - BLOCK_M=ctx.BLOCK, - D_HEAD=ctx.BLOCK_DMODEL, - ) - - # NOTE: kernel currently buggy for other values of `num_warps` - num_warps = 8 - _bwd_kernel[(ctx.grid[1],)]( - q, - k, - v, - ctx.sm_scale, - o, - do_scaled, - dq, - dk, - dv, - l, - m, - delta, - q.stride(0), - q.stride(1), - q.stride(2), - q.stride(3), - k.stride(0), - k.stride(1), - k.stride(2), - k.stride(3), - v.stride(0), - v.stride(1), - v.stride(2), - v.stride(3), - q.shape[0], - q.shape[1], - q.shape[2], - ctx.grid[0], - BLOCK_M=ctx.BLOCK, - BLOCK_N=ctx.BLOCK, - BLOCK_DMODEL=ctx.BLOCK_DMODEL, - num_warps=num_warps, - num_stages=1, - ) - return dq, dk, dv, None - - def triton_flash_attention(q, k, v, sm_scale): - """ - Arguments: - q: (batch, nheads, seq, headdim) - k: (batch, nheads, seq, headdim) - v: (batch, nheads, seq, headdim) - sm_scale: float. The scaling of QK^T before applying softmax. - Return: - out: (batch, nheads, seq, headdim) - """ - if HAS_TRITON: - return _TritonFlashAttention.apply(q, k, v, sm_scale) - else: - raise RuntimeError("Triton kernel requires CUDA 11.4+!") - - -if HAS_FLASH_ATTN: - - def flash_attention_qkv(qkv, sm_scale, batch_size, seq_len, dropout_p=0., causal=False): - """ - Arguments: - qkv: (batch * seqlen, 3, nheads, headdim) - batch_size: int. - seq_len: int. - sm_scale: float. The scaling of QK^T before applying softmax. - Default to 1 / sqrt(headdim). - dropout_p: float. - causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling). - Return: - out: (total, nheads, headdim). - """ - max_s = seq_len - cu_seqlens = torch.arange(0, (batch_size + 1) * seq_len, step=seq_len, dtype=torch.int32, device=qkv.device) - out = flash_attn_unpadded_qkvpacked_func(qkv, - cu_seqlens, - max_s, - dropout_p, - softmax_scale=sm_scale, - causal=causal) - return out - - def flash_attention_q_kv(q, kv, sm_scale, batch_size, q_seqlen, kv_seqlen, dropout_p=0., causal=False): - """ - Arguments: - q: (batch * q_seqlen, nheads, headdim) - kv: (batch * kv_seqlen, 2, nheads, headdim) - batch_size: int. - seq_len: int. - sm_scale: float. The scaling of QK^T before applying softmax. - Default to 1 / sqrt(headdim). - dropout_p: float. - causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling). - Return: - out: (total, nheads, headdim). - """ - cu_seqlens_q = torch.arange(0, (batch_size + 1) * q_seqlen, step=q_seqlen, dtype=torch.int32, device=q.device) - cu_seqlens_k = torch.arange(0, (batch_size + 1) * kv_seqlen, - step=kv_seqlen, - dtype=torch.int32, - device=kv.device) - out = flash_attn_unpadded_kvpacked_func(q, kv, cu_seqlens_q, cu_seqlens_k, q_seqlen, kv_seqlen, dropout_p, - sm_scale, causal) - return out - - def flash_attention_q_k_v(q, k, v, sm_scale, batch_size, q_seqlen, kv_seqlen, dropout_p=0., causal=False): - """ - Arguments: - q: (batch * q_seqlen, nheads, headdim) - k: (batch * kv_seqlen, nheads, headdim) - v: (batch * kv_seqlen, nheads, headdim) - batch_size: int. - seq_len: int. - dropout_p: float. Dropout probability. - sm_scale: float. The scaling of QK^T before applying softmax. - Default to 1 / sqrt(headdim). - causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling). - Return: - out: (total, nheads, headdim). - """ - cu_seqlens_q = torch.arange(0, (batch_size + 1) * q_seqlen, step=q_seqlen, dtype=torch.int32, device=q.device) - cu_seqlens_kv = torch.arange(0, (batch_size + 1) * kv_seqlen, - step=kv_seqlen, - dtype=torch.int32, - device=k.device) - return flash_attn_unpadded_func(q, k, v, cu_seqlens_q, cu_seqlens_kv, q_seqlen, kv_seqlen, dropout_p, sm_scale, - causal) - - -########################################################################## diff --git a/colossalai/kernel/cuda_native/layer_norm.py b/colossalai/kernel/cuda_native/layer_norm.py index 40355a41ed0d2c1b1d1b6266b38bf160e28f8a28..c7d2a3a450221bdcc592936a593b8aec275afea6 100644 --- a/colossalai/kernel/cuda_native/layer_norm.py +++ b/colossalai/kernel/cuda_native/layer_norm.py @@ -18,7 +18,6 @@ except ImportError: class FusedLayerNormAffineFunction(torch.autograd.Function): - @staticmethod @custom_fwd(cast_inputs=torch.float32) def forward(ctx, input, weight, bias, normalized_shape, eps): @@ -30,7 +29,6 @@ class FusedLayerNormAffineFunction(torch.autograd.Function): global layer_norm if layer_norm is None: - layer_norm = LayerNormBuilder().load() output, mean, invvar = layer_norm.forward_affine(input_, ctx.normalized_shape, weight_, bias_, ctx.eps) ctx.layernorm_op = layer_norm @@ -43,17 +41,14 @@ class FusedLayerNormAffineFunction(torch.autograd.Function): def backward(ctx, grad_output): input_, weight_, bias_, mean, invvar = ctx.saved_tensors grad_input = grad_weight = grad_bias = None - grad_input, grad_weight, grad_bias \ - = layer_norm.backward_affine( - grad_output.contiguous(), mean, invvar, - input_, ctx.normalized_shape, - weight_, bias_, ctx.eps) + grad_input, grad_weight, grad_bias = layer_norm.backward_affine( + grad_output.contiguous(), mean, invvar, input_, ctx.normalized_shape, weight_, bias_, ctx.eps + ) return grad_input, grad_weight, grad_bias, None, None class MixedFusedLayerNorm(torch.nn.Module): - def __init__(self, normalized_shape, eps=1e-5, device=None, dtype=None): super(MixedFusedLayerNorm, self).__init__() @@ -66,13 +61,11 @@ class MixedFusedLayerNorm(torch.nn.Module): self.reset_parameters() def reset_parameters(self): - init.ones_(self.weight) init.zeros_(self.bias) def forward(self, input): - return FusedLayerNormAffineFunction.apply(input, self.weight, self.bias, self.normalized_shape, self.eps) def __repr__(self): - return f'MixedFusedLayerNorm(normalized_shape={self.normalized_shape}, eps={self.eps})' + return f"MixedFusedLayerNorm(normalized_shape={self.normalized_shape}, eps={self.eps})" diff --git a/colossalai/kernel/cuda_native/mha/__init__.py b/colossalai/kernel/cuda_native/mha/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..cad36e598d14b627d290dc3f3aec2101b0ee1186 --- /dev/null +++ b/colossalai/kernel/cuda_native/mha/__init__.py @@ -0,0 +1,3 @@ +from .mha import ColoAttention + +__all__ = ["ColoAttention"] diff --git a/colossalai/kernel/cuda_native/mha/flash_attn_2.py b/colossalai/kernel/cuda_native/mha/flash_attn_2.py new file mode 100644 index 0000000000000000000000000000000000000000..9ee83915b1b4275908751f1daa587d36689a34d3 --- /dev/null +++ b/colossalai/kernel/cuda_native/mha/flash_attn_2.py @@ -0,0 +1,80 @@ +import warnings +from typing import Optional + +import torch + + +def is_ampere_or_better_gpu(): + if torch.cuda.is_available(): + device = torch.device("cuda") + properties = torch.cuda.get_device_properties(device) + if properties.major >= 8: # Ampere GPUs or newer + return True + return False + + +# "Check Ampere GPUs or newer" +HAS_FLASH_ATTN = False +if is_ampere_or_better_gpu(): + HAS_FLASH_ATTN = True +else: + warnings.warn("FlashAttention only supports Ampere GPUs or newer.") + HAS_FLASH_ATTN = False +try: + from flash_attn.flash_attn_interface import flash_attn_func, flash_attn_varlen_func + + HAS_FLASH_ATTN = True +except ImportError: + warnings.warn("please install flash_attn from https://github.com/HazyResearch/flash-attention") + HAS_FLASH_ATTN = False + +if HAS_FLASH_ATTN: + pass + + from .utils import SeqLenInfo + + def flash_attention( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + seq_len_info_q: SeqLenInfo, + seq_len_info_kv: SeqLenInfo, + bias: Optional[torch.Tensor] = None, + dropout_p: float = 0.0, + scale: float = None, + causal: bool = False, + padded: bool = False, + ): + """ + Arguments: + q: (batch, q_seqlen, nheads, headdim) + k: (batch, kv_seqlen, nheads, headdim) + v: (batch, kv_seqlen, nheads, headdim) + batch_size: int. + seq_len: int. + dropout_p: float. Dropout probability. + sm_scale: float. The scaling of QK^T before applying softmax. + Default to 1 / sqrt(headdim). + causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling). + Return: + attn_out: (batch, q_seqlen, nheads, headdim). + """ + if padded: + if seq_len_info_kv == None: + seq_len_info_kv = seq_len_info_q + + attn_out = flash_attn_varlen_func( + q, + k, + v, + seq_len_info_q.cu_seqlens, + seq_len_info_kv.cu_seqlens, + seq_len_info_q.max_seqlen, + seq_len_info_kv.max_seqlen, + dropout_p, + scale, + causal, + ) + else: + attn_out = flash_attn_func(q, k, v, dropout_p=dropout_p, softmax_scale=scale, causal=causal) + return attn_out diff --git a/colossalai/kernel/cuda_native/mha/mem_eff_attn.py b/colossalai/kernel/cuda_native/mha/mem_eff_attn.py new file mode 100644 index 0000000000000000000000000000000000000000..649e74d61bab2b872c4211b1eb97fb9fbd9cb040 --- /dev/null +++ b/colossalai/kernel/cuda_native/mha/mem_eff_attn.py @@ -0,0 +1,70 @@ +import warnings + +HAS_MEM_EFF_ATTN = False +try: + from xformers.ops.fmha import MemoryEfficientAttentionCutlassOp, memory_efficient_attention + from xformers.ops.fmha.attn_bias import ( + BlockDiagonalCausalMask, + BlockDiagonalMask, + LowerTriangularMask, + LowerTriangularMaskWithTensorBias, + ) + + HAS_MEM_EFF_ATTN = True +except ImportError: + warnings.warn("please install xformers from https://github.com/facebookresearch/xformers") + HAS_MEM_EFF_ATTN = False + +if HAS_MEM_EFF_ATTN: + """ + A general attention module using the flash attention kernels from xformers: + https://github.com/facebookresearch/xformers/tree/main/xformers/ops/fmha + """ + from typing import Optional + + import torch + + from .utils import SeqLenInfo + + allow_alibi = True + for op in MemoryEfficientAttentionCutlassOp: + allow_alibi = allow_alibi & (LowerTriangularMaskWithTensorBias in op.SUPPORTED_ATTN_BIAS_TYPES) + + def mem_eff_attention( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + seq_len_info_q: SeqLenInfo, + seq_len_info_kv: SeqLenInfo, + bias: Optional[torch.Tensor] = None, + dropout_p: float = 0.0, + scale: float = None, + causal: bool = False, + padded: bool = False, + ): + attn_bias = None + if padded: # bert style + if not causal: + attn_bias = BlockDiagonalMask.from_seqlens(seq_len_info_q.seqlens, seq_len_info_kv.seqlens) + else: + attn_bias = BlockDiagonalCausalMask.from_seqlens(seq_len_info_q.seqlens, seq_len_info_kv.seqlens) + elif causal: # gpt style + attn_bias = LowerTriangularMask() + + if bias is not None: # alibi / relative position embedding + assert allow_alibi, "flash attention with bias is not supported in this system." + assert causal, "attention with bias is only supported for causal attention so far." + attn_bias = attn_bias.add_bias(bias) + + if padded: + q = q.unsqueeze(0) + k = k.unsqueeze(0) + v = v.unsqueeze(0) + + out = memory_efficient_attention(q, k, v, attn_bias=attn_bias, p=dropout_p, scale=scale) + + # shape: (b*s, n, d) + if padded: + out = out.squeeze(0) + + return out diff --git a/colossalai/kernel/cuda_native/mha/mha.py b/colossalai/kernel/cuda_native/mha/mha.py new file mode 100644 index 0000000000000000000000000000000000000000..1c778439d33f2fbbb4f40f86e9b7e48ca020772d --- /dev/null +++ b/colossalai/kernel/cuda_native/mha/mha.py @@ -0,0 +1,113 @@ +import math +from typing import Optional + +import torch +from einops import rearrange + +from ..scaled_softmax import AttnMaskType +from .flash_attn_2 import HAS_FLASH_ATTN +from .mem_eff_attn import HAS_MEM_EFF_ATTN +from .utils import Repad, SeqLenInfo, Unpad + +if HAS_FLASH_ATTN: + from .flash_attn_2 import flash_attention +if HAS_MEM_EFF_ATTN: + from .mem_eff_attn import mem_eff_attention + + +class ColoAttention(torch.nn.Module): + def __init__(self, embed_dim: int, num_heads: int, dropout: float = 0.0, scale=None): + super().__init__() + assert ( + embed_dim % num_heads == 0 + ), f"the embed dim ({embed_dim}) is not divisible by the number of attention heads ({num_heads})." + if scale is not None: + self.scale = scale + else: + self.scale = 1 / math.sqrt(embed_dim // num_heads) + self.dropout = dropout + + if not HAS_MEM_EFF_ATTN and not HAS_FLASH_ATTN: + raise Exception("flash attention can not support!") + + @staticmethod + def unpad(tensor: torch.Tensor, indices: torch.Tensor) -> torch.Tensor: + return Unpad.apply(tensor, indices) + + @staticmethod + def repad(tensor: torch.Tensor, indices: torch.Tensor, batch_size: int, seq_len: int) -> torch.Tensor: + return Repad.apply(tensor, indices, batch_size, seq_len) + + def forward( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attn_mask: Optional[torch.Tensor] = None, + attn_mask_type: Optional[AttnMaskType] = None, + bias: Optional[torch.Tensor] = None, + ): + attn = None + if HAS_FLASH_ATTN and query.dtype in [torch.float16, torch.bfloat16] and bias == None: + attn = flash_attention + else: + attn = mem_eff_attention + + padded = attn_mask_type is not None and attn_mask_type.value % 2 == 1 + causal = attn_mask_type is not None and attn_mask_type.value > 1 + + batch_size, tgt_len, src_len = query.shape[0], query.shape[1], key.shape[1] + # unpad + seq_len_info_q = None + seq_len_info_kv = None + if padded: + # bert style, unpad process + assert ( + attn_mask is not None + ), f"attention mask {attn_mask} is not valid for attention mask type {attn_mask_type}." + assert attn_mask.dim() == 2, ( + "attention mask is supposed to have shape (batch_size, seq_len), " + + f"but got {attn_mask.dim()} dimensions." + ) + + # bert style + if tgt_len == src_len: + seq_len_info_q = SeqLenInfo.materialize(attn_mask=attn_mask, device=query.device) + if batch_size > 1: + query, key, value = self.unpad( + torch.stack([query, key, value], dim=2), seq_len_info_q.indices + ).unbind(dim=1) + else: + query, key, value = torch.stack([query, key, value], dim=2).squeeze(0).unbind(dim=1) + seq_len_info_kv = seq_len_info_q + else: + seq_len_info_q = SeqLenInfo.materialize(size=(batch_size, tgt_len), device=query.device) + seq_len_info_kv = SeqLenInfo.materialize(attn_mask=attn_mask, device=query.device) + if batch_size > 1: + query = rearrange(query, "b s ... -> c (b s) ...", c=1) + key, value = self.unpad(torch.stack([query, key, value], dim=2), seq_len_info_kv.indices).unbind( + dim=1 + ) + else: + query, key, value = torch.stack([query, key, value], dim=2).squeeze(0).unbind(dim=1) + + out = attn( + query, + key, + value, + seq_len_info_q, + seq_len_info_kv, + dropout_p=self.dropout, + scale=self.scale, + causal=causal, + padded=padded, + ) + + # repad + if padded: + if batch_size > 1: + out = self.repad(out, seq_len_info_q.indices, batch_size, tgt_len) + out = rearrange(out, "(b s) h d -> b s h d", b=batch_size) + + out = rearrange(out, "b s h d -> b s (h d)") + return out diff --git a/colossalai/kernel/cuda_native/mha/utils.py b/colossalai/kernel/cuda_native/mha/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..fe31921b961b05ddf0865d7949fce3b1aea6a515 --- /dev/null +++ b/colossalai/kernel/cuda_native/mha/utils.py @@ -0,0 +1,82 @@ +from dataclasses import dataclass +from typing import Iterable, Tuple + +import torch +import torch.nn.functional as F +from einops import rearrange + +from colossalai.utils.cuda import get_current_device + + +class Unpad(torch.autograd.Function): + """ + Adapted from + https://github.com/HazyResearch/flash-attention/blob/main/flash_attn/bert_padding.py + """ + + @staticmethod + def forward(ctx, tensor: torch.Tensor, indices: torch.Tensor): + ctx.save_for_backward(indices) + # [b, s, ...] + assert tensor.ndim >= 3 + ctx.bsz = tensor.shape[0] + out = rearrange(tensor, "b s ... -> (b s) ...") + ctx.shape = out.shape + # [ntokens, ...] + return out[indices] + + @staticmethod + def backward(ctx, grad_output): + (indices,) = ctx.saved_tensors + # [ntokens, ...] + grad = torch.zeros(ctx.shape, dtype=grad_output.dtype, device=grad_output.device) + grad[indices] = grad_output + grad = rearrange(grad, "(b s) ... -> b s ...", b=ctx.bsz) + # [b, s, ...] + return grad, None + + +class Repad(torch.autograd.Function): + """ + Adapted from + https://github.com/HazyResearch/flash-attention/blob/main/flash_attn/bert_padding.py + """ + + @staticmethod + def forward(ctx, tensor: torch.Tensor, indices: torch.Tensor, batch_size: int, seq_len: int): + ctx.save_for_backward(indices) + # [ntokens, ...] + tensor = tensor + out = torch.zeros((batch_size * seq_len, *tensor.shape[1:]), dtype=tensor.dtype, device=tensor.device) + # [b*s, ...] + out[indices] = tensor + return out + + @staticmethod + def backward(ctx, grad_output): + (indices,) = ctx.saved_tensors + # [b*s, ...] + grad = grad_output[indices] + # [ntokens, ...] + return grad, None, None, None + + +@dataclass +class SeqLenInfo: + seqlens: Iterable[int] = None + indices: torch.Tensor = None + max_seqlen: int = None + cu_seqlens: torch.Tensor = None + + @staticmethod + def materialize(attn_mask: torch.Tensor = None, size: Tuple[int] = None, device=get_current_device()): + if attn_mask is not None: + indices = torch.nonzero(attn_mask.flatten(), as_tuple=False).flatten().to(device) + seqlens = attn_mask.sum(dim=-1, dtype=torch.int32).flatten() + else: + batch_size, tgt_len = size[0], size[1] + indices = torch.arange(batch_size * tgt_len, dtype=torch.long, device=device) + seqlens = torch.LongTensor([tgt_len] * batch_size, device=device) + max_seqlen = max(seqlens) + cu_seqlens = F.pad(torch.cumsum(seqlens, dim=0, dtype=torch.int32), (1, 0)).to(device) + return SeqLenInfo(seqlens.tolist(), indices, max_seqlen, cu_seqlens) diff --git a/colossalai/kernel/cuda_native/multihead_attention.py b/colossalai/kernel/cuda_native/multihead_attention.py index 3b6470cdcbb98e622fc68062b609dc590f5301ae..87afc1862847be24b6196ad43e75ffeecb51bdf0 100644 --- a/colossalai/kernel/cuda_native/multihead_attention.py +++ b/colossalai/kernel/cuda_native/multihead_attention.py @@ -36,34 +36,64 @@ colossal_multihead_attention = None @dataclass class Config: - max_batch_tokens: int # max batch token numbers - max_seq_len: int # max sequence length - hidden_size: int # size of transformer hidden layers - nhead: int # number of heads in attention - attn_prob_dropout_ratio: float # attention score dropout ratio - hidden_dropout_ratio: float # dropout ration before residual - norm_first: bool # norm_first - fp16: bool # fp16 presion + max_batch_tokens: int # max batch token numbers + max_seq_len: int # max sequence length + hidden_size: int # size of transformer hidden layers + nhead: int # number of heads in attention + attn_prob_dropout_ratio: float # attention score dropout ratio + hidden_dropout_ratio: float # dropout ration before residual + norm_first: bool # norm_first + fp16: bool # fp16 precision class MultiHeadAttention1DFunc(Function): - @staticmethod - def forward(ctx, input, input_mask, in_proj_weight, in_proj_bias, out_proj_weight, out_proj_bias, norm_weight, - norm_bias, config): + def forward( + ctx, + input, + input_mask, + in_proj_weight, + in_proj_bias, + out_proj_weight, + out_proj_bias, + norm_weight, + norm_bias, + config, + ): cuda_module = colossal_multihead_attention - forward_func = (cuda_module.multihead_attention_fw_fp16 - if config.fp16 else cuda_module.multihead_attention_fw_fp32) + forward_func = ( + cuda_module.multihead_attention_fw_fp16 if config.fp16 else cuda_module.multihead_attention_fw_fp32 + ) if config.fp16: input = input.to(torch.half) input_mask = input_mask.to(torch.half) - (output,) = forward_func(config.layer_id, input, input_mask, in_proj_weight, in_proj_bias, out_proj_weight, - out_proj_bias, norm_weight, norm_bias, config.training, config.norm_first) + (output,) = forward_func( + config.layer_id, + input, + input_mask, + in_proj_weight, + in_proj_bias, + out_proj_weight, + out_proj_bias, + norm_weight, + norm_bias, + config.training, + config.norm_first, + ) if config.is_grad_enabled and config.training: - ctx.save_for_backward(output, input, input_mask, in_proj_weight, in_proj_bias, out_proj_weight, - out_proj_bias, norm_weight, norm_bias) + ctx.save_for_backward( + output, + input, + input_mask, + in_proj_weight, + in_proj_bias, + out_proj_weight, + out_proj_bias, + norm_weight, + norm_bias, + ) ctx.config = config return output @@ -72,11 +102,21 @@ class MultiHeadAttention1DFunc(Function): assert ctx.config.training cuda_module = colossal_multihead_attention - backward_func = (cuda_module.multihead_attention_bw_fp16 - if ctx.config.fp16 else cuda_module.multihead_attention_bw_fp32) + backward_func = ( + cuda_module.multihead_attention_bw_fp16 if ctx.config.fp16 else cuda_module.multihead_attention_bw_fp32 + ) - output, input, input_mask, in_proj_weight, in_proj_bias, out_proj_weight, \ - out_proj_bias, norm_weight, norm_bias = ctx.saved_tensors + ( + output, + input, + input_mask, + in_proj_weight, + in_proj_bias, + out_proj_weight, + out_proj_bias, + norm_weight, + norm_bias, + ) = ctx.saved_tensors grad_input = None grad_in_proj_weight = None @@ -91,13 +131,39 @@ class MultiHeadAttention1DFunc(Function): output = output.to(torch.half) input = input.to(torch.half) input_mask = input_mask.to(torch.half) - grad_input, grad_in_proj_weight, grad_in_proj_bias, grad_out_proj_weight, \ - grad_out_proj_bias, grad_norm_weight, grad_norm_bias = backward_func( - ctx.config.layer_id, grad_output, output, input, input_mask, in_proj_weight, - in_proj_bias, out_proj_weight, out_proj_bias, norm_weight, norm_bias) + ( + grad_input, + grad_in_proj_weight, + grad_in_proj_bias, + grad_out_proj_weight, + grad_out_proj_bias, + grad_norm_weight, + grad_norm_bias, + ) = backward_func( + ctx.config.layer_id, + grad_output, + output, + input, + input_mask, + in_proj_weight, + in_proj_bias, + out_proj_weight, + out_proj_bias, + norm_weight, + norm_bias, + ) - return (grad_input, None, grad_in_proj_weight, grad_in_proj_bias, grad_out_proj_weight, grad_out_proj_bias, - grad_norm_weight, grad_norm_bias, None) + return ( + grad_input, + None, + grad_in_proj_weight, + grad_in_proj_bias, + grad_out_proj_weight, + grad_out_proj_bias, + grad_norm_weight, + grad_norm_bias, + None, + ) class MultiHeadAttention(nn.Module): @@ -122,8 +188,9 @@ class MultiHeadAttention(nn.Module): def __init__(self, hidden_size, nhead, batch_size, max_seq_len, dropout=0.0, norm_first=False, fp16=True, pg=None): super(MultiHeadAttention, self).__init__() - self.config = Config(batch_size * max_seq_len, max_seq_len, hidden_size, nhead, dropout, dropout, norm_first, - fp16) + self.config = Config( + batch_size * max_seq_len, max_seq_len, hidden_size, nhead, dropout, dropout, norm_first, fp16 + ) check_config(self.config) self.pg = pg self.pg_size = 1 @@ -136,13 +203,17 @@ class MultiHeadAttention(nn.Module): global colossal_multihead_attention if colossal_multihead_attention is None: from colossalai.kernel.op_builder import MultiHeadAttnBuilder + multihead_attention = MultiHeadAttnBuilder().load() colossal_multihead_attention = multihead_attention # create the layer in cuda kernels. cuda_module = colossal_multihead_attention - create_layer_func = (cuda_module.create_multihead_attention_fp16 - if self.config.fp16 else cuda_module.create_multihead_attention_fp32) + create_layer_func = ( + cuda_module.create_multihead_attention_fp16 + if self.config.fp16 + else cuda_module.create_multihead_attention_fp32 + ) create_layer_func( self.config.layer_id, @@ -204,13 +275,15 @@ class MultiHeadAttention(nn.Module): with torch.no_grad(): self.in_proj_weight.copy_( - attn_qkvw_global.view(3, hs, hs)[:, - int(hs * rank_in_pg / self.pg_size):int(hs * (rank_in_pg + 1) / - self.pg_size), :]) + attn_qkvw_global.view(3, hs, hs)[ + :, int(hs * rank_in_pg / self.pg_size) : int(hs * (rank_in_pg + 1) / self.pg_size), : + ] + ) self.in_proj_bias.copy_( - attn_qkvb_global.view(3, hs)[:, - int(hs * rank_in_pg / self.pg_size):int(hs * (rank_in_pg + 1) / - self.pg_size)]) + attn_qkvb_global.view(3, hs)[ + :, int(hs * rank_in_pg / self.pg_size) : int(hs * (rank_in_pg + 1) / self.pg_size) + ] + ) attn_ow_global = torch.empty(hs, hs) nn.init.xavier_uniform_(attn_ow_global, 1.0) @@ -218,9 +291,9 @@ class MultiHeadAttention(nn.Module): torch.distributed.broadcast(attn_ow_global, src=0, group=self.pg) attn_ow_global = attn_ow_global.cpu() with torch.no_grad(): - self.out_proj_weight.copy_(attn_ow_global[:, - int(hs * rank_in_pg / - self.pg_size):int(hs * (rank_in_pg + 1) / self.pg_size)]) + self.out_proj_weight.copy_( + attn_ow_global[:, int(hs * rank_in_pg / self.pg_size) : int(hs * (rank_in_pg + 1) / self.pg_size)] + ) else: attn_qkvw = self.in_proj_weight.view(-1, hs) @@ -238,7 +311,7 @@ class MultiHeadAttention(nn.Module): self.config.training = self.training self.config.is_grad_enabled = torch.is_grad_enabled() hidden_states = hidden_states.contiguous() - encoder_padding_mask = ((encoder_padding_mask * -1e8).type_as(hidden_states).contiguous()) + encoder_padding_mask = (encoder_padding_mask * -1e8).type_as(hidden_states).contiguous() bs, sl, dim = hidden_states.size() if bs * sl > self.config.max_batch_tokens: @@ -250,8 +323,16 @@ class MultiHeadAttention(nn.Module): else: assert bs == encoder_padding_mask.size(0) and sl == encoder_padding_mask.size(1) - output = MultiHeadAttention1DFunc.apply(hidden_states, encoder_padding_mask, self.in_proj_weight, - self.in_proj_bias, self.out_proj_weight, self.out_proj_bias, - self.norm_weight, self.norm_bias, self.config) + output = MultiHeadAttention1DFunc.apply( + hidden_states, + encoder_padding_mask, + self.in_proj_weight, + self.in_proj_bias, + self.out_proj_weight, + self.out_proj_bias, + self.norm_weight, + self.norm_bias, + self.config, + ) return output.to(self.precision) diff --git a/colossalai/kernel/cuda_native/scaled_softmax.py b/colossalai/kernel/cuda_native/scaled_softmax.py index 24e458bb3ea53d87fb5cca21339155dbaa35be07..26a5bce16d5c86ab4d3c4d0545544902c31c34d2 100644 --- a/colossalai/kernel/cuda_native/scaled_softmax.py +++ b/colossalai/kernel/cuda_native/scaled_softmax.py @@ -19,6 +19,7 @@ except ImportError: class AttnMaskType(enum.Enum): padding = 1 causal = 2 + paddedcausal = 3 class ScaledUpperTriangMaskedSoftmax(torch.autograd.Function): @@ -107,15 +108,16 @@ class FusedScaleMaskSoftmax(nn.Module): super(FusedScaleMaskSoftmax, self).__init__() self.input_in_fp16 = input_in_fp16 self.input_in_bf16 = input_in_bf16 - assert not (self.input_in_fp16 - and self.input_in_bf16), "both fp16 and bf16 flags cannot be active at the same time." + assert not ( + self.input_in_fp16 and self.input_in_bf16 + ), "both fp16 and bf16 flags cannot be active at the same time." self.input_in_float16 = self.input_in_fp16 or self.input_in_bf16 self.attn_mask_type = attn_mask_type self.scaled_masked_softmax_fusion = scaled_masked_softmax_fusion self.mask_func = mask_func self.softmax_in_fp32 = softmax_in_fp32 self.scale = scale - assert (self.scale is None or softmax_in_fp32), "softmax should be in fp32 when scaled" + assert self.scale is None or softmax_in_fp32, "softmax should be in fp32 when scaled" def forward(self, input, mask): # [b, np, sq, sk] @@ -129,17 +131,18 @@ class FusedScaleMaskSoftmax(nn.Module): def is_kernel_available(self, mask, b, np, sq, sk): attn_batches = b * np - if (self.scaled_masked_softmax_fusion # user want to fuse - and self.input_in_float16 # input must be fp16 - and mask is not None # mask tensor must not be None - and 16 < sk <= 2048 # sk must be 16 ~ 2048 - and sq % 4 == 0 # sq must be divisor of 4 - and attn_batches % 4 == 0 # np * b must be divisor of 4 - ): + if ( + self.scaled_masked_softmax_fusion # user want to fuse + and self.input_in_float16 # input must be fp16 + and mask is not None # mask tensor must not be None + and 16 < sk <= 2048 # sk must be 16 ~ 2048 + and sq % 4 == 0 # sq must be divisor of 4 + and attn_batches % 4 == 0 # np * b must be divisor of 4 + ): if 0 <= sk <= 2048: batch_per_block = self.get_batch_per_block(sq, sk, b, np) - if self.attn_mask_type == AttnMaskType.causal: + if self.attn_mask_type.value > 1: if attn_batches % batch_per_block == 0: return True else: @@ -151,7 +154,7 @@ class FusedScaleMaskSoftmax(nn.Module): b, np, sq, sk = input.size() scale = self.scale if self.scale is not None else 1.0 - if self.attn_mask_type == AttnMaskType.causal: + if self.attn_mask_type.value > 1: assert sq == sk, "causal mask is only for self attention" # input is 3D tensor (attn_batches, sq, sk) diff --git a/colossalai/kernel/jit/__init__.py b/colossalai/kernel/jit/__init__.py index 57b8fb7b2e996ea0f0336dad1e42ea379d608b15..67a147cd581c5d89d5c7f48da9257b949be918c9 100644 --- a/colossalai/kernel/jit/__init__.py +++ b/colossalai/kernel/jit/__init__.py @@ -1,8 +1,10 @@ -from .option import set_jit_fusion_options -from .bias_dropout_add import bias_dropout_add_fused_train, bias_dropout_add_fused_inference +from .bias_dropout_add import bias_dropout_add_fused_inference, bias_dropout_add_fused_train from .bias_gelu import bias_gelu_impl +from .option import set_jit_fusion_options __all__ = [ - "bias_dropout_add_fused_train", "bias_dropout_add_fused_inference", "bias_gelu_impl", - "set_jit_fusion_options" + "bias_dropout_add_fused_train", + "bias_dropout_add_fused_inference", + "bias_gelu_impl", + "set_jit_fusion_options", ] diff --git a/colossalai/kernel/jit/bias_dropout_add.py b/colossalai/kernel/jit/bias_dropout_add.py index 3687dde79a08b7f8f192d6516694938828aae659..e046ee2964afb2ee1079948ec84b6ac5bf89c749 100644 --- a/colossalai/kernel/jit/bias_dropout_add.py +++ b/colossalai/kernel/jit/bias_dropout_add.py @@ -9,16 +9,14 @@ def bias_dropout_add(x, bias, residual, prob, training): @torch.jit.script -def bias_dropout_add_fused_train(x: torch.Tensor, - bias: torch.Tensor, - residual: torch.Tensor, - prob: float) -> torch.Tensor: +def bias_dropout_add_fused_train( + x: torch.Tensor, bias: torch.Tensor, residual: torch.Tensor, prob: float +) -> torch.Tensor: return bias_dropout_add(x, bias, residual, prob, True) @torch.jit.script -def bias_dropout_add_fused_inference(x: torch.Tensor, - bias: torch.Tensor, - residual: torch.Tensor, - prob: float) -> torch.Tensor: +def bias_dropout_add_fused_inference( + x: torch.Tensor, bias: torch.Tensor, residual: torch.Tensor, prob: float +) -> torch.Tensor: return bias_dropout_add(x, bias, residual, prob, False) diff --git a/colossalai/kernel/jit/bias_gelu.py b/colossalai/kernel/jit/bias_gelu.py index 33b4ac32b044f662e475386a3b8e7504b54b108f..5fa0d07015be43bd625f48e237ed540591c39756 100644 --- a/colossalai/kernel/jit/bias_gelu.py +++ b/colossalai/kernel/jit/bias_gelu.py @@ -29,7 +29,6 @@ def bias_gelu_back(g, bias, y): class GeLUFunction(torch.autograd.Function): - @staticmethod # bias is an optional argument def forward(ctx, input, bias): diff --git a/colossalai/kernel/jit/option.py b/colossalai/kernel/jit/option.py index aa41f57678fc116ac4acef72f52763f5dadabfed..8bebad894ca4571d906609f65937f3057286aeb7 100644 --- a/colossalai/kernel/jit/option.py +++ b/colossalai/kernel/jit/option.py @@ -1,6 +1,6 @@ import torch -from colossalai.nn.layer.colossalai_layer import Embedding, Linear +from colossalai.legacy.nn.layer.colossalai_layer import Embedding, Linear from colossalai.utils import get_current_device from .bias_dropout_add import bias_dropout_add_fused_train @@ -10,15 +10,14 @@ JIT_OPTIONS_SET = False def set_jit_fusion_options(): - """Set PyTorch JIT layer fusion options. - """ + """Set PyTorch JIT layer fusion options.""" # LSG: the latest pytorch and CUDA versions may not support # the following jit settings global JIT_OPTIONS_SET if JIT_OPTIONS_SET == False: # flags required to enable jit fusion kernels - TORCH_MAJOR = int(torch.__version__.split('.')[0]) - TORCH_MINOR = int(torch.__version__.split('.')[1]) + TORCH_MAJOR = int(torch.__version__.split(".")[0]) + TORCH_MINOR = int(torch.__version__.split(".")[1]) if (TORCH_MAJOR > 1) or (TORCH_MAJOR == 1 and TORCH_MINOR >= 10): # nvfuser torch._C._jit_set_profiling_executor(True) @@ -38,12 +37,14 @@ def set_jit_fusion_options(): JIT_OPTIONS_SET = True -def warmup_jit_fusion(batch_size: int, - hidden_size: int, - seq_length: int = 512, - vocab_size: int = 32768, - dtype: torch.dtype = torch.float32): - """ Compilie JIT functions before the main training steps """ +def warmup_jit_fusion( + batch_size: int, + hidden_size: int, + seq_length: int = 512, + vocab_size: int = 32768, + dtype: torch.dtype = torch.float32, +): + """Compile JIT functions before the main training steps""" embed = Embedding(vocab_size, hidden_size).to(get_current_device()) linear_1 = Linear(hidden_size, hidden_size * 4, skip_bias_add=True).to(get_current_device()) diff --git a/colossalai/kernel/triton/__init__.py b/colossalai/kernel/triton/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..9830691581c00cd32f537739da698761448bd9a5 --- /dev/null +++ b/colossalai/kernel/triton/__init__.py @@ -0,0 +1,31 @@ +try: + import triton + + HAS_TRITON = True + +except ImportError: + HAS_TRITON = False + print("Triton is not installed. Please install Triton to use Triton kernels.") + +# There may exist import error even if we have triton installed. +if HAS_TRITON: + from .context_attention import bloom_context_attn_fwd, llama_context_attn_fwd + from .copy_kv_cache_dest import copy_kv_cache_to_dest + from .fused_layernorm import layer_norm + from .gptq_triton import gptq_fused_linear_triton + from .rms_norm import rmsnorm_forward + from .rotary_embedding_kernel import rotary_embedding_fwd + from .softmax import softmax + from .token_attention_kernel import token_attention_fwd + + __all__ = [ + "llama_context_attn_fwd", + "bloom_context_attn_fwd", + "softmax", + "layer_norm", + "rmsnorm_forward", + "copy_kv_cache_to_dest", + "rotary_embedding_fwd", + "token_attention_fwd", + "gptq_fused_linear_triton", + ] diff --git a/colossalai/kernel/triton/context_attention.py b/colossalai/kernel/triton/context_attention.py new file mode 100644 index 0000000000000000000000000000000000000000..01d54566483abe2468220b4fb63bafaccfee0bca --- /dev/null +++ b/colossalai/kernel/triton/context_attention.py @@ -0,0 +1,566 @@ +import math + +import torch + +try: + import triton + import triton.language as tl + + HAS_TRITON = True +except ImportError: + HAS_TRITON = False + print("please install triton from https://github.com/openai/triton") + +if HAS_TRITON: + """ + this function is modified from + https://github.com/ModelTC/lightllm/blob/f093edc20683ac3ea1bca3fb5d8320a0dd36cf7b/lightllm/models/llama/triton_kernel/context_flashattention_nopad.py#L10 + """ + + @triton.jit + def _context_flash_attention_kernel( + Q, + K, + V, + sm_scale, + B_Start_Loc, + B_Seqlen, + TMP, + alibi_ptr, + Out, + stride_qbs, + stride_qh, + stride_qd, + stride_kbs, + stride_kh, + stride_kd, + stride_vbs, + stride_vh, + stride_vd, + stride_obs, + stride_oh, + stride_od, + stride_tmp_b, + stride_tmp_h, + stride_tmp_s, + # suggtest set-up 64, 128, 256, 512 + BLOCK_M: tl.constexpr, + BLOCK_DMODEL: tl.constexpr, + BLOCK_N: tl.constexpr, + ): + batch_id = tl.program_id(0) + cur_head = tl.program_id(1) + start_m = tl.program_id(2) + + # initialize offsets + offs_n = tl.arange(0, BLOCK_N) + offs_d = tl.arange(0, BLOCK_DMODEL) + offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) + + # get batch info + cur_batch_seq_len = tl.load(B_Seqlen + batch_id) + cur_batch_start_index = tl.load(B_Start_Loc + batch_id) + block_start_loc = BLOCK_M * start_m + + load_p_ptrs = ( + Q + + (cur_batch_start_index + offs_m[:, None]) * stride_qbs + + cur_head * stride_qh + + offs_d[None, :] * stride_qd + ) + q = tl.load(load_p_ptrs, mask=offs_m[:, None] < cur_batch_seq_len, other=0.0) + + k_ptrs = K + offs_n[None, :] * stride_kbs + cur_head * stride_kh + offs_d[:, None] * stride_kd + v_ptrs = V + offs_n[:, None] * stride_vbs + cur_head * stride_vh + offs_d[None, :] * stride_vd + t_ptrs = TMP + batch_id * stride_tmp_b + cur_head * stride_tmp_h + offs_m * stride_tmp_s + + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) + + if alibi_ptr is not None: + alibi_m = tl.load(alibi_ptr + cur_head) + + block_mask = tl.where(block_start_loc < cur_batch_seq_len, 1, 0) + + for start_n in range(0, block_mask * (start_m + 1) * BLOCK_M, BLOCK_N): + start_n = tl.multiple_of(start_n, BLOCK_N) + k = tl.load( + k_ptrs + (cur_batch_start_index + start_n) * stride_kbs, + mask=(start_n + offs_n[None, :]) < cur_batch_seq_len, + other=0.0, + ) + + qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) + qk += tl.dot(q, k) + qk *= sm_scale + + if alibi_ptr is not None: + alibi_loc = offs_m[:, None] - (start_n + offs_n[None, :]) + qk -= alibi_loc * alibi_m + + qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk, float("-inf")) + + m_ij = tl.max(qk, 1) + p = tl.exp(qk - m_ij[:, None]) + l_ij = tl.sum(p, 1) + # -- update m_i and l_i + m_i_new = tl.maximum(m_i, m_ij) + alpha = tl.exp(m_i - m_i_new) + beta = tl.exp(m_ij - m_i_new) + l_i_new = alpha * l_i + beta * l_ij + # -- update output accumulator -- + # scale p + p_scale = beta / l_i_new + p = p * p_scale[:, None] + # scale acc + acc_scale = l_i / l_i_new * alpha + tl.store(t_ptrs, acc_scale) + acc_scale = tl.load(t_ptrs) + acc = acc * acc_scale[:, None] + # update acc + v = tl.load( + v_ptrs + (cur_batch_start_index + start_n) * stride_vbs, + mask=(start_n + offs_n[:, None]) < cur_batch_seq_len, + other=0.0, + ) + + p = p.to(v.dtype) + acc += tl.dot(p, v) + # update m_i and l_i + l_i = l_i_new + m_i = m_i_new + + off_o = ( + (cur_batch_start_index + offs_m[:, None]) * stride_obs + cur_head * stride_oh + offs_d[None, :] * stride_od + ) + out_ptrs = Out + off_o + tl.store(out_ptrs, acc, mask=offs_m[:, None] < cur_batch_seq_len) + return + + @torch.no_grad() + def bloom_context_attn_fwd(q, k, v, o, b_start_loc, b_seq_len, max_input_len, alibi=None): + BLOCK = 128 + # shape constraints + Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1] + assert Lq == Lk, "context process only supports equal query, key, value length" + assert Lk == Lv, "context process only supports equal query, key, value length" + assert Lk in {16, 32, 64, 128} + + sm_scale = 1.0 / math.sqrt(Lk) + batch, head = b_seq_len.shape[0], q.shape[1] + + grid = (batch, head, triton.cdiv(max_input_len, BLOCK)) + + num_warps = 4 if Lk <= 64 else 8 + + tmp = torch.empty((batch, head, max_input_len + 256), device=q.device, dtype=torch.float32) + + _context_flash_attention_kernel[grid]( + q, + k, + v, + sm_scale, + b_start_loc, + b_seq_len, + tmp, + alibi, + o, + q.stride(0), + q.stride(1), + q.stride(2), + k.stride(0), + k.stride(1), + k.stride(2), + v.stride(0), + v.stride(1), + v.stride(2), + o.stride(0), + o.stride(1), + o.stride(2), + tmp.stride(0), + tmp.stride(1), + tmp.stride(2), + # manually setting this blcok num, we can use tuning config to futher speed-up + BLOCK_M=BLOCK, + BLOCK_DMODEL=Lk, + BLOCK_N=BLOCK, + num_warps=num_warps, + num_stages=1, + ) + return + + @torch.no_grad() + def llama_context_attn_fwd(q, k, v, o, b_start_loc, b_seq_len, max_input_len): + BLOCK = 128 + # shape constraints + Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1] + assert Lq == Lk, "context process only supports equal query, key, value length" + assert Lk == Lv, "context process only supports equal query, key, value length" + assert Lk in {16, 32, 64, 128} + + sm_scale = 1.0 / math.sqrt(Lk) + batch, head = b_seq_len.shape[0], q.shape[1] + + grid = (batch, head, triton.cdiv(max_input_len, BLOCK)) + + tmp = torch.empty((batch, head, max_input_len + 256), device=q.device, dtype=torch.float32) + num_warps = 4 if Lk <= 64 else 8 + # num_warps = 4 + _context_flash_attention_kernel[grid]( + q, + k, + v, + sm_scale, + b_start_loc, + b_seq_len, + tmp, + None, + o, + q.stride(0), + q.stride(1), + q.stride(2), + k.stride(0), + k.stride(1), + k.stride(2), + v.stride(0), + v.stride(1), + v.stride(2), + o.stride(0), + o.stride(1), + o.stride(2), + tmp.stride(0), + tmp.stride(1), + tmp.stride(2), + BLOCK_M=BLOCK, + BLOCK_DMODEL=Lk, + BLOCK_N=BLOCK, + num_warps=num_warps, + num_stages=1, + ) + return + + @triton.jit + def _fwd_kernel_latest( + Q, + K, + V, + sm_scale, + B_Start_Loc, + B_Seqlen, + Out, + stride_qbs, + stride_qh, + stride_qd, + stride_kbs, + stride_kh, + stride_kd, + stride_vbs, + stride_vh, + stride_vd, + stride_obs, + stride_oh, + stride_od, + kv_group_num, + BLOCK_M: tl.constexpr, + BLOCK_DMODEL: tl.constexpr, + BLOCK_N: tl.constexpr, + ): + cur_batch = tl.program_id(0) + cur_head = tl.program_id(1) + start_m = tl.program_id(2) + + cur_kv_head = cur_head // kv_group_num + + cur_batch_seq_len = tl.load(B_Seqlen + cur_batch) + cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch) + + block_start_loc = BLOCK_M * start_m + + # initialize offsets + offs_n = tl.arange(0, BLOCK_N) + offs_d = tl.arange(0, BLOCK_DMODEL) + offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) + off_q = ( + (cur_batch_in_all_start_index + offs_m[:, None]) * stride_qbs + + cur_head * stride_qh + + offs_d[None, :] * stride_qd + ) + off_k = offs_n[None, :] * stride_kbs + cur_kv_head * stride_kh + offs_d[:, None] * stride_kd + off_v = offs_n[:, None] * stride_vbs + cur_kv_head * stride_vh + offs_d[None, :] * stride_vd + + q = tl.load(Q + off_q, mask=offs_m[:, None] < cur_batch_seq_len, other=0.0) + + k_ptrs = K + off_k + v_ptrs = V + off_v + + # initialize pointer to m and l + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) + + block_mask = tl.where(block_start_loc < cur_batch_seq_len, 1, 0) + + for start_n in range(0, block_mask * (start_m + 1) * BLOCK_M, BLOCK_N): + start_n = tl.multiple_of(start_n, BLOCK_N) + # -- compute qk ---- + k = tl.load( + k_ptrs + (cur_batch_in_all_start_index + start_n) * stride_kbs, + mask=(start_n + offs_n[None, :]) < cur_batch_seq_len, + other=0.0, + ) + # mask = tl.load(mask_ptrs + start_n, mask=start_n + offs_n < cur_batch_end_loc, other=0.0) + + qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) + qk += tl.dot(q, k) + qk *= sm_scale + qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk, float("-inf")) + + # -- compute m_ij, p, l_ij + m_ij = tl.max(qk, 1) + p = tl.exp(qk - m_ij[:, None]) + l_ij = tl.sum(p, 1) + # -- update m_i and l_i + m_i_new = tl.maximum(m_i, m_ij) + alpha = tl.exp(m_i - m_i_new) + beta = tl.exp(m_ij - m_i_new) + l_i_new = alpha * l_i + beta * l_ij + # -- update output accumulator -- + # scale p + p_scale = beta / l_i_new + p = p * p_scale[:, None] + # scale acc + acc_scale = l_i / l_i_new * alpha + acc = acc * acc_scale[:, None] + # update acc + v = tl.load( + v_ptrs + (cur_batch_in_all_start_index + start_n) * stride_vbs, + mask=(start_n + offs_n[:, None]) < cur_batch_seq_len, + other=0.0, + ) + + p = p.to(v.dtype) + acc += tl.dot(p, v) + # update m_i and l_i + l_i = l_i_new + m_i = m_i_new + # initialize pointers to output + off_o = ( + (cur_batch_in_all_start_index + offs_m[:, None]) * stride_obs + + cur_head * stride_oh + + offs_d[None, :] * stride_od + ) + out_ptrs = Out + off_o + tl.store(out_ptrs, acc, mask=offs_m[:, None] < cur_batch_seq_len) + return + + @triton.jit + def _fwd_kernel_old( + Q, + K, + V, + sm_scale, + B_Start_Loc, + B_Seqlen, + TMP, # NOTE: TMP is a scratchpad buffer to workaround a compiler bug + Out, + stride_qbs, + stride_qh, + stride_qd, + stride_kbs, + stride_kh, + stride_kd, + stride_vbs, + stride_vh, + stride_vd, + stride_obs, + stride_oh, + stride_od, + stride_tmp_b, + stride_tmp_h, + stride_tmp_s, + kv_group_num, + BLOCK_M: tl.constexpr, + BLOCK_DMODEL: tl.constexpr, + BLOCK_N: tl.constexpr, + ): + cur_batch = tl.program_id(0) + cur_head = tl.program_id(1) + start_m = tl.program_id(2) + + cur_kv_head = cur_head // kv_group_num + + cur_batch_seq_len = tl.load(B_Seqlen + cur_batch) + cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch) + + block_start_loc = BLOCK_M * start_m + + # initialize offsets + offs_n = tl.arange(0, BLOCK_N) + offs_d = tl.arange(0, BLOCK_DMODEL) + offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) + off_q = ( + (cur_batch_in_all_start_index + offs_m[:, None]) * stride_qbs + + cur_head * stride_qh + + offs_d[None, :] * stride_qd + ) + off_k = offs_n[None, :] * stride_kbs + cur_kv_head * stride_kh + offs_d[:, None] * stride_kd + off_v = offs_n[:, None] * stride_vbs + cur_kv_head * stride_vh + offs_d[None, :] * stride_vd + q = tl.load(Q + off_q, mask=offs_m[:, None] < cur_batch_seq_len, other=0.0) + + k_ptrs = K + off_k + v_ptrs = V + off_v + + t_ptrs = TMP + cur_batch * stride_tmp_b + cur_head * stride_tmp_h + offs_m * stride_tmp_s + # t_ptrs = TMP + offs_m + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) + + block_mask = tl.where(block_start_loc < cur_batch_seq_len, 1, 0) + + for start_n in range(0, block_mask * (start_m + 1) * BLOCK_M, BLOCK_N): + start_n = tl.multiple_of(start_n, BLOCK_N) + # -- compute qk ---- + k = tl.load( + k_ptrs + (cur_batch_in_all_start_index + start_n) * stride_kbs, + mask=(start_n + offs_n[None, :]) < cur_batch_seq_len, + other=0.0, + ) + + qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) + qk += tl.dot(q, k) + qk *= sm_scale + qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk, float("-inf")) + + m_ij = tl.max(qk, 1) + p = tl.exp(qk - m_ij[:, None]) + l_ij = tl.sum(p, 1) + # -- update m_i and l_i + m_i_new = tl.maximum(m_i, m_ij) + alpha = tl.exp(m_i - m_i_new) + beta = tl.exp(m_ij - m_i_new) + l_i_new = alpha * l_i + beta * l_ij + # -- update output accumulator -- + # scale p + p_scale = beta / l_i_new + p = p * p_scale[:, None] + # scale acc + acc_scale = l_i / l_i_new * alpha + tl.store(t_ptrs, acc_scale) + acc_scale = tl.load(t_ptrs) # BUG: have to store and immediately load + acc = acc * acc_scale[:, None] + # update acc + v = tl.load( + v_ptrs + (cur_batch_in_all_start_index + start_n) * stride_vbs, + mask=(start_n + offs_n[:, None]) < cur_batch_seq_len, + other=0.0, + ) + + p = p.to(v.dtype) + acc += tl.dot(p, v) + # update m_i and l_i + l_i = l_i_new + m_i = m_i_new + # initialize pointers to output + off_o = ( + (cur_batch_in_all_start_index + offs_m[:, None]) * stride_obs + + cur_head * stride_oh + + offs_d[None, :] * stride_od + ) + out_ptrs = Out + off_o + tl.store(out_ptrs, acc, mask=offs_m[:, None] < cur_batch_seq_len) + + return + + @torch.no_grad() + def llama2_context_attn_fwd(q, k, v, o, b_start_loc, b_seq_len, max_input_len): + if triton.__version__ >= "2.1.0": + BLOCK = 128 + # shape constraints + Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1] + assert Lq == Lk and Lk == Lv + assert Lk in {16, 32, 64, 128} + sm_scale = 1.0 / (Lq**0.5) # 计算scale系数 + batch, head = b_seq_len.shape[0], q.shape[1] + kv_group_num = q.shape[1] // k.shape[1] + + grid = (batch, head, triton.cdiv(max_input_len, BLOCK)) # batch, head, + + num_warps = 4 if Lk <= 64 else 8 + _fwd_kernel_latest[grid]( + q, + k, + v, + sm_scale, + b_start_loc, + b_seq_len, + o, + q.stride(0), + q.stride(1), + q.stride(2), + k.stride(0), + k.stride(1), + k.stride(2), + v.stride(0), + v.stride(1), + v.stride(2), + o.stride(0), + o.stride(1), + o.stride(2), + kv_group_num=kv_group_num, + BLOCK_M=BLOCK, + BLOCK_DMODEL=Lk, + BLOCK_N=BLOCK, + num_warps=num_warps, + num_stages=1, + ) + return + + elif triton.__version__ == "2.0.0": + BLOCK = 128 + # shape constraints + Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1] + assert Lq == Lk and Lk == Lv + assert Lk in {16, 32, 64, 128} + + sm_scale = 1.0 / (Lq**0.5) + batch, head = b_seq_len.shape[0], q.shape[1] + kv_group_num = q.shape[1] // k.shape[1] + + grid = (batch, head, triton.cdiv(max_input_len, BLOCK)) + tmp = torch.empty((batch, head, max_input_len + 256), device=q.device, dtype=torch.float32) + num_warps = 4 if Lk <= 64 else 8 + # num_warps = 4 + _fwd_kernel_old[grid]( + q, + k, + v, + sm_scale, + b_start_loc, + b_seq_len, + tmp, + o, + q.stride(0), + q.stride(1), + q.stride(2), + k.stride(0), + k.stride(1), + k.stride(2), + v.stride(0), + v.stride(1), + v.stride(2), + o.stride(0), + o.stride(1), + o.stride(2), + tmp.stride(0), + tmp.stride(1), + tmp.stride(2), + kv_group_num=kv_group_num, + BLOCK_M=BLOCK, + BLOCK_DMODEL=Lk, + BLOCK_N=BLOCK, + num_warps=num_warps, + num_stages=1, + ) + return diff --git a/colossalai/kernel/triton/copy_kv_cache_dest.py b/colossalai/kernel/triton/copy_kv_cache_dest.py new file mode 100644 index 0000000000000000000000000000000000000000..02edcc9a903aee4ec7e90f53dd10ab536792cc9e --- /dev/null +++ b/colossalai/kernel/triton/copy_kv_cache_dest.py @@ -0,0 +1,71 @@ +import torch + +try: + import triton + import triton.language as tl + + HAS_TRITON = True +except ImportError: + HAS_TRITON = False + print("please install triton from https://github.com/openai/triton") + +if HAS_TRITON: + + @triton.jit + def _fwd_copy_kv_cache_dest( + kv_cache_ptr, + dest_index_ptr, + out, + stride_k_bs, + stride_k_h, + stride_k_d, + stride_o_bs, + stride_o_h, + stride_o_d, + head_num, + BLOCK_DMODEL: tl.constexpr, + BLOCK_HEAD: tl.constexpr, + ): + cur_index = tl.program_id(0) + offs_h = tl.arange(0, BLOCK_HEAD) + offs_d = tl.arange(0, BLOCK_DMODEL) + + dest_index = tl.load(dest_index_ptr + cur_index) + + cache_offsets = stride_k_h * offs_h[:, None] + stride_k_d * offs_d[None, :] + k_ptrs = kv_cache_ptr + cur_index * stride_k_bs + cache_offsets + + o_offsets = stride_o_h * offs_h[:, None] + stride_o_d * offs_d[None, :] + o_ptrs = out + dest_index * stride_o_bs + o_offsets + + k = tl.load(k_ptrs, mask=offs_h[:, None] < head_num, other=0.0) + tl.store(o_ptrs, k, mask=offs_h[:, None] < head_num) + return + + @torch.no_grad() + def copy_kv_cache_to_dest(k_ptr, dest_index_ptr, out): + seq_len = dest_index_ptr.shape[0] + head_num = k_ptr.shape[1] + head_dim = k_ptr.shape[2] + assert head_num == out.shape[1], "head_num should be the same for k_ptr and out" + assert head_dim == out.shape[2], "head_dim should be the same for k_ptr and out" + + num_warps = 2 + + _fwd_copy_kv_cache_dest[(seq_len,)]( + k_ptr, + dest_index_ptr, + out, + k_ptr.stride(0), + k_ptr.stride(1), + k_ptr.stride(2), + out.stride(0), + out.stride(1), + out.stride(2), + head_num, + BLOCK_DMODEL=head_dim, + BLOCK_HEAD=triton.next_power_of_2(head_num), + num_warps=num_warps, + num_stages=2, + ) + return diff --git a/colossalai/kernel/triton/custom_autotune.py b/colossalai/kernel/triton/custom_autotune.py new file mode 100644 index 0000000000000000000000000000000000000000..17bb1cf0070cc8cd678d041b8b215d34cbb353dd --- /dev/null +++ b/colossalai/kernel/triton/custom_autotune.py @@ -0,0 +1,176 @@ +# code from AutoGPTQ auto_gptq: https://github.com/PanQiWei/AutoGPTQ/blob/main/auto_gptq/nn_modules/triton_utils/custom_autotune.py + +import builtins +import math +import time +from typing import Dict + +import triton + + +class CustomizedTritonAutoTuner(triton.KernelInterface): + def __init__( + self, + fn, + arg_names, + configs, + key, + reset_to_zero, + prune_configs_by: Dict = None, + nearest_power_of_two: bool = False, + ): + if not configs: + self.configs = [triton.Config({}, num_warps=4, num_stages=2)] + else: + self.configs = configs + self.key_idx = [arg_names.index(k) for k in key] + self.nearest_power_of_two = nearest_power_of_two + self.cache = {} + # hook to reset all required tensor to zeros before relaunching a kernel + self.hook = lambda args: 0 + if reset_to_zero is not None: + self.reset_idx = [arg_names.index(k) for k in reset_to_zero] + + def _hook(args): + for i in self.reset_idx: + args[i].zero_() + + self.hook = _hook + self.arg_names = arg_names + # prune configs + if prune_configs_by: + perf_model, top_k = prune_configs_by["perf_model"], prune_configs_by["top_k"] + if "early_config_prune" in prune_configs_by: + early_config_prune = prune_configs_by["early_config_prune"] + else: + perf_model, top_k, early_config_prune = None, None, None + self.perf_model, self.configs_top_k = perf_model, top_k + self.early_config_prune = early_config_prune + self.fn = fn + + def _bench(self, *args, config, **meta): + # check for conflicts, i.e. meta-parameters both provided + # as kwargs and by the autotuner + conflicts = meta.keys() & config.kwargs.keys() + if conflicts: + raise ValueError( + f"Conflicting meta-parameters: {', '.join(conflicts)}." + " Make sure that you don't re-define auto-tuned symbols." + ) + # augment meta-parameters with tunable ones + current = dict(meta, **config.kwargs) + + def kernel_call(): + if config.pre_hook: + config.pre_hook(self.nargs) + self.hook(args) + self.fn.run(*args, num_warps=config.num_warps, num_stages=config.num_stages, **current) + + try: + # In testings using only 40 reps seems to be close enough and it appears to be what PyTorch uses + # PyTorch also sets fast_flush to True, but I didn't see any speedup so I'll leave the default + return triton.testing.do_bench(kernel_call, percentiles=(0.5, 0.2, 0.8), rep=40) + except triton.compiler.OutOfResources: + return (float("inf"), float("inf"), float("inf")) + + def run(self, *args, **kwargs): + self.nargs = dict(zip(self.arg_names, args)) + if len(self.configs) > 1: + key = tuple(args[i] for i in self.key_idx) + + # This reduces the amount of autotuning by rounding the keys to the nearest power of two + # In my testing this gives decent results, and greatly reduces the amount of tuning required + if self.nearest_power_of_two: + key = tuple([2 ** int(math.log2(x) + 0.5) for x in key]) + + if key not in self.cache: + # prune configs + pruned_configs = self.prune_configs(kwargs) + bench_start = time.time() + timings = {config: self._bench(*args, config=config, **kwargs) for config in pruned_configs} + bench_end = time.time() + self.bench_time = bench_end - bench_start + self.cache[key] = builtins.min(timings, key=timings.get) + self.hook(args) + self.configs_timings = timings + config = self.cache[key] + else: + config = self.configs[0] + self.best_config = config + if config.pre_hook is not None: + config.pre_hook(self.nargs) + return self.fn.run(*args, num_warps=config.num_warps, num_stages=config.num_stages, **kwargs, **config.kwargs) + + def prune_configs(self, kwargs): + pruned_configs = self.configs + if self.early_config_prune: + pruned_configs = self.early_config_prune(self.configs, self.nargs) + if self.perf_model: + top_k = self.configs_top_k + if isinstance(top_k, float) and top_k <= 1.0: + top_k = int(len(self.configs) * top_k) + if len(pruned_configs) > top_k: + est_timing = { + config: self.perf_model( + **self.nargs, + **kwargs, + **config.kwargs, + num_stages=config.num_stages, + num_warps=config.num_warps, + ) + for config in pruned_configs + } + pruned_configs = sorted(est_timing.keys(), key=lambda x: est_timing[x])[:top_k] + return pruned_configs + + def warmup(self, *args, **kwargs): + self.nargs = dict(zip(self.arg_names, args)) + for config in self.prune_configs(kwargs): + self.fn.warmup( + *args, + num_warps=config.num_warps, + num_stages=config.num_stages, + **kwargs, + **config.kwargs, + ) + self.nargs = None + + +def autotune(configs, key, prune_configs_by=None, reset_to_zero=None, nearest_power_of_two=False): + def decorator(fn): + return CustomizedTritonAutoTuner( + fn, fn.arg_names, configs, key, reset_to_zero, prune_configs_by, nearest_power_of_two + ) + + return decorator + + +def matmul248_kernel_config_pruner(configs, nargs): + """ + The main purpose of this function is to shrink BLOCK_SIZE_* when the corresponding dimension is smaller. + """ + m = max(2 ** int(math.ceil(math.log2(nargs["M"]))), 16) + n = max(2 ** int(math.ceil(math.log2(nargs["N"]))), 16) + k = max(2 ** int(math.ceil(math.log2(nargs["K"]))), 16) + + used = set() + for config in configs: + block_size_m = min(m, config.kwargs["BLOCK_SIZE_M"]) + block_size_n = min(n, config.kwargs["BLOCK_SIZE_N"]) + block_size_k = min(k, config.kwargs["BLOCK_SIZE_K"]) + group_size_m = config.kwargs["GROUP_SIZE_M"] + + if (block_size_m, block_size_n, block_size_k, group_size_m, config.num_stages, config.num_warps) in used: + continue + + used.add((block_size_m, block_size_n, block_size_k, group_size_m, config.num_stages, config.num_warps)) + yield triton.Config( + { + "BLOCK_SIZE_M": block_size_m, + "BLOCK_SIZE_N": block_size_n, + "BLOCK_SIZE_K": block_size_k, + "GROUP_SIZE_M": group_size_m, + }, + num_stages=config.num_stages, + num_warps=config.num_warps, + ) diff --git a/colossalai/kernel/triton/fused_layernorm.py b/colossalai/kernel/triton/fused_layernorm.py new file mode 100644 index 0000000000000000000000000000000000000000..24083b0508080b045819d703dd91093591c09cd9 --- /dev/null +++ b/colossalai/kernel/triton/fused_layernorm.py @@ -0,0 +1,78 @@ +import torch + +try: + import triton + import triton.language as tl + + HAS_TRITON = True +except ImportError: + HAS_TRITON = False + print("please install triton from https://github.com/openai/triton") + +if HAS_TRITON: + # CREDITS: These functions are adapted from the Triton tutorial + # https://triton-lang.org/main/getting-started/tutorials/05-layer-norm.html + + @triton.jit + def _layer_norm_fwd_fused( + X, # pointer to the input + Y, # pointer to the output + W, # pointer to the weights + B, # pointer to the biases + stride, # how much to increase the pointer when moving by 1 row + N, # number of columns in X + eps, # epsilon to avoid division by zero + BLOCK_SIZE: tl.constexpr, + ): + # Map the program id to the row of X and Y it should compute. + row = tl.program_id(0) + Y += row * stride + X += row * stride + # Compute mean + mean = 0 + _mean = tl.zeros([BLOCK_SIZE], dtype=tl.float32) + for off in range(0, N, BLOCK_SIZE): + cols = off + tl.arange(0, BLOCK_SIZE) + a = tl.load(X + cols, mask=cols < N, other=0.0).to(tl.float32) + _mean += a + mean = tl.sum(_mean, axis=0) / N + # Compute variance + _var = tl.zeros([BLOCK_SIZE], dtype=tl.float32) + for off in range(0, N, BLOCK_SIZE): + cols = off + tl.arange(0, BLOCK_SIZE) + x = tl.load(X + cols, mask=cols < N, other=0.0).to(tl.float32) + x = tl.where(cols < N, x - mean, 0.0) + _var += x * x + var = tl.sum(_var, axis=0) / N + rstd = 1 / tl.sqrt(var + eps) + # Normalize and apply linear transformation + for off in range(0, N, BLOCK_SIZE): + cols = off + tl.arange(0, BLOCK_SIZE) + mask = cols < N + w = tl.load(W + cols, mask=mask) + b = tl.load(B + cols, mask=mask) + x = tl.load(X + cols, mask=mask, other=0.0).to(tl.float32) + x_hat = (x - mean) * rstd + y = x_hat * w + b + # Write output + tl.store(Y + cols, y.to(tl.float16), mask=mask) + + @torch.no_grad() + def layer_norm(x, weight, bias, eps): + # allocate output + y = torch.empty_like(x) + # reshape input data into 2D tensor + x_arg = x.reshape(-1, x.shape[-1]) + M, N = x_arg.shape + # Less than 64KB per feature: enqueue fused kernel + MAX_FUSED_SIZE = 65536 // x.element_size() + BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(N)) + if N > BLOCK_SIZE: + raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.") + # heuristics for number of warps + num_warps = min(max(BLOCK_SIZE // 256, 1), 8) + # enqueue kernel + _layer_norm_fwd_fused[(M,)]( + x_arg, y, weight, bias, x_arg.stride(0), N, eps, BLOCK_SIZE=BLOCK_SIZE, num_warps=num_warps + ) + return y diff --git a/colossalai/kernel/triton/gptq_triton.py b/colossalai/kernel/triton/gptq_triton.py new file mode 100644 index 0000000000000000000000000000000000000000..8460103e261d59e31312473d2b52a2298895f085 --- /dev/null +++ b/colossalai/kernel/triton/gptq_triton.py @@ -0,0 +1,542 @@ +# Adapted from AutoGPTQ auto_gptq: https://github.com/PanQiWei/AutoGPTQ + +import torch +import triton +import triton.language as tl + +from .custom_autotune import autotune, matmul248_kernel_config_pruner + + +@triton.jit +def tanh(x): + # Tanh is just a scaled sigmoid + return 2 * tl.sigmoid(2 * x) - 1 + + +@triton.jit +def cosh(x): + exp_x = tl.exp(x) + return (exp_x + 1.0 / exp_x) * 0.5 + + +# a Triton implementation of the most used activations +# See for instance http://arxiv.org/abs/1606.08415 for an overview + + +# ReLU +@triton.jit +def relu(x): + """ + ReLU_ activation function + + .. _ReLU: https://pytorch.org/docs/stable/generated/torch.nn.ReLU.html + """ + return tl.where(x >= 0, x, 0.0) + + +@triton.jit +def squared_relu(x): + """ + Squared ReLU activation, as proposed in the Primer_ paper. + + .. _Primer: https://arxiv.org/abs/2109.08668 + """ + x_sq = x * x + return tl.where(x > 0.0, x_sq, 0.0) + + +@triton.jit +def star_relu(x): + """ + Star ReLU activation, as proposed in the "MetaFormer Baselines for Vision"_ paper. + + .. _ "MetaFormer Baselines for Vision": https://arxiv.org/pdf/2210.13452.pdf + """ + x_sq = x * x + return 0.8944 * tl.where(x > 0.0, x_sq, 0.0) - 0.4472 + + +# Leaky ReLU +@triton.jit +def leaky_relu(x): + """ + LeakyReLU_ activation + + .. _LeakyReLU: https://pytorch.org/docs/stable/generated/torch.nn.LeakyReLU.html + """ + return tl.where(x >= 0.0, x, 0.01 * x) + + +@triton.jit +def gelu(x): + """ + GeLU_ activation - Gaussian error linear unit + + .. _GeLU: https://arxiv.org/pdf/1606.08415.pdf + """ + return 0.5 * x * (1 + tanh(_kAlpha * (x + 0.044715 * x * x * x))) + + +@triton.jit +def smelu(x): + """ + SmeLU_ activation - Smooth ReLU with beta=2.0 + + .. _SmeLU: https://arxiv.org/pdf/2202.06499.pdf + """ + beta = 2.0 + + relu = tl.where(x >= beta, x, 0.0) + return tl.where(tl.abs(x) <= beta, (x + beta) * (x + beta) / (4.0 * beta), relu) + + +@triton.jit +def silu(x): + return x * tl.sigmoid(x) + + +@autotune( + configs=[ + triton.Config( + {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 32, "GROUP_SIZE_M": 8}, num_stages=4, num_warps=4 + ), + triton.Config( + {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32, "GROUP_SIZE_M": 8}, num_stages=4, num_warps=4 + ), + triton.Config( + {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32, "GROUP_SIZE_M": 8}, num_stages=4, num_warps=4 + ), + triton.Config( + {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32, "GROUP_SIZE_M": 8}, num_stages=4, num_warps=4 + ), + triton.Config( + {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32, "GROUP_SIZE_M": 8}, num_stages=4, num_warps=4 + ), + triton.Config( + {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32, "GROUP_SIZE_M": 8}, num_stages=2, num_warps=8 + ), + triton.Config( + {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 8}, num_stages=3, num_warps=8 + ), + triton.Config( + {"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 8}, num_stages=2, num_warps=4 + ), + ], + key=["M", "N", "K"], + nearest_power_of_two=True, + prune_configs_by={ + "early_config_prune": matmul248_kernel_config_pruner, + "perf_model": None, + "top_k": None, + }, +) +@triton.jit +def cai_gptq_matmul_248_kernel( + a_ptr, + b_ptr, + c_ptr, + scales_ptr, + zeros_ptr, + bias_ptr, + residual_ptr, + M, + N, + K, + bits, + maxq, + gptq_group_size, + stride_am, + stride_ak, + stride_bk, + stride_bn, + stride_cm, + stride_cn, + stride_scales, + stride_zeros, + QKV_FUSED: tl.constexpr, + ADD_BIAS: tl.constexpr, + ADD_RESIDUAL: tl.constexpr, + ACT_TYPE: tl.constexpr, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, +): + """ + Compute the matrix multiplication C = A x B. + A is of shape (M, K) float16 + B is of shape (K//8, N) int32 + C is of shape (M, N) float16 + scales is of shape (G, N) float16 + zeros is of shape (G, N) float16 + """ + infearure_per_bits = 32 // bits + + pid = tl.program_id(axis=0) + NK = K + + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + num_pid_k = tl.cdiv(NK, BLOCK_SIZE_K) + qkv_offset = pid // (num_pid_m * num_pid_n) + pid = pid % (num_pid_m * num_pid_n) + num_pid_in_group = GROUP_SIZE_M * num_pid_n + group_id = pid // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + (pid % group_size_m) + pid_n = (pid % num_pid_in_group) // group_size_m + + offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + offs_k = tl.arange(0, BLOCK_SIZE_K) + # offs_bk = offs_k + qkv_offset * NK + a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) # (BLOCK_SIZE_M, BLOCK_SIZE_K) + + a_mask = offs_am[:, None] < M + # b_ptrs is set up such that it repeats elements along the K axis 8 times + b_ptrs = ( + b_ptr + + qkv_offset * N * NK // infearure_per_bits + + ((offs_k[:, None] // infearure_per_bits) * stride_bk + offs_bn[None, :] * stride_bn) + ) # (BLOCK_SIZE_K, BLOCK_SIZE_N) + # g_ptrs = g_ptr + offs_k + # shifter is used to extract the N bits of each element in the 32-bit word from B + scales_ptrs = scales_ptr + qkv_offset * NK * N // gptq_group_size + offs_bn[None, :] + zeros_ptrs = ( + zeros_ptr + + qkv_offset * NK * N // gptq_group_size // infearure_per_bits + + (offs_bn[None, :] // infearure_per_bits) + ) + + shifter = (offs_k % infearure_per_bits) * bits + zeros_shifter = (offs_bn % infearure_per_bits) * bits + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + g_idx_base = tl.arange(0, BLOCK_SIZE_K) + g_idx_base = g_idx_base // gptq_group_size + g_idx = g_idx_base + # tl.device_print("gidx, ", g_idx) + + scales = tl.load(scales_ptrs + g_idx[:, None] * stride_scales) # (BLOCK_SIZE_K, BLOCK_SIZE_N,) + zeros = tl.load(zeros_ptrs + g_idx[:, None] * stride_zeros) # (BLOCK_SIZE_K, BLOCK_SIZE_N,) + zeros = (zeros >> zeros_shifter[None, :]) & maxq + zeros = zeros + 1 + + for k in range(0, num_pid_k): + # g_idx = tl.load(g_ptrs) + # if (k + 1) * BLOCK_SIZE_K > currend_group_end: + scales = tl.load(scales_ptrs + g_idx[:, None] * stride_scales) # (BLOCK_SIZE_K, BLOCK_SIZE_N,) + zeros = tl.load(zeros_ptrs + g_idx[:, None] * stride_zeros) # (BLOCK_SIZE_K, BLOCK_SIZE_N,) + zeros = (zeros >> zeros_shifter[None, :]) & maxq + zeros = zeros + 1 + # Fetch scales and zeros; these are per-outfeature and thus reused in the inner loop + a = tl.load(a_ptrs, mask=a_mask, other=0.0) # (BLOCK_SIZE_M, BLOCK_SIZE_K) + b = tl.load(b_ptrs) # (BLOCK_SIZE_K, BLOCK_SIZE_N), but repeated + # Now we need to unpack b (which is N-bit values) into 32-bit values + b = (b >> shifter[:, None]) & maxq # Extract the N-bit values + b = (b - zeros).to(tl.float16) * scales # Scale and shift + accumulator += tl.dot(a, b) + + a_ptrs += BLOCK_SIZE_K + b_ptrs += (BLOCK_SIZE_K // infearure_per_bits) * stride_bk + g_idx = g_idx_base + ((k + 1) * BLOCK_SIZE_K) // gptq_group_size + # if (k + 2) * BLOCK_SIZE_K > currend_group_end: + + c_ptrs = c_ptr + qkv_offset * M * N + stride_cm * offs_am[:, None] + stride_cn * offs_bn[None, :] + c_mask = (offs_am[:, None] < M) & (offs_bn[None, :] < N) + + if ADD_BIAS: + bias_mask = offs_bn < N + offs_bn += qkv_offset * N + bias_ptrs = bias_ptr + stride_cn * offs_bn + bias = tl.load(bias_ptrs, mask=bias_mask, other=0.0) # (BLOCK_SIZE_M, BLOCK_SIZE_K) + accumulator += bias[None, :] + + if ACT_TYPE == 1: + accumulator = relu(accumulator) + elif ACT_TYPE == 2: + accumulator = gelu(accumulator) + elif ACT_TYPE == 3: + accumulator = silu(accumulator) + + if ADD_RESIDUAL: + residual_ptrs = residual_ptr + stride_cm * offs_am[:, None] + stride_cn * offs_bn[None, :] + res = tl.load(residual_ptrs, mask=c_mask, other=0.0) + accumulator += res + + tl.store(c_ptrs, accumulator, mask=c_mask) + + +@autotune( + configs=[ + triton.Config( + {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 32, "GROUP_SIZE_M": 8}, num_stages=4, num_warps=4 + ), + triton.Config( + {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32, "GROUP_SIZE_M": 8}, num_stages=4, num_warps=4 + ), + triton.Config( + {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32, "GROUP_SIZE_M": 8}, num_stages=4, num_warps=4 + ), + triton.Config( + {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32, "GROUP_SIZE_M": 8}, num_stages=4, num_warps=4 + ), + triton.Config( + {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32, "GROUP_SIZE_M": 8}, num_stages=4, num_warps=4 + ), + triton.Config( + {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32, "GROUP_SIZE_M": 8}, num_stages=2, num_warps=8 + ), + triton.Config( + {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 8}, num_stages=3, num_warps=8 + ), + triton.Config( + {"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 8}, num_stages=2, num_warps=4 + ), + ], + key=["M", "N", "K"], + nearest_power_of_two=True, + prune_configs_by={ + "early_config_prune": matmul248_kernel_config_pruner, + "perf_model": None, + "top_k": None, + }, +) +@triton.jit +def cai_gptq_idx_matmul_248_kernel( + a_ptr, + b_ptr, + c_ptr, + scales_ptr, + zeros_ptr, + idx_ptr, + bias_ptr, + residual_ptr, + M, + N, + K, + bits, + maxq, + gptq_group_size, + stride_am, + stride_ak, + stride_bk, + stride_bn, + stride_cm, + stride_cn, + stride_scales, + stride_zeros, + QKV_FUSED: tl.constexpr, + ADD_BIAS: tl.constexpr, + ADD_RESIDUAL: tl.constexpr, + ACT_TYPE: tl.constexpr, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, +): + """ + Compute the matrix multiplication C = A x B. + A is of shape (M, K) float16 + B is of shape (K//8, N) int32 + C is of shape (M, N) float16 + scales is of shape (G, N) float16 + zeros is of shape (G, N) float16 + """ + infearure_per_bits = 32 // bits + + pid = tl.program_id(axis=0) + NK = K + + # if QKV_FUSED: + # NK = K//3 + # else: + # NK = K + # NK = K + + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + num_pid_k = tl.cdiv(NK, BLOCK_SIZE_K) + qkv_offset = pid // (num_pid_m * num_pid_n) + pid = pid % (num_pid_m * num_pid_n) + num_pid_in_group = GROUP_SIZE_M * num_pid_n + group_id = pid // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + (pid % group_size_m) + pid_n = (pid % num_pid_in_group) // group_size_m + + offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + offs_k = tl.arange(0, BLOCK_SIZE_K) + # offs_bk = offs_k + qkv_offset * NK + a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) # (BLOCK_SIZE_M, BLOCK_SIZE_K) + + a_mask = offs_am[:, None] < M + # b_ptrs is set up such that it repeats elements along the K axis 8 times + b_ptrs = ( + b_ptr + + qkv_offset * N * NK // infearure_per_bits + + ((offs_k[:, None] // infearure_per_bits) * stride_bk + offs_bn[None, :] * stride_bn) + ) # (BLOCK_SIZE_K, BLOCK_SIZE_N) + # g_ptrs = g_ptr + offs_k + # shifter is used to extract the N bits of each element in the 32-bit word from B + scales_ptrs = scales_ptr + qkv_offset * NK * N // gptq_group_size + offs_bn[None, :] + zeros_ptrs = ( + zeros_ptr + + qkv_offset * NK * N // gptq_group_size // infearure_per_bits + + (offs_bn[None, :] // infearure_per_bits) + ) + + shifter = (offs_k % infearure_per_bits) * bits + zeros_shifter = (offs_bn % infearure_per_bits) * bits + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + g_ptrs = idx_ptr + offs_k + g_idx = tl.load(g_ptrs) + # tl.device_print("gidx, ", g_idx) + zeros_ptrs = zeros_ptr + (offs_bn[None, :] // infearure_per_bits) + + scales = tl.load(scales_ptrs + g_idx[:, None] * stride_scales) # (BLOCK_SIZE_K, BLOCK_SIZE_N,) + + for k in range(0, num_pid_k): + g_idx = tl.load(g_ptrs) + scales = tl.load(scales_ptrs + g_idx[:, None] * stride_scales) # (BLOCK_SIZE_K, BLOCK_SIZE_N,) + zeros = tl.load(zeros_ptrs + g_idx[:, None] * stride_zeros) # (BLOCK_SIZE_K, BLOCK_SIZE_N,) + + zeros = (zeros >> zeros_shifter[None, :]) & maxq + zeros = zeros + 1 + + # Fetch scales and zeros; these are per-outfeature and thus reused in the inner loop + a = tl.load(a_ptrs, mask=a_mask, other=0.0) # (BLOCK_SIZE_M, BLOCK_SIZE_K) + b = tl.load(b_ptrs) # (BLOCK_SIZE_K, BLOCK_SIZE_N), but repeated + # Now we need to unpack b (which is N-bit values) into 32-bit values + b = (b >> shifter[:, None]) & maxq # Extract the N-bit values + b = (b - zeros).to(tl.float16) * scales # Scale and shift + accumulator += tl.dot(a, b) + + a_ptrs += BLOCK_SIZE_K + b_ptrs += (BLOCK_SIZE_K // infearure_per_bits) * stride_bk + g_ptrs += BLOCK_SIZE_K + + c_ptrs = c_ptr + qkv_offset * M * N + stride_cm * offs_am[:, None] + stride_cn * offs_bn[None, :] + c_mask = (offs_am[:, None] < M) & (offs_bn[None, :] < N) + + if ADD_BIAS: + bias_mask = offs_bn < N + offs_bn += qkv_offset * N + bias_ptrs = bias_ptr + stride_cn * offs_bn + bias = tl.load(bias_ptrs, mask=bias_mask, other=0.0) # (BLOCK_SIZE_M, BLOCK_SIZE_K) + accumulator += bias[None, :] + + if ACT_TYPE == 1: + accumulator = relu(accumulator) + elif ACT_TYPE == 2: + accumulator = gelu(accumulator) + elif ACT_TYPE == 3: + accumulator = silu(accumulator) + + if ADD_RESIDUAL: + residual_ptrs = residual_ptr + stride_cm * offs_am[:, None] + stride_cn * offs_bn[None, :] + res = tl.load(residual_ptrs, mask=c_mask, other=0.0) + accumulator += res + + tl.store(c_ptrs, accumulator, mask=c_mask) + + +def gptq_fused_linear_triton( + input, + qweight, + scales, + qzeros, + bias, + residual, + bits, + maxq, + gptq_group_size, + qkv_fused, + add_bias, + add_residual, + g_idx=None, + act_type=0, +): + # print("gptq fused ", qkv_fused, add_bias, add_residual) + assert input.is_cuda, "input is not in cuda" + assert qweight.is_cuda, "qweight is not in cuda" + assert scales.is_cuda, "scales is not in cuda" + assert qzeros.is_cuda, "qzeros is not in cuda" + + with torch.cuda.device(input.device): + if qkv_fused: + grid = lambda META: ( + triton.cdiv(input.shape[0], META["BLOCK_SIZE_M"]) + * triton.cdiv(qweight.shape[1], META["BLOCK_SIZE_N"]) + * 3, + ) + output = torch.empty((input.shape[0] * 3, qweight.shape[1]), device=input.device, dtype=torch.float16) + else: + grid = lambda META: ( + triton.cdiv(input.shape[0], META["BLOCK_SIZE_M"]) * triton.cdiv(qweight.shape[1], META["BLOCK_SIZE_N"]), + ) + output = torch.empty((input.shape[0], qweight.shape[1]), device=input.device, dtype=torch.float16) + # print("dtype, ", qweight.dtype, output.dtype, scales.dtype, qzeros.dtype, bias.dtype, residual.dtype) + if g_idx is None: + cai_gptq_matmul_248_kernel[grid]( + input, + qweight, + output, + scales, + qzeros, + bias, + residual, + input.shape[0], + qweight.shape[1], + input.shape[1], + bits, + maxq, + gptq_group_size, + input.stride(0), + input.stride(1), + qweight.stride(0), + qweight.stride(1), + output.stride(0), + output.stride(1), + scales.stride(0), + qzeros.stride(0), + QKV_FUSED=qkv_fused, + ADD_BIAS=add_bias, + ADD_RESIDUAL=add_residual, + ACT_TYPE=act_type, + ) + else: + cai_gptq_idx_matmul_248_kernel[grid]( + input, + qweight, + output, + scales, + qzeros, + g_idx, + bias, + residual, + input.shape[0], + qweight.shape[1], + input.shape[1], + bits, + maxq, + gptq_group_size, + input.stride(0), + input.stride(1), + qweight.stride(0), + qweight.stride(1), + output.stride(0), + output.stride(1), + scales.stride(0), + qzeros.stride(0), + QKV_FUSED=qkv_fused, + ADD_BIAS=add_bias, + ADD_RESIDUAL=add_residual, + ACT_TYPE=act_type, + ) + if qkv_fused: + return output.view(3, input.shape[0], qweight.shape[1]) + else: + return output diff --git a/colossalai/kernel/triton/qkv_matmul_kernel.py b/colossalai/kernel/triton/qkv_matmul_kernel.py new file mode 100644 index 0000000000000000000000000000000000000000..7b5cd2923f0e9b3fa781d47dfa0ade49e0f58563 --- /dev/null +++ b/colossalai/kernel/triton/qkv_matmul_kernel.py @@ -0,0 +1,115 @@ +try: + import triton + import triton.language as tl + + HAS_TRITON = True +except ImportError: + HAS_TRITON = False + print("please install triton from https://github.com/openai/triton") + + +if HAS_TRITON: + """ + this kernel function is modified from https://triton-lang.org/main/getting-started/tutorials/03-matrix-multiplication.html + """ + + @triton.jit + def qkv_gemm_4d_kernel( + a_ptr, + b_ptr, + c_ptr, + M, + N, + K, + stride_ab, + stride_ah, + stride_am, + stride_ak, + stride_bb, + stride_bh, + stride_bk, + stride_bn, + stride_cb, + stride_ch, + stride_cm, + stride_cn, + scale, + # Meta-parameters + BLOCK_SIZE_M: tl.constexpr = 64, + BLOCK_SIZE_N: tl.constexpr = 32, + BLOCK_SIZE_K: tl.constexpr = 32, + GROUP_SIZE_M: tl.constexpr = 8, + ): + r"""A kernel function which is used to do batch-matmul for Q*K^T or score_matrix * V for attention layer, + where score_matrix is softmax(Q*V^T/sqrt(hidden_size)) + Args: + a_ptr(torch.Tensor): pointer to input tensor array (bs, M, h, K) or (bs, h, M, K) + b_ptr(torch.Tensor): pointer to input tensor array (bs, N, h, K) or (bs, h, N, K) + c_ptr(torch.Tensor): pointer to output tensor array (bs, M, h, N) or (bs, h, M, N) + stride_ab(tl.constexpr): stride for bs-dimention for tensor array A + stride_ah(tl.constexpr): stride for h-dimention for tensor array A + stride_am(tl.constexpr): stride for m-dimention for tensor array A + stride_ak(tl.constexpr): stride for k-dimention for tensor array A + stride_bb(tl.constexpr): stride for bs-dimention for tensor array B + stride_bh(tl.constexpr): stride for h-dimention for tensor array B + stride_bk(tl.constexpr): stride for k-dimention for tensor array B + stride_bn(tl.constexpr): stride for n-dimention for tensor array B + stride_cb(tl.constexpr): stride for bs-dimention for tensor array output + stride_ch(tl.constexpr): stride for h-dimention for tensor array output + stride_cm(tl.constexpr): stride for m-dimention for tensor array output + stride_cn(tl.constexpr): stride for n-dimention for tensor array output + BLOCK_SIZE_M : tiling size for M-dimension of tensor Array a + BLOCK_SIZE_N : tiling size for N-dimension of tensor Array b + BLOCK_SIZE_K : tiling size for K-dimension of a and b + GROUP_SIZE_M : group size for reducing cache miss, more details: + """ + + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + batch = tl.program_id(axis=0) + head = tl.program_id(axis=1) + pid = tl.program_id(axis=2) + + # the following is from tutorial: https://triton-lang.org/main/getting-started/tutorials/03-matrix-multiplication.html + num_pid_in_group = GROUP_SIZE_M * num_pid_n + group_id = pid // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + (pid % group_size_m) + pid_n = (pid % num_pid_in_group) // group_size_m + + offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = ( + a_ptr + batch * stride_ab + head * stride_ah + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) + ) + b_ptrs = ( + b_ptr + batch * stride_bb + head * stride_bh + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn) + ) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + for k in range(0, K, BLOCK_SIZE_K): + a_mask = (offs_am[:, None] < M) & (offs_k[None, :] + k < K) + b_mask = (offs_k[:, None] + k < K) & (offs_bn[None, :] < N) + a = tl.load(a_ptrs, mask=a_mask, other=0.0) + b = tl.load(b_ptrs, mask=b_mask, other=0.0) + accumulator += tl.dot(a, b) + a_ptrs += BLOCK_SIZE_K * stride_ak + b_ptrs += BLOCK_SIZE_K * stride_bk + + accumulator = accumulator.to(c_ptr.dtype.element_ty) + if scale > 0: + accumulator = accumulator * scale.to(c_ptr.dtype.element_ty) + + offs_accumu_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_accumu_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = ( + c_ptr + + batch * stride_cb + + head * stride_ch + + stride_cm * offs_accumu_m[:, None] + + stride_cn * offs_accumu_n[None, :] + ) + accumulator_mask = (offs_accumu_m[:, None] < M) & (offs_accumu_n[None, :] < N) + tl.store(c_ptrs, accumulator, mask=accumulator_mask) diff --git a/colossalai/kernel/triton/rms_norm.py b/colossalai/kernel/triton/rms_norm.py new file mode 100644 index 0000000000000000000000000000000000000000..d5d6f9d85df1fbd7a25e5df7b239abc5cd824bc6 --- /dev/null +++ b/colossalai/kernel/triton/rms_norm.py @@ -0,0 +1,71 @@ +import torch + +try: + import triton + import triton.language as tl + + HAS_TRITON = True +except ImportError: + HAS_TRITON = False + print("please install triton from https://github.com/openai/triton") + + +if HAS_TRITON: + """ + this kernel function is modified from + https://github.com/ModelTC/lightllm/blob/main/lightllm/models/llama/triton_kernel/rmsnorm.py + """ + + @triton.jit + def _rms_norm_fwd_fused( + X, # pointer to the input + Y, # pointer to the output + W, # pointer to the weights + stride, # how much to increase the pointer when moving by 1 row + N, # number of columns in X + eps, # epsilon to avoid division by zero + BLOCK_SIZE: tl.constexpr, + ): + # Map the program id to the row of X and Y it should compute. + row = tl.program_id(0) + Y += row * stride + X += row * stride + # Compute variance + _var = tl.zeros([BLOCK_SIZE], dtype=tl.float32) + for off in range(0, N, BLOCK_SIZE): + cols = off + tl.arange(0, BLOCK_SIZE) + x = tl.load(X + cols, mask=cols < N, other=0.0).to(tl.float32) + _var += x * x + var = tl.sum(_var, axis=0) / N + rstd = 1 / tl.sqrt(var + eps) + # Normalize and apply linear transformation + for off in range(0, N, BLOCK_SIZE): + cols = off + tl.arange(0, BLOCK_SIZE) + mask = cols < N + w = tl.load(W + cols, mask=mask).to(tl.float32) + x = tl.load(X + cols, mask=mask, other=0.0).to(tl.float32) + x_hat = x * rstd + y = x_hat * w + # Write output + tl.store(Y + cols, y.to(tl.float16), mask=mask) + + def rmsnorm_forward(x, weight, eps): + # allocate output + y = torch.empty_like(x) + # reshape input data into 2D tensor + x_arg = x.view(-1, x.shape[-1]) + M, N = x_arg.shape + # Less than 64KB per feature: enqueue fused kernel + MAX_FUSED_SIZE = 65536 // x.element_size() + BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(N)) + # print("BLOCK_SIZE:", BLOCK_SIZE) + if N > BLOCK_SIZE: + raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.") + # heuristics for number of warps + num_warps = min(max(BLOCK_SIZE // 256, 1), 8) + # print(BLOCK_SIZE, num_warps, "block_size, numwarps") + BLOCK_SIZE = 128 * 2 * 2 * 2 * 2 * 2 * 2 * 2 + num_warps = 8 + # enqueue kernel + _rms_norm_fwd_fused[(M,)](x_arg, y, weight, x_arg.stride(0), N, eps, BLOCK_SIZE=BLOCK_SIZE, num_warps=num_warps) + return y diff --git a/colossalai/kernel/triton/rotary_embedding_kernel.py b/colossalai/kernel/triton/rotary_embedding_kernel.py new file mode 100644 index 0000000000000000000000000000000000000000..fd74ba817551e4b1dc61edc78de03070e2408cb2 --- /dev/null +++ b/colossalai/kernel/triton/rotary_embedding_kernel.py @@ -0,0 +1,212 @@ +# Adapted from ModelTC https://github.com/ModelTC/lightllm +import torch +import triton +import triton.language as tl + + +@triton.jit +def _rotary_kernel( + q, + Cos, + Sin, + q_bs_stride, + q_h_stride, + q_d_stride, + cos_bs_stride, + cos_d_stride, + total_len, + HEAD_NUM: tl.constexpr, + BLOCK_HEAD: tl.constexpr, + BLOCK_SEQ: tl.constexpr, + HEAD_DIM: tl.constexpr, +): + current_head_index = tl.program_id(0) + current_seq_index = tl.program_id(1) + + current_head_range = current_head_index * BLOCK_HEAD + tl.arange(0, BLOCK_HEAD) + current_seq_range = current_seq_index * BLOCK_SEQ + tl.arange(0, BLOCK_SEQ) + + dim_range0 = tl.arange(0, HEAD_DIM // 2) + dim_range1 = tl.arange(HEAD_DIM // 2, HEAD_DIM) + + off_q0 = ( + current_seq_range[:, None, None] * q_bs_stride + + current_head_range[None, :, None] * q_h_stride + + dim_range0[None, None, :] * q_d_stride + ) + off_q1 = ( + current_seq_range[:, None, None] * q_bs_stride + + current_head_range[None, :, None] * q_h_stride + + dim_range1[None, None, :] * q_d_stride + ) + + off_dimcos_sin = current_seq_range[:, None, None] * cos_bs_stride + dim_range0[None, None, :] * cos_d_stride + + q0 = tl.load( + q + off_q0, + mask=(current_seq_range[:, None, None] < total_len) & (current_head_range[None, :, None] < HEAD_NUM), + other=0.0, + ) + q1 = tl.load( + q + off_q1, + mask=(current_seq_range[:, None, None] < total_len) & (current_head_range[None, :, None] < HEAD_NUM), + other=0.0, + ) + + cos = tl.load(Cos + off_dimcos_sin, mask=current_seq_range[:, None, None] < total_len, other=0.0) + sin = tl.load(Sin + off_dimcos_sin, mask=current_seq_range[:, None, None] < total_len, other=0.0) + + out0 = q0 * cos - q1 * sin + out1 = q0 * sin + q1 * cos + + tl.store( + q + off_q0, + out0, + mask=(current_seq_range[:, None, None] < total_len) & (current_head_range[None, :, None] < HEAD_NUM), + ) + tl.store( + q + off_q1, + out1, + mask=(current_seq_range[:, None, None] < total_len) & (current_head_range[None, :, None] < HEAD_NUM), + ) + + return + + +@torch.no_grad() +def rotary_embedding_fwd(q, cos, sin): + total_len = q.shape[0] + head_num = q.shape[1] + head_dim = q.shape[2] + assert q.shape[0] == cos.shape[0] and q.shape[0] == sin.shape[0], f"q shape {q.shape} cos shape {cos.shape}" + BLOCK_HEAD = 4 + BLOCK_SEQ = 32 + grid = (triton.cdiv(head_num, BLOCK_HEAD), triton.cdiv(total_len, BLOCK_SEQ)) + if head_dim >= 128: + num_warps = 8 + else: + num_warps = 4 + + _rotary_kernel[grid]( + q, + cos, + sin, + q.stride(0), + q.stride(1), + q.stride(2), + cos.stride(0), + cos.stride(1), + total_len, + HEAD_NUM=head_num, + BLOCK_HEAD=BLOCK_HEAD, + BLOCK_SEQ=BLOCK_SEQ, + HEAD_DIM=head_dim, + num_warps=num_warps, + num_stages=1, + ) + return + + +class Llama2Forwards: + @staticmethod + @triton.jit + def _rotary_kernel( + Q, + Cos, + Sin, + stride_qbs, + stride_qh, + stride_qd, + stride_cosbs, + stride_cosd, + stride_sinbs, + stride_sind, + max_total_len, + H, # N_CTX + BLOCK_HEAD: tl.constexpr, + BLOCK_SEQ: tl.constexpr, + BLOCK_DMODEL: tl.constexpr, + ): + cur_head_index = tl.program_id(0) + cur_seq_index = tl.program_id(1) + + cur_head_range = cur_head_index * BLOCK_HEAD + tl.arange(0, BLOCK_HEAD) + cur_seq_range = cur_seq_index * BLOCK_SEQ + tl.arange(0, BLOCK_SEQ) + + dim_range0 = tl.arange(0, BLOCK_DMODEL // 2) * 2 + dim_range1 = dim_range0 + 1 + off_q0 = ( + cur_seq_range[:, None, None] * stride_qbs + + cur_head_range[None, :, None] * stride_qh + + dim_range0[None, None, :] * stride_qd + ) + off_q1 = ( + cur_seq_range[:, None, None] * stride_qbs + + cur_head_range[None, :, None] * stride_qh + + dim_range1[None, None, :] * stride_qd + ) + + cos_range = tl.arange(0, BLOCK_DMODEL // 2) + off_dimcos_sin = cur_seq_range[:, None, None] * stride_cosbs + cos_range[None, None, :] * stride_cosd + + q0 = tl.load( + Q + off_q0, + mask=(cur_seq_range[:, None, None] < max_total_len) & (cur_head_range[None, :, None] < H), + other=0.0, + ) + q1 = tl.load( + Q + off_q1, + mask=(cur_seq_range[:, None, None] < max_total_len) & (cur_head_range[None, :, None] < H), + other=0.0, + ) + + cos = tl.load(Cos + off_dimcos_sin, mask=cur_seq_range[:, None, None] < max_total_len, other=0.0) + sin = tl.load(Sin + off_dimcos_sin, mask=cur_seq_range[:, None, None] < max_total_len, other=0.0) + + out0 = q0 * cos - q1 * sin + out1 = q0 * sin + q1 * cos + + tl.store( + Q + off_q0, out0, mask=(cur_seq_range[:, None, None] < max_total_len) & (cur_head_range[None, :, None] < H) + ) + tl.store( + Q + off_q1, out1, mask=(cur_seq_range[:, None, None] < max_total_len) & (cur_head_range[None, :, None] < H) + ) + + return + + @staticmethod + @torch.no_grad() + def rotary_emb_fwd(q, cos, sin): + total_len = q.shape[0] + head_num = q.shape[1] + head_dim = q.shape[2] // 2 + assert q.shape[0] == cos.shape[0] and q.shape[0] == sin.shape[0], f"q shape {q.shape} cos shape {cos.shape}" + BLOCK_HEAD = 4 + BLOCK_SEQ = 32 + grid = (triton.cdiv(head_num, BLOCK_HEAD), triton.cdiv(total_len, BLOCK_SEQ)) + if head_dim >= 128: + num_warps = 8 + else: + num_warps = 4 + + Llama2Forwards._rotary_kernel[grid]( + q, + cos, + sin, + q.stride(0), + q.stride(1), + q.stride(2), + cos.stride(0), + cos.stride(1), + sin.stride(0), + sin.stride(1), + total_len, + head_num, + BLOCK_HEAD=BLOCK_HEAD, + BLOCK_SEQ=BLOCK_SEQ, + BLOCK_DMODEL=head_dim, + num_warps=num_warps, + num_stages=1, + ) + return diff --git a/colossalai/kernel/triton/self_attention_nofusion.py b/colossalai/kernel/triton/self_attention_nofusion.py new file mode 100644 index 0000000000000000000000000000000000000000..4b56c8afd67f56983a57b2bc78cf1f49680a8ff6 --- /dev/null +++ b/colossalai/kernel/triton/self_attention_nofusion.py @@ -0,0 +1,162 @@ +import torch + +try: + import triton + + HAS_TRITON = True +except ImportError: + HAS_TRITON = False + print("please install triton from https://github.com/openai/triton") + +if HAS_TRITON: + from .qkv_matmul_kernel import qkv_gemm_4d_kernel + from .softmax import softmax_kernel + + def self_attention_forward_without_fusion( + q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, input_mask: torch.Tensor, scale: float + ): + r"""A function to do QKV Attention calculation by calling GEMM and softmax triton kernels + Args: + q (torch.Tensor): Q embedding in attention layer, shape should be (batch, seq_len, num_heads, head_size) + k (torch.Tensor): K embedding in attention layer, shape should be (batch, seq_len, num_heads, head_size) + v (torch.Tensor): V embedding in attention layer, shape should be (batch, seq_len, num_heads, head_size) + input_mask (torch.Tensor): mask for softmax layer, shape should be (batch, num_heads, seq_lem, seq_len) + scale: the float scale value which is used to multiply with Q*K^T before doing softmax + + Return: + output (Torch.Tensor): The output shape is (batch, seq_len, num_heads, head_size) + """ + assert len(q.shape) == 4, "the shape of q val must be 4" + batches, M, H, K = q.shape + assert q.shape == k.shape, "the shape of q and the shape of k must be equal" + assert q.shape == v.shape, "the shape of q and the shape of v must be equal" + assert q.shape[-1] == k.shape[-1], "the last dimension of q and k must be equal" + + N = k.shape[1] + + # head_size * num_of_head + d_model = q.shape[-1] * q.shape[-2] + + score_output = torch.empty((batches, H, M, N), device=q.device, dtype=q.dtype) + + grid = lambda meta: ( + batches, + H, + triton.cdiv(M, meta["BLOCK_SIZE_M"]) * triton.cdiv(N, meta["BLOCK_SIZE_N"]), + ) + + qkv_gemm_4d_kernel[grid]( + q, + k, + score_output, + M, + N, + K, + q.stride(0), + q.stride(2), + q.stride(1), + q.stride(3), + k.stride(0), + k.stride(2), + k.stride(3), + k.stride(1), + score_output.stride(0), + score_output.stride(1), + score_output.stride(2), + score_output.stride(3), + scale=scale, + # currently manually setting, later on we can use auto-tune config to match best setting + BLOCK_SIZE_M=64, + BLOCK_SIZE_N=32, + BLOCK_SIZE_K=32, + GROUP_SIZE_M=8, + ) + + softmax_output = torch.empty(score_output.shape, device=score_output.device, dtype=score_output.dtype) + score_output_shape = score_output.shape + + score_output = score_output.view(-1, score_output.shape[-1]) + n_rows, n_cols = score_output.shape + + if n_rows <= 350000: + block_size = max(triton.next_power_of_2(n_cols), 2) + num_warps = 4 + if block_size >= 4096: + num_warps = 16 + elif block_size >= 2048: + num_warps = 8 + else: + num_warps = 4 + + softmax_kernel[(n_rows,)]( + softmax_output, + score_output, + score_output.stride(0), + n_cols, + mask_ptr=input_mask, + num_warps=num_warps, + BLOCK_SIZE=block_size, + ) + + else: + # NOTE: change softmax kernel functions to make it suitable for large size dimension + softmax_output = torch.nn.functional.softmax(score_output, dim=-1) + softmax_output = softmax_output.view(*score_output_shape) + + batches, H, M, K = softmax_output.shape + N = v.shape[-1] + + output = torch.empty((batches, M, H, N), device=softmax_output.device, dtype=softmax_output.dtype) + + grid = lambda meta: ( + batches, + H, + triton.cdiv(M, meta["BLOCK_SIZE_M"]) * triton.cdiv(N, meta["BLOCK_SIZE_N"]), + ) + + qkv_gemm_4d_kernel[grid]( + softmax_output, + v, + output, + M, + N, + K, + softmax_output.stride(0), + softmax_output.stride(1), + softmax_output.stride(2), + softmax_output.stride(3), + v.stride(0), + v.stride(2), + v.stride(1), + v.stride(3), + output.stride(0), + output.stride(2), + output.stride(1), + output.stride(3), + BLOCK_SIZE_M=128, + BLOCK_SIZE_N=64, + BLOCK_SIZE_K=64, + GROUP_SIZE_M=8, + scale=-1, + ) + return output.view(batches, -1, d_model) + + def self_attention_compute_using_triton( + qkv, input_mask, layer_past, alibi, scale, head_size, triangular=False, use_flash=False + ): + assert qkv.is_contiguous() + assert alibi is None, "current triton self-attention does not support alibi" + batches = qkv.shape[0] + d_model = qkv.shape[-1] // 3 + num_of_heads = d_model // head_size + + q = qkv[:, :, :d_model] + k = qkv[:, :, d_model : d_model * 2] + v = qkv[:, :, d_model * 2 :] + q = q.view(batches, -1, num_of_heads, head_size) + k = k.view(batches, -1, num_of_heads, head_size) + v = v.view(batches, -1, num_of_heads, head_size) + + data_output_triton = self_attention_forward_without_fusion(q, k, v, input_mask, scale) + + return data_output_triton diff --git a/colossalai/kernel/triton/softmax.py b/colossalai/kernel/triton/softmax.py new file mode 100644 index 0000000000000000000000000000000000000000..8ffce80a3041e651e3027f9270483cafa181ef6e --- /dev/null +++ b/colossalai/kernel/triton/softmax.py @@ -0,0 +1,99 @@ +import torch + +try: + import triton + import triton.language as tl + + HAS_TRITON = True +except ImportError: + HAS_TRITON = False + print("please install triton from https://github.com/openai/triton") + +if HAS_TRITON: + """ + softmax kernel is modified based on + https://github.com/openai/triton/blob/34817ecc954a6f4ca7b4dfb352fdde1f8bd49ca5/python/tutorials/02-fused-softmax.py + """ + + @triton.jit + def softmax_kernel(output_ptr, input_ptr, row_stride, n_cols, mask_ptr, BLOCK_SIZE: tl.constexpr): + r"""the kernel function for implementing softmax operator + Args: + output_ptr: the output after finishing softmax operation, (N, hidden_dim) + input_ptr: the tensor of input, shape should be (N, hidden_dim) + n_cols(tl.constexpr): the number of cols of input + BLOCK_SIZE(tl.constexpr): the block_size of your hidden_dim dimension, typically BLOCK_SIZE >= hidden_dim + """ + row_idx = tl.program_id(0) + row_start_ptr = input_ptr + row_idx * row_stride + col_offsets = tl.arange(0, BLOCK_SIZE) + input_ptrs = row_start_ptr + col_offsets + row = tl.load(input_ptrs, mask=col_offsets < n_cols, other=-float("inf")).to(tl.float32) + row_minus_max = row - tl.max(row, axis=0) + + if mask_ptr is not None: + # load mask into SRAM + mask_ptrs = (mask_ptr + (row_indx * row_stride)) + col_offsets + mask = tl.load(mask_ptrs, mask=col_offsets < n_cols, other=0).to(tl.float32) + + # update + row_minus_max = row_minus_max + mask + + numerator = tl.exp(row_minus_max) + denominator = tl.sum(numerator, axis=0) + softmax_output = numerator / denominator + output_row_start_ptr = output_ptr + row_idx * row_stride + output_ptrs = output_row_start_ptr + col_offsets + # Write back output to DRAM + tl.store(output_ptrs, softmax_output, mask=col_offsets < n_cols) + + def softmax(input: torch.Tensor, mask: torch.Tensor = None, dim=-1) -> torch.Tensor: + if mask is not None: + assert input[-1] == mask[-1], "the last dimentions should be the same for input and mask" + assert dim == -1 or dim == len(input.shape) - 1, "currently softmax layer only support last dimention" + + hidden_dim = input.shape[-1] + output = torch.empty_like(input) + input = input.view(-1, hidden_dim) + if mask is not None: + mask = mask.view(-1, hidden_dim) + assert input.shape[0] == mask.shape[0], "the fist dimention of mask and input should be the same" + + num_rows, num_cols = input.shape + block_size = max(triton.next_power_of_2(num_cols), 2) + num_warps = 16 + if block_size >= 4096: + num_warps = 16 + elif block_size >= 2048: + num_warps = 8 + else: + num_warps = 4 + + if num_rows <= 350000: + grid = (num_rows,) + softmax_kernel[grid]( + output, input, input.stride(0), num_cols, mask, BLOCK_SIZE=block_size, num_warps=num_warps + ) + else: + grid = lambda meta: () + + grid = lambda meta: (triton.cdiv(num_rows, meta["BLOCK_M"]),) + + if block_size >= 4096: + pass + elif block_size >= 2048: + pass + + softmax_kernel[grid]( + output_ptr=output, + input_ptr=input, + row_stride=input.stride(0), + n_rows=num_rows, + n_cols=num_cols, + mask_ptr=mask, + # currently manually setting up size + BLOCK_M=32, + BLOCK_SIZE=block_size, + ) + + return output diff --git a/colossalai/kernel/triton/token_attention_kernel.py b/colossalai/kernel/triton/token_attention_kernel.py new file mode 100644 index 0000000000000000000000000000000000000000..c27394f0f9cf577458f84a12f7dd370f3c99498a --- /dev/null +++ b/colossalai/kernel/triton/token_attention_kernel.py @@ -0,0 +1,841 @@ +# Adapted from ModelTC https://github.com/ModelTC/lightllm + + +import torch + +try: + import triton + import triton.language as tl + + HAS_TRITON = True +except ImportError: + HAS_TRITON = False + print("please install triton from https://github.com/openai/triton") + +if HAS_TRITON: + + @triton.jit + def _token_attn_1_kernel( + Q, + K, + sm_scale, + kv_cache_loc, + kv_cache_start_loc, + kv_cache_seqlen, + max_kv_cache_len, + attn_out, + kv_cache_loc_b_stride, + kv_cache_loc_s_stride, + q_batch_stride, + q_head_stride, + q_head_dim_stride, + k_batch_stride, + k_head_stride, + k_head_dim_stride, + attn_head_stride, + attn_batch_stride, + HEAD_DIM: tl.constexpr, + BLOCK_N: tl.constexpr, + ): + current_batch = tl.program_id(0) + current_head = tl.program_id(1) + start_n = tl.program_id(2) + + offs_d = tl.arange(0, HEAD_DIM) + current_batch_seq_len = tl.load(kv_cache_seqlen + current_batch) + current_batch_in_all_start_index = tl.load(kv_cache_start_loc + current_batch) + + current_batch_start_index = max_kv_cache_len - current_batch_seq_len + current_batch_end_index = max_kv_cache_len + + off_q = current_batch * q_batch_stride + current_head * q_head_stride + offs_d * q_head_dim_stride + + offs_n = start_n * BLOCK_N + tl.arange(0, BLOCK_N) + + block_stard_index = start_n * BLOCK_N + block_mask = tl.where(block_stard_index < current_batch_seq_len, 1, 0) + + for start_mark in range(0, block_mask, 1): + q = tl.load(Q + off_q + start_mark) + offs_n_new = current_batch_start_index + offs_n + k_loc = tl.load( + kv_cache_loc + kv_cache_loc_b_stride * current_batch + kv_cache_loc_s_stride * offs_n_new, + mask=offs_n_new < current_batch_end_index, + other=0, + ) + off_k = k_loc[:, None] * k_batch_stride + current_head * k_head_stride + offs_d[None, :] * k_head_dim_stride + k = tl.load(K + off_k, mask=offs_n_new[:, None] < current_batch_end_index, other=0.0) + att_value = tl.sum(q[None, :] * k, 1) + att_value *= sm_scale + off_o = current_head * attn_head_stride + (current_batch_in_all_start_index + offs_n) * attn_batch_stride + tl.store(attn_out + off_o, att_value, mask=offs_n_new < current_batch_end_index) + return + + @triton.jit + def _token_attn_1_alibi_kernel( + Q, + K, + sm_scale, + alibi, + kv_cache_loc, + kv_cache_start_loc, + kv_cache_seqlen, + max_kv_cache_len, + attn_out, + kv_cache_loc_b_stride, + kv_cache_loc_s_stride, + q_batch_stride, + q_head_stride, + q_head_dim_stride, + k_batch_stride, + k_head_stride, + k_head_dim_stride, + attn_head_stride, + attn_batch_stride, + HEAD_DIM: tl.constexpr, + BLOCK_N: tl.constexpr, + ): + current_batch = tl.program_id(0) + current_head = tl.program_id(1) + start_n = tl.program_id(2) + + offs_d = tl.arange(0, HEAD_DIM) + current_batch_seq_len = tl.load(kv_cache_seqlen + current_batch) + current_batch_in_all_start_index = tl.load(kv_cache_start_loc + current_batch) + + current_batch_start_index = max_kv_cache_len - current_batch_seq_len + current_batch_end_index = max_kv_cache_len + + off_q = current_batch * q_batch_stride + current_head * q_head_stride + offs_d * q_head_dim_stride + + offs_n = start_n * BLOCK_N + tl.arange(0, BLOCK_N) + + block_stard_index = start_n * BLOCK_N + block_mask = tl.where(block_stard_index < current_batch_seq_len, 1, 0) + + for start_mark in range(0, block_mask, 1): + alibi_m = tl.load(alibi + current_head) + q = tl.load(Q + off_q + start_mark) + offs_n_new = current_batch_start_index + offs_n + k_loc = tl.load( + kv_cache_loc + kv_cache_loc_b_stride * current_batch + kv_cache_loc_s_stride * offs_n_new, + mask=offs_n_new < current_batch_end_index, + other=0, + ) + off_k = k_loc[:, None] * k_batch_stride + current_head * k_head_stride + offs_d[None, :] * k_head_dim_stride + k = tl.load(K + off_k, mask=offs_n_new[:, None] < current_batch_end_index, other=0.0) + att_value = tl.sum(q[None, :] * k, 1) + att_value *= sm_scale + att_value -= alibi_m * (current_batch_seq_len - 1 - offs_n) + off_o = current_head * attn_head_stride + (current_batch_in_all_start_index + offs_n) * attn_batch_stride + tl.store(attn_out + off_o, att_value, mask=offs_n_new < current_batch_end_index) + return + + @torch.no_grad() + def token_attn_fwd_1( + q, k, attn_out, kv_cache_loc, kv_cache_start_loc, kv_cache_seqlen, max_kv_cache_len, alibi=None + ): + BLOCK = 32 + # shape constraints + q_head_dim, k_head_dim = q.shape[-1], k.shape[-1] + assert q_head_dim == k_head_dim + assert k_head_dim in {16, 32, 64, 128} + sm_scale = 1.0 / (k_head_dim**0.5) + + batch, head_num = kv_cache_loc.shape[0], q.shape[1] + + grid = (batch, head_num, triton.cdiv(max_kv_cache_len, BLOCK)) + + num_warps = 4 if k_head_dim <= 64 else 8 + num_warps = 2 + + if alibi is not None: + _token_attn_1_alibi_kernel[grid]( + q, + k, + sm_scale, + alibi, + kv_cache_loc, + kv_cache_start_loc, + kv_cache_seqlen, + max_kv_cache_len, + attn_out, + kv_cache_loc.stride(0), + kv_cache_loc.stride(1), + q.stride(0), + q.stride(1), + q.stride(2), + k.stride(0), + k.stride(1), + k.stride(2), + attn_out.stride(0), + attn_out.stride(1), + HEAD_DIM=k_head_dim, + BLOCK_N=BLOCK, + num_warps=num_warps, + num_stages=1, + ) + else: + _token_attn_1_kernel[grid]( + q, + k, + sm_scale, + kv_cache_loc, + kv_cache_start_loc, + kv_cache_seqlen, + max_kv_cache_len, + attn_out, + kv_cache_loc.stride(0), + kv_cache_loc.stride(1), + q.stride(0), + q.stride(1), + q.stride(2), + k.stride(0), + k.stride(1), + k.stride(2), + attn_out.stride(0), + attn_out.stride(1), + HEAD_DIM=k_head_dim, + BLOCK_N=BLOCK, + num_warps=num_warps, + num_stages=1, + ) + return + + @triton.jit + def _token_attn_softmax_fwd( + softmax_logics, + kv_cache_start_loc, + kv_cache_seqlen, + softmax_prob_out, + logics_head_dim_stride, + logics_batch_stride, + prob_head_dim_stride, + prob_batch_stride, + BLOCK_SIZE: tl.constexpr, + ): + current_batch = tl.program_id(0) + current_head = tl.program_id(1) + + col_offsets = tl.arange(0, BLOCK_SIZE) + current_batch_seq_len = tl.load(kv_cache_seqlen + current_batch) + current_batch_in_all_start_index = tl.load(kv_cache_start_loc + current_batch) + + row = tl.load( + softmax_logics + + current_head * logics_head_dim_stride + + (current_batch_in_all_start_index + col_offsets) * logics_batch_stride, + mask=col_offsets < current_batch_seq_len, + other=-float("inf"), + ).to(tl.float32) + + row_minus_max = row - tl.max(row, axis=0) + numerator = tl.exp(row_minus_max) + denominator = tl.sum(numerator, axis=0) + softmax_output = numerator / denominator + + tl.store( + softmax_prob_out + + current_head * prob_head_dim_stride + + (current_batch_in_all_start_index + col_offsets) * prob_batch_stride, + softmax_output, + mask=col_offsets < current_batch_seq_len, + ) + return + + @torch.no_grad() + def token_attn_softmax_fwd(softmax_logics, kv_cache_start_loc, kv_cache_seqlen, softmax_prob_out, max_kv_cache_len): + BLOCK_SIZE = triton.next_power_of_2(max_kv_cache_len) + batch, head_num = kv_cache_start_loc.shape[0], softmax_logics.shape[0] + + num_warps = 4 + if BLOCK_SIZE >= 2048: + num_warps = 8 + if BLOCK_SIZE >= 4096: + num_warps = 16 + + _token_attn_softmax_fwd[(batch, head_num)]( + softmax_logics, + kv_cache_start_loc, + kv_cache_seqlen, + softmax_prob_out, + softmax_logics.stride(0), + softmax_logics.stride(1), + softmax_prob_out.stride(0), + softmax_prob_out.stride(1), + num_warps=num_warps, + BLOCK_SIZE=BLOCK_SIZE, + ) + return + + @triton.jit + def _token_attn_2_kernel( + Prob, + V, + attn_out, + kv_cache_loc, + kv_cache_start_loc, + kv_cache_seqlen, + max_kv_cache_len, + kv_cache_loc_b_stride, + kv_cache_loc_s_stride, + prob_head_dim_stride, + prob_batch_stride, + v_batch_stride, + v_head_stride, + v_head_dim_stride, + attn_out_batch_stride, + attn_out_head_stride, + attn_out_head_dim_stride, + HEAD_DIM: tl.constexpr, + BLOCK_N: tl.constexpr, + ): + current_batch = tl.program_id(0) + current_head = tl.program_id(1) + + offs_n = tl.arange(0, BLOCK_N) + offs_d = tl.arange(0, HEAD_DIM) + current_batch_seq_len = tl.load(kv_cache_seqlen + current_batch) + current_batch_start_index = max_kv_cache_len - current_batch_seq_len + current_batch_in_all_start_index = tl.load(kv_cache_start_loc + current_batch) + + v_loc_off = current_batch * kv_cache_loc_b_stride + (current_batch_start_index + offs_n) * kv_cache_loc_s_stride + p_offs = current_head * prob_head_dim_stride + (current_batch_in_all_start_index + offs_n) * prob_batch_stride + v_offs = current_head * v_head_stride + offs_d[None, :] * v_head_dim_stride + + acc = tl.zeros([HEAD_DIM], dtype=tl.float32) + for start_n in range(0, current_batch_seq_len, BLOCK_N): + start_n = tl.multiple_of(start_n, BLOCK_N) + p_value = tl.load( + Prob + p_offs + start_n * kv_cache_loc_s_stride, + mask=(start_n + offs_n) < current_batch_seq_len, + other=0.0, + ) + v_loc = tl.load( + kv_cache_loc + v_loc_off + start_n * kv_cache_loc_s_stride, + mask=(start_n + offs_n) < current_batch_seq_len, + other=0.0, + ) + v_value = tl.load( + V + v_offs + v_loc[:, None] * v_batch_stride, + mask=(start_n + offs_n[:, None]) < current_batch_seq_len, + other=0.0, + ) + acc += tl.sum(p_value[:, None] * v_value, 0) + + acc = acc.to(tl.float16) + off_o = ( + current_batch * attn_out_batch_stride + + current_head * attn_out_head_stride + + offs_d * attn_out_head_dim_stride + ) + out_ptrs = attn_out + off_o + tl.store(out_ptrs, acc) + return + + @torch.no_grad() + def token_attn_fwd_2(prob, v, attn_out, kv_cache_loc, kv_cache_start_loc, kv_cache_seqlen, max_kv_cache_len): + if triton.__version__ >= "2.1.0": + BLOCK = 128 + else: + BLOCK = 64 + batch, head = kv_cache_loc.shape[0], v.shape[1] + grid = (batch, head) + num_warps = 4 + dim = v.shape[-1] + + _token_attn_2_kernel[grid]( + prob, + v, + attn_out, + kv_cache_loc, + kv_cache_start_loc, + kv_cache_seqlen, + max_kv_cache_len, + kv_cache_loc.stride(0), + kv_cache_loc.stride(1), + prob.stride(0), + prob.stride(1), + v.stride(0), + v.stride(1), + v.stride(2), + attn_out.stride(0), + attn_out.stride(1), + attn_out.stride(2), + HEAD_DIM=dim, + BLOCK_N=BLOCK, + num_warps=num_warps, + num_stages=1, + ) + return + + @torch.no_grad() + def token_attention_fwd( + q, k, v, attn_out, kv_cache_loc, kv_cache_start_loc, kv_cache_seq_len, max_len_in_batch, alibi=None + ): + head_num = k.shape[1] + batch_size = kv_cache_seq_len.shape[0] + calcu_shape1 = (batch_size, head_num, k.shape[2]) + total_token_num = k.shape[0] + + att_m_tensor = torch.empty((head_num, total_token_num), dtype=q.dtype, device="cuda") + + token_attn_fwd_1( + q.view(calcu_shape1), + k, + att_m_tensor, + kv_cache_loc, + kv_cache_start_loc, + kv_cache_seq_len, + max_len_in_batch, + alibi=alibi, + ) + + prob = torch.empty_like(att_m_tensor) + + token_attn_softmax_fwd(att_m_tensor, kv_cache_start_loc, kv_cache_seq_len, prob, max_len_in_batch) + att_m_tensor = None + token_attn_fwd_2( + prob, v, attn_out.view(calcu_shape1), kv_cache_loc, kv_cache_start_loc, kv_cache_seq_len, max_len_in_batch + ) + + prob = None + + return + + +class Llama2TokenAttentionForwards: + @staticmethod + @triton.jit + def _fwd_kernel( + Logics, + V, + Out, + B_Loc, + B_Start_Loc, + B_Seqlen, + max_input_len, + stride_logic_h, + stride_logic_bs, + stride_vbs, + stride_vh, + stride_vd, + stride_obs, + stride_oh, + stride_od, + stride_b_loc_b, + stride_b_loc_s, + other_kv_index, # avoid nan information + kv_group_num, + BLOCK_DMODEL: tl.constexpr, + BLOCK_N: tl.constexpr, + ): + cur_batch = tl.program_id(0) + cur_head = tl.program_id(1) + + cur_kv_head = cur_head // kv_group_num + + cur_batch_seq_len = tl.load(B_Seqlen + cur_batch) + cur_batch_start_loc = tl.load(B_Start_Loc + cur_batch) + + offs_n = tl.arange(0, BLOCK_N) + offs_d = tl.arange(0, BLOCK_DMODEL) + + off_v = cur_kv_head * stride_vh + offs_d[None, :] * stride_vd + off_b_loc = cur_batch * stride_b_loc_b + (max_input_len - cur_batch_seq_len) * stride_b_loc_s + + v_ptrs = V + off_v + + e_max = float("-inf") + e_sum = 0.0 + acc = tl.zeros([BLOCK_DMODEL], dtype=tl.float32) + + for start_n in range(0, cur_batch_seq_len, BLOCK_N): + start_n = tl.multiple_of(start_n, BLOCK_N) + v_index = tl.load( + B_Loc + off_b_loc + (start_n + offs_n) * stride_b_loc_s, + mask=(start_n + offs_n) < cur_batch_seq_len, + other=other_kv_index, + ) + + qk = tl.load( + Logics + cur_head * stride_logic_h + (cur_batch_start_loc + start_n + offs_n) * stride_logic_bs, + mask=start_n + offs_n < cur_batch_seq_len, + other=float("-inf"), + ) + + n_e_max = tl.maximum(tl.max(qk, 0), e_max) + old_scale = tl.exp(e_max - n_e_max) + p = tl.exp(qk - n_e_max) + e_sum = e_sum * old_scale + tl.sum(p, 0) + v = tl.load(v_ptrs + v_index[:, None] * stride_vbs) + acc = acc * old_scale + tl.sum(p[:, None] * v, 0) + e_max = n_e_max + + acc = acc / e_sum + off_o = cur_batch * stride_obs + cur_head * stride_oh + offs_d * stride_od + out_ptrs = Out + off_o + tl.store(out_ptrs, acc) + return + + @staticmethod + @torch.no_grad() + def token_softmax_reducev_fwd(logics, v, o, b_loc, b_start_loc, b_seq_len, max_input_len, other_kv_index): + BLOCK = 64 + batch, head = b_seq_len.shape[0], logics.shape[0] + grid = (batch, head) + kv_group_num = logics.shape[0] // v.shape[1] + + num_warps = 1 + Llama2TokenAttentionForwards._fwd_kernel[grid]( + logics, + v, + o, + b_loc, + b_start_loc, + b_seq_len, + max_input_len, + logics.stride(0), + logics.stride(1), + v.stride(0), + v.stride(1), + v.stride(2), + o.stride(0), + o.stride(1), + o.stride(2), + b_loc.stride(0), + b_loc.stride(1), + other_kv_index, + kv_group_num, + BLOCK_DMODEL=v.shape[-1], + BLOCK_N=BLOCK, + num_warps=num_warps, + num_stages=3, + ) + return + + @staticmethod + @triton.jit + def _fwd_kernel_token_softmax( + Logics, + B_Start_Loc, + B_Seqlen, + Prob_Out, + stride_logic_h, + stride_logic_bs, + stride_prob_h, + stride_prob_bs, + BLOCK_SIZE: tl.constexpr, + ): + cur_batch = tl.program_id(0) + cur_head = tl.program_id(1) + + col_offsets = tl.arange(0, BLOCK_SIZE) + cur_batch_seq_len = tl.load(B_Seqlen + cur_batch) + cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch) + + row = tl.load( + Logics + cur_head * stride_logic_h + (cur_batch_in_all_start_index + col_offsets) * stride_logic_bs, + mask=col_offsets < cur_batch_seq_len, + other=-float("inf"), + ).to(tl.float32) + + row_minus_max = row - tl.max(row, axis=0) + numerator = tl.exp(row_minus_max) + denominator = tl.sum(numerator, axis=0) + softmax_output = numerator / denominator + + tl.store( + Prob_Out + cur_head * stride_prob_h + (cur_batch_in_all_start_index + col_offsets) * stride_prob_bs, + softmax_output, + mask=col_offsets < cur_batch_seq_len, + ) + return + + @staticmethod + @torch.no_grad() + def token_softmax_fwd(Logics, B_Start_Loc, B_Seqlen, Prob_Out, max_input_len): + BLOCK_SIZE = triton.next_power_of_2(max_input_len) + batch, head_num = B_Start_Loc.shape[0], Logics.shape[0] + + num_warps = 4 + if BLOCK_SIZE >= 2048: + num_warps = 8 + if BLOCK_SIZE >= 4096: + num_warps = 16 + + Llama2TokenAttentionForwards._fwd_kernel_token_softmax[(batch, head_num)]( + Logics, + B_Start_Loc, + B_Seqlen, + Prob_Out, + Logics.stride(0), + Logics.stride(1), + Prob_Out.stride(0), + Prob_Out.stride(1), + num_warps=num_warps, + BLOCK_SIZE=BLOCK_SIZE, + ) + return + + @staticmethod + @triton.jit + def _fwd_kernel_token_att1( + Q, + K, + sm_scale, + B_Loc, + B_Start_Loc, + B_Seqlen, + max_input_len, + Att_Out, + stride_b_loc_b, + stride_b_loc_s, + stride_qbs, + stride_qh, + stride_qd, + stride_kbs, + stride_kh, + stride_kd, + att_stride_h, + att_stride_bs, + kv_group_num, + BLOCK_DMODEL: tl.constexpr, + BLOCK_N: tl.constexpr, + ): + cur_batch = tl.program_id(0) + cur_head = tl.program_id(1) + start_n = tl.program_id(2) + + cur_kv_head = cur_head // kv_group_num + + offs_d = tl.arange(0, BLOCK_DMODEL) + cur_batch_seq_len = tl.load(B_Seqlen + cur_batch) + cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch) + + cur_batch_start_index = max_input_len - cur_batch_seq_len + cur_batch_end_index = max_input_len + + off_q = cur_batch * stride_qbs + cur_head * stride_qh + offs_d * stride_qd + + offs_n = start_n * BLOCK_N + tl.arange(0, BLOCK_N) + + block_stard_index = start_n * BLOCK_N + block_mask = tl.where(block_stard_index < cur_batch_seq_len, 1, 0) + + for start_mark in range(0, block_mask, 1): + q = tl.load(Q + off_q + start_mark) + offs_n_new = cur_batch_start_index + offs_n + k_loc = tl.load( + B_Loc + stride_b_loc_b * cur_batch + stride_b_loc_s * offs_n_new, + mask=offs_n_new < cur_batch_end_index, + other=0, + ) + off_k = k_loc[:, None] * stride_kbs + cur_kv_head * stride_kh + offs_d[None, :] * stride_kd + k = tl.load(K + off_k, mask=offs_n_new[:, None] < cur_batch_end_index, other=0.0) + att_value = tl.sum(q[None, :] * k, 1) + att_value *= sm_scale + off_o = cur_head * att_stride_h + (cur_batch_in_all_start_index + offs_n) * att_stride_bs + tl.store(Att_Out + off_o, att_value, mask=offs_n_new < cur_batch_end_index) + return + + @staticmethod + @torch.no_grad() + def token_att_fwd(q, k, att_out, B_Loc, B_Start_Loc, B_Seqlen, max_input_len): + BLOCK = 32 + # shape constraints + Lq, Lk = q.shape[-1], k.shape[-1] + assert Lq == Lk + assert Lk in {16, 32, 64, 128} + sm_scale = 1.0 / (Lk**0.5) + + batch, head_num = B_Loc.shape[0], q.shape[1] + + grid = (batch, head_num, triton.cdiv(max_input_len, BLOCK)) + kv_group_num = q.shape[1] // k.shape[1] + + num_warps = 4 if Lk <= 64 else 8 + num_warps = 2 + + Llama2TokenAttentionForwards._fwd_kernel_token_att1[grid]( + q, + k, + sm_scale, + B_Loc, + B_Start_Loc, + B_Seqlen, + max_input_len, + att_out, + B_Loc.stride(0), + B_Loc.stride(1), + q.stride(0), + q.stride(1), + q.stride(2), + k.stride(0), + k.stride(1), + k.stride(2), + att_out.stride(0), + att_out.stride(1), + kv_group_num=kv_group_num, + BLOCK_DMODEL=Lk, + BLOCK_N=BLOCK, + num_warps=num_warps, + num_stages=1, + ) + return + + @staticmethod + @triton.jit + def _fwd_kernel_token_att2( + Prob, + V, + Out, + B_Loc, + B_Start_Loc, + B_Seqlen, + max_input_len, # B_Start_Loc cumsum of input lens if continuous + stride_b_loc_b, + stride_b_loc_s, + stride_ph, + stride_pbs, + stride_vbs, + stride_vh, + stride_vd, + stride_obs, + stride_oh, + stride_od, + kv_group_num, + BLOCK_DMODEL: tl.constexpr, + BLOCK_N: tl.constexpr, + ): + cur_batch = tl.program_id(0) + cur_head = tl.program_id(1) + + cur_kv_head = cur_head // kv_group_num + + offs_n = tl.arange(0, BLOCK_N) + offs_d = tl.arange(0, BLOCK_DMODEL) + cur_batch_seq_len = tl.load(B_Seqlen + cur_batch) + cur_batch_start_index = max_input_len - cur_batch_seq_len + cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch) + + v_loc_off = cur_batch * stride_b_loc_b + (cur_batch_start_index + offs_n) * stride_b_loc_s + p_offs = cur_head * stride_ph + (cur_batch_in_all_start_index + offs_n) * stride_pbs + v_offs = cur_kv_head * stride_vh + offs_d[None, :] * stride_vd + + acc = tl.zeros([BLOCK_DMODEL], dtype=tl.float32) + for start_n in range(0, cur_batch_seq_len, BLOCK_N): + start_n = tl.multiple_of(start_n, BLOCK_N) + p_value = tl.load( + Prob + p_offs + start_n * stride_b_loc_s, mask=(start_n + offs_n) < cur_batch_seq_len, other=0.0 + ) + v_loc = tl.load( + B_Loc + v_loc_off + start_n * stride_b_loc_s, mask=(start_n + offs_n) < cur_batch_seq_len, other=0.0 + ) + v_value = tl.load( + V + v_offs + v_loc[:, None] * stride_vbs, + mask=(start_n + offs_n[:, None]) < cur_batch_seq_len, + other=0.0, + ) + acc += tl.sum(p_value[:, None] * v_value, 0) + + acc = acc.to(tl.float16) + off_o = cur_batch * stride_obs + cur_head * stride_oh + offs_d * stride_od + out_ptrs = Out + off_o + tl.store(out_ptrs, acc) + return + + @staticmethod + @torch.no_grad() + def token_att_fwd2(prob, v, out, B_Loc, B_Start_Loc, B_Seqlen, max_input_len): + if triton.__version__ >= "2.1.0": + BLOCK = 128 + else: + BLOCK = 64 + batch, head = B_Loc.shape[0], prob.shape[0] + grid = (batch, head) + num_warps = 4 + dim = v.shape[-1] + + kv_group_num = prob.shape[0] // v.shape[1] + + Llama2TokenAttentionForwards._fwd_kernel_token_att2[grid]( + prob, + v, + out, + B_Loc, + B_Start_Loc, + B_Seqlen, + max_input_len, + B_Loc.stride(0), + B_Loc.stride(1), + prob.stride(0), + prob.stride(1), + v.stride(0), + v.stride(1), + v.stride(2), + out.stride(0), + out.stride(1), + out.stride(2), + kv_group_num=kv_group_num, + BLOCK_DMODEL=dim, + BLOCK_N=BLOCK, + num_warps=num_warps, + num_stages=1, + ) + return + + # this is the interface of llama2 attn forward + @staticmethod + @torch.no_grad() + def token_attn( + q, k, v, attn_out, kv_cache_loc, kv_cache_start_loc, kv_cache_seq_len, max_len_in_batch, other_kv_index + ): + total_token_num = k.shape[0] + batch_size, head_num, head_dim = q.shape + calcu_shape1 = (batch_size, head_num, head_dim) + att_m_tensor = torch.empty((head_num, total_token_num), dtype=q.dtype, device="cuda") + + Llama2TokenAttentionForwards.token_att_fwd( + q, + k, + att_m_tensor, + kv_cache_loc, + kv_cache_start_loc, + kv_cache_seq_len, + max_len_in_batch, + ) + + if triton.__version__ == "2.0.0": + prob = torch.empty_like(att_m_tensor) + Llama2TokenAttentionForwards.token_softmax_fwd( + att_m_tensor, kv_cache_start_loc, kv_cache_seq_len, prob, max_len_in_batch + ) + att_m_tensor = None + + Llama2TokenAttentionForwards.token_att_fwd2( + prob, + v, + attn_out.view(calcu_shape1), + kv_cache_loc, + kv_cache_start_loc, + kv_cache_seq_len, + max_len_in_batch, + ) + + prob = None + return + + elif triton.__version__ >= "2.1.0": + Llama2TokenAttentionForwards.token_softmax_reducev_fwd( + att_m_tensor, + v, + attn_out.view(calcu_shape1), + kv_cache_loc, + kv_cache_start_loc, + kv_cache_seq_len, + max_len_in_batch, + other_kv_index, + ) + else: + raise Exception("not support triton version") diff --git a/colossalai/lazy/__init__.py b/colossalai/lazy/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..c6b813c500363b19311fb905c2570cbad6b6f51f --- /dev/null +++ b/colossalai/lazy/__init__.py @@ -0,0 +1,6 @@ +from .lazy_init import LazyInitContext, LazyTensor + +__all__ = [ + "LazyInitContext", + "LazyTensor", +] diff --git a/colossalai/lazy/construction.py b/colossalai/lazy/construction.py new file mode 100644 index 0000000000000000000000000000000000000000..6764eaf774abb43303f03eb85ebfe6cd95e869ca --- /dev/null +++ b/colossalai/lazy/construction.py @@ -0,0 +1,87 @@ +from contextlib import contextmanager +from typing import Callable, Dict, Tuple + +import torch + +__all__ = [ + "_LEGACY_TENSOR_CONSTRUCTOR", + "_NO_META_FACTORY", + "_NORMAL_FACTORY", + "ConstructorManager", +] + +# reference: https://pytorch.org/cppdocs/notes/tensor_creation.html +_NORMAL_FACTORY = [ + "arange", + "full", + "empty", + "linspace", + "logspace", + "ones", + "rand", + "randn", + "randint", + "randperm", + "zeros", + "tensor", +] + +# factory function that does not support meta tensor backend +_NO_META_FACTORY = [ + "eye", +] + +_LEGACY_TENSOR_CONSTRUCTOR = { + "FloatTensor": torch.float, + "DoubleTensor": torch.double, + "HalfTensor": torch.half, + "BFloat16Tensor": torch.bfloat16, + "ByteTensor": torch.uint8, + "CharTensor": torch.int8, + "ShortTensor": torch.short, + "IntTensor": torch.int, + "LongTensor": torch.long, + "BoolTensor": torch.bool, +} + + +class ConstructorManager: + # function name: (new, old) + overwrites: Dict[str, Tuple[Callable, Callable]] = {} + changed: bool = False + + @staticmethod + def apply(overwrites: Dict[Callable, Callable]): + ConstructorManager.overwrites.clear() + ConstructorManager.overwrites.update(overwrites) + ConstructorManager.redo() + + @staticmethod + def undo(): + assert ConstructorManager.changed, "No constructor change to undo" + for name, (new, old) in ConstructorManager.overwrites.items(): + setattr(torch, name, old) + ConstructorManager.changed = False + + @staticmethod + def redo(): + assert not ConstructorManager.changed, "Constructor already changed" + for name, (new, old) in ConstructorManager.overwrites.items(): + setattr(torch, name, new) + ConstructorManager.changed = True + + @staticmethod + @contextmanager + def disable(): + enabled = ConstructorManager.changed + if enabled: + ConstructorManager.undo() + yield + if enabled: + ConstructorManager.redo() + + @staticmethod + def clear(): + if ConstructorManager.changed: + ConstructorManager.undo() + ConstructorManager.overwrites.clear() diff --git a/colossalai/lazy/lazy_init.py b/colossalai/lazy/lazy_init.py new file mode 100644 index 0000000000000000000000000000000000000000..b130111ba3d9e92b1cb801f738f00071443fe0a1 --- /dev/null +++ b/colossalai/lazy/lazy_init.py @@ -0,0 +1,664 @@ +from types import MethodType +from typing import Callable, Optional, Union + +import torch +import torch.nn as nn +from packaging import version +from torch import Tensor +from torch.nn import Parameter +from torch.utils._pytree import tree_map + +from colossalai.logging import get_dist_logger + +from .construction import ConstructorManager +from .pretrained import PretrainedManager + +import colossalai._analyzer._subclasses._meta_registration # noqa + +# reference: https://pytorch.org/cppdocs/notes/tensor_creation.html +_NORMAL_FACTORY = [ + "arange", + "full", + "empty", + "linspace", + "logspace", + "ones", + "rand", + "randn", + "randint", + "randperm", + "zeros", + "tensor", +] + +# factory function that does not support meta tensor backend +_NO_META_FACTORY = [ + "eye", +] + +_EARLY_MATERIALIZED_OPS = ["__getitem__", "split"] + +# If your intent is to change the metadata of a Tensor (such as sizes / strides / storage / storage_offset) +# without autograd tracking the change, remove the .data / .detach() call and wrap the change in a `with torch.no_grad():` block. +# These ops cannot be unwrapped using .data +_CHANGE_META_OPS = ["_cudnn_rnn_flatten_weight", "requires_grad_", "__get__", "__set__", "numel", "size", "dim"] + +# These ops is not related to tensor value and should not be rerun +_NO_RERUN_OPS = ["__get__", "numel", "size", "dim"] + +_LEGACY_TENSOR_CONSTRUCTOR = { + "FloatTensor": torch.float, + "DoubleTensor": torch.double, + "HalfTensor": torch.half, + "BFloat16Tensor": torch.bfloat16, + "ByteTensor": torch.uint8, + "CharTensor": torch.int8, + "ShortTensor": torch.short, + "IntTensor": torch.int, + "LongTensor": torch.long, + "BoolTensor": torch.bool, +} + +# These ops have at least one lazy tensor argument and maybe a scalar argument +# scalar value should be converted to meta tensor +# this is a hack for torch 2.0 +_EXPAND_SCALAR_OPS = [ + "where", + "clamp", + "clamp_min", + "clamp_max", + "clamp_", + "clamp_min_", + "clamp_max_", +] +_old_tensor_factory = torch.tensor + +_EMPTY_DATA = torch.empty(0) + + +class _MyTensor(Tensor): + """This class is only for correctness verification.""" + + _pre_op_fn: Callable[["LazyTensor"], None] = lambda *args: None + + default_device: Optional[torch.device] = None + + def __new__(cls, func, *args, concrete_data=None, **kwargs) -> "_MyTensor": + cls._pre_op_fn() + if concrete_data is not None: + # uniform api as LazyTensor + data = concrete_data + else: + kwargs["device"] = cls.default_device + data = func(*args, **kwargs) + return Tensor._make_subclass(cls, data, require_grad=data.requires_grad) + + @classmethod + def __torch_function__(cls, func, types, args=(), kwargs=None): + cls._pre_op_fn() + return super().__torch_function__(func, types, args, kwargs) + + +def _data_tolist(tensor: torch.Tensor) -> list: + """tolist() method is not allowed for a subclass of tensor. Tensor.data returns a Tensor.""" + return tensor.data.tolist() + + +def _convert_cls(tensor: "LazyTensor", target: torch.Tensor) -> torch.Tensor: + """Convert a lazy tensor's class to target's class, with target's data. + + The reason why we change the class of a lazy tensor in-place is that this can easily handle shared modules/parameters, which is common in huggingface models. + If we create a new tensor and update the module by ``setattr(module, name, param)``, the shared parameters will not be updated. And we have to track all shared parameters and update them manually. + + Args: + tensor (LazyTensor): the LazyTensor to be converted + target (torch.Tensor): target tensor + + Returns: + torch.Tensor: the converted tensor + """ + cls_to_become = Parameter if isinstance(tensor, Parameter) else torch.Tensor + tensor.__class__ = cls_to_become + if cls_to_become is Parameter: + # to fit UninitializedParameter + delattr(tensor, "_is_param") + tensor.data = target + tensor.requires_grad = target.requires_grad + # subclass of torch.Tensor does not have tolist() method + # overwrite this method after materialization or distribution + tensor.tolist = MethodType(_data_tolist, tensor) + return tensor + + +class LazyTensor(torch.Tensor): + """A naive implementation of LazyTensor (https://arxiv.org/pdf/2102.13267.pdf). + + Usage: + 1. Use ``LazyTensor`` instead of ``torch.Tensor``. + >>> x = LazyTensor(torch.zeros, 2, 3) + >>> x += 1 + >>> y = x * x + >>> y = y.cuda().half() + >>> y[0, 0] = 0 + >>> y = y.materialize() # materialize the tensor + >>> print(y) + tensor([[0., 1., 1.], + [1., 1., 1.]], device='cuda:0', dtype=torch.float16) + + Warnings: + 1. Cases that ``LazyTensor`` can't deal with. + >>> x = LazyTensor(torch.ones, 2, 3) + >>> x[0, 0] = -x[0, 0] # this will cause infinite recursion + >>> y = x.clone() + >>> x.add_(1) # modifying origin tensor after cloning leads to wrong materialization + >>> z = x.tolist() + >>> x.zeros_() # modifying origin tensor after cloning tolist is not allowed + >>> nn.utils.weight_norm(self.conv, name="weight", dim=2) # applying weight norm on a lazy tensor is not allowed + + + 2. Cases that ``LazyTensor`` becomes eager (early materialization). + >>> b = a[:, 2:] # get a slice of a lazy tensor triggers early materialization + >>> chunks = a.split(3) # this also triggers early materialization + >>> x.data = torch.rand(2, 3) # directly setting data of a lazy tensor triggers early materialization + + """ + + _repr = True + _meta_data: Optional[torch.Tensor] = None # shape, dtype, device + _pre_op_fn: Callable[["LazyTensor"], None] = lambda *args: None + + default_device: Optional[torch.device] = None + _device: torch.device # fake device of mate tensor + + @staticmethod + def __new__(cls, func, *args, meta_data=None, concrete_data=None, **kwargs): + # tips for torch 2.0: + # torch 2.0 disables torch dispatch for subclass of tensor + # MetaTensor is cannot be used + # Now lazy tensor contains device injection and meta tensor + if concrete_data is not None: + # some ops don't support meta backend and should have concrete data + elem = concrete_data + else: + if meta_data is None: + with ConstructorManager.disable(): + # to disable create lazy tensor in inner ops, this is a hack for torch 2.0 + meta_data = func(*args, **{**kwargs, "device": "meta"}) + elem = meta_data + # As a meta tensor cannot be modified __class__ to torch.Tensor, we should use an empty real tensor here + r = torch.Tensor._make_subclass(cls, _EMPTY_DATA, require_grad=elem.requires_grad) + r._meta_data = meta_data + + return r + + def __init__(self, func, *args, meta_data=None, concrete_data=None, **kwargs): + self._device = torch.device(kwargs.get("device", None) or "cpu") + if func.__name__ in _NORMAL_FACTORY: + kwargs = {**kwargs, "device": LazyTensor.default_device} + self._factory_method = (func, args, kwargs) # (func, args, kwargs) + self._op_buffer = [] # (func, args, kwargs, replace) + self._materialized_data: Optional[torch.Tensor] = concrete_data # materialized data + + @property + def device(self) -> torch.device: + return self._materialized_data.device if self._materialized_data is not None else self._device + + def __repr__(self): + return f"LazyTensor(..., size={tuple(self.shape)}, device='{self.device}', dtype={self.dtype})" + + def materialize(self) -> torch.Tensor: + """Materialize the ``LazyTensor`` to ``torch.Tensor`` by modifying __class__ (inplace). + + Returns: + torch.Tensor: The materialized tensor (self). + """ + target = self._materialize_data() + self.clean() + return _convert_cls(self, target) + + def clean(self) -> None: + """Clean all stored operations, meta data and materialized data, which prevents memory leaking. This should be called after all tensors are materialized.""" + delattr(self, "_factory_method") + delattr(self, "_op_buffer") + delattr(self, "_materialized_data") + delattr(self, "_meta_data") + + @staticmethod + def _replace_with_materialized(x): + if isinstance(x, LazyTensor): + return x._materialize_data() + return x + + def _materialize_data(self) -> torch.Tensor: + # self._materialized_data should be generated after the first call of this function + if self._materialized_data is None: + # apply factory method + func, args, kwargs = self._factory_method + # apply cached sequence + self._pre_op_fn() + + init_val = func( + *tree_map(self._replace_with_materialized, args), **tree_map(self._replace_with_materialized, kwargs) + ) + + self._materialized_data = self._rerun_ops(init_val) + return self._materialized_data + + def _rerun_ops(self, target=None) -> torch.Tensor: + """Do lazy execution by rerunning all (stored) related operations. + + Args: + target (torc.Tensor, optional): Intial value of the target tensor (self). Defaults to None. + """ + + def replace(x): + if x is self: + return target + elif isinstance(x, LazyTensor): + return x._materialize_data() + return x + + packed = None + + for func, args, kwargs in self._op_buffer: + if func == torch.Tensor.requires_grad_: + packed = func, args, kwargs # requires grad should be set at last + else: + self._pre_op_fn() + o = func(*tree_map(replace, args), **tree_map(replace, kwargs)) + target = o if isinstance(o, torch.Tensor) else target # if func returns non-Tensor, discard the value + + # super-dainiu: set requires_grad after all inplace-ops are done + if packed is not None: + func, args, kwargs = packed + func(*tree_map(replace, args), **tree_map(replace, kwargs)) + + return target + + # cache everything with __torch_function__ + + @classmethod + def __torch_function__(cls, func, types, args=(), kwargs=None): + if kwargs is None: + kwargs = {} + if func.__name__ in _EARLY_MATERIALIZED_OPS: + # These OPs cannot be lazy and related tensors should be early materialized + tree_map(cls._replace_with_materialized, args) + tree_map(cls._replace_with_materialized, kwargs) + is_inplace: bool = ( + func.__name__.endswith("_") + and not (func.__name__.endswith("__")) + or func.__name__ in ("__setitem__", "__set__") + ) + + is_change_meta_op: bool = func.__name__ in _CHANGE_META_OPS + + if isinstance(func, torch._C.ScriptMethod): + # FIXME(ver217): torch script functions are not verified + + target = None + + def unwrap(x): + if isinstance(x, LazyTensor): + return x._meta_data + return x + + target: LazyTensor = args[0].clone() + target._op_buffer.append((func, args, kwargs)) + target._meta_data = getattr(target._meta_data, func.name)( + *tree_map(unwrap, args[1:]), **tree_map(unwrap, kwargs) + ) + return target + else: + meta_to_lazy = {} + + def unwrap(x): + if isinstance(x, LazyTensor): + if x._materialized_data is not None: + # for early materialized tensor, use its materialized data directly + return x._materialized_data if is_change_meta_op else x._materialized_data.data + t = x if is_inplace else x.clone() + if func.__name__ not in _NO_RERUN_OPS: + t._op_buffer.append((func, args, kwargs)) + meta = x._meta_data if is_change_meta_op else x._meta_data.data + meta_to_lazy[meta] = t + return meta + elif ( + version.parse(torch.__version__) >= version.parse("2.0.0") + and func.__name__ in _EXPAND_SCALAR_OPS + and not isinstance(x, torch.Tensor) + ): + return _old_tensor_factory(x, device="meta") + return x + + def wrap(y, i=None): + if isinstance(y, torch.Tensor): + if y.is_meta: + if y in meta_to_lazy: + # inplace op, just return origin lazy tensor + return meta_to_lazy[y] + else: + # out of place op, create new lazy tensor + fn = lambda *a, **kw: func(*a, **kw) if i is None else func(*a, **kw)[i] + fn.__name__ = func.__name__ + lazy_y = LazyTensor(fn, *args, meta_data=y, **kwargs) + return lazy_y + else: + # for early materialized tensor + return LazyTensor(lambda: None, concrete_data=y) + return y + + cls._pre_op_fn() + with ConstructorManager.disable(): + # to disable create lazy tensor in inner ops, this is a hack for torch 2.0 + o = func(*tree_map(unwrap, args), **tree_map(unwrap, kwargs)) + if isinstance(o, (tuple, list)): + return type(o)(wrap(y, i=i) for i, y in enumerate(o)) + return wrap(o) + + def to(self, *args, **kwargs) -> torch.Tensor: + if self._materialized_data is not None: + return LazyTensor(lambda: None, concrete_data=self._materialized_data.to(*args, **kwargs)) + + device = None + + def replace(x): + nonlocal device + if isinstance(x, (str, int, torch.device)) and not isinstance(x, bool): + device = x + return torch.device("meta") + return x + + meta_data = self._meta_data.to(*tree_map(replace, args), **tree_map(replace, kwargs)) + + if meta_data is self._meta_data and device == self.device: + return self + + def factory_fn(t: torch.Tensor, **kw): + return t.to(*args, **kwargs) + + return LazyTensor(factory_fn, self, meta_data=meta_data, device=device) + + def cpu(self, memory_format: torch.memory_format = torch.preserve_format): + return self.to(device=torch.device("cpu"), memory_format=memory_format) + + def cuda(self, device=None, non_blocking=False, memory_format: torch.memory_format = torch.preserve_format): + device = torch.device(device or "cuda") + return self.to(device=device, non_blocking=non_blocking, memory_format=memory_format) + + def clone(self) -> "LazyTensor": + def factory_fn(t: torch.Tensor, **kw): + # if self is materialized, return self + return t.clone() + + target = LazyTensor(factory_fn, self, meta_data=self._meta_data) + + return target + + def detach(self) -> Tensor: + return self + + def __deepcopy__(self, memo): + if not self.is_leaf: + raise RuntimeError( + "Only Tensors created explicitly by the user " + "(graph leaves) support the deepcopy protocol at the moment" + ) + if id(self) in memo: + return memo[id(self)] + + def factory_fn(t: torch.Tensor, **kw): + # if self is materialized, return self + return _copy_tensor(t, t.requires_grad) + + if self._materialized_data is not None: + # self is early materialized + copied = _copy_tensor(self._materialized_data, self.requires_grad) + target = LazyTensor(lambda: None, concrete_data=copied) + else: + target = LazyTensor(factory_fn, self, meta_data=self._meta_data) + + if isinstance(self, Parameter): + # hack isinstance check of parameter + target._is_param = True + + memo[id(self)] = target + return target + + @property + def data(self): + return self + + @data.setter + def data(self, other: "LazyTensor"): + """This is sightly different from oringinal `data` setter. + + E.g.: + >>> a = torch.randn(3, 3) # a is a Tensor + >>> b = torch.rand(2, 2) + >>> a.data = b + >>> b.add_(1) # this will affect a + >>> x = torch.randn(3, 3) # x is a LazyTensor + >>> y = torch.rand(2, 2) # y is a LazyTensor + >>> x.data = y + >>> y.add_(1) # this will not affect x + + """ + if other is self: + return + + def replace(x): + if x is other: + return self + return x + + for func, args, kwargs in [other._factory_method, *other._op_buffer]: + self._op_buffer.append((func, tree_map(replace, args), tree_map(replace, kwargs))) + + def tolist(self) -> list: + # Though self.__class__ is modified to torch.Tensor, in C++ side, it is still a subclass of torch.Tensor + # And subclass of torch.Tensor does not have tolist() method + t = self._materialize_data() + return t.tolist() + + def __hash__(self): + return id(self) + + def __rpow__(self, other): + dtype = torch.result_type(self, other) + return torch.tensor(other, dtype=dtype, device=self.device) ** self + + +class LazyInitContext: + """Context manager for lazy initialization. Enables initializing the model without allocating real memory. + + Args: + tensor_cls (Union[_MyTensor, LazyTensor], optional): This is only for test. Defaults to LazyTensor. + default_device (Optional[Union[torch.device, str, int]], optional): Defalt device for initialization. + If it's cuda, initilization will be accelerated, but cuda memory will be allocated. By default, it's cpu. + Defaults to None. + """ + + _replaced: bool = False + + def __init__( + self, + tensor_cls: Union[_MyTensor, LazyTensor] = LazyTensor, + default_device: Optional[Union[torch.device, str, int]] = None, + ): + assert tensor_cls is LazyTensor or tensor_cls is _MyTensor + self.tensor_cls = tensor_cls + self.old_default_device = LazyTensor.default_device + self.default_device = default_device + + def __enter__(self): + if LazyInitContext._replaced: + raise RuntimeError(f"LazyInitContext is not reentrant") + LazyInitContext._replaced = True + self.old_default_device = self.tensor_cls.default_device + self.tensor_cls.default_device = self.default_device + + def wrap_factory_method(target): + # factory functions (eg. torch.empty()) + def wrapper(*args, **kwargs): + return self.tensor_cls(target, *args, **kwargs) + + return wrapper, target + + def wrap_factory_like_method(orig_target, target): + # factory_like functions (eg. torch.empty_like()) + def wrapper(*args, **kwargs): + orig_t = args[0] + return self.tensor_cls( + orig_target, *orig_t.shape, *args[1:], device=orig_t.device, dtype=orig_t.dtype, **kwargs + ) + + return wrapper, target + + def wrap_legacy_constructor(target, dtype): + # legacy constructor (e.g. torch.LongTensor()) + def wrapper(*args, **kwargs): + if len(args) == 1 and isinstance(args[0], torch.Tensor): + # (Tensor other) + return args[0] + elif len(args) == 1: + # (object data, *, torch.device device) + kwargs = {**kwargs, "dtype": dtype} + replaced, orig = self.overrides["tensor"] + return replaced(*args, **kwargs) + elif _is_int_tuple(args): + # (tuple of ints size, *, torch.device device) + kwargs = {**kwargs, "dtype": dtype} + replaced, orig = self.overrides["empty"] + return replaced(*args, **kwargs) + else: + raise TypeError( + f"new() received an invalid combination of arguments - got {tuple(type(x) for x in args)}, but expected one of:\n * (Tensor other)\n * (tuple of ints size, *, torch.device device)\n * (object data, *, torch.device device)" + ) + + return wrapper, target + + def wrap_no_meta_factory(target): + # factory functions which don't support meta tensor backend + def wrapper(*args, **kwargs): + tensor = target(*args, **kwargs) + return self.tensor_cls(lambda: None, concrete_data=tensor) + + return wrapper, target + + overrides = { + target: wrap_factory_method(getattr(torch, target)) + for target in _NORMAL_FACTORY + if callable(getattr(torch, target, None)) + } + + overrides.update( + { + target + "_like": wrap_factory_like_method(getattr(torch, target), getattr(torch, target + "_like")) + for target in _NORMAL_FACTORY + if callable(getattr(torch, target + "_like", None)) + } + ) + + overrides.update( + { + target: wrap_legacy_constructor(getattr(torch, target), dtype) + for target, dtype in _LEGACY_TENSOR_CONSTRUCTOR.items() + if callable(getattr(torch, target, None)) + } + ) + + overrides.update( + { + target: wrap_no_meta_factory(getattr(torch, target)) + for target in _NO_META_FACTORY + if callable(getattr(torch, target, None)) + } + ) + + ConstructorManager.apply(overrides) + PretrainedManager.inject() + + def __exit__(self, exc_type, exc_val, exc_tb): + self.tensor_cls.default_device = self.old_default_device + LazyInitContext._replaced = False + ConstructorManager.clear() + PretrainedManager.recover() + + @staticmethod + def materialize(module: nn.Module, verbose: bool = False) -> nn.Module: + """Initialize all ``Parameter`` from ``LazyTensor``. This function will modify the module in-place. + + Args: + module (nn.Module): Target ``nn.Module`` + verbose (bool): Whether to print lazy initialization rate. Defaults to False. + """ + + def apply_fn(name: str, p: LazyTensor): + p.materialize() + + return _apply_to_lazy_module(module, apply_fn, verbose) + + +def _apply_to_lazy_module( + module: nn.Module, apply_fn: Callable[[str, torch.Tensor], None], verbose: bool = False +) -> nn.Module: + if verbose: + # verbose info + param_cnt = 0 + param_lazy_cnt = 0 + buf_cnt = 0 + buf_lazy_cnt = 0 + total_numel = 0 + non_lazy_numel = 0 + + for name, p in module.named_parameters(): + if verbose: + param_cnt += 1 + total_numel += p.numel() + if getattr(p, "_materialized_data", False) is None: + # if no _materialized_data attr, the tensor is not lazy + param_lazy_cnt += 1 + else: + non_lazy_numel += p.numel() + if isinstance(p, LazyTensor): + apply_fn(name, p) + + for name, buf in module.named_buffers(): + if verbose: + buf_cnt += 1 + total_numel += buf.numel() + if getattr(buf, "_materialized_data", False) is None: + # if no _materialized_data attr, the tensor is not lazy + buf_lazy_cnt += 1 + else: + non_lazy_numel += buf.numel() + if isinstance(buf, LazyTensor): + apply_fn(name, buf) + + if verbose: + non_lazy_numel_ratio = non_lazy_numel / total_numel * 100 if non_lazy_numel != 0 else 0 + logger = get_dist_logger() + logger.info(f"Param lazy rate: {param_lazy_cnt}/{param_cnt}", ranks=[0]) + logger.info(f"Buffer lazy rate: {buf_lazy_cnt}/{buf_cnt}", ranks=[0]) + logger.info( + f"Non lazy numel: {non_lazy_numel} ({non_lazy_numel/1024**2:.3f} M), ratio: {non_lazy_numel_ratio}%", + ranks=[0], + ) + + return module + + +def _is_int_tuple(args) -> bool: + if not isinstance(args, tuple): + return False + for x in args: + if not isinstance(x, int): + return False + return True + + +def _copy_tensor(tensor: Tensor, requires_grad: bool) -> Tensor: + copied = tensor.data.clone() + copied.requires_grad = requires_grad + return copied diff --git a/colossalai/lazy/pretrained.py b/colossalai/lazy/pretrained.py new file mode 100644 index 0000000000000000000000000000000000000000..21d44d4244d37986e268e5897161ec33f29aa132 --- /dev/null +++ b/colossalai/lazy/pretrained.py @@ -0,0 +1,309 @@ +import os +from typing import Callable, Optional, Union + +import torch +from torch.nn import Module + +from colossalai.interface import pretrained as pretrained_interface + + +class PretrainedManager: + old_from_pretrained: Optional[Callable] = None + + @staticmethod + def inject() -> None: + try: + from transformers.modeling_utils import PreTrainedModel + except ImportError: + return + # recover bound method to plain function + PretrainedManager.old_from_pretrained = PreTrainedModel.from_pretrained.__func__ + PreTrainedModel.from_pretrained = new_from_pretrained + + @staticmethod + def recover() -> None: + try: + from transformers.modeling_utils import PreTrainedModel + except ImportError: + return + # convert plain function to class method + PreTrainedModel.from_pretrained = classmethod(PretrainedManager.old_from_pretrained) + PretrainedManager.old_from_pretrained = None + + +@classmethod +def new_from_pretrained( + cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], *model_args, **kwargs +) -> Module: + from transformers import GenerationConfig + from transformers.configuration_utils import PretrainedConfig + from transformers.modeling_utils import ( + ContextManagers, + _add_variant, + cached_file, + download_url, + has_file, + is_offline_mode, + is_remote_url, + no_init_weights, + ) + from transformers.utils import ( + SAFE_WEIGHTS_INDEX_NAME, + SAFE_WEIGHTS_NAME, + WEIGHTS_INDEX_NAME, + WEIGHTS_NAME, + is_safetensors_available, + logging, + ) + + logger = logging.get_logger(__name__) + + config = kwargs.pop("config", None) + cache_dir = kwargs.pop("cache_dir", None) + force_download = kwargs.pop("force_download", False) + resume_download = kwargs.pop("resume_download", False) + proxies = kwargs.pop("proxies", None) + local_files_only = kwargs.pop("local_files_only", False) + use_auth_token = kwargs.pop("use_auth_token", None) + revision = kwargs.pop("revision", None) + _ = kwargs.pop("mirror", None) + from_pipeline = kwargs.pop("_from_pipeline", None) + from_auto_class = kwargs.pop("_from_auto", False) + _fast_init = kwargs.pop("_fast_init", True) + torch_dtype = kwargs.pop("torch_dtype", None) + subfolder = kwargs.pop("subfolder", "") + commit_hash = kwargs.pop("_commit_hash", None) + variant = kwargs.pop("variant", None) + use_safetensors = kwargs.pop("use_safetensors", None if is_safetensors_available() else False) + + if len(kwargs) > 0: + logger.warning(f"Below kwargs may be ignored: {list(kwargs.keys())}") + + from_pt = True + + user_agent = {"file_type": "model", "framework": "pytorch", "from_auto_class": from_auto_class} + if from_pipeline is not None: + user_agent["using_pipeline"] = from_pipeline + + if is_offline_mode() and not local_files_only: + logger.info("Offline mode: forcing local_files_only=True") + local_files_only = True + + # Load config if we don't provide a configuration + if not isinstance(config, PretrainedConfig): + config_path = config if config is not None else pretrained_model_name_or_path + config, model_kwargs = cls.config_class.from_pretrained( + config_path, + cache_dir=cache_dir, + return_unused_kwargs=True, + force_download=force_download, + resume_download=resume_download, + proxies=proxies, + local_files_only=local_files_only, + use_auth_token=use_auth_token, + revision=revision, + subfolder=subfolder, + _from_auto=from_auto_class, + _from_pipeline=from_pipeline, + **kwargs, + ) + else: + model_kwargs = kwargs + + if commit_hash is None: + commit_hash = getattr(config, "_commit_hash", None) + + # This variable will flag if we're loading a sharded checkpoint. In this case the archive file is just the + # index of the files. + + if pretrained_model_name_or_path is not None: + pretrained_model_name_or_path = str(pretrained_model_name_or_path) + is_local = os.path.isdir(pretrained_model_name_or_path) + if is_local: + if use_safetensors is not False and os.path.isfile( + os.path.join(pretrained_model_name_or_path, subfolder, _add_variant(SAFE_WEIGHTS_NAME, variant)) + ): + # Load from a safetensors checkpoint + archive_file = os.path.join( + pretrained_model_name_or_path, subfolder, _add_variant(SAFE_WEIGHTS_NAME, variant) + ) + elif use_safetensors is not False and os.path.isfile( + os.path.join(pretrained_model_name_or_path, subfolder, _add_variant(SAFE_WEIGHTS_INDEX_NAME, variant)) + ): + # Load from a sharded safetensors checkpoint + archive_file = os.path.join( + pretrained_model_name_or_path, subfolder, _add_variant(SAFE_WEIGHTS_INDEX_NAME, variant) + ) + elif os.path.isfile( + os.path.join(pretrained_model_name_or_path, subfolder, _add_variant(WEIGHTS_NAME, variant)) + ): + # Load from a PyTorch checkpoint + archive_file = os.path.join( + pretrained_model_name_or_path, subfolder, _add_variant(WEIGHTS_NAME, variant) + ) + elif os.path.isfile( + os.path.join(pretrained_model_name_or_path, subfolder, _add_variant(WEIGHTS_INDEX_NAME, variant)) + ): + # Load from a sharded PyTorch checkpoint + archive_file = os.path.join( + pretrained_model_name_or_path, subfolder, _add_variant(WEIGHTS_INDEX_NAME, variant) + ) + else: + raise EnvironmentError( + f"Error no file named {_add_variant(WEIGHTS_NAME, variant)} found in directory" + f" {pretrained_model_name_or_path}." + ) + elif os.path.isfile(os.path.join(subfolder, pretrained_model_name_or_path)): + archive_file = pretrained_model_name_or_path + is_local = True + elif is_remote_url(pretrained_model_name_or_path): + filename = pretrained_model_name_or_path + resolved_archive_file = download_url(pretrained_model_name_or_path) + else: + # set correct filename + if use_safetensors is not False: + filename = _add_variant(SAFE_WEIGHTS_NAME, variant) + else: + filename = _add_variant(WEIGHTS_NAME, variant) + + try: + # Load from URL or cache if already cached + cached_file_kwargs = { + "cache_dir": cache_dir, + "force_download": force_download, + "proxies": proxies, + "resume_download": resume_download, + "local_files_only": local_files_only, + "use_auth_token": use_auth_token, + "user_agent": user_agent, + "revision": revision, + "subfolder": subfolder, + "_raise_exceptions_for_missing_entries": False, + "_commit_hash": commit_hash, + } + resolved_archive_file = cached_file(pretrained_model_name_or_path, filename, **cached_file_kwargs) + + # Since we set _raise_exceptions_for_missing_entries=False, we don't get an exception but a None + # result when internet is up, the repo and revision exist, but the file does not. + if resolved_archive_file is None and filename == _add_variant(SAFE_WEIGHTS_NAME, variant): + # Maybe the checkpoint is sharded, we try to grab the index name in this case. + resolved_archive_file = cached_file( + pretrained_model_name_or_path, + _add_variant(SAFE_WEIGHTS_INDEX_NAME, variant), + **cached_file_kwargs, + ) + if resolved_archive_file is not None: + pass + elif use_safetensors: + raise EnvironmentError( + f" {_add_variant(SAFE_WEIGHTS_NAME, variant)} or {_add_variant(SAFE_WEIGHTS_INDEX_NAME, variant)} and thus cannot be loaded with `safetensors`. Please make sure that the model has been saved with `safe_serialization=True` or do not set `use_safetensors=True`." + ) + else: + # This repo has no safetensors file of any kind, we switch to PyTorch. + filename = _add_variant(WEIGHTS_NAME, variant) + resolved_archive_file = cached_file( + pretrained_model_name_or_path, filename, **cached_file_kwargs + ) + if resolved_archive_file is None and filename == _add_variant(WEIGHTS_NAME, variant): + # Maybe the checkpoint is sharded, we try to grab the index name in this case. + resolved_archive_file = cached_file( + pretrained_model_name_or_path, + _add_variant(WEIGHTS_INDEX_NAME, variant), + **cached_file_kwargs, + ) + if resolved_archive_file is not None: + pass + if resolved_archive_file is None: + # Otherwise, maybe there is a TF or Flax model file. We try those to give a helpful error + # message. + has_file_kwargs = { + "revision": revision, + "proxies": proxies, + "use_auth_token": use_auth_token, + } + if variant is not None and has_file(pretrained_model_name_or_path, WEIGHTS_NAME, **has_file_kwargs): + raise EnvironmentError( + f"{pretrained_model_name_or_path} does not appear to have a file named" + f" {_add_variant(WEIGHTS_NAME, variant)} but there is a file without the variant" + f" {variant}. Use `variant=None` to load this model from those weights." + ) + else: + raise EnvironmentError( + f"{pretrained_model_name_or_path} does not appear to have a file named" + f" {_add_variant(WEIGHTS_NAME, variant)}" + ) + except EnvironmentError: + # Raise any environment error raise by `cached_file`. It will have a helpful error message adapted + # to the original exception. + raise + except Exception: + # For any other exception, we throw a generic error. + raise EnvironmentError( + f"Can't load the model for '{pretrained_model_name_or_path}'. If you were trying to load it" + " from 'https://huggingface.co/models', make sure you don't have a local directory with the" + f" same name. Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a" + f" directory containing a file named {_add_variant(WEIGHTS_NAME, variant)}." + ) + + if is_local: + logger.info(f"loading weights file {archive_file}") + resolved_archive_file = archive_file + else: + logger.info(f"loading weights file {filename} from cache at {resolved_archive_file}") + else: + resolved_archive_file = None + + if from_pt: + # set dtype to instantiate the model under: + # 1. If torch_dtype is not None, we use that dtype + dtype_orig = None + + if torch_dtype is not None: + if not isinstance(torch_dtype, torch.dtype): + raise ValueError(f"`torch_dtype` can be either `torch.dtype` or `None`, but received {torch_dtype}") + dtype_orig = cls._set_default_torch_dtype(torch_dtype) + + config.name_or_path = pretrained_model_name_or_path + + # Instantiate model. + init_contexts = [no_init_weights(_enable=_fast_init)] + + with ContextManagers(init_contexts): + model = cls(config, *model_args, **model_kwargs) + + if from_pt: + # restore default dtype + if dtype_orig is not None: + torch.set_default_dtype(dtype_orig) + + # make sure token embedding weights are still tied if needed + model.tie_weights() + + # Set model in evaluation mode to deactivate DropOut modules by default + model.eval() + + # If it is a model with generation capabilities, attempt to load the generation config + if model.can_generate(): + try: + model.generation_config = GenerationConfig.from_pretrained( + pretrained_model_name_or_path, + cache_dir=cache_dir, + force_download=force_download, + resume_download=resume_download, + proxies=proxies, + local_files_only=local_files_only, + use_auth_token=use_auth_token, + revision=revision, + subfolder=subfolder, + _from_auto=from_auto_class, + _from_pipeline=from_pipeline, + **kwargs, + ) + except (OSError, TypeError): + logger.info("Generation config file not found, using a generation config created from the model config.") + + # set pretrained path + if resolved_archive_file: + pretrained_interface.set_pretrained_path(model, resolved_archive_file) + + return model diff --git a/colossalai/legacy/__init__.py b/colossalai/legacy/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..678a5def5c68e39bbb506d5f32ca997bcd456b66 --- /dev/null +++ b/colossalai/legacy/__init__.py @@ -0,0 +1,17 @@ +from .initialize import ( + get_default_parser, + initialize, + launch, + launch_from_openmpi, + launch_from_slurm, + launch_from_torch, +) + +__all__ = [ + "launch", + "launch_from_openmpi", + "launch_from_slurm", + "launch_from_torch", + "initialize", + "get_default_parser", +] diff --git a/colossalai/legacy/amp/__init__.py b/colossalai/legacy/amp/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..9d17d88b4c795d6949963ca398f918282a83b68f --- /dev/null +++ b/colossalai/legacy/amp/__init__.py @@ -0,0 +1,53 @@ +#!/usr/bin/env python +# -*- encoding: utf-8 -*- + +import torch.nn as nn +from torch.nn.modules.loss import _Loss +from torch.optim import Optimizer + +from colossalai.context import Config + +from .amp_type import AMP_TYPE +from .apex_amp import convert_to_apex_amp +from .naive_amp import convert_to_naive_amp +from .torch_amp import convert_to_torch_amp + +__all__ = ["convert_to_amp", "convert_to_naive_amp", "convert_to_apex_amp", "convert_to_torch_amp", "AMP_TYPE"] + + +def convert_to_amp(model: nn.Module, optimizer: Optimizer, criterion: _Loss, mode: AMP_TYPE, amp_config: Config = None): + """A helper function to wrap training components with Torch AMP modules. + + Args: + param model (:class:`torch.nn.Module`): your model object. + optimizer (:class:`torch.optim.Optimizer`): your optimizer object. + criterion (:class:`torch.nn.modules.loss._Loss`): your loss function object. + mode (:class:`colossalai.legacy.amp.AMP_TYPE`): amp mode. + amp_config (Union[:class:`colossalai.context.Config`, dict]): configuration for different amp modes. + + Returns: + A tuple (model, optimizer, criterion). + + Note: + ``amp_config`` may vary from different mode you choose. You should check the corresponding amp mode + for more details about ``amp_config``. + For ``apex_amp``, please check + `apex_amp config `_. + For ``naive_amp``, please check + `naive_amp config `_. + For ``torch_amp``, please check + `torch_amp config `_. + """ + assert isinstance(mode, AMP_TYPE), f"expected the argument mode be AMP_TYPE, but got {type(mode)}" + + if amp_config is None: + amp_config = Config() + + if mode == AMP_TYPE.TORCH: + model, optimizer, criterion = convert_to_torch_amp(model, optimizer, criterion, amp_config) + elif mode == AMP_TYPE.APEX: + model, optimizer = convert_to_apex_amp(model, optimizer, amp_config) + elif mode == AMP_TYPE.NAIVE: + model, optimizer = convert_to_naive_amp(model, optimizer, amp_config) + + return model, optimizer, criterion diff --git a/colossalai/legacy/amp/amp_type.py b/colossalai/legacy/amp/amp_type.py new file mode 100644 index 0000000000000000000000000000000000000000..5ad5faf08b71e55d438b57401457fa1c7c4e9d82 --- /dev/null +++ b/colossalai/legacy/amp/amp_type.py @@ -0,0 +1,10 @@ +#!/usr/bin/env python +# -*- encoding: utf-8 -*- + +from enum import Enum + + +class AMP_TYPE(Enum): + APEX = "apex" + TORCH = "torch" + NAIVE = "naive" diff --git a/colossalai/legacy/amp/apex_amp/__init__.py b/colossalai/legacy/amp/apex_amp/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..680c6e45ca9df92c8e43bacf2aff2ff02dd7d95e --- /dev/null +++ b/colossalai/legacy/amp/apex_amp/__init__.py @@ -0,0 +1,43 @@ +import torch.nn as nn +from torch.optim import Optimizer + +from .apex_amp import ApexAMPOptimizer + + +def convert_to_apex_amp(model: nn.Module, optimizer: Optimizer, amp_config): + r"""A helper function to wrap training components with Apex AMP modules + + Args: + model (:class:`torch.nn.Module`): your model object. + optimizer (:class:`torch.optim.Optimizer`): your optimizer object. + amp_config (Union[:class:`colossalai.context.Config`, dict]): configuration for initializing apex_amp. + + Returns: + Tuple: A tuple (model, optimizer). + + The ``amp_config`` should include parameters below: + :: + + enabled (bool, optional, default=True) + opt_level (str, optional, default="O1") + cast_model_type (``torch.dtype``, optional, default=None) + patch_torch_functions (bool, optional, default=None) + keep_batchnorm_fp32 (bool or str, optional, default=None + master_weights (bool, optional, default=None) + loss_scale (float or str, optional, default=None) + cast_model_outputs (torch.dtype, optional, default=None) + num_losses (int, optional, default=1) + verbosity (int, default=1) + min_loss_scale (float, default=None) + max_loss_scale (float, default=2.**24) + + More details about ``amp_config`` refer to `amp_config `_. + """ + import apex.amp as apex_amp + + model, optimizer = apex_amp.initialize(model, optimizer, **amp_config) + optimizer = ApexAMPOptimizer(optimizer) + return model, optimizer + + +__all__ = ["convert_to_apex_amp", "ApexAMPOptimizer"] diff --git a/colossalai/amp/apex_amp/apex_amp.py b/colossalai/legacy/amp/apex_amp/apex_amp.py similarity index 76% rename from colossalai/amp/apex_amp/apex_amp.py rename to colossalai/legacy/amp/apex_amp/apex_amp.py index e6bdbe4520f92450e80e930c0a7c746881e10bba..048c51891b176a923a7ac7aaf0abf4e46f982e66 100644 --- a/colossalai/amp/apex_amp/apex_amp.py +++ b/colossalai/legacy/amp/apex_amp/apex_amp.py @@ -10,12 +10,12 @@ except ImportError: from torch import Tensor -from colossalai.nn.optimizer import ColossalaiOptimizer -from colossalai.utils import clip_grad_norm_fp32 +from colossalai.interface import OptimizerWrapper +from colossalai.legacy.utils import clip_grad_norm_fp32 -class ApexAMPOptimizer(ColossalaiOptimizer): - """ A wrapper class for APEX optimizer and it implements apex-specific backward and clip_grad_norm +class ApexAMPOptimizer(OptimizerWrapper): + """A wrapper class for APEX optimizer and it implements apex-specific backward and clip_grad_norm methods """ diff --git a/colossalai/legacy/amp/naive_amp/__init__.py b/colossalai/legacy/amp/naive_amp/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..36e402299147f7a4480aaaa8daef6df4dec1e671 --- /dev/null +++ b/colossalai/legacy/amp/naive_amp/__init__.py @@ -0,0 +1,60 @@ +import inspect + +import torch.nn as nn +from torch.optim import Optimizer + +from colossalai.amp.naive_amp.grad_scaler import ConstantGradScaler, DynamicGradScaler +from colossalai.legacy.utils import is_no_pp_or_last_stage + +from ._fp16_optimizer import FP16Optimizer +from .naive_amp import NaiveAMPModel, NaiveAMPOptimizer + + +def convert_to_naive_amp(model: nn.Module, optimizer: Optimizer, amp_config): + """A helper function to wrap training components with naive AMP modules. In this mode, + we forcibly cast the model weights and inputs to FP16, and cast the model outputs to FP32 to calculate loss, + which is equivalent to Apex O3. + + Args: + model (:class:`torch.nn.Module`): your model object + optimizer (:class:`torch.optim.Optimizer`): your optimizer object + amp_config (:class:`colossalai.context.Config` or dict): configuration for naive mode amp. + + Returns: + Tuple: A tuple (model, optimizer) + + The ``amp_config`` should contain parameters below:: + + verbose (bool, optional): if set to `True`, will print debug info (Default: False). + clip_grad_norm (float, optional): clip gradients with this global L2 norm (Default 0). + Note that clipping is ignored if clip_grad == 0. + dynamic_grad_scale (bool): whether to use dynamic grad scaler. + """ + if isinstance(model, nn.ModuleList): + # interleaved pipeline + module_list = [] + for chunk, m in enumerate(model): + output_to_fp32 = is_no_pp_or_last_stage() and chunk == len(model) - 1 + module_list.append(NaiveAMPModel(m, output_to_fp32=output_to_fp32)) + model = nn.ModuleList(module_list) + else: + output_to_fp32 = is_no_pp_or_last_stage() + model = NaiveAMPModel(model, output_to_fp32=output_to_fp32) + + use_dynamic_grad_scaler = amp_config.pop("dynamic_grad_scale", True) + if use_dynamic_grad_scaler: + scaler_class = DynamicGradScaler + else: + scaler_class = ConstantGradScaler + + sig = inspect.signature(scaler_class.__init__) + kwargs = dict() + for param in sig.parameters.values(): + if param.name in amp_config: + kwargs[param.name] = amp_config.pop(param.name) + grad_scaler = scaler_class(**kwargs) + optimizer = NaiveAMPOptimizer(optimizer, grad_scaler, **amp_config) + return model, optimizer + + +__all__ = ["convert_to_naive_amp", "NaiveAMPOptimizer", "FP16Optimizer"] diff --git a/colossalai/amp/naive_amp/_fp16_optimizer.py b/colossalai/legacy/amp/naive_amp/_fp16_optimizer.py similarity index 84% rename from colossalai/amp/naive_amp/_fp16_optimizer.py rename to colossalai/legacy/amp/naive_amp/_fp16_optimizer.py index e4699f92b9444005086a2a625ed243f0fa49ec44..97ec57fbd007223f4c7803ad98abbd62b208bebf 100644 --- a/colossalai/amp/naive_amp/_fp16_optimizer.py +++ b/colossalai/legacy/amp/naive_amp/_fp16_optimizer.py @@ -6,21 +6,22 @@ import torch.distributed as dist from torch.distributed import ProcessGroup from torch.optim import Optimizer -from colossalai.context import ParallelMode -from colossalai.core import global_context as gpc +from colossalai.amp.naive_amp.grad_scaler import BaseGradScaler from colossalai.kernel.op_builder import FusedOptimBuilder +from colossalai.legacy.context import ParallelMode +from colossalai.legacy.core import global_context as gpc +from colossalai.legacy.utils import clip_grad_norm_fp32, copy_tensor_parallel_attributes from colossalai.logging import get_dist_logger -from colossalai.utils import clip_grad_norm_fp32, copy_tensor_parallel_attributes, multi_tensor_applier +from colossalai.utils import multi_tensor_applier from ._utils import has_inf_or_nan, zero_gard_by_list -from .grad_scaler import BaseGradScaler try: from colossalai._C import fused_optim except: fused_optim = None -__all__ = ['FP16Optimizer'] +__all__ = ["FP16Optimizer"] def load_fused_optim(): @@ -62,13 +63,15 @@ class FP16Optimizer(Optimizer): verbose (bool, optional): if set to `True`, will print debug info. Default False. """ - def __init__(self, - optimizer: Optimizer, - grad_scaler: BaseGradScaler, - verbose: bool = False, - clip_grad_norm=0, - dp_process_group: ProcessGroup = None, - mp_process_group: ProcessGroup = None): + def __init__( + self, + optimizer: Optimizer, + grad_scaler: BaseGradScaler, + verbose: bool = False, + clip_grad_norm=0, + dp_process_group: ProcessGroup = None, + mp_process_group: ProcessGroup = None, + ): # have a defaults for compatibility with pytorch optim self._optimizer = optimizer self._defaults = optimizer.defaults @@ -116,10 +119,10 @@ class FP16Optimizer(Optimizer): fp32_master_params = [] fp32_params = [] # For all the parameters in this group: - for i, param in enumerate(param_group['params']): + for i, param in enumerate(param_group["params"]): if param.requires_grad: # float16 params: - if param.type() in ['torch.cuda.HalfTensor']: + if param.type() in ["torch.cuda.HalfTensor"]: fp16_params.append(param) # Create a fp32 copy @@ -128,7 +131,7 @@ class FP16Optimizer(Optimizer): copy_tensor_parallel_attributes(param, fp32_param) # Replace the optimizer params with the new fp32 copy. - param_group['params'][i] = fp32_param + param_group["params"][i] = fp32_param fp32_master_params.append(fp32_param) # Reset existing state dict key to the new main param. @@ -136,11 +139,13 @@ class FP16Optimizer(Optimizer): self._optimizer.state[fp32_param] = self._optimizer.state.pop(param) # fp32 params. - elif param.type() == 'torch.cuda.FloatTensor': + elif param.type() == "torch.cuda.FloatTensor": fp32_params.append(param) else: - raise TypeError('Expected parameter of type torch.cuda.FloatTensor ' - f'or torch.cuda.HalfTensor, but got {param.type()}') + raise TypeError( + "Expected parameter of type torch.cuda.FloatTensor " + f"or torch.cuda.HalfTensor, but got {param.type()}" + ) self._fp16_param_groups.append(fp16_params) self._fp32_master_param_groups.append(fp32_master_params) @@ -159,12 +164,12 @@ class FP16Optimizer(Optimizer): f"clip_grad_norm = {clip_grad_norm}\n" f"grad_scaler = {self._grad_scaler.__class__.__name__}" f"==========================================", - ranks=[0]) + ranks=[0], + ) @property def max_norm(self): - """Returns the maximum norm of gradient clipping. - """ + """Returns the maximum norm of gradient clipping.""" return self._clip_grad_max_norm @property @@ -210,7 +215,7 @@ class FP16Optimizer(Optimizer): # check for overflow for group in self._optimizer.param_groups: - for p in group['params']: + for p in group["params"]: if p.grad is not None and has_inf_or_nan(p.grad): self._found_overflow.fill_(1.0) break @@ -234,7 +239,7 @@ class FP16Optimizer(Optimizer): # set_to_none = True can save some memory space for param_group in self._optimizer.param_groups: - zero_gard_by_list(param_group['params'], set_to_none=set_to_none) + zero_gard_by_list(param_group["params"], set_to_none=set_to_none) def _get_fp32_param_groups_to_update(self): return self._fp32_master_param_groups + self._fp32_param_groups @@ -261,13 +266,12 @@ class FP16Optimizer(Optimizer): for fp16_param, fp32_param in zip(fp16_group, fp32_group): fp16_param_data.append(fp16_param.data) fp32_master_param_data.append(fp32_param.data) - _multi_tensor_copy_this_to_that(this=fp32_master_param_data, - that=fp16_param_data, - overflow_buf=self._dummy_overflow_buf) + _multi_tensor_copy_this_to_that( + this=fp32_master_param_data, that=fp16_param_data, overflow_buf=self._dummy_overflow_buf + ) def step(self): - """Update the model parameters. - """ + """Update the model parameters.""" # Copy gradients from model params to main params. self._assign_grad_to_fp32_master_param() @@ -306,14 +310,13 @@ class FP16Optimizer(Optimizer): scaled_loss.backward() def state_dict(self): - """Returns the states of the fp16 optimizer as a dict object. - """ + """Returns the states of the fp16 optimizer as a dict object.""" state_dict = {} - state_dict['optimizer'] = self._optimizer.state_dict() + state_dict["optimizer"] = self._optimizer.state_dict() if self.grad_scaler: - state_dict['grad_scaler'] = self.grad_scaler.state_dict() - state_dict['fp32_master_param_groups'] = self._fp32_master_param_groups + state_dict["grad_scaler"] = self.grad_scaler.state_dict() + state_dict["fp32_master_param_groups"] = self._fp32_master_param_groups return state_dict def load_state_dict(self, state_dict): @@ -324,16 +327,17 @@ class FP16Optimizer(Optimizer): """ # Optimizer. - self._optimizer.load_state_dict(state_dict['optimizer']) + self._optimizer.load_state_dict(state_dict["optimizer"]) # Grad scaler. - if 'grad_scaler' in state_dict: - self.grad_scaler.load_state_dict(state_dict['grad_scaler']) + if "grad_scaler" in state_dict: + self.grad_scaler.load_state_dict(state_dict["grad_scaler"]) # Copy data for the main params. - if 'fp32_master_param_groups' in state_dict: - for current_group, ckpt_group in zip(self._fp32_master_param_groups, - state_dict['fp32_master_param_groups']): + if "fp32_master_param_groups" in state_dict: + for current_group, ckpt_group in zip( + self._fp32_master_param_groups, state_dict["fp32_master_param_groups"] + ): for current_param, ckpt_param in zip(current_group, ckpt_group): current_param.data.copy_(ckpt_param.data) @@ -345,7 +349,7 @@ class FP16Optimizer(Optimizer): """ params = [] for param_group in self._optimizer.param_groups: - for param in param_group['params']: + for param in param_group["params"]: params.append(param) return clip_grad_norm_fp32(params, clip_grad) diff --git a/colossalai/legacy/amp/naive_amp/_utils.py b/colossalai/legacy/amp/naive_amp/_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..aa5a91146bb05d0923dfc5e32351b8d8ed92cf20 --- /dev/null +++ b/colossalai/legacy/amp/naive_amp/_utils.py @@ -0,0 +1,49 @@ +from typing import List + +from torch import Tensor + + +def has_inf_or_nan(tensor): + """Check if tensor has inf or nan values. + + Args: + tensor (:class:`torch.Tensor`): a torch tensor object + + Returns: + bool: Whether the tensor has inf or nan. True for yes and False for no. + """ + try: + # if tensor is half, the .float() incurs an additional deep copy, but it's necessary if + # Pytorch's .sum() creates a one-element tensor of the same type as tensor + # (which is true for some recent version of pytorch). + tensor_sum = float(tensor.float().sum()) + # More efficient version that can be used if .sum() returns a Python scalar + # tensor_sum = float(tensor.sum()) + except RuntimeError as instance: + # We want to check if inst is actually an overflow exception. + # RuntimeError could come from a different error. + # If so, we still want the exception to propagate. + if "value cannot be converted" not in instance.args[0]: + raise + return True + else: + if tensor_sum == float("inf") or tensor_sum == -float("inf") or tensor_sum != tensor_sum: + return True + return False + + +def zero_gard_by_list(tensor_list: List[Tensor], set_to_none: bool = True) -> None: + """Clear the gradient of a list of tensors, + + Note: copied from torch.optim.optimizer. + """ + for param in tensor_list: + if param.grad is not None: + if set_to_none: + param.grad = None + else: + if param.grad.grad_fn is not None: + param.grad.detach_() + else: + param.grad.requires_grad_(False) + param.grad.zero_() diff --git a/colossalai/amp/naive_amp/naive_amp.py b/colossalai/legacy/amp/naive_amp/naive_amp.py similarity index 87% rename from colossalai/amp/naive_amp/naive_amp.py rename to colossalai/legacy/amp/naive_amp/naive_amp.py index 6a39d518d3f42716b800b7673fd128a4d6afe91b..f9c298941fa91cb84de91a31a100056425b66369 100644 --- a/colossalai/amp/naive_amp/naive_amp.py +++ b/colossalai/legacy/amp/naive_amp/naive_amp.py @@ -11,14 +11,14 @@ from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors from torch.distributed import ReduceOp from torch.optim import Optimizer -from colossalai.context import ParallelMode -from colossalai.core import global_context as gpc -from colossalai.nn.optimizer import ColossalaiOptimizer +from colossalai.interface import OptimizerWrapper +from colossalai.legacy.context import ParallelMode +from colossalai.legacy.core import global_context as gpc from ._fp16_optimizer import FP16Optimizer -class NaiveAMPOptimizer(ColossalaiOptimizer): +class NaiveAMPOptimizer(OptimizerWrapper): """A wrapper class for optimizer to cast all parameters to fp16 Args: @@ -45,9 +45,11 @@ class NaiveAMPOptimizer(ColossalaiOptimizer): def clip_grad_norm(self, model: nn.Module, max_norm: float): if self.optim.max_norm == max_norm: return - raise RuntimeError("NaiveAMP optimizer has clipped gradients during optimizer.step(). " - "If you have supplied clip_grad_norm in the amp_config, " - "executing the method clip_grad_norm is not allowed.") + raise RuntimeError( + "NaiveAMP optimizer has clipped gradients during optimizer.step(). " + "If you have supplied clip_grad_norm in the amp_config, " + "executing the method clip_grad_norm is not allowed." + ) class NaiveAMPModel(nn.Module): @@ -57,7 +59,7 @@ class NaiveAMPModel(nn.Module): Args: model (torch.nn.Module): torch.nn.Module to be wrapped. output_to_fp32 (bool, optional): Whether cast output of this module into fp32. (Default: True) - parallel_mode (:class:`colossalai.context.ParallelMode`): Parallel group mode used in this module. + parallel_mode (:class:`colossalai.legacy.context.ParallelMode`): Parallel group mode used in this module. (Default: ``ParallelMode.DATA``) sync_buffer (bool, optional): whether to synchronize buffer. (Default: True) @@ -66,11 +68,13 @@ class NaiveAMPModel(nn.Module): in `parallel_mode `_. """ - def __init__(self, - model: nn.Module, - output_to_fp32: bool = True, - parallel_mode: ParallelMode = ParallelMode.DATA, - sync_buffer: bool = True): + def __init__( + self, + model: nn.Module, + output_to_fp32: bool = True, + parallel_mode: ParallelMode = ParallelMode.DATA, + sync_buffer: bool = True, + ): super().__init__() self.model = model.half() self._output_to_fp32 = output_to_fp32 diff --git a/colossalai/legacy/amp/torch_amp/__init__.py b/colossalai/legacy/amp/torch_amp/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..ad2416eef06a58b765dee3b2d382188a8c982d52 --- /dev/null +++ b/colossalai/legacy/amp/torch_amp/__init__.py @@ -0,0 +1,44 @@ +from typing import Optional + +import torch.nn as nn +from torch.nn.modules.loss import _Loss +from torch.optim import Optimizer + +from colossalai.context import Config + +from .torch_amp import TorchAMPLoss, TorchAMPModel, TorchAMPOptimizer + + +def convert_to_torch_amp( + model: nn.Module, optimizer: Optimizer, criterion: Optional[_Loss] = None, amp_config: Optional[Config] = None +): + """A helper function to wrap training components with Pytorch AMP modules + + Args: + model (:class:`torch.nn.Module`): your model object. + optimizer (:class:`torch.optim.Optimizer`): your optimizer object + criterion (:class:`torch.nn.modules.loss._Loss`, optional): your loss function object + amp_config (:class:`colossalai.context.Config` or dict, optional): configuration for Pytorch AMP. + + The ``amp_config`` should include parameters below: + :: + + init_scale (float, optional, default=2.**16) + growth_factor (float, optional, default=2.0) + backoff_factor (float, optional, default=0.5) + growth_interval (int, optional, default=2000) + enabled (bool, optional, default=True) + + Returns: + A tuple (model, optimizer, criterion) + """ + model = TorchAMPModel(model) + if amp_config is None: + amp_config = dict() + optimizer = TorchAMPOptimizer(optimizer, **amp_config) + if criterion: + criterion = TorchAMPLoss(criterion) + return model, optimizer, criterion + + +__all__ = ["convert_to_torch_amp", "TorchAMPModel", "TorchAMPLoss", "TorchAMPOptimizer"] diff --git a/colossalai/amp/torch_amp/_grad_scaler.py b/colossalai/legacy/amp/torch_amp/_grad_scaler.py similarity index 90% rename from colossalai/amp/torch_amp/_grad_scaler.py rename to colossalai/legacy/amp/torch_amp/_grad_scaler.py index 7b78998fb8c233f13f34fdf64df95bdfd1601ee6..fc1aeec234fd5a573b9fc62cb23598fd78801343 100644 --- a/colossalai/amp/torch_amp/_grad_scaler.py +++ b/colossalai/legacy/amp/torch_amp/_grad_scaler.py @@ -13,8 +13,8 @@ import torch.distributed as dist from packaging import version from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors -from colossalai.context import ParallelMode -from colossalai.core import global_context as gpc +from colossalai.legacy.context import ParallelMode +from colossalai.legacy.core import global_context as gpc class _MultiDeviceReplicator(object): @@ -23,7 +23,7 @@ class _MultiDeviceReplicator(object): """ def __init__(self, master_tensor: torch.Tensor) -> None: - assert master_tensor.is_cuda or master_tensor.device.type == 'xla' + assert master_tensor.is_cuda or master_tensor.device.type == "xla" self.master = master_tensor self._per_device_tensors: Dict[torch.device, torch.Tensor] = {} @@ -118,7 +118,7 @@ class GradScaler(object): invokes the underlying ``optimizer.step()``, and other methods become no-ops. """ - def __init__(self, init_scale=2.**16, growth_factor=2.0, backoff_factor=0.5, growth_interval=2000, enabled=True): + def __init__(self, init_scale=2.0**16, growth_factor=2.0, backoff_factor=0.5, growth_interval=2000, enabled=True): if enabled and not torch.cuda.is_available(): warnings.warn("torch.cuda.amp.GradScaler is enabled, but CUDA is not available. Disabling.") self._enabled = False @@ -174,7 +174,7 @@ class GradScaler(object): # Short-circuit for the common case. if isinstance(outputs, torch.Tensor): - assert outputs.is_cuda or outputs.device.type == 'xla' + assert outputs.is_cuda or outputs.device.type == "xla" if self._scale is None: self._lazy_init_scale_growth_tracker(outputs.device) assert self._scale is not None @@ -186,7 +186,7 @@ class GradScaler(object): def apply_scale(val): if isinstance(val, torch.Tensor): - assert val.is_cuda or val.device.type == 'xla' + assert val.is_cuda or val.device.type == "xla" if len(stash) == 0: if self._scale is None: self._lazy_init_scale_growth_tracker(val.device) @@ -214,7 +214,7 @@ class GradScaler(object): # https://stackoverflow.com/questions/5029934/defaultdict-of-defaultdict # Google says mypy struggles with defaultdicts type annotations. - per_device_and_dtype_grads = defaultdict(lambda: defaultdict(list)) # type: ignore[var-annotated] + per_device_and_dtype_grads = defaultdict(lambda: defaultdict(list)) # type: ignore[var-annotated] with torch.no_grad(): for group in optimizer.param_groups: for param in group["params"]: @@ -238,9 +238,10 @@ class GradScaler(object): for device, per_dtype_grads in per_device_and_dtype_grads.items(): for grads in per_dtype_grads.values(): - torch._amp_foreach_non_finite_check_and_unscale_(grads, per_device_found_inf.get(device), - per_device_inv_scale.get(device)) - # For tensor parallel paramters it should be all-reduced over tensor parallel process group + torch._amp_foreach_non_finite_check_and_unscale_( + grads, per_device_found_inf.get(device), per_device_inv_scale.get(device) + ) + # For tensor parallel parameters it should be all-reduced over tensor parallel process group if gpc.is_initialized(ParallelMode.MODEL) and gpc.get_world_size(ParallelMode.MODEL) > 1: vals = [val for val in per_device_found_inf._per_device_tensors.values()] coalesced = _flatten_dense_tensors(vals) @@ -328,7 +329,7 @@ class GradScaler(object): .. warning:: Closure use is not currently supported. """ - if (not self._enabled): + if not self._enabled: return optimizer.step(*args, **kwargs) if "closure" in kwargs: @@ -343,7 +344,7 @@ class GradScaler(object): retval = None - if (hasattr(optimizer, "_step_supports_amp_scaling") and optimizer._step_supports_amp_scaling): + if hasattr(optimizer, "_step_supports_amp_scaling") and optimizer._step_supports_amp_scaling: # This optimizer has customized scale-handling logic, so we can call optimizer.step() directly. # The contract with custom optimizers is that their step() should accept an additional, # optional grad_scaler kwarg. We append self to the kwargs so the custom optimizer has full information: @@ -391,14 +392,14 @@ class GradScaler(object): if new_scale is not None: # Accept a new user-defined scale. if isinstance(new_scale, float): - self._scale.fill_(new_scale) # type: ignore[union-attr] + self._scale.fill_(new_scale) # type: ignore[union-attr] else: reason = "new_scale should be a float or a 1-element torch.cuda.FloatTensor with requires_grad=False." # type: ignore[attr-defined] assert isinstance(new_scale, torch.cuda.FloatTensor), reason assert new_scale.numel() == 1, reason assert new_scale.requires_grad is False, reason - self._scale.copy_(new_scale) # type: ignore[union-attr] + self._scale.copy_(new_scale) # type: ignore[union-attr] else: # Consume shared inf/nan data collected from optimizers to update the scale. # If all found_inf tensors are on the same device as self._scale, this operation is asynchronous. @@ -416,11 +417,23 @@ class GradScaler(object): found_inf_combined += found_infs[i] if self._higher_than_torch18: - torch._amp_update_scale_(_scale, _growth_tracker, found_inf_combined, self._growth_factor, - self._backoff_factor, self._growth_interval) + torch._amp_update_scale_( + _scale, + _growth_tracker, + found_inf_combined, + self._growth_factor, + self._backoff_factor, + self._growth_interval, + ) else: - self._scale = torch._amp_update_scale(_growth_tracker, _scale, found_inf_combined, self._growth_factor, - self._backoff_factor, self._growth_interval) + self._scale = torch._amp_update_scale( + _growth_tracker, + _scale, + found_inf_combined, + self._growth_factor, + self._backoff_factor, + self._growth_interval, + ) # To prepare for next iteration, clear the data collected from optimizers this iteration. self._per_optimizer_states = defaultdict(_refresh_per_optimizer_state) @@ -507,13 +520,17 @@ class GradScaler(object): If you wish to checkpoint the scaler's state after a particular iteration, :meth:`state_dict` should be called after :meth:`update`. """ - return { - "scale": self.get_scale(), - "growth_factor": self._growth_factor, - "backoff_factor": self._backoff_factor, - "growth_interval": self._growth_interval, - "_growth_tracker": self._get_growth_tracker() - } if self._enabled else {} + return ( + { + "scale": self.get_scale(), + "growth_factor": self._growth_factor, + "backoff_factor": self._backoff_factor, + "growth_interval": self._growth_interval, + "_growth_tracker": self._get_growth_tracker(), + } + if self._enabled + else {} + ) def load_state_dict(self, state_dict): r""" @@ -526,8 +543,10 @@ class GradScaler(object): return if len(state_dict) == 0: - raise RuntimeError("The source state dict is empty, possibly because it was saved " - "from a disabled instance of GradScaler.") + raise RuntimeError( + "The source state dict is empty, possibly because it was saved " + "from a disabled instance of GradScaler." + ) self._init_scale = state_dict["scale"] if self._scale is not None: @@ -542,15 +561,17 @@ class GradScaler(object): def __getstate__(self): state = self.__dict__.copy() if self._enabled: - assert len(self._per_optimizer_states) == 0, "A GradScaler instance may only be pickled at the beginning "\ - "of an iteration, or at the end after scaler.update()." + assert len(self._per_optimizer_states) == 0, ( + "A GradScaler instance may only be pickled at the beginning " + "of an iteration, or at the end after scaler.update()." + ) # Pickling _scale and _growth_tracker Tensors directly triggers # "warnings.warn("pickle support for Storage will be removed in 1.5..." # so instead, we set the unpickled instance up to reinitialize them lazily. - state['_init_scale'] = self.get_scale() - state['_init_growth_tracker'] = self._get_growth_tracker() - state['_scale'] = None - state['_growth_tracker'] = None + state["_init_scale"] = self.get_scale() + state["_init_growth_tracker"] = self._get_growth_tracker() + state["_scale"] = None + state["_growth_tracker"] = None return state def __setstate__(self, state): @@ -562,8 +583,9 @@ class GradScaler(object): dummy_inv_scale = torch.full((1,), 1.0, dtype=torch.float32, device=_scale.device) found_inf = torch.full((1,), 0.0, dtype=torch.float32, device=_scale.device) - self._per_optimizer_states[id(optimizer)]["found_inf_per_device"] = \ - self._unscale_grads_(optimizer, dummy_inv_scale, found_inf, True) + self._per_optimizer_states[id(optimizer)]["found_inf_per_device"] = self._unscale_grads_( + optimizer, dummy_inv_scale, found_inf, True + ) return self._per_optimizer_states[id(optimizer)]["found_inf_per_device"] diff --git a/colossalai/amp/torch_amp/torch_amp.py b/colossalai/legacy/amp/torch_amp/torch_amp.py similarity index 93% rename from colossalai/amp/torch_amp/torch_amp.py rename to colossalai/legacy/amp/torch_amp/torch_amp.py index 65718d77c2e00cdaf83ca8c27e9c26caed0d9362..ced5cc3e66478f3bb95dddbdce6b9f258ba3193c 100644 --- a/colossalai/amp/torch_amp/torch_amp.py +++ b/colossalai/legacy/amp/torch_amp/torch_amp.py @@ -7,13 +7,13 @@ from torch import Tensor from torch.nn.modules.loss import _Loss from torch.optim import Optimizer -from colossalai.nn.optimizer import ColossalaiOptimizer -from colossalai.utils import clip_grad_norm_fp32 +from colossalai.interface import OptimizerWrapper +from colossalai.legacy.utils import clip_grad_norm_fp32 from ._grad_scaler import GradScaler -class TorchAMPOptimizer(ColossalaiOptimizer): +class TorchAMPOptimizer(OptimizerWrapper): """A wrapper class which integrate Pytorch AMP with an optimizer Args: @@ -42,8 +42,7 @@ class TorchAMPOptimizer(ColossalaiOptimizer): self.scaler.scale(loss).backward() def step(self): - """Update the parameters of the model - """ + """Update the parameters of the model""" self.scaler.step(self.optim) self.scaler.update() diff --git a/colossalai/legacy/builder/__init__.py b/colossalai/legacy/builder/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..9af3d139d3b13351285fa239bc474f4b1ce49d11 --- /dev/null +++ b/colossalai/legacy/builder/__init__.py @@ -0,0 +1,3 @@ +from .builder import build_from_config, build_from_registry, build_gradient_handler + +__all__ = ["build_gradient_handler", "build_from_config", "build_from_registry"] diff --git a/colossalai/builder/builder.py b/colossalai/legacy/builder/builder.py similarity index 79% rename from colossalai/builder/builder.py rename to colossalai/legacy/builder/builder.py index 4a907601327c9c938243bfee121165937c02537c..dec3bc1c2487c58cfc36235d44d7252dced0db59 100644 --- a/colossalai/builder/builder.py +++ b/colossalai/legacy/builder/builder.py @@ -3,7 +3,7 @@ import inspect -from colossalai.registry import * +from colossalai.legacy.registry import * def build_from_config(module, config: dict): @@ -19,7 +19,7 @@ def build_from_config(module, config: dict): AssertionError: Raises an AssertionError if `module` is not a class """ - assert inspect.isclass(module), 'module must be a class' + assert inspect.isclass(module), "module must be a class" return module(**config) @@ -45,15 +45,15 @@ def build_from_registry(config, registry: Registry): Raises: Exception: Raises an Exception if an error occurred when building from registry. """ - config_ = config.copy() # keep the original config untouched - assert isinstance(registry, Registry), f'Expected type Registry but got {type(registry)}' + config_ = config.copy() # keep the original config untouched + assert isinstance(registry, Registry), f"Expected type Registry but got {type(registry)}" - mod_type = config_.pop('type') - assert registry.has(mod_type), f'{mod_type} is not found in registry {registry.name}' + mod_type = config_.pop("type") + assert registry.has(mod_type), f"{mod_type} is not found in registry {registry.name}" try: obj = registry.get_module(mod_type)(**config_) except Exception as e: - print(f'An error occurred when building {mod_type} from registry {registry.name}', flush=True) + print(f"An error occurred when building {mod_type} from registry {registry.name}", flush=True) raise e return obj @@ -71,9 +71,9 @@ def build_gradient_handler(config, model, optimizer): optimizer (:class:`torch.optim.Optimizer`): An optimizer object containing parameters for the gradient handler Returns: - An object of :class:`colossalai.engine.BaseGradientHandler` + An object of :class:`colossalai.legacy.engine.BaseGradientHandler` """ config_ = config.copy() - config_['model'] = model - config_['optimizer'] = optimizer + config_["model"] = model + config_["optimizer"] = optimizer return build_from_registry(config_, GRADIENT_HANDLER) diff --git a/colossalai/legacy/communication/__init__.py b/colossalai/legacy/communication/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..f4492b074425e1e4267e23be9f0da6e790cdf8d4 --- /dev/null +++ b/colossalai/legacy/communication/__init__.py @@ -0,0 +1,34 @@ +from .collective import all_gather, all_reduce, broadcast, reduce, reduce_scatter +from .p2p import ( + recv_backward, + recv_forward, + send_backward, + send_backward_recv_backward, + send_backward_recv_forward, + send_forward, + send_forward_backward_recv_forward_backward, + send_forward_recv_backward, + send_forward_recv_forward, +) +from .ring import ring_forward +from .utils import recv_obj_meta, send_obj_meta + +__all__ = [ + "all_gather", + "reduce_scatter", + "all_reduce", + "broadcast", + "reduce", + "send_forward", + "send_forward_recv_forward", + "send_forward_backward_recv_forward_backward", + "send_backward", + "send_backward_recv_backward", + "send_backward_recv_forward", + "send_forward_recv_backward", + "recv_backward", + "recv_forward", + "ring_forward", + "send_obj_meta", + "recv_obj_meta", +] diff --git a/colossalai/communication/collective.py b/colossalai/legacy/communication/collective.py similarity index 86% rename from colossalai/communication/collective.py rename to colossalai/legacy/communication/collective.py index 64fb5b8b5296fa8afe7b20c9c96609f7b999e8c0..9cf30f733deeab7b1a4765d45f0bb4cefb11aece 100644 --- a/colossalai/communication/collective.py +++ b/colossalai/legacy/communication/collective.py @@ -6,13 +6,13 @@ import torch.distributed as dist from torch import Tensor from torch.distributed import ReduceOp -from colossalai.context import ParallelMode -from colossalai.core import global_context as gpc +from colossalai.legacy.context import ParallelMode +from colossalai.legacy.core import global_context as gpc -_all_gather_func = dist._all_gather_base \ - if "all_gather_into_tensor" not in dir(dist) else dist.all_gather_into_tensor -_reduce_scatter_func = dist._reduce_scatter_base \ - if "reduce_scatter_tensor" not in dir(dist) else dist.reduce_scatter_tensor +_all_gather_func = dist._all_gather_base if "all_gather_into_tensor" not in dir(dist) else dist.all_gather_into_tensor +_reduce_scatter_func = ( + dist._reduce_scatter_base if "reduce_scatter_tensor" not in dir(dist) else dist.reduce_scatter_tensor +) def all_gather(tensor: Tensor, dim: int, parallel_mode: ParallelMode, async_op: bool = False) -> Tensor: @@ -26,7 +26,7 @@ def all_gather(tensor: Tensor, dim: int, parallel_mode: ParallelMode, async_op: Args: tensor (:class:`torch.Tensor`): Tensor to be gathered. dim (int): The dimension concatenating in. - parallel_mode (:class:`colossalai.context.ParallelMode`): Parallel group mode used in this communication. + parallel_mode (:class:`colossalai.legacy.context.ParallelMode`): Parallel group mode used in this communication. async_op (bool, optional): Whether operations are asynchronous. Returns: @@ -50,11 +50,9 @@ def all_gather(tensor: Tensor, dim: int, parallel_mode: ParallelMode, async_op: return out -def reduce_scatter(tensor: Tensor, - dim: int, - parallel_mode: ParallelMode, - op: ReduceOp = ReduceOp.SUM, - async_op: bool = False) -> Tensor: +def reduce_scatter( + tensor: Tensor, dim: int, parallel_mode: ParallelMode, op: ReduceOp = ReduceOp.SUM, async_op: bool = False +) -> Tensor: r"""Reduces all tensors then scatters it in a specific dimension to all members in the parallel group. @@ -65,7 +63,7 @@ def reduce_scatter(tensor: Tensor, Args: tensor (:class:`torch.Tensor`): Tensor to be reduce_scattered. dim (int): The dimension concatenating in. - parallel_mode (:class:`colossalai.context.ParallelMode`): Parallel group mode used in this communication. + parallel_mode (:class:`colossalai.legacy.context.ParallelMode`): Parallel group mode used in this communication. op (torch.distributed.ReduceOp, optional): The type of reduce operation, should be included in [SUM, AVG, PRODUCT, MIN, MAX, BAND, BOR, BXOR]. More details about ReduceOp please refer to @@ -93,10 +91,9 @@ def reduce_scatter(tensor: Tensor, return out -def all_reduce(tensor: Tensor, - parallel_mode: ParallelMode, - op: ReduceOp = ReduceOp.SUM, - async_op: bool = False) -> Tensor: +def all_reduce( + tensor: Tensor, parallel_mode: ParallelMode, op: ReduceOp = ReduceOp.SUM, async_op: bool = False +) -> Tensor: r"""Reduces the tensor data across whole parallel group in such a way that all get the final result. Note: @@ -105,7 +102,7 @@ def all_reduce(tensor: Tensor, Args: tensor (:class:`torch.Tensor`): Tensor to be all-reduced. - parallel_mode (:class:`colossalai.context.ParallelMode`): Parallel group mode used in this communication. + parallel_mode (:class:`colossalai.legacy.context.ParallelMode`): Parallel group mode used in this communication. op (torch.distributed.ReduceOp, optional): The type of reduce operation, should be included in [SUM, AVG, PRODUCT, MIN, MAX, BAND, BOR, BXOR]. More details about ReduceOp please refer to @@ -141,7 +138,7 @@ def broadcast(tensor: Tensor, src: int, parallel_mode: ParallelMode, async_op: b Args: tensor (:class:`torch.Tensor`): Tensor to be broadcast. src (int): Source rank. - parallel_mode (:class:`colossalai.context.ParallelMode`): Parallel group mode used in this communication. + parallel_mode (:class:`colossalai.legacy.context.ParallelMode`): Parallel group mode used in this communication. async_op (bool, optional): Whether operations are asynchronous. Returns: @@ -173,7 +170,7 @@ def reduce(tensor: Tensor, dst: int, parallel_mode: ParallelMode, op: ReduceOp = Args: tensor (:class:`torch.Tensor`): Tensor to be reduced. dst (int): Destination rank. - parallel_mode (:class:`colossalai.context.ParallelMode`): Parallel group mode used in this communication. + parallel_mode (:class:`colossalai.legacy.context.ParallelMode`): Parallel group mode used in this communication. async_op (bool, optional): Whether operations are asynchronous. Returns: @@ -201,16 +198,17 @@ def scatter_object_list(scatter_object_output_list, scatter_object_input_list, s if dist.distributed_c10d._rank_not_in_group(group): return - if (not isinstance(scatter_object_output_list, list) or len(scatter_object_output_list) < 1): + if not isinstance(scatter_object_output_list, list) or len(scatter_object_output_list) < 1: raise RuntimeError("Expected argument scatter_object_output_list to be a list of size at least 1.") # set tensor device to cuda if backend is nccl - device = torch.cuda.current_device() if dist.get_backend(group) == 'nccl' else torch.device("cpu") + device = torch.cuda.current_device() if dist.get_backend(group) == "nccl" else torch.device("cpu") - my_rank = dist.get_rank() # use global rank + my_rank = dist.get_rank() # use global rank if my_rank == src: tensor_list, tensor_sizes = zip( - *[dist.distributed_c10d._object_to_tensor(obj) for obj in scatter_object_input_list]) + *[dist.distributed_c10d._object_to_tensor(obj) for obj in scatter_object_input_list] + ) tensor_list = list(map(lambda x: x.to(device), tensor_list)) tensor_sizes = list(map(lambda x: x.to(device), tensor_sizes)) diff --git a/colossalai/legacy/communication/p2p.py b/colossalai/legacy/communication/p2p.py new file mode 100644 index 0000000000000000000000000000000000000000..19c3919b6e29cfd35176e53c915c273bbf7fe2dc --- /dev/null +++ b/colossalai/legacy/communication/p2p.py @@ -0,0 +1,429 @@ +#!/usr/bin/env python +# -*- encoding: utf-8 -*- + +import operator +from functools import reduce +from typing import List, Tuple, Union + +import torch +import torch.distributed as dist + +from colossalai.legacy.context.parallel_mode import ParallelMode +from colossalai.legacy.core import global_context as gpc +from colossalai.utils import get_current_device + +from .utils import gather_split_1d_tensor, split_tensor_into_1d_equal_chunks + +TensorShape = Union[torch.Size, List[int], Tuple[int]] + + +def _get_tensor_shape(tensor_shape: TensorShape, chunk_tensor: bool = False) -> Tuple[TensorShape, bool]: + """get the exact tensor shape when communicating and return whether the tensor is a chunk + + Args: + tensor_shape (:class:`torch.Size`): shape of tensor + chunk_tensor (bool, optional): whether to chunk tensor, defaults to False + + Returns: + Tuple[Union[:class:`torch.Size`, List[int], Tuple[int]], bool]: exact tensor shape, whether to chunk tensor + """ + if chunk_tensor: + tensor_chunk_shape = reduce(operator.mul, tensor_shape, 1) + tensor_parallel_world_size = gpc.get_world_size(ParallelMode.TENSOR) + if tensor_chunk_shape % tensor_parallel_world_size == 0: + tensor_chunk_shape = tensor_chunk_shape // tensor_parallel_world_size + else: + tensor_chunk_shape = tensor_shape + chunk_tensor = False + else: + tensor_chunk_shape = tensor_shape + return tensor_chunk_shape, chunk_tensor + + +def create_recv_buffer_with_shapes(recv_shapes, dtype, scatter_gather_tensors): + if isinstance(recv_shapes, torch.Size): + recv_chunk_shape, recv_split = _get_tensor_shape(recv_shapes, scatter_gather_tensors) + buffer_recv = torch.empty(recv_chunk_shape, requires_grad=True, device=get_current_device(), dtype=dtype) + return buffer_recv, recv_split + buffer_recv = [] + for recv_shape in recv_shapes: + recv_chunk_shape, recv_split = _get_tensor_shape(recv_shape, scatter_gather_tensors) + tensor_recv = torch.empty(recv_chunk_shape, requires_grad=True, device=get_current_device(), dtype=dtype) + buffer_recv.append(tensor_recv) + return buffer_recv, recv_split + + +def process_object_to_send(object_send, scatter_gather_tensors): + if isinstance(object_send, torch.Tensor): + send_split = _get_tensor_shape(object_send.shape, scatter_gather_tensors)[1] + if send_split: + object_send = split_tensor_into_1d_equal_chunks(object_send) + return object_send + + object_send_list = [] + for tensor_send in object_send: + send_split = _get_tensor_shape(tensor_send.shape, scatter_gather_tensors)[1] + if send_split: + object_send_list.append(split_tensor_into_1d_equal_chunks(tensor_send)) + else: + object_send_list.append(tensor_send) + object_send = tuple(object_send_list) + + return object_send + + +def filling_ops_queue(obj, comm_op, comm_rank, ops_queue): + if isinstance(obj, torch.Tensor): + op_to_add = dist.P2POp(comm_op, obj, comm_rank) + ops_queue.append(op_to_add) + else: + for tensor_to_comm in obj: + op_to_add = dist.P2POp(comm_op, tensor_to_comm, comm_rank) + ops_queue.append(op_to_add) + + +def _communicate( + object_send_next: Union[torch.Tensor, List[torch.Tensor]] = None, + object_send_prev: Union[torch.Tensor, List[torch.Tensor]] = None, + recv_prev: bool = False, + recv_next: bool = False, + recv_prev_shape: Union[torch.Size, List[torch.Size]] = None, + recv_next_shape: Union[torch.Size, List[torch.Size]] = None, + prev_rank: int = None, + next_rank: int = None, + dtype: torch.dtype = None, + scatter_gather_tensors: bool = False, +) -> Tuple[Union[torch.Tensor, List[torch.Tensor]]]: + """ + Adapted from megatron.p2p_communication. + Communicate tensors between stages. Used as helper method in other + communication methods that are used in pipeline schedule. + Takes the following arguments: + object_send_next (Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]): tensor to send to next rank (no tensor sent if + set to None). + object_send_prev (Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]): tensor to send to prev rank (no tensor sent if + set to None). + recv_prev (bool): boolean for whether tensor should be received from + previous rank. + recv_next (bool): boolean for whether tensor should be received from + next rank. + recv_prev_shape (Union[:class:`torch.Size`, List[:class:`torch.Size`]]): shape of the tensor to be received from the previous stage, defaults to None. + recv_next_shape (Union[:class:`torch.Size`, List[:class:`torch.Size`]]): shape of the tensor to be received from the next stage, defaults to None. + prev_rank (int): the rank of the previous pipeline stage, defaults to None, + next_rank (int): the rank of the next pipeline stage, defaults to None, + dtype (torch.dtype): data type of intermediate buffers, defaults to None + scatter_gather_tensors (bool): whether to scatter and gather tensor between pipeline stages, defaults to False + + Returns: + Tuple[Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]]: returns tensor_recv_prev, tensor_recv_next + """ + + # Create placeholder tensors for receive in forward and backward directions + # if needed. + tensor_recv_prev = None + tensor_recv_next = None + + if recv_prev: + assert recv_prev_shape is not None + tensor_recv_prev, recv_prev_split = create_recv_buffer_with_shapes( + recv_prev_shape, dtype, scatter_gather_tensors + ) + + if recv_next: + assert recv_next_shape is not None + tensor_recv_next, recv_next_split = create_recv_buffer_with_shapes( + recv_next_shape, dtype, scatter_gather_tensors + ) + + if object_send_prev is not None or recv_prev: + if prev_rank is None: + prev_rank = gpc.get_prev_global_rank(ParallelMode.PIPELINE) + + if object_send_next is not None or recv_next: + if next_rank is None: + next_rank = gpc.get_next_global_rank(ParallelMode.PIPELINE) + + if object_send_prev is not None: + object_send_prev = process_object_to_send(object_send_prev, scatter_gather_tensors) + + if object_send_next is not None: + object_send_next = process_object_to_send(object_send_next, scatter_gather_tensors) + + ops = [] + if object_send_prev is not None: + filling_ops_queue(object_send_prev, dist.isend, prev_rank, ops) + + if tensor_recv_prev is not None: + filling_ops_queue(tensor_recv_prev, dist.irecv, prev_rank, ops) + + if tensor_recv_next is not None: + filling_ops_queue(tensor_recv_next, dist.irecv, next_rank, ops) + + if object_send_next is not None: + filling_ops_queue(object_send_next, dist.isend, next_rank, ops) + + if len(ops) > 0: + reqs = dist.batch_isend_irecv(ops) + for req in reqs: + req.wait() + # To protect against race condition when using batch_isend_irecv(). + torch.cuda.synchronize() + + if recv_prev and recv_prev_split: + if isinstance(tensor_recv_prev, torch.Tensor): + tensor_recv_prev = gather_split_1d_tensor(tensor_recv_prev).view(recv_prev_shape).requires_grad_() + else: + for index in range(len(tensor_recv_prev)): + tensor_recv_prev[index] = ( + gather_split_1d_tensor(tensor_recv_prev[index]).view(recv_prev_shape[index]).requires_grad_() + ) + + if recv_next and recv_next_split: + if isinstance(tensor_recv_next, torch.Tensor): + tensor_recv_next = gather_split_1d_tensor(tensor_recv_next).view(recv_next_shape).requires_grad_() + else: + for index in range(len(tensor_recv_next)): + tensor_recv_next[index] = ( + gather_split_1d_tensor(tensor_recv_next[index]).view(recv_next_shape[index]).requires_grad_() + ) + + return tensor_recv_prev, tensor_recv_next + + +def recv_forward( + input_tensor_shape, prev_rank=None, dtype=torch.float, scatter_gather_tensors=False +) -> Union[torch.Tensor, List[torch.Tensor]]: + """Copy the forward output from the previous stage in pipeline as the input tensor of this stage. + + Args: + input_tensor_shape (Union[:class:`torch.Size`, List[:class:`torch.Size`]]): The shape of the tensor to be received. + prev_rank (int, optional): The rank of the source of the tensor. + + Returns: + Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]: The input tensor or input tensor list. + """ + if gpc.is_pipeline_first_stage(): + input_tensor = None + else: + input_tensor, _ = _communicate( + recv_prev=True, + recv_prev_shape=input_tensor_shape, + prev_rank=prev_rank, + dtype=dtype, + scatter_gather_tensors=scatter_gather_tensors, + ) + return input_tensor + + +def recv_backward( + output_grad_shape, next_rank=None, dtype=torch.float, scatter_gather_tensors=False +) -> Union[torch.Tensor, List[torch.Tensor]]: + """Copy the gradient tensor from the next stage in pipeline as the input gradient of this stage. + + Args: + output_grad_shape (Union[:class:`torch.Size`, List[:class:`torch.Size`]]): The shape of the tensor to be received. + next_rank (int, optional): The rank of the source of the tensor. + + Returns: + Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]: The input gradient tensor or gradient tensor list. + """ + if gpc.is_pipeline_last_stage(): + output_tensor_grad = None + else: + _, output_tensor_grad = _communicate( + recv_next=True, + recv_next_shape=output_grad_shape, + next_rank=next_rank, + dtype=dtype, + scatter_gather_tensors=scatter_gather_tensors, + ) + return output_tensor_grad + + +def send_forward(output_tensor, next_rank=None, scatter_gather_tensors=False) -> None: + """Sends the input tensor to the next stage in pipeline. + + Args: + output_tensor (Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]): Tensor to be sent. + next_rank (int, optional): The rank of the recipient of the tensor. + """ + if not gpc.is_pipeline_last_stage(): + _communicate(object_send_next=output_tensor, next_rank=next_rank, scatter_gather_tensors=scatter_gather_tensors) + + +def send_backward(input_tensor_grad, prev_rank=None, scatter_gather_tensors=False) -> None: + """Sends the gradient tensor to the previous stage in pipeline. + + Args: + input_tensor_grad (Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]): Tensor to be sent + prev_rank (int, optional): The rank of the recipient of the tensor + """ + if not gpc.is_pipeline_first_stage(): + _communicate( + object_send_prev=input_tensor_grad, prev_rank=prev_rank, scatter_gather_tensors=scatter_gather_tensors + ) + + +def send_forward_recv_backward( + output_tensor, output_grad_shape, recv_next=True, next_rank=None, dtype=torch.float, scatter_gather_tensors=False +) -> Union[torch.Tensor, List[torch.Tensor]]: + """Batched communication operation. Sends the input tensor to the + next stage in pipeline, while receives the gradient tensor from the + next stage in pipeline as the input gradient tensor of this stage. + + Args: + output_tensor (Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]): Tensor to be sent. + output_grad_shape (Union[:class:`torch.Size`, List[:class:`torch.Size`]]): The shape of the tensor to be received. + + Returns: + Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]: The input gradient tensor. + """ + if gpc.is_pipeline_last_stage(): + output_tensor_grad = None + else: + _, output_tensor_grad = _communicate( + object_send_next=output_tensor, + recv_next=recv_next, + recv_next_shape=output_grad_shape, + next_rank=next_rank, + dtype=dtype, + scatter_gather_tensors=scatter_gather_tensors, + ) + return output_tensor_grad + + +def send_backward_recv_forward( + input_tensor_grad, + input_tensor_shape, + recv_prev=True, + prev_rank=None, + dtype=torch.float, + scatter_gather_tensors=False, +) -> Union[torch.Tensor, List[torch.Tensor]]: + """Batched communication operation. Sends the gradient tensor to the + previous stage in pipeline, while receives the output tensor from the + previous stage in pipeline as the input of this stage. + + Args: + input_tensor_grad (Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]): Tensor to be sent. + input_tensor_shape (Union[:class:`torch.Size`, List[:class:`torch.Size`]]): The shape of the tensor to be received. + + Returns: + Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]: The input tensor. + """ + if gpc.is_pipeline_first_stage(): + input_tensor = None + else: + input_tensor, _ = _communicate( + object_send_prev=input_tensor_grad, + recv_prev=recv_prev, + recv_prev_shape=input_tensor_shape, + prev_rank=prev_rank, + dtype=dtype, + scatter_gather_tensors=scatter_gather_tensors, + ) + return input_tensor + + +def send_forward_recv_forward( + output_tensor, + input_tensor_shape, + recv_prev=True, + prev_rank=None, + next_rank=None, + dtype=torch.float, + scatter_gather_tensors=False, +) -> Union[torch.Tensor, List[torch.Tensor]]: + """Batched communication operation. Sends the input tensor to the + next stage in pipeline, while receives the output tensor from the + previous stage in pipeline as the input of this stage. + + Args: + output_tensor (Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]): Tensor to be sent. + input_tensor_shape (Union[:class:`torch.Size`, List[:class:`torch.Size`]]): The shape of the tensor to be received. + + Returns: + Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]: The input tensor. + """ + input_tensor, _ = _communicate( + object_send_next=output_tensor, + recv_prev=recv_prev, + recv_prev_shape=input_tensor_shape, + prev_rank=prev_rank, + next_rank=next_rank, + dtype=dtype, + scatter_gather_tensors=scatter_gather_tensors, + ) + return input_tensor + + +def send_backward_recv_backward( + input_tensor_grad, + output_grad_shape, + recv_next=True, + prev_rank=None, + next_rank=None, + dtype=torch.float, + scatter_gather_tensors=False, +) -> Union[torch.Tensor, List[torch.Tensor]]: + """Batched communication operation. Sends the gradient tensor to the + previous stage in pipeline, while receives the gradient tensor from the + next member in pipeline as the input of this stage. + + Args: + input_tensor_grad (Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]): Tensor to be sent. + output_grad_shape (Union[:class:`torch.Size`, List[:class:`torch.Size`]]): The shape of the tensor to be received. + + Returns: + Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]: The input gradient tensor. + """ + _, output_tensor_grad = _communicate( + object_send_prev=input_tensor_grad, + recv_next=recv_next, + recv_next_shape=output_grad_shape, + prev_rank=prev_rank, + next_rank=next_rank, + dtype=dtype, + scatter_gather_tensors=scatter_gather_tensors, + ) + return output_tensor_grad + + +def send_forward_backward_recv_forward_backward( + output_tensor, + input_tensor_grad, + input_tensor_shape, + output_grad_shape, + recv_prev=True, + recv_next=True, + prev_rank=None, + next_rank=None, + dtype=torch.float, + scatter_gather_tensors=False, +) -> Tuple[Union[torch.Tensor, List[torch.Tensor]]]: + """Batched communication operation. Sends the input tensor to the next stage in pipeline and + the gradient tensor to the previous stage, while receives the input gradient tensor from the + next stage and the input tensor from the previous stage. + + Args: + output_tensor (Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]): Tensor sent to the next. + input_tensor_grad (Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]): Tensor sent to the previous. + input_tensor_shape (Union[:class:`torch.Size`, List[:class:`torch.Size`]]): The shape of the tensor received from the previous. + output_grad_shape (Union[:class:`torch.Size`, List[:class:`torch.Size`]]): The shape of the tensor received from the next. + + Returns: + Tuple(Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]], Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]): (the input tensor, the input gradient tensor) + """ + input_tensor, output_tensor_grad = _communicate( + object_send_next=output_tensor, + object_send_prev=input_tensor_grad, + recv_prev=recv_prev, + recv_next=recv_next, + recv_prev_shape=input_tensor_shape, + recv_next_shape=output_grad_shape, + prev_rank=prev_rank, + next_rank=next_rank, + dtype=dtype, + scatter_gather_tensors=scatter_gather_tensors, + ) + return input_tensor, output_tensor_grad diff --git a/colossalai/communication/p2p_v2.py b/colossalai/legacy/communication/p2p_v2.py similarity index 92% rename from colossalai/communication/p2p_v2.py rename to colossalai/legacy/communication/p2p_v2.py index 0dacd8c3c9b5bacf65aee3168bcbd170c5a5b6dc..7c8d8bede0697e0756a12162d40512a64b730680 100644 --- a/colossalai/communication/p2p_v2.py +++ b/colossalai/legacy/communication/p2p_v2.py @@ -10,8 +10,8 @@ import torch.distributed as dist from torch.distributed import ProcessGroupNCCL from torch.distributed import distributed_c10d as c10d -from colossalai.context.parallel_mode import ParallelMode -from colossalai.core import global_context as gpc +from colossalai.legacy.context.parallel_mode import ParallelMode +from colossalai.legacy.core import global_context as gpc TensorShape = Union[torch.Size, List[int], Tuple[int]] _pg_manager = {} @@ -19,7 +19,7 @@ _unpickler = pickle.Unpickler def init_process_group(): - """intialise process group by dist.new_group in the adjacent stages + """initialise process group by dist.new_group in the adjacent stages Args: None @@ -62,10 +62,10 @@ def _cuda_safe_tensor_to_object(tensor: torch.Tensor, tensor_size: torch.Size) - Any: object after unpickled """ buf = tensor.numpy().tobytes()[:tensor_size] - if b'cuda' in buf: + if b"cuda" in buf: buf_array = bytearray(buf) device_index = torch.cuda.current_device() - buf_array[buf_array.find(b'cuda') + 5] = 48 + device_index + buf_array[buf_array.find(b"cuda") + 5] = 48 + device_index buf = bytes(buf_array) io_bytes = io.BytesIO(buf) @@ -123,8 +123,8 @@ def _broadcast_object_list(object_list: List[Any], src: int, dst: int, device=No if local_rank == src: object_tensor = torch.cat(tensor_list) else: - object_tensor = torch.empty( # type: ignore[call-overload] - torch.sum(object_sizes_tensor).item(), # type: ignore[arg-type] + object_tensor = torch.empty( # type: ignore[call-overload] + torch.sum(object_sizes_tensor).item(), # type: ignore[arg-type] dtype=torch.uint8, ) @@ -138,7 +138,7 @@ def _broadcast_object_list(object_list: List[Any], src: int, dst: int, device=No if local_rank != src: for i, obj_size in enumerate(object_sizes_tensor): - obj_view = object_tensor[offset:offset + obj_size] + obj_view = object_tensor[offset : offset + obj_size] obj_view = obj_view.type(torch.uint8) if obj_view.device != torch.device("cpu"): obj_view = obj_view.cpu() @@ -147,8 +147,10 @@ def _broadcast_object_list(object_list: List[Any], src: int, dst: int, device=No unpickle_object = _cuda_safe_tensor_to_object(obj_view, obj_size) # unconsistence in device - if isinstance(unpickle_object, - torch.Tensor) and unpickle_object.device.index != torch.cuda.current_device(): + if ( + isinstance(unpickle_object, torch.Tensor) + and unpickle_object.device.index != torch.cuda.current_device() + ): unpickle_object = unpickle_object.cuda() object_list[i] = unpickle_object diff --git a/colossalai/legacy/communication/ring.py b/colossalai/legacy/communication/ring.py new file mode 100644 index 0000000000000000000000000000000000000000..a61dae56cd429980ac1f475dcb411cb65fd2b480 --- /dev/null +++ b/colossalai/legacy/communication/ring.py @@ -0,0 +1,57 @@ +#!/usr/bin/env python +# -*- encoding: utf-8 -*- + +import torch + +from colossalai.legacy.context.parallel_mode import ParallelMode +from colossalai.legacy.core import global_context as gpc +from colossalai.utils import get_current_device, synchronize + + +def ring_forward(tensor_send_next: torch.Tensor, parallel_mode: ParallelMode) -> torch.Tensor: + """Sends a tensor to the next member and receives a tensor from the previous member. + This function returns the received tensor from the previous member. + + Args: + tensor_send_next (:class:`torch.Tensor`): Tensor sent to next member + parallel_mode (ParallelMode): Parallel group mode used in this communication + + Returns: + :class:`torch.Tensor`: The tensor received from the previous. + + Note: + The parallel_mode should be concluded in ``ParallelMode``. More details about ``ParallelMode`` could be found + in `parallel_mode `_. + """ + buffer_shape = tensor_send_next.size() + + ops = [] + current_rank = gpc.get_global_rank() + + tensor_recv_prev = torch.empty( + buffer_shape, requires_grad=True, device=get_current_device(), dtype=tensor_send_next.dtype + ) + + # send to next rank + send_next_op = torch.distributed.P2POp( + torch.distributed.isend, tensor_send_next, gpc.get_next_global_rank(parallel_mode) + ) + ops.append(send_next_op) + + # receive from prev rank + recv_prev_op = torch.distributed.P2POp( + torch.distributed.irecv, tensor_recv_prev, gpc.get_prev_global_rank(parallel_mode) + ) + ops.append(recv_prev_op) + + if current_rank % 2 == 0: + ops = ops[::-1] + + reqs = torch.distributed.batch_isend_irecv(ops) + for req in reqs: + req.wait() + + # To protect against race condition when using batch_isend_irecv(). + synchronize() + + return tensor_recv_prev diff --git a/colossalai/legacy/communication/utils.py b/colossalai/legacy/communication/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..6d77f3753fe8515c7a1aba3e393e044c83216f38 --- /dev/null +++ b/colossalai/legacy/communication/utils.py @@ -0,0 +1,127 @@ +from typing import List, Tuple, Union + +import torch +import torch.distributed as dist + +from colossalai.legacy.context.parallel_mode import ParallelMode +from colossalai.legacy.core import global_context as gpc +from colossalai.utils import get_current_device + +TensorShape = Union[torch.Size, List[int], Tuple[int]] + + +def send_meta_helper(obj, next_rank, tensor_kwargs): + send_shape = torch.tensor(obj.size(), **tensor_kwargs) + send_ndims = torch.tensor(len(obj.size()), **tensor_kwargs) + dist.send(send_ndims, next_rank) + dist.send(send_shape, next_rank) + + +def send_obj_meta(obj, need_meta=True, next_rank=None) -> bool: + """Sends obj meta information before sending a specific obj. + Since the recipient must know the shape of the obj in p2p communications, + meta information of the obj should be sent before communications. This function + synchronizes with :func:`recv_obj_meta`. + + Args: + obj (Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]): obj to be sent. + need_meta (bool, optional): If False, meta information won't be sent. + next_rank (int): The rank of the next member in pipeline parallel group. + + Returns: + bool: False + """ + if need_meta: + if next_rank is None: + next_rank = gpc.get_next_global_rank(ParallelMode.PIPELINE) + + tensor_kwargs = {"dtype": torch.long, "device": get_current_device()} + if isinstance(obj, torch.Tensor): + send_obj_nums = torch.tensor(1, **tensor_kwargs) + dist.send(send_obj_nums, next_rank) + send_meta_helper(obj, next_rank, tensor_kwargs) + else: + send_obj_nums = torch.tensor(len(obj), **tensor_kwargs) + dist.send(send_obj_nums, next_rank) + for tensor_to_send in obj: + send_meta_helper(tensor_to_send, next_rank, tensor_kwargs) + + return False + + +def recv_meta_helper(prev_rank, tensor_kwargs): + recv_ndims = torch.empty((), **tensor_kwargs) + dist.recv(recv_ndims, prev_rank) + recv_shape = torch.empty(recv_ndims, **tensor_kwargs) + dist.recv(recv_shape, prev_rank) + return recv_shape + + +def recv_obj_meta(obj_shape, prev_rank=None) -> torch.Size: + """Receives obj meta information before receiving a specific obj. + Since the recipient must know the shape of the obj in p2p communications, + meta information of the obj should be received before communications. This function + synchronizes with :func:`send_obj_meta`. + + Args: + obj_shape (Union[:class:`torch.Size`, List[:class:`torch.Size`]]): The shape of the obj to be received. + prev_rank (int): The rank of the source of the obj. + + Returns: + Union[:class:`torch.Size`, List[:class:`torch.Size`]]: The shape of the obj to be received. + """ + if obj_shape is None: + if prev_rank is None: + prev_rank = gpc.get_prev_global_rank(ParallelMode.PIPELINE) + + tensor_kwargs = {"dtype": torch.long, "device": get_current_device()} + recv_obj_nums = torch.empty((), **tensor_kwargs) + dist.recv(recv_obj_nums, prev_rank) + if recv_obj_nums.item() == 1: + recv_shape = recv_meta_helper(prev_rank, tensor_kwargs) + obj_shape = torch.Size(recv_shape) + else: + obj_shape = [] + for i in range(recv_obj_nums.item()): + recv_shape = recv_meta_helper(prev_rank, tensor_kwargs) + obj_shape.append(torch.Size(recv_shape)) + + return obj_shape + + +def split_tensor_into_1d_equal_chunks(tensor: torch.Tensor, new_buffer=False) -> torch.Tensor: + """Break a tensor into equal 1D chunks. + + Args: + tensor (:class:`torch.Tensor`): Tensor to be split before communication. + new_buffer (bool, optional): Whether to use a new buffer to store sliced tensor. + + Returns: + :class:`torch.Tensor`: The split tensor + """ + partition_size = torch.numel(tensor) // gpc.get_world_size(ParallelMode.PARALLEL_1D) + start_index = partition_size * gpc.get_local_rank(ParallelMode.PARALLEL_1D) + end_index = start_index + partition_size + if new_buffer: + data = torch.empty(partition_size, dtype=tensor.dtype, device=torch.cuda.current_device(), requires_grad=False) + data.copy_(tensor.view(-1)[start_index:end_index]) + else: + data = tensor.view(-1)[start_index:end_index] + return data + + +def gather_split_1d_tensor(tensor: torch.Tensor) -> torch.Tensor: + """Opposite of above function, gather values from model parallel ranks. + + Args: + tensor (:class:`torch.Tensor`): Tensor to be gathered after communication. + Returns: + :class:`torch.Tensor`: The gathered tensor. + """ + world_size = gpc.get_world_size(ParallelMode.PARALLEL_1D) + numel = torch.numel(tensor) + numel_gathered = world_size * numel + gathered = torch.empty(numel_gathered, dtype=tensor.dtype, device=torch.cuda.current_device(), requires_grad=False) + chunks = [gathered[i * numel : (i + 1) * numel] for i in range(world_size)] + dist.all_gather(chunks, tensor, group=gpc.get_group(ParallelMode.PARALLEL_1D)) + return gathered diff --git a/colossalai/legacy/constants.py b/colossalai/legacy/constants.py new file mode 100644 index 0000000000000000000000000000000000000000..5d64b676e73dbb3d15d64e543f441073adf7447a --- /dev/null +++ b/colossalai/legacy/constants.py @@ -0,0 +1,32 @@ +#!/usr/bin/env python +# -*- encoding: utf-8 -*- + +ALLOWED_MODES = [None, "1d", "2d", "2.5d", "3d", "sequence"] +TENSOR_PARALLEL_MODE = "tensor_parallel_mode" + +# initializer +INITIALIZER_MAPPING = { + "data": "Initializer_Data", + "tensor": "Initializer_Tensor", + "pipeline": "Initializer_Pipeline", + "embedding": "Initializer_Embedding", + "1d": "Initializer_1D", + "2d": "Initializer_2D", + "2.5d": "Initializer_2p5D", + "3d": "Initializer_3D", + "sequence": "Initializer_Sequence", + "model": "Initializer_Model", + "moe": "Initializer_Moe", +} + +# 3D parallelism groups +INPUT_GROUP_3D = "input_group_3d" +WEIGHT_GROUP_3D = "weight_group_3d" +OUTPUT_GROUP_3D = "output_group_3d" +INPUT_X_WEIGHT_3D = "input_x_weight_group_3d" +OUTPUT_X_WEIGHT_3D = "output_x_weight_group_3d" + +# Attributes of tensor parallel parameters +IS_TENSOR_PARALLEL = "is_tensor_parallel" +NUM_PARTITIONS = "num_partitions" +TENSOR_PARALLEL_ATTRIBUTES = [IS_TENSOR_PARALLEL, NUM_PARTITIONS] diff --git a/colossalai/legacy/context/__init__.py b/colossalai/legacy/context/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..7027945ead7c034f8d88db42cc9c8b8296ac1919 --- /dev/null +++ b/colossalai/legacy/context/__init__.py @@ -0,0 +1,4 @@ +from .parallel_context import ParallelContext +from .parallel_mode import ParallelMode +from .process_group_initializer import * +from .random import * diff --git a/colossalai/context/parallel_context.py b/colossalai/legacy/context/parallel_context.py similarity index 76% rename from colossalai/context/parallel_context.py rename to colossalai/legacy/context/parallel_context.py index 003f0cdd91b6630fe1a88271eed3afdd4021c3b8..48bf8ab279e895952b953e2ee03d0ea036180954 100644 --- a/colossalai/context/parallel_context.py +++ b/colossalai/legacy/context/parallel_context.py @@ -4,19 +4,18 @@ import random import socket from collections import Counter -from threading import local from typing import Union import numpy as np import torch import torch.distributed as dist -from colossalai.constants import ALLOWED_MODES, INITIALIZER_MAPPING from colossalai.context.config import Config from colossalai.context.singleton_meta import SingletonMeta -from colossalai.global_variables import tensor_parallel_env as env +from colossalai.legacy.constants import ALLOWED_MODES, INITIALIZER_MAPPING +from colossalai.legacy.global_variables import tensor_parallel_env as env +from colossalai.legacy.registry import DIST_GROUP_INITIALIZER from colossalai.logging import get_dist_logger -from colossalai.registry import DIST_GROUP_INITIALIZER from .parallel_mode import ParallelMode from .random import add_seed, get_seeds, set_mode @@ -95,8 +94,9 @@ class ParallelContext(metaclass=SingletonMeta): @staticmethod def _check_parallel_mode(parallel_mode: ParallelMode): - assert isinstance(parallel_mode, ParallelMode), \ - f'expected the argument parallel_mode to be of enum ParallelMode, but got {type(parallel_mode)}' + assert isinstance( + parallel_mode, ParallelMode + ), f"expected the argument parallel_mode to be of enum ParallelMode, but got {type(parallel_mode)}" def get_global_rank(self): """Returns the global rank of the current device. @@ -110,12 +110,12 @@ class ParallelContext(metaclass=SingletonMeta): """Adds the global rank of the current device for `parallel_mode` to the context. Args: - parallel_mode (:class:`colossalai.context.ParallelMode`): The parallel mode for the rank. + parallel_mode (:class:`colossalai.legacy.context.ParallelMode`): The parallel mode for the rank. rank (int): The rank to be added Raises: AssertionError: Raises an AssertionError if `parallel_mode` is not an instance - of :class:`colossalai.context.ParallelMode`. + of :class:`colossalai.legacy.context.ParallelMode`. """ self._check_parallel_mode(parallel_mode) self._global_ranks[parallel_mode] = rank @@ -124,11 +124,11 @@ class ParallelContext(metaclass=SingletonMeta): """Returns the local rank of the current device. Args: - parallel_mode (:class:`colossalai.context.ParallelMode`): The chosen parallel mode. + parallel_mode (:class:`colossalai.legacy.context.ParallelMode`): The chosen parallel mode. Raises: AssertionError: Raises an AssertionError if `parallel_mode` is not an instance - of :class:`colossalai.context.ParallelMode`. + of :class:`colossalai.legacy.context.ParallelMode`. Returns: int: The local rank of the current device for `parallel_mode`. @@ -140,12 +140,12 @@ class ParallelContext(metaclass=SingletonMeta): """Adds the local rank of the current device for `parallel_mode` to the context. Args: - parallel_mode (:class:`colossalai.context.ParallelMode`): The parallel mode for the rank. + parallel_mode (:class:`colossalai.legacy.context.ParallelMode`): The parallel mode for the rank. rank (int): The rank to be added. Raises: AssertionError: Raises an AssertionError if `parallel_mode` is not an instance - of :class:`colossalai.context.ParallelMode`. + of :class:`colossalai.legacy.context.ParallelMode`. """ self._check_parallel_mode(parallel_mode) self._local_ranks[parallel_mode] = rank @@ -154,11 +154,11 @@ class ParallelContext(metaclass=SingletonMeta): """Returns the global rank of the next device. Args: - parallel_mode (:class:`colossalai.context.ParallelMode`): The chosen parallel mode. + parallel_mode (:class:`colossalai.legacy.context.ParallelMode`): The chosen parallel mode. Raises: AssertionError: Raises an AssertionError if `parallel_mode` is not an instance - of :class:`colossalai.context.ParallelMode`. + of :class:`colossalai.legacy.context.ParallelMode`. Returns: int: The global rank of the next device for `parallel_mode`. @@ -176,11 +176,11 @@ class ParallelContext(metaclass=SingletonMeta): """Returns the global rank of the previous device. Args: - parallel_mode (:class:`colossalai.context.ParallelMode`): The chosen parallel mode. + parallel_mode (:class:`colossalai.legacy.context.ParallelMode`): The chosen parallel mode. Raises: AssertionError: Raises an AssertionError if `parallel_mode` is not an instance - of :class:`colossalai.context.ParallelMode`. + of :class:`colossalai.legacy.context.ParallelMode`. Returns: int: The global rank of the previous device for `parallel_mode`. @@ -199,11 +199,11 @@ class ParallelContext(metaclass=SingletonMeta): among its group for `parallel_mode`. Args: - parallel_mode (:class:`colossalai.context.ParallelMode`): The chosen parallel mode. + parallel_mode (:class:`colossalai.legacy.context.ParallelMode`): The chosen parallel mode. Raises: AssertionError: Raises an AssertionError if `parallel_mode` is not an instance - of :class:`colossalai.context.ParallelMode`. + of :class:`colossalai.legacy.context.ParallelMode`. Returns: bool: a boolean value indicating whether the current device is the first one @@ -217,11 +217,11 @@ class ParallelContext(metaclass=SingletonMeta): among its group for `parallel_mode`. Args: - parallel_mode (:class:`colossalai.context.ParallelMode`): The chosen parallel mode. + parallel_mode (:class:`colossalai.legacy.context.ParallelMode`): The chosen parallel mode. Raises: AssertionError: Raises an AssertionError if `parallel_mode` is not an instance - of :class:`colossalai.context.ParallelMode`. + of :class:`colossalai.legacy.context.ParallelMode`. Returns: bool: a boolean value indicating whether the current device is the first one @@ -239,8 +239,10 @@ class ParallelContext(metaclass=SingletonMeta): def is_pipeline_last_stage(self, ignore_virtual=False): if not ignore_virtual: - if self.virtual_pipeline_parallel_size \ - is not None and self.virtual_pipeline_parallel_rank != self.virtual_pipeline_parallel_size - 1: + if ( + self.virtual_pipeline_parallel_size is not None + and self.virtual_pipeline_parallel_rank != self.virtual_pipeline_parallel_size - 1 + ): return False return self.is_last_rank(ParallelMode.PIPELINE) @@ -248,11 +250,11 @@ class ParallelContext(metaclass=SingletonMeta): """Returns the world size for `parallel_mode`. Args: - parallel_mode (:class:`colossalai.context.ParallelMode`): The chosen parallel mode. + parallel_mode (:class:`colossalai.legacy.context.ParallelMode`): The chosen parallel mode. Raises: AssertionError: Raises an AssertionError if `parallel_mode` is not an instance - of :class:`colossalai.context.ParallelMode`. + of :class:`colossalai.legacy.context.ParallelMode`. Returns: int: The world size for `parallel_mode`. @@ -264,12 +266,12 @@ class ParallelContext(metaclass=SingletonMeta): """Adds world size for `parallel_mode`. Args: - parallel_mode (:class:`colossalai.context.ParallelMode`): The parallel mode corresponding to the process group + parallel_mode (:class:`colossalai.legacy.context.ParallelMode`): The parallel mode corresponding to the process group world_size (int): The world size to be added Raises: AssertionError: Raises an AssertionError if `parallel_mode` is not an instance - of :class:`colossalai.context.ParallelMode`. + of :class:`colossalai.legacy.context.ParallelMode`. """ self._check_parallel_mode(parallel_mode) self._world_sizes[parallel_mode] = world_size @@ -278,11 +280,11 @@ class ParallelContext(metaclass=SingletonMeta): """Returns the group of the current device for `parallel_mode`. Args: - parallel_mode (:class:`colossalai.context.ParallelMode`): The chosen parallel mode. + parallel_mode (:class:`colossalai.legacy.context.ParallelMode`): The chosen parallel mode. Raises: AssertionError: Raises an AssertionError if `parallel_mode` is not an instance - of :class:`colossalai.context.ParallelMode`. + of :class:`colossalai.legacy.context.ParallelMode`. Returns: torch.distributed.ProcessGroup: The group of the current device for `parallel_mode`. @@ -294,12 +296,12 @@ class ParallelContext(metaclass=SingletonMeta): """Adds the group of the current device for `parallel_mode`. Args: - parallel_mode (:class:`colossalai.context.ParallelMode`): The chosen parallel mode. + parallel_mode (:class:`colossalai.legacy.context.ParallelMode`): The chosen parallel mode. group (torch.distributed.ProcessGroup): The group to be added Raises: AssertionError: Raises an AssertionError if `parallel_mode` is not an instance - of :class:`colossalai.context.ParallelMode`. + of :class:`colossalai.legacy.context.ParallelMode`. """ self._check_parallel_mode(parallel_mode) self._groups[parallel_mode] = group @@ -308,9 +310,9 @@ class ParallelContext(metaclass=SingletonMeta): """Returns the Gloo group of the current device for `parallel_mode`. :param parallel_mode: The chosen parallel mode - :type parallel_mode: :class:`colossalai.context.ParallelMode` + :type parallel_mode: :class:`colossalai.legacy.context.ParallelMode` :raises AssertionError: Raises an AssertionError if `parallel_mode` is not an instance - of :class:`colossalai.context.ParallelMode` + of :class:`colossalai.legacy.context.ParallelMode` :return: The group of the current device for `parallel_mode` :rtype: torch.distributed.ProcessGroup """ @@ -321,11 +323,11 @@ class ParallelContext(metaclass=SingletonMeta): """Adds the Gloo group of the current device for `parallel_mode`. :param parallel_mode: The chosen parallel mode - :type parallel_mode: :class:`colossalai.context.ParallelMode` + :type parallel_mode: :class:`colossalai.legacy.context.ParallelMode` :param group: The group to be added :type group: torch.distributed.ProcessGroup :raises AssertionError: Raises an AssertionError if `parallel_mode` is not an instance - of :class:`colossalai.context.ParallelMode` + of :class:`colossalai.legacy.context.ParallelMode` """ self._check_parallel_mode(parallel_mode) self._cpu_groups[parallel_mode] = group @@ -334,11 +336,11 @@ class ParallelContext(metaclass=SingletonMeta): """Returns the rank of the current device for `parallel_mode` in the group. Args: - parallel_mode (:class:`colossalai.context.ParallelMode`): The chosen parallel mode. + parallel_mode (:class:`colossalai.legacy.context.ParallelMode`): The chosen parallel mode. Raises: AssertionError: Raises an AssertionError if `parallel_mode` is not an instance - of :class:`colossalai.context.ParallelMode`. + of :class:`colossalai.legacy.context.ParallelMode`. Returns: int: The rank of the current device for `parallel_mode` in the group. @@ -350,12 +352,12 @@ class ParallelContext(metaclass=SingletonMeta): """Adds the ranks of the current device for `parallel_mode` in the group. Args: - parallel_mode (:class:`colossalai.context.ParallelMode`): The chosen parallel mode. + parallel_mode (:class:`colossalai.legacy.context.ParallelMode`): The chosen parallel mode. ranks (list): List of ranks to be added Raises: AssertionError: Raises an AssertionError if `parallel_mode` is not an instance - of :class:`colossalai.context.ParallelMode`. + of :class:`colossalai.legacy.context.ParallelMode`. """ self._check_parallel_mode(parallel_mode) self._ranks_in_group[parallel_mode] = ranks @@ -371,12 +373,12 @@ class ParallelContext(metaclass=SingletonMeta): port (str): the master port for distributed training """ # initialize the default process group - init_method = f'tcp://[{host}]:{port}' + init_method = f"tcp://[{host}]:{port}" dist.init_process_group(rank=rank, world_size=world_size, backend=backend, init_method=init_method) # None will give the default global process group for pytorch dist operations ranks = list(range(world_size)) - cpu_group = dist.new_group(ranks, backend='gloo') if dist.get_backend() != 'gloo' else None + cpu_group = dist.new_group(ranks, backend="gloo") if dist.get_backend() != "gloo" else None self._register_dist(rank, world_size, dist.GroupMember.WORLD, cpu_group, ranks, ParallelMode.GLOBAL) self.add_global_rank(ParallelMode.GLOBAL, rank) @@ -398,10 +400,11 @@ class ParallelContext(metaclass=SingletonMeta): pps = self.pipeline_parallel_size tps = self.tensor_parallel_size ws = self.world_size - assert ws == dps * pps * \ - tps, f"Expected the world size {ws} to be equal to data" \ - f" parallel size ({dps}) * pipeline parallel size " \ - f"({pps}) * tensor parallel size ({tps})" + assert ws == dps * pps * tps, ( + f"Expected the world size {ws} to be equal to data" + f" parallel size ({dps}) * pipeline parallel size " + f"({pps}) * tensor parallel size ({tps})" + ) def _set_parallel_size_from_config(self, config: dict, key: str, attr_name: str): if key in config: @@ -409,10 +412,11 @@ class ParallelContext(metaclass=SingletonMeta): if isinstance(ele, int): setattr(self, attr_name, ele) elif isinstance(ele, dict): - setattr(self, attr_name, ele['size']) + setattr(self, attr_name, ele["size"]) else: raise NotImplementedError( - f'{"Parallel configuration does not support this kind of argument, please use int or dict"}') + f'{"Parallel configuration does not support this kind of argument, please use int or dict"}' + ) def init_parallel_groups(self): """Initializes the parallel groups. @@ -427,10 +431,10 @@ class ParallelContext(metaclass=SingletonMeta): self.world_size = world_size # set parallel size as attributes for global context - parallel_config = self.config.get('parallel', None) + parallel_config = self.config.get("parallel", None) if parallel_config is not None: - self._set_parallel_size_from_config(parallel_config, 'pipeline', 'pipeline_parallel_size') - self._set_parallel_size_from_config(parallel_config, 'tensor', 'tensor_parallel_size') + self._set_parallel_size_from_config(parallel_config, "pipeline", "pipeline_parallel_size") + self._set_parallel_size_from_config(parallel_config, "tensor", "tensor_parallel_size") # the user should not set the data parallel size manually # instead, it should be calculated based on other parallel config @@ -438,33 +442,33 @@ class ParallelContext(metaclass=SingletonMeta): # get the tensor parallel mode and check tensor_parallel_mode = None - if parallel_config is not None and 'tensor' in \ - parallel_config and 'mode' in parallel_config['tensor']: - tensor_parallel_mode = parallel_config['tensor']['mode'] - assert tensor_parallel_mode in ALLOWED_MODES, \ - f"mode in the parallel config must be set to one of {ALLOWED_MODES}" + if parallel_config is not None and "tensor" in parallel_config and "mode" in parallel_config["tensor"]: + tensor_parallel_mode = parallel_config["tensor"]["mode"] + assert ( + tensor_parallel_mode in ALLOWED_MODES + ), f"mode in the parallel config must be set to one of {ALLOWED_MODES}" env.mode = tensor_parallel_mode self.check_sanity() pg_init = [] # LSG: init data parallel process group for compatibility with other parallel module such as zero - pg_init.append(dict(type=INITIALIZER_MAPPING['data'])) + pg_init.append(dict(type=INITIALIZER_MAPPING["data"])) # LSG: init model parallel process group for compatibility with amp and clip grad - pg_init.append(dict(type=INITIALIZER_MAPPING['model'])) + pg_init.append(dict(type=INITIALIZER_MAPPING["model"])) if self.pipeline_parallel_size > 1: - pg_init.append(dict(type=INITIALIZER_MAPPING['pipeline'])) - pg_init.append(dict(type=INITIALIZER_MAPPING['tensor'])) + pg_init.append(dict(type=INITIALIZER_MAPPING["pipeline"])) + pg_init.append(dict(type=INITIALIZER_MAPPING["tensor"])) # init specific tensor parallel group if tensor_parallel_mode is not None: - tensor_parallel_cfg = parallel_config['tensor'].copy() + tensor_parallel_cfg = parallel_config["tensor"].copy() # remove duplicate parameters - tensor_parallel_cfg.pop('mode') - tensor_parallel_cfg.pop('size') + tensor_parallel_cfg.pop("mode") + tensor_parallel_cfg.pop("size") # add this config to initialize later pg_init.append(dict(type=INITIALIZER_MAPPING[tensor_parallel_mode.lower()], **tensor_parallel_cfg)) @@ -472,11 +476,16 @@ class ParallelContext(metaclass=SingletonMeta): # run initialization of different process groups for initializer_cfg in pg_init: cfg = initializer_cfg.copy() - initializer_type = cfg.pop('type') - initializer = DIST_GROUP_INITIALIZER.get_module(initializer_type)(rank, world_size, self.config, - self.data_parallel_size, - self.pipeline_parallel_size, - self.tensor_parallel_size, **cfg) + initializer_type = cfg.pop("type") + initializer = DIST_GROUP_INITIALIZER.get_module(initializer_type)( + rank, + world_size, + self.config, + self.data_parallel_size, + self.pipeline_parallel_size, + self.tensor_parallel_size, + **cfg, + ) parallel_setting = initializer.init_dist_group() if isinstance(parallel_setting, list): for args in parallel_setting: @@ -489,7 +498,7 @@ class ParallelContext(metaclass=SingletonMeta): in the current system. Args: - parallel_mode (:class:`colossalai.context.ParallelMode`): The chosen parallel mode. + parallel_mode (:class:`colossalai.legacy.context.ParallelMode`): The chosen parallel mode. Returns: bool: a boolean value indicating whether `parallel_mode` is initialized in the current system. @@ -497,8 +506,7 @@ class ParallelContext(metaclass=SingletonMeta): return parallel_mode in self._groups def destroy(self): - """Destroys the current distributed parallel environment. - """ + """Destroys the current distributed parallel environment.""" for mode, group in self._groups.items(): if mode is not ParallelMode.GLOBAL: dist.destroy_process_group(group) @@ -519,7 +527,7 @@ class ParallelContext(metaclass=SingletonMeta): torch.cuda.set_device(device_ordinal) if self._verbose: - self._logger.info(f'process rank {global_rank} is bound to device {device_ordinal}') + self._logger.info(f"process rank {global_rank} is bound to device {device_ordinal}") def set_seed(self, seed: int): """Sets seeds for all random libraries. @@ -552,21 +560,25 @@ class ParallelContext(metaclass=SingletonMeta): set_mode(ParallelMode.DATA) seeds = get_seeds() - seed_str = ', '.join([f'{k}: {v}' for k, v in seeds.items()]) + seed_str = ", ".join([f"{k}: {v}" for k, v in seeds.items()]) if self._verbose: - self._logger.info(f"initialized seed on rank {global_rank}, " - f"numpy: {seed}, python random: {seed}, {seed_str}," - f"the default parallel seed is {ParallelMode.DATA}.") + self._logger.info( + f"initialized seed on rank {global_rank}, " + f"numpy: {seed}, python random: {seed}, {seed_str}," + f"the default parallel seed is {ParallelMode.DATA}." + ) else: if self._verbose: self._logger.info( f"initialized seed on rank {global_rank}, " f"numpy: {seed}, python random: {seed}, pytorch: {seed}", - ranks=[0]) + ranks=[0], + ) self._logger.info( - 'WARNING: CUDA is not available, thus CUDA RNG cannot be used to track CUDA random number states', - ranks=[0]) + "WARNING: CUDA is not available, thus CUDA RNG cannot be used to track CUDA random number states", + ranks=[0], + ) def set_virtual_pipeline_parallel_size(self, size): self.virtual_pipeline_parallel_size = size diff --git a/colossalai/legacy/context/parallel_mode.py b/colossalai/legacy/context/parallel_mode.py new file mode 100644 index 0000000000000000000000000000000000000000..ceb52ff20da796f925399311bf8a8424d73fb379 --- /dev/null +++ b/colossalai/legacy/context/parallel_mode.py @@ -0,0 +1,48 @@ +#!/usr/bin/env python +# -*- encoding: utf-8 -*- + +from enum import Enum + + +# parallel modes +class ParallelMode(Enum): + """This is an enumeration class containing all possible parallel modes.""" + + GLOBAL = "global" + + # common parallel + DATA = "data" + + # model parallel - containing tensor and pipeline parallel groups + # this is added to facilitate amp and grad clipping in hybrid parallel + MODEL = "model" + + # pipeline parallel + PIPELINE = "pipe" + + # containing all ranks in tensor parallel + TENSOR = "tensor" + + # sequence parallel + SEQUENCE = "sequence" + SEQUENCE_DP = "sequence_dp" + + # 1D Parallel + PARALLEL_1D = "1d" + + # 2D parallel + PARALLEL_2D_ROW = "2d_row" + PARALLEL_2D_COL = "2d_col" + + # 3D parallel + PARALLEL_3D_INPUT = "3d_input" + PARALLEL_3D_WEIGHT = "3d_weight" + PARALLEL_3D_OUTPUT = "3d_output" + PARALLEL_3D_INPUT_X_WEIGHT = "3d_input_x_weight" + PARALLEL_3D_OUTPUT_X_WEIGHT = "3d_output_x_weight" + + # 2.5D parallel + PARALLEL_2P5D_ROW = "2p5d_row" + PARALLEL_2P5D_COL = "2p5d_col" + PARALLEL_2P5D_DEP = "2p5d_dep" + PARALLEL_2P5D_XZ = "2p5d_xz" diff --git a/colossalai/legacy/context/process_group_initializer/__init__.py b/colossalai/legacy/context/process_group_initializer/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..a83165e40a8fafc9bca355df08e619f4f0ac7f02 --- /dev/null +++ b/colossalai/legacy/context/process_group_initializer/__init__.py @@ -0,0 +1,23 @@ +from .initializer_1d import Initializer_1D +from .initializer_2d import Initializer_2D +from .initializer_2p5d import Initializer_2p5D +from .initializer_3d import Initializer_3D +from .initializer_data import Initializer_Data +from .initializer_model import Initializer_Model +from .initializer_pipeline import Initializer_Pipeline +from .initializer_sequence import Initializer_Sequence +from .initializer_tensor import Initializer_Tensor +from .process_group_initializer import ProcessGroupInitializer + +__all__ = [ + "Initializer_Tensor", + "Initializer_Sequence", + "Initializer_Pipeline", + "Initializer_Data", + "Initializer_2p5D", + "Initializer_2D", + "Initializer_3D", + "Initializer_1D", + "ProcessGroupInitializer", + "Initializer_Model", +] diff --git a/colossalai/context/process_group_initializer/initializer_1d.py b/colossalai/legacy/context/process_group_initializer/initializer_1d.py similarity index 88% rename from colossalai/context/process_group_initializer/initializer_1d.py rename to colossalai/legacy/context/process_group_initializer/initializer_1d.py index 4c05028041cef2a9ad453b05d17d35b09ec2617d..110a42cf880e60c8364ac6371cc625e2fd00c73d 100644 --- a/colossalai/context/process_group_initializer/initializer_1d.py +++ b/colossalai/legacy/context/process_group_initializer/initializer_1d.py @@ -2,8 +2,9 @@ # -*- encoding: utf-8 -*- import torch.distributed as dist -from colossalai.global_variables import tensor_parallel_env as env -from colossalai.registry import DIST_GROUP_INITIALIZER + +from colossalai.legacy.global_variables import tensor_parallel_env as env +from colossalai.legacy.registry import DIST_GROUP_INITIALIZER from ..parallel_mode import ParallelMode from .process_group_initializer import ProcessGroupInitializer @@ -44,7 +45,7 @@ class Initializer_1D(ProcessGroupInitializer): for i in range(self.num_group): ranks = [i * self.tensor_parallel_size + j for j in range(self.tensor_parallel_size)] group = dist.new_group(ranks) - group_cpu = dist.new_group(ranks, backend='gloo') if dist.get_backend() != 'gloo' else group + group_cpu = dist.new_group(ranks, backend="gloo") if dist.get_backend() != "gloo" else group if self.rank in ranks: local_rank = ranks.index(self.rank) diff --git a/colossalai/context/process_group_initializer/initializer_2d.py b/colossalai/legacy/context/process_group_initializer/initializer_2d.py similarity index 88% rename from colossalai/context/process_group_initializer/initializer_2d.py rename to colossalai/legacy/context/process_group_initializer/initializer_2d.py index 7fbe3be5901f73b8c670c71582771ab861e9fccd..1c08d4d4296a1b3d02383944b37b85538f6d94f8 100644 --- a/colossalai/context/process_group_initializer/initializer_2d.py +++ b/colossalai/legacy/context/process_group_initializer/initializer_2d.py @@ -2,8 +2,8 @@ import math import torch.distributed as dist -from colossalai.global_variables import tensor_parallel_env as env -from colossalai.registry import DIST_GROUP_INITIALIZER +from colossalai.legacy.global_variables import tensor_parallel_env as env +from colossalai.legacy.registry import DIST_GROUP_INITIALIZER from ..parallel_mode import ParallelMode from .process_group_initializer import ProcessGroupInitializer @@ -14,9 +14,10 @@ def _check_summa_env_var(summa_dim): env_summa_dim = env.summa_dim if env_summa_dim: - assert int(env_summa_dim) == summa_dim, \ - 'SUMMA_DIM has been set in the current environment and ' \ - 'does not match with the value passed to this initialized' + assert int(env_summa_dim) == summa_dim, ( + "SUMMA_DIM has been set in the current environment and " + "does not match with the value passed to this initialized" + ) else: env.summa_dim = summa_dim @@ -57,7 +58,7 @@ class Initializer_2D_Row(ProcessGroupInitializer): for j in range(self.summa_dim): ranks = [i * self.tensor_parallel_size + j * self.summa_dim + k for k in range(self.summa_dim)] group = dist.new_group(ranks) - group_cpu = dist.new_group(ranks, backend='gloo') if dist.get_backend() != 'gloo' else group + group_cpu = dist.new_group(ranks, backend="gloo") if dist.get_backend() != "gloo" else group if self.rank in ranks: local_rank = ranks.index(self.rank) @@ -106,7 +107,7 @@ class Initializer_2D_Col(ProcessGroupInitializer): for j in range(self.summa_dim): ranks = [i * self.tensor_parallel_size + j + k * self.summa_dim for k in range(self.summa_dim)] group = dist.new_group(ranks) - group_cpu = dist.new_group(ranks, backend='gloo') if dist.get_backend() != 'gloo' else group + group_cpu = dist.new_group(ranks, backend="gloo") if dist.get_backend() != "gloo" else group if self.rank in ranks: local_rank = ranks.index(self.rank) @@ -137,8 +138,9 @@ class Initializer_2D(ProcessGroupInitializer): self.num_group = self.world_size // self.tensor_parallel_size self.summa_dim = int(math.sqrt(self.tensor_parallel_size)) - assert self.tensor_parallel_size == self.summa_dim ** 2, \ - "2D summa dim should equal to tensor parallel size ^ 0.5" + assert ( + self.tensor_parallel_size == self.summa_dim**2 + ), "2D summa dim should equal to tensor parallel size ^ 0.5" _check_summa_env_var(self.summa_dim) self.col_initializer = Initializer_2D_Col(self.num_group, self.summa_dim, *args, **kwargs) diff --git a/colossalai/context/process_group_initializer/initializer_2p5d.py b/colossalai/legacy/context/process_group_initializer/initializer_2p5d.py similarity index 87% rename from colossalai/context/process_group_initializer/initializer_2p5d.py rename to colossalai/legacy/context/process_group_initializer/initializer_2p5d.py index 6b6fdc5d715c30169f04cef54abd946c4c46b904..b7d71b96334d6e8bbd9f3be5ba3a38eb93d530b9 100644 --- a/colossalai/context/process_group_initializer/initializer_2p5d.py +++ b/colossalai/legacy/context/process_group_initializer/initializer_2p5d.py @@ -4,9 +4,10 @@ import math import torch.distributed as dist + from colossalai.context import Config -from colossalai.global_variables import tensor_parallel_env as env -from colossalai.registry import DIST_GROUP_INITIALIZER +from colossalai.legacy.global_variables import tensor_parallel_env as env +from colossalai.legacy.registry import DIST_GROUP_INITIALIZER from ..parallel_mode import ParallelMode from .process_group_initializer import ProcessGroupInitializer @@ -18,12 +19,14 @@ def _check_tesseract_env_var(tesseract_dim: int, tesseract_dep: int): env_tesseract_dep = env.tesseract_dep if env_tesseract_dim and env_tesseract_dep: - assert int(env_tesseract_dim) == tesseract_dim, \ - 'TESSERACT_DIM has been set in the current environment and ' \ - 'does not match with the value passed to this initialized' - assert int(env_tesseract_dep) == tesseract_dep, \ - 'TESSERACT_DEP has been set in the current environment and ' \ - 'does not match with the value passed to this initialized' + assert int(env_tesseract_dim) == tesseract_dim, ( + "TESSERACT_DIM has been set in the current environment and " + "does not match with the value passed to this initialized" + ) + assert int(env_tesseract_dep) == tesseract_dep, ( + "TESSERACT_DEP has been set in the current environment and " + "does not match with the value passed to this initialized" + ) else: env.tesseract_dim = tesseract_dim env.tesseract_dep = tesseract_dep @@ -49,8 +52,9 @@ class Initializer_2p5D_ROW(ProcessGroupInitializer): self.num_group = self.world_size // self.tensor_parallel_size self.tesseract_dep = tesseract_dep self.tesseract_dim = tesseract_dim - assert self.tensor_parallel_size == self.tesseract_dim ** 2 * self.tesseract_dep, \ - "Tensor parallel size should be depth * dim ** 2 in 2.5D parallel" + assert ( + self.tensor_parallel_size == self.tesseract_dim**2 * self.tesseract_dep + ), "Tensor parallel size should be depth * dim ** 2 in 2.5D parallel" def init_dist_group(self): """Initialize 2.5D tensor row parallel groups, and assign local_ranks and groups to each gpu. @@ -74,7 +78,7 @@ class Initializer_2p5D_ROW(ProcessGroupInitializer): for i in range(self.tesseract_dim) ] group = dist.new_group(ranks) - group_cpu = dist.new_group(ranks, backend='gloo') if dist.get_backend() != 'gloo' else group + group_cpu = dist.new_group(ranks, backend="gloo") if dist.get_backend() != "gloo" else group if self.rank in ranks: local_rank = ranks.index(self.rank) @@ -128,7 +132,7 @@ class Initializer_2p5D_Col(ProcessGroupInitializer): for j in range(self.tesseract_dim) ] group = dist.new_group(ranks) - group_cpu = dist.new_group(ranks, backend='gloo') if dist.get_backend() != 'gloo' else group + group_cpu = dist.new_group(ranks, backend="gloo") if dist.get_backend() != "gloo" else group if self.rank in ranks: local_rank = ranks.index(self.rank) @@ -182,7 +186,7 @@ class Initializer_2p5D_Dep(ProcessGroupInitializer): for k in range(self.tesseract_dep) ] group = dist.new_group(ranks) - group_cpu = dist.new_group(ranks, backend='gloo') if dist.get_backend() != 'gloo' else group + group_cpu = dist.new_group(ranks, backend="gloo") if dist.get_backend() != "gloo" else group if self.rank in ranks: local_rank = ranks.index(self.rank) @@ -237,7 +241,7 @@ class Initializer_2p5D_XZ(ProcessGroupInitializer): for j in range(self.tesseract_dim) ] group = dist.new_group(ranks) - group_cpu = dist.new_group(ranks, backend='gloo') if dist.get_backend() != 'gloo' else group + group_cpu = dist.new_group(ranks, backend="gloo") if dist.get_backend() != "gloo" else group if self.rank in ranks: local_rank = ranks.index(self.rank) @@ -264,16 +268,25 @@ class Initializer_2p5D(ProcessGroupInitializer): depth (int): The depth of 2.5d parallel. """ - def __init__(self, rank: int, world_size: int, config: Config, data_parallel_size: int, pipeline_parallel_size: int, - tensor_parallel_size: int, depth: int): + def __init__( + self, + rank: int, + world_size: int, + config: Config, + data_parallel_size: int, + pipeline_parallel_size: int, + tensor_parallel_size: int, + depth: int, + ): args = (rank, world_size, config, data_parallel_size, pipeline_parallel_size, tensor_parallel_size) super().__init__(*args) self.num_group = self.world_size // self.tensor_parallel_size self.tesseract_dim = int(math.sqrt(self.tensor_parallel_size / depth)) self.tesseract_dep = depth - assert self.tensor_parallel_size == self.tesseract_dim ** 2 * self.tesseract_dep, \ - "2.5D tesseract dim should equal to (tensor parallel size / tesseract dep) ^ 0.5" + assert ( + self.tensor_parallel_size == self.tesseract_dim**2 * self.tesseract_dep + ), "2.5D tesseract dim should equal to (tensor parallel size / tesseract dep) ^ 0.5" _check_tesseract_env_var(self.tesseract_dim, self.tesseract_dep) self.col_initializer = Initializer_2p5D_Col(self.tesseract_dim, self.tesseract_dep, *args) @@ -292,6 +305,6 @@ class Initializer_2p5D(ProcessGroupInitializer): self.col_initializer.init_dist_group(), self.row_initializer.init_dist_group(), self.dep_initializer.init_dist_group(), - self.xz_initializer.init_dist_group() + self.xz_initializer.init_dist_group(), ] return parallel_setting diff --git a/colossalai/context/process_group_initializer/initializer_3d.py b/colossalai/legacy/context/process_group_initializer/initializer_3d.py similarity index 91% rename from colossalai/context/process_group_initializer/initializer_3d.py rename to colossalai/legacy/context/process_group_initializer/initializer_3d.py index 1ed8eec86efc83315ee8b549a9a035bc36dca6da..5f96405e90aae57e1aecd148745fb12c2a2d83f3 100644 --- a/colossalai/context/process_group_initializer/initializer_3d.py +++ b/colossalai/legacy/context/process_group_initializer/initializer_3d.py @@ -5,8 +5,8 @@ import math import torch.distributed as dist -from colossalai.global_variables import tensor_parallel_env as env -from colossalai.registry import DIST_GROUP_INITIALIZER +from colossalai.legacy.global_variables import tensor_parallel_env as env +from colossalai.legacy.registry import DIST_GROUP_INITIALIZER from ..parallel_mode import ParallelMode from .process_group_initializer import ProcessGroupInitializer @@ -17,9 +17,10 @@ def _check_depth_env_var(depth): env_depth = env.depth_3d if env_depth: - assert int(env_depth) == depth, \ - 'DEPTH_3D has been set in the current environment and ' \ - 'does not match with the value passed to this initialized' + assert int(env_depth) == depth, ( + "DEPTH_3D has been set in the current environment and " + "does not match with the value passed to this initialized" + ) else: env.depth_3d = depth @@ -63,7 +64,7 @@ class Initializer_3D_Input(ProcessGroupInitializer): for k in range(self.depth): ranks = [h * self.depth**3 + i + self.depth * (j + self.depth * k) for j in range(self.depth)] group = dist.new_group(ranks) - group_cpu = dist.new_group(ranks, backend='gloo') if dist.get_backend() != 'gloo' else group + group_cpu = dist.new_group(ranks, backend="gloo") if dist.get_backend() != "gloo" else group if self.rank in ranks: local_rank = ranks.index(self.rank) @@ -114,7 +115,7 @@ class Initializer_3D_Weight(ProcessGroupInitializer): for j in range(self.depth): ranks = [h * self.depth**3 + i + self.depth * (j + self.depth * k) for i in range(self.depth)] group = dist.new_group(ranks) - group_cpu = dist.new_group(ranks, backend='gloo') if dist.get_backend() != 'gloo' else group + group_cpu = dist.new_group(ranks, backend="gloo") if dist.get_backend() != "gloo" else group if self.rank in ranks: local_rank = ranks.index(self.rank) @@ -165,7 +166,7 @@ class Initializer_3D_Output(ProcessGroupInitializer): for j in range(self.depth): ranks = [h * self.depth**3 + i + self.depth * (j + self.depth * k) for k in range(self.depth)] group = dist.new_group(ranks) - group_cpu = dist.new_group(ranks, backend='gloo') if dist.get_backend() != 'gloo' else group + group_cpu = dist.new_group(ranks, backend="gloo") if dist.get_backend() != "gloo" else group if self.rank in ranks: local_rank = ranks.index(self.rank) @@ -219,7 +220,7 @@ class Initializer_3D_InputxWeight(ProcessGroupInitializer): for i in range(self.depth) ] group = dist.new_group(ranks) - group_cpu = dist.new_group(ranks, backend='gloo') if dist.get_backend() != 'gloo' else group + group_cpu = dist.new_group(ranks, backend="gloo") if dist.get_backend() != "gloo" else group if self.rank in ranks: local_rank = ranks.index(self.rank) @@ -273,7 +274,7 @@ class Initializer_3D_OutputxWeight(ProcessGroupInitializer): for i in range(self.depth) ] group = dist.new_group(ranks) - group_cpu = dist.new_group(ranks, backend='gloo') if dist.get_backend() != 'gloo' else group + group_cpu = dist.new_group(ranks, backend="gloo") if dist.get_backend() != "gloo" else group if self.rank in ranks: local_rank = ranks.index(self.rank) @@ -302,8 +303,9 @@ class Initializer_3D(ProcessGroupInitializer): super().__init__(*args) self.num_group = self.world_size // self.tensor_parallel_size self.depth = round(math.pow(self.tensor_parallel_size, 1 / 3)) - assert self.tensor_parallel_size == self.depth ** 3, \ - f'3D depth ({self.depth}) if not cube root of tensor parallel size ({self.tensor_parallel_size})' + assert ( + self.tensor_parallel_size == self.depth**3 + ), f"3D depth ({self.depth}) if not cube root of tensor parallel size ({self.tensor_parallel_size})" _check_depth_env_var(self.depth) self.input_initializer = Initializer_3D_Input(self.num_group, self.depth, *args) @@ -324,6 +326,6 @@ class Initializer_3D(ProcessGroupInitializer): self.weight_initializer.init_dist_group(), self.output_initializer.init_dist_group(), self.input_x_weight_initializer.init_dist_group(), - self.output_x_weight_initializer.init_dist_group() + self.output_x_weight_initializer.init_dist_group(), ] return parallel_setting diff --git a/colossalai/context/process_group_initializer/initializer_data.py b/colossalai/legacy/context/process_group_initializer/initializer_data.py similarity index 91% rename from colossalai/context/process_group_initializer/initializer_data.py rename to colossalai/legacy/context/process_group_initializer/initializer_data.py index 9715ebff7f00f0fc8a3f13a5dfca436c9b0e144b..9c8bcf353c20bd061ed85129e7dbb079dfe06435 100644 --- a/colossalai/context/process_group_initializer/initializer_data.py +++ b/colossalai/legacy/context/process_group_initializer/initializer_data.py @@ -3,7 +3,7 @@ from torch import distributed as dist -from colossalai.registry import DIST_GROUP_INITIALIZER +from colossalai.legacy.registry import DIST_GROUP_INITIALIZER from ..parallel_mode import ParallelMode from .process_group_initializer import ProcessGroupInitializer @@ -43,7 +43,7 @@ class Initializer_Data(ProcessGroupInitializer): for i in range(self.num_data_parallel_group): ranks = [i + j * self.num_data_parallel_group for j in range(self.data_parallel_size)] group = dist.new_group(ranks) - group_cpu = dist.new_group(ranks, backend='gloo') if dist.get_backend() != 'gloo' else group + group_cpu = dist.new_group(ranks, backend="gloo") if dist.get_backend() != "gloo" else group if self.rank in ranks: local_rank = ranks.index(self.rank) diff --git a/colossalai/context/process_group_initializer/initializer_model.py b/colossalai/legacy/context/process_group_initializer/initializer_model.py similarity index 92% rename from colossalai/context/process_group_initializer/initializer_model.py rename to colossalai/legacy/context/process_group_initializer/initializer_model.py index 99b9cc0d4edce35915c52c01fa5875545256ba97..6aeae27756e7601e4897e5653a058d6e07d58e75 100644 --- a/colossalai/context/process_group_initializer/initializer_model.py +++ b/colossalai/legacy/context/process_group_initializer/initializer_model.py @@ -2,9 +2,11 @@ # -*- encoding: utf-8 -*- import torch.distributed as dist -from colossalai.registry import DIST_GROUP_INITIALIZER -from .process_group_initializer import ProcessGroupInitializer + +from colossalai.legacy.registry import DIST_GROUP_INITIALIZER + from ..parallel_mode import ParallelMode +from .process_group_initializer import ProcessGroupInitializer @DIST_GROUP_INITIALIZER.register_module @@ -43,7 +45,7 @@ class Initializer_Model(ProcessGroupInitializer): for i in range(self.num_group): ranks = [i * self.model_parallel_size + j for j in range(self.model_parallel_size)] group = dist.new_group(ranks) - group_cpu = dist.new_group(ranks, backend='gloo') if dist.get_backend() != 'gloo' else group + group_cpu = dist.new_group(ranks, backend="gloo") if dist.get_backend() != "gloo" else group if self.rank in ranks: local_rank = ranks.index(self.rank) diff --git a/colossalai/legacy/context/process_group_initializer/initializer_pipeline.py b/colossalai/legacy/context/process_group_initializer/initializer_pipeline.py new file mode 100644 index 0000000000000000000000000000000000000000..3e69be75ff7e7034b5842cee6986f913d039741c --- /dev/null +++ b/colossalai/legacy/context/process_group_initializer/initializer_pipeline.py @@ -0,0 +1,66 @@ +#!/usr/bin/env python +# -*- encoding: utf-8 -*- + +from torch import distributed as dist + +from colossalai.legacy.registry import DIST_GROUP_INITIALIZER + +from ..parallel_mode import ParallelMode +from .process_group_initializer import ProcessGroupInitializer + + +@DIST_GROUP_INITIALIZER.register_module +class Initializer_Pipeline(ProcessGroupInitializer): + """A ProcessGroupInitializer for pipeline parallelism. + + Args: + rank (int): The rank of current process + world_size (int): Size of whole communication world + config (Config): Running configuration + data_parallel_size (int): Size of data parallel + pipeline_parallel_size (int): Size of pipeline parallel + tensor_parallel_size (int): Size of tensor parallel + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.data_group_size = self.world_size // self.data_parallel_size + self.pipeline_stage_size = self.data_group_size // self.pipeline_parallel_size + + def init_dist_group(self): + """Initialize pipeline parallel groups, and assign local_ranks and groups to each gpu. + + Returns: + List[Tuple (local_rank, group_world_size, process_group, ranks_in_group, mode)]: + A Pipeline parallelism's information in list of tuples. + """ + dist_settings = list() + for i in range(self.data_parallel_size): + for j in range(self.pipeline_stage_size): + pipe_ranks = list( + range(i * self.data_group_size + j, (i + 1) * self.data_group_size, self.pipeline_stage_size) + ) + pipe_group_size = len(pipe_ranks) + pipe_group = dist.new_group(pipe_ranks) + group_cpu = dist.new_group(pipe_ranks, backend="gloo") if dist.get_backend() != "gloo" else pipe_group + + if self.rank in pipe_ranks: + local_rank = pipe_ranks.index(self.rank) + group_world_size = pipe_group_size + process_group = pipe_group + cpu_group = group_cpu + ranks_in_group = pipe_ranks + dist_settings.append( + tuple( + ( + local_rank, + group_world_size, + process_group, + cpu_group, + ranks_in_group, + ParallelMode.PIPELINE, + ) + ) + ) + + return dist_settings diff --git a/colossalai/context/process_group_initializer/initializer_sequence.py b/colossalai/legacy/context/process_group_initializer/initializer_sequence.py similarity index 89% rename from colossalai/context/process_group_initializer/initializer_sequence.py rename to colossalai/legacy/context/process_group_initializer/initializer_sequence.py index eaacb14d22825db7913e1c87cfe08063ab5865ee..638b6d5ef2a61561e189c480ee2f7c5807c26704 100644 --- a/colossalai/context/process_group_initializer/initializer_sequence.py +++ b/colossalai/legacy/context/process_group_initializer/initializer_sequence.py @@ -2,7 +2,7 @@ # -*- encoding: utf-8 -*- import torch.distributed as dist -from colossalai.registry import DIST_GROUP_INITIALIZER +from colossalai.legacy.registry import DIST_GROUP_INITIALIZER from ..parallel_mode import ParallelMode from .initializer_tensor import Initializer_Tensor @@ -46,7 +46,7 @@ class Initializer_Sequence_DP(ProcessGroupInitializer): for i in range(self.num_group): ranks = [i * self.dp_size + j for j in range(self.dp_size)] group = dist.new_group(ranks) - group_cpu = dist.new_group(ranks, backend='gloo') if dist.get_backend() != 'gloo' else group + group_cpu = dist.new_group(ranks, backend="gloo") if dist.get_backend() != "gloo" else group if self.rank in ranks: local_rank = ranks.index(self.rank) @@ -91,11 +91,17 @@ class Initializer_Sequence(ProcessGroupInitializer): parallel_setting = [] - local_rank, group_world_size, process_group, cpu_grop, ranks_in_group, mode = \ - self._sequence_initializer.init_dist_group() + ( + local_rank, + group_world_size, + process_group, + cpu_group, + ranks_in_group, + mode, + ) = self._sequence_initializer.init_dist_group() # change mode to sequence mode = ParallelMode.SEQUENCE - parallel_setting.append((local_rank, group_world_size, process_group, cpu_grop, ranks_in_group, mode)) + parallel_setting.append((local_rank, group_world_size, process_group, cpu_group, ranks_in_group, mode)) parallel_setting.append(self._sequence_dp_initializer.init_dist_group()) return parallel_setting diff --git a/colossalai/context/process_group_initializer/initializer_tensor.py b/colossalai/legacy/context/process_group_initializer/initializer_tensor.py similarity index 91% rename from colossalai/context/process_group_initializer/initializer_tensor.py rename to colossalai/legacy/context/process_group_initializer/initializer_tensor.py index d2b5be9cfffbe9eb7234411c6526d4055c078f12..cb19a43bd373090278678efe1757270e0f8be15a 100644 --- a/colossalai/context/process_group_initializer/initializer_tensor.py +++ b/colossalai/legacy/context/process_group_initializer/initializer_tensor.py @@ -3,9 +3,10 @@ import torch.distributed as dist -from colossalai.registry import DIST_GROUP_INITIALIZER -from .process_group_initializer import ProcessGroupInitializer +from colossalai.legacy.registry import DIST_GROUP_INITIALIZER + from ..parallel_mode import ParallelMode +from .process_group_initializer import ProcessGroupInitializer @DIST_GROUP_INITIALIZER.register_module @@ -42,7 +43,7 @@ class Initializer_Tensor(ProcessGroupInitializer): for i in range(self.num_tensor_parallel_group): ranks = [i * self.tensor_parallel_size + j for j in range(self.tensor_parallel_size)] group = dist.new_group(ranks) - group_cpu = dist.new_group(ranks, backend='gloo') if dist.get_backend() != 'gloo' else group + group_cpu = dist.new_group(ranks, backend="gloo") if dist.get_backend() != "gloo" else group if self.rank in ranks: local_rank = ranks.index(self.rank) diff --git a/colossalai/context/process_group_initializer/process_group_initializer.py b/colossalai/legacy/context/process_group_initializer/process_group_initializer.py similarity index 82% rename from colossalai/context/process_group_initializer/process_group_initializer.py rename to colossalai/legacy/context/process_group_initializer/process_group_initializer.py index 98150ce8e428a3b9bf81185719685b38efc2bdfd..98b5d7fc3882cf2c95251c9d5da42e5cfaaeea7e 100644 --- a/colossalai/context/process_group_initializer/process_group_initializer.py +++ b/colossalai/legacy/context/process_group_initializer/process_group_initializer.py @@ -18,8 +18,15 @@ class ProcessGroupInitializer(ABC): tensor_parallel_size (int): Size of tensor parallel. """ - def __init__(self, rank: int, world_size: int, config: Config, data_parallel_size: int, pipeline_parallel_size: int, - tensor_parallel_size: int): + def __init__( + self, + rank: int, + world_size: int, + config: Config, + data_parallel_size: int, + pipeline_parallel_size: int, + tensor_parallel_size: int, + ): self.rank = rank self.world_size = world_size self.data_parallel_size = data_parallel_size diff --git a/colossalai/legacy/context/random/__init__.py b/colossalai/legacy/context/random/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..5e8d82922ddc4af2644cba1ef80479854c5db084 --- /dev/null +++ b/colossalai/legacy/context/random/__init__.py @@ -0,0 +1,27 @@ +from ._helper import ( + add_seed, + get_current_mode, + get_seeds, + get_states, + moe_set_seed, + reset_seeds, + seed, + set_mode, + set_seed_states, + sync_states, + with_seed, +) + +__all__ = [ + "seed", + "set_mode", + "with_seed", + "add_seed", + "get_seeds", + "get_states", + "get_current_mode", + "set_seed_states", + "sync_states", + "moe_set_seed", + "reset_seeds", +] diff --git a/colossalai/context/random/_helper.py b/colossalai/legacy/context/random/_helper.py similarity index 90% rename from colossalai/context/random/_helper.py rename to colossalai/legacy/context/random/_helper.py index 973c4d9faa325820aa1dedc5e133551430778057..be1d951d122917b5ed5d89ec11c067ebdfe2e92b 100644 --- a/colossalai/context/random/_helper.py +++ b/colossalai/legacy/context/random/_helper.py @@ -7,8 +7,8 @@ from contextlib import contextmanager import torch.cuda from torch import Tensor -from .seed_manager import SeedManager from ..parallel_mode import ParallelMode +from .seed_manager import SeedManager _SEED_MANAGER = SeedManager() @@ -53,11 +53,11 @@ def add_seed(parallel_mode: ParallelMode, seed: int, overwrite: bool = False): """Adds a seed to the seed manager for `parallel_mode`. Args: - parallel_mode (:class:`colossalai.context.ParallelMode`): The chosen parallel mode. + parallel_mode (:class:`colossalai.legacy.context.ParallelMode`): The chosen parallel mode. seed (int): The seed to be added Raises: AssertionError: Raises an AssertionError if `parallel_mode` is not an instance of - :class:`colossalai.context.ParallelMode` or the seed for `parallel_mode` has been added. + :class:`colossalai.legacy.context.ParallelMode` or the seed for `parallel_mode` has been added. Note: The parallel_mode should be concluded in ``ParallelMode``. More details about ``ParallelMode`` could be found @@ -70,7 +70,7 @@ def set_mode(parallel_mode: ParallelMode): """Sets the current mode of the seed manager. Args: - parallel_mode (:class:`colossalai.context.ParallelMode`): The chosen parallel mode. + parallel_mode (:class:`colossalai.legacy.context.ParallelMode`): The chosen parallel mode. Note: The parallel_mode should be concluded in ``ParallelMode``. More details about ``ParallelMode`` could be found @@ -83,7 +83,7 @@ def set_seed_states(parallel_mode: ParallelMode, state: Tensor): """Sets the state of the seed manager for `parallel_mode`. Args: - parallel_mode (:class:`colossalai.context.ParallelMode`): The chosen parallel mode. + parallel_mode (:class:`colossalai.legacy.context.ParallelMode`): The chosen parallel mode. state (:class:`torch.Tensor`): the state to be set. Raises: @@ -100,7 +100,7 @@ def sync_states(): @contextmanager def seed(parallel_mode: ParallelMode): - """ A context for seed switch + """A context for seed switch Examples: @@ -161,7 +161,8 @@ def with_seed(func, parallel_mode: ParallelMode): def moe_set_seed(seed): if torch.cuda.is_available(): - from colossalai.core import global_context as gpc + from colossalai.legacy.core import global_context as gpc + global_rank = gpc.get_global_rank() diff_seed = seed + global_rank add_seed(ParallelMode.TENSOR, diff_seed, True) diff --git a/colossalai/context/random/seed_manager.py b/colossalai/legacy/context/random/seed_manager.py similarity index 77% rename from colossalai/context/random/seed_manager.py rename to colossalai/legacy/context/random/seed_manager.py index 956f9001200d8706bbd45e1c9b09a175ff10b82d..c90e849631a15cdbca8449c34e3763aad7b5de22 100644 --- a/colossalai/context/random/seed_manager.py +++ b/colossalai/legacy/context/random/seed_manager.py @@ -4,7 +4,7 @@ import torch from torch import Tensor -from colossalai.context.parallel_mode import ParallelMode +from colossalai.legacy.context.parallel_mode import ParallelMode class SeedManager: @@ -36,20 +36,20 @@ class SeedManager: """Sets the state of the seed manager for `parallel_mode`. Args: - parallel_mode (:class:`colossalai.context.ParallelMode`): The chosen parallel mode. + parallel_mode (:class:`colossalai.legacy.context.ParallelMode`): The chosen parallel mode. state (:class:`torch.Tensor`): the state to be set. Raises: AssertionError: Raises an AssertionError if `parallel_mode` is not found in the seed manager. """ - assert parallel_mode in self._seed_states, f'Parallel mode {parallel_mode} is not found in the seed manager' + assert parallel_mode in self._seed_states, f"Parallel mode {parallel_mode} is not found in the seed manager" self._seed_states[parallel_mode] = state def set_mode(self, parallel_mode: ParallelMode): """Sets the current mode of the seed manager. Args: - parallel_mode (:class:`colossalai.context.ParallelMode`): The chosen parallel mode. + parallel_mode (:class:`colossalai.legacy.context.ParallelMode`): The chosen parallel mode. """ if self.current_mode: # save the current state for current mode @@ -63,17 +63,17 @@ class SeedManager: """Adds a seed to the seed manager for `parallel_mode`. Args: - parallel_mode (:class:`colossalai.context.ParallelMode`): The chosen parallel mode. + parallel_mode (:class:`colossalai.legacy.context.ParallelMode`): The chosen parallel mode. seed (int): The seed to be added. overwrite (bool, optional): Whether allows to overwrite the seed that has been set already Raises: - AssertionError: Raises an AssertionError if `parallel_mode` is not an instance of :class:`colossalai.context.ParallelMode` + AssertionError: Raises an AssertionError if `parallel_mode` is not an instance of :class:`colossalai.legacy.context.ParallelMode` or the seed for `parallel_mode` has been added. """ - assert isinstance(parallel_mode, ParallelMode), 'A valid ParallelMode must be provided' + assert isinstance(parallel_mode, ParallelMode), "A valid ParallelMode must be provided" if overwrite is False: - assert parallel_mode not in self._seed_states, f'The seed for {parallel_mode} has been added' + assert parallel_mode not in self._seed_states, f"The seed for {parallel_mode} has been added" elif parallel_mode in self._seed_states: print(f"Warning: {parallel_mode} seed has been overwritten.", flush=True) diff --git a/colossalai/legacy/core.py b/colossalai/legacy/core.py new file mode 100644 index 0000000000000000000000000000000000000000..80b6e4d25bd242067171d3ce15fe470577b1ffcb --- /dev/null +++ b/colossalai/legacy/core.py @@ -0,0 +1,6 @@ +#!/usr/bin/env python +# -*- encoding: utf-8 -*- + +from colossalai.legacy.context.parallel_context import global_context + +__all__ = ["global_context"] diff --git a/colossalai/legacy/engine/__init__.py b/colossalai/legacy/engine/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..581760748a169bcd8c9aa2c55edad5f2956a3cf9 --- /dev/null +++ b/colossalai/legacy/engine/__init__.py @@ -0,0 +1,4 @@ +from ._base_engine import Engine +from .gradient_handler import * + +__all__ = ["Engine"] diff --git a/colossalai/engine/_base_engine.py b/colossalai/legacy/engine/_base_engine.py similarity index 83% rename from colossalai/engine/_base_engine.py rename to colossalai/legacy/engine/_base_engine.py index ff8979d82401931b04649ffadff615932a1e1b37..0954e2be3eb1e28101ab032fb8b5dd74c915269e 100644 --- a/colossalai/engine/_base_engine.py +++ b/colossalai/legacy/engine/_base_engine.py @@ -8,10 +8,16 @@ from torch import Tensor from torch.nn import Module from torch.nn.modules.loss import _Loss -from colossalai.engine.gradient_handler import BaseGradientHandler -from colossalai.engine.schedule import BaseSchedule, InterleavedPipelineSchedule, NonPipelineSchedule, PipelineSchedule +from colossalai.interface import OptimizerWrapper +from colossalai.legacy.engine.gradient_handler import BaseGradientHandler +from colossalai.legacy.engine.schedule import ( + BaseSchedule, + InterleavedPipelineSchedule, + NonPipelineSchedule, + PipelineSchedule, +) +from colossalai.legacy.zero.gemini import BaseOpHook, register_ophooks_recursively from colossalai.logging import get_dist_logger -from colossalai.zero.legacy.gemini import BaseOpHook, register_ophooks_recursively class Engine: @@ -21,7 +27,7 @@ class Engine: Args: model (``torch.nn.Module``): The neural network model. - optimizer (``colossalai.nn.optimizer.ColossalaiOptimizer``): Optimizer for updating the parameters. + optimizer (``colossalai.interface.OptimizerWrapper``): Optimizer for updating the parameters. criterion (``torch.nn.modules.loss._Loss``, optional): Loss function for calculating loss. gradient_handlers (List[``BaseGradientHandler``], optional): A list of gradient handler used in backward. clip_grad_norm (float, optional): The norm of gradient clipping. @@ -53,15 +59,17 @@ class Engine: `Run resnet cifar10 with engine `_. """ - def __init__(self, - model: Module, - optimizer: "ColossalaiOptimizer", - criterion: Optional[_Loss] = None, - gradient_handlers: Optional[List[BaseGradientHandler]] = None, - clip_grad_norm: float = 0.0, - ophook_list: Optional[List[BaseOpHook]] = None, - verbose: bool = True, - schedule: Optional[BaseSchedule] = None): + def __init__( + self, + model: Module, + optimizer: "OptimizerWrapper", + criterion: Optional[_Loss] = None, + gradient_handlers: Optional[List[BaseGradientHandler]] = None, + clip_grad_norm: float = 0.0, + ophook_list: Optional[List[BaseOpHook]] = None, + verbose: bool = True, + schedule: Optional[BaseSchedule] = None, + ): self._model = model self._optimizer = optimizer self._criterion = criterion @@ -70,7 +78,7 @@ class Engine: self._logger = get_dist_logger() # state - self.training = True # default + self.training = True # default # build gradient handler if gradient_handlers: @@ -85,8 +93,9 @@ class Engine: # build schedule if schedule: - assert isinstance(schedule, BaseSchedule), \ - f'expected schedule to be of type BaseSchedule, but got {type(schedule)}' + assert isinstance( + schedule, BaseSchedule + ), f"expected schedule to be of type BaseSchedule, but got {type(schedule)}" self._schedule = schedule else: self._schedule = NonPipelineSchedule() @@ -143,15 +152,13 @@ class Engine: logger.warning(f"removing hooks is currently not supported") def zero_grad(self): - """Set the gradient of parameters to zero - """ + """Set the gradient of parameters to zero""" self.optimizer.zero_grad() def step(self): - """Execute parameter update - """ + """Execute parameter update""" self._all_reduce_gradients() - self.optimizer.clip_grad_norm(self.model, self._clip_grad_norm) + self.optimizer.clip_grad_by_norm(self._clip_grad_norm) return self.optimizer.step() def backward(self, loss: Tensor): @@ -186,8 +193,7 @@ class Engine: return self.model(*args, **kwargs) def _all_reduce_gradients(self): - """Handles all-reduce operations of gradients across different parallel groups. - """ + """Handles all-reduce operations of gradients across different parallel groups.""" for handler in self._gradient_handlers: handler.handle_gradient() @@ -202,13 +208,11 @@ class Engine: return output, label, loss def train(self): - """Sets the model to training mode. - """ + """Sets the model to training mode.""" self.training = True self._model.train() def eval(self): - """Sets the model to evaluation mode. - """ + """Sets the model to evaluation mode.""" self.training = False self._model.eval() diff --git a/colossalai/legacy/engine/gradient_accumulation/__init__.py b/colossalai/legacy/engine/gradient_accumulation/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e0835318ed9fedd37769bf79caac16524d35fbb9 --- /dev/null +++ b/colossalai/legacy/engine/gradient_accumulation/__init__.py @@ -0,0 +1,62 @@ +from typing import Iterable, List + +import torch.nn as nn +from torch.optim import Optimizer +from torch.optim.lr_scheduler import _LRScheduler + +from colossalai.legacy.engine import BaseGradientHandler + +from ._gradient_accumulation import ( + GradAccumDataloader, + GradAccumGradientHandler, + GradAccumLrSchedulerByStep, + GradAccumOptimizer, +) + +__all__ = [ + "accumulate_gradient", + "GradAccumDataloader", + "GradAccumOptimizer", + "GradAccumLrSchedulerByStep", + "GradAccumGradientHandler", +] + + +def accumulate_gradient( + model: nn.Module, + optimizer: Optimizer, + dataloader: Iterable, + accumulate_size: int, + gradient_handlers: List[BaseGradientHandler] = None, + lr_scheduler: _LRScheduler = None, +): + r"""Turning model, optimizer, dataloader into corresponding object for gradient accumulation. + + Args: + model (:class:`torch.nn.Module`): your model object for gradient accumulation. + optimizer (:class:`torch.optim.Optimizer`): your optimizer object for gradient accumulation. + dataloader (:class:`torch.utils.data.DataLoader` or iterable objects): + your dataloader object, would be called like iter(dataloader) + accumulate_size (int): the number of steps to accumulate gradients + gradient_handlers (List[:class:`colossalai.legacy.engine.BaseGradientHandler`]): + list of gradient handler objects. Default is None. + lr_scheduler (`torch.optim.lr_scheduler` or `colossalai.nn.lr_scheduler`): + your ``lr_scheduler`` object for gradient accumulation. Defaults to None. + + More details about `gradient_handlers` could be found in + `Gradient_handler `_. + + More details about `lr_scheduler` could be found + `lr_scheduler `_. and + `how to adjust learning rate `_. + """ + optimizer = GradAccumOptimizer(optimizer, accumulate_size=accumulate_size, model=model) + dataloader = GradAccumDataloader(dataloader, accumulate_size=accumulate_size) + + if gradient_handlers is not None: + gradient_handlers = [GradAccumGradientHandler(handler, accumulate_size) for handler in gradient_handlers] + + if lr_scheduler is not None: + lr_scheduler = GradAccumLrSchedulerByStep(lr_scheduler, accumulate_size=accumulate_size) + + return optimizer, dataloader, gradient_handlers, lr_scheduler diff --git a/colossalai/engine/gradient_accumulation/_gradient_accumulation.py b/colossalai/legacy/engine/gradient_accumulation/_gradient_accumulation.py similarity index 95% rename from colossalai/engine/gradient_accumulation/_gradient_accumulation.py rename to colossalai/legacy/engine/gradient_accumulation/_gradient_accumulation.py index cf66be1cd8218e8dc35177f9d69115ec1553c687..9de0f6c0ffd989cb8dbd1827a667a33b0fc7bf26 100644 --- a/colossalai/engine/gradient_accumulation/_gradient_accumulation.py +++ b/colossalai/legacy/engine/gradient_accumulation/_gradient_accumulation.py @@ -10,12 +10,12 @@ from torch.optim import Optimizer from torch.optim.lr_scheduler import _LRScheduler from torch.utils.data import DataLoader -from colossalai.engine import BaseGradientHandler -from colossalai.nn.optimizer import ColossalaiOptimizer +from colossalai.interface import OptimizerWrapper +from colossalai.legacy.engine import BaseGradientHandler from colossalai.utils import conditional_context -class GradAccumOptimizer(ColossalaiOptimizer): +class GradAccumOptimizer(OptimizerWrapper): """A wrapper for the optimizer to enable gradient accumulation by skipping the steps before accumulation size is reached. @@ -74,7 +74,7 @@ class GradAccumOptimizer(ColossalaiOptimizer): if self.accumulate_step < self.accumulate_size: pass else: - self.optim.clip_grad_norm(model, max_norm) + self.optim.clip_grad_by_norm(max_norm) def backward(self, loss: Tensor) -> None: """Execute backward pass. @@ -262,7 +262,7 @@ class GradAccumGradientHandler: before accumulation size is reached. Args: - grad_handler (:class:`colossalai.engine.BaseGradientHandler`): + grad_handler (:class:`colossalai.legacy.engine.BaseGradientHandler`): Your ``gradient_handler`` object for gradient accumulation, would be called when achieving `accumulate_size`. accumulate_size (int): The number of steps to accumulate gradients. @@ -272,8 +272,9 @@ class GradAccumGradientHandler: """ def __init__(self, grad_handler: BaseGradientHandler, accumulate_size: int) -> None: - assert isinstance(grad_handler, BaseGradientHandler), \ - f'expected grad_handler to be type BaseGradientHandler, but got {type(grad_handler)}' + assert isinstance( + grad_handler, BaseGradientHandler + ), f"expected grad_handler to be type BaseGradientHandler, but got {type(grad_handler)}" self.grad_handler = grad_handler self.accumulate_size = accumulate_size self.accumulate_step = 0 diff --git a/colossalai/legacy/engine/gradient_handler/__init__.py b/colossalai/legacy/engine/gradient_handler/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..78928b138842c7becd1739ca22771b8140403dd1 --- /dev/null +++ b/colossalai/legacy/engine/gradient_handler/__init__.py @@ -0,0 +1,15 @@ +from ._base_gradient_handler import BaseGradientHandler +from ._data_parallel_gradient_handler import DataParallelGradientHandler +from ._moe_gradient_handler import MoeGradientHandler +from ._pipeline_parallel_gradient_handler import PipelineSharedModuleGradientHandler +from ._sequence_parallel_gradient_handler import SequenceParallelGradientHandler +from ._zero_gradient_handler import ZeROGradientHandler + +__all__ = [ + "BaseGradientHandler", + "DataParallelGradientHandler", + "ZeROGradientHandler", + "PipelineSharedModuleGradientHandler", + "MoeGradientHandler", + "SequenceParallelGradientHandler", +] diff --git a/colossalai/engine/gradient_handler/_base_gradient_handler.py b/colossalai/legacy/engine/gradient_handler/_base_gradient_handler.py similarity index 98% rename from colossalai/engine/gradient_handler/_base_gradient_handler.py rename to colossalai/legacy/engine/gradient_handler/_base_gradient_handler.py index 7d96dd8a88a63d9f0c40ceefb99bf2809a37662d..e594bb00f96b177887ca80bdd3831071614d7055 100644 --- a/colossalai/engine/gradient_handler/_base_gradient_handler.py +++ b/colossalai/legacy/engine/gradient_handler/_base_gradient_handler.py @@ -22,4 +22,3 @@ class BaseGradientHandler(ABC): """A method to accumulate gradients across different parallel groups. Users should write their own functions or just use the functions in pre-defined subclasses. """ - pass diff --git a/colossalai/engine/gradient_handler/_data_parallel_gradient_handler.py b/colossalai/legacy/engine/gradient_handler/_data_parallel_gradient_handler.py similarity index 83% rename from colossalai/engine/gradient_handler/_data_parallel_gradient_handler.py rename to colossalai/legacy/engine/gradient_handler/_data_parallel_gradient_handler.py index 5cc7169c5a9f630dcb9e1b981f33c3fb35548cc0..3782adaf718727afe588f207feff354ce44d7d71 100644 --- a/colossalai/engine/gradient_handler/_data_parallel_gradient_handler.py +++ b/colossalai/legacy/engine/gradient_handler/_data_parallel_gradient_handler.py @@ -1,7 +1,7 @@ -from colossalai.core import global_context as gpc -from colossalai.registry import GRADIENT_HANDLER +from colossalai.legacy.context.parallel_mode import ParallelMode +from colossalai.legacy.core import global_context as gpc +from colossalai.legacy.registry import GRADIENT_HANDLER -from ...context.parallel_mode import ParallelMode from ._base_gradient_handler import BaseGradientHandler from .utils import bucket_allreduce @@ -20,8 +20,7 @@ class DataParallelGradientHandler(BaseGradientHandler): """ def handle_gradient(self): - """A method running a all-reduce operation in a data parallel group. - """ + """A method running a all-reduce operation in a data parallel group.""" # TODO: add memory buffer if gpc.data_parallel_size > 1: bucket_allreduce(param_list=self._model.parameters(), group=gpc.get_group(ParallelMode.DATA)) diff --git a/colossalai/engine/gradient_handler/_moe_gradient_handler.py b/colossalai/legacy/engine/gradient_handler/_moe_gradient_handler.py similarity index 83% rename from colossalai/engine/gradient_handler/_moe_gradient_handler.py rename to colossalai/legacy/engine/gradient_handler/_moe_gradient_handler.py index b499345d4e184662b3a242aa71d24859ad843c7c..6a7224cff7bdcc2b509e4a7d580e07e5f2de8934 100644 --- a/colossalai/engine/gradient_handler/_moe_gradient_handler.py +++ b/colossalai/legacy/engine/gradient_handler/_moe_gradient_handler.py @@ -1,9 +1,9 @@ from colossalai.context.moe_context import MOE_CONTEXT -from colossalai.core import global_context as gpc -from colossalai.registry import GRADIENT_HANDLER +from colossalai.legacy.context.parallel_mode import ParallelMode +from colossalai.legacy.core import global_context as gpc +from colossalai.legacy.registry import GRADIENT_HANDLER from colossalai.utils.moe import get_moe_epsize_param_dict -from ...context.parallel_mode import ParallelMode from ._base_gradient_handler import BaseGradientHandler from .utils import bucket_allreduce @@ -42,5 +42,6 @@ class MoeGradientHandler(BaseGradientHandler): for ep_size in epsize_param_dict: if ep_size != 1 and ep_size != MOE_CONTEXT.world_size: - bucket_allreduce(param_list=epsize_param_dict[ep_size], - group=MOE_CONTEXT.parallel_info_dict[ep_size].dp_group) + bucket_allreduce( + param_list=epsize_param_dict[ep_size], group=MOE_CONTEXT.parallel_info_dict[ep_size].dp_group + ) diff --git a/colossalai/engine/gradient_handler/_pipeline_parallel_gradient_handler.py b/colossalai/legacy/engine/gradient_handler/_pipeline_parallel_gradient_handler.py similarity index 77% rename from colossalai/engine/gradient_handler/_pipeline_parallel_gradient_handler.py rename to colossalai/legacy/engine/gradient_handler/_pipeline_parallel_gradient_handler.py index 5b49a9c0360dca600b8f1226ce0334f959e2265b..3a65f65abf73a4b5e4b2bcbc5204ecc5ac53a3bf 100644 --- a/colossalai/engine/gradient_handler/_pipeline_parallel_gradient_handler.py +++ b/colossalai/legacy/engine/gradient_handler/_pipeline_parallel_gradient_handler.py @@ -6,8 +6,8 @@ import torch import torch.distributed as dist from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors -from colossalai.core import global_context as gpc -from colossalai.registry import GRADIENT_HANDLER +from colossalai.legacy.core import global_context as gpc +from colossalai.legacy.registry import GRADIENT_HANDLER from ._base_gradient_handler import BaseGradientHandler @@ -26,17 +26,21 @@ class PipelineSharedModuleGradientHandler(BaseGradientHandler): """ def handle_gradient(self): - """A method running a all-reduce operation in sub pipeline parallel groups. - """ + """A method running a all-reduce operation in sub pipeline parallel groups.""" if gpc.pipeline_parallel_size > 1: # bucketize and all-reduce buckets = defaultdict(lambda: defaultdict(list)) # Pack the buckets. for param in self._model.parameters(): - group = getattr(param, 'pipeline_shared_module_pg', None) - if param.requires_grad and group is not None and ( - (hasattr(param, 'colo_attr') and not param.colo_attr.saved_grad.is_null()) - or param.grad is not None): + group = getattr(param, "pipeline_shared_module_pg", None) + if ( + param.requires_grad + and group is not None + and ( + (hasattr(param, "colo_attr") and not param.colo_attr.saved_grad.is_null()) + or param.grad is not None + ) + ): tp = param.data.type() buckets[group][tp].append(param) @@ -44,7 +48,7 @@ class PipelineSharedModuleGradientHandler(BaseGradientHandler): for group, group_buckets in buckets.items(): for tp, bucket in group_buckets.items(): grads = [ - param.colo_attr.grad_payload if hasattr(param, 'colo_attr') else param.grad.data + param.colo_attr.grad_payload if hasattr(param, "colo_attr") else param.grad.data for param in bucket ] coalesced = _flatten_dense_tensors(grads).to(torch.cuda.current_device()) diff --git a/colossalai/engine/gradient_handler/_sequence_parallel_gradient_handler.py b/colossalai/legacy/engine/gradient_handler/_sequence_parallel_gradient_handler.py similarity index 83% rename from colossalai/engine/gradient_handler/_sequence_parallel_gradient_handler.py rename to colossalai/legacy/engine/gradient_handler/_sequence_parallel_gradient_handler.py index ea4f0fbb1c718965deae37fc0a148aafca3d104a..6d507bcc0269c69ab6981355b5d81422b730f2c8 100644 --- a/colossalai/engine/gradient_handler/_sequence_parallel_gradient_handler.py +++ b/colossalai/legacy/engine/gradient_handler/_sequence_parallel_gradient_handler.py @@ -1,7 +1,7 @@ -from colossalai.core import global_context as gpc -from colossalai.registry import GRADIENT_HANDLER +from colossalai.legacy.context.parallel_mode import ParallelMode +from colossalai.legacy.core import global_context as gpc +from colossalai.legacy.registry import GRADIENT_HANDLER -from ...context.parallel_mode import ParallelMode from ._base_gradient_handler import BaseGradientHandler from .utils import bucket_allreduce @@ -20,7 +20,6 @@ class SequenceParallelGradientHandler(BaseGradientHandler): """ def handle_gradient(self): - """A method running a all-reduce operation in a data parallel group. - """ + """A method running a all-reduce operation in a data parallel group.""" if gpc.get_world_size(ParallelMode.SEQUENCE_DP) > 1: bucket_allreduce(param_list=self._model.parameters(), group=gpc.get_group(ParallelMode.SEQUENCE_DP)) diff --git a/colossalai/engine/gradient_handler/_zero_gradient_handler.py b/colossalai/legacy/engine/gradient_handler/_zero_gradient_handler.py similarity index 90% rename from colossalai/engine/gradient_handler/_zero_gradient_handler.py rename to colossalai/legacy/engine/gradient_handler/_zero_gradient_handler.py index 19fd1e97f86f035826666f766fa1983eb9aae2cc..63ec6e70ba062e5591d2335b541afe5638728c79 100644 --- a/colossalai/engine/gradient_handler/_zero_gradient_handler.py +++ b/colossalai/legacy/engine/gradient_handler/_zero_gradient_handler.py @@ -1,4 +1,4 @@ -from colossalai.registry import GRADIENT_HANDLER +from colossalai.legacy.registry import GRADIENT_HANDLER from ._base_gradient_handler import BaseGradientHandler @@ -16,6 +16,5 @@ class ZeROGradientHandler(BaseGradientHandler): """ def handle_gradient(self): - """A method running a all-reduce operation in a data parallel group. - """ + """A method running a all-reduce operation in a data parallel group.""" self._optimizer.sync_grad() diff --git a/colossalai/engine/gradient_handler/utils.py b/colossalai/legacy/engine/gradient_handler/utils.py similarity index 100% rename from colossalai/engine/gradient_handler/utils.py rename to colossalai/legacy/engine/gradient_handler/utils.py diff --git a/colossalai/legacy/engine/schedule/__init__.py b/colossalai/legacy/engine/schedule/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..017231a9b4a8a9965d96d429647ff9cf84536bdb --- /dev/null +++ b/colossalai/legacy/engine/schedule/__init__.py @@ -0,0 +1,5 @@ +from ._base_schedule import BaseSchedule +from ._non_pipeline_schedule import NonPipelineSchedule +from ._pipeline_schedule import InterleavedPipelineSchedule, PipelineSchedule, get_tensor_shape + +__all__ = ["BaseSchedule", "NonPipelineSchedule", "PipelineSchedule", "InterleavedPipelineSchedule", "get_tensor_shape"] diff --git a/colossalai/engine/schedule/_base_schedule.py b/colossalai/legacy/engine/schedule/_base_schedule.py similarity index 83% rename from colossalai/engine/schedule/_base_schedule.py rename to colossalai/legacy/engine/schedule/_base_schedule.py index a2d50041127ace67726f1390fbb58331925e8af5..4a3ccfda1bb5a522819da6c7809ead96a02f847c 100644 --- a/colossalai/engine/schedule/_base_schedule.py +++ b/colossalai/legacy/engine/schedule/_base_schedule.py @@ -47,7 +47,8 @@ class BaseSchedule(ABC): data = {k: self._move_tensor(v) for k, v in data.items()} else: raise TypeError( - f"Expected batch data to be of type torch.Tensor, list, tuple, or dict, but got {type(data)}") + f"Expected batch data to be of type torch.Tensor, list, tuple, or dict, but got {type(data)}" + ) return data def _get_batch_size(self, data): @@ -72,7 +73,7 @@ class BaseSchedule(ABC): Tuple (:class:`Tensor`, :class:`torch.Tensor`): A tuple of (data, label). """ if data_iter is None: - raise RuntimeError('Dataloader is not defined.') + raise RuntimeError("Dataloader is not defined.") batch_data = next(data_iter) if to_gpu: @@ -81,27 +82,26 @@ class BaseSchedule(ABC): return batch_data def pre_processing(self, engine): - """To perform actions before running the schedule. - """ - pass + """To perform actions before running the schedule.""" @abstractmethod - def forward_backward_step(self, - engine, - data_iter: Iterable, - forward_only: bool, - return_loss: bool = True, - return_output_label: bool = True): + def forward_backward_step( + self, + engine, + data_iter: Iterable, + forward_only: bool, + return_loss: bool = True, + return_output_label: bool = True, + ): """The process function over a batch of dataset for training or evaluation. Args: - engine (colossalai.engine.Engine): Colossalai engine for training and inference. + engine (colossalai.legacy.engine.Engine): Colossalai engine for training and inference. data_iter (Iterable): Data iterator from which get a batch of data, obtained by calling iter(dataloader). forward_only (bool): If True, the process won't include backward. return_loss (bool, optional): If False, the loss won't be returned. return_output_label (bool, optional): If False, the output and label won't be returned. """ - pass @staticmethod def _call_engine(engine, inputs): @@ -113,13 +113,14 @@ class BaseSchedule(ABC): return engine(**inputs) else: TypeError( - f"Expected engine inputs to be of type torch.Tensor, list, tuple, or dict, but got {type(inputs)}") + f"Expected engine inputs to be of type torch.Tensor, list, tuple, or dict, but got {type(inputs)}" + ) @staticmethod def _call_engine_criterion(engine, outputs, labels): - assert isinstance(outputs, - (torch.Tensor, list, tuple, - dict)), f'Expect output of model is (torch.Tensor, list, tuple), got {type(outputs)}' + assert isinstance( + outputs, (torch.Tensor, list, tuple, dict) + ), f"Expect output of model is (torch.Tensor, list, tuple), got {type(outputs)}" if isinstance(outputs, torch.Tensor): outputs = (outputs,) if isinstance(labels, torch.Tensor): @@ -134,6 +135,8 @@ class BaseSchedule(ABC): elif isinstance(outputs, dict) and isinstance(labels, (list, tuple)): raise ValueError(f"Expected labels to be a dict when the model outputs are dict, but got {type(labels)}") else: - raise TypeError(f"Expected model outputs and labels to be of type torch.Tensor ' \ + raise TypeError( + f"Expected model outputs and labels to be of type torch.Tensor ' \ '(which is auto-converted to tuple), list, tuple, or dict, ' \ - 'but got {type(outputs)} (model outputs) and {type(labels)} (labels)") + 'but got {type(outputs)} (model outputs) and {type(labels)} (labels)" + ) diff --git a/colossalai/engine/schedule/_non_pipeline_schedule.py b/colossalai/legacy/engine/schedule/_non_pipeline_schedule.py similarity index 78% rename from colossalai/engine/schedule/_non_pipeline_schedule.py rename to colossalai/legacy/engine/schedule/_non_pipeline_schedule.py index b9239d928a7ba4e9471071f1c4e08c8443f5edb1..08c6cfd60f28480cf2ea7493e4b6a47a081c93fb 100644 --- a/colossalai/engine/schedule/_non_pipeline_schedule.py +++ b/colossalai/legacy/engine/schedule/_non_pipeline_schedule.py @@ -37,24 +37,27 @@ class NonPipelineSchedule(BaseSchedule): if data_process_func: sig = inspect.signature(data_process_func) - assert len(sig.parameters) == 1, \ - 'The data_process_func only takes in one parameter for NonPipelineSchedule, ' \ - 'which is a tuple of tensors for the current batch, ' \ - 'i.e. data_process_func(dataloader_output).' + assert len(sig.parameters) == 1, ( + "The data_process_func only takes in one parameter for NonPipelineSchedule, " + "which is a tuple of tensors for the current batch, " + "i.e. data_process_func(dataloader_output)." + ) super().__init__(data_process_func) - def forward_backward_step(self, - engine, - data_iter: Iterable, - forward_only: bool = False, - return_loss: bool = True, - return_output_label: bool = True): + def forward_backward_step( + self, + engine, + data_iter: Iterable, + forward_only: bool = False, + return_loss: bool = True, + return_output_label: bool = True, + ): """The process function that loads a batch of dataset and feeds it to the model. The returned labels and loss will None if :attr:`return_loss` is False. Args: - engine (colossalai.engine.Engine): Colossalai engine for training and inference. + engine (colossalai.legacy.engine.Engine): Colossalai engine for training and inference. data_iter (Iterable): Dataloader as the form of an iterator, obtained by calling iter(dataloader). forward_only (bool, optional): If True, the model is run for the forward pass, else back propagation will be executed. @@ -64,8 +67,9 @@ class NonPipelineSchedule(BaseSchedule): Returns: Tuple[:class:`torch.Tensor`]: A tuple of (output, label, loss), loss and label could be None. """ - assert forward_only or return_loss, \ - "The argument 'return_loss' has to be True when 'forward_only' is False, but got False." + assert ( + forward_only or return_loss + ), "The argument 'return_loss' has to be True when 'forward_only' is False, but got False." batch_data = self.load_batch(data_iter) if self.data_process_func: data, label = self.data_process_func(batch_data) diff --git a/colossalai/legacy/engine/schedule/_pipeline_schedule.py b/colossalai/legacy/engine/schedule/_pipeline_schedule.py new file mode 100644 index 0000000000000000000000000000000000000000..4fc5040f69838782cd47293e5a6fda594dd09470 --- /dev/null +++ b/colossalai/legacy/engine/schedule/_pipeline_schedule.py @@ -0,0 +1,851 @@ +#!/usr/bin/env python +# -*- encoding: utf-8 -*- + +import inspect +from typing import Callable, List, Tuple, Union + +import torch.cuda + +import colossalai.legacy.communication as comm +from colossalai.legacy.amp.naive_amp import NaiveAMPModel +from colossalai.legacy.context.parallel_mode import ParallelMode +from colossalai.legacy.core import global_context as gpc +from colossalai.legacy.utils import switch_virtual_pipeline_parallel_rank +from colossalai.logging import get_dist_logger +from colossalai.utils.cuda import get_current_device + +from ._base_schedule import BaseSchedule + + +def get_tensor_shape(): + if hasattr(gpc.config, "TENSOR_SHAPE"): + return gpc.config.TENSOR_SHAPE + + if not gpc.is_initialized(ParallelMode.PIPELINE): + return None + + if ( + hasattr(gpc.config, "SEQ_LENGTH") + and hasattr(gpc.config, "GLOBAL_BATCH_SIZE") + and hasattr(gpc.config, "GLOBAL_BATCH_SIZE") + and hasattr(gpc.config, "HIDDEN_SIZE") + ): + if gpc.is_initialized(ParallelMode.DATA): + dp_size = gpc.get_world_size(ParallelMode.DATA) + else: + dp_size = 1 + if gpc.is_initialized(ParallelMode.SEQUENCE): + seq_size = gpc.get_world_size(ParallelMode.SEQUENCE) + else: + seq_size = 1 + + tensor_shape = ( + gpc.config.SEQ_LENGTH // seq_size, + gpc.config.GLOBAL_BATCH_SIZE // dp_size // gpc.config.NUM_MICRO_BATCHES, + gpc.config.HIDDEN_SIZE, + ) + return tensor_shape + else: + return None + + +def pack_return_tensors(return_tensors): + output, label = tuple(zip(*return_tensors)) + if isinstance(output[0], torch.Tensor): + output = torch.cat(output, dim=0) + elif isinstance(output[0], (list, tuple)): + output = tuple(torch.cat(tensors, dim=0) for tensors in zip(*output)) + else: + raise TypeError(f"Output of model must be tensor or list/tuple of tensors") + if isinstance(label[0], torch.Tensor): + label = torch.cat(label, dim=0) + else: + merged_label = {k: [] for k in label[0].keys()} + for d in label: + for k, v in d.items(): + merged_label[k].append(v) + label = {k: torch.cat(v, dim=0) for k, v in merged_label.items()} + return output, label + + +class PipelineSchedule(BaseSchedule): + """A helper schedule class for pipeline parallelism running environment. + It uses non-interleaved 1F1B strategy. Other properties are similar as + :class:`NonPipelineSchedule`. + + Args: + num_microbatches (int): The number of microbatches. + data_process_func (Callable, optional): + The preprocessing function which receives a batch of data, and it will be executed in `load_batch`. + tensor_shape (torch.Size, optional): Specified shape in pipeline communication. + scatter_gather_tensors (bool, optional): + If set to `True`, communication will be reduced over pipeline when using 1D tensor parallelization. + + Example: + + # this shows an example of customized data_process_func + def data_process_func(stage_output, dataloader_output): + output1, output2 = stage_output + item1, item2, item3 = dataloader_output + + # assume item2 is not needed + data = (output1, output2, item1) + label = item3 + return data, label + + """ + + def __init__( + self, + num_microbatches, + data_process_func: Callable = None, + tensor_shape: Union[torch.Size, List[int], Tuple[int]] = None, + scatter_gather_tensors: bool = False, + ): + # we need to make sure that the signature of the data_process_func is valid + if data_process_func: + sig = inspect.signature(data_process_func) + assert len(sig.parameters) == 2, ( + "The data_process_func only takes in two parameters for NonPipelineSchedule, " + "which is the tensors passed by the previous pipeline stage and the dataloader output from this stage, " + "i.e. data_process_func(stage_output, dataloader_output)." + ) + + super().__init__(data_process_func=data_process_func) + + assert num_microbatches > 0, f"expected num_microbatches to be larger then 1, but got {num_microbatches}" + + self.num_microbatches = num_microbatches + self.dtype = torch.float + assert not isinstance( + tensor_shape, int + ), "tensor_shape type should be one of Union[torch.Size, List[int], Tuple[int]]." + if tensor_shape is None: + self.tensor_shape = tensor_shape + elif isinstance(tensor_shape, torch.Size): + self.tensor_shape = tensor_shape + else: + self.tensor_shape = torch.Size(tensor_shape) + self.scatter_gather_tensors = False + if gpc.is_initialized(ParallelMode.PARALLEL_1D) and gpc.get_world_size(ParallelMode.PARALLEL_1D) > 1: + self.scatter_gather_tensors = scatter_gather_tensors + self._logger = get_dist_logger() + + # cache for the batch data + self.batch_data = None + + def load_batch(self, data_iter): + # Pipeline schedule just puts data in memory + batch_data = super().load_batch(data_iter, to_gpu=False) + self.microbatch_offset = 0 + assert self.batch_size % self.num_microbatches == 0, "Batch size should divided by the number of microbatches" + self.microbatch_size = self.batch_size // self.num_microbatches + self.batch_data = batch_data + + def _get_data_slice(self, data, offset): + if isinstance(data, torch.Tensor): + return data[offset : offset + self.microbatch_size] + elif isinstance(data, (list, tuple)): + data_dict = {} + for element in data: + if isinstance(element, dict): + data_dict.update({k: v[offset : offset + self.microbatch_size] for k, v in element.items()}) + elif data_dict: + data_dict["label"] = element[offset : offset + self.microbatch_size] + if data_dict: + return data_dict + return [val[offset : offset + self.microbatch_size] for val in data] + elif isinstance(data, dict): + return {k: v[offset : offset + self.microbatch_size] for k, v in data.items()} + else: + raise TypeError(f"Expected data to be of type torch.Tensor, list, tuple, or dict, but got {type(data)}") + + def load_micro_batch(self): + micro_batch_data = self._get_data_slice(self.batch_data, self.microbatch_offset) + self.microbatch_offset += self.microbatch_size + return self._move_to_device(micro_batch_data) + + def pre_processing(self, engine): + from colossalai.legacy.zero import ShardedModelV2 + + # TODO: remove this after testing new zero with pipeline parallelism + model = engine.model + if isinstance(model, NaiveAMPModel): + self.dtype = torch.half + model = model.model + if isinstance(model, ShardedModelV2): + self.dtype = torch.half + model = model.module + # sig = inspect.signature(model.forward) + # for p in sig.parameters.values(): + # assert p.kind != inspect.Parameter.VAR_POSITIONAL, '*args is not supported' + + @staticmethod + def _call_engine(model, data): + if data is not None: + if isinstance(data, torch.Tensor): + return model(data) + elif isinstance(data, (list, tuple)): + return model(*data) + elif isinstance(data, dict): + stage_output = None + if "stage_output" in data: + stage_output = data.pop("stage_output") + if stage_output is None: + return model(**data) + elif isinstance(stage_output, torch.Tensor): + return model(stage_output, **data) + elif isinstance(stage_output, (tuple, list)): + return model(*stage_output, **data) + else: + raise TypeError( + f"Expected stage_output to be of type torch.Tensor, list, or tuple, but got {type(stage_output)}" + ) + else: + raise TypeError(f"Expected data to be of type torch.Tensor, list, tuple, or dict, but got {type(data)}") + + def _get_actual_forward_func(self, module): + if isinstance(module, NaiveAMPModel): + sig = inspect.signature(module.model.forward) + elif hasattr(module, "colo_attr"): + sig = inspect.signature(module.module.forward) + else: + sig = inspect.signature(module.forward) + return sig + + def _get_data_label_for_current_step(self, stage_output, micro_batch_data, criterion, model): + if self.data_process_func: + # use customized function to get data and label + data, label = self.data_process_func(stage_output, micro_batch_data) + else: + if isinstance(micro_batch_data, (tuple, list)): + if gpc.is_first_rank(ParallelMode.PIPELINE): + # for the first stage, we use the data from the + # dataloader output by default + data, label = micro_batch_data + else: + # for non-first stage, we use the output passed + # by the previous as the model input + data = stage_output + _, label = micro_batch_data + elif isinstance(micro_batch_data, dict): + data = {} + data["stage_output"] = stage_output + if "label" in micro_batch_data: + label = micro_batch_data.pop("label") + else: + label = None + load_data = micro_batch_data + data.update(load_data) + return data, label + + def _forward_step(self, engine, input_obj, return_tensors, return_output_label=True, accum_loss=None): + """Forward step for passed-in model. If it is the first stage, the input tensor + is obtained from data_iterator, otherwise the passed-in input_obj is used. + Returns output tensor. This is a helper function and can be ignored by users. + + Args: + engine (colossalai.legacy.engine.Engine): Colossalai engine for training and inference. + input_obj (Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]): Input tensor for this pipeline stage. + return_tensors (List[:class:`torch.Tensor`]): A list of tensors to return. + return_output_label (bool, optional): Whether returns output labels. + accum_loss (optional): Where accumulated loss stores. + Returns: + Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]: output or the loss value of the current pipeline stage. + """ + micro_batch_data = self.load_micro_batch() + + data, label = self._get_data_label_for_current_step(input_obj, micro_batch_data, engine.criterion, engine.model) + + output_obj = self._call_engine(engine.model, data) + + if gpc.is_last_rank(ParallelMode.PIPELINE): + if return_output_label: + return_tensors.append((output_obj, label)) + if accum_loss is not None: + loss_reduced = self._call_engine_criterion(engine, output_obj, label) / self.num_microbatches + accum_loss.add_(loss_reduced.detach()) + return loss_reduced + else: + # forward only, it's useless since backward is not needed + return output_obj + else: + if isinstance(output_obj, torch.Tensor): + self._logger.debug( + f"Global rank {gpc.get_global_rank()}, pipeline rank {gpc.get_local_rank(ParallelMode.PIPELINE)} forward output tensor {output_obj.shape}, dtype {output_obj.dtype}" + ) + return output_obj + + def _backward_step(self, engine, input_obj, output_obj, output_obj_grad): + """Backward step through the passed-in output tensor. If it is the last stage, the + output_obj_grad is None, otherwise it is the gradients with respect to stage's output tensor. + Returns the gradients with respect to the input tensor (None if first stage). + This is a helper function and can be ignored by users. + + Args: + engine (colossalai.legacy.engine.Engine): Colossalai engine for training and inference. + input_obj (Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]): input tensor for this pipeline stage. + output_obj (Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]): output tensor for this pipeline stage. + output_obj_grad (Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]): gradient of output tensor for this pipeline stage. + + Returns: + Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]: gradient of input tensor. + """ + + # Retain the grad on the input_obj. + if input_obj is not None: + if isinstance(input_obj, torch.Tensor): + input_obj.retain_grad() + else: + for in_tensor in input_obj: + if in_tensor is not None: + in_tensor.retain_grad() + # Backward pass. + if output_obj_grad is None: + engine.backward(output_obj) + else: + engine.backward_by_grad(output_obj, output_obj_grad) + + # Collect the grad of the input_obj. + input_obj_grad = None + if input_obj is not None: + if isinstance(input_obj, torch.Tensor): + input_obj_grad = input_obj.grad + else: + input_obj_grad = [] + for in_tensor in input_obj: + input_obj_grad.append(in_tensor.grad) + + return input_obj_grad + + def forward_backward_step(self, engine, data_iter, forward_only=False, return_loss=True, return_output_label=True): + """Runs non-interleaved 1F1B schedule, with communication between pipeline stages. + Returns a tuple with losses if the last stage, an empty tuple otherwise. + + Args: + engine (colossalai.legacy.engine.Engine): Colossalai engine for training and inference. + data_iter (Iterable): Dataloader as the form of an iterator, obtained by calling iter(dataloader). + forward_only (bool, optional): + Whether run forward step only. Default is false. If true, no backward will be run. + return_loss (bool, optional): Whether returns the loss value. Default is true. + return_output_label (bool, optional): If False, the output and label won't be returned. + + Returns: + Tuple[:class:`torch.Tensor`]: A tuple of (output, label, loss), loss and label could be None. + """ + + assert ( + forward_only or return_loss + ), "The argument 'return_loss' has to be True when 'forward_only' is False, but got False." + self.load_batch(data_iter) + num_warmup_microbatches = ( + gpc.get_world_size(ParallelMode.PIPELINE) - gpc.get_local_rank(ParallelMode.PIPELINE) - 1 + ) + num_warmup_microbatches = min(num_warmup_microbatches, self.num_microbatches) + num_microbatches_remaining = self.num_microbatches - num_warmup_microbatches + + # Input, output tensors only need to be saved when doing backward passes + input_objs = None + output_objs = None + if not forward_only: + input_objs = [] + output_objs = [] + return_tensors = [] + if return_loss and gpc.is_pipeline_last_stage(ignore_virtual=True): + accum_loss = torch.zeros(1, device=get_current_device()) + else: + accum_loss = None + # Used for tensor meta information communication + ft_shapes = self.tensor_shape + bt_shapes = None + fs_checker = self.tensor_shape is None + + # Run warmup forward passes. + for i in range(num_warmup_microbatches): + if not gpc.is_first_rank(ParallelMode.PIPELINE): + ft_shapes = comm.recv_obj_meta(ft_shapes) + input_obj = comm.recv_forward( + ft_shapes, dtype=self.dtype, scatter_gather_tensors=self.scatter_gather_tensors + ) + output_obj = self._forward_step( + engine, input_obj, return_tensors, return_output_label=return_output_label, accum_loss=accum_loss + ) + if not gpc.is_last_rank(ParallelMode.PIPELINE): + if isinstance(output_obj, torch.Tensor): + bt_shapes = output_obj.shape + else: + bt_shapes = [] + for out_tensor in output_obj: + bt_shapes.append(out_tensor.shape) + fs_checker = comm.send_obj_meta(output_obj, fs_checker) + comm.send_forward(output_obj, scatter_gather_tensors=self.scatter_gather_tensors) + + if not forward_only: + input_objs.append(input_obj) + output_objs.append(output_obj) + + # Before running 1F1B, need to receive first forward tensor. + # If all microbatches are run in warmup / cooldown phase, then no need to + # receive this tensor here. + if num_microbatches_remaining > 0: + if not gpc.is_first_rank(ParallelMode.PIPELINE): + ft_shapes = comm.recv_obj_meta(ft_shapes) + input_obj = comm.recv_forward( + ft_shapes, dtype=self.dtype, scatter_gather_tensors=self.scatter_gather_tensors + ) + + # Run 1F1B in steady state. + for i in range(num_microbatches_remaining): + last_iteration = i == (num_microbatches_remaining - 1) + + output_obj = self._forward_step( + engine, input_obj, return_tensors, return_output_label=return_output_label, accum_loss=accum_loss + ) + if forward_only: + comm.send_forward(output_obj, scatter_gather_tensors=self.scatter_gather_tensors) + + if not last_iteration: + input_obj = comm.recv_forward( + ft_shapes, dtype=self.dtype, scatter_gather_tensors=self.scatter_gather_tensors + ) + + else: + output_obj_grad = comm.send_forward_recv_backward( + output_obj, bt_shapes, dtype=self.dtype, scatter_gather_tensors=self.scatter_gather_tensors + ) + + # Add input_obj and output_obj to end of list. + input_objs.append(input_obj) + output_objs.append(output_obj) + + # Pop output_obj and output_obj from the start of the list for + # the backward pass. + input_obj = input_objs.pop(0) + output_obj = output_objs.pop(0) + + input_obj_grad = self._backward_step(engine, input_obj, output_obj, output_obj_grad) + + if last_iteration: + input_obj = None + comm.send_backward(input_obj_grad, scatter_gather_tensors=self.scatter_gather_tensors) + else: + input_obj = comm.send_backward_recv_forward( + input_obj_grad, ft_shapes, dtype=self.dtype, scatter_gather_tensors=self.scatter_gather_tensors + ) + + # Run cooldown backward passes. + if not forward_only: + for i in range(num_warmup_microbatches): + input_obj = input_objs.pop(0) + output_obj = output_objs.pop(0) + + output_obj_grad = comm.recv_backward( + bt_shapes, dtype=self.dtype, scatter_gather_tensors=self.scatter_gather_tensors + ) + + input_obj_grad = self._backward_step(engine, input_obj, output_obj, output_obj_grad) + + comm.send_backward(input_obj_grad, scatter_gather_tensors=self.scatter_gather_tensors) + + if len(return_tensors) > 0: + output, label = pack_return_tensors(return_tensors) + return output, label, accum_loss + else: + return None, None, accum_loss + + +class InterleavedPipelineSchedule(PipelineSchedule): + def __init__( + self, + num_microbatches: int, + num_model_chunks: int, + data_process_func: Callable = None, + tensor_shape: Union[torch.Size, List[int], Tuple[int]] = None, + scatter_gather_tensors: bool = False, + ): + """A helper schedule class for pipeline parallelism running environment. + It uses interleaved 1F1B strategy. Other properties are similar as + :class:`NonPipelineSchedule`. + + Args: + num_microbatches (int): The number of microbatches. + num_model_chunks (int): The number of model chunks. + data_process_func (Callable, optional): + The preprocessing function which receives a batch of data, and it will be executed in `load_batch`. + tensor_shape (torch.Size, optional): Specified shape in pipeline communication. + scatter_gather_tensors (bool, optional): + If set to `True`, communication will be reduced over pipeline when using 1D tensor parallelization. + """ + assert ( + num_microbatches % gpc.get_world_size(ParallelMode.PIPELINE) == 0 + ), "num_microbatches must be an integer multiple of pipeline parallel world size" + assert ( + isinstance(num_model_chunks, int) and num_model_chunks > 0 + ), f"expected num_model_chunks to be an integer and larger than 0, but got {num_model_chunks}" + super().__init__( + num_microbatches, + data_process_func=data_process_func, + tensor_shape=tensor_shape, + scatter_gather_tensors=scatter_gather_tensors, + ) + gpc.set_virtual_pipeline_parallel_size(num_model_chunks) + gpc.set_virtual_pipeline_parallel_rank(0) + self.num_model_chunks = num_model_chunks + + def pre_processing(self, engine): + from colossalai.zero.sharded_model.sharded_model_v2 import ShardedModelV2 + + if isinstance(engine.model, ShardedModelV2): + self.dtype = torch.half + elif isinstance(engine.model[0], NaiveAMPModel): + self.dtype = torch.half + for model in engine.model: + if isinstance(model, NaiveAMPModel): + model = model.model + sig = inspect.signature(model.forward) + for p in sig.parameters.values(): + assert p.kind != inspect.Parameter.VAR_POSITIONAL, "*args is not supported" + + def load_batch(self, data_iter): + super().load_batch(data_iter) + # overwrite microbatch_offset, since model chunks load the same microbatch, and should tract the offset + self.microbatch_offset = [0 for _ in range(self.num_model_chunks)] + + def load_micro_batch(self, model_chunk_id): + data = self._get_data_slice(self.batch_data, self.microbatch_offset[model_chunk_id]) + self.microbatch_offset[model_chunk_id] += self.microbatch_size + return self._move_to_device(data) + + def _forward_step( + self, engine, model_chunk_id, input_obj, return_tensors, return_output_label=True, accum_loss=None + ): + """Forward step for passed-in model. If it is the first stage, the input tensor + is obtained from data_iterator, otherwise the passed-in input_obj is used. + Returns output tensor. This is a helper function and can be ignored by users. + + Args: + engine (colossalai.legacy.engine.Engine): Colossalai engine for training and inference. + model_chunk_id (int): The id of model chunks. + input_obj (Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]): Input tensor for this pipeline stage. + return_tensors (List[:class:`torch.Tensor`]): A list of tensors to return. + return_output_label (bool, optional): Whether returns output labels. + accum_loss (optional): Where accumulated loss stores. + Returns: + Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]: output or the loss value of the current pipeline stage. + """ + micro_batch_data = self.load_micro_batch(model_chunk_id) + data, label = self._get_data_label_for_current_step( + input_obj, micro_batch_data, engine.criterion, engine.model[model_chunk_id] + ) + + output_obj = self._call_engine(engine.model[model_chunk_id], data) + + if gpc.is_pipeline_last_stage(): + if return_output_label: + return_tensors.append((output_obj, label)) + if accum_loss is not None: + loss_reduced = self._call_engine_criterion(engine, output_obj, label) / self.num_microbatches + accum_loss.add_(loss_reduced.detach()) + return loss_reduced + else: + # forward only, it's useless since backward is not needed + return output_obj + else: + if isinstance(output_obj, torch.Tensor): + self._logger.debug( + f"Global rank {gpc.get_global_rank()}, pipeline rank {gpc.get_local_rank(ParallelMode.PIPELINE)} forward output tensor {output_obj.shape}, dtype {output_obj.dtype}" + ) + return output_obj + + def forward_backward_step(self, engine, data_iter, forward_only=False, return_loss=True, return_output_label=True): + """Run interleaved 1F1B schedule (model split into model chunks), with + communication between pipeline stages as needed. + + Args: + engine (colossalai.legacy.engine.Engine): Colossalai engine for training and inference. + data_iter (Iterable): Dataloader as the form of an iterator, obtained by calling iter(dataloader). + forward_only (bool, optional): + Whether run forward step only. Default is false. If true, no backward will be run. + return_loss (bool, optional): Whether returns the loss value. Default is true. + return_output_label (bool, optional): If False, the output and label won't be returned. + + Returns: + Tuple[:class:`torch.Tensor`]: A tuple of (output, label, loss), loss and label could be None. + The loss would be returned only in the last stage. + """ + assert ( + forward_only or return_loss + ), "The argument 'return_loss' has to be True when 'forward_only' is False, but got False." + self.load_batch(data_iter) + model = engine.model + input_objs = [[] for _ in range(len(model))] + output_objs = [[] for _ in range(len(model))] + return_tensors = [] + if not forward_only: + output_obj_grads = [[] for _ in range(len(model))] + if return_loss and gpc.is_pipeline_last_stage(ignore_virtual=True): + accum_loss = torch.zeros(1, device=get_current_device()) + else: + accum_loss = None + + # Used for obj meta information communication + input_obj_shapes = [self.tensor_shape for _ in range(len(model))] + output_obj_shapes = [None for _ in range(len(model))] + send_tensor_shape_flags = [self.tensor_shape is None for _ in range(len(model))] + + pipeline_parallel_size = gpc.get_world_size(ParallelMode.PIPELINE) + pipeline_parallel_rank = gpc.get_local_rank(ParallelMode.PIPELINE) + + # Compute number of warmup and remaining microbatches. + num_model_chunks = len(model) + num_microbatches = self.num_microbatches * num_model_chunks + all_warmup_microbatches = False + if forward_only: + num_warmup_microbatches = num_microbatches + else: + # Run all forward passes and then all backward passes if number of + # microbatches is just the number of pipeline stages. + # Otherwise, perform (num_model_chunks-1)*pipeline_parallel_size on + # all workers, followed by more microbatches after depending on + # stage ID (more forward passes for earlier stages, later stages can + # immediately start with 1F1B). + if self.num_microbatches == pipeline_parallel_size: + num_warmup_microbatches = num_microbatches + all_warmup_microbatches = True + else: + num_warmup_microbatches = (pipeline_parallel_size - pipeline_parallel_rank - 1) * 2 + num_warmup_microbatches += (num_model_chunks - 1) * pipeline_parallel_size + num_warmup_microbatches = min(num_warmup_microbatches, num_microbatches) + num_microbatches_remaining = num_microbatches - num_warmup_microbatches + + def get_model_chunk_id(microbatch_id, forward): + """Helper method to get the model chunk ID given the iteration number.""" + microbatch_id_in_group = microbatch_id % (pipeline_parallel_size * num_model_chunks) + model_chunk_id = microbatch_id_in_group // pipeline_parallel_size + if not forward: + model_chunk_id = num_model_chunks - model_chunk_id - 1 + return model_chunk_id + + def _forward_step_helper(microbatch_id): + """Helper method to run forward step with model split into chunks + (run set_virtual_pipeline_model_parallel_rank() before calling + forward_step()).""" + model_chunk_id = get_model_chunk_id(microbatch_id, forward=True) + gpc.set_virtual_pipeline_parallel_rank(model_chunk_id) + + # forward step + if gpc.is_pipeline_first_stage(): + if len(input_objs[model_chunk_id]) == len(output_objs[model_chunk_id]): + input_objs[model_chunk_id].append(None) + input_obj = input_objs[model_chunk_id][-1] + output_obj = self._forward_step( + engine, + model_chunk_id, + input_obj, + return_tensors, + return_output_label=return_output_label, + accum_loss=accum_loss, + ) + output_objs[model_chunk_id].append(output_obj) + + # if forward-only, no need to save tensors for a backward pass + if forward_only: + input_objs[model_chunk_id].pop() + output_objs[model_chunk_id].pop() + + return output_obj + + def _backward_step_helper(microbatch_id): + """Helper method to run backward step with model split into chunks + (run set_virtual_pipeline_model_parallel_rank() before calling + backward_step()).""" + model_chunk_id = get_model_chunk_id(microbatch_id, forward=False) + gpc.set_virtual_pipeline_parallel_rank(model_chunk_id) + + if gpc.is_pipeline_last_stage(): + if len(output_obj_grads[model_chunk_id]) == 0: + output_obj_grads[model_chunk_id].append(None) + input_obj = input_objs[model_chunk_id].pop(0) + output_obj = output_objs[model_chunk_id].pop(0) + output_obj_grad = output_obj_grads[model_chunk_id].pop(0) + input_obj_grad = self._backward_step(engine, input_obj, output_obj, output_obj_grad) + + return input_obj_grad + + # Run warmup forward passes. + gpc.set_virtual_pipeline_parallel_rank(0) + if not gpc.is_pipeline_first_stage(): + input_obj_shapes[0] = comm.recv_obj_meta(input_obj_shapes[0]) + input_objs[0].append( + comm.recv_forward(input_obj_shapes[0], dtype=self.dtype, scatter_gather_tensors=self.scatter_gather_tensors) + ) + + for k in range(num_warmup_microbatches): + model_chunk_id = get_model_chunk_id(k, forward=True) + output_obj = _forward_step_helper(k) + if not gpc.is_pipeline_last_stage(): + if isinstance(output_obj, torch.Tensor): + output_obj_shapes[model_chunk_id] = output_obj.shape + else: + output_obj_shapes[model_chunk_id] = [] + for out_tensor in output_obj: + output_obj_shapes[model_chunk_id].append(out_tensor.shape) + send_tensor_shape_flags[model_chunk_id] = comm.send_obj_meta( + output_obj, send_tensor_shape_flags[model_chunk_id] + ) + # Determine if tensor should be received from previous stage. + next_forward_model_chunk_id = get_model_chunk_id(k + 1, forward=True) + recv_prev = True + if gpc.is_pipeline_first_stage(ignore_virtual=True): + if next_forward_model_chunk_id == 0: + recv_prev = False + if k == (num_microbatches - 1): + recv_prev = False + + # Don't send tensor downstream if on last stage. + if gpc.is_pipeline_last_stage(): + output_obj = None + + with switch_virtual_pipeline_parallel_rank(next_forward_model_chunk_id): + if not gpc.is_pipeline_first_stage(): + input_obj_shapes[next_forward_model_chunk_id] = comm.recv_obj_meta( + input_obj_shapes[next_forward_model_chunk_id] + ) + # Send and receive tensors as appropriate (send tensors computed + # in this iteration; receive tensors for next iteration). + input_shape = input_obj_shapes[next_forward_model_chunk_id] if recv_prev else None + if k == (num_warmup_microbatches - 1) and not forward_only and not all_warmup_microbatches: + input_obj_grad = None + recv_next = True + if gpc.is_pipeline_last_stage(ignore_virtual=True): + recv_next = False + output_shape = output_obj_shapes[num_model_chunks - 1] if recv_next else None + input_obj, output_obj_grad = comm.send_forward_backward_recv_forward_backward( + output_obj, + input_obj_grad, + input_shape, + output_shape, + recv_prev=recv_prev, + recv_next=recv_next, + dtype=self.dtype, + scatter_gather_tensors=self.scatter_gather_tensors, + ) + output_obj_grads[num_model_chunks - 1].append(output_obj_grad) + else: + input_obj = comm.send_forward_recv_forward( + output_obj, + input_shape, + recv_prev=recv_prev, + dtype=self.dtype, + scatter_gather_tensors=self.scatter_gather_tensors, + ) + input_objs[next_forward_model_chunk_id].append(input_obj) + + # Run 1F1B in steady state. + for k in range(num_microbatches_remaining): + # Forward pass. + forward_k = k + num_warmup_microbatches + output_obj = _forward_step_helper(forward_k) + + # Backward pass. + backward_k = k + input_obj_grad = _backward_step_helper(backward_k) + + # Send output_obj and input_obj_grad, receive input_obj + # and output_obj_grad. + + # Determine if current stage has anything to send in either direction, + # otherwise set obj to None. + forward_model_chunk_id = get_model_chunk_id(forward_k, forward=True) + gpc.set_virtual_pipeline_parallel_rank(forward_model_chunk_id) + if gpc.is_pipeline_last_stage(): + output_obj = None + + backward_model_chunk_id = get_model_chunk_id(backward_k, forward=False) + gpc.set_virtual_pipeline_parallel_rank(backward_model_chunk_id) + if gpc.is_pipeline_first_stage(): + input_obj_grad = None + + # Determine if peers are sending, and where in data structure to put + # received tensors. + recv_prev = True + if gpc.is_pipeline_first_stage(ignore_virtual=True): + # First stage is ahead of last stage by (pipeline_parallel_size - 1). + next_forward_model_chunk_id = get_model_chunk_id(forward_k - (pipeline_parallel_size - 1), forward=True) + if next_forward_model_chunk_id == (num_model_chunks - 1): + recv_prev = False + next_forward_model_chunk_id += 1 + else: + next_forward_model_chunk_id = get_model_chunk_id(forward_k + 1, forward=True) + + recv_next = True + if gpc.is_pipeline_last_stage(ignore_virtual=True): + # Last stage is ahead of first stage by (pipeline_parallel_size - 1). + next_backward_model_chunk_id = get_model_chunk_id( + backward_k - (pipeline_parallel_size - 1), forward=False + ) + if next_backward_model_chunk_id == 0: + recv_next = False + next_backward_model_chunk_id -= 1 + else: + next_backward_model_chunk_id = get_model_chunk_id(backward_k + 1, forward=False) + + # If last iteration, don't receive; we already received one extra + # before the start of the for loop. + if k == (num_microbatches_remaining - 1): + recv_prev = False + + input_shape = input_obj_shapes[next_forward_model_chunk_id] if recv_prev else None + output_shape = output_obj_shapes[next_backward_model_chunk_id] if recv_next else None + # Communicate objs. + input_obj, output_obj_grad = comm.send_forward_backward_recv_forward_backward( + output_obj, + input_obj_grad, + input_shape, + output_shape, + recv_prev=recv_prev, + recv_next=recv_next, + dtype=self.dtype, + scatter_gather_tensors=self.scatter_gather_tensors, + ) + + # Put input_obj and output_obj_grad in data structures in the + # right location. + if recv_prev: + input_objs[next_forward_model_chunk_id].append(input_obj) + if recv_next: + output_obj_grads[next_backward_model_chunk_id].append(output_obj_grad) + + # Run cooldown backward passes (flush out pipeline). + if not forward_only: + if all_warmup_microbatches: + output_obj_grads[num_model_chunks - 1].append( + comm.recv_backward( + output_obj_shapes[num_model_chunks - 1], scatter_gather_tensors=self.scatter_gather_tensors + ) + ) + for k in range(num_microbatches_remaining, num_microbatches): + input_obj_grad = _backward_step_helper(k) + next_backward_model_chunk_id = get_model_chunk_id(k + 1, forward=False) + recv_next = True + if gpc.is_pipeline_last_stage(ignore_virtual=True): + if next_backward_model_chunk_id == (num_model_chunks - 1): + recv_next = False + if k == (num_microbatches - 1): + recv_next = False + output_shape = output_obj_shapes[next_backward_model_chunk_id] if recv_next else None + output_obj_grads[next_backward_model_chunk_id].append( + comm.send_backward_recv_backward( + input_obj_grad, + output_shape, + recv_next=recv_next, + dtype=self.dtype, + scatter_gather_tensors=self.scatter_gather_tensors, + ) + ) + + if len(return_tensors) > 0: + output, label = pack_return_tensors(return_tensors) + return output, label, accum_loss + else: + return None, None, accum_loss diff --git a/colossalai/engine/schedule/_pipeline_schedule_v2.py b/colossalai/legacy/engine/schedule/_pipeline_schedule_v2.py similarity index 77% rename from colossalai/engine/schedule/_pipeline_schedule_v2.py rename to colossalai/legacy/engine/schedule/_pipeline_schedule_v2.py index 28c58bd82b5c3f6969337c0a718a5698346744d9..867c3dfa819b997466ccded9ac5124b713fc2321 100644 --- a/colossalai/engine/schedule/_pipeline_schedule_v2.py +++ b/colossalai/legacy/engine/schedule/_pipeline_schedule_v2.py @@ -5,10 +5,10 @@ from typing import Iterable, Tuple import torch.cuda -import colossalai.communication.p2p_v2 as comm -from colossalai import engine -from colossalai.context.parallel_mode import ParallelMode -from colossalai.core import global_context as gpc +import colossalai.legacy.communication.p2p_v2 as comm +from colossalai.legacy.context.parallel_mode import ParallelMode +from colossalai.legacy.core import global_context as gpc +from colossalai.legacy.engine import Engine from colossalai.utils.cuda import get_current_device from ._pipeline_schedule import PipelineSchedule @@ -21,7 +21,7 @@ def pack_return_tensors(return_tensors): elif isinstance(output[0], (list, tuple)): output = tuple(torch.cat(tensors, dim=0) for tensors in zip(*output)) else: - raise TypeError(f'Output of model must be tensor or list/tuple of tensors') + raise TypeError(f"Output of model must be tensor or list/tuple of tensors") if isinstance(label[0], torch.Tensor): label = torch.cat(label, dim=0) else: @@ -59,17 +59,14 @@ class PipelineScheduleV2(PipelineSchedule): """ - def forward_backward_step(self, - engine: engine.Engine, - data_iter: Iterable, - forward_only=False, - return_loss=True, - return_output_label=True) -> Tuple[torch.Tensor]: + def forward_backward_step( + self, engine: Engine, data_iter: Iterable, forward_only=False, return_loss=True, return_output_label=True + ) -> Tuple[torch.Tensor]: """Runs non-interleaved 1F1B schedule, with communication between pipeline stages. Returns a tuple with losses if the last stage, an empty tuple otherwise. Args: - engine (colossalai.engine.Engine): Colossalai engine for training and inference. + engine (colossalai.legacy.engine.Engine): Colossalai engine for training and inference. data_iter (Iterable): Dataloader as the form of an iterator, obtained by calling iter(dataloader). forward_only (bool, optional): Whether run forward step only. Default is false. If true, no backward will be run. @@ -80,14 +77,15 @@ class PipelineScheduleV2(PipelineSchedule): Tuple[:class:`torch.Tensor`]: A tuple of (output, label, loss), loss and label could be None. """ - assert forward_only or return_loss, \ - 'The argument \'return_loss\' has to be True when \'forward_only\' is False, but got False.' + assert ( + forward_only or return_loss + ), "The argument 'return_loss' has to be True when 'forward_only' is False, but got False." self.load_batch(data_iter) - # num_warmup_microbatches is the step when not all the processers are working - num_warmup_microbatches = \ - (gpc.get_world_size(ParallelMode.PIPELINE) - - gpc.get_local_rank(ParallelMode.PIPELINE) - 1) + # num_warmup_microbatches is the step when not all the processes are working + num_warmup_microbatches = ( + gpc.get_world_size(ParallelMode.PIPELINE) - gpc.get_local_rank(ParallelMode.PIPELINE) - 1 + ) num_warmup_microbatches = min(num_warmup_microbatches, self.num_microbatches) num_microbatches_remaining = self.num_microbatches - num_warmup_microbatches @@ -109,11 +107,9 @@ class PipelineScheduleV2(PipelineSchedule): for i in range(num_warmup_microbatches): input_obj = comm.recv_forward() - output_obj = self._forward_step(engine, - input_obj, - return_tensors, - return_output_label=return_output_label, - accum_loss=accum_loss) + output_obj = self._forward_step( + engine, input_obj, return_tensors, return_output_label=return_output_label, accum_loss=accum_loss + ) comm.send_forward(output_obj) @@ -129,13 +125,11 @@ class PipelineScheduleV2(PipelineSchedule): # Run 1F1B in steady state. for i in range(num_microbatches_remaining): - last_iteration = (i == (num_microbatches_remaining - 1)) + last_iteration = i == (num_microbatches_remaining - 1) - output_obj = self._forward_step(engine, - input_obj, - return_tensors, - return_output_label=return_output_label, - accum_loss=accum_loss) + output_obj = self._forward_step( + engine, input_obj, return_tensors, return_output_label=return_output_label, accum_loss=accum_loss + ) if forward_only: comm.send_forward(output_obj) diff --git a/colossalai/legacy/global_variables.py b/colossalai/legacy/global_variables.py new file mode 100644 index 0000000000000000000000000000000000000000..93cd5e60fa613d7a66d7d796dde682fd83e8bdbb --- /dev/null +++ b/colossalai/legacy/global_variables.py @@ -0,0 +1,60 @@ +from typing import Optional + + +class TensorParallelEnv(object): + _instance = None + + def __new__(cls, *args, **kwargs): + if cls._instance is None: + cls._instance = object.__new__(cls, *args, **kwargs) + return cls._instance + + def __init__(self, *args, **kwargs): + self.load(*args, **kwargs) + + def load( + self, + mode: Optional[str] = None, + vocab_parallel: bool = False, + parallel_input_1d: bool = False, + summa_dim: int = None, + tesseract_dim: int = None, + tesseract_dep: int = None, + depth_3d: int = None, + input_group_3d=None, + weight_group_3d=None, + output_group_3d=None, + input_x_weight_group_3d=None, + output_x_weight_group_3d=None, + ): + self.mode = mode + self.vocab_parallel = vocab_parallel + self.parallel_input_1d = parallel_input_1d + self.summa_dim = summa_dim + self.tesseract_dim = tesseract_dim + self.tesseract_dep = tesseract_dep + self.depth_3d = depth_3d + self.input_group_3d = input_group_3d + self.weight_group_3d = weight_group_3d + self.output_group_3d = output_group_3d + self.input_x_weight_group_3d = input_x_weight_group_3d + self.output_x_weight_group_3d = output_x_weight_group_3d + + def save(self): + return dict( + mode=self.mode, + vocab_parallel=self.vocab_parallel, + parallel_input_1d=self.parallel_input_1d, + summa_dim=self.summa_dim, + tesseract_dim=self.tesseract_dim, + tesseract_dep=self.tesseract_dep, + depth_3d=self.depth_3d, + input_group_3d=self.input_group_3d, + weight_group_3d=self.weight_group_3d, + output_group_3d=self.output_group_3d, + input_x_weight_group_3d=self.input_x_weight_group_3d, + output_x_weight_group_3d=self.output_x_weight_group_3d, + ) + + +tensor_parallel_env = TensorParallelEnv() diff --git a/colossalai/legacy/initialize.py b/colossalai/legacy/initialize.py new file mode 100644 index 0000000000000000000000000000000000000000..ce9c626553bf97fc2ffc0bad55b7c4cb572e14e1 --- /dev/null +++ b/colossalai/legacy/initialize.py @@ -0,0 +1,502 @@ +#!/usr/bin/env python +# -*- encoding: utf-8 -*- + +import argparse +import os +import pprint +from pathlib import Path +from typing import Callable, Dict, Iterable, List, Optional, Tuple, Union + +import torch +import torch.nn as nn +from torch.nn.modules.loss import _Loss +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.optim.lr_scheduler import _LRScheduler +from torch.optim.optimizer import Optimizer +from torch.utils.data import DataLoader + +from colossalai.context import Config, ConfigException +from colossalai.context.moe_context import MOE_CONTEXT +from colossalai.interface import OptimizerWrapper +from colossalai.legacy.amp import AMP_TYPE, convert_to_amp +from colossalai.legacy.amp.naive_amp import NaiveAMPModel +from colossalai.legacy.builder.builder import build_gradient_handler +from colossalai.legacy.context import ParallelMode +from colossalai.legacy.core import global_context as gpc +from colossalai.legacy.engine import Engine +from colossalai.legacy.engine.gradient_accumulation import accumulate_gradient +from colossalai.legacy.engine.schedule import ( + InterleavedPipelineSchedule, + NonPipelineSchedule, + PipelineSchedule, + get_tensor_shape, +) +from colossalai.legacy.utils import is_using_ddp, is_using_pp, is_using_sequence, sync_model_param +from colossalai.legacy.zero import ShardedOptimizerV2, convert_to_zero_v2 +from colossalai.legacy.zero.gemini.ophooks import BaseOpHook +from colossalai.logging import get_dist_logger +from colossalai.utils import get_current_device +from colossalai.utils.moe import sync_moe_model_param + + +def get_default_parser(): + """Reads user command line and uses an argument parser to parse the input arguments. + Input arguments include configuration, host, port, world size, local rank, backend for torch.distributed. + + Returns: + Namespace: Returns the parser with the default arguments, the user may add customized arguments into this parser. + """ + parser = argparse.ArgumentParser() + parser.add_argument("--config", type=str, help="path to the config file") + parser.add_argument("--host", type=str, help="the master address for distributed training") + parser.add_argument("--port", type=int, help="the master port for distributed training") + parser.add_argument("--world_size", type=int, help="world size for distributed training") + parser.add_argument("--rank", type=int, help="rank for the default process group") + parser.add_argument("--local_rank", type=int, help="local rank on the node") + parser.add_argument("--backend", type=str, default="nccl", help="backend for distributed communication") + return parser + + +def launch( + config: Union[str, Path, Config, Dict], + rank: int, + world_size: int, + host: str, + port: int, + backend: str = "nccl", + local_rank: int = None, + seed: int = 1024, + verbose: bool = True, +): + """This function first parses the configuration arguments, using :func:`parse_args()` in case one of the input + arguments are not given. Then initialize and set distributed environment by calling global_context's functions. + + Args: + config (Union[str, dict, Config]): Config file or config file path are both acceptable + rank (int): Rank for the default process group + world_size (int): World size of the default process group + host (str): The master address for distributed training + port (str): The master port for distributed training + backend (str, optional): Backend for ``torch.distributed``, defaults to ``nccl`` + local_rank (int, optional): + Rank for the process on the node and is used to set the default CUDA device, + defaults to None. If local_rank = None, the default device ordinal will be calculated automatically. + seed (int, optional): Specified random seed for every process. Defaults to 1024. + verbose (bool, optional): Whether to print logs. Defaults to True. + + Raises: + Exception: Raise exception when config type is wrong + """ + gpc.verbose = verbose + + # set config + assert isinstance( + config, (Config, str, Path, dict) + ), f"expected argument config to be Config, str or Path, but got {type(config)}" + if not isinstance(config, Config) and isinstance(config, dict): + config = Config(config) + if isinstance(config, (str, Path)): + config = Config.from_file(config) + gpc.load_config(config) + + # init default process group + gpc.init_global_dist(rank, world_size, backend, host, port) + + # init process groups for different parallel modes from config + gpc.init_parallel_groups() + + # set cuda device + if torch.cuda.is_available(): + # if local rank is not given, calculate automatically + gpc.set_device(local_rank) + + # set the number of processes running on the same node + gpc.detect_num_processes_on_current_node() + + gpc.set_seed(seed) + + if verbose: + logger = get_dist_logger() + logger.info( + f"Distributed environment is initialized, " + f"data parallel size: {gpc.data_parallel_size}, pipeline parallel size: {gpc.pipeline_parallel_size}, " + f"tensor parallel size: {gpc.tensor_parallel_size}", + ranks=[0], + ) + + +def launch_from_slurm( + config: Union[str, Path, Config, Dict], + host: str, + port: int, + backend: str = "nccl", + seed: int = 1024, + verbose: bool = True, +): + """A wrapper for colossalai.launch for SLURM launcher by reading rank and world size from the environment variables + set by SLURM + + Args: + config (Union[str, dict, Config]): Config file or config file path are both acceptable + host (str): The master address for distributed training + port (str): The master port for distributed training + backend (str, optional): Backend for ``torch.distributed``, defaults to ``nccl`` + seed (int, optional): Specified random seed for every process. Defaults to 1024. + verbose (bool, optional): Whether to print logs. Defaults to True. + """ + try: + rank = int(os.environ["SLURM_PROCID"]) + world_size = int(os.environ["SLURM_NPROCS"]) + except KeyError as e: + raise RuntimeError( + f"Could not find {e} in the SLURM environment, visit https://www.colossalai.org/ for more information on launching with SLURM" + ) + + launch( + config=config, + rank=rank, + world_size=world_size, + host=host, + port=port, + backend=backend, + seed=seed, + verbose=verbose, + ) + + +def launch_from_openmpi( + config: Union[str, Path, Config, Dict], + host: str, + port: int, + backend: str = "nccl", + seed: int = 1024, + verbose: bool = True, +): + """A wrapper for colossalai.launch for OpenMPI launcher by reading rank and world size from the environment variables + set by OpenMPI + + Args: + config (Union[str, dict, Config]): Config file or config file path are both acceptable + host (str): The master address for distributed training + port (str): The master port for distributed training + backend (str, optional): Backend for ``torch.distributed``, defaults to ``nccl`` + seed (int, optional): Specified random seed for every process. Defaults to 1024. + verbose (bool, optional): Whether to print logs. Defaults to True. + """ + try: + rank = int(os.environ["OMPI_COMM_WORLD_RANK"]) + local_rank = int(os.environ["OMPI_COMM_WORLD_LOCAL_RANK"]) + world_size = int(os.environ["OMPI_COMM_WORLD_SIZE"]) + except KeyError as e: + raise RuntimeError( + f"Could not find {e} in the OpenMPI environment, visit https://www.colossalai.org/ for more information on launching with OpenMPI" + ) + + launch( + config=config, + local_rank=local_rank, + rank=rank, + world_size=world_size, + host=host, + port=port, + backend=backend, + seed=seed, + verbose=verbose, + ) + + +def launch_from_torch( + config: Union[str, Path, Config, Dict], backend: str = "nccl", seed: int = 1024, verbose: bool = True +): + """A wrapper for colossalai.launch for torchrun or torch.distributed.launch by reading rank and world size + from the environment variables set by PyTorch + + Args: + config (Union[str, dict, Config]): Config file or config file path are both acceptable + backend (str, optional): Backend for ``torch.distributed``, defaults to ``nccl`` + seed (int, optional): Specified random seed for every process. Defaults to 1024. + verbose (bool, optional): Whether to print logs. Defaults to True. + """ + try: + rank = int(os.environ["RANK"]) + local_rank = int(os.environ["LOCAL_RANK"]) + world_size = int(os.environ["WORLD_SIZE"]) + host = os.environ["MASTER_ADDR"] + port = int(os.environ["MASTER_PORT"]) + except KeyError as e: + raise RuntimeError( + f"Could not find {e} in the torch environment, visit https://www.colossalai.org/ for more information on launching with torch" + ) + + launch( + config=config, + local_rank=local_rank, + rank=rank, + world_size=world_size, + host=host, + port=port, + backend=backend, + seed=seed, + verbose=verbose, + ) + + +def initialize( + model: nn.Module, + optimizer: Optimizer, + criterion: Optional[_Loss] = None, + train_dataloader: Optional[Iterable] = None, + test_dataloader: Optional[Iterable] = None, + lr_scheduler: Optional[_LRScheduler] = None, + ophooks: Optional[List[BaseOpHook]] = None, + verbose: bool = True, +) -> Tuple[Engine, DataLoader, DataLoader, _LRScheduler]: + """Core function to wrap the essential training components with our functionality based on the config which is + loaded into gpc.config. + + Args: + model (:class:`torch.nn.Module` or Callable): Your model instance or a function to build the model. + optimizer (:class:`torch.optim.optimizer.Optimizer` or :class:`Type[torch.optim.optimizer]`): + Your optimizer instance. + criterion (:class:`torch.nn.modules.loss._Loss`, optional): Your criterion instance. + train_dataloader (:class:`torch.utils.data.DataLoader`, optional): Dataloader for training. + test_dataloader (:class:`torch.utils.data.DataLoader`, optional): Dataloader for testing. + lr_scheduler (:class:`torch.nn.lr_scheduler._LRScheduler`, optional): Your lr scheduler instance, optional. + verbose (bool, optional): Whether to print logs. + + Returns: + Tuple (engine, train_dataloader, test_dataloader, lr_scheduler): + A tuple of ``(engine, train_dataloader, test_dataloader, lr_scheduler)`` + where only ``engine`` could not be None. + """ + # get logger + logger = get_dist_logger() + gpc.verbose = verbose + + # get config from gpc + config = gpc.config + + # print config + if verbose: + logger.info( + f"\n========== Your Config ========\n" + f"{pprint.pformat(gpc.config)}\n" + f"================================\n", + ranks=[0], + ) + + # cudnn + cudnn_benchmark = config.get("cudnn_benchmark", False) + cudnn_deterministic = config.get("cudnn_deterministic", False) + torch.backends.cudnn.benchmark = cudnn_benchmark + torch.backends.cudnn.deterministic = cudnn_deterministic + if verbose: + logger.info(f"cuDNN benchmark = {cudnn_benchmark}, deterministic = {cudnn_deterministic}", ranks=[0]) + + # zero + use_zero = hasattr(gpc.config, "zero") + if use_zero: + zero_cfg = gpc.config.get("zero", None) + if zero_cfg is not None: + cfg_ = zero_cfg.copy() + else: + cfg_ = {} + optimizer_config = zero_cfg.get("optimizer_config", None) + model_config = zero_cfg.get("model_config", None) + model, optimizer = convert_to_zero_v2( + model, optimizer, model_config=model_config, optimizer_config=optimizer_config + ) + + logger.info("Initializing ZeRO model and optimizer finished!", ranks=[0]) + else: + if isinstance(model, nn.Module): + # first sync model across dp ranks + model.to(get_current_device()) + elif isinstance(model, Callable): + model = model().to(get_current_device()) + + # optimizer maybe a optimizer_cls + if isinstance(optimizer, Callable): + optimizer = optimizer(model.parameters()) + logger.warning("Initializing an non ZeRO model with optimizer class") + + if not use_zero: + if is_using_sequence(): + sync_model_param(model, ParallelMode.SEQUENCE_DP) + elif MOE_CONTEXT.is_initialized: + sync_moe_model_param(model) + elif is_using_ddp(): + sync_model_param(model, ParallelMode.DATA) + else: + logger.warning( + "The parameters of models is not automatically synchronized.\n" + "Please make sure that all parameters are the same in data parallel group.", + ranks=[0], + ) + + # check amp and zero + fp16_cfg = gpc.config.get("fp16", None) + + if fp16_cfg is not None and fp16_cfg.mode is not None and use_zero: + raise ConfigException( + "It is not allowed to set fp16 and zero configuration in your config file at the same time" + ) + + # clip grad norm + clip_grad_norm = gpc.config.get("clip_grad_norm", 0.0) + + # initialize amp + amp_mode = None + if fp16_cfg is not None and fp16_cfg.mode is not None: + cfg_ = fp16_cfg.copy() + amp_mode = cfg_.pop("mode") + if is_using_pp(): + assert amp_mode == AMP_TYPE.NAIVE, "Pipeline only support NaiveAMP currently" + if amp_mode == AMP_TYPE.NAIVE: + cfg_["clip_grad_norm"] = clip_grad_norm + model, optimizer, criterion = convert_to_amp( + model=model, optimizer=optimizer, criterion=criterion, mode=amp_mode, amp_config=cfg_ + ) + + # get torch ddp config + torch_ddp_cfg = gpc.config.get("torch_ddp", dict()) + + # gradient handler + gradient_handler_cfg = gpc.config.get("gradient_handler", None) + if gradient_handler_cfg is None: + # if gradient handler is not specified in the configuration file, + # check in the following order + # 1. if optimizer is ZERO, then use zero grad handler + # 2. if dp size is larger than 1 and pipeline is not used, use pytorch ddp + # 3. if using pipeline and dp size larger than 1, use data parallel grad handler + if isinstance(optimizer, ShardedOptimizerV2): + gradient_handler_cfg = [dict(type="ZeROGradientHandler")] + if verbose: + logger.info( + "Training with zero is detected, ZeROGradientHandler is automatically " + "added even though not specified in the configuration", + ranks=[0], + ) + elif is_using_ddp() and MOE_CONTEXT.is_initialized: + gradient_handler_cfg = [dict(type="MoeGradientHandler")] + if verbose: + logger.info( + "Data parallel training is detected with moe parallel, MoeGradientHandler is automatically " + "added even though not specified in the configuration", + ranks=[0], + ) + elif is_using_sequence(): + model = DDP( + model, + process_group=gpc.get_group(ParallelMode.SEQUENCE_DP), + device_ids=[torch.cuda.current_device()], + **torch_ddp_cfg, + ) + if verbose: + logger.info( + "Model is using torch.nn.parallel.DistributedDataParallel for Sequence Parallelism", ranks=[0] + ) + elif is_using_ddp() and not is_using_pp() and amp_mode != AMP_TYPE.NAIVE: + model = DDP( + model, + process_group=gpc.get_group(ParallelMode.DATA), + device_ids=[torch.cuda.current_device()], + **torch_ddp_cfg, + ) + if verbose: + logger.info("Model is using torch.nn.parallel.DistributedDataParallel for Data Parallelism", ranks=[0]) + elif is_using_ddp(): + gradient_handler_cfg = [dict(type="DataParallelGradientHandler")] + if verbose: + logger.info( + "Data parallel training is detected when using pipeline parallel, " + "DataParallelGradientHandler is automatically " + "added even though not specified in the configuration", + ranks=[0], + ) + # add pipeline parallel gradient handler, if pipeline shared module is detected + for param in model.parameters(): + if getattr(param, "pipeline_shared_module_pg", None) is not None: + if gradient_handler_cfg is None: + gradient_handler_cfg = [dict(type="PipelineSharedModuleGradientHandler")] + else: + gradient_handler_cfg.append(dict(type="PipelineSharedModuleGradientHandler")) + if verbose: + logger.info( + "pipeline_shared_module is detected, PipelineSharedModuleGradientHandler is automatically " + "added even though not specified in the configuration", + ranks=[0], + ) + break + else: + if not isinstance(gradient_handler_cfg, list): + raise ConfigException( + f"expected gradient_handler in the configuration file to be a list but got {type(gradient_handler_cfg)}" + ) + + # turn off sync buffer for NaiveAMPModel if using torch DDP and NaiveAMPModel at the same time + # to avoid duplicated buffer synchronization + if isinstance(model, DDP) and isinstance(model.module, NaiveAMPModel): + model.module.sync_buffer = False + + # initialize schedule for engine + if is_using_pp(): + tensor_shape = get_tensor_shape() + use_interleaved = hasattr(gpc.config, "model") and hasattr(gpc.config.model, "num_chunks") + if gpc.is_initialized(ParallelMode.PARALLEL_1D): + scatter_gather = True + else: + scatter_gather = False + if use_interleaved: + if isinstance(model, nn.Sequential): + model = nn.ModuleList([model]) + schedule = InterleavedPipelineSchedule( + gpc.config.NUM_MICRO_BATCHES, + gpc.config.model.num_chunks, + tensor_shape=tensor_shape, + scatter_gather_tensors=scatter_gather, + ) + else: + schedule = PipelineSchedule( + gpc.config.NUM_MICRO_BATCHES, tensor_shape=tensor_shape, scatter_gather_tensors=scatter_gather + ) + else: + schedule = NonPipelineSchedule() + + if gradient_handler_cfg is None: + gradient_handlers = None + if verbose and not isinstance(model, DDP): + logger.warning( + "No PyTorch DDP or gradient handler is set up, please make sure you do not need " + "to all-reduce the gradients after a training step.", + ranks=[0], + ) + else: + gradient_handlers = [build_gradient_handler(cfg, model, optimizer) for cfg in gradient_handler_cfg] + + # check if optimizer is OptimizerWrapper + if not isinstance(optimizer, (OptimizerWrapper, ShardedOptimizerV2)): + optimizer = OptimizerWrapper(optim=optimizer) + + # gradient accumulation + grad_accum_size = gpc.config.get("gradient_accumulation", None) + if grad_accum_size is not None: + optimizer, train_dataloader, gradient_handlers, lr_scheduler = accumulate_gradient( + model=model, + optimizer=optimizer, + dataloader=train_dataloader, + accumulate_size=grad_accum_size, + gradient_handlers=gradient_handlers, + lr_scheduler=lr_scheduler, + ) + engine = Engine( + model=model, + optimizer=optimizer, + criterion=criterion, + gradient_handlers=gradient_handlers, + clip_grad_norm=clip_grad_norm, + ophook_list=ophooks, + schedule=schedule, + ) + + return engine, train_dataloader, test_dataloader, lr_scheduler diff --git a/colossalai/legacy/nn/__init__.py b/colossalai/legacy/nn/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..d30ebf8d5406f7604c0c2c710f796cc2d080cdcd --- /dev/null +++ b/colossalai/legacy/nn/__init__.py @@ -0,0 +1,3 @@ +from .layer import * +from .loss import * +from .metric import * diff --git a/colossalai/legacy/nn/_ops/__init__.py b/colossalai/legacy/nn/_ops/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..9a35d02ce5edd0e2f4a16831621041009928f129 --- /dev/null +++ b/colossalai/legacy/nn/_ops/__init__.py @@ -0,0 +1 @@ +from ._utils import * diff --git a/colossalai/legacy/nn/_ops/_utils.py b/colossalai/legacy/nn/_ops/_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..b6a99f855a4cf4c83792a41205f3795adeba902b --- /dev/null +++ b/colossalai/legacy/nn/_ops/_utils.py @@ -0,0 +1,285 @@ +from typing import List, Optional, Union + +import torch +import torch.distributed as dist + +from colossalai.legacy.global_variables import tensor_parallel_env as env +from colossalai.legacy.nn.layer.utils import divide +from colossalai.legacy.tensor import ColoTensorSpec, ProcessGroup +from colossalai.tensor import ColoTensor + +GeneralTensor = Union[ColoTensor, torch.Tensor] +Number = Union[int, float] + + +def convert_to_colo_tensor(tensor: Optional[GeneralTensor], pg: ProcessGroup) -> Optional[ColoTensor]: + if tensor is not None and not isinstance(tensor, ColoTensor): + tensor = ColoTensor.from_torch_tensor(tensor, ColoTensorSpec(pg)) + return tensor + + +def set_parallel_input(input_parallel: bool): + env.parallel_input_1d = input_parallel + + +def get_parallel_input(): + return env.parallel_input_1d + + +def vocab_range_from_per_partition_vocab_size(per_partition_vocab_size, rank): + index_f = rank * per_partition_vocab_size + index_l = index_f + per_partition_vocab_size + return index_f, index_l + + +def vocab_range_from_global_vocab_size(global_vocab_size, rank, world_size): + per_partition_vocab_size = divide(global_vocab_size, world_size) + return vocab_range_from_per_partition_vocab_size(per_partition_vocab_size, rank) + + +def _reduce(input_, pg: ProcessGroup): + # skip if only one rank involved + if pg.tp_world_size() == 1: + return input_ + assert input_.device.type == "cuda" + group = pg.tp_process_group() + dist.all_reduce(input_, group=group) + + return input_ + + +def _split(input_, pg: ProcessGroup, dim=-1): + # skip if only one rank involved + world_size = pg.tp_world_size() + if world_size == 1: + return input_ + + # Split along last dimension. + dim_size = input_.size(dim) + assert dim_size % world_size == 0, ( + f"The dimension to split ({dim_size}) is not a multiple of world size ({world_size}), " + f"cannot split tensor evenly" + ) + + tensor_list = torch.split(input_, dim_size // world_size, dim=dim) + rank = pg.tp_local_rank() + output = tensor_list[rank].contiguous() + + return output + + +def _gather(input_, pg: ProcessGroup, dim=-1): + # skip if only one rank involved + world_size = pg.tp_world_size() + if world_size == 1: + return input_ + + # all gather + rank = pg.tp_local_rank() + tensor_list = [torch.empty_like(input_) for _ in range(world_size)] + tensor_list[rank] = input_ + assert input_.device.type == "cuda" + group = pg.tp_process_group() + torch.distributed.all_gather(tensor_list, input_, group=group) + + # concat + output = torch.cat(tensor_list, dim=dim).contiguous() + + return output + + +class _ReduceGrad(torch.autograd.Function): + """ + Pass the input to the model parallel region. + + Args: + input_: input matrix. + process_group: parallel mode. + """ + + @staticmethod + def symbolic(graph, input_): + return input_ + + @staticmethod + def forward(ctx, input_, process_group): + ctx.mode = process_group + return input_ + + @staticmethod + def backward(ctx, grad_output): + return _reduce(grad_output, ctx.mode), None + + +class _ReduceInput(torch.autograd.Function): + """ + All-reduce the input from the model parallel region. + + Args: + input_: input matrix. + process_group: parallel mode. + """ + + @staticmethod + def symbolic(graph, input_): + return _reduce(input_) + + @staticmethod + def forward(ctx, input_, process_group): + return _reduce(input_, process_group) + + @staticmethod + def backward(ctx, grad_output): + return grad_output, None + + +class _SplitForwardGatherBackward(torch.autograd.Function): + """ + Split the input and keep only the corresponding chuck to the rank. + + Args: + input_: input matrix. + process_group: parallel mode. + dim: dimension + """ + + @staticmethod + def symbolic(graph, input_): + return _split(input_) + + @staticmethod + def forward(ctx, input_, process_group, dim): + ctx.mode = process_group + ctx.dim = dim + return _split(input_, process_group, dim) + + @staticmethod + def backward(ctx, grad_output): + return _gather(grad_output, ctx.mode, ctx.dim), None, None + + +class _GatherForwardSplitBackward(torch.autograd.Function): + """Gather the input from model parallel region and concatenate. + + Args: + input_: input matrix. + process_group: parallel mode. + dim: dimension + """ + + @staticmethod + def symbolic(graph, input_): + return _gather(input_) + + @staticmethod + def forward(ctx, input_, process_group, dim): + ctx.mode = process_group + ctx.dim = dim + return _gather(input_, process_group, dim) + + @staticmethod + def backward(ctx, grad_output): + return _split(grad_output, ctx.mode, ctx.dim), None, None + + +def reduce_grad(input_, process_group): + return _ReduceGrad.apply(input_, process_group) + + +def reduce_input(input_, process_group): + return _ReduceInput.apply(input_, process_group) + + +def split_forward_gather_backward(input_, process_group, dim): + return _SplitForwardGatherBackward.apply(input_, process_group, dim) + + +def gather_forward_split_backward(input_, process_group, dim): + return _GatherForwardSplitBackward.apply(input_, process_group, dim) + + +def _all_to_all(x: torch.Tensor, pg: ProcessGroup, scatter_dim: int, gather_dim: int) -> torch.Tensor: + world_size = pg.tp_world_size() + if world_size == 1: + return x + + # TODO: enabling mpi backend to support CPU all_to_all + assert x.device.type == "cuda", f"Currently, the collective function dual_all_to_all only supports nccl backend" + + shapes = list(x.size()) + shapes[scatter_dim] = shapes[scatter_dim] // world_size + + scatter_list = [each.contiguous() for each in torch.tensor_split(x, world_size, scatter_dim)] + gather_list = [torch.empty(*shapes, dtype=x.dtype, device=x.device) for _ in range(world_size)] + torch.distributed.all_to_all(gather_list, scatter_list, group=pg.tp_process_group()) + + return torch.cat(gather_list, dim=gather_dim).contiguous() + + +class _DualAllToAll(torch.autograd.Function): + @staticmethod + def forward(ctx, x, pg, scatter_dim, gather_dim): + ctx.scatter_dim = scatter_dim + ctx.gather_dim = gather_dim + ctx.pg = pg + return _all_to_all(x, pg, scatter_dim, gather_dim) + + @staticmethod + def backward(ctx, grad): + return _all_to_all(grad, ctx.pg, ctx.gather_dim, ctx.scatter_dim), None, None, None + + +def dual_all_to_all(x, pg, scatter_dim: int, gather_dim: int): + return _DualAllToAll.apply(x, pg, scatter_dim, gather_dim) + + +# table wise embedding shard + + +def _all_to_all_for_tablewise( + x: torch.Tensor, pg: ProcessGroup, scatter_strides: List[int], gather_strides: List[int], forward=True +) -> torch.Tensor: + world_size = pg.tp_world_size() + rank = pg.tp_local_rank() + if world_size == 1: + return x + assert x.device.type == "cuda", f"Currently, the collective function dual_all_to_all only supports nccl backend" + if forward: + scatter_list = list(x.split(scatter_strides, 0)) + gather_list = [ + torch.empty(scatter_strides[rank], gather_strides[i], dtype=x.dtype, device=x.device) + for i in range(world_size) + ] + torch.distributed.all_to_all(gather_list, scatter_list, group=pg.tp_process_group()) + return torch.cat(gather_list, 1).contiguous() + else: + # split on dim 1, lose contiguity + scatter_list = [each.contiguous() for each in x.split(scatter_strides, 1)] + gather_list = [ + torch.empty(gather_strides[i], scatter_strides[rank], dtype=x.dtype, device=x.device) + for i in range(world_size) + ] + torch.distributed.all_to_all(gather_list, scatter_list, group=pg.tp_process_group()) + return torch.cat(gather_list, 0).contiguous() + + +class _DualAllToAllForTablewise(torch.autograd.Function): + @staticmethod + def forward(ctx, x, pg, scatter_strides, gather_strides): + ctx.pg = pg + ctx.scatter_strides = scatter_strides + ctx.gather_strides = gather_strides + return _all_to_all_for_tablewise(x, pg, scatter_strides, gather_strides, forward=True) + + @staticmethod + def backward(ctx, grad): + return ( + _all_to_all_for_tablewise(grad, ctx.pg, ctx.gather_strides, ctx.scatter_strides, forward=False), + None, + None, + None, + ) + + +def dual_all_to_all_tablewise(x, pg, scatter_strides, gather_strides): + return _DualAllToAllForTablewise.apply(x, pg, scatter_strides, gather_strides) diff --git a/colossalai/legacy/nn/layer/__init__.py b/colossalai/legacy/nn/layer/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..86961dd933a73f292da722fe76467657a20e950a --- /dev/null +++ b/colossalai/legacy/nn/layer/__init__.py @@ -0,0 +1,9 @@ +from .colossalai_layer import * +from .parallel_1d import * +from .parallel_2d import * +from .parallel_2p5d import * +from .parallel_3d import * +from .parallel_sequence import * +from .utils import * +from .vanilla import * +from .wrapper import * diff --git a/colossalai/legacy/nn/layer/base_layer.py b/colossalai/legacy/nn/layer/base_layer.py new file mode 100644 index 0000000000000000000000000000000000000000..66abc6fb1fd116b24da8da88c40de37d32731eae --- /dev/null +++ b/colossalai/legacy/nn/layer/base_layer.py @@ -0,0 +1,74 @@ +#!/usr/bin/env python +# -*- encoding: utf-8 -*- + +from contextlib import contextmanager + +import torch.nn as nn + +from colossalai.legacy.context import ParallelMode +from colossalai.legacy.core import global_context as gpc + + +class ParallelLayer(nn.Module): + global_state_dict: bool = True + + def __init__(self): + super().__init__() + self.data_parallel_rank = ( + 0 if not gpc.is_initialized(ParallelMode.DATA) else gpc.get_local_rank(ParallelMode.DATA) + ) + self.data_parallel_size = ( + 1 if not gpc.is_initialized(ParallelMode.DATA) else gpc.get_world_size(ParallelMode.DATA) + ) + + self.tensor_parallel_rank = ( + 0 if not gpc.is_initialized(ParallelMode.TENSOR) else gpc.get_local_rank(ParallelMode.TENSOR) + ) + self.tensor_parallel_size = ( + 1 if not gpc.is_initialized(ParallelMode.TENSOR) else gpc.get_world_size(ParallelMode.TENSOR) + ) + + self.pipeline_parallel_rank = ( + 0 if not gpc.is_initialized(ParallelMode.PIPELINE) else gpc.get_local_rank(ParallelMode.PIPELINE) + ) + self.pipeline_parallel_size = ( + 1 if not gpc.is_initialized(ParallelMode.PIPELINE) else gpc.get_world_size(ParallelMode.PIPELINE) + ) + + def _load_from_global_state_dict( + self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs + ): + return super()._load_from_state_dict( + state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs + ) + + def _save_to_global_state_dict(self, destination, prefix, keep_vars): + return super()._save_to_state_dict(destination, prefix, keep_vars) + + def _load_from_state_dict( + self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs + ): + if self.global_state_dict: + if gpc.get_local_rank(ParallelMode.TENSOR) != 0: + missing_keys.clear() + unexpected_keys.clear() + return self._load_from_global_state_dict( + state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs + ) + return super()._load_from_state_dict( + state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs + ) + + def _save_to_state_dict(self, destination, prefix, keep_vars): + if self.global_state_dict: + return self._save_to_global_state_dict(destination, prefix, keep_vars) + return super()._save_to_state_dict(destination, prefix, keep_vars) + + @classmethod + @contextmanager + def use_local_state_dict(cls): + try: + cls.global_state_dict = False + yield + finally: + cls.global_state_dict = True diff --git a/colossalai/legacy/nn/layer/colossalai_layer/__init__.py b/colossalai/legacy/nn/layer/colossalai_layer/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..7c5449ff5578042ac6b33355f090ed996970c489 --- /dev/null +++ b/colossalai/legacy/nn/layer/colossalai_layer/__init__.py @@ -0,0 +1,7 @@ +from ._utils import partition_batch +from .dropout import Dropout +from .embedding import Embedding, PatchEmbedding +from .linear import Classifier, Linear +from .normalization import LayerNorm + +__all__ = ["Linear", "Classifier", "Embedding", "PatchEmbedding", "LayerNorm", "Dropout", "partition_batch"] diff --git a/colossalai/legacy/nn/layer/colossalai_layer/_utils.py b/colossalai/legacy/nn/layer/colossalai_layer/_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..98255142a846ea91955ffd4b68de3cab215f614c --- /dev/null +++ b/colossalai/legacy/nn/layer/colossalai_layer/_utils.py @@ -0,0 +1,40 @@ +import torch.nn as nn +from torch import Tensor + +from ..parallel_2d._operation import split_batch_2d +from ..parallel_2p5d._operation import split_batch_2p5d +from ..parallel_3d._operation import split_batch_3d +from ..utils import get_tensor_parallel_mode + +_parallel_split_batch = {"2d": split_batch_2d, "2.5d": split_batch_2p5d, "3d": split_batch_3d} + + +def partition_batch(input_) -> Tensor: + tensor_parallel_mode = get_tensor_parallel_mode() + if tensor_parallel_mode in _parallel_split_batch: + if isinstance(input_, dict): + return {k: _parallel_split_batch[tensor_parallel_mode](v) for k, v in input_.items()} + else: + return _parallel_split_batch[tensor_parallel_mode](input_) + else: + return input_ + + +class ColossalaiModule(nn.Module): + def __init__(self, module: nn.Module, **kwargs): + super().__init__() + self.module = module + for k, v in kwargs.items(): + setattr(self, k, v) + + def __getattr__(self, name: str): + if name == "module": + return super().__getattr__(name) + elif hasattr(self.module, name): + return getattr(self.module, name) + elif name in self.__dict__: + return self.__dict__[name] + raise AttributeError("'{}' object has no attribute '{}'".format(type(self).__name__, name)) + + def forward(self, *args): + return self.module(*args) diff --git a/colossalai/legacy/nn/layer/colossalai_layer/dropout.py b/colossalai/legacy/nn/layer/colossalai_layer/dropout.py new file mode 100644 index 0000000000000000000000000000000000000000..ad6fcc2d8bf439755c7fe382fd0285b79c4571c3 --- /dev/null +++ b/colossalai/legacy/nn/layer/colossalai_layer/dropout.py @@ -0,0 +1,31 @@ +import torch.nn as nn + +from colossalai.legacy.context import ParallelMode, seed + +from ..parallel_1d import * +from ..utils import get_tensor_parallel_mode +from ._utils import ColossalaiModule + + +class Dropout(ColossalaiModule): + """Dropout layer of colossalai. + + Args: + p (float, optional): probability of an element to be zeroed, defaults 0.5. + inplace (bool, optional): whether to do dropout in-place, default to be False. + """ + + def __init__(self, p: float = 0.5, inplace: bool = False) -> None: + tensor_parallel = get_tensor_parallel_mode() + if tensor_parallel == "1d": + drop = Dropout1D(p, inplace) + else: + drop = nn.Dropout(p, inplace) + super().__init__(drop, tensor_parallel=tensor_parallel) + + def forward(self, *args): + if self.tensor_parallel in [None, "1d"]: + return super().forward(*args) + else: + with seed(ParallelMode.TENSOR): + return super().forward(*args) diff --git a/colossalai/legacy/nn/layer/colossalai_layer/embedding.py b/colossalai/legacy/nn/layer/colossalai_layer/embedding.py new file mode 100644 index 0000000000000000000000000000000000000000..e1db0fe98a02de1d2f443bd072d151c7c3172ebb --- /dev/null +++ b/colossalai/legacy/nn/layer/colossalai_layer/embedding.py @@ -0,0 +1,157 @@ +import math +from typing import Callable + +from torch import dtype, nn + +from colossalai.nn import init +from colossalai.utils import get_current_device + +from ..parallel_1d import Embedding1D, PatchEmbedding1D, VocabParallelEmbedding1D +from ..parallel_2d import Embedding2D, PatchEmbedding2D, VocabParallelEmbedding2D +from ..parallel_2p5d import Embedding2p5D, PatchEmbedding2p5D, VocabParallelEmbedding2p5D +from ..parallel_3d import Embedding3D, PatchEmbedding3D, VocabParallelEmbedding3D +from ..utils import get_tensor_parallel_mode +from ..vanilla import VanillaPatchEmbedding +from ._utils import ColossalaiModule + +_parallel_embedding = { + "1d": Embedding1D, + "2d": Embedding2D, + "2.5d": Embedding2p5D, + "3d": Embedding3D, +} + +_vocab_parallel_embedding = { + "1d": VocabParallelEmbedding1D, + "2d": VocabParallelEmbedding2D, + "2.5d": VocabParallelEmbedding2p5D, + "3d": VocabParallelEmbedding3D, +} + +_parallel_patchembedding = { + None: VanillaPatchEmbedding, + "1d": PatchEmbedding1D, + "2d": PatchEmbedding2D, + "2.5d": PatchEmbedding2p5D, + "3d": PatchEmbedding3D, +} + + +class Embedding(ColossalaiModule): + r"""Embedding for colossalai. + + Args: + num_embeddings (int): number of embeddings. + embedding_dim (int): dimension of embedding. + padding_idx (int, optional): If specified, the entries at padding_idx do not contribute to the gradient; + therefore, the embedding vector at padding_idx is not updated during training, + i.e. it remains as a fixed “pad”, defaults to None. + dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None. + weight_initializer (:class:`typing.Callable`, optional): + he initializer of weight, defaults to normal initializer. + + The ``args`` and ``kwargs`` used in :class:`torch.nn.functional.embedding` should contain: + :: + + max_norm (float, optional): If given, each embedding vector with norm larger than max_norm is + renormalized to have norm max_norm. Note: this will modify weight in-place. + norm_type (float, optional): The p of the p-norm to compute for the max_norm option. Default 2. + scale_grad_by_freq (bool, optional): If given, this will scale gradients by the inverse + of frequency of the words in the mini-batch. Default False. + sparse (bool, optional): If True, gradient w.r.t. weight will be a sparse tensor. Default False. + + More details about ``args`` and ``kwargs`` could be found in + `Embedding `_. + + More details about ``initializer`` please refer to + `init `_ + """ + + def __init__( + self, + num_embeddings: int, + embedding_dim: int, + padding_idx: int = None, + dtype: dtype = None, + weight_initializer: Callable = init.normal_(), + vocab_parallel_limit: int = 2048, + *args, + **kwargs, + ) -> None: + tensor_parallel = get_tensor_parallel_mode() + if tensor_parallel is None: + embed = ( + nn.Embedding(num_embeddings, embedding_dim, padding_idx=padding_idx, *args, **kwargs) + .to(dtype) + .to(get_current_device()) + ) + weight_initializer(embed.weight, fan_in=num_embeddings, fan_out=embedding_dim) + elif num_embeddings <= vocab_parallel_limit: + embed = _parallel_embedding[tensor_parallel]( + num_embeddings, + embedding_dim, + padding_idx=padding_idx, + dtype=dtype, + weight_initializer=weight_initializer, + *args, + **kwargs, + ) + else: + embed = _vocab_parallel_embedding[tensor_parallel]( + num_embeddings, + embedding_dim, + padding_idx=padding_idx, + dtype=dtype, + weight_initializer=weight_initializer, + *args, + **kwargs, + ) + super().__init__(embed) + + +class PatchEmbedding(ColossalaiModule): + """2D Image to Patch Embedding. + + Args: + img_size (int): image size. + patch_size (int): patch size. + in_chans (int): number of channels of input image. + embed_size (int): size of embedding. + dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None. + flatten (bool, optional): whether to flatten output tensor, defaults to True. + weight_initializer (:class:`typing.Callable`, optional): + The initializer of weight, defaults to kaiming uniform initializer. + bias_initializer (:class:`typing.Callable`, optional): + The initializer of bias, defaults to xavier uniform initializer. + position_embed_initializer (:class:`typing.Callable`, optional): + The initializer of position embedding, defaults to zeros initializer. + + More details about ``initializer`` please refer to + `init `_. + """ + + def __init__( + self, + img_size: int, + patch_size: int, + in_chans: int, + embed_size: int, + dtype: dtype = None, + flatten: bool = True, + weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), + bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1), + position_embed_initializer: Callable = init.zeros_(), + ) -> None: + tensor_parallel = get_tensor_parallel_mode() + embed = _parallel_patchembedding[tensor_parallel]( + img_size, + patch_size, + in_chans, + embed_size, + dtype=dtype, + flatten=flatten, + weight_initializer=weight_initializer, + bias_initializer=bias_initializer, + position_embed_initializer=position_embed_initializer, + ) + super().__init__(embed) diff --git a/colossalai/legacy/nn/layer/colossalai_layer/linear.py b/colossalai/legacy/nn/layer/colossalai_layer/linear.py new file mode 100644 index 0000000000000000000000000000000000000000..aa4863e28b81da1cb9922194138eeec8449b4ecd --- /dev/null +++ b/colossalai/legacy/nn/layer/colossalai_layer/linear.py @@ -0,0 +1,144 @@ +import inspect +import math +from typing import Callable + +from torch import dtype, nn + +from colossalai.nn import init + +from ..parallel_1d import * +from ..parallel_2d import * +from ..parallel_2p5d import * +from ..parallel_3d import * +from ..utils import get_tensor_parallel_mode +from ..vanilla import * +from ._utils import ColossalaiModule + +_parallel_linear = {None: VanillaLinear, "1d": Linear1D, "2d": Linear2D, "2.5d": Linear2p5D, "3d": Linear3D} + +_parallel_classifier = { + None: VanillaClassifier, + "1d": Classifier1D, + "2d": Classifier2D, + "2.5d": Classifier2p5D, + "3d": Classifier3D, +} + +_vocab_parallel_classifier = { + "1d": VocabParallelClassifier1D, + "2d": VocabParallelClassifier2D, + "2.5d": VocabParallelClassifier2p5D, + "3d": VocabParallelClassifier3D, +} + + +class Linear(ColossalaiModule): + """Linear layer of colossalai. + + Args: + in_features (int): size of each input sample. + out_features (int): size of each output sample. + bias (bool, optional): If set to ``False``, the layer will not learn an additive bias, defaults to ``True``. + dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None. + weight_initializer (:class:`typing.Callable`, optional): + The initializer of weight, defaults to kaiming uniform initializer. + bias_initializer (:class:`typing.Callable`, optional): + The initializer of bias, defaults to xavier uniform initializer. + + Note: ``kwargs`` would contain different parameters when you use different parallelisms. + + The ``kwargs`` should contain parameters below: + :: + + Linear1D: + gather_output: bool (optional, default to be false) + skip_bias_add: bool (optional, default to be false) + Linear2D: + skip_bias_add: bool (optional, default to be false) + Linear2p5D: + skip_bias_add: bool (optional, default to be false) + Linear3D: + None + + More details about ``initializer`` please refer to + `init `_. + """ + + def __init__( + self, + in_features: int, + out_features: int, + bias: bool = True, + dtype: dtype = None, + weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), + bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1), + **kwargs, + ) -> None: + tensor_parallel = get_tensor_parallel_mode() + linear_cls = _parallel_linear[tensor_parallel] + gather_output = kwargs.pop("gather_output", None) + if "gather_output" in inspect.signature(linear_cls.__init__).parameters.keys(): # gather_out arg is available + kwargs["gather_output"] = gather_output + layer = linear_cls( + in_features, + out_features, + bias=bias, + dtype=dtype, + weight_initializer=weight_initializer, + bias_initializer=bias_initializer, + **kwargs, + ) + super().__init__(layer) + + +class Classifier(ColossalaiModule): + """Classifier layer of colossalai. + + Args: + in_features (int): size of each input sample. + num_classes (int): number of classes. + weight (:class:`torch.nn.Parameter`, optional): weight of the classifier, defaults to None. + bias (bool, optional): If set to ``False``, the layer will not learn an additive bias, defaults to ``True``. + dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None. + weight_initializer (:class:`typing.Callable`, optional): + The initializer of weight, defaults to kaiming uniform initializer. + bias_initializer (:class:`typing.Callable`, optional): + The initializer of bias, defaults to xavier uniform initializer. + + More details about ``initializer`` please refer to + `init `_. + """ + + def __init__( + self, + in_features: int, + num_classes: int, + weight: nn.Parameter = None, + bias: bool = True, + dtype: dtype = None, + weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), + bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1), + vocab_parallel_limit: int = 2048, + ) -> None: + tensor_parallel = get_tensor_parallel_mode() + if num_classes <= vocab_parallel_limit or tensor_parallel is None: + layer = _parallel_classifier[tensor_parallel]( + in_features, + num_classes, + weight=weight, + bias=bias, + dtype=dtype, + weight_initializer=weight_initializer, + bias_initializer=bias_initializer, + ) + else: + layer = _vocab_parallel_classifier[tensor_parallel]( + in_features, + num_classes, + weight=weight, + bias=bias, + dtype=dtype, + weight_initializer=weight_initializer, + bias_initializer=bias_initializer, + ) + super().__init__(layer) diff --git a/colossalai/legacy/nn/layer/colossalai_layer/normalization.py b/colossalai/legacy/nn/layer/colossalai_layer/normalization.py new file mode 100644 index 0000000000000000000000000000000000000000..f8e317e723f112baf3a381c01074cb2d30bd5361 --- /dev/null +++ b/colossalai/legacy/nn/layer/colossalai_layer/normalization.py @@ -0,0 +1,42 @@ +from torch import nn + +from colossalai.utils import get_current_device + +from ..parallel_1d import LayerNorm1D +from ..parallel_2d import LayerNorm2D +from ..parallel_2p5d import LayerNorm2p5D +from ..parallel_3d import LayerNorm3D +from ..utils import get_tensor_parallel_mode +from ..vanilla import VanillaLayerNorm +from ._utils import ColossalaiModule + +_parallel_layernorm = { + None: VanillaLayerNorm, + "1d": LayerNorm1D, + "2d": LayerNorm2D, + "2.5d": LayerNorm2p5D, + "3d": LayerNorm3D, +} + + +class LayerNorm(ColossalaiModule): + r"""Layer Normalization for colossalai. + + Args: + normalized_shape (int): input shape from an expected input of size. + :math:`[* \times \text{normalized_shape}[0] \times \text{normalized_shape}[1] + \times \ldots \times \text{normalized_shape}[-1]]` + If a single integer is used, it is treated as a singleton list, and this module will + normalize over the last dimension which is expected to be of that specific size. + eps (float): a value added to the denominator for numerical stability, defaults to 1e-05. + bias (bool, optional): Whether to add a bias, defaults to ``True``. + dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None. + """ + + def __init__(self, normalized_shape: int, eps=1e-05, bias=True, dtype=None) -> None: + tensor_parallel = get_tensor_parallel_mode() + if tensor_parallel is None: + norm = nn.LayerNorm(normalized_shape, eps=eps).to(dtype).to(get_current_device()) + else: + norm = _parallel_layernorm[tensor_parallel](normalized_shape, eps=eps, dtype=dtype) + super().__init__(norm) diff --git a/colossalai/legacy/nn/layer/parallel_1d/__init__.py b/colossalai/legacy/nn/layer/parallel_1d/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..35e9ec40d100ff9bab61f814a5d1256113ff8172 --- /dev/null +++ b/colossalai/legacy/nn/layer/parallel_1d/__init__.py @@ -0,0 +1,25 @@ +from .layers import ( + Classifier1D, + Dropout1D, + Embedding1D, + LayerNorm1D, + Linear1D, + Linear1D_Col, + Linear1D_Row, + PatchEmbedding1D, + VocabParallelClassifier1D, + VocabParallelEmbedding1D, +) + +__all__ = [ + "Linear1D", + "Linear1D_Col", + "Linear1D_Row", + "Embedding1D", + "Dropout1D", + "Classifier1D", + "VocabParallelClassifier1D", + "VocabParallelEmbedding1D", + "LayerNorm1D", + "PatchEmbedding1D", +] diff --git a/colossalai/legacy/nn/layer/parallel_1d/_operation.py b/colossalai/legacy/nn/layer/parallel_1d/_operation.py new file mode 100644 index 0000000000000000000000000000000000000000..f01da97ba39ab8664949b5eb08ddce6fe8af0861 --- /dev/null +++ b/colossalai/legacy/nn/layer/parallel_1d/_operation.py @@ -0,0 +1,96 @@ +import torch +import torch.distributed as dist + +from colossalai.legacy.core import global_context as gpc + +try: + import fused_mix_prec_layer_norm_cuda +except: + fused_mix_prec_layer_norm_cuda = None + + +class FusedLayerNormAffineFunction1D(torch.autograd.Function): + r"""Layernorm + + Args: + input: input matrix. + weight: weight matrix. + bias: bias matrix. + normalized_shape: input shape from an expected input of size. + :math:`[* \times \text{normalized_shape}[0] \times \text{normalized_shape}[1] \times \ldots \times \text{normalized_shape}[-1]]` + If a single integer is used, it is treated as a singleton list, and this module will + normalize over the last dimension which is expected to be of that specific size. + eps: a value added to the denominator for numerical stability + """ + + @staticmethod + def forward(ctx, input, weight, bias, normalized_shape, eps): + ctx.normalized_shape = normalized_shape + ctx.eps = eps + input_ = input.contiguous() + weight_ = weight.contiguous() + bias_ = bias.contiguous() + output, mean, invvar = fused_mix_prec_layer_norm_cuda.forward_affine( + input_, ctx.normalized_shape, weight_, bias_, ctx.eps + ) + ctx.save_for_backward(input_, weight_, bias_, mean, invvar) + return output + + @staticmethod + def backward(ctx, grad_output): + input_, weight_, bias_, mean, invvar = ctx.saved_tensors + grad_input = grad_weight = grad_bias = None + grad_input, grad_weight, grad_bias = fused_mix_prec_layer_norm_cuda.backward_affine( + grad_output.contiguous(), mean, invvar, input_, ctx.normalized_shape, weight_, bias_, ctx.eps + ) + + return grad_input, grad_weight, grad_bias, None, None + + +class LinearWithAsyncCommunication(torch.autograd.Function): + """ + Linear layer execution with asynchronous communication in backprop. + """ + + @staticmethod + def forward(ctx, input_, weight, bias, parallel_mode, async_grad_allreduce): + ctx.save_for_backward(input_, weight) + ctx.use_bias = bias is not None + ctx.parallel_mode = parallel_mode + ctx.async_grad_allreduce = async_grad_allreduce + + output = torch.matmul(input_, weight.t()) + if bias is not None: + output = output + bias + return output + + @staticmethod + def backward(ctx, grad_output): + input, weight = ctx.saved_tensors + use_bias = ctx.use_bias + + total_input = input + grad_input = grad_output.matmul(weight) + + # Convert the tensor shapes to 2D for execution compatibility + grad_output = grad_output.view(grad_output.shape[0] * grad_output.shape[1], grad_output.shape[2]) + total_input = total_input.view(total_input.shape[0] * total_input.shape[1], total_input.shape[2]) + + if ctx.async_grad_allreduce: + # Asynchronous all-reduce + handle = dist.all_reduce(grad_input, group=gpc.get_group(ctx.parallel_mode), async_op=True) + # Delay the start of weight gradient computation shortly (3us) to have + # all-reduce scheduled first and have GPU resources allocated + _ = torch.empty(1, device=grad_output.device) + 1 + + grad_weight = grad_output.t().matmul(total_input) + grad_bias = grad_output.sum(dim=0) if use_bias else None + + if ctx.async_grad_allreduce: + handle.wait() + + return grad_input, grad_weight, grad_bias, None, None, None + + +def linear_with_async_comm(input_, weight, bias, parallel_mode, async_grad_allreduce): + return LinearWithAsyncCommunication.apply(input_, weight, bias, parallel_mode, async_grad_allreduce) diff --git a/colossalai/legacy/nn/layer/parallel_1d/_utils.py b/colossalai/legacy/nn/layer/parallel_1d/_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..93b476e811a45be74b34b0e4e7b2db6e0073f9b4 --- /dev/null +++ b/colossalai/legacy/nn/layer/parallel_1d/_utils.py @@ -0,0 +1,188 @@ +#!/usr/bin/env python +# -*- encoding: utf-8 -*- + +import torch +import torch.distributed as dist + +from colossalai.legacy.core import global_context as gpc +from colossalai.legacy.global_variables import tensor_parallel_env as env + +from ..utils import divide + + +def set_parallel_input(input_parallel: bool): + env.parallel_input_1d = input_parallel + + +def get_parallel_input(): + return env.parallel_input_1d + + +def vocab_range_from_per_partition_vocab_size(per_partition_vocab_size, rank): + index_f = rank * per_partition_vocab_size + index_l = index_f + per_partition_vocab_size + return index_f, index_l + + +def vocab_range_from_global_vocab_size(global_vocab_size, rank, world_size): + per_partition_vocab_size = divide(global_vocab_size, world_size) + return vocab_range_from_per_partition_vocab_size(per_partition_vocab_size, rank) + + +def _reduce(input_, parallel_mode): + # skip if only one rank involved + if gpc.get_world_size(parallel_mode) == 1: + return input_ + group = gpc.get_cpu_group(parallel_mode) if input_.device.type == "cpu" else gpc.get_group(parallel_mode) + dist.all_reduce(input_, group=group) + + return input_ + + +def _split(input_, parallel_mode, dim=-1): + # skip if only one rank involved + world_size = gpc.get_world_size(parallel_mode) + if world_size == 1: + return input_ + + # Split along last dimension. + dim_size = input_.size(dim) + assert dim_size % world_size == 0, ( + f"The dimension to split ({dim_size}) is not a multiple of world size ({world_size}), " + f"cannot split tensor evenly" + ) + + tensor_list = torch.split(input_, dim_size // world_size, dim=dim) + rank = gpc.get_local_rank(parallel_mode) + output = tensor_list[rank].contiguous() + + return output + + +def _gather(input_, parallel_mode, dim=-1): + # skip if only one rank involved + world_size = gpc.get_world_size(parallel_mode) + if world_size == 1: + return input_ + + # all gather + rank = gpc.get_local_rank(parallel_mode) + tensor_list = [torch.empty_like(input_) for _ in range(world_size)] + tensor_list[rank] = input_ + group = gpc.get_cpu_group(parallel_mode) if input_.device.type == "cpu" else gpc.get_group(parallel_mode) + torch.distributed.all_gather(tensor_list, input_, group=group) + + # concat + output = torch.cat(tensor_list, dim=dim).contiguous() + + return output + + +class _ReduceGrad(torch.autograd.Function): + """ + Pass the input to the model parallel region. + + Args: + input_: input matrix. + parallel_mode: parallel mode. + """ + + @staticmethod + def symbolic(graph, input_): + return input_ + + @staticmethod + def forward(ctx, input_, parallel_mode): + ctx.mode = parallel_mode + return input_ + + @staticmethod + def backward(ctx, grad_output): + return _reduce(grad_output, ctx.mode), None + + +class _ReduceInput(torch.autograd.Function): + """ + All-reduce the input from the model parallel region. + + Args: + input_: input matrix. + parallel_mode: parallel mode. + """ + + @staticmethod + def symbolic(graph, input_): + return _reduce(input_) + + @staticmethod + def forward(ctx, input_, parallel_mode): + return _reduce(input_, parallel_mode) + + @staticmethod + def backward(ctx, grad_output): + return grad_output, None + + +class _SplitForwardGatherBackward(torch.autograd.Function): + """ + Split the input and keep only the corresponding chuck to the rank. + + Args: + input_: input matrix. + parallel_mode: parallel mode. + dim: dimension + """ + + @staticmethod + def symbolic(graph, input_): + return _split(input_) + + @staticmethod + def forward(ctx, input_, parallel_mode, dim): + ctx.mode = parallel_mode + ctx.dim = dim + return _split(input_, parallel_mode, dim) + + @staticmethod + def backward(ctx, grad_output): + return _gather(grad_output, ctx.mode, ctx.dim), None, None + + +class _GatherForwardSplitBackward(torch.autograd.Function): + """Gather the input from model parallel region and concatenate. + + Args: + input_: input matrix. + parallel_mode: parallel mode. + dim: dimension + """ + + @staticmethod + def symbolic(graph, input_): + return _gather(input_) + + @staticmethod + def forward(ctx, input_, parallel_mode, dim): + ctx.mode = parallel_mode + ctx.dim = dim + return _gather(input_, parallel_mode, dim) + + @staticmethod + def backward(ctx, grad_output): + return _split(grad_output, ctx.mode, ctx.dim), None, None + + +def reduce_grad(input_, parallel_mode): + return _ReduceGrad.apply(input_, parallel_mode) + + +def reduce_input(input_, parallel_mode): + return _ReduceInput.apply(input_, parallel_mode) + + +def split_forward_gather_backward(input_, parallel_mode, dim): + return _SplitForwardGatherBackward.apply(input_, parallel_mode, dim) + + +def gather_forward_split_backward(input_, parallel_mode, dim): + return _GatherForwardSplitBackward.apply(input_, parallel_mode, dim) diff --git a/colossalai/legacy/nn/layer/parallel_1d/layers.py b/colossalai/legacy/nn/layer/parallel_1d/layers.py new file mode 100644 index 0000000000000000000000000000000000000000..8304cd2e1eb770c416856077db9c01d77e17d22e --- /dev/null +++ b/colossalai/legacy/nn/layer/parallel_1d/layers.py @@ -0,0 +1,1073 @@ +#!/usr/bin/env python +# -*- encoding: utf-8 -*- + +import math +from collections import OrderedDict +from typing import Callable, Tuple + +import torch +import torch.nn.functional as F +from torch import Tensor +from torch.nn.parameter import Parameter + +from colossalai.kernel import LayerNorm +from colossalai.legacy.communication import broadcast +from colossalai.legacy.context import ParallelMode, seed +from colossalai.legacy.context.parallel_context import global_context as gpc +from colossalai.legacy.global_variables import tensor_parallel_env as env +from colossalai.legacy.registry import LAYERS +from colossalai.legacy.utils.checkpointing import ( + broadcast_state_dict, + gather_tensor_parallel_state_dict, + partition_tensor_parallel_state_dict, +) +from colossalai.nn import init as init +from colossalai.utils.cuda import get_current_device + +from ..base_layer import ParallelLayer +from ..colossalai_layer._utils import ColossalaiModule +from ..utils import divide, set_tensor_parallel_attribute_by_partition +from ..vanilla import VanillaPatchEmbedding +from ._operation import linear_with_async_comm +from ._utils import ( + gather_forward_split_backward, + get_parallel_input, + reduce_grad, + reduce_input, + set_parallel_input, + split_forward_gather_backward, +) + +Fast_LN = None +try: + from apex.contrib.layer_norm.layer_norm import FastLayerNorm + + Fast_LN = FastLayerNorm +except ImportError: + pass + + +@LAYERS.register_module +class Linear1D(ColossalaiModule): + r"""Linear layer for 1D parallelism. + + Args: + in_features (int): size of each input sample. + out_features (int): size of each output sample. + bias (bool, optional): If set to ``False``, the layer will not learn an additive bias, defaults to ``True``. + dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None. + gather_output (bool, optional): Whether to call all-gather on output, defaults to False. + skip_bias_add (bool, optional): If set to ``True``, it will skip bias add for linear layer, + which is preserved for kernel fusion, defaults to False + weight_initializer (:class:`typing.Callable`, optional): + The initializer of weight, defaults to kaiming uniform initializer. + bias_initializer (:class:`typing.Callable`, optional): + The initializer of bias, defaults to xavier uniform initializer. + + More details about ``initializer`` please refer to + `init `_. + """ + + def __init__( + self, + in_features: int, + out_features: int, + bias: bool = True, + dtype: torch.dtype = None, + gather_output: bool = False, + skip_bias_add: bool = False, + weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), + bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1), + ): + parallel_input = get_parallel_input() + if not parallel_input and not gather_output: + layer = Linear1D_Col( + in_features, + out_features, + bias=bias, + dtype=dtype, + skip_bias_add=skip_bias_add, + weight_initializer=weight_initializer, + bias_initializer=bias_initializer, + ) + else: + layer = Linear1D_Row( + in_features, + out_features, + bias=bias, + dtype=dtype, + parallel_input=parallel_input, + skip_bias_add=skip_bias_add, + weight_initializer=weight_initializer, + bias_initializer=bias_initializer, + ) + super().__init__(layer) + + +@LAYERS.register_module +class LayerNorm1D(ColossalaiModule): + r""" + Layer Normalization for colossalai + + Args: + normalized_shape (int): input shape from an expected input of size. + :math:`[* \times \text{normalized_shape}[0] \times \text{normalized_shape}[1] + \times \ldots \times \text{normalized_shape}[-1]]` + If a single integer is used, it is treated as a singleton list, and this module will + normalize over the last dimension which is expected to be of that specific size. + eps (float): a value added to the denominator for numerical stability, defaults to 1e-05. + bias (bool, optional): Whether to add a bias, defaults to ``True``. + dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None. + """ + + _fast_ln_supported_sizes = [ + 1024, + 1536, + 2048, + 2304, + 3072, + 3840, + 4096, + 5120, + 6144, + 8192, + 10240, + 12288, + 12800, + 15360, + 16384, + 18432, + 20480, + 24576, + 25600, + 30720, + 32768, + 40960, + 49152, + 65536, + ] + + def __init__(self, normalized_shape: int, eps=1e-05, bias=True, dtype=None): + if Fast_LN is not None and normalized_shape in self._fast_ln_supported_sizes: + norm = Fast_LN(normalized_shape, eps=eps).to(dtype) + else: + norm = None + try: + from apex.normalization import FusedLayerNorm + + norm = FusedLayerNorm(normalized_shape, eps=eps).to(dtype) + except ImportError: + norm = LayerNorm(normalized_shape, eps=eps).to(dtype) + super().__init__(norm) + + def _load_from_state_dict(self, state_dict, prefix, *args): + local_state = OrderedDict() + weight_key = prefix + "weight" + bias_key = prefix + "bias" + if gpc.get_local_rank(ParallelMode.TENSOR) == 0: + # weight + weight = state_dict.pop(weight_key, None) + if weight is not None: + local_state[weight_key] = weight + # bias + bias = state_dict.pop(bias_key, None) + if bias is not None: + local_state[bias_key] = bias + + local_state = broadcast_state_dict(local_state, ParallelMode.PARALLEL_1D) + super()._load_from_state_dict(local_state, prefix, *args) + + def _save_to_state_dict(self, destination, prefix, keep_vars): + if gpc.get_local_rank(ParallelMode.TENSOR) == 0: + super()._save_to_state_dict(destination, prefix, keep_vars) + + +@LAYERS.register_module +class Classifier1D(ParallelLayer): + r"""RowLinear with given weight. Classifier of 1D parallelism. + + Args: + in_features (int): size of each input sample. + num_classes (int): number of classes. + weight (:class:`torch.nn.Parameter`, optional): weight of the classifier, defaults to None. + bias (bool, optional): If set to ``False``, the layer will not learn an additive bias, defaults to ``True``. + dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None. + weight_initializer (:class:`typing.Callable`, optional): + The initializer of weight, defaults to kaiming uniform initializer. + bias_initializer (:class:`typing.Callable`, optional): + The initializer of bias, defaults to xavier uniform initializer. + + More details about ``initializer`` please refer to + `init `_. + """ + + def __init__( + self, + in_features: int, + num_classes: int, + weight: Parameter = None, + bias: bool = True, + dtype: torch.dtype = None, + weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), + bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1), + ): + super().__init__() + self.in_features = in_features + self.num_classes = num_classes + self.parallel_input = get_parallel_input() + + # Divide the weight matrix along the last dimension. + self.input_size_per_partition = divide(in_features, gpc.tensor_parallel_size) + + # Parameters. + # Initialize weight. + factory_kwargs = {"device": get_current_device(), "dtype": dtype} + if weight is not None: + self.weight = weight + self.has_weight = False + else: + self.weight = Parameter(torch.empty(self.num_classes, self.input_size_per_partition, **factory_kwargs)) + self.has_weight = True + if bias: + self.bias = Parameter(torch.empty(self.num_classes, **factory_kwargs)) + else: + self.bias = None + with seed(ParallelMode.TENSOR): + self.reset_parameters(weight_initializer, bias_initializer) + self._set_tensor_parallel_attributes() + set_parallel_input(False) + env.vocab_parallel = False + + def reset_parameters(self, weight_initializer, bias_initializer) -> None: + fan_in, fan_out = self.in_features, self.num_classes + if self.has_weight: + weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out) + if self.bias is not None: + bias_initializer(self.bias, fan_in=fan_in) + broadcast(self.bias, gpc.get_ranks_in_group(ParallelMode.PARALLEL_1D)[0], ParallelMode.PARALLEL_1D) + + def _set_tensor_parallel_attributes(self): + if self.has_weight: + num_partition = gpc.get_world_size(ParallelMode.TENSOR) + set_tensor_parallel_attribute_by_partition(self.weight, num_partition) + + def _load_from_global_state_dict(self, state_dict, prefix, *args): + local_state = OrderedDict() + weight_key = prefix + "weight" + bias_key = prefix + "bias" + if gpc.get_local_rank(ParallelMode.TENSOR) == 0: + # weight + if self.has_weight: + weight = state_dict.pop(weight_key, None) + if weight is not None: + local_state[weight_key] = weight + # bias + if self.bias is not None: + bias = state_dict.pop(bias_key, None) + if bias is not None: + local_state[bias_key] = bias + + local_state = partition_tensor_parallel_state_dict( + local_state, + ParallelMode.PARALLEL_1D, + dims={weight_key: -1, bias_key: 0}, + partition_states={weight_key: True, bias_key: False}, + ) + super()._load_from_global_state_dict(local_state, prefix, *args) + + def _save_to_global_state_dict(self, destination, prefix, keep_vars): + weight_key = prefix + "weight" + bias_key = prefix + "bias" + local_state = OrderedDict() + if self.has_weight: + local_state[weight_key] = self.weight + if self.bias is not None: + local_state[bias_key] = self.bias + local_state = gather_tensor_parallel_state_dict( + local_state, + ParallelMode.PARALLEL_1D, + dims={weight_key: -1, bias_key: 0}, + partition_states={weight_key: True, bias_key: False}, + keep_vars=keep_vars, + ) + destination.update(local_state) + + def forward(self, input_: Tensor) -> Tensor: + # Set up backprop all-reduce. + if self.parallel_input: + assert ( + input_.shape[-1] == self.weight.shape[-1] + ), "Invalid shapes in Classifier1D forward: input={}, weight={}. Expected last dim of input {}.".format( + input_.shape, self.weight.shape, self.weight.shape[-1] + ) + input_ = input_ + else: + assert ( + divide(input_.shape[-1], gpc.tensor_parallel_size) == self.weight.shape[-1] + ), "Invalid shapes in Classifier1D forward: input={}, weight={}. Expected last dim of input {}.".format( + input_.shape, self.weight.shape, self.weight.shape[-1] * gpc.tensor_parallel_size + ) + input_ = split_forward_gather_backward(input_, ParallelMode.PARALLEL_1D, dim=-1) + + output_parallel = F.linear(input_, self.weight) + output = reduce_input(output_parallel, ParallelMode.PARALLEL_1D) + if self.bias is not None: + output = output + self.bias + return output + + +@LAYERS.register_module +class VocabParallelClassifier1D(ParallelLayer): + r"""ColLinear with given weight. Classifier of 1D parallelism. + + Args: + in_features (int): size of each input sample. + num_classes (int): number of classes. + weight (:class:`torch.nn.Parameter`, optional): weight of the classifier, defaults to None. + bias (bool, optional): If set to ``False``, the layer will not learn an additive bias, defaults to ``True``. + dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None. + weight_initializer (:class:`typing.Callable`, optional): + The initializer of weight, defaults to kaiming uniform initializer. + bias_initializer (:class:`typing.Callable`, optional): + The initializer of bias, defaults to xavier uniform initializer. + + More details about ``initializer`` please refer to + `init `_. + """ + + def __init__( + self, + in_features: int, + num_classes: int, + weight: Parameter = None, + bias: bool = True, + dtype: torch.dtype = None, + gather_output: bool = False, + weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), + bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1), + ): + super().__init__() + self.in_features = in_features + self.num_classes = num_classes + self.gather_output = gather_output + self.parallel_input = get_parallel_input() + + # Divide the weight matrix along the last dimension. + self.num_classes_per_partition = divide(num_classes, gpc.tensor_parallel_size) + + # Parameters. + # Initialize weight. + factory_kwargs = {"device": get_current_device(), "dtype": dtype} + if weight is not None: + self.weight = weight + self.has_weight = False + else: + self.weight = Parameter(torch.empty(self.num_classes_per_partition, self.in_features, **factory_kwargs)) + self.has_weight = True + if bias: + self.bias = Parameter(torch.empty(self.num_classes_per_partition, **factory_kwargs)) + else: + self.bias = None + with seed(ParallelMode.TENSOR): + self.reset_parameters(weight_initializer, bias_initializer) + self._set_tensor_parallel_attributes() + set_parallel_input(False) + env.vocab_parallel = True + + def reset_parameters(self, weight_initializer, bias_initializer) -> None: + fan_in, fan_out = self.in_features, self.num_classes + if self.has_weight: + weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out) + if self.bias is not None: + bias_initializer(self.bias, fan_in=fan_in) + + def _set_tensor_parallel_attributes(self): + num_partition = gpc.get_world_size(ParallelMode.TENSOR) + if self.has_weight: + set_tensor_parallel_attribute_by_partition(self.weight, num_partition) + if self.bias is not None: + set_tensor_parallel_attribute_by_partition(self.bias, num_partition) + + def _load_from_global_state_dict(self, state_dict, prefix, *args): + local_state = OrderedDict() + weight_key = prefix + "weight" + bias_key = prefix + "bias" + if gpc.get_local_rank(ParallelMode.TENSOR) == 0: + # weight + if self.has_weight: + weight = state_dict.pop(weight_key, None) + if weight is not None: + local_state[weight_key] = weight + # bias + if self.bias is not None: + bias = state_dict.pop(bias_key, None) + if bias is not None: + local_state[bias_key] = bias + + local_state = partition_tensor_parallel_state_dict( + local_state, + ParallelMode.PARALLEL_1D, + dims={weight_key: 0, bias_key: 0}, + partition_states={weight_key: True, bias_key: True}, + ) + super()._load_from_global_state_dict(local_state, prefix, *args) + + def _save_to_global_state_dict(self, destination, prefix, keep_vars): + weight_key = prefix + "weight" + bias_key = prefix + "bias" + local_state = OrderedDict() + if self.has_weight: + local_state[weight_key] = self.weight + if self.bias is not None: + local_state[bias_key] = self.bias + local_state = gather_tensor_parallel_state_dict( + local_state, + ParallelMode.PARALLEL_1D, + dims={weight_key: 0, bias_key: 0}, + partition_states={weight_key: True, bias_key: True}, + keep_vars=keep_vars, + ) + destination.update(local_state) + + def forward(self, input_: Tensor) -> Tensor: + assert ( + input_.shape[-1] == self.weight.shape[-1] + ), "Invalid shapes in VocabParallelClassifier1D forward: input={}, weight={}. Expected last dim of input {}.".format( + input_.shape, self.weight.shape, self.weight.shape[-1] + ) + # Set up backprop all-reduce. + input_parallel = reduce_grad(input_, ParallelMode.PARALLEL_1D) + # Matrix multiply. + output_parallel = F.linear(input_parallel, self.weight, self.bias) + if self.gather_output: + # All-gather across the partitions. + output = gather_forward_split_backward(output_parallel, ParallelMode.PARALLEL_1D, dim=-1) + else: + output = output_parallel + return output + + +@LAYERS.register_module +class Linear1D_Col(ParallelLayer): + r"""Linear layer with column parallelism. + + The linear layer is defined as :math:`Y = XA + b`. A is parallelized along + its second dimension as :math:`A = [A_1, ..., A_p]`. + + Args: + in_features (int): size of each input sample. + out_features (int): size of each output sample. + bias (bool, optional): If set to ``False``, the layer will not learn an additive bias, defaults to ``True``. + dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None. + gather_output (bool, optional): If true, call all-gather on output and make Y available + to all GPUs, otherwise, every GPU will have its output + which is :math:`Y_i = XA_i`, defaults to False + skip_bias_add (bool, optional): If set to ``True``, it will skip bias add for linear layer, + which is preserved for kernel fusion, defaults to False + weight_initializer (:class:`typing.Callable`, optional): + The initializer of weight, defaults to kaiming uniform initializer. + bias_initializer (:class:`typing.Callable`, optional): + The initializer of bias, defaults to xavier uniform initializer. + + More details about ``initializer`` please refer to + `init `_. + """ + + def __init__( + self, + in_features: int, + out_features: int, + bias: bool = True, + dtype: torch.dtype = None, + gather_output: bool = False, + skip_bias_add: bool = False, + weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), + bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1), + ): + super().__init__() + + # Keep input parameters + self.in_features = in_features + self.out_features = out_features + self.gather_output = gather_output + self.skip_bias_add = skip_bias_add + + if skip_bias_add and not bias: + raise ValueError("cannot skip bias addition if bias is None") + + self.out_features_per_partition = divide(out_features, gpc.tensor_parallel_size) + + # Parameters. + # Initialize weight. + factory_kwargs = {"device": get_current_device(), "dtype": dtype} + self.weight = Parameter(torch.empty(self.out_features_per_partition, self.in_features, **factory_kwargs)) + + if bias: + self.bias = Parameter(torch.empty(self.out_features_per_partition, **factory_kwargs)) + else: + self.bias = None + with seed(ParallelMode.TENSOR): + self.reset_parameters(weight_initializer, bias_initializer) + self._set_tensor_parallel_attributes() + is_parallel_output = not self.gather_output + set_parallel_input(is_parallel_output) + + def reset_parameters(self, weight_initializer, bias_initializer) -> None: + fan_in, fan_out = self.in_features, self.out_features + weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out) + if self.bias is not None: + bias_initializer(self.bias, fan_in=fan_in) + + def _set_tensor_parallel_attributes(self): + num_partition = gpc.get_world_size(ParallelMode.TENSOR) + set_tensor_parallel_attribute_by_partition(self.weight, num_partition) + if self.bias is not None: + set_tensor_parallel_attribute_by_partition(self.bias, num_partition) + + def _load_from_global_state_dict(self, state_dict, prefix, *args): + local_state = OrderedDict() + weight_key = prefix + "weight" + bias_key = prefix + "bias" + if gpc.get_local_rank(ParallelMode.TENSOR) == 0: + # weight + weight = state_dict.pop(weight_key, None) + if weight is not None: + local_state[weight_key] = weight + # bias + if self.bias is not None: + bias = state_dict.pop(bias_key, None) + if bias is not None: + local_state[bias_key] = bias + + local_state = partition_tensor_parallel_state_dict( + local_state, + ParallelMode.PARALLEL_1D, + dims={weight_key: 0, bias_key: 0}, + partition_states={weight_key: True, bias_key: True}, + ) + super()._load_from_global_state_dict(local_state, prefix, *args) + + def _save_to_global_state_dict(self, destination, prefix, keep_vars): + weight_key = prefix + "weight" + bias_key = prefix + "bias" + local_state = OrderedDict({weight_key: self.weight}) + if self.bias is not None: + local_state[bias_key] = self.bias + local_state = gather_tensor_parallel_state_dict( + local_state, + ParallelMode.PARALLEL_1D, + dims={weight_key: 0, bias_key: 0}, + partition_states={weight_key: True, bias_key: True}, + keep_vars=keep_vars, + ) + destination.update(local_state) + + def forward(self, input_: Tensor) -> Tuple[Tensor, Tensor]: + assert ( + input_.shape[-1] == self.weight.shape[-1] + ), "Invalid shapes in Linear1D_Col forward: input={}, weight={}. Expected last dim of input {}.".format( + input_.shape, self.weight.shape, self.weight.shape[-1] + ) + # Set up backprop all-reduce. + # input_parallel = reduce_grad(input_, ParallelMode.PARALLEL_1D) + input_parallel = input_ + # Matrix multiply. + bias = self.bias if not self.skip_bias_add else None + # output_parallel = F.linear(input_parallel, self.weight, bias) + output_parallel = linear_with_async_comm(input_parallel, self.weight, bias, ParallelMode.PARALLEL_1D, True) + if self.gather_output: + # All-gather across the partitions. + output = gather_forward_split_backward(output_parallel, ParallelMode.PARALLEL_1D, dim=-1) + else: + output = output_parallel + + if self.skip_bias_add: + return output, self.bias + else: + return output + + +@LAYERS.register_module +class Linear1D_Row(ParallelLayer): + r"""Linear layer with row parallelism + + Args: + in_features (int): size of each input sample. + out_features (int): size of each output sample. + bias (bool, optional): If set to ``False``, the layer will not learn an additive bias, defaults to ``True``. + dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None. + parallel_input (bool, optional): If set to ``True``, it's assumed that the input is split, defaults to False. + skip_bias_add (bool, optional): If set to ``True``, it will skip bias add for linear layer, + which is preserved for kernel fusion, defaults to False + weight_initializer (:class:`typing.Callable`, optional): + The initializer of weight, defaults to kaiming uniform initializer. + bias_initializer (:class:`typing.Callable`, optional): + The initializer of bias, defaults to xavier uniform initializer. + + More details about ``initializer`` please refer to + `init `_. + """ + + def __init__( + self, + in_features: int, + out_features: int, + bias: bool = True, + dtype: torch.dtype = None, + parallel_input: bool = True, + skip_bias_add: bool = False, + weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), + bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1), + stream_chunk_num: int = 1, + ): + super().__init__() + + self.stream_chunk_num = stream_chunk_num + + # Keep input parameters + self.in_features = in_features + self.out_features = out_features + self.parallel_input = parallel_input + self.skip_bias_add = skip_bias_add + + if skip_bias_add and not bias: + raise ValueError("cannot skip bias addition if bias is None") + + # Divide the weight matrix along the last dimension. + self.input_size_per_partition = divide(in_features, gpc.tensor_parallel_size) + + # Parameters. + # Initialize weight. + factory_kwargs = {"device": get_current_device(), "dtype": dtype} + self.weight = Parameter(torch.empty(self.out_features, self.input_size_per_partition, **factory_kwargs)) + + if self.stream_chunk_num > 1: + # TODO() work for inference only + self.chunk_weight() + if bias: + self.bias = Parameter(torch.empty(self.out_features, **factory_kwargs)) + else: + self.bias = None + with seed(ParallelMode.TENSOR): + self.reset_parameters(weight_initializer, bias_initializer) + self._set_tensor_parallel_attributes() + set_parallel_input(False) + + def chunk_weight(self): + self.weight_list = torch.chunk(self.weight, self.stream_chunk_num, dim=0) + + def reset_parameters(self, weight_initializer, bias_initializer) -> None: + fan_in, fan_out = self.in_features, self.out_features + weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out) + if self.bias is not None: + bias_initializer(self.bias, fan_in=fan_in) + broadcast(self.bias, gpc.get_ranks_in_group(ParallelMode.PARALLEL_1D)[0], ParallelMode.PARALLEL_1D) + + def _set_tensor_parallel_attributes(self): + num_partition = gpc.get_world_size(ParallelMode.TENSOR) + set_tensor_parallel_attribute_by_partition(self.weight, num_partition) + + def _load_from_global_state_dict(self, state_dict, prefix, *args): + local_state = OrderedDict() + weight_key = prefix + "weight" + bias_key = prefix + "bias" + if gpc.get_local_rank(ParallelMode.TENSOR) == 0: + # weight + weight = state_dict.pop(weight_key, None) + if weight is not None: + local_state[weight_key] = weight + # bias + if self.bias is not None: + bias = state_dict.pop(bias_key, None) + if bias is not None: + local_state[bias_key] = bias + + local_state = partition_tensor_parallel_state_dict( + local_state, + ParallelMode.PARALLEL_1D, + dims={weight_key: -1, bias_key: 0}, + partition_states={weight_key: True, bias_key: False}, + ) + super()._load_from_global_state_dict(local_state, prefix, *args) + + def _save_to_global_state_dict(self, destination, prefix, keep_vars): + weight_key = prefix + "weight" + bias_key = prefix + "bias" + local_state = OrderedDict({weight_key: self.weight}) + if self.bias is not None: + local_state[bias_key] = self.bias + local_state = gather_tensor_parallel_state_dict( + local_state, + ParallelMode.PARALLEL_1D, + dims={weight_key: -1, bias_key: 0}, + partition_states={weight_key: True, bias_key: False}, + keep_vars=keep_vars, + ) + destination.update(local_state) + + def forward(self, input_: Tensor) -> Tensor: + # Set up backprop all-reduce. + if self.parallel_input: + assert ( + input_.shape[-1] == self.weight.shape[-1] + ), "Invalid shapes in Linear1D_Row forward: input={}, weight={}. Expected last dim of input {}.".format( + input_.shape, self.weight.shape, self.weight.shape[-1] + ) + input_ = input_ + else: + assert ( + divide(input_.shape[-1], gpc.tensor_parallel_size) == self.weight.shape[-1] + ), "Invalid shapes in Linear1D_Row forward: input={}, weight={}. Expected last dim of input {}.".format( + input_.shape, self.weight.shape, self.weight.shape[-1] * gpc.tensor_parallel_size + ) + input_ = split_forward_gather_backward(input_, ParallelMode.PARALLEL_1D, dim=-1) + + if self.stream_chunk_num > 1: + if self.training: + raise RuntimeError("use stream_chunk_num=1 in Linear1D_Row for training!") + with torch.no_grad(): + output_parallel_list = [None for i in range(self.stream_chunk_num)] + handle_list = [] + for i in range(self.stream_chunk_num): + output_parallel_list[i] = F.linear(input_, self.weight_list[i]) + handle = torch.distributed.all_reduce( + output_parallel_list[i], group=gpc.get_group(ParallelMode.PARALLEL_1D), async_op=True + ) + handle_list.append(handle) + # output_parallel_list[i] = reduce_input(output_parallel_list[i], ParallelMode.PARALLEL_1D) + for handle in handle_list: + handle.wait() + output = torch.cat(output_parallel_list, dim=-1) + else: + output_parallel = F.linear(input_, self.weight) + # output_parallel = linear_with_async_comm(input_, self.weight, None, ParallelMode.PARALLEL_1D, False) + output = reduce_input(output_parallel, ParallelMode.PARALLEL_1D) + if not self.skip_bias_add: + if self.bias is not None: + output = output + self.bias + return output + else: + return output, self.bias + + +@LAYERS.register_module +class Embedding1D(ParallelLayer): + r"""Embedding for 1D parallelism. + + Args: + num_embeddings (int): number of embeddings. + embedding_dim (int): dimension of embedding. + padding_idx (int, optional): If specified, the entries at padding_idx do not contribute to the gradient; + therefore, the embedding vector at padding_idx is not updated during training, + i.e. it remains as a fixed “pad”, defaults to None. + dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None. + weight_initializer (:class:`typing.Callable`, optional): + he initializer of weight, defaults to normal initializer. + + The ``args`` and ``kwargs`` used in :class:`torch.nn.functional.embedding` should contain: + :: + + max_norm (float, optional): If given, each embedding vector with norm larger than max_norm is + renormalized to have norm max_norm. Note: this will modify weight in-place. + norm_type (float, optional): The p of the p-norm to compute for the max_norm option. Default 2. + scale_grad_by_freq (bool, optional): If given, this will scale gradients by the inverse + of frequency of the words in the mini-batch. Default False. + sparse (bool, optional): If True, gradient w.r.t. weight will be a sparse tensor. Default False. + + More details about ``args`` and ``kwargs`` could be found in + `Embedding `_. + + More details about ``initializer`` please refer to + `init `_ + """ + + def __init__( + self, + num_embeddings: int, + embedding_dim: int, + padding_idx: int = None, + dtype: torch.dtype = None, + weight_initializer: Callable = init.normal_(), + *args, + **kwargs, + ): + super().__init__() + + self.num_embeddings = num_embeddings + self.embed_dim = embedding_dim + embed_dim_per_partition = divide(embedding_dim, gpc.tensor_parallel_size) + + self.padding_idx = padding_idx + self.embed_args = args + self.embed_kwargs = kwargs + + self.weight = Parameter( + torch.empty((num_embeddings, embed_dim_per_partition), device=get_current_device(), dtype=dtype) + ) + + self.reset_parameters(weight_initializer) + self._set_tensor_parallel_attributes() + set_parallel_input(False) + + def _set_tensor_parallel_attributes(self): + set_tensor_parallel_attribute_by_partition(self.weight, gpc.tensor_parallel_size) + + def reset_parameters(self, weight_initializer) -> None: + with seed(ParallelMode.TENSOR): + fan_in, fan_out = self.num_embeddings, self.embed_dim + weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out) + self._fill_padding_idx_with_zero() + + def _fill_padding_idx_with_zero(self) -> None: + if self.padding_idx is not None: + with torch.no_grad(): + self.weight[self.padding_idx].fill_(0) + + def _load_from_global_state_dict(self, state_dict, prefix, *args): + local_state = OrderedDict() + weight_key = prefix + "weight" + if gpc.get_local_rank(ParallelMode.TENSOR) == 0: + # weight + weight = state_dict.pop(weight_key, None) + if weight is not None: + local_state[weight_key] = weight + + local_state = partition_tensor_parallel_state_dict( + local_state, ParallelMode.PARALLEL_1D, dims={weight_key: -1}, partition_states={weight_key: True} + ) + super()._load_from_global_state_dict(local_state, prefix, *args) + + def _save_to_global_state_dict(self, destination, prefix, keep_vars): + weight_key = prefix + "weight" + local_state = OrderedDict({weight_key: self.weight}) + local_state = gather_tensor_parallel_state_dict( + local_state, + ParallelMode.PARALLEL_1D, + dims={weight_key: -1}, + partition_states={weight_key: True}, + keep_vars=keep_vars, + ) + destination.update(local_state) + + def forward(self, input_: Tensor) -> Tensor: + output_parallel = F.embedding(input_, self.weight, self.padding_idx, *self.embed_args, **self.embed_kwargs) + + output = gather_forward_split_backward(output_parallel, ParallelMode.PARALLEL_1D, dim=-1) + + return output + + +@LAYERS.register_module +class VocabParallelEmbedding1D(ParallelLayer): + r"""Embedding parallelized in the vocabulary dimension. + + Args: + num_embeddings (int): number of embeddings. + embedding_dim (int): dimension of embedding. + padding_idx (int, optional): If specified, the entries at padding_idx do not contribute to the gradient; + therefore, the embedding vector at padding_idx is not updated during training, + i.e. it remains as a fixed “pad”, defaults to None. + dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None. + weight_initializer (:class:`typing.Callable`, optional): + he initializer of weight, defaults to normal initializer. + + The ``args`` and ``kwargs`` used in :class:``torch.nn.functional.embedding`` should contain: + :: + + max_norm (float, optional): If given, each embedding vector with norm larger than max_norm is + renormalized to have norm max_norm. Note: this will modify weight in-place. + norm_type (float, optional): The p of the p-norm to compute for the max_norm option. Default 2. + scale_grad_by_freq (bool, optional): If given, this will scale gradients by the inverse + of frequency of the words in the mini-batch. Default False. + sparse (bool, optional): If True, gradient w.r.t. weight will be a sparse tensor. Default False. + + More details about ``args`` and ``kwargs`` could be found in + `Embedding `_. + + More details about initializer please refer to + `init `_. + """ + + def __init__( + self, + num_embeddings: int, + embedding_dim: int, + padding_idx: int = None, + dtype: torch.dtype = None, + weight_initializer: Callable = init.normal_(), + *args, + **kwargs, + ): + super().__init__() + self.num_embeddings = num_embeddings + self.embed_dim = embedding_dim + self.padding_idx = padding_idx + self.embed_args = args + self.embed_kwargs = kwargs + + tensor_parallel_size = gpc.get_world_size(ParallelMode.PARALLEL_1D) + tensor_parallel_rank = gpc.get_local_rank(ParallelMode.PARALLEL_1D) + self.num_embeddings_per_partition = divide(num_embeddings, tensor_parallel_size) + self.vocab_start_index = tensor_parallel_rank * self.num_embeddings_per_partition + self.vocab_end_index = self.vocab_start_index + self.num_embeddings_per_partition + + self.weight = Parameter( + torch.empty((self.num_embeddings_per_partition, self.embed_dim), device=get_current_device(), dtype=dtype) + ) + + self.reset_parameters(weight_initializer) + self._set_tensor_parallel_attributes() + set_parallel_input(False) + env.vocab_parallel = True + + def _set_tensor_parallel_attributes(self): + set_tensor_parallel_attribute_by_partition(self.weight, gpc.tensor_parallel_size) + + def reset_parameters(self, weight_initializer) -> None: + with seed(ParallelMode.TENSOR): + fan_in, fan_out = self.num_embeddings, self.embed_dim + weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out) + self._fill_padding_idx_with_zero() + + def _fill_padding_idx_with_zero(self) -> None: + if ( + self.padding_idx is not None + and self.padding_idx >= self.vocab_start_index + and self.padding_idx < self.vocab_end_index + ): + with torch.no_grad(): + self.weight[self.padding_idx - self.vocab_start_index].fill_(0) + + def _load_from_global_state_dict(self, state_dict, prefix, *args): + local_state = OrderedDict() + weight_key = prefix + "weight" + if gpc.get_local_rank(ParallelMode.TENSOR) == 0: + # weight + weight = state_dict.pop(weight_key, None) + if weight is not None: + local_state[weight_key] = weight + + local_state = partition_tensor_parallel_state_dict( + local_state, ParallelMode.PARALLEL_1D, dims={weight_key: 0}, partition_states={weight_key: True} + ) + super()._load_from_global_state_dict(local_state, prefix, *args) + + def _save_to_global_state_dict(self, destination, prefix, keep_vars): + weight_key = prefix + "weight" + local_state = OrderedDict({weight_key: self.weight}) + local_state = gather_tensor_parallel_state_dict( + local_state, + ParallelMode.PARALLEL_1D, + dims={weight_key: 0}, + partition_states={weight_key: True}, + keep_vars=keep_vars, + ) + destination.update(local_state) + + def forward(self, input_: Tensor) -> Tensor: + # Build the mask. + input_mask = (input_ < self.vocab_start_index) | (input_ >= self.vocab_end_index) + # Mask the input. + masked_input = input_.clone() - self.vocab_start_index + masked_input[input_mask] = 0 + + output_parallel = F.embedding( + masked_input, self.weight, self.padding_idx, *self.embed_args, **self.embed_kwargs + ) + + # Mask the output embedding. + output_parallel[input_mask, :] = 0.0 + # Reduce across all the model parallel GPUs. + output = reduce_input(output_parallel, ParallelMode.PARALLEL_1D) + return output + + +@LAYERS.register_module +class Dropout1D(ParallelLayer): + """Dropout layer of 1D parallelism. + + Args: + p (float, optional): probability of an element to be zeroed, defaults 0.5. + inplace (bool, optional): whether to do dropout in-place, default to be False. + """ + + def __init__(self, p: float = 0.5, inplace: bool = False): + super().__init__() + self.parallel_input = get_parallel_input() + self.p = p + self.inplace = inplace + + def forward(self, input_: Tensor) -> Tensor: + if self.parallel_input: + with seed(ParallelMode.TENSOR): + output = F.dropout(input_, self.p, self.training, self.inplace) + else: + output = F.dropout(input_, self.p, self.training, self.inplace) + return output + + +@LAYERS.register_module +class PatchEmbedding1D(ColossalaiModule): + """ + 2D Image to Patch Embedding + + :param img_size: image size + :type img_size: int + :param patch_size: patch size + :type patch_size: int + :param in_chans: number of channels of input image + :type in_chans: int + :param embed_size: size of embedding + :type embed_size: int + :param dtype: The dtype of parameters, defaults to None + :type dtype: torch.dtype, optional + :param flatten: whether to flatten output tensor, defaults to True + :type flatten: bool, optional + :param weight_initializer: The initializer of weight, defaults to kaiming uniform initializer + :type weight_initializer: typing.Callable, optional + :param bias_initializer: The initializer of bias, defaults to xavier uniform initializer + :type bias_initializer: typing.Callable, optional + :param position_embed_initializer: The initializer of position embedding, defaults to zero + :type position_embed_initializer: typing.Callable, optional + """ + + def __init__( + self, + img_size: int, + patch_size: int, + in_chans: int, + embed_size: int, + dtype: torch.dtype = None, + flatten: bool = True, + weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), + bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1), + position_embed_initializer: Callable = init.zeros_(), + ): + embed = VanillaPatchEmbedding( + img_size, + patch_size, + in_chans, + embed_size, + dtype=dtype, + flatten=flatten, + weight_initializer=weight_initializer, + bias_initializer=bias_initializer, + position_embed_initializer=position_embed_initializer, + ) + super().__init__(embed) + + def _load_from_state_dict(self, state_dict, prefix, *args): + local_state = OrderedDict() + param_keys = [prefix + "weight", prefix + "bias", prefix + "cls_token", prefix + "pos_embed"] + if gpc.get_local_rank(ParallelMode.TENSOR) == 0: + for key in param_keys: + param = state_dict.pop(key, None) + if param is not None: + local_state[key] = param + + local_state = broadcast_state_dict(local_state, ParallelMode.PARALLEL_1D) + super()._load_from_state_dict(local_state, prefix, *args) + + def _save_to_state_dict(self, destination, prefix, keep_vars): + if gpc.get_local_rank(ParallelMode.TENSOR) == 0: + super()._save_to_state_dict(destination, prefix, keep_vars) diff --git a/colossalai/legacy/nn/layer/parallel_2d/__init__.py b/colossalai/legacy/nn/layer/parallel_2d/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..8d29c66b3a24006296a7e4770938669614ad7fa2 --- /dev/null +++ b/colossalai/legacy/nn/layer/parallel_2d/__init__.py @@ -0,0 +1,22 @@ +from ._operation import reduce_by_batch_2d, split_batch_2d +from .layers import ( + Classifier2D, + Embedding2D, + LayerNorm2D, + Linear2D, + PatchEmbedding2D, + VocabParallelClassifier2D, + VocabParallelEmbedding2D, +) + +__all__ = [ + "split_batch_2d", + "reduce_by_batch_2d", + "Linear2D", + "LayerNorm2D", + "Classifier2D", + "PatchEmbedding2D", + "Embedding2D", + "VocabParallelEmbedding2D", + "VocabParallelClassifier2D", +] diff --git a/colossalai/legacy/nn/layer/parallel_2d/_operation.py b/colossalai/legacy/nn/layer/parallel_2d/_operation.py new file mode 100644 index 0000000000000000000000000000000000000000..f1eff7128e7ae7df1fed8af9da44d729e43c1c55 --- /dev/null +++ b/colossalai/legacy/nn/layer/parallel_2d/_operation.py @@ -0,0 +1,985 @@ +from typing import Any, Optional, Tuple + +import torch +import torch.distributed as dist +from torch import Tensor +from torch.cuda.amp import custom_bwd, custom_fwd + +from colossalai.legacy.communication.collective import all_gather, all_reduce, reduce_scatter +from colossalai.legacy.context.parallel_mode import ParallelMode +from colossalai.legacy.core import global_context as gpc +from colossalai.utils import get_current_device + + +def matmul_2d( + a, + b, + summa_dim, + out_shape, + row_rank=None, + col_rank=None, + row_parallel_mode=ParallelMode.PARALLEL_2D_ROW, + col_parallel_mode=ParallelMode.PARALLEL_2D_COL, +): + r"""Matrix multiplication for 2D parallelism. + + Args: + a (:class:`torch.tensor`): matrix :math:`A`. + b (:class:`torch.tensor`): matrix :math:`B`. + summa_dim (int): dimension of SUMMA fo 2D parallelism. + out_shape (:class:`torch.size`): shape of output tensor. + row_rank (int, optional): the rank of row, defaults to None. + col_rank (int, optional): the rank of column, defaults to None. + row_parallel_mode (:class:`colossalai.legacy.context.ParallelMode`, optional): + row parallel mode, defaults to ParallelMode.PARALLEL_2D_ROW. + col_parallel_mode (:class:`colossalai.legacy.context.ParallelMode`, optional): + column parallel mode, defaults to ParallelMode.PARALLEL_2D_COL. + + Returns: + :class:`torch.tensor`: :math:`C = AB`. + + Note: + The parallel_mode should be concluded in ``ParallelMode``. More details about ``ParallelMode`` could be found + in `parallel_mode `_ + """ + if row_rank is None: + row_rank = gpc.get_local_rank(col_parallel_mode) + if col_rank is None: + col_rank = gpc.get_local_rank(row_parallel_mode) + + data_parallel_rank = 0 if not gpc.is_initialized(ParallelMode.DATA) else gpc.get_local_rank(ParallelMode.DATA) + pipeline_parallel_rank = ( + 0 if not gpc.is_initialized(ParallelMode.PIPELINE) else gpc.get_local_rank(ParallelMode.PIPELINE) + ) + pipeline_parallel_size = ( + 1 if not gpc.is_initialized(ParallelMode.PIPELINE) else gpc.get_world_size(ParallelMode.PIPELINE) + ) + tensor_parallel_size = summa_dim**2 + return Matmul_AB_2D( + a, + b, + summa_dim, + out_shape, + row_rank, + col_rank, + row_parallel_mode, + col_parallel_mode, + data_parallel_rank, + pipeline_parallel_rank, + pipeline_parallel_size, + tensor_parallel_size, + ) + + +class _Classifier2D(torch.autograd.Function): + @staticmethod + @custom_fwd(cast_inputs=torch.float16) + def forward( + ctx: Any, + A: Tensor, + B: Tensor, + bias: Optional[Tensor], + summa_dim: int, + out_shape: Tuple[int, ...], + row_rank: int, + col_rank: int, + row_parallel_mode: ParallelMode, + col_parallel_mode: ParallelMode, + data_parallel_rank: int, + pipeline_parallel_rank: int, + pipeline_parallel_size: int, + tensor_parallel_size: int, + ) -> Tensor: + A = A.clone().detach() + A_shape = A.shape + A = A.reshape((-1, A_shape[-1])) + B_shape = B.shape + B = B.reshape((-1, B_shape[-1])) + B_temp = all_gather(B, -1, col_parallel_mode) + if ctx: + ctx.save_for_backward(A, B_temp) + + C = torch.matmul(A, B_temp.transpose(0, 1)) + + C = all_reduce(C, row_parallel_mode) + + ctx.use_bias = bias is not None + if bias is not None: + C = C + bias + + out = C.reshape(out_shape) + + if ctx: + ctx.summa_dim = summa_dim + ctx.row_rank = row_rank + ctx.col_rank = col_rank + ctx.row_parallel_mode = row_parallel_mode + ctx.col_parallel_mode = col_parallel_mode + ctx.A_shape = A_shape + ctx.B_shape = B_shape + ctx.data_parallel_rank = data_parallel_rank + ctx.pipeline_parallel_rank = pipeline_parallel_rank + ctx.pipeline_parallel_size = pipeline_parallel_size + ctx.tensor_parallel_size = tensor_parallel_size + + return out + + @staticmethod + @custom_bwd + def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]: + A, B = ctx.saved_tensors + + with torch.no_grad(): + A_grad = torch.matmul(output_grad, B) + A_grad = A_grad.reshape(ctx.A_shape) + B_grad = torch.matmul(output_grad.reshape(-1, output_grad.shape[-1]).transpose(0, 1), A) + B_grad = reduce_scatter(B_grad, -1, ctx.col_parallel_mode) + B_grad = B_grad.reshape(ctx.B_shape) + if ctx.use_bias: + bias_grad = torch.sum(output_grad, dim=tuple(range(output_grad.ndim - 1))) + bias_grad = all_reduce(bias_grad, ctx.col_parallel_mode) + else: + bias_grad = None + + return A_grad, B_grad, bias_grad, None, None, None, None, None, None, None, None, None, None + + +def classifier_2d( + A: Tensor, + B: Tensor, + bias: Optional[Tensor], + summa_dim: int, + out_shape: Tuple[int, ...], + row_rank: int, + col_rank: int, + row_parallel_mode: ParallelMode, + col_parallel_mode: ParallelMode, + data_parallel_rank: int, + pipeline_parallel_rank: int, + pipeline_parallel_size: int, + tensor_parallel_size: int, +) -> Tensor: + r"""2D parallel classifier. + + Args: + A (:class:`torch.tensor`): matrix :math:`A`. + B (:class:`torch.tensor`): matrix :math:`B`. + bias (:class:`torch.tensor`, optional): matrix of bias. + summa_dim (int): dimension of SUMMA fo 2D parallelism. + out_shape (:class:`torch.size`): shape of output tensor. + row_rank (int, optional): the rank of row, defaults to None. + col_rank (int, optional): the rank of column, defaults to None. + row_parallel_mode (:class:`colossalai.legacy.context.ParallelMode`): row parallel mode. + col_parallel_mode (:class:`colossalai.legacy.context.ParallelMode`): column parallel mode. + data_parallel_rank (int): data parallel rank. + pipeline_parallel_rank (int): pipeline parallel rank + pipeline_parallel_size (int): pipeline parallel size. + tensor_parallel_size (int): tensor parallel size. + + Note: + The parallel_mode should be concluded in ``ParallelMode``. More details about ``ParallelMode`` could be found + in `parallel_mode `_ + """ + return _Classifier2D.apply( + A, + B, + bias, + summa_dim, + out_shape, + row_rank, + col_rank, + row_parallel_mode, + col_parallel_mode, + data_parallel_rank, + pipeline_parallel_rank, + pipeline_parallel_size, + tensor_parallel_size, + ) + + +class Matmul_AB_2D(torch.autograd.Function): + r"""Matrix multiplication for :math:`C = AB`. + + Args: + A (:class:`torch.tensor`): matrix :math:`A`. + B (:class:`torch.tensor`): matrix :math:`B`. + summa_dim (int): dimension of SUMMA fo 2D parallelism. + out_shape (:class:`torch.size`): shape of output tensor. + row_rank (int, optional): the rank of row, defaults to None. + col_rank (int, optional): the rank of column, defaults to None. + row_parallel_mode (:class:`colossalai.legacy.context.ParallelMode`): row parallel mode. + col_parallel_mode (:class:`colossalai.legacy.context.ParallelMode`): column parallel mode. + data_parallel_rank (int): data parallel rank. + pipeline_parallel_rank (int): pipeline parallel rank + pipeline_parallel_size (int): pipeline parallel size. + tensor_parallel_size (int): tensor parallel size. + + Note: + The parallel_mode should be concluded in ``ParallelMode``. More details about ``ParallelMode`` could be found + in `parallel_mode `_ + """ + + @staticmethod + @custom_fwd(cast_inputs=torch.float16) + def forward( + ctx: Any, + A: Tensor, + B: Tensor, + summa_dim: int, + out_shape: Tuple[int, ...], + row_rank: int, + col_rank: int, + row_parallel_mode: ParallelMode, + col_parallel_mode: ParallelMode, + data_parallel_rank: int, + pipeline_parallel_rank: int, + pipeline_parallel_size: int, + tensor_parallel_size: int, + ) -> Tensor: + # A: [b / q, s, h / q] -> [(b * s) / q, h / q] + # B: [h / q, s / q] + # C: [b / q, s, s / q] -> [(b * s) / q, s / q] + + assert A.shape[-1] == B.shape[-2], "Invalid shapes: A={}, B={} for AB.".format(A.shape, B.shape) + + if ctx: + ctx.save_for_backward(A, B) + + A_shape = A.shape + A = A.reshape((-1, A_shape[-1])) + B_shape = B.shape + B = B.reshape((-1, B_shape[-1])) + C_shape = (A.shape[0], B.shape[-1]) + C = torch.zeros(C_shape, dtype=A.dtype, device=get_current_device()) + + # use circular buffer to store the communication tensor + # 2 is enough for all cases + A_list = [torch.empty_like(A) for _ in range(2)] + B_list = [torch.empty_like(B) for _ in range(2)] + + row_group = gpc.get_group(row_parallel_mode) + col_group = gpc.get_group(col_parallel_mode) + + src_a = ( + summa_dim * row_rank + + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + + pipeline_parallel_rank * tensor_parallel_size + ) + src_b = ( + col_rank + + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + + pipeline_parallel_rank * tensor_parallel_size + ) + + opa = [None] * 2 + opb = [None] * 2 + + A_list[0].copy_(A) + B_list[0].copy_(B) + opa[0] = dist.broadcast(A_list[0], src=src_a, group=row_group, async_op=True) + opb[0] = dist.broadcast(B_list[0], src=src_b, group=col_group, async_op=True) + cur = 0 + + for i in range(summa_dim): + if i != summa_dim - 1: + A_list[1 - cur].copy_(A) + opa[1 - cur] = dist.broadcast(A_list[1 - cur], src=src_a + 1, group=row_group, async_op=True) + B_list[1 - cur].copy_(B) + opb[1 - cur] = dist.broadcast(B_list[1 - cur], src=src_b + summa_dim, group=col_group, async_op=True) + + if opa[cur] is not None: + opa[cur].wait() + if opb[cur] is not None: + opb[cur].wait() + + torch.addmm(C, A_list[cur], B_list[cur], out=C) + cur = 1 - cur + src_a += 1 + src_b += summa_dim + + out = C.reshape(out_shape) + + if ctx: + ctx.summa_dim = summa_dim + ctx.row_rank = row_rank + ctx.col_rank = col_rank + ctx.row_parallel_mode = row_parallel_mode + ctx.col_parallel_mode = col_parallel_mode + ctx.A_shape = A_shape + ctx.B_shape = B_shape + ctx.data_parallel_rank = data_parallel_rank + ctx.pipeline_parallel_rank = pipeline_parallel_rank + ctx.pipeline_parallel_size = pipeline_parallel_size + ctx.tensor_parallel_size = tensor_parallel_size + return out + + @staticmethod + @custom_bwd + def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]: + A, B = ctx.saved_tensors + with torch.no_grad(): + A_grad = Matmul_ABT_2D.apply( + output_grad, + B, + ctx.summa_dim, + ctx.A_shape, + ctx.row_rank, + ctx.col_rank, + ctx.row_parallel_mode, + ctx.col_parallel_mode, + ctx.data_parallel_rank, + ctx.pipeline_parallel_rank, + ctx.pipeline_parallel_size, + ctx.tensor_parallel_size, + ) + B_grad = Matmul_ATB_2D.apply( + A, + output_grad, + ctx.summa_dim, + ctx.B_shape, + ctx.row_rank, + ctx.col_rank, + ctx.row_parallel_mode, + ctx.col_parallel_mode, + ctx.data_parallel_rank, + ctx.pipeline_parallel_rank, + ctx.pipeline_parallel_size, + ctx.tensor_parallel_size, + ) + return A_grad, B_grad, None, None, None, None, None, None, None, None, None, None + + +class Matmul_ABT_2D(torch.autograd.Function): + r"""Matrix multiplication for :math:`C = AB^T` + + Args: + A (:class:`torch.tensor`): matrix :math:`A`. + B (:class:`torch.tensor`): matrix :math:`B`. + summa_dim (int): dimension of SUMMA fo 2D parallelism. + out_shape (:class:`torch.size`): shape of output tensor. + row_rank (int, optional): the rank of row, defaults to None. + col_rank (int, optional): the rank of column, defaults to None. + row_parallel_mode (:class:`colossalai.legacy.context.ParallelMode`): row parallel mode. + col_parallel_mode (:class:`colossalai.legacy.context.ParallelMode`): column parallel mode. + column parallel mode, defaults to ParallelMode.PARALLEL_2D_COL. + data_parallel_rank (int): data parallel rank. + pipeline_parallel_rank (int): pipeline parallel rank + pipeline_parallel_size (int): pipeline parallel size. + tensor_parallel_size (int): tensor parallel size. + + Note: + The parallel_mode should be concluded in ``ParallelMode``. More details about ``ParallelMode`` could be found + in `parallel_mode `_. + """ + + @staticmethod + @custom_fwd(cast_inputs=torch.float16) + def forward( + ctx: Any, + A: Tensor, + B: Tensor, + summa_dim: int, + out_shape: Tuple[int, ...], + row_rank: int, + col_rank: int, + row_parallel_mode: ParallelMode, + col_parallel_mode: ParallelMode, + data_parallel_rank: int, + pipeline_parallel_rank: int, + pipeline_parallel_size: int, + tensor_parallel_size: int, + ) -> Tensor: + assert A.shape[-1] == B.shape[-1], "Invalid shapes: A={}, B={} for ABT.".format(A.shape, B.shape) + + if ctx: + ctx.save_for_backward(A, B) + + A_shape = A.shape + A = A.reshape((-1, A_shape[-1])) + B_shape = B.shape + B = B.reshape((-1, B_shape[-1])) + C_shape = (A.shape[0], B.shape[0]) + C = torch.empty(C_shape, dtype=A.dtype, device=get_current_device()) + + # use circular buffer to store the communication tensor + # 2 is enough for all cases + B_list = [torch.empty_like(B) for _ in range(2)] + C_list = [torch.empty_like(C) for _ in range(2)] + + row_group = gpc.get_group(row_parallel_mode) + col_group = gpc.get_group(col_parallel_mode) + + src_b = ( + col_rank + + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + + pipeline_parallel_rank * tensor_parallel_size + ) + src_c = ( + summa_dim * row_rank + + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + + pipeline_parallel_rank * tensor_parallel_size + ) + + opb = [None] * 2 + opr = [None] * 2 + + B_list[0].copy_(B) + opb[0] = dist.broadcast(B_list[0], src=src_b, group=col_group, async_op=True) + cur = 0 + + for i in range(summa_dim): + if i != summa_dim - 1: + B_list[1 - cur].copy_(B) + opb[1 - cur] = dist.broadcast(B_list[1 - cur], src=src_b + summa_dim, group=col_group, async_op=True) + + if opr[cur] is not None: + opr[cur].wait() + if i - 2 == col_rank: + C.copy_(C_list[cur]) + + if opb[cur] is not None: + opb[cur].wait() + + torch.matmul(A, B_list[cur].transpose(0, 1), out=C_list[cur]) + opr[cur] = dist.reduce(C_list[cur], dst=src_c, group=row_group, async_op=True) + cur = 1 - cur + src_b += summa_dim + src_c += 1 + + for op in opr: + op.wait() + + if summa_dim - 2 == col_rank: + C.copy_(C_list[cur]) + if summa_dim - 1 == col_rank: + C.copy_(C_list[1 - cur]) + out = C.reshape(out_shape) + + if ctx: + ctx.summa_dim = summa_dim + ctx.row_rank = row_rank + ctx.col_rank = col_rank + ctx.row_parallel_mode = row_parallel_mode + ctx.col_parallel_mode = col_parallel_mode + ctx.A_shape = A_shape + ctx.B_shape = B_shape + ctx.data_parallel_rank = data_parallel_rank + ctx.pipeline_parallel_rank = pipeline_parallel_rank + ctx.pipeline_parallel_size = pipeline_parallel_size + ctx.tensor_parallel_size = tensor_parallel_size + + return out + + @staticmethod + @custom_bwd + def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]: + A, B = ctx.saved_tensors + + with torch.no_grad(): + A_grad = Matmul_AB_2D.apply( + output_grad, + B, + ctx.summa_dim, + ctx.A_shape, + ctx.row_rank, + ctx.col_rank, + ctx.row_parallel_mode, + ctx.col_parallel_mode, + ctx.data_parallel_rank, + ctx.pipeline_parallel_rank, + ctx.pipeline_parallel_size, + ctx.tensor_parallel_size, + ) + B_grad = Matmul_ATB_2D.apply( + output_grad, + A, + ctx.summa_dim, + ctx.B_shape, + ctx.row_rank, + ctx.col_rank, + ctx.row_parallel_mode, + ctx.col_parallel_mode, + ctx.data_parallel_rank, + ctx.pipeline_parallel_rank, + ctx.pipeline_parallel_size, + ctx.tensor_parallel_size, + ) + return A_grad, B_grad, None, None, None, None, None, None, None, None, None, None + + +class Matmul_ATB_2D(torch.autograd.Function): + r"""Matrix multiplication for :math:`C = A^TB`. + + Args: + A (:class:`torch.tensor`): matrix :math:`A`. + B (:class:`torch.tensor`): matrix :math:`B`. + summa_dim (int): dimension of SUMMA fo 2D parallelism. + out_shape (:class:`torch.size`): shape of output tensor. + row_rank (int, optional): the rank of row, defaults to None. + col_rank (int, optional): the rank of column, defaults to None. + row_parallel_mode (:class:`colossalai.legacy.context.ParallelMode`): row parallel mode. + col_parallel_mode (:class:`colossalai.legacy.context.ParallelMode`): column parallel mode. + data_parallel_rank (int): data parallel rank. + pipeline_parallel_rank (int): pipeline parallel rank + pipeline_parallel_size (int): pipeline parallel size. + tensor_parallel_size (int): tensor parallel size. + + Note: + The parallel_mode should be concluded in ``ParallelMode``. More details about ``ParallelMode`` could be found + in `parallel_mode `_. + """ + + @staticmethod + @custom_fwd(cast_inputs=torch.float16) + def forward( + ctx: Any, + A: Tensor, + B: Tensor, + summa_dim: int, + out_shape: Tuple[int, ...], + row_rank: int, + col_rank: int, + row_parallel_mode: ParallelMode, + col_parallel_mode: ParallelMode, + data_parallel_rank: int, + pipeline_parallel_rank: int, + pipeline_parallel_size: int, + tensor_parallel_size: int, + ) -> Tensor: + assert A.shape[-2] == B.shape[-2], "Invalid shapes: A={}, B={} for ATB.".format(A.shape, B.shape) + + if ctx: + ctx.save_for_backward(A, B) + + A_shape = A.shape + A = A.reshape((-1, A_shape[-1])) + B_shape = B.shape + B = B.reshape((-1, B_shape[-1])) + C_shape = (A.shape[-1], B.shape[-1]) + C = torch.empty(C_shape, dtype=A.dtype, device=get_current_device()) + + # use circular buffer to store the communication tensor + # 2 is enough for all cases + A_list = [torch.empty_like(A) for _ in range(2)] + C_list = [torch.empty_like(C) for _ in range(2)] + + row_group = gpc.get_group(row_parallel_mode) + col_group = gpc.get_group(col_parallel_mode) + + src_a = ( + summa_dim * row_rank + + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + + pipeline_parallel_rank * tensor_parallel_size + ) + src_c = ( + col_rank + + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + + pipeline_parallel_rank * tensor_parallel_size + ) + + opa = [None] * 2 + opr = [None] * 2 + + A_list[0].copy_(A) + opa[0] = dist.broadcast(A_list[0], src=src_a, group=row_group, async_op=True) + cur = 0 + + for i in range(summa_dim): + if i != summa_dim - 1: + A_list[1 - cur].copy_(A) + opa[1 - cur] = dist.broadcast(A_list[1 - cur], src=src_a + 1, group=row_group, async_op=True) + + if opr[cur] is not None: + opr[cur].wait() + if i - 2 == row_rank: + C.copy_(C_list[cur]) + + if opa[cur] is not None: + opa[cur].wait() + + torch.matmul(A_list[cur].transpose(0, 1), B, out=C_list[cur]) + opr[cur] = dist.reduce(C_list[cur], dst=src_c, group=col_group, async_op=True) + cur = 1 - cur + src_a += 1 + src_c += summa_dim + + for op in opr: + op.wait() + + if summa_dim - 2 == row_rank: + C.copy_(C_list[cur]) + if summa_dim - 1 == row_rank: + C.copy_(C_list[1 - cur]) + out = C.reshape(out_shape) + + if ctx: + ctx.summa_dim = summa_dim + ctx.row_rank = row_rank + ctx.col_rank = col_rank + ctx.row_parallel_mode = row_parallel_mode + ctx.col_parallel_mode = col_parallel_mode + ctx.A_shape = A_shape + ctx.B_shape = B_shape + ctx.data_parallel_rank = data_parallel_rank + ctx.pipeline_parallel_rank = pipeline_parallel_rank + ctx.pipeline_parallel_size = pipeline_parallel_size + ctx.tensor_parallel_size = tensor_parallel_size + + return out + + @staticmethod + @custom_bwd + def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]: + A, B = ctx.saved_tensors + + with torch.no_grad(): + A_grad = Matmul_ABT_2D.apply( + B, + output_grad, + ctx.summa_dim, + ctx.A_shape, + ctx.row_rank, + ctx.col_rank, + ctx.row_parallel_mode, + ctx.col_parallel_mode, + ctx.data_parallel_rank, + ctx.pipeline_parallel_rank, + ctx.pipeline_parallel_size, + ctx.tensor_parallel_size, + ) + B_grad = Matmul_AB_2D.apply( + A, + output_grad, + ctx.summa_dim, + ctx.B_shape, + ctx.row_rank, + ctx.col_rank, + ctx.row_parallel_mode, + ctx.col_parallel_mode, + ctx.data_parallel_rank, + ctx.pipeline_parallel_rank, + ctx.pipeline_parallel_size, + ctx.tensor_parallel_size, + ) + return A_grad, B_grad, None, None, None, None, None, None, None, None, None, None + + +class _Add_Bias_2D(torch.autograd.Function): + @staticmethod + @custom_fwd(cast_inputs=torch.float16) + def forward( + ctx: Any, + input_: Tensor, + bias: Tensor, + output_size_per_partition: int, + row_rank: int, + col_rank: int, + row_parallel_mode: ParallelMode, + col_parallel_mode: ParallelMode, + skip_bias_add: bool, + data_parallel_rank: int, + pipeline_parallel_rank: int, + pipeline_parallel_size: int, + tensor_parallel_size: int, + ) -> Tensor: + bias_temp = all_gather(bias, -1, col_parallel_mode) + + ctx.row_rank = row_rank + ctx.col_rank = col_rank + ctx.row_parallel_mode = row_parallel_mode + ctx.col_parallel_mode = col_parallel_mode + ctx.bias = skip_bias_add + ctx.data_parallel_rank = data_parallel_rank + ctx.pipeline_parallel_rank = pipeline_parallel_rank + ctx.pipeline_parallel_size = pipeline_parallel_size + ctx.tensor_parallel_size = tensor_parallel_size + + if skip_bias_add: + return bias_temp + else: + output = input_ + bias_temp + return output + + @staticmethod + @custom_bwd + def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]: + col_parallel_mode = ctx.col_parallel_mode + + if ctx.bias: + grad = reduce_scatter(output_grad, -1, col_parallel_mode) + return None, grad, None, None, None, None, None, None, None, None, None, None + else: + reduce_dim = tuple(range(output_grad.ndim - 1)) + reduce = torch.sum(output_grad, dim=reduce_dim) + grad = reduce_scatter(reduce, -1, col_parallel_mode) + return output_grad, grad, None, None, None, None, None, None, None, None, None, None + + +def add_bias_2d( + input_: Tensor, + bias: Tensor, + output_size_per_partition: int, + row_rank: int, + col_rank: int, + row_parallel_mode: ParallelMode, + col_parallel_mode: ParallelMode, + skip_bias_add: bool, + data_parallel_rank: int, + pipeline_parallel_rank: int, + pipeline_parallel_size: int, + tensor_parallel_size: int, +) -> Tensor: + r"""Matrix add bias: :math:`C = A + b`. + + Args: + input_ (:class:`torch.tensor`): matrix :math:`A`. + bias (:class:`torch.tensor`): matrix :math:`B`. + output_size_per_partition (int): size of output per partition. + row_rank (int, optional): the rank of row, defaults to None. + col_rank (int, optional): the rank of column, defaults to None. + row_parallel_mode (:class:`colossalai.legacy.context.ParallelMode`): row parallel mode. + col_parallel_mode (:class:`colossalai.legacy.context.ParallelMode`): column parallel mode. + skip_bias_add (bool): + If set to ``True``, it will skip bias add for linear layer, which is preserved for kernel fusion. + data_parallel_rank (int): data parallel rank. + pipeline_parallel_rank (int): pipeline parallel rank + pipeline_parallel_size (int): pipeline parallel size. + tensor_parallel_size (int): tensor parallel size. + + Note: + The parallel_mode should be concluded in ``ParallelMode``. More details about ``ParallelMode`` could be found + in `parallel_mode `_ + """ + return _Add_Bias_2D.apply( + input_, + bias, + output_size_per_partition, + row_rank, + col_rank, + row_parallel_mode, + col_parallel_mode, + skip_bias_add, + data_parallel_rank, + pipeline_parallel_rank, + pipeline_parallel_size, + tensor_parallel_size, + ) + + +class _Layernorm_2D(torch.autograd.Function): + @staticmethod + @custom_fwd(cast_inputs=torch.float32) + def forward( + ctx: Any, + input_: Tensor, + E_x: Tensor, + Var_x: Tensor, + hidden_size: int, + row_parallel_mode: ParallelMode, + col_parallel_mode: ParallelMode, + ) -> Tensor: + input_ = input_ - E_x + # in here, input = x - E[x], Var_x = 1 / sqrt(Var[x] + eps) + ctx.normalized_shape = hidden_size + output = input_ * Var_x + ctx.save_for_backward(output, Var_x) + ctx.row_parallel_mode = row_parallel_mode + ctx.col_parallel_mode = col_parallel_mode + return output + + @staticmethod + @custom_bwd + def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]: + row_parallel_mode = ctx.row_parallel_mode + ctx.col_parallel_mode + x, Var_x = ctx.saved_tensors + # in here, Var_x = 1 / sqrt(Var[x] + eps), x = (x - E[x]) * Var_x + output_grad_sum = torch.sum(output_grad, dim=-1, keepdim=True) + torch.distributed.all_reduce(output_grad_sum, group=gpc.get_group(row_parallel_mode)) + output_grad_sum /= ctx.normalized_shape + + output_grad_mul_x_sum = torch.sum(output_grad * x, dim=-1, keepdim=True) + torch.distributed.all_reduce(output_grad_mul_x_sum, group=gpc.get_group(row_parallel_mode)) + output_grad_mul_x_sum /= ctx.normalized_shape + + input_grad = output_grad.clone() + input_grad -= x * output_grad_mul_x_sum + input_grad -= output_grad_sum + input_grad *= Var_x + + return input_grad, None, None, None, None, None + + +def layernorm_2d( + input_: Tensor, + E_x: Tensor, + Var_x: Tensor, + hidden_size: int, + row_parallel_mode: ParallelMode, + col_parallel_mode: ParallelMode, +) -> Tensor: + r"""Layernorm. + + Args: + input_ (:class:`torch.tensor`): input matrix. + E_x (:class:`torch.tensor`): mean. + Var_x (:class:`torch.tensor`): variance. + hidden_size (int): hidden size. + row_parallel_mode (:class:`colossalai.legacy.context.ParallelMode`): row parallel mode. + col_parallel_mode (:class:`colossalai.legacy.context.ParallelMode`): column parallel mode. + + Note: + The parallel_mode should be concluded in ``ParallelMode``. More details about ``ParallelMode`` could be found + in `parallel_mode `_ + """ + return _Layernorm_2D.apply(input_, E_x, Var_x, hidden_size, row_parallel_mode, col_parallel_mode) + + +class _AllGatherTensor2D(torch.autograd.Function): + @staticmethod + @custom_fwd(cast_inputs=torch.float16) + def forward(ctx: Any, inputs: Tensor, dim: int, parallel_mode: ParallelMode) -> Tensor: + ctx.dim = dim + ctx.parallel_mode = parallel_mode + + outputs = all_gather(inputs, dim, parallel_mode) + return outputs + + @staticmethod + @custom_bwd + def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]: + grad = reduce_scatter(output_grad, ctx.dim, ctx.parallel_mode) + return grad.contiguous(), None, None + + +def all_gather_tensor_2d(tensor: Tensor, dim: int, parallel_mode: ParallelMode) -> Tensor: + r"""All gather the tensor of 2D parallelism. + + Args: + tensor (:class:`torch.tensor`): Input tensor. + dim (int): Dimension to gather. + parallel_mode (:class:`colossalai.legacy.context.ParallelMode`): The parallel mode tensor used. + + Note: + The parallel_mode should be concluded in ``ParallelMode``. More details about ``ParallelMode`` could be found + in `parallel_mode `_ + """ + return _AllGatherTensor2D.apply(tensor, dim, parallel_mode) + + +def split_batch_2d(input_: Tensor, dim: int = 0) -> Tensor: + """Splits 2D tensor in specified dimension across cols. + + Args: + input_ (:class:`torch.tensor`): Input tensor. + dim (int): Specified dimension in which to split. + + Returns: + :class:`torch.tensor`: The tensor has been split. + """ + dim_size = input_.size(dim) + world_size = gpc.get_world_size(ParallelMode.PARALLEL_2D_COL) + + if world_size <= 1: + return input_ + + assert dim_size % world_size == 0, f"The batch size ({dim_size}) is not a multiple of 2D size ({world_size})." + + return torch.chunk(input_, gpc.get_world_size(ParallelMode.PARALLEL_2D_COL), dim=dim)[ + gpc.get_local_rank(ParallelMode.PARALLEL_2D_COL) + ].contiguous() + + +class _ReduceTensor2D(torch.autograd.Function): + @staticmethod + def forward(ctx, input_, parallel_mode): + return all_reduce(input_, parallel_mode) + + @staticmethod + def backward(ctx, output_grad): + return output_grad, None + + +def reduce_tensor_2d(input_: Tensor, parallel_mode: ParallelMode) -> Tensor: + r"""All-reduce the input. + + Args: + input_ (:class:`torch.tensor`): Input tensor. + parallel_mode (:class:`colossalai.legacy.context.ParallelMode`): The parallel mode tensor used. + + Note: + The parallel_mode should be concluded in ``ParallelMode``. More details about ``ParallelMode`` could be found + in `parallel_mode `_ + """ + return _ReduceTensor2D.apply(input_, parallel_mode) + + +class _ReduceScatterTensor2D(torch.autograd.Function): + @staticmethod + def forward(ctx, input_, dim, parallel_mode): + ctx.dim = dim + ctx.parallel_mode = parallel_mode + return reduce_scatter(input_, dim, parallel_mode) + + @staticmethod + def backward(ctx, output_grad): + return all_gather(output_grad, ctx.dim, ctx.parallel_mode), None, None + + +def reduce_scatter_tensor_2d(tensor: Tensor, dim: int, parallel_mode: ParallelMode) -> Tensor: + r"""Reduce-scatter the input. + + Args: + tensor (:class:`torch.tensor`): Input tensor. + dim (int): Dimension to reduce. + parallel_mode (:class:`colossalai.legacy.context.ParallelMode`): The parallel mode tensor used. + + Note: + The parallel_mode should be concluded in ``ParallelMode``. More details about ``ParallelMode`` could be found + in `parallel_mode `_ + """ + dim_size = tensor.size(dim) + world_size = gpc.get_world_size(parallel_mode) + assert dim_size % world_size == 0, f"The batch size ({dim_size}) is not a multiple of 2D size ({world_size})." + + return _ReduceScatterTensor2D.apply(tensor, dim, parallel_mode) + + +class _ReduceByBatch2D(torch.autograd.Function): + @staticmethod + def symbolic(graph, input_, reduce_mean: bool = False): + output = all_reduce(input_, ParallelMode.PARALLEL_2D_COL) + if reduce_mean: + reduce_size = gpc.get_world_size(ParallelMode.PARALLEL_2D_COL) + return output / reduce_size + return output + + @staticmethod + @custom_fwd(cast_inputs=torch.float32) + def forward(ctx, input_, reduce_mean: bool = False): + output = all_reduce(input_, ParallelMode.PARALLEL_2D_COL) + ctx.reduce_mean = reduce_mean + if reduce_mean: + reduce_size = gpc.get_world_size(ParallelMode.PARALLEL_2D_COL) + ctx.reduce_size = reduce_size + return output.clone() / reduce_size + return output.clone() + + @staticmethod + @custom_bwd + def backward(ctx, output_grad): + if ctx.reduce_mean: + return output_grad / ctx.reduce_size, None + else: + return output_grad, None + + +def reduce_by_batch_2d(input_, reduce_mean: bool = False) -> Tensor: + r"""All-reduce the input from the model parallel region. + + Args: + input_ (:class:`torch.tensor`): input matrix. + reduce_mean (bool, optional): + If set to ``True``, it will divide the output by column parallel size, default to False. + """ + return _ReduceByBatch2D.apply(input_, reduce_mean) diff --git a/colossalai/legacy/nn/layer/parallel_2d/_utils.py b/colossalai/legacy/nn/layer/parallel_2d/_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..fe18af26f88fcd19e945131f612e7c2cad8144a9 --- /dev/null +++ b/colossalai/legacy/nn/layer/parallel_2d/_utils.py @@ -0,0 +1,22 @@ +from colossalai.legacy.context.parallel_mode import ParallelMode +from colossalai.legacy.core import global_context as gpc +from colossalai.legacy.global_variables import tensor_parallel_env as env + + +def get_summa_dim_from_env() -> int: + try: + summa_dim = env.summa_dim + assert summa_dim > 0, "SUMMA_DIM must be larger than zero" + return summa_dim + + except KeyError: + raise EnvironmentError( + "SUMMA_DIM is not found in the current environment, " + "please make sure that you have used the correct process group initializer" + ) + + +def assert_summa_initialization(): + assert gpc.is_initialized(ParallelMode.PARALLEL_2D_COL) and gpc.is_initialized( + ParallelMode.PARALLEL_2D_ROW + ), "Both TWO_DIMENSION_COL and TWO_DIMENSION_ROW must be initialized by the process group initializer" diff --git a/colossalai/legacy/nn/layer/parallel_2d/layers.py b/colossalai/legacy/nn/layer/parallel_2d/layers.py new file mode 100644 index 0000000000000000000000000000000000000000..3b2e032e5127b7b67f97ecc06e5fa4096ed7d0c7 --- /dev/null +++ b/colossalai/legacy/nn/layer/parallel_2d/layers.py @@ -0,0 +1,1188 @@ +import math +from collections import OrderedDict +from typing import Callable + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch import Tensor +from torch.nn import Parameter + +from colossalai.legacy.communication import broadcast +from colossalai.legacy.context import ParallelMode, seed +from colossalai.legacy.core import global_context as gpc +from colossalai.legacy.global_variables import tensor_parallel_env as env +from colossalai.legacy.registry import LAYERS +from colossalai.legacy.utils.checkpointing import ( + gather_tensor_parallel_state_dict, + partition_tensor_parallel_state_dict, +) +from colossalai.nn import init as init +from colossalai.utils.cuda import get_current_device + +from ..base_layer import ParallelLayer +from ..utils import divide, set_tensor_parallel_attribute_by_partition, to_2tuple +from ._operation import ( + Matmul_AB_2D, + Matmul_ABT_2D, + add_bias_2d, + all_gather_tensor_2d, + classifier_2d, + layernorm_2d, + reduce_scatter_tensor_2d, + split_batch_2d, +) +from ._utils import assert_summa_initialization, get_summa_dim_from_env + + +@LAYERS.register_module +class Linear2D(ParallelLayer): + r"""Linear layer for 2D parallelism + + Args: + in_features (int): size of each input sample. + out_features (int): size of each output sample. + bias (bool, optional): If set to ``False``, the layer will not learn an additive bias, defaults to ``True``. + dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None. + skip_bias_add (bool, optional): If set to ``True``, it will skip bias add for linear layer, + which is preserved for kernel fusion, defaults to False. + weight_initializer (:class:`typing.Callable`, optional): + The initializer of weight, defaults to kaiming uniform initializer. + bias_initializer (:class:`typing.Callable`, optional): + The initializer of bias, defaults to xavier uniform initializer. + + More details about ``initializer`` please refer to + `init `_. + """ + + def __init__( + self, + in_features: int, + out_features: int, + bias: bool = True, + dtype: torch.dtype = None, + skip_bias_add: bool = False, + weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), + bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1), + ): + super().__init__() + + self.in_features = in_features + self.out_features = out_features + self.skip_bias_add = skip_bias_add + + # parallel settings + assert_summa_initialization() + self.row_rank = gpc.get_local_rank(ParallelMode.PARALLEL_2D_COL) + self.col_rank = gpc.get_local_rank(ParallelMode.PARALLEL_2D_ROW) + self.summa_dim = get_summa_dim_from_env() + + # partitioning dimension + self.input_size_per_partition = divide(self.in_features, self.summa_dim) + self.hidden_size_per_partition = divide(self.out_features, self.summa_dim) + + # create weight, shape: [k/q, h/q] + factory_kwargs = {"device": get_current_device(), "dtype": dtype} + self.weight = Parameter( + torch.empty(self.input_size_per_partition, self.hidden_size_per_partition, **factory_kwargs) + ) + + # create bias, shape: [h/q] + if bias: + self.bias = Parameter(torch.empty(divide(self.out_features, self.summa_dim**2), **factory_kwargs)) + else: + self.register_parameter("bias", None) + + # initialize parameters + with seed(ParallelMode.TENSOR): + self.reset_parameters(weight_initializer, bias_initializer) + self._set_tensor_parallel_attributes() + + def _set_tensor_parallel_attributes(self): + set_tensor_parallel_attribute_by_partition(self.weight, self.summa_dim**2) + if self.bias is not None: + set_tensor_parallel_attribute_by_partition(self.bias, self.summa_dim**2) + + def reset_parameters(self, weight_initializer, bias_initializer) -> None: + fan_in, fan_out = self.in_features, self.out_features + weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out) + if self.bias is not None: + bias_initializer(self.bias, fan_in=fan_in) + + def _load_from_global_state_dict(self, state_dict, prefix, *args, **kwargs): + local_state = OrderedDict() + weight_key = prefix + "weight" + bias_key = prefix + "bias" + if gpc.get_local_rank(ParallelMode.TENSOR) == 0: + # weight + weight = state_dict.pop(weight_key, None) + if weight is not None: + local_state[weight_key] = weight.transpose(0, 1) + # bias + if self.bias is not None: + bias = state_dict.pop(bias_key, None) + if bias is not None: + local_state[bias_key] = bias + + # partition in row groups + if gpc.get_local_rank(ParallelMode.PARALLEL_2D_COL) == 0: + local_state = partition_tensor_parallel_state_dict( + local_state, + ParallelMode.PARALLEL_2D_ROW, + dims={weight_key: -1, bias_key: 0}, + partition_states={weight_key: True, bias_key: True}, + ) + # partition in column groups + local_state = partition_tensor_parallel_state_dict( + local_state, + ParallelMode.PARALLEL_2D_COL, + dims={weight_key: 0, bias_key: 0}, + partition_states={weight_key: True, bias_key: True}, + ) + + super()._load_from_global_state_dict(local_state, prefix, *args, **kwargs) + + def _save_to_global_state_dict(self, destination, prefix, keep_vars): + weight_key = prefix + "weight" + bias_key = prefix + "bias" + local_state = OrderedDict({weight_key: self.weight}) + if self.bias is not None: + local_state[bias_key] = self.bias + + # gather in column groups + local_state = gather_tensor_parallel_state_dict( + local_state, + ParallelMode.PARALLEL_2D_COL, + dims={weight_key: 0, bias_key: 0}, + partition_states={weight_key: True, bias_key: True}, + keep_vars=keep_vars, + ) + # gather in row groups + if gpc.get_local_rank(ParallelMode.PARALLEL_2D_COL) == 0: + local_state = gather_tensor_parallel_state_dict( + local_state, + ParallelMode.PARALLEL_2D_ROW, + dims={weight_key: -1, bias_key: 0}, + partition_states={weight_key: True, bias_key: True}, + keep_vars=keep_vars, + ) + if gpc.get_local_rank(ParallelMode.TENSOR) == 0: + local_state[weight_key] = local_state[weight_key].transpose(0, 1) + destination.update(local_state) + + def forward(self, x: Tensor) -> Tensor: + # input: [m/q, n/q, k/q] + # output: [m/q, n/q, h/q] + out_shape = x.shape[:-1] + (self.hidden_size_per_partition,) + + output = Matmul_AB_2D.apply( + x, + self.weight, + self.summa_dim, + out_shape, + self.row_rank, + self.col_rank, + ParallelMode.PARALLEL_2D_ROW, + ParallelMode.PARALLEL_2D_COL, + self.data_parallel_rank, + self.pipeline_parallel_rank, + self.pipeline_parallel_size, + self.tensor_parallel_size, + ) + + if self.bias is not None: + if self.skip_bias_add: + bias = add_bias_2d( + None, + self.bias, + self.hidden_size_per_partition, + self.row_rank, + self.col_rank, + ParallelMode.PARALLEL_2D_ROW, + ParallelMode.PARALLEL_2D_COL, + True, + self.data_parallel_rank, + self.pipeline_parallel_rank, + self.pipeline_parallel_size, + self.tensor_parallel_size, + ) + return output, bias + else: + output = add_bias_2d( + output, + self.bias, + self.hidden_size_per_partition, + self.row_rank, + self.col_rank, + ParallelMode.PARALLEL_2D_ROW, + ParallelMode.PARALLEL_2D_COL, + False, + self.data_parallel_rank, + self.pipeline_parallel_rank, + self.pipeline_parallel_size, + self.tensor_parallel_size, + ) + return output + else: + return output + + +@LAYERS.register_module +class LayerNorm2D(ParallelLayer): + r"""Layer Normalization for 2D parallelism. + + Args: + normalized_shape (int): input shape from an expected input of size. + :math:`[* \times \text{normalized_shape}[0] \times \text{normalized_shape}[1] + \times \ldots \times \text{normalized_shape}[-1]]` + If a single integer is used, it is treated as a singleton list, and this module will + normalize over the last dimension which is expected to be of that specific size. + eps (float, optional): a value added to the denominator for numerical stability, defaults to 1e-05. + bias (bool, optional): Whether to add a bias, defaults to ``True``. + dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None. + """ + + def __init__(self, normalized_shape: int, eps: float = 1e-05, bias=True, dtype=None): + super().__init__() + + # layer norm config + self.normalized_shape = normalized_shape + self.variance_epsilon = eps + + # parallel setting + assert_summa_initialization() + self.row_rank = gpc.get_local_rank(ParallelMode.PARALLEL_2D_COL) + self.col_rank = gpc.get_local_rank(ParallelMode.PARALLEL_2D_ROW) + self.summa_dim = get_summa_dim_from_env() + + # partitioning dimension + self.partitioned_partition = divide(normalized_shape, self.summa_dim**2) + + # create parameters + factory_kwargs = {"device": get_current_device(), "dtype": dtype} + + self.weight = Parameter(torch.ones(self.partitioned_partition, **factory_kwargs)) + if bias: + self.bias = Parameter(torch.zeros(self.partitioned_partition, **factory_kwargs)) + else: + self.bias = None + + self._set_tensor_parallel_attributes() + + def _set_tensor_parallel_attributes(self): + set_tensor_parallel_attribute_by_partition(self.weight, self.summa_dim**2) + if self.bias is not None: + set_tensor_parallel_attribute_by_partition(self.bias, self.summa_dim**2) + + def _load_from_global_state_dict(self, state_dict, prefix, *args, **kwargs): + local_state = OrderedDict() + weight_key = prefix + "weight" + bias_key = prefix + "bias" + if gpc.get_local_rank(ParallelMode.TENSOR) == 0: + # weight + weight = state_dict.pop(weight_key, None) + if weight is not None: + local_state[weight_key] = weight + # bias + bias = state_dict.pop(bias_key, None) + if bias is not None: + local_state[bias_key] = bias + + # partition in row groups + if gpc.get_local_rank(ParallelMode.PARALLEL_2D_COL) == 0: + local_state = partition_tensor_parallel_state_dict( + local_state, + ParallelMode.PARALLEL_2D_ROW, + dims={weight_key: 0, bias_key: 0}, + partition_states={weight_key: True, bias_key: True}, + ) + # partition in column groups + local_state = partition_tensor_parallel_state_dict( + local_state, + ParallelMode.PARALLEL_2D_COL, + dims={weight_key: 0, bias_key: 0}, + partition_states={weight_key: True, bias_key: True}, + ) + + super()._load_from_global_state_dict(local_state, prefix, *args, **kwargs) + + def _save_to_global_state_dict(self, destination, prefix, keep_vars): + weight_key = prefix + "weight" + bias_key = prefix + "bias" + local_state = OrderedDict({weight_key: self.weight}) + if self.bias is not None: + local_state[bias_key] = self.bias + + # gather in column groups + local_state = gather_tensor_parallel_state_dict( + local_state, + ParallelMode.PARALLEL_2D_COL, + dims={weight_key: 0, bias_key: 0}, + partition_states={weight_key: True, bias_key: True}, + keep_vars=keep_vars, + ) + # gather in row groups + if gpc.get_local_rank(ParallelMode.PARALLEL_2D_COL) == 0: + local_state = gather_tensor_parallel_state_dict( + local_state, + ParallelMode.PARALLEL_2D_ROW, + dims={weight_key: 0, bias_key: 0}, + partition_states={weight_key: True, bias_key: True}, + keep_vars=keep_vars, + ) + if gpc.get_local_rank(ParallelMode.TENSOR) == 0: + destination.update(local_state) + + def forward(self, x: Tensor) -> Tensor: + with torch.no_grad(): + E_x = torch.sum(x, dim=-1, keepdim=True) # [b/q, s, 1] + torch.distributed.all_reduce(E_x, group=gpc.get_group(ParallelMode.PARALLEL_2D_ROW)) + E_x /= self.normalized_shape + + # Var_x in the block below is the sum of input^2 + Var_x = torch.sum(x * x, dim=-1, keepdim=True) # [b/q, s, 1] + torch.distributed.all_reduce(Var_x, group=gpc.get_group(ParallelMode.PARALLEL_2D_ROW)) + Var_x /= self.normalized_shape + + Var_x = Var_x - E_x * E_x # variance of x [b/q, s, 1] + # this time 1/sqrt(Var_x + epsilon) + Var_x = 1.0 / torch.sqrt(Var_x + self.variance_epsilon) + + output = layernorm_2d( + x, E_x, Var_x, self.normalized_shape, ParallelMode.PARALLEL_2D_ROW, ParallelMode.PARALLEL_2D_COL + ) + scale = add_bias_2d( + None, + self.weight, + self.partitioned_partition, + self.row_rank, + self.col_rank, + ParallelMode.PARALLEL_2D_ROW, + ParallelMode.PARALLEL_2D_COL, + True, + self.data_parallel_rank, + self.pipeline_parallel_rank, + self.pipeline_parallel_size, + self.tensor_parallel_size, + ) + if self.bias is not None: + bias = add_bias_2d( + None, + self.bias, + self.partitioned_partition, + self.row_rank, + self.col_rank, + ParallelMode.PARALLEL_2D_ROW, + ParallelMode.PARALLEL_2D_COL, + True, + self.data_parallel_rank, + self.pipeline_parallel_rank, + self.pipeline_parallel_size, + self.tensor_parallel_size, + ) + output = torch.addcmul(bias, scale, output) + else: + output = torch.mul(scale, output) + return output + + +@LAYERS.register_module +class PatchEmbedding2D(ParallelLayer): + r"""2D Image to Patch Embedding. + + Args: + img_size (int): image size. + patch_size (int): patch size. + in_chans (int): number of channels of input image. + embed_size (int): size of embedding. + dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None. + flatten (bool, optional): whether to flatten output tensor, defaults to True. + weight_initializer (:class:`typing.Callable`, optional): + The initializer of weight, defaults to kaiming uniform initializer. + bias_initializer (:class:`typing.Callable`, optional): + The initializer of bias, defaults to xavier uniform initializer. + position_embed_initializer (:class:`typing.Callable`, optional): + The initializer of position embedding, defaults to zeros initializer. + + More details about ``initializer`` please refer to + `init `_. + """ + + def __init__( + self, + img_size: int, + patch_size: int, + in_chans: int, + embed_size: int, + flatten: bool = True, + dtype: torch.dtype = None, + weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), + bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1), + position_embed_initializer: Callable = init.zeros_(), + ): + super().__init__() + img_size = to_2tuple(img_size) + patch_size = to_2tuple(patch_size) + + assert_summa_initialization() + self.summa_dim = get_summa_dim_from_env() + self.img_size = img_size + self.patch_size = patch_size + self.grid_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1]) + self.num_patches = self.grid_size[0] * self.grid_size[1] + self.flatten = flatten + self.embed_size = embed_size + self.embed_size_per_partition = embed_size // (self.summa_dim**2) + + with seed(ParallelMode.TENSOR): + self.weight = Parameter( + torch.empty( + (self.embed_size_per_partition, in_chans, *self.patch_size), + device=get_current_device(), + dtype=dtype, + ) + ) + self.bias = Parameter(torch.empty(self.embed_size_per_partition, device=get_current_device(), dtype=dtype)) + + self.cls_token = Parameter( + torch.zeros((1, 1, self.embed_size_per_partition), device=get_current_device(), dtype=dtype) + ) + self.pos_embed = Parameter( + torch.zeros( + (1, self.num_patches + 1, self.embed_size_per_partition), device=get_current_device(), dtype=dtype + ) + ) + + self.reset_parameters(weight_initializer, bias_initializer, position_embed_initializer) + self._set_tensor_parallel_attribute() + + def _set_tensor_parallel_attribute(self): + set_tensor_parallel_attribute_by_partition(self.weight, self.summa_dim**2) + set_tensor_parallel_attribute_by_partition(self.bias, self.summa_dim**2) + set_tensor_parallel_attribute_by_partition(self.cls_token, self.summa_dim**2) + set_tensor_parallel_attribute_by_partition(self.pos_embed, self.summa_dim**2) + + def reset_parameters(self, weight_initializer, bias_initializer, position_embed_initializer): + with seed(ParallelMode.TENSOR): + fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight) + fan_out = self.embed_size + weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out) + bias_initializer(self.bias, fan_in=fan_in) + position_embed_initializer(self.pos_embed) + + def _load_from_global_state_dict(self, state_dict, prefix, *args, **kwargs): + local_state = OrderedDict() + weight_key = prefix + "weight" + bias_key = prefix + "bias" + cls_token_key = prefix + "cls_token" + pos_embed_key = prefix + "pos_embed" + if gpc.get_local_rank(ParallelMode.TENSOR) == 0: + # weight + weight = state_dict.pop(weight_key, None) + if weight is not None: + local_state[weight_key] = weight + # bias + bias = state_dict.pop(bias_key, None) + if bias is not None: + local_state[bias_key] = bias + # cls token + cls_token = state_dict.pop(cls_token_key, None) + if cls_token is not None: + local_state[cls_token_key] = cls_token + # pos embed + pos_embed = state_dict.pop(pos_embed_key, None) + if pos_embed is not None: + local_state[pos_embed_key] = pos_embed + + # partition in row groups + if gpc.get_local_rank(ParallelMode.PARALLEL_2D_COL) == 0: + local_state = partition_tensor_parallel_state_dict( + local_state, + ParallelMode.PARALLEL_2D_ROW, + dims={weight_key: 0, bias_key: 0, cls_token_key: -1, pos_embed_key: -1}, + partition_states={weight_key: True, bias_key: True, cls_token_key: True, pos_embed_key: True}, + ) + # partition in column groups + local_state = partition_tensor_parallel_state_dict( + local_state, + ParallelMode.PARALLEL_2D_COL, + dims={weight_key: 0, bias_key: 0, cls_token_key: -1, pos_embed_key: -1}, + partition_states={weight_key: True, bias_key: True, cls_token_key: True, pos_embed_key: True}, + ) + + super()._load_from_global_state_dict(local_state, prefix, *args, **kwargs) + + def _save_to_global_state_dict(self, destination, prefix, keep_vars): + weight_key = prefix + "weight" + bias_key = prefix + "bias" + cls_token_key = prefix + "cls_token" + pos_embed_key = prefix + "pos_embed" + local_state = OrderedDict( + {weight_key: self.weight, bias_key: self.bias, cls_token_key: self.cls_token, pos_embed_key: self.pos_embed} + ) + + # gather in column groups + local_state = gather_tensor_parallel_state_dict( + local_state, + ParallelMode.PARALLEL_2D_COL, + dims={weight_key: 0, bias_key: 0, cls_token_key: -1, pos_embed_key: -1}, + partition_states={weight_key: True, bias_key: True, cls_token_key: True, pos_embed_key: True}, + keep_vars=keep_vars, + ) + # gather in row groups + if gpc.get_local_rank(ParallelMode.PARALLEL_2D_COL) == 0: + local_state = gather_tensor_parallel_state_dict( + local_state, + ParallelMode.PARALLEL_2D_ROW, + dims={weight_key: 0, bias_key: 0, cls_token_key: -1, pos_embed_key: -1}, + partition_states={weight_key: True, bias_key: True, cls_token_key: True, pos_embed_key: True}, + keep_vars=keep_vars, + ) + if gpc.get_local_rank(ParallelMode.TENSOR) == 0: + destination.update(local_state) + + def forward(self, input_: Tensor) -> Tensor: + input_ = split_batch_2d(input_) + + B, C, H, W = input_.shape + assert ( + H == self.img_size[0] and W == self.img_size[1] + ), f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." + + weight = all_gather_tensor_2d(self.weight, 0, ParallelMode.PARALLEL_2D_COL) + bias = all_gather_tensor_2d(self.bias, 0, ParallelMode.PARALLEL_2D_COL) + + output = F.conv2d(input_, weight, bias, stride=self.patch_size) + if self.flatten: + output = output.flatten(2).transpose(1, 2) # BCHW -> BNC + + cls_token = all_gather_tensor_2d(self.cls_token, -1, ParallelMode.PARALLEL_2D_COL) + pos_embed = all_gather_tensor_2d(self.pos_embed, -1, ParallelMode.PARALLEL_2D_COL) + cls_token = cls_token.expand(output.shape[0], -1, -1) + output = torch.cat((cls_token, output), dim=1) + output = output + pos_embed + + return output + + +@LAYERS.register_module +class Embedding2D(ParallelLayer): + r"""Embedding for 2D parallelism. + + Args: + num_embeddings (int): number of embeddings. + embedding_dim (int): dimension of embedding. + padding_idx (int, optional): If specified, the entries at padding_idx do not contribute to the gradient; + therefore, the embedding vector at padding_idx is not updated during training, + i.e. it remains as a fixed “pad”, defaults to None. + dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None. + weight_initializer (:class:`typing.Callable`, optional): + he initializer of weight, defaults to normal initializer. + + The ``args`` and ``kwargs`` used in :class:``torch.nn.functional.embedding`` should contain: + :: + + max_norm (float, optional): If given, each embedding vector with norm larger than max_norm is + renormalized to have norm max_norm. Note: this will modify weight in-place. + norm_type (float, optional): The p of the p-norm to compute for the max_norm option. Default 2. + scale_grad_by_freq (bool, optional): If given, this will scale gradients by the inverse + of frequency of the words in the mini-batch. Default False. + sparse (bool, optional): If True, gradient w.r.t. weight will be a sparse tensor. Default False. + + More details about ``args`` and ``kwargs`` could be found in + `Embedding `_. + + More details about initializer please refer to + `init `_ + """ + + def __init__( + self, + num_embeddings: int, + embedding_dim: int, + padding_idx: int = None, + dtype: torch.dtype = None, + weight_initializer: Callable = init.normal_(), + *args, + **kwargs, + ): + super().__init__() + + assert_summa_initialization() + self.summa_dim = get_summa_dim_from_env() + self.num_embeddings = num_embeddings + self.embed_dim = embedding_dim + embed_dim_per_partition = divide(embedding_dim, self.summa_dim**2) + + self.padding_idx = padding_idx + self.embed_args = args + self.embed_kwargs = kwargs + + self.weight = Parameter( + torch.empty((num_embeddings, embed_dim_per_partition), device=get_current_device(), dtype=dtype) + ) + + self.reset_parameters(weight_initializer) + self._set_tensor_parallel_attributes() + + def _set_tensor_parallel_attributes(self): + set_tensor_parallel_attribute_by_partition(self.weight, self.summa_dim**2) + + def reset_parameters(self, weight_initializer) -> None: + with seed(ParallelMode.TENSOR): + fan_in, fan_out = self.num_embeddings, self.embed_dim + weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out) + self._fill_padding_idx_with_zero() + + def _fill_padding_idx_with_zero(self) -> None: + if self.padding_idx is not None: + with torch.no_grad(): + self.weight[self.padding_idx].fill_(0) + + def _load_from_global_state_dict(self, state_dict, prefix, *args, **kwargs): + local_state = OrderedDict() + weight_key = prefix + "weight" + if gpc.get_local_rank(ParallelMode.TENSOR) == 0: + # weight + weight = state_dict.pop(weight_key, None) + if weight is not None: + local_state[weight_key] = weight + + # partition in row groups + if gpc.get_local_rank(ParallelMode.PARALLEL_2D_COL) == 0: + local_state = partition_tensor_parallel_state_dict( + local_state, + ParallelMode.PARALLEL_2D_ROW, + dims={weight_key: -1}, + partition_states={weight_key: True}, + ) + # partition in column groups + local_state = partition_tensor_parallel_state_dict( + local_state, + ParallelMode.PARALLEL_2D_COL, + dims={weight_key: -1}, + partition_states={weight_key: True}, + ) + + super()._load_from_global_state_dict(local_state, prefix, *args, **kwargs) + + def _save_to_global_state_dict(self, destination, prefix, keep_vars): + weight_key = prefix + "weight" + local_state = OrderedDict({weight_key: self.weight}) + + # gather in column groups + local_state = gather_tensor_parallel_state_dict( + local_state, + ParallelMode.PARALLEL_2D_COL, + dims={weight_key: -1}, + partition_states={weight_key: True}, + keep_vars=keep_vars, + ) + # gather in row groups + if gpc.get_local_rank(ParallelMode.PARALLEL_2D_COL) == 0: + local_state = gather_tensor_parallel_state_dict( + local_state, + ParallelMode.PARALLEL_2D_ROW, + dims={weight_key: -1}, + partition_states={weight_key: True}, + keep_vars=keep_vars, + ) + if gpc.get_local_rank(ParallelMode.TENSOR) == 0: + destination.update(local_state) + + def forward(self, input_: Tensor) -> Tensor: + input_ = split_batch_2d(input_) + + weight = all_gather_tensor_2d(self.weight, -1, ParallelMode.PARALLEL_2D_COL) + output = F.embedding(input_, weight, self.padding_idx, *self.embed_args, **self.embed_kwargs) + + return output + + +@LAYERS.register_module +class VocabParallelEmbedding2D(ParallelLayer): + r"""Embedding parallelized in the vocabulary dimension. + + Args: + num_embeddings (int): number of embeddings. + embedding_dim (int): dimension of embedding. + padding_idx (int, optional): If specified, the entries at padding_idx do not contribute to the gradient; + therefore, the embedding vector at padding_idx is not updated during training, + i.e. it remains as a fixed “pad”, defaults to None. + dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None. + weight_initializer (:class:`typing.Callable`, optional): + he initializer of weight, defaults to normal initializer. + + The ``args`` and ``kwargs`` used in :class:``torch.nn.functional.embedding`` should contain: + :: + + max_norm (float, optional): If given, each embedding vector with norm larger than max_norm is + renormalized to have norm max_norm. Note: this will modify weight in-place. + norm_type (float, optional): The p of the p-norm to compute for the max_norm option. Default 2. + scale_grad_by_freq (bool, optional): If given, this will scale gradients by the inverse + of frequency of the words in the mini-batch. Default False. + sparse (bool, optional): If True, gradient w.r.t. weight will be a sparse tensor. Default False. + + More details about ``args`` and ``kwargs`` could be found in + `Embedding `_. + + More details about initializer please refer to + `init `_. + """ + + def __init__( + self, + num_embeddings: int, + embedding_dim: int, + padding_idx: int = None, + dtype: torch.dtype = None, + weight_initializer: Callable = init.normal_(), + *args, + **kwargs, + ): + super().__init__() + self.num_embeddings = num_embeddings + self.embed_dim = embedding_dim + self.padding_idx = padding_idx + self.embed_args = args + self.embed_kwargs = kwargs + + assert_summa_initialization() + self.summa_dim = get_summa_dim_from_env() + self.num_embeddings_per_partition = divide(self.num_embeddings, self.summa_dim) + self.embed_dim_per_partition = divide(self.embed_dim, self.summa_dim) + tensor_parallel_rank = gpc.get_local_rank(ParallelMode.PARALLEL_2D_COL) + self.vocab_start_index = tensor_parallel_rank * self.num_embeddings_per_partition + self.vocab_end_index = self.vocab_start_index + self.num_embeddings_per_partition + + self.weight = Parameter( + torch.empty( + (self.num_embeddings_per_partition, self.embed_dim_per_partition), + device=get_current_device(), + dtype=dtype, + ) + ) + + self.reset_parameters(weight_initializer) + self._set_tensor_parallel_attributes() + env.vocab_parallel = True + + def _set_tensor_parallel_attributes(self): + set_tensor_parallel_attribute_by_partition(self.weight, self.summa_dim**2) + + def reset_parameters(self, weight_initializer) -> None: + with seed(ParallelMode.TENSOR): + fan_in, fan_out = self.num_embeddings, self.embed_dim + weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out) + self._fill_padding_idx_with_zero() + + def _fill_padding_idx_with_zero(self) -> None: + if ( + self.padding_idx is not None + and self.padding_idx >= self.vocab_start_index + and self.padding_idx < self.vocab_end_index + ): + with torch.no_grad(): + self.weight[self.padding_idx - self.vocab_start_index].fill_(0) + + def _load_from_global_state_dict(self, state_dict, prefix, *args, **kwargs): + local_state = OrderedDict() + weight_key = prefix + "weight" + if gpc.get_local_rank(ParallelMode.TENSOR) == 0: + # weight + weight = state_dict.pop(weight_key, None) + if weight is not None: + local_state[weight_key] = weight + + # partition in row groups + if gpc.get_local_rank(ParallelMode.PARALLEL_2D_COL) == 0: + local_state = partition_tensor_parallel_state_dict( + local_state, + ParallelMode.PARALLEL_2D_ROW, + dims={weight_key: -1}, + partition_states={weight_key: True}, + ) + # partition in column groups + local_state = partition_tensor_parallel_state_dict( + local_state, + ParallelMode.PARALLEL_2D_COL, + dims={weight_key: 0}, + partition_states={weight_key: True}, + ) + + super()._load_from_global_state_dict(local_state, prefix, *args, **kwargs) + + def _save_to_global_state_dict(self, destination, prefix, keep_vars): + weight_key = prefix + "weight" + local_state = OrderedDict({weight_key: self.weight}) + + # gather in column groups + local_state = gather_tensor_parallel_state_dict( + local_state, + ParallelMode.PARALLEL_2D_COL, + dims={weight_key: 0}, + partition_states={weight_key: True}, + keep_vars=keep_vars, + ) + # gather in row groups + if gpc.get_local_rank(ParallelMode.PARALLEL_2D_COL) == 0: + local_state = gather_tensor_parallel_state_dict( + local_state, + ParallelMode.PARALLEL_2D_ROW, + dims={weight_key: -1}, + partition_states={weight_key: True}, + keep_vars=keep_vars, + ) + if gpc.get_local_rank(ParallelMode.TENSOR) == 0: + destination.update(local_state) + + def forward(self, input_: Tensor) -> Tensor: + input_mask = (input_ < self.vocab_start_index) | (input_ >= self.vocab_end_index) + masked_input = input_.clone() - self.vocab_start_index + masked_input[input_mask] = 0 + + output_parallel = F.embedding( + masked_input, self.weight, self.padding_idx, *self.embed_args, **self.embed_kwargs + ) + + output_parallel[input_mask, :] = 0.0 + output = reduce_scatter_tensor_2d(output_parallel, 0, ParallelMode.PARALLEL_2D_COL) + return output + + +@LAYERS.register_module +class Classifier2D(ParallelLayer): + r"""Classifier for 2D parallelism. + + Args: + in_features (int): size of each input sample. + num_classes (int): number of classes. + weight (:class:`torch.nn.Parameter`, optional): weight of the classifier, defaults to None. + bias (bool, optional): If set to ``False``, the layer will not learn an additive bias, defaults to ``True``. + dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None. + weight_initializer (:class:`typing.Callable`, optional): + The initializer of weight, defaults to kaiming uniform initializer. + bias_initializer (:class:`typing.Callable`, optional): + The initializer of bias, defaults to xavier uniform initializer. + + More details about ``initializer`` please refer to + `init `_. + """ + + def __init__( + self, + in_features: int, + num_classes: int, + weight: Parameter = None, + bias: bool = True, + dtype: torch.dtype = None, + weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), + bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1), + ): + super().__init__() + self.in_features = in_features + self.num_classes = num_classes + assert_summa_initialization() + self.row_rank = gpc.get_local_rank(ParallelMode.PARALLEL_2D_COL) + self.col_rank = gpc.get_local_rank(ParallelMode.PARALLEL_2D_ROW) + self.summa_dim = get_summa_dim_from_env() + + # partitioning dimension + self.input_size_per_partition = divide(self.in_features, self.summa_dim**2) + + if weight is not None: + self.weight = weight + self.has_weight = False + else: + self.weight = Parameter( + torch.empty(self.num_classes, self.input_size_per_partition, device=get_current_device(), dtype=dtype) + ) + self.has_weight = True + if bias: + self.bias = Parameter(torch.zeros(self.num_classes, device=get_current_device(), dtype=dtype)) + else: + self.bias = None + + self.reset_parameters(weight_initializer, bias_initializer) + self._set_tensor_parallel_attributes() + + def _set_tensor_parallel_attributes(self): + if self.has_weight: + set_tensor_parallel_attribute_by_partition(self.weight, self.summa_dim**2) + + def reset_parameters(self, weight_initializer, bias_initializer) -> None: + with seed(ParallelMode.TENSOR): + fan_in, fan_out = self.in_features, self.num_classes + col_src_rank = gpc.get_ranks_in_group(ParallelMode.PARALLEL_2D_COL)[0] + row_src_rank = gpc.get_ranks_in_group(ParallelMode.PARALLEL_2D_ROW)[0] + + if self.has_weight: + weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out) + + if self.bias is not None: + bias_initializer(self.bias, fan_in=fan_in) + broadcast(self.bias, col_src_rank, ParallelMode.PARALLEL_2D_COL) + broadcast(self.bias, row_src_rank, ParallelMode.PARALLEL_2D_ROW) + + def _load_from_global_state_dict(self, state_dict, prefix, *args, **kwargs): + local_state = OrderedDict() + weight_key = prefix + "weight" + bias_key = prefix + "bias" + if gpc.get_local_rank(ParallelMode.TENSOR) == 0: + # weight + if self.has_weight: + weight = state_dict.pop(weight_key, None) + if weight is not None: + local_state[weight_key] = weight + # bias + if self.bias is not None: + bias = state_dict.pop(bias_key, None) + if bias is not None: + local_state[bias_key] = bias + + # partition in row groups + if gpc.get_local_rank(ParallelMode.PARALLEL_2D_COL) == 0: + local_state = partition_tensor_parallel_state_dict( + local_state, + ParallelMode.PARALLEL_2D_ROW, + dims={weight_key: -1, bias_key: 0}, + partition_states={weight_key: True, bias_key: False}, + ) + # partition in column groups + local_state = partition_tensor_parallel_state_dict( + local_state, + ParallelMode.PARALLEL_2D_COL, + dims={weight_key: -1, bias_key: 0}, + partition_states={weight_key: True, bias_key: False}, + ) + + super()._load_from_global_state_dict(local_state, prefix, *args, **kwargs) + + def _save_to_global_state_dict(self, destination, prefix, keep_vars): + weight_key = prefix + "weight" + bias_key = prefix + "bias" + local_state = OrderedDict() + if self.has_weight: + local_state[weight_key] = self.weight + if self.bias is not None: + local_state[bias_key] = self.bias + + # gather in column groups + local_state = gather_tensor_parallel_state_dict( + local_state, + ParallelMode.PARALLEL_2D_COL, + dims={weight_key: -1, bias_key: 0}, + partition_states={weight_key: True, bias_key: False}, + keep_vars=keep_vars, + ) + # gather in row groups + if gpc.get_local_rank(ParallelMode.PARALLEL_2D_COL) == 0: + local_state = gather_tensor_parallel_state_dict( + local_state, + ParallelMode.PARALLEL_2D_ROW, + dims={weight_key: -1, bias_key: 0}, + partition_states={weight_key: True, bias_key: False}, + keep_vars=keep_vars, + ) + if gpc.get_local_rank(ParallelMode.TENSOR) == 0: + destination.update(local_state) + + def forward(self, input_: Tensor) -> Tensor: + out_shape = input_.shape[:-1] + (self.num_classes,) + + return classifier_2d( + input_, + self.weight, + self.bias, + self.summa_dim, + out_shape, + self.row_rank, + self.col_rank, + ParallelMode.PARALLEL_2D_ROW, + ParallelMode.PARALLEL_2D_COL, + self.data_parallel_rank, + self.pipeline_parallel_rank, + self.pipeline_parallel_size, + self.tensor_parallel_size, + ) + + +@LAYERS.register_module +class VocabParallelClassifier2D(ParallelLayer): + r"""Vocab parallel classifier layer for 2D parallelism. + + Args: + in_features (int): size of each input sample. + num_classes (int): number of classes. + weight (:class:`torch.nn.Parameter`, optional): weight of the classifier, defaults to None. + bias (bool, optional): If set to ``False``, the layer will not learn an additive bias, defaults to ``True``. + dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None. + weight_initializer (:class:`typing.Callable`, optional): + The initializer of weight, defaults to kaiming uniform initializer. + bias_initializer (:class:`typing.Callable`, optional): + The initializer of bias, defaults to xavier uniform initializer. + + More details about ``initializer`` please refer to + `init `_. + """ + + def __init__( + self, + in_features: int, + num_classes: int, + weight: Parameter = None, + bias: bool = True, + dtype: torch.dtype = None, + weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), + bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1), + ): + super().__init__() + + self.in_features = in_features + self.num_classes = num_classes + + # parallel setting + assert_summa_initialization() + self.row_rank = gpc.get_local_rank(ParallelMode.PARALLEL_2D_COL) + self.col_rank = gpc.get_local_rank(ParallelMode.PARALLEL_2D_ROW) + self.summa_dim = get_summa_dim_from_env() + + # partitioning dimension + self.input_size_per_partition = divide(in_features, self.summa_dim) + self.output_size_per_partition = divide(num_classes, self.summa_dim) + + # create weight, shape: [k/q, h/q] + factory_kwargs = {"device": get_current_device(), "dtype": dtype} + if weight is not None: + self.weight = weight + self.has_weight = False + else: + self.weight = Parameter( + torch.empty(self.output_size_per_partition, self.input_size_per_partition, **factory_kwargs) + ) + self.has_weight = True + # create bias, shape: [h/q] + if bias: + self.bias = Parameter(torch.empty(divide(self.num_classes, self.summa_dim**2), **factory_kwargs)) + else: + self.bias = None + + # initialize parameters + with seed(ParallelMode.TENSOR): + self.reset_parameters(weight_initializer, bias_initializer) + self._set_tensor_parallel_attributes() + env.vocab_parallel = True + + def _set_tensor_parallel_attributes(self): + if self.has_weight: + set_tensor_parallel_attribute_by_partition(self.weight, self.summa_dim**2) + if self.bias is not None: + set_tensor_parallel_attribute_by_partition(self.bias, self.summa_dim**2) + + def reset_parameters(self, weight_initializer, bias_initializer) -> None: + fan_in, fan_out = self.in_features, self.num_classes + if self.has_weight: + weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out) + if self.bias is not None: + bias_initializer(self.bias, fan_in=fan_in) + + def _load_from_global_state_dict(self, state_dict, prefix, *args, **kwargs): + local_state = OrderedDict() + weight_key = prefix + "weight" + bias_key = prefix + "bias" + if gpc.get_local_rank(ParallelMode.TENSOR) == 0: + # weight + if self.has_weight: + weight = state_dict.pop(weight_key, None) + if weight is not None: + local_state[weight_key] = weight + # bias + if self.bias is not None: + bias = state_dict.pop(bias_key, None) + if bias is not None: + local_state[bias_key] = bias + + # partition in row groups + if gpc.get_local_rank(ParallelMode.PARALLEL_2D_COL) == 0: + local_state = partition_tensor_parallel_state_dict( + local_state, + ParallelMode.PARALLEL_2D_ROW, + dims={weight_key: -1, bias_key: 0}, + partition_states={weight_key: True, bias_key: True}, + ) + # partition in column groups + local_state = partition_tensor_parallel_state_dict( + local_state, + ParallelMode.PARALLEL_2D_COL, + dims={weight_key: 0, bias_key: 0}, + partition_states={weight_key: True, bias_key: True}, + ) + + super()._load_from_global_state_dict(local_state, prefix, *args, **kwargs) + + def _save_to_global_state_dict(self, destination, prefix, keep_vars): + weight_key = prefix + "weight" + bias_key = prefix + "bias" + local_state = OrderedDict() + if self.has_weight: + local_state[weight_key] = self.weight + if self.bias is not None: + local_state[bias_key] = self.bias + + # gather in column groups + local_state = gather_tensor_parallel_state_dict( + local_state, + ParallelMode.PARALLEL_2D_COL, + dims={weight_key: 0, bias_key: 0}, + partition_states={weight_key: True, bias_key: True}, + keep_vars=keep_vars, + ) + # gather in row groups + if gpc.get_local_rank(ParallelMode.PARALLEL_2D_COL) == 0: + local_state = gather_tensor_parallel_state_dict( + local_state, + ParallelMode.PARALLEL_2D_ROW, + dims={weight_key: -1, bias_key: 0}, + partition_states={weight_key: True, bias_key: True}, + keep_vars=keep_vars, + ) + if gpc.get_local_rank(ParallelMode.TENSOR) == 0: + local_state[weight_key] = local_state[weight_key].transpose(0, 1) + destination.update(local_state) + + def forward(self, x: Tensor) -> Tensor: + # input: [m/q, n/q, k/q] + # output: [m/q, n/q, h/q] + out_shape = x.shape[:-1] + (self.output_size_per_partition,) + + output = Matmul_ABT_2D.apply( + x, + self.weight, + self.summa_dim, + out_shape, + self.row_rank, + self.col_rank, + ParallelMode.PARALLEL_2D_ROW, + ParallelMode.PARALLEL_2D_COL, + self.data_parallel_rank, + self.pipeline_parallel_rank, + self.pipeline_parallel_size, + self.tensor_parallel_size, + ) + + if self.bias is not None: + output = add_bias_2d( + output, + self.bias, + self.output_size_per_partition, + self.row_rank, + self.col_rank, + ParallelMode.PARALLEL_2D_ROW, + ParallelMode.PARALLEL_2D_COL, + False, + self.data_parallel_rank, + self.pipeline_parallel_rank, + self.pipeline_parallel_size, + self.tensor_parallel_size, + ) + return output diff --git a/colossalai/legacy/nn/layer/parallel_2p5d/__init__.py b/colossalai/legacy/nn/layer/parallel_2p5d/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..46b4d3f3b782a3e355ff8fd9940acddfd49b6d8b --- /dev/null +++ b/colossalai/legacy/nn/layer/parallel_2p5d/__init__.py @@ -0,0 +1,22 @@ +from ._operation import reduce_by_batch_2p5d, split_batch_2p5d +from .layers import ( + Classifier2p5D, + Embedding2p5D, + LayerNorm2p5D, + Linear2p5D, + PatchEmbedding2p5D, + VocabParallelClassifier2p5D, + VocabParallelEmbedding2p5D, +) + +__all__ = [ + "split_batch_2p5d", + "reduce_by_batch_2p5d", + "Linear2p5D", + "LayerNorm2p5D", + "Classifier2p5D", + "PatchEmbedding2p5D", + "Embedding2p5D", + "VocabParallelClassifier2p5D", + "VocabParallelEmbedding2p5D", +] diff --git a/colossalai/legacy/nn/layer/parallel_2p5d/_operation.py b/colossalai/legacy/nn/layer/parallel_2p5d/_operation.py new file mode 100644 index 0000000000000000000000000000000000000000..50900c135cabc0a574d91a51ac1de15177c431fe --- /dev/null +++ b/colossalai/legacy/nn/layer/parallel_2p5d/_operation.py @@ -0,0 +1,1113 @@ +from typing import Any, Tuple + +import torch +import torch.distributed as dist +from torch import Tensor +from torch.cuda.amp import custom_bwd, custom_fwd + +from colossalai.legacy.communication.collective import all_gather, all_reduce, reduce_scatter +from colossalai.legacy.context.parallel_mode import ParallelMode +from colossalai.legacy.core import global_context as gpc +from colossalai.utils import get_current_device + + +def get_parallel_group(parallel_mode: ParallelMode): + return gpc.get_group(parallel_mode) + + +def get_global_rank(): + return gpc.get_global_rank() + + +def get_parallel_rank(parallel_mode: ParallelMode): + return gpc.get_local_rank(parallel_mode) + + +class _Classifier2p5D(torch.autograd.Function): + @staticmethod + @custom_fwd(cast_inputs=torch.float16) + def forward( + ctx: Any, + A: Tensor, + B: Tensor, + bias, + tesseract_dim: int, + out_shape: Tuple[int, ...], + row_rank: int, + col_rank: int, + row_parallel_mode: ParallelMode, + col_parallel_mode: ParallelMode, + data_parallel_rank: int, + pipeline_parallel_rank: int, + pipeline_parallel_size: int, + tensor_parallel_size: int, + ) -> Tensor: + A = A.clone().detach() + A_shape = A.shape + A = A.reshape((-1, A_shape[-1])) + B_shape = B.shape + B = B.reshape((-1, B_shape[-1])) + B_temp = all_gather(B, -1, col_parallel_mode) + if ctx: + ctx.save_for_backward(A, B_temp) + + C = torch.matmul(A, B_temp.transpose(0, 1)) + + C = all_reduce(C, row_parallel_mode) + + ctx.use_bias = bias is not None + if bias is not None: + C = C + bias + + out = C.reshape(out_shape) + + if ctx: + ctx.tesseract_dim = tesseract_dim + ctx.row_rank = row_rank + ctx.col_rank = col_rank + ctx.row_parallel_mode = row_parallel_mode + ctx.col_parallel_mode = col_parallel_mode + ctx.A_shape = A_shape + ctx.B_shape = B_shape + ctx.data_parallel_rank = data_parallel_rank + ctx.pipeline_parallel_rank = pipeline_parallel_rank + ctx.pipeline_parallel_size = pipeline_parallel_size + ctx.tensor_parallel_size = tensor_parallel_size + + return out + + @staticmethod + @custom_bwd + def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]: + A, B = ctx.saved_tensors + + with torch.no_grad(): + A_grad = torch.matmul(output_grad, B) + A_grad = A_grad.reshape(ctx.A_shape) + B_grad = torch.matmul(output_grad.reshape(-1, output_grad.shape[-1]).transpose(0, 1), A) + B_grad = reduce_scatter(B_grad, -1, ctx.col_parallel_mode) + B_grad = B_grad.reshape(ctx.B_shape) + + if ctx.use_bias: + bias_grad = torch.sum(output_grad, dim=tuple(range(output_grad.ndim - 1))) + bias_grad = all_reduce(bias_grad, ctx.col_parallel_mode) + else: + bias_grad = None + + return A_grad, B_grad, bias_grad, None, None, None, None, None, None, None, None, None, None + + +def classifier_2p5d( + A: Tensor, + B: Tensor, + bias, + tesseract_dim: int, + out_shape: Tuple[int, ...], + row_rank: int, + col_rank: int, + row_parallel_mode: ParallelMode, + col_parallel_mode: ParallelMode, + data_parallel_rank: int, + pipeline_parallel_rank: int, + pipeline_parallel_size: int, + tensor_parallel_size: int, +) -> Tensor: + r"""Classifier. + + Args: + A (:class:`torch.tensor`): matrix :math:`A`. + B (:class:`torch.tensor`): matrix :math:`B`. + bias (:class:`torch.tensor`): matrix of bias. + tesseract_dim (int): dimension of TESSERACT fo 2.5D parallelism. + out_shape (:class:`torch.size`): shape of output tensor. + row_rank (int): the rank of row. + col_rank (int): the rank of column. + row_parallel_mode (:class:`colossalai.legacy.context.ParallelMode`): row parallel mode. + col_parallel_mode (:class:`colossalai.legacy.context.ParallelMode`): column parallel mode. + data_parallel_rank (int): data parallel rank. + pipeline_parallel_rank (int): pipeline parallel rank + pipeline_parallel_size (int): pipeline parallel size. + tensor_parallel_size (int): tensor parallel size. + + Note: + The parallel_mode should be concluded in ``ParallelMode``. More details about ``ParallelMode`` could be found + in `parallel_mode `_ + """ + return _Classifier2p5D.apply( + A, + B, + bias, + tesseract_dim, + out_shape, + row_rank, + col_rank, + row_parallel_mode, + col_parallel_mode, + data_parallel_rank, + pipeline_parallel_rank, + pipeline_parallel_size, + tensor_parallel_size, + ) + + +class Matmul_AB_2p5D(torch.autograd.Function): + r"""Matrix multiplication for :math:`C = AB`. + + Args: + A (:class:`torch.tensor`): matrix :math:`A`. + B (:class:`torch.tensor`): matrix :math:`B`. + tesseract_dim (int): dimension of TESSERACT fo 2.5D parallelism. + out_shape (:class:`torch.size`): shape of output tensor. + row_rank (int): the rank of row. + col_rank (int): the rank of column. + dep_rank (int): the rank of depth. + row_parallel_mode (:class:`colossalai.legacy.context.ParallelMode`): row parallel mode. + col_parallel_mode (:class:`colossalai.legacy.context.ParallelMode`): column parallel mode. + data_parallel_rank (int): data parallel rank. + pipeline_parallel_rank (int): pipeline parallel rank + pipeline_parallel_size (int): pipeline parallel size. + tensor_parallel_size (int): tensor parallel size. + + Note: + The parallel_mode should be concluded in ``ParallelMode``. More details about ``ParallelMode`` could be found + in `parallel_mode `_ + """ + + @staticmethod + @custom_fwd(cast_inputs=torch.float16) + def forward( + ctx: Any, + A: Tensor, + B: Tensor, + tesseract_dim: int, + out_shape: Tuple[int, ...], + row_rank: int, + col_rank: int, + dep_rank: int, + row_parallel_mode: ParallelMode, + col_parallel_mode: ParallelMode, + data_parallel_rank: int, + pipeline_parallel_rank: int, + pipeline_parallel_size: int, + tensor_parallel_size: int, + ) -> Tensor: + # A: [b / dq, s, h / q] -> [(b * s) / dq, h / q] + # B: [h / dq, s / q] + # C: [b / dq, s, s / q] -> [(b * s) / dq, s / q] + + assert A.shape[-1] == B.shape[-2], "Invalid shapes: A={}, B={} for AB.".format(A.shape, B.shape) + + if ctx: + ctx.save_for_backward(A, B) + + A_shape = A.shape + A = A.reshape((-1, A_shape[-1])) + B_shape = B.shape + B = B.reshape((-1, B_shape[-1])) + C_shape = (A.shape[0], B.shape[-1]) + C = torch.zeros(C_shape, dtype=A.dtype, device=get_current_device()) + + # use circular buffer to store the communication tensor + # 2 is enough for all cases + A_list = [torch.empty_like(A) for _ in range(2)] + B_list = [torch.empty_like(B) for _ in range(2)] + + row_group = gpc.get_group(row_parallel_mode) + col_group = gpc.get_group(col_parallel_mode) + + src_a = ( + tesseract_dim * row_rank + + tesseract_dim**2 * dep_rank + + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + + pipeline_parallel_rank * tensor_parallel_size + ) + src_b = ( + col_rank + + tesseract_dim**2 * dep_rank + + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + + pipeline_parallel_rank * tensor_parallel_size + ) + + opa = [None] * 2 + opb = [None] * 2 + + A_list[0].copy_(A) + B_list[0].copy_(B) + opa[0] = dist.broadcast(A_list[0], src=src_a, group=row_group, async_op=True) + opb[0] = dist.broadcast(B_list[0], src=src_b, group=col_group, async_op=True) + cur = 0 + + for i in range(tesseract_dim): + if i != tesseract_dim - 1: + A_list[1 - cur].copy_(A) + opa[1 - cur] = dist.broadcast(A_list[1 - cur], src=src_a + 1, group=row_group, async_op=True) + B_list[1 - cur].copy_(B) + opb[1 - cur] = dist.broadcast( + B_list[1 - cur], src=src_b + tesseract_dim, group=col_group, async_op=True + ) + + if opa[cur] is not None: + opa[cur].wait() + if opb[cur] is not None: + opb[cur].wait() + + torch.addmm(C, A_list[cur], B_list[cur], out=C) + cur = 1 - cur + src_a += 1 + src_b += tesseract_dim + out = C.reshape(out_shape) + + if ctx: + ctx.tesseract_dim = tesseract_dim + ctx.row_rank = row_rank + ctx.col_rank = col_rank + ctx.dep_rank = dep_rank + ctx.row_parallel_mode = row_parallel_mode + ctx.col_parallel_mode = col_parallel_mode + ctx.A_shape = A_shape + ctx.B_shape = B_shape + ctx.data_parallel_rank = data_parallel_rank + ctx.pipeline_parallel_rank = pipeline_parallel_rank + ctx.pipeline_parallel_size = pipeline_parallel_size + ctx.tensor_parallel_size = tensor_parallel_size + + return out + + @staticmethod + @custom_bwd + def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]: + A, B = ctx.saved_tensors + with torch.no_grad(): + A_grad = Matmul_ABT_2p5D.apply( + output_grad, + B, + ctx.tesseract_dim, + ctx.A_shape, + ctx.row_rank, + ctx.col_rank, + ctx.dep_rank, + ctx.row_parallel_mode, + ctx.col_parallel_mode, + ctx.data_parallel_rank, + ctx.pipeline_parallel_rank, + ctx.pipeline_parallel_size, + ctx.tensor_parallel_size, + ) + B_grad = Matmul_ATB_2p5D.apply( + A, + output_grad, + ctx.tesseract_dim, + ctx.B_shape, + ctx.row_rank, + ctx.col_rank, + ctx.dep_rank, + ctx.row_parallel_mode, + ctx.col_parallel_mode, + ctx.data_parallel_rank, + ctx.pipeline_parallel_rank, + ctx.pipeline_parallel_size, + ctx.tensor_parallel_size, + ) + return A_grad, B_grad, None, None, None, None, None, None, None, None, None, None, None, None, None + + +class Matmul_ABT_2p5D(torch.autograd.Function): + r"""Matrix multiplication for :math:`C = AB^T`. + + Args: + A (:class:`torch.tensor`): matrix :math:`A`. + B (:class:`torch.tensor`): matrix :math:`B`. + tesseract_dim (int): dimension of TESSERACT fo 2.5D parallelism. + out_shape (:class:`torch.size`): shape of output tensor. + row_rank (int): the rank of row. + col_rank (int): the rank of column. + dep_rank (int): the rank of depth. + row_parallel_mode (:class:`colossalai.legacy.context.ParallelMode`): row parallel mode. + col_parallel_mode (:class:`colossalai.legacy.context.ParallelMode`): column parallel mode. + data_parallel_rank (int): data parallel rank. + pipeline_parallel_rank (int): pipeline parallel rank + pipeline_parallel_size (int): pipeline parallel size. + tensor_parallel_size (int): tensor parallel size. + + Note: + The parallel_mode should be concluded in ``ParallelMode``. More details about ``ParallelMode`` could be found + in `parallel_mode `_ + """ + + @staticmethod + @custom_fwd(cast_inputs=torch.float16) + def forward( + ctx: Any, + A: Tensor, + B: Tensor, + tesseract_dim: int, + out_shape: Tuple[int, ...], + row_rank: int, + col_rank: int, + dep_rank: int, + row_parallel_mode: ParallelMode, + col_parallel_mode: ParallelMode, + data_parallel_rank: int, + pipeline_parallel_rank: int, + pipeline_parallel_size: int, + tensor_parallel_size: int, + ) -> Tensor: + assert A.shape[-1] == B.shape[-1], "Invalid shapes: A={}, B={} for ABT.".format(A.shape, B.shape) + + if ctx: + ctx.save_for_backward(A, B) + + A_shape = A.shape + A = A.reshape((-1, A_shape[-1])) + B_shape = B.shape + B = B.reshape((-1, B_shape[-1])) + C_shape = (A.shape[0], B.shape[0]) + C = torch.empty(C_shape, dtype=A.dtype, device=get_current_device()) + + # use circular buffer to store the communication tensor + # 2 is enough for all cases + B_list = [torch.empty_like(B) for _ in range(2)] + C_list = [torch.empty_like(C) for _ in range(2)] + + row_group = gpc.get_group(row_parallel_mode) + col_group = gpc.get_group(col_parallel_mode) + + src_b = ( + col_rank + + tesseract_dim**2 * dep_rank + + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + + pipeline_parallel_rank * tensor_parallel_size + ) + src_c = ( + tesseract_dim * row_rank + + tesseract_dim**2 * dep_rank + + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + + pipeline_parallel_rank * tensor_parallel_size + ) + + opb = [None] * 2 + opr = [None] * 2 + + B_list[0].copy_(B) + opb[0] = dist.broadcast(B_list[0], src=src_b, group=col_group, async_op=True) + cur = 0 + + for i in range(tesseract_dim): + if i != tesseract_dim - 1: + B_list[1 - cur].copy_(B) + opb[1 - cur] = dist.broadcast( + B_list[1 - cur], src=src_b + tesseract_dim, group=col_group, async_op=True + ) + + if opr[cur] is not None: + opr[cur].wait() + if i - 2 == col_rank: + C.copy_(C_list[cur]) + + if opb[cur] is not None: + opb[cur].wait() + + torch.matmul(A, B_list[cur].transpose(0, 1), out=C_list[cur]) + opr[cur] = dist.reduce(C_list[cur], dst=src_c, group=row_group, async_op=True) + cur = 1 - cur + src_b += tesseract_dim + src_c += 1 + + for op in opr: + op.wait() + + if tesseract_dim - 2 == col_rank: + C.copy_(C_list[cur]) + if tesseract_dim - 1 == col_rank: + C.copy_(C_list[1 - cur]) + out = C.reshape(out_shape) + + if ctx: + ctx.tesseract_dim = tesseract_dim + ctx.row_rank = row_rank + ctx.col_rank = col_rank + ctx.dep_rank = dep_rank + ctx.row_parallel_mode = row_parallel_mode + ctx.col_parallel_mode = col_parallel_mode + ctx.A_shape = A_shape + ctx.B_shape = B_shape + ctx.data_parallel_rank = data_parallel_rank + ctx.pipeline_parallel_rank = pipeline_parallel_rank + ctx.pipeline_parallel_size = pipeline_parallel_size + ctx.tensor_parallel_size = tensor_parallel_size + + return out + + @staticmethod + @custom_bwd + def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]: + A, B = ctx.saved_tensors + with torch.no_grad(): + A_grad = Matmul_AB_2p5D.apply( + output_grad, + B, + ctx.tesseract_dim, + ctx.A_shape, + ctx.row_rank, + ctx.col_rank, + ctx.dep_rank, + ctx.row_parallel_mode, + ctx.col_parallel_mode, + ctx.data_parallel_rank, + ctx.pipeline_parallel_rank, + ctx.pipeline_parallel_size, + ctx.tensor_parallel_size, + ) + B_grad = Matmul_ATB_2p5D.apply( + output_grad, + A, + ctx.tesseract_dim, + ctx.B_shape, + ctx.row_rank, + ctx.col_rank, + ctx.dep_rank, + ctx.row_parallel_mode, + ctx.col_parallel_mode, + ctx.data_parallel_rank, + ctx.pipeline_parallel_rank, + ctx.pipeline_parallel_size, + ctx.tensor_parallel_size, + ) + return A_grad, B_grad, None, None, None, None, None, None, None, None, None, None, None, None, None + + +class Matmul_ATB_2p5D(torch.autograd.Function): + r"""Matrix multiplication for :math:`C = A^TB` + + Args: + A (:class:`torch.tensor`): matrix :math:`A`. + B (:class:`torch.tensor`): matrix :math:`B`. + tesseract_dim (int): dimension of TESSERACT fo 2.5D parallelism. + out_shape (:class:`torch.size`): shape of output tensor. + row_rank (int): the rank of row. + col_rank (int): the rank of column. + dep_rank (int): the rank of depth. + row_parallel_mode (:class:`colossalai.legacy.context.ParallelMode`): row parallel mode. + col_parallel_mode (:class:`colossalai.legacy.context.ParallelMode`): column parallel mode. + data_parallel_rank (int): data parallel rank. + pipeline_parallel_rank (int): pipeline parallel rank + pipeline_parallel_size (int): pipeline parallel size. + tensor_parallel_size (int): tensor parallel size. + + Note: + The parallel_mode should be concluded in ``ParallelMode``. More details about ``ParallelMode`` could be found + in `parallel_mode `_ + """ + + @staticmethod + @custom_fwd(cast_inputs=torch.float16) + def forward( + ctx: Any, + A: Tensor, + B: Tensor, + tesseract_dim: int, + out_shape: Tuple[int, ...], + row_rank: int, + col_rank: int, + dep_rank: int, + row_parallel_mode: ParallelMode, + col_parallel_mode: ParallelMode, + data_parallel_rank: int, + pipeline_parallel_rank: int, + pipeline_parallel_size: int, + tensor_parallel_size: int, + ): + assert A.shape[-2] == B.shape[-2], "Invalid shapes: A={}, B={} for ATB.".format(A.shape, B.shape) + + if ctx: + ctx.save_for_backward(A, B) + + A_shape = A.shape + A = A.reshape((-1, A_shape[-1])) + B_shape = B.shape + B = B.reshape((-1, B_shape[-1])) + C_shape = (A.shape[-1], B.shape[-1]) + C = torch.empty(C_shape, dtype=A.dtype, device=get_current_device()) + + # use circular buffer to store the communication tensor + # 2 is enough for all cases + A_list = [torch.empty_like(A) for _ in range(2)] + C_list = [torch.empty_like(C) for _ in range(2)] + + row_group = gpc.get_group(row_parallel_mode) + col_group = gpc.get_group(col_parallel_mode) + + src_a = ( + tesseract_dim * row_rank + + tesseract_dim**2 * dep_rank + + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + + pipeline_parallel_rank * tensor_parallel_size + ) + src_c = ( + col_rank + + tesseract_dim**2 * dep_rank + + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + + pipeline_parallel_rank * tensor_parallel_size + ) + + opa = [None] * 2 + opr = [None] * 2 + + A_list[0].copy_(A) + opa[0] = dist.broadcast(A_list[0], src=src_a, group=row_group, async_op=True) + cur = 0 + + for i in range(tesseract_dim): + if i != tesseract_dim - 1: + A_list[1 - cur].copy_(A) + opa[1 - cur] = dist.broadcast(A_list[1 - cur], src=src_a + 1, group=row_group, async_op=True) + + if opr[cur] is not None: + opr[cur].wait() + if i - 2 == row_rank: + C.copy_(C_list[cur]) + + if opa[cur] is not None: + opa[cur].wait() + + torch.matmul(A_list[cur].transpose(0, 1), B, out=C_list[cur]) + opr[cur] = dist.reduce(C_list[cur], dst=src_c, group=col_group, async_op=True) + cur = 1 - cur + src_a += 1 + src_c += tesseract_dim + + for op in opr: + op.wait() + + if tesseract_dim - 2 == row_rank: + C.copy_(C_list[cur]) + if tesseract_dim - 1 == row_rank: + C.copy_(C_list[1 - cur]) + out = C.reshape(out_shape) + + if ctx: + ctx.tesseract_dim = tesseract_dim + ctx.row_rank = row_rank + ctx.col_rank = col_rank + ctx.dep_rank = dep_rank + ctx.row_parallel_mode = row_parallel_mode + ctx.col_parallel_mode = col_parallel_mode + ctx.A_shape = A_shape + ctx.B_shape = B_shape + ctx.data_parallel_rank = data_parallel_rank + ctx.pipeline_parallel_rank = pipeline_parallel_rank + ctx.pipeline_parallel_size = pipeline_parallel_size + ctx.tensor_parallel_size = tensor_parallel_size + + return out + + @staticmethod + @custom_bwd + def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]: + A, B = ctx.saved_tensors + with torch.no_grad(): + A_grad = Matmul_ABT_2p5D.apply( + B, + output_grad, + ctx.tesseract_dim, + ctx.A_shape, + ctx.row_rank, + ctx.col_rank, + ctx.dep_rank, + ctx.row_parallel_mode, + ctx.col_parallel_mode, + ctx.data_parallel_rank, + ctx.pipeline_parallel_rank, + ctx.pipeline_parallel_size, + ctx.tensor_parallel_size, + ) + B_grad = Matmul_AB_2p5D.apply( + A, + output_grad, + ctx.tesseract_dim, + ctx.B_shape, + ctx.row_rank, + ctx.col_rank, + ctx.dep_rank, + ctx.row_parallel_mode, + ctx.col_parallel_mode, + ctx.data_parallel_rank, + ctx.pipeline_parallel_rank, + ctx.pipeline_parallel_size, + ctx.tensor_parallel_size, + ) + return A_grad, B_grad, None, None, None, None, None, None, None, None, None, None, None, None, None + + +class _Add_Bias_2p5D(torch.autograd.Function): + @staticmethod + @custom_fwd(cast_inputs=torch.float16) + def forward( + ctx: Any, + input: Tensor, + bias: Tensor, + output_size_per_partition: int, + tesseract_dim: int, + row_rank: int, + col_rank: int, + dep_rank: int, + col_parallel_mode: ParallelMode, + skip_bias_add: bool, + data_parallel_rank: int, + pipeline_parallel_rank: int, + pipeline_parallel_size: int, + tensor_parallel_size: int, + ) -> Tensor: + if row_rank == 0: + bias_temp = bias.clone() + else: + bias_temp = torch.zeros(output_size_per_partition, dtype=bias.dtype, device=get_current_device()) + src_rank = ( + col_rank + + dep_rank * tesseract_dim**2 + + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + + pipeline_parallel_rank * tensor_parallel_size + ) + dist.broadcast(bias_temp, src=src_rank, group=get_parallel_group(col_parallel_mode)) + + ctx.row_rank = row_rank + ctx.col_rank = col_rank + ctx.dep_rank = dep_rank + ctx.tesseract_dim = tesseract_dim + ctx.col_parallel_mode = col_parallel_mode + ctx.bias = skip_bias_add + ctx.data_parallel_rank = data_parallel_rank + ctx.pipeline_parallel_rank = pipeline_parallel_rank + ctx.pipeline_parallel_size = pipeline_parallel_size + ctx.tensor_parallel_size = tensor_parallel_size + + if skip_bias_add: + return bias_temp + else: + output = input + bias_temp + return output + + @staticmethod + @custom_bwd + def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]: + row_rank = ctx.row_rank + col_rank = ctx.col_rank + dep_rank = ctx.dep_rank + tesseract_dim = ctx.tesseract_dim + col_parallel_mode = ctx.col_parallel_mode + data_parallel_rank = ctx.data_parallel_rank + pipeline_parallel_rank = ctx.pipeline_parallel_rank + pipeline_parallel_size = ctx.pipeline_parallel_size + tensor_parallel_size = ctx.tensor_parallel_size + + if ctx.bias: + dst_rank = ( + col_rank + + dep_rank * (tesseract_dim**2) + + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + + pipeline_parallel_rank * tensor_parallel_size + ) + dist.reduce(output_grad, dst=dst_rank, group=get_parallel_group(col_parallel_mode)) + if row_rank == 0: + return ( + None, + output_grad, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + ) + else: + grad_tmp = torch.zeros_like(output_grad) + return ( + None, + grad_tmp, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + ) + else: + reduce_dim = tuple(range(output_grad.ndim - 1)) + reduce = torch.sum(output_grad, dim=reduce_dim) + dst_rank = ( + col_rank + + dep_rank * (tesseract_dim**2) + + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + + pipeline_parallel_rank * tensor_parallel_size + ) + dist.reduce(reduce, dst=dst_rank, group=get_parallel_group(col_parallel_mode)) + if row_rank == 0: + return ( + output_grad, + reduce, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + ) + else: + reduce_tmp = torch.zeros_like(reduce) + return ( + output_grad, + reduce_tmp, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + ) + + +def add_bias_2p5d( + input: Tensor, + bias: Tensor, + output_size_per_partition: int, + tesseract_dim: int, + row_rank: int, + col_rank: int, + dep_rank: int, + col_parallel_mode: ParallelMode, + skip_bias_add: bool, + data_parallel_rank: int, + pipeline_parallel_rank: int, + pipeline_parallel_size: int, + tensor_parallel_size: int, +) -> Tensor: + r"""Matrix add bias: :math:`C = A + b`. + + Args: + input (:class:`torch.tensor`): matrix :math:`A`. + bias (:class:`torch.tensor`): matrix :math:`B`. + tesseract_dim (int): dimension of TESSERACT fo 2.5D parallelism. + output_size_per_partition (int): output size in each partition. + row_rank (int): the rank of row. + col_rank (int): the rank of column. + dep_rank (int): the rank of depth. + col_parallel_mode (:class:`colossalai.legacy.context.ParallelMode`): column parallel mode. + skip_bias_add (bool): If set to ``True``, it will skip bias add for linear layer, + which is preserved for kernel fusion. + data_parallel_rank (int): data parallel rank. + pipeline_parallel_rank (int): pipeline parallel rank + pipeline_parallel_size (int): pipeline parallel size. + tensor_parallel_size (int): tensor parallel size. + + Note: + The parallel_mode should be concluded in ``ParallelMode``. More details about ``ParallelMode`` could be found + in `parallel_mode `_ + """ + return _Add_Bias_2p5D.apply( + input, + bias, + output_size_per_partition, + tesseract_dim, + row_rank, + col_rank, + dep_rank, + col_parallel_mode, + skip_bias_add, + data_parallel_rank, + pipeline_parallel_rank, + pipeline_parallel_size, + tensor_parallel_size, + ) + + +class _Layernorm2p5D(torch.autograd.Function): + r"""Layernorm. + + Args: + input (:class:`torch.tensor`): input matrix. + E_x (:class:`torch.tensor`): mean. + Var_x (:class:`torch.tensor`): variance. + hidden_size (int): hidden size. + row_parallel_mode (:class:`colossalai.legacy.context.ParallelMode`): row parallel mode. + + Note: + The parallel_mode should be concluded in ``ParallelMode``. More details about ``ParallelMode`` could be found + in `parallel_mode `_ + """ + + @staticmethod + @custom_fwd(cast_inputs=torch.float32) + def forward( + ctx: Any, input: Tensor, E_x: Tensor, Var_x: Tensor, hidden_size: int, row_parallel_mode: ParallelMode + ) -> Tensor: + input = input - E_x + # in here, input = x - E[x], Var_x = 1 / sqrt(Var[x] + eps) + ctx.hidden_size = hidden_size + output = input * Var_x + ctx.save_for_backward(output, Var_x) + ctx.row_parallel_mode = row_parallel_mode + return output + + @staticmethod + @custom_bwd + def backward(ctx, output_grad): + row_parallel_mode = ctx.row_parallel_mode + x, Var_x = ctx.saved_tensors + # in here, Var_x = 1 / sqrt(Var[x] + eps), x = (x - E[x]) * Var_x + with torch.no_grad(): + output_grad_sum = torch.sum(output_grad, dim=-1, keepdim=True) + torch.distributed.all_reduce(output_grad_sum, group=get_parallel_group(row_parallel_mode)) + output_grad_sum /= ctx.hidden_size + + output_grad_mul_x_sum = torch.sum(output_grad * x, dim=-1, keepdim=True) + torch.distributed.all_reduce(output_grad_mul_x_sum, group=get_parallel_group(row_parallel_mode)) + output_grad_mul_x_sum /= ctx.hidden_size + + input_grad = output_grad.clone() + input_grad -= x * output_grad_mul_x_sum + input_grad -= output_grad_sum + input_grad *= Var_x + + return input_grad, None, None, None, None, None, None + + +def layernorm_2p5d( + input: Tensor, E_x: Tensor, Var_x: Tensor, hidden_size: int, row_parallel_mode: ParallelMode +) -> Tensor: + r"""Layernorm. + + Args: + input (:class:`torch.tensor`): input matrix. + E_x (:class:`torch.tensor`): mean. + Var_x (:class:`torch.tensor`): variance. + hidden_size (int): hidden size. + row_parallel_mode (:class:`colossalai.legacy.context.ParallelMode`): row parallel mode. + + Note: + The parallel_mode should be concluded in ``ParallelMode``. More details about ``ParallelMode`` could be found + in `parallel_mode `_. + """ + return _Layernorm2p5D.apply(input, E_x, Var_x, hidden_size, row_parallel_mode) + + +class _AllGatherTensor2p5D(torch.autograd.Function): + @staticmethod + @custom_fwd(cast_inputs=torch.float16) + def forward(ctx: Any, inputs: Tensor, dim: int, col_parallel_mode: ParallelMode) -> Tensor: + ctx.dim = dim + ctx.col_parallel_mode = col_parallel_mode + + outputs = all_gather(inputs, dim, col_parallel_mode) + return outputs + + @staticmethod + @custom_bwd + def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]: + grad = reduce_scatter(output_grad, ctx.dim, ctx.col_parallel_mode) + return grad.contiguous(), None, None + + +def all_gather_tensor_2p5d(inputs: Tensor, dim: int, col_parallel_mode: ParallelMode) -> Tensor: + r"""all gather the weight of 2.5D parallelism. + + Args: + inputs (:class:`torch.tensor`): input tensor. + dim (int): dimension of all-gather. + col_parallel_mode (:class:`colossalai.legacy.context.ParallelMode`): column parallel mode. + + Note: + The parallel_mode should be concluded in ``ParallelMode``. More details about ``ParallelMode`` could be found + in `parallel_mode `_. + """ + return _AllGatherTensor2p5D.apply(inputs, dim, col_parallel_mode) + + +class SplitFirst(torch.autograd.Function): + r""" + + Args: + inputs (:class:`torch.tensor`): input tensor. + tesseract_dim (int): dimension of TESSERACT fo 2.5D parallelism + col_parallel_mode (:class:`colossalai.legacy.context.ParallelMode`): column parallel mode. + + Note: + The parallel_mode should be concluded in ``ParallelMode``. More details about ``ParallelMode`` could be found + in `parallel_mode `_. + """ + + @staticmethod + @custom_fwd(cast_inputs=torch.float16) + def forward(ctx: Any, inputs: Tensor, tesseract_dim: int, col_parallel_mode: ParallelMode) -> Tensor: + ctx.tesseract_dim = tesseract_dim + ctx.batch_size = inputs.size(0) + ctx.para_mode = col_parallel_mode + row_rank = gpc.get_local_rank(col_parallel_mode) + + outputs = inputs.chunk(tesseract_dim, dim=0)[row_rank] + return outputs + + @staticmethod + @custom_bwd + def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]: + grad_shape = (ctx.batch_size,) + output_grad.shape[1:] + grad = torch.empty(grad_shape, dtype=output_grad.dtype, device=get_current_device()) + dist.all_gather( + list(grad.chunk(ctx.tesseract_dim, dim=0)), output_grad.contiguous(), group=gpc.get_group(ctx.para_mode) + ) + return grad, None, None + + +def split_batch_2p5d(input_: Tensor, dim: int = 0) -> Tensor: + """Splits 2P5D tensor in specified dimension across cols. + + Args: + input_ (:class:`torch.tensor`): Input tensor. + dim (int): Specified dimension in which to split. + + Returns: + :class:`torch.tensor`: The tensor has been split. + """ + dim_size = input_.size(dim) + world_size = gpc.get_world_size(ParallelMode.PARALLEL_2P5D_COL) + + if world_size <= 1: + return input_ + + assert ( + dim_size % world_size == 0 + ), f"The batch size ({dim_size}) is not a multiple of 2.5D size * depth ({world_size})." + + return torch.chunk(input_, gpc.get_world_size(ParallelMode.PARALLEL_2P5D_COL), dim=dim)[ + gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL) + ].contiguous() + + +class _ReduceTensor2p5D(torch.autograd.Function): + @staticmethod + def forward(ctx, input_, parallel_mode): + return all_reduce(input_, parallel_mode) + + @staticmethod + def backward(ctx, output_grad): + return output_grad, None + + +def reduce_tensor_2p5d(input_: Tensor, parallel_mode: ParallelMode) -> Tensor: + r"""All-reduce the input. + + Args: + input_ (:class:`torch.tensor`): Input tensor. + parallel_mode (:class:`colossalai.legacy.context.ParallelMode`): The parallel mode tensor used. + + Note: + The parallel_mode should be concluded in ``ParallelMode``. More details about ``ParallelMode`` could be found + in `parallel_mode `_ + """ + return _ReduceTensor2p5D.apply(input_, parallel_mode) + + +class _ReduceScatterTensor2p5D(torch.autograd.Function): + @staticmethod + def forward(ctx, input_, dim, parallel_mode): + ctx.dim = dim + ctx.parallel_mode = parallel_mode + return reduce_scatter(input_, dim, parallel_mode) + + @staticmethod + def backward(ctx, output_grad): + return all_gather(output_grad, ctx.dim, ctx.parallel_mode), None, None + + +def reduce_scatter_tensor_2p5d(input_: Tensor, dim: int, parallel_mode: ParallelMode) -> Tensor: + r"""Reduce-scatter the input. + + Args: + input_ (:class:`torch.tensor`): Input tensor. + dim (int): Dimension to reduce. + parallel_mode (:class:`colossalai.legacy.context.ParallelMode`): The parallel mode tensor used. + + Note: + The parallel_mode should be concluded in ``ParallelMode``. More details about ``ParallelMode`` could be found + in `parallel_mode `_ + """ + dim_size = input_.size(dim) + world_size = gpc.get_world_size(parallel_mode) + assert ( + dim_size % world_size == 0 + ), f"The batch size ({dim_size}) is not a multiple of 2.5D size * depth ({world_size})." + + return _ReduceScatterTensor2p5D.apply(input_, dim, parallel_mode) + + +class _RreduceByBatch2p5D(torch.autograd.Function): + @staticmethod + def symbolic(graph, input_, reduce_mean: bool = False): + output = all_reduce(input_, ParallelMode.PARALLEL_2P5D_COL) + if reduce_mean: + reduce_size = gpc.get_world_size(ParallelMode.PARALLEL_2P5D_COL) + return output / reduce_size + return output + + @staticmethod + @custom_fwd(cast_inputs=torch.float32) + def forward(ctx, input_, reduce_mean: bool = False): + output = all_reduce(input_, ParallelMode.PARALLEL_2P5D_COL) + ctx.reduce_mean = reduce_mean + if reduce_mean: + reduce_size = gpc.get_world_size(ParallelMode.PARALLEL_2P5D_COL) + ctx.reduce_size = reduce_size + return output.clone() / reduce_size + return output.clone() + + @staticmethod + @custom_bwd + def backward(ctx, output_grad): + if ctx.reduce_mean: + return output_grad / ctx.reduce_size, None + else: + return output_grad, None + + +def reduce_by_batch_2p5d(input_, reduce_mean: bool = False) -> Tensor: + r"""All-reduce the input from the model parallel region. + + Args: + input_ (:class:`torch.tensor`): input matrix. + reduce_mean (bool, optional): + If set to ``True``, it will divide the output by column parallel size, default to False. + """ + return _RreduceByBatch2p5D.apply(input_, reduce_mean) diff --git a/colossalai/legacy/nn/layer/parallel_2p5d/_utils.py b/colossalai/legacy/nn/layer/parallel_2p5d/_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..8cda15aed2a70e9c2c4c73b5edac3aba37dd52aa --- /dev/null +++ b/colossalai/legacy/nn/layer/parallel_2p5d/_utils.py @@ -0,0 +1,30 @@ +from colossalai.legacy.context.parallel_mode import ParallelMode +from colossalai.legacy.core import global_context as gpc +from colossalai.legacy.global_variables import tensor_parallel_env as env + + +def get_tesseract_dim_dep_from_env(): + try: + tesseract_dim = env.tesseract_dim + tesseract_dep = env.tesseract_dep + assert tesseract_dim > 0, "TESSERACT_DIM must be larger than zero" + assert tesseract_dep > 0, "TESSERACT_DEP must be larger than zero" + return tesseract_dim, tesseract_dep + + except KeyError: + raise EnvironmentError( + "TESSERACT_DIM or TESSERACT_DEP is not found in the current environment, " + "please make sure that you have used the correct process group initializer" + ) + + +def assert_tesseract_initialization(): + assert ( + gpc.is_initialized(ParallelMode.PARALLEL_2P5D_COL) + and gpc.is_initialized(ParallelMode.PARALLEL_2P5D_ROW) + and gpc.is_initialized(ParallelMode.PARALLEL_2P5D_DEP) + and gpc.is_initialized(ParallelMode.PARALLEL_2P5D_XZ) + ), ( + "Both PARALLEL_2P5D_COL, PARALLEL_2P5D_ROW, PARALLEL_2P5D_DEP and PARALLEL_2P5D_XZ " + "must be initialized by the process group initializer" + ) diff --git a/colossalai/legacy/nn/layer/parallel_2p5d/layers.py b/colossalai/legacy/nn/layer/parallel_2p5d/layers.py new file mode 100644 index 0000000000000000000000000000000000000000..fc2e35f36cbc2b337c4f0416686e775d99c0f588 --- /dev/null +++ b/colossalai/legacy/nn/layer/parallel_2p5d/layers.py @@ -0,0 +1,1176 @@ +import math +from collections import OrderedDict +from typing import Callable + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch import Tensor +from torch.nn import Parameter + +from colossalai.legacy.communication import broadcast +from colossalai.legacy.context import ParallelMode, seed +from colossalai.legacy.core import global_context as gpc +from colossalai.legacy.global_variables import tensor_parallel_env as env +from colossalai.legacy.registry import LAYERS +from colossalai.legacy.utils.checkpointing import ( + broadcast_state_dict, + gather_tensor_parallel_state_dict, + partition_tensor_parallel_state_dict, +) +from colossalai.nn import init as init +from colossalai.utils.cuda import get_current_device + +from ..base_layer import ParallelLayer +from ..utils import divide, set_tensor_parallel_attribute_by_partition, to_2tuple +from ._operation import ( + Matmul_AB_2p5D, + Matmul_ABT_2p5D, + add_bias_2p5d, + all_gather_tensor_2p5d, + classifier_2p5d, + layernorm_2p5d, + reduce_scatter_tensor_2p5d, + split_batch_2p5d, +) +from ._utils import assert_tesseract_initialization, get_tesseract_dim_dep_from_env + + +@LAYERS.register_module +class Linear2p5D(ParallelLayer): + r"""Linear layer for 2.5D parallelism. + + Args: + in_features (int): size of each input sample. + out_features (int): size of each output sample. + bias (bool, optional): If set to ``False``, the layer will not learn an additive bias, defaults to ``True``. + dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None. + skip_bias_add (bool, optional): If set to ``True``, it will skip bias add for linear layer, + which is preserved for kernel fusion, defaults to False. + weight_initializer (:class:`typing.Callable`, optional): + The initializer of weight, defaults to kaiming uniform initializer. + bias_initializer (:class:`typing.Callable`, optional): + The initializer of bias, defaults to xavier uniform initializer. + + More details about ``initializer`` please refer to + `init `_. + """ + + def __init__( + self, + in_features: int, + out_features: int, + bias: bool = True, + dtype: torch.dtype = None, + skip_bias_add: bool = False, + weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), + bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1), + ): + super().__init__() + + self.in_features = in_features + self.out_features = out_features + self.skip_bias_add = skip_bias_add + + # parallel setting + assert_tesseract_initialization() + self.row_rank = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL) + self.col_rank = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_ROW) + self.dep_rank = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_DEP) + self.tesseract_dim, _ = get_tesseract_dim_dep_from_env() + + # partitioning dimension + self.input_size_per_partition = divide(in_features, self.tesseract_dim) + self.hidden_size_per_partition = divide(out_features, self.tesseract_dim) + + # create weight, shape: [k/q, h/q] + factory_kwargs = {"device": get_current_device(), "dtype": dtype} + self.weight = Parameter( + torch.empty(self.input_size_per_partition, self.hidden_size_per_partition, **factory_kwargs) + ) + + # create bias, shape: [h/q] + if bias: + self.bias = Parameter(torch.empty(self.hidden_size_per_partition, **factory_kwargs)) + else: + self.register_parameter("bias", None) + + # initialize parameters + with seed(ParallelMode.TENSOR): + self.reset_parameters(weight_initializer, bias_initializer) + self._set_tensor_parallel_attributes() + + def _set_tensor_parallel_attributes(self): + set_tensor_parallel_attribute_by_partition(self.weight, self.tesseract_dim**2) + if self.bias is not None: + set_tensor_parallel_attribute_by_partition(self.bias, self.tesseract_dim) + + def reset_parameters(self, weight_initializer, bias_initializer) -> None: + fan_in, fan_out = self.in_features, self.out_features + weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out) + if self.bias is not None: + bias_initializer(self.bias, fan_in=fan_in) + + def _load_from_global_state_dict(self, state_dict, prefix, *args, **kwargs): + local_state = OrderedDict() + weight_key = prefix + "weight" + bias_key = prefix + "bias" + if gpc.get_local_rank(ParallelMode.TENSOR) == 0: + # weight + weight = state_dict.pop(weight_key, None) + if weight is not None: + local_state[weight_key] = weight.transpose(0, 1) + # bias + if self.bias is not None: + bias = state_dict.pop(bias_key, None) + if bias is not None: + local_state[bias_key] = bias + + # broadcast in dep groups + if ( + gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL) == 0 + and gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_ROW) == 0 + ): + broadcast_state_dict(local_state, ParallelMode.PARALLEL_2P5D_DEP) + # partition in column groups + if gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_ROW) == 0: + local_state = partition_tensor_parallel_state_dict( + local_state, + ParallelMode.PARALLEL_2P5D_COL, + dims={weight_key: 0, bias_key: 0}, + partition_states={weight_key: True, bias_key: False}, + ) + # partition in row groups + local_state = partition_tensor_parallel_state_dict( + local_state, + ParallelMode.PARALLEL_2P5D_ROW, + dims={weight_key: -1, bias_key: 0}, + partition_states={weight_key: True, bias_key: True}, + ) + + super()._load_from_global_state_dict(local_state, prefix, *args, **kwargs) + + def _save_to_global_state_dict(self, destination, prefix, keep_vars): + if gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_DEP) == 0: + weight_key = prefix + "weight" + bias_key = prefix + "bias" + local_state = OrderedDict({weight_key: self.weight}) + if self.bias is not None: + local_state[bias_key] = self.bias + + # gather in row groups + local_state = gather_tensor_parallel_state_dict( + local_state, + ParallelMode.PARALLEL_2P5D_ROW, + dims={weight_key: -1, bias_key: 0}, + partition_states={weight_key: True, bias_key: True}, + keep_vars=keep_vars, + ) + # gather in column groups + if gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_ROW) == 0: + local_state = gather_tensor_parallel_state_dict( + local_state, + ParallelMode.PARALLEL_2P5D_COL, + dims={weight_key: 0, bias_key: 0}, + partition_states={weight_key: True, bias_key: False}, + keep_vars=keep_vars, + ) + if gpc.get_local_rank(ParallelMode.TENSOR) == 0: + local_state[weight_key] = local_state[weight_key].transpose(0, 1) + destination.update(local_state) + + def forward(self, x: Tensor) -> Tensor: + # input: [m/dq, n/q, k/q] + # output: [m/dq, n/q, h/q] + out_shape = x.shape[:-1] + (self.hidden_size_per_partition,) + + output = Matmul_AB_2p5D.apply( + x, + self.weight, + self.tesseract_dim, + out_shape, + self.row_rank, + self.col_rank, + self.dep_rank, + ParallelMode.PARALLEL_2P5D_ROW, + ParallelMode.PARALLEL_2P5D_COL, + self.data_parallel_rank, + self.pipeline_parallel_rank, + self.pipeline_parallel_size, + self.tensor_parallel_size, + ) + + if self.bias is not None: + if self.skip_bias_add: + bias = add_bias_2p5d( + None, + self.bias, + self.hidden_size_per_partition, + self.tesseract_dim, + self.row_rank, + self.col_rank, + self.dep_rank, + ParallelMode.PARALLEL_2P5D_COL, + True, + self.data_parallel_rank, + self.pipeline_parallel_rank, + self.pipeline_parallel_size, + self.tensor_parallel_size, + ) + return output, bias + else: + output = add_bias_2p5d( + output, + self.bias, + self.hidden_size_per_partition, + self.tesseract_dim, + self.row_rank, + self.col_rank, + self.dep_rank, + ParallelMode.PARALLEL_2P5D_COL, + False, + self.data_parallel_rank, + self.pipeline_parallel_rank, + self.pipeline_parallel_size, + self.tensor_parallel_size, + ) + return output + else: + return output + + +@LAYERS.register_module +class LayerNorm2p5D(ParallelLayer): + r"""Layer Normalization for 2.5D parallelism. + + Args: + normalized_shape (int): input shape from an expected input of size. + :math:`[* \times \text{normalized_shape}[0] \times \text{normalized_shape}[1] + \times \ldots \times \text{normalized_shape}[-1]]` + If a single integer is used, it is treated as a singleton list, and this module will + normalize over the last dimension which is expected to be of that specific size. + eps (float, optional): a value added to the denominator for numerical stability, defaults to 1e-05. + bias (bool, optional): Whether to add a bias, defaults to ``True``. + dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None. + """ + + def __init__(self, normalized_shape: int, eps: float = 1e-05, bias=True, dtype=None): + super().__init__() + + # layer norm config + self.normalized_shape = normalized_shape + self.variance_epsilon = eps + + # parallel setting + assert_tesseract_initialization() + self.row_rank = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL) + self.col_rank = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_ROW) + self.dep_rank = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_DEP) + self.tesseract_dim, _ = get_tesseract_dim_dep_from_env() + + # partitioning dimension + self.partitioned_partition = divide(normalized_shape, self.tesseract_dim) # * + + # create parameters + factory_kwargs = {"device": get_current_device(), "dtype": dtype} + + self.weight = Parameter(torch.ones(self.partitioned_partition, **factory_kwargs)) + if bias: + self.bias = Parameter(torch.zeros(self.partitioned_partition, **factory_kwargs)) + else: + self.bias = None + + self._set_tensor_parallel_attribute() + + def _set_tensor_parallel_attribute(self): + set_tensor_parallel_attribute_by_partition(self.weight, self.tesseract_dim) + if self.bias is not None: + set_tensor_parallel_attribute_by_partition(self.bias, self.tesseract_dim) + + def _load_from_global_state_dict(self, state_dict, prefix, *args, **kwargs): + local_state = OrderedDict() + weight_key = prefix + "weight" + bias_key = prefix + "bias" + if gpc.get_local_rank(ParallelMode.TENSOR) == 0: + # weight + weight = state_dict.pop(weight_key, None) + if weight is not None: + local_state[weight_key] = weight + # bias + bias = state_dict.pop(bias_key, None) + if bias is not None: + local_state[bias_key] = bias + + # partition in row groups + if gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL) == 0: + local_state = partition_tensor_parallel_state_dict( + local_state, + ParallelMode.PARALLEL_2P5D_ROW, + dims={weight_key: 0, bias_key: 0}, + partition_states={weight_key: True, bias_key: True}, + ) + # partition in column groups + local_state = partition_tensor_parallel_state_dict( + local_state, + ParallelMode.PARALLEL_2P5D_COL, + dims={weight_key: 0, bias_key: 0}, + partition_states={weight_key: True, bias_key: True}, + ) + + super()._load_from_global_state_dict(local_state, prefix, *args, **kwargs) + + def _save_to_global_state_dict(self, destination, prefix, keep_vars): + weight_key = prefix + "weight" + bias_key = prefix + "bias" + local_state = OrderedDict({weight_key: self.weight}) + if self.bias is not None: + local_state[bias_key] = self.bias + + # gather in column groups + local_state = gather_tensor_parallel_state_dict( + local_state, + ParallelMode.PARALLEL_2P5D_COL, + dims={weight_key: 0, bias_key: 0}, + partition_states={weight_key: True, bias_key: True}, + keep_vars=keep_vars, + ) + # gather in row groups + if gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL) == 0: + local_state = gather_tensor_parallel_state_dict( + local_state, + ParallelMode.PARALLEL_2P5D_ROW, + dims={weight_key: 0, bias_key: 0}, + partition_states={weight_key: True, bias_key: True}, + keep_vars=keep_vars, + ) + if gpc.get_local_rank(ParallelMode.TENSOR) == 0: + destination.update(local_state) + + def forward(self, x: Tensor) -> Tensor: + with torch.no_grad(): + E_x = torch.sum(x, dim=-1, keepdim=True) # [b/q, s, 1] + torch.distributed.all_reduce(E_x, group=gpc.get_group(ParallelMode.PARALLEL_2P5D_ROW)) + E_x /= self.normalized_shape + + # Var_x in the block below is the sum of input^2 + Var_x = torch.sum(x * x, dim=-1, keepdim=True) # [b/q, s, 1] + torch.distributed.all_reduce(Var_x, group=gpc.get_group(ParallelMode.PARALLEL_2P5D_ROW)) + Var_x /= self.normalized_shape + + Var_x = Var_x - E_x * E_x # variance of x [b/q, s, 1] + # this time 1/sqrt(Var_x + epsilon) + Var_x = 1.0 / torch.sqrt(Var_x + self.variance_epsilon) + + output = layernorm_2p5d(x, E_x, Var_x, self.normalized_shape, ParallelMode.PARALLEL_2P5D_ROW) + scale = add_bias_2p5d( + None, + self.weight, + self.partitioned_partition, + self.tesseract_dim, + self.row_rank, + self.col_rank, + self.dep_rank, + ParallelMode.PARALLEL_2P5D_COL, + True, + self.data_parallel_rank, + self.pipeline_parallel_rank, + self.pipeline_parallel_size, + self.tensor_parallel_size, + ) + if self.bias is not None: + bias = add_bias_2p5d( + None, + self.bias, + self.partitioned_partition, + self.tesseract_dim, + self.row_rank, + self.col_rank, + self.dep_rank, + ParallelMode.PARALLEL_2P5D_COL, + True, + self.data_parallel_rank, + self.pipeline_parallel_rank, + self.pipeline_parallel_size, + self.tensor_parallel_size, + ) + output = torch.addcmul(bias, scale, output) + else: + output = torch.mul(scale, output) + return output + + +@LAYERS.register_module +class PatchEmbedding2p5D(ParallelLayer): + r"""2D Image to Patch Embedding. + + Args: + img_size (int): image size. + patch_size (int): patch size. + in_chans (int): number of channels of input image. + embed_size (int): size of embedding. + dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None. + flatten (bool, optional): whether to flatten output tensor, defaults to True. + weight_initializer (:class:`typing.Callable`, optional): + The initializer of weight, defaults to kaiming uniform initializer. + bias_initializer (:class:`typing.Callable`, optional): + The initializer of bias, defaults to xavier uniform initializer. + position_embed_initializer (:class:`typing.Callable`, optional): + The initializer of position embedding, defaults to zeros initializer. + + More details about ``initializer`` please refer to + `init `_. + """ + + def __init__( + self, + img_size: int, + patch_size: int, + in_chans: int, + embed_size: int, + flatten: bool = True, + dtype: torch.dtype = None, + weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), + bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1), + position_embed_initializer: Callable = init.zeros_(), + ): + super().__init__() + img_size = to_2tuple(img_size) + patch_size = to_2tuple(patch_size) + + assert_tesseract_initialization() + self.tesseract_dim, self.tesseract_dep = get_tesseract_dim_dep_from_env() + self.img_size = img_size + self.patch_size = patch_size + self.grid_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1]) + self.num_patches = self.grid_size[0] * self.grid_size[1] + self.flatten = flatten + self.embed_size = embed_size + self.embed_size_per_partition = embed_size // self.tesseract_dim**2 + + with seed(ParallelMode.TENSOR): + self.weight = Parameter( + torch.empty( + (self.embed_size_per_partition, in_chans, *self.patch_size), + device=get_current_device(), + dtype=dtype, + ) + ) + self.bias = Parameter(torch.empty(self.embed_size_per_partition, device=get_current_device(), dtype=dtype)) + + self.cls_token = Parameter( + torch.zeros((1, 1, self.embed_size_per_partition), device=get_current_device(), dtype=dtype) + ) + self.pos_embed = Parameter( + torch.zeros( + (1, self.num_patches + 1, self.embed_size_per_partition), device=get_current_device(), dtype=dtype + ) + ) + + self.reset_parameters(weight_initializer, bias_initializer, position_embed_initializer) + self._set_tensor_parallel_attribute() + + def _set_tensor_parallel_attribute(self): + set_tensor_parallel_attribute_by_partition(self.weight, self.tesseract_dim**2) + set_tensor_parallel_attribute_by_partition(self.bias, self.tesseract_dim**2) + set_tensor_parallel_attribute_by_partition(self.cls_token, self.tesseract_dim**2) + set_tensor_parallel_attribute_by_partition(self.pos_embed, self.tesseract_dim**2) + + def reset_parameters(self, weight_initializer, bias_initializer, position_embed_initializer): + with seed(ParallelMode.TENSOR): + fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight) + fan_out = self.embed_size + weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out) + bias_initializer(self.bias, fan_in=fan_in) + position_embed_initializer(self.pos_embed) + + def _load_from_global_state_dict(self, state_dict, prefix, *args, **kwargs): + local_state = OrderedDict() + weight_key = prefix + "weight" + bias_key = prefix + "bias" + cls_token_key = prefix + "cls_token" + pos_embed_key = prefix + "pos_embed" + if gpc.get_local_rank(ParallelMode.TENSOR) == 0: + # weight + weight = state_dict.pop(weight_key, None) + if weight is not None: + local_state[weight_key] = weight + # bias + bias = state_dict.pop(bias_key, None) + if bias is not None: + local_state[bias_key] = bias + # cls token + cls_token = state_dict.pop(cls_token_key, None) + if cls_token is not None: + local_state[cls_token_key] = cls_token + # pos embed + pos_embed = state_dict.pop(pos_embed_key, None) + if pos_embed is not None: + local_state[pos_embed_key] = pos_embed + + # partition in row groups + if gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL) == 0: + local_state = partition_tensor_parallel_state_dict( + local_state, + ParallelMode.PARALLEL_2P5D_ROW, + dims={weight_key: 0, bias_key: 0, cls_token_key: -1, pos_embed_key: -1}, + partition_states={weight_key: True, bias_key: True, cls_token_key: True, pos_embed_key: True}, + ) + # partition in column groups + local_state = partition_tensor_parallel_state_dict( + local_state, + ParallelMode.PARALLEL_2P5D_COL, + dims={weight_key: 0, bias_key: 0, cls_token_key: -1, pos_embed_key: -1}, + partition_states={weight_key: True, bias_key: True, cls_token_key: True, pos_embed_key: True}, + ) + + super()._load_from_global_state_dict(local_state, prefix, *args, **kwargs) + + def _save_to_global_state_dict(self, destination, prefix, keep_vars): + weight_key = prefix + "weight" + bias_key = prefix + "bias" + cls_token_key = prefix + "cls_token" + pos_embed_key = prefix + "pos_embed" + local_state = OrderedDict( + {weight_key: self.weight, bias_key: self.bias, cls_token_key: self.cls_token, pos_embed_key: self.pos_embed} + ) + + # gather in column groups + local_state = gather_tensor_parallel_state_dict( + local_state, + ParallelMode.PARALLEL_2P5D_COL, + dims={weight_key: 0, bias_key: 0, cls_token_key: -1, pos_embed_key: -1}, + partition_states={weight_key: True, bias_key: True, cls_token_key: True, pos_embed_key: True}, + keep_vars=keep_vars, + ) + # gather in row groups + if gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL) == 0: + local_state = gather_tensor_parallel_state_dict( + local_state, + ParallelMode.PARALLEL_2P5D_ROW, + dims={weight_key: 0, bias_key: 0, cls_token_key: -1, pos_embed_key: -1}, + partition_states={weight_key: True, bias_key: True, cls_token_key: True, pos_embed_key: True}, + keep_vars=keep_vars, + ) + if gpc.get_local_rank(ParallelMode.TENSOR) == 0: + destination.update(local_state) + + def forward(self, input_: Tensor) -> Tensor: + input_ = split_batch_2p5d(input_, 0) + + B, C, H, W = input_.shape + assert ( + H == self.img_size[0] and W == self.img_size[1] + ), f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." + + weight = all_gather_tensor_2p5d(self.weight, 0, ParallelMode.PARALLEL_2P5D_COL) + bias = all_gather_tensor_2p5d(self.bias, 0, ParallelMode.PARALLEL_2P5D_COL) + + output = F.conv2d(input_, weight, bias, stride=self.patch_size) + if self.flatten: + output = output.flatten(2).transpose(1, 2) # BCHW -> BNC + + cls_token = all_gather_tensor_2p5d(self.cls_token, -1, ParallelMode.PARALLEL_2P5D_COL) + pos_embed = all_gather_tensor_2p5d(self.pos_embed, -1, ParallelMode.PARALLEL_2P5D_COL) + cls_token = cls_token.expand(output.shape[0], -1, -1) + output = torch.cat((cls_token, output), dim=1) + output = output + pos_embed + + return output + + +@LAYERS.register_module +class Embedding2p5D(ParallelLayer): + r"""Embedding for 2.5D parallelism. + + Args: + num_embeddings (int): number of embeddings. + embedding_dim (int): dimension of embedding. + padding_idx (int, optional): If specified, the entries at padding_idx do not contribute to the gradient; + therefore, the embedding vector at padding_idx is not updated during training, + i.e. it remains as a fixed “pad”, defaults to None. + dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None. + weight_initializer (:class:`typing.Callable`, optional): + he initializer of weight, defaults to normal initializer. + + The ``args`` and ``kwargs`` used in :class:``torch.nn.functional.embedding`` should contain: + :: + + max_norm (float, optional): If given, each embedding vector with norm larger than max_norm is + renormalized to have norm max_norm. Note: this will modify weight in-place. + norm_type (float, optional): The p of the p-norm to compute for the max_norm option. Default 2. + scale_grad_by_freq (bool, optional): If given, this will scale gradients by the inverse + of frequency of the words in the mini-batch. Default False. + sparse (bool, optional): If True, gradient w.r.t. weight will be a sparse tensor. Default False. + + More details about ``args`` and ``kwargs`` could be found in + `Embedding `_. + + More details about initializer please refer to + `init `_ + """ + + def __init__( + self, + num_embeddings: int, + embedding_dim: int, + padding_idx: int = None, + dtype: torch.dtype = None, + weight_initializer: Callable = init.normal_(), + *args, + **kwargs, + ): + super().__init__() + + assert_tesseract_initialization() + self.tesseract_dim, self.tesseract_dep = get_tesseract_dim_dep_from_env() + self.num_embeddings = num_embeddings + self.embed_dim = embedding_dim + embed_dim_per_partition = embedding_dim // self.tesseract_dim**2 + + self.padding_idx = padding_idx + self.embed_args = args + self.embed_kwargs = kwargs + + self.weight = Parameter( + torch.empty((num_embeddings, embed_dim_per_partition), device=get_current_device(), dtype=dtype) + ) + + self.reset_parameters(weight_initializer) + self._set_tensor_parallel_attributes() + + def _set_tensor_parallel_attributes(self): + set_tensor_parallel_attribute_by_partition(self.weight, self.tesseract_dim**2) + + def reset_parameters(self, weight_initializer) -> None: + with seed(ParallelMode.TENSOR): + fan_in, fan_out = self.num_embeddings, self.embed_dim + weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out) + self._fill_padding_idx_with_zero() + + def _fill_padding_idx_with_zero(self) -> None: + if self.padding_idx is not None: + with torch.no_grad(): + self.weight[self.padding_idx].fill_(0) + + def _load_from_global_state_dict(self, state_dict, prefix, *args, **kwargs): + local_state = OrderedDict() + weight_key = prefix + "weight" + if gpc.get_local_rank(ParallelMode.TENSOR) == 0: + # weight + weight = state_dict.pop(weight_key, None) + if weight is not None: + local_state[weight_key] = weight + + # partition in row groups + if gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL) == 0: + local_state = partition_tensor_parallel_state_dict( + local_state, + ParallelMode.PARALLEL_2P5D_ROW, + dims={weight_key: -1}, + partition_states={weight_key: True}, + ) + # partition in column groups + local_state = partition_tensor_parallel_state_dict( + local_state, + ParallelMode.PARALLEL_2P5D_COL, + dims={weight_key: -1}, + partition_states={weight_key: True}, + ) + + super()._load_from_global_state_dict(local_state, prefix, *args, **kwargs) + + def _save_to_global_state_dict(self, destination, prefix, keep_vars): + weight_key = prefix + "weight" + local_state = OrderedDict({weight_key: self.weight}) + + # gather in column groups + local_state = gather_tensor_parallel_state_dict( + local_state, + ParallelMode.PARALLEL_2P5D_COL, + dims={weight_key: -1}, + partition_states={weight_key: True}, + keep_vars=keep_vars, + ) + # gather in row groups + if gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL) == 0: + local_state = gather_tensor_parallel_state_dict( + local_state, + ParallelMode.PARALLEL_2P5D_ROW, + dims={weight_key: -1}, + partition_states={weight_key: True}, + keep_vars=keep_vars, + ) + if gpc.get_local_rank(ParallelMode.TENSOR) == 0: + destination.update(local_state) + + def forward(self, input_: Tensor) -> Tensor: + input_ = split_batch_2p5d(input_, 0) + + weight = all_gather_tensor_2p5d(self.weight, -1, ParallelMode.PARALLEL_2P5D_COL) + + output = F.embedding(input_, weight, self.padding_idx, *self.embed_args, **self.embed_kwargs) + + return output + + +@LAYERS.register_module +class VocabParallelEmbedding2p5D(ParallelLayer): + """Embedding parallelized in the vocabulary dimension. + + Args: + num_embeddings (int): number of embeddings. + embedding_dim (int): dimension of embedding. + padding_idx (int, optional): If specified, the entries at padding_idx do not contribute to the gradient; + therefore, the embedding vector at padding_idx is not updated during training, + i.e. it remains as a fixed “pad”, defaults to None. + dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None. + weight_initializer (:class:`typing.Callable`, optional): + he initializer of weight, defaults to normal initializer. + + The ``args`` and ``kwargs`` used in :class:``torch.nn.functional.embedding`` should contain: + :: + + max_norm (float, optional): If given, each embedding vector with norm larger than max_norm is + renormalized to have norm max_norm. Note: this will modify weight in-place. + norm_type (float, optional): The p of the p-norm to compute for the max_norm option. Default 2. + scale_grad_by_freq (bool, optional): If given, this will scale gradients by the inverse + of frequency of the words in the mini-batch. Default False. + sparse (bool, optional): If True, gradient w.r.t. weight will be a sparse tensor. Default False. + + More details about ``args`` and ``kwargs`` could be found in + `Embedding `_. + + More details about initializer please refer to + `init `_. + """ + + def __init__( + self, + num_embeddings: int, + embedding_dim: int, + padding_idx: int = None, + dtype: torch.dtype = None, + weight_initializer: Callable = init.normal_(), + *args, + **kwargs, + ): + super().__init__() + self.num_embeddings = num_embeddings + self.embed_dim = embedding_dim + self.padding_idx = padding_idx + self.embed_args = args + self.embed_kwargs = kwargs + + assert_tesseract_initialization() + self.tesseract_dim, self.tesseract_dep = get_tesseract_dim_dep_from_env() + self.num_embeddings_per_partition = divide(self.num_embeddings, self.tesseract_dim) + self.embed_dim_per_partition = divide(self.embed_dim, self.tesseract_dim) + tensor_parallel_rank = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL) + self.vocab_start_index = tensor_parallel_rank * self.num_embeddings_per_partition + self.vocab_end_index = self.vocab_start_index + self.num_embeddings_per_partition + + self.weight = Parameter( + torch.empty( + (self.num_embeddings_per_partition, self.embed_dim_per_partition), + device=get_current_device(), + dtype=dtype, + ) + ) + + self.reset_parameters(weight_initializer) + self._set_tensor_parallel_attributes() + env.vocab_parallel = True + + def _set_tensor_parallel_attributes(self): + set_tensor_parallel_attribute_by_partition(self.weight, self.tesseract_dim**2) + + def reset_parameters(self, weight_initializer) -> None: + with seed(ParallelMode.TENSOR): + fan_in, fan_out = self.num_embeddings, self.embed_dim + weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out) + self._fill_padding_idx_with_zero() + + def _fill_padding_idx_with_zero(self) -> None: + if self.padding_idx is not None and self.vocab_start_index <= self.padding_idx < self.vocab_end_index: + with torch.no_grad(): + self.weight[self.padding_idx - self.vocab_start_index].fill_(0) + + def _load_from_global_state_dict(self, state_dict, prefix, *args, **kwargs): + local_state = OrderedDict() + weight_key = prefix + "weight" + if gpc.get_local_rank(ParallelMode.TENSOR) == 0: + # weight + weight = state_dict.pop(weight_key, None) + if weight is not None: + local_state[weight_key] = weight + + # partition in row groups + if gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL) == 0: + local_state = partition_tensor_parallel_state_dict( + local_state, + ParallelMode.PARALLEL_2P5D_ROW, + dims={weight_key: -1}, + partition_states={weight_key: True}, + ) + # partition in column groups + local_state = partition_tensor_parallel_state_dict( + local_state, + ParallelMode.PARALLEL_2P5D_COL, + dims={weight_key: 0}, + partition_states={weight_key: True}, + ) + + super()._load_from_global_state_dict(local_state, prefix, *args, **kwargs) + + def _save_to_global_state_dict(self, destination, prefix, keep_vars): + weight_key = prefix + "weight" + local_state = OrderedDict({weight_key: self.weight}) + + # gather in column groups + local_state = gather_tensor_parallel_state_dict( + local_state, + ParallelMode.PARALLEL_2P5D_COL, + dims={weight_key: 0}, + partition_states={weight_key: True}, + keep_vars=keep_vars, + ) + # gather in row groups + if gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL) == 0: + local_state = gather_tensor_parallel_state_dict( + local_state, + ParallelMode.PARALLEL_2P5D_ROW, + dims={weight_key: -1}, + partition_states={weight_key: True}, + keep_vars=keep_vars, + ) + if gpc.get_local_rank(ParallelMode.TENSOR) == 0: + destination.update(local_state) + + def forward(self, input_: Tensor) -> Tensor: + # Build the mask. + input_mask = (input_ < self.vocab_start_index) | (input_ >= self.vocab_end_index) + # Mask the input. + masked_input = input_.clone() - self.vocab_start_index + masked_input[input_mask] = 0 + + output_parallel = F.embedding( + masked_input, self.weight, self.padding_idx, *self.embed_args, **self.embed_kwargs + ) + + # Mask the output embedding. + output_parallel[input_mask, :] = 0.0 + # Reduce across all the model parallel GPUs. + output = reduce_scatter_tensor_2p5d(output_parallel, 0, ParallelMode.PARALLEL_2P5D_COL) + return output + + +@LAYERS.register_module +class Classifier2p5D(ParallelLayer): + r"""Classifier for 2.5D parallelism. + + Args: + in_features (int): size of each input sample. + num_classes (int): number of classes. + weight (:class:`torch.nn.Parameter`, optional): weight of the classifier, defaults to None. + bias (bool, optional): If set to ``False``, the layer will not learn an additive bias, defaults to ``True``. + dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None. + weight_initializer (:class:`typing.Callable`, optional): + The initializer of weight, defaults to kaiming uniform initializer. + bias_initializer (:class:`typing.Callable`, optional): + The initializer of bias, defaults to xavier uniform initializer. + + More details about ``initializer`` please refer to + `init `_. + """ + + def __init__( + self, + in_features: int, + num_classes: int, + weight: Parameter = None, + bias: bool = True, + dtype: torch.dtype = None, + weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), + bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1), + ): + super().__init__() + self.in_features = in_features + self.num_classes = num_classes + assert_tesseract_initialization() + self.row_rank = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL) + self.col_rank = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_ROW) + self.dep_rank = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_DEP) + self.tesseract_dim, self.tesseract_dep = get_tesseract_dim_dep_from_env() + + # partitioning dimension + self.input_size_per_partition = divide(self.in_features, self.tesseract_dim**2) + + if weight is not None: + self.weight = weight + self.has_weight = False + else: + self.weight = Parameter( + torch.empty(self.num_classes, self.input_size_per_partition, device=get_current_device(), dtype=dtype) + ) + self.has_weight = True + if bias: + self.bias = Parameter(torch.zeros(self.num_classes, device=get_current_device(), dtype=dtype)) + else: + self.bias = None + + self.reset_parameters(weight_initializer, bias_initializer) + self._set_tensor_parallel_attributes() + + def _set_tensor_parallel_attributes(self): + if self.has_weight: + set_tensor_parallel_attribute_by_partition(self.weight, self.tesseract_dim**2) + + def reset_parameters(self, weight_initializer, bias_initializer) -> None: + with seed(ParallelMode.TENSOR): + fan_in, fan_out = self.in_features, self.num_classes + col_src_rank = gpc.get_ranks_in_group(ParallelMode.PARALLEL_2P5D_COL)[0] + row_src_rank = gpc.get_ranks_in_group(ParallelMode.PARALLEL_2P5D_ROW)[0] + + if self.has_weight: + weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out) + + if self.bias is not None: + bias_initializer(self.bias, fan_in=fan_in) + broadcast(self.bias, col_src_rank, ParallelMode.PARALLEL_2P5D_COL) + broadcast(self.bias, row_src_rank, ParallelMode.PARALLEL_2P5D_ROW) + + def _load_from_global_state_dict(self, state_dict, prefix, *args, **kwargs): + local_state = OrderedDict() + weight_key = prefix + "weight" + bias_key = prefix + "bias" + if gpc.get_local_rank(ParallelMode.TENSOR) == 0: + # weight + if self.has_weight: + weight = state_dict.pop(weight_key, None) + if weight is not None: + local_state[weight_key] = weight + # bias + if self.bias is not None: + bias = state_dict.pop(bias_key, None) + if bias is not None: + local_state[bias_key] = bias + + # partition in row groups + if gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL) == 0: + local_state = partition_tensor_parallel_state_dict( + local_state, + ParallelMode.PARALLEL_2P5D_ROW, + dims={weight_key: -1, bias_key: 0}, + partition_states={weight_key: True, bias_key: False}, + ) + # partition in column groups + local_state = partition_tensor_parallel_state_dict( + local_state, + ParallelMode.PARALLEL_2P5D_COL, + dims={weight_key: -1, bias_key: 0}, + partition_states={weight_key: True, bias_key: False}, + ) + + super()._load_from_global_state_dict(local_state, prefix, *args, **kwargs) + + def _save_to_global_state_dict(self, destination, prefix, keep_vars): + weight_key = prefix + "weight" + bias_key = prefix + "bias" + local_state = OrderedDict() + if self.has_weight: + local_state[weight_key] = self.weight + if self.bias is not None: + local_state[bias_key] = self.bias + + # gather in column groups + local_state = gather_tensor_parallel_state_dict( + local_state, + ParallelMode.PARALLEL_2P5D_COL, + dims={weight_key: -1, bias_key: 0}, + partition_states={weight_key: True, bias_key: False}, + keep_vars=keep_vars, + ) + # gather in row groups + if gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL) == 0: + local_state = gather_tensor_parallel_state_dict( + local_state, + ParallelMode.PARALLEL_2P5D_ROW, + dims={weight_key: -1, bias_key: 0}, + partition_states={weight_key: True, bias_key: False}, + keep_vars=keep_vars, + ) + if gpc.get_local_rank(ParallelMode.TENSOR) == 0: + destination.update(local_state) + + def forward(self, input_: Tensor) -> Tensor: + out_shape = input_.shape[:-1] + (self.num_classes,) + + return classifier_2p5d( + input_, + self.weight, + self.bias, + self.tesseract_dim, + out_shape, + self.row_rank, + self.col_rank, + ParallelMode.PARALLEL_2P5D_ROW, + ParallelMode.PARALLEL_2P5D_COL, + self.data_parallel_rank, + self.pipeline_parallel_rank, + self.pipeline_parallel_size, + self.tensor_parallel_size, + ) + + +@LAYERS.register_module +class VocabParallelClassifier2p5D(ParallelLayer): + r"""Vocab parallel classifier layer for 2.5D parallelism. + + Args: + in_features (int): size of each input sample. + num_classes (int): number of classes. + weight (:class:`torch.nn.Parameter`, optional): weight of the classifier, defaults to None. + bias (bool, optional): If set to ``False``, the layer will not learn an additive bias, defaults to ``True``. + dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None. + weight_initializer (:class:`typing.Callable`, optional): + The initializer of weight, defaults to kaiming uniform initializer. + bias_initializer (:class:`typing.Callable`, optional): + The initializer of bias, defaults to xavier uniform initializer. + + More details about ``initializer`` please refer to + `init `_. + """ + + def __init__( + self, + in_features: int, + num_classes: int, + weight: Parameter = None, + bias: bool = True, + dtype: torch.dtype = None, + weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), + bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1), + ): + super().__init__() + + self.in_features = in_features + self.num_classes = num_classes + + # parallel setting + assert_tesseract_initialization() + self.row_rank = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL) + self.col_rank = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_ROW) + self.dep_rank = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_DEP) + self.tesseract_dim, _ = get_tesseract_dim_dep_from_env() + + # partitioning dimension + self.input_size_per_partition = divide(in_features, self.tesseract_dim) + self.hidden_size_per_partition = divide(num_classes, self.tesseract_dim) + + # create weight, shape: [k/q, h/q] + factory_kwargs = {"device": get_current_device(), "dtype": dtype} + if weight is not None: + self.weight = weight + self.has_weight = False + else: + self.weight = Parameter( + torch.empty(self.hidden_size_per_partition, self.input_size_per_partition, **factory_kwargs) + ) + self.has_weight = True + # create bias, shape: [h/q] + if bias: + self.bias = Parameter(torch.empty(self.hidden_size_per_partition, **factory_kwargs)) + else: + self.bias = None + + # initialize parameters + with seed(ParallelMode.TENSOR): + self.reset_parameters(weight_initializer, bias_initializer) + self._set_tensor_parallel_attributes() + env.vocab_parallel = True + + def _set_tensor_parallel_attributes(self): + if self.has_weight: + set_tensor_parallel_attribute_by_partition(self.weight, self.tesseract_dim**2) + if self.bias is not None: + set_tensor_parallel_attribute_by_partition(self.bias, self.tesseract_dim) + + def reset_parameters(self, weight_initializer, bias_initializer) -> None: + fan_in, fan_out = self.in_features, self.num_classes + if self.has_weight: + weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out) + if self.bias is not None: + bias_initializer(self.bias, fan_in=fan_in) + + def _load_from_global_state_dict(self, state_dict, prefix, *args, **kwargs): + local_state = OrderedDict() + weight_key = prefix + "weight" + bias_key = prefix + "bias" + if gpc.get_local_rank(ParallelMode.TENSOR) == 0: + # weight + if self.has_weight: + weight = state_dict.pop(weight_key, None) + if weight is not None: + local_state[weight_key] = weight + # bias + if self.bias is not None: + bias = state_dict.pop(bias_key, None) + if bias is not None: + local_state[bias_key] = bias + + # partition in row groups + if gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL) == 0: + local_state = partition_tensor_parallel_state_dict( + local_state, + ParallelMode.PARALLEL_2P5D_ROW, + dims={weight_key: -1, bias_key: 0}, + partition_states={weight_key: True, bias_key: True}, + ) + # partition in column groups + local_state = partition_tensor_parallel_state_dict( + local_state, + ParallelMode.PARALLEL_2P5D_COL, + dims={weight_key: 0, bias_key: 0}, + partition_states={weight_key: True, bias_key: True}, + ) + + super()._load_from_global_state_dict(local_state, prefix, *args, **kwargs) + + def forward(self, x: Tensor) -> Tensor: + # input: [m/dq, n/q, k/q] + # output: [m/dq, n/q, h/q] + out_shape = x.shape[:-1] + (self.hidden_size_per_partition,) + + output = Matmul_ABT_2p5D.apply( + x, + self.weight, + self.tesseract_dim, + out_shape, + self.row_rank, + self.col_rank, + self.dep_rank, + ParallelMode.PARALLEL_2P5D_ROW, + ParallelMode.PARALLEL_2P5D_COL, + self.data_parallel_rank, + self.pipeline_parallel_rank, + self.pipeline_parallel_size, + self.tensor_parallel_size, + ) + + if self.bias is not None: + output = add_bias_2p5d( + output, + self.bias, + self.hidden_size_per_partition, + self.tesseract_dim, + self.row_rank, + self.col_rank, + self.dep_rank, + ParallelMode.PARALLEL_2P5D_COL, + False, + self.data_parallel_rank, + self.pipeline_parallel_rank, + self.pipeline_parallel_size, + self.tensor_parallel_size, + ) + return output diff --git a/colossalai/legacy/nn/layer/parallel_3d/__init__.py b/colossalai/legacy/nn/layer/parallel_3d/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..5d38f6a5687412856997528cf121942c82a478c1 --- /dev/null +++ b/colossalai/legacy/nn/layer/parallel_3d/__init__.py @@ -0,0 +1,23 @@ +from ._operation import reduce_by_batch_3d, split_batch_3d, split_tensor_3d +from .layers import ( + Classifier3D, + Embedding3D, + LayerNorm3D, + Linear3D, + PatchEmbedding3D, + VocabParallelClassifier3D, + VocabParallelEmbedding3D, +) + +__all__ = [ + "reduce_by_batch_3d", + "split_tensor_3d", + "split_batch_3d", + "Linear3D", + "LayerNorm3D", + "PatchEmbedding3D", + "Classifier3D", + "Embedding3D", + "VocabParallelEmbedding3D", + "VocabParallelClassifier3D", +] diff --git a/colossalai/legacy/nn/layer/parallel_3d/_operation.py b/colossalai/legacy/nn/layer/parallel_3d/_operation.py new file mode 100755 index 0000000000000000000000000000000000000000..fe42d8e28111bcd838fee78df421544f9bc2dff5 --- /dev/null +++ b/colossalai/legacy/nn/layer/parallel_3d/_operation.py @@ -0,0 +1,591 @@ +#!/usr/bin/env python +# -*- encoding: utf-8 -*- + +from typing import Optional, Tuple + +import torch +from torch import Tensor +from torch.cuda.amp import custom_bwd, custom_fwd + +from colossalai.legacy.communication import all_gather, all_reduce, broadcast, reduce, reduce_scatter +from colossalai.legacy.constants import INPUT_GROUP_3D, WEIGHT_GROUP_3D +from colossalai.legacy.context.parallel_mode import ParallelMode +from colossalai.legacy.core import global_context as gpc + +from ._utils import get_parallel_mode_from_env, push_async_grad + + +class _Linear3D(torch.autograd.Function): + @staticmethod + @custom_fwd(cast_inputs=torch.float16) + def forward( + ctx, + input_: Tensor, + weight: Tensor, + weight_id: int, + input_parallel_mode: ParallelMode, + weight_parallel_mode: ParallelMode, + output_parallel_mode: ParallelMode, + ) -> Tensor: + ctx.weight_id = weight_id + ctx.input_parallel_mode = input_parallel_mode + ctx.weight_parallel_mode = weight_parallel_mode + ctx.output_parallel_mode = output_parallel_mode + + input_ = all_gather(input_, 0, input_parallel_mode) + weight = all_gather(weight, 0, weight_parallel_mode) + ctx.save_for_backward(input_, weight) + + output = torch.matmul(input_, weight) + output = reduce_scatter(output, 0, output_parallel_mode) + + return output + + @staticmethod + @custom_bwd + def backward(ctx, output_grad: Tensor) -> Tuple[Tensor, ...]: + input_, weight = ctx.saved_tensors + output_grad = all_gather(output_grad, 0, ctx.output_parallel_mode) + + input_grad = torch.matmul(output_grad, weight.transpose(0, 1)) + input_grad, input_op = reduce_scatter(input_grad, 0, ctx.input_parallel_mode, async_op=True) + + weight_grad = torch.matmul( + input_.reshape(-1, input_.shape[-1]).transpose(0, 1), output_grad.reshape(-1, output_grad.shape[-1]) + ) + weight_grad, op = reduce_scatter(weight_grad, 0, ctx.weight_parallel_mode, async_op=True) + weight_grad = push_async_grad(op, weight_grad, ctx.weight_id) + + input_op.wait() + + return input_grad, weight_grad, None, None, None, None + + +def linear_3d( + input_: Tensor, + weight: Tensor, + input_parallel_mode: ParallelMode, + weight_parallel_mode: ParallelMode, + output_parallel_mode: ParallelMode, +) -> Tensor: + r"""Linear layer for 3D parallelism. + + Args: + input_ (:class:`torch.tensor`): input matrix. + weight (:class:`torch.tensor`): matrix of weight. + input_parallel_mode (:class:`colossalai.legacy.context.parallel_mode.ParallelMode`): input parallel mode. + weight_parallel_mode (:class:`colossalai.legacy.context.parallel_mode.ParallelMode`): weight parallel mode. + output_parallel_mode (:class:`colossalai.legacy.context.parallel_mode.ParallelMode`): output parallel mode. + + Note: + The parallel_mode should be concluded in ``ParallelMode``. More details about ``ParallelMode`` could be found + in `parallel_mode `_ + """ + return _Linear3D.apply( + input_, + weight, + id(weight), + input_parallel_mode, + weight_parallel_mode, + output_parallel_mode, + ) + + +class _Classifier3D(torch.autograd.Function): + @staticmethod + @custom_fwd(cast_inputs=torch.float16) + def forward( + ctx, + input_: Tensor, + weight: Tensor, + bias: Optional[Tensor], + weight_id: int, + bias_id: Optional[int], + input_parallel_mode: ParallelMode, + weight_parallel_mode: ParallelMode, + output_parallel_mode: ParallelMode, + ) -> Tensor: + ctx.use_bias = bias is not None + ctx.weight_id = weight_id + + src_rank = gpc.get_ranks_in_group(input_parallel_mode)[gpc.get_local_rank(output_parallel_mode)] + weight = broadcast(weight, src_rank, input_parallel_mode) + ctx.save_for_backward(input_, weight) + + output = torch.matmul(input_, weight.transpose(0, 1)) + output = all_reduce(output, output_parallel_mode) + + if bias is not None: + ctx.bias_id = bias_id + output += bias + + ctx.src_rank = src_rank + ctx.input_parallel_mode = input_parallel_mode + ctx.weight_parallel_mode = weight_parallel_mode + ctx.output_parallel_mode = output_parallel_mode + return output + + @staticmethod + @custom_bwd + def backward(ctx, output_grad: Tensor) -> Tuple[Tensor, ...]: + input_, weight = ctx.saved_tensors + weight_grad = torch.matmul( + output_grad.reshape(-1, output_grad.shape[-1]).transpose(0, 1), input_.reshape(-1, input_.shape[-1]) + ) + weight_grad = reduce(weight_grad, ctx.src_rank, ctx.input_parallel_mode) + if gpc.get_local_rank(ctx.input_parallel_mode) == gpc.get_local_rank(ctx.output_parallel_mode): + weight_grad, op = all_reduce(weight_grad, ctx.weight_parallel_mode, async_op=True) + weight_grad = push_async_grad(op, weight_grad, ctx.weight_id) + else: + weight_grad = None + + if ctx.use_bias: + bias_grad = torch.sum(output_grad, dim=tuple(range(len(output_grad.shape))[:-1])) + bias_grad = all_reduce(bias_grad, ctx.input_parallel_mode) + bias_grad, op = all_reduce(bias_grad, ctx.weight_parallel_mode, async_op=True) + bias_grad = push_async_grad(op, bias_grad, ctx.bias_id) + else: + bias_grad = None + + input_grad = torch.matmul(output_grad, weight) + + return input_grad, weight_grad, bias_grad, None, None, None, None, None + + +def classifier_3d( + input_: Tensor, + weight: Tensor, + bias: Optional[Tensor], + input_parallel_mode: ParallelMode, + weight_parallel_mode: ParallelMode, + output_parallel_mode: ParallelMode, +) -> Tensor: + r"""3D parallel classifier. + + Args: + input_ (:class:`torch.tensor`): input matrix. + weight (:class:`torch.tensor`): matrix of weight. + bias (:class:`torch.tensor`): matrix of bias. + input_parallel_mode (:class:`colossalai.legacy.context.parallel_mode.ParallelMode`): input parallel mode. + weight_parallel_mode (:class:`colossalai.legacy.context.parallel_mode.ParallelMode`): weight parallel mode. + output_parallel_mode (:class:`colossalai.legacy.context.parallel_mode.ParallelMode`): output parallel mode. + + Note: + The parallel_mode should be concluded in ``ParallelMode``. More details about ``ParallelMode`` could be found + in `parallel_mode `_ + """ + return _Classifier3D.apply( + input_, + weight, + bias, + id(weight), + id(bias) if bias is not None else None, + input_parallel_mode, + weight_parallel_mode, + output_parallel_mode, + ) + + +class _VocabParallelClassifier3D(torch.autograd.Function): + @staticmethod + @custom_fwd(cast_inputs=torch.float16) + def forward( + ctx, + input_: Tensor, + weight: Tensor, + bias: Optional[Tensor], + weight_id: int, + bias_id: Optional[int], + input_parallel_mode: ParallelMode, + weight_parallel_mode: ParallelMode, + output_parallel_mode: ParallelMode, + ) -> Tensor: + ctx.use_bias = bias is not None + ctx.weight_id = weight_id + + input_ = all_gather(input_, 0, input_parallel_mode) + weight = all_gather(weight, 0, weight_parallel_mode).transpose(0, 1) + ctx.save_for_backward(input_, weight) + + output = torch.matmul(input_, weight) + output = reduce_scatter(output, 0, output_parallel_mode) + + if bias is not None: + ctx.bias_id = bias_id + output += bias + + ctx.input_parallel_mode = input_parallel_mode + ctx.weight_parallel_mode = weight_parallel_mode + ctx.output_parallel_mode = output_parallel_mode + return output + + @staticmethod + @custom_bwd + def backward(ctx, output_grad: Tensor) -> Tuple[Tensor, ...]: + input_, weight = ctx.saved_tensors + output_grad = all_gather(output_grad, 0, ctx.output_parallel_mode) + + input_grad = torch.matmul(output_grad, weight.transpose(0, 1)) + input_grad, input_op = reduce_scatter(input_grad, 0, ctx.input_parallel_mode, async_op=True) + + weight_grad = torch.matmul( + input_.reshape(-1, input_.shape[-1]).transpose(0, 1), output_grad.reshape(-1, output_grad.shape[-1]) + ) + weight_grad, op = reduce_scatter(weight_grad.transpose(0, 1), 0, ctx.weight_parallel_mode, async_op=True) + weight_grad = push_async_grad(op, weight_grad, ctx.weight_id) + + if ctx.use_bias: + bias_grad = torch.sum(output_grad, dim=tuple(range(len(output_grad.shape))[:-1])) + bias_grad, op = all_reduce(bias_grad, ctx.weight_parallel_mode, async_op=True) + bias_grad = push_async_grad(op, bias_grad, ctx.bias_id) + else: + bias_grad = None + + input_op.wait() + + return input_grad, weight_grad, bias_grad, None, None, None, None, None + + +def vocab_parallel_classifier_3d( + input_: Tensor, + weight: Tensor, + bias: Optional[Tensor], + input_parallel_mode: ParallelMode, + weight_parallel_mode: ParallelMode, + output_parallel_mode: ParallelMode, +) -> Tensor: + r"""3D vocab parallel classifier. + + Args: + input_ (:class:`torch.tensor`): input matrix. + weight (:class:`torch.tensor`): matrix of weight. + bias (:class:`torch.tensor`): matrix of bias. + input_parallel_mode (:class:`colossalai.legacy.context.parallel_mode.ParallelMode`): input parallel mode. + weight_parallel_mode (:class:`colossalai.legacy.context.parallel_mode.ParallelMode`): weight parallel mode. + output_parallel_mode (:class:`colossalai.legacy.context.parallel_mode.ParallelMode`): output parallel mode. + + Note: + The parallel_mode should be concluded in ``ParallelMode``. More details about ``ParallelMode`` could be found + in `parallel_mode `_ + """ + return _VocabParallelClassifier3D.apply( + input_, + weight, + bias, + id(weight), + id(bias) if bias is not None else None, + input_parallel_mode, + weight_parallel_mode, + output_parallel_mode, + ) + + +@torch.jit.script +def norm_forward(x: Tensor, mean: Tensor, sqr_mean: Tensor, weight: Tensor, bias: Tensor, eps: float): + mu = x - mean + var = sqr_mean - mean**2 + sigma = torch.sqrt(var + eps) + z = mu / sigma + output = weight * z + bias + + return output, mu, sigma + + +@torch.jit.script +def norm_backward(grad: Tensor, mu: Tensor, sigma: Tensor, weight: Tensor): + # dbias, dweight = grad, grad * mu / sigma + dz = grad * weight + dmu = dz / sigma + dvar = dz * mu * (-0.5) * sigma ** (-3) + dmean = -dmu + dvar = torch.sum(dvar, -1, keepdim=True) + dmean = torch.sum(dmean, -1, keepdim=True) + + return dmu, dmean, dvar + + +class _Layernorm3D(torch.autograd.Function): + @staticmethod + @custom_fwd(cast_inputs=torch.float32) + def forward( + ctx, + input_: Tensor, + weight: Tensor, + bias: Tensor, + weight_id: int, + bias_id: int, + normalized_shape: int, + eps: float, + output_parallel_mode: ParallelMode, + input_x_weight_parallel_mode: ParallelMode, + ) -> Tensor: + ctx.weight_id = weight_id + ctx.bias_id = bias_id + + sum_ = torch.sum(input_, dim=-1, keepdim=True) + sqr_sum = torch.sum(input_**2, dim=-1, keepdim=True) + mean, sqr_mean = all_reduce(torch.stack((sum_, sqr_sum)), output_parallel_mode) / normalized_shape + + output, mu, sigma = norm_forward(input_, mean, sqr_mean, weight, bias, eps) + + ctx.save_for_backward(mu, sigma, weight) + + ctx.normalized_shape = normalized_shape + ctx.output_parallel_mode = output_parallel_mode + ctx.input_x_weight_parallel_mode = input_x_weight_parallel_mode + + return output + + @staticmethod + @custom_bwd + def backward(ctx, output_grad: Tensor) -> Tuple[Tensor, ...]: + mu, sigma, weight = ctx.saved_tensors + + bias_grad, weight_grad = output_grad, output_grad * mu / sigma + bias_grad = torch.sum(bias_grad, dim=tuple(range(len(bias_grad.shape))[:-1])) + bias_grad, op = all_reduce(bias_grad, ctx.input_x_weight_parallel_mode, async_op=True) + bias_grad = push_async_grad(op, bias_grad, ctx.bias_id) + weight_grad = torch.sum(weight_grad, dim=tuple(range(len(weight_grad.shape))[:-1])) + weight_grad, op = all_reduce(weight_grad, ctx.input_x_weight_parallel_mode, async_op=True) + weight_grad = push_async_grad(op, weight_grad, ctx.weight_id) + + dmu, dmean, dvar = norm_backward(output_grad, mu, sigma, weight) + dvar, dmean = all_reduce(torch.stack((dvar, dmean)), ctx.output_parallel_mode) + input_grad = dmu + (dmean + 2 * dvar * mu) / ctx.normalized_shape + + return input_grad, weight_grad, bias_grad, None, None, None, None, None, None, None, None + + +def layernorm_3d( + input_: Tensor, + weight: Tensor, + bias: Tensor, + normalized_shape: int, + eps: float, + output_parallel_mode: ParallelMode, + input_x_weight_parallel_mode: ParallelMode, +) -> Tensor: + r"""3D parallel Layernorm. + + Args: + input_ (:class:`torch.tensor`): input matrix. + weight (:class:`torch.tensor`): matrix of weight. + bias (:class:`torch.tensor`): matrix of bias. + normalized_shape (int): input shape from an expected input of size. + :math:`[* \times \text{normalized_shape}[0] \times \text{normalized_shape}[1] + \times \ldots \times \text{normalized_shape}[-1]]` + If a single integer is used, it is treated as a singleton list, and this module will + normalize over the last dimension which is expected to be of that specific size. + eps (float): a value added to the denominator for numerical stability + output_parallel_mode (:class:`colossalai.legacy.context.parallel_mode.ParallelMode`): output parallel mode. + input_x_weight_parallel_mode (:class:`colossalai.legacy.context.parallel_mode.ParallelMode`): input x weight parallel mode. + + Note: + The parallel_mode should be concluded in ``ParallelMode``. More details about ``ParallelMode`` could be found + in `parallel_mode `_ + """ + return _Layernorm3D.apply( + input_, + weight, + bias, + id(weight), + id(bias), + normalized_shape, + eps, + output_parallel_mode, + input_x_weight_parallel_mode, + ) + + +def split_tensor_3d(tensor: Tensor, dim: int, parallel_mode: ParallelMode) -> Tensor: + r"""Splits 3D parallel tensor in specified dimension. + + Args: + tensor (:class:`torch.tensor`): Input tensor. + dim (int): Specified dimension in which to split. + parallel_mode (:class:`colossalai.legacy.context.parallel_mode.ParallelMode`, optional): Parallel mode. + + Returns: + :class:`torch.tensor`: The tensor has been split. + + Note: + The parallel_mode should be concluded in ``ParallelMode``. More details about ``ParallelMode`` could be found + in `parallel_mode `_. + """ + dim_size = tensor.size(dim) + world_size = gpc.get_world_size(parallel_mode) + assert dim_size % world_size == 0, ( + f"The dimension {dim} to split, size ({dim_size}) is not a multiple of world size ({world_size}), " + f"cannot split tensor evenly" + ) + if tensor.size(dim) <= 1: + return tensor + output = torch.chunk(tensor, gpc.get_world_size(parallel_mode), dim=dim)[ + gpc.get_local_rank(parallel_mode) + ].contiguous() + return output + + +def split_batch_3d( + input_: Tensor, + dim: int = 0, + input_parallel_mode: ParallelMode = ParallelMode.PARALLEL_3D_INPUT, + weight_parallel_mode: ParallelMode = ParallelMode.PARALLEL_3D_WEIGHT, +) -> Tensor: + r"""Splits 3D tensor in batch. + + Args: + input_ (:class:`torch.tensor`): Input tensor. + dim (int): Specified dimension in which to split. + input_parallel_mode (:class:`colossalai.legacy.context.parallel_mode.ParallelMode`, optional): input parallel mode. + weight_parallel_mode (:class:`colossalai.legacy.context.parallel_mode.ParallelMode`, optional): weight parallel mode. + + Returns: + :class:`torch.tensor`: The tensor has been split. + + Note: + The parallel_mode should be concluded in ``ParallelMode``. More details about ``ParallelMode`` could be found + in `parallel_mode `_. + """ + if input_.size(dim) <= 1: + return input_ + weight_parallel_mode = get_parallel_mode_from_env(WEIGHT_GROUP_3D) + input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D) + weight_world_size = gpc.get_world_size(weight_parallel_mode) + input_world_size = gpc.get_world_size(input_parallel_mode) + output = torch.chunk(input_, weight_world_size, dim=dim)[gpc.get_local_rank(weight_parallel_mode)].contiguous() + output = torch.chunk(output, input_world_size, dim=dim)[gpc.get_local_rank(input_parallel_mode)].contiguous() + return output + + +class _ReduceTensor3D(torch.autograd.Function): + @staticmethod + def forward(ctx, input_, parallel_mode): + return all_reduce(input_, parallel_mode) + + @staticmethod + def backward(ctx, output_grad): + return output_grad, None + + +def reduce_tensor_3d(tensor: Tensor, parallel_mode: ParallelMode) -> Tensor: + r"""All-reduce the input + + Args: + tensor (:class:`torch.tensor`): Input tensor. + parallel_mode (:class:`colossalai.legacy.context.parallel_mode.ParallelMode`): Parallel mode. + + Note: + The parallel_mode should be concluded in ``ParallelMode``. More details about ``ParallelMode`` could be found + in `parallel_mode `_. + """ + return _ReduceTensor3D.apply(tensor, parallel_mode) + + +class _AllGatherTensor3D(torch.autograd.Function): + @staticmethod + def forward(ctx, input_, dim, parallel_mode): + ctx.dim = dim + ctx.parallel_mode = parallel_mode + output = all_gather(input_, dim, parallel_mode) + return output + + @staticmethod + def backward(ctx, output_grad): + input_grad = reduce_scatter(output_grad, ctx.dim, ctx.parallel_mode) + return input_grad, None, None + + +def all_gather_tensor_3d(tensor: Tensor, dim: int, parallel_mode: ParallelMode) -> Tensor: + r"""All-reduce the gradient in backward pass. + + Args: + tensor (:class:`torch.tensor`): Input tensor. + dim (int): Dimension to gather. + parallel_mode (:class:`colossalai.legacy.context.parallel_mode.ParallelMode`): Parallel mode. + + Note: + The parallel_mode should be concluded in ``ParallelMode``. More details about ``ParallelMode`` could be found + in `parallel_mode `_. + """ + return _AllGatherTensor3D.apply(tensor, dim, parallel_mode) + + +class _ReduceScatterTensor3D(torch.autograd.Function): + @staticmethod + def forward(ctx, input_, dim, parallel_mode): + ctx.dim = dim + ctx.parallel_mode = parallel_mode + return reduce_scatter(input_, dim, parallel_mode) + + @staticmethod + def backward(ctx, output_grad): + input_grad = all_gather(output_grad, ctx.dim, ctx.parallel_mode) + return input_grad, None, None + + +def reduce_scatter_tensor_3d(tensor: Tensor, dim: int, parallel_mode: ParallelMode) -> Tensor: + r"""Reduce-scatter the input. + + Args: + tensor (:class:`torch.tensor`): Input tensor. + dim (int): Dimension to scatter. + parallel_mode (:class:`colossalai.legacy.context.parallel_mode.ParallelMode`): Parallel mode. + + Note: + The parallel_mode should be concluded in ``ParallelMode``. More details about ``ParallelMode`` could be found + in `parallel_mode `_ + """ + dim_size = tensor.size(dim) + world_size = gpc.get_world_size(parallel_mode) + assert ( + dim_size % world_size == 0 + ), f"The batch size ({dim_size}) is not a multiple of square of 3D depth ({world_size})." + + return _ReduceScatterTensor3D.apply(tensor, dim, parallel_mode) + + +class _ReduceByBatch3D(torch.autograd.Function): + @staticmethod + @custom_fwd(cast_inputs=torch.float32) + def forward( + ctx, + input_: Tensor, + input_parallel_mode: ParallelMode, + weight_parallel_mode: ParallelMode, + reduce_mean: bool = False, + ) -> Tensor: + output = all_reduce(input_, input_parallel_mode) + output = all_reduce(output, weight_parallel_mode) + ctx.reduce_mean = reduce_mean + if reduce_mean: + reduce_size = gpc.get_world_size(input_parallel_mode) * gpc.get_world_size(weight_parallel_mode) + ctx.reduce_size = reduce_size + return output.clone() / reduce_size + return output.clone() + + @staticmethod + @custom_bwd + def backward(ctx, output_grad: Tensor) -> Tuple[Tensor, ...]: + if ctx.reduce_mean: + return output_grad / ctx.reduce_size, None, None, None + else: + return output_grad, None, None, None + + +def reduce_by_batch_3d( + tensor: Tensor, input_parallel_mode: ParallelMode, weight_parallel_mode: ParallelMode, reduce_mean: bool = False +) -> Tensor: + r"""All-reduce the input from the model parallel region. + + Args: + input_parallel_mode (:class:`colossalai.legacy.context.parallel_mode.ParallelMode`): input parallel mode. + weight_parallel_mode (:class:`colossalai.legacy.context.parallel_mode.ParallelMode`): weight parallel mode. + reduce_mean (bool, optional): If set to ``True``, it will divide the output by + (input parallel size * weight parallel size), default to False. + + Note: + The parallel_mode should be concluded in ``ParallelMode``. More details about ``ParallelMode`` could be found + in `parallel_mode `_ + """ + return _ReduceByBatch3D.apply(tensor, input_parallel_mode, weight_parallel_mode, reduce_mean) diff --git a/colossalai/legacy/nn/layer/parallel_3d/_utils.py b/colossalai/legacy/nn/layer/parallel_3d/_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..8c967da74e67373b3c4d4d7759a822c55e79f01a --- /dev/null +++ b/colossalai/legacy/nn/layer/parallel_3d/_utils.py @@ -0,0 +1,110 @@ +from collections import OrderedDict +from functools import partial + +import torch +from torch import Tensor + +from colossalai.legacy.constants import ( + INPUT_GROUP_3D, + INPUT_X_WEIGHT_3D, + OUTPUT_GROUP_3D, + OUTPUT_X_WEIGHT_3D, + WEIGHT_GROUP_3D, +) +from colossalai.legacy.core import global_context as gpc +from colossalai.legacy.global_variables import tensor_parallel_env as env + + +def get_depth_from_env() -> int: + try: + depth = env.depth_3d + assert depth > 0, "DEPTH must be greater than zero" + return depth + + except KeyError: + raise EnvironmentError( + "DEPTH is not found in the current environment, " + "please make sure that you have used the correct process group initializer" + ) + + +def get_parallel_mode_from_env(group): + assert group in [ + INPUT_GROUP_3D, + WEIGHT_GROUP_3D, + OUTPUT_GROUP_3D, + INPUT_X_WEIGHT_3D, + OUTPUT_X_WEIGHT_3D, + ], f"{group} is not valid for 3D tensor parallelism." + return getattr(env, group) + + +def swap_in_out_group(): + env.input_group_3d, env.output_group_3d = env.output_group_3d, env.input_group_3d + env.input_x_weight_group_3d, env.output_x_weight_group_3d = ( + env.output_x_weight_group_3d, + env.input_x_weight_group_3d, + ) + + +def dbg_check_shape(tensor: Tensor, shape: tuple): + rank = gpc.get_global_rank() + if rank == 0: + print(tensor.shape) + assert tensor.shape == shape, "{} does not match {}".format(tensor.shape, shape) + + +class AsyncGradientBucket(object): + def __init__(self): + self.bucket = OrderedDict() + + def __len__(self): + return len(self.bucket) + + def push(self, async_op, grad_tensor, param_id): + self.bucket[param_id] = tuple((async_op, grad_tensor)) + return torch.zeros_like(grad_tensor, dtype=grad_tensor.dtype, device=grad_tensor.device) + + def pop(self, param_id): + grad = None + if param_id in self.bucket: + op, grad = self.bucket.pop(param_id) + if op is not None: + op.wait() + return grad + + def synchronize(self, params): + for p in params: + i = id(p) + if i in self.bucket: + op, grad = self.bucket.pop(i) + if op is not None: + op.wait() + p.grad.add_(grad) + + +_async_grad_bucket = AsyncGradientBucket() + + +def push_async_grad(op, grad, param_id): + return _async_grad_bucket.push(op, grad, param_id) + + +def pop_async_grad(param_id): + return _async_grad_bucket.pop(param_id) + + +def _async_grad_hook(grad, param_id): + grad.add_(pop_async_grad(param_id)) + return grad + + +def register_async_grad_hook(param): + param.register_hook(partial(_async_grad_hook, param_id=id(param))) + + +def synchronize(params=list()): + _async_grad_bucket.synchronize(params) + torch.cuda.default_stream().synchronize() + if len(_async_grad_bucket) > 0: + raise RuntimeError(f"{len(_async_grad_bucket)} asynchronous gradient(s) not collected.") diff --git a/colossalai/legacy/nn/layer/parallel_3d/layers.py b/colossalai/legacy/nn/layer/parallel_3d/layers.py new file mode 100644 index 0000000000000000000000000000000000000000..19667999419735bd146dd9b5d8ae706efd1c0cab --- /dev/null +++ b/colossalai/legacy/nn/layer/parallel_3d/layers.py @@ -0,0 +1,1131 @@ +import math +from collections import OrderedDict +from typing import Callable + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch import Tensor +from torch.nn import Parameter + +from colossalai.legacy.communication import all_reduce, broadcast +from colossalai.legacy.constants import ( + INPUT_GROUP_3D, + INPUT_X_WEIGHT_3D, + OUTPUT_GROUP_3D, + OUTPUT_X_WEIGHT_3D, + WEIGHT_GROUP_3D, +) +from colossalai.legacy.context import ParallelMode, seed +from colossalai.legacy.core import global_context as gpc +from colossalai.legacy.global_variables import tensor_parallel_env as env +from colossalai.legacy.nn.layer.base_layer import ParallelLayer +from colossalai.legacy.registry import LAYERS +from colossalai.legacy.utils.checkpointing import ( + broadcast_state_dict, + gather_tensor_parallel_state_dict, + partition_tensor_parallel_state_dict, +) +from colossalai.nn import init as init +from colossalai.utils.cuda import get_current_device + +from ..utils import divide, set_tensor_parallel_attribute_by_partition, to_2tuple +from ._operation import ( + all_gather_tensor_3d, + classifier_3d, + layernorm_3d, + linear_3d, + reduce_scatter_tensor_3d, + split_batch_3d, + split_tensor_3d, + vocab_parallel_classifier_3d, +) +from ._utils import get_depth_from_env, get_parallel_mode_from_env, register_async_grad_hook, swap_in_out_group + + +@LAYERS.register_module +class LayerNorm3D(ParallelLayer): + r"""Layer Normalization for 3D parallelism. + + Args: + normalized_shape (int): input shape from an expected input of size. + :math:`[* \times \text{normalized_shape}[0] \times \text{normalized_shape}[1] + \times \ldots \times \text{normalized_shape}[-1]]` + If a single integer is used, it is treated as a singleton list, and this module will + normalize over the last dimension which is expected to be of that specific size. + eps (float, optional): a value added to the denominator for numerical stability, defaults to 1e-12. + bias (bool, optional): Whether to add a bias, defaults to ``True``. + dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None. + """ + + def __init__(self, normalized_shape: int, eps: float = 1e-12, bias=True, dtype=None): + super().__init__() + self.input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D) + self.weight_parallel_mode = get_parallel_mode_from_env(WEIGHT_GROUP_3D) + self.output_parallel_mode = get_parallel_mode_from_env(OUTPUT_GROUP_3D) + self.input_x_weight_parallel_mode = get_parallel_mode_from_env(INPUT_X_WEIGHT_3D) + self.depth = get_depth_from_env() + self.normalized_shape = normalized_shape + self.normalized_shape_per_partition = divide(normalized_shape, self.depth) + + self.weight = Parameter( + torch.ones(self.normalized_shape_per_partition, device=get_current_device(), dtype=dtype) + ) + if bias: + self.bias = Parameter( + torch.zeros(self.normalized_shape_per_partition, device=get_current_device(), dtype=dtype) + ) + else: + self.bias = None + self.variance_epsilon = eps + self.reset_parameters() + self._set_tensor_parallel_attributes() + + def _set_tensor_parallel_attributes(self) -> None: + set_tensor_parallel_attribute_by_partition(self.weight, self.depth) + if self.bias is not None: + set_tensor_parallel_attribute_by_partition(self.bias, self.depth) + + def reset_parameters(self) -> None: + init.ones_()(self.weight) + register_async_grad_hook(self.weight) + if self.bias is not None: + init.zeros_()(self.bias) + register_async_grad_hook(self.bias) + + def _load_from_global_state_dict(self, state_dict, prefix, *args, **kwargs): + local_state = OrderedDict() + weight_key = prefix + "weight" + bias_key = prefix + "bias" + if gpc.get_local_rank(ParallelMode.TENSOR) == 0: + # weight + weight = state_dict.pop(weight_key, None) + if weight is not None: + local_state[weight_key] = weight.transpose(0, 1) + # bias + bias = state_dict.pop(bias_key, None) + if bias is not None: + local_state[bias_key] = bias + + # partition in output groups + if gpc.get_local_rank(self.input_parallel_mode) == 0 and gpc.get_local_rank(self.weight_parallel_mode) == 0: + local_state = partition_tensor_parallel_state_dict( + local_state, + self.output_parallel_mode, + dims={weight_key: 0, bias_key: 0}, + partition_states={ + weight_key: True, + bias_key: True, + }, + ) + # broadcast in input groups + if gpc.get_local_rank(self.weight_parallel_mode) == 0: + local_state = broadcast_state_dict(local_state, self.input_parallel_mode) + # broadcast in weight groups + local_state = broadcast_state_dict(local_state, self.weight_parallel_mode) + + super()._load_from_global_state_dict(local_state, prefix, *args, **kwargs) + + def _save_to_global_state_dict(self, destination, prefix, keep_vars): + weight_key = prefix + "weight" + bias_key = prefix + "bias" + local_state = OrderedDict({weight_key: self.weight}) + if self.bias is not None: + local_state[bias_key] = self.bias + + # gather in output groups + if gpc.get_local_rank(self.input_parallel_mode) == 0 and gpc.get_local_rank(self.weight_parallel_mode) == 0: + local_state = gather_tensor_parallel_state_dict( + local_state, + self.output_parallel_mode, + dims={weight_key: 0, bias_key: 0}, + partition_states={weight_key: True, bias_key: True}, + keep_vars=keep_vars, + ) + if gpc.get_local_rank(ParallelMode.TENSOR) == 0: + destination.update(local_state) + + def forward(self, input_: Tensor) -> Tensor: + return layernorm_3d( + input_, + self.weight, + self.bias, + self.normalized_shape, + self.variance_epsilon, + self.output_parallel_mode, + self.input_x_weight_parallel_mode, + ) + + +@LAYERS.register_module +class Linear3D(ParallelLayer): + r"""Linear layer for 3D parallelism. + + Args: + in_features (int): size of each input sample. + out_features (int): size of each output sample. + bias (bool, optional): If set to ``False``, the layer will not learn an additive bias, defaults to ``True``. + dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None. + weight_initializer (:class:`typing.Callable`, optional): + The initializer of weight, defaults to kaiming uniform initializer. + bias_initializer (:class:`typing.Callable`, optional): + The initializer of bias, defaults to xavier uniform initializer. + + More details about ``initializer`` please refer to + `init `_. + """ + + def __init__( + self, + in_features: int, + out_features: int, + bias: bool = True, + dtype: torch.dtype = None, + skip_bias_add: bool = False, + weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), + bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1), + ): + super().__init__() + self.in_features = in_features + self.out_features = out_features + self.input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D) + self.weight_parallel_mode = get_parallel_mode_from_env(WEIGHT_GROUP_3D) + self.output_parallel_mode = get_parallel_mode_from_env(OUTPUT_GROUP_3D) + self.output_x_weight_parallel_mode = get_parallel_mode_from_env(OUTPUT_X_WEIGHT_3D) + self.depth = get_depth_from_env() + self.skip_bias_add = skip_bias_add + self.in_features_per_partition = divide(in_features, self.depth**2) + self.out_features_per_partition = divide(out_features, self.depth) + self.bias_features_per_partition = divide(out_features, self.depth) + + self.weight = Parameter( + torch.empty( + self.in_features_per_partition, + self.out_features_per_partition, + device=get_current_device(), + dtype=dtype, + ) + ) + if bias: + self.bias = Parameter( + torch.zeros(self.bias_features_per_partition, device=get_current_device(), dtype=dtype) + ) + else: + self.bias = None + + self.reset_parameters(weight_initializer, bias_initializer) + self._set_tensor_parallel_attributes() + swap_in_out_group() + + def _set_tensor_parallel_attributes(self) -> None: + set_tensor_parallel_attribute_by_partition(self.weight, self.depth**3) + if self.bias is not None: + set_tensor_parallel_attribute_by_partition(self.bias, self.depth) + + def _sync_grad_hook(self, grad) -> Tensor: + grad = all_reduce(grad.clone(), self.output_x_weight_parallel_mode) + return grad + + def reset_parameters(self, weight_initializer, bias_initializer) -> None: + with seed(ParallelMode.TENSOR): + fan_in, fan_out = self.in_features, self.out_features + + weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out) + register_async_grad_hook(self.weight) + + if self.bias is not None: + bias_initializer(self.bias, fan_in=fan_in) + broadcast( + self.bias, + gpc.get_ranks_in_group(self.output_x_weight_parallel_mode)[0], + self.output_x_weight_parallel_mode, + ) + self.bias.register_hook(self._sync_grad_hook) + + def _load_from_global_state_dict(self, state_dict, prefix, *args, **kwargs): + local_state = OrderedDict() + weight_key = prefix + "weight" + bias_key = prefix + "bias" + if gpc.get_local_rank(ParallelMode.TENSOR) == 0: + # weight + weight = state_dict.pop(weight_key, None) + if weight is not None: + local_state[weight_key] = weight.transpose(0, 1) + # bias + if self.bias is not None: + bias = state_dict.pop(bias_key, None) + if bias is not None: + local_state[bias_key] = bias + + # partition in output groups + if gpc.get_local_rank(self.input_parallel_mode) == 0 and gpc.get_local_rank(self.weight_parallel_mode) == 0: + local_state = partition_tensor_parallel_state_dict( + local_state, + self.output_parallel_mode, + dims={weight_key: 0, bias_key: 0}, + partition_states={weight_key: True, bias_key: False}, + ) + # partition in input groups + if gpc.get_local_rank(self.weight_parallel_mode) == 0: + local_state = partition_tensor_parallel_state_dict( + local_state, + self.input_parallel_mode, + dims={weight_key: -1, bias_key: 0}, + partition_states={weight_key: True, bias_key: True}, + ) + # partition in weight groups + local_state = partition_tensor_parallel_state_dict( + local_state, + self.weight_parallel_mode, + dims={weight_key: 0, bias_key: 0}, + partition_states={weight_key: True, bias_key: False}, + ) + + super()._load_from_global_state_dict(local_state, prefix, *args, **kwargs) + + def _save_to_global_state_dict(self, destination, prefix, keep_vars): + weight_key = prefix + "weight" + bias_key = prefix + "bias" + local_state = OrderedDict({weight_key: self.weight}) + if self.bias is not None: + local_state[bias_key] = self.bias + + # gather in weight groups + local_state = gather_tensor_parallel_state_dict( + local_state, + self.weight_parallel_mode, + dims={weight_key: 0, bias_key: 0}, + partition_states={weight_key: True, bias_key: False}, + keep_vars=keep_vars, + ) + # gather in input groups + if gpc.get_local_rank(self.weight_parallel_mode) == 0: + local_state = gather_tensor_parallel_state_dict( + local_state, + self.input_parallel_mode, + dims={weight_key: -1, bias_key: 0}, + partition_states={weight_key: True, bias_key: True}, + keep_vars=keep_vars, + ) + # gather in output groups + if gpc.get_local_rank(self.input_parallel_mode) == 0 and gpc.get_local_rank(self.weight_parallel_mode) == 0: + local_state = gather_tensor_parallel_state_dict( + local_state, + self.output_parallel_mode, + dims={weight_key: 0, bias_key: 0}, + partition_states={weight_key: True, bias_key: False}, + keep_vars=keep_vars, + ) + if gpc.get_local_rank(ParallelMode.TENSOR) == 0: + local_state[weight_key] = local_state[weight_key].transpose(0, 1) + destination.update(local_state) + + def forward(self, input_: Tensor) -> Tensor: + output = linear_3d( + input_, + self.weight, + self.input_parallel_mode, + self.weight_parallel_mode, + self.output_parallel_mode, + ) + + if not self.skip_bias_add: + if self.bias is not None: + output = output + self.bias + return output + else: + return output, self.bias + + +@LAYERS.register_module +class Classifier3D(ParallelLayer): + r"""Classifier for 3D parallelism. + + Args: + in_features (int): size of each input sample. + num_classes (int): number of classes. + weight (:class:`torch.nn.Parameter`, optional): weight of the classifier, defaults to None. + bias (bool, optional): If set to ``False``, the layer will not learn an additive bias, defaults to ``True``. + dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None. + weight_initializer (:class:`typing.Callable`, optional): + The initializer of weight, defaults to kaiming uniform initializer. + bias_initializer (:class:`typing.Callable`, optional): + The initializer of bias, defaults to xavier uniform initializer. + + More details about ``initializer`` please refer to + `init `_. + """ + + def __init__( + self, + in_features: int, + num_classes: int, + weight: Parameter = None, + bias: bool = True, + dtype: torch.dtype = None, + weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), + bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1), + ): + super().__init__() + self.in_features = in_features + self.num_classes = num_classes + self.input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D) + self.weight_parallel_mode = get_parallel_mode_from_env(WEIGHT_GROUP_3D) + self.output_parallel_mode = get_parallel_mode_from_env(OUTPUT_GROUP_3D) + self.depth = get_depth_from_env() + self.in_features_per_partition = divide(in_features, self.depth) + + if weight is not None: + self.weight = weight + self.has_weight = False + else: + self.weight = Parameter( + torch.empty(self.num_classes, self.in_features_per_partition, device=get_current_device(), dtype=dtype) + ) + self.has_weight = True + if bias: + self.bias = Parameter(torch.zeros(self.num_classes, device=get_current_device(), dtype=dtype)) + else: + self.bias = None + + self.reset_parameters(weight_initializer, bias_initializer) + self._set_tensor_parallel_attributes() + + def _set_tensor_parallel_attributes(self) -> None: + if self.has_weight: + set_tensor_parallel_attribute_by_partition(self.weight, self.depth) + + def reset_parameters(self, weight_initializer, bias_initializer) -> None: + with seed(ParallelMode.TENSOR): + fan_in, fan_out = self.in_features, self.num_classes + + if self.has_weight: + weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out) + broadcast(self.weight, gpc.get_ranks_in_group(self.weight_parallel_mode)[0], self.weight_parallel_mode) + + register_async_grad_hook(self.weight) + + if self.bias is not None: + bias_initializer(self.bias, fan_in=fan_in) + broadcast(self.bias, gpc.get_ranks_in_group(ParallelMode.TENSOR)[0], ParallelMode.TENSOR) + register_async_grad_hook(self.bias) + + def _load_from_global_state_dict(self, state_dict, prefix, *args, **kwargs): + local_state = OrderedDict() + weight_key = prefix + "weight" + bias_key = prefix + "bias" + if gpc.get_local_rank(ParallelMode.TENSOR) == 0: + # weight + if self.has_weight: + weight = state_dict.pop(weight_key, None) + if weight is not None: + local_state[weight_key] = weight + # bias + if self.bias is not None: + bias = state_dict.pop(bias_key, None) + if bias is not None: + local_state[bias_key] = bias + + # partition in output groups + if gpc.get_local_rank(self.input_parallel_mode) == 0 and gpc.get_local_rank(self.weight_parallel_mode) == 0: + local_state = partition_tensor_parallel_state_dict( + local_state, + self.output_parallel_mode, + dims={weight_key: -1, bias_key: 0}, + partition_states={weight_key: True, bias_key: False}, + ) + # broadcast in input groups + if gpc.get_local_rank(self.weight_parallel_mode) == 0: + local_state = broadcast_state_dict(local_state, self.input_parallel_mode) + # broadcast in weight groups + local_state = broadcast_state_dict(local_state, self.weight_parallel_mode) + + super()._load_from_global_state_dict(local_state, prefix, *args, **kwargs) + + def _save_to_global_state_dict(self, destination, prefix, keep_vars): + weight_key = prefix + "weight" + bias_key = prefix + "bias" + local_state = OrderedDict() + if self.has_weight: + local_state[weight_key] = self.weight + if self.bias is not None: + local_state[bias_key] = self.bias + + # gather in output groups + if gpc.get_local_rank(self.input_parallel_mode) == 0 and gpc.get_local_rank(self.weight_parallel_mode) == 0: + local_state = gather_tensor_parallel_state_dict( + local_state, + self.output_parallel_mode, + dims={weight_key: -1, bias_key: 0}, + partition_states={weight_key: True, bias_key: False}, + keep_vars=keep_vars, + ) + if gpc.get_local_rank(ParallelMode.TENSOR) == 0: + destination.update(local_state) + + def forward(self, input_: Tensor) -> Tensor: + return classifier_3d( + input_, + self.weight, + self.bias, + self.input_parallel_mode, + self.weight_parallel_mode, + self.output_parallel_mode, + ) + + +@LAYERS.register_module +class VocabParallelClassifier3D(ParallelLayer): + r"""Vocab parallel classifier layer for 3D parallelism. + + Args: + in_features (int): size of each input sample. + num_classes (int): number of classes. + weight (:class:`torch.nn.Parameter`, optional): weight of the classifier, defaults to None. + bias (bool, optional): If set to ``False``, the layer will not learn an additive bias, defaults to ``True``. + dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None. + weight_initializer (:class:`typing.Callable`, optional): + The initializer of weight, defaults to kaiming uniform initializer. + bias_initializer (:class:`typing.Callable`, optional): + The initializer of bias, defaults to xavier uniform initializer. + + More details about ``initializer`` please refer to + `init `_. + """ + + def __init__( + self, + in_features: int, + num_classes: int, + weight: Parameter = None, + bias: bool = True, + dtype: torch.dtype = None, + weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), + bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1), + ): + super().__init__() + self.in_features = in_features + self.num_classes = num_classes + self.input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D) + self.weight_parallel_mode = get_parallel_mode_from_env(WEIGHT_GROUP_3D) + self.output_parallel_mode = get_parallel_mode_from_env(OUTPUT_GROUP_3D) + self.output_x_weight_parallel_mode = get_parallel_mode_from_env(OUTPUT_X_WEIGHT_3D) + self.depth = get_depth_from_env() + self.in_features_per_partition = divide(in_features, self.depth) + self.out_features_per_partition = divide(num_classes, self.depth**2) + self.bias_features_per_partition = divide(num_classes, self.depth) + + if weight is not None: + self.weight = weight + self.has_weight = False + else: + self.weight = Parameter( + torch.empty( + self.out_features_per_partition, + self.in_features_per_partition, + device=get_current_device(), + dtype=dtype, + ) + ) + self.has_weight = True + if bias: + self.bias = Parameter( + torch.zeros(self.bias_features_per_partition, device=get_current_device(), dtype=dtype) + ) + else: + self.bias = None + + self.reset_parameters(weight_initializer, bias_initializer) + self._set_tensor_parallel_attributes() + swap_in_out_group() + env.vocab_parallel = True + + def _set_tensor_parallel_attributes(self) -> None: + if self.has_weight: + set_tensor_parallel_attribute_by_partition(self.weight, self.depth**3) + if self.bias is not None: + set_tensor_parallel_attribute_by_partition(self.bias, self.depth) + + def reset_parameters(self, weight_initializer, bias_initializer) -> None: + with seed(ParallelMode.TENSOR): + fan_in, fan_out = self.in_features, self.num_classes + + if self.has_weight: + weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out) + + register_async_grad_hook(self.weight) + + if self.bias is not None: + bias_initializer(self.bias, fan_in=fan_in) + broadcast( + self.bias, + gpc.get_ranks_in_group(self.output_x_weight_parallel_mode)[0], + self.output_x_weight_parallel_mode, + ) + register_async_grad_hook(self.bias) + + def _load_from_global_state_dict(self, state_dict, prefix, *args, **kwargs): + local_state = OrderedDict() + weight_key = prefix + "weight" + bias_key = prefix + "bias" + if gpc.get_local_rank(ParallelMode.TENSOR) == 0: + # weight + if self.has_weight: + weight = state_dict.pop(weight_key, None) + if weight is not None: + local_state[weight_key] = weight + # bias + if self.bias is not None: + bias = state_dict.pop(bias_key, None) + if bias is not None: + local_state[bias_key] = bias + + # partition in output groups + if gpc.get_local_rank(self.input_parallel_mode) == 0 and gpc.get_local_rank(self.weight_parallel_mode) == 0: + local_state = partition_tensor_parallel_state_dict( + local_state, + self.output_parallel_mode, + dims={weight_key: -1, bias_key: 0}, + partition_states={weight_key: True, bias_key: False}, + ) + # partition in input groups + if gpc.get_local_rank(self.weight_parallel_mode) == 0: + local_state = partition_tensor_parallel_state_dict( + local_state, + self.input_parallel_mode, + dims={weight_key: 0, bias_key: 0}, + partition_states={weight_key: True, bias_key: True}, + ) + # partition in weight groups + local_state = partition_tensor_parallel_state_dict( + local_state, + self.weight_parallel_mode, + dims={weight_key: 0, bias_key: 0}, + partition_states={weight_key: True, bias_key: False}, + ) + + super()._load_from_global_state_dict(local_state, prefix, *args, **kwargs) + + def _save_to_global_state_dict(self, destination, prefix, keep_vars): + weight_key = prefix + "weight" + bias_key = prefix + "bias" + local_state = OrderedDict({weight_key: self.weight}) + if self.bias is not None: + local_state[bias_key] = self.bias + + # gather in weight groups + local_state = gather_tensor_parallel_state_dict( + local_state, + self.weight_parallel_mode, + dims={weight_key: 0, bias_key: 0}, + partition_states={weight_key: True, bias_key: False}, + keep_vars=keep_vars, + ) + # gather in input groups + if gpc.get_local_rank(self.weight_parallel_mode) == 0: + local_state = gather_tensor_parallel_state_dict( + local_state, + self.input_parallel_mode, + dims={weight_key: 0, bias_key: 0}, + partition_states={weight_key: True, bias_key: True}, + keep_vars=keep_vars, + ) + # gather in output groups + if gpc.get_local_rank(self.input_parallel_mode) == 0 and gpc.get_local_rank(self.weight_parallel_mode) == 0: + local_state = gather_tensor_parallel_state_dict( + local_state, + self.output_parallel_mode, + dims={weight_key: -1, bias_key: 0}, + partition_states={weight_key: True, bias_key: False}, + keep_vars=keep_vars, + ) + if gpc.get_local_rank(ParallelMode.TENSOR) == 0: + destination.update(local_state) + + def forward(self, input_: Tensor) -> Tensor: + return vocab_parallel_classifier_3d( + input_, + self.weight, + self.bias, + self.input_parallel_mode, + self.weight_parallel_mode, + self.output_parallel_mode, + ) + + +@LAYERS.register_module +class PatchEmbedding3D(ParallelLayer): + r"""2D Image to Patch Embedding. + + Args: + img_size (int): image size. + patch_size (int): patch size. + in_chans (int): number of channels of input image. + embed_size (int): size of embedding. + dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None. + flatten (bool, optional): whether to flatten output tensor, defaults to True. + weight_initializer (:class:`typing.Callable`, optional): + The initializer of weight, defaults to kaiming uniform initializer. + bias_initializer (:class:`typing.Callable`, optional): + The initializer of bias, defaults to xavier uniform initializer. + position_embed_initializer (:class:`typing.Callable`, optional): + The initializer of position embedding, defaults to zeros initializer. + + More details about ``initializer`` please refer to + `init `_. + """ + + def __init__( + self, + img_size: int, + patch_size: int, + in_chans: int, + embed_size: int, + flatten: bool = True, + dtype: torch.dtype = None, + weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), + bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1), + position_embed_initializer: Callable = init.zeros_(), + ): + super().__init__() + self.depth = get_depth_from_env() + self.input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D) + self.weight_parallel_mode = get_parallel_mode_from_env(WEIGHT_GROUP_3D) + self.output_parallel_mode = get_parallel_mode_from_env(OUTPUT_GROUP_3D) + self.input_x_weight_parallel_mode = get_parallel_mode_from_env(INPUT_X_WEIGHT_3D) + img_size = to_2tuple(img_size) + patch_size = to_2tuple(patch_size) + self.img_size = img_size + self.patch_size = patch_size + self.grid_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1]) + self.num_patches = self.grid_size[0] * self.grid_size[1] + self.embed_size = embed_size + embed_size_per_partition = embed_size // self.depth + self.flatten = flatten + + self.weight = nn.Parameter( + torch.empty( + (embed_size_per_partition, in_chans, *self.patch_size), device=get_current_device(), dtype=dtype + ) + ) + self.bias = nn.Parameter(torch.empty(embed_size_per_partition, device=get_current_device(), dtype=dtype)) + + self.cls_token = nn.Parameter( + torch.zeros((1, 1, embed_size_per_partition), device=get_current_device(), dtype=dtype) + ) + self.pos_embed = nn.Parameter( + torch.zeros((1, self.num_patches + 1, embed_size_per_partition), device=get_current_device(), dtype=dtype) + ) + + self.reset_parameters(weight_initializer, bias_initializer, position_embed_initializer) + self._set_tensor_parallel_attributes() + + def _set_tensor_parallel_attributes(self) -> None: + set_tensor_parallel_attribute_by_partition(self.weight, self.depth) + set_tensor_parallel_attribute_by_partition(self.bias, self.depth) + set_tensor_parallel_attribute_by_partition(self.cls_token, self.depth) + set_tensor_parallel_attribute_by_partition(self.pos_embed, self.depth) + + def _sync_grad_hook(self, grad) -> Tensor: + grad = all_reduce(grad.clone(), self.input_x_weight_parallel_mode) + return grad + + def reset_parameters(self, weight_initializer, bias_initializer, position_embed_initializer) -> None: + with seed(ParallelMode.TENSOR): + fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight) + fan_out = self.embed_size + weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out) + bias_initializer(self.bias, fan_in=fan_in) + position_embed_initializer(self.pos_embed) + + src_rank = gpc.get_ranks_in_group(self.input_x_weight_parallel_mode)[0] + broadcast(self.weight, src_rank, self.input_x_weight_parallel_mode) + broadcast(self.bias, src_rank, self.input_x_weight_parallel_mode) + broadcast(self.pos_embed, src_rank, self.input_x_weight_parallel_mode) + + self.weight.register_hook(self._sync_grad_hook) + self.bias.register_hook(self._sync_grad_hook) + self.cls_token.register_hook(self._sync_grad_hook) + self.pos_embed.register_hook(self._sync_grad_hook) + + def _load_from_global_state_dict(self, state_dict, prefix, *args, **kwargs): + local_state = OrderedDict() + weight_key = prefix + "weight" + bias_key = prefix + "bias" + cls_token_key = prefix + "cls_token" + pos_embed_key = prefix + "pos_embed" + if gpc.get_local_rank(ParallelMode.TENSOR) == 0: + # weight + weight = state_dict.pop(weight_key, None) + if weight is not None: + local_state[weight_key] = weight + # bias + bias = state_dict.pop(bias_key, None) + if bias is not None: + local_state[bias_key] = bias + # cls token + cls_token = state_dict.pop(cls_token_key, None) + if cls_token is not None: + local_state[cls_token_key] = cls_token + # pos embed + pos_embed = state_dict.pop(pos_embed_key, None) + if pos_embed is not None: + local_state[pos_embed_key] = pos_embed + + # partition in output groups + if gpc.get_local_rank(self.input_parallel_mode) == 0 and gpc.get_local_rank(self.weight_parallel_mode) == 0: + local_state = partition_tensor_parallel_state_dict( + local_state, + self.output_parallel_mode, + dims={weight_key: 0, bias_key: 0, cls_token_key: -1, pos_embed_key: -1}, + partition_states={weight_key: True, bias_key: True, cls_token_key: True, pos_embed_key: True}, + ) + # broadcast in input groups + if gpc.get_local_rank(self.weight_parallel_mode) == 0: + local_state = broadcast_state_dict(local_state, self.input_parallel_mode) + # broadcast in weight groups + local_state = broadcast_state_dict(local_state, self.weight_parallel_mode) + + super()._load_from_global_state_dict(local_state, prefix, *args, **kwargs) + + def _save_to_global_state_dict(self, destination, prefix, keep_vars): + weight_key = prefix + "weight" + bias_key = prefix + "bias" + cls_token_key = prefix + "cls_token" + pos_embed_key = prefix + "pos_embed" + local_state = OrderedDict( + {weight_key: self.weight, bias_key: self.bias, cls_token_key: self.cls_token, pos_embed_key: self.pos_embed} + ) + + # gather in output groups + if gpc.get_local_rank(self.input_parallel_mode) == 0 and gpc.get_local_rank(self.weight_parallel_mode) == 0: + local_state = gather_tensor_parallel_state_dict( + local_state, + self.output_parallel_mode, + dims={weight_key: 0, bias_key: 0, cls_token_key: -1, pos_embed_key: -1}, + partition_states={weight_key: True, bias_key: True, cls_token_key: True, pos_embed_key: True}, + keep_vars=keep_vars, + ) + if gpc.get_local_rank(ParallelMode.TENSOR) == 0: + destination.update(local_state) + + def forward(self, input_: Tensor) -> Tensor: + input_ = split_batch_3d( + input_, input_parallel_mode=self.input_parallel_mode, weight_parallel_mode=self.weight_parallel_mode + ) + output = F.conv2d(input_, self.weight, self.bias, stride=self.patch_size) + if self.flatten: + output = output.flatten(2).transpose(1, 2) # BCHW -> BNC + + cls_token = self.cls_token.expand(output.shape[0], -1, -1) + output = torch.cat((cls_token, output), dim=1) + output = output + self.pos_embed + + return output + + +@LAYERS.register_module +class Embedding3D(ParallelLayer): + r"""Embedding for 3D parallelism. + + Args: + num_embeddings (int): number of embeddings. + embedding_dim (int): dimension of embedding. + padding_idx (int, optional): If specified, the entries at padding_idx do not contribute to the gradient; + therefore, the embedding vector at padding_idx is not updated during training, + i.e. it remains as a fixed “pad”, defaults to None. + dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None. + weight_initializer (:class:`typing.Callable`, optional): + he initializer of weight, defaults to normal initializer. + + The ``args`` and ``kwargs`` used in :class:``torch.nn.functional.embedding`` should contain: + :: + + max_norm (float, optional): If given, each embedding vector with norm larger than max_norm is + renormalized to have norm max_norm. Note: this will modify weight in-place. + norm_type (float, optional): The p of the p-norm to compute for the max_norm option. Default 2. + scale_grad_by_freq (bool, optional): If given, this will scale gradients by the inverse + of frequency of the words in the mini-batch. Default False. + sparse (bool, optional): If True, gradient w.r.t. weight will be a sparse tensor. Default False. + + More details about ``args`` and ``kwargs`` could be found in + `Embedding `_. + + More details about initializer please refer to + `init `_ + """ + + def __init__( + self, + num_embeddings: int, + embedding_dim: int, + padding_idx: int = None, + dtype: torch.dtype = None, + weight_initializer: Callable = init.normal_(), + *args, + **kwargs, + ): + super().__init__() + self.depth = get_depth_from_env() + self.input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D) + self.weight_parallel_mode = get_parallel_mode_from_env(WEIGHT_GROUP_3D) + self.output_parallel_mode = get_parallel_mode_from_env(OUTPUT_GROUP_3D) + self.input_x_weight_parallel_mode = get_parallel_mode_from_env(INPUT_X_WEIGHT_3D) + + self.num_embeddings = num_embeddings + self.embed_dim = embedding_dim + embed_dim_per_partition = divide(embedding_dim, self.depth) + self.padding_idx = padding_idx + self.embed_args = args + self.embed_kwargs = kwargs + + self.weight = nn.Parameter( + torch.empty((num_embeddings, embed_dim_per_partition), device=get_current_device(), dtype=dtype) + ) + + self.reset_parameters(weight_initializer) + self._set_tensor_parallel_attributes() + + def _set_tensor_parallel_attributes(self) -> None: + set_tensor_parallel_attribute_by_partition(self.weight, self.depth) + + def _sync_grad_hook(self, grad) -> Tensor: + grad = all_reduce(grad.clone(), self.input_x_weight_parallel_mode) + return grad + + def reset_parameters(self, weight_initializer) -> None: + with seed(ParallelMode.TENSOR): + fan_in, fan_out = self.num_embeddings, self.embed_dim + weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out) + self._fill_padding_idx_with_zero() + broadcast( + self.weight, gpc.get_ranks_in_group(self.input_x_weight_parallel_mode)[0], self.input_x_weight_parallel_mode + ) + self.weight.register_hook(self._sync_grad_hook) + + def _fill_padding_idx_with_zero(self) -> None: + if self.padding_idx is not None: + with torch.no_grad(): + self.weight[self.padding_idx].fill_(0) + + def _load_from_global_state_dict(self, state_dict, prefix, *args, **kwargs): + local_state = OrderedDict() + weight_key = prefix + "weight" + if gpc.get_local_rank(ParallelMode.TENSOR) == 0: + # weight + weight = state_dict.pop(weight_key, None) + if weight is not None: + local_state[weight_key] = weight + + # partition in output groups + if gpc.get_local_rank(self.input_parallel_mode) == 0 and gpc.get_local_rank(self.weight_parallel_mode) == 0: + local_state = partition_tensor_parallel_state_dict( + local_state, + self.output_parallel_mode, + dims={weight_key: 0}, + partition_states={weight_key: True}, + ) + # broadcast in input groups + if gpc.get_local_rank(self.weight_parallel_mode) == 0: + local_state = broadcast_state_dict(local_state, self.input_parallel_mode) + # broadcast in weight groups + local_state = broadcast_state_dict(local_state, self.weight_parallel_mode) + + super()._load_from_global_state_dict(local_state, prefix, *args, **kwargs) + + def _save_to_global_state_dict(self, destination, prefix, keep_vars): + weight_key = prefix + "weight" + local_state = OrderedDict({weight_key: self.weight}) + + # gather in output groups + if gpc.get_local_rank(self.input_parallel_mode) == 0 and gpc.get_local_rank(self.weight_parallel_mode) == 0: + local_state = gather_tensor_parallel_state_dict( + local_state, + self.output_parallel_mode, + dims={weight_key: 0}, + partition_states={weight_key: True}, + keep_vars=keep_vars, + ) + if gpc.get_local_rank(ParallelMode.TENSOR) == 0: + destination.update(local_state) + + def forward(self, input_: Tensor) -> Tensor: + input_ = split_batch_3d( + input_, input_parallel_mode=self.input_parallel_mode, weight_parallel_mode=self.weight_parallel_mode + ) + output = F.embedding(input_, self.weight, self.padding_idx, *self.embed_args, **self.embed_kwargs) + + return output + + +@LAYERS.register_module +class VocabParallelEmbedding3D(ParallelLayer): + r"""Embedding parallelized in the vocabulary dimension. + + Args: + num_embeddings (int): number of embeddings. + embedding_dim (int): dimension of embedding. + padding_idx (int, optional): If specified, the entries at padding_idx do not contribute to the gradient; + therefore, the embedding vector at padding_idx is not updated during training, + i.e. it remains as a fixed “pad”, defaults to None. + dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None. + weight_initializer (:class:`typing.Callable`, optional): + he initializer of weight, defaults to normal initializer. + + The ``args`` and ``kwargs`` used in :class:``torch.nn.functional.embedding`` should contain: + :: + + max_norm (float, optional): If given, each embedding vector with norm larger than max_norm is + renormalized to have norm max_norm. Note: this will modify weight in-place. + norm_type (float, optional): The p of the p-norm to compute for the max_norm option. Default 2. + scale_grad_by_freq (bool, optional): If given, this will scale gradients by the inverse + of frequency of the words in the mini-batch. Default False. + sparse (bool, optional): If True, gradient w.r.t. weight will be a sparse tensor. Default False. + + More details about ``args`` and ``kwargs`` could be found in + `Embedding `_. + + More details about initializer please refer to + `init `_. + """ + + def __init__( + self, + num_embeddings: int, + embedding_dim: int, + padding_idx: int = None, + dtype: torch.dtype = None, + weight_initializer: Callable = init.normal_(), + *args, + **kwargs, + ): + super().__init__() + self.num_embeddings = num_embeddings + self.embed_dim = embedding_dim + self.padding_idx = padding_idx + self.embed_args = args + self.embed_kwargs = kwargs + + self.depth = get_depth_from_env() + self.input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D) + self.weight_parallel_mode = get_parallel_mode_from_env(WEIGHT_GROUP_3D) + self.output_parallel_mode = get_parallel_mode_from_env(OUTPUT_GROUP_3D) + self.num_embeddings_per_partition = divide(self.num_embeddings, self.depth**2) + self.embed_dim_per_partition = divide(self.embed_dim, self.depth) + vocab_parallel_rank = gpc.get_local_rank(self.input_parallel_mode) + self.vocab_start_index = vocab_parallel_rank * self.num_embeddings_per_partition * self.depth + self.vocab_end_index = self.vocab_start_index + self.num_embeddings_per_partition * self.depth + + self.weight = Parameter( + torch.empty( + (self.num_embeddings_per_partition, self.embed_dim_per_partition), + device=get_current_device(), + dtype=dtype, + ) + ) + + self.reset_parameters(weight_initializer) + self._set_tensor_parallel_attributes() + env.vocab_parallel = True + + def _set_tensor_parallel_attributes(self): + set_tensor_parallel_attribute_by_partition(self.weight, self.depth**3) + + def reset_parameters(self, weight_initializer) -> None: + with seed(ParallelMode.TENSOR): + fan_in, fan_out = self.num_embeddings, self.embed_dim + weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out) + self._fill_padding_idx_with_zero() + + def _fill_padding_idx_with_zero(self) -> None: + if ( + self.padding_idx is not None + and self.padding_idx >= self.vocab_start_index + and self.padding_idx < self.vocab_end_index + ): + with torch.no_grad(): + self.weight[self.padding_idx - self.vocab_start_index].fill_(0) + + def _load_from_global_state_dict(self, state_dict, prefix, *args, **kwargs): + local_state = OrderedDict() + weight_key = prefix + "weight" + if gpc.get_local_rank(ParallelMode.TENSOR) == 0: + # weight + weight = state_dict.pop(weight_key, None) + if weight is not None: + local_state[weight_key] = weight + + # partition in output groups + if gpc.get_local_rank(self.input_parallel_mode) == 0 and gpc.get_local_rank(self.weight_parallel_mode) == 0: + local_state = partition_tensor_parallel_state_dict( + local_state, + self.output_parallel_mode, + dims={weight_key: -1}, + partition_states={weight_key: True}, + ) + # partition in input groups + if gpc.get_local_rank(self.weight_parallel_mode) == 0: + local_state = partition_tensor_parallel_state_dict( + local_state, + self.input_parallel_mode, + dims={weight_key: 0}, + partition_states={weight_key: True}, + ) + # partition in weight groups + local_state = partition_tensor_parallel_state_dict( + local_state, + self.weight_parallel_mode, + dims={weight_key: 0}, + partition_states={weight_key: True}, + ) + + super()._load_from_global_state_dict(local_state, prefix, *args, **kwargs) + + def _save_to_global_state_dict(self, destination, prefix, keep_vars): + weight_key = prefix + "weight" + local_state = OrderedDict({weight_key: self.weight}) + + # gather in weight groups + local_state = gather_tensor_parallel_state_dict( + local_state, + self.weight_parallel_mode, + dims={weight_key: 0}, + partition_states={weight_key: True}, + keep_vars=keep_vars, + ) + # gather in input groups + if gpc.get_local_rank(self.weight_parallel_mode) == 0: + local_state = gather_tensor_parallel_state_dict( + local_state, + self.input_parallel_mode, + dims={weight_key: 0}, + partition_states={weight_key: True}, + keep_vars=keep_vars, + ) + # gather in output groups + if gpc.get_local_rank(self.input_parallel_mode) == 0 and gpc.get_local_rank(self.weight_parallel_mode) == 0: + local_state = gather_tensor_parallel_state_dict( + local_state, + self.output_parallel_mode, + dims={weight_key: -1}, + partition_states={weight_key: True}, + keep_vars=keep_vars, + ) + if gpc.get_local_rank(ParallelMode.TENSOR) == 0: + destination.update(local_state) + + def forward(self, input_: Tensor) -> Tensor: + input_ = split_tensor_3d(input_, 0, self.weight_parallel_mode) + + input_mask = (input_ < self.vocab_start_index) | (input_ >= self.vocab_end_index) + masked_input = input_.clone() - self.vocab_start_index + masked_input[input_mask] = 0 + + weight = all_gather_tensor_3d(self.weight, 0, self.weight_parallel_mode) + + output_parallel = F.embedding(masked_input, weight, self.padding_idx, *self.embed_args, **self.embed_kwargs) + + output_parallel[input_mask, :] = 0.0 + output = reduce_scatter_tensor_3d(output_parallel, 0, self.input_parallel_mode) + + return output diff --git a/colossalai/legacy/nn/layer/parallel_sequence/__init__.py b/colossalai/legacy/nn/layer/parallel_sequence/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..d64aba6bafe4f1caddb933e50c471648d2d151b0 --- /dev/null +++ b/colossalai/legacy/nn/layer/parallel_sequence/__init__.py @@ -0,0 +1,4 @@ +from ._operation import RingAV, RingQK +from .layers import TransformerSelfAttentionRing + +__all__ = ["TransformerSelfAttentionRing", "RingAV", "RingQK"] diff --git a/colossalai/legacy/nn/layer/parallel_sequence/_operation.py b/colossalai/legacy/nn/layer/parallel_sequence/_operation.py new file mode 100644 index 0000000000000000000000000000000000000000..24d5499e3a5fcd246cb5ddaa7b6ef69c05c868be --- /dev/null +++ b/colossalai/legacy/nn/layer/parallel_sequence/_operation.py @@ -0,0 +1,158 @@ +#!/usr/bin/env python +# -*- encoding: utf-8 -*- + +import torch +from torch import distributed as dist +from torch.cuda.amp import custom_bwd, custom_fwd + +from colossalai.legacy.communication import ring_forward +from colossalai.legacy.context.parallel_mode import ParallelMode +from colossalai.legacy.core import global_context as gpc +from colossalai.legacy.nn.layer.parallel_sequence._utils import _calc_current_device_range, _calc_incoming_device_range +from colossalai.utils import get_current_device + + +class RingQK(torch.autograd.Function): + """ + Calculate QK in a ring-exchange style + """ + + @staticmethod + @custom_fwd + def forward(ctx, sub_q, sub_k, batch_size, num_attention_heads, sub_seq_length): + # save tensor for backward + ctx.save_for_backward(sub_q, sub_k) + ctx.sub_seq_length = sub_seq_length + + # create local segment of attention score + attention_score = torch.empty( + batch_size * num_attention_heads, + sub_seq_length, + sub_seq_length * gpc.get_world_size(ParallelMode.SEQUENCE), + dtype=sub_q.dtype, + device=get_current_device(), + ) + + # compute local QK^T + part_a = torch.matmul(sub_q, sub_k.transpose(2, 1)) + local_rank = gpc.get_local_rank(ParallelMode.SEQUENCE) + local_world_size = gpc.get_world_size(ParallelMode.SEQUENCE) + start_idx = local_rank * sub_seq_length + end_idx = (local_rank + 1) * sub_seq_length + attention_score[:, :, start_idx:end_idx] = part_a + + # compute QK^T in ring-all-reduce style + for i in range(local_world_size - 1): + sub_k = ring_forward(sub_k, ParallelMode.SEQUENCE) + start_idx, end_idx = _calc_incoming_device_range(i, local_rank, local_world_size, sub_seq_length) + part_a = torch.matmul(sub_q, sub_k.transpose(2, 1)) + attention_score[:, :, start_idx:end_idx] = part_a + + return attention_score + + @staticmethod + @custom_bwd + def backward(ctx, grad_output): + ( + sub_q, + sub_k, + ) = ctx.saved_tensors + local_rank = gpc.get_local_rank(ParallelMode.SEQUENCE) + local_world_size = gpc.get_world_size(ParallelMode.SEQUENCE) + + # calculate gradient of sub_k + grad_k = torch.matmul(grad_output.transpose(2, 1), sub_q) + + dist.all_reduce(grad_k, group=gpc.get_group(ParallelMode.SEQUENCE)) + grad_k = grad_k[:, local_rank * ctx.sub_seq_length : (local_rank + 1) * ctx.sub_seq_length] + grad_k /= local_world_size + + # calculate gradient for sub_q + grad_q = torch.zeros_like( + sub_q, + dtype=sub_q.dtype, + device=get_current_device(), + ) + + # compute with local sub_k + start_idx, end_idx = _calc_current_device_range(local_rank, ctx.sub_seq_length) + grad_q += torch.matmul(grad_output[:, :, start_idx:end_idx], sub_k) + + # compute QK^T in ring-all-reduce style + for i in range(local_world_size - 1): + sub_k = ring_forward(sub_k, ParallelMode.SEQUENCE) + start_idx, end_idx = _calc_incoming_device_range(i, local_rank, local_world_size, ctx.sub_seq_length) + grad_q += torch.matmul(grad_output[:, :, start_idx:end_idx], sub_k) + + grad_q /= local_world_size + + return grad_q, grad_k, None, None, None + + +class RingAV(torch.autograd.Function): + """ + Calculate AV in a ring-exchange style + """ + + @staticmethod + @custom_fwd + def forward(ctx, attention_score, sub_v, batch_size, num_attention_heads, attention_head_size, sub_seq_length): + local_rank = gpc.get_local_rank(ParallelMode.SEQUENCE) + local_world_size = gpc.get_world_size(ParallelMode.SEQUENCE) + local_start_idx, local_end_idx = _calc_current_device_range(local_rank, sub_seq_length) + + sub_attention_result = torch.zeros( + batch_size * num_attention_heads, + sub_seq_length, + attention_head_size, + device=get_current_device(), + dtype=attention_score.dtype, + ) + + # save tensors for backward + ctx.save_for_backward(attention_score, sub_v) + ctx.sub_seq_length = sub_seq_length + + # compute local AV + part_av = torch.matmul(attention_score[:, :, local_start_idx:local_end_idx], sub_v) + sub_attention_result += part_av + + # compute AV in ring - all - reduce style + for i in range(local_world_size - 1): + sub_v = ring_forward(sub_v, ParallelMode.SEQUENCE) + start_idx, end_idx = _calc_incoming_device_range(i, local_rank, local_world_size, sub_seq_length) + + # compute QK^T + part_av = torch.matmul(attention_score[:, :, start_idx:end_idx], sub_v) + sub_attention_result += part_av + return sub_attention_result + + @staticmethod + @custom_bwd + def backward(ctx, grad_output): + local_rank = gpc.get_local_rank(ParallelMode.SEQUENCE) + local_world_size = gpc.get_world_size(ParallelMode.SEQUENCE) + local_start_idx, local_end_idx = _calc_current_device_range(local_rank, ctx.sub_seq_length) + attention_scores, sub_v = ctx.saved_tensors + + # calculate gradient of v + grad_v = torch.matmul(attention_scores.transpose(2, 1), grad_output) + dist.all_reduce(grad_v, group=gpc.get_group(ParallelMode.SEQUENCE)) + grad_v = grad_v[:, local_start_idx:local_end_idx] + grad_v /= local_world_size + + # calculate gradient for attention score + grad_attention_score = torch.zeros_like(attention_scores, dtype=grad_output.dtype, device=get_current_device()) + + # compute with local sub_k + grad_attention_score[:, :, local_start_idx:local_end_idx] += torch.matmul(grad_output, sub_v.transpose(2, 1)) + + # compute QK^T in ring-all-reduce style + for i in range(local_world_size - 1): + sub_v = ring_forward(sub_v, ParallelMode.SEQUENCE) + start_idx, end_idx = _calc_incoming_device_range(i, local_rank, local_world_size, ctx.sub_seq_length) + + # compute grad_q + grad_attention_score[:, :, start_idx:end_idx] += torch.matmul(grad_output, sub_v.transpose(2, 1)) + + return grad_attention_score, grad_v, None, None, None, None diff --git a/colossalai/nn/layer/parallel_sequence/_utils.py b/colossalai/legacy/nn/layer/parallel_sequence/_utils.py similarity index 100% rename from colossalai/nn/layer/parallel_sequence/_utils.py rename to colossalai/legacy/nn/layer/parallel_sequence/_utils.py diff --git a/colossalai/legacy/nn/layer/parallel_sequence/layers.py b/colossalai/legacy/nn/layer/parallel_sequence/layers.py new file mode 100644 index 0000000000000000000000000000000000000000..063b0cd8e2b2bc93c2e23e7d0851c3843defee48 --- /dev/null +++ b/colossalai/legacy/nn/layer/parallel_sequence/layers.py @@ -0,0 +1,266 @@ +#!/usr/bin/env python +# -*- encoding: utf-8 -*- + +import math + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.nn import Parameter + +from colossalai.kernel import FusedScaleMaskSoftmax +from colossalai.kernel.cuda_native.scaled_softmax import AttnMaskType +from colossalai.legacy.context import seed +from colossalai.legacy.context.parallel_mode import ParallelMode +from colossalai.legacy.core import global_context as gpc +from colossalai.legacy.nn.layer.parallel_sequence._operation import RingAV, RingQK +from colossalai.legacy.registry import LAYERS + + +@LAYERS.register_module +class TransformerSelfAttentionRing(nn.Module): + """Parallel self-attention layer abstract class. + Self-attention layer takes input with size [b, s, h] + and returns output of the same size. + + Args: + hidden_size (int): hidden size. + num_attention_heads (int): number of attention heads. + attention_dropout (float): dropout probability for attention layer. + attention_mask_func (:class:`typing.Callable`): Mask function to be applied. + layer_number (int): number of layers. + + """ + + def __init__( + self, + hidden_size, + num_attention_heads, + attention_dropout, + attention_mask_func, + layer_number, + apply_query_key_layer_scaling: bool = False, + convert_fp16_to_fp32_in_softmax: bool = False, + attn_mask_type=AttnMaskType.padding, + masked_softmax_fusion=True, + fp16=False, + bf16=False, + ): + super().__init__() + self.convert_fp16_to_fp32_in_softmax = convert_fp16_to_fp32_in_softmax + self.apply_query_key_layer_scaling = apply_query_key_layer_scaling + self.attention_mask_func = attention_mask_func + self.layer_number = layer_number + self.hidden_size = hidden_size + self.num_attention_heads = num_attention_heads + self.attn_mask_type = attn_mask_type + assert self.layer_number > 0 + self.attention_dropout = attention_dropout + + if self.apply_query_key_layer_scaling: + self.convert_fp16_to_fp32_in_softmax = True + + assert ( + self.hidden_size % self.num_attention_heads == 0 + ), "hidden size is not divisible by the number of attention heads" + + self.hidden_size_per_attention_head = self.hidden_size // num_attention_heads + + self.world_size = gpc.get_world_size(ParallelMode.SEQUENCE) + + # Strided linear layer. + self.query_key_value = _Linear( + hidden_size, + 3 * self.hidden_size, + ) + + self.coeff = None + self.norm_factor = math.sqrt(self.hidden_size) + + if self.apply_query_key_layer_scaling: + self.coeff = layer_number + self.norm_factor *= self.coeff + + self.scale_mask_softmax = FusedScaleMaskSoftmax( + fp16, + bf16, + self.attn_mask_type, + masked_softmax_fusion, + self.attention_mask_func, + self.convert_fp16_to_fp32_in_softmax, + self.coeff, + ) + + self.attention_dropout = nn.Dropout(attention_dropout) + + # Output. + self.dense = _Linear(hidden_size, hidden_size, bias=True, skip_bias_add=True) + + def forward(self, hidden_states, attention_mask): + # hidden_states: [sub_seq_len, batch_size, hidden_size] + # attention_mask: [batch_size, 1, sub_seq_len, seq_len] + sub_seq_length, batch_size, hidden_size = hidden_states.size() + + # ===================== + # Query, Key, and Value + # ===================== + + # Attention heads shape change: + # [sub_seq_len, batch_size, hidden_size] --> [sub_seq_len, batch_size, (3 * head_size * num_heads)] + mixed_x_layer = self.query_key_value(hidden_states) + + # [sub_seq_len, batch_size, num_heads, 3 * head_size] --> 3 [sub_seq_len, batch_size, num_heads, head_size] + new_tensor_shape = mixed_x_layer.size()[:-1] + ( + self.num_attention_heads, + 3 * self.hidden_size_per_attention_head, + ) + mixed_x_layer = mixed_x_layer.view(*new_tensor_shape) + + # split into query, key and value + last_dim = mixed_x_layer.dim() - 1 + last_dim_value = mixed_x_layer.size(-1) + assert last_dim_value % 3 == 0, ( + "the last dimension is not a multiple of 3, " "cannot be divided into query, key and value" + ) + partition_size = last_dim_value // 3 + (query_layer, key_layer, value_layer) = torch.split(mixed_x_layer, partition_size, dim=last_dim) + + # attention scores: [batch_size, num_heads, sub_seq_len, seq_len] + output_size = ( + query_layer.size(1), + query_layer.size(2), + query_layer.size(0), + key_layer.size(0) * self.world_size, + ) + + # [sub_seq_len, batch_size, num_heads, head_size] -> [sub_seq_len, batch_size * num_heads, head_size] + query_layer = query_layer.view(output_size[2], output_size[0] * output_size[1], -1) + # [sub_seq_len, batch_size, num_heads, head_size] -> [sub_seq_len, batch_size * num_heads, head_size] + key_layer = key_layer.view(key_layer.size(0), output_size[0] * output_size[1], -1) + + # attention_scores: [batch_size * num_heads, sub_seq_len, seq_len] + attention_scores = RingQK.apply( + query_layer.transpose(0, 1).contiguous(), # [batch_size * num_heads, sub_seq_len, head_size] + key_layer.transpose(0, 1).contiguous(), # [batch_size * num_heads, sub_seq_len, head_size], + batch_size, + self.num_attention_heads, + sub_seq_length, + ) + + attention_scores /= self.norm_factor + + # change view to [batch_size, num_heads, sub_seq_len, seq_len] + attention_scores = attention_scores.view(*output_size) + + # change shape to [batch_size, num_heads, sub_seq_len, seq_len] + attention_probs = self.scale_mask_softmax(attention_scores, attention_mask) + # This is actually dropping out entire tokens to attend to, which might + # seem a bit unusual, but is taken from the original Transformer paper. + with seed(ParallelMode.TENSOR): + attention_probs = self.attention_dropout(attention_probs) + + # context layer shape: [batch_size, num_heads, sub_seq_len, head_size] + output_size = (value_layer.size(1), value_layer.size(2), query_layer.size(0), value_layer.size(3)) + + # change view [sub_seq_len, batch_size * num_heads, head_size] + value_layer = value_layer.contiguous().view(value_layer.size(0), output_size[0] * output_size[1], -1) + + # # change view [b * num_heads, sub_seq_len, seq_len] + attention_probs = attention_probs.view( + attention_probs.size(0) * attention_probs.size(1), attention_probs.size(2), attention_probs.size(3) + ) + + # matmul: [batch_size * num_heads, sub_seq_len, head_size] + context_layer = RingAV.apply( + attention_probs, + value_layer.transpose(0, 1).contiguous(), + batch_size, + self.num_attention_heads, + self.hidden_size_per_attention_head, + sub_seq_length, + ) + + # change view [batch_size, num_heads, sub_seq_len, head_size] + context_layer = context_layer.view(*output_size) + + # [batch_size, num_heads, sub_seq_len, head_size] -> [sub_seq_len, batch_size, num_heads, head_size] + context_layer = context_layer.permute(2, 0, 1, 3).contiguous() + + # [sub_seq_len, batch_size, num_heads, head_size] -> [sub_seq_len, batch_size, hidden_size] + new_context_layer_shape = context_layer.size()[:-2] + ( + self.hidden_size_per_attention_head * self.num_attention_heads, + ) + context_layer = context_layer.view(*new_context_layer_shape) + + output, bias = self.dense(context_layer) + + return output, bias + + def __repr__(self): + return ( + f"TransformerSelfAttentionRing(apply_query_key_layer_scaling={self.apply_query_key_layer_scaling}, " + f"layer_number={self.layer_number}, hidden_size:{self.hidden_size}, attention_dropout={self.attention_dropout}, " + f"attn_mask_type={self.attn_mask_type}, num_attention_heads={self.num_attention_heads}, " + f"hidden_size_per_attention_head={self.hidden_size_per_attention_head}, coeff={self.coeff}, norm_factor={self.norm_factor}, " + f"convert_fp16_to_fp32_in_softmax={self.convert_fp16_to_fp32_in_softmax})" + ) + + +class _Linear(nn.Module): + """Linear layer with column parallelism. + The linear layer is defined as Y = XA + b. A is parallelized along + its second dimension as A = [A_1, ..., A_p]. + Arguments: + input_size: first dimension of matrix A. + output_size: second dimension of matrix A. + bias: If true, add bias + init_method: method to initialize weights. Note that bias is always set + to zero. + stride: For the strided linear layers. + keep_master_weight_for_test: This was added for testing and should be + set to False. It returns the master weights + used for initialization. + skip_bias_add: This was added to enable performance optimizations where bias + can be fused with other elementwise operations. we skip + adding bias but instead return it. + """ + + def __init__(self, input_size, output_size, bias=True, skip_bias_add=False): + super(_Linear, self).__init__() + + # Keep input parameters + self.input_size = input_size + self.output_size = output_size + self.skip_bias_add = skip_bias_add + + self.weight = Parameter( + torch.empty( + self.output_size, + self.input_size, + ) + ) + nn.init.xavier_normal_(self.weight) + + if bias: + self.bias = Parameter(torch.empty(self.output_size)) + # Always initialize bias to zero. + with torch.no_grad(): + self.bias.zero_() + else: + self.register_parameter("bias", None) + + def forward(self, input_): + # Matrix multiply. + bias = self.bias if not self.skip_bias_add else None + output = F.linear(input_, self.weight, bias) + + if self.skip_bias_add: + return output, self.bias + else: + return output + + def __repr__(self): + return ( + f"Linear(in_features={self.input_size}, out_features={self.output_size}, " + + f"bias={self.bias is not None}, skip_bias_add={self.skip_bias_add})" + ) diff --git a/colossalai/legacy/nn/layer/utils/__init__.py b/colossalai/legacy/nn/layer/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..4e78b228eb4f56797bb1d861378d76ed5ca24fa7 --- /dev/null +++ b/colossalai/legacy/nn/layer/utils/__init__.py @@ -0,0 +1,21 @@ +from .common import ( + ACT2FN, + CheckpointModule, + _ntuple, + divide, + get_tensor_parallel_mode, + set_tensor_parallel_attribute_by_partition, + set_tensor_parallel_attribute_by_size, + to_2tuple, +) + +__all__ = [ + "CheckpointModule", + "divide", + "ACT2FN", + "set_tensor_parallel_attribute_by_size", + "set_tensor_parallel_attribute_by_partition", + "get_tensor_parallel_mode", + "_ntuple", + "to_2tuple", +] diff --git a/colossalai/legacy/nn/layer/utils/common.py b/colossalai/legacy/nn/layer/utils/common.py new file mode 100644 index 0000000000000000000000000000000000000000..fd6a5b38d60a4204d1991e630268ef4ce599d2c8 --- /dev/null +++ b/colossalai/legacy/nn/layer/utils/common.py @@ -0,0 +1,89 @@ +#!/usr/bin/env python +# -*- encoding: utf-8 -*- + +import collections.abc +from itertools import repeat + +import numpy as np +import torch +from torch import Tensor, nn + +from colossalai.legacy.constants import IS_TENSOR_PARALLEL, NUM_PARTITIONS +from colossalai.legacy.global_variables import tensor_parallel_env as env +from colossalai.legacy.utils import checkpoint + + +class CheckpointModule(nn.Module): + def __init__(self, checkpoint: bool = True, offload: bool = False): + super().__init__() + self.checkpoint = checkpoint + self._use_checkpoint = checkpoint + self._offload = offload + + def _forward(self, *args, **kwargs): + raise NotImplementedError("CheckpointModule should implement _forward method instead of origin forward") + + def forward(self, *args, **kwargs): + if self._use_checkpoint: + return checkpoint(self._forward, self._offload, *args, **kwargs) + else: + return self._forward(*args, **kwargs) + + def train(self, mode: bool = True): + self._use_checkpoint = self.checkpoint + return super().train(mode=mode) + + def eval(self): + self._use_checkpoint = False + return super().eval() + + +def divide(numerator, denominator): + """Only allow exact division. + + Args: + numerator (int): Numerator of the division. + denominator (int): Denominator of the division. + + Returns: + int: the result of exact division. + """ + assert denominator != 0, "denominator can not be zero" + assert numerator % denominator == 0, "{} is not divisible by {}".format(numerator, denominator) + return numerator // denominator + + +def swish(x: Tensor) -> Tensor: + return x * torch.sigmoid(x) + + +ACT2FN = {"gelu": torch.nn.functional.gelu, "relu": torch.nn.functional.relu, "swish": swish} + + +def set_tensor_parallel_attribute_by_size(param, size): + setattr(param, IS_TENSOR_PARALLEL, True) + setattr(param, NUM_PARTITIONS, size // np.prod(param.shape)) + + +def set_tensor_parallel_attribute_by_partition(param, num_partitions): + setattr(param, IS_TENSOR_PARALLEL, True) + setattr(param, NUM_PARTITIONS, num_partitions) + + +def get_tensor_parallel_mode(): + return env.mode + + +# From PyTorch internals + + +def _ntuple(n): + def parse(x): + if isinstance(x, collections.abc.Iterable): + return x + return tuple(repeat(x, n)) + + return parse + + +to_2tuple = _ntuple(2) diff --git a/colossalai/legacy/nn/layer/vanilla/__init__.py b/colossalai/legacy/nn/layer/vanilla/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..5785bbef33d7b11c4e5b561df3bece108c431632 --- /dev/null +++ b/colossalai/legacy/nn/layer/vanilla/__init__.py @@ -0,0 +1,19 @@ +from .layers import ( + DropPath, + VanillaClassifier, + VanillaLayerNorm, + VanillaLinear, + VanillaPatchEmbedding, + WrappedDropout, + WrappedDropPath, +) + +__all__ = [ + "VanillaLayerNorm", + "VanillaPatchEmbedding", + "VanillaClassifier", + "DropPath", + "WrappedDropout", + "WrappedDropPath", + "VanillaLinear", +] diff --git a/colossalai/legacy/nn/layer/vanilla/layers.py b/colossalai/legacy/nn/layer/vanilla/layers.py new file mode 100644 index 0000000000000000000000000000000000000000..12965a4a6409b1e6a5e484bbd1d74024f6e3b257 --- /dev/null +++ b/colossalai/legacy/nn/layer/vanilla/layers.py @@ -0,0 +1,350 @@ +import math +from typing import Callable + +import torch +import torch.nn.functional as F +from torch import Tensor +from torch import nn as nn +from torch.nn.parameter import Parameter + +from colossalai.legacy.context import seed +from colossalai.legacy.registry import LAYERS +from colossalai.nn import init as init +from colossalai.utils.cuda import get_current_device + +from ..utils import to_2tuple + + +def drop_path(x, drop_prob: float = 0.0, training: bool = False): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). + + This is the same as the DropConnect impl I created for EfficientNet, etc networks, however, + the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper... + See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for + changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use + 'survival rate' as the argument. + + Args: + drop_prob (float, optional): probability of dropping path, defaults 0.0. + training (bool, optional): whether in training progress, defaults False. + """ + if drop_prob == 0.0 or not training: + return x + keep_prob = 1 - drop_prob + shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets + random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device) + random_tensor.floor_() # binarize + output = x.div(keep_prob) * random_tensor + return output + + +class DropPath(nn.Module): + """ + Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). + Adapted from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/layers/drop.py + + Args: + drop_prob (float, optional): probability of dropping path, defaults None. + """ + + def __init__(self, drop_prob=None): + super(DropPath, self).__init__() + self.drop_prob = drop_prob + + def forward(self, x): + return drop_path(x, self.drop_prob, self.training) + + +class WrappedDropout(nn.Module): + r"""Same as torch.nn.Dropout. But it is wrapped with the context of seed manager. During training, randomly zeroes + some elements of the input tensor with probability p using samples from a Bernoulli distribution. Each + channel will be zeroed out independently on every forward call. Furthermore, the outputs are scaled by a factor of + 1/(1-p) during training. This means that during evaluation the module simply computes an identity function. + + Args: + p (float, optional): probability of an element to be zeroed, defaults 0.5. + inplace (bool, optional): whether to do dropout in-place, default to be False. + mode (:class:`colossalai.legacy.context.ParallelMode`): The chosen parallel mode. + + Note: + The parallel_mode should be concluded in ``ParallelMode``. More details about ``ParallelMode`` could be found + in `parallel_mode `_ + """ + + def __init__(self, p: float = 0.5, inplace: bool = False, mode=None): + super().__init__() + if p < 0 or p > 1: + raise ValueError("dropout probability has to be between 0 and 1, " "but got {}".format(p)) + self.p = p + self.inplace = inplace + if mode is None: + self.func = self.nonefunc + else: + self.func = self.normalfunc + self.mode = mode + + def nonefunc(self, inputs): + return F.dropout(inputs, self.p, self.training, self.inplace) + + def normalfunc(self, inputs): + with seed(self.mode): + return F.dropout(inputs, self.p, self.training, self.inplace) + + def forward(self, inputs): + return self.func(inputs) + + +class WrappedDropPath(nn.Module): + r"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). + Here, it is wrapped with the context of seed manager. + + Args: + p (float, optional): probability of dropping path, defaults 0.0. + mode (:class:`colossalai.legacy.context.ParallelMode`): The chosen parallel mode. + + Note: + The parallel_mode should be concluded in ``ParallelMode``. More details about ``ParallelMode`` could be found + in `parallel_mode `_ + """ + + def __init__(self, p: float = 0.0, mode=None): + super().__init__() + self.p = p + self.mode = mode + if self.mode is None: + self.func = self.nonefunc + else: + self.func = self.normalfunc + self.mode = mode + + def nonefunc(self, inputs): + return drop_path(inputs, self.p, self.training) + + def normalfunc(self, inputs): + with seed(self.mode): + return drop_path(inputs, self.p, self.training) + + def forward(self, inputs): + return self.func(inputs) + + +@LAYERS.register_module +class VanillaPatchEmbedding(nn.Module): + r""" + 2D Image to Patch Embedding + + Args: + img_size (int): image size. + patch_size (int): patch size. + in_chans (int): number of channels of input image. + embed_size (int): size of embedding. + dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None. + flatten (bool, optional): whether to flatten output tensor, defaults to True. + weight_initializer (:class:`typing.Callable`, optional): + The initializer of weight, defaults to kaiming uniform initializer. + bias_initializer (:class:`typing.Callable`, optional): + The initializer of bias, defaults to xavier uniform initializer. + position_embed_initializer (:class:`typing.Callable`, optional): + The initializer of position embedding, defaults to zeros initializer. + + More details about initializer please refer to + `init `_. + """ + + def __init__( + self, + img_size: int, + patch_size: int, + in_chans: int, + embed_size: int, + flatten: bool = True, + dtype: torch.dtype = None, + weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), + bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1), + position_embed_initializer: Callable = init.zeros_(), + ): + super().__init__() + img_size = to_2tuple(img_size) + patch_size = to_2tuple(patch_size) + self.img_size = img_size + self.patch_size = patch_size + self.grid_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1]) + self.num_patches = self.grid_size[0] * self.grid_size[1] + self.flatten = flatten + + self.weight = nn.Parameter( + torch.empty((embed_size, in_chans, *self.patch_size), device=get_current_device(), dtype=dtype) + ) + self.bias = nn.Parameter(torch.empty(embed_size, device=get_current_device(), dtype=dtype)) + self.cls_token = nn.Parameter(torch.zeros((1, 1, embed_size), device=get_current_device(), dtype=dtype)) + self.pos_embed = nn.Parameter( + torch.zeros((1, self.num_patches + 1, embed_size), device=get_current_device(), dtype=dtype) + ) + + self.reset_parameters(weight_initializer, bias_initializer, position_embed_initializer) + + def reset_parameters(self, weight_initializer, bias_initializer, position_embed_initializer): + fan_in, fan_out = nn.init._calculate_fan_in_and_fan_out(self.weight) + weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out) + bias_initializer(self.bias, fan_in=fan_in) + position_embed_initializer(self.pos_embed) + + def forward(self, input_: Tensor) -> Tensor: + B, C, H, W = input_.shape + assert ( + H == self.img_size[0] and W == self.img_size[1] + ), f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." + output = F.conv2d(input_, self.weight, self.bias, stride=self.patch_size) + if self.flatten: + output = output.flatten(2).transpose(1, 2) # BCHW -> BNC + + cls_token = self.cls_token.expand(output.shape[0], -1, -1) + output = torch.cat((cls_token, output), dim=1) + output = output + self.pos_embed + return output + + +@LAYERS.register_module +class VanillaClassifier(nn.Module): + r"""Dense linear classifier. + + Args: + in_features (int): size of each input sample. + num_classes (int): number of classes. + weight (:class:`torch.nn.Parameter`, optional): weight of the classifier, defaults to None. + dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None. + flatten (bool, optional): whether to flatten output tensor, defaults to True. + weight_initializer (:class:`typing.Callable`, optional): + The initializer of weight, defaults to kaiming uniform initializer. + bias_initializer (:class:`typing.Callable`, optional): + The initializer of bias, defaults to xavier uniform initializer. + + More details about initializer please refer to + `init `_. + """ + + def __init__( + self, + in_features: int, + num_classes: int, + weight: nn.Parameter = None, + bias: bool = True, + dtype: torch.dtype = None, + weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), + bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1), + ): + super().__init__() + self.in_features = in_features + self.num_classes = num_classes + + if weight is not None: + self.weight = weight + self.has_weight = False + else: + self.weight = nn.Parameter( + torch.empty(self.num_classes, self.in_features, device=get_current_device(), dtype=dtype) + ) + self.has_weight = True + if bias: + self.bias = nn.Parameter(torch.zeros(self.num_classes, device=get_current_device(), dtype=dtype)) + else: + self.bias = None + + self.reset_parameters(weight_initializer, bias_initializer) + + def reset_parameters(self, weight_initializer, bias_initializer): + fan_in, fan_out = self.in_features, self.num_classes + + if self.has_weight: + weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out) + + if self.bias is not None: + bias_initializer(self.bias, fan_in=fan_in) + + def forward(self, input_: Tensor) -> Tensor: + return F.linear(input_, self.weight, self.bias) + + +@LAYERS.register_module +class VanillaLayerNorm(nn.Module): + r""" + Layer Normalization for colossalai + + Args: + normalized_shape (int): input shape from an expected input of size. + :math:`[* \times \text{normalized_shape}[0] \times \text{normalized_shape}[1] + \times \ldots \times \text{normalized_shape}[-1]]` + If a single integer is used, it is treated as a singleton list, and this module will + normalize over the last dimension which is expected to be of that specific size. + eps (float): a value added to the denominator for numerical stability, defaults to 1e-05. + bias (bool, optional): Whether to add a bias, defaults to ``True``. + dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None. + """ + + def __init__(self, normalized_shape: int, eps=1e-05, bias=True, dtype=None): + super().__init__() + + self.normalized_shape = (normalized_shape,) + self.variance_epsilon = eps + + factory_kwargs = {"device": get_current_device(), "dtype": dtype} + + self.weight = nn.Parameter(torch.ones(normalized_shape, **factory_kwargs)) + if bias: + self.bias = nn.Parameter(torch.zeros(normalized_shape, **factory_kwargs)) + else: + self.bias = None + + def forward(self, x: Tensor) -> Tensor: + return F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.variance_epsilon) + + +@LAYERS.register_module +class VanillaLinear(nn.Module): + """Linear layer. + + Args: + in_features (int): size of each input sample. + out_features (int): size of each output sample. + bias (bool, optional): If set to ``False``, the layer will not learn an additive bias, defaults to ``True``. + dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None. + skip_bias_add: bool (optional, default to be false). + weight_initializer (:class:`typing.Callable`, optional): + The initializer of weight, defaults to kaiming uniform initializer. + bias_initializer (:class:`typing.Callable`, optional): + The initializer of bias, defaults to xavier uniform initializer. + + More details about ``initializer`` please refer to + `init `_. + """ + + def __init__( + self, + in_features: int, + out_features: int, + bias: bool = True, + dtype: torch.dtype = None, + skip_bias_add: bool = False, + weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), + bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1), + **kwargs, + ) -> None: + super().__init__() + self.in_features = in_features + self.out_features = out_features + self.skip_bias_add = skip_bias_add + factory_kwargs = {"device": get_current_device(), "dtype": dtype} + self.weight = Parameter(torch.empty(self.out_features, self.in_features, **factory_kwargs)) + if bias: + self.bias = Parameter(torch.empty(self.out_features, **factory_kwargs)) + else: + self.bias = None + weight_initializer(self.weight, fan_in=in_features, fan_out=out_features) + if self.bias is not None: + bias_initializer(self.bias, fan_in=in_features) + + def forward(self, input: Tensor) -> Tensor: + if not self.skip_bias_add: + return F.linear(input, self.weight, self.bias) + else: + return F.linear(input, self.weight), self.bias diff --git a/colossalai/legacy/nn/layer/wrapper/__init__.py b/colossalai/legacy/nn/layer/wrapper/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..4f3a336453443ba8a1f8b85599ed903b4d83cb56 --- /dev/null +++ b/colossalai/legacy/nn/layer/wrapper/__init__.py @@ -0,0 +1,3 @@ +from .pipeline_wrapper import PipelineSharedModuleWrapper + +__all__ = ["PipelineSharedModuleWrapper"] diff --git a/colossalai/legacy/nn/layer/wrapper/pipeline_wrapper.py b/colossalai/legacy/nn/layer/wrapper/pipeline_wrapper.py new file mode 100644 index 0000000000000000000000000000000000000000..55445eb4d35a85abed3497899505af9e764eca21 --- /dev/null +++ b/colossalai/legacy/nn/layer/wrapper/pipeline_wrapper.py @@ -0,0 +1,49 @@ +from typing import List, Tuple, Union + +import torch.distributed as dist +import torch.nn as nn + +from colossalai.legacy.context import ParallelMode +from colossalai.legacy.core import global_context as gpc + + +class PipelineSharedModuleWrapper: + def __init__(self, pipeline_ranks: Union[List[int], Tuple[int]]) -> None: + assert len(pipeline_ranks) > 1, f"Expect len(pipeline_ranks) > 1, got {len(pipeline_ranks)}" + self.pipeline_ranks = pipeline_ranks + self.group = None + self.ranks_in_group = None + self._init_group() + + def _init_group(self): + world_size = gpc.get_world_size(ParallelMode.GLOBAL) + dp_size = gpc.get_world_size(ParallelMode.DATA) + pp_size = gpc.get_world_size(ParallelMode.PIPELINE) + rank = gpc.get_global_rank() + num_dp_groups = world_size // dp_size + num_pp_stages = num_dp_groups // pp_size + for i in range(dp_size): + for j in range(num_pp_stages): + pipeline_ranks = list(range(i * num_dp_groups + j, (i + 1) * num_dp_groups, num_pp_stages)) + sub_ranks = [pipeline_ranks[idx] for idx in self.pipeline_ranks] + group = dist.new_group(sub_ranks) + if rank in sub_ranks: + self.group = group + self.ranks_in_group = sub_ranks + + def register_module(self, module: nn.Module): + assert ( + self.ranks_in_group is not None + ), f"Rank {gpc.get_local_rank(ParallelMode.PIPELINE)} is not in pipeline_ranks {self.pipeline_ranks}" + src = self.ranks_in_group[self.pipeline_ranks[0]] + for p in module.parameters(): + setattr(p, "pipeline_shared_module_pg", self.group) + dist.broadcast(p, src, group=self.group) + + def register_parameter(self, param: nn.Parameter): + assert ( + self.ranks_in_group is not None + ), f"Rank {gpc.get_local_rank(ParallelMode.PIPELINE)} is not in pipeline_ranks {self.pipeline_ranks}" + src = self.ranks_in_group[self.pipeline_ranks[0]] + setattr(param, "pipeline_shared_module_pg", self.group) + dist.broadcast(param, src, group=self.group) diff --git a/colossalai/legacy/nn/loss/__init__.py b/colossalai/legacy/nn/loss/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..43e5a5a2e2aa8cde10ade57be4bfa11d4bc0a682 --- /dev/null +++ b/colossalai/legacy/nn/loss/__init__.py @@ -0,0 +1,40 @@ +from torch import nn +from torch.nn.modules.loss import * +from torch.nn.modules.loss import _Loss + +from colossalai.legacy.global_variables import tensor_parallel_env as env +from colossalai.legacy.nn.layer.utils import get_tensor_parallel_mode + +from .loss_1d import VocabParallelCrossEntropyLoss1D +from .loss_2d import CrossEntropyLoss2D, VocabParallelCrossEntropyLoss2D +from .loss_2p5d import CrossEntropyLoss2p5D, VocabParallelCrossEntropyLoss2p5D +from .loss_3d import CrossEntropyLoss3D, VocabParallelCrossEntropyLoss3D + +_parallel_cross_entropy = { + "2d": CrossEntropyLoss2D, + "2.5d": CrossEntropyLoss2p5D, + "3d": CrossEntropyLoss3D, +} + +_vocab_parallel_cross_entropy = { + "1d": VocabParallelCrossEntropyLoss1D, + "2d": VocabParallelCrossEntropyLoss2D, + "2.5d": VocabParallelCrossEntropyLoss2p5D, + "3d": VocabParallelCrossEntropyLoss3D, +} + + +class CrossEntropyLoss(_Loss): + def __init__(self, reduction: bool = True, *args, **kwargs): + super().__init__() + tensor_parallel = get_tensor_parallel_mode() + if tensor_parallel is not None and env.vocab_parallel: + self.loss = _vocab_parallel_cross_entropy[tensor_parallel](reduction=reduction, *args, **kwargs) + elif tensor_parallel is None or tensor_parallel == "1d": + reduction = "mean" if reduction else "none" + self.loss = nn.CrossEntropyLoss(reduction=reduction, *args, **kwargs) + else: + self.loss = _parallel_cross_entropy[tensor_parallel](reduction=reduction, *args, **kwargs) + + def forward(self, *args): + return self.loss(*args) diff --git a/colossalai/nn/loss/loss_1d.py b/colossalai/legacy/nn/loss/loss_1d.py similarity index 89% rename from colossalai/nn/loss/loss_1d.py rename to colossalai/legacy/nn/loss/loss_1d.py index 2fabd954f8fb7daeb7b457f0c2b216d191e6ed69..fae9c929b78827c916c2b201dcc66d9313faa801 100644 --- a/colossalai/nn/loss/loss_1d.py +++ b/colossalai/legacy/nn/loss/loss_1d.py @@ -1,105 +1,104 @@ -import torch -import torch.distributed as dist -from colossalai.context import ParallelMode -from colossalai.core import global_context as gpc -from colossalai.registry import LOSSES -from torch.cuda.amp import custom_bwd, custom_fwd -from torch.nn.modules.loss import _Loss - - -class _VocabParallelCrossEntropy1D(torch.autograd.Function): - - @staticmethod - @custom_fwd(cast_inputs=torch.float32) - def forward(ctx, vocab_parallel_logits, targets, process_group): - if process_group is None: - process_group = gpc.get_group(ParallelMode.PARALLEL_1D) - - # Maximum value along vocab dimension across all GPUs. - logits_max = torch.max(vocab_parallel_logits, dim=-1)[0] - torch.distributed.all_reduce(logits_max, op=torch.distributed.ReduceOp.MAX, group=process_group) - # Subtract the maximum value. - vocab_parallel_logits.sub_(logits_max.unsqueeze(dim=-1)) - - # Get the partition's vocab indecies - partition_vocab_size = vocab_parallel_logits.size()[-1] - rank = dist.get_rank(process_group) - vocab_start_index = partition_vocab_size * rank - vocab_end_index = vocab_start_index + partition_vocab_size - - # Create a mask of valid vocab ids (1 means it needs to be masked). - target_mask = (targets < vocab_start_index) | (targets >= vocab_end_index) - masked_target = targets.clone() - vocab_start_index - masked_target[target_mask] = 0 - - # Get predicted-logits = logits[target]. - # For Simplicity, we convert logits to a 2-D tensor with size - # [*, partition-vocab-size] and target to a 1-D tensor of size [*]. - logits_2d = vocab_parallel_logits.view(-1, partition_vocab_size) - masked_target_1d = masked_target.view(-1) - arange_1d = torch.arange(start=0, end=logits_2d.size()[0], device=logits_2d.device) - predicted_logits_1d = logits_2d[arange_1d, masked_target_1d] - predicted_logits_1d = predicted_logits_1d.clone().contiguous() - predicted_logits = predicted_logits_1d.view_as(targets) - predicted_logits[target_mask] = 0.0 - # All reduce is needed to get the chunks from other GPUs. - torch.distributed.all_reduce(predicted_logits, op=torch.distributed.ReduceOp.SUM, group=process_group) - - # Sum of exponential of logits along vocab dimension across all GPUs. - exp_logits = torch.exp(vocab_parallel_logits) - sum_exp_logits = exp_logits.sum(dim=-1) - torch.distributed.all_reduce(sum_exp_logits, op=torch.distributed.ReduceOp.SUM, group=process_group) - - # Loss = log(sum(exp(logits))) - predicted-logit. - loss = torch.log(sum_exp_logits) - predicted_logits - # Store softmax, target-mask and masked-target for backward pass. - exp_logits.div_(sum_exp_logits.unsqueeze(dim=-1)) - ctx.save_for_backward(exp_logits, target_mask, masked_target_1d) - return loss - - @staticmethod - @custom_bwd - def backward(ctx, grad_output): - - # Retreive tensors from the forward path. - softmax, target_mask, masked_target_1d = ctx.saved_tensors - - # All the inputs have softmax as thier gradient. - grad_input = softmax - # For simplicity, work with the 2D gradient. - partition_vocab_size = softmax.size()[-1] - grad_2d = grad_input.view(-1, partition_vocab_size) - - # Add the gradient from matching classes. - arange_1d = torch.arange(start=0, end=grad_2d.size()[0], device=grad_2d.device) - grad_2d[arange_1d, masked_target_1d] -= (1.0 - target_mask.view(-1).float()) - - # Finally elementwise multiplication with the output gradients. - grad_input.mul_(grad_output.unsqueeze(dim=-1)) - - return grad_input, None, None - - -@LOSSES.register_module -class VocabParallelCrossEntropyLoss1D(_Loss): - """Vocab parallel cross entropy loss for 1D parallelism. - - Args: - reduction (bool, optional): whether to average the loss, defaults to True. - """ - - def __init__(self, reduction=True): - super().__init__() - self.reduction_mean = reduction - - def forward(self, logits, targets, process_group=None): - """Calculate loss between logits and targets. - - Args: - logits (:class:`torch.tensor`): Predicted unnormalized scores (often referred to as logits). - targets (:class:`torch.tensor`): Ground truth class indices or class probabilities. - """ - loss = _VocabParallelCrossEntropy1D.apply(logits, targets, process_group) - if self.reduction_mean: - loss = loss.mean() - return loss +import torch +import torch.distributed as dist +from torch.cuda.amp import custom_bwd, custom_fwd +from torch.nn.modules.loss import _Loss + +from colossalai.legacy.context import ParallelMode +from colossalai.legacy.core import global_context as gpc +from colossalai.legacy.registry import LOSSES + + +class _VocabParallelCrossEntropy1D(torch.autograd.Function): + @staticmethod + @custom_fwd(cast_inputs=torch.float32) + def forward(ctx, vocab_parallel_logits, targets, process_group): + if process_group is None: + process_group = gpc.get_group(ParallelMode.PARALLEL_1D) + + # Maximum value along vocab dimension across all GPUs. + logits_max = torch.max(vocab_parallel_logits, dim=-1)[0] + torch.distributed.all_reduce(logits_max, op=torch.distributed.ReduceOp.MAX, group=process_group) + # Subtract the maximum value. + vocab_parallel_logits.sub_(logits_max.unsqueeze(dim=-1)) + + # Get the partition's vocab indices + partition_vocab_size = vocab_parallel_logits.size()[-1] + rank = dist.get_rank(process_group) + vocab_start_index = partition_vocab_size * rank + vocab_end_index = vocab_start_index + partition_vocab_size + + # Create a mask of valid vocab ids (1 means it needs to be masked). + target_mask = (targets < vocab_start_index) | (targets >= vocab_end_index) + masked_target = targets.clone() - vocab_start_index + masked_target[target_mask] = 0 + + # Get predicted-logits = logits[target]. + # For Simplicity, we convert logits to a 2-D tensor with size + # [*, partition-vocab-size] and target to a 1-D tensor of size [*]. + logits_2d = vocab_parallel_logits.view(-1, partition_vocab_size) + masked_target_1d = masked_target.view(-1) + arange_1d = torch.arange(start=0, end=logits_2d.size()[0], device=logits_2d.device) + predicted_logits_1d = logits_2d[arange_1d, masked_target_1d] + predicted_logits_1d = predicted_logits_1d.clone().contiguous() + predicted_logits = predicted_logits_1d.view_as(targets) + predicted_logits[target_mask] = 0.0 + # All reduce is needed to get the chunks from other GPUs. + torch.distributed.all_reduce(predicted_logits, op=torch.distributed.ReduceOp.SUM, group=process_group) + + # Sum of exponential of logits along vocab dimension across all GPUs. + exp_logits = torch.exp(vocab_parallel_logits) + sum_exp_logits = exp_logits.sum(dim=-1) + torch.distributed.all_reduce(sum_exp_logits, op=torch.distributed.ReduceOp.SUM, group=process_group) + + # Loss = log(sum(exp(logits))) - predicted-logit. + loss = torch.log(sum_exp_logits) - predicted_logits + # Store softmax, target-mask and masked-target for backward pass. + exp_logits.div_(sum_exp_logits.unsqueeze(dim=-1)) + ctx.save_for_backward(exp_logits, target_mask, masked_target_1d) + return loss + + @staticmethod + @custom_bwd + def backward(ctx, grad_output): + # Retrieve tensors from the forward path. + softmax, target_mask, masked_target_1d = ctx.saved_tensors + + # All the inputs have softmax as their gradient. + grad_input = softmax + # For simplicity, work with the 2D gradient. + partition_vocab_size = softmax.size()[-1] + grad_2d = grad_input.view(-1, partition_vocab_size) + + # Add the gradient from matching classes. + arange_1d = torch.arange(start=0, end=grad_2d.size()[0], device=grad_2d.device) + grad_2d[arange_1d, masked_target_1d] -= 1.0 - target_mask.view(-1).float() + + # Finally elementwise multiplication with the output gradients. + grad_input.mul_(grad_output.unsqueeze(dim=-1)) + + return grad_input, None, None + + +@LOSSES.register_module +class VocabParallelCrossEntropyLoss1D(_Loss): + """Vocab parallel cross entropy loss for 1D parallelism. + + Args: + reduction (bool, optional): whether to average the loss, defaults to True. + """ + + def __init__(self, reduction=True): + super().__init__() + self.reduction_mean = reduction + + def forward(self, logits, targets, process_group=None): + """Calculate loss between logits and targets. + + Args: + logits (:class:`torch.tensor`): Predicted unnormalized scores (often referred to as logits). + targets (:class:`torch.tensor`): Ground truth class indices or class probabilities. + """ + loss = _VocabParallelCrossEntropy1D.apply(logits, targets, process_group) + if self.reduction_mean: + loss = loss.mean() + return loss diff --git a/colossalai/nn/loss/loss_2d.py b/colossalai/legacy/nn/loss/loss_2d.py similarity index 87% rename from colossalai/nn/loss/loss_2d.py rename to colossalai/legacy/nn/loss/loss_2d.py index cb12e723c3232446bf2b911730ff31f7274ea6e2..44f39a6db262e0ad101b4544c72f331637357fbe 100644 --- a/colossalai/nn/loss/loss_2d.py +++ b/colossalai/legacy/nn/loss/loss_2d.py @@ -1,15 +1,16 @@ import torch import torch.distributed as dist -from colossalai.context import ParallelMode -from colossalai.core import global_context as gpc -from colossalai.nn.layer.parallel_2d import reduce_by_batch_2d, split_batch_2d -from colossalai.nn.layer.parallel_2d._utils import assert_summa_initialization -from colossalai.registry import LOSSES -from colossalai.utils import get_current_device from torch.cuda.amp import custom_bwd, custom_fwd from torch.nn.functional import cross_entropy from torch.nn.modules.loss import _Loss +from colossalai.legacy.context import ParallelMode +from colossalai.legacy.core import global_context as gpc +from colossalai.legacy.nn.layer.parallel_2d import reduce_by_batch_2d, split_batch_2d +from colossalai.legacy.nn.layer.parallel_2d._utils import assert_summa_initialization +from colossalai.legacy.registry import LOSSES +from colossalai.utils import get_current_device + @LOSSES.register_module class CrossEntropyLoss2D(_Loss): @@ -49,7 +50,7 @@ class CrossEntropyLoss2D(_Loss): float: the loss between logits and targets. """ targets = split_batch_2d(targets) - loss = cross_entropy(logits, targets, reduction='none', *self.loss_args, **self.loss_kwargs) + loss = cross_entropy(logits, targets, reduction="none", *self.loss_args, **self.loss_kwargs) if self.reduction_mean: loss = loss.mean() loss = reduce_by_batch_2d(loss, True) @@ -68,9 +69,9 @@ class _VocabParallelCrossEntropy2D(torch.autograd.Function): # vocab_parallel_logits: [b/q, s, v/q] # target: [b/q, s] logits_max = torch.max(logits, dim=-1)[0] - torch.distributed.all_reduce(logits_max, - op=torch.distributed.ReduceOp.MAX, - group=gpc.get_group(ParallelMode.PARALLEL_2D_ROW)) + torch.distributed.all_reduce( + logits_max, op=torch.distributed.ReduceOp.MAX, group=gpc.get_group(ParallelMode.PARALLEL_2D_ROW) + ) # Subtract the maximum value. # vocab_parallel_logits.sub_(logits_max.unsqueeze(dim=-1)) logits = logits - logits_max.unsqueeze(dim=-1) @@ -89,7 +90,7 @@ class _VocabParallelCrossEntropy2D(torch.autograd.Function): end=logits.size()[0], ) predicted_logits = logits[arange_1d, masked_target] - predicted_logits[target_mask] = 0. + predicted_logits[target_mask] = 0.0 dist.all_reduce(predicted_logits, group=gpc.get_group(ParallelMode.PARALLEL_2D_ROW)) exp_logits = torch.exp(logits) @@ -106,7 +107,7 @@ class _VocabParallelCrossEntropy2D(torch.autograd.Function): @staticmethod @custom_bwd def backward(ctx, output_grad): - # Retreive tensors from the forward path. + # Retrieve tensors from the forward path. softmax, target_mask, masked_target = ctx.saved_tensors # All the inputs have softmax as their gradient. @@ -118,7 +119,7 @@ class _VocabParallelCrossEntropy2D(torch.autograd.Function): # Add the gradient from matching classes. arange_1d = torch.arange(start=0, end=grad_2d.size()[0], device=get_current_device()) - grad_2d[arange_1d, masked_target] -= (1.0 - target_mask.view(-1).float()) + grad_2d[arange_1d, masked_target] -= 1.0 - target_mask.view(-1).float() # Finally elementwise multiplication with the output gradients. grad_input.mul_(output_grad.unsqueeze(dim=-1)) diff --git a/colossalai/nn/loss/loss_2p5d.py b/colossalai/legacy/nn/loss/loss_2p5d.py similarity index 86% rename from colossalai/nn/loss/loss_2p5d.py rename to colossalai/legacy/nn/loss/loss_2p5d.py index f8e3324fc5ff8fe3d28ea25798e33ff9aeb26d25..c57bf26e913963296822be7f3bc5cd7b63013f18 100644 --- a/colossalai/nn/loss/loss_2p5d.py +++ b/colossalai/legacy/nn/loss/loss_2p5d.py @@ -1,15 +1,16 @@ import torch import torch.distributed as dist -from colossalai.context import ParallelMode -from colossalai.core import global_context as gpc -from colossalai.nn.layer.parallel_2p5d import reduce_by_batch_2p5d, split_batch_2p5d -from colossalai.nn.layer.parallel_2p5d._utils import assert_tesseract_initialization -from colossalai.registry import LOSSES -from colossalai.utils import get_current_device from torch.cuda.amp import custom_bwd, custom_fwd from torch.nn.functional import cross_entropy from torch.nn.modules.loss import _Loss +from colossalai.legacy.context import ParallelMode +from colossalai.legacy.core import global_context as gpc +from colossalai.legacy.nn.layer.parallel_2p5d import reduce_by_batch_2p5d, split_batch_2p5d +from colossalai.legacy.nn.layer.parallel_2p5d._utils import assert_tesseract_initialization +from colossalai.legacy.registry import LOSSES +from colossalai.utils import get_current_device + @LOSSES.register_module class CrossEntropyLoss2p5D(_Loss): @@ -46,7 +47,7 @@ class CrossEntropyLoss2p5D(_Loss): targets (:class:`torch.tensor`): Ground truth class indices or class probabilities. """ targets = split_batch_2p5d(targets) - loss = cross_entropy(logits, targets, reduction='none', *self.loss_args, **self.loss_kwargs) + loss = cross_entropy(logits, targets, reduction="none", *self.loss_args, **self.loss_kwargs) if self.reduction_mean: loss = loss.mean() loss = reduce_by_batch_2p5d(loss, True) @@ -63,9 +64,9 @@ class _VocabParallelCrossEntropy2p5D(torch.autograd.Function): # loss: [b/dq] # targets: [b/dq, h/q] logits_max = torch.max(logits, dim=-1)[0] - torch.distributed.all_reduce(logits_max, - op=torch.distributed.ReduceOp.MAX, - group=gpc.get_group(ParallelMode.PARALLEL_2P5D_ROW)) + torch.distributed.all_reduce( + logits_max, op=torch.distributed.ReduceOp.MAX, group=gpc.get_group(ParallelMode.PARALLEL_2P5D_ROW) + ) # Subtract the maximum value. logits = logits - logits_max.unsqueeze(dim=-1) @@ -83,7 +84,7 @@ class _VocabParallelCrossEntropy2p5D(torch.autograd.Function): end=logits.size()[0], ) predicted_logits = logits[arange_1d, masked_target] - predicted_logits[target_mask] = 0. + predicted_logits[target_mask] = 0.0 dist.all_reduce(predicted_logits, group=gpc.get_group(ParallelMode.PARALLEL_2P5D_ROW)) exp_logits = torch.exp(logits) @@ -100,7 +101,7 @@ class _VocabParallelCrossEntropy2p5D(torch.autograd.Function): @staticmethod @custom_bwd def backward(ctx, output_grad): - # Retreive tensors from the forward path. + # Retrieve tensors from the forward path. softmax, target_mask, masked_target = ctx.saved_tensors # All the inputs have softmax as their gradient. @@ -112,7 +113,7 @@ class _VocabParallelCrossEntropy2p5D(torch.autograd.Function): # Add the gradient from matching classes. arange_1d = torch.arange(start=0, end=grad_2d.size()[0], device=get_current_device()) - grad_2d[arange_1d, masked_target] -= (1.0 - target_mask.view(-1).float()) + grad_2d[arange_1d, masked_target] -= 1.0 - target_mask.view(-1).float() # Finally elementwise multiplication with the output gradients. grad_input.mul_(output_grad.unsqueeze(dim=-1)) diff --git a/colossalai/nn/loss/loss_3d.py b/colossalai/legacy/nn/loss/loss_3d.py similarity index 89% rename from colossalai/nn/loss/loss_3d.py rename to colossalai/legacy/nn/loss/loss_3d.py index e76439191fdbc8fc31737e21b5c2b17e6685059b..988317cae3ebebad090b1594ff0cbde5faed30d2 100644 --- a/colossalai/nn/loss/loss_3d.py +++ b/colossalai/legacy/nn/loss/loss_3d.py @@ -1,15 +1,16 @@ import torch import torch.distributed as dist -from colossalai.constants import INPUT_GROUP_3D, WEIGHT_GROUP_3D, OUTPUT_GROUP_3D -from colossalai.core import global_context as gpc -from colossalai.nn.layer.parallel_3d import reduce_by_batch_3d, split_tensor_3d -from colossalai.nn.layer.parallel_3d._utils import get_parallel_mode_from_env -from colossalai.registry import LOSSES -from colossalai.utils import get_current_device from torch.cuda.amp import custom_bwd, custom_fwd from torch.nn.functional import cross_entropy from torch.nn.modules.loss import _Loss +from colossalai.legacy.constants import INPUT_GROUP_3D, OUTPUT_GROUP_3D, WEIGHT_GROUP_3D +from colossalai.legacy.core import global_context as gpc +from colossalai.legacy.nn.layer.parallel_3d import reduce_by_batch_3d, split_tensor_3d +from colossalai.legacy.nn.layer.parallel_3d._utils import get_parallel_mode_from_env +from colossalai.legacy.registry import LOSSES +from colossalai.utils import get_current_device + @LOSSES.register_module class CrossEntropyLoss3D(_Loss): @@ -48,7 +49,7 @@ class CrossEntropyLoss3D(_Loss): """ targets = split_tensor_3d(targets, 0, self.weight_parallel_mode) targets = split_tensor_3d(targets, 0, self.input_parallel_mode) - loss = cross_entropy(logits, targets, reduction='none', *self.loss_args, **self.loss_kwargs) + loss = cross_entropy(logits, targets, reduction="none", *self.loss_args, **self.loss_kwargs) if self.reduction_mean: loss = loss.mean() loss = reduce_by_batch_3d(loss, self.input_parallel_mode, self.weight_parallel_mode, True) @@ -82,7 +83,7 @@ class _VocabParallelCrossEntropy3D(torch.autograd.Function): arange_1d = torch.arange(start=0, end=logits.size()[0], device=get_current_device()) predicted_logits = logits[arange_1d, masked_target] predicted_logits = predicted_logits.clone().contiguous().view_as(targets) - predicted_logits[target_mask] = 0. + predicted_logits[target_mask] = 0.0 dist.all_reduce(predicted_logits, group=gpc.get_group(output_parallel_mode)) # Loss = log(sum(exp(logits))) - predicted-logit. @@ -99,10 +100,10 @@ class _VocabParallelCrossEntropy3D(torch.autograd.Function): @staticmethod @custom_bwd def backward(ctx, output_grad): - # Retreive tensors from the forward path. + # Retrieve tensors from the forward path. softmax, target_mask, masked_target = ctx.saved_tensors - # All the inputs have softmax as thier gradient. + # All the inputs have softmax as their gradient. input_grad = softmax # For simplicity, work with the 2D gradient. partition_vocab_size = softmax.size()[-1] @@ -110,7 +111,7 @@ class _VocabParallelCrossEntropy3D(torch.autograd.Function): # Add the gradient from matching classes. arange_1d = torch.arange(start=0, end=grad_2d.size()[0], device=get_current_device()) - grad_2d[arange_1d, masked_target] -= (1.0 - target_mask.view(-1).float()) + grad_2d[arange_1d, masked_target] -= 1.0 - target_mask.view(-1).float() input_grad.mul_(output_grad.unsqueeze(dim=-1)) return input_grad, None, None, None diff --git a/colossalai/legacy/nn/metric/__init__.py b/colossalai/legacy/nn/metric/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..cc2b2c5d02547dd67b6f4ad4237575f862556663 --- /dev/null +++ b/colossalai/legacy/nn/metric/__init__.py @@ -0,0 +1,27 @@ +from torch import nn + +from colossalai.legacy.nn.layer.utils import get_tensor_parallel_mode + +from ._utils import calc_acc +from .accuracy_2d import Accuracy2D +from .accuracy_2p5d import Accuracy2p5D +from .accuracy_3d import Accuracy3D + +_parallel_accuracy = { + "2d": Accuracy2D, + "2.5d": Accuracy2p5D, + "3d": Accuracy3D, +} + + +class Accuracy(nn.Module): + def __init__(self): + super().__init__() + tensor_parallel = get_tensor_parallel_mode() + if tensor_parallel not in _parallel_accuracy: + self.acc = calc_acc + else: + self.acc = _parallel_accuracy[tensor_parallel]() + + def forward(self, *args): + return self.acc(*args) diff --git a/colossalai/legacy/nn/metric/_utils.py b/colossalai/legacy/nn/metric/_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..8706ffc101b0e3f5c007d3b08e4ebe0f1aef6e72 --- /dev/null +++ b/colossalai/legacy/nn/metric/_utils.py @@ -0,0 +1,7 @@ +import torch + + +def calc_acc(logits, targets): + preds = torch.argmax(logits, dim=-1) + correct = torch.sum(targets == preds) + return correct diff --git a/colossalai/nn/metric/accuracy_2d.py b/colossalai/legacy/nn/metric/accuracy_2d.py similarity index 84% rename from colossalai/nn/metric/accuracy_2d.py rename to colossalai/legacy/nn/metric/accuracy_2d.py index a86832973cfda4ffe89f414d9ae7342191293aac..59ddd5d66e2094ac89202f4eae376c489a5dbab4 100644 --- a/colossalai/nn/metric/accuracy_2d.py +++ b/colossalai/legacy/nn/metric/accuracy_2d.py @@ -1,13 +1,13 @@ import torch -from colossalai.nn.layer.parallel_2d import reduce_by_batch_2d, split_batch_2d from torch import nn +from colossalai.legacy.nn.layer.parallel_2d import reduce_by_batch_2d, split_batch_2d + from ._utils import calc_acc class Accuracy2D(nn.Module): - """Accuracy for 2D parallelism - """ + """Accuracy for 2D parallelism""" def __init__(self): super().__init__() diff --git a/colossalai/nn/metric/accuracy_2p5d.py b/colossalai/legacy/nn/metric/accuracy_2p5d.py similarity index 83% rename from colossalai/nn/metric/accuracy_2p5d.py rename to colossalai/legacy/nn/metric/accuracy_2p5d.py index 3044da065de136b5e2d5f73b2690893f3dc8e240..948eae989d4872a1e25cb1a7ae2723d0eef64738 100644 --- a/colossalai/nn/metric/accuracy_2p5d.py +++ b/colossalai/legacy/nn/metric/accuracy_2p5d.py @@ -1,13 +1,13 @@ import torch -from colossalai.nn.layer.parallel_2p5d import reduce_by_batch_2p5d, split_batch_2p5d from torch import nn +from colossalai.legacy.nn.layer.parallel_2p5d import reduce_by_batch_2p5d, split_batch_2p5d + from ._utils import calc_acc class Accuracy2p5D(nn.Module): - """Accuracy for 2p5D parallelism - """ + """Accuracy for 2p5D parallelism""" def __init__(self): super().__init__() diff --git a/colossalai/nn/metric/accuracy_3d.py b/colossalai/legacy/nn/metric/accuracy_3d.py similarity index 75% rename from colossalai/nn/metric/accuracy_3d.py rename to colossalai/legacy/nn/metric/accuracy_3d.py index 5506fc1d2ffcf918cbb9f079bfe0f07f92b9bc7f..aee6118413ef2f965da32109a8a0e05c1ba74ce5 100644 --- a/colossalai/nn/metric/accuracy_3d.py +++ b/colossalai/legacy/nn/metric/accuracy_3d.py @@ -1,33 +1,34 @@ -import torch -from colossalai.constants import INPUT_GROUP_3D, WEIGHT_GROUP_3D -from colossalai.nn.layer.parallel_3d import reduce_by_batch_3d, split_tensor_3d -from colossalai.nn.layer.parallel_3d._utils import get_parallel_mode_from_env -from torch import nn - -from ._utils import calc_acc - - -class Accuracy3D(nn.Module): - """Accuracy for 3D parallelism - """ - def __init__(self): - super().__init__() - self.input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D) - self.weight_parallel_mode = get_parallel_mode_from_env(WEIGHT_GROUP_3D) - - def forward(self, logits, targets): - """Calculate the accuracy of predicted labels. - - Args: - logits (:class:`torch.tensor`): Predicted labels. - targets (:class:`torch.tensor`): True labels from data. - - Returns: - float: the accuracy of prediction. - """ - with torch.no_grad(): - targets = split_tensor_3d(targets, 0, self.weight_parallel_mode) - targets = split_tensor_3d(targets, 0, self.input_parallel_mode) - correct = calc_acc(logits, targets) - correct = reduce_by_batch_3d(correct, self.input_parallel_mode, self.weight_parallel_mode) - return correct +import torch +from torch import nn + +from colossalai.legacy.constants import INPUT_GROUP_3D, WEIGHT_GROUP_3D +from colossalai.legacy.nn.layer.parallel_3d import reduce_by_batch_3d, split_tensor_3d +from colossalai.legacy.nn.layer.parallel_3d._utils import get_parallel_mode_from_env + +from ._utils import calc_acc + + +class Accuracy3D(nn.Module): + """Accuracy for 3D parallelism""" + + def __init__(self): + super().__init__() + self.input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D) + self.weight_parallel_mode = get_parallel_mode_from_env(WEIGHT_GROUP_3D) + + def forward(self, logits, targets): + """Calculate the accuracy of predicted labels. + + Args: + logits (:class:`torch.tensor`): Predicted labels. + targets (:class:`torch.tensor`): True labels from data. + + Returns: + float: the accuracy of prediction. + """ + with torch.no_grad(): + targets = split_tensor_3d(targets, 0, self.weight_parallel_mode) + targets = split_tensor_3d(targets, 0, self.input_parallel_mode) + correct = calc_acc(logits, targets) + correct = reduce_by_batch_3d(correct, self.input_parallel_mode, self.weight_parallel_mode) + return correct diff --git a/colossalai/legacy/nn/parallel/__init__.py b/colossalai/legacy/nn/parallel/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..19ad8404de1816dc17f000cfe491b7d77491b0c3 --- /dev/null +++ b/colossalai/legacy/nn/parallel/__init__.py @@ -0,0 +1,5 @@ +from .data_parallel import ColoDDP + +__all__ = [ + "ColoDDP", +] diff --git a/colossalai/nn/parallel/data_parallel.py b/colossalai/legacy/nn/parallel/data_parallel.py similarity index 81% rename from colossalai/nn/parallel/data_parallel.py rename to colossalai/legacy/nn/parallel/data_parallel.py index f839d6b2844491a870b38a68459c4bedaf9d760a..9634cb46a12ad530e960daed1e2be40e32ae5c9a 100644 --- a/colossalai/nn/parallel/data_parallel.py +++ b/colossalai/legacy/nn/parallel/data_parallel.py @@ -5,7 +5,7 @@ from typing import Iterable, Optional, Set import torch import torch.distributed as dist -from colossalai.tensor import ProcessGroup as ColoProcessGroup +from colossalai.legacy.tensor import ProcessGroup as ColoProcessGroup from colossalai.utils import is_ddp_ignored from .reducer import Reducer @@ -34,8 +34,8 @@ class ColoDDP(torch.nn.Module): """Distributed data parallel for ColoTensor. Nested ColoDDP is not supported now. Example: - >>> from colossalai.core import global_context as gpc - >>> from colossalai.context import ParallelMode + >>> from colossalai.legacy.core import global_context as gpc + >>> from colossalai.legacy.context import ParallelMode >>> model = torch.nn.Linear(20, 1) >>> pg = ProcessGroup(tp_degree = world_size//2) >>> model = ColoDDP(model, pg) @@ -49,11 +49,13 @@ class ColoDDP(torch.nn.Module): If it's None, the default data parallel group will be used. Defaults to None. """ - def __init__(self, - module: torch.nn.Module, - process_group: ColoProcessGroup, - bucket_cap_mb: int = 25, - rebuild_bucket: bool = True) -> None: + def __init__( + self, + module: torch.nn.Module, + process_group: ColoProcessGroup, + bucket_cap_mb: int = 25, + rebuild_bucket: bool = True, + ) -> None: assert not isinstance(module, ColoDDP) super().__init__() self.module = module @@ -74,19 +76,18 @@ class ColoDDP(torch.nn.Module): def parameters(self, recurse: bool = True): return self.module.parameters(recurse) - def named_parameters(self, prefix: str = '', recurse: bool = True): + def named_parameters(self, prefix: str = "", recurse: bool = True): return self.module.named_parameters(prefix, recurse) - def named_buffers(self, prefix: str = '', recurse: bool = True): + def named_buffers(self, prefix: str = "", recurse: bool = True): return self.module.named_buffers(prefix, recurse) def named_children(self): return self.module.named_children() - def named_modules(self, - memo: Optional[Set[torch.nn.Module]] = None, - prefix: str = '', - remove_duplicate: bool = True): + def named_modules( + self, memo: Optional[Set[torch.nn.Module]] = None, prefix: str = "", remove_duplicate: bool = True + ): return self.module.named_modules(memo, prefix, remove_duplicate) def forward(self, *args, **kwargs): @@ -114,9 +115,9 @@ class ColoDDP(torch.nn.Module): grad = grad / self.dp_world_size self.comm_stream.wait_stream(torch.cuda.current_stream()) with torch.cuda.stream(self.comm_stream): - self.reducer.all_reduce_async(grad, - group=self.process_group.dp_process_group(), - callback_fn=partial(self._save_grad, p)) + self.reducer.all_reduce_async( + grad, group=self.process_group.dp_process_group(), callback_fn=partial(self._save_grad, p) + ) grad.record_stream(self.comm_stream) else: ColoDDP._save_grad(p, grad) @@ -130,7 +131,7 @@ class ColoDDP(torch.nn.Module): @staticmethod def _save_grad(p, grad): - if hasattr(p, '_saved_grad'): + if hasattr(p, "_saved_grad"): p._saved_grad.add_(grad) else: p._saved_grad = grad @@ -138,7 +139,7 @@ class ColoDDP(torch.nn.Module): def zero_grad(self, set_to_none: bool = False) -> None: self.module.zero_grad(set_to_none=True) for p in self.module.parameters(): - if getattr(p, '_saved_grad', None) is not None: + if getattr(p, "_saved_grad", None) is not None: if set_to_none: p._saved_grad = None else: @@ -167,8 +168,8 @@ class ColoDDP(torch.nn.Module): for p in params_to_ignore: p._ddp_to_ignore = True - def state_dict(self, destination=None, prefix='', keep_vars=False): + def state_dict(self, destination=None, prefix="", keep_vars=False): return self.module.state_dict(destination=destination, prefix=prefix, keep_vars=keep_vars) - def load_state_dict(self, state_dict: 'OrderedDict[str, torch.Tensor]', strict: bool = True): + def load_state_dict(self, state_dict: "OrderedDict[str, torch.Tensor]", strict: bool = True): return self.module.load_state_dict(state_dict, strict) diff --git a/colossalai/legacy/nn/parallel/layers/__init__.py b/colossalai/legacy/nn/parallel/layers/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..2663076c6992555fac8273b597cb15bc6123fbf2 --- /dev/null +++ b/colossalai/legacy/nn/parallel/layers/__init__.py @@ -0,0 +1,33 @@ +from .cache_embedding import ( + CachedEmbeddingBag, + CachedParamMgr, + EvictionStrategy, + LimitBuffIndexCopyer, + ParallelCachedEmbeddingBag, + ParallelCachedEmbeddingBagTablewise, + ParallelCachedEmbeddingBagTablewiseSpiltCache, + TablewiseEmbeddingBagConfig, +) +from .colo_module import ColoModule +from .embedding import ColoEmbedding +from .linear import ColoLinear +from .module_utils import check_colo_module, get_colo_module, init_colo_module, is_colo_module, register_colo_module + +__all__ = [ + "ColoModule", + "register_colo_module", + "is_colo_module", + "get_colo_module", + "init_colo_module", + "check_colo_module", + "ColoLinear", + "ColoEmbedding", + "CachedEmbeddingBag", + "ParallelCachedEmbeddingBag", + "CachedParamMgr", + "LimitBuffIndexCopyer", + "EvictionStrategy", + "ParallelCachedEmbeddingBagTablewise", + "TablewiseEmbeddingBagConfig", + "ParallelCachedEmbeddingBagTablewiseSpiltCache", +] diff --git a/colossalai/legacy/nn/parallel/layers/cache_embedding/__init__.py b/colossalai/legacy/nn/parallel/layers/cache_embedding/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..aad6dcc5d7d8318257df8dd85b78fc3f64a8e060 --- /dev/null +++ b/colossalai/legacy/nn/parallel/layers/cache_embedding/__init__.py @@ -0,0 +1,18 @@ +from .cache_mgr import CachedParamMgr, EvictionStrategy +from .cached_embedding import CachedEmbeddingBag +from .copyer import LimitBuffIndexCopyer +from .embedding_config import TablewiseEmbeddingBagConfig +from .parallel_cached_embedding import ParallelCachedEmbeddingBag +from .parallel_cached_embedding_tablewise import ParallelCachedEmbeddingBagTablewise +from .parallel_cached_embedding_tablewise_split_cache import ParallelCachedEmbeddingBagTablewiseSpiltCache + +__all__ = [ + "CachedParamMgr", + "LimitBuffIndexCopyer", + "CachedEmbeddingBag", + "ParallelCachedEmbeddingBag", + "EvictionStrategy", + "ParallelCachedEmbeddingBagTablewise", + "TablewiseEmbeddingBagConfig", + "ParallelCachedEmbeddingBagTablewiseSpiltCache", +] diff --git a/colossalai/nn/parallel/layers/cache_embedding/base_embedding.py b/colossalai/legacy/nn/parallel/layers/cache_embedding/base_embedding.py similarity index 78% rename from colossalai/nn/parallel/layers/cache_embedding/base_embedding.py rename to colossalai/legacy/nn/parallel/layers/cache_embedding/base_embedding.py index 705835a0ed22ef5c9f6ecc388ddcf8d4e2ea3073..3f825f11fe51febe1280ac3552b6db283e39e0e9 100644 --- a/colossalai/nn/parallel/layers/cache_embedding/base_embedding.py +++ b/colossalai/legacy/nn/parallel/layers/cache_embedding/base_embedding.py @@ -1,19 +1,19 @@ import abc + import torch.nn as nn class BaseEmbeddingBag(abc.ABC, nn.Module): - def __init__( self, num_embeddings, embedding_dim, padding_idx=None, max_norm=None, - norm_type=2., + norm_type=2.0, scale_grad_by_freq=False, sparse=False, - mode='mean', + mode="mean", include_last_offset=False, ): super(BaseEmbeddingBag, self).__init__() @@ -21,9 +21,9 @@ class BaseEmbeddingBag(abc.ABC, nn.Module): self.embedding_dim = embedding_dim if padding_idx is not None: if padding_idx > 0: - assert padding_idx < self.num_embeddings, 'Padding_idx must be within num_embeddings' + assert padding_idx < self.num_embeddings, "Padding_idx must be within num_embeddings" elif padding_idx < 0: - assert padding_idx >= -self.num_embeddings, 'Padding_idx must be within num_embeddings' + assert padding_idx >= -self.num_embeddings, "Padding_idx must be within num_embeddings" padding_idx = self.num_embeddings + padding_idx self.padding_idx = padding_idx self.max_norm = max_norm diff --git a/colossalai/nn/parallel/layers/cache_embedding/cache_mgr.py b/colossalai/legacy/nn/parallel/layers/cache_embedding/cache_mgr.py similarity index 83% rename from colossalai/nn/parallel/layers/cache_embedding/cache_mgr.py rename to colossalai/legacy/nn/parallel/layers/cache_embedding/cache_mgr.py index da043df368ae17ed1b95cc42b90e58994c84f9aa..e23864071e66d86ad5830b6adaf5912b682cc880 100644 --- a/colossalai/nn/parallel/layers/cache_embedding/cache_mgr.py +++ b/colossalai/legacy/nn/parallel/layers/cache_embedding/cache_mgr.py @@ -1,12 +1,14 @@ +import sys +from contextlib import contextmanager +from enum import Enum +from typing import List, Optional + import numpy as np import torch -from torch.profiler import record_function -from typing import List, Optional from contexttimer import Timer +from torch.profiler import record_function + from .copyer import LimitBuffIndexCopyer -from enum import Enum -import sys -from contextlib import contextmanager class EvictionStrategy(Enum): @@ -20,8 +22,8 @@ def _wait_for_data(t, stream: Optional[torch.cuda.streams.Stream]) -> None: return torch.cuda.current_stream().wait_stream(stream) # As mentioned in https://pytorch.org/docs/stable/generated/torch.Tensor.record_stream.html, - # PyTorch uses the "caching allocator" for memroy allocation for tensors. When a tensor is - # freed, its memory is likely to be reused by newly constructed tenosrs. By default, + # PyTorch uses the "caching allocator" for memory allocation for tensors. When a tensor is + # freed, its memory is likely to be reused by newly constructed tensors. By default, # this allocator traces whether a tensor is still in use by only the CUDA stream where it # was created. When a tensor is used by additional CUDA streams, we need to call record_stream # to tell the allocator about all these streams. Otherwise, the allocator might free the @@ -35,7 +37,7 @@ def _wait_for_data(t, stream: Optional[torch.cuda.streams.Stream]) -> None: class CachedParamMgr(torch.nn.Module): """ Manage Embedding Weights on CPU and CUDA memory uses a software cache. - CPU maintains the entire original weight. + CPU maintains the entire original weight. CUDA maintains a fraction of the weights used in the upcoming computation. The row number in CUDA is controlled by `cuda_row_num`. During training, GPU needs to transmit embedding rows between CPU and GPU. Args: @@ -81,15 +83,16 @@ class CachedParamMgr(torch.nn.Module): if self._async_copy: self._memcpy_stream = torch.cuda.Stream() - print('use async copy') + print("use async copy") if self._evict_strategy == EvictionStrategy.LFU: # cache_row_idx -> frequency, freq of the cache rows. # classic lfu cache. evict the minimal freq value row in cuda cache. - self.register_buffer("freq_cnter", - torch.empty(self.cuda_row_num, device=torch.cuda.current_device(), - dtype=torch.long).fill_(sys.maxsize), - persistent=False) + self.register_buffer( + "freq_cnter", + torch.empty(self.cuda_row_num, device=torch.cuda.current_device(), dtype=torch.long).fill_(sys.maxsize), + persistent=False, + ) self._elapsed_dict = {} self._show_cache_miss = True self._reset_comm_stats() @@ -115,7 +118,7 @@ class CachedParamMgr(torch.nn.Module): self._elapsed_dict[name] += t.elapsed def _find_evict_gpu_idxs(self, evict_num: int) -> torch.Tensor: - """_find_evict_gpu_idxs + """_find_evict_gpu_idxs Find the gpu idxs to be evicted, according to their freq. Args: evict_num (int): how many rows has to be evicted @@ -140,10 +143,10 @@ class CachedParamMgr(torch.nn.Module): if self.cuda_row_num > 0: # Enable cache with introducing auxiliary data structures self.cuda_cached_weight = torch.nn.Parameter( - torch.zeros(self.cuda_row_num, - self.embedding_dim, - device=torch.cuda.current_device(), - dtype=weight.dtype)) + torch.zeros( + self.cuda_row_num, self.embedding_dim, device=torch.cuda.current_device(), dtype=weight.dtype + ) + ) # pin memory cpu for higher CPU-GPU copy bandwidth self.weight = weight.pin_memory() if self.pin_weight else weight @@ -156,17 +159,19 @@ class CachedParamMgr(torch.nn.Module): ) # cached_idx_map: gpu_row_idx -> cpu_row_idx - self.register_buffer("cached_idx_map", - torch.empty(self.cuda_row_num, device=torch.cuda.current_device(), - dtype=torch.long).fill_(-1), - persistent=False) + self.register_buffer( + "cached_idx_map", + torch.empty(self.cuda_row_num, device=torch.cuda.current_device(), dtype=torch.long).fill_(-1), + persistent=False, + ) # cpu_row_id -> gpu_row_idx. # gpu_row_idx as -1 means cpu_row_id not in CUDA. - self.register_buffer("inverted_cached_idx", - torch.zeros(self.num_embeddings, device=torch.cuda.current_device(), - dtype=torch.long).fill_(-1), - persistent=False) + self.register_buffer( + "inverted_cached_idx", + torch.zeros(self.num_embeddings, device=torch.cuda.current_device(), dtype=torch.long).fill_(-1), + persistent=False, + ) self.evict_backlist = torch.tensor([], device=torch.cuda.current_device()) @@ -189,9 +194,11 @@ class CachedParamMgr(torch.nn.Module): torch.Tensor: a piece of memory in CPU weight corresponding to row id's payload. The tensor is 1-D. """ - return self.weight.data.view(-1).narrow(0, - int(row_idx) * self.embedding_dim, - self.embedding_dim).view(1, self.embedding_dim) + return ( + self.weight.data.view(-1) + .narrow(0, int(row_idx) * self.embedding_dim, self.embedding_dim) + .view(1, self.embedding_dim) + ) @property def cuda_available_row_num(self): @@ -202,7 +209,7 @@ class CachedParamMgr(torch.nn.Module): """reorder reorder the weight according to ids' frequency in dataset before training. Execute only once before training, also known as warmup phase. - + Note: If you would like to use the DATASET as the eviction strategy, you must call this function. Note: @@ -236,15 +243,18 @@ class CachedParamMgr(torch.nn.Module): preload_cpu_ids = torch.arange(preload_row_num) preload_cuda_row_idxs = preload_cpu_ids.cuda() if self.buffer_size > 0: - self.limit_buff_index_copyer.index_copy(0, - src_index=preload_cpu_ids, - tgt_index=preload_cuda_row_idxs, - src=self.weight.view(self.num_embeddings, -1), - tgt=self.cuda_cached_weight.view(self.cuda_row_num, -1)) + self.limit_buff_index_copyer.index_copy( + 0, + src_index=preload_cpu_ids, + tgt_index=preload_cuda_row_idxs, + src=self.weight.view(self.num_embeddings, -1), + tgt=self.cuda_cached_weight.view(self.cuda_row_num, -1), + ) else: preload_rows = self.weight.view(self.num_embeddings, -1).index_select(0, preload_cpu_ids).cuda() - self.cuda_cached_weight.view(self.cuda_row_num, -1).index_copy_(0, preload_cuda_row_idxs, - preload_rows) + self.cuda_cached_weight.view(self.cuda_row_num, -1).index_copy_( + 0, preload_cuda_row_idxs, preload_rows + ) # update auxiliary info self.cached_idx_map[preload_cuda_row_idxs] = preload_cpu_ids.cuda() @@ -258,7 +268,7 @@ class CachedParamMgr(torch.nn.Module): else: self.freq_cnter[preload_cuda_row_idxs] = freq_value.cuda() - print(f'Cache warmup finished cost {timer.elapsed} sec.') + print(f"Cache warmup finished cost {timer.elapsed} sec.") def flush(self): """flush all CUDA rows to CPU. @@ -288,18 +298,18 @@ class CachedParamMgr(torch.nn.Module): print( f"CUDA->CPU BWD {self._cuda_to_cpu_numel * self.elem_size_in_byte / 1e6 / elapsed} MB/s {self._cuda_to_cpu_numel / 1e6} M elem" ) - print(f'cuda_to_cpu_elapse {elapsed} sec') + print(f"cuda_to_cpu_elapse {elapsed} sec") if self._cpu_to_cuda_numel > 0 and "5_evict_in" in self._elapsed_dict: elapsed = self._elapsed_dict["5_evict_in"] print( f"CPU->CUDA BWD {self._cpu_to_cuda_numel * self.elem_size_in_byte / 1e6 / elapsed} MB/s {self._cpu_to_cuda_numel / 1e6} M elem" ) - print(f'cpu_to_cuda_elpase {elapsed} sec') + print(f"cpu_to_cuda_elapse {elapsed} sec") for k, v in self._elapsed_dict.items(): - print(f'{k}: {v}') + print(f"{k}: {v}") - print(f'cache miss ratio {self._cache_miss / self._total_cache}') + print(f"cache miss ratio {self._cache_miss / self._total_cache}") @torch.no_grad() def _id_to_cached_cuda_id(self, ids: torch.Tensor) -> torch.Tensor: @@ -334,10 +344,11 @@ class CachedParamMgr(torch.nn.Module): else: cpu_row_idxs, repeat_times = torch.unique(self.idx_map.index_select(0, ids), return_counts=True) - assert len(cpu_row_idxs) <= self.cuda_row_num, \ - f"You move {len(cpu_row_idxs)} embedding rows from CPU to CUDA. " \ - f"It is larger than the capacity of the cache, which at most contains {self.cuda_row_num} rows, " \ + assert len(cpu_row_idxs) <= self.cuda_row_num, ( + f"You move {len(cpu_row_idxs)} embedding rows from CPU to CUDA. " + f"It is larger than the capacity of the cache, which at most contains {self.cuda_row_num} rows, " f"Please increase cuda_row_num or decrease the training batch size." + ) self.evict_backlist = cpu_row_idxs tmp = torch.isin(cpu_row_idxs, self.cached_idx_map, invert=True) comm_cpu_row_idxs = cpu_row_idxs[tmp] @@ -384,8 +395,9 @@ class CachedParamMgr(torch.nn.Module): # move evict in rows to gpu if self._async_copy: if self.buffer_size == 0: - evict_in_rows_gpu = self.weight.view(self.num_embeddings, - -1).index_select(0, cpu_row_idxs_copy).pin_memory() + evict_in_rows_gpu = ( + self.weight.view(self.num_embeddings, -1).index_select(0, cpu_row_idxs_copy).pin_memory() + ) with torch.cuda.stream(self._memcpy_stream): evict_in_rows_gpu = evict_in_rows_gpu.to(torch.cuda.current_device(), non_blocking=True) else: @@ -407,9 +419,10 @@ class CachedParamMgr(torch.nn.Module): # move evict out rows to cpu if self._async_copy: - evict_out_rows_gpu = self.cuda_cached_weight.view(self.cuda_row_num, - -1).index_select(0, evict_gpu_row_idxs) - evict_out_rows_cpu = torch.empty_like(evict_out_rows_gpu, device='cpu', pin_memory=True) + evict_out_rows_gpu = self.cuda_cached_weight.view(self.cuda_row_num, -1).index_select( + 0, evict_gpu_row_idxs + ) + evict_out_rows_cpu = torch.empty_like(evict_out_rows_gpu, device="cpu", pin_memory=True) with torch.cuda.stream(None): evict_out_rows_cpu.copy_(evict_out_rows_gpu, non_blocking=True) self.cached_idx_map.index_copy_(0, invalid_idxs, backup_idxs) @@ -423,9 +436,10 @@ class CachedParamMgr(torch.nn.Module): evict_gpu_row_idxs = self._find_evict_gpu_idxs(evict_num) if self._async_copy: - evict_out_rows_gpu = self.cuda_cached_weight.view(self.cuda_row_num, - -1).index_select(0, evict_gpu_row_idxs) - evict_out_rows_cpu = torch.empty_like(evict_out_rows_gpu, device='cpu', pin_memory=True) + evict_out_rows_gpu = self.cuda_cached_weight.view(self.cuda_row_num, -1).index_select( + 0, evict_gpu_row_idxs + ) + evict_out_rows_cpu = torch.empty_like(evict_out_rows_gpu, device="cpu", pin_memory=True) with torch.cuda.stream(None): evict_out_rows_cpu.copy_(evict_out_rows_gpu, non_blocking=True) @@ -436,11 +450,13 @@ class CachedParamMgr(torch.nn.Module): with self.timer("3_evict_out") as timer: if self.buffer_size > 0: - self.limit_buff_index_copyer.index_copy(0, - src_index=evict_gpu_row_idxs, - tgt_index=evict_info.cpu(), - src=self.cuda_cached_weight.view(self.cuda_row_num, -1), - tgt=self.weight.view(self.num_embeddings, -1)) + self.limit_buff_index_copyer.index_copy( + 0, + src_index=evict_gpu_row_idxs, + tgt_index=evict_info.cpu(), + src=self.cuda_cached_weight.view(self.cuda_row_num, -1), + tgt=self.weight.view(self.num_embeddings, -1), + ) else: # allocate tmp memory on CPU and copy rows on CUDA to CPU. # TODO async gpu -> cpu @@ -448,8 +464,9 @@ class CachedParamMgr(torch.nn.Module): _wait_for_data(evict_out_rows_cpu, None) else: with self.timer("3_1_evict_out_index_select") as timer: - evict_out_rows_cpu = self.cuda_cached_weight.view(self.cuda_row_num, - -1).index_select(0, evict_gpu_row_idxs) + evict_out_rows_cpu = self.cuda_cached_weight.view(self.cuda_row_num, -1).index_select( + 0, evict_gpu_row_idxs + ) with self.timer("3_2_evict_out_gpu_to_cpu_copy") as timer: evict_out_rows_cpu = evict_out_rows_cpu.cpu() @@ -467,17 +484,19 @@ class CachedParamMgr(torch.nn.Module): # slots of cuda weight to evict in with self.timer("4_identify_cuda_slot") as timer: - slots = torch.nonzero(self.cached_idx_map == -1).squeeze(1)[:cpu_row_idxs.numel()] + slots = torch.nonzero(self.cached_idx_map == -1).squeeze(1)[: cpu_row_idxs.numel()] # TODO wait for optimize with self.timer("5_evict_in") as timer: # Here also allocate extra memory on CUDA. #cpu_row_idxs if self.buffer_size > 0: - self.limit_buff_index_copyer.index_copy(0, - src_index=cpu_row_idxs_copy, - tgt_index=slots, - src=self.weight.view(self.num_embeddings, -1), - tgt=self.cuda_cached_weight.view(self.cuda_row_num, -1)) + self.limit_buff_index_copyer.index_copy( + 0, + src_index=cpu_row_idxs_copy, + tgt_index=slots, + src=self.weight.view(self.num_embeddings, -1), + tgt=self.cuda_cached_weight.view(self.cuda_row_num, -1), + ) else: if self._async_copy: _wait_for_data(evict_in_rows_gpu, self._memcpy_stream) @@ -486,8 +505,9 @@ class CachedParamMgr(torch.nn.Module): # narrow index select to a subset of self.weight # tmp = torch.narrow(self.weight.view(self.num_embeddings, -1), 0, min(cpu_row_idxs).cpu(), max(cpu_row_idxs) - min(cpu_row_idxs) + 1) # evict_in_rows_gpu = tmp.index_select(0, cpu_row_idxs_copy - min(cpu_row_idxs).cpu()) - evict_in_rows_gpu = self.weight.view(self.num_embeddings, - -1).index_select(0, cpu_row_idxs_copy).pin_memory() + evict_in_rows_gpu = ( + self.weight.view(self.num_embeddings, -1).index_select(0, cpu_row_idxs_copy).pin_memory() + ) with self.timer("5_2_evict_in_gpu_to_cpu_copy") as timer: evict_in_rows_gpu = evict_in_rows_gpu.cuda() @@ -516,7 +536,7 @@ class CachedParamMgr(torch.nn.Module): """ deprecated evict one row from cuda to cpu. - Returns: + Returns: (int) : the slot id be evicted. """ mask = torch.logical_or(torch.isin(self.cached_idx_map, self.evict_backlist), self.cached_idx_map == -1) @@ -535,8 +555,9 @@ class CachedParamMgr(torch.nn.Module): self.cached_idx_map.index_copy_(0, idx, buf) with Timer() as timer: - cuda_tensor = torch.narrow(self.cuda_cached_weight.view(-1), 0, max_offset * self.embedding_dim, - self.embedding_dim).view(1, self.embedding_dim) + cuda_tensor = torch.narrow( + self.cuda_cached_weight.view(-1), 0, max_offset * self.embedding_dim, self.embedding_dim + ).view(1, self.embedding_dim) self.cpu_weight_data(max_gpu_row_idx).data.copy_(cuda_tensor) # update inverted_cached_idx, min_slot_id is evicted from cuda @@ -568,8 +589,9 @@ class CachedParamMgr(torch.nn.Module): slot_offset = slot_id # copy payload from cpu to cuda with Timer() as timer: - cuda_tensor = torch.narrow(self.cuda_cached_weight.view(-1), 0, slot_offset * self.embedding_dim, - self.embedding_dim).view(1, self.embedding_dim) + cuda_tensor = torch.narrow( + self.cuda_cached_weight.view(-1), 0, slot_offset * self.embedding_dim, self.embedding_dim + ).view(1, self.embedding_dim) cuda_tensor.data.copy_(self.cpu_weight_data(row_id)) # update the inverted_cached_idx diff --git a/colossalai/legacy/nn/parallel/layers/cache_embedding/cached_embedding.py b/colossalai/legacy/nn/parallel/layers/cache_embedding/cached_embedding.py new file mode 100644 index 0000000000000000000000000000000000000000..03667857b1ac2335c9ba84b46747564dcf4547d2 --- /dev/null +++ b/colossalai/legacy/nn/parallel/layers/cache_embedding/cached_embedding.py @@ -0,0 +1,186 @@ +from typing import Iterator, List, Optional, Tuple, Union + +import torch +import torch.nn.functional as F +from torch.nn.parameter import Parameter + +from .base_embedding import BaseEmbeddingBag +from .cache_mgr import CachedParamMgr, EvictionStrategy + + +class CachedEmbeddingBag(BaseEmbeddingBag): + """CachedEmbeddingBag + + Cached Embedding. Apply a GPU-based software cache approaches to dynamically manage the embedding table in the CPU and GPU memory space. + It can leverage the id's frequency statistics of the target dataset, by passing a frequency list to param `ids_freq_mapping`. + You can also apply a naive LFU cache eviction strategy by setting `evict_strategy` as EvictionStrategy.LFU. + + Args: + num_embeddings (int): size of the dictionary of embeddings + embedding_dim (int): the size of each embedding vector + padding_idx (int, optional): If specified, the entries at padding_idx do not contribute to the gradient; therefore, the embedding vector at padding_idx is not updated during training, i.e. it remains as a fixed “pad”. For a newly constructed EmbeddingBag, the embedding vector at padding_idx will default to all zeros, but can be updated to another value to be used as the padding vector. Note that the embedding vector at padding_idx is excluded from the reduction. + max_norm (float, optional): If given, each embedding vector with norm larger than max_norm is renormalized to have norm max_norm + norm_type (str, optional): The p of the p-norm to compute for the max_norm option. Defaults to 2. + scale_grad_by_freq (bool, optional): if given, this will scale gradients by the inverse of frequency of the words in the mini-batch. Default False. Note: this option is not supported when mode="max". Defaults to False. + sparse (bool, optional): if True, gradient w.r.t. weight matrix will be a sparse tensor. See Notes for more details regarding sparse gradients. Note: this option is not supported when mode="max".. Defaults to False. + _weight (torch.Tensor, optional): an embedding weight tensor. Concatenate multiple tables in a embedding bag as a single one. Defaults to None. + mode (str, optional): "sum", "mean" or "max". Specifies the way to reduce the bag. "sum" computes the weighted sum, taking per_sample_weights into consideration. "mean" computes the average of the values in the bag, "max" computes the max value over each bag. Default: "mean". Defaults to 'mean'. + include_last_offset (bool, optional): if True, offsets has one additional element, where the last element is equivalent to the size of indices. This matches the CSR format.. Defaults to False. + dtype (torch.dtype, optional): data type of the cpu weight initialization. Defaults to None meaning float32. + device (torch.device, optional): device type to the cpu weight. Defaults to None meaning cpu. + cache_ratio (float, float): cache ratio of the #cuda_weight_row / #cpu_weight_row + ids_freq_mapping (Union[List, torch.Tensor], optional): the frequency of each embedding vector occurs in dataset. Defaults to None. + warmup_ratio (float, optional): the ratio of cuda cache is warmuped with. Defaults to 0.7. + buffer_size (int, optional): the max number of vectors in transmitter buffer. If set to 0, the buffer is not used. Defaults to 0. + pin_weight (bool, optional): pin the cpu weight. Defaults to False. + evict_strategy (EvictionStrategy, optional): evict strategy of the software cache. Defaults to EvictionStrategy.DATASET. + """ + + def __init__( + self, + num_embeddings: int, + embedding_dim: int, + padding_idx: int = None, + max_norm: float = None, + norm_type: float = 2.0, + scale_grad_by_freq: bool = False, + sparse: bool = False, + _weight: Optional[torch.Tensor] = None, + mode: str = "mean", + include_last_offset: bool = False, + dtype: Optional[torch.dtype] = None, + device: Optional[torch.device] = None, + cache_ratio: float = 0.01, + ids_freq_mapping: Optional[Union[List, torch.Tensor]] = None, + warmup_ratio: float = 0.7, + buffer_size: int = 0, + pin_weight: bool = False, + evict_strategy: EvictionStrategy = EvictionStrategy.LFU, + ): + super(CachedEmbeddingBag, self).__init__( + num_embeddings, + embedding_dim, + padding_idx, + max_norm, + norm_type, + scale_grad_by_freq, + sparse, + mode, + include_last_offset, + ) + + assert cache_ratio <= 1.0, f"cache ratio {cache_ratio} must less than 1.0" + self.evict_strategy = evict_strategy + if _weight is None: + _weight = self._weight_alloc(dtype, device) + cuda_row_num = int(num_embeddings * cache_ratio) + # configure weight & cache + self._preprocess(_weight, cuda_row_num, ids_freq_mapping, warmup_ratio, buffer_size, pin_weight) + self.cache_op = True + + def set_cache_mgr_async_copy(self, flag): + self.cache_weight_mgr._async_copy = flag + + def _weight_alloc(self, dtype, device): + weight = torch.empty(self.num_embeddings, self.embedding_dim, dtype=dtype, device=device) + with torch.no_grad(): + weight.data.uniform_(-1 / self.num_embeddings, 1 / self.num_embeddings) + if self.padding_idx is not None: + weight[self.padding_idx].fill_(0) + return weight + + def _preprocess( + self, + weight, + cuda_row_num: int, + ids_freq_mapping: Optional[List[int]] = None, + warmup_ratio=0.7, + buffer_size=50_000, + pin_weight=False, + ): + """ + Called after initialized. + Reorder the weight rows according to the ids_freq_mapping. + Then, let the weights of the Module be managed by a CachedParamMgr. + + Args: + cuda_row_num (int): number of rows can be hosted in CUDA memory + ids_freq_mapping (List[int]): a list, idx is id number, value is freq + warmup_ratio (float): the amount of rows preloaded in cuda cache + """ + self.cache_weight_mgr = CachedParamMgr( + weight, cuda_row_num, buffer_size, pin_weight, evict_strategy=self.evict_strategy + ) + self.cache_weight_mgr.reorder(ids_freq_mapping, warmup_ratio) + + def forward(self, input, offsets=None, per_sample_weights=None, shape_hook=None): + if self.cache_op: + with torch.no_grad(): + input = self.cache_weight_mgr.prepare_ids(input) + + embeddings = F.embedding_bag( + input.cuda(), + self.cache_weight_mgr.cuda_cached_weight, + offsets, + self.max_norm, + self.norm_type, + self.scale_grad_by_freq, + self.mode, + self.sparse, + per_sample_weights, + self.include_last_offset, + self.padding_idx, + ) + if shape_hook is not None: + embeddings = shape_hook(embeddings) + return embeddings + + @property + def weight(self): + return self.cache_weight_mgr.weight + + def named_parameters(self, prefix: str = "", recurse: bool = True) -> Iterator[Tuple[str, Parameter]]: + yield "weight", self.cache_weight_mgr.cuda_cached_weight + + def parameters(self, recurse: bool = True) -> Iterator[Parameter]: + yield self.cache_weight_mgr.cuda_cached_weight + + def set_cache_op(self, cache_op: bool = True): + self.cache_op = cache_op + + ############################# Perf Log ################################### + + @property + def num_hits_history(self): + return self.cache_weight_mgr.num_hits_history + + @property + def num_miss_history(self): + return self.cache_weight_mgr.num_miss_history + + @property + def num_write_back_history(self): + return self.cache_weight_mgr.num_write_back_history + + @property + def swap_in_bandwidth(self): + if self.cache_weight_mgr._cpu_to_cuda_numel > 0: + return ( + self.cache_weight_mgr._cpu_to_cuda_numel + * self.cache_weight_mgr.elem_size_in_byte + / 1e6 + / self.cache_weight_mgr._cpu_to_cuda_elapse + ) + else: + return 0 + + @property + def swap_out_bandwidth(self): + if self.cache_weight_mgr._cuda_to_cpu_numel > 0: + return ( + self.cache_weight_mgr._cuda_to_cpu_numel + * self.cache_weight_mgr.elem_size_in_byte + / 1e6 + / self.cache_weight_mgr._cuda_to_cpu_elapse + ) + return 0 diff --git a/colossalai/nn/parallel/layers/cache_embedding/copyer.py b/colossalai/legacy/nn/parallel/layers/cache_embedding/copyer.py similarity index 89% rename from colossalai/nn/parallel/layers/cache_embedding/copyer.py rename to colossalai/legacy/nn/parallel/layers/cache_embedding/copyer.py index b586be1dc6d98ed7df2cffd1c326ca25ab837f33..5e3a8df05cfe05b8e78f97cd98708f73c32b5837 100644 --- a/colossalai/nn/parallel/layers/cache_embedding/copyer.py +++ b/colossalai/legacy/nn/parallel/layers/cache_embedding/copyer.py @@ -3,7 +3,7 @@ from torch import LongTensor class LimitBuffIndexCopyer(object): - """LimitBuffIndexCopyer + """LimitBuffIndexCopyer Index Copy using limited temp buffer on CUDA. Args: @@ -15,9 +15,9 @@ class LimitBuffIndexCopyer(object): @torch.no_grad() def index_copy(self, dim: int, src_index: LongTensor, tgt_index: LongTensor, src: torch.Tensor, tgt: torch.Tensor): - """copy + """copy src tensor[src_index] -(index_select)-> tmp -(index_copy_)-> tgt tensor [tgt_index] - The valid rows in the src tensor are continous, while rows in tgt tensor is scattered. + The valid rows in the src tensor are continuous, while rows in tgt tensor is scattered. Args: dim (int): dimension along which to index @@ -39,7 +39,7 @@ class LimitBuffIndexCopyer(object): for begin_pos in range(0, dim_size, self._buff_size): cur_len = min(self._buff_size, dim_size - begin_pos) src_idx_piece = src_index.narrow(0, begin_pos, cur_len) - if src_device.type == 'cpu' and tgt_device.type == 'cuda': + if src_device.type == "cpu" and tgt_device.type == "cuda": cpu_tmp_buffer = src.index_select(dim, src_idx_piece).pin_memory() tmp_buffer = torch.empty_like(cpu_tmp_buffer, device=tgt_device) tmp_buffer.copy_(cpu_tmp_buffer) diff --git a/colossalai/legacy/nn/parallel/layers/cache_embedding/embedding_config.py b/colossalai/legacy/nn/parallel/layers/cache_embedding/embedding_config.py new file mode 100644 index 0000000000000000000000000000000000000000..ceaa9081c724014fddaa21a0f43c65cf25f9abc6 --- /dev/null +++ b/colossalai/legacy/nn/parallel/layers/cache_embedding/embedding_config.py @@ -0,0 +1,29 @@ +import torch + + +class TablewiseEmbeddingBagConfig: + """ + example: + def prepare_tablewise_config(args, cache_ratio, ...): + embedding_bag_config_list: List[TablewiseEmbeddingBagConfig] = [] + ... + return embedding_bag_config_list + """ + + def __init__( + self, + num_embeddings: int, + cuda_row_num: int, + assigned_rank: int = 0, + buffer_size=50_000, + ids_freq_mapping=None, + initial_weight: torch.tensor = None, + name: str = "", + ): + self.num_embeddings = num_embeddings + self.cuda_row_num = cuda_row_num + self.assigned_rank = assigned_rank + self.buffer_size = buffer_size + self.ids_freq_mapping = ids_freq_mapping + self.initial_weight = initial_weight + self.name = name diff --git a/colossalai/legacy/nn/parallel/layers/cache_embedding/parallel_cached_embedding.py b/colossalai/legacy/nn/parallel/layers/cache_embedding/parallel_cached_embedding.py new file mode 100644 index 0000000000000000000000000000000000000000..ee739935fef2f3db69ba21a891e8ea6cfc1fe65d --- /dev/null +++ b/colossalai/legacy/nn/parallel/layers/cache_embedding/parallel_cached_embedding.py @@ -0,0 +1,174 @@ +from typing import List, Optional, Tuple + +import torch +import torch.nn.functional as F + +from colossalai.legacy.nn._ops._utils import dual_all_to_all +from colossalai.legacy.tensor import ColoTensorSpec, ComputePattern, ProcessGroup, ShardSpec +from colossalai.tensor import ColoTensor + +from .cache_mgr import EvictionStrategy +from .cached_embedding import CachedEmbeddingBag + + +def get_partition(embedding_dim, rank, world_size) -> Tuple[int, int, bool]: + if world_size == 1: + return 0, embedding_dim, True + + assert embedding_dim >= world_size, ( + f"Embedding dimension {embedding_dim} must be larger than the world size " f"{world_size} of the process group" + ) + chunk_size = embedding_dim // world_size + threshold = embedding_dim % world_size + # if embedding dim is divisible by world size + if threshold == 0: + return rank * chunk_size, (rank + 1) * chunk_size, True + + # align with the split strategy of torch.tensor_split + size_list = [chunk_size + 1 if i < threshold else chunk_size for i in range(world_size)] + offset = sum(size_list[:rank]) + return offset, offset + size_list[rank], False + + +class ParallelCachedEmbeddingBag(CachedEmbeddingBag): + def __init__( + self, + num_embeddings, + embedding_dim, + padding_idx=None, + max_norm=None, + norm_type=2.0, + scale_grad_by_freq=False, + sparse=False, + _weight=None, + mode="mean", + include_last_offset=False, + dtype=None, + device=None, + cache_ratio=0.01, + ids_freq_mapping=None, + warmup_ratio=0.7, + buffer_size=50_000, + pin_weight=False, + evict_strategy: EvictionStrategy = EvictionStrategy.DATASET, + ): + self.rank = torch.distributed.get_rank() + self.world_size = torch.distributed.get_world_size() + + self.partition_start_index, self.partition_end_index, divisible = get_partition( + embedding_dim, self.rank, self.world_size + ) + self.embedding_dim_per_partition = self.partition_end_index - self.partition_start_index + + super(ParallelCachedEmbeddingBag, self).__init__( + num_embeddings, + embedding_dim, + padding_idx, + max_norm, + norm_type, + scale_grad_by_freq, + sparse, + _weight, + mode, + include_last_offset, + dtype, + device, + cache_ratio, + ids_freq_mapping, + warmup_ratio, + buffer_size, + pin_weight, + evict_strategy, + ) + self.cache_op = True + + def _weight_alloc(self, dtype, device): + weight = torch.empty(self.num_embeddings, self.embedding_dim_per_partition, device=device, dtype=dtype) + with torch.no_grad(): + weight.data.uniform_(-1 / self.num_embeddings, 1 / self.num_embeddings) + if self.padding_idx is not None: + weight[self.padding_idx].fill_(0) + colo_tensor_spec = ColoTensorSpec( + pg=ProcessGroup(tp_degree=self.world_size), + dist_attr=ShardSpec(dims=[-1], num_partitions=[self.world_size]), + compute_attr=ComputePattern.TP1D, + ) + return ColoTensor.from_torch_tensor(weight, spec=colo_tensor_spec) + + def forward( + self, + indices, + offsets=None, + per_sample_weights=None, + shape_hook=None, + scatter_dim=0, + gather_dim=-1, + ): + if self.cache_op: + with torch.no_grad(): + indices = self.cache_weight_mgr.prepare_ids(indices) + output_shard = F.embedding_bag( + indices.cuda(), + self.cache_weight_mgr.cuda_cached_weight, + offsets, + self.max_norm, + self.norm_type, + self.scale_grad_by_freq, + self.mode, + self.sparse, + per_sample_weights, + self.include_last_offset, + self.padding_idx, + ) + if shape_hook is not None: + output_shard = shape_hook(output_shard) + output_full = dual_all_to_all( + output_shard, self.weight.get_process_group(), scatter_dim=scatter_dim, gather_dim=gather_dim + ) + return output_full + + def set_cache_op(self, cache_op: bool = True): + self.cache_op = cache_op + + @classmethod + def from_pretrained( + cls, + embedding: torch.Tensor, + freeze: bool = True, + padding_idx: Optional[int] = None, + max_norm: Optional[float] = None, + norm_type: float = 2.0, + scale_grad_by_freq: bool = False, + sparse: bool = False, + mode: str = "mean", + include_last_offset: bool = False, + cuda_row_num: int = 100_000, + ids_freq_mapping: Optional[List[int]] = None, + warmup_ratio: float = 0.7, + buffer_size: int = 0, + ) -> "ParallelCachedEmbeddingBag": + rows, cols = embedding.shape + embedding_bag = cls( + rows, + cols, + padding_idx, + max_norm, + norm_type, + scale_grad_by_freq, + sparse, + embedding, + mode, + include_last_offset, + cuda_row_num=cuda_row_num, + ids_freq_mapping=ids_freq_mapping, + warmup_ratio=warmup_ratio, + buffer_size=buffer_size, + ) + embedding_bag.cache_weight_mgr.cuda_cached_weight.requires_grad_ = not freeze + return embedding_bag + + def print_comm_stats_(self): + self.cache_weight_mgr.print_comm_stats() + + def element_size(self): + return self.weight.element_size() diff --git a/colossalai/legacy/nn/parallel/layers/cache_embedding/parallel_cached_embedding_tablewise.py b/colossalai/legacy/nn/parallel/layers/cache_embedding/parallel_cached_embedding_tablewise.py new file mode 100644 index 0000000000000000000000000000000000000000..7d21f5b68ce617f64bebd92ada96255ffa540326 --- /dev/null +++ b/colossalai/legacy/nn/parallel/layers/cache_embedding/parallel_cached_embedding_tablewise.py @@ -0,0 +1,229 @@ +from typing import List + +import torch +import torch.distributed as dist +import torch.nn.functional as F + +from colossalai.legacy.nn._ops._utils import dual_all_to_all_tablewise +from colossalai.legacy.tensor import ProcessGroup + +from .cache_mgr import EvictionStrategy +from .cached_embedding import CachedEmbeddingBag +from .embedding_config import TablewiseEmbeddingBagConfig + + +class ParallelCachedEmbeddingBagTablewise(CachedEmbeddingBag): + """ + all tables assigned to this class instance are managed by a single CachedEmbeddingBag. + Those parameters in TablewiseEmbeddingBagConfig are ignored: cuda_row_num, buffer_size, initial_weight. + """ + + def __init__( + self, + embedding_bag_config_list: List[TablewiseEmbeddingBagConfig], + embedding_dim: int, + padding_idx=None, + max_norm=None, + norm_type=2.0, + scale_grad_by_freq=False, + sparse=False, + _weight=None, + mode="mean", + include_last_offset=False, + dtype=None, + device=None, + cache_ratio=0.01, + warmup_ratio=0.7, + buffer_size=50_000, + pin_weight=False, + evict_strategy: EvictionStrategy = EvictionStrategy.LFU, + ): + self.rank = dist.get_rank() + self.world_size = dist.get_world_size() + self.rank_of_tables = [config.assigned_rank for config in embedding_bag_config_list] + self.global_table_num_embeddings_list = [config.num_embeddings for config in embedding_bag_config_list] + self.global_tables_num = len(embedding_bag_config_list) + self.global_tables_offsets = torch.cumsum(torch.tensor([0] + self.global_table_num_embeddings_list), 0).cuda() + self.assigned_table_list: List[int] = [] + self.pg = ProcessGroup(tp_degree=self.world_size) + self.num_embeddings = 0 + for i, rank in enumerate(self.rank_of_tables): + if rank == self.rank: + self.assigned_table_list.append(i) + self.num_embeddings += self.global_table_num_embeddings_list[i] + self.include_last_offset = include_last_offset + + ids_freq_mapping = [] + for config in embedding_bag_config_list: + if config.assigned_rank == self.rank: + if config.ids_freq_mapping != None: + ids_freq_mapping.extend(config.ids_freq_mapping) + else: + ids_freq_mapping = None + break + self.cache_ratio = cache_ratio + # table-associate cache + int(cache_ratio * self.num_embeddings) + super(ParallelCachedEmbeddingBagTablewise, self).__init__( + self.num_embeddings, + embedding_dim, + padding_idx, + max_norm, + norm_type, + scale_grad_by_freq, + sparse, + _weight, + mode, + include_last_offset, + dtype, + device, + cache_ratio, + ids_freq_mapping, + warmup_ratio, + buffer_size, + pin_weight, + evict_strategy, + ) + + # for assigned tables reconnection: + self.idx_offset_list = [] + offset_cumsum = 0 + for table_i, table_num_embeddings in enumerate(self.global_table_num_embeddings_list): + if self.rank_of_tables[table_i] == self.rank: + self.idx_offset_list.append(offset_cumsum) + else: + offset_cumsum += table_num_embeddings + + # prepare list shape for all_to_all output + self.embedding_dim_per_rank = [0 for i in range(self.world_size)] + for rank in self.rank_of_tables: + self.embedding_dim_per_rank[rank] += embedding_dim + + self.cache_op = True + + def forward( + self, + indices: torch.Tensor, + offsets: torch.Tensor = None, + per_sample_weights=None, + shape_hook=None, + already_split_along_rank=True, + ): + if not already_split_along_rank: + # not recommanded. it takes time. + batch_size = (offsets.shape[0]) // self.global_tables_num + local_indices, local_offsets, local_per_sample_weights = self.split_along_rank( + batch_size, indices, offsets, per_sample_weights + ) + else: + # recommanded. + batch_size = (offsets.shape[0]) // len(self.assigned_table_list) + local_indices, local_offsets, local_per_sample_weights = indices, offsets, per_sample_weights + if self.cache_op: + with torch.no_grad(): + indices = self.cache_weight_mgr.prepare_ids(local_indices) + local_output = F.embedding_bag( + indices.cuda(), + self.cache_weight_mgr.cuda_cached_weight, + local_offsets, + self.max_norm, + self.norm_type, + self.scale_grad_by_freq, + self.mode, + self.sparse, + local_per_sample_weights, + self.include_last_offset, + self.padding_idx, + ) + local_output = torch.cat(local_output.split(batch_size), 1) + remains = batch_size % self.world_size + scatter_strides = [batch_size // self.world_size + int(i < remains) for i in range(self.world_size)] + output_full = dual_all_to_all_tablewise(local_output, self.pg, scatter_strides, self.embedding_dim_per_rank) + if shape_hook is not None: + output_full = shape_hook(output_full) + return output_full + + def split_along_rank( + self, batch_size, indices: torch.Tensor, offsets: torch.Tensor = None, per_sample_weights=None + ): + """ + if input indices and offsets haven't been splitted along assigned rank, this function will do it. + it takes time. please consider splitting data during batch loading. + """ + local_indices_list: List(torch.Tensor) = [] + local_offsets_list: List(torch.Tensor) = [] + if per_sample_weights != None: + local_per_sample_weights_list: List(torch.Tensor) = [] + + offset_pre_end = 0 # local_offsets trick + for i, handle_table in enumerate(self.assigned_table_list): + indices_start_position = offsets[batch_size * handle_table] + if (not self.include_last_offset) and (batch_size * (handle_table + 1) >= indices.shape[0]): + # till-the-end special case + indices_end_position = indices.shape[0] + else: + indices_end_position = offsets[batch_size * (handle_table + 1)] + # alternative approach: reduce malloc + """ + # 1. local_indices_list: + local_indices = indices.narrow(0, indices_start_position, indices_end_position - indices_start_position) + torch.sub(local_indices, self.idx_offset_list[i], out=local_indices) + local_indices_list.append(local_indices) + # 2. local_offsets_list: + if i + 1 == len(self.assigned_table_list): + # till-the-end special case + if not self.include_last_offset: + local_offsets = offsets.narrow(0, batch_size * handle_table, batch_size) + else: + local_offsets = offsets.narrow(0, batch_size * handle_table, batch_size + 1) + torch.add(local_offsets, offset_pre_end - offsets[batch_size * handle_table], out=local_offsets) + local_offsets_list.append(local_offsets) + else: + temp_holder = offsets[batch_size * handle_table].item() + local_offsets = offsets.narrow(0, batch_size * handle_table, batch_size) + torch.add(local_offsets, offset_pre_end - offsets[batch_size * handle_table], out=local_offsets) + offset_pre_end = offsets[batch_size * (handle_table + 1)] + offset_pre_end - temp_holder + local_offsets_list.append(local_offsets) + """ + # 1. local_indices_list: + local_indices_list.append( + indices.narrow(0, indices_start_position, indices_end_position - indices_start_position).sub( + self.idx_offset_list[i] + ) + ) + # 2. local_offsets_list: + if i + 1 == len(self.assigned_table_list): + # till-the-end special case + if not self.include_last_offset: + local_offsets = offsets.narrow(0, batch_size * handle_table, batch_size).add( + offset_pre_end - offsets[batch_size * (handle_table)] + ) + else: + local_offsets = offsets.narrow(0, batch_size * handle_table, batch_size + 1).add( + offset_pre_end - offsets[batch_size * (handle_table)] + ) + local_offsets_list.append(local_offsets) + else: + local_offsets = offsets.narrow(0, batch_size * handle_table, batch_size + 1).add( + offset_pre_end - offsets[batch_size * (handle_table)] + ) + offset_pre_end = local_offsets[-1] + local_offsets_list.append(local_offsets[:-1]) + # 3. local_per_sample_weights_list: + if per_sample_weights != None: + local_per_sample_weights_list.append(per_sample_weights[indices_start_position:indices_end_position]) + local_indices = torch.cat(local_indices_list, 0) + local_offsets = torch.cat(local_offsets_list, 0) + local_per_sample_weights = None + if per_sample_weights != None: + local_per_sample_weights = torch.cat(local_per_sample_weights_list, 0) + return local_indices, local_offsets, local_per_sample_weights + + def set_cache_op(self, cache_op: bool = True): + self.cache_op = cache_op + + def print_comm_stats_(self): + self.cache_weight_mgr.print_comm_stats() + + def element_size(self): + return self.weight.element_size() diff --git a/colossalai/legacy/nn/parallel/layers/cache_embedding/parallel_cached_embedding_tablewise_split_cache.py b/colossalai/legacy/nn/parallel/layers/cache_embedding/parallel_cached_embedding_tablewise_split_cache.py new file mode 100644 index 0000000000000000000000000000000000000000..94a27a8673da475c4ad9c9fea2ce03fb8dbf3a59 --- /dev/null +++ b/colossalai/legacy/nn/parallel/layers/cache_embedding/parallel_cached_embedding_tablewise_split_cache.py @@ -0,0 +1,147 @@ +import abc +from typing import List + +import torch +import torch.distributed as dist +import torch.nn as nn +from torch.profiler import record_function + +from colossalai.legacy.nn._ops._utils import dual_all_to_all_tablewise +from colossalai.legacy.tensor import ProcessGroup + +from .cache_mgr import EvictionStrategy +from .cached_embedding import CachedEmbeddingBag +from .embedding_config import TablewiseEmbeddingBagConfig + + +class ParallelCachedEmbeddingBagTablewiseSpiltCache(abc.ABC, nn.Module): + """ + every table assigned to this class instance is managed by a CachedEmbeddingBag. + """ + + def __init__( + self, + embedding_bag_config_list: List[TablewiseEmbeddingBagConfig], + embedding_dim: int, + padding_idx=None, + max_norm=None, + norm_type=2.0, + scale_grad_by_freq=False, + sparse=False, + mode="mean", + include_last_offset=False, + dtype=None, + device=None, + warmup_ratio=0.7, + pin_weight=False, + evict_strategy: EvictionStrategy = EvictionStrategy.LFU, + ): + super(ParallelCachedEmbeddingBagTablewiseSpiltCache, self).__init__() + self.rank = dist.get_rank() + self.world_size = dist.get_world_size() + self.rank_of_tables = [config.assigned_rank for config in embedding_bag_config_list] + self.global_table_num_embeddings_list = [config.num_embeddings for config in embedding_bag_config_list] + self.global_tables_num = len(embedding_bag_config_list) + self.global_tables_offsets = torch.cumsum(torch.tensor([0] + self.global_table_num_embeddings_list), 0).cuda() + + self.assigned_table_list: List[int] = [] + for i, rank in enumerate(self.rank_of_tables): + if rank == self.rank: + self.assigned_table_list.append(i) + self.include_last_offset = include_last_offset + self.pg = ProcessGroup(tp_degree=self.world_size) + + # prepare CachedEmbeddingBag list + + self.cached_embedding_bag_list: nn.ModuleList = nn.ModuleList() + for config in embedding_bag_config_list: + if config.assigned_rank != self.rank: + continue + self.cached_embedding_bag_list.append( + CachedEmbeddingBag( + num_embeddings=config.num_embeddings, + embedding_dim=embedding_dim, + padding_idx=padding_idx, + max_norm=max_norm, + norm_type=norm_type, + scale_grad_by_freq=scale_grad_by_freq, + sparse=sparse, + _weight=config.initial_weight, + mode=mode, + include_last_offset=include_last_offset, + dtype=dtype, + device=device, + cuda_row_num=config.cuda_row_num, + ids_freq_mapping=config.ids_freq_mapping, + warmup_ratio=warmup_ratio, + buffer_size=config.buffer_size, + pin_weight=pin_weight, + evict_strategy=evict_strategy, + ) + ) + + # prepare list shape for all_to_all output + self.embedding_dim_per_rank = [0 for i in range(self.world_size)] + for rank in self.rank_of_tables: + self.embedding_dim_per_rank[rank] += embedding_dim + + def forward(self, indices: torch.Tensor, offsets: torch.Tensor = None, per_sample_weights=None, shape_hook=None): + # determine indices to handle + batch_size = (offsets.shape[0]) // self.global_tables_num + local_output_list = [] + for i, handle_table in enumerate(self.assigned_table_list): + with record_function("(tablewise) prepare indices and offsets"): + with record_function("part 1"): + indices_start_position = offsets[batch_size * handle_table] + if (not self.include_last_offset) and (batch_size * (handle_table + 1) >= indices.shape[0]): + # till the end special case + indices_end_position = indices.shape[0] + else: + indices_end_position = offsets[batch_size * (handle_table + 1)] + with record_function("part 2"): + # local_indices = indices[indices_start_position:indices_end_position] - self.global_tables_offsets[handle_table] + local_indices = indices.narrow( + 0, indices_start_position, indices_end_position - indices_start_position + ).sub(self.global_tables_offsets[handle_table]) + if self.include_last_offset: + # local_offsets = offsets[batch_size * handle_table:batch_size * (handle_table + 1) + 1] - offsets[batch_size * (handle_table)] + local_offsets = offsets.narrow(0, batch_size * handle_table, batch_size + 1).sub( + offsets[batch_size * (handle_table)] + ) + else: + # local_offsets = offsets[batch_size * handle_table:batch_size * (handle_table + 1)] - offsets[batch_size * (handle_table)] + local_offsets = offsets.narrow(0, batch_size * handle_table, batch_size).sub( + offsets[batch_size * (handle_table)] + ) + local_per_sample_weights = None + if per_sample_weights != None: + local_per_sample_weights = per_sample_weights[indices_start_position:indices_end_position] + with record_function("(tablewise) tablewise forward"): + local_output_list.append( + self.cached_embedding_bag_list[i](local_indices, local_offsets, local_per_sample_weights) + ) + + # get result of shape = (batch_size, (len(assigned_table_list)*embedding_dim)) + local_output = torch.cat(local_output_list, 1) + # then concatenate those local_output on the second dimension. + # use all_to_all + remains = batch_size % self.world_size + scatter_strides = [batch_size // self.world_size + int(i < remains) for i in range(self.world_size)] + output_full = dual_all_to_all_tablewise(local_output, self.pg, scatter_strides, self.embedding_dim_per_rank) + if shape_hook is not None: + output_full = shape_hook(output_full) + return output_full + + def element_size(self): + if len(self.assigned_table_list) == 0: + return 0 + return self.cached_embedding_bag_list[0].cache_weight_mgr.weight.element_size() + + def print_comm_stats_(self): + cuda_to_cpu_elem_num = 0 + cpu_to_cuda_elem_num = 0 + for cached_embedding_bag in self.cached_embedding_bag_list: + cuda_to_cpu_elem_num += cached_embedding_bag.cache_weight_mgr._cuda_to_cpu_numel + cpu_to_cuda_elem_num += cached_embedding_bag.cache_weight_mgr._cpu_to_cuda_numel + print(f"CUDA->CPU num: {cuda_to_cpu_elem_num / 1e6} M elem") + print(f"CPU->CUDA num: {cpu_to_cuda_elem_num / 1e6} M elem") diff --git a/colossalai/legacy/nn/parallel/layers/colo_module.py b/colossalai/legacy/nn/parallel/layers/colo_module.py new file mode 100644 index 0000000000000000000000000000000000000000..df0b324eeeb8c52cf5e5f00b658b29e6f42f17ca --- /dev/null +++ b/colossalai/legacy/nn/parallel/layers/colo_module.py @@ -0,0 +1,46 @@ +from typing import Dict, List + +from colossalai.legacy.tensor import ComputePattern +from colossalai.legacy.tensor.distspec import _DistSpec + + +class ColoModule(object): + def __init__(self): + self._shard_params: List[str] = [] + self._allowed_patterns: Dict[ComputePattern, Dict[str, Dict[str, _DistSpec]]] = {} + + def _register_shard_params(self, params: List[str]): + self._shard_params = params + + def _register_allowed_patterns( + self, compute_pattern: ComputePattern, dist_specs: Dict[str, _DistSpec], mode="default" + ): + assert ( + list(dist_specs.keys()).sort() == self._shard_params.sort() + ), "Every registered param should have dist_spec." + if not compute_pattern in self._allowed_patterns: + self._allowed_patterns[compute_pattern] = {} + self._allowed_patterns[compute_pattern][mode] = dist_specs + + def _set_default(self, compute_pattern: ComputePattern, target_mode): + self._allowed_patterns[compute_pattern]["default"] = self._allowed_patterns[compute_pattern][target_mode] + + def has_compute_pattern(self, compute_pattern: ComputePattern): + return compute_pattern in self._allowed_patterns + + def get_dist_specs(self, compute_pattern: ComputePattern): + assert self.has_compute_pattern(compute_pattern) + return self._allowed_patterns[compute_pattern] + + def has_compute_pattern_with_mode(self, compute_pattern: ComputePattern, mode="default"): + return compute_pattern in self._allowed_patterns and mode in self._allowed_patterns[compute_pattern] + + def get_dist_specs_with_mode(self, compute_pattern: ComputePattern, mode="default"): + assert self.has_compute_pattern_with_mode(compute_pattern, mode) + return self._allowed_patterns[compute_pattern][mode] + + def get_param_names(self): + return self._shard_params + + def register(self, compute_pattern, pg): + raise NotImplementedError diff --git a/colossalai/legacy/nn/parallel/layers/embedding.py b/colossalai/legacy/nn/parallel/layers/embedding.py new file mode 100644 index 0000000000000000000000000000000000000000..f204f3fb71f04e78f887ff774c356d56082b845b --- /dev/null +++ b/colossalai/legacy/nn/parallel/layers/embedding.py @@ -0,0 +1,36 @@ +from colossalai.legacy.tensor import ComputePattern, ProcessGroup, ShardSpec + +from .colo_module import ColoModule + + +class ColoEmbedding(ColoModule): + def __init__(self): + super(ColoEmbedding, self).__init__() + self._register_shard_params(["weight"]) + + def register(self, compute_pattern, pg: ProcessGroup): + if not compute_pattern in self._allowed_patterns: + if ComputePattern.TP1D == compute_pattern: + self._set_TP1D(pg) + + def _set_TP1D(self, pg: ProcessGroup): + # TP1D Row Linear + _compute_pattern = ComputePattern.TP1D + self._register_allowed_patterns( + compute_pattern=_compute_pattern, + dist_specs={ + "weight": ShardSpec([0], [pg.tp_world_size()]), + }, + mode="row", + ) + + # TP1D Col Linear + self._register_allowed_patterns( + compute_pattern=_compute_pattern, + dist_specs={ + "weight": ShardSpec([-1], [pg.tp_world_size()]), + }, + mode="col", + ) + + self._set_default(compute_pattern=_compute_pattern, target_mode="row") diff --git a/colossalai/legacy/nn/parallel/layers/linear.py b/colossalai/legacy/nn/parallel/layers/linear.py new file mode 100644 index 0000000000000000000000000000000000000000..c3b6df1ec9da6b102bac410889af439aaaa4f2c0 --- /dev/null +++ b/colossalai/legacy/nn/parallel/layers/linear.py @@ -0,0 +1,32 @@ +from colossalai.legacy.tensor import ComputePattern, ProcessGroup, ShardSpec + +from .colo_module import ColoModule + + +class ColoLinear(ColoModule): + def __init__(self): + super(ColoLinear, self).__init__() + self._register_shard_params(["weight", "bias"]) + + def register(self, compute_pattern, pg: ProcessGroup): + if not compute_pattern in self._allowed_patterns: + if ComputePattern.TP1D == compute_pattern: + self._set_TP1D(pg) + + def _set_TP1D(self, pg): + # TP1D Row Linear + _compute_pattern = ComputePattern.TP1D + self._register_allowed_patterns( + compute_pattern=_compute_pattern, + dist_specs={"weight": ShardSpec([-1], [pg.tp_world_size()]), "bias": None}, + mode="row", + ) + + # TP1D Col Linear + self._register_allowed_patterns( + compute_pattern=_compute_pattern, + dist_specs={"weight": ShardSpec([0], [pg.tp_world_size()]), "bias": ShardSpec([0], [pg.tp_world_size()])}, + mode="col", + ) + + self._set_default(compute_pattern=_compute_pattern, target_mode="row") diff --git a/colossalai/nn/parallel/layers/module_utils.py b/colossalai/legacy/nn/parallel/layers/module_utils.py similarity index 84% rename from colossalai/nn/parallel/layers/module_utils.py rename to colossalai/legacy/nn/parallel/layers/module_utils.py index 38d128cc705e6bfa3db68cb6c59b66450ce1223e..4dbce7e09f379c1450c07ded7130686d2a94823f 100644 --- a/colossalai/nn/parallel/layers/module_utils.py +++ b/colossalai/legacy/nn/parallel/layers/module_utils.py @@ -1,9 +1,12 @@ from typing import Dict -from colossalai.tensor import ColoParameter, ComputeSpec, ProcessGroup -from colossalai.tensor import distspec -from . import ColoModule + import torch +from colossalai.legacy.tensor import ComputeSpec, ProcessGroup +from colossalai.tensor import ColoParameter + +from . import ColoModule + _COLOSSAL_MODULES: Dict[type, ColoModule] = {} @@ -38,7 +41,7 @@ def check_colo_module(module: torch.nn.Module, pg: ProcessGroup, recursive=True) for param_name in param_names: param = module.get_parameter(param_name) if not isinstance(param, ColoParameter): - raise Exception(f'Invalid ColoParameter spec: {param} in {module} is not a ColoParameter.') + raise Exception(f"Invalid ColoParameter spec: {param} in {module} is not a ColoParameter.") if param.has_compute_spec(): cur_compute_pattern = param.compute_spec.compute_pattern if compute_pattern is None: @@ -46,7 +49,8 @@ def check_colo_module(module: torch.nn.Module, pg: ProcessGroup, recursive=True) else: if cur_compute_pattern != compute_pattern: raise Exception( - f'Invalid ColoParameter spec: Params in {module} have different compute_pattern.') + f"Invalid ColoParameter spec: Params in {module} have different compute_pattern." + ) else: continue @@ -54,7 +58,8 @@ def check_colo_module(module: torch.nn.Module, pg: ProcessGroup, recursive=True) colo_module.register(compute_pattern, pg) if not colo_module.has_compute_pattern(compute_pattern): raise Exception( - f'Invalid ColoParameter spec: ComputePattern {compute_pattern} in {module} is not allowed.') + f"Invalid ColoParameter spec: ComputePattern {compute_pattern} in {module} is not allowed." + ) match_specs = False allowed_specs = colo_module.get_dist_specs(compute_pattern) @@ -74,17 +79,15 @@ def check_colo_module(module: torch.nn.Module, pg: ProcessGroup, recursive=True) match_specs = True break if match_specs == False: - raise Exception(f'Invalid ColoParameter spec: Params in {module} are incorrectly sharded.') + raise Exception(f"Invalid ColoParameter spec: Params in {module} are incorrectly sharded.") if recursive == True: for submodule in module.children(): check_colo_module(submodule, pg=pg, recursive=True) -def init_colo_module(module: torch.nn.Module, - compute_spec: ComputeSpec, - pg: ProcessGroup, - recursive=True, - mode='default'): +def init_colo_module( + module: torch.nn.Module, compute_spec: ComputeSpec, pg: ProcessGroup, recursive=True, mode="default" +): compute_pattern = compute_spec.compute_pattern if is_colo_module(module): # for each param diff --git a/colossalai/nn/parallel/reducer.py b/colossalai/legacy/nn/parallel/reducer.py similarity index 92% rename from colossalai/nn/parallel/reducer.py rename to colossalai/legacy/nn/parallel/reducer.py index 5687055819fe1fd0177e507f1a27d3bab8b5b1b5..7b3d283e47dd4b6ded2aeb747e398b03b9f429d7 100644 --- a/colossalai/nn/parallel/reducer.py +++ b/colossalai/legacy/nn/parallel/reducer.py @@ -13,7 +13,6 @@ from torch.distributed import ProcessGroup class Bucket: - def __init__(self, size: int, dtype: torch.dtype, device: torch.device, group: ProcessGroup): self.buffer = torch.zeros(size, dtype=dtype, device=device) self.group = group @@ -26,7 +25,7 @@ class Bucket: assert len(self.callbacks) == 0 return # reduce-scatter bucket - dist.all_reduce(self.buffer[:self.offset], group=self.group) + dist.all_reduce(self.buffer[: self.offset], group=self.group) # execute post-reduction callbacks for callback_fn in self.callbacks: @@ -37,24 +36,22 @@ class Bucket: self.buffer = torch.zeros_like(self.buffer) def alloc(self) -> None: - if self.buffer.storage().size() == 0: self.buffer.storage().resize_(self.buffer.numel()) def free(self) -> None: - assert self.offset == 0 and self.callbacks == [], "Incorrect call of teardown" self.buffer.storage().resize_(0) def append(self, tensor: Tensor, callback_fn: Callable): tensor_size = tensor.numel() offset = self.offset - self.buffer[offset:offset + tensor_size].copy_(tensor.flatten()) + self.buffer[offset : offset + tensor_size].copy_(tensor.flatten()) self.offset += tensor_size # callback will be given the reduced result if callback_fn is not None: - result_view = self.buffer[offset:offset + tensor_size].view(tensor.shape) + result_view = self.buffer[offset : offset + tensor_size].view(tensor.shape) self.callbacks.append(functools.partial(callback_fn, result_view)) @property @@ -63,7 +60,6 @@ class Bucket: class Reducer: - def __init__(self, bucket_size_mb: int = 25): self.bucket_size_mb = bucket_size_mb self.buckets: Dict[Tuple[torch.dtype, torch.device, ProcessGroup], Bucket] = {} @@ -101,7 +97,7 @@ class Reducer: @functools.lru_cache() def _get_bucket_size(self, element_size: int) -> int: - if self.bucket_size_mb <= 0: # Values <= 0 disable bucketing. + if self.bucket_size_mb <= 0: # Values <= 0 disable bucketing. return 0 MB = 1024 * 1024 bucket_size = self.bucket_size_mb * MB / element_size diff --git a/colossalai/legacy/pipeline/__init__.py b/colossalai/legacy/pipeline/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..9f1a5ec7fd1f696e51139759bf3a116748baa678 --- /dev/null +++ b/colossalai/legacy/pipeline/__init__.py @@ -0,0 +1,4 @@ +from .layer_spec import LayerSpec +from .pipelinable import PipelinableContext, PipelinableModel + +__all__ = ["PipelinableModel", "PipelinableContext", "LayerSpec"] diff --git a/colossalai/pipeline/layer_spec.py b/colossalai/legacy/pipeline/layer_spec.py similarity index 91% rename from colossalai/pipeline/layer_spec.py rename to colossalai/legacy/pipeline/layer_spec.py index 7e9169efff78bad7d30f42e9896dc3f61ecaf7fd..825816e1c032187c46500b2b0e57f6b17ab0c334 100644 --- a/colossalai/pipeline/layer_spec.py +++ b/colossalai/legacy/pipeline/layer_spec.py @@ -1,10 +1,10 @@ import torch + from colossalai.utils.model.utils import call_to_str + class LayerSpec: - """ - - """ + """ """ def __init__(self, typename, *module_args, **module_kwargs): self.typename = typename @@ -14,7 +14,7 @@ class LayerSpec: self._param_count = 0 if not issubclass(typename, torch.nn.Module): - raise RuntimeError('LayerSpec only supports torch.nn.Module types.') + raise RuntimeError("LayerSpec only supports torch.nn.Module types.") def __repr__(self): return call_to_str(self.typename.__name__, self.module_args, self.module_kwargs) @@ -52,4 +52,4 @@ class LayerSpec: return self._param_count def reset_param_count(self): - self._param_count = 0 \ No newline at end of file + self._param_count = 0 diff --git a/colossalai/legacy/pipeline/middleware/__init__.py b/colossalai/legacy/pipeline/middleware/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..8a678b7b4c8725a939970b701053cb34add58567 --- /dev/null +++ b/colossalai/legacy/pipeline/middleware/__init__.py @@ -0,0 +1,3 @@ +from .topo import Partition, PartitionInputVal, PartitionOutputVal, Topo + +__all__ = ["Topo", "Partition", "PartitionOutputVal", "PartitionInputVal"] diff --git a/colossalai/legacy/pipeline/middleware/adaptor/__init__.py b/colossalai/legacy/pipeline/middleware/adaptor/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..7f2b18670a767dbf454b5de65a7f34694fcfd89d --- /dev/null +++ b/colossalai/legacy/pipeline/middleware/adaptor/__init__.py @@ -0,0 +1,3 @@ +from .fx import get_topology as get_fx_topology + +__all__ = ["get_fx_topology"] diff --git a/colossalai/pipeline/middleware/adaptor/fx.py b/colossalai/legacy/pipeline/middleware/adaptor/fx.py similarity index 85% rename from colossalai/pipeline/middleware/adaptor/fx.py rename to colossalai/legacy/pipeline/middleware/adaptor/fx.py index 8437c519476218dec90c968a47a73440ed71f519..34b21f8be1bb53fc13560f458225dd1aa11b0805 100644 --- a/colossalai/pipeline/middleware/adaptor/fx.py +++ b/colossalai/legacy/pipeline/middleware/adaptor/fx.py @@ -1,6 +1,8 @@ -from torch.fx.graph_module import GraphModule -from colossalai.pipeline.middleware.topo import Partition, PartitionInputVal, PartitionOutputVal, Topo import torch +from torch.fx.graph_module import GraphModule + +from colossalai.legacy.pipeline.middleware.topo import Partition, PartitionInputVal, PartitionOutputVal, Topo + def partition_name_to_id(partition_name, is_input=False, is_output=False): if is_input: @@ -8,10 +10,11 @@ def partition_name_to_id(partition_name, is_input=False, is_output=False): elif is_output: partition_id = 1 else: - prefix = 'submod_' + prefix = "submod_" partition_id = int(partition_name.split(prefix)[-1]) + 2 return partition_id + # There are two kinds of def in fx.graph # 1. non direct_use & non direct_def, which means the output is used by next partition with a temporary mid value. # e.g. submod1 = call_module(...) @@ -20,12 +23,14 @@ def partition_name_to_id(partition_name, is_input=False, is_output=False): # 2. direct_use & direct_def, which means the output is used by next partition directly. # e.g. submod1 = call_module(...) # submod2 = call_module(submod1, ...) + + def find_input_in_partition(node, partitions, input_partitions=None): p_input_val = None - direct_def = not node.name.startswith('getitem') + direct_def = not node.name.startswith("getitem") # search in input if direct_def and input_partitions is not None: - partition_id = partition_name_to_id('', is_input=True) + partition_id = partition_name_to_id("", is_input=True) for i, input_node in enumerate(input_partitions): if input_node == node: p_input_val = PartitionInputVal(partition_id=partition_id, offset=i) @@ -45,13 +50,14 @@ def find_input_in_partition(node, partitions, input_partitions=None): partition_id = partition_name_to_id(partition.name) p_input_val = PartitionInputVal(partition_id=partition_id, offset=offset) return p_input_val - + return p_input_val - + + def find_output_in_partition(node, partitions, output_partitions=None): p_output_val = PartitionOutputVal() for user in node.users: - direct_use = not user.name.startswith('getitem') + direct_use = not user.name.startswith("getitem") # user is mid partition for partition in partitions: # direct call @@ -70,13 +76,13 @@ def find_output_in_partition(node, partitions, output_partitions=None): if arg == user: p_output_val.add(partition_id=partition_id, offset=i) break - + # user is output if output_partitions is not None: output_node = output_partitions[0] if user.op == output_node.op: output_keys = {} - partition_id = partition_name_to_id('', is_output=True) + partition_id = partition_name_to_id("", is_output=True) torch.fx.graph.map_arg(output_node.args[0], lambda n: output_keys.setdefault(n)) for i, arg in enumerate(output_keys): if arg == node: @@ -84,19 +90,20 @@ def find_output_in_partition(node, partitions, output_partitions=None): break return p_output_val + def get_topology(gm: GraphModule): topo = Topo() topo_output_partition = Partition() - + input_partitions = [] partitions = [] output_partitions = [] for node in gm.graph.nodes: - if node.op == 'placeholder': + if node.op == "placeholder": input_partitions.append(node) - elif node.name.startswith('submod_'): + elif node.name.startswith("submod_"): partitions.append(node) - elif node.op == 'output': + elif node.op == "output": output_partitions.append(node) else: continue @@ -109,7 +116,7 @@ def get_topology(gm: GraphModule): topo_input_partition.add_output_val(p_output_val) topo.set_partitions(partition_id=0, partition=topo_input_partition) topo.set_input_partition_id(partition_id=0) - + for i, partition in enumerate(partitions): topo_mid_partition = Partition() # set input for submodule @@ -120,7 +127,7 @@ def get_topology(gm: GraphModule): # set output for submodule direct_use = True for user in partition.users: - if user.name.startswith('getitem'): + if user.name.startswith("getitem"): direct_use = False break if direct_use: @@ -131,15 +138,17 @@ def get_topology(gm: GraphModule): for user in partition.users: cur_node = user p_output_val = find_output_in_partition(cur_node, partitions, output_partitions) - topo_mid_partition.add_output_val(p_output_val) - topo.set_partitions(partition_id=i+2, partition=topo_mid_partition) - + topo_mid_partition.add_output_val(p_output_val) + topo.set_partitions(partition_id=i + 2, partition=topo_mid_partition) + # set input for output_partition for partition in output_partitions: topo_output_partition = Partition() - torch.fx.graph.map_arg(partition.args[0], lambda n: topo_output_partition.add_input_val( - find_input_in_partition(n, partitions, input_partitions))) + torch.fx.graph.map_arg( + partition.args[0], + lambda n: topo_output_partition.add_input_val(find_input_in_partition(n, partitions, input_partitions)), + ) topo.set_partitions(partition_id=1, partition=topo_output_partition) topo.set_output_partition_id(partition_id=1) - return topo \ No newline at end of file + return topo diff --git a/colossalai/pipeline/middleware/topo.py b/colossalai/legacy/pipeline/middleware/topo.py similarity index 81% rename from colossalai/pipeline/middleware/topo.py rename to colossalai/legacy/pipeline/middleware/topo.py index e798e2ed9cab0cd3036b2dc39169c2f81470bae8..d0e3d2c3dedf3b36d92d4519861dcebc66bf5224 100644 --- a/colossalai/pipeline/middleware/topo.py +++ b/colossalai/legacy/pipeline/middleware/topo.py @@ -1,77 +1,81 @@ -from typing import Dict, List from dataclasses import dataclass +from typing import Dict, List # This file includes data structure used by Pipeline Middleware. + @dataclass class ValPosition: partition_id: int offset: int - + def __str__(self) -> str: - res = f'[partition_id:{self.partition_id},offset:{self.offset}]' + res = f"[partition_id:{self.partition_id},offset:{self.offset}]" return res - + def __repr__(self) -> str: return self.__str__() + class PartitionInputVal(object): def __init__(self, partition_id, offset) -> None: # every input from which partition_id and which offset val_pos = ValPosition(partition_id, offset) self._from_partition_and_offset: ValPosition = val_pos - + def get(self): return self._from_partition_and_offset - + def __str__(self) -> str: - res = '' - res += f'<-({self._from_partition_and_offset})' + res = "" + res += f"<-({self._from_partition_and_offset})" return res - + def __repr__(self) -> str: return self.__str__() - + + class PartitionOutputVal(object): def __init__(self) -> None: # every output to which partition_id and which offset self._to_partition_and_offset: List[ValPosition] = [] - + def add(self, partition_id, offset): val_pos = ValPosition(partition_id, offset) self._to_partition_and_offset.append(val_pos) - + def get(self): return self._to_partition_and_offset - + def __str__(self) -> str: - res = '' - res += '->(' + res = "" + res += "->(" for val_pos in self._to_partition_and_offset: - res += f'{val_pos},' - res += ')' + res += f"{val_pos}," + res += ")" return res - + def __repr__(self) -> str: return self.__str__() + class Partition(object): def __init__(self) -> None: self._input_vals: List[PartitionInputVal] = [] self._output_vals: List[PartitionOutputVal] = [] - + def add_input_val(self, input_val: PartitionInputVal): self._input_vals.append(input_val) - + def add_output_val(self, output_val: PartitionOutputVal): self._output_vals.append(output_val) - + def get_input_vals(self): return self._input_vals - + def get_output_vals(self): return self._output_vals - + # get the output offsets sent to dst_partition_id def get_output_offsets(self, dst_partition_id): res = [] @@ -80,9 +84,9 @@ class Partition(object): for val_pos in outputs: if val_pos.partition_id == dst_partition_id: res.append(offset) - + return res - + # get all input dst partition_ids def get_input_partition_ids(self): res = [] @@ -91,7 +95,7 @@ class Partition(object): if val_pos.partition_id not in res: res.append(val_pos.partition_id) return res - + # get all output dst partition_ids def get_output_partition_ids(self): res = [] @@ -101,24 +105,25 @@ class Partition(object): if val_pos.partition_id not in res: res.append(val_pos.partition_id) return res - + def __str__(self) -> str: - res = '' - res += f' input:\n' - res += f' length:{len(self._input_vals)}\n' + res = "" + res += f" input:\n" + res += f" length:{len(self._input_vals)}\n" for i, input_val in enumerate(self._input_vals): - res += f' offset={i}:{input_val}\n' - - res += f' output:\n' - res += f' length:{len(self._output_vals)}\n' + res += f" offset={i}:{input_val}\n" + + res += f" output:\n" + res += f" length:{len(self._output_vals)}\n" for i, output_val in enumerate(self._output_vals): - res += f' offset={i}:{output_val}\n' - + res += f" offset={i}:{output_val}\n" + return res - + def __repr__(self) -> str: return self.__str__() + # This class is a middleware between partition splitter # and Pipeline Scheduler. It records the graph info about # partition input/output and provides it to scheduler. @@ -136,38 +141,38 @@ class Topo(object): self._partitions: Dict[int, Partition] = {} self._input_partition_id = input_partition_id self._output_partition_id = output_partition_id - + def set_input_partition_id(self, partition_id: int): self._input_partition_id = partition_id - + def set_output_partition_id(self, partition_id: int): self._output_partition_id = partition_id - + def get_input_partition_id(self): return self._input_partition_id - + def get_output_partition_id(self): return self._output_partition_id - + def set_partitions(self, partition_id: int, partition: Partition): self._partitions[partition_id] = partition - + def get_mid_partitions(self): - res = {} #{partition_id: Partition} + res = {} # {partition_id: Partition} for partition_id, partition in self._partitions.items(): if self._input_partition_id == partition_id or self._output_partition_id == partition_id: continue res[partition_id] = partition return res - + def get_mid_partition_ids(self): return list(self.get_mid_partitions().keys()) - + def get_input_partition(self): if self._input_partition_id is not None: return self._partitions[self._input_partition_id] return None - + def get_output_partition(self): if self._output_partition_id is not None: return self._partitions[self._output_partition_id] @@ -175,32 +180,31 @@ class Topo(object): def get_partition_by_id(self, partition_id): return self._partitions[partition_id] - + def __str__(self) -> str: - res = '' + res = "" if len(self._partitions) == 0: - return 'Empty Topo Graph.' + return "Empty Topo Graph." input_part = self.get_input_partition() if input_part is not None: - res += '{\n' - res += f'InputPartition:\n partition_id={self._input_partition_id}\n{input_part}' - res += '}\n' - + res += "{\n" + res += f"InputPartition:\n partition_id={self._input_partition_id}\n{input_part}" + res += "}\n" + mid_parts = self.get_mid_partitions() for i, (partition_id, part) in enumerate(mid_parts.items()): - res += '{\n' - res += f'SubPartition_{i}:\n partition_id={partition_id}\n {part}' - res += '}\n' - + res += "{\n" + res += f"SubPartition_{i}:\n partition_id={partition_id}\n {part}" + res += "}\n" + output_part = self.get_output_partition() if output_part is not None: - res += '{\n' - res += f'OutputPartition:\n partition_id={self._output_partition_id}\n{output_part}' - res += '}\n' - + res += "{\n" + res += f"OutputPartition:\n partition_id={self._output_partition_id}\n{output_part}" + res += "}\n" + return res - + def __repr__(self) -> str: return self.__str__() - \ No newline at end of file diff --git a/colossalai/pipeline/pipelinable.py b/colossalai/legacy/pipeline/pipelinable.py similarity index 87% rename from colossalai/pipeline/pipelinable.py rename to colossalai/legacy/pipeline/pipelinable.py index 9731530a6e15755c9d152697fec0ae0cfd102328..82ccdb554527a027e1638464f0400afc4ea2779f 100644 --- a/colossalai/pipeline/pipelinable.py +++ b/colossalai/legacy/pipeline/pipelinable.py @@ -1,15 +1,20 @@ import torch -import inspect -from colossalai.utils.model.utils import InsertPostInitMethodToModuleSubClasses -from .utils import partition_uniform, partition_balanced, build_kwargs_for_function, \ - build_kwargs_for_module, exec_func_with_kwargs, exec_funcs_with_kwargs, \ - call_module, customized_partition -from colossalai.nn.layer.utils import CheckpointModule +from colossalai.legacy.context import ParallelMode +from colossalai.legacy.core import global_context as gpc +from colossalai.legacy.nn.layer.utils import CheckpointModule from colossalai.tensor import ColoParameter -from colossalai.core import global_context as gpc -from colossalai.context import ParallelMode +from colossalai.utils.model.utils import InsertPostInitMethodToModuleSubClasses + from .layer_spec import LayerSpec +from .utils import ( + build_kwargs_for_module, + call_module, + customized_partition, + exec_funcs_with_kwargs, + partition_balanced, + partition_uniform, +) class PipelinableContext(InsertPostInitMethodToModuleSubClasses): @@ -83,7 +88,7 @@ class PipelinableContext(InsertPostInitMethodToModuleSubClasses): for k, v in kwargs.items(): if isinstance(v, torch.nn.Module): v = self._layer_spec_dict[id(v)] - # (lyl)TODO: analyse ColoTensor as well + # (lyl)TODO: analyze ColoTensor as well modified_kwargs[k] = v # keep track of the module children @@ -117,7 +122,7 @@ class PipelinableContext(InsertPostInitMethodToModuleSubClasses): def to_layer_list(self, exec_seq=None): """ Create a layer spec list and func list with execution sequence given by user. - If exec_seq is None, we will take the module initizing order as execution order. + If exec_seq is None, we will take the module initializing order as execution order. """ self._exec_seq = exec_seq @@ -126,8 +131,10 @@ class PipelinableContext(InsertPostInitMethodToModuleSubClasses): children_name = [] for child in self._root_children: layer_spec = self._layer_spec_dict[id(child)] - if layer_spec.typename in (torch.nn.modules.container.ModuleList, - torch.nn.modules.container.Sequential): + if layer_spec.typename in ( + torch.nn.modules.container.ModuleList, + torch.nn.modules.container.Sequential, + ): for child_in_container in layer_spec.children: self._layer_spec_list.append(self._layer_spec_dict[id(child_in_container)]) for name, module in self._model.named_modules(): @@ -146,9 +153,11 @@ class PipelinableContext(InsertPostInitMethodToModuleSubClasses): named_modules = dict(self._model.named_modules()) for index, element in enumerate(exec_seq): if isinstance(element, str): - if element == 'SPLIT_NODE': + if element == "SPLIT_NODE": continue - assert element in named_modules, f'Found invalid module name {element}, please check if you spell the module name correctly.' + assert ( + element in named_modules + ), f"Found invalid module name {element}, please check if you spell the module name correctly." # get the layer spec based on the module ID module = named_modules[element] @@ -177,7 +186,7 @@ class PipelinableContext(InsertPostInitMethodToModuleSubClasses): def partition(self, num_chunks, pipeline_size, rank): """ - Partitioned model will be built respect to partion policy. + Partitioned model will be built respect to partition policy. The real module instance will be built in this method. """ if isinstance(self._policy, str): @@ -189,11 +198,13 @@ class PipelinableContext(InsertPostInitMethodToModuleSubClasses): param_counts.append(layer_spec.count_params()) parts = partition_balanced(param_counts, pipeline_size, num_chunks)[rank] elif self._policy == "customized": - assert self._exec_seq is not None, f'An explicit exec_seq must be defined by user in customized policy mode.' + assert ( + self._exec_seq is not None + ), f"An explicit exec_seq must be defined by user in customized policy mode." self.customized_parts = customized_partition(self._exec_seq) assert len(self.customized_parts) == gpc.get_world_size( ParallelMode.PIPELINE - ), f'World size is {gpc.get_world_size(ParallelMode.PIPELINE)}, but the number of partions is {len(self.customized_parts)}' + ), f"World size is {gpc.get_world_size(ParallelMode.PIPELINE)}, but the number of partitions is {len(self.customized_parts)}" parts = self.customized_parts[rank] else: raise ValueError("A string partition policy should be one of ['uniform', 'balanced', 'customized'].") @@ -216,14 +227,14 @@ class PipelinableContext(InsertPostInitMethodToModuleSubClasses): elif (layer, "behind") in self._func_dict: behind_func_dict_in_partition[id(module)] = self._func_dict[(layer, "behind")] module_list_in_partition = torch.nn.ModuleList(module_list_in_partition) - pipeline_model = PipelinableModel(module_list_in_partition, front_func_dict_in_partition, - behind_func_dict_in_partition) + pipeline_model = PipelinableModel( + module_list_in_partition, front_func_dict_in_partition, behind_func_dict_in_partition + ) return pipeline_model class PipelinableModel(torch.nn.Module): - def __init__(self, module_list, front_func_dict, behind_func_dict): super().__init__() self._module_list = module_list @@ -232,7 +243,6 @@ class PipelinableModel(torch.nn.Module): def forward(self, *input_tensor, **kwargs): for module in self._module_list: - if id(module) in self._front_func_dict: input_tensor = exec_funcs_with_kwargs(self._front_func_dict, id(module), input_tensor, kwargs) diff --git a/colossalai/pipeline/pipeline_process_group.py b/colossalai/legacy/pipeline/pipeline_process_group.py similarity index 87% rename from colossalai/pipeline/pipeline_process_group.py rename to colossalai/legacy/pipeline/pipeline_process_group.py index c61d97ebabfa354ced39b952708188287d56cab3..2d0d5be87cac69de69b05fd3f52e076a37c5ecd7 100644 --- a/colossalai/pipeline/pipeline_process_group.py +++ b/colossalai/legacy/pipeline/pipeline_process_group.py @@ -1,11 +1,10 @@ -from typing import List, Dict, Tuple -import os import threading +from typing import List -from torch.distributed import rpc import torch.distributed as dist +from torch.distributed import rpc -from colossalai.tensor import ProcessGroup +from colossalai.legacy.tensor import ProcessGroup class PipelineProcessGroup: @@ -14,14 +13,15 @@ class PipelineProcessGroup: def __init__(self) -> None: self.is_initialize = False - def set_global_info(self, - rank: int, - world_size: int, - dp_degree: int = 1, - tp_degree: int = 1, - num_worker_threads: int = 1, - device: str = "cuda") -> None: - + def set_global_info( + self, + rank: int, + world_size: int, + dp_degree: int = 1, + tp_degree: int = 1, + num_worker_threads: int = 1, + device: str = "cuda", + ) -> None: device_mesh_size = dp_degree * tp_degree assert world_size % device_mesh_size == 0, "world_size must be the multiple of dp_degree * tp_degree !!!" self._num_worker_threads = num_worker_threads @@ -60,8 +60,8 @@ class PipelineProcessGroup: device = self.device world_size = self.get_world_size() rank = self.get_global_rank() - backend = 'nccl' if device == 'cuda' else 'gloo' - dist.init_process_group(backend, world_size=world_size, rank=rank, group_name='main_group') + backend = "nccl" if device == "cuda" else "gloo" + dist.init_process_group(backend, world_size=world_size, rank=rank, group_name="main_group") def _initialize_pp_process_group(self) -> None: rank = self.get_global_rank() @@ -71,9 +71,9 @@ class PipelineProcessGroup: options = rpc.TensorPipeRpcBackendOptions(num_worker_threads=self._num_worker_threads) for pp_rank in self._pp_ranks: - options.set_device_map(f'work{pp_rank}', {rank: pp_rank}) + options.set_device_map(f"work{pp_rank}", {rank: pp_rank}) - rpc.init_rpc(name=f'work{rank}', rank=rank, world_size=world_size, rpc_backend_options=options) + rpc.init_rpc(name=f"work{rank}", rank=rank, world_size=world_size, rpc_backend_options=options) def _initialize_tp_dp_process_group(self) -> None: rank = self.get_global_rank() @@ -147,10 +147,10 @@ class PipelineProcessGroup: def get_chimera_all_reduce_group(self, pp_rank: int): with self.chimera_lock: - if not hasattr(self, 'chimera_groups'): + if not hasattr(self, "chimera_groups"): world_size = self.get_world_size() stage_num = self.get_stage_num() - assert world_size % 2 == 0, 'world_size must be even in chimera!' + assert world_size % 2 == 0, "world_size must be even in chimera!" self.chimera_groups = {} for rank in range(world_size // 2): pair = [rank, world_size - 1 - rank] diff --git a/colossalai/legacy/pipeline/rpc/__init__.py b/colossalai/legacy/pipeline/rpc/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..791b9d530673c9ad0ef88b27bbfecad0d4b37d7c --- /dev/null +++ b/colossalai/legacy/pipeline/rpc/__init__.py @@ -0,0 +1,4 @@ +from ._pipeline_schedule import ChimeraPipelineEngine, FillDrainPipelineEngine, OneFOneBPipelineEngine +from .utils import pytree_map + +__all__ = ["FillDrainPipelineEngine", "OneFOneBPipelineEngine", "ChimeraPipelineEngine", "pytree_map"] diff --git a/colossalai/pipeline/rpc/_pipeline_base.py b/colossalai/legacy/pipeline/rpc/_pipeline_base.py similarity index 85% rename from colossalai/pipeline/rpc/_pipeline_base.py rename to colossalai/legacy/pipeline/rpc/_pipeline_base.py index 2d7e25c82e7b917ce7f0b9e4eaf1d15cc8a3cbdd..d203e1a11180a684ee4047a9864a35cc4347c13b 100644 --- a/colossalai/pipeline/rpc/_pipeline_base.py +++ b/colossalai/legacy/pipeline/rpc/_pipeline_base.py @@ -12,17 +12,9 @@ from torch import autograd, nn, optim from torch._C._distributed_rpc import PyRRef from torch.futures import Future -from colossalai.pipeline.middleware import Partition, PartitionInputVal, PartitionOutputVal, Topo -from colossalai.pipeline.pipeline_process_group import ppg -from colossalai.pipeline.rpc.utils import ( - get_batch_lengths, - pyobj_map, - pytree_filter, - pytree_map, - split_batch, - tensor_shape_list, - type_detail, -) +from colossalai.legacy.pipeline.middleware import Partition, Topo +from colossalai.legacy.pipeline.pipeline_process_group import ppg +from colossalai.legacy.pipeline.rpc.utils import get_batch_lengths, pyobj_map, pytree_filter, pytree_map, split_batch class Phase(Enum): @@ -33,7 +25,7 @@ class Phase(Enum): class UniqueKey: - __slots__ = ('microbatch_id', 'phase') + __slots__ = ("microbatch_id", "phase") microbatch_id: int phase: Phase @@ -48,12 +40,22 @@ class UniqueKey: return tuple.__hash__((self.microbatch_id, self.phase)) def __repr__(self) -> str: - return f'Key(microbatch_id={self.microbatch_id}, phase={self.phase})' + return f"Key(microbatch_id={self.microbatch_id}, phase={self.phase})" class WorkItem: - __slots__ = ('stage_id', 'phase', 'args', 'kwargs', 'output', 'refcount', 'microbatch_id', 'batch_id', - 'num_microbatches', 'forward_only') + __slots__ = ( + "stage_id", + "phase", + "args", + "kwargs", + "output", + "refcount", + "microbatch_id", + "batch_id", + "num_microbatches", + "forward_only", + ) stage_id: int phase: Phase @@ -66,50 +68,45 @@ class WorkItem: num_microbatches: int forward_only: bool - def __init__(self, - stage_id, - phase, - args, - kwargs, - output, - microbatch_id, - batch_id, - num_microbatches, - forward_only, - refcount=0) -> None: + def __init__( + self, stage_id, phase, args, kwargs, output, microbatch_id, batch_id, num_microbatches, forward_only, refcount=0 + ) -> None: for attr_name in self.__slots__: setattr(self, attr_name, locals()[attr_name]) class BackwardCache: - __slots__ = ('checkpoint', 'stage_input_args', 'stage_input_kwargs', 'stage_outputs') + __slots__ = ("checkpoint", "stage_input_args", "stage_input_kwargs", "stage_outputs") checkpoint: bool stage_input_args: Tuple[Any] stage_input_kwargs: Dict[Any, Any] stage_outputs: Tuple[Any] - def __init__(self, - stage_input_args: Tuple[Any], - stage_input_kwargs: Dict[Any, Any] = None, - stage_outputs: Tuple[Any] = None, - checkpoint: bool = False) -> None: + def __init__( + self, + stage_input_args: Tuple[Any], + stage_input_kwargs: Dict[Any, Any] = None, + stage_outputs: Tuple[Any] = None, + checkpoint: bool = False, + ) -> None: for arg_name in self.__slots__: setattr(self, arg_name, locals()[arg_name]) class WorkerBase(ABC): - - def __init__(self, - partition_fn: Callable, - partition_args: tuple, - pp_rank: int, - actual_stage_num: int, - num_microbatches: int, - device: str, - criterion: Callable = None, - metric: Callable = None, - checkpoint: bool = False, - data_process_func: Callable = None) -> None: + def __init__( + self, + partition_fn: Callable, + partition_args: tuple, + pp_rank: int, + actual_stage_num: int, + num_microbatches: int, + device: str, + criterion: Callable = None, + metric: Callable = None, + checkpoint: bool = False, + data_process_func: Callable = None, + ) -> None: super().__init__() self.pp_rank = pp_rank @@ -123,7 +120,7 @@ class WorkerBase(ABC): self.device = device self._initialize_outstanding_range() - # variable and const for context managment + # variable and const for context management self.outstanding = 0 self.forward_times = 0 self.backward_times = 0 @@ -150,11 +147,11 @@ class WorkerBase(ABC): self._initialize_context_container() # main loop - self.main_loop_thread = threading.Thread(target=self._work_loop, name=f'rank_{pp_rank}', daemon=True) + self.main_loop_thread = threading.Thread(target=self._work_loop, name=f"rank_{pp_rank}", daemon=True) self.main_loop_thread.start() def _get_future_by_device(self): - return torch.futures.Future(devices=None if self.device in (None, 'cpu') else [self.device]) + return torch.futures.Future(devices=None if self.device in (None, "cpu") else [self.device]) def _initialize_outstanding_range(self): outstanding_range = None @@ -199,12 +196,13 @@ class WorkerBase(ABC): # lifecycle management for DAG scheduler if output_work_item.phase == Phase.FORWARD: lifecycle = len(self.get_consumer_stage_ids()) - if self.is_model_output(): # an extra reference for scheduler collecting results + if self.is_model_output(): # an extra reference for scheduler collecting results lifecycle += 1 elif output_work_item.phase == Phase.BACKWARD: lifecycle = len(self.get_producer_stage_ids()) if self.is_model_input() and self._is_last_step( - output_work_item): # an extra reference for ensure_backward + output_work_item + ): # an extra reference for ensure_backward lifecycle += 1 else: lifecycle = 0 @@ -226,7 +224,7 @@ class WorkerBase(ABC): self.pp_rank_to_worker_rref = pp_rank_to_worker_rref # for some schedule need the other worker's info to initialise partition (like Chimera) - # construction of partition is executed after the registion of pp_rank_to_worker_rref + # construction of partition is executed after the registration of pp_rank_to_worker_rref self._initialize_partition() # res_use works for lifecycle counter, @@ -234,9 +232,9 @@ class WorkerBase(ABC): # offset supports get partial output to reduce comm costs. def get_output_by_key(self, key: UniqueKey, ref_use=False, rank=None, offsets=None) -> Any: output = self._get_output_all(key, ref_use, rank) - if offsets is None: # get all for non iterable output + if offsets is None: # get all for non iterable output return output - else: # get part for iterable output + else: # get part for iterable output output = [output[i] for i in offsets] return output @@ -252,12 +250,12 @@ class WorkerBase(ABC): def get_partition(self): with self.partition_condition_lock: - self.partition_condition_lock.wait_for(lambda: hasattr(self, 'module_partition')) + self.partition_condition_lock.wait_for(lambda: hasattr(self, "module_partition")) return self.module_partition def get_partition_state_dict(self): with self.partition_condition_lock: - self.partition_condition_lock.wait_for(lambda: hasattr(self, 'module_partition')) + self.partition_condition_lock.wait_for(lambda: hasattr(self, "module_partition")) return self.module_partition.state_dict() def _make_args_kwargs(self, microbatch, merge=False): @@ -293,8 +291,17 @@ class WorkerBase(ABC): # make args and kwargs args, kwargs = self._make_args_kwargs(microbatch) - work_item = WorkItem(self.pp_rank, Phase.FORWARD, args, kwargs, output, microbatch_id, None, - self.num_microbatches, forward_only) + work_item = WorkItem( + self.pp_rank, + Phase.FORWARD, + args, + kwargs, + output, + microbatch_id, + None, + self.num_microbatches, + forward_only, + ) with self.work_list_condition_lock: self.work_list[key] = work_item self.work_list_condition_lock.notify_all() @@ -314,15 +321,25 @@ class WorkerBase(ABC): for off in self_input_offsets: self_arg_lst.append(arg_lst[off]) - work_item = WorkItem(self.pp_rank, Phase.FORWARD, self_arg_lst, {}, output, microbatch_id, None, - self.num_microbatches, forward_only) + work_item = WorkItem( + self.pp_rank, + Phase.FORWARD, + self_arg_lst, + {}, + output, + microbatch_id, + None, + self.num_microbatches, + forward_only, + ) with self.work_list_condition_lock: self.work_list[key] = work_item self.work_list_condition_lock.notify_all() # put input tensor which other nodes need into output_list as Phase.INPUT - work_item_remote = WorkItem(self.pp_rank, Phase.INPUT, [], {}, arg_lst, microbatch_id, None, - self.num_microbatches, forward_only) + work_item_remote = WorkItem( + self.pp_rank, Phase.INPUT, [], {}, arg_lst, microbatch_id, None, self.num_microbatches, forward_only + ) with self.output_list_condition_lock: self.output_list[recv_input_key] = work_item_remote @@ -343,8 +360,17 @@ class WorkerBase(ABC): output = self._get_future_by_device() grad_wrt_loss = None - work_item = WorkItem(self.pp_rank, Phase.BACKWARD, grad_wrt_loss, {}, output, microbatch_id, None, - self.num_microbatches, False) + work_item = WorkItem( + self.pp_rank, + Phase.BACKWARD, + grad_wrt_loss, + {}, + output, + microbatch_id, + None, + self.num_microbatches, + False, + ) self.work_list[key] = work_item self.work_list_condition_lock.notify_all() @@ -367,7 +393,7 @@ class WorkerBase(ABC): producer_stage_ids = self.get_producer_stage_ids() producer_num = len(producer_stage_ids) if self.need_model_input(): - producer_num += 1 # for input partition + producer_num += 1 # for input partition subscribe_forward_futures: List[Future] = [None] * producer_num # TODO(jiangziyue) get single value instead of the whole output @@ -376,9 +402,9 @@ class WorkerBase(ABC): producer_output_key = UniqueKey(microbatch_id, Phase.INPUT) producer_worker_rref = self.pp_rank_to_worker_rref[producer_stage_id] offsets = self._get_input_offsets_by_index(target_index=0) - subscribe_forward_futures[0] = producer_worker_rref.rpc_async().get_output_by_key(producer_output_key, - rank=self.pp_rank, - offsets=offsets) + subscribe_forward_futures[0] = producer_worker_rref.rpc_async().get_output_by_key( + producer_output_key, rank=self.pp_rank, offsets=offsets + ) for i in range(0, producer_num - 1): producer_stage_id = producer_stage_ids[i] @@ -386,11 +412,12 @@ class WorkerBase(ABC): producer_worker_rref = self.pp_rank_to_worker_rref[producer_stage_id] target_index = i + 1 offsets = self._get_input_offsets_by_index(target_index=target_index) - if offsets is not None and len(offsets) == 0: # no need to do rpc + if offsets is not None and len(offsets) == 0: # no need to do rpc subscribe_forward_futures[target_index] = [] else: subscribe_forward_futures[target_index] = producer_worker_rref.rpc_async().get_output_by_key( - producer_output_key, rank=self.pp_rank, offsets=offsets) + producer_output_key, rank=self.pp_rank, offsets=offsets + ) else: for i in range(producer_num): @@ -399,14 +426,24 @@ class WorkerBase(ABC): producer_worker_rref = self.pp_rank_to_worker_rref[producer_stage_id] target_index = i offsets = self._get_input_offsets_by_index(target_index=target_index) - if offsets is not None and len(offsets) == 0: # no need to do rpc + if offsets is not None and len(offsets) == 0: # no need to do rpc subscribe_forward_futures[target_index] = [] else: subscribe_forward_futures[target_index] = producer_worker_rref.rpc_async().get_output_by_key( - producer_output_key, rank=self.pp_rank, offsets=offsets) - - work_item_from_producer = WorkItem(stage_id, Phase.FORWARD, subscribe_forward_futures, {}, output, - microbatch_id, None, self.num_microbatches, forward_only) + producer_output_key, rank=self.pp_rank, offsets=offsets + ) + + work_item_from_producer = WorkItem( + stage_id, + Phase.FORWARD, + subscribe_forward_futures, + {}, + output, + microbatch_id, + None, + self.num_microbatches, + forward_only, + ) return work_item_from_producer @@ -418,7 +455,7 @@ class WorkerBase(ABC): # On current PP middleware design for DAG, get_output_by_key used by _subscribe_producer # can only be executed once for every producer-consumer stage pair, which is necessary # to count the lifecycle of work_item. So, keeping the _subscribe_producer in the same - # lock of work_item queue operation gurantees the consistency of lifecycle counter. + # lock of work_item queue operation guarantees the consistency of lifecycle counter. work_item_from_producer = self._subscribe_producer(microbatch_id, forward_only) self.work_list[key] = work_item_from_producer self.work_list_condition_lock.notify_all() @@ -441,15 +478,25 @@ class WorkerBase(ABC): consumer_worker_rref = self.pp_rank_to_worker_rref[consumer_stage_id] target_index = i offsets = self._get_output_offsets_by_index(target_index=target_index) - if offsets is not None and len(offsets) == 0: # no need to do rpc + if offsets is not None and len(offsets) == 0: # no need to do rpc subscribe_backward_futures[target_index] = [] else: subscribe_backward_futures[target_index] = consumer_worker_rref.rpc_async().get_output_by_key( - consumer_output_key, rank=self.pp_rank, offsets=offsets) + consumer_output_key, rank=self.pp_rank, offsets=offsets + ) # flatten args - work_item_from_consumer = WorkItem(stage_id, Phase.BACKWARD, subscribe_backward_futures, {}, output, - microbatch_id, None, self.num_microbatches, False) + work_item_from_consumer = WorkItem( + stage_id, + Phase.BACKWARD, + subscribe_backward_futures, + {}, + output, + microbatch_id, + None, + self.num_microbatches, + False, + ) return work_item_from_consumer @@ -460,7 +507,7 @@ class WorkerBase(ABC): # On current PP middleware design for DAG, get_output_by_key used by subscribe_consumer # can only be executed once for every producer-consumer stage pair, which is necessary # to count the lifecycle of work_item. So, keeping the subscribe_consumer in the same - # lock of work_item queue operation gurantees the consistency of lifecycle counter. + # lock of work_item queue operation guarantees the consistency of lifecycle counter. work_item_from_consumer = self._subscribe_consumer(microbatch_id) self.work_list[key] = work_item_from_consumer self.work_list_condition_lock.notify_all() @@ -508,7 +555,7 @@ class WorkerBase(ABC): assert self.producer_stage_ids is None, f"all the producers of rank {rank} has been subscribed" assert self.consumer_stage_ids is None, f"all the consumers of rank {rank} has been subscribed" - # should be aranged in order, the order of the input of current forward + # should be arranged in order, the order of the input of current forward self.producer_stage_ids = self.get_producer_stage_ids() self.consumer_stage_ids = self.get_consumer_stage_ids() @@ -524,8 +571,8 @@ class WorkerBase(ABC): def get_topo(self): with self.partition_condition_lock: - self.partition_condition_lock.wait_for(lambda: hasattr(self, 'module_partition')) - if hasattr(self.module_partition, '_topo'): + self.partition_condition_lock.wait_for(lambda: hasattr(self, "module_partition")) + if hasattr(self.module_partition, "_topo"): return self.module_partition._topo else: return None @@ -564,12 +611,12 @@ class WorkerBase(ABC): if stage_id == src_stage_id: src_index += i break - else: # data from input partition + else: # data from input partition src_index = 0 # when output_len = 1, not iterable if target_index == src_index: if output_len == 1: - res = None # offset = None to get all outputs + res = None # offset = None to get all outputs return res else: res.append(src_offset) @@ -584,7 +631,6 @@ class WorkerBase(ABC): consumer_stage_ids = self.get_consumer_stage_ids() for val_list in output_vals: # An output may be passed to many down stages. - target = None for val_pos in val_list.get(): dst_partition_id = val_pos.partition_id dst_offset = val_pos.offset @@ -597,7 +643,7 @@ class WorkerBase(ABC): break if target_index == dst_index: if input_len == 1: - res = None # offset = None to get all outputs + res = None # offset = None to get all outputs return res else: res.append(dst_offset) @@ -623,7 +669,7 @@ class WorkerBase(ABC): flatten_args = [] if self.is_first_stage(): pytree_map(args_or_kwargs, fn=lambda x: flatten_args.append(x), map_all=True) - else: # get by offset + else: # get by offset topo: Topo = self.get_topo() self_partition_id = self.pp_rank_to_partition_id(self.pp_rank, topo) self_partition: Partition = topo.get_partition_by_id(self_partition_id) @@ -652,7 +698,7 @@ class WorkerBase(ABC): if stage_id == src_stage_id: src_index += i break - else: # data from input partition + else: # data from input partition src_index = 0 # when output_len = 1, not iterable if output_len == 1: @@ -679,7 +725,7 @@ class WorkerBase(ABC): else: for i, arg in enumerate(args_or_kwargs): args_or_kwargs[i] = arg.wait() - if args_or_kwargs is not None: # get by offset + if args_or_kwargs is not None: # get by offset flatten_args = [] topo: Topo = self.get_topo() self_partition_id = self.pp_rank_to_partition_id(self.pp_rank, topo) @@ -719,7 +765,7 @@ class WorkerBase(ABC): @abstractmethod def _get_work_item_key(self) -> UniqueKey: """ - this method control the order of the microbatch to consume + this method control the order of the microbatch to consume """ def is_first_stage(self): @@ -761,7 +807,7 @@ class WorkerBase(ABC): kwargs = work_item.kwargs microbatch_id = work_item.microbatch_id forward_only = work_item.forward_only - data_process_func = getattr(self, 'data_process_func', self._default_data_process_func) + data_process_func = getattr(self, "data_process_func", self._default_data_process_func) consume_result = None is_first_stage = self.is_first_stage() @@ -787,10 +833,12 @@ class WorkerBase(ABC): else: args_kwargs = self._get_real_args_kwargs_fwd(args) - args_kwargs = pyobj_map(args_kwargs, fn=lambda x: x.to(self.device).detach(), - process_types=torch.Tensor) # torch rpc doesn't support args or rets in GPU - args_kwargs = pyobj_map(args_kwargs, fn=lambda x: self.device, - process_types=torch.device) # change devices from last stage to current device + args_kwargs = pyobj_map( + args_kwargs, fn=lambda x: x.to(self.device).detach(), process_types=torch.Tensor + ) # torch rpc doesn't support args or rets in GPU + args_kwargs = pyobj_map( + args_kwargs, fn=lambda x: self.device, process_types=torch.device + ) # change devices from last stage to current device args, kwargs = data_process_func(args_kwargs) @@ -851,16 +899,16 @@ class WorkerBase(ABC): use_checkpoint = False if not forward_only: - self.microbatch_id_to_backward_cache[microbatch_id] = BackwardCache(stage_input_args, - stage_input_kwargs, - stage_outputs, - checkpoint=use_checkpoint) - consume_result = pyobj_map(consume_result, fn=lambda x: x.to('cpu'), - process_types=torch.Tensor) # torch rpc doesn't support args or rets in + self.microbatch_id_to_backward_cache[microbatch_id] = BackwardCache( + stage_input_args, stage_input_kwargs, stage_outputs, checkpoint=use_checkpoint + ) + consume_result = pyobj_map( + consume_result, fn=lambda x: x.to("cpu"), process_types=torch.Tensor + ) # torch rpc doesn't support args or rets in # if not forward_only, do the backward if not forward_only: - if is_last_stage: # if it is the last stage, trigger backward automatic + if is_last_stage: # if it is the last stage, trigger backward automatic self._begin_backward(microbatch_id) elif phase == Phase.BACKWARD: @@ -872,7 +920,9 @@ class WorkerBase(ABC): self.backward_times += 1 self.outstanding -= 1 - assert microbatch_id in self.microbatch_id_to_backward_cache, f"microbatch_id {microbatch_id} not in backward cache" + assert ( + microbatch_id in self.microbatch_id_to_backward_cache + ), f"microbatch_id {microbatch_id} not in backward cache" backward_cache = self.microbatch_id_to_backward_cache.pop(microbatch_id) stage_outputs = backward_cache.stage_outputs @@ -906,8 +956,9 @@ class WorkerBase(ABC): filtered_grads.append(grad) stage_outputs = filtered_outputs - grad_tensors = pyobj_map(filtered_grads, fn=lambda x: x.to(self.device), - process_types=torch.Tensor) # torch rpc doesn't support args or rets in GPU + grad_tensors = pyobj_map( + filtered_grads, fn=lambda x: x.to(self.device), process_types=torch.Tensor + ) # torch rpc doesn't support args or rets in GPU autograd.backward(stage_outputs, grad_tensors=grad_tensors) # collect grad of input tensor @@ -920,8 +971,8 @@ class WorkerBase(ABC): else: consume_result.append(None) consume_result = pyobj_map( - consume_result, fn=lambda x: x.to('cpu'), - process_types=torch.Tensor) # torch rpc doesn't support args or rets in GPU + consume_result, fn=lambda x: x.to("cpu"), process_types=torch.Tensor + ) # torch rpc doesn't support args or rets in GPU else: raise TypeError(f"Unknown phase appears in _consume_work_item_by_phase {phase}") @@ -929,7 +980,7 @@ class WorkerBase(ABC): return consume_result def _get_store_len(self): - return f'work_list:{len(self.work_list)} output_list:{len(self.output_list)} backward_cache:{len(self.microbatch_id_to_backward_cache)} label_cache:{len(self.microbatch_id_to_labels)}' + return f"work_list:{len(self.work_list)} output_list:{len(self.output_list)} backward_cache:{len(self.microbatch_id_to_backward_cache)} label_cache:{len(self.microbatch_id_to_labels)}" def _get_parameter_grad_sum(self): grad_sum = 0 @@ -1014,19 +1065,20 @@ class WorkerBase(ABC): class PipelineEngineBase(ABC, nn.Module): - - def __init__(self, - worker_type, - partition_fn: Callable, - stage_num, - num_microbatches, - device: str, - use_1F1B=False, - chunk: int = 1, - criterion: Callable = None, - metric: Callable = None, - checkpoint: bool = False, - data_process_func: Callable = None) -> None: + def __init__( + self, + worker_type, + partition_fn: Callable, + stage_num, + num_microbatches, + device: str, + use_1F1B=False, + chunk: int = 1, + criterion: Callable = None, + metric: Callable = None, + checkpoint: bool = False, + data_process_func: Callable = None, + ) -> None: super().__init__() self.worker_type = worker_type self.partition_fn: Callable = partition_fn @@ -1056,12 +1108,12 @@ class PipelineEngineBase(ABC, nn.Module): data_process_func = self.data_process_func if data_process_func is not None: assert callable(data_process_func), "data_process_func must be a function" - assert '' not in data_process_func.__repr__(), "data_process_func must be a global function" - assert '' not in data_process_func.__repr__(), "data_process_func cannot be a lambda expression" + assert "" not in data_process_func.__repr__(), "data_process_func must be a global function" + assert "" not in data_process_func.__repr__(), "data_process_func cannot be a lambda expression" sig = inspect.signature(data_process_func) - assert len( - sig.parameters - ) == 2, f"length of data_process_func' arguments must be 2, receive {len(sig.parameters)} arguments instead" + assert ( + len(sig.parameters) == 2 + ), f"length of data_process_func' arguments must be 2, receive {len(sig.parameters)} arguments instead" def _get_actual_stage_num(self) -> int: return self.stage_num if self.chunk == 1 else self.virtual_stage_num @@ -1104,19 +1156,33 @@ class PipelineEngineBase(ABC, nn.Module): partition_id = self.pp_rank_to_module_partition_id[pp_rank] partition_args = (partition_id, chunk, actual_stage_num) rpc_worker_id = self.pp_rank_to_rpc_worker_id[pp_rank] - if device[:4] == 'cuda': - device = f'cuda:{rpc_worker_id}' - self.pp_rank_to_worker_rref[pp_rank] = rpc.remote(rpc_worker_id, - worker_type, - args=(partition_fn, partition_args, pp_rank, - actual_stage_num, num_microbatches, device, - criterion, metric, checkpoint, data_process_func)) + if device[:4] == "cuda": + device = f"cuda:{rpc_worker_id}" + self.pp_rank_to_worker_rref[pp_rank] = rpc.remote( + rpc_worker_id, + worker_type, + args=( + partition_fn, + partition_args, + pp_rank, + actual_stage_num, + num_microbatches, + device, + criterion, + metric, + checkpoint, + data_process_func, + ), + ) # let each worker know global worker rref (include itself) sync_futs = [] for pp_rank in self.pp_rank_to_worker_rref: - fut = self.pp_rank_to_worker_rref[pp_rank].rpc_async(timeout=0).sync_global_worker_rrefs( - self.pp_rank_to_worker_rref) + fut = ( + self.pp_rank_to_worker_rref[pp_rank] + .rpc_async(timeout=0) + .sync_global_worker_rrefs(self.pp_rank_to_worker_rref) + ) sync_futs.append(fut) for fut in sync_futs: @@ -1157,8 +1223,9 @@ class PipelineEngineBase(ABC, nn.Module): def get_output_pp_ranks(self) -> List[int]: return [self._get_actual_stage_num() - 1] - def _consume_constraint(self, microbatch_id: int, forward_only: bool, input_pp_ranks: List[int], - output_pp_ranks: List[int], ret_future): + def _consume_constraint( + self, microbatch_id: int, forward_only: bool, input_pp_ranks: List[int], output_pp_ranks: List[int], ret_future + ): actual_stage_num = self._get_actual_stage_num() use_1F1B = self.use_1F1B if microbatch_id >= actual_stage_num: @@ -1206,7 +1273,8 @@ class PipelineEngineBase(ABC, nn.Module): worker_rref = self.pp_rank_to_worker_rref[pp_rank] key = UniqueKey(self.num_microbatches - 1, Phase.BACKWARD) fut = worker_rref.rpc_async().get_output_by_key( - key, offsets=[]) # only ensure the res exists, no need for real data. + key, offsets=[] + ) # only ensure the res exists, no need for real data. backward_result.append(fut) for fut in backward_result: @@ -1244,11 +1312,14 @@ class PipelineEngineBase(ABC, nn.Module): if labels is not None and not forward_only: assert hasattr( - self, 'optimizer_class'), "call `initialize_optimizer` to initialize optimizer before forward_backward" + self, "optimizer_class" + ), "call `initialize_optimizer` to initialize optimizer before forward_backward" num_microbatches = self.num_microbatches - assert batch_length >= num_microbatches, "num_microbatches is greater than the size of a batch, which is illegal" + assert ( + batch_length >= num_microbatches + ), "num_microbatches is greater than the size of a batch, which is illegal" microbatch_size = math.ceil(batch_length / num_microbatches) device = self.device @@ -1285,10 +1356,10 @@ class PipelineEngineBase(ABC, nn.Module): # collect forward result forward_result = self._collect_forward_result(output_pp_ranks, ret_future) - if not forward_only and hasattr(self, 'optimizer_class'): + if not forward_only and hasattr(self, "optimizer_class"): self.step() - self._reset_worker() # reset worker attributes for next batch + self._reset_worker() # reset worker attributes for next batch return forward_result def initialize_optimizer(self, optimizer_class: type, **kwargs): diff --git a/colossalai/legacy/pipeline/rpc/_pipeline_schedule.py b/colossalai/legacy/pipeline/rpc/_pipeline_schedule.py new file mode 100644 index 0000000000000000000000000000000000000000..56da2a9542255213e3772392b53a1acb98ae68d3 --- /dev/null +++ b/colossalai/legacy/pipeline/rpc/_pipeline_schedule.py @@ -0,0 +1,377 @@ +import threading +from typing import Callable, Dict, List + +import torch +from torch._C._distributed_rpc import PyRRef +from torch.futures import Future + +from colossalai.legacy.pipeline.pipeline_process_group import ppg +from colossalai.legacy.pipeline.rpc._pipeline_base import Phase, PipelineEngineBase, UniqueKey, WorkerBase, WorkItem + +# Implementation of different Pipeline schedule +# Worker defines the worker for each stage +# PipelineEngine is the class for use + + +class FillDrainWorker(WorkerBase): + def _get_work_item_key(self) -> UniqueKey: + # execute backward first (if backward phase in work_list) + num_microbatches = self.num_microbatches + + if self.forward_times < num_microbatches: + target_phase = Phase.FORWARD + target_microbatch_id = self.forward_times + else: + target_phase = Phase.BACKWARD + target_microbatch_id = self.backward_times + + target_key = UniqueKey(target_microbatch_id, target_phase) + + return target_key + + +class FillDrainPipelineEngine(PipelineEngineBase): + def __init__( + self, + partition_fn: Callable, + stage_num: int, + num_microbatches: int, + device: str, + chunk: int = 1, + criterion: Callable = None, + metric: Callable = None, + checkpoint: bool = False, + data_process_func: Callable = None, + ) -> None: + if chunk > 1: + assert ( + num_microbatches % stage_num == 0 + ), "if you use interleaving strategy, make sure 'num_microbatches' is a multiple of stage_num!" + use_1F1B = False + + super().__init__( + FillDrainWorker, + partition_fn, + stage_num, + num_microbatches, + device, + use_1F1B, + chunk, + criterion, + metric, + checkpoint, + data_process_func, + ) + + +class OneFOneBWorker(WorkerBase): + def _get_work_item_key(self) -> UniqueKey: + # execute backward first (if backward phase in work_list) + pp_rank = self.pp_rank + actual_stage_num = self.actual_stage_num + num_microbatches = self.num_microbatches + is_last_stage = pp_rank == actual_stage_num - 1 + + if self.outstanding <= self.outstanding_range[0]: + target_phase = Phase.FORWARD + target_microbatch_id = self.forward_times + elif self.outstanding >= self.outstanding_range[1]: + target_phase = Phase.BACKWARD + target_microbatch_id = self.backward_times + else: + raise ValueError("outstanding_range[1] - outstanding_range[0] must be in [0, 1]") + + target_key = UniqueKey(target_microbatch_id, target_phase) + + # change outstanding_range at: + # 1. forward times reach actual_stage_num, this is the end of continuous forward + # 2. forward times reach num_microbatches, this is the end of 1F1B mode + if not is_last_stage and target_key.phase == Phase.FORWARD: + if target_key.microbatch_id == actual_stage_num - 1 and num_microbatches > 2: + # Why need num_microbatches > 2 ? Because there is no steady stage when num_microbatches <= 2 + outstanding_min = actual_stage_num - pp_rank - 1 + outstanding_max = actual_stage_num - pp_rank + self.outstanding_range = (outstanding_min, outstanding_max) + if target_key.microbatch_id == num_microbatches - 1: + self.outstanding_range = (0, 0) + + return target_key + + +class OneFOneBPipelineEngine(PipelineEngineBase): + def __init__( + self, + partition_fn: Callable, + stage_num: int, + num_microbatches: int, + device: str, + chunk: int = 1, + criterion: Callable = None, + metric: Callable = None, + checkpoint: bool = False, + data_process_func: Callable = None, + ) -> None: + if chunk > 1: + assert ( + num_microbatches % stage_num == 0 + ), "if you use interleaving strategy, make sure 'num_microbatches' is a multiple of stage_num!" + # assert num_microbatches > stage_num * chunk, "num_microbatches must be greater than stage_num * chunk" + use_1F1B = True + + super().__init__( + OneFOneBWorker, + partition_fn, + stage_num, + num_microbatches, + device, + use_1F1B, + chunk, + criterion, + metric, + checkpoint, + data_process_func, + ) + + +class ChimeraWorker(WorkerBase): + def _get_producer_consumer(self) -> None: + rank = self.pp_rank + min_pp_rank = (rank // self.actual_stage_num) * self.actual_stage_num + max_pp_rank = min_pp_rank + self.actual_stage_num - 1 + + assert self.producer_stage_ids is None, f"all the producers of rank {rank} has been subscribed" + assert self.consumer_stage_ids is None, f"all the consumers of rank {rank} has been subscribed" + + # should be arranged in order, the order of the input of current forward + self.producer_stage_ids = [] + self.consumer_stage_ids = [] + + # Just for demo + prev_rank = rank - 1 + next_rank = rank + 1 + if prev_rank >= min_pp_rank: + self.producer_stage_ids.append(prev_rank) + if next_rank <= max_pp_rank: + self.consumer_stage_ids.append(next_rank) + + def _get_work_item_key(self) -> UniqueKey: + pp_rank = self.pp_rank + stage_num = self.actual_stage_num + real_microbatch_num = self.num_microbatches // 2 + + forward_block_size = 1 if self.num_microbatches < stage_num else self.num_microbatches // stage_num + forward_block_num = self.forward_times // forward_block_size + + if self.forward_times >= real_microbatch_num or ( + (pp_rank + 1) % stage_num == 0 and forward_block_num > self.backward_times + ): + target_phase = Phase.BACKWARD + target_microbatch_id = self.backward_times + else: # others + target_phase = Phase.FORWARD + target_microbatch_id = self.forward_times + + # In up pipeline, microbatch_id to consume is 0, 2, 4 (2n) + # In down pipeline, microbatch_id to consume is 1, 3, 5 (2n + 1) + real_target_microbatch_id = target_microbatch_id * 2 + if pp_rank >= stage_num: + real_target_microbatch_id += 1 + target_key = UniqueKey(real_target_microbatch_id, target_phase) + + with self.work_list_condition_lock: + self.work_list_condition_lock.wait_for(lambda: target_key in self.work_list) + return target_key + + def _initialize_partition(self): + # In order to ensure the down pipeline share the same parameter + # with the up pipeline, partition of down partition will be copied + # from corresponding up stage + pp_rank = self.pp_rank + stage_num = self.actual_stage_num + self.device + if pp_rank < stage_num: + super()._initialize_partition() + else: + # if it is down pipeline, create partition by origin method + co_up_pp_worker_rref = self.pp_rank_to_worker_rref[pp_rank - stage_num] + # get the corresponding model state dict and wait for its init + state_dict = co_up_pp_worker_rref.rpc_sync().get_partition_state_dict() + super()._initialize_partition() + self.module_partition.load_state_dict(state_dict) + + # init group for chimera in ppg + ppg.get_chimera_all_reduce_group(pp_rank) + + # lock for step sync + self.step_sync_lock = threading.Lock() + self.step_sync_lock.acquire() + + self.have_grad_lock = threading.Lock() + self.have_grad_lock.acquire() + + def _get_lock_gradient(self): + self.have_grad_lock.acquire() + grads = self.get_parameter_gradients() + self.step_sync_lock.release() + return grads + + def is_first_stage(self): + return (self.pp_rank % self.actual_stage_num) == 0 + + def is_last_stage(self): + return (self.pp_rank % self.actual_stage_num) == self.actual_stage_num - 1 + + def _is_last_step(self, work_item: WorkItem) -> bool: + if work_item.forward_only: + last_phase = Phase.FORWARD + else: + last_phase = Phase.BACKWARD + is_last_phase = work_item.phase == last_phase + last_microbatch_id = self.num_microbatches - 1 + if self.pp_rank < self.actual_stage_num: + last_microbatch_id -= 1 + is_last_microbatch = work_item.microbatch_id == last_microbatch_id + return is_last_phase and is_last_microbatch + + def _get_step_order(self) -> List[int]: + # TODO : If you want to extend it to multi head chimera, overwrite here + stage_num = self.actual_stage_num + pp_rank = self.pp_rank + # pp_rank in the same device + local_device_pp_ranks = [pp_rank, stage_num * 2 - pp_rank - 1] + local_device_pp_ranks.sort(reverse=min(local_device_pp_ranks) < stage_num // 2) + return local_device_pp_ranks + + def _hook_before_step(self): + self.have_grad_lock.release() + pp_rank = self.pp_rank + stage_num = self.actual_stage_num + co_pp_rank = (pp_rank + stage_num) % (2 * stage_num) + + # if current pp_rank is not the first to do step + # wait its previous pp_rank finish step + grads = self.get_parameter_gradients() + + # send + co_worker = self.pp_rank_to_worker_rref[co_pp_rank] + co_grads = co_worker.rpc_sync()._get_lock_gradient() + # sync + self.step_sync_lock.acquire() + for i in range(len(grads)): + grads[i] += co_grads[i] + + +class ChimeraPipelineEngine(PipelineEngineBase): + def __init__( + self, + partition_fn: Callable, + stage_num: int, + num_microbatches: int, + device: str, + criterion: Callable = None, + metric: Callable = None, + checkpoint: bool = False, + data_process_func: Callable = None, + ) -> None: + assert num_microbatches % stage_num == 0, "In Chimera, num_microbatches must be the multiply of stage_num!" + use_1F1B = False + chunk = 1 + + super().__init__( + ChimeraWorker, + partition_fn, + stage_num, + num_microbatches, + device, + use_1F1B, + chunk, + criterion, + metric, + checkpoint, + data_process_func, + ) + + def _consume_constraint( + self, microbatch_id: int, forward_only: bool, input_pp_ranks: List[int], output_pp_ranks: List[int], ret_future + ): + pass + + def _create_pp_rank_to_rpc_worker_id(self) -> None: + stage_num = self.stage_num + self.pp_rank_to_rpc_worker_id = [0] * (stage_num * 2) + for pp_rank in range(stage_num): + self.pp_rank_to_rpc_worker_id[pp_rank] = pp_rank + self.pp_rank_to_rpc_worker_id[pp_rank + stage_num] = stage_num - pp_rank - 1 + + def _create_pp_rank_to_module_partition_id(self) -> None: + stage_num = self.stage_num + self.pp_rank_to_module_partition_id = [0] * (stage_num * 2) + for pp_rank in range(stage_num): + self.pp_rank_to_module_partition_id[pp_rank] = pp_rank + self.pp_rank_to_module_partition_id[pp_rank + stage_num] = pp_rank + + def _create_ret_future(self, output_pp_ranks: List[int]) -> Dict[int, List[Future]]: + num_microbatches = self.num_microbatches + stage_num = self.stage_num + up_ret_future = {pp_rank: [None] * num_microbatches for pp_rank in output_pp_ranks} + down_ret_future = {pp_rank + stage_num: [None] * num_microbatches for pp_rank in output_pp_ranks} + # merge up and down + return {**up_ret_future, **down_ret_future} + + def _set_input(self, input_pp_ranks: List[int], microbatch_id: int, microbatch, forward_only: bool): + # offset is 0 for all the ranks in up pipeline + # offset is stage_num for all the ranks in down pipeline + offset = (microbatch_id % 2) * self.stage_num + for pp_rank in input_pp_ranks: + worker_rref = self.pp_rank_to_worker_rref[pp_rank + offset] + worker_rref.remote().set_input(microbatch_id, microbatch, forward_only) + + def _set_labels(self, output_pp_ranks: List[int], microbatch_id: int, microlabels): + # offset is 0 for all the ranks in up pipeline + # offset is stage_num for all the ranks in down pipeline + offset = (microbatch_id % 2) * self.stage_num + for pp_rank in output_pp_ranks: + worker_rref = self.pp_rank_to_worker_rref[pp_rank + offset] + worker_rref.remote().set_labels(microbatch_id, microlabels) + + def _subscribe_forward(self, microbatch_id: int, output_pp_ranks: List[int], ret_future: Dict[int, List[Future]]): + key = UniqueKey(microbatch_id, Phase.FORWARD) + offset = (microbatch_id % 2) * self.stage_num + for pp_rank in output_pp_ranks: + worker_rref = self.pp_rank_to_worker_rref[pp_rank + offset] + ret_future[pp_rank + offset][microbatch_id] = worker_rref.rpc_async().get_output_by_key(key) + + def _ensure_backward(self, forward_only: bool, input_pp_ranks: List[int]): + stage_num = self.stage_num + num_microbatches = self.num_microbatches + if not forward_only: + for pp_rank in input_pp_ranks: + up_last_microbatch_id = num_microbatches - 2 + down_last_microbatch_id = num_microbatches - 1 + + up_worker_rref = self.pp_rank_to_worker_rref[pp_rank] + down_worker_rref = self.pp_rank_to_worker_rref[pp_rank + stage_num] + + up_key = UniqueKey(up_last_microbatch_id, Phase.BACKWARD) + down_key = UniqueKey(down_last_microbatch_id, Phase.BACKWARD) + up_worker_rref.rpc_sync().get_output_by_key(up_key) + down_worker_rref.rpc_sync().get_output_by_key(down_key) + + def _collect_forward_result(self, output_pp_ranks: List[int], ret_future: Dict[PyRRef, List[Future]]): + """Logic of collection of forward in Chimera. + Currently, only one input one output model is supported + """ + stage_num = self.stage_num + forward_result = [] + for pp_rank in output_pp_ranks: + worker_forward_result = [None] * self.num_microbatches + for microbatch_id in range(self.num_microbatches): + offset = (microbatch_id % 2) * stage_num + ret = ret_future[pp_rank + offset][microbatch_id].wait() + ret = [ret] if isinstance(ret, torch.Tensor) else ret + worker_forward_result[microbatch_id] = ret + + worker_forward_result = list(zip(*worker_forward_result)) + forward_result.extend(worker_forward_result) + + return forward_result diff --git a/colossalai/legacy/pipeline/rpc/utils.py b/colossalai/legacy/pipeline/rpc/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..808de301a2a0987babbe76cf1967950e0d4f42e6 --- /dev/null +++ b/colossalai/legacy/pipeline/rpc/utils.py @@ -0,0 +1,157 @@ +import argparse +import os +import warnings +from typing import Any, Callable, Tuple, Type, Union + +import torch +import torch.distributed.rpc as rpc +import torch.multiprocessing as mp +from torch._C._distributed_rpc import _is_current_rpc_agent_set +from torch.futures import Future + +from colossalai.initialize import launch +from colossalai.legacy.pipeline.pipeline_process_group import ppg + + +def pyobj_map(obj: Any, fn: Callable, process_types: Union[Type, Tuple[Type]] = ()) -> Any: + if isinstance(obj, process_types): + return fn(obj) + elif type(obj) is dict: + return {k: pyobj_map(obj[k], fn, process_types) for k in obj} + elif type(obj) is tuple: + return tuple(pyobj_map(o, fn, process_types) for o in obj) + elif type(obj) is list: + return list(pyobj_map(o, fn, process_types) for o in obj) + else: + return obj + + +def pytree_map(obj: Any, fn: Callable, process_types: Union[Type, Tuple[Type]] = (), map_all: bool = False) -> Any: + """process object recursively, like pytree + + Args: + obj (:class:`Any`): object to process + fn (:class:`Callable`): a function to process subobject in obj + process_types (:class: `type | tuple[type]`): types to determine the type to process + map_all (:class: `bool`): if map_all is True, then any type of element will use fn + + Returns: + :class:`Any`: returns have the same structure of `obj` and type in process_types after map of `fn` + """ + if isinstance(obj, dict): + return {k: pytree_map(obj[k], fn, process_types, map_all) for k in obj} + elif isinstance(obj, tuple): + return tuple(pytree_map(o, fn, process_types, map_all) for o in obj) + elif isinstance(obj, list): + return list(pytree_map(o, fn, process_types, map_all) for o in obj) + elif isinstance(obj, process_types): + return fn(obj) + else: + return fn(obj) if map_all else obj + + +def tensor_shape_list(obj): + return pytree_map(obj, fn=lambda x: x.shape, process_types=torch.Tensor) + + +def get_batch_lengths(batch): + lengths = [] + pytree_map(batch, fn=lambda x: lengths.append(len(x)), process_types=torch.Tensor) + return lengths + + +def split_batch(batch: Any, start, stop, device: str): + if device == "cuda": + fn = lambda x: x[start:stop].cuda() + else: + fn = lambda x: x[start:stop] + return pytree_map(batch, fn=fn, process_types=torch.Tensor) + + +def type_detail(obj): + return pytree_map(obj, lambda x: type(x), map_all=True) + + +def pytree_filter(fn, obj, process_types): + if obj is None: + return None + + filters = [] + + def condition_append(obj): + if fn(obj): + filters.append(obj) + + pytree_map(obj, fn=condition_append, process_types=process_types) + return filters + + +def get_real_args_kwargs(args_or_kwargs): + args_or_kwargs = pytree_map(args_or_kwargs, fn=lambda x: x.wait(), process_types=Future) + # TODO : combine producer and consumer + # by default, merge all args in the output args or kwargs + if args_or_kwargs is not None: + if isinstance(args_or_kwargs, dict): + pass + else: + flatten_args = [] + pytree_map(args_or_kwargs, fn=lambda x: flatten_args.append(x), map_all=True) + args_or_kwargs = flatten_args + + return args_or_kwargs + + +def run_worker(rank, args, master_func): + os.environ["MASTER_ADDR"] = args.master_addr + os.environ["MASTER_PORT"] = args.master_port + + device = args.device + world_size = args.world_size + dp_degree = args.dp_degree + tp_degree = args.tp_degree + num_worker_threads = args.num_worker_threads + host = args.master_addr + port = args.master_port + backend = "nccl" if device == "cuda" else "gloo" + + launch(dict(), rank, world_size, host, int(port), backend, verbose=False) + ppg.set_global_info( + rank=rank, + world_size=world_size, + dp_degree=dp_degree, + tp_degree=tp_degree, + num_worker_threads=num_worker_threads, + device=device, + ) + ppg.args = args + # in rpc mode, only rank 0 is needed to be coded + if rank == 0: + master_func(args) + # barrier here + if _is_current_rpc_agent_set(): + rpc.shutdown() + else: + warnings.warn("RPC has not been initialized") + + +def rpc_run(args, master_func): + world_size = args.world_size + mp.spawn(run_worker, args=(args, master_func), nprocs=world_size) + + +def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument("--epoch", type=int, default=1) + parser.add_argument("--world_size", type=int, default=2) + parser.add_argument("--batch_size", type=int, default=16) + parser.add_argument("--dp_degree", type=int, default=1) + parser.add_argument("--tp_degree", type=int, default=1) + parser.add_argument("--num_microbatches", type=int, default=2) + parser.add_argument("--chunk", type=int, default=1) + parser.add_argument("--use_checkpoint", action="store_true") + parser.add_argument("--optimizer", type=str, choices=["SGD", "Adam", "RMSprop"], default="SGD") + parser.add_argument("--device", type=str, choices=["cpu", "cuda"], default="cuda") + parser.add_argument("--master_addr", type=str, default="localhost") + parser.add_argument("--master_port", type=str, default="29020") + parser.add_argument("--num_worker_threads", type=int, default=128) + return parser.parse_args() diff --git a/colossalai/legacy/pipeline/utils.py b/colossalai/legacy/pipeline/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..182af677c0478a962dbf444bb68f583514c56207 --- /dev/null +++ b/colossalai/legacy/pipeline/utils.py @@ -0,0 +1,276 @@ +import heapq +import inspect +from collections import OrderedDict +from typing import List + +import torch + +from colossalai.legacy.nn.layer.utils import CheckpointModule +from colossalai.logging import get_dist_logger + + +def _binary_partition(weights: List, start: int, end: int): + """Returns the binary partition position of `weights`, given the start + position `st` and the end position `ed`. + + Args: + weights (list): A python list to be binary partitioned + start (int): the start position of the binary partition + end (int): the end position of the binary partition + + Returns: + int: the binary partition position of `weights` + """ + w_sum = weights[end - 1] + prefix = 0 + if start > 0: + w_sum -= weights[start - 1] + prefix = weights[start - 1] + minimum = float("inf") + for idx in range(start + 1, end): + front = weights[idx - 1] - prefix + diff = abs(w_sum - 2 * front) + if diff < minimum: + pos = idx + minimum = diff + + return start, pos, end + + +def _heap_addition(weights: List, intervals: int, add_cnt: int): + """ """ + + def _heap_push(heap, st, ed): + value = weights[ed - 1] + if st > 0: + value -= weights[st - 1] + heapq.heappush(heap, (-value, st, ed)) + + ret_intervals = [] + heap = [] + + for st, ed in intervals: + _heap_push(heap, st, ed) + + while add_cnt > 0: + _, st, ed = heapq.heappop(heap) + if ed - st == 1: + ret_intervals.append((st, ed)) + else: + l, m, r = _binary_partition(weights, st, ed) + _heap_push(heap, l, m) + _heap_push(heap, m, r) + add_cnt -= 1 + + while heap: + _, st, ed = heapq.heappop(heap) + ret_intervals.append((st, ed)) + + ret_intervals.sort() + return ret_intervals + + +def _calc_partitions(weights, value): + prev = 0 + prefix = 0 + num_block = 0 + intervals = [] + + for idx, w in enumerate(weights): + if weights[idx] - prefix > value: + intervals.append((prev, idx)) + prev = idx + prefix = weights[idx - 1] + num_block += 1 + + intervals.append((prev, len(weights))) + return num_block + 1, intervals + + +def _binary_search(weights, num): + length = len(weights) + prefix = [1 if w == 0 else w for w in weights] + for i in range(1, length): + prefix[i] += prefix[i - 1] + + lower_bound = max(weights) + upper_bound = prefix[length - 1] + + while upper_bound > lower_bound: + mid = (upper_bound + lower_bound) // 2 + number, _ = _calc_partitions(prefix, mid) + if number <= num: + upper_bound = mid + else: + lower_bound = mid + 1 + + num_block, intervals = _calc_partitions(prefix, upper_bound) + if num_block < num: + intervals = _heap_addition(prefix, intervals, num - num_block) + + return intervals + + +def partition_uniform(num_items, pipeline_parallel_size, num_chunks): + assert ( + num_items % num_chunks == 0 + ), "Layer length should be divided by the number of chunks, otherwise parameter method is recommended" + + logger = get_dist_logger() + parts = [[] for _ in range(pipeline_parallel_size)] + partition_items = num_items // num_chunks + for idx in range(num_chunks): + base_idx = idx * partition_items + chunk_size = partition_items // pipeline_parallel_size + left = pipeline_parallel_size - partition_items % pipeline_parallel_size + if chunk_size == 0: + logger.warning("Some nodes in Pipeline have no requests") + + for p in range(pipeline_parallel_size): + st = base_idx + base_idx += chunk_size + (p >= left) + parts[p].append((st, base_idx)) + + return parts + + +def partition_balanced(weights, pipeline_parallel_size, num_chunks): + num_total = pipeline_parallel_size * num_chunks + num_items = len(weights) + if num_items <= num_total: + return partition_uniform(num_items, pipeline_parallel_size, num_chunks) + + intervals = _binary_search(weights, num_total) + + current = 0 + parts = [[] for _ in range(pipeline_parallel_size)] + for inter in intervals: + parts[current].append(inter) + current = (current + 1) % pipeline_parallel_size + + return parts + + +def build_kwargs_for_module(function, input_tensor, kw_dict): + """ + Generally, the first argument of module.forward is an input tensor come from the previous layer. + Therefore, we just filter the kwargs from second element of the dictionary. + """ + sig = inspect.signature(function) + if input_tensor is None: + kwargs_offset = 0 + elif isinstance(input_tensor, torch.Tensor): + kwargs_offset = 1 + elif isinstance(input_tensor, (tuple, OrderedDict)): + # assert isinstance(input_tensor, tuple), f'input_tensor should be a torch.Tensor or a tuple object.' + # Huggingface will take their own structures based on OrderedDict as the output + # between layers so we've to close this check. + kwargs_offset = len(input_tensor) + args_name_list = list(sig.parameters.keys()) + kw_dict = {k: v for k, v in kw_dict.items() if k in args_name_list[kwargs_offset:]} + if len(kw_dict) == 0: + return None + return kw_dict + + +def build_kwargs_for_function(function, kw_dict): + sig = inspect.signature(function) + kw_dict = {k: v for k, v in kw_dict.items() if k in sig.parameters} + if len(kw_dict) == 0: + return None + return kw_dict + + +def exec_func_with_kwargs(func, kw_dict, input_tensor, kwargs): + """ + We suppose the callable object passed to to_layer_list method in two purpose: + a. use the callable object to modify input tensor, such as \ + lambda x: torch.flatten(x, 1) + b. use the callable object to modify kwargs value, such as \ + def foo(attention_mask=None): + if attention_mask is not None: + batch_size = input_ids.shape[0] + attention_mask = attention_mask.view(batch_size, -1) + return attention_mask + """ + + if kw_dict is not None: + rst = func(**kw_dict) + if isinstance(rst, tuple): + for i, k in enumerate(kw_dict.keys()): + kwargs[k] = rst[i] + else: + for k in kw_dict.keys(): + kwargs[k] = rst + return input_tensor + if isinstance(input_tensor, tuple): + assert len(input_tensor) > 0, f"input_tensor should not be empty, when kw_dict is None." + sig = inspect.signature(func) + func_args_num = len(sig.parameters) + assert func_args_num <= len( + input_tensor + ), f"func requires {func_args_num} arguments, but input_tensors only have {len(input_tensor)}." + if func_args_num < len(input_tensor): + return func(*input_tensor[:func_args_num]) + else: + return func(*input_tensor) + assert isinstance(input_tensor, torch.Tensor), "input_tensor should be a type of torch.Tensor or tuple." + return func(input_tensor) + + +def exec_funcs_with_kwargs(func_dict, func_key, input_tensor, kwargs): + assert func_key in func_dict, f"{func_key} is not in the function_dict." + funcs_to_exec = func_dict[func_key] + if isinstance(funcs_to_exec, list): + for f in funcs_to_exec: + f_kwargs = build_kwargs_for_function(f, kwargs) + input_tensor = exec_func_with_kwargs(f, f_kwargs, input_tensor, kwargs) + else: + f_kwargs = build_kwargs_for_function(funcs_to_exec, kwargs) + input_tensor = exec_func_with_kwargs(funcs_to_exec, f_kwargs, input_tensor, kwargs) + + return input_tensor + + +def call_module(module, args=None, kwargs=None): + if args is None: + args = () + if kwargs is None: + kwargs = {} + if isinstance(module, CheckpointModule): + forward_func = module._forward + else: + forward_func = module.forward + sig = inspect.signature(forward_func) + param_nums = len(sig.parameters) + len(args) + len(kwargs) + args_needed_nums = param_nums - len(kwargs) + args_needed = args[:args_needed_nums] + if isinstance(module, CheckpointModule): + convert_kwargs_to_args = [] + for v in kwargs.values(): + convert_kwargs_to_args.append(v) + return module(*args_needed, *convert_kwargs_to_args) + else: + return module(*args_needed, **kwargs) + + +def customized_partition(exec_seq): + """ + This function will analyze the exec_seq. In the exec_seq, users will use 'SPLIT_NODE' as an + annotation to note the partition point. + """ + customized_parts = {} + start = 0 + stop = 0 + rank = 0 + for element in exec_seq: + if isinstance(element, str): + if element == "SPLIT_NODE": + customized_parts[rank] = [(start, stop)] + start = stop + rank += 1 + else: + stop += 1 + customized_parts[rank] = [(start, stop)] + return customized_parts diff --git a/colossalai/registry/__init__.py b/colossalai/legacy/registry/__init__.py similarity index 100% rename from colossalai/registry/__init__.py rename to colossalai/legacy/registry/__init__.py diff --git a/colossalai/registry/registry.py b/colossalai/legacy/registry/registry.py similarity index 95% rename from colossalai/registry/registry.py rename to colossalai/legacy/registry/registry.py index 8a4173f7ab992079d180322245d25cbb9010b07c..43644f8a9e73dc402ab56b1d39db3d22faa1a3bf 100644 --- a/colossalai/registry/registry.py +++ b/colossalai/legacy/registry/registry.py @@ -6,7 +6,7 @@ from typing import List class Registry: - """This is a registry class used to register classes and modules so that a universal + """This is a registry class used to register classes and modules so that a universal object builder can be enabled. Args: @@ -42,7 +42,7 @@ class Registry: return module_class def get_module(self, module_name: str): - """Retrieves a module with name `module_name` and returns the module if it has + """Retrieves a module with name `module_name` and returns the module if it has already been registered before. Args: @@ -59,7 +59,7 @@ class Registry: for lib in self._third_party_lib: if hasattr(lib, module_name): return getattr(lib, module_name) - raise NameError(f'Module {module_name} not found in the registry {self.name}') + raise NameError(f"Module {module_name} not found in the registry {self.name}") def has(self, module_name: str): """Searches for a module with name `module_name` and returns a boolean value indicating diff --git a/colossalai/legacy/tensor/__init__.py b/colossalai/legacy/tensor/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..a34870eba068b554ebe743cd324b8237e44df0d4 --- /dev/null +++ b/colossalai/legacy/tensor/__init__.py @@ -0,0 +1,17 @@ +from . import distspec +from .compute_spec import ComputePattern, ComputeSpec +from .dist_spec_mgr import DistSpecManager +from .distspec import ReplicaSpec, ShardSpec +from .process_group import ProcessGroup +from .tensor_spec import ColoTensorSpec + +__all__ = [ + "ComputePattern", + "ComputeSpec", + "distspec", + "DistSpecManager", + "ProcessGroup", + "ColoTensorSpec", + "ShardSpec", + "ReplicaSpec", +] diff --git a/colossalai/tensor/compute_spec.py b/colossalai/legacy/tensor/compute_spec.py similarity index 86% rename from colossalai/tensor/compute_spec.py rename to colossalai/legacy/tensor/compute_spec.py index 12f8f36bc61318910edfa3a0e5ece9cd81b6aafe..820aafab687fe09142eed8ae92272650f6d24023 100644 --- a/colossalai/tensor/compute_spec.py +++ b/colossalai/legacy/tensor/compute_spec.py @@ -23,7 +23,7 @@ class ComputeSpec(object): self.output_replicate = True def __repr__(self): - return f'ComputeSpec(pattern={self.compute_pattern}, replicate_output={self.output_replicate})' + return f"ComputeSpec(pattern={self.compute_pattern}, replicate_output={self.output_replicate})" def set_output_replicate(self, flag: bool = True): self.output_replicate = flag diff --git a/colossalai/legacy/tensor/const.py b/colossalai/legacy/tensor/const.py new file mode 100644 index 0000000000000000000000000000000000000000..cbc2b29d66a8c1e01d9d928e74b333cbfd43cac0 --- /dev/null +++ b/colossalai/legacy/tensor/const.py @@ -0,0 +1,6 @@ +from enum import Enum + + +class TensorType(Enum): + MODEL = 0 + NONMODEL = 1 # mainly activations diff --git a/colossalai/legacy/tensor/dist_spec_mgr.py b/colossalai/legacy/tensor/dist_spec_mgr.py new file mode 100644 index 0000000000000000000000000000000000000000..3942b5b7a33c4a2334bf56ef6bbe9b8071b1f98d --- /dev/null +++ b/colossalai/legacy/tensor/dist_spec_mgr.py @@ -0,0 +1,206 @@ +from contextlib import contextmanager + +import torch +import torch.distributed as dist +from numpy import prod + +from colossalai.legacy.tensor.distspec import DistPlacementPattern, _DistSpec +from colossalai.legacy.tensor.process_group import ProcessGroup + + +# TODO(jiaruifang) circle import, move the divide to colossalai.commons. +# colossalai.legacy.tensor shall not import any submodule from colossal.nn +def divide(numerator, denominator): + """Only allow exact division. + + Args: + numerator (int): Numerator of the division. + denominator (int): Denominator of the division. + + Returns: + int: the result of exact division. + """ + assert denominator != 0, "denominator can not be zero" + assert numerator % denominator == 0, "{} is not divisible by {}".format(numerator, denominator) + return numerator // denominator + + +class TransformDistSpec(torch.autograd.Function): + @staticmethod + def forward(ctx, tensor, old_dist_spec, dist_spec, pg, forward_trans_func, backward_trans_func): + ctx.old_dist_spec = old_dist_spec + ctx.dist_spec = dist_spec + ctx.backward_trans_func = backward_trans_func + ctx.pg = pg + return forward_trans_func(tensor, old_dist_spec, dist_spec, pg) + + @staticmethod + def backward(ctx, grad_outputs): + return ( + ctx.backward_trans_func(grad_outputs, ctx.dist_spec, ctx.old_dist_spec, ctx.pg), + None, + None, + None, + None, + None, + ) + + +class DistSpecManager: + _use_autograd_function: bool = True + + @staticmethod + def _sanity_check(old_dist_spec: _DistSpec, dist_spec: _DistSpec) -> None: + pass + + @staticmethod + def _shard_as( + tensor: torch.Tensor, old_dist_spec: _DistSpec, dist_spec: _DistSpec, pg: ProcessGroup + ) -> torch.Tensor: + """_shard_as: shard the tensor w.r.t a distributed specification. + Assuming the tensor passed in is a global (replicated) tensor. + Args: + tensor (torch.Tensor): a global (replicated) tensor before shard + dist_spec (_DistSpec): the distributed spec. to be sharded as. + pg (ProcessGroup): the process group of the corresponding colotensor + Returns: + torch.Tensor: a torch tensor after sharded. + """ + assert ( + old_dist_spec.placement.value == "r" + ), f"The old_dist_spec of DistSpecManager._shard_as must be REPLICATE!" + DistSpecManager._sanity_check(old_dist_spec, dist_spec) + + chunk = tensor + idx = pg.tp_local_rank() + num_parts = prod(dist_spec.num_partitions) + for i, dim in enumerate(dist_spec.dims): + num_parts //= dist_spec.num_partitions[i] + + chunk_size = divide(tensor.size(dim), dist_spec.num_partitions[i]) + chunk = chunk.narrow(dim, idx // num_parts * chunk_size, chunk_size) + idx %= num_parts + return chunk.clone().detach().contiguous() + + @staticmethod + def _gather(tensor: torch.Tensor, old_dist_spec: _DistSpec, pg: ProcessGroup) -> torch.Tensor: + """_gather gather sharded tensors to a replicated one. + Args: + tensor (torch.Tensor): a shared torch tensor + old_dist_spec (_DistSpec): the distributed spec. of the tensor. + + Returns: + torch.Tensor: a replicated tensor. + """ + assert old_dist_spec.placement.value == "s", f"The old_dist_spec of DistSpecManager._gather must be SHARD!" + is_cpu_tensor = False + if tensor.device.type == "cpu": + # pytorch lower than 1.11 dose not support gather a cpu tensor. + # Therefore, we transfer tensor to GPU before gather. + saved_dev = tensor.device + tensor.data = tensor.data.cuda() + is_cpu_tensor = True + + buffer = [torch.empty_like(tensor) for _ in range(pg.tp_world_size())] + assert tensor.device.type == "cuda" + dist.all_gather(buffer, tensor, group=pg.tp_process_group()) + for i in range(len(old_dist_spec.dims) - 1, -1, -1): + new_buffer = [] + dim = old_dist_spec.dims[i] + num_parts = old_dist_spec.num_partitions[i] + for start in range(0, len(buffer), num_parts): + new_buffer.append(torch.cat(buffer[start : start + num_parts], dim)) + buffer = new_buffer + assert len(buffer) == 1 + + if is_cpu_tensor: + buffer[0].data = buffer[0].data.to(saved_dev) + return buffer[0] + + @staticmethod + def _all_to_all( + tensor: torch.Tensor, old_dist_spec: _DistSpec, dist_spec: _DistSpec, pg: ProcessGroup + ) -> torch.Tensor: + world_size = pg.tp_world_size() + if world_size == 1: + return tensor + + assert tensor.device.type == "cuda", ( + "Currently, only CUDA Tensor with NCCL backend is supported for the requested AlltoAll " + f"collective function, however, we got {tensor.device.type} device" + ) + + gather_dim = old_dist_spec.dims[0] + scatter_dim = dist_spec.dims[0] + shapes = list(tensor.shape) + scattered_dim_size = shapes[scatter_dim] // world_size + gathered_dim_size = shapes[gather_dim] * world_size + shapes[scatter_dim] = scattered_dim_size + + scatter_list = [t.contiguous() for t in torch.tensor_split(tensor, world_size, scatter_dim)] + gather_list = [torch.empty(*shapes, dtype=tensor.dtype, device=tensor.device) for _ in range(world_size)] + dist.all_to_all(gather_list, scatter_list, group=pg.tp_process_group()) + + output_ = torch.cat(gather_list, dim=gather_dim).contiguous() + assert output_.shape[scatter_dim] == scattered_dim_size and output_.shape[gather_dim] == gathered_dim_size + return output_ + + @staticmethod + def _r2r(tensor: torch.Tensor, old_dist_spec: _DistSpec, dist_spec: _DistSpec, pg: ProcessGroup) -> torch.Tensor: + DistSpecManager._sanity_check(old_dist_spec, dist_spec) + return tensor + + @staticmethod + def _r2s(tensor: torch.Tensor, old_dist_spec: _DistSpec, dist_spec: _DistSpec, pg: ProcessGroup) -> torch.Tensor: + DistSpecManager._sanity_check(old_dist_spec, dist_spec) + return DistSpecManager._shard_as(tensor, old_dist_spec, dist_spec, pg) + + @staticmethod + def _s2r(tensor: torch.Tensor, old_dist_spec: _DistSpec, dist_spec: _DistSpec, pg: ProcessGroup) -> torch.Tensor: + DistSpecManager._sanity_check(old_dist_spec, dist_spec) + return DistSpecManager._gather(tensor, old_dist_spec, pg) + + @staticmethod + def _s2s(tensor: torch.Tensor, old_dist_spec: _DistSpec, dist_spec: _DistSpec, pg: ProcessGroup) -> torch.Tensor: + DistSpecManager._sanity_check(old_dist_spec, dist_spec) + if old_dist_spec == dist_spec: + return tensor + if len(old_dist_spec.dims) == 1 and len(dist_spec.dims) == 1: + # use all-to-all to save memory + return DistSpecManager._all_to_all(tensor, old_dist_spec, dist_spec, pg) + tensor = DistSpecManager._gather(tensor, old_dist_spec, pg) + return DistSpecManager._shard_as(tensor, old_dist_spec, dist_spec, pg) + + @staticmethod + def handle_trans_spec( + tensor: torch.Tensor, old_dist_spec: _DistSpec, dist_spec: _DistSpec, pg: ProcessGroup + ) -> torch.Tensor: + assert isinstance(old_dist_spec, _DistSpec), f"{type(old_dist_spec)} should be _DistSpec" + assert isinstance(dist_spec, _DistSpec), f"{type(dist_spec)} should be _DistSpec" + + trans_func_key = (old_dist_spec.placement, dist_spec.placement) + trans_funcs = { + (DistPlacementPattern.REPLICATE, DistPlacementPattern.REPLICATE): DistSpecManager._r2r, + (DistPlacementPattern.REPLICATE, DistPlacementPattern.SHARD): DistSpecManager._r2s, + (DistPlacementPattern.SHARD, DistPlacementPattern.REPLICATE): DistSpecManager._s2r, + (DistPlacementPattern.SHARD, DistPlacementPattern.SHARD): DistSpecManager._s2s, + } + + forward_trans_handle = trans_funcs[trans_func_key] + if not DistSpecManager._use_autograd_function: + return forward_trans_handle(tensor, old_dist_spec, dist_spec, pg) + + backward_trans_handle = trans_funcs[(dist_spec.placement, old_dist_spec.placement)] + + return TransformDistSpec.apply( + tensor, old_dist_spec, dist_spec, pg, forward_trans_handle, backward_trans_handle + ) + + @staticmethod + @contextmanager + def no_grad(): + try: + DistSpecManager._use_autograd_function = False + yield + finally: + DistSpecManager._use_autograd_function = True diff --git a/colossalai/tensor/distspec.py b/colossalai/legacy/tensor/distspec.py similarity index 90% rename from colossalai/tensor/distspec.py rename to colossalai/legacy/tensor/distspec.py index 3a09f1426e3140f1b2857e0c02409b16e0fba041..efef9904ec10cfc9593aa3694cbef106267a67f2 100644 --- a/colossalai/tensor/distspec.py +++ b/colossalai/legacy/tensor/distspec.py @@ -1,12 +1,12 @@ from enum import Enum from typing import List -__all__ = ['ReplicaSpec', 'ShardSpec'] +__all__ = ["ReplicaSpec", "ShardSpec"] class DistPlacementPattern(Enum): - REPLICATE = 'r' - SHARD = 's' + REPLICATE = "r" + SHARD = "s" class _DistSpec: @@ -25,7 +25,6 @@ class _DistSpec: """ def __init__(self, dist_placement_pattern: DistPlacementPattern, **meta_info): - self.placement = dist_placement_pattern for k, v in meta_info.items(): setattr(self, k, v) @@ -34,15 +33,15 @@ class _DistSpec: if dir(self) != dir(other): return False for attr in dir(self): - if not attr.startswith('__') and getattr(self, attr) != getattr(other, attr): + if not attr.startswith("__") and getattr(self, attr) != getattr(other, attr): return False return True def __repr__(self) -> str: attr_list = [] for attr in dir(self): - if not attr.startswith('__'): - attr_list.append(f'{attr}={str(getattr(self, attr))}') + if not attr.startswith("__"): + attr_list.append(f"{attr}={str(getattr(self, attr))}") attr_str = ", ".join(attr_list) return "DistSpec(" + attr_str + ")" diff --git a/colossalai/tensor/op_wrapper.py b/colossalai/legacy/tensor/op_wrapper.py similarity index 97% rename from colossalai/tensor/op_wrapper.py rename to colossalai/legacy/tensor/op_wrapper.py index 1c00066f74655ef997a4e53e4fc1ce97d33f6434..63ebaa264279a71670f187df2b9adb542b3ef69a 100644 --- a/colossalai/tensor/op_wrapper.py +++ b/colossalai/legacy/tensor/op_wrapper.py @@ -1,8 +1,5 @@ -from typing import ( - Callable, - Dict, -) import functools +from typing import Callable, Dict # Custom sharded ops _COLOSSAL_OPS: Dict[str, Callable] = {} diff --git a/colossalai/tensor/process_group.py b/colossalai/legacy/tensor/process_group.py similarity index 85% rename from colossalai/tensor/process_group.py rename to colossalai/legacy/tensor/process_group.py index f108bdc247f5d84e0aa78240b52c5b9b1f17ed06..ec6043163336252ae2d7f51fe8ccfb7467be47aa 100644 --- a/colossalai/tensor/process_group.py +++ b/colossalai/legacy/tensor/process_group.py @@ -7,13 +7,12 @@ from colossalai.logging import get_dist_logger class PyTorchProcessGroupDict(metaclass=SingletonMeta): - def __init__(self): # distributed settings # use this dict to record all Pytorch ProcessGroups self.dict = {} # set a distributed logger - self.logger = get_dist_logger('ProcessGroup') + self.logger = get_dist_logger("ProcessGroup") def log_pg_init(self, rank_list: List[int], backend: str): str_list = ["Pytorch ProcessGroup Init:"] @@ -21,9 +20,8 @@ class PyTorchProcessGroupDict(metaclass=SingletonMeta): str_list.append(f"ranks: {rank_list}") self.logger.info("\n\t".join(str_list), ranks=[0]) - def get(self, rank_list: List[int], backend: str = 'nccl'): - """Reuse Pytorch ProcessGroup when such a group is initialized - """ + def get(self, rank_list: List[int], backend: str = "nccl"): + """Reuse Pytorch ProcessGroup when such a group is initialized""" # we need to convert the passed list to a tuple # since List is unhashable processgroup_key = (backend, tuple(rank_list)) @@ -51,11 +49,13 @@ class ProcessGroup: dp_degree: Optional[int], data parallelism degree. How many processes are inside a dp process group. . default None means len(ranks). """ - def __init__(self, - rank: Optional[int] = None, - ranks: Optional[List[int]] = None, - tp_degree: Optional[int] = None, - dp_degree: Optional[int] = None) -> None: + def __init__( + self, + rank: Optional[int] = None, + ranks: Optional[List[int]] = None, + tp_degree: Optional[int] = None, + dp_degree: Optional[int] = None, + ) -> None: if not torch.distributed.is_initialized(): self.is_init = False return @@ -64,13 +64,13 @@ class ProcessGroup: self._rank = torch.distributed.get_rank() if rank is not None: - assert self._rank == rank # make sure that the global rank is correct + assert self._rank == rank # make sure that the global rank is correct if ranks is None: self._rank_list = list(range(torch.distributed.get_world_size())) else: self._rank_list = ranks - self._rank_list.sort() # ensure that the list is in order + self._rank_list.sort() # ensure that the list is in order self._world_size = len(self._rank_list) @@ -79,31 +79,36 @@ class ProcessGroup: self._tp_degree = 1 elif dp_degree and not tp_degree: self._dp_degree = dp_degree - assert self._world_size % self._dp_degree == 0, f"DP degree {dp_degree} should be divisible by {self._world_size} hen DP degree is None" + assert ( + self._world_size % self._dp_degree == 0 + ), f"DP degree {dp_degree} should be divisible by {self._world_size} hen DP degree is None" self._tp_degree = self._world_size // dp_degree elif not dp_degree and tp_degree: self._tp_degree = tp_degree - assert self._world_size % self._tp_degree == 0, f"TP degree {tp_degree} should be divisible by {self._world_size} when DP degree is None" + assert ( + self._world_size % self._tp_degree == 0 + ), f"TP degree {tp_degree} should be divisible by {self._world_size} when DP degree is None" self._dp_degree = self._world_size // tp_degree else: self._dp_degree = dp_degree self._tp_degree = tp_degree - assert self._dp_degree * self._tp_degree == self._world_size, \ - f"the world size {self._world_size} should equals to the product of DP degree {self._dp_degree}" \ + assert self._dp_degree * self._tp_degree == self._world_size, ( + f"the world size {self._world_size} should equals to the product of DP degree {self._dp_degree}" f"and TP degree {self._tp_degree}" + ) self._tp_rank_list = None self._dp_rank_list = None for i in range(self._dp_degree): i_tp_list = [self._rank_list[i * self._tp_degree + j] for j in range(self._tp_degree)] - PYTORCHPGDICT_.get(i_tp_list, 'nccl') + PYTORCHPGDICT_.get(i_tp_list, "nccl") if self._rank in i_tp_list: self._tp_rank_list = i_tp_list for j in range(self._tp_degree): j_dp_list = [self._rank_list[i * self._tp_degree + j] for i in range(self._dp_degree)] - PYTORCHPGDICT_.get(j_dp_list, 'nccl') + PYTORCHPGDICT_.get(j_dp_list, "nccl") if self._rank in j_dp_list: self._dp_rank_list = j_dp_list @@ -119,18 +124,18 @@ class ProcessGroup: for i in range(self._dp_degree): i_tp_list = [self._rank_list[i * self._tp_degree + j] for j in range(self._tp_degree)] - PYTORCHPGDICT_.get(i_tp_list, 'gloo') + PYTORCHPGDICT_.get(i_tp_list, "gloo") for j in range(self._tp_degree): j_dp_list = [self._rank_list[i * self._tp_degree + j] for i in range(self._dp_degree)] - PYTORCHPGDICT_.get(j_dp_list, 'gloo') + PYTORCHPGDICT_.get(j_dp_list, "gloo") self._has_cpu_groups = True @property def has_cpu_groups(self) -> bool: """has_cpu_groups - If cpu groups have been initailized. + If cpu groups have been initialized. Returns: bool: cpu process groups have been initialized or not. @@ -145,7 +150,7 @@ class ProcessGroup: else: return "ProcessGroup not initialized" - def __eq__(self, obj: 'ProcessGroup') -> bool: + def __eq__(self, obj: "ProcessGroup") -> bool: if not isinstance(obj, ProcessGroup): return False if self._rank != obj._rank: @@ -260,7 +265,7 @@ class ProcessGroup: Returns: `torch._C._distributed_c10d.ProcessGroup`: the pytorch DP process group. """ - return PYTORCHPGDICT_.get(self._dp_rank_list, 'nccl') + return PYTORCHPGDICT_.get(self._dp_rank_list, "nccl") def tp_process_group(self): """tp_process_group @@ -270,7 +275,7 @@ class ProcessGroup: Returns: `torch._C._distributed_c10d.ProcessGroup`: the pytorch TP process group. """ - return PYTORCHPGDICT_.get(self._tp_rank_list, 'nccl') + return PYTORCHPGDICT_.get(self._tp_rank_list, "nccl") def cpu_dp_process_group(self): """cpu_dp_process_group @@ -283,7 +288,7 @@ class ProcessGroup: `torch._C._distributed_c10d.ProcessGroup`: the pytorch DP process group. """ assert self._has_cpu_groups - return PYTORCHPGDICT_.get(self._dp_rank_list, 'gloo') + return PYTORCHPGDICT_.get(self._dp_rank_list, "gloo") def cpu_tp_process_group(self): """cpu_tp_process_group @@ -296,7 +301,7 @@ class ProcessGroup: `torch._C._distributed_c10d.ProcessGroup`: the pytorch TP process group. """ assert self._has_cpu_groups - return PYTORCHPGDICT_.get(self._tp_rank_list, 'gloo') + return PYTORCHPGDICT_.get(self._tp_rank_list, "gloo") def get_ranks_in_dp(self) -> List[int]: """get_ranks_in_dp diff --git a/colossalai/tensor/tensor_spec.py b/colossalai/legacy/tensor/tensor_spec.py similarity index 76% rename from colossalai/tensor/tensor_spec.py rename to colossalai/legacy/tensor/tensor_spec.py index 580df9f8f31023f0623cadc08fe67d49b309819f..5bdd384e5e15b8db00e697178cb6551f8786e594 100644 --- a/colossalai/tensor/tensor_spec.py +++ b/colossalai/legacy/tensor/tensor_spec.py @@ -1,20 +1,21 @@ from dataclasses import dataclass from typing import Optional -from colossalai.tensor.distspec import DistPlacementPattern, _DistSpec -from colossalai.tensor.process_group import ProcessGroup +from colossalai.legacy.tensor.distspec import DistPlacementPattern, _DistSpec +from colossalai.legacy.tensor.process_group import ProcessGroup from .compute_spec import ComputeSpec @dataclass class ColoTensorSpec: - """ ColoTensorSpec + """ColoTensorSpec A data class for specifications of the `ColoTensor`. It contains attributes of `ProcessGroup`, `_DistSpec`, `ComputeSpec`. The latter two attributes are optional. If not set, they are default value is `Replicate()` and `None`. """ + pg: ProcessGroup dist_attr: Optional[_DistSpec] = _DistSpec(DistPlacementPattern.REPLICATE) compute_attr: Optional[ComputeSpec] = None diff --git a/colossalai/legacy/trainer/__init__.py b/colossalai/legacy/trainer/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e4fddc7c1c9f5963b800cdb03566524854128e72 --- /dev/null +++ b/colossalai/legacy/trainer/__init__.py @@ -0,0 +1,3 @@ +from ._trainer import Trainer + +__all__ = ["Trainer"] diff --git a/colossalai/trainer/_trainer.py b/colossalai/legacy/trainer/_trainer.py similarity index 96% rename from colossalai/trainer/_trainer.py rename to colossalai/legacy/trainer/_trainer.py index 60bbc4eeee32adbd46472106fff4a07c653f7cb6..46e708622237effe57c7b7e2999f514cf393164b 100644 --- a/colossalai/trainer/_trainer.py +++ b/colossalai/legacy/trainer/_trainer.py @@ -1,14 +1,14 @@ -from typing import Union, List, Any +from typing import Any, List, Union import torch from torch.utils.data import DataLoader from tqdm import tqdm -from colossalai.engine import Engine +from colossalai.legacy.engine import Engine +from colossalai.legacy.trainer.hooks import BaseHook +from colossalai.legacy.utils import is_dp_rank_0, is_no_pp_or_last_stage, is_tp_rank_0 from colossalai.logging import DistributedLogger from colossalai.utils import MultiTimer -from colossalai.utils import is_dp_rank_0, is_tp_rank_0, is_no_pp_or_last_stage -from colossalai.trainer.hooks import BaseHook class Trainer: @@ -31,9 +31,9 @@ class Trainer: >>> # Initialize your engine, train_dataloader, test_dataloader, lr_scheduler >>> engine, train_dataloader, _, _ = colossalai.initialize(model, optimizer, criterion) >>> # Beginning training progress - >>> timier = ... + >>> timer = ... >>> logger = ... - >>> trainer = Trainer(engine=engine, logger=logger, timer=timier) + >>> trainer = Trainer(engine=engine, logger=logger, timer=timer) >>> # add hooks you would like to use here. >>> hook_list = [] >>> trainer.fit( @@ -56,7 +56,7 @@ class Trainer: timer: MultiTimer = None, logger: DistributedLogger = None, ): - # training-ralated params + # training-related params self._engine = engine self._max_epochs = 0 self._cur_epoch = 0 @@ -118,7 +118,7 @@ class Trainer: self._cur_step = epoch * self._steps_per_epoch def _call_timer(self, action: str, item: str, *args, **kwargs) -> None: - """Call timer funciton with a given timer name. + """Call timer function with a given timer name. Args: action (str): Function to be called on timer. @@ -151,7 +151,7 @@ class Trainer: @staticmethod def _should_display_progress(display_progress: bool): """Only display progress on DP rank 0, TP rank 0 and PP last rank""" - return (display_progress and is_dp_rank_0() and is_tp_rank_0() and is_no_pp_or_last_stage()) + return display_progress and is_dp_rank_0() and is_tp_rank_0() and is_no_pp_or_last_stage() def _train_epoch( self, @@ -293,8 +293,7 @@ class Trainer: assert isinstance(hooks, list), f"expected argument hooks be to list, but got {type(hooks)}" for hook in hooks: - assert isinstance(hook, BaseHook), \ - f'expected the hook to be of type BaseHook, but got {type(hook)}' + assert isinstance(hook, BaseHook), f"expected the hook to be of type BaseHook, but got {type(hook)}" else: hooks = [] self.hooks = hooks diff --git a/colossalai/legacy/trainer/hooks/__init__.py b/colossalai/legacy/trainer/hooks/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..290aeb64a04dc1350f40f31323a6193bf2bb3ef6 --- /dev/null +++ b/colossalai/legacy/trainer/hooks/__init__.py @@ -0,0 +1,26 @@ +from ._base_hook import BaseHook +from ._checkpoint_hook import SaveCheckpointHook +from ._log_hook import ( + LogMemoryByEpochHook, + LogMetricByEpochHook, + LogMetricByStepHook, + LogTimingByEpochHook, + TensorboardHook, +) +from ._lr_scheduler_hook import LRSchedulerHook +from ._metric_hook import AccuracyHook, LossHook, MetricHook, ThroughputHook + +__all__ = [ + "BaseHook", + "MetricHook", + "LossHook", + "AccuracyHook", + "LogMetricByEpochHook", + "TensorboardHook", + "LogTimingByEpochHook", + "LogMemoryByEpochHook", + "LRSchedulerHook", + "ThroughputHook", + "LogMetricByStepHook", + "SaveCheckpointHook", +] diff --git a/colossalai/legacy/trainer/hooks/_base_hook.py b/colossalai/legacy/trainer/hooks/_base_hook.py new file mode 100644 index 0000000000000000000000000000000000000000..fc883134203fe9c31b5ba034217c8bfe233b26df --- /dev/null +++ b/colossalai/legacy/trainer/hooks/_base_hook.py @@ -0,0 +1,82 @@ +#!/usr/bin/env python +# -*- encoding: utf-8 -*- + +from abc import ABC + +from torch import Tensor + + +class BaseHook(ABC): + """This class allows users to add desired actions in specific time points + during training or evaluation. + + :param priority: Priority in the printing, hooks with small priority will be printed in front + :type priority: int + """ + + def __init__(self, priority: int) -> None: + self.priority = priority + + def after_hook_is_attached(self, trainer): + """Actions after hooks are attached to trainer.""" + + def before_train(self, trainer): + """Actions before training.""" + + def after_train(self, trainer): + """Actions after training.""" + + def before_train_iter(self, trainer): + """Actions before running a training iteration.""" + + def after_train_iter(self, trainer, output: Tensor, label: Tensor, loss: Tensor): + """Actions after running a training iteration. + + Args: + trainer (:class:`Trainer`): Trainer which is using this hook. + output (:class:`torch.Tensor`): Output of the model. + label (:class:`torch.Tensor`): Labels of the input data. + loss (:class:`torch.Tensor`): Loss between the output and input data. + """ + + def before_train_epoch(self, trainer): + """Actions before starting a training epoch.""" + + def after_train_epoch(self, trainer): + """Actions after finishing a training epoch.""" + + def before_test(self, trainer): + """Actions before evaluation.""" + + def after_test(self, trainer): + """Actions after evaluation.""" + + def before_test_epoch(self, trainer): + """Actions before starting a testing epoch.""" + + def after_test_epoch(self, trainer): + """Actions after finishing a testing epoch.""" + + def before_test_iter(self, trainer): + """Actions before running a testing iteration.""" + + def after_test_iter(self, trainer, output: Tensor, label: Tensor, loss: Tensor): + """Actions after running a testing iteration. + + Args: + trainer (:class:`Trainer`): Trainer which is using this hook + output (:class:`torch.Tensor`): Output of the model + label (:class:`torch.Tensor`): Labels of the input data + loss (:class:`torch.Tensor`): Loss between the output and input data + """ + + def init_runner_states(self, trainer, key, val): + """Initializes trainer's state. + + Args: + trainer (:class:`Trainer`): Trainer which is using this hook + key: Key of state to be reset + val: Value of state to be reset + """ + if key not in trainer.states: + trainer.states[key] = val diff --git a/colossalai/legacy/trainer/hooks/_checkpoint_hook.py b/colossalai/legacy/trainer/hooks/_checkpoint_hook.py new file mode 100644 index 0000000000000000000000000000000000000000..50c80759867e7a3d8f5f4866c3fb75748677379d --- /dev/null +++ b/colossalai/legacy/trainer/hooks/_checkpoint_hook.py @@ -0,0 +1,76 @@ +#!/usr/bin/env python +# -*- encoding: utf-8 -*- +import torch + +from colossalai.legacy.registry import HOOKS +from colossalai.legacy.trainer.hooks import BaseHook +from colossalai.legacy.utils.checkpointing import save_checkpoint +from colossalai.logging import get_dist_logger + +from ._lr_scheduler_hook import LRSchedulerHook + + +@HOOKS.register_module +class SaveCheckpointHook(BaseHook): + """Saves the model by interval in training process. + + Args: + interval (int, optional): Number of epochs between saving the checkpoint, defaults to 1. + if save_by_iter is True, this arg refers to the number of iters between saving. + checkpoint_dir (str, optional): File name to save the checkpoint, defaults to None. + model (torch.nn.Module, Optional): The model to save, defaults to None. When not passing, + 'trainer.engine.model' will be used. We encourage you to pass the model in it to avoid some + unexpected bugs, especially when using **DDP**. + save_by_iter (bool, optional): Whether saving the checkpoint by iter, default to False. + priority (int, optional): Priority in the printing, hooks with small priority will be printed in front + defaults to 10. If different hooks share same priority, the order of printing would + depend on the hooks order in the hook list. + """ + + def __init__( + self, + interval: int = 1, + checkpoint_dir: str = None, + model: torch.nn.Module = None, + save_by_iter: bool = False, + priority: int = 10, + ): + super().__init__(priority=priority) + self.interval = interval + self.checkpoint_dir = checkpoint_dir + self.model = model + self.save_by_iter = save_by_iter + self.logger = get_dist_logger() + + # get lr scheduler from the LRSchedulerHook before train + self._lr_scheduler = None + + def after_hook_is_attached(self, trainer): + # get lr scheduler if exists + for hook in trainer.hooks: + if isinstance(hook, LRSchedulerHook): + self._lr_scheduler = hook.lr_scheduler + break + self.model = self.model if self.model is not None else trainer.engine.model + + def after_train_iter(self, trainer, output, label, loss): + """Saves the model after a training iter.""" + # save by interval + if self.save_by_iter and trainer.cur_step % self.interval == 0: + save_checkpoint( + self.checkpoint_dir, trainer.cur_epoch, self.model, trainer.engine.optimizer, self._lr_scheduler + ) + self.logger.info( + f"checkpoint for iteration {trainer.cur_step} is saved to {self.checkpoint_dir}", ranks=[0] + ) + else: + pass + + def after_train_epoch(self, trainer): + """Saves the model after a training epoch.""" + # save by interval + if trainer.cur_epoch % self.interval == 0: + save_checkpoint( + self.checkpoint_dir, trainer.cur_epoch, self.model, trainer.engine.optimizer, self._lr_scheduler + ) + self.logger.info(f"checkpoint for epoch {trainer.cur_epoch} is saved to {self.checkpoint_dir}", ranks=[0]) diff --git a/colossalai/legacy/trainer/hooks/_commons_.py b/colossalai/legacy/trainer/hooks/_commons_.py new file mode 100644 index 0000000000000000000000000000000000000000..18da38298704d871750187ed6b6a7035272fb639 --- /dev/null +++ b/colossalai/legacy/trainer/hooks/_commons_.py @@ -0,0 +1,9 @@ +import torch + + +def _format_number(val, prec=5): + if isinstance(val, float): + return f"{val:.{prec}g}" + elif torch.is_tensor(val) and torch.is_floating_point(val): + return f"{val.item():.{prec}g}" + return val diff --git a/colossalai/legacy/trainer/hooks/_log_hook.py b/colossalai/legacy/trainer/hooks/_log_hook.py new file mode 100644 index 0000000000000000000000000000000000000000..c1cf0ca5228bd1c54556e97dfa0a88decf23b044 --- /dev/null +++ b/colossalai/legacy/trainer/hooks/_log_hook.py @@ -0,0 +1,302 @@ +#!/usr/bin/env python +# -*- encoding: utf-8 -*- + +import os +import os.path as osp +from typing import List + +from colossalai.legacy.context import ParallelMode +from colossalai.legacy.core import global_context as gpc +from colossalai.legacy.registry import HOOKS +from colossalai.legacy.trainer.hooks._metric_hook import ThroughputMetric +from colossalai.legacy.utils import is_dp_rank_0, is_no_pp_or_last_stage, is_tp_rank_0, report_memory_usage +from colossalai.logging import DistributedLogger +from colossalai.utils import MultiTimer + +from ._base_hook import BaseHook +from ._commons_ import _format_number + + +class LogByEpochHook(BaseHook): + """Hook to log by epoch. + + Args: + logger (:class:`colossalai.logging.DistributedLogger`): Logger for recording the log information. + interval (int, optional): Interval of printing log information, defaults to 1. + priority (int, optional): Priority in the printing, hooks with small priority will be printed in front, + defaults to 1. If different hooks share same priority, the order of printing would + depend on the hooks order in the hook list. + """ + + def __init__(self, logger, interval: int = 1, priority: int = 1): + super().__init__(priority) + self.logger = logger + self._interval = interval + + def _is_epoch_to_log(self, trainer): + return trainer.cur_epoch % self._interval == 0 + + +@HOOKS.register_module +class LogMetricByStepHook(BaseHook): + """Hook to log metric by step. + + Args: + priority (int, optional): Priority in the printing, hooks with small priority will be printed in front, + defaults to 10. If different hooks share same priority, the order of printing would + depend on the hooks order in the hook list. + """ + + def __init__(self, priority: int = 10): + super().__init__(priority) + + def after_train_iter(self, trainer, *args): + trainer.states["step_metrics"] = dict() + for metric_name, metric_calculator in trainer.states["metrics"]["train"].items(): + if isinstance(metric_calculator, ThroughputMetric): + trainer.states["step_metrics"][metric_name.lower()] = metric_calculator.get_last_step_info() + else: + trainer.states["step_metrics"][metric_name.lower()] = metric_calculator.get_last_step_value() + + def after_test_iter(self, trainer, *args): + trainer.states["step_metrics"] = dict() + for metric_name, metric_calculator in trainer.states["metrics"]["test"].items(): + if isinstance(metric_calculator, ThroughputMetric): + trainer.states["step_metrics"][metric_name.lower()] = metric_calculator.get_last_step_info() + else: + trainer.states["step_metrics"][metric_name.lower()] = metric_calculator.get_last_step_value() + + +@HOOKS.register_module +class LogMetricByEpochHook(LogByEpochHook): + """Specialized hook to record the metric to log. + + Args: + logger (:class:`colossalai.logging.DistributedLogger`): Logger for recording the log information. + interval (int, optional): Interval of printing log information, defaults to 1. + priority (int, optional): Priority in the printing, hooks with small priority will be printed in front, + defaults to 10. If different hooks share same priority, the order of printing would + depend on the hooks order in the hook list. + """ + + def __init__(self, logger, interval: int = 1, priority: int = 10) -> None: + super().__init__(logger, interval, priority) + self._is_rank_to_log = is_dp_rank_0() and is_tp_rank_0() and is_no_pp_or_last_stage() + + def _get_str(self, trainer, mode): + msg = [] + for metric_name, metric_calculator in trainer.states["metrics"][mode].items(): + msg.append(f"{metric_name} = {_format_number(metric_calculator.get_accumulated_value())}") + msg = " | ".join(msg) + return msg + + def after_train_epoch(self, trainer): + if self._is_epoch_to_log(trainer): + msg = self._get_str(trainer=trainer, mode="train") + + if self._is_rank_to_log: + self.logger.info(f"[Epoch {trainer.cur_epoch} / Train]: {msg}") + # f'Training - Epoch {trainer.cur_epoch} - {self.__class__.__name__}: {msg}') + + def after_test_epoch(self, trainer): + if self._is_epoch_to_log(trainer): + msg = self._get_str(trainer=trainer, mode="test") + if self._is_rank_to_log: + self.logger.info(f"[Epoch {trainer.cur_epoch} / Test]: {msg}") + # f'Testing - Epoch {trainer.cur_epoch} - {self.__class__.__name__}: {msg}') + + +@HOOKS.register_module +class TensorboardHook(BaseHook): + """Specialized hook to record the metric to Tensorboard. + + Args: + log_dir (str): Directory of log. + ranks (list): Ranks of processors. + parallel_mode (:class:`colossalai.legacy.context.parallel_mode.ParallelMode`, optional): Parallel mode used in trainer, + defaults to colossalai.legacy.context.parallel_mode.ParallelMode.GLOBAL. + priority (int, optional): Priority in the printing, hooks with small priority will be printed in front, + defaults to 10. If different hooks share same priority, the order of printing would + depend on the hooks order in the hook list. + """ + + def __init__( + self, + log_dir: str, + ranks: List = None, + parallel_mode: ParallelMode = ParallelMode.GLOBAL, + priority: int = 10, + ) -> None: + super().__init__(priority=priority) + from torch.utils.tensorboard import SummaryWriter + + # create log dir + if not gpc.is_initialized(ParallelMode.GLOBAL) or gpc.get_global_rank() == 0: + os.makedirs(log_dir, exist_ok=True) + + # determine the ranks to generate tensorboard logs + self._is_valid_rank_to_log = False + if not gpc.is_initialized(parallel_mode): + self._is_valid_rank_to_log = True + else: + local_rank = gpc.get_local_rank(parallel_mode) + + if ranks is None or local_rank in ranks: + self._is_valid_rank_to_log = True + + # check for + if ( + gpc.is_initialized(ParallelMode.PIPELINE) + and not gpc.is_last_rank(ParallelMode.PIPELINE) + and self._is_valid_rank_to_log + ): + raise ValueError("Tensorboard hook can only log on the last rank of pipeline process group") + + if self._is_valid_rank_to_log: + # create workspace on only one rank + if gpc.is_initialized(parallel_mode): + rank = gpc.get_local_rank(parallel_mode) + else: + rank = 0 + + # create workspace + log_dir = osp.join(log_dir, f"{parallel_mode}_rank_{rank}") + os.makedirs(log_dir, exist_ok=True) + + self.writer = SummaryWriter(log_dir=log_dir, filename_suffix=f"_rank_{rank}") + + def _log_by_iter(self, trainer, mode: str): + for metric_name, metric_calculator in trainer.states["metrics"][mode].items(): + if metric_calculator.epoch_only: + continue + val = metric_calculator.get_last_step_value() + + if self._is_valid_rank_to_log: + self.writer.add_scalar(f"{metric_name}/{mode}", val, trainer.cur_step) + + def _log_by_epoch(self, trainer, mode: str): + for metric_name, metric_calculator in trainer.states["metrics"][mode].items(): + if metric_calculator.epoch_only: + val = metric_calculator.get_accumulated_value() + if self._is_valid_rank_to_log: + self.writer.add_scalar(f"{metric_name}/{mode}", val, trainer.cur_step) + + def after_test_iter(self, trainer, *args): + self._log_by_iter(trainer, mode="test") + + def after_test_epoch(self, trainer): + self._log_by_epoch(trainer, mode="test") + + def after_train_iter(self, trainer, *args): + self._log_by_iter(trainer, mode="train") + + def after_train_epoch(self, trainer): + self._log_by_epoch(trainer, mode="train") + + +@HOOKS.register_module +class LogTimingByEpochHook(LogByEpochHook): + """Specialized hook to write timing record to log. + + Args: + timer (:class:`colossalai.utils.MultiTimer`): Timer for the hook. + logger (:class:`colossalai.logging.DistributedLogger`): Logger for recording the log information. + interval (int, optional): Interval of printing log information, defaults to 1. + priority (int, optional): Priority in the printing, hooks with small priority will be printed in front + defaults to 10. If different hooks share same priority, the order of printing would + depend on the hooks order in the hook list. + log_eval (bool, optional): Whether writes in evaluation, defaults to True. + ignore_num_train_steps (int, optional): Number of training steps to ignore, defaults to 0. + """ + + def __init__( + self, + timer: MultiTimer, + logger: DistributedLogger, + interval: int = 1, + priority: int = 10, + log_eval: bool = True, + ignore_num_train_steps: int = 0, + ) -> None: + super().__init__(logger=logger, interval=interval, priority=priority) + self._timer = timer + self._log_eval = log_eval + self._is_rank_to_log = is_dp_rank_0() and is_tp_rank_0() and is_no_pp_or_last_stage() + + # extra handling to avoid the unstable readings of the first + # few training steps to affect the history mean time + self._ignore_num_train_steps = ignore_num_train_steps + self._is_train_step_history_trimmed = False + + def _get_message(self, mode): + msg = [] + for timer_name, timer in self._timer: + if timer_name.startswith(mode): + last_elapsed_time = timer.get_elapsed_time() + if timer.has_history: + if timer_name == "Train-step" and not self._is_train_step_history_trimmed: + timer._history = timer._history[self._ignore_num_train_steps :] + self._is_train_step_history_trimmed = True + history_mean = timer.get_history_mean() + timer.get_history_sum() + msg.append( + f"{timer_name}: last = {_format_number(last_elapsed_time)} s, mean = {_format_number(history_mean)} s" + ) + else: + msg.append(f"{timer_name}: last = {_format_number(last_elapsed_time)} s") + + msg = " | ".join(msg) + return msg + + def after_train_epoch(self, trainer): + """Writes log after finishing a training epoch.""" + if self._is_epoch_to_log(trainer) and self._is_rank_to_log: + msg = self._get_message("Train") + self.logger.info(f"[Epoch {trainer.cur_epoch} / Train]: {msg} | #steps/epoch = {trainer.steps_per_epoch}") + + def after_test_epoch(self, trainer): + """Writes log after finishing a testing epoch.""" + if self._is_epoch_to_log(trainer) and self._is_rank_to_log and self._log_eval: + msg = self._get_message("Test") + self.logger.info(f"[Epoch {trainer.cur_epoch} / Test]: {msg}") + + +@HOOKS.register_module +class LogMemoryByEpochHook(LogByEpochHook): + """Specialized Hook to write memory usage record to log. + + Args: + logger (:class:`colossalai.logging.DistributedLogger`): Logger for recording the log information. + interval (int, optional): Interval of printing log information, defaults to 1. + priority (int, optional): Priority in the printing, hooks with small priority will be printed in front + defaults to 1. If different hooks share same priority, the order of printing would + depend on the hooks order in the hook list. + log_eval (bool, optional): Whether writes in evaluation, defaults to True. + """ + + def __init__( + self, + logger: DistributedLogger, + interval: int = 1, + priority: int = 10, + log_eval: bool = True, + report_cpu: bool = False, # no reference + ) -> None: + super().__init__(logger=logger, interval=interval, priority=priority) + self._log_eval = log_eval + self._is_rank_to_log = is_dp_rank_0() and is_tp_rank_0() + + def before_train(self, trainer): + """Resets before training.""" + if self._is_epoch_to_log(trainer) and self._is_rank_to_log: + report_memory_usage("Before-train", self.logger) + + def after_train_epoch(self, trainer): + """Writes log after finishing a training epoch.""" + if self._is_epoch_to_log(trainer) and self._is_rank_to_log: + report_memory_usage(f"[Epoch {trainer.cur_epoch} / Train]", self.logger) + + def after_test(self, trainer): + """Reports after testing.""" + if self._is_epoch_to_log(trainer) and self._is_rank_to_log and self._log_eval: + report_memory_usage(f"[Epoch {trainer.cur_epoch} / Test]", self.logger) diff --git a/colossalai/trainer/hooks/_lr_scheduler_hook.py b/colossalai/legacy/trainer/hooks/_lr_scheduler_hook.py similarity index 82% rename from colossalai/trainer/hooks/_lr_scheduler_hook.py rename to colossalai/legacy/trainer/hooks/_lr_scheduler_hook.py index c6da33442dc39b78474fb36c50b3a9bbfc790666..d14db563473ca07e87f37731b234424ba7b6cc4b 100644 --- a/colossalai/trainer/hooks/_lr_scheduler_hook.py +++ b/colossalai/legacy/trainer/hooks/_lr_scheduler_hook.py @@ -1,6 +1,7 @@ -from colossalai.registry import HOOKS from torch import Tensor +from colossalai.legacy.registry import HOOKS + from ._metric_hook import LearningRateMetric, MetricHook @@ -33,15 +34,16 @@ class LRSchedulerHook(MetricHook): def after_hook_is_attached(self, trainer): self._check_metric_states_initialization(trainer) - trainer.states['metrics']['train']['LR'] = LearningRateMetric(epoch_only=self.by_epoch, - initial_lr=self.lr_scheduler.get_last_lr()[0]) + trainer.states["metrics"]["train"]["LR"] = LearningRateMetric( + epoch_only=self.by_epoch, initial_lr=self.lr_scheduler.get_last_lr()[0] + ) def after_train_epoch(self, trainer): if self.by_epoch: self.lr_scheduler.step() - trainer.states['metrics']['train']['LR'].update(self.lr_scheduler.get_last_lr()[0]) + trainer.states["metrics"]["train"]["LR"].update(self.lr_scheduler.get_last_lr()[0]) def after_train_iter(self, trainer, output: Tensor, label: Tensor, loss: Tensor): if not self.by_epoch: self.lr_scheduler.step() - trainer.states['metrics']['train']['LR'].update(self.lr_scheduler.get_last_lr()[0]) + trainer.states["metrics"]["train"]["LR"].update(self.lr_scheduler.get_last_lr()[0]) diff --git a/colossalai/trainer/hooks/_metric_hook.py b/colossalai/legacy/trainer/hooks/_metric_hook.py similarity index 87% rename from colossalai/trainer/hooks/_metric_hook.py rename to colossalai/legacy/trainer/hooks/_metric_hook.py index 526d6c746ec6511c97b283ac1074340daabb2516..35a7f0a156abfac64cf8f371ee5dfe6934af0fbf 100644 --- a/colossalai/trainer/hooks/_metric_hook.py +++ b/colossalai/legacy/trainer/hooks/_metric_hook.py @@ -6,11 +6,13 @@ from typing import Callable import torch import torch.distributed as dist -from colossalai.communication import all_reduce -from colossalai.context import ParallelMode -from colossalai.core import global_context as gpc -from colossalai.registry import HOOKS -from colossalai.utils import get_current_device, is_no_pp_or_last_stage + +from colossalai.legacy.communication import all_reduce +from colossalai.legacy.context import ParallelMode +from colossalai.legacy.core import global_context as gpc +from colossalai.legacy.registry import HOOKS +from colossalai.legacy.utils import is_no_pp_or_last_stage +from colossalai.utils import get_current_device from ._base_hook import BaseHook from ._commons_ import _format_number @@ -19,8 +21,8 @@ from ._commons_ import _format_number class Metric(ABC): """A basic class of metric collectors. It collects a specific metric during training or evaluation and would always be used with - :class:`MetricHook` to help it update its states and show the - metric. So please use corresponding hook class to make the metric + :class:`MetricHook` to help it update its states and show the + metric. So please use corresponding hook class to make the metric collector works. Args: @@ -33,8 +35,7 @@ class Metric(ABC): @property def epoch_only(self): - """Returns :attr:`epoch_only`. - """ + """Returns :attr:`epoch_only`.""" return self._epoch_only @abstractmethod @@ -42,20 +43,16 @@ class Metric(ABC): """Resets the metric to it's initial state. By default, this is called at the start of each epoch. """ - pass @abstractmethod def update(self, *args, **kwargs) -> None: """Updates the metric's state using the passed batch output. By default, this is called once for each batch. """ - pass @abstractmethod def get_last_step_value(self) -> float: - """Returns the metric value in the last iteration. - """ - pass + """Returns the metric value in the last iteration.""" @abstractmethod def get_accumulated_value(self): @@ -65,7 +62,6 @@ class Metric(ABC): :return: the actual quantity of interest :rtype: Any """ - pass @staticmethod @abstractmethod @@ -75,7 +71,6 @@ class Metric(ABC): :return: The result of comparison :rtype: bool """ - pass class LossMetric(Metric): @@ -92,8 +87,7 @@ class LossMetric(Metric): self.count = 0 def reset(self) -> None: - """Sets :attr:`last_step_loss` and :attr:`accum_loss` to zero. - """ + """Sets :attr:`last_step_loss` and :attr:`accum_loss` to zero.""" self.last_step_loss.zero_() self.accum_loss.zero_() self.count = 0 @@ -112,8 +106,7 @@ class LossMetric(Metric): self.count += 1 def get_accumulated_value(self): - """Returns accumulated loss. - """ + """Returns accumulated loss.""" if gpc.is_initialized(ParallelMode.DATA): dist.all_reduce(self.accum_loss, op=dist.ReduceOp.SUM, group=gpc.get_group(ParallelMode.DATA)) self.accum_loss.div_(gpc.get_world_size(ParallelMode.DATA)) @@ -122,8 +115,7 @@ class LossMetric(Metric): return self.accum_loss.item() def get_last_step_value(self) -> float: - """Returns :attr:`last_step_loss`. - """ + """Returns :attr:`last_step_loss`.""" return self.last_step_loss.cpu().item() @staticmethod @@ -139,7 +131,7 @@ class LearningRateMetric(Metric): initial_lr (float, optional): Initial learning rate, defaults to 0.0. """ - def __init__(self, epoch_only: bool, initial_lr: float = 0.): + def __init__(self, epoch_only: bool, initial_lr: float = 0.0): super().__init__(epoch_only=epoch_only) self.lr = initial_lr @@ -220,9 +212,9 @@ class AccuracyMetric(Metric): class MetricHook(BaseHook): - """Specialized hook classes for :class:`Metric`. - Some help metric collectors initialize, reset and - update their states. Others are used to display and + """Specialized hook classes for :class:`Metric`. + Some help metric collectors initialize, reset and + update their states. Others are used to display and record the metric. Args: @@ -239,8 +231,8 @@ class MetricHook(BaseHook): self._is_stage_to_compute = is_no_pp_or_last_stage() def _check_metric_states_initialization(self, trainer): - if 'metrics' not in trainer.states: - self.init_runner_states(trainer, 'metrics', dict(train={}, test={})) + if "metrics" not in trainer.states: + self.init_runner_states(trainer, "metrics", dict(train={}, test={})) @HOOKS.register_module @@ -264,8 +256,8 @@ class LossHook(MetricHook): self.test_loss = LossMetric(epoch_only=True) # register the metric calculator - trainer.states['metrics']['train']['Loss'] = self.train_loss - trainer.states['metrics']['test']['Loss'] = self.test_loss + trainer.states["metrics"]["train"]["Loss"] = self.train_loss + trainer.states["metrics"]["test"]["Loss"] = self.test_loss def before_train_epoch(self, trainer): if self._is_stage_to_compute: @@ -305,7 +297,7 @@ class AccuracyHook(MetricHook): self.metric = AccuracyMetric(epoch_only=True, accuracy_func=self.accuracy_func) # register the metric - trainer.states['metrics']['test']['Accuracy'] = self.metric + trainer.states["metrics"]["test"]["Accuracy"] = self.metric def before_test(self, trainer): if self._is_stage_to_compute: @@ -354,8 +346,9 @@ class ThroughputMetric(Metric): if self._use_local: self.last_step_num_samples *= gpc.get_world_size(ParallelMode.DATA) else: - self.last_step_used_time = all_reduce(self.last_step_used_time, ParallelMode.DATA) / \ - gpc.get_world_size(ParallelMode.DATA) + self.last_step_used_time = all_reduce(self.last_step_used_time, ParallelMode.DATA) / gpc.get_world_size( + ParallelMode.DATA + ) self.last_step_num_samples = all_reduce(self.last_step_num_samples, ParallelMode.DATA) sample_per_sec = _format_number(self.last_step_num_samples / (self.last_step_used_time + 1e-12).item()) @@ -365,8 +358,9 @@ class ThroughputMetric(Metric): if self._use_local: self.last_step_num_samples *= gpc.get_world_size(ParallelMode.DATA) else: - self.last_step_used_time = all_reduce(self.last_step_used_time, ParallelMode.DATA) / \ - gpc.get_world_size(ParallelMode.DATA) + self.last_step_used_time = all_reduce(self.last_step_used_time, ParallelMode.DATA) / gpc.get_world_size( + ParallelMode.DATA + ) self.last_step_num_samples = all_reduce(self.last_step_num_samples, ParallelMode.DATA) sample_per_sec = _format_number(self.last_step_num_samples / (self.last_step_used_time + 1e-12).item()) @@ -377,8 +371,9 @@ class ThroughputMetric(Metric): return f"{sample_per_sec} sample_per_sec" def get_accumulated_value(self) -> float: - self.accumulated_used_time = all_reduce(self.accumulated_used_time, ParallelMode.DATA) / \ - gpc.get_world_size(ParallelMode.DATA) + self.accumulated_used_time = all_reduce(self.accumulated_used_time, ParallelMode.DATA) / gpc.get_world_size( + ParallelMode.DATA + ) self.accumulated_num_samples = all_reduce(self.accumulated_num_samples, ParallelMode.DATA) return (self.accumulated_num_samples / (self.accumulated_used_time + 1e-12)).item() @@ -409,14 +404,16 @@ class ThroughputHook(MetricHook): def after_hook_is_attached(self, trainer): self._check_metric_states_initialization(trainer) if self._is_stage_to_compute: - self.metric = ThroughputMetric(epoch_only=True, - ignored_steps=self.ignored_steps, - tflop_per_step=self._tflop_per_step, - use_local=self._use_local) + self.metric = ThroughputMetric( + epoch_only=True, + ignored_steps=self.ignored_steps, + tflop_per_step=self._tflop_per_step, + use_local=self._use_local, + ) # register the metric - trainer.states['metrics']['train']['Throughput'] = self.metric - trainer.states['metrics']['test']['Throughput'] = self.metric + trainer.states["metrics"]["train"]["Throughput"] = self.metric + trainer.states["metrics"]["test"]["Throughput"] = self.metric def before_train_epoch(self, trainer): if self._is_stage_to_compute: @@ -424,8 +421,9 @@ class ThroughputHook(MetricHook): def after_train_iter(self, trainer, *args): if self._is_stage_to_compute: - self.metric.update(trainer.engine.schedule.batch_size, - trainer._timer.get_timer('Train-step').get_elapsed_time()) + self.metric.update( + trainer.engine.schedule.batch_size, trainer._timer.get_timer("Train-step").get_elapsed_time() + ) def before_test(self, trainer): if self._is_stage_to_compute: @@ -433,5 +431,6 @@ class ThroughputHook(MetricHook): def after_test_iter(self, trainer, *args): if self._is_stage_to_compute: - self.metric.update(trainer.engine.schedule.batch_size, - trainer._timer.get_timer('Test-step').get_elapsed_time()) + self.metric.update( + trainer.engine.schedule.batch_size, trainer._timer.get_timer("Test-step").get_elapsed_time() + ) diff --git a/colossalai/legacy/utils/__init__.py b/colossalai/legacy/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..86984edeec655dd940869fd65b712bcc605160ff --- /dev/null +++ b/colossalai/legacy/utils/__init__.py @@ -0,0 +1,53 @@ +from .checkpointing import load_checkpoint, save_checkpoint +from .common import ( + clip_grad_norm_fp32, + copy_tensor_parallel_attributes, + count_zeros_fp32, + is_dp_rank_0, + is_model_parallel_parameter, + is_no_pp_or_last_stage, + is_tp_rank_0, + is_using_ddp, + is_using_pp, + is_using_sequence, + param_is_not_tensor_parallel_duplicate, + print_rank_0, + switch_virtual_pipeline_parallel_rank, + sync_model_param, +) +from .data_sampler import DataParallelSampler, get_dataloader +from .memory import ( + colo_device_memory_capacity, + colo_device_memory_used, + colo_get_cpu_memory_capacity, + colo_set_cpu_memory_capacity, + colo_set_process_memory_fraction, + report_memory_usage, +) + +__all__ = [ + "DataParallelSampler", + "get_dataloader", + "save_checkpoint", + "load_checkpoint", + "colo_device_memory_capacity", + "colo_device_memory_used", + "colo_get_cpu_memory_capacity", + "colo_set_cpu_memory_capacity", + "colo_set_process_memory_fraction", + "report_memory_usage", + "clip_grad_norm_fp32", + "copy_tensor_parallel_attributes", + "count_zeros_fp32", + "is_dp_rank_0", + "is_model_parallel_parameter", + "is_no_pp_or_last_stage", + "is_tp_rank_0", + "is_using_ddp", + "is_using_pp", + "is_using_sequence", + "param_is_not_tensor_parallel_duplicate", + "print_rank_0", + "switch_virtual_pipeline_parallel_rank", + "sync_model_param", +] diff --git a/colossalai/utils/activation_checkpoint.py b/colossalai/legacy/utils/activation_checkpoint.py similarity index 86% rename from colossalai/utils/activation_checkpoint.py rename to colossalai/legacy/utils/activation_checkpoint.py index fa9ed827a8a7fa649d5689d15a16050c9bc877b6..387e1c54ec877f9d33b2a5642c2fb8059cbec8cd 100644 --- a/colossalai/utils/activation_checkpoint.py +++ b/colossalai/legacy/utils/activation_checkpoint.py @@ -1,13 +1,13 @@ #!/usr/bin/env python # -*- encoding: utf-8 -*- +import weakref + import torch from torch.utils.checkpoint import check_backward_validity, detach_variable -from colossalai.context.random import get_states, get_current_mode, set_seed_states, set_mode, sync_states -from .cuda import get_current_device - -import weakref +from colossalai.legacy.context.random import get_current_mode, get_states, set_mode, set_seed_states, sync_states +from colossalai.utils import get_current_device def copy_to_device(obj, device): @@ -28,7 +28,6 @@ def copy_to_device(obj, device): class CheckpointFunction(torch.autograd.Function): - @staticmethod def forward(ctx, run_function, activation_offload=False, *args): check_backward_validity(args) @@ -42,7 +41,7 @@ class CheckpointFunction(torch.autograd.Function): ctx.fwd_seed_states = get_states(copy=True) ctx.fwd_current_mode = get_current_mode() - if hasattr(torch, 'is_autocast_enabled'): + if hasattr(torch, "is_autocast_enabled"): ctx.had_autocast_in_fwd = torch.is_autocast_enabled() else: ctx.had_autocast_in_fwd = False @@ -62,7 +61,7 @@ class CheckpointFunction(torch.autograd.Function): for i, arg in enumerate(args): if torch.is_tensor(arg): if activation_offload: - tensor_inputs.append(copy_to_device(arg, 'cpu')) + tensor_inputs.append(copy_to_device(arg, "cpu")) else: tensor_inputs.append(arg) ctx.tensor_indices.append(i) @@ -79,8 +78,10 @@ class CheckpointFunction(torch.autograd.Function): @staticmethod def backward(ctx, *args): if not torch.autograd._is_checkpoint_valid(): - raise RuntimeError("Checkpointing is not compatible with .grad() or when an `inputs` parameter is " - "passed to .backward(). Please use .backward() and do not pass its `inputs` argument.") + raise RuntimeError( + "Checkpointing is not compatible with .grad() or when an `inputs` parameter is " + "passed to .backward(). Please use .backward() and do not pass its `inputs` argument." + ) # Copy the list to avoid modifying original list. inputs = list(ctx.inputs) tensor_indices = ctx.tensor_indices @@ -131,8 +132,7 @@ class CheckpointFunction(torch.autograd.Function): outputs_with_grad.append(outputs[i]) args_with_grad.append(args[i]) if len(outputs_with_grad) == 0: - raise RuntimeError("none of output has requires_grad=True," - " this checkpoint() is not necessary") + raise RuntimeError("none of output has requires_grad=True," " this checkpoint() is not necessary") torch.autograd.backward(outputs_with_grad, args_with_grad) grads = tuple(inp.grad if isinstance(inp, torch.Tensor) else None for inp in detached_inputs) return (None, None) + grads @@ -143,7 +143,7 @@ def checkpoint(function, activation_offload, *args, use_reentrant: bool = True): Args: function: Describe the forward pass function. It should know how to handle the input tuples. - activation_offload: The variable to check whether we should offload activation to cpu + activation_offload: The variable to check whether we should offload activation to cpu args (list): Tuple containing the parameters of the function use_reentrant: Bool type to check if we need to use_reentrant, if use_reentrant=False, there might be more flexibility for user to define there checkpoint function @@ -169,7 +169,7 @@ def _checkpoint_without_reentrant(function, activation_offload=False, *args): fwd_current_mode = get_current_mode() # check if use autocast - if hasattr(torch, 'is_autocast_enabled'): + if hasattr(torch, "is_autocast_enabled"): has_autocast_in_fwd = torch.is_autocast_enabled() else: has_autocast_in_fwd = False @@ -179,7 +179,7 @@ def _checkpoint_without_reentrant(function, activation_offload=False, *args): weak_holder_list = [] # class for weakref.ref - class Holder(): + class Holder: pass # return a Holder object for later unpack process @@ -226,19 +226,20 @@ def _checkpoint_without_reentrant(function, activation_offload=False, *args): # rerun forward, the inner_pack will store all the activations in storage if has_autocast_in_fwd: - with torch.enable_grad(), \ - torch.cuda.amp.autocast(), \ - torch.autograd.graph.saved_tensors_hooks(inner_pack, inner_unpack): + with torch.enable_grad(), torch.cuda.amp.autocast(), torch.autograd.graph.saved_tensors_hooks( + inner_pack, inner_unpack + ): _unused = function(*args) else: - with torch.enable_grad(), \ - torch.autograd.graph.saved_tensors_hooks(inner_pack, inner_unpack): + with torch.enable_grad(), torch.autograd.graph.saved_tensors_hooks(inner_pack, inner_unpack): _unused = function(*args) if x not in storage: - raise RuntimeError("Attempt to retrieve a tensor saved by autograd multiple times without checkpoint" - " recomputation being triggered in between, this is not currently supported. Please" - " open an issue with details on your use case so that we can prioritize adding this.") + raise RuntimeError( + "Attempt to retrieve a tensor saved by autograd multiple times without checkpoint" + " recomputation being triggered in between, this is not currently supported. Please" + " open an issue with details on your use case so that we can prioritize adding this." + ) return storage[x] diff --git a/colossalai/legacy/utils/checkpoint/__init__.py b/colossalai/legacy/utils/checkpoint/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..35ce19ea1c690d7d4b8702af213447f3f41232f8 --- /dev/null +++ b/colossalai/legacy/utils/checkpoint/__init__.py @@ -0,0 +1,3 @@ +from .module_checkpoint import load_checkpoint, save_checkpoint + +__all__ = ["save_checkpoint", "load_checkpoint"] diff --git a/colossalai/legacy/utils/checkpoint/module_checkpoint.py b/colossalai/legacy/utils/checkpoint/module_checkpoint.py new file mode 100644 index 0000000000000000000000000000000000000000..1d691e5c8f976289d491a762f00ae55ce825ff38 --- /dev/null +++ b/colossalai/legacy/utils/checkpoint/module_checkpoint.py @@ -0,0 +1,144 @@ +from typing import Dict, Optional + +import torch +import torch.distributed as dist + +from colossalai.interface import OptimizerWrapper +from colossalai.tensor import ColoTensor + +from .utils import gather_tensor, scatter_tensor + + +def save_checkpoint( + path: str, + epoch: int, + model: torch.nn.Module, + optimizer: Optional[OptimizerWrapper] = None, + lr_scheduler: torch.optim.lr_scheduler._LRScheduler = None, + *args, + **kwargs, +): + """save_checkpoint + save a model, whose parameters are `ColoTensor`s. + Args: + path (str): directory to save the checkpoint files. + epoch (int): the number of epoch + model (torch.nn.Module): a torch module initialized by ColoInitContext + optimizer (OptimizerWrapper, optional): optimizers. Defaults to None. + lr_scheduler (torch.optim.lr_scheduler._LRScheduler, optional): lr schedule. Defaults to None. + """ + rank = dist.get_rank() + model_state = model.state_dict() + # save the dist context about the tensors in a new dict, while still maintain the original dict. + for k, v in model_state.items(): + if isinstance(v, ColoTensor): + gather_tensor(v) # gather shared tensors to rank0 + # don't recover tensors in rank0, since the dict is only a copy of model + + if rank == 0: + # sanity check + for k, v in model_state.items(): + if isinstance(v, ColoTensor): + assert v.save_ready + assert v.is_replicate() + delattr(v, "save_ready") + # model saving + save_state = {"epoch": epoch, "model": model_state} + torch.save(save_state, path + "/epoch_{}_model.pth".format(epoch), *args, **kwargs) + + # delete old dicts + del model_state + # synchronize all the processes + dist.barrier() + + if optimizer is not None: + mapping = dict() + optim_state = optimizer.state_dict() + for k, v in optim_state["state"].items(): + for n, t in v.items(): + if isinstance(t, ColoTensor): + mapping[(k, n)] = t.dist_spec + gather_tensor(t) + + if rank == 0: + save_state = {"epoch": epoch, "optim": optim_state} + torch.save(save_state, path + "/epoch_{}_optim.pth".format(epoch), *args, **kwargs) + # recover colo tensors in rank0 + for k, v in optimizer.state_dict()["state"].items(): + for n, t in v.items(): + if isinstance(t, ColoTensor): + assert hasattr(t, "save_ready") + t.set_dist_spec(mapping[(k, n)]) + delattr(t, "save_ready") + + del optim_state + del mapping + dist.barrier() + + +def load_checkpoint( + path: str, + epoch: int, + model: torch.nn.Module, + optimizer: Optional[OptimizerWrapper] = None, + lr_scheduler: torch.optim.lr_scheduler._LRScheduler = None, + torch_load_kwargs: Optional[Dict] = None, + load_state_dict_kwargs: Optional[Dict] = None, +): + """load_checkpoint + load a model, whose parameters are `ColoTensor`s. + Args: + path (str): directory to save the checkpoint files. + epoch (int): the number of epoch + model (torch.nn.Module): a torch module initialized by ColoInitContext + optimizer (OptimizerWrapper, optional): optimizers. Defaults to None. + lr_scheduler (torch.optim.lr_scheduler._LRScheduler, optional): lr schedule. Defaults to None. + torch_load_kwargs: (dict, optional): The kwargs of torch.load inside the function + load_state_dict_kwargs (dict, optional): The kwargs of load_state_dict inside the function + """ + # initialize the default parameters + if not torch_load_kwargs: + torch_load_kwargs = dict() + if not load_state_dict_kwargs: + load_state_dict_kwargs = dict() + + rank = dist.get_rank() + mapping = dict() + for n, p in model.named_parameters(): + if isinstance(p, ColoTensor): + mapping[n] = p.dist_spec + gather_tensor(p) + + if rank == 0: + load_state = torch.load(path + "/epoch_{}_model.pth".format(epoch), **torch_load_kwargs) + model.load_state_dict(load_state["model"], **load_state_dict_kwargs) + dist.barrier() + + # scatter loaded parameters + for n, p in model.named_parameters(): + if isinstance(p, ColoTensor): + scatter_tensor(p, mapping[n]) + if rank == 0: + assert hasattr(p, "save_ready") + delattr(p, "save_ready") + del mapping + + if optimizer is not None: + mapping = dict() + for k, v in optimizer.state_dict()["state"].items(): + for n, t in v.items(): + if isinstance(t, ColoTensor): + mapping[(k, n)] = t.dist_spec + gather_tensor(t) + + if rank == 0: + colo_checkpoint = torch.load(path + "/epoch_{}_optim.pth".format(epoch), **torch_load_kwargs) + optimizer.load_state_dict(colo_checkpoint["optim"], **load_state_dict_kwargs) + dist.barrier() + + for k, v in optimizer.state_dict()["state"].items(): + for n, t in v.items(): + if isinstance(t, ColoTensor): + scatter_tensor(t, mapping[(k, n)]) + + del mapping diff --git a/colossalai/legacy/utils/checkpoint/utils.py b/colossalai/legacy/utils/checkpoint/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..c56848cf06c431af32b7371898a4b76fb61de3c4 --- /dev/null +++ b/colossalai/legacy/utils/checkpoint/utils.py @@ -0,0 +1,64 @@ +import torch +import torch.distributed as dist + +from colossalai.legacy.tensor import ColoTensorSpec +from colossalai.legacy.tensor.distspec import DistPlacementPattern, _DistSpec +from colossalai.tensor import ColoTensor + + +def robust_broadcast(tensor): + with torch.no_grad(): + is_cpu_ten = tensor.device.type == "cpu" + if is_cpu_ten: + b_data = tensor.cuda() + else: + b_data = tensor + + dist.broadcast(b_data, 0) + + if is_cpu_ten: + tensor.copy_(b_data) + + +def gather_tensor(colo_tensor: ColoTensor) -> None: + """Make colo_tensor replicated when the rank is 0""" + if not colo_tensor.is_replicate(): + pg = colo_tensor.get_process_group() + # for the group which contains rank 0 + if pg.dp_local_rank() == 0: + old_dist_spec = colo_tensor.dist_spec + colo_tensor.to_replicate_() + if dist.get_rank() != 0: + colo_tensor.set_dist_spec(old_dist_spec) + + # synchronize all processes for unexpected problems + dist.barrier() + + if dist.get_rank() == 0: + setattr(colo_tensor, "save_ready", True) # set saving signature + + +def scatter_tensor(colo_tensor: ColoTensor, dist_spec: _DistSpec) -> None: + """Reversal operation of `gather_tensor`.""" + if dist_spec.placement == DistPlacementPattern.REPLICATE: + robust_broadcast(colo_tensor.data) + else: + global_size = colo_tensor.size_global() + + if dist.get_rank() == 0: + entire_data = colo_tensor.data + else: + entire_data = torch.empty(global_size, device=colo_tensor.device) + robust_broadcast(entire_data) + + if dist.get_rank() == 0: + colo_tensor.set_dist_spec(dist_spec) + else: + rep_tensor = ColoTensor( + entire_data, ColoTensorSpec(pg=colo_tensor.get_process_group(), compute_attr=colo_tensor.compute_spec) + ) + rep_tensor.set_dist_spec(dist_spec) + with torch.no_grad(): + colo_tensor.data.copy_(rep_tensor.data) + # synchronize all processes for unexpected problems + dist.barrier() diff --git a/colossalai/utils/checkpointing.py b/colossalai/legacy/utils/checkpointing.py similarity index 84% rename from colossalai/utils/checkpointing.py rename to colossalai/legacy/utils/checkpointing.py index d1c6b6370ede4420191eba75e3ab7816f99ed499..c068faafbf447e56abf9d240f684e0221eb64a55 100644 --- a/colossalai/utils/checkpointing.py +++ b/colossalai/legacy/utils/checkpointing.py @@ -3,13 +3,15 @@ from itertools import chain import torch import torch.distributed as dist -from colossalai.context.parallel_mode import ParallelMode -from colossalai.core import global_context as gpc -from colossalai.constants import IS_TENSOR_PARALLEL + +from colossalai.legacy.constants import IS_TENSOR_PARALLEL +from colossalai.legacy.context.parallel_mode import ParallelMode +from colossalai.legacy.core import global_context as gpc + try: from torch.nn.modules.module import _EXTRA_STATE_KEY_SUFFIX except ImportError: - _EXTRA_STATE_KEY_SUFFIX = '_extra_state' + _EXTRA_STATE_KEY_SUFFIX = "_extra_state" from .common import is_using_pp @@ -23,10 +25,9 @@ def broadcast_state_dict(state_dict, parallel_mode): return state_dict[0] -def partition_tensor_parallel_state_dict(state_dict: OrderedDict, - parallel_mode: ParallelMode, - dims: dict = dict(), - partition_states: dict = dict()): +def partition_tensor_parallel_state_dict( + state_dict: OrderedDict, parallel_mode: ParallelMode, dims: dict = dict(), partition_states: dict = dict() +): src_rank = gpc.get_ranks_in_group(parallel_mode)[0] depth = gpc.get_world_size(parallel_mode) group = gpc.get_cpu_group(parallel_mode) @@ -63,11 +64,11 @@ def partition_tensor_parallel_state_dict(state_dict: OrderedDict, def gather_tensor_parallel_state_dict( - state_dict: OrderedDict, - parallel_mode: ParallelMode, - dims: dict = dict(), - partition_states: dict = dict(), - keep_vars: bool = False, + state_dict: OrderedDict, + parallel_mode: ParallelMode, + dims: dict = dict(), + partition_states: dict = dict(), + keep_vars: bool = False, ): dst_rank = gpc.get_ranks_in_group(parallel_mode)[0] depth = gpc.get_world_size(parallel_mode) @@ -136,8 +137,11 @@ def partition_pipeline_parallel_state_dict(model, state_dict): def gather_pipeline_parallel_state_dict(state_dict): - gathered_states = ([None for _ in range(gpc.get_world_size(ParallelMode.PIPELINE))] - if gpc.get_local_rank(ParallelMode.PIPELINE) == 0 else None) + gathered_states = ( + [None for _ in range(gpc.get_world_size(ParallelMode.PIPELINE))] + if gpc.get_local_rank(ParallelMode.PIPELINE) == 0 + else None + ) dist.gather_object( state_dict, gathered_states, @@ -145,18 +149,23 @@ def gather_pipeline_parallel_state_dict(state_dict): group=gpc.get_cpu_group(ParallelMode.PIPELINE), ) - state_dict = (OrderedDict(chain.from_iterable(state.items() for state in gathered_states)) - if gpc.get_local_rank(ParallelMode.PIPELINE) == 0 else OrderedDict()) + state_dict = ( + OrderedDict(chain.from_iterable(state.items() for state in gathered_states)) + if gpc.get_local_rank(ParallelMode.PIPELINE) == 0 + else OrderedDict() + ) return state_dict -def save_checkpoint(file, - epoch: int, - model: torch.nn.Module, - optimizer: torch.optim.Optimizer = None, - lr_scheduler: torch.optim.lr_scheduler._LRScheduler = None, - **kwargs): +def save_checkpoint( + file, + epoch: int, + model: torch.nn.Module, + optimizer: torch.optim.Optimizer = None, + lr_scheduler: torch.optim.lr_scheduler._LRScheduler = None, + **kwargs, +): """Stores the checkpoint to disk. Saves all the training components' parameters or buffers, such as model, optimizer, lr_scheduler etc. into a checkpoint dictionary. @@ -194,8 +203,11 @@ def broadcast_model(model: torch.nn.Module): src_rank = gpc.get_ranks_in_group(ParallelMode.TENSOR)[0] for p in model.parameters(): if not getattr(p, IS_TENSOR_PARALLEL, False) and p.storage().size() > 0: - group = gpc.get_group(ParallelMode.TENSOR) if p.device.type == 'cuda' else gpc.get_cpu_group( - ParallelMode.TENSOR) + group = ( + gpc.get_group(ParallelMode.TENSOR) + if p.device.type == "cuda" + else gpc.get_cpu_group(ParallelMode.TENSOR) + ) dist.broadcast(p, src_rank, group=group) @@ -224,8 +236,9 @@ def load_checkpoint( Raises: RuntimeError: Raise error if the model/optimizer cannot successfully be recuperated """ - state_dict = (torch.load(file, map_location=torch.device("cpu")) - if gpc.get_local_rank(ParallelMode.MODEL) == 0 else None) + state_dict = ( + torch.load(file, map_location=torch.device("cpu")) if gpc.get_local_rank(ParallelMode.MODEL) == 0 else None + ) # model states model_state = state_dict.pop("model") if state_dict is not None else dict() @@ -244,8 +257,11 @@ def load_checkpoint( dist.gather_object(error_msgs, all_error_msgs, dst=dst_rank, group=gpc.get_cpu_group(ParallelMode.MODEL)) if gpc.get_global_rank() == 0: all_error_msgs = list(chain.from_iterable(all_error_msgs)) - raise RuntimeError("Error(s) in loading state_dict for {}:\n\t{}".format( - model.__class__.__name__, "\n\t".join(all_error_msgs))) + raise RuntimeError( + "Error(s) in loading state_dict for {}:\n\t{}".format( + model.__class__.__name__, "\n\t".join(all_error_msgs) + ) + ) else: raise e diff --git a/colossalai/legacy/utils/common.py b/colossalai/legacy/utils/common.py new file mode 100644 index 0000000000000000000000000000000000000000..671bcc3d6ad7069e0833298aa11ab5dfd64462bf --- /dev/null +++ b/colossalai/legacy/utils/common.py @@ -0,0 +1,433 @@ +#!/usr/bin/env python +# -*- encoding: utf-8 -*- +from collections import defaultdict +from contextlib import contextmanager +from typing import Dict, List, Optional, Union + +import torch +import torch.distributed as dist +from torch import inf +from torch.nn.parameter import Parameter + +from colossalai.legacy.constants import IS_TENSOR_PARALLEL, NUM_PARTITIONS, TENSOR_PARALLEL_ATTRIBUTES +from colossalai.legacy.context.parallel_mode import ParallelMode +from colossalai.legacy.core import global_context as gpc +from colossalai.legacy.global_variables import tensor_parallel_env as env +from colossalai.legacy.tensor import ProcessGroup +from colossalai.tensor import ColoParameter +from colossalai.utils.multi_tensor_apply import multi_tensor_applier + +try: + from colossalai._C import fused_optim +except: + fused_optim = None + + +def print_rank_0(msg: str, logger=None): + """Print messages and save logs(optional). This is executed only if you are the rank-0 gpu. + + Args: + msg (str): A string message to output. + logger (:class:`colossalai.logging.DistributedLogger`, optional): + The logger to record the message, defaults to None. + """ + if gpc.get_global_rank() == 0: + if logger is None: + print(msg, flush=True) + else: + logger.info(msg) + + +def sync_model_param(model, parallel_mode): + r"""Make sure data parameters are consistent during Data Parallel Mode. + + Args: + model (:class:`torch.nn.Module`): A pyTorch model on whose parameters you check the consistency. + parallel_mode (:class:`colossalai.legacy.context.ParallelMode`): Parallel mode to be checked. + + Note: + The parallel_mode should be concluded in ``ParallelMode``. More details about ``ParallelMode`` could be found + in `parallel_mode `_ + """ + if gpc.is_initialized(parallel_mode) and gpc.get_world_size(parallel_mode) > 1: + for param in model.parameters(): + ranks = gpc.get_ranks_in_group(parallel_mode) + dist.broadcast(param, src=ranks[0], group=gpc.get_group(parallel_mode)) + + +def is_dp_rank_0(): + return not gpc.is_initialized(ParallelMode.DATA) or gpc.is_first_rank(ParallelMode.DATA) + + +def is_tp_rank_0(): + return not gpc.is_initialized(ParallelMode.TENSOR) or gpc.is_first_rank(ParallelMode.TENSOR) + + +def is_no_pp_or_last_stage(): + return not gpc.is_initialized(ParallelMode.PIPELINE) or gpc.is_last_rank(ParallelMode.PIPELINE) + + +def is_using_ddp(): + return gpc.is_initialized(ParallelMode.DATA) and gpc.get_world_size(ParallelMode.DATA) > 1 + + +def is_using_pp(): + return gpc.is_initialized(ParallelMode.PIPELINE) and gpc.get_world_size(ParallelMode.PIPELINE) > 1 + + +def is_using_sequence(): + return gpc.is_initialized(ParallelMode.SEQUENCE) and gpc.get_world_size(ParallelMode.SEQUENCE) > 1 + + +class model_branch_context(object): + def __enter__(self): + self.env_status = env.save() + + def __exit__(self, *exc_info): + env.load(**self.env_status) + + +def is_model_parallel_parameter(p): + return hasattr(p, IS_TENSOR_PARALLEL) and getattr(p, IS_TENSOR_PARALLEL) + + +def _calc_l2_norm(grads): + # we should not + global fused_optim + + if fused_optim is None: + from colossalai.kernel.op_builder import FusedOptimBuilder + + fused_optim = FusedOptimBuilder().load() + + norm = 0.0 + if len(grads) > 0: + dummy_overflow_buf = torch.cuda.IntTensor([0]) + norm, _ = multi_tensor_applier( + fused_optim.multi_tensor_l2norm, dummy_overflow_buf, [grads], False # no per-parameter norm + ) + return norm + + +def _calc_lp(grads, norm_type): + norm = 0.0 + for grad in grads: + grad_norm = torch.norm(grad, norm_type) + norm += grad_norm**norm_type + return norm + + +def _move_norm_to_cuda(norm: Union[float, torch.Tensor]) -> Union[float, torch.Tensor]: + if torch.is_tensor(norm) and norm.device.type != "cuda": + norm = norm.to(torch.cuda.current_device()) + return norm + + +def _get_tensor_norm(norm: Union[float, torch.Tensor], move_to_cuda) -> torch.Tensor: + if isinstance(norm, float): + norm = torch.Tensor([norm]) + if move_to_cuda: + norm = norm.to(torch.cuda.current_device()) + return norm + + +# ======== Gradient Clipping ========= + + +def _compute_local_lp(params: List[ColoParameter], norm_type: float) -> float: + if len(params) == 0: + return 0.0 + grads = [p.grad for p in params] + use_cuda_kernel = grads[0].device.type == "cuda" + if norm_type == inf: + local_lp = max([g.abs().max() for g in grads]) + elif norm_type == 2.0 and use_cuda_kernel: + local_lp = _calc_l2_norm(grads) ** norm_type + else: + local_lp = _calc_lp(grads, norm_type) + if isinstance(local_lp, torch.Tensor): + return local_lp.item() + return local_lp + + +def _compute_buckets_lp(params: List[ColoParameter], norm_type: float) -> float: + if len(params) == 0: + return 0.0 + buckets: Dict[Optional[ProcessGroup], List[ColoParameter]] = defaultdict(list) + for p in params: + if p.is_replicate(): + buckets[None].append(p) + else: + buckets[p.get_process_group().tp_process_group()].append(p) + total_lp = 0.0 + for group, bucket in buckets.items(): + local_lp = _compute_local_lp(bucket, norm_type) + if group is not None: + local_lp_tensor = torch.tensor([local_lp], device=torch.cuda.current_device()) + if norm_type == inf: + dist.all_reduce(local_lp_tensor, op=dist.ReduceOp.MAX, group=group) + else: + dist.all_reduce(local_lp_tensor, group=group) + local_lp = local_lp_tensor.item() + if norm_type == inf: + total_lp = max(total_lp, local_lp) + else: + total_lp += local_lp + return total_lp + + +def _compute_pp_grad_lp(total_lp: float, norm_type: float) -> float: + if gpc.is_initialized(ParallelMode.PIPELINE) and gpc.get_world_size(ParallelMode.PIPELINE) > 1: + total_lp_tensor = torch.tensor([total_lp], device=torch.cuda.current_device()) + if norm_type == inf: + dist.all_reduce(total_lp_tensor, op=dist.ReduceOp.MAX, group=gpc.get_group(ParallelMode.PIPELINE)) + else: + dist.all_reduce(total_lp_tensor, group=gpc.get_group(ParallelMode.PIPELINE)) + total_lp = total_lp_tensor.item() + return total_lp + + +def _compute_grad_lp(parameters, norm_type: float = 2.0) -> float: + if isinstance(parameters, torch.Tensor): + parameters = [parameters] + grad_dtype = None + cpu_grad_params: List[ColoParameter] = [] + cuda_grad_params: List[ColoParameter] = [] + for p in parameters: + if p.grad is None: + continue + assert isinstance(p, ColoParameter) + if grad_dtype is None: + grad_dtype = p.grad.dtype + assert p.grad.dtype == grad_dtype, f"Expected all grads are {grad_dtype}, got {p.grad.dtype}" + if p.grad.device.type == "cuda": + cuda_grad_params.append(p) + else: + cpu_grad_params.append(p) + norm_type = float(norm_type) + cpu_lp = _compute_buckets_lp(cpu_grad_params, norm_type) + cuda_lp = _compute_buckets_lp(cuda_grad_params, norm_type) + if norm_type == inf: + total_lp = max(cpu_lp, cuda_lp) + else: + total_lp = cpu_lp + cuda_lp + return _compute_pp_grad_lp(total_lp, norm_type) + + +def compute_grad_norm(parameters, norm_type: float = 2.0) -> float: + norm_type = float(norm_type) + total_norm = _compute_grad_lp(parameters, norm_type) + if norm_type != inf: + total_norm = total_norm ** (1 / norm_type) + return total_norm + + +def _clip_grad_norm(parameters, max_norm: float, total_norm: float) -> None: + clip_coef = max_norm / (total_norm + 1e-6) + if clip_coef < 1.0: + cuda_grads: List[torch.Tensor] = [] + cpu_grads: List[torch.Tensor] = [] + if isinstance(parameters, torch.Tensor): + parameters = [parameters] + for p in parameters: + if p.grad is None: + continue + if p.grad.device.type == "cuda": + cuda_grads.append(p.grad.detach()) + else: + cpu_grads.append(p.grad.detach()) + if len(cuda_grads) > 0: + dummy_overflow_buf = torch.cuda.IntTensor([0]) + multi_tensor_applier( + fused_optim.multi_tensor_scale, dummy_overflow_buf, [cuda_grads, cuda_grads], clip_coef + ) + for g in cpu_grads: + g.mul_(clip_coef) + + +def clip_grad_norm(parameters, max_norm: float, norm_type: float = 2.0) -> float: + total_norm = compute_grad_norm(parameters, norm_type) + _clip_grad_norm(parameters, max_norm, total_norm) + return total_norm + + +def clip_grad_norm_fp32(parameters, max_norm, norm_type=2): + """Clips gradient norm of an iterable of parameters whose gradients are in fp32. + + This is adapted from :func:`torch.nn.utils.clip_grad.clip_grad_norm_` and + added functionality to handle model parallel parameters. + + Note: + the gradients are modified in place. + + Args: + parameters (Iterable[:class:`torch.tensor`] or :class:`torch.tensor`): + An iterable of Tensors or a single Tensor that will have gradients normalized. + max_norm (Union[float, int]): Max norm of the gradients. + norm_type (Union[float, int, 'inf']): Type of the used p-norm. Can be ``'inf'`` for infinity norm. + + Returns: + float: Total norm of the parameters. + """ + + if isinstance(parameters, torch.Tensor): + parameters = [parameters] + + # Filter parameters based on: + # - grad should not be none + # - parameter should not be shared + # - should not be a replica due to tensor model parallelism + params: List[Parameter] = [] + has_zero_shared_param: bool = False + for param in parameters: + if param.grad is not None: + # Make sure the grads are in fp32 + assert ( + param.grad.dtype == torch.float + ), f"expected gradient to be dtype torch.float, but got {param.grad.type()}" + if hasattr(param, "colo_attr") and param.colo_attr.sharded_data_tensor.is_sharded: + has_zero_shared_param = True + params.append(param) + + if len(params) == 0: + enable_cuda_kernels = False + else: + enable_cuda_kernels = params[0].grad.device.type == "cuda" + # Norm parameters. + max_norm = float(max_norm) + norm_type = float(norm_type) + + # Parameters can be on CPU or CUDA + # If parameters are on CPU, disable CUDA kernels + + # Calculate norm. + if norm_type == inf: + total_norm = max(p.grad.data.abs().max() for p in params) + total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)]) + # Take max across all model-parallel GPUs. + if gpc.is_initialized(ParallelMode.MODEL) and gpc.get_world_size(ParallelMode.MODEL) > 1: + dist.all_reduce( + total_norm_cuda, op=dist.ReduceOp.MAX, group=gpc.get_group(ParallelMode.MODEL), async_op=False + ) + if has_zero_shared_param: + dist.all_reduce( + total_norm_cuda, op=dist.ReduceOp.MAX, group=gpc.get_group(ParallelMode.DATA), async_op=False + ) + total_norm = total_norm_cuda[0].item() + else: + tensor_parallel_grads = [] + no_tensor_parallel_grads = [] + zero_sharded_grads = [] + for p in params: + if is_model_parallel_parameter(p): + reductor = (gpc.get_world_size(ParallelMode.TENSOR) / getattr(p, NUM_PARTITIONS)) ** (1 / norm_type) + tensor_parallel_grads.append(p.grad.data / reductor) + elif hasattr(p, "colo_attr") and p.colo_attr.sharded_data_tensor.is_sharded: + zero_sharded_grads.append(p.grad.data) + else: + no_tensor_parallel_grads.append(p.grad.data) + + if norm_type == 2.0 and enable_cuda_kernels: + tensor_parallel_norm = _calc_l2_norm(tensor_parallel_grads) ** norm_type + no_tensor_parallel_norm = _calc_l2_norm(no_tensor_parallel_grads) ** norm_type + zero_sharded_norm = _calc_l2_norm(zero_sharded_grads) ** norm_type + else: + tensor_parallel_norm = _calc_lp(tensor_parallel_grads, norm_type) + no_tensor_parallel_norm = _calc_lp(no_tensor_parallel_grads, norm_type) + zero_sharded_norm = _calc_lp(zero_sharded_grads, norm_type) + # If norm is type of float, then we convert them into torch.Tensor. + tensor_parallel_norm = _get_tensor_norm(tensor_parallel_norm, enable_cuda_kernels) + no_tensor_parallel_norm = _get_tensor_norm(no_tensor_parallel_norm, enable_cuda_kernels) + zero_sharded_norm = _get_tensor_norm(zero_sharded_norm, enable_cuda_kernels) + # If grads are on CPU, the norms is also on CPU. Cast them to CUDA tensors + if not enable_cuda_kernels: + tensor_parallel_norm = _move_norm_to_cuda(tensor_parallel_norm) + no_tensor_parallel_norm = _move_norm_to_cuda(no_tensor_parallel_norm) + zero_sharded_norm = _move_norm_to_cuda(zero_sharded_norm) + + # Sum across all model-parallel GPUs. + if gpc.is_initialized(ParallelMode.TENSOR) and len(tensor_parallel_grads) > 0: + dist.all_reduce(tensor_parallel_norm, op=dist.ReduceOp.SUM, group=gpc.get_group(ParallelMode.TENSOR)) + # Sum across all zero sharded GPUs + if len(zero_sharded_grads) > 0: + dist.all_reduce(zero_sharded_norm, group=gpc.get_group(ParallelMode.DATA)) + no_tensor_parallel_norm += zero_sharded_norm + total_norm = tensor_parallel_norm + no_tensor_parallel_norm + if gpc.is_initialized(ParallelMode.PIPELINE) and gpc.get_world_size(ParallelMode.PIPELINE) > 1: + dist.all_reduce(total_norm, op=dist.ReduceOp.SUM, group=gpc.get_group(ParallelMode.PIPELINE)) + total_norm = total_norm ** (1.0 / norm_type) + if torch.is_tensor(total_norm): + total_norm = total_norm.item() + + # Scale. + clip_coeff = max_norm / (total_norm + 1.0e-6) + if clip_coeff < 1.0: + if enable_cuda_kernels: + grads = [p.grad.detach() for p in params] + dummy_overflow_buf = torch.cuda.IntTensor([0]) + multi_tensor_applier(fused_optim.multi_tensor_scale, dummy_overflow_buf, [grads, grads], clip_coeff) + else: + for p in params: + p.grad.detach().mul_(clip_coeff) + return total_norm + + +def count_zeros_fp32(parameters): + if isinstance(parameters, torch.Tensor): + parameters = [parameters] + + # Filter parameters based on: + # - grad should not be none + # - parameter should not be shared + # - should not be a replica due to tensor model parallelism + total_num_zeros = 0.0 + for param in parameters: + grad_not_none = param.grad is not None + is_not_tp_duplicate = param_is_not_tensor_parallel_duplicate(param) + if grad_not_none and is_not_tp_duplicate: + grad = param.grad.detach() + num_zeros = grad.numel() - torch.count_nonzero(grad) + total_num_zeros = num_zeros + total_num_zeros + + total_num_zeros = torch.IntTensor([int(total_num_zeros)]).cuda() + + # Sum across all model-parallel GPUs. + ops = [] + ops.append( + dist.all_reduce(total_num_zeros, op=dist.ReduceOp.SUM, group=gpc.get_group(ParallelMode.TENSOR), async_op=True) + ) + if gpc.is_initialized(ParallelMode.PIPELINE): + ops.append( + dist.all_reduce( + total_num_zeros, op=dist.ReduceOp.SUM, group=gpc.get_group(ParallelMode.PIPELINE), async_op=True + ) + ) + + for req in ops: + req.wait() + total_num_zeros = total_num_zeros.item() + + return total_num_zeros + + +def copy_tensor_parallel_attributes(src_tensor, dst_tensor): + for attr in TENSOR_PARALLEL_ATTRIBUTES: + if hasattr(src_tensor, attr): + val = getattr(src_tensor, attr) + setattr(dst_tensor, attr, val) + + +def param_is_not_tensor_parallel_duplicate(param): + return (hasattr(param, IS_TENSOR_PARALLEL) and getattr(param, IS_TENSOR_PARALLEL)) or ( + gpc.get_local_rank(ParallelMode.TENSOR) == 0 + ) + + +@contextmanager +def switch_virtual_pipeline_parallel_rank(rank): + prev_rank = gpc.virtual_pipeline_parallel_rank + try: + gpc.set_virtual_pipeline_parallel_rank(rank) + yield + finally: + gpc.set_virtual_pipeline_parallel_rank(prev_rank) diff --git a/colossalai/legacy/utils/data_sampler/__init__.py b/colossalai/legacy/utils/data_sampler/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..677d767667f2eb37080d2b7003cdc483b9a37c21 --- /dev/null +++ b/colossalai/legacy/utils/data_sampler/__init__.py @@ -0,0 +1,4 @@ +from .base_sampler import BaseSampler +from .data_parallel_sampler import DataParallelSampler, get_dataloader + +__all__ = ["BaseSampler", "DataParallelSampler", "get_dataloader"] diff --git a/colossalai/utils/data_sampler/base_sampler.py b/colossalai/legacy/utils/data_sampler/base_sampler.py similarity index 99% rename from colossalai/utils/data_sampler/base_sampler.py rename to colossalai/legacy/utils/data_sampler/base_sampler.py index 89f3bca5b1b51925ef7b32e4a08f1df301776fcb..c6b916fc48702afe425d1cc3eb85bc0fee0b0cd4 100644 --- a/colossalai/utils/data_sampler/base_sampler.py +++ b/colossalai/legacy/utils/data_sampler/base_sampler.py @@ -5,7 +5,6 @@ from abc import ABC, abstractmethod class BaseSampler(ABC): - def __init__(self, dataset, batch_size): self.dataset = dataset self.batch_size = batch_size diff --git a/colossalai/legacy/utils/data_sampler/data_parallel_sampler.py b/colossalai/legacy/utils/data_sampler/data_parallel_sampler.py new file mode 100644 index 0000000000000000000000000000000000000000..41d0861e2249ccb7c8ed6e52d56d58c2571d9f64 --- /dev/null +++ b/colossalai/legacy/utils/data_sampler/data_parallel_sampler.py @@ -0,0 +1,160 @@ +#!/usr/bin/env python +# -*- encoding: utf-8 -*- +# adapted from torch.utils.data.DistributedSampler + +import math +import random +from typing import Iterator, TypeVar + +import numpy as np +import torch +from torch.utils.data import DataLoader, Dataset, Sampler + +from colossalai.legacy.context.parallel_mode import ParallelMode +from colossalai.legacy.core import global_context as gpc + +T_co = TypeVar("T_co", covariant=True) + + +class DataParallelSampler(Sampler): + """A data sampler for distributed data parallelism. + + Args: + dataset (:class:`torch.utils.data.Dataset`): The Dataset for sampling. + shuffle (bool, optional): Whether to shuffle data, defaults to False. + seed (int, optional): The random seed used for sampling, defaults to 0. + drop_last (bool, optional): Set to True to drop the last incomplete batch, if the dataset size + is not divisible by the batch size. If False and the size of dataset is not divisible by + the batch size, then the last batch will be smaller, defaults to False. + """ + + def __init__(self, dataset: Dataset, shuffle: bool = False, seed: int = 0, drop_last: bool = False) -> None: + self.dataset = dataset + self.num_replicas = gpc.get_world_size(ParallelMode.DATA) + self.rank = gpc.get_local_rank(ParallelMode.DATA) + self.epoch = 0 + self.drop_last = drop_last + # If the dataset length is evenly divisible by # of replicas, then there + # is no need to drop any data, since the dataset will be split equally. + # type: ignore[arg-type] + if self.drop_last and len(self.dataset) % self.num_replicas != 0: + # Split to nearest available length that is evenly divisible. + # This is to ensure each rank receives the same amount of data when + # using this Sampler. + self.num_samples = math.ceil( + # `type:ignore` is required because Dataset cannot provide a default __len__ + # see NOTE in pytorch/torch/utils/data/sampler.py + (len(self.dataset) - self.num_replicas) + / self.num_replicas # type: ignore[arg-type] + ) + else: + self.num_samples = math.ceil(len(self.dataset) / self.num_replicas) # type: ignore[arg-type] + self.total_size = self.num_samples * self.num_replicas + self.shuffle = shuffle + self.seed = seed + + def __iter__(self) -> Iterator[T_co]: + if self.shuffle: + # deterministically shuffle based on epoch and seed + g = torch.Generator() + g.manual_seed(self.seed + self.epoch) + # type: ignore[arg-type] + indices = torch.randperm(len(self.dataset), generator=g).tolist() + + # update for next epoch so that there is no need to call + # set_epoch manually + self.epoch += 1 + else: + indices = list(range(len(self.dataset))) # type: ignore[arg-type] + + if not self.drop_last: + # add extra samples to make it evenly divisible + padding_size = self.total_size - len(indices) + if padding_size <= len(indices): + indices += indices[:padding_size] + else: + indices += (indices * math.ceil(padding_size / len(indices)))[:padding_size] + else: + # remove tail of data to make it evenly divisible. + indices = indices[: self.total_size] + assert len(indices) == self.total_size + + # subsample + indices = indices[self.rank : self.total_size : self.num_replicas] + assert len(indices) == self.num_samples + + return iter(indices) + + def __len__(self) -> int: + return self.num_samples + + def set_epoch(self, epoch: int) -> None: + r"""Sets the epoch for this sampler. When :attr:`shuffle=True`, this ensures all replicas + use a different random ordering for each epoch. Otherwise, the next iteration of this + sampler will yield the same ordering. + + Args: + epoch (int): Epoch number. + """ + self.epoch = epoch + + +def get_dataloader( + dataset, shuffle=False, seed=1024, add_sampler=True, drop_last=False, pin_memory=False, num_workers=0, **kwargs +): + r"""Set up a deterministic dataloader (also configure seed workers, samplers and whether shuffle or not) + + Note: + When pipeline parallel is enabled, shuffle cannot be True as it will result in mismatch between input data + on the 1st stage and label on the last stage. + + Args: + dataset (:class:`torch.utils.data.Dataset`): The dataset to be loaded. + shuffle (bool, optional): Whether to shuffle the dataset. Defaults to False. + seed (int, optional): Random worker seed for sampling, defaults to 1024. + add_sampler: Whether to add ``DistributedDataParallelSampler`` to the dataset. Defaults to True. + drop_last (bool, optional): Set to True to drop the last incomplete batch, if the dataset size + is not divisible by the batch size. If False and the size of dataset is not divisible by + the batch size, then the last batch will be smaller, defaults to False. + pin_memory (bool, optional): Whether to pin memory address in CPU memory. Defaults to False. + num_workers (int, optional): Number of worker threads for this dataloader. Defaults to 0. + kwargs (dict): optional parameters for ``torch.utils.data.DataLoader``, more details could be found in + `DataLoader `_. + + Returns: + :class:`torch.utils.data.DataLoader`: A DataLoader used for training or testing. + """ + _kwargs = kwargs.copy() + + if add_sampler and gpc.is_initialized(ParallelMode.DATA) and gpc.get_world_size(ParallelMode.DATA) > 1: + sampler = DataParallelSampler(dataset, shuffle=shuffle) + else: + sampler = None + + # Deterministic dataloader + def seed_worker(worker_id): + worker_seed = seed + np.random.seed(worker_seed) + torch.manual_seed(worker_seed) + random.seed(worker_seed) + + if sampler is None: + return DataLoader( + dataset, + worker_init_fn=seed_worker, + shuffle=shuffle, + drop_last=drop_last, + pin_memory=pin_memory, + num_workers=num_workers, + **_kwargs, + ) + else: + return DataLoader( + dataset, + sampler=sampler, + worker_init_fn=seed_worker, + drop_last=drop_last, + pin_memory=pin_memory, + num_workers=num_workers, + **_kwargs, + ) diff --git a/colossalai/utils/memory.py b/colossalai/legacy/utils/memory.py similarity index 86% rename from colossalai/utils/memory.py rename to colossalai/legacy/utils/memory.py index 434e90edd3b98fb7f69c502d2d5ebf21e127d0bb..2f99a7d2f72e0889cdb18cdff265f49050594902 100644 --- a/colossalai/utils/memory.py +++ b/colossalai/legacy/utils/memory.py @@ -1,15 +1,15 @@ -import torch import gc -import psutil from collections import namedtuple -from colossalai.context.parallel_mode import ParallelMode -from colossalai.utils import get_current_device -from colossalai.core import global_context as gpc -from colossalai.context.parallel_mode import ParallelMode -from colossalai.logging import get_dist_logger +import psutil +import torch +import torch.distributed as dist from packaging import version +from colossalai.legacy.core import global_context as gpc +from colossalai.logging import get_dist_logger +from colossalai.utils import get_current_device + _GLOBAL_CUDA_MEM_FRACTION = 1.0 _GLOBAL_CPU_MEM_CAPACITY = -1 @@ -68,7 +68,7 @@ def report_memory_usage(message, logger=None, report_cpu=False): Raises: EnvironmentError: Raise error if no distributed environment has been initialized. """ - if not gpc.is_initialized(ParallelMode.GLOBAL): + if not dist.is_initialized(): raise EnvironmentError("No distributed environment is initialized") gpu_allocated = _bytes_to_MB(torch.cuda.memory_allocated()) @@ -76,8 +76,10 @@ def report_memory_usage(message, logger=None, report_cpu=False): gpu_cached = _bytes_to_MB(torch.cuda.memory_reserved()) gpu_max_cached = _bytes_to_MB(torch.cuda.max_memory_reserved()) - full_log = f"{message}: GPU: allocated {gpu_allocated} MB, max allocated {gpu_max_allocated} MB, " \ + full_log = ( + f"{message}: GPU: allocated {gpu_allocated} MB, max allocated {gpu_max_allocated} MB, " + f"cached: {gpu_cached} MB, max cached: {gpu_max_cached} MB" + ) if report_cpu: # python doesn't do real-time garbage collection so do it explicitly to get the correct RAM reports @@ -91,7 +93,7 @@ def report_memory_usage(message, logger=None, report_cpu=False): logger.info(full_log) # get the peak memory to report correct data, so reset the counter for the next call - if hasattr(torch.cuda, "reset_peak_memory_stats"): # pytorch 1.4+ + if hasattr(torch.cuda, "reset_peak_memory_stats"): # pytorch 1.4+ torch.cuda.reset_peak_memory_stats() @@ -106,10 +108,10 @@ def colo_device_memory_capacity(device: torch.device) -> int: int: size in byte """ assert isinstance(device, torch.device) - if device.type == 'cpu': + if device.type == "cpu": # In the context of 1-CPU-N-GPU, the memory capacity of the current process is 1/N overall CPU memory. return colo_get_cpu_memory_capacity() / gpc.num_processes_on_current_node - if device.type == 'cuda': + if device.type == "cuda": return torch.cuda.get_device_properties(get_current_device()).total_memory * _GLOBAL_CUDA_MEM_FRACTION @@ -123,31 +125,31 @@ def colo_device_memory_used(device: torch.device) -> int: Returns: int: memory size in bytes """ - if device.type == 'cpu': + if device.type == "cpu": mem_info = _get_cpu_memory_info() # In the context of 1-CPU-N-GPU, the memory usage of the current process is 1/N CPU memory used. # Each process consumes the same amount of memory. ret = mem_info.used / gpc.num_processes_on_current_node return ret - elif device.type == 'cuda': + elif device.type == "cuda": ret: int = torch.cuda.memory_allocated(device) # get the peak memory to report correct data, so reset the counter for the next call - if hasattr(torch.cuda, "reset_peak_memory_stats"): # pytorch 1.4+ + if hasattr(torch.cuda, "reset_peak_memory_stats"): # pytorch 1.4+ torch.cuda.reset_peak_memory_stats(device) return ret def colo_set_process_memory_fraction(ratio: float) -> None: - """colo_set_process_memory_fraction + """colo_set_process_memory_fraction set how much cuda memory used on the gpu belonging to the current process. Args: ratio (float): a ratio between 0. ~ 1. """ - if version.parse(torch.__version__) < version.parse('1.8'): - logger = get_dist_logger('colo_set_process_memory_fraction') - logger.warning('colo_set_process_memory_fraction failed because torch version is less than 1.8') + if version.parse(torch.__version__) < version.parse("1.8"): + logger = get_dist_logger("colo_set_process_memory_fraction") + logger.warning("colo_set_process_memory_fraction failed because torch version is less than 1.8") return global _GLOBAL_CUDA_MEM_FRACTION _GLOBAL_CUDA_MEM_FRACTION = ratio diff --git a/colossalai/utils/profiler/__init__.py b/colossalai/legacy/utils/profiler/__init__.py similarity index 100% rename from colossalai/utils/profiler/__init__.py rename to colossalai/legacy/utils/profiler/__init__.py diff --git a/colossalai/utils/profiler/extention.py b/colossalai/legacy/utils/profiler/extention.py similarity index 99% rename from colossalai/utils/profiler/extention.py rename to colossalai/legacy/utils/profiler/extention.py index 6726a683cc05ebb1ac5370d8c17750cd869d9ec2..c07c6046bb1cdfe6a2a8fc9838e4fcb30ae64859 100644 --- a/colossalai/utils/profiler/extention.py +++ b/colossalai/legacy/utils/profiler/extention.py @@ -2,7 +2,6 @@ from abc import ABC, abstractmethod class ProfilerExtension(ABC): - @abstractmethod def prepare_trace(self): pass diff --git a/colossalai/legacy/utils/profiler/legacy/__init__.py b/colossalai/legacy/utils/profiler/legacy/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..88b4201d8bf3fcabf78c7615d0477323ae26bd71 --- /dev/null +++ b/colossalai/legacy/utils/profiler/legacy/__init__.py @@ -0,0 +1,6 @@ +from .comm_profiler import CommProfiler +from .mem_profiler import MemProfiler +from .pcie_profiler import PcieProfiler +from .prof_utils import BaseProfiler, ProfilerContext + +__all__ = ["BaseProfiler", "CommProfiler", "PcieProfiler", "MemProfiler", "ProfilerContext"] diff --git a/colossalai/utils/profiler/legacy/comm_profiler.py b/colossalai/legacy/utils/profiler/legacy/comm_profiler.py similarity index 76% rename from colossalai/utils/profiler/legacy/comm_profiler.py rename to colossalai/legacy/utils/profiler/legacy/comm_profiler.py index a4f5729c97ec4b4365e40b69c9a0fda7de5055d9..ad54b989f4122885dd75036f1fc9622028fbf68f 100644 --- a/colossalai/utils/profiler/legacy/comm_profiler.py +++ b/colossalai/legacy/utils/profiler/legacy/comm_profiler.py @@ -1,308 +1,318 @@ -import inspect -from pathlib import Path -from functools import partial -import torch -from torch.autograd.profiler import profile -import torch.distributed as dist -from torch.distributed import ReduceOp -from colossalai.utils import get_current_device -from .prof_utils import BaseProfiler, _format_time, _format_memory, _format_bandwidth -from typing import List, Optional - - -def _get_code_location(depth: int): - ret = [] - length = min(len(inspect.stack()), depth + 1) - for i in range(3, length): - upper_frame = inspect.stack()[i] - function_name = inspect.stack()[i - 1].function - ret.append(upper_frame.filename) - ret.append('(') - ret.append(str(upper_frame.lineno)) - ret.append('): ') - ret.append(function_name) - if i != length - 1: - ret.append('\n') - - return ''.join(ret) - - -torch_all_reduce = dist.all_reduce -torch_all_gather = dist.all_gather -torch_reduce_scatter = dist.reduce_scatter -torch_broadcast = dist.broadcast -torch_reduce = dist.reduce - - -class CommEvent(object): - """Communication Event. Used for communication time and communication - volume recording. - """ - - def __init__(self, count: int = 0, comm_vol: float = 0., cuda_time: int = 0): - self.self_count = count - self.self_comm_vol = comm_vol - self.self_cuda_time = cuda_time - - def add(self, rhs): - self.self_count += rhs.self_count - self.self_comm_vol += rhs.self_comm_vol - self.self_cuda_time += rhs.self_cuda_time - - -class CommProfiler(BaseProfiler): - """Communication profiler. Records all communication events. - """ - - def __init__(self, depth: int = 0, total_count: int = 0, total_comm_vol: float = 0, total_cuda_time: int = 0): - super().__init__(profiler_name="Collective_Communication", priority=0) - self.depth = 3 + depth - self.total_count = total_count - self.total_comm_vol = total_comm_vol - self.total_cuda_time = total_cuda_time - - self.ops_record = dict() - self.profiler = None - self.pending_op = None - self.pending_metadata = None - self.warn_flag = False - - def reset(self): - self.total_count = 0 - self.total_comm_vol = 0 - self.total_cuda_time = 0 - - self.ops_record = dict() - self.profiler = None - self.pending_op = None - self.pending_metadata = None - self.warn_flag = False - - def enable(self): - dist.all_reduce = partial(all_reduce, profiler=self) - dist.all_gather = partial(all_gather, profiler=self) - dist.reduce_scatter = partial(reduce_scatter, profiler=self) - dist.broadcast = partial(broadcast, profiler=self) - dist.reduce = partial(reduce, profiler=self) - - def disable(self): - dist.all_reduce = torch_all_reduce - dist.all_gather = torch_all_gather - dist.reduce_scatter = torch_reduce_scatter - dist.broadcast = torch_broadcast - dist.reduce = torch_reduce - - def to_tensorboard(self, writer): - writer.add_text(tag="Collective Communication", text_string=self.result_str("\n\n")) - - def to_file(self, filename: Path): - with open(filename, "w") as f: - f.write(self.result_str()) - - def show(self): - print(self.result_str()) - - def result_str(self, sep: str = "\n"): - res = [] - - def append(s: str = None): - if s is not None: - res.append(s) - res.append(sep) - - if self.warn_flag: - append("Warnning: there exists multiple communication operations in the same time. As a result, " - "the profiling result is not accurate.") - - if self.total_cuda_time == 0: - return "No collective communication has been called yet!" - - append("Collective communication profiling result:") - append("total cuda time: {}".format(_format_time(self.total_cuda_time))) - append("average bandwidth: {}".format(_format_bandwidth(self.total_comm_vol, self.total_cuda_time))) - append("total number of calls: {}".format(self.total_count)) - append("All events:") - - seperation = '-' * 74 - row_format = '{:^10}' + '{:^12}' * 2 + '{:^16}' + '{:^12}' * 2 - - append(seperation) - append(row_format.format('Location', 'GPU time', 'Percentage', 'Comm volume', 'Bandwidth', 'Num of calls')) - append(seperation) - - show_list = sorted(self.ops_record.items(), key=lambda kv: -kv[1].self_cuda_time) - for location, event in show_list: - append(location) - append( - row_format.format('', _format_time(event.self_cuda_time), - '{:.1f}%'.format(event.self_cuda_time / self.total_cuda_time * 100.0), - _format_memory(event.self_comm_vol), - _format_bandwidth(event.self_comm_vol, event.self_cuda_time), event.self_count)) - append() - - return ''.join(res) - - @property - def has_aync_op(self): - return self.pending_op is not None - - def activate_profiler(self, kn: str, vol: float): - self.pending_metadata = (kn, _get_code_location(self.depth), vol) - self.profiler = profile(enabled=True, use_cuda=True, use_cpu=True, use_kineto=True) - self.profiler.__enter__() - - def close_profiler(self, group=None): - assert self.profiler is not None, "There is no running dist op" - kernel_name, code_location, vol = self.pending_metadata - self.profiler.__exit__(None, None, None) - - if self.profiler.enabled and dist.get_world_size(group) > 1: - assert_flag = 0 - current_comm_event = None - events = self.profiler.function_events - for event in events: - if kernel_name in event.name: - assert assert_flag == 0, "Multiple dist ops has been called " - current_comm_event = CommEvent(1, vol, event.self_cuda_time_total) - assert_flag += 1 - - assert current_comm_event is not None, "dist op has not been found" - - buffer = torch.tensor([current_comm_event.self_cuda_time], device=get_current_device()) - torch_all_reduce(buffer, op=ReduceOp.MIN, group=group) - current_comm_event.self_cuda_time = buffer.item() - - self.total_count += current_comm_event.self_count - self.total_comm_vol += current_comm_event.self_comm_vol - self.total_cuda_time += current_comm_event.self_cuda_time - if code_location in self.ops_record: - self.ops_record[code_location].add(current_comm_event) - else: - self.ops_record[code_location] = current_comm_event - - self.profiler = None - self.pending_op = None - self.pending_metadata = None - - def wait_async_op(self): - if self.pending_op is not None: - op = self.pending_op - op.wait() - self.close_profiler() - - -class CommHandler(object): - """Communication handler. A dummy handler to wait aync operations. - """ - - def __init__(self, profiler: CommProfiler): - super().__init__() - self.prof = profiler - - def wait(self): - self.prof.wait_async_op() - - -def async_check(profiler: CommProfiler): - if profiler.pending_op is not None: - profiler.warn_flag = True - profiler.wait_async_op() - - -def all_reduce(tensor: torch.Tensor, - op: ReduceOp = ReduceOp.SUM, - group=None, - async_op: bool = False, - profiler: CommProfiler = None) -> Optional[CommHandler]: - async_check(profiler) - - comm_size = dist.get_world_size(group) - correction = 2 * (comm_size - 1) / comm_size - comm_vol = correction * tensor.element_size() * tensor.numel() - profiler.activate_profiler("ncclKernel_AllReduce_", comm_vol) - profiler.pending_op = torch_all_reduce(tensor, op, group, async_op) - - if async_op: - return CommHandler(profiler) - - profiler.close_profiler(group) - - -def reduce_scatter(output: torch.Tensor, - input_list: List[torch.Tensor], - op: ReduceOp = ReduceOp.SUM, - group=None, - async_op: bool = False, - profiler: CommProfiler = None) -> Optional[CommHandler]: - async_check(profiler) - - comm_size = dist.get_world_size(group) - correction = (comm_size - 1) / comm_size - comm_vol = 0 - for tensor in input_list: - comm_vol += tensor.element_size() * tensor.numel() - comm_vol *= correction - profiler.activate_profiler("ncclKernel_ReduceScatter_", comm_vol) - profiler.pending_op = torch_reduce_scatter(output, input_list, op, group, async_op) - - if async_op: - return CommHandler(profiler) - - profiler.close_profiler(group) - - -def all_gather(tensor_list: List[torch.Tensor], - tensor: torch.Tensor, - group=None, - async_op: bool = False, - profiler: CommProfiler = None) -> Optional[CommHandler]: - async_check(profiler) - - comm_size = dist.get_world_size(group) - correction = (comm_size - 1) / comm_size - comm_vol = 0 - for ten in tensor_list: - comm_vol += ten.element_size() * ten.numel() - comm_vol *= correction - profiler.activate_profiler("ncclKernel_AllGather_", comm_vol) - profiler.pending_op = torch_all_gather(tensor_list, tensor, group, async_op) - - if async_op: - return CommHandler(profiler) - - profiler.close_profiler(group) - - -def broadcast(tensor: torch.Tensor, - src: int, - group=None, - async_op: bool = False, - profiler: CommProfiler = None) -> Optional[CommHandler]: - async_check(profiler) - - comm_vol = 1.0 * tensor.element_size() * tensor.numel() - profiler.activate_profiler("ncclKernel_Broadcast_", comm_vol) - profiler.pending_op = torch_broadcast(tensor, src, group, async_op) - - if async_op: - return CommHandler(profiler) - - profiler.close_profiler(group) - - -def reduce(tensor: torch.Tensor, - dst: int, - op: ReduceOp = ReduceOp.SUM, - group=None, - async_op: bool = False, - profiler: CommProfiler = None) -> Optional[CommHandler]: - async_check(profiler) - - comm_vol = 1.0 * tensor.element_size() * tensor.numel() - profiler.activate_profiler("ncclKernel_Reduce_", comm_vol) - profiler.pending_op = torch_reduce(tensor, dst, op, group, async_op) - - if async_op: - return CommHandler(profiler) - - profiler.close_profiler(group) +import inspect +from functools import partial +from pathlib import Path +from typing import List, Optional + +import torch +import torch.distributed as dist +from torch.autograd.profiler import profile +from torch.distributed import ReduceOp + +from colossalai.utils import get_current_device + +from .prof_utils import BaseProfiler, _format_bandwidth, _format_memory, _format_time + + +def _get_code_location(depth: int): + ret = [] + length = min(len(inspect.stack()), depth + 1) + for i in range(3, length): + upper_frame = inspect.stack()[i] + function_name = inspect.stack()[i - 1].function + ret.append(upper_frame.filename) + ret.append("(") + ret.append(str(upper_frame.lineno)) + ret.append("): ") + ret.append(function_name) + if i != length - 1: + ret.append("\n") + + return "".join(ret) + + +torch_all_reduce = dist.all_reduce +torch_all_gather = dist.all_gather +torch_reduce_scatter = dist.reduce_scatter +torch_broadcast = dist.broadcast +torch_reduce = dist.reduce + + +class CommEvent(object): + """Communication Event. Used for communication time and communication + volume recording. + """ + + def __init__(self, count: int = 0, comm_vol: float = 0.0, cuda_time: int = 0): + self.self_count = count + self.self_comm_vol = comm_vol + self.self_cuda_time = cuda_time + + def add(self, rhs): + self.self_count += rhs.self_count + self.self_comm_vol += rhs.self_comm_vol + self.self_cuda_time += rhs.self_cuda_time + + +class CommProfiler(BaseProfiler): + """Communication profiler. Records all communication events.""" + + def __init__(self, depth: int = 0, total_count: int = 0, total_comm_vol: float = 0, total_cuda_time: int = 0): + super().__init__(profiler_name="Collective_Communication", priority=0) + self.depth = 3 + depth + self.total_count = total_count + self.total_comm_vol = total_comm_vol + self.total_cuda_time = total_cuda_time + + self.ops_record = dict() + self.profiler = None + self.pending_op = None + self.pending_metadata = None + self.warn_flag = False + + def reset(self): + self.total_count = 0 + self.total_comm_vol = 0 + self.total_cuda_time = 0 + + self.ops_record = dict() + self.profiler = None + self.pending_op = None + self.pending_metadata = None + self.warn_flag = False + + def enable(self): + dist.all_reduce = partial(all_reduce, profiler=self) + dist.all_gather = partial(all_gather, profiler=self) + dist.reduce_scatter = partial(reduce_scatter, profiler=self) + dist.broadcast = partial(broadcast, profiler=self) + dist.reduce = partial(reduce, profiler=self) + + def disable(self): + dist.all_reduce = torch_all_reduce + dist.all_gather = torch_all_gather + dist.reduce_scatter = torch_reduce_scatter + dist.broadcast = torch_broadcast + dist.reduce = torch_reduce + + def to_tensorboard(self, writer): + writer.add_text(tag="Collective Communication", text_string=self.result_str("\n\n")) + + def to_file(self, filename: Path): + with open(filename, "w") as f: + f.write(self.result_str()) + + def show(self): + print(self.result_str()) + + def result_str(self, sep: str = "\n"): + res = [] + + def append(s: str = None): + if s is not None: + res.append(s) + res.append(sep) + + if self.warn_flag: + append( + "Warning: there exists multiple communication operations in the same time. As a result, " + "the profiling result is not accurate." + ) + + if self.total_cuda_time == 0: + return "No collective communication has been called yet!" + + append("Collective communication profiling result:") + append("total cuda time: {}".format(_format_time(self.total_cuda_time))) + append("average bandwidth: {}".format(_format_bandwidth(self.total_comm_vol, self.total_cuda_time))) + append("total number of calls: {}".format(self.total_count)) + append("All events:") + + separation = "-" * 74 + row_format = "{:^10}" + "{:^12}" * 2 + "{:^16}" + "{:^12}" * 2 + + append(separation) + append(row_format.format("Location", "GPU time", "Percentage", "Comm volume", "Bandwidth", "Num of calls")) + append(separation) + + show_list = sorted(self.ops_record.items(), key=lambda kv: -kv[1].self_cuda_time) + for location, event in show_list: + append(location) + append( + row_format.format( + "", + _format_time(event.self_cuda_time), + "{:.1f}%".format(event.self_cuda_time / self.total_cuda_time * 100.0), + _format_memory(event.self_comm_vol), + _format_bandwidth(event.self_comm_vol, event.self_cuda_time), + event.self_count, + ) + ) + append() + + return "".join(res) + + @property + def has_aync_op(self): + return self.pending_op is not None + + def activate_profiler(self, kn: str, vol: float): + self.pending_metadata = (kn, _get_code_location(self.depth), vol) + self.profiler = profile(enabled=True, use_cuda=True, use_cpu=True, use_kineto=True) + self.profiler.__enter__() + + def close_profiler(self, group=None): + assert self.profiler is not None, "There is no running dist op" + kernel_name, code_location, vol = self.pending_metadata + self.profiler.__exit__(None, None, None) + + if self.profiler.enabled and dist.get_world_size(group) > 1: + assert_flag = 0 + current_comm_event = None + events = self.profiler.function_events + for event in events: + if kernel_name in event.name: + assert assert_flag == 0, "Multiple dist ops has been called " + current_comm_event = CommEvent(1, vol, event.self_cuda_time_total) + assert_flag += 1 + + assert current_comm_event is not None, "dist op has not been found" + + buffer = torch.tensor([current_comm_event.self_cuda_time], device=get_current_device()) + torch_all_reduce(buffer, op=ReduceOp.MIN, group=group) + current_comm_event.self_cuda_time = buffer.item() + + self.total_count += current_comm_event.self_count + self.total_comm_vol += current_comm_event.self_comm_vol + self.total_cuda_time += current_comm_event.self_cuda_time + if code_location in self.ops_record: + self.ops_record[code_location].add(current_comm_event) + else: + self.ops_record[code_location] = current_comm_event + + self.profiler = None + self.pending_op = None + self.pending_metadata = None + + def wait_async_op(self): + if self.pending_op is not None: + op = self.pending_op + op.wait() + self.close_profiler() + + +class CommHandler(object): + """Communication handler. A dummy handler to wait aync operations.""" + + def __init__(self, profiler: CommProfiler): + super().__init__() + self.prof = profiler + + def wait(self): + self.prof.wait_async_op() + + +def async_check(profiler: CommProfiler): + if profiler.pending_op is not None: + profiler.warn_flag = True + profiler.wait_async_op() + + +def all_reduce( + tensor: torch.Tensor, op: ReduceOp = ReduceOp.SUM, group=None, async_op: bool = False, profiler: CommProfiler = None +) -> Optional[CommHandler]: + async_check(profiler) + + comm_size = dist.get_world_size(group) + correction = 2 * (comm_size - 1) / comm_size + comm_vol = correction * tensor.element_size() * tensor.numel() + profiler.activate_profiler("ncclKernel_AllReduce_", comm_vol) + profiler.pending_op = torch_all_reduce(tensor, op, group, async_op) + + if async_op: + return CommHandler(profiler) + + profiler.close_profiler(group) + + +def reduce_scatter( + output: torch.Tensor, + input_list: List[torch.Tensor], + op: ReduceOp = ReduceOp.SUM, + group=None, + async_op: bool = False, + profiler: CommProfiler = None, +) -> Optional[CommHandler]: + async_check(profiler) + + comm_size = dist.get_world_size(group) + correction = (comm_size - 1) / comm_size + comm_vol = 0 + for tensor in input_list: + comm_vol += tensor.element_size() * tensor.numel() + comm_vol *= correction + profiler.activate_profiler("ncclKernel_ReduceScatter_", comm_vol) + profiler.pending_op = torch_reduce_scatter(output, input_list, op, group, async_op) + + if async_op: + return CommHandler(profiler) + + profiler.close_profiler(group) + + +def all_gather( + tensor_list: List[torch.Tensor], + tensor: torch.Tensor, + group=None, + async_op: bool = False, + profiler: CommProfiler = None, +) -> Optional[CommHandler]: + async_check(profiler) + + comm_size = dist.get_world_size(group) + correction = (comm_size - 1) / comm_size + comm_vol = 0 + for ten in tensor_list: + comm_vol += ten.element_size() * ten.numel() + comm_vol *= correction + profiler.activate_profiler("ncclKernel_AllGather_", comm_vol) + profiler.pending_op = torch_all_gather(tensor_list, tensor, group, async_op) + + if async_op: + return CommHandler(profiler) + + profiler.close_profiler(group) + + +def broadcast( + tensor: torch.Tensor, src: int, group=None, async_op: bool = False, profiler: CommProfiler = None +) -> Optional[CommHandler]: + async_check(profiler) + + comm_vol = 1.0 * tensor.element_size() * tensor.numel() + profiler.activate_profiler("ncclKernel_Broadcast_", comm_vol) + profiler.pending_op = torch_broadcast(tensor, src, group, async_op) + + if async_op: + return CommHandler(profiler) + + profiler.close_profiler(group) + + +def reduce( + tensor: torch.Tensor, + dst: int, + op: ReduceOp = ReduceOp.SUM, + group=None, + async_op: bool = False, + profiler: CommProfiler = None, +) -> Optional[CommHandler]: + async_check(profiler) + + comm_vol = 1.0 * tensor.element_size() * tensor.numel() + profiler.activate_profiler("ncclKernel_Reduce_", comm_vol) + profiler.pending_op = torch_reduce(tensor, dst, op, group, async_op) + + if async_op: + return CommHandler(profiler) + + profiler.close_profiler(group) diff --git a/colossalai/utils/profiler/legacy/pcie_profiler.py b/colossalai/legacy/utils/profiler/legacy/pcie_profiler.py similarity index 76% rename from colossalai/utils/profiler/legacy/pcie_profiler.py rename to colossalai/legacy/utils/profiler/legacy/pcie_profiler.py index 526222941ef979c1ff805349f40d51e9a6fdd569..10a3f8dfc43b5e3c2ab778ebf0d7b7c59df26847 100644 --- a/colossalai/utils/profiler/legacy/pcie_profiler.py +++ b/colossalai/legacy/utils/profiler/legacy/pcie_profiler.py @@ -1,148 +1,153 @@ -from pathlib import Path -from torch.autograd.profiler import profile -from .prof_utils import BaseProfiler, _format_time, _format_memory, _format_bandwidth -from typing import List - - -def _get_size(dtype: str): - if dtype == "fp16": - return 2 - elif dtype == "fp32": - return 4 - else: - raise NotImplementedError - - -def _get_numel(my_list: List[int]) -> int: - from functools import reduce - from operator import mul - return reduce(mul, my_list) - - -def _reduce_location(locations: List[str]) -> str: - ret = [] - for lo in locations: - ret.append(lo) - ret.append("\n") - ret = ret[:-1] - return ''.join(ret) - - -class PcieEvent(object): - """Pcie Event. - """ - - def __init__(self, count: int = 0, pcie_vol: int = 0, cuda_time: int = 0): - self.count = count - self.pcie_vol = pcie_vol - self.cuda_time = cuda_time - - def add(self, rhs): - self.count += rhs.count - self.pcie_vol += rhs.pcie_vol - self.cuda_time += rhs.cuda_time - - -class PcieProfiler(BaseProfiler): - """Pcie profiler. Records all data transmission between CPU and GPU. - - TODO: Merge pcie profiler into communication profiler - """ - - def __init__(self, dtype: str = "fp32", depth: int = 1): - super().__init__(profiler_name="Pcie", priority=10) - self.depth = depth - self.data_size = _get_size(dtype) - self.h2d_count = 0 - self.h2d_time = 0 - self.d2h_count = 0 - self.d2h_time = 0 - - self.ops_record = dict() - self.profiler = None - - def reset(self): - self.h2d_count = 0 - self.h2d_time = 0 - self.d2h_count = 0 - self.d2h_time = 0 - - self.ops_record = dict() - self.profiler = None - - def enable(self): - self.profiler = profile(enabled=True, - use_cuda=True, - use_cpu=True, - use_kineto=True, - record_shapes=True, - with_stack=True) - self.profiler.__enter__() - - def disable(self): - self.profiler.__exit__(None, None, None) - - if self.profiler.enabled: - events = self.profiler.function_events - for event in events: - if event.name == "aten::copy_": - t_shape = event.input_shapes[0] - if len(t_shape) == 0 or event.cuda_time_total == 0 or len(event.stack) == 0: - continue - current_comm_event = PcieEvent(1, self.data_size * _get_numel(t_shape), event.cuda_time_total) - code_location = _reduce_location(event.stack[:self.depth]) - if code_location in self.ops_record: - self.ops_record[code_location].add(current_comm_event) - else: - self.ops_record[code_location] = current_comm_event - elif 'Memcpy HtoD' in event.name: - self.h2d_count += 1 - self.h2d_time += event.cuda_time_total - elif 'Memcpy DtoH' in event.name: - self.d2h_count += 1 - self.d2h_time += event.cuda_time_total - - self.profiler = None - - def to_tensorboard(self, writer): - writer.add_text(tag="Data Transmission", text_string=self.result_str("\n\n")) - - def to_file(self, filename: Path): - with open(filename, "w") as f: - f.write(self.result_str()) - - def show(self): - print(self.result_str()) - - def result_str(self, sep: str = "\n"): - res = [] - - def append(s: str = None): - if s is not None: - res.append(s) - res.append(sep) - - append("Pcie profiling result:") - append("time of data transmission (CPU -> GPU): {}".format(_format_time(self.h2d_time))) - append("number of transmission (CPU -> GPU): {}".format(self.h2d_count)) - append("time of data transmission (GPU -> CPU): {}".format(_format_time(self.d2h_time))) - append("number of transmission (GPU -> CPU): {}".format(self.d2h_count)) - - append("Possible data transmission events in PCIE:") - - seperation = '-' * 62 - row_format = '{:^10}' + '{:^12}' + '{:^16}' + '{:^12}' * 2 - - append(seperation) - append(row_format.format('Location', 'GPU time', 'Trans volume', 'Bandwidth', 'Num of calls')) - append(seperation) - - show_list = sorted(self.ops_record.items(), key=lambda kv: -kv[1].cuda_time) - for location, event in show_list: - append(location) - append( - row_format.format('', _format_time(event.cuda_time), _format_memory(event.pcie_vol), - _format_bandwidth(event.pcie_vol, event.cuda_time), event.count)) - append() - - return ''.join(res) +from pathlib import Path +from typing import List + +from torch.autograd.profiler import profile + +from .prof_utils import BaseProfiler, _format_bandwidth, _format_memory, _format_time + + +def _get_size(dtype: str): + if dtype == "fp16": + return 2 + elif dtype == "fp32": + return 4 + else: + raise NotImplementedError + + +def _get_numel(my_list: List[int]) -> int: + from functools import reduce + from operator import mul + + return reduce(mul, my_list) + + +def _reduce_location(locations: List[str]) -> str: + ret = [] + for lo in locations: + ret.append(lo) + ret.append("\n") + ret = ret[:-1] + return "".join(ret) + + +class PcieEvent(object): + """Pcie Event.""" + + def __init__(self, count: int = 0, pcie_vol: int = 0, cuda_time: int = 0): + self.count = count + self.pcie_vol = pcie_vol + self.cuda_time = cuda_time + + def add(self, rhs): + self.count += rhs.count + self.pcie_vol += rhs.pcie_vol + self.cuda_time += rhs.cuda_time + + +class PcieProfiler(BaseProfiler): + """Pcie profiler. Records all data transmission between CPU and GPU. + + TODO: Merge pcie profiler into communication profiler + """ + + def __init__(self, dtype: str = "fp32", depth: int = 1): + super().__init__(profiler_name="Pcie", priority=10) + self.depth = depth + self.data_size = _get_size(dtype) + self.h2d_count = 0 + self.h2d_time = 0 + self.d2h_count = 0 + self.d2h_time = 0 + + self.ops_record = dict() + self.profiler = None + + def reset(self): + self.h2d_count = 0 + self.h2d_time = 0 + self.d2h_count = 0 + self.d2h_time = 0 + + self.ops_record = dict() + self.profiler = None + + def enable(self): + self.profiler = profile( + enabled=True, use_cuda=True, use_cpu=True, use_kineto=True, record_shapes=True, with_stack=True + ) + self.profiler.__enter__() + + def disable(self): + self.profiler.__exit__(None, None, None) + + if self.profiler.enabled: + events = self.profiler.function_events + for event in events: + if event.name == "aten::copy_": + t_shape = event.input_shapes[0] + if len(t_shape) == 0 or event.cuda_time_total == 0 or len(event.stack) == 0: + continue + current_comm_event = PcieEvent(1, self.data_size * _get_numel(t_shape), event.cuda_time_total) + code_location = _reduce_location(event.stack[: self.depth]) + if code_location in self.ops_record: + self.ops_record[code_location].add(current_comm_event) + else: + self.ops_record[code_location] = current_comm_event + elif "Memcpy HtoD" in event.name: + self.h2d_count += 1 + self.h2d_time += event.cuda_time_total + elif "Memcpy DtoH" in event.name: + self.d2h_count += 1 + self.d2h_time += event.cuda_time_total + + self.profiler = None + + def to_tensorboard(self, writer): + writer.add_text(tag="Data Transmission", text_string=self.result_str("\n\n")) + + def to_file(self, filename: Path): + with open(filename, "w") as f: + f.write(self.result_str()) + + def show(self): + print(self.result_str()) + + def result_str(self, sep: str = "\n"): + res = [] + + def append(s: str = None): + if s is not None: + res.append(s) + res.append(sep) + + append("Pcie profiling result:") + append("time of data transmission (CPU -> GPU): {}".format(_format_time(self.h2d_time))) + append("number of transmission (CPU -> GPU): {}".format(self.h2d_count)) + append("time of data transmission (GPU -> CPU): {}".format(_format_time(self.d2h_time))) + append("number of transmission (GPU -> CPU): {}".format(self.d2h_count)) + + append("Possible data transmission events in PCIE:") + + separation = "-" * 62 + row_format = "{:^10}" + "{:^12}" + "{:^16}" + "{:^12}" * 2 + + append(separation) + append(row_format.format("Location", "GPU time", "Trans volume", "Bandwidth", "Num of calls")) + append(separation) + + show_list = sorted(self.ops_record.items(), key=lambda kv: -kv[1].cuda_time) + for location, event in show_list: + append(location) + append( + row_format.format( + "", + _format_time(event.cuda_time), + _format_memory(event.pcie_vol), + _format_bandwidth(event.pcie_vol, event.cuda_time), + event.count, + ) + ) + append() + + return "".join(res) diff --git a/colossalai/legacy/utils/profiler/legacy/prof_utils.py b/colossalai/legacy/utils/profiler/legacy/prof_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..95eecf0715e75e1a6412979685f5061640c4e7b8 --- /dev/null +++ b/colossalai/legacy/utils/profiler/legacy/prof_utils.py @@ -0,0 +1,132 @@ +from abc import ABC, abstractmethod +from pathlib import Path +from typing import List, Union + +from colossalai.legacy.core import global_context as gpc + + +# copied from high version pytorch to support low version +def _format_time(time_us): + """Defines how to format time in FunctionEvent""" + US_IN_SECOND = 1000.0 * 1000.0 + US_IN_MS = 1000.0 + if time_us >= US_IN_SECOND: + return "{:.3f}s".format(time_us / US_IN_SECOND) + if time_us >= US_IN_MS: + return "{:.3f}ms".format(time_us / US_IN_MS) + return "{:.3f}us".format(time_us) + + +# copied from high version pytorch to support low version +def _format_memory(nbytes): + """Returns a formatted memory size string""" + KB = 1024 + MB = 1024 * KB + GB = 1024 * MB + if abs(nbytes) >= GB: + return "{:.2f} GB".format(nbytes * 1.0 / GB) + elif abs(nbytes) >= MB: + return "{:.2f} MB".format(nbytes * 1.0 / MB) + elif abs(nbytes) >= KB: + return "{:.2f} KB".format(nbytes * 1.0 / KB) + else: + return str(nbytes) + " B" + + +def _format_bandwidth(volume: float or int, time_us: int): + sec_div_mb = (1000.0 / 1024.0) ** 2 + mb_per_sec = volume / time_us * sec_div_mb + + if mb_per_sec >= 1024.0: + return "{:.3f} GB/s".format(mb_per_sec / 1024.0) + else: + return "{:.3f} MB/s".format(mb_per_sec) + + +class BaseProfiler(ABC): + def __init__(self, profiler_name: str, priority: int): + self.name = profiler_name + self.priority = priority + + @abstractmethod + def enable(self): + pass + + @abstractmethod + def disable(self): + pass + + @abstractmethod + def to_tensorboard(self, writer): + pass + + @abstractmethod + def to_file(self, filename: Path): + pass + + @abstractmethod + def show(self): + pass + + +class ProfilerContext(object): + """Profiler context manager + + Usage:: + + world_size = 4 + inputs = torch.randn(10, 10, dtype=torch.float32, device=get_current_device()) + outputs = torch.empty(world_size, 10, 10, dtype=torch.float32, device=get_current_device()) + outputs_list = list(torch.chunk(outputs, chunks=world_size, dim=0)) + + cc_prof = CommProfiler() + + with ProfilerContext([cc_prof]) as prof: + op = dist.all_reduce(inputs, async_op=True) + dist.all_gather(outputs_list, inputs) + op.wait() + dist.reduce_scatter(inputs, outputs_list) + dist.broadcast(inputs, 0) + dist.reduce(inputs, 0) + + prof.show() + """ + + def __init__(self, profilers: List[BaseProfiler] = None, enable: bool = True): + self.enable = enable + self.profilers = sorted(profilers, key=lambda prof: prof.priority) + + def __enter__(self): + if self.enable: + for prof in self.profilers: + prof.enable() + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + if self.enable: + for prof in self.profilers: + prof.disable() + + def to_tensorboard(self, writer): + from torch.utils.tensorboard import SummaryWriter + + assert isinstance( + writer, SummaryWriter + ), f"torch.utils.tensorboard.SummaryWriter is required, but found {type(writer)}." + + for prof in self.profilers: + prof.to_tensorboard(writer) + + def to_file(self, log_dir: Union[str, Path]): + if isinstance(log_dir, str): + log_dir = Path(log_dir) + + if not log_dir.exists(): + log_dir.mkdir(parents=True, exist_ok=True) + for prof in self.profilers: + log_file = log_dir.joinpath(f"{prof.name}_rank_{gpc.get_global_rank()}.log") + prof.to_file(log_file) + + def show(self): + for prof in self.profilers: + prof.show() diff --git a/colossalai/utils/profiler/profiler.py b/colossalai/legacy/utils/profiler/profiler.py similarity index 79% rename from colossalai/utils/profiler/profiler.py rename to colossalai/legacy/utils/profiler/profiler.py index 8f43a0b96de0a8deb4818644bb48df7ddaad896c..b7a75f25d951c38fc638e8679ee98becef629da8 100644 --- a/colossalai/utils/profiler/profiler.py +++ b/colossalai/legacy/utils/profiler/profiler.py @@ -1,16 +1,16 @@ -import os -from typing import List -from colossalai.engine import Engine -from torch.profiler import profile as torch_profile -from torch.profiler.profiler import ProfilerAction -from typing import Any, Callable, Iterable, Optional -from torch.autograd import ProfilerActivity +import gzip import json import os import tempfile -import gzip -from colossalai.utils.profiler.extention import ProfilerExtension -from colossalai.utils.profiler.stateful_tensor_mem_extention import StatefulTensorMemoryProfilerExtention +from typing import Any, Callable, Iterable, List, Optional + +from torch.autograd import ProfilerActivity +from torch.profiler import profile as torch_profile +from torch.profiler.profiler import ProfilerAction + +from colossalai.legacy.engine import Engine +from colossalai.legacy.utils.profiler.extention import ProfilerExtension +from colossalai.legacy.utils.profiler.stateful_tensor_mem_extention import StatefulTensorMemoryProfilerExtention from colossalai.logging import get_dist_logger @@ -120,26 +120,30 @@ class profile(torch_profile): p.step() """ - def __init__(self, - *, - activities: Optional[Iterable[ProfilerActivity]] = None, - schedule: Optional[Callable[[int], ProfilerAction]] = None, - on_trace_ready: Optional[Callable[..., Any]] = None, - engine: Optional[Engine] = None, - record_shapes: bool = False, - profile_memory: bool = False, - with_stack: bool = False, - with_flops: bool = False, - with_modules: bool = False, - profile_stateful_tensor_memory: bool = False) -> None: - super().__init__(activities=activities, - schedule=schedule, - on_trace_ready=on_trace_ready, - record_shapes=record_shapes, - profile_memory=profile_memory, - with_stack=with_stack, - with_flops=with_flops, - with_modules=with_modules) + def __init__( + self, + *, + activities: Optional[Iterable[ProfilerActivity]] = None, + schedule: Optional[Callable[[int], ProfilerAction]] = None, + on_trace_ready: Optional[Callable[..., Any]] = None, + engine: Optional[Engine] = None, + record_shapes: bool = False, + profile_memory: bool = False, + with_stack: bool = False, + with_flops: bool = False, + with_modules: bool = False, + profile_stateful_tensor_memory: bool = False, + ) -> None: + super().__init__( + activities=activities, + schedule=schedule, + on_trace_ready=on_trace_ready, + record_shapes=record_shapes, + profile_memory=profile_memory, + with_stack=with_stack, + with_flops=with_flops, + with_modules=with_modules, + ) self._logger = get_dist_logger() self.extentions: List[ProfilerExtension] = [] if profile_stateful_tensor_memory: @@ -149,9 +153,9 @@ class profile(torch_profile): self.extentions.append(StatefulTensorMemoryProfilerExtention(engine)) def prepare_trace(self) -> None: - if hasattr(super(), 'prepare_trace'): + if hasattr(super(), "prepare_trace"): super().prepare_trace() - elif hasattr(super(), '_start_warmup'): + elif hasattr(super(), "_start_warmup"): super()._start_warmup() for ext in self.extentions: ext.prepare_trace() @@ -160,9 +164,9 @@ class profile(torch_profile): self.prepare_trace() def start_trace(self): - if hasattr(super(), '_start_trace'): + if hasattr(super(), "_start_trace"): super()._start_trace() - elif hasattr(super(), 'start_trace'): + elif hasattr(super(), "start_trace"): super().start_trace() for ext in self.extentions: ext.start_trace() @@ -171,9 +175,9 @@ class profile(torch_profile): self.start_trace() def stop_trace(self): - if hasattr(super(), '_stop_trace'): + if hasattr(super(), "_stop_trace"): super()._stop_trace() - elif hasattr(super(), 'stop_trace'): + elif hasattr(super(), "stop_trace"): super().stop_trace() for ext in self.extentions: ext.stop_trace() @@ -186,15 +190,15 @@ class profile(torch_profile): Exports the collected trace in Chrome JSON format. """ assert self.profiler - fp = tempfile.NamedTemporaryFile('w+t', suffix='.json', delete=False) + fp = tempfile.NamedTemporaryFile("w+t", suffix=".json", delete=False) fp.close() retvalue = self.profiler.export_chrome_trace(fp.name) with open(fp.name) as fin: trace = json.load(fin) for ext in self.extentions: trace = ext.extend_chrome_trace(trace) - open_func = gzip.open if path.endswith('.gz') else open - with open_func(path, 'wt') as fout: + open_func = gzip.open if path.endswith(".gz") else open + with open_func(path, "wt") as fout: json.dump(trace, fout) os.remove(fp.name) diff --git a/colossalai/utils/profiler/stateful_tensor_mem_extention.py b/colossalai/legacy/utils/profiler/stateful_tensor_mem_extention.py similarity index 82% rename from colossalai/utils/profiler/stateful_tensor_mem_extention.py rename to colossalai/legacy/utils/profiler/stateful_tensor_mem_extention.py index 127055c8c1efa5e193ee04cbf6467df35a5972a7..9247a9b80772d5b9e94d1c08fdb8771555ec8145 100644 --- a/colossalai/utils/profiler/stateful_tensor_mem_extention.py +++ b/colossalai/legacy/utils/profiler/stateful_tensor_mem_extention.py @@ -1,13 +1,15 @@ import os import threading import time -import torch from enum import Enum from typing import List -from colossalai.gemini.stateful_tensor import StatefulTensor + +import torch + from colossalai.gemini.ophooks import BaseOpHook -from colossalai.engine import Engine -from colossalai.utils.profiler.extention import ProfilerExtension +from colossalai.gemini.stateful_tensor import StatefulTensor +from colossalai.legacy.engine import Engine +from colossalai.legacy.utils.profiler.extention import ProfilerExtension class DeviceType(Enum): @@ -20,11 +22,11 @@ def get_timestamp_us(): def generic_instant_event(name, pid, tid, timestamp, args): - return {'ph': 'i', 's': 't', 'name': name, 'pid': pid, 'tid': tid, 'ts': timestamp, 'args': args} + return {"ph": "i", "s": "t", "name": name, "pid": pid, "tid": tid, "ts": timestamp, "args": args} class StatefulTensorMemoryEvent: - EVENT_NAME = '[statefulTensorMemory]' + EVENT_NAME = "[statefulTensorMemory]" def __init__(self, timestamp: int, device_type: DeviceType, bytes_: int) -> None: self.pid = os.getpid() @@ -35,22 +37,23 @@ class StatefulTensorMemoryEvent: self.bytes = bytes_ def state_dict(self): - return generic_instant_event(StatefulTensorMemoryEvent.EVENT_NAME, self.pid, self.tid, self.timestamp, { - 'Device Type': self.device_type.value, - 'Device Id': self.device_id, - 'Bytes': self.bytes - }) + return generic_instant_event( + StatefulTensorMemoryEvent.EVENT_NAME, + self.pid, + self.tid, + self.timestamp, + {"Device Type": self.device_type.value, "Device Id": self.device_id, "Bytes": self.bytes}, + ) class StatefulTensorMemoryTracer: - def __init__(self) -> None: self.events: List[StatefulTensorMemoryEvent] = [] self._tracing = False def sample(self): - cuda_mem = StatefulTensor.GST_MGR.total_mem['cuda'] - cpu_mem = StatefulTensor.GST_MGR.total_mem['cpu'] + cuda_mem = StatefulTensor.GST_MGR.total_mem["cuda"] + cpu_mem = StatefulTensor.GST_MGR.total_mem["cpu"] timestamp = get_timestamp_us() if self._tracing: self.events.append(StatefulTensorMemoryEvent(timestamp, DeviceType.CUDA, cuda_mem)) @@ -68,7 +71,6 @@ class StatefulTensorMemoryTracer: class StatefulTensorMemoryTracerHook(BaseOpHook): - def __init__(self, tracer: StatefulTensorMemoryTracer): super().__init__() self.tracer = tracer @@ -102,7 +104,6 @@ class StatefulTensorMemoryTracerHook(BaseOpHook): class StatefulTensorMemoryProfilerExtention(ProfilerExtension): - def __init__(self, engine: Engine) -> None: self.engine = engine self.tracer = StatefulTensorMemoryTracer() @@ -129,5 +130,5 @@ class StatefulTensorMemoryProfilerExtention(ProfilerExtension): # self.hook_registered = False def extend_chrome_trace(self, trace: dict) -> dict: - trace['traceEvents'].extend(self.tracer.state_dict()) + trace["traceEvents"].extend(self.tracer.state_dict()) return trace diff --git a/colossalai/legacy/zero/__init__.py b/colossalai/legacy/zero/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..760fd529f3a6cbe38e56ad083cc8ef65e72a49dd --- /dev/null +++ b/colossalai/legacy/zero/__init__.py @@ -0,0 +1,52 @@ +from typing import Tuple + +import torch +import torch.nn as nn + +from colossalai.logging import get_dist_logger + +from .init_ctx import ZeroInitContext, no_shard_zero_context, no_shard_zero_decrator +from .shard_utils import BucketTensorShardStrategy, TensorShardStrategy +from .sharded_model import ShardedModelV2 +from .sharded_optim import ShardedOptimizerV2 + + +def convert_to_zero_v2( + model: nn.Module, optimizer: torch.optim.Optimizer, model_config, optimizer_config +) -> Tuple[ShardedModelV2, ShardedOptimizerV2]: + """ + A helper function to integrate the model and optimizer with ZeRO optimizer and off-loading + + :param model: Your model object + :type model: :class:`torch.nn.Module` + :param optimizer_config: Your optimizer object + :type optimizer_config: :class:`dict` + + :return: (model, optimizer) + :rtype: Tuple + """ + + logger = get_dist_logger("convert_to_zero_v2") + + logger.info(f"optimizer_config is {optimizer_config}", ranks=[0]) + if optimizer_config is None: + optimizer_config = dict() + logger.info(f"model_config is {model_config}", ranks=[0]) + if model_config is None: + model_config = dict() + + zero_model = ShardedModelV2(model, **model_config) + zero_optimizer = ShardedOptimizerV2(zero_model, optimizer, **optimizer_config) + return zero_model, zero_optimizer + + +__all__ = [ + "convert_to_zero_v2", + "ShardedModelV2", + "ShardedOptimizerV2", + "ZeroInitContext", + "no_shard_zero_context", + "no_shard_zero_decrator", + "TensorShardStrategy", + "BucketTensorShardStrategy", +] diff --git a/colossalai/legacy/zero/gemini/__init__.py b/colossalai/legacy/zero/gemini/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b272980d34d81bac0c0352e629bf870f211734d2 --- /dev/null +++ b/colossalai/legacy/zero/gemini/__init__.py @@ -0,0 +1,14 @@ +from .ophooks import BaseOpHook, register_ophooks_recursively +from .stateful_tensor import StatefulTensor +from .stateful_tensor_mgr import StatefulTensorMgr +from .tensor_placement_policy import AutoTensorPlacementPolicy, CPUTensorPlacementPolicy, CUDATensorPlacementPolicy + +__all__ = [ + "StatefulTensorMgr", + "StatefulTensor", + "CPUTensorPlacementPolicy", + "CUDATensorPlacementPolicy", + "AutoTensorPlacementPolicy", + "register_ophooks_recursively", + "BaseOpHook", +] diff --git a/colossalai/legacy/zero/gemini/gemini_context.py b/colossalai/legacy/zero/gemini/gemini_context.py new file mode 100644 index 0000000000000000000000000000000000000000..9e82d948fba74881c85e4e99afaaed1696261dfc --- /dev/null +++ b/colossalai/legacy/zero/gemini/gemini_context.py @@ -0,0 +1,51 @@ +from enum import EnumMeta + + +class GeminiMemoryManager(object): + def __init__(self, states_cls: EnumMeta): + super().__init__() + self.states_cls = states_cls + self._cnter = 0 # the counter of instances + + self.total_mem = dict() + self.state_mem = dict() + self.state_mem["cpu"] = dict() + self.state_mem["cuda"] = dict() + + self.reset() + + @property + def total_number(self): + return self._cnter + + def reset(self): + self._cnter = 0 # the counter of instances + + self.total_mem["cpu"] = 0 # memory occupation of instances in cpu + self.total_mem["cuda"] = 0 # memory of occupation of instances in cuda + + # memory conditions for all states + for state in self.states_cls: + self.state_mem["cpu"][state] = 0 + self.state_mem["cuda"][state] = 0 + + def register_new_instance(self): + self._cnter += 1 + + def delete_instance(self): + self._cnter -= 1 + + def print_info(self): + print( + f"Total number: {self.total_number}", + f"Total CPU memory occupation: {self.total_mem['cpu']}", + f"Total CUDA memory occupation: {self.total_mem['cuda']}\n", + sep="\n", + ) + + for state in self.states_cls: + print( + f"{state}: CPU memory occupation: {self.state_mem['cpu'][state]}", + f"{state}: CUDA memory occupation: {self.state_mem['cuda'][state]}\n", + sep="\n", + ) diff --git a/colossalai/zero/legacy/gemini/ophooks/__init__.py b/colossalai/legacy/zero/gemini/ophooks/__init__.py similarity index 100% rename from colossalai/zero/legacy/gemini/ophooks/__init__.py rename to colossalai/legacy/zero/gemini/ophooks/__init__.py diff --git a/colossalai/zero/legacy/gemini/ophooks/_shard_grad_ophook.py b/colossalai/legacy/zero/gemini/ophooks/_shard_grad_ophook.py similarity index 87% rename from colossalai/zero/legacy/gemini/ophooks/_shard_grad_ophook.py rename to colossalai/legacy/zero/gemini/ophooks/_shard_grad_ophook.py index 8f8fec64924ea325f88cf282cdf98dba8cf731f1..4129b14bcae9d7d96317f77c03825798f16ce52d 100644 --- a/colossalai/zero/legacy/gemini/ophooks/_shard_grad_ophook.py +++ b/colossalai/legacy/zero/gemini/ophooks/_shard_grad_ophook.py @@ -1,6 +1,6 @@ import torch -from colossalai.registry import OPHOOKS +from colossalai.legacy.registry import OPHOOKS from . import BaseOpHook @@ -22,7 +22,7 @@ class ShardGradMemTracerHook(BaseOpHook): def pre_bwd_exec(self, module: torch.nn.Module, input, output): for param in module.parameters(): - assert hasattr(param, '_sharded_grad') + assert hasattr(param, "_sharded_grad") param._sharded_grad.setup() def post_bwd_exec(self, module: torch.nn.Module, input): diff --git a/colossalai/zero/legacy/gemini/ophooks/_shard_param_ophook.py b/colossalai/legacy/zero/gemini/ophooks/_shard_param_ophook.py similarity index 83% rename from colossalai/zero/legacy/gemini/ophooks/_shard_param_ophook.py rename to colossalai/legacy/zero/gemini/ophooks/_shard_param_ophook.py index a2a62fb9788a9de973ae28e535b81eb542248c3a..e0c83eec0445b41df1531c33d7f4d3301476eea8 100644 --- a/colossalai/zero/legacy/gemini/ophooks/_shard_param_ophook.py +++ b/colossalai/legacy/zero/gemini/ophooks/_shard_param_ophook.py @@ -1,6 +1,6 @@ import torch -from colossalai.registry import OPHOOKS +from colossalai.legacy.registry import OPHOOKS from . import BaseOpHook @@ -19,25 +19,25 @@ class ShardParamHook(BaseOpHook): def pre_fwd_exec(self, module: torch.nn.Module, *args): for param in module.parameters(): - assert hasattr(param, 'ca_attr') + assert hasattr(param, "ca_attr") param.ca_attr.gather() param.data = param.ca_attr.payload() def post_fwd_exec(self, module: torch.nn.Module, *args): for param in module.parameters(): - assert hasattr(param, 'ca_attr') + assert hasattr(param, "ca_attr") param.ca_attr.shard() param.data = param.ca_attr.payload() def pre_bwd_exec(self, module: torch.nn.Module, input, output): for param in module.parameters(): - assert hasattr(param, 'ca_attr') + assert hasattr(param, "ca_attr") param.ca_attr.gather() param.data = param.ca_attr.payload() def post_bwd_exec(self, module: torch.nn.Module, input): for param in module.parameters(): - assert hasattr(param, 'ca_attr') + assert hasattr(param, "ca_attr") param.ca_attr.shard() param.data = param.ca_attr.payload() diff --git a/colossalai/zero/legacy/gemini/ophooks/runtime_mem_tracer_hook.py b/colossalai/legacy/zero/gemini/ophooks/runtime_mem_tracer_hook.py similarity index 93% rename from colossalai/zero/legacy/gemini/ophooks/runtime_mem_tracer_hook.py rename to colossalai/legacy/zero/gemini/ophooks/runtime_mem_tracer_hook.py index f40d6ced1ee09a4aa2edca2dd5302111647937ce..57076063cb3f8ece04dccbe7a9dd3579679f0b85 100644 --- a/colossalai/zero/legacy/gemini/ophooks/runtime_mem_tracer_hook.py +++ b/colossalai/legacy/zero/gemini/ophooks/runtime_mem_tracer_hook.py @@ -5,9 +5,9 @@ from typing import List import torch +from colossalai.legacy.zero.gemini.tensor_utils import alloc_storage, free_storage from colossalai.tensor.param_op_hook import ColoParamOpHook from colossalai.zero.gemini.memory_tracer import MemStats, SyncCudaMemoryMonitor -from colossalai.zero.legacy.gemini.tensor_utils import alloc_storage, free_storage class TrainingPhase(Enum): @@ -15,8 +15,7 @@ class TrainingPhase(Enum): BACKWARD = 1 -class GradMemStats(): - +class GradMemStats: def __init__(self) -> None: self.unreleased_grad_flag = {} self.unreleased_grad_volume = 0 @@ -26,8 +25,7 @@ class GradMemStats(): self.unreleased_grad_volume = 0 -class GradMemTracerHook(): - +class GradMemTracerHook: def __init__(self, grad_stats: GradMemStats): self.grad_hook_list = [] self._grad_stats = grad_stats @@ -50,7 +48,6 @@ class GradMemTracerHook(): class ParamMemTracerHook(ColoParamOpHook): - def __init__(self, memstats: MemStats, gradstats: GradMemStats) -> None: super().__init__() self._training_phase = TrainingPhase.FORWARD @@ -79,10 +76,9 @@ class ParamMemTracerHook(ColoParamOpHook): if cur_dev == "cpu": if p.grad is not None and p.grad.device.type == "cpu": raise NotImplementedError("Only run in forward propagation") - p.data = torch.empty(p.data.shape, - device="cuda", - dtype=p.data.dtype, - requires_grad=p.data.requires_grad) + p.data = torch.empty( + p.data.shape, device="cuda", dtype=p.data.dtype, requires_grad=p.data.requires_grad + ) elif cur_dev == "cuda": alloc_storage(p.data) diff --git a/colossalai/legacy/zero/gemini/ophooks/utils.py b/colossalai/legacy/zero/gemini/ophooks/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..057906156d8d68e95370cb82a352a4762bb22861 --- /dev/null +++ b/colossalai/legacy/zero/gemini/ophooks/utils.py @@ -0,0 +1,137 @@ +# this code is inspired by the DeepSpeed library and implemented with our own design from scratch +from abc import ABC, abstractmethod +from typing import Callable, List, Optional + +import torch + + +class BaseOpHook(ABC): + """This class allows users to add customized operations + before and after the execution of a PyTorch submodule""" + + def __init__(self): + pass + + @abstractmethod + def pre_fwd_exec(self, module: torch.nn.Module, *args): + pass + + @abstractmethod + def post_fwd_exec(self, module: torch.nn.Module, *args): + pass + + @abstractmethod + def pre_bwd_exec(self, module: torch.nn.Module, input, output): + pass + + @abstractmethod + def post_bwd_exec(self, module: torch.nn.Module, input): + pass + + @abstractmethod + def post_iter(self): + pass + + +# apply torch.autograd.Function that calls a backward_function to tensors in output +def _apply_to_tensors_only(module, functional, backward_function, outputs): + if type(outputs) is tuple: + touched_outputs = [] + for output in outputs: + touched_output = _apply_to_tensors_only(module, functional, backward_function, output) + touched_outputs.append(touched_output) + return tuple(touched_outputs) + elif type(outputs) is torch.Tensor: + return functional.apply(module, backward_function, outputs) + else: + return outputs + + +class PreBackwardFunction(torch.autograd.Function): + @staticmethod + def forward(ctx, module, pre_backward_function, outputs): + ctx.module = module + ctx.pre_backward_function = pre_backward_function + module.applied_pre_backward = False + outputs = outputs.detach() + return outputs + + @staticmethod + def backward(ctx, *args): + ctx.pre_backward_function(ctx.module) + return (None, None) + args + + +class PostBackwardFunction(torch.autograd.Function): + @staticmethod + def forward(ctx, module, pre_backward_function, output): + ctx.module = module + output = output.detach() + ctx.pre_backward_function = pre_backward_function + return output + + @staticmethod + def backward(ctx, *args): + """ + Args: + activation_grad of the next layer. + Returns: + grad of the input activation. + """ + ctx.pre_backward_function(ctx.module) + return (None, None) + args + + +def register_ophooks_recursively( + module: torch.nn.Module, ophook_list: List[BaseOpHook], name: str = "", filter_fn: Optional[Callable] = None +): + r"""Recursively register pre/post hooks for all submodules in the module in FWD and BWD.""" + assert isinstance(module, torch.nn.Module) + assert isinstance(ophook_list, (list, tuple)) + assert len(ophook_list) > 0, "expected at least 1 hook in the argument ophook_list but found 0" + for hook in ophook_list: + assert isinstance(hook, BaseOpHook) + + # Add hooks for submodules + for child_name, child in module.named_children(): + register_ophooks_recursively(child, ophook_list, name + child_name, filter_fn) + + # Early return on modules with no parameters. + if len(list(module.parameters(recurse=False))) == 0: + return + + # return from filtered module + if filter_fn is not None and filter_fn(module): + return + + def _pre_forward_module_hook(submodule, *args): + for hook in ophook_list: + assert isinstance(submodule, torch.nn.Module) + hook.pre_fwd_exec(submodule, *args) + + def _post_forward_module_hook(submodule, *args): + for hook in ophook_list: + assert isinstance(submodule, torch.nn.Module) + hook.post_fwd_exec(submodule, *args) + + def _pre_backward_module_hook(submodule, inputs, output): + def _run_before_backward_function(submodule): + for hook in ophook_list: + assert isinstance(submodule, torch.nn.Module) + hook.pre_bwd_exec(submodule, inputs, output) + + return _apply_to_tensors_only(submodule, PreBackwardFunction, _run_before_backward_function, output) + + def _post_backward_module_hook(submodule, inputs): + def _run_after_backward_function(submodule): + for hook in ophook_list: + assert isinstance(submodule, torch.nn.Module) + hook.post_bwd_exec(submodule, inputs) + + return _apply_to_tensors_only(submodule, PostBackwardFunction, _run_after_backward_function, inputs) + + module.register_forward_pre_hook(_pre_forward_module_hook) + module.register_forward_hook(_post_forward_module_hook) + + module.register_forward_hook(_pre_backward_module_hook) + module.register_forward_pre_hook(_post_backward_module_hook) diff --git a/colossalai/zero/legacy/gemini/paramhooks/__init__.py b/colossalai/legacy/zero/gemini/paramhooks/__init__.py similarity index 100% rename from colossalai/zero/legacy/gemini/paramhooks/__init__.py rename to colossalai/legacy/zero/gemini/paramhooks/__init__.py diff --git a/colossalai/zero/legacy/gemini/paramhooks/_param_hookmgr.py b/colossalai/legacy/zero/gemini/paramhooks/_param_hookmgr.py similarity index 83% rename from colossalai/zero/legacy/gemini/paramhooks/_param_hookmgr.py rename to colossalai/legacy/zero/gemini/paramhooks/_param_hookmgr.py index 84f32be358e3b844ae7db394022d5b9a077352a7..91c7bdc2961b468fa7615ccee6c1d5952ff3981d 100644 --- a/colossalai/zero/legacy/gemini/paramhooks/_param_hookmgr.py +++ b/colossalai/legacy/zero/gemini/paramhooks/_param_hookmgr.py @@ -5,7 +5,6 @@ import torch class BaseParamHookMgr(object): - def __init__(self, param_list: List[torch.nn.Parameter]) -> None: r""" register backward hook on every parameters of module @@ -23,9 +22,9 @@ class BaseParamHookMgr(object): ``` """ if not torch.is_grad_enabled(): - return # don't register grad hooks if grad isn't enabled + return # don't register grad hooks if grad isn't enabled for p in self._param_list: - if p.requires_grad and not hasattr(p, '_base_param_hook'): + if p.requires_grad and not hasattr(p, "_base_param_hook"): handle = p.register_hook(functools.partial(hook_call, p)) p._base_param_hook = handle @@ -35,5 +34,5 @@ class BaseParamHookMgr(object): """ for p in self._param_list: - if p.requires_grad and hasattr(p, '_base_param_hook'): + if p.requires_grad and hasattr(p, "_base_param_hook"): p._base_param_hook.remove() diff --git a/colossalai/zero/legacy/gemini/stateful_tensor.py b/colossalai/legacy/zero/gemini/stateful_tensor.py similarity index 96% rename from colossalai/zero/legacy/gemini/stateful_tensor.py rename to colossalai/legacy/zero/gemini/stateful_tensor.py index 1619ae40798d17e1fbff351e5678017dd54cf049..668d344132d0796c9d224c0fe4e3d7bd9f80a482 100644 --- a/colossalai/zero/legacy/gemini/stateful_tensor.py +++ b/colossalai/legacy/zero/gemini/stateful_tensor.py @@ -25,13 +25,14 @@ class StatefulTensor(object): https://arxiv.org/abs/2108.05818 """ + # Global Stateful Tensor Manager GST_MGR = GeminiMemoryManager(TensorState) def __init__(self, maybe_tensor: Optional[torch.Tensor], state: Optional[TensorState] = TensorState.HOLD) -> None: self._state = state self._payload = None - self._payload_size = 0 # byte size of current payload + self._payload_size = 0 # byte size of current payload StatefulTensor.GST_MGR.register_new_instance() @@ -47,7 +48,7 @@ class StatefulTensor(object): def data_ptr(self): if self._payload is None: - return 0 # if a tensor has no storage, 0 should be returned + return 0 # if a tensor has no storage, 0 should be returned return self._payload.data_ptr() def set_null(self) -> None: @@ -80,7 +81,7 @@ class StatefulTensor(object): assert self.state is not TensorState.FREE, "Can't move free stateful tensor" if not isinstance(device, torch.device): - to_device = torch.device('cuda', device) + to_device = torch.device("cuda", device) else: to_device = device @@ -97,7 +98,6 @@ class StatefulTensor(object): self._payload.view(-1).copy_(tensor.view(-1)) def payload_reset(self, tensor) -> None: - assert tensor is not None, "Can't reset None for stateful tensors, please use set_null() instead" if self.payload is not None: @@ -168,8 +168,7 @@ class StatefulTensor(object): self._payload_size = 0 def __trans_state_update(self, from_state: TensorState, to_state: TensorState): - """Update global manager when changing the state of a tensor - """ + """Update global manager when changing the state of a tensor""" manager = StatefulTensor.GST_MGR size = self.payload_size device_type = self.device.type @@ -189,8 +188,7 @@ class StatefulTensor(object): manager.total_mem[device_type] -= size def __trans_device_update(self, from_type: str, to_type: str): - """Update global manager when changing the device of a tensor - """ + """Update global manager when changing the device of a tensor""" manager = StatefulTensor.GST_MGR size = self.payload_size state = self.state diff --git a/colossalai/zero/legacy/gemini/stateful_tensor_mgr.py b/colossalai/legacy/zero/gemini/stateful_tensor_mgr.py similarity index 80% rename from colossalai/zero/legacy/gemini/stateful_tensor_mgr.py rename to colossalai/legacy/zero/gemini/stateful_tensor_mgr.py index 4f9ea7c6d5202238def5fc5be6d8a38724932860..19f77d4305afe996d68b9c4e0141db7d72779663 100644 --- a/colossalai/zero/legacy/gemini/stateful_tensor_mgr.py +++ b/colossalai/legacy/zero/gemini/stateful_tensor_mgr.py @@ -3,14 +3,11 @@ import types from time import time from typing import List -import torch - -from colossalai.logging import get_dist_logger from colossalai.utils.cuda import get_current_device from .stateful_tensor import StatefulTensor, TensorState from .tensor_placement_policy import TensorPlacementPolicy -from .tensor_utils import colo_model_data_tensor_move_inline, colo_tensor_mem_usage +from .tensor_utils import colo_model_data_tensor_move_inline class StatefulTensorMgr(object): @@ -44,8 +41,7 @@ class StatefulTensorMgr(object): pass def finish_iter(self): - """This function must be called when each iteration finishes - """ + """This function must be called when each iteration finishes""" self._warmup = False self._compute_idx = -1 self._cpu_gpu_move_volume = 0 @@ -53,19 +49,21 @@ class StatefulTensorMgr(object): self._evict_time = 0 def adjust_layout(self) -> None: - """ Adjust the layout of stateful tensor according to the information provided + """Adjust the layout of stateful tensor according to the information provided by mem_stats_collector, which should belongs to a Sharded Model. """ # find stateful tensor in state COMPUTE - cuda_demand = StatefulTensor.GST_MGR.state_mem['cpu'][TensorState.COMPUTE] + cuda_demand = StatefulTensor.GST_MGR.state_mem["cpu"][TensorState.COMPUTE] start = time() move_to_cuda_tensor_list, hold_cuda_tensor_list = self._get_layout_info(self._compute_idx, self._warmup) self._layout_time += time() - start - vol, evict_time = self._tensor_placement_policy.evict_tensors(hold_cuda_tensor_list, - cuda_demand=cuda_demand, - warmup=self._warmup, - compute_list=self._compute_list, - compute_idx=self._compute_idx) + vol, evict_time = self._tensor_placement_policy.evict_tensors( + hold_cuda_tensor_list, + cuda_demand=cuda_demand, + warmup=self._warmup, + compute_list=self._compute_list, + compute_idx=self._compute_idx, + ) self._cpu_gpu_move_volume += vol self._evict_time += evict_time # move COMPUTE tensors to CUDA @@ -92,10 +90,10 @@ class StatefulTensorMgr(object): if tensor.state == TensorState.FREE: continue - if tensor.device.type == 'cuda': + if tensor.device.type == "cuda": if tensor.state in [TensorState.HOLD, TensorState.HOLD_AFTER_BWD, TensorState.HOLD_AFTER_FWD]: hold_cuda_tensor_list.append(tensor) - elif tensor.device.type == 'cpu': + elif tensor.device.type == "cpu": if tensor.state == TensorState.COMPUTE: move_to_cuda_tensor_list.append(tensor) else: diff --git a/colossalai/zero/legacy/gemini/tensor_placement_policy.py b/colossalai/legacy/zero/gemini/tensor_placement_policy.py similarity index 85% rename from colossalai/zero/legacy/gemini/tensor_placement_policy.py rename to colossalai/legacy/zero/gemini/tensor_placement_policy.py index 165ae51fee60e2292a909ef4ffe8646fdae67353..3aca80cfe56af342c867dfb1250673526121e15d 100644 --- a/colossalai/zero/legacy/gemini/tensor_placement_policy.py +++ b/colossalai/legacy/zero/gemini/tensor_placement_policy.py @@ -5,16 +5,15 @@ from typing import List, Optional, Type import torch +from colossalai.legacy.utils.memory import colo_device_memory_capacity from colossalai.utils import get_current_device -from colossalai.utils.memory import colo_device_memory_capacity from colossalai.zero.gemini.memory_tracer import MemStatsCollector from .stateful_tensor import StatefulTensor -from .tensor_utils import colo_model_data_tensor_move_inline, colo_tensor_mem_usage +from .tensor_utils import colo_model_data_tensor_move_inline class TensorPlacementPolicy(ABC): - def __init__(self, device: Optional[torch.device], mem_stats_collector: Optional[MemStatsCollector] = None) -> None: self.device: Optional[torch.device] = device self.mem_stats_collector: Optional[MemStatsCollector] = mem_stats_collector @@ -25,9 +24,8 @@ class TensorPlacementPolicy(ABC): class CPUTensorPlacementPolicy(TensorPlacementPolicy): - def __init__(self, mem_stats_collector: Optional[MemStatsCollector] = None) -> None: - super().__init__(torch.device('cpu'), mem_stats_collector=mem_stats_collector) + super().__init__(torch.device("cpu"), mem_stats_collector=mem_stats_collector) def evict_tensors(self, hold_cuda_tensor_list: List[StatefulTensor], **kwargs) -> int: volume = 0 @@ -38,9 +36,8 @@ class CPUTensorPlacementPolicy(TensorPlacementPolicy): class CUDATensorPlacementPolicy(TensorPlacementPolicy): - def __init__(self, mem_stats_collector: Optional[MemStatsCollector] = None) -> None: - assert torch.cuda.is_available(), 'Cannot use CUDATensorPlacementPolicy when CUDA is not available' + assert torch.cuda.is_available(), "Cannot use CUDATensorPlacementPolicy when CUDA is not available" super().__init__(get_current_device(), mem_stats_collector=mem_stats_collector) def evict_tensors(self, hold_cuda_tensor_list: List[StatefulTensor], **kwargs) -> int: @@ -48,7 +45,6 @@ class CUDATensorPlacementPolicy(TensorPlacementPolicy): class AutoTensorPlacementPolicy(TensorPlacementPolicy): - def __init__(self, mem_stats_collector: Optional[MemStatsCollector] = None) -> None: super().__init__(None, mem_stats_collector=mem_stats_collector) # model data will use 1-self._warmup_non_model_data_ratio CUDA memory in warmup phase @@ -56,13 +52,15 @@ class AutoTensorPlacementPolicy(TensorPlacementPolicy): self._warmup_non_model_data_ratio: float = 0.8 self._steady_cuda_cap_ratio: float = 0.9 - def evict_tensors(self, - hold_cuda_tensor_list: List[StatefulTensor], - cuda_demand: int = 0, - warmup: bool = True, - compute_list: List[StatefulTensor] = [], - compute_idx: int = 0, - **kwargs) -> int: + def evict_tensors( + self, + hold_cuda_tensor_list: List[StatefulTensor], + cuda_demand: int = 0, + warmup: bool = True, + compute_list: List[StatefulTensor] = [], + compute_idx: int = 0, + **kwargs, + ) -> int: """ Evict tensors from CUDA device. @@ -81,13 +79,13 @@ class AutoTensorPlacementPolicy(TensorPlacementPolicy): """ start = time() cuda_capacity = colo_device_memory_capacity(get_current_device()) - used_cuda_model_data = StatefulTensor.GST_MGR.total_mem['cuda'] + used_cuda_model_data = StatefulTensor.GST_MGR.total_mem["cuda"] if warmup: # We designate a part of CUDA memory for model data in warmup iterations. max_cuda_non_model_data_per_period = cuda_capacity * self._warmup_non_model_data_ratio else: # max non-model-data cuda memory consumption of this sampling moment and the next sampling moment. - max_cuda_non_model_data_per_period = self.mem_stats_collector.next_period_non_model_data_usage('cuda') + max_cuda_non_model_data_per_period = self.mem_stats_collector.next_period_non_model_data_usage("cuda") cuda_capacity *= self._steady_cuda_cap_ratio total_cuda_model_data = cuda_capacity - max_cuda_non_model_data_per_period avail_cuda_model_data = total_cuda_model_data - used_cuda_model_data @@ -99,15 +97,16 @@ class AutoTensorPlacementPolicy(TensorPlacementPolicy): to_free_cuda_model_data = cuda_demand - avail_cuda_model_data to_free_tensor_list = hold_cuda_tensor_list if not warmup: - to_free_tensor_list = self._sort_hold_cuda_tensors(tuple(hold_cuda_tensor_list), compute_idx, - tuple(compute_list)) + to_free_tensor_list = self._sort_hold_cuda_tensors( + tuple(hold_cuda_tensor_list), compute_idx, tuple(compute_list) + ) # print(self._sort_hold_cuda_tensors.cache_info()) end = time() for t in to_free_tensor_list: if freed_cuda_model_data >= to_free_cuda_model_data: break freed_cuda_model_data += t.payload_size - colo_model_data_tensor_move_inline(t, torch.device('cpu')) + colo_model_data_tensor_move_inline(t, torch.device("cpu")) if freed_cuda_model_data < to_free_cuda_model_data: raise RuntimeError( f"Adjust layout failed! No enough CUDA memory! Need {to_free_cuda_model_data}, freed {freed_cuda_model_data}" @@ -126,14 +125,13 @@ class AutoTensorPlacementPolicy(TensorPlacementPolicy): class TensorPlacementPolicyFactory: - @staticmethod def create(policy_name: str) -> Type[TensorPlacementPolicy]: - if policy_name == 'cpu': + if policy_name == "cpu": return CPUTensorPlacementPolicy - elif policy_name == 'cuda': + elif policy_name == "cuda": return CUDATensorPlacementPolicy - elif policy_name == 'auto': + elif policy_name == "auto": return AutoTensorPlacementPolicy else: raise TypeError(f"Unknown tensor placement policy {policy_name}") diff --git a/colossalai/zero/legacy/gemini/tensor_utils.py b/colossalai/legacy/zero/gemini/tensor_utils.py similarity index 78% rename from colossalai/zero/legacy/gemini/tensor_utils.py rename to colossalai/legacy/zero/gemini/tensor_utils.py index b7f23e0253fdd1123306a32801d1cbc32884bc73..6e51dee6ef945daae570a853cb7d817f2012c975 100644 --- a/colossalai/zero/legacy/gemini/tensor_utils.py +++ b/colossalai/legacy/zero/gemini/tensor_utils.py @@ -30,16 +30,17 @@ def colo_tensor_mem_usage(tensor: Union[torch.Tensor, StatefulTensor]) -> Tuple[ cuda_use, cpu_use = 0, 0 mem_use = t.storage().size() * t.element_size() - if t.device.type == 'cuda': + if t.device.type == "cuda": cuda_use += mem_use - elif t.device.type == 'cpu': + elif t.device.type == "cpu": cpu_use += mem_use return cuda_use, cpu_use -def colo_model_data_tensor_move(src_t: Union[StatefulTensor, torch.Tensor], tgt_t: Union[StatefulTensor, - torch.Tensor]) -> None: +def colo_model_data_tensor_move( + src_t: Union[StatefulTensor, torch.Tensor], tgt_t: Union[StatefulTensor, torch.Tensor] +) -> None: """ A colossal API for model data tensor move. The src and target tensors could be resident on both CPU and GPU. @@ -71,23 +72,24 @@ def colo_model_data_tensor_move(src_t: Union[StatefulTensor, torch.Tensor], tgt_ src_t.data = torch.empty(0, device=src_dev, dtype=src_t_payload.dtype) -def colo_model_data_tensor_move_inline(t: Union[StatefulTensor, torch.Tensor], target_device: Union[torch.device, - int]) -> None: +def colo_model_data_tensor_move_inline( + t: Union[StatefulTensor, torch.Tensor], target_device: Union[torch.device, int] +) -> None: """ move a tensor to the target_device Args: t (Union[StatefulTensor, torch.Tensor]): the tensor be moved - target_device: a traget device, if type is int, it the index of cuda card. + target_device: a target device, if type is int, it the index of cuda card. """ if not isinstance(target_device, torch.device): - target_device = torch.device(f'cuda:{target_device}') + target_device = torch.device(f"cuda:{target_device}") if isinstance(t, torch.Tensor): t.data = t.data.to(target_device) elif isinstance(t, StatefulTensor): t.move_to(target_device) else: - raise TypeError(f'colo_model_data_tensor_move_inline dose not accept type {type(t)}') + raise TypeError(f"colo_model_data_tensor_move_inline dose not accept type {type(t)}") def colo_model_data_move_to_cpu(t: Union[StatefulTensor, torch.Tensor]) -> None: @@ -100,9 +102,9 @@ def colo_model_data_move_to_cpu(t: Union[StatefulTensor, torch.Tensor]) -> None: if isinstance(t, torch.Tensor): t.data = t.data.cpu() elif isinstance(t, StatefulTensor): - t.move_to(torch.device('cpu')) + t.move_to(torch.device("cpu")) else: - raise TypeError(f'colo_model_data_move_to_cpu dose not accept type {type(t)}') + raise TypeError(f"colo_model_data_move_to_cpu dose not accept type {type(t)}") def colo_model_tensor_clone(t: Union[StatefulTensor, torch.Tensor], target_device: torch.device) -> torch.Tensor: diff --git a/colossalai/legacy/zero/init_ctx/__init__.py b/colossalai/legacy/zero/init_ctx/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..28ce72a18b31f347bb37267e5ec3d19f09f5977e --- /dev/null +++ b/colossalai/legacy/zero/init_ctx/__init__.py @@ -0,0 +1,3 @@ +from .init_context import ZeroInitContext, no_shard_zero_context, no_shard_zero_decrator + +__all__ = ["ZeroInitContext", "no_shard_zero_context", "no_shard_zero_decrator"] diff --git a/colossalai/zero/legacy/init_ctx/init_context.py b/colossalai/legacy/zero/init_ctx/init_context.py similarity index 79% rename from colossalai/zero/legacy/init_ctx/init_context.py rename to colossalai/legacy/zero/init_ctx/init_context.py index a921ca0aa83a5d588c5339b799f5f3ee05386feb..6c5a8122ef80328f9a89f9f62371f3ccd1288a6e 100644 --- a/colossalai/zero/legacy/init_ctx/init_context.py +++ b/colossalai/legacy/zero/init_ctx/init_context.py @@ -8,15 +8,15 @@ import torch import torch.distributed as dist import torch.nn as nn -from colossalai.context.parallel_mode import ParallelMode from colossalai.context.singleton_meta import SingletonMeta -from colossalai.core import global_context as gpc +from colossalai.legacy.context.parallel_mode import ParallelMode +from colossalai.legacy.core import global_context as gpc +from colossalai.legacy.zero.shard_utils import BaseShardStrategy +from colossalai.legacy.zero.sharded_model._utils import cast_tensor_to_bf16, cast_tensor_to_fp16 +from colossalai.legacy.zero.sharded_model.sharded_model_v2 import ShardedModelV2 +from colossalai.legacy.zero.sharded_param import ShardedParamV2 from colossalai.logging import get_dist_logger from colossalai.utils.model.utils import InsertPostInitMethodToModuleSubClasses -from colossalai.zero.legacy.shard_utils import BaseShardStrategy -from colossalai.zero.legacy.sharded_model._utils import cast_tensor_to_fp16 -from colossalai.zero.legacy.sharded_model.sharded_model_v2 import ShardedModelV2 -from colossalai.zero.legacy.sharded_param import ShardedParamV2 @dataclass @@ -39,14 +39,14 @@ class ZeroContextConfig: assert self.is_replicated, "Non-replicated parameters can't be sharded." if self.is_replicated and not self.shard_param: - assert self.target_device.type == 'cuda', "Replicated no-shard parameters should be located in cuda." + assert self.target_device.type == "cuda", "Replicated no-shard parameters should be located in cuda." class ZeroInitContext(InsertPostInitMethodToModuleSubClasses): """A context to initialize model. 1. Convert the model to fp16. - 2. The paramaters of the module are adapted to type ShardedParameter. + 2. The parameters of the module are adapted to type ShardedParameter. 3. Shard the param and grad according to flags. Args: @@ -55,22 +55,26 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses): seed (int, optional): Random seed for weight initialization shard_param (bool, optional): Is param sharded after exiting the context. Defaults to False. default_dtype (torch.dtype, optional): If it's not None, parameters will be initialized as ``default_dtype`` then converted to fp16. + bf16 (bool, optional): If it's True, parameters will be initialized as ``torch.bfloat16``. Otherwise, parameters will be initialized as ``torch.float16``. Defaults to False. model_numel_tensor (torch.Tensor, optional): A tensor which will store the number of elements of model. Defaults to torch.zeros(1, dtype=torch.int). """ - def __init__(self, - target_device: torch.device, - shard_strategy: BaseShardStrategy, - seed: int = 2**10 - 1, - shard_param: bool = False, - default_dtype: Optional[torch.dtype] = None, - model_numel_tensor: torch.Tensor = torch.zeros(1, dtype=torch.long)): - + def __init__( + self, + target_device: torch.device, + shard_strategy: BaseShardStrategy, + seed: int = 2**10 - 1, + shard_param: bool = False, + default_dtype: Optional[torch.dtype] = None, + bf16: bool = False, + model_numel_tensor: torch.Tensor = torch.zeros(1, dtype=torch.long), + ): super().__init__(default_dtype=default_dtype) self.shard_strategy = shard_strategy self.param_list = [] self.model_numel_tensor = model_numel_tensor self.seed = seed + self.bf16 = bf16 self.dp_process_group = gpc.get_group(ParallelMode.DATA) self.config = ZeroContextConfig(target_device=target_device, is_replicated=True, shard_param=shard_param) @@ -100,7 +104,7 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses): assert isinstance(tensor, nn.Parameter), "Sharded tensor initialization is only allowed for parameters" # get correct shape of input tensor - if not hasattr(tensor, 'colo_attr') or not tensor.colo_attr.param_is_sharded: + if not hasattr(tensor, "colo_attr") or not tensor.colo_attr.param_is_sharded: tensor_shape = tensor.shape else: tensor_shape = tensor.colo_attr.sharded_data_tensor.origin_shape @@ -134,13 +138,16 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses): self.module_load_from_state_dict = nn.Module._load_from_state_dict shard_strategy = self.shard_strategy if self.config.shard_param else None - nn.Module._load_from_state_dict = functools.partialmethod(ShardedModelV2._colo_load_from_state_dict, - shard_strategy=shard_strategy) + nn.Module._load_from_state_dict = functools.partialmethod( + ShardedModelV2._colo_load_from_state_dict, shard_strategy=shard_strategy + ) self.module_state_dict = nn.Module.state_dict - nn.Module.state_dict = functools.partialmethod(ShardedModelV2._colo_state_dict, - shard_strategy=shard_strategy, - state_dict_func=self.module_state_dict, - process_group=self.dp_process_group) + nn.Module.state_dict = functools.partialmethod( + ShardedModelV2._colo_state_dict, + shard_strategy=shard_strategy, + state_dict_func=self.module_state_dict, + process_group=self.dp_process_group, + ) # reserve rng states self.cpu_rng_state = torch.get_rng_state() @@ -149,16 +156,15 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses): # set new seed for initialization, since we initialize sharded tensor separately # we don't want all processes have the same seed # otherwise all sharded tensors are same after init - offset = self.seed + 1 # we want to have more 1 in binary format seed + offset = self.seed + 1 # we want to have more 1 in binary format seed torch.manual_seed(self.seed + offset * dist.get_rank()) def _post_context_exec(self): - """The callback function when exiting context. - """ + """The callback function when exiting context.""" # broadcast replicated no-shard parameters src_rank = gpc.get_ranks_in_group(ParallelMode.DATA)[0] for param in self.param_list: - assert hasattr(param, 'colo_attr') + assert hasattr(param, "colo_attr") if not param.colo_attr.param_is_sharded and param.colo_attr.is_replicated: dist.broadcast(tensor=param.data, src=src_rank, group=self.dp_process_group) param.colo_attr.set_data_none() @@ -183,13 +189,14 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses): NOTE() The module may be passed to this function multiple times. """ self.top_module = module + half_dtype = torch.float16 if not self.bf16 else torch.bfloat16 def half_fn(t: torch.Tensor): - return t.half() if t.is_floating_point() else t + return t.to(half_dtype) if t.is_floating_point() else t for param in module.parameters(recurse=False): # avoid adapting a param to ShardedParam twice - if hasattr(param, 'colo_attr'): + if hasattr(param, "colo_attr"): continue self.param_numel[param] = param.numel() @@ -212,7 +219,7 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses): if self.shard_param: self.shard_strategy.shard([param.colo_attr.sharded_data_tensor], self.dp_process_group) - param.data = param.colo_attr.data_payload # set param.data to payload + param.data = param.colo_attr.data_payload # set param.data to payload # mark whether the param is replicated param.colo_attr.is_replicated = self.is_replicated @@ -226,9 +233,10 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses): # We must cast buffers # If we use BN, buffers may be on CPU and Float # We must cast them + cast_fn = cast_tensor_to_fp16 if not self.bf16 else cast_tensor_to_bf16 for buffer in module.buffers(recurse=False): buffer.data = buffer.data.to(device=torch.cuda.current_device()) - buffer.data = cast_tensor_to_fp16(buffer.data) + buffer.data = cast_fn(buffer.data) class ZeroContextMgr(metaclass=SingletonMeta): @@ -246,15 +254,13 @@ class ZeroContextMgr(metaclass=SingletonMeta): def no_shard_zero_context(is_replicated: bool = True) -> AbstractContextManager: - return ZeroContextMgr().hijack_context_config(target_device=torch.device('cuda', torch.cuda.current_device()), - is_replicated=is_replicated, - shard_param=False) + return ZeroContextMgr().hijack_context_config( + target_device=torch.device("cuda", torch.cuda.current_device()), is_replicated=is_replicated, shard_param=False + ) def no_shard_zero_decrator(is_replicated: bool = True): - def _wrapper(init_func): - def _no_shard(*args, **kwargs): with no_shard_zero_context(is_replicated): ret = init_func(*args, **kwargs) diff --git a/colossalai/legacy/zero/shard_utils/__init__.py b/colossalai/legacy/zero/shard_utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..945c77a412c1efbdcb074464de707ac389c70cb6 --- /dev/null +++ b/colossalai/legacy/zero/shard_utils/__init__.py @@ -0,0 +1,5 @@ +from .base_shard_strategy import BaseShardStrategy +from .bucket_tensor_shard_strategy import BucketTensorShardStrategy +from .tensor_shard_strategy import TensorShardStrategy + +__all__ = ["BaseShardStrategy", "TensorShardStrategy", "BucketTensorShardStrategy"] diff --git a/colossalai/zero/legacy/shard_utils/base_shard_strategy.py b/colossalai/legacy/zero/shard_utils/base_shard_strategy.py similarity index 86% rename from colossalai/zero/legacy/shard_utils/base_shard_strategy.py rename to colossalai/legacy/zero/shard_utils/base_shard_strategy.py index 7ca95109164028f1a1389ae0fa7547883af1e441..13e6f0e482980bdf88bd82314ea57f46c74d1e3e 100644 --- a/colossalai/zero/legacy/shard_utils/base_shard_strategy.py +++ b/colossalai/legacy/zero/shard_utils/base_shard_strategy.py @@ -3,14 +3,12 @@ from typing import List, Optional import torch.distributed as dist -from colossalai.zero.legacy.sharded_param.sharded_tensor import ShardedTensor +from colossalai.legacy.zero.sharded_param.sharded_tensor import ShardedTensor class BaseShardStrategy(ABC): - def __init__(self) -> None: - """Abstract Shard Strategy. Use to shard a tensors on multiple GPUs. - """ + """Abstract Shard Strategy. Use to shard a tensors on multiple GPUs.""" super().__init__() @abstractmethod diff --git a/colossalai/zero/legacy/shard_utils/bucket_tensor_shard_strategy.py b/colossalai/legacy/zero/shard_utils/bucket_tensor_shard_strategy.py similarity index 88% rename from colossalai/zero/legacy/shard_utils/bucket_tensor_shard_strategy.py rename to colossalai/legacy/zero/shard_utils/bucket_tensor_shard_strategy.py index d663104831ce79062fb6083e19401705efa64fd0..b9d3071a877e7bad6882436bef57a3ba3ef7fdec 100644 --- a/colossalai/zero/legacy/shard_utils/bucket_tensor_shard_strategy.py +++ b/colossalai/legacy/zero/shard_utils/bucket_tensor_shard_strategy.py @@ -4,8 +4,8 @@ import torch import torch.distributed as dist from torch._utils import _flatten_dense_tensors as flatten +from colossalai.legacy.zero.sharded_param.sharded_tensor import ShardedTensor from colossalai.utils import get_current_device -from colossalai.zero.legacy.sharded_param.sharded_tensor import ShardedTensor from .tensor_shard_strategy import TensorShardStrategy @@ -18,7 +18,6 @@ class BucketTensorShardStrategy(TensorShardStrategy): """ def gather(self, tensor_list: List[ShardedTensor], process_group: Optional[dist.ProcessGroup] = None): - tensor_list: List[ShardedTensor] = [t for t in tensor_list if t.is_sharded] if len(tensor_list) == 0: return @@ -40,8 +39,8 @@ class BucketTensorShardStrategy(TensorShardStrategy): buffer_list = [buffer.to(target_device) for buffer in buffer_list] offset = 0 for i, t in enumerate(tensor_list): - gathered_payload = [buffer[offset:offset + tensor_numels[i]] for buffer in buffer_list] - gathered_payload = torch.cat(gathered_payload)[:t.origin_numel].view(t.origin_shape) + gathered_payload = [buffer[offset : offset + tensor_numels[i]] for buffer in buffer_list] + gathered_payload = torch.cat(gathered_payload)[: t.origin_numel].view(t.origin_shape) t.payload_reset(gathered_payload) t.is_sharded = False offset += tensor_numels[i] diff --git a/colossalai/zero/legacy/shard_utils/commons.py b/colossalai/legacy/zero/shard_utils/commons.py similarity index 100% rename from colossalai/zero/legacy/shard_utils/commons.py rename to colossalai/legacy/zero/shard_utils/commons.py diff --git a/colossalai/zero/legacy/shard_utils/tensor_shard_strategy.py b/colossalai/legacy/zero/shard_utils/tensor_shard_strategy.py similarity index 81% rename from colossalai/zero/legacy/shard_utils/tensor_shard_strategy.py rename to colossalai/legacy/zero/shard_utils/tensor_shard_strategy.py index d1df4803b820988d20e06ac37e653cc3bb7759c9..ebaef774bd063787446c87a3e55b0bb9a06fa109 100644 --- a/colossalai/zero/legacy/shard_utils/tensor_shard_strategy.py +++ b/colossalai/legacy/zero/shard_utils/tensor_shard_strategy.py @@ -3,11 +3,11 @@ from typing import List, Optional import torch import torch.distributed as dist +from colossalai.legacy.zero.gemini.tensor_utils import colo_model_data_tensor_move_inline +from colossalai.legacy.zero.shard_utils import BaseShardStrategy +from colossalai.legacy.zero.shard_utils.commons import get_shard +from colossalai.legacy.zero.sharded_param.sharded_tensor import ShardedTensor from colossalai.utils import get_current_device -from colossalai.zero.legacy.gemini.tensor_utils import colo_model_data_tensor_move_inline -from colossalai.zero.legacy.shard_utils import BaseShardStrategy -from colossalai.zero.legacy.shard_utils.commons import get_shard -from colossalai.zero.legacy.sharded_param.sharded_tensor import ShardedTensor class TensorShardStrategy(BaseShardStrategy): @@ -24,7 +24,7 @@ class TensorShardStrategy(BaseShardStrategy): self._gather_tensor(t, process_group) def _shard_tensor(self, t: ShardedTensor, process_group: Optional[dist.ProcessGroup] = None): - """ Shard tensor among processes. + """Shard tensor among processes. Args: t (ShardedTensor): a tensor to be sharded. @@ -33,9 +33,11 @@ class TensorShardStrategy(BaseShardStrategy): """ if t.is_sharded: return - if t.payload.device.type == 'cuda': - assert t.payload.device == get_current_device(), f"shard tensor on cuda device index {t.payload.device.index},"\ + if t.payload.device.type == "cuda": + assert t.payload.device == get_current_device(), ( + f"shard tensor on cuda device index {t.payload.device.index}," f" but current cuda device is {get_current_device()}" + ) sharded_payload, _ = get_shard(t.payload, dist.get_rank(process_group), dist.get_world_size(process_group)) t.payload_reset(sharded_payload) t.is_sharded = True diff --git a/colossalai/legacy/zero/sharded_model/__init__.py b/colossalai/legacy/zero/sharded_model/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..ecead2f6a657c890a170a1c44394c6f6405d2d3b --- /dev/null +++ b/colossalai/legacy/zero/sharded_model/__init__.py @@ -0,0 +1,3 @@ +from .sharded_model_v2 import ShardedModelV2 + +__all__ = ["ShardedModelV2"] diff --git a/colossalai/legacy/zero/sharded_model/_utils.py b/colossalai/legacy/zero/sharded_model/_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..1007623185932dded40a9366184aa985f9c131d1 --- /dev/null +++ b/colossalai/legacy/zero/sharded_model/_utils.py @@ -0,0 +1,85 @@ +from typing import Any, Callable, List, Tuple, Union + +import torch +import torch.nn.functional as F + +from colossalai.legacy.zero.gemini.stateful_tensor import StatefulTensor + + +def get_gradient_predivide_factor(world_size: int) -> float: + factor: int = 1 + while world_size % factor == 0 and world_size / factor > factor: + factor *= 2 + return float(factor) + + +def free_storage(data: torch.Tensor) -> None: + """Free underlying storage of a Tensor.""" + if data.storage().size() > 0: + # Since we're modifying the Tensor's Storage directly, make sure the Tensor + # is the sole occupant of the Storage. + assert data.storage_offset() == 0 + data.storage().resize_(0) + + +@torch.no_grad() +def alloc_storage(data: torch.Tensor, size: torch.Size) -> None: + """Allocate storage for a tensor.""" + if data.storage().size() == size.numel(): # no need to reallocate + return + assert data.storage().size() == 0 + data.storage().resize_(size.numel()) + + +def cast_tensor_to_fp16(tensor: torch.Tensor) -> torch.Tensor: + if isinstance(tensor, StatefulTensor): + tensor = tensor.payload + if torch.is_floating_point(tensor) and tensor.dtype is torch.float32: + return tensor.half() + return tensor + + +def cast_tensor_to_fp32(tensor: Union[torch.Tensor, StatefulTensor]) -> torch.Tensor: + if isinstance(tensor, StatefulTensor): + tensor = tensor.payload + + if torch.is_floating_point(tensor) and tensor.dtype in (torch.float16, torch.bfloat16): + return tensor.float() + return tensor + + +def cast_tensor_to_bf16(tensor: torch.Tensor) -> torch.Tensor: + if isinstance(tensor, StatefulTensor): + tensor = tensor.payload + if torch.is_floating_point(tensor) and tensor.dtype is torch.float32: + return tensor.bfloat16() + return tensor + + +def apply_to_tensors(x: Any, fn: Callable): + if torch.is_tensor(x): + return fn(x) + elif isinstance(x, list): + return [apply_to_tensors(t, fn) for t in x] + elif isinstance(x, tuple): + return tuple(apply_to_tensors(t, fn) for t in x) + elif isinstance(x, dict): + return {key: apply_to_tensors(val, fn) for key, val in x.items()} + else: + return x + + +def cast_float_arguments(fn: Callable, *args: Any, **kwargs: Any) -> Tuple[Any, Any]: + return apply_to_tensors(args, fn), apply_to_tensors(kwargs, fn) + + +def chunk_and_pad(tensor: torch.Tensor, num_chunks: int) -> List[torch.Tensor]: + """Chunk a given Tensor into num_chunks parts and add any necessary padding.""" + chunks = list(torch.flatten(tensor).chunk(num_chunks)) + # torch.chunk may return fewer than num_chunks chunks, pad accordingly. + num_pad_for_partial_chunk = chunks[0].numel() - chunks[-1].numel() + if num_pad_for_partial_chunk > 0: + chunks[-1] = F.pad(chunks[-1], [0, num_pad_for_partial_chunk]) + if len(chunks) < num_chunks: + chunks.extend([torch.zeros_like(chunks[0]) for _ in range(num_chunks - len(chunks))]) + return chunks diff --git a/colossalai/zero/legacy/sharded_model/reduce_scatter.py b/colossalai/legacy/zero/sharded_model/reduce_scatter.py similarity index 89% rename from colossalai/zero/legacy/sharded_model/reduce_scatter.py rename to colossalai/legacy/zero/sharded_model/reduce_scatter.py index 4fb507382df9eae2d3efa35fdcdcb2704a9256dc..0f11365515d2e9076ced7e5d6e063693cbe26fe6 100644 --- a/colossalai/zero/legacy/sharded_model/reduce_scatter.py +++ b/colossalai/legacy/zero/sharded_model/reduce_scatter.py @@ -20,7 +20,6 @@ else: class Bucket: - def __init__(self, shard_size: int, dtype: torch.dtype, device: torch.device, group: ProcessGroup): self.buffer = torch.zeros((group.size(), shard_size), dtype=dtype, device=device) self.group = group @@ -35,18 +34,18 @@ class Bucket: return # reduce-scatter bucket if hasattr(dist, "_reduce_scatter_base") and enable_nccl_base_collectives: - dist._reduce_scatter_base(self.output_shard[:self.offset], - self.buffer[:, :self.offset].contiguous(), - group=self.group) + dist._reduce_scatter_base( + self.output_shard[: self.offset], self.buffer[:, : self.offset].contiguous(), group=self.group + ) else: - dist.reduce_scatter(self.output_shard[:self.offset], - list(self.buffer[:, :self.offset].unbind(0)), - group=self.group) + dist.reduce_scatter( + self.output_shard[: self.offset], list(self.buffer[:, : self.offset].unbind(0)), group=self.group + ) # execute post-reduction callbacks for callback_fn in self.callbacks: callback_fn() # reuse input bucket but allocate a fresh output shard - self.buffer[:, :self.offset].zero_() + self.buffer[:, : self.offset].zero_() self.offset = 0 self.callbacks.clear() self.output_shard = torch.zeros_like(self.buffer[0]) @@ -74,12 +73,12 @@ class Bucket: tensor_size = tensor_list[0].numel() stacked_input = torch.stack(tensor_list).view(self.group.size(), tensor_size) offset = self.offset - self.buffer[:, offset:offset + tensor_size].copy_(stacked_input) + self.buffer[:, offset : offset + tensor_size].copy_(stacked_input) self.offset += tensor_size # callback will be given the reduced result if callback_fn is not None: - result_view = self.output_shard[offset:offset + tensor_size].view_as(tensor_list[0]) + result_view = self.output_shard[offset : offset + tensor_size].view_as(tensor_list[0]) self.callbacks.append(functools.partial(callback_fn, result_view)) @@ -142,8 +141,9 @@ class ReduceScatterBucketer: """ world_size = group.size() - assert (len(input_list) == world_size - ), f"reduce_scatter received {len(input_list)} inputs, expected group.size() ({world_size})" + assert ( + len(input_list) == world_size + ), f"reduce_scatter received {len(input_list)} inputs, expected group.size() ({world_size})" first_input = input_list[0] first_input_size = first_input.numel() @@ -183,7 +183,7 @@ class ReduceScatterBucketer: @functools.lru_cache() def _get_shard_size(self, element_size: int, num_shards: int) -> int: - if self.bucket_size_mb <= 0: # Values <= 0 disable bucketing. + if self.bucket_size_mb <= 0: # Values <= 0 disable bucketing. return 0 MB = 1024 * 1024 bucket_size = self.bucket_size_mb * MB / element_size diff --git a/colossalai/legacy/zero/sharded_model/sharded_model_v2.py b/colossalai/legacy/zero/sharded_model/sharded_model_v2.py new file mode 100644 index 0000000000000000000000000000000000000000..85f2ac2159f463b68db210ba12877e80bcda81ca --- /dev/null +++ b/colossalai/legacy/zero/sharded_model/sharded_model_v2.py @@ -0,0 +1,587 @@ +# this code is inspired by the DeepSpeed library and implemented with our own design from scratch +import functools +import itertools +from collections import OrderedDict +from typing import Any, Iterator, Optional, Tuple + +import torch +import torch.distributed as dist +import torch.nn as nn +from torch.distributed import ProcessGroup +from torch.nn.parameter import Parameter + +from colossalai.legacy.context.parallel_mode import ParallelMode +from colossalai.legacy.core import global_context as gpc +from colossalai.legacy.utils.memory import colo_device_memory_capacity +from colossalai.legacy.zero.gemini.ophooks import register_ophooks_recursively +from colossalai.legacy.zero.gemini.paramhooks import BaseParamHookMgr +from colossalai.legacy.zero.gemini.stateful_tensor import TensorState +from colossalai.legacy.zero.gemini.stateful_tensor_mgr import StatefulTensorMgr +from colossalai.legacy.zero.gemini.tensor_placement_policy import TensorPlacementPolicy, TensorPlacementPolicyFactory +from colossalai.legacy.zero.gemini.tensor_utils import colo_model_data_move_to_cpu +from colossalai.legacy.zero.shard_utils import BaseShardStrategy +from colossalai.legacy.zero.sharded_model.reduce_scatter import ReduceScatterBucketer +from colossalai.logging import get_dist_logger +from colossalai.utils import disposable, get_current_device +from colossalai.zero.gemini.memory_tracer import MemStatsCollector + +from ._utils import ( + cast_float_arguments, + cast_tensor_to_bf16, + cast_tensor_to_fp16, + cast_tensor_to_fp32, + chunk_and_pad, + free_storage, + get_gradient_predivide_factor, +) +from .zero_hook import ZeroHook + +try: + from torch.nn.modules.module import _EXTRA_STATE_KEY_SUFFIX +except ImportError: + _EXTRA_STATE_KEY_SUFFIX = "_extra_state" + + +class ShardedModelV2(nn.Module): + """ + A wrapper for the PyTorch module shards the model parameters among multiple GPU memory. + Only `1/#nproc` of parameters, gradients are stored in local CUDA memory, so forward and backward + passes can be executed with limited CUDA memory budget. + + Note: + You must use ``ShardedModelV2`` with ``ShardedOptimizerV2``. + Note: + Make sure you don't use gradient accumulation and your optimizer can work with fp16 gradient and fp32 parameter, + if you enable ``reuse_fp16_shard``. + + Args: + module (nn.Module): A sharded module, which must be initialized by `ZeroInitContext`. + shard_strategy (BaseShardStrategy): A shard strategy to manage shard behavior. + process_group (Optional[ProcessGroup], optional): Data parallel process group. Defaults to None. + reduce_scatter_process_group (Optional[ProcessGroup], optional): Reduce-scatter process group. + Generally, it should be `None`, and it's the same as `process_group`. Defaults to None. + reduce_scatter_bucket_size_mb (int, optional): Reduce-scatter bucket size in *MB*. Defaults to 25. + fp32_reduce_scatter (bool, optional): If set to `True`, gradients are forced to FP32 before reduce-scatter. Defaults to False. + tensor_placement_policy (str): Which device to place *held* tensors. It can be 'cpu', 'cuda' and 'auto'. + If it's 'cpu', parameters, gradients and optimizer states will be offloaded to CPU, which means min CUDA memory will be used. + If it's 'cuda', they won't be offloaded, which means max CUDA memory will be used. + If it's 'auto', they are moving dynamically based on CPU and CUDA memory usage. It will utilize heterogeneous memory space evenly and well. + Note that 'auto' policy can only work well when no other processes use CUDA during your training. + Defaults to 'cuda'. + gradient_predivide_factor (Optional[float], optional): Gradient is divided by this value before reduce-scatter. Defaults to 1.0. + reuse_fp16_shard (bool, optional): Whether to reuse fp16 shard for param and grad. + Enabling this can reduce GPU memory usage, but you have to make sure you disable it when using gradient accumulation. + In this mode, grad will be fp16. Make sure your optimizer supports mixed precision (fp32 param and fp16 grad). + We find that PyTorch's optimizers don't support mixed precision, + so we recommend you enable this only when using our CPUAdam with CPU offload. Defaults to False. + bf16 (bool, optional): Whether to use bfloat16 for param and grad. Defaults to False. + """ + + def __init__( + self, + module: nn.Module, + shard_strategy: BaseShardStrategy, + process_group: Optional[ProcessGroup] = None, + reduce_scatter_process_group: Optional[ProcessGroup] = None, + reduce_scatter_bucket_size_mb: int = 25, + fp32_reduce_scatter: bool = False, + tensor_placement_policy: str = "cuda", + gradient_predivide_factor: Optional[float] = 1.0, + reuse_fp16_shard: bool = False, + bf16: bool = False, + *args, + **kwargs, + ): + assert not isinstance(module, ShardedModelV2), "Nested ShardedModelV2 is not supported." + super().__init__() + self.logger = get_dist_logger() + self.bf16 = bf16 + + # We force users to use ZeroInitContext + for submodule in module.modules(): + sharded_cnt = 0 + unshard_cnt = 0 + for param in submodule.parameters(recurse=False): + assert hasattr(param, "colo_attr"), "You must use ZeroInitContext to init your module first." + if param.colo_attr.param_is_sharded: + sharded_cnt += 1 + else: + unshard_cnt += 1 + assert (not sharded_cnt) or (not unshard_cnt), "nn.Module can not both have shard param and unshard param" + submodule.param_is_sharded = sharded_cnt > 0 + + self.sharded_params = [] + self.unshard_params = [] + for param in module.parameters(): + if param.colo_attr.param_is_sharded: + self.sharded_params.append(param) + else: + self.unshard_params.append(param) + + self.module = module + self.process_group = process_group or gpc.get_group(ParallelMode.DATA) + self.reduce_scatter_process_group = reduce_scatter_process_group or self.process_group + self.world_size = dist.get_world_size(self.process_group) + self.rank = dist.get_rank(self.process_group) + self.shard_strategy = shard_strategy + + self._use_memory_tracer = tensor_placement_policy == "auto" + if self._use_memory_tracer: + self._memstats_collector = MemStatsCollector() + self._start_collect_memstats = disposable(self._memstats_collector.start_collection) + self._finish_collect_memstats = disposable(self._memstats_collector.finish_collection) + else: + self._memstats_collector = None + self._tensor_placement_policy: TensorPlacementPolicy = TensorPlacementPolicyFactory.create( + tensor_placement_policy + )(mem_stats_collector=self._memstats_collector) + + if "warmup_non_model_data_ratio" in kwargs: + if tensor_placement_policy != "auto": + self.logger.warning("setting warmup_non_model_data_ratio is useless if not use auto placement") + else: + ratio = kwargs["warmup_non_model_data_ratio"] + self._tensor_placement_policy._warmup_non_model_data_ratio = ratio + self.logger.info(f"setting warmup_non_model_data_ratio as {ratio} for auto placement") + + self._stateful_tensor_mgr = StatefulTensorMgr(self._tensor_placement_policy) + param_tensor_list = [p.colo_attr.sharded_data_tensor for p in module.parameters() if hasattr(p, "colo_attr")] + self._stateful_tensor_mgr.register_stateful_tensor_list(param_tensor_list) + + # Register hooks + self._ophook_list = [ + ZeroHook(self.shard_strategy, self._memstats_collector, self._stateful_tensor_mgr, self.process_group) + ] + register_ophooks_recursively(self.module, self._ophook_list) + self.param_hook_mgr = BaseParamHookMgr(list(self.module.parameters())) + self.param_hook_mgr.register_backward_hooks(self._grad_post_backward_hook) + + self.fp32_reduce_scatter = fp32_reduce_scatter + self._cpu_offload: bool = tensor_placement_policy != "cuda" + for param in module.parameters(): + # Init `offload_grad` + param.colo_attr.offload_grad = self._cpu_offload + + # We find if gradient_predivide_factor != 1.0, there may be wrong precision problem + # So we use 1.0 as the default gradient_predivide_factor + # However, if you set gradient_predivide_factor to None, we will set + # gradient_predivide_factor to a value >= 1.0 automatically + self.gradient_predivide_factor: float = ( + gradient_predivide_factor + if gradient_predivide_factor is not None + else get_gradient_predivide_factor(self.world_size) + ) + self.gradient_postdivide_factor: float = self.world_size / self.gradient_predivide_factor + + self.comm_stream: torch.cuda.Stream = torch.cuda.Stream() + self.reducer = ReduceScatterBucketer(reduce_scatter_bucket_size_mb) + self._require_backward_grad_sync: bool = True + + self._cuda_margin_space = 0 + self.reuse_fp16_shard = reuse_fp16_shard + + # record whether gradients have inf or nan + self.overflow_counter = 0 + + def adjust_stateful_tensor_layout(self) -> None: + self._stateful_tensor_mgr.adjust_layout() + + @property + def use_memory_tracer(self): + return self._use_memory_tracer + + @property + def cuda_margin_space(self): + return self._cuda_margin_space + + @property + def cpu_offload(self): + return self._cpu_offload + + def dump_memory_stats(self, filename: Optional[str] = "dump_mem_stats.log") -> None: + """ + dummy memory tracer collected information to a file. + try: + # forward: model(inputs) + # backward: optimizer.backward() + except Exception as e: + model.dump_memory_stats() + exit(0) + """ + if self._use_memory_tracer: + self.logger.error(f"dump memory tracer collected information to a {filename}", ranks=[0]) + if gpc.get_global_rank() == 0: + with open(filename, "w+") as f: + f.write(f"cuda reserved {torch.cuda.memory_reserved(get_current_device()) / 1e9} GB\n") + f.write(f"cuda max allocated {torch.cuda.max_memory_allocated(get_current_device()) / 1e9} GB\n") + f.write("CUDA model data (GB)\n") + f.write("\n") + f.write("CUDA non model data (GB)\n") + f.write(str(self._memstats_collector._memstats.non_model_data_list("cuda"))) + f.write("CPU non model data (GB)\n") + f.write(str(self._memstats_collector._memstats.non_model_data_list("cpu"))) + f.write("\n") + + def _pre_forward_operations(self, *args): + # the operation will affect the memory tracer behavior in ZeroHook + if self._memstats_collector: + self._start_collect_memstats() + + for p in self.module.parameters(): + if hasattr(p, "colo_attr"): + p.colo_attr.sharded_data_tensor.trans_state(TensorState.HOLD) + + self._stateful_tensor_mgr.start_iter() + + def _post_forward_operations(self): + for p in self.module.parameters(): + if hasattr(p, "colo_attr"): + p.colo_attr.sharded_data_tensor.trans_state(TensorState.HOLD) + + def forward(self, *args: Any, **kwargs: Any) -> torch.Tensor: + self._pre_forward_operations(*args) + cast_fn = cast_tensor_to_bf16 if self.bf16 else cast_tensor_to_fp16 + args, kwargs = cast_float_arguments(cast_fn, *args, **kwargs) + outputs = self.module(*args, **kwargs) + self._post_forward_operations() + return outputs + + def backward(self, loss): + loss.backward() + self._post_backward_operations() + for ophook in self._ophook_list: + ophook.post_iter() + + def backward_by_grad(self, tensor, grad): + torch.autograd.backward(tensors=tensor, grad_tensors=grad) + self._post_backward_operations() + for ophook in self._ophook_list: + ophook.post_iter() + + def _update_memstats(self): + if self._memstats_collector: + self._finish_collect_memstats() + # cuda margin space = cuda mem capacity - max fwd/bwd cuda mem used. + # the way to calculate margin space is based on the assumption that + # model data is fixed in cuda during training. + # cuda margin space can be used to store OS. + self._cuda_margin_space = ( + colo_device_memory_capacity(get_current_device()) - self._memstats_collector._memstats.max_overall_cuda + ) + + @torch.no_grad() + def _post_backward_operations(self) -> None: + """ + The method includes operations required to be processed after backward + 1. update memory tracer. + 2. flush the gradient in buckets. Reducing partial gradients in each process. + 3. shard tensors not dealed in the zero hook + 4. move sharded param grad payload to param.grad + """ + # 1. update memory tracer. + self._update_memstats() + + # 2. flush the gradient in buckets. Reducing partial gradients in each process. + if self._require_backward_grad_sync: + # Flush any unreduced buckets in the post_backward stream. + with torch.cuda.stream(self.comm_stream): + self.reducer.flush() + torch.cuda.current_stream().wait_stream(self.comm_stream) + self.reducer.free() + + # 3. shard tensors not dealed in the zero hook + tensor_list = [] + for p in self.sharded_params: + if not p.colo_attr.param_is_sharded: + tensor_list.append(p.colo_attr.sharded_data_tensor) + p.colo_attr.sharded_data_tensor.trans_state(TensorState.HOLD_AFTER_BWD) + p.colo_attr.set_data_none() + self.shard_strategy.shard(tensor_list, self.process_group) + + # 4. set all parameters' grad to None + for p in self.module.parameters(): + if not p.requires_grad: + continue + # Leave the gradient accumulation state (_require_backward_grad_sync) as-is if not synchronizing this pass. + # NOTE() (no-sync)/sync pass: (not conduct)/conduct gradient all reducing between process group. + # If _require_backward_grad_sync is True, + # p.grad remains the accumulated unsharded gradient from prior no-sync passes. + # We also allows to interleave no-sync pass with sync passes, if desired. + if not self._require_backward_grad_sync: + continue + + p.grad = None + + @torch.no_grad() + def _grad_post_backward_hook(self, param: Parameter, grad: torch.Tensor) -> Optional[torch.Tensor]: + """ + At the start of :func:`_grad_post_backward_hook`, ``param.grad`` contains the + full gradient for the local batch. The reduce-scatter op will save + a single shard of the summed gradient across all + GPUs to param.colo_attr.grad. This shard will align with the current GPU rank. For example:: + + before reduce_scatter: + param.grad (GPU #0): [1, 2, 3, 4] + param.grad (GPU #1): [5, 6, 7, 8] + + after reduce_scatter: + param.grad (GPU #0): [6, 8] # 1+5, 2+6 + param.grad (GPU #1): [10, 12] # 3+7, 4+8 + + The local GPU's ``optim.step`` is responsible for updating a single + shard of params, also corresponding to the current GPU's rank. This + alignment is created by `param.colo_attr.grad`, which ensures that + the local optimizer only sees the relevant parameter shard. + """ + if grad is None: + return + assert not grad.requires_grad, "ShardedModel only works with gradients that don't require gradients" + if not self._require_backward_grad_sync: + return + # used to cheat Pytorch, since we can't return None + empty_grad = torch.empty_like(grad) + free_storage(empty_grad) + # As torch didn't allow modifying grad in hook, we make a copy + grad = grad.clone() + if param.colo_attr.is_replicated: + self._reduce_scatter_handler(param, grad) + else: + self._save_grad(param, grad) + return empty_grad + + def _reduce_scatter_handler(self, param: Parameter, grad: torch.Tensor) -> None: + self.comm_stream.wait_stream(torch.cuda.current_stream()) + with torch.cuda.stream(self.comm_stream): + if self.fp32_reduce_scatter: + grad.data = grad.data.to(param.dtype) + if self.gradient_predivide_factor > 1.0: + # Average grad by world_size for consistency with PyTorch DDP. + grad.data.div_(self.gradient_predivide_factor) + if self.world_size > 1: + grad_chunks = chunk_and_pad(grad, self.reduce_scatter_process_group.size()) + self.reducer.reduce_scatter_async( + grad_chunks, + group=self.reduce_scatter_process_group, + callback_fn=functools.partial(self._reduce_scatter_callback, param), + ) + else: + self._reduce_scatter_callback(param, grad) + torch.cuda.current_stream().wait_stream(self.comm_stream) + + def _reduce_scatter_callback(self, param: Parameter, reduced_grad: torch.Tensor) -> None: + assert isinstance( + reduced_grad, torch.Tensor + ), f"_reduce_scatter_callback accept reduced_grad as {type(reduced_grad)}" + reduced_grad.data = reduced_grad.data.contiguous().view(-1) + if self.gradient_postdivide_factor > 1: + # Average grad by world_size for consistency with PyTorch DDP. + reduced_grad.data.div_(self.gradient_postdivide_factor) + self._save_grad(param, reduced_grad) + + # FIXME(ver217): refactor the below line when impl eviction policy + def _save_grad(self, param: Parameter, grad: torch.Tensor): + # record whether we have overflow + self.overflow_counter += torch.isinf(grad).any().item() + self.overflow_counter += torch.isnan(grad).any().item() + + # move gradient to cpu + if param.colo_attr.offload_grad: + colo_model_data_move_to_cpu(grad) + + if self.reuse_fp16_shard: + # make parameters point to gradient + + assert ( + param.colo_attr.saved_grad.is_null() + ), "Gradient accumulation is not supported when reuse_fp16_shard=True" + + param.colo_attr.grad_payload_reset(grad.data) + # release the memory of param + # we set a false None for parameter's payload + # so we can get parameter's device and dtype later in optimizer + param.colo_attr.data_payload_reset(torch.empty(0, device=grad.device, dtype=grad.dtype)) + + if param.colo_attr.is_replicated: + param.colo_attr.sharded_data_tensor.is_sharded = True + else: + fp32_grad = cast_tensor_to_fp32(grad) + + if param.colo_attr.saved_grad.is_null(): + param.colo_attr.grad_payload_reset(fp32_grad) + else: + param.colo_attr.grad_payload.add_(fp32_grad.view_as(param.colo_attr.grad_payload)) + + # keep saved_grad in HOLD state + param.colo_attr.saved_grad.trans_state(TensorState.HOLD) + + def parameters(self, recurse: bool = True) -> Iterator[Parameter]: + return self.module.parameters(recurse=recurse) + + def named_parameters(self, prefix: str = "", recurse: bool = True) -> Iterator[Tuple[str, Parameter]]: + return self.module.named_parameters(prefix, recurse) + + def state_dict(self, destination=None, prefix="", keep_vars=False) -> "OrderedDict[str, torch.Tensor]": + return self._colo_state_dict( + destination, + prefix, + keep_vars, + shard_strategy=self.shard_strategy, + state_dict_func=nn.Module.state_dict, + module_to_load=self.module, + sharded_params=self.sharded_params, + process_group=self.process_group, + ) + + def load_state_dict(self, state_dict: "OrderedDict[str, torch.Tensor]", strict: bool = True) -> None: + for name, p in self.named_parameters(): + if name in state_dict: + p.colo_attr.data_payload_reset( + state_dict[name].to(dtype=p.colo_attr.data_payload.dtype, device=p.colo_attr.data_payload.device) + ) + # Force re-shard + p.colo_attr.sharded_data_tensor.is_sharded = False + self.shard_strategy.shard([p.colo_attr.sharded_data_tensor]) + elif strict: + raise RuntimeError(f"Missing key in state_dict: {name}") + + def _colo_state_dict( + self, + destination=None, + prefix="", + keep_vars=False, + shard_strategy: Optional[BaseShardStrategy] = None, + state_dict_func=None, + module_to_load=None, + sharded_params=[], + process_group=None, + ) -> "OrderedDict[str, torch.Tensor]": + if len(sharded_params) == 0: + for param in self.parameters(): + if param.colo_attr.param_is_sharded: + sharded_params.append(param) + if shard_strategy is not None: + shard_strategy.gather([p.colo_attr.sharded_data_tensor for p in sharded_params], process_group) + for p in sharded_params: + p.data = p.colo_attr.data_payload + module_to_load = module_to_load or self + gathered_state_dict = state_dict_func(module_to_load, destination, prefix, keep_vars) + gathered_state_dict = {k: v.cpu() if isinstance(v, torch.Tensor) else v for k, v in gathered_state_dict.items()} + if shard_strategy is not None: + shard_strategy.shard([p.colo_attr.sharded_data_tensor for p in sharded_params], process_group) + for p in sharded_params: + p.colo_attr.set_data_none() + return gathered_state_dict + + def _colo_load_from_state_dict( + self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs, shard_strategy=None + ): + r"""Copies parameters and buffers from :attr:`state_dict` into only + this module, but not its descendants. This is called on every submodule + in :meth:`~torch.nn.Module.load_state_dict`. Metadata saved for this + module in input :attr:`state_dict` is provided as :attr:`local_metadata`. + For state dicts without metadata, :attr:`local_metadata` is empty. + Subclasses can achieve class-specific backward compatible loading using + the version number at `local_metadata.get("version", None)`. + + .. note:: + :attr:`state_dict` is not the same object as the input + :attr:`state_dict` to :meth:`~torch.nn.Module.load_state_dict`. So + it can be modified. + + Args: + state_dict (dict): a dict containing parameters and + persistent buffers. + prefix (str): the prefix for parameters and buffers used in this + module + local_metadata (dict): a dict containing the metadata for this module. + See + strict (bool): whether to strictly enforce that the keys in + :attr:`state_dict` with :attr:`prefix` match the names of + parameters and buffers in this module + missing_keys (list of str): if ``strict=True``, add missing keys to + this list + unexpected_keys (list of str): if ``strict=True``, add unexpected + keys to this list + error_msgs (list of str): error messages should be added to this + list, and will be reported together in + :meth:`~torch.nn.Module.load_state_dict` + shard_strategy (Optional[BaseShardStrategy], optional): A shard strategy to manage shard behavior. Defaults to None. + """ + for hook in self._load_state_dict_pre_hooks.values(): + hook(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs) + + persistent_buffers = {k: v for k, v in self._buffers.items() if k not in self._non_persistent_buffers_set} + local_name_params = itertools.chain(self._parameters.items(), persistent_buffers.items()) + local_state = {k: v for k, v in local_name_params if v is not None} + + for name, param in local_state.items(): + key = prefix + name + if key in state_dict: + input_param = state_dict[key] + if hasattr(param, "colo_attr"): + param.colo_attr.data_payload_reset( + input_param.to( + dtype=param.colo_attr.data_payload.dtype, device=param.colo_attr.data_payload.device + ) + ) + if shard_strategy is not None: + # Force re-shard + param.colo_attr.sharded_data_tensor.is_sharded = False + shard_strategy.shard([param.colo_attr.sharded_data_tensor]) + else: + # This is used to avoid copying uninitialized parameters into + # non-lazy modules, since they dont have the hook to do the checks + # in such case, it will error when accessing the .shape attribute. + is_param_lazy = torch.nn.parameter.is_lazy(param) + # Backward compatibility: loading 1-dim tensor from 0.3.* to version 0.4+ + if not is_param_lazy and len(param.shape) == 0 and len(input_param.shape) == 1: + input_param = input_param[0] + + if not is_param_lazy and input_param.shape != param.shape: + # local shape should match the one in checkpoint + error_msgs.append( + "size mismatch for {}: copying a param with shape {} from checkpoint, " + "the shape in current model is {}.".format(key, input_param.shape, param.shape) + ) + continue + try: + with torch.no_grad(): + param.copy_(input_param) + except Exception as ex: + error_msgs.append( + 'While copying the parameter named "{}", ' + "whose dimensions in the model are {} and " + "whose dimensions in the checkpoint are {}, " + "an exception occurred : {}.".format(key, param.size(), input_param.size(), ex.args) + ) + elif strict: + missing_keys.append(key) + + extra_state_key = prefix + _EXTRA_STATE_KEY_SUFFIX + if getattr(self.__class__, "set_extra_state", nn.Module.set_extra_state) is not nn.Module.set_extra_state: + if extra_state_key in state_dict: + self.set_extra_state(state_dict[extra_state_key]) + elif strict: + missing_keys.append(extra_state_key) + elif strict and (extra_state_key in state_dict): + unexpected_keys.append(extra_state_key) + + if strict: + for key in state_dict.keys(): + if key.startswith(prefix) and key != extra_state_key: + input_name = key[len(prefix) :] + input_name = input_name.split(".", 1)[0] # get the name of param/buffer/child + if input_name not in self._modules and input_name not in local_state: + unexpected_keys.append(key) + + def __getitem__(self, idx: int): + assert isinstance(self.module, nn.ModuleList) + return self.module[idx] + + def __len__(self): + assert isinstance(self.module, nn.ModuleList) + return len(self.module) + + def __iter__(self): + assert isinstance(self.module, nn.ModuleList) + return iter(self.module) diff --git a/colossalai/legacy/zero/sharded_model/utils.py b/colossalai/legacy/zero/sharded_model/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..cb085f19e6b250642aad7d806dac929181bd78e0 --- /dev/null +++ b/colossalai/legacy/zero/sharded_model/utils.py @@ -0,0 +1,20 @@ +import copy + +import torch + +from colossalai.legacy.zero.sharded_model import ShardedModelV2 + + +def col_model_deepcopy(sharded_model: ShardedModelV2, other_model: torch.nn.Module): + """ + copy param of the ShardedModelV2 to other_model. + Note the other_model has to be the same as self. + """ + for zero_param, param in zip(sharded_model.parameters(), other_model.parameters()): + assert hasattr(zero_param, "colo_attr") + shard_flag = zero_param.colo_attr.sharded_data_tensor.is_sharded + if shard_flag: + sharded_model.shard_strategy.gather([zero_param.colo_attr.sharded_data_tensor]) + param.data = copy.deepcopy(zero_param.colo_attr.data_payload) + if shard_flag: + sharded_model.shard_strategy.shard([zero_param.colo_attr.sharded_data_tensor]) diff --git a/colossalai/zero/legacy/sharded_model/zero_hook.py b/colossalai/legacy/zero/sharded_model/zero_hook.py similarity index 82% rename from colossalai/zero/legacy/sharded_model/zero_hook.py rename to colossalai/legacy/zero/sharded_model/zero_hook.py index 50f4bdfc775d77a938d51442e2966b720a519b94..892e9f31ded4472045c649907dc6f27aae6a7e22 100644 --- a/colossalai/zero/legacy/sharded_model/zero_hook.py +++ b/colossalai/legacy/zero/sharded_model/zero_hook.py @@ -3,14 +3,14 @@ from typing import Optional import torch import torch.distributed as dist +from colossalai.legacy.registry import OPHOOKS +from colossalai.legacy.zero.gemini.ophooks import BaseOpHook +from colossalai.legacy.zero.gemini.stateful_tensor import TensorState +from colossalai.legacy.zero.gemini.stateful_tensor_mgr import StatefulTensorMgr +from colossalai.legacy.zero.shard_utils import BaseShardStrategy from colossalai.logging import get_dist_logger -from colossalai.registry import OPHOOKS from colossalai.utils import get_current_device from colossalai.zero.gemini.memory_tracer import MemStatsCollector -from colossalai.zero.legacy.gemini.ophooks import BaseOpHook -from colossalai.zero.legacy.gemini.stateful_tensor import TensorState -from colossalai.zero.legacy.gemini.stateful_tensor_mgr import StatefulTensorMgr -from colossalai.zero.legacy.shard_utils import BaseShardStrategy @OPHOOKS.register_module @@ -20,11 +20,13 @@ class ZeroHook(BaseOpHook): Warning: this class has been deprecated after version 0.1.12 """ - def __init__(self, - shard_strategy: BaseShardStrategy, - memstarts_collector: Optional[MemStatsCollector] = None, - stateful_tensor_mgr: Optional[StatefulTensorMgr] = None, - process_group: Optional[dist.ProcessGroup] = None): + def __init__( + self, + shard_strategy: BaseShardStrategy, + memstarts_collector: Optional[MemStatsCollector] = None, + stateful_tensor_mgr: Optional[StatefulTensorMgr] = None, + process_group: Optional[dist.ProcessGroup] = None, + ): super().__init__() self.logger = get_dist_logger("ZeROHook") self.shard_strategy = shard_strategy @@ -41,7 +43,7 @@ class ZeroHook(BaseOpHook): if module.param_is_sharded: tensor_list = [] for param in module.parameters(recurse=False): - assert hasattr(param, 'colo_attr') + assert hasattr(param, "colo_attr") tensor_list.append(param.colo_attr.sharded_data_tensor) self.shard_strategy.gather(tensor_list, self.process_group) @@ -50,7 +52,7 @@ class ZeroHook(BaseOpHook): if module.param_is_sharded: tensor_list = [] for param in module.parameters(recurse=False): - assert hasattr(param, 'colo_attr') + assert hasattr(param, "colo_attr") tensor_list.append(param.colo_attr.sharded_data_tensor) self.shard_strategy.shard(tensor_list, self.process_group) @@ -74,10 +76,9 @@ class ZeroHook(BaseOpHook): self.gather_parameters(module) for param in module.parameters(recurse=False): param.data = param.colo_attr.data_payload - assert param.data.device.type == 'cuda', f"PRE FWD param.data must be on CUDA" + assert param.data.device.type == "cuda", f"PRE FWD param.data must be on CUDA" def post_fwd_exec(self, module: torch.nn.Module, *args): - # change tensor state to HOLD_AFTER_FWD for param in module.parameters(recurse=False): param.colo_attr.sharded_data_tensor.trans_state(TensorState.HOLD_AFTER_FWD) @@ -93,10 +94,9 @@ class ZeroHook(BaseOpHook): self.gather_parameters(module) for param in module.parameters(recurse=False): param.data = param.colo_attr.data_payload - assert param.data.device.type == 'cuda', f"PRE BWD param.data must be on CUDA" + assert param.data.device.type == "cuda", f"PRE BWD param.data must be on CUDA" def post_bwd_exec(self, module: torch.nn.Module, input): - # change tensor state to HOLD_AFTER_BWD for param in module.parameters(recurse=False): param.colo_attr.sharded_data_tensor.trans_state(TensorState.HOLD_AFTER_BWD) @@ -114,5 +114,6 @@ class ZeroHook(BaseOpHook): if self._stateful_tensor_mgr: self.logger.debug( f"CPU-GPU data moving this iteration {self._stateful_tensor_mgr.cpu_gpu_move_volume/1e9} GB, get layout info time: {self._stateful_tensor_mgr._layout_time}, evict cpu time: {self._stateful_tensor_mgr._evict_time}", - ranks=[0]) + ranks=[0], + ) self._stateful_tensor_mgr.finish_iter() diff --git a/colossalai/legacy/zero/sharded_optim/__init__.py b/colossalai/legacy/zero/sharded_optim/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..700fb0eb91d31cc867bb0e34c10f3fe59d71f919 --- /dev/null +++ b/colossalai/legacy/zero/sharded_optim/__init__.py @@ -0,0 +1,3 @@ +from .sharded_optim_v2 import ShardedOptimizerV2 + +__all__ = ["ShardedOptimizerV2"] diff --git a/colossalai/zero/legacy/sharded_optim/sharded_optim_v2.py b/colossalai/legacy/zero/sharded_optim/sharded_optim_v2.py similarity index 75% rename from colossalai/zero/legacy/sharded_optim/sharded_optim_v2.py rename to colossalai/legacy/zero/sharded_optim/sharded_optim_v2.py index be60209af434ea80da71f76299217f0c9fd4340e..e73679163fab468b5c462fee67418a79a9bc4735 100644 --- a/colossalai/zero/legacy/sharded_optim/sharded_optim_v2.py +++ b/colossalai/legacy/zero/sharded_optim/sharded_optim_v2.py @@ -1,6 +1,5 @@ # this code is inspired by the DeepSpeed library and implemented with our own design from scratch from enum import Enum -from os import stat from typing import Dict, Optional, Tuple import torch @@ -12,15 +11,15 @@ from torch.nn.parameter import Parameter from torch.optim import Optimizer from colossalai.amp.naive_amp.grad_scaler import DynamicGradScaler -from colossalai.context.parallel_mode import ParallelMode -from colossalai.core import global_context as gpc +from colossalai.interface import OptimizerWrapper +from colossalai.legacy.context.parallel_mode import ParallelMode +from colossalai.legacy.core import global_context as gpc +from colossalai.legacy.zero.gemini.stateful_tensor import StatefulTensor, TensorState +from colossalai.legacy.zero.gemini.tensor_placement_policy import AutoTensorPlacementPolicy +from colossalai.legacy.zero.gemini.tensor_utils import colo_model_data_tensor_move_inline, colo_tensor_mem_usage +from colossalai.legacy.zero.sharded_model import ShardedModelV2 +from colossalai.legacy.zero.sharded_model._utils import cast_tensor_to_fp32 from colossalai.logging import get_dist_logger -from colossalai.nn.optimizer import ColossalaiOptimizer -from colossalai.zero.legacy.gemini.stateful_tensor import StatefulTensor, TensorState -from colossalai.zero.legacy.gemini.tensor_placement_policy import AutoTensorPlacementPolicy -from colossalai.zero.legacy.gemini.tensor_utils import colo_model_data_tensor_move_inline, colo_tensor_mem_usage -from colossalai.zero.legacy.sharded_model import ShardedModelV2 -from colossalai.zero.legacy.sharded_model._utils import cast_tensor_to_fp32 class OptimState(Enum): @@ -28,7 +27,7 @@ class OptimState(Enum): UNSCALED = 2 -class ShardedOptimizerV2(ColossalaiOptimizer): +class ShardedOptimizerV2(OptimizerWrapper): """A wrapper for optimizer. ``ShardedOptimizerV2`` and ``ShardedModelV2`` implement Zero Redundancy Optimizer (ZeRO). By default the ZeRO optimizer stage 3 offload Optimizer States on CPU. @@ -74,60 +73,74 @@ class ShardedOptimizerV2(ColossalaiOptimizer): https://arxiv.org/abs/2108.05818 """ - def __init__(self, - sharded_model: ShardedModelV2, - optimizer: Optimizer, - gpu_margin_mem_ratio: float = 0.0, - initial_scale: float = 2**32, - min_scale: float = 1, - growth_factor: float = 2, - backoff_factor: float = 0.5, - growth_interval: int = 1000, - hysteresis: int = 2, - max_scale: float = 2**32, - dp_process_group: Optional[ProcessGroup] = None, - mp_process_group: Optional[ProcessGroup] = None, - verbose: bool = False) -> None: - assert isinstance(sharded_model, ShardedModelV2), 'model must be wrapped with ShardedModel' - assert not isinstance(optimizer, ShardedOptimizerV2), 'Nested ShardedOptimizerV2 is not supported.' + def __init__( + self, + sharded_model: ShardedModelV2, + optimizer: Optimizer, + gpu_margin_mem_ratio: float = 0.0, + initial_scale: float = 2**32, + min_scale: float = 1, + growth_factor: float = 2, + backoff_factor: float = 0.5, + growth_interval: int = 1000, + hysteresis: int = 2, + max_scale: float = 2**32, + dp_process_group: Optional[ProcessGroup] = None, + mp_process_group: Optional[ProcessGroup] = None, + verbose: bool = False, + ) -> None: + assert isinstance(sharded_model, ShardedModelV2), "model must be wrapped with ShardedModel" + assert not isinstance(optimizer, ShardedOptimizerV2), "Nested ShardedOptimizerV2 is not supported." super().__init__(optimizer) self.shard_strategy = sharded_model.shard_strategy self.model: ShardedModelV2 = sharded_model + self.bf16 = sharded_model.bf16 self.gpu_margin_mem_ratio: float = float(gpu_margin_mem_ratio) - assert 0.0 <= self.gpu_margin_mem_ratio <= 1.0, f'gpu_margin_mem_ratio must >=0.0 and <=1.0' + assert 0.0 <= self.gpu_margin_mem_ratio <= 1.0, f"gpu_margin_mem_ratio must >=0.0 and <=1.0" # Only move fp32 shards from CPU to GPU when user allows and inner optimizer is valid # Inner optimizer must support optimizing hybrid (CPU and CUDA) tensors, # and it must set `num_fp32_shards_per_param` correctly - self._should_move_fp32_shards_h2d: bool = sharded_model.cpu_offload and self.gpu_margin_mem_ratio > 0.0 and getattr( - optimizer, 'num_fp32_shards_per_param', 0) >= 2 - self.device = sharded_model._tensor_placement_policy.device or torch.device('cpu') + self._should_move_fp32_shards_h2d: bool = ( + sharded_model.cpu_offload + and self.gpu_margin_mem_ratio > 0.0 + and getattr(optimizer, "num_fp32_shards_per_param", 0) >= 2 + ) + self.device = sharded_model._tensor_placement_policy.device or torch.device("cpu") self.optim_state: OptimState = OptimState.UNSCALED self.dp_process_group = dp_process_group or gpc.get_group(ParallelMode.DATA) self.mp_process_group = mp_process_group or gpc.get_group(ParallelMode.MODEL) # Grad scaler - self.grad_scaler = DynamicGradScaler(initial_scale=initial_scale, - min_scale=min_scale, - growth_factor=growth_factor, - backoff_factor=backoff_factor, - growth_interval=growth_interval, - hysteresis=hysteresis, - max_scale=max_scale) + self.grad_scaler = DynamicGradScaler( + initial_scale=initial_scale, + min_scale=min_scale, + growth_factor=growth_factor, + backoff_factor=backoff_factor, + growth_interval=growth_interval, + hysteresis=hysteresis, + max_scale=max_scale, + ) self._found_overflow: Tensor = torch.IntTensor([0]).to(torch.cuda.current_device()) self._logger = get_dist_logger("ShardedOptimizerV2") self._verbose = verbose + self._grad_prepared: bool = ( + False # this should be set to true when _prepare_grads() and reset to false when backward + ) # Store fp32 param shards self._register_master_weight() - if self.gpu_margin_mem_ratio != 0.0 and not isinstance(sharded_model._tensor_placement_policy, - AutoTensorPlacementPolicy): - self._logger.warning(f'gpu_margin_mem_ratio is meaningless when tensor_placement_policy is not "auto"', - ranks=[0]) + if self.gpu_margin_mem_ratio != 0.0 and not isinstance( + sharded_model._tensor_placement_policy, AutoTensorPlacementPolicy + ): + self._logger.warning( + f'gpu_margin_mem_ratio is meaningless when tensor_placement_policy is not "auto"', ranks=[0] + ) if self._verbose: self._logger.debug( - f"After init ShardedOptimizerV2 consumes {self.get_memory_usage()[0] / 1e6} MB CUDA Memory!", ranks=[0]) + f"After init ShardedOptimizerV2 consumes {self.get_memory_usage()[0] / 1e6} MB CUDA Memory!", ranks=[0] + ) self._use_memory_tracer = self.model.use_memory_tracer @@ -136,7 +149,7 @@ class ShardedOptimizerV2(ColossalaiOptimizer): return self.grad_scaler.scale.item() def get_memory_usage(self) -> Tuple[int, int]: - """ Get the memory usage of the optimizer. Including master_params (param fp32), + """Get the memory usage of the optimizer. Including master_params (param fp32), momentum (``self.state[p]['exp_avg']``) variance (``self.state[p]['exp_avg_sq']``) Returns: @@ -155,7 +168,7 @@ class ShardedOptimizerV2(ColossalaiOptimizer): for _, p_fp32 in self.master_params.items(): update_mem_use(p_fp32) for group in self.optim.param_groups: - for p in group['params']: + for p in group["params"]: state = self.optim.state[p] for k, v in state.items(): update_mem_use(v) @@ -166,8 +179,10 @@ class ShardedOptimizerV2(ColossalaiOptimizer): self._zero_grad() def backward(self, loss: Tensor) -> None: - loss = self.loss_scale * loss - self.optim_state = OptimState.SCALED + if not self.bf16: + loss = self.loss_scale * loss + self.optim_state = OptimState.SCALED + self._grad_prepared = False self.model.backward(loss) def backward_by_grad(self, tensor: Tensor, grad: Tensor) -> None: @@ -175,30 +190,32 @@ class ShardedOptimizerV2(ColossalaiOptimizer): # It receives the scaled grad from the previous rank # No need to scale the grad again # Need to unscale when optimizing - self.optim_state = OptimState.SCALED + if not self.bf16: + self.optim_state = OptimState.SCALED + self._grad_prepared = False self.model.backward_by_grad(tensor, grad) def clip_grad_norm(self, model: nn.Module, max_norm: float): - if self.optim_state == OptimState.SCALED: - self._prepare_grads() + self._prepare_grads() + if not self.bf16 and self.optim_state == OptimState.SCALED: self._unscale_grads() return super().clip_grad_norm(model, max_norm) def step(self, *args, **kwargs): - + self._prepare_grads() # unscale grads if scaled - if self.optim_state == OptimState.SCALED: - self._prepare_grads() + if not self.bf16 and self.optim_state == OptimState.SCALED: self._unscale_grads() self._maybe_move_fp32_shards() - found_inf = self._check_overflow() - self.grad_scaler.update(found_inf) + if not self.bf16: + found_inf = self._check_overflow() + self.grad_scaler.update(found_inf) - if found_inf: - self._logger.warning('found inf during ShardedOptimV2 step') - self._zero_grad(recover_data=True) - return + if found_inf: + self._logger.warning("found inf during ShardedOptimV2 step") + self._zero_grad(recover_data=True) + return self._point_param_fp16_to_master_param() @@ -206,14 +223,16 @@ class ShardedOptimizerV2(ColossalaiOptimizer): gpu_mem, cpu_mem = self.get_memory_usage() self._logger.debug( f"Before step ShardedOptimizerV2 consumes {gpu_mem / 1e6} MB CUDA Memory, {cpu_mem / 1e6} MB CUDA Memory!", - ranks=[0]) + ranks=[0], + ) ret = self.optim.step(*args, **kwargs) if self._verbose: gpu_mem, cpu_mem = self.get_memory_usage() self._logger.debug( f"After step ShardedOptimizerV2 consumes {gpu_mem / 1e6} MB CUDA Memory, {cpu_mem / 1e6} MB CUDA Memory!", - ranks=[0]) + ranks=[0], + ) self._copy_master_model_to_model_fp16() return ret @@ -233,7 +252,7 @@ class ShardedOptimizerV2(ColossalaiOptimizer): def _unscale_grads(self): assert self.optim_state == OptimState.SCALED for group in self.optim.param_groups: - for p in group['params']: + for p in group["params"]: if p.grad is not None: p.grad.data.div_(self.loss_scale) self.optim_state = OptimState.UNSCALED @@ -253,16 +272,16 @@ class ShardedOptimizerV2(ColossalaiOptimizer): # Which leads to wrong accumulation self.optim.zero_grad(set_to_none=True) for group in self.optim.param_groups: - for p in group['params']: + for p in group["params"]: # p.colo_attr.sharded_data_tensor stores grad now # we have to recover fp16 param - reuse_fp16_shard = (p.colo_attr.sharded_data_tensor.payload_size == 0) + reuse_fp16_shard = p.colo_attr.sharded_data_tensor.payload_size == 0 if recover_data and reuse_fp16_shard: self._copy_master_param_to_param_fp16(p) else: # release saved gradient p.colo_attr.saved_grad.set_null() - self.model.overflow_counter = 0 # set overflow counter to zero + self.model.overflow_counter = 0 # set overflow counter to zero def sync_grad(self): pass @@ -270,8 +289,8 @@ class ShardedOptimizerV2(ColossalaiOptimizer): def _register_master_weight(self): self.master_params: Dict[Parameter, StatefulTensor] = {} for group in self.optim.param_groups: - for p in group['params']: - assert hasattr(p, 'colo_attr'), 'The parameter must be wrapped with ShardedParam' + for p in group["params"]: + assert hasattr(p, "colo_attr"), "The parameter must be wrapped with ShardedParam" shard_flag = not p.colo_attr.sharded_data_tensor.is_sharded and p.colo_attr.is_replicated if shard_flag: # we always shard replicated parameters @@ -289,7 +308,7 @@ class ShardedOptimizerV2(ColossalaiOptimizer): fp32_shards_available_cuda_margin_mem = available_cuda_margin_mem / self.optim.num_fp32_shards_per_param fp32_shards_used_cuda_margin_mem = 0 for group in self.optim.param_groups: - for p in group['params']: + for p in group["params"]: if p.colo_attr.saved_grad.is_null(): continue shard_mem = self.master_params[p].payload.numel() * self.master_params[p].payload.element_size() @@ -304,8 +323,10 @@ class ShardedOptimizerV2(ColossalaiOptimizer): state[k] = v.cuda() def _prepare_grads(self): + if self._grad_prepared: + return for group in self.optim.param_groups: - for p in group['params']: + for p in group["params"]: if p.colo_attr.saved_grad.is_null(): continue p.colo_attr.saved_grad.trans_state(TensorState.COMPUTE) @@ -320,12 +341,13 @@ class ShardedOptimizerV2(ColossalaiOptimizer): p.grad = p.colo_attr.grad_payload # Set p.data to empty tensor, in case of memory leaking p.colo_attr.set_data_none() + self._grad_prepared = True def _point_param_fp16_to_master_param(self): # assign master param pointers to p.data. # We will not trigger data copy here. for group in self.optim.param_groups: - for p in group['params']: + for p in group["params"]: self.master_params[p].trans_state(TensorState.COMPUTE) p.data = self.master_params[p].payload # Now p.data is sharded @@ -336,7 +358,7 @@ class ShardedOptimizerV2(ColossalaiOptimizer): # TODO() improve efficiency by gathering tensors into a chunk and transferring # a chunk. for group in self.optim.param_groups: - for p in group['params']: + for p in group["params"]: self._copy_master_param_to_param_fp16(p) def _copy_master_param_to_param_fp16(self, p): @@ -354,15 +376,17 @@ class ShardedOptimizerV2(ColossalaiOptimizer): # in order to use copy, otherwise, the sizes of tensor is not compatible if p.colo_attr.data_payload.numel() != p.data.numel(): p.colo_attr.data_payload_reset( - torch.empty(p.data.shape, dtype=p.colo_attr.data_payload.dtype, device=p.colo_attr.data_payload.device)) + torch.empty(p.data.shape, dtype=p.colo_attr.data_payload.dtype, device=p.colo_attr.data_payload.device) + ) # TODO() optimize this line CPU (fp32) -> GPU (fp16) - p.colo_attr.sharded_data_tensor.payload_copy(p.half().detach()) + half_dtype = torch.bfloat16 if self.bf16 else torch.float16 + p.colo_attr.sharded_data_tensor.payload_copy(p.to(half_dtype).detach()) p.colo_attr.set_data_none() if p.colo_attr.keep_not_shard and p.colo_attr.is_replicated: # We gather full fp16 param here - p.colo_attr.sharded_data_tensor.is_sharded = True # since only gradient is sharded, we should set to True + p.colo_attr.sharded_data_tensor.is_sharded = True # since only gradient is sharded, we should set to True self.shard_strategy.gather([p.colo_attr.sharded_data_tensor], self.dp_process_group) self.master_params[p].trans_state(TensorState.HOLD) @@ -370,18 +394,18 @@ class ShardedOptimizerV2(ColossalaiOptimizer): def state_dict(self): optim_state_dict = super().state_dict() scaler_state_dict = self.grad_scaler.state_dict() - optim_state_dict['scaler'] = scaler_state_dict + optim_state_dict["scaler"] = scaler_state_dict return optim_state_dict def load_state_dict(self, *args, **kwargs): - if 'scaler' not in args[0]: - self._logger.warning('Missing scaler when loading optimizer state dict', ranks=[0]) + if "scaler" not in args[0]: + self._logger.warning("Missing scaler when loading optimizer state dict", ranks=[0]) else: - scaler_state_dict = args[0].pop('scaler') + scaler_state_dict = args[0].pop("scaler") self.grad_scaler.load_state_dict(scaler_state_dict) super().load_state_dict(*args, **kwargs) for group in self.optim.param_groups: - for p in group['params']: + for p in group["params"]: state = self.optim.state[p] for k, v in state.items(): if isinstance(v, Tensor): diff --git a/colossalai/legacy/zero/sharded_param/__init__.py b/colossalai/legacy/zero/sharded_param/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..c7afb95391a4d7a4daefe2ad0a206583795ed04f --- /dev/null +++ b/colossalai/legacy/zero/sharded_param/__init__.py @@ -0,0 +1,4 @@ +from .sharded_param import ShardedParamV2 +from .sharded_tensor import ShardedTensor + +__all__ = ["ShardedTensor", "ShardedParamV2"] diff --git a/colossalai/zero/legacy/sharded_param/sharded_param.py b/colossalai/legacy/zero/sharded_param/sharded_param.py similarity index 94% rename from colossalai/zero/legacy/sharded_param/sharded_param.py rename to colossalai/legacy/zero/sharded_param/sharded_param.py index 4bcc4b62104ab23b0f86683bea1bf34294c07a70..22b09d5ff4bbac5ea22a9e77de04ef7607e4a66d 100644 --- a/colossalai/zero/legacy/sharded_param/sharded_param.py +++ b/colossalai/legacy/zero/sharded_param/sharded_param.py @@ -2,8 +2,8 @@ from typing import List, Optional, Tuple import torch -from colossalai.zero.legacy.gemini.stateful_tensor import StatefulTensor, TensorState -from colossalai.zero.legacy.gemini.tensor_utils import colo_tensor_mem_usage +from colossalai.legacy.zero.gemini.stateful_tensor import StatefulTensor, TensorState +from colossalai.legacy.zero.gemini.tensor_utils import colo_tensor_mem_usage from .sharded_tensor import ShardedTensor @@ -19,7 +19,6 @@ def get_empty_tensor(device: torch.device, dtype: torch.dtype): class ShardedParamV2(object): - def __init__(self, param: torch.nn.Parameter, set_data_none: bool = False) -> None: self._sharded_data_tensor: ShardedTensor = ShardedTensor(param.data) self.saved_grad: StatefulTensor = StatefulTensor(None, TensorState.FREE) @@ -36,8 +35,7 @@ class ShardedParamV2(object): self.set_data_none() def get_payload_tensors(self) -> List[StatefulTensor]: - """returns stateful tensors kept by this class. - """ + """returns stateful tensors kept by this class.""" return [self._sharded_data_tensor] def set_data_none(self): diff --git a/colossalai/zero/legacy/sharded_param/sharded_tensor.py b/colossalai/legacy/zero/sharded_param/sharded_tensor.py similarity index 94% rename from colossalai/zero/legacy/sharded_param/sharded_tensor.py rename to colossalai/legacy/zero/sharded_param/sharded_tensor.py index af60312600f22554450218ce82d4890e93f926fe..262682d44645692f7757c4ea307651e4e065cac8 100644 --- a/colossalai/zero/legacy/sharded_param/sharded_tensor.py +++ b/colossalai/legacy/zero/sharded_param/sharded_tensor.py @@ -1,10 +1,9 @@ import torch -from colossalai.zero.legacy.gemini.stateful_tensor import StatefulTensor, TensorState +from colossalai.legacy.zero.gemini.stateful_tensor import StatefulTensor, TensorState class ShardedTensor(StatefulTensor): - def __init__(self, tensor: torch.Tensor, state: TensorState = TensorState.HOLD) -> None: r""" A tensor sharded in multiple processes. Constructed from an existing torch.Tensor instance. diff --git a/colossalai/logging/__init__.py b/colossalai/logging/__init__.py index 97fe4f89ded370c6a71ce277bb33e252912b6f17..521eafa74c30785d836a49d12e6b30bbdb994793 100644 --- a/colossalai/logging/__init__.py +++ b/colossalai/logging/__init__.py @@ -3,23 +3,23 @@ from typing import List, Optional from .logger import DistributedLogger -__all__ = ['get_dist_logger', 'DistributedLogger', 'disable_existing_loggers'] +__all__ = ["get_dist_logger", "DistributedLogger", "disable_existing_loggers"] -def get_dist_logger(name: str = 'colossalai') -> DistributedLogger: +def get_dist_logger(name: str = "colossalai") -> DistributedLogger: """Get logger instance based on name. The DistributedLogger will create singleton instances, which means that only one logger instance is created per name. Args: name (str): name of the logger, name must be unique - + Returns: :class:`colossalai.logging.DistributedLogger`: A distributed logger singleton instance. """ return DistributedLogger.get_instance(name=name) -def disable_existing_loggers(include: Optional[List[str]] = None, exclude: List[str] = ['colossalai']) -> None: +def disable_existing_loggers(include: Optional[List[str]] = None, exclude: List[str] = ["colossalai"]) -> None: """Set the level of existing loggers to `WARNING`. By default, it will "disable" all existing loggers except the logger named "colossalai". Args: diff --git a/colossalai/logging/logger.py b/colossalai/logging/logger.py index af7b7de54a8d481312c4934ac1a24a2f1bcaf7e0..eb5f28e2a3cf809e77d3111d914868e6ae2325a1 100644 --- a/colossalai/logging/logger.py +++ b/colossalai/logging/logger.py @@ -6,8 +6,7 @@ import logging from pathlib import Path from typing import List, Union -import colossalai -from colossalai.context.parallel_mode import ParallelMode +import torch.distributed as dist class DistributedLogger: @@ -43,12 +42,14 @@ class DistributedLogger: def __init__(self, name): if name in DistributedLogger.__instances: raise Exception( - 'Logger with the same name has been created, you should use colossalai.logging.get_dist_logger') + "Logger with the same name has been created, you should use colossalai.logging.get_dist_logger" + ) else: handler = None - formatter = logging.Formatter('colossalai - %(name)s - %(levelname)s: %(message)s') + formatter = logging.Formatter("colossalai - %(name)s - %(levelname)s: %(message)s") try: from rich.logging import RichHandler + handler = RichHandler(show_path=False, markup=True, rich_tracebacks=True) handler.setFormatter(formatter) except ImportError: @@ -63,6 +64,7 @@ class DistributedLogger: self._logger.propagate = False DistributedLogger.__instances[name] = self + self.rank = dist.get_rank() if dist.is_initialized() else 0 @staticmethod def __get_call_info(): @@ -79,7 +81,7 @@ class DistributedLogger: @staticmethod def _check_valid_logging_level(level: str): - assert level in ['INFO', 'DEBUG', 'WARNING', 'ERROR'], 'found invalid logging level' + assert level in ["INFO", "DEBUG", "WARNING", "ERROR"], "found invalid logging level" def set_level(self, level: str) -> None: """Set the logging level @@ -90,7 +92,7 @@ class DistributedLogger: self._check_valid_logging_level(level) self._logger.setLevel(getattr(logging, level)) - def log_to_file(self, path: Union[str, Path], mode: str = 'a', level: str = 'INFO', suffix: str = None) -> None: + def log_to_file(self, path: Union[str, Path], mode: str = "a", level: str = "INFO", suffix: str = None) -> None: """Save the logs to file Args: @@ -99,8 +101,7 @@ class DistributedLogger: level (str): Can only be INFO, DEBUG, WARNING and ERROR. suffix (str): The suffix string of log's name. """ - assert isinstance(path, (str, Path)), \ - f'expected argument path to be type str or Path, but got {type(path)}' + assert isinstance(path, (str, Path)), f"expected argument path to be type str or Path, but got {type(path)}" self._check_valid_logging_level(level) if isinstance(path, str): @@ -109,85 +110,66 @@ class DistributedLogger: # create log directory path.mkdir(parents=True, exist_ok=True) - # set the default file name if path is a directory - if not colossalai.core.global_context.is_initialized(ParallelMode.GLOBAL): - rank = 0 - else: - rank = colossalai.core.global_context.get_global_rank() - if suffix is not None: - log_file_name = f'rank_{rank}_{suffix}.log' + log_file_name = f"rank_{self.rank}_{suffix}.log" else: - log_file_name = f'rank_{rank}.log' + log_file_name = f"rank_{self.rank}.log" path = path.joinpath(log_file_name) # add file handler file_handler = logging.FileHandler(path, mode) file_handler.setLevel(getattr(logging, level)) - formatter = logging.Formatter('colossalai - %(name)s - %(levelname)s: %(message)s') + formatter = logging.Formatter("colossalai - %(name)s - %(levelname)s: %(message)s") file_handler.setFormatter(formatter) self._logger.addHandler(file_handler) - def _log(self, - level, - message: str, - parallel_mode: ParallelMode = ParallelMode.GLOBAL, - ranks: List[int] = None) -> None: + def _log(self, level, message: str, ranks: List[int] = None) -> None: if ranks is None: getattr(self._logger, level)(message) else: - local_rank = colossalai.core.global_context.get_local_rank(parallel_mode) - if local_rank in ranks: + if self.rank in ranks: getattr(self._logger, level)(message) - def info(self, message: str, parallel_mode: ParallelMode = ParallelMode.GLOBAL, ranks: List[int] = None) -> None: + def info(self, message: str, ranks: List[int] = None) -> None: """Log an info message. Args: message (str): The message to be logged. - parallel_mode (:class:`colossalai.context.parallel_mode.ParallelMode`): - The parallel mode used for logging. Defaults to ParallelMode.GLOBAL. ranks (List[int]): List of parallel ranks. """ message_prefix = "{}:{} {}".format(*self.__get_call_info()) - self._log('info', message_prefix, parallel_mode, ranks) - self._log('info', message, parallel_mode, ranks) + self._log("info", message_prefix, ranks) + self._log("info", message, ranks) - def warning(self, message: str, parallel_mode: ParallelMode = ParallelMode.GLOBAL, ranks: List[int] = None) -> None: + def warning(self, message: str, ranks: List[int] = None) -> None: """Log a warning message. Args: message (str): The message to be logged. - parallel_mode (:class:`colossalai.context.parallel_mode.ParallelMode`): - The parallel mode used for logging. Defaults to ParallelMode.GLOBAL. ranks (List[int]): List of parallel ranks. """ message_prefix = "{}:{} {}".format(*self.__get_call_info()) - self._log('warning', message_prefix, parallel_mode, ranks) - self._log('warning', message, parallel_mode, ranks) + self._log("warning", message_prefix, ranks) + self._log("warning", message, ranks) - def debug(self, message: str, parallel_mode: ParallelMode = ParallelMode.GLOBAL, ranks: List[int] = None) -> None: + def debug(self, message: str, ranks: List[int] = None) -> None: """Log a debug message. Args: message (str): The message to be logged. - parallel_mode (:class:`colossalai.context.parallel_mode.ParallelMode`): - The parallel mode used for logging. Defaults to ParallelMode.GLOBAL. ranks (List[int]): List of parallel ranks. """ message_prefix = "{}:{} {}".format(*self.__get_call_info()) - self._log('debug', message_prefix, parallel_mode, ranks) - self._log('debug', message, parallel_mode, ranks) + self._log("debug", message_prefix, ranks) + self._log("debug", message, ranks) - def error(self, message: str, parallel_mode: ParallelMode = ParallelMode.GLOBAL, ranks: List[int] = None) -> None: + def error(self, message: str, ranks: List[int] = None) -> None: """Log an error message. Args: message (str): The message to be logged. - parallel_mode (:class:`colossalai.context.parallel_mode.ParallelMode`): - The parallel mode used for logging. Defaults to ParallelMode.GLOBAL. ranks (List[int]): List of parallel ranks. """ message_prefix = "{}:{} {}".format(*self.__get_call_info()) - self._log('error', message_prefix, parallel_mode, ranks) - self._log('error', message, parallel_mode, ranks) + self._log("error", message_prefix, ranks) + self._log("error", message, ranks) diff --git a/colossalai/nn/__init__.py b/colossalai/nn/__init__.py index 910ad203180c8dd533ebc7732d26d94a20b72929..c6c4d30425562d2b012db8c6afc627fceae115a6 100644 --- a/colossalai/nn/__init__.py +++ b/colossalai/nn/__init__.py @@ -1,6 +1,5 @@ -from ._ops import * +from .init import * from .layer import * from .loss import * from .lr_scheduler import * -from .metric import * from .optimizer import * diff --git a/colossalai/nn/_ops/__init__.py b/colossalai/nn/_ops/__init__.py deleted file mode 100644 index 4991ad9a2217f904287d438ca37c8c4717a40a67..0000000000000000000000000000000000000000 --- a/colossalai/nn/_ops/__init__.py +++ /dev/null @@ -1,9 +0,0 @@ -from .addmm import colo_addmm -from .batch_norm import colo_batch_norm -from .element_wise import * -from .embedding import colo_embedding -from .embedding_bag import colo_embedding_bag -from .layernorm import colo_layernorm -from .linear import colo_linear -from .loss import colo_cross_entropy -from .view import colo_view diff --git a/colossalai/nn/_ops/_utils.py b/colossalai/nn/_ops/_utils.py deleted file mode 100644 index 24877bbb552f9dcd0e5b00baf783eef7f2ccecbb..0000000000000000000000000000000000000000 --- a/colossalai/nn/_ops/_utils.py +++ /dev/null @@ -1,283 +0,0 @@ -from typing import List, Optional, Union - -import torch -import torch.distributed as dist - -from colossalai.global_variables import tensor_parallel_env as env -from colossalai.nn.layer.utils import divide -from colossalai.tensor import ColoTensor, ColoTensorSpec, ProcessGroup - -GeneralTensor = Union[ColoTensor, torch.Tensor] -Number = Union[int, float] - - -def convert_to_colo_tensor(tensor: Optional[GeneralTensor], pg: ProcessGroup) -> Optional[ColoTensor]: - if tensor is not None and not isinstance(tensor, ColoTensor): - tensor = ColoTensor.from_torch_tensor(tensor, ColoTensorSpec(pg)) - return tensor - - -def set_parallel_input(input_parallel: bool): - env.parallel_input_1d = input_parallel - - -def get_parallel_input(): - return env.parallel_input_1d - - -def vocab_range_from_per_partition_vocab_size(per_partition_vocab_size, rank): - index_f = rank * per_partition_vocab_size - index_l = index_f + per_partition_vocab_size - return index_f, index_l - - -def vocab_range_from_global_vocab_size(global_vocab_size, rank, world_size): - per_partition_vocab_size = divide(global_vocab_size, world_size) - return vocab_range_from_per_partition_vocab_size(per_partition_vocab_size, rank) - - -def _reduce(input_, pg: ProcessGroup): - # skip if only one rank involved - if pg.tp_world_size() == 1: - return input_ - assert input_.device.type == 'cuda' - group = pg.tp_process_group() - dist.all_reduce(input_, group=group) - - return input_ - - -def _split(input_, pg: ProcessGroup, dim=-1): - # skip if only one rank involved - world_size = pg.tp_world_size() - if world_size == 1: - return input_ - - # Split along last dimension. - dim_size = input_.size(dim) - assert dim_size % world_size == 0, \ - f'The dimension to split ({dim_size}) is not a multiple of world size ({world_size}), ' \ - f'cannot split tensor evenly' - - tensor_list = torch.split(input_, dim_size // world_size, dim=dim) - rank = pg.tp_local_rank() - output = tensor_list[rank].contiguous() - - return output - - -def _gather(input_, pg: ProcessGroup, dim=-1): - # skip if only one rank involved - world_size = pg.tp_world_size() - if world_size == 1: - return input_ - - # all gather - rank = pg.tp_local_rank() - tensor_list = [torch.empty_like(input_) for _ in range(world_size)] - tensor_list[rank] = input_ - assert input_.device.type == 'cuda' - group = pg.tp_process_group() - torch.distributed.all_gather(tensor_list, input_, group=group) - - # concat - output = torch.cat(tensor_list, dim=dim).contiguous() - - return output - - -class _ReduceGrad(torch.autograd.Function): - """ - Pass the input to the model parallel region. - - Args: - input_: input matrix. - process_group: parallel mode. - """ - - @staticmethod - def symbolic(graph, input_): - return input_ - - @staticmethod - def forward(ctx, input_, process_group): - ctx.mode = process_group - return input_ - - @staticmethod - def backward(ctx, grad_output): - return _reduce(grad_output, ctx.mode), None - - -class _ReduceInput(torch.autograd.Function): - """ - All-reduce the input from the model parallel region. - - Args: - input_: input matrix. - process_group: parallel mode. - """ - - @staticmethod - def symbolic(graph, input_): - return _reduce(input_) - - @staticmethod - def forward(ctx, input_, process_group): - return _reduce(input_, process_group) - - @staticmethod - def backward(ctx, grad_output): - return grad_output, None - - -class _SplitForwardGatherBackward(torch.autograd.Function): - """ - Split the input and keep only the corresponding chuck to the rank. - - Args: - input_: input matrix. - process_group: parallel mode. - dim: dimension - """ - - @staticmethod - def symbolic(graph, input_): - return _split(input_) - - @staticmethod - def forward(ctx, input_, process_group, dim): - ctx.mode = process_group - ctx.dim = dim - return _split(input_, process_group, dim) - - @staticmethod - def backward(ctx, grad_output): - return _gather(grad_output, ctx.mode, ctx.dim), None, None - - -class _GatherForwardSplitBackward(torch.autograd.Function): - """Gather the input from model parallel region and concatenate. - - Args: - input_: input matrix. - process_group: parallel mode. - dim: dimension - """ - - @staticmethod - def symbolic(graph, input_): - return _gather(input_) - - @staticmethod - def forward(ctx, input_, process_group, dim): - ctx.mode = process_group - ctx.dim = dim - return _gather(input_, process_group, dim) - - @staticmethod - def backward(ctx, grad_output): - return _split(grad_output, ctx.mode, ctx.dim), None, None - - -def reduce_grad(input_, process_group): - return _ReduceGrad.apply(input_, process_group) - - -def reduce_input(input_, process_group): - return _ReduceInput.apply(input_, process_group) - - -def split_forward_gather_backward(input_, process_group, dim): - return _SplitForwardGatherBackward.apply(input_, process_group, dim) - - -def gather_forward_split_backward(input_, process_group, dim): - return _GatherForwardSplitBackward.apply(input_, process_group, dim) - - -def _all_to_all(x: torch.Tensor, pg: ProcessGroup, scatter_dim: int, gather_dim: int) -> torch.Tensor: - world_size = pg.tp_world_size() - if world_size == 1: - return x - - # TODO: enabling mpi backend to support CPU all_to_all - assert x.device.type == 'cuda', f"Currently, the collective function dual_all_to_all only supports nccl backend" - - shapes = list(x.size()) - shapes[scatter_dim] = shapes[scatter_dim] // world_size - - scatter_list = [each.contiguous() for each in torch.tensor_split(x, world_size, scatter_dim)] - gather_list = [torch.empty(*shapes, dtype=x.dtype, device=x.device) for _ in range(world_size)] - torch.distributed.all_to_all(gather_list, scatter_list, group=pg.tp_process_group()) - - return torch.cat(gather_list, dim=gather_dim).contiguous() - - -class _DualAllToAll(torch.autograd.Function): - - @staticmethod - def forward(ctx, x, pg, scatter_dim, gather_dim): - ctx.scatter_dim = scatter_dim - ctx.gather_dim = gather_dim - ctx.pg = pg - return _all_to_all(x, pg, scatter_dim, gather_dim) - - @staticmethod - def backward(ctx, grad): - return _all_to_all(grad, ctx.pg, ctx.gather_dim, ctx.scatter_dim), None, None, None - - -def dual_all_to_all(x, pg, scatter_dim: int, gather_dim: int): - return _DualAllToAll.apply(x, pg, scatter_dim, gather_dim) - - -### table wise embedding shard - - -def _all_to_all_for_tablewise(x: torch.Tensor, - pg: ProcessGroup, - scatter_strides: List[int], - gather_strides: List[int], - forward=True) -> torch.Tensor: - world_size = pg.tp_world_size() - rank = pg.tp_local_rank() - if world_size == 1: - return x - assert x.device.type == 'cuda', f"Currently, the collective function dual_all_to_all only supports nccl backend" - if forward: - scatter_list = list(x.split(scatter_strides, 0)) - gather_list = [ - torch.empty(scatter_strides[rank], gather_strides[i], dtype=x.dtype, device=x.device) - for i in range(world_size) - ] - torch.distributed.all_to_all(gather_list, scatter_list, group=pg.tp_process_group()) - return torch.cat(gather_list, 1).contiguous() - else: - # split on dim 1, lose contiguity - scatter_list = [each.contiguous() for each in x.split(scatter_strides, 1)] - gather_list = [ - torch.empty(gather_strides[i], scatter_strides[rank], dtype=x.dtype, device=x.device) - for i in range(world_size) - ] - torch.distributed.all_to_all(gather_list, scatter_list, group=pg.tp_process_group()) - return torch.cat(gather_list, 0).contiguous() - - -class _DualAllToAllForTablewise(torch.autograd.Function): - - @staticmethod - def forward(ctx, x, pg, scatter_strides, gather_strides): - ctx.pg = pg - ctx.scatter_strides = scatter_strides - ctx.gather_strides = gather_strides - return _all_to_all_for_tablewise(x, pg, scatter_strides, gather_strides, forward=True) - - @staticmethod - def backward(ctx, grad): - return _all_to_all_for_tablewise(grad, ctx.pg, ctx.gather_strides, ctx.scatter_strides, - forward=False), None, None, None - - -def dual_all_to_all_tablewise(x, pg, scatter_strides, gather_strides): - return _DualAllToAllForTablewise.apply(x, pg, scatter_strides, gather_strides) diff --git a/colossalai/nn/_ops/addmm.py b/colossalai/nn/_ops/addmm.py deleted file mode 100644 index 660b48a71d57ee4cc9b1e0cc904f9b7f8d69e244..0000000000000000000000000000000000000000 --- a/colossalai/nn/_ops/addmm.py +++ /dev/null @@ -1,90 +0,0 @@ -import torch - -from colossalai.tensor import ColoTensor, ColoTensorSpec, ComputePattern, ComputeSpec, ReplicaSpec, ShardSpec, distspec -from colossalai.tensor.op_wrapper import colo_op_impl - -from ._utils import GeneralTensor, Number, convert_to_colo_tensor, reduce_grad, reduce_input - - -def colo_addmm_1Drow(input_tensor: ColoTensor, mat1: ColoTensor, mat2: ColoTensor, beta: Number, - alpha: Number) -> ColoTensor: - # mat1:S[1] x mat2:S[0] = Output:P - # beta * input + alpha * All-Reduce(Output) = res - - mat1 = mat1.redistribute(ShardSpec([-1], [mat2.get_tp_world_size()]), mat2.get_process_group()) - - # Output:P - partial_output = torch.mm(mat1, mat2) - # Reduce(Output) - output = reduce_input(partial_output, mat2.get_process_group()) - # input - assert not input_tensor.has_compute_spec(), 'Invalid input spec for 1Drow addmm op' - output = beta * input_tensor + alpha * output - output = ColoTensor.from_torch_tensor(output, spec=ColoTensorSpec(input_tensor.get_process_group())) - return output - - -def colo_addmm_1Dcol(input_tensor: ColoTensor, mat1: ColoTensor, mat2: ColoTensor, beta: Number, - alpha: Number) -> ColoTensor: - # mat1:B x mat2:S[1] + input:S[1] = Output:S[1] - compute_spec = mat2.compute_spec - mat1 = mat1.redistribute(ReplicaSpec()) - mat1 = reduce_grad(mat1, mat1.get_process_group()) - - output_parallel = torch.addmm(input_tensor, mat1, mat2, beta=beta, alpha=alpha) - output_spec = ColoTensorSpec(input_tensor.get_process_group(), ShardSpec([-1], [mat2.get_tp_world_size()]), - ComputeSpec(ComputePattern.TP1D)) - output = ColoTensor.from_torch_tensor(output_parallel, spec=output_spec) - - if compute_spec.output_replicate: - return output.to_replicate() - else: - return output - - -def colo_addmm_1d(mode: str, input_tensor: ColoTensor, mat1: ColoTensor, mat2: ColoTensor, beta: Number, - alpha: Number) -> ColoTensor: - assert mode in ('row', 'col') - funcs = {'row': colo_addmm_1Drow, 'col': colo_addmm_1Dcol} - return funcs[mode](input_tensor, mat1, mat2, beta, alpha) - - -@colo_op_impl(torch.addmm) -def colo_addmm(input_tensor: GeneralTensor, - mat1: ColoTensor, - mat2: ColoTensor, - beta: Number = 1, - alpha: Number = 1, - **kargs) -> ColoTensor: - """Handles ``__torch_function__`` dispatch for ``torch.nn.functional.linear``. - This method computes a linear. - """ - # At least one of the tensor should be ColoTensor - assert isinstance(mat2, ColoTensor) - input_tensor = convert_to_colo_tensor(input_tensor, mat2.get_process_group()) - mat1 = convert_to_colo_tensor(mat1, mat2.get_process_group()) - - # Add communication logic before and after linear call. - ret_tensor = None - if not mat2.has_compute_spec(): # No Model Parallel Applied - assert mat2.is_replicate(), 'Invalid mat2 spec for native addmm op' - assert input_tensor.is_replicate(), 'Invalid input spec for native addmm op' - ret_tensor = ColoTensor.from_torch_tensor(tensor=torch.addmm(input_tensor, - mat1, - mat2, - beta=beta, - alpha=alpha, - **kargs), - spec=ColoTensorSpec(mat2.get_process_group())) - elif mat2.has_compute_pattern(ComputePattern.TP1D): # Single Model Parallel Applied - if mat2.is_shard_1drow() and input_tensor.is_replicate(): - mode = 'row' - elif mat2.is_shard_1dcol() and (input_tensor.is_shard_1dcol() or input_tensor.is_shard_1drow()): - mode = 'col' - else: - raise NotImplementedError - ret_tensor = colo_addmm_1d(mode, input_tensor, mat1, mat2, beta, alpha) - else: - raise NotImplementedError - - return ret_tensor diff --git a/colossalai/nn/_ops/batch_norm.py b/colossalai/nn/_ops/batch_norm.py deleted file mode 100644 index 54ecc88f420a8d8a2c81aca6ac765e57f5a56cac..0000000000000000000000000000000000000000 --- a/colossalai/nn/_ops/batch_norm.py +++ /dev/null @@ -1,33 +0,0 @@ -from typing import Optional - -import torch.nn.functional as F - -from colossalai.tensor import ColoTensor, ColoTensorSpec, ReplicaSpec -from colossalai.tensor.op_wrapper import colo_op_impl - -from ._utils import GeneralTensor, convert_to_colo_tensor - - -@colo_op_impl(F.batch_norm) -def colo_batch_norm( - input: GeneralTensor, - running_mean: Optional[GeneralTensor], - running_var: Optional[GeneralTensor], - weight: Optional[GeneralTensor] = None, - bias: Optional[GeneralTensor] = None, - training: bool = False, - momentum: float = 0.1, - eps: float = 1e-5, -): - assert isinstance(weight, ColoTensor) - running_mean = running_mean.detach() - running_var = running_var.detach() - - input = convert_to_colo_tensor(input, weight.get_process_group()) - bias = convert_to_colo_tensor(bias, weight.get_process_group()) - input = input.redistribute(ReplicaSpec()) - bias = bias.redistribute(ReplicaSpec()) - - output = F.batch_norm(input, running_mean, running_var, weight, bias, training, momentum, eps) - output = ColoTensor.from_torch_tensor(tensor=output, spec=ColoTensorSpec(pg=weight.get_process_group())) - return output diff --git a/colossalai/nn/_ops/element_wise.py b/colossalai/nn/_ops/element_wise.py deleted file mode 100644 index 2de51e24a6dd1bd45271a0b8d51372ee5209415d..0000000000000000000000000000000000000000 --- a/colossalai/nn/_ops/element_wise.py +++ /dev/null @@ -1,250 +0,0 @@ -import torch -import torch.nn.functional as F -from torch import Tensor - -from colossalai.tensor import ColoTensor, ColoTensorSpec -from colossalai.tensor.op_wrapper import colo_op_impl - -from ._utils import GeneralTensor, convert_to_colo_tensor - - -def register_elementwise_op(op): - - @colo_op_impl(op) - def elementwise_op(input_tensor: GeneralTensor, *args, **kwargs): - """ - Handles ``__torch_function__`` dispatch for the elementwise op such - as ``torch.nn.functional.gelu`` or ``torch.nn.functional.relu``. - This method computes on either a normal tensor or a sharded tensor. - """ - if 'inplace' in kwargs: - # TODO(jiaruifang) inplace will cause bugs - input_tensor = input_tensor.clone() - return op(input_tensor, *args, **kwargs) - else: - output = op(input_tensor, *args, **kwargs) - # return output - if isinstance(input_tensor, ColoTensor): - if isinstance(output, str): - return output - if not isinstance(output, torch.Tensor): - raise NotImplementedError - return ColoTensor.from_torch_tensor(output, - spec=ColoTensorSpec(input_tensor.get_process_group(), - dist_attr=input_tensor.dist_spec)) - - -# @colo_op_impl(torch.relu_) -# def elementwise_op(input_tensor): -# torch.relu_(input_tensor.data) -# return input_tensor - -# @colo_op_impl(Tensor.add_) -# def elementwise_op(input_tensor: ColoTensor, *args, **kwargs): -# input_tensor = input_tensor.data.add_(*args, **kwargs) -# return input_tensor - -# Tensor op -register_elementwise_op(Tensor.abs) -register_elementwise_op(Tensor.absolute) -register_elementwise_op(Tensor.acos) -register_elementwise_op(Tensor.arccos) -register_elementwise_op(Tensor.angle) -register_elementwise_op(Tensor.asin) -register_elementwise_op(Tensor.arcsin) -register_elementwise_op(Tensor.atan) -register_elementwise_op(Tensor.arctan) -register_elementwise_op(Tensor.all) -register_elementwise_op(Tensor.any) -register_elementwise_op(Tensor.bernoulli) -register_elementwise_op(Tensor.bfloat16) -register_elementwise_op(Tensor.bitwise_not) -register_elementwise_op(Tensor.bool) -register_elementwise_op(Tensor.byte) -register_elementwise_op(Tensor.ceil) -register_elementwise_op(Tensor.char) -register_elementwise_op(Tensor.clamp) -register_elementwise_op(Tensor.clamp_max) -register_elementwise_op(Tensor.clamp_min) -register_elementwise_op(Tensor.clip) -register_elementwise_op(Tensor.clone) -register_elementwise_op(Tensor.contiguous) -register_elementwise_op(Tensor.copysign) -register_elementwise_op(Tensor.cos) -register_elementwise_op(Tensor.cosh) -register_elementwise_op(Tensor.acosh) -register_elementwise_op(Tensor.arccosh) -register_elementwise_op(Tensor.cpu) -register_elementwise_op(Tensor.cuda) -register_elementwise_op(Tensor.deg2rad) -register_elementwise_op(Tensor.detach) -register_elementwise_op(Tensor.digamma) -register_elementwise_op(Tensor.double) -register_elementwise_op(Tensor.erf) -register_elementwise_op(Tensor.erfc) -register_elementwise_op(Tensor.erfinv) -register_elementwise_op(Tensor.exp) -register_elementwise_op(Tensor.expm1) -register_elementwise_op(Tensor.fix) -register_elementwise_op(Tensor.trunc) -register_elementwise_op(Tensor.float) -register_elementwise_op(Tensor.float_power) -register_elementwise_op(Tensor.floor) -register_elementwise_op(Tensor.frac) -register_elementwise_op(Tensor.half) -register_elementwise_op(Tensor.hardshrink) -register_elementwise_op(Tensor.heaviside) -register_elementwise_op(Tensor.i0) -register_elementwise_op(Tensor.int) -register_elementwise_op(Tensor.isfinite) -register_elementwise_op(Tensor.isinf) -register_elementwise_op(Tensor.isposinf) -register_elementwise_op(Tensor.isneginf) -register_elementwise_op(Tensor.isnan) -register_elementwise_op(Tensor.lgamma) -register_elementwise_op(Tensor.log) -register_elementwise_op(Tensor.log10) -register_elementwise_op(Tensor.log1p) -register_elementwise_op(Tensor.log2) -register_elementwise_op(Tensor.logical_not) -register_elementwise_op(Tensor.logit) -register_elementwise_op(Tensor.long) -register_elementwise_op(Tensor.nan_to_num) -register_elementwise_op(Tensor.neg) -register_elementwise_op(Tensor.negative) -register_elementwise_op(Tensor.positive) -register_elementwise_op(Tensor.pow) -register_elementwise_op(Tensor.rad2deg) -register_elementwise_op(Tensor.reciprocal) -register_elementwise_op(Tensor.round) -register_elementwise_op(Tensor.rsqrt) -register_elementwise_op(Tensor.short) -register_elementwise_op(Tensor.sigmoid) -register_elementwise_op(Tensor.sign) -register_elementwise_op(Tensor.signbit) -register_elementwise_op(Tensor.sgn) -register_elementwise_op(Tensor.sin) -register_elementwise_op(Tensor.sinc) -register_elementwise_op(Tensor.sinh) -register_elementwise_op(Tensor.asinh) -register_elementwise_op(Tensor.arcsinh) -register_elementwise_op(Tensor.sqrt) -register_elementwise_op(Tensor.square) -register_elementwise_op(Tensor.to) -register_elementwise_op(Tensor.tan) -register_elementwise_op(Tensor.tanh) -register_elementwise_op(Tensor.atanh) -register_elementwise_op(Tensor.arctanh) -register_elementwise_op(Tensor.type) -register_elementwise_op(Tensor.type_as) - -# torch OP -register_elementwise_op(torch.abs) -register_elementwise_op(torch.absolute) -register_elementwise_op(torch.acos) -register_elementwise_op(torch.arccos) -register_elementwise_op(torch.angle) -register_elementwise_op(torch.asin) -register_elementwise_op(torch.arcsin) -register_elementwise_op(torch.atan) -register_elementwise_op(torch.arctan) -register_elementwise_op(torch.all) -register_elementwise_op(torch.any) -register_elementwise_op(torch.bernoulli) -register_elementwise_op(torch.bitwise_not) -register_elementwise_op(torch.ceil) -register_elementwise_op(torch.clamp) -register_elementwise_op(torch.clamp_max) -register_elementwise_op(torch.clamp_min) -register_elementwise_op(torch.clip) -register_elementwise_op(torch.clone) -register_elementwise_op(torch.copysign) -register_elementwise_op(torch.cos) -register_elementwise_op(torch.cosh) -register_elementwise_op(torch.acosh) -register_elementwise_op(torch.arccosh) -register_elementwise_op(torch.deg2rad) -register_elementwise_op(torch.digamma) -register_elementwise_op(torch.erf) -register_elementwise_op(torch.erfc) -register_elementwise_op(torch.erfinv) -register_elementwise_op(torch.exp) -register_elementwise_op(torch.expm1) -register_elementwise_op(torch.fix) -register_elementwise_op(torch.trunc) -register_elementwise_op(torch.float_power) -register_elementwise_op(torch.floor) -register_elementwise_op(torch.frac) -register_elementwise_op(torch.hardshrink) -register_elementwise_op(torch.heaviside) -register_elementwise_op(torch.i0) -register_elementwise_op(torch.isfinite) -register_elementwise_op(torch.isinf) -register_elementwise_op(torch.isposinf) -register_elementwise_op(torch.isneginf) -register_elementwise_op(torch.isnan) -register_elementwise_op(torch.lgamma) -register_elementwise_op(torch.log) -register_elementwise_op(torch.log10) -register_elementwise_op(torch.log1p) -register_elementwise_op(torch.log2) -register_elementwise_op(torch.logical_not) -register_elementwise_op(torch.logit) -register_elementwise_op(torch.nan_to_num) -register_elementwise_op(torch.neg) -register_elementwise_op(torch.negative) -register_elementwise_op(torch.positive) -register_elementwise_op(torch.pow) -register_elementwise_op(torch.rad2deg) -register_elementwise_op(torch.reciprocal) -register_elementwise_op(torch.round) -register_elementwise_op(torch.rsqrt) -register_elementwise_op(torch.sigmoid) -register_elementwise_op(torch.sign) -register_elementwise_op(torch.signbit) -register_elementwise_op(torch.sgn) -register_elementwise_op(torch.sin) -register_elementwise_op(torch.sinc) -register_elementwise_op(torch.sinh) -register_elementwise_op(torch.asinh) -register_elementwise_op(torch.arcsinh) -register_elementwise_op(torch.sqrt) -register_elementwise_op(torch.square) -register_elementwise_op(torch.tan) -register_elementwise_op(torch.tanh) -register_elementwise_op(torch.atanh) -register_elementwise_op(torch.arctanh) -register_elementwise_op(torch.zeros_like) - -# nn.functional OP -register_elementwise_op(F.threshold) -register_elementwise_op(F.relu) -register_elementwise_op(F.hardtanh) -register_elementwise_op(F.hardswish) -register_elementwise_op(F.relu6) -register_elementwise_op(F.elu) -register_elementwise_op(F.selu) -register_elementwise_op(F.celu) -register_elementwise_op(F.leaky_relu) -register_elementwise_op(F.prelu) -register_elementwise_op(F.rrelu) -register_elementwise_op(F.gelu) -register_elementwise_op(F.logsigmoid) -register_elementwise_op(F.hardshrink) -register_elementwise_op(F.tanhshrink) -register_elementwise_op(F.softsign) -register_elementwise_op(F.softplus) -register_elementwise_op(F.softmin) -register_elementwise_op(F.softmax) -register_elementwise_op(F.softshrink) -register_elementwise_op(F.gumbel_softmax) -register_elementwise_op(F.log_softmax) -register_elementwise_op(F.tanh) -register_elementwise_op(F.sigmoid) -register_elementwise_op(F.hardsigmoid) -register_elementwise_op(F.silu) -register_elementwise_op(F.mish) -# TODO(ver217): dropout handles seed -register_elementwise_op(F.dropout) -register_elementwise_op(F.alpha_dropout) -register_elementwise_op(F.feature_alpha_dropout) diff --git a/colossalai/nn/_ops/embedding.py b/colossalai/nn/_ops/embedding.py deleted file mode 100644 index a045f305b5dc72454043298b2a69f114ad50f1e9..0000000000000000000000000000000000000000 --- a/colossalai/nn/_ops/embedding.py +++ /dev/null @@ -1,140 +0,0 @@ -import torch.nn.functional as F -from typing import Optional -from colossalai.tensor.op_wrapper import colo_op_impl -from colossalai.tensor import ComputePattern, ColoTensorSpec, ComputePattern, ComputeSpec, ColoTensor, ShardSpec, \ - ReplicaSpec -from ._utils import GeneralTensor, convert_to_colo_tensor, reduce_input - - -def colo_embedding_1Dcol(input_tensor: ColoTensor, - weight: ColoTensor, - padding_idx: Optional[int] = None, - max_norm: Optional[float] = None, - norm_type: float = 2.0, - scale_grad_by_freq: bool = False, - sparse: bool = False) -> ColoTensor: - # embedding_1Dcol split the weight(lookup table) to (num_embeddings, embedding_dim/P) - # Gather splitted lookup table - input_tensor = input_tensor.redistribute(ReplicaSpec()) - - output_parallel = F.embedding(input_tensor, - weight, - padding_idx=padding_idx, - max_norm=max_norm, - norm_type=norm_type, - scale_grad_by_freq=scale_grad_by_freq, - sparse=sparse) - output_spec = ColoTensorSpec(weight.get_process_group(), ShardSpec([-1], [weight.get_tp_world_size()]), - ComputeSpec(ComputePattern.TP1D)) - output = ColoTensor.from_torch_tensor(output_parallel, spec=output_spec) - - compute_spec = weight.compute_spec - - if compute_spec.output_replicate: - return output.to_replicate() - else: - return output - - -def colo_embedding_1Drow(input_tensor: ColoTensor, - weight: ColoTensor, - padding_idx: Optional[int] = None, - max_norm: Optional[float] = None, - norm_type: float = 2.0, - scale_grad_by_freq: bool = False, - sparse: bool = False) -> ColoTensor: - # embedding_1Drow splits the weight(lookup table) to the shape, [num_embeddings/P, embedding_dim] - # get the index of current segment and mask other segments with 0 - - # get complete input tensor through all-gather - input_tensor = input_tensor.redistribute(ReplicaSpec()) - - # tensor_parallel_rank = gpc.get_local_rank(ParallelMode.PARALLEL_1D) - tensor_parallel_rank = weight.get_process_group().tp_local_rank() - num_embeddings_per_partition = weight.size_local(0) - vocab_start_index = tensor_parallel_rank * num_embeddings_per_partition - vocab_end_index = vocab_start_index + num_embeddings_per_partition - - # build the mask. - input_mask = (input_tensor < vocab_start_index) | (input_tensor >= vocab_end_index) - # mask the input. - # TODO(jzy) masked_input may be an activation managed by ColoTensor. - masked_input = input_tensor - vocab_start_index - masked_input[input_mask] = 0 - - partial_output = F.embedding(masked_input, - weight, - padding_idx=padding_idx, - max_norm=max_norm, - norm_type=norm_type, - scale_grad_by_freq=scale_grad_by_freq, - sparse=sparse) - - # Mask the output embedding. - partial_output[input_mask, :] = 0. - # Reduce across all the model parallel GPUs. - output = reduce_input(partial_output, weight.get_process_group()) - output = ColoTensor.from_torch_tensor(output, spec=ColoTensorSpec(weight.get_process_group(), ReplicaSpec())) - return output - - -def colo_embedding_1d(mode: str, - input_tensor: ColoTensor, - weight: ColoTensor, - padding_idx: Optional[int] = None, - max_norm: Optional[float] = None, - norm_type: float = 2.0, - scale_grad_by_freq: bool = False, - sparse: bool = False) -> ColoTensor: - assert mode in ('row', 'col') - funcs = {'row': colo_embedding_1Drow, 'col': colo_embedding_1Dcol} - return funcs[mode](input_tensor, - weight, - padding_idx=padding_idx, - max_norm=max_norm, - norm_type=norm_type, - scale_grad_by_freq=scale_grad_by_freq, - sparse=sparse) - - -@colo_op_impl(F.embedding) -def colo_embedding(input_tensor: GeneralTensor, - weight: GeneralTensor, - padding_idx: Optional[int] = None, - max_norm: Optional[float] = None, - norm_type: float = 2.0, - scale_grad_by_freq: bool = False, - sparse: bool = False): - """Handles ``__torch_function__`` dispatch for ``torch.nn.functional.embedding``. - This method looks up an embedding table. - """ - assert isinstance(weight, ColoTensor) - input_tensor = convert_to_colo_tensor(input_tensor, weight.get_process_group()) - - if not weight.has_compute_spec(): # No Model Parallel Applied - assert weight.is_replicate(), 'Invalid weight spec for native embedding op' - return ColoTensor.from_torch_tensor(tensor=F.embedding(input_tensor, - weight, - padding_idx=padding_idx, - max_norm=max_norm, - norm_type=norm_type, - scale_grad_by_freq=scale_grad_by_freq, - sparse=sparse), - spec=ColoTensorSpec(weight.get_process_group())) - elif weight.has_compute_pattern(ComputePattern.TP1D): # Single Model Parallel Applied - if weight.is_shard_1drow(): - mode = 'row' - elif weight.is_shard_1dcol(): - mode = 'col' - else: - raise NotImplementedError - return colo_embedding_1d(mode, - input_tensor, - weight, - padding_idx=padding_idx, - max_norm=max_norm, - norm_type=norm_type, - scale_grad_by_freq=scale_grad_by_freq, - sparse=sparse) - else: - raise NotImplementedError diff --git a/colossalai/nn/_ops/embedding_bag.py b/colossalai/nn/_ops/embedding_bag.py deleted file mode 100644 index 0026f579b6dccc1344ba0f32cf47d4131f30f7e6..0000000000000000000000000000000000000000 --- a/colossalai/nn/_ops/embedding_bag.py +++ /dev/null @@ -1,125 +0,0 @@ -import torch.nn.functional as F -from typing import Optional -from torch import Tensor -from colossalai.tensor.op_wrapper import colo_op_impl -from colossalai.tensor import ComputePattern, ComputePattern, ComputeSpec, ColoTensor, distspec, ColoTensorSpec, \ - ShardSpec, ReplicaSpec -from ._utils import GeneralTensor, convert_to_colo_tensor - - -def colo_embedding_bag_1Dcol(input_tensor: ColoTensor, - weight: ColoTensor, - offsets: Optional[Tensor] = None, - max_norm: Optional[float] = None, - norm_type: float = 2, - scale_grad_by_freq: bool = False, - mode: str = "mean", - sparse: bool = False, - per_sample_weights: Optional[Tensor] = None, - include_last_offset: bool = False, - padding_idx: Optional[int] = None) -> ColoTensor: - # embedding_bag_1Dcol split the weight(lookup table) to (num_embeddings, embedding_dim/P) - # Gather splitted lookup table - pg = weight.get_process_group() - input_tensor = input_tensor.redistribute(ReplicaSpec()) - - output_parallel = F.embedding_bag(input_tensor, - weight, - offsets=offsets, - max_norm=max_norm, - norm_type=norm_type, - scale_grad_by_freq=scale_grad_by_freq, - mode=mode, - sparse=sparse, - per_sample_weights=per_sample_weights, - include_last_offset=include_last_offset, - padding_idx=padding_idx) - output_spec = ColoTensorSpec(pg, ShardSpec([-1], [weight.get_tp_world_size()]), ComputeSpec(ComputePattern.TP1D)) - output = ColoTensor.from_torch_tensor(output_parallel, spec=output_spec) - - if weight.compute_spec.output_replicate: - return output.to_replicate() - else: - return output - - -def colo_embedding_bag_1d(tp_mode: str, - input_tensor: ColoTensor, - weight: ColoTensor, - offsets: Optional[Tensor] = None, - max_norm: Optional[float] = None, - norm_type: float = 2, - scale_grad_by_freq: bool = False, - mode: str = "mean", - sparse: bool = False, - per_sample_weights: Optional[Tensor] = None, - include_last_offset: bool = False, - padding_idx: Optional[int] = None) -> ColoTensor: - assert tp_mode in ('col',) - funcs = {'col': colo_embedding_bag_1Dcol} - return funcs[tp_mode](input_tensor, - weight, - offsets=offsets, - max_norm=max_norm, - norm_type=norm_type, - scale_grad_by_freq=scale_grad_by_freq, - mode=mode, - sparse=sparse, - per_sample_weights=per_sample_weights, - include_last_offset=include_last_offset, - padding_idx=padding_idx) - - -@colo_op_impl(F.embedding_bag) -def colo_embedding_bag(input_tensor: GeneralTensor, - weight: GeneralTensor, - offsets: Optional[Tensor] = None, - max_norm: Optional[float] = None, - norm_type: float = 2, - scale_grad_by_freq: bool = False, - mode: str = "mean", - sparse: bool = False, - per_sample_weights: Optional[Tensor] = None, - include_last_offset: bool = False, - padding_idx: Optional[int] = None): - """Handles ``__torch_function__`` dispatch for ``torch.nn.functional.embedding_bag``. - This method looks up an embedding table. - """ - assert isinstance(weight, ColoTensor) - input_tensor = convert_to_colo_tensor(input_tensor, weight.get_process_group()) - - # Handle different parallel actions. - - if not weight.has_compute_spec(): # No Model Parallel Applied - assert weight.is_replicate(), 'Invalid weight spec for native embedding op' - return ColoTensor.from_torch_tensor(tensor=F.embedding_bag(input_tensor, - weight, - offsets=offsets, - max_norm=max_norm, - norm_type=norm_type, - scale_grad_by_freq=scale_grad_by_freq, - mode=mode, - sparse=sparse, - per_sample_weights=per_sample_weights, - include_last_offset=include_last_offset, - padding_idx=padding_idx), - spec=ColoTensorSpec(weight.get_process_group())) - elif weight.has_compute_pattern(ComputePattern.TP1D): # Single Model Parallel Applied - if weight.is_shard_1dcol(): - tp_mode = 'col' - else: - raise NotImplementedError - return colo_embedding_bag_1d(tp_mode, - input_tensor, - weight, - offsets=offsets, - max_norm=max_norm, - norm_type=norm_type, - scale_grad_by_freq=scale_grad_by_freq, - mode=mode, - sparse=sparse, - per_sample_weights=per_sample_weights, - include_last_offset=include_last_offset, - padding_idx=padding_idx) - else: - raise NotImplementedError diff --git a/colossalai/nn/_ops/layernorm.py b/colossalai/nn/_ops/layernorm.py deleted file mode 100644 index 2b761b84e3ee8aa9dcaf7ef1ba054b6857aa4b49..0000000000000000000000000000000000000000 --- a/colossalai/nn/_ops/layernorm.py +++ /dev/null @@ -1,25 +0,0 @@ -from typing import List, Optional -import torch.nn.functional as F -from colossalai.tensor.op_wrapper import colo_op_impl -from colossalai.tensor import ColoTensor, distspec, ColoTensorSpec, ReplicaSpec -from ._utils import GeneralTensor, convert_to_colo_tensor - - -@colo_op_impl(F.layer_norm) -def colo_layernorm( - input_tensor: GeneralTensor, - normalized_shape: List[int], - weight: Optional[GeneralTensor] = None, - bias: Optional[GeneralTensor] = None, - eps: float = 1e-5, -): - assert isinstance(weight, ColoTensor) - input_tensor = convert_to_colo_tensor(input_tensor, weight.get_process_group()) - bias = convert_to_colo_tensor(bias, weight.get_process_group()) - input_tensor = input_tensor.redistribute(ReplicaSpec()) - - output = F.layer_norm(input_tensor, normalized_shape, weight=weight, bias=bias, eps=eps) - output = ColoTensor.from_torch_tensor(tensor=output, - spec=ColoTensorSpec(pg=input_tensor.get_process_group(), - dist_attr=input_tensor.dist_spec)) - return output diff --git a/colossalai/nn/_ops/linear.py b/colossalai/nn/_ops/linear.py deleted file mode 100644 index 2f2088c61fa842b7eff6eecd348e6ff7d42916cf..0000000000000000000000000000000000000000 --- a/colossalai/nn/_ops/linear.py +++ /dev/null @@ -1,171 +0,0 @@ -from copy import deepcopy -from typing import Optional - -import torch.nn.functional as F - -from colossalai.tensor import ColoTensor, ColoTensorSpec, ComputePattern, ComputeSpec, ReplicaSpec, ShardSpec -from colossalai.tensor.op_wrapper import colo_op_impl -from colossalai.tensor.sharding_spec import ShardingSpec - -from ._utils import GeneralTensor, convert_to_colo_tensor, reduce_grad, reduce_input - - -def colo_linear_1drow(input_tensor: ColoTensor, weight: ColoTensor, bias: Optional[ColoTensor]) -> 'ColoTensor': - # Input:S[1] x Weight:S[0] = Output:P - # All-Reduce(Output) + bias = res - # Input:S[1] - pg = weight.get_process_group() - input_tensor = input_tensor.redistribute(ShardSpec([-1], [weight.get_tp_world_size()]), pg) - - # Output:P - partial_output = F.linear(input_tensor, weight) - # Reduce(Output) - - output = reduce_input(partial_output, pg) - # Bias - if bias is not None: - assert not bias.has_compute_spec(), 'Invalid bias spec for 1Drow Linear op' - output = output + bias - - output = ColoTensor.from_torch_tensor(output, spec=ColoTensorSpec(pg, ReplicaSpec())) - return output - - -def colo_linear_1dcol(input_tensor: ColoTensor, weight: ColoTensor, bias: Optional[ColoTensor]) -> 'ColoTensor': - # Input:B x Weight:S[1] + Bias:S[1] = Output:S[1] - # All-Gather(Output) - # Input:B - compute_spec = weight.compute_spec - input_tensor = input_tensor.redistribute(ReplicaSpec()) - input_parallel = reduce_grad(input_tensor, weight.get_process_group()) - - output_parallel = F.linear(input_parallel, weight, bias) - output = ColoTensor.from_torch_tensor(output_parallel, - spec=ColoTensorSpec(weight.get_process_group(), - ShardSpec([-1], [weight.get_tp_world_size()]), - ComputeSpec(ComputePattern.TP1D))) - if compute_spec.output_replicate: - return output.to_replicate() - else: - return output - - -def colo_linear_1d(mode: str, input_tensor: ColoTensor, weight: ColoTensor, bias: Optional[ColoTensor]) -> 'ColoTensor': - assert mode in ('row', 'col') - funcs = {'row': colo_linear_1drow, 'col': colo_linear_1dcol} - return funcs[mode](input_tensor, weight, bias) - - -# @register_colo_graph(input_pos=[1], param_pos=[2, 3]) -def colo_linear_imp(input_tensor: GeneralTensor, - weight: GeneralTensor, - bias: Optional[GeneralTensor] = None) -> 'ColoTensor': - """Handles ``__torch_function__`` dispatch for ``torch.nn.functional.linear``. - This method computes a linear. - """ - assert isinstance(weight, ColoTensor) - pg = weight.get_process_group() - assert pg - input_tensor = convert_to_colo_tensor(input_tensor, pg) - bias = convert_to_colo_tensor(bias, pg) - # input_tensor, weight, bias = tuple(map(convert_to_colo_tensor, (input_tensor, weight, bias))) - - # Add communication logic before and after linear call. - ret_tensor = None - if not weight.has_compute_spec(): # No Model Parallel Applied - assert weight.is_replicate(), 'Invalid weight spec for native Linear op' - assert bias is None or bias.is_replicate(), 'Invalid bias spec for native Linear op' - ret_tensor = ColoTensor.from_torch_tensor(F.linear(input_tensor, weight, bias), spec=ColoTensorSpec(pg)) - elif weight.has_compute_pattern(ComputePattern.TP1D): # Single Model Parallel Applied - if weight.is_shard_1dcol() and (bias is None or bias.is_replicate()): - mode = 'row' - elif weight.is_shard_1drow() and (bias is None or bias.is_shard_1drow() or bias.is_shard_1dcol()): - mode = 'col' - else: - raise RuntimeError(f"the weight or bias tensor spec is not valid, weight {weight}, bias {bias}") - ret_tensor = colo_linear_1d(mode, input_tensor, weight, bias) - else: - raise NotImplementedError - - return ret_tensor - - -def _new_colo_linear_imp(input_tensor: GeneralTensor, - weight: GeneralTensor, - bias: Optional[GeneralTensor] = None) -> 'ColoTensor': - """ - A tentative function to compute the distributed linear layer with the latest sharding spec. - This function is subject to future change as the current sharding API is not stable. - """ - # get mesh info - input_sharding_seq = input_tensor.sharding_spec.sharding_sequence - weight_sharding_seq = weight.sharding_spec.sharding_sequence - if bias is not None: - bias_sharding_seq = bias.sharding_spec.sharding_sequence - device_mesh = weight.sharding_spec.device_mesh - pg_axis0 = weight.pg_axis0 - pg_axis1 = weight.pg_axis1 - - # the last dim of input should have the same spec as the first dim of weight - # the weight is transposed, so we look at the second dimension - assert input_sharding_seq[-1] == weight_sharding_seq[1] - - if bias is not None: - assert bias_sharding_seq[0] == weight_sharding_seq[0] - - # compute the output sharding sequence - # as weight is transposed, so we look at the first dimension - output_shard_seq = input_sharding_seq[:-1] + weight_sharding_seq[:1] - output_shard_seq = deepcopy(output_shard_seq) - - # TODO: add reduce grad logic - - # handle column and row parallel linear - # by reusing the implementation above - out = F.linear(input_tensor, weight) - - # run all reduce if necessary - last_dim_spec = input_sharding_seq[-1] - if last_dim_spec.is_replica: - pass - elif last_dim_spec.shard_list is not None: - for dim in last_dim_spec.shard_list: - if dim == 0: - reduce_input(out, pg_axis0) - elif dim == 1: - reduce_input(out, pg_axis1) - else: - raise RuntimeError("Found invalid sharding axis {dim}, only 0 or 1 is expected") - # add bias - if bias is not None: - out += bias - - # convert shard seq to partition dict - output_partition_dict = {} - for index, dim_spec in enumerate(output_shard_seq): - if not dim_spec.is_replica: - if index not in output_partition_dict: - output_partition_dict[index] = [] - output_partition_dict[index].extend(dim_spec.shard_list) - - entire_shape = out.shape - output_sharding_spec = ShardingSpec(device_mesh, entire_shape, output_partition_dict) - ret_tensor = ColoTensor.from_torch_tensor(out) - setattr(ret_tensor, 'sharding_spec', output_sharding_spec) - return ret_tensor - - -def _has_sharding_spec(tensor): - """ - A tentative function to check whether the tensor is using the new sharding spec API. We assume that the sharding spec object is - set as the attribute `sharding_spec` on a tensor. - """ - return hasattr(tensor, 'sharding_spec') - - -@colo_op_impl(F.linear) -def colo_linear(input: GeneralTensor, weight: GeneralTensor, bias: Optional[GeneralTensor] = None) -> 'ColoTensor': - if _has_sharding_spec(weight): - return _new_colo_linear_imp(input, weight, bias) - else: - return colo_linear_imp(input, weight, bias) diff --git a/colossalai/nn/_ops/loss.py b/colossalai/nn/_ops/loss.py deleted file mode 100644 index 1e54f662859ceffa3edb66157c187c6e17658142..0000000000000000000000000000000000000000 --- a/colossalai/nn/_ops/loss.py +++ /dev/null @@ -1,48 +0,0 @@ -import torch -import torch.nn.functional as F -from typing import Optional -from colossalai.tensor.op_wrapper import colo_op_impl -from colossalai.tensor import ColoTensor, ColoTensorSpec -from colossalai.nn.loss.loss_1d import VocabParallelCrossEntropyLoss1D -from ._utils import GeneralTensor, convert_to_colo_tensor - - -@colo_op_impl(F.cross_entropy) -def colo_cross_entropy(input_tensor: GeneralTensor, - target: GeneralTensor, - weight: Optional[GeneralTensor] = None, - size_average: Optional[bool] = None, - ignore_index: int = -100, - reduce: Optional[bool] = None, - reduction: str = "mean", - label_smoothing: float = 0.0): - assert isinstance(weight, ColoTensor) or isinstance(target, ColoTensor) or isinstance(input_tensor, ColoTensor) - pg = input_tensor.get_process_group() if isinstance(input_tensor, ColoTensor) else isinstance(target, ColoTensor) - weight = convert_to_colo_tensor(weight, pg) - target = convert_to_colo_tensor(target, pg) - input_tensor = convert_to_colo_tensor(input_tensor, pg) - - if input_tensor.is_replicate(): # Input is gathered - assert target.is_replicate() and (weight is None or weight.is_replicate()), \ - "Target tensor and weight tensor both should be complete" - output = F.cross_entropy(input_tensor, - target, - weight=weight, - size_average=size_average, - ignore_index=ignore_index, - reduce=reduce, - reduction=reduction, - label_smoothing=label_smoothing) - return ColoTensor.from_torch_tensor(output, ColoTensorSpec(pg)) - elif input_tensor.has_compute_spec(): # Single Model Parallel Applied - if input_tensor.is_shard_1dcol(): - assert weight is None, "Current TP cross entropy loss function doesn't support passing weight tensor in" - assert target.is_replicate(), "Target tensor should be complete in TP cross entropy loss function" - output = VocabParallelCrossEntropyLoss1D()(input_tensor, - target, - process_group=input_tensor.process_group.tp_process_group()) - return ColoTensor.from_torch_tensor(output, ColoTensorSpec(pg)) - else: - raise NotImplementedError - else: - raise NotImplementedError diff --git a/colossalai/nn/_ops/view.py b/colossalai/nn/_ops/view.py deleted file mode 100644 index 3c0bc52337ce8f053f04fbc2885fa5aba879990a..0000000000000000000000000000000000000000 --- a/colossalai/nn/_ops/view.py +++ /dev/null @@ -1,96 +0,0 @@ -import operator -from functools import reduce -from typing import Optional, Union - -import torch - -from colossalai.tensor import ColoTensor, ColoTensorSpec, ReplicaSpec -from colossalai.tensor.op_wrapper import colo_op_impl - - -def _all_int(my_iter): - return all(isinstance(i, int) for i in my_iter) - - -def _get_valid_shape(shape): - if isinstance(shape, list): - if _all_int(shape): - return tuple(shape) - else: - raise RuntimeError("expects type(int) but finds an other type") - elif isinstance(shape, tuple): - if _all_int(shape): - return shape - else: - return _get_valid_shape(shape[0]) - else: - raise RuntimeError("expects an iterable array but finds '{}'".format(type(shape))) - - -def _shape_infer(org_sp, tgt_sp): - cnt = 0 - pos = 0 - for idx, dim in enumerate(tgt_sp): - if dim < -1: - raise RuntimeError("invalid shape dimension {}".format(dim)) - elif dim == -1: - cnt += 1 - pos = idx - - if cnt > 1: - raise RuntimeError("only one dimension can be inferred") - - org_prod = reduce(operator.mul, org_sp, 1) - tgt_prod = reduce(operator.mul, tgt_sp, 1) - - if cnt == 0: - if org_prod != tgt_prod: - raise RuntimeError("shape '{}' is invalid for input of size {}".format(tgt_sp, org_prod)) - else: - return tgt_sp - elif org_prod % tgt_prod != 0: - raise RuntimeError("shape '{}' is invalid for input of size {}".format(tgt_sp, org_prod)) - - infer_dim = -(org_prod // tgt_prod) - return tgt_sp[:pos] + (infer_dim,) + tgt_sp[pos + 1:] - - -@colo_op_impl(torch.Tensor.view) -def colo_view(self: ColoTensor, *shape) -> 'ColoTensor': - """Handles ``__torch_function__`` dispatch for ``torch.Tensor.view``. - Changes the shape of the current tensor. - """ - assert isinstance(self, ColoTensor) - # apply original `view` function for replicated colo tensors - if self.is_replicate(): - return self.view(*shape) - - cur_sp = self.size() - org_sp = self.size_global() - # parse the passed arguments - tgt_sp = _get_valid_shape(shape) - # get the correct shape from inference - inf_sp = _shape_infer(org_sp, tgt_sp) - - if self.is_shard_1drow() and org_sp[0] == inf_sp[0]: - new_shape = (cur_sp[0],) + tgt_sp[1:] - res = self.view(*new_shape) - elif self.is_shard_1dcol() and org_sp[-1] == inf_sp[-1]: - new_shape = tgt_sp[:-1] + (cur_sp[-1],) - res = self.view(*new_shape) - else: - replicated_t = self.redistribute(dist_spec=ReplicaSpec()) - return ColoTensor.from_torch_tensor(tensor=replicated_t.view(*shape), - spec=ColoTensorSpec(self.get_process_group())) - - return ColoTensor.from_torch_tensor(tensor=res, - spec=ColoTensorSpec(pg=self.get_process_group(), dist_attr=self.dist_spec)) - - -@colo_op_impl(torch.Tensor.size) -def colo_size(self: ColoTensor, dim: Optional[int] = None) -> Union[torch.Size, int]: - size = self.size_global() - if dim is None: - return size - else: - return size[dim] diff --git a/colossalai/nn/init.py b/colossalai/nn/init.py index 559b7038fc352a0ccbea22403cd4b1284bed42e0..2637aa8eaaeb91ffa0a1c9c42f8536ebb974f96e 100644 --- a/colossalai/nn/init.py +++ b/colossalai/nn/init.py @@ -1,8 +1,8 @@ import math import warnings -from torch import Tensor import torch.nn as nn +from torch import Tensor def zeros_(): @@ -23,7 +23,7 @@ def ones_(): return initializer -def uniform_(a: float = 0., b: float = 1.): +def uniform_(a: float = 0.0, b: float = 1.0): r"""Return the initializer filling the input Tensor with values drawn from the uniform distribution :math:`\mathcal{U}(a, b)`. @@ -38,7 +38,7 @@ def uniform_(a: float = 0., b: float = 1.): return initializer -def normal_(mean: float = 0., std: float = 1.): +def normal_(mean: float = 0.0, std: float = 1.0): r"""Return the initializer filling the input Tensor with values drawn from the normal distribution .. math:: @@ -47,7 +47,7 @@ def normal_(mean: float = 0., std: float = 1.): Args: mean (float): the mean of the normal distribution. Defaults 0.0. std (float): the standard deviation of the normal distribution. Defaults 1.0. - """ + """ def initializer(tensor: Tensor, fan_in: int = None, fan_out: int = None): return nn.init.normal_(tensor, mean, std) @@ -55,7 +55,7 @@ def normal_(mean: float = 0., std: float = 1.): return initializer -def trunc_normal_(mean: float = 0., std: float = 1., a: float = -2., b: float = 2.): +def trunc_normal_(mean: float = 0.0, std: float = 1.0, a: float = -2.0, b: float = 2.0): r"""Return the initializer filling the input Tensor with values drawn from a truncated normal distribution. The values are effectively drawn from the normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)` @@ -76,7 +76,7 @@ def trunc_normal_(mean: float = 0., std: float = 1., a: float = -2., b: float = return initializer -def kaiming_uniform_(a=0, mode='fan_in', nonlinearity='leaky_relu'): +def kaiming_uniform_(a=0, mode="fan_in", nonlinearity="leaky_relu"): r"""Return the initializer filling the input `Tensor` with values according to the method described in `Delving deep into rectifiers: Surpassing human-level performance on ImageNet classification` - He, K. et al. (2015), using a @@ -104,23 +104,23 @@ def kaiming_uniform_(a=0, mode='fan_in', nonlinearity='leaky_relu'): warnings.warn("Initializing zero-element tensors is a no-op") return tensor - if mode == 'fan_in': - assert fan_in is not None, 'Fan_in is not provided.' + if mode == "fan_in": + assert fan_in is not None, "Fan_in is not provided." fan = fan_in - elif mode == 'fan_out': - assert fan_out is not None, 'Fan_out is not provided.' + elif mode == "fan_out": + assert fan_out is not None, "Fan_out is not provided." fan = fan_out else: - raise ValueError(f'Invalid initialization mode \'{mode}\'') + raise ValueError(f"Invalid initialization mode '{mode}'") std = nn.init.calculate_gain(nonlinearity, a) / math.sqrt(fan) - bound = math.sqrt(3.) * std + bound = math.sqrt(3.0) * std return nn.init.uniform_(tensor, -bound, bound) return initializer -def kaiming_normal_(a=0, mode='fan_in', nonlinearity='leaky_relu'): +def kaiming_normal_(a=0, mode="fan_in", nonlinearity="leaky_relu"): r"""Return the initializer filling the input `Tensor` with values according to the method described in `Delving deep into rectifiers: Surpassing human-level performance on ImageNet classification` - He, K. et al. (2015), using a @@ -148,14 +148,14 @@ def kaiming_normal_(a=0, mode='fan_in', nonlinearity='leaky_relu'): warnings.warn("Initializing zero-element tensors is a no-op") return tensor - if mode == 'fan_in': - assert fan_in is not None, 'Fan_in is not provided.' + if mode == "fan_in": + assert fan_in is not None, "Fan_in is not provided." fan = fan_in - elif mode == 'fan_out': - assert fan_out is not None, 'Fan_out is not provided.' + elif mode == "fan_out": + assert fan_out is not None, "Fan_out is not provided." fan = fan_out else: - raise ValueError(f'Invalid initialization mode \'{mode}\'') + raise ValueError(f"Invalid initialization mode '{mode}'") std = nn.init.calculate_gain(nonlinearity, a) / math.sqrt(fan) return nn.init.normal_(tensor, 0, std) @@ -163,7 +163,7 @@ def kaiming_normal_(a=0, mode='fan_in', nonlinearity='leaky_relu'): return initializer -def xavier_uniform_(a: float = math.sqrt(3.), scale: float = 2., gain: float = 1.): +def xavier_uniform_(a: float = math.sqrt(3.0), scale: float = 2.0, gain: float = 1.0): r"""Return the initializer filling the input `Tensor` with values according to the method described in `Understanding the difficulty of training deep feedforward neural networks` - Glorot, X. & Bengio, Y. (2010), using a uniform @@ -184,7 +184,7 @@ def xavier_uniform_(a: float = math.sqrt(3.), scale: float = 2., gain: float = 1 # adapted from torch.nn.init def initializer(tensor: Tensor, fan_in: int = None, fan_out: int = None): - assert fan_in is not None, 'Fan_in is not provided.' + assert fan_in is not None, "Fan_in is not provided." fan = fan_in if fan_out is not None: @@ -197,7 +197,7 @@ def xavier_uniform_(a: float = math.sqrt(3.), scale: float = 2., gain: float = 1 return initializer -def xavier_normal_(scale: float = 2., gain: float = 1.): +def xavier_normal_(scale: float = 2.0, gain: float = 1.0): r"""Return the initializer filling the input `Tensor` with values according to the method described in `Understanding the difficulty of training deep feedforward neural networks` - Glorot, X. & Bengio, Y. (2010), using a normal @@ -216,7 +216,7 @@ def xavier_normal_(scale: float = 2., gain: float = 1.): # adapted from torch.nn.init def initializer(tensor: Tensor, fan_in: int = None, fan_out: int = None): - assert fan_in is not None, 'Fan_in is not provided.' + assert fan_in is not None, "Fan_in is not provided." fan = fan_in if fan_out is not None: @@ -224,7 +224,7 @@ def xavier_normal_(scale: float = 2., gain: float = 1.): std = gain * math.sqrt(scale / float(fan)) - return nn.init.normal_(tensor, 0., std) + return nn.init.normal_(tensor, 0.0, std) return initializer @@ -232,7 +232,7 @@ def xavier_normal_(scale: float = 2., gain: float = 1.): def lecun_uniform_(): # adapted from jax.nn.initializers def initializer(tensor: Tensor, fan_in: int = None, fan_out: int = None): - assert fan_in is not None, 'Fan_in is not provided.' + assert fan_in is not None, "Fan_in is not provided." var = 1.0 / fan_in bound = math.sqrt(3 * var) @@ -244,9 +244,9 @@ def lecun_uniform_(): def lecun_normal_(): # adapted from jax.nn.initializers def initializer(tensor: Tensor, fan_in: int = None, fan_out: int = None): - assert fan_in is not None, 'Fan_in is not provided.' + assert fan_in is not None, "Fan_in is not provided." std = math.sqrt(1.0 / fan_in) - return nn.init.trunc_normal_(tensor, std=std / .87962566103423978) + return nn.init.trunc_normal_(tensor, std=std / 0.87962566103423978) return initializer diff --git a/colossalai/nn/layer/__init__.py b/colossalai/nn/layer/__init__.py index b705632f80407ceda019f4b905c724d1b2254f66..9aeab9f44a6d519a01d322b78982781628bcf0c5 100644 --- a/colossalai/nn/layer/__init__.py +++ b/colossalai/nn/layer/__init__.py @@ -1,10 +1,2 @@ -from .colossalai_layer import * -from .parallel_1d import * -from .parallel_2d import * -from .parallel_2p5d import * -from .parallel_3d import * -from .parallel_sequence import * -from .moe import * +# from .moe import * from .utils import * -from .vanilla import * -from .wrapper import * diff --git a/colossalai/nn/layer/base_layer.py b/colossalai/nn/layer/base_layer.py deleted file mode 100644 index c85f53cc44c3660e39a1cf2995ca3e44a70b4c04..0000000000000000000000000000000000000000 --- a/colossalai/nn/layer/base_layer.py +++ /dev/null @@ -1,62 +0,0 @@ -#!/usr/bin/env python -# -*- encoding: utf-8 -*- - -import torch.nn as nn - -from colossalai.context import ParallelMode -from colossalai.core import global_context as gpc -from contextlib import contextmanager - - -class ParallelLayer(nn.Module): - global_state_dict: bool = True - - def __init__(self): - super().__init__() - self.data_parallel_rank = 0 if not gpc.is_initialized(ParallelMode.DATA) else gpc.get_local_rank( - ParallelMode.DATA) - self.data_parallel_size = 1 if not gpc.is_initialized(ParallelMode.DATA) else gpc.get_world_size( - ParallelMode.DATA) - - self.tensor_parallel_rank = 0 if not gpc.is_initialized(ParallelMode.TENSOR) else gpc.get_local_rank( - ParallelMode.TENSOR) - self.tensor_parallel_size = 1 if not gpc.is_initialized(ParallelMode.TENSOR) else gpc.get_world_size( - ParallelMode.TENSOR) - - self.pipeline_parallel_rank = 0 if not gpc.is_initialized(ParallelMode.PIPELINE) else gpc.get_local_rank( - ParallelMode.PIPELINE) - self.pipeline_parallel_size = 1 if not gpc.is_initialized(ParallelMode.PIPELINE) else gpc.get_world_size( - ParallelMode.PIPELINE) - - def _load_from_global_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, - error_msgs): - return super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, - error_msgs) - - def _save_to_global_state_dict(self, destination, prefix, keep_vars): - return super()._save_to_state_dict(destination, prefix, keep_vars) - - def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, - error_msgs): - if self.global_state_dict: - if gpc.get_local_rank(ParallelMode.TENSOR) != 0: - missing_keys.clear() - unexpected_keys.clear() - return self._load_from_global_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, - unexpected_keys, error_msgs) - return super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, - error_msgs) - - def _save_to_state_dict(self, destination, prefix, keep_vars): - if self.global_state_dict: - return self._save_to_global_state_dict(destination, prefix, keep_vars) - return super()._save_to_state_dict(destination, prefix, keep_vars) - - @classmethod - @contextmanager - def use_local_state_dict(cls): - try: - cls.global_state_dict = False - yield - finally: - cls.global_state_dict = True diff --git a/colossalai/nn/layer/colossalai_layer/__init__.py b/colossalai/nn/layer/colossalai_layer/__init__.py deleted file mode 100644 index 2ae1b07a75b2e7a231fc3512e8f46bccd0e9d4c6..0000000000000000000000000000000000000000 --- a/colossalai/nn/layer/colossalai_layer/__init__.py +++ /dev/null @@ -1,7 +0,0 @@ -from ._utils import partition_batch -from .dropout import Dropout -from .embedding import Embedding, PatchEmbedding -from .linear import Classifier, Linear -from .normalization import LayerNorm - -__all__ = ['Linear', 'Classifier', 'Embedding', 'PatchEmbedding', 'LayerNorm', 'Dropout', 'partition_batch'] diff --git a/colossalai/nn/layer/colossalai_layer/_utils.py b/colossalai/nn/layer/colossalai_layer/_utils.py deleted file mode 100644 index 677cb0e7ac428856c6888ae195ab32c7c70d5758..0000000000000000000000000000000000000000 --- a/colossalai/nn/layer/colossalai_layer/_utils.py +++ /dev/null @@ -1,41 +0,0 @@ -import torch.nn as nn -from torch import Tensor - -from ..parallel_2d._operation import split_batch_2d -from ..parallel_2p5d._operation import split_batch_2p5d -from ..parallel_3d._operation import split_batch_3d -from ..utils import get_tensor_parallel_mode - -_parallel_split_batch = {'2d': split_batch_2d, '2.5d': split_batch_2p5d, '3d': split_batch_3d} - - -def partition_batch(input_) -> Tensor: - tensor_parallel_mode = get_tensor_parallel_mode() - if tensor_parallel_mode in _parallel_split_batch: - if isinstance(input_, dict): - return {k: _parallel_split_batch[tensor_parallel_mode](v) for k, v in input_.items()} - else: - return _parallel_split_batch[tensor_parallel_mode](input_) - else: - return input_ - - -class ColossalaiModule(nn.Module): - - def __init__(self, module: nn.Module, **kwargs): - super().__init__() - self.module = module - for k, v in kwargs.items(): - setattr(self, k, v) - - def __getattr__(self, name: str): - if name == 'module': - return super().__getattr__(name) - elif hasattr(self.module, name): - return getattr(self.module, name) - elif name in self.__dict__: - return self.__dict__[name] - raise AttributeError("'{}' object has no attribute '{}'".format(type(self).__name__, name)) - - def forward(self, *args): - return self.module(*args) diff --git a/colossalai/nn/layer/colossalai_layer/dropout.py b/colossalai/nn/layer/colossalai_layer/dropout.py deleted file mode 100644 index 0c049cb3f408e22eb9ce2b67d354da305aeadbae..0000000000000000000000000000000000000000 --- a/colossalai/nn/layer/colossalai_layer/dropout.py +++ /dev/null @@ -1,31 +0,0 @@ -import torch.nn as nn - -from colossalai.context import ParallelMode, seed - -from ..parallel_1d import * -from ..utils import get_tensor_parallel_mode -from ._utils import ColossalaiModule - - -class Dropout(ColossalaiModule): - """Dropout layer of colossalai. - - Args: - p (float, optional): probability of an element to be zeroed, defaults 0.5. - inplace (bool, optional): whether to do dropout in-place, default to be False. - """ - - def __init__(self, p: float = 0.5, inplace: bool = False) -> None: - tensor_parallel = get_tensor_parallel_mode() - if tensor_parallel == "1d": - drop = Dropout1D(p, inplace) - else: - drop = nn.Dropout(p, inplace) - super().__init__(drop, tensor_parallel=tensor_parallel) - - def forward(self, *args): - if self.tensor_parallel in [None, '1d']: - return super().forward(*args) - else: - with seed(ParallelMode.TENSOR): - return super().forward(*args) diff --git a/colossalai/nn/layer/colossalai_layer/embedding.py b/colossalai/nn/layer/colossalai_layer/embedding.py deleted file mode 100644 index e5c9c46e0ff1e6fc3413809a493e7a1012273315..0000000000000000000000000000000000000000 --- a/colossalai/nn/layer/colossalai_layer/embedding.py +++ /dev/null @@ -1,151 +0,0 @@ -import math -from typing import Callable - -from colossalai.utils import get_current_device -from torch import dtype, nn - -from ... import init as init -from ..parallel_1d import Embedding1D, PatchEmbedding1D, VocabParallelEmbedding1D -from ..parallel_2d import Embedding2D, PatchEmbedding2D, VocabParallelEmbedding2D -from ..parallel_2p5d import Embedding2p5D, PatchEmbedding2p5D, VocabParallelEmbedding2p5D -from ..parallel_3d import Embedding3D, PatchEmbedding3D, VocabParallelEmbedding3D -from ..utils import get_tensor_parallel_mode -from ..vanilla import VanillaPatchEmbedding -from ._utils import ColossalaiModule - -_parallel_embedding = { - '1d': Embedding1D, - '2d': Embedding2D, - '2.5d': Embedding2p5D, - '3d': Embedding3D, -} - -_vocab_parallel_embedding = { - '1d': VocabParallelEmbedding1D, - '2d': VocabParallelEmbedding2D, - '2.5d': VocabParallelEmbedding2p5D, - '3d': VocabParallelEmbedding3D -} - -_parallel_patchembedding = { - None: VanillaPatchEmbedding, - '1d': PatchEmbedding1D, - '2d': PatchEmbedding2D, - '2.5d': PatchEmbedding2p5D, - '3d': PatchEmbedding3D -} - - -class Embedding(ColossalaiModule): - r"""Embedding for colossalai. - - Args: - num_embeddings (int): number of embeddings. - embedding_dim (int): dimension of embedding. - padding_idx (int, optional): If specified, the entries at padding_idx do not contribute to the gradient; - therefore, the embedding vector at padding_idx is not updated during training, - i.e. it remains as a fixed “pad”, defaults to None. - dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None. - weight_initializer (:class:`typing.Callable`, optional): - he initializer of weight, defaults to normal initializer. - - The ``args`` and ``kwargs`` used in :class:`torch.nn.functional.embedding` should contain: - :: - - max_norm (float, optional): If given, each embedding vector with norm larger than max_norm is - renormalized to have norm max_norm. Note: this will modify weight in-place. - norm_type (float, optional): The p of the p-norm to compute for the max_norm option. Default 2. - scale_grad_by_freq (bool, optional): If given, this will scale gradients by the inverse - of frequency of the words in the mini-batch. Default False. - sparse (bool, optional): If True, gradient w.r.t. weight will be a sparse tensor. Default False. - - More details about ``args`` and ``kwargs`` could be found in - `Embedding `_. - - More details about ``initializer`` please refer to - `init `_ - """ - - def __init__(self, - num_embeddings: int, - embedding_dim: int, - padding_idx: int = None, - dtype: dtype = None, - weight_initializer: Callable = init.normal_(), - vocab_parallel_limit: int = 2048, - *args, - **kwargs) -> None: - tensor_parallel = get_tensor_parallel_mode() - if tensor_parallel is None: - embed = nn.Embedding(num_embeddings, embedding_dim, padding_idx=padding_idx, *args, - **kwargs).to(dtype).to(get_current_device()) - weight_initializer(embed.weight, fan_in=num_embeddings, fan_out=embedding_dim) - elif num_embeddings <= vocab_parallel_limit: - embed = _parallel_embedding[tensor_parallel]( - num_embeddings, - embedding_dim, - padding_idx=padding_idx, - dtype=dtype, - weight_initializer=weight_initializer, - *args, - **kwargs, - ) - else: - embed = _vocab_parallel_embedding[tensor_parallel]( - num_embeddings, - embedding_dim, - padding_idx=padding_idx, - dtype=dtype, - weight_initializer=weight_initializer, - *args, - **kwargs, - ) - super().__init__(embed) - - -class PatchEmbedding(ColossalaiModule): - """2D Image to Patch Embedding. - - Args: - img_size (int): image size. - patch_size (int): patch size. - in_chans (int): number of channels of input image. - embed_size (int): size of embedding. - dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None. - flatten (bool, optional): whether to flatten output tensor, defaults to True. - weight_initializer (:class:`typing.Callable`, optional): - The initializer of weight, defaults to kaiming uniform initializer. - bias_initializer (:class:`typing.Callable`, optional): - The initializer of bias, defaults to xavier uniform initializer. - position_embed_initializer (:class:`typing.Callable`, optional): - The initializer of position embedding, defaults to zeros initializer. - - More details about ``initializer`` please refer to - `init `_. - """ - - def __init__( - self, - img_size: int, - patch_size: int, - in_chans: int, - embed_size: int, - dtype: dtype = None, - flatten: bool = True, - weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), - bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1), - position_embed_initializer: Callable = init.zeros_() - ) -> None: - tensor_parallel = get_tensor_parallel_mode() - embed = _parallel_patchembedding[tensor_parallel]( - img_size, - patch_size, - in_chans, - embed_size, - dtype=dtype, - flatten=flatten, - weight_initializer=weight_initializer, - bias_initializer=bias_initializer, - position_embed_initializer=position_embed_initializer, - ) - super().__init__(embed) diff --git a/colossalai/nn/layer/colossalai_layer/linear.py b/colossalai/nn/layer/colossalai_layer/linear.py deleted file mode 100644 index 3e0c6e285c1c64ccfc25bf5eca6f636cc8744aea..0000000000000000000000000000000000000000 --- a/colossalai/nn/layer/colossalai_layer/linear.py +++ /dev/null @@ -1,141 +0,0 @@ -import inspect -import math -from typing import Callable - -from torch import dtype, nn - -from colossalai.utils import get_current_device - -from ... import init as init -from ..parallel_1d import * -from ..parallel_2d import * -from ..parallel_2p5d import * -from ..parallel_3d import * -from ..utils import get_tensor_parallel_mode -from ..vanilla import * -from ._utils import ColossalaiModule - -_parallel_linear = {None: VanillaLinear, '1d': Linear1D, '2d': Linear2D, '2.5d': Linear2p5D, '3d': Linear3D} - -_parallel_classifier = { - None: VanillaClassifier, - '1d': Classifier1D, - '2d': Classifier2D, - '2.5d': Classifier2p5D, - '3d': Classifier3D -} - -_vocab_parallel_classifier = { - '1d': VocabParallelClassifier1D, - '2d': VocabParallelClassifier2D, - '2.5d': VocabParallelClassifier2p5D, - '3d': VocabParallelClassifier3D -} - - -class Linear(ColossalaiModule): - """Linear layer of colossalai. - - Args: - in_features (int): size of each input sample. - out_features (int): size of each output sample. - bias (bool, optional): If set to ``False``, the layer will not learn an additive bias, defaults to ``True``. - dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None. - weight_initializer (:class:`typing.Callable`, optional): - The initializer of weight, defaults to kaiming uniform initializer. - bias_initializer (:class:`typing.Callable`, optional): - The initializer of bias, defaults to xavier uniform initializer. - - Note: ``kwargs`` would contain different parameters when you use different parallelisms. - - The ``kwargs`` should contain parameters below: - :: - - Linear1D: - gather_output: bool (optional, default to be false) - skip_bias_add: bool (optional, default to be false) - Linear2D: - skip_bias_add: bool (optional, default to be false) - Linear2p5D: - skip_bias_add: bool (optional, default to be false) - Linear3D: - None - - More details about ``initializer`` please refer to - `init `_. - """ - - def __init__(self, - in_features: int, - out_features: int, - bias: bool = True, - dtype: dtype = None, - weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), - bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1), - **kwargs) -> None: - tensor_parallel = get_tensor_parallel_mode() - linear_cls = _parallel_linear[tensor_parallel] - gather_output = kwargs.pop('gather_output', None) - if 'gather_output' in inspect.signature(linear_cls.__init__).parameters.keys(): # gather_out arg is available - kwargs['gather_output'] = gather_output - layer = linear_cls( - in_features, - out_features, - bias=bias, - dtype=dtype, - weight_initializer=weight_initializer, - bias_initializer=bias_initializer, - **kwargs, - ) - super().__init__(layer) - - -class Classifier(ColossalaiModule): - """Classifier layer of colossalai. - - Args: - in_features (int): size of each input sample. - num_classes (int): number of classes. - weight (:class:`torch.nn.Parameter`, optional): weight of the classifier, defaults to None. - bias (bool, optional): If set to ``False``, the layer will not learn an additive bias, defaults to ``True``. - dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None. - weight_initializer (:class:`typing.Callable`, optional): - The initializer of weight, defaults to kaiming uniform initializer. - bias_initializer (:class:`typing.Callable`, optional): - The initializer of bias, defaults to xavier uniform initializer. - - More details about ``initializer`` please refer to - `init `_. - """ - - def __init__(self, - in_features: int, - num_classes: int, - weight: nn.Parameter = None, - bias: bool = True, - dtype: dtype = None, - weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), - bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1), - vocab_parallel_limit: int = 2048) -> None: - tensor_parallel = get_tensor_parallel_mode() - if num_classes <= vocab_parallel_limit or tensor_parallel is None: - layer = _parallel_classifier[tensor_parallel]( - in_features, - num_classes, - weight=weight, - bias=bias, - dtype=dtype, - weight_initializer=weight_initializer, - bias_initializer=bias_initializer, - ) - else: - layer = _vocab_parallel_classifier[tensor_parallel]( - in_features, - num_classes, - weight=weight, - bias=bias, - dtype=dtype, - weight_initializer=weight_initializer, - bias_initializer=bias_initializer, - ) - super().__init__(layer) diff --git a/colossalai/nn/layer/colossalai_layer/normalization.py b/colossalai/nn/layer/colossalai_layer/normalization.py deleted file mode 100644 index 86861d30214a43a95192d0f179be6ba705e002c8..0000000000000000000000000000000000000000 --- a/colossalai/nn/layer/colossalai_layer/normalization.py +++ /dev/null @@ -1,41 +0,0 @@ -from colossalai.utils import get_current_device -from torch import nn - -from ..parallel_1d import LayerNorm1D -from ..parallel_2d import LayerNorm2D -from ..parallel_2p5d import LayerNorm2p5D -from ..parallel_3d import LayerNorm3D -from ..utils import get_tensor_parallel_mode -from ..vanilla import VanillaLayerNorm -from ._utils import ColossalaiModule - -_parallel_layernorm = { - None: VanillaLayerNorm, - "1d": LayerNorm1D, - "2d": LayerNorm2D, - "2.5d": LayerNorm2p5D, - "3d": LayerNorm3D, -} - - -class LayerNorm(ColossalaiModule): - r"""Layer Normalization for colossalai. - - Args: - normalized_shape (int): input shape from an expected input of size. - :math:`[* \times \text{normalized_shape}[0] \times \text{normalized_shape}[1] - \times \ldots \times \text{normalized_shape}[-1]]` - If a single integer is used, it is treated as a singleton list, and this module will - normalize over the last dimension which is expected to be of that specific size. - eps (float): a value added to the denominator for numerical stability, defaults to 1e-05. - bias (bool, optional): Whether to add a bias, defaults to ``True``. - dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None. - """ - - def __init__(self, normalized_shape: int, eps=1e-05, bias=True, dtype=None) -> None: - tensor_parallel = get_tensor_parallel_mode() - if tensor_parallel is None: - norm = nn.LayerNorm(normalized_shape, eps=eps).to(dtype).to(get_current_device()) - else: - norm = _parallel_layernorm[tensor_parallel](normalized_shape, eps=eps, dtype=dtype) - super().__init__(norm) diff --git a/colossalai/nn/layer/moe/__init__.py b/colossalai/nn/layer/moe/__init__.py index 05333fe965f1a4efff779550cb5e85946940fd7b..6a5ccff510be715e2d6ea30de3d70e43a4b8d8b7 100644 --- a/colossalai/nn/layer/moe/__init__.py +++ b/colossalai/nn/layer/moe/__init__.py @@ -5,6 +5,17 @@ from .routers import MoeRouter, Top1Router, Top2Router from .utils import NormalNoiseGenerator, UniformNoiseGenerator, build_ffn_experts __all__ = [ - 'Experts', 'FFNExperts', 'TPExperts', 'Top1Router', 'Top2Router', 'MoeLayer', 'NormalNoiseGenerator', - 'UniformNoiseGenerator', 'build_ffn_experts', 'MoeModule', 'MoeRouter', 'save_moe_model', 'load_moe_model' + "Experts", + "FFNExperts", + "TPExperts", + "Top1Router", + "Top2Router", + "MoeLayer", + "NormalNoiseGenerator", + "UniformNoiseGenerator", + "build_ffn_experts", + "MoeModule", + "MoeRouter", + "save_moe_model", + "load_moe_model", ] diff --git a/colossalai/nn/layer/moe/_operation.py b/colossalai/nn/layer/moe/_operation.py index 37f31c16709b94bb6d2cebbabd850c03a85895eb..2f0b7e43673a2a2d7f09d39e13c087cc4a77a55b 100644 --- a/colossalai/nn/layer/moe/_operation.py +++ b/colossalai/nn/layer/moe/_operation.py @@ -18,18 +18,18 @@ def build_moe_if_not_prebuilt(): global moe if moe is None: from colossalai.kernel.op_builder import MOEBuilder + moe = MOEBuilder().load() class AllGather(torch.autograd.Function): - @staticmethod def forward(ctx: Any, inputs: Tensor, group: Optional[ProcessGroup] = None) -> Tensor: - global moe if moe is None: from colossalai.kernel.op_builder import MOEBuilder + moe = MOEBuilder().load() if ctx is not None: @@ -51,7 +51,6 @@ class AllGather(torch.autograd.Function): class ReduceScatter(torch.autograd.Function): - @staticmethod def forward(ctx: Any, inputs: Tensor, group: Optional[ProcessGroup] = None) -> Tensor: if ctx is not None: @@ -98,7 +97,6 @@ class AllToAll(torch.autograd.Function): class MoeDispatch(torch.autograd.Function): - @staticmethod def forward(ctx, tokens, mask, dest_idx, ec): s = tokens.size(0) @@ -124,7 +122,6 @@ class MoeDispatch(torch.autograd.Function): class MoeCombine(torch.autograd.Function): - @staticmethod def forward(ctx, expert_tokens, logits, mask, dest_idx, ec): assert logits.dtype == torch.float32 @@ -137,7 +134,7 @@ class MoeCombine(torch.autograd.Function): # load moe kernel during runtime if not pre-built build_moe_if_not_prebuilt() - fp16_flag = (expert_tokens.dtype == torch.float16) + fp16_flag = expert_tokens.dtype == torch.float16 cb_input = expert_tokens.to(torch.float32) if fp16_flag else expert_tokens ctokens = moe.combine_forward(s, e, c, h, cb_input, logits, mask, dest_idx) output = ctokens.to(torch.float16) if fp16_flag else ctokens @@ -155,8 +152,7 @@ class MoeCombine(torch.autograd.Function): def backward(ctx, tokens_grad): expert_tokens, logits, mask, dest_idx = ctx.saved_tensors - cb_grad = tokens_grad.to(torch.float32) if tokens_grad.dtype is torch.float16 \ - else tokens_grad + cb_grad = tokens_grad.to(torch.float32) if tokens_grad.dtype is torch.float16 else tokens_grad cb_input = expert_tokens.to(torch.float32) if ctx.fp16_flag else expert_tokens d_expert, d_logits = moe.combine_backward(ctx.s, ctx.e, ctx.c, ctx.h, cb_grad, cb_input, logits, mask, dest_idx) d_expert = d_expert.to(torch.float16) if ctx.fp16_flag else d_expert diff --git a/colossalai/nn/layer/moe/checkpoint.py b/colossalai/nn/layer/moe/checkpoint.py index efda1f22252d97db97219c9cc5a0a6c62c182bcd..adad19d581efe1478c4ab61f4b495770ef262c7f 100644 --- a/colossalai/nn/layer/moe/checkpoint.py +++ b/colossalai/nn/layer/moe/checkpoint.py @@ -16,7 +16,7 @@ def load_moe_model(model: nn.Module, load_path: str): state_dict = torch.load(load_path) for prefix, module in model.named_modules(): - if prefix.endswith('.moe_layer.experts'): + if prefix.endswith(".moe_layer.experts"): # this module should be an Experts instance assert isinstance(module, MoeExperts) @@ -25,16 +25,16 @@ def load_moe_model(model: nn.Module, load_path: str): for i in range(num_local): expert_id = ep_rank * num_local + i for name, _ in module.experts[i].named_parameters(): - cur_key = f'{prefix}.experts.{i}.{name}' - param_key = f'{prefix}.experts.{expert_id}.{name}' + cur_key = f"{prefix}.experts.{i}.{name}" + param_key = f"{prefix}.experts.{expert_id}.{name}" load_param = state_dict[param_key] state_dict[cur_key] = load_param for name, _ in module.experts[0].named_parameters(): - pop_pre = f'{prefix}.experts.' - pop_suf = f'.{name}' + pop_pre = f"{prefix}.experts." + pop_suf = f".{name}" for i in range(num_local, module.num_total_experts): - pop_key = f'{pop_pre}{i}{pop_suf}' + pop_key = f"{pop_pre}{i}{pop_suf}" state_dict.pop(pop_key) model.load_state_dict(state_dict) diff --git a/colossalai/nn/layer/moe/experts.py b/colossalai/nn/layer/moe/experts.py index 56b11f4d9e08b7556268ea427a3fdda92d52ef61..4b2ecb24170267d7fb2c621ae6499ff61d1bdcdf 100644 --- a/colossalai/nn/layer/moe/experts.py +++ b/colossalai/nn/layer/moe/experts.py @@ -6,10 +6,10 @@ import torch import torch.distributed as dist import torch.nn as nn -from colossalai.context import ParallelMode, seed from colossalai.context.moe_context import MOE_CONTEXT +from colossalai.legacy.context import ParallelMode, seed +from colossalai.legacy.zero.init_ctx import no_shard_zero_decrator from colossalai.utils import get_current_device -from colossalai.zero.legacy.init_ctx import no_shard_zero_decrator class MoeExperts(nn.Module): @@ -20,8 +20,10 @@ class MoeExperts(nn.Module): def __init__(self, comm_name: str, num_experts: int): super().__init__() - assert comm_name in {"all_to_all", "all_gather"}, \ - "This kind of communication has not been implemented yet.\n Please use Experts build function." + assert comm_name in { + "all_to_all", + "all_gather", + }, "This kind of communication has not been implemented yet.\n Please use Experts build function." self.comm_name = comm_name self.num_total_experts = num_experts # Get the configuration of experts' deployment and parallel information from moe context @@ -50,7 +52,7 @@ class Experts(MoeExperts): # Attach parallel information for all parameters in Experts for exp in self.experts: for param in exp.parameters(): - param.__setattr__('moe_info', self.dist_info) + param.__setattr__("moe_info", self.dist_info) def forward(self, inputs: torch.Tensor): # Split inputs for each expert @@ -65,7 +67,7 @@ class Experts(MoeExperts): output = torch.cat(expert_output, dim=1).contiguous() return output - def state_dict(self, destination=None, prefix='', keep_vars=False): + def state_dict(self, destination=None, prefix="", keep_vars=False): assert keep_vars == False, "Only support keep_vars=False now" dp_rank = dist.get_rank(self.dist_info.dp_group) ep_rank = dist.get_rank(self.dist_info.ep_group) @@ -79,11 +81,11 @@ class Experts(MoeExperts): example_submodule = subm if dp_rank == 0: - local_prefix = prefix + 'experts.' + local_prefix = prefix + "experts." buffer_module = deepcopy(example_submodule) for i in range(self.num_total_experts): source_rank = i // self.num_local_experts - current_prefix = local_prefix + str(i) + '.' + current_prefix = local_prefix + str(i) + "." comm_module = submodule_dict.get(i, buffer_module) for name, param in comm_module.named_parameters(): dist.broadcast(param.data, src=source_rank, group=self.dist_info.ep_group) @@ -94,8 +96,7 @@ class Experts(MoeExperts): class FFNExperts(MoeExperts): - """Use torch.bmm to speed up for multiple experts. - """ + """Use torch.bmm to speed up for multiple experts.""" def __init__(self, num_experts: int, d_model: int, d_ff: int, activation=None, drop_rate: float = 0): super().__init__("all_to_all", num_experts) @@ -119,10 +120,9 @@ class FFNExperts(MoeExperts): self.drop = nn.Dropout(p=drop_rate) for param in self.parameters(): - param.__setattr__('moe_info', self.dist_info) - - def forward(self, inputs): # inputs [g, el, c, h] + param.__setattr__("moe_info", self.dist_info) + def forward(self, inputs): # inputs [g, el, c, h] el = inputs.size(1) h = inputs.size(-1) @@ -137,7 +137,7 @@ class FFNExperts(MoeExperts): out_model = torch.baddbmm(self.b2, out_inter, self.w2) with seed(ParallelMode.TENSOR): - outputs = self.drop(out_model) # outputs [el, gc, h] + outputs = self.drop(out_model) # outputs [el, gc, h] outputs = outputs.reshape(inshape) outputs = outputs.transpose(0, 1).contiguous() @@ -153,8 +153,7 @@ class TPExperts(MoeExperts): def __init__(self, num_experts: int, d_model: int, d_ff: int, activation=None, drop_rate: float = 0): super().__init__("all_gather", MOE_CONTEXT.max_ep_size) - assert d_ff % MOE_CONTEXT.max_ep_size == 0, \ - "d_ff should be divide by maximum expert parallel size" + assert d_ff % MOE_CONTEXT.max_ep_size == 0, "d_ff should be divide by maximum expert parallel size" p_ff = d_ff // MOE_CONTEXT.max_ep_size @@ -177,12 +176,11 @@ class TPExperts(MoeExperts): self.act = nn.GELU() if activation is None else activation self.drop = nn.Dropout(p=drop_rate) - self.w1.__setattr__('moe_info', self.dist_info) - self.w2.__setattr__('moe_info', self.dist_info) - self.b1.__setattr__('moe_info', self.dist_info) - - def forward(self, inputs): # inputs [g, e, c, h] + self.w1.__setattr__("moe_info", self.dist_info) + self.w2.__setattr__("moe_info", self.dist_info) + self.b1.__setattr__("moe_info", self.dist_info) + def forward(self, inputs): # inputs [g, e, c, h] e = inputs.size(1) h = inputs.size(-1) @@ -196,8 +194,8 @@ class TPExperts(MoeExperts): out_inter = self.drop(out_act) out_model = torch.baddbmm(self.b2, out_inter, self.w2) - outputs = self.drop(out_model) # outputs [e, gc, h] + outputs = self.drop(out_model) # outputs [e, gc, h] outputs = outputs.reshape(inshape) outputs = outputs.transpose(0, 1).contiguous() - return outputs # outputs [g, e, c, h] + return outputs # outputs [g, e, c, h] diff --git a/colossalai/nn/layer/moe/layers.py b/colossalai/nn/layer/moe/layers.py index 03f55d91f3a861fd552c6df63dbc6b171d43e53e..23d483e6a17a0a0ed54023be59e3adf1b0100233 100644 --- a/colossalai/nn/layer/moe/layers.py +++ b/colossalai/nn/layer/moe/layers.py @@ -6,6 +6,7 @@ import torch.nn as nn import torch.nn.functional as F from colossalai.context.moe_context import MOE_CONTEXT +from colossalai.legacy.zero.init_ctx import no_shard_zero_context, no_shard_zero_decrator from colossalai.nn.layer.moe._operation import ( COL_MOE_KERNEL_FLAG, AllGather, @@ -18,7 +19,6 @@ from colossalai.nn.layer.moe.experts import Experts, MoeExperts from colossalai.nn.layer.moe.routers import MoeRouter, Top1Router, Top2Router from colossalai.nn.layer.moe.utils import NormalNoiseGenerator, UniformNoiseGenerator from colossalai.utils import get_current_device -from colossalai.zero.legacy.init_ctx import no_shard_zero_context, no_shard_zero_decrator @no_shard_zero_decrator(is_replicated=True) @@ -89,8 +89,9 @@ class MoeLayer(nn.Module): elif self.experts.comm_name == "all_gather": expert_output = self.tp_process(dispatch_data) else: - raise NotImplementedError("This kind of communication has not been implemented yet.\n Please use Experts " - "build function.") + raise NotImplementedError( + "This kind of communication has not been implemented yet.\n Please use Experts " "build function." + ) # expert_output [e, c, h] if self.use_kernel: expert_output = expert_output.reshape(-1, self.d_model) @@ -135,27 +136,29 @@ class MoeModule(nn.Module): https://arxiv.org/abs/2201.05596 """ - def __init__(self, - dim_model: int, - num_experts: int, - top_k: int = 1, - capacity_factor_train: float = 1.25, - capacity_factor_eval: float = 2.0, - min_capacity: int = 4, - noisy_policy: Optional[str] = None, - drop_tks: bool = True, - use_residual: bool = False, - residual_instance: Optional[nn.Module] = None, - expert_instance: Optional[MoeExperts] = None, - expert_cls: Optional[Type[nn.Module]] = None, - **expert_args): + def __init__( + self, + dim_model: int, + num_experts: int, + top_k: int = 1, + capacity_factor_train: float = 1.25, + capacity_factor_eval: float = 2.0, + min_capacity: int = 4, + noisy_policy: Optional[str] = None, + drop_tks: bool = True, + use_residual: bool = False, + residual_instance: Optional[nn.Module] = None, + expert_instance: Optional[MoeExperts] = None, + expert_cls: Optional[Type[nn.Module]] = None, + **expert_args, + ): super().__init__() noisy_func = None if noisy_policy is not None: - if noisy_policy == 'Jitter': + if noisy_policy == "Jitter": noisy_func = UniformNoiseGenerator() - elif noisy_policy == 'Gaussian': + elif noisy_policy == "Gaussian": noisy_func = NormalNoiseGenerator(num_experts) else: raise NotImplementedError("Unsupported input noisy policy") @@ -167,18 +170,19 @@ class MoeModule(nn.Module): else: raise NotImplementedError("top_k > 2 is not supported yet") - self.moe_router = moe_router_cls(capacity_factor_train=capacity_factor_train, - capacity_factor_eval=capacity_factor_eval, - min_capacity=min_capacity, - noisy_func=noisy_func, - drop_tks=drop_tks) + self.moe_router = moe_router_cls( + capacity_factor_train=capacity_factor_train, + capacity_factor_eval=capacity_factor_eval, + min_capacity=min_capacity, + noisy_func=noisy_func, + drop_tks=drop_tks, + ) self.use_residual = use_residual if use_residual: if residual_instance is not None: self.residual_module = residual_instance else: - assert expert_cls is not None, \ - "Expert class can't be None when residual instance is not given" + assert expert_cls is not None, "Expert class can't be None when residual instance is not given" self.residual_module = expert_cls(**expert_args) with no_shard_zero_context(): @@ -187,14 +191,12 @@ class MoeModule(nn.Module): if expert_instance is not None: my_experts = expert_instance else: - assert expert_cls is not None, \ - "Expert class can't be None when experts instance is not given" + assert expert_cls is not None, "Expert class can't be None when experts instance is not given" my_experts = Experts(expert_cls, num_experts, **expert_args) - self.moe_layer = MoeLayer(dim_model=dim_model, - num_experts=num_experts, - router=self.moe_router, - experts=my_experts) + self.moe_layer = MoeLayer( + dim_model=dim_model, num_experts=num_experts, router=self.moe_router, experts=my_experts + ) def forward(self, inputs: torch.Tensor): moe_output, l_aux = self.moe_layer(inputs) diff --git a/colossalai/nn/layer/moe/routers.py b/colossalai/nn/layer/moe/routers.py index c5b8390bf0472302d10ef96080a23007d02c9eb7..7ba83b2787a0c70a2c3dd457770850a0ed07ec85 100644 --- a/colossalai/nn/layer/moe/routers.py +++ b/colossalai/nn/layer/moe/routers.py @@ -1,226 +1,235 @@ -import math -from abc import ABC - -import torch -import torch.nn as nn -import torch.nn.functional as F -import torch.distributed as dist -from colossalai.utils import get_current_device -from colossalai.context import MOE_CONTEXT -from colossalai.nn.layer.moe._operation import moe_cumsum -from typing import Callable, Optional -from torch.distributed import ProcessGroup - - -class MoeRouter(nn.Module, ABC): - """Base class for all MoE routers. - Args: - k_value (int): The value of top_k. - capacity_factor_train (float): Capacity factor in routing of training. - capacity_factor_eval (float): Capacity factor in routing of evaluation. - min_capacity (int): The minimum number of the capacity of each expert. - noisy_func (:class:`typing.Callable`, optional): Noisy function used in logits. - drop_tks (bool, optional): Whether drops tokens in evaluation - """ - - def __init__(self, - k_value: int, - capacity_factor_train: float, - capacity_factor_eval: float, - min_capacity: int, - noisy_func: Callable = None, - drop_tks: bool = True): - super().__init__() - self.k_value = k_value - self.capacity_factor_train = capacity_factor_train - self.capacity_factor_eval = capacity_factor_eval - self.min_capacity = min_capacity - self.noisy_func = noisy_func - self.drop_tks = drop_tks - self._routing_loss = None - - def get_capacity(self, logits_shape): - capacity_factor = self.capacity_factor_train if self.training else self.capacity_factor_eval - capacity = math.floor(self.k_value * capacity_factor * logits_shape[-2] / logits_shape[-1]) - capacity += capacity % 2 - capacity = max(capacity, self.min_capacity) - assert capacity > 0 - return capacity - - def set_routing_loss(self, aux_loss: torch.Tensor) -> None: - assert self._routing_loss is None - self._routing_loss = aux_loss - - def pop_routing_loss(self) -> torch.Tensor: - assert self._routing_loss is not None - reservation = self._routing_loss - self._routing_loss = None - return reservation - - -class Top1Router(MoeRouter): - """Top1 router that returns the dispatch mask [s, e, c] and combine weight [s, e, c] - for routing usage. More detailed function can be found in the paper about Switch Transformer - of Google. - Args: - capacity_factor_train (float, optional): Capacity factor in routing of training. - capacity_factor_eval (float, optional): Capacity factor in routing of evaluation. - min_capacity (int, optional): The minimum number of the capacity of each expert. - select_policy (str, optional): The policy about tokens selection. - noisy_func (:class:`typing.Callable`, optional): Noisy function used in logits. - drop_tks (bool, optional): Whether drops tokens in evaluation - """ - - def __init__(self, - capacity_factor_train: float = 1.25, - capacity_factor_eval: float = 2.0, - min_capacity: int = 4, - select_policy: str = "first", - noisy_func: Callable = None, - drop_tks: bool = True): - super().__init__(k_value=1, - capacity_factor_train=capacity_factor_train, - capacity_factor_eval=capacity_factor_eval, - min_capacity=min_capacity, - noisy_func=noisy_func, - drop_tks=drop_tks) - self.select_policy = select_policy - assert select_policy in {"first", "random"} - if select_policy == "random": - self.uniform = torch.distributions.uniform.Uniform(low=torch.tensor(0.0, device=get_current_device()), - high=torch.tensor(1.0, - device=get_current_device())).rsample - - def forward(self, inputs: torch.Tensor, use_kernel: bool = False, ep_group: Optional[ProcessGroup] = None): - - if self.noisy_func is not None and self.training: - inputs = self.noisy_func(inputs) - - assert inputs.dtype == torch.float - logits = F.softmax(inputs, dim=-1) - num_experts = logits.size(-1) - capacity = self.get_capacity(logits.shape) - - top1_idx = torch.argmax(inputs, dim=-1) - mask = F.one_hot(top1_idx, num_classes=num_experts).to(torch.int32) - - # caculate the auxiliary loss - me = torch.mean(logits, dim=0) - ce = torch.mean(mask.float(), dim=0) - l_aux = num_experts * torch.sum(me * ce) - self.set_routing_loss(l_aux) - - if not self.training and not self.drop_tks: - max_num = torch.max(torch.sum(mask, dim=0)) - dist.all_reduce(max_num, op=dist.ReduceOp.MAX, group=ep_group) - capacity = max_num.item() - - if self.select_policy == "random": - rand_mask = mask * self.uniform(mask.shape) - _, dispatch_idx = torch.topk(rand_mask, k=capacity, dim=0) - mask = mask * torch.zeros_like(mask).scatter_(0, dispatch_idx, 1) - ranks = moe_cumsum(mask) - elif self.select_policy == "first": - ranks = moe_cumsum(mask) - mask = mask * torch.lt(ranks, capacity) - else: - raise NotImplementedError("Not support such select policy yet.") - - ranks = torch.sum(mask * ranks, dim=-1) - - if use_kernel: - mask = torch.sum(mask, dim=-1) - mask = torch.stack([mask], dim=0).to(torch.int32) - dest_idx = torch.stack([top1_idx * capacity + ranks], dim=0).to(torch.int32) - return logits, mask, dest_idx, num_experts * capacity - else: - ranks = F.one_hot(ranks, num_classes=capacity) - weight = mask * logits.type_as(inputs) - combine_weights = weight.unsqueeze(2) * ranks.unsqueeze(1) - sec_mask = combine_weights.bool() - return combine_weights, sec_mask - - -class Top2Router(MoeRouter): - """Top2 router that returns the dispatch mask [s, e, c] and combine weight [s, e, c] - for routing usage. More detailed function can be found in the paper about ViT-MoE. - Args: - capacity_factor_train (float, optional): Capacity factor in routing of training. - capacity_factor_eval (float, optional): Capacity factor in routing of evaluation. - min_capacity (int, optional): The minimum number of the capacity of each expert - noisy_func (:class:`typing.Callable`, optional): Noisy function used in logits. - drop_tks (bool, optional): Whether drops tokens in evaluation. - """ - - def __init__(self, - capacity_factor_train: float = 1.25, - capacity_factor_eval: float = 2.0, - min_capacity: int = 4, - noisy_func: Callable = None, - drop_tks: bool = True): - super().__init__(k_value=2, - capacity_factor_train=capacity_factor_train, - capacity_factor_eval=capacity_factor_eval, - min_capacity=min_capacity, - noisy_func=noisy_func, - drop_tks=drop_tks) - - def forward(self, inputs: torch.Tensor, use_kernel: bool = False, ep_group: Optional[ProcessGroup] = None): - # inputs: [s, h] - if self.noisy_func is not None and self.training: - inputs = self.noisy_func(inputs) - - assert inputs.dtype == torch.float - logits = F.softmax(inputs, dim=-1) # logits: [s, e] - num_experts = logits.size(-1) - capacity = self.get_capacity(logits.shape) - - top1_idx = torch.argmax(logits, dim=-1) - mask1 = F.one_hot(top1_idx, num_classes=num_experts).to(torch.int32) - logits_except1 = logits.masked_fill(mask1.bool(), float("-inf")) - top2_idx = torch.argmax(logits_except1, dim=-1) - mask2 = F.one_hot(top2_idx, num_classes=num_experts).to(torch.int32) - - cmask = (mask1 + mask2) # loss: [s, e] - - # caculate the auxiliary loss - me = torch.mean(logits, dim=0) - ce = torch.mean(cmask.float(), dim=0) - l_aux = num_experts * torch.sum(me * ce) / 2.0 # div 2 to normalize it to 1 - self.set_routing_loss(l_aux) - - if not self.training and not self.drop_tks: - max_num = torch.max(torch.sum(cmask, dim=0)) - dist.all_reduce(max_num, op=dist.ReduceOp.MAX, group=ep_group) - capacity = max_num.item() - - rank1 = moe_cumsum(mask1) # rank1: [s, e] - rank2 = moe_cumsum(mask2) - rank2 += torch.sum(mask1, dim=-2, keepdim=True) - - mask1 *= torch.lt(rank1, capacity) - mask2 *= torch.lt(rank2, capacity) - - rank1 = torch.sum(mask1 * rank1, dim=-1) - rank2 = torch.sum(mask2 * rank2, dim=-1) - - if use_kernel: - mask1 = torch.sum(mask1, dim=-1) - mask2 = torch.sum(mask2, dim=-1) - - mask = torch.stack([mask1, mask2], dim=0).to(torch.int32) - dest_idx = torch.stack([top1_idx * capacity + rank1, top2_idx * capacity + rank2], dim=0).to(torch.int32) - - return logits, mask, dest_idx, num_experts * capacity - else: - weight1 = mask1 * logits.type_as(inputs) - weight2 = mask2 * logits.type_as(inputs) - rank1_sc = F.one_hot(rank1, num_classes=capacity) - rank2_sc = F.one_hot(rank2, num_classes=capacity) - - cb_weight1 = weight1.unsqueeze(2) * rank1_sc.unsqueeze(1) - cb_weight2 = weight2.unsqueeze(2) * rank2_sc.unsqueeze(1) - cb_weight = cb_weight1 + cb_weight2 - sec_mask = cb_weight.bool() - - return cb_weight, sec_mask +import math +from abc import ABC +from typing import Callable, Optional + +import torch +import torch.distributed as dist +import torch.nn as nn +import torch.nn.functional as F +from torch.distributed import ProcessGroup + +from colossalai.nn.layer.moe._operation import moe_cumsum +from colossalai.utils import get_current_device + + +class MoeRouter(nn.Module, ABC): + """Base class for all MoE routers. + Args: + k_value (int): The value of top_k. + capacity_factor_train (float): Capacity factor in routing of training. + capacity_factor_eval (float): Capacity factor in routing of evaluation. + min_capacity (int): The minimum number of the capacity of each expert. + noisy_func (:class:`typing.Callable`, optional): Noisy function used in logits. + drop_tks (bool, optional): Whether drops tokens in evaluation + """ + + def __init__( + self, + k_value: int, + capacity_factor_train: float, + capacity_factor_eval: float, + min_capacity: int, + noisy_func: Callable = None, + drop_tks: bool = True, + ): + super().__init__() + self.k_value = k_value + self.capacity_factor_train = capacity_factor_train + self.capacity_factor_eval = capacity_factor_eval + self.min_capacity = min_capacity + self.noisy_func = noisy_func + self.drop_tks = drop_tks + self._routing_loss = None + + def get_capacity(self, logits_shape): + capacity_factor = self.capacity_factor_train if self.training else self.capacity_factor_eval + capacity = math.floor(self.k_value * capacity_factor * logits_shape[-2] / logits_shape[-1]) + capacity += capacity % 2 + capacity = max(capacity, self.min_capacity) + assert capacity > 0 + return capacity + + def set_routing_loss(self, aux_loss: torch.Tensor) -> None: + assert self._routing_loss is None + self._routing_loss = aux_loss + + def pop_routing_loss(self) -> torch.Tensor: + assert self._routing_loss is not None + reservation = self._routing_loss + self._routing_loss = None + return reservation + + +class Top1Router(MoeRouter): + """Top1 router that returns the dispatch mask [s, e, c] and combine weight [s, e, c] + for routing usage. More detailed function can be found in the paper about Switch Transformer + of Google. + Args: + capacity_factor_train (float, optional): Capacity factor in routing of training. + capacity_factor_eval (float, optional): Capacity factor in routing of evaluation. + min_capacity (int, optional): The minimum number of the capacity of each expert. + select_policy (str, optional): The policy about tokens selection. + noisy_func (:class:`typing.Callable`, optional): Noisy function used in logits. + drop_tks (bool, optional): Whether drops tokens in evaluation + """ + + def __init__( + self, + capacity_factor_train: float = 1.25, + capacity_factor_eval: float = 2.0, + min_capacity: int = 4, + select_policy: str = "first", + noisy_func: Callable = None, + drop_tks: bool = True, + ): + super().__init__( + k_value=1, + capacity_factor_train=capacity_factor_train, + capacity_factor_eval=capacity_factor_eval, + min_capacity=min_capacity, + noisy_func=noisy_func, + drop_tks=drop_tks, + ) + self.select_policy = select_policy + assert select_policy in {"first", "random"} + if select_policy == "random": + self.uniform = torch.distributions.uniform.Uniform( + low=torch.tensor(0.0, device=get_current_device()), high=torch.tensor(1.0, device=get_current_device()) + ).rsample + + def forward(self, inputs: torch.Tensor, use_kernel: bool = False, ep_group: Optional[ProcessGroup] = None): + if self.noisy_func is not None and self.training: + inputs = self.noisy_func(inputs) + + assert inputs.dtype == torch.float + logits = F.softmax(inputs, dim=-1) + num_experts = logits.size(-1) + capacity = self.get_capacity(logits.shape) + + top1_idx = torch.argmax(inputs, dim=-1) + mask = F.one_hot(top1_idx, num_classes=num_experts).to(torch.int32) + + # caculate the auxiliary loss + me = torch.mean(logits, dim=0) + ce = torch.mean(mask.float(), dim=0) + l_aux = num_experts * torch.sum(me * ce) + self.set_routing_loss(l_aux) + + if not self.training and not self.drop_tks: + max_num = torch.max(torch.sum(mask, dim=0)) + dist.all_reduce(max_num, op=dist.ReduceOp.MAX, group=ep_group) + capacity = max_num.item() + + if self.select_policy == "random": + rand_mask = mask * self.uniform(mask.shape) + _, dispatch_idx = torch.topk(rand_mask, k=capacity, dim=0) + mask = mask * torch.zeros_like(mask).scatter_(0, dispatch_idx, 1) + ranks = moe_cumsum(mask) + elif self.select_policy == "first": + ranks = moe_cumsum(mask) + mask = mask * torch.lt(ranks, capacity) + else: + raise NotImplementedError("Not support such select policy yet.") + + ranks = torch.sum(mask * ranks, dim=-1) + + if use_kernel: + mask = torch.sum(mask, dim=-1) + mask = torch.stack([mask], dim=0).to(torch.int32) + dest_idx = torch.stack([top1_idx * capacity + ranks], dim=0).to(torch.int32) + return logits, mask, dest_idx, num_experts * capacity + else: + ranks = F.one_hot(ranks, num_classes=capacity) + weight = mask * logits.type_as(inputs) + combine_weights = weight.unsqueeze(2) * ranks.unsqueeze(1) + sec_mask = combine_weights.bool() + return combine_weights, sec_mask + + +class Top2Router(MoeRouter): + """Top2 router that returns the dispatch mask [s, e, c] and combine weight [s, e, c] + for routing usage. More detailed function can be found in the paper about ViT-MoE. + Args: + capacity_factor_train (float, optional): Capacity factor in routing of training. + capacity_factor_eval (float, optional): Capacity factor in routing of evaluation. + min_capacity (int, optional): The minimum number of the capacity of each expert + noisy_func (:class:`typing.Callable`, optional): Noisy function used in logits. + drop_tks (bool, optional): Whether drops tokens in evaluation. + """ + + def __init__( + self, + capacity_factor_train: float = 1.25, + capacity_factor_eval: float = 2.0, + min_capacity: int = 4, + noisy_func: Callable = None, + drop_tks: bool = True, + ): + super().__init__( + k_value=2, + capacity_factor_train=capacity_factor_train, + capacity_factor_eval=capacity_factor_eval, + min_capacity=min_capacity, + noisy_func=noisy_func, + drop_tks=drop_tks, + ) + + def forward(self, inputs: torch.Tensor, use_kernel: bool = False, ep_group: Optional[ProcessGroup] = None): + # inputs: [s, h] + if self.noisy_func is not None and self.training: + inputs = self.noisy_func(inputs) + + assert inputs.dtype == torch.float + logits = F.softmax(inputs, dim=-1) # logits: [s, e] + num_experts = logits.size(-1) + capacity = self.get_capacity(logits.shape) + + top1_idx = torch.argmax(logits, dim=-1) + mask1 = F.one_hot(top1_idx, num_classes=num_experts).to(torch.int32) + logits_except1 = logits.masked_fill(mask1.bool(), float("-inf")) + top2_idx = torch.argmax(logits_except1, dim=-1) + mask2 = F.one_hot(top2_idx, num_classes=num_experts).to(torch.int32) + + cmask = mask1 + mask2 # loss: [s, e] + + # caculate the auxiliary loss + me = torch.mean(logits, dim=0) + ce = torch.mean(cmask.float(), dim=0) + l_aux = num_experts * torch.sum(me * ce) / 2.0 # div 2 to normalize it to 1 + self.set_routing_loss(l_aux) + + if not self.training and not self.drop_tks: + max_num = torch.max(torch.sum(cmask, dim=0)) + dist.all_reduce(max_num, op=dist.ReduceOp.MAX, group=ep_group) + capacity = max_num.item() + + rank1 = moe_cumsum(mask1) # rank1: [s, e] + rank2 = moe_cumsum(mask2) + rank2 += torch.sum(mask1, dim=-2, keepdim=True) + + mask1 *= torch.lt(rank1, capacity) + mask2 *= torch.lt(rank2, capacity) + + rank1 = torch.sum(mask1 * rank1, dim=-1) + rank2 = torch.sum(mask2 * rank2, dim=-1) + + if use_kernel: + mask1 = torch.sum(mask1, dim=-1) + mask2 = torch.sum(mask2, dim=-1) + + mask = torch.stack([mask1, mask2], dim=0).to(torch.int32) + dest_idx = torch.stack([top1_idx * capacity + rank1, top2_idx * capacity + rank2], dim=0).to(torch.int32) + + return logits, mask, dest_idx, num_experts * capacity + else: + weight1 = mask1 * logits.type_as(inputs) + weight2 = mask2 * logits.type_as(inputs) + rank1_sc = F.one_hot(rank1, num_classes=capacity) + rank2_sc = F.one_hot(rank2, num_classes=capacity) + + cb_weight1 = weight1.unsqueeze(2) * rank1_sc.unsqueeze(1) + cb_weight2 = weight2.unsqueeze(2) * rank2_sc.unsqueeze(1) + cb_weight = cb_weight1 + cb_weight2 + sec_mask = cb_weight.bool() + + return cb_weight, sec_mask diff --git a/colossalai/nn/layer/moe/utils.py b/colossalai/nn/layer/moe/utils.py index 4ca8bd7033868706cde62eb45a98f4d88ffa2a67..4f31dd5579dcc9f06e030e387d08e90d7080d478 100644 --- a/colossalai/nn/layer/moe/utils.py +++ b/colossalai/nn/layer/moe/utils.py @@ -1,68 +1,71 @@ -import torch -import torch.nn.functional as F -from colossalai.utils import get_current_device -from colossalai.context.moe_context import MOE_CONTEXT -from .experts import FFNExperts, TPExperts - - -class ForceFP32Parameter(torch.nn.Parameter): - - def half(self, memory_format=None): - return self.data.clone() - - -class NormalNoiseGenerator: - """Generates a random noisy mask for logits tensor. - - All noise is generated from a normal distribution :math:`(0, 1 / E^2)`, where - `E = the number of experts`. - - Args: - num_experts (int): The number of experts. - """ - - def __init__(self, num_experts: int): - self.normal = torch.distributions.normal.Normal(loc=torch.tensor(0.0, device=get_current_device()), - scale=torch.tensor(1.0 / num_experts**2, - device=get_current_device())).rsample - - def __call__(self, inputs: torch.Tensor): - noisy = self.normal(inputs.shape) - return inputs + noisy - - -class UniformNoiseGenerator: - """Generates a random noisy mask for logits tensor. - copied from mesh tensorflow: - Multiply values by a random number between :math:`1-epsilon` and :math:`1+epsilon`. - Makes models more resilient to rounding errors introduced by bfloat16. - This seems particularly important for logits. - - Args: - eps (float, optional): Epsilon in generator, defaults 1e-2. - """ - - def __init__(self, eps: float = 1e-2): - self.uniform = torch.distributions.uniform.Uniform(low=torch.tensor(1.0 - eps, device=get_current_device()), - high=torch.tensor(1.0 + eps, - device=get_current_device())).rsample - - def __call__(self, inputs: torch.Tensor): - noisy = self.uniform(inputs.shape) - return inputs * noisy - - -def autocast_softmax(logit: torch.Tensor, dim: int): - if logit.dtype != torch.float32: - logit = logit.float() - return F.softmax(logit, dim=dim) - - -def build_ffn_experts(num_experts: int, d_model: int, d_ff: int, activation=None, drop_rate: float = 0): - mep_size = MOE_CONTEXT.max_ep_size - if num_experts % mep_size == 0 or mep_size % num_experts == 0: - return FFNExperts(num_experts, d_model, d_ff, activation, drop_rate) - elif d_ff % mep_size == 0: - return TPExperts(num_experts, d_model, d_ff, activation, drop_rate) - else: - raise NotImplementedError(f"Can not build {num_experts} experts in {mep_size} GPUS.") +import torch +import torch.nn.functional as F + +from colossalai.context.moe_context import MOE_CONTEXT +from colossalai.utils import get_current_device + +from .experts import FFNExperts, TPExperts + + +class ForceFP32Parameter(torch.nn.Parameter): + def half(self, memory_format=None): + return self.data.clone() + + +class NormalNoiseGenerator: + """Generates a random noisy mask for logits tensor. + + All noise is generated from a normal distribution :math:`(0, 1 / E^2)`, where + `E = the number of experts`. + + Args: + num_experts (int): The number of experts. + """ + + def __init__(self, num_experts: int): + self.normal = torch.distributions.normal.Normal( + loc=torch.tensor(0.0, device=get_current_device()), + scale=torch.tensor(1.0 / num_experts**2, device=get_current_device()), + ).rsample + + def __call__(self, inputs: torch.Tensor): + noisy = self.normal(inputs.shape) + return inputs + noisy + + +class UniformNoiseGenerator: + """Generates a random noisy mask for logits tensor. + copied from mesh tensorflow: + Multiply values by a random number between :math:`1-epsilon` and :math:`1+epsilon`. + Makes models more resilient to rounding errors introduced by bfloat16. + This seems particularly important for logits. + + Args: + eps (float, optional): Epsilon in generator, defaults 1e-2. + """ + + def __init__(self, eps: float = 1e-2): + self.uniform = torch.distributions.uniform.Uniform( + low=torch.tensor(1.0 - eps, device=get_current_device()), + high=torch.tensor(1.0 + eps, device=get_current_device()), + ).rsample + + def __call__(self, inputs: torch.Tensor): + noisy = self.uniform(inputs.shape) + return inputs * noisy + + +def autocast_softmax(logit: torch.Tensor, dim: int): + if logit.dtype != torch.float32: + logit = logit.float() + return F.softmax(logit, dim=dim) + + +def build_ffn_experts(num_experts: int, d_model: int, d_ff: int, activation=None, drop_rate: float = 0): + mep_size = MOE_CONTEXT.max_ep_size + if num_experts % mep_size == 0 or mep_size % num_experts == 0: + return FFNExperts(num_experts, d_model, d_ff, activation, drop_rate) + elif d_ff % mep_size == 0: + return TPExperts(num_experts, d_model, d_ff, activation, drop_rate) + else: + raise NotImplementedError(f"Can not build {num_experts} experts in {mep_size} GPUS.") diff --git a/colossalai/nn/layer/parallel_1d/__init__.py b/colossalai/nn/layer/parallel_1d/__init__.py deleted file mode 100644 index 2353851df665246251bb7ef0d884dd5f961b7aac..0000000000000000000000000000000000000000 --- a/colossalai/nn/layer/parallel_1d/__init__.py +++ /dev/null @@ -1,7 +0,0 @@ -from .layers import (Classifier1D, Dropout1D, Embedding1D, LayerNorm1D, Linear1D, Linear1D_Col, Linear1D_Row, - PatchEmbedding1D, VocabParallelClassifier1D, VocabParallelEmbedding1D) - -__all__ = [ - 'Linear1D', 'Linear1D_Col', 'Linear1D_Row', 'Embedding1D', 'Dropout1D', 'Classifier1D', 'VocabParallelClassifier1D', - 'VocabParallelEmbedding1D', 'LayerNorm1D', 'PatchEmbedding1D' -] diff --git a/colossalai/nn/layer/parallel_1d/_operation.py b/colossalai/nn/layer/parallel_1d/_operation.py deleted file mode 100644 index 3943345582758edaa1298ac4d927eb132887e072..0000000000000000000000000000000000000000 --- a/colossalai/nn/layer/parallel_1d/_operation.py +++ /dev/null @@ -1,96 +0,0 @@ -import torch -import torch.distributed as dist -from colossalai.core import global_context as gpc - -try: - import fused_mix_prec_layer_norm_cuda -except: - fused_mix_prec_layer_norm_cuda = None - - -class FusedLayerNormAffineFunction1D(torch.autograd.Function): - r"""Layernorm - - Args: - input: input matrix. - weight: weight matrix. - bias: bias matrix. - normalized_shape: input shape from an expected input of size. - :math:`[* \times \text{normalized_shape}[0] \times \text{normalized_shape}[1] \times \ldots \times \text{normalized_shape}[-1]]` - If a single integer is used, it is treated as a singleton list, and this module will - normalize over the last dimension which is expected to be of that specific size. - eps: a value added to the denominator for numerical stability - """ - - @staticmethod - def forward(ctx, input, weight, bias, normalized_shape, eps): - ctx.normalized_shape = normalized_shape - ctx.eps = eps - input_ = input.contiguous() - weight_ = weight.contiguous() - bias_ = bias.contiguous() - output, mean, invvar = fused_mix_prec_layer_norm_cuda.forward_affine(input_, ctx.normalized_shape, weight_, - bias_, ctx.eps) - ctx.save_for_backward(input_, weight_, bias_, mean, invvar) - return output - - @staticmethod - def backward(ctx, grad_output): - input_, weight_, bias_, mean, invvar = ctx.saved_tensors - grad_input = grad_weight = grad_bias = None - grad_input, grad_weight, grad_bias \ - = fused_mix_prec_layer_norm_cuda.backward_affine( - grad_output.contiguous(), mean, invvar, - input_, ctx.normalized_shape, - weight_, bias_, ctx.eps) - - return grad_input, grad_weight, grad_bias, None, None - - -class LinearWithAsyncCommunication(torch.autograd.Function): - """ - Linear layer execution with asynchronous communication in backprop. - """ - - @staticmethod - def forward(ctx, input_, weight, bias, parallel_mode, async_grad_allreduce): - ctx.save_for_backward(input_, weight) - ctx.use_bias = bias is not None - ctx.parallel_mode = parallel_mode - ctx.async_grad_allreduce = async_grad_allreduce - - output = torch.matmul(input_, weight.t()) - if bias is not None: - output = output + bias - return output - - @staticmethod - def backward(ctx, grad_output): - input, weight = ctx.saved_tensors - use_bias = ctx.use_bias - - total_input = input - grad_input = grad_output.matmul(weight) - - # Convert the tensor shapes to 2D for execution compatibility - grad_output = grad_output.view(grad_output.shape[0] * grad_output.shape[1], grad_output.shape[2]) - total_input = total_input.view(total_input.shape[0] * total_input.shape[1], total_input.shape[2]) - - if ctx.async_grad_allreduce: - # Asynchronous all-reduce - handle = dist.all_reduce(grad_input, group=gpc.get_group(ctx.parallel_mode), async_op=True) - # Delay the start of weight gradient computation shortly (3us) to have - # all-reduce scheduled first and have GPU resources allocated - _ = torch.empty(1, device=grad_output.device) + 1 - - grad_weight = grad_output.t().matmul(total_input) - grad_bias = grad_output.sum(dim=0) if use_bias else None - - if ctx.async_grad_allreduce: - handle.wait() - - return grad_input, grad_weight, grad_bias, None, None, None - - -def linear_with_async_comm(input_, weight, bias, parallel_mode, async_grad_allreduce): - return LinearWithAsyncCommunication.apply(input_, weight, bias, parallel_mode, async_grad_allreduce) diff --git a/colossalai/nn/layer/parallel_1d/_utils.py b/colossalai/nn/layer/parallel_1d/_utils.py deleted file mode 100644 index 1212d595635d7c305132edb8c74e01ab165a903b..0000000000000000000000000000000000000000 --- a/colossalai/nn/layer/parallel_1d/_utils.py +++ /dev/null @@ -1,186 +0,0 @@ -#!/usr/bin/env python -# -*- encoding: utf-8 -*- - -import torch -import torch.distributed as dist -from colossalai.core import global_context as gpc -from colossalai.global_variables import tensor_parallel_env as env - -from ..utils import divide - - -def set_parallel_input(input_parallel: bool): - env.parallel_input_1d = input_parallel - - -def get_parallel_input(): - return env.parallel_input_1d - - -def vocab_range_from_per_partition_vocab_size(per_partition_vocab_size, rank): - index_f = rank * per_partition_vocab_size - index_l = index_f + per_partition_vocab_size - return index_f, index_l - - -def vocab_range_from_global_vocab_size(global_vocab_size, rank, world_size): - per_partition_vocab_size = divide(global_vocab_size, world_size) - return vocab_range_from_per_partition_vocab_size(per_partition_vocab_size, rank) - - -def _reduce(input_, parallel_mode): - # skip if only one rank involved - if gpc.get_world_size(parallel_mode) == 1: - return input_ - group = gpc.get_cpu_group(parallel_mode) if input_.device.type == "cpu" else gpc.get_group(parallel_mode) - dist.all_reduce(input_, group=group) - - return input_ - - -def _split(input_, parallel_mode, dim=-1): - # skip if only one rank involved - world_size = gpc.get_world_size(parallel_mode) - if world_size == 1: - return input_ - - # Split along last dimension. - dim_size = input_.size(dim) - assert dim_size % world_size == 0, \ - f'The dimension to split ({dim_size}) is not a multiple of world size ({world_size}), ' \ - f'cannot split tensor evenly' - - tensor_list = torch.split(input_, dim_size // world_size, dim=dim) - rank = gpc.get_local_rank(parallel_mode) - output = tensor_list[rank].contiguous() - - return output - - -def _gather(input_, parallel_mode, dim=-1): - # skip if only one rank involved - world_size = gpc.get_world_size(parallel_mode) - if world_size == 1: - return input_ - - # all gather - rank = gpc.get_local_rank(parallel_mode) - tensor_list = [torch.empty_like(input_) for _ in range(world_size)] - tensor_list[rank] = input_ - group = gpc.get_cpu_group(parallel_mode) if input_.device.type == "cpu" else gpc.get_group(parallel_mode) - torch.distributed.all_gather(tensor_list, input_, group=group) - - # concat - output = torch.cat(tensor_list, dim=dim).contiguous() - - return output - - -class _ReduceGrad(torch.autograd.Function): - """ - Pass the input to the model parallel region. - - Args: - input_: input matrix. - parallel_mode: parallel mode. - """ - - @staticmethod - def symbolic(graph, input_): - return input_ - - @staticmethod - def forward(ctx, input_, parallel_mode): - ctx.mode = parallel_mode - return input_ - - @staticmethod - def backward(ctx, grad_output): - return _reduce(grad_output, ctx.mode), None - - -class _ReduceInput(torch.autograd.Function): - """ - All-reduce the input from the model parallel region. - - Args: - input_: input matrix. - parallel_mode: parallel mode. - """ - - @staticmethod - def symbolic(graph, input_): - return _reduce(input_) - - @staticmethod - def forward(ctx, input_, parallel_mode): - return _reduce(input_, parallel_mode) - - @staticmethod - def backward(ctx, grad_output): - return grad_output, None - - -class _SplitForwardGatherBackward(torch.autograd.Function): - """ - Split the input and keep only the corresponding chuck to the rank. - - Args: - input_: input matrix. - parallel_mode: parallel mode. - dim: dimension - """ - - @staticmethod - def symbolic(graph, input_): - return _split(input_) - - @staticmethod - def forward(ctx, input_, parallel_mode, dim): - ctx.mode = parallel_mode - ctx.dim = dim - return _split(input_, parallel_mode, dim) - - @staticmethod - def backward(ctx, grad_output): - return _gather(grad_output, ctx.mode, ctx.dim), None, None - - -class _GatherForwardSplitBackward(torch.autograd.Function): - """Gather the input from model parallel region and concatenate. - - Args: - input_: input matrix. - parallel_mode: parallel mode. - dim: dimension - """ - - @staticmethod - def symbolic(graph, input_): - return _gather(input_) - - @staticmethod - def forward(ctx, input_, parallel_mode, dim): - ctx.mode = parallel_mode - ctx.dim = dim - return _gather(input_, parallel_mode, dim) - - @staticmethod - def backward(ctx, grad_output): - return _split(grad_output, ctx.mode, ctx.dim), None, None - - -def reduce_grad(input_, parallel_mode): - return _ReduceGrad.apply(input_, parallel_mode) - - -def reduce_input(input_, parallel_mode): - return _ReduceInput.apply(input_, parallel_mode) - - -def split_forward_gather_backward(input_, parallel_mode, dim): - return _SplitForwardGatherBackward.apply(input_, parallel_mode, dim) - - -def gather_forward_split_backward(input_, parallel_mode, dim): - return _GatherForwardSplitBackward.apply(input_, parallel_mode, dim) diff --git a/colossalai/nn/layer/parallel_1d/layers.py b/colossalai/nn/layer/parallel_1d/layers.py deleted file mode 100644 index 406173a18c6010de0b2004f641ec79df3de32dd3..0000000000000000000000000000000000000000 --- a/colossalai/nn/layer/parallel_1d/layers.py +++ /dev/null @@ -1,1040 +0,0 @@ -#!/usr/bin/env python -# -*- encoding: utf-8 -*- - -import math -from collections import OrderedDict -from typing import Callable, Tuple - -import torch -import torch.nn.functional as F -from torch import Tensor -from torch.nn.parameter import Parameter - -from colossalai.communication import broadcast -from colossalai.context import ParallelMode, seed -from colossalai.core import global_context as gpc -from colossalai.global_variables import tensor_parallel_env as env -from colossalai.kernel import LayerNorm -from colossalai.nn import init as init -from colossalai.registry import LAYERS -from colossalai.utils.checkpointing import ( - broadcast_state_dict, - gather_tensor_parallel_state_dict, - partition_tensor_parallel_state_dict, -) -from colossalai.utils.cuda import get_current_device - -from ..base_layer import ParallelLayer -from ..colossalai_layer._utils import ColossalaiModule -from ..utils import divide, set_tensor_parallel_attribute_by_partition -from ..vanilla import VanillaLayerNorm, VanillaPatchEmbedding -from ._operation import linear_with_async_comm -from ._utils import ( - gather_forward_split_backward, - get_parallel_input, - reduce_grad, - reduce_input, - set_parallel_input, - split_forward_gather_backward, -) - -Fast_LN = None -try: - from apex.contrib.layer_norm.layer_norm import FastLayerNorm - Fast_LN = FastLayerNorm -except ImportError: - pass - - -@LAYERS.register_module -class Linear1D(ColossalaiModule): - r"""Linear layer for 1D parallelism. - - Args: - in_features (int): size of each input sample. - out_features (int): size of each output sample. - bias (bool, optional): If set to ``False``, the layer will not learn an additive bias, defaults to ``True``. - dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None. - gather_output (bool, optional): Whether to call all-gather on output, defaults to False. - skip_bias_add (bool, optional): If set to ``True``, it will skip bias add for linear layer, - which is preserved for kernel fusion, defaults to False - weight_initializer (:class:`typing.Callable`, optional): - The initializer of weight, defaults to kaiming uniform initializer. - bias_initializer (:class:`typing.Callable`, optional): - The initializer of bias, defaults to xavier uniform initializer. - - More details about ``initializer`` please refer to - `init `_. - """ - - def __init__(self, - in_features: int, - out_features: int, - bias: bool = True, - dtype: torch.dtype = None, - gather_output: bool = False, - skip_bias_add: bool = False, - weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), - bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1)): - parallel_input = get_parallel_input() - if not parallel_input and not gather_output: - layer = Linear1D_Col(in_features, - out_features, - bias=bias, - dtype=dtype, - skip_bias_add=skip_bias_add, - weight_initializer=weight_initializer, - bias_initializer=bias_initializer) - else: - layer = Linear1D_Row(in_features, - out_features, - bias=bias, - dtype=dtype, - parallel_input=parallel_input, - skip_bias_add=skip_bias_add, - weight_initializer=weight_initializer, - bias_initializer=bias_initializer) - super().__init__(layer) - - -@LAYERS.register_module -class LayerNorm1D(ColossalaiModule): - r""" - Layer Normalization for colossalai - - Args: - normalized_shape (int): input shape from an expected input of size. - :math:`[* \times \text{normalized_shape}[0] \times \text{normalized_shape}[1] - \times \ldots \times \text{normalized_shape}[-1]]` - If a single integer is used, it is treated as a singleton list, and this module will - normalize over the last dimension which is expected to be of that specific size. - eps (float): a value added to the denominator for numerical stability, defaults to 1e-05. - bias (bool, optional): Whether to add a bias, defaults to ``True``. - dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None. - """ - - _fast_ln_supported_sizes = [ - 1024, 1536, 2048, 2304, 3072, 3840, 4096, 5120, 6144, 8192, 10240, 12288, 12800, 15360, 16384, 18432, 20480, - 24576, 25600, 30720, 32768, 40960, 49152, 65536 - ] - - def __init__(self, normalized_shape: int, eps=1e-05, bias=True, dtype=None): - if Fast_LN is not None and normalized_shape in self._fast_ln_supported_sizes: - norm = Fast_LN(normalized_shape, eps=eps).to(dtype) - else: - norm = None - try: - from apex.normalization import FusedLayerNorm - norm = FusedLayerNorm(normalized_shape, eps=eps).to(dtype) - except ImportError: - norm = LayerNorm(normalized_shape, eps=eps).to(dtype) - super().__init__(norm) - - def _load_from_state_dict(self, state_dict, prefix, *args): - local_state = OrderedDict() - weight_key = prefix + 'weight' - bias_key = prefix + 'bias' - if gpc.get_local_rank(ParallelMode.TENSOR) == 0: - # weight - weight = state_dict.pop(weight_key, None) - if weight is not None: - local_state[weight_key] = weight - # bias - bias = state_dict.pop(bias_key, None) - if bias is not None: - local_state[bias_key] = bias - - local_state = broadcast_state_dict(local_state, ParallelMode.PARALLEL_1D) - super()._load_from_state_dict(local_state, prefix, *args) - - def _save_to_state_dict(self, destination, prefix, keep_vars): - if gpc.get_local_rank(ParallelMode.TENSOR) == 0: - super()._save_to_state_dict(destination, prefix, keep_vars) - - -@LAYERS.register_module -class Classifier1D(ParallelLayer): - r"""RowLinear with given weight. Classifier of 1D parallelism. - - Args: - in_features (int): size of each input sample. - num_classes (int): number of classes. - weight (:class:`torch.nn.Parameter`, optional): weight of the classifier, defaults to None. - bias (bool, optional): If set to ``False``, the layer will not learn an additive bias, defaults to ``True``. - dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None. - weight_initializer (:class:`typing.Callable`, optional): - The initializer of weight, defaults to kaiming uniform initializer. - bias_initializer (:class:`typing.Callable`, optional): - The initializer of bias, defaults to xavier uniform initializer. - - More details about ``initializer`` please refer to - `init `_. - """ - - def __init__(self, - in_features: int, - num_classes: int, - weight: Parameter = None, - bias: bool = True, - dtype: torch.dtype = None, - weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), - bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1)): - super().__init__() - self.in_features = in_features - self.num_classes = num_classes - self.parallel_input = get_parallel_input() - - # Divide the weight matrix along the last dimension. - self.input_size_per_partition = divide(in_features, gpc.tensor_parallel_size) - - # Parameters. - # Initialize weight. - factory_kwargs = {'device': get_current_device(), 'dtype': dtype} - if weight is not None: - self.weight = weight - self.has_weight = False - else: - self.weight = Parameter(torch.empty(self.num_classes, self.input_size_per_partition, **factory_kwargs)) - self.has_weight = True - if bias: - self.bias = Parameter(torch.empty(self.num_classes, **factory_kwargs)) - else: - self.bias = None - with seed(ParallelMode.TENSOR): - self.reset_parameters(weight_initializer, bias_initializer) - self._set_tensor_parallel_attributes() - set_parallel_input(False) - env.vocab_parallel = False - - def reset_parameters(self, weight_initializer, bias_initializer) -> None: - fan_in, fan_out = self.in_features, self.num_classes - if self.has_weight: - weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out) - if self.bias is not None: - bias_initializer(self.bias, fan_in=fan_in) - broadcast(self.bias, gpc.get_ranks_in_group(ParallelMode.PARALLEL_1D)[0], ParallelMode.PARALLEL_1D) - - def _set_tensor_parallel_attributes(self): - if self.has_weight: - num_partition = gpc.get_world_size(ParallelMode.TENSOR) - set_tensor_parallel_attribute_by_partition(self.weight, num_partition) - - def _load_from_global_state_dict(self, state_dict, prefix, *args): - local_state = OrderedDict() - weight_key = prefix + 'weight' - bias_key = prefix + 'bias' - if gpc.get_local_rank(ParallelMode.TENSOR) == 0: - # weight - if self.has_weight: - weight = state_dict.pop(weight_key, None) - if weight is not None: - local_state[weight_key] = weight - # bias - if self.bias is not None: - bias = state_dict.pop(bias_key, None) - if bias is not None: - local_state[bias_key] = bias - - local_state = partition_tensor_parallel_state_dict(local_state, - ParallelMode.PARALLEL_1D, - dims={ - weight_key: -1, - bias_key: 0 - }, - partition_states={ - weight_key: True, - bias_key: False - }) - super()._load_from_global_state_dict(local_state, prefix, *args) - - def _save_to_global_state_dict(self, destination, prefix, keep_vars): - weight_key = prefix + 'weight' - bias_key = prefix + 'bias' - local_state = OrderedDict() - if self.has_weight: - local_state[weight_key] = self.weight - if self.bias is not None: - local_state[bias_key] = self.bias - local_state = gather_tensor_parallel_state_dict(local_state, - ParallelMode.PARALLEL_1D, - dims={ - weight_key: -1, - bias_key: 0 - }, - partition_states={ - weight_key: True, - bias_key: False - }, - keep_vars=keep_vars) - destination.update(local_state) - - def forward(self, input_: Tensor) -> Tensor: - # Set up backprop all-reduce. - if self.parallel_input: - assert input_.shape[-1] == self.weight.shape[-1], \ - 'Invalid shapes in Classifier1D forward: input={}, weight={}. Expected last dim of input {}.'.format( - input_.shape, self.weight.shape, self.weight.shape[-1]) - input_ = input_ - else: - assert divide(input_.shape[-1], gpc.tensor_parallel_size) == self.weight.shape[-1], \ - 'Invalid shapes in Classifier1D forward: input={}, weight={}. Expected last dim of input {}.'.format( - input_.shape, self.weight.shape, self.weight.shape[-1] * gpc.tensor_parallel_size) - input_ = split_forward_gather_backward(input_, ParallelMode.PARALLEL_1D, dim=-1) - - output_parallel = F.linear(input_, self.weight) - output = reduce_input(output_parallel, ParallelMode.PARALLEL_1D) - if self.bias is not None: - output = output + self.bias - return output - - -@LAYERS.register_module -class VocabParallelClassifier1D(ParallelLayer): - r"""ColLinear with given weight. Classifier of 1D parallelism. - - Args: - in_features (int): size of each input sample. - num_classes (int): number of classes. - weight (:class:`torch.nn.Parameter`, optional): weight of the classifier, defaults to None. - bias (bool, optional): If set to ``False``, the layer will not learn an additive bias, defaults to ``True``. - dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None. - weight_initializer (:class:`typing.Callable`, optional): - The initializer of weight, defaults to kaiming uniform initializer. - bias_initializer (:class:`typing.Callable`, optional): - The initializer of bias, defaults to xavier uniform initializer. - - More details about ``initializer`` please refer to - `init `_. - """ - - def __init__(self, - in_features: int, - num_classes: int, - weight: Parameter = None, - bias: bool = True, - dtype: torch.dtype = None, - gather_output: bool = False, - weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), - bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1)): - super().__init__() - self.in_features = in_features - self.num_classes = num_classes - self.gather_output = gather_output - self.parallel_input = get_parallel_input() - - # Divide the weight matrix along the last dimension. - self.num_classes_per_partition = divide(num_classes, gpc.tensor_parallel_size) - - # Parameters. - # Initialize weight. - factory_kwargs = {'device': get_current_device(), 'dtype': dtype} - if weight is not None: - self.weight = weight - self.has_weight = False - else: - self.weight = Parameter(torch.empty(self.num_classes_per_partition, self.in_features, **factory_kwargs)) - self.has_weight = True - if bias: - self.bias = Parameter(torch.empty(self.num_classes_per_partition, **factory_kwargs)) - else: - self.bias = None - with seed(ParallelMode.TENSOR): - self.reset_parameters(weight_initializer, bias_initializer) - self._set_tensor_parallel_attributes() - set_parallel_input(False) - env.vocab_parallel = True - - def reset_parameters(self, weight_initializer, bias_initializer) -> None: - fan_in, fan_out = self.in_features, self.num_classes - if self.has_weight: - weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out) - if self.bias is not None: - bias_initializer(self.bias, fan_in=fan_in) - - def _set_tensor_parallel_attributes(self): - num_partition = gpc.get_world_size(ParallelMode.TENSOR) - if self.has_weight: - set_tensor_parallel_attribute_by_partition(self.weight, num_partition) - if self.bias is not None: - set_tensor_parallel_attribute_by_partition(self.bias, num_partition) - - def _load_from_global_state_dict(self, state_dict, prefix, *args): - local_state = OrderedDict() - weight_key = prefix + 'weight' - bias_key = prefix + 'bias' - if gpc.get_local_rank(ParallelMode.TENSOR) == 0: - # weight - if self.has_weight: - weight = state_dict.pop(weight_key, None) - if weight is not None: - local_state[weight_key] = weight - # bias - if self.bias is not None: - bias = state_dict.pop(bias_key, None) - if bias is not None: - local_state[bias_key] = bias - - local_state = partition_tensor_parallel_state_dict(local_state, - ParallelMode.PARALLEL_1D, - dims={ - weight_key: 0, - bias_key: 0 - }, - partition_states={ - weight_key: True, - bias_key: True - }) - super()._load_from_global_state_dict(local_state, prefix, *args) - - def _save_to_global_state_dict(self, destination, prefix, keep_vars): - weight_key = prefix + 'weight' - bias_key = prefix + 'bias' - local_state = OrderedDict() - if self.has_weight: - local_state[weight_key] = self.weight - if self.bias is not None: - local_state[bias_key] = self.bias - local_state = gather_tensor_parallel_state_dict(local_state, - ParallelMode.PARALLEL_1D, - dims={ - weight_key: 0, - bias_key: 0 - }, - partition_states={ - weight_key: True, - bias_key: True - }, - keep_vars=keep_vars) - destination.update(local_state) - - def forward(self, input_: Tensor) -> Tensor: - assert input_.shape[-1] == self.weight.shape[-1], \ - 'Invalid shapes in VocabParallelClassifier1D forward: input={}, weight={}. Expected last dim of input {}.'.format( - input_.shape, self.weight.shape, self.weight.shape[-1]) - # Set up backprop all-reduce. - input_parallel = reduce_grad(input_, ParallelMode.PARALLEL_1D) - # Matrix multiply. - output_parallel = F.linear(input_parallel, self.weight, self.bias) - if self.gather_output: - # All-gather across the partitions. - output = gather_forward_split_backward(output_parallel, ParallelMode.PARALLEL_1D, dim=-1) - else: - output = output_parallel - return output - - -@LAYERS.register_module -class Linear1D_Col(ParallelLayer): - r"""Linear layer with column parallelism. - - The linear layer is defined as :math:`Y = XA + b`. A is parallelized along - its second dimension as :math:`A = [A_1, ..., A_p]`. - - Args: - in_features (int): size of each input sample. - out_features (int): size of each output sample. - bias (bool, optional): If set to ``False``, the layer will not learn an additive bias, defaults to ``True``. - dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None. - gather_output (bool, optional): If true, call all-gather on output and make Y available - to all GPUs, otherwise, every GPU will have its output - which is :math:`Y_i = XA_i`, defaults to False - skip_bias_add (bool, optional): If set to ``True``, it will skip bias add for linear layer, - which is preserved for kernel fusion, defaults to False - weight_initializer (:class:`typing.Callable`, optional): - The initializer of weight, defaults to kaiming uniform initializer. - bias_initializer (:class:`typing.Callable`, optional): - The initializer of bias, defaults to xavier uniform initializer. - - More details about ``initializer`` please refer to - `init `_. - """ - - def __init__(self, - in_features: int, - out_features: int, - bias: bool = True, - dtype: torch.dtype = None, - gather_output: bool = False, - skip_bias_add: bool = False, - weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), - bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1)): - super().__init__() - - # Keep input parameters - self.in_features = in_features - self.out_features = out_features - self.gather_output = gather_output - self.skip_bias_add = skip_bias_add - - if skip_bias_add and not bias: - raise ValueError('cannot skip bias addition if bias is None') - - self.out_features_per_partition = divide(out_features, gpc.tensor_parallel_size) - - # Parameters. - # Initialize weight. - factory_kwargs = {'device': get_current_device(), 'dtype': dtype} - self.weight = Parameter(torch.empty(self.out_features_per_partition, self.in_features, **factory_kwargs)) - - if bias: - self.bias = Parameter(torch.empty(self.out_features_per_partition, **factory_kwargs)) - else: - self.bias = None - with seed(ParallelMode.TENSOR): - self.reset_parameters(weight_initializer, bias_initializer) - self._set_tensor_parallel_attributes() - is_parallel_output = not self.gather_output - set_parallel_input(is_parallel_output) - - def reset_parameters(self, weight_initializer, bias_initializer) -> None: - fan_in, fan_out = self.in_features, self.out_features - weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out) - if self.bias is not None: - bias_initializer(self.bias, fan_in=fan_in) - - def _set_tensor_parallel_attributes(self): - num_partition = gpc.get_world_size(ParallelMode.TENSOR) - set_tensor_parallel_attribute_by_partition(self.weight, num_partition) - if self.bias is not None: - set_tensor_parallel_attribute_by_partition(self.bias, num_partition) - - def _load_from_global_state_dict(self, state_dict, prefix, *args): - local_state = OrderedDict() - weight_key = prefix + 'weight' - bias_key = prefix + 'bias' - if gpc.get_local_rank(ParallelMode.TENSOR) == 0: - # weight - weight = state_dict.pop(weight_key, None) - if weight is not None: - local_state[weight_key] = weight - # bias - if self.bias is not None: - bias = state_dict.pop(bias_key, None) - if bias is not None: - local_state[bias_key] = bias - - local_state = partition_tensor_parallel_state_dict(local_state, - ParallelMode.PARALLEL_1D, - dims={ - weight_key: 0, - bias_key: 0 - }, - partition_states={ - weight_key: True, - bias_key: True - }) - super()._load_from_global_state_dict(local_state, prefix, *args) - - def _save_to_global_state_dict(self, destination, prefix, keep_vars): - weight_key = prefix + 'weight' - bias_key = prefix + 'bias' - local_state = OrderedDict({weight_key: self.weight}) - if self.bias is not None: - local_state[bias_key] = self.bias - local_state = gather_tensor_parallel_state_dict(local_state, - ParallelMode.PARALLEL_1D, - dims={ - weight_key: 0, - bias_key: 0 - }, - partition_states={ - weight_key: True, - bias_key: True - }, - keep_vars=keep_vars) - destination.update(local_state) - - def forward(self, input_: Tensor) -> Tuple[Tensor, Tensor]: - assert input_.shape[-1] == self.weight.shape[-1], \ - 'Invalid shapes in Linear1D_Col forward: input={}, weight={}. Expected last dim of input {}.'.format( - input_.shape, self.weight.shape, self.weight.shape[-1]) - # Set up backprop all-reduce. - # input_parallel = reduce_grad(input_, ParallelMode.PARALLEL_1D) - input_parallel = input_ - # Matrix multiply. - bias = self.bias if not self.skip_bias_add else None - # output_parallel = F.linear(input_parallel, self.weight, bias) - output_parallel = linear_with_async_comm(input_parallel, self.weight, bias, ParallelMode.PARALLEL_1D, True) - if self.gather_output: - # All-gather across the partitions. - output = gather_forward_split_backward(output_parallel, ParallelMode.PARALLEL_1D, dim=-1) - else: - output = output_parallel - - if self.skip_bias_add: - return output, self.bias - else: - return output - - -@LAYERS.register_module -class Linear1D_Row(ParallelLayer): - r""" Linear layer with row parallelism - - Args: - in_features (int): size of each input sample. - out_features (int): size of each output sample. - bias (bool, optional): If set to ``False``, the layer will not learn an additive bias, defaults to ``True``. - dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None. - parallel_input (bool, optional): If set to ``True``, it's assumed that the input is split, defaults to False. - skip_bias_add (bool, optional): If set to ``True``, it will skip bias add for linear layer, - which is preserved for kernel fusion, defaults to False - weight_initializer (:class:`typing.Callable`, optional): - The initializer of weight, defaults to kaiming uniform initializer. - bias_initializer (:class:`typing.Callable`, optional): - The initializer of bias, defaults to xavier uniform initializer. - - More details about ``initializer`` please refer to - `init `_. - """ - - def __init__(self, - in_features: int, - out_features: int, - bias: bool = True, - dtype: torch.dtype = None, - parallel_input: bool = True, - skip_bias_add: bool = False, - weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), - bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1), - stream_chunk_num: int = 1): - super().__init__() - - self.stream_chunk_num = stream_chunk_num - - # Keep input parameters - self.in_features = in_features - self.out_features = out_features - self.parallel_input = parallel_input - self.skip_bias_add = skip_bias_add - - if skip_bias_add and not bias: - raise ValueError('cannot skip bias addition if bias is None') - - # Divide the weight matrix along the last dimension. - self.input_size_per_partition = divide(in_features, gpc.tensor_parallel_size) - - # Parameters. - # Initialize weight. - factory_kwargs = {'device': get_current_device(), 'dtype': dtype} - self.weight = Parameter(torch.empty(self.out_features, self.input_size_per_partition, **factory_kwargs)) - - if self.stream_chunk_num > 1: - # TODO() work for inference only - self.chunk_weight() - if bias: - self.bias = Parameter(torch.empty(self.out_features, **factory_kwargs)) - else: - self.bias = None - with seed(ParallelMode.TENSOR): - self.reset_parameters(weight_initializer, bias_initializer) - self._set_tensor_parallel_attributes() - set_parallel_input(False) - - def chunk_weight(self): - self.weight_list = torch.chunk(self.weight, self.stream_chunk_num, dim=0) - - def reset_parameters(self, weight_initializer, bias_initializer) -> None: - fan_in, fan_out = self.in_features, self.out_features - weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out) - if self.bias is not None: - bias_initializer(self.bias, fan_in=fan_in) - broadcast(self.bias, gpc.get_ranks_in_group(ParallelMode.PARALLEL_1D)[0], ParallelMode.PARALLEL_1D) - - def _set_tensor_parallel_attributes(self): - num_partition = gpc.get_world_size(ParallelMode.TENSOR) - set_tensor_parallel_attribute_by_partition(self.weight, num_partition) - - def _load_from_global_state_dict(self, state_dict, prefix, *args): - local_state = OrderedDict() - weight_key = prefix + 'weight' - bias_key = prefix + 'bias' - if gpc.get_local_rank(ParallelMode.TENSOR) == 0: - # weight - weight = state_dict.pop(weight_key, None) - if weight is not None: - local_state[weight_key] = weight - # bias - if self.bias is not None: - bias = state_dict.pop(bias_key, None) - if bias is not None: - local_state[bias_key] = bias - - local_state = partition_tensor_parallel_state_dict(local_state, - ParallelMode.PARALLEL_1D, - dims={ - weight_key: -1, - bias_key: 0 - }, - partition_states={ - weight_key: True, - bias_key: False - }) - super()._load_from_global_state_dict(local_state, prefix, *args) - - def _save_to_global_state_dict(self, destination, prefix, keep_vars): - weight_key = prefix + 'weight' - bias_key = prefix + 'bias' - local_state = OrderedDict({weight_key: self.weight}) - if self.bias is not None: - local_state[bias_key] = self.bias - local_state = gather_tensor_parallel_state_dict(local_state, - ParallelMode.PARALLEL_1D, - dims={ - weight_key: -1, - bias_key: 0 - }, - partition_states={ - weight_key: True, - bias_key: False - }, - keep_vars=keep_vars) - destination.update(local_state) - - def forward(self, input_: Tensor) -> Tensor: - # Set up backprop all-reduce. - if self.parallel_input: - assert input_.shape[-1] == self.weight.shape[-1], \ - 'Invalid shapes in Linear1D_Row forward: input={}, weight={}. Expected last dim of input {}.'.format( - input_.shape, self.weight.shape, self.weight.shape[-1]) - input_ = input_ - else: - assert divide(input_.shape[-1], gpc.tensor_parallel_size) == self.weight.shape[-1], \ - 'Invalid shapes in Linear1D_Row forward: input={}, weight={}. Expected last dim of input {}.'.format( - input_.shape, self.weight.shape, self.weight.shape[-1] * gpc.tensor_parallel_size) - input_ = split_forward_gather_backward(input_, ParallelMode.PARALLEL_1D, dim=-1) - - if self.stream_chunk_num > 1: - if self.training: - raise RuntimeError("use stream_chunk_num=1 in Linear1D_Row for training!") - with torch.no_grad(): - output_parallel_list = [None for i in range(self.stream_chunk_num)] - handle_list = [] - for i in range(self.stream_chunk_num): - output_parallel_list[i] = F.linear(input_, self.weight_list[i]) - handle = torch.distributed.all_reduce(output_parallel_list[i], - group=gpc.get_group(ParallelMode.PARALLEL_1D), - async_op=True) - handle_list.append(handle) - # output_parallel_list[i] = reduce_input(output_parallel_list[i], ParallelMode.PARALLEL_1D) - for handle in handle_list: - handle.wait() - output = torch.cat(output_parallel_list, dim=-1) - else: - output_parallel = F.linear(input_, self.weight) - # output_parallel = linear_with_async_comm(input_, self.weight, None, ParallelMode.PARALLEL_1D, False) - output = reduce_input(output_parallel, ParallelMode.PARALLEL_1D) - if not self.skip_bias_add: - if self.bias is not None: - output = output + self.bias - return output - else: - return output, self.bias - - -@LAYERS.register_module -class Embedding1D(ParallelLayer): - r"""Embedding for 1D parallelism. - - Args: - num_embeddings (int): number of embeddings. - embedding_dim (int): dimension of embedding. - padding_idx (int, optional): If specified, the entries at padding_idx do not contribute to the gradient; - therefore, the embedding vector at padding_idx is not updated during training, - i.e. it remains as a fixed “pad”, defaults to None. - dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None. - weight_initializer (:class:`typing.Callable`, optional): - he initializer of weight, defaults to normal initializer. - - The ``args`` and ``kwargs`` used in :class:`torch.nn.functional.embedding` should contain: - :: - - max_norm (float, optional): If given, each embedding vector with norm larger than max_norm is - renormalized to have norm max_norm. Note: this will modify weight in-place. - norm_type (float, optional): The p of the p-norm to compute for the max_norm option. Default 2. - scale_grad_by_freq (bool, optional): If given, this will scale gradients by the inverse - of frequency of the words in the mini-batch. Default False. - sparse (bool, optional): If True, gradient w.r.t. weight will be a sparse tensor. Default False. - - More details about ``args`` and ``kwargs`` could be found in - `Embedding `_. - - More details about ``initializer`` please refer to - `init `_ - """ - - def __init__(self, - num_embeddings: int, - embedding_dim: int, - padding_idx: int = None, - dtype: torch.dtype = None, - weight_initializer: Callable = init.normal_(), - *args, - **kwargs): - super().__init__() - - self.num_embeddings = num_embeddings - self.embed_dim = embedding_dim - embed_dim_per_partition = divide(embedding_dim, gpc.tensor_parallel_size) - - self.padding_idx = padding_idx - self.embed_args = args - self.embed_kwargs = kwargs - - self.weight = Parameter( - torch.empty((num_embeddings, embed_dim_per_partition), device=get_current_device(), dtype=dtype)) - - self.reset_parameters(weight_initializer) - self._set_tensor_parallel_attributes() - set_parallel_input(False) - - def _set_tensor_parallel_attributes(self): - set_tensor_parallel_attribute_by_partition(self.weight, gpc.tensor_parallel_size) - - def reset_parameters(self, weight_initializer) -> None: - with seed(ParallelMode.TENSOR): - fan_in, fan_out = self.num_embeddings, self.embed_dim - weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out) - self._fill_padding_idx_with_zero() - - def _fill_padding_idx_with_zero(self) -> None: - if self.padding_idx is not None: - with torch.no_grad(): - self.weight[self.padding_idx].fill_(0) - - def _load_from_global_state_dict(self, state_dict, prefix, *args): - local_state = OrderedDict() - weight_key = prefix + 'weight' - if gpc.get_local_rank(ParallelMode.TENSOR) == 0: - # weight - weight = state_dict.pop(weight_key, None) - if weight is not None: - local_state[weight_key] = weight - - local_state = partition_tensor_parallel_state_dict(local_state, - ParallelMode.PARALLEL_1D, - dims={weight_key: -1}, - partition_states={weight_key: True}) - super()._load_from_global_state_dict(local_state, prefix, *args) - - def _save_to_global_state_dict(self, destination, prefix, keep_vars): - weight_key = prefix + 'weight' - local_state = OrderedDict({weight_key: self.weight}) - local_state = gather_tensor_parallel_state_dict(local_state, - ParallelMode.PARALLEL_1D, - dims={weight_key: -1}, - partition_states={weight_key: True}, - keep_vars=keep_vars) - destination.update(local_state) - - def forward(self, input_: Tensor) -> Tensor: - - output_parallel = F.embedding(input_, self.weight, self.padding_idx, *self.embed_args, **self.embed_kwargs) - - output = gather_forward_split_backward(output_parallel, ParallelMode.PARALLEL_1D, dim=-1) - - return output - - -@LAYERS.register_module -class VocabParallelEmbedding1D(ParallelLayer): - r"""Embedding parallelized in the vocabulary dimension. - - Args: - num_embeddings (int): number of embeddings. - embedding_dim (int): dimension of embedding. - padding_idx (int, optional): If specified, the entries at padding_idx do not contribute to the gradient; - therefore, the embedding vector at padding_idx is not updated during training, - i.e. it remains as a fixed “pad”, defaults to None. - dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None. - weight_initializer (:class:`typing.Callable`, optional): - he initializer of weight, defaults to normal initializer. - - The ``args`` and ``kwargs`` used in :class:``torch.nn.functional.embedding`` should contain: - :: - - max_norm (float, optional): If given, each embedding vector with norm larger than max_norm is - renormalized to have norm max_norm. Note: this will modify weight in-place. - norm_type (float, optional): The p of the p-norm to compute for the max_norm option. Default 2. - scale_grad_by_freq (bool, optional): If given, this will scale gradients by the inverse - of frequency of the words in the mini-batch. Default False. - sparse (bool, optional): If True, gradient w.r.t. weight will be a sparse tensor. Default False. - - More details about ``args`` and ``kwargs`` could be found in - `Embedding `_. - - More details about initializer please refer to - `init `_. - """ - - def __init__(self, - num_embeddings: int, - embedding_dim: int, - padding_idx: int = None, - dtype: torch.dtype = None, - weight_initializer: Callable = init.normal_(), - *args, - **kwargs): - super().__init__() - self.num_embeddings = num_embeddings - self.embed_dim = embedding_dim - self.padding_idx = padding_idx - self.embed_args = args - self.embed_kwargs = kwargs - - tensor_parallel_size = gpc.get_world_size(ParallelMode.PARALLEL_1D) - tensor_parallel_rank = gpc.get_local_rank(ParallelMode.PARALLEL_1D) - self.num_embeddings_per_partition = divide(num_embeddings, tensor_parallel_size) - self.vocab_start_index = tensor_parallel_rank * self.num_embeddings_per_partition - self.vocab_end_index = self.vocab_start_index + self.num_embeddings_per_partition - - self.weight = Parameter( - torch.empty((self.num_embeddings_per_partition, self.embed_dim), device=get_current_device(), dtype=dtype)) - - self.reset_parameters(weight_initializer) - self._set_tensor_parallel_attributes() - set_parallel_input(False) - env.vocab_parallel = True - - def _set_tensor_parallel_attributes(self): - set_tensor_parallel_attribute_by_partition(self.weight, gpc.tensor_parallel_size) - - def reset_parameters(self, weight_initializer) -> None: - with seed(ParallelMode.TENSOR): - fan_in, fan_out = self.num_embeddings, self.embed_dim - weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out) - self._fill_padding_idx_with_zero() - - def _fill_padding_idx_with_zero(self) -> None: - if self.padding_idx is not None and \ - self.padding_idx >= self.vocab_start_index and self.padding_idx < self.vocab_end_index: - with torch.no_grad(): - self.weight[self.padding_idx - self.vocab_start_index].fill_(0) - - def _load_from_global_state_dict(self, state_dict, prefix, *args): - local_state = OrderedDict() - weight_key = prefix + 'weight' - if gpc.get_local_rank(ParallelMode.TENSOR) == 0: - # weight - weight = state_dict.pop(weight_key, None) - if weight is not None: - local_state[weight_key] = weight - - local_state = partition_tensor_parallel_state_dict(local_state, - ParallelMode.PARALLEL_1D, - dims={weight_key: 0}, - partition_states={weight_key: True}) - super()._load_from_global_state_dict(local_state, prefix, *args) - - def _save_to_global_state_dict(self, destination, prefix, keep_vars): - weight_key = prefix + 'weight' - local_state = OrderedDict({weight_key: self.weight}) - local_state = gather_tensor_parallel_state_dict(local_state, - ParallelMode.PARALLEL_1D, - dims={weight_key: 0}, - partition_states={weight_key: True}, - keep_vars=keep_vars) - destination.update(local_state) - - def forward(self, input_: Tensor) -> Tensor: - # Build the mask. - input_mask = (input_ < self.vocab_start_index) | (input_ >= self.vocab_end_index) - # Mask the input. - masked_input = input_.clone() - self.vocab_start_index - masked_input[input_mask] = 0 - - output_parallel = F.embedding(masked_input, self.weight, self.padding_idx, *self.embed_args, - **self.embed_kwargs) - - # Mask the output embedding. - output_parallel[input_mask, :] = 0. - # Reduce across all the model parallel GPUs. - output = reduce_input(output_parallel, ParallelMode.PARALLEL_1D) - return output - - -@LAYERS.register_module -class Dropout1D(ParallelLayer): - """Dropout layer of 1D parallelism. - - Args: - p (float, optional): probability of an element to be zeroed, defaults 0.5. - inplace (bool, optional): whether to do dropout in-place, default to be False. - """ - - def __init__(self, p: float = 0.5, inplace: bool = False): - super().__init__() - self.parallel_input = get_parallel_input() - self.p = p - self.inplace = inplace - - def forward(self, input_: Tensor) -> Tensor: - if self.parallel_input: - with seed(ParallelMode.TENSOR): - output = F.dropout(input_, self.p, self.training, self.inplace) - else: - output = F.dropout(input_, self.p, self.training, self.inplace) - return output - - -@LAYERS.register_module -class PatchEmbedding1D(ColossalaiModule): - """ - 2D Image to Patch Embedding - - :param img_size: image size - :type img_size: int - :param patch_size: patch size - :type patch_size: int - :param in_chans: number of channels of input image - :type in_chans: int - :param embed_size: size of embedding - :type embed_size: int - :param dtype: The dtype of parameters, defaults to None - :type dtype: torch.dtype, optional - :param flatten: whether to flatten output tensor, defaults to True - :type flatten: bool, optional - :param weight_initializer: The initializer of weight, defaults to kaiming uniform initializer - :type weight_initializer: typing.Callable, optional - :param bias_initializer: The initializer of bias, defaults to xavier uniform initializer - :type bias_initializer: typing.Callable, optional - :param position_embed_initializer: The initializer of position embedding, defaults to zero - :type position_embed_initializer: typing.Callable, optional - """ - - def __init__(self, - img_size: int, - patch_size: int, - in_chans: int, - embed_size: int, - dtype: torch.dtype = None, - flatten: bool = True, - weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), - bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1), - position_embed_initializer: Callable = init.zeros_()): - embed = VanillaPatchEmbedding(img_size, - patch_size, - in_chans, - embed_size, - dtype=dtype, - flatten=flatten, - weight_initializer=weight_initializer, - bias_initializer=bias_initializer, - position_embed_initializer=position_embed_initializer) - super().__init__(embed) - - def _load_from_state_dict(self, state_dict, prefix, *args): - local_state = OrderedDict() - param_keys = [prefix + 'weight', prefix + 'bias', prefix + 'cls_token', prefix + 'pos_embed'] - if gpc.get_local_rank(ParallelMode.TENSOR) == 0: - for key in param_keys: - param = state_dict.pop(key, None) - if param is not None: - local_state[key] = param - - local_state = broadcast_state_dict(local_state, ParallelMode.PARALLEL_1D) - super()._load_from_state_dict(local_state, prefix, *args) - - def _save_to_state_dict(self, destination, prefix, keep_vars): - if gpc.get_local_rank(ParallelMode.TENSOR) == 0: - super()._save_to_state_dict(destination, prefix, keep_vars) diff --git a/colossalai/nn/layer/parallel_2d/__init__.py b/colossalai/nn/layer/parallel_2d/__init__.py deleted file mode 100644 index 5562d1a700361c23bd8238848b5141c05c9b25aa..0000000000000000000000000000000000000000 --- a/colossalai/nn/layer/parallel_2d/__init__.py +++ /dev/null @@ -1,8 +0,0 @@ -from ._operation import reduce_by_batch_2d, split_batch_2d -from .layers import (Classifier2D, Embedding2D, LayerNorm2D, Linear2D, PatchEmbedding2D, VocabParallelClassifier2D, - VocabParallelEmbedding2D) - -__all__ = [ - 'split_batch_2d', 'reduce_by_batch_2d', 'Linear2D', 'LayerNorm2D', 'Classifier2D', 'PatchEmbedding2D', - 'Embedding2D', 'VocabParallelEmbedding2D', 'VocabParallelClassifier2D' -] diff --git a/colossalai/nn/layer/parallel_2d/_operation.py b/colossalai/nn/layer/parallel_2d/_operation.py deleted file mode 100644 index 306577dbd9333987bb181d70ef21fcff1d548b7c..0000000000000000000000000000000000000000 --- a/colossalai/nn/layer/parallel_2d/_operation.py +++ /dev/null @@ -1,849 +0,0 @@ -from typing import Any, Optional, Tuple - -import torch -import torch.distributed as dist -from colossalai.communication.collective import (all_gather, all_reduce, reduce, reduce_scatter) -from colossalai.context.parallel_mode import ParallelMode -from colossalai.core import global_context as gpc -from colossalai.utils import get_current_device -from torch import Tensor -from torch.cuda.amp import custom_bwd, custom_fwd -from colossalai.global_variables import tensor_parallel_env as env - - -def matmul_2d( - a, - b, - summa_dim, - out_shape, - row_rank=None, - col_rank=None, - row_parallel_mode=ParallelMode.PARALLEL_2D_ROW, - col_parallel_mode=ParallelMode.PARALLEL_2D_COL, -): - r"""Matrix multiplication for 2D parallelism. - - Args: - a (:class:`torch.tensor`): matrix :math:`A`. - b (:class:`torch.tensor`): matrix :math:`B`. - summa_dim (int): dimension of SUMMA fo 2D parallelism. - out_shape (:class:`torch.size`): shape of output tensor. - row_rank (int, optional): the rank of row, defaults to None. - col_rank (int, optional): the rank of column, defaults to None. - row_parallel_mode (:class:`colossalai.context.ParallelMode`, optional): - row parallel mode, defaults to ParallelMode.PARALLEL_2D_ROW. - col_parallel_mode (:class:`colossalai.context.ParallelMode`, optional): - column parallel mode, defaults to ParallelMode.PARALLEL_2D_COL. - - Returns: - :class:`torch.tensor`: :math:`C = AB`. - - Note: - The parallel_mode should be concluded in ``ParallelMode``. More details about ``ParallelMode`` could be found - in `parallel_mode `_ - """ - if row_rank is None: - row_rank = gpc.get_local_rank(col_parallel_mode) - if col_rank is None: - col_rank = gpc.get_local_rank(row_parallel_mode) - - data_parallel_rank = 0 if not gpc.is_initialized(ParallelMode.DATA) else gpc.get_local_rank(ParallelMode.DATA) - pipeline_parallel_rank = 0 if not gpc.is_initialized(ParallelMode.PIPELINE) else gpc.get_local_rank( - ParallelMode.PIPELINE) - pipeline_parallel_size = 1 if not gpc.is_initialized(ParallelMode.PIPELINE) else gpc.get_world_size( - ParallelMode.PIPELINE) - tensor_parallel_size = summa_dim**2 - return Matmul_AB_2D(a, b, summa_dim, out_shape, row_rank, col_rank, row_parallel_mode, col_parallel_mode, - data_parallel_rank, pipeline_parallel_rank, pipeline_parallel_size, tensor_parallel_size) - - -class _Classifier2D(torch.autograd.Function): - - @staticmethod - @custom_fwd(cast_inputs=torch.float16) - def forward( - ctx: Any, - A: Tensor, - B: Tensor, - bias: Optional[Tensor], - summa_dim: int, - out_shape: Tuple[int, ...], - row_rank: int, - col_rank: int, - row_parallel_mode: ParallelMode, - col_parallel_mode: ParallelMode, - data_parallel_rank: int, - pipeline_parallel_rank: int, - pipeline_parallel_size: int, - tensor_parallel_size: int, - ) -> Tensor: - A = A.clone().detach() - A_shape = A.shape - A = A.reshape((-1, A_shape[-1])) - B_shape = B.shape - B = B.reshape((-1, B_shape[-1])) - B_temp = all_gather(B, -1, col_parallel_mode) - if ctx: - ctx.save_for_backward(A, B_temp) - - C = torch.matmul(A, B_temp.transpose(0, 1)) - - C = all_reduce(C, row_parallel_mode) - - ctx.use_bias = bias is not None - if bias is not None: - C = C + bias - - out = C.reshape(out_shape) - - if ctx: - ctx.summa_dim = summa_dim - ctx.row_rank = row_rank - ctx.col_rank = col_rank - ctx.row_parallel_mode = row_parallel_mode - ctx.col_parallel_mode = col_parallel_mode - ctx.A_shape = A_shape - ctx.B_shape = B_shape - ctx.data_parallel_rank = data_parallel_rank - ctx.pipeline_parallel_rank = pipeline_parallel_rank - ctx.pipeline_parallel_size = pipeline_parallel_size - ctx.tensor_parallel_size = tensor_parallel_size - - return out - - @staticmethod - @custom_bwd - def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]: - A, B = ctx.saved_tensors - - with torch.no_grad(): - A_grad = torch.matmul(output_grad, B) - A_grad = A_grad.reshape(ctx.A_shape) - B_grad = torch.matmul(output_grad.reshape(-1, output_grad.shape[-1]).transpose(0, 1), A) - B_grad = reduce_scatter(B_grad, -1, ctx.col_parallel_mode) - B_grad = B_grad.reshape(ctx.B_shape) - if ctx.use_bias: - bias_grad = torch.sum(output_grad, dim=tuple(range(output_grad.ndim - 1))) - bias_grad = all_reduce(bias_grad, ctx.col_parallel_mode) - else: - bias_grad = None - - return A_grad, B_grad, bias_grad, None, None, None, None, None, None, None, None, None, None - - -def classifier_2d(A: Tensor, B: Tensor, bias: Optional[Tensor], summa_dim: int, out_shape: Tuple[int, ...], - row_rank: int, col_rank: int, row_parallel_mode: ParallelMode, col_parallel_mode: ParallelMode, - data_parallel_rank: int, pipeline_parallel_rank: int, pipeline_parallel_size: int, - tensor_parallel_size: int) -> Tensor: - r"""2D parallel classifier. - - Args: - A (:class:`torch.tensor`): matrix :math:`A`. - B (:class:`torch.tensor`): matrix :math:`B`. - bias (:class:`torch.tensor`, optional): matrix of bias. - summa_dim (int): dimension of SUMMA fo 2D parallelism. - out_shape (:class:`torch.size`): shape of output tensor. - row_rank (int, optional): the rank of row, defaults to None. - col_rank (int, optional): the rank of column, defaults to None. - row_parallel_mode (:class:`colossalai.context.ParallelMode`): row parallel mode. - col_parallel_mode (:class:`colossalai.context.ParallelMode`): column parallel mode. - data_parallel_rank (int): data parallel rank. - pipeline_parallel_rank (int): pipeline parallel rank - pipeline_parallel_size (int): pipeline parallel size. - tensor_parallel_size (int): tensor parallel size. - - Note: - The parallel_mode should be concluded in ``ParallelMode``. More details about ``ParallelMode`` could be found - in `parallel_mode `_ - """ - return _Classifier2D.apply(A, B, bias, summa_dim, out_shape, row_rank, col_rank, row_parallel_mode, - col_parallel_mode, data_parallel_rank, pipeline_parallel_rank, pipeline_parallel_size, - tensor_parallel_size) - - -class Matmul_AB_2D(torch.autograd.Function): - r"""Matrix multiplication for :math:`C = AB`. - - Args: - A (:class:`torch.tensor`): matrix :math:`A`. - B (:class:`torch.tensor`): matrix :math:`B`. - summa_dim (int): dimension of SUMMA fo 2D parallelism. - out_shape (:class:`torch.size`): shape of output tensor. - row_rank (int, optional): the rank of row, defaults to None. - col_rank (int, optional): the rank of column, defaults to None. - row_parallel_mode (:class:`colossalai.context.ParallelMode`): row parallel mode. - col_parallel_mode (:class:`colossalai.context.ParallelMode`): column parallel mode. - data_parallel_rank (int): data parallel rank. - pipeline_parallel_rank (int): pipeline parallel rank - pipeline_parallel_size (int): pipeline parallel size. - tensor_parallel_size (int): tensor parallel size. - - Note: - The parallel_mode should be concluded in ``ParallelMode``. More details about ``ParallelMode`` could be found - in `parallel_mode `_ - """ - - @staticmethod - @custom_fwd(cast_inputs=torch.float16) - def forward( - ctx: Any, - A: Tensor, - B: Tensor, - summa_dim: int, - out_shape: Tuple[int, ...], - row_rank: int, - col_rank: int, - row_parallel_mode: ParallelMode, - col_parallel_mode: ParallelMode, - data_parallel_rank: int, - pipeline_parallel_rank: int, - pipeline_parallel_size: int, - tensor_parallel_size: int, - ) -> Tensor: - # A: [b / q, s, h / q] -> [(b * s) / q, h / q] - # B: [h / q, s / q] - # C: [b / q, s, s / q] -> [(b * s) / q, s / q] - - assert A.shape[-1] == B.shape[-2], \ - 'Invalid shapes: A={}, B={} for AB.'.format(A.shape, B.shape) - - if ctx: - ctx.save_for_backward(A, B) - - A_shape = A.shape - A = A.reshape((-1, A_shape[-1])) - B_shape = B.shape - B = B.reshape((-1, B_shape[-1])) - C_shape = (A.shape[0], B.shape[-1]) - C = torch.zeros(C_shape, dtype=A.dtype, device=get_current_device()) - - # use circular buffer to store the communication tensor - # 2 is enough for all cases - A_list = [torch.empty_like(A) for _ in range(2)] - B_list = [torch.empty_like(B) for _ in range(2)] - - row_group = gpc.get_group(row_parallel_mode) - col_group = gpc.get_group(col_parallel_mode) - - src_a = summa_dim * row_rank + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \ - pipeline_parallel_rank * tensor_parallel_size - src_b = col_rank + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \ - pipeline_parallel_rank * tensor_parallel_size - - opa = [None] * 2 - opb = [None] * 2 - - A_list[0].copy_(A) - B_list[0].copy_(B) - opa[0] = dist.broadcast(A_list[0], src=src_a, group=row_group, async_op=True) - opb[0] = dist.broadcast(B_list[0], src=src_b, group=col_group, async_op=True) - cur = 0 - - for i in range(summa_dim): - if i != summa_dim - 1: - A_list[1 - cur].copy_(A) - opa[1 - cur] = dist.broadcast(A_list[1 - cur], src=src_a + 1, group=row_group, async_op=True) - B_list[1 - cur].copy_(B) - opb[1 - cur] = dist.broadcast(B_list[1 - cur], src=src_b + summa_dim, group=col_group, async_op=True) - - if opa[cur] is not None: - opa[cur].wait() - if opb[cur] is not None: - opb[cur].wait() - - torch.addmm(C, A_list[cur], B_list[cur], out=C) - cur = 1 - cur - src_a += 1 - src_b += summa_dim - - out = C.reshape(out_shape) - - if ctx: - ctx.summa_dim = summa_dim - ctx.row_rank = row_rank - ctx.col_rank = col_rank - ctx.row_parallel_mode = row_parallel_mode - ctx.col_parallel_mode = col_parallel_mode - ctx.A_shape = A_shape - ctx.B_shape = B_shape - ctx.data_parallel_rank = data_parallel_rank - ctx.pipeline_parallel_rank = pipeline_parallel_rank - ctx.pipeline_parallel_size = pipeline_parallel_size - ctx.tensor_parallel_size = tensor_parallel_size - return out - - @staticmethod - @custom_bwd - def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]: - A, B = ctx.saved_tensors - with torch.no_grad(): - A_grad = Matmul_ABT_2D.apply(output_grad, B, ctx.summa_dim, ctx.A_shape, ctx.row_rank, ctx.col_rank, - ctx.row_parallel_mode, ctx.col_parallel_mode, ctx.data_parallel_rank, - ctx.pipeline_parallel_rank, ctx.pipeline_parallel_size, - ctx.tensor_parallel_size) - B_grad = Matmul_ATB_2D.apply(A, output_grad, ctx.summa_dim, ctx.B_shape, ctx.row_rank, ctx.col_rank, - ctx.row_parallel_mode, ctx.col_parallel_mode, ctx.data_parallel_rank, - ctx.pipeline_parallel_rank, ctx.pipeline_parallel_size, - ctx.tensor_parallel_size) - return A_grad, B_grad, None, None, None, None, None, None, None, None, None, None - - -class Matmul_ABT_2D(torch.autograd.Function): - r"""Matrix multiplication for :math:`C = AB^T` - - Args: - A (:class:`torch.tensor`): matrix :math:`A`. - B (:class:`torch.tensor`): matrix :math:`B`. - summa_dim (int): dimension of SUMMA fo 2D parallelism. - out_shape (:class:`torch.size`): shape of output tensor. - row_rank (int, optional): the rank of row, defaults to None. - col_rank (int, optional): the rank of column, defaults to None. - row_parallel_mode (:class:`colossalai.context.ParallelMode`): row parallel mode. - col_parallel_mode (:class:`colossalai.context.ParallelMode`): column parallel mode. - column parallel mode, defaults to ParallelMode.PARALLEL_2D_COL. - data_parallel_rank (int): data parallel rank. - pipeline_parallel_rank (int): pipeline parallel rank - pipeline_parallel_size (int): pipeline parallel size. - tensor_parallel_size (int): tensor parallel size. - - Note: - The parallel_mode should be concluded in ``ParallelMode``. More details about ``ParallelMode`` could be found - in `parallel_mode `_. - """ - - @staticmethod - @custom_fwd(cast_inputs=torch.float16) - def forward( - ctx: Any, - A: Tensor, - B: Tensor, - summa_dim: int, - out_shape: Tuple[int, ...], - row_rank: int, - col_rank: int, - row_parallel_mode: ParallelMode, - col_parallel_mode: ParallelMode, - data_parallel_rank: int, - pipeline_parallel_rank: int, - pipeline_parallel_size: int, - tensor_parallel_size: int, - ) -> Tensor: - - assert A.shape[-1] == B.shape[-1], \ - 'Invalid shapes: A={}, B={} for ABT.'.format(A.shape, B.shape) - - if ctx: - ctx.save_for_backward(A, B) - - A_shape = A.shape - A = A.reshape((-1, A_shape[-1])) - B_shape = B.shape - B = B.reshape((-1, B_shape[-1])) - C_shape = (A.shape[0], B.shape[0]) - C = torch.empty(C_shape, dtype=A.dtype, device=get_current_device()) - - # use circular buffer to store the communication tensor - # 2 is enough for all cases - B_list = [torch.empty_like(B) for _ in range(2)] - C_list = [torch.empty_like(C) for _ in range(2)] - - row_group = gpc.get_group(row_parallel_mode) - col_group = gpc.get_group(col_parallel_mode) - - src_b = col_rank + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \ - pipeline_parallel_rank * tensor_parallel_size - src_c = summa_dim * row_rank + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \ - pipeline_parallel_rank * tensor_parallel_size - - opb = [None] * 2 - opr = [None] * 2 - - B_list[0].copy_(B) - opb[0] = dist.broadcast(B_list[0], src=src_b, group=col_group, async_op=True) - cur = 0 - - for i in range(summa_dim): - if i != summa_dim - 1: - B_list[1 - cur].copy_(B) - opb[1 - cur] = dist.broadcast(B_list[1 - cur], src=src_b + summa_dim, group=col_group, async_op=True) - - if opr[cur] is not None: - opr[cur].wait() - if i - 2 == col_rank: - C.copy_(C_list[cur]) - - if opb[cur] is not None: - opb[cur].wait() - - torch.matmul(A, B_list[cur].transpose(0, 1), out=C_list[cur]) - opr[cur] = dist.reduce(C_list[cur], dst=src_c, group=row_group, async_op=True) - cur = 1 - cur - src_b += summa_dim - src_c += 1 - - for op in opr: - op.wait() - - if summa_dim - 2 == col_rank: - C.copy_(C_list[cur]) - if summa_dim - 1 == col_rank: - C.copy_(C_list[1 - cur]) - out = C.reshape(out_shape) - - if ctx: - ctx.summa_dim = summa_dim - ctx.row_rank = row_rank - ctx.col_rank = col_rank - ctx.row_parallel_mode = row_parallel_mode - ctx.col_parallel_mode = col_parallel_mode - ctx.A_shape = A_shape - ctx.B_shape = B_shape - ctx.data_parallel_rank = data_parallel_rank - ctx.pipeline_parallel_rank = pipeline_parallel_rank - ctx.pipeline_parallel_size = pipeline_parallel_size - ctx.tensor_parallel_size = tensor_parallel_size - - return out - - @staticmethod - @custom_bwd - def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]: - A, B = ctx.saved_tensors - - with torch.no_grad(): - A_grad = Matmul_AB_2D.apply(output_grad, B, ctx.summa_dim, ctx.A_shape, ctx.row_rank, ctx.col_rank, - ctx.row_parallel_mode, ctx.col_parallel_mode, ctx.data_parallel_rank, - ctx.pipeline_parallel_rank, ctx.pipeline_parallel_size, - ctx.tensor_parallel_size) - B_grad = Matmul_ATB_2D.apply(output_grad, A, ctx.summa_dim, ctx.B_shape, ctx.row_rank, ctx.col_rank, - ctx.row_parallel_mode, ctx.col_parallel_mode, ctx.data_parallel_rank, - ctx.pipeline_parallel_rank, ctx.pipeline_parallel_size, - ctx.tensor_parallel_size) - return A_grad, B_grad, None, None, None, None, None, None, None, None, None, None - - -class Matmul_ATB_2D(torch.autograd.Function): - r"""Matrix multiplication for :math:`C = A^TB`. - - Args: - A (:class:`torch.tensor`): matrix :math:`A`. - B (:class:`torch.tensor`): matrix :math:`B`. - summa_dim (int): dimension of SUMMA fo 2D parallelism. - out_shape (:class:`torch.size`): shape of output tensor. - row_rank (int, optional): the rank of row, defaults to None. - col_rank (int, optional): the rank of column, defaults to None. - row_parallel_mode (:class:`colossalai.context.ParallelMode`): row parallel mode. - col_parallel_mode (:class:`colossalai.context.ParallelMode`): column parallel mode. - data_parallel_rank (int): data parallel rank. - pipeline_parallel_rank (int): pipeline parallel rank - pipeline_parallel_size (int): pipeline parallel size. - tensor_parallel_size (int): tensor parallel size. - - Note: - The parallel_mode should be concluded in ``ParallelMode``. More details about ``ParallelMode`` could be found - in `parallel_mode `_. - """ - - @staticmethod - @custom_fwd(cast_inputs=torch.float16) - def forward( - ctx: Any, - A: Tensor, - B: Tensor, - summa_dim: int, - out_shape: Tuple[int, ...], - row_rank: int, - col_rank: int, - row_parallel_mode: ParallelMode, - col_parallel_mode: ParallelMode, - data_parallel_rank: int, - pipeline_parallel_rank: int, - pipeline_parallel_size: int, - tensor_parallel_size: int, - ) -> Tensor: - - assert A.shape[-2] == B.shape[-2], \ - 'Invalid shapes: A={}, B={} for ATB.'.format(A.shape, B.shape) - - if ctx: - ctx.save_for_backward(A, B) - - A_shape = A.shape - A = A.reshape((-1, A_shape[-1])) - B_shape = B.shape - B = B.reshape((-1, B_shape[-1])) - C_shape = (A.shape[-1], B.shape[-1]) - C = torch.empty(C_shape, dtype=A.dtype, device=get_current_device()) - - # use circular buffer to store the communication tensor - # 2 is enough for all cases - A_list = [torch.empty_like(A) for _ in range(2)] - C_list = [torch.empty_like(C) for _ in range(2)] - - row_group = gpc.get_group(row_parallel_mode) - col_group = gpc.get_group(col_parallel_mode) - - src_a = summa_dim * row_rank + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \ - pipeline_parallel_rank * tensor_parallel_size - src_c = col_rank + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \ - pipeline_parallel_rank * tensor_parallel_size - - opa = [None] * 2 - opr = [None] * 2 - - A_list[0].copy_(A) - opa[0] = dist.broadcast(A_list[0], src=src_a, group=row_group, async_op=True) - cur = 0 - - for i in range(summa_dim): - if i != summa_dim - 1: - A_list[1 - cur].copy_(A) - opa[1 - cur] = dist.broadcast(A_list[1 - cur], src=src_a + 1, group=row_group, async_op=True) - - if opr[cur] is not None: - opr[cur].wait() - if i - 2 == row_rank: - C.copy_(C_list[cur]) - - if opa[cur] is not None: - opa[cur].wait() - - torch.matmul(A_list[cur].transpose(0, 1), B, out=C_list[cur]) - opr[cur] = dist.reduce(C_list[cur], dst=src_c, group=col_group, async_op=True) - cur = 1 - cur - src_a += 1 - src_c += summa_dim - - for op in opr: - op.wait() - - if summa_dim - 2 == row_rank: - C.copy_(C_list[cur]) - if summa_dim - 1 == row_rank: - C.copy_(C_list[1 - cur]) - out = C.reshape(out_shape) - - if ctx: - ctx.summa_dim = summa_dim - ctx.row_rank = row_rank - ctx.col_rank = col_rank - ctx.row_parallel_mode = row_parallel_mode - ctx.col_parallel_mode = col_parallel_mode - ctx.A_shape = A_shape - ctx.B_shape = B_shape - ctx.data_parallel_rank = data_parallel_rank - ctx.pipeline_parallel_rank = pipeline_parallel_rank - ctx.pipeline_parallel_size = pipeline_parallel_size - ctx.tensor_parallel_size = tensor_parallel_size - - return out - - @staticmethod - @custom_bwd - def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]: - A, B = ctx.saved_tensors - - with torch.no_grad(): - A_grad = Matmul_ABT_2D.apply(B, output_grad, ctx.summa_dim, ctx.A_shape, ctx.row_rank, ctx.col_rank, - ctx.row_parallel_mode, ctx.col_parallel_mode, ctx.data_parallel_rank, - ctx.pipeline_parallel_rank, ctx.pipeline_parallel_size, - ctx.tensor_parallel_size) - B_grad = Matmul_AB_2D.apply(A, output_grad, ctx.summa_dim, ctx.B_shape, ctx.row_rank, ctx.col_rank, - ctx.row_parallel_mode, ctx.col_parallel_mode, ctx.data_parallel_rank, - ctx.pipeline_parallel_rank, ctx.pipeline_parallel_size, - ctx.tensor_parallel_size) - return A_grad, B_grad, None, None, None, None, None, None, None, None, None, None - - -class _Add_Bias_2D(torch.autograd.Function): - - @staticmethod - @custom_fwd(cast_inputs=torch.float16) - def forward( - ctx: Any, - input_: Tensor, - bias: Tensor, - output_size_per_partition: int, - row_rank: int, - col_rank: int, - row_parallel_mode: ParallelMode, - col_parallel_mode: ParallelMode, - skip_bias_add: bool, - data_parallel_rank: int, - pipeline_parallel_rank: int, - pipeline_parallel_size: int, - tensor_parallel_size: int, - ) -> Tensor: - bias_temp = all_gather(bias, -1, col_parallel_mode) - - ctx.row_rank = row_rank - ctx.col_rank = col_rank - ctx.row_parallel_mode = row_parallel_mode - ctx.col_parallel_mode = col_parallel_mode - ctx.bias = skip_bias_add - ctx.data_parallel_rank = data_parallel_rank - ctx.pipeline_parallel_rank = pipeline_parallel_rank - ctx.pipeline_parallel_size = pipeline_parallel_size - ctx.tensor_parallel_size = tensor_parallel_size - - if skip_bias_add: - return bias_temp - else: - output = input_ + bias_temp - return output - - @staticmethod - @custom_bwd - def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]: - col_parallel_mode = ctx.col_parallel_mode - - if ctx.bias: - grad = reduce_scatter(output_grad, -1, col_parallel_mode) - return None, grad, None, None, None, None, None, None, None, None, None, None - else: - reduce_dim = tuple(range(output_grad.ndim - 1)) - reduce = torch.sum(output_grad, dim=reduce_dim) - grad = reduce_scatter(reduce, -1, col_parallel_mode) - return output_grad, grad, None, None, None, None, None, None, None, None, None, None - - -def add_bias_2d(input_: Tensor, bias: Tensor, output_size_per_partition: int, row_rank: int, col_rank: int, - row_parallel_mode: ParallelMode, col_parallel_mode: ParallelMode, skip_bias_add: bool, - data_parallel_rank: int, pipeline_parallel_rank: int, pipeline_parallel_size: int, - tensor_parallel_size: int) -> Tensor: - r"""Matrix add bias: :math:`C = A + b`. - - Args: - input_ (:class:`torch.tensor`): matrix :math:`A`. - bias (:class:`torch.tensor`): matrix :math:`B`. - output_size_per_partition (int): size of output per partition. - row_rank (int, optional): the rank of row, defaults to None. - col_rank (int, optional): the rank of column, defaults to None. - row_parallel_mode (:class:`colossalai.context.ParallelMode`): row parallel mode. - col_parallel_mode (:class:`colossalai.context.ParallelMode`): column parallel mode. - skip_bias_add (bool): - If set to ``True``, it will skip bias add for linear layer, which is preserved for kernel fusion. - data_parallel_rank (int): data parallel rank. - pipeline_parallel_rank (int): pipeline parallel rank - pipeline_parallel_size (int): pipeline parallel size. - tensor_parallel_size (int): tensor parallel size. - - Note: - The parallel_mode should be concluded in ``ParallelMode``. More details about ``ParallelMode`` could be found - in `parallel_mode `_ - """ - return _Add_Bias_2D.apply(input_, bias, output_size_per_partition, row_rank, col_rank, row_parallel_mode, - col_parallel_mode, skip_bias_add, data_parallel_rank, pipeline_parallel_rank, - pipeline_parallel_size, tensor_parallel_size) - - -class _Layernorm_2D(torch.autograd.Function): - - @staticmethod - @custom_fwd(cast_inputs=torch.float32) - def forward(ctx: Any, input_: Tensor, E_x: Tensor, Var_x: Tensor, hidden_size: int, row_parallel_mode: ParallelMode, - col_parallel_mode: ParallelMode) -> Tensor: - input_ = input_ - E_x - # in here, input = x - E[x], Var_x = 1 / sqrt(Var[x] + eps) - ctx.normalized_shape = hidden_size - output = input_ * Var_x - ctx.save_for_backward(output, Var_x) - ctx.row_parallel_mode = row_parallel_mode - ctx.col_parallel_mode = col_parallel_mode - return output - - @staticmethod - @custom_bwd - def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]: - row_parallel_mode = ctx.row_parallel_mode - col_parallel_mode = ctx.col_parallel_mode - x, Var_x = ctx.saved_tensors - # in here, Var_x = 1 / sqrt(Var[x] + eps), x = (x - E[x]) * Var_x - output_grad_sum = torch.sum(output_grad, dim=-1, keepdim=True) - torch.distributed.all_reduce(output_grad_sum, group=gpc.get_group(row_parallel_mode)) - output_grad_sum /= ctx.normalized_shape - - output_grad_mul_x_sum = torch.sum(output_grad * x, dim=-1, keepdim=True) - torch.distributed.all_reduce(output_grad_mul_x_sum, group=gpc.get_group(row_parallel_mode)) - output_grad_mul_x_sum /= ctx.normalized_shape - - input_grad = output_grad.clone() - input_grad -= x * output_grad_mul_x_sum - input_grad -= output_grad_sum - input_grad *= Var_x - - return input_grad, None, None, None, None, None - - -def layernorm_2d(input_: Tensor, E_x: Tensor, Var_x: Tensor, hidden_size: int, row_parallel_mode: ParallelMode, - col_parallel_mode: ParallelMode) -> Tensor: - r"""Layernorm. - - Args: - input_ (:class:`torch.tensor`): input matrix. - E_x (:class:`torch.tensor`): mean. - Var_x (:class:`torch.tensor`): variance. - hidden_size (int): hidden size. - row_parallel_mode (:class:`colossalai.context.ParallelMode`): row parallel mode. - col_parallel_mode (:class:`colossalai.context.ParallelMode`): column parallel mode. - - Note: - The parallel_mode should be concluded in ``ParallelMode``. More details about ``ParallelMode`` could be found - in `parallel_mode `_ - """ - return _Layernorm_2D.apply(input_, E_x, Var_x, hidden_size, row_parallel_mode, col_parallel_mode) - - -class _AllGatherTensor2D(torch.autograd.Function): - - @staticmethod - @custom_fwd(cast_inputs=torch.float16) - def forward(ctx: Any, inputs: Tensor, dim: int, parallel_mode: ParallelMode) -> Tensor: - ctx.dim = dim - ctx.parallel_mode = parallel_mode - - outputs = all_gather(inputs, dim, parallel_mode) - return outputs - - @staticmethod - @custom_bwd - def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]: - grad = reduce_scatter(output_grad, ctx.dim, ctx.parallel_mode) - return grad.contiguous(), None, None - - -def all_gather_tensor_2d(tensor: Tensor, dim: int, parallel_mode: ParallelMode) -> Tensor: - r"""All gather the tensor of 2D parallelism. - - Args: - tensor (:class:`torch.tensor`): Input tensor. - dim (int): Dimension to gather. - parallel_mode (:class:`colossalai.context.ParallelMode`): The parallel mode tensor used. - - Note: - The parallel_mode should be concluded in ``ParallelMode``. More details about ``ParallelMode`` could be found - in `parallel_mode `_ - """ - return _AllGatherTensor2D.apply(tensor, dim, parallel_mode) - - -def split_batch_2d(input_: Tensor, dim: int = 0) -> Tensor: - """Splits 2D tensor in specified dimension across cols. - - Args: - input_ (:class:`torch.tensor`): Input tensor. - dim (int): Specified dimension in which to split. - - Returns: - :class:`torch.tensor`: The tensor has been split. - """ - dim_size = input_.size(dim) - world_size = gpc.get_world_size(ParallelMode.PARALLEL_2D_COL) - - if world_size <= 1: - return input_ - - assert dim_size % world_size == 0, \ - f'The batch size ({dim_size}) is not a multiple of 2D size ({world_size}).' - - return torch.chunk(input_, gpc.get_world_size(ParallelMode.PARALLEL_2D_COL), - dim=dim)[gpc.get_local_rank(ParallelMode.PARALLEL_2D_COL)].contiguous() - - -class _ReduceTensor2D(torch.autograd.Function): - - @staticmethod - def forward(ctx, input_, parallel_mode): - return all_reduce(input_, parallel_mode) - - @staticmethod - def backward(ctx, output_grad): - return output_grad, None - - -def reduce_tensor_2d(input_: Tensor, parallel_mode: ParallelMode) -> Tensor: - r"""All-reduce the input. - - Args: - input_ (:class:`torch.tensor`): Input tensor. - parallel_mode (:class:`colossalai.context.ParallelMode`): The parallel mode tensor used. - - Note: - The parallel_mode should be concluded in ``ParallelMode``. More details about ``ParallelMode`` could be found - in `parallel_mode `_ - """ - return _ReduceTensor2D.apply(input_, parallel_mode) - - -class _ReduceScatterTensor2D(torch.autograd.Function): - - @staticmethod - def forward(ctx, input_, dim, parallel_mode): - ctx.dim = dim - ctx.parallel_mode = parallel_mode - return reduce_scatter(input_, dim, parallel_mode) - - @staticmethod - def backward(ctx, output_grad): - return all_gather(output_grad, ctx.dim, ctx.parallel_mode), None, None - - -def reduce_scatter_tensor_2d(tensor: Tensor, dim: int, parallel_mode: ParallelMode) -> Tensor: - r"""Reduce-scatter the input. - - Args: - tensor (:class:`torch.tensor`): Input tensor. - dim (int): Dimension to reduce. - parallel_mode (:class:`colossalai.context.ParallelMode`): The parallel mode tensor used. - - Note: - The parallel_mode should be concluded in ``ParallelMode``. More details about ``ParallelMode`` could be found - in `parallel_mode `_ - """ - dim_size = tensor.size(dim) - world_size = gpc.get_world_size(parallel_mode) - assert dim_size % world_size == 0, \ - f'The batch size ({dim_size}) is not a multiple of 2D size ({world_size}).' - - return _ReduceScatterTensor2D.apply(tensor, dim, parallel_mode) - - -class _ReduceByBatch2D(torch.autograd.Function): - - @staticmethod - def symbolic(graph, input_, reduce_mean: bool = False): - output = all_reduce(input_, ParallelMode.PARALLEL_2D_COL) - if reduce_mean: - reduce_size = gpc.get_world_size(ParallelMode.PARALLEL_2D_COL) - return output / reduce_size - return output - - @staticmethod - @custom_fwd(cast_inputs=torch.float32) - def forward(ctx, input_, reduce_mean: bool = False): - output = all_reduce(input_, ParallelMode.PARALLEL_2D_COL) - ctx.reduce_mean = reduce_mean - if reduce_mean: - reduce_size = gpc.get_world_size(ParallelMode.PARALLEL_2D_COL) - ctx.reduce_size = reduce_size - return output.clone() / reduce_size - return output.clone() - - @staticmethod - @custom_bwd - def backward(ctx, output_grad): - if ctx.reduce_mean: - return output_grad / ctx.reduce_size, None - else: - return output_grad, None - - -def reduce_by_batch_2d(input_, reduce_mean: bool = False) -> Tensor: - r"""All-reduce the input from the model parallel region. - - Args: - input_ (:class:`torch.tensor`): input matrix. - reduce_mean (bool, optional): - If set to ``True``, it will divide the output by column parallel size, default to False. - """ - return _ReduceByBatch2D.apply(input_, reduce_mean) diff --git a/colossalai/nn/layer/parallel_2d/_utils.py b/colossalai/nn/layer/parallel_2d/_utils.py deleted file mode 100644 index 012fec41c80231165ceb92e57e2f449e61fdb8b2..0000000000000000000000000000000000000000 --- a/colossalai/nn/layer/parallel_2d/_utils.py +++ /dev/null @@ -1,20 +0,0 @@ -from colossalai.context.parallel_mode import ParallelMode -from colossalai.core import global_context as gpc -from colossalai.global_variables import tensor_parallel_env as env - - -def get_summa_dim_from_env() -> int: - try: - summa_dim = env.summa_dim - assert summa_dim > 0, 'SUMMA_DIM must be larger than zero' - return summa_dim - - except KeyError as e: - raise EnvironmentError('SUMMA_DIM is not found in the current environment, ' - 'please make sure that you have used the correct process group initializer') - - -def assert_summa_initialization(): - assert gpc.is_initialized(ParallelMode.PARALLEL_2D_COL) and \ - gpc.is_initialized(ParallelMode.PARALLEL_2D_ROW), \ - 'Both TWO_DIMENSION_COL and TWO_DIMENSION_ROW must be initialized by the process group initializer' diff --git a/colossalai/nn/layer/parallel_2d/layers.py b/colossalai/nn/layer/parallel_2d/layers.py deleted file mode 100644 index f3a4d2bbbc32f8f13815ea36a7c47268217446d5..0000000000000000000000000000000000000000 --- a/colossalai/nn/layer/parallel_2d/layers.py +++ /dev/null @@ -1,1201 +0,0 @@ -import math -from collections import OrderedDict -from typing import Callable - -import torch -import torch.nn as nn -import torch.nn.functional as F -from colossalai.communication import broadcast -from colossalai.context import ParallelMode, seed -from colossalai.core import global_context as gpc -from colossalai.global_variables import tensor_parallel_env as env -from colossalai.nn import init as init -from colossalai.registry import LAYERS -from colossalai.utils.checkpointing import gather_tensor_parallel_state_dict, partition_tensor_parallel_state_dict -from colossalai.utils.cuda import get_current_device -from torch import Tensor -from torch.nn import Parameter - -from ..base_layer import ParallelLayer -from ..utils import divide, set_tensor_parallel_attribute_by_partition, to_2tuple -from ._operation import (Matmul_AB_2D, Matmul_ABT_2D, add_bias_2d, all_gather_tensor_2d, classifier_2d, layernorm_2d, - reduce_scatter_tensor_2d, split_batch_2d) -from ._utils import assert_summa_initialization, get_summa_dim_from_env - - -@LAYERS.register_module -class Linear2D(ParallelLayer): - r"""Linear layer for 2D parallelism - - Args: - in_features (int): size of each input sample. - out_features (int): size of each output sample. - bias (bool, optional): If set to ``False``, the layer will not learn an additive bias, defaults to ``True``. - dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None. - skip_bias_add (bool, optional): If set to ``True``, it will skip bias add for linear layer, - which is preserved for kernel fusion, defaults to False. - weight_initializer (:class:`typing.Callable`, optional): - The initializer of weight, defaults to kaiming uniform initializer. - bias_initializer (:class:`typing.Callable`, optional): - The initializer of bias, defaults to xavier uniform initializer. - - More details about ``initializer`` please refer to - `init `_. - """ - - def __init__(self, - in_features: int, - out_features: int, - bias: bool = True, - dtype: torch.dtype = None, - skip_bias_add: bool = False, - weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), - bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1)): - super().__init__() - - self.in_features = in_features - self.out_features = out_features - self.skip_bias_add = skip_bias_add - - # parallel settings - assert_summa_initialization() - self.row_rank = gpc.get_local_rank(ParallelMode.PARALLEL_2D_COL) - self.col_rank = gpc.get_local_rank(ParallelMode.PARALLEL_2D_ROW) - self.summa_dim = get_summa_dim_from_env() - - # partitioning dimension - self.input_size_per_partition = divide(self.in_features, self.summa_dim) - self.hidden_size_per_partition = divide(self.out_features, self.summa_dim) - - # create weight, shape: [k/q, h/q] - factory_kwargs = {'device': get_current_device(), 'dtype': dtype} - self.weight = Parameter( - torch.empty(self.input_size_per_partition, self.hidden_size_per_partition, **factory_kwargs)) - - # create bias, shape: [h/q] - if bias: - self.bias = Parameter(torch.empty(divide(self.out_features, self.summa_dim**2), **factory_kwargs)) - else: - self.register_parameter('bias', None) - - # initialize parameters - with seed(ParallelMode.TENSOR): - self.reset_parameters(weight_initializer, bias_initializer) - self._set_tensor_parallel_attributes() - - def _set_tensor_parallel_attributes(self): - set_tensor_parallel_attribute_by_partition(self.weight, self.summa_dim**2) - if self.bias is not None: - set_tensor_parallel_attribute_by_partition(self.bias, self.summa_dim**2) - - def reset_parameters(self, weight_initializer, bias_initializer) -> None: - fan_in, fan_out = self.in_features, self.out_features - weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out) - if self.bias is not None: - bias_initializer(self.bias, fan_in=fan_in) - - def _load_from_global_state_dict(self, state_dict, prefix, *args, **kwargs): - local_state = OrderedDict() - weight_key = prefix + 'weight' - bias_key = prefix + 'bias' - if gpc.get_local_rank(ParallelMode.TENSOR) == 0: - # weight - weight = state_dict.pop(weight_key, None) - if weight is not None: - local_state[weight_key] = weight.transpose(0, 1) - # bias - if self.bias is not None: - bias = state_dict.pop(bias_key, None) - if bias is not None: - local_state[bias_key] = bias - - # partition in row groups - if gpc.get_local_rank(ParallelMode.PARALLEL_2D_COL) == 0: - local_state = partition_tensor_parallel_state_dict( - local_state, - ParallelMode.PARALLEL_2D_ROW, - dims={ - weight_key: -1, - bias_key: 0 - }, - partition_states={ - weight_key: True, - bias_key: True - }, - ) - # partition in column groups - local_state = partition_tensor_parallel_state_dict( - local_state, - ParallelMode.PARALLEL_2D_COL, - dims={ - weight_key: 0, - bias_key: 0 - }, - partition_states={ - weight_key: True, - bias_key: True - }, - ) - - super()._load_from_global_state_dict(local_state, prefix, *args, **kwargs) - - def _save_to_global_state_dict(self, destination, prefix, keep_vars): - weight_key = prefix + 'weight' - bias_key = prefix + 'bias' - local_state = OrderedDict({weight_key: self.weight}) - if self.bias is not None: - local_state[bias_key] = self.bias - - # gather in column groups - local_state = gather_tensor_parallel_state_dict( - local_state, - ParallelMode.PARALLEL_2D_COL, - dims={ - weight_key: 0, - bias_key: 0 - }, - partition_states={ - weight_key: True, - bias_key: True - }, - keep_vars=keep_vars, - ) - # gather in row groups - if gpc.get_local_rank(ParallelMode.PARALLEL_2D_COL) == 0: - local_state = gather_tensor_parallel_state_dict( - local_state, - ParallelMode.PARALLEL_2D_ROW, - dims={ - weight_key: -1, - bias_key: 0 - }, - partition_states={ - weight_key: True, - bias_key: True - }, - keep_vars=keep_vars, - ) - if gpc.get_local_rank(ParallelMode.TENSOR) == 0: - local_state[weight_key] = local_state[weight_key].transpose(0, 1) - destination.update(local_state) - - def forward(self, x: Tensor) -> Tensor: - # input: [m/q, n/q, k/q] - # output: [m/q, n/q, h/q] - out_shape = x.shape[:-1] + (self.hidden_size_per_partition,) - - output = Matmul_AB_2D.apply(x, self.weight, self.summa_dim, out_shape, self.row_rank, self.col_rank, - ParallelMode.PARALLEL_2D_ROW, ParallelMode.PARALLEL_2D_COL, self.data_parallel_rank, - self.pipeline_parallel_rank, self.pipeline_parallel_size, self.tensor_parallel_size) - - if self.bias is not None: - if self.skip_bias_add: - bias = add_bias_2d(None, self.bias, self.hidden_size_per_partition, self.row_rank, self.col_rank, - ParallelMode.PARALLEL_2D_ROW, ParallelMode.PARALLEL_2D_COL, True, - self.data_parallel_rank, self.pipeline_parallel_rank, self.pipeline_parallel_size, - self.tensor_parallel_size) - return output, bias - else: - output = add_bias_2d(output, self.bias, self.hidden_size_per_partition, self.row_rank, self.col_rank, - ParallelMode.PARALLEL_2D_ROW, ParallelMode.PARALLEL_2D_COL, False, - self.data_parallel_rank, self.pipeline_parallel_rank, self.pipeline_parallel_size, - self.tensor_parallel_size) - return output - else: - return output - - -@LAYERS.register_module -class LayerNorm2D(ParallelLayer): - r"""Layer Normalization for 2D parallelism. - - Args: - normalized_shape (int): input shape from an expected input of size. - :math:`[* \times \text{normalized_shape}[0] \times \text{normalized_shape}[1] - \times \ldots \times \text{normalized_shape}[-1]]` - If a single integer is used, it is treated as a singleton list, and this module will - normalize over the last dimension which is expected to be of that specific size. - eps (float, optional): a value added to the denominator for numerical stability, defaults to 1e-05. - bias (bool, optional): Whether to add a bias, defaults to ``True``. - dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None. - """ - - def __init__(self, normalized_shape: int, eps: float = 1e-05, bias=True, dtype=None): - super().__init__() - - # layer norm config - self.normalized_shape = normalized_shape - self.variance_epsilon = eps - - # parallel setting - assert_summa_initialization() - self.row_rank = gpc.get_local_rank(ParallelMode.PARALLEL_2D_COL) - self.col_rank = gpc.get_local_rank(ParallelMode.PARALLEL_2D_ROW) - self.summa_dim = get_summa_dim_from_env() - - # partitioning dimension - self.partitioned_partition = divide(normalized_shape, self.summa_dim**2) - - # create parameters - factory_kwargs = {'device': get_current_device(), 'dtype': dtype} - - self.weight = Parameter(torch.ones(self.partitioned_partition, **factory_kwargs)) - if bias: - self.bias = Parameter(torch.zeros(self.partitioned_partition, **factory_kwargs)) - else: - self.bias = None - - self._set_tensor_parallel_attributes() - - def _set_tensor_parallel_attributes(self): - set_tensor_parallel_attribute_by_partition(self.weight, self.summa_dim**2) - if self.bias is not None: - set_tensor_parallel_attribute_by_partition(self.bias, self.summa_dim**2) - - def _load_from_global_state_dict(self, state_dict, prefix, *args, **kwargs): - local_state = OrderedDict() - weight_key = prefix + 'weight' - bias_key = prefix + 'bias' - if gpc.get_local_rank(ParallelMode.TENSOR) == 0: - # weight - weight = state_dict.pop(weight_key, None) - if weight is not None: - local_state[weight_key] = weight - # bias - bias = state_dict.pop(bias_key, None) - if bias is not None: - local_state[bias_key] = bias - - # partition in row groups - if gpc.get_local_rank(ParallelMode.PARALLEL_2D_COL) == 0: - local_state = partition_tensor_parallel_state_dict( - local_state, - ParallelMode.PARALLEL_2D_ROW, - dims={ - weight_key: 0, - bias_key: 0 - }, - partition_states={ - weight_key: True, - bias_key: True - }, - ) - # partition in column groups - local_state = partition_tensor_parallel_state_dict( - local_state, - ParallelMode.PARALLEL_2D_COL, - dims={ - weight_key: 0, - bias_key: 0 - }, - partition_states={ - weight_key: True, - bias_key: True - }, - ) - - super()._load_from_global_state_dict(local_state, prefix, *args, **kwargs) - - def _save_to_global_state_dict(self, destination, prefix, keep_vars): - weight_key = prefix + 'weight' - bias_key = prefix + 'bias' - local_state = OrderedDict({weight_key: self.weight}) - if self.bias is not None: - local_state[bias_key] = self.bias - - # gather in column groups - local_state = gather_tensor_parallel_state_dict( - local_state, - ParallelMode.PARALLEL_2D_COL, - dims={ - weight_key: 0, - bias_key: 0 - }, - partition_states={ - weight_key: True, - bias_key: True - }, - keep_vars=keep_vars, - ) - # gather in row groups - if gpc.get_local_rank(ParallelMode.PARALLEL_2D_COL) == 0: - local_state = gather_tensor_parallel_state_dict( - local_state, - ParallelMode.PARALLEL_2D_ROW, - dims={ - weight_key: 0, - bias_key: 0 - }, - partition_states={ - weight_key: True, - bias_key: True - }, - keep_vars=keep_vars, - ) - if gpc.get_local_rank(ParallelMode.TENSOR) == 0: - destination.update(local_state) - - def forward(self, x: Tensor) -> Tensor: - with torch.no_grad(): - E_x = torch.sum(x, dim=-1, keepdim=True) # [b/q, s, 1] - torch.distributed.all_reduce(E_x, group=gpc.get_group(ParallelMode.PARALLEL_2D_ROW)) - E_x /= self.normalized_shape - - # Var_x in the block below is the sum of input^2 - Var_x = torch.sum(x * x, dim=-1, keepdim=True) # [b/q, s, 1] - torch.distributed.all_reduce(Var_x, group=gpc.get_group(ParallelMode.PARALLEL_2D_ROW)) - Var_x /= self.normalized_shape - - Var_x = Var_x - E_x * E_x # variance of x [b/q, s, 1] - # this time 1/sqrt(Var_x + epsilon) - Var_x = 1.0 / torch.sqrt(Var_x + self.variance_epsilon) - - output = layernorm_2d(x, E_x, Var_x, self.normalized_shape, ParallelMode.PARALLEL_2D_ROW, - ParallelMode.PARALLEL_2D_COL) - scale = add_bias_2d(None, self.weight, self.partitioned_partition, self.row_rank, self.col_rank, - ParallelMode.PARALLEL_2D_ROW, ParallelMode.PARALLEL_2D_COL, True, self.data_parallel_rank, - self.pipeline_parallel_rank, self.pipeline_parallel_size, self.tensor_parallel_size) - if self.bias is not None: - bias = add_bias_2d(None, self.bias, self.partitioned_partition, self.row_rank, self.col_rank, - ParallelMode.PARALLEL_2D_ROW, ParallelMode.PARALLEL_2D_COL, True, - self.data_parallel_rank, self.pipeline_parallel_rank, self.pipeline_parallel_size, - self.tensor_parallel_size) - output = torch.addcmul(bias, scale, output) - else: - output = torch.mul(scale, output) - return output - - -@LAYERS.register_module -class PatchEmbedding2D(ParallelLayer): - r"""2D Image to Patch Embedding. - - Args: - img_size (int): image size. - patch_size (int): patch size. - in_chans (int): number of channels of input image. - embed_size (int): size of embedding. - dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None. - flatten (bool, optional): whether to flatten output tensor, defaults to True. - weight_initializer (:class:`typing.Callable`, optional): - The initializer of weight, defaults to kaiming uniform initializer. - bias_initializer (:class:`typing.Callable`, optional): - The initializer of bias, defaults to xavier uniform initializer. - position_embed_initializer (:class:`typing.Callable`, optional): - The initializer of position embedding, defaults to zeros initializer. - - More details about ``initializer`` please refer to - `init `_. - """ - - def __init__(self, - img_size: int, - patch_size: int, - in_chans: int, - embed_size: int, - flatten: bool = True, - dtype: torch.dtype = None, - weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), - bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1), - position_embed_initializer: Callable = init.zeros_()): - super().__init__() - img_size = to_2tuple(img_size) - patch_size = to_2tuple(patch_size) - - assert_summa_initialization() - self.summa_dim = get_summa_dim_from_env() - self.img_size = img_size - self.patch_size = patch_size - self.grid_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1]) - self.num_patches = self.grid_size[0] * self.grid_size[1] - self.flatten = flatten - self.embed_size = embed_size - self.embed_size_per_partition = embed_size // (self.summa_dim**2) - - with seed(ParallelMode.TENSOR): - self.weight = Parameter( - torch.empty((self.embed_size_per_partition, in_chans, *self.patch_size), - device=get_current_device(), - dtype=dtype)) - self.bias = Parameter(torch.empty(self.embed_size_per_partition, device=get_current_device(), dtype=dtype)) - - self.cls_token = Parameter( - torch.zeros((1, 1, self.embed_size_per_partition), device=get_current_device(), dtype=dtype)) - self.pos_embed = Parameter( - torch.zeros((1, self.num_patches + 1, self.embed_size_per_partition), - device=get_current_device(), - dtype=dtype)) - - self.reset_parameters(weight_initializer, bias_initializer, position_embed_initializer) - self._set_tensor_parallel_attribute() - - def _set_tensor_parallel_attribute(self): - set_tensor_parallel_attribute_by_partition(self.weight, self.summa_dim**2) - set_tensor_parallel_attribute_by_partition(self.bias, self.summa_dim**2) - set_tensor_parallel_attribute_by_partition(self.cls_token, self.summa_dim**2) - set_tensor_parallel_attribute_by_partition(self.pos_embed, self.summa_dim**2) - - def reset_parameters(self, weight_initializer, bias_initializer, position_embed_initializer): - with seed(ParallelMode.TENSOR): - fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight) - fan_out = self.embed_size - weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out) - bias_initializer(self.bias, fan_in=fan_in) - position_embed_initializer(self.pos_embed) - - def _load_from_global_state_dict(self, state_dict, prefix, *args, **kwargs): - local_state = OrderedDict() - weight_key = prefix + 'weight' - bias_key = prefix + 'bias' - cls_token_key = prefix + 'cls_token' - pos_embed_key = prefix + 'pos_embed' - if gpc.get_local_rank(ParallelMode.TENSOR) == 0: - # weight - weight = state_dict.pop(weight_key, None) - if weight is not None: - local_state[weight_key] = weight - # bias - bias = state_dict.pop(bias_key, None) - if bias is not None: - local_state[bias_key] = bias - # cls token - cls_token = state_dict.pop(cls_token_key, None) - if cls_token is not None: - local_state[cls_token_key] = cls_token - # pos embed - pos_embed = state_dict.pop(pos_embed_key, None) - if pos_embed is not None: - local_state[pos_embed_key] = pos_embed - - # partition in row groups - if gpc.get_local_rank(ParallelMode.PARALLEL_2D_COL) == 0: - local_state = partition_tensor_parallel_state_dict( - local_state, - ParallelMode.PARALLEL_2D_ROW, - dims={ - weight_key: 0, - bias_key: 0, - cls_token_key: -1, - pos_embed_key: -1 - }, - partition_states={ - weight_key: True, - bias_key: True, - cls_token_key: True, - pos_embed_key: True - }, - ) - # partition in column groups - local_state = partition_tensor_parallel_state_dict( - local_state, - ParallelMode.PARALLEL_2D_COL, - dims={ - weight_key: 0, - bias_key: 0, - cls_token_key: -1, - pos_embed_key: -1 - }, - partition_states={ - weight_key: True, - bias_key: True, - cls_token_key: True, - pos_embed_key: True - }, - ) - - super()._load_from_global_state_dict(local_state, prefix, *args, **kwargs) - - def _save_to_global_state_dict(self, destination, prefix, keep_vars): - weight_key = prefix + 'weight' - bias_key = prefix + 'bias' - cls_token_key = prefix + 'cls_token' - pos_embed_key = prefix + 'pos_embed' - local_state = OrderedDict({ - weight_key: self.weight, - bias_key: self.bias, - cls_token_key: self.cls_token, - pos_embed_key: self.pos_embed - }) - - # gather in column groups - local_state = gather_tensor_parallel_state_dict( - local_state, - ParallelMode.PARALLEL_2D_COL, - dims={ - weight_key: 0, - bias_key: 0, - cls_token_key: -1, - pos_embed_key: -1 - }, - partition_states={ - weight_key: True, - bias_key: True, - cls_token_key: True, - pos_embed_key: True - }, - keep_vars=keep_vars, - ) - # gather in row groups - if gpc.get_local_rank(ParallelMode.PARALLEL_2D_COL) == 0: - local_state = gather_tensor_parallel_state_dict( - local_state, - ParallelMode.PARALLEL_2D_ROW, - dims={ - weight_key: 0, - bias_key: 0, - cls_token_key: -1, - pos_embed_key: -1 - }, - partition_states={ - weight_key: True, - bias_key: True, - cls_token_key: True, - pos_embed_key: True - }, - keep_vars=keep_vars, - ) - if gpc.get_local_rank(ParallelMode.TENSOR) == 0: - destination.update(local_state) - - def forward(self, input_: Tensor) -> Tensor: - input_ = split_batch_2d(input_) - - B, C, H, W = input_.shape - assert H == self.img_size[0] and W == self.img_size[1], \ - f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." - - weight = all_gather_tensor_2d(self.weight, 0, ParallelMode.PARALLEL_2D_COL) - bias = all_gather_tensor_2d(self.bias, 0, ParallelMode.PARALLEL_2D_COL) - - output = F.conv2d(input_, weight, bias, stride=self.patch_size) - if self.flatten: - output = output.flatten(2).transpose(1, 2) # BCHW -> BNC - - cls_token = all_gather_tensor_2d(self.cls_token, -1, ParallelMode.PARALLEL_2D_COL) - pos_embed = all_gather_tensor_2d(self.pos_embed, -1, ParallelMode.PARALLEL_2D_COL) - cls_token = cls_token.expand(output.shape[0], -1, -1) - output = torch.cat((cls_token, output), dim=1) - output = output + pos_embed - - return output - - -@LAYERS.register_module -class Embedding2D(ParallelLayer): - r"""Embedding for 2D parallelism. - - Args: - num_embeddings (int): number of embeddings. - embedding_dim (int): dimension of embedding. - padding_idx (int, optional): If specified, the entries at padding_idx do not contribute to the gradient; - therefore, the embedding vector at padding_idx is not updated during training, - i.e. it remains as a fixed “pad”, defaults to None. - dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None. - weight_initializer (:class:`typing.Callable`, optional): - he initializer of weight, defaults to normal initializer. - - The ``args`` and ``kwargs`` used in :class:``torch.nn.functional.embedding`` should contain: - :: - - max_norm (float, optional): If given, each embedding vector with norm larger than max_norm is - renormalized to have norm max_norm. Note: this will modify weight in-place. - norm_type (float, optional): The p of the p-norm to compute for the max_norm option. Default 2. - scale_grad_by_freq (bool, optional): If given, this will scale gradients by the inverse - of frequency of the words in the mini-batch. Default False. - sparse (bool, optional): If True, gradient w.r.t. weight will be a sparse tensor. Default False. - - More details about ``args`` and ``kwargs`` could be found in - `Embedding `_. - - More details about initializer please refer to - `init `_ - """ - - def __init__(self, - num_embeddings: int, - embedding_dim: int, - padding_idx: int = None, - dtype: torch.dtype = None, - weight_initializer: Callable = init.normal_(), - *args, - **kwargs): - super().__init__() - - assert_summa_initialization() - self.summa_dim = get_summa_dim_from_env() - self.num_embeddings = num_embeddings - self.embed_dim = embedding_dim - embed_dim_per_partition = divide(embedding_dim, self.summa_dim**2) - - self.padding_idx = padding_idx - self.embed_args = args - self.embed_kwargs = kwargs - - self.weight = Parameter( - torch.empty((num_embeddings, embed_dim_per_partition), device=get_current_device(), dtype=dtype)) - - self.reset_parameters(weight_initializer) - self._set_tensor_parallel_attributes() - - def _set_tensor_parallel_attributes(self): - set_tensor_parallel_attribute_by_partition(self.weight, self.summa_dim**2) - - def reset_parameters(self, weight_initializer) -> None: - with seed(ParallelMode.TENSOR): - fan_in, fan_out = self.num_embeddings, self.embed_dim - weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out) - self._fill_padding_idx_with_zero() - - def _fill_padding_idx_with_zero(self) -> None: - if self.padding_idx is not None: - with torch.no_grad(): - self.weight[self.padding_idx].fill_(0) - - def _load_from_global_state_dict(self, state_dict, prefix, *args, **kwargs): - local_state = OrderedDict() - weight_key = prefix + 'weight' - if gpc.get_local_rank(ParallelMode.TENSOR) == 0: - # weight - weight = state_dict.pop(weight_key, None) - if weight is not None: - local_state[weight_key] = weight - - # partition in row groups - if gpc.get_local_rank(ParallelMode.PARALLEL_2D_COL) == 0: - local_state = partition_tensor_parallel_state_dict( - local_state, - ParallelMode.PARALLEL_2D_ROW, - dims={weight_key: -1}, - partition_states={weight_key: True}, - ) - # partition in column groups - local_state = partition_tensor_parallel_state_dict( - local_state, - ParallelMode.PARALLEL_2D_COL, - dims={weight_key: -1}, - partition_states={weight_key: True}, - ) - - super()._load_from_global_state_dict(local_state, prefix, *args, **kwargs) - - def _save_to_global_state_dict(self, destination, prefix, keep_vars): - weight_key = prefix + 'weight' - local_state = OrderedDict({weight_key: self.weight}) - - # gather in column groups - local_state = gather_tensor_parallel_state_dict( - local_state, - ParallelMode.PARALLEL_2D_COL, - dims={weight_key: -1}, - partition_states={weight_key: True}, - keep_vars=keep_vars, - ) - # gather in row groups - if gpc.get_local_rank(ParallelMode.PARALLEL_2D_COL) == 0: - local_state = gather_tensor_parallel_state_dict( - local_state, - ParallelMode.PARALLEL_2D_ROW, - dims={weight_key: -1}, - partition_states={weight_key: True}, - keep_vars=keep_vars, - ) - if gpc.get_local_rank(ParallelMode.TENSOR) == 0: - destination.update(local_state) - - def forward(self, input_: Tensor) -> Tensor: - input_ = split_batch_2d(input_) - - weight = all_gather_tensor_2d(self.weight, -1, ParallelMode.PARALLEL_2D_COL) - output = F.embedding(input_, weight, self.padding_idx, *self.embed_args, **self.embed_kwargs) - - return output - - -@LAYERS.register_module -class VocabParallelEmbedding2D(ParallelLayer): - r"""Embedding parallelized in the vocabulary dimension. - - Args: - num_embeddings (int): number of embeddings. - embedding_dim (int): dimension of embedding. - padding_idx (int, optional): If specified, the entries at padding_idx do not contribute to the gradient; - therefore, the embedding vector at padding_idx is not updated during training, - i.e. it remains as a fixed “pad”, defaults to None. - dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None. - weight_initializer (:class:`typing.Callable`, optional): - he initializer of weight, defaults to normal initializer. - - The ``args`` and ``kwargs`` used in :class:``torch.nn.functional.embedding`` should contain: - :: - - max_norm (float, optional): If given, each embedding vector with norm larger than max_norm is - renormalized to have norm max_norm. Note: this will modify weight in-place. - norm_type (float, optional): The p of the p-norm to compute for the max_norm option. Default 2. - scale_grad_by_freq (bool, optional): If given, this will scale gradients by the inverse - of frequency of the words in the mini-batch. Default False. - sparse (bool, optional): If True, gradient w.r.t. weight will be a sparse tensor. Default False. - - More details about ``args`` and ``kwargs`` could be found in - `Embedding `_. - - More details about initializer please refer to - `init `_. - """ - - def __init__(self, - num_embeddings: int, - embedding_dim: int, - padding_idx: int = None, - dtype: torch.dtype = None, - weight_initializer: Callable = init.normal_(), - *args, - **kwargs): - super().__init__() - self.num_embeddings = num_embeddings - self.embed_dim = embedding_dim - self.padding_idx = padding_idx - self.embed_args = args - self.embed_kwargs = kwargs - - assert_summa_initialization() - self.summa_dim = get_summa_dim_from_env() - self.num_embeddings_per_partition = divide(self.num_embeddings, self.summa_dim) - self.embed_dim_per_partition = divide(self.embed_dim, self.summa_dim) - tensor_parallel_rank = gpc.get_local_rank(ParallelMode.PARALLEL_2D_COL) - self.vocab_start_index = tensor_parallel_rank * self.num_embeddings_per_partition - self.vocab_end_index = self.vocab_start_index + self.num_embeddings_per_partition - - self.weight = Parameter( - torch.empty((self.num_embeddings_per_partition, self.embed_dim_per_partition), - device=get_current_device(), - dtype=dtype)) - - self.reset_parameters(weight_initializer) - self._set_tensor_parallel_attributes() - env.vocab_parallel = True - - def _set_tensor_parallel_attributes(self): - set_tensor_parallel_attribute_by_partition(self.weight, self.summa_dim**2) - - def reset_parameters(self, weight_initializer) -> None: - with seed(ParallelMode.TENSOR): - fan_in, fan_out = self.num_embeddings, self.embed_dim - weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out) - self._fill_padding_idx_with_zero() - - def _fill_padding_idx_with_zero(self) -> None: - if self.padding_idx is not None and \ - self.padding_idx >= self.vocab_start_index and self.padding_idx < self.vocab_end_index: - with torch.no_grad(): - self.weight[self.padding_idx - self.vocab_start_index].fill_(0) - - def _load_from_global_state_dict(self, state_dict, prefix, *args, **kwargs): - local_state = OrderedDict() - weight_key = prefix + 'weight' - if gpc.get_local_rank(ParallelMode.TENSOR) == 0: - # weight - weight = state_dict.pop(weight_key, None) - if weight is not None: - local_state[weight_key] = weight - - # partition in row groups - if gpc.get_local_rank(ParallelMode.PARALLEL_2D_COL) == 0: - local_state = partition_tensor_parallel_state_dict( - local_state, - ParallelMode.PARALLEL_2D_ROW, - dims={weight_key: -1}, - partition_states={weight_key: True}, - ) - # partition in column groups - local_state = partition_tensor_parallel_state_dict( - local_state, - ParallelMode.PARALLEL_2D_COL, - dims={weight_key: 0}, - partition_states={weight_key: True}, - ) - - super()._load_from_global_state_dict(local_state, prefix, *args, **kwargs) - - def _save_to_global_state_dict(self, destination, prefix, keep_vars): - weight_key = prefix + 'weight' - local_state = OrderedDict({weight_key: self.weight}) - - # gather in column groups - local_state = gather_tensor_parallel_state_dict( - local_state, - ParallelMode.PARALLEL_2D_COL, - dims={weight_key: 0}, - partition_states={weight_key: True}, - keep_vars=keep_vars, - ) - # gather in row groups - if gpc.get_local_rank(ParallelMode.PARALLEL_2D_COL) == 0: - local_state = gather_tensor_parallel_state_dict( - local_state, - ParallelMode.PARALLEL_2D_ROW, - dims={weight_key: -1}, - partition_states={weight_key: True}, - keep_vars=keep_vars, - ) - if gpc.get_local_rank(ParallelMode.TENSOR) == 0: - destination.update(local_state) - - def forward(self, input_: Tensor) -> Tensor: - input_mask = (input_ < self.vocab_start_index) | (input_ >= self.vocab_end_index) - masked_input = input_.clone() - self.vocab_start_index - masked_input[input_mask] = 0 - - output_parallel = F.embedding(masked_input, self.weight, self.padding_idx, *self.embed_args, - **self.embed_kwargs) - - output_parallel[input_mask, :] = 0. - output = reduce_scatter_tensor_2d(output_parallel, 0, ParallelMode.PARALLEL_2D_COL) - return output - - -@LAYERS.register_module -class Classifier2D(ParallelLayer): - r"""Classifier for 2D parallelism. - - Args: - in_features (int): size of each input sample. - num_classes (int): number of classes. - weight (:class:`torch.nn.Parameter`, optional): weight of the classifier, defaults to None. - bias (bool, optional): If set to ``False``, the layer will not learn an additive bias, defaults to ``True``. - dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None. - weight_initializer (:class:`typing.Callable`, optional): - The initializer of weight, defaults to kaiming uniform initializer. - bias_initializer (:class:`typing.Callable`, optional): - The initializer of bias, defaults to xavier uniform initializer. - - More details about ``initializer`` please refer to - `init `_. - """ - - def __init__(self, - in_features: int, - num_classes: int, - weight: Parameter = None, - bias: bool = True, - dtype: torch.dtype = None, - weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), - bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1)): - super().__init__() - self.in_features = in_features - self.num_classes = num_classes - assert_summa_initialization() - self.row_rank = gpc.get_local_rank(ParallelMode.PARALLEL_2D_COL) - self.col_rank = gpc.get_local_rank(ParallelMode.PARALLEL_2D_ROW) - self.summa_dim = get_summa_dim_from_env() - - # partitioning dimension - self.input_size_per_partition = divide(self.in_features, self.summa_dim**2) - - if weight is not None: - self.weight = weight - self.has_weight = False - else: - self.weight = Parameter( - torch.empty(self.num_classes, self.input_size_per_partition, device=get_current_device(), dtype=dtype)) - self.has_weight = True - if bias: - self.bias = Parameter(torch.zeros(self.num_classes, device=get_current_device(), dtype=dtype)) - else: - self.bias = None - - self.reset_parameters(weight_initializer, bias_initializer) - self._set_tensor_parallel_attributes() - - def _set_tensor_parallel_attributes(self): - if self.has_weight: - set_tensor_parallel_attribute_by_partition(self.weight, self.summa_dim**2) - - def reset_parameters(self, weight_initializer, bias_initializer) -> None: - with seed(ParallelMode.TENSOR): - fan_in, fan_out = self.in_features, self.num_classes - col_src_rank = gpc.get_ranks_in_group(ParallelMode.PARALLEL_2D_COL)[0] - row_src_rank = gpc.get_ranks_in_group(ParallelMode.PARALLEL_2D_ROW)[0] - - if self.has_weight: - weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out) - - if self.bias is not None: - bias_initializer(self.bias, fan_in=fan_in) - broadcast(self.bias, col_src_rank, ParallelMode.PARALLEL_2D_COL) - broadcast(self.bias, row_src_rank, ParallelMode.PARALLEL_2D_ROW) - - def _load_from_global_state_dict(self, state_dict, prefix, *args, **kwargs): - local_state = OrderedDict() - weight_key = prefix + 'weight' - bias_key = prefix + 'bias' - if gpc.get_local_rank(ParallelMode.TENSOR) == 0: - # weight - if self.has_weight: - weight = state_dict.pop(weight_key, None) - if weight is not None: - local_state[weight_key] = weight - # bias - if self.bias is not None: - bias = state_dict.pop(bias_key, None) - if bias is not None: - local_state[bias_key] = bias - - # partition in row groups - if gpc.get_local_rank(ParallelMode.PARALLEL_2D_COL) == 0: - local_state = partition_tensor_parallel_state_dict( - local_state, - ParallelMode.PARALLEL_2D_ROW, - dims={ - weight_key: -1, - bias_key: 0 - }, - partition_states={ - weight_key: True, - bias_key: False - }, - ) - # partition in column groups - local_state = partition_tensor_parallel_state_dict( - local_state, - ParallelMode.PARALLEL_2D_COL, - dims={ - weight_key: -1, - bias_key: 0 - }, - partition_states={ - weight_key: True, - bias_key: False - }, - ) - - super()._load_from_global_state_dict(local_state, prefix, *args, **kwargs) - - def _save_to_global_state_dict(self, destination, prefix, keep_vars): - weight_key = prefix + 'weight' - bias_key = prefix + 'bias' - local_state = OrderedDict() - if self.has_weight: - local_state[weight_key] = self.weight - if self.bias is not None: - local_state[bias_key] = self.bias - - # gather in column groups - local_state = gather_tensor_parallel_state_dict( - local_state, - ParallelMode.PARALLEL_2D_COL, - dims={ - weight_key: -1, - bias_key: 0 - }, - partition_states={ - weight_key: True, - bias_key: False - }, - keep_vars=keep_vars, - ) - # gather in row groups - if gpc.get_local_rank(ParallelMode.PARALLEL_2D_COL) == 0: - local_state = gather_tensor_parallel_state_dict( - local_state, - ParallelMode.PARALLEL_2D_ROW, - dims={ - weight_key: -1, - bias_key: 0 - }, - partition_states={ - weight_key: True, - bias_key: False - }, - keep_vars=keep_vars, - ) - if gpc.get_local_rank(ParallelMode.TENSOR) == 0: - destination.update(local_state) - - def forward(self, input_: Tensor) -> Tensor: - out_shape = input_.shape[:-1] + (self.num_classes,) - - return classifier_2d(input_, self.weight, self.bias, self.summa_dim, out_shape, self.row_rank, self.col_rank, - ParallelMode.PARALLEL_2D_ROW, ParallelMode.PARALLEL_2D_COL, self.data_parallel_rank, - self.pipeline_parallel_rank, self.pipeline_parallel_size, self.tensor_parallel_size) - - -@LAYERS.register_module -class VocabParallelClassifier2D(ParallelLayer): - r"""Vocab parallel classifier layer for 2D parallelism. - - Args: - in_features (int): size of each input sample. - num_classes (int): number of classes. - weight (:class:`torch.nn.Parameter`, optional): weight of the classifier, defaults to None. - bias (bool, optional): If set to ``False``, the layer will not learn an additive bias, defaults to ``True``. - dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None. - weight_initializer (:class:`typing.Callable`, optional): - The initializer of weight, defaults to kaiming uniform initializer. - bias_initializer (:class:`typing.Callable`, optional): - The initializer of bias, defaults to xavier uniform initializer. - - More details about ``initializer`` please refer to - `init `_. - """ - - def __init__(self, - in_features: int, - num_classes: int, - weight: Parameter = None, - bias: bool = True, - dtype: torch.dtype = None, - weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), - bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1)): - super().__init__() - - self.in_features = in_features - self.num_classes = num_classes - - # parallel setting - assert_summa_initialization() - self.row_rank = gpc.get_local_rank(ParallelMode.PARALLEL_2D_COL) - self.col_rank = gpc.get_local_rank(ParallelMode.PARALLEL_2D_ROW) - self.summa_dim = get_summa_dim_from_env() - - # partitioning dimension - self.input_size_per_partition = divide(in_features, self.summa_dim) - self.output_size_per_partition = divide(num_classes, self.summa_dim) - - # create weight, shape: [k/q, h/q] - factory_kwargs = {'device': get_current_device(), 'dtype': dtype} - if weight is not None: - self.weight = weight - self.has_weight = False - else: - self.weight = Parameter( - torch.empty(self.output_size_per_partition, self.input_size_per_partition, **factory_kwargs)) - self.has_weight = True - # create bias, shape: [h/q] - if bias: - self.bias = Parameter(torch.empty(divide(self.num_classes, self.summa_dim**2), **factory_kwargs)) - else: - self.bias = None - - # initialize parameters - with seed(ParallelMode.TENSOR): - self.reset_parameters(weight_initializer, bias_initializer) - self._set_tensor_parallel_attributes() - env.vocab_parallel = True - - def _set_tensor_parallel_attributes(self): - if self.has_weight: - set_tensor_parallel_attribute_by_partition(self.weight, self.summa_dim**2) - if self.bias is not None: - set_tensor_parallel_attribute_by_partition(self.bias, self.summa_dim**2) - - def reset_parameters(self, weight_initializer, bias_initializer) -> None: - fan_in, fan_out = self.in_features, self.num_classes - if self.has_weight: - weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out) - if self.bias is not None: - bias_initializer(self.bias, fan_in=fan_in) - - def _load_from_global_state_dict(self, state_dict, prefix, *args, **kwargs): - local_state = OrderedDict() - weight_key = prefix + 'weight' - bias_key = prefix + 'bias' - if gpc.get_local_rank(ParallelMode.TENSOR) == 0: - # weight - if self.has_weight: - weight = state_dict.pop(weight_key, None) - if weight is not None: - local_state[weight_key] = weight - # bias - if self.bias is not None: - bias = state_dict.pop(bias_key, None) - if bias is not None: - local_state[bias_key] = bias - - # partition in row groups - if gpc.get_local_rank(ParallelMode.PARALLEL_2D_COL) == 0: - local_state = partition_tensor_parallel_state_dict( - local_state, - ParallelMode.PARALLEL_2D_ROW, - dims={ - weight_key: -1, - bias_key: 0 - }, - partition_states={ - weight_key: True, - bias_key: True - }, - ) - # partition in column groups - local_state = partition_tensor_parallel_state_dict( - local_state, - ParallelMode.PARALLEL_2D_COL, - dims={ - weight_key: 0, - bias_key: 0 - }, - partition_states={ - weight_key: True, - bias_key: True - }, - ) - - super()._load_from_global_state_dict(local_state, prefix, *args, **kwargs) - - def _save_to_global_state_dict(self, destination, prefix, keep_vars): - weight_key = prefix + 'weight' - bias_key = prefix + 'bias' - local_state = OrderedDict() - if self.has_weight: - local_state[weight_key] = self.weight - if self.bias is not None: - local_state[bias_key] = self.bias - - # gather in column groups - local_state = gather_tensor_parallel_state_dict( - local_state, - ParallelMode.PARALLEL_2D_COL, - dims={ - weight_key: 0, - bias_key: 0 - }, - partition_states={ - weight_key: True, - bias_key: True - }, - keep_vars=keep_vars, - ) - # gather in row groups - if gpc.get_local_rank(ParallelMode.PARALLEL_2D_COL) == 0: - local_state = gather_tensor_parallel_state_dict( - local_state, - ParallelMode.PARALLEL_2D_ROW, - dims={ - weight_key: -1, - bias_key: 0 - }, - partition_states={ - weight_key: True, - bias_key: True - }, - keep_vars=keep_vars, - ) - if gpc.get_local_rank(ParallelMode.TENSOR) == 0: - local_state[weight_key] = local_state[weight_key].transpose(0, 1) - destination.update(local_state) - - def forward(self, x: Tensor) -> Tensor: - # input: [m/q, n/q, k/q] - # output: [m/q, n/q, h/q] - out_shape = x.shape[:-1] + (self.output_size_per_partition,) - - output = Matmul_ABT_2D.apply(x, self.weight, self.summa_dim, out_shape, self.row_rank, self.col_rank, - ParallelMode.PARALLEL_2D_ROW, ParallelMode.PARALLEL_2D_COL, - self.data_parallel_rank, self.pipeline_parallel_rank, self.pipeline_parallel_size, - self.tensor_parallel_size) - - if self.bias is not None: - output = add_bias_2d(output, self.bias, self.output_size_per_partition, self.row_rank, self.col_rank, - ParallelMode.PARALLEL_2D_ROW, ParallelMode.PARALLEL_2D_COL, False, - self.data_parallel_rank, self.pipeline_parallel_rank, self.pipeline_parallel_size, - self.tensor_parallel_size) - return output diff --git a/colossalai/nn/layer/parallel_2p5d/__init__.py b/colossalai/nn/layer/parallel_2p5d/__init__.py deleted file mode 100644 index bec3b1c4b0b87e8db497b627207ca6f30b1fff49..0000000000000000000000000000000000000000 --- a/colossalai/nn/layer/parallel_2p5d/__init__.py +++ /dev/null @@ -1,8 +0,0 @@ -from ._operation import reduce_by_batch_2p5d, split_batch_2p5d -from .layers import (Classifier2p5D, Embedding2p5D, LayerNorm2p5D, Linear2p5D, PatchEmbedding2p5D, - VocabParallelClassifier2p5D, VocabParallelEmbedding2p5D) - -__all__ = [ - 'split_batch_2p5d', 'reduce_by_batch_2p5d', 'Linear2p5D', 'LayerNorm2p5D', 'Classifier2p5D', 'PatchEmbedding2p5D', - 'Embedding2p5D', 'VocabParallelClassifier2p5D', 'VocabParallelEmbedding2p5D' -] diff --git a/colossalai/nn/layer/parallel_2p5d/_operation.py b/colossalai/nn/layer/parallel_2p5d/_operation.py deleted file mode 100644 index 5a0f537cd6d9ca0d6104390a8cdb4634f25d204c..0000000000000000000000000000000000000000 --- a/colossalai/nn/layer/parallel_2p5d/_operation.py +++ /dev/null @@ -1,880 +0,0 @@ -from typing import Any, Tuple - -import torch -import torch.distributed as dist -from colossalai.communication.collective import (all_gather, all_reduce, reduce_scatter) -from colossalai.context.parallel_mode import ParallelMode -from colossalai.core import global_context as gpc -from colossalai.utils import get_current_device -from torch import Tensor -from torch.cuda.amp import custom_bwd, custom_fwd - - -def get_parallel_group(parallel_mode: ParallelMode): - return gpc.get_group(parallel_mode) - - -def get_global_rank(): - return gpc.get_global_rank() - - -def get_parallel_rank(parallel_mode: ParallelMode): - return gpc.get_local_rank(parallel_mode) - - -class _Classifier2p5D(torch.autograd.Function): - - @staticmethod - @custom_fwd(cast_inputs=torch.float16) - def forward( - ctx: Any, - A: Tensor, - B: Tensor, - bias, - tesseract_dim: int, - out_shape: Tuple[int, ...], - row_rank: int, - col_rank: int, - row_parallel_mode: ParallelMode, - col_parallel_mode: ParallelMode, - data_parallel_rank: int, - pipeline_parallel_rank: int, - pipeline_parallel_size: int, - tensor_parallel_size: int, - ) -> Tensor: - A = A.clone().detach() - A_shape = A.shape - A = A.reshape((-1, A_shape[-1])) - B_shape = B.shape - B = B.reshape((-1, B_shape[-1])) - B_temp = all_gather(B, -1, col_parallel_mode) - if ctx: - ctx.save_for_backward(A, B_temp) - - C = torch.matmul(A, B_temp.transpose(0, 1)) - - C = all_reduce(C, row_parallel_mode) - - ctx.use_bias = bias is not None - if bias is not None: - C = C + bias - - out = C.reshape(out_shape) - - if ctx: - ctx.tesseract_dim = tesseract_dim - ctx.row_rank = row_rank - ctx.col_rank = col_rank - ctx.row_parallel_mode = row_parallel_mode - ctx.col_parallel_mode = col_parallel_mode - ctx.A_shape = A_shape - ctx.B_shape = B_shape - ctx.data_parallel_rank = data_parallel_rank - ctx.pipeline_parallel_rank = pipeline_parallel_rank - ctx.pipeline_parallel_size = pipeline_parallel_size - ctx.tensor_parallel_size = tensor_parallel_size - - return out - - @staticmethod - @custom_bwd - def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]: - A, B = ctx.saved_tensors - - with torch.no_grad(): - A_grad = torch.matmul(output_grad, B) - A_grad = A_grad.reshape(ctx.A_shape) - B_grad = torch.matmul(output_grad.reshape(-1, output_grad.shape[-1]).transpose(0, 1), A) - B_grad = reduce_scatter(B_grad, -1, ctx.col_parallel_mode) - B_grad = B_grad.reshape(ctx.B_shape) - - if ctx.use_bias: - bias_grad = torch.sum(output_grad, dim=tuple(range(output_grad.ndim - 1))) - bias_grad = all_reduce(bias_grad, ctx.col_parallel_mode) - else: - bias_grad = None - - return A_grad, B_grad, bias_grad, None, None, None, None, None, None, None, None, None, None - - -def classifier_2p5d(A: Tensor, B: Tensor, bias, tesseract_dim: int, out_shape: Tuple[int, - ...], row_rank: int, col_rank: int, - row_parallel_mode: ParallelMode, col_parallel_mode: ParallelMode, data_parallel_rank: int, - pipeline_parallel_rank: int, pipeline_parallel_size: int, tensor_parallel_size: int) -> Tensor: - r"""Classifier. - - Args: - A (:class:`torch.tensor`): matrix :math:`A`. - B (:class:`torch.tensor`): matrix :math:`B`. - bias (:class:`torch.tensor`): matrix of bias. - tesseract_dim (int): dimension of TESSERACT fo 2.5D parallelism. - out_shape (:class:`torch.size`): shape of output tensor. - row_rank (int): the rank of row. - col_rank (int): the rank of column. - row_parallel_mode (:class:`colossalai.context.ParallelMode`): row parallel mode. - col_parallel_mode (:class:`colossalai.context.ParallelMode`): column parallel mode. - data_parallel_rank (int): data parallel rank. - pipeline_parallel_rank (int): pipeline parallel rank - pipeline_parallel_size (int): pipeline parallel size. - tensor_parallel_size (int): tensor parallel size. - - Note: - The parallel_mode should be concluded in ``ParallelMode``. More details about ``ParallelMode`` could be found - in `parallel_mode `_ - """ - return _Classifier2p5D.apply(A, B, bias, tesseract_dim, out_shape, row_rank, col_rank, row_parallel_mode, - col_parallel_mode, data_parallel_rank, pipeline_parallel_rank, pipeline_parallel_size, - tensor_parallel_size) - - -class Matmul_AB_2p5D(torch.autograd.Function): - r"""Matrix multiplication for :math:`C = AB`. - - Args: - A (:class:`torch.tensor`): matrix :math:`A`. - B (:class:`torch.tensor`): matrix :math:`B`. - tesseract_dim (int): dimension of TESSERACT fo 2.5D parallelism. - out_shape (:class:`torch.size`): shape of output tensor. - row_rank (int): the rank of row. - col_rank (int): the rank of column. - dep_rank (int): the rank of depth. - row_parallel_mode (:class:`colossalai.context.ParallelMode`): row parallel mode. - col_parallel_mode (:class:`colossalai.context.ParallelMode`): column parallel mode. - data_parallel_rank (int): data parallel rank. - pipeline_parallel_rank (int): pipeline parallel rank - pipeline_parallel_size (int): pipeline parallel size. - tensor_parallel_size (int): tensor parallel size. - - Note: - The parallel_mode should be concluded in ``ParallelMode``. More details about ``ParallelMode`` could be found - in `parallel_mode `_ - """ - - @staticmethod - @custom_fwd(cast_inputs=torch.float16) - def forward(ctx: Any, A: Tensor, B: Tensor, tesseract_dim: int, out_shape: Tuple[int, ...], row_rank: int, - col_rank: int, dep_rank: int, row_parallel_mode: ParallelMode, col_parallel_mode: ParallelMode, - data_parallel_rank: int, pipeline_parallel_rank: int, pipeline_parallel_size: int, - tensor_parallel_size: int) -> Tensor: - # A: [b / dq, s, h / q] -> [(b * s) / dq, h / q] - # B: [h / dq, s / q] - # C: [b / dq, s, s / q] -> [(b * s) / dq, s / q] - - assert A.shape[-1] == B.shape[-2], \ - 'Invalid shapes: A={}, B={} for AB.'.format(A.shape, B.shape) - - if ctx: - ctx.save_for_backward(A, B) - - A_shape = A.shape - A = A.reshape((-1, A_shape[-1])) - B_shape = B.shape - B = B.reshape((-1, B_shape[-1])) - C_shape = (A.shape[0], B.shape[-1]) - C = torch.zeros(C_shape, dtype=A.dtype, device=get_current_device()) - - # use circular buffer to store the communication tensor - # 2 is enough for all cases - A_list = [torch.empty_like(A) for _ in range(2)] - B_list = [torch.empty_like(B) for _ in range(2)] - - row_group = gpc.get_group(row_parallel_mode) - col_group = gpc.get_group(col_parallel_mode) - - src_a = \ - tesseract_dim * row_rank + tesseract_dim ** 2 * dep_rank + \ - data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \ - pipeline_parallel_rank * tensor_parallel_size - src_b = \ - col_rank + tesseract_dim ** 2 * dep_rank + \ - data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \ - pipeline_parallel_rank * tensor_parallel_size - - opa = [None] * 2 - opb = [None] * 2 - - A_list[0].copy_(A) - B_list[0].copy_(B) - opa[0] = dist.broadcast(A_list[0], src=src_a, group=row_group, async_op=True) - opb[0] = dist.broadcast(B_list[0], src=src_b, group=col_group, async_op=True) - cur = 0 - - for i in range(tesseract_dim): - if i != tesseract_dim - 1: - A_list[1 - cur].copy_(A) - opa[1 - cur] = dist.broadcast(A_list[1 - cur], src=src_a + 1, group=row_group, async_op=True) - B_list[1 - cur].copy_(B) - opb[1 - cur] = dist.broadcast(B_list[1 - cur], - src=src_b + tesseract_dim, - group=col_group, - async_op=True) - - if opa[cur] is not None: - opa[cur].wait() - if opb[cur] is not None: - opb[cur].wait() - - torch.addmm(C, A_list[cur], B_list[cur], out=C) - cur = 1 - cur - src_a += 1 - src_b += tesseract_dim - out = C.reshape(out_shape) - - if ctx: - ctx.tesseract_dim = tesseract_dim - ctx.row_rank = row_rank - ctx.col_rank = col_rank - ctx.dep_rank = dep_rank - ctx.row_parallel_mode = row_parallel_mode - ctx.col_parallel_mode = col_parallel_mode - ctx.A_shape = A_shape - ctx.B_shape = B_shape - ctx.data_parallel_rank = data_parallel_rank - ctx.pipeline_parallel_rank = pipeline_parallel_rank - ctx.pipeline_parallel_size = pipeline_parallel_size - ctx.tensor_parallel_size = tensor_parallel_size - - return out - - @staticmethod - @custom_bwd - def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]: - A, B = ctx.saved_tensors - with torch.no_grad(): - A_grad = Matmul_ABT_2p5D.apply(output_grad, B, ctx.tesseract_dim, ctx.A_shape, ctx.row_rank, ctx.col_rank, - ctx.dep_rank, ctx.row_parallel_mode, ctx.col_parallel_mode, - ctx.data_parallel_rank, ctx.pipeline_parallel_rank, - ctx.pipeline_parallel_size, ctx.tensor_parallel_size) - B_grad = Matmul_ATB_2p5D.apply(A, output_grad, ctx.tesseract_dim, ctx.B_shape, ctx.row_rank, ctx.col_rank, - ctx.dep_rank, ctx.row_parallel_mode, ctx.col_parallel_mode, - ctx.data_parallel_rank, ctx.pipeline_parallel_rank, - ctx.pipeline_parallel_size, ctx.tensor_parallel_size) - return A_grad, B_grad, None, None, None, None, None, None, None, None, None, None, None, None, None - - -class Matmul_ABT_2p5D(torch.autograd.Function): - r"""Matrix multiplication for :math:`C = AB^T`. - - Args: - A (:class:`torch.tensor`): matrix :math:`A`. - B (:class:`torch.tensor`): matrix :math:`B`. - tesseract_dim (int): dimension of TESSERACT fo 2.5D parallelism. - out_shape (:class:`torch.size`): shape of output tensor. - row_rank (int): the rank of row. - col_rank (int): the rank of column. - dep_rank (int): the rank of depth. - row_parallel_mode (:class:`colossalai.context.ParallelMode`): row parallel mode. - col_parallel_mode (:class:`colossalai.context.ParallelMode`): column parallel mode. - data_parallel_rank (int): data parallel rank. - pipeline_parallel_rank (int): pipeline parallel rank - pipeline_parallel_size (int): pipeline parallel size. - tensor_parallel_size (int): tensor parallel size. - - Note: - The parallel_mode should be concluded in ``ParallelMode``. More details about ``ParallelMode`` could be found - in `parallel_mode `_ - """ - - @staticmethod - @custom_fwd(cast_inputs=torch.float16) - def forward(ctx: Any, A: Tensor, B: Tensor, tesseract_dim: int, out_shape: Tuple[int, ...], row_rank: int, - col_rank: int, dep_rank: int, row_parallel_mode: ParallelMode, col_parallel_mode: ParallelMode, - data_parallel_rank: int, pipeline_parallel_rank: int, pipeline_parallel_size: int, - tensor_parallel_size: int) -> Tensor: - - assert A.shape[-1] == B.shape[-1], \ - 'Invalid shapes: A={}, B={} for ABT.'.format(A.shape, B.shape) - - if ctx: - ctx.save_for_backward(A, B) - - A_shape = A.shape - A = A.reshape((-1, A_shape[-1])) - B_shape = B.shape - B = B.reshape((-1, B_shape[-1])) - C_shape = (A.shape[0], B.shape[0]) - C = torch.empty(C_shape, dtype=A.dtype, device=get_current_device()) - - # use circular buffer to store the communication tensor - # 2 is enough for all cases - B_list = [torch.empty_like(B) for _ in range(2)] - C_list = [torch.empty_like(C) for _ in range(2)] - - row_group = gpc.get_group(row_parallel_mode) - col_group = gpc.get_group(col_parallel_mode) - - src_b = \ - col_rank + tesseract_dim ** 2 * dep_rank + \ - data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \ - pipeline_parallel_rank * tensor_parallel_size - src_c = \ - tesseract_dim * row_rank + tesseract_dim ** 2 * dep_rank + \ - data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \ - pipeline_parallel_rank * tensor_parallel_size - - opb = [None] * 2 - opr = [None] * 2 - - B_list[0].copy_(B) - opb[0] = dist.broadcast(B_list[0], src=src_b, group=col_group, async_op=True) - cur = 0 - - for i in range(tesseract_dim): - if i != tesseract_dim - 1: - B_list[1 - cur].copy_(B) - opb[1 - cur] = dist.broadcast(B_list[1 - cur], - src=src_b + tesseract_dim, - group=col_group, - async_op=True) - - if opr[cur] is not None: - opr[cur].wait() - if i - 2 == col_rank: - C.copy_(C_list[cur]) - - if opb[cur] is not None: - opb[cur].wait() - - torch.matmul(A, B_list[cur].transpose(0, 1), out=C_list[cur]) - opr[cur] = dist.reduce(C_list[cur], dst=src_c, group=row_group, async_op=True) - cur = 1 - cur - src_b += tesseract_dim - src_c += 1 - - for op in opr: - op.wait() - - if tesseract_dim - 2 == col_rank: - C.copy_(C_list[cur]) - if tesseract_dim - 1 == col_rank: - C.copy_(C_list[1 - cur]) - out = C.reshape(out_shape) - - if ctx: - ctx.tesseract_dim = tesseract_dim - ctx.row_rank = row_rank - ctx.col_rank = col_rank - ctx.dep_rank = dep_rank - ctx.row_parallel_mode = row_parallel_mode - ctx.col_parallel_mode = col_parallel_mode - ctx.A_shape = A_shape - ctx.B_shape = B_shape - ctx.data_parallel_rank = data_parallel_rank - ctx.pipeline_parallel_rank = pipeline_parallel_rank - ctx.pipeline_parallel_size = pipeline_parallel_size - ctx.tensor_parallel_size = tensor_parallel_size - - return out - - @staticmethod - @custom_bwd - def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]: - A, B = ctx.saved_tensors - with torch.no_grad(): - A_grad = Matmul_AB_2p5D.apply(output_grad, B, ctx.tesseract_dim, ctx.A_shape, ctx.row_rank, ctx.col_rank, - ctx.dep_rank, ctx.row_parallel_mode, ctx.col_parallel_mode, - ctx.data_parallel_rank, ctx.pipeline_parallel_rank, - ctx.pipeline_parallel_size, ctx.tensor_parallel_size) - B_grad = Matmul_ATB_2p5D.apply(output_grad, A, ctx.tesseract_dim, ctx.B_shape, ctx.row_rank, ctx.col_rank, - ctx.dep_rank, ctx.row_parallel_mode, ctx.col_parallel_mode, - ctx.data_parallel_rank, ctx.pipeline_parallel_rank, - ctx.pipeline_parallel_size, ctx.tensor_parallel_size) - return A_grad, B_grad, None, None, None, None, None, None, None, None, None, None, None, None, None - - -class Matmul_ATB_2p5D(torch.autograd.Function): - r"""Matrix multiplication for :math:`C = A^TB` - - Args: - A (:class:`torch.tensor`): matrix :math:`A`. - B (:class:`torch.tensor`): matrix :math:`B`. - tesseract_dim (int): dimension of TESSERACT fo 2.5D parallelism. - out_shape (:class:`torch.size`): shape of output tensor. - row_rank (int): the rank of row. - col_rank (int): the rank of column. - dep_rank (int): the rank of depth. - row_parallel_mode (:class:`colossalai.context.ParallelMode`): row parallel mode. - col_parallel_mode (:class:`colossalai.context.ParallelMode`): column parallel mode. - data_parallel_rank (int): data parallel rank. - pipeline_parallel_rank (int): pipeline parallel rank - pipeline_parallel_size (int): pipeline parallel size. - tensor_parallel_size (int): tensor parallel size. - - Note: - The parallel_mode should be concluded in ``ParallelMode``. More details about ``ParallelMode`` could be found - in `parallel_mode `_ - """ - - @staticmethod - @custom_fwd(cast_inputs=torch.float16) - def forward(ctx: Any, A: Tensor, B: Tensor, tesseract_dim: int, out_shape: Tuple[int, ...], row_rank: int, - col_rank: int, dep_rank: int, row_parallel_mode: ParallelMode, col_parallel_mode: ParallelMode, - data_parallel_rank: int, pipeline_parallel_rank: int, pipeline_parallel_size: int, - tensor_parallel_size: int): - - assert A.shape[-2] == B.shape[-2], \ - 'Invalid shapes: A={}, B={} for ATB.'.format(A.shape, B.shape) - - if ctx: - ctx.save_for_backward(A, B) - - A_shape = A.shape - A = A.reshape((-1, A_shape[-1])) - B_shape = B.shape - B = B.reshape((-1, B_shape[-1])) - C_shape = (A.shape[-1], B.shape[-1]) - C = torch.empty(C_shape, dtype=A.dtype, device=get_current_device()) - - # use circular buffer to store the communication tensor - # 2 is enough for all cases - A_list = [torch.empty_like(A) for _ in range(2)] - C_list = [torch.empty_like(C) for _ in range(2)] - - row_group = gpc.get_group(row_parallel_mode) - col_group = gpc.get_group(col_parallel_mode) - - src_a = \ - tesseract_dim * row_rank + tesseract_dim ** 2 * dep_rank + \ - data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \ - pipeline_parallel_rank * tensor_parallel_size - src_c = \ - col_rank + tesseract_dim ** 2 * dep_rank + \ - data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \ - pipeline_parallel_rank * tensor_parallel_size - - opa = [None] * 2 - opr = [None] * 2 - - A_list[0].copy_(A) - opa[0] = dist.broadcast(A_list[0], src=src_a, group=row_group, async_op=True) - cur = 0 - - for i in range(tesseract_dim): - if i != tesseract_dim - 1: - A_list[1 - cur].copy_(A) - opa[1 - cur] = dist.broadcast(A_list[1 - cur], src=src_a + 1, group=row_group, async_op=True) - - if opr[cur] is not None: - opr[cur].wait() - if i - 2 == row_rank: - C.copy_(C_list[cur]) - - if opa[cur] is not None: - opa[cur].wait() - - torch.matmul(A_list[cur].transpose(0, 1), B, out=C_list[cur]) - opr[cur] = dist.reduce(C_list[cur], dst=src_c, group=col_group, async_op=True) - cur = 1 - cur - src_a += 1 - src_c += tesseract_dim - - for op in opr: - op.wait() - - if tesseract_dim - 2 == row_rank: - C.copy_(C_list[cur]) - if tesseract_dim - 1 == row_rank: - C.copy_(C_list[1 - cur]) - out = C.reshape(out_shape) - - if ctx: - ctx.tesseract_dim = tesseract_dim - ctx.row_rank = row_rank - ctx.col_rank = col_rank - ctx.dep_rank = dep_rank - ctx.row_parallel_mode = row_parallel_mode - ctx.col_parallel_mode = col_parallel_mode - ctx.A_shape = A_shape - ctx.B_shape = B_shape - ctx.data_parallel_rank = data_parallel_rank - ctx.pipeline_parallel_rank = pipeline_parallel_rank - ctx.pipeline_parallel_size = pipeline_parallel_size - ctx.tensor_parallel_size = tensor_parallel_size - - return out - - @staticmethod - @custom_bwd - def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]: - A, B = ctx.saved_tensors - with torch.no_grad(): - A_grad = Matmul_ABT_2p5D.apply(B, output_grad, ctx.tesseract_dim, ctx.A_shape, ctx.row_rank, ctx.col_rank, - ctx.dep_rank, ctx.row_parallel_mode, ctx.col_parallel_mode, - ctx.data_parallel_rank, ctx.pipeline_parallel_rank, - ctx.pipeline_parallel_size, ctx.tensor_parallel_size) - B_grad = Matmul_AB_2p5D.apply(A, output_grad, ctx.tesseract_dim, ctx.B_shape, ctx.row_rank, ctx.col_rank, - ctx.dep_rank, ctx.row_parallel_mode, ctx.col_parallel_mode, - ctx.data_parallel_rank, ctx.pipeline_parallel_rank, - ctx.pipeline_parallel_size, ctx.tensor_parallel_size) - return A_grad, B_grad, None, None, None, None, None, None, None, None, None, None, None, None, None - - -class _Add_Bias_2p5D(torch.autograd.Function): - - @staticmethod - @custom_fwd(cast_inputs=torch.float16) - def forward(ctx: Any, input: Tensor, bias: Tensor, output_size_per_partition: int, tesseract_dim: int, - row_rank: int, col_rank: int, dep_rank: int, col_parallel_mode: ParallelMode, skip_bias_add: bool, - data_parallel_rank: int, pipeline_parallel_rank: int, pipeline_parallel_size: int, - tensor_parallel_size: int) -> Tensor: - if row_rank == 0: - bias_temp = bias.clone() - else: - bias_temp = torch.zeros(output_size_per_partition, dtype=bias.dtype, device=get_current_device()) - src_rank = \ - col_rank + dep_rank * tesseract_dim ** 2 + \ - data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \ - pipeline_parallel_rank * tensor_parallel_size - dist.broadcast(bias_temp, src=src_rank, group=get_parallel_group(col_parallel_mode)) - - ctx.row_rank = row_rank - ctx.col_rank = col_rank - ctx.dep_rank = dep_rank - ctx.tesseract_dim = tesseract_dim - ctx.col_parallel_mode = col_parallel_mode - ctx.bias = skip_bias_add - ctx.data_parallel_rank = data_parallel_rank - ctx.pipeline_parallel_rank = pipeline_parallel_rank - ctx.pipeline_parallel_size = pipeline_parallel_size - ctx.tensor_parallel_size = tensor_parallel_size - - if skip_bias_add: - return bias_temp - else: - output = input + bias_temp - return output - - @staticmethod - @custom_bwd - def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]: - row_rank = ctx.row_rank - col_rank = ctx.col_rank - dep_rank = ctx.dep_rank - tesseract_dim = ctx.tesseract_dim - col_parallel_mode = ctx.col_parallel_mode - data_parallel_rank = ctx.data_parallel_rank - pipeline_parallel_rank = ctx.pipeline_parallel_rank - pipeline_parallel_size = ctx.pipeline_parallel_size - tensor_parallel_size = ctx.tensor_parallel_size - - if ctx.bias: - dst_rank = \ - col_rank + dep_rank * (tesseract_dim ** 2) + \ - data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \ - pipeline_parallel_rank * tensor_parallel_size - dist.reduce(output_grad, dst=dst_rank, group=get_parallel_group(col_parallel_mode)) - if row_rank == 0: - return \ - None, output_grad, None, None, None, None, None, None, \ - None, None, None, None, None, None, None, None - else: - grad_tmp = torch.zeros_like(output_grad) - return \ - None, grad_tmp, None, None, None, None, None, None, \ - None, None, None, None, None, None, None, None - else: - reduce_dim = tuple(range(output_grad.ndim - 1)) - reduce = torch.sum(output_grad, dim=reduce_dim) - dst_rank = \ - col_rank + dep_rank * (tesseract_dim ** 2) + \ - data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \ - pipeline_parallel_rank * tensor_parallel_size - dist.reduce(reduce, dst=dst_rank, group=get_parallel_group(col_parallel_mode)) - if row_rank == 0: - return \ - output_grad, reduce, None, None, None, None, None, None, None, \ - None, None, None, None, None, None, None, None - else: - reduce_tmp = torch.zeros_like(reduce) - return \ - output_grad, reduce_tmp, None, None, None, None, None, None, \ - None, None, None, None, None, None, None, None, None - - -def add_bias_2p5d(input: Tensor, bias: Tensor, output_size_per_partition: int, tesseract_dim: int, row_rank: int, - col_rank: int, dep_rank: int, col_parallel_mode: ParallelMode, skip_bias_add: bool, - data_parallel_rank: int, pipeline_parallel_rank: int, pipeline_parallel_size: int, - tensor_parallel_size: int) -> Tensor: - r"""Matrix add bias: :math:`C = A + b`. - - Args: - input (:class:`torch.tensor`): matrix :math:`A`. - bias (:class:`torch.tensor`): matrix :math:`B`. - tesseract_dim (int): dimension of TESSERACT fo 2.5D parallelism. - output_size_per_partition (int): output size in each partition. - row_rank (int): the rank of row. - col_rank (int): the rank of column. - dep_rank (int): the rank of depth. - col_parallel_mode (:class:`colossalai.context.ParallelMode`): column parallel mode. - skip_bias_add (bool): If set to ``True``, it will skip bias add for linear layer, - which is preserved for kernel fusion. - data_parallel_rank (int): data parallel rank. - pipeline_parallel_rank (int): pipeline parallel rank - pipeline_parallel_size (int): pipeline parallel size. - tensor_parallel_size (int): tensor parallel size. - - Note: - The parallel_mode should be concluded in ``ParallelMode``. More details about ``ParallelMode`` could be found - in `parallel_mode `_ - """ - return _Add_Bias_2p5D.apply(input, bias, output_size_per_partition, tesseract_dim, row_rank, col_rank, dep_rank, - col_parallel_mode, skip_bias_add, data_parallel_rank, pipeline_parallel_rank, - pipeline_parallel_size, tensor_parallel_size) - - -class _Layernorm2p5D(torch.autograd.Function): - r"""Layernorm. - - Args: - input (:class:`torch.tensor`): input matrix. - E_x (:class:`torch.tensor`): mean. - Var_x (:class:`torch.tensor`): variance. - hidden_size (int): hidden size. - row_parallel_mode (:class:`colossalai.context.ParallelMode`): row parallel mode. - - Note: - The parallel_mode should be concluded in ``ParallelMode``. More details about ``ParallelMode`` could be found - in `parallel_mode `_ - """ - - @staticmethod - @custom_fwd(cast_inputs=torch.float32) - def forward(ctx: Any, input: Tensor, E_x: Tensor, Var_x: Tensor, hidden_size: int, - row_parallel_mode: ParallelMode) -> Tensor: - input = input - E_x - # in here, input = x - E[x], Var_x = 1 / sqrt(Var[x] + eps) - ctx.hidden_size = hidden_size - output = input * Var_x - ctx.save_for_backward(output, Var_x) - ctx.row_parallel_mode = row_parallel_mode - return output - - @staticmethod - @custom_bwd - def backward(ctx, output_grad): - row_parallel_mode = ctx.row_parallel_mode - x, Var_x = ctx.saved_tensors - # in here, Var_x = 1 / sqrt(Var[x] + eps), x = (x - E[x]) * Var_x - with torch.no_grad(): - output_grad_sum = torch.sum(output_grad, dim=-1, keepdim=True) - torch.distributed.all_reduce(output_grad_sum, group=get_parallel_group(row_parallel_mode)) - output_grad_sum /= ctx.hidden_size - - output_grad_mul_x_sum = torch.sum(output_grad * x, dim=-1, keepdim=True) - torch.distributed.all_reduce(output_grad_mul_x_sum, group=get_parallel_group(row_parallel_mode)) - output_grad_mul_x_sum /= ctx.hidden_size - - input_grad = output_grad.clone() - input_grad -= x * output_grad_mul_x_sum - input_grad -= output_grad_sum - input_grad *= Var_x - - return input_grad, None, None, None, None, None, None - - -def layernorm_2p5d(input: Tensor, E_x: Tensor, Var_x: Tensor, hidden_size: int, - row_parallel_mode: ParallelMode) -> Tensor: - r"""Layernorm. - - Args: - input (:class:`torch.tensor`): input matrix. - E_x (:class:`torch.tensor`): mean. - Var_x (:class:`torch.tensor`): variance. - hidden_size (int): hidden size. - row_parallel_mode (:class:`colossalai.context.ParallelMode`): row parallel mode. - - Note: - The parallel_mode should be concluded in ``ParallelMode``. More details about ``ParallelMode`` could be found - in `parallel_mode `_. - """ - return _Layernorm2p5D.apply(input, E_x, Var_x, hidden_size, row_parallel_mode) - - -class _AllGatherTensor2p5D(torch.autograd.Function): - - @staticmethod - @custom_fwd(cast_inputs=torch.float16) - def forward(ctx: Any, inputs: Tensor, dim: int, col_parallel_mode: ParallelMode) -> Tensor: - ctx.dim = dim - ctx.col_parallel_mode = col_parallel_mode - - outputs = all_gather(inputs, dim, col_parallel_mode) - return outputs - - @staticmethod - @custom_bwd - def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]: - grad = reduce_scatter(output_grad, ctx.dim, ctx.col_parallel_mode) - return grad.contiguous(), None, None - - -def all_gather_tensor_2p5d(inputs: Tensor, dim: int, col_parallel_mode: ParallelMode) -> Tensor: - r"""all gather the weight of 2.5D parallelism. - - Args: - inputs (:class:`torch.tensor`): input tensor. - dim (int): dimension of all-gather. - col_parallel_mode (:class:`colossalai.context.ParallelMode`): column parallel mode. - - Note: - The parallel_mode should be concluded in ``ParallelMode``. More details about ``ParallelMode`` could be found - in `parallel_mode `_. - """ - return _AllGatherTensor2p5D.apply(inputs, dim, col_parallel_mode) - - -class SplitFirst(torch.autograd.Function): - r""" - - Args: - inputs (:class:`torch.tensor`): input tensor. - tesseract_dim (int): dimension of TESSERACT fo 2.5D parallelism - col_parallel_mode (:class:`colossalai.context.ParallelMode`): column parallel mode. - - Note: - The parallel_mode should be concluded in ``ParallelMode``. More details about ``ParallelMode`` could be found - in `parallel_mode `_. - """ - - @staticmethod - @custom_fwd(cast_inputs=torch.float16) - def forward(ctx: Any, inputs: Tensor, tesseract_dim: int, col_parallel_mode: ParallelMode) -> Tensor: - ctx.tesseract_dim = tesseract_dim - ctx.batch_size = inputs.size(0) - ctx.para_mode = col_parallel_mode - row_rank = gpc.get_local_rank(col_parallel_mode) - - outputs = inputs.chunk(tesseract_dim, dim=0)[row_rank] - return outputs - - @staticmethod - @custom_bwd - def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]: - grad_shape = (ctx.batch_size,) + output_grad.shape[1:] - grad = torch.empty(grad_shape, dtype=output_grad.dtype, device=get_current_device()) - dist.all_gather(list(grad.chunk(ctx.tesseract_dim, dim=0)), - output_grad.contiguous(), - group=gpc.get_group(ctx.para_mode)) - return grad, None, None - - -def split_batch_2p5d(input_: Tensor, dim: int = 0) -> Tensor: - """Splits 2P5D tensor in specified dimension across cols. - - Args: - input_ (:class:`torch.tensor`): Input tensor. - dim (int): Specified dimension in which to split. - - Returns: - :class:`torch.tensor`: The tensor has been split. - """ - dim_size = input_.size(dim) - world_size = gpc.get_world_size(ParallelMode.PARALLEL_2P5D_COL) - - if world_size <= 1: - return input_ - - assert dim_size % world_size == 0, \ - f'The batch size ({dim_size}) is not a multiple of 2.5D size * depth ({world_size}).' - - return torch.chunk(input_, gpc.get_world_size(ParallelMode.PARALLEL_2P5D_COL), - dim=dim)[gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL)].contiguous() - - -class _ReduceTensor2p5D(torch.autograd.Function): - - @staticmethod - def forward(ctx, input_, parallel_mode): - return all_reduce(input_, parallel_mode) - - @staticmethod - def backward(ctx, output_grad): - return output_grad, None - - -def reduce_tensor_2p5d(input_: Tensor, parallel_mode: ParallelMode) -> Tensor: - r"""All-reduce the input. - - Args: - input_ (:class:`torch.tensor`): Input tensor. - parallel_mode (:class:`colossalai.context.ParallelMode`): The parallel mode tensor used. - - Note: - The parallel_mode should be concluded in ``ParallelMode``. More details about ``ParallelMode`` could be found - in `parallel_mode `_ - """ - return _ReduceTensor2p5D.apply(input_, parallel_mode) - - -class _ReduceScatterTensor2p5D(torch.autograd.Function): - - @staticmethod - def forward(ctx, input_, dim, parallel_mode): - ctx.dim = dim - ctx.parallel_mode = parallel_mode - return reduce_scatter(input_, dim, parallel_mode) - - @staticmethod - def backward(ctx, output_grad): - return all_gather(output_grad, ctx.dim, ctx.parallel_mode), None, None - - -def reduce_scatter_tensor_2p5d(input_: Tensor, dim: int, parallel_mode: ParallelMode) -> Tensor: - r"""Reduce-scatter the input. - - Args: - input_ (:class:`torch.tensor`): Input tensor. - dim (int): Dimension to reduce. - parallel_mode (:class:`colossalai.context.ParallelMode`): The parallel mode tensor used. - - Note: - The parallel_mode should be concluded in ``ParallelMode``. More details about ``ParallelMode`` could be found - in `parallel_mode `_ - """ - dim_size = input_.size(dim) - world_size = gpc.get_world_size(parallel_mode) - assert dim_size % world_size == 0, \ - f'The batch size ({dim_size}) is not a multiple of 2.5D size * depth ({world_size}).' - - return _ReduceScatterTensor2p5D.apply(input_, dim, parallel_mode) - - -class _RreduceByBatch2p5D(torch.autograd.Function): - - @staticmethod - def symbolic(graph, input_, reduce_mean: bool = False): - output = all_reduce(input_, ParallelMode.PARALLEL_2P5D_COL) - if reduce_mean: - reduce_size = gpc.get_world_size(ParallelMode.PARALLEL_2P5D_COL) - return output / reduce_size - return output - - @staticmethod - @custom_fwd(cast_inputs=torch.float32) - def forward(ctx, input_, reduce_mean: bool = False): - output = all_reduce(input_, ParallelMode.PARALLEL_2P5D_COL) - ctx.reduce_mean = reduce_mean - if reduce_mean: - reduce_size = gpc.get_world_size(ParallelMode.PARALLEL_2P5D_COL) - ctx.reduce_size = reduce_size - return output.clone() / reduce_size - return output.clone() - - @staticmethod - @custom_bwd - def backward(ctx, output_grad): - if ctx.reduce_mean: - return output_grad / ctx.reduce_size, None - else: - return output_grad, None - - -def reduce_by_batch_2p5d(input_, reduce_mean: bool = False) -> Tensor: - r"""All-reduce the input from the model parallel region. - - Args: - input_ (:class:`torch.tensor`): input matrix. - reduce_mean (bool, optional): - If set to ``True``, it will divide the output by column parallel size, default to False. - """ - return _RreduceByBatch2p5D.apply(input_, reduce_mean) diff --git a/colossalai/nn/layer/parallel_2p5d/_utils.py b/colossalai/nn/layer/parallel_2p5d/_utils.py deleted file mode 100644 index 1478b25de618978ef2c7e060da33edcb47ecff0b..0000000000000000000000000000000000000000 --- a/colossalai/nn/layer/parallel_2p5d/_utils.py +++ /dev/null @@ -1,25 +0,0 @@ -from colossalai.context.parallel_mode import ParallelMode -from colossalai.core import global_context as gpc -from colossalai.global_variables import tensor_parallel_env as env - - -def get_tesseract_dim_dep_from_env(): - try: - tesseract_dim = env.tesseract_dim - tesseract_dep = env.tesseract_dep - assert tesseract_dim > 0, 'TESSERACT_DIM must be larger than zero' - assert tesseract_dep > 0, 'TESSERACT_DEP must be larger than zero' - return tesseract_dim, tesseract_dep - - except KeyError as e: - raise EnvironmentError('TESSERACT_DIM or TESSERACT_DEP is not found in the current environment, ' - 'please make sure that you have used the correct process group initializer') - - -def assert_tesseract_initialization(): - assert gpc.is_initialized(ParallelMode.PARALLEL_2P5D_COL) and \ - gpc.is_initialized(ParallelMode.PARALLEL_2P5D_ROW) and \ - gpc.is_initialized(ParallelMode.PARALLEL_2P5D_DEP) and \ - gpc.is_initialized(ParallelMode.PARALLEL_2P5D_XZ), \ - 'Both PARALLEL_2P5D_COL, PARALLEL_2P5D_ROW, PARALLEL_2P5D_DEP and PARALLEL_2P5D_XZ ' \ - 'must be initialized by the process group initializer' diff --git a/colossalai/nn/layer/parallel_2p5d/layers.py b/colossalai/nn/layer/parallel_2p5d/layers.py deleted file mode 100644 index f849cbbe7b0d0f27067941a1c74ccae2ac6fa6f8..0000000000000000000000000000000000000000 --- a/colossalai/nn/layer/parallel_2p5d/layers.py +++ /dev/null @@ -1,1198 +0,0 @@ -import math -from collections import OrderedDict -from typing import Callable - -import torch -import torch.nn as nn -import torch.nn.functional as F -from colossalai.communication import broadcast -from colossalai.context import ParallelMode, seed -from colossalai.core import global_context as gpc -from colossalai.global_variables import tensor_parallel_env as env -from colossalai.nn import init as init -from colossalai.registry import LAYERS -from colossalai.utils.checkpointing import (broadcast_state_dict, gather_tensor_parallel_state_dict, - partition_tensor_parallel_state_dict) -from colossalai.utils.cuda import get_current_device -from torch import Tensor -from torch.nn import Parameter - -from ..base_layer import ParallelLayer -from ..utils import divide, set_tensor_parallel_attribute_by_partition, to_2tuple -from ._operation import (Matmul_AB_2p5D, Matmul_ABT_2p5D, add_bias_2p5d, all_gather_tensor_2p5d, classifier_2p5d, - layernorm_2p5d, reduce_scatter_tensor_2p5d, split_batch_2p5d) -from ._utils import assert_tesseract_initialization, get_tesseract_dim_dep_from_env - - -@LAYERS.register_module -class Linear2p5D(ParallelLayer): - r"""Linear layer for 2.5D parallelism. - - Args: - in_features (int): size of each input sample. - out_features (int): size of each output sample. - bias (bool, optional): If set to ``False``, the layer will not learn an additive bias, defaults to ``True``. - dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None. - skip_bias_add (bool, optional): If set to ``True``, it will skip bias add for linear layer, - which is preserved for kernel fusion, defaults to False. - weight_initializer (:class:`typing.Callable`, optional): - The initializer of weight, defaults to kaiming uniform initializer. - bias_initializer (:class:`typing.Callable`, optional): - The initializer of bias, defaults to xavier uniform initializer. - - More details about ``initializer`` please refer to - `init `_. - """ - - def __init__(self, - in_features: int, - out_features: int, - bias: bool = True, - dtype: torch.dtype = None, - skip_bias_add: bool = False, - weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), - bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1)): - super().__init__() - - self.in_features = in_features - self.out_features = out_features - self.skip_bias_add = skip_bias_add - - # parallel setting - assert_tesseract_initialization() - self.row_rank = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL) - self.col_rank = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_ROW) - self.dep_rank = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_DEP) - self.tesseract_dim, _ = get_tesseract_dim_dep_from_env() - - # partitioning dimension - self.input_size_per_partition = divide(in_features, self.tesseract_dim) - self.hidden_size_per_partition = divide(out_features, self.tesseract_dim) - - # create weight, shape: [k/q, h/q] - factory_kwargs = {'device': get_current_device(), 'dtype': dtype} - self.weight = Parameter( - torch.empty(self.input_size_per_partition, self.hidden_size_per_partition, **factory_kwargs)) - - # create bias, shape: [h/q] - if bias: - self.bias = Parameter(torch.empty(self.hidden_size_per_partition, **factory_kwargs)) - else: - self.register_parameter('bias', None) - - # initialize parameters - with seed(ParallelMode.TENSOR): - self.reset_parameters(weight_initializer, bias_initializer) - self._set_tensor_parallel_attributes() - - def _set_tensor_parallel_attributes(self): - set_tensor_parallel_attribute_by_partition(self.weight, self.tesseract_dim**2) - if self.bias is not None: - set_tensor_parallel_attribute_by_partition(self.bias, self.tesseract_dim) - - def reset_parameters(self, weight_initializer, bias_initializer) -> None: - fan_in, fan_out = self.in_features, self.out_features - weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out) - if self.bias is not None: - bias_initializer(self.bias, fan_in=fan_in) - - def _load_from_global_state_dict(self, state_dict, prefix, *args, **kwargs): - local_state = OrderedDict() - weight_key = prefix + 'weight' - bias_key = prefix + 'bias' - if gpc.get_local_rank(ParallelMode.TENSOR) == 0: - # weight - weight = state_dict.pop(weight_key, None) - if weight is not None: - local_state[weight_key] = weight.transpose(0, 1) - # bias - if self.bias is not None: - bias = state_dict.pop(bias_key, None) - if bias is not None: - local_state[bias_key] = bias - - # broadcast in dep groups - if gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL) == 0 and \ - gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_ROW) == 0: - broadcast_state_dict(local_state, ParallelMode.PARALLEL_2P5D_DEP) - # partition in column groups - if gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_ROW) == 0: - local_state = partition_tensor_parallel_state_dict( - local_state, - ParallelMode.PARALLEL_2P5D_COL, - dims={ - weight_key: 0, - bias_key: 0 - }, - partition_states={ - weight_key: True, - bias_key: False - }, - ) - # partition in row groups - local_state = partition_tensor_parallel_state_dict( - local_state, - ParallelMode.PARALLEL_2P5D_ROW, - dims={ - weight_key: -1, - bias_key: 0 - }, - partition_states={ - weight_key: True, - bias_key: True - }, - ) - - super()._load_from_global_state_dict(local_state, prefix, *args, **kwargs) - - def _save_to_global_state_dict(self, destination, prefix, keep_vars): - if gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_DEP) == 0: - weight_key = prefix + 'weight' - bias_key = prefix + 'bias' - local_state = OrderedDict({weight_key: self.weight}) - if self.bias is not None: - local_state[bias_key] = self.bias - - # gather in row groups - local_state = gather_tensor_parallel_state_dict( - local_state, - ParallelMode.PARALLEL_2P5D_ROW, - dims={ - weight_key: -1, - bias_key: 0 - }, - partition_states={ - weight_key: True, - bias_key: True - }, - keep_vars=keep_vars, - ) - # gather in column groups - if gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_ROW) == 0: - local_state = gather_tensor_parallel_state_dict( - local_state, - ParallelMode.PARALLEL_2P5D_COL, - dims={ - weight_key: 0, - bias_key: 0 - }, - partition_states={ - weight_key: True, - bias_key: False - }, - keep_vars=keep_vars, - ) - if gpc.get_local_rank(ParallelMode.TENSOR) == 0: - local_state[weight_key] = local_state[weight_key].transpose(0, 1) - destination.update(local_state) - - def forward(self, x: Tensor) -> Tensor: - # input: [m/dq, n/q, k/q] - # output: [m/dq, n/q, h/q] - out_shape = x.shape[:-1] + (self.hidden_size_per_partition,) - - output = Matmul_AB_2p5D.apply( - x, - self.weight, - self.tesseract_dim, - out_shape, - self.row_rank, - self.col_rank, - self.dep_rank, - ParallelMode.PARALLEL_2P5D_ROW, - ParallelMode.PARALLEL_2P5D_COL, - self.data_parallel_rank, - self.pipeline_parallel_rank, - self.pipeline_parallel_size, - self.tensor_parallel_size, - ) - - if self.bias is not None: - if self.skip_bias_add: - bias = add_bias_2p5d(None, self.bias, self.hidden_size_per_partition, self.tesseract_dim, self.row_rank, - self.col_rank, self.dep_rank, ParallelMode.PARALLEL_2P5D_COL, True, - self.data_parallel_rank, self.pipeline_parallel_rank, self.pipeline_parallel_size, - self.tensor_parallel_size) - return output, bias - else: - output = add_bias_2p5d(output, self.bias, self.hidden_size_per_partition, self.tesseract_dim, - self.row_rank, self.col_rank, self.dep_rank, ParallelMode.PARALLEL_2P5D_COL, - False, self.data_parallel_rank, self.pipeline_parallel_rank, - self.pipeline_parallel_size, self.tensor_parallel_size) - return output - else: - return output - - -@LAYERS.register_module -class LayerNorm2p5D(ParallelLayer): - r"""Layer Normalization for 2.5D parallelism. - - Args: - normalized_shape (int): input shape from an expected input of size. - :math:`[* \times \text{normalized_shape}[0] \times \text{normalized_shape}[1] - \times \ldots \times \text{normalized_shape}[-1]]` - If a single integer is used, it is treated as a singleton list, and this module will - normalize over the last dimension which is expected to be of that specific size. - eps (float, optional): a value added to the denominator for numerical stability, defaults to 1e-05. - bias (bool, optional): Whether to add a bias, defaults to ``True``. - dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None. - """ - - def __init__(self, normalized_shape: int, eps: float = 1e-05, bias=True, dtype=None): - super().__init__() - - # layer norm config - self.normalized_shape = normalized_shape - self.variance_epsilon = eps - - # parallel setting - assert_tesseract_initialization() - self.row_rank = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL) - self.col_rank = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_ROW) - self.dep_rank = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_DEP) - self.tesseract_dim, _ = get_tesseract_dim_dep_from_env() - - # partitioning dimension - self.partitioned_partition = divide(normalized_shape, self.tesseract_dim) # * - - # create parameters - factory_kwargs = {'device': get_current_device(), 'dtype': dtype} - - self.weight = Parameter(torch.ones(self.partitioned_partition, **factory_kwargs)) - if bias: - self.bias = Parameter(torch.zeros(self.partitioned_partition, **factory_kwargs)) - else: - self.bias = None - - self._set_tensor_parallel_attribute() - - def _set_tensor_parallel_attribute(self): - set_tensor_parallel_attribute_by_partition(self.weight, self.tesseract_dim) - if self.bias is not None: - set_tensor_parallel_attribute_by_partition(self.bias, self.tesseract_dim) - - def _load_from_global_state_dict(self, state_dict, prefix, *args, **kwargs): - local_state = OrderedDict() - weight_key = prefix + 'weight' - bias_key = prefix + 'bias' - if gpc.get_local_rank(ParallelMode.TENSOR) == 0: - # weight - weight = state_dict.pop(weight_key, None) - if weight is not None: - local_state[weight_key] = weight - # bias - bias = state_dict.pop(bias_key, None) - if bias is not None: - local_state[bias_key] = bias - - # partition in row groups - if gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL) == 0: - local_state = partition_tensor_parallel_state_dict( - local_state, - ParallelMode.PARALLEL_2P5D_ROW, - dims={ - weight_key: 0, - bias_key: 0 - }, - partition_states={ - weight_key: True, - bias_key: True - }, - ) - # partition in column groups - local_state = partition_tensor_parallel_state_dict( - local_state, - ParallelMode.PARALLEL_2P5D_COL, - dims={ - weight_key: 0, - bias_key: 0 - }, - partition_states={ - weight_key: True, - bias_key: True - }, - ) - - super()._load_from_global_state_dict(local_state, prefix, *args, **kwargs) - - def _save_to_global_state_dict(self, destination, prefix, keep_vars): - weight_key = prefix + 'weight' - bias_key = prefix + 'bias' - local_state = OrderedDict({weight_key: self.weight}) - if self.bias is not None: - local_state[bias_key] = self.bias - - # gather in column groups - local_state = gather_tensor_parallel_state_dict( - local_state, - ParallelMode.PARALLEL_2P5D_COL, - dims={ - weight_key: 0, - bias_key: 0 - }, - partition_states={ - weight_key: True, - bias_key: True - }, - keep_vars=keep_vars, - ) - # gather in row groups - if gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL) == 0: - local_state = gather_tensor_parallel_state_dict( - local_state, - ParallelMode.PARALLEL_2P5D_ROW, - dims={ - weight_key: 0, - bias_key: 0 - }, - partition_states={ - weight_key: True, - bias_key: True - }, - keep_vars=keep_vars, - ) - if gpc.get_local_rank(ParallelMode.TENSOR) == 0: - destination.update(local_state) - - def forward(self, x: Tensor) -> Tensor: - with torch.no_grad(): - E_x = torch.sum(x, dim=-1, keepdim=True) # [b/q, s, 1] - torch.distributed.all_reduce(E_x, group=gpc.get_group(ParallelMode.PARALLEL_2P5D_ROW)) - E_x /= self.normalized_shape - - # Var_x in the block below is the sum of input^2 - Var_x = torch.sum(x * x, dim=-1, keepdim=True) # [b/q, s, 1] - torch.distributed.all_reduce(Var_x, group=gpc.get_group(ParallelMode.PARALLEL_2P5D_ROW)) - Var_x /= self.normalized_shape - - Var_x = Var_x - E_x * E_x # variance of x [b/q, s, 1] - # this time 1/sqrt(Var_x + epsilon) - Var_x = 1.0 / torch.sqrt(Var_x + self.variance_epsilon) - - output = layernorm_2p5d(x, E_x, Var_x, self.normalized_shape, ParallelMode.PARALLEL_2P5D_ROW) - scale = add_bias_2p5d(None, self.weight, self.partitioned_partition, self.tesseract_dim, self.row_rank, - self.col_rank, self.dep_rank, ParallelMode.PARALLEL_2P5D_COL, True, - self.data_parallel_rank, self.pipeline_parallel_rank, self.pipeline_parallel_size, - self.tensor_parallel_size) - if self.bias is not None: - bias = add_bias_2p5d(None, self.bias, self.partitioned_partition, self.tesseract_dim, self.row_rank, - self.col_rank, self.dep_rank, ParallelMode.PARALLEL_2P5D_COL, True, - self.data_parallel_rank, self.pipeline_parallel_rank, self.pipeline_parallel_size, - self.tensor_parallel_size) - output = torch.addcmul(bias, scale, output) - else: - output = torch.mul(scale, output) - return output - - -@LAYERS.register_module -class PatchEmbedding2p5D(ParallelLayer): - r"""2D Image to Patch Embedding. - - Args: - img_size (int): image size. - patch_size (int): patch size. - in_chans (int): number of channels of input image. - embed_size (int): size of embedding. - dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None. - flatten (bool, optional): whether to flatten output tensor, defaults to True. - weight_initializer (:class:`typing.Callable`, optional): - The initializer of weight, defaults to kaiming uniform initializer. - bias_initializer (:class:`typing.Callable`, optional): - The initializer of bias, defaults to xavier uniform initializer. - position_embed_initializer (:class:`typing.Callable`, optional): - The initializer of position embedding, defaults to zeros initializer. - - More details about ``initializer`` please refer to - `init `_. - """ - - def __init__(self, - img_size: int, - patch_size: int, - in_chans: int, - embed_size: int, - flatten: bool = True, - dtype: torch.dtype = None, - weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), - bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1), - position_embed_initializer: Callable = init.zeros_()): - super().__init__() - img_size = to_2tuple(img_size) - patch_size = to_2tuple(patch_size) - - assert_tesseract_initialization() - self.tesseract_dim, self.tesseract_dep = get_tesseract_dim_dep_from_env() - self.img_size = img_size - self.patch_size = patch_size - self.grid_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1]) - self.num_patches = self.grid_size[0] * self.grid_size[1] - self.flatten = flatten - self.embed_size = embed_size - self.embed_size_per_partition = embed_size // self.tesseract_dim**2 - - with seed(ParallelMode.TENSOR): - self.weight = Parameter( - torch.empty((self.embed_size_per_partition, in_chans, *self.patch_size), - device=get_current_device(), - dtype=dtype)) - self.bias = Parameter(torch.empty(self.embed_size_per_partition, device=get_current_device(), dtype=dtype)) - - self.cls_token = Parameter( - torch.zeros((1, 1, self.embed_size_per_partition), device=get_current_device(), dtype=dtype)) - self.pos_embed = Parameter( - torch.zeros((1, self.num_patches + 1, self.embed_size_per_partition), - device=get_current_device(), - dtype=dtype)) - - self.reset_parameters(weight_initializer, bias_initializer, position_embed_initializer) - self._set_tensor_parallel_attribute() - - def _set_tensor_parallel_attribute(self): - set_tensor_parallel_attribute_by_partition(self.weight, self.tesseract_dim**2) - set_tensor_parallel_attribute_by_partition(self.bias, self.tesseract_dim**2) - set_tensor_parallel_attribute_by_partition(self.cls_token, self.tesseract_dim**2) - set_tensor_parallel_attribute_by_partition(self.pos_embed, self.tesseract_dim**2) - - def reset_parameters(self, weight_initializer, bias_initializer, position_embed_initializer): - with seed(ParallelMode.TENSOR): - fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight) - fan_out = self.embed_size - weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out) - bias_initializer(self.bias, fan_in=fan_in) - position_embed_initializer(self.pos_embed) - - def _load_from_global_state_dict(self, state_dict, prefix, *args, **kwargs): - local_state = OrderedDict() - weight_key = prefix + 'weight' - bias_key = prefix + 'bias' - cls_token_key = prefix + 'cls_token' - pos_embed_key = prefix + 'pos_embed' - if gpc.get_local_rank(ParallelMode.TENSOR) == 0: - # weight - weight = state_dict.pop(weight_key, None) - if weight is not None: - local_state[weight_key] = weight - # bias - bias = state_dict.pop(bias_key, None) - if bias is not None: - local_state[bias_key] = bias - # cls token - cls_token = state_dict.pop(cls_token_key, None) - if cls_token is not None: - local_state[cls_token_key] = cls_token - # pos embed - pos_embed = state_dict.pop(pos_embed_key, None) - if pos_embed is not None: - local_state[pos_embed_key] = pos_embed - - # partition in row groups - if gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL) == 0: - local_state = partition_tensor_parallel_state_dict( - local_state, - ParallelMode.PARALLEL_2P5D_ROW, - dims={ - weight_key: 0, - bias_key: 0, - cls_token_key: -1, - pos_embed_key: -1 - }, - partition_states={ - weight_key: True, - bias_key: True, - cls_token_key: True, - pos_embed_key: True - }, - ) - # partition in column groups - local_state = partition_tensor_parallel_state_dict( - local_state, - ParallelMode.PARALLEL_2P5D_COL, - dims={ - weight_key: 0, - bias_key: 0, - cls_token_key: -1, - pos_embed_key: -1 - }, - partition_states={ - weight_key: True, - bias_key: True, - cls_token_key: True, - pos_embed_key: True - }, - ) - - super()._load_from_global_state_dict(local_state, prefix, *args, **kwargs) - - def _save_to_global_state_dict(self, destination, prefix, keep_vars): - weight_key = prefix + 'weight' - bias_key = prefix + 'bias' - cls_token_key = prefix + 'cls_token' - pos_embed_key = prefix + 'pos_embed' - local_state = OrderedDict({ - weight_key: self.weight, - bias_key: self.bias, - cls_token_key: self.cls_token, - pos_embed_key: self.pos_embed - }) - - # gather in column groups - local_state = gather_tensor_parallel_state_dict( - local_state, - ParallelMode.PARALLEL_2P5D_COL, - dims={ - weight_key: 0, - bias_key: 0, - cls_token_key: -1, - pos_embed_key: -1 - }, - partition_states={ - weight_key: True, - bias_key: True, - cls_token_key: True, - pos_embed_key: True - }, - keep_vars=keep_vars, - ) - # gather in row groups - if gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL) == 0: - local_state = gather_tensor_parallel_state_dict( - local_state, - ParallelMode.PARALLEL_2P5D_ROW, - dims={ - weight_key: 0, - bias_key: 0, - cls_token_key: -1, - pos_embed_key: -1 - }, - partition_states={ - weight_key: True, - bias_key: True, - cls_token_key: True, - pos_embed_key: True - }, - keep_vars=keep_vars, - ) - if gpc.get_local_rank(ParallelMode.TENSOR) == 0: - destination.update(local_state) - - def forward(self, input_: Tensor) -> Tensor: - input_ = split_batch_2p5d(input_, 0) - - B, C, H, W = input_.shape - assert H == self.img_size[0] and W == self.img_size[1], \ - f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." - - weight = all_gather_tensor_2p5d(self.weight, 0, ParallelMode.PARALLEL_2P5D_COL) - bias = all_gather_tensor_2p5d(self.bias, 0, ParallelMode.PARALLEL_2P5D_COL) - - output = F.conv2d(input_, weight, bias, stride=self.patch_size) - if self.flatten: - output = output.flatten(2).transpose(1, 2) # BCHW -> BNC - - cls_token = all_gather_tensor_2p5d(self.cls_token, -1, ParallelMode.PARALLEL_2P5D_COL) - pos_embed = all_gather_tensor_2p5d(self.pos_embed, -1, ParallelMode.PARALLEL_2P5D_COL) - cls_token = cls_token.expand(output.shape[0], -1, -1) - output = torch.cat((cls_token, output), dim=1) - output = output + pos_embed - - return output - - -@LAYERS.register_module -class Embedding2p5D(ParallelLayer): - r"""Embedding for 2.5D parallelism. - - Args: - num_embeddings (int): number of embeddings. - embedding_dim (int): dimension of embedding. - padding_idx (int, optional): If specified, the entries at padding_idx do not contribute to the gradient; - therefore, the embedding vector at padding_idx is not updated during training, - i.e. it remains as a fixed “pad”, defaults to None. - dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None. - weight_initializer (:class:`typing.Callable`, optional): - he initializer of weight, defaults to normal initializer. - - The ``args`` and ``kwargs`` used in :class:``torch.nn.functional.embedding`` should contain: - :: - - max_norm (float, optional): If given, each embedding vector with norm larger than max_norm is - renormalized to have norm max_norm. Note: this will modify weight in-place. - norm_type (float, optional): The p of the p-norm to compute for the max_norm option. Default 2. - scale_grad_by_freq (bool, optional): If given, this will scale gradients by the inverse - of frequency of the words in the mini-batch. Default False. - sparse (bool, optional): If True, gradient w.r.t. weight will be a sparse tensor. Default False. - - More details about ``args`` and ``kwargs`` could be found in - `Embedding `_. - - More details about initializer please refer to - `init `_ - """ - - def __init__(self, - num_embeddings: int, - embedding_dim: int, - padding_idx: int = None, - dtype: torch.dtype = None, - weight_initializer: Callable = init.normal_(), - *args, - **kwargs): - super().__init__() - - assert_tesseract_initialization() - self.tesseract_dim, self.tesseract_dep = get_tesseract_dim_dep_from_env() - self.num_embeddings = num_embeddings - self.embed_dim = embedding_dim - embed_dim_per_partition = embedding_dim // self.tesseract_dim**2 - - self.padding_idx = padding_idx - self.embed_args = args - self.embed_kwargs = kwargs - - self.weight = Parameter( - torch.empty((num_embeddings, embed_dim_per_partition), device=get_current_device(), dtype=dtype)) - - self.reset_parameters(weight_initializer) - self._set_tensor_parallel_attributes() - - def _set_tensor_parallel_attributes(self): - set_tensor_parallel_attribute_by_partition(self.weight, self.tesseract_dim**2) - - def reset_parameters(self, weight_initializer) -> None: - with seed(ParallelMode.TENSOR): - fan_in, fan_out = self.num_embeddings, self.embed_dim - weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out) - self._fill_padding_idx_with_zero() - - def _fill_padding_idx_with_zero(self) -> None: - if self.padding_idx is not None: - with torch.no_grad(): - self.weight[self.padding_idx].fill_(0) - - def _load_from_global_state_dict(self, state_dict, prefix, *args, **kwargs): - local_state = OrderedDict() - weight_key = prefix + 'weight' - if gpc.get_local_rank(ParallelMode.TENSOR) == 0: - # weight - weight = state_dict.pop(weight_key, None) - if weight is not None: - local_state[weight_key] = weight - - # partition in row groups - if gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL) == 0: - local_state = partition_tensor_parallel_state_dict( - local_state, - ParallelMode.PARALLEL_2P5D_ROW, - dims={weight_key: -1}, - partition_states={weight_key: True}, - ) - # partition in column groups - local_state = partition_tensor_parallel_state_dict( - local_state, - ParallelMode.PARALLEL_2P5D_COL, - dims={weight_key: -1}, - partition_states={weight_key: True}, - ) - - super()._load_from_global_state_dict(local_state, prefix, *args, **kwargs) - - def _save_to_global_state_dict(self, destination, prefix, keep_vars): - weight_key = prefix + 'weight' - local_state = OrderedDict({weight_key: self.weight}) - - # gather in column groups - local_state = gather_tensor_parallel_state_dict( - local_state, - ParallelMode.PARALLEL_2P5D_COL, - dims={weight_key: -1}, - partition_states={weight_key: True}, - keep_vars=keep_vars, - ) - # gather in row groups - if gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL) == 0: - local_state = gather_tensor_parallel_state_dict( - local_state, - ParallelMode.PARALLEL_2P5D_ROW, - dims={weight_key: -1}, - partition_states={weight_key: True}, - keep_vars=keep_vars, - ) - if gpc.get_local_rank(ParallelMode.TENSOR) == 0: - destination.update(local_state) - - def forward(self, input_: Tensor) -> Tensor: - input_ = split_batch_2p5d(input_, 0) - - weight = all_gather_tensor_2p5d(self.weight, -1, ParallelMode.PARALLEL_2P5D_COL) - - output = F.embedding(input_, weight, self.padding_idx, *self.embed_args, **self.embed_kwargs) - - return output - - -@LAYERS.register_module -class VocabParallelEmbedding2p5D(ParallelLayer): - """Embedding parallelized in the vocabulary dimension. - - Args: - num_embeddings (int): number of embeddings. - embedding_dim (int): dimension of embedding. - padding_idx (int, optional): If specified, the entries at padding_idx do not contribute to the gradient; - therefore, the embedding vector at padding_idx is not updated during training, - i.e. it remains as a fixed “pad”, defaults to None. - dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None. - weight_initializer (:class:`typing.Callable`, optional): - he initializer of weight, defaults to normal initializer. - - The ``args`` and ``kwargs`` used in :class:``torch.nn.functional.embedding`` should contain: - :: - - max_norm (float, optional): If given, each embedding vector with norm larger than max_norm is - renormalized to have norm max_norm. Note: this will modify weight in-place. - norm_type (float, optional): The p of the p-norm to compute for the max_norm option. Default 2. - scale_grad_by_freq (bool, optional): If given, this will scale gradients by the inverse - of frequency of the words in the mini-batch. Default False. - sparse (bool, optional): If True, gradient w.r.t. weight will be a sparse tensor. Default False. - - More details about ``args`` and ``kwargs`` could be found in - `Embedding `_. - - More details about initializer please refer to - `init `_. - """ - - def __init__(self, - num_embeddings: int, - embedding_dim: int, - padding_idx: int = None, - dtype: torch.dtype = None, - weight_initializer: Callable = init.normal_(), - *args, - **kwargs): - super().__init__() - self.num_embeddings = num_embeddings - self.embed_dim = embedding_dim - self.padding_idx = padding_idx - self.embed_args = args - self.embed_kwargs = kwargs - - assert_tesseract_initialization() - self.tesseract_dim, self.tesseract_dep = get_tesseract_dim_dep_from_env() - self.num_embeddings_per_partition = divide(self.num_embeddings, self.tesseract_dim) - self.embed_dim_per_partition = divide(self.embed_dim, self.tesseract_dim) - tensor_parallel_rank = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL) - self.vocab_start_index = tensor_parallel_rank * self.num_embeddings_per_partition - self.vocab_end_index = self.vocab_start_index + self.num_embeddings_per_partition - - self.weight = Parameter( - torch.empty((self.num_embeddings_per_partition, self.embed_dim_per_partition), - device=get_current_device(), - dtype=dtype)) - - self.reset_parameters(weight_initializer) - self._set_tensor_parallel_attributes() - env.vocab_parallel = True - - def _set_tensor_parallel_attributes(self): - set_tensor_parallel_attribute_by_partition(self.weight, self.tesseract_dim**2) - - def reset_parameters(self, weight_initializer) -> None: - with seed(ParallelMode.TENSOR): - fan_in, fan_out = self.num_embeddings, self.embed_dim - weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out) - self._fill_padding_idx_with_zero() - - def _fill_padding_idx_with_zero(self) -> None: - if self.padding_idx is not None and \ - self.vocab_start_index <= self.padding_idx < self.vocab_end_index: - with torch.no_grad(): - self.weight[self.padding_idx - self.vocab_start_index].fill_(0) - - def _load_from_global_state_dict(self, state_dict, prefix, *args, **kwargs): - local_state = OrderedDict() - weight_key = prefix + 'weight' - if gpc.get_local_rank(ParallelMode.TENSOR) == 0: - # weight - weight = state_dict.pop(weight_key, None) - if weight is not None: - local_state[weight_key] = weight - - # partition in row groups - if gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL) == 0: - local_state = partition_tensor_parallel_state_dict( - local_state, - ParallelMode.PARALLEL_2P5D_ROW, - dims={weight_key: -1}, - partition_states={weight_key: True}, - ) - # partition in column groups - local_state = partition_tensor_parallel_state_dict( - local_state, - ParallelMode.PARALLEL_2P5D_COL, - dims={weight_key: 0}, - partition_states={weight_key: True}, - ) - - super()._load_from_global_state_dict(local_state, prefix, *args, **kwargs) - - def _save_to_global_state_dict(self, destination, prefix, keep_vars): - weight_key = prefix + 'weight' - local_state = OrderedDict({weight_key: self.weight}) - - # gather in column groups - local_state = gather_tensor_parallel_state_dict( - local_state, - ParallelMode.PARALLEL_2P5D_COL, - dims={weight_key: 0}, - partition_states={weight_key: True}, - keep_vars=keep_vars, - ) - # gather in row groups - if gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL) == 0: - local_state = gather_tensor_parallel_state_dict( - local_state, - ParallelMode.PARALLEL_2P5D_ROW, - dims={weight_key: -1}, - partition_states={weight_key: True}, - keep_vars=keep_vars, - ) - if gpc.get_local_rank(ParallelMode.TENSOR) == 0: - destination.update(local_state) - - def forward(self, input_: Tensor) -> Tensor: - # Build the mask. - input_mask = (input_ < self.vocab_start_index) | (input_ >= self.vocab_end_index) - # Mask the input. - masked_input = input_.clone() - self.vocab_start_index - masked_input[input_mask] = 0 - - output_parallel = F.embedding(masked_input, self.weight, self.padding_idx, *self.embed_args, - **self.embed_kwargs) - - # Mask the output embedding. - output_parallel[input_mask, :] = 0. - # Reduce across all the model parallel GPUs. - output = reduce_scatter_tensor_2p5d(output_parallel, 0, ParallelMode.PARALLEL_2P5D_COL) - return output - - -@LAYERS.register_module -class Classifier2p5D(ParallelLayer): - r"""Classifier for 2.5D parallelism. - - Args: - in_features (int): size of each input sample. - num_classes (int): number of classes. - weight (:class:`torch.nn.Parameter`, optional): weight of the classifier, defaults to None. - bias (bool, optional): If set to ``False``, the layer will not learn an additive bias, defaults to ``True``. - dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None. - weight_initializer (:class:`typing.Callable`, optional): - The initializer of weight, defaults to kaiming uniform initializer. - bias_initializer (:class:`typing.Callable`, optional): - The initializer of bias, defaults to xavier uniform initializer. - - More details about ``initializer`` please refer to - `init `_. - """ - - def __init__(self, - in_features: int, - num_classes: int, - weight: Parameter = None, - bias: bool = True, - dtype: torch.dtype = None, - weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), - bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1)): - super().__init__() - self.in_features = in_features - self.num_classes = num_classes - assert_tesseract_initialization() - self.row_rank = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL) - self.col_rank = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_ROW) - self.dep_rank = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_DEP) - self.tesseract_dim, self.tesseract_dep = get_tesseract_dim_dep_from_env() - - # partitioning dimension - self.input_size_per_partition = divide(self.in_features, self.tesseract_dim**2) - - if weight is not None: - self.weight = weight - self.has_weight = False - else: - self.weight = Parameter( - torch.empty(self.num_classes, self.input_size_per_partition, device=get_current_device(), dtype=dtype)) - self.has_weight = True - if bias: - self.bias = Parameter(torch.zeros(self.num_classes, device=get_current_device(), dtype=dtype)) - else: - self.bias = None - - self.reset_parameters(weight_initializer, bias_initializer) - self._set_tensor_parallel_attributes() - - def _set_tensor_parallel_attributes(self): - if self.has_weight: - set_tensor_parallel_attribute_by_partition(self.weight, self.tesseract_dim**2) - - def reset_parameters(self, weight_initializer, bias_initializer) -> None: - with seed(ParallelMode.TENSOR): - fan_in, fan_out = self.in_features, self.num_classes - col_src_rank = gpc.get_ranks_in_group(ParallelMode.PARALLEL_2P5D_COL)[0] - row_src_rank = gpc.get_ranks_in_group(ParallelMode.PARALLEL_2P5D_ROW)[0] - - if self.has_weight: - weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out) - - if self.bias is not None: - bias_initializer(self.bias, fan_in=fan_in) - broadcast(self.bias, col_src_rank, ParallelMode.PARALLEL_2P5D_COL) - broadcast(self.bias, row_src_rank, ParallelMode.PARALLEL_2P5D_ROW) - - def _load_from_global_state_dict(self, state_dict, prefix, *args, **kwargs): - local_state = OrderedDict() - weight_key = prefix + 'weight' - bias_key = prefix + 'bias' - if gpc.get_local_rank(ParallelMode.TENSOR) == 0: - # weight - if self.has_weight: - weight = state_dict.pop(weight_key, None) - if weight is not None: - local_state[weight_key] = weight - # bias - if self.bias is not None: - bias = state_dict.pop(bias_key, None) - if bias is not None: - local_state[bias_key] = bias - - # partition in row groups - if gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL) == 0: - local_state = partition_tensor_parallel_state_dict( - local_state, - ParallelMode.PARALLEL_2P5D_ROW, - dims={ - weight_key: -1, - bias_key: 0 - }, - partition_states={ - weight_key: True, - bias_key: False - }, - ) - # partition in column groups - local_state = partition_tensor_parallel_state_dict( - local_state, - ParallelMode.PARALLEL_2P5D_COL, - dims={ - weight_key: -1, - bias_key: 0 - }, - partition_states={ - weight_key: True, - bias_key: False - }, - ) - - super()._load_from_global_state_dict(local_state, prefix, *args, **kwargs) - - def _save_to_global_state_dict(self, destination, prefix, keep_vars): - weight_key = prefix + 'weight' - bias_key = prefix + 'bias' - local_state = OrderedDict() - if self.has_weight: - local_state[weight_key] = self.weight - if self.bias is not None: - local_state[bias_key] = self.bias - - # gather in column groups - local_state = gather_tensor_parallel_state_dict( - local_state, - ParallelMode.PARALLEL_2P5D_COL, - dims={ - weight_key: -1, - bias_key: 0 - }, - partition_states={ - weight_key: True, - bias_key: False - }, - keep_vars=keep_vars, - ) - # gather in row groups - if gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL) == 0: - local_state = gather_tensor_parallel_state_dict( - local_state, - ParallelMode.PARALLEL_2P5D_ROW, - dims={ - weight_key: -1, - bias_key: 0 - }, - partition_states={ - weight_key: True, - bias_key: False - }, - keep_vars=keep_vars, - ) - if gpc.get_local_rank(ParallelMode.TENSOR) == 0: - destination.update(local_state) - - def forward(self, input_: Tensor) -> Tensor: - out_shape = input_.shape[:-1] + (self.num_classes,) - - return classifier_2p5d(input_, self.weight, self.bias, self.tesseract_dim, out_shape, self.row_rank, - self.col_rank, ParallelMode.PARALLEL_2P5D_ROW, ParallelMode.PARALLEL_2P5D_COL, - self.data_parallel_rank, self.pipeline_parallel_rank, self.pipeline_parallel_size, - self.tensor_parallel_size) - - -@LAYERS.register_module -class VocabParallelClassifier2p5D(ParallelLayer): - r"""Vocab parallel classifier layer for 2.5D parallelism. - - Args: - in_features (int): size of each input sample. - num_classes (int): number of classes. - weight (:class:`torch.nn.Parameter`, optional): weight of the classifier, defaults to None. - bias (bool, optional): If set to ``False``, the layer will not learn an additive bias, defaults to ``True``. - dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None. - weight_initializer (:class:`typing.Callable`, optional): - The initializer of weight, defaults to kaiming uniform initializer. - bias_initializer (:class:`typing.Callable`, optional): - The initializer of bias, defaults to xavier uniform initializer. - - More details about ``initializer`` please refer to - `init `_. - """ - - def __init__(self, - in_features: int, - num_classes: int, - weight: Parameter = None, - bias: bool = True, - dtype: torch.dtype = None, - weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), - bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1)): - super().__init__() - - self.in_features = in_features - self.num_classes = num_classes - - # parallel setting - assert_tesseract_initialization() - self.row_rank = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL) - self.col_rank = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_ROW) - self.dep_rank = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_DEP) - self.tesseract_dim, _ = get_tesseract_dim_dep_from_env() - - # partitioning dimension - self.input_size_per_partition = divide(in_features, self.tesseract_dim) - self.hidden_size_per_partition = divide(num_classes, self.tesseract_dim) - - # create weight, shape: [k/q, h/q] - factory_kwargs = {'device': get_current_device(), 'dtype': dtype} - if weight is not None: - self.weight = weight - self.has_weight = False - else: - self.weight = Parameter( - torch.empty(self.hidden_size_per_partition, self.input_size_per_partition, **factory_kwargs)) - self.has_weight = True - # create bias, shape: [h/q] - if bias: - self.bias = Parameter(torch.empty(self.hidden_size_per_partition, **factory_kwargs)) - else: - self.bias = None - - # initialize parameters - with seed(ParallelMode.TENSOR): - self.reset_parameters(weight_initializer, bias_initializer) - self._set_tensor_parallel_attributes() - env.vocab_parallel = True - - def _set_tensor_parallel_attributes(self): - if self.has_weight: - set_tensor_parallel_attribute_by_partition(self.weight, self.tesseract_dim**2) - if self.bias is not None: - set_tensor_parallel_attribute_by_partition(self.bias, self.tesseract_dim) - - def reset_parameters(self, weight_initializer, bias_initializer) -> None: - fan_in, fan_out = self.in_features, self.num_classes - if self.has_weight: - weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out) - if self.bias is not None: - bias_initializer(self.bias, fan_in=fan_in) - - def _load_from_global_state_dict(self, state_dict, prefix, *args, **kwargs): - local_state = OrderedDict() - weight_key = prefix + 'weight' - bias_key = prefix + 'bias' - if gpc.get_local_rank(ParallelMode.TENSOR) == 0: - # weight - if self.has_weight: - weight = state_dict.pop(weight_key, None) - if weight is not None: - local_state[weight_key] = weight - # bias - if self.bias is not None: - bias = state_dict.pop(bias_key, None) - if bias is not None: - local_state[bias_key] = bias - - # partition in row groups - if gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL) == 0: - local_state = partition_tensor_parallel_state_dict( - local_state, - ParallelMode.PARALLEL_2P5D_ROW, - dims={ - weight_key: -1, - bias_key: 0 - }, - partition_states={ - weight_key: True, - bias_key: True - }, - ) - # partition in column groups - local_state = partition_tensor_parallel_state_dict( - local_state, - ParallelMode.PARALLEL_2P5D_COL, - dims={ - weight_key: 0, - bias_key: 0 - }, - partition_states={ - weight_key: True, - bias_key: True - }, - ) - - super()._load_from_global_state_dict(local_state, prefix, *args, **kwargs) - - def forward(self, x: Tensor) -> Tensor: - # input: [m/dq, n/q, k/q] - # output: [m/dq, n/q, h/q] - out_shape = x.shape[:-1] + (self.hidden_size_per_partition,) - - output = Matmul_ABT_2p5D.apply( - x, - self.weight, - self.tesseract_dim, - out_shape, - self.row_rank, - self.col_rank, - self.dep_rank, - ParallelMode.PARALLEL_2P5D_ROW, - ParallelMode.PARALLEL_2P5D_COL, - self.data_parallel_rank, - self.pipeline_parallel_rank, - self.pipeline_parallel_size, - self.tensor_parallel_size, - ) - - if self.bias is not None: - output = add_bias_2p5d(output, self.bias, self.hidden_size_per_partition, self.tesseract_dim, self.row_rank, - self.col_rank, self.dep_rank, ParallelMode.PARALLEL_2P5D_COL, False, - self.data_parallel_rank, self.pipeline_parallel_rank, self.pipeline_parallel_size, - self.tensor_parallel_size) - return output diff --git a/colossalai/nn/layer/parallel_3d/__init__.py b/colossalai/nn/layer/parallel_3d/__init__.py deleted file mode 100644 index 9ae255b449ee7f57a08a3bb596102860bb1b60d3..0000000000000000000000000000000000000000 --- a/colossalai/nn/layer/parallel_3d/__init__.py +++ /dev/null @@ -1,8 +0,0 @@ -from ._operation import reduce_by_batch_3d, split_batch_3d, split_tensor_3d -from .layers import (Classifier3D, Embedding3D, LayerNorm3D, Linear3D, PatchEmbedding3D, VocabParallelClassifier3D, - VocabParallelEmbedding3D) - -__all__ = [ - 'reduce_by_batch_3d', 'split_tensor_3d', 'split_batch_3d', 'Linear3D', 'LayerNorm3D', 'PatchEmbedding3D', - 'Classifier3D', 'Embedding3D', 'VocabParallelEmbedding3D', 'VocabParallelClassifier3D' -] diff --git a/colossalai/nn/layer/parallel_3d/_operation.py b/colossalai/nn/layer/parallel_3d/_operation.py deleted file mode 100755 index 5dc9a242851fa79af244f018e4c9e4d2e57e84fc..0000000000000000000000000000000000000000 --- a/colossalai/nn/layer/parallel_3d/_operation.py +++ /dev/null @@ -1,590 +0,0 @@ -#!/usr/bin/env python -# -*- encoding: utf-8 -*- - -from typing import Optional, Tuple - -import torch -from torch import Tensor -from torch.cuda.amp import custom_bwd, custom_fwd - -from colossalai.communication import all_gather, all_reduce, broadcast, reduce, reduce_scatter -from colossalai.constants import INPUT_GROUP_3D, WEIGHT_GROUP_3D -from colossalai.context.parallel_mode import ParallelMode -from colossalai.core import global_context as gpc - -from ._utils import get_parallel_mode_from_env, push_async_grad - - -class _Linear3D(torch.autograd.Function): - - @staticmethod - @custom_fwd(cast_inputs=torch.float16) - def forward( - ctx, - input_: Tensor, - weight: Tensor, - weight_id: int, - input_parallel_mode: ParallelMode, - weight_parallel_mode: ParallelMode, - output_parallel_mode: ParallelMode, - ) -> Tensor: - ctx.weight_id = weight_id - ctx.input_parallel_mode = input_parallel_mode - ctx.weight_parallel_mode = weight_parallel_mode - ctx.output_parallel_mode = output_parallel_mode - - input_ = all_gather(input_, 0, input_parallel_mode) - weight = all_gather(weight, 0, weight_parallel_mode) - ctx.save_for_backward(input_, weight) - - output = torch.matmul(input_, weight) - output = reduce_scatter(output, 0, output_parallel_mode) - - return output - - @staticmethod - @custom_bwd - def backward(ctx, output_grad: Tensor) -> Tuple[Tensor, ...]: - input_, weight = ctx.saved_tensors - output_grad = all_gather(output_grad, 0, ctx.output_parallel_mode) - - input_grad = torch.matmul(output_grad, weight.transpose(0, 1)) - input_grad, input_op = reduce_scatter(input_grad, 0, ctx.input_parallel_mode, async_op=True) - - weight_grad = torch.matmul( - input_.reshape(-1, input_.shape[-1]).transpose(0, 1), output_grad.reshape(-1, output_grad.shape[-1])) - weight_grad, op = reduce_scatter(weight_grad, 0, ctx.weight_parallel_mode, async_op=True) - weight_grad = push_async_grad(op, weight_grad, ctx.weight_id) - - input_op.wait() - - return input_grad, weight_grad, None, None, None, None - - -def linear_3d( - input_: Tensor, - weight: Tensor, - input_parallel_mode: ParallelMode, - weight_parallel_mode: ParallelMode, - output_parallel_mode: ParallelMode, -) -> Tensor: - r"""Linear layer for 3D parallelism. - - Args: - input_ (:class:`torch.tensor`): input matrix. - weight (:class:`torch.tensor`): matrix of weight. - input_parallel_mode (:class:`colossalai.context.parallel_mode.ParallelMode`): input parallel mode. - weight_parallel_mode (:class:`colossalai.context.parallel_mode.ParallelMode`): weight parallel mode. - output_parallel_mode (:class:`colossalai.context.parallel_mode.ParallelMode`): output parallel mode. - - Note: - The parallel_mode should be concluded in ``ParallelMode``. More details about ``ParallelMode`` could be found - in `parallel_mode `_ - """ - return _Linear3D.apply( - input_, - weight, - id(weight), - input_parallel_mode, - weight_parallel_mode, - output_parallel_mode, - ) - - -class _Classifier3D(torch.autograd.Function): - - @staticmethod - @custom_fwd(cast_inputs=torch.float16) - def forward( - ctx, - input_: Tensor, - weight: Tensor, - bias: Optional[Tensor], - weight_id: int, - bias_id: Optional[int], - input_parallel_mode: ParallelMode, - weight_parallel_mode: ParallelMode, - output_parallel_mode: ParallelMode, - ) -> Tensor: - ctx.use_bias = bias is not None - ctx.weight_id = weight_id - - src_rank = gpc.get_ranks_in_group(input_parallel_mode)[gpc.get_local_rank(output_parallel_mode)] - weight = broadcast(weight, src_rank, input_parallel_mode) - ctx.save_for_backward(input_, weight) - - output = torch.matmul(input_, weight.transpose(0, 1)) - output = all_reduce(output, output_parallel_mode) - - if bias is not None: - ctx.bias_id = bias_id - output += bias - - ctx.src_rank = src_rank - ctx.input_parallel_mode = input_parallel_mode - ctx.weight_parallel_mode = weight_parallel_mode - ctx.output_parallel_mode = output_parallel_mode - return output - - @staticmethod - @custom_bwd - def backward(ctx, output_grad: Tensor) -> Tuple[Tensor, ...]: - input_, weight = ctx.saved_tensors - weight_grad = torch.matmul( - output_grad.reshape(-1, output_grad.shape[-1]).transpose(0, 1), input_.reshape(-1, input_.shape[-1])) - weight_grad = reduce(weight_grad, ctx.src_rank, ctx.input_parallel_mode) - if gpc.get_local_rank(ctx.input_parallel_mode) == gpc.get_local_rank(ctx.output_parallel_mode): - weight_grad, op = all_reduce(weight_grad, ctx.weight_parallel_mode, async_op=True) - weight_grad = push_async_grad(op, weight_grad, ctx.weight_id) - else: - weight_grad = None - - if ctx.use_bias: - bias_grad = torch.sum(output_grad, dim=tuple(range(len(output_grad.shape))[:-1])) - bias_grad = all_reduce(bias_grad, ctx.input_parallel_mode) - bias_grad, op = all_reduce(bias_grad, ctx.weight_parallel_mode, async_op=True) - bias_grad = push_async_grad(op, bias_grad, ctx.bias_id) - else: - bias_grad = None - - input_grad = torch.matmul(output_grad, weight) - - return input_grad, weight_grad, bias_grad, None, None, None, None, None - - -def classifier_3d( - input_: Tensor, - weight: Tensor, - bias: Optional[Tensor], - input_parallel_mode: ParallelMode, - weight_parallel_mode: ParallelMode, - output_parallel_mode: ParallelMode, -) -> Tensor: - r"""3D parallel classifier. - - Args: - input_ (:class:`torch.tensor`): input matrix. - weight (:class:`torch.tensor`): matrix of weight. - bias (:class:`torch.tensor`): matrix of bias. - input_parallel_mode (:class:`colossalai.context.parallel_mode.ParallelMode`): input parallel mode. - weight_parallel_mode (:class:`colossalai.context.parallel_mode.ParallelMode`): weight parallel mode. - output_parallel_mode (:class:`colossalai.context.parallel_mode.ParallelMode`): output parallel mode. - - Note: - The parallel_mode should be concluded in ``ParallelMode``. More details about ``ParallelMode`` could be found - in `parallel_mode `_ - """ - return _Classifier3D.apply( - input_, - weight, - bias, - id(weight), - id(bias) if bias is not None else None, - input_parallel_mode, - weight_parallel_mode, - output_parallel_mode, - ) - - -class _VocabParallelClassifier3D(torch.autograd.Function): - - @staticmethod - @custom_fwd(cast_inputs=torch.float16) - def forward( - ctx, - input_: Tensor, - weight: Tensor, - bias: Optional[Tensor], - weight_id: int, - bias_id: Optional[int], - input_parallel_mode: ParallelMode, - weight_parallel_mode: ParallelMode, - output_parallel_mode: ParallelMode, - ) -> Tensor: - ctx.use_bias = bias is not None - ctx.weight_id = weight_id - - input_ = all_gather(input_, 0, input_parallel_mode) - weight = all_gather(weight, 0, weight_parallel_mode).transpose(0, 1) - ctx.save_for_backward(input_, weight) - - output = torch.matmul(input_, weight) - output = reduce_scatter(output, 0, output_parallel_mode) - - if bias is not None: - ctx.bias_id = bias_id - output += bias - - ctx.input_parallel_mode = input_parallel_mode - ctx.weight_parallel_mode = weight_parallel_mode - ctx.output_parallel_mode = output_parallel_mode - return output - - @staticmethod - @custom_bwd - def backward(ctx, output_grad: Tensor) -> Tuple[Tensor, ...]: - input_, weight = ctx.saved_tensors - output_grad = all_gather(output_grad, 0, ctx.output_parallel_mode) - - input_grad = torch.matmul(output_grad, weight.transpose(0, 1)) - input_grad, input_op = reduce_scatter(input_grad, 0, ctx.input_parallel_mode, async_op=True) - - weight_grad = torch.matmul( - input_.reshape(-1, input_.shape[-1]).transpose(0, 1), output_grad.reshape(-1, output_grad.shape[-1])) - weight_grad, op = reduce_scatter(weight_grad.transpose(0, 1), 0, ctx.weight_parallel_mode, async_op=True) - weight_grad = push_async_grad(op, weight_grad, ctx.weight_id) - - if ctx.use_bias: - bias_grad = torch.sum(output_grad, dim=tuple(range(len(output_grad.shape))[:-1])) - bias_grad, op = all_reduce(bias_grad, ctx.weight_parallel_mode, async_op=True) - bias_grad = push_async_grad(op, bias_grad, ctx.bias_id) - else: - bias_grad = None - - input_op.wait() - - return input_grad, weight_grad, bias_grad, None, None, None, None, None - - -def vocab_parallel_classifier_3d( - input_: Tensor, - weight: Tensor, - bias: Optional[Tensor], - input_parallel_mode: ParallelMode, - weight_parallel_mode: ParallelMode, - output_parallel_mode: ParallelMode, -) -> Tensor: - r"""3D vocab parallel classifier. - - Args: - input_ (:class:`torch.tensor`): input matrix. - weight (:class:`torch.tensor`): matrix of weight. - bias (:class:`torch.tensor`): matrix of bias. - input_parallel_mode (:class:`colossalai.context.parallel_mode.ParallelMode`): input parallel mode. - weight_parallel_mode (:class:`colossalai.context.parallel_mode.ParallelMode`): weight parallel mode. - output_parallel_mode (:class:`colossalai.context.parallel_mode.ParallelMode`): output parallel mode. - - Note: - The parallel_mode should be concluded in ``ParallelMode``. More details about ``ParallelMode`` could be found - in `parallel_mode `_ - """ - return _VocabParallelClassifier3D.apply( - input_, - weight, - bias, - id(weight), - id(bias) if bias is not None else None, - input_parallel_mode, - weight_parallel_mode, - output_parallel_mode, - ) - - -@torch.jit.script -def norm_forward(x: Tensor, mean: Tensor, sqr_mean: Tensor, weight: Tensor, bias: Tensor, eps: float): - mu = x - mean - var = sqr_mean - mean**2 - sigma = torch.sqrt(var + eps) - z = mu / sigma - output = weight * z + bias - - return output, mu, sigma - - -@torch.jit.script -def norm_backward(grad: Tensor, mu: Tensor, sigma: Tensor, weight: Tensor): - # dbias, dweight = grad, grad * mu / sigma - dz = grad * weight - dmu = dz / sigma - dvar = dz * mu * (-0.5) * sigma**(-3) - dmean = -dmu - dvar = torch.sum(dvar, -1, keepdim=True) - dmean = torch.sum(dmean, -1, keepdim=True) - - return dmu, dmean, dvar - - -class _Layernorm3D(torch.autograd.Function): - - @staticmethod - @custom_fwd(cast_inputs=torch.float32) - def forward( - ctx, - input_: Tensor, - weight: Tensor, - bias: Tensor, - weight_id: int, - bias_id: int, - normalized_shape: int, - eps: float, - output_parallel_mode: ParallelMode, - input_x_weight_parallel_mode: ParallelMode, - ) -> Tensor: - ctx.weight_id = weight_id - ctx.bias_id = bias_id - - sum_ = torch.sum(input_, dim=-1, keepdim=True) - sqr_sum = torch.sum(input_**2, dim=-1, keepdim=True) - mean, sqr_mean = all_reduce(torch.stack((sum_, sqr_sum)), output_parallel_mode) / normalized_shape - - output, mu, sigma = norm_forward(input_, mean, sqr_mean, weight, bias, eps) - - ctx.save_for_backward(mu, sigma, weight) - - ctx.normalized_shape = normalized_shape - ctx.output_parallel_mode = output_parallel_mode - ctx.input_x_weight_parallel_mode = input_x_weight_parallel_mode - - return output - - @staticmethod - @custom_bwd - def backward(ctx, output_grad: Tensor) -> Tuple[Tensor, ...]: - mu, sigma, weight = ctx.saved_tensors - - bias_grad, weight_grad = output_grad, output_grad * mu / sigma - bias_grad = torch.sum(bias_grad, dim=tuple(range(len(bias_grad.shape))[:-1])) - bias_grad, op = all_reduce(bias_grad, ctx.input_x_weight_parallel_mode, async_op=True) - bias_grad = push_async_grad(op, bias_grad, ctx.bias_id) - weight_grad = torch.sum(weight_grad, dim=tuple(range(len(weight_grad.shape))[:-1])) - weight_grad, op = all_reduce(weight_grad, ctx.input_x_weight_parallel_mode, async_op=True) - weight_grad = push_async_grad(op, weight_grad, ctx.weight_id) - - dmu, dmean, dvar = norm_backward(output_grad, mu, sigma, weight) - dvar, dmean = all_reduce(torch.stack((dvar, dmean)), ctx.output_parallel_mode) - input_grad = dmu + (dmean + 2 * dvar * mu) / ctx.normalized_shape - - return input_grad, weight_grad, bias_grad, None, None, None, None, None, None, None, None - - -def layernorm_3d( - input_: Tensor, - weight: Tensor, - bias: Tensor, - normalized_shape: int, - eps: float, - output_parallel_mode: ParallelMode, - input_x_weight_parallel_mode: ParallelMode, -) -> Tensor: - r"""3D parallel Layernorm. - - Args: - input_ (:class:`torch.tensor`): input matrix. - weight (:class:`torch.tensor`): matrix of weight. - bias (:class:`torch.tensor`): matrix of bias. - normalized_shape (int): input shape from an expected input of size. - :math:`[* \times \text{normalized_shape}[0] \times \text{normalized_shape}[1] - \times \ldots \times \text{normalized_shape}[-1]]` - If a single integer is used, it is treated as a singleton list, and this module will - normalize over the last dimension which is expected to be of that specific size. - eps (float): a value added to the denominator for numerical stability - output_parallel_mode (:class:`colossalai.context.parallel_mode.ParallelMode`): output parallel mode. - input_x_weight_parallel_mode (:class:`colossalai.context.parallel_mode.ParallelMode`): input x weight parallel mode. - - Note: - The parallel_mode should be concluded in ``ParallelMode``. More details about ``ParallelMode`` could be found - in `parallel_mode `_ - """ - return _Layernorm3D.apply( - input_, - weight, - bias, - id(weight), - id(bias), - normalized_shape, - eps, - output_parallel_mode, - input_x_weight_parallel_mode, - ) - - -def split_tensor_3d(tensor: Tensor, dim: int, parallel_mode: ParallelMode) -> Tensor: - r"""Splits 3D parallel tensor in specified dimension. - - Args: - tensor (:class:`torch.tensor`): Input tensor. - dim (int): Specified dimension in which to split. - parallel_mode (:class:`colossalai.context.parallel_mode.ParallelMode`, optional): Parallel mode. - - Returns: - :class:`torch.tensor`: The tensor has been split. - - Note: - The parallel_mode should be concluded in ``ParallelMode``. More details about ``ParallelMode`` could be found - in `parallel_mode `_. - """ - dim_size = tensor.size(dim) - world_size = gpc.get_world_size(parallel_mode) - assert dim_size % world_size == 0, \ - f'The dimension {dim} to split, size ({dim_size}) is not a multiple of world size ({world_size}), ' \ - f'cannot split tensor evenly' - if tensor.size(dim) <= 1: - return tensor - output = torch.chunk(tensor, gpc.get_world_size(parallel_mode), - dim=dim)[gpc.get_local_rank(parallel_mode)].contiguous() - return output - - -def split_batch_3d(input_: Tensor, - dim: int = 0, - input_parallel_mode: ParallelMode = ParallelMode.PARALLEL_3D_INPUT, - weight_parallel_mode: ParallelMode = ParallelMode.PARALLEL_3D_WEIGHT) -> Tensor: - r"""Splits 3D tensor in batch. - - Args: - input_ (:class:`torch.tensor`): Input tensor. - dim (int): Specified dimension in which to split. - input_parallel_mode (:class:`colossalai.context.parallel_mode.ParallelMode`, optional): input parallel mode. - weight_parallel_mode (:class:`colossalai.context.parallel_mode.ParallelMode`, optional): weight parallel mode. - - Returns: - :class:`torch.tensor`: The tensor has been split. - - Note: - The parallel_mode should be concluded in ``ParallelMode``. More details about ``ParallelMode`` could be found - in `parallel_mode `_. - """ - if input_.size(dim) <= 1: - return input_ - weight_parallel_mode = get_parallel_mode_from_env(WEIGHT_GROUP_3D) - input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D) - weight_world_size = gpc.get_world_size(weight_parallel_mode) - input_world_size = gpc.get_world_size(input_parallel_mode) - output = torch.chunk(input_, weight_world_size, dim=dim)[gpc.get_local_rank(weight_parallel_mode)].contiguous() - output = torch.chunk(output, input_world_size, dim=dim)[gpc.get_local_rank(input_parallel_mode)].contiguous() - return output - - -class _ReduceTensor3D(torch.autograd.Function): - - @staticmethod - def forward(ctx, input_, parallel_mode): - return all_reduce(input_, parallel_mode) - - @staticmethod - def backward(ctx, output_grad): - return output_grad, None - - -def reduce_tensor_3d(tensor: Tensor, parallel_mode: ParallelMode) -> Tensor: - r"""All-reduce the input - - Args: - tensor (:class:`torch.tensor`): Input tensor. - parallel_mode (:class:`colossalai.context.parallel_mode.ParallelMode`): Parallel mode. - - Note: - The parallel_mode should be concluded in ``ParallelMode``. More details about ``ParallelMode`` could be found - in `parallel_mode `_. - """ - return _ReduceTensor3D.apply(tensor, parallel_mode) - - -class _AllGatherTensor3D(torch.autograd.Function): - - @staticmethod - def forward(ctx, input_, dim, parallel_mode): - ctx.dim = dim - ctx.parallel_mode = parallel_mode - output = all_gather(input_, dim, parallel_mode) - return output - - @staticmethod - def backward(ctx, output_grad): - input_grad = reduce_scatter(output_grad, ctx.dim, ctx.parallel_mode) - return input_grad, None, None - - -def all_gather_tensor_3d(tensor: Tensor, dim: int, parallel_mode: ParallelMode) -> Tensor: - r"""All-reduce the gradient in backward pass. - - Args: - tensor (:class:`torch.tensor`): Input tensor. - dim (int): Dimension to gather. - parallel_mode (:class:`colossalai.context.parallel_mode.ParallelMode`): Parallel mode. - - Note: - The parallel_mode should be concluded in ``ParallelMode``. More details about ``ParallelMode`` could be found - in `parallel_mode `_. - """ - return _AllGatherTensor3D.apply(tensor, dim, parallel_mode) - - -class _ReduceScatterTensor3D(torch.autograd.Function): - - @staticmethod - def forward(ctx, input_, dim, parallel_mode): - ctx.dim = dim - ctx.parallel_mode = parallel_mode - return reduce_scatter(input_, dim, parallel_mode) - - @staticmethod - def backward(ctx, output_grad): - input_grad = all_gather(output_grad, ctx.dim, ctx.parallel_mode) - return input_grad, None, None - - -def reduce_scatter_tensor_3d(tensor: Tensor, dim: int, parallel_mode: ParallelMode) -> Tensor: - r"""Reduce-scatter the input. - - Args: - tensor (:class:`torch.tensor`): Input tensor. - dim (int): Dimension to scatter. - parallel_mode (:class:`colossalai.context.parallel_mode.ParallelMode`): Parallel mode. - - Note: - The parallel_mode should be concluded in ``ParallelMode``. More details about ``ParallelMode`` could be found - in `parallel_mode `_ - """ - dim_size = tensor.size(dim) - world_size = gpc.get_world_size(parallel_mode) - assert dim_size % world_size == 0, \ - f'The batch size ({dim_size}) is not a multiple of square of 3D depth ({world_size}).' - - return _ReduceScatterTensor3D.apply(tensor, dim, parallel_mode) - - -class _ReduceByBatch3D(torch.autograd.Function): - - @staticmethod - @custom_fwd(cast_inputs=torch.float32) - def forward(ctx, - input_: Tensor, - input_parallel_mode: ParallelMode, - weight_parallel_mode: ParallelMode, - reduce_mean: bool = False) -> Tensor: - output = all_reduce(input_, input_parallel_mode) - output = all_reduce(output, weight_parallel_mode) - ctx.reduce_mean = reduce_mean - if reduce_mean: - reduce_size = gpc.get_world_size(input_parallel_mode) * gpc.get_world_size(weight_parallel_mode) - ctx.reduce_size = reduce_size - return output.clone() / reduce_size - return output.clone() - - @staticmethod - @custom_bwd - def backward(ctx, output_grad: Tensor) -> Tuple[Tensor, ...]: - if ctx.reduce_mean: - return output_grad / ctx.reduce_size, None, None, None - else: - return output_grad, None, None, None - - -def reduce_by_batch_3d(tensor: Tensor, - input_parallel_mode: ParallelMode, - weight_parallel_mode: ParallelMode, - reduce_mean: bool = False) -> Tensor: - r"""All-reduce the input from the model parallel region. - - Args: - input_parallel_mode (:class:`colossalai.context.parallel_mode.ParallelMode`): input parallel mode. - weight_parallel_mode (:class:`colossalai.context.parallel_mode.ParallelMode`): weight parallel mode. - reduce_mean (bool, optional): If set to ``True``, it will divide the output by - (input parallel size * weight parallel size), default to False. - - Note: - The parallel_mode should be concluded in ``ParallelMode``. More details about ``ParallelMode`` could be found - in `parallel_mode `_ - """ - return _ReduceByBatch3D.apply(tensor, input_parallel_mode, weight_parallel_mode, reduce_mean) diff --git a/colossalai/nn/layer/parallel_3d/_utils.py b/colossalai/nn/layer/parallel_3d/_utils.py deleted file mode 100644 index 364191a79f88450ca8701d96258f8e34b7b9b784..0000000000000000000000000000000000000000 --- a/colossalai/nn/layer/parallel_3d/_utils.py +++ /dev/null @@ -1,99 +0,0 @@ -from collections import OrderedDict -from functools import partial - -import torch -from torch import Tensor - -from colossalai.constants import INPUT_GROUP_3D, INPUT_X_WEIGHT_3D, OUTPUT_GROUP_3D, OUTPUT_X_WEIGHT_3D, WEIGHT_GROUP_3D -from colossalai.core import global_context as gpc -from colossalai.global_variables import tensor_parallel_env as env - - -def get_depth_from_env() -> int: - try: - depth = env.depth_3d - assert depth > 0, 'DEPTH must be greater than zero' - return depth - - except KeyError as e: - raise EnvironmentError('DEPTH is not found in the current environment, ' - 'please make sure that you have used the correct process group initializer') - - -def get_parallel_mode_from_env(group): - assert group in [INPUT_GROUP_3D, WEIGHT_GROUP_3D, OUTPUT_GROUP_3D, INPUT_X_WEIGHT_3D, OUTPUT_X_WEIGHT_3D], \ - f'{group} is not valid for 3D tensor parallelism.' - return getattr(env, group) - - -def swap_in_out_group(): - env.input_group_3d, env.output_group_3d = env.output_group_3d, env.input_group_3d - env.input_x_weight_group_3d, env.output_x_weight_group_3d = ( - env.output_x_weight_group_3d, - env.input_x_weight_group_3d, - ) - - -def dbg_check_shape(tensor: Tensor, shape: tuple): - rank = gpc.get_global_rank() - if rank == 0: - print(tensor.shape) - assert tensor.shape == shape, \ - '{} does not match {}'.format(tensor.shape, shape) - - -class AsyncGradientBucket(object): - - def __init__(self): - self.bucket = OrderedDict() - - def __len__(self): - return len(self.bucket) - - def push(self, async_op, grad_tensor, param_id): - self.bucket[param_id] = tuple((async_op, grad_tensor)) - return torch.zeros_like(grad_tensor, dtype=grad_tensor.dtype, device=grad_tensor.device) - - def pop(self, param_id): - grad = None - if param_id in self.bucket: - op, grad = self.bucket.pop(param_id) - if op is not None: - op.wait() - return grad - - def synchronize(self, params): - for p in params: - i = id(p) - if i in self.bucket: - op, grad = self.bucket.pop(i) - if op is not None: - op.wait() - p.grad.add_(grad) - - -_async_grad_bucket = AsyncGradientBucket() - - -def push_async_grad(op, grad, param_id): - return _async_grad_bucket.push(op, grad, param_id) - - -def pop_async_grad(param_id): - return _async_grad_bucket.pop(param_id) - - -def _async_grad_hook(grad, param_id): - grad.add_(pop_async_grad(param_id)) - return grad - - -def register_async_grad_hook(param): - param.register_hook(partial(_async_grad_hook, param_id=id(param))) - - -def synchronize(params=list()): - _async_grad_bucket.synchronize(params) - torch.cuda.default_stream().synchronize() - if len(_async_grad_bucket) > 0: - raise RuntimeError(f"{len(_async_grad_bucket)} asynchronous gradient(s) not collected.") diff --git a/colossalai/nn/layer/parallel_3d/layers.py b/colossalai/nn/layer/parallel_3d/layers.py deleted file mode 100644 index 99b0c3f8b7ec339190b28d69842a694cc318fbc7..0000000000000000000000000000000000000000 --- a/colossalai/nn/layer/parallel_3d/layers.py +++ /dev/null @@ -1,1218 +0,0 @@ -import math -from collections import OrderedDict -from typing import Callable - -import torch -import torch.nn as nn -import torch.nn.functional as F -from torch import Tensor -from torch.nn import Parameter - -from colossalai.communication import all_reduce, broadcast -from colossalai.constants import INPUT_GROUP_3D, INPUT_X_WEIGHT_3D, OUTPUT_GROUP_3D, OUTPUT_X_WEIGHT_3D, WEIGHT_GROUP_3D -from colossalai.context import ParallelMode, seed -from colossalai.core import global_context as gpc -from colossalai.global_variables import tensor_parallel_env as env -from colossalai.nn import init as init -from colossalai.nn.layer.base_layer import ParallelLayer -from colossalai.registry import LAYERS -from colossalai.utils.checkpointing import ( - broadcast_state_dict, - gather_tensor_parallel_state_dict, - partition_tensor_parallel_state_dict, -) -from colossalai.utils.cuda import get_current_device - -from ..utils import divide, set_tensor_parallel_attribute_by_partition, to_2tuple -from ._operation import ( - all_gather_tensor_3d, - classifier_3d, - layernorm_3d, - linear_3d, - reduce_scatter_tensor_3d, - split_batch_3d, - split_tensor_3d, - vocab_parallel_classifier_3d, -) -from ._utils import get_depth_from_env, get_parallel_mode_from_env, register_async_grad_hook, swap_in_out_group - - -@LAYERS.register_module -class LayerNorm3D(ParallelLayer): - r"""Layer Normalization for 3D parallelism. - - Args: - normalized_shape (int): input shape from an expected input of size. - :math:`[* \times \text{normalized_shape}[0] \times \text{normalized_shape}[1] - \times \ldots \times \text{normalized_shape}[-1]]` - If a single integer is used, it is treated as a singleton list, and this module will - normalize over the last dimension which is expected to be of that specific size. - eps (float, optional): a value added to the denominator for numerical stability, defaults to 1e-12. - bias (bool, optional): Whether to add a bias, defaults to ``True``. - dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None. - """ - - def __init__(self, normalized_shape: int, eps: float = 1e-12, bias=True, dtype=None): - - super().__init__() - self.input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D) - self.weight_parallel_mode = get_parallel_mode_from_env(WEIGHT_GROUP_3D) - self.output_parallel_mode = get_parallel_mode_from_env(OUTPUT_GROUP_3D) - self.input_x_weight_parallel_mode = get_parallel_mode_from_env(INPUT_X_WEIGHT_3D) - self.depth = get_depth_from_env() - self.normalized_shape = normalized_shape - self.normalized_shape_per_partition = divide(normalized_shape, self.depth) - - self.weight = Parameter( - torch.ones(self.normalized_shape_per_partition, device=get_current_device(), dtype=dtype)) - if bias: - self.bias = Parameter( - torch.zeros(self.normalized_shape_per_partition, device=get_current_device(), dtype=dtype)) - else: - self.bias = None - self.variance_epsilon = eps - self.reset_parameters() - self._set_tensor_parallel_attributes() - - def _set_tensor_parallel_attributes(self) -> None: - set_tensor_parallel_attribute_by_partition(self.weight, self.depth) - if self.bias is not None: - set_tensor_parallel_attribute_by_partition(self.bias, self.depth) - - def reset_parameters(self) -> None: - init.ones_()(self.weight) - register_async_grad_hook(self.weight) - if self.bias is not None: - init.zeros_()(self.bias) - register_async_grad_hook(self.bias) - - def _load_from_global_state_dict(self, state_dict, prefix, *args, **kwargs): - local_state = OrderedDict() - weight_key = prefix + 'weight' - bias_key = prefix + 'bias' - if gpc.get_local_rank(ParallelMode.TENSOR) == 0: - # weight - weight = state_dict.pop(weight_key, None) - if weight is not None: - local_state[weight_key] = weight.transpose(0, 1) - # bias - bias = state_dict.pop(bias_key, None) - if bias is not None: - local_state[bias_key] = bias - - # partition in output groups - if gpc.get_local_rank(self.input_parallel_mode) == 0 and \ - gpc.get_local_rank(self.weight_parallel_mode) == 0: - local_state = partition_tensor_parallel_state_dict( - local_state, - self.output_parallel_mode, - dims={ - weight_key: 0, - bias_key: 0 - }, - partition_states={ - weight_key: True, - bias_key: True, - }, - ) - # broadcast in input groups - if gpc.get_local_rank(self.weight_parallel_mode) == 0: - local_state = broadcast_state_dict(local_state, self.input_parallel_mode) - # broadcast in weight groups - local_state = broadcast_state_dict(local_state, self.weight_parallel_mode) - - super()._load_from_global_state_dict(local_state, prefix, *args, **kwargs) - - def _save_to_global_state_dict(self, destination, prefix, keep_vars): - weight_key = prefix + 'weight' - bias_key = prefix + 'bias' - local_state = OrderedDict({weight_key: self.weight}) - if self.bias is not None: - local_state[bias_key] = self.bias - - # gather in output groups - if gpc.get_local_rank(self.input_parallel_mode) == 0 and \ - gpc.get_local_rank(self.weight_parallel_mode) == 0: - local_state = gather_tensor_parallel_state_dict( - local_state, - self.output_parallel_mode, - dims={ - weight_key: 0, - bias_key: 0 - }, - partition_states={ - weight_key: True, - bias_key: True - }, - keep_vars=keep_vars, - ) - if gpc.get_local_rank(ParallelMode.TENSOR) == 0: - destination.update(local_state) - - def forward(self, input_: Tensor) -> Tensor: - return layernorm_3d( - input_, - self.weight, - self.bias, - self.normalized_shape, - self.variance_epsilon, - self.output_parallel_mode, - self.input_x_weight_parallel_mode, - ) - - -@LAYERS.register_module -class Linear3D(ParallelLayer): - r"""Linear layer for 3D parallelism. - - Args: - in_features (int): size of each input sample. - out_features (int): size of each output sample. - bias (bool, optional): If set to ``False``, the layer will not learn an additive bias, defaults to ``True``. - dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None. - weight_initializer (:class:`typing.Callable`, optional): - The initializer of weight, defaults to kaiming uniform initializer. - bias_initializer (:class:`typing.Callable`, optional): - The initializer of bias, defaults to xavier uniform initializer. - - More details about ``initializer`` please refer to - `init `_. - """ - - def __init__(self, - in_features: int, - out_features: int, - bias: bool = True, - dtype: torch.dtype = None, - skip_bias_add: bool = False, - weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), - bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1)): - super().__init__() - self.in_features = in_features - self.out_features = out_features - self.input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D) - self.weight_parallel_mode = get_parallel_mode_from_env(WEIGHT_GROUP_3D) - self.output_parallel_mode = get_parallel_mode_from_env(OUTPUT_GROUP_3D) - self.output_x_weight_parallel_mode = get_parallel_mode_from_env(OUTPUT_X_WEIGHT_3D) - self.depth = get_depth_from_env() - self.skip_bias_add = skip_bias_add - self.in_features_per_partition = divide(in_features, self.depth**2) - self.out_features_per_partition = divide(out_features, self.depth) - self.bias_features_per_partition = divide(out_features, self.depth) - - self.weight = Parameter( - torch.empty(self.in_features_per_partition, - self.out_features_per_partition, - device=get_current_device(), - dtype=dtype)) - if bias: - self.bias = Parameter( - torch.zeros(self.bias_features_per_partition, device=get_current_device(), dtype=dtype)) - else: - self.bias = None - - self.reset_parameters(weight_initializer, bias_initializer) - self._set_tensor_parallel_attributes() - swap_in_out_group() - - def _set_tensor_parallel_attributes(self) -> None: - set_tensor_parallel_attribute_by_partition(self.weight, self.depth**3) - if self.bias is not None: - set_tensor_parallel_attribute_by_partition(self.bias, self.depth) - - def _sync_grad_hook(self, grad) -> Tensor: - grad = all_reduce(grad.clone(), self.output_x_weight_parallel_mode) - return grad - - def reset_parameters(self, weight_initializer, bias_initializer) -> None: - with seed(ParallelMode.TENSOR): - fan_in, fan_out = self.in_features, self.out_features - - weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out) - register_async_grad_hook(self.weight) - - if self.bias is not None: - bias_initializer(self.bias, fan_in=fan_in) - broadcast(self.bias, - gpc.get_ranks_in_group(self.output_x_weight_parallel_mode)[0], - self.output_x_weight_parallel_mode) - self.bias.register_hook(self._sync_grad_hook) - - def _load_from_global_state_dict(self, state_dict, prefix, *args, **kwargs): - local_state = OrderedDict() - weight_key = prefix + 'weight' - bias_key = prefix + 'bias' - if gpc.get_local_rank(ParallelMode.TENSOR) == 0: - # weight - weight = state_dict.pop(weight_key, None) - if weight is not None: - local_state[weight_key] = weight.transpose(0, 1) - # bias - if self.bias is not None: - bias = state_dict.pop(bias_key, None) - if bias is not None: - local_state[bias_key] = bias - - # partition in output groups - if gpc.get_local_rank(self.input_parallel_mode) == 0 and \ - gpc.get_local_rank(self.weight_parallel_mode) == 0: - local_state = partition_tensor_parallel_state_dict( - local_state, - self.output_parallel_mode, - dims={ - weight_key: 0, - bias_key: 0 - }, - partition_states={ - weight_key: True, - bias_key: False - }, - ) - # partition in input groups - if gpc.get_local_rank(self.weight_parallel_mode) == 0: - local_state = partition_tensor_parallel_state_dict( - local_state, - self.input_parallel_mode, - dims={ - weight_key: -1, - bias_key: 0 - }, - partition_states={ - weight_key: True, - bias_key: True - }, - ) - # partition in weight groups - local_state = partition_tensor_parallel_state_dict( - local_state, - self.weight_parallel_mode, - dims={ - weight_key: 0, - bias_key: 0 - }, - partition_states={ - weight_key: True, - bias_key: False - }, - ) - - super()._load_from_global_state_dict(local_state, prefix, *args, **kwargs) - - def _save_to_global_state_dict(self, destination, prefix, keep_vars): - weight_key = prefix + 'weight' - bias_key = prefix + 'bias' - local_state = OrderedDict({weight_key: self.weight}) - if self.bias is not None: - local_state[bias_key] = self.bias - - # gather in weight groups - local_state = gather_tensor_parallel_state_dict( - local_state, - self.weight_parallel_mode, - dims={ - weight_key: 0, - bias_key: 0 - }, - partition_states={ - weight_key: True, - bias_key: False - }, - keep_vars=keep_vars, - ) - # gather in input groups - if gpc.get_local_rank(self.weight_parallel_mode) == 0: - local_state = gather_tensor_parallel_state_dict( - local_state, - self.input_parallel_mode, - dims={ - weight_key: -1, - bias_key: 0 - }, - partition_states={ - weight_key: True, - bias_key: True - }, - keep_vars=keep_vars, - ) - # gather in output groups - if gpc.get_local_rank(self.input_parallel_mode) == 0 and \ - gpc.get_local_rank(self.weight_parallel_mode) == 0: - local_state = gather_tensor_parallel_state_dict( - local_state, - self.output_parallel_mode, - dims={ - weight_key: 0, - bias_key: 0 - }, - partition_states={ - weight_key: True, - bias_key: False - }, - keep_vars=keep_vars, - ) - if gpc.get_local_rank(ParallelMode.TENSOR) == 0: - local_state[weight_key] = local_state[weight_key].transpose(0, 1) - destination.update(local_state) - - def forward(self, input_: Tensor) -> Tensor: - output = linear_3d( - input_, - self.weight, - self.input_parallel_mode, - self.weight_parallel_mode, - self.output_parallel_mode, - ) - - if not self.skip_bias_add: - if self.bias is not None: - output = output + self.bias - return output - else: - return output, self.bias - - -@LAYERS.register_module -class Classifier3D(ParallelLayer): - r"""Classifier for 3D parallelism. - - Args: - in_features (int): size of each input sample. - num_classes (int): number of classes. - weight (:class:`torch.nn.Parameter`, optional): weight of the classifier, defaults to None. - bias (bool, optional): If set to ``False``, the layer will not learn an additive bias, defaults to ``True``. - dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None. - weight_initializer (:class:`typing.Callable`, optional): - The initializer of weight, defaults to kaiming uniform initializer. - bias_initializer (:class:`typing.Callable`, optional): - The initializer of bias, defaults to xavier uniform initializer. - - More details about ``initializer`` please refer to - `init `_. - """ - - def __init__(self, - in_features: int, - num_classes: int, - weight: Parameter = None, - bias: bool = True, - dtype: torch.dtype = None, - weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), - bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1)): - super().__init__() - self.in_features = in_features - self.num_classes = num_classes - self.input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D) - self.weight_parallel_mode = get_parallel_mode_from_env(WEIGHT_GROUP_3D) - self.output_parallel_mode = get_parallel_mode_from_env(OUTPUT_GROUP_3D) - self.depth = get_depth_from_env() - self.in_features_per_partition = divide(in_features, self.depth) - - if weight is not None: - self.weight = weight - self.has_weight = False - else: - self.weight = Parameter( - torch.empty(self.num_classes, self.in_features_per_partition, device=get_current_device(), dtype=dtype)) - self.has_weight = True - if bias: - self.bias = Parameter(torch.zeros(self.num_classes, device=get_current_device(), dtype=dtype)) - else: - self.bias = None - - self.reset_parameters(weight_initializer, bias_initializer) - self._set_tensor_parallel_attributes() - - def _set_tensor_parallel_attributes(self) -> None: - if self.has_weight: - set_tensor_parallel_attribute_by_partition(self.weight, self.depth) - - def reset_parameters(self, weight_initializer, bias_initializer) -> None: - with seed(ParallelMode.TENSOR): - fan_in, fan_out = self.in_features, self.num_classes - - if self.has_weight: - weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out) - broadcast(self.weight, gpc.get_ranks_in_group(self.weight_parallel_mode)[0], self.weight_parallel_mode) - - register_async_grad_hook(self.weight) - - if self.bias is not None: - bias_initializer(self.bias, fan_in=fan_in) - broadcast(self.bias, gpc.get_ranks_in_group(ParallelMode.TENSOR)[0], ParallelMode.TENSOR) - register_async_grad_hook(self.bias) - - def _load_from_global_state_dict(self, state_dict, prefix, *args, **kwargs): - local_state = OrderedDict() - weight_key = prefix + 'weight' - bias_key = prefix + 'bias' - if gpc.get_local_rank(ParallelMode.TENSOR) == 0: - # weight - if self.has_weight: - weight = state_dict.pop(weight_key, None) - if weight is not None: - local_state[weight_key] = weight - # bias - if self.bias is not None: - bias = state_dict.pop(bias_key, None) - if bias is not None: - local_state[bias_key] = bias - - # partition in output groups - if gpc.get_local_rank(self.input_parallel_mode) == 0 and \ - gpc.get_local_rank(self.weight_parallel_mode) == 0: - local_state = partition_tensor_parallel_state_dict( - local_state, - self.output_parallel_mode, - dims={ - weight_key: -1, - bias_key: 0 - }, - partition_states={ - weight_key: True, - bias_key: False - }, - ) - # broadcast in input groups - if gpc.get_local_rank(self.weight_parallel_mode) == 0: - local_state = broadcast_state_dict(local_state, self.input_parallel_mode) - # broadcast in weight groups - local_state = broadcast_state_dict(local_state, self.weight_parallel_mode) - - super()._load_from_global_state_dict(local_state, prefix, *args, **kwargs) - - def _save_to_global_state_dict(self, destination, prefix, keep_vars): - weight_key = prefix + 'weight' - bias_key = prefix + 'bias' - local_state = OrderedDict() - if self.has_weight: - local_state[weight_key] = self.weight - if self.bias is not None: - local_state[bias_key] = self.bias - - # gather in output groups - if gpc.get_local_rank(self.input_parallel_mode) == 0 and \ - gpc.get_local_rank(self.weight_parallel_mode) == 0: - local_state = gather_tensor_parallel_state_dict( - local_state, - self.output_parallel_mode, - dims={ - weight_key: -1, - bias_key: 0 - }, - partition_states={ - weight_key: True, - bias_key: False - }, - keep_vars=keep_vars, - ) - if gpc.get_local_rank(ParallelMode.TENSOR) == 0: - destination.update(local_state) - - def forward(self, input_: Tensor) -> Tensor: - return classifier_3d( - input_, - self.weight, - self.bias, - self.input_parallel_mode, - self.weight_parallel_mode, - self.output_parallel_mode, - ) - - -@LAYERS.register_module -class VocabParallelClassifier3D(ParallelLayer): - r"""Vocab parallel classifier layer for 3D parallelism. - - Args: - in_features (int): size of each input sample. - num_classes (int): number of classes. - weight (:class:`torch.nn.Parameter`, optional): weight of the classifier, defaults to None. - bias (bool, optional): If set to ``False``, the layer will not learn an additive bias, defaults to ``True``. - dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None. - weight_initializer (:class:`typing.Callable`, optional): - The initializer of weight, defaults to kaiming uniform initializer. - bias_initializer (:class:`typing.Callable`, optional): - The initializer of bias, defaults to xavier uniform initializer. - - More details about ``initializer`` please refer to - `init `_. - """ - - def __init__(self, - in_features: int, - num_classes: int, - weight: Parameter = None, - bias: bool = True, - dtype: torch.dtype = None, - weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), - bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1)): - super().__init__() - self.in_features = in_features - self.num_classes = num_classes - self.input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D) - self.weight_parallel_mode = get_parallel_mode_from_env(WEIGHT_GROUP_3D) - self.output_parallel_mode = get_parallel_mode_from_env(OUTPUT_GROUP_3D) - self.output_x_weight_parallel_mode = get_parallel_mode_from_env(OUTPUT_X_WEIGHT_3D) - self.depth = get_depth_from_env() - self.in_features_per_partition = divide(in_features, self.depth) - self.out_features_per_partition = divide(num_classes, self.depth**2) - self.bias_features_per_partition = divide(num_classes, self.depth) - - if weight is not None: - self.weight = weight - self.has_weight = False - else: - self.weight = Parameter( - torch.empty(self.out_features_per_partition, - self.in_features_per_partition, - device=get_current_device(), - dtype=dtype)) - self.has_weight = True - if bias: - self.bias = Parameter( - torch.zeros(self.bias_features_per_partition, device=get_current_device(), dtype=dtype)) - else: - self.bias = None - - self.reset_parameters(weight_initializer, bias_initializer) - self._set_tensor_parallel_attributes() - swap_in_out_group() - env.vocab_parallel = True - - def _set_tensor_parallel_attributes(self) -> None: - if self.has_weight: - set_tensor_parallel_attribute_by_partition(self.weight, self.depth**3) - if self.bias is not None: - set_tensor_parallel_attribute_by_partition(self.bias, self.depth) - - def reset_parameters(self, weight_initializer, bias_initializer) -> None: - with seed(ParallelMode.TENSOR): - fan_in, fan_out = self.in_features, self.num_classes - - if self.has_weight: - weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out) - - register_async_grad_hook(self.weight) - - if self.bias is not None: - bias_initializer(self.bias, fan_in=fan_in) - broadcast(self.bias, - gpc.get_ranks_in_group(self.output_x_weight_parallel_mode)[0], - self.output_x_weight_parallel_mode) - register_async_grad_hook(self.bias) - - def _load_from_global_state_dict(self, state_dict, prefix, *args, **kwargs): - local_state = OrderedDict() - weight_key = prefix + 'weight' - bias_key = prefix + 'bias' - if gpc.get_local_rank(ParallelMode.TENSOR) == 0: - # weight - if self.has_weight: - weight = state_dict.pop(weight_key, None) - if weight is not None: - local_state[weight_key] = weight - # bias - if self.bias is not None: - bias = state_dict.pop(bias_key, None) - if bias is not None: - local_state[bias_key] = bias - - # partition in output groups - if gpc.get_local_rank(self.input_parallel_mode) == 0 and \ - gpc.get_local_rank(self.weight_parallel_mode) == 0: - local_state = partition_tensor_parallel_state_dict( - local_state, - self.output_parallel_mode, - dims={ - weight_key: -1, - bias_key: 0 - }, - partition_states={ - weight_key: True, - bias_key: False - }, - ) - # partition in input groups - if gpc.get_local_rank(self.weight_parallel_mode) == 0: - local_state = partition_tensor_parallel_state_dict( - local_state, - self.input_parallel_mode, - dims={ - weight_key: 0, - bias_key: 0 - }, - partition_states={ - weight_key: True, - bias_key: True - }, - ) - # partition in weight groups - local_state = partition_tensor_parallel_state_dict( - local_state, - self.weight_parallel_mode, - dims={ - weight_key: 0, - bias_key: 0 - }, - partition_states={ - weight_key: True, - bias_key: False - }, - ) - - super()._load_from_global_state_dict(local_state, prefix, *args, **kwargs) - - def _save_to_global_state_dict(self, destination, prefix, keep_vars): - weight_key = prefix + 'weight' - bias_key = prefix + 'bias' - local_state = OrderedDict({weight_key: self.weight}) - if self.bias is not None: - local_state[bias_key] = self.bias - - # gather in weight groups - local_state = gather_tensor_parallel_state_dict( - local_state, - self.weight_parallel_mode, - dims={ - weight_key: 0, - bias_key: 0 - }, - partition_states={ - weight_key: True, - bias_key: False - }, - keep_vars=keep_vars, - ) - # gather in input groups - if gpc.get_local_rank(self.weight_parallel_mode) == 0: - local_state = gather_tensor_parallel_state_dict( - local_state, - self.input_parallel_mode, - dims={ - weight_key: 0, - bias_key: 0 - }, - partition_states={ - weight_key: True, - bias_key: True - }, - keep_vars=keep_vars, - ) - # gather in output groups - if gpc.get_local_rank(self.input_parallel_mode) == 0 and \ - gpc.get_local_rank(self.weight_parallel_mode) == 0: - local_state = gather_tensor_parallel_state_dict( - local_state, - self.output_parallel_mode, - dims={ - weight_key: -1, - bias_key: 0 - }, - partition_states={ - weight_key: True, - bias_key: False - }, - keep_vars=keep_vars, - ) - if gpc.get_local_rank(ParallelMode.TENSOR) == 0: - destination.update(local_state) - - def forward(self, input_: Tensor) -> Tensor: - return vocab_parallel_classifier_3d( - input_, - self.weight, - self.bias, - self.input_parallel_mode, - self.weight_parallel_mode, - self.output_parallel_mode, - ) - - -@LAYERS.register_module -class PatchEmbedding3D(ParallelLayer): - r"""2D Image to Patch Embedding. - - Args: - img_size (int): image size. - patch_size (int): patch size. - in_chans (int): number of channels of input image. - embed_size (int): size of embedding. - dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None. - flatten (bool, optional): whether to flatten output tensor, defaults to True. - weight_initializer (:class:`typing.Callable`, optional): - The initializer of weight, defaults to kaiming uniform initializer. - bias_initializer (:class:`typing.Callable`, optional): - The initializer of bias, defaults to xavier uniform initializer. - position_embed_initializer (:class:`typing.Callable`, optional): - The initializer of position embedding, defaults to zeros initializer. - - More details about ``initializer`` please refer to - `init `_. - """ - - def __init__(self, - img_size: int, - patch_size: int, - in_chans: int, - embed_size: int, - flatten: bool = True, - dtype: torch.dtype = None, - weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), - bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1), - position_embed_initializer: Callable = init.zeros_()): - super().__init__() - self.depth = get_depth_from_env() - self.input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D) - self.weight_parallel_mode = get_parallel_mode_from_env(WEIGHT_GROUP_3D) - self.output_parallel_mode = get_parallel_mode_from_env(OUTPUT_GROUP_3D) - self.input_x_weight_parallel_mode = get_parallel_mode_from_env(INPUT_X_WEIGHT_3D) - img_size = to_2tuple(img_size) - patch_size = to_2tuple(patch_size) - self.img_size = img_size - self.patch_size = patch_size - self.grid_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1]) - self.num_patches = self.grid_size[0] * self.grid_size[1] - self.embed_size = embed_size - embed_size_per_partition = embed_size // self.depth - self.flatten = flatten - - self.weight = nn.Parameter( - torch.empty((embed_size_per_partition, in_chans, *self.patch_size), - device=get_current_device(), - dtype=dtype)) - self.bias = nn.Parameter(torch.empty(embed_size_per_partition, device=get_current_device(), dtype=dtype)) - - self.cls_token = nn.Parameter( - torch.zeros((1, 1, embed_size_per_partition), device=get_current_device(), dtype=dtype)) - self.pos_embed = nn.Parameter( - torch.zeros((1, self.num_patches + 1, embed_size_per_partition), device=get_current_device(), dtype=dtype)) - - self.reset_parameters(weight_initializer, bias_initializer, position_embed_initializer) - self._set_tensor_parallel_attributes() - - def _set_tensor_parallel_attributes(self) -> None: - set_tensor_parallel_attribute_by_partition(self.weight, self.depth) - set_tensor_parallel_attribute_by_partition(self.bias, self.depth) - set_tensor_parallel_attribute_by_partition(self.cls_token, self.depth) - set_tensor_parallel_attribute_by_partition(self.pos_embed, self.depth) - - def _sync_grad_hook(self, grad) -> Tensor: - grad = all_reduce(grad.clone(), self.input_x_weight_parallel_mode) - return grad - - def reset_parameters(self, weight_initializer, bias_initializer, position_embed_initializer) -> None: - with seed(ParallelMode.TENSOR): - fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight) - fan_out = self.embed_size - weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out) - bias_initializer(self.bias, fan_in=fan_in) - position_embed_initializer(self.pos_embed) - - src_rank = gpc.get_ranks_in_group(self.input_x_weight_parallel_mode)[0] - broadcast(self.weight, src_rank, self.input_x_weight_parallel_mode) - broadcast(self.bias, src_rank, self.input_x_weight_parallel_mode) - broadcast(self.pos_embed, src_rank, self.input_x_weight_parallel_mode) - - self.weight.register_hook(self._sync_grad_hook) - self.bias.register_hook(self._sync_grad_hook) - self.cls_token.register_hook(self._sync_grad_hook) - self.pos_embed.register_hook(self._sync_grad_hook) - - def _load_from_global_state_dict(self, state_dict, prefix, *args, **kwargs): - local_state = OrderedDict() - weight_key = prefix + 'weight' - bias_key = prefix + 'bias' - cls_token_key = prefix + 'cls_token' - pos_embed_key = prefix + 'pos_embed' - if gpc.get_local_rank(ParallelMode.TENSOR) == 0: - # weight - weight = state_dict.pop(weight_key, None) - if weight is not None: - local_state[weight_key] = weight - # bias - bias = state_dict.pop(bias_key, None) - if bias is not None: - local_state[bias_key] = bias - # cls token - cls_token = state_dict.pop(cls_token_key, None) - if cls_token is not None: - local_state[cls_token_key] = cls_token - # pos embed - pos_embed = state_dict.pop(pos_embed_key, None) - if pos_embed is not None: - local_state[pos_embed_key] = pos_embed - - # partition in output groups - if gpc.get_local_rank(self.input_parallel_mode) == 0 and \ - gpc.get_local_rank(self.weight_parallel_mode) == 0: - local_state = partition_tensor_parallel_state_dict( - local_state, - self.output_parallel_mode, - dims={ - weight_key: 0, - bias_key: 0, - cls_token_key: -1, - pos_embed_key: -1 - }, - partition_states={ - weight_key: True, - bias_key: True, - cls_token_key: True, - pos_embed_key: True - }, - ) - # broadcast in input groups - if gpc.get_local_rank(self.weight_parallel_mode) == 0: - local_state = broadcast_state_dict(local_state, self.input_parallel_mode) - # broadcast in weight groups - local_state = broadcast_state_dict(local_state, self.weight_parallel_mode) - - super()._load_from_global_state_dict(local_state, prefix, *args, **kwargs) - - def _save_to_global_state_dict(self, destination, prefix, keep_vars): - weight_key = prefix + 'weight' - bias_key = prefix + 'bias' - cls_token_key = prefix + 'cls_token' - pos_embed_key = prefix + 'pos_embed' - local_state = OrderedDict({ - weight_key: self.weight, - bias_key: self.bias, - cls_token_key: self.cls_token, - pos_embed_key: self.pos_embed - }) - - # gather in output groups - if gpc.get_local_rank(self.input_parallel_mode) == 0 and \ - gpc.get_local_rank(self.weight_parallel_mode) == 0: - local_state = gather_tensor_parallel_state_dict( - local_state, - self.output_parallel_mode, - dims={ - weight_key: 0, - bias_key: 0, - cls_token_key: -1, - pos_embed_key: -1 - }, - partition_states={ - weight_key: True, - bias_key: True, - cls_token_key: True, - pos_embed_key: True - }, - keep_vars=keep_vars, - ) - if gpc.get_local_rank(ParallelMode.TENSOR) == 0: - destination.update(local_state) - - def forward(self, input_: Tensor) -> Tensor: - input_ = split_batch_3d(input_, - input_parallel_mode=self.input_parallel_mode, - weight_parallel_mode=self.weight_parallel_mode) - output = F.conv2d(input_, self.weight, self.bias, stride=self.patch_size) - if self.flatten: - output = output.flatten(2).transpose(1, 2) # BCHW -> BNC - - cls_token = self.cls_token.expand(output.shape[0], -1, -1) - output = torch.cat((cls_token, output), dim=1) - output = output + self.pos_embed - - return output - - -@LAYERS.register_module -class Embedding3D(ParallelLayer): - r"""Embedding for 3D parallelism. - - Args: - num_embeddings (int): number of embeddings. - embedding_dim (int): dimension of embedding. - padding_idx (int, optional): If specified, the entries at padding_idx do not contribute to the gradient; - therefore, the embedding vector at padding_idx is not updated during training, - i.e. it remains as a fixed “pad”, defaults to None. - dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None. - weight_initializer (:class:`typing.Callable`, optional): - he initializer of weight, defaults to normal initializer. - - The ``args`` and ``kwargs`` used in :class:``torch.nn.functional.embedding`` should contain: - :: - - max_norm (float, optional): If given, each embedding vector with norm larger than max_norm is - renormalized to have norm max_norm. Note: this will modify weight in-place. - norm_type (float, optional): The p of the p-norm to compute for the max_norm option. Default 2. - scale_grad_by_freq (bool, optional): If given, this will scale gradients by the inverse - of frequency of the words in the mini-batch. Default False. - sparse (bool, optional): If True, gradient w.r.t. weight will be a sparse tensor. Default False. - - More details about ``args`` and ``kwargs`` could be found in - `Embedding `_. - - More details about initializer please refer to - `init `_ - """ - - def __init__(self, - num_embeddings: int, - embedding_dim: int, - padding_idx: int = None, - dtype: torch.dtype = None, - weight_initializer: Callable = init.normal_(), - *args, - **kwargs): - super().__init__() - self.depth = get_depth_from_env() - self.input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D) - self.weight_parallel_mode = get_parallel_mode_from_env(WEIGHT_GROUP_3D) - self.output_parallel_mode = get_parallel_mode_from_env(OUTPUT_GROUP_3D) - self.input_x_weight_parallel_mode = get_parallel_mode_from_env(INPUT_X_WEIGHT_3D) - - self.num_embeddings = num_embeddings - self.embed_dim = embedding_dim - embed_dim_per_partition = divide(embedding_dim, self.depth) - self.padding_idx = padding_idx - self.embed_args = args - self.embed_kwargs = kwargs - - self.weight = nn.Parameter( - torch.empty((num_embeddings, embed_dim_per_partition), device=get_current_device(), dtype=dtype)) - - self.reset_parameters(weight_initializer) - self._set_tensor_parallel_attributes() - - def _set_tensor_parallel_attributes(self) -> None: - set_tensor_parallel_attribute_by_partition(self.weight, self.depth) - - def _sync_grad_hook(self, grad) -> Tensor: - grad = all_reduce(grad.clone(), self.input_x_weight_parallel_mode) - return grad - - def reset_parameters(self, weight_initializer) -> None: - with seed(ParallelMode.TENSOR): - fan_in, fan_out = self.num_embeddings, self.embed_dim - weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out) - self._fill_padding_idx_with_zero() - broadcast(self.weight, - gpc.get_ranks_in_group(self.input_x_weight_parallel_mode)[0], self.input_x_weight_parallel_mode) - self.weight.register_hook(self._sync_grad_hook) - - def _fill_padding_idx_with_zero(self) -> None: - if self.padding_idx is not None: - with torch.no_grad(): - self.weight[self.padding_idx].fill_(0) - - def _load_from_global_state_dict(self, state_dict, prefix, *args, **kwargs): - local_state = OrderedDict() - weight_key = prefix + 'weight' - if gpc.get_local_rank(ParallelMode.TENSOR) == 0: - # weight - weight = state_dict.pop(weight_key, None) - if weight is not None: - local_state[weight_key] = weight - - # partition in output groups - if gpc.get_local_rank(self.input_parallel_mode) == 0 and \ - gpc.get_local_rank(self.weight_parallel_mode) == 0: - local_state = partition_tensor_parallel_state_dict( - local_state, - self.output_parallel_mode, - dims={weight_key: 0}, - partition_states={weight_key: True}, - ) - # broadcast in input groups - if gpc.get_local_rank(self.weight_parallel_mode) == 0: - local_state = broadcast_state_dict(local_state, self.input_parallel_mode) - # broadcast in weight groups - local_state = broadcast_state_dict(local_state, self.weight_parallel_mode) - - super()._load_from_global_state_dict(local_state, prefix, *args, **kwargs) - - def _save_to_global_state_dict(self, destination, prefix, keep_vars): - weight_key = prefix + 'weight' - local_state = OrderedDict({weight_key: self.weight}) - - # gather in output groups - if gpc.get_local_rank(self.input_parallel_mode) == 0 and \ - gpc.get_local_rank(self.weight_parallel_mode) == 0: - local_state = gather_tensor_parallel_state_dict( - local_state, - self.output_parallel_mode, - dims={weight_key: 0}, - partition_states={weight_key: True}, - keep_vars=keep_vars, - ) - if gpc.get_local_rank(ParallelMode.TENSOR) == 0: - destination.update(local_state) - - def forward(self, input_: Tensor) -> Tensor: - input_ = split_batch_3d(input_, - input_parallel_mode=self.input_parallel_mode, - weight_parallel_mode=self.weight_parallel_mode) - output = F.embedding(input_, self.weight, self.padding_idx, *self.embed_args, **self.embed_kwargs) - - return output - - -@LAYERS.register_module -class VocabParallelEmbedding3D(ParallelLayer): - r"""Embedding parallelized in the vocabulary dimension. - - Args: - num_embeddings (int): number of embeddings. - embedding_dim (int): dimension of embedding. - padding_idx (int, optional): If specified, the entries at padding_idx do not contribute to the gradient; - therefore, the embedding vector at padding_idx is not updated during training, - i.e. it remains as a fixed “pad”, defaults to None. - dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None. - weight_initializer (:class:`typing.Callable`, optional): - he initializer of weight, defaults to normal initializer. - - The ``args`` and ``kwargs`` used in :class:``torch.nn.functional.embedding`` should contain: - :: - - max_norm (float, optional): If given, each embedding vector with norm larger than max_norm is - renormalized to have norm max_norm. Note: this will modify weight in-place. - norm_type (float, optional): The p of the p-norm to compute for the max_norm option. Default 2. - scale_grad_by_freq (bool, optional): If given, this will scale gradients by the inverse - of frequency of the words in the mini-batch. Default False. - sparse (bool, optional): If True, gradient w.r.t. weight will be a sparse tensor. Default False. - - More details about ``args`` and ``kwargs`` could be found in - `Embedding `_. - - More details about initializer please refer to - `init `_. - """ - - def __init__(self, - num_embeddings: int, - embedding_dim: int, - padding_idx: int = None, - dtype: torch.dtype = None, - weight_initializer: Callable = init.normal_(), - *args, - **kwargs): - super().__init__() - self.num_embeddings = num_embeddings - self.embed_dim = embedding_dim - self.padding_idx = padding_idx - self.embed_args = args - self.embed_kwargs = kwargs - - self.depth = get_depth_from_env() - self.input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D) - self.weight_parallel_mode = get_parallel_mode_from_env(WEIGHT_GROUP_3D) - self.output_parallel_mode = get_parallel_mode_from_env(OUTPUT_GROUP_3D) - self.num_embeddings_per_partition = divide(self.num_embeddings, self.depth**2) - self.embed_dim_per_partition = divide(self.embed_dim, self.depth) - vocab_parallel_rank = gpc.get_local_rank(self.input_parallel_mode) - self.vocab_start_index = vocab_parallel_rank * self.num_embeddings_per_partition * self.depth - self.vocab_end_index = self.vocab_start_index + self.num_embeddings_per_partition * self.depth - - self.weight = Parameter( - torch.empty((self.num_embeddings_per_partition, self.embed_dim_per_partition), - device=get_current_device(), - dtype=dtype)) - - self.reset_parameters(weight_initializer) - self._set_tensor_parallel_attributes() - env.vocab_parallel = True - - def _set_tensor_parallel_attributes(self): - set_tensor_parallel_attribute_by_partition(self.weight, self.depth**3) - - def reset_parameters(self, weight_initializer) -> None: - with seed(ParallelMode.TENSOR): - fan_in, fan_out = self.num_embeddings, self.embed_dim - weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out) - self._fill_padding_idx_with_zero() - - def _fill_padding_idx_with_zero(self) -> None: - if self.padding_idx is not None and \ - self.padding_idx >= self.vocab_start_index and self.padding_idx < self.vocab_end_index: - with torch.no_grad(): - self.weight[self.padding_idx - self.vocab_start_index].fill_(0) - - def _load_from_global_state_dict(self, state_dict, prefix, *args, **kwargs): - local_state = OrderedDict() - weight_key = prefix + 'weight' - if gpc.get_local_rank(ParallelMode.TENSOR) == 0: - # weight - weight = state_dict.pop(weight_key, None) - if weight is not None: - local_state[weight_key] = weight - - # partition in output groups - if gpc.get_local_rank(self.input_parallel_mode) == 0 and \ - gpc.get_local_rank(self.weight_parallel_mode) == 0: - local_state = partition_tensor_parallel_state_dict( - local_state, - self.output_parallel_mode, - dims={weight_key: -1}, - partition_states={weight_key: True}, - ) - # partition in input groups - if gpc.get_local_rank(self.weight_parallel_mode) == 0: - local_state = partition_tensor_parallel_state_dict( - local_state, - self.input_parallel_mode, - dims={weight_key: 0}, - partition_states={weight_key: True}, - ) - # partition in weight groups - local_state = partition_tensor_parallel_state_dict( - local_state, - self.weight_parallel_mode, - dims={weight_key: 0}, - partition_states={weight_key: True}, - ) - - super()._load_from_global_state_dict(local_state, prefix, *args, **kwargs) - - def _save_to_global_state_dict(self, destination, prefix, keep_vars): - weight_key = prefix + 'weight' - local_state = OrderedDict({weight_key: self.weight}) - - # gather in weight groups - local_state = gather_tensor_parallel_state_dict( - local_state, - self.weight_parallel_mode, - dims={weight_key: 0}, - partition_states={weight_key: True}, - keep_vars=keep_vars, - ) - # gather in input groups - if gpc.get_local_rank(self.weight_parallel_mode) == 0: - local_state = gather_tensor_parallel_state_dict( - local_state, - self.input_parallel_mode, - dims={weight_key: 0}, - partition_states={weight_key: True}, - keep_vars=keep_vars, - ) - # gather in output groups - if gpc.get_local_rank(self.input_parallel_mode) == 0 and \ - gpc.get_local_rank(self.weight_parallel_mode) == 0: - local_state = gather_tensor_parallel_state_dict( - local_state, - self.output_parallel_mode, - dims={weight_key: -1}, - partition_states={weight_key: True}, - keep_vars=keep_vars, - ) - if gpc.get_local_rank(ParallelMode.TENSOR) == 0: - destination.update(local_state) - - def forward(self, input_: Tensor) -> Tensor: - input_ = split_tensor_3d(input_, 0, self.weight_parallel_mode) - - input_mask = (input_ < self.vocab_start_index) | (input_ >= self.vocab_end_index) - masked_input = input_.clone() - self.vocab_start_index - masked_input[input_mask] = 0 - - weight = all_gather_tensor_3d(self.weight, 0, self.weight_parallel_mode) - - output_parallel = F.embedding(masked_input, weight, self.padding_idx, *self.embed_args, **self.embed_kwargs) - - output_parallel[input_mask, :] = 0. - output = reduce_scatter_tensor_3d(output_parallel, 0, self.input_parallel_mode) - - return output diff --git a/colossalai/nn/layer/parallel_sequence/__init__.py b/colossalai/nn/layer/parallel_sequence/__init__.py deleted file mode 100644 index 4fa9eed6f34b8ccdcf03935337bc96ba705530d0..0000000000000000000000000000000000000000 --- a/colossalai/nn/layer/parallel_sequence/__init__.py +++ /dev/null @@ -1,4 +0,0 @@ -from ._operation import RingQK, RingAV -from .layers import TransformerSelfAttentionRing - -__all__ = ['TransformerSelfAttentionRing', 'RingAV', 'RingQK'] diff --git a/colossalai/nn/layer/parallel_sequence/_operation.py b/colossalai/nn/layer/parallel_sequence/_operation.py deleted file mode 100644 index fc80494224c6d2ec3176f40c47733133a422b88c..0000000000000000000000000000000000000000 --- a/colossalai/nn/layer/parallel_sequence/_operation.py +++ /dev/null @@ -1,151 +0,0 @@ -#!/usr/bin/env python -# -*- encoding: utf-8 -*- - -import torch -from torch import distributed as dist - -from colossalai.communication import ring_forward -from colossalai.context.parallel_mode import ParallelMode -from colossalai.core import global_context as gpc -from colossalai.nn.layer.parallel_sequence._utils import _calc_incoming_device_range, _calc_current_device_range -from colossalai.utils import get_current_device -from torch.cuda.amp import custom_bwd, custom_fwd - - -class RingQK(torch.autograd.Function): - """ - Calculate QK in a ring-exchange style - """ - - @staticmethod - @custom_fwd - def forward(ctx, sub_q, sub_k, batch_size, num_attention_heads, sub_seq_length): - # save tensor for backward - ctx.save_for_backward(sub_q, sub_k) - ctx.sub_seq_length = sub_seq_length - - # create local segment of attention score - attention_score = torch.empty(batch_size * num_attention_heads, - sub_seq_length, - sub_seq_length * gpc.get_world_size(ParallelMode.SEQUENCE), - dtype=sub_q.dtype, - device=get_current_device()) - - # compute local QK^T - part_a = torch.matmul(sub_q, sub_k.transpose(2, 1)) - local_rank = gpc.get_local_rank(ParallelMode.SEQUENCE) - local_world_size = gpc.get_world_size(ParallelMode.SEQUENCE) - start_idx = local_rank * sub_seq_length - end_idx = (local_rank + 1) * sub_seq_length - attention_score[:, :, start_idx:end_idx] = part_a - - # compute QK^T in ring-all-reduce style - for i in range(local_world_size - 1): - sub_k = ring_forward(sub_k, ParallelMode.SEQUENCE) - start_idx, end_idx = _calc_incoming_device_range(i, local_rank, local_world_size, sub_seq_length) - part_a = torch.matmul(sub_q, sub_k.transpose(2, 1)) - attention_score[:, :, start_idx:end_idx] = part_a - - return attention_score - - @staticmethod - @custom_bwd - def backward(ctx, grad_output): - sub_q, sub_k, = ctx.saved_tensors - local_rank = gpc.get_local_rank(ParallelMode.SEQUENCE) - local_world_size = gpc.get_world_size(ParallelMode.SEQUENCE) - - # calculate gradient of sub_k - grad_k = torch.matmul(grad_output.transpose(2, 1), sub_q) - - dist.all_reduce(grad_k, group=gpc.get_group(ParallelMode.SEQUENCE)) - grad_k = grad_k[:, local_rank * ctx.sub_seq_length:(local_rank + 1) * ctx.sub_seq_length] - grad_k /= local_world_size - - # calculate gradient for sub_q - grad_q = torch.zeros_like( - sub_q, - dtype=sub_q.dtype, - device=get_current_device(), - ) - - # compute with local sub_k - start_idx, end_idx = _calc_current_device_range(local_rank, ctx.sub_seq_length) - grad_q += torch.matmul(grad_output[:, :, start_idx:end_idx], sub_k) - - # compute QK^T in ring-all-reduce style - for i in range(local_world_size - 1): - sub_k = ring_forward(sub_k, ParallelMode.SEQUENCE) - start_idx, end_idx = _calc_incoming_device_range(i, local_rank, local_world_size, ctx.sub_seq_length) - grad_q += torch.matmul(grad_output[:, :, start_idx:end_idx], sub_k) - - grad_q /= local_world_size - - return grad_q, grad_k, None, None, None - - -class RingAV(torch.autograd.Function): - """ - Calculate AV in a ring-exchange style - """ - - @staticmethod - @custom_fwd - def forward(ctx, attention_score, sub_v, batch_size, num_attention_heads, attention_head_size, sub_seq_length): - local_rank = gpc.get_local_rank(ParallelMode.SEQUENCE) - local_world_size = gpc.get_world_size(ParallelMode.SEQUENCE) - local_start_idx, local_end_idx = _calc_current_device_range(local_rank, sub_seq_length) - - sub_attention_result = torch.zeros(batch_size * num_attention_heads, - sub_seq_length, - attention_head_size, - device=get_current_device(), - dtype=attention_score.dtype) - - # save tensors for backward - ctx.save_for_backward(attention_score, sub_v) - ctx.sub_seq_length = sub_seq_length - - # compute local AV - part_av = torch.matmul(attention_score[:, :, local_start_idx:local_end_idx], sub_v) - sub_attention_result += part_av - - # compute AV in ring - all - reduce style - for i in range(local_world_size - 1): - sub_v = ring_forward(sub_v, ParallelMode.SEQUENCE) - start_idx, end_idx = _calc_incoming_device_range(i, local_rank, local_world_size, sub_seq_length) - - # compute QK^T - part_av = torch.matmul(attention_score[:, :, start_idx:end_idx], sub_v) - sub_attention_result += part_av - return sub_attention_result - - @staticmethod - @custom_bwd - def backward(ctx, grad_output): - local_rank = gpc.get_local_rank(ParallelMode.SEQUENCE) - local_world_size = gpc.get_world_size(ParallelMode.SEQUENCE) - local_start_idx, local_end_idx = _calc_current_device_range(local_rank, ctx.sub_seq_length) - attention_scores, sub_v = ctx.saved_tensors - - # calculate gradient of v - grad_v = torch.matmul(attention_scores.transpose(2, 1), grad_output) - dist.all_reduce(grad_v, group=gpc.get_group(ParallelMode.SEQUENCE)) - grad_v = grad_v[:, local_start_idx:local_end_idx] - grad_v /= local_world_size - - # calculate gradient for attention score - grad_attention_score = torch.zeros_like(attention_scores, dtype=grad_output.dtype, device=get_current_device()) - - # compute with local sub_k - grad_attention_score[:, :, local_start_idx:local_end_idx] += torch.matmul(grad_output, sub_v.transpose(2, 1)) - - # compute QK^T in ring-all-reduce style - for i in range(local_world_size - 1): - sub_v = ring_forward(sub_v, ParallelMode.SEQUENCE) - start_idx, end_idx = _calc_incoming_device_range(i, local_rank, local_world_size, ctx.sub_seq_length) - - # compute grad_q - grad_attention_score[:, :, start_idx:end_idx] += torch.matmul(grad_output, sub_v.transpose(2, 1)) - - return grad_attention_score, grad_v, None, None, None, None diff --git a/colossalai/nn/layer/parallel_sequence/layers.py b/colossalai/nn/layer/parallel_sequence/layers.py deleted file mode 100644 index d9486217bbc93a9f75cfda809701cafaa15957cc..0000000000000000000000000000000000000000 --- a/colossalai/nn/layer/parallel_sequence/layers.py +++ /dev/null @@ -1,237 +0,0 @@ -#!/usr/bin/env python -# -*- encoding: utf-8 -*- - -import math -import colossalai - -import torch -import torch.nn as nn -import torch.nn.functional as F -from torch.nn import Parameter - -from colossalai.context.parallel_mode import ParallelMode -from colossalai.core import global_context as gpc -from colossalai.nn.layer.parallel_sequence._operation import RingQK, RingAV -from colossalai.registry import LAYERS -from colossalai.kernel.cuda_native.scaled_softmax import AttnMaskType -from colossalai.kernel import FusedScaleMaskSoftmax -from colossalai.context import seed - - -@LAYERS.register_module -class TransformerSelfAttentionRing(nn.Module): - """Parallel self-attention layer abstract class. - Self-attention layer takes input with size [b, s, h] - and returns output of the same size. - - Args: - hidden_size (int): hidden size. - num_attention_heads (int): number of attention heads. - attention_dropout (float): dropout probability for attention layer. - attention_mask_func (:class:`typing.Callable`): Mask function to be applied. - layer_number (int): number of layers. - - """ - - def __init__(self, - hidden_size, - num_attention_heads, - attention_dropout, - attention_mask_func, - layer_number, - apply_query_key_layer_scaling: bool = False, - convert_fp16_to_fp32_in_softmax: bool = False, - attn_mask_type=AttnMaskType.padding, - masked_softmax_fusion=True, - fp16=False, - bf16=False): - super().__init__() - self.convert_fp16_to_fp32_in_softmax = convert_fp16_to_fp32_in_softmax - self.apply_query_key_layer_scaling = apply_query_key_layer_scaling - self.attention_mask_func = attention_mask_func - self.layer_number = layer_number - self.hidden_size = hidden_size - self.num_attention_heads = num_attention_heads - self.attn_mask_type = attn_mask_type - assert self.layer_number > 0 - self.attention_dropout = attention_dropout - - if self.apply_query_key_layer_scaling: - self.convert_fp16_to_fp32_in_softmax = True - - assert self.hidden_size % self.num_attention_heads == 0, \ - 'hidden size is not divisible by the number of attention heads' - - self.hidden_size_per_attention_head = self.hidden_size // num_attention_heads - - self.world_size = gpc.get_world_size(ParallelMode.SEQUENCE) - - # Strided linear layer. - self.query_key_value = _Linear( - hidden_size, - 3 * self.hidden_size, - ) - - self.coeff = None - self.norm_factor = math.sqrt(self.hidden_size) - - if self.apply_query_key_layer_scaling: - self.coeff = layer_number - self.norm_factor *= self.coeff - - self.scale_mask_softmax = FusedScaleMaskSoftmax(fp16, bf16, self.attn_mask_type, masked_softmax_fusion, - self.attention_mask_func, self.convert_fp16_to_fp32_in_softmax, - self.coeff) - - self.attention_dropout = nn.Dropout(attention_dropout) - - # Output. - self.dense = _Linear(hidden_size, hidden_size, bias=True, skip_bias_add=True) - - def forward(self, hidden_states, attention_mask): - # hidden_states: [sub_seq_len, batch_size, hidden_size] - # attention_mask: [batch_size, 1, sub_seq_len, seq_len] - sub_seq_length, batch_size, hidden_size = hidden_states.size() - - # ===================== - # Query, Key, and Value - # ===================== - - # Attention heads shape change: - # [sub_seq_len, batch_size, hidden_size] --> [sub_seq_len, batch_size, (3 * head_size * num_heads)] - mixed_x_layer = self.query_key_value(hidden_states) - - # [sub_seq_len, batch_size, num_heads, 3 * head_size] --> 3 [sub_seq_len, batch_size, num_heads, head_size] - new_tensor_shape = mixed_x_layer.size()[:-1] + (self.num_attention_heads, - 3 * self.hidden_size_per_attention_head) - mixed_x_layer = mixed_x_layer.view(*new_tensor_shape) - - # split into query, key and value - last_dim = mixed_x_layer.dim() - 1 - last_dim_value = mixed_x_layer.size(-1) - assert last_dim_value % 3 == 0, 'the last dimension is not a multiple of 3, ' \ - 'cannot be divided into query, key and value' - partition_size = last_dim_value // 3 - (query_layer, key_layer, value_layer) = torch.split(mixed_x_layer, partition_size, dim=last_dim) - - # attention scores: [batch_size, num_heads, sub_seq_len, seq_len] - output_size = (query_layer.size(1), query_layer.size(2), query_layer.size(0), - key_layer.size(0) * self.world_size) - - # [sub_seq_len, batch_size, num_heads, head_size] -> [sub_seq_len, batch_size * num_heads, head_size] - query_layer = query_layer.view(output_size[2], output_size[0] * output_size[1], -1) - # [sub_seq_len, batch_size, num_heads, head_size] -> [sub_seq_len, batch_size * num_heads, head_size] - key_layer = key_layer.view(key_layer.size(0), output_size[0] * output_size[1], -1) - - # attention_scores: [batch_size * num_heads, sub_seq_len, seq_len] - attention_scores = RingQK.apply( - query_layer.transpose(0, 1).contiguous(), # [batch_size * num_heads, sub_seq_len, head_size] - key_layer.transpose(0, 1).contiguous(), # [batch_size * num_heads, sub_seq_len, head_size], - batch_size, - self.num_attention_heads, - sub_seq_length) - - attention_scores /= self.norm_factor - - # change view to [batch_size, num_heads, sub_seq_len, seq_len] - attention_scores = attention_scores.view(*output_size) - - # change shape to [batch_size, num_heads, sub_seq_len, seq_len] - attention_probs = self.scale_mask_softmax(attention_scores, attention_mask) - # This is actually dropping out entire tokens to attend to, which might - # seem a bit unusual, but is taken from the original Transformer paper. - with seed(ParallelMode.TENSOR): - attention_probs = self.attention_dropout(attention_probs) - - # context layer shape: [batch_size, num_heads, sub_seq_len, head_size] - output_size = (value_layer.size(1), value_layer.size(2), query_layer.size(0), value_layer.size(3)) - - # change view [sub_seq_len, batch_size * num_heads, head_size] - value_layer = value_layer.contiguous().view(value_layer.size(0), output_size[0] * output_size[1], -1) - - # # change view [b * num_heads, sub_seq_len, seq_len] - attention_probs = attention_probs.view( - attention_probs.size(0) * attention_probs.size(1), attention_probs.size(2), attention_probs.size(3)) - - # matmul: [batch_size * num_heads, sub_seq_len, head_size] - context_layer = RingAV.apply(attention_probs, - value_layer.transpose(0, 1).contiguous(), batch_size, self.num_attention_heads, - self.hidden_size_per_attention_head, sub_seq_length) - - # change view [batch_size, num_heads, sub_seq_len, head_size] - context_layer = context_layer.view(*output_size) - - # [batch_size, num_heads, sub_seq_len, head_size] -> [sub_seq_len, batch_size, num_heads, head_size] - context_layer = context_layer.permute(2, 0, 1, 3).contiguous() - - # [sub_seq_len, batch_size, num_heads, head_size] -> [sub_seq_len, batch_size, hidden_size] - new_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size_per_attention_head * - self.num_attention_heads,) - context_layer = context_layer.view(*new_context_layer_shape) - - output, bias = self.dense(context_layer) - - return output, bias - - def __repr__(self): - return f'TransformerSelfAttentionRing(apply_query_key_layer_scaling={self.apply_query_key_layer_scaling}, ' \ - f'layer_number={self.layer_number}, hidden_size:{self.hidden_size}, attention_dropout={self.attention_dropout}, ' \ - f'attn_mask_type={self.attn_mask_type}, num_attention_heads={self.num_attention_heads}, ' \ - f'hidden_size_per_attention_head={self.hidden_size_per_attention_head}, coeff={self.coeff}, norm_factor={self.norm_factor}, ' \ - f'convert_fp16_to_fp32_in_softmax={self.convert_fp16_to_fp32_in_softmax})' - - -class _Linear(nn.Module): - """Linear layer with column parallelism. - The linear layer is defined as Y = XA + b. A is parallelized along - its second dimension as A = [A_1, ..., A_p]. - Arguments: - input_size: first dimension of matrix A. - output_size: second dimension of matrix A. - bias: If true, add bias - init_method: method to initialize weights. Note that bias is always set - to zero. - stride: For the strided linear layers. - keep_master_weight_for_test: This was added for testing and should be - set to False. It returns the master weights - used for initialization. - skip_bias_add: This was added to enable performance optimations where bias - can be fused with other elementwise operations. we skip - adding bias but instead return it. - """ - - def __init__(self, input_size, output_size, bias=True, skip_bias_add=False): - super(_Linear, self).__init__() - - # Keep input parameters - self.input_size = input_size - self.output_size = output_size - self.skip_bias_add = skip_bias_add - - self.weight = Parameter(torch.empty( - self.output_size, - self.input_size, - )) - nn.init.xavier_normal_(self.weight) - - if bias: - self.bias = Parameter(torch.empty(self.output_size)) - # Always initialize bias to zero. - with torch.no_grad(): - self.bias.zero_() - else: - self.register_parameter('bias', None) - - def forward(self, input_): - # Matrix multiply. - bias = self.bias if not self.skip_bias_add else None - output = F.linear(input_, self.weight, bias) - - if self.skip_bias_add: - return output, self.bias - else: - return output - - def __repr__(self): - return f'Linear(in_features={self.input_size}, out_features={self.output_size}, ' + \ - f'bias={self.bias is not None}, skip_bias_add={self.skip_bias_add})' diff --git a/colossalai/nn/layer/utils.py b/colossalai/nn/layer/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..ff9b5c8f2b5b0f6d8917fe3f1c380c05897c5e19 --- /dev/null +++ b/colossalai/nn/layer/utils.py @@ -0,0 +1,13 @@ +def divide(numerator, denominator): + """Only allow exact division. + + Args: + numerator (int): Numerator of the division. + denominator (int): Denominator of the division. + + Returns: + int: the result of exact division. + """ + assert denominator != 0, "denominator can not be zero" + assert numerator % denominator == 0, "{} is not divisible by {}".format(numerator, denominator) + return numerator // denominator diff --git a/colossalai/nn/layer/utils/__init__.py b/colossalai/nn/layer/utils/__init__.py deleted file mode 100644 index 7e999ee8214916d9d2b5465333262d05cad198ec..0000000000000000000000000000000000000000 --- a/colossalai/nn/layer/utils/__init__.py +++ /dev/null @@ -1,7 +0,0 @@ -from .common import (ACT2FN, CheckpointModule, _ntuple, divide, get_tensor_parallel_mode, - set_tensor_parallel_attribute_by_partition, set_tensor_parallel_attribute_by_size, to_2tuple) - -__all__ = [ - 'CheckpointModule', 'divide', 'ACT2FN', 'set_tensor_parallel_attribute_by_size', - 'set_tensor_parallel_attribute_by_partition', 'get_tensor_parallel_mode', '_ntuple', 'to_2tuple' -] diff --git a/colossalai/nn/layer/utils/common.py b/colossalai/nn/layer/utils/common.py deleted file mode 100644 index f2297304fdc939c6a57e9eaa0f52ee268ead46fd..0000000000000000000000000000000000000000 --- a/colossalai/nn/layer/utils/common.py +++ /dev/null @@ -1,91 +0,0 @@ -#!/usr/bin/env python -# -*- encoding: utf-8 -*- - -import collections.abc -from itertools import repeat - -import numpy as np -import torch -from colossalai.constants import IS_TENSOR_PARALLEL, NUM_PARTITIONS -from colossalai.global_variables import tensor_parallel_env as env -from colossalai.utils import checkpoint -from torch import Tensor, nn - - -class CheckpointModule(nn.Module): - - def __init__(self, checkpoint: bool = True, offload: bool = False): - super().__init__() - self.checkpoint = checkpoint - self._use_checkpoint = checkpoint - self._offload = offload - - def _forward(self, *args, **kwargs): - raise NotImplementedError('CheckpointModule should implement _forward method instead of origin forward') - - def forward(self, *args, **kwargs): - if self._use_checkpoint: - return checkpoint(self._forward, self._offload, *args, **kwargs) - else: - return self._forward(*args, **kwargs) - - def train(self, mode: bool = True): - self._use_checkpoint = self.checkpoint - return super().train(mode=mode) - - def eval(self): - self._use_checkpoint = False - return super().eval() - - -def divide(numerator, denominator): - """Only allow exact division. - - Args: - numerator (int): Numerator of the division. - denominator (int): Denominator of the division. - - Returns: - int: the result of exact division. - """ - assert denominator != 0, 'denominator can not be zero' - assert numerator % denominator == 0, \ - '{} is not divisible by {}'.format(numerator, denominator) - return numerator // denominator - - -def swish(x: Tensor) -> Tensor: - return x * torch.sigmoid(x) - - -ACT2FN = {"gelu": torch.nn.functional.gelu, "relu": torch.nn.functional.relu, "swish": swish} - - -def set_tensor_parallel_attribute_by_size(param, size): - setattr(param, IS_TENSOR_PARALLEL, True) - setattr(param, NUM_PARTITIONS, size // np.prod(param.shape)) - - -def set_tensor_parallel_attribute_by_partition(param, num_partitions): - setattr(param, IS_TENSOR_PARALLEL, True) - setattr(param, NUM_PARTITIONS, num_partitions) - - -def get_tensor_parallel_mode(): - return env.mode - - -# From PyTorch internals - - -def _ntuple(n): - - def parse(x): - if isinstance(x, collections.abc.Iterable): - return x - return tuple(repeat(x, n)) - - return parse - - -to_2tuple = _ntuple(2) diff --git a/colossalai/nn/layer/vanilla/__init__.py b/colossalai/nn/layer/vanilla/__init__.py deleted file mode 100644 index 3d767b8886f53fe5ddb697fd9fb4ed261cfd05c3..0000000000000000000000000000000000000000 --- a/colossalai/nn/layer/vanilla/__init__.py +++ /dev/null @@ -1,14 +0,0 @@ -from .layers import ( - DropPath, - VanillaClassifier, - VanillaLayerNorm, - VanillaLinear, - VanillaPatchEmbedding, - WrappedDropout, - WrappedDropPath, -) - -__all__ = [ - "VanillaLayerNorm", "VanillaPatchEmbedding", "VanillaClassifier", "DropPath", "WrappedDropout", "WrappedDropPath", - "VanillaLinear" -] diff --git a/colossalai/nn/layer/vanilla/layers.py b/colossalai/nn/layer/vanilla/layers.py deleted file mode 100644 index 225aed3916a6dbacaa8876ea994c2d2441713338..0000000000000000000000000000000000000000 --- a/colossalai/nn/layer/vanilla/layers.py +++ /dev/null @@ -1,341 +0,0 @@ -import math -from typing import Callable - -import torch -import torch.nn.functional as F -from torch import Tensor -from torch import nn as nn -from torch.nn.parameter import Parameter - -from colossalai.context import seed -from colossalai.nn import init as init -from colossalai.registry import LAYERS -from colossalai.utils.cuda import get_current_device - -from ..utils import to_2tuple - - -def drop_path(x, drop_prob: float = 0., training: bool = False): - """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). - - This is the same as the DropConnect impl I created for EfficientNet, etc networks, however, - the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper... - See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for - changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use - 'survival rate' as the argument. - - Args: - drop_prob (float, optional): probability of dropping path, defaults 0.0. - training (bool, optional): whether in training progress, defaults False. - """ - if drop_prob == 0. or not training: - return x - keep_prob = 1 - drop_prob - shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets - random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device) - random_tensor.floor_() # binarize - output = x.div(keep_prob) * random_tensor - return output - - -class DropPath(nn.Module): - """ - Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). - Adapted from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/layers/drop.py - - Args: - drop_prob (float, optional): probability of dropping path, defaults None. - """ - - def __init__(self, drop_prob=None): - super(DropPath, self).__init__() - self.drop_prob = drop_prob - - def forward(self, x): - return drop_path(x, self.drop_prob, self.training) - - -class WrappedDropout(nn.Module): - r"""Same as torch.nn.Dropout. But it is wrapped with the context of seed manager. During training, randomly zeroes - some elements of the input tensor with probability p using samples from a Bernoulli distribution. Each - channel will be zeroed out independently on every forward call. Furthermore, the outputs are scaled by a factor of - 1/(1-p) during training. This means that during evaluation the module simply computes an identity function. - - Args: - p (float, optional): probability of an element to be zeroed, defaults 0.5. - inplace (bool, optional): whether to do dropout in-place, default to be False. - mode (:class:`colossalai.context.ParallelMode`): The chosen parallel mode. - - Note: - The parallel_mode should be concluded in ``ParallelMode``. More details about ``ParallelMode`` could be found - in `parallel_mode `_ - """ - - def __init__(self, p: float = 0.5, inplace: bool = False, mode=None): - super().__init__() - if p < 0 or p > 1: - raise ValueError("dropout probability has to be between 0 and 1, " - "but got {}".format(p)) - self.p = p - self.inplace = inplace - if mode is None: - self.func = self.nonefunc - else: - self.func = self.normalfunc - self.mode = mode - - def nonefunc(self, inputs): - return F.dropout(inputs, self.p, self.training, self.inplace) - - def normalfunc(self, inputs): - with seed(self.mode): - return F.dropout(inputs, self.p, self.training, self.inplace) - - def forward(self, inputs): - return self.func(inputs) - - -class WrappedDropPath(nn.Module): - r"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). - Here, it is wrapped with the context of seed manager. - - Args: - p (float, optional): probability of dropping path, defaults 0.0. - mode (:class:`colossalai.context.ParallelMode`): The chosen parallel mode. - - Note: - The parallel_mode should be concluded in ``ParallelMode``. More details about ``ParallelMode`` could be found - in `parallel_mode `_ - """ - - def __init__(self, p: float = 0., mode=None): - super().__init__() - self.p = p - self.mode = mode - if self.mode is None: - self.func = self.nonefunc - else: - self.func = self.normalfunc - self.mode = mode - - def nonefunc(self, inputs): - return drop_path(inputs, self.p, self.training) - - def normalfunc(self, inputs): - with seed(self.mode): - return drop_path(inputs, self.p, self.training) - - def forward(self, inputs): - return self.func(inputs) - - -@LAYERS.register_module -class VanillaPatchEmbedding(nn.Module): - r""" - 2D Image to Patch Embedding - - Args: - img_size (int): image size. - patch_size (int): patch size. - in_chans (int): number of channels of input image. - embed_size (int): size of embedding. - dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None. - flatten (bool, optional): whether to flatten output tensor, defaults to True. - weight_initializer (:class:`typing.Callable`, optional): - The initializer of weight, defaults to kaiming uniform initializer. - bias_initializer (:class:`typing.Callable`, optional): - The initializer of bias, defaults to xavier uniform initializer. - position_embed_initializer (:class:`typing.Callable`, optional): - The initializer of position embedding, defaults to zeros initializer. - - More details about initializer please refer to - `init `_. - """ - - def __init__(self, - img_size: int, - patch_size: int, - in_chans: int, - embed_size: int, - flatten: bool = True, - dtype: torch.dtype = None, - weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), - bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1), - position_embed_initializer: Callable = init.zeros_()): - super().__init__() - img_size = to_2tuple(img_size) - patch_size = to_2tuple(patch_size) - self.img_size = img_size - self.patch_size = patch_size - self.grid_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1]) - self.num_patches = self.grid_size[0] * self.grid_size[1] - self.flatten = flatten - - self.weight = nn.Parameter( - torch.empty((embed_size, in_chans, *self.patch_size), device=get_current_device(), dtype=dtype)) - self.bias = nn.Parameter(torch.empty(embed_size, device=get_current_device(), dtype=dtype)) - self.cls_token = nn.Parameter(torch.zeros((1, 1, embed_size), device=get_current_device(), dtype=dtype)) - self.pos_embed = nn.Parameter( - torch.zeros((1, self.num_patches + 1, embed_size), device=get_current_device(), dtype=dtype)) - - self.reset_parameters(weight_initializer, bias_initializer, position_embed_initializer) - - def reset_parameters(self, weight_initializer, bias_initializer, position_embed_initializer): - fan_in, fan_out = nn.init._calculate_fan_in_and_fan_out(self.weight) - weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out) - bias_initializer(self.bias, fan_in=fan_in) - position_embed_initializer(self.pos_embed) - - def forward(self, input_: Tensor) -> Tensor: - B, C, H, W = input_.shape - assert H == self.img_size[0] and W == self.img_size[1], \ - f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." - output = F.conv2d(input_, self.weight, self.bias, stride=self.patch_size) - if self.flatten: - output = output.flatten(2).transpose(1, 2) # BCHW -> BNC - - cls_token = self.cls_token.expand(output.shape[0], -1, -1) - output = torch.cat((cls_token, output), dim=1) - output = output + self.pos_embed - return output - - -@LAYERS.register_module -class VanillaClassifier(nn.Module): - r"""Dense linear classifier. - - Args: - in_features (int): size of each input sample. - num_classes (int): number of classes. - weight (:class:`torch.nn.Parameter`, optional): weight of the classifier, defaults to None. - dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None. - flatten (bool, optional): whether to flatten output tensor, defaults to True. - weight_initializer (:class:`typing.Callable`, optional): - The initializer of weight, defaults to kaiming uniform initializer. - bias_initializer (:class:`typing.Callable`, optional): - The initializer of bias, defaults to xavier uniform initializer. - - More details about initializer please refer to - `init `_. - """ - - def __init__(self, - in_features: int, - num_classes: int, - weight: nn.Parameter = None, - bias: bool = True, - dtype: torch.dtype = None, - weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), - bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1)): - super().__init__() - self.in_features = in_features - self.num_classes = num_classes - - if weight is not None: - self.weight = weight - self.has_weight = False - else: - self.weight = nn.Parameter( - torch.empty(self.num_classes, self.in_features, device=get_current_device(), dtype=dtype)) - self.has_weight = True - if bias: - self.bias = nn.Parameter(torch.zeros(self.num_classes, device=get_current_device(), dtype=dtype)) - else: - self.bias = None - - self.reset_parameters(weight_initializer, bias_initializer) - - def reset_parameters(self, weight_initializer, bias_initializer): - fan_in, fan_out = self.in_features, self.num_classes - - if self.has_weight: - weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out) - - if self.bias is not None: - bias_initializer(self.bias, fan_in=fan_in) - - def forward(self, input_: Tensor) -> Tensor: - return F.linear(input_, self.weight, self.bias) - - -@LAYERS.register_module -class VanillaLayerNorm(nn.Module): - r""" - Layer Normalization for colossalai - - Args: - normalized_shape (int): input shape from an expected input of size. - :math:`[* \times \text{normalized_shape}[0] \times \text{normalized_shape}[1] - \times \ldots \times \text{normalized_shape}[-1]]` - If a single integer is used, it is treated as a singleton list, and this module will - normalize over the last dimension which is expected to be of that specific size. - eps (float): a value added to the denominator for numerical stability, defaults to 1e-05. - bias (bool, optional): Whether to add a bias, defaults to ``True``. - dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None. - """ - - def __init__(self, normalized_shape: int, eps=1e-05, bias=True, dtype=None): - super().__init__() - - self.normalized_shape = (normalized_shape,) - self.variance_epsilon = eps - - factory_kwargs = {'device': get_current_device(), 'dtype': dtype} - - self.weight = nn.Parameter(torch.ones(normalized_shape, **factory_kwargs)) - if bias: - self.bias = nn.Parameter(torch.zeros(normalized_shape, **factory_kwargs)) - else: - self.bias = None - - def forward(self, x: Tensor) -> Tensor: - return F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.variance_epsilon) - - -@LAYERS.register_module -class VanillaLinear(nn.Module): - """Linear layer. - - Args: - in_features (int): size of each input sample. - out_features (int): size of each output sample. - bias (bool, optional): If set to ``False``, the layer will not learn an additive bias, defaults to ``True``. - dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None. - skip_bias_add: bool (optional, default to be false). - weight_initializer (:class:`typing.Callable`, optional): - The initializer of weight, defaults to kaiming uniform initializer. - bias_initializer (:class:`typing.Callable`, optional): - The initializer of bias, defaults to xavier uniform initializer. - - More details about ``initializer`` please refer to - `init `_. - """ - - def __init__(self, - in_features: int, - out_features: int, - bias: bool = True, - dtype: torch.dtype = None, - skip_bias_add: bool = False, - weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), - bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1), - **kwargs) -> None: - super().__init__() - self.in_features = in_features - self.out_features = out_features - self.skip_bias_add = skip_bias_add - factory_kwargs = {'device': get_current_device(), 'dtype': dtype} - self.weight = Parameter(torch.empty(self.out_features, self.in_features, **factory_kwargs)) - if bias: - self.bias = Parameter(torch.empty(self.out_features, **factory_kwargs)) - else: - self.bias = None - weight_initializer(self.weight, fan_in=in_features, fan_out=out_features) - if self.bias is not None: - bias_initializer(self.bias, fan_in=in_features) - - def forward(self, input: Tensor) -> Tensor: - if not self.skip_bias_add: - return F.linear(input, self.weight, self.bias) - else: - return F.linear(input, self.weight), self.bias diff --git a/colossalai/nn/layer/wrapper/__init__.py b/colossalai/nn/layer/wrapper/__init__.py deleted file mode 100644 index c7d90d887ec6612e351713e508d96b106b767a81..0000000000000000000000000000000000000000 --- a/colossalai/nn/layer/wrapper/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -from .pipeline_wrapper import PipelineSharedModuleWrapper - -__all__ = ['PipelineSharedModuleWrapper'] diff --git a/colossalai/nn/layer/wrapper/pipeline_wrapper.py b/colossalai/nn/layer/wrapper/pipeline_wrapper.py deleted file mode 100644 index ef1d794cc68f15a1885776016601e1b850d933b8..0000000000000000000000000000000000000000 --- a/colossalai/nn/layer/wrapper/pipeline_wrapper.py +++ /dev/null @@ -1,46 +0,0 @@ -import torch.nn as nn -import torch.distributed as dist -from typing import List, Tuple, Union -from colossalai.context import ParallelMode -from colossalai.core import global_context as gpc - - -class PipelineSharedModuleWrapper: - - def __init__(self, pipeline_ranks: Union[List[int], Tuple[int]]) -> None: - assert len(pipeline_ranks) > 1, f'Expect len(pipeline_ranks) > 1, got {len(pipeline_ranks)}' - self.pipeline_ranks = pipeline_ranks - self.group = None - self.ranks_in_group = None - self._init_group() - - def _init_group(self): - world_size = gpc.get_world_size(ParallelMode.GLOBAL) - dp_size = gpc.get_world_size(ParallelMode.DATA) - pp_size = gpc.get_world_size(ParallelMode.PIPELINE) - rank = gpc.get_global_rank() - num_dp_groups = world_size // dp_size - num_pp_stages = num_dp_groups // pp_size - for i in range(dp_size): - for j in range(num_pp_stages): - pipeline_ranks = list(range(i * num_dp_groups + j, (i + 1) * num_dp_groups, num_pp_stages)) - sub_ranks = [pipeline_ranks[idx] for idx in self.pipeline_ranks] - group = dist.new_group(sub_ranks) - if rank in sub_ranks: - self.group = group - self.ranks_in_group = sub_ranks - - def register_module(self, module: nn.Module): - assert self.ranks_in_group is not None,\ - f'Rank {gpc.get_local_rank(ParallelMode.PIPELINE)} is not in pipeline_ranks {self.pipeline_ranks}' - src = self.ranks_in_group[self.pipeline_ranks[0]] - for p in module.parameters(): - setattr(p, 'pipeline_shared_module_pg', self.group) - dist.broadcast(p, src, group=self.group) - - def register_parameter(self, param: nn.Parameter): - assert self.ranks_in_group is not None,\ - f'Rank {gpc.get_local_rank(ParallelMode.PIPELINE)} is not in pipeline_ranks {self.pipeline_ranks}' - src = self.ranks_in_group[self.pipeline_ranks[0]] - setattr(param, 'pipeline_shared_module_pg', self.group) - dist.broadcast(param, src, group=self.group) diff --git a/colossalai/nn/loss/__init__.py b/colossalai/nn/loss/__init__.py index 373e4ec9468bc13317d74c19b5922073a5cb8c0c..7c6fb099d272c1fd9db228370a72b2c8755ab3e6 100644 --- a/colossalai/nn/loss/__init__.py +++ b/colossalai/nn/loss/__init__.py @@ -1,41 +1 @@ -from colossalai.global_variables import tensor_parallel_env as env -from colossalai.nn.layer.utils import get_tensor_parallel_mode -from torch import nn -from torch.nn.modules.loss import * -from torch.nn.modules.loss import _Loss - -from .loss_1d import VocabParallelCrossEntropyLoss1D -from .loss_2d import CrossEntropyLoss2D, VocabParallelCrossEntropyLoss2D -from .loss_2p5d import CrossEntropyLoss2p5D, VocabParallelCrossEntropyLoss2p5D -from .loss_3d import CrossEntropyLoss3D, VocabParallelCrossEntropyLoss3D -from .loss_moe import MoeCrossEntropyLoss, MoeLoss - -_parallel_cross_entropy = { - '2d': CrossEntropyLoss2D, - '2.5d': CrossEntropyLoss2p5D, - '3d': CrossEntropyLoss3D, -} - -_vocab_parallel_cross_entropy = { - '1d': VocabParallelCrossEntropyLoss1D, - '2d': VocabParallelCrossEntropyLoss2D, - '2.5d': VocabParallelCrossEntropyLoss2p5D, - '3d': VocabParallelCrossEntropyLoss3D, -} - - -class CrossEntropyLoss(_Loss): - - def __init__(self, reduction: bool = True, *args, **kwargs): - super().__init__() - tensor_parallel = get_tensor_parallel_mode() - if tensor_parallel is not None and env.vocab_parallel: - self.loss = _vocab_parallel_cross_entropy[tensor_parallel](reduction=reduction, *args, **kwargs) - elif tensor_parallel is None or tensor_parallel == '1d': - reduction = 'mean' if reduction else 'none' - self.loss = nn.CrossEntropyLoss(reduction=reduction, *args, **kwargs) - else: - self.loss = _parallel_cross_entropy[tensor_parallel](reduction=reduction, *args, **kwargs) - - def forward(self, *args): - return self.loss(*args) +# from .loss_moe import MoeCrossEntropyLoss, MoeLoss diff --git a/colossalai/nn/loss/loss_moe.py b/colossalai/nn/loss/loss_moe.py index a8b18a3e37ee20821df21fa7b2e309c4dc9f02e5..40cea788c3c3cf0c1ede4f9e460d28470950efd2 100644 --- a/colossalai/nn/loss/loss_moe.py +++ b/colossalai/nn/loss/loss_moe.py @@ -1,80 +1,81 @@ -import torch.nn as nn -from colossalai.registry import LOSSES -from torch.nn.modules.loss import _Loss -from colossalai.context.moe_context import MOE_CONTEXT - - -@LOSSES.register_module -class MoeCrossEntropyLoss(_Loss): - r"""torch.nn.CrossEntropyLoss added with auxiliary loss. - - Args: - input (:class:`torch.tensor`): Predicted unnormalized scores (often referred to as logits). - target (:class:`torch.tensor`): Ground truth class indices or class probabilities. - aux_weight (float, optional): Weight of auxiliary loss in total loss.Defaults 0.01. - - The ``args`` and ``kwargs`` should include parameters below: - :: - - weight (Tensor, optional) - size_average (bool, optional) - ignore_index (int, optional) - reduce (bool, optional) - reduction (str, optional) - label_smoothing (float, optional) - - More details about ``args``, ``kwargs`` and ``torch.nn.functional.cross_entropy`` could be found in - `Cross_entropy `_. - """ - - def __init__(self, aux_weight: float = 0.01, *args, **kwargs): - super().__init__() - self.loss = nn.CrossEntropyLoss(*args, **kwargs) - self.aux_weight = aux_weight - - def forward(self, *args): - """ - The ``args`` should at least include parameters below: - :: - - input (:class:`torch.tensor`): Predicted unnormalized scores (often referred to as logits). - target (:class:`torch.tensor`): Ground truth class indices or class probabilities. - - More details about ``args``, ``kwargs`` and ``torch.nn.functional.cross_entropy`` could be found in - `Cross_entropy `_. - """ - main_loss = self.loss(*args) - aux_loss = MOE_CONTEXT.get_loss() - return main_loss + self.aux_weight * aux_loss - - -@LOSSES.register_module -class MoeLoss(_Loss): - """A wrapper class for any loss module to add with auxiliary loss. - - Args: - aux_weight (float): Weight of auxiliary loss in total loss. - loss_fn (``Callable``): Loss function. - args (list): Args in loss function. - kwargs (dict): Kwargs in loss function - """ - - def __init__(self, aux_weight: float, loss_fn, *args, **kwargs): - super().__init__() - self.loss_fn = loss_fn(*args, **kwargs) - self.aux_weight = aux_weight - - def forward(self, *args, **kwargs): - """ - The ``args`` and ``kwargs`` should at least include parameters below: - :: - - input (:class:`torch.tensor`): Predicted unnormalized scores (often referred to as logits). - target (:class:`torch.tensor`): Ground truth class indices or class probabilities. - - Note: - The ``args`` and ``kwargs`` may include different parameters varying with different loss function. - """ - main_loss = self.loss_fn(*args, **kwargs) - aux_loss = MOE_CONTEXT.get_loss() - return main_loss + self.aux_weight * aux_loss +import torch.nn as nn +from torch.nn.modules.loss import _Loss + +from colossalai.context.moe_context import MOE_CONTEXT +from colossalai.legacy.registry import LOSSES + + +@LOSSES.register_module +class MoeCrossEntropyLoss(_Loss): + r"""torch.nn.CrossEntropyLoss added with auxiliary loss. + + Args: + input (:class:`torch.tensor`): Predicted unnormalized scores (often referred to as logits). + target (:class:`torch.tensor`): Ground truth class indices or class probabilities. + aux_weight (float, optional): Weight of auxiliary loss in total loss.Defaults 0.01. + + The ``args`` and ``kwargs`` should include parameters below: + :: + + weight (Tensor, optional) + size_average (bool, optional) + ignore_index (int, optional) + reduce (bool, optional) + reduction (str, optional) + label_smoothing (float, optional) + + More details about ``args``, ``kwargs`` and ``torch.nn.functional.cross_entropy`` could be found in + `Cross_entropy `_. + """ + + def __init__(self, aux_weight: float = 0.01, *args, **kwargs): + super().__init__() + self.loss = nn.CrossEntropyLoss(*args, **kwargs) + self.aux_weight = aux_weight + + def forward(self, *args): + """ + The ``args`` should at least include parameters below: + :: + + input (:class:`torch.tensor`): Predicted unnormalized scores (often referred to as logits). + target (:class:`torch.tensor`): Ground truth class indices or class probabilities. + + More details about ``args``, ``kwargs`` and ``torch.nn.functional.cross_entropy`` could be found in + `Cross_entropy `_. + """ + main_loss = self.loss(*args) + aux_loss = MOE_CONTEXT.get_loss() + return main_loss + self.aux_weight * aux_loss + + +@LOSSES.register_module +class MoeLoss(_Loss): + """A wrapper class for any loss module to add with auxiliary loss. + + Args: + aux_weight (float): Weight of auxiliary loss in total loss. + loss_fn (``Callable``): Loss function. + args (list): Args in loss function. + kwargs (dict): Kwargs in loss function + """ + + def __init__(self, aux_weight: float, loss_fn, *args, **kwargs): + super().__init__() + self.loss_fn = loss_fn(*args, **kwargs) + self.aux_weight = aux_weight + + def forward(self, *args, **kwargs): + """ + The ``args`` and ``kwargs`` should at least include parameters below: + :: + + input (:class:`torch.tensor`): Predicted unnormalized scores (often referred to as logits). + target (:class:`torch.tensor`): Ground truth class indices or class probabilities. + + Note: + The ``args`` and ``kwargs`` may include different parameters varying with different loss function. + """ + main_loss = self.loss_fn(*args, **kwargs) + aux_loss = MOE_CONTEXT.get_loss() + return main_loss + self.aux_weight * aux_loss diff --git a/colossalai/nn/lr_scheduler/__init__.py b/colossalai/nn/lr_scheduler/__init__.py index 34731ee901a0d37a3296e57915df08c38f0b648b..783f12f8c7c40579d6ae2f67c7b24e579c678dcc 100644 --- a/colossalai/nn/lr_scheduler/__init__.py +++ b/colossalai/nn/lr_scheduler/__init__.py @@ -3,10 +3,21 @@ from .linear import LinearWarmupLR from .multistep import MultiStepLR, MultiStepWarmupLR from .onecycle import OneCycleLR from .poly import PolynomialLR, PolynomialWarmupLR -from .torch import LambdaLR, MultiplicativeLR, StepLR, ExponentialLR +from .torch import ExponentialLR, LambdaLR, MultiplicativeLR, StepLR __all__ = [ - 'CosineAnnealingLR', 'CosineAnnealingWarmupLR', 'FlatAnnealingLR', 'FlatAnnealingWarmupLR', 'LinearWarmupLR', - 'MultiStepLR', 'MultiStepWarmupLR', 'OneCycleLR', 'PolynomialLR', 'PolynomialWarmupLR', 'LambdaLR', - 'MultiplicativeLR', 'StepLR', 'ExponentialLR' + "CosineAnnealingLR", + "CosineAnnealingWarmupLR", + "FlatAnnealingLR", + "FlatAnnealingWarmupLR", + "LinearWarmupLR", + "MultiStepLR", + "MultiStepWarmupLR", + "OneCycleLR", + "PolynomialLR", + "PolynomialWarmupLR", + "LambdaLR", + "MultiplicativeLR", + "StepLR", + "ExponentialLR", ] diff --git a/colossalai/nn/lr_scheduler/cosine.py b/colossalai/nn/lr_scheduler/cosine.py index aab523bef8b30dafc65f60a9475a8b9a70326738..f563825de0d5d36ede336f0c7cdd60535bb7aa0e 100644 --- a/colossalai/nn/lr_scheduler/cosine.py +++ b/colossalai/nn/lr_scheduler/cosine.py @@ -1,10 +1,8 @@ from torch.optim.lr_scheduler import CosineAnnealingLR as _CosineAnnealingLR -from colossalai.registry import LR_SCHEDULERS from .delayed import DelayerScheduler, WarmupDelayerScheduler, WarmupScheduler -@LR_SCHEDULERS.register_module class CosineAnnealingLR(_CosineAnnealingLR): r"""Set the learning rate of each parameter group using a cosine annealing schedule, where :math:`\eta_{max}` is set to the initial lr and @@ -48,7 +46,6 @@ class CosineAnnealingLR(_CosineAnnealingLR): super().__init__(optimizer, total_steps, eta_min=eta_min, last_epoch=last_epoch) -@LR_SCHEDULERS.register_module class CosineAnnealingWarmupLR(WarmupScheduler): """Cosine annealing learning rate scheduler with learning rate warmup. A linear warmup schedule will be applied. @@ -61,15 +58,13 @@ class CosineAnnealingWarmupLR(WarmupScheduler): the schedule is started from the beginning or When last_epoch=-1, sets initial lr as lr. """ - def __init__(self, optimizer, total_steps: int, warmup_steps: int = 0, eta_min: float = 0., last_epoch: int = -1): - base_scheduler = _CosineAnnealingLR(optimizer, - total_steps - warmup_steps, - eta_min=eta_min, - last_epoch=last_epoch) - super().__init__(optimizer, warmup_steps, base_scheduler) + def __init__(self, optimizer, total_steps: int, warmup_steps: int = 0, eta_min: float = 0.0, last_epoch: int = -1): + base_scheduler = _CosineAnnealingLR( + optimizer, total_steps - warmup_steps, eta_min=eta_min, last_epoch=last_epoch + ) + super().__init__(optimizer, warmup_steps, base_scheduler, last_epoch=last_epoch) -@LR_SCHEDULERS.register_module class FlatAnnealingLR(DelayerScheduler): """Flat and cosine annealing learning rate scheduler. The learning rate will be a fixed value before starting decay. @@ -83,14 +78,13 @@ class FlatAnnealingLR(DelayerScheduler): def __init__(self, optimizer, total_steps: int, pct_start: float = 0.72, last_epoch: int = -1, **kwargs): if not (0.0 <= pct_start <= 1.0): - raise ValueError(f'pct_start must >= 0.0 and <= 1.0, got {pct_start}') + raise ValueError(f"pct_start must >= 0.0 and <= 1.0, got {pct_start}") flat_steps = int(total_steps * pct_start) anneal_steps = total_steps - flat_steps base_scheduler = _CosineAnnealingLR(optimizer, anneal_steps) super().__init__(optimizer, flat_steps, base_scheduler, last_epoch=last_epoch) -@LR_SCHEDULERS.register_module class FlatAnnealingWarmupLR(WarmupDelayerScheduler): """Flat and cosine annealing learning rate scheduler with learning rate warmup. A linear warmup schedule will be applied, and then the learning rate will be a fixed value before starting decay. @@ -105,16 +99,18 @@ class FlatAnnealingWarmupLR(WarmupDelayerScheduler): the schedule is started from the beginning or When last_epoch=-1, sets initial lr as lr. """ - def __init__(self, - optimizer, - total_steps: int, - warmup_steps: int = 0, - pct_start: float = 0.72, - eta_min: int = 0, - last_epoch: int = -1, - **kwargs): + def __init__( + self, + optimizer, + total_steps: int, + warmup_steps: int = 0, + pct_start: float = 0.72, + eta_min: int = 0, + last_epoch: int = -1, + **kwargs, + ): if not (0.0 <= pct_start <= 1.0): - raise ValueError(f'pct_start must >= 0.0 and <= 1.0, got {pct_start}') + raise ValueError(f"pct_start must >= 0.0 and <= 1.0, got {pct_start}") flat_steps = int((total_steps - warmup_steps) * pct_start) anneal_steps = total_steps - warmup_steps - flat_steps base_scheduler = _CosineAnnealingLR(optimizer, anneal_steps, eta_min=eta_min) diff --git a/colossalai/nn/lr_scheduler/delayed.py b/colossalai/nn/lr_scheduler/delayed.py index a73ff8ae37ace4f82433c55048dac9cf729389fe..ce7f126d6101ac7d6ad751dcae440377f7a55643 100644 --- a/colossalai/nn/lr_scheduler/delayed.py +++ b/colossalai/nn/lr_scheduler/delayed.py @@ -2,7 +2,6 @@ from torch.optim.lr_scheduler import _LRScheduler class _enable_get_lr_call: - def __init__(self, o): self.o = o @@ -28,18 +27,18 @@ class DelayerScheduler(_LRScheduler): def __init__(self, optimizer, delay_epochs, after_scheduler, last_epoch=-1): if delay_epochs < 0: - raise ValueError(f'delay_epochs must >= 0, got {delay_epochs}') + raise ValueError(f"delay_epochs must >= 0, got {delay_epochs}") self.delay_epochs = delay_epochs self.after_scheduler = after_scheduler self.finished = False super().__init__(optimizer, last_epoch) def state_dict(self): - state_dict = {key: value for key, value in self.__dict__.items() if key not in 'optimizer'} - if isinstance(state_dict['after_scheduler'], _LRScheduler): - state_dict['after_scheduler_type'] = type(state_dict['after_scheduler']).__name__ - state_dict['after_scheduler_dict'] = state_dict['after_scheduler'].state_dict() - del state_dict['after_scheduler'] + state_dict = {key: value for key, value in self.__dict__.items() if key not in "optimizer"} + if isinstance(state_dict["after_scheduler"], _LRScheduler): + state_dict["after_scheduler_type"] = type(state_dict["after_scheduler"]).__name__ + state_dict["after_scheduler_dict"] = state_dict["after_scheduler"].state_dict() + del state_dict["after_scheduler"] else: raise NotImplementedError() return state_dict @@ -85,11 +84,11 @@ class WarmupScheduler(_LRScheduler): super().__init__(optimizer, last_epoch) def state_dict(self): - state_dict = {key: value for key, value in self.__dict__.items() if key not in 'optimizer'} - if isinstance(state_dict['after_scheduler'], _LRScheduler): - state_dict['after_scheduler_type'] = type(state_dict['after_scheduler']).__name__ - state_dict['after_scheduler_dict'] = state_dict['after_scheduler'].state_dict() - del state_dict['after_scheduler'] + state_dict = {key: value for key, value in self.__dict__.items() if key not in "optimizer"} + if isinstance(state_dict["after_scheduler"], _LRScheduler): + state_dict["after_scheduler_type"] = type(state_dict["after_scheduler"]).__name__ + state_dict["after_scheduler_dict"] = state_dict["after_scheduler"].state_dict() + del state_dict["after_scheduler"] else: raise NotImplementedError() return state_dict @@ -130,9 +129,9 @@ class WarmupDelayerScheduler(_LRScheduler): def __init__(self, optimizer, warmup_epochs, delay_epochs, after_scheduler, last_epoch=-1): if delay_epochs < 0: - raise ValueError(f'delay_epochs must >= 0, got {delay_epochs}') + raise ValueError(f"delay_epochs must >= 0, got {delay_epochs}") if warmup_epochs < 0: - raise ValueError(f'warmup_epochs must >= 0, got {warmup_epochs}') + raise ValueError(f"warmup_epochs must >= 0, got {warmup_epochs}") self.warmup_epochs = warmup_epochs self.delay_epochs = delay_epochs self.after_scheduler = after_scheduler @@ -140,11 +139,11 @@ class WarmupDelayerScheduler(_LRScheduler): super().__init__(optimizer, last_epoch) def state_dict(self): - state_dict = {key: value for key, value in self.__dict__.items() if key not in 'optimizer'} - if isinstance(state_dict['after_scheduler'], _LRScheduler): - state_dict['after_scheduler_type'] = type(state_dict['after_scheduler']).__name__ - state_dict['after_scheduler_dict'] = state_dict['after_scheduler'].state_dict() - del state_dict['after_scheduler'] + state_dict = {key: value for key, value in self.__dict__.items() if key not in "optimizer"} + if isinstance(state_dict["after_scheduler"], _LRScheduler): + state_dict["after_scheduler_type"] = type(state_dict["after_scheduler"]).__name__ + state_dict["after_scheduler_dict"] = state_dict["after_scheduler"].state_dict() + del state_dict["after_scheduler"] else: raise NotImplementedError() return state_dict @@ -155,7 +154,7 @@ class WarmupDelayerScheduler(_LRScheduler): self.after_scheduler.base_lrs = self.base_lrs # reset lr to base_lr for group, base_lr in zip(self.optimizer.param_groups, self.base_lrs): - group['lr'] = base_lr + group["lr"] = base_lr self.finished = True with _enable_get_lr_call(self.after_scheduler): return self.after_scheduler.get_lr() diff --git a/colossalai/nn/lr_scheduler/linear.py b/colossalai/nn/lr_scheduler/linear.py index 556938b8a60c8ae5eaf116e17710f5253f4bce28..1251c261d51ff8f6a2d7e2852a8bdb0b31b64623 100644 --- a/colossalai/nn/lr_scheduler/linear.py +++ b/colossalai/nn/lr_scheduler/linear.py @@ -1,9 +1,6 @@ from torch.optim.lr_scheduler import _LRScheduler -from colossalai.registry import LR_SCHEDULERS - -@LR_SCHEDULERS.register_module class LinearWarmupLR(_LRScheduler): """Linearly warmup learning rate and then linearly decay. @@ -24,5 +21,7 @@ class LinearWarmupLR(_LRScheduler): if self.last_epoch < self.warmup_steps: return [(self.last_epoch + 1) / (self.warmup_steps + 1) * lr for lr in self.base_lrs] else: - return [(self.total_steps - self.last_epoch) / (self.total_steps - self.warmup_steps) * lr - for lr in self.base_lrs] + return [ + (self.total_steps - self.last_epoch) / (self.total_steps - self.warmup_steps) * lr + for lr in self.base_lrs + ] diff --git a/colossalai/nn/lr_scheduler/multistep.py b/colossalai/nn/lr_scheduler/multistep.py index 29531a9e385524913b9b6bfb4c89c7ce40b05792..86589d74662da9b66cad59b3a1df4af3694966a8 100644 --- a/colossalai/nn/lr_scheduler/multistep.py +++ b/colossalai/nn/lr_scheduler/multistep.py @@ -2,11 +2,9 @@ from typing import List from torch.optim.lr_scheduler import MultiStepLR as _MultiStepLR -from colossalai.registry import LR_SCHEDULERS from .delayed import WarmupScheduler -@LR_SCHEDULERS.register_module class MultiStepLR(_MultiStepLR): """Decays the learning rate of each parameter group by gamma once the number of epoch reaches one of the milestones. Notice that such decay can @@ -22,17 +20,18 @@ class MultiStepLR(_MultiStepLR): the schedule is started from the beginning or When last_epoch=-1, sets initial lr as lr. """ - def __init__(self, - optimizer, - total_steps: int, - milestones: List[int] = None, - gamma: float = 0.1, - last_epoch: int = -1, - **kwargs): + def __init__( + self, + optimizer, + total_steps: int, + milestones: List[int] = None, + gamma: float = 0.1, + last_epoch: int = -1, + **kwargs, + ): super().__init__(optimizer, milestones, gamma=gamma, last_epoch=last_epoch) -@LR_SCHEDULERS.register_module class MultiStepWarmupLR(WarmupScheduler): """Multistep learning rate scheduler with warmup. @@ -47,16 +46,18 @@ class MultiStepWarmupLR(WarmupScheduler): the schedule is started from the beginning or When last_epoch=-1, sets initial lr as lr. """ - def __init__(self, - optimizer, - total_steps: int, - warmup_steps: int = 0, - milestones: List[int] = None, - gamma: float = 0.1, - last_epoch: int = -1, - **kwargs): + def __init__( + self, + optimizer, + total_steps: int, + warmup_steps: int = 0, + milestones: List[int] = None, + gamma: float = 0.1, + last_epoch: int = -1, + **kwargs, + ): if len(milestones) == 0: - raise ValueError('milestones cannot be empty') + raise ValueError("milestones cannot be empty") milestones = [v - warmup_steps for v in milestones if v >= warmup_steps] base_scheduler = _MultiStepLR(optimizer, milestones=milestones, gamma=gamma) super().__init__(optimizer, warmup_steps, base_scheduler, last_epoch=last_epoch) diff --git a/colossalai/nn/lr_scheduler/onecycle.py b/colossalai/nn/lr_scheduler/onecycle.py index 8007fd36008ea01830500a1f3272d0407d2ca3f2..a8e551526dbd78cadf3f9144c43411baad1eee8d 100644 --- a/colossalai/nn/lr_scheduler/onecycle.py +++ b/colossalai/nn/lr_scheduler/onecycle.py @@ -1,9 +1,6 @@ from torch.optim.lr_scheduler import OneCycleLR as _OneCycleLR -from colossalai.registry import LR_SCHEDULERS - -@LR_SCHEDULERS.register_module class OneCycleLR(_OneCycleLR): r"""Sets the learning rate of each parameter group according to the 1cycle learning rate policy. The 1cycle policy anneals the learning @@ -68,27 +65,31 @@ class OneCycleLR(_OneCycleLR): https://arxiv.org/abs/1708.07120 """ - def __init__(self, - optimizer, - total_steps: int, - pct_start=0.3, - anneal_strategy='cos', - cycle_momentum=True, - base_momentum=0.85, - max_momentum=0.95, - div_factor=25.0, - final_div_factor=10000.0, - last_epoch=-1, - **kwargs): - max_lrs = list(map(lambda group: group['lr'], optimizer.param_groups)) - super().__init__(optimizer, - max_lrs, - total_steps=total_steps, - pct_start=pct_start, - anneal_strategy=anneal_strategy, - cycle_momentum=cycle_momentum, - base_momentum=base_momentum, - max_momentum=max_momentum, - div_factor=div_factor, - final_div_factor=final_div_factor, - last_epoch=last_epoch) + def __init__( + self, + optimizer, + total_steps: int, + pct_start=0.3, + anneal_strategy="cos", + cycle_momentum=True, + base_momentum=0.85, + max_momentum=0.95, + div_factor=25.0, + final_div_factor=10000.0, + last_epoch=-1, + **kwargs, + ): + max_lrs = list(map(lambda group: group["lr"], optimizer.param_groups)) + super().__init__( + optimizer, + max_lrs, + total_steps=total_steps, + pct_start=pct_start, + anneal_strategy=anneal_strategy, + cycle_momentum=cycle_momentum, + base_momentum=base_momentum, + max_momentum=max_momentum, + div_factor=div_factor, + final_div_factor=final_div_factor, + last_epoch=last_epoch, + ) diff --git a/colossalai/nn/lr_scheduler/poly.py b/colossalai/nn/lr_scheduler/poly.py index 16352bc5175ff022f3111bec8503443cdae772b9..4a3814461ea9dacde9d8027c7ef2bc8cc6d37571 100644 --- a/colossalai/nn/lr_scheduler/poly.py +++ b/colossalai/nn/lr_scheduler/poly.py @@ -1,10 +1,8 @@ from torch.optim.lr_scheduler import _LRScheduler -from colossalai.registry import LR_SCHEDULERS from .delayed import WarmupScheduler -@LR_SCHEDULERS.register_module class PolynomialLR(_LRScheduler): """Polynomial learning rate scheduler. @@ -17,15 +15,11 @@ class PolynomialLR(_LRScheduler): the schedule is started from the beginning or When last_epoch=-1, sets initial lr as lr. """ - def __init__(self, - optimizer, - total_steps: int, - end_lr: float = 0.0001, - power: float = 1.0, - last_epoch: int = -1, - **kwargs): + def __init__( + self, optimizer, total_steps: int, end_lr: float = 0.0001, power: float = 1.0, last_epoch: int = -1, **kwargs + ): if end_lr < 0: - raise ValueError(f'end_lr must >= 0, got {end_lr}') + raise ValueError(f"end_lr must >= 0, got {end_lr}") self.total_steps = total_steps self.end_lr = end_lr self.power = power @@ -35,12 +29,13 @@ class PolynomialLR(_LRScheduler): return self._get_closed_form_lr() def _get_closed_form_lr(self): - return [(base_lr - self.end_lr) * - ((1 - min(self.last_epoch, self.total_steps) / self.total_steps)**self.power) + self.end_lr - for base_lr in self.base_lrs] + return [ + (base_lr - self.end_lr) * ((1 - min(self.last_epoch, self.total_steps) / self.total_steps) ** self.power) + + self.end_lr + for base_lr in self.base_lrs + ] -@LR_SCHEDULERS.register_module class PolynomialWarmupLR(WarmupScheduler): """Polynomial learning rate scheduler with warmup. @@ -54,13 +49,15 @@ class PolynomialWarmupLR(WarmupScheduler): the schedule is started from the beginning or When last_epoch=-1, sets initial lr as lr. """ - def __init__(self, - optimizer, - total_steps: int, - warmup_steps: int = 0, - end_lr: float = 0.0001, - power: float = 1.0, - last_epoch: int = -1, - **kwargs): + def __init__( + self, + optimizer, + total_steps: int, + warmup_steps: int = 0, + end_lr: float = 0.0001, + power: float = 1.0, + last_epoch: int = -1, + **kwargs, + ): base_scheduler = PolynomialLR(optimizer, total_steps - warmup_steps, end_lr=end_lr, power=power) super().__init__(optimizer, warmup_steps, base_scheduler, last_epoch=last_epoch) diff --git a/colossalai/nn/lr_scheduler/torch.py b/colossalai/nn/lr_scheduler/torch.py index 05d2a49c1ea5b0363f82dc61bc3f737c5c2a845a..8846e13c7511dbeed419b7740572814d1e017edc 100644 --- a/colossalai/nn/lr_scheduler/torch.py +++ b/colossalai/nn/lr_scheduler/torch.py @@ -1,12 +1,9 @@ +from torch.optim.lr_scheduler import ExponentialLR as _ExponentialLR from torch.optim.lr_scheduler import LambdaLR as _LambdaLR from torch.optim.lr_scheduler import MultiplicativeLR as _MultiplicativeLR from torch.optim.lr_scheduler import StepLR as _StepLR -from torch.optim.lr_scheduler import ExponentialLR as _ExponentialLR - -from colossalai.registry import LR_SCHEDULERS -@LR_SCHEDULERS.register_module class LambdaLR(_LambdaLR): """Sets the learning rate of each parameter group to the initial lr times a given function. When last_epoch=-1, sets initial lr as lr. @@ -24,7 +21,6 @@ class LambdaLR(_LambdaLR): super().__init__(optimizer, lr_lambda, last_epoch=last_epoch) -@LR_SCHEDULERS.register_module class MultiplicativeLR(_MultiplicativeLR): """Multiply the learning rate of each parameter group by the factor given in the specified function. When last_epoch=-1, sets initial lr as lr. @@ -42,7 +38,6 @@ class MultiplicativeLR(_MultiplicativeLR): super().__init__(optimizer, lr_lambda, last_epoch=last_epoch) -@LR_SCHEDULERS.register_module class StepLR(_StepLR): """Decays the learning rate of each parameter group by gamma every step_size epochs. Notice that such decay can happen simultaneously with @@ -61,7 +56,6 @@ class StepLR(_StepLR): super().__init__(optimizer, step_size, gamma=gamma, last_epoch=last_epoch) -@LR_SCHEDULERS.register_module class ExponentialLR(_ExponentialLR): """Decays the learning rate of each parameter group by gamma every epoch. When last_epoch=-1, sets initial lr as lr diff --git a/colossalai/nn/metric/__init__.py b/colossalai/nn/metric/__init__.py deleted file mode 100644 index 00833b6119c161be0bd2855a1b44b333f2b93f66..0000000000000000000000000000000000000000 --- a/colossalai/nn/metric/__init__.py +++ /dev/null @@ -1,26 +0,0 @@ -from torch import nn - -from ._utils import calc_acc -from .accuracy_2d import Accuracy2D -from .accuracy_2p5d import Accuracy2p5D -from .accuracy_3d import Accuracy3D -from colossalai.nn.layer.utils import get_tensor_parallel_mode - -_parallel_accuracy = { - '2d': Accuracy2D, - '2.5d': Accuracy2p5D, - '3d': Accuracy3D, -} - - -class Accuracy(nn.Module): - def __init__(self): - super().__init__() - tensor_parallel = get_tensor_parallel_mode() - if tensor_parallel not in _parallel_accuracy: - self.acc = calc_acc - else: - self.acc = _parallel_accuracy[tensor_parallel]() - - def forward(self, *args): - return self.acc(*args) diff --git a/colossalai/nn/metric/_utils.py b/colossalai/nn/metric/_utils.py deleted file mode 100644 index eac591b64c65cd835d4212698b1809c29710b425..0000000000000000000000000000000000000000 --- a/colossalai/nn/metric/_utils.py +++ /dev/null @@ -1,7 +0,0 @@ -import torch - - -def calc_acc(logits, targets): - preds = torch.argmax(logits, dim=-1) - correct = torch.sum(targets == preds) - return correct diff --git a/colossalai/nn/optimizer/README.md b/colossalai/nn/optimizer/README.md index 09395d08b93e9d80d3e966772e7197f022f7a851..e89e6217d596dd25a4d33c26012202aa12a61b94 100644 --- a/colossalai/nn/optimizer/README.md +++ b/colossalai/nn/optimizer/README.md @@ -3,7 +3,8 @@ ## Introduction Welcome to the large-scale deep learning optimization techniques of [Colossal-AI](https://github.com/hpcaitech/ColossalAI), -which has been accepted as official tutorials by top conference [SC](https://sc22.supercomputing.org/), [AAAI](https://aaai.org/Conferences/AAAI-23/), [PPoPP](https://ppopp23.sigplan.org/), [CVPR](https://cvpr2023.thecvf.com/), [ISC](https://www.isc-hpc.com/), etc. +which has been accepted as official tutorials by top conference [NeurIPS](https://nips.cc/), [SC](https://sc22.supercomputing.org/), [AAAI](https://aaai.org/Conferences/AAAI-23/), +[PPoPP](https://ppopp23.sigplan.org/), [CVPR](https://cvpr2023.thecvf.com/), [ISC](https://www.isc-hpc.com/), [NVIDIA GTC](https://www.nvidia.com/en-us/on-demand/session/gtcspring23-S51482/) ,etc. [Colossal-AI](https://github.com/hpcaitech/ColossalAI), a unified deep learning system for the big model era, integrates @@ -17,7 +18,7 @@ quickly deploy large AI model training and inference, reducing large AI model tr [**Paper**](https://arxiv.org/abs/2110.14883) | [**Documentation**](https://www.colossalai.org/) | [**Forum**](https://github.com/hpcaitech/ColossalAI/discussions) | -[**Slack**](https://join.slack.com/t/colossalaiworkspace/shared_invite/zt-z7b26eeb-CBp7jouvu~r0~lcFzX832w) +[**Slack**](https://github.com/hpcaitech/public_assets/tree/main/colossalai/contact/slack) ## Table of Content diff --git a/colossalai/nn/optimizer/__init__.py b/colossalai/nn/optimizer/__init__.py index 06072648beba974c2e6f638bdc5f3d1162141056..26f152da20d37fad297d0a17b95246afad3c8f5f 100644 --- a/colossalai/nn/optimizer/__init__.py +++ b/colossalai/nn/optimizer/__init__.py @@ -1,10 +1,9 @@ -from .colossalai_optimizer import ColossalaiOptimizer +from .cpu_adam import CPUAdam from .fused_adam import FusedAdam from .fused_lamb import FusedLAMB from .fused_sgd import FusedSGD +from .hybrid_adam import HybridAdam from .lamb import Lamb from .lars import Lars -from .cpu_adam import CPUAdam -from .hybrid_adam import HybridAdam -__all__ = ['ColossalaiOptimizer', 'FusedLAMB', 'FusedAdam', 'FusedSGD', 'Lamb', 'Lars', 'CPUAdam', 'HybridAdam'] +__all__ = ["FusedLAMB", "FusedAdam", "FusedSGD", "Lamb", "Lars", "CPUAdam", "HybridAdam"] diff --git a/colossalai/nn/optimizer/colossalai_optimizer.py b/colossalai/nn/optimizer/colossalai_optimizer.py deleted file mode 100644 index 34f5a9541975aa3029032d0346947c3524cd5a69..0000000000000000000000000000000000000000 --- a/colossalai/nn/optimizer/colossalai_optimizer.py +++ /dev/null @@ -1,44 +0,0 @@ -import torch -import torch.nn as nn -from torch import Tensor -from torch.optim import Optimizer -from colossalai.utils import clip_grad_norm_fp32 - - -class ColossalaiOptimizer(Optimizer): - - def __init__(self, optim: Optimizer): - self.optim = optim - - @property - def param_groups(self): - return self.optim.param_groups - - @property - def defaults(self): - return self.optim.defaults - - def add_param_group(self, *args, **kwargs): - return self.optim.add_param_group(*args, **kwargs) - - def step(self, *args, **kwargs): - return self.optim.step(*args, **kwargs) - - def zero_grad(self, *args, **kwargs): - self.optim.zero_grad(*args, **kwargs) - - def load_state_dict(self, *args, **kwargs): - self.optim.load_state_dict(*args, **kwargs) - - def state_dict(self): - return self.optim.state_dict() - - def backward(self, loss: Tensor): - loss.backward() - - def backward_by_grad(self, tensor: Tensor, grad: Tensor): - torch.autograd.backward(tensors=tensor, grad_tensors=grad) - - def clip_grad_norm(self, model: nn.Module, max_norm: float): - if max_norm > 0.0: - clip_grad_norm_fp32(model.parameters(), max_norm) diff --git a/colossalai/nn/optimizer/cpu_adam.py b/colossalai/nn/optimizer/cpu_adam.py index 54036973e1e31441f594f1a774693b5f93b01dd3..f35dc0200237a93f860fc77c0025291a613b3f5a 100644 --- a/colossalai/nn/optimizer/cpu_adam.py +++ b/colossalai/nn/optimizer/cpu_adam.py @@ -4,16 +4,14 @@ from typing import Optional import torch from colossalai.kernel.op_builder import CPUAdamBuilder -from colossalai.registry import OPTIMIZERS from .nvme_optimizer import NVMeOptimizer -@OPTIMIZERS.register_module class CPUAdam(NVMeOptimizer): """Implements Adam algorithm. - Supports parameters updating on both GPU and CPU, depanding on the device of paramters. + Supports parameters updating on both GPU and CPU, depending on the device of parameters. But the parameters and gradients should on the same device: * Parameters on CPU and gradients on CPU is allowed. * Parameters on GPU and gradients on GPU is allowed. @@ -21,7 +19,7 @@ class CPUAdam(NVMeOptimizer): `CPUAdam` requires CUDA extensions which can be built during installation or runtime. - This version of CPU Adam accelates parameters updating on CPU with SIMD. + This version of CPU Adam accelerates parameters updating on CPU with SIMD. Support of AVX2 or AVX512 is required. The GPU part is implemented in an naive way. @@ -63,38 +61,40 @@ class CPUAdam(NVMeOptimizer): # Param weight, grad, momentum and variance num_fp32_shards_per_param = 4 - def __init__(self, - model_params, - lr=1e-3, - bias_correction=True, - betas=(0.9, 0.999), - eps=1e-8, - weight_decay=0, - adamw_mode=True, - nvme_offload_fraction: float = 0.0, - nvme_offload_dir: Optional[str] = None): - + def __init__( + self, + model_params, + lr=1e-3, + bias_correction=True, + betas=(0.9, 0.999), + eps=1e-8, + weight_decay=0, + adamw_mode=True, + nvme_offload_fraction: float = 0.0, + nvme_offload_dir: Optional[str] = None, + ): default_args = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, bias_correction=bias_correction) super(CPUAdam, self).__init__(model_params, default_args, nvme_offload_fraction, nvme_offload_dir) self.adamw_mode = adamw_mode cpu_adam = CPUAdamBuilder().load() self.cpu_adam_op = cpu_adam.CPUAdamOptimizer(lr, betas[0], betas[1], eps, weight_decay, adamw_mode) - def torch_adam_update(self, - data, - grad, - exp_avg, - exp_avg_sq, - lr, - beta1, - beta2, - eps, - weight_decay, - bias_correction1, - bias_correction2, - use_adamw=False): - # FIXME(ver217): remove the below line when replace torch adam with fused adam - grad = grad.float() + def torch_adam_update( + self, + data, + grad, + exp_avg, + exp_avg_sq, + lr, + beta1, + beta2, + eps, + weight_decay, + bias_correction1, + bias_correction2, + use_adamw=False, + ): + grad = grad.to(data.dtype) if weight_decay != 0: if use_adamw: @@ -120,10 +120,9 @@ class CPUAdam(NVMeOptimizer): with torch.enable_grad(): loss = closure() - self._pre_step('exp_avg', 'exp_avg_sq') + self._pre_step("exp_avg", "exp_avg_sq") for _, group in enumerate(self.param_groups): - for _, p in enumerate(group['params']): - + for _, p in enumerate(group["params"]): if p.grad is None: continue @@ -131,38 +130,81 @@ class CPUAdam(NVMeOptimizer): target_device = p.device if len(state) == 0: - state['step'] = 0 + state["step"] = 0 + # FIXME(ver217): CPU adam kernel only supports fp32 states now + assert p.dtype is torch.float, "CPUAdam only support fp32 parameters" # gradient momentums - state['exp_avg'] = torch.zeros_like(p, dtype=torch.float, device=target_device) + state["exp_avg"] = torch.zeros_like(p, device=target_device) # gradient variances - state['exp_avg_sq'] = torch.zeros_like(p, dtype=torch.float, device=target_device) + state["exp_avg_sq"] = torch.zeros_like(p, device=target_device) self._post_state_init(p) - state['step'] += 1 - beta1, beta2 = group['betas'] + state["step"] += 1 + beta1, beta2 = group["betas"] - if target_device.type == 'cpu': + if target_device.type == "cpu": assert p.data.numel() == p.grad.data.numel(), "parameter and gradient should have the same size" - assert state['exp_avg'].device.type == 'cpu', "exp_avg should stay on cpu" - assert state['exp_avg_sq'].device.type == 'cpu', "exp_avg should stay on cpu" - self._pre_update(p, 'exp_avg', 'exp_avg_sq') - self.cpu_adam_op.step(state['step'], group['lr'], beta1, beta2, group['eps'], group['weight_decay'], - group['bias_correction'], p.data, p.grad.data, state['exp_avg'], - state['exp_avg_sq'], div_scale) - self._post_update(p, 'exp_avg', 'exp_avg_sq') - elif target_device.type == 'cuda': + assert state["exp_avg"].device.type == "cpu", "exp_avg should stay on cpu" + assert state["exp_avg_sq"].device.type == "cpu", "exp_avg should stay on cpu" + self._pre_update(p, "exp_avg", "exp_avg_sq") + if p.grad.dtype is torch.bfloat16: + # cpu adam kernel does not support bf16 now + bias_correction1 = 1 - beta1 ** state["step"] + bias_correction2 = 1 - beta2 ** state["step"] + self.torch_adam_update( + p.data, + p.grad.data, + state["exp_avg"], + state["exp_avg_sq"], + group["lr"], + beta1, + beta2, + group["eps"], + group["weight_decay"], + bias_correction1, + bias_correction2, + self.adamw_mode, + ) + else: + self.cpu_adam_op.step( + state["step"], + group["lr"], + beta1, + beta2, + group["eps"], + group["weight_decay"], + group["bias_correction"], + p.data, + p.grad.data, + state["exp_avg"], + state["exp_avg_sq"], + div_scale, + ) + self._post_update(p, "exp_avg", "exp_avg_sq") + elif target_device.type == "cuda": assert div_scale == -1, "div_scale should remain default" - assert state['exp_avg'].device.type == 'cuda', "exp_avg should stay on cuda" - assert state['exp_avg_sq'].device.type == 'cuda', "exp_avg should stay on cuda" + assert state["exp_avg"].device.type == "cuda", "exp_avg should stay on cuda" + assert state["exp_avg_sq"].device.type == "cuda", "exp_avg should stay on cuda" - bias_correction1 = 1 - beta1**state['step'] - bias_correction2 = 1 - beta2**state['step'] + bias_correction1 = 1 - beta1 ** state["step"] + bias_correction2 = 1 - beta2 ** state["step"] # adam on cuda - self.torch_adam_update(p.data, p.grad.data, state['exp_avg'], state['exp_avg_sq'], group['lr'], - beta1, beta2, group['eps'], group['weight_decay'], bias_correction1, - bias_correction2, self.adamw_mode) + self.torch_adam_update( + p.data, + p.grad.data, + state["exp_avg"], + state["exp_avg_sq"], + group["lr"], + beta1, + beta2, + group["eps"], + group["weight_decay"], + bias_correction1, + bias_correction2, + self.adamw_mode, + ) else: raise RuntimeError self._post_step() diff --git a/colossalai/nn/optimizer/fused_adam.py b/colossalai/nn/optimizer/fused_adam.py index 987af8a968b7d883178b145c121edd2a0fc1033a..fcdd3257d7008f2619c1a0015f161f347816136e 100644 --- a/colossalai/nn/optimizer/fused_adam.py +++ b/colossalai/nn/optimizer/fused_adam.py @@ -1,18 +1,16 @@ # modified from https://github.com/NVIDIA/apex/blob/master/apex/optimizers/fused_adam.py -''' +""" Copyright 2020 The Microsoft DeepSpeed Team Copyright NVIDIA/apex This file is adapted from fused adam in NVIDIA/apex, commit a109f85 Licensed under the MIT License. -''' +""" import torch -from colossalai.registry import OPTIMIZERS from colossalai.utils import multi_tensor_applier -@OPTIMIZERS.register_module class FusedAdam(torch.optim.Optimizer): """Implements Adam algorithm. @@ -53,37 +51,39 @@ class FusedAdam(torch.optim.Optimizer): https://openreview.net/forum?id=ryQu7f-RZ """ - def __init__(self, - params, - lr=1e-3, - bias_correction=True, - betas=(0.9, 0.999), - eps=1e-8, - adamw_mode=True, - weight_decay=0., - amsgrad=False, - set_grad_none=True): - + def __init__( + self, + params, + lr=1e-3, + bias_correction=True, + betas=(0.9, 0.999), + eps=1e-8, + adamw_mode=True, + weight_decay=0.0, + amsgrad=False, + set_grad_none=True, + ): if amsgrad: - raise RuntimeError('FusedAdam does not support the AMSGrad variant.') + raise RuntimeError("FusedAdam does not support the AMSGrad variant.") defaults = dict(lr=lr, bias_correction=bias_correction, betas=betas, eps=eps, weight_decay=weight_decay) super(FusedAdam, self).__init__(params, defaults) self.adamw_mode = 1 if adamw_mode else 0 self.set_grad_none = set_grad_none if multi_tensor_applier.available: from colossalai.kernel.op_builder import FusedOptimBuilder + fused_optim = FusedOptimBuilder().load() # Skip buffer self._dummy_overflow_buf = torch.cuda.IntTensor([0]) self.multi_tensor_adam = fused_optim.multi_tensor_adam else: - raise RuntimeError('FusedAdam requires cuda extensions') + raise RuntimeError("FusedAdam requires cuda extensions") def zero_grad(self, set_to_none=False): if set_to_none: for group in self.param_groups: - for p in group['params']: + for p in group["params"]: p.grad = None else: super(FusedAdam, self).zero_grad() @@ -99,51 +99,63 @@ class FusedAdam(torch.optim.Optimizer): """ if any(p is not None for p in [grads, output_params, scale, grad_norms]): raise RuntimeError( - 'FusedAdam has been updated. Simply initialize it identically to torch.optim.Adam, and call step() with no arguments.' + "FusedAdam has been updated. Simply initialize it identically to torch.optim.Adam, and call step() with no arguments." ) loss = None if closure is not None: loss = closure() for group in self.param_groups: - bias_correction = 1 if group['bias_correction'] else 0 - beta1, beta2 = group['betas'] + bias_correction = 1 if group["bias_correction"] else 0 + beta1, beta2 = group["betas"] # assume same step across group now to simplify things # per parameter step can be easily support by making it tensor, or pass list into kernel - if 'step' in group: - group['step'] += 1 + if "step" in group: + group["step"] += 1 else: - group['step'] = 1 + group["step"] = 1 # create lists for multi-tensor apply g_l, p_l, m_l, v_l = [], [], [], [] - for p in group['params']: + for p in group["params"]: if p.grad is None: continue if p.grad.data.is_sparse: raise RuntimeError( - 'FusedAdam does not support sparse gradients, please consider SparseAdam instead') + "FusedAdam does not support sparse gradients, please consider SparseAdam instead" + ) state = self.state[p] # State initialization if len(state) == 0: # Exponential moving average of gradient values - state['exp_avg'] = torch.zeros_like(p) + state["exp_avg"] = torch.zeros_like(p) # Exponential moving average of squared gradient values - state['exp_avg_sq'] = torch.zeros_like(p) + state["exp_avg_sq"] = torch.zeros_like(p) - if p.dtype not in [torch.float16, torch.float32]: - raise RuntimeError('FusedAdam only support fp16 and fp32.') + if p.dtype not in [torch.float16, torch.float32, torch.bfloat16]: + raise RuntimeError("FusedAdam only support fp16, fp32 and bf16.") g_l.append(p.grad.data) p_l.append(p.data) - m_l.append(state['exp_avg']) - v_l.append(state['exp_avg_sq']) - - multi_tensor_applier(self.multi_tensor_adam, self._dummy_overflow_buf, [g_l, p_l, m_l, v_l], group['lr'], - beta1, beta2, group['eps'], group['step'], self.adamw_mode, bias_correction, - group['weight_decay'], div_scale) + m_l.append(state["exp_avg"]) + v_l.append(state["exp_avg_sq"]) + + multi_tensor_applier( + self.multi_tensor_adam, + self._dummy_overflow_buf, + [g_l, p_l, m_l, v_l], + group["lr"], + beta1, + beta2, + group["eps"], + group["step"], + self.adamw_mode, + bias_correction, + group["weight_decay"], + div_scale, + ) return loss diff --git a/colossalai/nn/optimizer/fused_lamb.py b/colossalai/nn/optimizer/fused_lamb.py index 72520064e98ba50954189490e7cbdb2f00685340..3e1d5a7ba539286fa39b9df9bfb4f8bfdfd11684 100644 --- a/colossalai/nn/optimizer/fused_lamb.py +++ b/colossalai/nn/optimizer/fused_lamb.py @@ -1,11 +1,9 @@ # modified from https://github.com/NVIDIA/apex/blob/master/apex/optimizers/fused_lamb.py import torch -from colossalai.registry import OPTIMIZERS from colossalai.utils import multi_tensor_applier -@OPTIMIZERS.register_module class FusedLAMB(torch.optim.Optimizer): """Implements LAMB algorithm. @@ -51,41 +49,46 @@ class FusedLAMB(torch.optim.Optimizer): https://openreview.net/forum?id=ryQu7f-RZ """ - def __init__(self, - params, - lr=1e-3, - bias_correction=True, - betas=(0.9, 0.999), - eps=1e-6, - weight_decay=0.01, - amsgrad=False, - adam_w_mode=True, - grad_averaging=True, - set_grad_none=True, - max_grad_norm=1.0, - use_nvlamb=False): + def __init__( + self, + params, + lr=1e-3, + bias_correction=True, + betas=(0.9, 0.999), + eps=1e-6, + weight_decay=0.01, + amsgrad=False, + adam_w_mode=True, + grad_averaging=True, + set_grad_none=True, + max_grad_norm=1.0, + use_nvlamb=False, + ): if amsgrad: - raise RuntimeError('FusedLAMB does not support the AMSGrad variant.') - defaults = dict(lr=lr, - bias_correction=bias_correction, - betas=betas, - eps=eps, - weight_decay=weight_decay, - grad_averaging=grad_averaging, - max_grad_norm=max_grad_norm) + raise RuntimeError("FusedLAMB does not support the AMSGrad variant.") + defaults = dict( + lr=lr, + bias_correction=bias_correction, + betas=betas, + eps=eps, + weight_decay=weight_decay, + grad_averaging=grad_averaging, + max_grad_norm=max_grad_norm, + ) super(FusedLAMB, self).__init__(params, defaults) if multi_tensor_applier.available: from colossalai.kernel.op_builder import FusedOptimBuilder + fused_optim = FusedOptimBuilder().load() self.multi_tensor_l2norm = fused_optim.multi_tensor_l2norm # Skip buffer - self._dummy_overflow_buf = torch.tensor([0], - dtype=torch.int, - device=self.param_groups[0]["params"][0].device) + self._dummy_overflow_buf = torch.tensor( + [0], dtype=torch.int, device=self.param_groups[0]["params"][0].device + ) self.multi_tensor_lamb = fused_optim.multi_tensor_lamb else: - raise RuntimeError('FusedLAMB requires cuda extensions') + raise RuntimeError("FusedLAMB requires cuda extensions") self.adam_w_mode = 1 if adam_w_mode else 0 self.set_grad_none = set_grad_none @@ -94,7 +97,7 @@ class FusedLAMB(torch.optim.Optimizer): def zero_grad(self): if self.set_grad_none: for group in self.param_groups: - for p in group['params']: + for p in group["params"]: p.grad = None else: super(FusedLAMB, self).zero_grad() @@ -113,7 +116,7 @@ class FusedLAMB(torch.optim.Optimizer): # create separate grad lists for fp32 and fp16 params g_all_32, g_all_16 = [], [] for group in self.param_groups: - for p in group['params']: + for p in group["params"]: if p.grad is None: continue if p.dtype == torch.float32: @@ -121,7 +124,7 @@ class FusedLAMB(torch.optim.Optimizer): elif p.dtype == torch.float16: g_all_16.append(p.grad.data) else: - raise RuntimeError('FusedLAMB only support fp16 and fp32.') + raise RuntimeError("FusedLAMB only support fp16 and fp32.") device = self.param_groups[0]["params"][0].device g_norm_32, g_norm_16 = torch.zeros(1, device=device), torch.zeros(1, device=device) @@ -132,63 +135,91 @@ class FusedLAMB(torch.optim.Optimizer): g_norm_16 = multi_tensor_applier(self.multi_tensor_l2norm, self._dummy_overflow_buf, [g_all_16], False)[0] # blend two grad norms to get global grad norm - global_grad_norm = multi_tensor_applier(self.multi_tensor_l2norm, self._dummy_overflow_buf, - [[g_norm_32, g_norm_16]], False)[0] - max_grad_norm = self.defaults['max_grad_norm'] + global_grad_norm = multi_tensor_applier( + self.multi_tensor_l2norm, self._dummy_overflow_buf, [[g_norm_32, g_norm_16]], False + )[0] + max_grad_norm = self.defaults["max_grad_norm"] for group in self.param_groups: - bias_correction = 1 if group['bias_correction'] else 0 - beta1, beta2 = group['betas'] - grad_averaging = 1 if group['grad_averaging'] else 0 + bias_correction = 1 if group["bias_correction"] else 0 + beta1, beta2 = group["betas"] + grad_averaging = 1 if group["grad_averaging"] else 0 # assume same step across group now to simplify things # per parameter step can be easily support by making it tensor, or pass list into kernel - if 'step' in group: - group['step'] += 1 + if "step" in group: + group["step"] += 1 else: - group['step'] = 1 + group["step"] = 1 # create lists for multi-tensor apply g_16, p_16, m_16, v_16 = [], [], [], [] g_32, p_32, m_32, v_32 = [], [], [], [] - for p in group['params']: + for p in group["params"]: if p.grad is None: continue if p.grad.data.is_sparse: raise RuntimeError( - 'FusedLAMB does not support sparse gradients, please consider SparseAdam instead') + "FusedLAMB does not support sparse gradients, please consider SparseAdam instead" + ) state = self.state[p] # State initialization if len(state) == 0: # Exponential moving average of gradient values - state['exp_avg'] = torch.zeros_like(p) + state["exp_avg"] = torch.zeros_like(p) # Exponential moving average of gradient values - state['exp_avg_sq'] = torch.zeros_like(p) + state["exp_avg_sq"] = torch.zeros_like(p) if p.dtype == torch.float16: g_16.append(p.grad.data) p_16.append(p.data) - m_16.append(state['exp_avg']) - v_16.append(state['exp_avg_sq']) + m_16.append(state["exp_avg"]) + v_16.append(state["exp_avg_sq"]) elif p.dtype == torch.float32: g_32.append(p.grad.data) p_32.append(p.data) - m_32.append(state['exp_avg']) - v_32.append(state['exp_avg_sq']) + m_32.append(state["exp_avg"]) + v_32.append(state["exp_avg_sq"]) else: - raise RuntimeError('FusedLAMB only support fp16 and fp32.') - - if (len(g_16) > 0): - multi_tensor_applier(self.multi_tensor_lamb, self._dummy_overflow_buf, [g_16, p_16, m_16, v_16], - group['lr'], beta1, beta2, group['eps'], group['step'], bias_correction, - group['weight_decay'], grad_averaging, self.adam_w_mode, global_grad_norm, - max_grad_norm, self.use_nvlamb) - if (len(g_32) > 0): - multi_tensor_applier(self.multi_tensor_lamb, self._dummy_overflow_buf, [g_32, p_32, m_32, v_32], - group['lr'], beta1, beta2, group['eps'], group['step'], bias_correction, - group['weight_decay'], grad_averaging, self.adam_w_mode, global_grad_norm, - max_grad_norm, self.use_nvlamb) + raise RuntimeError("FusedLAMB only support fp16 and fp32.") + + if len(g_16) > 0: + multi_tensor_applier( + self.multi_tensor_lamb, + self._dummy_overflow_buf, + [g_16, p_16, m_16, v_16], + group["lr"], + beta1, + beta2, + group["eps"], + group["step"], + bias_correction, + group["weight_decay"], + grad_averaging, + self.adam_w_mode, + global_grad_norm, + max_grad_norm, + self.use_nvlamb, + ) + if len(g_32) > 0: + multi_tensor_applier( + self.multi_tensor_lamb, + self._dummy_overflow_buf, + [g_32, p_32, m_32, v_32], + group["lr"], + beta1, + beta2, + group["eps"], + group["step"], + bias_correction, + group["weight_decay"], + grad_averaging, + self.adam_w_mode, + global_grad_norm, + max_grad_norm, + self.use_nvlamb, + ) return loss diff --git a/colossalai/nn/optimizer/fused_sgd.py b/colossalai/nn/optimizer/fused_sgd.py index 468713b223c15015ce5e53aa84fda6d7f8afb31c..95a6354208a894e4b44d29f4fa48717d0209308b 100644 --- a/colossalai/nn/optimizer/fused_sgd.py +++ b/colossalai/nn/optimizer/fused_sgd.py @@ -2,11 +2,9 @@ import torch from torch.optim.optimizer import Optimizer, required -from colossalai.registry import OPTIMIZERS from colossalai.utils import multi_tensor_applier -@OPTIMIZERS.register_module class FusedSGD(Optimizer): r"""Implements stochastic gradient descent (optionally with momentum). @@ -56,14 +54,9 @@ class FusedSGD(Optimizer): The Nesterov version is analogously modified. """ - def __init__(self, - params, - lr=required, - momentum=0, - dampening=0, - weight_decay=0, - nesterov=False, - wd_after_momentum=False): + def __init__( + self, params, lr=required, momentum=0, dampening=0, weight_decay=0, nesterov=False, wd_after_momentum=False + ): if lr is not required and lr < 0.0: raise ValueError("Invalid learning rate: {}".format(lr)) if momentum < 0.0: @@ -80,20 +73,21 @@ class FusedSGD(Optimizer): if multi_tensor_applier.available: from colossalai.kernel.op_builder import FusedOptimBuilder + fused_optim = FusedOptimBuilder().load() # Skip buffer - self._dummy_overflow_buf = torch.tensor([0], - dtype=torch.int, - device=self.param_groups[0]["params"][0].device) + self._dummy_overflow_buf = torch.tensor( + [0], dtype=torch.int, device=self.param_groups[0]["params"][0].device + ) self.multi_tensor_sgd = fused_optim.multi_tensor_sgd else: - raise RuntimeError('FusedSGD requires cuda extensions') + raise RuntimeError("FusedSGD requires cuda extensions") def __setstate__(self, state): super(FusedSGD, self).__setstate__(state) for group in self.param_groups: - group.setdefault('nesterov', False) + group.setdefault("nesterov", False) def get_momentums(self, params): momentums = [] @@ -103,13 +97,13 @@ class FusedSGD(Optimizer): # torch.optim.SGD initializes momentum in the main loop, we have # to do it here, and track whether or not we've done so, so that # momentum application can be skipped in the main kernel. - if 'momentum_buffer' not in param_state: + if "momentum_buffer" not in param_state: first_run = True - buf = param_state['momentum_buffer'] = torch.zeros_like(p) + buf = param_state["momentum_buffer"] = torch.zeros_like(p) momentums.append(buf) else: first_run = False - momentums.append(param_state['momentum_buffer']) + momentums.append(param_state["momentum_buffer"]) return momentums, first_run def step(self, closure=None): @@ -124,10 +118,10 @@ class FusedSGD(Optimizer): loss = closure() for group in self.param_groups: - weight_decay = group['weight_decay'] - momentum = group['momentum'] - dampening = group['dampening'] - nesterov = group['nesterov'] + weight_decay = group["weight_decay"] + momentum = group["momentum"] + dampening = group["dampening"] + nesterov = group["nesterov"] # For each group, there are 3 possible combinations we need to consider: # grad_type, param_to_update_type, momentum_type @@ -135,15 +129,26 @@ class FusedSGD(Optimizer): # 2. fp32, fp32, fp32 # 3. fp16, fp32, fp32 g_l, p_l = [], [] - for p in group['params']: + for p in group["params"]: if p.grad is None: continue if p.grad.data.is_sparse: - raise RuntimeError('FusedSGD does not support sparse gradients') + raise RuntimeError("FusedSGD does not support sparse gradients") g_l.append(p.grad) p_l.append(p) m_l, first_run = self.get_momentums(p_l) - multi_tensor_applier(self.multi_tensor_sgd, self._dummy_overflow_buf, [g_l, p_l, m_l], weight_decay, - momentum, dampening, group['lr'], nesterov, first_run, self.wd_after_momentum, 1.0) + multi_tensor_applier( + self.multi_tensor_sgd, + self._dummy_overflow_buf, + [g_l, p_l, m_l], + weight_decay, + momentum, + dampening, + group["lr"], + nesterov, + first_run, + self.wd_after_momentum, + 1.0, + ) return loss diff --git a/colossalai/nn/optimizer/hybrid_adam.py b/colossalai/nn/optimizer/hybrid_adam.py index 1d0fb92de499ba832af472ae5ad190478bc2d302..32fc6136c4e6b2a2c9d3191e3155113534a15fed 100644 --- a/colossalai/nn/optimizer/hybrid_adam.py +++ b/colossalai/nn/optimizer/hybrid_adam.py @@ -2,30 +2,28 @@ from typing import Any, Optional import torch -from colossalai.kernel.op_builder import CPUAdamBuilder, FusedOptimBuilder -from colossalai.registry import OPTIMIZERS +from colossalai.kernel.op_builder import FusedOptimBuilder from colossalai.utils import multi_tensor_applier -from .nvme_optimizer import NVMeOptimizer +from .cpu_adam import CPUAdam -@OPTIMIZERS.register_module -class HybridAdam(NVMeOptimizer): +class HybridAdam(CPUAdam): """Implements Adam algorithm. - Supports parameters updating on both GPU and CPU, depanding on the device of paramters. + Supports parameters updating on both GPU and CPU, depending on the device of parameters. But the parameters and gradients should on the same device: * Parameters on CPU and gradients on CPU is allowed. * Parameters on GPU and gradients on GPU is allowed. * Parameters on GPU and gradients on CPU is **not** allowed. - `HybriadAdam` requires CUDA extensions which can be built during installation or runtime. + `HybridAdam` requires CUDA extensions which can be built during installation or runtime. This version of Hybrid Adam is an hybrid of CPUAdam and FusedAdam. * For parameters updating on CPU, it uses CPUAdam. * For parameters updating on GPU, it uses FusedAdam. - * Hybird precision calculation of fp16 and fp32 is supported, eg fp32 parameters and fp16 gradients. + * Hybrid precision calculation of fp16 and fp32 is supported, eg fp32 parameters and fp16 gradients. :class:`colossalai.nn.optimizer.HybridAdam` may be used as a drop-in replacement for ``torch.optim.AdamW``, or ``torch.optim.Adam`` with ``adamw_mode=False`` @@ -62,27 +60,31 @@ class HybridAdam(NVMeOptimizer): # Param weight, grad, momentum and variance num_fp32_shards_per_param = 4 - def __init__(self, - model_params, - lr=1e-3, - bias_correction=True, - betas=(0.9, 0.999), - eps=1e-8, - weight_decay=0, - adamw_mode=True, - nvme_offload_fraction: float = 0.0, - nvme_offload_dir: Optional[str] = None, - **defaults: Any): - - default_args = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, bias_correction=bias_correction) - super(HybridAdam, self).__init__(model_params, default_args, nvme_offload_fraction, nvme_offload_dir) - self.adamw_mode = adamw_mode - - # build during runtime if not found - cpu_optim = CPUAdamBuilder().load() + def __init__( + self, + model_params, + lr=1e-3, + bias_correction=True, + betas=(0.9, 0.999), + eps=1e-8, + weight_decay=0, + adamw_mode=True, + nvme_offload_fraction: float = 0.0, + nvme_offload_dir: Optional[str] = None, + **defaults: Any, + ): + super().__init__( + model_params, + lr, + bias_correction, + betas, + eps, + weight_decay, + adamw_mode, + nvme_offload_fraction, + nvme_offload_dir, + ) fused_optim = FusedOptimBuilder().load() - self.cpu_adam_op = cpu_optim.CPUAdamOptimizer(lr, betas[0], betas[1], eps, weight_decay, adamw_mode) - self.gpu_adam_op = fused_optim.multi_tensor_adam self._dummy_overflow_buf = torch.cuda.IntTensor([0]) @@ -93,12 +95,11 @@ class HybridAdam(NVMeOptimizer): with torch.enable_grad(): loss = closure() - self._pre_step('exp_avg', 'exp_avg_sq') + self._pre_step("exp_avg", "exp_avg_sq") for _, group in enumerate(self.param_groups): g_l, p_l, m_l, v_l = [], [], [], [] group_step = 0 - for _, p in enumerate(group['params']): - + for _, p in enumerate(group["params"]): if p.grad is None: continue @@ -106,44 +107,87 @@ class HybridAdam(NVMeOptimizer): target_device = p.device if len(state) == 0: - state['step'] = 0 + state["step"] = 0 + # FIXME(ver217): CPU adam kernel only supports fp32 states now + assert p.dtype is torch.float, "HybridAdam only support fp32 parameters" # gradient momentums - state['exp_avg'] = torch.zeros_like(p, dtype=torch.float, device=target_device) + state["exp_avg"] = torch.zeros_like(p, device=target_device) # gradient variances - state['exp_avg_sq'] = torch.zeros_like(p, dtype=torch.float, device=target_device) + state["exp_avg_sq"] = torch.zeros_like(p, device=target_device) self._post_state_init(p) - state['step'] += 1 - group_step = state['step'] - beta1, beta2 = group['betas'] - - if target_device.type == 'cpu': - assert state['exp_avg'].device.type == 'cpu', "exp_avg should stay on cpu" - assert state['exp_avg_sq'].device.type == 'cpu', "exp_avg should stay on cpu" - self._pre_update(p, 'exp_avg', 'exp_avg_sq') - self.cpu_adam_op.step(state['step'], group['lr'], beta1, beta2, group['eps'], group['weight_decay'], - group['bias_correction'], p.data, p.grad.data, state['exp_avg'], - state['exp_avg_sq'], div_scale) - self._post_update(p, 'exp_avg', 'exp_avg_sq') - - elif target_device.type == 'cuda': - assert state['exp_avg'].device.type == 'cuda', "exp_avg should stay on cuda" - assert state['exp_avg_sq'].device.type == 'cuda', "exp_avg should stay on cuda" - - # record the state by gruop and update at once + state["step"] += 1 + group_step = state["step"] + beta1, beta2 = group["betas"] + + if target_device.type == "cpu": + assert state["exp_avg"].device.type == "cpu", "exp_avg should stay on cpu" + assert state["exp_avg_sq"].device.type == "cpu", "exp_avg should stay on cpu" + self._pre_update(p, "exp_avg", "exp_avg_sq") + if p.grad.dtype is torch.bfloat16: + # cpu adam kernel does not support bf16 now + bias_correction1 = 1 - beta1 ** state["step"] + bias_correction2 = 1 - beta2 ** state["step"] + self.torch_adam_update( + p.data, + p.grad.data, + state["exp_avg"], + state["exp_avg_sq"], + group["lr"], + beta1, + beta2, + group["eps"], + group["weight_decay"], + bias_correction1, + bias_correction2, + self.adamw_mode, + ) + else: + self.cpu_adam_op.step( + state["step"], + group["lr"], + beta1, + beta2, + group["eps"], + group["weight_decay"], + group["bias_correction"], + p.data, + p.grad.data, + state["exp_avg"], + state["exp_avg_sq"], + div_scale, + ) + self._post_update(p, "exp_avg", "exp_avg_sq") + + elif target_device.type == "cuda": + assert state["exp_avg"].device.type == "cuda", "exp_avg should stay on cuda" + assert state["exp_avg_sq"].device.type == "cuda", "exp_avg should stay on cuda" + + # record the state by group and update at once g_l.append(p.grad.data) p_l.append(p.data) - m_l.append(state['exp_avg']) - v_l.append(state['exp_avg_sq']) + m_l.append(state["exp_avg"]) + v_l.append(state["exp_avg_sq"]) else: raise RuntimeError if len(g_l) > 0: adamw_mode = 1 if self.adamw_mode else 0 - bias_correction = 1 if group['bias_correction'] else 0 - multi_tensor_applier(self.gpu_adam_op, self._dummy_overflow_buf, [g_l, p_l, m_l, v_l], group['lr'], - group['betas'][0], group['betas'][1], group['eps'], group_step, adamw_mode, - bias_correction, group['weight_decay'], div_scale) + bias_correction = 1 if group["bias_correction"] else 0 + multi_tensor_applier( + self.gpu_adam_op, + self._dummy_overflow_buf, + [g_l, p_l, m_l, v_l], + group["lr"], + group["betas"][0], + group["betas"][1], + group["eps"], + group_step, + adamw_mode, + bias_correction, + group["weight_decay"], + div_scale, + ) self._post_step() return loss diff --git a/colossalai/nn/optimizer/lamb.py b/colossalai/nn/optimizer/lamb.py index 7ac2109572a443e004c58f91bfeb03f85dbbdc33..0d742487f4734899fc49954013ccac91c39d93ec 100644 --- a/colossalai/nn/optimizer/lamb.py +++ b/colossalai/nn/optimizer/lamb.py @@ -5,10 +5,7 @@ Adapted from the pytorch-lamb library at https://github.com/cybertronai/pytorch- import torch from torch.optim import Optimizer -from colossalai.registry import OPTIMIZERS - -@OPTIMIZERS.register_module class Lamb(Optimizer): r"""Implements Lamb algorithm. It has been proposed in `Large Batch Optimization for Deep Learning: Training BERT in 76 minutes`_. @@ -54,27 +51,27 @@ class Lamb(Optimizer): loss = closure() for group in self.param_groups: - for p in group['params']: + for p in group["params"]: if p.grad is None: continue grad = p.grad.data if grad.is_sparse: - raise RuntimeError('Lamb does not support sparse gradients, consider SparseAdam instad.') + raise RuntimeError("Lamb does not support sparse gradients, consider SparseAdam instead.") state = self.state[p] # State initialization if len(state) == 0: - state['step'] = 0 + state["step"] = 0 # Exponential moving average of gradient values - state['exp_avg'] = torch.zeros_like(p) + state["exp_avg"] = torch.zeros_like(p) # Exponential moving average of squared gradient values - state['exp_avg_sq'] = torch.zeros_like(p) + state["exp_avg_sq"] = torch.zeros_like(p) - exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] - beta1, beta2 = group['betas'] + exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"] + beta1, beta2 = group["betas"] - state['step'] += 1 + state["step"] += 1 # Decay the first and second moment running average coefficient # m_t @@ -87,22 +84,22 @@ class Lamb(Optimizer): # bias_correction2 = 1 - beta2 ** state['step'] # Apply bias to lr to avoid broadcast. # * math.sqrt(bias_correction2) / bias_correction1 - step_size = group['lr'] + step_size = group["lr"] weight_norm = p.data.pow(2).sum().sqrt() - adam_step = exp_avg / exp_avg_sq.sqrt().add(group['eps']) - if group['weight_decay'] != 0: - adam_step.add_(p.data, alpha=group['weight_decay']) + adam_step = exp_avg / exp_avg_sq.sqrt().add(group["eps"]) + if group["weight_decay"] != 0: + adam_step.add_(p.data, alpha=group["weight_decay"]) adam_norm = adam_step.pow(2).sum().sqrt() if weight_norm == 0 or adam_norm == 0: trust_ratio = 1 else: trust_ratio = weight_norm / adam_norm - state['weight_norm'] = weight_norm - state['adam_norm'] = adam_norm - state['trust_ratio'] = trust_ratio + state["weight_norm"] = weight_norm + state["adam_norm"] = adam_norm + state["trust_ratio"] = trust_ratio if self.adam: trust_ratio = 1 diff --git a/colossalai/nn/optimizer/lars.py b/colossalai/nn/optimizer/lars.py index 212f66671a0db9f580d99758823fdf78e3e54106..b117c00846d130b168614c99572b4c0bf30e844b 100644 --- a/colossalai/nn/optimizer/lars.py +++ b/colossalai/nn/optimizer/lars.py @@ -5,10 +5,7 @@ from typing import Iterable import torch from torch.optim import Optimizer -from colossalai.registry import OPTIMIZERS - -@OPTIMIZERS.register_module class Lars(Optimizer): r"""Implements the LARS optimizer from `"Large batch training of convolutional networks" `_. @@ -23,27 +20,19 @@ class Lars(Optimizer): """ def __init__( - self, - params: Iterable[torch.nn.Parameter], - lr=1e-3, - momentum=0, - eeta=1e-3, - weight_decay=0, - epsilon=0.0 + self, params: Iterable[torch.nn.Parameter], lr=1e-3, momentum=0, eeta=1e-3, weight_decay=0, epsilon=0.0 ) -> None: if not isinstance(lr, float) or lr < 0.0: raise ValueError("Invalid learning rate: {}".format(lr)) if momentum < 0.0: raise ValueError("Invalid momentum value: {}".format(momentum)) if weight_decay < 0.0: - raise ValueError( - "Invalid weight_decay value: {}".format(weight_decay)) + raise ValueError("Invalid weight_decay value: {}".format(weight_decay)) if eeta <= 0 or eeta > 1: raise ValueError("Invalid eeta value: {}".format(eeta)) if epsilon < 0: raise ValueError("Invalid epsilon value: {}".format(epsilon)) - defaults = dict(lr=lr, momentum=momentum, - weight_decay=weight_decay, eeta=eeta, epsilon=epsilon, lars=True) + defaults = dict(lr=lr, momentum=momentum, weight_decay=weight_decay, eeta=eeta, epsilon=epsilon, lars=True) super().__init__(params, defaults) @@ -61,14 +50,14 @@ class Lars(Optimizer): loss = closure() for group in self.param_groups: - weight_decay = group['weight_decay'] - momentum = group['momentum'] - eeta = group['eeta'] - lr = group['lr'] - lars = group['lars'] - eps = group['epsilon'] + weight_decay = group["weight_decay"] + momentum = group["momentum"] + eeta = group["eeta"] + lr = group["lr"] + lars = group["lars"] + eps = group["epsilon"] - for p in group['params']: + for p in group["params"]: if p.grad is None: continue decayed_grad = p.grad @@ -79,7 +68,7 @@ class Lars(Optimizer): trust_ratio = torch.where( w_norm > 0 and g_norm > 0, eeta * w_norm / (g_norm + weight_decay * w_norm + eps), - torch.ones_like(w_norm) + torch.ones_like(w_norm), ) trust_ratio.clamp_(0.0, 50) scaled_lr *= trust_ratio.item() @@ -89,11 +78,10 @@ class Lars(Optimizer): if momentum != 0: param_state = self.state[p] - if 'momentum_buffer' not in param_state: - buf = param_state['momentum_buffer'] = torch.clone( - decayed_grad).detach() + if "momentum_buffer" not in param_state: + buf = param_state["momentum_buffer"] = torch.clone(decayed_grad).detach() else: - buf = param_state['momentum_buffer'] + buf = param_state["momentum_buffer"] buf.mul_(momentum).add_(decayed_grad) decayed_grad = buf diff --git a/colossalai/nn/optimizer/nvme_optimizer.py b/colossalai/nn/optimizer/nvme_optimizer.py index 53e4a46c9741dc8cb5fe3fc2335789beada3e3c8..fd02bfb683e1120e53288d1e17865600cd88e869 100644 --- a/colossalai/nn/optimizer/nvme_optimizer.py +++ b/colossalai/nn/optimizer/nvme_optimizer.py @@ -19,13 +19,11 @@ class NVMeOptimizer(torch.optim.Optimizer): Raises: ImportError: Raise if ``tensornvme`` is not installed. - """ + """ - def __init__(self, - params, - defaults: dict, - nvme_offload_fraction: float = 0.0, - offload_dir: Optional[str] = None) -> None: + def __init__( + self, params, defaults: dict, nvme_offload_fraction: float = 0.0, offload_dir: Optional[str] = None + ) -> None: assert 0.0 <= nvme_offload_fraction <= 1.0 super().__init__(params, defaults) self.nvme_offload_fraction = float(nvme_offload_fraction) @@ -34,16 +32,16 @@ class NVMeOptimizer(torch.optim.Optimizer): from tensornvme import DiskOffloader from tensornvme._C import get_backends except ModuleNotFoundError: - raise ModuleNotFoundError('Please install tensornvme to use NVMeOptimizer') + raise ModuleNotFoundError("Please install tensornvme to use NVMeOptimizer") self.offload_dir = offload_dir or tempfile.mkdtemp() - backend = 'uring' if 'uring' in get_backends() else 'aio' + backend = "uring" if "uring" in get_backends() else "aio" self.offloader = DiskOffloader(self.offload_dir, 8, backend=backend) else: self.offload_dir = None self.offloader = None self.is_on_nvme: Dict[Parameter, bool] = {} self.offloaded_numel: int = 0 - # As param may be not materialized here, these attributes are initalized when the first step + # As param may be not materialized here, these attributes are initialized when the first step self.total_numel: Optional[int] = None self.can_offload_numel: Optional[int] = None @@ -53,13 +51,17 @@ class NVMeOptimizer(torch.optim.Optimizer): def _get_numel(self) -> int: numel = 0 for group in self.param_groups: - for p in group['params']: + for p in group["params"]: numel += p.storage().size() return numel def _post_state_init(self, param: Parameter) -> None: numel = param.storage().size() - if self.offloader is not None and param.device.type == 'cpu' and numel + self.offloaded_numel <= self.can_offload_numel: + if ( + self.offloader is not None + and param.device.type == "cpu" + and numel + self.offloaded_numel <= self.can_offload_numel + ): self.is_on_nvme[param] = True self.offloaded_numel += numel else: @@ -70,11 +72,11 @@ class NVMeOptimizer(torch.optim.Optimizer): return assert len(self.prefetch_params) == 0 and len(self.param_to_prefetch_idx) == 0 for group in self.param_groups: - for p in group['params']: + for p in group["params"]: if p.grad is None: continue if len(self.state[p]) > 0 and self.is_on_nvme[p]: - assert p.device.type == 'cpu' + assert p.device.type == "cpu" self.param_to_prefetch_idx[p] = len(self.prefetch_params) self.prefetch_params.append(p) @@ -156,7 +158,7 @@ class NVMeOptimizer(torch.optim.Optimizer): super().load_state_dict(state_dict) def __del__(self) -> None: - if getattr(self, 'offloader', None) is not None: + if getattr(self, "offloader", None) is not None: del self.offloader if os.path.exists(self.offload_dir): try: diff --git a/colossalai/nn/parallel/__init__.py b/colossalai/nn/parallel/__init__.py deleted file mode 100644 index 17e010f478c92f5c445c32ef2c512ef3007f95b2..0000000000000000000000000000000000000000 --- a/colossalai/nn/parallel/__init__.py +++ /dev/null @@ -1,5 +0,0 @@ -from .data_parallel import ColoDDP - -__all__ = [ - 'ColoDDP', -] diff --git a/colossalai/nn/parallel/layers/__init__.py b/colossalai/nn/parallel/layers/__init__.py deleted file mode 100644 index 29b8353e63c5930950d28d58b0c122c76486c3e1..0000000000000000000000000000000000000000 --- a/colossalai/nn/parallel/layers/__init__.py +++ /dev/null @@ -1,14 +0,0 @@ -from .colo_module import ColoModule -from .linear import ColoLinear -from .embedding import ColoEmbedding -from .module_utils import register_colo_module, is_colo_module, get_colo_module, init_colo_module, check_colo_module - -from .cache_embedding import CachedEmbeddingBag, ParallelCachedEmbeddingBag, CachedParamMgr, LimitBuffIndexCopyer, EvictionStrategy, \ - ParallelCachedEmbeddingBagTablewise, TablewiseEmbeddingBagConfig, ParallelCachedEmbeddingBagTablewiseSpiltCache - -__all__ = [ - 'ColoModule', 'register_colo_module', 'is_colo_module', 'get_colo_module', 'init_colo_module', 'check_colo_module', - 'ColoLinear', 'ColoEmbedding', 'CachedEmbeddingBag', 'ParallelCachedEmbeddingBag', 'CachedParamMgr', - 'LimitBuffIndexCopyer', 'EvictionStrategy', 'ParallelCachedEmbeddingBagTablewise', 'TablewiseEmbeddingBagConfig', - 'ParallelCachedEmbeddingBagTablewiseSpiltCache' -] diff --git a/colossalai/nn/parallel/layers/cache_embedding/__init__.py b/colossalai/nn/parallel/layers/cache_embedding/__init__.py deleted file mode 100644 index 5bbc931a79dceeffcf827fc3fd3b18ca8bf87dd3..0000000000000000000000000000000000000000 --- a/colossalai/nn/parallel/layers/cache_embedding/__init__.py +++ /dev/null @@ -1,13 +0,0 @@ -from .cache_mgr import CachedParamMgr, EvictionStrategy -from .copyer import LimitBuffIndexCopyer -from .cached_embedding import CachedEmbeddingBag -from .parallel_cached_embedding import ParallelCachedEmbeddingBag -from .embedding_config import TablewiseEmbeddingBagConfig -from .parallel_cached_embedding_tablewise import ParallelCachedEmbeddingBagTablewise -from .parallel_cached_embedding_tablewise_split_cache import ParallelCachedEmbeddingBagTablewiseSpiltCache - -__all__ = [ - 'CachedParamMgr', 'LimitBuffIndexCopyer', 'CachedEmbeddingBag', 'ParallelCachedEmbeddingBag', 'EvictionStrategy', - 'ParallelCachedEmbeddingBagTablewise', 'TablewiseEmbeddingBagConfig', - 'ParallelCachedEmbeddingBagTablewiseSpiltCache' -] diff --git a/colossalai/nn/parallel/layers/cache_embedding/cached_embedding.py b/colossalai/nn/parallel/layers/cache_embedding/cached_embedding.py deleted file mode 100644 index a0c45d8e80c028637a5c964b41062846b7020c7b..0000000000000000000000000000000000000000 --- a/colossalai/nn/parallel/layers/cache_embedding/cached_embedding.py +++ /dev/null @@ -1,157 +0,0 @@ -import torch -import torch.nn.functional as F -from typing import List, Optional, Iterator, Tuple, Union - -from .base_embedding import BaseEmbeddingBag -from .cache_mgr import CachedParamMgr, EvictionStrategy -from torch.nn.parameter import Parameter - - -class CachedEmbeddingBag(BaseEmbeddingBag): - """CachedEmbeddingBag - - Cached Embedding. Apply a GPU-based software cache approaches to dynamically manage the embedding table in the CPU and GPU memory space. - It can leverage the id's frequency statistics of the target dataset, by passing a frequency list to param `ids_freq_mapping`. - You can also apply a navie LFU cache eviction strategy by setting `evict_strategy` as EvictionStrategy.LFU. - - Args: - num_embeddings (int): size of the dictionary of embeddings - embedding_dim (int): the size of each embedding vector - padding_idx (int, optional): If specified, the entries at padding_idx do not contribute to the gradient; therefore, the embedding vector at padding_idx is not updated during training, i.e. it remains as a fixed “pad”. For a newly constructed EmbeddingBag, the embedding vector at padding_idx will default to all zeros, but can be updated to another value to be used as the padding vector. Note that the embedding vector at padding_idx is excluded from the reduction. - max_norm (float, optional): If given, each embedding vector with norm larger than max_norm is renormalized to have norm max_norm - norm_type (str, optional): The p of the p-norm to compute for the max_norm option. Defaults to 2.. - scale_grad_by_freq (bool, optional): if given, this will scale gradients by the inverse of frequency of the words in the mini-batch. Default False. Note: this option is not supported when mode="max". Defaults to False. - sparse (bool, optional): if True, gradient w.r.t. weight matrix will be a sparse tensor. See Notes for more details regarding sparse gradients. Note: this option is not supported when mode="max".. Defaults to False. - _weight (torch.Tensor, optional): an embedding weight tensor. Concate multiple tables in a embedding bag as a single one. Defaults to None. - mode (str, optional): "sum", "mean" or "max". Specifies the way to reduce the bag. "sum" computes the weighted sum, taking per_sample_weights into consideration. "mean" computes the average of the values in the bag, "max" computes the max value over each bag. Default: "mean". Defaults to 'mean'. - include_last_offset (bool, optional): if True, offsets has one additional element, where the last element is equivalent to the size of indices. This matches the CSR format.. Defaults to False. - dtype (torch.dtype, optional): data type of the cpu weight initialization. Defaults to None meaning float32. - device (torch.device, optional): device type to the cpu weight. Defaults to None meaning cpu. - cache_ratio (float, float): cache ratio of the #cuda_weight_row / #cpu_weight_row - ids_freq_mapping (Union[List, torch.Tensor], optional): the frequency of each embedding vector occures in dataset. Defaults to None. - warmup_ratio (float, optional): the ratio of cuda cache is warmuped with. Defaults to 0.7. - buffer_size (int, optional): the max number of vectors in transmitter buffer. If set to 0, the buffer is not used. Defaults to 0. - pin_weight (bool, optional): pin the cpu weight. Defaults to False. - evict_strategy (EvictionStrategy, optional): evict strategy of the software cache. Defaults to EvictionStrategy.DATASET. - """ - - def __init__(self, - num_embeddings: int, - embedding_dim: int, - padding_idx: int = None, - max_norm: float = None, - norm_type: float = 2., - scale_grad_by_freq: bool = False, - sparse: bool = False, - _weight: Optional[torch.Tensor] = None, - mode: str = 'mean', - include_last_offset: bool = False, - dtype: Optional[torch.dtype] = None, - device: Optional[torch.device] = None, - cache_ratio: float = 0.01, - ids_freq_mapping: Optional[Union[List, torch.Tensor]] = None, - warmup_ratio: float = 0.7, - buffer_size: int = 0, - pin_weight: bool = False, - evict_strategy: EvictionStrategy = EvictionStrategy.LFU): - super(CachedEmbeddingBag, self).__init__(num_embeddings, embedding_dim, padding_idx, max_norm, norm_type, - scale_grad_by_freq, sparse, mode, include_last_offset) - - assert cache_ratio <= 1.0, f"cache ratio {cache_ratio} must less than 1.0" - self.evict_strategy = evict_strategy - if _weight is None: - _weight = self._weight_alloc(dtype, device) - cuda_row_num = int(num_embeddings * cache_ratio) - # configure weight & cache - self._preprocess(_weight, cuda_row_num, ids_freq_mapping, warmup_ratio, buffer_size, pin_weight) - self.cache_op = True - - def set_cache_mgr_async_copy(self, flag): - self.cache_weight_mgr._async_copy = flag - - def _weight_alloc(self, dtype, device): - weight = torch.empty(self.num_embeddings, self.embedding_dim, dtype=dtype, device=device) - with torch.no_grad(): - weight.data.uniform_(-1 / self.num_embeddings, 1 / self.num_embeddings) - if self.padding_idx is not None: - weight[self.padding_idx].fill_(0) - return weight - - def _preprocess(self, - weight, - cuda_row_num: int, - ids_freq_mapping: Optional[List[int]] = None, - warmup_ratio=0.7, - buffer_size=50_000, - pin_weight=False): - """ - Called after initialized. - Reorder the weight rows according to the ids_freq_mapping. - Then, let the weights of the Module be managed by a CachedParamMgr. - - Args: - cuda_row_num (int): number of rows can be hosted in CUDA memory - ids_freq_mapping (List[int]): a list, idx is id number, value is freq - warmup_ratio (float): the amount of rows preloaded in cuda cache - """ - self.cache_weight_mgr = CachedParamMgr(weight, - cuda_row_num, - buffer_size, - pin_weight, - evict_strategy=self.evict_strategy) - self.cache_weight_mgr.reorder(ids_freq_mapping, warmup_ratio) - - def forward(self, input, offsets=None, per_sample_weights=None, shape_hook=None): - if self.cache_op: - with torch.no_grad(): - input = self.cache_weight_mgr.prepare_ids(input) - - embeddings = F.embedding_bag(input.cuda(), self.cache_weight_mgr.cuda_cached_weight, offsets, self.max_norm, - self.norm_type, self.scale_grad_by_freq, self.mode, self.sparse, - per_sample_weights, self.include_last_offset, self.padding_idx) - if shape_hook is not None: - embeddings = shape_hook(embeddings) - return embeddings - - @property - def weight(self): - return self.cache_weight_mgr.weight - - def named_parameters(self, prefix: str = '', recurse: bool = True) -> Iterator[Tuple[str, Parameter]]: - yield 'weight', self.cache_weight_mgr.cuda_cached_weight - - def parameters(self, recurse: bool = True) -> Iterator[Parameter]: - yield self.cache_weight_mgr.cuda_cached_weight - - def set_cache_op(self, cache_op: bool = True): - self.cache_op = cache_op - - -############################# Perf Log ################################### - - @property - def num_hits_history(self): - return self.cache_weight_mgr.num_hits_history - - @property - def num_miss_history(self): - return self.cache_weight_mgr.num_miss_history - - @property - def num_write_back_history(self): - return self.cache_weight_mgr.num_write_back_history - - @property - def swap_in_bandwidth(self): - if self.cache_weight_mgr._cpu_to_cuda_numel > 0: - return self.cache_weight_mgr._cpu_to_cuda_numel * self.cache_weight_mgr.elem_size_in_byte / 1e6 / \ - self.cache_weight_mgr._cpu_to_cuda_elpase - else: - return 0 - - @property - def swap_out_bandwidth(self): - if self.cache_weight_mgr._cuda_to_cpu_numel > 0: - return self.cache_weight_mgr._cuda_to_cpu_numel * self.cache_weight_mgr.elem_size_in_byte / 1e6 / \ - self.cache_weight_mgr._cuda_to_cpu_elapse - return 0 diff --git a/colossalai/nn/parallel/layers/cache_embedding/embedding_config.py b/colossalai/nn/parallel/layers/cache_embedding/embedding_config.py deleted file mode 100644 index 36e04c833feb4203d9033a15951b580207890e8b..0000000000000000000000000000000000000000 --- a/colossalai/nn/parallel/layers/cache_embedding/embedding_config.py +++ /dev/null @@ -1,27 +0,0 @@ -import torch - - -class TablewiseEmbeddingBagConfig: - ''' - example: - def prepare_tablewise_config(args, cache_ratio, ...): - embedding_bag_config_list: List[TablewiseEmbeddingBagConfig] = [] - ... - return embedding_bag_config_list - ''' - - def __init__(self, - num_embeddings: int, - cuda_row_num: int, - assigned_rank: int = 0, - buffer_size=50_000, - ids_freq_mapping=None, - initial_weight: torch.tensor = None, - name: str = ""): - self.num_embeddings = num_embeddings - self.cuda_row_num = cuda_row_num - self.assigned_rank = assigned_rank - self.buffer_size = buffer_size - self.ids_freq_mapping = ids_freq_mapping - self.initial_weight = initial_weight - self.name = name diff --git a/colossalai/nn/parallel/layers/cache_embedding/parallel_cached_embedding.py b/colossalai/nn/parallel/layers/cache_embedding/parallel_cached_embedding.py deleted file mode 100644 index d7f77e195f4b480da53e6f1772145d122afb1e69..0000000000000000000000000000000000000000 --- a/colossalai/nn/parallel/layers/cache_embedding/parallel_cached_embedding.py +++ /dev/null @@ -1,141 +0,0 @@ -import torch -import torch.nn.functional as F -from typing import List, Optional, Iterator, Tuple - -from .cached_embedding import CachedEmbeddingBag -from colossalai.nn._ops._utils import dual_all_to_all - -from colossalai.tensor import ColoParameter, ShardSpec, ComputePattern, ProcessGroup, ColoTensorSpec, ColoTensor -from .cache_mgr import CachedParamMgr, EvictionStrategy - - -def get_partition(embedding_dim, rank, world_size) -> Tuple[int, int, bool]: - if world_size == 1: - return 0, embedding_dim, True - - assert embedding_dim >= world_size, \ - f"Embedding dimension {embedding_dim} must be larger than the world size " \ - f"{world_size} of the process group" - chunk_size = embedding_dim // world_size - threshold = embedding_dim % world_size - # if embedding dim is divisible by world size - if threshold == 0: - return rank * chunk_size, (rank + 1) * chunk_size, True - - # align with the split strategy of torch.tensor_split - size_list = [chunk_size + 1 if i < threshold else chunk_size for i in range(world_size)] - offset = sum(size_list[:rank]) - return offset, offset + size_list[rank], False - - -class ParallelCachedEmbeddingBag(CachedEmbeddingBag): - - def __init__(self, - num_embeddings, - embedding_dim, - padding_idx=None, - max_norm=None, - norm_type=2., - scale_grad_by_freq=False, - sparse=False, - _weight=None, - mode='mean', - include_last_offset=False, - dtype=None, - device=None, - cache_ratio=0.01, - ids_freq_mapping=None, - warmup_ratio=0.7, - buffer_size=50_000, - pin_weight=False, - evict_strategy: EvictionStrategy = EvictionStrategy.DATASET): - self.rank = torch.distributed.get_rank() - self.world_size = torch.distributed.get_world_size() - - self.partition_start_index, self.partition_end_index, divisible = get_partition( - embedding_dim, self.rank, self.world_size) - self.embedding_dim_per_partition = self.partition_end_index - self.partition_start_index - - super(ParallelCachedEmbeddingBag, - self).__init__(num_embeddings, embedding_dim, padding_idx, max_norm, norm_type, scale_grad_by_freq, - sparse, _weight, mode, include_last_offset, dtype, device, cache_ratio, ids_freq_mapping, - warmup_ratio, buffer_size, pin_weight, evict_strategy) - self.cache_op = True - - def _weight_alloc(self, dtype, device): - weight = torch.empty(self.num_embeddings, self.embedding_dim_per_partition, device=device, dtype=dtype) - with torch.no_grad(): - weight.data.uniform_(-1 / self.num_embeddings, 1 / self.num_embeddings) - if self.padding_idx is not None: - weight[self.padding_idx].fill_(0) - colo_tensor_spec = ColoTensorSpec(pg=ProcessGroup(tp_degree=self.world_size), - dist_attr=ShardSpec(dims=[-1], num_partitions=[self.world_size]), - compute_attr=ComputePattern.TP1D) - return ColoTensor.from_torch_tensor(weight, spec=colo_tensor_spec) - - def forward( - self, - indices, - offsets=None, - per_sample_weights=None, - shape_hook=None, - scatter_dim=0, - gather_dim=-1, - ): - if self.cache_op: - with torch.no_grad(): - indices = self.cache_weight_mgr.prepare_ids(indices) - output_shard = F.embedding_bag(indices.cuda(), self.cache_weight_mgr.cuda_cached_weight, offsets, self.max_norm, - self.norm_type, self.scale_grad_by_freq, self.mode, self.sparse, - per_sample_weights, self.include_last_offset, self.padding_idx) - if shape_hook is not None: - output_shard = shape_hook(output_shard) - output_full = dual_all_to_all(output_shard, - self.weight.get_process_group(), - scatter_dim=scatter_dim, - gather_dim=gather_dim) - return output_full - - def set_cache_op(self, cache_op: bool = True): - self.cache_op = cache_op - - @classmethod - def from_pretrained( - cls, - embedding: torch.Tensor, - freeze: bool = True, - padding_idx: Optional[int] = None, - max_norm: Optional[float] = None, - norm_type: float = 2., - scale_grad_by_freq: bool = False, - sparse: bool = False, - mode: str = 'mean', - include_last_offset: bool = False, - cuda_row_num: int = 100_000, - ids_freq_mapping: Optional[List[int]] = None, - warmup_ratio: float = 0.7, - buffer_size: int = 0, - ) -> 'ParallelCachedEmbeddingBag': - rows, cols = embedding.shape - embedding_bag = cls(rows, - cols, - padding_idx, - max_norm, - norm_type, - scale_grad_by_freq, - sparse, - embedding, - mode, - include_last_offset, - cuda_row_num=cuda_row_num, - ids_freq_mapping=ids_freq_mapping, - warmup_ratio=warmup_ratio, - buffer_size=buffer_size) - embedding_bag.cache_weight_mgr.cuda_cached_weight.requires_grad_ = not freeze - return embedding_bag - - def print_comm_stats_(self): - self.cache_weight_mgr.print_comm_stats() - - def element_size(self): - return self.weight.element_size() diff --git a/colossalai/nn/parallel/layers/cache_embedding/parallel_cached_embedding_tablewise.py b/colossalai/nn/parallel/layers/cache_embedding/parallel_cached_embedding_tablewise.py deleted file mode 100644 index 949f85ad4baf894d4e53f06594bcbd9080f249ae..0000000000000000000000000000000000000000 --- a/colossalai/nn/parallel/layers/cache_embedding/parallel_cached_embedding_tablewise.py +++ /dev/null @@ -1,198 +0,0 @@ -import torch -import torch.distributed as dist -import torch.nn.functional as F - -from .cached_embedding import CachedEmbeddingBag -from .cache_mgr import EvictionStrategy -from .embedding_config import TablewiseEmbeddingBagConfig -from colossalai.tensor import ProcessGroup -from colossalai.nn._ops._utils import dual_all_to_all_tablewise - -from typing import List -import time - - -class ParallelCachedEmbeddingBagTablewise(CachedEmbeddingBag): - """ - all tables assigned to this class instance are managed by a single CachedEmbeddingBag. - Those parameters in TablewiseEmbeddingBagConfig are ignored: cuda_row_num, buffer_size, initial_weight. - """ - - def __init__(self, - embedding_bag_config_list: List[TablewiseEmbeddingBagConfig], - embedding_dim: int, - padding_idx=None, - max_norm=None, - norm_type=2., - scale_grad_by_freq=False, - sparse=False, - _weight=None, - mode='mean', - include_last_offset=False, - dtype=None, - device=None, - cache_ratio=0.01, - warmup_ratio=0.7, - buffer_size=50_000, - pin_weight=False, - evict_strategy: EvictionStrategy = EvictionStrategy.LFU): - self.rank = dist.get_rank() - self.world_size = dist.get_world_size() - self.rank_of_tables = [config.assigned_rank for config in embedding_bag_config_list] - self.global_table_num_embeddings_list = [config.num_embeddings for config in embedding_bag_config_list] - self.global_tables_num = len(embedding_bag_config_list) - self.global_tables_offsets = torch.cumsum(torch.tensor([0] + self.global_table_num_embeddings_list), 0).cuda() - self.assigned_table_list: List[int] = [] - self.pg = ProcessGroup(tp_degree=self.world_size) - self.num_embeddings = 0 - for i, rank in enumerate(self.rank_of_tables): - if rank == self.rank: - self.assigned_table_list.append(i) - self.num_embeddings += self.global_table_num_embeddings_list[i] - self.include_last_offset = include_last_offset - - ids_freq_mapping = [] - for config in embedding_bag_config_list: - if config.assigned_rank == self.rank: - if config.ids_freq_mapping != None: - ids_freq_mapping.extend(config.ids_freq_mapping) - else: - ids_freq_mapping = None - break - self.cache_ratio = cache_ratio - # table-associate cache - cuda_row_num = int(cache_ratio * self.num_embeddings) - super(ParallelCachedEmbeddingBagTablewise, - self).__init__(self.num_embeddings, embedding_dim, padding_idx, max_norm, norm_type, scale_grad_by_freq, - sparse, _weight, mode, include_last_offset, dtype, device, cache_ratio, ids_freq_mapping, - warmup_ratio, buffer_size, pin_weight, evict_strategy) - - # for assigned tables reconnection: - self.idx_offset_list = [] - offset_cumsum = 0 - for table_i, table_num_embeddings in enumerate(self.global_table_num_embeddings_list): - if self.rank_of_tables[table_i] == self.rank: - self.idx_offset_list.append(offset_cumsum) - else: - offset_cumsum += table_num_embeddings - - # prepare list shape for all_to_all output - self.embedding_dim_per_rank = [0 for i in range(self.world_size)] - for rank in self.rank_of_tables: - self.embedding_dim_per_rank[rank] += embedding_dim - - self.cache_op = True - - def forward( - self, - indices: torch.Tensor, - offsets: torch.Tensor = None, - per_sample_weights=None, - shape_hook=None, - already_split_along_rank=True, - ): - if not already_split_along_rank: - # not recommanded. it takes time. - batch_size = (offsets.shape[0]) // self.global_tables_num - local_indices, local_offsets, local_per_sample_weights = self.split_along_rank( - batch_size, indices, offsets, per_sample_weights) - else: - # recommanded. - batch_size = (offsets.shape[0]) // len(self.assigned_table_list) - local_indices, local_offsets, local_per_sample_weights = indices, offsets, per_sample_weights - if self.cache_op: - with torch.no_grad(): - indices = self.cache_weight_mgr.prepare_ids(local_indices) - local_output = F.embedding_bag(indices.cuda(), self.cache_weight_mgr.cuda_cached_weight, local_offsets, - self.max_norm, self.norm_type, self.scale_grad_by_freq, self.mode, self.sparse, - local_per_sample_weights, self.include_last_offset, self.padding_idx) - local_output = torch.cat(local_output.split(batch_size), 1) - remains = batch_size % self.world_size - scatter_strides = [batch_size // self.world_size + int(i < remains) for i in range(self.world_size)] - output_full = dual_all_to_all_tablewise(local_output, self.pg, scatter_strides, self.embedding_dim_per_rank) - if shape_hook is not None: - output_full = shape_hook(output_full) - return output_full - - def split_along_rank(self, - batch_size, - indices: torch.Tensor, - offsets: torch.Tensor = None, - per_sample_weights=None): - ''' - if input indices and offsets haven't been splitted along assigned rank, this function will do it. - it takes time. please consider splitting data during batch loading. - ''' - local_indices_list: List(torch.Tensor) = [] - local_offsets_list: List(torch.Tensor) = [] - if per_sample_weights != None: - local_per_sample_weights_list: List(torch.Tensor) = [] - - offset_pre_end = 0 # local_offsets trick - for i, handle_table in enumerate(self.assigned_table_list): - indices_start_position = offsets[batch_size * handle_table] - if (not self.include_last_offset) and (batch_size * (handle_table + 1) >= indices.shape[0]): - # till-the-end special case - indices_end_position = indices.shape[0] - else: - indices_end_position = offsets[batch_size * (handle_table + 1)] - # alternative approach: reduce malloc - ''' - # 1. local_indices_list: - local_indices = indices.narrow(0, indices_start_position, indices_end_position - indices_start_position) - torch.sub(local_indices, self.idx_offset_list[i], out=local_indices) - local_indices_list.append(local_indices) - # 2. local_offsets_list: - if i + 1 == len(self.assigned_table_list): - # till-the-end special case - if not self.include_last_offset: - local_offsets = offsets.narrow(0, batch_size * handle_table, batch_size) - else: - local_offsets = offsets.narrow(0, batch_size * handle_table, batch_size + 1) - torch.add(local_offsets, offset_pre_end - offsets[batch_size * handle_table], out=local_offsets) - local_offsets_list.append(local_offsets) - else: - temp_holder = offsets[batch_size * handle_table].item() - local_offsets = offsets.narrow(0, batch_size * handle_table, batch_size) - torch.add(local_offsets, offset_pre_end - offsets[batch_size * handle_table], out=local_offsets) - offset_pre_end = offsets[batch_size * (handle_table + 1)] + offset_pre_end - temp_holder - local_offsets_list.append(local_offsets) - ''' - # 1. local_indices_list: - local_indices_list.append( - indices.narrow(0, indices_start_position, - indices_end_position - indices_start_position).sub(self.idx_offset_list[i])) - # 2. local_offsets_list: - if i + 1 == len(self.assigned_table_list): - # till-the-end special case - if not self.include_last_offset: - local_offsets = offsets.narrow(0, batch_size * handle_table, - batch_size).add(offset_pre_end - offsets[batch_size * - (handle_table)]) - else: - local_offsets = offsets.narrow(0, batch_size * handle_table, batch_size + - 1).add(offset_pre_end - offsets[batch_size * (handle_table)]) - local_offsets_list.append(local_offsets) - else: - local_offsets = offsets.narrow(0, batch_size * handle_table, batch_size + - 1).add(offset_pre_end - offsets[batch_size * (handle_table)]) - offset_pre_end = local_offsets[-1] - local_offsets_list.append(local_offsets[:-1]) - # 3. local_per_sample_weights_list: - if per_sample_weights != None: - local_per_sample_weights_list.append(per_sample_weights[indices_start_position:indices_end_position]) - local_indices = torch.cat(local_indices_list, 0) - local_offsets = torch.cat(local_offsets_list, 0) - local_per_sample_weights = None - if per_sample_weights != None: - local_per_sample_weights = torch.cat(local_per_sample_weights_list, 0) - return local_indices, local_offsets, local_per_sample_weights - - def set_cache_op(self, cache_op: bool = True): - self.cache_op = cache_op - - def print_comm_stats_(self): - self.cache_weight_mgr.print_comm_stats() - - def element_size(self): - return self.weight.element_size() diff --git a/colossalai/nn/parallel/layers/cache_embedding/parallel_cached_embedding_tablewise_split_cache.py b/colossalai/nn/parallel/layers/cache_embedding/parallel_cached_embedding_tablewise_split_cache.py deleted file mode 100644 index cb4647028d477d6c89ddfcb0d68d906ba43aa4c8..0000000000000000000000000000000000000000 --- a/colossalai/nn/parallel/layers/cache_embedding/parallel_cached_embedding_tablewise_split_cache.py +++ /dev/null @@ -1,138 +0,0 @@ -import torch -import torch.distributed as dist -import torch.nn as nn -from torch.profiler import record_function - -from .cached_embedding import CachedEmbeddingBag - -from colossalai.tensor import ProcessGroup -from colossalai.nn._ops._utils import dual_all_to_all_tablewise -from .embedding_config import TablewiseEmbeddingBagConfig -from .cache_mgr import EvictionStrategy - -from typing import List -import abc - - -class ParallelCachedEmbeddingBagTablewiseSpiltCache(abc.ABC, nn.Module): - """ - every table assigned to this class instance is managed by a CachedEmbeddingBag. - """ - - def __init__(self, - embedding_bag_config_list: List[TablewiseEmbeddingBagConfig], - embedding_dim: int, - padding_idx=None, - max_norm=None, - norm_type=2., - scale_grad_by_freq=False, - sparse=False, - mode='mean', - include_last_offset=False, - dtype=None, - device=None, - warmup_ratio=0.7, - pin_weight=False, - evict_strategy: EvictionStrategy = EvictionStrategy.LFU): - super(ParallelCachedEmbeddingBagTablewiseSpiltCache, self).__init__() - self.rank = dist.get_rank() - self.world_size = dist.get_world_size() - self.rank_of_tables = [config.assigned_rank for config in embedding_bag_config_list] - self.global_table_num_embeddings_list = [config.num_embeddings for config in embedding_bag_config_list] - self.global_tables_num = len(embedding_bag_config_list) - self.global_tables_offsets = torch.cumsum(torch.tensor([0] + self.global_table_num_embeddings_list), 0).cuda() - - self.assigned_table_list: List[int] = [] - for i, rank in enumerate(self.rank_of_tables): - if rank == self.rank: - self.assigned_table_list.append(i) - self.include_last_offset = include_last_offset - self.pg = ProcessGroup(tp_degree=self.world_size) - - # prepare CachedEmbeddingBag list - - self.cached_embedding_bag_list: nn.ModuleList = nn.ModuleList() - for config in embedding_bag_config_list: - if config.assigned_rank != self.rank: - continue - self.cached_embedding_bag_list.append( - CachedEmbeddingBag(num_embeddings=config.num_embeddings, - embedding_dim=embedding_dim, - padding_idx=padding_idx, - max_norm=max_norm, - norm_type=norm_type, - scale_grad_by_freq=scale_grad_by_freq, - sparse=sparse, - _weight=config.initial_weight, - mode=mode, - include_last_offset=include_last_offset, - dtype=dtype, - device=device, - cuda_row_num=config.cuda_row_num, - ids_freq_mapping=config.ids_freq_mapping, - warmup_ratio=warmup_ratio, - buffer_size=config.buffer_size, - pin_weight=pin_weight, - evict_strategy=evict_strategy)) - - # prepare list shape for all_to_all output - self.embedding_dim_per_rank = [0 for i in range(self.world_size)] - for rank in self.rank_of_tables: - self.embedding_dim_per_rank[rank] += embedding_dim - - def forward(self, indices: torch.Tensor, offsets: torch.Tensor = None, per_sample_weights=None, shape_hook=None): - # determine indices to handle - batch_size = (offsets.shape[0]) // self.global_tables_num - local_output_list = [] - for i, handle_table in enumerate(self.assigned_table_list): - with record_function("(tablewise) prepare indices and offsets"): - with record_function("part 1"): - indices_start_position = offsets[batch_size * handle_table] - if (not self.include_last_offset) and (batch_size * (handle_table + 1) >= indices.shape[0]): - # till the end special case - indices_end_position = indices.shape[0] - else: - indices_end_position = offsets[batch_size * (handle_table + 1)] - with record_function("part 2"): - # local_indices = indices[indices_start_position:indices_end_position] - self.global_tables_offsets[handle_table] - local_indices = indices.narrow(0, indices_start_position, indices_end_position - - indices_start_position).sub(self.global_tables_offsets[handle_table]) - if self.include_last_offset: - # local_offsets = offsets[batch_size * handle_table:batch_size * (handle_table + 1) + 1] - offsets[batch_size * (handle_table)] - local_offsets = offsets.narrow(0, batch_size * handle_table, - batch_size + 1).sub(offsets[batch_size * (handle_table)]) - else: - # local_offsets = offsets[batch_size * handle_table:batch_size * (handle_table + 1)] - offsets[batch_size * (handle_table)] - local_offsets = offsets.narrow(0, batch_size * handle_table, - batch_size).sub(offsets[batch_size * (handle_table)]) - local_per_sample_weights = None - if per_sample_weights != None: - local_per_sample_weights = per_sample_weights[indices_start_position:indices_end_position] - with record_function("(tablewise) tablewise forward"): - local_output_list.append(self.cached_embedding_bag_list[i](local_indices, local_offsets, - local_per_sample_weights)) - - # get result of shape = (batch_size, (len(assigned_table_list)*embedding_dim)) - local_output = torch.cat(local_output_list, 1) - # then concatenate those local_output on the second demension. - # use all_to_all - remains = batch_size % self.world_size - scatter_strides = [batch_size // self.world_size + int(i < remains) for i in range(self.world_size)] - output_full = dual_all_to_all_tablewise(local_output, self.pg, scatter_strides, self.embedding_dim_per_rank) - if shape_hook is not None: - output_full = shape_hook(output_full) - return output_full - - def element_size(self): - if len(self.assigned_table_list) == 0: - return 0 - return self.cached_embedding_bag_list[0].cache_weight_mgr.weight.element_size() - - def print_comm_stats_(self): - cuda_to_cpu_elem_num = 0 - cpu_to_cuda_elem_num = 0 - for cached_embedding_bag in self.cached_embedding_bag_list: - cuda_to_cpu_elem_num += cached_embedding_bag.cache_weight_mgr._cuda_to_cpu_numel - cpu_to_cuda_elem_num += cached_embedding_bag.cache_weight_mgr._cpu_to_cuda_numel - print(f"CUDA->CPU num: {cuda_to_cpu_elem_num / 1e6} M elem") - print(f"CPU->CUDA num: {cpu_to_cuda_elem_num / 1e6} M elem") diff --git a/colossalai/nn/parallel/layers/colo_module.py b/colossalai/nn/parallel/layers/colo_module.py deleted file mode 100644 index 8f0f5d5f520a17c979a21048903b025dc642296f..0000000000000000000000000000000000000000 --- a/colossalai/nn/parallel/layers/colo_module.py +++ /dev/null @@ -1,46 +0,0 @@ -from colossalai.tensor.distspec import _DistSpec -from colossalai.tensor import ComputePattern -from typing import List, Dict - - -class ColoModule(object): - - def __init__(self): - self._shard_params: List[str] = [] - self._allowed_patterns: Dict[ComputePattern, Dict[str, Dict[str, _DistSpec]]] = {} - - def _register_shard_params(self, params: List[str]): - self._shard_params = params - - def _register_allowed_patterns(self, - compute_pattern: ComputePattern, - dist_specs: Dict[str, _DistSpec], - mode='default'): - assert list( - dist_specs.keys()).sort() == self._shard_params.sort(), 'Every registered param should have dist_spec.' - if not compute_pattern in self._allowed_patterns: - self._allowed_patterns[compute_pattern] = {} - self._allowed_patterns[compute_pattern][mode] = dist_specs - - def _set_default(self, compute_pattern: ComputePattern, target_mode): - self._allowed_patterns[compute_pattern]['default'] = self._allowed_patterns[compute_pattern][target_mode] - - def has_compute_pattern(self, compute_pattern: ComputePattern): - return compute_pattern in self._allowed_patterns - - def get_dist_specs(self, compute_pattern: ComputePattern): - assert self.has_compute_pattern(compute_pattern) - return self._allowed_patterns[compute_pattern] - - def has_compute_pattern_with_mode(self, compute_pattern: ComputePattern, mode='default'): - return compute_pattern in self._allowed_patterns and mode in self._allowed_patterns[compute_pattern] - - def get_dist_specs_with_mode(self, compute_pattern: ComputePattern, mode='default'): - assert self.has_compute_pattern_with_mode(compute_pattern, mode) - return self._allowed_patterns[compute_pattern][mode] - - def get_param_names(self): - return self._shard_params - - def register(self, compute_pattern, pg): - raise NotImplementedError diff --git a/colossalai/nn/parallel/layers/embedding.py b/colossalai/nn/parallel/layers/embedding.py deleted file mode 100644 index ccacc1ead297b349c7252aec37a41f106dff7993..0000000000000000000000000000000000000000 --- a/colossalai/nn/parallel/layers/embedding.py +++ /dev/null @@ -1,36 +0,0 @@ -from .colo_module import ColoModule -from colossalai.tensor import ComputePattern, distspec, ProcessGroup, ShardSpec - - -class ColoEmbedding(ColoModule): - - def __init__(self): - super(ColoEmbedding, self).__init__() - self._register_shard_params(['weight']) - - def register(self, compute_pattern, pg: ProcessGroup): - if not compute_pattern in self._allowed_patterns: - if ComputePattern.TP1D == compute_pattern: - self._set_TP1D(pg) - - def _set_TP1D(self, pg: ProcessGroup): - # TP1D Row Linear - _compute_pattern = ComputePattern.TP1D - self._register_allowed_patterns( - compute_pattern=_compute_pattern, - dist_specs={ - 'weight': ShardSpec([0], [pg.tp_world_size()]), - }, - mode='row', - ) - - # TP1D Col Linear - self._register_allowed_patterns( - compute_pattern=_compute_pattern, - dist_specs={ - 'weight': ShardSpec([-1], [pg.tp_world_size()]), - }, - mode='col', - ) - - self._set_default(compute_pattern=_compute_pattern, target_mode='row') diff --git a/colossalai/nn/parallel/layers/linear.py b/colossalai/nn/parallel/layers/linear.py deleted file mode 100644 index 84a8c042587dfdb279f4bd1f83e07d317bc57d08..0000000000000000000000000000000000000000 --- a/colossalai/nn/parallel/layers/linear.py +++ /dev/null @@ -1,38 +0,0 @@ -from .colo_module import ColoModule -from colossalai.tensor import ComputePattern, distspec, ProcessGroup, ShardSpec - - -class ColoLinear(ColoModule): - - def __init__(self): - super(ColoLinear, self).__init__() - self._register_shard_params(['weight', 'bias']) - - def register(self, compute_pattern, pg: ProcessGroup): - if not compute_pattern in self._allowed_patterns: - if ComputePattern.TP1D == compute_pattern: - self._set_TP1D(pg) - - def _set_TP1D(self, pg): - # TP1D Row Linear - _compute_pattern = ComputePattern.TP1D - self._register_allowed_patterns( - compute_pattern=_compute_pattern, - dist_specs={ - 'weight': ShardSpec([-1], [pg.tp_world_size()]), - 'bias': None - }, - mode='row', - ) - - # TP1D Col Linear - self._register_allowed_patterns( - compute_pattern=_compute_pattern, - dist_specs={ - 'weight': ShardSpec([0], [pg.tp_world_size()]), - 'bias': ShardSpec([0], [pg.tp_world_size()]) - }, - mode='col', - ) - - self._set_default(compute_pattern=_compute_pattern, target_mode='row') diff --git a/colossalai/pipeline/__init__.py b/colossalai/pipeline/__init__.py index 0fcde970764688c210d37de51964b5b081834666..4754212c1914297c12e012e8d98d6e5bbc6e22cd 100644 --- a/colossalai/pipeline/__init__.py +++ b/colossalai/pipeline/__init__.py @@ -1,4 +1,11 @@ -from .pipelinable import PipelinableContext, PipelinableModel -from .layer_spec import LayerSpec +from .p2p import PipelineP2PCommunication +from .schedule import InterleavedSchedule, OneForwardOneBackwardSchedule, PipelineSchedule +from .stage_manager import PipelineStageManager -__all__ = ['PipelinableModel', 'PipelinableContext', 'LayerSpec'] \ No newline at end of file +__all__ = [ + "PipelineSchedule", + "OneForwardOneBackwardSchedule", + "InterleavedSchedule", + "PipelineP2PCommunication", + "PipelineStageManager", +] diff --git a/colossalai/pipeline/middleware/__init__.py b/colossalai/pipeline/middleware/__init__.py deleted file mode 100644 index 79e19f9eaf771635852d8ffd747c06ea1209e110..0000000000000000000000000000000000000000 --- a/colossalai/pipeline/middleware/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -from .topo import Topo, Partition, PartitionOutputVal, PartitionInputVal - -__all__ = ['Topo', 'Partition', 'PartitionOutputVal', 'PartitionInputVal'] \ No newline at end of file diff --git a/colossalai/pipeline/middleware/adaptor/__init__.py b/colossalai/pipeline/middleware/adaptor/__init__.py deleted file mode 100644 index 949700a2c49de505b37c408fd44283ef67f9569b..0000000000000000000000000000000000000000 --- a/colossalai/pipeline/middleware/adaptor/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -from .fx import get_topology as get_fx_topology - -__all__ = ['get_fx_topology'] \ No newline at end of file diff --git a/colossalai/pipeline/p2p.py b/colossalai/pipeline/p2p.py new file mode 100644 index 0000000000000000000000000000000000000000..c69bbe6e852122e1a1d2459e573f2a1ce22c29d4 --- /dev/null +++ b/colossalai/pipeline/p2p.py @@ -0,0 +1,223 @@ +#!/usr/bin/env python +# -*- encoding: utf-8 -*- + +import io +import pickle +import re +from typing import Any, List, Optional, Union + +import torch +import torch.distributed as dist +from packaging.version import Version +from torch.distributed import ProcessGroup +from torch.distributed import distributed_c10d as c10d + +from .stage_manager import PipelineStageManager + +_unpickler = pickle.Unpickler + + +def _cuda_safe_tensor_to_object(tensor: torch.Tensor, tensor_size: torch.Size) -> object: + """transform tensor to object with unpickle. + Info of the device in bytes stream will be modified into current device before unpickling + + Args: + tensor (:class:`torch.tensor`): tensor to be unpickled + tensor_size (:class:`torch.Size`): Size of the real info in bytes + + Returns: + Any: object after unpickled + """ + buf = tensor.numpy().tobytes()[:tensor_size] + if b"cuda" in buf: + buf_array = bytearray(buf) + device_index = torch.cuda.current_device() + # There might be more than one output tensors during forward + for cuda_str in re.finditer(b"cuda", buf_array): + pos = cuda_str.start() + buf_array[pos + 5] = 48 + device_index + buf = bytes(buf_array) + + io_bytes = io.BytesIO(buf) + byte_pickler = _unpickler(io_bytes) + unpickle = byte_pickler.load() + + return unpickle + + +def _broadcast_object_list( + object_list: List[Any], src: int, group: ProcessGroup, device: Optional[Union[torch.device, str, int]] = None +): + """This is a modified version of the broadcast_object_list in torch.distribution + The only difference is that object will be move to correct device after unpickled. + If local_rank = src, then object list will be sent to rank src. Otherwise, object list will + be updated with data sent from rank src. + + Args: + object_list (List[Any]): list of object to broadcast + src (int): source rank to broadcast + dst (int): dst rank to broadcast + device (:class:`torch.device`): device to do broadcast. current device in default + + """ + + if c10d._rank_not_in_group(group): + c10d._warn_not_in_group("broadcast_object_list") + return + + is_nccl_backend = c10d._check_for_nccl_backend(group) + current_device = None + + if device is not None: + if is_nccl_backend and device.type != "cuda": + raise ValueError("device type must be cuda for nccl backend") + current_device = device + else: + current_device = torch.device("cpu") + if is_nccl_backend: + current_device = torch.device("cuda", torch.cuda.current_device()) + + my_rank = dist.get_rank() + # Serialize object_list elements to tensors on src rank. + if my_rank == src: + if Version(torch.__version__) >= Version("1.13.0"): + tensor_list, size_list = zip(*[c10d._object_to_tensor(obj, device=current_device) for obj in object_list]) + else: + tensor_list, size_list = zip(*[c10d._object_to_tensor(obj) for obj in object_list]) + object_sizes_tensor = torch.cat(size_list) + else: + object_sizes_tensor = torch.empty(len(object_list), dtype=torch.long) + + if is_nccl_backend: + object_sizes_tensor = object_sizes_tensor.to(current_device) + + # Broadcast object sizes + c10d.broadcast(object_sizes_tensor, src=src, group=group, async_op=False) + + # Concatenate and broadcast serialized object tensors + if my_rank == src: + object_tensor = torch.cat(tensor_list) + else: + object_tensor = torch.empty( # type: ignore[call-overload] + torch.sum(object_sizes_tensor).item(), # type: ignore[arg-type] + dtype=torch.uint8, + ) + + if is_nccl_backend: + object_tensor = object_tensor.to(current_device) + + c10d.broadcast(object_tensor, src=src, group=group, async_op=False) + + # Deserialize objects using their stored sizes. + offset = 0 + + if my_rank != src: + for i, obj_size in enumerate(object_sizes_tensor): + obj_view = object_tensor[offset : offset + obj_size] + obj_view = obj_view.type(torch.uint8) + if obj_view.device != torch.device("cpu"): + obj_view = obj_view.cpu() + offset += obj_size + # unpickle + unpickle_object = _cuda_safe_tensor_to_object(obj_view, obj_size) + + # unconsistence in device + if ( + isinstance(unpickle_object, torch.Tensor) + and unpickle_object.device.index != torch.cuda.current_device() + ): + unpickle_object = unpickle_object.cuda() + + object_list[i] = unpickle_object + + +def _send_object(object: Any, src: int, dst: int, group: ProcessGroup) -> None: + """send anything to dst rank + + Args: + object (Any): object needed to be sent + dst (int): rank of the destination + + Returns: + None + """ + # then broadcast safely + _broadcast_object_list([object], src, group) + + +def _recv_object(src: int, dst: int, group: ProcessGroup) -> Any: + """recv anything from src + + Args: + src (int): source rank of data. local rank will receive data from src rank. + + Returns: + Any: Object received from src. + """ + object_list = [None] + _broadcast_object_list(object_list, src, group) + + return object_list[0] + + +class PipelineP2PCommunication: + def __init__(self, stage_manager: PipelineStageManager) -> None: + self.stage_manager = stage_manager + + def recv_forward(self, prev_rank: int = None) -> Any: + """Copy the forward output from the previous stage in pipeline as the input tensor of this stage. + + Args: + prev_rank (int, optional): The rank of the source of the tensor. + + Returns: + Any: The input tensor or input tensor list. + """ + if prev_rank is None: + prev_rank = self.stage_manager.get_prev_rank() + cur_rank = self.stage_manager.get_rank() + input_tensor = _recv_object(prev_rank, cur_rank, self.stage_manager.get_p2p_process_group(prev_rank, cur_rank)) + + return input_tensor + + def recv_backward(self, next_rank: int = None) -> Any: + """Copy the gradient tensor from the next stage in pipeline as the input gradient of this stage. + + Args: + next_rank (int, optional): The rank of the source of the tensor. + + Returns: + Any: The input gradient tensor or gradient tensor list. + """ + if next_rank is None: + next_rank = self.stage_manager.get_next_rank() + cur_rank = self.stage_manager.get_rank() + output_tensor_grad = _recv_object( + next_rank, cur_rank, self.stage_manager.get_p2p_process_group(next_rank, cur_rank) + ) + + return output_tensor_grad + + def send_forward(self, output_object: Any, next_rank: int = None) -> None: + """Sends the input tensor to the next stage in pipeline. + + Args: + output_object (Any): Object to be sent. + next_rank (int, optional): The rank of the recipient of the tensor. + """ + if next_rank is None: + next_rank = self.stage_manager.get_next_rank() + cur_rank = self.stage_manager.get_rank() + _send_object(output_object, cur_rank, next_rank, self.stage_manager.get_p2p_process_group(cur_rank, next_rank)) + + def send_backward(self, input_object: Any, prev_rank: int = None) -> None: + """Sends the gradient tensor to the previous stage in pipeline. + + Args: + input_object (Any): Object to be sent. + prev_rank (int, optional): The rank of the recipient of the tensor + """ + if prev_rank is None: + prev_rank = self.stage_manager.get_prev_rank() + cur_rank = self.stage_manager.get_rank() + _send_object(input_object, cur_rank, prev_rank, self.stage_manager.get_p2p_process_group(cur_rank, prev_rank)) diff --git a/colossalai/pipeline/rpc/__init__.py b/colossalai/pipeline/rpc/__init__.py deleted file mode 100644 index 9d9e9d44f46c5aea6c0d211d24af21ed834a2503..0000000000000000000000000000000000000000 --- a/colossalai/pipeline/rpc/__init__.py +++ /dev/null @@ -1,4 +0,0 @@ -from ._pipeline_schedule import FillDrainPipelineEngine, OneFOneBPipelineEngine, ChimeraPipelineEngine -from .utils import pytree_map - -__all__ = ['FillDrainPipelineEngine', 'OneFOneBPipelineEngine', 'ChimeraPipelineEngine', 'pytree_map'] \ No newline at end of file diff --git a/colossalai/pipeline/rpc/_pipeline_schedule.py b/colossalai/pipeline/rpc/_pipeline_schedule.py deleted file mode 100644 index 0d572231d37830a89cac07eecf066f7e556ea640..0000000000000000000000000000000000000000 --- a/colossalai/pipeline/rpc/_pipeline_schedule.py +++ /dev/null @@ -1,346 +0,0 @@ -import threading -from typing import Callable, Dict, List - -import torch -import torch.distributed as dist -from torch._C._distributed_rpc import PyRRef -from torch.futures import Future - -from colossalai.pipeline.pipeline_process_group import ppg -from colossalai.pipeline.rpc._pipeline_base import Phase, PipelineEngineBase, UniqueKey, WorkerBase, WorkItem - -# Implementation of different Pipeline schedule -# Worker defines the worker for each stage -# PipelineEngine is the class for use - - -class FillDrainWorker(WorkerBase): - - def _get_work_item_key(self) -> UniqueKey: - # execute backward first (if backward phase in work_list) - num_microbatches = self.num_microbatches - - if self.forward_times < num_microbatches: - target_phase = Phase.FORWARD - target_microbatch_id = self.forward_times - else: - target_phase = Phase.BACKWARD - target_microbatch_id = self.backward_times - - target_key = UniqueKey(target_microbatch_id, target_phase) - - return target_key - - -class FillDrainPipelineEngine(PipelineEngineBase): - - def __init__(self, - partition_fn: Callable, - stage_num: int, - num_microbatches: int, - device: str, - chunk: int = 1, - criterion: Callable = None, - metric: Callable = None, - checkpoint: bool = False, - data_process_func: Callable = None) -> None: - - if chunk > 1: - assert num_microbatches % stage_num == 0, \ - "if you use interleaving strategy, make sure 'num_microbatches' is a multiple of stage_num!" - use_1F1B = False - - super().__init__(FillDrainWorker, partition_fn, stage_num, num_microbatches, device, use_1F1B, chunk, criterion, - metric, checkpoint, data_process_func) - - -class OneFOneBWorker(WorkerBase): - - def _get_work_item_key(self) -> UniqueKey: - # execute backward first (if backward phase in work_list) - pp_rank = self.pp_rank - actual_stage_num = self.actual_stage_num - num_microbatches = self.num_microbatches - is_last_stage = pp_rank == actual_stage_num - 1 - - if self.outstanding <= self.outstanding_range[0]: - target_phase = Phase.FORWARD - target_microbatch_id = self.forward_times - elif self.outstanding >= self.outstanding_range[1]: - target_phase = Phase.BACKWARD - target_microbatch_id = self.backward_times - else: - raise ValueError("outstanding_range[1] - outstanding_range[0] must be in [0, 1]") - - target_key = UniqueKey(target_microbatch_id, target_phase) - - # change outstanding_range at: - # 1. forward times reach actual_stage_num, this is the end of continuous forward - # 2. forward times reach num_microbatches, this is the end of 1F1B mode - if not is_last_stage and \ - target_key.phase == Phase.FORWARD: - if target_key.microbatch_id == actual_stage_num - 1 and num_microbatches > 2: - # Why need num_microbatches > 2 ? Because there is no steady stage when num_microbatches <= 2 - outstanding_min = actual_stage_num - pp_rank - 1 - outstanding_max = actual_stage_num - pp_rank - self.outstanding_range = (outstanding_min, outstanding_max) - if target_key.microbatch_id == num_microbatches - 1: - self.outstanding_range = (0, 0) - - return target_key - - -class OneFOneBPipelineEngine(PipelineEngineBase): - - def __init__(self, - partition_fn: Callable, - stage_num: int, - num_microbatches: int, - device: str, - chunk: int = 1, - criterion: Callable = None, - metric: Callable = None, - checkpoint: bool = False, - data_process_func: Callable = None) -> None: - - if chunk > 1: - assert num_microbatches % stage_num == 0, \ - "if you use interleaving strategy, make sure 'num_microbatches' is a multiple of stage_num!" - # assert num_microbatches > stage_num * chunk, "num_microbatches must be greater than stage_num * chunk" - use_1F1B = True - - super().__init__(OneFOneBWorker, partition_fn, stage_num, num_microbatches, device, use_1F1B, chunk, criterion, - metric, checkpoint, data_process_func) - - -class ChimeraWorker(WorkerBase): - - def _get_producer_consumer(self) -> None: - rank = self.pp_rank - min_pp_rank = (rank // self.actual_stage_num) * self.actual_stage_num - max_pp_rank = min_pp_rank + self.actual_stage_num - 1 - - assert self.producer_stage_ids is None, f"all the producers of rank {rank} has been subscribed" - assert self.consumer_stage_ids is None, f"all the consumers of rank {rank} has been subscribed" - - # should be aranged in order, the order of the input of current forward - self.producer_stage_ids = [] - self.consumer_stage_ids = [] - - # Just for demo - prev_rank = rank - 1 - next_rank = rank + 1 - if prev_rank >= min_pp_rank: - self.producer_stage_ids.append(prev_rank) - if next_rank <= max_pp_rank: - self.consumer_stage_ids.append(next_rank) - - def _get_work_item_key(self) -> UniqueKey: - pp_rank = self.pp_rank - stage_num = self.actual_stage_num - real_microbatch_num = self.num_microbatches // 2 - - forward_block_size = 1 if self.num_microbatches < stage_num else self.num_microbatches // stage_num - forward_block_num = self.forward_times // forward_block_size - - if self.forward_times >= real_microbatch_num or \ - ((pp_rank + 1) % stage_num == 0 and forward_block_num > self.backward_times): - target_phase = Phase.BACKWARD - target_microbatch_id = self.backward_times - else: # others - target_phase = Phase.FORWARD - target_microbatch_id = self.forward_times - - # In up pipeline, microbatch_id to consume is 0, 2, 4 (2n) - # In down pipeline, microbatch_id to consume is 1, 3, 5 (2n + 1) - real_target_microbatch_id = target_microbatch_id * 2 - if pp_rank >= stage_num: - real_target_microbatch_id += 1 - target_key = UniqueKey(real_target_microbatch_id, target_phase) - - with self.work_list_condition_lock: - self.work_list_condition_lock.wait_for(lambda: target_key in self.work_list) - return target_key - - def _initialize_partition(self): - # In order to ensure the down pipeline share the same parameter - # with the up pipeline, partition of down partition will be copied - # from corresponding up stage - pp_rank = self.pp_rank - stage_num = self.actual_stage_num - device = self.device - if pp_rank < stage_num: - super()._initialize_partition() - else: - # if it is down pipeline, create partition by origin method - co_up_pp_worker_rref = self.pp_rank_to_worker_rref[pp_rank - stage_num] - # get the coresponding model state dict and wait for its init - state_dict = co_up_pp_worker_rref.rpc_sync().get_partition_state_dict() - super()._initialize_partition() - self.module_partition.load_state_dict(state_dict) - - # init group for chimera in ppg - ppg.get_chimera_all_reduce_group(pp_rank) - - # lock for step sync - self.step_sync_lock = threading.Lock() - self.step_sync_lock.acquire() - - self.have_grad_lock = threading.Lock() - self.have_grad_lock.acquire() - - def _get_lock_gradient(self): - self.have_grad_lock.acquire() - grads = self.get_parameter_gradients() - self.step_sync_lock.release() - return grads - - def is_first_stage(self): - return (self.pp_rank % self.actual_stage_num) == 0 - - def is_last_stage(self): - return (self.pp_rank % self.actual_stage_num) == self.actual_stage_num - 1 - - def _is_last_step(self, work_item: WorkItem) -> bool: - if work_item.forward_only: - last_phase = Phase.FORWARD - else: - last_phase = Phase.BACKWARD - is_last_phase = work_item.phase == last_phase - last_microbatch_id = self.num_microbatches - 1 - if self.pp_rank < self.actual_stage_num: - last_microbatch_id -= 1 - is_last_microbatch = work_item.microbatch_id == last_microbatch_id - return is_last_phase and is_last_microbatch - - def _get_step_order(self) -> List[int]: - # TODO : If you want to extend it to multi head chimera, overwrite here - stage_num = self.actual_stage_num - pp_rank = self.pp_rank - # pp_rank in the same device - local_device_pp_ranks = [pp_rank, stage_num * 2 - pp_rank - 1] - local_device_pp_ranks.sort(reverse=min(local_device_pp_ranks) < stage_num // 2) - return local_device_pp_ranks - - def _hook_before_step(self): - self.have_grad_lock.release() - pp_rank = self.pp_rank - stage_num = self.actual_stage_num - co_pp_rank = (pp_rank + stage_num) % (2 * stage_num) - - # if currrent pp_rank is not the first to do step - # wait its previous pp_rank finish step - grads = self.get_parameter_gradients() - - # send - co_worker = self.pp_rank_to_worker_rref[co_pp_rank] - co_grads = co_worker.rpc_sync()._get_lock_gradient() - # sync - self.step_sync_lock.acquire() - for i in range(len(grads)): - grads[i] += co_grads[i] - - -class ChimeraPipelineEngine(PipelineEngineBase): - - def __init__(self, - partition_fn: Callable, - stage_num: int, - num_microbatches: int, - device: str, - criterion: Callable = None, - metric: Callable = None, - checkpoint: bool = False, - data_process_func: Callable = None) -> None: - - assert num_microbatches % stage_num == 0, \ - "In Chimera, num_microbatches must be the multiply of stage_num!" - use_1F1B = False - chunk = 1 - - super().__init__(ChimeraWorker, partition_fn, stage_num, num_microbatches, device, use_1F1B, chunk, criterion, - metric, checkpoint, data_process_func) - - def _consume_constraint(self, microbatch_id: int, forward_only: bool, input_pp_ranks: List[int], - output_pp_ranks: List[int], ret_future): - pass - - def _create_pp_rank_to_rpc_worker_id(self) -> None: - stage_num = self.stage_num - self.pp_rank_to_rpc_worker_id = [0] * (stage_num * 2) - for pp_rank in range(stage_num): - self.pp_rank_to_rpc_worker_id[pp_rank] = pp_rank - self.pp_rank_to_rpc_worker_id[pp_rank + stage_num] = stage_num - pp_rank - 1 - - def _create_pp_rank_to_module_partition_id(self) -> None: - stage_num = self.stage_num - self.pp_rank_to_module_partition_id = [0] * (stage_num * 2) - for pp_rank in range(stage_num): - self.pp_rank_to_module_partition_id[pp_rank] = pp_rank - self.pp_rank_to_module_partition_id[pp_rank + stage_num] = pp_rank - - def _create_ret_future(self, output_pp_ranks: List[int]) -> Dict[int, List[Future]]: - num_microbatches = self.num_microbatches - stage_num = self.stage_num - up_ret_future = {pp_rank: [None] * num_microbatches for pp_rank in output_pp_ranks} - down_ret_future = {pp_rank + stage_num: [None] * num_microbatches for pp_rank in output_pp_ranks} - # merge up and down - return {**up_ret_future, **down_ret_future} - - def _set_input(self, input_pp_ranks: List[int], microbatch_id: int, microbatch, forward_only: bool): - # offset is 0 for all the ranks in up pipeline - # offset is stage_num for all the ranks in down pipeline - offset = (microbatch_id % 2) * self.stage_num - for pp_rank in input_pp_ranks: - worker_rref = self.pp_rank_to_worker_rref[pp_rank + offset] - worker_rref.remote().set_input(microbatch_id, microbatch, forward_only) - - def _set_labels(self, output_pp_ranks: List[int], microbatch_id: int, microlabels): - # offset is 0 for all the ranks in up pipeline - # offset is stage_num for all the ranks in down pipeline - offset = (microbatch_id % 2) * self.stage_num - for pp_rank in output_pp_ranks: - worker_rref = self.pp_rank_to_worker_rref[pp_rank + offset] - worker_rref.remote().set_labels(microbatch_id, microlabels) - - def _subscribe_forward(self, microbatch_id: int, output_pp_ranks: List[int], ret_future: Dict[int, List[Future]]): - key = UniqueKey(microbatch_id, Phase.FORWARD) - offset = (microbatch_id % 2) * self.stage_num - for pp_rank in output_pp_ranks: - worker_rref = self.pp_rank_to_worker_rref[pp_rank + offset] - ret_future[pp_rank + offset][microbatch_id] = worker_rref.rpc_async().get_output_by_key(key) - - def _ensure_backward(self, forward_only: bool, input_pp_ranks: List[int]): - stage_num = self.stage_num - num_microbatches = self.num_microbatches - if not forward_only: - for pp_rank in input_pp_ranks: - up_last_microbatch_id = num_microbatches - 2 - down_last_microbatch_id = num_microbatches - 1 - - up_worker_rref = self.pp_rank_to_worker_rref[pp_rank] - down_worker_rref = self.pp_rank_to_worker_rref[pp_rank + stage_num] - - up_key = UniqueKey(up_last_microbatch_id, Phase.BACKWARD) - down_key = UniqueKey(down_last_microbatch_id, Phase.BACKWARD) - up_worker_rref.rpc_sync().get_output_by_key(up_key) - down_worker_rref.rpc_sync().get_output_by_key(down_key) - - def _collect_forward_result(self, output_pp_ranks: List[int], ret_future: Dict[PyRRef, List[Future]]): - """Logic of collection of forward in Chimera. - Currently, only one input one output model is supported - """ - stage_num = self.stage_num - forward_result = [] - for pp_rank in output_pp_ranks: - worker_forward_result = [None] * self.num_microbatches - for microbatch_id in range(self.num_microbatches): - offset = (microbatch_id % 2) * stage_num - ret = ret_future[pp_rank + offset][microbatch_id].wait() - ret = [ret] if isinstance(ret, torch.Tensor) else ret - worker_forward_result[microbatch_id] = ret - - worker_forward_result = list(zip(*worker_forward_result)) - forward_result.extend(worker_forward_result) - - return forward_result diff --git a/colossalai/pipeline/rpc/utils.py b/colossalai/pipeline/rpc/utils.py deleted file mode 100644 index 06e6d976d7715cd9d6d57f4e196bd0bd1b3116cd..0000000000000000000000000000000000000000 --- a/colossalai/pipeline/rpc/utils.py +++ /dev/null @@ -1,155 +0,0 @@ -import argparse -import os -import warnings -from typing import Any, Callable, Dict, List, Tuple, Type, Union - -import torch -import torch.distributed.rpc as rpc -import torch.multiprocessing as mp -from torch._C._distributed_rpc import _is_current_rpc_agent_set -from torch.futures import Future - -from colossalai.initialize import launch -from colossalai.pipeline.pipeline_process_group import ppg - - -def pyobj_map(obj: Any, fn: Callable, process_types: Union[Type, Tuple[Type]] = ()) -> Any: - if isinstance(obj, process_types): - return fn(obj) - elif type(obj) is dict: - return {k: pyobj_map(obj[k], fn, process_types) for k in obj} - elif type(obj) is tuple: - return tuple(pyobj_map(o, fn, process_types) for o in obj) - elif type(obj) is list: - return list(pyobj_map(o, fn, process_types) for o in obj) - else: - return obj - - -def pytree_map(obj: Any, fn: Callable, process_types: Union[Type, Tuple[Type]] = (), map_all: bool = False) -> Any: - """process object recursively, like pytree - - Args: - obj (:class:`Any`): object to process - fn (:class:`Callable`): a function to process subobject in obj - process_types (:class: `type | tuple[type]`): types to determine the type to process - map_all (:class: `bool`): if map_all is True, then any type of element will use fn - - Returns: - :class:`Any`: returns have the same structure of `obj` and type in process_types after map of `fn` - """ - if isinstance(obj, dict): - return {k: pytree_map(obj[k], fn, process_types, map_all) for k in obj} - elif isinstance(obj, tuple): - return tuple(pytree_map(o, fn, process_types, map_all) for o in obj) - elif isinstance(obj, list): - return list(pytree_map(o, fn, process_types, map_all) for o in obj) - elif isinstance(obj, process_types): - return fn(obj) - else: - return fn(obj) if map_all else obj - - -def tensor_shape_list(obj): - return pytree_map(obj, fn=lambda x: x.shape, process_types=torch.Tensor) - - -def get_batch_lengths(batch): - lengths = [] - pytree_map(batch, fn=lambda x: lengths.append(len(x)), process_types=torch.Tensor) - return lengths - - -def split_batch(batch: Any, start, stop, device: str): - if device == 'cuda': - fn = lambda x: x[start:stop].cuda() - else: - fn = lambda x: x[start:stop] - return pytree_map(batch, fn=fn, process_types=torch.Tensor) - - -def type_detail(obj): - return pytree_map(obj, lambda x: type(x), map_all=True) - - -def pytree_filter(fn, obj, process_types): - if obj is None: - return None - - filters = [] - - def condition_append(obj): - if fn(obj): - filters.append(obj) - - pytree_map(obj, fn=condition_append, process_types=process_types) - return filters - - -def get_real_args_kwargs(args_or_kwargs): - args_or_kwargs = pytree_map(args_or_kwargs, fn=lambda x: x.wait(), process_types=Future) - # TODO : combine producer and consumer - # by default, merge all args in the output args or kwargs - if args_or_kwargs is not None: - if isinstance(args_or_kwargs, dict): - pass - else: - flatten_args = [] - pytree_map(args_or_kwargs, fn=lambda x: flatten_args.append(x), map_all=True) - args_or_kwargs = flatten_args - - return args_or_kwargs - - -def run_worker(rank, args, master_func): - os.environ['MASTER_ADDR'] = args.master_addr - os.environ['MASTER_PORT'] = args.master_port - - device = args.device - world_size = args.world_size - dp_degree = args.dp_degree - tp_degree = args.tp_degree - num_worker_threads = args.num_worker_threads - host = args.master_addr - port = args.master_port - backend = 'nccl' if device == 'cuda' else 'gloo' - - launch(dict(), rank, world_size, host, int(port), backend, verbose=False) - ppg.set_global_info(rank=rank, - world_size=world_size, - dp_degree=dp_degree, - tp_degree=tp_degree, - num_worker_threads=num_worker_threads, - device=device) - ppg.args = args - # in rpc mode, only rank 0 is needed to be coded - if rank == 0: - master_func(args) - # barrier here - if _is_current_rpc_agent_set(): - rpc.shutdown() - else: - warnings.warn("RPC has not been initialized") - - -def rpc_run(args, master_func): - world_size = args.world_size - mp.spawn(run_worker, args=(args, master_func), nprocs=world_size) - - -def parse_args(): - parser = argparse.ArgumentParser() - parser.add_argument('--epoch', type=int, default=1) - parser.add_argument('--world_size', type=int, default=2) - parser.add_argument('--batch_size', type=int, default=16) - parser.add_argument('--dp_degree', type=int, default=1) - parser.add_argument('--tp_degree', type=int, default=1) - parser.add_argument('--num_microbatches', type=int, default=2) - parser.add_argument('--chunk', type=int, default=1) - parser.add_argument('--use_checkpoint', action='store_true') - parser.add_argument('--optimizer', type=str, choices=['SGD', 'Adam', 'RMSprop'], default='SGD') - parser.add_argument('--device', type=str, choices=['cpu', 'cuda'], default='cuda') - parser.add_argument('--master_addr', type=str, default='localhost') - parser.add_argument('--master_port', type=str, default='29020') - parser.add_argument('--num_worker_threads', type=int, default=128) - return parser.parse_args() diff --git a/colossalai/pipeline/schedule/__init__.py b/colossalai/pipeline/schedule/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..6845dc23753b2ed1d94c6216fe8f6538cf43f853 --- /dev/null +++ b/colossalai/pipeline/schedule/__init__.py @@ -0,0 +1,9 @@ +from .base import PipelineSchedule +from .interleaved_pp import InterleavedSchedule +from .one_f_one_b import OneForwardOneBackwardSchedule + +__all__ = [ + "PipelineSchedule", + "OneForwardOneBackwardSchedule", + "InterleavedSchedule", +] diff --git a/colossalai/pipeline/schedule/_utils.py b/colossalai/pipeline/schedule/_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..271b3238f5c4ca2bedd839f34f82ce6397329fc7 --- /dev/null +++ b/colossalai/pipeline/schedule/_utils.py @@ -0,0 +1,175 @@ +from collections import OrderedDict +from typing import Any, List, Optional, Tuple + +import torch +import torch.cuda +from torch.nn import Module +from torch.utils._pytree import SUPPORTED_NODES, TreeSpec, _register_pytree_node, tree_flatten, tree_map, tree_unflatten + + +# this register are for torch under version 1.13.1, maybe removed in the future +def _odict_flatten(d: "OrderedDict[Any, Any]") -> Tuple[List[Any], Any]: + return list(d.values()), list(d.keys()) + + +def _odict_unflatten(values: List[Any], context: Any) -> "OrderedDict[Any, Any]": + return OrderedDict((key, value) for key, value in zip(context, values)) + + +_register_pytree_node(OrderedDict, _odict_flatten, _odict_unflatten) + + +def tree_map_hf(fn: Any, pytree: Any): + flat_args, spec = tree_flatten_hf(pytree) + return tree_unflatten([fn(i) for i in flat_args], spec) + + +# use this flatten function to handle the ModelingOutput Class instance. +def tree_flatten_hf(pytree: Any) -> Tuple[List[Any], TreeSpec]: + """Flattens a pytree into a list of values an a TreeSpec that can be used + to reconstruct the pytree. + """ + if isinstance(pytree, OrderedDict): + node_type = OrderedDict + flatten_fn = SUPPORTED_NODES[node_type].flatten_fn + child_pytrees, context = flatten_fn(pytree) + + # Recursively flatten the children + result: List[Any] = [] + children_specs: List["TreeSpec"] = [] + for child in child_pytrees: + flat, child_spec = tree_flatten_hf(child) + result += flat + children_specs.append(child_spec) + return result, TreeSpec(node_type, context, children_specs) + else: + result, tree_spec = tree_flatten(pytree) + return result, tree_spec + + +def to_device(x: Any, device: Optional[torch.device] = None) -> Any: + """Move object to device if it is a tensor. + + Args: + x (Any): Object to be moved. + device (Optional[torch.device], optional): Target device. Defaults to None. + + Returns: + Any: Moved object. + """ + if isinstance(x, torch.Tensor): + return x.to(device) + return x + + +def get_batch_size(batch: Any) -> int: + """Get the batch size (size of dimension-0) of the first tensor in the batch. + + Args: + batch (Any): Batch to be inspected. + + Raises: + RuntimeError: If no tensor is found in the batch. + + Returns: + int: Batch size. + """ + data_list, _ = tree_flatten(batch) + for data in data_list: + if isinstance(data, torch.Tensor): + return data.size(0) + raise RuntimeError("No tensor found in the batch") + + +def get_micro_batch(batch: Any, start: int, micro_batch_size: int) -> Any: + """Get a micro batch of the original batch. + + Args: + batch (Any): Batch to be sliced. + start (int): Start index of the micro batch. + micro_batch_size (int): Size of the micro batch. + + Returns: + Any: Target micro batch. + """ + + def _get_tensor_slice(x: Any): + if isinstance(x, torch.Tensor): + return x[start : start + micro_batch_size] + return x + + return tree_map(_get_tensor_slice, batch) + + +def model_forward(model: Module, data: Any, internal_inputs: Optional[dict]) -> Any: + """Call model forward function with data and internal inputs. + + Args: + model (Module): Model to be called. + data (Any): Data loaded from data iterator. + internal_inputs (Optional[dict]): Data from previous stage. It must be a dict or None if it's the first stage. + + Returns: + Any: Outputs of the model. + """ + if internal_inputs is None: + internal_inputs = {} + if isinstance(data, (list, tuple)): + return model(*data, **internal_inputs) + elif isinstance(data, dict): + return model(**data, **internal_inputs) + return model(data, **internal_inputs) + + +def retain_grad(x: Any) -> None: + """Call retain_grad() on a tensor. + + Args: + x (Any): Object to be called. + """ + if isinstance(x, torch.Tensor) and x.requires_grad: + x.retain_grad() + + +def detach(x: Any) -> Any: + """Call detach() on a tensor. + + Args: + x (Any): Object to be called. + + Returns: + Any: The detached object. + """ + if isinstance(x, torch.Tensor): + return x.detach() + return x + + +def merge_batch(data: List[Any], batch_size_dim=0) -> Any: + """Merge micro batches into a batch. + + Args: + data (List[Any]): A list of micro batches. + + Returns: + Any: Merge batch. + """ + if len(data) == 0: + return + flattened_data = [] + tree_spec = None + for d in data: + # elems should be an instance of OrderedDict + elems, tree_spec = tree_flatten_hf(d) + flattened_data.append(elems) + merged_data = [] + + for elem_batch in zip(*flattened_data): + if isinstance(elem_batch[0], torch.Tensor): + if len(elem_batch[0].shape) == 0: # set loss to None in pipeline outputs + merged_data.append(None) + else: + merged_data.append(torch.cat(elem_batch, dim=batch_size_dim)) + else: + merged_data.append(list(elem_batch)) + return tree_unflatten(merged_data, tree_spec) diff --git a/colossalai/pipeline/schedule/base.py b/colossalai/pipeline/schedule/base.py new file mode 100644 index 0000000000000000000000000000000000000000..1bce297862c830f604f4a1a1dfd45244daa07cc2 --- /dev/null +++ b/colossalai/pipeline/schedule/base.py @@ -0,0 +1,36 @@ +from typing import Any, Callable, Iterable, Optional + +from torch import Tensor +from torch.nn import Module + +from colossalai.interface import OptimizerWrapper +from colossalai.pipeline.stage_manager import PipelineStageManager + + +class PipelineSchedule: + def __init__(self, stage_manager: PipelineStageManager) -> None: + self.stage_manager = stage_manager + + def forward_backward_step( + self, + model: Module, + data_iter: Iterable, + criterion: Callable[[Any, Any], Tensor], + optimizer: Optional[OptimizerWrapper] = None, + return_loss: bool = False, + return_outputs: bool = False, + ) -> dict: + """Forward and backward step for pipeline training. + + Args: + model (Module): Model to be trained. + data_iter (Iterable): Data iterator. + criterion (Callable[[Any, Any], Tensor]): Criterion to be used. It should take two arguments: model outputs and inputs, and returns loss tensor. + optimizer (OptimizerWrapper, optional): Optimizer to be used. Can be None when only forward is executed. Defaults to None. + return_loss (bool, optional): Whether to return loss. Defaults to False. Whether to return loss. + return_outputs (bool, optional): Whether to return model outputs. Defaults to False. Whether to return model outputs. + + Returns: + dict: A dict with keys: 'loss' and 'outputs'. + """ + raise NotImplementedError diff --git a/colossalai/pipeline/schedule/interleaved_pp.py b/colossalai/pipeline/schedule/interleaved_pp.py new file mode 100644 index 0000000000000000000000000000000000000000..780437155c617b06b6c142f265453ba8c464bca8 --- /dev/null +++ b/colossalai/pipeline/schedule/interleaved_pp.py @@ -0,0 +1,380 @@ +from functools import partial +from typing import Any, Callable, Iterable, List, Optional, Union + +import torch +import torch.cuda +from torch.nn import Module +from torch.utils._pytree import tree_map + +from colossalai.interface import OptimizerWrapper +from colossalai.pipeline.p2p import PipelineP2PCommunication +from colossalai.pipeline.stage_manager import PipelineStageManager +from colossalai.utils.cuda import get_current_device + +from ._utils import detach, get_batch_size, get_micro_batch, merge_batch, model_forward, retain_grad, to_device +from .base import PipelineSchedule + + +class InterleavedSchedule(PipelineSchedule): + def __init__(self, num_microbatches: int, num_model_chunks: int, stage_manager: PipelineStageManager) -> None: + self.num_model_chunks = num_model_chunks + assert ( + num_microbatches % self.num_model_chunks == 0 + ), "Number of microbatches should be an integer multiple of number of model chunks" + super().__init__(stage_manager) + self.comm = PipelineP2PCommunication(stage_manager) + self.num_microbatches = num_microbatches + self.batch: Optional[Any] = None + self.batch_size: Optional[int] = None + self.microbatch_offset: Optional[int] = None + self.microbatch_size: Optional[int] = None + + def load_batch(self, data_iter: Iterable, device: Optional[torch.device] = None) -> None: + """Load a batch from data iterator. + + Args: + data_iter (Iterable): Data iterator. + device (Optional[torch.device], optional): Target device. Defaults to None. + """ + batch = next(data_iter) + if device is not None: + batch = tree_map(partial(to_device, device=device), batch) + self.batch = batch + self.batch_size = get_batch_size(batch) + self.microbatch_offset = [0 for _ in range(self.num_model_chunks)] + assert self.batch_size % self.num_microbatches == 0, "Batch size should divided by the number of microbatches" + self.microbatch_size = self.batch_size // self.num_microbatches + + def load_micro_batch(self, model_chunk_id: int) -> Any: + """Load a micro batch from the current batch. + + Args: + microbatch_id (int): the current model chunk idx. + + Returns: + Any: Micro batch. + """ + micro_batch = get_micro_batch(self.batch, self.microbatch_offset[model_chunk_id], self.microbatch_size) + self.microbatch_offset[model_chunk_id] += self.microbatch_size + return tree_map(partial(to_device, device=get_current_device()), micro_batch) + + def get_model_chunk_id(self, microbatch_id: int, forward: bool) -> int: + """Helper method to get the model chunk ID given the iteration number. + + Args: + microbatch_id (int): the current microbatch idx + forward (bool): if is the forward process + + Returns: + int: The model chunk idx of the input microbatch_id + """ + microbatch_id_in_group = (microbatch_id) % (self.stage_manager.num_stages * self.num_model_chunks) + model_chunk_id = microbatch_id_in_group // self.stage_manager.num_stages + if not forward: + model_chunk_id = self.num_model_chunks - model_chunk_id - 1 + return model_chunk_id + + def is_first_stage(self, model_chunk_id: int) -> bool: + """Is the current virtual stage the first stage + + Args: + model_chunk_id (int): The current model chunk idx. + + Returns: + bool: Whether the current virtual stage is the first stage. + """ + if self.stage_manager.is_first_stage() and model_chunk_id == 0: + return True + return False + + def is_last_stage(self, model_chunk_id: int) -> bool: + """Is the current virtual stage the last stage + + Args: + model_chunk_id (int): The current model chunk idx. + + Returns: + bool: Whether the current virtual stage is the last stage. + """ + if self.stage_manager.is_last_stage() and model_chunk_id == self.num_model_chunks - 1: + return True + return False + + def recv_forward(self, model_chunk_id: int, prev_rank: int = None) -> Any: + """Copy the forward output from the previous stage in pipeline as the input tensor of this stage. + For interleaved 1F1B. + + Args: + model_chunk_id (int): The current model chunk idx. + prev_rank (int, optional): The rank of the source of the tensor. + + Returns: + Any: The input tensor or input tensor list. + """ + if self.is_first_stage(model_chunk_id): + input_tensor = None + else: + input_tensor = self.comm.recv_forward(prev_rank) + + return input_tensor + + def recv_backward(self, model_chunk_id: int, next_rank: int = None) -> Any: + """Copy the gradient tensor from the next stage in pipeline as the input gradient of this stage. + For interleaved 1F1B. + + Args: + model_chunk_id (int): The current model chunk idx. + next_rank (int, optional): The rank of the source of the tensor. + + Returns: + Any: The input gradient tensor or gradient tensor list. + """ + if self.is_last_stage(model_chunk_id): + output_tensor_grad = None + else: + output_tensor_grad = self.comm.recv_backward(next_rank) + + return output_tensor_grad + + def send_forward(self, model_chunk_id, output_object: Any, next_rank: int = None) -> None: + """Sends the input tensor to the next stage in pipeline. + For interleaved 1F1B. + + Args: + model_chunk_id (int): The current model chunk idx. + output_object (Any): Object to be sent. + next_rank (int, optional): The rank of the recipient of the tensor. + """ + if not self.is_last_stage(model_chunk_id): + self.comm.send_forward(output_object, next_rank) + + def send_backward(self, model_chunk_id, input_object: Any, prev_rank: int = None) -> None: + """Sends the gradient tensor to the previous stage in pipeline. + For interleaved 1F1B. + + Args: + model_chunk_id (int): The current model chunk idx. + input_object (Any): Object to be sent. + prev_rank (int, optional): The rank of the recipient of the tensor + """ + if not self.is_first_stage(model_chunk_id): + self.comm.send_backward(input_object, prev_rank) + + def forward_step( + self, + model_chunk: Module, + model_chunk_id: int, + input_obj: Optional[dict], + criterion: Callable, + accum_loss: Optional[torch.Tensor] = None, + outputs: Optional[List[Any]] = None, + ) -> Union[torch.Tensor, dict]: + """Forward one step of the pipeline + Args: + model (Module): Model Chunk to be run + input_obj (Optional[dict]): The output from the previous stage. If it is the first stage, the `input_obj` is None. + criterion (Callable): Criterion to calculate loss. + accum_loss (Optional[torch.Tensor], optional): Accumulated loss. Defaults to None. + outputs (Optional[List[Any]], optional): List to store the output of the last stage (final output). Defaults to None. + + Returns: + Union[torch.Tensor, dict]: The intermediate output (dict) of the current stage. If it is the last stage, the output is the loss (Tensor). + """ + micro_batch = self.load_micro_batch(model_chunk_id=model_chunk_id) + + # for the first stage, input_obj is None + # for the non-first stage, input_obj is the output of the previous stage and it's must be a dict + output_obj = model_forward(model_chunk[model_chunk_id], micro_batch, input_obj) + + if self.is_last_stage(model_chunk_id): + loss = criterion(output_obj, micro_batch) / self.num_microbatches + if accum_loss is not None: + accum_loss.add_(loss.detach()) + if outputs is not None: + outputs.append(tree_map(detach, output_obj)) + return loss + else: + return output_obj + + def backward_step( + self, + optimizer: OptimizerWrapper, + input_obj: Optional[dict], + output_obj: Union[dict, torch.Tensor], + output_obj_grad: Optional[dict], + ) -> Optional[dict]: + """Backward one step of the pipeline + + Args: + optimizer (OptimizerWrapper): Optimizer to update the model + input_obj (Optional[dict]): Output of the previous stage. If it is the first stage, the `input_obj` is None. + output_obj (Union[dict, torch.Tensor]): Output of the current stage. If it is the last stage, the output is the loss (Tensor). + output_obj_grad (dict): Gradient of the `output_obj`. If it is the last stage, the `output_obj_grad` is None. + + Returns: + Optional[dict]: Gradient of the `input_obj`. If it is the first stage, the `input_obj_grad` is None. + """ + + # Retain the grad on the input_obj. + tree_map(retain_grad, input_obj) + + # Backward pass. + if output_obj_grad is None: + optimizer.backward(output_obj) + else: + if "backward_tensor_keys" not in output_obj: + for k, grad in output_obj_grad.items(): + optimizer.backward_by_grad(output_obj[k], grad) + else: + for k, grad in output_obj_grad.items(): + output_obj[k].grad = grad + for k in output_obj["backward_tensor_keys"]: + tensor_to_backward = output_obj[k] + optimizer.backward_by_grad(tensor_to_backward, tensor_to_backward.grad) + + # Collect the grad of the input_obj. + input_obj_grad = None + if input_obj is not None: + input_obj_grad = {} + for k, v in input_obj.items(): + if isinstance(v, torch.Tensor) and v.grad is not None: + input_obj_grad[k] = v.grad + return input_obj_grad + + def forward_backward_step( + self, + model_chunk: Module, + data_iter: Iterable, + criterion: Callable[..., Any], + optimizer: Optional[OptimizerWrapper] = None, + return_loss: bool = False, + return_outputs: bool = False, + ) -> dict: + """Runs interleaved 1F1B schedule, with communication between pipeline stages. + + Args: + model_chunk (List[Module]): Model Chunk to be trained. + data_iter (Iterable): Data iterator. + criterion (Callable[[Any, Any], Tensor]): Criterion to be used. It should take two arguments: model outputs and inputs, and returns loss tensor. + optimizer (OptimizerWrapper, optional): Optimizer to be used. Can be None when only forward is executed. Defaults to None. + return_loss (bool, optional): Whether to return loss. Defaults to False. Whether to return loss. + return_outputs (bool, optional): Whether to return model outputs. Defaults to False. Whether to return model outputs. + + Returns: + dict: A dict with keys: 'loss' and 'outputs'. + """ + forward_only = not torch.is_grad_enabled() + if optimizer is None: + assert forward_only, "Optimizer should be passed when doing backward." + + self.load_batch(data_iter) + num_model_chunks = len(model_chunk) + + # num_warmup_microbatches is the step when not all the processes are working + num_microbatches = self.num_microbatches * num_model_chunks + if forward_only: + num_warmup_microbatches = num_microbatches + else: + num_warmup_microbatches = (self.stage_manager.num_stages - self.stage_manager.stage - 1) * 2 + num_warmup_microbatches += (num_model_chunks - 1) * self.stage_manager.num_stages + num_warmup_microbatches = min(num_warmup_microbatches, num_microbatches) + + num_microbatches_remaining = num_microbatches - num_warmup_microbatches + + # Input, output tensors only need to be saved when doing backward passes + input_objs = None + output_objs = None + + if not forward_only: + input_objs = [[] for _ in range(num_model_chunks)] + output_objs = [[] for _ in range(num_model_chunks)] + + outputs = [] if return_outputs and self.stage_manager.is_last_stage() else None + + if return_loss and self.stage_manager.is_last_stage(): + accum_loss = torch.zeros(1, device=get_current_device()) + else: + accum_loss = None + + # for ranks except the first one, get into recv state + # print(self.stage_manager.stage,num_microbatches, num_warmup_microbatches, num_microbatches_remaining) + input_obj = self.recv_forward(0) + input_objs[0].append(input_obj) + # Run warmup forward passes. + for i in range(num_warmup_microbatches): + model_chunk_id = self.get_model_chunk_id(i, forward=True) + + # recv first on first rank to avoid sending or recving at the same time + if self.stage_manager.is_first_stage(): + input_obj = self.recv_forward(model_chunk_id) + output_obj = self.forward_step(model_chunk, model_chunk_id, input_obj, criterion, accum_loss, outputs) + self.send_forward(model_chunk_id, output_obj) + if not forward_only: + input_objs[model_chunk_id].append(input_obj) + output_objs[model_chunk_id].append(output_obj) + else: + output_obj = self.forward_step(model_chunk, model_chunk_id, input_obj, criterion, accum_loss, outputs) + if not forward_only: + output_objs[model_chunk_id].append(output_obj) + self.send_forward(model_chunk_id, output_obj) + if num_microbatches_remaining == 0 and i + 1 == num_warmup_microbatches: + break + else: + model_chunk_id = self.get_model_chunk_id(i + 1, forward=True) + + input_obj = self.recv_forward(model_chunk_id) + if not forward_only: + input_objs[model_chunk_id].append(input_obj) + + # Run 1F1B in steady state. + for i in range(num_microbatches_remaining): + model_chunk_id = self.get_model_chunk_id(i + num_warmup_microbatches, forward=True) + last_iteration = i == (num_microbatches_remaining - 1) + + output_obj = self.forward_step(model_chunk, model_chunk_id, input_obj, criterion, accum_loss, outputs) + if forward_only: + self.send_forward(model_chunk_id, output_obj) + + if not last_iteration: + input_obj = self.recv_forward(model_chunk_id) + + else: + self.send_forward(model_chunk_id, output_obj) + # Add input_obj and output_obj to end of list. + input_objs[model_chunk_id].append(input_obj) + output_objs[model_chunk_id].append(output_obj) + + model_chunk_id = self.get_model_chunk_id(i, forward=False) + output_obj_grad = self.recv_backward(model_chunk_id) + + # Pop output_obj and output_obj from the start of the list for + # the backward pass. + input_obj = input_objs[model_chunk_id].pop(0) + output_obj = output_objs[model_chunk_id].pop(0) + + # backward + input_obj_grad = self.backward_step(optimizer, input_obj, output_obj, output_obj_grad) + + if last_iteration: + input_obj = None + else: + model_chunk_id = self.get_model_chunk_id(i + num_warmup_microbatches + 1, forward=True) + input_obj = self.recv_forward(model_chunk_id) + model_chunk_id = self.get_model_chunk_id(i, forward=False) + self.send_backward(model_chunk_id, input_obj_grad) + + # Run cooldown backward passes. + if not forward_only: + for i in range(num_microbatches_remaining, num_microbatches): + model_chunk_id = self.get_model_chunk_id(i, forward=False) + # print(f"{self.stage_manager.stage}/{model_chunk_id}: {len(input_objs[model_chunk_id])} {len(output_objs[model_chunk_id])} {i}") + input_obj = input_objs[model_chunk_id].pop(0) + output_obj = output_objs[model_chunk_id].pop(0) + + output_obj_grad = self.recv_backward(model_chunk_id) + input_obj_grad = self.backward_step(optimizer, input_obj, output_obj, output_obj_grad) + self.send_backward(model_chunk_id, input_obj_grad) + + if outputs is not None: + outputs = merge_batch(outputs) + return {"loss": accum_loss, "outputs": outputs} diff --git a/colossalai/pipeline/schedule/one_f_one_b.py b/colossalai/pipeline/schedule/one_f_one_b.py new file mode 100644 index 0000000000000000000000000000000000000000..4eaf135fd5db4026755c9dca6baa6c5a34678fad --- /dev/null +++ b/colossalai/pipeline/schedule/one_f_one_b.py @@ -0,0 +1,330 @@ +from functools import partial +from typing import Any, Callable, Iterable, List, Optional, Union + +import torch +import torch.cuda +from torch.nn import Module +from torch.utils._pytree import tree_map + +from colossalai.interface import ModelWrapper, OptimizerWrapper +from colossalai.pipeline.p2p import PipelineP2PCommunication +from colossalai.pipeline.stage_manager import PipelineStageManager +from colossalai.utils.cuda import get_current_device + +from ._utils import ( + detach, + get_batch_size, + get_micro_batch, + merge_batch, + model_forward, + retain_grad, + to_device, + tree_map_hf, +) +from .base import PipelineSchedule + + +class OneForwardOneBackwardSchedule(PipelineSchedule): + def __init__( + self, + stage_manager: PipelineStageManager, + num_microbatches: Optional[int] = None, + microbatch_size: Optional[int] = None, + ) -> None: + """1F1B pipeline schedule. + + Args: + stage_manager (PipelineStageManager): Pipeline stage manager + num_microbatches (Optional[int], optional): The number of microbatches. If not provided, it will be derived from microbatch size. Defaults to None. + microbatch_size (Optional[int], optional): Microbatch size. If num_microbatches is provided, this will be ignored. Defaults to None. + """ + super().__init__(stage_manager) + assert ( + num_microbatches is not None or microbatch_size is not None + ), "Either num_microbatches or microbatch_size should be provided" + self.comm = PipelineP2PCommunication(stage_manager) + self.num_microbatches = num_microbatches + self.microbatch_size = microbatch_size + self.batch: Optional[Any] = None + self.batch_size: Optional[int] = None + self.microbatch_offset: Optional[int] = None + self._use_microbatch_size = num_microbatches is None + + def load_batch(self, data_iter: Iterable, device: Optional[torch.device] = None) -> None: + """Load a batch from data iterator. + + Args: + data_iter (Iterable): Data iterator. + device (Optional[torch.device], optional): Target device. Defaults to None. + """ + batch = next(data_iter) + if device is not None: + batch = tree_map(partial(to_device, device=device), batch) + self.batch = batch + self.batch_size = get_batch_size(batch) + self.microbatch_offset = 0 + if not self._use_microbatch_size: + assert ( + self.batch_size % self.num_microbatches == 0 + ), "Batch size should divided by the number of microbatches" + self.microbatch_size = self.batch_size // self.num_microbatches + else: + assert self.batch_size % self.microbatch_size == 0, "Batch size should divided by the microbatch size" + self.num_microbatches = self.batch_size // self.microbatch_size + + def load_micro_batch(self) -> Any: + """Load a micro batch from the current batch. + + Returns: + Any: Micro batch. + """ + micro_batch = get_micro_batch(self.batch, self.microbatch_offset, self.microbatch_size) + self.microbatch_offset += self.microbatch_size + return tree_map(partial(to_device, device=get_current_device()), micro_batch) + + def recv_forward(self, prev_rank: int = None) -> Any: + """Copy the forward output from the previous stage in pipeline as the input tensor of this stage. + For 1F1B. + + Args: + prev_rank (int, optional): The rank of the source of the tensor. + + Returns: + Any: The input tensor or input tensor list. + """ + if self.stage_manager.is_first_stage(): + input_tensor = None + else: + input_tensor = self.comm.recv_forward(prev_rank) + + return input_tensor + + def recv_backward(self, next_rank: int = None) -> Any: + """Copy the gradient tensor from the next stage in pipeline as the input gradient of this stage. + For 1F1B. + + Args: + next_rank (int, optional): The rank of the source of the tensor. + + Returns: + Any: The input gradient tensor or gradient tensor list. + """ + if self.stage_manager.is_last_stage(): + output_tensor_grad = None + else: + output_tensor_grad = self.comm.recv_backward(next_rank) + + return output_tensor_grad + + def send_forward(self, output_object: Any, next_rank: int = None) -> None: + """Sends the input tensor to the next stage in pipeline. + For 1F1B. + + Args: + output_object (Any): Object to be sent. + next_rank (int, optional): The rank of the recipient of the tensor. + """ + if not self.stage_manager.is_last_stage(): + self.comm.send_forward(output_object, next_rank) + + def send_backward(self, input_object: Any, prev_rank: int = None) -> None: + """Sends the gradient tensor to the previous stage in pipeline. + For 1F1B. + + Args: + input_object (Any): Object to be sent. + prev_rank (int, optional): The rank of the recipient of the tensor + """ + if not self.stage_manager.is_first_stage(): + self.comm.send_backward(input_object, prev_rank) + + def forward_step( + self, + model: Module, + input_obj: Optional[dict], + criterion: Callable, + accum_loss: Optional[torch.Tensor] = None, + outputs: Optional[List[Any]] = None, + ) -> Union[torch.Tensor, dict]: + """Forward one step of the pipeline + + Args: + model (Module): Model to be run + input_obj (Optional[dict]): The output from the previous stage. If it is the first stage, the `input_obj` is None. + criterion (Callable): Criterion to calculate loss. + accum_loss (Optional[torch.Tensor], optional): Accumulated loss. Defaults to None. + outputs (Optional[List[Any]], optional): List to store the output of the last stage (final output). Defaults to None. + + Returns: + Union[torch.Tensor, dict]: The intermediate output (dict) of the current stage. If it is the last stage, the output is the loss (Tensor). + """ + micro_batch = self.load_micro_batch() + # for the first stage, input_obj is None + # for the non-first stage, input_obj is the output of the previous stage and it's must be a dict + output_obj = model_forward(model, micro_batch, input_obj) + if self.stage_manager.is_last_stage(): + loss = criterion(output_obj, micro_batch) / self.num_microbatches + if accum_loss is not None: + accum_loss.add_(loss.detach()) + if outputs is not None: + outputs.append(tree_map_hf(detach, output_obj)) + return loss + else: + return output_obj + + def backward_step( + self, + optimizer: OptimizerWrapper, + input_obj: Optional[dict], + output_obj: Union[dict, torch.Tensor], + output_obj_grad: Optional[dict], + ) -> Optional[dict]: + """Backward one step of the pipeline + + Args: + optimizer (OptimizerWrapper): Optimizer to update the model + input_obj (Optional[dict]): Output of the previous stage. If it is the first stage, the `input_obj` is None. + output_obj (Union[dict, torch.Tensor]): Output of the current stage. If it is the last stage, the output is the loss (Tensor). + output_obj_grad (dict): Gradient of the `output_obj`. If it is the last stage, the `output_obj_grad` is None. + + Returns: + Optional[dict]: Gradient of the `input_obj`. If it is the first stage, the `input_obj_grad` is None. + """ + + # Retain the grad on the input_obj. + tree_map(retain_grad, input_obj) + # Backward pass. + if output_obj_grad is None: + optimizer.backward(output_obj) + else: + if "backward_tensor_keys" not in output_obj: + for k, grad in output_obj_grad.items(): + optimizer.backward_by_grad(output_obj[k], grad) + else: + for k, grad in output_obj_grad.items(): + output_obj[k].grad = grad + for k in output_obj["backward_tensor_keys"]: + tensor_to_backward = output_obj[k] + optimizer.backward_by_grad(tensor_to_backward, tensor_to_backward.grad) + + # Collect the grad of the input_obj. + input_obj_grad = None + if input_obj is not None: + input_obj_grad = {} + for k, v in input_obj.items(): + if isinstance(v, torch.Tensor) and v.grad is not None: + input_obj_grad[k] = v.grad + return input_obj_grad + + def forward_backward_step( + self, + model: Module, + data_iter: Iterable, + criterion: Callable[..., Any], + optimizer: Optional[OptimizerWrapper] = None, + return_loss: bool = False, + return_outputs: bool = False, + ) -> dict: + """Runs non-interleaved 1F1B schedule, with communication between pipeline stages. + + Args: + model (Module): Model to be trained. + data_iter (Iterable): Data iterator. + criterion (Callable[[Any, Any], Tensor]): Criterion to be used. It should take two arguments: model outputs and inputs, and returns loss tensor. + optimizer (OptimizerWrapper, optional): Optimizer to be used. Can be None when only forward is executed. Defaults to None. + return_loss (bool, optional): Whether to return loss. Defaults to False. Whether to return loss. + return_outputs (bool, optional): Whether to return model outputs. Defaults to False. Whether to return model outputs. + + Returns: + dict: A dict with keys: 'loss' and 'outputs'. + """ + forward_only = not torch.is_grad_enabled() + if optimizer is None: + assert forward_only, "Optimizer should be passed when doing backward." + + self.load_batch(data_iter) + + # num_warmup_microbatches is the step when not all the processes are working + num_warmup_microbatches = self.stage_manager.num_stages - self.stage_manager.stage - 1 + num_warmup_microbatches = min(num_warmup_microbatches, self.num_microbatches) + num_microbatches_remaining = self.num_microbatches - num_warmup_microbatches + + # Input, output tensors only need to be saved when doing backward passes + input_objs = None + output_objs = None + + if not forward_only: + input_objs = [] + output_objs = [] + + outputs = [] if return_outputs and self.stage_manager.is_last_stage() else None + if return_loss and self.stage_manager.is_last_stage(): + accum_loss = torch.zeros(1, device=get_current_device()) + else: + accum_loss = None + + # Run warmup forward passes. + for i in range(num_warmup_microbatches): + input_obj = self.recv_forward() + + output_obj = self.forward_step(model, input_obj, criterion, accum_loss, outputs) + + self.send_forward(output_obj) + + if not forward_only: + input_objs.append(input_obj) + output_objs.append(output_obj) + + # Before running 1F1B, need to receive first forward tensor. + # If all microbatches are run in warmup / cooldown phase, then no need to + # receive this tensor here. + if num_microbatches_remaining > 0: + input_obj = self.recv_forward() + + # Run 1F1B in steady state. + for i in range(num_microbatches_remaining): + last_iteration = i == (num_microbatches_remaining - 1) + + output_obj = self.forward_step(model, input_obj, criterion, accum_loss, outputs) + if forward_only: + self.send_forward(output_obj) + + if not last_iteration: + input_obj = self.recv_forward() + + else: + # TODO adjust here + self.send_forward(output_obj) + output_obj_grad = self.recv_backward() + + # Add input_obj and output_obj to end of list. + input_objs.append(input_obj) + output_objs.append(output_obj) + + # Pop output_obj and output_obj from the start of the list for + # the backward pass. + input_obj = input_objs.pop(0) + output_obj = output_objs.pop(0) + input_obj_grad = self.backward_step(optimizer, input_obj, output_obj, output_obj_grad) + + if last_iteration: + input_obj = None + else: + input_obj = self.recv_forward() + self.send_backward(input_obj_grad) + + # Run cooldown backward passes. + if not forward_only: + for i in range(num_warmup_microbatches): + input_obj = input_objs.pop(0) + output_obj = output_objs.pop(0) + + output_obj_grad = self.recv_backward() + input_obj_grad = self.backward_step(optimizer, input_obj, output_obj, output_obj_grad) + self.send_backward(input_obj_grad) + + if outputs is not None: + if isinstance(model, ModelWrapper): + model = model.unwrap() + outputs = merge_batch(outputs, getattr(model, "batch_size_dim", 0)) + return {"loss": accum_loss, "outputs": outputs} diff --git a/colossalai/pipeline/stage_manager.py b/colossalai/pipeline/stage_manager.py new file mode 100644 index 0000000000000000000000000000000000000000..b79867a2c6519524ea5ec56ece6f55706609e758 --- /dev/null +++ b/colossalai/pipeline/stage_manager.py @@ -0,0 +1,133 @@ +from typing import Dict, List, Optional, Tuple + +import torch.distributed as dist +from torch.distributed import ProcessGroup + +from colossalai.cluster import ProcessGroupMesh + + +class PipelineStageManager: + """PipelineStageManager is a helper class to manage pipeline stages. + + Args: + pg_mesh (ProcessGroupMesh): Process group mesh. + pipeline_axis (int): The axis along which the pipeline is constructed. + + Attributes: + num_stages (int): Number of stages in the pipeline. + stage (int): The current stage. + """ + + def __init__(self, pg_mesh: ProcessGroupMesh, pipeline_axis: int, is_virtual: bool = False) -> None: + self.pg_mesh = pg_mesh + self.pipeline_axis = pipeline_axis + self.prev_rank: Optional[Tuple[int, ...]] = None + self.next_rank: Optional[Tuple[int, ...]] = None + self.p2p_groups: Dict[Tuple[int, int], ProcessGroup] = {} + # init prev and next coord + coord = self.pg_mesh.coordinate() + # the prev rank of rank0 is the last rank + prev_coord = coord[: self.pipeline_axis] + (coord[self.pipeline_axis] - 1,) + coord[self.pipeline_axis + 1 :] + self.prev_rank = self.pg_mesh.ravel(prev_coord, self.pg_mesh.shape, mode="wrap") + # the next rank of the last rank is rank0 + next_coord = coord[: self.pipeline_axis] + (coord[self.pipeline_axis] + 1,) + coord[self.pipeline_axis + 1 :] + self.next_rank = self.pg_mesh.ravel(next_coord, self.pg_mesh.shape, mode="wrap") + + # init p2p process groups + stages = list(range(self.num_stages)) + for prev, cur in zip(stages[:-1], stages[1:]): + group = self.pg_mesh.get_group_along_axis(self.pipeline_axis, [prev, cur]) + if self.stage in [prev, cur]: + ranks_in_group = self.pg_mesh.get_ranks_in_group(group) + self.p2p_groups[tuple(ranks_in_group)] = group + + if is_virtual: + # add the process group of the first rank and the last rank + # only used in interleaved pipeline for now + group = self.pg_mesh.get_group_along_axis(self.pipeline_axis, [stages[0], stages[-1]]) + if self.stage in [stages[0], stages[-1]]: + ranks_in_group = self.pg_mesh.get_ranks_in_group(group) + self.p2p_groups[tuple(ranks_in_group)] = group + + def is_first_stage(self) -> bool: + """Is the current stage the first stage. + + Returns: + bool: Whether the current stage is the first stage. + """ + return self.stage == 0 + + def is_last_stage(self) -> bool: + """Is the current stage the last stage. + + Returns: + bool: Whether the current stage is the last stage. + """ + return self.stage == self.num_stages - 1 + + @property + def num_stages(self) -> int: + """Number of stages in the pipeline. + + Returns: + int: Number of stages in the pipeline. + """ + return self.pg_mesh.size(self.pipeline_axis) + + @property + def stage(self) -> int: + """Current stage. + + Returns: + int: Current stage. + """ + return self.pg_mesh.coordinate(self.pipeline_axis) + + def get_rank(self) -> int: + """Get the rank of the current process. + + Returns: + int: Rank of the current process. + """ + return dist.get_rank() + + def get_prev_rank(self) -> int: + """Get the rank of the previous stage. + + Returns: + int: Rank of the previous stage. + """ + return self.prev_rank + + def get_next_rank(self) -> int: + """Get the rank of the next stage. + + Returns: + int: Rank of the next stage. + """ + return self.next_rank + + def get_p2p_process_group(self, first_rank: int, second_rank: int) -> ProcessGroup: + """Get the p2p process group between two ranks. The order of the two ranks does not matter. + + Args: + first_rank (int): The first rank. + second_rank (int): The second rank. + + Returns: + ProcessGroup: P2P process group between the two ranks. + """ + if first_rank > second_rank: + first_rank, second_rank = second_rank, first_rank + return self.p2p_groups[(first_rank, second_rank)] + + def init_process_group_by_stages(self, stages: List[int]) -> ProcessGroup: + """Get the process group of the given stages. + + Args: + stages (List[int]): List of stages. + + Returns: + ProcessGroup: Process group of the given stages. + """ + return self.pg_mesh.get_group_along_axis(self.pipeline_axis, stages) diff --git a/colossalai/pipeline/utils.py b/colossalai/pipeline/utils.py deleted file mode 100644 index df7226644a7a9931ecd82c09a39653fc5e0cfdeb..0000000000000000000000000000000000000000 --- a/colossalai/pipeline/utils.py +++ /dev/null @@ -1,275 +0,0 @@ -import heapq -import inspect -import torch - -from colossalai.logging import get_dist_logger -from colossalai.nn.layer.utils import CheckpointModule -from typing import List - -from collections import OrderedDict - -def _binary_partition(weights: List, start: int, end: int): - """Returns the binary partition position of `weights`, given the start - position `st` and the end position `ed`. - - Args: - weights (list): A python list to be binary partitioned - start (int): the start position of the binary partition - end (int): the end position of the binary partition - - Returns: - int: the binary partition position of `weights` - """ - w_sum = weights[end - 1] - prefix = 0 - if start > 0: - w_sum -= weights[start - 1] - prefix = weights[start - 1] - minimum = float("inf") - for idx in range(start + 1, end): - front = weights[idx - 1] - prefix - diff = abs(w_sum - 2 * front) - if diff < minimum: - pos = idx - minimum = diff - - return start, pos, end - - -def _heap_addition(weights: List, intervals: int, add_cnt: int): - """ - """ - - def _heap_push(heap, st, ed): - value = weights[ed - 1] - if st > 0: - value -= weights[st - 1] - heapq.heappush(heap, (-value, st, ed)) - - ret_intervals = [] - heap = [] - - for st, ed in intervals: - _heap_push(heap, st, ed) - - while add_cnt > 0: - _, st, ed = heapq.heappop(heap) - if ed - st == 1: - ret_intervals.append((st, ed)) - else: - l, m, r = _binary_partition(weights, st, ed) - _heap_push(heap, l, m) - _heap_push(heap, m, r) - add_cnt -= 1 - - while heap: - _, st, ed = heapq.heappop(heap) - ret_intervals.append((st, ed)) - - ret_intervals.sort() - return ret_intervals - - -def _calc_partitions(weights, value): - prev = 0 - prefix = 0 - num_block = 0 - intervals = [] - - for idx, w in enumerate(weights): - if weights[idx] - prefix > value: - intervals.append((prev, idx)) - prev = idx - prefix = weights[idx - 1] - num_block += 1 - - intervals.append((prev, len(weights))) - return num_block + 1, intervals - - -def _binary_search(weights, num): - length = len(weights) - prefix = [1 if w == 0 else w for w in weights] - for i in range(1, length): - prefix[i] += prefix[i - 1] - - lower_bound = max(weights) - upper_bound = prefix[length - 1] - - while upper_bound > lower_bound: - mid = (upper_bound + lower_bound) // 2 - number, _ = _calc_partitions(prefix, mid) - if number <= num: - upper_bound = mid - else: - lower_bound = mid + 1 - - num_block, intervals = _calc_partitions(prefix, upper_bound) - if num_block < num: - intervals = _heap_addition(prefix, intervals, num - num_block) - - return intervals - - -def partition_uniform(num_items, pipeline_parallel_size, num_chunks): - assert num_items % num_chunks == 0, \ - "Layer length should be divided by the number of chunks, otherwise parameter method is recomended" - - logger = get_dist_logger() - parts = [[] for _ in range(pipeline_parallel_size)] - partition_items = num_items // num_chunks - for idx in range(num_chunks): - base_idx = idx * partition_items - chunk_size = partition_items // pipeline_parallel_size - left = pipeline_parallel_size - partition_items % pipeline_parallel_size - if chunk_size == 0: - logger.warning("Some nodes in Pipeline have no requests") - - for p in range(pipeline_parallel_size): - st = base_idx - base_idx += chunk_size + (p >= left) - parts[p].append((st, base_idx)) - - return parts - - -def partition_balanced(weights, pipeline_parallel_size, num_chunks): - num_total = pipeline_parallel_size * num_chunks - num_items = len(weights) - if num_items <= num_total: - return partition_uniform(num_items, pipeline_parallel_size, num_chunks) - - intervals = _binary_search(weights, num_total) - - current = 0 - parts = [[] for _ in range(pipeline_parallel_size)] - for inter in intervals: - parts[current].append(inter) - current = (current + 1) % pipeline_parallel_size - - return parts - - -def build_kwargs_for_module(function, input_tensor, kw_dict): - """ - Generally, the first argument of module.forward is an input tensor come from the previous layer. - Therefore, we just filter the kwargs from second element of the dictionary. - """ - sig = inspect.signature(function) - if input_tensor is None: - kwargs_offset = 0 - elif isinstance(input_tensor, torch.Tensor): - kwargs_offset = 1 - elif isinstance(input_tensor, (tuple, OrderedDict)): - #assert isinstance(input_tensor, tuple), f'input_tensor should be a torch.Tensor or a tuple object.' - # Huggingface will take their own structures based on OrderedDict as the output - # between layers so we've to close this check. - kwargs_offset = len(input_tensor) - args_name_list = list(sig.parameters.keys()) - kw_dict = {k: v for k, v in kw_dict.items() if k in args_name_list[kwargs_offset:]} - if len(kw_dict) == 0: - return None - return kw_dict - - -def build_kwargs_for_function(function, kw_dict): - sig = inspect.signature(function) - kw_dict = {k: v for k, v in kw_dict.items() if k in sig.parameters} - if len(kw_dict) == 0: - return None - return kw_dict - - -def exec_func_with_kwargs(func, kw_dict, input_tensor, kwargs): - """ - We suppose the callable object passed to to_layer_list method in two purpose: - a. use the callable object to modify input tensor, such as \ - lambda x: torch.flatten(x, 1) - b. use the callable object to modify kwargs value, such as \ - def foo(attention_mask=None): - if attention_mask is not None: - batch_size = input_ids.shape[0] - attention_mask = attention_mask.view(batch_size, -1) - return attention_mask - """ - - if kw_dict is not None: - rst = func(**kw_dict) - if isinstance(rst, tuple): - for i, k in enumerate(kw_dict.keys()): - kwargs[k] = rst[i] - else: - for k in kw_dict.keys(): - kwargs[k] = rst - return input_tensor - if isinstance(input_tensor, tuple): - assert len(input_tensor) > 0, f'input_tensor should not be empty, when kw_dict is None.' - sig = inspect.signature(func) - func_args_num = len(sig.parameters) - assert func_args_num <= len( - input_tensor), f'func requires {func_args_num} arguments, but input_tensors only have {len(input_tensor)}.' - if func_args_num < len(input_tensor): - return func(*input_tensor[:func_args_num]) - else: - return func(*input_tensor) - assert isinstance(input_tensor, torch.Tensor), 'input_tensor should be a type of torch.Tensor or tuple.' - return func(input_tensor) - - -def exec_funcs_with_kwargs(func_dict, func_key, input_tensor, kwargs): - - assert func_key in func_dict, f"{func_key} is not in the function_dict." - funcs_to_exec = func_dict[func_key] - if isinstance(funcs_to_exec, list): - for f in funcs_to_exec: - f_kwargs = build_kwargs_for_function(f, kwargs) - input_tensor = exec_func_with_kwargs(f, f_kwargs, input_tensor, kwargs) - else: - f_kwargs = build_kwargs_for_function(funcs_to_exec, kwargs) - input_tensor = exec_func_with_kwargs(funcs_to_exec, f_kwargs, input_tensor, kwargs) - - return input_tensor - - -def call_module(module, args=None, kwargs=None): - if args is None: - args = () - if kwargs is None: - kwargs = {} - if isinstance(module, CheckpointModule): - forward_func = module._forward - else: - forward_func = module.forward - sig = inspect.signature(forward_func) - param_nums = len(sig.parameters) - feed_nums = len(args) + len(kwargs) - args_needed_nums = param_nums - len(kwargs) - args_needed = args[:args_needed_nums] - if isinstance(module, CheckpointModule): - convert_kwargs_to_args = [] - for v in kwargs.values(): - convert_kwargs_to_args.append(v) - return module(*args_needed, *convert_kwargs_to_args) - else: - return module(*args_needed, **kwargs) - - -def customized_partition(exec_seq): - ''' - This function will analyze the exec_seq. In the exec_seq, users will use 'SPLIT_NODE' as an - annotation to note the partition point. - ''' - customized_parts = {} - start = 0 - stop = 0 - rank = 0 - for element in exec_seq: - if isinstance(element, str): - if element == 'SPLIT_NODE': - customized_parts[rank] = [(start, stop)] - start = stop - rank += 1 - else: - stop += 1 - customized_parts[rank] = [(start, stop)] - return customized_parts diff --git a/colossalai/shardformer/README.md b/colossalai/shardformer/README.md new file mode 100644 index 0000000000000000000000000000000000000000..4bd7d5208a64c5600e0c9d2ade2734bb6c27a5ff --- /dev/null +++ b/colossalai/shardformer/README.md @@ -0,0 +1,464 @@ +# ⚡️ ShardFormer + +## 📚 Table of Contents + +- [⚡️ ShardFormer](#️-shardformer) + - [📚 Table of Contents](#-table-of-contents) + - [🔗 Introduction](#-introduction) + - [🔨 Usage](#-usage) + - [Quick Start](#quick-start) + - [Write your own policy](#write-your-own-policy) + - [🗺 Roadmap](#-roadmap) + - [💡 API Design](#-api-design) + - [Distributed Modules](#distributed-modules) + - [Shard Config](#shard-config) + - [Policy](#policy) + - [Model Sharder](#model-sharder) + - [User-facing API](#user-facing-api) + - [⌨️ Development Notes](#️-development-notes) + - [Add New Policy to Shardformer](#add-new-policy-to-shardformer) + - [Write Your Unit Testing](#write-your-unit-testing) + - [📊 Benchmarking](#-benchmarking) + - [System Performance](#system-performance) + - [Convergence](#convergence) + +## 🔗 Introduction + +**Shardformer** is a module that automatically parallelizes the mainstream models in libraries such as HuggingFace and TIMM. This module aims to make parallelization hassle-free for users who are not from the system background. + +## 🔨 Usage + +### Quick Start + +The sample API usage is given below(If you enable the use of flash attention, please install `flash_attn`. In addition, xformers's `cutlass_op` provide a supplementary optimization): + +```python +from colossalai.shardformer import ShardConfig, ShardFormer +from transformers import BertForMaskedLM +import colossalai + +# launch colossalai +colossalai.launch_from_torch(config={}) + +# create model +config = BertConfig.from_pretrained('bert-base-uncased') +model = BertForMaskedLM.from_pretrained('bert-base-uncased', config=config) + +# create huggingface model as normal +shard_config = ShardConfig(tensor_parallel_process_group=tp_group, + pipeline_stage_manager=stage_manager, + enable_tensor_parallelism=True, + enable_fused_normalization=True, + enable_flash_attention=True, + enable_jit_fused=True, + enable_sequence_parallelism=True, + enable_sequence_overlap=True) + +shard_former = ShardFormer(shard_config=shard_config) +sharded_model, shared_params = shard_former.optimize(model).to('cuda') + +# do everything like normal +... +``` + +Following are the description `ShardConfig`'s arguments: + +- `tensor_parallel_process_group`: The process group of tensor parallelism, it's necessary when using tensor parallel. Defaults to None, which is the global process group. + +- `pipeline_stage_manager`: If using pipeline parallelism, it's necessary to specify a pipeline stage manager for inter-process communication in pipeline parallelism. Defaults to None, which means not using pipeline parallelism. + +- `enable_tensor_parallelism`: Whether to use tensor parallelism. Defaults to True. + +- `enable_fused_normalization`: Whether to use fused layernorm. Defaults to False. + +- `enable_flash_attention`: Whether to switch on flash attention. Defaults to False. + +- `enable_jit_fused`: Whether to switch on JIT fused operators. Defaults to False. + +- `enable_sequence_parallelism`: Whether to turn on sequence parallelism, which partitions non-tensor-parallel regions along the sequence dimension. Defaults to False. + +- `enable_sequence_overlap`: Whether to turn on sequence overlap, wheich overlap the computation and communication in sequence parallelism. It can only be used when `enable_sequence_parallelism` is True. Defaults to False. + +- `enable_all_optimization`: Whether to turn on all optimization tools including `fused normalizaion`, `flash attention`, `JIT fused operators`, `sequence parallelism` and `sequence overlap`. Defaults to False. + +- `inference_only`: Whether only doing forward passing. Defaults to False. + +### Write your own policy + +If you have a custom model, you can also use Shardformer to parallelize it by writing your own sharding policy. More information about the sharding policy can be found in [API Design](#-api-design). + +```python +from colossalai.shardformer import Policy + +class MyPolicy(Policy): + # implement your own policy + ... + +# init model and shard former +... + +# use customized policy to shard model +my_policy = MyPolicy() +shard_former.optimize(model, my_policy) + + + +``` + +## 🗺 Roadmap + +We will follow this roadmap to develop Shardformer: + +- [x] API Design +- [x] API Implementation +- [x] Unit Testing +- [ ] Policy Implementation + +| model | tensor parallel | pipeline parallel | lazy initialization | xformer | flash attn2 | jit fused operator | fused layernorm | sequence parallel | overlap | +| :------: | :-----: | :-----: | :--------: | :---------: | :------: | :-----: | :-----: | :--------: | :---------: | +| bert | [x] | [x] | [x] | [x] | [x] | [x] | [x] | [x] | [x] | +| t5 | [x] | [x] | [x] | [x] | [x] | [x] | [x] | [ ] | [ ] | +| llama V1/V2 | [x] | [x] | [x] | [x] | [x] | [x] | [x] | [ ] | [ ] | +| gpt2 | [x] | [x] | [x] | [x] | [x] | [x] | [x] | [x] | [x] | +| opt | [x] | [x] | [x] | [x] | [x] | [x] | [x] | [ ] | [ ] | +| bloom | [x] | [x] | [x] | [x] | [x] | [x] | [x] | [x] | [x] | +| chatglm2 | [x] | [x] | [x] | [x] | [x] | [x] | [x] | [x] | [x] | +| vit | [x] | [x] | [ ] | [x] | [x] | [x] | [x] | [ ] | [ ] | +| whisper | [x] | [x] | [x] | [x] | [x] | [ ] | [x] | [ ] | [ ] | +| sam | [x] | [ ] | [ ] | [x] | [x] | [x] | [x] | [ ] | [ ] | +| blip2 | [x] | [ ] | [ ] | [x] | [x] | [x] | [x] | [ ] | [ ] | +| roberta | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | +| albert | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | +| ernie | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | +| gpt-neo | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | +| gpt-j | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | +| beit | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | +| swin | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | +| swin V2 | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | +| qwen | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | + + +## 💡 API Design + +We will discuss the major components of `ShardFormer` below to help you better understand how things work. +This section serves as the design doc for Shardformer and the function signature might differ from the actual implementation. +Please refer to the code for more details. + +

        + +
        +

        + +### Distributed Modules + +`ShardFormer` replaces the original PyTorch module with a distributed module. +The distributed module keeps the same attributes as the original module but replaces the original parameters with distributed parameters and defines a new `forward` function to execute distributed computation. +Each distributed module implements its `from_native_module` static method to convert the PyTorch module to its corresponding distributed module. + +````python +class ParallelModule(torch.nn.Module): + + @abstractmethod + def from_native_module(module: torch.nn.Module, process_group: Union[ProcessGroup, Tuple[ProcessGroup]]) -> ParallelModule + """ + Convert a native module to a parallelized + + Examples: + + ```python + # replace module + my_linear = Linear1D_Col.from_native_module(my_linear, process_group) + ``` + """ +```` + +### Shard Config + +`ShardConfig` is a simple data class to tell `ShardFormer` how sharding will be performed. + +```python +@dataclass +class ShardConfig: + tensor_parallel_process_group: ProcessGroup = None + enable_fused_normalization: bool = False + ... + + # Some possible future config fields + tensor_parallel_mode: Choice['1d', '2d', '2.5d', '3d'] # support different tensor parallel mode + inference_only: bool # only inject inference-suitable sharding policy + use_flash_attention: bool # whether to use flash attention to speed up attention +``` + +### Policy + +The `Policy` class describes how to handle the model sharding. +It is merely a description, the actual sharding will be performed by `ModelSharder`. +We abstract the policy into four stages: + +1. Preprocessing: call `Policy.preprocess` to do some prior work before sharding, for example, resizing the embedding +2. Providing `ModulePolicyDescription`: call `Policy.module_policy` to get a bunch of `ModulePolicyDescription` to tell `ModelSharder` how the submodules's attributes, child parameters, and deeper submodules will be substituted. +3. Postprocessing: call `Policy.postprocess` to perform some postprocessing work, for example, binding the embedding and classifier head weights of the BERT model. + +```python +@dataclass +class ModulePolicyDescription: + r""" + Describe how the attributes and parameters will be transformed in a policy. + + Args: + attribute_replacement (Dict[str, Any]): key is the attribute name, value is the attribute value after sharding + param_replacement (List[Callable]): a list of functions to perform in-place param replacement. The function must receive only one arguments: module. + sub_module_replacement (List[SubModuleReplacementDescription]): each element in the list is a ParamReplacementDescription + object which specifies the module to be replaced and the target module used to replacement. + method_replace (Dict[str, Callable]): key is the method name, value is the method for replacement + """ + attribute_replacement: Dict[str, Any] = None + param_replacement: List[Callable] = None + sub_module_replacement: List[SubModuleReplacementDescription] = None + method_replacement: Dict[str, Callable] = None + +@dataclass +class SubModuleReplacementDescription: + r""" + Describe how a submodule will be replaced + + Args: + suffix (str): used to get the submodule object + target_module (ParallelModule): specifies the module class used to replace to submodule + kwargs (Dict[str, Any]): the dictionary used to pass extra arguments to the `ParallelModule.from_native_module` method. + ignore_if_not_exist (bool): if the submodule does not exist, ignore it or raise an exception + """ + suffix: str + target_module: ParallelModule + kwargs: Dict[str, Any] = None + ignore_if_not_exist: bool = False + + +class Policy(ABC): + + def __init__(self) + self.model = None + + def set_model(self, model: nn.Module) -> None: + """ + Set model as an attribute of the Policy object so that we can access the model's attributes. + """ + self.model = model + + @abstractmethod + def preprocess(self) -> nn.Module: + """ + Perform some preprocessing on the model, such as resizing the embedding size + """ + ... + + @abstractmethod + def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: + """ + Return the dict for the modify policy, the key is the original layer class and the value is the + argument for the modify layer + """ + ... + + @abstractmethods + def postprocess(self) -> nn.Module: + """ + Perform some postprocessing on the model, such as binding the embedding with the weight of the classifier head + """ + ... +``` + +### Model Sharder + +`ModelSharder` is the class in charge of sharding the model based on the given policy. + +```python +class ModelSharder: + + def __init__(self, model: torch.nn.Module, shard_config: ShardConfig, Policy: ShardPolicy = None): + #TODO: input is a cls or a obj + ... + + def shard(self) -> None: + """ + Shard model with parallelism with the help of pre-processing, replace_model_class, replace_module, and post-processing. + """ + ... + + def replace_module(self) -> None: + """ + Replace the layer according to the policy. Call Policy.module_policy() to get the module. Call _replace_module recursively. + """ + ... +``` + +### User-facing API + +We only expose a limited number of APIs to the user to keep their user experience simple and clean. + +```python +class ShardFormer: + """ + Parallelize model based on the given config and policy + + Example: + + org_model = BertForMaskedLM.from_pretrained('bert-base-uncased') + shard_config = ShardConfig() + shard_former = ShardFormer(shard_config=shard_config) + model, shared_params = shard_former.optimize(org_model) + + """ + + def __init__(self, shard_config: ShardConfig): + """ + Do two things: + 1. Create a distribute coordinator + 2. serve as a store for shard config + """ + self.shard_config = shard_config + self.coordinator = DistCoordinator() + + def optimize(self, model: nn.Module, policy: Policy = None) -> Tuple[nn.Module, List[Dict[int, Tensor]]]: + r""" + This method will optimize the model based on the given policy. + + Args: + model (`torch.nn.Model`): the origin huggingface model + shard_config (`ShardConfig`): the config for distribute information + policy (`Policy`): the custom policy for sharding + + Returns: the sharded model and the shared parameters + """ + sharder = ModelSharder(model=model, shard_config=self.shard_config, policy=policy) + shared_params = sharder.shard() + return model, shared_params +``` + +## ⌨️ Development Notes + +### Add New Policy to Shardformer + +This section serves as the guideline for writing new policies and register them into `shardformer`. + +- Step 1. Write your own model policy + +You can create a new file in the `colossalai/shardformer/policies` folder and name the file with the model name. You can implement your policy in this file. You should not import the any model zoo library at the header section of the file because we do not want to import the library when we do not use the policy. Libraries such as `transformers` should be imported only in the function body when needed. + +Please follow the following protocols when writing your policy: + +- You have to make a clear decision what you want to replace exactly in the original PyTorch module + - Use `ModulePolicyDescription.attribute_replacement` to replace the module attributes + - Use `ModulePolicyDescription.param_replacement` to replace the module parameters + - Use `ModulePolicyDescription.sub_module_replacement` to replace the submodules completely. The target module should implement the `from_native_module` for the replacement. + - Use `ModulePolicyDescription.method_replacement` to replace the module methods. **These replacement methods should be put in the `shardformer/modeling/.py`**. +- You can implement the `ParallelModule` for primitive modules in the `shardformer/layer/.py` file. Primitive modules refer to modules which are not composed of other modules. For example, the `torch.nn.Linear` module is a primitive module while modules such as `BertEncoder` module in the `transformers` library is a composite module. Primitive modules do not nested inner `nn.Module` members. For composite modules, you should consider using `ModulePolicyDescription` to implement your replacement. +- `ParallelModule` is meant to be used in two ways: `ParallelModule.from_native_module` to convert native PyTorch module to the `ParallelModule` and `ParallelModule(...)` to instantiate the module directly just like a normal PyTorch module. `ParallelModule` should be only implemented for modules whose weights are sharded. If you want to make your module compatible with the `ModulePolicyDescription.sub_module_replacement` and there is no weight sharding in your module, you can just implement the `from_native_module` method without inheriting the `ParallelModule` like `colossalai/shardformer/layer/normalization.py`. +- **Do not import any file in the `colossalai/shardformer/policies` and `colossalai/shardformer/modeling` to avoid unwanted import error**. For example, a file in these folders accidentally imports `transformers` library at the top of the file, then the user will have to install `transformers` library even if they do not use this file. Any file in the `modeling` folder should be only imported by the policy file. A policy implementation should be only imported dynamically via the autopolicy or manually via the `ShardFormer` module. +- Try to keep your import statement on third-party libraries such as `transformers` within the function body instead of the header section of the file. This is because we do not want to import the library when we do not use the policy. + +- Step 2. Register your policy to the autopolicy + +Next, you need to register your policy in the `colossalai/shardformer/policies/autopolicy.py` file. + +For example, if we register the policy for the BERT model, we just add a key-value in the `_POLICY_LIST` dictionary. The key if the `qualname` of the model object (you can get it by model.\_\_class\_\_.\_\_qualname\_\_). The value is a `PolicyLocation` object, which contains the file name and the class name of the policy. We do not import the policy directly because the policy file may contain libraries (such as `transformers`) which we do not want to import when we do not use the policy. + +```python +_POLICY_LIST = { + # BERT + "transformers.models.bert.modeling_bert.BertModel": + PolicyLocation(file_name="bert", class_name="BertModelPolicy"), +} +``` + +### Write Your Unit Testing + +This section serves as the guideline for testing the `shardformer` module. + +- Step 1. Add your model to the model zoo in the test kits. + +Add your model to the `tests/kit/model_zoo` file. This allows you to define test-related components for this model. You can take `tests/kit/model_zoo/transformers/llama.py` as an example for reference. + +- Step 2. Write your unit testing for the model + +Next, implement your unit test in the `tests/test_shardformer` folder. Please refer to other similar tests for style consistency. + +- Step 3. Execute your test + +When you run tests locally, you should run tests for both your newly-added test file and the whole `shardformer` module tests. + +```bash +# test for your own test file +pytest tests/test_shardformer/test_model/.py + +# test for the whole shardformer module +pytest tests/test_shardformer +``` + +## 📊 Benchmarking + +### System Performance + +We conducted [benchmark tests](./examples/performance_benchmark.py) to evaluate the performance improvement of Shardformer. We compared the training time between the original model and the shard model. + +We set the batch size to 4, the number of attention heads to 8, and the head dimension to 64. 'N_CTX' refers to the sequence length. + +In the case of using 2 GPUs, the training times are as follows. +| N_CTX | org_model | shard_model | +| :------: | :-----: | :-----: | +| 256 | 11.2ms | 17.2ms | +| 512 | 9.8ms | 19.5ms | +| 1024 | 19.6ms | 18.9ms | +| 2048 | 46.6ms | 30.8ms | +| 4096 | 160.5ms | 90.4ms | + + +

        + +
        +

        + +In the case of using 4 GPUs, the training times are as follows. + +| N_CTX | org_model | shard_model | +| :------: | :-----: | :-----: | +| 256 | 10.0ms | 21.1ms | +| 512 | 11.5ms | 20.2ms | +| 1024 | 22.1ms | 20.6ms | +| 2048 | 46.9ms | 24.8ms | +| 4096 | 160.4ms | 68.0ms | + + + +

        + +
        +

        + + +As shown in the figures above, when the sequence length is around 1000 or greater, the parallel optimization of Shardformer for long sequences starts to become evident. + +### Convergence + + +To validate that training the model using shardformers does not impact its convergence. We [fine-tuned the BERT model](./examples/convergence_benchmark.py) using both shardformer and non-shardformer approaches. The example that utilizes Shardformer simultaneously with Pipeline Parallelism and Data Parallelism (Zero1). We then compared the accuracy, loss, and F1 score of the training results. + +the configurations are as follows: +```python +batch_size = 2 +epoch = 3 +lr = 2.4e-5 +accumulation_steps = 8 +warmup_fraction = 0.03 +``` + + + +| accuracy | f1 | loss | GPU number | model sharded | +| :------: | :-----: | :-----: | :--------: | :---------: | +| 0.82971 | 0.87713 | 0.23194 | 4 | True | +| 0.83797 | 0.88006 | 0.22683 | 2 | True | +| 0.84521 | 0.88700 | 0.21822 | 1 | False | + + +Overall, the results demonstrate that using shardformers during model training does not affect the convergence. diff --git a/colossalai/shardformer/__init__.py b/colossalai/shardformer/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..77c2af8d18f7db8641e13030913887af098e1722 --- /dev/null +++ b/colossalai/shardformer/__init__.py @@ -0,0 +1 @@ +from .shard import ShardConfig, ShardFormer diff --git a/colossalai/shardformer/_utils.py b/colossalai/shardformer/_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..96d6cea21075a79d3aca4217875a71c31b0c55f3 --- /dev/null +++ b/colossalai/shardformer/_utils.py @@ -0,0 +1,112 @@ +import re + + +def get_obj_list_element(obj, attr: str): + r""" + Get the element of the list in the object + + If the attr is a normal attribute, return the attribute of the object. + If the attr is a index type, return the element of the index in the list, like `layers[0]`. + + Args: + obj (Object): The object to get + attr (str): The suffix of the attribute to get + + """ + re_pattern = r"\[\d+\]" + prog = re.compile(re_pattern) + result = prog.search(attr) + if result: + matched_brackets = result.group() + matched_index = matched_brackets.replace("[", "") + matched_index = matched_index.replace("]", "") + attr_ = attr.replace(matched_brackets, "") + container_obj = getattr(obj, attr_) + obj = container_obj[int(matched_index)] + else: + obj = getattr(obj, attr) + return obj + + +def set_obj_list_element(obj, attr: str, value): + r""" + Set the element to value of a list object + + It used like set_obj_list_element(obj, 'lyaers[0]', new_layer), it will set obj.layers[0] to value + + Args: + obj (object): The object to set + attr (str): the string including a list index like `layers[0]` + """ + re_pattern = r"\[\d+\]" + prog = re.compile(re_pattern) + result = prog.search(attr) + if result: + matched_brackets = result.group() + matched_index = matched_brackets.replace("[", "") + matched_index = matched_index.replace("]", "") + attr_ = attr.replace(matched_brackets, "") + container_obj = getattr(obj, attr_) + container_obj[int(matched_index)] = value + else: + setattr(obj, attr, value) + + +def hasattr_(obj, attr: str): + r""" + Check whether the object has the multi sublevel attr + + Args: + obj (object): The object to check + attr (str): The multi level attr to check + """ + attrs = attr.split(".") + for a in attrs: + try: + obj = get_obj_list_element(obj, a) + except AttributeError: + return False + return True + + +def setattr_(obj, attr: str, value, ignore: bool = False): + r""" + Set the object's multi sublevel attr to value, if ignore, ignore when it doesn't exist + + Args: + obj (object): The object to set + attr (str): The multi level attr to set + value (Any): The value to set + ignore (bool): Whether to ignore when the attr doesn't exist + """ + + attrs = attr.split(".") + for a in attrs[:-1]: + try: + obj = get_obj_list_element(obj, a) + except AttributeError: + if ignore: + return + raise AttributeError(f"Object {obj.__class__.__name__} has no attribute {attr}") + set_obj_list_element(obj, attrs[-1], value) + + +def getattr_(obj, attr: str, ignore: bool = False): + r""" + Get the object's multi sublevel attr + + Args: + obj (object): The object to set + attr (str): The multi level attr to set + ignore (bool): Whether to ignore when the attr doesn't exist + """ + + attrs = attr.split(".") + for a in attrs: + try: + obj = get_obj_list_element(obj, a) + except AttributeError: + if ignore: + return None + raise AttributeError(f"Object {obj.__class__.__name__} has no attribute {attr}") + return obj diff --git a/colossalai/shardformer/examples/convergence_benchmark.py b/colossalai/shardformer/examples/convergence_benchmark.py new file mode 100644 index 0000000000000000000000000000000000000000..b03e6201dce8e64f677032dd9e07e691d82ea62c --- /dev/null +++ b/colossalai/shardformer/examples/convergence_benchmark.py @@ -0,0 +1,180 @@ +import argparse +import math +from typing import Any, List, Union + +import evaluate +import torch +import torch.distributed as dist +from data import GLUEDataBuilder +from torch import nn +from torch.optim import Adam, Optimizer +from torch.utils._pytree import tree_map +from torch.utils.data import DataLoader +from tqdm import tqdm +from transformers import BertConfig, BertForSequenceClassification, get_linear_schedule_with_warmup + +import colossalai +from colossalai.cluster import DistCoordinator +from colossalai.shardformer import ShardConfig, ShardFormer + + +def to_device(x: Any, device: torch.device) -> Any: + def _to(t: Any): + if isinstance(t, torch.Tensor): + return t.to(device) + return t + + return tree_map(_to, x) + + +def train(args): + colossalai.launch_from_torch(config={}, seed=42) + coordinator = DistCoordinator() + + # prepare for data and dataset + data_builder = GLUEDataBuilder( + model_name_or_path=args.pretrain, + task_name=args.task, + train_batch_size=args.batch_size, + eval_batch_size=args.batch_size, + ) + train_dataloader = data_builder.train_dataloader() + test_dataloader = data_builder.test_dataloader() + + if args.model == "bert": + cfg = BertConfig.from_pretrained(args.pretrain, num_labels=data_builder.num_labels) + model = BertForSequenceClassification.from_pretrained(args.pretrain, config=cfg) + + model.to(torch.cuda.current_device()) + + # if multiple GPUs, shard the model + if dist.get_world_size() > 1: + tp_group = dist.new_group(backend="nccl") + shard_config = ShardConfig( + tensor_parallel_process_group=tp_group, enable_tensor_parallelism=True, enable_all_optimization=True + ) + shard_former = ShardFormer(shard_config=shard_config) + model, _ = shard_former.optimize(model) + + optim = Adam(model.parameters(), lr=args.lr) + num_update_steps_per_epoch = len(train_dataloader) // args.accumulation_steps + max_steps = math.ceil(args.max_epochs * num_update_steps_per_epoch) + lr_scheduler = get_linear_schedule_with_warmup( + optim, + num_warmup_steps=math.ceil(max_steps * args.warmup_fraction), + num_training_steps=max_steps, + ) + fit( + model, + optim, + lr_scheduler, + train_dataloader, + args.max_epochs, + args.accumulation_steps, + args.batch_size, + coordinator, + ) + results = evaluate_model( + model, test_dataloader, data_builder.num_labels, args.task, data_builder.eval_splits, coordinator + ) + if coordinator.is_master(): + print(results) + if args.target_f1 is not None and "f1" in results: + assert results["f1"] >= args.target_f1, f'f1 score {results["f1"]} is lower than target {args.target_f1}' + + +def fit( + model: nn.Module, + optimizer: Optimizer, + scheduler, + train_dataloader, + max_epochs, + accumulation_steps, + batch_size, + coordinator, +): + step_bar = tqdm( + range(len(train_dataloader) // accumulation_steps * max_epochs), + desc=f"steps", + disable=not coordinator.is_master(), + ) + total_loss = 0 + for epoch in range(max_epochs): + model.train() + for batch_id, batch in enumerate(train_dataloader): + batch = to_device(batch, torch.cuda.current_device()) + outputs = model(**batch) + loss = outputs.loss + loss = loss / accumulation_steps + loss.backward() + total_loss += loss.item() + if (batch_id + 1) % accumulation_steps == 0: + optimizer.step() + scheduler.step() + optimizer.zero_grad() + step_bar.set_postfix( + {"epoch": epoch, "loss": total_loss / batch_size, "lr": scheduler.get_last_lr()[0]} + ) + total_loss = 0 + step_bar.update() + + +# evaluate +@torch.no_grad() +def evaluate_model( + model: nn.Module, + test_dataloader: Union[DataLoader, List[DataLoader]], + num_labels: int, + task_name: str, + eval_splits: List[str], + coordinator: DistCoordinator, +): + metric = evaluate.load("glue", task_name, process_id=coordinator.rank, num_process=coordinator.world_size) + model.eval() + + def evaluate_subset(dataloader: DataLoader): + accum_loss = torch.zeros(1, device=torch.cuda.current_device()) + for batch in dataloader: + batch = to_device(batch, torch.cuda.current_device()) + outputs = model(**batch) + val_loss, logits = outputs[:2] + accum_loss.add_(val_loss) + + if num_labels > 1: + preds = torch.argmax(logits, axis=1) + elif num_labels == 1: + preds = logits.squeeze() + + labels = batch["labels"] + metric.add_batch(predictions=preds, references=labels) + + results = metric.compute() + if coordinator.is_master(): + results["loss"] = accum_loss.item() / (len(dataloader) * dataloader.batch_size) + return results + + if isinstance(test_dataloader, DataLoader): + return evaluate_subset(test_dataloader) + else: + assert len(test_dataloader) == len(eval_splits) + final_results = {} + for split, sub_loader in zip(eval_splits, test_dataloader): + results = evaluate_subset(sub_loader) + final_results.update({f"{k}_{split}": v for k, v in results.items()}) + return final_results + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("-t", "--task", default="mrpc", help="GLUE task to run") + parser.add_argument("--model", type=str, default="bert") + parser.add_argument("--pretrain", type=str, default="bert-base-uncased") + parser.add_argument("--max_epochs", type=int, default=1) + parser.add_argument("--batch_size", type=int, default=4) + parser.add_argument("--lr", type=float, default=2.4e-5) + parser.add_argument("--fused_layernorm", type=bool, default=False) + parser.add_argument("--accumulation_steps", type=int, default=8) + parser.add_argument("--warmup_fraction", type=float, default=0.03) + parser.add_argument("--target_f1", type=float, default=None) + args = parser.parse_args() + train(args) diff --git a/colossalai/shardformer/examples/convergence_benchmark.sh b/colossalai/shardformer/examples/convergence_benchmark.sh new file mode 100644 index 0000000000000000000000000000000000000000..22f13a7cf827f3780267dbe8d357465b7aa421f9 --- /dev/null +++ b/colossalai/shardformer/examples/convergence_benchmark.sh @@ -0,0 +1,9 @@ +torchrun --standalone --nproc_per_node=4 convergence_benchmark.py \ + --model "bert" \ + --pretrain "bert-base-uncased" \ + --max_epochs 3 \ + --batch_size 2 \ + --lr 2.4e-5 \ + --fused_layernorm False \ + --accumulation_steps 8 \ + --warmup_fraction 0.03 diff --git a/colossalai/shardformer/examples/data.py b/colossalai/shardformer/examples/data.py new file mode 100644 index 0000000000000000000000000000000000000000..ddf44a874659a3cfffb768490a98a0ccf73db6ce --- /dev/null +++ b/colossalai/shardformer/examples/data.py @@ -0,0 +1,137 @@ +import datasets +from torch.utils.data import DataLoader +from transformers import AutoTokenizer, PreTrainedTokenizer + +from colossalai.booster.plugin.dp_plugin_base import DPPluginBase + + +class GLUEDataBuilder: + task_text_field_map = { + "cola": ["sentence"], + "sst2": ["sentence"], + "mrpc": ["sentence1", "sentence2"], + "qqp": ["question1", "question2"], + "stsb": ["sentence1", "sentence2"], + "mnli": ["premise", "hypothesis"], + "qnli": ["question", "sentence"], + "rte": ["sentence1", "sentence2"], + "wnli": ["sentence1", "sentence2"], + "ax": ["premise", "hypothesis"], + } + + glue_task_num_labels = { + "cola": 2, + "sst2": 2, + "mrpc": 2, + "qqp": 2, + "stsb": 1, + "mnli": 3, + "qnli": 2, + "rte": 2, + "wnli": 2, + "ax": 3, + } + + loader_columns = [ + "datasets_idx", + "input_ids", + "token_type_ids", + "attention_mask", + "start_positions", + "end_positions", + "labels", + ] + + def __init__( + self, + model_name_or_path: str, + plugin: DPPluginBase = None, + task_name: str = "mrpc", + max_seq_length: int = 128, + train_batch_size: int = 32, + eval_batch_size: int = 32, + **kwargs, + ): + super().__init__() + self.model_name_or_path = model_name_or_path + self.task_name = task_name + self.max_seq_length = max_seq_length + self.train_batch_size = train_batch_size + self.eval_batch_size = eval_batch_size + self.plugin = plugin + + self.text_fields = self.task_text_field_map[task_name] + self.num_labels = self.glue_task_num_labels[task_name] + self.tokenizer: PreTrainedTokenizer = AutoTokenizer.from_pretrained(self.model_name_or_path, use_fast=True) + self.setup() + + def setup(self): + self.dataset = datasets.load_dataset("glue", self.task_name) + + for split in self.dataset.keys(): + self.dataset[split] = self.dataset[split].map( + self.convert_to_features, + batched=True, + remove_columns=["label"], + ) + self.columns = [c for c in self.dataset[split].column_names if c in self.loader_columns] + self.dataset[split].set_format(type="torch", columns=self.columns) + + self.eval_splits = [x for x in self.dataset.keys() if "validation" in x] + + def prepare_data(self): + datasets.load_dataset("glue", self.task_name) + AutoTokenizer.from_pretrained(self.model_name_or_path, use_fast=True) + + def train_dataloader(self): + if self.plugin == None: + return self.native_prepare_dataloader( + self.dataset["train"], batch_size=self.train_batch_size, shuffle=True, drop_last=True + ) + return self.plugin.prepare_dataloader( + self.dataset["train"], batch_size=self.train_batch_size, shuffle=True, drop_last=True + ) + + def val_dataloader(self): + if self.plugin == None: + return self.native_prepare_dataloader(self.dataset["validation"], batch_size=self.eval_batch_size) + if len(self.eval_splits) == 1: + return self.plugin.prepare_dataloader(self.dataset["validation"], batch_size=self.eval_batch_size) + elif len(self.eval_splits) > 1: + return [ + self.plugin.prepare_dataloader(self.dataset[x], batch_size=self.eval_batch_size) + for x in self.eval_splits + ] + + def test_dataloader(self): + if self.plugin == None: + return self.native_prepare_dataloader(self.dataset["test"], batch_size=self.train_batch_size) + if len(self.eval_splits) == 1: + return self.plugin.prepare_dataloader(self.dataset["test"], batch_size=self.eval_batch_size) + elif len(self.eval_splits) > 1: + return [ + self.plugin.prepare_dataloader(self.dataset[x], batch_size=self.eval_batch_size) + for x in self.eval_splits + ] + + def convert_to_features(self, example_batch): + # Either encode single sentence or sentence pairs + if len(self.text_fields) > 1: + texts_or_text_pairs = list(zip(example_batch[self.text_fields[0]], example_batch[self.text_fields[1]])) + else: + texts_or_text_pairs = example_batch[self.text_fields[0]] + + # Tokenize the text/text pairs + features = self.tokenizer.batch_encode_plus( + texts_or_text_pairs, max_length=self.max_seq_length, padding="max_length", truncation=True + ) + + # Rename label to labels to make it easier to pass to model forward + features["labels"] = example_batch["label"] + + return features + + def native_prepare_dataloader(self, dataset, batch_size, shuffle=False, drop_last=False, pin_memory=False): + return DataLoader( + dataset, batch_size=batch_size, sampler=None, shuffle=shuffle, drop_last=drop_last, pin_memory=pin_memory + ) diff --git a/colossalai/shardformer/examples/performance_benchmark.py b/colossalai/shardformer/examples/performance_benchmark.py new file mode 100644 index 0000000000000000000000000000000000000000..81215dcdf5d442c65272c42a6f55ab9d4dbfdef3 --- /dev/null +++ b/colossalai/shardformer/examples/performance_benchmark.py @@ -0,0 +1,88 @@ +""" +Shardformer Benchmark +""" +import torch +import torch.distributed as dist +import transformers +import triton + +import colossalai +from colossalai.shardformer import ShardConfig, ShardFormer + + +def data_gen(batch_size, seq_length): + input_ids = torch.randint(0, seq_length, (batch_size, seq_length), dtype=torch.long) + attention_mask = torch.ones((batch_size, seq_length), dtype=torch.long) + return dict(input_ids=input_ids, attention_mask=attention_mask) + + +def data_gen_for_sequence_classification(batch_size, seq_length): + # LM data gen + # the `labels` of LM is the token of the output, cause no padding, use `input_ids` as `labels` + data = data_gen(batch_size, seq_length) + data["labels"] = torch.ones((batch_size), dtype=torch.long) + return data + + +MODEL_CONFIG = transformers.LlamaConfig( + num_hidden_layers=4, + hidden_size=128, + intermediate_size=256, + num_attention_heads=4, + max_position_embeddings=128, + num_labels=16, + pad_token_id=2, +) +BATCH, N_HEADS, N_CTX, D_HEAD = 4, 8, 4096, 64 +model_func = lambda: transformers.LlamaForSequenceClassification(MODEL_CONFIG) + +# vary seq length for fixed head and batch=4 +configs = [ + triton.testing.Benchmark( + x_names=["N_CTX"], + x_vals=[2**i for i in range(8, 13)], + line_arg="provider", + line_vals=["org_model", "shard_model"], + line_names=["org_model", "shard_model"], + styles=[("red", "-"), ("blue", "-")], + ylabel="ms", + plot_name=f"lama_for_sequence_classification-batch-{BATCH}", + args={"BATCH": BATCH, "dtype": torch.float16, "model_func": model_func}, + ) +] + + +def train(model, data): + output = model(**data) + loss = output.logits.mean() + loss.backward() + + +@triton.testing.perf_report(configs) +def bench_shardformer(BATCH, N_CTX, provider, model_func, dtype=torch.float32, device="cuda"): + warmup = 10 + rep = 100 + # prepare data + data = data_gen_for_sequence_classification(BATCH, N_CTX) + data = {k: v.cuda() for k, v in data.items()} + model = model_func().to(device) + model.train() + if provider == "org_model": + fn = lambda: train(model, data) + ms = triton.testing.do_bench(fn, warmup=warmup, rep=rep) + return ms + if provider == "shard_model": + shard_config = ShardConfig(enable_fused_normalization=True, enable_tensor_parallelism=True) + shard_former = ShardFormer(shard_config=shard_config) + sharded_model, _ = shard_former.optimize(model) + sharded_model = sharded_model.cuda() + fn = lambda: train(sharded_model, data) + ms = triton.testing.do_bench(fn, warmup=warmup, rep=rep) + return ms + + +# start benchmark, command: +# torchrun --standalone --nproc_per_node=2 performance_benchmark.py +if __name__ == "__main__": + colossalai.launch_from_torch({}) + bench_shardformer.run(save_path=".", print_data=dist.get_rank() == 0) diff --git a/colossalai/shardformer/layer/__init__.py b/colossalai/shardformer/layer/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..a134a2cbd21c815f82befaf7103464cc7228fe9a --- /dev/null +++ b/colossalai/shardformer/layer/__init__.py @@ -0,0 +1,23 @@ +from .dropout import DropoutForParallelInput, DropoutForReplicatedInput +from .embedding import Embedding1D, VocabParallelEmbedding1D +from .linear import Linear1D_Col, Linear1D_Row +from .loss import cross_entropy_1d +from .normalization import FusedLayerNorm, FusedRMSNorm +from .parallel_module import ParallelModule +from .qkv_fused_linear import FusedLinear1D_Col, GPT2FusedLinearConv1D_Col, GPT2FusedLinearConv1D_Row + +__all__ = [ + "Embedding1D", + "VocabParallelEmbedding1D", + "Linear1D_Col", + "Linear1D_Row", + "GPT2FusedLinearConv1D_Col", + "GPT2FusedLinearConv1D_Row", + "DropoutForParallelInput", + "DropoutForReplicatedInput", + "cross_entropy_1d", + "FusedLayerNorm", + "FusedRMSNorm", + "FusedLinear1D_Col", + "ParallelModule", +] diff --git a/colossalai/shardformer/layer/_operation.py b/colossalai/shardformer/layer/_operation.py new file mode 100644 index 0000000000000000000000000000000000000000..5ec48096183b0428f07bb5b6a5ca0af9800643f7 --- /dev/null +++ b/colossalai/shardformer/layer/_operation.py @@ -0,0 +1,570 @@ +import torch +import torch.distributed as dist +import torch.nn.functional as F + +try: + import fused_mix_prec_layer_norm_cuda +except: + fused_mix_prec_layer_norm_cuda = None + + +class FusedLayerNormAffineFunction1D(torch.autograd.Function): + r"""Layernorm + + Args: + input: input matrix. + weight: weight matrix. + bias: bias matrix. + normalized_shape: input shape from an expected input of size. + :math:`[* \times \text{normalized_shape}[0] \times \text{normalized_shape}[1] \times \ldots \times \text{normalized_shape}[-1]]` + If a single integer is used, it is treated as a singleton list, and this module will + normalize over the last dimension which is expected to be of that specific size. + eps: a value added to the denominator for numerical stability + """ + + @staticmethod + def forward(ctx, input, weight, bias, normalized_shape, eps): + ctx.normalized_shape = normalized_shape + ctx.eps = eps + input_ = input.contiguous() + weight_ = weight.contiguous() + bias_ = bias.contiguous() + output, mean, invvar = fused_mix_prec_layer_norm_cuda.forward_affine( + input_, ctx.normalized_shape, weight_, bias_, ctx.eps + ) + ctx.save_for_backward(input_, weight_, bias_, mean, invvar) + return output + + @staticmethod + def backward(ctx, grad_output): + input_, weight_, bias_, mean, invvar = ctx.saved_tensors + grad_input = grad_weight = grad_bias = None + grad_input, grad_weight, grad_bias = fused_mix_prec_layer_norm_cuda.backward_affine( + grad_output.contiguous(), mean, invvar, input_, ctx.normalized_shape, weight_, bias_, ctx.eps + ) + + return grad_input, grad_weight, grad_bias, None, None + + +class MatmulWithAsyncCommunication(torch.autograd.Function): + """ + Linear layer execution with asynchronous communication in backprop. + """ + + @staticmethod + def forward(ctx, input_, weight, bias, process_group, async_grad_allreduce): + ctx.save_for_backward(input_, weight) + ctx.use_bias = bias is not None + ctx.process_group = process_group + ctx.async_grad_allreduce = async_grad_allreduce + + output = torch.matmul(input_, weight) + + if bias is not None: + output = output + bias + return output + + @staticmethod + def backward(ctx, grad_output): + input, weight = ctx.saved_tensors + use_bias = ctx.use_bias + + total_input = input + grad_input = grad_output.matmul(weight.T) + grad_output = grad_output.contiguous() + # Convert the tensor shapes to 2D for execution compatibility + if len(grad_output.shape) > 2: + grad_output = grad_output.view(-1, grad_output.shape[-1]) + total_input = total_input.view(-1, total_input.shape[-1]) + + if ctx.async_grad_allreduce: + # Asynchronous all-reduce + handle = dist.all_reduce(grad_input, group=ctx.process_group, async_op=True) + # Delay the start of weight gradient computation shortly (3us) to have + # all-reduce scheduled first and have GPU resources allocated + _ = torch.empty(1, device=grad_output.device) + 1 + + grad_weight = total_input.t().matmul(grad_output) + grad_bias = grad_output.sum(dim=0) if use_bias else None + + if ctx.async_grad_allreduce: + handle.wait() + + return grad_input, grad_weight, grad_bias, None, None, None + + +class LinearWithAsyncCommunication(torch.autograd.Function): + """ + Linear layer execution with asynchronous communication in backprop. + """ + + @staticmethod + def forward(ctx, input_, weight, bias, process_group, async_grad_allreduce): + ctx.save_for_backward(input_, weight) + ctx.use_bias = bias is not None + ctx.process_group = process_group + ctx.async_grad_allreduce = async_grad_allreduce + + if bias is not None: + output = F.linear(input_, weight, bias) + else: + output = F.linear(input_, weight) + return output + + @staticmethod + def backward(ctx, grad_output): + input, weight = ctx.saved_tensors + use_bias = ctx.use_bias + + total_input = input + grad_input = grad_output.matmul(weight) + grad_output = grad_output.contiguous() + # Convert the tensor shapes to 2D for execution compatibility + if len(grad_output.shape) > 2: + grad_output = grad_output.view(-1, grad_output.shape[-1]) + total_input = total_input.view(-1, total_input.shape[-1]) + + if ctx.async_grad_allreduce: + # Asynchronous all-reduce + handle = dist.all_reduce(grad_input, group=ctx.process_group, async_op=True) + # Delay the start of weight gradient computation shortly (3us) to have + # all-reduce scheduled first and have GPU resources allocated + _ = torch.empty(1, device=grad_output.device) + 1 + + grad_weight = grad_output.t().matmul(total_input) + grad_bias = grad_output.sum(dim=0) if use_bias else None + + if ctx.async_grad_allreduce: + handle.wait() + + return grad_input, grad_weight, grad_bias, None, None, None + + +class _LinearWithGatherForwardReduceScatterBackward(torch.autograd.Function): + """Gather input from sequence parallel in forward and reduce-scatter gradient in backward + + Args: + input_ (`torch.Tensor`): The input tensor from sequence parallel region. + process_group (`torch.distributed.ProcessGroup`): The process group used for collective communication. + overlap (`bool`): Whther to overlap the all_gather op and gradient calculate in backward. + + """ + + @staticmethod + def forward(ctx, input_, weight, bias, process_group, async_grad_reduce_scatter, dim, overlap=True): + ctx.save_for_backward(input_, weight) + ctx.use_bias = bias is not None + ctx.process_group = process_group + ctx.async_grad_reduce_scatter = async_grad_reduce_scatter + ctx.dim = dim + ctx.overlap = overlap + + input_parallel = _gather(input_, dim, process_group) + + if bias is not None: + output = F.linear(input_parallel, weight, bias) + else: + output = F.linear(input_parallel, weight) + + return output + + @staticmethod + def backward(ctx, grad_output): + input_, weight = ctx.saved_tensors + use_bias = ctx.use_bias + dim = ctx.dim + process_group = ctx.process_group + overlap = ctx.overlap + + if not overlap: + input_parallel = _gather(input_, dim, process_group) + + total_input = input_parallel + grad_input = grad_output.matmul(weight) + grad_output = grad_output.contiguous() + # Convert the tensor shapes to 2D for execution compatibility + if len(grad_output.shape) > 2: + grad_output = grad_output.view(-1, grad_output.shape[-1]) + total_input = total_input.view(-1, total_input.shape[-1]) + + if ctx.async_grad_reduce_scatter: + # Asynchronous reduce-scatter + input_list = [ + item.contiguous() for item in torch.chunk(grad_input, dist.get_world_size(process_group), dim=dim) + ] + output = torch.empty( + input_.shape, dtype=input_parallel.dtype, device=input_parallel.device + ).contiguous() + handle = dist.reduce_scatter(output, input_list, group=process_group, async_op=True) + # Delay the start of weight gradient computation shortly (3us) to have + # reduce-scatter scheduled first and have GPU resources allocated + _ = torch.empty(1, device=grad_output.device) + 1 + + grad_weight = grad_output.t().matmul(total_input) + grad_bias = grad_output.sum(dim=0) if use_bias else None + + if ctx.async_grad_reduce_scatter: + handle.wait() + + else: + input_ = input_.contiguous() + world_size = dist.get_world_size(process_group) + tensor_list = [torch.empty_like(input_) for _ in range(world_size)] + + # do all gather in is async way + gather_handle = dist.all_gather(tensor_list, input_, group=process_group, async_op=True) + # calculate gradient and prepare data asynchronously with all-gather + # calculate + grad_input = grad_output.matmul(weight) + grad_output = grad_output.contiguous() + # Convert the tensor shapes to 2D for execution compatibility + if len(grad_output.shape) > 2: + grad_output = grad_output.view(-1, grad_output.shape[-1]) + grad_bias = grad_output.sum(dim=0) if use_bias else None + # prepare data + input_list = [ + item.contiguous() for item in torch.chunk(grad_input, dist.get_world_size(process_group), dim=dim) + ] + output = torch.empty(input_.shape, dtype=input_.dtype, device=input_.device).contiguous() + # wait until all-gather finished + gather_handle.wait() + + # do reduce-scatter in async way + reducescatter_handle = dist.reduce_scatter(output, input_list, group=process_group, async_op=True) + input_parallel = torch.cat(tensor_list, dim=dim).contiguous() + # calculate gradient + if len(input_parallel.shape) > 2: + input_parallel = input_parallel.view(-1, input_parallel.shape[-1]) + grad_weight = grad_output.t().matmul(input_parallel) + # wait until reduce-scatter finished + reducescatter_handle.wait() + + return output, grad_weight, grad_bias, None, None, None, None + + +class _LinearWithReduceScatterForwardGatherBackward(torch.autograd.Function): + """Gather input from sequence parallel in forward and reduce-scatter gradient in backward + + Args: + input_ (`torch.Tensor`): The input tensor from sequence parallel region. + process_group (`torch.distributed.ProcessGroup`): The process group used for collective communication. + + """ + + @staticmethod + def forward(ctx, input_, process_group, dim): + ctx.dim = dim + ctx.process_group = process_group + + # do reduce-scatter + new_shape = list(input_.shape) + assert ( + new_shape[dim] % dist.get_world_size(process_group) == 0 + ), f"The dimension to split ({new_shape[dim]}) is not a multiple of tensor parallel size ({dist.get_world_size(process_group)}). " + new_shape[dim] = new_shape[dim] // dist.get_world_size(process_group) + input_list = [item.contiguous() for item in torch.chunk(input_, dist.get_world_size(process_group), dim=dim)] + output = torch.empty(new_shape, dtype=input_.dtype, device=input_.device) + dist.reduce_scatter(output, input_list, group=process_group) + + return output + + @staticmethod + def backward(ctx, grad_output): + dim = ctx.dim + process_group = ctx.process_group + + return _gather(grad_output, dim, process_group), None, None + + +class _MatmulWithGatherForwardReduceScatterBackward(torch.autograd.Function): + """ + This class is designed for matmul operation with gather forward and reduce-scatter backward. + + Args: + input_ (`torch.Tensor`): input matrix. + dim (int): the dimension to perform split and gather + process_group (`torch.distributed.ProcessGroup`): the process group used for collective communication + + """ + + @staticmethod + def forward(ctx, input_, weight, bias, process_group, async_grad_reduce_scatter, dim, overlap): + ctx.save_for_backward(input_, weight) + ctx.use_bias = bias is not None + ctx.process_group = process_group + ctx.async_grad_reduce_scatter = async_grad_reduce_scatter + ctx.dim = dim + ctx.overlap = overlap + + input_parallel = _gather(input_, dim, process_group) + + output = torch.matmul(input_parallel, weight) + + if bias is not None: + output = output + bias + return output + + @staticmethod + def backward(ctx, grad_output): + input_, weight = ctx.saved_tensors + use_bias = ctx.use_bias + dim = ctx.dim + process_group = ctx.process_group + overlap = ctx.overlap + + if not overlap: + input_parallel = _gather(input_, dim, process_group) + + total_input = input_parallel + grad_input = grad_output.matmul(weight.T) + grad_output = grad_output.contiguous() + # Convert the tensor shapes to 2D for execution compatibility + if len(grad_output.shape) > 2: + grad_output = grad_output.view(-1, grad_output.shape[-1]) + total_input = total_input.view(-1, total_input.shape[-1]) + + if ctx.async_grad_reduce_scatter: + # Asynchronous reduce-scatter + input_list = [ + item.contiguous() for item in torch.chunk(grad_input, dist.get_world_size(process_group), dim=dim) + ] + output = torch.empty( + input_.shape, dtype=input_parallel.dtype, device=input_parallel.device + ).contiguous() + handle = dist.reduce_scatter(output, input_list, group=process_group, async_op=True) + # Delay the start of weight gradient computation shortly (3us) to have + # reduce-scatter scheduled first and have GPU resources allocated + _ = torch.empty(1, device=grad_output.device) + 1 + + grad_weight = total_input.t().matmul(grad_output) + grad_bias = grad_output.sum(dim=0) if use_bias else None + + if ctx.async_grad_reduce_scatter: + handle.wait() + + else: + world_size = dist.get_world_size(process_group) + tensor_list = [torch.empty_like(input_) for _ in range(world_size)] + + # do all gather in is async way + gather_handle = dist.all_gather(tensor_list, input_, group=process_group, async_op=True) + # calculate gradient and prepare data asynchronously with all-gather + # calculate + grad_input = grad_output.matmul(weight.T) + grad_output = grad_output.contiguous() + # Convert the tensor shapes to 2D for execution compatibility + if len(grad_output.shape) > 2: + grad_output = grad_output.view(-1, grad_output.shape[-1]) + grad_bias = grad_output.sum(dim=0) if use_bias else None + # prepare data + input_list = [ + item.contiguous() for item in torch.chunk(grad_input, dist.get_world_size(process_group), dim=dim) + ] + output = torch.empty(input_.shape, dtype=input_.dtype, device=input_.device).contiguous() + # wait until all-gather finished + gather_handle.wait() + + # do reduce-scatter in async way + reducescatter_handle = dist.reduce_scatter(output, input_list, group=process_group, async_op=True) + input_parallel = torch.cat(tensor_list, dim=dim).contiguous() + # calculate gradient + if len(input_parallel.shape) > 2: + input_parallel = input_parallel.view(-1, input_parallel.shape[-1]) + grad_weight = input_parallel.t().matmul(grad_output) + # wait until reduce-scatter finished + reducescatter_handle.wait() + + return output, grad_weight, grad_bias, None, None, None, None + + +class _SplitForwardGatherBackward(torch.autograd.Function): + """ + Split the input and keep only the corresponding chuck to the rank. + + Args: + input_ (`torch.Tensor`): input matrix. + dim (int): the dimension to perform split and gather + process_group (`torch.distributed.ProcessGroup`): the process group used for collective communication + + """ + + @staticmethod + def forward(ctx, input_, dim, process_group): + ctx.process_group = process_group + ctx.dim = dim + return _split(input_, dim, process_group) + + @staticmethod + def backward(ctx, grad_output): + return _gather(grad_output, ctx.dim, ctx.process_group), None, None + + +class _ReduceForward(torch.autograd.Function): + """ + All-reduce the input from the model parallel region. + + Args: + input_: input matrix. + parallel_mode: parallel mode. + """ + + @staticmethod + def forward(ctx, input_, process_group): + return _reduce(input_, process_group) + + @staticmethod + def backward(ctx, grad_output): + return grad_output, None + + +class _ReduceBackward(torch.autograd.Function): + """ + All-reduce the input from the model parallel region. + + Args: + input_: input matrix. + parallel_mode: parallel mode. + """ + + @staticmethod + def forward(ctx, input_, process_group): + ctx.process_group = process_group + return input_ + + @staticmethod + def backward(ctx, grad_output): + return _reduce(grad_output, ctx.process_group), None + + +class _GatherForwardSplitBackward(torch.autograd.Function): + """Gather the input from model parallel region and concatenate. + + Args: + input_: input matrix. + parallel_mode: parallel mode. + dim: dimension + """ + + @staticmethod + def forward(ctx, input_, dim, process_group): + ctx.process_group = process_group + ctx.dim = dim + return _gather(input_, dim, process_group) + + @staticmethod + def backward(ctx, grad_output): + return _split(grad_output, ctx.dim, ctx.process_group), None, None + + +def _reduce(input_, process_group): + # skip if only one rank involved + if dist.get_world_size(process_group) == 1: + return input_ + else: + dist.all_reduce(input_, group=process_group) + return input_ + + +def _split(input_, dim=-1, process_group=None): + # skip if only one rank involved + world_size = dist.get_world_size(process_group) + if world_size == 1: + return input_ + + # Split along last dimension. + dim_size = input_.size(dim) + assert dim_size % world_size == 0, ( + f"The dimension to split ({dim_size}) is not a multiple of world size ({world_size}), " + f"cannot split tensor evenly" + ) + + tensor_list = torch.split(input_, dim_size // world_size, dim=dim) + rank = dist.get_rank(process_group) + output = tensor_list[rank].contiguous() + + return output + + +def _gather(input_, dim=-1, process_group=None): + # skip if only one rank involved + world_size = dist.get_world_size(process_group) + if world_size == 1: + return input_ + + # all gather + input_ = input_.contiguous() + tensor_list = [torch.empty_like(input_) for _ in range(world_size)] + torch.distributed.all_gather(tensor_list, input_, group=process_group) + + # concat + output = torch.cat(tensor_list, dim=dim).contiguous() + + return output + + +def _reduce_scatter(input_, dim=1, process_group=None): + """Do reduce-scatter operation. + + Args: + input_ (`torch.Tensor`): The input tensor from sequence parallel region. + dim (int): The dimension to perform reduce-scatter. + process_group (`torch.distributed.ProcessGroup`): The process group used for collective communication. + """ + world_size = dist.get_world_size(process_group) + if world_size == 1: + return input_ + + # reduce-scatter + new_shape = list(input_.shape) + assert ( + new_shape[dim] % dist.get_world_size(process_group) == 0 + ), f"The dimension to split ({new_shape[dim]}) is not a multiple of tensor parallel size ({dist.get_world_size(process_group)}). " + new_shape[dim] = new_shape[dim] // world_size + output = torch.empty(new_shape, dtype=input_.dtype, device=input_.device) + dist.reduce_scatter(output, input_, group=process_group) + + return output + + +def matmul_with_async_comm(input_, weight, bias, process_group, async_grad_allreduce): + return MatmulWithAsyncCommunication.apply(input_, weight, bias, process_group, async_grad_allreduce) + + +def linear_with_async_comm(input_, weight, bias, process_group, async_grad_allreduce): + return LinearWithAsyncCommunication.apply(input_, weight, bias, process_group, async_grad_allreduce) + + +def linear_gather_forward_reducescatter_backward( + input_, weight, bias, process_group, async_grad_reduce_scatter, dim, overlap +): + return _LinearWithGatherForwardReduceScatterBackward.apply( + input_, weight, bias, process_group, async_grad_reduce_scatter, dim, overlap + ) + + +def linear_reducescatter_forward_gather_backward(input_, process_group, dim): + return _LinearWithReduceScatterForwardGatherBackward.apply(input_, process_group, dim) + + +def matmul_gather_forward_reducescatter_backward( + input_, weight, bias, process_group, async_grad_reduce_scatter, dim, overlap +): + return _MatmulWithGatherForwardReduceScatterBackward.apply( + input_, weight, bias, process_group, async_grad_reduce_scatter, dim, overlap + ) + + +def gather_forward_split_backward(input_, dim, process_group): + return _GatherForwardSplitBackward.apply(input_, dim, process_group) + + +def split_forward_gather_backward(input_, dim, process_group): + return _SplitForwardGatherBackward.apply(input_, dim, process_group) + + +def reduce_forward(input_, process_group): + return _ReduceForward.apply(input_, process_group) + + +def reduce_backward(input_, process_group): + return _ReduceBackward.apply(input_, process_group) diff --git a/colossalai/shardformer/layer/dropout.py b/colossalai/shardformer/layer/dropout.py new file mode 100644 index 0000000000000000000000000000000000000000..8771913ee62fea0bd50cfe61d03542e2c13860d2 --- /dev/null +++ b/colossalai/shardformer/layer/dropout.py @@ -0,0 +1,84 @@ +from typing import List, Union + +import torch +import torch.nn as nn +from torch.distributed import ProcessGroup + +from .parallel_module import ParallelModule +from .utils import create_randomizer_with_offset + +__all__ = ["DropoutForParallelInput", "DropoutForReplicatedInput"] + + +class DropoutForParallelInput(ParallelModule, nn.Dropout): + """ + The Dropout Layer will apply dropout mask to the input tensor. The dropout mask is generated with + randomness on different ranks of the given process group. This can avoid the same dropout mask is generated + and applied on the same position of different ranks, leading to poor convergence performance. + + Args: + p (float): probability of an element to be zeroed. Defaults to 0.5. + inplace (bool): If set to True, will do this operation in-place. Defaults to False. + process_group (ProcessGroup): the process group to be used for generating randomness. Defaults to None. + """ + + def __init__(self, p: float = 0.5, inplace: bool = False, process_group: ProcessGroup = None): + # init with nn.Dropout + super(nn.Dropout, self).__init__(p=p, inplace=inplace) + + # offset the seed with randomizer index and rank + seed = torch.random.initial_seed() + self.randomizer = create_randomizer_with_offset(seed, process_group=process_group) + + @staticmethod + def from_native_module( + module: nn.Dropout, process_group: Union[ProcessGroup, List[ProcessGroup]] = None + ) -> "DropoutForParallelInput": + """ + Create a DropoutForParallelInput layer from a native dropout layer. + """ + p = module.p + inplace = module.inplace + return DropoutForParallelInput(p=p, inplace=inplace, process_group=process_group) + + def forward(self, input): + with self.randomizer.fork_rng(): + input = super().forward(input) + return input + + +class DropoutForReplicatedInput(ParallelModule, nn.Dropout): + """ + The Dropout Layer will apply dropout mask to the input tensor. The dropout mask is generated with + randomness on different ranks of the given process group. This can avoid the same dropout mask is generated + and applied on the same position of different ranks, leading to poor convergence performance. + + Args: + p (float): probability of an element to be zeroed. Defaults to 0.5. + inplace (bool): If set to True, will do this operation in-place. Defaults to False. + process_group (ProcessGroup): the process group to be used for generating randomness. Defaults to None. + """ + + def __init__(self, p: float = 0.5, inplace: bool = False, process_group: ProcessGroup = None): + # init with nn.Dropout + super(nn.Dropout, self).__init__(p=p, inplace=inplace) + + # offset the seed with randomizer index only + seed = torch.random.initial_seed() + self.randomizer = create_randomizer_with_offset(seed, process_group=process_group, offset_by_rank=False) + + @staticmethod + def from_native_module( + module: nn.Dropout, process_group: Union[ProcessGroup, List[ProcessGroup]] = None + ) -> "DropoutForReplicatedInput": + """ + Create a Dropout1D layer from a native dropout layer. + """ + p = module.p + inplace = module.inplace + return DropoutForReplicatedInput(p=p, inplace=inplace, process_group=process_group) + + def forward(self, input): + with self.randomizer.fork_rng(): + input = super().forward(input) + return input diff --git a/colossalai/shardformer/layer/embedding.py b/colossalai/shardformer/layer/embedding.py new file mode 100644 index 0000000000000000000000000000000000000000..62163cb009aaa397b49677a1cc5ef305835ca1a3 --- /dev/null +++ b/colossalai/shardformer/layer/embedding.py @@ -0,0 +1,315 @@ +#!/usr/bin/env python +# -*- encoding: utf-8 -*- + +from typing import Callable, List, Optional, Union + +import torch +import torch.distributed as dist +import torch.nn as nn +import torch.nn.functional as F +from torch import Tensor +from torch.distributed import ProcessGroup + +from colossalai.lazy import LazyInitContext +from colossalai.nn import init as init +from colossalai.nn.layer.utils import divide +from colossalai.tensor.d_tensor.api import ( + is_distributed_tensor, + shard_colwise, + shard_rowwise, + sharded_tensor_to_existing_param, +) + +from ._operation import gather_forward_split_backward, reduce_forward +from .parallel_module import ParallelModule +from .utils import create_randomizer_with_offset + +__all__ = ["Embedding1D", "VocabParallelEmbedding1D"] + + +class Embedding1D(ParallelModule): + r"""Embedding for 1D parallelism. + + Args: + num_embeddings (int): number of embeddings. + embedding_dim (int): dimension of embedding. + padding_idx (int, optional): If specified, the entries at padding_idx do not contribute to the gradient; + therefore, the embedding vector at padding_idx is not updated during training, + i.e. it remains as a fixed “pad”, defaults to None. + dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None. + weight_initializer (:class:`typing.Callable`, optional): + he initializer of weight, defaults to normal initializer. + + The ``args`` and ``kwargs`` used in :class:`torch.nn.functional.embedding` should contain: + :: + + max_norm (float, optional): If given, each embedding vector with norm larger than max_norm is + renormalized to have norm max_norm. Note: this will modify weight in-place. + norm_type (float, optional): The p of the p-norm to compute for the max_norm option. Default 2. + scale_grad_by_freq (bool, optional): If given, this will scale gradients by the inverse + of frequency of the words in the mini-batch. Default False. + sparse (bool, optional): If True, gradient w.r.t. weight will be a sparse tensor. Default False. + + More details about ``args`` and ``kwargs`` could be found in + `Embedding `_. + + More details about ``initializer`` please refer to + `init `_ + """ + + def __init__( + self, + num_embeddings: int, + embedding_dim: int, + padding_idx: int = None, + dtype: torch.dtype = None, + device: torch.device = None, + process_group: ProcessGroup = None, + gather_output: bool = True, + weight: Optional[nn.Parameter] = None, + weight_initializer: Callable = init.normal_(), + *args, + **kwargs, + ): + super().__init__() + + self.num_embeddings = num_embeddings + self.embedding_dim = embedding_dim + self.process_group = process_group + + self.padding_idx = padding_idx + self.embed_args = args + self.embed_kwargs = kwargs + self.gather_output = gather_output + + # offset the seed with randomizer index and rank + seed = torch.random.initial_seed() + self.randomizer = create_randomizer_with_offset(seed, process_group=self.process_group) + + # Parameters. + if weight is None: + factory_kwargs = {"device": device, "dtype": dtype} + self.weight = nn.Parameter(torch.empty((num_embeddings, self.embedding_dim), **factory_kwargs)) + else: + weight.data = weight.data.to(device=device, dtype=dtype) + self.weight = weight + if not is_distributed_tensor(self.weight): + sharded_weight = shard_colwise(self.weight.data, process_group) + sharded_tensor_to_existing_param(sharded_weight, self.weight) + + if weight is None: + with self.randomizer.fork_rng(enable_cpu=True): + self.reset_parameters(weight_initializer) + + @staticmethod + def from_native_module( + module: nn.Embedding, process_group: Union[ProcessGroup, List[ProcessGroup]] = None, *args, **kwargs + ) -> "Embedding1D": + r""" + Build a 1D parallelized Embedding from a native nn.Embedding module. + """ + LazyInitContext.materialize(module) + # get the attributes + num_embedding = module.num_embeddings + embedding_dim = module.embedding_dim + padding_idx = module.padding_idx + max_norm = module.max_norm + norm_type = module.norm_type + scale_grad_by_freq = module.scale_grad_by_freq + sparse = module.sparse + dtype = module.weight.dtype + device = module.weight.device + + # sparse is not support yet + if sparse: + raise NotImplementedError("The Embedding1D module does not support sparse embedding yet.") + + embedding = Embedding1D( + num_embeddings=num_embedding, + embedding_dim=embedding_dim, + padding_idx=padding_idx, + process_group=process_group, + dtype=dtype, + device=device, + max_norm=max_norm, + norm_type=norm_type, + scale_grad_by_freq=scale_grad_by_freq, + sparse=sparse, + weight=module.weight, + *args, + **kwargs, + ) + + return embedding + + def reset_parameters(self, weight_initializer) -> None: + fan_in, fan_out = self.num_embeddings, self.embedding_dim + weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out) + self._fill_padding_idx_with_zero() + + def _fill_padding_idx_with_zero(self) -> None: + if self.padding_idx is not None: + with torch.no_grad(): + self.weight[self.padding_idx].fill_(0) + + def forward(self, input_: Tensor) -> Tensor: + output_parallel = F.embedding(input_, self.weight, self.padding_idx, *self.embed_args, **self.embed_kwargs) + if self.gather_output: + output = gather_forward_split_backward(output_parallel, dim=-1, process_group=self.process_group) + return output + else: + return output_parallel + + +class VocabParallelEmbedding1D(ParallelModule): + r"""Embedding parallelized in the vocabulary dimension. + + Args: + num_embeddings (int): number of embeddings. + embedding_dim (int): dimension of embedding. + padding_idx (int, optional): If specified, the entries at padding_idx do not contribute to the gradient; + therefore, the embedding vector at padding_idx is not updated during training, + i.e. it remains as a fixed “pad”, defaults to None. + dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None. + weight_initializer (:class:`typing.Callable`, optional): + he initializer of weight, defaults to normal initializer. + + The ``args`` and ``kwargs`` used in :class:``torch.nn.functional.embedding`` should contain: + :: + + max_norm (float, optional): If given, each embedding vector with norm larger than max_norm is + renormalized to have norm max_norm. Note: this will modify weight in-place. + norm_type (float, optional): The p of the p-norm to compute for the max_norm option. Default 2. + scale_grad_by_freq (bool, optional): If given, this will scale gradients by the inverse + of frequency of the words in the mini-batch. Default False. + sparse (bool, optional): If True, gradient w.r.t. weight will be a sparse tensor. Default False. + + More details about ``args`` and ``kwargs`` could be found in + `Embedding `_. + + More details about initializer please refer to + `init `_. + """ + + def __init__( + self, + num_embeddings: int, + embedding_dim: int, + padding_idx: int = None, + dtype: torch.dtype = None, + device: torch.device = None, + process_group: ProcessGroup = None, + weight: Optional[nn.Parameter] = None, + weight_initializer: Callable = init.normal_(), + *args, + **kwargs, + ): + super().__init__() + self.num_embeddings = num_embeddings + self.embedding_dim = embedding_dim + self.embed_args = args + self.embed_kwargs = kwargs + self.process_group = process_group + + tensor_parallel_size = dist.get_world_size(group=process_group) + tensor_parallel_rank = dist.get_rank(group=process_group) + + self.num_embeddings_per_partition = divide(num_embeddings, tensor_parallel_size) + self.num_embeddings = self.num_embeddings_per_partition + self.vocab_start_index = tensor_parallel_rank * self.num_embeddings_per_partition + self.vocab_end_index = self.vocab_start_index + self.num_embeddings_per_partition + + # padding index + self.padding_idx = self._select_padding_idx(padding_idx) + + # offset the seed with randomizer index and rank + seed = torch.random.initial_seed() + self.randomizer = create_randomizer_with_offset(seed, process_group=self.process_group) + + # parameter + if weight is None: + factory_kwargs = {"device": device, "dtype": dtype} + self.weight = nn.Parameter(torch.empty((num_embeddings, self.embedding_dim), **factory_kwargs)) + else: + weight.data = weight.data.to(device=device, dtype=dtype) + self.weight = weight + if not is_distributed_tensor(self.weight): + sharded_weight = shard_rowwise(self.weight.data, process_group) + sharded_tensor_to_existing_param(sharded_weight, self.weight) + + if weight is None: + self.reset_parameters(weight_initializer) + + @staticmethod + def from_native_module( + module: nn.Embedding, process_group: Union[ProcessGroup, List[ProcessGroup]], *args, **kwargs + ) -> ParallelModule: + r""" + Convert a native pytorch embedding module to a parallel module. + """ + LazyInitContext.materialize(module) + # get the origin attributes + num_embeddings = module.num_embeddings + embedding_dim = module.embedding_dim + padding_idx = module.padding_idx + device = module.weight.device + + # ensure only one process group is used + if isinstance(process_group, (list, tuple)): + assert len(process_group) == 1, f"Expected only one process group, got {len(process_group)}." + process_group = process_group[0] + + # create the parallel module + vocab_embedding_1d = VocabParallelEmbedding1D( + num_embeddings=num_embeddings, + embedding_dim=embedding_dim, + padding_idx=padding_idx, + device=device, + process_group=process_group, + weight=module.weight, + *args, + **kwargs, + ) + + return vocab_embedding_1d + + def reset_parameters(self, weight_initializer) -> None: + with self.randomizer.fork_rng(enable_cpu=True): + fan_in, fan_out = self.num_embeddings, self.embedding_dim + weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out) + self._fill_padding_idx_with_zero() + + def _fill_padding_idx_with_zero(self) -> None: + if ( + self.padding_idx is not None + and self.padding_idx >= self.vocab_start_index + and self.padding_idx < self.vocab_end_index + ): + with torch.no_grad(): + self.weight[self.padding_idx - self.vocab_start_index].fill_(0) + + def _select_padding_idx(self, padding_idx: int): + # select padding index according to the rank + if padding_idx is None: + return None + elif padding_idx < self.vocab_end_index and padding_idx >= self.vocab_start_index: + return padding_idx - self.vocab_start_index + else: + return None + + def forward(self, input_: Tensor) -> Tensor: + # Build the mask. + input_mask = (input_ < self.vocab_start_index) | (input_ >= self.vocab_end_index) + # Mask the input. + masked_input = input_.clone() - self.vocab_start_index + masked_input[input_mask] = 0 + + output_parallel = F.embedding( + masked_input, self.weight, self.padding_idx, *self.embed_args, **self.embed_kwargs + ) + + # Mask the output embedding. + output_parallel[input_mask, :] = 0.0 + # Reduce across all the model parallel GPUs. + output = reduce_forward(output_parallel, self.process_group) + return output diff --git a/colossalai/shardformer/layer/linear.py b/colossalai/shardformer/layer/linear.py new file mode 100644 index 0000000000000000000000000000000000000000..cf2003877d3c7e3c73cde0aa6516e6730872b3ff --- /dev/null +++ b/colossalai/shardformer/layer/linear.py @@ -0,0 +1,425 @@ +#!/usr/bin/env python +# -*- encoding: utf-8 -*- + +import math +from typing import Callable, List, Optional, Tuple, Union + +import torch +import torch.distributed as dist +import torch.nn as nn +import torch.nn.functional as F +from torch import Tensor +from torch.distributed import ProcessGroup +from torch.nn.parameter import Parameter + +from colossalai.lazy import LazyInitContext +from colossalai.nn import init as init +from colossalai.nn.layer.utils import divide +from colossalai.tensor.d_tensor.api import ( + is_distributed_tensor, + shard_colwise, + shard_rowwise, + sharded_tensor_to_existing_param, +) + +from ._operation import ( + gather_forward_split_backward, + linear_gather_forward_reducescatter_backward, + linear_reducescatter_forward_gather_backward, + linear_with_async_comm, + reduce_forward, + split_forward_gather_backward, +) +from .parallel_module import ParallelModule +from .utils import create_randomizer_with_offset + +__all__ = ["Linear1D_Col", "Linear1D_Row"] + + +class Linear1D_Col(ParallelModule): + r"""Linear layer with column parallelism. + + The linear layer is defined as :math:`Y = XA + b`. A is parallelized along + its second dimension as :math:`A = [A_1, ..., A_p]`. + + Args: + in_features (int): size of each input sample. + out_features (int): size of each output sample. + bias (bool, optional): If set to ``False``, the layer will not learn an additive bias, defaults to ``True``. + dtype (`torch.dtype`): The dtype of parameters, defaults to None. + device (`torch.device`): The device of parameters, defaults to None. + process_group (`torch.distributed.ProcessGroup`): The process group to be used for weight sharding and communication, defaults to None. + gather_output (bool, optional): If true, call all-gather on output and make Y available + to all GPUs, otherwise, every GPU will have its output + which is :math:`Y_i = XA_i`, defaults to False + seq_parallel (`bool`): If set to ``True``, it will use sequence parallel, defaults to False. + overlap (`bool`): If set to ``True``, it will overlap input all-gather with gradient computation during backward, defaults to False. + skip_bias_add (bool): If set to ``True``, it will skip bias add for linear layer, + which is preserved for kernel fusion, defaults to False + weight_initializer (`typing.Callable`): + The initializer of weight, defaults to kaiming uniform initializer. + bias_initializer (`typing.Callable`): + The initializer of bias, defaults to xavier uniform initializer. + + More details about ``initializer`` please refer to + `init `_. + """ + + def __init__( + self, + in_features: int, + out_features: int, + bias: bool = True, + dtype: torch.dtype = None, + device: torch.device = None, + process_group: ProcessGroup = None, + gather_output: bool = False, + seq_parallel: bool = False, + seq_parallel_dim: int = 1, + overlap: torch.cuda.Stream = None, + skip_bias_add: bool = False, + weight: Optional[Parameter] = None, + bias_: Optional[Parameter] = None, + weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), + bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1), + ): + super().__init__() + + # Keep input parameters + self.in_features = in_features + self.out_features = out_features + self.gather_output = gather_output + self.seq_parallel = seq_parallel + self.seq_parallel_dim = seq_parallel_dim + self.overlap = overlap + self.skip_bias_add = skip_bias_add + self.device = device + self.process_group = process_group + + if skip_bias_add and not bias: + raise ValueError("cannot skip bias addition if bias is None") + + # offset the seed with randomizer index and rank + seed = torch.random.initial_seed() + self.randomizer = create_randomizer_with_offset(seed, process_group=self.process_group) + + # sanity check + if weight is not None: + assert not bias or bias_ is not None, "bias_ must be provided if bias is True when weight is not None" + else: + assert bias_ is None, "bias_ must be None if weight is None" + + # Parameters. + if weight is None: + factory_kwargs = {"device": device, "dtype": dtype} + self.weight = Parameter(torch.empty(self.out_features, self.in_features, **factory_kwargs)) + else: + weight.data = weight.data.to(device=device, dtype=dtype) + self.weight = weight + if not is_distributed_tensor(self.weight): + sharded_weight = shard_rowwise(self.weight.data, self.process_group) + sharded_tensor_to_existing_param(sharded_weight, self.weight) + + if bias: + if bias_ is None: + self.bias = Parameter(torch.empty(self.out_features, **factory_kwargs)) + else: + bias_.data = bias_.data.to(device=device, dtype=dtype) + self.bias = bias_ + if not is_distributed_tensor(self.bias): + sharded_bias = shard_colwise(self.bias.data, self.process_group) + sharded_tensor_to_existing_param(sharded_bias, self.bias) + else: + self.bias = None + + if weight is None: + # init weights + self.reset_parameters(weight_initializer, bias_initializer) + + @staticmethod + def from_native_module( + module: nn.Linear, process_group: Union[ProcessGroup, List[ProcessGroup]], *args, **kwargs + ) -> ParallelModule: + r""" + Convert a native PyTorch linear layer to a parallelized linear layer. + """ + LazyInitContext.materialize(module) + # get the attributes + in_features = module.in_features + out_features = module.out_features + bias = module.bias is not None + device = module.weight.device + + # ensure only one process group is passed + if isinstance(process_group, (list, tuple)): + assert len(process_group) == 1, f"Expected only one process group, got {len(process_group)}." + process_group = process_group[0] + + tp_size = dist.get_world_size(process_group) + if out_features < tp_size: + return module + + if out_features % tp_size != 0: + raise ValueError( + f"The size of out_features:{out_features} is not integer multiples of tensor parallel size: {tp_size}!" + ) + + linear_1d = Linear1D_Col( + in_features=in_features, + out_features=out_features, + bias=bias, + device=device, + process_group=process_group, + weight=module.weight, + bias_=module.bias, + *args, + **kwargs, + ) + + return linear_1d + + def reset_parameters(self, weight_initializer, bias_initializer) -> None: + with self.randomizer.fork_rng(enable_cpu=True): + fan_in, fan_out = self.in_features, self.out_features + weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out) + if self.bias is not None: + bias_initializer(self.bias, fan_in=fan_in) + + def forward(self, input_: Tensor) -> Tuple[Tensor, Tensor]: + assert ( + input_.shape[-1] == self.weight.shape[-1] + ), "Invalid shapes in Linear1D_Col forward: input={}, weight={}. Expected last dim of input {}.".format( + input_.shape, self.weight.shape, self.weight.shape[-1] + ) + + # Set up backprop all-reduce. + input_parallel = input_ + + # Matrix multiply. + bias = self.bias if not self.skip_bias_add else None + if self.seq_parallel: + output_parallel = linear_gather_forward_reducescatter_backward( + input_parallel, self.weight, bias, self.process_group, True, self.seq_parallel_dim, self.overlap + ) + else: + output_parallel = linear_with_async_comm(input_parallel, self.weight, bias, self.process_group, True) + + if self.gather_output: + # All-gather across the partitions. + output = gather_forward_split_backward(output_parallel, dim=-1, process_group=self.process_group) + else: + output = output_parallel + + if self.skip_bias_add: + return output, self.bias + else: + return output + + +class Linear1D_Row(ParallelModule): + r"""Linear layer with row parallelism + + Args: + in_features (int): size of each input sample. + out_features (int): size of each output sample. + bias (bool, optional): If set to ``False``, the layer will not learn an additive bias, defaults to ``True``. + dtype (`torch.dtype`): The dtype of parameters, defaults to None. + parallel_input (bool): If set to ``True``, it's assumed that the input is split, defaults to False. + process_group (`torch.distributed.ProcessGroup`): The process group to be used for weight sharding and communication, defaults to None. + seq_parallel (`bool`): If set to ``True``, it will use sequence parallel, defaults to False. + skip_bias_add (bool): If set to ``True``, it will skip bias add for linear layer, + which is preserved for kernel fusion, defaults to False + weight_initializer (:class:`typing.Callable`, optional): + The initializer of weight, defaults to kaiming uniform initializer. + bias_initializer (:class:`typing.Callable`, optional): + The initializer of bias, defaults to xavier uniform initializer. + + More details about ``initializer`` please refer to + `init `_. + """ + + def __init__( + self, + in_features: int, + out_features: int, + bias: bool = True, + dtype: torch.dtype = None, + device: torch.device = None, + process_group: ProcessGroup = None, + seq_parallel: bool = False, + seq_parallel_dim: int = 1, + parallel_input: bool = True, + skip_bias_add: bool = False, + weight: Optional[Parameter] = None, + bias_: Optional[Parameter] = None, + weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), + bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1), + stream_chunk_num: int = 1, + ): + super().__init__() + + self.stream_chunk_num = stream_chunk_num + + # Keep input parameters + self.in_features = in_features + self.out_features = out_features + self.parallel_input = parallel_input + self.skip_bias_add = skip_bias_add + self.process_group = process_group + self.seq_parallel = seq_parallel + self.seq_parallel_dim = seq_parallel_dim + self.num_partitions = dist.get_world_size(self.process_group) + + if skip_bias_add and not bias: + raise ValueError("cannot skip bias addition if bias is None") + + # offset the seed with randomizer index and rank + seed = torch.random.initial_seed() + self.randomizer = create_randomizer_with_offset(seed, process_group=self.process_group) + + # sanity check + if weight is not None: + assert not bias or bias_ is not None, "bias_ must be provided if bias is True when weight is not None" + else: + assert bias_ is None, "bias_ must be None if weight is None" + + # Parameters. + if weight is None: + # Initialize weight. + factory_kwargs = {"device": device, "dtype": dtype} + self.weight = Parameter(torch.empty(self.out_features, self.in_features, **factory_kwargs)) + else: + weight.data = weight.data.to(device=device, dtype=dtype) + self.weight = weight + if not is_distributed_tensor(self.weight): + sharded_weight = shard_colwise(self.weight.data, self.process_group) + sharded_tensor_to_existing_param(sharded_weight, self.weight) + + if self.stream_chunk_num > 1: + # TODO() work for inference only + self.chunk_weight() + + if bias: + if bias_ is None: + self.bias = Parameter(torch.empty(self.out_features, **factory_kwargs)) + else: + bias_.data = bias_.data.to(device=device, dtype=dtype) + self.bias = bias_ + else: + self.bias = None + + if weight is None: + with self.randomizer.fork_rng(enable_cpu=True): + self.reset_parameters(weight_initializer, bias_initializer) + + @staticmethod + def from_native_module( + module: nn.Linear, process_group: Union[ProcessGroup, List[ProcessGroup]], *args, **kwargs + ) -> ParallelModule: + r""" + Convert a native PyTorch linear layer to a parallelized linear layer. + """ + LazyInitContext.materialize(module) + # get the attributes + in_features = module.in_features + out_features = module.out_features + bias = module.bias is not None + device = module.weight.device + + # ensure only one process group is passed + if isinstance(process_group, (list, tuple)): + assert len(process_group) == 1, f"Expected only one process group, got {len(process_group)}." + process_group = process_group[0] + + tp_size = dist.get_world_size(process_group) + if in_features < tp_size: + return module + + if in_features % tp_size != 0: + raise ValueError( + f"The size of in_features:{in_features} is not integer multiples of tensor parallel size: {tp_size}!" + ) + + linear_1d = Linear1D_Row( + in_features=in_features, + out_features=out_features, + bias=bias, + device=device, + process_group=process_group, + weight=module.weight, + bias_=module.bias, + *args, + **kwargs, + ) + + return linear_1d + + def chunk_weight(self): + self.weight_list = torch.chunk(self.weight, self.stream_chunk_num, dim=0) + + @torch.no_grad() + def reset_parameters(self, weight_initializer, bias_initializer) -> None: + fan_in, fan_out = self.in_features, self.out_features + weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out) + + if self.bias is not None: + bias_initializer(self.bias, fan_in=fan_in) + if self.process_group is None: + src_rank = 0 + else: + src_rank = dist.distributed_c10d._get_global_rank(self.process_group, 0) + + origin_device = self.bias.device + bias = self.bias.cuda() + dist.broadcast(bias, src=src_rank, group=self.process_group) + bias = bias.to(origin_device) + self.bias.copy_(bias) + + def forward(self, input_: Tensor) -> Tensor: + # Set up backprop all-reduce. + if self.parallel_input: + assert ( + input_.shape[-1] == self.weight.shape[-1] + ), "Invalid shapes in Linear1D_Row forward: input={}, weight={}. Expected last dim of input {}.".format( + input_.shape, self.weight.shape, self.weight.shape[-1] + ) + input_ = input_ + else: + assert ( + divide(input_.shape[-1], self.num_partitions) == self.weight.shape[-1] + ), "Invalid shapes in Linear1D_Row forward: input={}, weight={}. Expected last dim of input {}.".format( + input_.shape, self.weight.shape, self.weight.shape[-1] * self.num_partitions + ) + input_ = split_forward_gather_backward(input_, dim=-1, process_group=self.process_group) + + if self.stream_chunk_num > 1: + if self.training: + raise RuntimeError("use stream_chunk_num=1 in Linear1D_Row for training!") + with torch.no_grad(): + output_parallel_list = [None for i in range(self.stream_chunk_num)] + handle_list = [] + for i in range(self.stream_chunk_num): + output_parallel_list[i] = F.linear(input_, self.weight_list[i]) + handle = torch.distributed.all_reduce( + output_parallel_list[i], group=self.process_group, async_op=True + ) + handle_list.append(handle) + # output_parallel_list[i] = reduce_input(output_parallel_list[i], ParallelMode.PARALLEL_1D) + for handle in handle_list: + handle.wait() + output = torch.cat(output_parallel_list, dim=-1) + else: + output_parallel = F.linear(input_, self.weight) + if self.seq_parallel: + output = linear_reducescatter_forward_gather_backward( + output_parallel, self.process_group, self.seq_parallel_dim + ) + else: + output = reduce_forward(output_parallel, self.process_group) + + if not self.skip_bias_add: + if self.bias is not None: + output = output + self.bias + return output + else: + return output, self.bias diff --git a/colossalai/shardformer/layer/loss.py b/colossalai/shardformer/layer/loss.py new file mode 100644 index 0000000000000000000000000000000000000000..848e4a3a1f7d767dcd796b525b84018a34cdbb6d --- /dev/null +++ b/colossalai/shardformer/layer/loss.py @@ -0,0 +1,109 @@ +import torch +import torch.distributed as dist +from torch.autograd import Function +from torch.distributed import ProcessGroup + +__all__ = ["DistCrossEntropy", "cross_entropy_1d"] + + +class DistCrossEntropy(Function): + r""" + Overwrite the forward and backward function to calculate the cross entropy loss before gather + + Args: + Function (:class:`torch.autograd.Function`): default + """ + + @staticmethod + def forward(ctx, vocab_logits: torch.Tensor, target: torch.Tensor, ignore_index: int, process_group: ProcessGroup): + r""" + Calculate the cross entropy loss before gather, the origin loss function is as follows: + loss = -log(exp(x[class])/sum(exp(x[i])) + and can be rewrite as: + loss = log(sum(exp(x[i])) - x[class] + + To avoid the `nan` of log(sum(exp(x[i]))), we minus the max of x[i] + + Args: + vocab_logits (:class:`torch.Tensor`): The logits of the vocabulary, shape is + [batch_size, seq_len, vocab_size] + labels (:class:`torch.Tensor`): The labels of the vocabulary, shape is + [batch_size, seq_len] + + Returns: + :class:`torch.Tensor`: The cross entropy loss + """ + # get the max + logits_max = torch.max(vocab_logits, dim=-1)[0] + dist.all_reduce(logits_max, op=dist.ReduceOp.MAX, group=process_group) + + # minus the max to avoid the result of sum of exp is too large and the log is nan + vocab_logits = vocab_logits - logits_max.unsqueeze(dim=-1) + + # mask the target in the local device + partition_vocab_size = vocab_logits.size()[-1] + rank = dist.get_rank(group=process_group) + world_size = dist.get_world_size(group=process_group) + global_vocab_size = partition_vocab_size * world_size + + # [down, up) => false, other device and -100 => true + delta = (global_vocab_size + world_size - 1) // world_size + down_threshold = rank * delta + up_threshold = down_threshold + delta + mask = (target < down_threshold) | (target >= up_threshold) + masked_target = target.clone() - down_threshold + masked_target[mask] = 0 + + # reshape the logits and target + # reshape the vocab_logits to [bath_size * seq_len, vocab_size] + # reshape the labels to [bath_size * seq_len] + logits_2d = vocab_logits.view(-1, partition_vocab_size) + masked_target_1d = masked_target.view(-1) + + # extract the x[class] and set the x[other device] to zero + pred_logits_1d = logits_2d[ + torch.arange(start=0, end=logits_2d.shape[0], device=logits_2d.device), masked_target_1d + ] + pred_logits_1d = pred_logits_1d.clone().contiguous() + pred_logits = pred_logits_1d.view_as(target) + pred_logits[mask] = 0.0 + + # allreduce the get all x(i,y) + dist.all_reduce(pred_logits, op=dist.ReduceOp.SUM, group=process_group) + exp_logits = vocab_logits + torch.exp(vocab_logits, out=exp_logits) + sum_exp_logits = torch.sum(exp_logits, dim=-1) + dist.all_reduce(sum_exp_logits, op=dist.ReduceOp.SUM, group=process_group) + + # calculate the loss + # loss = log(sum(exp(x[i]))) - x[class] + loss = torch.where(target == ignore_index, 0.0, torch.log(sum_exp_logits) - pred_logits) + loss = torch.sum(loss).div_(torch.sum(loss != 0.0)) + + # calculate the softmax + exp_logits.div_(sum_exp_logits.unsqueeze(dim=-1)) + ctx.save_for_backward(exp_logits, mask, masked_target_1d) + + return loss + + @staticmethod + def backward(ctx, grad_output): + # retrieve the saved tensors + exp_logits, mask, masked_target_1d = ctx.saved_tensors + + # use exp logits as the input grad + grad_logits = exp_logits + partion_vocab_size = grad_logits.shape[-1] + grad_logits_2d = grad_logits.view(-1, partion_vocab_size) + + update = 1.0 - mask.view(-1).float() + grad_logits_2d[torch.arange(0, grad_logits_2d.shape[0]), masked_target_1d] -= update + + grad_logits.mul_(grad_output.unsqueeze(dim=-1)) + return grad_logits, None, None + + +def cross_entropy_1d( + vocab_logits: torch.Tensor, labels: torch.Tensor, ignore_index: int = -100, process_group: ProcessGroup = None +) -> torch.Tensor: + return DistCrossEntropy.apply(vocab_logits, labels, ignore_index, process_group) diff --git a/colossalai/shardformer/layer/normalization.py b/colossalai/shardformer/layer/normalization.py new file mode 100644 index 0000000000000000000000000000000000000000..19b973be867940e18025354fdc8b4b92dc11257e --- /dev/null +++ b/colossalai/shardformer/layer/normalization.py @@ -0,0 +1,127 @@ +#!/usr/bin/env python +# -*- encoding: utf-8 -*- + +import torch.nn as nn + +from colossalai.lazy import LazyInitContext + +__all__ = ["FusedLayerNorm", "FusedRMSNorm"] + +FAST_LAYERNORM_SUPPORTED_SIZE = [ + 1024, + 1536, + 2048, + 2304, + 3072, + 3840, + 4096, + 5120, + 6144, + 8192, + 10240, + 12288, + 12800, + 15360, + 16384, + 18432, + 20480, + 24576, + 25600, + 30720, + 32768, + 40960, + 49152, + 65536, +] + + +class FusedLayerNorm: + r""" + This is a wrapper around the apex fused layernorm implementation. It is meant to be used only with the from_native_module interface. + """ + + def __init__(self) -> None: + raise NotImplementedError( + "FusedLayerNorm is not implemented as a physical class. " + "It is meant to be used only with the from_native_module interface to wrap the fused layernorm implementation provided by apex." + ) + + @staticmethod + def from_native_module(module: nn.LayerNorm, *args, **kwargs) -> nn.Module: + r""" + Convert a native pytorch layer norm module to colossalai layer norm module + """ + # check if apex is installed + try: + pass + except ImportError: + raise ImportError( + "Please install apex from source (https://github.com/NVIDIA/apex) to use the fused layernorm kernel" + ) + + LazyInitContext.materialize(module) + # get the attributes of the module + normalized_shape = module.normalized_shape + eps = module.eps + elementwise_affine = module.elementwise_affine + dtype = module.weight.dtype + device = module.weight.device + + # pick the suitable layernorm implementation + use_fast_ln = normalized_shape in FAST_LAYERNORM_SUPPORTED_SIZE + + if use_fast_ln: + try: + from apex.contrib.layer_norm.layer_norm import FastLayerNorm as ApexFusedLayerNorm + except ImportError: + # fall back to the normal fused layernorm is not built + from apex.normalization import FusedLayerNorm as ApexFusedLayerNorm + else: + from apex.normalization import FusedLayerNorm as ApexFusedLayerNorm + + layernorm = ( + ApexFusedLayerNorm(normalized_shape, eps=eps, elementwise_affine=elementwise_affine).to(dtype).to(device) + ) + + layernorm.weight = module.weight + layernorm.bias = module.bias + return layernorm + + +class FusedRMSNorm: + """ + This is a wrapper around the apex fused rms norm implementation. It is meant to be used only with the from_native_module interface. + """ + + def __init__(self) -> None: + raise NotImplementedError( + "FusedRMSNorm is not implemented as a physical class. " + "It is meant to be used only with the from_native_module interface to wrap the fused rms norm implementation provided by apex." + ) + + @staticmethod + def from_native_module(module: nn.Module, *args, **kwargs) -> nn.Module: + try: + from apex.normalization import FusedRMSNorm as ApexFusedRMSNorm + except ImportError: + raise ImportError( + "Please install apex from source (https://github.com/NVIDIA/apex) to use the fused RMS normalization kernel" + ) + + LazyInitContext.materialize(module) + # to check if it is huggingface LlamaRMSNorm + if module.__class__.__name__ == "LlamaRMSNorm": + normalized_shape = module.weight.shape[0] + eps = module.variance_epsilon + elementwise_affine = True + else: + # get the attributes of the module + normalized_shape = module.normalized_shape + eps = module.eps + elementwise_affine = module.elementwise_affine + + rmsnorm = ApexFusedRMSNorm(normalized_shape=normalized_shape, eps=eps, elementwise_affine=elementwise_affine) + + rmsnorm.weight = module.weight + + return rmsnorm diff --git a/colossalai/shardformer/layer/parallel_module.py b/colossalai/shardformer/layer/parallel_module.py new file mode 100644 index 0000000000000000000000000000000000000000..6c0d83cc7a20ab119bd938e81828edb9eb498ccc --- /dev/null +++ b/colossalai/shardformer/layer/parallel_module.py @@ -0,0 +1,173 @@ +#!/usr/bin/env python +# -*- encoding: utf-8 -*- + +import itertools +from abc import ABC, abstractmethod +from typing import List, Union + +import torch +import torch.nn as nn +from torch.distributed import ProcessGroup +from torch.nn.modules.module import _EXTRA_STATE_KEY_SUFFIX, Module + +from colossalai.checkpoint_io.utils import gather_distributed_param +from colossalai.tensor.d_tensor import ( + distribute_tensor, + distribute_tensor_with_customization, + get_device_mesh, + get_sharding_spec, + is_customized_distributed_tensor, + is_distributed_tensor, + sharded_tensor_to_param, +) + +__all__ = ["ParallelModule"] + + +class ParallelModule(nn.Module, ABC): + @abstractmethod + def from_native_module( + module: nn.Module, process_group: Union[ProcessGroup, List[ProcessGroup]] = None + ) -> "ParallelModule": + """ + Convert a native PyTorch module to a parallelized module. + + Args: + module (nn.Module): the module to be converted. + process_group (ProcessGroup or list[ProcessGroup]): the process group(s) to be used for communication. + If this is a list, the process group at the ith index of the list will correspond to the process group + in the ith axis of the device mesh. Defaults to None, which means the global process group. + """ + + def _save_to_state_dict(self, destination, prefix, keep_vars): + r"""Saves module state to `destination` dictionary, containing a state + of the module, but not its descendants. This is called on every + submodule in :meth:`~torch.nn.Module.state_dict`. + + In rare cases, subclasses can achieve class-specific behavior by + overriding this method with custom logic. + + Args: + destination (dict): a dict where state will be stored + prefix (str): the prefix for parameters and buffers used in this + module + """ + for name, param in self._parameters.items(): + if param is not None: + destination[prefix + name] = gather_distributed_param(param, keep_vars=keep_vars) + + for name, buf in self._buffers.items(): + if buf is not None and name not in self._non_persistent_buffers_set: + destination[prefix + name] = buf if keep_vars else buf.detach() + extra_state_key = prefix + _EXTRA_STATE_KEY_SUFFIX + if getattr(self.__class__, "get_extra_state", Module.get_extra_state) is not Module.get_extra_state: + destination[extra_state_key] = self.get_extra_state() + + def _load_from_state_dict( + self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs + ): + r"""Copies parameters and buffers from :attr:`state_dict` into only + this module, but not its descendants. This is called on every submodule + in :meth:`~torch.nn.Module.load_state_dict`. Metadata saved for this + module in input :attr:`state_dict` is provided as :attr:`local_metadata`. + For state dicts without metadata, :attr:`local_metadata` is empty. + Subclasses can achieve class-specific backward compatible loading using + the version number at `local_metadata.get("version", None)`. + + .. note:: + :attr:`state_dict` is not the same object as the input + :attr:`state_dict` to :meth:`~torch.nn.Module.load_state_dict`. So + it can be modified. + + Args: + state_dict (dict): a dict containing parameters and + persistent buffers. + prefix (str): the prefix for parameters and buffers used in this + module + local_metadata (dict): a dict containing the metadata for this module. + See + strict (bool): whether to strictly enforce that the keys in + :attr:`state_dict` with :attr:`prefix` match the names of + parameters and buffers in this module + missing_keys (list of str): if ``strict=True``, add missing keys to + this list + unexpected_keys (list of str): if ``strict=True``, add unexpected + keys to this list + error_msgs (list of str): error messages should be added to this + list, and will be reported together in + :meth:`~torch.nn.Module.load_state_dict` + """ + for hook in self._load_state_dict_pre_hooks.values(): + hook(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs) + + persistent_buffers = {k: v for k, v in self._buffers.items() if k not in self._non_persistent_buffers_set} + local_name_params = itertools.chain(self._parameters.items(), persistent_buffers.items()) + local_state = {k: v for k, v in local_name_params if v is not None} + + for name, param in local_state.items(): + key = prefix + name + + if key in state_dict: + input_param = state_dict[key] + if not torch.overrides.is_tensor_like(input_param): + error_msgs.append( + 'While copying the parameter named "{}", ' + "expected torch.Tensor or Tensor-like object from checkpoint but " + "received {}".format(key, type(input_param)) + ) + continue + + if is_distributed_tensor(param): + # shard the input param + device_mesh = get_device_mesh(param) + sharding_spec = get_sharding_spec(param) + sharded_tensor = distribute_tensor(input_param, device_mesh, sharding_spec) + input_param = sharded_tensor_to_param(sharded_tensor) + elif is_customized_distributed_tensor(param): + input_param = distribute_tensor_with_customization(input_param, param.shard_fn, param.gather_fn) + + # This is used to avoid copying uninitialized parameters into + # non-lazy modules, since they dont have the hook to do the checks + # in such case, it will error when accessing the .shape attribute. + is_param_lazy = torch.nn.parameter.is_lazy(param) + # Backward compatibility: loading 1-dim tensor from 0.3.* to version 0.4+ + if not is_param_lazy and len(param.shape) == 0 and len(input_param.shape) == 1: + input_param = input_param[0] + + if not is_param_lazy and input_param.shape != param.shape: + # local shape should match the one in checkpoint + error_msgs.append( + "size mismatch for {}: copying a param with shape {} from checkpoint, " + "the shape in current model is {}.".format(key, input_param.shape, param.shape) + ) + continue + + try: + with torch.no_grad(): + param.copy_(input_param) + except Exception as ex: + error_msgs.append( + 'While copying the parameter named "{}", ' + "whose dimensions in the model are {} and " + "whose dimensions in the checkpoint are {}, " + "an exception occurred : {}.".format(key, param.size(), input_param.size(), ex.args) + ) + elif strict: + missing_keys.append(key) + + extra_state_key = prefix + _EXTRA_STATE_KEY_SUFFIX + if getattr(self.__class__, "set_extra_state", Module.set_extra_state) is not Module.set_extra_state: + if extra_state_key in state_dict: + self.set_extra_state(state_dict[extra_state_key]) + elif strict: + missing_keys.append(extra_state_key) + elif strict and (extra_state_key in state_dict): + unexpected_keys.append(extra_state_key) + + if strict: + for key in state_dict.keys(): + if key.startswith(prefix) and key != extra_state_key: + input_name = key[len(prefix) :] + input_name = input_name.split(".", 1)[0] # get the name of param/buffer/child + if input_name not in self._modules and input_name not in local_state: + unexpected_keys.append(key) diff --git a/colossalai/shardformer/layer/qkv_fused_linear.py b/colossalai/shardformer/layer/qkv_fused_linear.py new file mode 100644 index 0000000000000000000000000000000000000000..12476d050600739fdf86d3796d0163ccb04da5a2 --- /dev/null +++ b/colossalai/shardformer/layer/qkv_fused_linear.py @@ -0,0 +1,739 @@ +#!/usr/bin/env python +# -*- encoding: utf-8 -*- + +import math +from typing import Callable, List, Optional, Tuple, Union + +import torch +import torch.distributed as dist +import torch.nn as nn +from torch import Tensor +from torch.distributed import ProcessGroup +from torch.nn.parameter import Parameter + +from colossalai.lazy import LazyInitContext +from colossalai.nn import init as init +from colossalai.nn.layer.utils import divide +from colossalai.tensor.d_tensor.api import ( + customized_distributed_tensor_to_existing_param, + distribute_tensor_with_customization, + is_customized_distributed_tensor, + is_distributed_tensor, + shard_rowwise, + sharded_tensor_to_existing_param, +) + +from ._operation import ( + gather_forward_split_backward, + linear_reducescatter_forward_gather_backward, + linear_with_async_comm, + matmul_gather_forward_reducescatter_backward, + matmul_with_async_comm, + reduce_backward, + reduce_forward, + split_forward_gather_backward, +) +from .parallel_module import ParallelModule +from .utils import create_randomizer_with_offset + +__all__ = ["FusedLinear1D_Col", "FusedLinear1D_Row", "GPT2FusedLinearConv1D_Col", "GPT2FusedLinearConv1D_Row"] + +# ==================================== +# For GPT Only +# ==================================== + + +def split_fused_qkv_in_gpt2_style( + qkv: torch.Tensor, n_fused: int, process_group: ProcessGroup, is_transposed: bool = False +): + """ + The fused qkv tensor looks like [Q1, Q2, K1, K2, V1, V2], this function will split them into [Q1, K1, V1] and [Q2, K2, V2]. + + Args: + qkv (torch.Tensor): The fused qkv tensor. + n_fused (int): The number items fused together, defaults to 3 (query, key and value). + process_group (ProcessGroup): The process group for distributed communication. + is_transposed (bool): generally the tensor is the shape of (out_features, in_features). Set this to True if the tensor is in the shape (in_features, out_features). + """ + # get the number of slice for the fused qkv + rank = dist.get_rank(group=process_group) + world_size = dist.get_world_size(group=process_group) + order = torch.arange(world_size * n_fused) + + # split the fused qkv + # from + # [Q, K, V] + # to + # [Q1, Q2, K1, K2, V1, V2] + if is_transposed: + weight_chunks = torch.chunk(qkv, world_size * n_fused, dim=-1) + else: + weight_chunks = torch.chunk(qkv, world_size * n_fused, dim=0) + + # rearrange the slice into the final order + # from + # [Q1, Q2, K1, K2, V1, V2] + # to + # [Q1, K1, V1], [Q2, K2, V2] + weight_chunks_of_current_rank = [weight_chunks[i] for i in order[rank::world_size]] + + if is_transposed: + weight_of_current_rank = torch.cat(weight_chunks_of_current_rank, dim=-1) + else: + weight_of_current_rank = torch.cat(weight_chunks_of_current_rank, dim=0) + return weight_of_current_rank + + +def gather_fused_qkv_in_gpt2_style( + qkv: torch.Tensor, n_fused: int, process_group: ProcessGroup, is_transposed: bool = False +): + """ + The splitted qkv tensor looks like [Q1, K1, V1] and [Q2, K2, V2], this function will gather them into [Q1, Q2, K1, K2, V1, V2]. + + Args: + qkv (torch.Tensor): The fused qkv tensor. + n_fused (int): The number items fused together, defaults to 3 (query, key and value). + process_group (ProcessGroup): The process group for distributed communication. + is_transposed (bool): generally the tensor is the shape of (out_features, in_features). Set this to True if the tensor is in the shape (in_features, out_features). + """ + world_size = dist.get_world_size(group=process_group) + + # gather the tensors + # from + # [Q1, K1, V1], [Q2, K2, V2] + # to + # [Q1, K1, V1, Q2, K2, V2] + origin_device = qkv.device + qkv = qkv.cuda() + gather_list = [torch.zeros_like(qkv) for _ in range(world_size)] + dist.all_gather(gather_list, qkv, group=process_group) + + if is_transposed: + gather_weight = torch.cat(gather_list, dim=-1) + else: + gather_weight = torch.cat(gather_list, dim=0) + gather_weight = gather_weight.to(origin_device) + qkv = qkv.to(origin_device) + + # rearrange the tensor slices + # from + # [Q1, K1, V1, Q2, K2, V2] + # to + # [Q1, Q2, K1, K2, V1, V2] + if is_transposed: + weight_chunks = torch.chunk(gather_weight, world_size * n_fused, dim=-1) + else: + weight_chunks = torch.chunk(gather_weight, world_size * n_fused, dim=0) + + reordered_chunk_list = [] + for i in range(n_fused): + reordered_chunk_list.extend(weight_chunks[i::n_fused]) + + if is_transposed: + reordered_gather_weight = torch.cat(reordered_chunk_list, dim=-1) + else: + reordered_gather_weight = torch.cat(reordered_chunk_list, dim=0) + return reordered_gather_weight + + +class GPT2FusedLinearConv1D_Col(ParallelModule): + r"""Linear layer with column parallelism. + + The linear layer is defined as :math:`Y = XA + b`. A is parallelized along + its second dimension as :math:`A = [A_1, ..., A_p]`. This layer is used to fit `Conv1D` layer (Fused QKV) in gpt2 of huggingface. + + Args: + in_features (int): size of each input sample. + out_features (int): size of each output sample. + bias (bool, optional): If set to ``False``, the layer will not learn an additive bias, defaults to ``True``. + dtype (`torch.dtype`): The dtype of parameters, defaults to None. + device (`torch.device`): The device of parameters, defaults to None. + n_fused (int): The number items fused, defaults to 3 (QKV). + process_group (`torch.distributed.ProcessGroup`): The process group to be used for weight sharding and communication, defaults to None. + seq_parallel (`bool`): If set to ``True``, it will use sequence parallel, defaults to False. + gather_output (bool, optional): If true, call all-gather on output and make Y available + to all GPUs, otherwise, every GPU will have its output + which is :math:`Y_i = XA_i`, defaults to False + skip_bias_add (bool): If set to ``True``, it will skip bias add for linear layer, + which is preserved for kernel fusion, defaults to False + weight_initializer (`typing.Callable`): + The initializer of weight, defaults to kaiming uniform initializer. + bias_initializer (`typing.Callable`): + The initializer of bias, defaults to xavier uniform initializer. + + More details about ``initializer`` please refer to + `init `_. + """ + + def __init__( + self, + in_features: int, + out_features: int, + bias: bool = True, + dtype: torch.dtype = None, + device: torch.device = None, + process_group: ProcessGroup = None, + async_communication: bool = False, + gather_output: bool = False, + seq_parallel: bool = False, + overlap: bool = False, + skip_bias_add: bool = False, + n_fused: int = 3, + weight: Optional[Parameter] = None, + bias_: Optional[Parameter] = None, + weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), + bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1), + ): + super().__init__() + + # Keep input parameters + self.in_features = in_features + self.out_features = out_features + self.gather_output = gather_output + self.seq_parallel = seq_parallel + self.overlap = overlap + self.skip_bias_add = skip_bias_add + self.device = device + self.n_fused = n_fused + self.process_group = process_group + self.async_communication = async_communication + + if skip_bias_add and not bias: + raise ValueError("cannot skip bias addition if bias is None") + + # offset the seed with randomizer index and rank + seed = torch.random.initial_seed() + self.randomizer = create_randomizer_with_offset(seed, process_group=self.process_group) + + # sanity check + if weight is not None: + assert not bias or bias_ is not None, "bias_ must be provided if bias is True when weight is not None" + else: + assert bias_ is None, "bias_ must be None if weight is None" + + # Parameters. + if weight is None: + # Initialize weight. + factory_kwargs = {"device": device, "dtype": dtype} + self.weight = Parameter(torch.empty(self.in_features, self.out_features, **factory_kwargs)) + else: + weight.data = weight.data.to(device=device, dtype=dtype) + self.weight = weight + + def shard_fn(tensor): + return split_fused_qkv_in_gpt2_style(tensor, self.n_fused, self.process_group, True) + + def gather_fn(tensor): + return gather_fused_qkv_in_gpt2_style(tensor, self.n_fused, self.process_group, True) + + if not is_customized_distributed_tensor(self.weight): + with torch.no_grad(): + sharded_weight = distribute_tensor_with_customization(self.weight.data, shard_fn, gather_fn) + customized_distributed_tensor_to_existing_param(sharded_weight, self.weight) + + if bias: + if bias_ is None: + self.bias = Parameter(torch.empty(self.out_features, **factory_kwargs)) + else: + bias_.data = bias_.data.to(device=device, dtype=dtype) + self.bias = bias_ + if not is_customized_distributed_tensor(self.bias): + with torch.no_grad(): + sharded_bias = distribute_tensor_with_customization(self.bias.data, shard_fn, gather_fn) + customized_distributed_tensor_to_existing_param(sharded_bias, self.bias) + else: + self.bias = None + + if weight is None: + # init weights + self.reset_parameters(weight_initializer, bias_initializer) + + @staticmethod + def from_native_module( + module: nn.Module, process_group: Union[ProcessGroup, List[ProcessGroup]], *args, **kwargs + ) -> ParallelModule: + r""" + Convert a huggingface layer `Conv1D` in gpt2 to a parallelized linear layer. + + Args: + module (`nn.Linear`): The module to be converted. + process_group (`Union[ProcessGroup, List[ProcessGroup]]`): The process group to be used for weight sharding and communication. + n_fused (int): The number of layers to be fused. In GPT2, Q,K,V are fused in one weight. + """ + LazyInitContext.materialize(module) + # get the attributes + in_features = module.weight.shape[0] + out_features = module.weight.shape[1] + bias = module.bias is not None + device = module.weight.device + + # ensure only one process group is passed + if isinstance(process_group, (list, tuple)): + assert len(process_group) == 1, f"Expected only one process group, got {len(process_group)}." + process_group = process_group[0] + + tp_size = dist.get_world_size(process_group) + if out_features < tp_size: + return module + + if out_features % tp_size != 0: + raise ValueError( + f"The size of out_features:{out_features} is not integer multiples of tensor parallel size: {tp_size}!" + ) + + linear_1d = GPT2FusedLinearConv1D_Col( + in_features=in_features, + out_features=out_features, + bias=bias, + device=device, + process_group=process_group, + weight=module.weight, + bias_=module.bias, + *args, + **kwargs, + ) + + return linear_1d + + def reset_parameters(self, weight_initializer, bias_initializer) -> None: + with self.randomizer.fork_rng(enable_cpu=True): + fan_in, fan_out = self.in_features, self.out_features + weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out) + if self.bias is not None: + bias_initializer(self.bias, fan_in=fan_in) + + def forward(self, input_: Tensor) -> Tuple[Tensor, Tensor]: + assert ( + input_.shape[-1] == self.weight.shape[0] + ), "Invalid shapes in Linear1D_Col forward: input={}, weight={}. Expected last dim of input {}.".format( + input_.shape, self.weight.shape, self.weight.shape[-1] + ) + + # Matrix multiply. + bias = self.bias if not self.skip_bias_add else None + + if self.seq_parallel: + input_parallel = input_ + output_parallel = matmul_gather_forward_reducescatter_backward( + input_parallel, self.weight, bias, self.process_group, True, 1, self.overlap + ) + else: + # Set up backprop all-reduce. + input_parallel = reduce_backward(input_, self.process_group) + output_parallel = matmul_with_async_comm( + input_parallel, self.weight, bias, self.process_group, self.async_communication + ) + + if self.gather_output: + # All-gather across the partitions. + output = gather_forward_split_backward(output_parallel, dim=-1, process_group=self.process_group) + else: + output = output_parallel + + if self.skip_bias_add: + return output, self.bias + else: + return output + + +class GPT2FusedLinearConv1D_Row(ParallelModule): + r"""Linear layer with row parallelism. + This layer is used to fit `Conv1D` layer (Fused QKV) in gpt2 of huggingface. + + Args: + in_features (int): size of each input sample. + out_features (int): size of each output sample. + bias (bool, optional): If set to ``False``, the layer will not learn an additive bias, defaults to ``True``. + dtype (`torch.dtype`): The dtype of parameters, defaults to None. + parallel_input (bool): If set to ``True``, it's assumed that the input is split, defaults to False. + skip_bias_add (bool): If set to ``True``, it will skip bias add for linear layer, + seq_parallel (`bool`): If set to ``True``, it will use sequence parallel, defaults to False. + which is preserved for kernel fusion, defaults to False + weight_initializer (:class:`typing.Callable`, optional): + The initializer of weight, defaults to kaiming uniform initializer. + bias_initializer (:class:`typing.Callable`, optional): + The initializer of bias, defaults to xavier uniform initializer. + + More details about ``initializer`` please refer to + `init `_. + """ + + def __init__( + self, + in_features: int, + out_features: int, + bias: bool = True, + dtype: torch.dtype = None, + device: torch.device = None, + process_group: ProcessGroup = None, + seq_parallel: bool = False, + parallel_input: bool = True, + skip_bias_add: bool = False, + weight: Optional[Parameter] = None, + bias_: Optional[Parameter] = None, + weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), + bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1), + stream_chunk_num: int = 1, + ): + super().__init__() + + self.stream_chunk_num = stream_chunk_num + + # Keep input parameters + self.in_features = in_features + self.out_features = out_features + self.parallel_input = parallel_input + self.skip_bias_add = skip_bias_add + self.process_group = process_group + self.seq_parallel = seq_parallel + self.num_partitions = dist.get_world_size(self.process_group) + + if skip_bias_add and not bias: + raise ValueError("cannot skip bias addition if bias is None") + + # offset the seed with randomizer index and rank + seed = torch.random.initial_seed() + self.randomizer = create_randomizer_with_offset(seed, process_group=self.process_group) + + # Divide the weight matrix along the last dimension. + self.input_size_per_partition = divide(in_features, self.num_partitions) + + # sanity check + if weight is not None: + assert not bias or bias_ is not None, "bias_ must be provided if bias is True when weight is not None" + else: + assert bias_ is None, "bias_ must be None if weight is None" + + # Parameters. + if weight is None: + # Initialize weight. + factory_kwargs = {"device": device, "dtype": dtype} + self.weight = Parameter(torch.empty(self.in_features, self.out_features, **factory_kwargs)) + else: + weight.data = weight.data.to(device=device, dtype=dtype) + self.weight = weight + if not is_distributed_tensor(self.weight): + sharded_weight = shard_rowwise(self.weight.data, self.process_group) + sharded_tensor_to_existing_param(sharded_weight, self.weight) + + if self.stream_chunk_num > 1: + # TODO() work for inference only + self.chunk_weight() + if bias: + if bias_ is None: + self.bias = Parameter(torch.empty(self.out_features, **factory_kwargs)) + else: + bias_.data = bias_.data.to(device=device, dtype=dtype) + self.bias = bias_ + else: + self.bias = None + + if weight is None: + # init weights + self.reset_parameters(weight_initializer, bias_initializer) + + @staticmethod + def from_native_module( + module: nn.Linear, process_group: Union[ProcessGroup, List[ProcessGroup]], *args, **kwargs + ) -> ParallelModule: + r""" + Convert a native PyTorch linear layer to a parallelized linear layer. + """ + LazyInitContext.materialize(module) + # get the attributes + in_features = module.weight.shape[0] + out_features = module.weight.shape[1] + bias = module.bias is not None + device = module.weight.device + + # ensure only one process group is passed + if isinstance(process_group, (list, tuple)): + assert len(process_group) == 1, f"Expected only one process group, got {len(process_group)}." + process_group = process_group[0] + + tp_size = dist.get_world_size(process_group) + if in_features < tp_size: + return module + + if in_features % tp_size != 0: + raise ValueError( + f"The size of in_features:{in_features} is not integer multiples of tensor parallel size: {tp_size}!" + ) + + linear_1d = GPT2FusedLinearConv1D_Row( + in_features=in_features, + out_features=out_features, + bias=bias, + device=device, + process_group=process_group, + weight=module.weight, + bias_=module.bias, + *args, + **kwargs, + ) + + return linear_1d + + def chunk_weight(self): + self.weight_list = torch.chunk(self.weight, self.stream_chunk_num, dim=0) + + def reset_parameters(self, weight_initializer, bias_initializer) -> None: + with self.randomizer.fork_rng(enable_cpu=True): + fan_in, fan_out = self.in_features, self.out_features + weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out) + + if self.bias is not None: + bias_initializer(self.bias, fan_in=fan_in) + if self.process_group is None: + src_rank = 0 + else: + src_rank = dist.distributed_c10d._get_global_rank(self.process_group, 0) + + origin_device = self.bias.device + self.bias.data = self.bias.cuda() + dist.broadcast(self.bias, src=src_rank, group=self.process_group) + self.bias.data = self.bias.to(origin_device) + + def forward(self, input_: Tensor) -> Tensor: + # Set up backprop all-reduce. + if self.parallel_input: + assert ( + input_.shape[-1] == self.weight.shape[0] + ), "Invalid shapes in Linear1D_Row forward: input={}, weight={}. Expected last dim of input {}.".format( + input_.shape, self.weight.shape, self.weight.shape[0] + ) + input_ = input_ + else: + assert ( + divide(input_.shape[-1], self.num_partitions) == self.weight.shape[0] + ), "Invalid shapes in Linear1D_Row forward: input={}, weight={}. Expected last dim of input {}.".format( + input_.shape, self.weight.shape, self.weight.shape[0] * self.num_partitions + ) + input_ = split_forward_gather_backward(input_, dim=-1, process_group=self.process_group) + + if self.stream_chunk_num > 1: + if self.training: + raise RuntimeError("use stream_chunk_num=1 in Linear1D_Row for training!") + with torch.no_grad(): + output_parallel_list = [None for i in range(self.stream_chunk_num)] + handle_list = [] + for i in range(self.stream_chunk_num): + output_parallel_list[i] = torch.matmul(input_, self.weight_list[i]) + handle = torch.distributed.all_reduce( + output_parallel_list[i], group=self.process_group, async_op=True + ) + handle_list.append(handle) + # output_parallel_list[i] = reduce_input(output_parallel_list[i], ParallelMode.PARALLEL_1D) + for handle in handle_list: + handle.wait() + output = torch.cat(output_parallel_list, dim=-1) + else: + output_parallel = torch.matmul(input_, self.weight) + if self.seq_parallel: + output = linear_reducescatter_forward_gather_backward(output_parallel, self.process_group, 1) + else: + output = reduce_forward(output_parallel, self.process_group) + + if not self.skip_bias_add: + if self.bias is not None: + output = output + self.bias + return output + else: + return output, self.bias + + +# ==================================== +# For Fused torch.nn.Linear +# ==================================== + + +class FusedLinear1D_Col(ParallelModule): + r"""Fused Linear layer with column parallelism. + + The linear layer is defined as :math:`Y = XA + b`. A is parallelized along + its second dimension as :math:`A = [A_1, ..., A_p]`. This layer is used to fit `torch.nn.Linear` layer (Fused QKV) in normal torch layer of huggingface, like SAM. + + Args: + in_features (int): size of each input sample. + out_features (int): size of each output sample. + bias (bool, optional): If set to ``False``, the layer will not learn an additive bias, defaults to ``True``. + dtype (`torch.dtype`): The dtype of parameters, defaults to None. + device (`torch.device`): The device of parameters, defaults to None. + n_fused (int): The number items fused, defaults to 3 (QKV). + process_group (`torch.distributed.ProcessGroup`): The process group to be used for weight sharding and communication, defaults to None. + gather_output (bool, optional): If true, call all-gather on output and make Y available + to all GPUs, otherwise, every GPU will have its output + which is :math:`Y_i = XA_i`, defaults to False + skip_bias_add (bool): If set to ``True``, it will skip bias add for linear layer, + which is preserved for kernel fusion, defaults to False + weight_initializer (`typing.Callable`): + The initializer of weight, defaults to kaiming uniform initializer. + bias_initializer (`typing.Callable`): + The initializer of bias, defaults to xavier uniform initializer. + + More details about ``initializer`` please refer to + `init `_. + """ + + def __init__( + self, + in_features: int, + out_features: int, + bias: bool = True, + dtype: torch.dtype = None, + device: torch.device = None, + process_group: ProcessGroup = None, + async_communication: bool = False, + gather_output: bool = False, + skip_bias_add: bool = False, + n_fused: int = 3, + weight: Optional[Parameter] = None, + bias_: Optional[Parameter] = None, + weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), + bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1), + ): + super().__init__() + # Keep input parameters + self.in_features = in_features + self.out_features = out_features + self.gather_output = gather_output + self.skip_bias_add = skip_bias_add + self.device = device + self.n_fused = n_fused + self.process_group = process_group + self.async_communication = async_communication + + if skip_bias_add and not bias: + raise ValueError("cannot skip bias addition if bias is None") + + # offset the seed with randomizer index and rank + seed = torch.random.initial_seed() + self.randomizer = create_randomizer_with_offset(seed, process_group=self.process_group) + + # sanity check + if weight is not None: + assert not bias or bias_ is not None, "bias_ must be provided if bias is True when weight is not None" + else: + assert bias_ is None, "bias_ must be None if weight is None" + + # Parameters. + if weight is None: + # Initialize weight. + factory_kwargs = {"device": device, "dtype": dtype} + self.weight = Parameter(torch.empty(self.out_features, self.in_features, **factory_kwargs)) + else: + weight.data = weight.data.to(device=device, dtype=dtype) + self.weight = weight + + def shard_fn(tensor): + return split_fused_qkv_in_gpt2_style(tensor, self.n_fused, self.process_group, False) + + def gather_fn(tensor): + return gather_fused_qkv_in_gpt2_style(tensor, self.n_fused, self.process_group, False) + + if not is_customized_distributed_tensor(self.weight): + with torch.no_grad(): + sharded_weight = distribute_tensor_with_customization(self.weight.data, shard_fn, gather_fn) + customized_distributed_tensor_to_existing_param(sharded_weight, self.weight) + + if bias: + if bias_ is None: + self.bias = Parameter(torch.empty(self.out_features, **factory_kwargs)) + else: + bias_.data = bias_.data.to(device=device, dtype=dtype) + self.bias = bias_ + if not is_customized_distributed_tensor(self.bias): + with torch.no_grad(): + sharded_bias = distribute_tensor_with_customization(self.bias.data, shard_fn, gather_fn) + customized_distributed_tensor_to_existing_param(sharded_bias, self.bias) + else: + self.bias = None + + if weight is None: + # init weights + self.reset_parameters(weight_initializer, bias_initializer) + + @staticmethod + def from_native_module( + module: nn.Module, process_group: Union[ProcessGroup, List[ProcessGroup]], n_fused: int, *args, **kwargs + ) -> ParallelModule: + r""" + Convert a fused `torch.nn.linear` layer to a parallelized linear layer. + + Args: + module (`nn.Linear`): The module to be converted. + process_group (`Union[ProcessGroup, List[ProcessGroup]]`): The process group to be used for weight sharding and communication. + n_fused (int): The number of layers to be fused. In common, Q,K,V are fused in one weight. + """ + # get the attributes + in_features = module.in_features + out_features = module.out_features + bias = module.bias is not None + device = module.weight.device + + # ensure only one process group is passed + if isinstance(process_group, (list, tuple)): + assert len(process_group) == 1, f"Expected only one process group, got {len(process_group)}." + process_group = process_group[0] + + linear_1d = FusedLinear1D_Col( + in_features=in_features, + out_features=out_features, + bias=bias, + device=device, + process_group=process_group, + weight=module.weight, + bias_=module.bias, + *args, + **kwargs, + ) + + # # TODO: copy the sharded weights + # with torch.no_grad(): + # sharded_weight = split_fused_qkv_in_gpt2_style(module.weight.data, + # n_fused=n_fused, + # process_group=process_group, + # is_transposed=False) + # linear_1d.weight.data.copy_(sharded_weight.data) + + # if bias: + # sharded_bias = split_fused_qkv_in_gpt2_style(module.bias.data, + # n_fused=n_fused, + # process_group=process_group, + # is_transposed=False) + # linear_1d.bias.data.copy_(sharded_bias.data) + print(linear_1d.weight.shape) + return linear_1d + + def reset_parameters(self, weight_initializer, bias_initializer) -> None: + with self.randomizer.fork_rng(enable_cpu=True): + fan_in, fan_out = self.in_features, self.out_features + weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out) + if self.bias is not None: + bias_initializer(self.bias, fan_in=fan_in) + + def forward(self, input_: Tensor) -> Tuple[Tensor, Tensor]: + assert ( + input_.shape[-1] == self.weight.shape[-1] + ), "Invalid shapes in Linear1D_Col forward: input={}, weight={}. Expected last dim of input {}.".format( + input_.shape, self.weight.shape, self.weight.shape[-1] + ) + # Set up backprop all-reduce. + # input_parallel = reduce_backward(input_, self.process_group) + input_parallel = input_ + + # Matrix multiply. + bias = self.bias if not self.skip_bias_add else None + + output_parallel = linear_with_async_comm(input_parallel, self.weight, bias, self.process_group, True) + + if self.gather_output: + # All-gather across the partitions. + output = gather_forward_split_backward(output_parallel, dim=-1, process_group=self.process_group) + else: + output = output_parallel + + if self.skip_bias_add: + return output, self.bias + else: + return output diff --git a/colossalai/shardformer/layer/utils.py b/colossalai/shardformer/layer/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..c3d8501cdeae74999009edc6b18e3285fc25093a --- /dev/null +++ b/colossalai/shardformer/layer/utils.py @@ -0,0 +1,207 @@ +from contextlib import contextmanager + +import torch +import torch.distributed as dist +from torch.distributed import ProcessGroup + + +class Randomizer: + """ + Randomizer enables the program to be executed under a different seed within the context. + + Example: + + ```python + randomizer = Randomizer(seed=1024) + + with randomizer.fork(): + # do something here with seed 1024 + do_something() + ``` + + Args: + seed (int): The random seed to set. + enable_cpu (bool): fork the CPU RNG state as well. + with_index (bool): whether to use the index of the randomizer. + """ + + _INDEX = 0 + + def __init__(self, seed: int): + self.seed = seed + + # Handle CUDA rng state + # 1. get the current rng state + # 2. set the seed and store the rng state + # 3. recover the original rng state + cuda_original_rng_state = torch.cuda.get_rng_state() + torch.cuda.manual_seed(seed) + self.cuda_rng_state = torch.cuda.get_rng_state() + torch.cuda.set_rng_state(cuda_original_rng_state) + + # to the same for cpu rng state + cpu_original_rng_state = torch.get_rng_state() + torch.manual_seed(seed) + self.cpu_rng_state = torch.get_rng_state() + torch.set_rng_state(cpu_original_rng_state) + + def _set_cuda_rng_state(self, rng_state): + torch.cuda.set_rng_state(rng_state) + + def _get_cuda_rng_state(self): + current_state = torch.cuda.get_rng_state() + return current_state + + def _set_cpu_rng_state(self, rng_state): + torch.set_rng_state(rng_state) + + def _get_cpu_rng_state(self): + current_state = torch.get_rng_state() + return current_state + + @contextmanager + def fork_rng(self, enable_cpu: bool = False): + """ + This is a context manager to change the dropout state and recover the original state. + + Usage: + :: + >>> with _seed_manager.dropout_mode(): + >>> input = super().forward(input) + """ + try: + current_cuda_rng_state = self._get_cuda_rng_state() + self._set_cuda_rng_state(self.cuda_rng_state) + + if enable_cpu: + current_cpu_rng_state = self._get_cpu_rng_state() + self._set_cpu_rng_state(self.cpu_rng_state) + yield + finally: + self.cuda_rng_state = self._get_cuda_rng_state() + self._set_cuda_rng_state(current_cuda_rng_state) + + if enable_cpu: + self.cpu_rng_state = self._get_cpu_rng_state() + self._set_cpu_rng_state(current_cpu_rng_state) + + @staticmethod + def index(): + """ + Return the index of the randomizer. The index is useful when the user wants + to introduce some randomness in the program. + + Note: + The index will increment by one each time this method is called. + + Example: + + ```python + # assume we need a randomizer to init the weight of different layers + # we can use the index of the randomizer to do so that + # each layer has its own randomizer with a different seed + base_seed = torch.random.initial_seed() + seed = base_seed + Randomizer.index() + randomizer = Randomizer(seed) + + with randomizer.fork(): + init_weights() + ``` + + """ + idx = Randomizer._INDEX + return idx + + @staticmethod + def increment_index(): + """ + Increment the index of the randomizer by one. + """ + Randomizer._INDEX += 1 + + @staticmethod + def reset_index(): + """ + Reset the index to zero. + """ + Randomizer._INDEX = 0 + + @staticmethod + def is_randomizer_index_synchronized(process_group: ProcessGroup = None): + """ + Return whether the randomizer index is synchronized across processes. + """ + index = Randomizer.index() + if dist.is_initialized(): + # convert the index to tensor + index_tensor = torch.tensor(index, dtype=torch.int32).cuda() + + # all gather the index + gathered_index = [torch.zeros_like(index_tensor) for _ in range(dist.get_world_size(process_group))] + dist.all_gather(gathered_index, index_tensor, process_group) + + # make sure all the gathered index are the same + for i in range(1, dist.get_world_size(process_group)): + if gathered_index[i] != gathered_index[0]: + return False + + return True + + @staticmethod + def synchronize_index(process_group: ProcessGroup = None): + """ + All gather the index and pick the largest value. + """ + index = Randomizer.index() + + if dist.is_initialized(): + # convert the index to tensor + index_tensor = torch.tensor(index, dtype=torch.int32).cuda() + + # all gather the index + gathered_index = [torch.zeros_like(index_tensor) for _ in range(dist.get_world_size(process_group))] + dist.all_gather(gathered_index, index_tensor, process_group) + + # pick the largest index + for i in range(1, dist.get_world_size(process_group)): + if gathered_index[i] > index_tensor: + index_tensor = gathered_index[i] + + # set the index + Randomizer._INDEX = index_tensor.item() + + +def create_randomizer_with_offset( + seed: int, process_group: ProcessGroup = None, offset_by_rank: bool = True, offset_by_index: bool = True +): + """ + Create a randomizer with an offset. The offset is equal to the rank of the process and the index of the randomizer. + + Args: + seed (int): The base random seed to set. + process_group (ProcessGroup): the process group to get the rank from. + offset_by_rank (bool): whether to offset by the rank of the process, i.e., the rank of the process will be added to the seed. Default: True. + offset_by_index (bool): whether to offset by the index of the randomizer, i.e., the index of the randomizer will be added to the seed. Default: True. + + Returns: + Randomizer: the randomizer with offset. + """ + base_seed = seed + + if offset_by_rank and dist.is_initialized(): + rank = dist.get_rank(process_group) + base_seed += rank + + if offset_by_index: + # check if the randomizer index is synchronized + is_synchronized = Randomizer.is_randomizer_index_synchronized(process_group) + assert is_synchronized, ( + "We detect that the randomizer index is not synchronized across processes." + "This is not allowed when we want to create a randomizer with offset by index." + "Please call Randomizer.synchronize_index() first." + ) + + base_seed += Randomizer.index() + Randomizer.increment_index() + + return Randomizer(seed=base_seed) diff --git a/tests/test_layers/test_3d/checks_3d/__init__.py b/colossalai/shardformer/modeling/__init__.py similarity index 100% rename from tests/test_layers/test_3d/checks_3d/__init__.py rename to colossalai/shardformer/modeling/__init__.py diff --git a/colossalai/shardformer/modeling/bert.py b/colossalai/shardformer/modeling/bert.py new file mode 100644 index 0000000000000000000000000000000000000000..7411e1d0ec46e5510b3f36999d10d6c8fc2aba60 --- /dev/null +++ b/colossalai/shardformer/modeling/bert.py @@ -0,0 +1,1287 @@ +import math +import warnings +from typing import List, Optional, Tuple, Union + +import torch +from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss +from transformers.modeling_outputs import ( + BaseModelOutputWithPoolingAndCrossAttentions, + CausalLMOutputWithCrossAttentions, + MaskedLMOutput, + MultipleChoiceModelOutput, + NextSentencePredictorOutput, + QuestionAnsweringModelOutput, + SequenceClassifierOutput, + TokenClassifierOutput, +) +from transformers.models.bert.modeling_bert import ( + BertForMaskedLM, + BertForMultipleChoice, + BertForNextSentencePrediction, + BertForPreTraining, + BertForPreTrainingOutput, + BertForQuestionAnswering, + BertForSequenceClassification, + BertForTokenClassification, + BertLMHeadModel, + BertModel, +) +from transformers.utils import logging + +from colossalai.pipeline.stage_manager import PipelineStageManager +from colossalai.shardformer import ShardConfig +from colossalai.shardformer.layer._operation import gather_forward_split_backward, split_forward_gather_backward + + +class BertPipelineForwards: + """ + This class serves as a micro library for forward function substitution of Bert models + under pipeline setting. + """ + + @staticmethod + def bert_model_forward( + self: BertModel, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + token_type_ids: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + stage_manager: Optional[PipelineStageManager] = None, + hidden_states: Optional[torch.FloatTensor] = None, # this is from the previous stage + stage_index: Optional[List[int]] = None, + shard_config: ShardConfig = None, + ): + # TODO(jianghai): add explaination of the output here. + r""" + encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if + the model is configured as a decoder. + encoder_attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in + the cross-attention if the model is configured as a decoder. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + past_key_values (`tuple(tuple(torch.FloatTensor))` of length `config.n_layers` with each tuple having 4 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): + Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding. + + If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that + don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all + `decoder_input_ids` of shape `(batch_size, sequence_length)`. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). + """ + logger = logging.get_logger(__name__) + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if self.config.is_decoder: + use_cache = use_cache if use_cache is not None else self.config.use_cache + else: + use_cache = False + + if stage_manager.is_first_stage(): + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is not None: + input_shape = input_ids.size() + elif inputs_embeds is not None: + input_shape = inputs_embeds.size()[:-1] + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + batch_size, seq_length = input_shape + device = input_ids.device if input_ids is not None else inputs_embeds.device + if token_type_ids is None: + if hasattr(self.embeddings, "token_type_ids"): + buffered_token_type_ids = self.embeddings.token_type_ids[:, :seq_length] + buffered_token_type_ids_expanded = buffered_token_type_ids.expand(batch_size, seq_length) + token_type_ids = buffered_token_type_ids_expanded + else: + token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device) + else: + input_shape = hidden_states.size()[:-1] + batch_size, seq_length = input_shape + device = hidden_states.device + + # TODO(jianghai): left the recording kv-value tensors as () or None type, this feature may be added in the future. + if output_attentions: + logger.warning_once("output_attentions=True is not supported for pipeline models at the moment.") + output_attentions = False + if output_hidden_states: + logger.warning_once("output_hidden_states=True is not supported for pipeline models at the moment.") + output_hidden_states = False + if use_cache: + logger.warning_once("use_cache=True is not supported for pipeline models at the moment.") + use_cache = False + + # past_key_values_length + past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0 + + if attention_mask is None: + attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device) + + # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] + # ourselves in which case we just need to make it broadcastable to all heads. + extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape) + attention_mask = extended_attention_mask + # If a 2D or 3D attention mask is provided for the cross-attention + # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] + if self.config.is_decoder and encoder_hidden_states is not None: + encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size() + encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length) + if encoder_attention_mask is None: + encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device) + encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask) + else: + encoder_extended_attention_mask = None + + # Prepare head mask if needed + # 1.0 in head_mask indicate we keep the head + # attention_probs has shape bsz x n_heads x N x N + # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] + # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] + head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) + hidden_states = hidden_states if hidden_states is not None else None + + if stage_manager.is_first_stage(): + hidden_states = self.embeddings( + input_ids=input_ids, + position_ids=position_ids, + token_type_ids=token_type_ids, + inputs_embeds=inputs_embeds, + past_key_values_length=past_key_values_length, + ) + + # inherit from bert_layer,this should be changed when we add the feature to record hidden_states + all_hidden_states = () if output_hidden_states else None + all_self_attentions = () if output_attentions else None + all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None + + if self.encoder.gradient_checkpointing and self.encoder.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + next_decoder_cache = () if use_cache else None + + start_idx, end_idx = stage_index[0], stage_index[1] + # layer_outputs + layer_outputs = hidden_states if hidden_states is not None else None + + # split the input tensor along sequence dimension + # [batch_size, seq_len, hidden_size] -> [batch_size, seq_len/TP_size, hidden_size] + if shard_config is not None and shard_config.enable_sequence_parallelism: + hidden_states = split_forward_gather_backward( + hidden_states, dim=1, process_group=shard_config.tensor_parallel_process_group + ) + if encoder_hidden_states is not None: + encoder_hidden_states = split_forward_gather_backward( + encoder_hidden_states, dim=1, process_group=shard_config.tensor_parallel_process_group + ) + + for idx, encoder_layer in enumerate(self.encoder.layer[start_idx:end_idx], start=start_idx): + if stage_manager.is_first_stage() and idx == 0: + encoder_attention_mask = encoder_extended_attention_mask + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + layer_head_mask = head_mask[idx] if head_mask is not None else None + past_key_value = past_key_values[idx] if past_key_values is not None else None + + if self.encoder.gradient_checkpointing and self.encoder.training: + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs, past_key_value, output_attentions) + + return custom_forward + + layer_outputs = torch.utils.checkpoint.checkpoint( + create_custom_forward(encoder_layer), + hidden_states, + attention_mask, + layer_head_mask, + encoder_hidden_states, + encoder_attention_mask, + ) + else: + layer_outputs = encoder_layer( + hidden_states, + attention_mask, + layer_head_mask, + encoder_hidden_states, + encoder_attention_mask, + past_key_value, + output_attentions, + ) + hidden_states = layer_outputs[0] + if use_cache: + next_decoder_cache += (layer_outputs[-1],) + if output_attentions: + all_self_attentions = all_self_attentions + (layer_outputs[1],) + if self.config.add_cross_attention: + all_cross_attentions = all_cross_attentions + (layer_outputs[2],) + + # When sequence parallelism done, gather the output tensor in forward and split it in backward + if shard_config is not None and shard_config.enable_sequence_parallelism: + hidden_states = gather_forward_split_backward( + hidden_states, dim=1, process_group=shard_config.tensor_parallel_process_group + ) + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + # end of a stage loop + sequence_output = hidden_states if hidden_states is not None else None + + if stage_manager.is_last_stage(): + pooled_output = self.pooler(sequence_output) if self.pooler is not None else None + if not return_dict: + return (sequence_output, pooled_output) + layer_outputs[1:] + # return dict is not supported at this moment + else: + return BaseModelOutputWithPoolingAndCrossAttentions( + last_hidden_state=sequence_output, + pooler_output=pooled_output, + past_key_values=next_decoder_cache, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + cross_attentions=all_cross_attentions, + ) + + # output of non-first and non-last stages: must be a dict + else: + # intermediate stage always return dict + return { + "hidden_states": hidden_states, + } + + @staticmethod + def bert_for_pretraining_forward( + self: BertForPreTraining, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + token_type_ids: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + next_sentence_label: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + hidden_states: Optional[torch.FloatTensor] = None, + stage_manager: Optional[PipelineStageManager] = None, + stage_index: Optional[List[int]] = None, + shard_config: ShardConfig = None, + ): + logger = logging.get_logger(__name__) + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + # TODO(jianghai) left the recording kv-value tensors as () or None type, this feature may be added in the future. + if output_attentions: + logger.warning_once("output_attentions=True is not supported for pipeline models at the moment.") + output_attentions = False + if output_hidden_states: + logger.warning_once("output_hidden_states=True is not supported for pipeline models at the moment.") + output_hidden_states = False + + outputs = BertPipelineForwards.bert_model_forward( + self.bert, + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + stage_manager=stage_manager, + hidden_states=hidden_states if hidden_states is not None else None, + stage_index=stage_index, + shard_config=shard_config, + ) + + if stage_manager.is_last_stage(): + sequence_output, pooled_output = outputs[:2] + prediction_scores, seq_relationship_score = self.cls(sequence_output, pooled_output) + # the last stage for pretraining model + total_loss = None + if labels is not None and next_sentence_label is not None: + loss_fct = CrossEntropyLoss() + masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1)) + next_sentence_loss = loss_fct(seq_relationship_score.view(-1, 2), next_sentence_label.view(-1)) + total_loss = masked_lm_loss + next_sentence_loss + + if not return_dict: + output = (prediction_scores, seq_relationship_score) + outputs[2:] + return ((total_loss,) + output) if total_loss is not None else output + + return BertForPreTrainingOutput( + loss=total_loss, + prediction_logits=prediction_scores, + seq_relationship_logits=seq_relationship_score, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + else: + hidden_states = outputs.get("hidden_states") + + # intermediate stage always return dict + return { + "hidden_states": hidden_states, + } + + @staticmethod + def bert_lm_head_model_forward( + self: BertLMHeadModel, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + token_type_ids: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + past_key_values: Optional[List[torch.Tensor]] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + hidden_states: Optional[torch.FloatTensor] = None, + stage_manager: Optional[PipelineStageManager] = None, + stage_index: Optional[List[int]] = None, + shard_config: ShardConfig = None, + ): + r""" + encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if + the model is configured as a decoder. + encoder_attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in + the cross-attention if the model is configured as a decoder. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the left-to-right language modeling loss (next word prediction). Indices should be in + `[-100, 0, ..., config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are + ignored (masked), the loss is only computed for the tokens with labels n `[0, ..., config.vocab_size]` + past_key_values (`tuple(tuple(torch.FloatTensor))` of length `config.n_layers` with each tuple having 4 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): + Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding. + + If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that + don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all + `decoder_input_ids` of shape `(batch_size, sequence_length)`. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). + """ + logger = logging.get_logger(__name__) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if labels is not None: + use_cache = False + if output_attentions: + logger.warning_once("output_attentions=True is not supported for pipeline models at the moment.") + output_attentions = False + if output_hidden_states: + logger.warning_once("output_hidden_states=True is not supported for pipeline models at the moment.") + output_hidden_states = False + + outputs = BertPipelineForwards.bert_model_forward( + self.bert, + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + stage_manager=stage_manager, + hidden_states=hidden_states if hidden_states is not None else None, + stage_index=stage_index, + shard_config=shard_config, + ) + past_key_values = None + + if stage_manager.is_last_stage(): + sequence_output = outputs[0] + prediction_scores = self.cls(sequence_output) + + lm_loss = None + if labels is not None: + # we are doing next-token prediction; shift prediction scores and input ids by one + shifted_prediction_scores = prediction_scores[:, :-1, :].contiguous() + labels = labels[:, 1:].contiguous() + loss_fct = CrossEntropyLoss() + lm_loss = loss_fct(shifted_prediction_scores.view(-1, self.config.vocab_size), labels.view(-1)) + + if not return_dict: + output = (prediction_scores,) + outputs[2:] + return ((lm_loss,) + output) if lm_loss is not None else output + + return CausalLMOutputWithCrossAttentions( + loss=lm_loss, + logits=prediction_scores, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + cross_attentions=outputs.cross_attentions, + ) + else: + hidden_states = outputs.get("hidden_states") + # intermediate stage always return dict + return {"hidden_states": hidden_states} + + @staticmethod + def bert_for_masked_lm_forward( + self: BertForMaskedLM, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + token_type_ids: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + hidden_states: Optional[torch.Tensor] = None, + stage_manager: Optional[PipelineStageManager] = None, + stage_index: Optional[List[int]] = None, + shard_config: ShardConfig = None, + ): + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ..., + config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the + loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]` + """ + logger = logging.get_logger(__name__) + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if output_attentions: + logger.warning_once("output_attentions=True is not supported for pipeline models at the moment.") + output_attentions = False + if output_hidden_states: + logger.warning_once("output_hidden_states=True is not supported for pipeline models at the moment.") + output_hidden_states = False + + outputs = BertPipelineForwards.bert_model_forward( + self.bert, + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + hidden_states=hidden_states, + stage_manager=stage_manager, + stage_index=stage_index, + shard_config=shard_config, + ) + + if stage_manager.is_last_stage(): + sequence_output = outputs[0] + prediction_scores = self.cls(sequence_output) + + masked_lm_loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() # -100 index = padding token + masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1)) + + if not return_dict: + output = (prediction_scores,) + outputs[2:] + return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output + + return MaskedLMOutput( + loss=masked_lm_loss, + logits=prediction_scores, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + else: + hidden_states = outputs.get("hidden_states") + return {"hidden_states": hidden_states} + + @staticmethod + def bert_for_next_sentence_prediction_forward( + self: BertForNextSentencePrediction, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + token_type_ids: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + hidden_states: Optional[torch.Tensor] = None, + stage_manager: Optional[PipelineStageManager] = None, + stage_index: Optional[List[int]] = None, + shard_config: ShardConfig = None, + **kwargs, + ): + # -> Union[Tuple[torch.Tensor], NextSentencePredictorOutput]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the next sequence prediction (classification) loss. Input should be a sequence pair + (see `input_ids` docstring). Indices should be in `[0, 1]`: + + - 0 indicates sequence B is a continuation of sequence A, + - 1 indicates sequence B is a random sequence. + + Returns: + + Example: + + ```python + >>> from transformers import AutoTokenizer, BertForNextSentencePrediction + >>> import torch + + >>> tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased") + >>> model = BertForNextSentencePrediction.from_pretrained("bert-base-uncased") + + >>> prompt = "In Italy, pizza served in formal settings, such as at a restaurant, is presented unsliced." + >>> next_sentence = "The sky is blue due to the shorter wavelength of blue light." + >>> encoding = tokenizer(prompt, next_sentence, return_tensors="pt") + + >>> outputs = model(**encoding, labels=torch.LongTensor([1])) + >>> logits = outputs.logits + >>> assert logits[0, 0] < logits[0, 1] # next sentence was random + ``` + """ + logger = logging.get_logger(__name__) + + if "next_sentence_label" in kwargs: + warnings.warn( + "The `next_sentence_label` argument is deprecated and will be removed in a future version, use" + " `labels` instead.", + FutureWarning, + ) + labels = kwargs.pop("next_sentence_label") + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if output_attentions: + logger.warning_once("output_attentions=True is not supported for pipeline models at the moment.") + output_attentions = False + if output_hidden_states: + logger.warning_once("output_hidden_states=True is not supported for pipeline models at the moment.") + output_hidden_states = False + + outputs = BertPipelineForwards.bert_model_forward( + self.bert, + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + hidden_states=hidden_states, + stage_manager=stage_manager, + stage_index=stage_index, + shard_config=shard_config, + ) + + if stage_manager.is_last_stage(): + pooled_output = outputs[1] + seq_relationship_scores = self.cls(pooled_output) + + next_sentence_loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() + next_sentence_loss = loss_fct(seq_relationship_scores.view(-1, 2), labels.view(-1)) + + if not return_dict: + output = (seq_relationship_scores,) + outputs[2:] + return ((next_sentence_loss,) + output) if next_sentence_loss is not None else output + + return NextSentencePredictorOutput( + loss=next_sentence_loss, + logits=seq_relationship_scores, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + else: + hidden_states = outputs.get("hidden_states") + # intermediate stage always return dict + return {"hidden_states": hidden_states} + + @staticmethod + def bert_for_sequence_classification_forward( + self: BertForSequenceClassification, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + token_type_ids: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + hidden_states: Optional[torch.Tensor] = None, + stage_manager: Optional[PipelineStageManager] = None, + stage_index: Optional[List[int]] = None, + shard_config: ShardConfig = None, + ): + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + logger = logging.get_logger(__name__) + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if output_attentions: + logger.warning_once("output_attentions=True is not supported for pipeline models at the moment.") + output_attentions = False + if output_hidden_states: + logger.warning_once("output_hidden_states=True is not supported for pipeline models at the moment.") + output_hidden_states = False + + outputs = BertPipelineForwards.bert_model_forward( + self.bert, + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + hidden_states=hidden_states, + stage_manager=stage_manager, + stage_index=stage_index, + shard_config=shard_config, + ) + + if stage_manager.is_last_stage(): + pooled_output = outputs[1] + + pooled_output = self.dropout(pooled_output) + logits = self.classifier(pooled_output) + + loss = None + if labels is not None: + if self.config.problem_type is None: + if self.num_labels == 1: + self.config.problem_type = "regression" + elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): + self.config.problem_type = "single_label_classification" + else: + self.config.problem_type = "multi_label_classification" + + if self.config.problem_type == "regression": + loss_fct = MSELoss() + if self.num_labels == 1: + loss = loss_fct(logits.squeeze(), labels.squeeze()) + else: + loss = loss_fct(logits, labels) + elif self.config.problem_type == "single_label_classification": + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + elif self.config.problem_type == "multi_label_classification": + loss_fct = BCEWithLogitsLoss() + loss = loss_fct(logits, labels) + if not return_dict: + output = (logits,) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return SequenceClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + else: + hidden_states = outputs.get("hidden_states") + return {"hidden_states": hidden_states} + + @staticmethod + def bert_for_token_classification_forward( + self: BertForTokenClassification, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + token_type_ids: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + hidden_states: Optional[torch.Tensor] = None, + stage_manager: Optional[PipelineStageManager] = None, + stage_index: Optional[List[int]] = None, + shard_config: ShardConfig = None, + ): + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`. + """ + logger = logging.get_logger(__name__) + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if output_attentions: + logger.warning_once("output_attentions=True is not supported for pipeline models at the moment.") + output_attentions = False + if output_hidden_states: + logger.warning_once("output_hidden_states=True is not supported for pipeline models at the moment.") + output_hidden_states = False + + outputs = BertPipelineForwards.bert_model_forward( + self.bert, + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + hidden_states=hidden_states, + stage_manager=stage_manager, + stage_index=stage_index, + shard_config=shard_config, + ) + + if stage_manager.is_last_stage(): + sequence_output = outputs[0] + + sequence_output = self.dropout(sequence_output) + logits = self.classifier(sequence_output) + + loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + + if not return_dict: + output = (logits,) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return TokenClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + else: + hidden_states = outputs.get("hidden_states") + return {"hidden_states": hidden_states} + + @staticmethod + def bert_for_multiple_choice_forward( + self: BertForMultipleChoice, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + token_type_ids: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + hidden_states: Optional[torch.Tensor] = None, + stage_manager: Optional[PipelineStageManager] = None, + stage_index: Optional[List[int]] = None, + shard_config: ShardConfig = None, + ): + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the multiple choice classification loss. Indices should be in `[0, ..., + num_choices-1]` where `num_choices` is the size of the second dimension of the input tensors. (See + `input_ids` above) + """ + logger = logging.get_logger(__name__) + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if output_attentions: + logger.warning_once("output_attentions=True is not supported for pipeline models at the moment.") + output_attentions = False + if output_hidden_states: + logger.warning_once("output_hidden_states=True is not supported for pipeline models at the moment.") + output_hidden_states = False + + # in our pipeline design,input ids are copied for every stage and shouldn't be none + # the input_ids for multiple choice model is [batch_size, num_choices, sequence_length] + if stage_manager.is_last_stage(): + num_choices = input_ids.shape[1] if input_ids is not None else inputs_embeds.shape[1] + + input_ids = input_ids.view(-1, input_ids.size(-1)) if input_ids is not None else None + attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None + token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None + position_ids = position_ids.view(-1, position_ids.size(-1)) if position_ids is not None else None + inputs_embeds = ( + inputs_embeds.view(-1, inputs_embeds.size(-2), inputs_embeds.size(-1)) + if inputs_embeds is not None + else None + ) + + outputs = BertPipelineForwards.bert_model_forward( + self.bert, + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + hidden_states=hidden_states, + stage_manager=stage_manager, + stage_index=stage_index, + shard_config=shard_config, + ) + if stage_manager.is_last_stage(): + pooled_output = outputs[1] + pooled_output = self.dropout(pooled_output) + logits = self.classifier(pooled_output) + reshaped_logits = logits.view(-1, num_choices) + + loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() + loss = loss_fct(reshaped_logits, labels) + + if not return_dict: + output = (reshaped_logits,) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return MultipleChoiceModelOutput( + loss=loss, + logits=reshaped_logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + else: + hidden_states = outputs.get("hidden_states") + return {"hidden_states": hidden_states} + + @staticmethod + def bert_for_question_answering_forward( + self: BertForQuestionAnswering, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + token_type_ids: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + start_positions: Optional[torch.Tensor] = None, + end_positions: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + hidden_states: Optional[torch.Tensor] = None, + stage_manager: Optional[PipelineStageManager] = None, + stage_index: Optional[List[int]] = None, + shard_config: ShardConfig = None, + ): + # NOTE: the arg start_position and end_position are used only for the last stage + r""" + start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for position (index) of the start of the labelled span for computing the token classification loss. + Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence + are not taken into account for computing the loss. + end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for position (index) of the end of the labelled span for computing the token classification loss. + Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence + are not taken into account for computing the loss. + """ + logger = logging.get_logger(__name__) + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if output_attentions: + logger.warning_once("output_attentions=True is not supported for pipeline models at the moment.") + output_attentions = False + if output_hidden_states: + logger.warning_once("output_hidden_states=True is not supported for pipeline models at the moment.") + output_hidden_states = False + + outputs = BertPipelineForwards.bert_model_forward( + self.bert, + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + hidden_states=hidden_states, + stage_manager=stage_manager, + stage_index=stage_index, + shard_config=shard_config, + ) + if stage_manager.is_last_stage(): + sequence_output = outputs[0] + + logits = self.qa_outputs(sequence_output) + start_logits, end_logits = logits.split(1, dim=-1) + start_logits = start_logits.squeeze(-1).contiguous() + end_logits = end_logits.squeeze(-1).contiguous() + + total_loss = None + if start_positions is not None and end_positions is not None: + # If we are on multi-GPU, split add a dimension + if len(start_positions.size()) > 1: + start_positions = start_positions.squeeze(-1) + if len(end_positions.size()) > 1: + end_positions = end_positions.squeeze(-1) + # sometimes the start/end positions are outside our model inputs, we ignore these terms + ignored_index = start_logits.size(1) + start_positions = start_positions.clamp(0, ignored_index) + end_positions = end_positions.clamp(0, ignored_index) + + loss_fct = CrossEntropyLoss(ignore_index=ignored_index) + start_loss = loss_fct(start_logits, start_positions) + end_loss = loss_fct(end_logits, end_positions) + total_loss = (start_loss + end_loss) / 2 + + if not return_dict: + output = (start_logits, end_logits) + outputs[2:] + return ((total_loss,) + output) if total_loss is not None else output + + return QuestionAnsweringModelOutput( + loss=total_loss, + start_logits=start_logits, + end_logits=end_logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + else: + hidden_states = outputs.get("hidden_states") + return {"hidden_states": hidden_states} + + +def get_bert_flash_attention_forward(): + try: + from xformers.ops import memory_efficient_attention as me_attention + except: + raise ImportError("Error: xformers module is not installed. Please install it to use flash attention.") + from transformers.models.bert.modeling_bert import BertAttention + + def forward( + self: BertAttention, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + output_attentions: Optional[bool] = False, + ) -> Tuple[torch.Tensor]: + mixed_query_layer = self.query(hidden_states) + + # If this is instantiated as a cross-attention module, the keys + # and values come from an encoder; the attention mask needs to be + # such that the encoder's padding tokens are not attended to. + is_cross_attention = encoder_hidden_states is not None + + if is_cross_attention and past_key_value is not None: + # reuse k,v, cross_attentions + key_layer = past_key_value[0] + value_layer = past_key_value[1] + attention_mask = encoder_attention_mask + elif is_cross_attention: + key_layer = self.transpose_for_scores(self.key(encoder_hidden_states)) + value_layer = self.transpose_for_scores(self.value(encoder_hidden_states)) + attention_mask = encoder_attention_mask + elif past_key_value is not None: + key_layer = self.transpose_for_scores(self.key(hidden_states)) + value_layer = self.transpose_for_scores(self.value(hidden_states)) + key_layer = torch.cat([past_key_value[0], key_layer], dim=2) + value_layer = torch.cat([past_key_value[1], value_layer], dim=2) + else: + key_layer = self.transpose_for_scores(self.key(hidden_states)) + value_layer = self.transpose_for_scores(self.value(hidden_states)) + + query_layer = self.transpose_for_scores(mixed_query_layer) + + use_cache = past_key_value is not None + if self.is_decoder: + # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. + # Further calls to cross_attention layer can then reuse all cross-attention + # key/value_states (first "if" case) + # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of + # all previous decoder key/value_states. Further calls to uni-directional self-attention + # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) + # if encoder bi-directional self-attention `past_key_value` is always `None` + past_key_value = (key_layer, value_layer) + + final_attention_mask = None + if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query": + query_length, key_length = query_layer.shape[2], key_layer.shape[2] + if use_cache: + position_ids_l = torch.tensor(key_length - 1, dtype=torch.long, device=hidden_states.device).view(-1, 1) + else: + position_ids_l = torch.arange(query_length, dtype=torch.long, device=hidden_states.device).view(-1, 1) + position_ids_r = torch.arange(key_length, dtype=torch.long, device=hidden_states.device).view(1, -1) + distance = position_ids_l - position_ids_r + + positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1) + positional_embedding = positional_embedding.to(dtype=query_layer.dtype) # fp16 compatibility + + if self.position_embedding_type == "relative_key": + relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding) + final_attention_mask = relative_position_scores + elif self.position_embedding_type == "relative_key_query": + relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding) + relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding) + final_attention_mask = relative_position_scores_query + relative_position_scores_key + + scale = 1 / math.sqrt(self.attention_head_size) + if attention_mask is not None: + if final_attention_mask != None: + final_attention_mask = final_attention_mask * scale + attention_mask + else: + final_attention_mask = attention_mask + + if final_attention_mask is not None: + batch_size, src_len = query_layer.size()[0], query_layer.size()[2] + tgt_len = key_layer.size()[2] + final_attention_mask = final_attention_mask.expand( + batch_size, self.num_attention_heads, src_len, tgt_len + ).contiguous() + + query_layer = query_layer.permute(0, 2, 1, 3).contiguous() + key_layer = key_layer.permute(0, 2, 1, 3).contiguous() + value_layer = value_layer.permute(0, 2, 1, 3).contiguous() + + context_layer = me_attention( + query_layer, key_layer, value_layer, attn_bias=final_attention_mask, p=self.dropout.p, scale=scale + ) + new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) + context_layer = context_layer.view(new_context_layer_shape) + + outputs = (context_layer, None) + + if self.is_decoder: + outputs = outputs + (past_key_value,) + return outputs + + return forward + + +def get_jit_fused_bert_self_output_forward(): + from transformers.models.bert.modeling_bert import BertSelfOutput + + def forward(self: BertSelfOutput, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout_add(hidden_states, input_tensor, self.dropout.p, self.dropout.training) + hidden_states = self.LayerNorm(hidden_states) + return hidden_states + + return forward + + +def get_jit_fused_bert_output_forward(): + from transformers.models.bert.modeling_bert import BertOutput + + def forward(self: BertOutput, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout_add(hidden_states, input_tensor, self.dropout.p, self.dropout.training) + hidden_states = self.LayerNorm(hidden_states) + return hidden_states + + return forward + + +def bert_sequence_parallel_forward_fn(shard_config: ShardConfig): + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + token_type_ids: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPoolingAndCrossAttentions]: + r""" + encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if + the model is configured as a decoder. + encoder_attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in + the cross-attention if the model is configured as a decoder. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + past_key_values (`tuple(tuple(torch.FloatTensor))` of length `config.n_layers` with each tuple having 4 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): + Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding. + + If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that + don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all + `decoder_input_ids` of shape `(batch_size, sequence_length)`. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if self.config.is_decoder: + use_cache = use_cache if use_cache is not None else self.config.use_cache + else: + use_cache = False + + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is not None: + input_shape = input_ids.size() + elif inputs_embeds is not None: + input_shape = inputs_embeds.size()[:-1] + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + batch_size, seq_length = input_shape + device = input_ids.device if input_ids is not None else inputs_embeds.device + + # past_key_values_length + past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0 + + if attention_mask is None: + attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device) + + if token_type_ids is None: + if hasattr(self.embeddings, "token_type_ids"): + buffered_token_type_ids = self.embeddings.token_type_ids[:, :seq_length] + buffered_token_type_ids_expanded = buffered_token_type_ids.expand(batch_size, seq_length) + token_type_ids = buffered_token_type_ids_expanded + else: + token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device) + + # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] + # ourselves in which case we just need to make it broadcastable to all heads. + extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape) + + # If a 2D or 3D attention mask is provided for the cross-attention + # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] + if self.config.is_decoder and encoder_hidden_states is not None: + encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size() + encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length) + if encoder_attention_mask is None: + encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device) + encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask) + else: + encoder_extended_attention_mask = None + + # Prepare head mask if needed + # 1.0 in head_mask indicate we keep the head + # attention_probs has shape bsz x n_heads x N x N + # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] + # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] + head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) + + embedding_output = self.embeddings( + input_ids=input_ids, + position_ids=position_ids, + token_type_ids=token_type_ids, + inputs_embeds=inputs_embeds, + past_key_values_length=past_key_values_length, + ) + + # split the input tensor along sequence dimension + # [batch_size, seq_len, hidden_size] -> [batch_size, seq_len/TP_size, hidden_size] + embedding_output = split_forward_gather_backward( + embedding_output, dim=1, process_group=shard_config.tensor_parallel_process_group + ) + if encoder_hidden_states is not None: + encoder_hidden_states = split_forward_gather_backward( + encoder_hidden_states, dim=1, process_group=shard_config.tensor_parallel_process_group + ) + + encoder_outputs = self.encoder( + embedding_output, + attention_mask=extended_attention_mask, + head_mask=head_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_extended_attention_mask, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + sequence_output = encoder_outputs[0] + + # When sequence parallelism done, gather the output tensor in forward and split it in backward + sequence_output = gather_forward_split_backward( + sequence_output, dim=1, process_group=shard_config.tensor_parallel_process_group + ) + + pooled_output = self.pooler(sequence_output) if self.pooler is not None else None + + if not return_dict: + return (sequence_output, pooled_output) + encoder_outputs[1:] + + return BaseModelOutputWithPoolingAndCrossAttentions( + last_hidden_state=sequence_output, + pooler_output=pooled_output, + past_key_values=encoder_outputs.past_key_values, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + cross_attentions=encoder_outputs.cross_attentions, + ) + + return forward diff --git a/colossalai/shardformer/modeling/blip2.py b/colossalai/shardformer/modeling/blip2.py new file mode 100644 index 0000000000000000000000000000000000000000..00b2037fbdc872c9cceceabaae219e65c760cdfe --- /dev/null +++ b/colossalai/shardformer/modeling/blip2.py @@ -0,0 +1,114 @@ +from typing import Optional, Tuple + +import torch +import torch.nn as nn + + +def forward_fn(): + def forward( + self, + hidden_states: torch.Tensor, + head_mask: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = False, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + """Input shape: Batch x Time x Channel""" + + bsz, tgt_len, embed_dim = hidden_states.size() + + mixed_qkv = self.qkv(hidden_states) + + # modified from original code, which is: + # mixed_qkv = mixed_qkv.reshape(bsz, tgt_len, 3, self.num_heads, embed_dim // self.num_heads).permute( + # 2, 0, 3, 1, 4 + # ) + # to: + mixed_qkv = mixed_qkv.reshape(bsz, tgt_len, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) + query_states, key_states, value_states = ( + mixed_qkv[0], + mixed_qkv[1], + mixed_qkv[2], + ) + + # Take the dot product between "query" and "key" to get the raw attention scores. + attention_scores = torch.matmul(query_states, key_states.transpose(-1, -2)) + + attention_scores = attention_scores * self.scale + + # Normalize the attention scores to probabilities. + attention_probs = nn.functional.softmax(attention_scores, dim=-1) + + # This is actually dropping out entire tokens to attend to, which might + # seem a bit unusual, but is taken from the original Transformer paper. + attention_probs = self.dropout(attention_probs) + + # Mask heads if we want to + if head_mask is not None: + attention_probs = attention_probs * head_mask + + context_layer = torch.matmul(attention_probs, value_states).permute(0, 2, 1, 3) + + new_context_layer_shape = context_layer.size()[:-2] + (self.embed_dim,) + context_layer = context_layer.reshape(new_context_layer_shape) + + output = self.projection(context_layer) + + outputs = (output, attention_probs) if output_attentions else (output, None) + + return outputs + + return forward + + +def get_blip2_flash_attention_forward(): + from transformers.models.blip_2.modeling_blip_2 import Blip2Attention + + from colossalai.kernel.cuda_native import ColoAttention + + def forward( + self: Blip2Attention, + hidden_states: torch.Tensor, + head_mask: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = False, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + """Input shape: Batch x Time x Channel""" + + bsz, tgt_len, embed_dim = hidden_states.size() + mixed_qkv = self.qkv(hidden_states) + mixed_qkv = mixed_qkv.reshape(bsz, tgt_len, 3, self.num_heads, -1).permute(2, 0, 1, 3, 4) + query_states, key_states, value_states = mixed_qkv[0], mixed_qkv[1], mixed_qkv[2] + + attention = ColoAttention( + embed_dim=self.embed_dim, num_heads=self.num_heads, dropout=self.dropout.p, scale=self.scale + ) + context_layer = attention(query_states, key_states, value_states) + + output = self.projection(context_layer) + outputs = (output, None) + + return outputs + + return forward + + +def get_jit_fused_blip2_QFormer_self_output_forward(): + from transformers.models.blip_2.modeling_blip_2 import Blip2QFormerSelfOutput + + def forward(self: Blip2QFormerSelfOutput, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout_add(hidden_states, input_tensor, self.dropout.p, self.dropout.training) + hidden_states = self.LayerNorm(hidden_states) + return hidden_states + + return forward + + +def get_jit_fused_blip2_QFormer_output_forward(): + from transformers.models.blip_2.modeling_blip_2 import Blip2QFormerOutput + + def forward(self: Blip2QFormerOutput, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout_add(hidden_states, input_tensor, self.dropout.p, self.dropout.training) + hidden_states = self.LayerNorm(hidden_states) + return hidden_states + + return forward diff --git a/colossalai/shardformer/modeling/bloom.py b/colossalai/shardformer/modeling/bloom.py new file mode 100644 index 0000000000000000000000000000000000000000..1bf87e80a4612a33abb560f0d02bcce4780fd382 --- /dev/null +++ b/colossalai/shardformer/modeling/bloom.py @@ -0,0 +1,1072 @@ +import warnings +from typing import List, Optional, Tuple, Union + +import torch +import torch.distributed as dist +from torch.distributed import ProcessGroup +from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss +from torch.nn import functional as F +from transformers.modeling_outputs import ( + BaseModelOutputWithPastAndCrossAttentions, + CausalLMOutputWithCrossAttentions, + QuestionAnsweringModelOutput, + SequenceClassifierOutputWithPast, + TokenClassifierOutput, +) +from transformers.models.bloom.modeling_bloom import ( + BloomForCausalLM, + BloomForQuestionAnswering, + BloomForSequenceClassification, + BloomForTokenClassification, + BloomModel, +) +from transformers.utils import logging + +from colossalai.pipeline.stage_manager import PipelineStageManager +from colossalai.shardformer.layer._operation import gather_forward_split_backward, split_forward_gather_backward +from colossalai.shardformer.shard import ShardConfig + +logger = logging.get_logger(__name__) + + +def build_bloom_alibi_tensor_fn(process_group: ProcessGroup) -> torch.Tensor: + def build_bloom_alibi_tensor( + self, attention_mask: torch.Tensor, num_heads: int, dtype: torch.dtype + ) -> torch.Tensor: + """ + Link to paper: https://arxiv.org/abs/2108.12409 Alibi tensor is not causal as the original paper mentions, it + relies on a translation invariance of softmax for quick implementation: with l being a tensor, and a fixed value + `softmax(l+a) = softmax(l)`. Based on + https://github.com/ofirpress/attention_with_linear_biases/blob/a35aaca144e0eb6b789dfcb46784c4b8e31b7983/fairseq/models/transformer.py#L742 + TODO @thomasw21 this doesn't work as nicely due to the masking strategy, and so masking varies slightly. + + Args: + Returns tensor shaped (batch_size * num_heads, 1, max_seq_len) + attention_mask (`torch.Tensor`): + Token-wise attention mask, this should be of shape (batch_size, max_seq_len). + num_heads (`int`, *required*): + number of heads + dtype (`torch.dtype`, *optional*, default=`torch.bfloat16`): + dtype of the output tensor + """ + import math + + if dist.is_initialized(): + world_size = dist.get_world_size(process_group) + num_heads = num_heads * world_size + + batch_size, seq_length = attention_mask.shape + closest_power_of_2 = 2 ** math.floor(math.log2(num_heads)) + base = torch.tensor( + 2 ** (-(2 ** -(math.log2(closest_power_of_2) - 3))), device=attention_mask.device, dtype=torch.float32 + ) + powers = torch.arange(1, 1 + closest_power_of_2, device=attention_mask.device, dtype=torch.int32) + slopes = torch.pow(base, powers) + + if closest_power_of_2 != num_heads: + extra_base = torch.tensor( + 2 ** (-(2 ** -(math.log2(2 * closest_power_of_2) - 3))), + device=attention_mask.device, + dtype=torch.float32, + ) + num_remaining_heads = min(closest_power_of_2, num_heads - closest_power_of_2) + extra_powers = torch.arange( + 1, 1 + 2 * num_remaining_heads, 2, device=attention_mask.device, dtype=torch.int32 + ) + slopes = torch.cat([slopes, torch.pow(extra_base, extra_powers)], dim=0) + + # Note: alibi will added to the attention bias that will be applied to the query, key product of attention + # => therefore alibi will have to be of shape (batch_size, num_heads, query_length, key_length) + # => here we set (batch_size=1, num_heads=num_heads, query_length=1, key_length=max_length) + # => the query_length dimension will then be broadcasted correctly + # This is more or less identical to T5's relative position bias: + # https://github.com/huggingface/transformers/blob/f681437203baa7671de3174b0fa583c349d9d5e1/src/transformers/models/t5/modeling_t5.py#L527 + arange_tensor = ((attention_mask.cumsum(dim=-1) - 1) * attention_mask)[:, None, :] + alibi = slopes[..., None] * arange_tensor + if dist.is_initialized(): + num_heads_per_rank = int(num_heads / dist.get_world_size(process_group)) + offset = dist.get_rank(process_group) * num_heads_per_rank + alibi = alibi.view(batch_size, num_heads, 1, seq_length) + alibi = alibi[:, offset : num_heads_per_rank + offset, :, :] + return alibi.reshape(batch_size * num_heads_per_rank, 1, seq_length).to(dtype) + else: + return alibi.reshape(batch_size * num_heads, 1, seq_length).to(dtype) + + return build_bloom_alibi_tensor + + +class BloomPipelineForwards: + """ + This class serves as a micro library for bloom pipeline forwards. + """ + + @staticmethod + def bloom_model_forward( + self: BloomModel, + input_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None, + attention_mask: Optional[torch.Tensor] = None, + head_mask: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + stage_manager: Optional[PipelineStageManager] = None, + hidden_states: Optional[torch.FloatTensor] = None, + stage_index: Optional[List[int]] = None, + shard_config: ShardConfig = None, + **deprecated_arguments, + ) -> Union[Tuple[torch.Tensor, ...], "BaseModelOutputWithPastAndCrossAttentions"]: + logger = logging.get_logger(__name__) + + if deprecated_arguments.pop("position_ids", False) is not False: + # `position_ids` could have been `torch.Tensor` or `None` so defaulting pop to `False` allows to detect if users were passing explicitly `None` + warnings.warn( + "`position_ids` have no functionality in BLOOM and will be removed in v5.0.0. You can safely ignore" + " passing `position_ids`.", + FutureWarning, + ) + if len(deprecated_arguments) > 0: + raise ValueError(f"Got unexpected arguments: {deprecated_arguments}") + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # add warnings here + if output_attentions: + logger.warning_once("output_attentions=True is not supported for pipeline models at the moment.") + output_attentions = False + if output_hidden_states: + logger.warning_once("output_hidden_states=True is not supported for pipeline models at the moment.") + output_hidden_states = False + if use_cache: + logger.warning_once("use_cache=True is not supported for pipeline models at the moment.") + use_cache = False + # Prepare head mask if needed + # 1.0 in head_mask indicate we keep the head + # attention_probs has shape batch_size x num_heads x N x N + + # head_mask has shape n_layer x batch x num_heads x N x N + head_mask = self.get_head_mask(head_mask, self.config.n_layer) + + # case: First stage of training + if stage_manager.is_first_stage(): + # check input_ids and inputs_embeds + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is not None: + batch_size, seq_length = input_ids.shape + elif inputs_embeds is not None: + batch_size, seq_length, _ = inputs_embeds.shape + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + if inputs_embeds is None: + inputs_embeds = self.word_embeddings(input_ids) + + hidden_states = self.word_embeddings_layernorm(inputs_embeds) + # initialize in the first stage and then pass to the next stage + else: + input_shape = hidden_states.shape[:-1] + batch_size, seq_length = input_shape + + # extra recording tensor should be generated in the first stage + + presents = () if use_cache else None + all_self_attentions = () if output_attentions else None + all_hidden_states = () if output_hidden_states else None + + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + if past_key_values is None: + past_key_values = tuple([None] * len(self.h)) + # Compute alibi tensor: check build_alibi_tensor documentation,build for every stage + seq_length_with_past = seq_length + past_key_values_length = 0 + if past_key_values[0] is not None: + past_key_values_length = past_key_values[0][0].shape[2] # source_len + + seq_length_with_past = seq_length_with_past + past_key_values_length + if attention_mask is None: + attention_mask = torch.ones((batch_size, seq_length_with_past), device=hidden_states.device) + else: + attention_mask = attention_mask.to(hidden_states.device) + + alibi = self.build_alibi_tensor(attention_mask, self.num_heads, dtype=hidden_states.dtype) + + # causal_mask is constructed every stage and its input is passed through different stages + causal_mask = self._prepare_attn_mask( + attention_mask, + input_shape=(batch_size, seq_length), + past_key_values_length=past_key_values_length, + ) + + # split the input tensor along sequence dimension + # [batch_size, seq_len, hidden_size] -> [batch_size, seq_len/TP_size, hidden_size] + if shard_config.enable_sequence_parallelism: + hidden_states = split_forward_gather_backward( + hidden_states, dim=1, process_group=shard_config.tensor_parallel_process_group + ) + + start_idx, end_idx = stage_index[0], stage_index[1] + for i, (block, layer_past) in enumerate( + zip(self.h[start_idx:end_idx], past_key_values[start_idx:end_idx]), start=start_idx + ): + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if self.gradient_checkpointing and self.training: + + def create_custom_forward(module): + def custom_forward(*inputs): + # None for past_key_value + return module(*inputs, use_cache=use_cache, output_attentions=output_attentions) + + return custom_forward + + outputs = torch.utils.checkpoint.checkpoint( + create_custom_forward(block), + hidden_states, + alibi, + causal_mask, + layer_past, + head_mask[i], + ) + else: + outputs = block( + hidden_states, + layer_past=layer_past, + attention_mask=causal_mask, + head_mask=head_mask[i], + use_cache=use_cache, + output_attentions=output_attentions, + alibi=alibi, + ) + + hidden_states = outputs[0] + + if use_cache is True: + presents = presents + (outputs[1],) + if output_attentions: + all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],) + + # When sequence parallelism done, gather the output tensor in forward and split it in backward + if shard_config.enable_sequence_parallelism: + hidden_states = gather_forward_split_backward( + hidden_states, dim=1, process_group=shard_config.tensor_parallel_process_group + ) + + if stage_manager.is_last_stage(): + # Add last hidden state + hidden_states = self.ln_f(hidden_states) + + # TODO(jianghai): deal with all_hidden_states, all_self_attentions, presents + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if stage_manager.is_last_stage(): + if not return_dict: + return tuple( + v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None + ) + + # attention_mask is not returned ; presents = past_key_values + return BaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=hidden_states, + past_key_values=presents, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + ) + else: + # always return dict for imediate stage + return {"hidden_states": hidden_states} + + @staticmethod + def bloom_for_causal_lm_forward( + self: BloomForCausalLM, + input_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None, + attention_mask: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + stage_manager: Optional[PipelineStageManager] = None, + hidden_states: Optional[torch.FloatTensor] = None, + stage_index: Optional[List[int]] = None, + shard_config: ShardConfig = None, + **deprecated_arguments, + ): + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set + `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100` + are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]` + """ + logger = logging.get_logger(__name__) + + if deprecated_arguments.pop("position_ids", False) is not False: + # `position_ids` could have been `torch.Tensor` or `None` so defaulting pop to `False` allows to detect if users were passing explicitly `None` + warnings.warn( + "`position_ids` have no functionality in BLOOM and will be removed in v5.0.0. You can safely ignore" + " passing `position_ids`.", + FutureWarning, + ) + if len(deprecated_arguments) > 0: + raise ValueError(f"Got unexpected arguments: {deprecated_arguments}") + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + # TODO(jianghai): left the recording kv-value tensors as () or None type, this feature may be added in the future. + if output_attentions: + logger.warning_once("output_attentions=True is not supported for pipeline models at the moment.") + output_attentions = False + if output_hidden_states: + logger.warning_once("output_hidden_states=True is not supported for pipeline models at the moment.") + output_hidden_states = False + + transformer_outputs = BloomPipelineForwards.bloom_model_forward( + self.transformer, + input_ids, + past_key_values=past_key_values, + attention_mask=attention_mask, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + stage_manager=stage_manager, + hidden_states=hidden_states, + stage_index=stage_index, + shard_config=shard_config, + ) + past_key_values = None + if stage_manager.is_last_stage(): + hidden_states = transformer_outputs[0] + lm_logits = self.lm_head(hidden_states) + + loss = None + if labels is not None: + # move labels to correct device to enable model parallelism + labels = labels.to(lm_logits.device) + # Shift so that tokens < n predict n + shift_logits = lm_logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + batch_size, seq_length, vocab_size = shift_logits.shape + # Flatten the tokens + loss_fct = CrossEntropyLoss() + loss = loss_fct( + shift_logits.view(batch_size * seq_length, vocab_size), shift_labels.view(batch_size * seq_length) + ) + + if not return_dict: + output = (lm_logits,) + transformer_outputs[1:] + return ((loss,) + output) if loss is not None else output + + return CausalLMOutputWithCrossAttentions( + loss=loss, + logits=lm_logits, + past_key_values=transformer_outputs.past_key_values, + hidden_states=transformer_outputs.hidden_states, + attentions=transformer_outputs.attentions, + ) + else: + hidden_states = transformer_outputs.get("hidden_states") + return {"hidden_states": hidden_states} + + @staticmethod + def bloom_for_sequence_classification_forward( + self: BloomForSequenceClassification, + input_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None, + attention_mask: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + stage_manager: Optional[PipelineStageManager] = None, + hidden_states: Optional[torch.FloatTensor] = None, + stage_index: Optional[List[int]] = None, + shard_config: ShardConfig = None, + **deprecated_arguments, + ): + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + logger = logging.get_logger(__name__) + + if deprecated_arguments.pop("position_ids", False) is not False: + # `position_ids` could have been `torch.Tensor` or `None` so defaulting pop to `False` allows to detect if users were passing explicitly `None` + warnings.warn( + "`position_ids` have no functionality in BLOOM and will be removed in v5.0.0. You can safely ignore" + " passing `position_ids`.", + FutureWarning, + ) + if len(deprecated_arguments) > 0: + raise ValueError(f"Got unexpected arguments: {deprecated_arguments}") + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # TODO(jianghai): left the recording kv-value tensors as () or None type, this feature may be added in the future. + if output_attentions: + logger.warning_once("output_attentions=True is not supported for pipeline models at the moment.") + output_attentions = False + if output_hidden_states: + logger.warning_once("output_hidden_states=True is not supported for pipeline models at the moment.") + output_hidden_states = False + + transformer_outputs = BloomPipelineForwards.bloom_model_forward( + self.transformer, + input_ids, + past_key_values=past_key_values, + attention_mask=attention_mask, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + stage_manager=stage_manager, + hidden_states=hidden_states, + stage_index=stage_index, + shard_config=shard_config, + ) + past_key_values = None + if stage_manager.is_last_stage(): + batch_size = hidden_states.shape[0] + # update batch size + hidden_states = transformer_outputs[0] + logits = self.score(hidden_states) + + if self.config.pad_token_id is None and batch_size != 1: + raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.") + if self.config.pad_token_id is None: + sequence_lengths = -1 + else: + if input_ids is not None: + sequence_lengths = (torch.ne(input_ids, self.config.pad_token_id).sum(-1) - 1).to(logits.device) + else: + sequence_lengths = -1 + logger.warning( + f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be " + "unexpected if using padding tokens in conjunction with `inputs_embeds.`" + ) + + pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths] + + loss = None + if labels is not None: + if self.config.problem_type is None: + if self.num_labels == 1: + self.config.problem_type = "regression" + elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): + self.config.problem_type = "single_label_classification" + else: + self.config.problem_type = "multi_label_classification" + + if self.config.problem_type == "regression": + loss_fct = MSELoss() + if self.num_labels == 1: + loss = loss_fct(pooled_logits.squeeze(), labels.squeeze()) + else: + loss = loss_fct(pooled_logits, labels) + elif self.config.problem_type == "single_label_classification": + loss_fct = CrossEntropyLoss() + loss = loss_fct(pooled_logits, labels) + elif self.config.problem_type == "multi_label_classification": + loss_fct = BCEWithLogitsLoss() + loss = loss_fct(pooled_logits, labels) + if not return_dict: + output = (pooled_logits,) + transformer_outputs[1:] + return ((loss,) + output) if loss is not None else output + + return SequenceClassifierOutputWithPast( + loss=loss, + logits=pooled_logits, + past_key_values=transformer_outputs.past_key_values, + hidden_states=transformer_outputs.hidden_states, + attentions=transformer_outputs.attentions, + ) + else: + hidden_states = transformer_outputs.get("hidden_states") + return {"hidden_states": hidden_states} + + @staticmethod + def bloom_for_token_classification_forward( + self: BloomForTokenClassification, + input_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None, + attention_mask: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + stage_manager: Optional[PipelineStageManager] = None, + hidden_states: Optional[torch.FloatTensor] = None, + stage_index: Optional[List[int]] = None, + shard_config: ShardConfig = None, + **deprecated_arguments, + ): + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + logger = logging.get_logger(__name__) + + if deprecated_arguments.pop("position_ids", False) is not False: + # `position_ids` could have been `torch.Tensor` or `None` so defaulting pop to `False` allows to detect if users were passing explicitly `None` + warnings.warn( + "`position_ids` have no functionality in BLOOM and will be removed in v5.0.0. You can safely ignore" + " passing `position_ids`.", + FutureWarning, + ) + if len(deprecated_arguments) > 0: + raise ValueError(f"Got unexpected arguments: {deprecated_arguments}") + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # TODO(jianghai): left the recording kv-value tensors as () or None type, this feature may be added in the future. + if output_attentions: + logger.warning_once("output_attentions=True is not supported for pipeline models at the moment.") + output_attentions = False + if output_hidden_states: + logger.warning_once("output_hidden_states=True is not supported for pipeline models at the moment.") + output_hidden_states = False + + transformer_outputs = BloomPipelineForwards.bloom_model_forward( + self.transformer, + input_ids, + past_key_values=past_key_values, + attention_mask=attention_mask, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + stage_manager=stage_manager, + hidden_states=hidden_states, + stage_index=stage_index, + shard_config=shard_config, + ) + past_key_values = None + + if stage_manager.is_last_stage(): + hidden_states = transformer_outputs[0] + hidden_states = self.dropout(hidden_states) + logits = self.classifier(hidden_states) + + loss = None + if labels is not None: + # move labels to correct device to enable model parallelism + labels = labels.to(logits.device) + batch_size, seq_length = labels.shape + loss_fct = CrossEntropyLoss() + loss = loss_fct( + logits.view(batch_size * seq_length, self.num_labels), labels.view(batch_size * seq_length) + ) + + if not return_dict: + output = (logits,) + transformer_outputs[2:] + return ((loss,) + output) if loss is not None else output + + return TokenClassifierOutput( + loss=loss, + logits=logits, + hidden_states=transformer_outputs.hidden_states, + attentions=transformer_outputs.attentions, + ) + else: + hidden_states = transformer_outputs.get("hidden_states") + return {"hidden_states": hidden_states} + + @staticmethod + def bloom_for_question_answering_forward( + self: BloomForQuestionAnswering, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + start_positions: Optional[torch.LongTensor] = None, + end_positions: Optional[torch.LongTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + stage_manager: Optional[PipelineStageManager] = None, + hidden_states: Optional[torch.FloatTensor] = None, + stage_index: Optional[List[int]] = None, + shard_config: ShardConfig = None, + ): + r""" + start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for position (index) of the start of the labelled span for computing the token classification loss. + Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence + are not taken into account for computing the loss. + end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for position (index) of the end of the labelled span for computing the token classification loss. + Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence + are not taken into account for computing the loss. + """ + logger = logging.get_logger(__name__) + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + # TODO(jianghai): left the recording kv-value tensors as () or None type, this feature may be added in the future. + if output_attentions: + logger.warning_once("output_attentions=True is not supported for pipeline models at the moment.") + output_attentions = False + if output_hidden_states: + logger.warning_once("output_hidden_states=True is not supported for pipeline models at the moment.") + output_hidden_states = False + + outputs = BloomPipelineForwards.bloom_model_forward( + self.transformer, + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + stage_manager=stage_manager, + hidden_states=hidden_states, + stage_index=stage_index, + shard_config=shard_config, + ) + + if stage_manager.is_last_stage(): + sequence_output = outputs[0] + logits = self.qa_outputs(sequence_output) + start_logits, end_logits = logits.split(1, dim=-1) + start_logits = start_logits.squeeze(-1).contiguous() + end_logits = end_logits.squeeze(-1).contiguous() + + total_loss = None + if start_positions is not None and end_positions is not None: + # If we are on multi-GPU, split add a dimension + if len(start_positions.size()) > 1: + start_positions = start_positions.squeeze(-1) + if len(end_positions.size()) > 1: + end_positions = end_positions.squeeze(-1) + # sometimes the start/end positions are outside our model inputs, we ignore these terms + ignored_index = start_logits.size(1) + start_positions = start_positions.clamp(0, ignored_index) + end_positions = end_positions.clamp(0, ignored_index) + + loss_fct = CrossEntropyLoss(ignore_index=ignored_index) + start_loss = loss_fct(start_logits, start_positions) + end_loss = loss_fct(end_logits, end_positions) + total_loss = (start_loss + end_loss) / 2 + + if not return_dict: + output = (start_logits, end_logits) + outputs[2:] + return ((total_loss,) + output) if total_loss is not None else output + + return QuestionAnsweringModelOutput( + loss=total_loss, + start_logits=start_logits, + end_logits=end_logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + else: + hidden_states = outputs.get("hidden_states") + return {"hidden_states": hidden_states} + + +def get_bloom_flash_attention_forward(enabel_jit_fused=False): + try: + from xformers.ops import memory_efficient_attention as me_attention + except: + raise ImportError("Error: xformers module is not installed. Please install it to use flash attention.") + from transformers.models.bloom.modeling_bloom import BloomAttention + + def forward( + self: BloomAttention, + hidden_states: torch.Tensor, + residual: torch.Tensor, + alibi: torch.Tensor, + attention_mask: torch.Tensor, + layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + head_mask: Optional[torch.Tensor] = None, + use_cache: bool = False, + output_attentions: bool = False, + ): + fused_qkv = self.query_key_value(hidden_states) + (query_layer, key_layer, value_layer) = self._split_heads(fused_qkv) + batch_size, tgt_len, _ = query_layer.size() + + _, kv_length, _, _ = key_layer.size() + + proj_shape = (batch_size, tgt_len, self.num_heads, self.head_dim) + query_layer = query_layer.contiguous().view(*proj_shape) + key_layer = key_layer.contiguous().view(*proj_shape) + value_layer = value_layer.contiguous().view(*proj_shape) + + if layer_past is not None: + past_key, past_value = layer_past + # concatenate along seq_length dimension: + # - key: [batch_size * self.num_heads, head_dim, kv_length] + # - value: [batch_size * self.num_heads, kv_length, head_dim] + key_layer = torch.cat((past_key, key_layer), dim=1) + value_layer = torch.cat((past_value, value_layer), dim=1) + + if use_cache is True: + present = (key_layer, value_layer) + else: + present = None + + tgt_len = key_layer.size()[1] + + attention_numerical_mask = torch.zeros( + (batch_size, self.num_heads, tgt_len, kv_length), + dtype=torch.float32, + device=query_layer.device, + requires_grad=True, + ) + attention_numerical_mask = ( + attention_numerical_mask + alibi.view(batch_size, self.num_heads, 1, kv_length) * self.beta + ) + attention_numerical_mask = torch.masked_fill( + attention_numerical_mask, attention_mask, torch.finfo(torch.float32).min + ) + + context_layer = me_attention( + query_layer, + key_layer, + value_layer, + attn_bias=attention_numerical_mask, + scale=self.inv_norm_factor, + p=self.attention_dropout.p, + ) + context_layer = context_layer.reshape(-1, kv_length, self.hidden_size) + if self.pretraining_tp > 1 and self.slow_but_exact: + slices = self.hidden_size / self.pretraining_tp + output_tensor = torch.zeros_like(context_layer) + for i in range(self.pretraining_tp): + output_tensor = output_tensor + F.linear( + context_layer[:, :, int(i * slices) : int((i + 1) * slices)], + self.dense.weight[:, int(i * slices) : int((i + 1) * slices)], + ) + else: + output_tensor = self.dense(context_layer) + + # TODO to replace with the bias_dropout_add function in jit + output_tensor = self.dropout_add(output_tensor, residual, self.hidden_dropout, self.training) + outputs = (output_tensor, present, None) + + return outputs + + return forward + + +def get_jit_fused_bloom_attention_forward(): + from transformers.models.bloom.modeling_bloom import BloomAttention + + def forward( + self: BloomAttention, + hidden_states: torch.Tensor, + residual: torch.Tensor, + alibi: torch.Tensor, + attention_mask: torch.Tensor, + layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + head_mask: Optional[torch.Tensor] = None, + use_cache: bool = False, + output_attentions: bool = False, + ): + fused_qkv = self.query_key_value(hidden_states) # [batch_size, seq_length, 3 x hidden_size] + + # 3 x [batch_size, seq_length, num_heads, head_dim] + (query_layer, key_layer, value_layer) = self._split_heads(fused_qkv) + + batch_size, q_length, _, _ = query_layer.shape + + query_layer = query_layer.transpose(1, 2).reshape(batch_size * self.num_heads, q_length, self.head_dim) + key_layer = key_layer.permute(0, 2, 3, 1).reshape(batch_size * self.num_heads, self.head_dim, q_length) + value_layer = value_layer.transpose(1, 2).reshape(batch_size * self.num_heads, q_length, self.head_dim) + if layer_past is not None: + past_key, past_value = layer_past + # concatenate along seq_length dimension: + # - key: [batch_size * self.num_heads, head_dim, kv_length] + # - value: [batch_size * self.num_heads, kv_length, head_dim] + key_layer = torch.cat((past_key, key_layer), dim=2) + value_layer = torch.cat((past_value, value_layer), dim=1) + + _, _, kv_length = key_layer.shape + + if use_cache is True: + present = (key_layer, value_layer) + else: + present = None + + # [batch_size * num_heads, q_length, kv_length] + # we use `torch.Tensor.baddbmm` instead of `torch.baddbmm` as the latter isn't supported by TorchScript v1.11 + matmul_result = alibi.baddbmm( + batch1=query_layer, + batch2=key_layer, + beta=self.beta, + alpha=self.inv_norm_factor, + ) + + # change view to [batch_size, num_heads, q_length, kv_length] + attention_scores = matmul_result.view(batch_size, self.num_heads, q_length, kv_length) + + # cast attention scores to fp32, compute scaled softmax and cast back to initial dtype - [batch_size, num_heads, q_length, kv_length] + input_dtype = attention_scores.dtype + # `float16` has a minimum value of -65504.0, whereas `bfloat16` and `float32` have a minimum value of `-3.4e+38` + if input_dtype == torch.float16: + attention_scores = attention_scores.to(torch.float) + attn_weights = torch.masked_fill(attention_scores, attention_mask, torch.finfo(attention_scores.dtype).min) + attention_probs = F.softmax(attn_weights, dim=-1, dtype=torch.float32).to(input_dtype) + + # [batch_size, num_heads, q_length, kv_length] + attention_probs = self.attention_dropout(attention_probs) + + if head_mask is not None: + attention_probs = attention_probs * head_mask + + # change view [batch_size x num_heads, q_length, kv_length] + attention_probs_reshaped = attention_probs.view(batch_size * self.num_heads, q_length, kv_length) + + # matmul: [batch_size * num_heads, q_length, head_dim] + context_layer = torch.bmm(attention_probs_reshaped, value_layer) + + # change view [batch_size, num_heads, q_length, head_dim] + context_layer = self._merge_heads(context_layer) + + # aggregate results across tp ranks. See here: https://github.com/pytorch/pytorch/issues/76232 + if self.pretraining_tp > 1 and self.slow_but_exact: + slices = self.hidden_size / self.pretraining_tp + output_tensor = torch.zeros_like(context_layer) + for i in range(self.pretraining_tp): + output_tensor = output_tensor + F.linear( + context_layer[:, :, int(i * slices) : int((i + 1) * slices)], + self.dense.weight[:, int(i * slices) : int((i + 1) * slices)], + ) + else: + output_tensor = self.dense(context_layer) + + output_tensor = self.dropout_add(output_tensor, residual, self.hidden_dropout, self.training) + + outputs = (output_tensor, present) + if output_attentions: + outputs += (attention_probs,) + + return outputs + + return forward + + +def get_jit_fused_bloom_mlp_forward(): + from transformers.models.bloom.modeling_bloom import BloomMLP + + def forward(self: BloomMLP, hidden_states: torch.Tensor, residual: torch.Tensor) -> torch.Tensor: + hidden_states = self.gelu_impl(self.dense_h_to_4h(hidden_states)) + + if self.pretraining_tp > 1 and self.slow_but_exact: + intermediate_output = torch.zeros_like(residual) + slices = self.dense_4h_to_h.weight.shape[-1] / self.pretraining_tp + for i in range(self.pretraining_tp): + intermediate_output = intermediate_output + F.linear( + hidden_states[:, :, int(i * slices) : int((i + 1) * slices)], + self.dense_4h_to_h.weight[:, int(i * slices) : int((i + 1) * slices)], + ) + else: + intermediate_output = self.dense_4h_to_h(hidden_states) + output = self.dropout_add(intermediate_output, residual, self.hidden_dropout, self.training) + return output + + return forward + + +def get_jit_fused_bloom_gelu_forward(): + from transformers.models.bloom.modeling_bloom import BloomGelu + + from colossalai.kernel.jit.bias_gelu import GeLUFunction as JitGeLUFunction + + def forward(self: BloomGelu, x: torch.Tensor) -> torch.Tensor: + bias = torch.zeros_like(x) + if self.training: + return JitGeLUFunction.apply(x, bias) + else: + return self.bloom_gelu_forward(x, bias) + + return forward + + +def get_bloom_sequence_parallel_forward_fn(shard_config: ShardConfig): + from transformers import BloomModel + + def forward( + self: BloomModel, + input_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None, + attention_mask: Optional[torch.Tensor] = None, + head_mask: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + **deprecated_arguments, + ) -> Union[Tuple[torch.Tensor, ...], BaseModelOutputWithPastAndCrossAttentions]: + if deprecated_arguments.pop("position_ids", False) is not False: + # `position_ids` could have been `torch.Tensor` or `None` so defaulting pop to `False` allows to detect if users were passing explicitly `None` + warnings.warn( + "`position_ids` have no functionality in BLOOM and will be removed in v5.0.0. You can safely ignore" + " passing `position_ids`.", + FutureWarning, + ) + if len(deprecated_arguments) > 0: + raise ValueError(f"Got unexpected arguments: {deprecated_arguments}") + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is not None: + batch_size, seq_length = input_ids.shape + elif inputs_embeds is not None: + batch_size, seq_length, _ = inputs_embeds.shape + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + if past_key_values is None: + past_key_values = tuple([None] * len(self.h)) + + # Prepare head mask if needed + # 1.0 in head_mask indicate we keep the head + # attention_probs has shape batch_size x num_heads x N x N + # head_mask has shape n_layer x batch x num_heads x N x N + head_mask = self.get_head_mask(head_mask, self.config.n_layer) + + if inputs_embeds is None: + inputs_embeds = self.word_embeddings(input_ids) + + hidden_states = self.word_embeddings_layernorm(inputs_embeds) + + presents = () if use_cache else None + all_self_attentions = () if output_attentions else None + all_hidden_states = () if output_hidden_states else None + + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + # Compute alibi tensor: check build_alibi_tensor documentation + seq_length_with_past = seq_length + past_key_values_length = 0 + if past_key_values[0] is not None: + past_key_values_length = past_key_values[0][0].shape[2] + seq_length_with_past = seq_length_with_past + past_key_values_length + if attention_mask is None: + attention_mask = torch.ones((batch_size, seq_length_with_past), device=hidden_states.device) + else: + attention_mask = attention_mask.to(hidden_states.device) + + alibi = self.build_alibi_tensor(attention_mask, self.num_heads, dtype=hidden_states.dtype) + + causal_mask = self._prepare_attn_mask( + attention_mask, + input_shape=(batch_size, seq_length), + past_key_values_length=past_key_values_length, + ) + # split the input tensor along sequence dimension + # [batch_size, seq_len, hidden_size] -> [batch_size, seq_len/TP_size, hidden_size] + hidden_states = split_forward_gather_backward( + hidden_states, dim=1, process_group=shard_config.tensor_parallel_process_group + ) + + for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)): + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if self.gradient_checkpointing and self.training: + + def create_custom_forward(module): + def custom_forward(*inputs): + # None for past_key_value + return module(*inputs, use_cache=use_cache, output_attentions=output_attentions) + + return custom_forward + + outputs = torch.utils.checkpoint.checkpoint( + create_custom_forward(block), + hidden_states, + alibi, + causal_mask, + layer_past, + head_mask[i], + ) + else: + outputs = block( + hidden_states, + layer_past=layer_past, + attention_mask=causal_mask, + head_mask=head_mask[i], + use_cache=use_cache, + output_attentions=output_attentions, + alibi=alibi, + ) + + hidden_states = outputs[0] + if use_cache is True: + presents = presents + (outputs[1],) + + if output_attentions: + all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],) + + # When sequence parallelism done, gather the output tensor in forward and split it in backward + hidden_states = gather_forward_split_backward( + hidden_states, dim=1, process_group=shard_config.tensor_parallel_process_group + ) + # Add last hidden state + hidden_states = self.ln_f(hidden_states) + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if not return_dict: + return tuple(v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None) + + return BaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=hidden_states, + past_key_values=presents, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + ) + + return forward diff --git a/colossalai/shardformer/modeling/chatglm2.py b/colossalai/shardformer/modeling/chatglm2.py new file mode 100644 index 0000000000000000000000000000000000000000..8934068d609c88d5a5d9f1884cc11986fdfea689 --- /dev/null +++ b/colossalai/shardformer/modeling/chatglm2.py @@ -0,0 +1,404 @@ +""" PyTorch ChatGLM model. """ +from typing import List, Optional, Tuple + +import torch +import torch.utils.checkpoint +from torch.nn import CrossEntropyLoss +from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast +from transformers.utils import logging + +from colossalai.pipeline.stage_manager import PipelineStageManager +from colossalai.shardformer import ShardConfig +from colossalai.shardformer.layer._operation import gather_forward_split_backward, split_forward_gather_backward +from colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm import ChatGLMForConditionalGeneration, ChatGLMModel + + +def get_flash_core_attention_forward(): + from colossalai.kernel.cuda_native import AttnMaskType, ColoAttention + + from .chatglm2_6b.modeling_chatglm import CoreAttention + + def forward(self: CoreAttention, query_layer, key_layer, value_layer, attention_mask): + pytorch_major_version = int(torch.__version__.split(".")[0]) + if pytorch_major_version >= 2: + query_layer, key_layer, value_layer = [k.permute(1, 2, 0, 3) for k in [query_layer, key_layer, value_layer]] + if attention_mask is None and query_layer.shape[2] == key_layer.shape[2]: + context_layer = torch.nn.functional.scaled_dot_product_attention( + query_layer, key_layer, value_layer, is_causal=True + ) + else: + if attention_mask is not None: + attention_mask = ~attention_mask + context_layer = torch.nn.functional.scaled_dot_product_attention( + query_layer, key_layer, value_layer, attention_mask + ) + context_layer = context_layer.permute(2, 0, 1, 3) + new_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size_per_partition,) + context_layer = context_layer.reshape(*new_context_layer_shape) + else: + # Raw attention scores + query_layer = query_layer.permute(1, 0, 2, 3).contiguous() + key_layer = key_layer.permute(1, 0, 2, 3).contiguous() + value_layer = value_layer.permute(1, 0, 2, 3).contiguous() + + scale = 1.0 / self.norm_factor + if self.coeff is not None: + scale = scale * self.coeff + + flash_attention_mask = None + attn_mask_type = None + if attention_mask is None: + attn_mask_type = AttnMaskType.causal + else: + flash_attention_mask = ~(attention_mask[:, :, -1].squeeze(1).to(torch.bool)).contiguous() + attn_mask_type = AttnMaskType.paddedcausal + + attention = ColoAttention( + embed_dim=self.hidden_size_per_partition, + num_heads=self.num_attention_heads_per_partition, + dropout=self.attention_dropout.p, + scale=scale, + ) + context_layer = attention( + query_layer, key_layer, value_layer, attn_mask=flash_attention_mask, attn_mask_type=attn_mask_type + ) + + context_layer = context_layer.permute(1, 0, -1).contiguous() + + return context_layer + + return forward + + +def get_jit_fused_glm_block_forward(): + from .chatglm2_6b.modeling_chatglm import GLMBlock + + def forward( + self: GLMBlock, + hidden_states, + attention_mask, + rotary_pos_emb, + kv_cache=None, + use_cache=True, + ): + # hidden_states: [s, b, h] + # Layer norm at the beginning of the transformer layer. + layernorm_output = self.input_layernorm(hidden_states) + # Self attention. + attention_output, kv_cache = self.self_attention( + layernorm_output, + attention_mask, + rotary_pos_emb, + kv_cache=kv_cache, + use_cache=use_cache, + ) + + # Residual connection. + if self.apply_residual_connection_post_layernorm: + residual = layernorm_output + else: + residual = hidden_states + + layernorm_input = self.dropout_add(attention_output, residual, self.hidden_dropout, self.training) + + # Layer norm post the self attention. + layernorm_output = self.post_attention_layernorm(layernorm_input) + + # MLP. + mlp_output = self.mlp(layernorm_output) + + # Second residual connection. + if self.apply_residual_connection_post_layernorm: + residual = layernorm_output + else: + residual = layernorm_input + + output = self.dropout_add(mlp_output, residual, self.hidden_dropout, self.training) + + return output, kv_cache + + return forward + + +class ChatGLMPipelineForwards: + """ + This class serves as a micro library for ChatGLM model forwards under pipeline parallelism. + """ + + @staticmethod + def chatglm_model_forward( + self: ChatGLMModel, + input_ids, + position_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.BoolTensor] = None, + full_attention_mask: Optional[torch.BoolTensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None, + inputs_embeds: Optional[torch.Tensor] = None, + use_cache: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + stage_manager: Optional[PipelineStageManager] = None, + hidden_states: Optional[torch.FloatTensor] = None, + stage_index: Optional[List[int]] = None, + shard_config: ShardConfig = None, + ): + logger = logging.get_logger(__name__) + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + # TODO(jianghai): left the recording kv-value tensors as () or None type, this feature may be added in the future. + if past_key_values: + logger.warning_once("Non-empty past_key_values is not supported for pipeline models at the moment.") + past_key_values = None + if output_hidden_states: + logger.warning_once("output_hidden_states=True is not supported for pipeline models at the moment.") + output_hidden_states = False + if use_cache: + logger.warning_once("use_cache=True is not supported for pipeline models at the moment.") + use_cache = False + if stage_manager.is_first_stage(): + batch_size, seq_length = input_ids.shape + if inputs_embeds is None: + inputs_embeds = self.embedding(input_ids) + hidden_states = inputs_embeds + else: + seq_length, batch_size = hidden_states.shape[:2] + if self.pre_seq_len is not None: + if past_key_values is None: + past_key_values = self.get_prompt( + batch_size=batch_size, device=input_ids.device, dtype=inputs_embeds.dtype + ) + if attention_mask is not None: + attention_mask = torch.cat( + [attention_mask.new_ones((batch_size, self.pre_seq_len)), attention_mask], dim=-1 + ) + if full_attention_mask is None: + if (attention_mask is not None and not attention_mask.all()) or (past_key_values and seq_length != 1): + full_attention_mask = self.get_masks(input_ids, past_key_values, padding_mask=attention_mask) + # Rotary positional embeddings + rotary_pos_emb = self.rotary_pos_emb(self.seq_length) + if position_ids is not None: + rotary_pos_emb = rotary_pos_emb[position_ids] + else: + rotary_pos_emb = rotary_pos_emb[None, :seq_length] + rotary_pos_emb = rotary_pos_emb.transpose(0, 1).contiguous() + if not past_key_values: + past_key_values = [None for _ in range(self.num_layers)] + presents = () if use_cache else None + if self.encoder.gradient_checkpointing and self.encoder.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + all_self_attentions = None + all_hidden_states = () if output_hidden_states else None + start_idx, end_idx = stage_index[0], stage_index[1] + + if shard_config.enable_sequence_parallelism: + hidden_states = split_forward_gather_backward( + hidden_states, dim=0, process_group=shard_config.tensor_parallel_process_group + ) + for idx in range(start_idx, end_idx): + layer = self.encoder._get_layer(idx) + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + if self.encoder.gradient_checkpointing and self.encoder.training: + layer_ret = torch.utils.checkpoint.checkpoint( + layer, hidden_states, attention_mask, rotary_pos_emb, past_key_values[idx], use_cache + ) + else: + layer_ret = layer( + hidden_states, + full_attention_mask, + rotary_pos_emb, + kv_cache=past_key_values[idx], + use_cache=use_cache, + ) + hidden_states, kv_cache = layer_ret + if use_cache: + presents = presents + (kv_cache,) + + if shard_config.enable_sequence_parallelism: + hidden_states = gather_forward_split_backward( + hidden_states, dim=0, process_group=shard_config.tensor_parallel_process_group + ) + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + if stage_manager.is_last_stage(): + # final layer_norm + if self.encoder.post_layer_norm: + hidden_states = self.encoder.final_layernorm(hidden_states) + if not return_dict: + return tuple( + v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None + ) + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=presents, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + ) + else: + return {"hidden_states": hidden_states} + + @staticmethod + def chatglm_for_conditional_generation_forward( + self: ChatGLMForConditionalGeneration, + input_ids: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[Tuple[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + return_last_logit: Optional[bool] = False, + stage_manager: Optional[PipelineStageManager] = None, + hidden_states: Optional[torch.FloatTensor] = None, + stage_index: Optional[List[int]] = None, + shard_config: ShardConfig = None, + ): + logging.get_logger(__name__) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + transformer_outputs = ChatGLMPipelineForwards.chatglm_model_forward( + self.transformer, + input_ids=input_ids, + position_ids=position_ids, + attention_mask=attention_mask, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + stage_manager=stage_manager, + hidden_states=hidden_states, + stage_index=stage_index, + shard_config=shard_config, + ) + if stage_manager.is_last_stage(): + hidden_states = transformer_outputs[0] + if return_last_logit: + hidden_states = hidden_states[-1:] + lm_logits = self.transformer.output_layer(hidden_states) + lm_logits = lm_logits.transpose(0, 1).contiguous() + loss = None + if labels is not None: + lm_logits = lm_logits.to(torch.float32) + # Shift so that tokens < n predict n + shift_logits = lm_logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + loss_fct = CrossEntropyLoss(ignore_index=-100) + loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) + lm_logits = lm_logits.to(hidden_states.dtype) + loss = loss.to(hidden_states.dtype) + if not return_dict: + output = (lm_logits,) + transformer_outputs[1:] + return ((loss,) + output) if loss is not None else output + return CausalLMOutputWithPast( + loss=loss, + logits=lm_logits, + past_key_values=transformer_outputs.past_key_values, + hidden_states=transformer_outputs.hidden_states, + attentions=transformer_outputs.attentions, + ) + else: + return transformer_outputs + + +def get_chatglm_sequence_parallel_forward_fn(shard_config: ShardConfig): + def forward( + self, + input_ids, + position_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.BoolTensor] = None, + full_attention_mask: Optional[torch.BoolTensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None, + inputs_embeds: Optional[torch.Tensor] = None, + use_cache: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ): + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + batch_size, seq_length = input_ids.shape + + if inputs_embeds is None: + inputs_embeds = self.embedding(input_ids) + + if self.pre_seq_len is not None: + if past_key_values is None: + past_key_values = self.get_prompt( + batch_size=batch_size, + device=input_ids.device, + dtype=inputs_embeds.dtype, + ) + if attention_mask is not None: + attention_mask = torch.cat( + [ + attention_mask.new_ones((batch_size, self.pre_seq_len)), + attention_mask, + ], + dim=-1, + ) + + if full_attention_mask is None: + if (attention_mask is not None and not attention_mask.all()) or (past_key_values and seq_length != 1): + full_attention_mask = self.get_masks(input_ids, past_key_values, padding_mask=attention_mask) + + # Rotary positional embeddings + rotary_pos_emb = self.rotary_pos_emb(self.seq_length) + if position_ids is not None: + rotary_pos_emb = rotary_pos_emb[position_ids] + else: + rotary_pos_emb = rotary_pos_emb[None, :seq_length] + rotary_pos_emb = rotary_pos_emb.transpose(0, 1).contiguous() + + # Run encoder. + # [seq_len, batch_size, hidden_size] -> [seq_len/TP_size, batch_size, hidden_size] + inputs_embeds = split_forward_gather_backward( + inputs_embeds, dim=0, process_group=shard_config.tensor_parallel_process_group + ) + hidden_states, presents, all_hidden_states, all_self_attentions = self.encoder( + inputs_embeds, + full_attention_mask, + rotary_pos_emb=rotary_pos_emb, + kv_caches=past_key_values, + use_cache=use_cache, + output_hidden_states=output_hidden_states, + ) + + hidden_states = gather_forward_split_backward( + hidden_states, dim=0, process_group=shard_config.tensor_parallel_process_group + ) + + if not return_dict: + return tuple( + v + for v in [ + hidden_states, + presents, + all_hidden_states, + all_self_attentions, + ] + if v is not None + ) + + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=presents, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + ) + + return forward diff --git a/colossalai/shardformer/modeling/chatglm2_6b/configuration_chatglm.py b/colossalai/shardformer/modeling/chatglm2_6b/configuration_chatglm.py new file mode 100644 index 0000000000000000000000000000000000000000..bb774676a4d47956be38f362f1202dad1a2ed1b9 --- /dev/null +++ b/colossalai/shardformer/modeling/chatglm2_6b/configuration_chatglm.py @@ -0,0 +1,60 @@ +from transformers import PretrainedConfig + + +class ChatGLMConfig(PretrainedConfig): + model_type = "chatglm" + + def __init__( + self, + num_layers=28, + padded_vocab_size=65024, + hidden_size=4096, + ffn_hidden_size=13696, + kv_channels=128, + num_attention_heads=32, + seq_length=2048, + hidden_dropout=0.0, + attention_dropout=0.0, + layernorm_epsilon=1e-5, + rmsnorm=True, + apply_residual_connection_post_layernorm=False, + post_layer_norm=True, + add_bias_linear=False, + add_qkv_bias=False, + bias_dropout_fusion=True, + multi_query_attention=False, + multi_query_group_num=1, + apply_query_key_layer_scaling=True, + attention_softmax_in_fp32=True, + fp32_residual_connection=False, + quantization_bit=0, + pre_seq_len=None, + prefix_projection=False, + **kwargs, + ): + self.num_layers = num_layers + self.vocab_size = padded_vocab_size + self.padded_vocab_size = padded_vocab_size + self.hidden_size = hidden_size + self.ffn_hidden_size = ffn_hidden_size + self.kv_channels = kv_channels + self.num_attention_heads = num_attention_heads + self.seq_length = seq_length + self.hidden_dropout = hidden_dropout + self.attention_dropout = attention_dropout + self.layernorm_epsilon = layernorm_epsilon + self.rmsnorm = rmsnorm + self.apply_residual_connection_post_layernorm = apply_residual_connection_post_layernorm + self.post_layer_norm = post_layer_norm + self.add_bias_linear = add_bias_linear + self.add_qkv_bias = add_qkv_bias + self.bias_dropout_fusion = bias_dropout_fusion + self.multi_query_attention = multi_query_attention + self.multi_query_group_num = multi_query_group_num + self.apply_query_key_layer_scaling = apply_query_key_layer_scaling + self.attention_softmax_in_fp32 = attention_softmax_in_fp32 + self.fp32_residual_connection = fp32_residual_connection + self.quantization_bit = quantization_bit + self.pre_seq_len = pre_seq_len + self.prefix_projection = prefix_projection + super().__init__(**kwargs) diff --git a/colossalai/shardformer/modeling/chatglm2_6b/modeling_chatglm.py b/colossalai/shardformer/modeling/chatglm2_6b/modeling_chatglm.py new file mode 100644 index 0000000000000000000000000000000000000000..fdd49ecfeae55872db37c60088e0ef942cca2ab0 --- /dev/null +++ b/colossalai/shardformer/modeling/chatglm2_6b/modeling_chatglm.py @@ -0,0 +1,1393 @@ +""" +The ChatGLM2-6B License + +1. Definitions + +“Licensor” means the ChatGLM2-6B Model Team that distributes its Software. + +“Software” means the ChatGLM2-6B model parameters made available under this license. + +2. License Grant + +Subject to the terms and conditions of this License, the Licensor hereby grants to you a non-exclusive, worldwide, non-transferable, non-sublicensable, revocable, royalty-free copyright license to use the Software solely for your non-commercial research purposes. + +The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. + +3. Restriction + +You will not use, copy, modify, merge, publish, distribute, reproduce, or create derivative works of the Software, in whole or in part, for any commercial, military, or illegal purposes. + +You will not use the Software for any act that may undermine China's national security and national unity, harm the public interest of society, or infringe upon the rights and interests of human beings. + +4. Disclaimer + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + +5. Limitation of Liability + +EXCEPT TO THE EXTENT PROHIBITED BY APPLICABLE LAW, IN NO EVENT AND UNDER NO LEGAL THEORY, WHETHER BASED IN TORT, NEGLIGENCE, CONTRACT, LIABILITY, OR OTHERWISE WILL ANY LICENSOR BE LIABLE TO YOU FOR ANY DIRECT, INDIRECT, SPECIAL, INCIDENTAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES, OR ANY OTHER COMMERCIAL LOSSES, EVEN IF THE LICENSOR HAS BEEN ADVISED OF THE POSSIBILITY OF SUCH DAMAGES. + +6. Dispute Resolution + +This license shall be governed and construed in accordance with the laws of People’s Republic of China. Any dispute arising from or in connection with this License shall be submitted to Haidian District People's Court in Beijing. + +Note that the license is subject to update to a more comprehensive version. For any questions related to the license and copyright, please contact us at glm-130b@googlegroups.com. +""" +""" PyTorch ChatGLM model. """ + +import copy +import math +import sys +import warnings +from typing import Any, Callable, Dict, List, Optional, Tuple + +import torch +import torch.nn.functional as F +import torch.utils.checkpoint +from torch import nn +from torch.nn import CrossEntropyLoss, LayerNorm +from torch.nn.utils import skip_init +from transformers.generation.logits_process import LogitsProcessor +from transformers.generation.utils import GenerationConfig, LogitsProcessorList, ModelOutput, StoppingCriteriaList +from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast +from transformers.modeling_utils import PreTrainedModel +from transformers.utils import logging + +from .configuration_chatglm import ChatGLMConfig + +# flags required to enable jit fusion kernels + +if sys.platform != "darwin": + torch._C._jit_set_profiling_mode(False) + torch._C._jit_set_profiling_executor(False) + torch._C._jit_override_can_fuse_on_cpu(True) + torch._C._jit_override_can_fuse_on_gpu(True) + +logger = logging.get_logger(__name__) + +_CHECKPOINT_FOR_DOC = "THUDM/ChatGLM2-6B" +_CONFIG_FOR_DOC = "ChatGLM6BConfig" + +CHATGLM_6B_PRETRAINED_MODEL_ARCHIVE_LIST = [ + "THUDM/chatglm2-6b", + # See all ChatGLM models at https://huggingface.co/models?filter=chatglm +] + + +def default_init(cls, *args, **kwargs): + return cls(*args, **kwargs) + + +class InvalidScoreLogitsProcessor(LogitsProcessor): + def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: + if torch.isnan(scores).any() or torch.isinf(scores).any(): + scores.zero_() + scores[..., 5] = 5e4 + return scores + + +class PrefixEncoder(torch.nn.Module): + """ + The torch.nn model to encode the prefix + Input shape: (batch-size, prefix-length) + Output shape: (batch-size, prefix-length, 2*layers*hidden) + """ + + def __init__(self, config: ChatGLMConfig): + super().__init__() + self.prefix_projection = config.prefix_projection + if self.prefix_projection: + # Use a two-layer MLP to encode the prefix + kv_size = config.num_layers * config.kv_channels * config.multi_query_group_num * 2 + self.embedding = torch.nn.Embedding(config.pre_seq_len, kv_size) + self.trans = torch.nn.Sequential( + torch.nn.Linear(kv_size, config.hidden_size), + torch.nn.Tanh(), + torch.nn.Linear(config.hidden_size, kv_size), + ) + else: + self.embedding = torch.nn.Embedding( + config.pre_seq_len, + config.num_layers * config.kv_channels * config.multi_query_group_num * 2, + ) + + def forward(self, prefix: torch.Tensor): + if self.prefix_projection: + prefix_tokens = self.embedding(prefix) + past_key_values = self.trans(prefix_tokens) + else: + past_key_values = self.embedding(prefix) + return past_key_values + + +def split_tensor_along_last_dim( + tensor: torch.Tensor, + num_partitions: int, + contiguous_split_chunks: bool = False, +) -> List[torch.Tensor]: + """Split a tensor along its last dimension. + + Arguments: + tensor: input tensor. + num_partitions: number of partitions to split the tensor + contiguous_split_chunks: If True, make each chunk contiguous + in memory. + + Returns: + A list of Tensors + """ + # Get the size and dimension. + last_dim = tensor.dim() - 1 + last_dim_size = tensor.size()[last_dim] // num_partitions + # Split. + tensor_list = torch.split(tensor, last_dim_size, dim=last_dim) + # Note: torch.split does not create contiguous tensors by default. + if contiguous_split_chunks: + return tuple(chunk.contiguous() for chunk in tensor_list) + + return tensor_list + + +class RotaryEmbedding(nn.Module): + def __init__(self, dim, original_impl=False, device=None, dtype=None): + super().__init__() + inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2, device=device).to(dtype=dtype) / dim)) + self.register_buffer("inv_freq", inv_freq) + self.dim = dim + self.original_impl = original_impl + + def forward_impl( + self, + seq_len: int, + n_elem: int, + dtype: torch.dtype, + device: torch.device, + base: int = 10000, + ): + """Enhanced Transformer with Rotary Position Embedding. + + Derived from: https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/labml_nn/ + transformers/rope/__init__.py. MIT License: + https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/license. + """ + # $\Theta = {\theta_i = 10000^{\frac{2(i-1)}{d}}, i \in [1, 2, ..., \frac{d}{2}]}$ + theta = 1.0 / (base ** (torch.arange(0, n_elem, 2, dtype=dtype, device=device) / n_elem)) + + # Create position indexes `[0, 1, ..., seq_len - 1]` + seq_idx = torch.arange(seq_len, dtype=dtype, device=device) + + # Calculate the product of position index and $\theta_i$ + idx_theta = torch.outer(seq_idx, theta).float() + + cache = torch.stack([torch.cos(idx_theta), torch.sin(idx_theta)], dim=-1) + + # this is to mimic the behaviour of complex32, else we will get different results + if dtype in (torch.float16, torch.bfloat16, torch.int8): + cache = cache.bfloat16() if dtype == torch.bfloat16 else cache.half() + return cache + + def forward(self, max_seq_len, offset=0): + return self.forward_impl( + max_seq_len, + self.dim, + dtype=self.inv_freq.dtype, + device=self.inv_freq.device, + ) + + +@torch.jit.script +def apply_rotary_pos_emb(x: torch.Tensor, rope_cache: torch.Tensor) -> torch.Tensor: + # x: [sq, b, np, hn] + sq, b, np, hn = x.size(0), x.size(1), x.size(2), x.size(3) + rot_dim = rope_cache.shape[-2] * 2 + x, x_pass = x[..., :rot_dim], x[..., rot_dim:] + # truncate to support variable sizes + rope_cache = rope_cache[:sq] + xshaped = x.reshape(sq, -1, np, rot_dim // 2, 2) + rope_cache = rope_cache.view(sq, -1, 1, xshaped.size(3), 2) + x_out2 = torch.stack( + [ + xshaped[..., 0] * rope_cache[..., 0] - xshaped[..., 1] * rope_cache[..., 1], + xshaped[..., 1] * rope_cache[..., 0] + xshaped[..., 0] * rope_cache[..., 1], + ], + -1, + ) + x_out2 = x_out2.flatten(3) + return torch.cat((x_out2, x_pass), dim=-1) + + +class RMSNorm(torch.nn.Module): + def __init__(self, normalized_shape, eps=1e-5, device=None, dtype=None, **kwargs): + super().__init__() + self.elementwise_affine = True + self.normalized_shape = normalized_shape + self.weight = torch.nn.Parameter(torch.ones(normalized_shape, device=device, dtype=dtype)) + self.eps = eps + + def forward(self, hidden_states: torch.Tensor): + input_dtype = hidden_states.dtype + variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.eps) + return (self.weight * hidden_states).to(input_dtype) + + +class CoreAttention(torch.nn.Module): + def __init__(self, config: ChatGLMConfig, layer_number): + super(CoreAttention, self).__init__() + + self.apply_query_key_layer_scaling = config.apply_query_key_layer_scaling + self.attention_softmax_in_fp32 = config.attention_softmax_in_fp32 + if self.apply_query_key_layer_scaling: + self.attention_softmax_in_fp32 = True + self.layer_number = max(1, layer_number) + + projection_size = config.kv_channels * config.num_attention_heads + + # Per attention head and per partition values. + self.hidden_size_per_partition = projection_size + self.hidden_size_per_attention_head = projection_size // config.num_attention_heads + self.num_attention_heads_per_partition = config.num_attention_heads + + coeff = None + self.norm_factor = math.sqrt(self.hidden_size_per_attention_head) + if self.apply_query_key_layer_scaling: + coeff = self.layer_number + self.norm_factor *= coeff + self.coeff = coeff + + self.attention_dropout = torch.nn.Dropout(config.attention_dropout) + + def forward(self, query_layer, key_layer, value_layer, attention_mask): + pytorch_major_version = int(torch.__version__.split(".")[0]) + if pytorch_major_version >= 2: + query_layer, key_layer, value_layer = [k.permute(1, 2, 0, 3) for k in [query_layer, key_layer, value_layer]] + if attention_mask is None and query_layer.shape[2] == key_layer.shape[2]: + context_layer = torch.nn.functional.scaled_dot_product_attention( + query_layer, key_layer, value_layer, is_causal=True + ) + else: + if attention_mask is not None: + attention_mask = ~attention_mask + context_layer = torch.nn.functional.scaled_dot_product_attention( + query_layer, key_layer, value_layer, attention_mask + ) + context_layer = context_layer.permute(2, 0, 1, 3) + new_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size_per_partition,) + context_layer = context_layer.reshape(*new_context_layer_shape) + else: + # Raw attention scores + + # [b, np, sq, sk] + output_size = ( + query_layer.size(1), + query_layer.size(2), + query_layer.size(0), + key_layer.size(0), + ) + + # [sq, b, np, hn] -> [sq, b * np, hn] + query_layer = query_layer.view(output_size[2], output_size[0] * output_size[1], -1) + # [sk, b, np, hn] -> [sk, b * np, hn] + key_layer = key_layer.view(output_size[3], output_size[0] * output_size[1], -1) + + # preallocting input tensor: [b * np, sq, sk] + matmul_input_buffer = torch.empty( + output_size[0] * output_size[1], + output_size[2], + output_size[3], + dtype=query_layer.dtype, + device=query_layer.device, + ) + + # Raw attention scores. [b * np, sq, sk] + matmul_result = torch.baddbmm( + matmul_input_buffer, + query_layer.transpose(0, 1), # [b * np, sq, hn] + key_layer.transpose(0, 1).transpose(1, 2), # [b * np, hn, sk] + beta=0.0, + alpha=(1.0 / self.norm_factor), + ) + + # change view to [b, np, sq, sk] + attention_scores = matmul_result.view(*output_size) + + # =========================== + # Attention probs and dropout + # =========================== + + # attention scores and attention mask [b, np, sq, sk] + if self.attention_softmax_in_fp32: + attention_scores = attention_scores.float() + if self.coeff is not None: + attention_scores = attention_scores * self.coeff + if attention_mask is None and attention_scores.shape[2] == attention_scores.shape[3]: + attention_mask = torch.ones( + output_size[0], + 1, + output_size[2], + output_size[3], + device=attention_scores.device, + dtype=torch.bool, + ) + attention_mask.tril_() + attention_mask = ~attention_mask + if attention_mask is not None: + attention_scores = attention_scores.masked_fill(attention_mask, float("-inf")) + attention_probs = F.softmax(attention_scores, dim=-1) + attention_probs = attention_probs.type_as(value_layer) + + # This is actually dropping out entire tokens to attend to, which might + # seem a bit unusual, but is taken from the original Transformer paper. + attention_probs = self.attention_dropout(attention_probs) + # ========================= + # Context layer. [sq, b, hp] + # ========================= + + # value_layer -> context layer. + # [sk, b, np, hn] --> [b, np, sq, hn] + + # context layer shape: [b, np, sq, hn] + output_size = ( + value_layer.size(1), + value_layer.size(2), + query_layer.size(0), + value_layer.size(3), + ) + # change view [sk, b * np, hn] + value_layer = value_layer.view(value_layer.size(0), output_size[0] * output_size[1], -1) + # change view [b * np, sq, sk] + attention_probs = attention_probs.view(output_size[0] * output_size[1], output_size[2], -1) + # matmul: [b * np, sq, hn] + context_layer = torch.bmm(attention_probs, value_layer.transpose(0, 1)) + # change view [b, np, sq, hn] + context_layer = context_layer.view(*output_size) + # [b, np, sq, hn] --> [sq, b, np, hn] + context_layer = context_layer.permute(2, 0, 1, 3).contiguous() + # [sq, b, np, hn] --> [sq, b, hp] + new_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size_per_partition,) + context_layer = context_layer.view(*new_context_layer_shape) + + return context_layer + + +class SelfAttention(torch.nn.Module): + """Parallel self-attention layer abstract class. + + Self-attention layer takes input with size [s, b, h] + and returns output of the same size. + """ + + def __init__(self, config: ChatGLMConfig, layer_number, device=None): + super(SelfAttention, self).__init__() + self.layer_number = max(1, layer_number) + self.projection_size = config.kv_channels * config.num_attention_heads + # Per attention head and per partition values. + self.hidden_size_per_attention_head = self.projection_size // config.num_attention_heads + self.num_attention_heads_per_partition = config.num_attention_heads + self.multi_query_attention = config.multi_query_attention + self.qkv_hidden_size = 3 * self.projection_size + if self.multi_query_attention: + self.num_multi_query_groups_per_partition = config.multi_query_group_num + self.qkv_hidden_size = ( + self.projection_size + 2 * self.hidden_size_per_attention_head * config.multi_query_group_num + ) + self.query_key_value = nn.Linear( + config.hidden_size, + self.qkv_hidden_size, + bias=config.add_bias_linear or config.add_qkv_bias, + device=device, + **_config_to_kwargs(config), + ) + + self.core_attention = CoreAttention(config, self.layer_number) + + # Output. + self.dense = nn.Linear( + self.projection_size, + config.hidden_size, + bias=config.add_bias_linear, + device=device, + **_config_to_kwargs(config), + ) + + def _allocate_memory(self, inference_max_sequence_len, batch_size, device=None, dtype=None): + if self.multi_query_attention: + num_attention_heads = self.num_multi_query_groups_per_partition + else: + num_attention_heads = self.num_attention_heads_per_partition + return torch.empty( + inference_max_sequence_len, + batch_size, + num_attention_heads, + self.hidden_size_per_attention_head, + dtype=dtype, + device=device, + ) + + def forward( + self, + hidden_states, + attention_mask, + rotary_pos_emb, + kv_cache=None, + use_cache=True, + ): + # hidden_states: [sq, b, h] + + # ================================================= + # Pre-allocate memory for key-values for inference. + # ================================================= + # ===================== + # Query, Key, and Value + # ===================== + + # Attention heads [sq, b, h] --> [sq, b, (np * 3 * hn)] + mixed_x_layer = self.query_key_value(hidden_states) + if self.multi_query_attention: + (query_layer, key_layer, value_layer) = mixed_x_layer.split( + [ + self.num_attention_heads_per_partition * self.hidden_size_per_attention_head, + self.num_multi_query_groups_per_partition * self.hidden_size_per_attention_head, + self.num_multi_query_groups_per_partition * self.hidden_size_per_attention_head, + ], + dim=-1, + ) + query_layer = query_layer.view( + query_layer.size()[:-1] + + ( + self.num_attention_heads_per_partition, + self.hidden_size_per_attention_head, + ) + ) + key_layer = key_layer.view( + key_layer.size()[:-1] + + ( + self.num_multi_query_groups_per_partition, + self.hidden_size_per_attention_head, + ) + ) + value_layer = value_layer.view( + value_layer.size()[:-1] + + ( + self.num_multi_query_groups_per_partition, + self.hidden_size_per_attention_head, + ) + ) + else: + new_tensor_shape = mixed_x_layer.size()[:-1] + ( + self.num_attention_heads_per_partition, + 3 * self.hidden_size_per_attention_head, + ) + mixed_x_layer = mixed_x_layer.view(*new_tensor_shape) + # [sq, b, np, 3 * hn] --> 3 [sq, b, np, hn] + (query_layer, key_layer, value_layer) = split_tensor_along_last_dim(mixed_x_layer, 3) + + # apply relative positional encoding (rotary embedding) + if rotary_pos_emb is not None: + query_layer = apply_rotary_pos_emb(query_layer, rotary_pos_emb) + key_layer = apply_rotary_pos_emb(key_layer, rotary_pos_emb) + + # adjust key and value for inference + if kv_cache is not None: + cache_k, cache_v = kv_cache + key_layer = torch.cat((cache_k, key_layer), dim=0) + value_layer = torch.cat((cache_v, value_layer), dim=0) + if use_cache: + kv_cache = (key_layer, value_layer) + else: + kv_cache = None + + if self.multi_query_attention: + key_layer = key_layer.unsqueeze(-2) + key_layer = key_layer.expand( + -1, + -1, + -1, + self.num_attention_heads_per_partition // self.num_multi_query_groups_per_partition, + -1, + ) + key_layer = key_layer.contiguous().view( + key_layer.size()[:2] + + ( + self.num_attention_heads_per_partition, + self.hidden_size_per_attention_head, + ) + ) + value_layer = value_layer.unsqueeze(-2) + value_layer = value_layer.expand( + -1, + -1, + -1, + self.num_attention_heads_per_partition // self.num_multi_query_groups_per_partition, + -1, + ) + value_layer = value_layer.contiguous().view( + value_layer.size()[:2] + + ( + self.num_attention_heads_per_partition, + self.hidden_size_per_attention_head, + ) + ) + + # ================================== + # core attention computation + # ================================== + + context_layer = self.core_attention(query_layer, key_layer, value_layer, attention_mask) + + # ================= + # Output. [sq, b, h] + # ================= + output = self.dense(context_layer) + + return output, kv_cache + + +def _config_to_kwargs(args): + common_kwargs = { + "dtype": args.torch_dtype, + } + return common_kwargs + + +class MLP(torch.nn.Module): + """MLP. + + MLP will take the input with h hidden state, project it to 4*h + hidden dimension, perform nonlinear transformation, and project the + state back into h hidden dimension. + """ + + def __init__(self, config: ChatGLMConfig, device=None): + super(MLP, self).__init__() + + self.add_bias = config.add_bias_linear + + # Project to 4h. If using swiglu double the output width, see https://arxiv.org/pdf/2002.05202.pdf + self.dense_h_to_4h = nn.Linear( + config.hidden_size, + config.ffn_hidden_size * 2, + bias=self.add_bias, + device=device, + **_config_to_kwargs(config), + ) + + def swiglu(x): + x = torch.chunk(x, 2, dim=-1) + return F.silu(x[0]) * x[1] + + self.activation_func = swiglu + + # Project back to h. + self.dense_4h_to_h = nn.Linear( + config.ffn_hidden_size, + config.hidden_size, + bias=self.add_bias, + device=device, + **_config_to_kwargs(config), + ) + + def forward(self, hidden_states): + # [s, b, 4hp] + intermediate_parallel = self.dense_h_to_4h(hidden_states) + intermediate_parallel = self.activation_func(intermediate_parallel) + # [s, b, h] + output = self.dense_4h_to_h(intermediate_parallel) + return output + + +class GLMBlock(torch.nn.Module): + """A single transformer layer. + + Transformer layer takes input with size [s, b, h] and returns an + output of the same size. + """ + + def __init__(self, config: ChatGLMConfig, layer_number, device=None): + super(GLMBlock, self).__init__() + self.layer_number = layer_number + + self.apply_residual_connection_post_layernorm = config.apply_residual_connection_post_layernorm + + self.fp32_residual_connection = config.fp32_residual_connection + + LayerNormFunc = RMSNorm if config.rmsnorm else LayerNorm + # Layernorm on the input data. + self.input_layernorm = LayerNormFunc( + config.hidden_size, + eps=config.layernorm_epsilon, + device=device, + dtype=config.torch_dtype, + ) + + # Self attention. + self.self_attention = SelfAttention(config, layer_number, device=device) + self.hidden_dropout = config.hidden_dropout + + # Layernorm on the attention output + self.post_attention_layernorm = LayerNormFunc( + config.hidden_size, + eps=config.layernorm_epsilon, + device=device, + dtype=config.torch_dtype, + ) + + # MLP + self.mlp = MLP(config, device=device) + + def forward( + self, + hidden_states, + attention_mask, + rotary_pos_emb, + kv_cache=None, + use_cache=True, + ): + # hidden_states: [s, b, h] + + # Layer norm at the beginning of the transformer layer. + layernorm_output = self.input_layernorm(hidden_states) + # Self attention. + attention_output, kv_cache = self.self_attention( + layernorm_output, + attention_mask, + rotary_pos_emb, + kv_cache=kv_cache, + use_cache=use_cache, + ) + + # Residual connection. + if self.apply_residual_connection_post_layernorm: + residual = layernorm_output + else: + residual = hidden_states + + layernorm_input = torch.nn.functional.dropout(attention_output, p=self.hidden_dropout, training=self.training) + layernorm_input = residual + layernorm_input + + # Layer norm post the self attention. + layernorm_output = self.post_attention_layernorm(layernorm_input) + + # MLP. + mlp_output = self.mlp(layernorm_output) + + # Second residual connection. + if self.apply_residual_connection_post_layernorm: + residual = layernorm_output + else: + residual = layernorm_input + + output = torch.nn.functional.dropout(mlp_output, p=self.hidden_dropout, training=self.training) + output = residual + output + + return output, kv_cache + + +class GLMTransformer(torch.nn.Module): + """Transformer class.""" + + def __init__(self, config: ChatGLMConfig, device=None): + super(GLMTransformer, self).__init__() + + self.fp32_residual_connection = config.fp32_residual_connection + self.post_layer_norm = config.post_layer_norm + + # Number of layers. + self.num_layers = config.num_layers + + # Transformer layers. + def build_layer(layer_number): + return GLMBlock(config, layer_number, device=device) + + self.layers = torch.nn.ModuleList([build_layer(i + 1) for i in range(self.num_layers)]) + + if self.post_layer_norm: + LayerNormFunc = RMSNorm if config.rmsnorm else LayerNorm + # Final layer norm before output. + self.final_layernorm = LayerNormFunc( + config.hidden_size, + eps=config.layernorm_epsilon, + device=device, + dtype=config.torch_dtype, + ) + + self.gradient_checkpointing = False + + def _get_layer(self, layer_number): + return self.layers[layer_number] + + def forward( + self, + hidden_states, + attention_mask, + rotary_pos_emb, + kv_caches=None, + use_cache: Optional[bool] = True, + output_hidden_states: Optional[bool] = False, + ): + if not kv_caches: + kv_caches = [None for _ in range(self.num_layers)] + presents = () if use_cache else None + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + all_self_attentions = None + all_hidden_states = () if output_hidden_states else None + for index in range(self.num_layers): + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + layer = self._get_layer(index) + if self.gradient_checkpointing and self.training: + layer_ret = torch.utils.checkpoint.checkpoint( + layer, + hidden_states, + attention_mask, + rotary_pos_emb, + kv_caches[index], + use_cache, + ) + else: + layer_ret = layer( + hidden_states, + attention_mask, + rotary_pos_emb, + kv_cache=kv_caches[index], + use_cache=use_cache, + ) + hidden_states, kv_cache = layer_ret + if use_cache: + presents = presents + (kv_cache,) + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + # Final layer norm. + if self.post_layer_norm: + hidden_states = self.final_layernorm(hidden_states) + + return hidden_states, presents, all_hidden_states, all_self_attentions + + +class ChatGLMPreTrainedModel(PreTrainedModel): + """ + An abstract class to handle weights initialization and + a simple interface for downloading and loading pretrained models. + """ + + is_parallelizable = False + supports_gradient_checkpointing = True + config_class = ChatGLMConfig + base_model_prefix = "transformer" + _no_split_modules = ["GLMBlock"] + + def _init_weights(self, module: nn.Module): + """Initialize the weights.""" + return + + def get_masks(self, input_ids, past_key_values, padding_mask=None): + batch_size, seq_length = input_ids.shape + full_attention_mask = torch.ones(batch_size, seq_length, seq_length, device=input_ids.device) + full_attention_mask.tril_() + past_length = 0 + if past_key_values: + past_length = past_key_values[0][0].shape[0] + if past_length: + full_attention_mask = torch.cat( + ( + torch.ones(batch_size, seq_length, past_length, device=input_ids.device), + full_attention_mask, + ), + dim=-1, + ) + if padding_mask is not None: + full_attention_mask = full_attention_mask * padding_mask.unsqueeze(1) + if not past_length and padding_mask is not None: + full_attention_mask -= padding_mask.unsqueeze(-1) - 1 + full_attention_mask = (full_attention_mask < 0.5).bool() + full_attention_mask.unsqueeze_(1) + return full_attention_mask + + def get_position_ids(self, input_ids, device): + batch_size, seq_length = input_ids.shape + position_ids = torch.arange(seq_length, dtype=torch.long, device=device).unsqueeze(0).repeat(batch_size, 1) + return position_ids + + def _set_gradient_checkpointing(self, module, value=False): + if isinstance(module, GLMTransformer): + module.gradient_checkpointing = value + + +class Embedding(torch.nn.Module): + """Language model embeddings.""" + + def __init__(self, config: ChatGLMConfig, device=None): + super(Embedding, self).__init__() + + self.hidden_size = config.hidden_size + # Word embeddings (parallel). + self.word_embeddings = nn.Embedding( + config.padded_vocab_size, + self.hidden_size, + dtype=config.torch_dtype, + device=device, + ) + self.fp32_residual_connection = config.fp32_residual_connection + + def forward(self, input_ids): + # Embeddings. + words_embeddings = self.word_embeddings(input_ids) + embeddings = words_embeddings + # Data format change to avoid explicit tranposes : [b s h] --> [s b h]. + embeddings = embeddings.transpose(0, 1).contiguous() + # If the input flag for fp32 residual connection is set, convert for float. + if self.fp32_residual_connection: + embeddings = embeddings.float() + return embeddings + + +class ChatGLMModel(ChatGLMPreTrainedModel): + def __init__(self, config: ChatGLMConfig, device=None, empty_init=True): + super().__init__(config) + if empty_init: + init_method = skip_init + else: + init_method = default_init + init_kwargs = {} + if device is not None: + init_kwargs["device"] = device + self.embedding = init_method(Embedding, config, **init_kwargs) + self.num_layers = config.num_layers + self.multi_query_group_num = config.multi_query_group_num + self.kv_channels = config.kv_channels + + # Rotary positional embeddings + self.seq_length = config.seq_length + rotary_dim = ( + config.hidden_size // config.num_attention_heads if config.kv_channels is None else config.kv_channels + ) + + self.rotary_pos_emb = RotaryEmbedding( + rotary_dim // 2, + # original_impl=config.original_rope, # config has no attribute original_rope + device=device, + dtype=config.torch_dtype, + ) + self.encoder = init_method(GLMTransformer, config, **init_kwargs) + self.output_layer = init_method( + nn.Linear, + config.hidden_size, + config.padded_vocab_size, + bias=False, + dtype=config.torch_dtype, + **init_kwargs, + ) + self.pre_seq_len = config.pre_seq_len + self.prefix_projection = config.prefix_projection + if self.pre_seq_len is not None: + for param in self.parameters(): + param.requires_grad = False + self.prefix_tokens = torch.arange(self.pre_seq_len).long() + self.prefix_encoder = PrefixEncoder(config) + self.dropout = torch.nn.Dropout(0.1) + + def get_input_embeddings(self): + return self.embedding.word_embeddings + + def get_prompt(self, batch_size, device, dtype=torch.half): + prefix_tokens = self.prefix_tokens.unsqueeze(0).expand(batch_size, -1).to(device) + past_key_values = self.prefix_encoder(prefix_tokens).type(dtype) + past_key_values = past_key_values.view( + batch_size, + self.pre_seq_len, + self.num_layers * 2, + self.multi_query_group_num, + self.kv_channels, + ) + # seq_len, b, nh, hidden_size + past_key_values = self.dropout(past_key_values) + past_key_values = past_key_values.permute([2, 1, 0, 3, 4]).split(2) + return past_key_values + + def forward( + self, + input_ids, + position_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.BoolTensor] = None, + full_attention_mask: Optional[torch.BoolTensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None, + inputs_embeds: Optional[torch.Tensor] = None, + use_cache: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ): + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + batch_size, seq_length = input_ids.shape + + if inputs_embeds is None: + inputs_embeds = self.embedding(input_ids) + + if self.pre_seq_len is not None: + if past_key_values is None: + past_key_values = self.get_prompt( + batch_size=batch_size, + device=input_ids.device, + dtype=inputs_embeds.dtype, + ) + if attention_mask is not None: + attention_mask = torch.cat( + [ + attention_mask.new_ones((batch_size, self.pre_seq_len)), + attention_mask, + ], + dim=-1, + ) + + if full_attention_mask is None: + if (attention_mask is not None and not attention_mask.all()) or (past_key_values and seq_length != 1): + full_attention_mask = self.get_masks(input_ids, past_key_values, padding_mask=attention_mask) + + # Rotary positional embeddings + rotary_pos_emb = self.rotary_pos_emb(self.seq_length) + if position_ids is not None: + rotary_pos_emb = rotary_pos_emb[position_ids] + else: + rotary_pos_emb = rotary_pos_emb[None, :seq_length] + rotary_pos_emb = rotary_pos_emb.transpose(0, 1).contiguous() + + # Run encoder. + hidden_states, presents, all_hidden_states, all_self_attentions = self.encoder( + inputs_embeds, + full_attention_mask, + rotary_pos_emb=rotary_pos_emb, + kv_caches=past_key_values, + use_cache=use_cache, + output_hidden_states=output_hidden_states, + ) + + if not return_dict: + return tuple( + v + for v in [ + hidden_states, + presents, + all_hidden_states, + all_self_attentions, + ] + if v is not None + ) + + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=presents, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + ) + + def quantize(self, weight_bit_width: int): + from .quantization import quantize + + quantize(self.encoder, weight_bit_width) + return self + + +class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel): + def __init__(self, config: ChatGLMConfig, empty_init=True, device=None): + super().__init__(config) + + self.max_sequence_length = config.max_length + self.transformer = ChatGLMModel(config, empty_init=empty_init, device=device) + self.config = config + self.quantized = False + + if self.config.quantization_bit: + self.quantize(self.config.quantization_bit, empty_init=True) + + def _update_model_kwargs_for_generation( + self, + outputs: ModelOutput, + model_kwargs: Dict[str, Any], + is_encoder_decoder: bool = False, + standardize_cache_format: bool = False, + ) -> Dict[str, Any]: + # update past_key_values + model_kwargs["past_key_values"] = self._extract_past_from_model_output( + outputs, standardize_cache_format=standardize_cache_format + ) + + # update attention mask + if "attention_mask" in model_kwargs: + attention_mask = model_kwargs["attention_mask"] + model_kwargs["attention_mask"] = torch.cat( + [attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], + dim=-1, + ) + + # update position ids + if "position_ids" in model_kwargs: + position_ids = model_kwargs["position_ids"] + new_position_id = position_ids[..., -1:].clone() + new_position_id += 1 + model_kwargs["position_ids"] = torch.cat([position_ids, new_position_id], dim=-1) + + model_kwargs["is_first_forward"] = False + return model_kwargs + + def prepare_inputs_for_generation( + self, + input_ids: torch.LongTensor, + past_key_values: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + is_first_forward: bool = True, + **kwargs, + ) -> dict: + # only last token for input_ids if past is not None + if position_ids is None: + position_ids = self.get_position_ids(input_ids, device=input_ids.device) + if not is_first_forward: + position_ids = position_ids[..., -1:] + input_ids = input_ids[:, -1:] + return { + "input_ids": input_ids, + "past_key_values": past_key_values, + "position_ids": position_ids, + "attention_mask": attention_mask, + "return_last_logit": True, + } + + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[Tuple[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + return_last_logit: Optional[bool] = False, + ): + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + transformer_outputs = self.transformer( + input_ids=input_ids, + position_ids=position_ids, + attention_mask=attention_mask, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + hidden_states = transformer_outputs[0] + if return_last_logit: + hidden_states = hidden_states[-1:] + lm_logits = self.transformer.output_layer(hidden_states) + lm_logits = lm_logits.transpose(0, 1).contiguous() + + loss = None + if labels is not None: + lm_logits = lm_logits.to(torch.float32) + + # Shift so that tokens < n predict n + shift_logits = lm_logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + loss_fct = CrossEntropyLoss(ignore_index=-100) + loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) + + lm_logits = lm_logits.to(hidden_states.dtype) + loss = loss.to(hidden_states.dtype) + + if not return_dict: + output = (lm_logits,) + transformer_outputs[1:] + return ((loss,) + output) if loss is not None else output + + return CausalLMOutputWithPast( + loss=loss, + logits=lm_logits, + past_key_values=transformer_outputs.past_key_values, + hidden_states=transformer_outputs.hidden_states, + attentions=transformer_outputs.attentions, + ) + + @staticmethod + def _reorder_cache( + past: Tuple[Tuple[torch.Tensor, torch.Tensor], ...], beam_idx: torch.LongTensor + ) -> Tuple[Tuple[torch.Tensor, torch.Tensor], ...]: + """ + This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or + [`~PreTrainedModel.beam_sample`] is called. This is required to match `past_key_values` with the correct + beam_idx at every generation step. + + Output shares the same memory storage as `past`. + """ + return tuple( + ( + layer_past[0].index_select(1, beam_idx.to(layer_past[0].device)), + layer_past[1].index_select(1, beam_idx.to(layer_past[1].device)), + ) + for layer_past in past + ) + + def process_response(self, response): + response = response.strip() + response = response.replace("[[训练时间]]", "2023年") + return response + + def build_inputs(self, tokenizer, query: str, history: List[Tuple[str, str]] = None): + prompt = tokenizer.build_prompt(query, history=history) + inputs = tokenizer([prompt], return_tensors="pt") + inputs = inputs.to(self.device) + return inputs + + def build_stream_inputs(self, tokenizer, query: str, history: List[Tuple[str, str]] = None): + if history: + prompt = "\n\n[Round {}]\n\n问:{}\n\n答:".format(len(history) + 1, query) + input_ids = tokenizer.encode(prompt, add_special_tokens=False) + input_ids = input_ids[1:] + inputs = tokenizer.batch_encode_plus([(input_ids, None)], return_tensors="pt", add_special_tokens=False) + else: + prompt = "[Round {}]\n\n问:{}\n\n答:".format(len(history) + 1, query) + inputs = tokenizer([prompt], return_tensors="pt") + inputs = inputs.to(self.device) + return inputs + + @torch.no_grad() + def chat( + self, + tokenizer, + query: str, + history: List[Tuple[str, str]] = None, + max_length: int = 8192, + num_beams=1, + do_sample=True, + top_p=0.8, + temperature=0.8, + logits_processor=None, + **kwargs, + ): + if history is None: + history = [] + if logits_processor is None: + logits_processor = LogitsProcessorList() + logits_processor.append(InvalidScoreLogitsProcessor()) + gen_kwargs = { + "max_length": max_length, + "num_beams": num_beams, + "do_sample": do_sample, + "top_p": top_p, + "temperature": temperature, + "logits_processor": logits_processor, + **kwargs, + } + inputs = self.build_inputs(tokenizer, query, history=history) + outputs = self.generate(**inputs, **gen_kwargs) + outputs = outputs.tolist()[0][len(inputs["input_ids"][0]) :] + response = tokenizer.decode(outputs) + response = self.process_response(response) + history = history + [(query, response)] + return response, history + + @torch.no_grad() + def stream_chat( + self, + tokenizer, + query: str, + history: List[Tuple[str, str]] = None, + past_key_values=None, + max_length: int = 8192, + do_sample=True, + top_p=0.8, + temperature=0.8, + logits_processor=None, + return_past_key_values=False, + **kwargs, + ): + if history is None: + history = [] + if logits_processor is None: + logits_processor = LogitsProcessorList() + logits_processor.append(InvalidScoreLogitsProcessor()) + gen_kwargs = { + "max_length": max_length, + "do_sample": do_sample, + "top_p": top_p, + "temperature": temperature, + "logits_processor": logits_processor, + **kwargs, + } + if past_key_values is None and not return_past_key_values: + inputs = self.build_inputs(tokenizer, query, history=history) + else: + inputs = self.build_stream_inputs(tokenizer, query, history=history) + if past_key_values is not None: + past_length = past_key_values[0][0].shape[0] + if self.transformer.pre_seq_len is not None: + past_length -= self.transformer.pre_seq_len + inputs.position_ids += past_length + attention_mask = inputs.attention_mask + attention_mask = torch.cat((attention_mask.new_ones(1, past_length), attention_mask), dim=1) + inputs["attention_mask"] = attention_mask + for outputs in self.stream_generate( + **inputs, + past_key_values=past_key_values, + return_past_key_values=return_past_key_values, + **gen_kwargs, + ): + if return_past_key_values: + outputs, past_key_values = outputs + outputs = outputs.tolist()[0][len(inputs["input_ids"][0]) :] + response = tokenizer.decode(outputs) + if response and response[-1] != "�": + response = self.process_response(response) + new_history = history + [(query, response)] + if return_past_key_values: + yield response, new_history, past_key_values + else: + yield response, new_history + + @torch.no_grad() + def stream_generate( + self, + input_ids, + generation_config: Optional[GenerationConfig] = None, + logits_processor: Optional[LogitsProcessorList] = None, + stopping_criteria: Optional[StoppingCriteriaList] = None, + prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None, + return_past_key_values=False, + **kwargs, + ): + batch_size, input_ids_seq_length = input_ids.shape[0], input_ids.shape[-1] + + if generation_config is None: + generation_config = self.generation_config + generation_config = copy.deepcopy(generation_config) + model_kwargs = generation_config.update(**kwargs) + bos_token_id, eos_token_id = ( + generation_config.bos_token_id, + generation_config.eos_token_id, + ) + + if isinstance(eos_token_id, int): + eos_token_id = [eos_token_id] + + has_default_max_length = kwargs.get("max_length") is None and generation_config.max_length is not None + if has_default_max_length and generation_config.max_new_tokens is None: + warnings.warn( + f"Using `max_length`'s default ({generation_config.max_length}) to control the generation length. " + "This behaviour is deprecated and will be removed from the config in v5 of Transformers -- we" + " recommend using `max_new_tokens` to control the maximum length of the generation.", + UserWarning, + ) + elif generation_config.max_new_tokens is not None: + generation_config.max_length = generation_config.max_new_tokens + input_ids_seq_length + if not has_default_max_length: + logger.warn( + f"Both `max_new_tokens` (={generation_config.max_new_tokens}) and `max_length`(=" + f"{generation_config.max_length}) seem to have been set. `max_new_tokens` will take precedence. " + "Please refer to the documentation for more information. " + "(https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)", + UserWarning, + ) + + if input_ids_seq_length >= generation_config.max_length: + input_ids_string = "decoder_input_ids" if self.config.is_encoder_decoder else "input_ids" + logger.warning( + f"Input length of {input_ids_string} is {input_ids_seq_length}, but `max_length` is set to" + f" {generation_config.max_length}. This can lead to unexpected behavior. You should consider" + " increasing `max_new_tokens`." + ) + + # 2. Set generation parameters if not already defined + logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList() + stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList() + + logits_processor = self._get_logits_processor( + generation_config=generation_config, + input_ids_seq_length=input_ids_seq_length, + encoder_input_ids=input_ids, + prefix_allowed_tokens_fn=prefix_allowed_tokens_fn, + logits_processor=logits_processor, + ) + + stopping_criteria = self._get_stopping_criteria( + generation_config=generation_config, stopping_criteria=stopping_criteria + ) + logits_warper = self._get_logits_warper(generation_config) + + unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1) + scores = None + while True: + model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) + # forward pass to get next token + outputs = self( + **model_inputs, + return_dict=True, + output_attentions=False, + output_hidden_states=False, + ) + + next_token_logits = outputs.logits[:, -1, :] + + # pre-process distribution + next_token_scores = logits_processor(input_ids, next_token_logits) + next_token_scores = logits_warper(input_ids, next_token_scores) + + # sample + probs = nn.functional.softmax(next_token_scores, dim=-1) + if generation_config.do_sample: + next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1) + else: + next_tokens = torch.argmax(probs, dim=-1) + + # update generated ids, model inputs, and length for next step + input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1) + model_kwargs = self._update_model_kwargs_for_generation( + outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder + ) + unfinished_sequences = unfinished_sequences.mul((sum(next_tokens != i for i in eos_token_id)).long()) + if return_past_key_values: + yield input_ids, outputs.past_key_values + else: + yield input_ids + # stop when each sentence is finished, or if we exceed the maximum length + if unfinished_sequences.max() == 0 or stopping_criteria(input_ids, scores): + break + + def quantize(self, bits: int, empty_init=False, device=None, **kwargs): + if bits == 0: + return + + from .quantization import quantize + + if self.quantized: + logger.info("Already quantized.") + return self + + self.quantized = True + + self.config.quantization_bit = bits + + self.transformer.encoder = quantize( + self.transformer.encoder, + bits, + empty_init=empty_init, + device=device, + **kwargs, + ) + return self diff --git a/colossalai/shardformer/modeling/gpt2.py b/colossalai/shardformer/modeling/gpt2.py new file mode 100644 index 0000000000000000000000000000000000000000..21f06393071dfc694747050ded34dd16586d6de8 --- /dev/null +++ b/colossalai/shardformer/modeling/gpt2.py @@ -0,0 +1,1007 @@ +from typing import Dict, List, Optional, Tuple, Union + +import torch +from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss +from transformers.modeling_outputs import ( + BaseModelOutputWithPastAndCrossAttentions, + CausalLMOutputWithCrossAttentions, + QuestionAnsweringModelOutput, + SequenceClassifierOutputWithPast, + TokenClassifierOutput, +) +from transformers.models.gpt2.modeling_gpt2 import ( + GPT2DoubleHeadsModel, + GPT2DoubleHeadsModelOutput, + GPT2ForQuestionAnswering, + GPT2ForSequenceClassification, + GPT2ForTokenClassification, + GPT2LMHeadModel, + GPT2Model, +) +from transformers.utils import logging + +from colossalai.pipeline.stage_manager import PipelineStageManager +from colossalai.shardformer.layer._operation import gather_forward_split_backward, split_forward_gather_backward +from colossalai.shardformer.shard import ShardConfig + + +class GPT2PipelineForwards: + """ + This class serves as a micro library for forward function substitution of GPT2 models + under pipeline setting. + """ + + @staticmethod + def gpt2_model_forward( + self: GPT2Model, + input_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, + attention_mask: Optional[torch.FloatTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + stage_manager: Optional[PipelineStageManager] = None, + hidden_states: Optional[torch.FloatTensor] = None, + stage_index: Optional[List[int]] = None, + shard_config: ShardConfig = None, + ) -> Union[Dict, Tuple, BaseModelOutputWithPastAndCrossAttentions]: + # This function is modified on the basis of transformers.models.gpt2.modeling_gpt2.GPT2Model.forward. + # Please refer to original code of transformers for more details. + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + logger = logging.get_logger(__name__) + + # Preprocess passed in arguments + # TODO(baizhou): left the recording kv-value tensors as () or None type, this feature may be added in the future. + if past_key_values: + logger.warning_once("Non-empty past_key_values is not supported for pipeline models at the moment.") + past_key_values = None + if output_attentions: + logger.warning_once("output_attentions=True is not supported for pipeline models at the moment.") + output_attentions = False + if output_hidden_states: + logger.warning_once("output_hidden_states=True is not supported for pipeline models at the moment.") + output_hidden_states = False + if use_cache: + logger.warning_once("use_cache=True is not supported for pipeline models at the moment.") + use_cache = False + + if stage_manager.is_first_stage(): + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is not None: + input_shape = input_ids.size() + input_ids = input_ids.view(-1, input_shape[-1]) + batch_size = input_ids.shape[0] + elif inputs_embeds is not None: + input_shape = inputs_embeds.size()[:-1] + batch_size = inputs_embeds.shape[0] + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + device = input_ids.device if input_ids is not None else inputs_embeds.device + if token_type_ids is not None: + token_type_ids = token_type_ids.view(-1, input_shape[-1]) + else: + if hidden_states is None: + raise ValueError("hidden_states shouldn't be None for stages other than the first stage.") + input_shape = hidden_states.size()[:-1] + device = hidden_states.device + hidden_states = hidden_states.view((-1,) + hidden_states.shape[-2:]) + batch_size = hidden_states.shape[0] + + # GPT2Attention mask. + if attention_mask is not None: + if batch_size <= 0: + raise ValueError("batch_size has to be defined and > 0") + attention_mask = attention_mask.view(batch_size, -1) + # We create a 3D attention mask from a 2D tensor mask. + # Sizes are [batch_size, 1, 1, to_seq_length] + # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length] + # this attention mask is more simple than the triangular masking of causal attention + # used in OpenAI GPT, we just need to prepare the broadcast dimension here. + attention_mask = attention_mask[:, None, None, :] + + # Since attention_mask is 1.0 for positions we want to attend and 0.0 for + # masked positions, this operation will create a tensor which is 0.0 for + # positions we want to attend and the dtype's smallest value for masked positions. + # Since we are adding it to the raw scores before the softmax, this is + # effectively the same as removing these entirely. + attention_mask = attention_mask.to(dtype=self.dtype) # fp16 compatibility + attention_mask = (1.0 - attention_mask) * torch.finfo(self.dtype).min + + # If a 2D or 3D attention mask is provided for the cross-attention + # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] + if self.config.add_cross_attention and encoder_hidden_states is not None: + encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size() + encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length) + if encoder_attention_mask is None: + encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device) + encoder_attention_mask = self.invert_attention_mask(encoder_attention_mask) + else: + encoder_attention_mask = None + + # Prepare head mask if needed + # 1.0 in head_mask indicate we keep the head + # attention_probs has shape bsz x n_heads x N x N + # head_mask has shape n_layer x batch x n_heads x N x N + head_mask = self.get_head_mask(head_mask, self.config.n_layer) + + if stage_manager.is_first_stage(): + if position_ids is not None: + position_ids = position_ids.view(-1, input_shape[-1]) + else: + position_ids = torch.arange(0, input_shape[-1], dtype=torch.long, device=device) + position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1]) + + if inputs_embeds is None: + inputs_embeds = self.wte(input_ids) + position_embeds = self.wpe(position_ids) + hidden_states = inputs_embeds + position_embeds + if token_type_ids is not None: + token_type_embeds = self.wte(token_type_ids) + hidden_states = hidden_states + token_type_embeds + hidden_states = self.drop(hidden_states) + + output_shape = input_shape + (hidden_states.size(-1),) + + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + presents = () if use_cache else None + all_self_attentions = () if output_attentions else None + all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None + all_hidden_states = () if output_hidden_states else None + + # split the input tensor along sequence dimension + # [batch_size, seq_len, hidden_size] -> [batch_size, seq_len/TP_size, hidden_size] + if shard_config.enable_sequence_parallelism: + hidden_states = split_forward_gather_backward( + hidden_states, dim=1, process_group=shard_config.tensor_parallel_process_group + ) + + # Going through held blocks. + start_idx, end_idx = stage_index[0], stage_index[1] + for i in range(start_idx, end_idx): + block = self.h[i] + torch.cuda.set_device(hidden_states.device) + # Ensure that attention_mask is always on the same device as hidden_states + if attention_mask is not None: + attention_mask = attention_mask.to(hidden_states.device) + if isinstance(head_mask, torch.Tensor): + head_mask = head_mask.to(hidden_states.device) + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if self.gradient_checkpointing and self.training: + + def create_custom_forward(module): + def custom_forward(*inputs): + # None for past_key_value + return module(*inputs, use_cache, output_attentions) + + return custom_forward + + outputs = torch.utils.checkpoint.checkpoint( + create_custom_forward(block), + hidden_states, + None, + attention_mask, + head_mask[i], + encoder_hidden_states, + encoder_attention_mask, + ) + else: + outputs = block( + hidden_states, + layer_past=None, + attention_mask=attention_mask, + head_mask=head_mask[i], + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + use_cache=use_cache, + output_attentions=output_attentions, + ) + + hidden_states = outputs[0] + if use_cache is True: + presents = presents + (outputs[1],) + + if output_attentions: + all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],) + if self.config.add_cross_attention: + all_cross_attentions = all_cross_attentions + (outputs[3 if use_cache else 2],) + + # When sequence parallelism done, gather the output tensor in forward and split it in backward + if shard_config.enable_sequence_parallelism: + hidden_states = gather_forward_split_backward( + hidden_states, dim=1, process_group=shard_config.tensor_parallel_process_group + ) + + if stage_manager.is_last_stage(): + hidden_states = self.ln_f(hidden_states) + + hidden_states = hidden_states.view(output_shape) + + # Add last hidden state + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if stage_manager.is_last_stage(): + if not return_dict: + return tuple( + v + for v in [hidden_states, presents, all_hidden_states, all_self_attentions, all_cross_attentions] + if v is not None + ) + + return BaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=hidden_states, + past_key_values=presents, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + cross_attentions=all_cross_attentions, + ) + else: + # always return dict for intermediate stage + return {"hidden_states": hidden_states} + + @staticmethod + def gpt2_lmhead_model_forward( + self: GPT2LMHeadModel, + input_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, + attention_mask: Optional[torch.FloatTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + stage_manager: Optional[PipelineStageManager] = None, + hidden_states: Optional[torch.FloatTensor] = None, + stage_index: Optional[List[int]] = None, + shard_config: ShardConfig = None, + ) -> Union[Dict, Tuple, CausalLMOutputWithCrossAttentions]: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set + `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100` + are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]` + + This function is modified on the basis of transformers.models.gpt2.modeling_gpt2.GPT2LMHeadModel.forward. + Please refer to original code of transformers for more details. + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = GPT2PipelineForwards.gpt2_model_forward( + self.transformer, + input_ids, + past_key_values=past_key_values, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + stage_manager=stage_manager, + hidden_states=hidden_states, + stage_index=stage_index, + shard_config=shard_config, + ) + + # If not at the last stage, return hidden_states as in GPT2Model + if not stage_manager.is_last_stage(): + return {"hidden_states": outputs["hidden_states"]} + + hidden_states = outputs[0] + lm_logits = self.lm_head(hidden_states) + loss = None + if labels is not None: + # move labels to correct device to enable model parallelism + labels = labels.to(lm_logits.device) + # Shift so that tokens < n predict n + shift_logits = lm_logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + loss_fct = CrossEntropyLoss() + loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) + if not return_dict: + output = (lm_logits,) + outputs[1:] + return ((loss,) + output) if loss is not None else output + + return CausalLMOutputWithCrossAttentions( + loss=loss, + logits=lm_logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + cross_attentions=outputs.cross_attentions, + ) + + @staticmethod + def gpt2_double_heads_model_forward( + self: GPT2DoubleHeadsModel, + input_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, + attention_mask: Optional[torch.FloatTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + mc_token_ids: Optional[torch.LongTensor] = None, + labels: Optional[torch.LongTensor] = None, + mc_labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + stage_manager: Optional[PipelineStageManager] = None, + hidden_states: Optional[torch.FloatTensor] = None, + stage_index: Optional[List[int]] = None, + shard_config: ShardConfig = None, + ) -> Union[Dict, Tuple, GPT2DoubleHeadsModelOutput]: + r""" + mc_token_ids (`torch.LongTensor` of shape `(batch_size, num_choices)`, *optional*, default to index of the last token of the input): + Index of the classification token in each input sequence. Selected in the range `[0, input_ids.size(-1) - + 1]`. + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set + `labels = input_ids`. Indices are selected in `[-100, 0, ..., config.vocab_size - 1]`. All labels set to + `-100` are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size - 1]` + mc_labels (`torch.LongTensor` of shape `(batch_size)`, *optional*): + Labels for computing the multiple choice classification loss. Indices should be in `[0, ..., num_choices]` + where *num_choices* is the size of the second dimension of the input tensors. (see *input_ids* above) + + This function is modified on the basis of transformers.models.gpt2.modeling_gpt2.GPT2DoubleHeadsModel.forward. + Please refer to original code of transformers for more details. + ```""" + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = GPT2PipelineForwards.gpt2_model_forward( + self.transformer, + input_ids, + past_key_values=past_key_values, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + stage_manager=stage_manager, + hidden_states=hidden_states, + stage_index=stage_index, + shard_config=shard_config, + ) + + # If not at the last stage, return hidden_states as in GPT2Model + if not stage_manager.is_last_stage(): + return {"hidden_states": outputs["hidden_states"]} + + hidden_states = outputs[0] + lm_logits = self.lm_head(hidden_states) + mc_logits = self.multiple_choice_head(hidden_states, mc_token_ids).squeeze(-1) + + mc_loss = None + if mc_labels is not None: + loss_fct = CrossEntropyLoss() + mc_loss = loss_fct(mc_logits.view(-1, mc_logits.size(-1)), mc_labels.view(-1)) + lm_loss = None + if labels is not None: + labels = labels.to(lm_logits.device) + shift_logits = lm_logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + loss_fct = CrossEntropyLoss() + lm_loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) + + if not return_dict: + output = (lm_logits, mc_logits) + outputs[1:] + if mc_loss is not None: + output = (mc_loss,) + output + return ((lm_loss,) + output) if lm_loss is not None else output + + return GPT2DoubleHeadsModelOutput( + loss=lm_loss, + mc_loss=mc_loss, + logits=lm_logits, + mc_logits=mc_logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + @staticmethod + def gpt2_for_question_answering_forward( + self: GPT2ForQuestionAnswering, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + start_positions: Optional[torch.LongTensor] = None, + end_positions: Optional[torch.LongTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + stage_manager: Optional[PipelineStageManager] = None, + hidden_states: Optional[torch.FloatTensor] = None, + stage_index: Optional[List[int]] = None, + shard_config: ShardConfig = None, + ) -> Union[Dict, Tuple, QuestionAnsweringModelOutput]: + r""" + start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for position (index) of the start of the labelled span for computing the token classification loss. + Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence + are not taken into account for computing the loss. + end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for position (index) of the end of the labelled span for computing the token classification loss. + Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence + are not taken into account for computing the loss. + + # This function is modified on the basis of transformers.models.gpt2.modeling_gpt2.GPT2ForQuestionAnswering.forward. + # Please refer to original code of transformers for more details. + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = GPT2PipelineForwards.gpt2_model_forward( + self.transformer, + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + stage_manager=stage_manager, + hidden_states=hidden_states, + stage_index=stage_index, + shard_config=shard_config, + ) + + # If not at the last stage, return hidden_states as in GPT2Model + if not stage_manager.is_last_stage(): + return {"hidden_states": outputs["hidden_states"]} + + sequence_output = outputs[0] + + logits = self.qa_outputs(sequence_output) + start_logits, end_logits = logits.split(1, dim=-1) + start_logits = start_logits.squeeze(-1).contiguous() + end_logits = end_logits.squeeze(-1).contiguous() + + total_loss = None + if start_positions is not None and end_positions is not None: + # If we are on multi-GPU, split add a dimension + if len(start_positions.size()) > 1: + start_positions = start_positions.squeeze(-1).to(start_logits.device) + if len(end_positions.size()) > 1: + end_positions = end_positions.squeeze(-1).to(end_logits.device) + # sometimes the start/end positions are outside our model inputs, we ignore these terms + ignored_index = start_logits.size(1) + start_positions = start_positions.clamp(0, ignored_index) + end_positions = end_positions.clamp(0, ignored_index) + + loss_fct = CrossEntropyLoss(ignore_index=ignored_index) + start_loss = loss_fct(start_logits, start_positions) + end_loss = loss_fct(end_logits, end_positions) + total_loss = (start_loss + end_loss) / 2 + + if not return_dict: + output = (start_logits, end_logits) + outputs[2:] + return ((total_loss,) + output) if total_loss is not None else output + + return QuestionAnsweringModelOutput( + loss=total_loss, + start_logits=start_logits, + end_logits=end_logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + @staticmethod + def gpt2_for_token_classification_forward( + self: GPT2ForTokenClassification, + input_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, + attention_mask: Optional[torch.FloatTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + stage_manager: Optional[PipelineStageManager] = None, + hidden_states: Optional[torch.FloatTensor] = None, + stage_index: Optional[List[int]] = None, + shard_config: ShardConfig = None, + ) -> Union[Dict, Tuple, TokenClassifierOutput]: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + + # This function is modified on the basis of transformers.models.gpt2.modeling_gpt2.GPT2ForTokenClassification.forward. + # Please refer to original code of transformers for more details. + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = GPT2PipelineForwards.gpt2_model_forward( + self.transformer, + input_ids, + past_key_values=past_key_values, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + stage_manager=stage_manager, + hidden_states=hidden_states, + stage_index=stage_index, + shard_config=shard_config, + ) + + # If not at the last stage, return hidden_states as in GPT2Model + if not stage_manager.is_last_stage(): + return {"hidden_states": outputs["hidden_states"]} + + hidden_states = outputs[0] + hidden_states = self.dropout(hidden_states) + logits = self.classifier(hidden_states) + + loss = None + if labels is not None: + labels = labels.to(logits.device) + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + + if not return_dict: + output = (logits,) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return TokenClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + @staticmethod + def gpt2_for_sequence_classification_forward( + self: GPT2ForSequenceClassification, + input_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, + attention_mask: Optional[torch.FloatTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + stage_manager: Optional[PipelineStageManager] = None, + hidden_states: Optional[torch.FloatTensor] = None, + stage_index: Optional[List[int]] = None, + shard_config: ShardConfig = None, + ) -> Union[Dict, Tuple, SequenceClassifierOutputWithPast]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + + # This function is modified on the basis of transformers.models.gpt2.modeling_gpt2.GPT2ForSequenceClassification.forward. + # Please refer to original code of transformers for more details. + """ + logger = logging.get_logger(__name__) + + if input_ids is not None: + batch_size, _ = input_ids.shape[:2] + else: + batch_size, _ = hidden_states.shape[:2] + assert ( + self.config.pad_token_id is not None or batch_size == 1 + ), "Cannot handle batch sizes > 1 if no padding token is defined." + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = GPT2PipelineForwards.gpt2_model_forward( + self.transformer, + input_ids, + past_key_values=past_key_values, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + stage_manager=stage_manager, + hidden_states=hidden_states, + stage_index=stage_index, + shard_config=shard_config, + ) + + # If not at the last stage, return hidden_states as in GPT2Model + if not stage_manager.is_last_stage(): + return {"hidden_states": outputs["hidden_states"]} + + hidden_states = outputs[0] + logits = self.score(hidden_states) + + if self.config.pad_token_id is None: + sequence_lengths = -1 + else: + if input_ids is not None: + sequence_lengths = (torch.ne(input_ids, self.config.pad_token_id).sum(-1) - 1).to(logits.device) + else: + sequence_lengths = -1 + logger.warning_once( + f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be " + "unexpected if using padding tokens in conjunction with `inputs_embeds.`" + ) + + pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths] + + loss = None + if labels is not None: + if self.config.problem_type is None: + if self.num_labels == 1: + self.config.problem_type = "regression" + elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): + self.config.problem_type = "single_label_classification" + else: + self.config.problem_type = "multi_label_classification" + + if self.config.problem_type == "regression": + loss_fct = MSELoss() + if self.num_labels == 1: + loss = loss_fct(pooled_logits.squeeze(), labels.squeeze()) + else: + loss = loss_fct(pooled_logits, labels) + elif self.config.problem_type == "single_label_classification": + loss_fct = CrossEntropyLoss() + loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1)) + elif self.config.problem_type == "multi_label_classification": + loss_fct = BCEWithLogitsLoss() + loss = loss_fct(pooled_logits, labels) + if not return_dict: + output = (pooled_logits,) + outputs[1:] + return ((loss,) + output) if loss is not None else output + + return SequenceClassifierOutputWithPast( + loss=loss, + logits=pooled_logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +def get_gpt2_flash_attention_forward(): + from transformers.models.gpt2.modeling_gpt2 import GPT2Attention + + from colossalai.kernel.cuda_native import AttnMaskType, ColoAttention + + def split_heads(tensor, num_heads, attn_head_size): + """ + Splits hidden_size dim into attn_head_size and num_heads + """ + new_shape = tensor.size()[:-1] + (num_heads, attn_head_size) + tensor = tensor.view(new_shape) + return tensor + + def forward( + self: GPT2Attention, + hidden_states: Optional[Tuple[torch.FloatTensor]], + layer_past: Optional[Tuple[torch.Tensor]] = None, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = False, + output_attentions: Optional[bool] = False, + ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]], ...]: + if encoder_hidden_states is not None: + if not hasattr(self, "q_attn"): + raise ValueError( + "If class is used as cross attention, the weights `q_attn` have to be defined. " + "Please make sure to instantiate class with `GPT2Attention(..., is_cross_attention=True)`." + ) + + query = self.q_attn(hidden_states) + key, value = self.c_attn(encoder_hidden_states).split(self.split_size, dim=2) + attention_mask = encoder_attention_mask + else: + query, key, value = self.c_attn(hidden_states).split(self.split_size, dim=2) + + query = split_heads(query, self.num_heads, self.head_dim) + key = split_heads(key, self.num_heads, self.head_dim) + value = split_heads(value, self.num_heads, self.head_dim) + + if layer_past is not None: + past_key, past_value = layer_past + key = torch.cat((past_key, key), dim=1) + value = torch.cat((past_value, value), dim=1) + + if use_cache is True: + present = (key, value) + else: + present = None + + if not self.is_cross_attention: + attn_mask_type = AttnMaskType.causal + flash_attention_mask = None + if attention_mask != None: + if attn_mask_type == AttnMaskType.causal: + attn_mask_type == AttnMaskType.paddedcausal + else: + attn_mask_type = AttnMaskType.padding + flash_attention_mask = ~(attention_mask[:, :, -1].squeeze(1).to(torch.bool)).contiguous() + + scale = value.size(-1) ** -0.5 + if self.scale_attn_by_inverse_layer_idx: + scale = scale * (1 / float(self.layer_idx + 1)) + + # use coloattention + attention = ColoAttention( + embed_dim=self.embed_dim, num_heads=self.num_heads, dropout=self.attn_dropout.p, scale=scale + ) + + attn_output = attention(query, key, value, attn_mask=flash_attention_mask, attn_mask_type=attn_mask_type) + + attn_output = self.c_proj(attn_output) + attn_output = self.resid_dropout(attn_output) + outputs = (attn_output, present, None) + + return outputs + + return forward + + +def gpt2_sequence_parallel_forward_fn(shard_config: ShardConfig): + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, + attention_mask: Optional[torch.FloatTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is not None: + input_shape = input_ids.size() + input_ids = input_ids.view(-1, input_shape[-1]) + batch_size = input_ids.shape[0] + elif inputs_embeds is not None: + input_shape = inputs_embeds.size()[:-1] + batch_size = inputs_embeds.shape[0] + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + device = input_ids.device if input_ids is not None else inputs_embeds.device + + if token_type_ids is not None: + token_type_ids = token_type_ids.view(-1, input_shape[-1]) + if position_ids is not None: + position_ids = position_ids.view(-1, input_shape[-1]) + + if past_key_values is None: + past_length = 0 + past_key_values = tuple([None] * len(self.h)) + else: + past_length = past_key_values[0][0].size(-2) + if position_ids is None: + position_ids = torch.arange(past_length, input_shape[-1] + past_length, dtype=torch.long, device=device) + position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1]) + + # GPT2Attention mask. + if attention_mask is not None: + if batch_size <= 0: + raise ValueError("batch_size has to be defined and > 0") + attention_mask = attention_mask.view(batch_size, -1) + # We create a 3D attention mask from a 2D tensor mask. + # Sizes are [batch_size, 1, 1, to_seq_length] + # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length] + # this attention mask is more simple than the triangular masking of causal attention + # used in OpenAI GPT, we just need to prepare the broadcast dimension here. + attention_mask = attention_mask[:, None, None, :] + + # Since attention_mask is 1.0 for positions we want to attend and 0.0 for + # masked positions, this operation will create a tensor which is 0.0 for + # positions we want to attend and the dtype's smallest value for masked positions. + # Since we are adding it to the raw scores before the softmax, this is + # effectively the same as removing these entirely. + attention_mask = attention_mask.to(dtype=self.dtype) # fp16 compatibility + attention_mask = (1.0 - attention_mask) * torch.finfo(self.dtype).min + + # If a 2D or 3D attention mask is provided for the cross-attention + # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] + if self.config.add_cross_attention and encoder_hidden_states is not None: + encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size() + encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length) + if encoder_attention_mask is None: + encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device) + encoder_attention_mask = self.invert_attention_mask(encoder_attention_mask) + else: + encoder_attention_mask = None + + # Prepare head mask if needed + # 1.0 in head_mask indicate we keep the head + # attention_probs has shape bsz x n_heads x N x N + # head_mask has shape n_layer x batch x n_heads x N x N + head_mask = self.get_head_mask(head_mask, self.config.n_layer) + + if inputs_embeds is None: + inputs_embeds = self.wte(input_ids) + position_embeds = self.wpe(position_ids) + hidden_states = inputs_embeds + position_embeds + + if token_type_ids is not None: + token_type_embeds = self.wte(token_type_ids) + hidden_states = hidden_states + token_type_embeds + + hidden_states = self.drop(hidden_states) + + output_shape = input_shape + (hidden_states.size(-1),) + + if self.gradient_checkpointing and self.training: + if use_cache: + logger = logging.get_logger(__name__) + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + presents = () if use_cache else None + all_self_attentions = () if output_attentions else None + all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None + all_hidden_states = () if output_hidden_states else None + + # split the input tensor along sequence dimension + # [batch_size, seq_len, hidden_size] -> [batch_size, seq_len/TP_size, hidden_size] + hidden_states = split_forward_gather_backward( + hidden_states, dim=1, process_group=shard_config.tensor_parallel_process_group + ) + + for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)): + # Model parallel + if self.model_parallel: + torch.cuda.set_device(hidden_states.device) + # Ensure layer_past is on same device as hidden_states (might not be correct) + if layer_past is not None: + layer_past = tuple(past_state.to(hidden_states.device) for past_state in layer_past) + # Ensure that attention_mask is always on the same device as hidden_states + if attention_mask is not None: + attention_mask = attention_mask.to(hidden_states.device) + if isinstance(head_mask, torch.Tensor): + head_mask = head_mask.to(hidden_states.device) + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if self.gradient_checkpointing and self.training: + + def create_custom_forward(module): + def custom_forward(*inputs): + # None for past_key_value + return module(*inputs, use_cache, output_attentions) + + return custom_forward + + outputs = torch.utils.checkpoint.checkpoint( + create_custom_forward(block), + hidden_states, + None, + attention_mask, + head_mask[i], + encoder_hidden_states, + encoder_attention_mask, + ) + else: + outputs = block( + hidden_states, + layer_past=layer_past, + attention_mask=attention_mask, + head_mask=head_mask[i], + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + use_cache=use_cache, + output_attentions=output_attentions, + ) + + hidden_states = outputs[0] + if use_cache is True: + presents = presents + (outputs[1],) + + if output_attentions: + all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],) + if self.config.add_cross_attention: + all_cross_attentions = all_cross_attentions + (outputs[3 if use_cache else 2],) + + # Model Parallel: If it's the last layer for that device, put things on the next device + if self.model_parallel: + for k, v in self.device_map.items(): + if i == v[-1] and "cuda:" + str(k) != self.last_device: + hidden_states = hidden_states.to("cuda:" + str(k + 1)) + + # When sequence parallelism done, gather the output tensor in forward and split it in backward + hidden_states = gather_forward_split_backward( + hidden_states, dim=1, process_group=shard_config.tensor_parallel_process_group + ) + + hidden_states = self.ln_f(hidden_states) + hidden_states = hidden_states.view(output_shape) + # Add last hidden state + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if not return_dict: + return tuple( + v + for v in [hidden_states, presents, all_hidden_states, all_self_attentions, all_cross_attentions] + if v is not None + ) + + return BaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=hidden_states, + past_key_values=presents, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + cross_attentions=all_cross_attentions, + ) + + return forward diff --git a/colossalai/shardformer/modeling/jit.py b/colossalai/shardformer/modeling/jit.py new file mode 100644 index 0000000000000000000000000000000000000000..c92847a3fbcccc32b466610dbac8a29722b025f9 --- /dev/null +++ b/colossalai/shardformer/modeling/jit.py @@ -0,0 +1,31 @@ +import torch + + +def get_dropout_add_func(): + from transformers.models.bloom.modeling_bloom import dropout_add + + def self_dropout_add(self, x: torch.Tensor, residual: torch.Tensor, prob: float, training: bool) -> torch.Tensor: + return dropout_add(x, residual, prob, training) + + return self_dropout_add + + +def get_jit_fused_dropout_add_func(): + from colossalai.kernel.jit import bias_dropout_add_fused_inference, bias_dropout_add_fused_train + + def self_dropout_add(self, x: torch.Tensor, residual: torch.Tensor, prob: float, training: bool) -> torch.Tensor: + bias = torch.zeros_like(x) + if training: + return bias_dropout_add_fused_train(x, bias, residual, prob) + return bias_dropout_add_fused_inference(x, bias, residual, prob) + + return self_dropout_add + + +def get_jit_fused_gelu_forward_func(): + from colossalai.kernel.jit.bias_gelu import bias_gelu + + def bloom_gelu_forward(x: torch.Tensor, bias: torch.Tensor) -> torch.Tensor: + return bias_gelu(bias, x) + + return bloom_gelu_forward diff --git a/colossalai/shardformer/modeling/llama.py b/colossalai/shardformer/modeling/llama.py new file mode 100644 index 0000000000000000000000000000000000000000..4b6c8342534a84f45a2ba7e76e3ea65be3d0c982 --- /dev/null +++ b/colossalai/shardformer/modeling/llama.py @@ -0,0 +1,468 @@ +import warnings +from typing import List, Optional, Tuple + +import torch +from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss +from transformers.modeling_outputs import ( + BaseModelOutputWithPast, + CausalLMOutputWithPast, + SequenceClassifierOutputWithPast, +) +from transformers.models.llama.modeling_llama import LlamaForCausalLM, LlamaForSequenceClassification, LlamaModel +from transformers.utils import logging + +from colossalai.pipeline.stage_manager import PipelineStageManager + + +class LlamaPipelineForwards: + """ + This class serves as a micro library for forward function substitution of Llama models + under pipeline setting. + """ + + @staticmethod + def llama_model_forward( + self: LlamaModel, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + stage_manager: Optional[PipelineStageManager] = None, + hidden_states: Optional[torch.FloatTensor] = None, + stage_index: Optional[List[int]] = None, + ): + logger = logging.get_logger(__name__) + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # retrieve input_ids and inputs_embeds + if stage_manager.is_first_stage(): + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") + elif input_ids is not None: + batch_size, seq_length = input_ids.shape + elif inputs_embeds is not None: + batch_size, seq_length, _ = inputs_embeds.shape + else: + raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") + device = input_ids.device if input_ids is not None else inputs_embeds.device + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + hidden_states = inputs_embeds + else: + input_shape = hidden_states.shape[:-1] + batch_size, seq_length = input_shape + device = hidden_states.device + + seq_length_with_past = seq_length + past_key_values_length = 0 + + # TODO(jianghai): left the recording kv-value tensors as () or None type, this feature may be added in the future. + if output_attentions: + logger.warning_once("output_attentions=True is not supported for pipeline models at the moment.") + output_attentions = False + if output_hidden_states: + logger.warning_once("output_hidden_states=True is not supported for pipeline models at the moment.") + output_hidden_states = False + if use_cache: + logger.warning_once("use_cache=True is not supported for pipeline models at the moment.") + use_cache = False + + if past_key_values is not None: + past_key_values_length = past_key_values[0][0].shape[2] + seq_length_with_past = seq_length_with_past + past_key_values_length + + if position_ids is None: + position_ids = torch.arange( + past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device + ) + position_ids = position_ids.unsqueeze(0).view(-1, seq_length) + else: + position_ids = position_ids.view(-1, seq_length).long() + + # embed positions, for the first stage, hidden_states is the input embeddings, + # for the other stages, hidden_states is the output of the previous stage + if attention_mask is None: + attention_mask = torch.ones( + (batch_size, seq_length_with_past), dtype=torch.bool, device=hidden_states.device + ) + attention_mask = self._prepare_decoder_attention_mask( + attention_mask, (batch_size, seq_length), hidden_states, past_key_values_length + ) + + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + next_decoder_cache = () if use_cache else None + + start_idx, end_idx = stage_index[0], stage_index[1] + for idx, decoder_layer in enumerate(self.layers[start_idx:end_idx], start=start_idx): + if output_hidden_states: + all_hidden_states += (hidden_states,) + + past_key_value = past_key_values[idx] if past_key_values is not None else None + + if self.gradient_checkpointing and self.training: + + def create_custom_forward(module): + def custom_forward(*inputs): + # None for past_key_value + return module(*inputs, output_attentions, None) + + return custom_forward + + layer_outputs = torch.utils.checkpoint.checkpoint( + create_custom_forward(decoder_layer), + hidden_states, + attention_mask, + position_ids, + None, + ) + else: + layer_outputs = decoder_layer( + hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + ) + + hidden_states = layer_outputs[0] + + if use_cache: + next_decoder_cache += (layer_outputs[2 if output_attentions else 1],) + if output_attentions: + all_self_attns += (layer_outputs[1],) + + if stage_manager.is_last_stage(): + hidden_states = self.norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + next_cache = next_decoder_cache if use_cache else None + if stage_manager.is_last_stage(): + if not return_dict: + return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + ) + # always return dict for imediate stage + return {"hidden_states": hidden_states} + + @staticmethod + def llama_for_causal_lm_forward( + self: LlamaForCausalLM, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + stage_manager: Optional[PipelineStageManager] = None, + hidden_states: Optional[torch.FloatTensor] = None, + stage_index: Optional[List[int]] = None, + ): + r""" + Args: + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + Returns: + + Example: + + ```python + >>> from transformers import AutoTokenizer, LlamaForCausalLM + + >>> model = LlamaForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS) + >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER) + + >>> prompt = "Hey, are you consciours? Can you talk to me?" + >>> inputs = tokenizer(prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "Hey, are you consciours? Can you talk to me?\nI'm not consciours, but I can talk to you." + ```""" + logger = logging.get_logger(__name__) + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # TODO(jianghai): left the recording kv-value tensors as () or None type, this feature may be added in the future. + if output_attentions: + logger.warning_once("output_attentions=True is not supported for pipeline models at the moment.") + output_attentions = False + if output_hidden_states: + logger.warning_once("output_hidden_states=True is not supported for pipeline models at the moment.") + output_hidden_states = False + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = LlamaPipelineForwards.llama_model_forward( + self.model, + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + stage_manager=stage_manager, + hidden_states=hidden_states, + stage_index=stage_index, + ) + past_key_values = None + + if stage_manager.is_last_stage(): + hidden_states = outputs[0] + logits = self.lm_head(hidden_states) + loss = None + if labels is not None: + # Shift so that tokens < n predict n + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + loss_fct = CrossEntropyLoss() + shift_logits = shift_logits.view(-1, self.config.vocab_size) + shift_labels = shift_labels.view(-1) + # Enable model parallelism + shift_labels = shift_labels.to(shift_logits.device) + loss = loss_fct(shift_logits, shift_labels) + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + else: + hidden_states = outputs.get("hidden_states") + return {"hidden_states": hidden_states} + + @staticmethod + def llama_for_sequence_classification_forward( + self: LlamaForSequenceClassification, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + stage_manager: Optional[PipelineStageManager] = None, + hidden_states: Optional[torch.FloatTensor] = None, + stage_index: Optional[List[int]] = None, + ): + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + logger = logging.get_logger(__name__) + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + # TODO(jianghai): left the recording kv-value tensors as () or None type, this feature may be added in the future. + if output_attentions: + logger.warning_once("output_attentions=True is not supported for pipeline models at the moment.") + output_attentions = False + if output_hidden_states: + logger.warning_once("output_hidden_states=True is not supported for pipeline models at the moment.") + output_hidden_states = False + + transformer_outputs = LlamaPipelineForwards.llama_model_forward( + self.model, + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + stage_manager=stage_manager, + hidden_states=hidden_states, + stage_index=stage_index, + ) + + if input_ids is not None: + batch_size = input_ids.shape[0] + elif inputs_embeds is not None: + batch_size = inputs_embeds.shape[0] + else: + batch_size = hidden_states.shape[0] + + if stage_manager.is_last_stage(): + hidden_states = transformer_outputs[0] + logits = self.score(hidden_states) + + if self.config.pad_token_id is None and batch_size != 1: + raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.") + if self.config.pad_token_id is None: + sequence_lengths = -1 + else: + if input_ids is not None: + sequence_lengths = (torch.ne(input_ids, self.config.pad_token_id).sum(-1) - 1).to(logits.device) + else: + sequence_lengths = -1 + + pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths] + + loss = None + if labels is not None: + labels = labels.to(logits.device) + if self.config.problem_type is None: + if self.num_labels == 1: + self.config.problem_type = "regression" + elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): + self.config.problem_type = "single_label_classification" + else: + self.config.problem_type = "multi_label_classification" + + if self.config.problem_type == "regression": + loss_fct = MSELoss() + if self.num_labels == 1: + loss = loss_fct(pooled_logits.squeeze(), labels.squeeze()) + else: + loss = loss_fct(pooled_logits, labels) + elif self.config.problem_type == "single_label_classification": + loss_fct = CrossEntropyLoss() + loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1)) + elif self.config.problem_type == "multi_label_classification": + loss_fct = BCEWithLogitsLoss() + loss = loss_fct(pooled_logits, labels) + if not return_dict: + output = (pooled_logits,) + transformer_outputs[1:] + return ((loss,) + output) if loss is not None else output + + return SequenceClassifierOutputWithPast( + loss=loss, + logits=pooled_logits, + past_key_values=transformer_outputs.past_key_values, + hidden_states=transformer_outputs.hidden_states, + attentions=transformer_outputs.attentions, + ) + + else: + hidden_states = transformer_outputs.get("hidden_states") + return {"hidden_states": hidden_states} + + +def get_llama_flash_attention_forward(): + from transformers.models.llama.modeling_llama import LlamaAttention, apply_rotary_pos_emb + + from colossalai.kernel.cuda_native import AttnMaskType, ColoAttention + + llama_version = 2 + try: + from transformers.models.llama.modeling_llama import repeat_kv + except: + warnings.warn("using llamav1, llamav1 hasn't repeat_kv function") + llama_version = 1 + + from colossalai.kernel.cuda_native import AttnMaskType, ColoAttention + + def forward( + self: LlamaAttention, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: bool = False, + use_cache: bool = False, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + bsz, q_len, _ = hidden_states.size() + assert q_len % 4 == 0, "Flash Attention Error: The sequence length should be a multiple of 4." + + query_states = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = self.k_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + value_states = self.v_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + + kv_seq_len = key_states.shape[-2] + if past_key_value is not None: + kv_seq_len += past_key_value[0].shape[-2] + + cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) + + if past_key_value is not None: + # reuse k, v, self_attention + key_states = torch.cat([past_key_value[0], key_states], dim=2) + value_states = torch.cat([past_key_value[1], value_states], dim=2) + + past_key_value = (key_states, value_states) if use_cache else None + + # repeat k/v heads if n_kv_heads < n_heads + if llama_version == 2: + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + me_input_shape = (bsz, q_len, self.num_heads, self.head_dim) + query_states = query_states.transpose(1, 2).contiguous().view(*me_input_shape) + key_states = key_states.transpose(1, 2).contiguous().view(*me_input_shape) + value_states = value_states.transpose(1, 2).contiguous().view(*me_input_shape) + + flash_attention_mask = None + attn_mask_type = AttnMaskType.causal + if attention_mask != None: + if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): + raise ValueError( + f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" + ) + flash_attention_mask = ~(attention_mask[:, :, -1].squeeze(1).to(torch.bool)).contiguous() + attn_mask_type = AttnMaskType.paddedcausal + + attention = ColoAttention(embed_dim=self.hidden_size, num_heads=self.num_heads) + attn_output = attention( + query_states, key_states, value_states, attn_mask=flash_attention_mask, attn_mask_type=attn_mask_type + ) + + attn_output = self.o_proj(attn_output) + + return attn_output, None, past_key_value + + return forward diff --git a/colossalai/shardformer/modeling/opt.py b/colossalai/shardformer/modeling/opt.py new file mode 100644 index 0000000000000000000000000000000000000000..e0978d38e110de9d94f2040213da35ee74ec5d2e --- /dev/null +++ b/colossalai/shardformer/modeling/opt.py @@ -0,0 +1,678 @@ +import random +from typing import List, Optional, Tuple, Union + +import torch +from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss +from transformers.modeling_outputs import ( + BaseModelOutputWithPast, + CausalLMOutputWithPast, + QuestionAnsweringModelOutput, + SequenceClassifierOutputWithPast, +) +from transformers.models.opt.modeling_opt import ( + OPTForCausalLM, + OPTForQuestionAnswering, + OPTForSequenceClassification, + OPTModel, +) +from transformers.utils import logging + +from colossalai.pipeline.stage_manager import PipelineStageManager + + +class OPTPipelineForwards: + """ + This class serves as a micro library for forward function substitution of OPT models + under pipeline setting. + """ + + @staticmethod + def _prepare_decoder_attention_mask(attention_mask, input_shape, _dtype, device, past_key_values_length): + # create causal mask + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + from transformers.models.opt.modeling_opt import _make_causal_mask + + combined_attention_mask = None + if input_shape[-1] > 1: + combined_attention_mask = _make_causal_mask( + input_shape, + _dtype, + device, + past_key_values_length=past_key_values_length, + ) + + if attention_mask is not None: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + expanded_attn_mask = OPTPipelineForwards._expand_mask(attention_mask, _dtype, tgt_len=input_shape[-1]).to( + device + ) + combined_attention_mask = ( + expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask + ) + + return combined_attention_mask + + @staticmethod + def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None): + """ + Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`. + """ + bsz, src_len = mask.size() + tgt_len = tgt_len if tgt_len is not None else src_len + + expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype) + + inverted_mask = 1.0 - expanded_mask + + return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min) + + @staticmethod + def opt_model_forward( + self: OPTModel, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + stage_manager: Optional[PipelineStageManager] = None, + hidden_states: Optional[torch.FloatTensor] = None, + stage_index: Optional[List[int]] = None, + ) -> Union[Tuple, BaseModelOutputWithPast]: + """ + This forward method is modified based on transformers.models.opt.modeling_opt.OPTModel.forward + """ + + from transformers.modeling_outputs import BaseModelOutputWithPast + from transformers.utils import logging + + logger = logging.get_logger(__name__) + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + decoder = self.decoder + if stage_manager.is_first_stage(): + # retrieve input_ids and inputs_embeds + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") + elif input_ids is not None: + input_shape = input_ids.size() + input_ids = input_ids.view(-1, input_shape[-1]) + elif inputs_embeds is not None: + input_shape = inputs_embeds.size()[:-1] + else: + raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") + + batch_size, seq_length = input_shape + + if inputs_embeds is None: + inputs_embeds = decoder.embed_tokens(input_ids) + + if decoder.project_in is not None: + inputs_embeds = decoder.project_in(inputs_embeds) + device = input_ids.device if input_ids is not None else inputs_embeds.device + _dtype = inputs_embeds.dtype + + else: + if hidden_states is None: + raise ValueError("hidden_states shouln't be None for intermediate stages.") + input_shape = hidden_states.size()[:-1] + batch_size, seq_length = input_shape[0], input_shape[1] + device = hidden_states.device + _dtype = hidden_states.dtype + + past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0 + # required mask seq length can be calculated via length of past + mask_seq_length = past_key_values_length + seq_length + # embed positions + if attention_mask is None: + attention_mask = torch.ones(batch_size, mask_seq_length, device=device) + elif attention_mask.shape[1] != mask_seq_length: + raise ValueError( + f"The provided attention mask has length {attention_mask.shape[1]}, but its length should be " + f"{mask_seq_length} (sum of the lengths of current and past inputs)" + ) + + causal_attention_mask = OPTPipelineForwards._prepare_decoder_attention_mask( + attention_mask, input_shape, _dtype, device, past_key_values_length + ) + + if stage_manager.is_first_stage(): + pos_embeds = decoder.embed_positions(attention_mask, past_key_values_length) + hidden_states = inputs_embeds + pos_embeds + + if decoder.gradient_checkpointing and decoder.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + # TODO(baizhou): left the recording kv-value tensors as () or None type, this feature may be added in the future. + if past_key_values: + logger.warning_once("Non-empty past_key_values is not supported for pipeline models at the moment.") + past_key_values = None + if output_attentions: + logger.warning_once("output_attentions=True is not supported for pipeline models at the moment.") + output_attentions = False + if output_hidden_states: + logger.warning_once("output_hidden_states=True is not supported for pipeline models at the moment.") + output_hidden_states = False + if use_cache: + logger.warning_once("use_cache=True is not supported for pipeline models at the moment.") + use_cache = False + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + next_decoder_cache = () if use_cache else None + + # check if head_mask has a correct number of layers specified if desired + for attn_mask, mask_name in zip([head_mask], ["head_mask"]): + if attn_mask is not None: + if attn_mask.size()[0] != (len(decoder.layers)): + raise ValueError( + f"The `{mask_name}` should be specified for {len(decoder.layers)} layers, but it is for" + f" {head_mask.size()[0]}." + ) + + start_idx, end_idx = stage_index[0], stage_index[1] + + torch.cuda.set_device(device) + + for idx in range(start_idx, end_idx): + # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) + decoder_layer = decoder.layers[idx] + + if output_hidden_states: + all_hidden_states += (hidden_states,) + + dropout_probability = random.uniform(0, 1) + if decoder.training and (dropout_probability < decoder.layerdrop): + continue + + past_key_value = past_key_values[idx] if past_key_values is not None else None + + if decoder.gradient_checkpointing and decoder.training: + + def create_custom_forward(module): + def custom_forward(*inputs): + # None for past_key_value + return module(*inputs, output_attentions, None) + + return custom_forward + + layer_outputs = torch.utils.checkpoint.checkpoint( + create_custom_forward(decoder_layer), + hidden_states, + causal_attention_mask, + head_mask[idx] if head_mask is not None else None, + None, + ) + else: + layer_outputs = decoder_layer( + hidden_states, + attention_mask=causal_attention_mask, + layer_head_mask=(head_mask[idx] if head_mask is not None else None), + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + ) + + hidden_states = layer_outputs[0] + + if use_cache: + next_decoder_cache += (layer_outputs[2 if output_attentions else 1],) + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + if stage_manager.is_last_stage(): + if decoder.final_layer_norm is not None: + hidden_states = decoder.final_layer_norm(hidden_states) + if decoder.project_out is not None: + hidden_states = decoder.project_out(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + next_cache = next_decoder_cache if use_cache else None + + if stage_manager.is_last_stage(): + if not return_dict: + return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) + + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + ) + else: + return {"hidden_states": hidden_states} + + @staticmethod + def opt_for_causal_lm_forward( + self: OPTForCausalLM, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + stage_manager: Optional[PipelineStageManager] = None, + hidden_states: Optional[torch.FloatTensor] = None, + stage_index: Optional[List[int]] = None, + ) -> Union[Tuple, CausalLMOutputWithPast]: + r""" + This function is modified on the basis of transformers.models.opt.modeling_opt.OPTForCausalLM.forward. + Please refer to original code of transformers for more details. + """ + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = OPTPipelineForwards.opt_model_forward( + self.model, + input_ids=input_ids, + attention_mask=attention_mask, + head_mask=head_mask, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + stage_manager=stage_manager, + hidden_states=hidden_states, + stage_index=stage_index, + ) + if stage_manager.is_last_stage(): + logits = self.lm_head(outputs[0]).contiguous() + loss = None + if labels is not None: + # move labels to correct device to enable model parallelism + labels = labels.to(logits.device) + # Shift so that tokens < n predict n + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + loss_fct = CrossEntropyLoss() + loss = loss_fct(shift_logits.view(-1, self.config.vocab_size), shift_labels.view(-1)) + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + else: + hidden_states = outputs.get("hidden_states") + return {"hidden_states": hidden_states} + + @staticmethod + def opt_for_sequence_classification_forward( + self: OPTForSequenceClassification, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + stage_manager: Optional[PipelineStageManager] = None, + hidden_states: Optional[torch.FloatTensor] = None, + stage_index: Optional[List[int]] = None, + ) -> Union[Tuple, SequenceClassifierOutputWithPast]: + r""" + This function is modified on the basis of transformers.models.opt.modeling_opt.OPTForSequenceClassification.forward. + Please refer to original code of transformers for more details. + """ + + logger = logging.get_logger(__name__) + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + transformer_outputs = OPTPipelineForwards.opt_model_forward( + self.model, + input_ids, + past_key_values=past_key_values, + attention_mask=attention_mask, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + stage_manager=stage_manager, + hidden_states=hidden_states, + stage_index=stage_index, + ) + + if stage_manager.is_last_stage(): + hidden_states = transformer_outputs[0] + logits = self.score(hidden_states) + + batch_size = input_ids.shape[0] if input_ids is not None else hidden_states.shape[0] + + if self.config.pad_token_id is None: + sequence_lengths = -1 + else: + if input_ids is not None: + sequence_lengths = (torch.ne(input_ids, self.config.pad_token_id).sum(-1) - 1).to(logits.device) + else: + sequence_lengths = -1 + logger.warning( + f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be " + "unexpected if using padding tokens in conjunction with `inputs_embeds.`" + ) + + pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths] + + loss = None + if labels is not None: + if self.config.problem_type is None: + if self.num_labels == 1: + self.config.problem_type = "regression" + elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): + self.config.problem_type = "single_label_classification" + else: + self.config.problem_type = "multi_label_classification" + + if self.config.problem_type == "regression": + loss_fct = MSELoss() + if self.num_labels == 1: + loss = loss_fct(pooled_logits.squeeze(), labels.squeeze()) + else: + loss = loss_fct(pooled_logits, labels) + elif self.config.problem_type == "single_label_classification": + loss_fct = CrossEntropyLoss() + loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1)) + elif self.config.problem_type == "multi_label_classification": + loss_fct = BCEWithLogitsLoss() + loss = loss_fct(pooled_logits, labels) + + if not return_dict: + output = (pooled_logits,) + transformer_outputs[1:] + return ((loss,) + output) if loss is not None else output + + return SequenceClassifierOutputWithPast( + loss=loss, + logits=pooled_logits, + past_key_values=transformer_outputs.past_key_values, + hidden_states=transformer_outputs.hidden_states, + attentions=transformer_outputs.attentions, + ) + else: + hidden_states = transformer_outputs.get("hidden_states") + return {"hidden_states": hidden_states} + + @staticmethod + def opt_for_question_answering_forward( + self: OPTForQuestionAnswering, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + start_positions: Optional[torch.LongTensor] = None, + end_positions: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + stage_manager: Optional[PipelineStageManager] = None, + hidden_states: Optional[torch.FloatTensor] = None, + stage_index: Optional[List[int]] = None, + ) -> Union[Tuple, QuestionAnsweringModelOutput]: + r""" + This function is modified on the basis of transformers.models.opt.modeling_opt.OPTForQuestionAnswering.forward. + Please refer to original code of transformers for more details. + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + transformer_outputs = OPTPipelineForwards.opt_model_forward( + self.model, + input_ids, + past_key_values=past_key_values, + attention_mask=attention_mask, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + stage_manager=stage_manager, + hidden_states=hidden_states, + stage_index=stage_index, + ) + if stage_manager.is_last_stage(): + hidden_states = transformer_outputs[0] + + logits = self.qa_outputs(hidden_states) + start_logits, end_logits = logits.split(1, dim=-1) + start_logits = start_logits.squeeze(-1).contiguous() + end_logits = end_logits.squeeze(-1).contiguous() + + total_loss = None + if start_positions is not None and end_positions is not None: + # If we are on multi-GPU, split add a dimension + if len(start_positions.size()) > 1: + start_positions = start_positions.squeeze(-1) + if len(end_positions.size()) > 1: + end_positions = end_positions.squeeze(-1) + # sometimes the start/end positions are outside our model inputs, we ignore these terms + ignored_index = start_logits.size(1) + start_positions = start_positions.clamp(0, ignored_index) + end_positions = end_positions.clamp(0, ignored_index) + + loss_fct = CrossEntropyLoss(ignore_index=ignored_index) + start_loss = loss_fct(start_logits, start_positions) + end_loss = loss_fct(end_logits, end_positions) + total_loss = (start_loss + end_loss) / 2 + + if not return_dict: + output = (start_logits, end_logits) + transformer_outputs[2:] + return ((total_loss,) + output) if total_loss is not None else output + + return QuestionAnsweringModelOutput( + loss=total_loss, + start_logits=start_logits, + end_logits=end_logits, + hidden_states=transformer_outputs.hidden_states, + attentions=transformer_outputs.attentions, + ) + else: + hidden_states = transformer_outputs.get("hidden_states") + return {"hidden_states": hidden_states} + + +def get_opt_flash_attention_forward(): + from transformers.models.opt.modeling_opt import OPTAttention + + from colossalai.kernel.cuda_native import AttnMaskType, ColoAttention + + def forward( + self: OPTAttention, + hidden_states: torch.Tensor, + key_value_states: Optional[torch.Tensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + attention_mask: Optional[torch.Tensor] = None, + layer_head_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + """Input shape: Batch x Time x Channel""" + + # if key_value_states are provided this layer is used as a cross-attention layer + # for the decoder + is_cross_attention = key_value_states is not None + bsz, tgt_len, _ = hidden_states.size() + + attention_input_shape = (bsz, -1, self.num_heads, self.head_dim) + # get query proj + query_states = self.q_proj(hidden_states).view(*attention_input_shape) + # get key, value proj + if is_cross_attention and past_key_value is not None: + # reuse k, v, cross_attentions + key_states = past_key_value[0].transpose(1, 2).contiguous().view(*attention_input_shape) + value_states = past_key_value[1].transpose(1, 2).contiguous().view(*attention_input_shape) + elif is_cross_attention: + # cross_attentions + key_states = self.k_proj(key_value_states).view(*attention_input_shape) + value_states = self.v_proj(key_value_states).view(*attention_input_shape) + elif past_key_value is not None: + # reuse k, v, self_attention + key_states = self.k_proj(hidden_states).view(*attention_input_shape) + value_states = self.v_proj(hidden_states).view(*attention_input_shape) + key_states = torch.cat([past_key_value[0], key_states], dim=1) + value_states = torch.cat([past_key_value[1], value_states], dim=1) + else: + # self_attention + key_states = self.k_proj(hidden_states).view(*attention_input_shape) + value_states = self.v_proj(hidden_states).view(*attention_input_shape) + + if self.is_decoder: + # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. + # Further calls to cross_attention layer can then reuse all cross-attention + # key/value_states (first "if" case) + # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of + # all previous decoder key/value_states. Further calls to uni-directional self-attention + # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) + # if encoder bi-directional self-attention `past_key_value` is always `None` + past_key_value = (key_states, value_states) + + src_len = key_states.size(1) + if layer_head_mask != None: + if layer_head_mask.size() != (self.num_heads,): + raise ValueError( + f"Head mask for a single layer should be of size {(self.num_heads,)}, but is" + f" {layer_head_mask.size()}" + ) + + flash_attention_mask = None + attn_mask_type = AttnMaskType.causal + if attention_mask != None: + if attention_mask.size() != (bsz, 1, tgt_len, src_len): + raise ValueError( + f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}" + ) + flash_attention_mask = ~(attention_mask[:, :, -1].squeeze(1).to(torch.bool)).contiguous() + attn_mask_type = AttnMaskType.paddedcausal + + attention = ColoAttention( + embed_dim=self.embed_dim, num_heads=self.num_heads, dropout=self.dropout, scale=self.scaling + ) + attn_output = attention( + query_states, key_states, value_states, attn_mask=flash_attention_mask, attn_mask_type=attn_mask_type + ) + + attn_output = self.out_proj(attn_output) + return attn_output, None, past_key_value + + return forward + + +def get_jit_fused_opt_decoder_layer_forward(): + from transformers.models.opt.modeling_opt import OPTDecoderLayer + + def forward( + self: OPTDecoderLayer, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + layer_head_mask: Optional[torch.Tensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`, *optional*): attention mask of size + `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. + layer_head_mask (`torch.FloatTensor`, *optional*): mask for attention heads in a given layer of size + `(encoder_attention_heads,)`. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding + (see `past_key_values`). + past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states + """ + + residual = hidden_states + + # 125m, 1.7B, ..., 175B applies layer norm BEFORE attention + if self.do_layer_norm_before: + hidden_states = self.self_attn_layer_norm(hidden_states) + + # Self Attention + hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states=hidden_states, + past_key_value=past_key_value, + attention_mask=attention_mask, + layer_head_mask=layer_head_mask, + output_attentions=output_attentions, + ) + + hidden_states = self.dropout_add(hidden_states, residual, self.dropout, self.training) + + # 350m applies layer norm AFTER attention + if not self.do_layer_norm_before: + hidden_states = self.self_attn_layer_norm(hidden_states) + + # Fully Connected + hidden_states_shape = hidden_states.shape + hidden_states = hidden_states.reshape(-1, hidden_states.size(-1)) + residual = hidden_states + + # 125m, 1.7B, ..., 175B applies layer norm BEFORE attention + if self.do_layer_norm_before: + hidden_states = self.final_layer_norm(hidden_states) + + hidden_states = self.fc1(hidden_states) + hidden_states = self.activation_fn(hidden_states) + + hidden_states = self.fc2(hidden_states) + + hidden_states = self.dropout_add(hidden_states, residual, self.dropout, self.training).view(hidden_states_shape) + + # 350m applies layer norm AFTER attention + if not self.do_layer_norm_before: + hidden_states = self.final_layer_norm(hidden_states) + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights,) + + if use_cache: + outputs += (present_key_value,) + + return outputs + + return forward diff --git a/colossalai/shardformer/modeling/sam.py b/colossalai/shardformer/modeling/sam.py new file mode 100644 index 0000000000000000000000000000000000000000..26e0b224d3ab2f31939e08dd55f3f34f56ac2736 --- /dev/null +++ b/colossalai/shardformer/modeling/sam.py @@ -0,0 +1,207 @@ +import math +from typing import Tuple + +import torch +import torch.nn.functional as F +from torch import Tensor + + +def forward_fn(): + def forward(self, hidden_states: torch.Tensor, output_attentions=False) -> torch.Tensor: + batch_size, height, width, _ = hidden_states.shape + # qkv with shape (3, batch_size, nHead, height * width, channel) + qkv = ( + self.qkv(hidden_states) + .reshape(batch_size, height * width, 3, self.num_attention_heads, -1) + .permute(2, 0, 3, 1, 4) + ) + # q, k, v with shape (batch_size * nHead, height * width, channel) + query, key, value = qkv.reshape(3, batch_size * self.num_attention_heads, height * width, -1).unbind(0) + + attn_weights = (query * self.scale) @ key.transpose(-2, -1) + + if self.use_rel_pos: + attn_weights = self.add_decomposed_rel_pos( + attn_weights, query, self.rel_pos_h, self.rel_pos_w, (height, width), (height, width) + ) + + attn_weights = torch.nn.functional.softmax(attn_weights, dtype=torch.float32, dim=-1).to(query.dtype) + + # replace dropout process with added DropoutForParallelInput layer + # origin code: + # attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) + attn_probs = self.dropout_layer(attn_weights) + + attn_output = (attn_probs @ value).reshape(batch_size, self.num_attention_heads, height, width, -1) + attn_output = attn_output.permute(0, 2, 3, 1, 4).reshape(batch_size, height, width, -1) + + attn_output = self.proj(attn_output) + + if output_attentions: + outputs = (attn_output, attn_weights) + else: + outputs = (attn_output, None) + + return outputs + + return forward + + +def get_sam_flash_attention_forward(): + from transformers.models.sam.modeling_sam import SamAttention + + try: + from xformers.ops import memory_efficient_attention as me_attention + except: + raise ImportError("Error: xformers module is not installed. Please install it to use flash attention.") + + def _separate_heads(hidden_states: Tensor, num_attention_heads: int) -> Tensor: + batch, point_batch_size, n_tokens, channel = hidden_states.shape + c_per_head = channel // num_attention_heads + hidden_states = hidden_states.reshape(batch * point_batch_size, n_tokens, num_attention_heads, c_per_head) + return hidden_states + + def _recombine_heads(hidden_states: Tensor, point_batch_size: int) -> Tensor: + batch, n_tokens, n_heads, c_per_head = hidden_states.shape + return hidden_states.reshape(batch // point_batch_size, point_batch_size, n_tokens, n_heads * c_per_head) + + def forward( + self: SamAttention, query: Tensor, key: Tensor, value: Tensor, attention_similarity: Tensor = None + ) -> Tensor: + # Input projections + query = self.q_proj(query) + key = self.k_proj(key) + value = self.v_proj(value) + + point_batch_size = query.shape[1] + # Separate into heads + query = _separate_heads(query, self.num_attention_heads) + key = _separate_heads(key, self.num_attention_heads) + value = _separate_heads(value, self.num_attention_heads) + + # SamAttention + _, _, _, c_per_head = query.shape + bias = None + if attention_similarity is not None: + bias = attention_similarity + + scale = 1.0 / math.sqrt(c_per_head) + out = me_attention(query, key, value, attn_bias=bias, scale=scale) + + out = _recombine_heads(out, point_batch_size) + out = self.out_proj(out) + + return out + + return forward + + +def get_sam_vision_flash_attention_forward(): + from transformers.models.sam.modeling_sam import SamVisionAttention + + try: + from xformers.ops import memory_efficient_attention as me_attention + except: + raise ImportError("Error: xformers module is not installed. Please install it to use flash attention.") + + def add_decomposed_rel_pos( + query: torch.Tensor, + rel_pos_h: torch.Tensor, + rel_pos_w: torch.Tensor, + q_size: Tuple[int, int], + k_size: Tuple[int, int], + ) -> torch.Tensor: + """ + Calculate decomposed Relative Positional Embeddings from :paper:`mvitv2`. + https://github.com/facebookresearch/mvit/blob/19786631e330df9f3622e5402b4a419a263a2c80/mvit/models/attention.py + + Args: + attn (`torch.Tensor`): + attention map. + query (`torch.Tensor`): + query q in the attention layer with shape (batch_size, query_height * query_width, channel). + rel_pos_h (`torch.Tensor`): + relative position embeddings (Lh, channel) for height axis. + rel_pos_w (`torch.Tensor`): + relative position embeddings (Lw, channel) for width axis. + q_size (tuple): + spatial sequence size of query q with (query_height, query_width). + k_size (tuple): + spatial sequence size of key k with (key_height, key_width). + + Returns: + attn (`torch.Tensor`): + attention map with added relative positional embeddings. + """ + + query_height, query_width = q_size + key_height, key_width = k_size + relative_position_height = get_rel_pos(query_height, key_height, rel_pos_h) + relative_position_width = get_rel_pos(query_width, key_width, rel_pos_w) + + batch_size, _, nHead, dim = query.shape + reshaped_query = query.transpose(1, 2).reshape(batch_size * nHead, query_height, query_width, dim) + rel_h = torch.einsum("bhwc,hkc->bhwk", reshaped_query, relative_position_height) + rel_w = torch.einsum("bhwc,wkc->bhwk", reshaped_query, relative_position_width) + rel_pos = rel_h[:, :, :, :, None] + rel_w[:, :, :, None, :] + rel_pos = rel_pos.reshape(batch_size, nHead, query_height * query_width, key_height * key_width) + return rel_pos + + def get_rel_pos(q_size: int, k_size: int, rel_pos: torch.Tensor) -> torch.Tensor: + """ + Get relative positional embeddings according to the relative positions of + query and key sizes. + + Args: + q_size (int): + size of the query. + k_size (int): + size of key k. + rel_pos (`torch.Tensor`): + relative position embeddings (L, channel). + + Returns: + Extracted positional embeddings according to relative positions. + """ + max_rel_dist = int(2 * max(q_size, k_size) - 1) + # Interpolate rel pos. + rel_pos_resized = F.interpolate( + rel_pos.reshape(1, rel_pos.shape[0], -1).permute(0, 2, 1), + size=max_rel_dist, + mode="linear", + ) + rel_pos_resized = rel_pos_resized.reshape(-1, max_rel_dist).permute(1, 0) + + # Scale the coords with short length if shapes for q and k are different. + q_coords = torch.arange(q_size)[:, None] * max(k_size / q_size, 1.0) + k_coords = torch.arange(k_size)[None, :] * max(q_size / k_size, 1.0) + relative_coords = (q_coords - k_coords) + (k_size - 1) * max(q_size / k_size, 1.0) + + return rel_pos_resized[relative_coords.long()] + + def forward(self: SamVisionAttention, hidden_states: torch.Tensor, output_attentions=False) -> torch.Tensor: + batch_size, height, width, _ = hidden_states.shape + # qkv with shape (3, batch_size, nHead, height * width, channel) + qkv = ( + self.qkv(hidden_states) + .reshape(batch_size, height * width, 3, self.num_attention_heads, -1) + .permute(2, 0, 1, 3, 4) + ) + + query, key, value = qkv.reshape(3, batch_size, height * width, self.num_attention_heads, -1).unbind(0) + + rel_pos = None + if self.use_rel_pos: + rel_pos = add_decomposed_rel_pos(query, self.rel_pos_h, self.rel_pos_w, (height, width), (height, width)) + + attn_output = me_attention(query, key, value, attn_bias=rel_pos, p=self.dropout, scale=self.scale) + + attn_output = attn_output.reshape(batch_size, height, width, -1) + + attn_output = self.proj(attn_output) + + outputs = (attn_output, None) + + return outputs + + return forward diff --git a/colossalai/shardformer/modeling/t5.py b/colossalai/shardformer/modeling/t5.py new file mode 100644 index 0000000000000000000000000000000000000000..f67aa84e4e72fcbc3f65b882733d2e7268356d01 --- /dev/null +++ b/colossalai/shardformer/modeling/t5.py @@ -0,0 +1,793 @@ +import warnings +from typing import Dict, List, Optional, Tuple, Union + +import torch +from torch.nn import CrossEntropyLoss +from torch.utils.checkpoint import checkpoint +from transformers.modeling_outputs import ( + BaseModelOutput, + BaseModelOutputWithPastAndCrossAttentions, + Seq2SeqLMOutput, + Seq2SeqModelOutput, +) +from transformers.models.t5.modeling_t5 import T5EncoderModel, T5ForConditionalGeneration, T5Model, T5Stack +from transformers.utils import logging + +from colossalai.pipeline.stage_manager import PipelineStageManager + + +class T5PipelineForwards: + """ + This class serves as a micro library for forward function substitution of + T5 models under pipeline setting. + """ + + @staticmethod + def t5_stack_forward( + self: T5Stack, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + cross_attn_head_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + use_cache: Optional[bool] = False, + output_attentions: Optional[bool] = False, + output_hidden_states: Optional[bool] = False, + return_dict: Optional[bool] = None, + stage_manager: Optional[PipelineStageManager] = None, + hidden_states: Optional[torch.FloatTensor] = None, + position_bias: Optional[torch.Tensor] = None, + encoder_decoder_position_bias: Optional[torch.Tensor] = None, + stage_index: Optional[List[int]] = None, + decoder_starting_stage: Optional[int] = None, + ) -> Union[Dict, Tuple, BaseModelOutputWithPastAndCrossAttentions]: + # This function is modified on the basis of transformers.models.t5.modeling_t5.T5Stack.forward. + # Please refer to original code of transformers for more details. + + logger = logging.get_logger(__name__) + + # TODO(baizhou): left the recording kv-value tensors as () or None type, this feature may be added in the future. + if past_key_values: + logger.warning_once("Non-empty past_key_values is not supported for pipeline models at the moment.") + past_key_values = None + if output_attentions: + logger.warning_once("output_attentions=True is not supported for pipeline models at the moment.") + output_attentions = False + if output_hidden_states: + logger.warning_once("output_hidden_states=True is not supported for pipeline models at the moment.") + output_hidden_states = False + if use_cache: + logger.warning_once("use_cache=True is not supported for pipeline models at the moment.") + use_cache = False + if use_cache is True: + if not in_decoder: + raise ValueError(f"`use_cache` can only be set to `True` if {self} is used as a decoder") + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + stage = stage_manager.stage + in_decoder = self.is_decoder + if in_decoder != (stage >= decoder_starting_stage): + raise ValueError("Config in T5Stack is not aligned with pipeline setting.") + + # at_first_stage: current stage is the first stage of encoder/decoder, taking input_ids/input_embedds + # at_last_stage: current stage is the last stage of encoder/decoder, making outputs the same form as huggingface + at_first_stage = (stage == 0) or (stage == decoder_starting_stage) + at_last_stage = (stage == decoder_starting_stage - 1) or (stage == stage_manager.num_stages - 1) + + # Process inputs if at the first stage of encoder/decoder. + if at_first_stage: + if input_ids is not None and inputs_embeds is not None: + err_msg_prefix = "decoder_" if in_decoder else "" + raise ValueError( + f"You cannot specify both {err_msg_prefix}input_ids and {err_msg_prefix}inputs_embeds at the same time" + ) + elif input_ids is not None: + input_shape = input_ids.size() + input_ids = input_ids.view(-1, input_shape[-1]) + elif inputs_embeds is not None: + input_shape = inputs_embeds.size()[:-1] + else: + err_msg_prefix = "decoder_" if in_decoder else "" + raise ValueError( + f"You have to specify either {err_msg_prefix}input_ids or {err_msg_prefix}inputs_embeds" + ) + if inputs_embeds is None: + if self.embed_tokens is None: + raise ValueError("You have to initialize the model with valid token embeddings") + inputs_embeds = self.embed_tokens(input_ids) + batch_size, seq_length = input_shape + device = inputs_embeds.device + hidden_states = self.dropout(inputs_embeds) + else: + if hidden_states is None: + raise ValueError( + "hidden_states shouldn't be None for stages other than the first stage of encoder/decoder." + ) + input_shape = hidden_states.size()[:-1] + batch_size, seq_length = input_shape[0], input_shape[1] + device = hidden_states.device + + # required mask seq length can be calculated via length of past + mask_seq_length = past_key_values[0][0].shape[2] + seq_length if past_key_values is not None else seq_length + + if attention_mask is None: + attention_mask = torch.ones(batch_size, mask_seq_length, device=device) + if in_decoder and encoder_attention_mask is None and encoder_hidden_states is not None: + encoder_seq_length = encoder_hidden_states.shape[1] + encoder_attention_mask = torch.ones(batch_size, encoder_seq_length, device=device, dtype=torch.long) + + # initialize past_key_values with `None` if past does not exist + if past_key_values is None: + past_key_values = [None] * len(self.block) + + # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] + # ourselves in which case we just need to make it broadcastable to all heads. + extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape) + + # If a 2D or 3D attention mask is provided for the cross-attention + # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] + if self.is_decoder and encoder_hidden_states is not None: + encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size() + encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length) + if encoder_attention_mask is None: + encoder_attention_mask = torch.ones(encoder_hidden_shape, device=inputs_embeds.device) + encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask) + else: + encoder_extended_attention_mask = None + + # Prepare head mask if needed + head_mask = self.get_head_mask(head_mask, self.config.num_layers) + cross_attn_head_mask = self.get_head_mask(cross_attn_head_mask, self.config.num_layers) + present_key_value_states = () if use_cache else None + all_hidden_states = () if output_hidden_states else None + all_attentions = () if output_attentions else None + all_cross_attentions = () if (output_attentions and self.is_decoder) else None + + # Going through held blocks. + start_idx, end_idx = stage_index[0], stage_index[1] + + for i in range(start_idx, end_idx): + past_key_value = past_key_values[i] + layer_module = self.block[i] + layer_head_mask = head_mask[i] + cross_attn_layer_head_mask = cross_attn_head_mask[i] + torch.cuda.set_device(hidden_states.device) + + if self.gradient_checkpointing and self.training: + + def create_custom_forward(module): + def custom_forward(*inputs): + return tuple(module(*inputs, use_cache, output_attentions)) + + return custom_forward + + layer_outputs = checkpoint( + create_custom_forward(layer_module), + hidden_states, + extended_attention_mask, + position_bias, + encoder_hidden_states, + encoder_extended_attention_mask, + encoder_decoder_position_bias, + layer_head_mask, + cross_attn_layer_head_mask, + None, # past_key_value is always None with gradient checkpointing + ) + else: + layer_outputs = layer_module( + hidden_states, + attention_mask=extended_attention_mask, + position_bias=position_bias, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_extended_attention_mask, + encoder_decoder_position_bias=encoder_decoder_position_bias, + layer_head_mask=layer_head_mask, + cross_attn_layer_head_mask=cross_attn_layer_head_mask, + past_key_value=past_key_value, + use_cache=use_cache, + output_attentions=output_attentions, + ) + + # layer_outputs is a tuple with: + # hidden-states, key-value-states, (self-attention position bias), (self-attention weights), (cross-attention position bias), (cross-attention weights) + + if use_cache is False or use_cache is None: + layer_outputs = layer_outputs[:1] + (None,) + layer_outputs[1:] + hidden_states, present_key_value_state = layer_outputs[:2] + + # We share the position biases between the layers - the first layer store them + # layer_outputs = hidden-states, key-value-states (self-attention position bias), (self-attention weights), + # (cross-attention position bias), (cross-attention weights) + position_bias = layer_outputs[2] + + if in_decoder and encoder_hidden_states is not None: + encoder_decoder_position_bias = layer_outputs[4 if output_attentions else 3] + # append next layer key value states + if use_cache: + present_key_value_states = present_key_value_states + (present_key_value_state,) + + # last layer + if at_last_stage: + hidden_states = self.final_layer_norm(hidden_states) + hidden_states = self.dropout(hidden_states) + + if not return_dict: + return tuple( + v + for v in [ + hidden_states, + present_key_value_states, + all_hidden_states, + all_attentions, + all_cross_attentions, + ] + if v is not None + ) + return BaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=hidden_states, + past_key_values=present_key_value_states, + hidden_states=all_hidden_states, + attentions=all_attentions, + cross_attentions=all_cross_attentions, + ) + else: + return { + "hidden_states": hidden_states, + "position_bias": position_bias, + "encoder_decoder_position_bias": encoder_decoder_position_bias, + "backward_tensor_keys": ["hidden_states"], + } + + @staticmethod + def t5_model_forward( + self: T5Model, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + decoder_input_ids: Optional[torch.LongTensor] = None, + decoder_attention_mask: Optional[torch.BoolTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + decoder_head_mask: Optional[torch.FloatTensor] = None, + cross_attn_head_mask: Optional[torch.Tensor] = None, + encoder_outputs: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + inputs_embeds: Optional[torch.Tensor] = None, + decoder_inputs_embeds: Optional[torch.Tensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + stage_manager: Optional[PipelineStageManager] = None, + hidden_states: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + position_bias: Optional[torch.Tensor] = None, + encoder_decoder_position_bias: Optional[torch.Tensor] = None, + backward_tensor_keys: Optional[List[str]] = None, + stage_index: Optional[List[int]] = None, + decoder_starting_stage: Optional[int] = None, + ) -> Union[Tuple[torch.FloatTensor], Seq2SeqModelOutput]: + # This function is modified on the basis of transformers.models.t5.modeling_t5.T5Model.forward. + # Please refer to original code of transformers for more details. + + __HEAD_MASK_WARNING_MSG = """ + The input argument `head_mask` was split into two arguments `head_mask` and `decoder_head_mask`. Currently, + `decoder_head_mask` is set to copy `head_mask`, but this feature is deprecated and will be removed in future versions. + If you do not want to use any `decoder_head_mask` now, please set `decoder_head_mask = torch.ones(num_layers, + num_heads)`. + """ + + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + logger = logging.get_logger(__name__) + + # TODO(baizhou): left the recording kv-value tensors as () or None type, this feature may be added in the future. + if past_key_values: + logger.warning_once("Non-empty past_key_values is not supported for pipeline models at the moment.") + past_key_values = None + if output_attentions: + logger.warning_once("output_attentions=True is not supported for pipeline models at the moment.") + output_attentions = False + if output_hidden_states: + logger.warning_once("output_hidden_states=True is not supported for pipeline models at the moment.") + output_hidden_states = False + if use_cache: + logger.warning_once("use_cache=True is not supported for pipeline models at the moment.") + use_cache = False + + # FutureWarning: head_mask was separated into two input args - head_mask, decoder_head_mask + if head_mask is not None and decoder_head_mask is None: + if self.config.num_layers == self.config.num_decoder_layers: + warnings.warn(__HEAD_MASK_WARNING_MSG, FutureWarning) + decoder_head_mask = head_mask + + in_decoder = stage_manager.stage >= decoder_starting_stage + # Stage is in encoder, directly return the output of t5_stack_forward + if not in_decoder: + encoder_outputs = T5PipelineForwards.t5_stack_forward( + self.encoder, + input_ids=input_ids, + attention_mask=attention_mask, + inputs_embeds=inputs_embeds, + head_mask=head_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + stage_manager=stage_manager, + hidden_states=hidden_states, + position_bias=position_bias, + encoder_decoder_position_bias=encoder_decoder_position_bias, + stage_index=stage_index, + decoder_starting_stage=decoder_starting_stage, + ) + if stage_manager.stage == decoder_starting_stage - 1: + # last stage of encoder + return {"encoder_hidden_states": encoder_outputs[0]} + else: + return encoder_outputs + + at_last_decoder_stage = stage_manager.is_last_stage() + at_first_decoder_stage = stage_manager.stage == decoder_starting_stage + + if encoder_outputs is not None: + encoder_hidden_states = encoder_outputs[0] + elif encoder_hidden_states is None: + raise ValueError("Non-empty encoder_hidden_states should be passed in at decoder stages.") + + if not at_first_decoder_stage and hidden_states is None: + raise ValueError("If not at the first layer of decoder, non-empty hidden_states must be provided.") + + # Decode + decoder_outputs = T5PipelineForwards.t5_stack_forward( + self.decoder, + input_ids=decoder_input_ids, + attention_mask=decoder_attention_mask, + inputs_embeds=decoder_inputs_embeds, + past_key_values=past_key_values, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=attention_mask, + head_mask=decoder_head_mask, + cross_attn_head_mask=cross_attn_head_mask, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + stage_manager=stage_manager, + hidden_states=hidden_states, + position_bias=position_bias, + encoder_decoder_position_bias=encoder_decoder_position_bias, + stage_index=stage_index, + decoder_starting_stage=decoder_starting_stage, + ) + + # Directly return outputs of overloaded T5Stack forward if not at last stage. + if not at_last_decoder_stage: + # encoder_hidden_states should be passed to the next stage + decoder_outputs["encoder_hidden_states"] = encoder_hidden_states + return decoder_outputs + + if not return_dict: + return decoder_outputs + encoder_hidden_states + else: + return Seq2SeqModelOutput( + last_hidden_state=decoder_outputs.last_hidden_state, + past_key_values=decoder_outputs.past_key_values, + decoder_hidden_states=decoder_outputs.hidden_states, + decoder_attentions=decoder_outputs.attentions, + cross_attentions=decoder_outputs.cross_attentions, + encoder_last_hidden_state=encoder_hidden_states, + ) + + @staticmethod + def t5_for_conditional_generation_forward( + self: T5ForConditionalGeneration, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + decoder_input_ids: Optional[torch.LongTensor] = None, + decoder_attention_mask: Optional[torch.BoolTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + decoder_head_mask: Optional[torch.FloatTensor] = None, + cross_attn_head_mask: Optional[torch.Tensor] = None, + encoder_outputs: Optional[Tuple[Tuple[torch.Tensor]]] = None, + past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + decoder_inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + stage_manager: Optional[PipelineStageManager] = None, + hidden_states: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + position_bias: Optional[torch.Tensor] = None, + encoder_decoder_position_bias: Optional[torch.Tensor] = None, + backward_tensor_keys: Optional[List[str]] = None, + stage_index: Optional[List[int]] = None, + decoder_starting_stage: Optional[int] = None, + ) -> Union[Tuple[torch.FloatTensor], Seq2SeqLMOutput]: + # This function is modified on the basis of transformers.models.t5.modeling_t5.T5ForConditionalGeneration.forward. + # Please refer to original code of transformers for more details. + + __HEAD_MASK_WARNING_MSG = """ + The input argument `head_mask` was split into two arguments `head_mask` and `decoder_head_mask`. Currently, + `decoder_head_mask` is set to copy `head_mask`, but this feature is deprecated and will be removed in future versions. + If you do not want to use any `decoder_head_mask` now, please set `decoder_head_mask = torch.ones(num_layers, + num_heads)`. + """ + + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + logger = logging.get_logger(__name__) + + # TODO(baizhou): left the recording kv-value tensors as () or None type, this feature may be added in the future. + if past_key_values: + logger.warning_once("Non-empty past_key_values is not supported for pipeline models at the moment.") + past_key_values = None + if output_attentions: + logger.warning_once("output_attentions=True is not supported for pipeline models at the moment.") + output_attentions = False + if output_hidden_states: + logger.warning_once("output_hidden_states=True is not supported for pipeline models at the moment.") + output_hidden_states = False + if use_cache: + logger.warning_once("use_cache=True is not supported for pipeline models at the moment.") + use_cache = False + + # FutureWarning: head_mask was separated into two input args - head_mask, decoder_head_mask + if head_mask is not None and decoder_head_mask is None: + if self.config.num_layers == self.config.num_decoder_layers: + warnings.warn(__HEAD_MASK_WARNING_MSG, FutureWarning) + decoder_head_mask = head_mask + + in_decoder = stage_manager.stage >= decoder_starting_stage + + # Stage is in encoder, directly return the output of t5_stack_forward + if not in_decoder: + encoder_outputs = T5PipelineForwards.t5_stack_forward( + self.encoder, + input_ids=input_ids, + attention_mask=attention_mask, + inputs_embeds=inputs_embeds, + head_mask=head_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + stage_manager=stage_manager, + hidden_states=hidden_states, + position_bias=position_bias, + encoder_decoder_position_bias=encoder_decoder_position_bias, + stage_index=stage_index, + decoder_starting_stage=decoder_starting_stage, + ) + if stage_manager.stage == decoder_starting_stage - 1: + # last stage of encoder + return {"encoder_hidden_states": encoder_outputs[0]} + else: + return encoder_outputs + + at_last_decoder_stage = stage_manager.is_last_stage() + at_first_decoder_stage = stage_manager.stage == decoder_starting_stage + + if encoder_outputs is not None: + encoder_hidden_states = encoder_outputs[0] + elif encoder_hidden_states is None: + raise ValueError("Non-empty encoder_hidden_states should be passed in at decoder stages.") + + if not at_first_decoder_stage and hidden_states is None: + raise ValueError("If not at the first layer of decoder, non-empty hidden_states must be provided.") + + if labels is not None and decoder_input_ids is None and decoder_inputs_embeds is None: + # get decoder inputs from shifting lm labels to the right + decoder_input_ids = self._shift_right(labels) + + # Decode + decoder_outputs = T5PipelineForwards.t5_stack_forward( + self.decoder, + input_ids=decoder_input_ids, + attention_mask=decoder_attention_mask, + inputs_embeds=decoder_inputs_embeds, + past_key_values=past_key_values, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=attention_mask, + head_mask=decoder_head_mask, + cross_attn_head_mask=cross_attn_head_mask, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + stage_manager=stage_manager, + hidden_states=hidden_states, + position_bias=position_bias, + encoder_decoder_position_bias=encoder_decoder_position_bias, + stage_index=stage_index, + decoder_starting_stage=decoder_starting_stage, + ) + + # Directly return outputs of overloaded T5Stack forward if not at last stage. + if not at_last_decoder_stage: + # encoder_hidden_states should be passed to the next stage + decoder_outputs["encoder_hidden_states"] = encoder_hidden_states + return decoder_outputs + + sequence_output = decoder_outputs[0] + + if self.config.tie_word_embeddings: + # Rescale output before projecting on vocab + # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/transformer.py#L586 + sequence_output = sequence_output * (self.model_dim**-0.5) + + lm_logits = self.lm_head(sequence_output) + + loss = None + if labels is not None: + loss_fct = CrossEntropyLoss(ignore_index=-100) + # move labels to correct device to enable PP + labels = labels.to(lm_logits.device) + loss = loss_fct(lm_logits.view(-1, lm_logits.size(-1)), labels.view(-1)) + + if not return_dict: + output = (lm_logits,) + decoder_outputs[1:] + encoder_hidden_states + return ((loss,) + output) if loss is not None else output + + return Seq2SeqLMOutput( + loss=loss, + logits=lm_logits, + past_key_values=decoder_outputs.past_key_values, + decoder_hidden_states=decoder_outputs.hidden_states, + decoder_attentions=decoder_outputs.attentions, + cross_attentions=decoder_outputs.cross_attentions, + encoder_last_hidden_state=encoder_hidden_states, + ) + + @staticmethod + def t5_encoder_model_forward( + self: T5EncoderModel, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + stage_manager: Optional[PipelineStageManager] = None, + hidden_states: Optional[torch.FloatTensor] = None, + position_bias: Optional[torch.Tensor] = None, + encoder_decoder_position_bias: Optional[torch.Tensor] = None, + backward_tensor_keys: Optional[List[str]] = None, + stage_index: Optional[List[int]] = None, + decoder_starting_stage: Optional[int] = None, + ) -> Union[Tuple[torch.FloatTensor], BaseModelOutput]: + r""" + This function is modified on the basis of transformers.models.t5.modeling_gpt2.T5EncoderModel.forward. + Please refer to original code of transformers for more details. + ```""" + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = T5PipelineForwards.t5_stack_forward( + self.encoder, + input_ids=input_ids, + attention_mask=attention_mask, + inputs_embeds=inputs_embeds, + head_mask=head_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + stage_manager=stage_manager, + hidden_states=hidden_states, + position_bias=position_bias, + encoder_decoder_position_bias=encoder_decoder_position_bias, + stage_index=stage_index, + decoder_starting_stage=decoder_starting_stage, + ) + + return outputs + + +def get_t5_flash_attention_forward(): + try: + from xformers.ops import memory_efficient_attention as me_attention + except: + raise ImportError("Error: xformers module is not installed. Please install it to use flash attention.") + from transformers.models.t5.modeling_t5 import T5Attention + + def forward( + self: T5Attention, + hidden_states: torch.Tensor, + mask: Optional[torch.Tensor] = None, + key_value_states: Optional[torch.Tensor] = None, + position_bias: Optional[torch.Tensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + layer_head_mask: Optional[torch.Tensor] = None, + query_length: Optional[int] = None, + use_cache: bool = False, + output_attentions: bool = False, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]: + """ + Self-attention (if key_value_states is None) or attention over source sentence (provided by key_value_states). + """ + # Input is (batch_size, seq_length, dim) + # Mask is (batch_size, key_length) (non-causal) or (batch_size, key_length, key_length) + # past_key_value[0] is (batch_size, n_heads, q_len - 1, dim_per_head) + batch_size, seq_length = hidden_states.shape[:2] + + real_seq_length = seq_length + + if past_key_value is not None: + if len(past_key_value) != 2: + raise ValueError( + f"past_key_value should have 2 past states: keys and values. Got { len(past_key_value)} past states" + ) + real_seq_length += past_key_value[0].shape[2] if query_length is None else query_length + + key_length = real_seq_length if key_value_states is None else key_value_states.shape[1] + + def shape(states): + """projection""" + return states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim) + + def unshape(states): + """reshape""" + return states.view(batch_size, -1, self.inner_dim) + + def project(hidden_states, proj_layer, key_value_states, past_key_value): + """projects hidden states correctly to key/query states""" + if key_value_states is None: + # self-attn + # (batch_size, n_heads, seq_length, dim_per_head) + hidden_states = shape(proj_layer(hidden_states)) + elif past_key_value is None: + # cross-attn + # (batch_size, n_heads, seq_length, dim_per_head) + hidden_states = shape(proj_layer(key_value_states)) + + if past_key_value is not None: + if key_value_states is None: + # self-attn + # (batch_size, n_heads, key_length, dim_per_head) + hidden_states = torch.cat([past_key_value, hidden_states], dim=1) + elif past_key_value.shape[1] != key_value_states.shape[1]: + # checking that the `sequence_length` of the `past_key_value` is the same as + # the provided `key_value_states` to support prefix tuning + # cross-attn + # (batch_size, n_heads, seq_length, dim_per_head) + hidden_states = shape(proj_layer(key_value_states)) + else: + # cross-attn + hidden_states = past_key_value + return hidden_states + + # get query states + query_states = shape(self.q(hidden_states)) # (batch_size, n_heads, seq_length, dim_per_head) + + # get key/value states + key_states = project( + hidden_states, self.k, key_value_states, past_key_value[0] if past_key_value is not None else None + ) + value_states = project( + hidden_states, self.v, key_value_states, past_key_value[1] if past_key_value is not None else None + ) + + if position_bias is None: + if not self.has_relative_attention_bias: + position_bias = torch.zeros( + (1, self.n_heads, real_seq_length, key_length), device=query_states.device, dtype=query_states.dtype + ) + if self.gradient_checkpointing and self.training: + position_bias.requires_grad = True + else: + position_bias = self.compute_bias(real_seq_length, key_length, device=query_states.device) + + # if key and values are already calculated + # we want only the last query position bias + if past_key_value is not None: + position_bias = position_bias[:, :, -hidden_states.size(1) :, :] + + if mask is not None: + position_bias = position_bias + mask # (batch_size, n_heads, seq_length, key_length) + + if self.pruned_heads: + mask = torch.ones(position_bias.shape[1]) + mask[list(self.pruned_heads)] = 0 + position_bias_masked = position_bias[:, mask.bool()] + else: + position_bias_masked = position_bias + + position_bias_masked = position_bias_masked.contiguous() + attn_output = me_attention( + query_states, key_states, value_states, attn_bias=position_bias_masked, p=self.dropout, scale=1.0 + ) + attn_output = unshape(attn_output) + attn_output = self.o(attn_output) + + present_key_value_state = (key_states, value_states) if (self.is_decoder and use_cache) else None + + outputs = (attn_output,) + (present_key_value_state,) + (position_bias,) + + return outputs + + return forward + + +def get_jit_fused_T5_layer_ff_forward(): + from transformers.models.t5.modeling_t5 import T5LayerFF + + def forward(self: T5LayerFF, hidden_states: torch.Tensor) -> torch.Tensor: + forwarded_states = self.layer_norm(hidden_states) + forwarded_states = self.DenseReluDense(forwarded_states) + hidden_states = self.dropout_add(forwarded_states, hidden_states, self.dropout.p, self.dropout.training) + return hidden_states + + return forward + + +def get_T5_layer_self_attention_forward(): + from transformers.models.t5.modeling_t5 import T5LayerSelfAttention + + def forward( + self: T5LayerSelfAttention, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_bias: Optional[torch.Tensor] = None, + layer_head_mask: Optional[torch.Tensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + use_cache: bool = False, + output_attentions: bool = False, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]: + normed_hidden_states = self.layer_norm(hidden_states) + attention_output = self.SelfAttention( + normed_hidden_states, + mask=attention_mask, + position_bias=position_bias, + layer_head_mask=layer_head_mask, + past_key_value=past_key_value, + use_cache=use_cache, + output_attentions=output_attentions, + ) + hidden_states = self.dropout_add(attention_output[0], hidden_states, self.dropout.p, self.dropout.training) + outputs = (hidden_states,) + attention_output[1:] # add attentions if we output them + return outputs + + return forward + + +def get_T5_layer_cross_attention_forward(): + from transformers.models.t5.modeling_t5 import T5LayerCrossAttention + + def forward( + self: T5LayerCrossAttention, + hidden_states: torch.Tensor, + key_value_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_bias: Optional[torch.Tensor] = None, + layer_head_mask: Optional[torch.Tensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + use_cache: bool = False, + query_length: Optional[int] = None, + output_attentions: bool = False, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]: + normed_hidden_states = self.layer_norm(hidden_states) + attention_output = self.EncDecAttention( + normed_hidden_states, + mask=attention_mask, + key_value_states=key_value_states, + position_bias=position_bias, + layer_head_mask=layer_head_mask, + past_key_value=past_key_value, + use_cache=use_cache, + query_length=query_length, + output_attentions=output_attentions, + ) + layer_output = self.dropout_add(attention_output[0], hidden_states, self.dropout.p, self.dropout.training) + outputs = (layer_output,) + attention_output[1:] # add attentions if we output them + return outputs + + return forward diff --git a/colossalai/shardformer/modeling/vit.py b/colossalai/shardformer/modeling/vit.py new file mode 100644 index 0000000000000000000000000000000000000000..2db83b912112624cbc3199aa51cb7ceca443dbc6 --- /dev/null +++ b/colossalai/shardformer/modeling/vit.py @@ -0,0 +1,392 @@ +import math +from typing import List, Optional, Tuple, Union + +import torch +from transformers.models.vit.modeling_vit import BaseModelOutput, ViTEncoder +from transformers.utils import logging + +from colossalai.pipeline.stage_manager import PipelineStageManager + + +def _encoder_forward( + encoder: ViTEncoder, + start_idx: int, + end_idx: int, + hidden_states: torch.Tensor, + head_mask: Optional[torch.Tensor] = None, + return_dict: bool = True, + stage_manager: PipelineStageManager = None, +) -> Union[tuple, BaseModelOutput]: + for i in range(start_idx, end_idx): + layer_module = encoder.layer[i] + + layer_head_mask = head_mask[i] if head_mask is not None else None + + if encoder.gradient_checkpointing and encoder.training: + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs, False) + + return custom_forward + + layer_outputs = torch.utils.checkpoint.checkpoint( + create_custom_forward(layer_module), + hidden_states, + layer_head_mask, + ) + else: + layer_outputs = layer_module(hidden_states, layer_head_mask, False) + + hidden_states = layer_outputs[0] + if not stage_manager.is_last_stage(): + return hidden_states + else: + if not return_dict: + return tuple(hidden_states) + return BaseModelOutput( + last_hidden_state=hidden_states, + hidden_states=None, + attentions=None, + ) + + +def ViTModel_pipeline_forward(stage_manager: PipelineStageManager, stage_index: List[int]): + from transformers.models.vit.modeling_vit import BaseModelOutputWithPooling + + def pp_forward( + self, + pixel_values: Optional[torch.Tensor] = None, + bool_masked_pos: Optional[torch.BoolTensor] = None, + head_mask: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + interpolate_pos_encoding: Optional[bool] = None, + return_dict: Optional[bool] = None, + hidden_states: Optional[torch.FloatTensor] = None, + ) -> Union[Tuple, BaseModelOutputWithPooling]: + r""" + bool_masked_pos (`torch.BoolTensor` of shape `(batch_size, num_patches)`, *optional*): + Boolean masked positions. Indicates which patches are masked (1) and which aren't (0). + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + logger = logging.get_logger(__name__) + + # Preprocess passed in arguments + if output_attentions: + logger.warning_once("output_attentions=True is not supported for pipeline models at the moment.") + output_attentions = False + if output_hidden_states: + logger.warning_once("output_hidden_states=True is not supported for pipeline models at the moment.") + output_hidden_states = False + + # Prepare head mask if needed + # 1.0 in head_mask indicate we keep the head + # attention_probs has shape bsz x n_heads x N x N + # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] + # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] + head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) + + if stage_manager.is_first_stage(): + if pixel_values is None: + raise ValueError("You have to specify pixel_values") + + # TODO(FoolPlayer): maybe have a cleaner way to cast the input (from `ImageProcessor` side?) + expected_dtype = self.embeddings.patch_embeddings.projection.weight.dtype + if pixel_values.dtype != expected_dtype: + pixel_values = pixel_values.to(expected_dtype) + + embedding_output = self.embeddings( + pixel_values, bool_masked_pos=bool_masked_pos, interpolate_pos_encoding=interpolate_pos_encoding + ) + else: + assert ( + hidden_states is not None + ), f"Current stage is {stage_manager.stage}, hidden_states should not be None" + + # Go through encoder + if not stage_manager.is_last_stage(): + hidden_states = _encoder_forward( + encoder=self.encoder, + start_idx=stage_index[0], + end_idx=stage_index[1], + hidden_states=embedding_output, + head_mask=head_mask, + return_dict=return_dict, + stage_manager=stage_manager, + ) + return {"hidden_states": hidden_states} + else: + encoder_outputs = _encoder_forward( + encoder=self.encoder, + start_idx=stage_index[0], + end_idx=stage_index[1], + hidden_states=hidden_states, + head_mask=head_mask, + return_dict=return_dict, + stage_manager=stage_manager, + ) + + # Go through rest layers + sequence_output = encoder_outputs[0] + sequence_output = self.layernorm(sequence_output) + pooled_output = self.pooler(sequence_output) if self.pooler is not None else None + + if not return_dict: + head_outputs = (sequence_output, pooled_output) if pooled_output is not None else (sequence_output,) + return head_outputs + encoder_outputs[1:] + + return BaseModelOutputWithPooling( + last_hidden_state=sequence_output, + pooler_output=pooled_output, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + ) + + return pp_forward + + +def ViTForImageClassification_pipeline_forward(stage_manager: PipelineStageManager, stage_index: List[int]): + from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss + from transformers.models.vit.modeling_vit import ImageClassifierOutput + + def pp_forward( + self, + pixel_values: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + interpolate_pos_encoding: Optional[bool] = None, + return_dict: Optional[bool] = None, + hidden_states: Optional[torch.FloatTensor] = None, + ) -> Union[tuple, ImageClassifierOutput]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the image classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if not stage_manager.is_first_stage(): + assert ( + hidden_states is not None + ), f"Current stage is {stage_manager.stage}, hidden_states should not be None" + + outputs = self.vit( + pixel_values, + head_mask=head_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + interpolate_pos_encoding=interpolate_pos_encoding, + return_dict=return_dict, + hidden_states=hidden_states, + ) + + # not last stage, return hidden_states + if not stage_manager.is_last_stage(): + return outputs + else: + sequence_output = outputs[0] + + # last stage + logits = self.classifier(sequence_output[:, 0, :]) + + loss = None + if labels is not None: + # move labels to correct device to enable model parallelism + labels = labels.to(logits.device) + if self.config.problem_type is None: + if self.num_labels == 1: + self.config.problem_type = "regression" + elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): + self.config.problem_type = "single_label_classification" + else: + self.config.problem_type = "multi_label_classification" + + if self.config.problem_type == "regression": + loss_fct = MSELoss() + if self.num_labels == 1: + loss = loss_fct(logits.squeeze(), labels.squeeze()) + else: + loss = loss_fct(logits, labels) + elif self.config.problem_type == "single_label_classification": + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + elif self.config.problem_type == "multi_label_classification": + loss_fct = BCEWithLogitsLoss() + loss = loss_fct(logits, labels) + + if not return_dict: + output = (logits,) + outputs[1:] + return ((loss,) + output) if loss is not None else output + + return ImageClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + return pp_forward + + +def ViTForMaskedImageModeling_pipeline_forward(stage_manager: PipelineStageManager, stage_index: List[int]): + import math + + import torch.nn as nn + from transformers.models.vit.modeling_vit import ImageClassifierOutput, MaskedImageModelingOutput + + def pp_forward( + self, + pixel_values: Optional[torch.Tensor] = None, + bool_masked_pos: Optional[torch.BoolTensor] = None, + head_mask: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + interpolate_pos_encoding: Optional[bool] = None, + return_dict: Optional[bool] = None, + hidden_states: Optional[torch.FloatTensor] = None, + ) -> Union[tuple, ImageClassifierOutput]: + r""" + bool_masked_pos (`torch.BoolTensor` of shape `(batch_size, num_patches)`): + Boolean masked positions. Indicates which patches are masked (1) and which aren't (0). + + Returns: + + Examples: + ```python + >>> from transformers import AutoImageProcessor, ViTForMaskedImageModeling + >>> import torch + >>> from PIL import Image + >>> import requests + + >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + + >>> image_processor = AutoImageProcessor.from_pretrained("google/vit-base-patch16-224-in21k") + >>> model = ViTForMaskedImageModeling.from_pretrained("google/vit-base-patch16-224-in21k") + + >>> num_patches = (model.config.image_size // model.config.patch_size) ** 2 + >>> pixel_values = image_processor(images=image, return_tensors="pt").pixel_values + >>> # create random boolean mask of shape (batch_size, num_patches) + >>> bool_masked_pos = torch.randint(low=0, high=2, size=(1, num_patches)).bool() + + >>> outputs = model(pixel_values, bool_masked_pos=bool_masked_pos) + >>> loss, reconstructed_pixel_values = outputs.loss, outputs.reconstruction + >>> list(reconstructed_pixel_values.shape) + [1, 3, 224, 224] + ```""" + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if bool_masked_pos is not None and (self.config.patch_size != self.config.encoder_stride): + raise ValueError( + "When `bool_masked_pos` is provided, `patch_size` must be equal to `encoder_stride` to ensure that " + "the reconstructed image has the same dimensions as the input." + f"Got `patch_size` = {self.config.patch_size} and `encoder_stride` = {self.config.encoder_stride}." + ) + + if not stage_manager.is_first_stage(): + assert ( + hidden_states is not None + ), f"Current stage is {stage_manager.stage}, hidden_states should not be None" + + outputs = self.vit( + pixel_values, + bool_masked_pos=bool_masked_pos, + head_mask=head_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + interpolate_pos_encoding=interpolate_pos_encoding, + return_dict=return_dict, + hidden_states=hidden_states, + ) + if not stage_manager.is_last_stage(): + return outputs + else: + sequence_output = outputs[0] + + # Reshape to (batch_size, num_channels, height, width) + sequence_output = sequence_output[:, 1:] + batch_size, sequence_length, num_channels = sequence_output.shape + height = width = math.floor(sequence_length**0.5) + sequence_output = sequence_output.permute(0, 2, 1).reshape(batch_size, num_channels, height, width) + + # Reconstruct pixel values + reconstructed_pixel_values = self.decoder(sequence_output) + + masked_im_loss = None + if bool_masked_pos is not None: + size = self.config.image_size // self.config.patch_size + bool_masked_pos = bool_masked_pos.reshape(-1, size, size) + mask = ( + bool_masked_pos.repeat_interleave(self.config.patch_size, 1) + .repeat_interleave(self.config.patch_size, 2) + .unsqueeze(1) + .contiguous() + ) + reconstruction_loss = nn.functional.l1_loss(pixel_values, reconstructed_pixel_values, reduction="none") + masked_im_loss = (reconstruction_loss * mask).sum() / (mask.sum() + 1e-5) / self.config.num_channels + + if not return_dict: + output = (reconstructed_pixel_values,) + outputs[1:] + return ((masked_im_loss,) + output) if masked_im_loss is not None else output + + return MaskedImageModelingOutput( + loss=masked_im_loss, + reconstruction=reconstructed_pixel_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + return pp_forward + + +def get_vit_flash_self_attention_forward(): + from transformers.models.vit.modeling_vit import ViTSelfAttention + + from colossalai.kernel.cuda_native import ColoAttention + + def transpose_for_scores(x: torch.Tensor, num_attention_heads, attention_head_size) -> torch.Tensor: + new_x_shape = x.size()[:-1] + (num_attention_heads, attention_head_size) + x = x.view(new_x_shape) + return x + + def forward( + self: ViTSelfAttention, + hidden_states: torch.Tensor, + head_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, + ) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]: + mixed_query_layer = self.query(hidden_states) + + key_layer = transpose_for_scores(self.key(hidden_states), self.num_attention_heads, self.attention_head_size) + value_layer = transpose_for_scores( + self.value(hidden_states), self.num_attention_heads, self.attention_head_size + ) + query_layer = transpose_for_scores(mixed_query_layer, self.num_attention_heads, self.attention_head_size) + + scale = 1.0 / math.sqrt(self.attention_head_size) + attention = ColoAttention( + embed_dim=self.all_head_size, num_heads=self.num_attention_heads, dropout=self.dropout.p, scale=scale + ) + context_layer = attention(query_layer, key_layer, value_layer) + + outputs = (context_layer,) + + return outputs + + return forward + + +def get_jit_fused_vit_output_forward(): + from transformers.models.vit.modeling_vit import ViTOutput + + def forward(self: ViTOutput, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout_add(hidden_states, input_tensor, self.dropout.p, self.dropout.training) + return hidden_states + + return forward diff --git a/colossalai/shardformer/modeling/whisper.py b/colossalai/shardformer/modeling/whisper.py new file mode 100644 index 0000000000000000000000000000000000000000..ef59dbcee68043092edc8430ca905b1f339a069f --- /dev/null +++ b/colossalai/shardformer/modeling/whisper.py @@ -0,0 +1,977 @@ +import logging +import random +from typing import List, Optional, Tuple, Union + +import torch +from torch import nn +from torch.nn import CrossEntropyLoss +from transformers.modeling_outputs import ( + BaseModelOutput, + BaseModelOutputWithPastAndCrossAttentions, + Seq2SeqLMOutput, + Seq2SeqModelOutput, + SequenceClassifierOutput, +) +from transformers.models.whisper.modeling_whisper import ( + WhisperEncoder, + WhisperForAudioClassification, + WhisperForConditionalGeneration, + WhisperModel, +) +from transformers.utils import logging + +from colossalai.pipeline.stage_manager import PipelineStageManager + + +def get_whisper_flash_attention_forward(): + from transformers.models.whisper.modeling_whisper import WhisperAttention + + from colossalai.kernel.cuda_native import AttnMaskType, ColoAttention + + def shape(tensor: torch.Tensor, seq_len: int, bsz: int, num_heads: int, head_dim: int): + return tensor.view(bsz, seq_len, num_heads, head_dim).contiguous() + + def forward( + self: WhisperAttention, + hidden_states: torch.Tensor, + key_value_states: Optional[torch.Tensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + attention_mask: Optional[torch.Tensor] = None, + layer_head_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + """Input shape: Batch x Time x Channel""" + + # if key_value_states are provided this layer is used as a cross-attention layer + # for the decoder + is_cross_attention = key_value_states is not None + + bsz, tgt_len, _ = hidden_states.size() + + # get key, value proj + # `past_key_value[0].shape[2] == key_value_states.shape[1]` + # is checking that the `sequence_length` of the `past_key_value` is the same as + # the provided `key_value_states` to support prefix tuning + if ( + is_cross_attention + and past_key_value is not None + and past_key_value[0].shape[1] == key_value_states.shape[1] + ): + # reuse k,v, cross_attentions + key_states = past_key_value[0] + value_states = past_key_value[1] + elif is_cross_attention: + # cross_attentions + key_states = shape(self.k_proj(key_value_states), -1, bsz, self.num_heads, self.head_dim) + value_states = shape(self.v_proj(key_value_states), -1, bsz, self.num_heads, self.head_dim) + elif past_key_value is not None: + # reuse k, v, self_attention + key_states = shape(self.k_proj(hidden_states), -1, bsz, self.num_heads, self.head_dim) + value_states = shape(self.v_proj(hidden_states), -1, bsz, self.num_heads, self.head_dim) + key_states = torch.cat([past_key_value[0], key_states], dim=1) + value_states = torch.cat([past_key_value[1], value_states], dim=1) + else: + # self_attention + key_states = shape(self.k_proj(hidden_states), -1, bsz, self.num_heads, self.head_dim) + value_states = shape(self.v_proj(hidden_states), -1, bsz, self.num_heads, self.head_dim) + + if self.is_decoder: + # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. + # Further calls to cross_attention layer can then reuse all cross-attention + # key/value_states (first "if" case) + # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of + # all previous decoder key/value_states. Further calls to uni-directional self-attention + # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) + # if encoder bi-directional self-attention `past_key_value` is always `None` + past_key_value = (key_states, value_states) + + # get query proj + query_states = shape(self.q_proj(hidden_states), tgt_len, bsz, self.num_heads, self.head_dim) + + src_len = key_states.size(1) + if layer_head_mask is not None: + if layer_head_mask.size() != (self.num_heads,): + raise ValueError( + f"Head mask for a single layer should be of size {(self.num_heads,)}, but is" + f" {layer_head_mask.size()}" + ) + + attn_type = None + flash_attention_mask = None + + if self.is_decoder: + if attention_mask is not None: + if attention_mask.size() != (bsz, 1, tgt_len, src_len): + raise ValueError( + f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}" + ) + flash_attention_mask = ~(attention_mask[:, :, -1].squeeze(1).to(torch.bool).contiguous()) + attn_type = AttnMaskType.paddedcausal + + attention = ColoAttention( + embed_dim=self.embed_dim, num_heads=self.num_heads, dropout=self.dropout, scale=self.scaling + ) + attn_output = attention( + query_states, key_states, value_states, attn_mask=flash_attention_mask, attn_mask_type=attn_type + ) + + attn_output = self.out_proj(attn_output) + + return attn_output, None, past_key_value + + return forward + + +def get_jit_fused_whisper_encoder_layer_forward(): + from transformers.models.whisper.modeling_whisper import WhisperEncoderLayer + + def forward( + self: WhisperEncoderLayer, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor, + layer_head_mask: torch.Tensor, + output_attentions: bool = False, + ) -> torch.Tensor: + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`): attention mask of size + `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. + layer_head_mask (`torch.FloatTensor`): mask for attention heads in a given layer of size + `(encoder_attention_heads,)`. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + """ + residual = hidden_states + hidden_states = self.self_attn_layer_norm(hidden_states) + hidden_states, attn_weights, _ = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + layer_head_mask=layer_head_mask, + output_attentions=output_attentions, + ) + hidden_states = self.dropout_add(hidden_states, residual, self.dropout, self.training) + + residual = hidden_states + hidden_states = self.final_layer_norm(hidden_states) + hidden_states = self.activation_fn(self.fc1(hidden_states)) + hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training) + hidden_states = self.fc2(hidden_states) + hidden_states = self.dropout_add(hidden_states, residual, self.dropout, self.training) + + if hidden_states.dtype == torch.float16 and ( + torch.isinf(hidden_states).any() or torch.isnan(hidden_states).any() + ): + clamp_value = torch.finfo(hidden_states.dtype).max - 1000 + hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value) + + outputs = (hidden_states,) + + if output_attentions: + outputs += (attn_weights,) + + return outputs + + return forward + + +def get_jit_fused_whisper_decoder_layer_forward(): + from transformers.models.whisper.modeling_whisper import WhisperDecoderLayer + + def forward( + self: WhisperDecoderLayer, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + layer_head_mask: Optional[torch.Tensor] = None, + cross_attn_layer_head_mask: Optional[torch.Tensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = True, + ) -> torch.Tensor: + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`): attention mask of size + `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. + encoder_hidden_states (`torch.FloatTensor`): + cross attention input to the layer of shape `(batch, seq_len, embed_dim)` + encoder_attention_mask (`torch.FloatTensor`): encoder attention mask of size + `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. + layer_head_mask (`torch.FloatTensor`): mask for attention heads in a given layer of size + `(encoder_attention_heads,)`. + cross_attn_layer_head_mask (`torch.FloatTensor`): mask for cross-attention heads in a given layer of + size `(decoder_attention_heads,)`. + past_key_value (`Tuple(torch.FloatTensor)`): cached past key and value projection states + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + """ + residual = hidden_states + hidden_states = self.self_attn_layer_norm(hidden_states) + + # Self Attention + # decoder uni-directional self-attention cached key/values tuple is at positions 1,2 + self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None + # add present self-attn cache to positions 1,2 of present_key_value tuple + hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states=hidden_states, + past_key_value=self_attn_past_key_value, + attention_mask=attention_mask, + layer_head_mask=layer_head_mask, + output_attentions=output_attentions, + ) + hidden_states = self.dropout_add(hidden_states, residual, self.dropout, self.training) + + # Cross-Attention Block + cross_attn_present_key_value = None + cross_attn_weights = None + if encoder_hidden_states is not None: + residual = hidden_states + hidden_states = self.encoder_attn_layer_norm(hidden_states) + + # cross_attn cached key/values tuple is at positions 3,4 of present_key_value tuple + cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None + hidden_states, cross_attn_weights, cross_attn_present_key_value = self.encoder_attn( + hidden_states=hidden_states, + key_value_states=encoder_hidden_states, + attention_mask=encoder_attention_mask, + layer_head_mask=cross_attn_layer_head_mask, + past_key_value=cross_attn_past_key_value, + output_attentions=output_attentions, + ) + hidden_states = self.dropout_add(hidden_states, residual, self.dropout, self.training) + + # add cross-attn to positions 3,4 of present_key_value tuple + present_key_value = present_key_value + cross_attn_present_key_value + + # Fully Connected + residual = hidden_states + hidden_states = self.final_layer_norm(hidden_states) + hidden_states = self.activation_fn(self.fc1(hidden_states)) + hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training) + hidden_states = self.fc2(hidden_states) + hidden_states = self.dropout_add(hidden_states, residual, self.dropout, self.training) + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights, cross_attn_weights) + + if use_cache: + outputs += (present_key_value,) + + return outputs + + return forward + + +class WhisperPipelineForwards: + """ + This class serves as a micro library for forward function substitution of Llama models + under pipeline setting. + """ + + @staticmethod + def whisper_encoder_forward( + self: WhisperEncoder, + input_features, + attention_mask=None, + head_mask=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + stage_manager: Optional[PipelineStageManager] = None, + hidden_states: Optional[torch.FloatTensor] = None, + encoder_states=None, + all_attentions=None, + stage_index: Optional[List[int]] = None, + decoder_starting_stage: Optional[int] = None, + ): + r""" + Args: + input_features (`torch.LongTensor` of shape `(batch_size, feature_size, sequence_length)`): + Float values of mel features extracted from the raw speech waveform. Raw speech waveform can be + obtained by loading a `.flac` or `.wav` audio file into an array of type `List[float]` or a + `numpy.ndarray`, *e.g.* via the soundfile library (`pip install soundfile`). To prepare the array into + `input_features`, the [`AutoFeatureExtractor`] should be used for extracting the mel features, padding + and conversion into a tensor of type `torch.FloatTensor`. See [`~WhisperFeatureExtractor.__call__`] + attention_mask (`torch.Tensor`)`, *optional*): + Whisper does not support masking of the `input_features`, this argument is preserved for compatibility, + but it is not used. By default the silence in the input log mel spectrogram are ignored. + head_mask (`torch.Tensor` of shape `(encoder_layers, encoder_attention_heads)`, *optional*): + Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors + for more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + """ + logging.get_logger(__name__) + + stage = stage_manager.stage + at_first_stage = stage == 0 + at_last_stage = stage == decoder_starting_stage - 1 + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # Process inputs if at the first stage of encoder. + if at_first_stage: + inputs_embeds = nn.functional.gelu(self.conv1(input_features)) + inputs_embeds = nn.functional.gelu(self.conv2(inputs_embeds)) + + inputs_embeds = inputs_embeds.permute(0, 2, 1) + embed_pos = self.embed_positions.weight + + hidden_states = inputs_embeds + embed_pos + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + + encoder_states = () if output_hidden_states else None + all_attentions = () if output_attentions else None + + # check if head_mask has a correct number of layers specified if desired + if head_mask is not None: + assert head_mask.size()[0] == ( + len(self.layers) + ), f"The head_mask should be specified for {len(self.layers)} layers, but it is for {head_mask.size()[0]}." + + else: + if hidden_states is None: + raise ValueError( + "hidden_states shouldn't be None for stages other than the first stage of encoder/decoder." + ) + + start_idx, end_idx = stage_index[0], stage_index[1] + + for idx in range(start_idx, end_idx): + encoder_layer = self.layers[idx] + + if output_hidden_states: + encoder_states = encoder_states + (hidden_states,) + # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) + dropout_probability = random.uniform(0, 1) + if self.training and (dropout_probability < self.layerdrop): # skip the layer + layer_outputs = (None, None) + else: + if self.gradient_checkpointing and self.training: + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs, output_attentions) + + return custom_forward + + layer_outputs = torch.utils.checkpoint.checkpoint( + create_custom_forward(encoder_layer), + hidden_states, + None, + (head_mask[idx] if head_mask is not None else None), + ) + else: + layer_outputs = encoder_layer( + hidden_states, + None, + layer_head_mask=(head_mask[idx] if head_mask is not None else None), + output_attentions=output_attentions, + ) + + hidden_states = layer_outputs[0] + + if output_attentions: + all_attentions = all_attentions + (layer_outputs[1],) + + if at_last_stage: + hidden_states = self.layer_norm(hidden_states) + if output_hidden_states: + encoder_states = encoder_states + (hidden_states,) + + if not return_dict: + return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None) + return BaseModelOutput( + last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions + ) + + else: + return {"hidden_states": hidden_states, "head_mask": head_mask} + + @staticmethod + def whisper_decoder_forward( + self, + input_ids=None, + attention_mask=None, + encoder_hidden_states=None, + head_mask=None, + cross_attn_head_mask=None, + past_key_values=None, + inputs_embeds=None, + use_cache=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + stage_manager: Optional[PipelineStageManager] = None, + hidden_states: Optional[torch.FloatTensor] = None, + stage_index: Optional[List[int]] = None, + decoder_starting_stage: Optional[int] = None, + ): + r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you + provide it. + + Indices can be obtained using [`WhisperTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, encoder_sequence_length, hidden_size)`, *optional*): + Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention + of the decoder. + head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*): + Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + cross_attn_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*): + Mask to nullify selected heads of the attention modules in encoder to avoid performing cross-attention + on hidden heads. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of + shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of + shape `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. + + Contains pre-computed hidden-states (key and values in the self-attention blocks and in the + cross-attention blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. + + If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those + that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of + all `decoder_input_ids` of shape `(batch_size, sequence_length)`. inputs_embeds (`torch.FloatTensor` of + shape `(batch_size, sequence_length, hidden_size)`, *optional*): Optionally, instead of passing + `input_ids` you can choose to directly pass an embedded representation. This is useful if you want more + control over how to convert `input_ids` indices into associated vectors than the model's internal + embedding lookup matrix. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors + for more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + """ + logger = logging.get_logger(__name__) + stage = stage_manager.stage + at_first_stage = stage == decoder_starting_stage + at_last_stage = stage == stage_manager.num_stages - 1 + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None + next_decoder_cache = () if use_cache else None + + # check if head_mask/cross_attn_head_mask has a correct number of layers specified if desired + for attn_mask, mask_name in zip([head_mask, cross_attn_head_mask], ["head_mask", "cross_attn_head_mask"]): + if attn_mask is not None: + assert attn_mask.size()[0] == (len(self.layers)), ( + f"The `{mask_name}` should be specified for {len(self.layers)} layers, but it is for" + f" {head_mask.size()[0]}." + ) + + # past_key_values_length + past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0 + + if at_first_stage: + # retrieve input_ids and inputs_embeds + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") + elif input_ids is not None: + input_shape = input_ids.size() + input_ids = input_ids.view(-1, input_shape[-1]) + elif inputs_embeds is not None: + input_shape = inputs_embeds.size()[:-1] + else: + raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + # embed positions + if input_ids is not None: + positions = self.embed_positions(input_ids, past_key_values_length=past_key_values_length) + else: + positions = self.embed_positions(inputs_embeds, past_key_values_length=past_key_values_length) + + attention_mask = self._prepare_decoder_attention_mask( + attention_mask, input_shape, inputs_embeds, past_key_values_length + ) + + hidden_states = inputs_embeds + positions + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache = True` is incompatible with gradient checkpointing. Setting `use_cache = False`..." + ) + use_cache = False + + else: + if hidden_states is None: + raise ValueError( + "hidden_states shouldn't be None for stages other than the first stage of encoder/decoder." + ) + input_shape = hidden_states.size()[:-1] + + attention_mask = self._prepare_decoder_attention_mask( + attention_mask, input_shape, hidden_states, past_key_values_length + ) + + start_idx, end_idx = stage_index[0], stage_index[1] + + for idx in range(start_idx, end_idx): + # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) + decoder_layer = self.layers[idx] + + if output_hidden_states: + all_hidden_states += (hidden_states,) + dropout_probability = random.uniform(0, 1) + if self.training and (dropout_probability < self.layerdrop): + continue + + past_key_value = past_key_values[idx] if past_key_values is not None else None + + if self.gradient_checkpointing and self.training: + + def create_custom_forward(module): + def custom_forward(*inputs): + # None for past_key_value + return module(*inputs, output_attentions, use_cache) + + return custom_forward + + layer_outputs = torch.utils.checkpoint.checkpoint( + create_custom_forward(decoder_layer), + hidden_states, + attention_mask, + encoder_hidden_states, + None, # encoder attention mask + head_mask[idx] if head_mask is not None else None, + cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None, + None, # past_key_value + ) + else: + layer_outputs = decoder_layer( + hidden_states, + attention_mask=attention_mask, + encoder_hidden_states=encoder_hidden_states, + layer_head_mask=(head_mask[idx] if head_mask is not None else None), + cross_attn_layer_head_mask=( + cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None + ), + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + ) + hidden_states = layer_outputs[0] + + if use_cache: + next_decoder_cache += (layer_outputs[3 if output_attentions else 1],) + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + if encoder_hidden_states is not None: + all_cross_attentions += (layer_outputs[2],) + + if at_last_stage: + hidden_states = self.layer_norm(hidden_states) + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + next_cache = next_decoder_cache if use_cache else None + if not return_dict: + return tuple( + v + for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, all_cross_attentions] + if v is not None + ) + return BaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + cross_attentions=all_cross_attentions, + ) + + else: + return { + "head_mask": head_mask, + "cross_attn_head_mask": cross_attn_head_mask, + "hidden_states": hidden_states, + } + + @staticmethod + def whisper_model_forward( + self: WhisperModel, + input_features: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.LongTensor] = None, + decoder_input_ids: Optional[torch.LongTensor] = None, + decoder_attention_mask: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.Tensor] = None, + decoder_head_mask: Optional[torch.Tensor] = None, + cross_attn_head_mask: Optional[torch.Tensor] = None, + encoder_outputs: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + decoder_inputs_embeds: Optional[Tuple[torch.FloatTensor]] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + stage_manager: Optional[PipelineStageManager] = None, + hidden_states: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + stage_index: Optional[List[int]] = None, + decoder_starting_stage: Optional[int] = None, + ): + r""" + Returns: + + Example: + ```python + >>> import torch + >>> from transformers import AutoFeatureExtractor, WhisperModel + >>> from datasets import load_dataset + + >>> model = WhisperModel.from_pretrained("openai/whisper-base") + >>> feature_extractor = AutoFeatureExtractor.from_pretrained("openai/whisper-base") + >>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation") + >>> inputs = feature_extractor(ds[0]["audio"]["array"], return_tensors="pt") + >>> input_features = inputs.input_features + >>> decoder_input_ids = torch.tensor([[1, 1]]) * model.config.decoder_start_token_id + >>> last_hidden_state = model(input_features, decoder_input_ids=decoder_input_ids).last_hidden_state + >>> list(last_hidden_state.shape) + [1, 2, 512] + ```""" + # TODO: left the recording kv-value tensors as () or None type, this feature may be added in the future. + if past_key_values: + logger.warning_once("Non-empty past_key_values is not supported for pipeline models at the moment.") + past_key_values = None + if output_attentions: + logger.warning_once("output_attentions=True is not supported for pipeline models at the moment.") + output_attentions = False + if output_hidden_states: + logger.warning_once("output_hidden_states=True is not supported for pipeline models at the moment.") + output_hidden_states = False + if use_cache: + logger.warning_once("use_cache=True is not supported for pipeline models at the moment.") + use_cache = False + + logging.get_logger(__name__) + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + in_decoder = stage_manager.stage >= decoder_starting_stage + if not in_decoder: + if encoder_outputs is None: + input_features = self._mask_input_features(input_features, attention_mask=attention_mask) + + encoder_outputs = WhisperPipelineForwards.whisper_encoder_forward( + self.encoder, + input_features, + head_mask=head_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + stage_manager=stage_manager, + hidden_states=hidden_states, + stage_index=stage_index, + decoder_starting_stage=decoder_starting_stage, + ) + + if stage_manager.stage == decoder_starting_stage - 1: + # last stage of encoder + return {"encoder_hidden_states": encoder_outputs[0]} + else: + return encoder_outputs + # If the user passed a tuple for encoder_outputs, we wrap it in a BaseModelOutput when return_dict=True + elif return_dict and not isinstance(encoder_outputs, BaseModelOutput): + encoder_outputs = BaseModelOutput( + last_hidden_state=encoder_outputs[0], + hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None, + attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None, + ) + + at_last_decoder_stage = stage_manager.is_last_stage() + at_first_decoder_stage = stage_manager.stage == decoder_starting_stage + if encoder_outputs is not None: + encoder_hidden_states = encoder_outputs[0] + elif encoder_hidden_states is None: + raise ValueError("Non-empty encoder_hidden_states should be passed in at decoder stages.") + + if not at_first_decoder_stage and hidden_states is None: + raise ValueError("If not at the first layer of decoder, non-empty hidden_states must be provided.") + + # decoder outputs consists of (dec_features, past_key_value, dec_hidden, dec_attn) + decoder_outputs = WhisperPipelineForwards.whisper_decoder_forward( + self.decoder, + input_ids=decoder_input_ids, + attention_mask=decoder_attention_mask, + encoder_hidden_states=encoder_hidden_states, + head_mask=decoder_head_mask, + cross_attn_head_mask=cross_attn_head_mask, + past_key_values=past_key_values, + inputs_embeds=decoder_inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + stage_manager=stage_manager, + hidden_states=hidden_states, + stage_index=stage_index, + decoder_starting_stage=decoder_starting_stage, + ) + + # Directly return outputs of overloaded Whisper forward if not at last stage. + if not at_last_decoder_stage: + # encoder_hidden_states should be passed to the next stage + decoder_outputs["encoder_hidden_states"] = encoder_hidden_states + return decoder_outputs + + if not return_dict: + return decoder_outputs + encoder_outputs + + return Seq2SeqModelOutput( + last_hidden_state=decoder_outputs.last_hidden_state, + past_key_values=decoder_outputs.past_key_values, + decoder_hidden_states=decoder_outputs.hidden_states, + decoder_attentions=decoder_outputs.attentions, + cross_attentions=decoder_outputs.cross_attentions, + encoder_last_hidden_state=encoder_hidden_states, + ) + + @staticmethod + def whisper_for_conditional_generation_forward( + self: WhisperForConditionalGeneration, + input_features: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.LongTensor] = None, + decoder_input_ids: Optional[torch.LongTensor] = None, + decoder_attention_mask: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.Tensor] = None, + decoder_head_mask: Optional[torch.Tensor] = None, + cross_attn_head_mask: Optional[torch.Tensor] = None, + encoder_outputs: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + decoder_inputs_embeds: Optional[Tuple[torch.FloatTensor]] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + stage_manager: Optional[PipelineStageManager] = None, + hidden_states: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + stage_index: Optional[List[int]] = None, + decoder_starting_stage: Optional[int] = None, + ) -> Union[Tuple[torch.Tensor], Seq2SeqLMOutput]: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the language modeling loss. Indices should either be in `[0, ..., config.vocab_size]` + or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored (masked), the loss is + only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + Returns: + + Example: + + ```python + >>> import torch + >>> from transformers import AutoProcessor, WhisperForConditionalGeneration + >>> from datasets import load_dataset + + >>> processor = AutoProcessor.from_pretrained("openai/whisper-tiny.en") + >>> model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny.en") + + >>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation") + + >>> inputs = processor(ds[0]["audio"]["array"], return_tensors="pt") + >>> input_features = inputs.input_features + + >>> generated_ids = model.generate(inputs=input_features) + + >>> transcription = processor.batch_decode(generated_ids, skip_special_tokens=True)[0] + >>> transcription + ' Mr. Quilter is the apostle of the middle classes, and we are glad to welcome his gospel.' + ```""" + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if labels is not None: + if decoder_input_ids is None and decoder_inputs_embeds is None: + decoder_input_ids = shift_tokens_right( + labels, self.config.pad_token_id, self.config.decoder_start_token_id + ) + in_decoder = stage_manager.stage >= decoder_starting_stage + at_last_decoder_stage = stage_manager.is_last_stage() + outputs = WhisperPipelineForwards.whisper_model_forward( + self.model, + input_features, + attention_mask=attention_mask, + decoder_input_ids=decoder_input_ids, + encoder_outputs=encoder_outputs, + decoder_attention_mask=decoder_attention_mask, + head_mask=head_mask, + decoder_head_mask=decoder_head_mask, + cross_attn_head_mask=cross_attn_head_mask, + past_key_values=past_key_values, + decoder_inputs_embeds=decoder_inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + stage_manager=stage_manager, + hidden_states=hidden_states, + encoder_hidden_states=encoder_hidden_states, + stage_index=stage_index, + decoder_starting_stage=decoder_starting_stage, + ) + if not in_decoder: + return outputs + + if not at_last_decoder_stage: + # encoder_hidden_states should be passed to the next stage + outputs["encoder_hidden_states"] = encoder_hidden_states + return outputs + + lm_logits = self.proj_out(outputs[0]) + + loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() + # move labels to correct device to enable PP + labels = labels.to(lm_logits.device) + loss = loss_fct(lm_logits.view(-1, self.config.vocab_size), labels.reshape(-1)) + + if not return_dict: + output = (lm_logits,) + outputs[1:] + return ((loss,) + output) if loss is not None else output + + return Seq2SeqLMOutput( + loss=loss, + logits=lm_logits, + past_key_values=outputs.past_key_values, + decoder_hidden_states=outputs.decoder_hidden_states, + decoder_attentions=outputs.decoder_attentions, + cross_attentions=outputs.cross_attentions, + encoder_last_hidden_state=outputs.encoder_last_hidden_state, + encoder_hidden_states=outputs.encoder_hidden_states, + encoder_attentions=outputs.encoder_attentions, + ) + + @staticmethod + def whisper_for_audio_classification_forward( + self: WhisperForAudioClassification, + input_features: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.Tensor] = None, + encoder_outputs: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + labels: Optional[torch.LongTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + stage_manager: Optional[PipelineStageManager] = None, + hidden_states: Optional[torch.FloatTensor] = None, + encoder_states=None, + all_attentions=None, + stage_index: Optional[List[int]] = None, + decoder_starting_stage: Optional[int] = None, + ): + r""" + This function is modified on the basis of transformers.models.whisper.modeling_whisper.WhisperForAudioClassification.forward. + Please refer to original code of transformers for more details. + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # audio_classification only holds encoder + encoder_outputs = WhisperPipelineForwards.whisper_encoder_forward( + self.encoder, + input_features, + head_mask=head_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + stage_manager=stage_manager, + hidden_states=hidden_states, + stage_index=stage_index, + decoder_starting_stage=decoder_starting_stage, + ) + + if not stage_manager.is_last_stage(): + return encoder_outputs + + if self.config.use_weighted_layer_sum: + hidden_states = torch.stack(encoder_outputs, dim=1) + norm_weights = nn.functional.softmax(self.layer_weights, dim=-1) + hidden_states = (hidden_states * norm_weights.view(-1, 1, 1)).sum(dim=1) + else: + hidden_states = encoder_outputs[0] + + hidden_states = self.projector(hidden_states) + pooled_output = hidden_states.mean(dim=1) + + logits = self.classifier(pooled_output) + + loss = None + + if labels is not None: + loss_fct = CrossEntropyLoss() + # move labels to correct device to enable PP + labels = labels.to(logits.device) + loss = loss_fct(logits.view(-1, self.config.num_labels), labels.view(-1)) + + if not return_dict: + output = (logits,) + encoder_outputs[1:] + return ((loss,) + output) if loss is not None else output + + return SequenceClassifierOutput( + loss=loss, + logits=logits, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + ) diff --git a/tests/test_layers/test_sequence/checks_seq/__init__.py b/colossalai/shardformer/policies/__init__.py similarity index 100% rename from tests/test_layers/test_sequence/checks_seq/__init__.py rename to colossalai/shardformer/policies/__init__.py diff --git a/colossalai/shardformer/policies/auto_policy.py b/colossalai/shardformer/policies/auto_policy.py new file mode 100644 index 0000000000000000000000000000000000000000..f3587de15f86f330070f0636722a932073a87065 --- /dev/null +++ b/colossalai/shardformer/policies/auto_policy.py @@ -0,0 +1,222 @@ +import importlib +from dataclasses import dataclass +from typing import Optional + +import torch.nn as nn + +from .base_policy import Policy + +__all__ = ["PolicyLocation", "get_autopolicy", "import_policy"] + + +@dataclass +class PolicyLocation: + """ + PolicyLocation describes the location of a policy class. + + Args: + file_name (str): The file name of the policy under colossalai.shardformer.policies + class_name (str): The class name of the policy class + """ + + file_name: str + class_name: str + + +# we don't want to import all policies here +# as each policy file imports its own model zoo library +# we will allow the user to only import the policy file needed +_POLICY_LIST = { + # BERT + "transformers.models.bert.modeling_bert.BertModel": PolicyLocation(file_name="bert", class_name="BertModelPolicy"), + "transformers.models.bert.modeling_bert.BertForPreTraining": PolicyLocation( + file_name="bert", class_name="BertForPreTrainingPolicy" + ), + "transformers.models.bert.modeling_bert.BertLMHeadModel": PolicyLocation( + file_name="bert", class_name="BertLMHeadModelPolicy" + ), + "transformers.models.bert.modeling_bert.BertForMaskedLM": PolicyLocation( + file_name="bert", class_name="BertForMaskedLMPolicy" + ), + "transformers.models.bert.modeling_bert.BertForSequenceClassification": PolicyLocation( + file_name="bert", class_name="BertForSequenceClassificationPolicy" + ), + "transformers.models.bert.modeling_bert.BertForTokenClassification": PolicyLocation( + file_name="bert", class_name="BertForTokenClassificationPolicy" + ), + "transformers.models.bert.modeling_bert.BertForNextSentencePrediction": PolicyLocation( + file_name="bert", class_name="BertForNextSentencePredictionPolicy" + ), + "transformers.models.bert.modeling_bert.BertForMultipleChoice": PolicyLocation( + file_name="bert", class_name="BertForMultipleChoicePolicy" + ), + "transformers.models.bert.modeling_bert.BertForQuestionAnswering": PolicyLocation( + file_name="bert", class_name="BertForQuestionAnsweringPolicy" + ), + # LLaMA + "transformers.models.llama.modeling_llama.LlamaModel": PolicyLocation( + file_name="llama", class_name="LlamaModelPolicy" + ), + "transformers.models.llama.modeling_llama.LlamaForCausalLM": PolicyLocation( + file_name="llama", class_name="LlamaForCausalLMPolicy" + ), + "transformers.models.llama.modeling_llama.LlamaForSequenceClassification": PolicyLocation( + file_name="llama", class_name="LlamaForSequenceClassificationPolicy" + ), + # T5 + "transformers.models.t5.modeling_t5.T5Model": PolicyLocation(file_name="t5", class_name="T5ModelPolicy"), + "transformers.models.t5.modeling_t5.T5ForConditionalGeneration": PolicyLocation( + file_name="t5", class_name="T5ForConditionalGenerationPolicy" + ), + "transformers.models.t5.modeling_t5.T5EncoderModel": PolicyLocation(file_name="t5", class_name="T5EncoderPolicy"), + # GPT2 + "transformers.models.gpt2.modeling_gpt2.GPT2Model": PolicyLocation(file_name="gpt2", class_name="GPT2ModelPolicy"), + "transformers.models.gpt2.modeling_gpt2.GPT2LMHeadModel": PolicyLocation( + file_name="gpt2", class_name="GPT2LMHeadModelPolicy" + ), + "transformers.models.gpt2.modeling_gpt2.GPT2DoubleHeadsModel": PolicyLocation( + file_name="gpt2", class_name="GPT2DoubleHeadsModelPolicy" + ), + "transformers.models.gpt2.modeling_gpt2.GPT2ForQuestionAnswering": PolicyLocation( + file_name="gpt2", class_name="GPT2ForQuestionAnsweringPolicy" + ), + "transformers.models.gpt2.modeling_gpt2.GPT2ForTokenClassification": PolicyLocation( + file_name="gpt2", class_name="GPT2ForTokenClassificationPolicy" + ), + "transformers.models.gpt2.modeling_gpt2.GPT2ForSequenceClassification": PolicyLocation( + file_name="gpt2", class_name="GPT2ForSequenceClassificationPolicy" + ), + # ViT + "transformers.models.vit.modeling_vit.ViTModel": PolicyLocation(file_name="vit", class_name="ViTModelPolicy"), + "transformers.models.vit.modeling_vit.ViTForImageClassification": PolicyLocation( + file_name="vit", class_name="ViTForImageClassificationPolicy" + ), + "transformers.models.vit.modeling_vit.ViTForMaskedImageModeling": PolicyLocation( + file_name="vit", class_name="ViTForMaskedImageModelingPolicy" + ), + # OPT + "transformers.models.opt.modeling_opt.OPTModel": PolicyLocation(file_name="opt", class_name="OPTModelPolicy"), + "transformers.models.opt.modeling_opt.OPTForCausalLM": PolicyLocation( + file_name="opt", class_name="OPTForCausalLMPolicy" + ), + "transformers.models.opt.modeling_opt.OPTForSequenceClassification": PolicyLocation( + file_name="opt", class_name="OPTForSequenceClassificationPolicy" + ), + "transformers.models.opt.modeling_opt.OPTForQuestionAnswering": PolicyLocation( + file_name="opt", class_name="OPTForQuestionAnsweringPolicy" + ), + # Bloom + "transformers.models.bloom.modeling_bloom.BloomModel": PolicyLocation( + file_name="bloom", class_name="BloomModelPolicy" + ), + "transformers.models.bloom.modeling_bloom.BloomForCausalLM": PolicyLocation( + file_name="bloom", class_name="BloomForCausalLMPolicy" + ), + "transformers.models.bloom.modeling_bloom.BloomForSequenceClassification": PolicyLocation( + file_name="bloom", class_name="BloomForSequenceClassificationPolicy" + ), + "transformers.models.bloom.modeling_bloom.BloomForTokenClassification": PolicyLocation( + file_name="bloom", class_name="BloomForTokenClassificationPolicy" + ), + "transformers.models.bloom.modeling_bloom.BloomForQuestionAnswering": PolicyLocation( + file_name="bloom", class_name="BloomForQuestionAnsweringPolicy" + ), + # Whisper + "transformers.models.whisper.modeling_whisper.WhisperModel": PolicyLocation( + file_name="whisper", class_name="WhisperModelPolicy" + ), + "transformers.models.whisper.modeling_whisper.WhisperForConditionalGeneration": PolicyLocation( + file_name="whisper", class_name="WhisperForConditionalGenerationPolicy" + ), + "transformers.models.whisper.modeling_whisper.WhisperForAudioClassification": PolicyLocation( + file_name="whisper", class_name="WhisperForAudioClassificationPolicy" + ), + # Sam + "transformers.models.sam.modeling_sam.SamModel": PolicyLocation(file_name="sam", class_name="SamModelPolicy"), + # Blip2 + "transformers.models.blip_2.modeling_blip_2.Blip2Model": PolicyLocation( + file_name="blip2", class_name="Blip2ModelPolicy" + ), + "transformers.models.blip_2.modeling_blip_2.Blip2ForConditionalGeneration": PolicyLocation( + file_name="blip2", class_name="Blip2ForConditionalGenerationPolicy" + ), + # ChatGLM + "colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm.ChatGLMModel": PolicyLocation( + file_name="chatglm2", class_name="ChatGLMModelPolicy" + ), + "colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm.ChatGLMForConditionalGeneration": PolicyLocation( + file_name="chatglm2", class_name="ChatGLMForConditionalGenerationPolicy" + ), +} + +_INFER_POLICY_LIST = { + # LlaMa + "transformers.models.llama.modeling_llama.LlamaModel": PolicyLocation( + file_name="llama", class_name="LlamaModelInferPolicy" + ), + "transformers.models.llama.modeling_llama.LlamaForCausalLM": PolicyLocation( + file_name="llama", class_name="LlamaModelInferPolicy" + ), + # Bloom + "transformers.models.bloom.modeling_bloom.BloomModel": PolicyLocation( + file_name="bloom", class_name="BloomModelInferPolicy" + ), + "transformers.models.bloom.modeling_bloom.BloomForCausalLM": PolicyLocation( + file_name="bloom", class_name="BloomModelInferPolicy" + ), + # ChatGLM2 + "colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm.ChatGLMModel": PolicyLocation( + file_name="chatglm2", class_name="ChatGLM2InferPolicy" + ), + "colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm.ChatGLMForConditionalGeneration": PolicyLocation( + file_name="chatglm2", class_name="ChatGLM2ForConditionalGenerationInferPolicy" + ), +} + + +def import_policy(policy_location: PolicyLocation, inference_only: Optional[bool] = False) -> Policy: + """ + Dynamically import a Policy class based on the policy location. + """ + if inference_only: + module_name = f"colossalai.inference.tensor_parallel.policies.{policy_location.file_name}" + else: + module_name = f"colossalai.shardformer.policies.{policy_location.file_name}" + module = importlib.import_module(module_name) + return getattr(module, policy_location.class_name) + + +def _fullname(obj): + """ + Return the full name of an object, including the module name. + """ + klass = obj.__class__ + module = klass.__module__ + if module == "builtins": + return klass.__qualname__ # avoid outputs like 'builtins.str' + return module + "." + klass.__qualname__ + + +def get_autopolicy(model: nn.Module, inference_only: Optional[bool] = False) -> Policy: + r""" + Return the auto policy for the model + + Args: + model (:class:`nn.Module`): The model to get the auto policy + + Return: + :class:`Policy`: The auto policy for the model + """ + full_name = _fullname(model) + if inference_only: + policy_location = _INFER_POLICY_LIST.get(full_name, None) + else: + policy_location = _POLICY_LIST.get(full_name, None) + + if policy_location is None: + raise NotImplementedError( + f"Auto policy for {model.__class__.__qualname__} is not implemented\n. Supported models are {list(_POLICY_LIST.keys())} and {list(_INFER_POLICY_LIST.keys())}" + ) + else: + policy = import_policy(policy_location, inference_only) + return policy() diff --git a/colossalai/shardformer/policies/base_policy.py b/colossalai/shardformer/policies/base_policy.py new file mode 100644 index 0000000000000000000000000000000000000000..eb03500531bce8f63cbd82d6415467a23dd677d7 --- /dev/null +++ b/colossalai/shardformer/policies/base_policy.py @@ -0,0 +1,226 @@ +# part of code modified from https://github.com/tunib-ai/parallelformers + +from abc import ABC, abstractmethod +from dataclasses import dataclass +from typing import Any, Callable, Dict, List, Optional, Union + +import numpy as np +import torch.nn as nn +from torch import Tensor +from torch.nn import Module + +from colossalai.pipeline.stage_manager import PipelineStageManager + +from ..layer.parallel_module import ParallelModule +from ..shard.shard_config import ShardConfig + +__all__ = ["ParallelModule", "SubModuleReplacementDescription", "ModulePolicyDescription", "Policy"] + + +@dataclass +class SubModuleReplacementDescription: + r""" + Describe how a submodule will be replaced + + Args: + suffix (str): used to get the submodule object + target_module (ParallelModule): specifies the module class used to replace to submodule + kwargs (Dict[str, Any]): the dictionary used to pass extra arguments to the `ParallelModule.from_native_module` method. + ignore_if_not_exist (bool): if the submodule does not exist, ignore it or raise an exception + """ + suffix: str + target_module: ParallelModule + kwargs: Dict[str, Any] = None + ignore_if_not_exist: bool = False + + +@dataclass +class ModulePolicyDescription: + r""" + Describe how the attributes and parameters will be transformed in a policy. + + Args: + attribute_replacement (Dict[str, Any]): key is the attribute name, value is the attribute value after sharding + param_replacement (List[Callable]): a list of functions to perform in-place param replacement. The function + must receive only one arguments: module. One example is + + ```python + def example_replace_weight(module: torch.nn.Module): + weight = module.weight + new_weight = shard_rowwise(weight, process_group) + module.weight = torch.nn.Parameter(new_weight) + ``` + sub_module_replacement (List[SubModuleReplacementDescription]): each element in the list is a SubModuleReplacementDescription + object which specifies the module to be replaced and the target module used to replacement. + method_replace (Dict[str, Callable]): key is the method name, value is the method for replacement + """ + attribute_replacement: Dict[str, Any] = None + param_replacement: List[Callable] = None + sub_module_replacement: List[SubModuleReplacementDescription] = None + method_replacement: Dict[str, Callable] = None + + +class Policy(ABC): + r""" + The base class for all the policies. For each different model, it should have a different policy class, + like BertPolicy for Bert Model or OPTPolicy for OPT model. + + Shardformer has provided many built-in sharding policies for the mainstream models. You can use the + built-in policies by setting `policy = None`, which is already the default argument for `Shardformer.optimize`. + If you want to define your own policy, you can inherit from this class and overwrite the methods you want to modify. + """ + + def __init__(self) -> None: + self.shard_config: Optional[ShardConfig] = None + self.model: Optional[Module] = None + + def set_model(self, model: nn.Module) -> None: + r""" + Set model as an attribute of the Policy object so that we can access the model's attributes. + + Args: + model (:class:`nn.Module`): The model to be perform + """ + self.model = model + + def set_shard_config(self, shard_config: ShardConfig) -> None: + r""" + Set shard config as an attribute of the Policy object. + + Args: + shard_config (:class:`ShardConfig`): The shard config to be perform + """ + self.shard_config = shard_config + self.config_sanity_check() + + @property + def pipeline_stage_manager(self) -> Optional[PipelineStageManager]: + if self.shard_config is not None: + return self.shard_config.pipeline_stage_manager + return None + + @abstractmethod + def config_sanity_check(self): + """ + Check if the shard config is valid for the model. Raise an exception if the config is invalid. + This method is made abstractmethod with no default implementation because we want to the policy writer + to take note of the feature supported by his/her model and policy. + """ + + @abstractmethod + def preprocess(self) -> nn.Module: + r""" + Perform some preprocessing of the model, like reshaping the embedding layer. + """ + + @abstractmethod + def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: + r""" + This method returns the module policy, which is a dictionary. The key is the module name or the module object, + and the value is the ModulePolicyDescription object. The ModulePolicyDescription object describes how the module + will be transformed. + """ + + @abstractmethod + def postprocess(self) -> nn.Module: + r""" + Perform some postprocessing of the model, like binding the weight of embedding layer with + the classifier layer + """ + + def append_or_create_submodule_replacement( + self, + description: Union[SubModuleReplacementDescription, List[SubModuleReplacementDescription]], + policy: Dict[Union[str, nn.Module], ModulePolicyDescription], + target_key: Union[str, nn.Module], + ) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: + r""" + Append or create a new submodule replacement description to the policy for the given key. + + Args: + submodule_replace_desc (Union[SubModuleReplacementDescription, List[SubModuleReplacementDescription]]): the submodule replacement description to be appended + policy (Dict[Union[str, nn.Module], ModulePolicyDescription]): the policy to be updated + target_key (Union[str, nn.Module]): the key of the policy to be updated + """ + # convert to list + if isinstance(description, SubModuleReplacementDescription): + description = [description] + + # append or create a new description + if target_key in policy: + if policy[target_key].sub_module_replacement is None: + policy[target_key].sub_module_replacement = description + else: + policy[target_key].sub_module_replacement.extend(description) + else: + policy[target_key] = ModulePolicyDescription(sub_module_replacement=description) + + return policy + + def append_or_create_method_replacement( + self, + description: Dict[str, Callable], + policy: Dict[Union[str, nn.Module], ModulePolicyDescription], + target_key: Union[str, nn.Module], + ) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: + r""" + Append or create a new method replacement description to the policy for the given key. + + Args: + description (Union[SubModuleReplacementDescription, List[SubModuleReplacementDescription]]): the submodule replacement description to be appended + policy (Dict[Union[str, nn.Module], ModulePolicyDescription]): the policy to be updated + target_key (Union[str, nn.Module]): the key of the policy to be updated + """ + if target_key in policy: + if policy[target_key].method_replacement is None: + policy[target_key].method_replacement = description + else: + policy[target_key].method_replacement.update(description) + else: + policy[target_key] = ModulePolicyDescription(method_replacement=description) + + return policy + + def get_held_layers(self) -> List[Module]: + """Get layers that should be held in current stage. This method should be implemented by subclass. + + Returns: + List[Module]: List of layers that should be hold in current stage + """ + raise NotImplementedError + + def get_shared_params(self) -> List[Dict[int, Tensor]]: + """Get parameters that should be shared across stages. This method should be implemented by subclass. + + Returns: + List[Dict[int, Tensor]]: List of parameters that should be shared across stages. E.g. [{0: module.model.embed_tokens.weight, 3: module.lm_head.weight}] + """ + return [] + + @staticmethod + def distribute_layers(num_layers: int, num_stages: int) -> List[int]: + """Divide layers into stages""" + quotient = num_layers // num_stages + remainder = num_layers % num_stages + + # calculate the num_layers per stage + layers_per_stage = [quotient] * num_stages + + # deal with the rest layers + if remainder > 0: + start_position = num_stages // 2 - remainder // 2 + for i in range(start_position, start_position + remainder): + layers_per_stage[i] += 1 + return layers_per_stage + + @staticmethod + def get_stage_index(layers_per_stage: List[int], stage: int) -> List[int]: + """ + get the start index and end index of layers for each stage. + """ + num_layers_per_stage_accumulated = np.insert(np.cumsum(layers_per_stage), 0, 0) + + start_idx = num_layers_per_stage_accumulated[stage] + end_idx = num_layers_per_stage_accumulated[stage + 1] + + return [start_idx, end_idx] diff --git a/colossalai/shardformer/policies/bert.py b/colossalai/shardformer/policies/bert.py new file mode 100644 index 0000000000000000000000000000000000000000..14146de158aec65ca2c192710795715240179aca --- /dev/null +++ b/colossalai/shardformer/policies/bert.py @@ -0,0 +1,641 @@ +from functools import partial +from typing import Callable, Dict, List + +import torch.nn as nn +from torch import Tensor +from torch.nn import Module + +import colossalai.shardformer.layer as col_nn + +from ..modeling.bert import ( + BertPipelineForwards, + bert_sequence_parallel_forward_fn, + get_bert_flash_attention_forward, + get_jit_fused_bert_output_forward, + get_jit_fused_bert_self_output_forward, +) +from ..modeling.jit import get_jit_fused_dropout_add_func +from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription + +__all__ = [ + "BertPolicy", + "BertModelPolicy", + "BertForPreTrainingPolicy", + "BertLMdHeadModelPolicy", + "BertForMaskedLMPolicy", + "BertForNextSentencePredictionPolicy", + "BertForSequenceClassificationPolicy", + "BertForTokenClassificationPolicy", + "BertForMultipleChoicePolicy", + "BertForQuestionAnsweringPolicy", +] + + +class BertPolicy(Policy): + def config_sanity_check(self): + pass + + def preprocess(self): + # reshape the embedding layer + r""" + Reshape the Embedding layer to make the embedding dimension divisible by world_size + """ + # TODO: + if self.shard_config.enable_tensor_parallelism: + vocab_size = self.model.config.vocab_size + world_size = self.shard_config.tensor_parallel_size + if vocab_size % world_size != 0: + new_vocab_size = vocab_size + world_size - vocab_size % world_size + self.model.resize_token_embeddings(new_vocab_size) + return self.model + + def module_policy(self): + from transformers.models.bert.modeling_bert import ( + BertEmbeddings, + BertLayer, + BertModel, + BertOutput, + BertSelfAttention, + BertSelfOutput, + ) + + policy = {} + use_sequence_parallel = self.shard_config.enable_sequence_parallelism + overlap = self.shard_config.enable_sequence_overlap + if self.shard_config.enable_tensor_parallelism: + policy[BertLayer] = ModulePolicyDescription( + attribute_replacement={ + "attention.self.all_head_size": self.model.config.hidden_size + // self.shard_config.tensor_parallel_size, + "crossattention.self.all_head_size": self.model.config.hidden_size + // self.shard_config.tensor_parallel_size, + "attention.self.num_attention_heads": self.model.config.num_attention_heads + // self.shard_config.tensor_parallel_size, + "crossattention.self.num_attention_heads": self.model.config.num_attention_heads + // self.shard_config.tensor_parallel_size, + }, + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="attention.self.query", + target_module=col_nn.Linear1D_Col, + kwargs={"seq_parallel": use_sequence_parallel, "overlap": overlap}, + ), + SubModuleReplacementDescription( + suffix="attention.self.key", + target_module=col_nn.Linear1D_Col, + kwargs={"seq_parallel": use_sequence_parallel, "overlap": overlap}, + ), + SubModuleReplacementDescription( + suffix="attention.self.value", + target_module=col_nn.Linear1D_Col, + kwargs={"seq_parallel": use_sequence_parallel, "overlap": overlap}, + ), + SubModuleReplacementDescription( + suffix="attention.self.dropout", + target_module=col_nn.DropoutForParallelInput, + ), + SubModuleReplacementDescription( + suffix="attention.output.dense", + target_module=col_nn.Linear1D_Row, + kwargs={"seq_parallel": use_sequence_parallel}, + ), + SubModuleReplacementDescription( + suffix="attention.output.dropout", + target_module=col_nn.DropoutForParallelInput, + ), + SubModuleReplacementDescription( + suffix="intermediate.dense", + target_module=col_nn.Linear1D_Col, + kwargs={"seq_parallel": use_sequence_parallel, "overlap": overlap}, + ), + SubModuleReplacementDescription( + suffix="output.dense", + target_module=col_nn.Linear1D_Row, + kwargs={"seq_parallel": use_sequence_parallel}, + ), + SubModuleReplacementDescription( + suffix="output.dropout", + target_module=col_nn.DropoutForParallelInput, + ), + ], + ) + + policy[BertEmbeddings] = ModulePolicyDescription( + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="word_embeddings", + target_module=col_nn.VocabParallelEmbedding1D, + ), + SubModuleReplacementDescription( + suffix="dropout", + target_module=col_nn.DropoutForReplicatedInput, + ), + ] + ) + + if use_sequence_parallel: + self.append_or_create_method_replacement( + description={"forward": bert_sequence_parallel_forward_fn(self.shard_config)}, + policy=policy, + target_key=BertModel, + ) + + # optimization configuration + if self.shard_config.enable_fused_normalization: + # Handle bert layer + self.append_or_create_submodule_replacement( + description=[ + SubModuleReplacementDescription( + suffix="attention.output.LayerNorm", + target_module=col_nn.FusedLayerNorm, + ), + SubModuleReplacementDescription( + suffix="output.LayerNorm", + target_module=col_nn.FusedLayerNorm, + ), + ], + policy=policy, + target_key=BertLayer, + ) + # handle embedding layer + self.append_or_create_submodule_replacement( + description=[ + SubModuleReplacementDescription( + suffix="LayerNorm", + target_module=col_nn.FusedLayerNorm, + ) + ], + policy=policy, + target_key=BertEmbeddings, + ) + + # use flash attention + if self.shard_config.enable_flash_attention: + self.append_or_create_method_replacement( + description={ + "forward": get_bert_flash_attention_forward(), + }, + policy=policy, + target_key=BertSelfAttention, + ) + + # use jit operator + if self.shard_config.enable_jit_fused: + self.append_or_create_method_replacement( + description={ + "forward": get_jit_fused_bert_self_output_forward(), + "dropout_add": get_jit_fused_dropout_add_func(), + }, + policy=policy, + target_key=BertSelfOutput, + ) + self.append_or_create_method_replacement( + description={ + "forward": get_jit_fused_bert_output_forward(), + "dropout_add": get_jit_fused_dropout_add_func(), + }, + policy=policy, + target_key=BertOutput, + ) + + return policy + + def add_lm_head_policy(self, base_policy): + from transformers.models.bert.modeling_bert import BertLMPredictionHead + + # optimize for tensor parallelism + if self.shard_config.enable_tensor_parallelism: + self.append_or_create_submodule_replacement( + description=SubModuleReplacementDescription( + suffix="decoder", target_module=col_nn.Linear1D_Col, kwargs={"gather_output": True} + ), + policy=base_policy, + target_key=BertLMPredictionHead, + ) + + # optimize with fused normalization + if self.shard_config.enable_fused_normalization: + # Handle bert lm prediction head + self.append_or_create_submodule_replacement( + description=SubModuleReplacementDescription( + suffix="transform.LayerNorm", + target_module=col_nn.FusedLayerNorm, + ), + policy=base_policy, + target_key=BertLMPredictionHead, + ) + return base_policy + + def add_lm_prediction_policy(self, base_policy): + from transformers.models.bert.modeling_bert import BertLMPredictionHead + + method_replacement = { + "_save_to_state_dict": col_nn.ParallelModule._save_to_state_dict, + "_load_from_state_dict": col_nn.ParallelModule._load_from_state_dict, + } + self.append_or_create_method_replacement( + description=method_replacement, policy=base_policy, target_key=BertLMPredictionHead + ) + return base_policy + + def postprocess(self): + return self.model + + def set_pipeline_forward(self, model_cls: nn.Module, new_forward: Callable, policy: Dict) -> None: + """If under pipeline parallel setting, replacing the original forward method of huggingface + to customized forward method, and add this changing to policy.""" + if self.pipeline_stage_manager: + stage_manager = self.pipeline_stage_manager + if self.model.__class__.__name__ == "BertModel": + module = self.model + else: + module = self.model.bert + + layers_per_stage = Policy.distribute_layers(len(module.encoder.layer), stage_manager.num_stages) + stage_index = Policy.get_stage_index(layers_per_stage, stage_manager.stage) + method_replacement = { + "forward": partial( + new_forward, stage_manager=stage_manager, stage_index=stage_index, shard_config=self.shard_config + ) + } + self.append_or_create_method_replacement( + description=method_replacement, policy=policy, target_key=model_cls + ) + + return + + def get_held_layers(self) -> List[Module]: + """Get pipeline layers for current stage.""" + assert self.pipeline_stage_manager is not None + + if self.model.__class__.__name__ == "BertModel": + module = self.model + else: + module = self.model.bert + stage_manager = self.pipeline_stage_manager + + held_layers = [] + layers_per_stage = self.distribute_layers(len(module.encoder.layer), stage_manager.num_stages) + if stage_manager.is_first_stage(): + held_layers.append(module.embeddings) + start_idx, end_idx = self.get_stage_index(layers_per_stage, stage_manager.stage) + held_layers.extend(module.encoder.layer[start_idx:end_idx]) + if stage_manager.is_last_stage(): + held_layers.append(module.pooler) + + return held_layers + + +# BertModel +class BertModelPolicy(BertPolicy): + def __init__(self) -> None: + super().__init__() + + def module_policy(self): + policy = super().module_policy() + from transformers.models.bert.modeling_bert import BertModel + + if self.pipeline_stage_manager: + self.set_pipeline_forward( + model_cls=BertModel, new_forward=BertPipelineForwards.bert_model_forward, policy=policy + ) + return policy + + def get_held_layers(self) -> List[Module]: + """Get pipeline layers for current stage.""" + held_layers = super().get_held_layers() + return held_layers + + def get_shared_params(self) -> List[Dict[int, Tensor]]: + """No shared params in bert model""" + return [] + + +# BertForPreTraining +class BertForPreTrainingPolicy(BertPolicy): + def __init__(self) -> None: + super().__init__() + + def module_policy(self): + policy = super().module_policy() + policy = self.add_lm_head_policy(policy) + policy = self.add_lm_prediction_policy(policy) + from transformers.models.bert.modeling_bert import BertForPreTraining + + if self.pipeline_stage_manager: + self.set_pipeline_forward( + model_cls=BertForPreTraining, + new_forward=BertPipelineForwards.bert_for_pretraining_forward, + policy=policy, + ) + return policy + + def get_held_layers(self) -> List[Module]: + """Get pipeline layers for current stage""" + held_layers = super().get_held_layers() + stage_manager = self.pipeline_stage_manager + if stage_manager.is_last_stage(): + held_layers.append(self.model.cls) + + return held_layers + + def get_shared_params(self) -> List[Dict[int, Tensor]]: + model = self.model + if self.pipeline_stage_manager and self.pipeline_stage_manager.num_stages > 1: + if id(model.bert.embeddings.word_embeddings.weight) == id(model.cls.predictions.decoder.weight): + # tie weights + return [ + { + 0: model.bert.embeddings.word_embeddings.weight, + self.pipeline_stage_manager.num_stages - 1: model.cls.predictions.decoder.weight, + } + ] + return [] + + +# BertLMHeadModel +class BertLMHeadModelPolicy(BertPolicy): + def __init__(self) -> None: + super().__init__() + + def module_policy(self): + policy = super().module_policy() + policy = self.add_lm_head_policy(policy) + policy = self.add_lm_prediction_policy(policy) + from transformers.models.bert.modeling_bert import BertLMHeadModel + + if self.pipeline_stage_manager: + self.set_pipeline_forward( + model_cls=BertLMHeadModel, new_forward=BertPipelineForwards.bert_lm_head_model_forward, policy=policy + ) + return policy + + def get_held_layers(self) -> List[Module]: + """ + get pipeline layers for current stage + """ + held_layers = super().get_held_layers() + stage_manager = self.pipeline_stage_manager + if stage_manager.is_last_stage(): + held_layers.append(self.model.cls) + return held_layers + + def get_shared_params(self) -> List[Dict[int, Tensor]]: + bert_model = self.model.bert + if self.pipeline_stage_manager and self.pipeline_stage_manager.num_stages > 1: + if id(bert_model.embeddings.word_embeddings.weight) == id(self.model.cls.predictions.decoder.weight): + # tie weights + return [ + { + 0: bert_model.embeddings.word_embeddings.weight, + self.pipeline_stage_manager.num_stages - 1: self.model.cls.predictions.decoder.weight, + } + ] + return [] + + +# BertForMaskedLM +class BertForMaskedLMPolicy(BertPolicy): + def __init__(self) -> None: + super().__init__() + + def module_policy(self): + policy = super().module_policy() + policy = self.add_lm_head_policy(policy) + policy = self.add_lm_prediction_policy(policy) + from transformers.models.bert.modeling_bert import BertForMaskedLM + + if self.pipeline_stage_manager: + self.set_pipeline_forward( + model_cls=BertForMaskedLM, new_forward=BertPipelineForwards.bert_for_masked_lm_forward, policy=policy + ) + return policy + + def get_held_layers(self) -> List[Module]: + """ + get pipeline layers for current stage + """ + held_layers = super().get_held_layers() + stage_manager = self.pipeline_stage_manager + if stage_manager.is_last_stage(): + held_layers.append(self.model.cls) + return held_layers + + def get_shared_params(self) -> List[Dict[int, Tensor]]: + bert_model = self.model.bert + if self.pipeline_stage_manager and self.pipeline_stage_manager.num_stages > 1: + if id(bert_model.embeddings.word_embeddings.weight) == id(self.model.cls.predictions.decoder.weight): + # tie weights + return [ + { + 0: bert_model.embeddings.word_embeddings.weight, + self.pipeline_stage_manager.num_stages - 1: self.model.cls.predictions.decoder.weight, + } + ] + return [] + + +# BertForSequenceClassification +class BertForSequenceClassificationPolicy(BertPolicy): + def __init__(self) -> None: + super().__init__() + + def module_policy(self): + from transformers.models.bert.modeling_bert import BertForSequenceClassification + + policy = super().module_policy() + + if self.shard_config.enable_tensor_parallelism: + addon_module = { + BertForSequenceClassification: ModulePolicyDescription( + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="dropout", + target_module=col_nn.DropoutForParallelInput, + ) + ] + ) + } + policy.update(addon_module) + if self.pipeline_stage_manager: + self.set_pipeline_forward( + model_cls=BertForSequenceClassification, + new_forward=BertPipelineForwards.bert_for_sequence_classification_forward, + policy=policy, + ) + + return policy + + def get_held_layers(self) -> List[Module]: + """ + get pipeline layers for current stage + """ + held_layers = super().get_held_layers() + stage_manager = self.pipeline_stage_manager + if stage_manager.is_last_stage(): + held_layers.append(self.model.dropout) + held_layers.append(self.model.classifier) + return held_layers + + def get_shared_params(self) -> List[Dict[int, Tensor]]: + # no shared params for sequence classification model + return [] + + +# BertForTokenClassification +class BertForTokenClassificationPolicy(BertPolicy): + def __init__(self) -> None: + super().__init__() + + def module_policy(self): + from transformers.models.bert.modeling_bert import BertForTokenClassification + + policy = super().module_policy() + + if self.shard_config.enable_tensor_parallelism: + addon_module = { + BertForTokenClassification: ModulePolicyDescription( + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="dropout", + target_module=col_nn.DropoutForParallelInput, + ) + ] + ) + } + policy.update(addon_module) + if self.pipeline_stage_manager: + self.set_pipeline_forward( + model_cls=BertForTokenClassification, + new_forward=BertPipelineForwards.bert_for_token_classification_forward, + policy=policy, + ) + + return policy + + def get_held_layers(self) -> List[Module]: + """ + get pipeline layers for current stage + """ + held_layers = super().get_held_layers() + stage_manager = self.pipeline_stage_manager + if stage_manager.is_last_stage(): + held_layers.append(self.model.dropout) + held_layers.append(self.model.classifier) + return held_layers + + def get_shared_params(self) -> List[Dict[int, Tensor]]: + # no shared params for sequence classification model + return [] + + +# BertForNextSentencePrediction +class BertForNextSentencePredictionPolicy(BertPolicy): + def __init__(self) -> None: + super().__init__() + + def module_policy(self): + policy = super().module_policy() + from transformers.models.bert.modeling_bert import BertForNextSentencePrediction + + if self.pipeline_stage_manager: + self.set_pipeline_forward( + model_cls=BertForNextSentencePrediction, + new_forward=BertPipelineForwards.bert_for_next_sentence_prediction_forward, + policy=policy, + ) + + return policy + + def get_held_layers(self) -> List[Module]: + """ + get pipeline layers for current stage + """ + held_layers = super().get_held_layers() + stage_manager = self.pipeline_stage_manager + if stage_manager.is_last_stage(): + held_layers.append(self.model.cls) + return held_layers + + def get_shared_params(self) -> List[Dict[int, Tensor]]: + # no shared params for sequence classification model + return [] + + +# BertForMultipleChoice +class BertForMultipleChoicePolicy(BertPolicy): + def __init__(self) -> None: + super().__init__() + + def module_policy(self): + from transformers.models.bert.modeling_bert import BertForMultipleChoice + + policy = super().module_policy() + + if self.shard_config.enable_tensor_parallelism: + addon_module = { + BertForMultipleChoice: ModulePolicyDescription( + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="dropout", + target_module=col_nn.DropoutForParallelInput, + ) + ] + ) + } + policy.update(addon_module) + if self.pipeline_stage_manager: + self.set_pipeline_forward( + model_cls=BertForMultipleChoice, + new_forward=BertPipelineForwards.bert_for_multiple_choice_forward, + policy=policy, + ) + + return policy + + def get_held_layers(self) -> List[Module]: + """ + get pipeline layers for current stage + """ + held_layers = super().get_held_layers() + stage_manager = self.pipeline_stage_manager + if stage_manager.is_last_stage(): + held_layers.append(self.model.dropout) + held_layers.append(self.model.classifier) + return held_layers + + def get_shared_params(self) -> List[Dict[int, Tensor]]: + # no shared params for sequence classification model + return [] + + +class BertForQuestionAnsweringPolicy(BertPolicy): + def __init__(self) -> None: + super().__init__() + + def module_policy(self): + from transformers.models.bert.modeling_bert import BertForQuestionAnswering + + policy = super().module_policy() + if self.pipeline_stage_manager: + self.set_pipeline_forward( + model_cls=BertForQuestionAnswering, + new_forward=BertPipelineForwards.bert_for_question_answering_forward, + policy=policy, + ) + + return policy + + def get_held_layers(self) -> List[Module]: + """ + get pipeline layers for current stage + """ + held_layers = super().get_held_layers() + stage_manager = self.pipeline_stage_manager + if stage_manager.is_last_stage(): + held_layers.append(self.model.qa_outputs) + return held_layers + + def get_shared_params(self) -> List[Dict[int, Tensor]]: + # no shared params for sequence classification model + return [] diff --git a/colossalai/shardformer/policies/blip2.py b/colossalai/shardformer/policies/blip2.py new file mode 100644 index 0000000000000000000000000000000000000000..997643d1a91179b913924ae8cc3ced6599d2be41 --- /dev/null +++ b/colossalai/shardformer/policies/blip2.py @@ -0,0 +1,350 @@ +import colossalai.shardformer.layer as col_nn + +from ..modeling.blip2 import ( + forward_fn, + get_blip2_flash_attention_forward, + get_jit_fused_blip2_QFormer_output_forward, + get_jit_fused_blip2_QFormer_self_output_forward, +) +from ..modeling.jit import get_jit_fused_dropout_add_func +from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription + +__all__ = ["BlipPolicy", "BlipModelPolicy"] + + +class BlipPolicy(Policy): + def config_sanity_check(self): + pass + + def preprocess(self): + # reshape the embedding layer + r""" + Reshape the Embedding layer to make the embedding dimension divisible by world_size + """ + # TODO: + vocab_size = self.model.config.qformer_config.vocab_size + world_size = self.shard_config.tensor_parallel_size + if vocab_size % world_size != 0: + new_vocab_size = vocab_size + world_size - vocab_size % world_size + self.model.resize_token_embeddings(new_vocab_size) + return self.model + + def module_policy(self): + from transformers.models.blip_2.modeling_blip_2 import ( + Blip2Attention, + Blip2EncoderLayer, + Blip2QFormerLayer, + Blip2QFormerModel, + Blip2QFormerOutput, + Blip2QFormerSelfOutput, + Blip2VisionModel, + ) + from transformers.models.opt.modeling_opt import OPTDecoderLayer, OPTForCausalLM + + policy = {} + + if self.shard_config.enable_tensor_parallelism: + policy[Blip2EncoderLayer] = ModulePolicyDescription( + attribute_replacement={ + "self_attn.num_heads": self.model.config.vision_config.num_attention_heads + // self.shard_config.tensor_parallel_size, + "self_attn.embed_dim": self.model.config.vision_config.hidden_size + // self.shard_config.tensor_parallel_size, + }, + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="self_attn.dropout", + target_module=col_nn.DropoutForParallelInput, + ), + SubModuleReplacementDescription( + suffix="self_attn.qkv", + target_module=col_nn.FusedLinear1D_Col, + kwargs={ + "n_fused": 3, + }, + ), + SubModuleReplacementDescription( + suffix="self_attn.projection", + target_module=col_nn.Linear1D_Row, + ), + SubModuleReplacementDescription( + suffix="mlp.fc1", + target_module=col_nn.Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="mlp.fc2", + target_module=col_nn.Linear1D_Row, + ), + ], + ) + + policy[Blip2QFormerModel] = ModulePolicyDescription( + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="dropout", + target_module=col_nn.DropoutForParallelInput, + ), + ] + ) + + policy[Blip2QFormerLayer] = ModulePolicyDescription( + attribute_replacement={ + "attention.attention.num_attention_heads": self.model.config.qformer_config.num_attention_heads + // self.shard_config.tensor_parallel_size, + "attention.attention.all_head_size": self.model.config.qformer_config.hidden_size + // self.shard_config.tensor_parallel_size, + "crossattention.attention.num_attention_heads": self.model.config.qformer_config.num_attention_heads + // self.shard_config.tensor_parallel_size, + "crossattention.attention.all_head_size": self.model.config.qformer_config.hidden_size + // self.shard_config.tensor_parallel_size, + }, + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="attention.attention.query", + target_module=col_nn.Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="attention.attention.key", + target_module=col_nn.Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="attention.attention.value", + target_module=col_nn.Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="attention.attention.dropout", + target_module=col_nn.DropoutForParallelInput, + ), + SubModuleReplacementDescription( + suffix="attention.output.dense", + target_module=col_nn.Linear1D_Row, + ), + SubModuleReplacementDescription( + suffix="attention.output.dropout", + target_module=col_nn.DropoutForParallelInput, + ), + SubModuleReplacementDescription( + suffix="crossattention.attention.query", + target_module=col_nn.Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="crossattention.attention.key", + target_module=col_nn.Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="crossattention.attention.value", + target_module=col_nn.Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="crossattention.attention.dropout", + target_module=col_nn.DropoutForParallelInput, + ), + SubModuleReplacementDescription( + suffix="crossattention.output.dense", + target_module=col_nn.Linear1D_Row, + ), + SubModuleReplacementDescription( + suffix="crossattention.output.dropout", + target_module=col_nn.DropoutForParallelInput, + ), + SubModuleReplacementDescription( + suffix="intermediate_query.dense", + target_module=col_nn.Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="output_query.dense", + target_module=col_nn.Linear1D_Row, + ), + SubModuleReplacementDescription( + suffix="output_query.dropout", + target_module=col_nn.DropoutForParallelInput, + ), + ], + ) + + policy[OPTDecoderLayer] = ModulePolicyDescription( + attribute_replacement={ + "self_attn.embed_dim": self.model.config.text_config.hidden_size + // self.shard_config.tensor_parallel_size, + "self_attn.num_heads": self.model.config.text_config.num_attention_heads + // self.shard_config.tensor_parallel_size, + }, + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="self_attn.q_proj", + target_module=col_nn.Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="self_attn.k_proj", + target_module=col_nn.Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="self_attn.v_proj", + target_module=col_nn.Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="self_attn.out_proj", + target_module=col_nn.Linear1D_Row, + ), + SubModuleReplacementDescription( + suffix="fc1", + target_module=col_nn.Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="fc2", + target_module=col_nn.Linear1D_Row, + ), + ], + ) + + policy[OPTForCausalLM] = ModulePolicyDescription( + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="model.decoder.embed_tokens", + target_module=col_nn.VocabParallelEmbedding1D, + ), + SubModuleReplacementDescription( + suffix="lm_head", + target_module=col_nn.Linear1D_Col, + kwargs={"gather_output": True}, + ), + ] + ) + + policy[Blip2Attention] = ModulePolicyDescription(method_replacement={"forward": forward_fn()}) + + # optimization configuration + if self.shard_config.enable_fused_normalization: + # Handle Blip2EncoderLayer layer + self.append_or_create_submodule_replacement( + description=[ + SubModuleReplacementDescription( + suffix="layer_norm1", + target_module=col_nn.FusedLayerNorm, + ), + SubModuleReplacementDescription( + suffix="layer_norm2", + target_module=col_nn.FusedLayerNorm, + ), + ], + policy=policy, + target_key=Blip2EncoderLayer, + ) + + # handle Blip2VisionModel layer + self.append_or_create_submodule_replacement( + description=[ + SubModuleReplacementDescription( + suffix="post_layernorm", + target_module=col_nn.FusedLayerNorm, + ) + ], + policy=policy, + target_key=Blip2VisionModel, + ) + + # handle Blip2VisionModel layer + self.append_or_create_submodule_replacement( + description=[ + SubModuleReplacementDescription( + suffix="layernorm", + target_module=col_nn.FusedLayerNorm, + ) + ], + policy=policy, + target_key=Blip2QFormerModel, + ) + + # handle Blip2QFormerLayer layer + self.append_or_create_submodule_replacement( + description=[ + SubModuleReplacementDescription( + suffix="attention.output.LayerNorm", + target_module=col_nn.FusedLayerNorm, + ), + SubModuleReplacementDescription( + suffix="crossattention.output.LayerNorm", + target_module=col_nn.FusedLayerNorm, + ), + SubModuleReplacementDescription( + suffix="output_query.LayerNorm", + target_module=col_nn.FusedLayerNorm, + ), + ], + policy=policy, + target_key=Blip2QFormerLayer, + ) + + # handle OPTForCausalLM layer + self.append_or_create_submodule_replacement( + description=[ + SubModuleReplacementDescription( + suffix="model.decoder.final_layer_norm", + target_module=col_nn.FusedLayerNorm, + ) + ], + policy=policy, + target_key=OPTForCausalLM, + ) + + # handle OPTDecoderLayer layer + self.append_or_create_submodule_replacement( + description=[ + SubModuleReplacementDescription( + suffix="self_attn_layer_norm", + target_module=col_nn.FusedLayerNorm, + ), + SubModuleReplacementDescription( + suffix="final_layer_norm", + target_module=col_nn.FusedLayerNorm, + ), + ], + policy=policy, + target_key=OPTDecoderLayer, + ) + + # use flash attention + if self.shard_config.enable_flash_attention: + self.append_or_create_method_replacement( + description={ + "forward": get_blip2_flash_attention_forward(), + }, + policy=policy, + target_key=Blip2Attention, + ) + + # use jit operator + if self.shard_config.enable_jit_fused: + self.append_or_create_method_replacement( + description={ + "forward": get_jit_fused_blip2_QFormer_self_output_forward(), + "dropout_add": get_jit_fused_dropout_add_func(), + }, + policy=policy, + target_key=Blip2QFormerSelfOutput, + ) + self.append_or_create_method_replacement( + description={ + "forward": get_jit_fused_blip2_QFormer_output_forward(), + "dropout_add": get_jit_fused_dropout_add_func(), + }, + policy=policy, + target_key=Blip2QFormerOutput, + ) + + return policy + + def postprocess(self): + return self.model + + +# Blip2Model +class Blip2ModelPolicy(BlipPolicy): + def __init__(self) -> None: + super().__init__() + + +# Blip2ForConditionalGeneration +class Blip2ForConditionalGenerationPolicy(BlipPolicy): + def __init__(self) -> None: + super().__init__() diff --git a/colossalai/shardformer/policies/bloom.py b/colossalai/shardformer/policies/bloom.py new file mode 100644 index 0000000000000000000000000000000000000000..13b9dd31345d191f70d786974c099e079e17a3d7 --- /dev/null +++ b/colossalai/shardformer/policies/bloom.py @@ -0,0 +1,400 @@ +from functools import partial +from typing import Callable, Dict, List + +import torch.nn as nn +from torch import Tensor +from torch.nn import Module + +import colossalai.shardformer.layer as col_nn + +from ..modeling.bloom import ( + BloomPipelineForwards, + build_bloom_alibi_tensor_fn, + get_bloom_flash_attention_forward, + get_bloom_sequence_parallel_forward_fn, + get_jit_fused_bloom_attention_forward, + get_jit_fused_bloom_gelu_forward, + get_jit_fused_bloom_mlp_forward, +) +from ..modeling.jit import get_dropout_add_func, get_jit_fused_dropout_add_func, get_jit_fused_gelu_forward_func +from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription + + +class BloomPolicy(Policy): + def config_sanity_check(self): + pass + + def preprocess(self): + # reshape the embedding layer + r""" + Reshape the Embedding layer to make the embedding dimension divisible by world_size + """ + if self.shard_config.enable_tensor_parallelism: + vocab_size = self.model.config.vocab_size + world_size = self.shard_config.tensor_parallel_size + if vocab_size % world_size != 0: + new_vocab_size = vocab_size + world_size - vocab_size % world_size + self.model.resize_token_embeddings(new_vocab_size) + return self.model + + def module_policy(self): + from transformers.models.bloom.modeling_bloom import BloomAttention, BloomBlock, BloomGelu, BloomMLP, BloomModel + + policy = {} + + use_sequence_parallel = self.shard_config.enable_sequence_parallelism + overlap = self.shard_config.enable_sequence_overlap + if self.shard_config.enable_tensor_parallelism: + policy[BloomBlock] = ModulePolicyDescription( + attribute_replacement={ + "self_attention.hidden_size": self.model.config.hidden_size + // self.shard_config.tensor_parallel_size, + "self_attention.split_size": self.model.config.hidden_size + // self.shard_config.tensor_parallel_size, + "self_attention.num_heads": self.model.config.n_head // self.shard_config.tensor_parallel_size, + }, + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="self_attention.query_key_value", + target_module=col_nn.Linear1D_Col, + kwargs={"seq_parallel": use_sequence_parallel, "overlap": overlap}, + ), + SubModuleReplacementDescription( + suffix="self_attention.dense", + target_module=col_nn.Linear1D_Row, + kwargs={"seq_parallel": use_sequence_parallel}, + ), + SubModuleReplacementDescription( + suffix="self_attention.attention_dropout", + target_module=col_nn.DropoutForParallelInput, + ), + SubModuleReplacementDescription( + suffix="mlp.dense_h_to_4h", + target_module=col_nn.Linear1D_Col, + kwargs={"seq_parallel": use_sequence_parallel, "overlap": overlap}, + ), + SubModuleReplacementDescription( + suffix="mlp.dense_4h_to_h", + target_module=col_nn.Linear1D_Row, + kwargs={"seq_parallel": use_sequence_parallel}, + ), + ], + ) + + policy[BloomModel] = ModulePolicyDescription( + attribute_replacement={ + "num_heads": self.model.config.n_head // self.shard_config.tensor_parallel_size, + }, + method_replacement={ + "build_alibi_tensor": build_bloom_alibi_tensor_fn(self.shard_config.tensor_parallel_process_group) + }, + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="word_embeddings", + target_module=col_nn.VocabParallelEmbedding1D, + ) + ], + ) + + # optimization configuration + if self.shard_config.enable_fused_normalization: + # handle bloom model + self.append_or_create_submodule_replacement( + description=[ + SubModuleReplacementDescription( + suffix="ln_f", + target_module=col_nn.FusedLayerNorm, + ), + SubModuleReplacementDescription( + suffix="word_embeddings_layernorm", + target_module=col_nn.FusedLayerNorm, + ), + ], + policy=policy, + target_key=BloomModel, + ) + + # handle bloom block + self.append_or_create_submodule_replacement( + description=[ + SubModuleReplacementDescription( + suffix="input_layernorm", + target_module=col_nn.FusedLayerNorm, + ), + SubModuleReplacementDescription( + suffix="post_attention_layernorm", + target_module=col_nn.FusedLayerNorm, + ), + ], + policy=policy, + target_key=BloomBlock, + ) + + if use_sequence_parallel: + self.append_or_create_method_replacement( + description={"forward": get_bloom_sequence_parallel_forward_fn(self.shard_config)}, + policy=policy, + target_key=BloomModel, + ) + + if self.shard_config.enable_flash_attention: + self.append_or_create_method_replacement( + description={ + "forward": get_bloom_flash_attention_forward(), + "dropout_add": get_dropout_add_func(), + }, + policy=policy, + target_key=BloomAttention, + ) + + # enable jit fused operator + if self.shard_config.enable_jit_fused: + self.append_or_create_method_replacement( + description={ + "forward": get_jit_fused_bloom_attention_forward(), + "dropout_add": get_jit_fused_dropout_add_func(), + }, + policy=policy, + target_key=BloomAttention, + ) + self.append_or_create_method_replacement( + description={ + "forward": get_jit_fused_bloom_mlp_forward(), + "dropout_add": get_jit_fused_dropout_add_func(), + }, + policy=policy, + target_key=BloomMLP, + ) + self.append_or_create_method_replacement( + description={ + "forward": get_jit_fused_bloom_gelu_forward(), + "bloom_gelu_forward": get_jit_fused_gelu_forward_func(), + }, + policy=policy, + target_key=BloomGelu, + ) + + return policy + + def postprocess(self): + return self.model + + def set_pipeline_forward(self, model_cls: nn.Module, new_forward: Callable, policy: Dict) -> None: + """If under pipeline parallel setting, replacing the original forward method of huggingface + to customized forward method, and add this changing to policy.""" + if self.pipeline_stage_manager: + stage_manager = self.pipeline_stage_manager + if self.model.__class__.__name__ == "BloomModel": + module = self.model + else: + module = self.model.transformer + + layers_per_stage = Policy.distribute_layers(len(module.h), stage_manager.num_stages) + stage_index = Policy.get_stage_index(layers_per_stage, stage_manager.stage) + method_replacement = { + "forward": partial( + new_forward, stage_manager=stage_manager, stage_index=stage_index, shard_config=self.shard_config + ) + } + self.append_or_create_method_replacement( + description=method_replacement, policy=policy, target_key=model_cls + ) + return + + def get_held_layers(self) -> List[Module]: + """Get pipeline layers for current stage.""" + assert self.pipeline_stage_manager is not None + + if self.model.__class__.__name__ == "BloomModel": + module = self.model + else: + module = self.model.transformer + stage_manager = self.pipeline_stage_manager + + held_layers = [] + layers_per_stage = self.distribute_layers(len(module.h), stage_manager.num_stages) + if stage_manager.is_first_stage(): + held_layers.append(module.word_embeddings) + held_layers.append(module.word_embeddings_layernorm) + start_idx, end_idx = self.get_stage_index(layers_per_stage, stage_manager.stage) + held_layers.extend(module.h[start_idx:end_idx]) + if stage_manager.is_last_stage(): + held_layers.append(module.ln_f) + + return held_layers + + +class BloomModelPolicy(BloomPolicy): + def __init__(self) -> None: + super().__init__() + + def module_policy(self): + policy = super().module_policy() + from transformers.models.bloom.modeling_bloom import BloomModel + + if self.pipeline_stage_manager: + self.set_pipeline_forward( + model_cls=BloomModel, new_forward=BloomPipelineForwards.bloom_model_forward, policy=policy + ) + return policy + + def get_held_layers(self) -> List[Module]: + """ + get pipeline layers for current stage + """ + held_layers = super().get_held_layers() + return held_layers + + def get_shared_params(self) -> List[Dict[int, Tensor]]: + """no shared params in bloom model""" + return [] + + +class BloomForCausalLMPolicy(BloomPolicy): + def module_policy(self): + from transformers.models.bloom.modeling_bloom import BloomForCausalLM + + policy = super().module_policy() + + # handle tensor parallelism + if self.shard_config.enable_tensor_parallelism: + self.append_or_create_submodule_replacement( + description=SubModuleReplacementDescription( + suffix="lm_head", target_module=col_nn.Linear1D_Col, kwargs=dict(gather_output=True) + ), + policy=policy, + target_key=BloomForCausalLM, + ) + if self.pipeline_stage_manager: + self.set_pipeline_forward( + model_cls=BloomForCausalLM, new_forward=BloomPipelineForwards.bloom_for_causal_lm_forward, policy=policy + ) + return policy + + def get_held_layers(self) -> List[Module]: + """Get pipeline layers for current stage.""" + stage_manager = self.pipeline_stage_manager + held_layers = super().get_held_layers() + if stage_manager.is_last_stage(): + held_layers.append(self.model.lm_head) + return held_layers + + def get_shared_params(self) -> List[Dict[int, Tensor]]: + bloom_model = self.model + if self.pipeline_stage_manager and self.pipeline_stage_manager.num_stages > 1: + if id(bloom_model.transformer.word_embeddings.weight) == id(bloom_model.lm_head.weight): + # tie weights + return [ + { + 0: bloom_model.transformer.word_embeddings.weight, + self.pipeline_stage_manager.num_stages - 1: bloom_model.lm_head.weight, + } + ] + return [] + + +class BloomForSequenceClassificationPolicy(BloomPolicy): + def module_policy(self): + from transformers.models.bloom.modeling_bloom import BloomForSequenceClassification + + policy = super().module_policy() + + # handle tensor parallelism + if self.shard_config.enable_tensor_parallelism: + self.append_or_create_submodule_replacement( + description=SubModuleReplacementDescription( + suffix="score", target_module=col_nn.Linear1D_Col, kwargs=dict(gather_output=True) + ), + policy=policy, + target_key=BloomForSequenceClassification, + ) + if self.pipeline_stage_manager: + self.set_pipeline_forward( + model_cls=BloomForSequenceClassification, + new_forward=BloomPipelineForwards.bloom_for_sequence_classification_forward, + policy=policy, + ) + return policy + + def get_held_layers(self) -> List[Module]: + """Get pipeline layers for current stage.""" + stage_manager = self.pipeline_stage_manager + held_layers = super().get_held_layers() + if stage_manager.is_last_stage(): + held_layers.append(self.model.score) + return held_layers + + def get_shared_params(self) -> List[Dict[int, Tensor]]: + """No shared params in bloom for sequence classification model""" + return [] + + +class BloomForTokenClassificationPolicy(BloomPolicy): + def module_policy(self): + from transformers.models.bloom.modeling_bloom import BloomForTokenClassification + + policy = super().module_policy() + + # handle tensor parallelism + if self.shard_config.enable_tensor_parallelism: + self.append_or_create_submodule_replacement( + description=[ + SubModuleReplacementDescription( + suffix="classifier", target_module=col_nn.Linear1D_Col, kwargs=dict(gather_output=True) + ), + SubModuleReplacementDescription( + suffix="dropout", + target_module=col_nn.DropoutForReplicatedInput, + ), + ], + policy=policy, + target_key=BloomForTokenClassification, + ) + if self.pipeline_stage_manager: + self.set_pipeline_forward( + model_cls=BloomForTokenClassification, + new_forward=BloomPipelineForwards.bloom_for_token_classification_forward, + policy=policy, + ) + + return policy + + def get_held_layers(self) -> List[Module]: + """Get pipeline layers for current stage.""" + stage_manager = self.pipeline_stage_manager + held_layers = super().get_held_layers() + if stage_manager.is_last_stage(): + held_layers.append(self.model.dropout) + held_layers.append(self.model.classifier) + return held_layers + + def get_shared_params(self) -> List[Dict[int, Tensor]]: + """No shared params in bloom for token classification model""" + return [] + + +class BloomForQuestionAnsweringPolicy(BloomPolicy): + # No head sharding as the output features is only 2 + def module_policy(self): + from transformers.models.bloom.modeling_bloom import BloomForQuestionAnswering + + policy = super().module_policy() + if self.pipeline_stage_manager: + self.set_pipeline_forward( + model_cls=BloomForQuestionAnswering, + new_forward=BloomPipelineForwards.bloom_for_question_answering_forward, + policy=policy, + ) + return policy + + def get_held_layers(self) -> List[Module]: + """Get pipeline layers for current stage.""" + held_layers = super().get_held_layers() + stage_manager = self.pipeline_stage_manager + if stage_manager.is_last_stage(): + held_layers.append(self.model.qa_outputs) + return held_layers + + def get_shared_params(self) -> List[Dict[int, Tensor]]: + """No shared params in bloom for question answering model""" + return [] diff --git a/colossalai/shardformer/policies/chatglm2.py b/colossalai/shardformer/policies/chatglm2.py new file mode 100644 index 0000000000000000000000000000000000000000..3c27c848e738f4d6ff17247c88d157d1ba7dae3a --- /dev/null +++ b/colossalai/shardformer/policies/chatglm2.py @@ -0,0 +1,269 @@ +from functools import partial +from typing import Callable, Dict, List, Union + +import torch.nn as nn +from torch import Tensor + +import colossalai.shardformer.layer as col_nn +from colossalai.shardformer.modeling.chatglm2 import ChatGLMPipelineForwards +from colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm import ChatGLMForConditionalGeneration, ChatGLMModel + +from ..modeling.chatglm2 import ( + get_chatglm_sequence_parallel_forward_fn, + get_flash_core_attention_forward, + get_jit_fused_glm_block_forward, +) +from ..modeling.jit import get_jit_fused_dropout_add_func +from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription + +__all__ = ["ChatGLMPolicy", "ChatGLMModelPolicy", "ChatGLMForConditionalGenerationPolicy"] + + +class ChatGLMPolicy(Policy): + def config_sanity_check(self): + pass + + def preprocess(self): + # Resize embedding + if self.shard_config.enable_tensor_parallelism: + vocab_size = self.model.config.padded_vocab_size + world_size = self.shard_config.tensor_parallel_size + + if vocab_size % world_size != 0: + new_vocab_size = vocab_size + world_size - vocab_size % world_size + self.model.resize_token_embeddings(new_vocab_size) + + if self.pipeline_stage_manager is not None: + # the batch_size_dim is bounded to Model + bsz_dim = 1 + setattr(self.model, "batch_size_dim", bsz_dim) + + return self.model + + def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: + from colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm import ChatGLMModel, CoreAttention, GLMBlock + + policy = {} + + use_sequence_parallel = self.shard_config.enable_sequence_parallelism + overlap = self.shard_config.enable_sequence_overlap + if self.shard_config.enable_tensor_parallelism: + policy[ChatGLMModel] = ModulePolicyDescription( + attribute_replacement={}, + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="embedding.word_embeddings", + target_module=col_nn.VocabParallelEmbedding1D, + ) + ], + ) + + policy[GLMBlock] = ModulePolicyDescription( + attribute_replacement={ + "self_attention.num_attention_heads_per_partition": self.model.config.num_attention_heads + // self.shard_config.tensor_parallel_size, + "self_attention.projection_size": ( + self.model.config.kv_channels * self.model.config.num_attention_heads + ) + // self.shard_config.tensor_parallel_size, + "self_attention.qkv_hidden_size": ( + self.model.config.kv_channels * self.model.config.num_attention_heads * 3 + ) + // self.shard_config.tensor_parallel_size, + "self_attention.core_attention.num_attention_heads_per_partition": self.model.config.num_attention_heads + // self.shard_config.tensor_parallel_size, + "self_attention.core_attention.hidden_size_per_partition": self.model.config.kv_channels + * self.model.config.num_attention_heads + // self.shard_config.tensor_parallel_size, + }, + param_replacement=[], + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="self_attention.query_key_value", + target_module=col_nn.Linear1D_Col, + kwargs={"seq_parallel": use_sequence_parallel, "seq_parallel_dim": 0, "overlap": overlap}, + ), + SubModuleReplacementDescription( + suffix="self_attention.dense", + target_module=col_nn.Linear1D_Row, + kwargs={"seq_parallel": use_sequence_parallel, "seq_parallel_dim": 0}, + ), + SubModuleReplacementDescription( + suffix="self_attention.core_attention.attention_dropout", + target_module=col_nn.DropoutForParallelInput, + ), + ], + ) + + # optimization configuration + if self.shard_config.enable_fused_normalization: + if not self.model.config.rmsnorm: + self.append_or_create_submodule_replacement( + description=[ + SubModuleReplacementDescription(suffix="input_layernorm", target_module=col_nn.FusedLayerNorm), + SubModuleReplacementDescription( + suffix="post_attention_layernorm", target_module=col_nn.FusedLayerNorm + ), + ], + policy=policy, + target_key=GLMBlock, + ) + + if self.model.config.post_layer_norm: + self.append_or_create_submodule_replacement( + description=[ + SubModuleReplacementDescription( + suffix="encoder.final_layernorm", target_module=col_nn.FusedLayerNorm + ) + ], + policy=policy, + target_key=ChatGLMModel, + ) + + else: + self.append_or_create_submodule_replacement( + description=[ + SubModuleReplacementDescription(suffix="input_layernorm", target_module=col_nn.FusedRMSNorm), + SubModuleReplacementDescription( + suffix="post_attention_layernorm", target_module=col_nn.FusedRMSNorm + ), + ], + policy=policy, + target_key=GLMBlock, + ) + + if self.model.config.post_layer_norm: + self.append_or_create_submodule_replacement( + description=[ + SubModuleReplacementDescription( + suffix="encoder.final_layernorm", target_module=col_nn.FusedRMSNorm + ) + ], + policy=policy, + target_key=ChatGLMModel, + ) + + # use flash attention + if self.shard_config.enable_flash_attention: + self.append_or_create_method_replacement( + description={ + "forward": get_flash_core_attention_forward(), + }, + policy=policy, + target_key=CoreAttention, + ) + + # use sequence parallel + if use_sequence_parallel: + self.append_or_create_method_replacement( + description={"forward": get_chatglm_sequence_parallel_forward_fn(self.shard_config)}, + policy=policy, + target_key=ChatGLMModel, + ) + + # use jit fused operator + if self.shard_config.enable_jit_fused: + self.append_or_create_method_replacement( + description={ + "forward": get_jit_fused_glm_block_forward(), + "dropout_add": get_jit_fused_dropout_add_func(), + }, + policy=policy, + target_key=GLMBlock, + ) + + return policy + + def postprocess(self): + return self.model + + def get_held_layers(self) -> List[nn.Module]: + """Get pipeline layers for current stage.""" + assert self.pipeline_stage_manager is not None + + if self.model.__class__.__name__ == "ChatGLMModel": + module = self.model + else: + module = self.model.transformer + stage_manager = self.pipeline_stage_manager + + held_layers = [] + layers_per_stage = self.distribute_layers(module.num_layers, stage_manager.num_stages) + if stage_manager.is_first_stage(): + held_layers.append(module.embedding) + start_idx, end_idx = self.get_stage_index(layers_per_stage, stage_manager.stage) + held_layers.extend(module.encoder.layers[start_idx:end_idx]) + if stage_manager.is_last_stage(): + if module.encoder.post_layer_norm: + held_layers.append(module.encoder.final_layernorm) + + # rotary_pos_emb is needed for all stages + held_layers.append(module.rotary_pos_emb) + + return held_layers + + def set_pipeline_forward(self, model_cls: nn.Module, new_forward: Callable, policy: Dict) -> None: + """If under pipeline parallel setting, replacing the original forward method of huggingface + to customized forward method, and add this changing to policy.""" + if not self.pipeline_stage_manager: + raise ValueError("set_pipeline_forward method can only be called when pipeline parallel is enabled.") + stage_manager = self.pipeline_stage_manager + if self.model.__class__.__name__ == "ChatGLMModel": + module = self.model + else: + module = self.model.transformer + + layers_per_stage = Policy.distribute_layers(module.num_layers, stage_manager.num_stages) + stage_index = Policy.get_stage_index(layers_per_stage, stage_manager.stage) + method_replacement = { + "forward": partial( + new_forward, stage_manager=stage_manager, stage_index=stage_index, shard_config=self.shard_config + ) + } + self.append_or_create_method_replacement(description=method_replacement, policy=policy, target_key=model_cls) + + +class ChatGLMModelPolicy(ChatGLMPolicy): + def __init__(self) -> None: + super().__init__() + + def module_policy(self): + pass + + policy = super().module_policy() + + if self.pipeline_stage_manager is not None: + self.set_pipeline_forward( + model_cls=ChatGLMModel, new_forward=ChatGLMPipelineForwards.chatglm_model_forward, policy=policy + ) + return policy + + def get_held_layers(self) -> List[nn.Module]: + return super().get_held_layers() + + def get_shared_params(self) -> List[Dict[int, Tensor]]: + """No shared params in ChatGLMModel.""" + return [] + + +class ChatGLMForConditionalGenerationPolicy(ChatGLMModelPolicy): + def module_policy(self): + policy = super().module_policy() + + if self.pipeline_stage_manager is not None: + self.set_pipeline_forward( + model_cls=ChatGLMForConditionalGeneration, + new_forward=ChatGLMPipelineForwards.chatglm_for_conditional_generation_forward, + policy=policy, + ) + return policy + + def get_held_layers(self) -> List[nn.Module]: + held_layers = super().get_held_layers() + if self.pipeline_stage_manager.is_last_stage(): + held_layers.append(self.model.transformer.output_layer) + return held_layers + + def get_shared_params(self) -> List[Dict[int, Tensor]]: + """No shared params in ChatGLMForConditionalGenerationModel.""" + return [] diff --git a/colossalai/shardformer/policies/gpt2.py b/colossalai/shardformer/policies/gpt2.py new file mode 100644 index 0000000000000000000000000000000000000000..6f46bfc7ef9f8c36f9008a631e209b7614234314 --- /dev/null +++ b/colossalai/shardformer/policies/gpt2.py @@ -0,0 +1,414 @@ +from functools import partial +from typing import Callable, Dict, List + +from torch import Tensor, nn + +import colossalai.shardformer.layer as col_nn + +from ..modeling.gpt2 import GPT2PipelineForwards, get_gpt2_flash_attention_forward, gpt2_sequence_parallel_forward_fn +from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription + +__all__ = [ + "GPT2Policy", + "GPT2ModelPolicy", + "GPT2LMHeadModelPolicy", + "GPT2DoubleHeadsModelPolicy", + "GPT2ForTokenClassificationPolicy", + "GPT2ForSequenceClassificationPolicy", +] + + +class GPT2Policy(Policy): + def config_sanity_check(self): + pass + + def preprocess(self): + # reshape the embedding layer + r""" + Reshape the Embedding layer to make the embedding dimension divisible by world_size + """ + if self.shard_config.enable_tensor_parallelism: + vocab_size = self.model.config.vocab_size + world_size = self.shard_config.tensor_parallel_size + if vocab_size % world_size != 0: + new_vocab_size = vocab_size + world_size - vocab_size % world_size + self.model.resize_token_embeddings(new_vocab_size) + return self.model + + def module_policy(self): + from transformers.models.gpt2.modeling_gpt2 import GPT2Attention, GPT2Block, GPT2Model + + policy = {} + use_sequence_parallel = self.shard_config.enable_sequence_parallelism + overlap = self.shard_config.enable_sequence_overlap + if self.shard_config.enable_tensor_parallelism: + policy[GPT2Model] = ModulePolicyDescription( + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="wte", + target_module=col_nn.VocabParallelEmbedding1D, + ), + SubModuleReplacementDescription( + suffix="drop", + target_module=col_nn.DropoutForParallelInput, + ), + ] + ) + + policy[GPT2Block] = ModulePolicyDescription( + attribute_replacement={ + "attn.embed_dim": self.model.config.hidden_size // self.shard_config.tensor_parallel_size, + "attn.split_size": self.model.config.hidden_size // self.shard_config.tensor_parallel_size, + "attn.num_heads": self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size, + }, + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="attn.c_attn", + target_module=col_nn.GPT2FusedLinearConv1D_Col, + kwargs={"n_fused": 3, "seq_parallel": use_sequence_parallel, "overlap": overlap}, + ), + SubModuleReplacementDescription( + suffix="attn.c_proj", + target_module=col_nn.GPT2FusedLinearConv1D_Row, + kwargs={ + "seq_parallel": use_sequence_parallel, + }, + ), + SubModuleReplacementDescription( + suffix="mlp.c_fc", + target_module=col_nn.GPT2FusedLinearConv1D_Col, + kwargs={"n_fused": 1, "seq_parallel": use_sequence_parallel, "overlap": overlap}, + ), + SubModuleReplacementDescription( + suffix="mlp.c_proj", + target_module=col_nn.GPT2FusedLinearConv1D_Row, + kwargs={ + "seq_parallel": use_sequence_parallel, + }, + ), + SubModuleReplacementDescription( + suffix="attn.attn_dropout", + target_module=col_nn.DropoutForParallelInput, + ), + SubModuleReplacementDescription( + suffix="attn.resid_dropout", + target_module=col_nn.DropoutForParallelInput, + ), + SubModuleReplacementDescription( + suffix="mlp.dropout", + target_module=col_nn.DropoutForParallelInput, + ), + ], + ) + + # optimization configuration + if self.shard_config.enable_fused_normalization: + self.append_or_create_submodule_replacement( + description=SubModuleReplacementDescription( + suffix="ln_f", + target_module=col_nn.FusedLayerNorm, + ), + policy=policy, + target_key=GPT2Model, + ) + + self.append_or_create_submodule_replacement( + description=[ + SubModuleReplacementDescription( + suffix="ln_1", + target_module=col_nn.FusedLayerNorm, + ), + SubModuleReplacementDescription( + suffix="ln_2", + target_module=col_nn.FusedLayerNorm, + ), + SubModuleReplacementDescription( + suffix="ln_cross_attn", target_module=col_nn.FusedLayerNorm, ignore_if_not_exist=True + ), + ], + policy=policy, + target_key=GPT2Block, + ) + + if self.shard_config.enable_flash_attention: + self.append_or_create_method_replacement( + description={ + "forward": get_gpt2_flash_attention_forward(), + }, + policy=policy, + target_key=GPT2Attention, + ) + + if self.shard_config.enable_sequence_parallelism: + policy[GPT2Model].method_replacement = {"forward": gpt2_sequence_parallel_forward_fn(self.shard_config)} + + return policy + + def postprocess(self): + return self.model + + def get_held_layers(self) -> List[nn.Module]: + """Get pipeline layers for current stage.""" + assert self.pipeline_stage_manager is not None + + if self.model.__class__.__name__ == "GPT2Model": + module = self.model + else: + module = self.model.transformer + stage_manager = self.pipeline_stage_manager + + held_layers = [] + layers_per_stage = self.distribute_layers(len(module.h), stage_manager.num_stages) + if stage_manager.is_first_stage(): + held_layers.append(module.wte) + held_layers.append(module.wpe) + held_layers.append(module.drop) + start_idx, end_idx = self.get_stage_index(layers_per_stage, stage_manager.stage) + held_layers.extend(module.h[start_idx:end_idx]) + if stage_manager.is_last_stage(): + held_layers.append(module.ln_f) + return held_layers + + def set_pipeline_forward(self, model_cls: nn.Module, new_forward: Callable, policy: Dict) -> None: + """If under pipeline parallel setting, replacing the original forward method of huggingface + to customized forward method, and add this changing to policy.""" + if not self.pipeline_stage_manager: + raise ValueError("set_pipeline_forward method can only be called when pipeline parallel is enabled.") + stage_manager = self.pipeline_stage_manager + if self.model.__class__.__name__ == "GPT2Model": + module = self.model + else: + module = self.model.transformer + + layers_per_stage = Policy.distribute_layers(len(module.h), stage_manager.num_stages) + stage_index = Policy.get_stage_index(layers_per_stage, stage_manager.stage) + method_replacement = { + "forward": partial( + new_forward, stage_manager=stage_manager, stage_index=stage_index, shard_config=self.shard_config + ) + } + self.append_or_create_method_replacement(description=method_replacement, policy=policy, target_key=model_cls) + + +# GPT2Model +class GPT2ModelPolicy(GPT2Policy): + def __init__(self) -> None: + super().__init__() + + def module_policy(self): + from transformers.models.gpt2.modeling_gpt2 import GPT2Model + + policy = super().module_policy() + + if self.pipeline_stage_manager is not None: + self.set_pipeline_forward( + model_cls=GPT2Model, new_forward=GPT2PipelineForwards.gpt2_model_forward, policy=policy + ) + return policy + + def get_held_layers(self) -> List[nn.Module]: + return super().get_held_layers() + + def get_shared_params(self) -> List[Dict[int, Tensor]]: + """No shared params in GPT2Model.""" + return [] + + +# GPT2LMHeadModel +class GPT2LMHeadModelPolicy(GPT2Policy): + def __init__(self) -> None: + super().__init__() + + def module_policy(self): + from transformers.models.gpt2.modeling_gpt2 import GPT2LMHeadModel + + module_policy = super().module_policy() + + if self.shard_config.enable_tensor_parallelism: + addon_module = { + GPT2LMHeadModel: ModulePolicyDescription( + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="lm_head", target_module=col_nn.Linear1D_Col, kwargs={"gather_output": True} + ) + ] + ) + } + module_policy.update(addon_module) + + if self.pipeline_stage_manager is not None: + self.set_pipeline_forward( + model_cls=GPT2LMHeadModel, + new_forward=GPT2PipelineForwards.gpt2_lmhead_model_forward, + policy=module_policy, + ) + return module_policy + + def get_held_layers(self) -> List[nn.Module]: + held_layers = super().get_held_layers() + if self.pipeline_stage_manager.is_last_stage(): + held_layers.append(self.model.lm_head) + return held_layers + + def get_shared_params(self) -> List[Dict[int, Tensor]]: + """The weights of wte and lm_head are shared.""" + module = self.model + stage_manager = self.pipeline_stage_manager + if stage_manager is not None: + if stage_manager.num_stages > 1 and id(module.transformer.wte.weight) == id(module.lm_head.weight): + first_stage, last_stage = 0, stage_manager.num_stages - 1 + return [{first_stage: module.transformer.wte.weight, last_stage: module.lm_head.weight}] + return [] + + +# GPT2DoubleHeadsModel +class GPT2DoubleHeadsModelPolicy(GPT2Policy): + def __init__(self) -> None: + super().__init__() + + def module_policy(self): + from transformers.models.gpt2.modeling_gpt2 import GPT2DoubleHeadsModel + + module_policy = super().module_policy() + + if self.shard_config.enable_tensor_parallelism: + addon_module = { + GPT2DoubleHeadsModel: ModulePolicyDescription( + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="lm_head", target_module=col_nn.Linear1D_Col, kwargs={"gather_output": True} + ) + ] + ) + } + module_policy.update(addon_module) + + if self.pipeline_stage_manager is not None: + self.set_pipeline_forward( + model_cls=GPT2DoubleHeadsModel, + new_forward=GPT2PipelineForwards.gpt2_double_heads_model_forward, + policy=module_policy, + ) + + return module_policy + + def get_held_layers(self) -> List[nn.Module]: + held_layers = super().get_held_layers() + if self.pipeline_stage_manager.is_last_stage(): + multiple_choice_head = self.model.multiple_choice_head + held_layers.append(self.model.lm_head) + held_layers.append(multiple_choice_head.summary) + held_layers.append(multiple_choice_head.activation) + held_layers.append(multiple_choice_head.first_dropout) + held_layers.append(multiple_choice_head.last_dropout) + + return held_layers + + def get_shared_params(self) -> List[Dict[int, Tensor]]: + """The weights of wte and lm_head are shared.""" + module = self.model + stage_manager = self.pipeline_stage_manager + if stage_manager is not None: + if stage_manager.num_stages > 1 and id(module.transformer.wte.weight) == id(module.lm_head.weight): + first_stage, last_stage = 0, stage_manager.num_stages - 1 + return [{first_stage: module.transformer.wte.weight, last_stage: module.lm_head.weight}] + return [] + + +# GPT2ForQuestionAnswering +class GPT2ForQuestionAnsweringPolicy(GPT2Policy): + def __init__(self) -> None: + super().__init__() + + def module_policy(self): + from transformers.models.gpt2.modeling_gpt2 import GPT2ForQuestionAnswering + + module_policy = super().module_policy() + + if self.pipeline_stage_manager is not None: + self.set_pipeline_forward( + model_cls=GPT2ForQuestionAnswering, + new_forward=GPT2PipelineForwards.gpt2_for_question_answering_forward, + policy=module_policy, + ) + + return module_policy + + def get_held_layers(self) -> List[nn.Module]: + held_layers = super().get_held_layers() + if self.pipeline_stage_manager.is_last_stage(): + held_layers.append(self.model.qa_outputs) + return held_layers + + def get_shared_params(self) -> List[Dict[int, Tensor]]: + """No shared_params in gpt2 for QA.""" + return [] + + +# GPT2ForTokenClassification +class GPT2ForTokenClassificationPolicy(GPT2Policy): + def __init__(self) -> None: + super().__init__() + + def module_policy(self): + from transformers.models.gpt2.modeling_gpt2 import GPT2ForTokenClassification + + module_policy = super().module_policy() + + if self.shard_config.enable_tensor_parallelism: + addon_module = { + GPT2ForTokenClassification: ModulePolicyDescription( + sub_module_replacement=[ + SubModuleReplacementDescription(suffix="dropout", target_module=col_nn.DropoutForParallelInput) + ] + ) + } + module_policy.update(addon_module) + + if self.pipeline_stage_manager is not None: + self.set_pipeline_forward( + model_cls=GPT2ForTokenClassification, + new_forward=GPT2PipelineForwards.gpt2_for_token_classification_forward, + policy=module_policy, + ) + return module_policy + + def get_held_layers(self) -> List[nn.Module]: + held_layers = super().get_held_layers() + if self.pipeline_stage_manager.is_last_stage(): + held_layers.append(self.model.dropout) + held_layers.append(self.model.classifier) + return held_layers + + def get_shared_params(self) -> List[Dict[int, Tensor]]: + """No shared params in GPT2ForTokenClassification.""" + return [] + + +# GPT2ForSequenceClassification +class GPT2ForSequenceClassificationPolicy(GPT2Policy): + def __init__(self) -> None: + super().__init__() + + def module_policy(self): + from transformers.models.gpt2.modeling_gpt2 import GPT2ForSequenceClassification + + module_policy = super().module_policy() + + if self.pipeline_stage_manager is not None: + self.set_pipeline_forward( + model_cls=GPT2ForSequenceClassification, + new_forward=GPT2PipelineForwards.gpt2_for_sequence_classification_forward, + policy=module_policy, + ) + return module_policy + + def get_held_layers(self) -> List[nn.Module]: + held_layers = super().get_held_layers() + if self.pipeline_stage_manager.is_last_stage(): + held_layers.append(self.model.score) + return held_layers + + def get_shared_params(self) -> List[Dict[int, Tensor]]: + """No shared params in GPT2ForTokenClassification.""" + return [] diff --git a/colossalai/shardformer/policies/llama.py b/colossalai/shardformer/policies/llama.py new file mode 100644 index 0000000000000000000000000000000000000000..099995acb44068f19c6b5cf286cf9c148f0d785f --- /dev/null +++ b/colossalai/shardformer/policies/llama.py @@ -0,0 +1,291 @@ +import warnings +from functools import partial +from typing import Callable, Dict, List, Union + +import torch.nn as nn +from torch import Tensor +from torch.nn import Module + +from colossalai.shardformer.layer import FusedRMSNorm, Linear1D_Col, Linear1D_Row, VocabParallelEmbedding1D + +from ..modeling.llama import LlamaPipelineForwards, get_llama_flash_attention_forward +from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription + +__all__ = ["LlamaPolicy", "LlamaForCausalLMPolicy", "LlamaForSequenceClassificationPolicy"] + + +class LlamaPolicy(Policy): + def config_sanity_check(self): + pass + + def preprocess(self): + if self.shard_config.enable_tensor_parallelism: + # Resize embedding + vocab_size = self.model.config.vocab_size + world_size = self.shard_config.tensor_parallel_size + + if vocab_size % world_size != 0: + new_vocab_size = vocab_size + world_size - vocab_size % world_size + self.model.resize_token_embeddings(new_vocab_size) + + return self.model + + def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: + from transformers.models.llama.modeling_llama import LlamaAttention, LlamaDecoderLayer, LlamaModel + + policy = {} + + if self.shard_config.enable_sequence_parallelism: + self.shard_config.enable_sequence_parallelism = False + warnings.warn("Llama dosen't support sequence parallelism now, will ignore the sequence parallelism flag.") + + if self.shard_config.enable_tensor_parallelism: + decoder_attribute_replacement = { + "self_attn.hidden_size": self.model.config.hidden_size // self.shard_config.tensor_parallel_size, + "self_attn.num_heads": self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size, + } + if getattr(self.model.config, "num_key_value_heads", False): + decoder_attribute_replacement["self_attn.num_key_value_heads"] = ( + self.model.config.num_key_value_heads // self.shard_config.tensor_parallel_size + ) + + policy[LlamaDecoderLayer] = ModulePolicyDescription( + attribute_replacement=decoder_attribute_replacement, + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="self_attn.q_proj", + target_module=Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="self_attn.k_proj", + target_module=Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="self_attn.v_proj", + target_module=Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="self_attn.o_proj", + target_module=Linear1D_Row, + ), + SubModuleReplacementDescription( + suffix="mlp.gate_proj", + target_module=Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="mlp.up_proj", + target_module=Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="mlp.down_proj", + target_module=Linear1D_Row, + ), + ], + ) + + self.append_or_create_submodule_replacement( + description=SubModuleReplacementDescription( + suffix="embed_tokens", + target_module=VocabParallelEmbedding1D, + ), + policy=policy, + target_key=LlamaModel, + ) + + # optimization configuration + if self.shard_config.enable_fused_normalization: + self.append_or_create_submodule_replacement( + description=[ + SubModuleReplacementDescription( + suffix="input_layernorm", + target_module=FusedRMSNorm, + ), + SubModuleReplacementDescription( + suffix="post_attention_layernorm", + target_module=FusedRMSNorm, + ), + ], + policy=policy, + target_key=LlamaDecoderLayer, + ) + + self.append_or_create_submodule_replacement( + description=SubModuleReplacementDescription( + suffix="norm", + target_module=FusedRMSNorm, + ), + policy=policy, + target_key=LlamaModel, + ) + + if self.shard_config.enable_flash_attention: + self.append_or_create_method_replacement( + description={ + "forward": get_llama_flash_attention_forward(), + }, + policy=policy, + target_key=LlamaAttention, + ) + + return policy + + def postprocess(self): + return self.model + + def set_pipeline_forward(self, model_cls: nn.Module, new_forward: Callable, policy: Dict) -> None: + """If under pipeline parallel setting, replacing the original forward method of huggingface + to customized forward method, and add this changing to policy.""" + if self.pipeline_stage_manager: + stage_manager = self.pipeline_stage_manager + if self.model.__class__.__name__ == "LlamaModel": + module = self.model + else: + module = self.model.model + + layers_per_stage = Policy.distribute_layers(len(module.layers), stage_manager.num_stages) + stage_index = Policy.get_stage_index(layers_per_stage, stage_manager.stage) + method_replacement = {"forward": partial(new_forward, stage_manager=stage_manager, stage_index=stage_index)} + self.append_or_create_method_replacement( + description=method_replacement, policy=policy, target_key=model_cls + ) + + return + + def get_held_layers(self) -> List[Module]: + """Get pipeline layers for current stage.""" + assert self.pipeline_stage_manager is not None + + if self.model.__class__.__name__ == "LlamaModel": + module = self.model + else: + module = self.model.model + stage_manager = self.pipeline_stage_manager + + held_layers = [] + layers_per_stage = self.distribute_layers(len(module.layers), stage_manager.num_stages) + if stage_manager.is_first_stage(): + held_layers.append(module.embed_tokens) + start_idx, end_idx = self.get_stage_index(layers_per_stage, stage_manager.stage) + held_layers.extend(module.layers[start_idx:end_idx]) + if stage_manager.is_last_stage(): + held_layers.append(module.norm) + + return held_layers + + +class LlamaModelPolicy(LlamaPolicy): + def __init__(self) -> None: + super().__init__() + + def module_policy(self): + policy = super().module_policy() + from transformers.models.llama.modeling_llama import LlamaModel + + if self.pipeline_stage_manager: + # set None as default + self.set_pipeline_forward( + model_cls=LlamaModel, new_forward=LlamaPipelineForwards.llama_model_forward, policy=policy + ) + return policy + + def get_held_layers(self) -> List[Module]: + """Get pipeline layers for current stage.""" + held_layers = super().get_held_layers() + return held_layers + + def get_shared_params(self) -> List[Dict[int, Tensor]]: + """No shared params in llama model""" + return [] + + +class LlamaForCausalLMPolicy(LlamaPolicy): + def module_policy(self): + from transformers import LlamaForCausalLM + + policy = super().module_policy() + + if self.shard_config.enable_tensor_parallelism: + # add a new item for casual lm + new_item = { + LlamaForCausalLM: ModulePolicyDescription( + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="lm_head", target_module=Linear1D_Col, kwargs=dict(gather_output=True) + ) + ] + ) + } + policy.update(new_item) + + if self.pipeline_stage_manager: + # set None as default + self.set_pipeline_forward( + model_cls=LlamaForCausalLM, new_forward=LlamaPipelineForwards.llama_for_causal_lm_forward, policy=policy + ) + + return policy + + def get_held_layers(self) -> List[Module]: + """Get pipeline layers for current stage.""" + stage_manager = self.pipeline_stage_manager + held_layers = super().get_held_layers() + if stage_manager.is_last_stage(): + held_layers.append(self.model.lm_head) + return held_layers + + def get_shared_params(self) -> List[Dict[int, Tensor]]: + llama_model = self.model.model + if self.pipeline_stage_manager and self.pipeline_stage_manager.num_stages > 1: + if ( + id(llama_model.embed_tokens.weight) == id(self.model.lm_head.weight) + and self.pipeline_stage_manager.num_stages > 1 + ): + # tie weights + return [ + { + 0: llama_model.embed_tokens.weight, + self.pipeline_stage_manager.num_stages - 1: self.model.lm_head.weight, + } + ] + return [] + + +class LlamaForSequenceClassificationPolicy(LlamaPolicy): + def module_policy(self): + from transformers import LlamaForSequenceClassification + + policy = super().module_policy() + + if self.shard_config.enable_tensor_parallelism: + # add a new item for sequence classification + new_item = { + LlamaForSequenceClassification: ModulePolicyDescription( + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="score", target_module=Linear1D_Col, kwargs=dict(gather_output=True) + ) + ] + ) + } + policy.update(new_item) + # to be confirmed + if self.pipeline_stage_manager: + # set None as default + self.set_pipeline_forward( + model_cls=LlamaForSequenceClassification, + new_forward=LlamaPipelineForwards.llama_for_sequence_classification_forward, + policy=policy, + ) + return policy + + def get_held_layers(self) -> List[Module]: + """Get pipeline layers for current stage.""" + stage_manager = self.pipeline_stage_manager + held_layers = super().get_held_layers() + if stage_manager.is_last_stage(): + held_layers.append(self.model.score) + return held_layers + + def get_shared_params(self) -> List[Dict[int, Tensor]]: + """No shared params in llama for sequence classification model""" + return [] diff --git a/colossalai/shardformer/policies/opt.py b/colossalai/shardformer/policies/opt.py new file mode 100644 index 0000000000000000000000000000000000000000..5739d21a3903aed87c8f422998e8216ffb4d9772 --- /dev/null +++ b/colossalai/shardformer/policies/opt.py @@ -0,0 +1,308 @@ +import warnings +from functools import partial +from typing import Callable, Dict, List + +import torch.nn as nn +from torch import Tensor, nn + +from colossalai.shardformer.layer import FusedLayerNorm, Linear1D_Col, Linear1D_Row, VocabParallelEmbedding1D + +from .._utils import getattr_ +from ..modeling.jit import get_jit_fused_dropout_add_func +from ..modeling.opt import OPTPipelineForwards, get_jit_fused_opt_decoder_layer_forward, get_opt_flash_attention_forward +from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription + +__all__ = [ + "OPTPolicy", + "OPTModelPolicy", + "OPTForCausalLMPolicy", + "OPTForSequenceClassificationPolicy", + "OPTForQuestionAnsweringPolicy", +] + + +class OPTPolicy(Policy): + def config_sanity_check(self): + pass + + def preprocess(self): + # reshape the embedding layer + r""" + Reshape the Embedding layer to make the embedding dimension divisible by world_size + """ + if self.shard_config.enable_tensor_parallelism: + vocab_size = self.model.config.vocab_size + world_size = self.shard_config.tensor_parallel_size + if vocab_size % world_size != 0: + new_vocab_size = vocab_size + world_size - vocab_size % world_size + self.model.resize_token_embeddings(new_vocab_size) + return self.model + + def module_policy(self): + from transformers.models.opt.modeling_opt import OPTAttention, OPTDecoder, OPTDecoderLayer + + policy = {} + if self.shard_config.enable_sequence_parallelism: + self.shard_config.enable_sequence_parallelism = False + warnings.warn("OPT dosen't support sequence parallelism now, will ignore the sequence parallelism flag.") + + if self.shard_config.enable_tensor_parallelism: + policy[OPTDecoder] = ModulePolicyDescription( + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="embed_tokens", + target_module=VocabParallelEmbedding1D, + ) + ] + ) + policy[OPTDecoderLayer] = ModulePolicyDescription( + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="fc1", + target_module=Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="fc2", + target_module=Linear1D_Row, + ), + ] + ) + + policy[OPTAttention] = ModulePolicyDescription( + attribute_replacement={ + "embed_dim": self.model.config.hidden_size // self.shard_config.tensor_parallel_size, + "num_heads": self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size, + }, + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="q_proj", + target_module=Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="k_proj", + target_module=Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="v_proj", + target_module=Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="out_proj", + target_module=Linear1D_Row, + ), + ], + ) + + # optimization configuration + if self.shard_config.enable_fused_normalization: + self.append_or_create_submodule_replacement( + description=SubModuleReplacementDescription( + suffix="final_layer_norm", target_module=FusedLayerNorm, ignore_if_not_exist=True + ), + policy=policy, + target_key=OPTDecoder, + ) + self.append_or_create_submodule_replacement( + description=[ + SubModuleReplacementDescription( + suffix="self_attn_layer_norm", target_module=FusedLayerNorm, ignore_if_not_exist=True + ), + SubModuleReplacementDescription( + suffix="final_layer_norm", target_module=FusedLayerNorm, ignore_if_not_exist=True + ), + ], + policy=policy, + target_key=OPTDecoderLayer, + ) + + # use flash attention + if self.shard_config.enable_flash_attention: + self.append_or_create_method_replacement( + description={ + "forward": get_opt_flash_attention_forward(), + }, + policy=policy, + target_key=OPTAttention, + ) + + # use jit fused operator + if self.shard_config.enable_jit_fused: + self.append_or_create_method_replacement( + description={ + "forward": get_jit_fused_opt_decoder_layer_forward(), + "dropout_add": get_jit_fused_dropout_add_func(), + }, + policy=policy, + target_key=OPTDecoderLayer, + ) + + return policy + + def postprocess(self): + return self.model + + def get_held_layers(self) -> List[nn.Module]: + """Get pipeline layers for current stage.""" + assert self.pipeline_stage_manager is not None + + if self.model.__class__.__name__ == "OPTModel": + module = self.model.decoder + else: + module = self.model.model.decoder + stage_manager = self.pipeline_stage_manager + + held_layers = [] + layers_per_stage = self.distribute_layers(len(module.layers), stage_manager.num_stages) + if stage_manager.is_first_stage(): + held_layers.append(module.embed_tokens) + held_layers.append(module.embed_positions) + held_layers.append(module.project_in) + start_idx, end_idx = self.get_stage_index(layers_per_stage, stage_manager.stage) + held_layers.extend(module.layers[start_idx:end_idx]) + if stage_manager.is_last_stage(): + held_layers.append(module.final_layer_norm) + held_layers.append(module.project_out) + return held_layers + + def set_pipeline_forward(self, model_cls: nn.Module, new_forward: Callable, policy: Dict) -> None: + """If under pipeline parallel setting, replacing the original forward method of huggingface + to customized forward method, and add this changing to policy.""" + if self.pipeline_stage_manager: + stage_manager = self.pipeline_stage_manager + if self.model.__class__.__name__ == "OPTModel": + module = self.model.decoder + else: + module = self.model.model.decoder + + layers_per_stage = Policy.distribute_layers(len(module.layers), stage_manager.num_stages) + stage_index = Policy.get_stage_index(layers_per_stage, stage_manager.stage) + method_replacement = {"forward": partial(new_forward, stage_manager=stage_manager, stage_index=stage_index)} + self.append_or_create_method_replacement( + description=method_replacement, policy=policy, target_key=model_cls + ) + + +class OPTModelPolicy(OPTPolicy): + def __init__(self) -> None: + super().__init__() + + def module_policy(self): + from transformers.models.opt.modeling_opt import OPTModel + + policy = super().module_policy() + if self.pipeline_stage_manager: + self.set_pipeline_forward( + model_cls=OPTModel, new_forward=OPTPipelineForwards.opt_model_forward, policy=policy + ) + return policy + + def get_held_layers(self) -> List[nn.Module]: + return super().get_held_layers() + + def get_shared_params(self) -> List[Dict[int, Tensor]]: + """No shared params in OPTModel.""" + return [] + + +class OPTForCausalLMPolicy(OPTPolicy): + def module_policy(self): + from transformers.models.opt.modeling_opt import OPTForCausalLM + + policy = super().module_policy() + if self.shard_config.enable_tensor_parallelism: + self.append_or_create_submodule_replacement( + description=SubModuleReplacementDescription( + suffix="lm_head", target_module=Linear1D_Col, kwargs=dict(gather_output=True) + ), + policy=policy, + target_key=OPTForCausalLM, + ) + if self.pipeline_stage_manager: + self.set_pipeline_forward( + model_cls=OPTForCausalLM, new_forward=OPTPipelineForwards.opt_for_causal_lm_forward, policy=policy + ) + + return policy + + def get_held_layers(self) -> List[nn.Module]: + held_layers = super().get_held_layers() + if self.pipeline_stage_manager.is_last_stage(): + held_layers.append(self.model.lm_head) + return held_layers + + def get_shared_params(self) -> List[Dict[int, Tensor]]: + opt_model = self.model + if self.pipeline_stage_manager and self.pipeline_stage_manager.num_stages > 1: + num_stages = self.pipeline_stage_manager.num_stages + if id(opt_model.model.decoder.embed_tokens.weight) == id(opt_model.lm_head.weight): + return [{0: opt_model.model.decoder.embed_tokens.weight, num_stages - 1: opt_model.lm_head.weight}] + return [] + + def postprocess(self): + if self.shard_config.enable_tensor_parallelism and self.pipeline_stage_manager is None: + binding_map = { + "model.decoder.embed_tokens": "lm_head", + } + + for k, v in binding_map.items(): + src_mod = getattr_(self.model, k) + dst_mod = getattr_(self.model, v) + dst_mod.weight = src_mod.weight + + return self.model + + +class OPTForSequenceClassificationPolicy(OPTPolicy): + def __init__(self) -> None: + super().__init__() + + def module_policy(self): + from transformers.models.opt.modeling_opt import OPTForSequenceClassification + + policy = super().module_policy() + if self.pipeline_stage_manager: + self.set_pipeline_forward( + model_cls=OPTForSequenceClassification, + new_forward=OPTPipelineForwards.opt_for_sequence_classification_forward, + policy=policy, + ) + + return policy + + def get_held_layers(self) -> List[nn.Module]: + held_layers = super().get_held_layers() + if self.pipeline_stage_manager.is_last_stage(): + held_layers.append(self.model.score) + return held_layers + + def get_shared_params(self) -> List[Dict[int, Tensor]]: + "no shared params in OPTForSequenceClassification" + return [] + + +class OPTForQuestionAnsweringPolicy(OPTPolicy): + def __init__(self) -> None: + super().__init__() + + def module_policy(self): + from transformers.models.opt.modeling_opt import OPTForQuestionAnswering + + policy = super().module_policy() + if self.pipeline_stage_manager: + self.set_pipeline_forward( + model_cls=OPTForQuestionAnswering, + new_forward=OPTPipelineForwards.opt_for_question_answering_forward, + policy=policy, + ) + + return policy + + def get_held_layers(self) -> List[nn.Module]: + held_layers = super().get_held_layers() + if self.pipeline_stage_manager.is_last_stage(): + held_layers.append(self.model.qa_outputs) + return held_layers + + def get_shared_params(self) -> List[Dict[int, Tensor]]: + "no shared params in OPTForSequenceClassification" + return [] diff --git a/colossalai/shardformer/policies/sam.py b/colossalai/shardformer/policies/sam.py new file mode 100644 index 0000000000000000000000000000000000000000..58a8500e3863fca1ebb62bdeb2c493b11afea306 --- /dev/null +++ b/colossalai/shardformer/policies/sam.py @@ -0,0 +1,233 @@ +import colossalai.shardformer.layer as col_nn + +from ..modeling.sam import forward_fn, get_sam_flash_attention_forward, get_sam_vision_flash_attention_forward +from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription + +__all__ = ["SamPolicy", "SamModelPolicy"] + + +class SamPolicy(Policy): + def config_sanity_check(self): + pass + + def preprocess(self): + return self.model + + def module_policy(self): + from transformers.models.sam.modeling_sam import ( + SamAttention, + SamTwoWayAttentionBlock, + SamTwoWayTransformer, + SamVisionAttention, + SamVisionLayer, + ) + + policy = {} + + if self.shard_config.enable_tensor_parallelism: + policy[SamVisionLayer] = ModulePolicyDescription( + attribute_replacement={ + "attn.num_attention_heads": self.model.config.vision_config.num_attention_heads + // self.shard_config.tensor_parallel_size, + }, + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="attn.qkv", + target_module=col_nn.FusedLinear1D_Col, + kwargs={ + "n_fused": 3, + }, + ), + SubModuleReplacementDescription( + suffix="attn.proj", + target_module=col_nn.Linear1D_Row, + ), + SubModuleReplacementDescription( + suffix="mlp.lin1", + target_module=col_nn.Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="mlp.lin2", + target_module=col_nn.Linear1D_Row, + ), + ], + ) + policy[SamTwoWayAttentionBlock] = ModulePolicyDescription( + attribute_replacement={ + "self_attn.num_attention_heads": self.model.config.mask_decoder_config.num_attention_heads + // self.shard_config.tensor_parallel_size, + }, + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="self_attn.q_proj", + target_module=col_nn.Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="self_attn.k_proj", + target_module=col_nn.Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="self_attn.v_proj", + target_module=col_nn.Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="self_attn.out_proj", + target_module=col_nn.Linear1D_Row, + ), + SubModuleReplacementDescription( + suffix="cross_attn_token_to_image.q_proj", + target_module=col_nn.Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="cross_attn_token_to_image.k_proj", + target_module=col_nn.Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="cross_attn_token_to_image.v_proj", + target_module=col_nn.Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="cross_attn_token_to_image.out_proj", + target_module=col_nn.Linear1D_Row, + ), + SubModuleReplacementDescription( + suffix="mlp.lin1", + target_module=col_nn.Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="mlp.lin2", + target_module=col_nn.Linear1D_Row, + ), + SubModuleReplacementDescription( + suffix="cross_attn_image_to_token.q_proj", + target_module=col_nn.Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="cross_attn_image_to_token.k_proj", + target_module=col_nn.Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="cross_attn_image_to_token.v_proj", + target_module=col_nn.Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="cross_attn_image_to_token.out_proj", + target_module=col_nn.Linear1D_Row, + ), + ], + ) + policy[SamTwoWayTransformer] = ModulePolicyDescription( + attribute_replacement={ + "final_attn_token_to_image.num_attention_heads": self.model.config.mask_decoder_config.num_attention_heads + // self.shard_config.tensor_parallel_size, + }, + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="final_attn_token_to_image.q_proj", + target_module=col_nn.Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="final_attn_token_to_image.k_proj", + target_module=col_nn.Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="final_attn_token_to_image.v_proj", + target_module=col_nn.Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="final_attn_token_to_image.out_proj", + target_module=col_nn.Linear1D_Row, + ), + ], + ) + + # add `DropoutForParallelInput` layer to replace the useage of `nn.functional.dropout` + policy[SamVisionAttention] = ModulePolicyDescription( + attribute_replacement={ + "dropout_layer": col_nn.DropoutForParallelInput(self.model.config.vision_config.attention_dropout) + }, + method_replacement={"forward": forward_fn()}, + sub_module_replacement=[], + ) + + # optimization configuration + if self.shard_config.enable_fused_normalization: + # Handle SamVisionLayer + self.append_or_create_submodule_replacement( + description=[ + SubModuleReplacementDescription( + suffix="layer_norm1", + target_module=col_nn.FusedLayerNorm, + ), + SubModuleReplacementDescription( + suffix="layer_norm2", + target_module=col_nn.FusedLayerNorm, + ), + ], + policy=policy, + target_key=SamVisionLayer, + ) + + # Handle SamTwoWayAttentionBlock + self.append_or_create_submodule_replacement( + description=[ + SubModuleReplacementDescription( + suffix="layer_norm1", + target_module=col_nn.FusedLayerNorm, + ), + SubModuleReplacementDescription( + suffix="layer_norm2", + target_module=col_nn.FusedLayerNorm, + ), + SubModuleReplacementDescription( + suffix="layer_norm3", + target_module=col_nn.FusedLayerNorm, + ), + SubModuleReplacementDescription( + suffix="layer_norm4", + target_module=col_nn.FusedLayerNorm, + ), + ], + policy=policy, + target_key=SamTwoWayAttentionBlock, + ) + + # Handle SamTwoWayTransformer + self.append_or_create_submodule_replacement( + description=[ + SubModuleReplacementDescription( + suffix="layer_norm_final_attn", + target_module=col_nn.FusedLayerNorm, + ) + ], + policy=policy, + target_key=SamTwoWayTransformer, + ) + + # use flash attention + if self.shard_config.enable_flash_attention: + self.append_or_create_method_replacement( + description={ + "forward": get_sam_flash_attention_forward(), + }, + policy=policy, + target_key=SamAttention, + ) + self.append_or_create_method_replacement( + description={ + "forward": get_sam_vision_flash_attention_forward(), + }, + policy=policy, + target_key=SamVisionAttention, + ) + + return policy + + def postprocess(self): + return self.model + + +# SamModel +class SamModelPolicy(SamPolicy): + def __init__(self) -> None: + super().__init__() diff --git a/colossalai/shardformer/policies/t5.py b/colossalai/shardformer/policies/t5.py new file mode 100644 index 0000000000000000000000000000000000000000..74cc7337e9f16eec1acef00dae21e2bd708dea68 --- /dev/null +++ b/colossalai/shardformer/policies/t5.py @@ -0,0 +1,497 @@ +import warnings +from functools import partial +from typing import Callable, Dict, List, Tuple + +import numpy as np +from torch import Tensor, nn + +from colossalai.shardformer.layer import ( + DropoutForParallelInput, + Embedding1D, + FusedRMSNorm, + Linear1D_Col, + Linear1D_Row, + VocabParallelEmbedding1D, +) +from colossalai.shardformer.policies.base_policy import ModulePolicyDescription + +from ..modeling.jit import get_jit_fused_dropout_add_func +from ..modeling.t5 import ( + T5PipelineForwards, + get_jit_fused_T5_layer_ff_forward, + get_t5_flash_attention_forward, + get_T5_layer_cross_attention_forward, + get_T5_layer_self_attention_forward, +) +from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription + +__all__ = ["distribute_t5_layers", "T5ModelPolicy", "T5ForConditionalGenerationPolicy", "T5EncoderPolicy"] + + +class T5BasePolicy(Policy): + def config_sanity_check(self): + pass + + def preprocess(self): + # reshape the embedding layer + r""" + Reshape the Embedding layer to make the embedding dimension divisible by world_size + """ + if self.shard_config.enable_tensor_parallelism: + vocab_size = self.model.config.vocab_size + world_size = self.shard_config.tensor_parallel_size + if vocab_size % world_size != 0: + new_vocab_size = vocab_size + world_size - vocab_size % world_size + self.model.resize_token_embeddings(new_vocab_size) + return self.model + + def module_policy(self): + from transformers.models.t5.modeling_t5 import ( + T5Attention, + T5DenseActDense, + T5DenseGatedActDense, + T5LayerCrossAttention, + T5LayerFF, + T5LayerSelfAttention, + T5Stack, + ) + + policy = {} + + if self.shard_config.enable_sequence_parallelism: + self.shard_config.enable_sequence_parallelism = False + warnings.warn("T5 dosen't support sequence parallelism now, will ignore the sequence parallelism flag.") + + if self.shard_config.enable_tensor_parallelism: + policy[T5Stack] = ModulePolicyDescription( + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="dropout", + target_module=DropoutForParallelInput, + ), + SubModuleReplacementDescription( + suffix="embed_tokens", + target_module=VocabParallelEmbedding1D, + ), + ] + ) + policy[T5LayerSelfAttention] = ModulePolicyDescription( + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="dropout", + target_module=DropoutForParallelInput, + ), + ] + ) + policy[T5LayerCrossAttention] = ModulePolicyDescription( + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="dropout", + target_module=DropoutForParallelInput, + ) + ] + ) + policy[T5Attention] = ModulePolicyDescription( + attribute_replacement={ + "d_model": self.model.config.d_model // self.shard_config.tensor_parallel_size, + "n_heads": self.model.config.num_heads // self.shard_config.tensor_parallel_size, + "inner_dim": self.model.config.num_heads + * self.model.config.d_kv + // self.shard_config.tensor_parallel_size, + }, + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="q", + target_module=Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="k", + target_module=Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="v", + target_module=Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="o", + target_module=Linear1D_Row, + ), + SubModuleReplacementDescription( + suffix="relative_attention_bias", + target_module=Embedding1D, + kwargs=dict(gather_output=False), + ignore_if_not_exist=True, + ), + ], + ) + policy[T5LayerFF] = ModulePolicyDescription( + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="dropout", + target_module=DropoutForParallelInput, + ), + ] + ) + policy[T5DenseGatedActDense] = ModulePolicyDescription( + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="wi_0 ", + target_module=Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="wi_1", + target_module=Linear1D_Row, + ), + SubModuleReplacementDescription( + suffix="wo", target_module=Linear1D_Col, kwargs=dict(gather_output=True) + ), + SubModuleReplacementDescription( + suffix="dropout", + target_module=DropoutForParallelInput, + ), + ] + ) + policy[T5DenseActDense] = ModulePolicyDescription( + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="wi", + target_module=Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="wo", + target_module=Linear1D_Row, + ), + SubModuleReplacementDescription( + suffix="dropout", + target_module=DropoutForParallelInput, + ), + ] + ) + + # optimization configuration + if self.shard_config.enable_fused_normalization: + self.append_or_create_submodule_replacement( + description=SubModuleReplacementDescription( + suffix="layer_norm", + target_module=FusedRMSNorm, + ), + policy=policy, + target_key=T5LayerFF, + ) + self.append_or_create_submodule_replacement( + description=SubModuleReplacementDescription( + suffix="layer_norm", + target_module=FusedRMSNorm, + ), + policy=policy, + target_key=T5LayerFF, + ) + self.append_or_create_submodule_replacement( + description=SubModuleReplacementDescription(suffix="layer_norm", target_module=FusedRMSNorm), + policy=policy, + target_key=T5LayerSelfAttention, + ) + self.append_or_create_submodule_replacement( + description=SubModuleReplacementDescription(suffix="layer_norm", target_module=FusedRMSNorm), + policy=policy, + target_key=T5LayerCrossAttention, + ) + self.append_or_create_submodule_replacement( + description=SubModuleReplacementDescription(suffix="final_layer_norm", target_module=FusedRMSNorm), + policy=policy, + target_key=T5Stack, + ) + + # use flash attention + if self.shard_config.enable_flash_attention: + self.append_or_create_method_replacement( + description={ + "forward": get_t5_flash_attention_forward(), + }, + policy=policy, + target_key=T5Attention, + ) + + # use jit operator + if self.shard_config.enable_jit_fused: + self.append_or_create_method_replacement( + description={ + "forward": get_jit_fused_T5_layer_ff_forward(), + "dropout_add": get_jit_fused_dropout_add_func(), + }, + policy=policy, + target_key=T5LayerFF, + ) + self.append_or_create_method_replacement( + description={ + "forward": get_T5_layer_self_attention_forward(), + "dropout_add": get_jit_fused_dropout_add_func(), + }, + policy=policy, + target_key=T5LayerSelfAttention, + ) + self.append_or_create_method_replacement( + description={ + "forward": get_T5_layer_cross_attention_forward(), + "dropout_add": get_jit_fused_dropout_add_func(), + }, + policy=policy, + target_key=T5LayerCrossAttention, + ) + + return policy + + def postprocess(self): + return self.model + + @staticmethod + def distribute_t5_layers( + num_encoder_layers: int, num_decoder_layers: int, num_stages: int + ) -> Tuple[List[int], int]: + """ + Distribute t5 layers into stages when pipeline parallel is used. + Return the layer distribution as a list and the starting stage of decoder. + If decoder doesn't exist, returned decoder starting stage is set to num_encoder_layers. + """ + + # number of encoder layers must be a positive integer + if num_encoder_layers <= 0: + raise ValueError("The number of encoder layers for T5 must be a positive integer.") + + # number of layers should be large enough to fill in every stage + if num_encoder_layers + num_decoder_layers < num_stages: + raise ValueError("The total number of layers can't be smaller than number of stages.") + + # in the case of T5EncoderModel, set decoder starting stage to num_stages since it doesn't exist + if num_decoder_layers == 0: + return Policy.distribute_layers(num_encoder_layers, num_stages), num_stages + + # the number of stages distributed between encoder and decoder is optmized in this way: + # num_encoder_stages = argmin(abs(num_encoder_layers / encoder_stages - num_decoder_layers / decoder_stages)) + # s.t. num_encoder_stages + num_decoder_stages = num_stages, num_encoder_stages >= 1, num_decoder_stages >= 1 + def objective(num_encoder_stages): + return abs(num_encoder_layers / num_encoder_stages - num_decoder_layers / (num_stages - num_encoder_stages)) + + num_encoder_stages = np.argmin([objective(i) for i in range(1, num_stages)]) + 1 + num_decoder_stages = num_stages - num_encoder_stages + + encoder_distribution = Policy.distribute_layers(num_encoder_layers, num_encoder_stages) + decoder_distribution = Policy.distribute_layers(num_decoder_layers, num_decoder_stages) + return encoder_distribution + decoder_distribution, num_encoder_stages + + @staticmethod + def get_t5_stage_index( + layers_per_stage: List[int], stage: int, decoder_starting_stage: int + ) -> Tuple[bool, int, int]: + """ + Input the distribution of layers among stages, the current stage and the first stage of decoder. + Return the starting/ending idx of layers in encoder/decoder + """ + if stage < decoder_starting_stage: + return Policy.get_stage_index(layers_per_stage[:decoder_starting_stage], stage) + else: + return Policy.get_stage_index(layers_per_stage[decoder_starting_stage:], stage - decoder_starting_stage) + + def get_held_layers(self) -> List[nn.Module]: + """Get pipeline layers for current stage.""" + assert self.pipeline_stage_manager is not None + stage_manager = self.pipeline_stage_manager + + model = self.model + encoder = self.model.encoder + decoder = getattr(self.model, "decoder", None) + + num_encoder_layers = len(encoder.block) + num_decoder_layers = len(decoder.block) if decoder else 0 + + held_layers = [] + layers_per_stage, decoder_starting_stage = T5BasePolicy.distribute_t5_layers( + num_encoder_layers, num_decoder_layers, stage_manager.num_stages + ) + start_idx, end_idx = T5BasePolicy.get_t5_stage_index( + layers_per_stage, stage_manager.stage, decoder_starting_stage + ) + + if stage_manager.stage < decoder_starting_stage: + # current stage is in t5's encoder + if stage_manager.is_first_stage(): + held_layers.append(model.shared) + held_layers.append(encoder.embed_tokens) + held_layers.append(encoder.dropout) + if stage_manager.stage == decoder_starting_stage - 1: + held_layers.append(encoder.final_layer_norm) + held_layers.append(encoder.dropout) + held_layers.extend(encoder.block[start_idx:end_idx]) + else: + # current stage is in t5's decoder + if stage_manager.stage == decoder_starting_stage: + held_layers.append(decoder.embed_tokens) + held_layers.append(decoder.dropout) + if stage_manager.is_last_stage(): + held_layers.append(decoder.final_layer_norm) + held_layers.append(decoder.dropout) + held_layers.extend(decoder.block[start_idx:end_idx]) + return held_layers + + def set_pipeline_forward(self, model_cls: nn.Module, new_forward: Callable, policy: Dict) -> None: + """If under pipeline parallel setting, replacing the original forward method of huggingface + to customized forward method, and add this changing to policy.""" + if not self.pipeline_stage_manager: + raise ValueError("set_pipeline_forward method can only be called when pipeline parallel is enabled.") + stage_manager = self.pipeline_stage_manager + + encoder = self.model.encoder + decoder = getattr(self.model, "decoder", None) + + num_encoder_layers = len(encoder.block) + num_decoder_layers = len(decoder.block) if decoder else 0 + + layers_per_stage, decoder_starting_stage = T5BasePolicy.distribute_t5_layers( + num_encoder_layers, num_decoder_layers, stage_manager.num_stages + ) + stage_index = T5BasePolicy.get_t5_stage_index(layers_per_stage, stage_manager.stage, decoder_starting_stage) + + method_replacement = { + "forward": partial( + new_forward, + stage_manager=stage_manager, + stage_index=stage_index, + decoder_starting_stage=decoder_starting_stage, + ) + } + self.append_or_create_method_replacement(description=method_replacement, policy=policy, target_key=model_cls) + + +class T5ModelPolicy(T5BasePolicy): + def __init__(self) -> None: + super().__init__() + + def module_policy(self): + from transformers import T5Model + + policy = super().module_policy() + + if self.shard_config.enable_tensor_parallelism: + self.append_or_create_submodule_replacement( + description=SubModuleReplacementDescription( + suffix="shared", + target_module=VocabParallelEmbedding1D, + ), + policy=policy, + target_key=T5Model, + ) + if self.pipeline_stage_manager is not None: + self.set_pipeline_forward(model_cls=T5Model, new_forward=T5PipelineForwards.t5_model_forward, policy=policy) + + return policy + + def get_held_layers(self) -> List[nn.Module]: + return super().get_held_layers() + + def get_shared_params(self) -> List[Dict[int, Tensor]]: + module = self.model + stage_manager = self.pipeline_stage_manager + if stage_manager is not None and stage_manager.num_stages > 1: + _, decoder_starting_stage = T5BasePolicy.distribute_t5_layers( + len(module.encoder.block), len(module.decoder.block), stage_manager.num_stages + ) + + if id(module.decoder.embed_tokens.weight) == id(module.shared.weight): + return [{0: module.shared.weight, decoder_starting_stage: module.decoder.embed_tokens.weight}] + return [] + + +class T5ForConditionalGenerationPolicy(T5BasePolicy): + def __init__(self) -> None: + super().__init__() + + def module_policy(self): + from transformers import T5ForConditionalGeneration + + policy = super().module_policy() + + if self.shard_config.enable_tensor_parallelism: + self.append_or_create_submodule_replacement( + description=[ + SubModuleReplacementDescription( + suffix="shared", + target_module=VocabParallelEmbedding1D, + ), + SubModuleReplacementDescription( + suffix="lm_head", target_module=Linear1D_Col, kwargs=dict(gather_output=True) + ), + ], + policy=policy, + target_key=T5ForConditionalGeneration, + ) + + if self.pipeline_stage_manager is not None: + self.set_pipeline_forward( + model_cls=T5ForConditionalGeneration, + new_forward=T5PipelineForwards.t5_for_conditional_generation_forward, + policy=policy, + ) + return policy + + def get_held_layers(self) -> List[nn.Module]: + held_layers = super().get_held_layers() + if self.pipeline_stage_manager.is_last_stage(): + held_layers.append(self.model.lm_head) + return held_layers + + def get_shared_params(self) -> List[Dict[int, Tensor]]: + module = self.model + stage_manager = self.pipeline_stage_manager + if stage_manager is not None and stage_manager.num_stages > 1: + _, decoder_starting_stage = T5BasePolicy.distribute_t5_layers( + len(module.encoder.block), len(module.decoder.block), stage_manager.num_stages + ) + + shared_params = [] + shared_embedding = {} + if id(module.decoder.embed_tokens.weight) == id(module.shared.weight): + shared_embedding[0] = module.shared.weight + shared_embedding[decoder_starting_stage] = module.decoder.embed_tokens.weight + + if id(module.lm_head.weight) == id(module.shared.weight): + shared_embedding[0] = module.shared.weight + shared_embedding[stage_manager.num_stages - 1] = module.lm_head.weight + + if len(shared_embedding) > 0: + shared_params.append(shared_embedding) + + return shared_params + + return [] + + +class T5EncoderPolicy(T5BasePolicy): + def __init__(self) -> None: + super().__init__() + + def module_policy(self): + from transformers import T5EncoderModel + + policy = super().module_policy() + + if self.shard_config.enable_tensor_parallelism: + self.append_or_create_submodule_replacement( + description=SubModuleReplacementDescription( + suffix="shared", + target_module=VocabParallelEmbedding1D, + ), + policy=policy, + target_key=T5EncoderModel, + ) + + if self.pipeline_stage_manager is not None: + self.set_pipeline_forward( + model_cls=T5EncoderModel, new_forward=T5PipelineForwards.t5_encoder_model_forward, policy=policy + ) + return policy + + def get_held_layers(self) -> List[nn.Module]: + return super().get_held_layers() + + def get_shared_params(self) -> List[Dict[int, Tensor]]: + return [] diff --git a/colossalai/shardformer/policies/vit.py b/colossalai/shardformer/policies/vit.py new file mode 100644 index 0000000000000000000000000000000000000000..270cdce9b09157c384a4e20de06f702af326dd98 --- /dev/null +++ b/colossalai/shardformer/policies/vit.py @@ -0,0 +1,257 @@ +import warnings +from typing import Callable, Dict, List, Union + +import torch.nn as nn + +import colossalai.shardformer.layer as col_nn +from colossalai.shardformer.layer import DropoutForReplicatedInput, Linear1D_Col + +from ..modeling.jit import get_jit_fused_dropout_add_func +from ..modeling.vit import ( + ViTForImageClassification_pipeline_forward, + ViTForMaskedImageModeling_pipeline_forward, + ViTModel_pipeline_forward, + get_jit_fused_vit_output_forward, + get_vit_flash_self_attention_forward, +) +from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription + +__all__ = ["ViTPolicy", "ViTModelPolicy", "ViTForImageClassificationPolicy", "ViTForMaskedImageModelingPolicy"] + + +class ViTPolicy(Policy): + def config_sanity_check(self): + pass + + def preprocess(self): + return self.model + + def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: + from transformers.models.vit.modeling_vit import ViTEmbeddings, ViTLayer, ViTOutput, ViTSelfAttention + + policy = {} + + if self.shard_config.enable_sequence_parallelism: + self.shard_config.enable_sequence_parallelism = False + warnings.warn("Vit dosen't support sequence parallelism now, will ignore the sequence parallelism flag.") + + if self.shard_config.enable_tensor_parallelism: + policy[ViTEmbeddings] = ModulePolicyDescription( + attribute_replacement={}, + param_replacement=[], + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="dropout", + target_module=DropoutForReplicatedInput, + ) + ], + ) + + policy[ViTLayer] = ModulePolicyDescription( + attribute_replacement={ + "attention.attention.num_attention_heads": self.model.config.num_attention_heads + // self.shard_config.tensor_parallel_size, + "attention.attention.all_head_size": self.model.config.hidden_size + // self.shard_config.tensor_parallel_size, + }, + param_replacement=[], + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="attention.attention.query", + target_module=col_nn.Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="attention.attention.key", + target_module=col_nn.Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="attention.attention.value", + target_module=col_nn.Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="attention.attention.dropout", + target_module=col_nn.DropoutForParallelInput, + ), + SubModuleReplacementDescription( + suffix="attention.output.dense", + target_module=col_nn.Linear1D_Row, + ), + SubModuleReplacementDescription( + suffix="attention.output.dropout", + target_module=col_nn.DropoutForReplicatedInput, + ), + SubModuleReplacementDescription( + suffix="intermediate.dense", + target_module=col_nn.Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="output.dense", + target_module=col_nn.Linear1D_Row, + ), + SubModuleReplacementDescription( + suffix="output.dropout", + target_module=col_nn.DropoutForReplicatedInput, + ), + ], + ) + + # use flash attention + if self.shard_config.enable_flash_attention: + self.append_or_create_method_replacement( + description={ + "forward": get_vit_flash_self_attention_forward(), + }, + policy=policy, + target_key=ViTSelfAttention, + ) + + # use jit fused operator + if self.shard_config.enable_jit_fused: + self.append_or_create_method_replacement( + description={ + "forward": get_jit_fused_vit_output_forward(), + "dropout_add": get_jit_fused_dropout_add_func(), + }, + policy=policy, + target_key=ViTOutput, + ) + return policy + + def new_model_class(self): + return None + + def postprocess(self): + return self.model + + def get_held_layers(self) -> List[nn.Module]: + """Get pipeline layers for current stage.""" + assert self.pipeline_stage_manager is not None, "pipeline_stage_manager is None" + + if self.model.__class__.__name__ == "ViTModel": + module = self.model + else: + module = self.model.vit + stage_manager = self.pipeline_stage_manager + + held_layers = [] + layers_per_stage = self.distribute_layers(len(module.encoder.layer), stage_manager.num_stages) + if stage_manager.is_first_stage(): + held_layers.append(module.embeddings) + start_idx, end_idx = self.get_stage_index(layers_per_stage, stage_manager.stage) + held_layers.extend(module.encoder.layer[start_idx:end_idx]) + return held_layers + + def set_pipeline_forward(self, model_cls: nn.Module, pipeline_forward: Callable, policy: Dict): + if self.pipeline_stage_manager: + stage_manager = self.pipeline_stage_manager + if self.model.__class__.__name__ == "ViTModel": + module = self.model + else: + module = self.model.vit + + layers_per_stage = Policy.distribute_layers(len(module.encoder.layer), stage_manager.num_stages) + stage_index = Policy.get_stage_index(layers_per_stage, stage_manager.stage) + method_replacement = {"forward": pipeline_forward(stage_manager=stage_manager, stage_index=stage_index)} + self.append_or_create_method_replacement( + description=method_replacement, policy=policy, target_key=model_cls + ) + + +# ViTModel +class ViTModelPolicy(ViTPolicy): + def __init__(self) -> None: + super().__init__() + + def module_policy(self): + from transformers.models.vit.modeling_vit import ViTModel + + policy = super().module_policy() + + if self.shard_config.pipeline_stage_manager is not None: + self.set_pipeline_forward(model_cls=ViTModel, pipeline_forward=ViTModel_pipeline_forward, policy=policy) + return policy + + def get_held_layers(self) -> List[nn.Module]: + held_layers = super().get_held_layers() + assert self.pipeline_stage_manager is not None, "pipeline_stage_manager is None" + + module = self.model + stage_manager = self.pipeline_stage_manager + if stage_manager.is_last_stage(): + held_layers.append(module.layernorm) + held_layers.append(module.pooler) + + return held_layers + + +# ViTForImageClassification +class ViTForImageClassificationPolicy(ViTPolicy): + def module_policy(self): + from transformers.models.vit.modeling_vit import ViTForImageClassification, ViTModel + + policy = super().module_policy() + if self.shard_config.enable_tensor_parallelism: + new_item = { + ViTForImageClassification: ModulePolicyDescription( + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="classifier", target_module=Linear1D_Col, kwargs=dict(gather_output=True) + ) + ] + ) + } + policy.update(new_item) + + if self.shard_config.pipeline_stage_manager is not None: + self.set_pipeline_forward(model_cls=ViTModel, pipeline_forward=ViTModel_pipeline_forward, policy=policy) + self.set_pipeline_forward( + model_cls=ViTForImageClassification, + pipeline_forward=ViTForImageClassification_pipeline_forward, + policy=policy, + ) + + return policy + + def get_held_layers(self) -> List[nn.Module]: + held_layers = super().get_held_layers() + assert self.pipeline_stage_manager is not None, "pipeline_stage_manager is None" + + module = self.model.vit + stage_manager = self.pipeline_stage_manager + if stage_manager.is_last_stage(): + held_layers.append(module.layernorm) + held_layers.append(self.model.classifier) + + return held_layers + + +# ViTForMaskedImageModeling +class ViTForMaskedImageModelingPolicy(ViTPolicy): + def __init__(self) -> None: + super().__init__() + + def module_policy(self): + from transformers.models.vit.modeling_vit import ViTForMaskedImageModeling, ViTModel + + policy = super().module_policy() + + if self.shard_config.pipeline_stage_manager is not None: + self.set_pipeline_forward(model_cls=ViTModel, pipeline_forward=ViTModel_pipeline_forward, policy=policy) + self.set_pipeline_forward( + model_cls=ViTForMaskedImageModeling, + pipeline_forward=ViTForMaskedImageModeling_pipeline_forward, + policy=policy, + ) + return policy + + def get_held_layers(self) -> List[nn.Module]: + held_layers = super().get_held_layers() + assert self.pipeline_stage_manager is not None, "pipeline_stage_manager is None" + + module = self.model.vit + stage_manager = self.pipeline_stage_manager + if stage_manager.is_last_stage(): + held_layers.append(module.layernorm) + held_layers.append(self.model.decoder) + + return held_layers diff --git a/colossalai/shardformer/policies/whisper.py b/colossalai/shardformer/policies/whisper.py new file mode 100644 index 0000000000000000000000000000000000000000..d9af2461cdb84c52568168a25ac2bb7be27abd67 --- /dev/null +++ b/colossalai/shardformer/policies/whisper.py @@ -0,0 +1,532 @@ +import warnings +from functools import partial +from typing import Callable, Dict, List, Tuple + +import numpy as np +import torch.nn as nn +from torch import Tensor + +import colossalai.shardformer.layer as col_nn + +from ..modeling.jit import get_jit_fused_dropout_add_func +from ..modeling.whisper import ( + WhisperPipelineForwards, + get_jit_fused_whisper_decoder_layer_forward, + get_jit_fused_whisper_encoder_layer_forward, + get_whisper_flash_attention_forward, +) +from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription + +__all__ = [ + "WhisperPolicy", + "WhisperModelPolicy", + "WhisperForConditionalGenerationPolicy", + "WhisperForAudioClassificationPolicy", +] + + +class WhisperPolicy(Policy): + def config_sanity_check(self): + pass + + def preprocess(self): + # reshape the embedding layer + r""" + Reshape the Embedding layer to make the embedding dimension divisible by world_size + """ + vocab_size = self.model.config.vocab_size + world_size = self.shard_config.tensor_parallel_size + if vocab_size % world_size != 0: + new_vocab_size = vocab_size + world_size - vocab_size % world_size + self.model.resize_token_embeddings(new_vocab_size) + return self.model + + def module_policy(self): + from transformers.models.whisper.modeling_whisper import ( + WhisperAttention, + WhisperDecoder, + WhisperDecoderLayer, + WhisperEncoder, + WhisperEncoderLayer, + ) + + policy = {} + + if self.shard_config.enable_sequence_parallelism: + self.shard_config.enable_sequence_parallelism = False + warnings.warn( + "Whisper dosen't support sequence parallelism now, will ignore the sequence parallelism flag." + ) + + # TODO using the jit fused add_and_dropout affect the accuracy + if self.shard_config.enable_jit_fused: + self.shard_config.enable_jit_fused = False + warnings.warn("Whisper dosen't support jit fused operator now, will ignore the jit fused operator flag.") + + if self.shard_config.enable_tensor_parallelism: + policy[WhisperEncoderLayer] = ModulePolicyDescription( + attribute_replacement={ + "self_attn.embed_dim": self.model.config.d_model // self.shard_config.tensor_parallel_size, + "self_attn.num_heads": self.model.config.encoder_attention_heads + // self.shard_config.tensor_parallel_size, + }, + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="self_attn.q_proj", + target_module=col_nn.Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="self_attn.k_proj", + target_module=col_nn.Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="self_attn.v_proj", + target_module=col_nn.Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="self_attn.out_proj", + target_module=col_nn.Linear1D_Row, + ), + SubModuleReplacementDescription( + suffix="fc1", + target_module=col_nn.Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="fc2", + target_module=col_nn.Linear1D_Row, + ), + ], + ) + + policy[WhisperDecoderLayer] = ModulePolicyDescription( + attribute_replacement={ + "self_attn.embed_dim": self.model.config.d_model // self.shard_config.tensor_parallel_size, + "self_attn.num_heads": self.model.config.decoder_attention_heads + // self.shard_config.tensor_parallel_size, + "encoder_attn.embed_dim": self.model.config.d_model // self.shard_config.tensor_parallel_size, + "encoder_attn.num_heads": self.model.config.encoder_attention_heads + // self.shard_config.tensor_parallel_size, + }, + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="self_attn.q_proj", + target_module=col_nn.Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="self_attn.k_proj", + target_module=col_nn.Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="self_attn.v_proj", + target_module=col_nn.Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="self_attn.out_proj", + target_module=col_nn.Linear1D_Row, + ), + SubModuleReplacementDescription( + suffix="encoder_attn.q_proj", + target_module=col_nn.Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="encoder_attn.k_proj", + target_module=col_nn.Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="encoder_attn.v_proj", + target_module=col_nn.Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="encoder_attn.out_proj", + target_module=col_nn.Linear1D_Row, + ), + SubModuleReplacementDescription( + suffix="fc1", + target_module=col_nn.Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="fc2", + target_module=col_nn.Linear1D_Row, + ), + ], + ) + + policy[WhisperDecoder] = ModulePolicyDescription( + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="embed_tokens", + target_module=col_nn.VocabParallelEmbedding1D, + ), + ] + ) + + # optimization configuration + if self.shard_config.enable_fused_normalization: + # Handle encoder layer + self.append_or_create_submodule_replacement( + description=[ + SubModuleReplacementDescription( + suffix="self_attn_layer_norm", + target_module=col_nn.FusedLayerNorm, + ), + SubModuleReplacementDescription( + suffix="final_layer_norm", + target_module=col_nn.FusedLayerNorm, + ), + ], + policy=policy, + target_key=WhisperEncoderLayer, + ) + + # Handle decoder layer + self.append_or_create_submodule_replacement( + description=[ + SubModuleReplacementDescription( + suffix="self_attn_layer_norm", + target_module=col_nn.FusedLayerNorm, + ), + SubModuleReplacementDescription( + suffix="final_layer_norm", + target_module=col_nn.FusedLayerNorm, + ), + ], + policy=policy, + target_key=WhisperDecoderLayer, + ) + + # handle encoder layer + self.append_or_create_submodule_replacement( + description=[ + SubModuleReplacementDescription( + suffix="layer_norm", + target_module=col_nn.FusedLayerNorm, + ) + ], + policy=policy, + target_key=WhisperEncoder, + ) + + # handle decoder layer + self.append_or_create_submodule_replacement( + description=[ + SubModuleReplacementDescription( + suffix="layer_norm", + target_module=col_nn.FusedLayerNorm, + ) + ], + policy=policy, + target_key=WhisperDecoder, + ) + + # enable flash attention + if self.shard_config.enable_flash_attention: + self.append_or_create_method_replacement( + description={ + "forward": get_whisper_flash_attention_forward(), + }, + policy=policy, + target_key=WhisperAttention, + ) + + # use jit fused operator + if self.shard_config.enable_jit_fused: + self.append_or_create_method_replacement( + description={ + "forward": get_jit_fused_whisper_decoder_layer_forward(), + "dropout_add": get_jit_fused_dropout_add_func(), + }, + policy=policy, + target_key=WhisperDecoderLayer, + ) + self.append_or_create_method_replacement( + description={ + "forward": get_jit_fused_whisper_encoder_layer_forward(), + "dropout_add": get_jit_fused_dropout_add_func(), + }, + policy=policy, + target_key=WhisperEncoderLayer, + ) + + return policy + + def add_lm_head_policy(self, base_policy): + from transformers.models.whisper.modeling_whisper import WhisperForConditionalGeneration + + # optimize for tensor parallelism + if self.shard_config.enable_tensor_parallelism: + self.append_or_create_submodule_replacement( + description=SubModuleReplacementDescription( + suffix="proj_out", target_module=col_nn.Linear1D_Col, kwargs={"gather_output": True} + ), + policy=base_policy, + target_key=WhisperForConditionalGeneration, + ) + + return base_policy + + def postprocess(self): + return self.model + + @staticmethod + def distribute_whisper_layers( + num_encoder_layers: int, num_decoder_layers: int, num_stages: int + ) -> Tuple[List[int], int]: + """ + Distribute whisper layers into stages when pipeline parallel is used. + Return the layer distribution as a list and the starting stage of decoder. + If decoder doesn't exist, returned decoder starting stage is set to num_encoder_layers. + """ + + # number of encoder layers must be a positive integer + if num_encoder_layers <= 0: + raise ValueError("The number of encoder layers for whisper must be a positive integer.") + + # number of layers should be large enough to fill in every stage + if num_encoder_layers + num_decoder_layers < num_stages: + raise ValueError("The total number of layers can't be smaller than number of stages.") + + # in the case of whisperEncoderModel, set decoder starting stage to num_stages since it doesn't exist + if num_decoder_layers == 0: + return Policy.distribute_layers(num_encoder_layers, num_stages), num_stages + + # the number of stages distributed between encoder and decoder is optmized in this way: + # num_encoder_stages = argmin(abs(num_encoder_layers / encoder_stages - num_decoder_layers / decoder_stages)) + # s.t. num_encoder_stages + num_decoder_stages = num_stages, num_encoder_stages >= 1, num_decoder_stages >= 1 + def objective(num_encoder_stages): + return abs(num_encoder_layers / num_encoder_stages - num_decoder_layers / (num_stages - num_encoder_stages)) + + num_encoder_stages = np.argmin([objective(i) for i in range(1, num_stages)]) + 1 + num_decoder_stages = num_stages - num_encoder_stages + + encoder_distribution = Policy.distribute_layers(num_encoder_layers, num_encoder_stages) + decoder_distribution = Policy.distribute_layers(num_decoder_layers, num_decoder_stages) + return encoder_distribution + decoder_distribution, num_encoder_stages + + @staticmethod + def get_whisper_stage_index( + layers_per_stage: List[int], stage: int, decoder_starting_stage: int + ) -> Tuple[bool, int, int]: + """ + Input the distribution of layers among stages, the current stage and the first stage of decoder. + Return the starting/ending idx of layers in encoder/decoder + """ + if stage < decoder_starting_stage: + return Policy.get_stage_index(layers_per_stage[:decoder_starting_stage], stage) + else: + return Policy.get_stage_index(layers_per_stage[decoder_starting_stage:], stage - decoder_starting_stage) + + def get_held_layers(self) -> List[nn.Module]: + assert self.pipeline_stage_manager is not None, "pipeline_stage_manager is None" + stage_manager = self.pipeline_stage_manager + + if self.model.__class__.__name__ == "WhisperModel": + model = self.model + elif self.model.__class__.__name__ == "WhisperForConditionalGeneration": + model = self.model.model + else: + model = None + + if model: + encoder = self.model.get_encoder() + decoder = self.model.get_decoder() + else: + # whisper for audio classification holds encoder only + encoder = self.model.encoder + decoder = None + + num_encoder_layers = len(encoder.layers) + if decoder: + num_decoder_layers = len(decoder.layers) + else: + num_decoder_layers = 0 + + held_layers = [] + layers_per_stage, decoder_starting_stage = WhisperPolicy.distribute_whisper_layers( + num_encoder_layers, num_decoder_layers, stage_manager.num_stages + ) + start_idx, end_idx = WhisperPolicy.get_whisper_stage_index( + layers_per_stage, stage_manager.stage, decoder_starting_stage + ) + + if stage_manager.stage < decoder_starting_stage: + # current stage is in whisper's encoder + if stage_manager.is_first_stage(): + held_layers.append(encoder.embed_positions) + held_layers.append(encoder.conv1) + held_layers.append(encoder.conv2) + if stage_manager.stage == decoder_starting_stage - 1: + held_layers.append(encoder.layer_norm) + held_layers.extend(encoder.layers[start_idx:end_idx]) + else: + # current stage is in whisper's decoder + # TODO:(Jianghai) We divide encoder and decoder layers into different parts here, + # the case encoder and decoder put in same stage should be add in the future. + if stage_manager.stage == decoder_starting_stage: + held_layers.append(decoder.embed_tokens) + held_layers.append(decoder.embed_positions) + if stage_manager.is_last_stage(): + held_layers.append(decoder.layer_norm) + held_layers.extend(decoder.layers[start_idx:end_idx]) + return held_layers + + def set_pipeline_forward(self, model_cls: nn.Module, new_forward: Callable, policy: Dict) -> None: + """If under pipeline parallel setting, replacing the original forward method of huggingface + to customized forward method, and add this changing to policy.""" + if not self.pipeline_stage_manager: + raise ValueError("set_pipeline_forward method can only be called when pipeline parallel is enabled.") + stage_manager = self.pipeline_stage_manager + + if self.model.__class__.__name__ == "WhisperModel": + model = self.model + elif self.model.__class__.__name__ == "WhisperForConditionalGeneration": + model = self.model.model + else: + model = None + + if model: + encoder = self.model.get_encoder() + decoder = self.model.get_decoder() + else: + encoder = self.model.encoder + decoder = None + + num_encoder_layers = len(encoder.layers) + if decoder: + num_decoder_layers = len(decoder.layers) + else: + num_decoder_layers = 0 + + layers_per_stage, decoder_starting_stage = WhisperPolicy.distribute_whisper_layers( + num_encoder_layers, num_decoder_layers, stage_manager.num_stages + ) + stage_index = WhisperPolicy.get_whisper_stage_index( + layers_per_stage, stage_manager.stage, decoder_starting_stage + ) + + method_replacement = { + "forward": partial( + new_forward, + stage_manager=stage_manager, + stage_index=stage_index, + decoder_starting_stage=decoder_starting_stage, + ) + } + self.append_or_create_method_replacement(description=method_replacement, policy=policy, target_key=model_cls) + + +# WhisperModel +class WhisperModelPolicy(WhisperPolicy): + def __init__(self) -> None: + super().__init__() + + def module_policy(self): + from transformers import WhisperModel + + policy = super().module_policy() + + if self.pipeline_stage_manager is not None: + self.set_pipeline_forward( + model_cls=WhisperModel, new_forward=WhisperPipelineForwards.whisper_model_forward, policy=policy + ) + + return policy + + def get_held_layers(self) -> List[nn.Module]: + return super().get_held_layers() + + def get_shared_params(self) -> List[Dict[int, Tensor]]: + "no shared params in whisper model" + return [] + + +# WhisperForConditionalGeneration +class WhisperForConditionalGenerationPolicy(WhisperPolicy): + def __init__(self) -> None: + super().__init__() + + def module_policy(self): + from transformers import WhisperForConditionalGeneration + + policy = super().module_policy() + policy = self.add_lm_head_policy(policy) + + if self.pipeline_stage_manager is not None: + self.set_pipeline_forward( + model_cls=WhisperForConditionalGeneration, + new_forward=WhisperPipelineForwards.whisper_for_conditional_generation_forward, + policy=policy, + ) + return policy + + def postprocess(self): + return self.model + + def get_held_layers(self) -> List[nn.Module]: + held_layers = super().get_held_layers() + if self.pipeline_stage_manager.is_last_stage(): + held_layers.append(self.model.proj_out) + return held_layers + + def get_shared_params(self) -> List[Dict[int, Tensor]]: + module = self.model + model = module.model + + if model: + encoder = self.model.get_encoder() + decoder = self.model.get_decoder() + else: + encoder = self.model.encoder + decoder = None + + num_encoder_layers = len(encoder.layers) + if decoder: + num_decoder_layers = len(decoder.layers) + else: + num_decoder_layers = 0 + + stage_manager = self.pipeline_stage_manager + if stage_manager is not None and stage_manager.num_stages > 1: + _, decoder_starting_stage = WhisperPolicy.distribute_whisper_layers( + num_encoder_layers, num_decoder_layers, stage_manager.num_stages + ) + shared_params = [] + shared_embedding = {} + if id(module.proj_out) == id(model.decoder.embed_tokens): + shared_embedding[decoder_starting_stage] = model.decoder.embed_tokens + shared_embedding[stage_manager.num_stages - 1] = module.proj_out + if len(shared_embedding) > 0: + shared_params.append(shared_embedding) + return shared_params + return [] + + +# WhisperForAudioClassification +class WhisperForAudioClassificationPolicy(WhisperPolicy): + def __init__(self) -> None: + super().__init__() + + def preprocess(self): + return self.model + + def module_policy(self): + from transformers import WhisperForAudioClassification + + policy = super().module_policy() + + if self.pipeline_stage_manager is not None: + self.set_pipeline_forward( + model_cls=WhisperForAudioClassification, + new_forward=WhisperPipelineForwards.whisper_for_audio_classification_forward, + policy=policy, + ) + return policy + + def get_held_layers(self) -> List[nn.Module]: + held_layers = super().get_held_layers() + if self.pipeline_stage_manager.is_last_stage(): + held_layers.append(self.model.projector) + held_layers.append(self.model.classifier) + return held_layers + + def get_shared_params(self) -> List[Dict[int, Tensor]]: + return [] diff --git a/colossalai/shardformer/shard/__init__.py b/colossalai/shardformer/shard/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..acf8a95a41ca093f312e6322637d5514c675b13d --- /dev/null +++ b/colossalai/shardformer/shard/__init__.py @@ -0,0 +1,5 @@ +from .shard_config import ShardConfig +from .sharder import ModelSharder +from .shardformer import ShardFormer + +__all__ = ["ShardConfig", "ModelSharder", "ShardFormer"] diff --git a/colossalai/shardformer/shard/shard_config.py b/colossalai/shardformer/shard/shard_config.py new file mode 100644 index 0000000000000000000000000000000000000000..a285874d218b3c6bd67427700beebec489d490fe --- /dev/null +++ b/colossalai/shardformer/shard/shard_config.py @@ -0,0 +1,79 @@ +from dataclasses import dataclass +from typing import Optional + +import torch.distributed as dist +from torch.distributed import ProcessGroup + +from colossalai.pipeline.stage_manager import PipelineStageManager + +__all__ = ["ShardConfig"] + + +@dataclass +class ShardConfig: + r""" + The config for sharding the huggingface model + + Args: + tensor_parallel_process_group (Optional[ProcessGroup]): The process group of tensor parallelism, it's necessary when using tensor parallel. Defaults to None, which is the global process group. + pipeline_stage_manager (Optional[PipelineStageManager]): If using pipeline parallelism, it's necessary to specify a pipeline stage manager for inter-process communication in pipeline parallelism. Defaults to None, which means not using pipeline parallelism. + enable_tensor_parallelism (bool): Whether to use tensor parallelism. Defaults to True. + enable_fused_normalization (bool): Whether to use fused layernorm. Defaults to False. + enable_flash_attention (bool, optional): Whether to switch on flash attention. Defaults to False. + enable_jit_fused (bool, optional): Whether to switch on JIT fused operators. Defaults to False. + enable_sequence_parallelism (bool): Whether to turn on sequence parallelism, which partitions non-tensor-parallel regions along the sequence dimension. Defaults to False. + enable_sequence_overlap (bool): Whether to turn on sequence overlap, wheich overlap the computation and communication in sequence parallelism. It can only be used when enable_sequence_parallelism is True. Defaults to False. + enable_all_optimization (bool): Whether to turn on all optimization tools including 'fused normalizaion', 'flash attention', 'JIT fused operators', 'sequence parallelism' and 'sequence overlap'. Defaults to False. + inference_only (bool): Whether only doing forward passing. Defaults to False. + """ + tensor_parallel_process_group: Optional[ProcessGroup] = None + pipeline_stage_manager: Optional[PipelineStageManager] = None + enable_tensor_parallelism: bool = True + enable_fused_normalization: bool = False + enable_flash_attention: bool = False + enable_jit_fused: bool = False + enable_all_optimization: bool = False + inference_only: bool = False + inference_gptq: bool = False + enable_sequence_parallelism: bool = False + enable_sequence_overlap: bool = False + # pipeline_parallel_size: int + # data_parallel_size: int + # tensor_parallel_mode: Literal['1d', '2d', '2.5d', '3d'] + + @property + def tensor_parallel_size(self): + return self._tensor_parallel_size + + def __post_init__(self): + if not self.enable_tensor_parallelism and self.enable_sequence_parallelism: + raise ValueError( + "enable_sequence_parallelism can only be set to True when enable_tensor_parallelism is True" + ) + if not self.enable_sequence_parallelism and self.enable_sequence_overlap: + raise ValueError("enable_sequence_overlap can only be set to True when enable_sequence_parallelism is True") + if not self.enable_tensor_parallelism: + self._tensor_parallel_size = 1 + else: + # get the parallel size + self._tensor_parallel_size = dist.get_world_size(self.tensor_parallel_process_group) + # turn on all optimization if all_optimization is set to True + if self.enable_all_optimization: + self._turn_on_all_optimization() + + def _turn_on_all_optimization(self): + """ + Turn on all optimization. + """ + # you can add all the optimization flag here + self.enable_fused_normalization = True + self.enable_flash_attention = True + self.enable_jit_fused = True + self.enable_sequence_parallelism = True + self.enable_sequence_overlap = True + + def _infer(self): + """ + Set default params for inference. + """ + assert self.pipeline_stage_manager is None, "pipeline parallelism is not supported in inference for now" diff --git a/colossalai/shardformer/shard/sharder.py b/colossalai/shardformer/shard/sharder.py new file mode 100644 index 0000000000000000000000000000000000000000..1bed850c65814362436c97a28c18eecff0bb3532 --- /dev/null +++ b/colossalai/shardformer/shard/sharder.py @@ -0,0 +1,237 @@ +from types import MethodType +from typing import Any, Callable, Dict, List, Optional, Set, Union + +import torch.nn as nn +from torch import Tensor + +from colossalai.lazy import LazyInitContext + +from .._utils import getattr_, setattr_ +from ..policies.auto_policy import get_autopolicy +from ..policies.base_policy import Policy, SubModuleReplacementDescription +from .shard_config import ShardConfig +from .utils import set_tensors_to_none + +__all__ = ["ModelSharder", "shard_model"] + + +class ModelSharder(object): + r""" + Shard the original huggingface model according to the policy + + Args: + policy (:class:`Policy`): The policy to shard the model + model (:class:`torch.Module`): The model to shard + shard_config: The setting of distributed model + """ + + def __init__(self, model: nn.Module, policy: Policy, shard_config: ShardConfig = None) -> None: + self.model = model + self.policy = get_autopolicy(self.model, shard_config.inference_only) if policy is None else policy + self.shard_config = shard_config + + def shard(self) -> List[Dict[int, Tensor]]: + r""" + Shard the model according to the policy + """ + self.policy.set_model(self.model) + self.policy.set_shard_config(self.shard_config) + self._preprocess() + # get shared params before release unheld layers, this avoid misjudgement of shared params (None is None) + shared_params = self.policy.get_shared_params() + held_layers = self._release_unheld_layers() + self._replace_module(include=held_layers) + self._materialize() + self._postprocess() + return shared_params + + def _preprocess(self) -> None: + self.model = self.policy.preprocess() + + def _postprocess(self) -> None: + self.model = self.policy.postprocess() + + def _replace_module(self, include: Optional[Set[nn.Module]] = None) -> None: + r""" + Replace the module according to the policy, and replace the module one by one + + Args: + model (:class:`torch.nn.Module`): The model to shard + """ + module_descriptions = self.policy.module_policy() + for layer_cls, module_description in module_descriptions.items(): + attr_replacement = module_description.attribute_replacement + param_replacement = module_description.param_replacement + sub_module_replacement = module_description.sub_module_replacement + method_replacement = module_description.method_replacement + self._recursive_replace_layer( + self.model, + layer_cls, + attr_replacement, + param_replacement, + method_replacement, + sub_module_replacement, + include=include, + ) + + def _recursive_replace_layer( + self, + module: nn.Module, + origin_cls: Union[str, nn.Module], + attr_replacement: Dict[str, Any], + param_replacement: List[Callable], + method_replacement: Dict[str, Callable], + sub_module_replacement: List[SubModuleReplacementDescription], + include: Optional[Set[nn.Module]] = None, + ) -> None: + r""" + Reverse the replace layer operation + + Args: + module (torch.nn.Module): The object of layer to shard + origin_cls (Union[str, torch.nn.Module]): The origin layer class or a string of layer class name + attr_replacement (Dict[str, Any]): The attribute dict to modify + param_replacement (List[Callable]): The function list to get parameter shard information in policy + method_replacement (Dict[str, Callable]): Key is the method name, value is the method for replacement + sub_module_replacement ((List[SubModuleReplacementDescription]): The function list to get sub module shard information in policy + include (Set[nn.Module], optional): The set of modules to keep on current device when pipeline parallel is enabled. Defaults to None + """ + if (isinstance(origin_cls, str) and origin_cls == module.__class__.__name__) or ( + module.__class__ == origin_cls + ): + if attr_replacement is not None: + self._replace_attr(module, attr_replacement) + + if param_replacement is not None and (include is None or module in include): + self._replace_param(module, param_replacement) + + if method_replacement is not None: + self._replace_method(module, method_replacement) + + if sub_module_replacement is not None: + self._replace_sub_module(module, sub_module_replacement, include) + + for name, child in module.named_children(): + self._recursive_replace_layer( + child, + origin_cls, + attr_replacement, + param_replacement, + method_replacement, + sub_module_replacement, + include=include, + ) + + def _replace_attr( + self, + module: nn.Module, + attr_replacement: Dict[str, Any], + ) -> None: + r""" + Replace the attribute of the layer + + Args: + module (:class:`torch.nn.Module`): The object of layer to shard + attr_replacement (Dict): The attribute dict to modify + """ + for k, v in attr_replacement.items(): + setattr_(module, k, v, ignore=True) + + def _replace_param( + self, + module: nn.Module, + param_replacement: List[Callable], + ) -> None: + r""" + Replace the parameter of the layer + + Args: + module (:class:`torch.nn.Module`): The object of layer to shard + param_replacement (List[Callable]): The function list to get parameter shard information in policy + """ + for param_func in param_replacement: + param_func(module) + + def _replace_method(self, module: nn.Module, method_replacement: Dict[str, Callable]): + for method_name, new_method in method_replacement.items(): + # bind the new method to the module + bound_method = MethodType(new_method, module) + setattr(module, method_name, bound_method) + + def _replace_sub_module( + self, + org_layer: nn.Module, + sub_module_replacement: List[SubModuleReplacementDescription], + include: Optional[Set[nn.Module]] = None, + ) -> None: + r""" + Shard one layer according to the policy, the layer should be the same class as the key in policy's argument_policy return dict + + Args: + org_layer (torch.nn.Module): The origin layer object to shard + sub_module_replacement (List[SubModuleReplacementDescription]): The sub module replacement description list + include (Set[nn.Module], optional): The set of modules to keep on current device when pipeline parallel is enabled. Defaults to None + """ + for description in sub_module_replacement: + suffix = description.suffix + target_module = description.target_module + kwargs = {} if description.kwargs is None else description.kwargs + + assert target_module is not None, "target_module should not be None" + + native_sub_module = getattr_(org_layer, suffix, ignore=True) + + # Skip replacement if submodule is not kept by current device when pipeline parallel is enabled. + if (include is not None) and (native_sub_module is not None) and (native_sub_module not in include): + continue + + assert not isinstance( + native_sub_module, target_module + ), f"The module with suffix {suffix} has been replaced, please check the policy" + + # if it is None and we are allowed to ignore this module + # just skip + if description.ignore_if_not_exist and native_sub_module is None: + continue + + try: + replace_layer = target_module.from_native_module( + native_sub_module, self.shard_config.tensor_parallel_process_group, **kwargs + ) + except Exception as e: + raise RuntimeError( + f"Failed to replace {suffix} of type {native_sub_module.__class__.__qualname__}" + f" with {target_module.__qualname__} with the exception: {e}. " + "Please check your model configuration or sharding policy, you can set up an issue for us to help you as well." + ) + + setattr_(org_layer, suffix, replace_layer) + + def _get_recursive_held_layers(self, held_layers: Optional[List[nn.Module]]) -> Optional[List[nn.Module]]: + def collect_sub_modules(module: nn.Module): + if module is None: + return + recursive_held_layers.append(module) + for name, child in module.named_children(): + collect_sub_modules(child) + + recursive_held_layers = [] + for module in held_layers: + collect_sub_modules(module) + return recursive_held_layers + + def _release_unheld_layers(self) -> Optional[Set[nn.Module]]: + r""" + Release the unheld layers in the model + """ + if self.shard_config and self.shard_config.pipeline_stage_manager: + held_layers = self.policy.get_held_layers() + set_tensors_to_none(self.model, exclude=set(held_layers)) + return set(self._get_recursive_held_layers(held_layers)) + return None + + def _materialize(self) -> None: + r""" + Materialize the model if lazy initialization is used + """ + LazyInitContext.materialize(self.model) diff --git a/colossalai/shardformer/shard/shardformer.py b/colossalai/shardformer/shard/shardformer.py new file mode 100644 index 0000000000000000000000000000000000000000..7a0d75bf2f2ac47eb29d38ffb7afe46add7689fc --- /dev/null +++ b/colossalai/shardformer/shard/shardformer.py @@ -0,0 +1,51 @@ +from typing import Dict, List, Tuple + +import torch.nn as nn +from torch import Tensor + +from colossalai.cluster import DistCoordinator + +from ..policies.base_policy import Policy +from .shard_config import ShardConfig +from .sharder import ModelSharder + + +class ShardFormer: + """ + Parallelize model based on the given config and policy + + Example: + + ```python + from colossalai.shardformer import ShardFormer, ShardConfig + from transformers import BertForMaskedLM + import colossalai + import torch + + colossalai.launch_from_torch(config={}) + + org_model = BertForMaskedLM.from_pretrained('bert-base-uncased') + shard_config = ShardConfig() + shard_former = ShardFormer(shard_config=shard_config) + model, shared_params = shard_former.optimize(org_model) + ``` + """ + + def __init__(self, shard_config: ShardConfig): + self.coordinator = DistCoordinator() + self.shard_config = shard_config + + def optimize(self, model: nn.Module, policy: Policy = None) -> Tuple[nn.Module, List[Dict[int, Tensor]]]: + r""" + This method will optimize the model based on the given policy. + + Args: + model (`torch.nn.Model`): the origin huggingface model + shard_config (`ShardConfig`): the config for distribute information + policy (`Policy`): the custom policy for sharding + + Returns: the sharded model and the shared parameters + """ + sharder = ModelSharder(model=model, shard_config=self.shard_config, policy=policy) + shared_params = sharder.shard() + return model, shared_params diff --git a/colossalai/shardformer/shard/utils.py b/colossalai/shardformer/shard/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..2bac37bfedda5bfd7342308ef097746d9e897e81 --- /dev/null +++ b/colossalai/shardformer/shard/utils.py @@ -0,0 +1,19 @@ +from typing import Set + +import torch.nn as nn + + +def set_tensors_to_none(model: nn.Module, exclude: Set[nn.Module] = set()) -> None: + """Set all parameters and buffers of model to None + + Args: + model (nn.Module): The model to set + """ + if model in exclude: + return + for child in model.children(): + set_tensors_to_none(child, exclude=exclude) + for n, p in model.named_parameters(recurse=False): + setattr(model, n, None) + for n, buf in model.named_buffers(recurse=False): + setattr(model, n, None) diff --git a/colossalai/tensor/__init__.py b/colossalai/tensor/__init__.py index b2da64e6c33a0f410dd5b7a9e05ff5775cc0a6eb..9ed149f33f2ff21210468fc21f07233c86297c99 100644 --- a/colossalai/tensor/__init__.py +++ b/colossalai/tensor/__init__.py @@ -1,18 +1,18 @@ -from . import distspec from .colo_parameter import ColoParameter from .colo_tensor import ColoTensor from .comm_spec import CollectiveCommPattern, CommSpec -from .compute_spec import ComputePattern, ComputeSpec -from .dist_spec_mgr import DistSpecManager -from .distspec import ReplicaSpec, ShardSpec from .param_op_hook import ColoParamOpHook, ColoParamOpHookManager -from .process_group import ProcessGroup -from .tensor_spec import ColoTensorSpec from .utils import convert_dim_partition_dict, convert_parameter, merge_same_dim_mesh_list, named_params_with_colotensor __all__ = [ - 'ColoTensor', 'convert_parameter', 'ComputePattern', 'ComputeSpec', 'named_params_with_colotensor', 'ColoParameter', - 'distspec', 'DistSpecManager', 'ColoParamOpHook', 'ColoParamOpHookManager', 'ProcessGroup', 'ColoTensorSpec', - 'ShardSpec', 'ReplicaSpec', 'CommSpec', 'CollectiveCommPattern', 'convert_dim_partition_dict', - 'merge_same_dim_mesh_list' + "ColoTensor", + "convert_parameter", + "named_params_with_colotensor", + "ColoParameter", + "ColoParamOpHook", + "ColoParamOpHookManager", + "CommSpec", + "CollectiveCommPattern", + "convert_dim_partition_dict", + "merge_same_dim_mesh_list", ] diff --git a/colossalai/tensor/colo_parameter.py b/colossalai/tensor/colo_parameter.py index b384579feb35ef643fb3206a165ad4e3e0c02a0a..5712505ae2ff66031d1638862b9799790986e116 100644 --- a/colossalai/tensor/colo_parameter.py +++ b/colossalai/tensor/colo_parameter.py @@ -3,9 +3,15 @@ from typing import Optional import torch from colossalai.tensor.colo_tensor import ColoTensor -from colossalai.tensor.const import TensorType from colossalai.tensor.param_op_hook import ColoParamOpHookManager -from colossalai.tensor.tensor_spec import ColoTensorSpec + +from .colo_tensor import _convert_output + +WHITE_LIST_FUNCS = {torch.Tensor.__getitem__} + + +def is_no_hook_op(func) -> bool: + return func.__name__.startswith("__") and func not in WHITE_LIST_FUNCS def filter_colo_parameters(*args, **kwargs): @@ -30,64 +36,34 @@ def filter_colo_parameters(*args, **kwargs): def replace_args(args, kwargs, new_args): - args = new_args[:len(args)] - for k, v in zip(kwargs.keys(), new_args[len(args):]): + args = new_args[: len(args)] + for k, v in zip(kwargs.keys(), new_args[len(args) :]): kwargs[k] = v return tuple(args), kwargs class ColoParameter(ColoTensor, torch.nn.Parameter): - r"""A kind of ColoTensor to be considered as a module parameter. - - """ + r"""A kind of ColoTensor to be considered as a module parameter.""" - def __new__(cls, - data: Optional[torch.Tensor] = None, - requires_grad: bool = True, - spec: ColoTensorSpec = None) -> 'ColoParameter': + def __new__(cls, data: Optional[torch.Tensor] = None, requires_grad: bool = True) -> "ColoParameter": if data is None: data = torch.empty(0) return torch.Tensor._make_subclass(cls, data, requires_grad) - def __init__(self, - data: Optional[torch.Tensor] = None, - requires_grad: bool = True, - spec: ColoTensorSpec = None) -> None: - ColoTensor.__init__(self, data, spec) - self._type = TensorType.MODEL - # a list contains modules sharing this ColoParameter with others. - self._shared_param_modules = [] - - @property - def shared_param_modules(self): - return self._shared_param_modules - - @staticmethod - def from_torch_tensor(tensor: torch.Tensor, - requires_grad: bool = True, - spec: ColoTensorSpec = None) -> 'ColoParameter': - tensor = tensor.as_subclass(ColoParameter) - tensor.__init__(tensor, requires_grad=requires_grad, spec=spec) - return tensor - - def __repr__(self): - return super(ColoParameter, self).__repr__() - @classmethod def __torch_function__(cls, func, types, args=..., kwargs=None): - if ColoParamOpHookManager.has_hook(): - if not func.__name__.startswith('__'): - if kwargs is None: - kwargs = {} - params = filter_colo_parameters(*args, **kwargs) - if len(params) > 0: - with torch._C.DisableTorchFunction(): - new_args = ColoParamOpHookManager.pre_op(params, *args, *kwargs.values()) - args, kwargs = replace_args(args, kwargs, new_args) - ret = super().__torch_function__(func, types, args, kwargs) - with torch._C.DisableTorchFunction(): - ret = ColoParamOpHookManager.post_op(params, ret) - return ret + if kwargs is None: + kwargs = {} + if ColoParamOpHookManager.has_hook() and not is_no_hook_op(func): + params = filter_colo_parameters(*args, **kwargs) + if len(params) > 0: + with torch._C.DisableTorchFunction(): + new_args = ColoParamOpHookManager.pre_op(params, *args, *kwargs.values()) + args, kwargs = replace_args(args, kwargs, new_args) + ret = super().__torch_function__(func, types, args, kwargs) + with torch._C.DisableTorchFunction(): + ret = ColoParamOpHookManager.post_op(params, ret) + return _convert_output(ret, func) return super().__torch_function__(func, types, args, kwargs) def __deepcopy__(self, memo): @@ -96,9 +72,7 @@ class ColoParameter(ColoTensor, torch.nn.Parameter): else: with torch._C.DisableTorchFunction(): data = self.data.clone() - tensor = ColoParameter(data, - self.requires_grad, - spec=ColoTensorSpec(self.get_process_group(), self.dist_spec, self.compute_spec)) + tensor = ColoParameter(data, self.requires_grad) memo[id(self)] = tensor return tensor diff --git a/colossalai/tensor/colo_tensor.py b/colossalai/tensor/colo_tensor.py index 4d762076461d085cd51aefd1b42d9c77491fa032..c2de9abce371e81e3dc92fd2f343fd4da29cc0db 100644 --- a/colossalai/tensor/colo_tensor.py +++ b/colossalai/tensor/colo_tensor.py @@ -1,17 +1,14 @@ -import operator -from copy import copy -from functools import lru_cache, reduce -from typing import Callable, Optional, Set +from functools import lru_cache +from typing import Callable, Set import torch -from colossalai.tensor.dist_spec_mgr import DistSpecManager -from colossalai.tensor.distspec import DistPlacementPattern, ReplicaSpec, _DistSpec -from colossalai.tensor.process_group import ProcessGroup -from colossalai.tensor.tensor_spec import ColoTensorSpec - -from .const import TensorType -from .op_wrapper import _COLOSSAL_OPS +INPALCE_MAPPING = { + torch.Tensor.add_: torch.Tensor.add, + torch.Tensor.sub_: torch.Tensor.sub, + torch.Tensor.mul_: torch.Tensor.mul, + torch.Tensor.div_: torch.Tensor.div, +} @lru_cache(None) @@ -21,65 +18,42 @@ def _get_my_nowrap_functions() -> Set[Callable]: Tensor._base.__get__, Tensor.grad.__get__, Tensor._grad.__get__, - Tensor.data.__get__, # make .data returns torch.Tensor rather than ColoTensor + Tensor.data.__get__, # make .data returns torch.Tensor rather than ColoTensor } -def _convert_output(output, colo_spec: ColoTensorSpec): - if type(output) == torch.Tensor: - return ColoTensor.from_torch_tensor(output, colo_spec) +def _convert(output): + if isinstance(output, torch.Tensor) and not isinstance(output, ColoTensor): + output.__class__ = ColoTensor elif isinstance(output, (list, tuple)): - return type(output)(_convert_output(o, colo_spec) for o in output) - else: - return output + output = type(output)(_convert(o) for o in output) + return output -def _get_spec_from_args(args, kwargs) -> ColoTensorSpec: - for elem in args: - if isinstance(elem, ColoTensor): - pg = elem.get_process_group() - dp = elem.dist_spec - return ColoTensorSpec(pg, dp) - elif isinstance(elem, (list, tuple)): - spec = _get_spec_from_args(elem, {}) - if spec is not None: - return spec - for k, v in kwargs.items(): - if isinstance(v, ColoTensor): - pg = v.get_process_group() - dp = v.dist_spec - return ColoTensorSpec(pg, dp) - return None +def _convert_output(output, func): + if func in _get_my_nowrap_functions(): + return output + return _convert(output) class ColoTensor(torch.Tensor): - """ Data Structure for Tensor in Colossal-AI. It is a subclass of torch.Tensor. + """Data Structure for Tensor in Colossal-AI. It is a subclass of torch.Tensor. - The Colotensor can be initialized with a PyTorch tensor in the following ways. - - >>> pg = ProcessGroup() - >>> colo_t1 = ColoTensor(torch.randn(2,3), spec = ColoTensorSpec(pg, ReplicaSpec())) - >>> # The tensor passed in is a tensor after sharding but not a global tensor. - >>> shard_spec = ShardSpec(process_group=ProcessGroup(tp=world_size), - >>> dims=[0], - >>> num_partitions=[world_size]) - >>> tensor_spec = ColoTensorSpec(pg, shard_spec) - >>> colo_t2 = ColoTensor.from_torch_tensor(t_ref.clone(), tensor_spec) + It is only used to trigger the torch function hook. Args: data (torch.Tensor): a torch tensor used as the payload the colotensor. - spec (ColoTensorSpec, optional): the tensor spec of initialization. Defaults to ColoTensorSpec(ReplicaSpec()). """ - torch_major = int(torch.__version__.split('.')[0]) - torch_minor = int(torch.__version__.split('.')[1]) - def __new__(cls, data: torch.Tensor, spec: ColoTensorSpec) -> 'ColoTensor': + torch_major = int(torch.__version__.split(".")[0]) + torch_minor = int(torch.__version__.split(".")[1]) + + def __new__(cls, data: torch.Tensor) -> "ColoTensor": """ The signature of the __new__ has to be consistent with the torch.Tensor. Args: data (torch.Tensor): a torch tensor used as the payload the colotensor. - spec (TensorSpec, optional): the tensor spec of initialization. Returns: ColoTensor: a ColoTensor wrappers the data. @@ -88,86 +62,6 @@ class ColoTensor(torch.Tensor): data = torch.empty(0) return torch.Tensor._make_subclass(cls, data, data.requires_grad) - def __init__(self, data: torch.Tensor, spec: Optional[ColoTensorSpec] = None) -> None: - # If not set spec, use a DP process group and replicate dist spec - if spec is None: - self.has_initialized = False - self.dist_spec = ReplicaSpec() - self.compute_spec = None - self.process_group = ProcessGroup() - else: - self.has_initialized = True - self.dist_spec = spec.dist_attr - self.compute_spec = spec.compute_attr - if spec.pg is None: - self.process_group = ProcessGroup() - else: - self.process_group = spec.pg - - self._type = TensorType.NONMODEL - - def has_compute_spec(self) -> bool: - return self.compute_spec is not None - - def is_model_data(self) -> bool: - return self._type == TensorType.MODEL - - def get_process_group(self) -> 'ProcessGroup': - return self.process_group - - def set_process_group(self, pg: ProcessGroup): - """set_process_group - change the pg of the ColoTensor. Note that the valid use cases is limited. - It works for the target pg is DP and TP only and current dist spec of the Tensor is Replica. - - Args: - pg (ProcessGroup): target pg - - """ - assert isinstance(pg, ProcessGroup), f"pg as type {type(pg)} is invalid" - # if the new pg is the same as the old pg, just returns - if self.process_group == pg: - return - assert self.process_group.tp_world_size() == 1 or self.process_group.dp_world_size() == 1, \ - "Can not set_process_group on a ColoTensor whose process_group is both tp > 1 and world group > 1" - assert self.dist_spec.placement.value == 'r', \ - "Can not set_process_group on a ColoTensor whose dist spec is not Replica" - - self.process_group = pg - - def get_tp_world_size(self) -> int: - return self.process_group.tp_world_size() - - def get_dp_world_size(self) -> int: - """get_dp_world_size - get the dp world size of the tensor. - - Returns: - int: dp world size - """ - return self.process_group.dp_world_size() - - def set_dist_spec(self, dist_spec: _DistSpec): - """set_dist_spec - set dist spec and change the payloads. - - Args: - dist_spec (_DistSpec): target dist spec. - """ - assert isinstance(dist_spec, _DistSpec) - assert self.process_group is not None - self._redistribute(dist_spec) - - def set_tensor_spec(self, dist_spec, compute_spec): - if dist_spec is not None: - assert isinstance(dist_spec, _DistSpec), f"{type(dist_spec)}" - self.set_dist_spec(dist_spec) - if compute_spec is not None: - self.compute_spec = compute_spec - - def has_compute_pattern(self, compute_pattern): - return self.compute_spec.compute_pattern == compute_pattern - @classmethod def __torch_function__(cls, func, types, args=(), kwargs=None): if kwargs is None: @@ -175,108 +69,27 @@ class ColoTensor(torch.Tensor): if not all(issubclass(cls, t) for t in types): return NotImplemented - global _COLOSSAL_OPS - if func in _COLOSSAL_OPS: - func = _COLOSSAL_OPS[func] if cls.torch_major > 1 or (cls.torch_major == 1 and cls.torch_minor >= 12): # in order to trigger pre-op hook in the forward of checkpoint module # we have to capture the `backward` function # and make sure that it does not in `torch._C.DisableTorchFunction()` context if func is torch.Tensor.backward: - assert len(args) == 1 # only has 1 parameter + assert len(args) == 1 # only has 1 parameter backward_tensor = torch.Tensor(args[0]) tensor_kwargs = {k: torch.Tensor(v) if torch.is_tensor(v) else v for k, v in kwargs.items()} return backward_tensor.backward(**tensor_kwargs) + # replace the in-place function + if func in INPALCE_MAPPING: + func = INPALCE_MAPPING[func] + # set the 'inplace' kwargs to False + if "inplace" in kwargs: + kwargs["inplace"] = False + with torch._C.DisableTorchFunction(): ret = func(*args, **kwargs) - if func in _get_my_nowrap_functions(): - return ret - else: - colo_spec = _get_spec_from_args(args, kwargs) - return _convert_output(ret, colo_spec) - - def __repr__(self): - output_list = [super(ColoTensor, self).__repr__()] - output_list.append(str(self.process_group)) - output_list.append(str(self.dist_spec)) - if self.compute_spec is not None: - output_list.append(str(self.compute_spec)) - return "\n".join(output_list) - - def _redistribute(self, dist_spec: _DistSpec) -> None: - """_redistribute - Note the function will not handle the logic of backward propagation! - It is used during model tensor initializations as an internal function. - - Args: - dist_spec (_DistSpec): the target dist. spec. - """ - assert self.grad_fn is None, "Current tensor has grad_fn and it can't get converted" - with DistSpecManager.no_grad(): - self.data = DistSpecManager.handle_trans_spec(self.data, self.dist_spec, dist_spec, self.process_group) - self.dist_spec = dist_spec - - def redistribute(self, dist_spec: _DistSpec, pg: Optional[ProcessGroup] = None) -> 'ColoTensor': - """redistribute - Redistribute the tensor among processes. The rule is like this: - - 1. If the pg is None, then redistribute the tensor payload among the TP process group. Keep the - DP process group not changed. - - 2. If the pg is not not None and not equal to the current process group. - First, convert the tensor as replicated among the TP process group. - Second, reset the process group to the new pg. - Third, convert the tensor (new replicated both among the tp process group) to the new dist_spec. - - Args: - dist_spec (_DistSpec): the new dist spec. - pg (Optional[ProcessGroup], optional): the new process group . Defaults to None. - - Returns: - ColoTensor: a redistributed colotensor - """ - if pg is not None and pg != self.get_process_group(): - # if the pg is not equal, convert the current tensor to replicated - handled = self.redistribute(ReplicaSpec()) - else: - handled = self - pg = self.process_group - - ret = DistSpecManager.handle_trans_spec(handled, handled.dist_spec, dist_spec, pg) - return ColoTensor.from_torch_tensor(ret, ColoTensorSpec(pg=pg, dist_attr=dist_spec)) - - def to_replicate_(self): - """to_replicate_ - - an inline member function, converting dist spec of the tensor to REPLICATE - """ - self._redistribute(dist_spec=ReplicaSpec()) - - def to_replicate(self) -> 'ColoTensor': - """to_replicate - - converting dist spec of the tensor to ReplicaSpec() - """ - return self.redistribute(ReplicaSpec()) - - @staticmethod - def from_torch_tensor(tensor: torch.Tensor, spec: Optional[ColoTensorSpec] = None) -> 'ColoTensor': - """from_torch_tensor - - A static method builds a `ColoTensor` from a PyTorch Tensor. - - Args: - tensor (torch.Tensor): the pytorch tensor, which is a local tensor for this rank not a global tensor. - spec (Optional[ColoTensorSpec], optional): tensor spec. Defaults to None. - - Returns: - ColoTensor: a ColoTensor - """ - tensor = tensor.as_subclass(ColoTensor) - tensor.__init__(tensor, spec=spec) - return tensor + return _convert_output(ret, func) def __deepcopy__(self, memo): if id(self) in memo: @@ -284,60 +97,6 @@ class ColoTensor(torch.Tensor): else: with torch._C.DisableTorchFunction(): data = self.data.clone() - tensor = ColoTensor(data, spec=copy(ColoTensorSpec(self.process_group, self.dist_spec, self.compute_spec))) + tensor = ColoTensor(data) memo[id(self)] = tensor return tensor - - # override builtin functions which must use tensor in replicate placement # - - def size_local(self, *args) -> torch.Size: - with torch._C.DisableTorchFunction(): - return super().size(*args) - - def size_global(self, *args) -> torch.Size: - """size_global - - override the torch building size() - the shape passed in must be in a replicate placement. - - Returns: - torch.Size: the global tensor shape - """ - if self.is_replicate(): - return self.size_local(*args) - spec = self.dist_spec - dims = spec.dims - num_partitions = spec.num_partitions - # import inspect - # print(*['{:40}| {}:{}\n'.format(x.function, x.filename, x.lineno) for x in inspect.stack()]) - size_list = list(self.size_local()) - for dim, num_partition in zip(dims, num_partitions): - size_list[dim] *= num_partition - if args == (): - return torch.Size(size_list) - else: - return size_list[args[0]] - - def numel_global(self): - """Returns the number of elements in the tensor when it's replicated. - """ - return reduce(operator.mul, self.size_global(), 1) - - # Some API for dist spec check - - def is_replicate(self): - return self.dist_spec.placement == DistPlacementPattern.REPLICATE \ - or (len(self.dist_spec.num_partitions) == 1 - and self.dist_spec.num_partitions[0] == 1) \ - or (self.process_group.tp_world_size() == 1) - - def is_shard_1dcol(self): - return self.dist_spec.placement == DistPlacementPattern.SHARD \ - and len(self.dist_spec.dims) == 1 and self.dist_spec.dims[0] == -1 - - def is_shard_1drow(self): - return self.dist_spec.placement == DistPlacementPattern.SHARD \ - and len(self.dist_spec.dims) == 1 and self.dist_spec.dims[0] == 0 - - def is_sharded(self): - return self.dist_spec.placement == DistPlacementPattern.SHARD diff --git a/colossalai/tensor/comm_spec.py b/colossalai/tensor/comm_spec.py index af38d2a502c25ef76a6fb99650176781e09bda83..de0cba26b52a827f4271edb45d2912d35a123f77 100644 --- a/colossalai/tensor/comm_spec.py +++ b/colossalai/tensor/comm_spec.py @@ -7,82 +7,79 @@ import torch.distributed as dist from torch.distributed import ReduceOp __all__ = [ - 'CollectiveCommPattern', - 'CommSpec', + "CollectiveCommPattern", + "CommSpec", ] def _all_gather(tensor, comm_spec): - ''' + """ Implement all gather operation on device mesh based on information provided by comm_spec. - ''' - process_groups_list = comm_spec.device_mesh.process_groups_dict[comm_spec.logical_process_axis] - for rank_list, process_group in process_groups_list: - if dist.get_rank() in rank_list: - tensor_list = [ - torch.zeros(tensor.shape, dtype=tensor.dtype, device=tensor.device) - for _ in range(comm_spec.device_mesh.mesh_shape[comm_spec.logical_process_axis]) - ] - # without this contiguous operation, the all gather may get some unexpected results. - tensor = tensor.contiguous() - dist.all_gather(tensor_list, tensor, group=process_group) - output = torch.cat(tuple(tensor_list), comm_spec.gather_dim).contiguous() - return output + """ + process_groups = comm_spec.device_mesh.get_process_group_for_all_axes() + process_group = process_groups[comm_spec.logical_process_axis] + + tensor_list = [ + torch.zeros(tensor.shape, dtype=tensor.dtype, device=tensor.device) + for _ in range(comm_spec.device_mesh.shape[comm_spec.logical_process_axis]) + ] + # without this contiguous operation, the all gather may get some unexpected results. + tensor = tensor.contiguous() + dist.all_gather(tensor_list, tensor, group=process_group) + output = torch.cat(tuple(tensor_list), comm_spec.gather_dim).contiguous() + return output def _split(tensor, comm_spec): - ''' + """ Implement shard operation on device mesh based on information provided by comm_spec. - ''' - process_groups_list = comm_spec.device_mesh.process_groups_dict[comm_spec.logical_process_axis] - for rank_list, _ in process_groups_list: - if dist.get_rank() in rank_list: - dim = comm_spec.shard_dim - length = tensor.shape[comm_spec.shard_dim] // len(rank_list) - start = length * rank_list.index(dist.get_rank()) - output = torch.narrow(tensor, dim, start, length).contiguous() - return output + """ + process_groups = comm_spec.device_mesh.get_process_group_for_all_axes() + process_group = process_groups[comm_spec.logical_process_axis] + + dim = comm_spec.shard_dim + length = tensor.shape[comm_spec.shard_dim] // dist.get_world_size(process_group) + start = length * dist.get_rank(process_group) + output = torch.narrow(tensor, dim, start, length).contiguous() + return output def _all_to_all(tensor, comm_spec): - ''' + """ Implement all to all operation on device mesh based on information provided by comm_spec. - ''' - process_groups_list = comm_spec.device_mesh.process_groups_dict[comm_spec.logical_process_axis] - for rank_list, process_group in process_groups_list: - if dist.get_rank() in rank_list: - new_shape = list(tensor.shape) - new_shape[comm_spec.shard_dim] = new_shape[comm_spec.shard_dim] // len(rank_list) - new_shape = torch.Size(new_shape) - output_tensor_list = [ - torch.zeros(new_shape, dtype=tensor.dtype, device=tensor.device) for _ in range(len(rank_list)) - ] - dim = comm_spec.shard_dim - length = tensor.shape[comm_spec.shard_dim] // len(rank_list) - input_tensor_list = [ - torch.narrow(tensor, dim, length * i, length).contiguous() for i in range(len(rank_list)) - ] - group = process_group - dist.all_to_all(output_tensor_list, input_tensor_list, group) - output = torch.cat(tuple(output_tensor_list), comm_spec.gather_dim).contiguous() - return output + """ + process_groups = comm_spec.device_mesh.get_process_group_for_all_axes() + process_group = process_groups[comm_spec.logical_process_axis] + world_size = dist.get_world_size(process_group) + + new_shape = list(tensor.shape) + new_shape[comm_spec.shard_dim] = new_shape[comm_spec.shard_dim] // world_size + new_shape = torch.Size(new_shape) + output_tensor_list = [torch.zeros(new_shape, dtype=tensor.dtype, device=tensor.device) for _ in range(world_size)] + dim = comm_spec.shard_dim + length = tensor.shape[comm_spec.shard_dim] // world_size + input_tensor_list = [torch.narrow(tensor, dim, length * i, length).contiguous() for i in range(world_size)] + group = process_group + dist.all_to_all(output_tensor_list, input_tensor_list, group) + output = torch.cat(tuple(output_tensor_list), comm_spec.gather_dim).contiguous() + return output def _all_reduce(tensor, comm_spec, async_op=False): - ''' + """ Implement all reduce operation on device mesh based on information provided by comm_spec. - ''' - process_groups_list = comm_spec.device_mesh.process_groups_dict[comm_spec.logical_process_axis] - for rank_list, process_group in process_groups_list: - if dist.get_rank() in rank_list: - if not tensor.is_contiguous(): - tensor = tensor.contiguous() - dist.all_reduce(tensor, op=ReduceOp.SUM, group=process_group, async_op=async_op) - return tensor + """ + process_groups = comm_spec.device_mesh.get_process_group_for_all_axes() + process_group = process_groups[comm_spec.logical_process_axis] + + if not tensor.is_contiguous(): + tensor = tensor.contiguous() + dist.all_reduce(tensor, op=ReduceOp.SUM, group=process_group, async_op=async_op) + return tensor def _mix_gather(tensor, comm_spec): - ''' + """ Implement mix gather operation on device mesh based on information provided by comm_spec. Mix gather is the all-gather operation on all devices in the device_mesh(FlattenDeviceMesh) of the comm_spec. It is different from _all_gather because _mix_gather does all-gather in two dimensions of device mesh, while _all_gather @@ -127,8 +124,8 @@ def _mix_gather(tensor, comm_spec): leading_group_dim = 1 process_group = "[0, 1, 2, 3, 4, 5, 6, 7]" tensor_list = [(0,0),(1,0),(2,0),(3,0),(4,0),(5,0),(6,0),(7,0)] - ''' - total_slices = comm_spec.device_mesh.mesh_shape[0] + """ + total_slices = comm_spec.device_mesh.shape[0] tensor_list = [torch.zeros(tensor.shape, dtype=tensor.dtype, device=tensor.device) for _ in range(total_slices)] leading_group_dim = comm_spec.logical_process_axes[0] assert len(comm_spec.device_mesh.process_groups_dict) == 1 @@ -149,7 +146,7 @@ def _mix_gather(tensor, comm_spec): if comm_spec.logical_process_axes[0] == comm_spec.logical_process_axes[1]: output = torch.cat(tuple(tensor_list), comm_spec.gather_dim[0]).contiguous() else: - mesh_shape = comm_spec.device_meshes.mesh_shape + mesh_shape = comm_spec.device_meshes.shape cat_slice = [mesh_shape[comm_spec.logical_process_axes[0]], mesh_shape[comm_spec.logical_process_axes[1]]] tmp_tensor_shape = list(tensor.shape) tmp_tensor_shape[comm_spec.gather_dim[0]] *= cat_slice[0] @@ -158,15 +155,16 @@ def _mix_gather(tensor, comm_spec): torch.zeros(tmp_tensor_shape, dtype=tensor.dtype, device=tensor.device) for _ in range(cat_slice[1]) ] for i in range(cat_slice[1]): - tmp_tensor_list[i] = torch.cat(tuple(tensor_list[i * cat_slice[0]:(i + 1) * cat_slice[0]]), - comm_spec.gather_dim[0]).contiguous() + tmp_tensor_list[i] = torch.cat( + tuple(tensor_list[i * cat_slice[0] : (i + 1) * cat_slice[0]]), comm_spec.gather_dim[0] + ).contiguous() output = torch.cat(tuple(tmp_tensor_list), comm_spec.gather_dim[1]).contiguous() return output def _mix_split(tensor, comm_spec): - ''' + """ Implement mix split operation. Mix split is only called for the backward of mix gather (Use ctx to keep consistent) Mix split shards the tensor on device mesh based on information provided by comm_spec. It is different from split because _mix_split shards the tensor in two dimensions of device mesh, while _split only shards in one dimension. @@ -180,10 +178,10 @@ def _mix_split(tensor, comm_spec): # [[0, 1, 2, 3], # [4, 5, 6, 7]] # return {0: [0, 4, 1, 5, 2, 6, 3, 7], 1: [0, 1, 2, 3, 4, 5, 6, 7]} - ''' - mesh_shape = comm_spec.device_meshes.mesh_shape + """ + mesh_shape = comm_spec.device_meshes.shape dim = comm_spec.gather_dim - total_slices = comm_spec.device_mesh.mesh_shape[0] + total_slices = comm_spec.device_mesh.shape[0] # Get global rank rank = dist.get_rank() @@ -319,11 +317,13 @@ class _AllToAll(torch.autograd.Function): @staticmethod def forward(ctx, input_, comm_spec): output = _all_to_all(input_, comm_spec) - comm_spec_for_backward = CommSpec(comm_pattern=comm_spec.comm_pattern, - sharding_spec=comm_spec.sharding_spec, - gather_dim=comm_spec.shard_dim, - shard_dim=comm_spec.gather_dim, - logical_process_axis=comm_spec.logical_process_axis) + comm_spec_for_backward = CommSpec( + comm_pattern=comm_spec.comm_pattern, + sharding_spec=comm_spec.sharding_spec, + gather_dim=comm_spec.shard_dim, + shard_dim=comm_spec.gather_dim, + logical_process_axis=comm_spec.logical_process_axis, + ) ctx.comm_spec = comm_spec_for_backward return output @@ -333,7 +333,6 @@ class _AllToAll(torch.autograd.Function): class _MixGatherForwardMixSplitBackward(torch.autograd.Function): - @staticmethod def symbolic(graph, input_): return _mix_gather(input_) @@ -373,16 +372,16 @@ def mixgather_forward_split_backward(input_, comm_spec): class CollectiveCommPattern(Enum): - GATHER_FWD_SPLIT_BWD = 'gather_fwd_split_bwd' - ALL2ALL_FWD_ALL2ALL_BWD = 'all2all_fwd_all2all_bwd' - SPLIT_FWD_GATHER_BWD = 'split_fwd_gather_bwd' - ALLREDUCE_FWD_IDENTITY_BWD = 'all_reduce_fwd_identity_bwd' - IDENTITY_FWD_ALLREDUCE_BWD = 'identity_fwd_all_reduce_bwd' + GATHER_FWD_SPLIT_BWD = "gather_fwd_split_bwd" + ALL2ALL_FWD_ALL2ALL_BWD = "all2all_fwd_all2all_bwd" + SPLIT_FWD_GATHER_BWD = "split_fwd_gather_bwd" + ALLREDUCE_FWD_IDENTITY_BWD = "all_reduce_fwd_identity_bwd" + IDENTITY_FWD_ALLREDUCE_BWD = "identity_fwd_all_reduce_bwd" MIXGATHER_FWD_SPLIT_BWD = "mixgather_fwd_split_bwd" class CommSpec: - ''' + """ Communication spec is used to record the communication action. It has two main functions: 1. Compute the communication cost which will be used in auto parallel solver. 2. Convert the communication spec to real action which will be used in runtime. @@ -396,16 +395,18 @@ class CommSpec: gather_dim(int, Optional): The gather_dim of the tensor will be gathered. shard_dim(int, Optional): The shard_dim of the tensor will be sharded. logical_process_axis(Union(int, List[int]), Optional): The mesh_dim to implement the communication action. - ''' - - def __init__(self, - comm_pattern, - sharding_spec, - gather_dim=None, - shard_dim=None, - logical_process_axis=None, - forward_only=False, - mix_gather=False): + """ + + def __init__( + self, + comm_pattern, + sharding_spec, + gather_dim=None, + shard_dim=None, + logical_process_axis=None, + forward_only=False, + mix_gather=False, + ): self.comm_pattern = comm_pattern self.sharding_spec = sharding_spec self.gather_dim = gather_dim @@ -414,7 +415,7 @@ class CommSpec: self.forward_only = forward_only if isinstance(self.logical_process_axis, list): if not mix_gather: - self.device_mesh = self.sharding_spec.device_mesh.flatten_device_mesh + self.device_mesh = self.sharding_spec.device_mesh.flatten() self.logical_process_axis = 0 else: self.device_meshes = self.sharding_spec.device_mesh.flatten_device_meshes @@ -452,14 +453,14 @@ class CommSpec: res_list.append(f"gather_dim:{self.gather_dim}, ") res_list.append(f"logical_process_asex:{self.logical_process_axes})") - return ''.join(res_list) + return "".join(res_list) def get_comm_cost(self): - ''' + """ For all_gather, all2all, and all_reduce operation, the formula provided in DeviceMesh with alpha-beta model is used to compute the communication cost. For shard operation, it is an on-chip operation, so the communication cost is zero. - ''' + """ comm_size = reduce(operator.mul, self.sharding_spec.get_sharded_shape_per_device(), 1) cost_dict = {} if self.comm_pattern == CollectiveCommPattern.GATHER_FWD_SPLIT_BWD: @@ -503,13 +504,13 @@ class CommSpec: return cost_dict def covert_spec_to_action(self, tensor): - ''' + """ Convert CommSpec into runtime action, implement real collection communication to target tensor. The collection communication action is directed by the CommSpec. Argument: tensor(torch.Tensor): Tensor stored in each device, which could be different in different ranks. - ''' + """ if self.comm_pattern in pattern_to_func_dict: tensor = pattern_to_func_dict[self.comm_pattern](tensor, self) else: diff --git a/colossalai/tensor/const.py b/colossalai/tensor/const.py deleted file mode 100644 index 356e8ecc885a3fb24766683b106a91ca2fac44eb..0000000000000000000000000000000000000000 --- a/colossalai/tensor/const.py +++ /dev/null @@ -1,6 +0,0 @@ -from enum import Enum - - -class TensorType(Enum): - MODEL = 0 - NONMODEL = 1 # mainly activations diff --git a/colossalai/tensor/d_tensor/README.md b/colossalai/tensor/d_tensor/README.md new file mode 100644 index 0000000000000000000000000000000000000000..3d862dddbf2036ee31f74d9b02cb47a269aeb703 --- /dev/null +++ b/colossalai/tensor/d_tensor/README.md @@ -0,0 +1,103 @@ +# 🔢 Distributed Tensor + +## 📚 Table of Contents + +- [🔢 Distributed Tensor](#-distributed-tensor) + - [📚 Table of Contents](#-table-of-contents) + - [🔗 Introduction](#-introduction) + - [📝 Design](#-design) + - [🔨 Usage](#-usage) + - [🎈 Progress Log](#-progress-log) + +## 🔗 Introduction + +Distributed tensor is a type of tensor that is distributed across multiple devices. It is a wrapper of PyTorch tensor, and it is used to support distributed training. +It can represent the device topology and tensor placement over the devices in the topology. It also provides a set of APIs to manipulate the distributed tensor. + +## 📝 Design + +Our implementation is inspired by the work [Alpa](https://arxiv.org/abs/2201.12023), which unifies data parallelism and tensor parallelism as intra-op parallelism. It uses notations `S` to represent the sharded dimension and `R` to represent the replicated dimension. For example, given a 2D matrix, `[S, R]` represents the tensor is sharded over the first dimension. + +Each sharded dimension will have a subscript to represent its placement over the devices. Assuming we have 4 GPUs and the GPUs are arranged in a 2 x 2 manner. Let's say we have a 2D matrix like below: + + +```text + [1, 2, 3, 4 ] +A = [4, 5, 6, 7 ] + [8, 9, 10, 11] + [12, 13, 14, 15] +``` + +`[S0, R]` would mean that the first dimension is sharded over the rows in the device topology. + +```text +| --------------------—————————————————————-| +| | | +| [1, 2, 3, 4 ] | [1, 2, 3, 4 ] | +| [4, 5, 6, 7 ] | [4, 5, 6, 7 ] | +| | | +| --------------------——————————————————----- +| | | +| [8, 9, 10, 11] | [8, 9, 10, 11] | +| [12, 13, 14, 15] | [12, 13, 14, 15] | +| | | +| --------------------——————————————————----- +``` + +`[S01, R]` would mean that the first dimension is sharded over both the row and column in the device topology. + +```text +| --------------------—————————————————————-| +| | | +| [1, 2, 3, 4 ] | [4, 5, 6, 7 ] | +| | | +| --------------------——————————————————----- +| | | +| [8, 9, 10, 11] | [12, 13, 14, 15] | +| | | +| --------------------——————————————————----- +``` + +## 🔨 Usage + +A sample API usage is given below. + +```python +import torch + +import colossalai +from colossalai.device.device_mesh import DeviceMesh +from colossalai.tensor.d_tensor import DTensor, ShardingSpec + +colossalai.launch_from_torch(config={}) + +# define your device mesh +# assume you have 4 GPUs +physical_mesh_id = torch.arange(0, 4) +mesh_shape = (2, 2) +device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True) + +# define a tensor +a = torch.rand(16, 32).cuda() + +# create sharding spec for the tensor +# assume the sharding spec is [S0, R] +dim_partition_dict = {0: [0]} +sharding_spec = ShardingSpec(a.dim(), dim_partition_dict) + +# create a distributed tensor +d_tensor = DTensor(a, device_mesh, sharding_spec) +print(d_tensor) + +global_tensor = d_tensor.to_global() +print(global_tensor) +``` + + +## 🎈 Progress Log + +- [x] Support layout conversion +- [x] Support sharding on 2D device mesh +- [ ] Support sharding on 3D device mesh +- [ ] Support sharding 4D device mesh +- [ ] Support sharding info saving and offline tensor merge (we can save tensor as dtensor and gather the tensors back to the global tensor based on the sharding info in a single process in CPU, useful for distributed training checkpoint load and save.) diff --git a/colossalai/tensor/d_tensor/__init__.py b/colossalai/tensor/d_tensor/__init__.py index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..fad5101d380cd1476f65600707e20b914b3b3303 100644 --- a/colossalai/tensor/d_tensor/__init__.py +++ b/colossalai/tensor/d_tensor/__init__.py @@ -0,0 +1,43 @@ +from .api import ( + compute_global_numel, + customized_distributed_tensor_to_param, + distribute_tensor, + distribute_tensor_with_customization, + get_device_mesh, + get_global_shape, + get_layout, + get_sharding_spec, + is_customized_distributed_tensor, + is_distributed_tensor, + is_sharded, + redistribute, + shard_colwise, + shard_rowwise, + sharded_tensor_to_param, + to_global, + to_global_for_customized_distributed_tensor, +) +from .layout import Layout +from .sharding_spec import ShardingSpec + +__all__ = [ + "is_distributed_tensor", + "distribute_tensor", + "to_global", + "is_sharded", + "shard_rowwise", + "shard_colwise", + "sharded_tensor_to_param", + "compute_global_numel", + "get_sharding_spec", + "get_global_shape", + "get_device_mesh", + "redistribute", + "get_layout", + "is_customized_distributed_tensor", + "distribute_tensor_with_customization", + "to_global_for_customized_distributed_tensor", + "customized_distributed_tensor_to_param", + "Layout", + "ShardingSpec", +] diff --git a/colossalai/tensor/d_tensor/api.py b/colossalai/tensor/d_tensor/api.py new file mode 100644 index 0000000000000000000000000000000000000000..178bac428ea94dd0c4155838f1085d373ea41ea4 --- /dev/null +++ b/colossalai/tensor/d_tensor/api.py @@ -0,0 +1,461 @@ +import copy +import operator +from functools import reduce +from typing import Union + +import torch +import torch.distributed as dist +from torch.distributed import ProcessGroup + +from colossalai.device.device_mesh import DeviceMesh + +from .layout import Layout +from .layout_converter import LayoutConverter +from .sharding_spec import ShardingSpec + +layout_converter = LayoutConverter() + + +def clear_layout_converter(): + global layout_converter + layout_converter.cached_solution.clear() + + +def is_distributed_tensor(tensor: torch.Tensor) -> bool: + """ + Check whether the given tensor is a distributed tensor. + + Args: + tensor (torch.Tensor): The tensor to be checked. + + Returns: + bool: Whether the given tensor is a distributed tensor. + """ + return hasattr(tensor, "dist_layout") + + +def is_sharded(dtensor: torch.Tensor) -> bool: + """ + Check if a tensor is sharded. + + Args: + tensor (torch.Tensor): The tensor to be checked. + + Returns: + bool: True if the tensor is sharded, False otherwise. + """ + assert is_distributed_tensor(dtensor), "The input tensor is not a distributed tensor." + return list(dtensor.shape) == list(dtensor.dist_layout.global_shape) + + +def _hijack_detach_and_clone(dtensor: torch.Tensor) -> torch.Tensor: + """ + Hijack the detach and clone methods of the tensor to make sure the dist_layout is copied. + + Args: + tensor (torch.Tensor): The tensor to be hijacked. + + Returns: + torch.Tensor: The hijacked tensor. + """ + dtensor._old_detach = dtensor.detach + dtensor._old_clone = dtensor.clone + + def new_detach(self): + t_ = self._old_detach() + t_.dist_layout = copy.deepcopy(self.dist_layout) + return t_ + + def new_clone(self, *args, **kwargs): + t_ = self._old_clone(*args, **kwargs) + t_.dist_layout = copy.deepcopy(self.dist_layout) + return t_ + + # bind the new methods to the tensor + dtensor.detach = new_detach.__get__(dtensor) + dtensor.clone = new_clone.__get__(dtensor) + return dtensor + + +def _construct_default_sharding_spec( + tensor: torch.Tensor, +) -> ShardingSpec: + """ + Construct the default sharding specification for the tensor. + + Args: + tensor (`torch.Tensor`): the tensor to be sharded. + + Returns: + A `ShardingSpec` object without any sharding specified. + """ + return ShardingSpec(dim_size=tensor.dim(), dim_partition_dict={}) + + +def _apply_layout(tensor, layout): + """ + Apply the layout to the local tensor during initializing process. + """ + # layout converter requires a source and target laytout + # we construct the source layer for an unsharded tensor + # and use self.dist_layer as the targer layout for the sharded tensor + source_spec = _construct_default_sharding_spec(tensor) + source_layout = Layout(device_mesh=layout.device_mesh, sharding_spec=source_spec, global_shape=tensor.shape) + sharded_tensor = layout_converter.apply(tensor=tensor, source_layout=source_layout, target_layout=layout) + return sharded_tensor + + +def distribute_tensor(tensor: torch.Tensor, device_mesh: DeviceMesh, sharding_spec: ShardingSpec) -> torch.Tensor: + """ + Convert the given tensor to a distributed tensor. + + Args: + tensor (torch.Tensor): The tensor to be converted. + device_mesh (DeviceMesh): The device mesh for abstraction of the compute devices. + sharding_spec (ShardingSpec): The sharding specification which describes how the tensor will be sharded. + + Returns: + torch.Tensor: The distributed tensor. + """ + assert not is_distributed_tensor(tensor), "The input tensor is already a distributed tensor." + dist_layout = Layout(device_mesh=device_mesh, sharding_spec=sharding_spec, global_shape=tensor.shape) + + # shard tensor + sharded_tensor = _apply_layout(tensor, dist_layout) + + # hack some tensor methods + _hijack_detach_and_clone(sharded_tensor) + + return sharded_tensor + + +def redistribute(dtensor: torch.Tensor, device_mesh: DeviceMesh, sharding_spec: ShardingSpec) -> None: + """ + Convert the layout of the tensor from source_spec to target_spec. + This will update the `local_tensor` and `dist_layout` in place. + + Args: + dtensor (torch.Tensor): the distributed tensor to be converted. + device_mesh (DeviceMesh): the device mesh for abstraction of the compute devices. + target_layout (Layout): the target layout specification. + """ + assert is_distributed_tensor(dtensor), "The input tensor is not a distributed tensor." + global_shape = get_global_shape(dtensor) + target_layout = Layout(device_mesh=device_mesh, sharding_spec=sharding_spec, global_shape=global_shape) + resharded_tensor = layout_converter.apply( + tensor=dtensor, source_layout=dtensor.dist_layout, target_layout=target_layout + ) + return resharded_tensor + + +def to_global(dtensor: torch.Tensor) -> torch.Tensor: + """ + Convert a distributed tensor to the global tensor with the given layout. + This function returns a native `torch.Tensor` object. + + Args: + dtensor (torch.Tensor): the distributed tensor to be converted. + + Returns: + torch.Tensor: the global tensor. + """ + assert is_distributed_tensor(dtensor), "The input tensor is not a distributed tensor." + layout_converter = LayoutConverter() + + global_sharding_spec = ShardingSpec(dtensor.dim(), {}) + device_mesh = get_device_mesh(dtensor) + global_shape = get_global_shape(dtensor) + global_layout = Layout(device_mesh=device_mesh, sharding_spec=global_sharding_spec, global_shape=global_shape) + + global_tensor = layout_converter.apply(dtensor, dtensor.dist_layout, global_layout) + return global_tensor + + +def shard_rowwise( + tensor: torch.Tensor, + group_or_device_mesh: Union[ProcessGroup, DeviceMesh] = None, +) -> torch.Tensor: + """ + Shard the first dim of the given tensor. + + Args: + tensor (torch.Tensor): The tensor to be sharded. + group_or_device_mesh (Union[ProcessGroup, DeviceMesh], optional): The group or device mesh to shard the tensor. + If None, the tensor will be sharded with respect to the global process group. + Defaults to None. + inplace (bool, optional): Whether to shard the tensor in-place. Defaults to False. + + Returns: + torch.Tensor: The sharded tensor. + """ + # if the group_or_device_mesh is None, we shard the tensor with respect to the global process group + if group_or_device_mesh is None: + group_or_device_mesh = dist.GroupMember.WORLD + + if isinstance(group_or_device_mesh, ProcessGroup): + device_mesh = DeviceMesh.from_process_group(group_or_device_mesh) + else: + assert len(group_or_device_mesh.shape) == 1, "Only 1D DeviceMesh is accepted for row-wise sharding." + device_mesh = group_or_device_mesh + + sharding_spec = ShardingSpec(dim_size=tensor.dim(), dim_partition_dict={0: [0]}) + + return distribute_tensor(tensor, device_mesh, sharding_spec) + + +def shard_colwise(tensor: torch.Tensor, group_or_device_mesh: Union[ProcessGroup, DeviceMesh] = None) -> torch.Tensor: + """ + Shard the first dim of the given tensor. + + Args: + tensor (torch.Tensor): The tensor to be sharded. + group_or_device_mesh (Union[ProcessGroup, DeviceMesh], optional): The group or device mesh to shard the tensor. + If None, the tensor will be sharded with respect to the global process group. + Defaults to None. + inplace (bool, optional): Whether to shard the tensor in-place. Defaults to False. + + Returns: + torch.Tensor: The sharded tensor. + """ + # if the group_or_device_mesh is None, we shard the tensor with respect to the global process group + if group_or_device_mesh is None: + group_or_device_mesh = dist.GroupMember.WORLD + + if isinstance(group_or_device_mesh, ProcessGroup): + device_mesh = DeviceMesh.from_process_group(group_or_device_mesh) + else: + assert len(group_or_device_mesh.shape) == 1, "Only 1D DeviceMesh is accepted for row-wise sharding." + device_mesh = group_or_device_mesh + sharding_spec = ShardingSpec(dim_size=tensor.dim(), dim_partition_dict={-1: [0]}) + + return distribute_tensor(tensor, device_mesh, sharding_spec) + + +def sharded_tensor_to_param(dtensor: torch.Tensor, requires_grad: bool = True): + assert is_distributed_tensor(dtensor), "The input tensor is not a distributed tensor." + param = torch.nn.Parameter(dtensor, requires_grad=requires_grad) + + # make it distributed as well + param.dist_layout = dtensor.dist_layout + _hijack_detach_and_clone(param) + + return param + + +def sharded_tensor_to_existing_param(dtensor: torch.Tensor, param: torch.nn.Parameter) -> None: + assert is_distributed_tensor(dtensor), "The input tensor is not a distributed tensor." + param.data = dtensor + # make it distributed as well + param.dist_layout = dtensor.dist_layout + _hijack_detach_and_clone(param) + + +def compute_global_numel(dtensor: torch.Tensor) -> int: + """ + Compute the global number of elements in the distributed tensor. + + Args: + dtensor (torch.Tensor): The distributed tensor. + + Returns: + int: The global number of elements in the distributed tensor. + """ + assert is_distributed_tensor(dtensor), "The input tensor is not a distributed tensor." + numel = reduce(operator.mul, dtensor.dist_layout.global_shape) + return numel + + +def get_layout(dtensor: torch.Tensor) -> Layout: + """ + Get the layout of the distributed tensor. + + Args: + dtensor (torch.Tensor): The distributed tensor. + + Returns: + Layout: The layout of the distributed tensor. + + """ + assert is_distributed_tensor(dtensor), "The input tensor is not a distributed tensor." + return dtensor.dist_layout + + +def get_global_shape(dtensor: torch.Tensor) -> torch.Size: + """ + Get the global shape of the distributed tensor. + + Args: + dtensor (torch.Tensor): The distributed tensor. + + Returns: + torch.Size: The global shape of the distributed tensor. + """ + assert is_distributed_tensor(dtensor), "The input tensor is not a distributed tensor." + return dtensor.dist_layout.global_shape + + +def get_device_mesh(dtensor: torch.Tensor) -> DeviceMesh: + """ + Get the device mesh of the distributed tensor. + + Args: + dtensor (torch.Tensor): The distributed tensor. + + Returns: + DeviceMesh: The device mesh of the distributed tensor. + """ + assert is_distributed_tensor(dtensor), "The input tensor is not a distributed tensor." + return dtensor.dist_layout.device_mesh + + +def get_sharding_spec(dtensor: torch.Tensor) -> ShardingSpec: + """ + Get the sharding spec of the distributed tensor. + + Args: + dtensor (torch.Tensor): The distributed tensor. + + Returns: + ShardingSpec: The sharding spec of the distributed tensor. + """ + assert is_distributed_tensor(dtensor), "The input tensor is not a distributed tensor." + return dtensor.dist_layout.sharding_spec + + +# ====================================================== +# Some sharding does not obey the SPMD style +# e.g. Fused QKV layer in GPT2 +# we support customize sharding with the following APIs +# ====================================================== +def is_customized_distributed_tensor(tensor: torch.Tensor): + """ + Check whether the given tensor is a customized distributed tensor. + + Args: + tensor (torch.Tensor): The tensor to be checked. + + Returns: + bool: Whether the given tensor is a customized distributed tensor. + """ + return hasattr(tensor, "shard_fn") and hasattr(tensor, "gather_fn") + + +def _hijack_detach_and_clone_for_customized_distributed_tensor(dtensor: torch.Tensor) -> torch.Tensor: + """ + Hijack the detach and clone methods of the tensor to make sure the dist_layout is copied. + + Args: + tensor (torch.Tensor): The tensor to be hijacked. + + Returns: + torch.Tensor: The hijacked tensor. + """ + dtensor._old_detach = dtensor.detach + dtensor._old_clone = dtensor.clone + + def new_detach(self): + t_ = self._old_detach() + t_.shard_fn = self.shard_fn + t_.gather_fn = self.gather_fn + return t_ + + def new_clone(self, *args, **kwargs): + t_ = self._old_clone(*args, **kwargs) + t_.shard_fn = self.shard_fn + t_.gather_fn = self.gather_fn + return t_ + + # bind the new methods to the tensor + dtensor.detach = new_detach.__get__(dtensor) + dtensor.clone = new_clone.__get__(dtensor) + return dtensor + + +def distribute_tensor_with_customization(tensor: torch.Tensor, shard_fn, gather_fn: callable): + """ + Distribute the given tensor with the given shard_fn and gather_fn. + + Example: + + ```python + # define shard and gather functions + def shard_fn(tensor): + rank = torch.distributed.get_rank() + world_size = torch.distributed.get_world_size() + return tensor.chunk(world_size, dim=0)[rank] + + def gather_fn(tensor): + rank = torch.distributed.get_rank() + world_size = torch.distributed.get_world_size() + shard_list = [torch.zeros_like(tensor) for _ in range(world_size)] + torch.distributed.all_gather(shard_list, tensor) + return torch.cat(shard_list, dim=0) + + # create a distributed tensor + tensor = torch.rand(4, 4) + dtensor = distribute_tensor_with_customization(tensor, shard_fn, gather_fn) + ``` + + Args: + tensor (torch.Tensor): The tensor to be distributed. + shard_fn (callable): The function to shard the tensor. + gather_fn (callable): The function to gather the tensor. + + Returns: + torch.Tensor: The distributed tensor. + """ + assert callable(shard_fn), "The shard_fn must be callable." + assert callable(gather_fn), "The gather_fn must be callable." + assert not is_distributed_tensor(tensor), "The input tensor is already a distributed tensor." + + sharded_tensor = shard_fn(tensor) + + # set the shard_fn and gather_fn as attributes of the distributed tensor + sharded_tensor.shard_fn = shard_fn + sharded_tensor.gather_fn = gather_fn + + # set the shard_fn and gather_fn as attributes of the distributed tensor + _hijack_detach_and_clone_for_customized_distributed_tensor(sharded_tensor) + + return sharded_tensor + + +def to_global_for_customized_distributed_tensor(dtensor: torch.Tensor) -> torch.Tensor: + """ + Gather the given tensor to the global tensor. + + Args: + dtensor (torch.Tensor): The distributed tensor. + + Returns: + torch.Tensor: The global tensor. + """ + assert is_customized_distributed_tensor(dtensor), "The input tensor is not a customized distributed tensor." + return dtensor.gather_fn(dtensor) + + +def customized_distributed_tensor_to_param(dtensor: torch.Tensor, requires_grad: bool = True): + """ + Convert the given customized distributed tensor to a parameter. + """ + assert is_customized_distributed_tensor(dtensor), "The input tensor is not a customized distributed tensor." + + param = torch.nn.Parameter(dtensor, requires_grad=requires_grad) + + # make it distributed as well + param.shard_fn = dtensor.shard_fn + param.gather_fn = dtensor.gather_fn + _hijack_detach_and_clone_for_customized_distributed_tensor(param) + return param + + +def customized_distributed_tensor_to_existing_param(dtensor: torch.Tensor, param: torch.nn.Parameter): + """ + Convert the given customized distributed tensor to an existing parameter. + """ + assert is_customized_distributed_tensor(dtensor), "The input tensor is not a customized distributed tensor." + + param.data = dtensor.data + param.shard_fn = dtensor.shard_fn + param.gather_fn = dtensor.gather_fn + _hijack_detach_and_clone_for_customized_distributed_tensor(param) diff --git a/colossalai/tensor/d_tensor/comm_spec.py b/colossalai/tensor/d_tensor/comm_spec.py index 765d8ec1b01a7857aa74e4ea4cb599869a308f1c..8f5b52aab8f88bcb80d9a1ae878b4eab3108f312 100644 --- a/colossalai/tensor/d_tensor/comm_spec.py +++ b/colossalai/tensor/d_tensor/comm_spec.py @@ -6,46 +6,48 @@ import torch.distributed as dist from torch.distributed import ReduceOp __all__ = [ - 'CollectiveCommPattern', - 'CommSpec', + "CollectiveCommPattern", + "CommSpec", ] class CollectiveCommPattern(Enum): - GATHER_FWD_SPLIT_BWD = 'gather_fwd_split_bwd' - ALL2ALL_FWD_ALL2ALL_BWD = 'all2all_fwd_all2all_bwd' - SPLIT_FWD_GATHER_BWD = 'split_fwd_gather_bwd' - ALLREDUCE_FWD_IDENTITY_BWD = 'all_reduce_fwd_identity_bwd' - IDENTITY_FWD_ALLREDUCE_BWD = 'identity_fwd_all_reduce_bwd' + GATHER_FWD_SPLIT_BWD = "gather_fwd_split_bwd" + ALL2ALL_FWD_ALL2ALL_BWD = "all2all_fwd_all2all_bwd" + SPLIT_FWD_GATHER_BWD = "split_fwd_gather_bwd" + ALLREDUCE_FWD_IDENTITY_BWD = "all_reduce_fwd_identity_bwd" + IDENTITY_FWD_ALLREDUCE_BWD = "identity_fwd_all_reduce_bwd" MIXGATHER_FWD_SPLIT_BWD = "mixgather_fwd_split_bwd" class CommSpec: - ''' + """ Communication spec is used to record the communication action. It converts the communication spec to real action which will be used in runtime. It contains comm_pattern to determine the - communication method, process_groups_dict to determine the process groups, gather_dim and shard_dim + communication method, process_group_dict to determine the process groups, gather_dim and shard_dim to determine the buffer shape, and logical_process_axis Argument: - comm_pattern(CollectiveCommPattern): decribe the communication method used in this spec. - process_groups_dict(Dict): A dict which contains the process groups used to apply this CommSpec. + comm_pattern(CollectiveCommPattern): describe the communication method used in this spec. + process_group_dict(Dict): A dict which contains the process groups used to apply this CommSpec. gather_dim(int, Optional): The gather_dim of the tensor will be gathered. shard_dim(int, Optional): The shard_dim of the tensor will be sharded. logical_process_axis(Union(int, List[int]), Optional): The mesh_dim to implement the communication action. - ''' - - def __init__(self, - comm_pattern: CollectiveCommPattern, - process_groups_dict: Dict, - gather_dim: int = None, - shard_dim: int = None, - logical_process_axis: int = None): + """ + + def __init__( + self, + comm_pattern: CollectiveCommPattern, + process_group_dict: Dict, + gather_dim: int = None, + shard_dim: int = None, + logical_process_axis: int = None, + ): self.comm_pattern = comm_pattern self.gather_dim = gather_dim self.shard_dim = shard_dim self.logical_process_axis = logical_process_axis - self.process_groups_dict = process_groups_dict + self.process_group_dict = process_group_dict def __repr__(self): res_list = ["CommSpec:("] @@ -71,16 +73,16 @@ class CommSpec: res_list.append(f"comm_pattern:IDENTITY_FWD_ALLREDUCE_BWD, ") res_list.append(f"logical_process_axis:{self.logical_process_axis})") - return ''.join(res_list) + return "".join(res_list) def covert_spec_to_action(self, tensor): - ''' + """ Convert CommSpec into runtime action, implement real collection communication to target tensor. The collection communication action is directed by the CommSpec. Argument: tensor(torch.Tensor): Tensor stored in each device, which could be different in different ranks. - ''' + """ if self.comm_pattern in pattern_to_func_dict: tensor = pattern_to_func_dict[self.comm_pattern](tensor, self) else: @@ -89,71 +91,59 @@ class CommSpec: def _all_gather(tensor: torch.Tensor, comm_spec: CommSpec): - ''' + """ Implement all gather operation on device mesh based on information provided by comm_spec. - ''' - process_groups_list = comm_spec.process_groups_dict[comm_spec.logical_process_axis] - for rank_list, process_group in process_groups_list: - if dist.get_rank() in rank_list: - tensor_list = [ - torch.zeros(tensor.shape, dtype=tensor.dtype, device=tensor.device) for _ in range(len(rank_list)) - ] - # without this contiguous operation, the all gather may get some unexpected results. - tensor = tensor.contiguous() - dist.all_gather(tensor_list, tensor, group=process_group) - output = torch.cat(tuple(tensor_list), comm_spec.gather_dim).contiguous() - return output + """ + process_group = comm_spec.process_group_dict[comm_spec.logical_process_axis] + world_size = dist.get_world_size(process_group) + tensor_list = [torch.zeros(tensor.shape, dtype=tensor.dtype, device=tensor.device) for _ in range(world_size)] + # without this contiguous operation, the all gather may get some unexpected results. + tensor = tensor.contiguous() + dist.all_gather(tensor_list, tensor, group=process_group) + output = torch.cat(tuple(tensor_list), comm_spec.gather_dim).contiguous() + return output def _split(tensor: torch.Tensor, comm_spec: CommSpec): - ''' + """ Implement shard operation on device mesh based on information provided by comm_spec. - ''' - process_groups_list = comm_spec.process_groups_dict[comm_spec.logical_process_axis] - for rank_list, _ in process_groups_list: - if dist.get_rank() in rank_list: - dim = comm_spec.shard_dim - length = tensor.shape[comm_spec.shard_dim] // len(rank_list) - start = length * rank_list.index(dist.get_rank()) - output = torch.narrow(tensor, dim, start, length).contiguous() - return output + """ + process_group = comm_spec.process_group_dict[comm_spec.logical_process_axis] + dim = comm_spec.shard_dim + length = tensor.shape[comm_spec.shard_dim] // dist.get_world_size(process_group) + start = length * dist.get_rank(process_group) + output = torch.narrow(tensor, dim, start, length).contiguous() + return output def _all_to_all(tensor: torch.Tensor, comm_spec: CommSpec): - ''' + """ Implement all to all operation on device mesh based on information provided by comm_spec. - ''' - process_groups_list = comm_spec.process_groups_dict[comm_spec.logical_process_axis] - for rank_list, process_group in process_groups_list: - if dist.get_rank() in rank_list: - new_shape = list(tensor.shape) - new_shape[comm_spec.shard_dim] = new_shape[comm_spec.shard_dim] // len(rank_list) - new_shape = torch.Size(new_shape) - output_tensor_list = [ - torch.zeros(new_shape, dtype=tensor.dtype, device=tensor.device) for _ in range(len(rank_list)) - ] - dim = comm_spec.shard_dim - length = tensor.shape[comm_spec.shard_dim] // len(rank_list) - input_tensor_list = [ - torch.narrow(tensor, dim, length * i, length).contiguous() for i in range(len(rank_list)) - ] - group = process_group - dist.all_to_all(output_tensor_list, input_tensor_list, group) - output = torch.cat(tuple(output_tensor_list), comm_spec.gather_dim).contiguous() - return output + """ + process_group = comm_spec.process_group_dict[comm_spec.logical_process_axis] + world_size = dist.get_world_size(process_group) + new_shape = list(tensor.shape) + new_shape[comm_spec.shard_dim] = new_shape[comm_spec.shard_dim] // world_size + new_shape = torch.Size(new_shape) + output_tensor_list = [torch.zeros(new_shape, dtype=tensor.dtype, device=tensor.device) for _ in range(world_size)] + dim = comm_spec.shard_dim + length = tensor.shape[comm_spec.shard_dim] // world_size + input_tensor_list = [torch.narrow(tensor, dim, length * i, length).contiguous() for i in range(world_size)] + group = process_group + dist.all_to_all(output_tensor_list, input_tensor_list, group) + output = torch.cat(tuple(output_tensor_list), comm_spec.gather_dim).contiguous() + return output def _all_reduce(tensor: torch.Tensor, comm_spec: CommSpec, async_op: bool = False): - ''' + """ Implement all reduce operation on device mesh based on information provided by comm_spec. - ''' - process_groups_list = comm_spec.process_groups_dict[comm_spec.logical_process_axis] - for rank_list, process_group in process_groups_list: - if dist.get_rank() in rank_list: - if not tensor.is_contiguous(): - tensor = tensor.contiguous() - dist.all_reduce(tensor, op=ReduceOp.SUM, group=process_group, async_op=async_op) - return tensor + """ + process_group = comm_spec.process_group_dict[comm_spec.logical_process_axis] + if not tensor.is_contiguous(): + tensor = tensor.contiguous() + dist.all_reduce(tensor, op=ReduceOp.SUM, group=process_group, async_op=async_op) + return tensor class _ReduceGrad(torch.autograd.Function): @@ -268,11 +258,13 @@ class _AllToAll(torch.autograd.Function): @staticmethod def forward(ctx, input_, comm_spec): output = _all_to_all(input_, comm_spec) - comm_spec_for_backward = CommSpec(comm_pattern=comm_spec.comm_pattern, - process_groups_dict=comm_spec.process_groups_dict, - gather_dim=comm_spec.shard_dim, - shard_dim=comm_spec.gather_dim, - logical_process_axis=comm_spec.logical_process_axis) + comm_spec_for_backward = CommSpec( + comm_pattern=comm_spec.comm_pattern, + process_group_dict=comm_spec.process_group_dict, + gather_dim=comm_spec.shard_dim, + shard_dim=comm_spec.gather_dim, + logical_process_axis=comm_spec.logical_process_axis, + ) ctx.comm_spec = comm_spec_for_backward return output diff --git a/colossalai/tensor/d_tensor/d_tensor.py b/colossalai/tensor/d_tensor/d_tensor.py deleted file mode 100644 index c1fe9d50a048397c3fee6297cd80fda61c745bc3..0000000000000000000000000000000000000000 --- a/colossalai/tensor/d_tensor/d_tensor.py +++ /dev/null @@ -1,142 +0,0 @@ -from typing import Optional - -import torch -from torch.utils._pytree import tree_map - -from .layout import Layout -from .layout_converter import LayoutConverter, to_global -from .sharding_spec import ShardingSpec - -layout_converter = LayoutConverter() - - -class DTensor(torch.Tensor): - - def __init__(self, local_tensor: torch.Tensor, dist_layout: Layout): - self.local_tensor = local_tensor - self.data_type = local_tensor.dtype - self.entire_shape = local_tensor.shape - self.dist_layout = dist_layout - self._apply_layout() - - @staticmethod - def __new__(cls, local_tensor, layout): - return torch.Tensor._make_subclass(cls, local_tensor, local_tensor.requires_grad) - - def __repr__(self): - return f"DTensor({self.to_global()}, {self.dist_layout})" - - def __str__(self): - return self.__repr__() - - def layout_convert(self, target_layout): - ''' - Convert the layout of the tensor from source_spec to target_spec. - ''' - self.local_tensor = layout_converter.apply(self.local_tensor, self.dist_layout, target_layout) - self.dist_layout = target_layout - - def _apply_layout(self): - ''' - Apply the layout to the local tensor during initializing process. - ''' - source_spec = construct_default_sharding_spec(self.local_tensor) - source_layout = Layout(device_mesh=self.dist_layout.device_mesh, - device_type=self.dist_layout.device_type, - sharding_spec=source_spec, - entire_shape=self.entire_shape) - self.local_tensor = layout_converter.apply(self.local_tensor, source_layout, self.dist_layout) - - @classmethod - def __torch_function__(cls, func, types, args=(), kwargs=None): - if kwargs is None: - kwargs = {} - - def filter_arg(arg): - if isinstance(arg, DTensor): - return arg.local_tensor - else: - return arg - - args = tree_map(filter_arg, args) - kwargs = tree_map(filter_arg, kwargs) - # if we want to convert the result into DTensor, we need to infer the layout of result from the layout of input tensors - # and op type. - - return func(*args, **kwargs) - - @property - def device_mesh(self): - ''' - Return the device mesh of the tensor. - ''' - return self.dist_layout.device_mesh - - @property - def sharding_spec(self): - ''' - Return the sharding specification of the tensor. - ''' - return self.dist_layout.sharding_spec - - def to(self, *args, **kwargs): - ''' - Move the tensor to a new device or convert the tensor to a new dtype. - ''' - self.local_tensor = self.local_tensor.to(*args, **kwargs) - self.data_type = self.local_tensor.dtype - self.dist_layout.device_type = self.local_tensor.device - # TODO: update the device mesh process groups or we should just cache - # both the cpu process groups and the cuda process groups? - return self - - def to_local(self): - ''' - Return the local tensor in this rank. - ''' - return self.local_tensor - - def to_global(self): - ''' - Recover the global tensor from the distributed tensor. - - Note: This function will all_gather the local tensor to the global tensor and it - will not change the layout of the DTensor. This function is mainly used for debugging or - check the correctness of the distributed tensor. - ''' - return to_global(self.local_tensor, self.dist_layout) - - -def distribute_tensor(local_tensor: torch.Tensor, dist_layout: Layout) -> DTensor: - ''' - Distribute the local tensor to the distributed tensor according to the dist_layout specified. - - Args: - local_tensor: tensor to be distributed. - dist_layout: the layout specification of the distributed tensor. - - Returns: - A 'DTensor' object. - ''' - return DTensor(local_tensor, dist_layout) - - -def distribute_module(module: torch.nn.Module, partition_fn: Optional[callable] = None) -> torch.nn.Module: - ''' - This function converts all the parameters in the module to DTensor(DParam). - - Note: This function is subject to future change as the DParam has not been implemented yet. - ''' - for name, param in module.named_parameters(): - if param is not None and not isinstance(param, DTensor): - # TODO: we could convert the parameter to DParam here, - # the type of the parameter could be an optional argument. - setattr(module, name, torch.nn.Parameter(partition_fn(name, param.data))) - return module - - -def construct_default_sharding_spec(tensor: torch.Tensor,) -> ShardingSpec: - ''' - Construct the default sharding specification for the tensor. - ''' - return ShardingSpec(dim_size=tensor.dim(), dim_partition_dict={}) diff --git a/colossalai/tensor/d_tensor/layout.py b/colossalai/tensor/d_tensor/layout.py index ee7ef74a99aed377a3956ffe53fd96417a4b7aef..6d4c5dbe3c09e1559816d1eb2679e4b1b174e11c 100644 --- a/colossalai/tensor/d_tensor/layout.py +++ b/colossalai/tensor/d_tensor/layout.py @@ -1,12 +1,11 @@ import operator -from dataclasses import dataclass from functools import reduce import torch from colossalai.device.device_mesh import DeviceMesh -from .misc import DuplicatedShardingDimensionError, LayoutException, ShardingNotDivisibleError +from .misc import DuplicatedShardingDimensionError, ShardingNotDivisibleError from .sharding_spec import ShardingSpec @@ -15,29 +14,27 @@ class Layout: Attributes: device_mesh: the device mesh to store the tensor distributed. - device_type: the type of the device mesh, e.g. 'cpu' or 'cuda'. sharding_spec: the sharding specification to describe how the tensor is sharded. - entire_shape: the entire shape of the global tensor. + global_shape: the entire shape of the global tensor. """ - def __init__(self, device_mesh: DeviceMesh, device_type: torch.device, sharding_spec: ShardingSpec, - entire_shape: torch.Size): + def __init__(self, device_mesh: DeviceMesh, sharding_spec: ShardingSpec, global_shape: torch.Size): self.device_mesh = device_mesh - self.device_type = device_type self.sharding_spec = sharding_spec - self.entire_shape = entire_shape + self.global_shape = global_shape self._sanity_check() def __hash__(self) -> int: - return hash(f'{self.sharding_spec}') + return hash(f"{self.sharding_spec}") def get_sharded_shape_per_device(self): - sharded_shape = list(self.entire_shape) + sharded_shape = list(self.global_shape) for dim, shard_list in self.sharding_spec.dim_partition_dict.items(): - mesh_list = [self.device_mesh.mesh_shape[mesh_dim] for mesh_dim in shard_list] + mesh_list = [self.device_mesh.shape[mesh_dim] for mesh_dim in shard_list] shard_partitions = reduce(operator.mul, mesh_list, 1) - assert sharded_shape[ - dim] % shard_partitions == 0, f'Cannot shard dimension {dim} into {shard_partitions} partitions.' + assert ( + sharded_shape[dim] % shard_partitions == 0 + ), f"Cannot shard dimension {dim} into {shard_partitions} partitions." sharded_shape[dim] //= shard_partitions return torch.Size(sharded_shape) @@ -45,24 +42,26 @@ class Layout: sharding_spec = self.sharding_spec # make sure all axes in logical device mesh only be used once - dim_check_list = list(range(self.device_mesh.logical_mesh_id.dim())) - for dim, shard_list in sharding_spec.dim_partition_dict.items(): - for element in shard_list: - if element in dim_check_list: - dim_check_list.remove(element) - else: - raise DuplicatedShardingDimensionError( - f"find an invalid sharding axis {element} in dim_partition_dict in tensor dimension {dim}.") + if self.device_mesh.logical_mesh_id is not None: + dim_check_list = list(range(self.device_mesh.logical_mesh_id.dim())) + for dim, shard_list in sharding_spec.dim_partition_dict.items(): + for element in shard_list: + if element in dim_check_list: + dim_check_list.remove(element) + else: + raise DuplicatedShardingDimensionError( + f"find an invalid sharding axis {element} in dim_partition_dict in tensor dimension {dim}." + ) # make sure that the sharding for a dimension is divisible by the number of devices for dim, shard_list in sharding_spec.dim_partition_dict.items(): - tensor_dim_size = self.entire_shape[dim] + tensor_dim_size = self.global_shape[dim] num_devices = 1 for element in shard_list: - num_devices *= self.device_mesh.mesh_shape[element] + num_devices *= self.device_mesh.shape[element] if tensor_dim_size % num_devices != 0: raise ShardingNotDivisibleError( - f'The size of dimension at index {dim} is {tensor_dim_size}, it cannot be sharded over {num_devices} devices.' + f"The size of dimension at index {dim} is {tensor_dim_size}, it cannot be sharded over {num_devices} devices." ) diff --git a/colossalai/tensor/d_tensor/layout_converter.py b/colossalai/tensor/d_tensor/layout_converter.py index cf02aac309f40d3bd7b5037000094601f5b2f2e3..e031e0472b0b7a6f6c2abd83519af3c1476c9ff0 100644 --- a/colossalai/tensor/d_tensor/layout_converter.py +++ b/colossalai/tensor/d_tensor/layout_converter.py @@ -3,10 +3,8 @@ from copy import deepcopy from dataclasses import dataclass from typing import Dict, List, Tuple -import numpy as np import torch -from colossalai.auto_parallel.tensor_shard.sharding_strategy import MemoryCost, TrainCycleItem from colossalai.context.singleton_meta import SingletonMeta from colossalai.tensor.d_tensor.comm_spec import * from colossalai.tensor.d_tensor.layout import Layout @@ -16,7 +14,7 @@ from colossalai.tensor.utils import all_gather_simulator, all_to_all_simulator, from .sharding_spec import ShardingSpec from .utils import get_comm_cost -__all__ = ['LayoutConverter', 'LayoutConverterOptions', 'set_layout_converting_options'] +__all__ = ["LayoutConverter", "LayoutConverterOptions", "set_layout_converting_options"] @dataclass @@ -24,20 +22,8 @@ class LayoutConverterOptions: """ LayoutConverterOptions is a dataclass which specifies the preferences for layout converting. """ - # TODO: layout converter option is not implemented yet - pass - -def to_global(distributed_tensor: torch.Tensor, layout: Layout) -> torch.Tensor: - layout_converter = LayoutConverter() - global_sharding_spec = ShardingSpec(distributed_tensor.dim(), {}) - global_layout = Layout(device_mesh=layout.device_mesh, - device_type=layout.device_type, - sharding_spec=global_sharding_spec, - entire_shape=layout.entire_shape) - with torch.no_grad(): - global_tensor = layout_converter.apply(distributed_tensor, layout, global_layout) - return global_tensor + # TODO: layout converter option is not implemented yet def set_layout_converting_options(options: LayoutConverterOptions): @@ -49,6 +35,9 @@ def set_layout_converting_options(options: LayoutConverterOptions): class LayoutConverter(metaclass=SingletonMeta): + """ + LayoutConverter is a singleton class which converts the layout of a distributed tensor. + """ def __init__(self): self._options = None @@ -74,7 +63,7 @@ class LayoutConverter(metaclass=SingletonMeta): self._forward_only = value def all_gather_transform_layouts(self, source_layout: Layout) -> Dict[Layout, CommSpec]: - ''' + """ Get all valid layouts from source_layout with single all-gather operation. For the all-gather operation, we just care about the S dimension. @@ -91,15 +80,14 @@ class LayoutConverter(metaclass=SingletonMeta): # [[0, 1, # [2, 3]] device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True) - entire_shape = (4, 4, 4) + global_shape = (4, 4, 4) dim_partition_dict = {0: [0], 1: [1]} # [S0,S1,R] sharding_spec = ShardingSpec(dim_size=3, dim_partition_dict=dim_partition_dict) layout = Layout(device_mesh=device_mesh, - device_type=torch.device('cuda'), sharding_spec=sharding_spec, - entire_shape=entire_shape) + global_shape=global_shape) rst_dict = layout_converter.all_gather_transform_layouts(layout) for layout, comm_spec in rst_dict.items(): @@ -108,11 +96,16 @@ class LayoutConverter(metaclass=SingletonMeta): Output: [R, S1, R]: CommSpec:(comm_pattern:GATHER_FWD_SPLIT_BWD, gather_dim:0, shard_dim:0, logical_process_axis:0) [S0, R, R]: CommSpec:(comm_pattern:GATHER_FWD_SPLIT_BWD, gather_dim:1, shard_dim:1, logical_process_axis:1) - ''' + """ valid_spec_dict = {} comm_pattern = CollectiveCommPattern.GATHER_FWD_SPLIT_BWD source_spec = source_layout.sharding_spec - process_groups_dict = source_layout.device_mesh.process_groups_dict + + # the key of the dict is the axis + # the value is the process group + current_rank = source_layout.device_mesh._global_rank_of_current_process + process_group_dict = source_layout.device_mesh._process_group_dict[current_rank] + for target_pair in source_spec.dim_partition_dict.items(): shard_list = all_gather_simulator(target_pair) index = target_pair[0] @@ -130,19 +123,21 @@ class LayoutConverter(metaclass=SingletonMeta): logical_process_axis = target_pair[1][-1] comm_spec = CommSpec( comm_pattern, - process_groups_dict=process_groups_dict, + process_group_dict=process_group_dict, gather_dim=gather_dim, - # shard_dim will be used during backward + # shard_dim will be used during backward shard_dim=gather_dim, - logical_process_axis=logical_process_axis) + logical_process_axis=logical_process_axis, + ) # generate new sharding spec try: new_sharding_spec = ShardingSpec(source_spec.dims, dim_partition_dict=new_dim_partition_dict) - new_layout = Layout(device_mesh=source_layout.device_mesh, - sharding_spec=new_sharding_spec, - device_type=source_layout.device_type, - entire_shape=source_layout.entire_shape) + new_layout = Layout( + device_mesh=source_layout.device_mesh, + sharding_spec=new_sharding_spec, + global_shape=source_layout.global_shape, + ) valid_spec_dict[new_layout] = comm_spec except LayoutException: @@ -150,7 +145,7 @@ class LayoutConverter(metaclass=SingletonMeta): return valid_spec_dict def all_to_all_transform_layout(self, source_layout: Layout) -> Dict[Layout, CommSpec]: - ''' + """ Get all valid layouts from source_layout with single all-to-all operation. For the all-to-all operation, we just care about the pairs containing S dimension. @@ -167,15 +162,14 @@ class LayoutConverter(metaclass=SingletonMeta): # [[0, 1, # [2, 3]] device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True) - entire_shape = (4, 4, 4) + global_shape = (4, 4, 4) dim_partition_dict = {0: [0], 1: [1]} # [S0,S1,R] sharding_spec = ShardingSpec(dim_size=3, dim_partition_dict=dim_partition_dict) layout = Layout(device_mesh=device_mesh, - device_type=torch.device('cuda'), sharding_spec=sharding_spec, - entire_shape=entire_shape) + global_shape=global_shape) rst_dict = layout_converter.all_to_all_transform_layout(layout) for layout, comm_spec in rst_dict.items(): @@ -185,10 +179,15 @@ class LayoutConverter(metaclass=SingletonMeta): [S01, R, R]: CommSpec:(comm_pattern:ALL2ALL_FWD_ALL2ALL_BWD, gather_dim:1, shard_dim:0, logical_process_axis: 1) [R, S1, S0]: CommSpec:(comm_pattern:ALL2ALL_FWD_ALL2ALL_BWD, gather_dim:0, shard_dim:2, logical_process_axis: 0) [S0, R, S1]: CommSpec:(comm_pattern:ALL2ALL_FWD_ALL2ALL_BWD, gather_dim:1, shard_dim:2, logical_process_axis: 1) - ''' + """ valid_spec_dict = {} comm_pattern = CollectiveCommPattern.ALL2ALL_FWD_ALL2ALL_BWD - process_groups_dict = source_layout.device_mesh.process_groups_dict + + # the key of the dict is the axis + # the value is the process group + current_rank = source_layout.device_mesh._global_rank_of_current_process + process_group_dict = source_layout.device_mesh._process_group_dict[current_rank] + source_spec = source_layout.sharding_spec tensor_dims = source_spec.dims for f_index in range(tensor_dims - 1): @@ -228,11 +227,13 @@ class LayoutConverter(metaclass=SingletonMeta): gather_dim = b_index shard_dim = f_index logical_process_axis = b_target_pair[1][-1] - comm_spec = CommSpec(comm_pattern, - process_groups_dict, - gather_dim=gather_dim, - shard_dim=shard_dim, - logical_process_axis=logical_process_axis) + comm_spec = CommSpec( + comm_pattern, + process_group_dict=process_group_dict, + gather_dim=gather_dim, + shard_dim=shard_dim, + logical_process_axis=logical_process_axis, + ) new_dim_partition_dict = deepcopy(source_spec.dim_partition_dict) @@ -250,10 +251,11 @@ class LayoutConverter(metaclass=SingletonMeta): # generate new sharding spec try: new_sharding_spec = ShardingSpec(source_spec.dims, dim_partition_dict=new_dim_partition_dict) - new_layout = Layout(device_mesh=source_layout.device_mesh, - sharding_spec=new_sharding_spec, - device_type=source_layout.device_type, - entire_shape=source_layout.entire_shape) + new_layout = Layout( + device_mesh=source_layout.device_mesh, + sharding_spec=new_sharding_spec, + global_shape=source_layout.global_shape, + ) valid_spec_dict[new_layout] = comm_spec except LayoutException: pass @@ -261,7 +263,7 @@ class LayoutConverter(metaclass=SingletonMeta): return valid_spec_dict def shard_transform_layout(self, source_layout: Layout) -> Dict[Layout, CommSpec]: - ''' + """ Get all valid layouts from source_layout with single shard operation. For the sharding operation, we just care about legal sharding dimensions. @@ -278,16 +280,15 @@ class LayoutConverter(metaclass=SingletonMeta): # [[0, 1, # [2, 3]] device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True) - entire_shape = (4, 4, 4) + global_shape = (4, 4, 4) dim_partition_dict = {0: [0]} # [S0,R,R] sharding_spec = ShardingSpec(dim_size=3, dim_partition_dict=dim_partition_dict) layout = Layout(device_mesh=device_mesh, - device_type=torch.device('cuda'), sharding_spec=sharding_spec, - entire_shape=entire_shape) + global_shape=global_shape) rst_dict = layout_converter.shard_transform_layout(layout) for layout, comm_spec in rst_dict.items(): @@ -297,14 +298,18 @@ class LayoutConverter(metaclass=SingletonMeta): [S01, R, R]: CommSpec:(comm_pattern:SPLIT_FWD_GATHER_BWD, gather_dim:0, shard_dim:0, logical_process_axis:1) [S0, S1, R]: CommSpec:(comm_pattern:SPLIT_FWD_GATHER_BWD, gather_dim:1, shard_dim:1, logical_process_axis:1) [S0, R, S1]: CommSpec:(comm_pattern:SPLIT_FWD_GATHER_BWD, gather_dim:2, shard_dim:2, logical_process_axis:1) - ''' + """ valid_spec_dict = {} comm_pattern = CollectiveCommPattern.SPLIT_FWD_GATHER_BWD source_spec = source_layout.sharding_spec - process_groups_dict = source_layout.device_mesh.process_groups_dict + + # the key of the dict is the axis + # the value is the process group + current_rank = source_layout.device_mesh._global_rank_of_current_process + process_group_dict = source_layout.device_mesh._process_group_dict[current_rank] # legal sharding dims means the mesh_id is still available to use. - legal_sharding_dims = [i for i in range(len(source_layout.device_mesh.mesh_shape))] + legal_sharding_dims = [i for i in range(len(source_layout.device_mesh.shape))] for dim, shard_list in source_spec.dim_partition_dict.items(): for element in shard_list: legal_sharding_dims.remove(element) @@ -328,27 +333,31 @@ class LayoutConverter(metaclass=SingletonMeta): # generate the CommSpec to record the action of source_sharding_spec->new_sharding_spec shard_dim = index logical_process_axis = shard_list[-1] - comm_spec = CommSpec(comm_pattern, - process_groups_dict, - gather_dim=shard_dim, - shard_dim=shard_dim, - logical_process_axis=logical_process_axis) + comm_spec = CommSpec( + comm_pattern, + process_group_dict=process_group_dict, + gather_dim=shard_dim, + shard_dim=shard_dim, + logical_process_axis=logical_process_axis, + ) # generate new sharding spec try: - new_sharding_spec = ShardingSpec(dim_size=source_spec.dims, - dim_partition_dict=new_dim_partition_dict) - new_layout = Layout(device_mesh=source_layout.device_mesh, - sharding_spec=new_sharding_spec, - device_type=source_layout.device_type, - entire_shape=source_layout.entire_shape) + new_sharding_spec = ShardingSpec( + dim_size=source_spec.dims, dim_partition_dict=new_dim_partition_dict + ) + new_layout = Layout( + device_mesh=source_layout.device_mesh, + sharding_spec=new_sharding_spec, + global_shape=source_layout.global_shape, + ) valid_spec_dict[new_layout] = comm_spec except LayoutException: pass return valid_spec_dict def get_all_one_step_transform_spec(self, source_layout: Layout) -> Dict[Layout, CommSpec]: - ''' + """ Get all valid layouts from source_layout with one step transform. Note: @@ -361,16 +370,17 @@ class LayoutConverter(metaclass=SingletonMeta): Return: valid_spec_dict(Dict[Layout, CommSpec]): all valid layouts from source_layout with one step transform. - ''' + """ valid_spec_dict = {} valid_spec_dict.update(self.all_gather_transform_layouts(source_layout)) valid_spec_dict.update(self.all_to_all_transform_layout(source_layout)) valid_spec_dict.update(self.shard_transform_layout(source_layout)) return valid_spec_dict - def layout_converting(self, source_layout: Layout, - target_layout: Layout) -> Tuple[List[Layout], List[CommSpec], float]: - ''' + def layout_converting( + self, source_layout: Layout, target_layout: Layout + ) -> Tuple[List[Layout], List[CommSpec], float]: + """ This method will find a path to transform source_layout to target_layout with a greedy algorithm. The basic idea is: @@ -399,7 +409,7 @@ class LayoutConverter(metaclass=SingletonMeta): # [[0, 1, # [2, 3]] device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True) - entire_shape = (4, 4, 4) + global_shape = (4, 4, 4) dim_partition_source = {1: [0, 1]} dim_partition_target = {0: [0, 1]} @@ -407,16 +417,14 @@ class LayoutConverter(metaclass=SingletonMeta): # [R,S01,R] sharding_spec_source = ShardingSpec(dim_size=3, dim_partition_dict=dim_partition_source) source_layout = Layout(device_mesh=device_mesh, - device_type=torch.device('cuda'), sharding_spec=sharding_spec_source, - entire_shape=entire_shape) + global_shape=global_shape) # [S01,R,R] sharding_spec_target = ShardingSpec(dim_size=3, dim_partition_dict=dim_partition_target) target_layout = Layout(device_mesh=device_mesh, - device_type=torch.device('cuda'), sharding_spec=sharding_spec_target, - entire_shape=entire_shape) + global_shape=global_shape) transform_path, comm_action_sequence = layout_converter.layout_converting(source_layout, target_layout) transform_path_str = '->'.join([str(layout.sharding_spec.sharding_sequence) for layout in transform_path]) @@ -424,7 +432,7 @@ class LayoutConverter(metaclass=SingletonMeta): output: [R, S01, R]->[R, S0, R]->[S0, R, R]->[S01, R, R] - ''' + """ source_spec = source_layout.sharding_spec target_spec = target_layout.sharding_spec MAX_TRANSFORM_STEPS = 20 @@ -475,11 +483,11 @@ class LayoutConverter(metaclass=SingletonMeta): raise RuntimeError(f"Could not find a valid transform path with in {MAX_TRANSFORM_STEPS} steps.") def get_total_comm_cost(self, source_layout: Layout, target_layout: Layout) -> Dict[str, float]: - ''' + """ Get the total communication cost of the layout converting process. - ''' + """ transform_path, comm_action_sequence = self.layout_converting(source_layout, target_layout) - total_cost = {'forward': 0.0, 'backward': 0.0, 'total': 0.0} + total_cost = {"forward": 0.0, "backward": 0.0, "total": 0.0} for layout, comm_spec in zip(transform_path, comm_action_sequence): cost_dict = get_comm_cost(layout, comm_spec, self.forward_only) for key in total_cost: @@ -487,7 +495,7 @@ class LayoutConverter(metaclass=SingletonMeta): return total_cost def apply(self, tensor: torch.Tensor, source_layout: Layout, target_layout: Layout) -> torch.Tensor: - ''' + """ Apply target_layout to tensor with source layout, the transform path is generated by the layout_converting method. @@ -505,21 +513,19 @@ class LayoutConverter(metaclass=SingletonMeta): # [[0, 1, # [2, 3]] device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True) - entire_shape = (4, 4, 4) + global_shape = (4, 4, 4) # [S0,R,R] sharding_spec_source = ShardingSpec(dim_size=3, dim_partition_dict=dim_partition_source) source_layout = Layout(device_mesh=device_mesh, - device_type=torch.device('cuda'), sharding_spec=sharding_spec_source, - entire_shape=entire_shape) + global_shape=global_shape) # [R,S0,R] sharding_spec_target = ShardingSpec(dim_size=3, dim_partition_dict=dim_partition_target) target_layout = Layout(device_mesh=device_mesh, - device_type=torch.device('cuda'), sharding_spec=sharding_spec_target, - entire_shape=entire_shape) + global_shape=global_shape) if rank in (0, 1): sharded_tensor_0 = torch.zeros(2, 1) @@ -549,8 +555,9 @@ class LayoutConverter(metaclass=SingletonMeta): [1.], [3.], [3.]]) - ''' + """ _, comm_action_sequence = self.layout_converting(source_layout, target_layout) for comm_spec in comm_action_sequence: tensor = comm_spec.covert_spec_to_action(tensor) + tensor.dist_layout = target_layout return tensor diff --git a/colossalai/tensor/d_tensor/sharding_spec.py b/colossalai/tensor/d_tensor/sharding_spec.py index 2ea0c4db89fd3ca73fc920f84185c27c446b9573..2ac0ca73e4b8301aa5b9711aef4fdc7645f7e38e 100644 --- a/colossalai/tensor/d_tensor/sharding_spec.py +++ b/colossalai/tensor/d_tensor/sharding_spec.py @@ -4,16 +4,16 @@ from typing import Dict, List from ..utils import merge_same_dim_mesh_list from .misc import ShardingOutOfIndexError -__all__ = ['DimSpec', 'ShardingException', 'ShardingSpec'] +__all__ = ["DimSpec", "ShardingException", "ShardingSpec"] ALLGATHER_COST = 20 SHARD_COST = 5 STEP_PENALTY = 6 -NAN = 'nan' +NAN = "nan" class DimSpec: - ''' + """ Sharding spec for single dimension of the sharded tensor describe the sharding dimension of logical device mesh and give a method to compute the difference between them. This class is used internally in ShardingSpec. @@ -21,7 +21,7 @@ class DimSpec: Argument: shard_list(List[int]): if shard_list is None, the dim spec will be 'R' type. Otherwise, the element in shard_list means the data will be sharded in that dimension. - ''' + """ def __init__(self, shard_list): self.is_replica = len(shard_list) == 0 @@ -33,41 +33,40 @@ class DimSpec: def __repr__(self): if self.is_replica: - return 'R' - target = 'S' + return "R" + target = "S" for dim in self.shard_list: target += str(dim) return target def _convert_str_to_shard_list(self, str_spec): - ''' - Conver str_spec into shard_list. + """ + Convert str_spec into shard_list. Argument: str_spec(str): dim spec in str type. - ''' + """ - if str_spec == 'R': + if str_spec == "R": return [] - if str_spec == 'S0': + if str_spec == "S0": return [0] - if str_spec == 'S1': + if str_spec == "S1": return [1] - if str_spec == 'S01': + if str_spec == "S01": return [0, 1] def build_difference_2d_dict(self): - ''' - Build a difference maping for 2D device mesh case. It will be used to + """ + Build a difference mapping for 2D device mesh case. It will be used to compute the difference between DimSpec pairs. - ''' + """ - source_spec_list = ['R', 'S0', 'S1', 'S01'] - target_spec_list = ['R', 'S0', 'S1', 'S01'] + source_spec_list = ["R", "S0", "S1", "S01"] + target_spec_list = ["R", "S0", "S1", "S01"] difference_dict = {} for source_spec in source_spec_list: for target_spec in target_spec_list: - legal_sharding_dims = [] spec_pair = (deepcopy(source_spec), deepcopy(target_spec)) source_shard_list = self._convert_str_to_shard_list(source_spec) target_shard_list = self._convert_str_to_shard_list(target_spec) @@ -77,14 +76,17 @@ class DimSpec: difference = 0 # all_gather(source) -> target - elif len(source_shard_list - ) == len(target_shard_list) + 1 and source_shard_list[:-1] == target_shard_list: + elif ( + len(source_shard_list) == len(target_shard_list) + 1 and source_shard_list[:-1] == target_shard_list + ): difference = ALLGATHER_COST # shard(source) -> target - elif len(source_shard_list) == len( - target_shard_list) - 1 and source_shard_list == target_shard_list[:-1] and target_shard_list[ - -1] not in source_shard_list: + elif ( + len(source_shard_list) == len(target_shard_list) - 1 + and source_shard_list == target_shard_list[:-1] + and target_shard_list[-1] not in source_shard_list + ): difference = SHARD_COST # S1 -> S0 or S0 -> S1 @@ -115,7 +117,7 @@ class DimSpec: self.difference_dict = difference_dict def dim_diff(self, other): - ''' + """ The difference between two _DimSpec. Argument: @@ -131,13 +133,13 @@ class DimSpec: Output: 5 - ''' + """ difference = self.difference_dict[(str(self), str(other))] return difference class ShardingSpec: - ''' + """ Sharding spec describes how to shard a tensor with dim_size dimensions. The sharding sequence looks like [R, R, S0, S1], which means @@ -145,23 +147,27 @@ class ShardingSpec: dim_partition_dict(Dict[int, List[int]], optional): The key is the dimension of tensor to be sharded, and the value of the key describe which logical axis will be sharded in that dimension. sharding_sequence(List[DimSpec], optional): A straight view of ShardingSpec looks like [R, R, S0, S1]. - ''' + """ - def __init__(self, - dim_size: int, - dim_partition_dict: Dict[int, List[int]] = None, - sharding_sequence: List[DimSpec] = None): + def __init__( + self, dim_size: int, dim_partition_dict: Dict[int, List[int]] = None, sharding_sequence: List[DimSpec] = None + ): self.dims = dim_size self.dim_partition_dict = dim_partition_dict self.sharding_sequence = sharding_sequence if self.sharding_sequence is None: - assert self.dim_partition_dict is not None, f'dim_partition_dict should not be None, if sharding_sequence is NoneType object.' - self.dim_partition_dict = merge_same_dim_mesh_list(dim_size=self.dims, - dim_partition_dict=self.dim_partition_dict) + assert ( + self.dim_partition_dict is not None + ), f"dim_partition_dict should not be None, if sharding_sequence is NoneType object." + self.dim_partition_dict = merge_same_dim_mesh_list( + dim_size=self.dims, dim_partition_dict=self.dim_partition_dict + ) self.sharding_sequence = self.convert_dict_to_shard_sequence() elif self.dim_partition_dict is None: - assert self.sharding_sequence is not None, f'sharding_sequence should not be None, if dim_partition_dict is NoneType object.' + assert ( + self.sharding_sequence is not None + ), f"sharding_sequence should not be None, if dim_partition_dict is NoneType object." self.dim_partition_dict = self.convert_shard_sequence_to_dict() self._sanity_check() @@ -169,31 +175,32 @@ class ShardingSpec: def _sanity_check(self): if len(self.sharding_sequence) > self.dims: raise ShardingOutOfIndexError( - f'sharding_sequence should have {self.dims} elements, but got index {len(self.sharding_sequence)}.') + f"sharding_sequence should have {self.dims} elements, but got index {len(self.sharding_sequence)}." + ) if list(self.dim_partition_dict.keys()) and max(list(self.dim_partition_dict.keys())) >= self.dims: raise ShardingOutOfIndexError( - f'the key of dim_partition_dict should be less than {self.dims}, but got {max(list(self.dim_partition_dict.keys()))}.' + f"the key of dim_partition_dict should be less than {self.dims}, but got {max(list(self.dim_partition_dict.keys()))}." ) def __repr__(self): res_list = ["ShardingSpec:"] res_list.append(f"\n\tshard_sequence: " + ",".join(str(dimspec) for dimspec in self.sharding_sequence)) - return ' '.join(res_list) + return " ".join(res_list) def convert_dict_to_shard_sequence(self): - ''' + """ Convert dim_partition_dict into list of DimSpec, and assign it to sharding_sequence. - ''' + """ sharding_sequence = [DimSpec([])] * self.dims for dim, shard_list in self.dim_partition_dict.items(): sharding_sequence[dim] = DimSpec(shard_list) return sharding_sequence def convert_shard_sequence_to_dict(self): - ''' + """ Convert sharding_sequence into dim_partition_dict. - ''' + """ new_dim_partition_dict = {} for index, dim_spec in enumerate(self.sharding_sequence): if not dim_spec.is_replica: @@ -203,7 +210,7 @@ class ShardingSpec: return new_dim_partition_dict def spec_diff(self, other): - ''' + """ This function is a naive version of difference computation. It just simply accumulates difference every dimension between the pair of sharding sequence. @@ -228,9 +235,10 @@ class ShardingSpec: Return: difference(int): Difference between two ShardingSpec. - ''' + """ assert len(self.sharding_sequence) == len( - other.sharding_sequence), f'Cannot compare difference for two sharding specs with different length.' + other.sharding_sequence + ), f"Cannot compare difference for two sharding specs with different length." difference = 0 for orig_dim_spec, other_dim_spec in zip(self.sharding_sequence, other.sharding_sequence): difference += orig_dim_spec.dim_diff(other_dim_spec) diff --git a/colossalai/tensor/d_tensor/utils.py b/colossalai/tensor/d_tensor/utils.py index 644bb6306b42c3ac28717396aeea169370748868..8f0081246fb32e7b57d5cd15d3d11899f7fee20b 100644 --- a/colossalai/tensor/d_tensor/utils.py +++ b/colossalai/tensor/d_tensor/utils.py @@ -7,7 +7,7 @@ from colossalai.tensor.d_tensor.layout import Layout def get_comm_cost(layout: Layout, comm_spec: CommSpec, forward_only: bool = False) -> Dict[str, float]: - ''' + """ This method is used to compute the communication cost for a given layout and comm_spec. For all_gather, all2all, and all_reduce operation, the formula provided in DeviceMesh with alpha-beta model is used to @@ -18,7 +18,7 @@ def get_comm_cost(layout: Layout, comm_spec: CommSpec, forward_only: bool = Fals comm_spec: the comm_spec to instruct the communication operation. forward_only: if it is True, we will just count the forward communication cost. If it is False, we will count both forward and backward communication cost. - ''' + """ comm_size = reduce(operator.mul, layout.get_sharded_shape_per_device(), 1) device_mesh = layout.device_mesh comm_pattern = comm_spec.comm_pattern @@ -29,7 +29,7 @@ def get_comm_cost(layout: Layout, comm_spec: CommSpec, forward_only: bool = Fals # the comm size for all gather is the size of the gathered tensor gather_dim = comm_spec.gather_dim all_gather_axis = layout.sharding_spec.dim_partition_dict[gather_dim][-1] - all_gather_size = device_mesh.mesh_shape[all_gather_axis] + all_gather_size = device_mesh.shape[all_gather_axis] comm_size_for_all_gather = comm_size * all_gather_size forward_communication_cost = device_mesh.all_gather_cost(comm_size_for_all_gather, logical_process_axis) # give a tiny cost to shard diff --git a/colossalai/tensor/dist_spec_mgr.py b/colossalai/tensor/dist_spec_mgr.py deleted file mode 100644 index 8657989235db49623f31191b2b4df84896893db6..0000000000000000000000000000000000000000 --- a/colossalai/tensor/dist_spec_mgr.py +++ /dev/null @@ -1,189 +0,0 @@ -from contextlib import contextmanager - -import torch -import torch.distributed as dist -# from colossalai.nn.layer.utils import divide -from numpy import prod -from packaging import version - -from colossalai.logging import get_dist_logger -from colossalai.tensor.distspec import _DistSpec -from colossalai.tensor.process_group import ProcessGroup - - -# TODO(jiaruifang) circle import, move the divide to colossalai.commons. -# colossalai.tensor shall not import any submodule from colossal.nn -def divide(numerator, denominator): - """Only allow exact division. - - Args: - numerator (int): Numerator of the division. - denominator (int): Denominator of the division. - - Returns: - int: the result of exact division. - """ - assert denominator != 0, 'denominator can not be zero' - assert numerator % denominator == 0, \ - '{} is not divisible by {}'.format(numerator, denominator) - return numerator // denominator - - -class TransformDistSpec(torch.autograd.Function): - - @staticmethod - def forward(ctx, tensor, old_dist_spec, dist_spec, pg, forward_trans_func, backward_trans_func): - ctx.old_dist_spec = old_dist_spec - ctx.dist_spec = dist_spec - ctx.backward_trans_func = backward_trans_func - ctx.pg = pg - return forward_trans_func(tensor, old_dist_spec, dist_spec, pg) - - @staticmethod - def backward(ctx, grad_outputs): - return ctx.backward_trans_func(grad_outputs, ctx.dist_spec, ctx.old_dist_spec, - ctx.pg), None, None, None, None, None - - -class DistSpecManager: - - _use_autograd_function: bool = True - - @staticmethod - def _sanity_check(old_dist_spec: _DistSpec, dist_spec: _DistSpec) -> None: - pass - - @staticmethod - def _shard_as(tensor: torch.Tensor, old_dist_spec: _DistSpec, dist_spec: _DistSpec, - pg: ProcessGroup) -> torch.Tensor: - """_shard_as: shard the tensor w.r.t a distributed specification. - Assuming the tensor passed in is a global (replicated) tensor. - Args: - tensor (torch.Tensor): a global (replicated) tensor before shard - dist_spec (_DistSpec): the distributed spec. to be sharded as. - pg (ProcessGroup): the process group of the corresponding colotensor - Returns: - torch.Tensor: a torch tensor after sharded. - """ - assert old_dist_spec.placement.value == 'r', f"The old_dist_spec of DistSpecManager._shard_as must be REPLICATE!" - DistSpecManager._sanity_check(old_dist_spec, dist_spec) - - chunk = tensor - idx = pg.tp_local_rank() - num_parts = prod(dist_spec.num_partitions) - for i, dim in enumerate(dist_spec.dims): - num_parts //= dist_spec.num_partitions[i] - - chunk_size = divide(tensor.size(dim), dist_spec.num_partitions[i]) - chunk = chunk.narrow(dim, idx // num_parts * chunk_size, chunk_size) - idx %= num_parts - return chunk.clone().detach().contiguous() - - @staticmethod - def _gather(tensor: torch.Tensor, old_dist_spec: _DistSpec, pg: ProcessGroup) -> torch.Tensor: - """_gather gather sharded tensors to a replicated one. - Args: - tensor (torch.Tensor): a shared torch tensor - old_dist_spec (_DistSpec): the distributed spec. of the tensor. - - Returns: - torch.Tensor: a replicated tensor. - """ - assert old_dist_spec.placement.value == 's', f"The old_dist_spec of DistSpecManager._gather must be SHARD!" - is_cpu_tensor = False - if tensor.device.type == 'cpu': - # pytorch lower than 1.11 dose not support gather a cpu tensor. - # Therefore, we transfer tensor to GPU before gather. - saved_dev = tensor.device - tensor.data = tensor.data.cuda() - is_cpu_tensor = True - - buffer = [torch.empty_like(tensor) for _ in range(pg.tp_world_size())] - assert tensor.device.type == 'cuda' - dist.all_gather(buffer, tensor, group=pg.tp_process_group()) - for i in range(len(old_dist_spec.dims) - 1, -1, -1): - new_buffer = [] - dim = old_dist_spec.dims[i] - num_parts = old_dist_spec.num_partitions[i] - for start in range(0, len(buffer), num_parts): - new_buffer.append(torch.cat(buffer[start:start + num_parts], dim)) - buffer = new_buffer - assert len(buffer) == 1 - - if is_cpu_tensor: - buffer[0].data = buffer[0].data.to(saved_dev) - return buffer[0] - - @staticmethod - def _all_to_all(tensor: torch.Tensor, old_dist_spec: _DistSpec, dist_spec: _DistSpec, - pg: ProcessGroup) -> torch.Tensor: - world_size = pg.tp_world_size() - if world_size == 1: - return tensor - - assert tensor.device.type == "cuda", \ - "Currently, only CUDA Tensor with NCCL backend is supported for the requested AlltoAll " \ - f"collective function, however, we got {tensor.device.type} device" - - gather_dim = old_dist_spec.dims[0] - scatter_dim = dist_spec.dims[0] - shapes = list(tensor.shape) - scattered_dim_size = shapes[scatter_dim] // world_size - gathered_dim_size = shapes[gather_dim] * world_size - shapes[scatter_dim] = scattered_dim_size - - scatter_list = [t.contiguous() for t in torch.tensor_split(tensor, world_size, scatter_dim)] - gather_list = [torch.empty(*shapes, dtype=tensor.dtype, device=tensor.device) for _ in range(world_size)] - dist.all_to_all(gather_list, scatter_list, group=pg.tp_process_group()) - - output_ = torch.cat(gather_list, dim=gather_dim).contiguous() - assert output_.shape[scatter_dim] == scattered_dim_size and output_.shape[gather_dim] == gathered_dim_size - return output_ - - @staticmethod - def _r2r(tensor: torch.Tensor, old_dist_spec: _DistSpec, dist_spec: _DistSpec, pg: ProcessGroup) -> torch.Tensor: - DistSpecManager._sanity_check(old_dist_spec, dist_spec) - return tensor - - @staticmethod - def _r2s(tensor: torch.Tensor, old_dist_spec: _DistSpec, dist_spec: _DistSpec, pg: ProcessGroup) -> torch.Tensor: - DistSpecManager._sanity_check(old_dist_spec, dist_spec) - return DistSpecManager._shard_as(tensor, old_dist_spec, dist_spec, pg) - - @staticmethod - def _s2r(tensor: torch.Tensor, old_dist_spec: _DistSpec, dist_spec: _DistSpec, pg: ProcessGroup) -> torch.Tensor: - DistSpecManager._sanity_check(old_dist_spec, dist_spec) - return DistSpecManager._gather(tensor, old_dist_spec, pg) - - @staticmethod - def _s2s(tensor: torch.Tensor, old_dist_spec: _DistSpec, dist_spec: _DistSpec, pg: ProcessGroup) -> torch.Tensor: - DistSpecManager._sanity_check(old_dist_spec, dist_spec) - if old_dist_spec == dist_spec: - return tensor - if len(old_dist_spec.dims) == 1 and len(dist_spec.dims) == 1: - # use all-to-all to save memory - return DistSpecManager._all_to_all(tensor, old_dist_spec, dist_spec, pg) - tensor = DistSpecManager._gather(tensor, old_dist_spec, pg) - return DistSpecManager._shard_as(tensor, old_dist_spec, dist_spec, pg) - - @staticmethod - def handle_trans_spec(tensor: torch.Tensor, old_dist_spec: _DistSpec, dist_spec: _DistSpec, - pg: ProcessGroup) -> torch.Tensor: - assert isinstance(old_dist_spec, _DistSpec), f"{type(old_dist_spec)} should be _DistSpec" - assert isinstance(dist_spec, _DistSpec), f"{type(dist_spec)} should be _DistSpec" - forward_trans_handle = getattr(DistSpecManager, f'_{old_dist_spec.placement.value}2{dist_spec.placement.value}') - if not DistSpecManager._use_autograd_function: - return forward_trans_handle(tensor, old_dist_spec, dist_spec, pg) - backward_trans_handle = getattr(DistSpecManager, - f'_{dist_spec.placement.value}2{old_dist_spec.placement.value}') - return TransformDistSpec.apply(tensor, old_dist_spec, dist_spec, pg, forward_trans_handle, - backward_trans_handle) - - @staticmethod - @contextmanager - def no_grad(): - try: - DistSpecManager._use_autograd_function = False - yield - finally: - DistSpecManager._use_autograd_function = True diff --git a/colossalai/tensor/param_op_hook.py b/colossalai/tensor/param_op_hook.py index 9c2e0d4adbf1bdc10b25a2602098eef3d36ebbaf..1fe99cd89a4ef687ff6f87fdb8366d2577378267 100644 --- a/colossalai/tensor/param_op_hook.py +++ b/colossalai/tensor/param_op_hook.py @@ -3,9 +3,7 @@ from contextlib import contextmanager from typing import Any, List, Tuple import torch - -from colossalai.tensor.colo_tensor import ColoTensor -from colossalai.tensor.tensor_spec import ColoTensorSpec +from torch.utils._pytree import TreeSpec, tree_flatten, tree_unflatten class ColoParamOpHook(ABC): @@ -38,6 +36,7 @@ class ColoParamOpHookManager: Manage your param op hooks. It only has static methods. The only static method you should call is ``use_hooks(*hooks)``. """ + hooks: Tuple[ColoParamOpHook, ...] = tuple() @staticmethod @@ -82,26 +81,18 @@ class ColoParamOpHookManager: @staticmethod def pre_op(params: List[torch.Tensor], *args: Any) -> list: ColoParamOpHookManager._trigger_pre_forward(params) - grad_args, rear_args = _get_grad_args(*args) - colo_info = _get_colo_tensors_info(*grad_args) - rets = PreFwdPostBwd.apply(params, *grad_args) - update_args = _update_colo_tensors(colo_info, *rets) - if rear_args is None: - return update_args - else: - arg_zero = (tuple(update_args),) - return arg_zero + rear_args + # auto grad function can only recognize torch.Tensor, thus we have to flatten the input + # if one of the input requires grad, all the output will be treated as requires grad + # and will have grad fn even the corresponding input does not require grad + # we have to extract tensors requiring grad into flat list and then merge them back + grad_args, other_args, grad_flags, spec = _flatten_grad_args(args) + new_grad_args = PreFwdPostBwd.apply(params, *grad_args) + return _merge_args(new_grad_args, other_args, grad_flags, spec) @staticmethod def post_op(params: List[torch.Tensor], arg: Any) -> Any: ColoParamOpHookManager._trigger_post_forward(params) - colo_info = _get_colo_tensors_info(arg) - ret = PostFwdPreBwd.apply(params, arg) - res = _update_colo_tensors(colo_info, ret) - if len(res) == 1: - return res[0] - else: - return res + return PostFwdPreBwd.apply(params, arg) @staticmethod def has_hook() -> bool: @@ -109,7 +100,6 @@ class ColoParamOpHookManager: class PreFwdPostBwd(torch.autograd.Function): - @staticmethod def forward(ctx, params, *args): ctx.params = params @@ -122,7 +112,6 @@ class PreFwdPostBwd(torch.autograd.Function): class PostFwdPreBwd(torch.autograd.Function): - @staticmethod def forward(ctx, params, args): ctx.params = params @@ -141,57 +130,24 @@ def _is_grad_tensor(obj) -> bool: return False -def _has_grad_tensor(obj) -> bool: - if isinstance(obj, tuple) or isinstance(obj, list): - for x in obj: - if _has_grad_tensor(x): - return True - return False - elif isinstance(obj, dict): - for x in obj.values(): - if _has_grad_tensor(x): - return True - return False - else: - return _is_grad_tensor(obj) - - -def _get_grad_args(*args): - # if there is no grad tensors, do nothing - if not _has_grad_tensor(args): - return args, None - # returns the identical args if there is a grad tensor - for obj in args: - if _is_grad_tensor(obj): - return args, None - # otherwise, the first arguement should be a tuple of grad tensors - # if there is no grad tensor, the backward of PreFwdPostBwd can't be triggered - arg_zero = args[0] - if not isinstance(arg_zero, tuple): - raise NotImplementedError("Some torch function is incompatible because of its complicated inputs.") - check_grad_flag = False - for obj in arg_zero: - check_grad_flag |= _is_grad_tensor(obj) - if not check_grad_flag: - raise NotImplementedError("Some torch function is incompatible because of its complicated inputs.") - return arg_zero, args[1:] - - -def _get_colo_tensors_info(*args) -> list: - info = [] - for arg in args: - if isinstance(arg, ColoTensor): - info.append((arg.__class__, ColoTensorSpec(arg.get_process_group(), arg.dist_spec, arg.compute_spec))) +def _flatten_grad_args(args) -> Tuple[list, list, List[bool], TreeSpec]: + flat_args, spec = tree_flatten(args) + grad_args = [] + other_args = [] + grad_flags = [] + for arg in flat_args: + flag = _is_grad_tensor(arg) + grad_flags.append(flag) + if flag: + grad_args.append(arg) else: - info.append(None) - return info - - -def _update_colo_tensors(info, *args) -> list: - ret = [] - for t_info, arg in zip(info, args): - if t_info is not None: - t_cls, spec = t_info - arg = t_cls.from_torch_tensor(arg, spec=spec) - ret.append(arg) - return ret + other_args.append(arg) + assert len(grad_args) > 0 + return grad_args, other_args, grad_flags, spec + + +def _merge_args(grad_args, other_args, grad_flags, spec): + grad_iter = iter(grad_args) + other_iter = iter(other_args) + flat_args = [next(grad_iter) if flag else next(other_iter) for flag in grad_flags] + return tree_unflatten(flat_args, spec) diff --git a/colossalai/tensor/shape_consistency.py b/colossalai/tensor/shape_consistency.py index 0a840006f086353c1ae8cc6db6a695e230e8e59c..409561b3a26b4a47a7d246319eb306c507afc29c 100644 --- a/colossalai/tensor/shape_consistency.py +++ b/colossalai/tensor/shape_consistency.py @@ -13,7 +13,7 @@ from colossalai.tensor.utils import all_gather_simulator, all_to_all_simulator, from .comm_spec import * -__all__ = ['ShapeConsistencyManager', 'ShapeConsistencyOptions', 'set_shape_consistency_options'] +__all__ = ["ShapeConsistencyManager", "ShapeConsistencyOptions", "set_shape_consistency_options"] @dataclass @@ -21,16 +21,17 @@ class ShapeConsistencyOptions: """ ShapeConsistencyOptions is a dataclass which specifies the preferences for shape consistency. """ + # TODO: shape consistency option is not implemented yet - pass def to_global(distributed_tensor: torch.Tensor, sharding_spec: ShardingSpec) -> torch.Tensor: shape_consistency_manager = ShapeConsistencyManager() global_sharding_spec = ShardingSpec(sharding_spec.device_mesh, sharding_spec.entire_shape, {}) with torch.no_grad(): - global_tensor = shape_consistency_manager.apply_for_autoparallel_runtime(distributed_tensor, sharding_spec, - global_sharding_spec) + global_tensor = shape_consistency_manager.apply_for_autoparallel_runtime( + distributed_tensor, sharding_spec, global_sharding_spec + ) return global_tensor @@ -43,7 +44,6 @@ def set_shape_consistency_options(options: ShapeConsistencyOptions): class ShapeConsistencyManager(metaclass=SingletonMeta): - def __init__(self): self._options = None self._forward_only = False @@ -69,9 +69,10 @@ class ShapeConsistencyManager(metaclass=SingletonMeta): assert isinstance(value, bool) self._forward_only = value - def get_all_all_gather_spec(self, source_spec: ShardingSpec, - orig_cost_dict: Dict[str, float]) -> Dict[ShardingSpec, float]: - ''' + def get_all_all_gather_spec( + self, source_spec: ShardingSpec, orig_cost_dict: Dict[str, float] + ) -> Dict[ShardingSpec, float]: + """ Get all valid sharding specs from source_spec with single all-gather operation, and accumulate communication cost on origin cost which will finally be used in auto sharding solver. For the all-gather operation, we just care about the S dimension. @@ -99,7 +100,7 @@ class ShapeConsistencyManager(metaclass=SingletonMeta): device_mesh_shape: (4, 4): 0, DistSpec: shard_sequence: S0,R,R device_mesh_shape: (4, 4): 0} - ''' + """ valid_spec_dict = {} comm_pattern = CollectiveCommPattern.GATHER_FWD_SPLIT_BWD for target_pair in source_spec.dim_partition_dict.items(): @@ -121,19 +122,20 @@ class ShapeConsistencyManager(metaclass=SingletonMeta): comm_pattern, sharding_spec=source_spec, gather_dim=gather_dim, - # shard_dim will be used during backward + # shard_dim will be used during backward shard_dim=gather_dim, logical_process_axis=logical_process_axis, - forward_only=self.forward_only) + forward_only=self.forward_only, + ) # compute the communication cost with CommSpec cost_dict = comm_spec.get_comm_cost() # generate new sharding spec try: - new_sharding_spec = ShardingSpec(source_spec.device_mesh, - source_spec.entire_shape, - dim_partition_dict=new_dim_partition_dict) + new_sharding_spec = ShardingSpec( + source_spec.device_mesh, source_spec.entire_shape, dim_partition_dict=new_dim_partition_dict + ) for phase, cost in cost_dict.items(): cost_dict[phase] = cost + orig_cost_dict[phase] valid_spec_dict[new_sharding_spec] = (comm_spec, cost_dict) @@ -141,9 +143,10 @@ class ShapeConsistencyManager(metaclass=SingletonMeta): pass return valid_spec_dict - def get_all_all_to_all_spec(self, source_spec: ShardingSpec, - orig_cost_dict: Dict[str, float]) -> Dict[ShardingSpec, float]: - ''' + def get_all_all_to_all_spec( + self, source_spec: ShardingSpec, orig_cost_dict: Dict[str, float] + ) -> Dict[ShardingSpec, float]: + """ Get all valid sharding specs from source_spec with single all-to-all operation, and accumulate communication cost on origin cost which will finally be used in auto sharding solver. For the all-to-all operation, we just care about the pairs containing S dimension. @@ -173,7 +176,7 @@ class ShapeConsistencyManager(metaclass=SingletonMeta): device_mesh_shape: (4, 4): 0, DistSpec: shard_sequence: S0,R,S1 device_mesh_shape: (4, 4): 0} - ''' + """ valid_spec_dict = {} comm_pattern = CollectiveCommPattern.ALL2ALL_FWD_ALL2ALL_BWD tensor_dims = len(source_spec.entire_shape) @@ -214,12 +217,14 @@ class ShapeConsistencyManager(metaclass=SingletonMeta): gather_dim = b_index shard_dim = f_index logical_process_axis = b_target_pair[1][-1] - comm_spec = CommSpec(comm_pattern, - sharding_spec=source_spec, - gather_dim=gather_dim, - shard_dim=shard_dim, - logical_process_axis=logical_process_axis, - forward_only=self.forward_only) + comm_spec = CommSpec( + comm_pattern, + sharding_spec=source_spec, + gather_dim=gather_dim, + shard_dim=shard_dim, + logical_process_axis=logical_process_axis, + forward_only=self.forward_only, + ) # compute the communication cost with CommSpec cost_dict = comm_spec.get_comm_cost() @@ -238,9 +243,9 @@ class ShapeConsistencyManager(metaclass=SingletonMeta): # generate new sharding spec try: - new_sharding_spec = ShardingSpec(source_spec.device_mesh, - source_spec.entire_shape, - dim_partition_dict=new_dim_partition_dict) + new_sharding_spec = ShardingSpec( + source_spec.device_mesh, source_spec.entire_shape, dim_partition_dict=new_dim_partition_dict + ) for phase, cost in cost_dict.items(): cost_dict[phase] = cost + orig_cost_dict[phase] valid_spec_dict[new_sharding_spec] = (comm_spec, cost_dict) @@ -250,9 +255,9 @@ class ShapeConsistencyManager(metaclass=SingletonMeta): return valid_spec_dict def get_all_shard_spec(self, source_spec: ShardingSpec, orig_cost_dict): - ''' + """ Get all valid sharding specs from source_spec with single shard operation, and - accumulate commucation cost on origin cost which will finally be used in auto sharding solver. + accumulate communication cost on origin cost which will finally be used in auto sharding solver. For the sharding operation, we just care about legal sharding dimensions. Argument: @@ -280,12 +285,12 @@ class ShapeConsistencyManager(metaclass=SingletonMeta): device_mesh_shape: (4, 4): 0, DistSpec: shard_sequence: S0,R,S1 device_mesh_shape: (4, 4): 0} - ''' + """ valid_spec_dict = {} comm_pattern = CollectiveCommPattern.SPLIT_FWD_GATHER_BWD # legal sharding dims means the mesh_id is still available to use. - legal_sharding_dims = [i for i in range(len(source_spec.device_mesh.mesh_shape))] + legal_sharding_dims = [i for i in range(len(source_spec.device_mesh.shape))] for dim, shard_list in source_spec.dim_partition_dict.items(): for element in shard_list: legal_sharding_dims.remove(element) @@ -308,21 +313,23 @@ class ShapeConsistencyManager(metaclass=SingletonMeta): # generate the CommSpec to record the action of source_sharding_spec->new_sharding_spec shard_dim = index logical_process_axis = shard_list[-1] - comm_spec = CommSpec(comm_pattern, - sharding_spec=source_spec, - gather_dim=shard_dim, - shard_dim=shard_dim, - logical_process_axis=logical_process_axis, - forward_only=self.forward_only) + comm_spec = CommSpec( + comm_pattern, + sharding_spec=source_spec, + gather_dim=shard_dim, + shard_dim=shard_dim, + logical_process_axis=logical_process_axis, + forward_only=self.forward_only, + ) # compute the communication cost with CommSpec cost_dict = comm_spec.get_comm_cost() # generate new sharding spec try: - new_sharding_spec = ShardingSpec(source_spec.device_mesh, - source_spec.entire_shape, - dim_partition_dict=new_dim_partition_dict) + new_sharding_spec = ShardingSpec( + source_spec.device_mesh, source_spec.entire_shape, dim_partition_dict=new_dim_partition_dict + ) for phase, cost in cost_dict.items(): cost_dict[phase] = cost + orig_cost_dict[phase] valid_spec_dict[new_sharding_spec] = (comm_spec, cost_dict) @@ -330,16 +337,17 @@ class ShapeConsistencyManager(metaclass=SingletonMeta): pass return valid_spec_dict - def get_all_mix_gather_spec(self, source_spec: ShardingSpec, - orig_cost_dict: Dict[str, float]) -> Dict[ShardingSpec, float]: - ''' + def get_all_mix_gather_spec( + self, source_spec: ShardingSpec, orig_cost_dict: Dict[str, float] + ) -> Dict[ShardingSpec, float]: + """ S0S1 -> RR S1S0 -> RR S01R -> RR RS01 -> RR - ''' + """ valid_spec_dict = {} - comm_pathern = CollectiveCommPattern.MIXGATHER_FWD_SPLIT_BWD + comm_pattern = CollectiveCommPattern.MIXGATHER_FWD_SPLIT_BWD tensor_dims = len(source_spec.entire_shape) for f_index in range(tensor_dims - 1): for b_index in range(f_index + 1, tensor_dims): @@ -362,19 +370,21 @@ class ShapeConsistencyManager(metaclass=SingletonMeta): b_target_pair = (b_index, []) gather_dim, logical_process_axes = mix_gather_simulator(f_target_pair, b_target_pair) - comm_spec = CommSpec(comm_pathern, - sharding_spec=source_spec, - gather_dim=gather_dim, - logical_process_axis=logical_process_axes, - forward_only=self.forward_only, - mix_gather=True) + comm_spec = CommSpec( + comm_pattern, + sharding_spec=source_spec, + gather_dim=gather_dim, + logical_process_axis=logical_process_axes, + forward_only=self.forward_only, + mix_gather=True, + ) cost_dict = comm_spec.get_comm_cost() new_dim_partition_dict = {} # generate new sharding spec try: - new_sharding_spec = ShardingSpec(source_spec.device_mesh, - source_spec.entire_shape, - dim_partition_dict=new_dim_partition_dict) + new_sharding_spec = ShardingSpec( + source_spec.device_mesh, source_spec.entire_shape, dim_partition_dict=new_dim_partition_dict + ) for phase, cost in cost_dict.items(): cost_dict[phase] = cost + orig_cost_dict[phase] valid_spec_dict[new_sharding_spec] = (comm_spec, cost_dict) @@ -384,9 +394,9 @@ class ShapeConsistencyManager(metaclass=SingletonMeta): return valid_spec_dict def get_all_one_step_transform_spec(self, source_spec: ShardingSpec, orig_cost_dict) -> Dict[ShardingSpec, float]: - ''' + """ Get all valid sharding specs from source_spec with one step transform, and - accumulate commucation cost on origin cost which will finally be used in auto sharding solver. + accumulate communication cost on origin cost which will finally be used in auto sharding solver. Note: all-gather will eliminate a sharding dimension, all-to-all will keep sharding dimension same as before, and shard will add a sharding dimension. Therefore, the result of above operations are mutual exclusive, @@ -398,7 +408,7 @@ class ShapeConsistencyManager(metaclass=SingletonMeta): Return: valid_spec_dict(Dict[ShardingSpec, float]): all valid sharding specs from source_spec with single all-to-all operation. - ''' + """ valid_spec_dict = {} valid_spec_dict.update(self.get_all_all_gather_spec(source_spec, orig_cost_dict)) valid_spec_dict.update(self.get_all_all_to_all_spec(source_spec, orig_cost_dict)) @@ -435,7 +445,7 @@ class ShapeConsistencyManager(metaclass=SingletonMeta): """ input_shape = compute_shape(comm_spec.sharding_spec) input_numel = np.prod(input_shape) - output_numel = input_numel * comm_spec.device_mesh.mesh_shape[comm_spec.logical_process_axis] + output_numel = input_numel * comm_spec.device_mesh.shape[comm_spec.logical_process_axis] peak_numel = max(peak_numel, alloc_numel + output_numel * 2) alloc_numel += output_numel if discard_input: @@ -461,7 +471,7 @@ class ShapeConsistencyManager(metaclass=SingletonMeta): # generate a new tensor input_shape = compute_shape(comm_spec.sharding_spec) input_numel = np.prod(input_shape) - output_numel = input_numel // comm_spec.device_mesh.mesh_shape[comm_spec.logical_process_axis] + output_numel = input_numel // comm_spec.device_mesh.shape[comm_spec.logical_process_axis] alloc_numel += output_numel peak_numel = max(peak_numel, alloc_numel) if discard_input: @@ -545,18 +555,22 @@ class ShapeConsistencyManager(metaclass=SingletonMeta): for idx, action_spec_pair in enumerate(zip(fwd_actions, comm_action_sequence)): # the first forward comm action will not discard input fwd_action, comm_spec = action_spec_pair - fwd_alloc_numel, fwd_peak_numel = fwd_action(comm_spec, False, fwd_alloc_numel, - fwd_peak_numel) if idx == 0 else fwd_action( - comm_spec, True, fwd_alloc_numel, fwd_peak_numel) + fwd_alloc_numel, fwd_peak_numel = ( + fwd_action(comm_spec, False, fwd_alloc_numel, fwd_peak_numel) + if idx == 0 + else fwd_action(comm_spec, True, fwd_alloc_numel, fwd_peak_numel) + ) # analyze memory footprint for backward comm actions sequence bwd_alloc_numel = 0 bwd_peak_numel = 0 for idx, action_spec_pair in enumerate(zip(reversed(bwd_actions), reversed(comm_action_sequence))): bwd_action, comm_spec = action_spec_pair - bwd_alloc_numel, bwd_peak_numel = bwd_action(comm_spec, False, bwd_alloc_numel, - bwd_peak_numel) if idx == 0 else bwd_action( - comm_spec, True, bwd_alloc_numel, bwd_peak_numel) + bwd_alloc_numel, bwd_peak_numel = ( + bwd_action(comm_spec, False, bwd_alloc_numel, bwd_peak_numel) + if idx == 0 + else bwd_action(comm_spec, True, bwd_alloc_numel, bwd_peak_numel) + ) fwd_mem = MemoryCost(activation=fwd_alloc_numel, temp=fwd_peak_numel - fwd_alloc_numel) bwd_mem = MemoryCost(activation=bwd_alloc_numel, temp=bwd_peak_numel - bwd_alloc_numel) @@ -564,9 +578,10 @@ class ShapeConsistencyManager(metaclass=SingletonMeta): return TrainCycleItem(fwd_mem, bwd_mem, total_mem) - def shape_consistency(self, source_spec: ShardingSpec, - target_spec: ShardingSpec) -> Tuple[List[ShardingSpec], List[CommSpec], float]: - ''' + def shape_consistency( + self, source_spec: ShardingSpec, target_spec: ShardingSpec + ) -> Tuple[List[ShardingSpec], List[CommSpec], float]: + """ This method will find a path to transform source_spec to target_spec with a greedy algorithm. The basic idea is: @@ -577,7 +592,7 @@ class ShapeConsistencyManager(metaclass=SingletonMeta): Step3: Repeat above steps until the source spec transform to target spec. - During finding the transform path, commucation cost will be accumulated, and it + During finding the transform path, communication cost will be accumulated, and it will be finally used in auto parallel solver. Additionally, to avoid repeating the path search in runtime, we cached all solved path @@ -623,9 +638,9 @@ class ShapeConsistencyManager(metaclass=SingletonMeta): CommSpec:(comm_pattern:all2all, gather_dim:1, shard_dim:0, logical_process_axis: 0), CommSpec:(comm_pattern:shard, shard_dim:0, logical_process_axis:1)] total_cost: 12294.402000000002 - ''' + """ MAX_TRANSFORM_STEPS = 20 - total_cost_dict = {'forward': 0, 'backward': 0, 'total': 0} + total_cost_dict = {"forward": 0, "backward": 0, "total": 0} total_steps = 0 transform_path = [] comm_action_sequence = [] @@ -672,7 +687,7 @@ class ShapeConsistencyManager(metaclass=SingletonMeta): raise RuntimeError(f"Could not find a valid transform path with in {MAX_TRANSFORM_STEPS} steps.") def apply(self, tensor_with_sharding_spec: torch.Tensor, target_spec: ShardingSpec) -> torch.Tensor: - ''' + """ Apply target_spec to tensor with source sharding spec, the transform path is generated by the shape_consistency method. @@ -729,7 +744,7 @@ class ShapeConsistencyManager(metaclass=SingletonMeta): [1.], [3.], [3.]]) - ''' + """ _, comm_action_sequence, _ = self.shape_consistency(tensor_with_sharding_spec.sharding_spec, target_spec) for comm_spec in comm_action_sequence: tensor_with_sharding_spec = comm_spec.covert_spec_to_action(tensor_with_sharding_spec) diff --git a/colossalai/tensor/sharding_spec.py b/colossalai/tensor/sharding_spec.py index bed320130ccdc09cbd2e82f2c1d1ba6d01f4e295..b78ef6d97dd44b0f318213c4565d1b2a70899de7 100644 --- a/colossalai/tensor/sharding_spec.py +++ b/colossalai/tensor/sharding_spec.py @@ -8,16 +8,16 @@ from colossalai.device.device_mesh import DeviceMesh from .utils import merge_same_dim_mesh_list -__all__ = ['_DimSpec', 'ShardingException', 'ShardingSpec'] +__all__ = ["_DimSpec", "ShardingException", "ShardingSpec"] ALLGATHER_COST = 20 SHARD_COST = 5 STEP_PENALTY = 6 -NAN = 'nan' +NAN = "nan" class _DimSpec: - ''' + """ Sharding spec for single dimension of the sharded tensor describe the sharding dimension of logical device mesh and give a method to compute the difference between them. This class is used internally in ShardingSpec. @@ -25,7 +25,7 @@ class _DimSpec: Argument: shard_list(List[int]): if shard_list is None, the dim spec will be 'R' type. Otherwise, the element in shard_list means the data will be sharded in that dimension. - ''' + """ def __init__(self, shard_list): self.is_replica = len(shard_list) == 0 @@ -37,41 +37,40 @@ class _DimSpec: def __repr__(self): if self.is_replica: - return 'R' - target = 'S' + return "R" + target = "S" for dim in self.shard_list: target += str(dim) return target def _convert_str_to_shard_list(self, str_spec): - ''' - Conver str_spec into shard_list. + """ + Convert str_spec into shard_list. Argument: str_spec(str): dim spec in str type. - ''' + """ - if str_spec == 'R': + if str_spec == "R": return [] - if str_spec == 'S0': + if str_spec == "S0": return [0] - if str_spec == 'S1': + if str_spec == "S1": return [1] - if str_spec == 'S01': + if str_spec == "S01": return [0, 1] def build_difference_2d_dict(self): - ''' - Build a difference maping for 2D device mesh case. It will be used to + """ + Build a difference mapping for 2D device mesh case. It will be used to compute the difference between DimSpec pairs. - ''' + """ - source_spec_list = ['R', 'S0', 'S1', 'S01'] - target_spec_list = ['R', 'S0', 'S1', 'S01'] + source_spec_list = ["R", "S0", "S1", "S01"] + target_spec_list = ["R", "S0", "S1", "S01"] difference_dict = {} for source_spec in source_spec_list: for target_spec in target_spec_list: - legal_sharding_dims = [] spec_pair = (deepcopy(source_spec), deepcopy(target_spec)) source_shard_list = self._convert_str_to_shard_list(source_spec) target_shard_list = self._convert_str_to_shard_list(target_spec) @@ -81,14 +80,17 @@ class _DimSpec: difference = 0 # all_gather(source) -> target - elif len(source_shard_list - ) == len(target_shard_list) + 1 and source_shard_list[:-1] == target_shard_list: + elif ( + len(source_shard_list) == len(target_shard_list) + 1 and source_shard_list[:-1] == target_shard_list + ): difference = ALLGATHER_COST # shard(source) -> target - elif len(source_shard_list) == len( - target_shard_list) - 1 and source_shard_list == target_shard_list[:-1] and target_shard_list[ - -1] not in source_shard_list: + elif ( + len(source_shard_list) == len(target_shard_list) - 1 + and source_shard_list == target_shard_list[:-1] + and target_shard_list[-1] not in source_shard_list + ): difference = SHARD_COST # S1 -> S0 or S0 -> S1 @@ -119,7 +121,7 @@ class _DimSpec: self.difference_dict = difference_dict def difference(self, other): - ''' + """ The difference between two _DimSpec. Argument: @@ -135,7 +137,7 @@ class _DimSpec: Output: 5 - ''' + """ difference = self.difference_dict[(str(self), str(other))] return difference @@ -157,7 +159,7 @@ class ShardingNotDivisibleError(ShardingSpecException): class ShardingSpec: - ''' + """ Sharding spec for a tensor, it contains info of the logical device mesh this tensor belong to, the entire shape of the tensor before sharded, and the sharding sequence looks like [R, R, S0, S1]. @@ -166,15 +168,13 @@ class ShardingSpec: device_mesh(DeviceMesh): A logical view of a physical mesh. entire_shape(torch.Size): The entire shape of tensor before sharded. dim_partition_dict(Dict[int, List[int]], optional): The key is the dimension of tensor to be sharded, - and the value of the key decribe which logical axis will be sharded in that dimension. + and the value of the key describe which logical axis will be sharded in that dimension. sharding_sequence(List[_DimSpec], optional): A straight view of ShardingSpec looks like [R, R, S0, S1]. - ''' + """ - def __init__(self, - device_mesh: DeviceMesh, - entire_shape: torch.Size, - dim_partition_dict=None, - sharding_sequence=None): + def __init__( + self, device_mesh: DeviceMesh, entire_shape: torch.Size, dim_partition_dict=None, sharding_sequence=None + ): self.device_mesh = device_mesh if isinstance(entire_shape, (list, tuple)): @@ -183,20 +183,25 @@ class ShardingSpec: self.dim_partition_dict = dim_partition_dict self.sharding_sequence = sharding_sequence if self.sharding_sequence is None: - assert self.dim_partition_dict is not None, f'dim_partition_dict should not be None, if sharding_sequence is NoneType object.' - self.dim_partition_dict = merge_same_dim_mesh_list(dim_size=len(entire_shape), - dim_partition_dict=self.dim_partition_dict) + assert ( + self.dim_partition_dict is not None + ), f"dim_partition_dict should not be None, if sharding_sequence is NoneType object." + self.dim_partition_dict = merge_same_dim_mesh_list( + dim_size=len(entire_shape), dim_partition_dict=self.dim_partition_dict + ) self.convert_dict_to_shard_sequence() elif self.dim_partition_dict is None: - assert self.sharding_sequence is not None, f'sharding_sequence should not be None, if dim_partition_dict is NoneType object.' + assert ( + self.sharding_sequence is not None + ), f"sharding_sequence should not be None, if dim_partition_dict is NoneType object." self.convert_shard_sequence_to_dict() self._sanity_check() def __repr__(self): res_list = ["DistSpec:"] res_list.append(f"\n\tshard_sequence: " + ",".join(str(dimspec) for dimspec in self.sharding_sequence)) - res_list.append(f"\n\tdevice_mesh_shape: {self.device_mesh.mesh_shape}") - return ' '.join(res_list) + res_list.append(f"\n\tdevice_mesh_shape: {self.device_mesh.shape}") + return " ".join(res_list) def _sanity_check(self): # make sure all axes in logical device mesh only be used once @@ -207,7 +212,8 @@ class ShardingSpec: dim_check_list.remove(element) else: raise DuplicatedShardingDimensionError( - f"find an invalid sharding axis {element} in dim_partition_dict in tensor dimension {dim}.") + f"find an invalid sharding axis {element} in dim_partition_dict in tensor dimension {dim}." + ) # make sure that the dimension is not out of index for dim in self.dim_partition_dict.keys(): @@ -222,26 +228,26 @@ class ShardingSpec: num_devices = 1 for element in shard_list: - num_devices *= self.device_mesh.mesh_shape[element] + num_devices *= self.device_mesh.shape[element] if tensor_dim_size % num_devices != 0: raise ShardingNotDivisibleError( - f'The size of dimension at index {dim} is {tensor_dim_size}, it cannot be sharded over {num_devices} devices.' + f"The size of dimension at index {dim} is {tensor_dim_size}, it cannot be sharded over {num_devices} devices." ) def convert_dict_to_shard_sequence(self): - ''' + """ Convert dim_partition_dict into list of _DimSpec, and assign it to sharding_sequence. - ''' + """ sharding_sequence = [_DimSpec([])] * len(self.entire_shape) for dim, shard_list in self.dim_partition_dict.items(): sharding_sequence[dim] = _DimSpec(shard_list) self.sharding_sequence = sharding_sequence def convert_shard_sequence_to_dict(self): - ''' + """ Convert sharding_sequence into dim_partition_dict. - ''' + """ new_dim_partition_dict = {} for index, dim_spec in enumerate(self.sharding_sequence): if not dim_spec.is_replica: @@ -251,7 +257,7 @@ class ShardingSpec: self.dim_partition_dict = new_dim_partition_dict def sharding_sequence_difference(self, other): - ''' + """ This function is a naive version of difference computation. It just simply accumulates difference every dimension between the pair of sharding sequence. @@ -276,21 +282,22 @@ class ShardingSpec: Return: difference(int): Difference between two ShardingSpec. - ''' + """ assert len(self.sharding_sequence) == len( - other.sharding_sequence), f'Cannot compare difference for two sharding specs with different length.' + other.sharding_sequence + ), f"Cannot compare difference for two sharding specs with different length." difference = 0 for orig_dim_spec, other_dim_spec in zip(self.sharding_sequence, other.sharding_sequence): difference += orig_dim_spec.difference(other_dim_spec) return difference def get_sharded_shape_per_device(self): - sharded_shape = list(self.entire_shape) for dim, shard_list in self.dim_partition_dict.items(): - mesh_list = [self.device_mesh.mesh_shape[mesh_dim] for mesh_dim in shard_list] + mesh_list = [self.device_mesh.shape[mesh_dim] for mesh_dim in shard_list] shard_partitions = reduce(operator.mul, mesh_list, 1) - assert sharded_shape[ - dim] % shard_partitions == 0, f'Cannot shard dimension {dim} into {shard_partitions} partitions.' + assert ( + sharded_shape[dim] % shard_partitions == 0 + ), f"Cannot shard dimension {dim} into {shard_partitions} partitions." sharded_shape[dim] //= shard_partitions return torch.Size(sharded_shape) diff --git a/colossalai/tensor/utils.py b/colossalai/tensor/utils.py index 6e30f97fef0388ac4d65d7e83aa458257026f3ee..19dde8febf84b7b8d4d9f4768851f90b65957a60 100644 --- a/colossalai/tensor/utils.py +++ b/colossalai/tensor/utils.py @@ -7,7 +7,7 @@ from colossalai.tensor.colo_tensor import ColoTensor def all_gather_simulator(target_pair): - ''' + """ Simulating all-gather operation, analyze the communication cost and simulate the influence of the DimSpec. @@ -19,7 +19,7 @@ def all_gather_simulator(target_pair): Argument: target_pair(Tuple[int, List[int]]): The first element is the dimension of tensor to be sharded, and the second element describes which logical axis will be sharded in that dimension. - ''' + """ _, shard_list = target_pair new_shard_list = shard_list[:-1] @@ -27,7 +27,7 @@ def all_gather_simulator(target_pair): def all_to_all_simulator(f_target_pair, b_target_pair): - ''' + """ Simulating all-to-all operation, analyze the communication cost and simulate the influence of the DimSpec. @@ -47,7 +47,7 @@ def all_to_all_simulator(f_target_pair, b_target_pair): Argument: target_pair(Tuple[int, List[int]]): The first element is the dimension of tensor to be sharded, and the second element describes which logical axis will be sharded in that dimension. - ''' + """ _, f_shard_list = f_target_pair _, b_shard_list = b_target_pair if not len(b_shard_list): @@ -61,7 +61,7 @@ def all_to_all_simulator(f_target_pair, b_target_pair): def shard_simulator(target_pair, legal_sharding_dims): - ''' + """ Simulating shard operation, analyze the communication cost(always ZERO) and simulate the influence of the DimSpec. @@ -77,8 +77,8 @@ def shard_simulator(target_pair, legal_sharding_dims): Argument: target_pair(Tuple[int, List[int]]): The first element is the dimension of tensor to be sharded, - and the second element decribes which logical axis will be sharded in that dimension. - ''' + and the second element describes which logical axis will be sharded in that dimension. + """ _, shard_list = target_pair shard_list_list = [] for dim in legal_sharding_dims: @@ -91,7 +91,7 @@ def shard_simulator(target_pair, legal_sharding_dims): def mix_gather_simulator(f_target_pair, b_target_pair): - ''' + """ Assume index of f and b target pairs are 'f' and 'b' S0S1 => Input: (f, [0]), (b, [1]) Output: [b, f], (1, 0) S1S0 => Input: (f, [1]), (b, [0]) Output: [b, f], (0, 1) @@ -99,7 +99,7 @@ def mix_gather_simulator(f_target_pair, b_target_pair): RS01 => Input: (f, []), (b, [0, 1]) Output: [b], (1, 1) S10R => Input: (f, [0, 1]), (b, []) Output: [f], (0, 0) RS10 => Input: (f, []), (b, [0, 1]) Output: [b], (0, 0) - ''' + """ if f_target_pair[1] and b_target_pair[1]: leading_dim = b_target_pair[1] > f_target_pair[1] return [b_target_pair[0], f_target_pair[0]], [int(leading_dim), int(leading_dim ^ 1)] @@ -118,7 +118,7 @@ def mix_gather_simulator(f_target_pair, b_target_pair): # The function is credited to PyTorch Team def named_params_with_colotensor( module: nn.Module, - prefix: str = '', + prefix: str = "", recurse: bool = True, ) -> Iterator[Tuple[str, Union[nn.Parameter, ColoTensor]]]: r"""Returns an iterator over module parameters (together with the @@ -154,7 +154,7 @@ def named_params_with_colotensor( for name, val in vars(mod).items(): if isinstance(val, ColoTensor) and val not in memo: memo.add(val) - name = mod_prefix + ('.' if mod_prefix else '') + name + name = mod_prefix + ("." if mod_prefix else "") + name yield name, val # find all nn.Parameters @@ -169,15 +169,16 @@ def _convert_tensor(tensor: torch.Tensor) -> ColoTensor: def convert_parameter(module: torch.nn.Module, param_name: str): # Perform some validation first. if not hasattr(module, param_name): - raise ValueError(f'module: {module} does not have parameter with name: {param_name}') + raise ValueError(f"module: {module} does not have parameter with name: {param_name}") tensor = getattr(module, param_name) if not isinstance(tensor, torch.Tensor): raise ValueError( - f'Expected {type(module).__name__}.{param_name} to be a Tensor, but found {type(tensor).__name__}') + f"Expected {type(module).__name__}.{param_name} to be a Tensor, but found {type(tensor).__name__}" + ) if not tensor.is_contiguous(): - raise ValueError(f'param: {param_name} is not a contiguous Tensor') + raise ValueError(f"param: {param_name} is not a contiguous Tensor") st = _convert_tensor(tensor) @@ -193,9 +194,9 @@ def convert_parameter(module: torch.nn.Module, param_name: str): def convert_dim_partition_dict(dim_size: int, dim_partition_dict: Dict[int, List[int]]) -> Dict[int, List[int]]: - ''' + """ This method is used to convert the negative dim value to positive. - ''' + """ dims_to_convert = [] for dim, mesh_list in dim_partition_dict.items(): if dim < 0: @@ -207,13 +208,13 @@ def convert_dim_partition_dict(dim_size: int, dim_partition_dict: Dict[int, List def merge_same_dim_mesh_list(dim_size: int, dim_partition_dict: Dict[int, List[int]]) -> Dict[int, List[int]]: - ''' + """ This method is used to merge the different key value which points to same physical position. For example: dim_partition_dict: {1 :[0], -1: [1]} or {1: [0], 1: [1]} for a 2d tensor, the dim 1 and -1 point same physical position. In this method, above dim_partition_dict will be converted to {1: [0, 1]} - ''' + """ converted_dim_partition_dict = {} for dim, mesh_list in dim_partition_dict.items(): if dim < 0: diff --git a/colossalai/testing/__init__.py b/colossalai/testing/__init__.py index c53e0f44c7e0a4d9c925844a84c5e4e775fd6883..c6956e81fbde8d7e95674fa3937499b98ecd7147 100644 --- a/colossalai/testing/__init__.py +++ b/colossalai/testing/__init__.py @@ -1,4 +1,12 @@ -from .comparison import assert_close, assert_close_loose, assert_equal, assert_equal_in_group, assert_not_equal +from .comparison import ( + assert_close, + assert_close_loose, + assert_equal, + assert_equal_in_group, + assert_hf_output_close, + assert_not_equal, + check_state_dict_equal, +) from .pytest_wrapper import run_on_environment_flag from .utils import ( clear_cache_before_run, @@ -11,7 +19,19 @@ from .utils import ( ) __all__ = [ - 'assert_equal', 'assert_not_equal', 'assert_close', 'assert_close_loose', 'assert_equal_in_group', 'parameterize', - 'rerun_on_exception', 'rerun_if_address_is_in_use', 'skip_if_not_enough_gpus', 'free_port', 'spawn', - 'clear_cache_before_run', 'run_on_environment_flag' + "assert_equal", + "assert_not_equal", + "assert_close", + "assert_close_loose", + "assert_equal_in_group", + "parameterize", + "rerun_on_exception", + "rerun_if_address_is_in_use", + "skip_if_not_enough_gpus", + "free_port", + "spawn", + "clear_cache_before_run", + "run_on_environment_flag", + "check_state_dict_equal", + "assert_hf_output_close", ] diff --git a/colossalai/testing/comparison.py b/colossalai/testing/comparison.py index e00d0da168c72c286f7515056bd8acbccb22695d..816bc0d7b6d78402241cff469b173e1025876a0e 100644 --- a/colossalai/testing/comparison.py +++ b/colossalai/testing/comparison.py @@ -1,20 +1,30 @@ +from typing import Any, List, OrderedDict + import torch import torch.distributed as dist from torch import Tensor from torch.distributed import ProcessGroup from torch.testing import assert_close +from torch.utils._pytree import tree_flatten def assert_equal(a: Tensor, b: Tensor): - assert torch.all(a == b), f'expected a and b to be equal but they are not, {a} vs {b}' + assert torch.all(a == b), f"expected a and b to be equal but they are not, {a} vs {b}" def assert_not_equal(a: Tensor, b: Tensor): - assert not torch.all(a == b), f'expected a and b to be not equal but they are, {a} vs {b}' + assert not torch.all(a == b), f"expected a and b to be not equal but they are, {a} vs {b}" def assert_close_loose(a: Tensor, b: Tensor, rtol: float = 1e-3, atol: float = 1e-3): - assert_close(a, b, rtol=rtol, atol=atol) + assert_close( + a, + b, + rtol=rtol, + atol=atol, + msg=f"Tensor not close, shape: {a.shape} vs {b.shape}, \ + dtype: {a.dtype} vs {b.dtype}", + ) def assert_equal_in_group(tensor: Tensor, process_group: ProcessGroup = None): @@ -27,4 +37,93 @@ def assert_equal_in_group(tensor: Tensor, process_group: ProcessGroup = None): for i in range(world_size - 1): a = tensor_list[i] b = tensor_list[i + 1] - assert torch.all(a == b), f'expected tensors on rank {i} and {i + 1} to be equal but they are not, {a} vs {b}' + assert torch.all(a == b), f"expected tensors on rank {i} and {i + 1} to be equal but they are not, {a} vs {b}" + + +def check_state_dict_equal(d1: OrderedDict, d2: OrderedDict, ignore_device: bool = True): + assert len(list(d1.keys())) == len( + list(d2.keys()) + ), f"Number of keys unequal: {len(list(d1.keys()))} vs {len(list(d2.keys()))}" + for k, v1 in d1.items(): + assert k in d2 + v2 = d2[k] + if isinstance(v1, dict): + assert isinstance(v2, dict) + check_state_dict_equal(v1, v2, ignore_device) + elif isinstance(v1, list): + assert isinstance(v2, list) + for v1_i, v2_i in zip(v1, v2): + if isinstance(v1_i, torch.Tensor): + assert isinstance(v2_i, torch.Tensor) + if not ignore_device: + v1_i = v1_i.to("cpu") + v2_i = v2_i.to("cpu") + assert_close_loose(v1_i, v2_i) + elif isinstance(v1_i, dict): + assert isinstance(v2_i, dict) + check_state_dict_equal(v1_i, v2_i, ignore_device) + else: + assert v1_i == v2_i, f"{v1_i} not equals to {v2_i}" + elif isinstance(v1, torch.Tensor): + assert isinstance(v2, torch.Tensor) + if not ignore_device: + v1 = v1.to("cpu") + v2 = v2.to("cpu") + assert_close_loose(v1, v2) + else: + assert v1 == v2, f"{v1} not equals to {v2}" + + +def check_state_dict_equal_pytree(d1: OrderedDict, d2: OrderedDict, ignore_device: bool = True): + flat_d1, _ = tree_flatten(d1) + flat_d2, _ = tree_flatten(d2) + assert len(flat_d1) == len(flat_d2) + for v1, v2 in zip(flat_d1, flat_d2): + if isinstance(v1, torch.Tensor): + assert isinstance(v2, torch.Tensor) + if not ignore_device: + v1 = v1.to("cpu") + v2 = v2.to("cpu") + assert_close_loose(v1, v2) + else: + assert v1 == v2, f"{v1} not equals to {v2}" + + +def assert_hf_output_close( + out1: Any, out2: Any, ignore_keys: List[str] = None, track_name: str = "", atol=1e-5, rtol=1e-5 +): + """ + Check if two outputs from huggingface are equal. + + Args: + out1 (Any): the first output + out2 (Any): the second output + ignore_keys (List[str]): the keys to ignore when comparing two dicts + track_name (str): the name of the value compared, used to track the path + """ + if isinstance(out1, dict) and isinstance(out2, dict): + # if two values are dict + # we recursively check the keys + assert set(out1.keys()) == set(out2.keys()) + for k in out1.keys(): + if ignore_keys is not None and k in ignore_keys: + continue + assert_hf_output_close( + out1[k], out2[k], track_name=f"{track_name}.{k}", ignore_keys=ignore_keys, atol=atol, rtol=rtol + ) + elif isinstance(out1, (list, tuple)) and isinstance(out2, (list, tuple)): + # if two values are list + # we recursively check the elements + assert len(out1) == len(out2) + for i in range(len(out1)): + assert_hf_output_close( + out1[i], out2[i], track_name=f"{track_name}.{i}", ignore_keys=ignore_keys, atol=atol, rtol=rtol + ) + elif isinstance(out1, Tensor) and isinstance(out2, Tensor): + if out1.shape != out2.shape: + raise AssertionError(f"{track_name}: shape mismatch: {out1.shape} vs {out2.shape}") + assert torch.allclose( + out1, out2, atol=atol, rtol=rtol + ), f"{track_name}: tensor value mismatch\nvalue 1: {out1}\nvalue 2: {out2}, \nmean error: {torch.abs(out1 - out2).mean()}" + else: + assert out1 == out2, f"{track_name}: value mismatch.\nout1: {out1}\nout2: {out2}" diff --git a/colossalai/testing/pytest_wrapper.py b/colossalai/testing/pytest_wrapper.py index a472eb3723ec1e8ce8705aa978355a460179a328..b1e82b469c9632819a57f4e2d095ebd10c1edc4d 100644 --- a/colossalai/testing/pytest_wrapper.py +++ b/colossalai/testing/pytest_wrapper.py @@ -1,10 +1,9 @@ """ This file will not be automatically imported by `colossalai.testing` -as this file has a dependency on `pytest`. Therefore, you need to +as this file has a dependency on `pytest`. Therefore, you need to explicitly import this file `from colossalai.testing.pytest_wrapper import `.from """ -import pytest import os @@ -30,11 +29,18 @@ def run_on_environment_flag(name: str): pytest test_for_something.py """ + try: + import pytest + except ImportError: + raise ImportError( + "This function requires `pytest` to be installed, please do `pip install pytest` and try again." + ) + assert isinstance(name, str) - flag = os.environ.get(name.upper(), '0') + flag = os.environ.get(name.upper(), "0") - reason = f'Environment varialbe {name} is {flag}' - if flag == '1': + reason = f"Environment variable {name} is {flag}" + if flag == "1": return pytest.mark.skipif(False, reason=reason) else: return pytest.mark.skipif(True, reason=reason) diff --git a/colossalai/testing/random.py b/colossalai/testing/random.py index ad6d24a4b94b152f20f404f5b1bb1cae2d74a0b5..4525dff3fe80d0a2445fee8354f511d2ba4a751c 100644 --- a/colossalai/testing/random.py +++ b/colossalai/testing/random.py @@ -11,7 +11,7 @@ def seed_all(seed, cuda_deterministic=False): if torch.cuda.is_available(): torch.cuda.manual_seed(seed) torch.cuda.manual_seed_all(seed) - if cuda_deterministic: # slower, more reproducible + if cuda_deterministic: # slower, more reproducible torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False else: diff --git a/colossalai/testing/utils.py b/colossalai/testing/utils.py index 6583eeb12bf43e8036fef3d583cc2a5e77540441..fdbda9a598bf2bcee48cb06dd10bf81b0a57d306 100644 --- a/colossalai/testing/utils.py +++ b/colossalai/testing/utils.py @@ -55,7 +55,6 @@ def parameterize(argument: str, values: List[Any]) -> Callable: """ def _wrapper(func): - def _execute_function_by_param(**kwargs): for val in values: arg_map = {argument: val} @@ -120,11 +119,11 @@ def rerun_on_exception(exception_type: Exception = Exception, pattern: str = Non return False def _wrapper(func): - def _run_until_success(*args, **kwargs): try_count = 0 - assert max_try is None or isinstance(max_try, int), \ - f'Expected max_try to be None or int, but got {type(max_try)}' + assert max_try is None or isinstance( + max_try, int + ), f"Expected max_try to be None or int, but got {type(max_try)}" while max_try is None or try_count < max_try: try: @@ -132,14 +131,14 @@ def rerun_on_exception(exception_type: Exception = Exception, pattern: str = Non ret = func(*args, **kwargs) return ret except exception_type as e: - error_lines = str(e).split('\n') + error_lines = str(e).split("\n") if try_count < max_try and (pattern is None or _match_lines(error_lines, pattern)): - print('Exception is caught, retrying...') + print("Exception is caught, retrying...") # when pattern is not specified, we always skip the exception # when pattern is specified, we only skip when pattern is matched continue else: - print('Maximum number of attempts is reached or pattern is not matched, no more retrying...') + print("Maximum number of attempts is reached or pattern is not matched, no more retrying...") raise e # Override signature @@ -167,10 +166,10 @@ def rerun_if_address_is_in_use(): """ # check version torch_version = version.parse(torch.__version__) - assert torch_version.major == 1 + assert torch_version.major >= 1 # only torch >= 1.8 has ProcessRaisedException - if torch_version.minor >= 8: + if torch_version >= version.parse("1.8.0"): exception = torch.multiprocessing.ProcessRaisedException else: exception = Exception @@ -198,7 +197,6 @@ def skip_if_not_enough_gpus(min_gpus: int): """ def _wrap_func(f): - def _execute_by_gpu_num(*args, **kwargs): num_avail_gpu = torch.cuda.device_count() if num_avail_gpu >= min_gpus: @@ -263,7 +261,6 @@ def clear_cache_before_run(): """ def _wrap_func(f): - def _clear_cache(*args, **kwargs): torch.cuda.empty_cache() torch.cuda.reset_peak_memory_stats() diff --git a/colossalai/trainer/__init__.py b/colossalai/trainer/__init__.py deleted file mode 100644 index 84e53dc4e87ac5b10a93aacc0fce975cc49c66eb..0000000000000000000000000000000000000000 --- a/colossalai/trainer/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -from ._trainer import Trainer - -__all__ = ['Trainer'] diff --git a/colossalai/trainer/hooks/__init__.py b/colossalai/trainer/hooks/__init__.py deleted file mode 100644 index 4d36093833d99429b15fb35962c930646b5cbf64..0000000000000000000000000000000000000000 --- a/colossalai/trainer/hooks/__init__.py +++ /dev/null @@ -1,12 +0,0 @@ -from ._base_hook import BaseHook -from ._checkpoint_hook import SaveCheckpointHook -from ._log_hook import (LogMemoryByEpochHook, LogMetricByEpochHook, LogMetricByStepHook, LogTimingByEpochHook, - TensorboardHook) -from ._lr_scheduler_hook import LRSchedulerHook -from ._metric_hook import AccuracyHook, LossHook, MetricHook, ThroughputHook - -__all__ = [ - 'BaseHook', 'MetricHook', 'LossHook', 'AccuracyHook', 'LogMetricByEpochHook', 'TensorboardHook', - 'LogTimingByEpochHook', 'LogMemoryByEpochHook', 'LRSchedulerHook', 'ThroughputHook', 'LogMetricByStepHook', - 'SaveCheckpointHook' -] diff --git a/colossalai/trainer/hooks/_base_hook.py b/colossalai/trainer/hooks/_base_hook.py deleted file mode 100644 index cca8e081ec883b8b5f3d88633ec9f57cc9fd6dfc..0000000000000000000000000000000000000000 --- a/colossalai/trainer/hooks/_base_hook.py +++ /dev/null @@ -1,106 +0,0 @@ -#!/usr/bin/env python -# -*- encoding: utf-8 -*- - -from abc import ABC - -from torch import Tensor - - -class BaseHook(ABC): - """This class allows users to add desired actions in specific time points - during training or evaluation. - - :param priority: Priority in the printing, hooks with small priority will be printed in front - :type priority: int - """ - - def __init__(self, priority: int) -> None: - self.priority = priority - - def after_hook_is_attached(self, trainer): - """Actions after hooks are attached to trainer. - """ - pass - - def before_train(self, trainer): - """Actions before training. - """ - pass - - def after_train(self, trainer): - """Actions after training. - """ - pass - - def before_train_iter(self, trainer): - """Actions before running a training iteration. - """ - pass - - def after_train_iter(self, trainer, output: Tensor, label: Tensor, loss: Tensor): - """Actions after running a training iteration. - - Args: - trainer (:class:`Trainer`): Trainer which is using this hook. - output (:class:`torch.Tensor`): Output of the model. - label (:class:`torch.Tensor`): Labels of the input data. - loss (:class:`torch.Tensor`): Loss between the output and input data. - """ - pass - - def before_train_epoch(self, trainer): - """Actions before starting a training epoch. - """ - pass - - def after_train_epoch(self, trainer): - """Actions after finishing a training epoch. - """ - pass - - def before_test(self, trainer): - """Actions before evaluation. - """ - pass - - def after_test(self, trainer): - """Actions after evaluation. - """ - pass - - def before_test_epoch(self, trainer): - """Actions before starting a testing epoch. - """ - pass - - def after_test_epoch(self, trainer): - """Actions after finishing a testing epoch. - """ - pass - - def before_test_iter(self, trainer): - """Actions before running a testing iteration. - """ - pass - - def after_test_iter(self, trainer, output: Tensor, label: Tensor, loss: Tensor): - """Actions after running a testing iteration. - - Args: - trainer (:class:`Trainer`): Trainer which is using this hook - output (:class:`torch.Tensor`): Output of the model - label (:class:`torch.Tensor`): Labels of the input data - loss (:class:`torch.Tensor`): Loss between the output and input data - """ - pass - - def init_runner_states(self, trainer, key, val): - """Initializes trainer's state. - - Args: - trainer (:class:`Trainer`): Trainer which is using this hook - key: Key of state to be reset - val: Value of state to be reset - """ - if key not in trainer.states: - trainer.states[key] = val diff --git a/colossalai/trainer/hooks/_checkpoint_hook.py b/colossalai/trainer/hooks/_checkpoint_hook.py deleted file mode 100644 index 3bcb32cd2dcbc46a9e57dfec1f72abe9bd4aabda..0000000000000000000000000000000000000000 --- a/colossalai/trainer/hooks/_checkpoint_hook.py +++ /dev/null @@ -1,72 +0,0 @@ -#!/usr/bin/env python -# -*- encoding: utf-8 -*- -import torch -from colossalai.logging import get_dist_logger - -from colossalai.registry import HOOKS -from colossalai.trainer.hooks import BaseHook -from colossalai.utils.checkpointing import save_checkpoint -from ._lr_scheduler_hook import LRSchedulerHook - - -@HOOKS.register_module -class SaveCheckpointHook(BaseHook): - """Saves the model by interval in training process. - - Args: - interval (int, optional): Number of epochs between saving the checkpoint, defaults to 1. - if save_by_iter is True, this arg refers to the number of iters between saving. - checkpoint_dir (str, optional): File name to save the checkpoint, defaults to None. - model (torch.nn.Module, Optional): The model to save, defaults to None. When not passing, - 'trainer.engine.model' will be used. We encourage you to pass the model in it to avoid some - unexpected bugs, especially when using **DDP**. - save_by_iter (bool, optional): Whether saving the checkpoint by iter, default to False. - priority (int, optional): Priority in the printing, hooks with small priority will be printed in front - defaults to 10. If different hooks share same priority, the order of printing would - depend on the hooks order in the hook list. - """ - - def __init__(self, - interval: int = 1, - checkpoint_dir: str = None, - model: torch.nn.Module = None, - save_by_iter: bool = False, - priority: int = 10): - super().__init__(priority=priority) - self.interval = interval - self.checkpoint_dir = checkpoint_dir - self.model = model - self.save_by_iter = save_by_iter - self.logger = get_dist_logger() - - # get lr scheduler from the LRSchedulerHook before train - self._lr_scheduler = None - - def after_hook_is_attached(self, trainer): - # get lr scheduler if exists - for hook in trainer.hooks: - if isinstance(hook, LRSchedulerHook): - self._lr_scheduler = hook.lr_scheduler - break - self.model = self.model if self.model is not None else trainer.engine.model - - def after_train_iter(self, trainer, output, label, loss): - """Saves the model after a training iter. - """ - # save by interval - if self.save_by_iter and trainer.cur_step % self.interval == 0: - save_checkpoint(self.checkpoint_dir, trainer.cur_epoch, self.model, trainer.engine.optimizer, - self._lr_scheduler) - self.logger.info(f'checkpoint for iteration {trainer.cur_step} is saved to {self.checkpoint_dir}', - ranks=[0]) - else: - pass - - def after_train_epoch(self, trainer): - """Saves the model after a training epoch. - """ - # save by interval - if trainer.cur_epoch % self.interval == 0: - save_checkpoint(self.checkpoint_dir, trainer.cur_epoch, self.model, trainer.engine.optimizer, - self._lr_scheduler) - self.logger.info(f'checkpoint for epoch {trainer.cur_epoch} is saved to {self.checkpoint_dir}', ranks=[0]) diff --git a/colossalai/trainer/hooks/_commons_.py b/colossalai/trainer/hooks/_commons_.py deleted file mode 100644 index 4923b8cba6c04e482bd4f5163b33767d34e83ebb..0000000000000000000000000000000000000000 --- a/colossalai/trainer/hooks/_commons_.py +++ /dev/null @@ -1,9 +0,0 @@ -import torch - - -def _format_number(val, prec=5): - if isinstance(val, float): - return f'{val:.{prec}g}' - elif torch.is_tensor(val) and torch.is_floating_point(val): - return f'{val.item():.{prec}g}' - return val diff --git a/colossalai/trainer/hooks/_log_hook.py b/colossalai/trainer/hooks/_log_hook.py deleted file mode 100644 index 5b1f33983422b11389927190acbef90165a15e2a..0000000000000000000000000000000000000000 --- a/colossalai/trainer/hooks/_log_hook.py +++ /dev/null @@ -1,301 +0,0 @@ -#!/usr/bin/env python -# -*- encoding: utf-8 -*- - -import os -import os.path as osp - -from typing import List -from colossalai.context import ParallelMode -from colossalai.core import global_context as gpc -from colossalai.registry import HOOKS -from colossalai.logging import DistributedLogger -from colossalai.utils import report_memory_usage, is_dp_rank_0, \ - is_tp_rank_0, is_no_pp_or_last_stage, MultiTimer -from ._base_hook import BaseHook -from ._commons_ import _format_number -from colossalai.trainer.hooks._metric_hook import ThroughputMetric - - -class LogByEpochHook(BaseHook): - """Hook to log by epoch. - - Args: - logger (:class:`colossalai.logging.DistributedLogger`): Logger for recording the log information. - interval (int, optional): Interval of printing log information, defaults to 1. - priority (int, optional): Priority in the printing, hooks with small priority will be printed in front, - defaults to 1. If different hooks share same priority, the order of printing would - depend on the hooks order in the hook list. - """ - - def __init__(self, logger, interval: int = 1, priority: int = 1): - super().__init__(priority) - self.logger = logger - self._interval = interval - - def _is_epoch_to_log(self, trainer): - return trainer.cur_epoch % self._interval == 0 - - -@HOOKS.register_module -class LogMetricByStepHook(BaseHook): - """Hook to log metric by step. - - Args: - priority (int, optional): Priority in the printing, hooks with small priority will be printed in front, - defaults to 10. If different hooks share same priority, the order of printing would - depend on the hooks order in the hook list. - """ - - def __init__(self, priority: int = 10): - super().__init__(priority) - - def after_train_iter(self, trainer, *args): - trainer.states['step_metrics'] = dict() - for metric_name, metric_calculator in trainer.states['metrics']['train'].items(): - if isinstance(metric_calculator, ThroughputMetric): - trainer.states['step_metrics'][metric_name.lower()] = metric_calculator.get_last_step_info() - else: - trainer.states['step_metrics'][metric_name.lower()] = metric_calculator.get_last_step_value() - - def after_test_iter(self, trainer, *args): - trainer.states['step_metrics'] = dict() - for metric_name, metric_calculator in trainer.states['metrics']['test'].items(): - if isinstance(metric_calculator, ThroughputMetric): - trainer.states['step_metrics'][metric_name.lower()] = metric_calculator.get_last_step_info() - else: - trainer.states['step_metrics'][metric_name.lower()] = metric_calculator.get_last_step_value() - - -@HOOKS.register_module -class LogMetricByEpochHook(LogByEpochHook): - """Specialized hook to record the metric to log. - - Args: - logger (:class:`colossalai.logging.DistributedLogger`): Logger for recording the log information. - interval (int, optional): Interval of printing log information, defaults to 1. - priority (int, optional): Priority in the printing, hooks with small priority will be printed in front, - defaults to 10. If different hooks share same priority, the order of printing would - depend on the hooks order in the hook list. - """ - - def __init__(self, logger, interval: int = 1, priority: int = 10) -> None: - super().__init__(logger, interval, priority) - self._is_rank_to_log = is_dp_rank_0() and is_tp_rank_0() and is_no_pp_or_last_stage() - - def _get_str(self, trainer, mode): - msg = [] - for metric_name, metric_calculator in trainer.states['metrics'][mode].items(): - msg.append(f'{metric_name} = {_format_number(metric_calculator.get_accumulated_value())}') - msg = ' | '.join(msg) - return msg - - def after_train_epoch(self, trainer): - if self._is_epoch_to_log(trainer): - msg = self._get_str(trainer=trainer, mode='train') - - if self._is_rank_to_log: - self.logger.info(f'[Epoch {trainer.cur_epoch} / Train]: {msg}') - # f'Training - Epoch {trainer.cur_epoch} - {self.__class__.__name__}: {msg}') - - def after_test_epoch(self, trainer): - if self._is_epoch_to_log(trainer): - msg = self._get_str(trainer=trainer, mode='test') - if self._is_rank_to_log: - self.logger.info(f'[Epoch {trainer.cur_epoch} / Test]: {msg}') - # f'Testing - Epoch {trainer.cur_epoch} - {self.__class__.__name__}: {msg}') - - -@HOOKS.register_module -class TensorboardHook(BaseHook): - """Specialized hook to record the metric to Tensorboard. - - Args: - log_dir (str): Directory of log. - ranks (list): Ranks of processors. - parallel_mode (:class:`colossalai.context.parallel_mode.ParallelMode`, optional): Parallel mode used in trainer, - defaults to colossalai.context.parallel_mode.ParallelMode.GLOBAL. - priority (int, optional): Priority in the printing, hooks with small priority will be printed in front, - defaults to 10. If different hooks share same priority, the order of printing would - depend on the hooks order in the hook list. - """ - - def __init__( - self, - log_dir: str, - ranks: List = None, - parallel_mode: ParallelMode = ParallelMode.GLOBAL, - priority: int = 10, - ) -> None: - super().__init__(priority=priority) - from torch.utils.tensorboard import SummaryWriter - - # create log dir - if not gpc.is_initialized(ParallelMode.GLOBAL) or gpc.get_global_rank() == 0: - os.makedirs(log_dir, exist_ok=True) - - # determine the ranks to generate tensorboard logs - self._is_valid_rank_to_log = False - if not gpc.is_initialized(parallel_mode): - self._is_valid_rank_to_log = True - else: - local_rank = gpc.get_local_rank(parallel_mode) - - if ranks is None or local_rank in ranks: - self._is_valid_rank_to_log = True - - # check for - if gpc.is_initialized(ParallelMode.PIPELINE) and \ - not gpc.is_last_rank(ParallelMode.PIPELINE) and self._is_valid_rank_to_log: - raise ValueError("Tensorboard hook can only log on the last rank of pipeline process group") - - if self._is_valid_rank_to_log: - # create workspace on only one rank - if gpc.is_initialized(parallel_mode): - rank = gpc.get_local_rank(parallel_mode) - else: - rank = 0 - - # create workspace - log_dir = osp.join(log_dir, f'{parallel_mode}_rank_{rank}') - os.makedirs(log_dir, exist_ok=True) - - self.writer = SummaryWriter(log_dir=log_dir, filename_suffix=f'_rank_{rank}') - - def _log_by_iter(self, trainer, mode: str): - for metric_name, metric_calculator in trainer.states['metrics'][mode].items(): - if metric_calculator.epoch_only: - continue - val = metric_calculator.get_last_step_value() - - if self._is_valid_rank_to_log: - self.writer.add_scalar(f'{metric_name}/{mode}', val, trainer.cur_step) - - def _log_by_epoch(self, trainer, mode: str): - for metric_name, metric_calculator in trainer.states['metrics'][mode].items(): - if metric_calculator.epoch_only: - val = metric_calculator.get_accumulated_value() - if self._is_valid_rank_to_log: - self.writer.add_scalar(f'{metric_name}/{mode}', val, trainer.cur_step) - - def after_test_iter(self, trainer, *args): - self._log_by_iter(trainer, mode='test') - - def after_test_epoch(self, trainer): - self._log_by_epoch(trainer, mode='test') - - def after_train_iter(self, trainer, *args): - self._log_by_iter(trainer, mode='train') - - def after_train_epoch(self, trainer): - self._log_by_epoch(trainer, mode='train') - - -@HOOKS.register_module -class LogTimingByEpochHook(LogByEpochHook): - """Specialized hook to write timing record to log. - - Args: - timer (:class:`colossalai.utils.MultiTimer`): Timer for the hook. - logger (:class:`colossalai.logging.DistributedLogger`): Logger for recording the log information. - interval (int, optional): Interval of printing log information, defaults to 1. - priority (int, optional): Priority in the printing, hooks with small priority will be printed in front - defaults to 10. If different hooks share same priority, the order of printing would - depend on the hooks order in the hook list. - log_eval (bool, optional): Whether writes in evaluation, defaults to True. - ignore_num_train_steps (int, optional): Number of training steps to ignore, defaults to 0. - """ - - def __init__(self, - timer: MultiTimer, - logger: DistributedLogger, - interval: int = 1, - priority: int = 10, - log_eval: bool = True, - ignore_num_train_steps: int = 0) -> None: - super().__init__(logger=logger, interval=interval, priority=priority) - self._timer = timer - self._log_eval = log_eval - self._is_rank_to_log = is_dp_rank_0() and is_tp_rank_0() and is_no_pp_or_last_stage() - - # extra handling to avoid the unstable readings of the first - # few training steps to affect the history mean time - self._ignore_num_train_steps = ignore_num_train_steps - self._is_train_step_history_trimmed = False - - def _get_message(self, mode): - msg = [] - for timer_name, timer in self._timer: - if timer_name.startswith(mode): - last_elapsed_time = timer.get_elapsed_time() - if timer.has_history: - if timer_name == 'Train-step' and not self._is_train_step_history_trimmed: - timer._history = timer._history[self._ignore_num_train_steps:] - self._is_train_step_history_trimmed = True - history_mean = timer.get_history_mean() - history_sum = timer.get_history_sum() - msg.append( - f'{timer_name}: last = {_format_number(last_elapsed_time)} s, mean = {_format_number(history_mean)} s' - ) - else: - msg.append(f'{timer_name}: last = {_format_number(last_elapsed_time)} s') - - msg = ' | '.join(msg) - return msg - - def after_train_epoch(self, trainer): - """Writes log after finishing a training epoch. - """ - if self._is_epoch_to_log(trainer) and self._is_rank_to_log: - msg = self._get_message('Train') - self.logger.info(f'[Epoch {trainer.cur_epoch} / Train]: {msg} | #steps/epoch = {trainer.steps_per_epoch}') - - def after_test_epoch(self, trainer): - """Writes log after finishing a testing epoch. - """ - if self._is_epoch_to_log(trainer) and self._is_rank_to_log and self._log_eval: - msg = self._get_message('Test') - self.logger.info(f'[Epoch {trainer.cur_epoch} / Test]: {msg}') - - -@HOOKS.register_module -class LogMemoryByEpochHook(LogByEpochHook): - """Specialized Hook to write memory usage record to log. - - Args: - logger (:class:`colossalai.logging.DistributedLogger`): Logger for recording the log information. - interval (int, optional): Interval of printing log information, defaults to 1. - priority (int, optional): Priority in the printing, hooks with small priority will be printed in front - defaults to 1. If different hooks share same priority, the order of printing would - depend on the hooks order in the hook list. - log_eval (bool, optional): Whether writes in evaluation, defaults to True. - """ - - def __init__( - self, - logger: DistributedLogger, - interval: int = 1, - priority: int = 10, - log_eval: bool = True, - report_cpu: bool = False, # no reference - ) -> None: - super().__init__(logger=logger, interval=interval, priority=priority) - self._log_eval = log_eval - self._is_rank_to_log = is_dp_rank_0() and is_tp_rank_0() - - def before_train(self, trainer): - """Resets before training. - """ - if self._is_epoch_to_log(trainer) and self._is_rank_to_log: - report_memory_usage('Before-train', self.logger) - - def after_train_epoch(self, trainer): - """Writes log after finishing a training epoch. - """ - if self._is_epoch_to_log(trainer) and self._is_rank_to_log: - report_memory_usage(f'[Epoch {trainer.cur_epoch} / Train]', self.logger) - - def after_test(self, trainer): - """Reports after testing. - """ - if self._is_epoch_to_log(trainer) and self._is_rank_to_log and self._log_eval: - report_memory_usage(f'[Epoch {trainer.cur_epoch} / Test]', self.logger) diff --git a/colossalai/utils/__init__.py b/colossalai/utils/__init__.py index 7b2e8480c66ce09bb5ad99070a2c22c7cb380697..3ec39b949a2371d6610c7adf29d6048bbe8e8b78 100644 --- a/colossalai/utils/__init__.py +++ b/colossalai/utils/__init__.py @@ -1,75 +1,32 @@ -from .activation_checkpoint import checkpoint -from .checkpointing import load_checkpoint, save_checkpoint from .common import ( - clip_grad_norm_fp32, + _cast_float, conditional_context, - copy_tensor_parallel_attributes, - count_zeros_fp32, disposable, ensure_path_exists, + free_storage, is_ddp_ignored, - is_dp_rank_0, - is_model_parallel_parameter, - is_no_pp_or_last_stage, - is_tp_rank_0, - is_using_ddp, - is_using_pp, - is_using_sequence, - multi_tensor_applier, - param_is_not_tensor_parallel_duplicate, - print_rank_0, - switch_virtual_pipeline_parallel_rank, - sync_model_param, -) -from .cuda import empty_cache, get_current_device, set_to_cuda, synchronize -from .data_sampler import DataParallelSampler, get_dataloader -from .memory import ( - colo_device_memory_capacity, - colo_device_memory_used, - colo_get_cpu_memory_capacity, - colo_set_cpu_memory_capacity, - colo_set_process_memory_fraction, - report_memory_usage, + set_seed, ) +from .cuda import empty_cache, get_current_device, set_device, set_to_cuda, synchronize +from .multi_tensor_apply import multi_tensor_applier from .tensor_detector import TensorDetector from .timer import MultiTimer, Timer __all__ = [ - 'checkpoint', - 'print_rank_0', - 'sync_model_param', - 'is_ddp_ignored', - 'is_dp_rank_0', - 'is_tp_rank_0', - 'is_no_pp_or_last_stage', - 'is_using_ddp', - 'is_using_pp', - 'is_using_sequence', - 'conditional_context', - 'is_model_parallel_parameter', - 'clip_grad_norm_fp32', - 'count_zeros_fp32', - 'copy_tensor_parallel_attributes', - 'param_is_not_tensor_parallel_duplicate', - 'get_current_device', - 'synchronize', - 'empty_cache', - 'set_to_cuda', - 'report_memory_usage', - 'colo_device_memory_capacity', - 'colo_device_memory_used', - 'colo_set_process_memory_fraction', - 'Timer', - 'MultiTimer', - 'multi_tensor_applier', - 'DataParallelSampler', - 'get_dataloader', - 'switch_virtual_pipeline_parallel_rank', - 'TensorDetector', - 'load_checkpoint', - 'save_checkpoint', - 'ensure_path_exists', - 'disposable', - 'colo_set_cpu_memory_capacity', - 'colo_get_cpu_memory_capacity', + "conditional_context", + "get_current_device", + "synchronize", + "empty_cache", + "set_to_cuda", + "Timer", + "MultiTimer", + "multi_tensor_applier", + "TensorDetector", + "ensure_path_exists", + "disposable", + "_cast_float", + "free_storage", + "set_seed", + "is_ddp_ignored", + "set_device", ] diff --git a/colossalai/utils/checkpoint/__init__.py b/colossalai/utils/checkpoint/__init__.py deleted file mode 100644 index 1795b4ce36f41d2a09da0c324db4cb1ef21c5e2c..0000000000000000000000000000000000000000 --- a/colossalai/utils/checkpoint/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -from .module_checkpoint import save_checkpoint, load_checkpoint - -__all__ = ['save_checkpoint', 'load_checkpoint'] diff --git a/colossalai/utils/checkpoint/module_checkpoint.py b/colossalai/utils/checkpoint/module_checkpoint.py deleted file mode 100644 index d390da864cd387445991260b31331a6978248970..0000000000000000000000000000000000000000 --- a/colossalai/utils/checkpoint/module_checkpoint.py +++ /dev/null @@ -1,137 +0,0 @@ -import torch -import torch.distributed as dist -from colossalai.tensor import ColoTensor -from colossalai.nn.optimizer import ColossalaiOptimizer -from colossalai.utils.checkpoint.utils import gather_tensor, scatter_tensor -from typing import Optional, Dict - - -def save_checkpoint(path: str, - epoch: int, - model: torch.nn.Module, - optimizer: Optional[ColossalaiOptimizer] = None, - lr_scheduler: torch.optim.lr_scheduler._LRScheduler = None, - *args, - **kwargs): - """save_checkpoint - save a model, whose parameters are `ColoTensor`s. - Args: - path (str): directory to save the checkpoint files. - epoch (int): the number of epoch - model (torch.nn.Module): a torch module initialized by ColoInitContext - optimizer (ColossalaiOptimizer, optional): optimizers. Defaults to None. - lr_scheduler (torch.optim.lr_scheduler._LRScheduler, optional): lr schedule. Defaults to None. - """ - rank = dist.get_rank() - model_state = model.state_dict() - # save the dist context about the tensors in a new dict, while still maintain the original dict. - for k, v in model_state.items(): - if isinstance(v, ColoTensor): - gather_tensor(v) # gather shared tensors to rank0 - # don't recover tensors in rank0, since the dict is only a copy of model - - if rank == 0: - # sanity check - for k, v in model_state.items(): - if isinstance(v, ColoTensor): - assert v.save_ready - assert v.is_replicate() - delattr(v, 'save_ready') - # model saving - save_state = {'epoch': epoch, 'model': model_state} - torch.save(save_state, path + '/epoch_{}_model.pth'.format(epoch), *args, **kwargs) - - # delete old dicts - del model_state - # synchronize all the processes - dist.barrier() - - if optimizer is not None: - mapping = dict() - optim_state = optimizer.state_dict() - for k, v in optim_state['state'].items(): - for n, t in v.items(): - if isinstance(t, ColoTensor): - mapping[(k, n)] = t.dist_spec - gather_tensor(t) - - if rank == 0: - save_state = {'epoch': epoch, 'optim': optim_state} - torch.save(save_state, path + '/epoch_{}_optim.pth'.format(epoch), *args, **kwargs) - # recover colo tensors in rank0 - for k, v in optimizer.state_dict()['state'].items(): - for n, t in v.items(): - if isinstance(t, ColoTensor): - assert hasattr(t, 'save_ready') - t.set_dist_spec(mapping[(k, n)]) - delattr(t, 'save_ready') - - del optim_state - del mapping - dist.barrier() - - -def load_checkpoint(path: str, - epoch: int, - model: torch.nn.Module, - optimizer: Optional[ColossalaiOptimizer] = None, - lr_scheduler: torch.optim.lr_scheduler._LRScheduler = None, - torch_load_kwargs: Optional[Dict] = None, - load_state_dict_kwargs: Optional[Dict] = None): - """load_checkpoint - load a model, whose parameters are `ColoTensor`s. - Args: - path (str): directory to save the checkpoint files. - epoch (int): the number of epoch - model (torch.nn.Module): a torch module initialized by ColoInitContext - optimizer (ColossalaiOptimizer, optional): optimizers. Defaults to None. - lr_scheduler (torch.optim.lr_scheduler._LRScheduler, optional): lr schedule. Defaults to None. - torch_load_kwargs: (dict, optional): The kwargs of torch.load inside the function - load_state_dict_kwargs (dict, optional): The kwargs of load_state_dict inside the function - """ - # initialize the default parameters - if not torch_load_kwargs: - torch_load_kwargs = dict() - if not load_state_dict_kwargs: - load_state_dict_kwargs = dict() - - rank = dist.get_rank() - mapping = dict() - for n, p in model.named_parameters(): - if isinstance(p, ColoTensor): - mapping[n] = p.dist_spec - gather_tensor(p) - - if rank == 0: - load_state = torch.load(path + '/epoch_{}_model.pth'.format(epoch), **torch_load_kwargs) - model.load_state_dict(load_state['model'], **load_state_dict_kwargs) - dist.barrier() - - # scatter loaded parameters - for n, p in model.named_parameters(): - if isinstance(p, ColoTensor): - scatter_tensor(p, mapping[n]) - if rank == 0: - assert hasattr(p, 'save_ready') - delattr(p, 'save_ready') - del mapping - - if optimizer is not None: - mapping = dict() - for k, v in optimizer.state_dict()['state'].items(): - for n, t in v.items(): - if isinstance(t, ColoTensor): - mapping[(k, n)] = t.dist_spec - gather_tensor(t) - - if rank == 0: - colo_checkpoint = torch.load(path + '/epoch_{}_optim.pth'.format(epoch), **torch_load_kwargs) - optimizer.load_state_dict(colo_checkpoint['optim'], **load_state_dict_kwargs) - dist.barrier() - - for k, v in optimizer.state_dict()['state'].items(): - for n, t in v.items(): - if isinstance(t, ColoTensor): - scatter_tensor(t, mapping[(k, n)]) - - del mapping diff --git a/colossalai/utils/checkpoint/utils.py b/colossalai/utils/checkpoint/utils.py deleted file mode 100644 index 682cd0903d5b3b8028ea0741132f256b384d4cce..0000000000000000000000000000000000000000 --- a/colossalai/utils/checkpoint/utils.py +++ /dev/null @@ -1,63 +0,0 @@ -import torch -import torch.distributed as dist -from colossalai.tensor import ColoTensor, ColoTensorSpec -from colossalai.tensor.distspec import _DistSpec, DistPlacementPattern - - -def robust_broadcast(tensor): - with torch.no_grad(): - is_cpu_ten = tensor.device.type == 'cpu' - if is_cpu_ten: - b_data = tensor.cuda() - else: - b_data = tensor - - dist.broadcast(b_data, 0) - - if is_cpu_ten: - tensor.copy_(b_data) - - -def gather_tensor(colo_tensor: ColoTensor) -> None: - """Make colo_tensor replicated when the rank is 0 - """ - if not colo_tensor.is_replicate(): - pg = colo_tensor.get_process_group() - # for the group which contains rank 0 - if pg.dp_local_rank() == 0: - old_dist_spec = colo_tensor.dist_spec - colo_tensor.to_replicate_() - if dist.get_rank() != 0: - colo_tensor.set_dist_spec(old_dist_spec) - - # synchronize all processes for unexpected problems - dist.barrier() - - if dist.get_rank() == 0: - setattr(colo_tensor, 'save_ready', True) # set saving signature - - -def scatter_tensor(colo_tensor: ColoTensor, dist_spec: _DistSpec) -> None: - """Reversal operation of `gather_tensor`. - """ - if dist_spec.placement == DistPlacementPattern.REPLICATE: - robust_broadcast(colo_tensor.data) - else: - global_size = colo_tensor.size_global() - - if dist.get_rank() == 0: - entire_data = colo_tensor.data - else: - entire_data = torch.empty(global_size, device=colo_tensor.device) - robust_broadcast(entire_data) - - if dist.get_rank() == 0: - colo_tensor.set_dist_spec(dist_spec) - else: - rep_tensor = ColoTensor( - entire_data, ColoTensorSpec(pg=colo_tensor.get_process_group(), compute_attr=colo_tensor.compute_spec)) - rep_tensor.set_dist_spec(dist_spec) - with torch.no_grad(): - colo_tensor.data.copy_(rep_tensor.data) - # synchronize all processes for unexpected problems - dist.barrier() diff --git a/colossalai/utils/checkpoint_io/__init__.py b/colossalai/utils/checkpoint_io/__init__.py deleted file mode 100644 index fe030866894f7666195f1563f6699ae88aa78a61..0000000000000000000000000000000000000000 --- a/colossalai/utils/checkpoint_io/__init__.py +++ /dev/null @@ -1,2 +0,0 @@ -from .io import load, merge, redist, save -from .meta import (ParamDistMeta, ParamRedistMeta, PipelineRedistMeta, RankRedistMeta, RedistMeta) diff --git a/colossalai/utils/checkpoint_io/backend.py b/colossalai/utils/checkpoint_io/backend.py deleted file mode 100644 index 140192c05f12cf4843df36d43c66723469ef6cad..0000000000000000000000000000000000000000 --- a/colossalai/utils/checkpoint_io/backend.py +++ /dev/null @@ -1,74 +0,0 @@ -import shutil -import tempfile -from abc import ABC, abstractmethod -from typing import Dict, List, Type - -from .reader import CheckpointReader, DiskCheckpointReader -from .writer import CheckpointWriter, DiskCheckpointWriter - -_backends: Dict[str, Type['CheckpointIOBackend']] = {} - - -def register(name: str): - assert name not in _backends, f'"{name}" is registered' - - def wrapper(cls): - _backends[name] = cls - return cls - - return wrapper - - -def get_backend(name: str) -> 'CheckpointIOBackend': - assert name in _backends, f'Unsupported backend "{name}"' - return _backends[name]() - - -class CheckpointIOBackend(ABC): - - def __init__(self) -> None: - super().__init__() - self.temps: List[str] = [] - - @abstractmethod - def get_writer(self, - base_name: str, - overwrite: bool = False, - rank: int = 0, - world_size: int = 1) -> CheckpointWriter: - pass - - @abstractmethod - def get_reader(self, base_name: str) -> CheckpointReader: - pass - - @abstractmethod - def get_temp(self, base_name: str) -> str: - pass - - @abstractmethod - def clean_temp(self) -> None: - pass - - -@register('disk') -class CheckpointDiskIO(CheckpointIOBackend): - - def get_writer(self, - base_name: str, - overwrite: bool = False, - rank: int = 0, - world_size: int = 1) -> CheckpointWriter: - return DiskCheckpointWriter(base_name, overwrite, rank=rank, world_size=world_size) - - def get_reader(self, base_name: str) -> CheckpointReader: - return DiskCheckpointReader(base_name) - - def get_temp(self, base_name: str) -> str: - temp_dir_name = tempfile.mkdtemp(dir=base_name) - self.temps.append(temp_dir_name) - return temp_dir_name - - def clean_temp(self) -> None: - for temp_dir_name in self.temps: - shutil.rmtree(temp_dir_name) diff --git a/colossalai/utils/checkpoint_io/constant.py b/colossalai/utils/checkpoint_io/constant.py deleted file mode 100644 index 2199484741bf5bb934d5c0583dd55f51f0cdbffe..0000000000000000000000000000000000000000 --- a/colossalai/utils/checkpoint_io/constant.py +++ /dev/null @@ -1,9 +0,0 @@ -import re - -GLOBAL_META_FILE_NAME = 'global_meta.bin' -MODEL_CKPT_FILE_NAME = 'model.bin' -OPTIM_CKPT_FILE_NAME = 'optim.bin' -META_CKPT_FILE_NAME = 'meta.bin' -OTHER_CKPT_FILE_NAME = 'other.bin' - -CKPT_PAT = re.compile(r'global_meta|model|optim|meta|other') diff --git a/colossalai/utils/checkpoint_io/convertor.py b/colossalai/utils/checkpoint_io/convertor.py deleted file mode 100644 index 529ceb86829b511d05911fe6103d5ef53ecf325d..0000000000000000000000000000000000000000 --- a/colossalai/utils/checkpoint_io/convertor.py +++ /dev/null @@ -1,227 +0,0 @@ -from abc import ABC, abstractmethod -from collections import defaultdict -from typing import Any, Callable, Dict, List, Optional - -from torch import Tensor - -from .distributed import merge_param, unmerge_param -from .meta import ParamDistMeta, RedistMeta -from .utils import (ModelCheckpointSharder, OptimizerCheckpointSharder, run_if_not_none) - - -class CheckpointConvertor(ABC): - - @abstractmethod - def append(self, shard_dict: Dict[int, dict], dist_meta_list: List[Optional[Dict[str, ParamDistMeta]]]) -> None: - pass - - @abstractmethod - def complete(self) -> None: - pass - - -class ModelCheckpointConvertor(CheckpointConvertor): - - def __init__(self, param_count: Dict[str, int]) -> None: - super().__init__() - self.param_count = param_count - self.buffer: Dict[str, Dict[int, Tensor]] = defaultdict(dict) - - @abstractmethod - def convert_tensors(self, key: str, tensors: List[Tensor], dist_metas: List[ParamDistMeta]) -> None: - pass - - def append(self, shard_dict: Dict[int, dict], dist_meta_list: List[Optional[Dict[str, ParamDistMeta]]]) -> None: - for rank, state_dict in shard_dict.items(): - for k, tensor in state_dict.items(): - self.buffer[k][rank] = tensor - converted_keys = set() - for k, rank_dict in self.buffer.items(): - if len(rank_dict) == self.param_count[k]: - tensors = [] - dist_metas = [] - for rank, tensor in rank_dict.items(): - tensors.append(tensor) - if dist_meta_list[rank] is not None: - dist_metas.append(dist_meta_list[rank][k]) - self.convert_tensors(k, tensors, dist_metas) - converted_keys.add(k) - for k in converted_keys: - del self.buffer[k] - - def complete(self) -> None: - assert len(self.buffer) == 0 - - -class ModelCheckpointMerger(ModelCheckpointConvertor): - - def __init__(self, max_shard_size: int, save_fn: Callable[[dict], Any], param_count: Dict[str, int]) -> None: - super().__init__(param_count) - self.sharder = ModelCheckpointSharder(max_shard_size) - self.save_fn = save_fn - - def convert_tensors(self, key: str, tensors: List[Tensor], dist_metas: List[ParamDistMeta]) -> None: - assert len(dist_metas) == len(tensors) - tensor = merge_param(tensors, dist_metas) - shard = self.sharder.append(key, tensor) - run_if_not_none(self.save_fn, shard) - - def complete(self) -> None: - super().complete() - run_if_not_none(self.save_fn, self.sharder.complete()) - - -class ModelCheckpointRedistor(ModelCheckpointConvertor): - - def __init__(self, max_shard_size: int, save_fns: List[Callable[[dict], Any]], param_count: Dict[str, int], - redist_meta: RedistMeta) -> None: - super().__init__(param_count) - self.save_fns = save_fns - self.redist_meta = redist_meta - nprocs = len(save_fns) - self.sharders = [ModelCheckpointSharder(max_shard_size) for _ in range(nprocs)] - self.rank_map = defaultdict(lambda: defaultdict(lambda: defaultdict(list))) - for k, rank_meta in redist_meta.rank_meta.items(): - for rank, rank_info in rank_meta.items(): - self.rank_map[k][rank_info.tp_rank][rank_info.dp_rank].append(rank) - - def convert_tensors(self, key: str, tensors: List[Tensor], dist_metas: List[ParamDistMeta]) -> None: - if len(dist_metas) == 0: - # already global - tensor = tensors[0] - else: - assert len(dist_metas) == len(tensors) - tensor = merge_param(tensors, dist_metas) - for tp_rank, tensor_list in enumerate(unmerge_param(tensor, self.redist_meta.param_meta[key])): - for dp_rank, t in enumerate(tensor_list): - for rank in self.rank_map[key][tp_rank][dp_rank]: - shard = self.sharders[rank].append(key, t) - run_if_not_none(self.save_fns[rank], shard) - - def complete(self) -> None: - super().complete() - for rank, save_fn in enumerate(self.save_fns): - run_if_not_none(save_fn, self.sharders[rank].complete()) - - -class OptimizerCheckpointConvertor(CheckpointConvertor): - - def __init__(self, param_count: Dict[str, int], param_to_os: Optional[Dict[str, int]], - paired_os: Optional[Dict[int, dict]]) -> None: - super().__init__() - self.param_count = param_count - self.param_to_os = param_to_os - self.paired_os = paired_os - self.buffer: Dict[int, Dict[int, dict]] = defaultdict(dict) - self.os_to_param = {v: k for k, v in param_to_os.items()} - - @abstractmethod - def setup(self, param_groups: dict) -> None: - pass - - @abstractmethod - def convert_states(self, idx: int, states: List[dict], dist_metas: List[ParamDistMeta]) -> None: - pass - - def append(self, shard_dict: Dict[int, dict], dist_meta_list: List[Optional[Dict[str, ParamDistMeta]]]) -> None: - for rank, state_dict in shard_dict.items(): - self.setup(state_dict['param_groups']) - for idx, state in state_dict['state'].items(): - self.buffer[idx][rank] = state - converted_indices = set() - for idx, rank_dict in self.buffer.items(): - if len(rank_dict) == self.param_count[self.os_to_param[idx]]: - states = [] - dist_metas = [] - for rank, state in rank_dict.items(): - states.append(state) - if dist_meta_list[rank] is not None: - dist_metas.append(dist_meta_list[rank][self.os_to_param[idx]]) - self.convert_states(idx, states, dist_metas) - converted_indices.add(idx) - for idx in converted_indices: - del self.buffer[idx] - - def complete(self) -> None: - assert len(self.buffer) == 0 - - -class OptimizerCheckpointMerger(OptimizerCheckpointConvertor): - - def __init__(self, max_shard_size: int, save_fn: Callable[[dict], Any], param_count: Dict[str, int], - param_to_os: Optional[Dict[str, int]], paired_os: Optional[Dict[int, dict]]) -> None: - super().__init__(param_count, param_to_os, paired_os) - self.max_shard_size = max_shard_size - self.save_fn = save_fn - self.sharder = None - - def setup(self, param_groups: dict) -> None: - if self.sharder is None: - self.sharder = OptimizerCheckpointSharder(self.max_shard_size, param_groups) - - def convert_states(self, idx: int, states: List[dict], dist_metas: List[ParamDistMeta]) -> None: - assert len(dist_metas) == len(states) - new_state = {} - for state_key, state_tensor in states[0].items(): - if self.paired_os[idx][state_key]: - new_state[state_key] = merge_param([state[state_key] for state in states], dist_metas) - else: - new_state[state_key] = state_tensor - shard = self.sharder.append(idx, new_state) - run_if_not_none(self.save_fn, shard) - - def complete(self) -> None: - super().complete() - run_if_not_none(self.save_fn, self.sharder.complete()) - - -class OptimizerCheckpointRedistor(OptimizerCheckpointConvertor): - - def __init__(self, max_shard_size: int, save_fns: List[Callable[[dict], Any]], param_count: Dict[str, int], - param_to_os: Optional[Dict[str, int]], paired_os: Optional[Dict[int, dict]], - redist_meta: RedistMeta) -> None: - super().__init__(param_count, param_to_os, paired_os) - self.max_shard_size = max_shard_size - self.save_fns = save_fns - self.redist_meta = redist_meta - self.sharders: List[OptimizerCheckpointSharder] = [] - self.rank_map = defaultdict(lambda: defaultdict(lambda: defaultdict(list))) - for k, rank_meta in redist_meta.rank_meta.items(): - for rank, rank_info in rank_meta.items(): - self.rank_map[k][rank_info.tp_rank][rank_info.dp_rank].append(rank) - - def setup(self, param_groups: dict) -> None: - if len(self.sharders) == 0: - nprocs = len(self.save_fns) - for _ in range(nprocs): - self.sharders.append(OptimizerCheckpointSharder(self.max_shard_size, param_groups)) - - def convert_states(self, idx: int, states: List[dict], dist_metas: List[ParamDistMeta]) -> None: - need_merge: bool = True - if len(dist_metas) == 0: - need_merge = False - else: - assert len(dist_metas) == len(states) - new_states = [{} for _ in range(len(self.save_fns))] - for state_key, state_tensor in states[0].items(): - if self.paired_os[idx][state_key]: - if need_merge: - tensor = merge_param([state[state_key] for state in states], dist_metas) - else: - tensor = state_tensor - for tp_rank, tensor_list in enumerate( - unmerge_param(tensor, self.redist_meta.param_meta[self.os_to_param[idx]])): - for dp_rank, t in enumerate(tensor_list): - for rank in self.rank_map[self.os_to_param[idx]][tp_rank][dp_rank]: - new_states[rank][state_key] = t - else: - for new_state in new_states: - new_state[state_key] = state_tensor - for rank, new_state in enumerate(new_states): - shard = self.sharders[rank].append(idx, new_state) - run_if_not_none(self.save_fns[rank], shard) - - def complete(self) -> None: - super().complete() - for rank, save_fn in enumerate(self.save_fns): - run_if_not_none(save_fn, self.sharders[rank].complete()) diff --git a/colossalai/utils/checkpoint_io/distributed.py b/colossalai/utils/checkpoint_io/distributed.py deleted file mode 100644 index bf720437c41a9114dc805647e415179a9b1243a9..0000000000000000000000000000000000000000 --- a/colossalai/utils/checkpoint_io/distributed.py +++ /dev/null @@ -1,127 +0,0 @@ -import torch -from numpy import prod -from torch import Tensor -from typing import List, Optional, Tuple -from collections import defaultdict -from .meta import ParamDistMeta, ParamRedistMeta - - -def unflatten_zero_param(tensors: List[Tensor], dist_metas: List[ParamDistMeta]) -> Tensor: - assert len(tensors) > 0 and len(dist_metas) > 0 and len(tensors) == len(dist_metas) - for dist_meta in dist_metas[1:]: - assert dist_meta.zero_meta == dist_metas[0].zero_meta, 'Expect all params have the same zero meta.' - if not dist_metas[0].used_zero: - # tensors are replicate - return tensors[0] - numel = dist_metas[0].zero_numel - orig_shape = dist_metas[0].zero_orig_shape - tensors = [t[1] for t in sorted(zip(dist_metas, tensors), key=lambda tp: tp[0].dp_rank)] - assert numel == sum(t.numel() for t in tensors), 'Expect numel of all params is equal to zero_numel.' - return torch.cat(tensors).reshape(orig_shape) - - -def gather_tp_param(tensors: List[Tensor], dist_metas: List[ParamDistMeta]) -> Tensor: - assert len(tensors) > 0 and len(dist_metas) > 0 and len(tensors) == len(dist_metas) - for dist_meta in dist_metas[1:]: - assert dist_meta.tp_meta == dist_metas[0].tp_meta, 'Expect all params have the same tp meta.' - for t in tensors[1:]: - assert t.shape == tensors[0].shape, 'Expect all params have the same shape.' - if not dist_metas[0].used_tp: - # tensors are replicate - return tensors[0] - total_parts = prod(dist_meta.tp_num_parts) - assert dist_meta.tp_world_size == total_parts, \ - f'Expect prod(tp_num_parts) == tp_world_size, got {total_parts} and {dist_meta.tp_world_size}.' - shard_info = sorted(zip(dist_meta.tp_shard_dims, dist_meta.tp_num_parts), key=lambda t: t[0], reverse=True) - for dim, num_parts in shard_info: - buffer = [] - for start in range(0, len(tensors), num_parts): - buffer.append(torch.cat(tensors[start:start + num_parts], dim)) - tensors = buffer - assert len(tensors) == 1 - return tensors[0] - - -def validate_parallel_info(dist_metas: List[ParamDistMeta]) -> None: - assert len(dist_metas) > 0 - # check world size - for dist_meta in dist_metas[1:]: - assert dist_meta.dp_world_size == dist_metas[ - 0].dp_world_size, 'Expect all dist meta have the same dp_world_size' - assert dist_meta.tp_world_size == dist_metas[ - 0].tp_world_size, 'Expect all dist meta have the same tp_world_size' - - -def deduplicate_params(tensors: List[Tensor], - dist_metas: List[ParamDistMeta]) -> Tuple[List[Tensor], List[ParamDistMeta]]: - unique_dist_meta = [] - unique_idx = [] - for i, dist_meta in enumerate(dist_metas): - if dist_meta not in unique_dist_meta: - unique_dist_meta.append(dist_meta) - unique_idx.append(i) - return [tensors[i] for i in unique_idx], [dist_metas[i] for i in unique_idx] - - -def merge_param(tensors: List[Tensor], dist_metas: List[ParamDistMeta]) -> Tensor: - assert len(tensors) > 0 and len(dist_metas) > 0 and len(tensors) == len(dist_metas) - # validate parallel info - validate_parallel_info(dist_metas) - tensors, dist_metas = deduplicate_params(tensors, dist_metas) - unflattened_tensors = [] - # group zero params by tp rank - tensor_dict = defaultdict(list) - dist_meta_dict = defaultdict(list) - for t, dist_meta in zip(tensors, dist_metas): - tensor_dict[dist_meta.tp_rank].append(t) - dist_meta_dict[dist_meta.tp_rank].append(dist_meta) - assert len(tensor_dict - ) == dist_metas[0].tp_world_size, f'Expect {dist_metas[0].tp_world_size} ranks, got {len(tensor_dict)}' - for tp_rank in tensor_dict.keys(): - unflattened_tensors.append(unflatten_zero_param(tensor_dict[tp_rank], dist_meta_dict[tp_rank])) - return gather_tp_param(unflattened_tensors, [dist_meta_list[0] for dist_meta_list in dist_meta_dict.values()]) - - -def split_tp_param(tensor: Tensor, redist_meta: ParamRedistMeta) -> List[Tensor]: - if not redist_meta.used_tp: - assert redist_meta.tp_world_size == 1, 'Expect tp_world_size == 1, when no tp meta provided.' - return [tensor] - total_parts = prod(redist_meta.tp_num_parts) - assert redist_meta.tp_world_size == total_parts, f'Expect prod(tp_num_parts) == tp_world_size, got {total_parts} and {redist_meta.tp_world_size}.' - shard_info = sorted(zip(redist_meta.tp_shard_dims, redist_meta.tp_num_parts), key=lambda t: t[0]) - tensors = [tensor] - for dim, num_parts in shard_info: - buffer = [] - for t in tensors: - assert t.size(dim) % num_parts == 0, \ - f'Expect dim{dim} of tensor({tensor.shape}) is divisible by {num_parts}.' - chunks = [chunk.contiguous() for chunk in t.chunk(num_parts, dim)] - buffer.extend(chunks) - tensors = buffer - assert len(tensors) == redist_meta.tp_world_size - return tensors - - -def flatten_zero_param(tensor: Tensor, redist_meta: ParamRedistMeta) -> List[Tensor]: - if not redist_meta.used_zero: - return [tensor] * redist_meta.dp_world_size - tensors: List[Optional[Tensor]] = [ - torch.empty(0, dtype=tensor.dtype, device=tensor.device) for _ in range(redist_meta.zero_start_dp_rank) - ] - offsets = redist_meta.zero_offsets + [tensor.numel()] - for i, offset in enumerate(offsets[:-1]): - end = offsets[i + 1] - tensors.append(tensor.view(-1)[offset:end]) - if len(tensors) < redist_meta.dp_world_size: - tensors.extend([ - torch.empty(0, dtype=tensor.dtype, device=tensor.device) - for _ in range(redist_meta.dp_world_size - len(tensors)) - ]) - assert len(tensors) == redist_meta.dp_world_size - return tensors - - -def unmerge_param(tensor: Tensor, redist_meta: ParamRedistMeta) -> List[List[Tensor]]: - tensors = split_tp_param(tensor, redist_meta) - tensors = [flatten_zero_param(t, redist_meta) for t in tensors] - return tensors diff --git a/colossalai/utils/checkpoint_io/io.py b/colossalai/utils/checkpoint_io/io.py deleted file mode 100644 index f00212cdf85986ff7b3526a2fd4d56210459dda2..0000000000000000000000000000000000000000 --- a/colossalai/utils/checkpoint_io/io.py +++ /dev/null @@ -1,170 +0,0 @@ -import warnings -from typing import Any, Callable, Dict, Generator, List, Optional, Tuple - -import torch.distributed as dist -from torch.nn import Module -from torch.optim import Optimizer - -from .backend import get_backend -from .convertor import (CheckpointConvertor, ModelCheckpointMerger, ModelCheckpointRedistor, OptimizerCheckpointMerger, - OptimizerCheckpointRedistor) -from .meta import ParamDistMeta, RedistMeta -from .utils import build_checkpoints, optimizer_load_state_dict - - -def save(path: str, - model: Module, - optimizer: Optional[Optimizer] = None, - param_to_os: Optional[Dict[str, int]] = None, - dist_meta: Optional[Dict[str, ParamDistMeta]] = None, - max_shard_size_gb: float = 0.0, - overwrite: bool = False, - backend: str = 'disk', - **kwargs: Any) -> None: - io_backend = get_backend(backend) - if dist.is_initialized(): - rank = dist.get_rank() - world_size = dist.get_world_size() - else: - rank = 0 - world_size = 1 - if world_size == 1: - # global doesn't need dist_meta - dist_meta = None - else: - assert dist_meta is not None - max_shard_size = int(max_shard_size_gb * 1024**3) - model_checkpoints, optimizer_checkpoints, meta_checkpoint = build_checkpoints(max_shard_size, model, optimizer, - param_to_os, dist_meta) - writer = io_backend.get_writer(path, overwrite, rank, world_size) - writer.save_others(kwargs) - for model_checkpoint in model_checkpoints: - writer.save_model(model_checkpoint) - for optimizer_checkpoint in optimizer_checkpoints: - writer.save_optimizer(optimizer_checkpoint) - writer.save_meta(meta_checkpoint) - - -def merge(path: str, - output_path: str, - max_shard_size_gb: float = 0.0, - overwrite: bool = False, - backend: str = 'disk') -> bool: - io_backend = get_backend(backend) - if dist.is_initialized() and dist.get_rank() != 0: - return False - reader = io_backend.get_reader(path) - if len(reader.meta_list) == 1: - # already global - warnings.warn(f'Checkpoint at "{path}" is already global, nothing to do.') - return False - dist_meta_list, param_count, param_to_os, paired_os = reader.load_meta() - writer = io_backend.get_writer(output_path, overwrite=overwrite) - writer.save_others(reader.load_others()) - max_shard_size = int(max_shard_size_gb * 1024**3) - _convert_shards(ModelCheckpointMerger(max_shard_size, writer.save_model, param_count), reader.load_models(), - dist_meta_list) - _convert_shards( - OptimizerCheckpointMerger(max_shard_size, writer.save_optimizer, param_count, param_to_os, paired_os), - reader.load_optimizers(), dist_meta_list) - meta_checkpoint = {'dist_meta': None, 'params': list(param_count.keys())} - if param_to_os is not None: - meta_checkpoint['param_to_os'] = param_to_os - meta_checkpoint['paired_os'] = paired_os - writer.save_meta(meta_checkpoint) - return True - - -def redist(path: str, - output_path: str, - redist_meta: RedistMeta, - dist_metas: List[Dict[str, ParamDistMeta]], - max_shard_size_gb: float = 0.0, - overwrite: bool = False, - backend: str = 'disk') -> bool: - io_backend = get_backend(backend) - if dist.is_initialized() and dist.get_rank() != 0: - return False - nprocs = len(dist_metas) - reader = io_backend.get_reader(path) - dist_meta_list, param_count, param_to_os, paired_os = reader.load_meta() - do_redist: bool = False - if len(dist_meta_list) == nprocs: - for a, b in zip(dist_metas, dist_meta_list): - if a != b: - do_redist = True - break - else: - do_redist = True - if not do_redist: - warnings.warn(f'Checkpoint at "{path}" is not required to redist, nothing to do.') - return False - - writers = [io_backend.get_writer(output_path, overwrite, rank, nprocs) for rank in range(nprocs)] - writers[0].save_others(reader.load_others()) - max_shard_size = int(max_shard_size_gb * 1024**3) - _convert_shards( - ModelCheckpointRedistor(max_shard_size, [writer.save_model for writer in writers], param_count, redist_meta), - reader.load_models(), dist_meta_list) - _convert_shards( - OptimizerCheckpointRedistor(max_shard_size, [writer.save_optimizer for writer in writers], param_count, - param_to_os, paired_os, redist_meta), reader.load_optimizers(), dist_meta_list) - for writer, dist_meta in zip(writers, dist_metas): - meta_checkpoint = {'dist_meta': dist_meta, 'params': list(param_count.keys())} - if param_to_os is not None: - meta_checkpoint['param_to_os'] = param_to_os - meta_checkpoint['paired_os'] = paired_os - writer.save_meta(meta_checkpoint) - return True - - -def _convert_shards(convertor: CheckpointConvertor, shard_generator: Generator[dict, None, None], - dist_meta_list: List[Optional[Dict[str, ParamDistMeta]]]) -> None: - for shard_dict in shard_generator: - convertor.append(shard_dict, dist_meta_list) - convertor.complete() - - -def load(path: str, - model: Module, - optimizer: Optional[Optimizer] = None, - redist_meta: Optional[RedistMeta] = None, - dist_metas: Optional[List[Dict[str, ParamDistMeta]]] = None, - max_shard_size_gb: float = 0.0, - backend: str = 'disk') -> dict: - is_global: bool = not dist.is_initialized() or dist.get_world_size() == 1 - rank: int = dist.get_rank() if dist.is_initialized() else 0 - is_main_process: bool = rank == 0 - # validate args - if redist_meta is None or dist_metas is None: - assert is_global - io_backend = get_backend(backend) - read_path: str = path - if is_main_process: - # pre-process checkpoints - temp_path = io_backend.get_temp(path) - if is_global: - wrote = merge(path, temp_path, max_shard_size_gb, backend=backend) - else: - wrote = redist(path, temp_path, redist_meta, dist_metas, max_shard_size_gb, backend=backend) - if wrote: - read_path = temp_path - if not is_global: - bcast_list = [read_path] if is_main_process else [None] - dist.broadcast_object_list(bcast_list) - read_path = bcast_list[0] - reader = io_backend.get_reader(read_path) - # load model - for shard in reader.load_model(rank): - model.load_state_dict(shard, strict=False) - if optimizer is not None: - for shard in reader.load_optimizer(rank): - # optimizer.load_state_dict(shard) - optimizer_load_state_dict(optimizer, shard) - others_dict = reader.load_others() - if not is_global: - dist.barrier() - # clean up temp - if is_main_process: - io_backend.clean_temp() - return others_dict diff --git a/colossalai/utils/checkpoint_io/meta.py b/colossalai/utils/checkpoint_io/meta.py deleted file mode 100644 index 994f08b4b5e44be753aa838a0c03ea693a78cc27..0000000000000000000000000000000000000000 --- a/colossalai/utils/checkpoint_io/meta.py +++ /dev/null @@ -1,81 +0,0 @@ -from dataclasses import dataclass -from typing import List, Optional, Set, Dict - - -@dataclass -class ParamDistMeta: - # parallel info - dp_rank: int - dp_world_size: int - tp_rank: int - tp_world_size: int - # tp info - tp_shard_dims: Optional[List[int]] = None - tp_num_parts: Optional[List[int]] = None - # zero info - zero_numel: Optional[int] = None - zero_orig_shape: Optional[List[int]] = None - - @property - def used_tp(self) -> bool: - return self.tp_shard_dims is not None and self.tp_num_parts is not None - - @property - def used_zero(self) -> bool: - return self.zero_numel is not None and self.zero_orig_shape is not None - - @property - def parallel_meta(self) -> tuple: - return self.dp_rank, self.dp_world_size, self.tp_rank, self.tp_world_size - - @property - def tp_meta(self) -> tuple: - return self.tp_shard_dims, self.tp_num_parts - - @property - def zero_meta(self) -> tuple: - return self.zero_numel, self.zero_orig_shape - - @staticmethod - def from_dict(d: dict) -> 'ParamDistMeta': - return ParamDistMeta(**d) - - -@dataclass -class ParamRedistMeta: - # parallel info - dp_world_size: int - tp_world_size: int - # tp info - tp_shard_dims: Optional[List[int]] = None - tp_num_parts: Optional[List[int]] = None - # zero info - zero_start_dp_rank: Optional[int] = None - zero_offsets: Optional[List[int]] = None - - @property - def used_tp(self) -> bool: - return self.tp_shard_dims is not None and self.tp_num_parts is not None - - @property - def used_zero(self) -> bool: - return self.zero_start_dp_rank is not None and self.zero_offsets is not None - - -@dataclass -class RankRedistMeta: - dp_rank: int - tp_rank: int - pp_rank: int - - -@dataclass -class PipelineRedistMeta: - params: Set[str] - - -@dataclass -class RedistMeta: - rank_meta: Dict[str, Dict[int, RankRedistMeta]] - pipeline_meta: List[PipelineRedistMeta] - param_meta: Dict[str, ParamRedistMeta] diff --git a/colossalai/utils/checkpoint_io/reader.py b/colossalai/utils/checkpoint_io/reader.py deleted file mode 100644 index 3158c6481263a7a4be4b15425728d2505b9661e2..0000000000000000000000000000000000000000 --- a/colossalai/utils/checkpoint_io/reader.py +++ /dev/null @@ -1,131 +0,0 @@ -import os -from abc import ABC, abstractmethod -from collections import Counter -from typing import Dict, Generator, List, Optional, Tuple - -import torch - -from .constant import GLOBAL_META_FILE_NAME, OTHER_CKPT_FILE_NAME -from .meta import ParamDistMeta -from .utils import is_duplicated_list - - -class CheckpointReader(ABC): - - def __init__(self, base_name: str) -> None: - super().__init__() - self.base_name = base_name - self.meta_list = [] - - @abstractmethod - def read(self, name: str) -> dict: - pass - - @abstractmethod - def load_meta( - self) -> Tuple[List[Optional[Dict[str, ParamDistMeta]]], Dict[str, int], Optional[dict], Optional[dict]]: - pass - - @abstractmethod - def load_model(self, rank: int) -> Generator[dict, None, None]: - pass - - @abstractmethod - def load_models(self) -> Generator[Dict[int, dict], None, None]: - pass - - @abstractmethod - def load_optimizer(self, rank: int) -> Generator[dict, None, None]: - pass - - @abstractmethod - def load_optimizers(self) -> Generator[Dict[int, dict], None, None]: - pass - - @abstractmethod - def load_others(self) -> dict: - pass - - -class DiskCheckpointReader(CheckpointReader): - - def __init__(self, base_name: str) -> None: - super().__init__(base_name) - assert os.path.isdir(base_name), f'"{base_name}" is not a directory' - global_meta = self.read(GLOBAL_META_FILE_NAME) - for meta_file_name in global_meta['meta']: - meta = self.read(meta_file_name) - if meta.get('dist_meta', None) is None: - # only global checkpoint can have empty dist_meta - assert len(global_meta['meta']) == 1 - self.meta_list.append(meta) - - def read(self, name: str) -> dict: - return torch.load(os.path.join(self.base_name, name)) - - def load_meta( - self) -> Tuple[List[Optional[Dict[str, ParamDistMeta]]], Dict[str, int], Optional[dict], Optional[dict]]: - meta_infos = [(meta.get('dist_meta', None), meta['params'], meta.get('param_to_os', - None), meta.get('paired_os', None)) - for meta in self.meta_list] - dist_meta_list, params_list, param_to_os_list, paired_os_list = zip(*meta_infos) - # reduce param_count - param_count = Counter(p for params in params_list for p in params) - # validate param_to_os - assert is_duplicated_list(param_to_os_list) - assert is_duplicated_list(paired_os_list) - return list(dist_meta_list), param_count, param_to_os_list[0], paired_os_list[0] - - def _load_shard(self, shard_type: str, rank: int) -> Generator[dict, None, None]: - meta = self.meta_list[rank] - checkpoint_names = meta.get(shard_type, []) - for name in checkpoint_names: - yield self.read(name) - - def load_model(self, rank: int) -> Generator[dict, None, None]: - return self._load_shard('model', rank) - - def load_models(self) -> Generator[Dict[int, dict], None, None]: - indices = [0] * len(self.meta_list) - while True: - shards = {} - for i, meta in enumerate(self.meta_list): - model_checkpoint_names = meta.get('model', []) - if indices[i] < len(model_checkpoint_names): - shards[i] = self.read(model_checkpoint_names[indices[i]]) - indices[i] += 1 - if len(shards) > 0: - yield shards - else: - break - - def load_optimizer(self, rank: int) -> Generator[dict, None, None]: - param_groups = None - for shard in self._load_shard('optimizer', rank): - if param_groups is None: - param_groups = shard['param_groups'] - else: - shard['param_groups'] = param_groups - yield shard - - def load_optimizers(self) -> Generator[Dict[int, dict], None, None]: - indices = [0] * len(self.meta_list) - param_groups = [] - while True: - shards = {} - for i, meta in enumerate(self.meta_list): - optimizer_checkpoint_names = meta.get('optimizer', []) - if indices[i] < len(optimizer_checkpoint_names): - shards[i] = self.read(optimizer_checkpoint_names[indices[i]]) - if indices[i] == 0: - param_groups.append(shards[i]['param_groups']) - else: - shards[i]['param_groups'] = param_groups[i] - indices[i] += 1 - if len(shards) > 0: - yield shards - else: - break - - def load_others(self) -> dict: - return self.read(OTHER_CKPT_FILE_NAME) diff --git a/colossalai/utils/checkpoint_io/utils.py b/colossalai/utils/checkpoint_io/utils.py deleted file mode 100644 index 135385f5737947400b8f6ae860f51001ca9742fd..0000000000000000000000000000000000000000 --- a/colossalai/utils/checkpoint_io/utils.py +++ /dev/null @@ -1,223 +0,0 @@ -import warnings -from copy import deepcopy -from itertools import chain -from typing import Any, Callable, Dict, List, Optional, Tuple - -from torch import Tensor -from torch.nn import Module -from torch.nn.parameter import Parameter -from torch.optim import Optimizer - -from .meta import ParamDistMeta - - -def run_if_not_none(fn: Callable[[Any], Any], arg: Any) -> Any: - if arg is not None: - return fn(arg) - - -def get_param_to_os(model: Module, optimizer: Optimizer) -> Dict[str, int]: - # ensure all params in optimizer are in model state dict - params_set = set(id(p) for p in model.parameters()) - for group in optimizer.param_groups: - for p in group['params']: - assert id(p) in params_set - param_mappings = {} - start_index = 0 - - def get_group_mapping(group): - nonlocal start_index - param_mappings.update( - {id(p): i for i, p in enumerate(group['params'], start_index) if id(p) not in param_mappings}) - start_index += len(group['params']) - - for g in optimizer.param_groups: - get_group_mapping(g) - return {k: param_mappings[id(p)] for k, p in model.named_parameters()} - - -def compute_optimizer_state_size(state: Dict[str, Any]) -> int: - size = 0 - for v in state.values(): - if isinstance(v, Tensor): - size += v.numel() * v.element_size() - return size - - -class ModelCheckpointSharder: - - def __init__(self, max_shard_size: int) -> None: - self.max_shard_size = max_shard_size - self.buffer: Dict[str, Tensor] = {} - self.buffer_size: int = 0 - - def append(self, key: str, tensor: Tensor) -> Optional[dict]: - retval = None - if self.max_shard_size > 0 and self.buffer_size >= self.max_shard_size: - retval = self.buffer - self.buffer = {} - self.buffer_size = 0 - self.buffer[key] = tensor - self.buffer_size += tensor.numel() * tensor.element_size() - return retval - - def extend(self, state_dict: Dict[str, Tensor]) -> List[dict]: - shards = [] - for key, tensor in state_dict.items(): - shard = self.append(key, tensor) - run_if_not_none(shards.append, shard) - return shards - - def complete(self) -> Optional[dict]: - return self.buffer if len(self.buffer) > 0 else None - - -class OptimizerCheckpointSharder: - - def __init__(self, max_shard_size: int, param_groups: dict) -> None: - self.max_shard_size = max_shard_size - self.buffer: Dict[str, dict] = {'state': {}, 'param_groups': param_groups} - self.buffer_size: int = 0 - self.returned_first: bool = False - - def append(self, key: int, state: dict) -> Optional[dict]: - retval = None - if self.max_shard_size > 0 and self.buffer_size >= self.max_shard_size: - retval = self.buffer - self.buffer = {'state': {}} - self.buffer_size = 0 - self.buffer['state'][key] = state - self.buffer_size += compute_optimizer_state_size(state) - return retval - - def extend(self, state_dict: Dict[str, dict]) -> List[dict]: - shards = [] - for key, state in state_dict['state'].items(): - shard = self.append(key, state) - run_if_not_none(shards.append, shard) - return shards - - def complete(self) -> Optional[dict]: - return self.buffer if len(self.buffer['state']) > 0 else None - - -def shard_checkpoint(max_shard_size: int, - model_state_dict: Dict[str, Tensor], - optimizer_state_dict: Optional[dict] = None, - param_to_os: Optional[dict] = None) -> Tuple[List[dict], List[dict]]: - has_optimizer: bool = False - if optimizer_state_dict is not None: - assert param_to_os is not None - os_to_param = {v: k for k, v in param_to_os.items()} - for os_key in optimizer_state_dict['state'].keys(): - assert os_key in os_to_param - assert os_to_param[os_key] in model_state_dict - has_optimizer = True - model_sharder = ModelCheckpointSharder(max_shard_size) - model_shards = model_sharder.extend(model_state_dict) - run_if_not_none(model_shards.append, model_sharder.complete()) - if not has_optimizer: - return model_shards, [] - optimizer_sharder = OptimizerCheckpointSharder(max_shard_size, optimizer_state_dict['param_groups']) - optimizer_shards = optimizer_sharder.extend(optimizer_state_dict) - run_if_not_none(optimizer_shards.append, optimizer_sharder.complete()) - return model_shards, optimizer_shards - - -def get_paired_os(model_state_dict: Dict[str, Tensor], optimizer_state_dict: dict, param_to_os: Dict[str, int]) -> dict: - os_to_param = {v: k for k, v in param_to_os.items()} - paired_os = {} - for idx, state in optimizer_state_dict['state'].items(): - paired_os[idx] = {} - p = model_state_dict[os_to_param[idx]] - for k, v in state.items(): - if isinstance(v, Tensor) and v.shape == p.shape: - paired_os[idx][k] = True - else: - paired_os[idx][k] = False - return paired_os - - -def build_checkpoints(max_size: int, - model: Module, - optimizer: Optional[Optimizer] = None, - param_to_os: Optional[Dict[str, int]] = None, - dist_meta: Optional[Dict[str, ParamDistMeta]] = None, - eliminate_replica: bool = False) -> Tuple[List[dict], List[dict], dict]: - save_global = dist_meta is None - model_state_dict = model.state_dict() - optimizer_state_dict = optimizer.state_dict() if optimizer else None - meta = {'dist_meta': dist_meta} - if optimizer: - param_to_os = param_to_os or get_param_to_os(model, optimizer) - paired_os = get_paired_os(model_state_dict, optimizer_state_dict, param_to_os) - meta['param_to_os'] = param_to_os - meta['paired_os'] = paired_os - if not save_global and eliminate_replica: - # filter dp replicated params - model_state_dict = { - k: v for k, v in model_state_dict.items() if dist_meta[k].used_zero or dist_meta[k].dp_rank == 0 - } - if optimizer: - optimizer_state_dict['state'] = { - param_to_os[k]: optimizer_state_dict['state'][param_to_os[k]] - for k in model_state_dict.keys() - if dist_meta[k].used_zero or dist_meta[k].dp_rank == 0 - } - meta['params'] = list(model_state_dict.keys()) - if len(model_state_dict) == 0: - warnings.warn('model state dict is empty, checkpoint is not saved') - return [], [], meta - model_checkpoints, optimizer_checkpoints = shard_checkpoint(max_size, model_state_dict, optimizer_state_dict, - param_to_os) - return model_checkpoints, optimizer_checkpoints, meta - - -def is_duplicated_list(list_: List[Any]) -> bool: - if len(list_) == 0: - return True - elem = list_[0] - for x in list_[1:]: - if x != elem: - return False - return True - - -def copy_optimizer_state(src_state: dict, dest_state: dict) -> None: - for k, v in src_state.items(): - if k in dest_state: - old_v = dest_state[k] - if isinstance(old_v, Tensor): - old_v.copy_(v) - else: - dest_state[k] = v - - -def optimizer_load_state_dict(optimizer: Optimizer, state_dict: dict, strict: bool = False) -> None: - assert optimizer.state_dict()['param_groups'] == state_dict['param_groups'] - state_dict = deepcopy(state_dict) - groups = optimizer.param_groups - saved_groups = state_dict['param_groups'] - idx_to_p: Dict[str, Parameter] = { - old_id: p for old_id, p in zip(chain.from_iterable((g['params'] for g in saved_groups - )), chain.from_iterable((g['params'] for g in groups))) - } - missing_keys = list(set(idx_to_p.keys()) - set(state_dict['state'].keys())) - unexpected_keys = [] - error_msgs = [] - for idx, state in state_dict['state'].items(): - if idx in idx_to_p: - old_state = optimizer.state[idx_to_p[idx]] - copy_optimizer_state(state, old_state) - else: - unexpected_keys.append(idx) - if strict: - if len(unexpected_keys) > 0: - error_msgs.insert( - 0, 'Unexpected key(s) in state_dict: {}. '.format(', '.join('"{}"'.format(k) for k in unexpected_keys))) - if len(missing_keys) > 0: - error_msgs.insert( - 0, 'Missing key(s) in state_dict: {}. '.format(', '.join('"{}"'.format(k) for k in missing_keys))) - if len(error_msgs) > 0: - raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(optimizer.__class__.__name__, - "\n\t".join(error_msgs))) diff --git a/colossalai/utils/checkpoint_io/writer.py b/colossalai/utils/checkpoint_io/writer.py deleted file mode 100644 index 4552accde470d9e0b54b9d658fdcfb6a33e4bc78..0000000000000000000000000000000000000000 --- a/colossalai/utils/checkpoint_io/writer.py +++ /dev/null @@ -1,98 +0,0 @@ -from abc import ABC, abstractmethod -from typing import Optional -from .constant import MODEL_CKPT_FILE_NAME, OPTIM_CKPT_FILE_NAME, META_CKPT_FILE_NAME, OTHER_CKPT_FILE_NAME, GLOBAL_META_FILE_NAME -import torch -import os - - -class CheckpointWriter(ABC): - - def __init__(self, base_name: str, overwrite: bool = False, rank: int = 0, world_size: int = 1) -> None: - super().__init__() - self.base_name = base_name - self.overwrite = overwrite - self.rank = rank - self.world_size = world_size - self.is_distributed = world_size > 1 - self.is_main_process = rank == 0 - - @abstractmethod - def write(self, name: str, state_dict: dict) -> None: - pass - - @abstractmethod - def save_model(self, model_checkpoint: dict) -> None: - pass - - @abstractmethod - def save_optimizer(self, optimizer_checkpoint: dict) -> None: - pass - - @abstractmethod - def save_meta(self, meta_checkpoint: dict) -> None: - pass - - @abstractmethod - def save_others(self, kwargs: dict) -> None: - pass - - -class DiskCheckpointWriter(CheckpointWriter): - - def __init__(self, base_name: str, overwrite: bool = False, rank: int = 0, world_size: int = 1) -> None: - super().__init__(base_name, overwrite, rank, world_size) - if not os.path.exists(base_name): - os.makedirs(base_name) - assert os.path.isdir(base_name), f'"{base_name}" is not a directory' - self.model_checkpoint_names = [] - self.optimizer_checkpoint_names = [] - self.is_meta_saved: bool = False - self._save_global_meta() - - def write(self, name: str, state_dict: dict) -> None: - path = os.path.join(self.base_name, name) - if os.path.exists(path) and not self.overwrite: - raise RuntimeError(f'Save error: Checkpoint "{path}" exists. (overwrite = False)') - torch.save(state_dict, path) - - def _save_global_meta(self) -> None: - if self.is_main_process: - global_meta = {'meta': []} - if self.is_distributed: - for i in range(self.world_size): - global_meta['meta'].append(META_CKPT_FILE_NAME.replace('.bin', f'-rank{i}.bin')) - else: - global_meta['meta'].append(META_CKPT_FILE_NAME) - self.write(GLOBAL_META_FILE_NAME, global_meta) - - def _get_checkpoint_name(self, base_name: str, shard_idx: Optional[int] = None) -> str: - checkpoint_name = base_name - if self.is_distributed: - checkpoint_name = checkpoint_name.replace('.bin', f'-rank{self.rank}.bin') - if shard_idx is not None: - checkpoint_name = checkpoint_name.replace('.bin', f'-shard{shard_idx}.bin') - return checkpoint_name - - def save_model(self, model_checkpoint: dict) -> None: - assert not self.is_meta_saved, 'Cannot save model after saving meta' - name = self._get_checkpoint_name(MODEL_CKPT_FILE_NAME, len(self.model_checkpoint_names)) - self.write(name, model_checkpoint) - self.model_checkpoint_names.append(name) - - def save_optimizer(self, optimizer_checkpoint: dict) -> None: - assert not self.is_meta_saved, 'Cannot save optimizer after saving meta' - name = self._get_checkpoint_name(OPTIM_CKPT_FILE_NAME, len(self.optimizer_checkpoint_names)) - self.write(name, optimizer_checkpoint) - self.optimizer_checkpoint_names.append(name) - - def save_meta(self, meta_checkpoint: dict) -> None: - if len(self.model_checkpoint_names) > 0: - meta_checkpoint['model'] = self.model_checkpoint_names - if len(self.optimizer_checkpoint_names) > 0: - meta_checkpoint['optimizer'] = self.optimizer_checkpoint_names - self.write(self._get_checkpoint_name(META_CKPT_FILE_NAME), meta_checkpoint) - self.is_meta_saved = True - - def save_others(self, kwargs: dict) -> None: - if self.is_main_process: - self.write(OTHER_CKPT_FILE_NAME, kwargs) diff --git a/colossalai/utils/common.py b/colossalai/utils/common.py index 95b3b8014af1ac2ad1f8c54bec701c45e5c8f900..c43caaff4806bde1113cea831b916e654e64d14c 100644 --- a/colossalai/utils/common.py +++ b/colossalai/utils/common.py @@ -3,44 +3,12 @@ import functools import os import random -import socket -from collections import defaultdict from contextlib import contextmanager from pathlib import Path -from typing import Callable, Dict, List, Optional, Union +from typing import Callable +import numpy as np import torch -import torch.distributed as dist -from torch import inf -from torch.nn.parameter import Parameter - -from colossalai.constants import IS_TENSOR_PARALLEL, NUM_PARTITIONS, TENSOR_PARALLEL_ATTRIBUTES -from colossalai.context.parallel_mode import ParallelMode -from colossalai.core import global_context as gpc -from colossalai.global_variables import tensor_parallel_env as env -from colossalai.tensor import ColoParameter, ProcessGroup - -from .multi_tensor_apply import multi_tensor_applier - -try: - from colossalai._C import fused_optim -except: - fused_optim = None - - -def print_rank_0(msg: str, logger=None): - """Print messages and save logs(optional). This is executed only if you are the rank-0 gpu. - - Args: - msg (str): A string message to output. - logger (:class:`colossalai.logging.DistributedLogger`, optional): - The logger to record the message, defaults to None. - """ - if gpc.get_global_rank() == 0: - if logger is None: - print(msg, flush=True) - else: - logger.info(msg) def ensure_path_exists(filename: str): @@ -50,47 +18,6 @@ def ensure_path_exists(filename: str): Path(dirpath).mkdir(parents=True, exist_ok=True) -def sync_model_param(model, parallel_mode): - r"""Make sure data parameters are consistent during Data Parallel Mode. - - Args: - model (:class:`torch.nn.Module`): A pyTorch model on whose parameters you check the consistency. - parallel_mode (:class:`colossalai.context.ParallelMode`): Parallel mode to be checked. - - Note: - The parallel_mode should be concluded in ``ParallelMode``. More details about ``ParallelMode`` could be found - in `parallel_mode `_ - """ - if gpc.is_initialized(parallel_mode) and gpc.get_world_size(parallel_mode) > 1: - for param in model.parameters(): - ranks = gpc.get_ranks_in_group(parallel_mode) - dist.broadcast(param, src=ranks[0], group=gpc.get_group(parallel_mode)) - - -def is_dp_rank_0(): - return not gpc.is_initialized(ParallelMode.DATA) or gpc.is_first_rank(ParallelMode.DATA) - - -def is_tp_rank_0(): - return not gpc.is_initialized(ParallelMode.TENSOR) or gpc.is_first_rank(ParallelMode.TENSOR) - - -def is_no_pp_or_last_stage(): - return not gpc.is_initialized(ParallelMode.PIPELINE) or gpc.is_last_rank(ParallelMode.PIPELINE) - - -def is_using_ddp(): - return gpc.is_initialized(ParallelMode.DATA) and gpc.get_world_size(ParallelMode.DATA) > 1 - - -def is_using_pp(): - return gpc.is_initialized(ParallelMode.PIPELINE) and gpc.get_world_size(ParallelMode.PIPELINE) > 1 - - -def is_using_sequence(): - return gpc.is_initialized(ParallelMode.SEQUENCE) and gpc.get_world_size(ParallelMode.SEQUENCE) > 1 - - @contextmanager def conditional_context(context_manager, enable=True): if enable: @@ -100,363 +27,8 @@ def conditional_context(context_manager, enable=True): yield -class model_branch_context(object): - - def __enter__(self): - self.env_status = env.save() - - def __exit__(self, *exc_info): - env.load(**self.env_status) - - -def is_model_parallel_parameter(p): - return hasattr(p, IS_TENSOR_PARALLEL) and getattr(p, IS_TENSOR_PARALLEL) - - def is_ddp_ignored(p): - return getattr(p, '_ddp_to_ignore', False) - - -def _calc_l2_norm(grads): - # we should not - global fused_optim - - if fused_optim is None: - from colossalai.kernel.op_builder import FusedOptimBuilder - fused_optim = FusedOptimBuilder().load() - - norm = 0.0 - if len(grads) > 0: - dummy_overflow_buf = torch.cuda.IntTensor([0]) - norm, _ = multi_tensor_applier( - fused_optim.multi_tensor_l2norm, - dummy_overflow_buf, - [grads], - False # no per-parameter norm - ) - return norm - - -def _calc_lp(grads, norm_type): - norm = 0.0 - for grad in grads: - grad_norm = torch.norm(grad, norm_type) - norm += grad_norm**norm_type - return norm - - -def _move_norm_to_cuda(norm: Union[float, torch.Tensor]) -> Union[float, torch.Tensor]: - if torch.is_tensor(norm) and norm.device.type != 'cuda': - norm = norm.to(torch.cuda.current_device()) - return norm - - -def _get_tensor_norm(norm: Union[float, torch.Tensor], move_to_cuda) -> torch.Tensor: - if isinstance(norm, float): - norm = torch.Tensor([norm]) - if move_to_cuda: - norm = norm.to(torch.cuda.current_device()) - return norm - - -# ======== Gradient Clipping ========= - - -def _compute_local_lp(params: List[ColoParameter], norm_type: float) -> float: - if len(params) == 0: - return 0.0 - grads = [p.grad for p in params] - use_cuda_kernel = grads[0].device.type == 'cuda' - if norm_type == inf: - local_lp = max([g.abs().max() for g in grads]) - elif norm_type == 2.0 and use_cuda_kernel: - local_lp = _calc_l2_norm(grads)**norm_type - else: - local_lp = _calc_lp(grads, norm_type) - if isinstance(local_lp, torch.Tensor): - return local_lp.item() - return local_lp - - -def _compute_buckets_lp(params: List[ColoParameter], norm_type: float) -> float: - if len(params) == 0: - return 0.0 - buckets: Dict[Optional[ProcessGroup], List[ColoParameter]] = defaultdict(list) - for p in params: - if p.is_replicate(): - buckets[None].append(p) - else: - buckets[p.get_process_group().tp_process_group()].append(p) - total_lp = 0.0 - for group, bucket in buckets.items(): - local_lp = _compute_local_lp(bucket, norm_type) - if group is not None: - local_lp_tensor = torch.tensor([local_lp], device=torch.cuda.current_device()) - if norm_type == inf: - dist.all_reduce(local_lp_tensor, op=dist.ReduceOp.MAX, group=group) - else: - dist.all_reduce(local_lp_tensor, group=group) - local_lp = local_lp_tensor.item() - if norm_type == inf: - total_lp = max(total_lp, local_lp) - else: - total_lp += local_lp - return total_lp - - -def _compute_pp_grad_lp(total_lp: float, norm_type: float) -> float: - if gpc.is_initialized(ParallelMode.PIPELINE) and gpc.get_world_size(ParallelMode.PIPELINE) > 1: - total_lp_tensor = torch.tensor([total_lp], device=torch.cuda.current_device()) - if norm_type == inf: - dist.all_reduce(total_lp_tensor, op=dist.ReduceOp.MAX, group=gpc.get_group(ParallelMode.PIPELINE)) - else: - dist.all_reduce(total_lp_tensor, group=gpc.get_group(ParallelMode.PIPELINE)) - total_lp = total_lp_tensor.item() - return total_lp - - -def _compute_grad_lp(parameters, norm_type: float = 2.0) -> float: - if isinstance(parameters, torch.Tensor): - parameters = [parameters] - grad_dtype = None - cpu_grad_params: List[ColoParameter] = [] - cuda_grad_params: List[ColoParameter] = [] - for p in parameters: - if p.grad is None: - continue - assert isinstance(p, ColoParameter) - if grad_dtype is None: - grad_dtype = p.grad.dtype - assert p.grad.dtype == grad_dtype, f'Expected all grads are {grad_dtype}, got {p.grad.dtype}' - if p.grad.device.type == 'cuda': - cuda_grad_params.append(p) - else: - cpu_grad_params.append(p) - norm_type = float(norm_type) - cpu_lp = _compute_buckets_lp(cpu_grad_params, norm_type) - cuda_lp = _compute_buckets_lp(cuda_grad_params, norm_type) - if norm_type == inf: - total_lp = max(cpu_lp, cuda_lp) - else: - total_lp = cpu_lp + cuda_lp - return _compute_pp_grad_lp(total_lp, norm_type) - - -def compute_grad_norm(parameters, norm_type: float = 2.0) -> float: - norm_type = float(norm_type) - total_norm = _compute_grad_lp(parameters, norm_type) - if norm_type != inf: - total_norm = total_norm**(1 / norm_type) - return total_norm - - -def _clip_grad_norm(parameters, max_norm: float, total_norm: float) -> None: - clip_coef = max_norm / (total_norm + 1e-6) - if clip_coef < 1.0: - cuda_grads: List[torch.Tensor] = [] - cpu_grads: List[torch.Tensor] = [] - if isinstance(parameters, torch.Tensor): - parameters = [parameters] - for p in parameters: - if p.grad is None: - continue - if p.grad.device.type == 'cuda': - cuda_grads.append(p.grad.detach()) - else: - cpu_grads.append(p.grad.detach()) - if len(cuda_grads) > 0: - dummy_overflow_buf = torch.cuda.IntTensor([0]) - multi_tensor_applier(fused_optim.multi_tensor_scale, dummy_overflow_buf, [cuda_grads, cuda_grads], - clip_coef) - for g in cpu_grads: - g.mul_(clip_coef) - - -def clip_grad_norm(parameters, max_norm: float, norm_type: float = 2.0) -> float: - total_norm = compute_grad_norm(parameters, norm_type) - _clip_grad_norm(parameters, max_norm, total_norm) - return total_norm - - -def clip_grad_norm_fp32(parameters, max_norm, norm_type=2): - """Clips gradient norm of an iterable of parameters whose gradients are in fp32. - - This is adapted from :func:`torch.nn.utils.clip_grad.clip_grad_norm_` and - added functionality to handle model parallel parameters. - - Note: - the gradients are modified in place. - - Args: - parameters (Iterable[:class:`torch.tensor`] or :class:`torch.tensor`): - An iterable of Tensors or a single Tensor that will have gradients normalized. - max_norm (Union[float, int]): Max norm of the gradients. - norm_type (Union[float, int, 'inf']): Type of the used p-norm. Can be ``'inf'`` for infinity norm. - - Returns: - float: Total norm of the parameters. - """ - - if isinstance(parameters, torch.Tensor): - parameters = [parameters] - - # Filter parameters based on: - # - grad should not be none - # - parameter should not be shared - # - should not be a replica due to tensor model parallelism - params: List[Parameter] = [] - has_zero_shared_param: bool = False - for param in parameters: - if param.grad is not None: - # Make sure the grads are in fp32 - assert param.grad.dtype == torch.float, \ - f'expected gradient to be dtype torch.float, but got {param.grad.type()}' - if hasattr(param, 'colo_attr') and param.colo_attr.sharded_data_tensor.is_sharded: - has_zero_shared_param = True - params.append(param) - - if len(params) == 0: - enable_cuda_kernels = False - else: - enable_cuda_kernels = params[0].grad.device.type == 'cuda' - # Norm parameters. - max_norm = float(max_norm) - norm_type = float(norm_type) - - # Parameters can be on CPU or CUDA - # If parameters are on CPU, disable CUDA kernerls - - # Calculate norm. - if norm_type == inf: - total_norm = max(p.grad.data.abs().max() for p in params) - total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)]) - # Take max across all model-parallel GPUs. - if gpc.is_initialized(ParallelMode.MODEL) and gpc.get_world_size(ParallelMode.MODEL) > 1: - dist.all_reduce(total_norm_cuda, - op=dist.ReduceOp.MAX, - group=gpc.get_group(ParallelMode.MODEL), - async_op=False) - if has_zero_shared_param: - dist.all_reduce(total_norm_cuda, - op=dist.ReduceOp.MAX, - group=gpc.get_group(ParallelMode.DATA), - async_op=False) - total_norm = total_norm_cuda[0].item() - else: - tensor_parallel_grads = [] - no_tensor_parallel_grads = [] - zero_sharded_grads = [] - for p in params: - if is_model_parallel_parameter(p): - reductor = (gpc.get_world_size(ParallelMode.TENSOR) / getattr(p, NUM_PARTITIONS))**(1 / norm_type) - tensor_parallel_grads.append(p.grad.data / reductor) - elif hasattr(p, 'colo_attr') and p.colo_attr.sharded_data_tensor.is_sharded: - zero_sharded_grads.append(p.grad.data) - else: - no_tensor_parallel_grads.append(p.grad.data) - - if norm_type == 2.0 and enable_cuda_kernels: - tensor_parallel_norm = _calc_l2_norm(tensor_parallel_grads)**norm_type - no_tensor_parallel_norm = _calc_l2_norm(no_tensor_parallel_grads)**norm_type - zero_sharded_norm = _calc_l2_norm(zero_sharded_grads)**norm_type - else: - tensor_parallel_norm = _calc_lp(tensor_parallel_grads, norm_type) - no_tensor_parallel_norm = _calc_lp(no_tensor_parallel_grads, norm_type) - zero_sharded_norm = _calc_lp(zero_sharded_grads, norm_type) - # If norm is type of float, then we convert them into torch.Tensor. - tensor_parallel_norm = _get_tensor_norm(tensor_parallel_norm, enable_cuda_kernels) - no_tensor_parallel_norm = _get_tensor_norm(no_tensor_parallel_norm, enable_cuda_kernels) - zero_sharded_norm = _get_tensor_norm(zero_sharded_norm, enable_cuda_kernels) - # If grads are on CPU, the norms is also on CPU. Cast them to CUDA tensors - if not enable_cuda_kernels: - tensor_parallel_norm = _move_norm_to_cuda(tensor_parallel_norm) - no_tensor_parallel_norm = _move_norm_to_cuda(no_tensor_parallel_norm) - zero_sharded_norm = _move_norm_to_cuda(zero_sharded_norm) - - # Sum across all model-parallel GPUs. - if gpc.is_initialized(ParallelMode.TENSOR) and len(tensor_parallel_grads) > 0: - dist.all_reduce(tensor_parallel_norm, op=dist.ReduceOp.SUM, group=gpc.get_group(ParallelMode.TENSOR)) - # Sum across all zero sharded GPUs - if len(zero_sharded_grads) > 0: - dist.all_reduce(zero_sharded_norm, group=gpc.get_group(ParallelMode.DATA)) - no_tensor_parallel_norm += zero_sharded_norm - total_norm = tensor_parallel_norm + no_tensor_parallel_norm - if gpc.is_initialized(ParallelMode.PIPELINE) and gpc.get_world_size(ParallelMode.PIPELINE) > 1: - dist.all_reduce(total_norm, op=dist.ReduceOp.SUM, group=gpc.get_group(ParallelMode.PIPELINE)) - total_norm = total_norm**(1.0 / norm_type) - if torch.is_tensor(total_norm): - total_norm = total_norm.item() - - # Scale. - clip_coeff = max_norm / (total_norm + 1.0e-6) - if clip_coeff < 1.0: - if enable_cuda_kernels: - grads = [p.grad.detach() for p in params] - dummy_overflow_buf = torch.cuda.IntTensor([0]) - multi_tensor_applier(fused_optim.multi_tensor_scale, dummy_overflow_buf, [grads, grads], clip_coeff) - else: - for p in params: - p.grad.detach().mul_(clip_coeff) - return total_norm - - -def count_zeros_fp32(parameters): - if isinstance(parameters, torch.Tensor): - parameters = [parameters] - - # Filter parameters based on: - # - grad should not be none - # - parameter should not be shared - # - should not be a replica due to tensor model parallelism - total_num_zeros = 0.0 - for param in parameters: - grad_not_none = param.grad is not None - is_not_tp_duplicate = param_is_not_tensor_parallel_duplicate(param) - if grad_not_none and is_not_tp_duplicate: - grad = param.grad.detach() - num_zeros = grad.numel() - torch.count_nonzero(grad) - total_num_zeros = num_zeros + total_num_zeros - - total_num_zeros = torch.IntTensor([int(total_num_zeros)]).cuda() - - # Sum across all model-parallel GPUs. - ops = [] - ops.append( - dist.all_reduce(total_num_zeros, op=dist.ReduceOp.SUM, group=gpc.get_group(ParallelMode.TENSOR), async_op=True)) - if gpc.is_initialized(ParallelMode.PIPELINE): - ops.append( - dist.all_reduce(total_num_zeros, - op=dist.ReduceOp.SUM, - group=gpc.get_group(ParallelMode.PIPELINE), - async_op=True)) - - for req in ops: - req.wait() - total_num_zeros = total_num_zeros.item() - - return total_num_zeros - - -def copy_tensor_parallel_attributes(src_tensor, dst_tensor): - for attr in TENSOR_PARALLEL_ATTRIBUTES: - if hasattr(src_tensor, attr): - val = getattr(src_tensor, attr) - setattr(dst_tensor, attr, val) - - -def param_is_not_tensor_parallel_duplicate(param): - return (hasattr(param, IS_TENSOR_PARALLEL) and getattr(param, IS_TENSOR_PARALLEL)) or (gpc.get_local_rank( - ParallelMode.TENSOR) == 0) - - -@contextmanager -def switch_virtual_pipeline_parallel_rank(rank): - prev_rank = gpc.virtual_pipeline_parallel_rank - try: - gpc.set_virtual_pipeline_parallel_rank(rank) - yield - finally: - gpc.set_virtual_pipeline_parallel_rank(prev_rank) + return getattr(p, "_ddp_to_ignore", False) def disposable(func: Callable) -> Callable: @@ -470,3 +42,28 @@ def disposable(func: Callable) -> Callable: return func(*args, **kwargs) return wrapper + + +def free_storage(data: torch.Tensor) -> None: + """Free underlying storage of a Tensor.""" + if data.storage().size() > 0: + # Since we're modifying the Tensor's Storage directly, make sure the Tensor + # is the sole occupant of the Storage. + assert data.storage_offset() == 0 + data.storage().resize_(0) + + +def _cast_float(args, dtype: torch.dtype): + if isinstance(args, torch.Tensor) and torch.is_floating_point(args): + args = args.to(dtype) + elif isinstance(args, (list, tuple)): + args = type(args)(_cast_float(t, dtype) for t in args) + elif isinstance(args, dict): + args = {k: _cast_float(v, dtype) for k, v in args.items()} + return args + + +def set_seed(seed): + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) diff --git a/colossalai/utils/cuda.py b/colossalai/utils/cuda.py index 60f3ccb60883e7af56da5f41e1b18c7e20cc098f..6bfb08d1f04aaac1d7e39ff8f049d6361a5e82eb 100644 --- a/colossalai/utils/cuda.py +++ b/colossalai/utils/cuda.py @@ -1,7 +1,10 @@ #!/usr/bin/env python # -*- encoding: utf-8 -*- +from typing import Optional + import torch +import torch.distributed as dist def set_to_cuda(models): @@ -23,12 +26,12 @@ def set_to_cuda(models): def get_current_device() -> torch.device: """ Returns currently selected device (gpu/cpu). - If cuda available, return gpu, otherwise return cpu. + If cuda available, return gpu, otherwise return cpu. """ if torch.cuda.is_available(): - return torch.device(f'cuda:{torch.cuda.current_device()}') + return torch.device(f"cuda:{torch.cuda.current_device()}") else: - return torch.device('cpu') + return torch.device("cpu") def synchronize(): @@ -45,3 +48,9 @@ def empty_cache(): """ if torch.cuda.is_available(): torch.cuda.empty_cache() + + +def set_device(index: Optional[int] = None) -> None: + if index is None: + index = dist.get_rank() % torch.cuda.device_count() + torch.cuda.set_device(index) diff --git a/colossalai/utils/data_sampler/__init__.py b/colossalai/utils/data_sampler/__init__.py deleted file mode 100644 index 12798a94c2d063bb120f805967e748c5a1059a3a..0000000000000000000000000000000000000000 --- a/colossalai/utils/data_sampler/__init__.py +++ /dev/null @@ -1,4 +0,0 @@ -from .base_sampler import BaseSampler -from .data_parallel_sampler import DataParallelSampler, get_dataloader - -__all__ = ['BaseSampler', 'DataParallelSampler', 'get_dataloader'] diff --git a/colossalai/utils/data_sampler/data_parallel_sampler.py b/colossalai/utils/data_sampler/data_parallel_sampler.py deleted file mode 100644 index 945dc54b397a3869232827838d4596e95e188059..0000000000000000000000000000000000000000 --- a/colossalai/utils/data_sampler/data_parallel_sampler.py +++ /dev/null @@ -1,169 +0,0 @@ -#!/usr/bin/env python -# -*- encoding: utf-8 -*- -# adpated from torch.utils.data.DistributedSampler - -import math -import random -import numpy as np -from typing import TypeVar, Iterator - -import torch -from torch.utils.data import Sampler, Dataset, DataLoader - -from colossalai.context.parallel_mode import ParallelMode -from colossalai.core import global_context as gpc -from colossalai.registry import DATA_SAMPLERS - -T_co = TypeVar('T_co', covariant=True) - - -@DATA_SAMPLERS.register_module -class DataParallelSampler(Sampler): - """A data sampler for distributed data parallelism. - - Args: - dataset (:class:`torch.utils.data.Dataset`): The Dataset for sampling. - shuffle (bool, optional): Whether to shuffle data, defaults to False. - seed (int, optional): The random seed used for sampling, defaults to 0. - drop_last (bool, optional): Set to True to drop the last incomplete batch, if the dataset size - is not divisible by the batch size. If False and the size of dataset is not divisible by - the batch size, then the last batch will be smaller, defaults to False. - """ - - def __init__(self, - dataset: Dataset, - shuffle: bool = False, - seed: int = 0, - drop_last: bool = False) -> None: - self.dataset = dataset - self.num_replicas = gpc.get_world_size(ParallelMode.DATA) - self.rank = gpc.get_local_rank(ParallelMode.DATA) - self.epoch = 0 - self.drop_last = drop_last - # If the dataset length is evenly divisible by # of replicas, then there - # is no need to drop any data, since the dataset will be split equally. - # type: ignore[arg-type] - if self.drop_last and len(self.dataset) % self.num_replicas != 0: - # Split to nearest available length that is evenly divisible. - # This is to ensure each rank receives the same amount of data when - # using this Sampler. - self.num_samples = math.ceil( - # `type:ignore` is required because Dataset cannot provide a default __len__ - # see NOTE in pytorch/torch/utils/data/sampler.py - (len(self.dataset) - self.num_replicas) / \ - self.num_replicas # type: ignore[arg-type] - ) - else: - self.num_samples = math.ceil( - len(self.dataset) / self.num_replicas) # type: ignore[arg-type] - self.total_size = self.num_samples * self.num_replicas - self.shuffle = shuffle - self.seed = seed - - def __iter__(self) -> Iterator[T_co]: - if self.shuffle: - # deterministically shuffle based on epoch and seed - g = torch.Generator() - g.manual_seed(self.seed + self.epoch) - # type: ignore[arg-type] - indices = torch.randperm(len(self.dataset), generator=g).tolist() - - # update for next epoch so that there is no need to call - # set_epoch manually - self.epoch += 1 - else: - indices = list(range(len(self.dataset))) # type: ignore[arg-type] - - if not self.drop_last: - # add extra samples to make it evenly divisible - padding_size = self.total_size - len(indices) - if padding_size <= len(indices): - indices += indices[:padding_size] - else: - indices += (indices * math.ceil(padding_size / - len(indices)))[:padding_size] - else: - # remove tail of data to make it evenly divisible. - indices = indices[:self.total_size] - assert len(indices) == self.total_size - - # subsample - indices = indices[self.rank:self.total_size:self.num_replicas] - assert len(indices) == self.num_samples - - return iter(indices) - - def __len__(self) -> int: - return self.num_samples - - def set_epoch(self, epoch: int) -> None: - r"""Sets the epoch for this sampler. When :attr:`shuffle=True`, this ensures all replicas - use a different random ordering for each epoch. Otherwise, the next iteration of this - sampler will yield the same ordering. - - Args: - epoch (int): Epoch number. - """ - self.epoch = epoch - - -def get_dataloader(dataset, - shuffle=False, - seed=1024, - add_sampler=True, - drop_last=False, - pin_memory=False, - num_workers=0, - **kwargs): - r"""Set up a deterministic dataloader (also configure seed workers, samplers and whether shuffle or not) - - Note: - When pipeline parallel is enabled, shuffle cannot be True as it will result in mismatch between input data - on the 1st stage and label on the last stage. - - Args: - dataset (:class:`torch.utils.data.Dataset`): The dataset to be loaded. - shuffle (bool, optional): Whether to shuffle the dataset. Defaults to False. - seed (int, optional): Random worker seed for sampling, defaults to 1024. - add_sampler: Whether to add ``DistributedDataParallelSampler`` to the dataset. Defaults to True. - drop_last (bool, optional): Set to True to drop the last incomplete batch, if the dataset size - is not divisible by the batch size. If False and the size of dataset is not divisible by - the batch size, then the last batch will be smaller, defaults to False. - pin_memory (bool, optional): Whether to pin memory address in CPU memory. Defaults to False. - num_workers (int, optional): Number of worker threads for this dataloader. Defaults to 0. - kwargs (dict): optional parameters for ``torch.utils.data.DataLoader``, more details could be found in - `DataLoader `_. - - Returns: - :class:`torch.utils.data.DataLoader`: A DataLoader used for training or testing. - """ - _kwargs = kwargs.copy() - - if add_sampler and gpc.is_initialized(ParallelMode.DATA) and gpc.get_world_size(ParallelMode.DATA) > 1: - sampler = DataParallelSampler(dataset, shuffle=shuffle) - else: - sampler = None - - # Deterministic dataloader - def seed_worker(worker_id): - worker_seed = seed - np.random.seed(worker_seed) - torch.manual_seed(worker_seed) - random.seed(worker_seed) - - if sampler is None: - return DataLoader(dataset, - worker_init_fn=seed_worker, - shuffle=shuffle, - drop_last=drop_last, - pin_memory=pin_memory, - num_workers=num_workers, - **_kwargs) - else: - return DataLoader(dataset, - sampler=sampler, - worker_init_fn=seed_worker, - drop_last=drop_last, - pin_memory=pin_memory, - num_workers=num_workers, - **_kwargs) diff --git a/colossalai/utils/model/experimental.py b/colossalai/utils/model/experimental.py deleted file mode 100644 index bf3e3d05b99cdc6b8e08922a426d5f3e1096fd07..0000000000000000000000000000000000000000 --- a/colossalai/utils/model/experimental.py +++ /dev/null @@ -1,604 +0,0 @@ -from types import MethodType -from typing import Callable, Optional, Union - -import torch -import torch.distributed as dist -import torch.nn as nn -from torch import Tensor -from torch.utils._pytree import tree_map - -from colossalai._analyzer._subclasses import MetaTensor -from colossalai.tensor.d_tensor.d_tensor import DTensor -from colossalai.tensor.d_tensor.layout import Layout - -# reference: https://pytorch.org/cppdocs/notes/tensor_creation.html -_NORMAL_FACTORY = [ - "arange", - "full", - "empty", - "linspace", - "logspace", - "ones", - "rand", - "randn", - "randint", - "randperm", - "zeros", - "tensor", -] - -# factory function that does not support meta tensor backend -_NO_META_FACTORY = [ - "eye", -] - -_EARLY_MATERIALIZED_OPS = ['__getitem__', 'split'] - -# If your intent is to change the metadata of a Tensor (such as sizes / strides / storage / storage_offset) -# without autograd tracking the change, remove the .data / .detach() call and wrap the change in a `with torch.no_grad():` block. -# These ops cannot be unwrapped using .data -_CHANGE_META_OPS = ['_cudnn_rnn_flatten_weight', 'requires_grad_', '__get__', '__set__'] - -_LEGACY_TENSOR_CONSTRUCTOR = { - 'FloatTensor': torch.float, - 'DoubleTensor': torch.double, - 'HalfTensor': torch.half, - 'BFloat16Tensor': torch.bfloat16, - 'ByteTensor': torch.uint8, - 'CharTensor': torch.int8, - 'ShortTensor': torch.short, - 'IntTensor': torch.int, - 'LongTensor': torch.long, - 'BoolTensor': torch.bool, -} - -_EMPTY_DATA = torch.empty(0) - - -class _MyTensor(Tensor): - """This class is only for correctness verification. - """ - _pre_op_fn: Callable[['LazyTensor'], None] = lambda *args: None - - def __new__(cls, func, *args, concrete_data=None, **kwargs) -> '_MyTensor': - cls._pre_op_fn() - if concrete_data is not None: - # uniform api as LazyTensor - data = concrete_data - else: - data = func(*args, **kwargs) - return Tensor._make_subclass(cls, data, require_grad=data.requires_grad) - - @classmethod - def __torch_function__(cls, func, types, args=(), kwargs=None): - cls._pre_op_fn() - return super().__torch_function__(func, types, args, kwargs) - - -def _data_tolist(tensor: torch.Tensor) -> list: - """tolist() method is not allowed for a subclass of tensor. Tensor.data returns a Tensor. - """ - return tensor.data.tolist() - - -def _convert_cls(tensor: 'LazyTensor', target: torch.Tensor) -> torch.Tensor: - """Convert a lazy tensor's class to target's class, with target's data. - - The reason why we change the class of a lazy tensor in-place is that this can easily handle shared modules/parameters, which is common in huggingface models. - If we create a new tensor and update the module by ``setattr(module, name, param)``, the shared parameters will not be updated. And we have to track all shared parameters and update them manually. - - Args: - tensor (LazyTensor): the LazyTensor to be converted - target (torch.Tensor): target tensor - - Returns: - torch.Tensor: the converted tensor - """ - cls_to_become = nn.Parameter if isinstance(tensor, nn.Parameter) else torch.Tensor - tensor.__class__ = cls_to_become - tensor.data = target - tensor.requires_grad = target.requires_grad - # subclass of torch.Tensor does not have tolist() method - # overwrite this method after materialization or distribution - tensor.tolist = MethodType(_data_tolist, tensor) - return tensor - - -class LazyTensor(torch.Tensor): - """A naive implementation of LazyTensor (https://arxiv.org/pdf/2102.13267.pdf). - - Usage: - 1. Use ``LazyTensor`` instead of ``torch.Tensor``. - >>> x = LazyTensor(torch.zeros, 2, 3) - >>> x += 1 - >>> y = x * x - >>> y = y.cuda().half() - >>> y[0, 0] = 0 - >>> y = y.materialize() # materialize the tensor - >>> print(y) - tensor([[0., 1., 1.], - [1., 1., 1.]], device='cuda:0', dtype=torch.float16) - - Warnings: - 1. Cases that ``LazyTensor`` can't deal with. - >>> x = LazyTensor(torch.ones, 2, 3) - >>> x[0, 0] = -x[0, 0] # this will cause infinite recursion - >>> y = x.clone() - >>> x.add_(1) # modifying origin tensor after cloning leads to wrong materialization - >>> z = x.tolist() - >>> x.zeros_() # modifying origin tensor after cloning tolist is not allowed - >>> nn.utils.weight_norm(self.conv, name="weight", dim=2) # applying weight norm on a lazy tensor is not allowed - - - 2. Cases that ``LazyTensor`` becomes eager (early materialization). - >>> b = a[:, 2:] # get a slice of a lazy tensor triggers early materialization - >>> chunks = a.split(3) # this also triggers early materialization - >>> x.data = torch.rand(2, 3) # directly setting data of a lazy tensor triggers early materialization - - """ - - _repr = True - _meta_data: Optional[MetaTensor] = None # shape, dtype, device - _pre_op_fn: Callable[['LazyTensor'], None] = lambda *args: None - - @staticmethod - def __new__(cls, func, *args, meta_data=None, concrete_data=None, **kwargs): - if concrete_data is not None: - # some ops don't support meta backend and should have concrete data - elem = concrete_data - else: - if meta_data is None: - device = kwargs.get('device', 'cpu') - elem = func(*args, **{**kwargs, 'device': 'meta'}) - meta_data = MetaTensor(elem, device=device) - elem = meta_data._tensor - # As a meta tensor cannot be modified __class__ to torch.Tensor, we should use an empty real tensor here - r = torch.Tensor._make_subclass(cls, _EMPTY_DATA, require_grad=elem.requires_grad) - r._meta_data = meta_data - return r - - def __init__(self, func, *args, meta_data=None, concrete_data=None, **kwargs): - self._factory_method = (func, args, kwargs) # (func, args, kwargs) - self._op_buffer = [] # (func, args, kwargs, replace) - self._materialized_data: Optional[torch.Tensor] = concrete_data # materialized data - - def materialize(self) -> torch.Tensor: - """Materialize the ``LazyTensor`` to ``torch.Tensor`` by modifying __class__ (inplace). - - Returns: - torch.Tensor: The materialized tensor (self). - """ - target = self._materialize_data() - self.clean() - return _convert_cls(self, target) - - def distribute(self, layout: Layout) -> torch.Tensor: - """Distribute the ``LazyTensor`` to ``torch.Tensor`` by modifying __class__ (inplace), according to the layout. - - Args: - layout (Layout): Distribution layout. - - Returns: - torch.Tensor: The distributed tensor (self). - """ - target = self._materialize_data() - self.clean() - local_tensor = DTensor(target, layout).local_tensor - return _convert_cls(self, local_tensor) - - def clean(self) -> None: - """Clean all stored operations, meta data and materialized data, which prevents memory leaking. This should be called after all tensors are materialized. - """ - self._factory_method = None - self._op_buffer = None - self._materialized_data = None - self._meta_data = None - - @staticmethod - def _replace_with_materialized(x): - if isinstance(x, LazyTensor): - return x._materialize_data() - return x - - def _materialize_data(self) -> torch.Tensor: - # self._materialized_data should be generated after the first call of this function - if self._materialized_data is None: - # apply factory method - func, args, kwargs = self._factory_method - - # apply cached sequence - self._pre_op_fn() - - try: - init_val = func(*tree_map(self._replace_with_materialized, args), - **tree_map(self._replace_with_materialized, kwargs)) - except TypeError as e: - print(f'init fn: {func.__name__}') - raise e - - self._materialized_data = self._rerun_ops(init_val) - return self._materialized_data - - def _rerun_ops(self, target=None) -> torch.Tensor: - """Do lazy execution by rerunning all (stored) related operations. - - Args: - target (torc.Tensor, optional): Intial value of the target tensor (self). Defaults to None. - """ - - def replace(x): - if x is self: - return target - elif isinstance(x, LazyTensor): - return x._materialize_data() - return x - - packed = None - - for (func, args, kwargs) in self._op_buffer: - if func == torch.Tensor.requires_grad_: - packed = func, args, kwargs # requires grad should be set at last - else: - self._pre_op_fn() - o = func(*tree_map(replace, args), **tree_map(replace, kwargs)) - target = o if isinstance(o, torch.Tensor) else target # if func returns non-Tensor, discard the value - - # super-dainiu: set requires_grad after all inplace-ops are done - if packed is not None: - func, args, kwargs = packed - func(*tree_map(replace, args), **tree_map(replace, kwargs)) - - return target - - # cache everything with __torch_function__ - - @classmethod - def __torch_function__(cls, func, types, args=(), kwargs=None): - if kwargs is None: - kwargs = {} - if func.__name__ in _EARLY_MATERIALIZED_OPS: - # These OPs cannot be lazy and related tensors should be early materialized - tree_map(cls._replace_with_materialized, args) - tree_map(cls._replace_with_materialized, kwargs) - is_inplace: bool = (func.__name__.endswith('_') and not (func.__name__.endswith('__')) - or func.__name__ in ('__setitem__', '__set__')) - - is_change_meta_op: bool = func.__name__ in _CHANGE_META_OPS - - if isinstance(func, torch._C.ScriptMethod): - # FIXME(ver217): torch script functions are not verified - - target = None - - def unwrap(x): - if isinstance(x, LazyTensor): - return x._meta_data - return x - - target: LazyTensor = args[0].clone() - target._op_buffer.append((func, args, kwargs)) - target._meta_data = getattr(target._meta_data, func.name)(*tree_map(unwrap, args[1:]), - **tree_map(unwrap, kwargs)) - return target - else: - - meta_to_lazy = {} - - def unwrap(x): - if isinstance(x, LazyTensor): - if x._materialized_data is not None: - # for early materialized tensor, use its materialized data directly - return x._materialized_data if is_change_meta_op else x._materialized_data.data - t = x if is_inplace else x.clone() - t._op_buffer.append((func, args, kwargs)) - meta = x._meta_data if is_change_meta_op else x._meta_data.data - meta_to_lazy[meta] = t - return meta - return x - - def wrap(y, i=None): - if isinstance(y, MetaTensor): - if y in meta_to_lazy: - # inplace op, just return origin lazy tensor - return meta_to_lazy[y] - else: - # out of place op, create new lazy tensor - fn = lambda *a, **kw: func(*a, **kw) if i is None else func(*a, **kw)[i] - lazy_y = LazyTensor(fn, *args, meta_data=y, **kwargs) - return lazy_y - elif type(y) is Tensor: - # for early materialized tensor - return LazyTensor(lambda: None, concrete_data=y) - return y - - cls._pre_op_fn() - o = func(*tree_map(unwrap, args), **tree_map(unwrap, kwargs)) - if isinstance(o, (tuple, list)): - return type(o)(wrap(y, i=i) for i, y in enumerate(o)) - return wrap(o) - - @classmethod - def __torch_dispatch__(cls, func, types, args=(), kwargs=None): - pass # skip - - def clone(self) -> "LazyTensor": - - def factory_fn(): - # if self is materialized, return self - new_tensor = self.materialize() if type(self) is LazyTensor else self - return new_tensor.clone() - - target = LazyTensor(factory_fn, meta_data=self._meta_data) - - return target - - def detach(self) -> Tensor: - return self - - def __deepcopy__(self, memo): - if not self.is_leaf: - raise RuntimeError("Only Tensors created explicitly by the user " - "(graph leaves) support the deepcopy protocol at the moment") - if id(self) in memo: - return memo[id(self)] - - def factory_fn(): - # if self is materialized, return self - new_tensor = self.materialize() if type(self) is LazyTensor else self - copied = new_tensor.detach().clone() - if new_tensor.requires_grad: - copied.requires_grad_() - return copied - - target = LazyTensor(factory_fn, meta_data=self._meta_data) - - memo[id(self)] = target - return target - - @property - def data(self): - return self - - @data.setter - def data(self, other: 'LazyTensor'): - """This is sightly different from oringinal `data` setter. - - E.g.: - >>> a = torch.randn(3, 3) # a is a Tensor - >>> b = torch.rand(2, 2) - >>> a.data = b - >>> b.add_(1) # this will affect a - >>> x = torch.randn(3, 3) # x is a LazyTensor - >>> y = torch.rand(2, 2) # y is a LazyTensor - >>> x.data = y - >>> y.add_(1) # this will not affect x - - """ - if other is self: - return - - self._op_buffer.append(other._factory_method) - - def replace(x): - if x is other: - return self - return x - - for func, args, kwargs in other._op_buffer: - self._op_buffer.append((func, tree_map(replace, args), tree_map(replace, kwargs))) - - def tolist(self) -> list: - # Though self.__class__ is modified to torch.Tensor, in C++ side, it is still a subclass of torch.Tensor - # And subclass of torch.Tensor does not have tolist() method - t = self._materialize_data() - return t.tolist() - - def __hash__(self): - return id(self) - - -class LazyInitContext: - """Context manager for lazy initialization. Enables initializing the model without allocating real memory. - - Usage: - 1. The model is initialized, but no real memory is allocated. - >>> ctx = LazyInitContext() - >>> with ctx: - >>> model = MyModel().cuda() - - 2. The model is initialized with ``MetaTensor`` as weights, but still no real memory is allocated. - >>> with ctx.traceable(model): - >>> gm = symbolic_trace(model, meta_args=meta_args) - >>> # Solve the execution strategy and apply the strategy to the model - >>> strategy = StrategyAndSpec() - - 3. The model is initialized with ``torch.Tensor`` as weights, and real memory is allocated. (single device) - >>> model = ctx.materialize(model) - - 3. The model is initialized with sharded ``torch.Tensor`` as weights, and real memory is allocated. (distributed scenario) - >>> model = apply_strategy_to_all_params(model, strategy) - >>> model = ctx.distribute(model) - - Warnings: - This API is still experimental and further modifications can be made to it. - For example: - 1. Quantization strategies can be applied before allocating real memory. - 2. Lazy initialization seems slower than normal initialization. - """ - _replaced: bool = False - - def __init__(self, tensor_cls: Union[_MyTensor, LazyTensor] = LazyTensor): - self.overrides = {} - self.tensor_cls = tensor_cls - - def __enter__(self): - if LazyInitContext._replaced: - raise RuntimeError(f'LazyInitContext is not reentrant') - LazyInitContext._replaced = True - - def wrap_factory_method(target): - # factory functions (eg. torch.empty()) - def wrapper(*args, **kwargs): - return self.tensor_cls(target, *args, **kwargs) - - return wrapper, target - - def wrap_factory_like_method(orig_target, target): - # factory_like functions (eg. torch.empty_like()) - def wrapper(*args, **kwargs): - orig_t = args[0] - return self.tensor_cls(orig_target, *args[1:], device=orig_t.device, dtype=orig_t.dtype, **kwargs) - - return wrapper, target - - def wrap_legacy_constructor(target, dtype): - # legacy constructor (e.g. torch.LongTensor()) - def wrapper(*args, **kwargs): - if len(args) == 1 and isinstance(args[0], torch.Tensor): - # (Tensor other) - return args[0] - elif len(args) == 1: - # (object data, *, torch.device device) - kwargs = {**kwargs, 'dtype': dtype} - replaced, orig = self.overrides['tensor'] - return replaced(*args, **kwargs) - elif _is_int_tuple(args): - # (tuple of ints size, *, torch.device device) - kwargs = {**kwargs, 'dtype': dtype} - replaced, orig = self.overrides['empty'] - return replaced(*args, **kwargs) - else: - raise TypeError( - f'new() received an invalid combination of arguments - got {tuple(type(x) for x in args)}, but expected one of:\n * (Tensor other)\n * (tuple of ints size, *, torch.device device)\n * (object data, *, torch.device device)' - ) - - return wrapper, target - - def wrap_no_meta_factory(target): - # factory functions which don't support meta tensor backend - def wrapper(*args, **kwargs): - tensor = target(*args, **kwargs) - return self.tensor_cls(lambda: None, concrete_data=tensor) - - return wrapper, target - - self.overrides = { - target: wrap_factory_method(getattr(torch, target)) - for target in _NORMAL_FACTORY - if callable(getattr(torch, target, None)) - } - - self.overrides.update({ - target + '_like': wrap_factory_like_method(getattr(torch, target), getattr(torch, target + '_like')) - for target in _NORMAL_FACTORY - if callable(getattr(torch, target + '_like', None)) - }) - - self.overrides.update({ - target: wrap_legacy_constructor(getattr(torch, target), dtype) - for target, dtype in _LEGACY_TENSOR_CONSTRUCTOR.items() - if callable(getattr(torch, target, None)) - }) - - self.overrides.update({ - target: wrap_no_meta_factory(getattr(torch, target)) - for target in _NO_META_FACTORY - if callable(getattr(torch, target, None)) - }) - - for name, (wrapper, orig) in self.overrides.items(): - setattr(torch, name, wrapper) - - def __exit__(self, exc_type, exc_val, exc_tb): - LazyInitContext._replaced = False - for name, (wrapper, orig) in self.overrides.items(): - setattr(torch, name, orig) - - @staticmethod - def materialize(module: nn.Module, verbose: bool = False) -> nn.Module: - """Initialize all ``nn.Parameter`` from ``LazyTensor``. This function will modify the module in-place. - - Args: - module (nn.Module): Target ``nn.Module`` - verbose (bool): Whether to print lazy initialization rate. Defaults to False. - """ - - def apply_fn(name: str, p: LazyTensor): - p.materialize() - - return _apply_to_lazy_module(module, apply_fn, verbose) - - @staticmethod - def distribute(module: nn.Module, layout_dict: dict, verbose: bool = False) -> nn.Module: - """Distribute all ``nn.Parameter`` from ``LazyTensor``. This function will modify the module in-place. - - Args: - module (nn.Module): Target ``nn.Module`` - layout_dict (dict): Dict of layout for each parameter/buffer. The key is the parameter/buffer name, and the value is the layout. - verbose (bool, optional): Whether to print lazy initialization rate. Defaults to False. - """ - - def apply_fn(name: str, p: LazyTensor): - p.distribute(layout_dict[name]) - - return _apply_to_lazy_module(module, apply_fn, verbose) - - -def _apply_to_lazy_module(module: nn.Module, - apply_fn: Callable[[str, torch.Tensor], None], - verbose: bool = False) -> nn.Module: - if verbose: - # verbose info - param_cnt = 0 - param_lazy_cnt = 0 - buf_cnt = 0 - buf_lazy_cnt = 0 - total_numel = 0 - non_lazy_numel = 0 - - for name, p in module.named_parameters(): - if verbose: - param_cnt += 1 - total_numel += p.numel() - if getattr(p, '_materialized_data', False) is None: - # if no _materialized_data attr, the tensor is not lazy - param_lazy_cnt += 1 - else: - non_lazy_numel += p.numel() - if isinstance(p, LazyTensor): - apply_fn(name, p) - - for name, buf in module.named_buffers(): - if verbose: - buf_cnt += 1 - total_numel += buf.numel() - if getattr(buf, "_materialized_data", False) is None: - # if no _materialized_data attr, the tensor is not lazy - buf_lazy_cnt += 1 - else: - non_lazy_numel += buf.numel() - if isinstance(buf, LazyTensor): - apply_fn(name, buf) - - if verbose: - non_lazy_numel_ratio = non_lazy_numel / total_numel * 100 if non_lazy_numel != 0 else 0 - _print_rank_0(f'Param lazy rate: {param_lazy_cnt}/{param_cnt}') - _print_rank_0(f'Buffer lazy rate: {buf_lazy_cnt}/{buf_cnt}') - _print_rank_0( - f'Non lazy numel: {non_lazy_numel} ({non_lazy_numel/1024**2:.3f} M), ratio: {non_lazy_numel_ratio}%') - - return module - - -def _print_rank_0(*args, **kwargs): - if not dist.is_initialized() or dist.get_rank() == 0: - print(*args, **kwargs) - - -def _is_int_tuple(args) -> bool: - if not isinstance(args, tuple): - return False - for x in args: - if not isinstance(x, int): - return False - return True diff --git a/colossalai/utils/model/lazy_init_context.py b/colossalai/utils/model/lazy_init_context.py deleted file mode 100644 index cf05f966089d16884166469cff299e6991192097..0000000000000000000000000000000000000000 --- a/colossalai/utils/model/lazy_init_context.py +++ /dev/null @@ -1,242 +0,0 @@ -#!/usr/bin/env python -# coding: utf-8 - -import inspect -import types -from typing import Callable, List - -import torch -import torch.nn as nn - -from colossalai.tensor import ColoParameter, ColoTensor -from colossalai.utils.model.utils import substitute_init_recursively - - -class LazyInitContext(): - """ - A context to allow for lazy weight initialization of PyTorch modules. It intercepts the tensor - initialization functions for lazy initialization - - Note: - This API is only experimental and subject to future changes. - - Usage: - with LazyInitContext() as ctx: - model = nn.Linear(10, 10) - model.weight.zero_() - - # make sure the weight is a meta tensor - assert model.weight.is_meta - - # initialize weights - ctx.lazy_init_parameters(model) - - # make sure the weight is not a meta tensor - # and initialized correctly - assert not model.weight.is_meta and torch.all(model.weight == 0) - - Args: - to_meta (bool): optional, whether to initialize the model with meta tensors, default is True. This - argument exists for now because some corner cases such as self.weight = torch.zeros(...) cannot be captured yet. - extra_torch_tensor_func (List[str]): extra torch tensor functions related - to value setting, such as `zero_` and `triu_`. `zero_` is pre-added by default. - """ - - tensor_set_value_func = ['zero_', 'fill_'] - - def __init__(self, to_meta: bool = True, extra_torch_tensor_func: List[str] = None): - # TODO: hijack the torch constructor functions as well - self._to_meta = to_meta - self._intercepted_nn_init_func_cache = {} - self._nn_init_methods = self._get_nn_init_methods() - self._torch_mod_cls = torch.nn.modules.module.Module - - if extra_torch_tensor_func: - # use tuple to remove duplicates - self._torch_tensor_funcs = tuple(self.tensor_set_value_func + extra_torch_tensor_func) - else: - self._torch_tensor_funcs = self.tensor_set_value_func - - @property - def to_meta(self): - return self._to_meta - - def _cache_init_func(self, func): - """ - This method wraps the ``torch.nn.init`` method and torch tensor value-setting functions - so that the function call is cached instead of being executed. - """ - - def wrapped_init_func(tensor, *args, **kwargs): - if tensor not in self._intercepted_nn_init_func_cache: - self._intercepted_nn_init_func_cache[tensor] = [] - self._intercepted_nn_init_func_cache[tensor].append((func, args, kwargs)) - - return wrapped_init_func - - def _get_nn_init_methods(self): - """ - This method looks for all available functions in the ``torch.nn.init`` - module. - """ - nn_init_method_names = dir(torch.nn.init) - nn_init_methods = [] - - # look for all methods in ``torch.nn.init`` module - for name in nn_init_method_names: - nn_init_methods.append((name, getattr(torch.nn.init, name))) - - def _is_init_method(item): - name, func = item - - if (not isinstance(func, types.FunctionType) or name.startswith('_') or not name.endswith('_')): - return False - else: - return True - - # remove methods which are not init functions - nn_init_methods = list(filter(_is_init_method, nn_init_methods)) - return nn_init_methods - - def _wrap_module_init(self, func): - """ - This method wraps the calls to the `__init__` of ``torch.nn.Module`` and replaces - the argument device with value 'meta' so that all modules are created as meta tensors. - """ - has_device = 'device' in inspect.signature(func).parameters - - def layer_lazy_init(module, *args, **kwargs): - # if this module contains device argument - # we set it to meta to initialize as meta backend - if has_device: - kwargs['device'] = 'meta' - func(module, *args, **kwargs) - - # if device is not found, we intialize it and convert to meta - if not has_device: - module.to('meta') - - return layer_lazy_init - - def _get_tmp_origin_func_ref(self, name): - """ - Generate a function name for consistency during caching and retrieving. - """ - return f'_orig_{name}' - - def _patch_nn_init_funcs(self): - # patch nn.init functions - for name, func in self._nn_init_methods: - setattr(torch.nn.init, name, self._cache_init_func(func)) - - def _unpatch_nn_init_funcs(self): - # unpatch nn.init functions - for name, func in self._nn_init_methods: - setattr(torch.nn.init, name, func) - - def _patch_submodule_init(self): - # patch classes __init__ methods - def _activate_wrap_init(cls): - cls.__orig_init__ = cls.__init__ - cls.__init__ = self._wrap_module_init(cls.__init__) - - substitute_init_recursively(self._torch_mod_cls, _activate_wrap_init, set()) - - def _unpatch_submodule_init(self): - - def _recover_orig_init(cls): - cls.__init__ = cls.__orig_init__ - - substitute_init_recursively(self._torch_mod_cls, _recover_orig_init, set()) - - def _patch_torch_tensor_funcs(self): - # patch tensor value-setting functions - for func_name in self._torch_tensor_funcs: - origin_func_name = self._get_tmp_origin_func_ref(func_name) - origin_func = getattr(torch.Tensor, func_name) - setattr(torch.Tensor, origin_func_name, origin_func) - setattr(torch.Tensor, func_name, self._cache_init_func(origin_func)) - - def _unpatch_torch_tensor_funcs(self): - for func_name in self._torch_tensor_funcs: - origin_func_name = self._get_tmp_origin_func_ref(func_name) - origin_func = getattr(torch.Tensor, origin_func_name) - setattr(torch.Tensor, func_name, origin_func) - - def __enter__(self): - self._patch_torch_tensor_funcs() - self._patch_nn_init_funcs() - - if self._to_meta: - self._patch_submodule_init() - return self - - def __exit__(self, *args, **kwargs): - if self._to_meta: - self._unpatch_submodule_init() - self._unpatch_nn_init_funcs() - self._unpatch_torch_tensor_funcs() - - def lazy_init_parameters(self, model: torch.nn.Module, device='cpu'): - """ - Initialize the weights of the meta-tensor model. - - Args: - model (`torch.nn.Module`): the model instantiated under the context. - device (str): the device on which weights are initialized - - """ - - def _init_recursively(module: nn.Module): - # recursively initialize the module - for mod in module.children(): - _init_recursively(mod) - - # initialize and shard tensors directly attached to the current module - for name, param in module.named_parameters(recurse=False): - _init_and_shard(module, name, param) - - for name, buf in module.named_buffers(recurse=False): - _init_and_shard(module, name, buf) - - @torch.no_grad() - def _init_and_shard(module, name, tensor): - # check whether the tensor is a buffer or parameter - is_param = isinstance(tensor, nn.parameter.Parameter) - - # get sharding spec - dist_spec = getattr(tensor, 'dist_spec', None) - pg = getattr(tensor, 'pg', None) - comp_spec = getattr(tensor, 'comp_spec', None) - - # convert the tensor from meta to materialized one - if tensor.is_meta: - materialized_tensor = torch.empty_like(tensor, device=device) - # if this tensor is a meta tensor, it must have an init function - assert tensor in self._intercepted_nn_init_func_cache - else: - materialized_tensor = tensor - - # apply init function - if tensor in self._intercepted_nn_init_func_cache: - init_func, args, kwargs = self._intercepted_nn_init_func_cache[tensor][-1] - init_func(materialized_tensor, *args, **kwargs) - - # convert it to ColoTensor or ColoParameter - if is_param: - tensor = ColoParameter.from_torch_tensor(materialized_tensor, requires_grad=tensor.requires_grad) - else: - tensor = ColoTensor.from_torch_tensor(materialized_tensor) - - # override the original tensor - with torch.no_grad(): - setattr(module, name, tensor) - - # apply sharding - if dist_spec: - tensor.process_group = pg - tensor.set_tensor_spec(dist_spec, comp_spec) - - _init_recursively(model) - - return model diff --git a/colossalai/utils/model/utils.py b/colossalai/utils/model/utils.py index f49607376439f6e48797a45c43afceb3fdd27224..4eee4fbc0eee161943218e4a57f2488b8f08acfa 100644 --- a/colossalai/utils/model/utils.py +++ b/colossalai/utils/model/utils.py @@ -27,19 +27,18 @@ def call_to_str(base, *args, **kwargs): Returns: str: A string representation of base(*args, **kwargs) """ - name = f'{base}(' + name = f"{base}(" if args: - name += ', '.join(repr(arg) for arg in args) + name += ", ".join(repr(arg) for arg in args) if kwargs: - name += ', ' + name += ", " if kwargs: - name += ', '.join(f'{key}={repr(arg)}' for key, arg in kwargs.items()) - name += ')' + name += ", ".join(f"{key}={repr(arg)}" for key, arg in kwargs.items()) + name += ")" return name class InsertPostInitMethodToModuleSubClasses(object): - def __init__(self, default_dtype: Optional[torch.dtype] = None): self._old_default_dtype = None self._default_dtype = default_dtype @@ -53,7 +52,6 @@ class InsertPostInitMethodToModuleSubClasses(object): torch.set_default_dtype(self._default_dtype) def preprocess_after(f): - @functools.wraps(f) def wrapper(module: torch.nn.Module, *args, **kwargs): f(module, *args, **kwargs) @@ -70,11 +68,11 @@ class InsertPostInitMethodToModuleSubClasses(object): cls.__init__ = preprocess_after(cls.__init__) # Replace .__init__() for all existing subclasses of torch.nn.Module - # Excution self._post_init_method after the default init function. + # Execution self._post_init_method after the default init function. substitute_init_recursively(torch.nn.modules.module.Module, _enable_class, set()) # holding on to the current __init__subclass__ for exit - torch.nn.modules.module.Module._old_init_subclass = (torch.nn.modules.module.Module.__init_subclass__) + torch.nn.modules.module.Module._old_init_subclass = torch.nn.modules.module.Module.__init_subclass__ # Replace .__init__() for future subclasses of torch.nn.Module torch.nn.modules.module.Module.__init_subclass__ = classmethod(_init_subclass) @@ -82,12 +80,11 @@ class InsertPostInitMethodToModuleSubClasses(object): return self def __exit__(self, exc_type, exc_value, traceback): - if self._default_dtype is not None: torch.set_default_dtype(self._old_default_dtype) def _disable_class(cls): - if not hasattr(cls, '_old_init'): + if not hasattr(cls, "_old_init"): raise AttributeError( f"_old_init is not found in the {cls.__name__}, please make sure that you have imported {cls.__name__} before entering the context." ) @@ -97,7 +94,7 @@ class InsertPostInitMethodToModuleSubClasses(object): substitute_init_recursively(torch.nn.modules.module.Module, _disable_class, set()) # Replace .__init__() for future subclasses of torch.nn.Module - torch.nn.modules.module.Module.__init_subclass__ = (torch.nn.modules.module.Module._old_init_subclass) + torch.nn.modules.module.Module.__init_subclass__ = torch.nn.modules.module.Module._old_init_subclass self._post_context_exec() # Now that we cleaned up the metaclass injection, raise the exception. diff --git a/colossalai/utils/moe.py b/colossalai/utils/moe.py index 86d04c11958b0f03389cf5ae6e12da4e487fca98..1b75448bdd3c6d909cd943dab8c250d63efc3bf9 100644 --- a/colossalai/utils/moe.py +++ b/colossalai/utils/moe.py @@ -1,52 +1,53 @@ -import torch.nn as nn -import torch.distributed as dist -from colossalai.core import global_context as gpc -from colossalai.context.moe_context import MOE_CONTEXT -from colossalai.context import ParallelMode -from .common import is_using_ddp -from typing import Dict, List - - -def get_moe_epsize_param_dict(model: nn.Module) -> Dict[int, List[nn.Parameter]]: - """Returns a parameter dictionary, the key of which is the expert parallel - size of every parameter. Since the parameters in data parallelism is replicated - in each GPU, we set their ep_size to 1. - - Args: - model (:class:`torch.nn.Module`): A pyTorch `nn.Module` from which we get dict. - """ - epsize_param_dict = dict() - for param in model.parameters(): - if not hasattr(param, 'moe_info'): - ep_size = 1 # set ep_size to 1 for dp parameters - else: - ep_size = param.moe_info.ep_size - if ep_size not in epsize_param_dict: - epsize_param_dict[ep_size] = [] - epsize_param_dict[ep_size].append(param) - - return epsize_param_dict - - -def sync_moe_model_param(model: nn.Module): - """Make sure model parameters are consistent in MoE parallel context. - - Args: - model (:class:`torch.nn.Module`): A pyTorch model on whose parameters you check the consistency. - """ - if is_using_ddp(): - - param_dict = get_moe_epsize_param_dict(model) - - # synchronize the parameters whose dp_group is the whole world - if 1 in param_dict: - src_rank = gpc.get_ranks_in_group(ParallelMode.DATA)[0] - for param in param_dict[1]: - dist.broadcast(param, src=src_rank, group=gpc.get_group(ParallelMode.DATA)) - - for ep_size in param_dict: - # When ep_size = world_size, communication is not needed - if ep_size != 1 and ep_size != MOE_CONTEXT.world_size: - src_rank = dist.get_rank(MOE_CONTEXT.parallel_info_dict[ep_size].ep_group) - for param in param_dict[ep_size]: - dist.broadcast(param, src=src_rank, group=param.moe_info.dp_group) +from typing import Dict, List + +import torch.distributed as dist +import torch.nn as nn + +from colossalai.context.moe_context import MOE_CONTEXT +from colossalai.legacy.context import ParallelMode +from colossalai.legacy.core import global_context as gpc +from colossalai.legacy.utils import is_using_ddp + + +def get_moe_epsize_param_dict(model: nn.Module) -> Dict[int, List[nn.Parameter]]: + """Returns a parameter dictionary, the key of which is the expert parallel + size of every parameter. Since the parameters in data parallelism is replicated + in each GPU, we set their ep_size to 1. + + Args: + model (:class:`torch.nn.Module`): A pyTorch `nn.Module` from which we get dict. + """ + epsize_param_dict = dict() + for param in model.parameters(): + if not hasattr(param, "moe_info"): + ep_size = 1 # set ep_size to 1 for dp parameters + else: + ep_size = param.moe_info.ep_size + if ep_size not in epsize_param_dict: + epsize_param_dict[ep_size] = [] + epsize_param_dict[ep_size].append(param) + + return epsize_param_dict + + +def sync_moe_model_param(model: nn.Module): + """Make sure model parameters are consistent in MoE parallel context. + + Args: + model (:class:`torch.nn.Module`): A pyTorch model on whose parameters you check the consistency. + """ + if is_using_ddp(): + param_dict = get_moe_epsize_param_dict(model) + + # synchronize the parameters whose dp_group is the whole world + if 1 in param_dict: + src_rank = gpc.get_ranks_in_group(ParallelMode.DATA)[0] + for param in param_dict[1]: + dist.broadcast(param, src=src_rank, group=gpc.get_group(ParallelMode.DATA)) + + for ep_size in param_dict: + # When ep_size = world_size, communication is not needed + if ep_size != 1 and ep_size != MOE_CONTEXT.world_size: + src_rank = dist.get_rank(MOE_CONTEXT.parallel_info_dict[ep_size].ep_group) + for param in param_dict[ep_size]: + dist.broadcast(param, src=src_rank, group=param.moe_info.dp_group) diff --git a/colossalai/utils/multi_tensor_apply/multi_tensor_apply.py b/colossalai/utils/multi_tensor_apply/multi_tensor_apply.py index 2b6de5fe1f3c810acaf7877a6db05286cd47af3c..750c2a32da34d08d0b802b0072d039fb48b7df16 100644 --- a/colossalai/utils/multi_tensor_apply/multi_tensor_apply.py +++ b/colossalai/utils/multi_tensor_apply/multi_tensor_apply.py @@ -25,7 +25,9 @@ class MultiTensorApply(object): raise RuntimeError( "Attempted to call MultiTensorApply method, but MultiTensorApply " "is not available, possibly because Apex was installed without " - "--cpp_ext --cuda_ext. Original import error message:", MultiTensorApply.import_err) + "--cpp_ext --cuda_ext. Original import error message:", + MultiTensorApply.import_err, + ) def __call__(self, op, noop_flag_buffer, tensor_lists, *args): self.check_avail() diff --git a/colossalai/utils/profiler/legacy/__init__.py b/colossalai/utils/profiler/legacy/__init__.py deleted file mode 100644 index 849c7fca305315c267dcc1bcbb014d44225d739f..0000000000000000000000000000000000000000 --- a/colossalai/utils/profiler/legacy/__init__.py +++ /dev/null @@ -1,6 +0,0 @@ -from .comm_profiler import CommProfiler -from .pcie_profiler import PcieProfiler -from .prof_utils import ProfilerContext, BaseProfiler -from .mem_profiler import MemProfiler - -__all__ = ['BaseProfiler', 'CommProfiler', 'PcieProfiler', 'MemProfiler', 'ProfilerContext'] diff --git a/colossalai/utils/profiler/legacy/prof_utils.py b/colossalai/utils/profiler/legacy/prof_utils.py deleted file mode 100644 index 87ad644a7ecc989e56ecc581ab2a3ba8d609b94f..0000000000000000000000000000000000000000 --- a/colossalai/utils/profiler/legacy/prof_utils.py +++ /dev/null @@ -1,131 +0,0 @@ -from abc import ABC, abstractmethod -from pathlib import Path -from typing import Union, List -from colossalai.core import global_context as gpc - - -# copied from high version pytorch to support low version -def _format_time(time_us): - """Defines how to format time in FunctionEvent""" - US_IN_SECOND = 1000.0 * 1000.0 - US_IN_MS = 1000.0 - if time_us >= US_IN_SECOND: - return '{:.3f}s'.format(time_us / US_IN_SECOND) - if time_us >= US_IN_MS: - return '{:.3f}ms'.format(time_us / US_IN_MS) - return '{:.3f}us'.format(time_us) - - -# copied from high version pytorch to support low version -def _format_memory(nbytes): - """Returns a formatted memory size string""" - KB = 1024 - MB = 1024 * KB - GB = 1024 * MB - if (abs(nbytes) >= GB): - return '{:.2f} GB'.format(nbytes * 1.0 / GB) - elif (abs(nbytes) >= MB): - return '{:.2f} MB'.format(nbytes * 1.0 / MB) - elif (abs(nbytes) >= KB): - return '{:.2f} KB'.format(nbytes * 1.0 / KB) - else: - return str(nbytes) + ' B' - - -def _format_bandwidth(volme: float or int, time_us: int): - sec_div_mb = (1000.0 / 1024.0)**2 - mb_per_sec = volme / time_us * sec_div_mb - - if mb_per_sec >= 1024.0: - return '{:.3f} GB/s'.format(mb_per_sec / 1024.0) - else: - return '{:.3f} MB/s'.format(mb_per_sec) - - -class BaseProfiler(ABC): - - def __init__(self, profiler_name: str, priority: int): - self.name = profiler_name - self.priority = priority - - @abstractmethod - def enable(self): - pass - - @abstractmethod - def disable(self): - pass - - @abstractmethod - def to_tensorboard(self, writer): - pass - - @abstractmethod - def to_file(self, filename: Path): - pass - - @abstractmethod - def show(self): - pass - - -class ProfilerContext(object): - """Profiler context manager - - Usage:: - - world_size = 4 - inputs = torch.randn(10, 10, dtype=torch.float32, device=get_current_device()) - outputs = torch.empty(world_size, 10, 10, dtype=torch.float32, device=get_current_device()) - outputs_list = list(torch.chunk(outputs, chunks=world_size, dim=0)) - - cc_prof = CommProfiler() - - with ProfilerContext([cc_prof]) as prof: - op = dist.all_reduce(inputs, async_op=True) - dist.all_gather(outputs_list, inputs) - op.wait() - dist.reduce_scatter(inputs, outputs_list) - dist.broadcast(inputs, 0) - dist.reduce(inputs, 0) - - prof.show() - """ - - def __init__(self, profilers: List[BaseProfiler] = None, enable: bool = True): - self.enable = enable - self.profilers = sorted(profilers, key=lambda prof: prof.priority) - - def __enter__(self): - if self.enable: - for prof in self.profilers: - prof.enable() - return self - - def __exit__(self, exc_type, exc_val, exc_tb): - if self.enable: - for prof in self.profilers: - prof.disable() - - def to_tensorboard(self, writer): - from torch.utils.tensorboard import SummaryWriter - - assert isinstance(writer, SummaryWriter), \ - f'torch.utils.tensorboard.SummaryWriter is required, but found {type(writer)}.' - - for prof in self.profilers: - prof.to_tensorboard(writer) - - def to_file(self, log_dir: Union[str, Path]): - if isinstance(log_dir, str): - log_dir = Path(log_dir) - - if not log_dir.exists(): - log_dir.mkdir(parents=True, exist_ok=True) - for prof in self.profilers: - log_file = log_dir.joinpath(f'{prof.name}_rank_{gpc.get_global_rank()}.log') - prof.to_file(log_file) - - def show(self): - for prof in self.profilers: - prof.show() diff --git a/colossalai/utils/rank_recorder/README.md b/colossalai/utils/rank_recorder/README.md index e30a925d2a9291de9d8eeb01119237d24a1bc38c..cad6c1fddd712d10cd47169294c059806ec22e9c 100644 --- a/colossalai/utils/rank_recorder/README.md +++ b/colossalai/utils/rank_recorder/README.md @@ -1,7 +1,7 @@ # Rank Recorder -This is a useful tool to get the records of certain functions in each rank. The records of each rank will dump into a json file after the end of multiple process program. You can parse and visualise the json file easily. +This is a useful tool to get the records of certain functions in each rank. The records of each rank will dump into a json file after the end of multiple process program. You can parse and visualize the json file easily. -Before using the tool, you should ensure dist.is_initialized() return true before exit of program. +Before using the tool, you should ensure dist.is_initialized() return true before exit of program. ## Usage @@ -20,7 +20,7 @@ with recorder(record_name, current_rank) as r: ``` ## Example -This is a demo to display kernel select in cuda and visualise the cost of several procedures in each rank. +This is a demo to display kernel select in cuda and visualize the cost of several procedures in each rank. ```python import time @@ -58,10 +58,10 @@ def worker(rank): with recorder("calc_1(x100)", rank) as r: calc(100, 100) - + with recorder("calc_2(x400)", rank) as r: calc(400, 400) - + with recorder("calc_2(x200)", rank) as r: calc(200, 200) @@ -69,4 +69,4 @@ if __name__ == "__main__": mp.spawn(worker, nprocs=WORLD_SIZE) ``` -run the script directly and you will get `kernel_select.json` and `kernel_select.png` in your current folder. \ No newline at end of file +run the script directly and you will get `kernel_select.json` and `kernel_select.png` in your current folder. diff --git a/colossalai/utils/rank_recorder/__init__.py b/colossalai/utils/rank_recorder/__init__.py index 1274d0e7dbc5277a60fb4d6fcd16c972b85c305c..1d347075a8ce47e1c5faaccf9f7db6853f20c2d0 100644 --- a/colossalai/utils/rank_recorder/__init__.py +++ b/colossalai/utils/rank_recorder/__init__.py @@ -1,3 +1,3 @@ from colossalai.utils.rank_recorder.rank_recorder import recorder -__all__ = ["recorder"] \ No newline at end of file +__all__ = ["recorder"] diff --git a/colossalai/utils/rank_recorder/rank_recorder.py b/colossalai/utils/rank_recorder/rank_recorder.py index c088ceeb2e87727ca9f4e5ac4a71d0721405f4a1..1cb9169125a12d86d3f49783c09c577d7e194c10 100644 --- a/colossalai/utils/rank_recorder/rank_recorder.py +++ b/colossalai/utils/rank_recorder/rank_recorder.py @@ -1,18 +1,15 @@ -import time -from typing import List, Dict +import atexit import json import os -import time import shutil -import atexit +import time +from typing import Dict, List +import matplotlib.colors as mcolors +import matplotlib.pyplot as plt import torch import torch.distributed as dist -import json -import matplotlib.pyplot as plt -import matplotlib.colors as mcolors - cmap = list(mcolors.TABLEAU_COLORS.values()) LOG_FOLDER = "record.log" @@ -20,7 +17,6 @@ MAX_WAIT_TIME = 20 class Event: - def __init__(self, start: int, end: int, name: str, rank: int) -> None: self.start = start self.end = end @@ -29,16 +25,15 @@ class Event: class Recorder: - def __init__(self) -> None: self.rank_to_history: Dict[int, List[Event]] = {} self.base_time = time.time() self.temp_event = None - self.export_format = 'png' - self.export_name = 'test' + self.export_format = "png" + self.export_name = "test" self.dpi = 500 - self.theme = 'dark_background' + self.theme = "dark_background" self.figure_width = 30 self.figure_height = 10 self.legend_fontsize = 16 @@ -84,18 +79,18 @@ class Recorder: def dump_record(self): rank = dist.get_rank() rank_to_history = self.rank_to_history - records = {'base_time': self.base_time, 'content': {}} + records = {"base_time": self.base_time, "content": {}} for record_rank in rank_to_history: history = rank_to_history[record_rank] recs = [] for event in history: - rec = {'start': event.start, 'end': event.end, 'name': event.name} + rec = {"start": event.start, "end": event.end, "name": event.name} recs.append(rec) - records['content'][record_rank] = recs + records["content"][record_rank] = recs - dump_name = f'{rank}.json' + dump_name = f"{rank}.json" dump_path = os.path.join(LOG_FOLDER, dump_name) - with open(dump_path, 'w', encoding='utf-8') as f: + with open(dump_path, "w", encoding="utf-8") as f: json.dump(records, f, ensure_ascii=False) def merge_recode(self): @@ -117,24 +112,22 @@ class Recorder: logs_path = [os.path.join(LOG_FOLDER, file) for file in os.listdir(LOG_FOLDER)] recoders = {} for path in logs_path: - with open(path, 'r', encoding='utf-8') as f: + with open(path, "r", encoding="utf-8") as f: recs = json.load(f) - for record_rank in recs['content']: - history = recs['content'][record_rank] + for record_rank in recs["content"]: + history = recs["content"][record_rank] recoders[record_rank] = [] for rec in history: - recoders[record_rank].append({ - 'start': rec['start'] - base_time, - 'end': rec['end'] - base_time, - 'name': rec['name'] - }) + recoders[record_rank].append( + {"start": rec["start"] - base_time, "end": rec["end"] - base_time, "name": rec["name"]} + ) shutil.rmtree(LOG_FOLDER) - with open(self.export_name + '.json', 'w', encoding='utf-8') as f: + with open(self.export_name + ".json", "w", encoding="utf-8") as f: json.dump(recoders, f, ensure_ascii=False) - def visualise_record(self): - with open(self.export_name + '.json', 'r', encoding='utf-8') as f: + def visualize_record(self): + with open(self.export_name + ".json", "r", encoding="utf-8") as f: records = json.load(f) records = dict(records) ranks = list(sorted(records.keys())) @@ -147,9 +140,9 @@ class Recorder: for rank in ranks: rank_records = records[rank] for rec in rank_records: - s = rec['start'] - e = rec['end'] - name = rec['name'] + s = rec["start"] + e = rec["end"] + name = rec["name"] if name not in name_list: name_list[name] = len(name_list) bar = plt.barh(rank, width=e - s, height=self.bar_height, left=s, color=cmap[name_list[name]]) @@ -157,8 +150,8 @@ class Recorder: plots[name] = bar plt.legend(list(plots.values()), list(plots.keys()), loc="upper left", fontsize=self.legend_fontsize) - plt.yticks(ticks=ranks, labels=[f'Device:{rank}' for rank in ranks], fontsize=self.device_fontsize) - plt.grid(axis='x') + plt.yticks(ticks=ranks, labels=[f"Device:{rank}" for rank in ranks], fontsize=self.device_fontsize) + plt.grid(axis="x") plt.savefig("{}.{}".format(self.export_name, self.export_format)) def exit_worker(self): @@ -171,7 +164,7 @@ class Recorder: if rank == 1: # take the base time of rank 0 as standard self.merge_recode() - self.visualise_record() + self.visualize_record() recorder = Recorder() diff --git a/colossalai/utils/tensor_detector/__init__.py b/colossalai/utils/tensor_detector/__init__.py index cafc19b67c5c3b5f585e3dbef49f6bdbb85d5755..c6c68aa4009bb7c6a2ccebaa45824e0e069bcbb8 100644 --- a/colossalai/utils/tensor_detector/__init__.py +++ b/colossalai/utils/tensor_detector/__init__.py @@ -1 +1 @@ -from .tensor_detector import TensorDetector +from .tensor_detector import TensorDetector diff --git a/colossalai/utils/tensor_detector/readme.md b/colossalai/utils/tensor_detector/readme.md index 840dc8f4eca648f2d8b4ffc286b430003169216e..455eae18116ae684e7ba2d53e2cd8e71b17cda0b 100644 --- a/colossalai/utils/tensor_detector/readme.md +++ b/colossalai/utils/tensor_detector/readme.md @@ -14,7 +14,7 @@ class MLP(nn.Module): super().__init__() self.mlp = nn.Sequential(nn.Linear(64, 8), nn.ReLU(), - nn.Linear(8, 32)) + nn.Linear(8, 32)) def forward(self, x): return self.mlp(x) ``` @@ -46,7 +46,7 @@ detector.detect() I have made some comments on the right of the output for your understanding. -Note that the total `Mem` of all the tensors and parameters is not equal to `Total GPU Memery Allocated`. PyTorch's memory management is really complicated, and for models of a large scale, it's impossible to figure out clearly. +Note that the total `Mem` of all the tensors and parameters is not equal to `Total GPU Memory Allocated`. PyTorch's memory management is really complicated, and for models of a large scale, it's impossible to figure out clearly. **The order of print is not equal to the order the tensor creates, but they are really close.** @@ -61,7 +61,7 @@ Note that the total `Mem` of all the tensors and parameters is not equal to `Tot + mlp.2.bias cuda:0 (32,) True torch.float32 128 B ------------------------------------------------------------------------------------------------------------ Detect Location: "test_tensor_detector.py" line 27 -Totle GPU Memery Allocated on cuda:0 is 4.5 KB +Total GPU Memory Allocated on cuda:0 is 4.5 KB ------------------------------------------------------------------------------------------------------------ @@ -72,7 +72,7 @@ Totle GPU Memery Allocated on cuda:0 is 4.5 KB + Tensor cuda:0 (32,) True torch.float32 128 B # output ------------------------------------------------------------------------------------------------------------ Detect Location: "test_tensor_detector.py" line 30 -Totle GPU Memery Allocated on cuda:0 is 5.5 KB +Total GPU Memory Allocated on cuda:0 is 5.5 KB ------------------------------------------------------------------------------------------------------------ @@ -82,7 +82,7 @@ Totle GPU Memery Allocated on cuda:0 is 5.5 KB + Tensor cuda:0 () True torch.float32 4 B # loss ------------------------------------------------------------------------------------------------------------ Detect Location: "test_tensor_detector.py" line 32 -Totle GPU Memery Allocated on cuda:0 is 6.0 KB +Total GPU Memory Allocated on cuda:0 is 6.0 KB ------------------------------------------------------------------------------------------------------------ @@ -103,7 +103,7 @@ Totle GPU Memery Allocated on cuda:0 is 6.0 KB - Tensor cuda:0 (8,) True torch.float32 32 B # deleted activation ------------------------------------------------------------------------------------------------------------ Detect Location: "test_tensor_detector.py" line 34 -Totle GPU Memery Allocated on cuda:0 is 10.0 KB +Total GPU Memory Allocated on cuda:0 is 10.0 KB ------------------------------------------------------------------------------------------------------------ @@ -117,7 +117,7 @@ Totle GPU Memery Allocated on cuda:0 is 10.0 KB + Tensor cuda:0 (32,) False torch.float32 128 B ------------------------------------------------------------------------------------------------------------ Detect Location: "test_tensor_detector.py" line 36 -Totle GPU Memery Allocated on cuda:0 is 14.0 KB +Total GPU Memory Allocated on cuda:0 is 14.0 KB ------------------------------------------------------------------------------------------------------------ ``` @@ -125,4 +125,3 @@ Totle GPU Memery Allocated on cuda:0 is 14.0 KB This tool was inspired by https://github.com/Stonesjtu/pytorch_memlab/blob/master/pytorch_memlab/mem_reporter.py and https://github.com/Oldpan/Pytorch-Memory-Utils - diff --git a/colossalai/utils/tensor_detector/tensor_detector.py b/colossalai/utils/tensor_detector/tensor_detector.py index a8186f76834c1eec3fdb28c7b6e12dfd6e65260f..38cf094b8dd0aa7e011e1633d731b1615f1bb02c 100644 --- a/colossalai/utils/tensor_detector/tensor_detector.py +++ b/colossalai/utils/tensor_detector/tensor_detector.py @@ -1,21 +1,19 @@ import gc import inspect +from collections import defaultdict +from typing import Optional + import torch import torch.nn as nn -from typing import Optional -from collections import defaultdict LINE_WIDTH = 108 -LINE = '-' * LINE_WIDTH + '\n' - +LINE = "-" * LINE_WIDTH + "\n" -class TensorDetector(): - def __init__(self, - show_info: bool = True, - log: str = None, - include_cpu: bool = False, - module: Optional[nn.Module] = None): +class TensorDetector: + def __init__( + self, show_info: bool = True, log: str = None, include_cpu: bool = False, module: Optional[nn.Module] = None + ): """This class is a detector to detect tensor on different devices. Args: @@ -55,42 +53,41 @@ class TensorDetector(): return self.mem_format(memory_size) def mem_format(self, real_memory_size): - # format the tensor memory into a reasonal magnitude + # format the tensor memory into a reasonable magnitude if real_memory_size >= 2**30: - return str(real_memory_size / (2**30)) + ' GB' + return str(real_memory_size / (2**30)) + " GB" if real_memory_size >= 2**20: - return str(real_memory_size / (2**20)) + ' MB' + return str(real_memory_size / (2**20)) + " MB" if real_memory_size >= 2**10: - return str(real_memory_size / (2**10)) + ' KB' - return str(real_memory_size) + ' B' + return str(real_memory_size / (2**10)) + " KB" + return str(real_memory_size) + " B" def collect_tensors_state(self): for obj in gc.get_objects(): if torch.is_tensor(obj): # skip cpu tensor when include_cpu is false and the tensor we have collected before - if (not self.include_cpu) and obj.device == torch.device('cpu'): + if (not self.include_cpu) and obj.device == torch.device("cpu"): continue self.detected.append(id(obj)) - # skip paramters we had added in __init__ when module is an instance of nn.Module for the first epoch + # skip parameters we had added in __init__ when module is an instance of nn.Module for the first epoch if id(obj) not in self.tensor_info: - name = type(obj).__name__ # after backward, we want to update the records, to show you the change - if isinstance(self.module, nn.Module) and name == 'Parameter': + if isinstance(self.module, nn.Module) and name == "Parameter": if obj.grad is not None: # with grad attached for par_name, param in self.module.named_parameters(): if param.requires_grad and param.grad.equal(obj.grad): - name = par_name + ' (with grad)' + name = par_name + " (with grad)" else: # with no grad attached - # there will be no new paramters created during running + # there will be no new parameters created during running # so it must be in saved_tensor_info continue # we can also marked common tensors as tensor(with grad) - if name == 'Tensor' and (obj.is_leaf or obj.retains_grad): + if name == "Tensor" and (obj.is_leaf or obj.retains_grad): if obj.grad is not None: - name = name + ' (with grad)' + name = name + " (with grad)" # in fact, common tensor have no grad # unless you set retain_grad() if id(obj) in self.saved_tensor_info.keys() and name == self.saved_tensor_info[id(obj)][0]: @@ -111,10 +108,10 @@ class TensorDetector(): self.devices.append(obj.device) def print_tensors_state(self): - template_format = '{:3s}{:<30s}{:>10s}{:>20s}{:>10s}{:>20s}{:>15s}' + template_format = "{:3s}{:<30s}{:>10s}{:>20s}{:>10s}{:>20s}{:>15s}" self.info += LINE - self.info += template_format.format(' ', 'Tensor', 'device', 'shape', 'grad', 'dtype', 'Mem') - self.info += '\n' + self.info += template_format.format(" ", "Tensor", "device", "shape", "grad", "dtype", "Mem") + self.info += "\n" self.info += LINE # if a tensor updates this turn, and was recorded before @@ -124,24 +121,30 @@ class TensorDetector(): minus = outdated + minus if len(self.order) > 0: for tensor_id in self.order: - self.info += template_format.format('+', str(self.tensor_info[tensor_id][0]), - str(self.tensor_info[tensor_id][1]), - str(tuple(self.tensor_info[tensor_id][2])), - str(self.tensor_info[tensor_id][3]), - str(self.tensor_info[tensor_id][4]), - str(self.tensor_info[tensor_id][5])) - self.info += '\n' + self.info += template_format.format( + "+", + str(self.tensor_info[tensor_id][0]), + str(self.tensor_info[tensor_id][1]), + str(tuple(self.tensor_info[tensor_id][2])), + str(self.tensor_info[tensor_id][3]), + str(self.tensor_info[tensor_id][4]), + str(self.tensor_info[tensor_id][5]), + ) + self.info += "\n" if len(self.order) > 0 and len(minus) > 0: - self.info += '\n' + self.info += "\n" if len(minus) > 0: for tensor_id in minus: - self.info += template_format.format('-', str(self.saved_tensor_info[tensor_id][0]), - str(self.saved_tensor_info[tensor_id][1]), - str(tuple(self.saved_tensor_info[tensor_id][2])), - str(self.saved_tensor_info[tensor_id][3]), - str(self.saved_tensor_info[tensor_id][4]), - str(self.saved_tensor_info[tensor_id][5])) - self.info += '\n' + self.info += template_format.format( + "-", + str(self.saved_tensor_info[tensor_id][0]), + str(self.saved_tensor_info[tensor_id][1]), + str(tuple(self.saved_tensor_info[tensor_id][2])), + str(self.saved_tensor_info[tensor_id][3]), + str(self.saved_tensor_info[tensor_id][4]), + str(self.saved_tensor_info[tensor_id][5]), + ) + self.info += "\n" # deleted the updated tensor self.saved_tensor_info.pop(tensor_id) @@ -152,16 +155,16 @@ class TensorDetector(): self.info += LINE self.info += f"Detect Location: {locate_msg}\n" for device in self.devices: - if device == torch.device('cpu'): + if device == torch.device("cpu"): continue gpu_mem_alloc = self.mem_format(torch.cuda.memory_allocated(device)) - self.info += f"Totle GPU Memery Allocated on {device} is {gpu_mem_alloc}\n" + self.info += f"Total GPU Memory Allocated on {device} is {gpu_mem_alloc}\n" self.info += LINE - self.info += '\n\n' + self.info += "\n\n" if self.show_info: print(self.info) if self.log is not None: - with open(self.log + '.log', 'a') as f: + with open(self.log + ".log", "a") as f: f.write(self.info) def detect(self, include_cpu=False): diff --git a/colossalai/utils/timer.py b/colossalai/utils/timer.py index 4b61f4a5ef1148b74962f349cd24d852da899737..2f61817f0461d5e238a1a4c749d16955752830ad 100644 --- a/colossalai/utils/timer.py +++ b/colossalai/utils/timer.py @@ -2,12 +2,12 @@ # -*- encoding: utf-8 -*- import time from typing import Tuple + from .cuda import synchronize class Timer: - """A timer object which helps to log the execution times, and provides different tools to assess the times. - """ + """A timer object which helps to log the execution times, and provides different tools to assess the times.""" def __init__(self): self._started = False @@ -25,16 +25,14 @@ class Timer: return time.time() def start(self): - """Firstly synchronize cuda, reset the clock and then start the timer. - """ + """Firstly synchronize cuda, reset the clock and then start the timer.""" self._elapsed = 0 synchronize() self._start_time = time.time() self._started = True def lap(self): - """lap time and return elapsed time - """ + """lap time and return elapsed time""" return self.current_time - self._start_time def stop(self, keep_in_history: bool = False): @@ -80,12 +78,11 @@ class Timer: Note: Use it only when timer is not in progress """ - assert not self._started, 'Timer is still in progress' + assert not self._started, "Timer is still in progress" return self._elapsed def reset(self): - """Clear up the timer and its history - """ + """Clear up the timer and its history""" self._history = [] self._started = False self._elapsed = 0 diff --git a/colossalai/zero/__init__.py b/colossalai/zero/__init__.py index 3465079e4fbbf4dbd3460cddea68d4b818216957..90d0f8de191655e21c45aa7506d0b9d5aab34e78 100644 --- a/colossalai/zero/__init__.py +++ b/colossalai/zero/__init__.py @@ -2,8 +2,7 @@ from .gemini import ( ColoInitContext, GeminiAdamOptimizer, GeminiDDP, - ZeroDDP, - ZeroOptimizer, + GeminiOptimizer, get_static_torch_model, post_process_colo_init_ctx, ) @@ -11,6 +10,13 @@ from .low_level import LowLevelZeroOptimizer from .wrapper import zero_model_wrapper, zero_optim_wrapper __all__ = [ - 'ZeroDDP', 'GeminiDDP', 'ZeroOptimizer', 'GeminiAdamOptimizer', 'zero_model_wrapper', 'zero_optim_wrapper', - 'LowLevelZeroOptimizer', 'ColoInitContext', 'post_process_colo_init_ctx', 'get_static_torch_model' + "GeminiDDP", + "GeminiOptimizer", + "GeminiAdamOptimizer", + "zero_model_wrapper", + "zero_optim_wrapper", + "LowLevelZeroOptimizer", + "ColoInitContext", + "post_process_colo_init_ctx", + "get_static_torch_model", ] diff --git a/colossalai/zero/gemini/__init__.py b/colossalai/zero/gemini/__init__.py index 60f85ca2f540497fb9ba8e11c55b1239a8cb45d6..358d5c7fd2895ed122dcb76db0a28c42ab0fd844 100644 --- a/colossalai/zero/gemini/__init__.py +++ b/colossalai/zero/gemini/__init__.py @@ -1,11 +1,20 @@ from .chunk import ChunkManager, TensorInfo, TensorState, search_chunk_configuration from .colo_init_context import ColoInitContext, post_process_colo_init_ctx -from .gemini_ddp import GeminiDDP, ZeroDDP +from .gemini_ddp import GeminiDDP from .gemini_mgr import GeminiManager -from .gemini_optimizer import GeminiAdamOptimizer, ZeroOptimizer +from .gemini_optimizer import GeminiAdamOptimizer, GeminiOptimizer from .utils import get_static_torch_model __all__ = [ - 'GeminiManager', 'TensorInfo', 'TensorState', 'ChunkManager', 'search_chunk_configuration', 'ZeroDDP', 'GeminiDDP', - 'get_static_torch_model', 'GeminiAdamOptimizer', 'ZeroOptimizer', 'ColoInitContext', 'post_process_colo_init_ctx' + "GeminiManager", + "TensorInfo", + "TensorState", + "ChunkManager", + "search_chunk_configuration", + "GeminiDDP", + "get_static_torch_model", + "GeminiAdamOptimizer", + "GeminiOptimizer", + "ColoInitContext", + "post_process_colo_init_ctx", ] diff --git a/colossalai/zero/gemini/chunk/__init__.py b/colossalai/zero/gemini/chunk/__init__.py index 6914d2dbef4581dbf37610cfc7589a2c5be77406..91906f68ad25c49a7e82e0b4c8576249fa0de001 100644 --- a/colossalai/zero/gemini/chunk/__init__.py +++ b/colossalai/zero/gemini/chunk/__init__.py @@ -3,4 +3,4 @@ from .manager import ChunkManager from .search_utils import classify_params_by_dp_degree, search_chunk_configuration from .utils import init_chunk_manager -__all__ = ['Chunk', 'ChunkManager', 'classify_params_by_dp_degree', 'search_chunk_configuration', 'init_chunk_manager'] +__all__ = ["Chunk", "ChunkManager", "classify_params_by_dp_degree", "search_chunk_configuration", "init_chunk_manager"] diff --git a/colossalai/zero/gemini/chunk/chunk.py b/colossalai/zero/gemini/chunk/chunk.py index a7682eaf62e97c618d97897e27a88b02378a33b1..bbef9013c20b45307e64efa484deb44f4029f5c3 100644 --- a/colossalai/zero/gemini/chunk/chunk.py +++ b/colossalai/zero/gemini/chunk/chunk.py @@ -4,8 +4,8 @@ from typing import Dict, List, Optional import torch import torch.distributed as dist +from torch.distributed import ProcessGroup -from colossalai.tensor import ProcessGroup as ColoProcessGroup from colossalai.utils import get_current_device @@ -17,12 +17,17 @@ class TensorState(Enum): READY_FOR_REDUCE = 4 -STATE_TRANS = ((TensorState.FREE, TensorState.HOLD), (TensorState.FREE, TensorState.COMPUTE), - (TensorState.HOLD, TensorState.FREE), (TensorState.HOLD, TensorState.COMPUTE), (TensorState.COMPUTE, - TensorState.HOLD), - (TensorState.COMPUTE, TensorState.HOLD_AFTER_BWD), (TensorState.HOLD_AFTER_BWD, TensorState.COMPUTE), - (TensorState.HOLD_AFTER_BWD, TensorState.READY_FOR_REDUCE), (TensorState.READY_FOR_REDUCE, - TensorState.HOLD)) +STATE_TRANS = ( + (TensorState.FREE, TensorState.HOLD), + (TensorState.FREE, TensorState.COMPUTE), + (TensorState.HOLD, TensorState.FREE), + (TensorState.HOLD, TensorState.COMPUTE), + (TensorState.COMPUTE, TensorState.HOLD), + (TensorState.COMPUTE, TensorState.HOLD_AFTER_BWD), + (TensorState.HOLD_AFTER_BWD, TensorState.COMPUTE), + (TensorState.HOLD_AFTER_BWD, TensorState.READY_FOR_REDUCE), + (TensorState.READY_FOR_REDUCE, TensorState.HOLD), +) @dataclass @@ -53,14 +58,16 @@ def alloc_storage(tensor: torch.Tensor) -> None: class Chunk: _total_number = 0 - def __init__(self, - chunk_size: int, - process_group: ColoProcessGroup, - dtype: torch.dtype, - init_device: Optional[torch.device] = None, - cpu_shard_init: bool = False, - keep_gathered: bool = False, - pin_memory: bool = False) -> None: + def __init__( + self, + chunk_size: int, + process_group: ProcessGroup, + dtype: torch.dtype, + init_device: Optional[torch.device] = None, + cpu_shard_init: bool = False, + keep_gathered: bool = False, + pin_memory: bool = False, + ) -> None: """ Chunk: A container owning a piece of contiguous memory space for tensors Here we use all-gather operation to gather the whole chunk. @@ -69,7 +76,7 @@ class Chunk: Args: chunk_size (int): the number of elements in the chunk - process_group (ColoProcessGroup): the process group of this chunk + process_group (ProcessGroup): the process group of this chunk dtype (torch.dtype): the data type of the chunk init_device (torch.device): optional, During the chunk construction process, where the tensor is stored. The default value is None, which is the current GPU @@ -83,7 +90,7 @@ class Chunk: self.chunk_size = chunk_size self.utilized_size = 0 - self.torch_pg = process_group.dp_process_group() + self.torch_pg = process_group self.pg_size = dist.get_world_size(self.torch_pg) self.pg_rank = dist.get_rank(self.torch_pg) @@ -99,9 +106,9 @@ class Chunk: device = init_device or get_current_device() # chunk_temp is a global chunk, which only exists during building the chunks. - self.chunk_temp = torch.zeros(chunk_size, dtype=dtype, device=device) # keep all zero + self.chunk_temp = torch.zeros(chunk_size, dtype=dtype, device=device) # keep all zero - self.cuda_global_chunk = None # we force cuda_global_chunk located in CUDA + self.cuda_global_chunk = None # we force cuda_global_chunk located in CUDA # cuda local chunk, which is sharded on GPUs self.cuda_shard = None @@ -134,7 +141,7 @@ class Chunk: # they are treated the same as that of the parameters in DDP during training. self.keep_gathered = keep_gathered if self.keep_gathered: - pin_memory = False # since this chunk is gathered, it doesn't need to pin + pin_memory = False # since this chunk is gathered, it doesn't need to pin # if pin_memory is True, we allocate a piece of CPU pin-memory # for it all the time @@ -160,7 +167,7 @@ class Chunk: if self.chunk_temp is not None: # this chunk is not closed - if self.chunk_temp.device.type == 'cuda': + if self.chunk_temp.device.type == "cuda": cuda_memory += self.chunk_mem else: cpu_memory += self.chunk_mem @@ -180,11 +187,11 @@ class Chunk: return self.chunk_temp.device.type else: if self.is_gathered: - return 'cuda' + return "cuda" elif self.cuda_shard is not None: - return 'cuda' + return "cuda" else: - return 'cpu' + return "cpu" @property def payload(self) -> torch.Tensor: @@ -217,8 +224,10 @@ class Chunk: if self.keep_gathered: return False else: - return self.tensor_state_cnter[TensorState.HOLD] + \ - self.tensor_state_cnter[TensorState.HOLD_AFTER_BWD] == self.num_tensors + return ( + self.tensor_state_cnter[TensorState.HOLD] + self.tensor_state_cnter[TensorState.HOLD_AFTER_BWD] + == self.num_tensors + ) @property def can_reduce(self): @@ -226,27 +235,25 @@ class Chunk: @property def has_inf_or_nan(self) -> bool: - """Check if the chunk has inf or nan values on CUDA. - """ + """Check if the chunk has inf or nan values on CUDA.""" if self.is_gathered: - valid_tensor = self.cuda_global_chunk[:self.utilized_size] + valid_tensor = self.cuda_global_chunk[: self.utilized_size] else: - assert self.cuda_shard is not None # only check on CUDA - valid_tensor = self.cuda_shard[:self.valid_end] + assert self.cuda_shard is not None # only check on CUDA + valid_tensor = self.cuda_shard[: self.valid_end] return torch.isinf(valid_tensor).any().item() | torch.isnan(valid_tensor).any().item() def set_l2_norm(self) -> None: - """Record l2 norm of this chunks on CUDA. - """ + """Record l2 norm of this chunks on CUDA.""" assert self.l2_norm is None, "you are calculating the l2 norm twice" if self.is_gathered: - valid_tensor = self.cuda_global_chunk[:self.utilized_size] + valid_tensor = self.cuda_global_chunk[: self.utilized_size] else: - assert self.cuda_shard is not None # calculate on CUDA - valid_tensor = self.cuda_shard[:self.valid_end] + assert self.cuda_shard is not None # calculate on CUDA + valid_tensor = self.cuda_shard[: self.valid_end] chunk_l2_norm = valid_tensor.data.float().norm(2) - self.l2_norm = chunk_l2_norm.item()**2 + self.l2_norm = chunk_l2_norm.item() ** 2 def append_tensor(self, tensor: torch.Tensor): """Add a tensor to the chunk. @@ -263,9 +270,9 @@ class Chunk: if new_utilized_size > self.chunk_size: raise ChunkFullError - self.chunk_temp[self.utilized_size:new_utilized_size].copy_(tensor.data.flatten()) + self.chunk_temp[self.utilized_size : new_utilized_size].copy_(tensor.data.flatten()) assert type(self.chunk_temp) == torch.Tensor, "copy_tensor_to_chunk_slice must use a torch tensor" - tensor.data = self.chunk_temp[self.utilized_size:new_utilized_size].view(tensor.shape) + tensor.data = self.chunk_temp[self.utilized_size : new_utilized_size].view(tensor.shape) # record all the information about the tensor self.num_tensors += 1 @@ -275,8 +282,7 @@ class Chunk: self.utilized_size = new_utilized_size def close_chunk(self): - """Close the chunk. Any tensor can't be appended to a closed chunk later. - """ + """Close the chunk. Any tensor can't be appended to a closed chunk later.""" # sanity check assert self.chunk_temp is not None @@ -286,7 +292,7 @@ class Chunk: elif self.utilized_size < self.shard_end: self.valid_end = self.utilized_size - self.shard_begin - if self.chunk_temp.device.type == 'cpu': + if self.chunk_temp.device.type == "cpu": self.cuda_global_chunk = self.chunk_temp.to(get_current_device()) self.__update_tensors_ptr() else: @@ -298,12 +304,12 @@ class Chunk: if self.keep_gathered: return - if self.pin_memory or self.shard_device.type == 'cpu': + if self.pin_memory or self.shard_device.type == "cpu": self.cpu_shard = torch.empty(self.shard_size, dtype=self.dtype, pin_memory=self.pin_memory) self.cpu_shard.copy_(self.cuda_shard) - self.cpu_vis_flag = True # cpu_shard has been visited + self.cpu_vis_flag = True # cpu_shard has been visited - if self.shard_device.type == 'cpu': + if self.shard_device.type == "cpu": self.cuda_shard = None def shard_move(self, device: torch.device, force_copy: bool = False): @@ -318,12 +324,12 @@ class Chunk: # when the current chunk is not synchronized with the optimizer # just use another way for the movement if not self.optim_sync_flag: - assert device.type == 'cuda', "each chunk should first be moved to CUDA" + assert device.type == "cuda", "each chunk should first be moved to CUDA" self.__paired_shard_move() self.optim_sync_flag = True return - if device.type == 'cuda': + if device.type == "cuda": assert device == get_current_device(), "can't move chunk to another device" if self.cuda_shard: @@ -333,7 +339,7 @@ class Chunk: if not self.pin_memory: self.cpu_shard = None - elif device.type == 'cpu': + elif device.type == "cpu": if self.cuda_shard is None: return @@ -350,8 +356,7 @@ class Chunk: raise NotImplementedError def access_chunk(self): - """Make the chunk usable for the parameters inside it. It's an operation done in CUDA. - """ + """Make the chunk usable for the parameters inside it. It's an operation done in CUDA.""" # sanity check assert self.chunk_temp is None @@ -360,8 +365,7 @@ class Chunk: self.__update_tensors_ptr() def release_chunk(self): - """Release the usable chunk. It's an operation done in CUDA. - """ + """Release the usable chunk. It's an operation done in CUDA.""" # sanity check assert self.chunk_temp is None @@ -369,8 +373,7 @@ class Chunk: self.__scatter() def reduce(self): - """Reduce scatter all the gradients. It's an operation done in CUDA. - """ + """Reduce scatter all the gradients. It's an operation done in CUDA.""" # sanity check assert self.is_gathered @@ -416,27 +419,25 @@ class Chunk: Copy data slice to the memory space indexed by the input tensor in the chunk. Args: - tensor (torch.Tensor): the tensor used to retrive meta information + tensor (torch.Tensor): the tensor used to retrieve meta information data_slice (torch.Tensor): the tensor to be copied to the chunk """ # sanity check assert self.is_gathered tensor_info = self.tensors_info[tensor] - self.cuda_global_chunk[tensor_info.offset:tensor_info.end].copy_(data_slice.data.flatten()) - tensor.data = self.cuda_global_chunk[tensor_info.offset:tensor_info.end].view(tensor.shape) + self.cuda_global_chunk[tensor_info.offset : tensor_info.end].copy_(data_slice.data.flatten()) + tensor.data = self.cuda_global_chunk[tensor_info.offset : tensor_info.end].view(tensor.shape) def get_valid_length(self) -> int: - """Get the valid length of the chunk's payload. - """ + """Get the valid length of the chunk's payload.""" if self.keep_gathered: return self.utilized_size else: return self.valid_end - def init_pair(self, friend_chunk: 'Chunk') -> None: - """Initialize the paired chunk. - """ + def init_pair(self, friend_chunk: "Chunk") -> None: + """Initialize the paired chunk.""" if self.paired_chunk is None and friend_chunk.paired_chunk is None: self.paired_chunk = friend_chunk friend_chunk.paired_chunk = self @@ -445,8 +446,7 @@ class Chunk: assert friend_chunk.paired_chunk is self def optim_update(self) -> None: - """Update the fp16 chunks via their fp32 chunks. It's used by the optimizer. - """ + """Update the fp16 chunks via their fp32 chunks. It's used by the optimizer.""" # sanity check assert self.paired_chunk is not None @@ -455,15 +455,15 @@ class Chunk: assert friend_chunk.is_gathered is True self.cuda_global_chunk.copy_(friend_chunk.cuda_global_chunk) self.optim_sync_flag = True - elif friend_chunk.device_type == 'cuda' and self.device_type == 'cuda': + elif friend_chunk.device_type == "cuda" and self.device_type == "cuda": self.cuda_shard.copy_(friend_chunk.cuda_shard) self.optim_sync_flag = True self.cpu_vis_flag = False else: # optim_sync_flag is set to False # see shard_move function for more details - assert friend_chunk.device_type == 'cpu' - assert self.device_type == 'cpu' + assert friend_chunk.device_type == "cpu" + assert self.device_type == "cpu" self.optim_sync_flag = False self.cpu_vis_flag = False @@ -492,7 +492,7 @@ class Chunk: self.cuda_shard = torch.empty(self.shard_size, dtype=self.dtype, device=self.cuda_global_chunk.device) - self.cuda_shard.copy_(self.cuda_global_chunk[self.shard_begin:self.shard_end]) + self.cuda_shard.copy_(self.cuda_global_chunk[self.shard_begin : self.shard_end]) free_storage(self.cuda_global_chunk) self.is_gathered = False @@ -518,7 +518,7 @@ class Chunk: assert type(self.cuda_global_chunk) == torch.Tensor for tensor, tensor_info in self.tensors_info.items(): - tensor.data = self.cuda_global_chunk[tensor_info.offset:tensor_info.end].view(tensor.shape) + tensor.data = self.cuda_global_chunk[tensor_info.offset : tensor_info.end].view(tensor.shape) def __update_one_tensor_info(self, tensor_info: TensorInfo, next_state: TensorState): self.tensor_state_cnter[tensor_info.state] -= 1 @@ -539,38 +539,41 @@ class Chunk: def __repr__(self, detailed: bool = True): output = [ "Chunk Information:\n", - "\tchunk size: {}, chunk dtype: {}, process group size: {}\n".format(self.chunk_size, self.dtype, - self.pg_size), + "\tchunk size: {}, chunk dtype: {}, process group size: {}\n".format( + self.chunk_size, self.dtype, self.pg_size + ), "\t# of tensors: {}, utilized size: {}, utilized percentage: {:.2f}\n".format( - self.num_tensors, self.utilized_size, self.utilized_size / self.chunk_size) + self.num_tensors, self.utilized_size, self.utilized_size / self.chunk_size + ), ] - def print_tensor(tensor, prefix=''): - output.append("{}shape: {}, dtype: {}, device: {}\n".format(prefix, tensor.shape, tensor.dtype, - tensor.device)) + def print_tensor(tensor, prefix=""): + output.append( + "{}shape: {}, dtype: {}, device: {}\n".format(prefix, tensor.shape, tensor.dtype, tensor.device) + ) if self.chunk_temp is not None: output.append("\tchunk temp:\n") - print_tensor(tensor=self.chunk_temp, prefix='\t\t') + print_tensor(tensor=self.chunk_temp, prefix="\t\t") if self.cuda_global_chunk is not None and self.cuda_global_chunk.storage().size() > 0: output.append("\tchunk total:\n") - print_tensor(tensor=self.cuda_global_chunk, prefix='\t\t') + print_tensor(tensor=self.cuda_global_chunk, prefix="\t\t") if self.cuda_shard is not None: output.append("\tcuda shard:\n") - print_tensor(tensor=self.cuda_shard, prefix='\t\t') + print_tensor(tensor=self.cuda_shard, prefix="\t\t") if self.cpu_shard is not None: output.append("\tcpu shard:\n") - print_tensor(tensor=self.cpu_shard, prefix='\t\t') + print_tensor(tensor=self.cpu_shard, prefix="\t\t") memory_info = self.memory_usage - output.append("\tmemory usage: cuda {}, cpu {}\n".format(memory_info['cuda'], memory_info['cpu'])) + output.append("\tmemory usage: cuda {}, cpu {}\n".format(memory_info["cuda"], memory_info["cpu"])) if detailed: output.append("\ttensor state monitor:\n") for st in TensorState: output.append("\t\t# of {}: {}\n".format(st, self.tensor_state_cnter[st])) - return ''.join(output) + return "".join(output) diff --git a/colossalai/zero/gemini/chunk/manager.py b/colossalai/zero/gemini/chunk/manager.py index d85df0b00476f88c4bf29943d88c9c9ea99ebbe8..957e41b02d49f79585445f3d66fc2f8245ea44eb 100644 --- a/colossalai/zero/gemini/chunk/manager.py +++ b/colossalai/zero/gemini/chunk/manager.py @@ -2,8 +2,9 @@ from collections import deque from typing import Deque, Dict, Iterable, List, Optional, Set, Tuple import torch +import torch.distributed as dist +from torch.distributed import ProcessGroup -from colossalai.tensor import ColoTensor from colossalai.utils import get_current_device from .chunk import Chunk, ChunkFullError, TensorState @@ -19,26 +20,28 @@ class ChunkManager: """ def __init__(self, chunk_configuration, init_device: Optional[torch.device] = None) -> None: - self.device = init_device or get_current_device() self.dp_degree_chunk_size_dict: Dict[int, int] = dict() self.kwargs_config = chunk_configuration for k, v in self.kwargs_config.items(): - self.dp_degree_chunk_size_dict[k] = v.pop('chunk_size') - v['init_device'] = self.device + self.dp_degree_chunk_size_dict[k] = v.pop("chunk_size") + v["init_device"] = self.device - self.chunk_groups: Dict[str, Deque] = dict() + self.chunk_groups: Dict[str, Deque[Chunk]] = dict() self.tensor_chunk_map: Dict[torch.Tensor, Chunk] = dict() self.accessed_chunks: Set[Chunk] = set() self.accessed_mem: int = 0 - self.total_mem: Dict[str, int] = {'cpu': 0, 'cuda': 0} - - def register_tensor(self, - tensor: ColoTensor, - group_type: str, - config_key: int, - cpu_offload: bool = False, - pin_memory: bool = False) -> None: + self.total_mem: Dict[str, int] = {"cpu": 0, "cuda": 0} + + def register_tensor( + self, + tensor: torch.Tensor, + group_type: str, + config_key: int, + process_group: ProcessGroup, + cpu_offload: bool = False, + pin_memory: bool = False, + ) -> None: """ Register a tensor to the chunk manager. Then, the tensor should be accessed by `get_chunks`. @@ -51,7 +54,7 @@ class ChunkManager: pin_memory: whether the chunk is pinned in the cpu memory """ assert tensor not in self.tensor_chunk_map - assert isinstance(tensor, ColoTensor), "Please feed ColoTensor to this ChunkManager" + assert isinstance(tensor, torch.Tensor), "Please feed Tensor to this ChunkManager" assert config_key in self.dp_degree_chunk_size_dict chunk_size = self.dp_degree_chunk_size_dict[config_key] @@ -73,12 +76,12 @@ class ChunkManager: if tensor.numel() > chunk_size: chunk_size = tensor.numel() - dp_size = tensor.get_dp_world_size() + dp_size = dist.get_world_size(process_group) chunk_size = chunk_size + (-chunk_size % dp_size) chunk = Chunk( chunk_size=chunk_size, - process_group=tensor.process_group, + process_group=process_group, dtype=tensor.dtype, cpu_shard_init=cpu_offload, pin_memory=pin_memory, @@ -92,53 +95,47 @@ class ChunkManager: self.tensor_chunk_map[tensor] = chunk_group[-1] def close_all_groups(self): - """Close all the chunks of all groups. - """ + """Close all the chunks of all groups.""" for group_name in self.chunk_groups: self.__close_one_chunk(self.chunk_groups[group_name][-1]) def access_chunk(self, chunk: Chunk) -> None: - """Make the chunk can be used for calculation. - """ + """Make the chunk can be used for calculation.""" if chunk in self.accessed_chunks: return - self.__sub_memroy_usage(chunk.memory_usage) - if chunk.device_type == 'cpu': + self.__sub_memory_usage(chunk.memory_usage) + if chunk.device_type == "cpu": chunk.shard_move(get_current_device()) self.__add_accessed_chunk(chunk) self.__add_memory_usage(chunk.memory_usage) def release_chunk(self, chunk: Chunk) -> None: - """Scatter the chunk in CUDA. - """ + """Scatter the chunk in CUDA.""" if chunk not in self.accessed_chunks: return if chunk.can_release: - self.__sub_memroy_usage(chunk.memory_usage) + self.__sub_memory_usage(chunk.memory_usage) self.__sub_accessed_chunk(chunk) self.__add_memory_usage(chunk.memory_usage) def move_chunk(self, chunk: Chunk, device: torch.device, force_copy: bool = False) -> None: - """Move the shard of the chunk to the target device. - """ + """Move the shard of the chunk to the target device.""" if not chunk.can_move or chunk.device_type == device.type: return - self.__sub_memroy_usage(chunk.memory_usage) + self.__sub_memory_usage(chunk.memory_usage) chunk.shard_move(device, force_copy) self.__add_memory_usage(chunk.memory_usage) def trans_tensor_state(self, tensor: torch.Tensor, state: TensorState) -> None: - """Transit tensor state according to pre-defined state machine. - """ + """Transit tensor state according to pre-defined state machine.""" chunk = self.tensor_chunk_map[tensor] chunk.tensor_trans_state(tensor, state) def reduce_chunk(self, chunk: Chunk) -> bool: - """Reduce or all reduce the chunk. - """ + """Reduce or all reduce the chunk.""" if not chunk.can_reduce: return False - self.__sub_memroy_usage(chunk.memory_usage) + self.__sub_memory_usage(chunk.memory_usage) chunk.reduce() self.__sub_accessed_chunk(chunk) self.__add_memory_usage(chunk.memory_usage) @@ -157,7 +154,7 @@ class ChunkManager: Copy data to the chunk. Args: - tensor (torch.Tensor): the tensor used to retrive meta information + tensor (torch.Tensor): the tensor used to retrieve meta information data (torch.Tensor): the tensor to be copied to the chunk """ chunk = self.tensor_chunk_map[tensor] @@ -211,28 +208,27 @@ class ChunkManager: def __repr__(self) -> str: msg = [ - 'Chunk Manager Information:\n', - 'Total memory: ' + ', '.join([f'{k}={v}B' for k, v in self.total_mem.items()]) + '\n' + "Chunk Manager Information:\n", + "Total memory: " + ", ".join([f"{k}={v}B" for k, v in self.total_mem.items()]) + "\n", ] for group_name, group in self.chunk_groups.items(): - msg.append(f'Group {group_name}:\n') + msg.append(f"Group {group_name}:\n") for i, chunk in enumerate(group): - msg.append(f'[{i}] {chunk}\n') - return ''.join(msg) + msg.append(f"[{i}] {chunk}\n") + return "".join(msg) - def __get_chunk_group(self, group_name: str) -> Deque: - """Register a chunk group. - """ + def __get_chunk_group(self, group_name: str) -> Deque[Chunk]: + """Register a chunk group.""" if group_name not in self.chunk_groups: self.chunk_groups[group_name] = deque() return self.chunk_groups[group_name] def __close_one_chunk(self, chunk: Chunk): - self.__sub_memroy_usage(chunk.memory_usage) + self.__sub_memory_usage(chunk.memory_usage) chunk.close_chunk() self.__add_memory_usage(chunk.memory_usage) - def __sub_memroy_usage(self, usage: Dict[str, int]): + def __sub_memory_usage(self, usage: Dict[str, int]): for k, v in usage.items(): self.total_mem[k] -= v diff --git a/colossalai/zero/gemini/chunk/search_utils.py b/colossalai/zero/gemini/chunk/search_utils.py index da58e038c8792bdc616250d0a4e512817438c34d..24d8537bad904980d7600aca6072279baba1d8d7 100644 --- a/colossalai/zero/gemini/chunk/search_utils.py +++ b/colossalai/zero/gemini/chunk/search_utils.py @@ -4,6 +4,7 @@ from typing import Dict, List, Optional, Tuple import numpy as np import torch.distributed as dist import torch.nn as nn +from torch.distributed import ProcessGroup from colossalai.tensor import ColoParameter from colossalai.utils import is_ddp_ignored @@ -59,7 +60,7 @@ def _get_unused_byte(size_list: List[int], chunk_size: int) -> int: return left + acc -def _tensor_numel(local_param: ColoParameter, strict_ddp_flag: bool) -> int: +def _tensor_numel(local_param: ColoParameter) -> int: """_tensor_numel Get the number of elements of a tensor. @@ -71,21 +72,19 @@ def _tensor_numel(local_param: ColoParameter, strict_ddp_flag: bool) -> int: Returns: int: the number of elements. """ - if strict_ddp_flag and type(local_param) is ColoParameter: - return local_param.numel_global() - else: - # if local_param is not ColoParameter, we assume it's replicated - return local_param.numel() + # TODO(ver217): support dtensor here + return local_param.numel() -def classify_params_by_dp_degree(param_order: OrderedParamGenerator, - strict_ddp_flag: bool = False) -> Dict[int, List[ColoParameter]]: +def classify_params_by_dp_degree( + param_order: OrderedParamGenerator, process_group: ProcessGroup +) -> Dict[int, List[ColoParameter]]: """classify_params_by_dp_degree Classify the parameters by their dp degree Args: - param_order (OrderedParamGenerator): the order of param be visied + param_order (OrderedParamGenerator): the order of param be vised strict_ddp_flag (bool, optional): whether to enable the strict ddp mode. Defaults to False. Returns: @@ -97,13 +96,7 @@ def classify_params_by_dp_degree(param_order: OrderedParamGenerator, # assert isinstance(param, ColoParameter), "please init model in the ColoInitContext" if is_ddp_ignored(param): continue - - if strict_ddp_flag or type(param) is not ColoParameter: - # if model is not initialized with ColoInitContext, we assume it's replicated - # TODO(ver217): integrate DTensor - param_key = dist.get_world_size() - else: - param_key = param.process_group.dp_world_size() + param_key = dist.get_world_size(process_group) if param_key not in params_dict: params_dict[param_key] = [] @@ -113,22 +106,24 @@ def classify_params_by_dp_degree(param_order: OrderedParamGenerator, def search_chunk_configuration( - model: nn.Module, - search_range_mb: float, - search_interval_byte: int, # hidden size is the best value for the interval - min_chunk_size_mb: float = 32, - filter_exlarge_params: bool = True, - strict_ddp_flag: bool = False, - memstas: Optional[MemStats] = None) -> Tuple[Dict, int, int]: + model: nn.Module, + search_range_m: float, + search_interval: int, # hidden size is the best value for the interval + min_chunk_size_m: float = 32, + filter_exlarge_params: bool = True, + strict_ddp_flag: bool = False, + process_group: Optional[ProcessGroup] = None, + memstas: Optional[MemStats] = None, +) -> Tuple[Dict, int, int]: """search_chunk_configuration Search the chunk configuration for a model. Args: model (nn.Module): torch module - search_range_mb (float): searching range in mega byte. - search_interval_byte (int): searching interval in byte. - min_chunk_size_mb (float, optional): the minimum size of a distributed chunk. + search_range_m (float): searching range divided by 2^20. + search_interval (int): searching interval. + min_chunk_size_m (float, optional): the minimum size of a distributed chunk, divided by 2^20.. filter_exlarge_params (bool, optional): filter extreme large parameters. Defaults to True. strict_ddp_flag (bool, optional): whether to enable the strict ddp mode. all parameters keep replicated in this mode. @@ -145,11 +140,11 @@ def search_chunk_configuration( for p in model.parameters(): param_order.append(p) - search_range_byte = round(search_range_mb * 1024**2) - min_chunk_size_byte = round(min_chunk_size_mb * 1024**2) - assert search_range_byte >= 0 + search_range = round(search_range_m * 1024**2) + min_chunk_size = round(min_chunk_size_m * 1024**2) + assert search_range >= 0 - params_dict = classify_params_by_dp_degree(param_order, strict_ddp_flag) + params_dict = classify_params_by_dp_degree(param_order, process_group) size_lcm = np.lcm.reduce(list(params_dict.keys())) config_dict: Dict[int, Dict] = dict() total_param_size = 0 @@ -157,12 +152,12 @@ def search_chunk_configuration( size_dict: Dict[int, List[int]] = dict() for dp_degree in params_dict: params_list = params_dict[dp_degree] - size_list = [_tensor_numel(p, strict_ddp_flag) for p in params_list] + size_list = [_tensor_numel(p) for p in params_list] group_acc_size = sum(size_list) total_param_size += group_acc_size # let small parameters keep gathered in CUDA all the time - if group_acc_size < min_chunk_size_byte: + if group_acc_size < min_chunk_size: config_dict[dp_degree] = dict(chunk_size=group_acc_size, keep_gathered=True) else: size_dict[dp_degree] = size_list @@ -170,15 +165,15 @@ def search_chunk_configuration( if filter_exlarge_params: _filter_exlarge_params(model, size_dict) - max_size = min_chunk_size_byte + max_size = min_chunk_size for key in size_dict: max_size = max(max_size, max(size_dict[key])) - start_size = int(math.ceil(max_size / search_interval_byte) * search_interval_byte) + start_size = int(math.ceil(max_size / search_interval) * search_interval) - min_chunk_waste = float('+inf') + min_chunk_waste = float("+inf") best_chunk_size = start_size - for chunk_size in range(start_size, start_size + search_range_byte + 1, search_interval_byte): + for chunk_size in range(start_size, start_size + search_range + 1, search_interval): temp_waste = 0 for key in size_dict: temp_waste += _get_unused_byte(size_dict[key], chunk_size) diff --git a/colossalai/zero/gemini/chunk/utils.py b/colossalai/zero/gemini/chunk/utils.py index 71242dcd6d498e537155abbb3d1f882d2b71da66..7a2ea360650bfff56ce8b25185e022f6f2bc3dc5 100644 --- a/colossalai/zero/gemini/chunk/utils.py +++ b/colossalai/zero/gemini/chunk/utils.py @@ -5,8 +5,6 @@ import torch import torch.distributed as dist import torch.nn as nn -from colossalai.utils import is_ddp_ignored - from .manager import ChunkManager from .search_utils import search_chunk_configuration @@ -17,16 +15,18 @@ def safe_div(a, b): return a / b -def init_chunk_manager(model: nn.Module, - init_device: Optional[torch.device] = None, - hidden_dim: Optional[int] = None, - verbose: bool = False, - **kwargs) -> ChunkManager: +def init_chunk_manager( + model: nn.Module, + init_device: Optional[torch.device] = None, + hidden_dim: Optional[int] = None, + verbose: bool = False, + **kwargs, +) -> ChunkManager: if hidden_dim: - search_interval_byte = hidden_dim + search_interval = hidden_dim else: - search_interval_byte = 1024 # defaults to 1kb - kwargs["search_interval_byte"] = search_interval_byte + search_interval = 1024 # defaults to 1024 + kwargs["search_interval"] = search_interval dist.barrier() begin = time() @@ -36,16 +36,18 @@ def init_chunk_manager(model: nn.Module, dist.barrier() end = time() span_s = end - begin - mb_size = 1024**2 - total_size /= mb_size - wasted_size /= mb_size + mega_unit = 1024**2 + total_size /= mega_unit + wasted_size /= mega_unit if verbose and dist.get_rank() == 0: - print("searching chunk configuration is completed in {:.2f} s.\n".format(span_s), - "used number: {:.2f} MB, wasted number: {:.2f} MB\n".format(total_size, wasted_size), - "total wasted percentage is {:.2f}%".format(100 * safe_div(wasted_size, total_size + wasted_size)), - sep='', - flush=True) + print( + "searching chunk configuration is completed in {:.2f} s.\n".format(span_s), + "used number: {:.2f} * 2^20, wasted number: {:.2f} * 2^20\n".format(total_size, wasted_size), + "total wasted percentage is {:.2f}%".format(100 * safe_div(wasted_size, total_size + wasted_size)), + sep="", + flush=True, + ) dist.barrier() chunk_manager = ChunkManager(config_dict, init_device) diff --git a/colossalai/zero/gemini/colo_init_context.py b/colossalai/zero/gemini/colo_init_context.py index 75f8576ca477977e03e054eb195c49f5e0048c5f..ab2ff8f920aa73dde00b1c7b8de14528de14f4f3 100644 --- a/colossalai/zero/gemini/colo_init_context.py +++ b/colossalai/zero/gemini/colo_init_context.py @@ -1,9 +1,10 @@ -from typing import Any, Dict, Iterator, Optional, Tuple, Union +from typing import Any, Iterator, Optional, Tuple, Union import torch from torch import nn -from colossalai.tensor import ColoParameter, ColoTensor, ProcessGroup +from colossalai.legacy.tensor import ProcessGroup +from colossalai.tensor import ColoParameter, ColoTensor from colossalai.utils.model.utils import InsertPostInitMethodToModuleSubClasses # find named_params includes replica @@ -11,7 +12,7 @@ from colossalai.utils.model.utils import InsertPostInitMethodToModuleSubClasses def _named_params_with_replica( module: nn.Module, - prefix: str = '', + prefix: str = "", recurse: bool = True, ) -> Iterator[Tuple[str, Union[nn.Parameter, ColoTensor]]]: modules = module.named_modules(prefix=prefix) if recurse else [(prefix, module)] @@ -20,16 +21,17 @@ def _named_params_with_replica( for name, val in mod._parameters.items(): if val is None: continue - name = mod_prefix + ('.' if mod_prefix else '') + name + name = mod_prefix + ("." if mod_prefix else "") + name yield name, val -def _convert_to_coloparam(param: torch.nn.Parameter, - device: torch.device, - dtype=torch.float, - default_pg: Optional[ProcessGroup] = None, - default_dist_spec: Optional[Any] = None) -> ColoParameter: - +def _convert_to_coloparam( + param: torch.nn.Parameter, + device: torch.device, + dtype=torch.float, + default_pg: Optional[ProcessGroup] = None, + default_dist_spec: Optional[Any] = None, +) -> ColoParameter: if type(param) is ColoParameter: return param # detaching tensor is necessary for optimizers. @@ -65,12 +67,13 @@ def ColoModulize(module): class ColoInitContext(InsertPostInitMethodToModuleSubClasses): - - def __init__(self, - device: torch.device = torch.device('cpu'), - dtype: torch.dtype = torch.float, - default_pg: Optional[ProcessGroup] = None, - default_dist_spec=None): + def __init__( + self, + device: torch.device = torch.device("cpu"), + dtype: torch.dtype = torch.float, + default_pg: Optional[ProcessGroup] = None, + default_dist_spec=None, + ): """ Args: device (torch.device): the device where parameters initialized are resident. Defaults to torch.device('cpu'). @@ -87,7 +90,8 @@ class ColoInitContext(InsertPostInitMethodToModuleSubClasses): self._default_dist_spec = default_dist_spec def _register_colo_modules(self): - from colossalai.nn.parallel.layers import ColoEmbedding, ColoLinear, register_colo_module + from colossalai.legacy.nn.parallel.layers import ColoEmbedding, ColoLinear, register_colo_module + register_colo_module(torch.nn.Linear, ColoLinear()) register_colo_module(torch.nn.Embedding, ColoEmbedding()) @@ -104,25 +108,25 @@ class ColoInitContext(InsertPostInitMethodToModuleSubClasses): if type(param) is ColoParameter: continue - split = name.rfind('.') - if split >= 0: # param in submodule + split = name.rfind(".") + if split >= 0: # param in submodule module_name = name[:split] - param_name = name[split + 1:] + param_name = name[split + 1 :] else: - module_name = '' # param in current module + module_name = "" # param in current module param_name = name name_list.append((module_name, param_name)) - replaced_tensors = dict( - ) # record mapping between (torch.Tensor, ColoTensor) to distinguish the same reference + replaced_tensors = dict() # record mapping between (torch.Tensor, ColoTensor) to distinguish the same reference for module_name, param_name in name_list: submodule = module.get_submodule(module_name) param = submodule.get_parameter(param_name) if param in replaced_tensors: colo_param = replaced_tensors[param] else: - colo_param = _convert_to_coloparam(param, self._device, self._dtype, self._default_pg, - self._default_dist_spec) + colo_param = _convert_to_coloparam( + param, self._device, self._dtype, self._default_pg, self._default_dist_spec + ) replaced_tensors[param] = colo_param delattr(submodule, param_name) setattr(submodule, param_name, colo_param) @@ -135,11 +139,11 @@ class ColoInitContext(InsertPostInitMethodToModuleSubClasses): for param in module.parameters(): param_number += 1 - meta_param_number += (param.device.type == 'meta') + meta_param_number += param.device.type == "meta" for buffer in module.buffers(): buffer_number += 1 - meta_buffer_number += (buffer.device.type == 'meta') + meta_buffer_number += buffer.device.type == "meta" if meta_param_number > 0 and meta_param_number != param_number: raise ValueError("Meta parameters and valued parameters can not be in the same model") @@ -151,11 +155,13 @@ class ColoInitContext(InsertPostInitMethodToModuleSubClasses): buffer.data = buffer.data.to(device=self._device) -def post_process_colo_init_ctx(model: torch.nn.Module, - device: torch.device = torch.device('cpu'), - dtype: torch.dtype = torch.float, - default_pg: Optional[ProcessGroup] = None, - default_dist_spec=None): +def post_process_colo_init_ctx( + model: torch.nn.Module, + device: torch.device = torch.device("cpu"), + dtype: torch.dtype = torch.float, + default_pg: Optional[ProcessGroup] = None, + default_dist_spec=None, +): """post_process_colo_init_ctx This function is called after `ColoInitContext`. @@ -177,8 +183,8 @@ def post_process_colo_init_ctx(model: torch.nn.Module, # print(f"{n} is not a ColoParameter. We are going to converting it to ColoParameter") torch_params.append((n, p)) - for (n, param) in torch_params: - name_list = n.split('.') + for n, param in torch_params: + name_list = n.split(".") module = model for i in range(len(name_list) - 1): module = module._modules[name_list[i]] diff --git a/colossalai/zero/gemini/gemini_ddp.py b/colossalai/zero/gemini/gemini_ddp.py index 8a001b114e9a89fd65c45a5e2466b3d31627affc..0ba9e53cfcd6b920673ee65adbac3fb94578df36 100644 --- a/colossalai/zero/gemini/gemini_ddp.py +++ b/colossalai/zero/gemini/gemini_ddp.py @@ -2,21 +2,21 @@ import itertools from collections import OrderedDict from contextlib import nullcontext from functools import partial -from typing import Dict, Iterator, List, Optional, Union +from typing import Dict, Iterable, Iterator, List, Optional, Set, Tuple, Union import torch import torch.distributed as dist import torch.nn as nn +from torch.distributed import ProcessGroup +from torch.distributed.distributed_c10d import _get_default_group -from colossalai.checkpoint_io.utils import calculate_tensor_size +from colossalai.checkpoint_io.utils import StateDictSharder +from colossalai.interface import ModelWrapper +from colossalai.lazy import LazyTensor from colossalai.logging import get_dist_logger -from colossalai.nn.parallel.data_parallel import ColoDDP, _cast_float, free_storage -from colossalai.tensor import ProcessGroup as ColoProcessGroup -from colossalai.tensor import ReplicaSpec -from colossalai.tensor.colo_parameter import ColoParameter, ColoTensor, ColoTensorSpec +from colossalai.tensor.colo_parameter import ColoParameter from colossalai.tensor.param_op_hook import ColoParamOpHookManager -from colossalai.utils import get_current_device, is_ddp_ignored -from colossalai.utils.model.experimental import LazyTensor +from colossalai.utils import _cast_float, free_storage, get_current_device, is_ddp_ignored from .chunk import Chunk, ChunkManager, TensorState, init_chunk_manager from .gemini_hook import GeminiZeROHook @@ -27,17 +27,16 @@ from .utils import get_temp_total_chunk_on_cuda try: from torch.nn.modules.module import _EXTRA_STATE_KEY_SUFFIX, _IncompatibleKeys except ImportError: - _EXTRA_STATE_KEY_SUFFIX = '_extra_state' + _EXTRA_STATE_KEY_SUFFIX = "_extra_state" __all__ = [ - 'ZeroDDP', - 'GeminiDDP', + "GeminiDDP", ] -class ZeroDDP(ColoDDP): - """ZeRO DDP for ColoTensor. - Warning: Nested ZeroDDP is not supported now. +class GeminiDDP(ModelWrapper): + """ZeRO DDP. + Warning: Nested GeminiDDP is not supported now. It is designed to be used with ChunkManager and GeminiManager. For more details, see the API reference of ``ChunkManager`` and ``GeminiManager``. @@ -51,26 +50,70 @@ class ZeroDDP(ColoDDP): strict_ddp_mode (bool): If set to True, there is no tensor sharding, each tensor is replicated. Defaults to False. Users can set it to True, when they clearly know that they only need DDP. scatter_after_inference (bool): If set to True, the model will be scattered after inference. This will save memory but slow down the consecutive inference. + mixed_precision (torch.dtype): If set to torch.float16, the model will be trained in fp16. Otherwise, the model will be trained in bf16. Defaults to torch.float16. """ - def __init__(self, - module: torch.nn.Module, - gemini_manager: GeminiManager, - pin_memory: bool = False, - force_outputs_fp32: bool = False, - strict_ddp_mode: bool = False, - scatter_after_inference: bool = True) -> None: - self.gemini_manager = gemini_manager - self.chunk_manager: ChunkManager = gemini_manager.chunk_manager + def __init__( + self, + module: torch.nn.Module, + chunk_config_dict: Optional[dict] = None, + chunk_init_device: torch.device = torch.device("cpu"), + placement_policy: str = "static", + shard_param_frac: float = 1.0, # only for static placement + offload_optim_frac: float = 0.0, # only for static placement + offload_param_frac: float = 0.0, # only for static placement + warmup_non_model_data_ratio: float = 0.8, # only for auto placement + steady_cuda_cap_ratio: float = 0.9, # only for auto placement + search_range_m: int = 32, # chunk search options + hidden_dim: Optional[int] = None, # chunk search options + min_chunk_size_m: float = 32, # chunk search options + pin_memory: bool = False, + force_outputs_fp32: bool = False, + strict_ddp_mode: bool = False, + scatter_after_inference: bool = True, + mixed_precision: torch.dtype = torch.float16, + process_group: Optional[ProcessGroup] = None, + memstats: Optional[MemStats] = None, # genimi memory stats + verbose: bool = False, + ) -> None: + assert mixed_precision in (torch.float16, torch.bfloat16) + if chunk_config_dict is not None: + self.chunk_manager = ChunkManager(chunk_config_dict, chunk_init_device) + else: + # some ugly hotfix for the compatibility with Lightning + if search_range_m is None: + search_range_m = 32 + self.chunk_manager = init_chunk_manager( + model=module, + init_device=chunk_init_device, + hidden_dim=hidden_dim, + search_range_m=search_range_m, + min_chunk_size_m=min_chunk_size_m, + strict_ddp_flag=strict_ddp_mode, + process_group=process_group, + verbose=verbose, + ) + self.gemini_manager = GeminiManager( + placement_policy, + self.chunk_manager, + memstats, + shard_param_frac=shard_param_frac, + offload_optim_frac=offload_optim_frac, + offload_param_frac=offload_param_frac, + warmup_non_model_data_ratio=warmup_non_model_data_ratio, + steady_cuda_cap_ratio=steady_cuda_cap_ratio, + ) self.force_outputs_fp32 = force_outputs_fp32 - self.param_op_hook = GeminiZeROHook(gemini_manager) - self.fp32_params: List[ColoTensor] = list() + self.param_op_hook = GeminiZeROHook(self.gemini_manager) + self.fp32_params: List[torch.Tensor] = list() self.fp16_params: List[ColoParameter] = list() self.overflow_counter = 0 self.grads_device: Dict[torch.Tensor, torch.device] = dict() self.param2name: Dict[nn.Parameter, str] = dict() self.name2param: Dict[str, nn.Parameter] = dict() self.scatter_after_inference = scatter_after_inference + self.mixed_precision = mixed_precision + self.dp_process_group = process_group or _get_default_group() self._logger = get_dist_logger() @@ -84,23 +127,97 @@ class ZeroDDP(ColoDDP): for p in module.parameters(): param_order.append(p) - self._init_chunks(param_order=param_order, - strict_ddp_mode=strict_ddp_mode, - cpu_offload=self.gemini_manager.policy_name != 'cuda', - pin_memory=pin_memory) - for name, param in module.named_parameters(): self.param2name[param] = name for m_name, m_var in module.named_modules(): for p_name, p_var in m_var.named_parameters(recurse=False): - param_name = m_name + '.' + p_name if m_name else p_name + param_name = m_name + "." + p_name if m_name else p_name self.name2param[param_name] = p_var - super().__init__(module, process_group=ColoProcessGroup()) + + self._init_chunks( + param_order=param_order, + strict_ddp_mode=strict_ddp_mode, + cpu_offload=self.gemini_manager.policy_name != "cuda", + pin_memory=pin_memory, + ) + super().__init__(module) + self._non_persistent_buffers_set = self._get_non_persistent_buffers_set(module) self._cast_buffers() + # register grad hook + for p in module.parameters(): + if is_ddp_ignored(p): + continue + if p.requires_grad: + p.register_hook(partial(self.grad_handle, p)) - def _post_forward(self): - """This function is only triggered for inference. + def parameters(self, recurse: bool = True): + return self.module.parameters(recurse) + + def named_parameters(self, prefix: str = "", recurse: bool = True): + return self.module.named_parameters(prefix, recurse) + + def named_buffers(self, prefix: str = "", recurse: bool = True): + return self.module.named_buffers(prefix, recurse) + + def named_children(self): + return self.module.named_children() + + def named_modules( + self, memo: Optional[Set[torch.nn.Module]] = None, prefix: str = "", remove_duplicate: bool = True + ): + return self.module.named_modules(memo, prefix, remove_duplicate) + + @staticmethod + def set_params_to_ignore(params_to_ignore: Iterable[torch.Tensor]) -> None: + """Sets parameters to be ignored by DDP. + This method must be called before initializing ColoDDP. + + Example: + >>> params_to_ignore = [] + >>> for p in module.parameters(): + >>> if should_ignore(p): + >>> params_to_ignore.append(p) + >>> ColoDDP.set_params_to_ignore(params_to_ignore) + >>> module = ColoDDP(module) + + Args: + params_to_ignore (Iterable[torch.Tensor]): A list of parameters to be ignored. + """ + for p in params_to_ignore: + p._ddp_to_ignore = True + + def _get_non_persistent_buffers_set( + self, module, memo: Optional[Set[nn.Module]] = None, prefix: str = "", remove_duplicate: bool = True + ): + r""" + Args: + memo: a memo to store the set of modules already added to the result + prefix: a prefix that will be added to the name of the module + remove_duplicate: whether to remove the duplicated module instances in the result + or not """ + + if memo is None: + memo = set() + self_non_persistent_set = set() + if module not in memo: + if remove_duplicate: + memo.add(module) + self_non_persistent_set = set( + map(lambda key: prefix + ("." if prefix else "") + key, module._non_persistent_buffers_set) + ) + for name, sub_module in module._modules.items(): + if sub_module is None: + continue + submodule_prefix = prefix + ("." if prefix else "") + name + child_non_persistent_set = self._get_non_persistent_buffers_set( + sub_module, memo, submodule_prefix, remove_duplicate + ) + self_non_persistent_set = set.union(self_non_persistent_set, child_non_persistent_set) + return self_non_persistent_set + + def _post_forward(self): + """This function is only triggered for inference.""" access_list = list(self.chunk_manager.accessed_chunks) # we need to scatter all accessed chunks and move them to their original places for chunk in access_list: @@ -117,10 +234,11 @@ class ZeroDDP(ColoDDP): # check whether we are in a inference mode grad_flag = torch.is_grad_enabled() if not grad_flag: - assert not self.gemini_manager.need_warmup or not self.gemini_manager.is_warmup( + assert ( + not self.gemini_manager.need_warmup or not self.gemini_manager.is_warmup() ), "You should run a completed iteration as your warmup iter" - args, kwargs = _cast_float(args, torch.half), _cast_float(kwargs, torch.half) + args, kwargs = _cast_float(args, self.mixed_precision), _cast_float(kwargs, self.mixed_precision) self.module.zero_grad(set_to_none=True) if not grad_flag: outputs = self._inference_forward(*args, **kwargs) @@ -134,8 +252,7 @@ class ZeroDDP(ColoDDP): return outputs def _inference_forward(self, *args, **kwargs): - """This function is only triggered for inference. - """ + """This function is only triggered for inference.""" fwd_ctx = ColoParamOpHookManager.use_hooks(self.param_op_hook) if not self.scatter_after_inference: # gather all chunks @@ -171,12 +288,14 @@ class ZeroDDP(ColoDDP): if not is_ddp_ignored(param) and not getattr(param, "_gemini_reduced"): error_params.append(self.param2name[param]) error_str = "\n\t".join(error_params) - raise RuntimeError("ZERO DDP error: the synchronization of gradients doesn't exit properly.", - "The most possible reason is that the model is not compatible with ZeroDDP.\n", - f"{error_str}") + raise RuntimeError( + "ZERO DDP error: the synchronization of gradients doesn't exit properly.", + "The most possible reason is that the model is not compatible with GeminiDDP.\n", + f"{error_str}", + ) self._setup_grads_ptr() self._logger.debug( - f'comp cuda demand time: {self.gemini_manager._comp_cuda_demand_time}, layout time: {self.gemini_manager._layout_time}, evict time: {self.gemini_manager._evict_time}, CPU->CUDA vol: {self.gemini_manager._h2d_volume}B, CUDA->CPU vol: {self.gemini_manager._d2h_volume}' + f"comp cuda demand time: {self.gemini_manager._comp_cuda_demand_time}, layout time: {self.gemini_manager._layout_time}, evict time: {self.gemini_manager._evict_time}, CPU->CUDA vol: {self.gemini_manager._h2d_volume}B, CUDA->CPU vol: {self.gemini_manager._d2h_volume}" ) self.gemini_manager.post_iter() @@ -192,13 +311,16 @@ class ZeroDDP(ColoDDP): self._post_backward() def grad_handle(self, p, grad): + setattr(p, "_gemini_reduced", True) empty_grad = torch.empty_like(grad) free_storage(empty_grad) with torch._C.DisableTorchFunction(): chunk = self.chunk_manager.get_chunk(p) if chunk.tensors_info[p].state != TensorState.HOLD_AFTER_BWD: - raise RuntimeError(f"Parameter `{self.param2name[p]}` failed at the gradient reduction. " - "Some unsupported torch function is operated upon this parameter.") + raise RuntimeError( + f"Parameter `{self.param2name[p]}` failed at the gradient reduction. " + "Some unsupported torch function is operated upon this parameter." + ) self.chunk_manager.trans_tensor_state(p, TensorState.READY_FOR_REDUCE) chunk.copy_tensor_to_chunk_slice(p, grad) reduced = self.chunk_manager.reduce_chunk(chunk) @@ -222,12 +344,9 @@ class ZeroDDP(ColoDDP): for tensor in chunk.get_tensors(): self.grads_device[tensor] = device - def state_dict(self, - destination=None, - prefix='', - keep_vars=False, - only_rank_0: bool = True, - dtype: torch.dtype = torch.float16): + def state_dict( + self, destination=None, prefix="", keep_vars=False, only_rank_0: bool = True, dtype: torch.dtype = torch.float16 + ): """Returns a dictionary containing a whole state of the module. Both parameters and persistent buffers (e.g. running averages) are included. @@ -274,7 +393,7 @@ class ZeroDDP(ColoDDP): record_tensor = torch.empty([0]) record_flag = (not only_rank_0) | (dist.get_rank(chunk.torch_pg) == 0) if record_flag: - record_tensor = temp_chunk[tensor_info.offset:tensor_info.end].view(tensor.shape).cpu() + record_tensor = temp_chunk[tensor_info.offset : tensor_info.end].view(tensor.shape).cpu() assert tensor not in chunk_to_save_data chunk_to_save_data[tensor] = record_tensor @@ -282,8 +401,9 @@ class ZeroDDP(ColoDDP): del temp_chunk return chunk_to_save_data - def _get_param_to_save_data(self, param_list: List[torch.nn.Parameter], only_rank_0: bool, - dtype: torch.dtype) -> Dict: + def _get_param_to_save_data( + self, param_list: List[torch.nn.Parameter], only_rank_0: bool, dtype: torch.dtype + ) -> Dict: """ get param content from chunks. @@ -342,11 +462,13 @@ class ZeroDDP(ColoDDP): destination[prefix + name] = buf if keep_vars else buf.detach() # save extra states extra_state_key = prefix + _EXTRA_STATE_KEY_SUFFIX - if getattr(self.__class__, "get_extra_state", - torch.nn.Module.get_extra_state) is not torch.nn.Module.get_extra_state: + if ( + getattr(self.__class__, "get_extra_state", torch.nn.Module.get_extra_state) + is not torch.nn.Module.get_extra_state + ): destination[extra_state_key] = self.get_extra_state() - def load_state_dict(self, state_dict: 'OrderedDict[str, torch.Tensor]', strict: bool = True): + def load_state_dict(self, state_dict: "OrderedDict[str, torch.Tensor]", strict: bool = True): r"""Copies parameters and buffers from :attr:`state_dict` into this module and its descendants. If :attr:`strict` is ``True``, then the keys of :attr:`state_dict` must exactly match the keys returned @@ -374,32 +496,38 @@ class ZeroDDP(ColoDDP): error_msgs: List[str] = [] # copy state_dict so _load_from_state_dict can modify it - metadata = getattr(state_dict, '_metadata', None) + metadata = getattr(state_dict, "_metadata", None) state_dict = state_dict.copy() if metadata is not None: # mypy isn't aware that "_metadata" exists in state_dict - state_dict._metadata = metadata # type: ignore[attr-defined] + state_dict._metadata = metadata # type: ignore[attr-defined] - prefix = '' + prefix = "" local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {}) self._load_from_state_dict(state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs) if strict: if len(unexpected_keys) > 0: error_msgs.insert( - 0, 'Unexpected key(s) in state_dict: {}. '.format(', '.join( - '"{}"'.format(k) for k in unexpected_keys))) + 0, + "Unexpected key(s) in state_dict: {}. ".format( + ", ".join('"{}"'.format(k) for k in unexpected_keys) + ), + ) if len(missing_keys) > 0: error_msgs.insert( - 0, 'Missing key(s) in state_dict: {}. '.format(', '.join('"{}"'.format(k) for k in missing_keys))) + 0, "Missing key(s) in state_dict: {}. ".format(", ".join('"{}"'.format(k) for k in missing_keys)) + ) if len(error_msgs) > 0: - raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format( - self.__class__.__name__, "\n\t".join(error_msgs))) + raise RuntimeError( + "Error(s) in loading state_dict for {}:\n\t{}".format(self.__class__.__name__, "\n\t".join(error_msgs)) + ) return _IncompatibleKeys(missing_keys, unexpected_keys) - def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, - error_msgs): + def _load_from_state_dict( + self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs + ): r"""Copies parameters and buffers from :attr:`state_dict` into only this module, but not its descendants. This is called on every submodule in :meth:`~torch.nn.Module.load_state_dict`. Metadata saved for this @@ -447,19 +575,21 @@ class ZeroDDP(ColoDDP): input_param = input_param[0] if input_param.shape != dest_tensor.shape: # local shape should match the one in checkpoint - error_msgs.append('size mismatch for {}: copying a param with shape {} from checkpoint, ' - 'the shape in current model is {}.'.format(state_key, input_param.shape, - dest_tensor.shape)) + error_msgs.append( + "size mismatch for {}: copying a param with shape {} from checkpoint, " + "the shape in current model is {}.".format(state_key, input_param.shape, dest_tensor.shape) + ) return try: with torch.no_grad(): copy_func(input_param) except Exception as ex: - error_msgs.append('While copying the parameter named "{}", ' - 'whose dimensions in the model are {} and ' - 'whose dimensions in the checkpoint are {}, ' - 'an exception occurred : {}.'.format(state_key, dest_tensor.size(), - input_param.size(), ex.args)) + error_msgs.append( + 'While copying the parameter named "{}", ' + "whose dimensions in the model are {} and " + "whose dimensions in the checkpoint are {}, " + "an exception occurred : {}.".format(state_key, dest_tensor.size(), input_param.size(), ex.args) + ) elif strict: missing_keys.append(state_key) @@ -483,30 +613,32 @@ class ZeroDDP(ColoDDP): for tensor, tensor_info in chunk.tensors_info.items(): parameter_name = fp32_to_name[tensor] - parameter_slice = temp_chunk[tensor_info.offset:tensor_info.end] + parameter_slice = temp_chunk[tensor_info.offset : tensor_info.end] load(parameter_name, tensor, partial(load_fp32_parameter, parameter_slice)) if chunk.is_gathered: chunk.cuda_global_chunk.copy_(temp_chunk) elif chunk.cuda_shard is not None: - chunk.cuda_shard.copy_(temp_chunk[chunk.shard_begin:chunk.shard_end]) + chunk.cuda_shard.copy_(temp_chunk[chunk.shard_begin : chunk.shard_end]) else: - chunk.cpu_shard.copy_(temp_chunk[chunk.shard_begin:chunk.shard_end]) + chunk.cpu_shard.copy_(temp_chunk[chunk.shard_begin : chunk.shard_end]) del temp_chunk for chunk_32 in chunk_list: chunk_16 = chunk_32.paired_chunk assert chunk_16 is not None - chunk_16.optim_update() + chunk_16.payload.copy_(chunk_32.payload) for name, buf in persistent_buffers.items(): if buf is not None: load(name, buf, buf.copy_) extra_state_key = prefix + _EXTRA_STATE_KEY_SUFFIX - if getattr(self.__class__, "set_extra_state", - torch.nn.Module.set_extra_state) is not torch.nn.Module.set_extra_state: + if ( + getattr(self.__class__, "set_extra_state", torch.nn.Module.set_extra_state) + is not torch.nn.Module.set_extra_state + ): if extra_state_key in state_dict: self.set_extra_state(state_dict[extra_state_key]) elif strict: @@ -517,64 +649,61 @@ class ZeroDDP(ColoDDP): if strict: for key in state_dict.keys(): if key.startswith(prefix) and key != extra_state_key: - input_name = key[len(prefix):] + input_name = key[len(prefix) :] if input_name not in local_state: unexpected_keys.append(key) def _init_chunks(self, param_order, strict_ddp_mode: bool, cpu_offload: bool, pin_memory: bool): - ddp_pg = ColoProcessGroup() + dp_world_size = dist.get_world_size(self.dp_process_group) for p in param_order.generate(): self._preprocess_param(p) assert type(p) is ColoParameter - # gather sharded parameters in the strict ddp mode - if strict_ddp_mode: - if not p.is_replicate(): - p.set_dist_spec(ReplicaSpec()) - p.set_process_group(pg=ddp_pg) - # ignore the parameters with no gradient if not p.requires_grad: self.set_params_to_ignore([p]) # move ignored parameters to CUDA if is_ddp_ignored(p): - p.data = p.data.to(device=get_current_device(), dtype=torch.float16) + p.data = p.data.to(device=get_current_device(), dtype=self.mixed_precision) continue # create a fp32 parameter - fp32_data = p.data.float() - fp32_p = ColoTensor(fp32_data, spec=ColoTensorSpec(p.process_group)) + fp32_p = p.data.float() # create a fp16 parameter - p.data = p.data.half() + p.data = p.data.to(self.mixed_precision) # register the fp16 parameter and fp32 parameter in the chunk manager - dp_world_size = p.process_group.dp_world_size() - self.chunk_manager.register_tensor(tensor=p, - group_type='fp16_param', - config_key=dp_world_size, - cpu_offload=cpu_offload, - pin_memory=pin_memory) - self.chunk_manager.register_tensor(tensor=fp32_p, - group_type='fp32_param', - config_key=dp_world_size, - cpu_offload=cpu_offload, - pin_memory=pin_memory) + self.chunk_manager.register_tensor( + tensor=p, + group_type="fp16_param", + config_key=dp_world_size, + process_group=self.dp_process_group, + cpu_offload=cpu_offload, + pin_memory=pin_memory, + ) + self.chunk_manager.register_tensor( + tensor=fp32_p, + group_type="fp32_param", + config_key=dp_world_size, + process_group=self.dp_process_group, + cpu_offload=cpu_offload, + pin_memory=pin_memory, + ) self.fp16_params.append(p) self.fp32_params.append(fp32_p) - self.grads_device[p] = self.gemini_manager.default_device self.chunk_manager.close_all_groups() + self.gemini_manager.setup_grads_device(self.fp16_params, self.grads_device) + # move master weights to corresponding device and setup paired chunks for p, fp32_p in zip(self.fp16_params, self.fp32_params): chunk_16 = self.chunk_manager.get_chunk(p) chunk_32 = self.chunk_manager.get_chunk(fp32_p) chunk_32.init_pair(chunk_16) - - # keep gathered chunks are in CUDA - if chunk_16.keep_gathered: - self.grads_device[p] = get_current_device() + if chunk_32.device_type != self.grads_device[p].type: + self.chunk_manager.move_chunk(chunk_32, self.grads_device[p]) def _cast_buffers(self): for buffer in self.module.buffers(): @@ -582,9 +711,9 @@ class ZeroDDP(ColoDDP): buffer.materialize() buffer.data = buffer.cuda() if torch.is_floating_point(buffer): - buffer.data = buffer.half() + buffer.data = buffer.to(self.mixed_precision) - def _preprocess_param(self, p: Union[nn.Parameter, ColoParameter, 'LazyTensor']) -> None: + def _preprocess_param(self, p: Union[nn.Parameter, ColoParameter, "LazyTensor"]) -> None: """Convert parameter to ColoParameter in-place. Args: p (Union[nn.Parameter, ColoParameter, LazyTensor]): parameter to be converted @@ -599,12 +728,14 @@ class ZeroDDP(ColoDDP): p.__class__ = ColoParameter p.__init__(p, requires_grad=requires_grad) - def state_dict_shard(self, - prefix: str = '', - keep_vars: bool = False, - max_shard_size: int = 1024, - only_rank_0: bool = True, - dtype: torch.dtype = torch.float16) -> Iterator[OrderedDict]: + def state_dict_shard( + self, + prefix: str = "", + keep_vars: bool = False, + max_shard_size: int = 1024, + only_rank_0: bool = True, + dtype: torch.dtype = torch.float16, + ) -> Iterator[Tuple[OrderedDict, int]]: """Returns dictionaries containing a whole state of the module one by one. The max size of dictionary shard is specified by ``max_shard_size``. Both parameters and persistent buffers (e.g. running averages) are included. @@ -622,7 +753,7 @@ class ZeroDDP(ColoDDP): Yields: Iterator[OrderedDict]: A generator of state dict shard """ - sharder = _StateDictSharder(max_shard_size) + sharder = StateDictSharder(max_shard_size) # get the mapping between copies and fp16 parameters fp16_to_fp32 = dict() @@ -644,9 +775,9 @@ class ZeroDDP(ColoDDP): gathered_param_buffer.update(self._get_chunk_to_save_data(chunk, only_rank_0, dtype)) gathered_param = gathered_param_buffer.pop(fp32_param) - block = sharder.append(prefix + name, gathered_param) + block, block_size = sharder.append_param(prefix + name, gathered_param) if block is not None: - yield block + yield block, block_size del fp16_to_fp32 del gathered_param_buffer @@ -655,93 +786,18 @@ class ZeroDDP(ColoDDP): for name, buf in self.named_buffers(): if buf is not None and name not in self._non_persistent_buffers_set: buffer = buf if keep_vars else buf.detach() - block = sharder.append(prefix + name, buffer) + block, block_size = sharder.append_param(prefix + name, buffer) if block is not None: - yield block + yield block, block_size # save extra states extra_state_key = prefix + _EXTRA_STATE_KEY_SUFFIX - if getattr(self.__class__, "get_extra_state", - torch.nn.Module.get_extra_state) is not torch.nn.Module.get_extra_state: + if ( + getattr(self.__class__, "get_extra_state", torch.nn.Module.get_extra_state) + is not torch.nn.Module.get_extra_state + ): extra_state = self.get_extra_state() - block = sharder.append(extra_state_key, extra_state) + block, block_size = sharder.append_param(extra_state_key, extra_state) if block is not None: - yield block - - yield sharder.current_block - - -class _StateDictSharder: - - def __init__(self, max_shard_size: int) -> None: - self.max_shard_size = max_shard_size - self.current_block = OrderedDict() - self.current_block_size = 0 - - def append(self, name: str, tensor: torch.Tensor) -> Optional[OrderedDict]: - tensor_size = calculate_tensor_size(tensor) - ret_block = None - if self.current_block_size + tensor_size > self.max_shard_size: - ret_block = self.current_block - self.current_block = OrderedDict() - self.current_block_size = 0 - self.current_block[name] = tensor - self.current_block_size += tensor_size - return ret_block - - -class GeminiDDP(ZeroDDP): - - def __init__(self, - module: torch.nn.Module, - device: torch.device, - placement_policy: str = "cpu", - pin_memory: bool = False, - force_outputs_fp32: bool = False, - strict_ddp_mode: bool = False, - scatter_after_inference: bool = True, - search_range_mb: int = 32, - hidden_dim: Optional[int] = None, - min_chunk_size_mb: float = 32, - memstats: Optional[MemStats] = None, - verbose: bool = False) -> None: - """ - A torch.Module wrapper using ZeRO-DP and Gemini. - ZeRO is for parallel. Gemini is for memory management. - WARNING: The class will modify the module inline! + yield block, block_size - Example: - model is initialized under the context of ColoInitContext - >>> model = GeminiDDP(model, torch.cuda.current_device(), "cuda") - >>> logits = model(x) - >>> loss = criterion(logits, labels) - >>> model.backward(loss) - - Args: - module (torch.nn.Module): the model to be wrapped. - device (torch.device): device to place the model. - placement_policy (str, optional): "cpu", "cuda", "auto". Defaults to "cpu". - pin_memory (bool, optional): use pin memory on CPU. Defaults to False. - force_outputs_fp32 (bool, optional): force outputs are fp32. Defaults to False. - search_range_mb (int, optional): chunk size searching range in MegaByte. Defaults to 32. - hidden_dim (int, optional): the hidden dimension of DNN. - Users can provide this argument to speed up searching. - If users do not know this argument before training, it is ok. We will use a default value 1024. - min_chunk_size_mb (float, optional): the minimum chunk size in MegaByte. - If the aggregate size of parameters is still smaller than the minimum chunk size, - all parameters will be compacted into one small chunk. - memstats (MemStats, optional) the memory statistics collector by a runtime memory tracer. - """ - # some ugly hotfix for the compatibility with Lightning - if search_range_mb is None: - search_range_mb = 32 - - chunk_manager = init_chunk_manager(model=module, - init_device=device, - hidden_dim=hidden_dim, - search_range_mb=search_range_mb, - min_chunk_size_mb=min_chunk_size_mb, - strict_ddp_flag=strict_ddp_mode, - verbose=verbose) - gemini_manager = GeminiManager(placement_policy, chunk_manager, memstats) - super().__init__(module, gemini_manager, pin_memory, force_outputs_fp32, strict_ddp_mode, - scatter_after_inference) + yield sharder.current_block, sharder.current_block_size diff --git a/colossalai/zero/gemini/gemini_hook.py b/colossalai/zero/gemini/gemini_hook.py index dbc2924858e6371de61fafe0160ace91b4ff182f..480a14511b6928956a135200d6996d6883a10cab 100644 --- a/colossalai/zero/gemini/gemini_hook.py +++ b/colossalai/zero/gemini/gemini_hook.py @@ -17,7 +17,6 @@ class TrainingPhase(Enum): class GeminiZeROHook(ColoParamOpHook): - def __init__(self, gemini_manager: GeminiManager) -> None: super().__init__() self._gemini_manager = gemini_manager @@ -40,7 +39,11 @@ class GeminiZeROHook(ColoParamOpHook): def post_op(self, params): params = [p for p in params if not is_ddp_ignored(p)] for p in params: - tensor_state = TensorState.HOLD if self._training_phase == TrainingPhase.FORWARD or not p.requires_grad else TensorState.HOLD_AFTER_BWD + tensor_state = ( + TensorState.HOLD + if self._training_phase == TrainingPhase.FORWARD or not p.requires_grad + else TensorState.HOLD_AFTER_BWD + ) self._chunk_manager.trans_tensor_state(p, tensor_state) def pre_forward(self, params: List[torch.Tensor]) -> None: diff --git a/colossalai/zero/gemini/gemini_mgr.py b/colossalai/zero/gemini/gemini_mgr.py index c38e6eff840dd8f9cb2681b2524609432bc47a57..f7ff3f6cdd86ddf9eed13ffc75690f1fa8374e14 100644 --- a/colossalai/zero/gemini/gemini_mgr.py +++ b/colossalai/zero/gemini/gemini_mgr.py @@ -1,6 +1,6 @@ import functools from time import time -from typing import List, Optional, Tuple +from typing import Dict, List, Optional, Tuple import torch @@ -26,8 +26,13 @@ class GeminiManager: memstats (MemStats, optional): a mem stats collected by a runtime mem tracer. if None then GeminiManager will collect it during a warmup iteration. """ - def __init__(self, placement_policy: str, chunk_manager: ChunkManager, memstats: Optional[MemStats] = None) -> None: - + def __init__( + self, + placement_policy: str, + chunk_manager: ChunkManager, + memstats: Optional[MemStats] = None, + **placement_kwargs, + ) -> None: assert placement_policy in PlacementPolicyFactory.get_policy_names() self.policy_name = placement_policy policy_cls = PlacementPolicyFactory.create(placement_policy) @@ -35,9 +40,10 @@ class GeminiManager: self._premade_memstats_ = memstats is not None self._memstats = memstats - self._mem_stats_collector = ChunkMemStatsCollector(chunk_manager, - self._memstats) if policy_cls.need_mem_stats else None - self._placement_policy = policy_cls(chunk_manager, self._mem_stats_collector) + self._mem_stats_collector = ( + ChunkMemStatsCollector(chunk_manager, self._memstats) if policy_cls.need_mem_stats else None + ) + self._placement_policy = policy_cls(chunk_manager, self._mem_stats_collector, **placement_kwargs) self._compute_list: List[Tuple[Chunk, ...]] = [] self._compute_idx: int = -1 @@ -58,7 +64,7 @@ class GeminiManager: @property def need_warmup(self) -> bool: - return self.policy_name in ('auto', 'const') + return self.policy_name in ("auto", "const") def is_warmup(self): return self._warmup @@ -81,15 +87,14 @@ class GeminiManager: self._mem_stats_collector.start_collection() def post_iter(self): - """This function must be called when each iteration finishes - """ + """This function must be called when each iteration finishes""" if self._mem_stats_collector and self._warmup: self._mem_stats_collector.finish_collection() self._warmup = False self.reset_attributes() def adjust_layout(self, chunks: Tuple[Chunk, ...]) -> None: - """ Adjust the layout of stateful tensors according to the information provided + """Adjust the layout of stateful tensors according to the information provided by mem_stats_collector, which should belongs to a Sharded Model. """ # find stateful tensor in state COMPUTE @@ -98,11 +103,13 @@ class GeminiManager: cuda_demand, hold_cuda_tensor_list = self._get_layout_info(self._compute_idx, self._warmup, chunks) self._layout_time += time() - start - vol, evict_time = self._placement_policy.evict_tensors(can_evict_chunks=hold_cuda_tensor_list, - cuda_demand=cuda_demand, - warmup=self._warmup, - compute_list=self._compute_list, - compute_idx=self._compute_idx) + vol, evict_time = self._placement_policy.evict_tensors( + can_evict_chunks=hold_cuda_tensor_list, + cuda_demand=cuda_demand, + warmup=self._warmup, + compute_list=self._compute_list, + compute_idx=self._compute_idx, + ) self._d2h_volume += vol self._evict_time += evict_time @@ -114,12 +121,12 @@ class GeminiManager: start = time() cuda_demand = 0 for chunk in chunks: - if chunk.device_type == 'cuda': + if chunk.device_type == "cuda": if chunk.is_gathered: pass else: cuda_demand += chunk.chunk_mem - chunk.shard_mem - elif chunk.device_type == 'cpu': + elif chunk.device_type == "cpu": cuda_demand += chunk.chunk_mem else: raise RuntimeError @@ -133,10 +140,6 @@ class GeminiManager: if self._warmup and self._placement_policy.need_mem_stats: self._compute_list.append(chunks) - @property - def default_device(self): - return self._placement_policy.get_default_device() - def sample_overall_data(self): if self._mem_stats_collector: self._mem_stats_collector.sample_overall_data() @@ -159,6 +162,7 @@ class GeminiManager: def is_cuda_margin_mem_avail(self) -> bool: return self._placement_policy.need_mem_stats - @staticmethod - def get_default_device(policy_name: str) -> torch.device: - return PlacementPolicyFactory.get_default_device(policy_name) + def setup_grads_device( + self, params: List[torch.Tensor], grads_device_map: Dict[torch.Tensor, torch.device] + ) -> None: + self._placement_policy.setup_grads_device(params, grads_device_map) diff --git a/colossalai/zero/gemini/gemini_optimizer.py b/colossalai/zero/gemini/gemini_optimizer.py index 71c4f65cb8d2942e9d0022628f84234c1842c11b..1aece99541b9f2d49629bc82b81f07ea2ad6f40e 100644 --- a/colossalai/zero/gemini/gemini_optimizer.py +++ b/colossalai/zero/gemini/gemini_optimizer.py @@ -1,37 +1,59 @@ # this code is inspired by the DeepSpeed library and implemented with our own design from scratch +import copy import math import warnings -from enum import Enum -from typing import Any, Dict, Set, Tuple +from typing import Any, Dict, Iterator, OrderedDict, Set, Tuple, Union import torch import torch.distributed as dist +from packaging.version import Version from torch.nn import Parameter from torch.optim import Optimizer -from colossalai.amp.naive_amp.grad_scaler import DynamicGradScaler +from colossalai.amp.naive_amp.mixed_precision_mixin import BF16MixedPrecisionMixin, FP16MixedPrecisionMixin +from colossalai.checkpoint_io.utils import StateDictSharder +from colossalai.interface import OptimizerWrapper from colossalai.logging import get_dist_logger -from colossalai.nn.optimizer import ColossalaiOptimizer, CPUAdam, FusedAdam, HybridAdam +from colossalai.nn.optimizer import CPUAdam, FusedAdam, HybridAdam from colossalai.utils import disposable, get_current_device, is_ddp_ignored from .chunk import Chunk, ChunkManager -from .gemini_ddp import ZeroDDP +from .gemini_ddp import GeminiDDP -__all__ = ['ZeroOptimizer', 'GeminiAdamOptimizer'] +__all__ = ["GeminiOptimizer", "GeminiAdamOptimizer"] _AVAIL_OPTIM_LIST = {FusedAdam, CPUAdam, HybridAdam} -class OptimState(Enum): - SCALED = 0 - UNSCALED = 1 +class GeminiFP16MixedPrecisionMixin(FP16MixedPrecisionMixin): + def __init__( + self, + module: GeminiDDP, + initial_scale: float = 2**16, + min_scale: float = 1, + growth_factor: float = 2, + backoff_factor: float = 0.5, + growth_interval: int = 1000, + hysteresis: int = 2, + max_scale: float = 2**32, + ) -> None: + super().__init__( + initial_scale, min_scale, growth_factor, backoff_factor, growth_interval, hysteresis, max_scale + ) + self.module = module + + def check_local_overflow(self) -> bool: + return self.module.overflow_counter > 0 + + def pre_zero_grad(self) -> None: + self.module.overflow_counter = 0 -class ZeroOptimizer(ColossalaiOptimizer): - """A wrapper for optimizer. ``ZeroDDP`` and ``ZeroOptimizer`` implement Zero Redundancy Optimizer (ZeRO state-3). +class GeminiOptimizer(OptimizerWrapper): + """A wrapper for optimizer. ``GeminiDDP`` and ``GeminiOptimizer`` implement Zero Redundancy Optimizer (ZeRO state-3). Note: - You must use ``ZeroDDP`` with ``ZeroOptimizer``. + You must use ``GeminiDDP`` with ``GeminiOptimizer``. Note: Make sure you set ``placement_policy`` of ``GeminiManager`` to `"auto"`, @@ -39,7 +61,7 @@ class ZeroOptimizer(ColossalaiOptimizer): Args: optim (Optimizer): An Optimizer instance. - module (ZeroDDP): A ``ZeroDDP`` instance. + module (GeminiDDP): A ``GeminiDDP`` instance. gpu_margin_mem_ratio (float, optional): The ratio of GPU remaining memory (after the first forward-backward) which will be used when using hybrid CPU optimizer. This argument is meaningless when `placement_policy` of `GeminiManager` is not "auto". @@ -51,51 +73,60 @@ class ZeroOptimizer(ColossalaiOptimizer): growth_interval (float, optional): Growth_interval used by DynamicGradScaler. Defaults to 1000. hysteresis (float, optional): Hysteresis used by DynamicGradScaler. Defaults to 2. max_scale (int, optional): Max_scale used by DynamicGradScaler. Defaults to 2**32. - clipping_norm (float, optional): The norm value used to clip gradient. Defaults to 0.0. + max_norm (float, optional): The norm value used to clip gradient. Defaults to 0.0. norm_type (float, optional): The type of norm used for gradient clipping. Currently, only L2-norm (norm_type=2.0) - is supported in ZeroOptimizer. Defaults to 2.0. + is supported in GeminiOptimizer. Defaults to 2.0. verbose (bool, optional): Whether to print verbose information, including grad overflow info. Defaults to False. """ - def __init__(self, - optim: Optimizer, - module: ZeroDDP, - gpu_margin_mem_ratio: float = 0.0, - initial_scale: float = 2**32, - min_scale: float = 1, - growth_factor: float = 2, - backoff_factor: float = 0.5, - growth_interval: int = 1000, - hysteresis: int = 2, - max_scale: float = 2**32, - clipping_norm: float = 0.0, - norm_type: float = 2.0, - verbose: bool = False, - **defaults: Any): + def __init__( + self, + optim: Optimizer, + module: GeminiDDP, + gpu_margin_mem_ratio: float = 0.0, + initial_scale: float = 2**32, + min_scale: float = 1, + growth_factor: float = 2, + backoff_factor: float = 0.5, + growth_interval: int = 1000, + hysteresis: int = 2, + max_scale: float = 2**32, + max_norm: float = 0.0, + norm_type: float = 2.0, + verbose: bool = False, + **defaults: Any, + ): super().__init__(optim) - assert isinstance(module, ZeroDDP) - assert type(optim) in _AVAIL_OPTIM_LIST, "You should use an optimizer in the available list:\n" \ - f"{_AVAIL_OPTIM_LIST}" + assert isinstance(module, GeminiDDP) + assert type(optim) in _AVAIL_OPTIM_LIST, ( + "You should use an optimizer in the available list:\n" f"{_AVAIL_OPTIM_LIST}" + ) self.module = module self.gemini_manager = module.gemini_manager self.chunk_manager: ChunkManager = self.gemini_manager.chunk_manager - self.optim_state = OptimState.UNSCALED self.param_to_range: Dict[Parameter, Tuple[int, int]] = dict() self.param_to_chunk32: Dict[Parameter, Chunk] = dict() self.chunk16_set: Set[Chunk] = set() - self.clipping_flag = clipping_norm > 0.0 - self.max_norm = clipping_norm + self.clipping_flag = max_norm > 0.0 + self.max_norm = max_norm self.verbose = verbose + self.param_groups_backup = list() + + # Mapping from integer id to real/fake param tensor, used for checkpointing. + self.id_to_real_params: Dict[int, Parameter] = dict() + self.id_to_fake_params: Dict[int, Parameter] = dict() if self.clipping_flag: - assert norm_type == 2.0, "ZeroOptimizer only supports L2 norm now" + assert norm_type == 2.0, "GeminiOptimizer only supports L2 norm now" ddp_param_list = [] for name, param in module.named_parameters(): if is_ddp_ignored(param): if param.requires_grad: - warnings.warn(f"Parameter `{name}` is ignored by DDP but requires gradient! " - "You should handle its optimizer update by yourself!") + warnings.warn( + f"Parameter `{name}` is ignored by DDP but requires gradient! " + "You should handle its optimizer update by yourself!" + ) else: ddp_param_list.append(param) @@ -107,24 +138,34 @@ class ZeroOptimizer(ColossalaiOptimizer): self.__init__optimizer() - # Grad scaler - self.grad_scaler = DynamicGradScaler(initial_scale=initial_scale, - min_scale=min_scale, - growth_factor=growth_factor, - backoff_factor=backoff_factor, - growth_interval=growth_interval, - hysteresis=hysteresis, - max_scale=max_scale) - self._found_overflow: torch.Tensor = torch.zeros(1, dtype=torch.int64, device=get_current_device()) + if module.mixed_precision is torch.float16: + self.mix_precision_mixin = GeminiFP16MixedPrecisionMixin( + module, + initial_scale=initial_scale, + min_scale=min_scale, + growth_factor=growth_factor, + backoff_factor=backoff_factor, + growth_interval=growth_interval, + hysteresis=hysteresis, + max_scale=max_scale, + ) + elif module.mixed_precision is torch.bfloat16: + self.mix_precision_mixin = BF16MixedPrecisionMixin() + else: + raise RuntimeError(f"Unsupported mixed precision type: {module.mixed_precision}") + self._logger = get_dist_logger() self.gpu_margin_mem_ratio: float = float(gpu_margin_mem_ratio) - assert 0.0 <= self.gpu_margin_mem_ratio <= 1.0, f'gpu_margin_mem_ratio must >=0.0 and <=1.0' + assert 0.0 <= self.gpu_margin_mem_ratio <= 1.0, f"gpu_margin_mem_ratio must >=0.0 and <=1.0" # Only move fp32 shards from CPU to GPU when user allows and inner optimizer is valid # Inner optimizer must support optimizing hybrid (CPU and CUDA) tensors, # and it must set `num_fp32_shards_per_param` correctly - self._should_move_fp32_params_h2d: bool = self.gemini_manager.is_cuda_margin_mem_avail and self.gpu_margin_mem_ratio > 0.0 and getattr( - optim, 'num_fp32_shards_per_param', 0) >= 2 + self._should_move_fp32_params_h2d: bool = ( + self.gemini_manager.is_cuda_margin_mem_avail + and self.gpu_margin_mem_ratio > 0.0 + and getattr(optim, "num_fp32_shards_per_param", 0) >= 2 + ) if self.gpu_margin_mem_ratio > 0.0 and not self.gemini_manager.is_cuda_margin_mem_avail: self._logger.warning(f'gpu_margin_mem_ratio is meaningless when placement_policy is not "auto"', ranks=[0]) @@ -132,7 +173,7 @@ class ZeroOptimizer(ColossalaiOptimizer): def _set_grad_ptr(self): for group in self.param_groups: - for fake_param in group['params']: + for fake_param in group["params"]: chunk32 = self.param_to_chunk32[fake_param] begin, end = self.param_to_range[fake_param] chunk16 = chunk32.paired_chunk @@ -144,22 +185,13 @@ class ZeroOptimizer(ColossalaiOptimizer): def _update_fp16_params(self): none_tensor = torch.empty([0]) for group in self.param_groups: - for fake_param in group['params']: + for fake_param in group["params"]: assert fake_param.grad is None fake_param.data = none_tensor.to(fake_param.device) for chunk16 in self.chunk16_set: chunk16.optim_update() - def _check_overflow(self): - # clear previous overflow record - self._found_overflow.fill_(self.module.overflow_counter) - - # all-reduce across global group - dist.all_reduce(self._found_overflow) - - return self._found_overflow.item() > 0 - def _clear_global_norm(self) -> None: for c16 in self.chunk16_set: c16.l2_norm = None @@ -178,7 +210,7 @@ class ZeroOptimizer(ColossalaiOptimizer): group_to_norm[c16.torch_pg] = 0.0 group_to_norm[c16.torch_pg] += c16.l2_norm - c16.l2_norm = None # clear l2 norm + c16.l2_norm = None # clear l2 norm comm_buffer = torch.zeros(1, dtype=torch.float, device=get_current_device()) for group, part_norm in group_to_norm.items(): @@ -190,51 +222,35 @@ class ZeroOptimizer(ColossalaiOptimizer): return global_norm def _get_combined_scale(self): - loss_scale = 1 + div_scale = self.mix_precision_mixin.get_grad_div_scale() - if self.optim_state == OptimState.SCALED: - loss_scale = self.loss_scale - self.optim_state = OptimState.UNSCALED - - combined_scale = loss_scale if self.clipping_flag: total_norm = self._calc_global_norm() - clip = ((total_norm / loss_scale) + 1e-6) / self.max_norm + clip = ((total_norm / div_scale) + 1e-6) / self.max_norm if clip > 1: - combined_scale = clip * loss_scale - - if combined_scale == 1: - return -1 - else: - return combined_scale + div_scale = clip * div_scale - @property - def loss_scale(self): - return self.grad_scaler.scale.item() + return -1 if div_scale == 1.0 else div_scale def zero_grad(self, *args, **kwargs): - self.module.overflow_counter = 0 + self.mix_precision_mixin.pre_zero_grad() return self.optim.zero_grad(set_to_none=True) def step(self, *args, **kwargs): self._maybe_move_fp32_params() self._set_grad_ptr() - found_inf = self._check_overflow() - if found_inf: - self.optim_state = OptimState.UNSCALED # no need to unscale grad - self.grad_scaler.update(found_inf) # update gradient scaler + if self.mix_precision_mixin.should_skip_step(): if self.verbose: - self._logger.info(f'Found overflow. Skip step') - self._clear_global_norm() # clear recorded norm - self.zero_grad() # reset all gradients + self._logger.info(f"Found overflow. Skip step") + self._clear_global_norm() # clear recorded norm + self.zero_grad() # reset all gradients self._update_fp16_params() return # get combined scale. combined scale = loss scale * clipping norm # so that gradient = gradient / combined scale combined_scale = self._get_combined_scale() - self.grad_scaler.update(found_inf) ret = self.optim.step(div_scale=combined_scale, *args, **kwargs) self._register_states() @@ -246,8 +262,7 @@ class ZeroOptimizer(ColossalaiOptimizer): raise NotImplementedError def backward(self, loss: torch.Tensor): - loss = self.loss_scale * loss - self.optim_state = OptimState.SCALED + loss = self.mix_precision_mixin.pre_backward(loss) self.module.backward(loss) def backward_by_grad(self, tensor: torch.Tensor, grad: torch.Tensor): @@ -255,7 +270,7 @@ class ZeroOptimizer(ColossalaiOptimizer): # It receives the scaled grad from the previous rank # No need to scale the grad again # Need to unscale when optimizing - self.optim_state = OptimState.SCALED + grad = self.mix_precision_mixin.pre_backward_by_grad(grad) self.module.backward_by_grad(tensor, grad) def _maybe_move_fp32_params(self): @@ -266,11 +281,11 @@ class ZeroOptimizer(ColossalaiOptimizer): fp32_params_used_cuda_margin_mem = 0 for group in self.param_groups: - for fake_param in group['params']: + for fake_param in group["params"]: chunk32 = self.param_to_chunk32[fake_param] chunk16 = chunk32.paired_chunk - if chunk32.device_type == 'cuda': + if chunk32.device_type == "cuda": continue if fp32_params_used_cuda_margin_mem + chunk32.payload_mem < fp32_params_available_cuda_margin_mem: @@ -281,9 +296,9 @@ class ZeroOptimizer(ColossalaiOptimizer): fp32_params_used_cuda_margin_mem += chunk32.payload_mem for group in self.param_groups: - for fake_param in group['params']: + for fake_param in group["params"]: chunk32 = self.param_to_chunk32[fake_param] - if chunk32.device_type == 'cuda': + if chunk32.device_type == "cuda": state = self.optim.state[fake_param] for k, v in state.items(): if isinstance(v, torch.Tensor): @@ -291,14 +306,13 @@ class ZeroOptimizer(ColossalaiOptimizer): def _register_states_(self): for group in self.optim.param_groups: - for p in group['params']: + for p in group["params"]: state = self.optim.state[p] for val in state.values(): if isinstance(val, torch.Tensor): self.chunk_manager.add_extern_static_tensor(val) def __init__optimizer(self): - def get_range_pair(local_chunk: Chunk, local_param: Parameter): param_info = local_chunk.tensors_info[local_param] if local_chunk.keep_gathered: @@ -307,29 +321,424 @@ class ZeroOptimizer(ColossalaiOptimizer): end = min(local_chunk.shard_size, param_info.end - local_chunk.shard_begin) return begin, end + param_id = -1 for group in self.optim.param_groups: fake_params_list = list() - - for param in group['params']: + group_backup = {k: v for k, v in group.items() if k != "params"} + group_ids = [] + for param in group["params"]: + # Record the mapping of id to current param. + param_id += 1 + self.id_to_real_params[param_id] = param + group_ids.append(param_id) + + # If current param is controlled by current process, add it to fake_param. if is_ddp_ignored(param): continue chunk16 = self.chunk_manager.get_chunk(param) range_pair = get_range_pair(chunk16, param) if range_pair[0] >= range_pair[1]: continue - grad_device = self.module.grads_device[param] fake_param = torch.nn.Parameter(torch.empty([0], device=grad_device)) self.param_to_chunk32[fake_param] = chunk16.paired_chunk self.param_to_range[fake_param] = range_pair - + self.id_to_fake_params[param_id] = fake_param fake_params_list.append(fake_param) - group['params'] = fake_params_list - - -class GeminiAdamOptimizer(ZeroOptimizer): - + # Update self.optim.param_groups as well as backup group. + group["params"] = fake_params_list + group_backup["params"] = group_ids + self.param_groups_backup.append(group_backup) + + def get_offsets(self, param_id: int) -> tuple: + """ + Args: + param_id(int): The id of parameter. + + Returns: + chunk_offset(int): Offset of parameter inside the chunk. + shard_offset(int): Offset of its optimizer state shard + relative to the whole optimizer state. + shard_size(int): Length of parameter shard owned by current process. + """ + + if param_id not in self.id_to_fake_params: + return -1, -1, -1 + fake_param = self.id_to_fake_params[param_id] + chunk = self.param_to_chunk32[fake_param].paired_chunk + param = self.id_to_real_params[param_id] + param_info = chunk.tensors_info[param] + + begin_in_chunk, end_in_chunk = self.param_to_range[fake_param] + chunk_offset = begin_in_chunk + if chunk.keep_gathered: + shard_offset = 0 + else: + shard_offset = begin_in_chunk + chunk.shard_begin - param_info.offset + shard_size = end_in_chunk - begin_in_chunk + assert chunk_offset >= 0 and shard_offset >= 0 + return chunk_offset, shard_offset, shard_size + + def collect_states(self, param_id: int, only_rank_0: bool = True) -> dict: + """ + Args: + param_id (int): id of the parameter whose state is to be gathered at master rank. + only_rank_0(bool): if True, states will be collected only on master rank, otherwise collected on every rank. + + Returns: + collected_states(dict): the gathered optimzier state of parameter with given id + if this method is called by master rank, otherwise an empty dict. + + This method can work only when called by all processes simultaneously. + """ + + # Get param & chunk & process group. + param = self.id_to_real_params[param_id] + fake_param = self.id_to_fake_params.get(param_id, None) + chunk = self.chunk_manager.get_chunk(param) + process_group = chunk.torch_pg + rank = dist.get_rank(process_group) + master_rank = 0 + collected_states = {} + + # Fetch names of states through all_gather. + local_state_names = None + if fake_param is not None: + local_state_names = list(self.optim.state[fake_param].keys()) + gathered_state_names = [None for _ in range(dist.get_world_size(process_group))] + dist.barrier() + dist.all_gather_object(gathered_state_names, local_state_names) + state_names = None + for names in gathered_state_names: + if names is not None: + # Assume different devices share the same set of state names if they have. + state_names = copy.deepcopy(names) + break + + # Directly return if this parameter doesn't have optimizer states. + # e.g. parameter freezed/layer dropped + if state_names is None: + return collected_states + + # Boolean variable is_collector indicates that whether the current rank + # needs to gather the whole optimizer states. + # Only master rank is collector when only_rank_0 is True. + # Every rank is collector when only_rank_0 is False. + is_collector = (rank == master_rank) or (not only_rank_0) + + # If the chunk is kept gathered, + # the parameteres are treated the same as that of those in strict DDP during training. + # So states can be directly fetched from current device. + if chunk.keep_gathered: + assert param_id in self.id_to_fake_params + if is_collector: + states = self.optim.state[fake_param] + for state_name in state_names: + if state_name == "step": + # To keep aligned with pytorch, state 'step' is stored as a pytorch tensor with type float32. + collected_states[state_name] = torch.tensor( + states["step"], dtype=torch.float32, requires_grad=False + ).cpu() + else: + state_tensor = states[state_name].detach().clone().to(torch.float32).cpu() + collected_states[state_name] = torch.reshape(state_tensor, param.shape) + return collected_states + + # Check whether the param with given id is managed by current process. + own_param = param_id in self.id_to_fake_params + + # Collector gets prepared for state collecting. + if is_collector: + for state_name in state_names: + if state_name == "step": + # To keep aligned with pytorch, state 'step' is stored as a pytorch tensor with type float32. + collected_states[state_name] = torch.tensor(0.0, dtype=torch.float32, requires_grad=False).cpu() + else: + collected_states[state_name] = torch.zeros( + param.numel(), dtype=torch.float32, requires_grad=False + ).cpu() + + # Materials for gathering, including compacted state tensors, and the offset of shard inside each state. + compacted_states = self.pack_optimizer_states_to_tensor(param_id, state_names) if own_param else None + _, shard_offset, shard_size = self.get_offsets(param_id) + + # Collectors gather state shards through all_gathering. + gathered_state_shards = [None for _ in range(dist.get_world_size(process_group))] + + dist.barrier() + dist.all_gather_object(gathered_state_shards, [compacted_states, shard_offset, shard_size]) + + if is_collector: + for state_shard in gathered_state_shards: + compacted_states = state_shard[0] + shard_offset = state_shard[1] + shard_size = state_shard[2] + if compacted_states is None: + continue + self.load_from_compacted_states( + compacted_states, collected_states, state_names, shard_offset, shard_size + ) + + # Reshape tensors + if is_collector: + for state_name, state_tensor in collected_states.items(): + if state_tensor.numel() == param.numel(): + collected_states[state_name] = torch.reshape(state_tensor, param.shape) + + return collected_states + + def pack_optimizer_states_to_tensor( + self, + param_id: int, + state_names: list, + device: torch.device = torch.device("cuda"), + dtype: torch.dtype = torch.float32, + ) -> torch.Tensor: + """ + With param id given, pack its optimizer states into a compact tensor and return. + """ + if param_id not in self.id_to_fake_params: + return None + + fake_param = self.id_to_fake_params[param_id] + param_range = self.param_to_range[fake_param] + states = self.optim.state[fake_param] + shard_size = param_range[1] - param_range[0] + compacted_size = 0 + for name in state_names: + if name == "step": + compacted_size += 1 + else: + compacted_size += shard_size + compacted_states = torch.zeros(compacted_size, dtype=dtype, device=device, requires_grad=False) + + next_state_offset = 0 + for state_name, state_tensor in states.items(): + # State 'step' needs special operation. + if state_name == "step": + if isinstance(state_tensor, torch.Tensor): + compacted_states[next_state_offset] = state_tensor[0].item() + else: + assert isinstance(state_tensor, int) + compacted_states[next_state_offset] = state_tensor + next_state_offset += 1 + else: + assert state_tensor.numel() == shard_size + compacted_states[next_state_offset : next_state_offset + shard_size].copy_(state_tensor) + next_state_offset += shard_size + + return compacted_states + + def load_from_compacted_states( + self, + compacted_states: torch.Tensor, + collected_states: dict, + state_names: list, + shard_start: int, + shard_size: int, + ): + """ + Given a tensor carrying compacted optimizer states, + update these states to collected_states. + """ + shard_end = shard_start + shard_size + next_state_offset = 0 + + for state_name in state_names: + if state_name == "step": + collected_states["step"].data = torch.tensor( + compacted_states[next_state_offset].item(), dtype=torch.float32, requires_grad=False + ).cpu() + next_state_offset += 1 + else: + target_segment = collected_states[state_name][shard_start:shard_end] + target_segment.copy_(compacted_states[next_state_offset : next_state_offset + shard_size]) + next_state_offset += shard_size + + def get_param_groups_for_saving(self) -> list: + """ + Return the param_groups in Pytorch format when saving to checkpoint. + """ + + param_groups = copy.deepcopy(self.param_groups_backup) + + # To be compatible with pytorch checkpointing, + # store extra hyperparameters used by pytorch Adam optimizer. + torch_special_hyperparameters = { + "amsgrad": False, + "maximize": False, + "foreach": None, + "capturable": False, + "differentiable": False, + "fused": False, + } + + for group in param_groups: + for k, v in torch_special_hyperparameters.items(): + if k not in group: + group[k] = v + + return param_groups + + def state_dict(self, only_rank_0: bool = True) -> dict: + """ + Args: + only_rank_0 (bool): a boolean value indicating whether the state_dict is collected + only on rank 0, dafault to True. + + Returns: + The complete state of the optimizer as a :class:`dict`. + It contains two entries: + + * state - a dict holding current optimization state. Its content + differs between optimizer classes. + * param_groups - a list containing all parameter groups where each + parameter group is a dict. + + Warning: This method will gather and return the whole optimizer state_dict, + so it should be called only when memory resources are abundant. + """ + state_dict = {} + state_dict["param_groups"] = self.get_param_groups_for_saving() + + # Collect optimizer states. + state_dict["state"] = dict() + for param_id in self.id_to_real_params.keys(): + dist.barrier() + state_dict["state"][param_id] = self.collect_states(param_id=param_id, only_rank_0=only_rank_0) + return state_dict + + def load_param_groups(self, saved_param_groups: list): + """ + Load saved_param_groups into + self.param_groups and self.param_groups_backup + """ + self.param_groups_backup = copy.deepcopy(saved_param_groups) + + # discard the older param_groups + self.optim.param_groups = [] + + for group in saved_param_groups: + fake_params_list = list() + updated_group = {k: v for k, v in group.items() if k != "params"} + for param_id in group["params"]: + if param_id not in self.id_to_fake_params: + continue + fake_param = self.id_to_fake_params[param_id] + fake_params_list.append(fake_param) + updated_group["params"] = fake_params_list + self.optim.param_groups.append(updated_group) + + def load_single_param_states(self, param_id: int, saved_states: dict): + """ + Load saved optimizer states into parameter with given id. + """ + + def cast(param, state_range, value, key=None): + """ + Make a copy of the needed segment of value and cast it to device of param. + """ + assert isinstance(value, torch.Tensor) + ret_val = value + if key == "step": + assert value.numel() == 1 + ret_val = int(value.item()) + else: + state_start, state_end = state_range + ret_val = torch.zeros( + state_end - state_start, dtype=torch.float32, device=param.device, requires_grad=False + ) + ret_val.copy_(value.flatten()[state_start:state_end]) + return ret_val + + assert param_id in self.id_to_fake_params + fake_param = self.id_to_fake_params[param_id] + _, state_offset, param_size = self.get_offsets(param_id) + state_range = (state_offset, state_offset + param_size) + + # Copy states assigned to param (and cast tensors to appropriate types). + updated_states = dict() + for k, v in saved_states.items(): + updated_states[k] = cast(fake_param, state_range, v, k) + del v # clean loaded states + self.optim.state[fake_param].update(updated_states) + + def load_param_states(self, param_states: dict): + """Loads param states from a state_dict. The param_states can be complete or sharded. + During loading, filter out the part of states not considered by current process. + + Args: + param_states (dict): A mapping from param_id to its states. + """ + for param_id, states in param_states.items(): + if param_id in self.id_to_fake_params: + self.load_single_param_states(param_id, states) + + def optimizer_loading_epilogue(self): + # Epilogue when loading state_dict to pytorch optimizer. + if Version(torch.__version__) >= Version("2.0.0"): + self.optim._patch_step_function() # To support multiprocessing pickle/unpickle + else: + self.optim._hook_for_profile() # To support multiprocessing pickle/unpickle. + self.optim.defaults.setdefault("differentiable", False) + + def load_state_dict(self, state_dict: dict): + """Loads optimizer state from complete optimizer state_dict. + During loading, filter out the part of states not considered by current process. + + Args: + state_dict (dict): optimizer state. Should be an object returned + from a call to :meth:`state_dict`. + """ + assert "param_groups" in state_dict + assert "state" in state_dict + self.load_param_groups(state_dict["param_groups"]) + self.load_param_states(state_dict["state"]) + self.optimizer_loading_epilogue() + + def state_shard( + self, prefix: str = "", max_shard_size: int = 1024, only_rank_0: bool = True + ) -> Iterator[Tuple[OrderedDict, int]]: + """Returns dictionaries containing shards of optimizer states one by one. + The max size of each dictionary shard is specified by ``max_shard_size``. + + Args: + prefix (str, optional): the prefix for states. Default to ''. + max_shard_size (int, optional): max size of state dict shard (in MB). Defaults to 1024. + only_rank_0 (bool, optional): a boolean value indicating whether the state_dict is collected + only on rank 0, dafault to True. + + Yields: + Iterator[OrderedDict]: A generator of state dict shard of optimizer states. + """ + + sharder = StateDictSharder(max_shard_size) + for param_id in self.id_to_real_params.keys(): + dist.barrier() + state = self.collect_states(param_id=param_id, only_rank_0=only_rank_0) + + block, block_size = sharder.append_optim_state(param_id, state) + if block is not None: + yield block, block_size + + yield sharder.current_block, sharder.current_block_size + + def clip_grad_by_value(self, clip_value: float, *args, **kwargs) -> None: + raise NotImplementedError("Gemini does not support clip_grad_by_value") + + def clip_grad_by_norm( + self, + max_norm: Union[float, int], + norm_type: Union[float, int] = 2, + error_if_nonfinite: bool = False, + *args, + **kwargs, + ) -> torch.Tensor: + warnings.warn(f"Gemini controls grad clipping by itself, so you should not use clip_grad_by_norm") + + +class GeminiAdamOptimizer(GeminiOptimizer): def __init__(self, model: torch.nn.Module, **defaults: Any) -> None: optimizer = HybridAdam(model.parameters(), **defaults) super().__init__(optimizer, model, **defaults) diff --git a/colossalai/zero/gemini/memory_tracer/__init__.py b/colossalai/zero/gemini/memory_tracer/__init__.py index 02c9d5754ec9a34c11531111a8fd6ca5e6698c96..cb7f626ff446b27c977748b142fd15689190a608 100644 --- a/colossalai/zero/gemini/memory_tracer/__init__.py +++ b/colossalai/zero/gemini/memory_tracer/__init__.py @@ -1,11 +1,14 @@ -from .param_runtime_order import OrderedParamGenerator # isort:skip -from .memory_stats import MemStats # isort:skip -from .memory_monitor import AsyncMemoryMonitor, SyncCudaMemoryMonitor # isort:skip -from .memstats_collector import MemStatsCollector # isort:skip -from .chunk_memstats_collector import ChunkMemStatsCollector # isort:skip -from .static_memstats_collector import StaticMemStatsCollector # isort:skip +from .param_runtime_order import OrderedParamGenerator # isort:skip +from .memory_stats import MemStats # isort:skip +from .memory_monitor import AsyncMemoryMonitor, SyncCudaMemoryMonitor # isort:skip +from .memstats_collector import MemStatsCollector # isort:skip +from .chunk_memstats_collector import ChunkMemStatsCollector # isort:skip __all__ = [ - 'AsyncMemoryMonitor', 'SyncCudaMemoryMonitor', 'MemStatsCollector', 'ChunkMemStatsCollector', - 'StaticMemStatsCollector', 'MemStats', 'OrderedParamGenerator' + "AsyncMemoryMonitor", + "SyncCudaMemoryMonitor", + "MemStatsCollector", + "ChunkMemStatsCollector", + "MemStats", + "OrderedParamGenerator", ] diff --git a/colossalai/zero/gemini/memory_tracer/chunk_memstats_collector.py b/colossalai/zero/gemini/memory_tracer/chunk_memstats_collector.py index f5eb05b4f22ac71d2bb61550d64992affd67e411..b5e40a817e58ebaf8749420c1b24b79c71eb95af 100644 --- a/colossalai/zero/gemini/memory_tracer/chunk_memstats_collector.py +++ b/colossalai/zero/gemini/memory_tracer/chunk_memstats_collector.py @@ -1,7 +1,6 @@ from typing import Optional from colossalai.utils import get_current_device -from colossalai.utils.memory import colo_device_memory_capacity from colossalai.zero.gemini.chunk import ChunkManager from .memory_stats import MemStats @@ -9,7 +8,6 @@ from .memstats_collector import MemStatsCollector class ChunkMemStatsCollector(MemStatsCollector): - def __init__(self, chunk_manager: ChunkManager, memstats: Optional[MemStats] = None) -> None: """ @@ -25,12 +23,14 @@ class ChunkMemStatsCollector(MemStatsCollector): # override def record_model_data_volume(self) -> None: """ - record model data volumn on cuda and cpu. + record model data volume on cuda and cpu. """ if self._start_flag and not self.use_outside_memstats: - cuda_mem = self._chunk_manager.total_mem['cuda'] + cuda_mem = self._chunk_manager.total_mem["cuda"] self._memstats.record_max_cuda_model_data(cuda_mem) @property def cuda_margin_mem(self) -> float: + from colossalai.legacy.utils.memory import colo_device_memory_capacity + return colo_device_memory_capacity(get_current_device()) - self._memstats.max_overall_cuda diff --git a/colossalai/zero/gemini/memory_tracer/memory_monitor.py b/colossalai/zero/gemini/memory_tracer/memory_monitor.py index f8d99dbce7a43a8089dd1ebddd9bce6979a17f40..513a6326d5f15a56305d7ecbd3201fa616abdf66 100644 --- a/colossalai/zero/gemini/memory_tracer/memory_monitor.py +++ b/colossalai/zero/gemini/memory_tracer/memory_monitor.py @@ -5,7 +5,7 @@ from time import sleep, time import torch -from colossalai.utils import colo_device_memory_used, get_current_device +from colossalai.utils import get_current_device class MemoryMonitor: @@ -45,7 +45,7 @@ class MemoryMonitor: class AsyncMemoryMonitor(MemoryMonitor): """ - An Async Memory Monitor runing during computing. Sampling memory usage of the current GPU + An Async Memory Monitor running during computing. Sampling memory usage of the current GPU at interval of `1/(10**power)` sec. The idea comes from Runtime Memory Tracer of PatrickStar @@ -67,7 +67,7 @@ class AsyncMemoryMonitor(MemoryMonitor): async_mem_monitor.save('log.pkl') Args: - power (int, optional): the power of time interva. Defaults to 10. + power (int, optional): the power of time interval. Defaults to 10. .. _PatrickStar: Parallel Training of Pre-trained Models via Chunk-based Memory Management: https://arxiv.org/abs/2108.05818 @@ -110,6 +110,8 @@ class AsyncMemoryMonitor(MemoryMonitor): return max_usage def _measure_usage(self): + from colossalai.legacy.utils import colo_device_memory_used + max_usage = 0 while self.keep_measuring: max_usage = max( diff --git a/colossalai/zero/gemini/memory_tracer/memory_stats.py b/colossalai/zero/gemini/memory_tracer/memory_stats.py index 9a45034ee27e2cc0aeea2b100cc69fc2e2df71e7..1c141169f04576a699cbd3cbe1bed57a63176da1 100644 --- a/colossalai/zero/gemini/memory_tracer/memory_stats.py +++ b/colossalai/zero/gemini/memory_tracer/memory_stats.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, List, Optional +from typing import List, Optional import torch @@ -6,10 +6,9 @@ from .param_runtime_order import OrderedParamGenerator class MemStats(object): - def __init__(self) -> None: """ - Store the non model data statistics used for Gemini and ZeroOptimizer. + Store the non model data statistics used for Gemini and GeminiOptimizer. """ # (preop_step, List[param]) self._step_param_dict = dict() @@ -59,7 +58,7 @@ class MemStats(object): time step. Args: - param_list (List[torch.nn.Parameter]): a list of torch paramters. + param_list (List[torch.nn.Parameter]): a list of torch parameters. """ for p in param_list: if p not in self._param_step_dict: @@ -92,17 +91,17 @@ class MemStats(object): return self._param_runtime_order def non_model_data_list(self, device_type: str) -> List[int]: - if device_type == 'cuda': + if device_type == "cuda": return self._non_model_data_cuda_list - elif device_type == 'cpu': + elif device_type == "cpu": return self._non_model_data_cpu_list else: raise TypeError def max_non_model_data(self, device_type: str) -> float: - if device_type == 'cuda': + if device_type == "cuda": return max(self._non_model_data_cuda_list) - elif device_type == 'cpu': + elif device_type == "cpu": return max(self._non_model_data_cpu_list) else: raise TypeError diff --git a/colossalai/zero/gemini/memory_tracer/memstats_collector.py b/colossalai/zero/gemini/memory_tracer/memstats_collector.py index 0694be48550aac735b3456e30edf4a0ddcd24e26..e4459831109a0b67216e1aea34fd93897cf3411d 100644 --- a/colossalai/zero/gemini/memory_tracer/memstats_collector.py +++ b/colossalai/zero/gemini/memory_tracer/memstats_collector.py @@ -40,11 +40,12 @@ class MemStatsCollector: Returns: int: max non model data memory usage of current sampling period """ - assert not self._start_flag, 'Cannot get mem stats info during collection phase.' - assert self._step_total > 0, 'Cannot get mem stats info before collection phase.' - assert len(self._memstats.non_model_data_list(device_type)) > self._step_idx, \ - f"{len(self._memstats.non_model_data_list(device_type))} should be > than step idx {self._step_idx}, "\ + assert not self._start_flag, "Cannot get mem stats info during collection phase." + assert self._step_total > 0, "Cannot get mem stats info before collection phase." + assert len(self._memstats.non_model_data_list(device_type)) > self._step_idx, ( + f"{len(self._memstats.non_model_data_list(device_type))} should be > than step idx {self._step_idx}, " f"step total {self._step_total}" + ) next_non_model_data = self._memstats.non_model_data_list(device_type)[self._step_idx] self._step_idx = (self._step_idx + 1) % self._step_total return next_non_model_data @@ -60,9 +61,9 @@ class MemStatsCollector: def finish_collection(self): self.sample_overall_data() # self._step_total = len(self._sampling_time) - self._step_total = len(self._memstats.non_model_data_list('cuda')) + self._step_total = len(self._memstats.non_model_data_list("cuda")) self._start_flag = False - print(f'finish_collection {self._step_total}') + print(f"finish_collection {self._step_total}") # deprecated def record_model_data_volume(self) -> None: @@ -70,10 +71,10 @@ class MemStatsCollector: Sampling model data statistics. """ if self._start_flag and not self.use_outside_memstats: - from colossalai.zero.legacy.gemini import StatefulTensor + from colossalai.legacy.zero.gemini import StatefulTensor # The following code work for ZeroInitContext, which is deprecated in v0.1.12 - cuda_mem = StatefulTensor.GST_MGR.total_mem['cuda'] + cuda_mem = StatefulTensor.GST_MGR.total_mem["cuda"] self._memstats.record_max_cuda_model_data(cuda_mem) def sample_overall_data(self) -> None: diff --git a/colossalai/zero/gemini/memory_tracer/param_runtime_order.py b/colossalai/zero/gemini/memory_tracer/param_runtime_order.py index 638c0533ce926b6629906d8b113161345017295d..670edb9ec0d2a0e56b5536ac461804a634748b93 100644 --- a/colossalai/zero/gemini/memory_tracer/param_runtime_order.py +++ b/colossalai/zero/gemini/memory_tracer/param_runtime_order.py @@ -4,7 +4,6 @@ import torch class ParamGenerator(ABC): - def append(self, param: torch.nn.Parameter): pass diff --git a/colossalai/zero/gemini/memory_tracer/runtime_mem_tracer.py b/colossalai/zero/gemini/memory_tracer/runtime_mem_tracer.py index 0c9eac8b63e3662c10e358c6653fdd524b7d5e6e..b0d258824d2bda3e92ed8a84512869eb0d1cd8aa 100644 --- a/colossalai/zero/gemini/memory_tracer/runtime_mem_tracer.py +++ b/colossalai/zero/gemini/memory_tracer/runtime_mem_tracer.py @@ -1,19 +1,19 @@ import torch.nn -from colossalai.nn.parallel.data_parallel import _cast_float -from colossalai.tensor.param_op_hook import ColoParamOpHookManager -from colossalai.zero.legacy.gemini.ophooks.runtime_mem_tracer_hook import ( +from colossalai.legacy.zero.gemini.ophooks.runtime_mem_tracer_hook import ( GradMemStats, GradMemTracerHook, ParamMemTracerHook, ) +from colossalai.tensor.param_op_hook import ColoParamOpHookManager +from colossalai.utils import _cast_float from .memory_stats import MemStats -__all__ = ['RuntimeMemTracer'] +__all__ = ["RuntimeMemTracer"] -class RuntimeMemTracer(): +class RuntimeMemTracer: """RuntimeMemTracer for the module training using ColoParameter. Trace non-model memory usage during fwd+bwd process. diff --git a/colossalai/zero/gemini/memory_tracer/static_memstats_collector.py b/colossalai/zero/gemini/memory_tracer/static_memstats_collector.py index b8f9a095f4224a500cbf630f227b9c257d2db3b2..2a1a3745f81c6f3574ae655ab8e16cac00504228 100644 --- a/colossalai/zero/gemini/memory_tracer/static_memstats_collector.py +++ b/colossalai/zero/gemini/memory_tracer/static_memstats_collector.py @@ -15,9 +15,9 @@ from .chunk_memstats_collector import ChunkMemStatsCollector class ModuleInfos: - - def __init__(self, module: torch.nn.Module, module_name: str, module_full_name: str, - parent_module: torch.nn.Module): + def __init__( + self, module: torch.nn.Module, module_name: str, module_full_name: str, parent_module: torch.nn.Module + ): self.module = module self.module_name = module_name self.module_full_name = module_full_name @@ -35,14 +35,13 @@ class StaticMemStatsCollector(ChunkMemStatsCollector): self.module_info_list = [] def init_mem_stats(self, *inputs): - self.register_opnodes_recursively(self.module) self.refactor_module() self.module = self.module.cpu() self.module.train() - data = [MetaTensor(torch.rand(inp.shape, device='meta'), fake_device='cpu') for inp in inputs] + data = [MetaTensor(torch.rand(inp.shape, device="meta"), fake_device="cpu") for inp in inputs] gm = symbolic_trace(self.module) interp = MetaInfoProp(gm) interp.propagate(*data) @@ -87,12 +86,13 @@ class StaticMemStatsCollector(ChunkMemStatsCollector): for modInfo in self.module_info_list: modInfo.parent_module.__setattr__(modInfo.module_name, modInfo.module) - def register_opnodes_recursively(self, - module: torch.nn.Module, - name: str = "", - full_name: str = "", - parent_module: Optional[torch.nn.Module] = None): - + def register_opnodes_recursively( + self, + module: torch.nn.Module, + name: str = "", + full_name: str = "", + parent_module: Optional[torch.nn.Module] = None, + ): assert isinstance(module, torch.nn.Module) for child_name, child in module.named_children(): diff --git a/colossalai/zero/gemini/memory_tracer/utils.py b/colossalai/zero/gemini/memory_tracer/utils.py index 6962c058110e245a6bbd2470d75c137b54202aae..9faf81af63d73b85f3c76703cdb35e8b19f7e9ce 100644 --- a/colossalai/zero/gemini/memory_tracer/utils.py +++ b/colossalai/zero/gemini/memory_tracer/utils.py @@ -7,14 +7,14 @@ def colo_model_optimizer_usage(optim) -> Tuple[int, int]: """Trace the optimizer memory usage Args: - optim (ShardedOptimV2): an instance of ShardedOptimver + optim (ShardedOptimV2): an instance of ShardedOptimizer Returns: Tuple[int, int]: cuda/cpu memory usage in Byte """ if optim is None: return 0, 0 - assert hasattr(optim, 'get_memory_usage'), f"{type(optim)} has no attr get_memory_usage()" + assert hasattr(optim, "get_memory_usage"), f"{type(optim)} has no attr get_memory_usage()" return optim.get_memory_usage() @@ -35,16 +35,16 @@ def colo_model_mem_usage(model: torch.nn.Module) -> Tuple[int, int]: return 0, 0 assert isinstance(t, torch.Tensor) _cpu_mem_usage, _cuda_mem_usage = 0, 0 - if t.device.type == 'cpu': + if t.device.type == "cpu": _cpu_mem_usage += t.numel() * t.element_size() - elif t.device.type == 'cuda': + elif t.device.type == "cuda": _cuda_mem_usage += t.numel() * t.element_size() return _cuda_mem_usage, _cpu_mem_usage cuda_mem_usage = 0 cpu_mem_usage = 0 for param in model.parameters(): - if hasattr(param, 'colo_attr'): + if hasattr(param, "colo_attr"): t_cuda, t_cpu = param.colo_attr.get_memory_usage() cuda_mem_usage += t_cuda cpu_mem_usage += t_cpu diff --git a/colossalai/zero/gemini/placement_policy.py b/colossalai/zero/gemini/placement_policy.py index 84a868872f887a571858a44901dbd82d022dd18d..8a74eb587b83fee3c7d8062e258f47e45dfaf7d7 100644 --- a/colossalai/zero/gemini/placement_policy.py +++ b/colossalai/zero/gemini/placement_policy.py @@ -1,12 +1,14 @@ import functools +import warnings from abc import ABC, abstractmethod from time import time from typing import Dict, List, Optional, Tuple, Type import torch +from colossalai.legacy.utils.memory import colo_device_memory_capacity from colossalai.utils import get_current_device -from colossalai.utils.memory import colo_device_memory_capacity +from colossalai.zero.gemini.chunk import Chunk from .chunk import Chunk, ChunkManager from .memory_tracer import ChunkMemStatsCollector @@ -15,9 +17,9 @@ from .memory_tracer import ChunkMemStatsCollector class PlacementPolicy(ABC): need_mem_stats: bool = False - def __init__(self, - chunk_manager: ChunkManager, - mem_stats_collector: Optional[ChunkMemStatsCollector] = None) -> None: + def __init__( + self, chunk_manager: ChunkManager, mem_stats_collector: Optional[ChunkMemStatsCollector] = None, **kwargs + ) -> None: self.chunk_manager = chunk_manager self.mem_stats_collector: Optional[ChunkMemStatsCollector] = mem_stats_collector @@ -25,65 +27,102 @@ class PlacementPolicy(ABC): def evict_tensors(self, can_evict_chunks: List[Chunk], **kwargs) -> Tuple[int, float]: raise NotImplementedError - @staticmethod - def get_default_device() -> torch.device: - return torch.device('cpu') - + @abstractmethod + def setup_grads_device( + self, params: List[torch.Tensor], grads_device_map: Dict[torch.Tensor, torch.device] + ) -> None: + raise NotImplementedError -class CPUPlacementPolicy(PlacementPolicy): - def __init__(self, - chunk_manager: ChunkManager, - mem_stats_collector: Optional[ChunkMemStatsCollector] = None) -> None: +class StaticPlacementPolicy(PlacementPolicy): + def __init__( + self, + chunk_manager: ChunkManager, + mem_stats_collector: Optional[ChunkMemStatsCollector] = None, + shard_param_frac: float = 1.0, + offload_optim_frac: float = 0.0, + offload_param_frac: float = 0.0, + **kwargs, + ) -> None: super().__init__(chunk_manager, mem_stats_collector=mem_stats_collector) + if offload_param_frac > 0.0 and (shard_param_frac != 1.0 or offload_optim_frac != 1.0): + warnings.warn("offload_param_frac is ignored when shard_param_frac != 1.0 or offload_optim_frac != 1.0") + offload_param_frac = 0.0 + self.shard_param_frac = shard_param_frac + self.offload_optim_frac = offload_optim_frac + self.offload_param_frac = offload_param_frac + # these should be initialized in setup_grads_device + self.keep_gathered_chunk_mem = 0.0 + self.keep_cuda_chunk_mem = 0.0 def evict_tensors(self, can_evict_chunks: List[Chunk], **kwargs) -> Tuple[int, float]: - volume = 0 - start = time() + can_shard_chunk_mem = sum(chunk.chunk_mem for chunk in can_evict_chunks) + can_offload_chunk_mem = can_shard_chunk_mem for chunk in can_evict_chunks: + if can_shard_chunk_mem <= self.keep_gathered_chunk_mem: + break self.chunk_manager.release_chunk(chunk) - self.chunk_manager.move_chunk(chunk, torch.device('cpu')) - volume += chunk.chunk_mem - return volume, time() - start - - -class CUDAPlacementPolicy(PlacementPolicy): - - def __init__(self, - chunk_manager: ChunkManager, - mem_stats_collector: Optional[ChunkMemStatsCollector] = None) -> None: - assert torch.cuda.is_available(), 'Cannot use CUDATensorPlacementPolicy when CUDA is not available' - super().__init__(chunk_manager, mem_stats_collector=mem_stats_collector) - - def evict_tensors(self, can_evict_chunks: List[Chunk], **kwargs) -> Tuple[int, float]: - return 0, 0 - - @staticmethod - def get_default_device() -> torch.device: - return get_current_device() + # real saved mem is chunk_mem - shard_mem, for simplicity we use chunk_mem + can_shard_chunk_mem -= chunk.chunk_mem + for chunk in can_evict_chunks: + if can_offload_chunk_mem <= self.keep_cuda_chunk_mem: + break + self.chunk_manager.move_chunk(chunk, torch.device("cpu")) + # real saved mem is shard_mem, for simplicity we use chunk_mem + can_offload_chunk_mem -= chunk.chunk_mem + return 0, 0.0 + + def setup_grads_device( + self, params: List[torch.Tensor], grads_device_map: Dict[torch.Tensor, torch.device] + ) -> None: + total_chunk_mem = sum(self.chunk_manager.get_chunk(p).chunk_mem for p in params) + + offload_optim_chunk_mem = total_chunk_mem * self.offload_optim_frac + offloaded_optim_chunk_mem = 0 + chunks = set(self.chunk_manager.get_chunk(p) for p in params) + for chunk in chunks: + params = chunk.get_tensors() + # init offload optim settings + # keep gathered chunks are in CUDA + if chunk.keep_gathered or offloaded_optim_chunk_mem >= offload_optim_chunk_mem: + device = get_current_device() + else: + device = torch.device("cpu") + # real offloaded mem is chunk.shard_mem, for simplicity we use chunk mem here + offloaded_optim_chunk_mem += chunk.chunk_mem + for p in params: + grads_device_map[p] = device + self.keep_gathered_chunk_mem = total_chunk_mem * (1 - self.shard_param_frac) + self.keep_cuda_chunk_mem = total_chunk_mem * (1 - self.offload_param_frac) class AutoPlacementPolicy(PlacementPolicy): - need_mem_stats: bool = True - # model data will use 1-_warmup_non_model_data_ratio CUDA memory in warmup phase - # you can set them by AutoPlacementPolicy.set_warmup_non_model_data_ratio() - # and AutoPlacementPolicy.set_steady_cuda_cap_ratio() - _warmup_non_model_data_ratio: float = 0.8 - _steady_cuda_cap_ratio: float = 0.9 - - def __init__(self, - chunk_manager: ChunkManager, - mem_stats_collector: Optional[ChunkMemStatsCollector] = None) -> None: - super().__init__(chunk_manager, mem_stats_collector=mem_stats_collector) - def evict_tensors(self, - can_evict_chunks: List[Chunk], - cuda_demand: int = 0, - warmup: bool = True, - compute_list: Optional[List[Tuple[Chunk, ...]]] = None, - compute_idx: int = 0, - **kwargs) -> Tuple[int, float]: + def __init__( + self, + chunk_manager: ChunkManager, + mem_stats_collector: Optional[ChunkMemStatsCollector] = None, + warmup_non_model_data_ratio: float = 0.8, + steady_cuda_cap_ratio: float = 0.9, + **kwargs, + ) -> None: + super().__init__(chunk_manager, mem_stats_collector=mem_stats_collector) + # model data will use 1-_warmup_non_model_data_ratio CUDA memory in warmup phase + # you can set them by AutoPlacementPolicy.set_warmup_non_model_data_ratio() + # and AutoPlacementPolicy.set_steady_cuda_cap_ratio() + self._warmup_non_model_data_ratio = warmup_non_model_data_ratio + self._steady_cuda_cap_ratio = steady_cuda_cap_ratio + + def evict_tensors( + self, + can_evict_chunks: List[Chunk], + cuda_demand: int = 0, + warmup: bool = True, + compute_list: Optional[List[Tuple[Chunk, ...]]] = None, + compute_idx: int = 0, + **kwargs, + ) -> Tuple[int, float]: """ Evict tensors from CUDA device. @@ -102,14 +141,14 @@ class AutoPlacementPolicy(PlacementPolicy): """ start = time() cuda_capacity = colo_device_memory_capacity(get_current_device()) - used_cuda_model_data = self.chunk_manager.total_mem['cuda'] + used_cuda_model_data = self.chunk_manager.total_mem["cuda"] if warmup: # We designate a part of CUDA memory for model data in warmup iterations. - max_cuda_non_model_data_per_period = cuda_capacity * AutoPlacementPolicy._warmup_non_model_data_ratio + max_cuda_non_model_data_per_period = cuda_capacity * self._warmup_non_model_data_ratio else: # max non-model-data cuda memory consumption of this sampling moment and the next sampling moment. - max_cuda_non_model_data_per_period = self.mem_stats_collector.next_period_non_model_data_usage('cuda') - cuda_capacity *= AutoPlacementPolicy._steady_cuda_cap_ratio + max_cuda_non_model_data_per_period = self.mem_stats_collector.next_period_non_model_data_usage("cuda") + cuda_capacity *= self._steady_cuda_cap_ratio total_cuda_model_data = cuda_capacity - max_cuda_non_model_data_per_period avail_cuda_model_data = total_cuda_model_data - used_cuda_model_data freed_cuda_model_data = 0 @@ -127,11 +166,13 @@ class AutoPlacementPolicy(PlacementPolicy): break self.chunk_manager.release_chunk(chunk) - self.chunk_manager.move_chunk(chunk, torch.device('cpu')) + self.chunk_manager.move_chunk(chunk, torch.device("cpu")) freed_cuda_model_data += chunk.chunk_mem if freed_cuda_model_data < to_free_cuda_model_data: - raise RuntimeError(f"Adjust layout failed! No enough CUDA memory! " - f"Need {to_free_cuda_model_data}, freed {freed_cuda_model_data}") + raise RuntimeError( + f"Adjust layout failed! No enough CUDA memory! " + f"Need {to_free_cuda_model_data}, freed {freed_cuda_model_data}" + ) return freed_cuda_model_data, time() - start @staticmethod @@ -145,89 +186,23 @@ class AutoPlacementPolicy(PlacementPolicy): next_compute_idx = sorted(next_compute_idx.items(), key=lambda pair: pair[1], reverse=True) return [t for (t, idx) in next_compute_idx] - @staticmethod - def set_warmup_non_model_data_ratio(ratio: float) -> None: - ratio = float(ratio) - assert 0.0 < ratio < 1.0 - AutoPlacementPolicy._warmup_non_model_data_ratio = ratio - - @staticmethod - def set_steady_cuda_cap_ratio(ratio: float) -> None: - ratio = float(ratio) - assert 0.0 < ratio < 1.0 - AutoPlacementPolicy._steady_cuda_cap_ratio = ratio - - -class ConstPlacementPolicy(PlacementPolicy): - - need_mem_stats: bool = False - _accessed_memory_boundary = 512 * 1024**2 - - def __init__(self, - chunk_manager: ChunkManager, - mem_stats_collector: Optional[ChunkMemStatsCollector] = None) -> None: - super().__init__(chunk_manager, mem_stats_collector=mem_stats_collector) - - def evict_tensors(self, - can_evict_chunks: List[Chunk], - cuda_demand: int = 0, - warmup: bool = True, - compute_list: Optional[List[Tuple[Chunk, ...]]] = None, - compute_idx: int = 0, - **kwargs) -> Tuple[int, float]: - """ - See the docstrings in the class `AutoPlacementPolicy`. - """ - start = time() - used_accessed_memory = self.chunk_manager.accessed_mem - avail_accessed_memory = ConstPlacementPolicy._accessed_memory_boundary - used_accessed_memory - freed_accessed_memory = 0 - - if avail_accessed_memory < cuda_demand: - to_free_memory = cuda_demand - avail_accessed_memory - to_free_chunks = can_evict_chunks - - if not warmup: - # sort all chunks - to_free_chunks = self._sort_can_evict_chunks(tuple(to_free_chunks), compute_idx, tuple(compute_list)) - - for chunk in to_free_chunks: - if freed_accessed_memory >= to_free_memory: - break - - self.chunk_manager.release_chunk(chunk) - self.chunk_manager.move_chunk(chunk, torch.device('cpu')) - freed_accessed_memory += chunk.chunk_mem - - if freed_accessed_memory < to_free_memory: - raise RuntimeError(f"Adjust layout failed! No enough CUDA memory! " - f"Need {to_free_memory}, freed {freed_accessed_memory}") - return freed_accessed_memory, time() - start - - @staticmethod - @functools.lru_cache(maxsize=None) - def _sort_can_evict_chunks(can_evict_chunks: tuple, compute_idx: int, compute_list: tuple) -> list: - next_compute_idx = {chunk: len(compute_list) for chunk in can_evict_chunks} - for i in range(len(compute_list) - 1, compute_idx, -1): - for chunk in compute_list[i]: - if chunk in next_compute_idx: - next_compute_idx[chunk] = i - next_compute_idx = sorted(next_compute_idx.items(), key=lambda pair: pair[1], reverse=True) - return [t for (t, idx) in next_compute_idx] - - @staticmethod - def set_const_memory_boundary(cuda_memory_mb: int) -> None: - boundary = int(cuda_memory_mb * 1024**2) - assert boundary > 0 - ConstPlacementPolicy._accessed_memory_boundary = boundary + def setup_grads_device( + self, params: List[torch.Tensor], grads_device_map: Dict[torch.Tensor, torch.device] + ) -> None: + for p in params: + chunk = self.chunk_manager.get_chunk(p) + # init offload optim settings + # keep gathered chunks are in CUDA + if chunk.keep_gathered: + grads_device_map[p] = get_current_device() + else: + grads_device_map[p] = torch.device("cpu") class PlacementPolicyFactory: policies: Dict[str, Type[PlacementPolicy]] = { - 'cpu': CPUPlacementPolicy, - 'cuda': CUDAPlacementPolicy, - 'auto': AutoPlacementPolicy, - 'const': ConstPlacementPolicy + "auto": AutoPlacementPolicy, + "static": StaticPlacementPolicy, } @staticmethod @@ -239,8 +214,3 @@ class PlacementPolicyFactory: @staticmethod def get_policy_names(): return tuple(PlacementPolicyFactory.policies.keys()) - - @staticmethod - def get_default_device(policy_name: str) -> torch.device: - policy_cls = PlacementPolicyFactory.create(policy_name) - return policy_cls.get_default_device() diff --git a/colossalai/zero/gemini/utils.py b/colossalai/zero/gemini/utils.py index e52b5b836b0bec568e8a74d369969c0dcec0763a..264099d22de2971cc2ae1b90b7f050033247a2db 100644 --- a/colossalai/zero/gemini/utils.py +++ b/colossalai/zero/gemini/utils.py @@ -27,16 +27,15 @@ def get_temp_total_chunk_on_cuda(chunk: Chunk): return total_temp -def _get_dfs_module_list(module: nn.Module, memo: Optional[Set[nn.Module]] = None, prefix: str = ''): - """Get a dfs module list of the given module. Its order is same as the order of creations of modules. - """ +def _get_dfs_module_list(module: nn.Module, memo: Optional[Set[nn.Module]] = None, prefix: str = ""): + """Get a dfs module list of the given module. Its order is same as the order of creations of modules.""" if memo is None: memo = set() if module not in memo: for name, submodule in module._modules.items(): if submodule is None: continue - submodule_prefix = prefix + ('.' if prefix else '') + name + submodule_prefix = prefix + ("." if prefix else "") + name for m in _get_dfs_module_list(submodule, memo, submodule_prefix): yield m @@ -60,41 +59,43 @@ def _get_shallow_copy_model(model: nn.Module): return old_to_new[model] -def get_static_torch_model(zero_ddp_model, - device=torch.device("cpu"), - dtype=torch.float32, - only_rank_0=True) -> torch.nn.Module: - """Get a static torch.nn.Module model from the given ZeroDDP module. - You should notice that the original ZeroDDP model is not modified. +def get_static_torch_model( + zero_ddp_model, device=torch.device("cpu"), dtype=torch.float32, only_rank_0=True +) -> torch.nn.Module: + """Get a static torch.nn.Module model from the given GeminiDDP module. + You should notice that the original GeminiDDP model is not modified. Thus, you can use the original model in further training. But you should not use the returned torch model to train, this can cause unexpected errors. Args: - zero_ddp_model (ZeroDDP): a zero ddp model + zero_ddp_model (GeminiDDP): a zero ddp model device (torch.device): the device of the final torch model dtype (torch.dtype): the dtype of the final torch model - only_rank_0 (bool): if True, only rank0 has the coverted torch model + only_rank_0 (bool): if True, only rank0 has the converted torch model Returns: torch.nn.Module: a static torch model used for saving checkpoints or numeric checks """ - from colossalai.zero.gemini.gemini_ddp import ZeroDDP - assert isinstance(zero_ddp_model, ZeroDDP) + from colossalai.zero.gemini.gemini_ddp import GeminiDDP + + assert isinstance(zero_ddp_model, GeminiDDP) state_dict = zero_ddp_model.state_dict(only_rank_0=only_rank_0) colo_model = zero_ddp_model.module torch_model = _get_shallow_copy_model(colo_model) if not only_rank_0 or dist.get_rank() == 0: - for (name, colo_module), (_, torch_module) in \ - zip(_get_dfs_module_list(colo_model), _get_dfs_module_list(torch_model)): + for (name, colo_module), (_, torch_module) in zip( + _get_dfs_module_list(colo_model), _get_dfs_module_list(torch_model) + ): # clean the parameter list of the new torch module torch_module._parameters = OrderedDict() for sufix_param_name, param in colo_module.named_parameters(recurse=False): # get the full name of the parameter - full_param_name = name + ('.' if name else '') + sufix_param_name - assert full_param_name in state_dict, \ - f"Can not find parameter `{full_param_name}` in the GeminiDDP module" + full_param_name = name + ("." if name else "") + sufix_param_name + assert ( + full_param_name in state_dict + ), f"Can not find parameter `{full_param_name}` in the GeminiDDP module" state_param = state_dict[full_param_name] torch_param = torch.nn.Parameter(state_param.data.to(device=device, dtype=dtype)) diff --git a/colossalai/zero/legacy/__init__.py b/colossalai/zero/legacy/__init__.py deleted file mode 100644 index 3783d38e61b27cb97e3709a0db882acfa5667821..0000000000000000000000000000000000000000 --- a/colossalai/zero/legacy/__init__.py +++ /dev/null @@ -1,45 +0,0 @@ -from typing import Tuple - -import torch -import torch.nn as nn - -from colossalai.logging import get_dist_logger - -from .init_ctx import ZeroInitContext, no_shard_zero_context, no_shard_zero_decrator -from .shard_utils import BucketTensorShardStrategy, TensorShardStrategy -from .sharded_model import ShardedModelV2 -from .sharded_optim import ShardedOptimizerV2 - - -def convert_to_zero_v2(model: nn.Module, optimizer: torch.optim.Optimizer, model_config, - optimizer_config) -> Tuple[ShardedModelV2, ShardedOptimizerV2]: - """ - A helper function to integrate the model and optimizer with ZeRO optimizer and off-loading - - :param model: Your model object - :type model: :class:`torch.nn.Module` - :param optimizer_config: Your optimizer object - :type optimizer_config: :class:`dict` - - :return: (model, optimizer) - :rtype: Tuple - """ - - logger = get_dist_logger('convert_to_zero_v2') - - logger.info(f'optimizer_config is {optimizer_config}', ranks=[0]) - if optimizer_config is None: - optimizer_config = dict() - logger.info(f'model_config is {model_config}', ranks=[0]) - if model_config is None: - model_config = dict() - - zero_model = ShardedModelV2(model, **model_config) - zero_optimizer = ShardedOptimizerV2(zero_model, optimizer, **optimizer_config) - return zero_model, zero_optimizer - - -__all__ = [ - 'convert_to_zero_v2', 'ShardedModelV2', 'ShardedOptimizerV2', 'ZeroInitContext', 'no_shard_zero_context', - 'no_shard_zero_decrator', 'TensorShardStrategy', 'BucketTensorShardStrategy' -] diff --git a/colossalai/zero/legacy/gemini/__init__.py b/colossalai/zero/legacy/gemini/__init__.py deleted file mode 100644 index 754ae9bc004431a793a01e1ff074fd8a1a972236..0000000000000000000000000000000000000000 --- a/colossalai/zero/legacy/gemini/__init__.py +++ /dev/null @@ -1,9 +0,0 @@ -from .ophooks import BaseOpHook, register_ophooks_recursively -from .stateful_tensor import StatefulTensor -from .stateful_tensor_mgr import StatefulTensorMgr -from .tensor_placement_policy import AutoTensorPlacementPolicy, CPUTensorPlacementPolicy, CUDATensorPlacementPolicy - -__all__ = [ - 'StatefulTensorMgr', 'StatefulTensor', 'CPUTensorPlacementPolicy', 'CUDATensorPlacementPolicy', - 'AutoTensorPlacementPolicy', 'register_ophooks_recursively', 'BaseOpHook' -] diff --git a/colossalai/zero/legacy/gemini/gemini_context.py b/colossalai/zero/legacy/gemini/gemini_context.py deleted file mode 100644 index 9a7da6b80fbaddc43074d3599bdd0fd18548f94b..0000000000000000000000000000000000000000 --- a/colossalai/zero/legacy/gemini/gemini_context.py +++ /dev/null @@ -1,48 +0,0 @@ -from enum import EnumMeta - - -class GeminiMemoryManager(object): - - def __init__(self, states_cls: EnumMeta): - super().__init__() - self.states_cls = states_cls - self._cnter = 0 # the counter of instances - - self.total_mem = dict() - self.state_mem = dict() - self.state_mem['cpu'] = dict() - self.state_mem['cuda'] = dict() - - self.reset() - - @property - def total_number(self): - return self._cnter - - def reset(self): - self._cnter = 0 # the counter of instances - - self.total_mem['cpu'] = 0 # memory occupation of instances in cpu - self.total_mem['cuda'] = 0 # memory of occupation of instances in cuda - - # memory conditions for all states - for state in self.states_cls: - self.state_mem['cpu'][state] = 0 - self.state_mem['cuda'][state] = 0 - - def register_new_instance(self): - self._cnter += 1 - - def delete_instance(self): - self._cnter -= 1 - - def print_info(self): - print(f"Total number: {self.total_number}", - f"Total CPU memory occupation: {self.total_mem['cpu']}", - f"Total CUDA memory occupation: {self.total_mem['cuda']}\n", - sep='\n') - - for state in self.states_cls: - print(f"{state}: CPU memory occupation: {self.state_mem['cpu'][state]}", - f"{state}: CUDA memory occupation: {self.state_mem['cuda'][state]}\n", - sep='\n') diff --git a/colossalai/zero/legacy/gemini/ophooks/utils.py b/colossalai/zero/legacy/gemini/ophooks/utils.py deleted file mode 100644 index 84e8298c1d5186c5c292a68126499a674f31a593..0000000000000000000000000000000000000000 --- a/colossalai/zero/legacy/gemini/ophooks/utils.py +++ /dev/null @@ -1,142 +0,0 @@ -# this code is inspired by the DeepSpeed library and implemented with our own design from scratch -from abc import ABC, abstractmethod -from typing import Callable, List, Optional - -import torch - - -class BaseOpHook(ABC): - """This class allows users to add customized operations - before and after the execution of a PyTorch submodule""" - - def __init__(self): - pass - - @abstractmethod - def pre_fwd_exec(self, module: torch.nn.Module, *args): - pass - - @abstractmethod - def post_fwd_exec(self, module: torch.nn.Module, *args): - pass - - @abstractmethod - def pre_bwd_exec(self, module: torch.nn.Module, input, output): - pass - - @abstractmethod - def post_bwd_exec(self, module: torch.nn.Module, input): - pass - - @abstractmethod - def post_iter(self): - pass - - -# apply torch.autograd.Function that calls a backward_function to tensors in output -def _apply_to_tensors_only(module, functional, backward_function, outputs): - if type(outputs) is tuple: - touched_outputs = [] - for output in outputs: - touched_output = _apply_to_tensors_only(module, functional, backward_function, output) - touched_outputs.append(touched_output) - return tuple(touched_outputs) - elif type(outputs) is torch.Tensor: - return functional.apply(module, backward_function, outputs) - else: - return outputs - - -class PreBackwardFunction(torch.autograd.Function): - - @staticmethod - def forward(ctx, module, pre_backward_function, outputs): - ctx.module = module - ctx.pre_backward_function = pre_backward_function - module.applied_pre_backward = False - outputs = outputs.detach() - return outputs - - @staticmethod - def backward(ctx, *args): - ctx.pre_backward_function(ctx.module) - return (None, None) + args - - -class PostBackwardFunction(torch.autograd.Function): - - @staticmethod - def forward(ctx, module, pre_backward_function, output): - ctx.module = module - output = output.detach() - ctx.pre_backward_function = pre_backward_function - return output - - @staticmethod - def backward(ctx, *args): - """ - Args: - activation_grad of the next layer. - Returns: - grad of the input activation. - """ - ctx.pre_backward_function(ctx.module) - return (None, None) + args - - -def register_ophooks_recursively(module: torch.nn.Module, - ophook_list: List[BaseOpHook], - name: str = "", - filter_fn: Optional[Callable] = None): - r"""Recursilvely register pre/post hooks for all submodules in the module in FWD and BWD.""" - assert isinstance(module, torch.nn.Module) - assert isinstance(ophook_list, (list, tuple)) - assert len(ophook_list) > 0, 'expected at least 1 hook in the argument ophook_list but found 0' - for hook in ophook_list: - assert (isinstance(hook, BaseOpHook)) - - # Add hooks for submodules - for child_name, child in module.named_children(): - register_ophooks_recursively(child, ophook_list, name + child_name, filter_fn) - - # Early return on modules with no parameters. - if len(list(module.parameters(recurse=False))) == 0: - return - - # return from flitered module - if filter_fn is not None and filter_fn(module): - return - - def _pre_forward_module_hook(submodule, *args): - for hook in ophook_list: - assert isinstance(submodule, torch.nn.Module) - hook.pre_fwd_exec(submodule, *args) - - def _post_forward_module_hook(submodule, *args): - for hook in ophook_list: - assert isinstance(submodule, torch.nn.Module) - hook.post_fwd_exec(submodule, *args) - - def _pre_backward_module_hook(submodule, inputs, output): - - def _run_before_backward_function(submodule): - for hook in ophook_list: - assert isinstance(submodule, torch.nn.Module) - hook.pre_bwd_exec(submodule, inputs, output) - - return _apply_to_tensors_only(submodule, PreBackwardFunction, _run_before_backward_function, output) - - def _post_backward_module_hook(submodule, inputs): - - def _run_after_backward_function(submodule): - for hook in ophook_list: - assert isinstance(submodule, torch.nn.Module) - hook.post_bwd_exec(submodule, inputs) - - return _apply_to_tensors_only(submodule, PostBackwardFunction, _run_after_backward_function, inputs) - - module.register_forward_pre_hook(_pre_forward_module_hook) - module.register_forward_hook(_post_forward_module_hook) - - module.register_forward_hook(_pre_backward_module_hook) - module.register_forward_pre_hook(_post_backward_module_hook) diff --git a/colossalai/zero/legacy/init_ctx/__init__.py b/colossalai/zero/legacy/init_ctx/__init__.py deleted file mode 100644 index 0a6f81566a9de2d83561fe7d91f9052244b286b8..0000000000000000000000000000000000000000 --- a/colossalai/zero/legacy/init_ctx/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -from .init_context import ZeroInitContext, no_shard_zero_context, no_shard_zero_decrator - -__all__ = ['ZeroInitContext', 'no_shard_zero_context', 'no_shard_zero_decrator'] diff --git a/colossalai/zero/legacy/shard_utils/__init__.py b/colossalai/zero/legacy/shard_utils/__init__.py deleted file mode 100644 index 5e5d63a7e768a470b609ccd185012864752cb432..0000000000000000000000000000000000000000 --- a/colossalai/zero/legacy/shard_utils/__init__.py +++ /dev/null @@ -1,5 +0,0 @@ -from .base_shard_strategy import BaseShardStrategy -from .bucket_tensor_shard_strategy import BucketTensorShardStrategy -from .tensor_shard_strategy import TensorShardStrategy - -__all__ = ['BaseShardStrategy', 'TensorShardStrategy', 'BucketTensorShardStrategy'] diff --git a/colossalai/zero/legacy/sharded_model/__init__.py b/colossalai/zero/legacy/sharded_model/__init__.py deleted file mode 100644 index 93120bdc34b4f18f48f7975e79d956c90d9ec50c..0000000000000000000000000000000000000000 --- a/colossalai/zero/legacy/sharded_model/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -from .sharded_model_v2 import ShardedModelV2 - -__all__ = ['ShardedModelV2'] diff --git a/colossalai/zero/legacy/sharded_model/_utils.py b/colossalai/zero/legacy/sharded_model/_utils.py deleted file mode 100644 index 2bd01531a78f517c5a431b4cb4d7e7af65024613..0000000000000000000000000000000000000000 --- a/colossalai/zero/legacy/sharded_model/_utils.py +++ /dev/null @@ -1,77 +0,0 @@ -from typing import Any, Callable, List, Tuple, Union - -import torch -import torch.nn.functional as F - -from colossalai.zero.legacy.gemini.stateful_tensor import StatefulTensor - - -def get_gradient_predivide_factor(world_size: int) -> float: - factor: int = 1 - while world_size % factor == 0 and world_size / factor > factor: - factor *= 2 - return float(factor) - - -def free_storage(data: torch.Tensor) -> None: - """Free underlying storage of a Tensor.""" - if data.storage().size() > 0: - # Since we're modifying the Tensor's Storage directly, make sure the Tensor - # is the sole occupant of the Storage. - assert data.storage_offset() == 0 - data.storage().resize_(0) - - -@torch.no_grad() -def alloc_storage(data: torch.Tensor, size: torch.Size) -> None: - """Allocate storage for a tensor.""" - if data.storage().size() == size.numel(): # no need to reallocate - return - assert data.storage().size() == 0 - data.storage().resize_(size.numel()) - - -def cast_tensor_to_fp16(tensor: torch.Tensor) -> torch.Tensor: - if isinstance(tensor, StatefulTensor): - tensor = tensor.payload - if torch.is_floating_point(tensor) and tensor.dtype is torch.float32: - return tensor.half() - return tensor - - -def cast_tensor_to_fp32(tensor: Union[torch.Tensor, StatefulTensor]) -> torch.Tensor: - if isinstance(tensor, StatefulTensor): - tensor = tensor.payload - - if torch.is_floating_point(tensor) and tensor.dtype is torch.float16: - return tensor.float() - return tensor - - -def apply_to_tensors(x: Any, fn: Callable): - if torch.is_tensor(x): - return fn(x) - elif isinstance(x, list): - return [apply_to_tensors(t, fn) for t in x] - elif isinstance(x, tuple): - return tuple(apply_to_tensors(t, fn) for t in x) - elif isinstance(x, dict): - return {key: apply_to_tensors(val, fn) for key, val in x.items()} - else: - return x - - -def cast_float_arguments(fn: Callable, *args: Any, **kwargs: Any) -> Tuple[Any, Any]: - return apply_to_tensors(args, fn), apply_to_tensors(kwargs, fn) - - -def chunk_and_pad(tensor: torch.Tensor, num_chunks: int) -> List[torch.Tensor]: - """Chunk a given Tensor into num_chunks parts and add any necessary padding.""" - chunks = list(torch.flatten(tensor).chunk(num_chunks)) - # torch.chunk may return fewer than num_chunks chunks, pad accordingly. - num_pad_for_partial_chunk = chunks[0].numel() - chunks[-1].numel() - if num_pad_for_partial_chunk > 0: - chunks[-1] = F.pad(chunks[-1], [0, num_pad_for_partial_chunk]) - if len(chunks) < num_chunks: - chunks.extend([torch.zeros_like(chunks[0]) for _ in range(num_chunks - len(chunks))]) - return chunks diff --git a/colossalai/zero/legacy/sharded_model/sharded_model_v2.py b/colossalai/zero/legacy/sharded_model/sharded_model_v2.py deleted file mode 100644 index b3a83b7418250216c4b275c17fd88cb4ec779641..0000000000000000000000000000000000000000 --- a/colossalai/zero/legacy/sharded_model/sharded_model_v2.py +++ /dev/null @@ -1,572 +0,0 @@ -# this code is inspired by the DeepSpeed library and implemented with our own design from scratch -import functools -import itertools -from collections import OrderedDict -from copy import deepcopy -from typing import Any, Iterator, Optional, Tuple - -import torch -import torch.distributed as dist -import torch.nn as nn -from torch.distributed import ProcessGroup -from torch.nn.parameter import Parameter - -from colossalai.context.parallel_mode import ParallelMode -from colossalai.core import global_context as gpc -from colossalai.logging import get_dist_logger -from colossalai.utils import disposable, get_current_device -from colossalai.utils.memory import colo_device_memory_capacity -from colossalai.zero.gemini.memory_tracer import MemStatsCollector, StaticMemStatsCollector -from colossalai.zero.legacy.gemini.ophooks import register_ophooks_recursively -from colossalai.zero.legacy.gemini.paramhooks import BaseParamHookMgr -from colossalai.zero.legacy.gemini.stateful_tensor import TensorState -from colossalai.zero.legacy.gemini.stateful_tensor_mgr import StatefulTensorMgr -from colossalai.zero.legacy.gemini.tensor_placement_policy import TensorPlacementPolicy, TensorPlacementPolicyFactory -from colossalai.zero.legacy.gemini.tensor_utils import colo_model_data_move_to_cpu -from colossalai.zero.legacy.shard_utils import BaseShardStrategy -from colossalai.zero.legacy.sharded_model.reduce_scatter import ReduceScatterBucketer - -from ._utils import ( - cast_float_arguments, - cast_tensor_to_fp16, - cast_tensor_to_fp32, - chunk_and_pad, - free_storage, - get_gradient_predivide_factor, -) -from .zero_hook import ZeroHook - -try: - from torch.nn.modules.module import _EXTRA_STATE_KEY_SUFFIX -except ImportError: - _EXTRA_STATE_KEY_SUFFIX = '_extra_state' - - -class ShardedModelV2(nn.Module): - """ - A wrapper for the PyTorch module shards the model parameters among multiple GPU memory. - Only `1/#nproc` of parameters, gradients are stored in local CUDA memory, so forward and backward - passes can be executed with limited CUDA memory budget. - - Note: - You must use ``ShardedModelV2`` with ``ShardedOptimizerV2``. - Note: - Make sure you don't use gradient accumulation and your optimizer can work with fp16 gradient and fp32 parameter, - if you enable ``reuse_fp16_shard``. - - Args: - module (nn.Module): A sharded module, which must be initialized by `ZeroInitContext`. - shard_strategy (BaseShardStrategy): A shard strategy to manage shard behavior. - process_group (Optional[ProcessGroup], optional): Data parallel process group. Defaults to None. - reduce_scatter_process_group (Optional[ProcessGroup], optional): Reduce-scatter process group. - Generally, it should be `None`, and it's the same as `process_group`. Defaults to None. - reduce_scatter_bucket_size_mb (int, optional): Reduce-scatter bucket size in *MB*. Defaults to 25. - fp32_reduce_scatter (bool, optional): If set to `True`, gradients are forced to FP32 before reduce-scatter. Defaults to False. - tensor_placement_policy (str): Which device to place *held* tensors. It can be 'cpu', 'cuda' and 'auto'. - If it's 'cpu', parameters, gradients and optimizer states will be offloaded to CPU, which means min CUDA memory will be used. - If it's 'cuda', they won't be offloaded, which means max CUDA memory will be used. - If it's 'auto', they are moving dynamically based on CPU and CUDA memory usage. It will utilize heterogeneous memory space evenly and well. - Note that 'auto' policy can only work well when no other processes use CUDA during your training. - Defaults to 'cuda'. - gradient_predivide_factor (Optional[float], optional): Gradient is divived by this value before reduce-scatter. Defaults to 1.0. - reuse_fp16_shard (bool, optional): Whether to reuse fp16 shard for param and grad. - Enabling this can reduce GPU memory usage, but you have to make sure you disable it when using gradient accumulation. - In this mode, grad will be fp16. Make sure your optimizer supports mixed precision (fp32 param and fp16 grad). - We find that PyTorch's optimizers don't support mixed precision, - so we recommend you enable this only when using our CPUAdam with CPU offload. Defaults to False. - """ - - def __init__(self, - module: nn.Module, - shard_strategy: BaseShardStrategy, - process_group: Optional[ProcessGroup] = None, - reduce_scatter_process_group: Optional[ProcessGroup] = None, - reduce_scatter_bucket_size_mb: int = 25, - fp32_reduce_scatter: bool = False, - tensor_placement_policy: str = 'cuda', - gradient_predivide_factor: Optional[float] = 1.0, - reuse_fp16_shard: bool = False, - *args, - **kwargs): - assert not isinstance(module, ShardedModelV2), 'Nested ShardedModelV2 is not supported.' - super().__init__() - self.logger = get_dist_logger() - - # We force users to use ZeroInitContext - for submodule in module.modules(): - sharded_cnt = 0 - unshard_cnt = 0 - for param in submodule.parameters(recurse=False): - assert hasattr(param, 'colo_attr'), 'You must use ZeroInitContext to init your module first.' - if param.colo_attr.param_is_sharded: - sharded_cnt += 1 - else: - unshard_cnt += 1 - assert (not sharded_cnt) or (not unshard_cnt), 'nn.Module can not both have shard param and unshard param' - submodule.param_is_sharded = (sharded_cnt > 0) - - self.sharded_params = [] - self.unshard_params = [] - for param in module.parameters(): - if param.colo_attr.param_is_sharded: - self.sharded_params.append(param) - else: - self.unshard_params.append(param) - - self.module = module - self.process_group = process_group or gpc.get_group(ParallelMode.DATA) - self.reduce_scatter_process_group = reduce_scatter_process_group or self.process_group - self.world_size = dist.get_world_size(self.process_group) - self.rank = dist.get_rank(self.process_group) - self.shard_strategy = shard_strategy - - self._use_memory_tracer = tensor_placement_policy == 'auto' - if self._use_memory_tracer: - self._memstats_collector = MemStatsCollector() - self._start_collect_memstats = disposable(self._memstats_collector.start_collection) - self._finish_collect_memstats = disposable(self._memstats_collector.finish_collection) - else: - self._memstats_collector = None - self._tensor_placement_policy: TensorPlacementPolicy = TensorPlacementPolicyFactory.create( - tensor_placement_policy)(mem_stats_collector=self._memstats_collector) - - if 'warmup_non_model_data_ratio' in kwargs: - if tensor_placement_policy != 'auto': - self.logger.warning('setting warmup_non_model_data_ratio is useless if not use auto placement') - else: - ratio = kwargs['warmup_non_model_data_ratio'] - self._tensor_placement_policy._warmup_non_model_data_ratio = ratio - self.logger.info(f'setting warmup_non_model_data_ratio as {ratio} for auto placement') - - self._stateful_tensor_mgr = StatefulTensorMgr(self._tensor_placement_policy) - param_tensor_list = [p.colo_attr.sharded_data_tensor for p in module.parameters() if hasattr(p, 'colo_attr')] - self._stateful_tensor_mgr.register_stateful_tensor_list(param_tensor_list) - - # Register hooks - self._ophook_list = [ - ZeroHook(self.shard_strategy, self._memstats_collector, self._stateful_tensor_mgr, self.process_group) - ] - register_ophooks_recursively(self.module, self._ophook_list) - self.param_hook_mgr = BaseParamHookMgr(list(self.module.parameters())) - self.param_hook_mgr.register_backward_hooks(self._grad_post_backward_hook) - - self.fp32_reduce_scatter = fp32_reduce_scatter - self._cpu_offload: bool = tensor_placement_policy != 'cuda' - for param in module.parameters(): - # Init `offload_grad` - param.colo_attr.offload_grad = self._cpu_offload - - # We find if gradient_predivide_factor != 1.0, there may be wrong precision problem - # So we use 1.0 as the default gradient_predivide_factor - # However, if you set gradient_predivide_factor to None, we will set - # gradient_predivide_factor to a value >= 1.0 automatically - self.gradient_predivide_factor: float = gradient_predivide_factor if \ - gradient_predivide_factor is not None else \ - get_gradient_predivide_factor(self.world_size) - self.gradient_postdivide_factor: float = self.world_size / self.gradient_predivide_factor - - self.comm_stream: torch.cuda.Stream = torch.cuda.Stream() - self.reducer = ReduceScatterBucketer(reduce_scatter_bucket_size_mb) - self._require_backward_grad_sync: bool = True - - self._cuda_margin_space = 0 - self.reuse_fp16_shard = reuse_fp16_shard - - # record whether gradients have inf or nan - self.overflow_counter = 0 - - def adjust_stateful_tensor_layout(self) -> None: - self._stateful_tensor_mgr.adjust_layout() - - @property - def use_memory_tracer(self): - return self._use_memory_tracer - - @property - def cuda_margin_space(self): - return self._cuda_margin_space - - @property - def cpu_offload(self): - return self._cpu_offload - - def dump_memory_stats(self, filename: Optional[str] = 'dump_mem_stats.log') -> None: - """ - dummy memory tracer collected information to a file. - try: - # forward: model(inputs) - # backward: optimizer.backward() - except Exception as e: - model.dump_memory_stats() - exit(0) - """ - if self._use_memory_tracer: - self.logger.error(f'dump memort tracer collected information to a {filename}', ranks=[0]) - if gpc.get_global_rank() == 0: - with open(filename, 'w+') as f: - f.write(f'cuda reserved {torch.cuda.memory_reserved(get_current_device()) / 1e9} GB\n') - f.write(f'cuda max allocated {torch.cuda.max_memory_allocated(get_current_device()) / 1e9} GB\n') - f.write('CUDA model data (GB)\n') - f.write('\n') - f.write('CUDA non model data (GB)\n') - f.write(str(self._memstats_collector._memstats.non_model_data_list('cuda'))) - f.write('CPU non model data (GB)\n') - f.write(str(self._memstats_collector._memstats.non_model_data_list('cpu'))) - f.write('\n') - - def _pre_forward_operations(self, *args): - # the operation will affect the memory tracer behavior in ZeroHook - if self._memstats_collector: - self._start_collect_memstats() - - for p in self.module.parameters(): - if hasattr(p, 'colo_attr'): - p.colo_attr.sharded_data_tensor.trans_state(TensorState.HOLD) - - self._stateful_tensor_mgr.start_iter() - - def _post_forward_operations(self): - for p in self.module.parameters(): - if hasattr(p, 'colo_attr'): - p.colo_attr.sharded_data_tensor.trans_state(TensorState.HOLD) - - def forward(self, *args: Any, **kwargs: Any) -> torch.Tensor: - self._pre_forward_operations(*args) - args, kwargs = cast_float_arguments(cast_tensor_to_fp16, *args, **kwargs) - outputs = self.module(*args, **kwargs) - self._post_forward_operations() - return outputs - - def backward(self, loss): - loss.backward() - self._post_backward_operations() - for ophook in self._ophook_list: - ophook.post_iter() - - def backward_by_grad(self, tensor, grad): - torch.autograd.backward(tensors=tensor, grad_tensors=grad) - self._post_backward_operations() - for ophook in self._ophook_list: - ophook.post_iter() - - def _update_memstats(self): - if self._memstats_collector: - self._finish_collect_memstats() - # cuda margin space = cuda mem capacity - max fwd/bwd cuda mem used. - # the way to calculate margin space is based on the assumption that - # model data is fixed in cuda during training. - # cuda margin space can be used to store OS. - self._cuda_margin_space = colo_device_memory_capacity( - get_current_device()) - self._memstats_collector._memstats.max_overall_cuda - - @torch.no_grad() - def _post_backward_operations(self) -> None: - """ - The method includes operations required to be processed after backward - 1. update memory tracer. - 2. flush the gradient in buckets. Reducing partial gradients in each process. - 3. shard tensors not dealed in the zero hook - 4. move sharded param grad payload to param.grad - """ - # 1. update memory tracer. - self._update_memstats() - - # 2. flush the gradient in buckets. Reducing partial gradients in each process. - if self._require_backward_grad_sync: - # Flush any unreduced buckets in the post_backward stream. - with torch.cuda.stream(self.comm_stream): - self.reducer.flush() - torch.cuda.current_stream().wait_stream(self.comm_stream) - self.reducer.free() - - # 3. shard tensors not dealed in the zero hook - tensor_list = [] - for p in self.sharded_params: - if not p.colo_attr.param_is_sharded: - tensor_list.append(p.colo_attr.sharded_data_tensor) - p.colo_attr.sharded_data_tensor.trans_state(TensorState.HOLD_AFTER_BWD) - p.colo_attr.set_data_none() - self.shard_strategy.shard(tensor_list, self.process_group) - - # 4. set all parameters' grad to None - for p in self.module.parameters(): - if not p.requires_grad: - continue - # Leave the gradient accumulation state (_require_backward_grad_sync) as-is if not synchronizing this pass. - # NOTE() (no-sync)/sync pass: (not conduct)/conduct gradient all reducing between process group. - # If _require_backward_grad_sync is True, - # p.grad remains the accumulated unsharded gradient from prior no-sync passes. - # We also allows to interleave no-sync pass with sync passes, if desired. - if not self._require_backward_grad_sync: - continue - - p.grad = None - - @torch.no_grad() - def _grad_post_backward_hook(self, param: Parameter, grad: torch.Tensor) -> Optional[torch.Tensor]: - """ - At the start of :func:`_grad_post_backward_hook`, ``param.grad`` contains the - full gradient for the local batch. The reduce-scatter op will save - a single shard of the summed gradient across all - GPUs to param.colo_attr.grad. This shard will align with the current GPU rank. For example:: - - before reduce_scatter: - param.grad (GPU #0): [1, 2, 3, 4] - param.grad (GPU #1): [5, 6, 7, 8] - - after reduce_scatter: - param.grad (GPU #0): [6, 8] # 1+5, 2+6 - param.grad (GPU #1): [10, 12] # 3+7, 4+8 - - The local GPU's ``optim.step`` is responsible for updating a single - shard of params, also corresponding to the current GPU's rank. This - alignment is created by `param.colo_attr.grad`, which ensures that - the local optimizer only sees the relevant parameter shard. - """ - if grad is None: - return - assert not grad.requires_grad, 'ShardedModel only works with gradients that don\'t require gradients' - if not self._require_backward_grad_sync: - return - # used to cheat Pytorch, since we can't return None - empty_grad = torch.empty_like(grad) - free_storage(empty_grad) - # As torch didn't allow modifying grad in hook, we make a copy - grad = grad.clone() - if param.colo_attr.is_replicated: - self._reduce_scatter_handler(param, grad) - else: - self._save_grad(param, grad) - return empty_grad - - def _reduce_scatter_handler(self, param: Parameter, grad: torch.Tensor) -> None: - self.comm_stream.wait_stream(torch.cuda.current_stream()) - with torch.cuda.stream(self.comm_stream): - if self.fp32_reduce_scatter: - grad.data = grad.data.to(param.dtype) - if self.gradient_predivide_factor > 1.0: - # Average grad by world_size for consistency with PyTorch DDP. - grad.data.div_(self.gradient_predivide_factor) - if self.world_size > 1: - grad_chunks = chunk_and_pad(grad, self.reduce_scatter_process_group.size()) - self.reducer.reduce_scatter_async(grad_chunks, - group=self.reduce_scatter_process_group, - callback_fn=functools.partial(self._reduce_scatter_callback, param)) - else: - self._reduce_scatter_callback(param, grad) - torch.cuda.current_stream().wait_stream(self.comm_stream) - - def _reduce_scatter_callback(self, param: Parameter, reduced_grad: torch.Tensor) -> None: - assert isinstance(reduced_grad, - torch.Tensor), f"_reduce_scatter_callback accept reduced_grad as {type(reduced_grad)}" - reduced_grad.data = reduced_grad.data.contiguous().view(-1) - if self.gradient_postdivide_factor > 1: - # Average grad by world_size for consistency with PyTorch DDP. - reduced_grad.data.div_(self.gradient_postdivide_factor) - self._save_grad(param, reduced_grad) - - # FIXME(ver217): refactor the below line when impl eviction policy - def _save_grad(self, param: Parameter, grad: torch.Tensor): - - # record whether we have overflow - self.overflow_counter += torch.isinf(grad).any().item() - self.overflow_counter += torch.isnan(grad).any().item() - - # move gradient to cpu - if param.colo_attr.offload_grad: - colo_model_data_move_to_cpu(grad) - - if self.reuse_fp16_shard: - # make parameters point to gradient - - assert param.colo_attr.saved_grad.is_null( - ), 'Gradien accumulation is not supported when reuse_fp16_shard=True' - - param.colo_attr.grad_payload_reset(grad.data) - # release the memory of param - # we set a false None for parameter's payload - # so we can get parameter's device and dtype later in optimizer - param.colo_attr.data_payload_reset(torch.empty(0, device=grad.device, dtype=grad.dtype)) - - if param.colo_attr.is_replicated: - param.colo_attr.sharded_data_tensor.is_sharded = True - else: - - fp32_grad = cast_tensor_to_fp32(grad) - - if param.colo_attr.saved_grad.is_null(): - param.colo_attr.grad_payload_reset(fp32_grad) - else: - param.colo_attr.grad_payload.add_(fp32_grad.view_as(param.colo_attr.grad_payload)) - - # keep saved_grad in HOLD state - param.colo_attr.saved_grad.trans_state(TensorState.HOLD) - - def parameters(self, recurse: bool = True) -> Iterator[Parameter]: - return self.module.parameters(recurse=recurse) - - def named_parameters(self, prefix: str = '', recurse: bool = True) -> Iterator[Tuple[str, Parameter]]: - return self.module.named_parameters(prefix, recurse) - - def state_dict(self, destination=None, prefix='', keep_vars=False) -> 'OrderedDict[str, torch.Tensor]': - return self._colo_state_dict(destination, - prefix, - keep_vars, - shard_strategy=self.shard_strategy, - state_dict_func=nn.Module.state_dict, - module_to_load=self.module, - sharded_params=self.sharded_params, - process_group=self.process_group) - - def load_state_dict(self, state_dict: 'OrderedDict[str, torch.Tensor]', strict: bool = True) -> None: - for name, p in self.named_parameters(): - if name in state_dict: - p.colo_attr.data_payload_reset(state_dict[name].to(dtype=p.colo_attr.data_payload.dtype, - device=p.colo_attr.data_payload.device)) - # Force re-shard - p.colo_attr.sharded_data_tensor.is_sharded = False - self.shard_strategy.shard([p.colo_attr.sharded_data_tensor]) - elif strict: - raise RuntimeError(f'Missing key in state_dict: {name}') - - def _colo_state_dict(self, - destination=None, - prefix='', - keep_vars=False, - shard_strategy: Optional[BaseShardStrategy] = None, - state_dict_func=None, - module_to_load=None, - sharded_params=[], - process_group=None) -> 'OrderedDict[str, torch.Tensor]': - if len(sharded_params) == 0: - for param in self.parameters(): - if param.colo_attr.param_is_sharded: - sharded_params.append(param) - if shard_strategy is not None: - shard_strategy.gather([p.colo_attr.sharded_data_tensor for p in sharded_params], process_group) - for p in sharded_params: - p.data = p.colo_attr.data_payload - module_to_load = module_to_load or self - gathered_state_dict = state_dict_func(module_to_load, destination, prefix, keep_vars) - gathered_state_dict = {k: v.cpu() if isinstance(v, torch.Tensor) else v for k, v in gathered_state_dict.items()} - if shard_strategy is not None: - shard_strategy.shard([p.colo_attr.sharded_data_tensor for p in sharded_params], process_group) - for p in sharded_params: - p.colo_attr.set_data_none() - return gathered_state_dict - - def _colo_load_from_state_dict(self, - state_dict, - prefix, - local_metadata, - strict, - missing_keys, - unexpected_keys, - error_msgs, - shard_strategy=None): - r"""Copies parameters and buffers from :attr:`state_dict` into only - this module, but not its descendants. This is called on every submodule - in :meth:`~torch.nn.Module.load_state_dict`. Metadata saved for this - module in input :attr:`state_dict` is provided as :attr:`local_metadata`. - For state dicts without metadata, :attr:`local_metadata` is empty. - Subclasses can achieve class-specific backward compatible loading using - the version number at `local_metadata.get("version", None)`. - - .. note:: - :attr:`state_dict` is not the same object as the input - :attr:`state_dict` to :meth:`~torch.nn.Module.load_state_dict`. So - it can be modified. - - Args: - state_dict (dict): a dict containing parameters and - persistent buffers. - prefix (str): the prefix for parameters and buffers used in this - module - local_metadata (dict): a dict containing the metadata for this module. - See - strict (bool): whether to strictly enforce that the keys in - :attr:`state_dict` with :attr:`prefix` match the names of - parameters and buffers in this module - missing_keys (list of str): if ``strict=True``, add missing keys to - this list - unexpected_keys (list of str): if ``strict=True``, add unexpected - keys to this list - error_msgs (list of str): error messages should be added to this - list, and will be reported together in - :meth:`~torch.nn.Module.load_state_dict` - shard_strategy (Optional[BaseShardStrategy], optional): A shard strategy to manage shard behavior. Defaults to None. - """ - for hook in self._load_state_dict_pre_hooks.values(): - hook(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs) - - persistent_buffers = {k: v for k, v in self._buffers.items() if k not in self._non_persistent_buffers_set} - local_name_params = itertools.chain(self._parameters.items(), persistent_buffers.items()) - local_state = {k: v for k, v in local_name_params if v is not None} - - for name, param in local_state.items(): - key = prefix + name - if key in state_dict: - input_param = state_dict[key] - if hasattr(param, 'colo_attr'): - param.colo_attr.data_payload_reset( - input_param.to(dtype=param.colo_attr.data_payload.dtype, - device=param.colo_attr.data_payload.device)) - if shard_strategy is not None: - # Force re-shard - param.colo_attr.sharded_data_tensor.is_sharded = False - shard_strategy.shard([param.colo_attr.sharded_data_tensor]) - else: - # This is used to avoid copying uninitialized parameters into - # non-lazy modules, since they dont have the hook to do the checks - # in such case, it will error when accessing the .shape attribute. - is_param_lazy = torch.nn.parameter.is_lazy(param) - # Backward compatibility: loading 1-dim tensor from 0.3.* to version 0.4+ - if not is_param_lazy and len(param.shape) == 0 and len(input_param.shape) == 1: - input_param = input_param[0] - - if not is_param_lazy and input_param.shape != param.shape: - # local shape should match the one in checkpoint - error_msgs.append('size mismatch for {}: copying a param with shape {} from checkpoint, ' - 'the shape in current model is {}.'.format( - key, input_param.shape, param.shape)) - continue - try: - with torch.no_grad(): - param.copy_(input_param) - except Exception as ex: - error_msgs.append('While copying the parameter named "{}", ' - 'whose dimensions in the model are {} and ' - 'whose dimensions in the checkpoint are {}, ' - 'an exception occurred : {}.'.format(key, param.size(), input_param.size(), - ex.args)) - elif strict: - missing_keys.append(key) - - extra_state_key = prefix + _EXTRA_STATE_KEY_SUFFIX - if getattr(self.__class__, "set_extra_state", nn.Module.set_extra_state) is not nn.Module.set_extra_state: - if extra_state_key in state_dict: - self.set_extra_state(state_dict[extra_state_key]) - elif strict: - missing_keys.append(extra_state_key) - elif strict and (extra_state_key in state_dict): - unexpected_keys.append(extra_state_key) - - if strict: - for key in state_dict.keys(): - if key.startswith(prefix) and key != extra_state_key: - input_name = key[len(prefix):] - input_name = input_name.split('.', 1)[0] # get the name of param/buffer/child - if input_name not in self._modules and input_name not in local_state: - unexpected_keys.append(key) - - def __getitem__(self, idx: int): - assert isinstance(self.module, nn.ModuleList) - return self.module[idx] - - def __len__(self): - assert isinstance(self.module, nn.ModuleList) - return len(self.module) - - def __iter__(self): - assert isinstance(self.module, nn.ModuleList) - return iter(self.module) diff --git a/colossalai/zero/legacy/sharded_model/utils.py b/colossalai/zero/legacy/sharded_model/utils.py deleted file mode 100644 index 08806e78ea3bf245e80d220e30f35029227bf144..0000000000000000000000000000000000000000 --- a/colossalai/zero/legacy/sharded_model/utils.py +++ /dev/null @@ -1,20 +0,0 @@ -import copy - -import torch - -from colossalai.zero.legacy.sharded_model import ShardedModelV2 - - -def col_model_deepcopy(sharded_model: ShardedModelV2, other_model: torch.nn.Module): - """ - copy param of the ShardedModelV2 to other_model. - Note the other_model has to be the same as self. - """ - for zero_param, param in zip(sharded_model.parameters(), other_model.parameters()): - assert hasattr(zero_param, 'colo_attr') - shard_flag = zero_param.colo_attr.sharded_data_tensor.is_sharded - if shard_flag: - sharded_model.shard_strategy.gather([zero_param.colo_attr.sharded_data_tensor]) - param.data = copy.deepcopy(zero_param.colo_attr.data_payload) - if shard_flag: - sharded_model.shard_strategy.shard([zero_param.colo_attr.sharded_data_tensor]) diff --git a/colossalai/zero/legacy/sharded_optim/__init__.py b/colossalai/zero/legacy/sharded_optim/__init__.py deleted file mode 100644 index b71a70aeffa44d9f4f539bf7ee4d497a55ebd3d5..0000000000000000000000000000000000000000 --- a/colossalai/zero/legacy/sharded_optim/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -from .sharded_optim_v2 import ShardedOptimizerV2 - -__all__ = ['ShardedOptimizerV2'] diff --git a/colossalai/zero/legacy/sharded_param/__init__.py b/colossalai/zero/legacy/sharded_param/__init__.py deleted file mode 100644 index 47e2ce2fa0e015c978b726c1fa77bb66fca91c42..0000000000000000000000000000000000000000 --- a/colossalai/zero/legacy/sharded_param/__init__.py +++ /dev/null @@ -1,4 +0,0 @@ -from .sharded_param import ShardedParamV2 -from .sharded_tensor import ShardedTensor - -__all__ = ['ShardedTensor', 'ShardedParamV2'] diff --git a/colossalai/zero/low_level/__init__.py b/colossalai/zero/low_level/__init__.py index ae3c1de3a5bc68bc178240729f51fde0ba8a6222..270a6a6a478650ced1a2c74a83302918bf575014 100644 --- a/colossalai/zero/low_level/__init__.py +++ b/colossalai/zero/low_level/__init__.py @@ -1,3 +1,3 @@ from .low_level_optim import LowLevelZeroOptimizer -__all__ = ['LowLevelZeroOptimizer'] +__all__ = ["LowLevelZeroOptimizer"] diff --git a/colossalai/zero/low_level/_utils.py b/colossalai/zero/low_level/_utils.py index afc98e7a7f54a7b4125eb989790c521fad3c86a7..0a15f8ddd718bef558e4958caa2e7a81c7ae3470 100644 --- a/colossalai/zero/low_level/_utils.py +++ b/colossalai/zero/low_level/_utils.py @@ -3,11 +3,9 @@ from typing import Optional import torch import torch.distributed as dist -from torch import inf +from torch import Tensor, inf from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors - -from colossalai.tensor import ColoParameter -from colossalai.utils import is_model_parallel_parameter +from torch.distributed import ProcessGroup def flatten(input_): @@ -46,8 +44,8 @@ def shuffle_by_round_robin(tensor_list, num_partitions): for partition_id in range(partitions_count): partition_tensors = partitions[partition_id] for item in partition_tensors: - tensor_index_mapping[item['index']] = len(new_tensor_list) - new_tensor_list.append(item['tensor']) + tensor_index_mapping[item["index"]] = len(new_tensor_list) + new_tensor_list.append(item["tensor"]) return new_tensor_list, tensor_index_mapping @@ -109,11 +107,13 @@ def split_by_dtype(tensor_list): return buckets -def reduce_tensor_dp_group(tensor: torch.Tensor, - dtype: Optional[torch.dtype] = None, - dst_local_rank: Optional[int] = None, - dst_global_rank: Optional[int] = None, - group: Optional[dist.ProcessGroup] = None): +def reduce_tensor_dp_group( + tensor: torch.Tensor, + dtype: Optional[torch.dtype] = None, + dst_local_rank: Optional[int] = None, + dst_global_rank: Optional[int] = None, + group: Optional[dist.ProcessGroup] = None, +): """ Reduce the tensor in the data parallel process group @@ -175,7 +175,7 @@ def has_inf_or_nan(tensor): raise return True else: - if tensor_sum == float('inf') or tensor_sum == -float('inf') or tensor_sum != tensor_sum: + if tensor_sum == float("inf") or tensor_sum == -float("inf") or tensor_sum != tensor_sum: return True return False @@ -186,34 +186,28 @@ def release_param_grad(tensor_list): def calculate_global_norm_from_list(norm_list): - """ Compute total from a list of norms - """ + """Compute total from a list of norms""" total_norm = 0.0 for norm in norm_list: total_norm += norm**2.0 return math.sqrt(total_norm) -def compute_norm(gradients, params, dp_group, mp_group, norm_type=2): +def compute_norm(gradients: Tensor, dp_group: ProcessGroup, tp_group: ProcessGroup, norm_type: int = 2) -> int: """Clips gradient norm of an iterable of parameters. This is adapted from torch.nn.utils.clip_grad.clip_grad_norm_ and - added functionality to handle model parallel parameters. Note that - the gradients are modified in place. - Arguments: - parameters (Iterable[Tensor] or Tensor): an iterable of Tensors or a - single Tensor that will have gradients normalized - max_norm (float or int): max norm of the gradients - norm_type (float or int): type of the used p-norm. Can be ``'inf'`` for - infinity norm. + added functionality to handle model parallel parameters. + + Args: + gradients (Tensor): The gradients to compute norm + dp_group (ProcessGroup): The process group of ZeRO Data Parallelism + tp_group (ProcessGroup): The process group of Tensor Parallelism + norm_type (int, optional): type of the used p-norm, Can be ``'inf'`` for infinity norm. Defaults to 2. + Returns: - Total norm of the parameters (viewed as a single vector). + int: The total norm of given gradients """ - if mp_group is None: - mp_rank = 0 - else: - mp_rank = dist.get_rank(mp_group) - norm_type = float(norm_type) if norm_type == inf: total_norm = max(g.data.abs().max() for g in gradients) @@ -221,39 +215,31 @@ def compute_norm(gradients, params, dp_group, mp_group, norm_type=2): dist.all_reduce(total_norm_cuda, op=torch.distributed.ReduceOp.MAX, group=dp_group) # Take max across all GPUs. - if mp_group is not None: + if tp_group is not None: dist.all_reduce(tensor=total_norm_cuda, op=torch.distributed.ReduceOp.MAX) total_norm = total_norm_cuda[0].item() else: total_norm = 0.0 - # if dist.get_rank() == 0: - # logger.info(f"Total Norm beginning {total_norm}") - - for g, p in zip(gradients, params): - # Pipeline parallelism may replicate parameters. Avoid multi-counting. - tp_param_flag = False - if is_model_parallel_parameter(p) or (isinstance(p, ColoParameter) and not p.is_replicate()): - tp_param_flag = True - if tp_param_flag or mp_rank == 0: - param_norm = g.data.double().norm(2) - total_norm += param_norm.item()**2 + for g in gradients: + param_norm = g.data.double().norm(norm_type) + total_norm += param_norm.item() ** norm_type # Sum across all model parallel GPUs. total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)]) torch.distributed.all_reduce(total_norm_cuda, op=torch.distributed.ReduceOp.SUM, group=dp_group) - if mp_group is not None: - dist.all_reduce(tensor=total_norm_cuda, op=torch.distributed.ReduceOp.SUM, group=mp_group) + if tp_group is not None: + dist.all_reduce(tensor=total_norm_cuda, op=torch.distributed.ReduceOp.SUM, group=tp_group) - total_norm = total_norm_cuda[0].item()**(1. / norm_type) + total_norm = total_norm_cuda[0].item() ** (1.0 / norm_type) - if total_norm == float('inf') or total_norm == -float('inf') or total_norm != total_norm: + if total_norm == float("inf") or total_norm == -float("inf") or total_norm != total_norm: total_norm = -1 return total_norm -def sync_param(flat_tensor, tensor_list): +def sync_tensor(flat_tensor, tensor_list): """ Synchronize the flattened tensor and unflattened tensor list. When a list of tensor are flattened with `torch._utils._unflatten_dense_tensors`, @@ -261,7 +247,7 @@ def sync_param(flat_tensor, tensor_list): share the same memory space. This function will update the tensor list so that they point to the same value. - :param flat_tensor: A flat tensor obtained by calling `torch._utils._unflatten_dense_tensors` on a tensor lsit + :param flat_tensor: A flat tensor obtained by calling `torch._utils._unflatten_dense_tensors` on a tensor list :param tensor_list: A list of tensors corresponding to the flattened tensor :type flat_tensor: torch.Tensor :type tensor_list: List[torch.Tensor] diff --git a/colossalai/zero/low_level/bookkeeping/__init__.py b/colossalai/zero/low_level/bookkeeping/__init__.py index 7bcacfabfded39972babff0536cb75b0c2c65506..427973772f9cf107d694590695c3366ff5554685 100644 --- a/colossalai/zero/low_level/bookkeeping/__init__.py +++ b/colossalai/zero/low_level/bookkeeping/__init__.py @@ -3,4 +3,4 @@ from .gradient_store import GradientStore from .parameter_store import ParameterStore from .tensor_bucket import TensorBucket -__all__ = ['GradientStore', 'ParameterStore', 'BucketStore', 'TensorBucket'] +__all__ = ["GradientStore", "ParameterStore", "BucketStore", "TensorBucket"] diff --git a/colossalai/zero/low_level/bookkeeping/base_store.py b/colossalai/zero/low_level/bookkeeping/base_store.py index 2ebd122464f4ef46976d7b7e7e45323aea322b2d..107d62dcbc0e5cf132073f7272f3fa61b1f98110 100644 --- a/colossalai/zero/low_level/bookkeeping/base_store.py +++ b/colossalai/zero/low_level/bookkeeping/base_store.py @@ -3,7 +3,6 @@ from torch.distributed import ProcessGroup class BaseStore: - def __init__(self, torch_pg: ProcessGroup): self._world_size = dist.get_world_size(group=torch_pg) self._local_rank = dist.get_rank(group=torch_pg) diff --git a/colossalai/zero/low_level/bookkeeping/bucket_store.py b/colossalai/zero/low_level/bookkeeping/bucket_store.py index ec322a78bf81a21d9b5e08b6c5bd9c44ca06f3f0..2828d517573da964f6651f0a5a2938e05aa4f8cd 100644 --- a/colossalai/zero/low_level/bookkeeping/bucket_store.py +++ b/colossalai/zero/low_level/bookkeeping/bucket_store.py @@ -1,41 +1,127 @@ +from typing import Dict + +import torch +from torch import Tensor +from torch._utils import _flatten_dense_tensors from torch.distributed import ProcessGroup from .base_store import BaseStore class BucketStore(BaseStore): - def __init__(self, torch_pg: ProcessGroup): super().__init__(torch_pg) - self._params = dict() - self._num_elements_in_bucket = dict() - self.reset() + # init + self.current_group_id = 0 + self._num_elements_in_bucket = 0 + # mapping gardient slices and parameter + self.grad_to_param_mapping = dict() + + self._grad_in_bucket = dict() + self._param_list = [] + self._padding_size = [] + for rank in range(self._world_size): + self._grad_in_bucket[rank] = [] + + # offset_list records number of tensors in the bucket before each reduction + self.offset_list = [0] + + def num_elements_in_bucket(self) -> int: + """Return the total number of elements in bucket + + Returns: + int: the total number of elements in bucket + """ + + return self._num_elements_in_bucket + + def reset_num_elements_in_bucket(self): + """Set the number of elements in bucket to zero.""" + + self._num_elements_in_bucket = 0 + + def add_param_grad(self, group_id: int, param: Tensor, padding_size: int): + """Add a param to bucket and record the padding size of a param for gradient padding + + Args: + group_id (int): The index of a parameter group + param (Tensor): The parameter + padding_size (int): The padding size of the parameter + """ + + self._param_list.append(param) + self._padding_size.append(padding_size) + self._num_elements_in_bucket += param.numel() + padding_size + self.current_group_id = group_id + + # number of tensors in current bucket + self.offset_list[-1] += 1 + + def build_grad_in_bucket(self): + """Orgnize parameters' gradient(padding and split), follows the paramters' splitting method + + Data structure of self._grad_in_bucket: + { + rank0: [grad0_rank0, grad1_rank0, ...] + rank1: [grad0_rank1, grad1_rank1, ...] + } + """ + for param, padding_size in zip(self._param_list, self._padding_size): + grad = param.grad.clone().detach().flatten() + if padding_size > 0: + with torch.no_grad(): + grad = torch.nn.functional.pad(grad.view(-1), [0, padding_size]) + grad_list = grad.split(grad.numel() // self._world_size) + for rank in range(self._world_size): + grad_current_rank = grad_list[rank].clone().detach() + self.grad_to_param_mapping[id(grad_current_rank)] = id(param) + self._grad_in_bucket[rank].append(grad_current_rank) + param.grad = None + + self.offset_list.append(0) + + def get_grad(self) -> Dict: + """Return the dictionary of gradients slices, of which the keys are ranks + + Returns: + Dict: The dictionary of gradients slices + """ + + return self._grad_in_bucket + + def get_flatten_grad(self) -> Tensor: + """Return the flattened gradients slices in the bucket, the data orginization of the flattened tensor: + [grad0_rank0, grad1_rank0, ..., grad_0_rank1, grad1_rank1, ....] + + Returns: + Tensor: the flattened gradients slices in the bucket + """ + + flat_grad = [] + for grad_list in self._grad_in_bucket.values(): + flat_grad.append(_flatten_dense_tensors(grad_list)) + flat_grad = _flatten_dense_tensors(flat_grad) + return flat_grad + + def get_param_id_of_grad(self, grad: Tensor) -> int: + """Return the id of a parameter which the gradient slice belongs to - def num_elements_in_bucket(self, reduce_rank: int = None): - return self._num_elements_in_bucket[reduce_rank] + Args: + grad (Tensor): the gradient slice - def add_num_elements_in_bucket(self, num_elements, reduce_rank: int = None): - self._num_elements_in_bucket[reduce_rank] += num_elements + Returns: + int: the id of a parameter which the gradient slice belongs to + """ - def add_param(self, tensor, reduce_rank: int = None): - self._params[reduce_rank].append(tensor) + return self.grad_to_param_mapping[id(grad)] def reset(self): - keys = [None] + list(range(self._world_size)) - self._params = {rank: [] for rank in keys} - self._num_elements_in_bucket = {rank: 0 for rank in keys} - - def reset_by_rank(self, reduce_rank=None): - self._params[reduce_rank] = [] - self._num_elements_in_bucket[reduce_rank] = 0 - - def get_grad(self, reduce_rank: int = None): - param_list = self.get_param(reduce_rank) - for param in param_list: - # the param must have grad for reduction - assert param.grad is not None, f'Parameter of size ({param.size()}) has None grad, cannot be reduced' - return [param.grad for param in param_list] - - def get_param(self, reduce_rank: int = None): - return self._params[reduce_rank] + """Reset the bucket storage after reduction, only release the tensors have been reduced""" + cur_offset = self.offset_list.pop(0) + self._param_list = self._param_list[cur_offset:] + self._padding_size = self._padding_size[cur_offset:] + for _ in range(cur_offset): + del self.grad_to_param_mapping[next(iter(self.grad_to_param_mapping))] + for rank in range(self._world_size): + self._grad_in_bucket[rank] = self._grad_in_bucket[rank][cur_offset:] diff --git a/colossalai/zero/low_level/bookkeeping/gradient_store.py b/colossalai/zero/low_level/bookkeeping/gradient_store.py index 942d7186e55f53f97211491025c759a7bed18cf9..3ce688cfa930022b9fc3e05052cd698284d72467 100644 --- a/colossalai/zero/low_level/bookkeeping/gradient_store.py +++ b/colossalai/zero/low_level/bookkeeping/gradient_store.py @@ -6,83 +6,85 @@ from .base_store import BaseStore class GradientStore(BaseStore): - - def __init__(self, *args): + def __init__(self, *args, partition_grad: bool = False): super().__init__(*args) - # bookkeeping data structures - self._averaged_gradients = dict() - - # for backward reduction hooks - self._grad_acc_objs = [] - - def append_accumulate_grad_object(self, obj): """ - Keep :class:`AccumulateGrad` objects. If these objects are not kept, reduction hooks may not - be attached successfully. - - :param obj: An object of :class:`AccumulateGrad` class - :type obj: :class:`AccumulateGrad` + self._grads_of_params mapping the paramater and its gradient slices + data structure: + { + group_id:{ + param_id: [grad_rank0, grad_rank1, ...] + } + } """ + self._grads_of_params = dict() + # for zero2, it's `param_id: [grad_local_rank]` + self._working_index = 0 if partition_grad else self._local_rank - self._grad_acc_objs.append(obj) + def get_partitioned_gradients_by_param_id(self, group_id: int, param_id: int) -> List: + """Return list of gradient slices of a specific parameter - def get_averaged_gradients_by_group(self, group_id: int) -> List[Tensor]: - """ - Return average gradients of a parameter group - - :param group_id: The index of parameter group - :type group_id: int + Args: + group_id (int): The index of a parameter group + param_id (int): The id of a parameter - :return: Return the list of averaged gradients of a parameter group. Each element is a gradient, not a parameter. - :rtype: List[torch.Tensor] + Returns: + List: the list of gradient slices of a parameter. """ - if group_id not in self._averaged_gradients: - self._averaged_gradients[group_id] = [] - return self._averaged_gradients[group_id] - - def append_average_gradient_by_group(self, group_id: int, tensor: Tensor) -> None: - """ - Append an average gradient to the list of averaged gradients of a parameter group + if group_id in self._grads_of_params: + if param_id in self._grads_of_params[group_id]: + return self._grads_of_params[group_id][param_id] + # the param has no grad, for instance, in layer drop + return [] - :param group_id: The index of a parameter group - :param tensor: A :class:`torch.Tensor` object - :type group_id: int - :type tensor: torch.Tensor + def append_gradients_by_param_id(self, grad: Tensor, group_id: int, param_id: int): + """Append a gradient slice to the parameter's gradient slice list + Args: + grad (Tensor): The gradient slice to append to list + group_id (int): The index of a parameter group + param_id (int): The id of a parameter """ - if group_id in self._averaged_gradients: - self._averaged_gradients[group_id].append(tensor) + if group_id not in self._grads_of_params: + self._grads_of_params[group_id] = dict() + if param_id not in self._grads_of_params[group_id]: + self._grads_of_params[group_id][param_id] = [grad] else: - self._averaged_gradients[group_id] = [tensor] + self._grads_of_params[group_id][param_id].append(grad) - def add_average_gradient_by_group(self, group_id: int, tensor_idx: int, tensor: Tensor) -> None: + def add_gradients_by_param_id(self, grad: Tensor, grad_idx: int, group_id: int, param_id: int): + """Add a gradient slice on an existing slice of the parameter's gradient + Used when no_sync is not activated. + + Args: + grad (Tensor): The split gradient to append to list + grad_idx (int): The index of the existing slice + group_id (int): The index of a parameter group + param_id (int): The id of a parameter """ - Add an average gradient to the list of averaged gradients of a parameter group - :param group_id: The index of a parameter group - :param tensor_idx: The index of a tensor in the list of averaged gradients - :param tensor: A :class:`torch.Tensor` object - :type group_id: int - :type tensor_idx: int - :type tensor: torch.Tensor + self._grads_of_params[group_id][param_id][grad_idx].add_(grad) - """ - self._averaged_gradients[group_id][tensor_idx].add_(tensor) + def get_working_grads_by_group_id(self, group_id: int) -> List: + """Return list of working gradient slices in the group - def reset_average_gradients_by_group(self, group_id: int) -> None: - """ - Reset the bookkeeping data structure for averaged gradients to an empty list + Args: + group_id (int): The index of a parameter group - :param group_id: The index of a parameter group - :type group_id: int + Returns: + List: the list working gradient slices in the group """ - self._averaged_gradients[group_id] = [] + grad_list = [] + for param_grads in self._grads_of_params[group_id].values(): + grad_list.append(param_grads[self._working_index]) - def reset_all_average_gradients(self) -> None: - """ - Reset the bookkeeping data structure for averaged gradients to an empty list - """ - self._averaged_gradients = dict() + return grad_list + + def reset_grads_by_group_id(self, group_id: int): + self._grads_of_params[group_id] = dict() + + def reset_all_gradients(self): + self._grads_of_params = dict() diff --git a/colossalai/zero/low_level/bookkeeping/parameter_store.py b/colossalai/zero/low_level/bookkeeping/parameter_store.py index 1f3ba7cbc3bc75c029072a11648628e7b866aa2d..e94fb4de9b9f9f950273efb73a14c85488cca843 100644 --- a/colossalai/zero/low_level/bookkeeping/parameter_store.py +++ b/colossalai/zero/low_level/bookkeeping/parameter_store.py @@ -1,5 +1,3 @@ -from typing import List - from torch import Tensor from torch.distributed import ProcessGroup @@ -7,91 +5,45 @@ from .base_store import BaseStore class ParameterStore(BaseStore): - def __init__(self, torch_pg: ProcessGroup): super().__init__(torch_pg) - # param partitioning data structures - self._param_to_rank = dict() - self._rank_group_id_to_param_list = dict() - self._rank_group_id_to_flat_param = dict() - # param reduction data structures - self._is_param_reduced = dict() - self._reduced_param = [] + # record the padding size of each param + self._padding_map = dict() - def set_param_to_rank(self, tensor: Tensor, rank: int) -> None: - """ - Set the mapping between parameter to rank, each parameter should be owned by a rank. - - :param tensor: A :class:`torch.Tensor` object - :type tensor: torch.Tensor - :param rank: The rank of which the process is responsible for updating the parameter - :type rank: int - """ + # mapping working param and master param + self.master_to_working_param = dict() + self.working_to_master_param = dict() - self._param_to_rank[tensor] = rank + def record_param_padding_size(self, param: Tensor, padding_size: int): + """Record the padding size of a param - def get_param_rank(self, tensor: Tensor) -> int: + Args: + param (Tensor): The parameter + padding_size (int): The padding size of the parameter """ - Gives the rank which the parameter belongs to - :param tensor: A :class:`torch.Tensor` object - :type tensor: torch.Tensor - """ - return self._param_to_rank[tensor] + self._padding_map[id(param)] = padding_size - def belongs_to_current_rank(self, tensor) -> bool: - """ - Check whether a parameter is supposed to be updated by the process of the current rank + def get_param_padding_size(self, param: Tensor) -> int: + """Return the padding size of the parameter - :param tensor: A :class:`torch.Tensor` object - :type tensor: torch.Tensor + Args: + param (Tensor): The parameter - :return: True if the parameter should be updated by the current rank. Otherwise false. - :rtype: bool + Returns: + int: the padding size of the parameter """ - tensor_rank = self._param_to_rank[tensor] - return tensor_rank == self._local_rank - - def add_param_list_by_rank_group(self, rank, group_id, tensor_list) -> None: - if rank not in self._rank_group_id_to_param_list: - self._rank_group_id_to_param_list[rank] = dict() - - if group_id not in self._rank_group_id_to_param_list[rank]: - self._rank_group_id_to_param_list[rank][group_id] = [] + return self._padding_map[id(param)] - self._rank_group_id_to_param_list[rank][group_id].extend(tensor_list) + def link_master_and_working_param(self, master_param: Tensor, working_param: Tensor): + """Mapping master parameter and working parameter - def get_params_by_rank_group(self, rank, group_id) -> List[Tensor]: - return self._rank_group_id_to_param_list[rank][group_id] - - def add_flat_param_by_rank_group(self, rank, group_id, tensor) -> None: - if rank not in self._rank_group_id_to_flat_param: - self._rank_group_id_to_flat_param[rank] = dict() - - self._rank_group_id_to_flat_param[rank][group_id] = tensor - - def get_flat_param_by_rank_group(self, rank, group_id) -> Tensor: - return self._rank_group_id_to_flat_param[rank][group_id] - - def is_param_reduced(self, tensor): - return self._is_param_reduced[tensor] - - def set_param_reduction_state(self, tensor, state): - self._is_param_reduced[tensor] = state - - def get_param_reduction_states(self): - return self._is_param_reduced - - def reset_previous_reduced_params(self): - self._reduced_param = [] - - def add_previous_reduced_param(self, tensor): - self._reduced_param.append(tensor) + Args: + master_param (Tensor): The parameter copy in optimizer + working_param (Tensor): The parameter of the model + """ - def clear_grads_of_previous_reduced_params(self): - if len(self._reduced_param) > 0: - for param in self._reduced_param: - param.grad = None - self.reset_previous_reduced_params() + self.master_to_working_param[id(master_param)] = working_param + self.working_to_master_param[id(working_param)] = master_param diff --git a/colossalai/zero/low_level/bookkeeping/tensor_bucket.py b/colossalai/zero/low_level/bookkeeping/tensor_bucket.py index b32816a046cd6a156196e84e957e070c4401d555..16ba8a6d644504c9a494b310a63a491a31f00fba 100644 --- a/colossalai/zero/low_level/bookkeeping/tensor_bucket.py +++ b/colossalai/zero/low_level/bookkeeping/tensor_bucket.py @@ -2,7 +2,6 @@ from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors class TensorBucket: - def __init__(self, size): self._max_size = size self._current_size = 0 @@ -26,8 +25,7 @@ class TensorBucket: tensor_size = tensor.numel() if not allow_oversize and self.will_exceed_max_size(tensor_size): - msg = f"The param bucket max size {self._max_size} is exceeded" \ - + f"by tensor (size {tensor_size})" + msg = f"The param bucket max size {self._max_size} is exceeded" + f"by tensor (size {tensor_size})" raise RuntimeError(msg) self._bucket.append(tensor) diff --git a/colossalai/zero/low_level/low_level_optim.py b/colossalai/zero/low_level/low_level_optim.py index 3e7661ecab769768c5ba0be00f9462d57b696f39..72df93ace302047fdadb9dada36a8830c43df415 100644 --- a/colossalai/zero/low_level/low_level_optim.py +++ b/colossalai/zero/low_level/low_level_optim.py @@ -1,17 +1,24 @@ # this code is inspired by the DeepSpeed library and implemented with our own design from scratch +import copy +from contextlib import contextmanager from functools import partial -from typing import Optional +from typing import Dict, Iterator, Optional, Tuple import torch import torch.distributed as dist +import torch.nn as nn +from torch.distributed import ProcessGroup from torch.optim import Optimizer -from colossalai.amp.naive_amp.grad_scaler import DynamicGradScaler -from colossalai.context import ParallelMode -from colossalai.core import global_context as gpc +from colossalai.amp.naive_amp.mixed_precision_mixin import ( + BF16MixedPrecisionMixin, + FP16MixedPrecisionMixin, + MixedPrecisionMixin, +) +from colossalai.interface import OptimizerWrapper from colossalai.logging import get_dist_logger -from colossalai.nn.optimizer import ColossalaiOptimizer -from colossalai.tensor import ColoParameter, ProcessGroup + +# from colossalai.tensor import ColoParameter, ProcessGroup from colossalai.utils.cuda import get_current_device from ._utils import ( @@ -19,45 +26,65 @@ from ._utils import ( compute_norm, flatten, has_inf_or_nan, - reduce_tensor_dp_group, release_param_grad, - split_by_dtype, - sync_param, + sync_tensor, ) -from .bookkeeping import BucketStore, GradientStore, ParameterStore, TensorBucket +from .bookkeeping import BucketStore, GradientStore, ParameterStore + + +class LowLevelZeroFP16MixedPrecisionMixin(FP16MixedPrecisionMixin): + def __init__( + self, + num_working_param_groups: int, + grad_store: GradientStore, + initial_scale: float = 2**16, + min_scale: float = 1, + growth_factor: float = 2, + backoff_factor: float = 0.5, + growth_interval: int = 1000, + hysteresis: int = 2, + max_scale: float = 2**32, + ) -> None: + super().__init__( + initial_scale, min_scale, growth_factor, backoff_factor, growth_interval, hysteresis, max_scale + ) + self.num_working_param_groups = num_working_param_groups + self.grad_store = grad_store + + def check_local_overflow(self) -> bool: + for group_id in range(self.num_working_param_groups): + for avg_grad in self.grad_store.get_working_grads_by_group_id(group_id): + if avg_grad is not None and has_inf_or_nan(avg_grad): + return True + return False -class LowLevelZeroOptimizer(ColossalaiOptimizer): - """Optimizer used for ZeRO-1 and ZeRO-2. - """ +class LowLevelZeroOptimizer(OptimizerWrapper): + """Optimizer used for ZeRO-1 and ZeRO-2.""" def __init__( - self, - optimizer: Optimizer, - initial_scale: int = 2**16, # grad scaler config - min_scale: int = 1, - growth_factor: float = 2., - backoff_factor: float = .5, - growth_interval: int = 2000, - hysteresis: int = 2, - max_scale: int = 2**24, - clip_grad_norm: float = 0.0, # grad clipping - verbose: bool = False, - reduce_bucket_size: int = 1024 * 1024, # communication - communication_dtype: Optional[torch.dtype] = None, - overlap_communication: bool = False, - partition_grad: bool = False, # stage 2 flag - cpu_offload: bool = False, # cpu offload - forced_dtype: Optional[torch.dtype] = None): - - # TODO: add support for - # 1. fp16 master weights - # 2. contiguous gradients - # 3. cpu offload - # 4. support when some parameters requires_grad = False - # 5. support layer drop + self, + optimizer: Optimizer, + initial_scale: int = 2**16, # grad scaler config + min_scale: int = 1, + growth_factor: float = 2.0, + backoff_factor: float = 0.5, + growth_interval: int = 2000, + hysteresis: int = 2, + max_scale: int = 2**24, + clip_grad_norm: float = 0.0, # grad clipping + verbose: bool = False, + reduce_bucket_size: int = 1024 * 1024, # communication + communication_dtype: Optional[torch.dtype] = None, + overlap_communication: bool = False, + partition_grad: bool = False, # stage 2 flag + cpu_offload: bool = False, # cpu offload + dp_process_group: Optional[ProcessGroup] = None, # the dp pg for comm + tp_process_group: Optional[ProcessGroup] = None, # if using tp + forced_dtype: Optional[torch.dtype] = None, + ): super(LowLevelZeroOptimizer, self).__init__(optim=optimizer) - self._dtype = self.optim.param_groups[0]['params'][0].dtype + self._dtype = self.optim.param_groups[0]["params"][0].dtype self._logger = get_dist_logger() self._verbose = verbose @@ -66,57 +93,31 @@ class LowLevelZeroOptimizer(ColossalaiOptimizer): self._cpu_offload = cpu_offload - colo_pg = self._search_colo_process_group() - if isinstance(colo_pg, ProcessGroup): - self._local_rank = colo_pg.dp_local_rank() - self._world_size = colo_pg.dp_world_size() - self._dp_global_ranks = colo_pg.get_ranks_in_dp() - self._dp_torch_group = colo_pg.dp_process_group() - self._mp_torch_group = None - if colo_pg.tp_world_size() > 1: - self._mp_torch_group = colo_pg.tp_process_group() - elif colo_pg is None: - dp_parallel_mode = ParallelMode.DATA - mp_parallel_mode = ParallelMode.MODEL - - self._dp_parallel_mode = dp_parallel_mode - self._mp_parallel_mode = mp_parallel_mode - self._local_rank = gpc.get_local_rank(dp_parallel_mode) - self._world_size = gpc.get_world_size(dp_parallel_mode) - self._dp_global_ranks = gpc.get_ranks_in_group(dp_parallel_mode) - self._dp_torch_group = gpc.get_group(dp_parallel_mode) - self._mp_torch_group = None - if gpc.is_initialized(mp_parallel_mode) and gpc.get_world_size(mp_parallel_mode) > 1: - self._mp_torch_group = gpc.get_group(mp_parallel_mode) - else: - raise NotImplementedError + # grad accumulation + self.require_grad_sync = True + + # if process_group is none, will use the default one + self.dp_pg = dp_process_group + self._local_rank = dist.get_rank(group=self.dp_pg) + self._world_size = dist.get_world_size(group=self.dp_pg) + + self.tp_pg = tp_process_group # working and master params for mixed precision training self._working_param_groups = dict() - self._master_flat_param_groups_of_current_rank = dict() + self._master_param_groups_of_current_rank = dict() # communication params self._overlap_communication = overlap_communication self._reduce_bucket_size = reduce_bucket_size self._communication_dtype = communication_dtype - # gradient scaler - self.grad_scaler = DynamicGradScaler(initial_scale=initial_scale, - min_scale=min_scale, - growth_factor=growth_factor, - backoff_factor=backoff_factor, - growth_interval=growth_interval, - hysteresis=hysteresis, - max_scale=max_scale, - verbose=verbose) - self._found_overflow = torch.FloatTensor([0]).to(get_current_device()) - # gradient clipping self._clip_grad_norm = clip_grad_norm if forced_dtype: for group in self.optim.param_groups: - group_params = group['params'] + group_params = group["params"] for param in group_params: param.data = param.data.to(forced_dtype) self._dtype = forced_dtype @@ -126,68 +127,30 @@ class LowLevelZeroOptimizer(ColossalaiOptimizer): # ParameterStore will manage the tensor buffers used for zero # it will not manage the tensors used by mixed precision training - self._param_store = ParameterStore(self._dp_torch_group) - self._grad_store = GradientStore(self._dp_torch_group) - self._bucket_store = BucketStore(self._dp_torch_group) + self._param_store = ParameterStore(self.dp_pg) + self._grad_store = GradientStore(self.dp_pg, partition_grad=partition_grad) + self._bucket_store = BucketStore(self.dp_pg) # iterate over the param group in the optimizer # partition these param groups for data parallel training # and add buffers to parameter store for future access for group_id, param_group in enumerate(self.optim.param_groups): group_params = list() - for param in param_group['params']: + for param in param_group["params"]: if param.requires_grad: group_params.append(param) # add the working params to working_param_groups for bookkeeping self._working_param_groups[group_id] = group_params - # assign parameters to ranks - # the params in the list are sorted - params_per_rank = self._partition_param_list(group_params) - - # store the mapping between param to rank - # each param should belong to only one rank - for rank, params in enumerate(params_per_rank): - self._param_store.add_param_list_by_rank_group(rank, group_id, params) - for param in params: - self._param_store.set_param_to_rank(param, rank) + master_param_current_rank = self._create_master_param_current_rank(group_params) - # move to cpu to make room to create the flat tensor - # move_tensor(params, device='cpu') - for param in group_params: - param.data = param.data.cpu() - - # flatten the reordered tensors - for rank in range(self._world_size): - tensor_list = self._param_store.get_params_by_rank_group(rank, group_id) - with torch.no_grad(): - flat_tensor = flatten(tensor_list) - flat_tensor = flat_tensor.data.cuda() - self._param_store.add_flat_param_by_rank_group(rank, group_id, flat_tensor) - - # sync parameters - for rank in range(self._world_size): - flat_tensor = self._param_store.get_flat_param_by_rank_group(rank, group_id) - tensor_list = self._param_store.get_params_by_rank_group(rank, group_id) - sync_param(flat_tensor=flat_tensor, tensor_list=tensor_list) - - # create a copy of fp32 master weights of the parameters for which this rank is responsible - working_flat_current_rank = self._param_store.get_flat_param_by_rank_group(self._local_rank, group_id) - master_flat_current_rank = working_flat_current_rank.float() - device = 'cpu' if self._cpu_offload else get_current_device() - master_flat_current_rank = master_flat_current_rank.to(device) - master_flat_current_rank.requires_grad = True - self._master_flat_param_groups_of_current_rank[group_id] = master_flat_current_rank + self._master_param_groups_of_current_rank[group_id] = master_param_current_rank # need to replace the params in the `params` field in the optimizer # so that when the optimizer calls step(), it only updates the tensors # managed by this data parallel rank - param_group['params'] = [master_flat_current_rank] - - # set reduction state - for param in self._working_param_groups[group_id]: - self._param_store.set_param_reduction_state(param, False) + param_group["params"] = master_param_current_rank # intialize communication stream for # communication-compuation overlapping @@ -200,65 +163,70 @@ class LowLevelZeroOptimizer(ColossalaiOptimizer): if self._overlap_communication or self._partition_grads: self._attach_reduction_hook() + # initialize mixed precision mixin + self.mixed_precision_mixin: Optional[MixedPrecisionMixin] = None + if self._dtype is torch.float16: + self.mixed_precision_mixin = LowLevelZeroFP16MixedPrecisionMixin( + self.num_param_groups, + self._grad_store, + initial_scale=initial_scale, + min_scale=min_scale, + growth_factor=growth_factor, + backoff_factor=backoff_factor, + growth_interval=growth_interval, + hysteresis=hysteresis, + max_scale=max_scale, + ) + elif self._dtype is torch.bfloat16: + self.mixed_precision_mixin = BF16MixedPrecisionMixin() + @property def dtype(self): return self._dtype - @property - def loss_scale(self): - return self.grad_scaler.scale - @property def num_param_groups(self): return len(self._working_param_groups) def _sanity_checks(self): - assert torch.cuda.is_available(), 'CUDA is required' + assert torch.cuda.is_available(), "CUDA is required" for param_group in self.optim.param_groups: - group_params = param_group['params'] + group_params = param_group["params"] for param in group_params: - assert param.dtype == self._dtype, \ - f"Parameters are expected to have the same dtype `{self._dtype}`, but got `{param.dtype}`" + assert ( + param.dtype == self._dtype + ), f"Parameters are expected to have the same dtype `{self._dtype}`, but got `{param.dtype}`" + + def _create_master_param_current_rank(self, param_list): + # split each param evenly by world size + params_current_rank = [] + device = "cpu" if self._cpu_offload else get_current_device() + + for param in param_list: + padding_size = (self._world_size - param.numel() % self._world_size) % self._world_size + self._param_store.record_param_padding_size(param, padding_size) + + with torch.no_grad(): + if padding_size > 0: + padding_param = torch.nn.functional.pad(param.data.view(-1), [0, padding_size]) + else: + padding_param = param.data.view(-1) + splited_params = padding_param.split(padding_param.numel() // self._world_size) - def _search_colo_process_group(self): - colo_flag = False - colo_pg = None - for param_group in self.optim.param_groups: - group_params = param_group['params'] - for param in group_params: - if isinstance(param, ColoParameter): - colo_flag = True - if colo_pg is None: - colo_pg = param.get_process_group() - else: - assert colo_pg == param.get_process_group(), "All parameters should be in a same process group" - elif colo_flag: - raise RuntimeError("All parameters should be ColoParameter if you use ColoParameter.") - return colo_pg - - def _partition_param_list(self, param_list): - params_per_rank = [[] for _ in range(self._world_size)] - numel_per_rank = [0 for _ in range(self._world_size)] - - # partititon the parameters in a greedy fashion - sorted_params = sorted(param_list, key=lambda x: x.numel(), reverse=True) - for param in sorted_params: - # allocate this parameter to the rank with - # the smallest numel for load balancing purpose - rank_to_go = numel_per_rank.index(min(numel_per_rank)) - params_per_rank[rank_to_go].append(param) - numel_per_rank[rank_to_go] += param.numel() - - if self._verbose: - self._logger.info(f'Number of elements on ranks: {numel_per_rank}', ranks=[0]) - return params_per_rank + splited_param_current_rank = splited_params[self._local_rank].detach().float().to(device) + params_current_rank.append(splited_param_current_rank) + self._param_store.link_master_and_working_param(splited_param_current_rank, param) + + return params_current_rank ########################### # Backward Reduction Hook # ########################### - def _grad_handler(self, param, grad, reduce_rank): - self._add_to_reduction_bucket(param, reduce_rank) + def _grad_handler(self, param, group_id, grad): + # if run with no_sync context, would not sync grad when backward + if self.require_grad_sync: + self._add_to_bucket(param, group_id) return grad def _attach_reduction_hook(self): @@ -268,148 +236,136 @@ class LowLevelZeroOptimizer(ColossalaiOptimizer): param_group = self._working_param_groups[group_id] for param in param_group: if param.requires_grad: - # determines the reduction destionation rank - # this is only valid for stage 2 - # dst_rank = None means using all-reduce - # else using reduce - if self._partition_grads: - reduce_rank = self._param_store.get_param_rank(param) - else: - reduce_rank = None - - param.register_hook(partial(self._grad_handler, param, reduce_rank=reduce_rank)) - - def _reduce_tensor_bucket(self, bucket: TensorBucket, reduce_rank): - if self._overlap_communication: - torch.cuda.synchronize() - self._param_store.clear_grads_of_previous_reduced_params() - stream = self._comm_stream - else: - stream = torch.cuda.current_stream() - - with torch.cuda.stream(stream): - flat = bucket.flatten() - reduce_global_rank = None - if reduce_rank is not None: - reduce_global_rank = self._dp_global_ranks[reduce_rank] - reduced_flat = reduce_tensor_dp_group(tensor=flat, - dtype=self._communication_dtype, - dst_local_rank=reduce_rank, - dst_global_rank=reduce_global_rank, - group=self._dp_torch_group) - - # update the reduced tensor - if reduce_rank is None or reduce_rank == self._local_rank: - bucket.unflatten_and_copy(reduced_flat) - - def _reduce_tensor_list_with_one_dtype(self, tensor_list, bucket_size, reduce_rank): - param_bucket = TensorBucket(size=bucket_size) - - for tensor in tensor_list: - param_bucket.add_to_bucket(tensor, allow_oversize=True) - - if param_bucket.is_full_or_oversized(): - self._reduce_tensor_bucket(bucket=param_bucket, reduce_rank=reduce_rank) - param_bucket.empty() - - if not param_bucket.is_empty(): - self._reduce_tensor_bucket(bucket=param_bucket, reduce_rank=reduce_rank) - - def _reduce_grads(self, reduce_rank, grads, bucket_size): - grad_buckets_by_dtype = split_by_dtype(grads) - - for tensor_list in grad_buckets_by_dtype: - self._reduce_tensor_list_with_one_dtype(tensor_list=tensor_list, - bucket_size=bucket_size, - reduce_rank=reduce_rank) + param.register_hook(partial(self._grad_handler, param, group_id)) ####################### # Reduction Functions # ####################### - def _run_reduction(self, reduce_rank=None): - # reduce grads - self._reduce_grads(reduce_rank=reduce_rank, - grads=self._bucket_store.get_grad(reduce_rank=reduce_rank), - bucket_size=self._bucket_store.num_elements_in_bucket(reduce_rank)) + def _run_reduction(self): + if self._bucket_store.num_elements_in_bucket() > 0: + self._bucket_store.build_grad_in_bucket() + + flat_grads = self._bucket_store.get_flatten_grad() + flat_grads /= self._world_size + + # ready to add other tensors to bucket + self._bucket_store.reset_num_elements_in_bucket() + + if self._overlap_communication: + stream = self._comm_stream + # in case of the memory being reused in the default stream + flat_grads.record_stream(stream) + # waiting for ops in the default stream finishing + stream.wait_stream(torch.cuda.current_stream()) + else: + stream = torch.cuda.current_stream() + + with torch.cuda.stream(stream): + group_id = self._bucket_store.current_group_id + + grad_dtype = flat_grads.dtype + if self._communication_dtype is not None: + flat_grads = flat_grads.to(self._communication_dtype) + + if not self._partition_grads: + dist.all_reduce(flat_grads, group=self.dp_pg) + if flat_grads.dtype != grad_dtype: + flat_grads = flat_grads.to(grad_dtype) + + flat_grads_per_rank = flat_grads.split(flat_grads.numel() // self._world_size) + grad_in_bucket = self._bucket_store.get_grad() + + for rank, grad_list in grad_in_bucket.items(): + sync_tensor(flat_grads_per_rank[rank], grad_list) + for grad in grad_list: + param_id = self._bucket_store.get_param_id_of_grad(grad) + if ( + len(self._grad_store.get_partitioned_gradients_by_param_id(group_id, param_id)) + < self._world_size + ): + self._grad_store.append_gradients_by_param_id(grad, group_id, param_id) + else: + self._grad_store.add_gradients_by_param_id(grad, rank, group_id, param_id) - # use communication stream if overlapping - # communication with computation - if self._overlap_communication: - stream = self._comm_stream - else: - stream = torch.cuda.current_stream() - - with torch.cuda.stream(stream): - params_in_bucket = self._bucket_store.get_param(reduce_rank=reduce_rank) - - for param in params_in_bucket: - # the is_param_reduced flag should be False showing that - # this param is not reduced before calling self._reduce_grads_by_rank - is_param_reduced = self._param_store.is_param_reduced(param) - - if is_param_reduced: - msg = f'Parameter of size ({param.size()}) has been reduced, ' + \ - 'duplicate reduction will lead to arithmetic incorrectness' - raise RuntimeError(msg) + else: + flat_grads_list = list(flat_grads.split(len(flat_grads) // self._world_size)) + recieved_grad = torch.zeros_like(flat_grads_list[0]) + dist.reduce_scatter(recieved_grad, flat_grads_list, group=self.dp_pg) - # update the flag - self._param_store.set_param_reduction_state(param, True) + if recieved_grad.dtype != grad_dtype: + recieved_grad = recieved_grad.to(grad_dtype) - # if partition grads = True - # we do not keep the gradient after reduction - if self._partition_grads and not self._param_store.belongs_to_current_rank(param): - if self._overlap_communication: - # we need to keep this gradient for now as reduction may - # be completed yet since it is using a different cuda stream - self._param_store.add_previous_reduced_param(param) - else: - param.grad = None + grad_in_bucket_current_rank = self._bucket_store.get_grad()[self._local_rank] + sync_tensor(recieved_grad, grad_in_bucket_current_rank) + for grad in grad_in_bucket_current_rank: + param_id = self._bucket_store.get_param_id_of_grad(grad) + if len(self._grad_store.get_partitioned_gradients_by_param_id(group_id, param_id)) < 1: + self._grad_store.append_gradients_by_param_id(grad, group_id, param_id) + else: + self._grad_store.add_gradients_by_param_id(grad, 0, group_id, param_id) - self._bucket_store.reset_by_rank(reduce_rank) + self._bucket_store.reset() - def _add_to_reduction_bucket(self, param, reduce_rank=None): + def _add_to_bucket(self, param, group_id): param_size = param.numel() # check if the bucket is full # if full, will reduce the grads already in the bucket + # or got a grad of param from another group # after reduction, the bucket will be empty - if self._bucket_store.num_elements_in_bucket(reduce_rank) + param_size > self._reduce_bucket_size: - self._run_reduction(reduce_rank) - - # the param must not be reduced to ensure correctness - is_param_reduced = self._param_store.is_param_reduced(param) - if is_param_reduced: - msg = f'Parameter of size ({param.size()}) has already been reduced, ' \ - + 'duplicate reduction will lead to arithmetic incorrectness' - raise RuntimeError(msg) + if ( + self._bucket_store.num_elements_in_bucket() + param_size > self._reduce_bucket_size + or group_id != self._bucket_store.current_group_id + ): + self._run_reduction() - self._bucket_store.add_num_elements_in_bucket(param_size, reduce_rank) - self._bucket_store.add_param(param, reduce_rank) + padding_size = self._param_store.get_param_padding_size(param) + self._bucket_store.add_param_grad(group_id, param, padding_size) ################################ # torch.optim.Optimizer methods ################################ - def backward(self, loss, retain_graph=False, sync_grad=True): - loss = self.loss_scale * loss + def backward(self, loss, retain_graph=False): + assert not ( + self._partition_grads and not self.require_grad_sync + ), "ZeRO2(partition_grads) and no_sync are not compatible" + + if self.mixed_precision_mixin is not None: + loss = self.mixed_precision_mixin.pre_backward(loss) + loss.backward(retain_graph=retain_graph) - # finish gradient reduction - if not self._partition_grads: - self._reduce_grad_stage1() - else: - # TODO: support async comm in reduce - self._reduce_grad_stage2() + if not self.require_grad_sync: + return + + self._reduce_grad(self._partition_grads) # clear reduced grads if self._overlap_communication: torch.cuda.synchronize() - self._param_store.clear_grads_of_previous_reduced_params() - # gradient synchronization - if sync_grad: - self._sync_grad() + self.zero_grad() + + def backward_by_grad(self, tensor, grad): + assert not ( + self._partition_grads and not self.require_grad_sync + ), "ZeRO2(partition_grads) and gradient accumulation(no_sync) are not compatible" + + if self.mixed_precision_mixin is not None: + grad = self.mixed_precision_mixin.pre_backward_by_grad(tensor, grad) + torch.autograd.backward(tensor, grad) + + if not self.require_grad_sync: + return + self._reduce_grad(self._partition_grads) + + # clear reduced grads + if self._overlap_communication: + torch.cuda.synchronize() + + self.zero_grad() def zero_grad(self, set_to_none=True): """ @@ -419,6 +375,8 @@ class LowLevelZeroOptimizer(ColossalaiOptimizer): :param set_to_none: Whether set the gradient to None. Default value is True. :type set_to_none: bool """ + if self.mixed_precision_mixin is not None: + self.mixed_precision_mixin.pre_zero_grad() for _, param_group in self._working_param_groups.items(): for param in param_group: if set_to_none: @@ -433,163 +391,266 @@ class LowLevelZeroOptimizer(ColossalaiOptimizer): #################### def step(self, closure=None): - assert closure is None, 'closure is not supported by step()' - - # check for overflow - found_inf = self._check_overflow() - self.grad_scaler.update(found_inf) + assert closure is None, "closure is not supported by step()" + if not self.require_grad_sync: + return - # update loss scale if overflow occurs - if found_inf: - self._grad_store.reset_all_average_gradients() + if self.mixed_precision_mixin is not None and self.mixed_precision_mixin.should_skip_step(): + self._grad_store.reset_all_gradients() if self._verbose: - self._logger.info(f'Found overflow. Skip step') + self._logger.info(f"Found overflow. Skip step") self.zero_grad() return - # copy the grad of working param to master param - single_grad_partition_groups = [] + # record all grads for unscale and clip + grad_partition_groups = [] norm_groups = [] + # sometimes not all params are 'really' working + # for instance, when layer drop, the dropped layer has no grad + # and should not be updated + real_working_params = dict() + real_master_params = dict() + + grad_index = 0 if self._partition_grads else self._local_rank + for group_id in range(self.num_param_groups): + master_params = self._master_param_groups_of_current_rank[group_id] + real_working_params[group_id] = [] + real_master_params[group_id] = [] + for splited_param in master_params: + working_param = self._param_store.master_to_working_param[id(splited_param)] + # if a working param requires grad and has no grad + # it is not 'really' working, e.g. the droped layer + # else the splited grad should be attached to the splited param + grads = self._grad_store.get_partitioned_gradients_by_param_id(group_id, id(working_param)) + if len(grads) > 0: + real_working_params[group_id].append(working_param) + grad = grads[grad_index].to(splited_param.dtype).to(splited_param.device) + splited_param.grad = grad + grad_partition_groups.append(grad) + real_master_params[group_id].append(splited_param) + # compute norm - norm_group = compute_norm(gradients=self._grad_store.get_averaged_gradients_by_group(group_id), - params=self._param_store.get_params_by_rank_group(group_id=group_id, - rank=self._local_rank), - dp_group=self._dp_torch_group, - mp_group=self._mp_torch_group) + working_grads = self._grad_store.get_working_grads_by_group_id(group_id) + norm_group = compute_norm(gradients=working_grads, dp_group=self.dp_pg, tp_group=self.tp_pg) norm_groups.append(norm_group) - # create flat gradient for the flat fp32 master params - working_avg_grads = self._grad_store.get_averaged_gradients_by_group(group_id) - flat_working_avg_grads = flatten(working_avg_grads) + self._grad_store.reset_grads_by_group_id(group_id) - dtype = self._master_flat_param_groups_of_current_rank[group_id].dtype - flat_master_avg_grads = flat_working_avg_grads.to(dtype) - - param_shape = self._master_flat_param_groups_of_current_rank[group_id].shape - assert param_shape == flat_master_avg_grads.shape, \ - f'fp32 param and grad have different shape {param_shape} vs {flat_master_avg_grads.shape}' - - single_grad_partition_groups.append(flat_master_avg_grads) - device = self._master_flat_param_groups_of_current_rank[group_id].device - self._master_flat_param_groups_of_current_rank[group_id].grad = flat_master_avg_grads.to(device) - self._grad_store.reset_average_gradients_by_group(group_id) + # update the params in the optimizer + self.optim.param_groups[group_id]["params"] = real_master_params[group_id] # unscale and clip grads global_norm = calculate_global_norm_from_list(norm_list=norm_groups) - self._unscale_and_clip_grads(single_grad_partition_groups, global_norm) + self._unscale_and_clip_grads(grad_partition_groups, global_norm) # update the parameters self.optim.step() - # release the master grad - release_param_grad(self._master_flat_param_groups_of_current_rank.values()) - # update working partition updated by the current rank - for group_id in range(len(self._working_param_groups)): - working_param = self._param_store.get_flat_param_by_rank_group(rank=self._local_rank, group_id=group_id) - master_param = self._master_flat_param_groups_of_current_rank[group_id] - working_param.data.copy_(master_param) + # release the grad + grad_partition_groups = [] + for group_id in range(self.num_param_groups): + release_param_grad(self._master_param_groups_of_current_rank[group_id]) - # broadcast the updated model weights - handles = [] + # update working partition updated by the current rank + dtype = real_working_params[0][0].dtype for group_id in range(self.num_param_groups): - for index in range(self._world_size): - rank = self._dp_global_ranks[index] - working_param = self._param_store.get_flat_param_by_rank_group(rank=index, group_id=group_id) - handle = dist.broadcast(working_param, src=rank, group=self._dp_torch_group, async_op=True) - handles.append(handle) + master_working_param = self.optim.param_groups[group_id]["params"] + for idx, splited_param in enumerate(master_working_param): + working_param = real_working_params[group_id][idx] + all_splited_param = [ + torch.zeros(splited_param.shape, device="cuda", dtype=dtype) for _ in range(self._world_size) + ] + dist.all_gather(all_splited_param, splited_param.cuda().to(dtype), group=self.dp_pg) + working_param.data.copy_(flatten(all_splited_param)[: working_param.numel()].reshape_as(working_param)) - for handle in handles: - handle.wait() + self.optim.param_groups[group_id]["params"] = self._master_param_groups_of_current_rank[group_id] ############################# # Mixed Precision Utilities # ############################# - def _check_overflow(self): - # clear previous overflow record - self._found_overflow.fill_(0.0) - - # check for overflow - for group_id in range(len(self._working_param_groups)): - for avg_grad in self._grad_store.get_averaged_gradients_by_group(group_id): - if avg_grad is not None and has_inf_or_nan(avg_grad): - self._found_overflow.fill_(1.0) - break - - # all-reduce across dp group - dist.all_reduce(self._found_overflow, op=dist.ReduceOp.MAX, group=self._dp_torch_group) - - # all-reduce over model parallel group - if self._mp_torch_group: - dist.all_reduce(self._found_overflow, op=dist.ReduceOp.MAX, group=self._mp_torch_group) - - if self._found_overflow.item() > 0: - return True - else: - return False - def _unscale_and_clip_grads(self, grad_groups_flat, total_norm): # compute combined scale factor for this group - combined_scale = self.loss_scale + div_scale = 1.0 + if self.mixed_precision_mixin is not None: + div_scale = self.mixed_precision_mixin.get_grad_div_scale() - if self._clip_grad_norm > 0.: + if self._clip_grad_norm > 0.0: # norm is in fact norm*scale - clip = ((total_norm / self.loss_scale) + 1e-6) / self._clip_grad_norm + clip = ((total_norm / div_scale) + 1e-6) / self._clip_grad_norm if clip > 1: - combined_scale = clip * self.loss_scale + div_scale = clip * div_scale for grad in grad_groups_flat: - grad.data.mul_(1. / combined_scale) + grad.data.mul_(1.0 / div_scale) ############################ # Gradient Synchronization # ############################ - def _sync_grad(self): - # update param already reduced flag - reduction_states = self._param_store.get_param_reduction_states() - for tensor, _ in reduction_states.items(): - reduction_states[tensor] = False - - # accumulate gradient + # this method is used to sync gradient manually + def sync_grad(self): for group_id in range(self.num_param_groups): - param_group = self._param_store.get_params_by_rank_group(self._local_rank, group_id) - - avg_gradients_group = self._grad_store.get_averaged_gradients_by_group(group_id) - - param_idx = 0 + param_group = self._working_param_groups[group_id] for param in param_group: - if param.grad is not None: - if len(avg_gradients_group) == param_idx: - self._grad_store.append_average_gradient_by_group(group_id, param.grad) - else: - self._grad_store.add_average_gradient_by_group(group_id, param_idx, param.grad) - param_idx += 1 - - # the gradients needed are stored in the avg_gradients buffer - # thus, can clear this - self.zero_grad() + if param.requires_grad and param.grad is not None: + self._add_to_bucket(param, group_id) - def _reduce_grad_stage1(self): - # if not overlapping communication (no reduction hook is attached) - # we need to manually reduce these gradients - if not self._overlap_communication: - for group_id in range(len(self._working_param_groups)): - param_group = self._working_param_groups[group_id] - for param in param_group: - if param.grad is not None: - self._add_to_reduction_bucket(param) - - # we need to reduce the gradients - # left in the communication bucket self._run_reduction() - def _reduce_grad_stage2(self): - # when partition_grads is True, reduction hooks - # are attached in the __init__ function, so we - # only need to reduce the gradients - # left in the communication bucket - for reduce_rank in range(self._world_size): - self._run_reduction(reduce_rank) + def _reduce_grad(self, partition_grad): + # if not overlapping communication (no reduction hook is attached) when zero1 + # we need to manually reduce these gradients + if not partition_grad and not self._overlap_communication: + self.sync_grad() + else: + self._run_reduction() + + # this context comes from pytorch DDP + @contextmanager + def no_sync(self): + old_require_grad_sync = self.require_grad_sync + self.require_grad_sync = False + try: + yield + finally: + self.require_grad_sync = old_require_grad_sync + + ############## + # State Dict # + ############## + + def _pack_state(self, state: Dict) -> Dict: + # comes from pytorch optimizer.state_dict() + param_mappings = {} + start_index = 0 + + def pack_group(group): + nonlocal start_index + packed = {k: v for k, v in group.items() if k != "params"} + param_mappings.update( + {id(p): i for i, p in enumerate(group["params"], start_index) if id(p) not in param_mappings} + ) + packed["params"] = [param_mappings[id(p)] for p in group["params"]] + start_index += len(packed["params"]) + return packed + + param_groups = [pack_group(g) for g in self.optim.param_groups] + # Remap state to use order indices as keys + packed_state = {(param_mappings[id(k)] if isinstance(k, torch.Tensor) else k): v for k, v in state.items()} + + return {"state": packed_state, "param_groups": param_groups} + + def state_dict(self) -> Dict: + """Return a state_dict same with DDP + + Returns: + Dict: the pytorch form state_dict + """ + zero_state = dict() + for param, state in self.optim.state.items(): + zero_state[param] = copy.deepcopy(state) + for k, v in state.items(): + if isinstance(v, torch.Tensor) and k != "step": + working_param = self._param_store.master_to_working_param[id(param)] + gather_tensor = [ + torch.zeros(v.shape, device="cuda", dtype=v.dtype) for _ in range(self._world_size) + ] + dist.all_gather(gather_tensor, v.cuda(), group=self.dp_pg) + param_state = ( + torch.stack(gather_tensor).view(-1)[: working_param.numel()].reshape_as(working_param).cpu() + ) + zero_state[param][k] = param_state + + states_dict = self._pack_state(zero_state) + + return states_dict + + def load_state_dict(self, state_dict: Dict): + """Load state dict, requires the state_dict be the pytorch form + + Args: + state_dict (dict): A pytorch form state_dict + """ + zero_state_dict = copy.deepcopy(state_dict) + for param_idx, state in zero_state_dict["state"].items(): + for k, v in state.items(): + if isinstance(v, torch.Tensor) and k != "step": + padding_size = (self._world_size - v.numel() % self._world_size) % self._world_size + with torch.no_grad(): + v = v.flatten() + if padding_size > 0: + v = torch.nn.functional.pad(v, [0, padding_size]) + v_list = v.split(v.numel() // self._world_size) + zero_state_dict["state"][param_idx][k] = v_list[self._local_rank].detach().clone() + + self.optim.load_state_dict(zero_state_dict) + + def state_dict_shard(self, max_shard_size: int = 1024) -> Iterator[Tuple[Dict, int]]: + """Returns dictionaries containing a whole state of the module one by one. The max size of dictionary shard is specified by ``max_shard_size``. + Only include the 'state' in state_dict. + + Args: + max_shard_size (int, optional): max size of state shard (in MB). Defaults to 1024. + + Yields: + Iterator[OrderedDict]: A generator of state dict shard + """ + ret_block = dict() + ret_block_size = 0 + + local_states = self.optim.state_dict()["state"] + for param_idx, states in local_states.items(): + current_block_size = 0 + current_block = copy.deepcopy(states) + + # find the working param of current param_id + for group_id, pg in self._master_param_groups_of_current_rank.items(): + if (group_id + 1) * len(pg) < param_idx: + continue + master_param = pg[param_idx - (group_id) * len(pg)] + working_param = self._param_store.master_to_working_param[id(master_param)] + + for k, v in states.items(): + if isinstance(v, torch.Tensor) and k != "step": + state_tensor = [torch.zeros(v.shape, device="cuda", dtype=v.dtype) for _ in range(self._world_size)] + dist.all_gather(state_tensor, v.cuda(), group=self.dp_pg) + state_tensor = ( + torch.stack(state_tensor).view(-1)[: working_param.numel()].reshape_as(working_param).cpu() + ) + current_block_size += state_tensor.numel() + current_block[k] = state_tensor + + if ret_block_size + current_block_size > max_shard_size and len(ret_block) > 0: + yield ret_block, ret_block_size + ret_block = dict() + ret_block_size = 0 + + ret_block[param_idx] = current_block + ret_block_size += current_block_size + + yield ret_block, ret_block_size + + def update_master_params(self, model: nn.Module) -> None: + """Update master params from working params + + Args: + model (nn.Module): The model to update master params + """ + for p in model.parameters(): + p_id = id(p) + if p_id in self._param_store.working_to_master_param: + master_param = self._param_store.working_to_master_param[p_id] + padding_size = self._param_store.get_param_padding_size(p) + working_param = p.data.view(-1) + if padding_size > 0: + working_param = torch.nn.functional.pad(working_param, [0, padding_size]) + master_param.copy_(working_param.chunk(self._world_size)[self._local_rank]) + + def get_working_to_master_map(self) -> Dict[int, torch.Tensor]: + return self._param_store.working_to_master_param + + def get_master_to_working_map(self) -> Dict[int, torch.Tensor]: + return self._param_store.master_to_working_param diff --git a/colossalai/zero/low_level/readme.md b/colossalai/zero/low_level/readme.md new file mode 100644 index 0000000000000000000000000000000000000000..b960a436219d8cfd835fc84f524b1c08deb40777 --- /dev/null +++ b/colossalai/zero/low_level/readme.md @@ -0,0 +1,90 @@ +# Low Level ZeRO +>Low Level ZeRO == ZeRO-DP stage 1 and 2, we would denote it as ZeRO. +## Examples of ZeRO and gradient accumulation + +The code below only shows a typical gradient accumulation process, and it drops a lot of details, such as the processing of loss. + +```python +# examples of ZeRO1 with gradient accumulation +... +outputs = model(input) +loss = SomeLoss(outputs) +if (idx + 1) % ACCUMULATE_STEP != 0: + with booster.no_sync(model, optimizer): + # under this context, the gradient would not sync when backward, + # left each rank having different gradient. + # It saves the backward time + booster.backward(loss, optimizer) + continue +else: + # need to sync all the accumulated gradient + booster.backward(loss, optimizer): + optimizer.step() + ... +``` + +```python +# example of ZeRO2 with gradient accumulation + +... +outputs = model(input) +loss = SomeLoss(outputs) +# ZeRO2 split the gradients and can NOT accumulate gradient with syncing. +booster.backward(loss, optimizer) +if (idx + 1) % ACCUMULATE_STEP == 0: + optimizer.step() +... +``` + + +## Design: +### Notion +`p32` denotes the param copy in the optimizer +`p` denotes the model param +`g` denotes the gradient + +### INIT +In low level zero(1, 2), `p32` is split. Different from the previous implement, we split each `p32` evenly by world_size. Thus, rank0 got a param list as `[p00, p10]`, rank1 got a param list as `[p-01, p-11]`, etc. +image + +For the detailed implementation, we first pad `p` for it can be split by world_size if needed. Then, we would view it to the shape `[world_size, -1]`, and each rank got its own part `p32` by cloning. + +### BWD +To leverage the communication, a gradient would be added to a bucket first. When the bucket is full, each `g` in it would be reshaped as `[world_size, -1]`. And the `[local_rank]` parts would be united. +The data structure looks like this: +``` +{ +0: [g-00, g-10], +1: [g-01, g-11], +2: [g-02, g-12] +} +``` +After that, the gradients would be flattened by rank, and the data structure looks like this: +``` +# g-X0 means flatten([g-00, g-10]) +{ +0: [g-X0], +1: [g-X1], +2: [g-X2] +} +``` +For zero1, we iterate the dictionary and do `all_reduce`. For zero2, we can just do `reduce-scatter`. + +### Optim +For each rank gets its own `p32` and the counterpart `g`, it is quite easy to do `optim.step()`. + +However, we have to consider a situation of layer drop, for instance: +``` +class MlpModel(nn.Module): + def __init__(self): + super(MlpModel, self).__init__() + self.linear1 = nn.Linear(128, 256) + self.drop_linear = nn.Linear(256, 256) + self.linear2 = nn.Linear(256, 512) + + def forward(self, x): + x = self.linear1(x) + x = self.linear2(x) + return x +``` +And the solution is to build a mapping of `p32`, `p`, and `g`. Before `optim.step()`, we collect `p` which `requires_grad=True` and `p.grad != None` as a real working param. And select the counterpart `p32` and `g`. diff --git a/colossalai/zero/wrapper.py b/colossalai/zero/wrapper.py index 3e48f49fa305fb59e45e66775e930ea36eb1c7db..ed873254e301f4a99f64a2ed302553e56e8af95a 100644 --- a/colossalai/zero/wrapper.py +++ b/colossalai/zero/wrapper.py @@ -7,10 +7,9 @@ import torch.nn as nn from .gemini import GeminiDDP -def zero_model_wrapper(model: nn.Module, - zero_stage: int = 1, - gemini_config: Optional[Dict] = None, - verbose: bool = False): +def zero_model_wrapper( + model: nn.Module, zero_stage: int = 1, gemini_config: Optional[Dict] = None, verbose: bool = False +): """This wrapper function is used to wrap your training model for ZeRO DDP. Example: @@ -50,19 +49,21 @@ def zero_model_wrapper(model: nn.Module, return wrapped_model -def zero_optim_wrapper(model: nn.Module, - optimizer: torch.optim.Optimizer, - initial_scale: float = 2**16, - growth_factor: float = 2, - backoff_factor: float = 0.5, - growth_interval: int = 1000, - hysteresis: int = 2, - min_scale: float = 1, - max_scale: float = 2**32, - max_norm: float = 0.0, - norm_type: float = 2.0, - optim_config: Optional[Dict] = None, - verbose: bool = False): +def zero_optim_wrapper( + model: nn.Module, + optimizer: torch.optim.Optimizer, + initial_scale: float = 2**16, + growth_factor: float = 2, + backoff_factor: float = 0.5, + growth_interval: int = 1000, + hysteresis: int = 2, + min_scale: float = 1, + max_scale: float = 2**32, + max_norm: float = 0.0, + norm_type: float = 2.0, + optim_config: Optional[Dict] = None, + verbose: bool = False, +): """This wrapper function is used to wrap your training optimizer for ZeRO DDP. Args: @@ -95,20 +96,22 @@ def zero_optim_wrapper(model: nn.Module, else: config_dict = copy(optim_config) - config_dict['initial_scale'] = initial_scale - config_dict['growth_factor'] = growth_factor - config_dict['backoff_factor'] = backoff_factor - config_dict['growth_interval'] = growth_interval - config_dict['hysteresis'] = hysteresis - config_dict['min_scale'] = min_scale - config_dict['max_scale'] = max_scale + config_dict["initial_scale"] = initial_scale + config_dict["growth_factor"] = growth_factor + config_dict["backoff_factor"] = backoff_factor + config_dict["growth_interval"] = growth_interval + config_dict["hysteresis"] = hysteresis + config_dict["min_scale"] = min_scale + config_dict["max_scale"] = max_scale if zero_stage in [1, 2]: from colossalai.zero.low_level import LowLevelZeroOptimizer - config_dict['partition_grad'] = zero_stage == 2 - config_dict['clip_grad_norm'] = max_norm + + config_dict["partition_grad"] = zero_stage == 2 + config_dict["clip_grad_norm"] = max_norm return LowLevelZeroOptimizer(optimizer, **config_dict, verbose=verbose) else: - from colossalai.zero.gemini.gemini_optimizer import ZeroOptimizer - config_dict['clipping_norm'] = max_norm - return ZeroOptimizer(optimizer, model, **config_dict, verbose=verbose) + from colossalai.zero.gemini.gemini_optimizer import GeminiOptimizer + + config_dict["clipping_norm"] = max_norm + return GeminiOptimizer(optimizer, model, **config_dict, verbose=verbose) diff --git a/docker/Dockerfile b/docker/Dockerfile index 49ff9b344268935ed6cddc42f62ec1081aa2512a..26d3fab1b6d7eb1f37b452709fc196cd2e409aa5 100644 --- a/docker/Dockerfile +++ b/docker/Dockerfile @@ -5,17 +5,37 @@ LABEL org.opencontainers.image.source = "https://github.com/hpcaitech/ColossalAI LABEL org.opencontainers.image.licenses = "Apache License 2.0" LABEL org.opencontainers.image.base.name = "docker.io/library/hpcaitech/cuda-conda:11.3" +# enable passwordless ssh +RUN mkdir ~/.ssh && \ + printf "Host * \n ForwardAgent yes\nHost *\n StrictHostKeyChecking no" > ~/.ssh/config && \ + ssh-keygen -t rsa -N "" -f ~/.ssh/id_rsa && \ + cat ~/.ssh/id_rsa.pub >> ~/.ssh/authorized_keys + +# enable RDMA support +RUN apt-get update && \ + apt-get install -y infiniband-diags perftest ibverbs-providers libibumad3 libibverbs1 libnl-3-200 libnl-route-3-200 librdmacm1 && \ + apt-get clean && \ + rm -rf /var/lib/apt/lists/* + # install torch -RUN conda install pytorch==1.12.1 torchvision==0.13.1 torchaudio==0.12.1 cudatoolkit=11.3 -c pytorch +RUN conda install -y pytorch==1.12.1 torchvision==0.13.1 torchaudio==0.12.1 cudatoolkit=11.3 -c pytorch + +# install ninja +RUN apt-get update && \ + apt-get install -y --no-install-recommends ninja-build && \ + apt-get clean && \ + rm -rf /var/lib/apt/lists/* # install apex RUN git clone https://github.com/NVIDIA/apex && \ cd apex && \ + git checkout 91fcaa && \ pip install packaging && \ pip install -v --disable-pip-version-check --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" --global-option="--fast_layer_norm" ./ # install colossalai -RUN git clone https://github.com/hpcaitech/ColossalAI.git \ +ARG VERSION=main +RUN git clone -b ${VERSION} https://github.com/hpcaitech/ColossalAI.git \ && cd ./ColossalAI \ && CUDA_EXT=1 pip install -v --no-cache-dir . @@ -23,8 +43,9 @@ RUN git clone https://github.com/hpcaitech/ColossalAI.git \ RUN pip install --no-cache-dir titans # install tensornvme -RUN conda install cmake && \ +RUN conda install -y cmake && \ git clone https://github.com/hpcaitech/TensorNVMe.git && \ cd TensorNVMe && \ + apt update -y && apt install -y libaio-dev && \ pip install -r requirements.txt && \ pip install -v --no-cache-dir . diff --git a/docs/README-zh-Hans.md b/docs/README-zh-Hans.md index 9d5bcfe3f9747dcfc85123451609915ea53f7cd9..499d67a37c70fccd2eb44ca01fea8951d846cd97 100644 --- a/docs/README-zh-Hans.md +++ b/docs/README-zh-Hans.md @@ -16,7 +16,7 @@ [![Documentation](https://readthedocs.org/projects/colossalai/badge/?version=latest)](https://colossalai.readthedocs.io/en/latest/?badge=latest) [![CodeFactor](https://www.codefactor.io/repository/github/hpcaitech/colossalai/badge)](https://www.codefactor.io/repository/github/hpcaitech/colossalai) [![HuggingFace badge](https://img.shields.io/badge/%F0%9F%A4%97HuggingFace-Join-yellow)](https://huggingface.co/hpcai-tech) - [![slack badge](https://img.shields.io/badge/Slack-join-blueviolet?logo=slack&)](https://join.slack.com/t/colossalaiworkspace/shared_invite/zt-z7b26eeb-CBp7jouvu~r0~lcFzX832w) + [![slack badge](https://img.shields.io/badge/Slack-join-blueviolet?logo=slack&)](https://github.com/hpcaitech/public_assets/tree/main/colossalai/contact/slack) [![WeChat badge](https://img.shields.io/badge/微信-加入-green?logo=wechat&)](https://raw.githubusercontent.com/hpcaitech/public_assets/main/colossalai/img/WeChat.png) | [English](README.md) | [中文](README-zh-Hans.md) | @@ -24,15 +24,15 @@
        ## 新闻 +* [2023/09] [One Half-Day of Training Using a Few Hundred Dollars Yields Similar Results to Mainstream Large Models, Open-Source and Commercial-Free Domain-Specific Llm Solution](https://www.hpc-ai.tech/blog/one-half-day-of-training-using-a-few-hundred-dollars-yields-similar-results-to-mainstream-large-models-open-source-and-commercial-free-domain-specific-llm-solution) +* [2023/09] [70 Billion Parameter LLaMA2 Model Training Accelerated by 195%](https://www.hpc-ai.tech/blog/70b-llama2-training) +* [2023/07] [HPC-AI Tech Raises 22 Million USD in Series A Funding](https://www.hpc-ai.tech/blog/hpc-ai-tech-raises-22-million-usd-in-series-a-funding-to-fuel-team-expansion-and-business-growth) +* [2023/07] [65B Model Pretraining Accelerated by 38%, Best Practices for Building LLaMA-Like Base Models Open-Source](https://www.hpc-ai.tech/blog/large-model-pretraining) * [2023/03] [ColossalChat: An Open-Source Solution for Cloning ChatGPT With a Complete RLHF Pipeline](https://medium.com/@yangyou_berkeley/colossalchat-an-open-source-solution-for-cloning-chatgpt-with-a-complete-rlhf-pipeline-5edf08fb538b) * [2023/03] [Intel and Colossal-AI Partner to Deliver Cost-Efficient Open-Source Solution for Protein Folding Structure Prediction](https://www.hpc-ai.tech/blog/intel-habana) * [2023/03] [AWS and Google Fund Colossal-AI with Startup Cloud Programs](https://www.hpc-ai.tech/blog/aws-and-google-fund-colossal-ai-with-startup-cloud-programs) * [2023/02] [Open Source Solution Replicates ChatGPT Training Process! Ready to go with only 1.6GB GPU Memory](https://www.hpc-ai.tech/blog/colossal-ai-chatgpt) * [2023/01] [Hardware Savings Up to 46 Times for AIGC and Automatic Parallelism](https://medium.com/pytorch/latest-colossal-ai-boasts-novel-automatic-parallelism-and-offers-savings-up-to-46x-for-stable-1453b48f3f02) -* [2022/11] [Diffusion Pretraining and Hardware Fine-Tuning Can Be Almost 7X Cheaper](https://www.hpc-ai.tech/blog/diffusion-pretraining-and-hardware-fine-tuning-can-be-almost-7x-cheaper) -* [2022/10] [Use a Laptop to Analyze 90% of Proteins, With a Single-GPU Inference Sequence Exceeding 10,000](https://www.hpc-ai.tech/blog/use-a-laptop-to-analyze-90-of-proteins-with-a-single-gpu-inference-sequence-exceeding) -* [2022/09] [HPC-AI Tech Completes $6 Million Seed and Angel Round Fundraising](https://www.hpc-ai.tech/blog/hpc-ai-tech-completes-6-million-seed-and-angel-round-fundraising-led-by-bluerun-ventures-in-the) - ## 目录
          @@ -41,6 +41,7 @@
        • Colossal-AI 成功案例
            +
          • Colossal-LLaMA-2: 千元预算半天训练,效果媲美主流大模型,开源可商用中文LLaMA-2
          • ColossalChat:完整RLHF流程0门槛克隆ChatGPT
          • AIGC: 加速 Stable Diffusion
          • 生物医药: 加速AlphaFold蛋白质结构预测
          • @@ -49,6 +50,7 @@
          • 并行训练样例展示
              +
            • LLaMA 1/2
            • GPT-3
            • GPT-2
            • BERT
            • @@ -118,15 +120,56 @@ Colossal-AI 为您提供了一系列并行组件。我们的目标是让您的

              (返回顶端)

              ## Colossal-AI 成功案例 +### Colossal-LLaMA-2 + +- 千元预算半天训练,效果媲美主流大模型,开源可商用中文LLaMA-2 +[[代码]](https://github.com/hpcaitech/ColossalAI/tree/main/applications/Colossal-LLaMA-2) +[[博客]](https://www.hpc-ai.tech/blog/one-half-day-of-training-using-a-few-hundred-dollars-yields-similar-results-to-mainstream-large-models-open-source-and-commercial-free-domain-specific-llm-solution) +[[模型权重]](https://huggingface.co/hpcai-tech/Colossal-LLaMA-2-7b-base) + +| | Backbone | Tokens Consumed | | MMLU | CMMLU | AGIEval | GAOKAO | CEval | +| :----------------------------: | :--------: | :-------------: | :------------------: | :-----------: | :-----: | :----: | :----: | :------------------------------: | +| | | - | | 5-shot | 5-shot | 5-shot | 0-shot | 5-shot | +| Baichuan-7B | - | 1.2T | | 42.32 (42.30) | 44.53 (44.02) | 38.72 | 36.74 | 42.80 | +| Baichuan-13B-Base | - | 1.4T | | 50.51 (51.60) | 55.73 (55.30) | 47.20 | 51.41 | 53.60 | +| Baichuan2-7B-Base | - | 2.6T | | 46.97 (54.16) | 57.67 (57.07) | 45.76 | 52.60 | 54.00 | +| Baichuan2-13B-Base | - | 2.6T | | 54.84 (59.17) | 62.62 (61.97) | 52.08 | 58.25 | 58.10 | +| ChatGLM-6B | - | 1.0T | | 39.67 (40.63) | 41.17 (-) | 40.10 | 36.53 | 38.90 | +| ChatGLM2-6B | - | 1.4T | | 44.74 (45.46) | 49.40 (-) | 46.36 | 45.49 | 51.70 | +| InternLM-7B | - | 1.6T | | 46.70 (51.00) | 52.00 (-) | 44.77 | 61.64 | 52.80 | +| Qwen-7B | - | 2.2T | | 54.29 (56.70) | 56.03 (58.80) | 52.47 | 56.42 | 59.60 | +| | | | | | | | | | +| Llama-2-7B | - | 2.0T | | 44.47 (45.30) | 32.97 (-) | 32.60 | 25.46 | - | +| Linly-AI/Chinese-LLaMA-2-7B-hf | Llama-2-7B | 1.0T | | 37.43 | 29.92 | 32.00 | 27.57 | - | +| wenge-research/yayi-7b-llama2 | Llama-2-7B | - | | 38.56 | 31.52 | 30.99 | 25.95 | - | +| ziqingyang/chinese-llama-2-7b | Llama-2-7B | - | | 33.86 | 34.69 | 34.52 | 25.18 | 34.2 | +| TigerResearch/tigerbot-7b-base | Llama-2-7B | 0.3T | | 43.73 | 42.04 | 37.64 | 30.61 | - | +| LinkSoul/Chinese-Llama-2-7b | Llama-2-7B | - | | 48.41 | 38.31 | 38.45 | 27.72 | - | +| FlagAlpha/Atom-7B | Llama-2-7B | 0.1T | | 49.96 | 41.10 | 39.83 | 33.00 | - | +| IDEA-CCNL/Ziya-LLaMA-13B-v1.1 | Llama-13B | 0.11T | | 50.25 | 40.99 | 40.04 | 30.54 | - | +| | | | | | | | | | +| **Colossal-LLaMA-2-7b-base** | Llama-2-7B | **0.0085T** | | 53.06 | 49.89 | 51.48 | 58.82 | 50.2 | + + ### ColossalChat -[ColossalChat](https://github.com/hpcaitech/ColossalAI/tree/main/applications/Chat): 完整RLHF流程0门槛克隆 [ChatGPT](https://openai.com/blog/chatgpt/) [[代码]](https://github.com/hpcaitech/ColossalAI/tree/main/applications/Chat) [[博客]](https://medium.com/@yangyou_berkeley/colossalchat-an-open-source-solution-for-cloning-chatgpt-with-a-complete-rlhf-pipeline-5edf08fb538b) [[在线样例]](https://chat.colossalai.org) +[ColossalChat](https://github.com/hpcaitech/ColossalAI/tree/main/applications/Chat): 完整RLHF流程0门槛克隆 [ChatGPT](https://openai.com/blog/chatgpt/) +[[代码]](https://github.com/hpcaitech/ColossalAI/tree/main/applications/Chat) +[[博客]](https://medium.com/@yangyou_berkeley/colossalchat-an-open-source-solution-for-cloning-chatgpt-with-a-complete-rlhf-pipeline-5edf08fb538b) +[[在线样例]](https://www.youtube.com/watch?v=HcTiHzApHm0) +[[教程]](https://www.youtube.com/watch?v=-qFBZFmOJfg) + +

              + +

              + +- 最高可提升RLHF PPO阶段3训练速度10倍

              @@ -199,6 +242,23 @@ Colossal-AI 为您提供了一系列并行组件。我们的目标是让您的

              (返回顶端)

              ## 并行训练样例展示 +### LLaMA2 +

              + +

              + +- 700亿参数LLaMA2训练加速195% +[[code]](https://github.com/hpcaitech/ColossalAI/tree/main/examples/language/llama2) +[[blog]](https://www.hpc-ai.tech/blog/70b-llama2-training) + +### LLaMA1 +

              + +

              + +- 650亿参数大模型预训练加速38% +[[代码]](https://github.com/hpcaitech/ColossalAI/tree/example/llama/examples/language/llama) +[[博客]](https://www.hpc-ai.tech/blog/large-model-pretraining) ### GPT-3

              @@ -424,6 +484,7 @@ Colossal-AI项目受一些相关的项目启发而成立,一些项目是我们 } ``` -Colossal-AI 已被 [SC](https://sc22.supercomputing.org/), [AAAI](https://aaai.org/Conferences/AAAI-23/), [PPoPP](https://ppopp23.sigplan.org/), [CVPR](https://cvpr2023.thecvf.com/), [ISC](https://www.isc-hpc.com/)等顶级会议录取为官方教程。 +Colossal-AI 已被[NeurIPS](https://nips.cc/), [SC](https://sc22.supercomputing.org/), [AAAI](https://aaai.org/Conferences/AAAI-23/), +[PPoPP](https://ppopp23.sigplan.org/), [CVPR](https://cvpr2023.thecvf.com/), [ISC](https://www.isc-hpc.com/), [NVIDIA GTC](https://www.nvidia.com/en-us/on-demand/session/gtcspring23-S51482/) ,等顶级会议录取为官方教程。

              (返回顶端)

              diff --git a/docs/README.md b/docs/README.md index f520608d552cf81e6752e66225dee5c884f38404..a5ae2ce96a996c4b0a8011dce97073861cfe74c1 100644 --- a/docs/README.md +++ b/docs/README.md @@ -98,7 +98,7 @@ Lastly, if you want to skip some code, you just need to add the following annota ``` -If you have any dependency required, please add it to `requriements-doc-test.txt` for pip and `conda-doc-test-deps.yml` for Conda. +If you have any dependency required, please add it to `requirements-doc-test.txt` for pip and `conda-doc-test-deps.yml` for Conda. ### 💉 Auto Documentation @@ -108,5 +108,5 @@ We support `autodoc` to extract the docstring and transform it into a Web elemen You just need to add `{{ autodoc: }}` in your markdown as a single line. An example is given below and you can see the outcome in [this PR](https://github.com/hpcaitech/ColossalAI-Documentation/pull/175). ```markdown -{{ autodoc:colossalai.amp.apex_amp.convert_to_apex_amp }} +{{ autodoc:colossalai.legacy.amp.apex_amp.convert_to_apex_amp }} ``` diff --git a/docs/REFERENCE.md b/docs/REFERENCE.md index 2681198191cba708d776663bca2cdcdb497eb931..0984b2dc3f28f45a036d9c6ecd57fbd8bf28caa1 100644 --- a/docs/REFERENCE.md +++ b/docs/REFERENCE.md @@ -1,6 +1,6 @@ # References -The Colossal-AI project aims to provide a wide array of parallelism techniques for the machine learning community in the big-model era. This project is inspired by quite a few reserach works, some are conducted by some of our developers and the others are research projects open-sourced by other organizations. We would like to credit these amazing projects below in the IEEE citation format. +The Colossal-AI project aims to provide a wide array of parallelism techniques for the machine learning community in the big-model era. This project is inspired by quite a few research works, some are conducted by some of our developers and the others are research projects open-sourced by other organizations. We would like to credit these amazing projects below in the IEEE citation format. ## By Our Team diff --git a/docs/sidebars.json b/docs/sidebars.json index 44287c17eadf45d73e96b53a2dc98387141e7244..45e86afc1f611baae9cb4d85bc3159c569f6d660 100644 --- a/docs/sidebars.json +++ b/docs/sidebars.json @@ -26,13 +26,10 @@ "collapsed": true, "items": [ "basics/command_line_tool", - "basics/define_your_config", "basics/launch_colossalai", - "basics/initialize_features", - "basics/engine_trainer", - "basics/configure_parallelization", - "basics/model_checkpoint", - "basics/colotensor_concept" + "basics/booster_api", + "basics/booster_plugins", + "basics/booster_checkpoint" ] }, { @@ -40,10 +37,10 @@ "label": "Features", "collapsed": true, "items": [ - "features/mixed_precision_training", - "features/gradient_accumulation", - "features/gradient_clipping", - "features/gradient_handler", + "features/shardformer", + "features/mixed_precision_training_with_booster", + "features/gradient_accumulation_with_booster", + "features/gradient_clipping_with_booster", "features/zero_with_chunk", { "type": "category", @@ -57,7 +54,9 @@ ] }, "features/pipeline_parallel", - "features/nvme_offload" + "features/nvme_offload", + "features/lazy_init", + "features/cluster_utils" ] }, { @@ -68,10 +67,7 @@ "advanced_tutorials/train_vit_using_pipeline_parallelism", "advanced_tutorials/train_vit_with_hybrid_parallelism", "advanced_tutorials/train_gpt_using_hybrid_parallelism", - "advanced_tutorials/define_your_own_parallel_model", - "advanced_tutorials/add_your_parallel", "advanced_tutorials/meet_gemini", - "advanced_tutorials/parallelize_your_training_like_Megatron", "advanced_tutorials/integrate_mixture_of_experts_into_your_model", "advanced_tutorials/opt_service" ] diff --git a/docs/source/en/advanced_tutorials/add_your_parallel.md b/docs/source/en/advanced_tutorials/add_your_parallel.md deleted file mode 100644 index be7284a7ab64824cb49dfb74a42831e14b4afb59..0000000000000000000000000000000000000000 --- a/docs/source/en/advanced_tutorials/add_your_parallel.md +++ /dev/null @@ -1,124 +0,0 @@ -# Add Your Own Parallel Mode - -Author: Shenggui Li, Yongbin Li - -**Prerequisite:** -- [Define Your Configuration](../basics/define_your_config.md) -- [Configure Parallelization](../basics/configure_parallelization.md) - -## Introduction - -To enable researchers and engineers to extend our system to other novel large-scale distributed training algorithm -with less effort, we have decoupled various components in the training lifecycle. You can implement your own -parallelism by simply inheriting from the base class. - -The main components are: - -1. `ProcessGroupInitializer` -2. `GradientHandler` -3. `Schedule` - -**This currently requires some code to the source code, thus we recommend that you install from source with the `-e` flag. -`-e` flag makes the installation editable, thus, your code change will be reflected in your Python runtime. -We will work on this to avoid change to source code in future releases.** - - -## Process Group Initializer - -Parallelism is often managed by process groups where processes involved in the same parallel algorithm are placed in the same -process group. For different parallel algorithms, different process groups need to be created. Colossal-AI provides a -global context for users to easily manage their process groups. If you wish to add new process group, you can easily -define a new class and set it in your configuration file. To define your own way of creating process groups, you can -follow the steps below to create a new distributed initialization. - -1. Add your parallel mode in `colossalai.context.parallel_mode.ParallelMode`. - ```python - class ParallelMode(Enum): - GLOBAL = 'global' - DATA = 'data' - PIPELINE = 'pipe' - ... - - NEW_MODE = 'new_mode' # define your mode here - ``` - -2. Create a `ProcessGroupInitializer`. You can refer to examples given in `colossalai.context.dist_group_initializer`. The - first six arguments are fixed. `ParallelContext` will pass in these arguments for you. If you need to set other - arguments, you can add it behind like the `arg1, arg2` in the example below. Lastly, register your initializer to the - registry by adding the decorator `@DIST_GROUP_INITIALIZER.register_module`. - ```python - # sample initializer class - @DIST_GROUP_INITIALIZER.register_module - class MyParallelInitializer(ProcessGroupInitializer): - - def __init__(self, - rank: int, - world_size: int, - config: Config, - data_parallel_size: int, - pipeline_parlalel_size: int, - tensor_parallel_size: int, - arg1, - arg2): - super().__init__(rank, world_size, config) - self.arg1 = arg1 - self.arg2 = arg2 - # ... your variable init - - def init_parallel_groups(self): - # initialize your process groups - pass - - ``` - - Then, you can insert your new initializer to the current mode-to-initialize mapping - in `colossalai.constants.INITIALIZER_MAPPING`. You can modify the file or insert new key-value pair dynamically. - - ```python - colossalai.constants.INITIALIZER_MAPPING['new_mode'] = 'MyParallelInitializer' - ``` - -3. Set your initializer in your config file. You can pass in your own arguments if there is any. This allows - the `ParallelContext` to create your initializer and initialize your desired process groups. - - ```python - parallel = dict( - pipeline=dict(size=1), - tensor=dict(size=x, mode='new_mode') # this is where you enable your new parallel mode - ) - ``` - -## Gradient Handler - -Gradient handlers are objects which execute the all-reduce operations on parameters' gradients. As different all-reduce -strategies may be executed for different kinds of parallelism, users can -inherit `colossalai.engine.gradient_handler.BaseGradientHandler` to implement their strategies. Currently, the library -uses the normal data parallel gradient handler which all-reduces the gradients across data parallel ranks. The data -parallel gradient handler is added to the engine automatically if data parallel is detected. You can add your own -gradient handler like below: - -```python -from colossalai.registry import GRADIENT_HANDLER -from colossalai.engine import BaseGradientHandler - -@GRADIENT_HANDLER.register_module -class YourGradientHandler(BaseGradientHandler): - - def handle_gradient(self): - do_something() - -``` - -Afterwards, you can specify the gradient handler you want to use in your configuration file. - -```python -gradient_handlers = [ - dict(type='YourGradientHandler'), -] -``` - -## Schedule - -Schedule entails how to execute a forward and backward pass. Currently, Colossal-AI provides pipeline and non-pipeline -schedules. If you want to modify how the forward and backward passes are executed, you can -inherit `colossalai.engine.schedule.BaseSchedule` and implement the `forward_back_step` function. diff --git a/docs/source/en/advanced_tutorials/define_your_own_parallel_model.md b/docs/source/en/advanced_tutorials/define_your_own_parallel_model.md deleted file mode 100644 index 8e48737d2f6435dd55f1f673647b82d6f64abaca..0000000000000000000000000000000000000000 --- a/docs/source/en/advanced_tutorials/define_your_own_parallel_model.md +++ /dev/null @@ -1,36 +0,0 @@ -# Define your own parallel model - -Author: Zhengda Bian, Yongbin Li - -> ⚠️ We are working on this documentation to make it more detailed. We will introduce the mechanism of different parallelism -> and how to use them to write a model. - -Let's say that you have a huge MLP model with billions of parameters and its extremely large hidden layer size makes it -impossible to fit into a single GPU directly. Don't worry, Colossal-AI is here to help you sort things out. With the help of Colossal-AI, -you can write your model in the familiar way in which you used to write models for a single GPU, while Colossal-AI automatically -splits your model weights and fit them perfectly into a set of GPUs. We give a simple example showing how to write a simple -2D parallel model in the Colossal-AI context. - -## Write a simple 2D parallel model - -```python -from colossalai.nn import Linear2D -import torch.nn as nn - -class MLP_2D(nn.Module): - - def __init__(self): - super().__init__() - self.linear_1 = Linear2D(in_features=1024, out_features=16384) - self.linear_2 = Linear2D(in_features=16384, out_features=1024) - - def forward(self, x): - x = self.linear_1(x) - x = self.linear_2(x) - return x -``` - -## Use pre-defined model - -For the sake of your convenience, we kindly provide you in our Model Zoo with some prevalent models such as *BERT*, *ViT*, *MoE*, -and *GPT*. Feel free to customize them into different sizes to fit into your special needs. diff --git a/docs/source/en/advanced_tutorials/integrate_mixture_of_experts_into_your_model.md b/docs/source/en/advanced_tutorials/integrate_mixture_of_experts_into_your_model.md index e01caf76d2b323959a4ce7d7d85521ff919a5385..bfa5539fe3a6516d5e1820171430db925c7b8401 100644 --- a/docs/source/en/advanced_tutorials/integrate_mixture_of_experts_into_your_model.md +++ b/docs/source/en/advanced_tutorials/integrate_mixture_of_experts_into_your_model.md @@ -121,7 +121,7 @@ Inside the initialization of Experts, the local expert number of each GPU will b ## Train Your Model -Do not to forget to use `colossalai.initialize` function in `colosalai` to add gradient handler for the engine. +Do not to forget to use `colossalai.initialize` function in `colossalai` to add gradient handler for the engine. We handle the back-propagation of MoE models for you. In `colossalai.initialize`, we will automatically create a `MoeGradientHandler` object to process gradients. You can find more information about the handler `MoeGradientHandler` in colossal directory. @@ -137,3 +137,4 @@ criterion = MoeLoss( Finally, just use trainer or engine in `colossalai` to do your training. Otherwise, you should take care of gradient by yourself. + diff --git a/docs/source/en/advanced_tutorials/meet_gemini.md b/docs/source/en/advanced_tutorials/meet_gemini.md index 8afb6705b6ae84afc23cc16767cf2581719c1abb..e94e3fea3710af3c40be0012bb9ae1ff7c935ef6 100644 --- a/docs/source/en/advanced_tutorials/meet_gemini.md +++ b/docs/source/en/advanced_tutorials/meet_gemini.md @@ -9,16 +9,21 @@ When you only have a few GPUs for large model training tasks, **heterogeneous tr ## Usage -At present, Gemini supports compatibility with ZeRO parallel mode, and it is really simple to use Gemini. Set attribute of zero model_config, i.e., tensor_placement_policy='auto'. - -``` -zero = dict( - model_config=dict( - tensor_placement_policy='auto', - shard_strategy=BucketTensorShardStrategy() - ), - optimizer_config=dict( - ...) +At present, Gemini supports compatibility with ZeRO parallel mode, and it is really simple to use Gemini: Inject the features of `GeminiPlugin` into training components with `booster`. More instructions of `booster` please refer to [**usage of booster**](../basics/booster_api.md). + +```python +from torchvision.models import resnet18 +from colossalai.booster import Booster +from colossalai.zero import ColoInitContext +from colossalai.booster.plugin import GeminiPlugin +plugin = GeminiPlugin(placement_policy='cuda', strict_ddp_mode=True, max_norm=1.0, initial_scale=2**5) +booster = Booster(plugin=plugin) +ctx = ColoInitContext() +with ctx: + model = resnet18() +optimizer = HybridAdam(model.parameters(), lr=1e-3) +criterion = lambda x: x.mean() +model, optimizer, criterion, _, _ = booster.boost(model, optimizer, criterion) ) ``` @@ -86,3 +91,5 @@ The important duty of MSC is to adjust the tensor layout position. For example, In the warmup stage, since we haven't finished a complete iteration yet, we don't know actual memory occupation. At this time, we limit the upper bound of memory usage of the model data. For example, only 30% of the GPU memory can be used. This ensures that we can successfully complete the warmup state. In the non-warmup stage, we need to use the memory information of non-model data collected in the warm-up stage to reserve the peak memory required by the computing device for the next Period, which requires us to move some model tensors. In order to avoid frequent replacement of the same tensor in and out of the CPU-GPU, causing a phenomenon similar to [cache thrashing](https://en.wikipedia.org/wiki/Thrashing_(computer_science)). Using the iterative characteristics of DNN training, we design the OPT cache swap out strategy. Specifically, in the warmup stage, we record the sampling time required by each tensor computing device. If we need to expel some HOLD tensors, we will choose the latest tensor needed on this device as the victim. + + diff --git a/docs/source/en/advanced_tutorials/opt_service.md b/docs/source/en/advanced_tutorials/opt_service.md index a43ec7fdd1fe8736a90e05ba186fe554c1e75384..eccfa12f9389a83eaab27ee7dfadaf3c4551e6f0 100644 --- a/docs/source/en/advanced_tutorials/opt_service.md +++ b/docs/source/en/advanced_tutorials/opt_service.md @@ -53,7 +53,7 @@ export CHECKPOINT_DIR="your_opt_checkpoint_path" # the ${CONFIG_DIR} must contain a server.sh file as the entry of service export CONFIG_DIR="config_file_path" -docker run --gpus all --rm -it -p 8020:8020 -v ${CHECKPOINT_DIR}:/model_checkpoint -v ${CONFIG_DIR}:/config --ipc=host energonai:lastest +docker run --gpus all --rm -it -p 8020:8020 -v ${CHECKPOINT_DIR}:/model_checkpoint -v ${CONFIG_DIR}:/config --ipc=host energonai:latest ``` Then open `https://[IP-ADDRESS]:8020/docs#` in your browser to try out! diff --git a/docs/source/en/advanced_tutorials/parallelize_your_training_like_Megatron.md b/docs/source/en/advanced_tutorials/parallelize_your_training_like_Megatron.md deleted file mode 100644 index e7698e5e9d1b9d5c99d52cd455281d7e9f358072..0000000000000000000000000000000000000000 --- a/docs/source/en/advanced_tutorials/parallelize_your_training_like_Megatron.md +++ /dev/null @@ -1,192 +0,0 @@ -# Parallelize Your Training like Megatron-LM via ColoTensor - -Author: [Haichen Huang](https://github.com/1SAA) and [Jiarui Fang](https://github.com/feifeibear) - -**Prerequisite:** -- [ColoTensor Concepts](../basics/colotensor_concept.md) - -## Introduction - -Thanks to the convenience given by ColoTensor, users can apply parallelism with the least edition to their serial code. -In this tutorial, we will illustrate how to modify the training model to automatically adapt the code to parallel training like Megatron-LM. -We take the GPT-2 model offered by HuggingFace as an example and provide a way for you to pre-train the GPT-2 model on a single GPU. - -Megatron-LM provided a profound paradigm to parallelize large transformer language models. -However, in order to train large transformer language models at scale, users have to build their models with those modules provided by Megatron. -It imposes several difficult jobs on users, such as loading the weights from the pre-trained models and constructing the parallelized models. -To mitigate users' trouble, we offer ColoTensor to enable the tensor model parallelism automatically. - -## Definitions of the model and the loss function - -First we use the GPTModel and GPTLoss directly from the HuggingFace library. - -```python -import torch -import torch.nn as nn -from transformers import GPT2Config, GPT2LMHeadModel - -class GPTLMModel(nn.Module): - def __init__(self, hidden_size=768, num_layers=12, num_attention_heads=12, max_seq_len=1024, vocab_size=50257, checkpoint=False): - super().__init__() - self.checkpoint = checkpoint - self.model = GPT2LMHeadModel(GPT2Config(n_embd=hidden_size, n_layer=num_layers, - n_head=num_attention_heads, n_positions=max_seq_len, n_ctx=max_seq_len, vocab_size=vocab_size)) - if checkpoint: - self.model.gradient_checkpointing_enable() - - def forward(self, input_ids, attention_mask): - # Only return lm_logits - return self.model(input_ids=input_ids, attention_mask=attention_mask, use_cache=not self.checkpoint)[0] - - -class GPTLMLoss(nn.Module): - def __init__(self): - super().__init__() - self.loss_fn = nn.CrossEntropyLoss() - - def forward(self, logits, labels): - shift_logits = logits[..., :-1, :].contiguous() - shift_labels = labels[..., 1:].contiguous() - # Flatten the tokens - return self.loss_fn(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) -``` - -## Brief Review of GPT-2 - -Now, we recall the structure of each GPT-2 model. -Every GPT-2 model can be represented as a DAG. -As shown in the below pictures, each circle represents an operator and each square represents a weight. -An arrow indicates the flow of the input data, and the notation alongside the arrow demonstrates the shape of the input data. - -Then, let's take an insight into this GPT-2 model. It consists of three parts. -They are the **embedding module**, **transformer layers**, and the **classification head**. - -The embedding module contains two weights, token embedding weight and position embedding weight. -After the forward operation of the embedding module, each word in all sequences of the raw input data will be embedded into a hidden state. - -
              - -
              The embedding module
              -
              - -Each transformer layer contains two blocks. The self-attention operation is called in the first block and a two-layer percepton is located in the second block. - -
              - -
              The transformer layer
              -
              - -In the end, the classification head is just a linear module without bias, which only has a weight inside. - -## Applied with ColoTensor - -Two steps make your serial code adapted to Megatron-LM tensor parallel style. -1. Initialize the model in the context of ColoInitContext. -2. Setting ColoTensorSpec for each parameter. - -### Initialize with ColoInitContext - -We should build the model in the ColoInitContext. -In this context, any parameter initialized would be transformed to ColoParameter and moved to the corresponded device automatically. - -```python -from colossalai.utils.model.colo_init_context import ColoInitContext - -with ColoInitContext(device=torch.device('cpu')): - model = GPTLMModel() -``` - -### Setting ColoTensorSpec for each parameter - -After the creation of the model, we establish the distributed environment through ProcessGroup. -Here, we specify the degree of the tensor parallelism as the same as the number of all GPUs, which means the degree of data parallelism is 1. - -```python -import torch.distributed as dist -from colossalai.tensor import ProcessGroup - -pg = ProcessGroup(tp_degree=dist.get_world_size()) -``` - -Now, some auxiliary functions are necessary for the next step. We define two functions to split a parameter. -Megatron-LM-like tensor parallelism requires splitting a parameter tensor along its first dimension or its last dimension. - -```python -from colossalai.tensor import ShardSpec, ComputeSpec, ComputePattern, ColoParameter, ProcessGroup - -def split_param_single_dim_tp1d(dim: int, param: ColoParameter, pg: ProcessGroup): - spec = (ShardSpec([dim], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D)) - if param.process_group.tp_world_size() == 1: - param.set_process_group(pg) - param.set_tensor_spec(*spec) - - -def split_param_row_tp1d(param: ColoParameter, pg: ProcessGroup): - split_param_single_dim_tp1d(0, param, pg) - - -def split_param_col_tp1d(param: ColoParameter, pg: ProcessGroup): - split_param_single_dim_tp1d(-1, param, pg) -``` - -Then we adapt the model to the tensor parallelism. -According to the tensor parallelism applied in Megatron, it is supposed to shard along the last dimension of tensors, including the weights of token embedding, position embedding, all linear weights and biases in self-attention blocks, the first weight linear and bias in each MLP. -And it shards the second linear weight along its first dimension. - -```python -for mn, module in model.named_modules(): - for pn, param in module.named_parameters(recurse=False): - # set process group for all parameters - param.set_process_group(pg) - - if 'mlp.c_fc' in mn: - if 'weight' in pn or 'bias' in pn: - split_param_col_tp1d(param, pg) # colmn slice - # keep the shape of the output from c_fc - param.compute_spec.set_output_replicate(False) - elif 'mlp.c_proj' in mn: - if 'weight' in pn: - split_param_row_tp1d(param, pg) # row slice - elif 'wte' in mn or 'wpe' in mn: - split_param_col_tp1d(param, pg) # colmn slice - elif 'c_attn' in mn or 'c_proj' in mn: - split_param_col_tp1d(param, pg) # colmn slice -``` - -The modified model is illustrated below. - -The embedding module: - -
              - -
              The modified embedding module
              -
              - -The transformer layers: - -
              - -
              The modified transformer layer
              -
              - -Once users have specified the distributed pattern of each parameter, ColoTensor is capable of inferring the computation patterns of all operators, including matrix multiplication, the linear function, other elementwise functions in torch.nn.functional, etc. -In this way, users can train their models as usual. - -In our latest example, a Gemini + ZeRO DDP model is also defined to reduce overhead and improve efficiency.For the details of this part, please refer to [ZeRO](../features/zero_with_chunk.md). You can combine these two parts to understand our entire training process: - -```python -def gemini_zero_dpp(model: torch.nn.Module, pg: ProcessGroup, placememt_policy: str = "auto"): - from colossalai.nn.parallel import GeminiDDP - model = GeminiDDP(model, - device=get_current_device(), - placement_policy=placememt_policy, - pin_memory=True, - search_range_mb=32) - return model -``` - -## Pretrain GPT-2 On Single GPU - -The above optimization we made allows us to pretrain the GPT-2 model on a single GPU. We only need to set the parameter `GPUNUM`=1 in `run.sh`, and then we can complete the model training on a single GPU when running the file. - -The GPT-2 example is accessible at [Train GPT with Colossal-AI](https://github.com/hpcaitech/ColossalAI/tree/main/examples/language/gpt). diff --git a/docs/source/en/advanced_tutorials/train_gpt_using_hybrid_parallelism.md b/docs/source/en/advanced_tutorials/train_gpt_using_hybrid_parallelism.md index 715c15eb63003b7c1cde0b28d2533518755db34f..0218264cc258dd3e8c90d17e04f45ad5e351bd03 100644 --- a/docs/source/en/advanced_tutorials/train_gpt_using_hybrid_parallelism.md +++ b/docs/source/en/advanced_tutorials/train_gpt_using_hybrid_parallelism.md @@ -36,14 +36,14 @@ import torch import torch.nn as nn from colossalai import nn as col_nn from colossalai.amp import AMP_TYPE -from colossalai.builder.pipeline import partition_uniform -from colossalai.context.parallel_mode import ParallelMode +from colossalai.legacy.builder.pipeline import partition_uniform +from colossalai.legacy.context.parallel_mode import ParallelMode from colossalai.core import global_context as gpc -from colossalai.engine.schedule import (InterleavedPipelineSchedule, +from colossalai.legacy.engine.schedule import (InterleavedPipelineSchedule, PipelineSchedule) from colossalai.logging import disable_existing_loggers, get_dist_logger -from colossalai.nn.layer.wrapper import PipelineSharedModuleWrapper -from colossalai.trainer import Trainer, hooks +from colossalai.legacy.nn.layer.wrapper import PipelineSharedModuleWrapper +from colossalai.legacy.trainer import Trainer, hooks from colossalai.utils.timer import MultiTimer from model_zoo.gpt import GPTLMLoss from torch.nn import functional as F @@ -268,3 +268,4 @@ def train(): return_output_label=False, ) ``` + diff --git a/docs/source/en/advanced_tutorials/train_vit_using_pipeline_parallelism.md b/docs/source/en/advanced_tutorials/train_vit_using_pipeline_parallelism.md index b26599740c5f573182bf764d969a0d8c376e8660..6dbe338008fa24a6134eecfcb73153001d82e35e 100644 --- a/docs/source/en/advanced_tutorials/train_vit_using_pipeline_parallelism.md +++ b/docs/source/en/advanced_tutorials/train_vit_using_pipeline_parallelism.md @@ -34,11 +34,11 @@ import colossalai import colossalai.nn as col_nn import torch import torch.nn as nn -from colossalai.builder import build_pipeline_model -from colossalai.engine.schedule import (InterleavedPipelineSchedule, +from colossalai.legacy.builder import build_pipeline_model +from colossalai.legacy.engine.schedule import (InterleavedPipelineSchedule, PipelineSchedule) from colossalai.logging import disable_existing_loggers, get_dist_logger -from colossalai.trainer import Trainer, hooks +from colossalai.legacy.trainer import Trainer, hooks from colossalai.utils import MultiTimer, get_dataloader from timm.models import vision_transformer as vit from torchvision import transforms @@ -51,17 +51,17 @@ from torchvision.datasets import CIFAR10 Generally, we provide 3 ways to build a pipelined model: -1. `colossalai.builder.build_pipeline_model_from_cfg` -2. `colossalai.builder.build_pipeline_model` +1. `colossalai.legacy.builder.build_pipeline_model_from_cfg` +2. `colossalai.legacy.builder.build_pipeline_model` 3. Split the model by stages by yourself When your memory can fit the model, you can use the first two methods to build your model, otherwise you must split the model by yourself. The first two methods first build the whole model on CPU, then split the model, and finally you can just move the corresponding part of model to GPU. -`colossalai.builder.build_pipeline_model_from_cfg()` receives a config file of model, and it can split the model uniformly (by layer) or balanced (by parameter size). +`colossalai.legacy.builder.build_pipeline_model_from_cfg()` receives a config file of model, and it can split the model uniformly (by layer) or balanced (by parameter size). -If you are familiar with `PyTorch`, you can use `colossalai.builder.build_pipeline_model()` which receives a `torch.nn.Sequential` model and split it by layer uniformly. +If you are familiar with `PyTorch`, you can use `colossalai.legacy.builder.build_pipeline_model()` which receives a `torch.nn.Sequential` model and split it by layer uniformly. -In this tutorial, we will modify [TIMM/ViT](https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py) to `torch.nn.Sequential` and then use `colossalai.builder.build_pipeline_model()` to build the pipelined model. +In this tutorial, we will modify [TIMM/ViT](https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py) to `torch.nn.Sequential` and then use `colossalai.legacy.builder.build_pipeline_model()` to build the pipelined model. When the data is **one** `Tensor`, you can use the positional argument in `forward()` of your model to get the data tensor. For the first stage of pipeline, the first positional argument of `forward()` is the data tensor loaded from data loader. For other stages, the first positional argument of `forward()` is the output tensor from the previous stage. Note that if the stage is not the last stage, the return of `forward()` must be a `Tensor`. @@ -195,7 +195,7 @@ def build_cifar(batch_size): ## Training ViT using pipeline -You can set the size of pipeline parallel and number of microbatches in config. `NUM_CHUNKS` is useful when using interleved-pipeline (for more details see [Efficient Large-Scale Language Model Training on GPU Clusters Using Megatron-LM](https://arxiv.org/abs/2104.04473) ). The original batch will be split into `num_microbatches`, and each stage will load a micro batch each time. Then we will generate an approriate schedule for you to execute the pipeline training. If you don't need the output and label of model, you can set `return_output_label` to `False` when calling `trainer.fit()` which can further reduce GPU memory usage. +You can set the size of pipeline parallel and number of microbatches in config. `NUM_CHUNKS` is useful when using interleaved-pipeline (for more details see [Efficient Large-Scale Language Model Training on GPU Clusters Using Megatron-LM](https://arxiv.org/abs/2104.04473) ). The original batch will be split into `num_microbatches`, and each stage will load a micro batch each time. Then we will generate an appropriate schedule for you to execute the pipeline training. If you don't need the output and label of model, you can set `return_output_label` to `False` when calling `trainer.fit()` which can further reduce GPU memory usage. You should `export DATA=/path/to/cifar`. @@ -245,3 +245,4 @@ def train(): hooks=hook_list, display_progress=True) ``` + diff --git a/docs/source/en/advanced_tutorials/train_vit_with_hybrid_parallelism.md b/docs/source/en/advanced_tutorials/train_vit_with_hybrid_parallelism.md index b2438a1cf562a5f6bd0d22c96d12856b8afca4d6..0ec9d5c3c5deb255b72bb36357238a9beb3bd360 100644 --- a/docs/source/en/advanced_tutorials/train_vit_with_hybrid_parallelism.md +++ b/docs/source/en/advanced_tutorials/train_vit_with_hybrid_parallelism.md @@ -16,14 +16,14 @@ In this example for ViT model, Colossal-AI provides three different parallelism We will show you how to train ViT on CIFAR-10 dataset with these parallelism techniques. To run this example, you will need 2-4 GPUs. -## Tabel of Contents +## Table of Contents 1. Colossal-AI installation 2. Steps to train ViT with data parallelism 3. Steps to train ViT with pipeline parallelism 4. Steps to train ViT with tensor parallelism or hybrid parallelism ## Colossal-AI Installation -You can install Colossal-AI pacakage and its dependencies with PyPI. +You can install Colossal-AI package and its dependencies with PyPI. ```bash pip install colossalai ``` @@ -31,7 +31,7 @@ pip install colossalai ## Data Parallelism -Data parallism is one basic way to accelerate model training process. You can apply data parallelism to training by only two steps: +Data parallelism is one basic way to accelerate model training process. You can apply data parallelism to training by only two steps: 1. Define a configuration file 2. Change a few lines of code in train script @@ -78,8 +78,8 @@ from colossalai.context import ParallelMode from colossalai.core import global_context as gpc from colossalai.logging import disable_existing_loggers, get_dist_logger from colossalai.nn.lr_scheduler import LinearWarmupLR -from colossalai.nn.metric import Accuracy -from colossalai.trainer import Trainer, hooks +from colossalai.legacy.nn.metric import Accuracy +from colossalai.legacy.trainer import Trainer, hooks ``` - Other modules @@ -94,7 +94,7 @@ from torchvision import transforms from torchvision.datasets import CIFAR10 ``` -#### Lauch Colossal-AI +#### Launch Colossal-AI In train script, you need to initialize the distributed environment for Colossal-AI after your config file is prepared. We call this process `launch`. In Colossal-AI, we provided several launch methods to initialize the distributed backend. In most cases, you can use `colossalai.launch` and `colossalai.get_default_parser` to pass the parameters via command line. Besides, Colossal-AI can utilize the existing launch tool provided by PyTorch as many users are familiar with by using `colossalai.launch_from_torch`. For more details, you can view the related [documents](https://www.colossalai.org/docs/basics/launch_colossalai). @@ -273,8 +273,8 @@ SEQ_LENGTH = (IMG_SIZE // PATCH_SIZE) ** 2 + 1 # add 1 for cls token ### Build pipeline model (`/hybrid_parallel/model/vit.py`) Colossal-AI provides two methods to build a pipeline model from the existing model. -- `colossalai.builder.build_pipeline_model_from_cfg` -- `colossalai.builder.build_pipeline_model` +- `colossalai.legacy.builder.build_pipeline_model_from_cfg` +- `colossalai.legacy.builder.build_pipeline_model` Besides, you can also build a pipeline model from scratch with Colossal-AI. ```python @@ -284,11 +284,11 @@ from typing import Callable import inspect import torch from colossalai import nn as col_nn -from colossalai.registry import LAYERS, MODELS +from colossalai.legacy.registry import LAYERS, MODELS from colossalai.logging import get_dist_logger from colossalai.core import global_context as gpc from colossalai.context import ParallelMode -from colossalai.builder.pipeline import partition_uniform +from colossalai.legacy.builder.pipeline import partition_uniform from torch import dtype, nn from model_zoo.vit.vit import ViTBlock, ViTEmbedding, ViTHead @@ -415,7 +415,7 @@ def build_pipeline_vit(num_layers, num_chunks, device=torch.device('cuda'), **kw #### Import modules ```python -from colossalai.engine.schedule import (InterleavedPipelineSchedule, +from colossalai.legacy.engine.schedule import (InterleavedPipelineSchedule, PipelineSchedule) from colossalai.utils import MultiTimer import os @@ -613,7 +613,7 @@ NUM_MICRO_BATCHES = parallel['pipeline'] TENSOR_SHAPE = (BATCH_SIZE // NUM_MICRO_BATCHES, SEQ_LENGTH, HIDDEN_SIZE) ``` -Ohter configs: +Other configs: ```python # hyper parameters # BATCH_SIZE is as per GPU @@ -644,3 +644,4 @@ torchrun --standalone --nproc_per_node train_hybrid.py --config ./co # If your torch >= 1.9.0 # python -m torch.distributed.run --standalone --nproc_per_node= train_hybrid.py --config ./configs/config_hybrid_parallel.py ``` + diff --git a/docs/source/en/basics/booster_api.md b/docs/source/en/basics/booster_api.md new file mode 100644 index 0000000000000000000000000000000000000000..4d7ffe5a4cbf4062273ee9b2bf28c31ab521aeb5 --- /dev/null +++ b/docs/source/en/basics/booster_api.md @@ -0,0 +1,92 @@ +# Booster API + +Author: [Mingyan Jiang](https://github.com/jiangmingyan), [Jianghai Chen](https://github.com/CjhHa1), [Baizhou Zhang](https://github.com/Fridge003) + +**Prerequisite:** + +- [Distributed Training](../concepts/distributed_training.md) +- [Colossal-AI Overview](../concepts/colossalai_overview.md) + +**Example Code** + +- [Train ResNet on CIFAR-10 with Booster](https://github.com/hpcaitech/ColossalAI/blob/main/examples/tutorial/new_api/cifar_resnet) +- [Train LLaMA-1/2 on RedPajama with Booster](https://github.com/hpcaitech/ColossalAI/tree/main/examples/language/llama2) + +## Introduction + +In our new design, `colossalai.booster` replaces the role of `colossalai.initialize` to inject features into your training components (e.g. model, optimizer, dataloader) seamlessly. With these new APIs, you can integrate your model with our parallelism features more friendly. Also, calling `colossalai.booster` is the standard procedure before you run into your training loops. In the sections below, we will cover how `colossalai.booster` works and what we should take note of. + +### Plugin + +Plugin is an important component that manages parallel configuration (eg: The gemini plugin encapsulates the gemini acceleration solution). Currently supported plugins are as follows: + +**_HybridParallelPlugin:_** This plugin wraps the hybrid parallel training acceleration solution. It provides an interface for any combination of tensor parallel, pipeline parallel and data parallel strategies including DDP and ZeRO. + +**_GeminiPlugin:_** This plugin wraps the Gemini acceleration solution, that ZeRO with chunk-based memory management. + +**_TorchDDPPlugin:_** This plugin wraps the DDP acceleration solution of Pytorch. It implements data parallel at the module level which can run across multiple machines. + +**_LowLevelZeroPlugin:_** This plugin wraps the 1/2 stage of Zero Redundancy Optimizer. Stage 1 : Shards optimizer states across data parallel workers/GPUs. Stage 2 : Shards optimizer states + gradients across data parallel workers/GPUs. + +**_TorchFSDPPlugin:_** This plugin wraps the FSDP acceleration solution of Pytorch and can be used to train models with zero-dp. + +More details about usages of each plugin can be found in chapter [Booster Plugins](./booster_plugins.md). + +Some plugins support lazy initialization, which can be used to save memory when initializating large models. For more details, please see [Lazy Initialization](../features/lazy_init.md). + +### API of booster + +{{ autodoc:colossalai.booster.Booster }} + +## Usage + +In a typical workflow, you should launch distributed environment at the beginning of training script and create objects needed (such as models, optimizers, loss function, data loaders etc.) firstly, then call `booster.boost` to inject features into these objects, After that, you can use our booster APIs and these returned objects to continue the rest of your training processes. + +A pseudo-code example is like below: + +```python +import torch +from torch.optim import SGD +from torchvision.models import resnet18 + +import colossalai +from colossalai.booster import Booster +from colossalai.booster.plugin import TorchDDPPlugin + +def train(): + # launch colossalai + colossalai.launch(config=dict(), rank=rank, world_size=world_size, port=port, host='localhost') + + # create plugin and objects for training + plugin = TorchDDPPlugin() + booster = Booster(plugin=plugin) + model = resnet18() + criterion = lambda x: x.mean() + optimizer = SGD((model.parameters()), lr=0.001) + scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.1) + + # use booster.boost to wrap the training objects + model, optimizer, criterion, _, scheduler = booster.boost(model, optimizer, criterion, lr_scheduler=scheduler) + + # do training as normal, except that the backward should be called by booster + x = torch.randn(4, 3, 224, 224) + x = x.to('cuda') + output = model(x) + loss = criterion(output) + booster.backward(loss, optimizer) + optimizer.clip_grad_by_norm(1.0) + optimizer.step() + scheduler.step() + optimizer.zero_grad() + + # checkpointing using booster api + save_path = "./model" + booster.save_model(model, save_path, shard=True, size_per_shard=10, use_safetensors=True) + + new_model = resnet18() + booster.load_model(new_model, save_path) +``` + +For more design details please see [this page](https://github.com/hpcaitech/ColossalAI/discussions/3046). + + diff --git a/docs/source/en/basics/booster_checkpoint.md b/docs/source/en/basics/booster_checkpoint.md new file mode 100644 index 0000000000000000000000000000000000000000..ea6c11ae2cdcb06495bb47d2396f12299d77cd24 --- /dev/null +++ b/docs/source/en/basics/booster_checkpoint.md @@ -0,0 +1,70 @@ +# Booster Checkpoint + +Author: [Hongxin Liu](https://github.com/ver217) + +**Prerequisite:** +- [Booster API](./booster_api.md) + +## Introduction + +We've introduced the [Booster API](./booster_api.md) in the previous tutorial. In this tutorial, we will introduce how to save and load checkpoints using booster. + +## Model Checkpoint + +{{ autodoc:colossalai.booster.Booster.save_model }} + +Model must be boosted by `colossalai.booster.Booster` before saving. `checkpoint` is the path to saved checkpoint. It can be a file, if `shard=False`. Otherwise, it should be a directory. If `shard=True`, the checkpoint will be saved in a sharded way. This is useful when the checkpoint is too large to be saved in a single file. Our sharded checkpoint format is compatible with [huggingface/transformers](https://github.com/huggingface/transformers), so you can use huggingface `from_pretrained` method to load model from our sharded checkpoint. + +{{ autodoc:colossalai.booster.Booster.load_model }} + +Model must be boosted by `colossalai.booster.Booster` before loading. It will detect the checkpoint format automatically, and load in corresponding way. + +If you want to load a pretrained model from Huggingface while the model is too large to be directly loaded through `from_pretrained` on a single device, a recommended way is to download the pretrained weights to a local directory, and use `booster.load` to load from that directory after boosting the model. Also, the model should be initialized under lazy initialization context to avoid OOM. Here is an example pseudocode: +```python +from colossalai.lazy import LazyInitContext +from huggingface_hub import snapshot_download +... + +# Initialize model under lazy init context +init_ctx = LazyInitContext(default_device=get_current_device) +with init_ctx: + model = LlamaForCausalLM(config) + +... + +# Wrap the model through Booster.boost +model, optimizer, _, _, _ = booster.boost(model, optimizer) + +# download huggingface pretrained model to local directory. +model_dir = snapshot_download(repo_id="lysandre/arxiv-nlp") + +# load model using booster.load +booster.load(model, model_dir) +... +``` + +## Optimizer Checkpoint + +{{ autodoc:colossalai.booster.Booster.save_optimizer }} + +Optimizer must be boosted by `colossalai.booster.Booster` before saving. + +{{ autodoc:colossalai.booster.Booster.load_optimizer }} + +Optimizer must be boosted by `colossalai.booster.Booster` before loading. + +## LR Scheduler Checkpoint + +{{ autodoc:colossalai.booster.Booster.save_lr_scheduler }} + +LR scheduler must be boosted by `colossalai.booster.Booster` before saving. `checkpoint` is the local path to checkpoint file. + +{{ autodoc:colossalai.booster.Booster.load_lr_scheduler }} + +LR scheduler must be boosted by `colossalai.booster.Booster` before loading. `checkpoint` is the local path to checkpoint file. + +## Checkpoint design + +More details about checkpoint design can be found in our discussion [A Unified Checkpoint System Design](https://github.com/hpcaitech/ColossalAI/discussions/3339). + + diff --git a/docs/source/en/basics/booster_plugins.md b/docs/source/en/basics/booster_plugins.md new file mode 100644 index 0000000000000000000000000000000000000000..feb37fc15de2bed4b101b6b59478d58092f71d40 --- /dev/null +++ b/docs/source/en/basics/booster_plugins.md @@ -0,0 +1,91 @@ +# Booster Plugins + +Author: [Hongxin Liu](https://github.com/ver217), [Baizhou Zhang](https://github.com/Fridge003), [Pengtai Xu](https://github.com/ppt0011) + +**Prerequisite:** +- [Booster API](./booster_api.md) + +## Introduction + +As mentioned in [Booster API](./booster_api.md), we can use booster plugins to customize the parallel training. In this tutorial, we will introduce how to use booster plugins. + +We currently provide the following plugins: + +- [Torch DDP Plugin](#torch-ddp-plugin): It is a wrapper of `torch.nn.parallel.DistributedDataParallel` and can be used to train models with data parallelism. +- [Torch FSDP Plugin](#torch-fsdp-plugin): It is a wrapper of `torch.distributed.fsdp.FullyShardedDataParallel` and can be used to train models with zero-dp. +- [Low Level Zero Plugin](#low-level-zero-plugin): It wraps the `colossalai.zero.low_level.LowLevelZeroOptimizer` and can be used to train models with zero-dp. It only supports zero stage-1 and stage-2. +- [Gemini Plugin](#gemini-plugin): It wraps the [Gemini](../features/zero_with_chunk.md) which implements Zero-3 with chunk-based and heterogeneous memory management. +- [Hybrid Pararllel Plugin](#hybrid-parallel-plugin): It provides a tidy interface that integrates the power of Shardformer, pipeline manager, mixied precision training, TorchDDP and Zero stage 1/2 feature. With this plugin, transformer models can be easily trained with any combination of tensor parallel, pipeline parallel and data parallel (DDP/Zero) efficiently, along with various kinds of optimization tools for acceleration and memory saving. Detailed information about supported parallel strategies and optimization tools is explained in the section below. + +More plugins are coming soon. + +## Choosing Your Plugin + +Generally only one plugin is used to train a model. Our recommended use case for each plugin is as follows. + +- [Torch DDP Plugin](#torch-ddp-plugin): It is suitable for models with less than 2 billion parameters (e.g. Bert-3m, GPT2-1.5b). +- [Torch FSDP Plugin](#torch-fsdp-plugin) / [Low Level Zero Plugin](#low-level-zero-plugin): It is suitable for models with less than 10 billion parameters (e.g. GPTJ-6b, MegatronLM-8b). +- [Gemini Plugin](#gemini-plugin): It is suitable for models with more than 10 billion parameters (e.g. TuringNLG-17b) and is ideal for scenarios with **high cross-node bandwidth and medium to small-scale clusters (below a thousand cards)** (e.g. Llama2-70b). +- [Hybrid Pararllel Plugin](#hybrid-parallel-plugin): It is suitable for models with more than 60 billion parameters, or special models such as those with exceptionally long sequences, very large vocabularies, and is best suited for scenarios with **low cross-node bandwidth and large-scale clusters (a thousand cards or more)** (e.g. GPT3-175b, Bloom-176b). + +## Plugins + +### Low Level Zero Plugin + +This plugin implements Zero-1 and Zero-2 (w/wo CPU offload), using `reduce` and `gather` to synchronize gradients and weights. + +Zero-1 can be regarded as a better substitute of Torch DDP, which is more memory efficient and faster. It can be easily used in hybrid parallelism. + +Zero-2 does not support local gradient accumulation. Though you can accumulate gradient if you insist, it cannot reduce communication cost. That is to say, it's not a good idea to use Zero-2 with pipeline parallelism. + +{{ autodoc:colossalai.booster.plugin.LowLevelZeroPlugin }} + +We've tested compatibility on some famous models, following models may not be supported: + +- `timm.models.convit_base` +- dlrm and deepfm models in `torchrec` + +Compatibility problems will be fixed in the future. + +### Gemini Plugin + +This plugin implements Zero-3 with chunk-based and heterogeneous memory management. It can train large models without much loss in speed. It also does not support local gradient accumulation. More details can be found in [Gemini Doc](../features/zero_with_chunk.md). + +{{ autodoc:colossalai.booster.plugin.GeminiPlugin }} + + +### Hybrid Parallel Plugin + +This plugin implements the combination of various parallel training strategies and optimization tools. The features of HybridParallelPlugin can be generally divided into four parts: + +1. Shardformer: This plugin provides an entrance to Shardformer, which controls model sharding under tensor parallel and pipeline parallel setting. Shardformer also overloads the logic of model's forward/backward process to ensure the smooth working of tp/pp. Also, optimization tools including fused normalization, flash attention (xformers), JIT and sequence parallel are injected into the overloaded forward/backward method by Shardformer. More details can be found in chapter [Shardformer Doc](../features/shardformer.md). + +2. Mixed Precision Training: Support for fp16/bf16 mixed precision training. More details about its arguments configuration can be found in [Mixed Precision Training Doc](../features/mixed_precision_training_with_booster.md). + +3. Torch DDP: This plugin will automatically adopt Pytorch DDP as data parallel strategy when pipeline parallel and Zero is not used. More details about its arguments configuration can be found in [Pytorch DDP Docs](https://pytorch.org/docs/main/generated/torch.nn.parallel.DistributedDataParallel.html#torch.nn.parallel.DistributedDataParallel). + +4. Zero: This plugin can adopt Zero 1/2 as data parallel strategy through setting the `zero_stage` argument as 1 or 2 when initializing plugin. Zero 1 is compatible with pipeline parallel strategy, while Zero 2 is not. More details about its argument configuration can be found in [Low Level Zero Plugin](#low-level-zero-plugin). + +> ⚠ When using this plugin, only the subset of Huggingface transformers supported by Shardformer are compatible with tensor parallel, pipeline parallel and optimization tools. Mainstream transformers such as Llama 1, Llama 2, OPT, Bloom, Bert and GPT2 etc. are all supported by Shardformer. + +{{ autodoc:colossalai.booster.plugin.HybridParallelPlugin }} + +### Torch DDP Plugin + +More details can be found in [Pytorch Docs](https://pytorch.org/docs/main/generated/torch.nn.parallel.DistributedDataParallel.html#torch.nn.parallel.DistributedDataParallel). + +{{ autodoc:colossalai.booster.plugin.TorchDDPPlugin }} + +### Torch FSDP Plugin + +> ⚠ This plugin is not available when torch version is lower than 1.12.0. + +> ⚠ This plugin does not support save/load sharded model checkpoint now. + +> ⚠ This plugin does not support optimizer that use multi params group. + +More details can be found in [Pytorch Docs](https://pytorch.org/docs/main/fsdp.html). + +{{ autodoc:colossalai.booster.plugin.TorchFSDPPlugin }} + + diff --git a/docs/source/en/basics/colotensor_concept.md b/docs/source/en/basics/colotensor_concept.md deleted file mode 100644 index 909c5e4d3c6f13933e80a7dbbe00473c141c4613..0000000000000000000000000000000000000000 --- a/docs/source/en/basics/colotensor_concept.md +++ /dev/null @@ -1,96 +0,0 @@ -# ColoTensor Concepts - -Author: [Jiarui Fang](https://github.com/feifeibear), [Hongxin Liu](https://github.com/ver217) and [Haichen Huang](https://github.com/1SAA) - -**Prerequisite:** -- [Colossal-AI Overview](../concepts/colossalai_overview.md) -- [Distributed Training](../concepts/distributed_training.md) -- [Paradigms of Parallelism](../concepts/paradigms_of_parallelism.md) - -## Introduction - -After ColossalAI version 0.1.8, [ColoTensor](https://colossalai.readthedocs.io/en/latest/colossalai/colossalai.tensor.html#colossalai.tensor.ColoTensor) becomes the basic data structure for tensors in ColossalAI. It is a subclass of torch.Tensor and can be used as a PyTorch Tensor. Additionally, some unique features make it possible to represent a Global Tensor with a payload distributed across multiple GPU devices. With the help of ColoTensor, the users can write distributed DNN training program similar to a serial one.support the following features. - -ColoTensor contains extra attributes capsuled in a [ColoTensorSpec](https://colossalai.readthedocs.io/en/latest/colossalai/colossalai.tensor.tensor_spec.html#colossalai.tensor.tensor_spec.ColoTensorSpec) instance to describe the tensor's payload distribution and computing pattern. - -- ProcessGroup: how processes are organized as communication groups. -- Distributed Spec: how tensor is distributed among process groups. -- Compute Spec: how the tensor is used during computation. - -We elaborate on them one by one. - -## ProcessGroup - -An instance of class [ProcessGroup](https://colossalai.readthedocs.io/en/latest/colossalai/colossalai.tensor.html#colossalai.tensor.ProcessGroup) describes how processes are organized in process groups. Processes in a process group can participate in the same collective communication operations together, such as allgather, allreduce, etc. The way the process group is organized is dominated by the Tensor's parallelism strategy. For example, if the user defines the tensor parallel (TP) and data parallel (DP) modes of a tensor, then the process organization of the process group will be automatically deduced. The process group settings can vary among different tensors. Therefore, it enables us to support more complicated hybrid parallel. The pipeline parallel (PP) definition is not in the ProcessGroup, it needs another set of mechanisms . We will supplement the related content of ColoTensor applied to PP in the future. - -Currently, a process group of ColoTensor is defined by two configurations, i.e. tp_degree and dp_degree. In the case of DP+TP hybrid parallelism, the device can be viewed as a 2D mesh. We place TP communication groups on the leading low dimension of the device mesh and then place the data parallel groups along the high dimension of the device mesh. The reason is that tensor parallelism has a larger communication overhead than data parallelism. Neighboring devices are placed inside a TP process group and are often placed in the same node. - -Considering that 8 processes are configured as tp_degree=4, and dp_degree=2, the layout is shown below. Process group tp0 contains gpu 0,1,2,3. Process dp1 contains gpu 1 and 5. - -
              - -
              Process Group using tp_degree=4, dp_degree=2
              -
              - -## Distributed Spec - -An instance of [Distributed Spec](https://colossalai.readthedocs.io/en/latest/colossalai/colossalai.tensor.distspec.html) describes how a ColoTensor is distributed among the ProcessGroup. - -How tensors are distributed among DP process groups is automatically derived and does not need to be manually specified by the user. If this tensor is a model parameter, it is replicated within the DP process group. If it is an activation tensor, it is split along the process with the highest dimension and evenly distributed the tensor payload among processes in the DP process group. - -Therefore, when using Distributed Spec, we only need to describe the way that the tensor is distributed among TP process groups. There are currently two ways to distribute among TP process group, i.e. [ShardSpec](https://colossalai.readthedocs.io/en/latest/colossalai/colossalai.tensor.distspec.html#colossalai.tensor.distspec.ShardSpec) and [ReplicaSpec](https://colossalai.readthedocs.io/en/latest/colossalai/colossalai.tensor.distspec.html#colossalai.tensor.distspec.ReplicaSpec). ShardSpec needs to specify the dimension index dim of the partition and the number of partitions num_partitions. Currently, we only support the split on a single dim. Different dist specs on the TP process groups can be converted to each other through the set_dist_spec() interface. The spec conversions are recorded by the autograd mechanism and it will trigger corresponding reverse operations during backward propagation. - -## Compute Spec - -An instance of class [ComputeSpec](https://colossalai.readthedocs.io/en/latest/colossalai/colossalai.tensor.compute_spec.html#colossalai.tensor.compute_spec.ComputeSpec) describes how a Colotensor be used in DNN training. Currently, we will set the correct Compute Pattern for the ColoTensor as the parameters of the module. The specific application scenarios will be shown in the next document. - -## ColoParameter - -[ColoParameter](https://colossalai.readthedocs.io/en/latest/colossalai/colossalai.tensor.colo_parameter.html#colossalai.tensor.colo_parameter.ColoParameter) is a subclass of ColoTensor. Used to define a Global Parameter tensor. Its relationship with ColoTensor is consistent with Torch.Tensor and torch.Parameter. The latter allows the tensor to appear in the return values of the module's parameters() and name_parameters() methods. - -## Example - -Let's see an example. A ColoTensor is initialized and sharded on 8 GPUs using tp_degree=4, dp_dgree=2. And then the tensor is sharded along the last dim among the TP process groups. Finally, we reshard it along the first dim (0 dim) among the TP process groups. We encourage users to run the code and observe the shape of each tensor. - - -```python -import torch -import torch.multiprocessing as mp -from colossalai.utils import print_rank_0 -from functools import partial - -import colossalai -from colossalai.tensor import ProcessGroup, ColoTensor, ColoTensorSpec, ShardSpec, ComputeSpec, ComputePattern -from colossalai.testing import spawn - -import torch - -def run_dist_tests(rank, world_size, port): - colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') - pg = ProcessGroup(tp_degree=2, dp_degree=2) - - torch.manual_seed(0) - local_tensor = torch.randn(2, 3, 1).cuda() - print_rank_0(f"shape {local_tensor.shape}, {local_tensor.data}") - - spec = ColoTensorSpec(pg, ShardSpec(dims=[-1], num_partitions=[pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D)) - t1 = ColoTensor.from_torch_tensor(local_tensor, spec) - t1 = t1.to_replicate() - print_rank_0(f"shape {t1.shape}, {t1.data}") - - spec2 = ShardSpec([0], [pg.tp_world_size()]) - t1.set_dist_spec(spec2) - print_rank_0(f"shape {t1.shape}, {t1.data}") - -def test_dist_cases(world_size): - spawn(run_dist_tests, world_size) - -if __name__ == '__main__': - test_dist_cases(4) -``` - -:::caution - -The ColoTensor is an experimental feature and may be updated. - -::: diff --git a/docs/source/en/basics/command_line_tool.md b/docs/source/en/basics/command_line_tool.md index 48b199cf78e9b36118443b482ff867c0bfc7996a..4c278aaa0c6a37e0f4a501dff950c832c6642ca7 100644 --- a/docs/source/en/basics/command_line_tool.md +++ b/docs/source/en/basics/command_line_tool.md @@ -30,24 +30,4 @@ This command will inform you information regarding the version compatibility and To launch distributed jobs on single or multiple nodes, the command `colossalai run` can be used for process launching. You may refer to [Launch Colossal-AI](./launch_colossalai.md) for more details. -## Tensor Parallel Micro-Benchmarking - -As Colossal-AI provides an array of tensor parallelism methods, it is not intuitive to choose one for your hardware and -model. Therefore, we provide a simple benchmarking to evaluate the performance of various tensor parallelisms on your system. -This benchmarking is run on a simple MLP model where the input data is of the shape `(batch_size, seq_length, hidden_size)`. -Based on the number of GPUs, the CLI will look for all possible tensor parallel configurations and display the benchmarking results. -You can customize the benchmarking configurations by checking out `colossalai benchmark --help`. - -```shell -# run on 4 GPUs -colossalai benchmark --gpus 4 - -# run on 8 GPUs -colossalai benchmark --gpus 8 -``` - -:::caution - -Only single-node benchmarking is supported currently. - -::: + diff --git a/docs/source/en/basics/configure_parallelization.md b/docs/source/en/basics/configure_parallelization.md deleted file mode 100644 index 4ac0299eac14252eb55f3f635d5747702d2033d5..0000000000000000000000000000000000000000 --- a/docs/source/en/basics/configure_parallelization.md +++ /dev/null @@ -1,156 +0,0 @@ -# Configure Parallelization - -Author: Shenggui Li, Siqi Mai - -**Prerequisite:** -- [Distributed Training](../concepts/distributed_training.md) -- [Paradigms of Parallelism](../concepts/paradigms_of_parallelism.md) -- [Define Your Configuration](./define_your_config.md) - - -## Introduction - -We support multiple parallelization in Colossal-AI. Hybrid parallelism in our codebase refers to namely the combination -of data parallelism, pipeline parallelism and tensor parallelism (1D, 2D, 2.5D, 3D). - -Each parallelism requires different network topology and thus initialize different process groups. -You can initialize the corresponding process group by setting `parallel` in the config file. -The configuration for `parallel` must obey the following format. Data parallel size will be -inferred automatically based on your inputs to pipeline parallelism and tensor parallelism. -`colossalai.launch` will initialize these distributed process groups automatically based on your configuration. - -Some sample configurations are shown below: - -```python -# sampler format -parallel = dict( - pipeline=dict("size": int), - tensor=dict("size": int, "mode": '1d' or '2d' or '2.5d' or '3d', "kwargs": Any) -) - -# this is ok -parallel = dict( - pipeline=dict(size=2), - tensor=dict(size=4, mode='2d') -) - -# this is ok -parallel = dict( - pipeline=2, - tensor=dict(size=4, mode='2d') -) - -# this is not ok -# as you need to specify the mode for tensor parallelism -parallel = dict( - pipeline=2, - tensor=4 -) - -# this is ok as well as tensor will be default to size 1 -# and mode None -parallel = dict( - pipeline=2 -) - -# this is ok as well as pipeline will default to size 1 -parallel = dict( - tensor=dict(size=4, mode='2d') -) - -``` - -The key name `size` refers to the parallel size of the parallelism dimension. For example, pipeline size 2 means there -will be 2 pipeline stages. The key name `mode` in tensor parallel config means the corresponding tensor parallelism -will be initialized. - -**You can choose to not have 'parallel' in your configuration and both pipeline and tensor will default to size 1.** - -**Total number of GPUs must be equal to `data parallel size * tensor parallel size * pipeline parallel size`** - -## Data Parallel - -Data parallel is the most common way to distribute your training task by splitting data into several shards and train on -a single shard on each device. The configuration for data parallel is detected automatically and set for you. You do not -have to explicitly set them in your configurations. There are two ways to handle the all-reduce in data parallel in Colossal-AI. - -1. If you specify gradient handlers, gradients will be all-reduced according to the gradient handlers -2. Otherwise, PyTorch DistributedDataParallel will be used - -In most cases, you will be using the second mode unless you have complex handling of the gradients. - -## 1D, 2D, 2.5D and 3D Parallel - -To enable hybrid parallelism, we provide an array of tensor parallelism. We provide the list of papers which match each -tensor parallel method. These parallel modes need to work with the distributed layers provided by Colossal-AI. - -- 1D: [Megatron-LM: Training Multi-Billion Parameter Language Models Using Model Parallelism](https://arxiv.org/abs/1909.08053) - -- 2D: [An Efficient 2D Method for Training Super-Large Deep Learning Models](https://arxiv.org/abs/2104.05343) - 2D parallel relies on the SUMMA matrix multiplication algorithm and splits the input data, model weights and layer - outputs along two different dimensions. The tensor chunks are distributed over a 2D mesh of `P = N^2` devices where - `N` is the number of tensor chunks in a single dimension. - -- 2.5D: [2.5-dimensional distributed model training](https://arxiv.org/abs/2105.14500) - Inspired by the 2.5D matrix multiplication algorithm, 2.5D parallel introduces a novel tensor parallelism which - further parallelizes 2D tensor parallelism. An amount of `P = N^2 ∗ d` processors are arranged into `d` layers, where - each layer performs matrix multiplication operations independently with a dimension `N`. - -- 3D: [Maximizing Parallelism in Distributed Training for Huge Neural Networks](https://arxiv.org/abs/2105.14450) - We also introduce a 3D tensor parallelism that parallelizes neural networks on a 3D processor cube. This method - achieves the optimal, `O(P^{1/3})` communication overhead on $P$ processors, while both computation and memory usage - are evenly distributed through optimized load balancing of parameters as well as activations. - -```python -# 1D parallel -parallel = dict( - tensor=dict(size=4, mode='1d') -) - -# 2D parallel -parallel = dict( - tensor=dict(size=4, mode='2d') -) - -# 2.5D parallel -parallel = dict( - tensor=dict(size=8, mode='2.5d', depth=2) -) - -# 3D parallel -parallel = dict( - tensor=dict(size=8, mode='3d') -) -``` - -Once you specify the tensor parallel mode in your configuration, you can proceed to use its corresponding distributed -operator. For example, if you mode is '2d', you can use `colossalai.nn.Linear2D` in you model construction. - - -## Pipeline Parallel - -Pipeline parallelism is to split the model into several partitions by layer. For example, let's assume we have a simple -model which consists of two linear layer. We have two GPUs, and we can allocate the first linear layer to the first GPU -and the second layer to the second GPU. - -You can set the number of pipeline stages in your configuration file. When pipeline size is larger than 1, Colossal-AI -will automatically creates the pipeline schedule which defines the forward and backward step. - -```python -parallel = dict( - pipeline=dict(size=4), # number of pipeline stages -) -``` - -## Sequence Parallel - -Sequence parallel is to support long-sequence modelling such as document-level text understanding and medical imaging. -This method is proposed in [Sequence Parallelism: Making 4D Parallelism Possible](https://arxiv.org/abs/2105.13120). -You can use specify the mode to be `sequence` to initialize its process group. - - -```python -parallel = dict( - tensor=dict(size=4, mode='sequence') -) -``` diff --git a/docs/source/en/basics/define_your_config.md b/docs/source/en/basics/define_your_config.md deleted file mode 100644 index d2569691b7dc2f910bd5e676ef3b8b2978b53bc2..0000000000000000000000000000000000000000 --- a/docs/source/en/basics/define_your_config.md +++ /dev/null @@ -1,82 +0,0 @@ -# Define Your Configuration - -Author: Guangyang Lu, Shenggui Li, Siqi Mai - -**Prerequisite:** -- [Distributed Training](../concepts/distributed_training.md) -- [Colossal-AI Overview](../concepts/colossalai_overview.md) - - -## Introduction - -In Colossal-AI, a configuration file is required to specify the features the system will inject into the training process. -In this tutorial, we will introduce you how to construct your configuration file and how this config file will be used. -Using configuration file has several advantages: - -1. You can store your feature configuration and training hyper-parameters in different configuration files -2. New features released in the future can be specified in the configuration without code change in the training script - -In this tutorial, we will cover how to define your configuration file. - -## Configuration Definition - -In a configuration file, there are two types of variables. One serves as feature specification and the other serves -as hyper-parameters. All feature-related variables are reserved keywords. For example, if you want to use mixed precision -training, you need to use the variable name `fp16` in the config file and follow a pre-defined format. - -### Feature Specification - -There is an array of features Colossal-AI provides to speed up training. Each feature is defined by a corresponding field -in the config file. In this tutorial, we are not giving the config details for all the features, but rather we are providing -an illustration of how to specify a feature. **The details of each feature can be found in its respective tutorial.** - -To illustrate the use of config file, we use mixed precision training as an example here. In order to do so, you need to -follow the steps below. - -1. create a configuration file (e.g. `config.py`, the file name can be anything) -2. define the mixed precision configuration in the config file. For example, in order to use mixed precision training -natively provided by PyTorch, you can just write these lines of code below into your config file. - - ```python - from colossalai.amp import AMP_TYPE - - fp16 = dict( - mode=AMP_TYPE.TORCH - ) - ``` - -3. Tell Colossal-AI where your config file is when launch the distributed environment. For example, the config file is in -the current directory. - - ```python - import colossalai - - colossalai.launch(config='./config.py', ...) - ``` - -In this way, Colossal-AI knows what features you want to use and will inject this feature during `colossalai.initialize`. - -### Global Hyper-parameters - -Besides feature specification, the config file can also serve as a place to define your training hyper-parameters. This -comes handy when you want to perform multiple experiments, each experiment details can be put into a single config file -to avoid confusion. These parameters will be stored in the global parallel context and can be accessed in the training script. - -For example, you can specify the batch size in your config file. - -```python -BATCH_SIZE = 32 -``` - -After launch, you are able to access your hyper-parameters through global parallel context. - -```python -import colossalai -from colossalai.core import global_context as gpc - -colossalai.launch(config='./config.py', ...) - -# access your parameter -print(gpc.config.BATCH_SIZE) - -``` diff --git a/docs/source/en/basics/engine_trainer.md b/docs/source/en/basics/engine_trainer.md deleted file mode 100644 index bbe32ed5a3b579805d99ead1169b69a196916626..0000000000000000000000000000000000000000 --- a/docs/source/en/basics/engine_trainer.md +++ /dev/null @@ -1,387 +0,0 @@ -# Use Engine and Trainer in Training - -Author: Shenggui Li, Siqi Mai - -**Prerequisite:** -- [Initialize Features](./initialize_features.md) - -## Introduction - -In this tutorial, you will learn how to use the engine and trainer provided in Colossal-AI to train your model. -Before we delve into the details, we would like to first explain the concept of engine and trainer. - -### Engine - -Engine is essentially a wrapper class for model, optimizer and loss function. -When we call `colossalai.initialize`, an engine object will be returned, and it has already been equipped with -functionalities such as gradient clipping, gradient accumulation and zero optimizer as specified in your configuration file. -An engine object will use similar APIs to those of PyTorch training components such that the user has minimum change -to their code. - -Below is a table which shows the commonly used APIs for the engine object. - -| Component | Function | PyTorch | Colossal-AI | -| ------------------------------------- | --------------------------------------------- | ------------------------------- | -------------------------------------- | -| optimizer | Set all gradients to zero before an iteration | optimizer.zero_grad() | engine.zero_grad() | -| optimizer | Update the parameters | optimizer.step() | engine.step() | -| model | Run a forward pass | outputs = model(inputs) | outputs = engine(inputs) | -| criterion | Calculate the loss value | loss = criterion(output, label) | loss = engine.criterion(output, label) | -| criterion | Execute back-propagation on the model | loss.backward() | engine.backward(loss) | - -The reason why we need such an engine class is that we can add more functionalities while hiding the implementations in -the `colossalai.initialize` function. -Imaging we are gonna add a new feature, we can manipulate the model, optimizer, dataloader and loss function in the -`colossalai.initialize` function and only expose an engine object to the user. -The user only needs to modify their code to the minimum extent by adapting the normal PyTorch APIs to the Colossal-AI -engine APIs. In this way, they can enjoy more features for efficient training. - -A normal training iteration using engine can be: - -```python -import colossalai - -# build your model, optimizer, criterion, dataloaders -... - -engine, train_dataloader, test_dataloader, _ = colossalai.initialize(model, - optimizer, - criterion, - train_dataloader, - test_dataloader) -for img, label in train_dataloader: - engine.zero_grad() - output = engine(img) - loss = engine.criterion(output, label) - engine.backward(loss) - engine.step() -``` - -### Trainer - -Trainer is a more high-level wrapper for the user to execute training with fewer lines of code. However, in pursuit of more abstraction, it loses some flexibility compared to engine. The trainer is designed to execute a forward and backward step to perform model weight update. It is easy to create a trainer object by passing the engine object. The trainer has a default value `None` for the argument `schedule`. In most cases, we leave this value to `None` unless we want to use pipeline parallelism. If you wish to explore more about this parameter, you can go to the tutorial on pipeline parallelism. - -```python -from colossalai.logging import get_dist_logger -from colossalai.trainer import Trainer, hooks - -# build components and initialize with colossalai.initialize -... - -# create a logger so that trainer can log on the console -logger = get_dist_logger() - -# create a trainer object -trainer = Trainer( - engine=engine, - logger=logger -) -``` - - - -In trainer, the user can customize some hooks and attach these hooks to the trainer object. A hook object will execute life-cycle methods periodically based on the training scheme. For example, The `LRSchedulerHook` will execute `lr_scheduler.step()` to update the learning rate of the model during either `after_train_iter` or `after_train_epoch` stages depending on whether the user wants to update the learning rate after each training iteration or only after the entire training epoch. You can store the hook objects in a list and pass it to `trainer.fit` method. `trainer.fit` method will execute training and testing based on your parameters. If `display_process` is True, a progress bar will be displayed on your console to show the training process. - -```python -# define the hooks to attach to the trainer -hook_list = [ - hooks.LossHook(), - hooks.LRSchedulerHook(lr_scheduler=lr_scheduler, by_epoch=True), - hooks.AccuracyHook(accuracy_func=Accuracy()), - hooks.LogMetricByEpochHook(logger), -] - -# start training -trainer.fit( - train_dataloader=train_dataloader, - epochs=NUM_EPOCHS, - test_dataloader=test_dataloader, - test_interval=1, - hooks=hook_list, - display_progress=True -) -``` - -If you want to customize your own hook class, you can inherit `hooks.BaseHook` and override the life-cycle methods of your interest. A dummy example to demonstrate how to create a simple log message hook is provided below for your reference. - -```python -from colossalai.logging import get_dist_logger -from colossalai.trainer import hooks - -class LogMessageHook(hooks.BaseHook): - - def __init__(self, priority=10): - self._logger = get_dist_logger() - - def before_train(self, trainer): - self._logger.info('training starts') - - def after_train(self, trainer): - self._logger.info('training finished') - - -... - -# then in your training script -hook_list.append(LogMessageHook()) -``` - - - -In the sections below, I will guide you through the steps required to train a ResNet model with both engine and trainer. - - - -## Explain with ResNet - -### Overview - -In this section we will cover: - -1. Use an engine object to train a ResNet34 model on CIFAR10 dataset -2. Use a trainer object to train a ResNet34 model on CIFAR10 dataset - -The project structure will be like: - -```bash --- config.py --- run_resnet_cifar10_with_engine.py --- run_resnet_cifar10_with_trainer.py -``` - -Steps 1-4 below are commonly used regardless of using engine or trainer. Thus, steps 1-4 + step 5 will be your `run_resnet_cifar10_with_engine.py` and steps 1-4 + step 6 will form `run_resnet_cifar10_with_trainer.py`. - -### Hands-on Practice - -#### Step 1. Create a Config File - -In your project folder, create a `config.py`. This file is to specify some features you may want to use to train your model. A sample config file is as below: - -```python -from colossalai.amp import AMP_TYPE - -BATCH_SIZE = 128 -NUM_EPOCHS = 200 - -fp16=dict( - mode=AMP_TYPE.TORCH -) -``` - -In this config file, we specify that we want to use batch size 128 per GPU and run for 200 epochs. These two parameters are exposed by `gpc.config`. For example, you can use `gpc.config.BATCH_SIZE` to access the value you store in your config file. The `fp16` configuration tells `colossalai.initialize` to use mixed precision training provided by PyTorch to train the model with better speed and lower memory consumption. - -#### Step 2. Initialize Distributed Environment - -We need to initialize the distributed training environment. This has been introduced in the tutorial on how to -[launch Colossal-AI](./launch_colossalai.md). For this demonstration, we use `launch_from_torch` and PyTorch launch utility. - -```python -import colossalai - -# ./config.py refers to the config file we just created in step 1 -colossalai.launch_from_torch(config='./config.py') -``` - -#### Step 3. Create all the training components - -In this step, we can create all the components used for training. These components include: - -1. Model -2. Optimizer -3. Criterion/loss function -4. Training/Testing dataloaders -5. Learning rate Scheduler -6. Logger - - - -To build these components, you need to import the following modules: - -```python -from pathlib import Path -from colossalai.logging import get_dist_logger -import torch -import os -from colossalai.core import global_context as gpc -from colossalai.utils import get_dataloader -from torchvision import transforms -from colossalai.nn.lr_scheduler import CosineAnnealingLR -from torchvision.datasets import CIFAR10 -from torchvision.models import resnet34 -``` - - - -Then build your components in the same way as how to normally build them in your PyTorch scripts. In the script below, we set the root path for CIFAR10 dataset as an environment variable `DATA`. You can change it to any path you like, for example, you can change `root=Path(os.environ['DATA'])` to `root='./data'` so that there is no need to set the environment variable. - -```python -# build logger -logger = get_dist_logger() - -# build resnet -model = resnet34(num_classes=10) - -# build datasets -train_dataset = CIFAR10( - root='./data', - download=True, - transform=transforms.Compose( - [ - transforms.RandomCrop(size=32, padding=4), - transforms.RandomHorizontalFlip(), - transforms.ToTensor(), - transforms.Normalize(mean=[0.4914, 0.4822, 0.4465], std=[ - 0.2023, 0.1994, 0.2010]), - ] - ) -) - -test_dataset = CIFAR10( - root='./data', - train=False, - transform=transforms.Compose( - [ - transforms.ToTensor(), - transforms.Normalize(mean=[0.4914, 0.4822, 0.4465], std=[ - 0.2023, 0.1994, 0.2010]), - ] - ) -) - -# build dataloaders -train_dataloader = get_dataloader(dataset=train_dataset, - shuffle=True, - batch_size=gpc.config.BATCH_SIZE, - num_workers=1, - pin_memory=True, - ) - -test_dataloader = get_dataloader(dataset=test_dataset, - add_sampler=False, - batch_size=gpc.config.BATCH_SIZE, - num_workers=1, - pin_memory=True, - ) - -# build criterion -criterion = torch.nn.CrossEntropyLoss() - -# optimizer -optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9, weight_decay=5e-4) - -# lr_scheduler -lr_scheduler = CosineAnnealingLR(optimizer, total_steps=gpc.config.NUM_EPOCHS) -``` - -#### Step 4. Initialize with Colossal-AI - -Next, the essential step is to obtain the engine class by calling `colossalai.initialize`. As stated in `config.py`, we will be using mixed precision training for training ResNet34 model. `colossalai.initialize` will automatically check your config file and assign relevant features to your training components. In this way, our engine object has already been able to train with mixed precision, but you do not have to explicitly take care of it. - -```python -engine, train_dataloader, test_dataloader, _ = colossalai.initialize(model, - optimizer, - criterion, - train_dataloader, - test_dataloader, - ) -``` - - - -#### Step 5. Train with engine - -With all the training components ready, we can train ResNet34 just like how to normally deal with PyTorch training. - -```python -for epoch in range(gpc.config.NUM_EPOCHS): - # execute a training iteration - engine.train() - for img, label in train_dataloader: - img = img.cuda() - label = label.cuda() - - # set gradients to zero - engine.zero_grad() - - # run forward pass - output = engine(img) - - # compute loss value and run backward pass - train_loss = engine.criterion(output, label) - engine.backward(train_loss) - - # update parameters - engine.step() - - # update learning rate - lr_scheduler.step() - - # execute a testing iteration - engine.eval() - correct = 0 - total = 0 - for img, label in test_dataloader: - img = img.cuda() - label = label.cuda() - - # run prediction without back-propagation - with torch.no_grad(): - output = engine(img) - test_loss = engine.criterion(output, label) - - # compute the number of correct prediction - pred = torch.argmax(output, dim=-1) - correct += torch.sum(pred == label) - total += img.size(0) - - logger.info( - f"Epoch {epoch} - train loss: {train_loss:.5}, test loss: {test_loss:.5}, acc: {correct / total:.5}, lr: {lr_scheduler.get_last_lr()[0]:.5g}", ranks=[0]) -``` - -#### Step 6. Train with trainer - -If you wish to train with a trainer object, you can follow the code snippet below: - -```python -from colossalai.nn.metric import Accuracy -from colossalai.trainer import Trainer, hooks - - -# create a trainer object -trainer = Trainer( - engine=engine, - logger=logger -) - -# define the hooks to attach to the trainer -hook_list = [ - hooks.LossHook(), - hooks.LRSchedulerHook(lr_scheduler=lr_scheduler, by_epoch=True), - hooks.AccuracyHook(accuracy_func=Accuracy()), - hooks.LogMetricByEpochHook(logger), - hooks.LogMemoryByEpochHook(logger) -] - -# start training -# run testing every 1 epoch -trainer.fit( - train_dataloader=train_dataloader, - epochs=gpc.config.NUM_EPOCHS, - test_dataloader=test_dataloader, - test_interval=1, - hooks=hook_list, - display_progress=True -) -``` - - - -#### Step 7. Start Distributed Training - -Lastly, we can invoke the scripts using the distributed launcher provided by PyTorch as we used `launch_from_torch` in Step 2. You need to replace `` with the number of GPUs available on your machine. This number can be 1 if you only want to use 1 GPU. If you wish to use other launchers, you can refer to the tutorial on How to Launch Colossal-AI. - -```bash -# with engine -python -m torch.distributed.launch --nproc_per_node --master_addr localhost --master_port 29500 run_resnet_cifar10_with_engine.py -# with trainer -python -m torch.distributed.launch --nproc_per_node --master_addr localhost --master_port 29500 run_resnet_cifar10_with_trainer.py -``` diff --git a/docs/source/en/basics/initialize_features.md b/docs/source/en/basics/initialize_features.md deleted file mode 100644 index e768d2022ad8d706980a1dcd1b90ffd7a8be5367..0000000000000000000000000000000000000000 --- a/docs/source/en/basics/initialize_features.md +++ /dev/null @@ -1,49 +0,0 @@ -# Initialize Features - -Author: Shenggui Li, Siqi Mai - -**Prerequisite:** -- [Distributed Training](../concepts/distributed_training.md) -- [Colossal-AI Overview](../concepts/colossalai_overview.md) - -## Introduction - -In this tutorial, we will cover the use of `colossalai.initialize` which injects features into your training components -(e.g. model, optimizer, dataloader) seamlessly. Calling `colossalai.initialize` is the standard procedure before you run -into your training loops. - -In the section below, I will cover how `colossalai.initialize` works and what we should take note of. - -## Usage - -In a typical workflow, we will launch distributed environment at the beginning of our training script. -Afterwards, we will instantiate our objects such as model, optimizer, loss function, dataloader etc. At this moment, `colossalai.initialize` -can come in to inject features into these objects. A pseudo-code example is like below: - -```python -import colossalai -import torch -... - - -# launch distributed environment -colossalai.launch(config='./config.py', ...) - -# create your objects -model = MyModel() -optimizer = torch.optim.Adam(model.parameters(), lr=0.001) -criterion = torch.nn.CrossEntropyLoss() -train_dataloader = MyTrainDataloader() -test_dataloader = MyTrainDataloader() - -# initialize features -engine, train_dataloader, test_dataloader, _ = colossalai.initialize(model, - optimizer, - criterion, - train_dataloader, - test_dataloader) -``` - -The `colossalai.initialize` function will return an `Engine` object. The engine object is a wrapper -for model, optimizer and loss function. **The engine object will run with features specified in the config file.** -More details about the engine can be found in the [Use Engine and Trainer in Training](./engine_trainer.md). diff --git a/docs/source/en/basics/launch_colossalai.md b/docs/source/en/basics/launch_colossalai.md index be487f8539a57ce1995516f4e95a9d8d480b9630..334757ea75af4f7c30cc074fea69458c75f86ae1 100644 --- a/docs/source/en/basics/launch_colossalai.md +++ b/docs/source/en/basics/launch_colossalai.md @@ -87,14 +87,13 @@ import colossalai args = colossalai.get_default_parser().parse_args() # launch distributed environment -colossalai.launch(config=, +colossalai.launch(config=args.config, rank=args.rank, world_size=args.world_size, host=args.host, port=args.port, backend=args.backend ) - ``` @@ -107,12 +106,21 @@ First, we need to set the launch method in our code. As this is a wrapper of the use `colossalai.launch_from_torch`. The arguments required for distributed environment such as rank, world size, host and port are all set by the PyTorch launcher and can be read from the environment variable directly. +config.py +```python +BATCH_SIZE = 512 +LEARNING_RATE = 3e-3 +WEIGHT_DECAY = 0.3 +NUM_EPOCHS = 2 +``` +train.py ```python import colossalai colossalai.launch_from_torch( - config=, + config="./config.py", ) +... ``` Next, we can easily start multiple processes with `colossalai run` in your terminal. Below is an example to run the code diff --git a/docs/source/en/basics/model_checkpoint.md b/docs/source/en/basics/model_checkpoint.md deleted file mode 100644 index 09d44e7c27097a37486ef4678af07aeb1163d44e..0000000000000000000000000000000000000000 --- a/docs/source/en/basics/model_checkpoint.md +++ /dev/null @@ -1,61 +0,0 @@ -# Model Checkpoint - -Author : Guangyang Lu - -**Prerequisite:** -- [Launch Colossal-AI](./launch_colossalai.md) -- [Initialize Colossal-AI](./initialize_features.md) - -**Example Code:** -- [ColossalAI-Examples Model Checkpoint](https://github.com/hpcaitech/ColossalAI-Examples/tree/main/utils/checkpoint) - -**This function is experiential.** - -## Introduction - -In this tutorial, you will learn how to save and load model checkpoints. - -To leverage the power of parallel strategies in Colossal-AI, modifications to models and tensors are needed, for which you cannot directly use `torch.save` or `torch.load` to save or load model checkpoints. Therefore, we have provided you with the API to achieve the same thing. - -Moreover, when loading, you are not demanded to use the same parallel strategy as saving. - -## How to use - -### Save - -There are two ways to train a model in Colossal-AI, by engine or by trainer. -**Be aware that we only save the `state_dict`.** Therefore, when loading the checkpoints, you need to define the model first. - -#### Save when using engine - -```python -from colossalai.utils import save_checkpoint -model = ... -engine, _, _, _ = colossalai.initialize(model=model, ...) -for epoch in range(num_epochs): - ... # do some training - save_checkpoint('xxx.pt', epoch, model) -``` - -#### Save when using trainer -```python -from colossalai.trainer import Trainer, hooks -model = ... -engine, _, _, _ = colossalai.initialize(model=model, ...) -trainer = Trainer(engine, ...) -hook_list = [ - hooks.SaveCheckpointHook(1, 'xxx.pt', model) - ...] - -trainer.fit(... - hook=hook_list) -``` - -### Load - -```python -from colossalai.utils import load_checkpoint -model = ... -load_checkpoint('xxx.pt', model) -... # train or test -``` diff --git a/docs/source/en/concepts/colossalai_overview.md b/docs/source/en/concepts/colossalai_overview.md index 38b682d49e62dd3ee9b7b03ea34fdd8697bda286..7617c62a4e00b786fda1c20491fd0259f0f51787 100644 --- a/docs/source/en/concepts/colossalai_overview.md +++ b/docs/source/en/concepts/colossalai_overview.md @@ -19,7 +19,7 @@ We aim to make Colossal-AI easy to use and non-intrusive to user code. There is 1. Prepare a configuration file where specifies the features you want to use and your parameters. 2. Initialize distributed backend with `colossalai.launch` -3. Inject the training features into your training components (e.g. model, optimizer) with `colossalai.initialize`. +3. Inject the training features into your training components (e.g. model, optimizer) with `colossalai.booster`. 4. Run training and testing We will cover the whole workflow in the `basic tutorials` section. @@ -34,3 +34,5 @@ The Colossal-AI system will be expanded to include more training skills, these n 4. expansion of existing parallelism methods We welcome ideas and contribution from the community and you can post your idea for future development in our forum. + + diff --git a/docs/source/en/features/1D_tensor_parallel.md b/docs/source/en/features/1D_tensor_parallel.md index 7577e50400e91c5e73224590738aa10a82f2ac7e..37c01db31342873cba3ded5e4f602f8808e1a798 100644 --- a/docs/source/en/features/1D_tensor_parallel.md +++ b/docs/source/en/features/1D_tensor_parallel.md @@ -2,12 +2,8 @@ Author: Zhengda Bian, Yongbin Li -**Prerequisite** -- [Define Your Configuration](../basics/define_your_config.md) -- [Configure Parallelization](../basics/configure_parallelization.md) - **Example Code** -- [ColossalAI-Examples 1D Tensor Parallelism](https://github.com/hpcaitech/ColossalAI-Examples/tree/main/features/tensor_parallel/tensor_parallel_1d.py) +- [Tensor Parallelism with Shardformer](https://github.com/hpcaitech/ColossalAI/tree/main/colossalai/shardformer/examples) **Related Paper** - [Efficient Large-Scale Language Model Training on GPU Clusters Using Megatron-LM](https://deepakn94.github.io/assets/papers/megatron-sc21.pdf) @@ -19,15 +15,15 @@ An efficient 1D tensor parallelism implementation was introduced by [Megatron-LM Let's take a linear layer as an example, which consists of a GEMM $Y = XA$. Given 2 processors, we split the columns of $A$ into $[A_1 ~ A_2]$, and calculate $Y_i = XA_i$ on each processor, which then forms $[Y_1 ~ Y_2] = [XA_1 ~ XA_2]$. This is called a column-parallel fashion. -When a second linear layer $Z=YB$ follows the column-parallel one, we split $B$ into -```math +When a second linear layer $Z=YB$ follows the column-parallel one, we split $B$ into +$$ \left[\begin{matrix} B_1 \\ B_2 \end{matrix} \right] -``` +$$ which is called a row-parallel fashion. -To calculate -```math +To calculate +$$ Z = [Y_1 ~ Y_2] \left[\begin{matrix} B_1 \\ B_2 \end{matrix} \right] -``` +$$ we first calculate $Y_iB_i$ on each processor, then use an all-reduce to aggregate the results as $Z=Y_1B_1+Y_2B_2$. We also need to note that in the backward pass, the column-parallel linear layer needs to aggregate the gradients of the input tensor $X$, because on each processor $i$ we only have $\dot{X_i}=\dot{Y_i}A_i^T$. @@ -42,77 +38,7 @@ Given $P$ processors, we present the theoretical computation and memory cost, as ## Usage -To enable 1D tensor parallelism for our model, e.g. on 2 GPUs, we need to configure the parallelism setting as below. -```python -CONFIG = dict(parallel=dict( - data=1, - pipeline=1, - tensor=dict(size=2, mode='1d'), -)) -``` -Then Colossal-AI will automatically apply 1D parallelism to all the layers from `colossalai.nn`. - -Let's define a model that consists of a two-layer multi-layer perceptron (MLP) as below. -```python -import colossalai -import colossalai.nn as col_nn -import torch -from colossalai.utils import print_rank_0 - -class MLP(torch.nn.Module): - def __init__(self, dim: int = 256): - super().__init__() - intermediate_dim = dim * 4 - self.dense_1 = col_nn.Linear(dim, intermediate_dim) - print_rank_0(f'Weight of the first linear layer: {self.dense_1.weight.transpose(0, 1).shape}') - self.activation = torch.nn.GELU() - self.dense_2 = col_nn.Linear(intermediate_dim, dim) - print_rank_0(f'Weight of the second linear layer: {self.dense_2.weight.transpose(0, 1).shape}') - self.dropout = col_nn.Dropout(0.1) - - def forward(self, x): - x = self.dense_1(x) - print_rank_0(f'Output of the first linear layer: {x.shape}') - x = self.activation(x) - x = self.dense_2(x) - print_rank_0(f'Output of the second linear layer: {x.shape}') - x = self.dropout(x) - return x -``` - -Launch Colossal-AI on 2 GPUs and build the model. - -```python -parser = colossalai.get_default_parser() -colossalai.launch(config=CONFIG, - rank=args.rank, - world_size=args.world_size, - local_rank=args.local_rank, - host=args.host, - port=args.port) - -m = MLP() -``` -We will see the shapes of partitioned parameters(e.g. weights) in the MLP model. -```shell -Weight of the first linear layer: torch.Size([256, 512]) -Weight of the second linear layer: torch.Size([512, 256]) -``` -The complete weight of the first linear layer is supposed to have the shape `[256, 1024]`. After the column-parallel partitioning, it becomes `[256, 512]`. -Similarly, the second row-parallel layer partitions the weight `[1024, 256]` into `[512, 256]`. - -We can run the model with some random inputs. -```python -from colossalai.utils import get_current_device - -x = torch.randn((16, 256), device=get_current_device()) -torch.distributed.broadcast(x, src=0) # synchronize input +1D tensor parallelism is implemented by `Shardformer` feature in the newest version of ColossalAI. +For more details about ideas and usages of `Shardformer`, please refer to [Shardformer Doc](./shardformer.md). -x = m(x) -``` -Then we can see the shapes of activation results. -```shell -Output of the first linear layer: torch.Size([16, 512]) -Output of the second linear layer: torch.Size([16, 256]) -``` -The output of the first linear layer is split into 2 partitions (each has the shape `[16, 512]`), while the second layer has identical outputs across the GPUs. + diff --git a/docs/source/en/features/2D_tensor_parallel.md b/docs/source/en/features/2D_tensor_parallel.md index 7b6c10766099f6fec50b656138fa2a7fd0cdd132..692e2702edd98a9e33798ac871bad493c3f237f4 100644 --- a/docs/source/en/features/2D_tensor_parallel.md +++ b/docs/source/en/features/2D_tensor_parallel.md @@ -3,12 +3,10 @@ Author: Zhengda Bian, Yongbin Li **Prerequisite** -- [Define Your Configuration](../basics/define_your_config.md) -- [Configure Parallelization](../basics/configure_parallelization.md) - [1D Tensor Parallelism](./1D_tensor_parallel.md) **Example Code** -- [ColossalAI-Examples - 2D Tensor Parallelism](https://github.com/hpcaitech/ColossalAI-Examples/tree/main/features/tensor_parallel/tensor_parallel_2d.py) +- [ColossalAI-Examples - 2D Tensor Parallelism](https://github.com/hpcaitech/ColossalAI-Examples/blob/main/features/tensor_parallel/README.md) **Related Paper** - [An Efficient 2D Method for Training Super-Large Deep Learning Models](https://arxiv.org/pdf/2104.05343.pdf) @@ -22,33 +20,33 @@ Let's still take a linear layer $Y = XA$ as an example. Given $P=q\times q$ processors (necessary condition), e.g. $q=2$, we split both the input $X$ and weight $A$ into $$ -\left[\begin{matrix} X_{10} & X_{11} \\ X_{00} & X_{01} \end{matrix} \right] +\left[\begin{matrix} X_{00} & X_{01} \\ X_{10} & X_{11} \end{matrix} \right] \text{~and~} -\left[\begin{matrix} A_{10} & A_{11} \\ A_{00} & A_{01} \end{matrix} \right]. +\left[\begin{matrix} A_{00} & A_{01} \\ A_{10} & A_{11} \end{matrix} \right]. $$ The calculation includes $q$ steps. When $t=1$, $X_{i0}$ is broadcasted in its row, and $A_{0j}$ is broadcasted in its column. So, we have $$ -\left[\begin{matrix} X_{10},A_{00} & X_{10},A_{01} \\ X_{00},A_{00} & X_{00},A_{01} \end{matrix} \right]. +\left[\begin{matrix} X_{00},A_{00} & X_{00},A_{01} \\ X_{10},A_{00} & X_{10},A_{01} \end{matrix} \right]. $$ Then we multiply $X_{i0}$ and $A_{0j}$ on each processor $(i, j)$ as $$ -\left[\begin{matrix} X_{10}A_{00} & X_{10}A_{01} \\ X_{00}A_{00} & X_{00}A_{01} \end{matrix} \right] (1). +\left[\begin{matrix} X_{00}A_{00} & X_{00}A_{01} \\ X_{10}A_{00} & X_{10}A_{01} \end{matrix} \right] (1). $$ Similarly, when $t=2$, $X_{i1}$ is broadcasted in its row, $A_{1j}$ is broadcasted in its column, and we multiply them as $$ -\left[\begin{matrix} X_{11}A_{10} & X_{11}A_{11} \\ X_{01}A_{10} & X_{01}A_{11} \end{matrix} \right] (2). +\left[\begin{matrix} X_{01}A_{10} & X_{01}A_{11} \\ X_{11}A_{10} & X_{11}A_{11} \end{matrix} \right] (2). $$ By adding $(1)$ and $(2)$ up, we have $$ -Y = XA = \left[\begin{matrix} X_{10}A_{00}+X_{11}A_{10} & X_{10}A_{01}+X_{11}A_{11} \\ X_{00}A_{00}+X_{01}A_{10} & X_{00}A_{01}+X_{01}A_{11} \end{matrix} \right]. +Y = XA = \left[\begin{matrix} X_{00}A_{00}+X_{01}A_{10} & X_{00}A_{01}+X_{01}A_{11} \\ X_{10}A_{00}+X_{11}A_{10} & X_{10}A_{01}+X_{11}A_{11} \end{matrix} \right]. $$ ## Efficiency @@ -60,83 +58,9 @@ Given $P=q\times q$ processors, we present the theoretical computation and memor ## Usage -To enable 2D tensor parallelism for our model, e.g. on 4 GPUs, we need to configure the parallelism setting as below. -```python -CONFIG = dict(parallel=dict( - data=1, - pipeline=1, - tensor=dict(size=4, mode='2d'), -)) -``` -Then Colossal-AI will automatically apply 2D parallelism to all the layers from `colossalai.nn`. - -Let's define a model that consists of a two-layer multi-layer perceptron (MLP) as below. -```python -import colossalai -import colossalai.nn as col_nn -import torch -from colossalai.utils import print_rank_0 - -class MLP(torch.nn.Module): - def __init__(self, dim: int = 256): - super().__init__() - intermediate_dim = dim * 4 - self.dense_1 = col_nn.Linear(dim, intermediate_dim) - print_rank_0(f'Weight of the first linear layer: {self.dense_1.weight.shape}') - self.activation = torch.nn.GELU() - self.dense_2 = col_nn.Linear(intermediate_dim, dim) - print_rank_0(f'Weight of the second linear layer: {self.dense_2.weight.shape}') - self.dropout = col_nn.Dropout(0.1) - - def forward(self, x): - x = self.dense_1(x) - print_rank_0(f'Output of the first linear layer: {x.shape}') - x = self.activation(x) - x = self.dense_2(x) - print_rank_0(f'Output of the second linear layer: {x.shape}') - x = self.dropout(x) - return x -``` -Launch Colossal-AI on 4 GPUs and build the model -```python -parser = colossalai.get_default_parser() -colossalai.launch(config=CONFIG, - rank=args.rank, - world_size=args.world_size, - local_rank=args.local_rank, - host=args.host, - port=args.port) - -m = MLP() -``` -We will see the shapes of partitioned parameters(e.g. weights) in the MLP model. -```shell -Weight of the first linear layer: torch.Size([128, 512]) -Weight of the second linear layer: torch.Size([512, 128]) -``` -The complete weight of the first linear layer is supposed to have the shape `[256, 1024]`. After the partitioning of 2D parallelism, it becomes `[128, 512]` on each GPU. -Similarly, the second layer partitions the weight `[1024, 256]` into `[512, 128]`. - -We can run the model with some random inputs. -```python -from colossalai.context import ParallelMode -from colossalai.core import global_context as gpc -from colossalai.utils import get_current_device - -x = torch.randn((16, 256), device=get_current_device()) -# partition input -torch.distributed.broadcast(x, src=0) -x = torch.chunk(x, 2, dim=0)[gpc.get_local_rank(ParallelMode.PARALLEL_2D_COL)] -x = torch.chunk(x, 2, dim=-1)[gpc.get_local_rank(ParallelMode.PARALLEL_2D_ROW)] -print_rank_0(f'Input: {x.shape}') - -x = m(x) -``` -Then we can see the shapes of activation results. -```shell -Input: torch.Size([8, 128]) -Output of the first linear layer: torch.Size([8, 512]) -Output of the second linear layer: torch.Size([8, 128]) -``` -The activation tensors in 2D parallelism are all split in both row and column. -E.g. the output of the first linear layer has the shape `[8, 512]`, while the second layer has the output of `[8, 128]`. +Currently the newest version of ColossalAI doesn't support 2D tensor parallelism, but this feature will be integrated into `Shardformer` in future releases. +For more details about ideas and usages of `Shardformer`, please refer to [Shardformer Doc](./shardformer.md). + +For users of older version of ColossalAI, please refer to [ColossalAI-Examples - 2D Tensor Parallelism](https://github.com/hpcaitech/ColossalAI-Examples/blob/main/features/tensor_parallel/README.md). + + diff --git a/docs/source/en/features/2p5D_tensor_parallel.md b/docs/source/en/features/2p5D_tensor_parallel.md index 6076562e6dca51a7f41c080d9e115d8b3f2a3415..4a97a39e1effac8ca51c9d44fc30abd36c97b48e 100644 --- a/docs/source/en/features/2p5D_tensor_parallel.md +++ b/docs/source/en/features/2p5D_tensor_parallel.md @@ -3,13 +3,11 @@ Author: Zhengda Bian, Yongbin Li **Prerequisite** -- [Define Your Configuration](../basics/define_your_config.md) -- [Configure Parallelization](../basics/configure_parallelization.md) - [1D Tensor Parallelism](./1D_tensor_parallel.md) - [2D Tensor Parallelism](./2D_tensor_parallel.md) **Example Code** -- [ColossalAI-Examples - 2.5D Tensor Parallelism](https://github.com/hpcaitech/ColossalAI-Examples/tree/main/features/tensor_parallel/tensor_parallel_2p5d.py) +- [ColossalAI-Examples - 2.5D Tensor Parallelism](https://github.com/hpcaitech/ColossalAI-Examples/blob/main/features/tensor_parallel/README.md) **Related Paper** - [2.5-dimensional distributed model training](https://arxiv.org/pdf/2105.14500.pdf) @@ -23,29 +21,30 @@ Let's still take a linear layer $Y = XA$ as an example. Given $P=q \times q \times d$ processors (necessary condition), e.g. $q=d=2$, we split the input $X$ into $d\times q$ rows and $q$ columns as $$ -\left[\begin{matrix} X_{30} & X_{31} \\ X_{20} & X_{21} \\ X_{10} & X_{11} \\ X_{00} & X_{01}\end{matrix} \right], +\left[\begin{matrix} X_{00} & X_{01} \\ X_{10} & X_{11} \\ X_{20} & X_{21} \\ X_{30} & X_{31}\end{matrix} \right], $$ + which can be reshaped into $d$ layers as $$ -\left[\begin{matrix} X_{10} & X_{11} \\ X_{00} & X_{01} \end{matrix} \right] \text{~and~}\left[\begin{matrix} X_{30} & X_{31} \\ X_{20} & X_{21} \end{matrix} \right]. +\left[\begin{matrix} X_{00} & X_{01} \\ X_{10} & X_{11} \end{matrix} \right] \text{~and~}\left[\begin{matrix} X_{20} & X_{21} \\ X_{30} & X_{31} \end{matrix} \right]. $$ Also, the weight $A$ is split into $$ -\left[\begin{matrix} A_{10} & A_{11} \\ A_{00} & A_{01} \end{matrix} \right]. +\left[\begin{matrix} A_{00} & A_{01} \\ A_{10} & A_{11} \end{matrix} \right]. $$ For each layer of $X$, we use the SUMMA algorithm to multiply $X$ and $A$. Then, we have the output $$ -\left[\begin{matrix} Y_{10}=X_{10}A_{00}+X_{11}A_{10} & Y_{11}=X_{10}A_{01}+X_{11}A_{11} \\ Y_{00}=X_{00}A_{00}+X_{01}A_{10} & Y_{01}=X_{00}A_{01}+X_{01}A_{11} \end{matrix} \right] +\left[\begin{matrix} Y_{00}=X_{00}A_{00}+X_{01}A_{10} & Y_{01}=X_{00}A_{01}+X_{01}A_{11} \\ Y_{10}=X_{10}A_{00}+X_{11}A_{10} & Y_{11}=X_{10}A_{01}+X_{11}A_{11} \end{matrix} \right] \text{~and~} $$ $$ -\left[\begin{matrix} Y_{30}=X_{30}A_{00}+X_{31}A_{10} & Y_{31}=X_{30}A_{01}+X_{31}A_{11} \\ Y_{20}=X_{20}A_{00}+X_{21}A_{10} & Y_{21}=X_{20}A_{01}+X_{21}A_{11} \end{matrix} \right]. +\left[\begin{matrix} Y_{20}=X_{20}A_{00}+X_{21}A_{10} & Y_{21}=X_{20}A_{01}+X_{21}A_{11} \\ Y_{30}=X_{30}A_{00}+X_{31}A_{10} & Y_{31}=X_{30}A_{01}+X_{31}A_{11} \end{matrix} \right]. $$ ## Efficiency @@ -57,86 +56,9 @@ Given $P=q \times q \times d$ processors, we present the theoretical computation ## Usage -To enable 2.5D tensor parallelism for our model, e.g. on 8 GPUs, we need to configure the parallelism setting as below. -```python -CONFIG = dict(parallel=dict( - data=1, - pipeline=1, - tensor=dict(size=8, mode='2.5d', depth=2), -)) - -``` -Then Colossal-AI will automatically apply 2.5D parallelism to all the layers from `colossalai.nn`. - -Let's define a model that consists of a two-layer multi-layer perceptron (MLP) as below. -```python -import colossalai -import colossalai.nn as col_nn -import torch -from colossalai.utils import print_rank_0 - -class MLP(torch.nn.Module): - def __init__(self, dim: int = 256): - super().__init__() - intermediate_dim = dim * 4 - self.dense_1 = col_nn.Linear(dim, intermediate_dim) - print_rank_0(f'Weight of the first linear layer: {self.dense_1.weight.shape}') - self.activation = torch.nn.GELU() - self.dense_2 = col_nn.Linear(intermediate_dim, dim) - print_rank_0(f'Weight of the second linear layer: {self.dense_2.weight.shape}') - self.dropout = col_nn.Dropout(0.1) - - def forward(self, x): - x = self.dense_1(x) - print_rank_0(f'Output of the first linear layer: {x.shape}') - x = self.activation(x) - x = self.dense_2(x) - print_rank_0(f'Output of the second linear layer: {x.shape}') - x = self.dropout(x) - return x -``` -Launch Colossal-AI on 8 GPUs and build the model -```python -parser = colossalai.get_default_parser() -colossalai.launch(config=CONFIG, - rank=args.rank, - world_size=args.world_size, - local_rank=args.local_rank, - host=args.host, - port=args.port) - -m = MLP() -``` -We will see the shapes of partitioned parameters(e.g. weights) in the MLP model. -```shell -Weight of the first linear layer: torch.Size([128, 512]) -Weight of the second linear layer: torch.Size([512, 128]) -``` -The complete weight of the first linear layer is supposed to have the shape `[256, 1024]`. After the partitioning of 2.5D parallelism, it becomes `[128, 512]` on each GPU. -Similarly, the second layer partitions the weight `[1024, 256]` into `[512, 128]`. - -We can run the model with some random inputs. -```python -from colossalai.context import ParallelMode -from colossalai.core import global_context as gpc -from colossalai.utils import get_current_device - -x = torch.randn((16, 256), device=get_current_device()) -# partition input -torch.distributed.broadcast(x, src=0) -x = torch.chunk(x, 2, dim=0)[gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_DEP)] -x = torch.chunk(x, 2, dim=0)[gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL)] -x = torch.chunk(x, 2, dim=-1)[gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_ROW)] -print_rank_0(f'Input: {x.shape}') - -x = m(x) -``` -Then we can see the shapes of activation results. -```shell -Input: torch.Size([4, 128]) -Output of the first linear layer: torch.Size([4, 512]) -Output of the second linear layer: torch.Size([4, 128]) -``` -The activation tensors in 2.5D parallelism are all split by $d \times q$ in the row and $q$ in the column. -E.g. the output of the first linear layer has the shape `[4, 512]`), while the second layer has the output of `[4, 128]`. -Note, 2.5D parallelism use the same partition method as 2D parallelism for weights, where the difference is the partition of input. +Currently the newest version of ColossalAI doesn't support 2.5D tensor parallelism, but this feature will be integrated into `Shardformer` in future releases. +For more details about ideas and usages of `Shardformer`, please refer to [Shardformer Doc](./shardformer.md). + +For users of older version of ColossalAI, please refer to [ColossalAI-Examples - 2.5D Tensor Parallelism](https://github.com/hpcaitech/ColossalAI-Examples/blob/main/features/tensor_parallel/README.md). + + diff --git a/docs/source/en/features/3D_tensor_parallel.md b/docs/source/en/features/3D_tensor_parallel.md index 1207376335cea4b7865b210f8b9ea09703e09a45..8f7deb5b6b7408fbbc4cfd5afa3a3e9e1ce34ef2 100644 --- a/docs/source/en/features/3D_tensor_parallel.md +++ b/docs/source/en/features/3D_tensor_parallel.md @@ -3,13 +3,11 @@ Author: Zhengda Bian, Yongbin Li **Prerequisite** -- [Define Your Configuration](../basics/define_your_config.md) -- [Configure Parallelization](../basics/configure_parallelization.md) - [1D Tensor Parallelism](./1D_tensor_parallel.md) - [2D Tensor Parallelism](./2D_tensor_parallel.md) **Example Code** -- [ColossalAI-Examples - 3D Tensor Parallelism](https://github.com/hpcaitech/ColossalAI-Examples/tree/main/features/tensor_parallel/tensor_parallel_3d.py) +- [ColossalAI-Examples - 3D Tensor Parallelism](https://github.com/hpcaitech/ColossalAI-Examples/blob/main/features/tensor_parallel/README.md) **Related Paper** - [Maximizing Parallelism in Distributed Training for Huge Neural Networks](https://arxiv.org/pdf/2105.14450.pdf) @@ -67,85 +65,9 @@ Given $P=q \times q \times q$ processors, we present the theoretical computation ## Usage -To enable 3D tensor parallelism for our model, e.g. on 8 GPUs, we need to configure the parallism setting as below. -```python -CONFIG = dict(parallel=dict( - data=1, - pipeline=1, - tensor=dict(size=8, mode='3d'), -)) -``` -Then Colossal-AI will automatically apply 3D parallelism to all the layers from `colossalai.nn`. - -Let's define a model that consists of a two-layer multi-layer perceptron (MLP) as below. -```python -import colossalai -import colossalai.nn as col_nn -import torch -from colossalai.utils import print_rank_0 - -class MLP(torch.nn.Module): - def __init__(self, dim: int = 256): - super().__init__() - intermediate_dim = dim * 4 - self.dense_1 = col_nn.Linear(dim, intermediate_dim) - print_rank_0(f'Weight of the first linear layer: {self.dense_1.weight.shape}') - self.activation = torch.nn.GELU() - self.dense_2 = col_nn.Linear(intermediate_dim, dim) - print_rank_0(f'Weight of the second linear layer: {self.dense_2.weight.shape}') - self.dropout = col_nn.Dropout(0.1) - - def forward(self, x): - x = self.dense_1(x) - print_rank_0(f'Output of the first linear layer: {x.shape}') - x = self.activation(x) - x = self.dense_2(x) - print_rank_0(f'Output of the second linear layer: {x.shape}') - x = self.dropout(x) - return x -``` -Launch Colossal-AI on 8 GPUs and build the model -```python -parser = colossalai.get_default_parser() -colossalai.launch(config=CONFIG, - rank=args.rank, - world_size=args.world_size, - local_rank=args.local_rank, - host=args.host, - port=args.port) - -m = MLP() -``` -We will see the shapes of partitioned parameters(e.g. weights) in the MLP model. -```shell -Weight of the first linear layer: torch.Size([128, 256]) -Weight of the second linear layer: torch.Size([512, 64]) -``` -The complete weight of the first linear layer is supposed to have the shape `[256, 1024]`. After the partitioning of 3D parallelism, it becomes `[128, 256]` on each GPU. -Similarly, the second layer partitions the weight `[1024, 256]` into `[512, 64]`. - -We can run the model with some random inputs. -```python -from colossalai.context import ParallelMode -from colossalai.core import global_context as gpc -from colossalai.utils import get_current_device - -x = torch.randn((16, 256), device=get_current_device()) -# partition input -torch.distributed.broadcast(x, src=0) -x = torch.chunk(x, 2, dim=0)[gpc.get_local_rank(ParallelMode.PARALLEL_3D_WEIGHT)] -x = torch.chunk(x, 2, dim=0)[gpc.get_local_rank(ParallelMode.PARALLEL_3D_INPUT)] -x = torch.chunk(x, 2, dim=-1)[gpc.get_local_rank(ParallelMode.PARALLEL_3D_OUTPUT)] -print_rank_0(f'Input: {x.shape}') - -x = m(x) -``` -Then we can see the shapes of activation results. -```shell -Input: torch.Size([4, 128]) -Output of the first linear layer: torch.Size([4, 512]) -Output of the second linear layer: torch.Size([4, 128]) -``` -The activation tensors in 3D parallelism are all split by $q^2$ in the row and $q$ in the column. -E.g. the output of the first linear layer has the shape `[4, 512]`), while the second layer has the output of `[4, 128]`. -Note, although the results of 3D parallelism have the same shape as that of 2.5D parallelism for weights here, the content of each partition is different. +Currently the newest version of ColossalAI doesn't support 3D tensor parallelism, but this feature will be integrated into `Shardformer` in future releases. +For more details about ideas and usages of `Shardformer`, please refer to [Shardformer Doc](./shardformer.md). + +For users of older version of ColossalAI, please refer to [ColossalAI-Examples - 3D Tensor Parallelism](https://github.com/hpcaitech/ColossalAI-Examples/blob/main/features/tensor_parallel/README.md). + + diff --git a/docs/source/en/features/cluster_utils.md b/docs/source/en/features/cluster_utils.md new file mode 100644 index 0000000000000000000000000000000000000000..7331d5e73ae0b228bab5e8fa1933d677684422d8 --- /dev/null +++ b/docs/source/en/features/cluster_utils.md @@ -0,0 +1,16 @@ +# Cluster Utilities + +Author: [Hongxin Liu](https://github.com/ver217) + +**Prerequisite:** +- [Distributed Training](../concepts/distributed_training.md) + +## Introduction + +We provide a utility class `colossalai.cluster.DistCoordinator` to coordinate distributed training. It's useful to get various information about the cluster, such as the number of nodes, the number of processes per node, etc. + +## API Reference + +{{ autodoc:colossalai.cluster.DistCoordinator }} + + diff --git a/docs/source/en/features/gradient_accumulation.md b/docs/source/en/features/gradient_accumulation.md deleted file mode 100644 index ecc209fbac8d043d5330c6f971a5cd52efe55b38..0000000000000000000000000000000000000000 --- a/docs/source/en/features/gradient_accumulation.md +++ /dev/null @@ -1,45 +0,0 @@ -# Gradient Accumulation - -Author: Shenggui Li, Yongbin Li - -**Prerequisite** -- [Define Your Configuration](../basics/define_your_config.md) -- [Use Engine and Trainer in Training](../basics/engine_trainer.md) - -**Example Code** -- [ColossalAI-Examples Gradient Accumulation](https://github.com/hpcaitech/ColossalAI-Examples/tree/main/features/gradient_accumulation) - -## Introduction - -Gradient accumulation is a common way to enlarge your batch size for training. -When training large-scale models, memory can easily become the bottleneck and the batch size can be very small, (e.g. 2), -leading to unsatisfactory convergence. Gradient accumulation works by adding up the gradients calculated in multiple iterations, -and only update the parameters in the preset iteration. - -## Usage - -It is simple to use gradient accumulation in Colossal-AI. Just add this following configuration into your config file. -The integer represents the number of iterations to accumulate gradients. - -```python -gradient_accumulation = -``` - -## Hands-on Practice - -We provide a [runnable example](https://github.com/hpcaitech/ColossalAI-Examples/tree/main/features/gradient_accumulation) -to demonstrate gradient accumulation. In this example, we set the gradient accumulation size to be 4. You can run the script using this command: - -```shell -python -m torch.distributed.launch --nproc_per_node 1 --master_addr localhost --master_port 29500 run_resnet_cifar10_with_engine.py -``` - -You will see output similar to the text below. This shows gradient is indeed accumulated as the parameter is not updated -in the first 3 steps, but only updated in the last step. - -```text -iteration 0, first 10 elements of param: tensor([-0.0208, 0.0189, 0.0234, 0.0047, 0.0116, -0.0283, 0.0071, -0.0359, -0.0267, -0.0006], device='cuda:0', grad_fn=) -iteration 1, first 10 elements of param: tensor([-0.0208, 0.0189, 0.0234, 0.0047, 0.0116, -0.0283, 0.0071, -0.0359, -0.0267, -0.0006], device='cuda:0', grad_fn=) -iteration 2, first 10 elements of param: tensor([-0.0208, 0.0189, 0.0234, 0.0047, 0.0116, -0.0283, 0.0071, -0.0359, -0.0267, -0.0006], device='cuda:0', grad_fn=) -iteration 3, first 10 elements of param: tensor([-0.0141, 0.0464, 0.0507, 0.0321, 0.0356, -0.0150, 0.0172, -0.0118, 0.0222, 0.0473], device='cuda:0', grad_fn=) -``` diff --git a/docs/source/en/features/gradient_accumulation_with_booster.md b/docs/source/en/features/gradient_accumulation_with_booster.md new file mode 100644 index 0000000000000000000000000000000000000000..347cd6e519bb006202fae7af8d4b885dc021f02d --- /dev/null +++ b/docs/source/en/features/gradient_accumulation_with_booster.md @@ -0,0 +1,145 @@ +# Gradient Accumulation + +Author: [Mingyan Jiang](https://github.com/jiangmingyan) + +**Prerequisite** +- [Training Booster](../basics/booster_api.md) + +## Introduction + +Gradient accumulation is a common way to enlarge your batch size for training. When training large-scale models, memory can easily become the bottleneck and the batch size can be very small, (e.g. 2), leading to unsatisfactory convergence. Gradient accumulation works by adding up the gradients calculated in multiple iterations, and only update the parameters in the preset iteration. + +## Usage + +It is simple to use gradient accumulation in Colossal-AI. Just call `booster.no_sync()` which returns a context manager. It accumulate gradients without synchronization, meanwhile you should not update the weights. + +## Hands-on Practice + +We now demonstrate gradient accumulation. In this example, we let the gradient accumulation size to be 4. + +### Step 1. Import libraries in train.py +Create a `train.py` and import the necessary dependencies. The version of `torch` should not be lower than 1.8.1. + +```python +import os +from pathlib import Path + +import torch +from torchvision import transforms +from torchvision.datasets import CIFAR10 +from torchvision.models import resnet18 +from torch.utils.data import DataLoader + +import colossalai +from colossalai.booster import Booster +from colossalai.booster.plugin import TorchDDPPlugin +from colossalai.logging import get_dist_logger +from colossalai.cluster.dist_coordinator import priority_execution +``` + +### Step 2. Initialize Distributed Environment +We then need to initialize distributed environment. For demo purpose, we uses `launch_from_torch`. You can refer to [Launch Colossal-AI](../basics/launch_colossalai.md) for other initialization methods. + +```python +# initialize distributed setting +parser = colossalai.get_default_parser() +args = parser.parse_args() +# launch from torch +colossalai.launch_from_torch(config=dict()) +``` + +### Step 3. Create training components +Build your model, optimizer, loss function, lr scheduler and dataloaders. Note that the root path of the dataset is obtained from the environment variable `DATA`. You may `export DATA=/path/to/data` or change `Path(os.environ['DATA'])` to a path on your machine. Data will be automatically downloaded to the root path. + +```python +# define the training hyperparameters +BATCH_SIZE = 128 +GRADIENT_ACCUMULATION = 4 + +# build resnet +model = resnet18(num_classes=10) + +# build dataloaders +with priority_execution(): + train_dataset = CIFAR10(root=Path(os.environ.get('DATA', './data')), + download=True, + transform=transforms.Compose([ + transforms.RandomCrop(size=32, padding=4), + transforms.RandomHorizontalFlip(), + transforms.ToTensor(), + transforms.Normalize(mean=[0.4914, 0.4822, 0.4465], std=[0.2023, 0.1994, 0.2010]), + ])) + +# build criterion +criterion = torch.nn.CrossEntropyLoss() + +# optimizer +optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9, weight_decay=5e-4) +``` + +### Step 4. Inject Feature +Create a `TorchDDPPlugin` object to instantiate a `Booster`, and boost these training components. + +```python +plugin = TorchDDPPlugin() +booster = Booster(plugin=plugin) +train_dataloader = plugin.prepare_dataloader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, drop_last=True) +model, optimizer, criterion, train_dataloader, _ = booster.boost(model=model, + optimizer=optimizer, + criterion=criterion, + dataloader=train_dataloader) +``` + +### Step 5. Train with Booster +Use booster in a normal training loops, and verify gradient accumulation. `param_by_iter` is to record the distributed training information. +```python +optimizer.zero_grad() +for idx, (img, label) in enumerate(train_dataloader): + sync_context = booster.no_sync(model) + img = img.cuda() + label = label.cuda() + if idx % (GRADIENT_ACCUMULATION - 1) != 0: + with sync_context: + output = model(img) + train_loss = criterion(output, label) + train_loss = train_loss / GRADIENT_ACCUMULATION + booster.backward(train_loss, optimizer) + else: + output = model(img) + train_loss = criterion(output, label) + train_loss = train_loss / GRADIENT_ACCUMULATION + booster.backward(train_loss, optimizer) + optimizer.step() + optimizer.zero_grad() + + ele_1st = next(model.parameters()).flatten()[0] + param_by_iter.append(str(ele_1st.item())) + + if idx != 0 and idx % (GRADIENT_ACCUMULATION - 1) == 0: + break + + for iteration, val in enumerate(param_by_iter): + print(f'iteration {iteration} - value: {val}') + + if param_by_iter[-1] != param_by_iter[0]: + print('The parameter is only updated in the last iteration') + +``` + +### Step 6. Invoke Training Scripts +To verify gradient accumulation, we can just check the change of parameter values. When gradient accumulation is set, parameters are only updated in the last step. You can run the script using this command: +```shell +colossalai run --nproc_per_node 1 train.py +``` + +You will see output similar to the text below. This shows gradient is indeed accumulated as the parameter is not updated +in the first 3 steps, but only updated in the last step. + +```text +iteration 0, first 10 elements of param: tensor([-0.0208, 0.0189, 0.0234, 0.0047, 0.0116, -0.0283, 0.0071, -0.0359, -0.0267, -0.0006], device='cuda:0', grad_fn=) +iteration 1, first 10 elements of param: tensor([-0.0208, 0.0189, 0.0234, 0.0047, 0.0116, -0.0283, 0.0071, -0.0359, -0.0267, -0.0006], device='cuda:0', grad_fn=) +iteration 2, first 10 elements of param: tensor([-0.0208, 0.0189, 0.0234, 0.0047, 0.0116, -0.0283, 0.0071, -0.0359, -0.0267, -0.0006], device='cuda:0', grad_fn=) +iteration 3, first 10 elements of param: tensor([-0.0141, 0.0464, 0.0507, 0.0321, 0.0356, -0.0150, 0.0172, -0.0118, 0.0222, 0.0473], device='cuda:0', grad_fn=) +``` + + diff --git a/docs/source/en/features/gradient_clipping.md b/docs/source/en/features/gradient_clipping.md deleted file mode 100644 index f606dde6c393e56269f4ad97c402529dc52569f3..0000000000000000000000000000000000000000 --- a/docs/source/en/features/gradient_clipping.md +++ /dev/null @@ -1,62 +0,0 @@ -# Gradient Clipping - -Author: Boxiang Wang, Haichen Huang, Yongbin Li - -**Prerequisite** -- [Define Your Configuration](../basics/define_your_config.md) -- [Use Engine and Trainer in Training](../basics/engine_trainer.md) - -**Example Code** -- [ColossalAI-Examples Gradient Clipping](https://github.com/hpcaitech/ColossalAI-Examples/tree/main/features/gradient_clipping) - -**Related Paper** -- [On the difficulty of training Recurrent Neural Networks](https://arxiv.org/abs/1211.5063) - -## Introduction - -In order to speed up training process and seek global optimum for better performance, more and more learning -rate schedulers have been proposed. People turn to control learning rate to adjust descent pace during training, -which makes gradient vector better to be uniformed in every step. In that case, the descent pace can be -controlled as expected. As a result, gradient clipping, a technique which can normalize the gradient vector -to circumscribe it in a uniformed length, becomes indispensable for those who desire their better -performance of their models. - -You do not have to worry about implementing gradient clipping when using Colossal-AI, we support gradient -clipping in a powerful and convenient way. All you need is just an additional command in your configuration -file. - -## Why you should use gradient clipping provided by Colossal-AI - -The reason of why we do not recommend users to write gradient clipping by themselves is that naive gradient clipping -may fail when applying tensor parallelism, pipeline parallelism or MoE. - -According to the illustration below, each GPU only owns a portion of parameters of the weight in a linear layer. -To get correct norm of gradient vector of the weight of the linear layer, the norm of every gradient vector in each GPU -should be summed together. -More complicated thing is that the distribution of bias is different from the distribution of the weight. -The communication group is different in the sum operation. - -(PS: This situation is an old version of 2D parallelism, the implementation in the code is not the same. -But it is a good example about the difficulty to unify all communication in gradient clipping.) - -
              - -
              Layout of parameters
              -
              - -Do not worry about it, since Colossal-AI have handled it for you. - -### Usage -To use gradient clipping, you can just simply add gradient clipping norm in your configuration file. -```python -clip_grad_norm = 1.0 -``` - -### Hands-On Practice - -We provide a [runnable example](https://github.com/hpcaitech/ColossalAI-Examples/tree/main/features/gradient_clipping) -to demonstrate gradient clipping. In this example, we set the gradient clipping vector norm to be 1.0. You can run the script using this command: - -```shell -python -m torch.distributed.launch --nproc_per_node 1 --master_addr localhost --master_port 29500 train_with_engine.py -``` diff --git a/docs/source/en/features/gradient_clipping_with_booster.md b/docs/source/en/features/gradient_clipping_with_booster.md new file mode 100644 index 0000000000000000000000000000000000000000..14eee67bc019b6c076399f7a7f2bfec1254c3fb9 --- /dev/null +++ b/docs/source/en/features/gradient_clipping_with_booster.md @@ -0,0 +1,141 @@ +# Gradient Clipping + +Author: [Mingyan Jiang](https://github.com/jiangmingyan) + +**Prerequisite** +- [Training Booster](../basics/booster_api.md) + +**Related Paper** +- [On the difficulty of training Recurrent Neural Networks](https://arxiv.org/abs/1211.5063) + +## Introduction + +In order to speed up training process and seek global optimum for better performance, more and more learning rate schedulers have been proposed. People turn to control learning rate to adjust descent pace during training, which makes gradient vector better to be uniformed in every step. In that case, the descent pace can be controlled as expected. As a result, gradient clipping, a technique which can normalize the gradient vector to circumscribe it in a uniformed length, becomes indispensable for those who desire their better performance of their models. + +You do not have to worry about implementing gradient clipping when using Colossal-AI, we support gradient clipping in a powerful and convenient way. All you need is just an additional command in your configuration file. + +## Why you should use gradient clipping provided by Colossal-AI + +The reason of why we do not recommend users to write gradient clipping by themselves is that naive gradient clipping may fail when applying tensor parallelism, pipeline parallelism or MoE. + +According to the illustration below, each GPU only owns a portion of parameters of the weight in a linear layer. To get correct norm of gradient vector of the weight of the linear layer, the norm of every gradient vector in each GPU should be summed together. More complicated thing is that the distribution of bias is different from the distribution of the weight. The communication group is different in the sum operation. + +(PS: This situation is an old version of 2D parallelism, the implementation in the code is not the same. But it is a good example about the difficulty to unify all communication in gradient clipping.) + +
              + +
              Layout of parameters
              +
              + +Do not worry about it, since Colossal-AI have handled it for you. + +## Usage +To use gradient clipping, you can just add the following code to your configuration file, and after boosted, you can call `clip_grad_by_norm` or `clip_grad_by_value` method of optimizer, if it support clip gradients. + +## Hands-On Practice + +We now demonstrate how to use gradient clipping. In this example, we set the gradient clipping vector norm to be 1.0. + +### step 1. Import libraries in train.py +Create a `train.py` and import the necessary dependencies. + +```python +import os +from pathlib import Path + +import torch +from torchvision import transforms +from torchvision.datasets import CIFAR10 +from torchvision.models import resnet34 +from tqdm import tqdm + +import colossalai +from colossalai.booster import Booster +from colossalai.booster.plugin import TorchDDPPlugin +from colossalai.logging import get_dist_logger +from colossalai.nn.lr_scheduler import CosineAnnealingLR +``` + +### Step 2. Initialize Distributed Environment +We then need to initialize distributed environment. For demo purpose, we uses `launch_from_torch`. You can refer to [Launch Colossal-AI](../basics/launch_colossalai.md) +for other initialization methods. + +```python +colossalai.launch_from_torch(config=dict()) +logger = get_dist_logger() +``` + + +### Step 3. Create training components + +Build your model, optimizer, loss function, lr scheduler and dataloaders. Note that the root path of the dataset is obtained from the environment variable `DATA`. You may `export DATA=/path/to/data` or change `Path(os.environ['DATA'])` to a path on your machine. Data will be automatically downloaded to the root path. +```python +# define training hyperparameters +NUM_EPOCHS = 200 +BATCH_SIZE = 128 +GRADIENT_CLIPPING = 0.1 +# build resnet +model = resnet34(num_classes=10) +# build dataloaders +train_dataset = CIFAR10(root=Path(os.environ.get('DATA', './data')), + download=True, + transform=transforms.Compose([ + transforms.RandomCrop(size=32, padding=4), + transforms.RandomHorizontalFlip(), + transforms.ToTensor(), + transforms.Normalize(mean=[0.4914, 0.4822, 0.4465], std=[0.2023, 0.1994, 0.2010]), + ])) +# build criterion +criterion = torch.nn.CrossEntropyLoss() + +# optimizer +optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9, weight_decay=5e-4) + +# lr_scheduler +lr_scheduler = CosineAnnealingLR(optimizer, total_steps=NUM_EPOCHS) + +``` +### Step 4. Inject Gradient Clipping Feature + +Create a `TorchDDPPlugin` object and `Booster` object, get a data loader from plugin, then boost all training components. +```python +plugin = TorchDDPPlugin() +booster = Booster(mixed_precision='fp16', plugin=plugin) +train_dataloader = plugin.prepare_dataloader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, drop_last=True) +model, optimizer, criterion, train_dataloader, lr_scheduler = booster.boost(model,optimizer, criterion,train_dataloader, lr_scheduler) + +``` + +### Step 5. Train with Booster +Use booster in a normal training loops. +```python +# verify gradient clipping +model.train() +for idx, (img, label) in enumerate(train_dataloader): + img = img.cuda() + label = label.cuda() + + model.zero_grad() + output = model(img) + train_loss = criterion(output, label) + booster.backward(train_loss, optimizer) + optimizer.clip_grad_by_norm(max_norm=GRADIENT_CLIPPING) + optimizer.step() + lr_scheduler.step() + + ele_1st = next(model.parameters()).flatten()[0] + logger.info(f'iteration {idx}, loss: {train_loss}, 1st element of parameters: {ele_1st.item()}') + + # only run for 4 iterations + if idx == 3: + break +``` + +### Step 6. Invoke Training Scripts +You can run the script using this command: + +```shell +colossalai run --nproc_per_node 1 train.py +``` + + diff --git a/docs/source/en/features/gradient_handler.md b/docs/source/en/features/gradient_handler.md deleted file mode 100644 index 757016fcb53a5c16ac02810d88eee447fb659b2b..0000000000000000000000000000000000000000 --- a/docs/source/en/features/gradient_handler.md +++ /dev/null @@ -1,63 +0,0 @@ -# Gradient Handler - -Author: Shenggui Li, Yongbin Li - -**Prerequisite** -- [Define Your Configuration](../basics/define_your_config.md) -- [Use Engine and Trainer in Training](../basics/engine_trainer.md) - -**Example Code** -- [ColossalAI-Examples Gradient Handler](https://github.com/hpcaitech/ColossalAI-Examples/tree/main/features/gradient_handler) - -## Introduction - -In distributed training, gradient synchronization is required at the end of each iteration. This is important because we -need to make sure the parameters are updated with the same gradients in different machines so that the resulting parameters -are the same. This is often seen in data parallel as the model is replicated across data parallel ranks. - -In Colossal-AI, we provide an interface for users to customize how they want to handle the synchronization. This brings -flexibility in cases such as implementing a new parallelism method. - -When gradient handlers are used, PyTorch `DistributedDataParallel` will not be used as it will synchronize automatically. - -## Customize Your Gradient Handlers - -To implement a customized gradient handler, you need to follow these steps. -1. inherit `BaseGradientHandler` in Colossal-AI. -2. register the gradient handler into the `GRADIENT_HANDLER`. -3. implement `handle_gradient` method. - -```python -from colossalai.registry import GRADIENT_HANDLER -from colossalai.engine.gradient_handler import BaseGradientHandler - - -@GRADIENT_HANDLER.register_module -class MyGradientHandler(BaseGradientHandler): - - def handle_gradient(self): - do_something() - - -``` - - -## Usage - -To use a gradient handler, you need to specify your gradient handler in the config file. The gradient handler -will be automatically built and attached to the engine. - -```python -gradient_handler = [dict(type='MyGradientHandler')] -``` - - -### Hands-On Practice - -We provide a [runnable example](https://github.com/hpcaitech/ColossalAI-Examples/tree/main/features/gradient_handler) -to demonstrate the use of gradient handler. In this example, we used `DataParallelGradientHandler` instead of PyTorch -`DistributedDataParallel` for data parallel training. - -```shell -python -m torch.distributed.launch --nproc_per_node 4 --master_addr localhost --master_port 29500 train_with_engine.py -``` diff --git a/docs/source/en/features/lazy_init.md b/docs/source/en/features/lazy_init.md new file mode 100644 index 0000000000000000000000000000000000000000..133fd799280a15d177209f53c646c63c89c3ef36 --- /dev/null +++ b/docs/source/en/features/lazy_init.md @@ -0,0 +1,76 @@ +# Lazy initialization + +Author: [Hongxiu Liu](https://github.com/ver217) + +**Prerequisite:** +- [Train with booster](../basics/booster_api.md) + +## Introduction + +Lazy initialization defers model initialization. It saves memory when initializing large models. + +If your model has `N` billion parameters and your memory (or GPU memory) is `M` GB, we recommend you use lazy initialization when `4N >= M`. Otherwise, it is optional. + +## Usage + +Lazy initialization must be used with booster. + +### API reference + +{{ autodoc:colossalai.lazy.LazyInitContext }} + +### Example + +```python +import colossalai +from colossalai.lazy import LazyInitContext +from colossalai.booster import Booster +from colossalai.booster.plugin import GeminiPlugin + +from transformers import LlamaForCausalLM, LlamaConfig, BertForPreTraining + +colossalai.launch({}) +plugin = GeminiPlugin() +booster = Booster(plugin) + +# 1. Initialize model from scratch +# Initialization on cuda will accelerate the initialization process but take more GPU memory. +with LazyInitContext(default_device="cuda"): + model = LlamaForCausalLM(LlamaConfig(hidden_size=64, intermediate_size=172, num_hidden_layers=4, num_attention_heads=4)) +model, *_ = booster.boost(model) + +# 2. Initialize model from pretrained +with LazyInitContext(): + model = BertForPreTraining.from_pretrained("prajjwal1/bert-tiny") +model, *_ = booster.boost(model) +``` + +> ⚠️ Lazy initialization from pretrained is supported for colossalai>0.3.3 or main branch. + +## Limitations + +As we claimed, lazy initialization must be used with booster. And only several plugins support it. + +| Plugin | Supported | Remarks | +|-----------------|-----------|--------------| +| Gemini | Yes | | +| Hybrid Parallel | Yes | | +| Low Level Zero | No | No need | +| Torch DDP | No | Incompatible | +| Torch FSDP | No | Incompatible | + +Not all models can be lazily initialized. In some cases, a part of parameters/buffers may be early initialized. But don't worry, this part usually takes a small proportion of the whole model. + +And some models are not supported at all which will raise an error. We tested models in torchvision, diffusers, timm, transformers, torchaudio and torchrec. Below models are not supported: + +| Model | Category | +|-------------------------------|--------------| +| wav2vec2_base | torchaudio | +| hubert_base | torchaudio | +| ViTModel | transformers | +| ViTForMaskedImageModeling | transformers | +| ViTForImageClassification | transformers | +| Blip2Model | transformers | +| Blip2ForConditionalGeneration | transformers | + + diff --git a/docs/source/en/features/mixed_precision_training.md b/docs/source/en/features/mixed_precision_training.md deleted file mode 100644 index 11aa5235301a36f4fc75695e06fe52f8937386da..0000000000000000000000000000000000000000 --- a/docs/source/en/features/mixed_precision_training.md +++ /dev/null @@ -1,367 +0,0 @@ -# Auto Mixed Precision Training - -Author: Chuanrui Wang, Shenggui Li, Yongbin Li - -**Prerequisite** -- [Define Your Configuration](../basics/define_your_config.md) -- [Use Engine and Trainer in Training](../basics/engine_trainer.md) - -**Example Code** -- [ColossalAI-Examples AMP](https://github.com/hpcaitech/ColossalAI-Examples/tree/main/features/amp) - -**Related Paper** -- [Accelerating Scientific Computations with Mixed Precision Algorithms](https://arxiv.org/abs/0808.2794) - - -## Introduction - -AMP stands for automatic mixed precision training. -In Colossal-AI, we have incorporated different implementations of mixed precision training: - -1. torch.cuda.amp -2. apex.amp -3. naive amp - - -| Colossal-AI | support tensor parallel | support pipeline parallel | fp16 extent | -| ----------- | ----------------------- | ------------------------- | ----------- | -| AMP_TYPE.TORCH | ✅ | ❌ | Model parameters, activation, gradients are downcast to fp16 during forward and backward propagation | -| AMP_TYPE.APEX | ❌ | ❌ | More fine-grained, we can choose opt_level O0, O1, O2, O3 | -| AMP_TYPE.NAIVE | ✅ | ✅ | Model parameters, forward and backward operations are all downcast to fp16 | - -The first two rely on the original implementation of PyTorch (version 1.6 and above) and NVIDIA Apex. -The last method is similar to Apex O2 level. -Among these methods, apex AMP is not compatible with tensor parallelism. -This is because that tensors are split across devices in tensor parallelism, thus, it is required to communicate among different processes to check if inf or nan occurs in the whole model weights. -We modified the torch amp implementation so that it is compatible with tensor parallelism now. - -> ❌️ fp16 and zero configuration are not compatible -> -> ⚠️ Pipeline only support naive AMP currently - -We recommend you to use torch AMP as it generally gives better accuracy than naive AMP if no pipeline is used. - -## Table of Contents - -In this tutorial we will cover: - -1. AMP introduction -2. AMP in Colossal-AI -3. Hands-on Practice - -## AMP Introduction - -Automatic Mixed Precision training is a mixture of FP16 and FP32 training. - -Half-precision float point format (FP16) has lower arithmetic complexity and higher compute efficiency. -Besides, fp16 requires half of the storage needed by fp32 and saves memory & network bandwidth, which makes more memory -available for large batch size and model size. - -However, there are other operations, like reductions, which require the dynamic range of fp32 to avoid numeric overflow/underflow. That's the reason why we introduce automatic mixed precision, attempting to match each operation to its appropriate data type, which can reduce the memory footprint and augment training efficiency. - -
              - -
              Illustration of an ordinary AMP (figure from PatrickStar paper)
              -
              - -## AMP in Colossal-AI - -We supported three AMP training methods and allowed the user to train with AMP with no code. You can just simply add `fp16` -configuration in your configuration file to use AMP. - - -```python -from colossalai.amp import AMP_TYPE - -# use Torch AMP -fp16=dict( - mode = AMP_TYPE.TORCH -) - -# use naive AMP -fp16=dict( - mode = AMP_TYPE.NAIVE -) - -# use NVIDIA Apex AMP -fp16=dict( - mode = AMP_TYPE.APEX -) - -``` - -> These are the minimum configuration, full configuration are stated in the section later - -### AMP Modularity - -AMP module is designed to be completely modular and can be used independently. -If you wish to only use AMP in your code base without `colossalai.initialize`, -you can use `colossalai.amp.convert_to_amp`. - -```python -from colossalai.amp import AMP_TYPE - -# example of using torch amp -model, optimizer, criterion = colossalai.amp.convert_to_amp(model, - optimizer, - criterion, - AMP_TYPE.TORCH) -``` - -### Torch AMP Configuration - -```python -from colossalai.amp import AMP_TYPE - -fp16=dict( - mode=AMP_TYPE.TORCH, - - # below are default values for grad scaler - init_scale=2.**16, - growth_factor=2.0, - backoff_factor=0.5, - growth_interval=2000, - enabled=True -) -``` - -With optional arguments: -- init_scale(float, optional, default=2.**16): Initial scale factor -- growth_factor(float, optional, default=2.0): Factor by which the scale is multiplied during `update` if no inf/NaN gradients occur for ``growth_interval`` consecutive iterations. -- backoff_factor(float, optional, default=0.5): Factor by which the scale is multiplied during `update` if inf/NaN gradients occur in an iteration. -- growth_interval(int, optional, default=2000): Number of consecutive iterations without inf/NaN gradients that must occur for the scale to be multiplied by ``growth_factor``. -- enabled(bool, optional, default=True): If ``False``, disables gradient scaling. `step` simply invokes the underlying ``optimizer.step()``, and other methods become no-ops. - -### Apex AMP Configuration - -For this mode, we rely on the Apex implementation for mixed precision training. -We support this plugin because it allows for finer control on the granularity of mixed precision. -For example, O2 level (optimization level 2) will keep batch normalization in fp32. - -If you look for more details, please refer to [Apex Documentation](https://nvidia.github.io/apex/). - -```python -from colossalai.amp import AMP_TYPE - -fp16 = dict( - mode=AMP_TYPE.APEX, - - # below are the default values - enabled=True, - opt_level='O1', - cast_model_type=None, - patch_torch_functions=None, - keep_batchnorm_fp32=None, - master_weights=None, - loss_scale=None, - cast_model_outputs=None, - num_losses=1, - verbosity=1, - min_loss_scale=None, - max_loss_scale=16777216.0 -) -``` - -Parameters: -- enabled(bool, optional, default=True): If False, renders all AMP calls no-ops, so your script should run as if Amp were not present. - -- opt_level(str, optional, default="O1" ): Pure or mixed precision optimization level. -Accepted values are “O0”, “O1”, “O2”, and “O3”, explained in detail above Apex AMP Documentation. - -- num_losses(int, optional, default=1): Option to tell AMP in advance how many losses/backward passes you plan to use. -When used in conjunction with the loss_id argument to `amp.scale_loss`, enables Amp to use a different loss scale per -loss/backward pass, which can improve stability. If num_losses is left to 1, Amp will still support multiple -losses/backward passes, but use a single global loss scale for all of them. - -- verbosity(int, default=1): Set to 0 to suppress Amp-related output. - -- min_loss_scale(float, default=None): Sets a floor for the loss scale values that can be chosen by dynamic loss scaling. -The default value of None means that no floor is imposed. If dynamic loss scaling is not used, min_loss_scale is ignored. - -- max_loss_scale(float, default=2.**24 ): Sets a ceiling for the loss scale values that can be chosen by dynamic loss -scaling. If dynamic loss scaling is not used, max_loss_scale is ignored. - -Currently, the under-the-hood properties that govern pure or mixed precision training are the following: -cast_model_type, patch_torch_functions, keep_batchnorm_fp32, master_weights, loss_scale. -They are optional properties override once opt_level is determined - -- cast_model_type: Casts your model’s parameters and buffers to the desired type. -- patch_torch_functions: Patch all Torch functions and Tensor methods to perform Tensor Core-friendly ops like GEMMs and convolutions in FP16, and any ops that benefit from FP32 precision in FP32. -- keep_batchnorm_fp32: To enhance precision and enable cudnn batchnorm (which improves performance), it’s often beneficial to keep batchnorm weights in FP32 even if the rest of the model is FP16. -- master_weights: Maintain FP32 master weights to accompany any FP16 model weights. FP32 master weights are stepped by the optimizer to enhance precision and capture small gradients. -- loss_scale: If loss_scale is a float value, use this value as the static (fixed) loss scale. If loss_scale is the string "dynamic", adaptively adjust the loss scale over time. Dynamic loss scale adjustments are performed by Amp automatically. - - -### Naive AMP Configuration - -In Naive AMP mode, we achieved mixed precision training while maintaining compatibility with complex tensor and pipeline parallelism. -This AMP mode will cast all operations into fp16. -The following code block shows the `config.py` file for this mode. - -```python -from colossalai.amp import AMP_TYPE - -fp16 = dict( - mode=AMP_TYPE.NAIVE, - - # below are the default values - log_num_zeros_in_grad=False, - initial_scale=2 ** 32, - min_scale=1, - growth_factor=2, - backoff_factor=0.5, - growth_interval=1000, - hysteresis=2 -) -``` - -The default parameters of Naive AMP: -- log_num_zeros_in_grad(bool): return number of zeros in the gradients. -- initial_scale(int): initial scale of gradient scaler -- growth_factor(int): the growth rate of loss scale -- backoff_factor(float): the decrease rate of loss scale -- hysteresis(int): delay shift in dynamic loss scaling -- max_scale(int): maximum loss scale allowed -- verbose(bool): if set to `True`, will print debug info - -When using `colossalai.initialize`, you are required to first instantiate a model, an optimizer and a criterion. -The output model is converted to AMP model of smaller memory consumption. -If your input model is already too large to fit in a GPU, please instantiate your model weights in `dtype=torch.float16`. -Otherwise, try smaller models or checkout more parallelization training techniques! - - -## Hands-on Practice - -We provide a [runnable example](https://github.com/hpcaitech/ColossalAI-Examples/tree/main/features/amp) which demonstrates -the use of AMP with Colossal-AI. In this practice, we will use Torch AMP as an example, but do note that config files are provided for all AMP modes. - -### Step 1. Create a config file - -Create a `config.py` and add the `fp16` configuration. - -```python -# in config.py -from colossalai.amp import AMP_TYPE - -BATCH_SIZE = 128 -DROP_RATE = 0.1 -NUM_EPOCHS = 300 - -fp16 = dict( - mode=AMP_TYPE.TORCH, -) - -clip_grad_norm = 1.0 -``` - -### Step 2. Import libraries in train_with_engine.py - -Create a `train_with_engine.py` and import the necessary dependencies. Remember to install `scipy` and `timm` by running -`pip install timm scipy`. - -```python -import os -import colossalai -import torch -from pathlib import Path -from colossalai.core import global_context as gpc -from colossalai.logging import get_dist_logger -from colossalai.utils import get_dataloader -from colossalai.trainer import Trainer, hooks -from colossalai.nn.lr_scheduler import LinearWarmupLR -from timm.models import vit_base_patch16_224 -from torchvision import datasets, transforms - -``` - -### Step 3. Initialize Distributed Environment - -We then need to initialize distributed environment. For demo purpose, we uses `launch_from_torch`. You can refer to [Launch Colossal-AI](../basics/launch_colossalai.md) -for other initialization methods. - -```python -# initialize distributed setting -parser = colossalai.get_default_parser() -args = parser.parse_args() - -# launch from torch -colossalai.launch_from_torch(config=args.config) - -``` - -### Step 4. Create training components - -Build your model, optimizer, loss function, lr scheduler and dataloaders. Note that the root path of the dataset is -obtained from the environment variable `DATA`. You may `export DATA=/path/to/data` or change `Path(os.environ['DATA'])` -to a path on your machine. Data will be automatically downloaded to the root path. - -```python -# build model - model = vit_base_patch16_224(drop_rate=0.1) - - # build dataloader - train_dataset = datasets.Caltech101( - root=Path(os.environ['DATA']), - download=True, - transform=transforms.Compose([ - transforms.Resize(256), - transforms.RandomResizedCrop(224), - transforms.RandomHorizontalFlip(), - transforms.ToTensor(), - Gray2RGB(), - transforms.Normalize([0.5, 0.5, 0.5], - [0.5, 0.5, 0.5]) - ])) - - train_dataloader = get_dataloader(dataset=train_dataset, - shuffle=True, - batch_size=gpc.config.BATCH_SIZE, - num_workers=1, - pin_memory=True, - ) - - # build optimizer - optimizer = torch.optim.SGD(model.parameters(), lr=1e-2, weight_decay=0.1) - - # build loss - criterion = torch.nn.CrossEntropyLoss() - - # lr_scheduler - lr_scheduler = LinearWarmupLR(optimizer, warmup_steps=50, total_steps=gpc.config.NUM_EPOCHS) -``` - -### Step 5. Inject AMP Feature - -Call `colossalai.initialize` to convert the training components to be running with FP16. - -```python -engine, train_dataloader, _, _ = colossalai.initialize( - model, optimizer, criterion, train_dataloader, - ) -``` - -### Step 6. Train with Engine - -Use engine in a normal training loops. - -```python -engine.train() -for epoch in range(gpc.config.NUM_EPOCHS): - for img, label in enumerate(train_dataloader): - img = img.cuda() - label = label.cuda() - engine.zero_grad() - output = engine(img) - loss = engine.criterion(output, label) - engine.backward(loss) - engine.step() - lr_scheduler.step() -``` - -### Step 7. Invoke Training Scripts - -Use the following command to start the training scripts. You can change `--nproc_per_node` to use a different number of GPUs. - -```python -python -m torch.distributed.launch --nproc_per_node 4 --master_addr localhost --master_port 29500 train_with_engine.py --config config/config_AMP_torch.py -``` diff --git a/docs/source/en/features/mixed_precision_training_with_booster.md b/docs/source/en/features/mixed_precision_training_with_booster.md new file mode 100644 index 0000000000000000000000000000000000000000..8e702a578ea4280a40f90e073df9d4ed96604730 --- /dev/null +++ b/docs/source/en/features/mixed_precision_training_with_booster.md @@ -0,0 +1,258 @@ +# Auto Mixed Precision Training + +Author: [Mingyan Jiang](https://github.com/jiangmingyan) + +**Prerequisite** + +- [Training Booster](../basics/booster_api.md) + +**Related Paper** + +- [Accelerating Scientific Computations with Mixed Precision Algorithms](https://arxiv.org/abs/0808.2794) + +## Introduction + +AMP stands for automatic mixed precision training. +In Colossal-AI, we have incorporated different implementations of mixed precision training: + +1. torch.cuda.amp +2. apex.amp +3. naive amp + +| Colossal-AI | support tensor parallel | support pipeline parallel | fp16 extent | +| -------------- | ----------------------- | ------------------------- | ---------------------------------------------------------------------------------------------------- | +| AMP_TYPE.TORCH | ✅ | ❌ | Model parameters, activation, gradients are downcast to fp16 during forward and backward propagation | +| AMP_TYPE.APEX | ❌ | ❌ | More fine-grained, we can choose opt_level O0, O1, O2, O3 | +| AMP_TYPE.NAIVE | ✅ | ✅ | Model parameters, forward and backward operations are all downcast to fp16 | + +The first two rely on the original implementation of PyTorch (version 1.6 and above) and NVIDIA Apex. +The last method is similar to Apex O2 level. +Among these methods, apex AMP is not compatible with tensor parallelism. +This is because that tensors are split across devices in tensor parallelism, thus, it is required to communicate among different processes to check if inf or nan occurs in the whole model weights. +We modified the torch amp implementation so that it is compatible with tensor parallelism now. + +> ❌️ fp16 and zero are not compatible +> +> ⚠️ Pipeline only support naive AMP currently + +We recommend you to use torch AMP as it generally gives better accuracy than naive AMP if no pipeline is used. + +## Table of Contents + +In this tutorial we will cover: + +1. [AMP introduction](#amp-introduction) +2. [AMP in Colossal-AI](#amp-in-colossal-ai) +3. [Hands-on Practice](#hands-on-practice) + +## AMP Introduction + +Automatic Mixed Precision training is a mixture of FP16 and FP32 training. + +Half-precision float point format (FP16) has lower arithmetic complexity and higher compute efficiency. Besides, fp16 requires half of the storage needed by fp32 and saves memory & network bandwidth, which makes more memory available for large batch size and model size. + +However, there are other operations, like reductions, which require the dynamic range of fp32 to avoid numeric overflow/underflow. That's the reason why we introduce automatic mixed precision, attempting to match each operation to its appropriate data type, which can reduce the memory footprint and augment training efficiency. + +
              + +
              Illustration of an ordinary AMP (figure from PatrickStar paper)
              +
              + +## AMP in Colossal-AI + +We supported three AMP training methods and allowed the user to train with AMP with no code. If you want to train with amp, just assign `mixed_precision` with `fp16` when you instantiate the `Booster`. Next we will support `bf16`, `fp8`. + +### Start with Booster + +instantiate `Booster` with `mixed_precision="fp16"`, then you can train with torch amp. + + + +```python +""" + Mapping: + 'fp16': torch amp + 'fp16_apex': apex amp, + 'bf16': bf16, + 'fp8': fp8, + 'fp16_naive': naive amp +""" +from colossalai import Booster +booster = Booster(mixed_precision='fp16',...) +``` + + + +or you can create a `FP16TorchMixedPrecision` object, such as: + + + +```python +from colossalai.mixed_precision import FP16TorchMixedPrecision +mixed_precision = FP16TorchMixedPrecision( + init_scale=2.**16, + growth_factor=2.0, + backoff_factor=0.5, + growth_interval=2000) +booster = Booster(mixed_precision=mixed_precision,...) +``` + + + +The same goes for other types of amps. + +### Torch AMP Configuration + +{{ autodoc:colossalai.booster.mixed_precision.FP16TorchMixedPrecision }} + +### Apex AMP Configuration + +For this mode, we rely on the Apex implementation for mixed precision training. +We support this plugin because it allows for finer control on the granularity of mixed precision. +For example, O2 level (optimization level 2) will keep batch normalization in fp32. + +If you look for more details, please refer to [Apex Documentation](https://nvidia.github.io/apex/). + +{{ autodoc:colossalai.booster.mixed_precision.FP16ApexMixedPrecision }} + +### Naive AMP Configuration + +In Naive AMP mode, we achieved mixed precision training while maintaining compatibility with complex tensor and pipeline parallelism. +This AMP mode will cast all operations into fp16. +The following code block shows the mixed precision api for this mode. + +{{ autodoc:colossalai.booster.mixed_precision.FP16NaiveMixedPrecision }} + +When using `colossalai.booster`, you are required to first instantiate a model, an optimizer and a criterion. +The output model is converted to AMP model of smaller memory consumption. +If your input model is already too large to fit in a GPU, please instantiate your model weights in `dtype=torch.float16`. +Otherwise, try smaller models or checkout more parallelization training techniques! + +## Hands-on Practice + +Now we will introduce the use of AMP with Colossal-AI. In this practice, we will use Torch AMP as an example. + +### Step 1. Import libraries in train.py + +Create a `train.py` and import the necessary dependencies. Remember to install `scipy` and `timm` by running +`pip install timm scipy`. + +```python +import os +from pathlib import Path + +import torch +from timm.models import vit_base_patch16_224 +from titans.utils import barrier_context +from torchvision import datasets, transforms + +import colossalai +from colossalai.booster import Booster +from colossalai.booster.plugin import TorchDDPPlugin +from colossalai.logging import get_dist_logger +from colossalai.nn.lr_scheduler import LinearWarmupLR +``` + +### Step 2. Initialize Distributed Environment + +We then need to initialize distributed environment. For demo purpose, we uses `launch_from_torch`. You can refer to [Launch Colossal-AI](../basics/launch_colossalai.md) +for other initialization methods. + +```python +# initialize distributed setting +parser = colossalai.get_default_parser() +args = parser.parse_args() + +# launch from torch +colossalai.launch_from_torch(config=dict()) + +``` + +### Step 3. Create training components + +Build your model, optimizer, loss function, lr scheduler and dataloaders. Note that the root path of the dataset is +obtained from the environment variable `DATA`. You may `export DATA=/path/to/data` or change `Path(os.environ['DATA'])` +to a path on your machine. Data will be automatically downloaded to the root path. + +```python +# define the constants +NUM_EPOCHS = 2 +BATCH_SIZE = 128 + +# build model +model = vit_base_patch16_224(drop_rate=0.1) + +# build dataloader +train_dataset = datasets.Caltech101( + root=Path(os.environ['DATA']), + download=True, + transform=transforms.Compose([ + transforms.Resize(256), + transforms.RandomResizedCrop(224), + transforms.RandomHorizontalFlip(), + transforms.ToTensor(), + Gray2RGB(), + transforms.Normalize([0.5, 0.5, 0.5], + [0.5, 0.5, 0.5]) + ])) + +# build optimizer +optimizer = torch.optim.SGD(model.parameters(), lr=1e-2, weight_decay=0.1) + +# build loss +criterion = torch.nn.CrossEntropyLoss() + +# lr_scheduler +lr_scheduler = LinearWarmupLR(optimizer, warmup_steps=50, total_steps=NUM_EPOCHS) +``` + +### Step 4. Inject AMP Feature + +Create a `MixedPrecision`(if needed) and `TorchDDPPlugin` object, call `colossalai.boost` convert the training components to be running with FP16. + +```python +plugin = TorchDDPPlugin() +train_dataloader = plugin.prepare_dataloader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, drop_last=True) +booster = Booster(mixed_precision='fp16', plugin=plugin) + +# if you need to customize the config, do like this +# >>> from colossalai.mixed_precision import FP16TorchMixedPrecision +# >>> mixed_precision = FP16TorchMixedPrecision( +# >>> init_scale=2.**16, +# >>> growth_factor=2.0, +# >>> backoff_factor=0.5, +# >>> growth_interval=2000) +# >>> plugin = TorchDDPPlugin() +# >>> booster = Booster(mixed_precision=mixed_precision, plugin=plugin) + +# boost model, optimizer, criterion, dataloader, lr_scheduler +model, optimizer, criterion, dataloader, lr_scheduler = booster.boost(model, optimizer, criterion, dataloader, lr_scheduler) +``` + +### Step 5. Train with Booster + +Use booster in a normal training loops. + +```python +model.train() +for epoch in range(NUM_EPOCHS): + for img, label in enumerate(train_dataloader): + img = img.cuda() + label = label.cuda() + optimizer.zero_grad() + output = model(img) + loss = criterion(output, label) + booster.backward(loss, optimizer) + optimizer.step() + lr_scheduler.step() +``` + +### Step 6. Invoke Training Scripts + +Use the following command to start the training scripts. You can change `--nproc_per_node` to use a different number of GPUs. + +```shell +colossalai run --nproc_per_node 1 train.py +``` + + diff --git a/docs/source/en/features/nvme_offload.md b/docs/source/en/features/nvme_offload.md index 4374da3c9c4558f5aca8e7d15603e95ef0477573..6ed6f2dee5d67816a1991844d2a326eaa7f8dfe4 100644 --- a/docs/source/en/features/nvme_offload.md +++ b/docs/source/en/features/nvme_offload.md @@ -53,7 +53,7 @@ It's compatible with all parallel methods in ColossalAI. > ⚠ It only offloads optimizer states on CPU. This means it only affects CPU training or Zero/Gemini with offloading. -## Exampls +## Examples Let's start from two simple examples -- training GPT with different methods. These examples relies on `transformers`. @@ -78,8 +78,9 @@ from transformers.models.gpt2.modeling_gpt2 import GPT2LMHeadModel import colossalai from colossalai.nn.optimizer import HybridAdam -from colossalai.zero import zero_model_wrapper, zero_optim_wrapper from colossalai.utils.model.colo_init_context import ColoInitContext +from colossalai.booster import Booster +from colossalai.booster.plugin import GeminiPlugin ``` Then we define a loss function: @@ -192,17 +193,23 @@ def train_gemini_cpu(nvme_offload_fraction: float = 0.0): optimizer = HybridAdam(model.parameters(), nvme_offload_fraction=nvme_offload_fraction) print(f'Model numel: {get_model_numel(model) / 1024**3:.3f} B') - gemini_config = dict(strict_ddp_mode=True, device=torch.cuda.current_device(), - placement_policy='cpu', pin_memory=True, hidden_dim=config.n_embd) - model = zero_model_wrapper(model, zero_stage=3, gemini_config=gemini_config) - optimizer = zero_optim_wrapper(model, optimizer, initial_scale=2**5) + plugin = GeminiPlugin( + strict_ddp_mode=True, + device=torch.cuda.current_device(), + placement_policy='cpu', + pin_memory=True, + hidden_dim=config.n_embd, + initial_scale=2**5 + ) + booster = Booster(plugin) + model, optimizer, criterion, _* = booster.boost(model, optimizer, criterion) start = time.time() for step in range(3): data = get_data(4, 128, config.vocab_size) outputs = model(**data) loss = criterion(outputs.logits, data['input_ids']) - optimizer.backward(loss) + booster.backward(loss, optimizer) optimizer.step() optimizer.zero_grad() print(f'[{step}] loss: {loss.item():.3f}') diff --git a/docs/source/en/features/pipeline_parallel.md b/docs/source/en/features/pipeline_parallel.md index ac49863b3c719241f8e7cc6349c3730c85519d03..cb19f9815bf2772b5f80eff53854d9bc66d16e75 100644 --- a/docs/source/en/features/pipeline_parallel.md +++ b/docs/source/en/features/pipeline_parallel.md @@ -1,14 +1,15 @@ # Pipeline Parallel -Author: Guangyang Lu, Hongxin Liu, Yongbin Li +Author: Guangyang Lu, Hongxin Liu, Yongbin Li, Mingyan Jiang **Prerequisite** -- [Define Your Configuration](../basics/define_your_config.md) -- [Use Engine and Trainer in Training](../basics/engine_trainer.md) -- [Configure Parallelization](../basics/configure_parallelization.md) +- [Paradigms of Parallelism](../concepts/paradigms_of_parallelism.md) +- [Use Booster to Training](../basics/booster_api.md) +- [Shardformer](../features/shardformer.md) +- [Plugin of Booster](../basics/booster_plugins.md) **Example Code** -- [ColossalAI-Examples ResNet with pipeline](https://github.com/hpcaitech/ColossalAI-Examples/tree/main/features/pipeline_parallel) +- [Fine-tune Bert with pipeline](https://github.com/hpcaitech/ColossalAI/blob/main/examples/language/bert/finetune.py) **Related Paper** - [Colossal-AI: A Unified Deep Learning System For Large-Scale Parallel Training](https://arxiv.org/abs/2110.14883) @@ -17,7 +18,7 @@ Author: Guangyang Lu, Hongxin Liu, Yongbin Li ## Quick introduction -In this tutorial, you will learn how to use pipeline parallel. In Colossal-AI, we use 1F1B pipeline, introduced by Nvidia. In this case, ViT and Imagenet are too large to use. Therefore, here we use ResNet and Cifar as example. +In this tutorial, you will learn how to use pipeline parallel. In Colossal-AI, we use 1F1B pipeline, introduced by Nvidia. In this case, ViT and Imagenet are too large to use. Therefore, here we use bert model and glue dataset as example. ## Table Of Content @@ -25,7 +26,7 @@ In this tutorial we will cover: 1. Introduction of 1F1B pipeline. 2. Usage of non-interleaved and interleaved schedule. -3. Training ResNet with pipeline. +3. Finetune Bert with pipeline. ## Introduction of 1F1B pipeline @@ -60,100 +61,158 @@ In this schedule, each device can perform computation for multiple subsets of la This mode is both memory-efficient and time-efficient. -## Usage of non-interleaved and interleaved schedule +## Colossal-AI's Implementation -In Colossal-AI, we provided both non-interleaved(as `PipelineSchedule`) and interleaved schedule(as `InterleavedPipelineSchedule`). +In Colossal-AI, pipeline parallelism relies on the `scheduler` and [`Shardformer`](../features/shardformer.md). We provide both non-interleaved (`OneForwardOneBackwardSchedule`) and interleaved (`InterleavedSchedule`) schedules. While `Shardformer` implements layer splitting for models and replaces the `forward` function of the model to make it compatible with the scheduler. -You just need to set `NUM_MICRO_BATCHES` in config file and set `NUM_CHUNKS` in config file if you want to use Interleaved Pipeline Schedule. If you certainly know the shape of each pipeline stage's output tensor and the shapes are all the same, you can set `TENSOR_SHAPE` in config file to further reduce communication. Otherwise, you can just ignore `tensor_shape`, and the shape will be exchanged over pipeline stages automatically. Then we will generate an appropriate schedule for you. +In Colossal-AI, the `HybridParallelPlugin` encapsulates pipeline execution strategies. It manages pipeline parallel communication groups and a scheduler. When boosting the model with this plugin, the model's layers are split by calling the `shardformer.optimize` function, and then `execute_pipeline` is called to execute the model in segments using `OneForwardOneBackwardSchedule` which is default scheduler used in `HybridParallelPlugin`, and `InterleavedSchedule` will be integrated later. -## Training ResNet with pipeline +You can customize your parallel strategy by setting parameters for the `HybridParallelPlugin`. -Let's build the `ResNet` model first with Colossal PipelinableContext: +For more usage details, please refer to the [documentation](../basics/booster_plugins.md) for `HybridParallelPlugin`. + +## Fine-tune Bert with pipeline + +First, we define the necessary training components, including model, dataloader, optimizer, lr_scheduler, criterion: ```python -import os -from typing import Callable, List, Optional, Type, Union +import argparse +from typing import Callable, List, Union + import torch import torch.nn as nn +from data import GLUEDataBuilder +from torch.optim import Adam, Optimizer +from torch.optim.lr_scheduler import _LRScheduler as LRScheduler +from torch.utils.data import DataLoader +from tqdm import tqdm +from transformers import ( + AlbertForSequenceClassification, + AutoConfig, + BertForSequenceClassification, + get_linear_schedule_with_warmup, +) + import colossalai -import colossalai.nn as col_nn +from colossalai.booster import Booster +from colossalai.booster.plugin import HybridParallelPlugin +from colossalai.cluster import DistCoordinator +from colossalai.nn.optimizer import HybridAdam -from colossalai.core import global_context as gpc -from colossalai.logging import disable_existing_loggers, get_dist_logger -from colossalai.trainer import Trainer, hooks -from colossalai.utils import MultiTimer, get_dataloader -from colossalai.context import ParallelMode -from colossalai.pipeline.pipelinable import PipelinableContext +# Define some config +NUM_EPOCHS = 3 +BATCH_SIZE = 32 +LEARNING_RATE = 2.4e-5 +WEIGHT_DECAY = 0.01 +WARMUP_FRACTION = 0.1 + +coordinator = DistCoordinator() + +def move_to_cuda(batch): + return {k: v.cuda() for k, v in batch.items()} + + +# Define 'criterion' function with two inputs, which will be passed to 'execute_pipeline'. +def _criterion(outputs, inputs): + return outputs.loss + +# Define optimizer +lr = LEARNING_RATE +no_decay = ["bias", "LayerNorm.weight"] +optimizer_grouped_parameters = [ + { + "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], + "weight_decay": WEIGHT_DECAY, + }, + { + "params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], + "weight_decay": 0.0, + }, +] -from titans.dataloader.cifar10 import build_cifar -from torchvision.models import resnet50 -from torchvision.models.resnet import BasicBlock, Bottleneck, conv1x1 +optimizer = HybridAdam(optimizer_grouped_parameters, lr=lr, eps=1e-8) -# Define some config -BATCH_SIZE = 64 -NUM_EPOCHS = 2 -NUM_CHUNKS = 1 -CONFIG = dict(NUM_MICRO_BATCHES=4, parallel=dict(pipeline=2)) - -# Train -disable_existing_loggers() -parser = colossalai.get_default_parser() -args = parser.parse_args() -colossalai.launch_from_torch(backend=args.backend, config=CONFIG) -logger = get_dist_logger() -pipelinable = PipelinableContext() - -# build model -with pipelinable: - model = resnet50() -``` -Define an execution sequence. -```python -exec_seq = [ - 'conv1', 'bn1', 'relu', 'maxpool', 'layer1', 'layer2', 'layer3', 'layer4', 'avgpool', - (lambda x: torch.flatten(x, 1), "behind"), 'fc' -] -pipelinable.to_layer_list(exec_seq) +# Define lr_scheduler +total_steps = len(train_dataloader) * NUM_EPOCHS +num_warmup_steps = int(WARMUP_FRACTION * total_steps) +lr_scheduler = get_linear_schedule_with_warmup( + optimizer, + num_warmup_steps=num_warmup_steps, + num_training_steps=total_steps, +) + + +# Define Bert model +model = BertForSequenceClassification.from_pretrained("bert-base-uncased", config=cfg).cuda() + +# Define a dataloader +data_builder = GLUEDataBuilder(model_name, + plugin, + args.task, + train_batch_size=BATCH_SIZE, + eval_batch_size=BATCH_SIZE) +train_dataloader = data_builder.train_dataloader() ``` -Partition the model into pipeline. +Define a booster with the `HybridParallelPlugin`. ```python -model = pipelinable.partition(NUM_CHUNKS, gpc.pipeline_parallel_size, gpc.get_local_rank(ParallelMode.PIPELINE)) +plugin = HybridParallelPlugin(tp_size=1, + pp_size=2, + num_microbatches=None, + microbatch_size=1, + enable_all_optimization=True, + zero_stage=1, + precision='fp16', + initial_scale=1) +booster = Booster(plugin=plugin) ``` -In this tutorial, we use `Trainer` to train `ResNet`: +Boost these train componts with the booster created. ```python -# build criterion -criterion = nn.CrossEntropyLoss() - -# optimizer -optimizer = torch.optim.Adam(model.parameters(), lr=1e-3) - -# build dataloader -root = os.environ.get('DATA', './data') -train_dataloader, test_dataloader = build_cifar(BATCH_SIZE, root, padding=4, crop=32, resize=32) - -lr_scheduler = col_nn.lr_scheduler.LinearWarmupLR(optimizer, NUM_EPOCHS, warmup_steps=1) -engine, train_dataloader, test_dataloader, lr_scheduler = colossalai.initialize(model, optimizer, criterion, - train_dataloader, test_dataloader, - lr_scheduler) -timer = MultiTimer() +model, optimizer, _criterion, _, lr_scheduler = booster.boost(model, + optimizer, + criterion=_criterion, + lr_scheduler=lr_scheduler) +``` -trainer = Trainer(engine=engine, timer=timer, logger=logger) +Train the model at last. -hook_list = [ - hooks.LossHook(), - hooks.AccuracyHook(col_nn.metric.Accuracy()), - hooks.LogMetricByEpochHook(logger), - hooks.LRSchedulerHook(lr_scheduler, by_epoch=True) -] - -trainer.fit(train_dataloader=train_dataloader, - epochs=NUM_EPOCHS, - test_dataloader=test_dataloader, - test_interval=1, - hooks=hook_list, - display_progress=True) +```python +# Define a train function +def train_epoch(epoch: int, model: nn.Module, optimizer: Optimizer, _criterion: Callable, lr_scheduler: LRScheduler, + train_dataloader: DataLoader, booster: Booster, coordinator: DistCoordinator): + + is_pp_last_stage = booster.plugin.stage_manager.is_last_stage() + total_step = len(train_dataloader) + + model.train() + optimizer.zero_grad() + # convert train_dataloader to a iterator + train_dataloader_iter = iter(train_dataloader) + with tqdm(range(total_step), + desc=f'Epoch [{epoch + 1}/{NUM_EPOCHS}]', + disable=not (is_pp_last_stage)) as pbar: + # Forward pass + for _ in pbar: + outputs = booster.execute_pipeline(train_dataloader_iter, + model, + _criterion, + optimizer, + return_loss=True, + return_outputs=True) + # Backward and optimize + if is_pp_last_stage: + loss = outputs['loss'] + pbar.set_postfix({'loss': loss.item()}) + + optimizer.step() + optimizer.zero_grad() + lr_scheduler.step() + +# Train model +for epoch in range(NUM_EPOCHS): + train_epoch(epoch, model, optimizer, _criterion, lr_scheduler, train_dataloader, booster, coordinator) ``` -We use `2` pipeline stages and the batch will be splitted into `4` micro batches. +We use `2` pipeline stages and the micro batches is 1. (these parameters can be configured to an appropriate value) + diff --git a/docs/source/en/features/shardformer.md b/docs/source/en/features/shardformer.md new file mode 100644 index 0000000000000000000000000000000000000000..a6e32d2c05fa0994ae11735f070fa47b56624f7e --- /dev/null +++ b/docs/source/en/features/shardformer.md @@ -0,0 +1,349 @@ +# Shardformer + +Author: [Baizhou Zhang](https://github.com/Fridge003), [Bin Jia](https://github.com/FoolPlayer) + +**Prerequisite** +- [Paradigms of Parallelism](../concepts/paradigms_of_parallelism.md) +- [Booster API](../basics/booster_api.md) +- [Booster Plugins](../basics/booster_plugins.md) + +**Example Code** +- [Tensor Parallelism with Shardformer](https://github.com/hpcaitech/ColossalAI/tree/main/colossalai/shardformer/examples) +- [Enabling Shardformer using HybridPrallelPlugin](https://github.com/hpcaitech/ColossalAI/tree/main/examples/language/bert) + +**Related Paper** +- [Efficient Large-Scale Language Model Training on GPU Clusters Using Megatron-LM](https://arxiv.org/abs/2104.04473) +- [GPipe: Efficient Training of Giant Neural Networks using Pipeline Parallelism](https://arxiv.org/abs/1811.06965) +- [FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning](https://arxiv.org/abs/2307.08691) +- [Sequence Parallelism: Long Sequence Training from System Perspective](https://arxiv.org/abs/2105.13120) +- [Reducing Activation Recomputation in Large Transformer Models](https://arxiv.org/abs/2205.05198) + +## Introduction + +When training large transformer models such as LLaMa-2 70B or OPT 175B, model parallelism methods that divide a huge model into smaller shards, including tensor parallelism or pipeline parallism, are essential so as to meet the limitation of GPU memory. +However, manually cutting model and rewriting its forward/backword logic could be difficult for users who are not familiar with distributed training. +Meanwhile, the Huggingface transformers library has gradually become users' first choice of model source, and most mainstream large models have been open-sourced in Huggingface transformers model library. + +Out of this motivation, the ColossalAI team develops **Shardformer**, a feature that automatically does preparation of model parallelism (tensor parallelism/pipeline parallelism) for popular transformer models in HuggingFace. +This module aims to make parallelization hassle-free for users who are not from the system background. +Within a few lines of codes, users can turn a model into a state ready for distributed training. +Also, Shardformer contains various optimization tools for acceleration and memory saving during forward/backward pass. + +## Supporting Information + +Model/Feature Compatibility Matrix: + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
              Model/FeatureTensor
              Parallel
              Pipeline
              Parallel
              Lazy
              Initialization
              xFormersFlash
              Attention 2
              JIT Fused
              Operators
              Fused
              LayerNorm
              Sequence
              Parallel
              Sequence
              Overlap
              Llama V1/V2✔️✔️✔️✔️✔️✔️✔️
              OPT✔️✔️✔️✔️✔️✔️✔️
              BLOOM✔️✔️✔️✔️✔️✔️✔️✔️✔️
              ChatGLM 2✔️✔️✔️✔️✔️✔️✔️✔️✔️
              BERT✔️✔️✔️✔️✔️✔️✔️✔️✔️
              GPT 2✔️✔️✔️✔️✔️✔️✔️✔️✔️
              T5✔️✔️✔️✔️✔️✔️✔️
              ViT✔️✔️✔️✔️✔️✔️
              Whisper✔️✔️✔️✔️✔️✔️
              SAM✔️✔️✔️✔️✔️
              Blip2✔️✔️✔️✔️✔️
              + +List of model families we plan to support in the near future: +- RoBERTa +- ALBERT +- ERNIE +- GPT Neo +- GPT-J +- BEiT +- SwinTransformer V1/V2 +- qwen + +The support matrix will grow larger as more models and optimization tools emerge in the future. If you have any suggestions on the models/optimization we should support, please feel free to mention it in [Issues](https://github.com/hpcaitech/ColossalAI/issues) section of our project. + +## Usage + +### Shardformer Configuration + +The configuration of Shardformer is controlled by class `ShardConfig`: + +{{ autodoc:colossalai.shardformer.ShardConfig }} + +If you want to enable Apex Fused Layernorm, please install `apex`. +If you want to enable the usage of flash attention, please install `flash_attn`. +In addition, xFormers's `cutlass_op` can serve as a backup for flash attention. + +### Enabling Shardformer + +#### 1. Enabling Shardformer Through Booster (Recommended) + +Enabling `Shardformer` through `Booster` initialized with `HybridParallelPlugin` is the recommended way to awaken the power of Shardformer. +The main reason is that pipeline parallelism cannot successfully work without the calling of `execute_pipeline` method of `Booster`. Besides, `HybridParallelPlugin` provides the capacity to combine the features of `Shardformer` with other useful features, such as mixed precision training or Zero. + +[Here](https://github.com/hpcaitech/ColossalAI/tree/main/examples/language/bert) is an example on how to trigger `Shardformer` through `HybridParallelPlugin`. Move to the root directory of this example, and execute +```bash +torchrun --standalone --nproc_per_node 4 finetune.py --target_f1 0.86 --plugin "hybrid_parallel" --model_type "bert" +``` +Then you can start finetuning a bert model wrapped by `Shardformer`. The process of wrapping is operated by `HybridParallelPlugin`. + +Let's delve into the code of `finetune.py`: + +In the `main` function, the plugin is created through the following codes: +```python +... +elif args.plugin == "hybrid_parallel": + # modify the param accordingly for finetuning test cases + plugin = HybridParallelPlugin( + tp_size=1, + pp_size=2, + num_microbatches=None, + microbatch_size=1, + enable_all_optimization=True, + zero_stage=1, + precision="fp16", + initial_scale=1, + ) +``` +Here you can change the configuration of plugin by setting `tp_size`, `pp_size` or `zero_stage` to other values. More details about plugin configuration can be found in [Booster Plugins Doc](../basics/booster_plugins.md). + +If pipeline parallel is not enabled, just do the training in the same way of other booster plugins(first boost with Booster, then do forward and backward through normal way). +However, if pipeline parallel is enabled, there are several usages different from other normal cases: + +1. Before doing forward or backward, the criterion function (loss function) is processed to meet the argument demand of running pipeline: + ```python + def _criterion(outputs, inputs): + outputs = output_transform_fn(outputs) + loss = criterion(outputs) + return loss + ``` + +2. In `train_epoch` function, dataloader is converted into `Iterator` class before running pipeline: + ```python + train_dataloader_iter = iter(train_dataloader) + ``` + +3. Do forward and backward passing through calling `Booster.execute_pipeline` method: + ```python + outputs = booster.execute_pipeline( + train_dataloader_iter, model, _criterion, optimizer, return_loss=True, return_outputs=True + ) + ``` + Backward passing has been completed by this method, so there is no need to call `loss.backward()` after executing this method. + More details about `Booster.execute_pipeline` can be found in [Booster API Doc](../basics/booster_api.md). + + +#### 2. Enabling Shardformer Through Shardformer APIs (Not Recommended) + +You can also use Shardformer through manually calling Shardformer APIs. However, this usage is not recommended since pipeline parallelism can't run without `Booster`. + +[Here](https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/shardformer/examples/convergence_benchmark.py) +is an example on how to trigger `Shardformer` through calling Shardformer APIs. In the `train` function of example code, the model is wrapped by `Shardformer` through the following few codes: +```python +... +if dist.get_world_size() > 1: + tp_group = dist.new_group(backend="nccl") + + # First create configuration for Shardformer + shard_config = ShardConfig( + tensor_parallel_process_group=tp_group, + enable_tensor_parallelism=True, + enable_all_optimization=True + ) + + # Then create ShardFormer object with created config + shard_former = ShardFormer(shard_config=shard_config) + + # Finally shard the model using ShardFormer.optimize method + model, _ = shard_former.optimize(model) +... +``` + +### Precautions + +1. When enabling pipeline parallel, please don't do the forward/backward pass in the conventional way (`model(input)`, `loss.backward()`), which will cause unexpected errors. Rather, please do forward/backward pass through calling `booster.execute_pipeline` method. + +2. When you use Shardformer to process classification models such as `GPT2ForSequenceClassification`, `ViTForImageClassification`, please ensure that the total number of labels should be integer multiple of tensor parallel size, otherwise Shardformer can't process the classifier layer correctly. A simple fix could be appending dummy labels in transformers config. This bug will be fixed in future version of Shardformer. + +3. The case of training ChatGLM-2 6B is a little special: since Huggingface transformers doesn't officially support ChatGLM at present, please import the configuration/model classes through + ```python + from colossalai.shardformer.modeling.chatglm2_6b.configuration_chatglm import ChatGLMConfig + from colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm import ChatGLMForConditionalGeneration, ChatGLMModel + ``` + when training ChatGLM-2 with Shardformer, and initialize your model with these imported classes. + +## How Shardformer Works + +### Main Idea + +Generally, Shardformer works through the following four kinds of *replacements*: + +1. Replacing original PyTorch module (e.g. `nn.Linear`, `nn.Embedding`) with a crafted distributed module. +The distributed module keeps the same attributes as the original module but replaces the original parameters with distributed parameters. +Also, new `forward` methods will replace original ones so as to execute distributed computation, such as linear layers' split /gather operations executed under tensor parallelism. +Each distributed module implements its `from_native_module` static method to convert the PyTorch module to its corresponding distributed module. + +2. Replacing attributes of original Huggingface Transformers layers with appropriate attributes for distributed training. +For example, when training LlaMa-2 with tensor parallel size as 2, the attribute `num_heads` of `LlamaDecoderLayer` (the number of attention heads in each layer) should be replaced with `model.config.num_attention_heads // 2`. + +3. Replacing the `forward` methods implemented by original Huggingface +Transformers libraries with our customized `forward` methods. +This replacement is essential for pipeline paralellism, where a customiozed function is needed to pass intermediate hidden states between different pipeline stages. +Also, optimization methods such as flash attention or sequence parallel can be injected into the `forward` process through our customized `forward` method. + +4. Replacing the whole copy of model parameters and optimizer states with incomplete ones controlled by current device (this is why it's called Shardformer). +By executing `ModelSharder.shard` method, current device will only keep the part of model parameters it's supposed to take care of. +To be specific, they should be the assigned parameter shards when using tensor parallelism, or the parameters belonging to current pipeline stage when using pipeline parallelism, or both of them. +All other parameters are released so as to liberate memory usage. +As a result, the optimizer will only compute the states corresponding to these part of parameters, causing the usage of memory to be further saved. + +All of these replacements are implemented with manually written policies and forward functions. +If you want to delve deeper into the design of Shardformer or customize your own Shardformer policies, please refer to our [Shardformer development document](https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/shardformer/README.md) and [pipeline parallelism design](https://github.com/hpcaitech/ColossalAI/discussions/4050) for more details. + +### Sequence Parallelism + +Sequence parallelism is a special optimization method supported by `Shardformer`. Sequence parallelism in `Shardformer` is a little different from [this one](https://colossalai.org/docs/basics/configure_parallelization/#sequence-parallel) which focuses on ring attention. In `Shardformer`, sequence parallelism is only used along with 1D tensor parallelism to further reduce memory occupation of activation tensors during computation. + +1. In normal [1D tensor parallel](https://colossalai.org/docs/features/1D_tensor_parallel), there are 2 communication operations, $g$ and $\vec{g}$, $g$ will do one time All-Reduce in backward to get all gradients from all the devices and $\vec{g}$ will do one time All-Reduce in forward to get whole outputs from all the devices. + +2. When using sequence parallelism, $\vec{g}$ needs to do All-Gather to gather the inputs along sequence dimension during forward, and Reduce-Scatter to split the gradient during backward. $\vec{g}$ needs to do Reduce-Scatter to split the output of `Row Linear` layer of tensor parallel to all devices along sequence dimension, and All-Gather to get the whole gradient during backward. + +3. NCCL's implementation of All-Reduce adopts the `Ring All-Reduce` approach, which consists of a Reduce-Scatter operation and an All-Gather operation with equal costs. Therefore, compared with sequence parallelism and tensor parallelism, it does not introduce additional communication overhead. + +4. One important thing to note is that when using sequence parallelism along with `Column Linear` module of tensor parallelism, the complete input needs to be obtained during the backward computation of gradients. During the forward pass, only the portion of the input that is split along the sequence dimension is retained, in the shape of $(batch, sequence_len/k, hidden_states)$. Therefore, an additional All-Gather operation is required to obtain the complete input for gradient computation. However, it is possible to overlap the gradient computation with the All-Gather communication operation in our implementation, which would not introduce additional communication overhead (corresponding to the `enable_sequence_overlap` parameter in `Shardformer`). + + + diff --git a/docs/source/en/features/zero_with_chunk.md b/docs/source/en/features/zero_with_chunk.md index a105831a54099456cc6489cbb9c5599b0b262933..42305182b8b8cce088003b1dbf6e730be05d08c3 100644 --- a/docs/source/en/features/zero_with_chunk.md +++ b/docs/source/en/features/zero_with_chunk.md @@ -3,7 +3,7 @@ Author: [Hongxiu Liu](https://github.com/ver217), [Jiarui Fang](https://github.com/feifeibear), [Zijian Ye](https://github.com/ZijianYY) **Prerequisite:** -- [Define Your Configuration](../basics/define_your_config.md) +- [Train with booster](../basics/booster_api.md) **Example Code** @@ -54,32 +54,38 @@ We also provide a lightweight chunk search mechanism to help users automatically We will use `GeminiDDP` to use ZeRO with chunk-based memory management. This is our new torch.Module wrapper which uses ZeRO-DP and Gemini. ZeRO is for parallelism and Gemini is for memory management. -Also Make sure that your model is initialized under the context of ColoInitContext. +Gemini allows LazyInitContext, which can save memory when initializing large models with multi-GPUs. +If your model has `N` billion parameters and your GPU memory is `M` GB, we recommend you use LazyInitContext when `4N >= M`. Otherwise, LazyInitContext is optional. + + ```python -with ColoInitContext(device='cpu', default_dist_spec=default_dist_spec, default_pg=default_pg): +with LazyInitContext(default_device=torch.device('cuda')): model = gpt2_medium(checkpoint=True) ``` + + +We've provided `Booster` API which is user-friendly. We recommend you use `Booster` API. But if you still want to use low level API, you can read below content of this section. -Define the model parameters as follows: +Wrap the model with `GeminiDDP`. + ```python -chunk_manager = init_chunk_manager(model=module, - init_device=device, - hidden_dim=hidden_dim, - search_range_mb=search_range_mb, - min_chunk_size_mb=min_chunk_size_mb) -gemini_manager = GeminiManager(placement_policy, chunk_manager) +model = GeminiDDP(model, hidden_dim=hidden_dim, min_chunk_size_m=min_chunk_size_m) ``` + -`hidden_dim` is the hidden dimension of DNN. Users can provide this argument to speed up searching. If users do not know this argument before training, it is ok. We will use a default value 1024. `min_chunk_size_mb` is the the minimum chunk size in MegaByte. If the aggregate size of parameters is still samller than the minimum chunk size, all parameters will be compacted into one small chunk. +`hidden_dim` is the hidden dimension of DNN. Users can provide this argument to speed up searching. If users do not know this argument before training, it is ok. We will use a default value 1024. `min_chunk_size_m` is a floating point, being the minimum chunk size divided by 2^20 (e.g., if min_chunk_size_m=2.5, then the minimum chunk size should be 2.5*(2^20)).If the aggregate size of parameters is still smaller than the minimum chunk size, all parameters will be compacted into one small chunk. Initialization of the optimizer. + ```python optimizer = GeminiAdamOptimizer(model, lr=1e-3, initial_scale=2**5) ``` + Training + ```python optimizer.zero_grad() outputs = model(input_ids, attn_mask) @@ -87,6 +93,7 @@ loss = criterion(outputs, input_ids) optimizer.backward(loss) optimizer.step() ``` + > ⚠️ Note: Please do not use `loss.backward()`, the standard way of writing is `optimizer.backward(loss)`. ### Train GPT @@ -97,6 +104,7 @@ For simplicity, we just use randomly generated data here. First we only need to import `GPT2LMHeadModel` from `Huggingface transformers` to define our model, which does not require users to define or modify the model, so that users can use it more conveniently. +Define a GPT model: ```python class GPTLMModel(nn.Module): @@ -141,74 +149,6 @@ class GPTLMLoss(nn.Module): return self.loss_fn(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) ``` -Define tensor parallel and parameter sharding strategies for tensor parallelism: - -```python -def tensor_parallelize(model: torch.nn.Module, pg: ProcessGroup): - for mn, module in model.named_modules(): - for pn, param in module.named_parameters(recurse=False): - if hasattr(param, 'visited'): - continue - param.set_dist_spec(ReplicaSpec()) - if 'mlp.c_fc' in mn: - if 'weight' in pn or 'bias' in pn: - split_param_col_tp1d(param, pg) - param.compute_spec.set_output_replicate(False) - else: - param.set_dist_spec(ReplicaSpec()) - elif 'mlp.c_proj' in mn: - if 'weight' in pn: - split_param_row_tp1d(param, pg) - else: - param.set_dist_spec(ReplicaSpec()) - elif 'wte' in mn or 'wpe' in mn: - split_param_col_tp1d(param, pg) - elif 'c_attn' in mn or 'c_proj' in mn: - split_param_col_tp1d(param, pg) - else: - param.set_dist_spec(ReplicaSpec()) - - param.visited = True -def split_param_single_dim_tp1d(dim: int, param: ColoParameter, pg: ProcessGroup): - spec = (ShardSpec([dim], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D)) - param.set_tensor_spec(*spec) - - -def split_param_row_tp1d(param: ColoParameter, pg: ProcessGroup): - split_param_single_dim_tp1d(0, param, pg) - - -def split_param_col_tp1d(param: ColoParameter, pg: ProcessGroup): - split_param_single_dim_tp1d(-1, param, pg) -``` - -Define a model which uses Gemini + ZeRO DDP: - -```python -def gemini_zero_dpp(model: torch.nn.Module, pg: ProcessGroup, placememt_policy: str = "auto"): - cai_version = colossalai.__version__ - if version.parse(cai_version) > version.parse("0.1.10"): - from colossalai.nn.parallel import GeminiDDP - model = GeminiDDP(model, - device=get_current_device(), - placement_policy=placememt_policy, - pin_memory=True, - search_range_mb=32) - elif version.parse(cai_version) <= version.parse("0.1.10") and version.parse(cai_version) >= version.parse("0.1.9"): - from colossalai.gemini import ChunkManager, GeminiManager - chunk_size = ChunkManager.search_chunk_size(model, 64 * 1024**2, 32) - gemini_manager = GeminiManager(placememt_policy, chunk_manager) - chunk_manager = ChunkManager(chunk_size, - pg, - enable_distributed_storage=True, - init_device=GeminiManager.get_default_device(placememt_policy)) - model = ZeroDDP(model, gemini_manager) - else: - raise NotImplemented(f"CAI version {cai_version} is not supported") - return model -``` - -As we pre-train GPT in this example, we just use a simple language model loss. Write a function to get random inputs: @@ -219,9 +159,15 @@ def get_data(batch_size, seq_len, vocab_size): return input_ids, attention_mask ``` -Finally, we can define our training loop: +Finally, we define a model which uses Gemini + ZeRO DDP and define our training loop, As we pre-train GPT in this example, we just use a simple language model loss: ```python +from colossalai.nn.optimizer import HybridAdam + +from colossalai.booster import Booster +from colossalai.lazy import LazyInitContext +from colossalai.booster.plugin import GeminiPlugin + def main(): args = parse_args() BATCH_SIZE = 8 @@ -232,22 +178,19 @@ def main(): # build criterion criterion = GPTLMLoss() + optimizer = HybridAdam(model.parameters(), lr=0.001) torch.manual_seed(123) - default_pg = ProcessGroup(tp_degree=args.tp_degree) - default_dist_spec = ShardSpec([-1], [args.tp_degree]) if args.shardinit else None # build GPT model - with ColoInitContext(device='cpu', default_dist_spec=default_dist_spec, default_pg=default_pg): + with ColoInitContext(default_device=torch.device('cuda')): model = gpt2_medium(checkpoint=True) - pg = default_pg - # Tensor Parallelism (TP) - tensor_parallelize(model, pg) - # Gemini + ZeRO DP, Note it must be used after TP - model = gemini_zero_dpp(model, pg, args.placement) - # build optimizer - optimizer = GeminiAdamOptimizer(model, lr=1e-3, initial_scale=2**5) - numel = sum([p.numel() for p in model.parameters()]) - get_tflops_func = partial(get_tflops, numel, BATCH_SIZE, SEQ_LEN) + + + # Gemini + ZeRO DP + plugin = GeminiPlugin(max_norm=1.0, initial_scale=2**5) + booster = Booster(plugin=plugin) + model, optimizer, criterion, _, _ = booster.boost(model, optimizer, criterion) + torch.cuda.synchronize() model.train() for n in range(NUM_STEPS): @@ -256,10 +199,12 @@ def main(): optimizer.zero_grad() outputs = model(input_ids, attn_mask) loss = criterion(outputs, input_ids) - optimizer.backward(loss) + booster.backward(loss, optimizer) optimizer.step() torch.cuda.synchronize() ``` -> ⚠️ Note: If you want to use the Gemini module, please do not use the [Gradient Accumulation](../features/gradient_accumulation.md) we mentioned before。 +> ⚠️ Note: If you want to use the Gemini module, please do not use the [Gradient Accumulation](../features/gradient_accumulation_with_booster.md) we mentioned before。 The complete example can be found on [Train GPT with Colossal-AI](https://github.com/hpcaitech/ColossalAI/tree/main/examples/language/gpt). + + diff --git a/docs/source/en/get_started/installation.md b/docs/source/en/get_started/installation.md index 290879219074bfc940705bf20f8167d08c941166..6fc4ce2c922a2190787fa7d7968af5fe9c29b710 100644 --- a/docs/source/en/get_started/installation.md +++ b/docs/source/en/get_started/installation.md @@ -29,7 +29,7 @@ CUDA_EXT=1 pip install colossalai ## Download From Source -> The version of Colossal-AI will be in line with the main branch of the repository. Feel free to raise an issue if you encounter any problem. :) +> The version of Colossal-AI will be in line with the main branch of the repository. Feel free to raise an issue if you encounter any problem. ```shell git clone https://github.com/hpcaitech/ColossalAI.git @@ -39,14 +39,29 @@ cd ColossalAI pip install -r requirements/requirements.txt # install colossalai -pip install . +CUDA_EXT=1 pip install . ``` -If you don't want to install and enable CUDA kernel fusion (compulsory installation when using fused optimizer): +If you don't want to install and enable CUDA kernel fusion (compulsory installation when using fused optimizer), just don't specify the `CUDA_EXT`: ```shell -CUDA_EXT=1 pip install . +pip install . ``` +For Users with CUDA 10.2, you can still build ColossalAI from source. However, you need to manually download the cub library and copy it to the corresponding directory. + +```bash +# clone the repository +git clone https://github.com/hpcaitech/ColossalAI.git +cd ColossalAI + +# download the cub library +wget https://github.com/NVIDIA/cub/archive/refs/tags/1.8.0.zip +unzip 1.8.0.zip +cp -r cub-1.8.0/cub/ colossalai/kernel/cuda_native/csrc/kernels/include/ + +# install +CUDA_EXT=1 pip install . +``` diff --git a/docs/source/en/get_started/run_demo.md b/docs/source/en/get_started/run_demo.md index f47bdbbd62fc13d9168212558a678cfa2f5b8127..1ce185e26db0e3210c3e3f5a51270e89d9ca0c02 100644 --- a/docs/source/en/get_started/run_demo.md +++ b/docs/source/en/get_started/run_demo.md @@ -7,19 +7,18 @@ can also run on systems with only one GPU. Quick demos showing how to use Coloss ## Single GPU Colossal-AI can be used to train deep learning models on systems with only one GPU and achieve baseline -performances. We provided an example to [train ResNet on CIFAR10 dataset](https://github.com/hpcaitech/ColossalAI-Examples/tree/main/image/resnet) -with only one GPU. You can find the example in [ColossalAI-Examples](https://github.com/hpcaitech/ColossalAI-Examples). +performances. We provided an example to [train ResNet on CIFAR10 dataset](https://github.com/hpcaitech/ColossalAI/tree/main/examples/images/resnet) +with only one GPU. You can find the example in [ColossalAI-Examples](https://github.com/hpcaitech/ColossalAI/tree/main/examples). Detailed instructions can be found in its `README.md`. ## Multiple GPUs Colossal-AI can be used to train deep learning models on distributed systems with multiple GPUs and accelerate the -training process drastically by applying efficient parallelization techniques. When we have several parallelism for you -to try out. +training process drastically by applying efficient parallelization techniques. When we have several parallelism for you to try out. #### 1. data parallel -You can use the same [ResNet example](https://github.com/hpcaitech/ColossalAI-Examples/tree/main/image/resnet) as the +You can use the same [ResNet example](https://github.com/hpcaitech/ColossalAI/tree/main/examples/images/resnet) as the single-GPU demo above. By setting `--nproc_per_node` to be the number of GPUs you have on your machine, the example is turned into a data parallel example. @@ -27,17 +26,19 @@ is turned into a data parallel example. Hybrid parallel includes data, tensor, and pipeline parallelism. In Colossal-AI, we support different types of tensor parallelism (i.e. 1D, 2D, 2.5D and 3D). You can switch between different tensor parallelism by simply changing the configuration -in the `config.py`. You can follow the [GPT example](https://github.com/hpcaitech/ColossalAI-Examples/tree/main/language/gpt). +in the `config.py`. You can follow the [GPT example](https://github.com/hpcaitech/ColossalAI/tree/main/examples/language/gpt). Detailed instructions can be found in its `README.md`. #### 3. MoE parallel -We provided [an example of WideNet](https://github.com/hpcaitech/ColossalAI-Examples/tree/main/image/widenet) to demonstrate +We provided [an example of ViT-MoE](https://github.com/hpcaitech/ColossalAI-Examples/tree/main/image/moe) to demonstrate MoE parallelism. WideNet uses mixture of experts (MoE) to achieve better performance. More details can be found in [Tutorial: Integrate Mixture-of-Experts Into Your Model](../advanced_tutorials/integrate_mixture_of_experts_into_your_model.md) #### 4. sequence parallel Sequence parallel is designed to tackle memory efficiency and sequence length limit problems in NLP tasks. We provided -[an example of BERT](https://github.com/hpcaitech/ColossalAI-Examples/tree/main/language/bert/sequene_parallel) in -[ColossalAI-Examples](https://github.com/hpcaitech/ColossalAI-Examples). You can follow the `README.md` to execute the code. +[an example of BERT](https://github.com/hpcaitech/ColossalAI/tree/main/examples/tutorial/sequence_parallel) in +[ColossalAI-Examples](https://github.com/hpcaitech/ColossalAI/tree/main/examples). You can follow the `README.md` to execute the code. + + diff --git a/docs/source/zh-Hans/advanced_tutorials/add_your_parallel.md b/docs/source/zh-Hans/advanced_tutorials/add_your_parallel.md deleted file mode 100644 index 4825a6fa1d6c8b9c7c74fe8f3d769057bad60980..0000000000000000000000000000000000000000 --- a/docs/source/zh-Hans/advanced_tutorials/add_your_parallel.md +++ /dev/null @@ -1,112 +0,0 @@ -# 添加你自己的并行模式 - -作者: Shenggui Li, Yongbin Li - -**前置教程** -- [定义配置文件](../basics/define_your_config.md) -- [并行配置](../basics/configure_parallelization.md) - -## 引言 - -为了使研究人员和工程师能够以更少的努力将我们的系统扩展到其他新颖的大规模分布式训练算法,我们已经将训练生命周期中的各种组件解耦。你可以通过简单地继承基类来实现你自己的并行模式。 - -主要组件有: - -1. `ProcessGroupInitializer` -2. `GradientHandler` -3. `Schedule` - -**目前这需要对源代码进行一些改动,因此我们建议你用`-e`标志从源代码安装。`-e`标志使得安装是可编辑的,因此,你的代码变化将反映在你的Python运行时中。我们将在这方面努力,以避免在未来的版本中改变源代码。** - - -## 进程组初始化器 - -并行通常由进程组来管理,参与相同并行算法的进程被置于同一进程组。对于不同的并行算法,需要创建不同的进程组。 -Colossal-AI 为用户提供了一个全局 context,使他们能够轻松地管理进程组。如果你想添加新的进程组,你可以很容易地定义一个新的类并在你的配置文件中设置它。为了定义你自己的进程组创建方式,你可以按照下面的步骤来创建一个新的分布式初始化。 - -1. 在 `colossalai.context.parallel_mode.ParallelMode` 中添加你自己的并行模式。 - ```python - class ParallelMode(Enum): - GLOBAL = 'global' - DATA = 'data' - PIPELINE = 'pipe' - ... - - NEW_MODE = 'new_mode' # define your mode here - ``` - -2. 创建一个 `ProcessGroupInitializer`。 你可以参考 `colossalai.context.dist_group_initializer` 中给出的例子,前六个参数是固定的。 -`ParallelContext` 将为你传入这些参数。如果你需要设置其他参数,可以像下面的例子中的 `arg1, arg2` 一样,在后面添加它。 -最后,通过添加装饰器 `@DIST_GROUP_INITIALIZER.register_module` 将你的初始化程序注册到注册表。 - ```python - # sample initializer class - @DIST_GROUP_INITIALIZER.register_module - class MyParallelInitializer(ProcessGroupInitializer): - - def __init__(self, - rank: int, - world_size: int, - config: Config, - data_parallel_size: int, - pipeline_parlalel_size: int, - tensor_parallel_size: int, - arg1, - arg2): - super().__init__(rank, world_size, config) - self.arg1 = arg1 - self.arg2 = arg2 - # ... your variable init - - def init_parallel_groups(self): - # initialize your process groups - pass - - ``` - 然后,你可以将你的新初始化器插入到 `colossalai.constants.INITIALIZER_MAPPING` 当前的模式与初始化映射中。你可以修改该文件或动态插入新的键值对。 - - ```python - colossalai.constants.INITIALIZER_MAPPING['new_mode'] = 'MyParallelInitializer' - ``` - -3. 在你的配置文件中设置你的初始化器。你可以传入你的自定义参数。这允许 - `ParallelContext` 创建你的初始化器并初始化你期望的进程组。 - - ```python - parallel = dict( - pipeline=dict(size=1), - tensor=dict(size=x, mode='new_mode') # this is where you enable your new parallel mode - ) - ``` - -## 梯度 Handler - -梯度 handler 是对参数的梯度执行 all-reduce 操作的对象。由于不同的 all-reduce 策略或许在不同的并行中被执行,用户可以继承 -`colossalai.engine.gradient_handler.BaseGradientHandler` 来实现其策略。目前,Colossal-AI 使用普通的数据并行梯度 handler 在数据并行的 rank 间 all-reduce 梯度。 -如果数据并行被检测到,梯度 handler 会被自动添加进 engine。 - -你可以添加你自己的梯度 handler,如下所示: - -```python -from colossalai.registry import GRADIENT_HANDLER -from colossalai.engine import BaseGradientHandler - -@GRADIENT_HANDLER.register_module -class YourGradientHandler(BaseGradientHandler): - - def handle_gradient(self): - do_something() - -``` - -之后,你可以在配置文件中指定你要使用的梯度 handler。 - -```python -gradient_handlers = [ - dict(type='YourGradientHandler'), -] -``` - -## Schedule - -Schedule 包含了如何执行前向和后向计算。目前, Colossal-AI 提供了流水和非流水的 schedule。 -如果你想修改前向和后向计算的执行方式,你可以继承 `colossalai.engine.schedule.BaseSchedule` 并实现 `forward_back_step` 函数。 diff --git a/docs/source/zh-Hans/advanced_tutorials/define_your_own_parallel_model.md b/docs/source/zh-Hans/advanced_tutorials/define_your_own_parallel_model.md deleted file mode 100644 index 64e8d8bcd14a19c841ccb146257b2db018249f64..0000000000000000000000000000000000000000 --- a/docs/source/zh-Hans/advanced_tutorials/define_your_own_parallel_model.md +++ /dev/null @@ -1,31 +0,0 @@ -# 定义你自己的并行模型 - -作者: Zhengda Bian, Yongbin Li - -> ⚠️ 我们正在编写此文档以使其更加详细。 我们将介绍不同并行的机制以及如何使用它们来编写模型。 - -假设您有一个具有数十亿参数的巨大 MLP 模型,其极大的隐藏层大小使其无法直接被单个 GPU 容纳。别担心,Colossal-AI 可以帮你解决这个问题。 -在 Colossal-AI 的帮助下,您可以用所熟悉的为单个 GPU 编写模型的方式编写大模型,而 Colossal-AI 会自动拆分您的模型权重,并将它们完美地分配到一组 GPU 中。我们给出一个简单的示例,展示如何在 Colossal-AI 中编写简单的 2D 并行模型。 - -## 写一个简单的2D并行模型 - -```python -from colossalai.nn import Linear2D -import torch.nn as nn - -class MLP_2D(nn.Module): - - def __init__(self): - super().__init__() - self.linear_1 = Linear2D(in_features=1024, out_features=16384) - self.linear_2 = Linear2D(in_features=16384, out_features=1024) - - def forward(self, x): - x = self.linear_1(x) - x = self.linear_2(x) - return x -``` - -## 使用预定义的模型 - -为了方便您的使用,我们在 Colossal-AI 的 Model Zoo 中提供一些流行的模型,如*BERT*, *ViT*, *MoE* 和 *GPT*,请自由地将它们定制为不同的尺寸,以满足您的特殊需求。 diff --git a/docs/source/zh-Hans/advanced_tutorials/integrate_mixture_of_experts_into_your_model.md b/docs/source/zh-Hans/advanced_tutorials/integrate_mixture_of_experts_into_your_model.md index 456878caa14715ac07549b4d61736d4a84e64a2c..8ed9a1e43cdd82d5dd93dab28157cc74f4f4c459 100644 --- a/docs/source/zh-Hans/advanced_tutorials/integrate_mixture_of_experts_into_your_model.md +++ b/docs/source/zh-Hans/advanced_tutorials/integrate_mixture_of_experts_into_your_model.md @@ -9,44 +9,24 @@ - [Switch Transformers: Scaling to Trillion Parameter Models with Simple and Efficient Sparsity](https://arxiv.org/abs/2101.03961) - [Go Wider Instead of Deeper](https://arxiv.org/abs/2107.11817) -(中文版教程将会在近期提供) - ## Introduction -Since the advent of Switch Transformer, the AI community has found Mixture of Experts (MoE) a useful technique to enlarge the capacity of deep learning models. - -Colossal-AI provides an early access version of parallelism specifically designed for MoE models. -The most prominent advantage of MoE in Colossal-AI is convenience. -We aim to help our users to easily combine MoE with model parallelism and data parallelism. - -However, the current implementation has two main drawbacks now. -The first drawback is its poor efficiency in large batch size and long sequence length training. -The second drawback is incompatibility with tensor parallelism. -We are working on system optimization to overcome the training efficiency problem. -The compatibility problem with tensor parallelism requires more adaptation, and we will tackle this issue in the future. - -Here, we will introduce how to use MoE with model parallelism and data parallelism. - -## Table of Content -In this tutorial we will cover: -1. Set up MoE running environment -2. Create MoE layer -3. Train your model +自从`Switch Transformer`出现以来,人工智能社区发现专家混合 (MoE) 是一种扩大深度学习模型容量的有用技术。 +Colossal-AI 提供了专为MoE模型设计的并行性的早期访问版本。Colossal-AI中MoE最突出的优势就是方便。我们的目标是帮助我们的用户轻松地将MoE与模型并行性和数据并行性结合起来。 +但是,当前的实施现在有两个主要缺点。第一个缺点是它在大批量和长序列长度训练中效率低下。第二个缺点是与张量并行性不兼容。我们正在致力于系统优化,以克服训练效率问题。与张量并行的兼容性问题需要更多的适应,我们将在未来解决这个问题。 +在这里,我们将介绍如何使用具有模型并行性和数据并行性的 MoE。 -We provided the [example code](https://github.com/hpcaitech/ColossalAI-Examples/tree/main/image/widenet) for this tutorial in [ColossalAI-Examples](https://github.com/hpcaitech/ColossalAI-Examples). -This example uses [WideNet](https://arxiv.org/abs/2107.11817) as an example of MoE-based model. +## 目录 +在本教程中,我们将介绍: +1. [搭建MoE运行环境](#搭建moe运行环境) +2. [创建MoE层](#创建moe层) +3. [定义训练模型](#训练模型) +我们提供[示例](https://github.com/hpcaitech/ColossalAI-Examples/tree/main/image/widenet), 详细介绍请参考 [ColossalAI-Examples](https://github.com/hpcaitech/ColossalAI-Examples). +该示例使用 [WideNet](https://arxiv.org/abs/2107.11817) 作为基于 MoE 的模型的示例. -## Set up MoE running environment -In your project folder, create a `config.py`. - -This file is to specify some features you may want to use to train your model. -In order to enable MoE, you need to add a dict called parallel and specify the value of key moe. -You can assign a value for the key size of moe, which represents the model parallel size of experts (i.e. the number of experts in one group to parallelize training). - -For example, if the size is 4, 4 processes will be assigned to 4 consecutive GPUs and these 4 processes form a moe model parallel group. -Each process on the 4 GPUs will only get a portion of experts. Increasing the model parallel size will reduce communication cost, but increase computation cost in each GPU and activation cost in memory. -The total data parallel size is auto-detected and set as the number of GPUs by default. +## 搭建MoE运行环境 +在您的项目文件夹中,创建`config.py`文件。在该文件中,您可以指定希望用于训练模型的一些功能。为了启用 MoE,您需要在`config.py`中定义`parallel`字段,并指定`moe`的值。`moe`表示一组moe并行化训练组的并行大小。例如,`moe`设置为4,则4个进程将分配给4个连续的GPU,这4个进程组成一个moe模型并行组。每个进程只会得到一部分专家。增加mo e并行的大小将降低通信成本,但会增加每个GPU的计算成本和内存中activation的存储成本。总的数据并行的大小是自动检测的,默认情况下设置为GPU的数量。 ```python MOE_MODEL_PARALLEL_SIZE = ... @@ -55,37 +35,29 @@ parallel = dict( ) ``` -If `MOE_MODEL_PARALLEL_SIZE = E` and set the number of experts as `E` where `E` is a constant number, the process flow of forward pass of a transformer encoder in a model parallel group is shown below. +如果`MOE_MODEL_PARALLEL_SIZE = E`,即设置专家的总数为`E`(`E`为一个常数)。在模型并行中,transformer编码器中前向部分的处理流程如下图所示。
              MoE Transformer, image source: GShard
              -Since all experts are allocated to all GPUs in a model parallel group and a GPU only owns a portion of experts, -original data parallel groups are no longer correct for the parameters of experts during gradient handling in backward pass anymore. -So we create a new kind of parallel group called moe data parallel group. -The difference among different kinds of parallel group, when the configuration is set as `WORLD_SIZE=4`, -`MOE_MODEL_PARALLEL_SIZE=2`, is shown here. +所有专家都分配给模型并行组中的GPU,每一个GPU只拥有一部分专家,原始数据并行组在反向传递的梯度处理期间不再适用于专家参数。所以我们创建了一个新的并行组,叫做moe数据并行组。当配置设置为`WORLD_SIZE=4`,`MOE_MODEL_PARALLEL_SIZE=2`时,两个并行组的区别如下图所示。
              -
              MoE process group
              +
              MoE并行处理
              +至于梯度处理,我们提供了`MoeGradientHandler`来all-reduce模型的每个参数。如果您使用`colossalai.initialize`函数创建您的训练引擎,MoE梯度处理程序将自动添加到您的引擎中。否则,你应该自己处理梯度。MoE运行环境的所有参数都保存在`colossalai.global_variables.moe_env`中。您可以访问您的配置参数来检查您的设置是否正确。 -As for gradient handling, we provide MoeGradientHandler to all-reduce every parameter of the model. -If you use `colossalai.initialize` function to create your training engine, the MoE gradient handler will be added to your engine automatically. -Otherwise, you should take care of gradient by yourself. -All parameters of MoE running environment are stored in colossalai.global_variables.moe_env. -You can access your configuration parameters to check whether your setup is correct. ```python from colossalai.global_variables import moe_env ``` -## Create MoE layer -You can create a MoE layer from `colossalai.nn.moe`. -But before doing that, you should set up random seeds for all processes like this. +## 创建MoE层 + +您可以从`colossalai.nn.moe`创建MoE层。但在此之前,您应该为所有进程设置随机种子。 ```python from colossalai.context.random import moe_set_seed @@ -95,10 +67,7 @@ moe_set_seed(42) model = Widenet(num_experts=4, capacity_factor=1.2) ``` -`moe_set_seed` will set different seed for different processes in a moe model parallel group. -This helps initialize parameters in experts. -Then create an instance of experts and an instance of router. -Here is the example in model zoo. +`moe_set_seed` 会为一个moe模型并行组中的不同进程设置不同的种子(这有助于在专家中初始化参数),创建一个专家实例和一个路由器实例,示例如下。 ```python from colossalai.nn.layer.moe import Experts, MoeLayer, Top2Router, NormalNoiseGenerator @@ -118,16 +87,11 @@ ffn=MoeLayer(dim_model=d_model, num_experts=num_experts, router=shared_router, experts=shared_experts) ``` -Inside the initialization of Experts, the local expert number of each GPU will be calculated automatically. You just need to specify the class of each expert and its parameters used in its initialization. As for routers, we have provided top1 router and top2 router. You can find them in colossalai.nn.layer.moe. After creating the instance of experts and router, the only thing initialized in Moelayer is gate module. More definitions of each class can be found in our API document and code. - +在Experts的初始化中,会自动计算每个GPU的本地expert数量,您只需指定每个专家的类型及其在初始化时使用的参数。此外,我们提供了`Top1Router`和`Top2Router`,您可以在`colossalai.nn.layer.moe` 找到它们。在创建experts和router的实例时,`Moelayer`只初始化了`gate`模块,类型的更多详细信息您可以参考我们的API文档和代码。 -## Train Your Model -Do not to forget to use `colossalai.initialize` function in `colosalai` to add gradient handler for the engine. -We handle the back-propagation of MoE models for you. -In `colossalai.initialize`, we will automatically create a `MoeGradientHandler` object to process gradients. -You can find more information about the handler `MoeGradientHandler` in colossal directory. +## 定义训练模型 -The loss criterion should be wrapped by `Moeloss` to add auxiliary loss of MoE. Example is like this. +使用colossalai中的`colossalai.initialize`函数为引擎添加梯度处理程序以处理 MoE模型的反向传播。在 `colossalai.initialize` 中,我们会自动创建一个`MoeGradientHandler`对象来处理梯度。您可以在colossal目录中找到有关`MoeGradientHandler`的更多信息。为了添加MoE的相关损失处理,损失函数应使用`Moeloss`封装,示例如下。 ```python criterion = MoeLoss( aux_weight=0.01, @@ -135,6 +99,6 @@ criterion = MoeLoss( label_smoothing=0.1 ) ``` +最后,您只需使用 `colossalai` 中的`trainer`或`engine`进行训练即可。 -Finally, just use trainer or engine in `colossalai` to do your training. -Otherwise, you should take care of gradient by yourself. + diff --git a/docs/source/zh-Hans/advanced_tutorials/meet_gemini.md b/docs/source/zh-Hans/advanced_tutorials/meet_gemini.md index 2bf0a9c98c3f9dc92423f1320dd3da02a61a6fc5..594823862de1377b3c28f7e92c77740c209e16df 100644 --- a/docs/source/zh-Hans/advanced_tutorials/meet_gemini.md +++ b/docs/source/zh-Hans/advanced_tutorials/meet_gemini.md @@ -8,21 +8,21 @@ ## 用法 -目前Gemini支持和ZeRO并行方式兼容,它的使用方法很简单,在训练策略的配置文件里设置zero的model_config属性tensor_placement_policy='auto' - -``` -zero = dict( - model_config=dict( - reduce_scatter_bucket_size_mb=25, - fp32_reduce_scatter=False, - gradient_predivide_factor=1.0, - tensor_placement_policy="auto", - shard_strategy=TensorShardStrategy(), - ... - ), - optimizer_config=dict( - ... - ) +目前Gemini支持和ZeRO并行方式兼容,它的使用方法很简单:使用booster将`GeminiPlugin`中的特性注入到训练组件中。更多`booster`介绍请参考[booster使用](../basics/booster_api.md)。 + +```python +from torchvision.models import resnet18 +from colossalai.booster import Booster +from colossalai.zero import ColoInitContext +from colossalai.booster.plugin import GeminiPlugin +plugin = GeminiPlugin(placement_policy='cuda', strict_ddp_mode=True, max_norm=1.0, initial_scale=2**5) +booster = Booster(plugin=plugin) +ctx = ColoInitContext() +with ctx: + model = resnet18() +optimizer = HybridAdam(model.parameters(), lr=1e-3) +criterion = lambda x: x.mean() +model, optimizer, criterion, _, _ = booster.boost(model, optimizer, criterion) ) ``` @@ -48,7 +48,7 @@ zero = dict( -ColossalAI设计了Gemini,就像双子星一样,它管理CPU和GPU二者内存空间。它可以让张量在训练过程中动态分布在CPU-GPU的存储空间内,从而让模型训练突破GPU的内存墙。内存管理器由两部分组成,分别是MemStatsCollector(MSC)和StatefuleTensorMgr(STM)。 +ColossalAI设计了Gemini,就像双子星一样,它管理CPU和GPU二者内存空间。它可以让张量在训练过程中动态分布在CPU-GPU的存储空间内,从而让模型训练突破GPU的内存墙。内存管理器由两部分组成,分别是MemStatsCollector(MSC)和StatefulTensorMgr(STM)。 我们利用了深度学习网络训练过程的迭代特性。我们将迭代分为warmup和non-warmup两个阶段,开始时的一个或若干迭代步属于预热阶段,其余的迭代步属于正式阶段。在warmup阶段我们为MSC收集信息,而在non-warmup阶段STM入去MSC收集的信息来移动tensor,以达到最小化CPU-GPU数据移动volume的目的。 @@ -75,7 +75,7 @@ STM管理所有model data tensor的信息。在模型的构造过程中,Coloss 我们在算子的开始和结束计算时,触发内存采样操作,我们称这个时间点为**采样时刻(sampling moment)**,两个采样时刻之间的时间我们称为**period**。计算过程是一个黑盒,由于可能分配临时buffer,内存使用情况很复杂。但是,我们可以较准确的获取period的系统最大内存使用。非模型数据的使用可以通过两个统计时刻之间系统最大内存使用-模型内存使用获得。 -我们如何设计采样时刻呢。我们选择preOp的model data layout adjust之前。如下图所示。我们采样获得上一个period的system memory used,和下一个period的model data memoy used。并行策略会给MSC的工作造成障碍。如图所示,比如对于ZeRO或者Tensor Parallel,由于Op计算前需要gather模型数据,会带来额外的内存需求。因此,我们要求在模型数据变化前进行采样系统内存,这样在一个period内,MSC会把preOp的模型变化内存捕捉。比如在period 2-3内,我们考虑的tensor gather和shard带来的内存变化。 +我们如何设计采样时刻呢。我们选择preOp的model data layout adjust之前。如下图所示。我们采样获得上一个period的system memory used,和下一个period的model data memory used。并行策略会给MSC的工作造成障碍。如图所示,比如对于ZeRO或者Tensor Parallel,由于Op计算前需要gather模型数据,会带来额外的内存需求。因此,我们要求在模型数据变化前进行采样系统内存,这样在一个period内,MSC会把preOp的模型变化内存捕捉。比如在period 2-3内,我们考虑的tensor gather和shard带来的内存变化。 尽管可以将采样时刻放在其他位置,比如排除gather buffer的变动新信息,但是会给造成麻烦。不同并行方式Op的实现有差异,比如对于Linear Op,Tensor Parallel中gather buffer的分配在Op中。而对于ZeRO,gather buffer的分配是在PreOp中。将放在PreOp开始时采样有利于将两种情况统一。 @@ -94,3 +94,5 @@ MSC的重要职责是在调整tensor layout位置,比如在上图S2时刻, 在non-warmup阶段,我们需要利用预热阶段采集的非模型数据内存信息,预留出下一个Period在计算设备上需要的峰值内存,这需要我们移动出一些模型张量。 为了避免频繁在CPU-GPU换入换出相同的tensor,引起类似[cache thrashing](https://en.wikipedia.org/wiki/Thrashing_(computer_science))的现象。我们利用DNN训练迭代特性,设计了OPT cache换出策略。具体来说,在warmup阶段,我们记录每个tensor被计算设备需要的采样时刻。如果我们需要驱逐一些HOLD tensor,那么我们选择在本设备上最晚被需要的tensor作为受害者。 + + diff --git a/docs/source/zh-Hans/advanced_tutorials/opt_service.md b/docs/source/zh-Hans/advanced_tutorials/opt_service.md index a213584fd41d52b6492e309b7a4cef9bd500065c..1f8324a53ecbd3bb58f3fbbd72e286d6163c3530 100644 --- a/docs/source/zh-Hans/advanced_tutorials/opt_service.md +++ b/docs/source/zh-Hans/advanced_tutorials/opt_service.md @@ -52,7 +52,7 @@ export CHECKPOINT_DIR="your_opt_checkpoint_path" # the ${CONFIG_DIR} must contain a server.sh file as the entry of service export CONFIG_DIR="config_file_path" -docker run --gpus all --rm -it -p 8020:8020 -v ${CHECKPOINT_DIR}:/model_checkpoint -v ${CONFIG_DIR}:/config --ipc=host energonai:lastest +docker run --gpus all --rm -it -p 8020:8020 -v ${CHECKPOINT_DIR}:/model_checkpoint -v ${CONFIG_DIR}:/config --ipc=host energonai:latest ``` 接下来,您就可以在您的浏览器中打开 `https://[IP-ADDRESS]:8020/docs#` 进行测试。 diff --git a/docs/source/zh-Hans/advanced_tutorials/parallelize_your_training_like_Megatron.md b/docs/source/zh-Hans/advanced_tutorials/parallelize_your_training_like_Megatron.md deleted file mode 100644 index f3c6247c38e452f98ba198195d8cccde390922f4..0000000000000000000000000000000000000000 --- a/docs/source/zh-Hans/advanced_tutorials/parallelize_your_training_like_Megatron.md +++ /dev/null @@ -1,176 +0,0 @@ -# 使用ColoTensor让串行程序像Megatron-LM一样并行 - -Author: [Haichen Huang](https://github.com/1SAA) and [Jiarui Fang](https://github.com/feifeibear) - -**Prerequisite:** -- [ColoTensor Concepts](../basics/colotensor_concept.md) - -## 介绍 - -在新版本中,我们引入了ColoTensor。ColoTensor为用户使用并行训练提供了极大的便利,使得用户可以在原本的串行代码上,通过较小的修改将训练改为并行。在本教程中,我们将说明如何修改训练模型以自动使代码采取像 Megatron-LM 一样的方式并行训练。我们以 HuggingFace 提供的 GPT-2 模型为例,并提供一种方式让你可以在单个GPU上预训练GPT-2模型。 - -Megatron-LM 提供了一个具有影响力的并行化范式,这个范式主要应用于Transformer大模型的训练。然而,为了大规模训练 Transformer 语言大模型,用户必须使用Megatron-LM提供的特殊模块来构建他们的模型。这给用户带来了一些困难的工作,例如从预先训练的模型中加载权重,或是构建自己的并行训练模型。为了减轻用户的麻烦,我们提供 ColoTensor 类,以完成自动启用张量模型并行。 - -## 定义模型和损失函数 - -首先,我们直接调用 HuggingFace 库中的 GPTModel 和 GPTLoss。 - -```python -import torch -import torch.nn as nn -from transformers import GPT2Config, GPT2LMHeadModel - -class GPTLMModel(nn.Module): - def __init__(self, hidden_size=768, num_layers=12, num_attention_heads=12, max_seq_len=1024, vocab_size=50257, checkpoint=False): - super().__init__() - self.checkpoint = checkpoint - self.model = GPT2LMHeadModel(GPT2Config(n_embd=hidden_size, n_layer=num_layers, - n_head=num_attention_heads, n_positions=max_seq_len, n_ctx=max_seq_len, vocab_size=vocab_size)) - if checkpoint: - self.model.gradient_checkpointing_enable() - - def forward(self, input_ids, attention_mask): - # Only return lm_logits - return self.model(input_ids=input_ids, attention_mask=attention_mask, use_cache=not self.checkpoint)[0] - - -class GPTLMLoss(nn.Module): - def __init__(self): - super().__init__() - self.loss_fn = nn.CrossEntropyLoss() - - def forward(self, logits, labels): - shift_logits = logits[..., :-1, :].contiguous() - shift_labels = labels[..., 1:].contiguous() - # Flatten the tokens - return self.loss_fn(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) -``` - -## 对GPT-2的简短回顾 - -现在,我们回顾一下 GPT-2 模型的结构。每个 GPT-2 模型都可以表示为一个 DAG。如下图所示,每个圆圈代表一个算子,每个方块代表一个权重。每个箭头表示输入数据的流向,而箭头旁边的符号表示输入数据的形状。 - -然后,让我们深入了解一下这个 GPT-2 模型。它由三部分组成,分别是**嵌入模块**、**转换器层**和**分类头**。 - -嵌入模块包含两个权重,符号嵌入权重和位置嵌入权重。在嵌入模块的前向操作之后,原始输入数据的所有序列中的每个单词都会被嵌入到隐藏状态。 - -
              - -
              嵌入模块
              -
              - -每个转换器层包含两个块。自注意操作在第一个块中调用,同时一个双层感知器位于第二个块中。 - -
              - -
              转换器层
              -
              - -最后,分类头只是一个不加偏差的线性模块,里面只有一个线性权重。 - -## 应用ColoTensor - -两个步骤使您的串行代码采取 Megatron-LM 张量并行风格。 -1. 在ColoInitContext的上下文中初始化模型。 -2. 为每个参数设置 ColoTensorSpec。 - -### 使用 ColoInitContext 初始化 - -我们应该在 ColoInitContext 中构建模型。在该种上下文中,任何初始化的参数都将转换为 ColoParameter 并自动移动到相应的设备上。 - -```python -from colossalai.utils.model.colo_init_context import ColoInitContext - -with ColoInitContext(device=torch.device('cpu')): - model = GPTLMModel() -``` - -### 为每个参数设置 ColoTensorSpec - -模型创建完成后,我们通过ProcessGroup建立分布式环境。这里,我们将张量并行度指定为所有GPU的数量,即数据并行度为一。 - -```python -import torch.distributed as dist -from colossalai.tensor import ProcessGroup - -pg = ProcessGroup(tp_degree=dist.get_world_size()) -``` - -现在,我们需要一些辅助函数为下一步做准备。我们定义了两个函数来切分参数。Megatron-LM张量并行需要沿参数的第一维或最后一维切分参数张量。 - -```python -from colossalai.tensor import ShardSpec, ComputeSpec, ComputePattern, ColoParameter, ProcessGroup - -def split_param_single_dim_tp1d(dim: int, param: ColoParameter, pg: ProcessGroup): - spec = (ShardSpec([dim], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D)) - if param.process_group.tp_world_size() == 1: - param.set_process_group(pg) - param.set_tensor_spec(*spec) - - -def split_param_row_tp1d(param: ColoParameter, pg: ProcessGroup): - split_param_single_dim_tp1d(0, param, pg) - - -def split_param_col_tp1d(param: ColoParameter, pg: ProcessGroup): - split_param_single_dim_tp1d(-1, param, pg) -``` - -然后我们使模型采用张量并行。根据 Megatron 中使用的张量并行,应该沿着张量的最后一个维度进行切片,包括符号嵌入的权重,位置嵌入的权重,自注意力块中的所有线性权重和偏差,以及每个双层感知器中的第一个线性权重和偏差。且需要沿第一个维度切分双层感知器中的第二个线性权重。 - -```python -for mn, module in model.named_modules(): - for pn, param in module.named_parameters(recurse=False): - # set process group for all parameters - param.set_process_group(pg) - - if 'mlp.c_fc' in mn: - if 'weight' in pn or 'bias' in pn: - split_param_col_tp1d(param, pg) # colmn slice - # keep the shape of the output from c_fc - param.compute_spec.set_output_replicate(False) - elif 'mlp.c_proj' in mn: - if 'weight' in pn: - split_param_row_tp1d(param, pg) # row slice - elif 'wte' in mn or 'wpe' in mn: - split_param_col_tp1d(param, pg) # colmn slice - elif 'c_attn' in mn or 'c_proj' in mn: - split_param_col_tp1d(param, pg) # colmn slice -``` - -修改后的模型如下图所示。 - -嵌入模块: - -
              - -
              修改后的嵌入模块
              -
              - -转换器层: - -
              - -
              修改后的转换器层
              -
              - -一旦用户指定了每个参数的在并行中的分布模式,ColoTensor 就能够推断出所有算子的计算模式,包括矩阵乘法、线性函数、torch.nn.functional 中的其他逐元素函数,以及其他的一些常用函数。这样,用户可以像往常一样训练他们的模型。 - -在我们最新示例中还定义了一个Gemini + ZeRO DDP 的模型从而减小开销,提升效率。这一部分的详细内容可以参考[ZeRO](../features/zero_with_chunk.md),你可以将这两部分内容结合起来看从而理解我们整个训练流程: - -```python -def gemini_zero_dpp(model: torch.nn.Module, pg: ProcessGroup, placememt_policy: str = "auto"): - from colossalai.nn.parallel import GeminiDDP - model = GeminiDDP(model, - device=get_current_device(), - placement_policy=placememt_policy, - pin_memory=True, - search_range_mb=32) - return model -``` - -## 在单个GPU上预训练GPT-2 - -我们做的上述优化让我们可以在单GPU上训练GPT-2模型,只需要将`run.sh`中设置参数`GPUNUM`=1,再运行文件时就可以在单个GPU上完成模型的训练。 - -GPT-2 示例在[Train GPT with Colossal-AI](https://github.com/hpcaitech/ColossalAI/tree/main/examples/language/gpt). 获得。 diff --git a/docs/source/zh-Hans/advanced_tutorials/train_gpt_using_hybrid_parallelism.md b/docs/source/zh-Hans/advanced_tutorials/train_gpt_using_hybrid_parallelism.md index 6c6dcf6e850db886cec080e131065af306187c13..a1d58e9fddc26e4ecf2a820e7ec521d3ff387393 100644 --- a/docs/source/zh-Hans/advanced_tutorials/train_gpt_using_hybrid_parallelism.md +++ b/docs/source/zh-Hans/advanced_tutorials/train_gpt_using_hybrid_parallelism.md @@ -36,14 +36,14 @@ import torch import torch.nn as nn from colossalai import nn as col_nn from colossalai.amp import AMP_TYPE -from colossalai.builder.pipeline import partition_uniform -from colossalai.context.parallel_mode import ParallelMode +from colossalai.legacy.builder.pipeline import partition_uniform +from colossalai.legacy.context.parallel_mode import ParallelMode from colossalai.core import global_context as gpc -from colossalai.engine.schedule import (InterleavedPipelineSchedule, +from colossalai.legacy.engine.schedule import (InterleavedPipelineSchedule, PipelineSchedule) from colossalai.logging import disable_existing_loggers, get_dist_logger -from colossalai.nn.layer.wrapper import PipelineSharedModuleWrapper -from colossalai.trainer import Trainer, hooks +from colossalai.legacy.nn.layer.wrapper import PipelineSharedModuleWrapper +from colossalai.legacy.trainer import Trainer, hooks from colossalai.utils.timer import MultiTimer from model_zoo.gpt import GPTLMLoss from torch.nn import functional as F @@ -273,3 +273,4 @@ def train(): return_output_label=False, ) ``` + diff --git a/docs/source/zh-Hans/advanced_tutorials/train_vit_using_pipeline_parallelism.md b/docs/source/zh-Hans/advanced_tutorials/train_vit_using_pipeline_parallelism.md index 495c7fa36cc1be2089f9736c700f950b0d3a5a33..5ef863dcd42315e4d2e5f78dcdb975bf561d23f9 100644 --- a/docs/source/zh-Hans/advanced_tutorials/train_vit_using_pipeline_parallelism.md +++ b/docs/source/zh-Hans/advanced_tutorials/train_vit_using_pipeline_parallelism.md @@ -32,11 +32,11 @@ import colossalai import colossalai.nn as col_nn import torch import torch.nn as nn -from colossalai.builder import build_pipeline_model -from colossalai.engine.schedule import (InterleavedPipelineSchedule, +from colossalai.legacy.builder import build_pipeline_model +from colossalai.legacy.engine.schedule import (InterleavedPipelineSchedule, PipelineSchedule) from colossalai.logging import disable_existing_loggers, get_dist_logger -from colossalai.trainer import Trainer, hooks +from colossalai.legacy.trainer import Trainer, hooks from colossalai.utils import MultiTimer, get_dataloader from timm.models import vision_transformer as vit from torchvision import transforms @@ -48,17 +48,17 @@ from torchvision.datasets import CIFAR10 总的来说, 我们提供3种方法来建立一个流水并行的模型: -1. `colossalai.builder.build_pipeline_model_from_cfg` -2. `colossalai.builder.build_pipeline_model` +1. `colossalai.legacy.builder.build_pipeline_model_from_cfg` +2. `colossalai.legacy.builder.build_pipeline_model` 3. 自己按阶段拆分模型 当你的内存能够容纳模型时,你可以使用前两种方法来建立你的模型,否则你必须自己分割模型。前两种方法首先在 CPU 上建立整个模型,然后分割模型,最后你可以直接把模型的相应部分移到 GPU 上。 -`colossalai.builder.build_pipeline_model_from_cfg()` 接收一个模型的配置文件,它可以均匀地(按层)或平衡地(按参数大小)分割模型。 +`colossalai.legacy.builder.build_pipeline_model_from_cfg()` 接收一个模型的配置文件,它可以均匀地(按层)或平衡地(按参数大小)分割模型。 -如果你熟悉 `PyTorch`, 你可以使用 `colossalai.builder.build_pipeline_model()` 它接收一个 `torch.nn.Sequential` 模型并按层均匀分割。 +如果你熟悉 `PyTorch`, 你可以使用 `colossalai.legacy.builder.build_pipeline_model()` 它接收一个 `torch.nn.Sequential` 模型并按层均匀分割。 -在本教程中,我们将修改 [TIMM/ViT](https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py) to `torch.nn.Sequential`,然后使用 `colossalai.builder.build_pipeline_model()` 来建立流水线模型。 +在本教程中,我们将修改 [TIMM/ViT](https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py) to `torch.nn.Sequential`,然后使用 `colossalai.legacy.builder.build_pipeline_model()` 来建立流水线模型。 当数据是 **一个** `Tensor`, 你可以使用你的模型 `forward()` 中的位置参数来获得数据张量。对于流水线的第一阶段,`forward()` 的第一个位置参数是从数据加载器加载的数据张量。对于其他阶段,`forward()` 的第一个位置参数是上一阶段的输出张量。注意,如果该阶段不是最后一个阶段,则 `forward()` 的返回必须是一个 `Tensor`。 @@ -244,3 +244,4 @@ def train(): hooks=hook_list, display_progress=True) ``` + diff --git a/docs/source/zh-Hans/advanced_tutorials/train_vit_with_hybrid_parallelism.md b/docs/source/zh-Hans/advanced_tutorials/train_vit_with_hybrid_parallelism.md index 6dc5eccf44218ced733de4bd5bb22fecc302c61e..f7dd8d477a661e1b824e6c7939f06724d82da610 100644 --- a/docs/source/zh-Hans/advanced_tutorials/train_vit_with_hybrid_parallelism.md +++ b/docs/source/zh-Hans/advanced_tutorials/train_vit_with_hybrid_parallelism.md @@ -73,8 +73,8 @@ from colossalai.context import ParallelMode from colossalai.core import global_context as gpc from colossalai.logging import disable_existing_loggers, get_dist_logger from colossalai.nn.lr_scheduler import LinearWarmupLR -from colossalai.nn.metric import Accuracy -from colossalai.trainer import Trainer, hooks +from colossalai.legacy.nn.metric import Accuracy +from colossalai.legacy.trainer import Trainer, hooks ``` - 其他模块 @@ -150,7 +150,7 @@ Colossal-AI 提供了自己的优化器、损失函数和学习率调度器。Py optimizer = colossalai.nn.Lamb(model.parameters(), lr=1.8e-2, weight_decay=0.1) # build loss criterion = torch.nn.CrossEntropyLoss() -# lr_scheduelr +# lr_scheduler lr_scheduler = LinearWarmupLR(optimizer, warmup_steps=50, total_steps=gpc.config.NUM_EPOCHS) ``` @@ -256,8 +256,8 @@ SEQ_LENGTH = (IMG_SIZE // PATCH_SIZE) ** 2 + 1 # add 1 for cls token ### 构建流水线模型 (`/hybrid_parallel/model/vit.py`) Colossal-AI 提供了两种从现有模型构建流水线模型的方法。 -- `colossalai.builder.build_pipeline_model_from_cfg` -- `colossalai.builder.build_pipeline_model` +- `colossalai.legacy.builder.build_pipeline_model_from_cfg` +- `colossalai.legacy.builder.build_pipeline_model` 此外,您还可以使用 Colossal-AI 从头开始构建流水线模型。 ```python @@ -266,11 +266,11 @@ from typing import Callable import inspect import torch from colossalai import nn as col_nn -from colossalai.registry import LAYERS, MODELS +from colossalai.legacy.registry import LAYERS, MODELS from colossalai.logging import get_dist_logger from colossalai.core import global_context as gpc from colossalai.context import ParallelMode -from colossalai.builder.pipeline import partition_uniform +from colossalai.legacy.builder.pipeline import partition_uniform from torch import dtype, nn from model_zoo.vit.vit import ViTBlock, ViTEmbedding, ViTHead @MODELS.register_module @@ -380,7 +380,7 @@ def build_pipeline_vit(num_layers, num_chunks, device=torch.device('cuda'), **kw #### 导入模块 ```python -from colossalai.engine.schedule import (InterleavedPipelineSchedule, +from colossalai.legacy.engine.schedule import (InterleavedPipelineSchedule, PipelineSchedule) from colossalai.utils import MultiTimer import os @@ -477,7 +477,7 @@ def build_cifar(batch_size): return train_dataloader, test_dataloader -# craete dataloaders +# create dataloaders train_dataloader , test_dataloader = build_cifar() # create loss function criterion = CrossEntropyLoss(label_smoothing=0.1) @@ -492,7 +492,7 @@ lr_scheduler = CosineAnnealingWarmupLR(optimizer=optimizer, #### 启动 Colossal-AI 引擎 ```python -# intiailize +# initialize engine, train_dataloader, test_dataloader, _ = colossalai.initialize(model=model, optimizer=optimizer, criterion=criterion, @@ -589,3 +589,4 @@ torchrun --standalone --nproc_per_node train_hybrid.py --config ./co # If your torch >= 1.9.0 # python -m torch.distributed.run --standalone --nproc_per_node= train_hybrid.py --config ./configs/config_hybrid_parallel.py ``` + diff --git a/docs/source/zh-Hans/basics/booster_api.md b/docs/source/zh-Hans/basics/booster_api.md new file mode 100644 index 0000000000000000000000000000000000000000..f9310374d823d76decee10aba62293adbd0f6576 --- /dev/null +++ b/docs/source/zh-Hans/basics/booster_api.md @@ -0,0 +1,97 @@ +# Booster API + +作者: [Mingyan Jiang](https://github.com/jiangmingyan), [Jianghai Chen](https://github.com/CjhHa1), [Baizhou Zhang](https://github.com/Fridge003) + +**预备知识:** + +- [分布式训练](../concepts/distributed_training.md) +- [Colossal-AI 总览](../concepts/colossalai_overview.md) + +**示例代码** + + + +- [使用Booster在CIFAR-10数据集上训练ResNet](https://github.com/hpcaitech/ColossalAI/blob/main/examples/tutorial/new_api/cifar_resnet) +- [使用Booster在RedPajama数据集上训练Llama-1/2](https://github.com/hpcaitech/ColossalAI/tree/main/examples/language/llama2) + +## 简介 + +在我们的新设计中, `colossalai.booster` 代替 `colossalai.initialize` 将特征(例如,模型、优化器、数据加载器)无缝注入到您的训练组件中。 使用 booster API, 您可以更友好地将我们的并行策略整合到待训练模型中. 调用 `colossalai.booster` 是您进入训练流程前的正常操作。 +在下面的章节中,我们将介绍 `colossalai.booster` 是如何工作的以及使用时我们要注意的细节。 + +### Booster 插件 + +Booster 插件是管理并行配置的重要组件(eg:gemini 插件封装了 gemini 加速方案)。目前支持的插件如下: + +**_HybridParallelPlugin:_** HybirdParallelPlugin 插件封装了混合并行的加速解决方案。它提供的接口可以在张量并行,流水线并行以及两种数据并行方法(DDP, Zero)间进行任意的组合。 + +**_GeminiPlugin:_** GeminiPlugin 插件封装了 gemini 加速解决方案,即基于块内存管理的 ZeRO 优化方案。 + +**_TorchDDPPlugin:_** TorchDDPPlugin 插件封装了Pytorch的DDP加速方案,实现了模型级别的数据并行,可以跨多机运行。 + +**_LowLevelZeroPlugin:_** LowLevelZeroPlugin 插件封装了零冗余优化器的 1/2 阶段。阶段 1:切分优化器参数,分发到各并发进程或并发 GPU 上。阶段 2:切分优化器参数及梯度,分发到各并发进程或并发 GPU 上。 + +**_TorchFSDPPlugin:_** TorchFSDPPlugin封装了 Pytorch的FSDP加速方案,可以用于零冗余优化器数据并行(ZeroDP)的训练。 + +若想了解更多关于插件的用法细节,请参考[Booster 插件](./booster_plugins.md)章节。 + +有一些插件支持懒惰初始化,它能节省初始化大模型时的内存占用。详情请参考[懒惰初始化](../features/lazy_init.md)。 + +### Booster 接口 + + + +{{ autodoc:colossalai.booster.Booster }} + +## 使用方法及示例 + +在使用 colossalai 训练时,首先需要在训练脚本的开头启动分布式环境,并创建需要使用的模型、优化器、损失函数、数据加载器等对象。之后,调用`booster.boost` 将特征注入到这些对象中,您就可以使用我们的 booster API 去进行您接下来的训练流程。 + +以下是一个伪代码示例,将展示如何使用我们的 booster API 进行模型训练: + +```python +import torch +from torch.optim import SGD +from torchvision.models import resnet18 + +import colossalai +from colossalai.booster import Booster +from colossalai.booster.plugin import TorchDDPPlugin + +def train(): + # launch colossalai + colossalai.launch(config=dict(), rank=rank, world_size=world_size, port=port, host='localhost') + + # create plugin and objects for training + plugin = TorchDDPPlugin() + booster = Booster(plugin=plugin) + model = resnet18() + criterion = lambda x: x.mean() + optimizer = SGD((model.parameters()), lr=0.001) + scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.1) + + # use booster.boost to wrap the training objects + model, optimizer, criterion, _, scheduler = booster.boost(model, optimizer, criterion, lr_scheduler=scheduler) + + # do training as normal, except that the backward should be called by booster + x = torch.randn(4, 3, 224, 224) + x = x.to('cuda') + output = model(x) + loss = criterion(output) + booster.backward(loss, optimizer) + optimizer.clip_grad_by_norm(1.0) + optimizer.step() + scheduler.step() + optimizer.zero_grad() + + # checkpointing using booster api + save_path = "./model" + booster.save_model(model, save_path, shard=True, size_per_shard=10, use_safetensors=True) + + new_model = resnet18() + booster.load_model(new_model, save_path) +``` + +更多的Booster设计细节请参考这一[页面](https://github.com/hpcaitech/ColossalAI/discussions/3046) + + diff --git a/docs/source/zh-Hans/basics/booster_checkpoint.md b/docs/source/zh-Hans/basics/booster_checkpoint.md new file mode 100644 index 0000000000000000000000000000000000000000..1ff2e330521c5b5ee41b7c5f9df98d20c4ef44d1 --- /dev/null +++ b/docs/source/zh-Hans/basics/booster_checkpoint.md @@ -0,0 +1,71 @@ +# Booster Checkpoint + +作者: [Hongxin Liu](https://github.com/ver217) + +**前置教程:** +- [Booster API](./booster_api.md) + +## 引言 + +我们在之前的教程中介绍了 [Booster API](./booster_api.md)。在本教程中,我们将介绍如何使用 booster 保存和加载 checkpoint。 + +## 模型 Checkpoint + +{{ autodoc:colossalai.booster.Booster.save_model }} + +模型在保存前必须被 `colossalai.booster.Booster` 封装。 `checkpoint` 是要保存的 checkpoint 的路径。 如果 `shard=False`,它就是文件。 否则, 它就是文件夹。如果 `shard=True`,checkpoint 将以分片方式保存,在 checkpoint 太大而无法保存在单个文件中时会很实用。我们的分片 checkpoint 格式与 [huggingface/transformers](https://github.com/huggingface/transformers) 兼容,所以用户可以使用huggingface的`from_pretrained`方法从分片checkpoint加载模型。 + +{{ autodoc:colossalai.booster.Booster.load_model }} + +模型在加载前必须被 `colossalai.booster.Booster` 封装。它会自动检测 checkpoint 格式,并以相应的方式加载。 + +如果您想从Huggingface加载预训练好的模型,但模型太大以至于无法在单个设备上通过“from_pretrained”直接加载,推荐的方法是将预训练的模型权重下载到本地,并在封装模型后使用`booster.load`直接从本地路径加载。为了避免内存不足,模型需要在`Lazy Initialization`的环境下初始化。以下是示例伪代码: +```python +from colossalai.lazy import LazyInitContext +from huggingface_hub import snapshot_download +... + +# Initialize model under lazy init context +init_ctx = LazyInitContext(default_device=get_current_device) +with init_ctx: + model = LlamaForCausalLM(config) + +... + +# Wrap the model through Booster.boost +model, optimizer, _, _, _ = booster.boost(model, optimizer) + +# download huggingface pretrained model to local directory. +model_dir = snapshot_download(repo_id="lysandre/arxiv-nlp") + +# load model using booster.load +booster.load(model, model_dir) +... +``` + +## 优化器 Checkpoint + + +{{ autodoc:colossalai.booster.Booster.save_optimizer }} + +优化器在保存前必须被 `colossalai.booster.Booster` 封装。 + +{{ autodoc:colossalai.booster.Booster.load_optimizer }} + +优化器在加载前必须被 `colossalai.booster.Booster` 封装。 + +## 学习率调度器 Checkpoint + +{{ autodoc:colossalai.booster.Booster.save_lr_scheduler }} + +学习率调度器在保存前必须被 `colossalai.booster.Booster` 封装。 `checkpoint` 是 checkpoint 文件的本地路径. + +{{ autodoc:colossalai.booster.Booster.load_lr_scheduler }} + +学习率调度器在加载前必须被 `colossalai.booster.Booster` 封装。 `checkpoint` 是 checkpoint 文件的本地路径. + +## Checkpoint 设计 + +有关 Checkpoint 设计的更多详细信息,请参见我们的讨论 [A Unified Checkpoint System Design](https://github.com/hpcaitech/ColossalAI/discussions/3339). + + diff --git a/docs/source/zh-Hans/basics/booster_plugins.md b/docs/source/zh-Hans/basics/booster_plugins.md new file mode 100644 index 0000000000000000000000000000000000000000..70352a7b9af398d100ee325073784cd81560e8ee --- /dev/null +++ b/docs/source/zh-Hans/basics/booster_plugins.md @@ -0,0 +1,88 @@ +# Booster 插件 + +作者: [Hongxin Liu](https://github.com/ver217), [Baizhou Zhang](https://github.com/Fridge003), [Pengtai Xu](https://github.com/ppt0011) + + +**前置教程:** +- [Booster API](./booster_api.md) + +## 引言 + +正如 [Booster API](./booster_api.md) 中提到的,我们可以使用 booster 插件来自定义并行训练。在本教程中,我们将介绍如何使用 booster 插件。 + +我们现在提供以下插件: + +- [Torch DDP 插件](#torch-ddp-插件): 它包装了 `torch.nn.parallel.DistributedDataParallel` 并且可用于使用数据并行训练模型。 +- [Torch FSDP 插件](#torch-fsdp-插件): 它包装了 `torch.distributed.fsdp.FullyShardedDataParallel` 并且可用于使用 Zero-dp 训练模型。 +- [Low Level Zero 插件](#low-level-zero-插件): 它包装了 `colossalai.zero.low_level.LowLevelZeroOptimizer`,可用于使用 Zero-dp 训练模型。它仅支持 Zero 阶段1和阶段2。 +- [Gemini 插件](#gemini-插件): 它包装了 [Gemini](../features/zero_with_chunk.md),Gemini 实现了基于Chunk内存管理和异构内存管理的 Zero-3。 +- [Hybrid Pararllel 插件](#hybrid-parallel-插件): 它为Shardformer,流水线管理器,混合精度运算,TorchDDP以及Zero-1/Zero-2功能提供了一个统一且简洁的接口。使用该插件可以简单高效地实现transformer模型在张量并行,流水线并行以及数据并行(DDP, Zero)间任意组合并行训练策略,同时支持多种训练速度和内存的优化工具。有关这些训练策略和优化工具的具体信息将在下一章中阐述。 + +更多插件即将推出。 + +## 插件选择 +- [Torch DDP 插件](#torch-ddp-插件): 适用于参数少于 20 亿的模型(例如 Bert-3m、GPT2-1.5b)。 +- [Torch FSDP 插件](#torch-fsdp-插件) / [Low Level Zero 插件](#low-level-zero-插件): 适用于参数少于 100 亿的模型(例如 GPTJ-6b、MegatronLM-8b)。 +- [Gemini 插件](#gemini-插件): 适合参数超过 100 亿的模型(例如 TuringNLG-17b),且**跨节点带宽高、中小规模集群(千卡以下)**的场景(例如 Llama2-70b)。 +- [Hybrid Pararllel 插件](#hybrid-parallel-插件): 适合参数超过 600 亿的模型、超长序列、超大词表等特殊模型,且**跨节点带宽低、大规模集群(千卡以上)**的场景(例如 GPT3-175b、Bloom-176b)。 + +## 插件 + +### Low Level Zero 插件 + +该插件实现了 Zero-1 和 Zero-2(使用/不使用 CPU 卸载),使用`reduce`和`gather`来同步梯度和权重。 + +Zero-1 可以看作是 Torch DDP 更好的替代品,内存效率更高,速度更快。它可以很容易地用于混合并行。 + +Zero-2 不支持局部梯度累积。如果您坚持使用,虽然可以积累梯度,但不能降低通信成本。也就是说,同时使用流水线并行和 Zero-2 并不是一个好主意。 + +{{ autodoc:colossalai.booster.plugin.LowLevelZeroPlugin }} + +我们已经测试了一些主流模型的兼容性,可能不支持以下模型: + +- `timm.models.convit_base` +- dlrm and deepfm models in `torchrec` + +兼容性问题将在未来修复。 + +### Gemini 插件 + +这个插件实现了基于Chunk内存管理和异构内存管理的 Zero-3。它可以训练大型模型而不会损失太多速度。它也不支持局部梯度累积。更多详细信息,请参阅 [Gemini 文档](../features/zero_with_chunk.md). + +{{ autodoc:colossalai.booster.plugin.GeminiPlugin }} + +### Hybrid Parallel 插件 + +这个插件实现了多种并行训练策略和优化工具的组合。Hybrid Parallel插件支持的功能大致可以被分为以下四个部分: + +1. Shardformer: Shardformer负责在张量并行以及流水线并行下切分模型的逻辑,以及前向/后向方法的重载,这个插件为Shardformer功能提供了一个简单易用的接口。与此同时,Shardformer还负责将包括fused normalization, flash attention (xformers), JIT和序列并行在内的各类优化工具融入重载后的前向/后向方法。更多关于Shardformer的信息请参考 [Shardformer文档](../features/shardformer.md)。 + +2. 混合精度训练:插件支持fp16/bf16的混合精度训练。更多关于混合精度训练的参数配置的详细信息请参考 [混合精度训练文档](../features/mixed_precision_training_with_booster.md)。 + +3. Torch DDP: 当流水线并行和Zero不被使用的时候,插件会自动采用Pytorch DDP作为数据并行的策略。更多关于Torch DDP的参数配置的详细信息请参考 [Pytorch DDP 文档](https://pytorch.org/docs/main/generated/torch.nn.parallel.DistributedDataParallel.html#torch.nn.parallel.DistributedDataParallel)。 + +4. Zero: 在初始化插件的时候,可以通过将`zero_stage`参数设置为1或2来让插件采用Zero 1/2作为数据并行的策略。Zero 1可以和流水线并行策略同时使用, 而Zero 2则不可以和流水线并行策略同时使用。更多关于Zero的参数配置的详细信息请参考 [Low Level Zero 插件](#low-level-zero-插件). + +> ⚠ 在使用该插件的时候, 只有支持Shardformer的部分Huggingface transformers模型才能够使用张量并行、流水线并行以及优化工具。Llama 1、Llama 2、OPT、Bloom、Bert以及GPT2等主流transformers模型均已支持Shardformer。 + +{{ autodoc:colossalai.booster.plugin.HybridParallelPlugin }} + +### Torch DDP 插件 + +更多详细信息,请参阅 [Pytorch 文档](https://pytorch.org/docs/main/generated/torch.nn.parallel.DistributedDataParallel.html#torch.nn.parallel.DistributedDataParallel). + +{{ autodoc:colossalai.booster.plugin.TorchDDPPlugin }} + +### Torch FSDP 插件 + +> ⚠ 如果 torch 版本低于 1.12.0,此插件将不可用。 + +> ⚠ 该插件现在还不支持保存/加载分片的模型 checkpoint。 + +> ⚠ 该插件现在还不支持使用了multi params group的optimizer。 + +更多详细信息,请参阅 [Pytorch 文档](https://pytorch.org/docs/main/fsdp.html). + +{{ autodoc:colossalai.booster.plugin.TorchFSDPPlugin }} + + diff --git a/docs/source/zh-Hans/basics/colotensor_concept.md b/docs/source/zh-Hans/basics/colotensor_concept.md deleted file mode 100644 index d6a332df2e9c4649e8184769e1b5650cf4c3823a..0000000000000000000000000000000000000000 --- a/docs/source/zh-Hans/basics/colotensor_concept.md +++ /dev/null @@ -1,97 +0,0 @@ -# ColoTensor Concepts - -Author: [Jiarui Fang](https://github.com/feifeibear), [Hongxin Liu](https://github.com/ver217) and [Haichen Huang](https://github.com/1SAA) - -**Prerequisite:** -- [Colossal-AI Overview](../concepts/colossalai_overview.md) -- [Distributed Training](../concepts/distributed_training.md) -- [Paradigms of Parallelism](../concepts/paradigms_of_parallelism.md) - -## Introduction - -在ColossalAI 0.1.8 版本之后,[ColoTensor](https://colossalai.readthedocs.io/en/latest/colossalai/colossalai.tensor.html#colossalai.tensor.ColoTensor) 成为 ColossalAI 中张量的基本数据结构。 它是 torch.Tensor 的子类,可以当做 PyTorch Tensor使用。 此外,一些独特的功能使其能够表示一个payload分布在多个 GPU 设备上的Global Tensor,并提供一些列方式操作这个Global Tensor。 在 ColoTensor 的帮助下,用户可以以类似编写串行程序方式,编写的分布式 DNN 训练程序。 - -ColoTensor 包含额外的属性[ColoTensorSpec](https://colossalai.readthedocs.io/en/latest/colossalai/colossalai.tensor.tensor_spec.html#colossalai.tensor.tensor_spec.ColoTensorSpec) -来描述张量的payload分布和计算模式。 - -- ProcessGroup:如何将进程组织为通信组。 -- Distributed Spec:张量如何在进程组之间分布。 -- Compute Spec:计算过程中如何使用张量。 - -我们一一详述。 - -## ProcessGroup - -[ProcessGroup](https://colossalai.readthedocs.io/en/latest/colossalai/colossalai.tensor.html#colossalai.tensor.ProcessGroup) 类的一个实例描述了如何在进程组中组织进程。进程组内的进程可以一起参与同一个集合通信,比如allgather, allreduce等。进程组组织方式被张量的并行策略支配。比如,如果用户定义了Tensor的张量并行(TP),数据并行(DP)方式,那么进程组的进程组织方式将被自动推导出来。 进程组设置可能因不同的张量而异。 因此,它使我们能够支持更复杂的混合并行。流水线并行(PP)定义不在ProcessGroup中描述,它需要另一套机制,我们将在未来补充ColoTensor应用于PP的相关内容。 - -目前,ColoTensor 的一个进程组由 tp_degree 和 dp_degree 两种配置定义。 在 DP+TP 混合并行的情况下,可以将设备视为 2D 网格。 我们将 TP 通信组放置在设备网格的前导低维上,然后将数据并行组放置在设备网格的高维上。 原因是张量并行比数据并行具有更大的通信开销。 相邻设备放置在一个 TP 进程组内,并且通常放置在同一个节点中。 - -考虑到8个进程配置为tp_degree=4,dp_degree=2,布局如下图。 进程组 tp0 包含 gpu 0,1,2,3。 进程 dp1 包含 gpu 1 和 5。 - -
              - -
              Process Group using tp_degree=4, dp_degree=2
              -
              - -## Distributed Spec - -[Distributed Spec](https://colossalai.readthedocs.io/en/latest/colossalai/colossalai.tensor.distspec.html)描述了 ColoTensor 如何在 ProcessGroup 中分布。 - -张量在 DP 进程组之间的分布方式是自动导出的,不需要用户手动指定。 如果这个张量是一个模型参数,它会在 DP 进程组中被复制。 如果是activation张量,则沿tensor最高维度在DP进程组中进行平均分割。 - -因此,在使用 Distributed Spec 时,我们只需要描述张量在 TP 进程组之间的分布方式即可。 TP 进程组目前有两种分布式规范,即 [ShardSpec](https://colossalai.readthedocs.io/en/latest/colossalai/colossalai.tensor.distspec.html#colossalai.tensor.distspec.ShardSpec)和[ReplicaSpec](https://colossalai.readthedocs.io/en/latest/colossalai/colossalai.tensor.distspec.html#colossalai.tensor.distspec.ReplicaSpec)。 ShardSpec 需要指定分区的维度索引 dim 和分区个数 num_partitions。 目前,我们仅支持在单个dim上进行拆分。 TP进程组上不同的dist spec可以通过set_dist_spec()接口相互转换。这些转化操作可以被记录在PyTorch的自动求导机制中,并在反向传播时候触发对应的反向操作。 - -## Compute Spec - -[ComputeSpec](https://colossalai.readthedocs.io/en/latest/colossalai/colossalai.tensor.compute_spec.html#colossalai.tensor.compute_spec.ComputeSpec)类描述Tensor如何参与计算。目前,我们将作为module parameter的ColoTensor设置正确的Compute Pattern。可以触发正取的计算模式。具体应用方式我们会在接下来的文档中展示。 - -## ColoParameter - -[ColoParameter](https://colossalai.readthedocs.io/en/latest/colossalai/colossalai.tensor.colo_parameter.html#colossalai.tensor.colo_parameter.ColoParameter)是ColoTensor的子类。用来声明Parameter。他和ColoTensor关系和Torch.Tensor和torch.Parameter一致。后者可以让tensor出现在module的parameters()和name_parameters() 的返回值中。 - -## Example - -让我们看一个例子。 使用 tp_degree=4, dp_dgree=2 在 8 个 GPU 上初始化并Shard一个ColoTensor。 然后tensor被沿着 TP 进程组中的最后一个维度进行分片。 最后,我们沿着 TP 进程组中的第一个维度(dim 0)对其进行重新Shard。 我们鼓励用户运行代码并观察每个张量的形状。 - - -```python -import torch -import torch.multiprocessing as mp -from colossalai.utils import print_rank_0 -from functools import partial - -import colossalai -from colossalai.tensor import ProcessGroup, ColoTensor, ColoTensorSpec, ShardSpec, ComputeSpec, ComputePattern -from colossalai.testing import spawn - -import torch - -def run_dist_tests(rank, world_size, port): - colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') - pg = ProcessGroup(tp_degree=2, dp_degree=2) - - torch.manual_seed(0) - local_tensor = torch.randn(2, 3, 1).cuda() - print_rank_0(f"shape {local_tensor.shape}, {local_tensor.data}") - - spec = ColoTensorSpec(pg, ShardSpec(dims=[-1], num_partitions=[pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D)) - t1 = ColoTensor.from_torch_tensor(local_tensor, spec) - t1 = t1.to_replicate() - print_rank_0(f"shape {t1.shape}, {t1.data}") - - spec2 = ShardSpec([0], [pg.tp_world_size()]) - t1.set_dist_spec(spec2) - print_rank_0(f"shape {t1.shape}, {t1.data}") - -def test_dist_cases(world_size): - spawn(run_dist_tests, world_size) - -if __name__ == '__main__': - test_dist_cases(4) -``` - -:::caution - -The ColoTensor is an experimental feature and may be updated. - -::: diff --git a/docs/source/zh-Hans/basics/command_line_tool.md b/docs/source/zh-Hans/basics/command_line_tool.md index 9b0275a6ceddb3ad6da16b068bb8dcf71eea8907..5c4c18989c179f860916d008b5cc14e0b7d2687e 100644 --- a/docs/source/zh-Hans/basics/command_line_tool.md +++ b/docs/source/zh-Hans/basics/command_line_tool.md @@ -26,22 +26,4 @@ Colossal-AI给用户提供了命令行工具,目前命令行工具可以用来 在分布式训练时,我们可以使用`colossalai run`来启动单节点或者多节点的多进程,详细的内容可以参考[启动 Colossal-AI](./launch_colossalai.md)。 -## 张量并行基准测试 - -Colossal-AI提供了多种张量并行,想要充分理解这些方法需要一定的学习成本,对于新手来说很难靠经验选择一个并行方式。 -所以我们提供了一个简单的基准测试,能够让用户在自己的机器上测试不同张量并行的性能。这个基准测试跑一个并行的MLP模型, -输入数据的维度为`(批大小,序列长度,隐藏层维度)`。通过指定GPU的数量,Colossal-AI会搜索所有可行的并行配置。用户可以通过查看`colossalai benchmark --help`来自定义相关的测试参数。 - -```shell -# 使用4个GPU -colossalai benchmark --gpus 4 - -# 使用8个GPU -colossalai benchmark --gpus 8 -``` - -:::caution - -目前仅支持单节点的基准测试。 - -::: + diff --git a/docs/source/zh-Hans/basics/configure_parallelization.md b/docs/source/zh-Hans/basics/configure_parallelization.md deleted file mode 100644 index eb4b38f48ddb3653735f3dfd49fa263fcdbc9d07..0000000000000000000000000000000000000000 --- a/docs/source/zh-Hans/basics/configure_parallelization.md +++ /dev/null @@ -1,136 +0,0 @@ -# 并行配置 - -作者: Shenggui Li, Siqi Mai - -**预备知识:** -- [分布式训练](../concepts/distributed_training.md) -- [并行技术](../concepts/paradigms_of_parallelism.md) -- [构建配置文件](./define_your_config.md) - - -## 简介 - -我们在 Colossal-AI 中支持多种并行技术。代码库中的混合并行是指您可以轻松地结合数据并行、流水线并行和张量并行(1D、2D、2.5D、3D)的优势共同来进行并行训练。 - -每种并行方式需要不同的网络拓扑结构,因此要初始化不同的进程组。您可以通过在配置文件中设置 `parallel` 来初始化相应的进程组。 `parallel` 的配置必须遵从以下格式。数据并行度的大小将被根据您对流水线并行和张量并行的输入自动推断。`colossalai.launch` 将根据您的配置自动初始化这些分布式进程组。 - -我们为您提供了一些配置的例子以供参考。 - -```python -# sampler format -parallel = dict( - pipeline=dict("size": int), - tensor=dict("size": int, "mode": '1d' or '2d' or '2.5d' or '3d', "kwargs": Any) -) - -# this is ok -parallel = dict( - pipeline=dict(size=2), - tensor=dict(size=4, mode='2d') -) - -# this is ok -parallel = dict( - pipeline=2, - tensor=dict(size=4, mode='2d') -) - -# this is not ok -# as you need to specify the mode for tensor parallelism -parallel = dict( - pipeline=2, - tensor=4 -) - -# this is ok as well as tensor will be default to size 1 -# and mode None -parallel = dict( - pipeline=2 -) - -# this is ok as well as pipeline will default to size 1 -parallel = dict( - tensor=dict(size=4, mode='2d') -) - -``` - -关键字 `size` 指的是并行维度的并行大小。 例如,流水线大小为2意味着有 -将有2个流水线阶段。张量并行配置中的关键字 `mode` 意味着相应的张量并行技术 -将被初始化,如1D、2D、2.5D、3D。 - -**您也可以选择不在您的配置中使用 "并行",此时流水线和张量的并行度都将默认为大小1。** - -**GPU的总数量必须等于` 数据并行大小 x 张量并行大小 x 流水线并行大小` 。** - -## 数据并行 - -数据并行是最常见的分布式训练方式。它将数据分割成几个碎片分别在每个设备上进行训练。数据并行的配置会自动检测并为您设置。您不需要在您的配置中明确地设置它们。在Colossal-AI 中,有两种方法来处理数据并行的 all-reduce。 - -1. 如果您设置了梯度handler,梯度handler将会all-reduce梯度。 -2. 若没有指定相应的配置,Colossal-AI 将会使用 PyTorch 的 DistributedDataParallel。 - -在大多数情况下,若您对梯度没有复杂的处理的需求,您将会使用第二种模式。 - -## 1D, 2D, 2.5D 和 3D 并行 - -为了实现混合并行,我们提供了一系列张量并行方法。您可以阅读相应的学术论文进行深入的了解。这些并行模式需要和 Colossal-AI 提供的分布式层一同工作。 - -- 1D: [Megatron-LM: Training Multi-Billion Parameter Language Models Using Model Parallelism](https://arxiv.org/abs/1909.08053) - -- 2D: [An Efficient 2D Method for Training Super-Large Deep Learning Models](https://arxiv.org/abs/2104.05343) - 2D 并行基于 SUMMA 矩阵乘法,它将输入数据、模型权重和层输出切分成两个不同的维度。 这些张量块分布在 `P = N^2` 设备的二维网格上,其中 `N` 是单一维度上张量块的数量。 - -- 2.5D: [2.5-dimensional distributed model training](https://arxiv.org/abs/2105.14500) - 在 2.5D 矩阵乘法的启发下,2.5D 并行引入了一种新的张量并行,进一步将2D张量并行化。其中,`P = N^2 ∗ d` 个处理器被分配到 `d` 层, 每层独立进行矩阵乘法运算,维度为 `N`。 - -- 3D: [Maximizing Parallelism in Distributed Training for Huge Neural Networks](https://arxiv.org/abs/2105.14450) - 我们还介绍了一种 3D 张量并行方法,在三维处理器立方体上并行化神经网络。这种方法在数量为 `P` 的处理器上实现了最佳的 `O(P^{1/3})` 通信开销,而计算和内存的使用都是通过优化的参数和激活的负载平衡来实现的。同时,通过优化参数和 activations 的负载平衡,计算和内存的使用都是均匀分布的。 - -```python -# 1D parallel -parallel = dict( - tensor=dict(size=4, mode='1d') -) - -# 2D parallel -parallel = dict( - tensor=dict(size=4, mode='2d') -) - -# 2.5D parallel -parallel = dict( - tensor=dict(size=8, mode='2.5d', depth=2) -) - -# 3D parallel -parallel = dict( - tensor=dict(size=8, mode='3d') -) -``` - -当您在配置中指定了张量并行模式,您就可以使用其相应的分布式算子。例如,若您设置模式为 `2d`,那么在模型构建中就能使用 `colossalai.nn.Linear2D` 了。 - - -## 流水线并行 - -流水线并行是将模型按层分成几个部分。例如,假设我们有一个简单的模型,它由两个线性层组成。我们有两个 GPU,我们可以将第一个线性层分配给第一个 GPU 而第二层则分配给第二个 GPU。 - -您可以在您的配置文件中设置流水线并行度的大小。当流水线并行度大于1,Colossal-AI 将会自动地创建流水线并行的 schedule,这将会为您定义好模型训练的 `forward` 和 `backward`。 - -```python -parallel = dict( - pipeline=dict(size=4), # number of pipeline stages -) -``` - -## 序列并行 - -针对处理大图片、视频、长文本、长时间医疗监控等数据的需要,Colossal-AI 还提供了序列并行的方法。该方法是在论文[Sequence Parallelism: Making 4D Parallelism Possible](https://arxiv.org/abs/2105.13120)中提出的。您可以指定模式为 `sequence` 来初始化进程组。 - - -```python -parallel = dict( - tensor=dict(size=4, mode='sequence') -) -``` diff --git a/docs/source/zh-Hans/basics/define_your_config.md b/docs/source/zh-Hans/basics/define_your_config.md deleted file mode 100644 index d7e49cbf23dee33fcf497feb3754e507d05c2442..0000000000000000000000000000000000000000 --- a/docs/source/zh-Hans/basics/define_your_config.md +++ /dev/null @@ -1,71 +0,0 @@ -# 构建配置文件 - -作者: Guangyang Lu, Shenggui Li, Siqi Mai - -**预备知识:** -- [分布式训练](../concepts/distributed_training.md) -- [Colossal-AI 总览](../concepts/colossalai_overview.md) - - -## 简介 - -在 Colossal-AI 中,我们需要一个配置文件来指定系统在训练过程中要注入的特征。在本教程中,我们将向您介绍如何构建您的配置文件以及如何使用这个配置文件。使用配置文件有以下一些好处: - -1. 您可以在不同的配置文件中存储您的特征配置和训练超参数。 -2. 对于我们未来发布的新功能,您亦可以在配置中指定,而无需改变训练脚本的代码。 - -在本教程中,我们将向您介绍如何构建您的配置文件。 - -## 配置定义 - -在一个配置文件中,有两种类型的变量。一种是作为特征说明,另一种是作为超参数。所有与特征相关的变量都是保留关键字。例如,如果您想使用混合精度训练,需要在 config 文件中使用变量名`fp16`,并遵循预先定义的格式。 - -### 功能配置 - -Colossal-AI 提供了一系列的功能来加快训练速度。每个功能都是由配置文件中的相应字段定义的。在本教程中,我们不会给出所有功能的配置细节,而是提供一个如何指定一个功能的说明。**每个功能的细节可以在其各自的教程中找到。** - -为了说明配置文件的使用,我们在这里使用混合精度训练作为例子。您需要遵循以下步骤。 - -1. 创建一个配置文件(例如 `config.py`,您可以指定任意的文件名)。 -2. 在配置文件中定义混合精度的配置。例如,为了使用 PyTorch 提供的原始混合精度训练,您只需将下面这几行代码写入您的配置文件中。 - - ```python - from colossalai.amp import AMP_TYPE - - fp16 = dict( - mode=AMP_TYPE.TORCH - ) - ``` - -3. 当启动分布式环境时,向 Colossal-AI 指定您的配置文件的位置。比如下面的例子是配置文件在当前目录下。 - - ```python - import colossalai - - colossalai.launch(config='./config.py', ...) - ``` - -这样,Colossal-AI 便知道您想使用什么功能,并会在 `colossalai.initialize` 期间注入您所需要的功能。 - -### 全局超参数 - -除了功能的配置,您还可以在配置文件中定义训练的超参数。当您想进行多个实验时,这将会变得非常方便。每个实验的细节都可以放在独立的配置文件中,以避免混乱。这些参数将被存储在全局并行环境中,可以在训练脚本中访问。 - -例如,您可以在配置文件中指定批量大小。 - -```python -BATCH_SIZE = 32 -``` - -启动后,您能够通过全局并行上下文访问您的超参数。 - -```python -import colossalai -from colossalai.core import global_context as gpc - -colossalai.launch(config='./config.py', ...) - -# access your parameter -print(gpc.config.BATCH_SIZE) - -``` diff --git a/docs/source/zh-Hans/basics/engine_trainer.md b/docs/source/zh-Hans/basics/engine_trainer.md deleted file mode 100644 index a7519bfca14f7dbddff22ef9dec20ae3cdc17e93..0000000000000000000000000000000000000000 --- a/docs/source/zh-Hans/basics/engine_trainer.md +++ /dev/null @@ -1,384 +0,0 @@ -# 如何在训练中使用 Engine 和 Trainer - -作者: Shenggui Li, Siqi Mai - -**预备知识:** -- [初始化功能](./initialize_features.md) - -## 简介 - -在本教程中,您将学习如何使用 Colossal-AI 中提供的 Engine 和 Trainer 来训练您的模型。在深入研究细节之前,我们想先解释一下 Engine 和 Trainer 的概念。 - -### Engine - -Engine 本质上是一个模型、优化器和损失函数的封装类。当我们调用 `colossalai.initialize` 时,一个 Engine 对象将被返回,并且配备了在您的配置文件中指定的梯度剪裁、梯度累计和 ZeRO 优化器等功能。 - -Engine 将使用与 PyTorch 训练组件类似的 API,因此您只需对代码进行微小的修改即可。 - -下表展示了Engine的常用API。 - -| 组件 | 功能 | PyTorch | Colossal-AI | -| ------------------------------------- | --------------------------------------------- | ------------------------------- | -------------------------------------- | -| optimizer | 迭代前将所有梯度设置为零 | optimizer.zero_grad() | engine.zero_grad() | -| optimizer | 更新参数 | optimizer.step() | engine.step() | -| model | 进行一次前向计算 | outputs = model(inputs) | outputs = engine(inputs) | -| criterion | 计算loss值 | loss = criterion(output, label) | loss = engine.criterion(output, label) | -| criterion | 反向计算 | loss.backward() | engine.backward(loss) | - -我们需要这样一个 Engine 类的原因是,我们可以添加更多的功能,同时将实现隐藏在 -`colossalai.initialize` 函数中实现。 -假如我们要添加一个新的功能,我们可以在 `colossalai.initialize` 函数中完成对于模型、优化器、数据加载器和损失函数的功能诠释。不管中间的过程有多复杂,最终我们呈现的以及用户需要使用的只有一个 Engine 类,这将十分便捷。 -用户只需要在最小范围内修改他们的代码,将普通的 PyTorch APIs 调整为 Colossal-AI -Engine 的 API。通过这种方式,他们可以享受更多的功能来进行有效的训练。 - -以下是一个简单的例子: - -```python -import colossalai - -# build your model, optimizer, criterion, dataloaders -... - -engine, train_dataloader, test_dataloader, _ = colossalai.initialize(model, - optimizer, - criterion, - train_dataloader, - test_dataloader) -for img, label in train_dataloader: - engine.zero_grad() - output = engine(img) - loss = engine.criterion(output, label) - engine.backward(loss) - engine.step() -``` - -### Trainer - -Trainer 是一个更高级的封装器,用户可以用更少的代码行来执行训练。 由于 Trainer 的使用会更加简单,相较于 Engine,它会缺少一点灵活性。 Trainer 被设计为进行前向和反向计算来进行模型权重的更新。通过传递 Engine 对象,我们可以很容易地创建一个 Trainer。 -Trainer 的参数 `schedule` 默认值是 `None` 。在大多数情况下,除非我们想使用流水线并行,否则我们把这个值设为 `None`。如果您想探索更多关于这个参数的内容,您可以前往流水线并行的相关教程。 - -```python -from colossalai.logging import get_dist_logger -from colossalai.trainer import Trainer, hooks - -# build components and initialize with colossalai.initialize -... - -# create a logger so that trainer can log on the console -logger = get_dist_logger() - -# create a trainer object -trainer = Trainer( - engine=engine, - logger=logger -) -``` - -在 Trainer 中,用户可以定制一些 hooks,并将这些 hooks 附加到 Trainer 上。hook 将根据训练方案定期地执行生命周期函数。例如,基于用户是想在每次训练迭代后还是只在整个训练周期后更新学习率, -`LRSchedulerHook` 将会在 `after_train_iter` 或 `after_train_epoch` 阶段执行 `lr_scheduler.step()` 去为用户更新学习率。您可以将 hook 存储在一个列表中并将其传递给 `trainer.fit` 方法。`trainer.fit` 方法将根据您的参数执行训练和测试。如果 `display_process` 为 True,将在您的控制台显示一个进度条,以显示训练的过程。 - - -```python -# define the hooks to attach to the trainer -hook_list = [ - hooks.LossHook(), - hooks.LRSchedulerHook(lr_scheduler=lr_scheduler, by_epoch=True), - hooks.AccuracyHook(accuracy_func=Accuracy()), - hooks.LogMetricByEpochHook(logger), -] - -# start training -trainer.fit( - train_dataloader=train_dataloader, - epochs=NUM_EPOCHS, - test_dataloader=test_dataloader, - test_interval=1, - hooks=hook_list, - display_progress=True -) -``` - -如果您想定制您的 hook 类,您可以继承 `hooks.BaseHook` 并重写您想要的生命周期方法。下面提供了一个例子来演示如何创建一个简单的关于日志信息的 hook,以供您参考。 - -```python -from colossalai.logging import get_dist_logger -from colossalai.trainer import hooks - -class LogMessageHook(hooks.BaseHook): - - def __init__(self, priority=10): - self._logger = get_dist_logger() - - def before_train(self, trainer): - self._logger.info('training starts') - - def after_train(self, trainer): - self._logger.info('training finished') - - -... - -# then in your training script -hook_list.append(LogMessageHook()) -``` - - - -在下面的章节中,您将会详细地了解到如何用 Engine 和 Trainer 来训练 ResNet 模型。 - - -## ResNet - -### 总览 - -在本节中,我们将介绍: - -1. 使用一个 Engine 在 CIFAR10 数据集上训练 ResNet34 模型 -2. 使用一个 Trainer 在 CIFAR10 数据集上训练 ResNet34 模型 - -项目结构如下: - -```bash --- config.py --- run_resnet_cifar10_with_engine.py --- run_resnet_cifar10_with_trainer.py -``` - -对于使用 Engine 或 Trainer,步骤 1-4 是通用的。 因此,步骤 1-4 + 步骤 5 将会是对应 `run_resnet_cifar10_with_engine.py` 而 步骤 1-4 + 步骤6 则对应 `run_resnet_cifar10_with_trainer.py`。 - -### 牛刀小试 - -#### 步骤 1. 创建配置文件 - -在你的项目文件夹中,创建一个 `config.py`。这个文件是用来指定一些您可能想用来训练您的模型的特征。下面是一个配置文件的例子。 - -```python -from colossalai.amp import AMP_TYPE - -BATCH_SIZE = 128 -NUM_EPOCHS = 200 - -fp16=dict( - mode=AMP_TYPE.TORCH -) -``` - -在这个配置文件中,我们指定要在每个 GPU 上使用批大小为128,并运行200个 epoch。这两个参数是在 `gpc.config` 中体现的。例如,您可以使用 `gpc.config.BATCH_SIZE` 来访问您存储在配置文件中的批大小值。而 `fp16` 配置则会告诉 `colossalai.initialize` 使用 PyTorch 提供的混合精度训练,以更好的速度和更低的内存消耗来训练模型。 - -#### 步骤 2. 初始化分布式环境 - -我们需要初始化分布式训练环境。这在 [启动 Colossal-AI](./launch_colossalai.md) 中有相应的教程。在当前的演示中,我们使用 `launch_from_torch` 和 PyTorch 启用工具。 - -```python -import colossalai - -# ./config.py refers to the config file we just created in step 1 -colossalai.launch_from_torch(config='./config.py') -``` - -#### 步骤 3. 创建所有的训练组件 - -这时,我们可以创建用于训练的所有组件,包括: - -1. 模型 -2. 优化器 -3. 损失函数 -4. 训练/测试数据加载器 -5. 学习率调度器 -6. 日志记录器 - - - -为了构建这些组件,您需要导入以下模块。 - -```python -from pathlib import Path -from colossalai.logging import get_dist_logger -import torch -import os -from colossalai.core import global_context as gpc -from colossalai.utils import get_dataloader -from torchvision import transforms -from colossalai.nn.lr_scheduler import CosineAnnealingLR -from torchvision.datasets import CIFAR10 -from torchvision.models import resnet34 -``` - - - -然后按照通常在PyTorch脚本中构建组件的方式来构建组件。在下面的脚本中,我们将CIFAR10数据集的根路径设置为环境变量 `DATA`。您可以把它改为您想要的任何路径,例如,您可以把 `root=Path(os.environ['DATA'])` 改为 `root='./data'` ,这样就不需要设置环境变量。 - -```python -# build logger -logger = get_dist_logger() - -# build resnet -model = resnet34(num_classes=10) - -# build datasets -train_dataset = CIFAR10( - root='./data', - download=True, - transform=transforms.Compose( - [ - transforms.RandomCrop(size=32, padding=4), - transforms.RandomHorizontalFlip(), - transforms.ToTensor(), - transforms.Normalize(mean=[0.4914, 0.4822, 0.4465], std=[ - 0.2023, 0.1994, 0.2010]), - ] - ) -) - -test_dataset = CIFAR10( - root='./data', - train=False, - transform=transforms.Compose( - [ - transforms.ToTensor(), - transforms.Normalize(mean=[0.4914, 0.4822, 0.4465], std=[ - 0.2023, 0.1994, 0.2010]), - ] - ) -) - -# build dataloaders -train_dataloader = get_dataloader(dataset=train_dataset, - shuffle=True, - batch_size=gpc.config.BATCH_SIZE, - num_workers=1, - pin_memory=True, - ) - -test_dataloader = get_dataloader(dataset=test_dataset, - add_sampler=False, - batch_size=gpc.config.BATCH_SIZE, - num_workers=1, - pin_memory=True, - ) - -# build criterion -criterion = torch.nn.CrossEntropyLoss() - -# optimizer -optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9, weight_decay=5e-4) - -# lr_scheduler -lr_scheduler = CosineAnnealingLR(optimizer, total_steps=gpc.config.NUM_EPOCHS) -``` - -#### 步骤 4. 用 Colossal-AI 进行初始化 - -接下来,重要的一步是通过调用 `colossalai.initialize` 获得 Engine。正如 `config.py` 中所述,我们将使用混合精度训练来训练 ResNet34 模型。`colossalai.initialize` 将自动检查您的配置文件,并将相关特征分配给您的训练组件。这样一来,我们的 Engine 已经能够进行混合精度训练,而您不需要进行额外的处理。 - -```python -engine, train_dataloader, test_dataloader, _ = colossalai.initialize(model, - optimizer, - criterion, - train_dataloader, - test_dataloader, - ) -``` - - - -#### 步骤 5. 用 Engine 进行训练 - -当所有的训练组件都准备好后,我们就可以像使用 PyTorch 一样训练 ResNet34 了。 - -```python -for epoch in range(gpc.config.NUM_EPOCHS): - # execute a training iteration - engine.train() - for img, label in train_dataloader: - img = img.cuda() - label = label.cuda() - - # set gradients to zero - engine.zero_grad() - - # run forward pass - output = engine(img) - - # compute loss value and run backward pass - train_loss = engine.criterion(output, label) - engine.backward(train_loss) - - # update parameters - engine.step() - - # update learning rate - lr_scheduler.step() - - # execute a testing iteration - engine.eval() - correct = 0 - total = 0 - for img, label in test_dataloader: - img = img.cuda() - label = label.cuda() - - # run prediction without back-propagation - with torch.no_grad(): - output = engine(img) - test_loss = engine.criterion(output, label) - - # compute the number of correct prediction - pred = torch.argmax(output, dim=-1) - correct += torch.sum(pred == label) - total += img.size(0) - - logger.info( - f"Epoch {epoch} - train loss: {train_loss:.5}, test loss: {test_loss:.5}, acc: {correct / total:.5}, lr: {lr_scheduler.get_last_lr()[0]:.5g}", ranks=[0]) -``` - -#### 步骤 6. 用 Trainer 进行训练 - -如果您想用 Trainer 进行训练,您可以参考下面的代码进行您的实验。 - - -```python -from colossalai.nn.metric import Accuracy -from colossalai.trainer import Trainer, hooks - - -# create a trainer object -trainer = Trainer( - engine=engine, - logger=logger -) - -# define the hooks to attach to the trainer -hook_list = [ - hooks.LossHook(), - hooks.LRSchedulerHook(lr_scheduler=lr_scheduler, by_epoch=True), - hooks.AccuracyHook(accuracy_func=Accuracy()), - hooks.LogMetricByEpochHook(logger), - hooks.LogMemoryByEpochHook(logger) -] - -# start training -# run testing every 1 epoch -trainer.fit( - train_dataloader=train_dataloader, - epochs=gpc.config.NUM_EPOCHS, - test_dataloader=test_dataloader, - test_interval=1, - hooks=hook_list, - display_progress=True -) -``` - - - -#### 步骤 7. 开始分布式训练 - -最后,我们可以使用 PyTorch 提供的分布式启动器来调用脚本,因为我们在步骤2中使用了 `launch_from_torch`。您需要把`` 替换成您机器上可用的GPU数量。如果您只想使用一个 GPU,您可以把这个数字设为1。如果您想使用其他的启动器,请您参考如何启动 Colossal-AI 的教程。 - - -```bash -# with engine -python -m torch.distributed.launch --nproc_per_node --master_addr localhost --master_port 29500 run_resnet_cifar10_with_engine.py -# with trainer -python -m torch.distributed.launch --nproc_per_node --master_addr localhost --master_port 29500 run_resnet_cifar10_with_trainer.py -``` diff --git a/docs/source/zh-Hans/basics/initialize_features.md b/docs/source/zh-Hans/basics/initialize_features.md deleted file mode 100644 index 67ea114b42b29e545c628d4cc9ae0fd77c1da2f8..0000000000000000000000000000000000000000 --- a/docs/source/zh-Hans/basics/initialize_features.md +++ /dev/null @@ -1,46 +0,0 @@ -# 初始化功能 - -作者: Shenggui Li, Siqi Mai - -**预备知识:** -- [分布式训练](../concepts/distributed_training.md) -- [Colossal-AI 总览](../concepts/colossalai_overview.md) - -## 简介 - -在本教程中,我们将介绍 `colossalai.initialize` 的使用。 它包含了如何将特征(例如,模型、优化器、数据加载器)无缝注入您的训练组件中。 调用 `colossalai.initialize` 是您进入训练循环前的基本操作。 - -在下面一节中,我们将介绍 `colossalai.initialize` 是如何工作的以及使用中我们要注意的细节。 - -## 使用 - -在一个典型的工作流程中,我们将在训练脚本的开始启动分布式环境。 -之后,我们将实例化我们的对象,如模型、优化器、损失函数、数据加载器等。此时,我们可以使用 `colossalai.initialize` 便捷地为这些对象注入特征。 -具体细节请看以下的伪代码例子。 - -```python -import colossalai -import torch -... - - -# launch distributed environment -colossalai.launch(config='./config.py', ...) - -# create your objects -model = MyModel() -optimizer = torch.optim.Adam(model.parameters(), lr=0.001) -criterion = torch.nn.CrossEntropyLoss() -train_dataloader = MyTrainDataloader() -test_dataloader = MyTrainDataloader() - -# initialize features -engine, train_dataloader, test_dataloader, _ = colossalai.initialize(model, - optimizer, - criterion, - train_dataloader, - test_dataloader) -``` - - `colossalai.initialize` 将返回一个 `Engine` 对象。 该对象把模型、优化器和损失函数封装起来。 **`Engine` 对象会以配置文件中指定的特征运行。** -关于 `Engine` 的更多使用细节可以在 [在训练中使用Engine和Trainer](./engine_trainer.md) 中获取。 diff --git a/docs/source/zh-Hans/basics/launch_colossalai.md b/docs/source/zh-Hans/basics/launch_colossalai.md index ca927de578d5d7db96f6d1717f8aed892da4504a..39b09deae085941e53e87796f2b710f14528282a 100644 --- a/docs/source/zh-Hans/basics/launch_colossalai.md +++ b/docs/source/zh-Hans/basics/launch_colossalai.md @@ -74,7 +74,7 @@ import colossalai args = colossalai.get_default_parser().parse_args() # launch distributed environment -colossalai.launch(config=, +colossalai.launch(config=args.config, rank=args.rank, world_size=args.world_size, host=args.host, @@ -93,12 +93,21 @@ PyTorch自带的启动器需要在每个节点上都启动命令才能启动多 首先,我们需要在代码里指定我们的启动方式。由于这个启动器是PyTorch启动器的封装,那么我们自然而然应该使用`colossalai.launch_from_torch`。 分布式环境所需的参数,如 rank, world size, host 和 port 都是由 PyTorch 启动器设置的,可以直接从环境变量中读取。 +config.py +```python +BATCH_SIZE = 512 +LEARNING_RATE = 3e-3 +WEIGHT_DECAY = 0.3 +NUM_EPOCHS = 2 +``` +train.py ```python import colossalai colossalai.launch_from_torch( - config=, + config="./config.py", ) +... ``` 接下来,我们可以轻松地在终端使用`colossalai run`来启动训练。下面的命令可以在当前机器上启动一个4卡的训练任务。 diff --git a/docs/source/zh-Hans/basics/model_checkpoint.md b/docs/source/zh-Hans/basics/model_checkpoint.md deleted file mode 100644 index cec12d45198911c526326590bd15d2617a6b4cf4..0000000000000000000000000000000000000000 --- a/docs/source/zh-Hans/basics/model_checkpoint.md +++ /dev/null @@ -1,61 +0,0 @@ -# 模型检查点 - -作者 : Guangyang Lu - -**预备知识:** -- [Launch Colossal-AI](./launch_colossalai.md) -- [Initialize Colossal-AI](./initialize_features.md) - -**示例代码:** -- [ColossalAI-Examples Model Checkpoint](https://github.com/hpcaitech/ColossalAI-Examples/tree/main/utils/checkpoint) - -**函数是经验函数.** - -## 简介 - -本教程将介绍如何保存和加载模型检查点。 - -为了充分利用Colossal-AI的强大并行策略,我们需要修改模型和张量,可以直接使用 `torch.save` 或者 `torch.load` 保存或加载模型检查点。在Colossal-AI中,我们提供了应用程序接口实现上述同样的效果。 - -但是,在加载时,你不需要使用与存储相同的保存策略。 - -## 使用方法 - -### 保存 - -有两种方法可以使用Colossal-AI训练模型,即使用engine或使用trainer。 -**注意我们只保存 `state_dict`.** 因此,在加载检查点时,需要首先定义模型。 - -#### 同 engine 保存 - -```python -from colossalai.utils import save_checkpoint -model = ... -engine, _, _, _ = colossalai.initialize(model=model, ...) -for epoch in range(num_epochs): - ... # do some training - save_checkpoint('xxx.pt', epoch, model) -``` - -#### 用 trainer 保存 -```python -from colossalai.trainer import Trainer, hooks -model = ... -engine, _, _, _ = colossalai.initialize(model=model, ...) -trainer = Trainer(engine, ...) -hook_list = [ - hooks.SaveCheckpointHook(1, 'xxx.pt', model) - ...] - -trainer.fit(... - hook=hook_list) -``` - -### 加载 - -```python -from colossalai.utils import load_checkpoint -model = ... -load_checkpoint('xxx.pt', model) -... # train or test -``` diff --git a/docs/source/zh-Hans/concepts/colossalai_overview.md b/docs/source/zh-Hans/concepts/colossalai_overview.md index cfb35e59e64a99924297e5bbdd83c96930a2f93c..8b28baf8e3d568ea1e8fe54de3b3b524059a94b5 100755 --- a/docs/source/zh-Hans/concepts/colossalai_overview.md +++ b/docs/source/zh-Hans/concepts/colossalai_overview.md @@ -19,7 +19,7 @@ Colossal-AI 是一个集成的系统,为用户提供一套综合的训练方 1. 准备一个配置文件,指定您要使用的功能和参数。 2. 用 `colossalai.launch` 初始化分布式后端。 -3. 用 `colossalai.initialize` 将训练特征注入您的训练组件(如模型、优化器)中。 +3. 用 `colossalai.booster` 将训练特征注入您的训练组件(如模型、优化器)中。 4. 进行训练和测试. 我们将在`基本教程`部分介绍整个工作流程。 @@ -34,3 +34,5 @@ Colossal-AI 系统将会进一步拓展和优化,包括但不限于: 4. 拓展现有的并行方法 **我们始终欢迎社区的建议和讨论,如果您遇到任何问题,我们将非常愿意帮助您。您可以在GitHub 提 [issue](https://github.com/hpcaitech/ColossalAI/issues) ,或在[论坛](https://github.com/hpcaitech/ColossalAI/discussions)上创建一个讨论主题。** + + diff --git a/docs/source/zh-Hans/features/1D_tensor_parallel.md b/docs/source/zh-Hans/features/1D_tensor_parallel.md index 2ddc27c7b50f8cef86b444d0b24a1504ebcb8772..fb6fd90ec4c2fae578aa48e7fe88b3774a2bf44c 100644 --- a/docs/source/zh-Hans/features/1D_tensor_parallel.md +++ b/docs/source/zh-Hans/features/1D_tensor_parallel.md @@ -2,12 +2,9 @@ 作者: Zhengda Bian, Yongbin Li -**前置教程** -- [定义配置文件](../basics/define_your_config.md) -- [并行配置](../basics/configure_parallelization.md) **示例代码** -- [ColossalAI-Examples 1D Tensor Parallelism](https://github.com/hpcaitech/ColossalAI-Examples/tree/main/features/tensor_parallel/tensor_parallel_1d.py) +- [Tensor Parallelism with Shardformer](https://github.com/hpcaitech/ColossalAI/tree/main/colossalai/shardformer/examples) **相关论文** - [Efficient Large-Scale Language Model Training on GPU Clusters Using Megatron-LM](https://deepakn94.github.io/assets/papers/megatron-sc21.pdf) @@ -20,15 +17,16 @@ 让我们以一个线性层为例,它包括一个 GEMM $Y = XA$。 给定2个处理器,我们把列 $A$ 划分为 $[A_1 ~ A_2]$, 并在每个处理器上计算 $Y_i = XA_i$ , 然后形成 $[Y_1 ~ Y_2] = [XA_1 ~ XA_2]$. 这被称为列并行方式。 当第二个线性层 $Z=YB$ 跟随上述列并行层的时候, 我们把 $B$ 划分为 -```math +$$ \left[\begin{matrix} B_1 \\ B_2 \end{matrix} \right] ``` -这就是所谓的行并行方式.
              +这就是所谓的行并行方式. +$$ 为了计算 -```math +$$ Z = [Y_1 ~ Y_2] \left[\begin{matrix} B_1 \\ B_2 \end{matrix} \right] -``` +$$ 我们首先在每个处理器上计算 $Y_iB_i$ 然后使用一个all-reduce操作将结果汇总为 $Z=Y_1B_1+Y_2B_2$。 我们还需要注意,在后向计算中,列并行线性层需要聚合输入张量 $X$, 因为在每个处理器 $i$ 上,我们只有 $\dot{X_i}=\dot{Y_i}A_i^T$,因此,我们在各处理器之间进行all-reduce,得到 $\dot{X}=\dot{Y}A^T=\dot{Y_1}A_1^T+\dot{Y_2}A_2^T$。 @@ -40,80 +38,10 @@ Z = [Y_1 ~ Y_2] \left[\begin{matrix} B_1 \\ B_2 \end{matrix} \right] | :-: | :-: | :-: | :-: | :-: | | $O(1/P)$ | $O(1/P)$ | $O(1)$ | $O(2(P-1)/P)$ | $O(2(P-1))$ | -## 使用 - -为了使模型能够实现一维张量并行, 如在2个 GPU 上, 我们需要配置如下的并行设置。 -```python -CONFIG = dict(parallel=dict( - data=1, - pipeline=1, - tensor=dict(size=2, mode='1d'), -)) -``` - -然后 Colossal-AI 会自动对所有来自 `colossalai.nn` 的层应用1D张量并行。 - -让我们定义一个由两层多层感知器 (MLP) 组成的模型,如下所示。 -```python -import colossalai -import colossalai.nn as col_nn -import torch -from colossalai.utils import print_rank_0 - -class MLP(torch.nn.Module): - def __init__(self, dim: int = 256): - super().__init__() - intermediate_dim = dim * 4 - self.dense_1 = col_nn.Linear(dim, intermediate_dim) - print_rank_0(f'Weight of the first linear layer: {self.dense_1.weight.transpose(0, 1).shape}') - self.activation = torch.nn.GELU() - self.dense_2 = col_nn.Linear(intermediate_dim, dim) - print_rank_0(f'Weight of the second linear layer: {self.dense_2.weight.transpose(0, 1).shape}') - self.dropout = col_nn.Dropout(0.1) - - def forward(self, x): - x = self.dense_1(x) - print_rank_0(f'Output of the first linear layer: {x.shape}') - x = self.activation(x) - x = self.dense_2(x) - print_rank_0(f'Output of the second linear layer: {x.shape}') - x = self.dropout(x) - return x -``` - -在2个 GPU 上启动 Colossal-AI 并建立模型。 -```python -parser = colossalai.get_default_parser() -colossalai.launch(config=CONFIG, - rank=args.rank, - world_size=args.world_size, - local_rank=args.local_rank, - host=args.host, - port=args.port) - -m = MLP() -``` -我们将会看到 MLP 模型中被划分的参数(如权重)的形状。 -```shell -Weight of the first linear layer: torch.Size([256, 512]) -Weight of the second linear layer: torch.Size([512, 256]) -``` -第一个线性层的完整权重形状应该为 `[256, 1024]`. 经过列-并行分割,它变成了 `[256, 512]`。 -同样地,第二个行并行层将权重 `[1024, 256]` 划分为 `[512, 256]`。 - -我们可以用一些随机输入来运行这个模型。 -```python -from colossalai.utils import get_current_device +## 使用 -x = torch.randn((16, 256), device=get_current_device()) -torch.distributed.broadcast(x, src=0) # synchronize input +在ColossalAI最新的版本中,1D张量并行由`Shardformer`功能实现。 +关于`Shardformer`的原理和用法细节请参考当前目录下的Shardformer文档。 -x = m(x) -``` -然后我们可以看到 activation 结果的形状。 -```shell -Output of the first linear layer: torch.Size([16, 512]) -Output of the second linear layer: torch.Size([16, 256]) -``` -第一个线性层的输出被划分成2块 (每个形状为 `[16, 512]`), 而第二层在整个 GPU 上的输出是相同的。 + diff --git a/docs/source/zh-Hans/features/2D_tensor_parallel.md b/docs/source/zh-Hans/features/2D_tensor_parallel.md index c942f82bf9d2592b981dc77ede304a8fe8a674cb..0cb7968c81030782cfef8c94dbfeed84ba98fa0b 100644 --- a/docs/source/zh-Hans/features/2D_tensor_parallel.md +++ b/docs/source/zh-Hans/features/2D_tensor_parallel.md @@ -3,12 +3,10 @@ 作者: Zhengda Bian, Yongbin Li **前置教程** -- [定义配置文件](../basics/define_your_config.md) -- [并行配置](../basics/configure_parallelization.md) - [1D 张量并行](./1D_tensor_parallel.md) **示例代码** -- [ColossalAI-Examples - 2D Tensor Parallelism](https://github.com/hpcaitech/ColossalAI-Examples/tree/main/features/tensor_parallel/tensor_parallel_2d.py) +- [ColossalAI-Examples - 2D Tensor Parallelism](https://github.com/hpcaitech/ColossalAI-Examples/blob/main/features/tensor_parallel/README.md) **相关论文** - [An Efficient 2D Method for Training Super-Large Deep Learning Models](https://arxiv.org/pdf/2104.05343.pdf) @@ -22,33 +20,33 @@ 给定 $P=q\times q$ 个处理器(必要条件), 如 $q=2$, 我们把输入 $X$ 和权重A $A$ 都划分为 $$ -\left[\begin{matrix} X_{10} & X_{11} \\ X_{00} & X_{01} \end{matrix} \right] +\left[\begin{matrix} X_{00} & X_{01} \\ X_{10} & X_{11} \end{matrix} \right] \text{~and~} -\left[\begin{matrix} A_{10} & A_{11} \\ A_{00} & A_{01} \end{matrix} \right]。 +\left[\begin{matrix} A_{00} & A_{01} \\ A_{10} & A_{11} \end{matrix} \right]. $$ 该计算包括 $q$ 步。 当 $t=1$ 时, $X_{i0}$ 在其行中被广播, 而 $A_{0j}$ 在其列中被广播。因此,我们有 $$ -\left[\begin{matrix} X_{10},A_{00} & X_{10},A_{01} \\ X_{00},A_{00} & X_{00},A_{01} \end{matrix} \right]。 +\left[\begin{matrix} X_{00},A_{00} & X_{00},A_{01} \\ X_{10},A_{00} & X_{10},A_{01} \end{matrix} \right]. $$ 然后我们在每个处理器 $(i, j)$ 上将 $X_{i0}$ 和 $A_{0j}$ 相乘为 $$ -\left[\begin{matrix} X_{10}A_{00} & X_{10}A_{01} \\ X_{00}A_{00} & X_{00}A_{01} \end{matrix} \right] (1)。 +\left[\begin{matrix} X_{00}A_{00} & X_{00}A_{01} \\ X_{10}A_{00} & X_{10}A_{01} \end{matrix} \right] (1). $$ 同样,当 $t=2$ 时, $X_{i1}$ 在其行中被广播, $A_{1j}$ 在其列中被广播, 我们将它们相乘为 $$ -\left[\begin{matrix} X_{11}A_{10} & X_{11}A_{11} \\ X_{01}A_{10} & X_{01}A_{11} \end{matrix} \right] (2)。 +\left[\begin{matrix} X_{01}A_{10} & X_{01}A_{11} \\ X_{11}A_{10} & X_{11}A_{11} \end{matrix} \right] (2). $$ 通过将 $(1)$ 和 $(2)$ 相加,我们有 $$ -Y = XA = \left[\begin{matrix} X_{10}A_{00}+X_{11}A_{10} & X_{10}A_{01}+X_{11}A_{11} \\ X_{00}A_{00}+X_{01}A_{10} & X_{00}A_{01}+X_{01}A_{11} \end{matrix} \right]。 +Y = XA = \left[\begin{matrix} X_{00}A_{00}+X_{01}A_{10} & X_{00}A_{01}+X_{01}A_{11} \\ X_{10}A_{00}+X_{11}A_{10} & X_{10}A_{01}+X_{11}A_{11} \end{matrix} \right]. $$ ## 效率 @@ -60,82 +58,8 @@ $$ ## 使用 -为了使我们的模型能够实现二维张量并行,例如在4个 GPU 上,我们需要配置如下的并行设置。 -```python -CONFIG = dict(parallel=dict( - data=1, - pipeline=1, - tensor=dict(size=4, mode='2d'), -)) -``` -然后 Colossal-AI 会自动对所有来自 `colossalai.nn` 的层应用2D张量并行。 - -让我们定义一个由两层多层感知器 (MLP) 组成的模型,如下所示。 -```python -import colossalai -import colossalai.nn as col_nn -import torch -from colossalai.utils import print_rank_0 - -class MLP(torch.nn.Module): - def __init__(self, dim: int = 256): - super().__init__() - intermediate_dim = dim * 4 - self.dense_1 = col_nn.Linear(dim, intermediate_dim) - print_rank_0(f'Weight of the first linear layer: {self.dense_1.weight.shape}') - self.activation = torch.nn.GELU() - self.dense_2 = col_nn.Linear(intermediate_dim, dim) - print_rank_0(f'Weight of the second linear layer: {self.dense_2.weight.shape}') - self.dropout = col_nn.Dropout(0.1) - - def forward(self, x): - x = self.dense_1(x) - print_rank_0(f'Output of the first linear layer: {x.shape}') - x = self.activation(x) - x = self.dense_2(x) - print_rank_0(f'Output of the second linear layer: {x.shape}') - x = self.dropout(x) - return x -``` -在4个 GPU 上启动 Colossal-AI 并建立模型。 -```python -parser = colossalai.get_default_parser() -colossalai.launch(config=CONFIG, - rank=args.rank, - world_size=args.world_size, - local_rank=args.local_rank, - host=args.host, - port=args.port) - -m = MLP() -``` -我们将会看到 MLP 模型中被划分的参数(如权重)的形状。 -```shell -Weight of the first linear layer: torch.Size([128, 512]) -Weight of the second linear layer: torch.Size([512, 128]) -``` -第一个线性层的完整权重形状应该为 `[256, 1024]`. 经过2D并行划分后,它在每个 GPU 上变成了 `[128, 512]` 。 -同样地,第二层将权重 `[1024, 256]` 划分为 `[512, 128]`. - -我们可以用一些随机输入来运行这个模型。 -```python -from colossalai.context import ParallelMode -from colossalai.core import global_context as gpc -from colossalai.utils import get_current_device - -x = torch.randn((16, 256), device=get_current_device()) -# partition input -torch.distributed.broadcast(x, src=0) -x = torch.chunk(x, 2, dim=0)[gpc.get_local_rank(ParallelMode.PARALLEL_2D_COL)] -x = torch.chunk(x, 2, dim=-1)[gpc.get_local_rank(ParallelMode.PARALLEL_2D_ROW)] -print_rank_0(f'Input: {x.shape}') - -x = m(x) -``` -然后我们可以看到 activation 结果的形状。 -```shell -Input: torch.Size([8, 128]) -Output of the first linear layer: torch.Size([8, 512]) -Output of the second linear layer: torch.Size([8, 128]) -``` -2D并行中的 activation 张量都是同时在行和列分割的。例如,第一个线性层的输出是 `[8, 512]`, 而第二层的输出为 `[8, 128]`。 +ColossalAI的最新版本还暂不支持2D张量并行,但2D张量并行的功能会在未来的版本被集成入`Shardformer`中。关于`Shardformer`的原理和用法细节请参考当前目录下的Shardformer文档。 + +对于老版本ColossalAI的用户,2D张量并行的用法请参考[ColossalAI-Examples - 2D Tensor Parallelism](https://github.com/hpcaitech/ColossalAI-Examples/blob/main/features/tensor_parallel/README.md)。 + + diff --git a/docs/source/zh-Hans/features/2p5D_tensor_parallel.md b/docs/source/zh-Hans/features/2p5D_tensor_parallel.md index 59a4be02ce47764889fb38b6fa66503a58322baf..308638a359f1ea9c8a4a70f252d646f6441692b7 100644 --- a/docs/source/zh-Hans/features/2p5D_tensor_parallel.md +++ b/docs/source/zh-Hans/features/2p5D_tensor_parallel.md @@ -3,13 +3,11 @@ 作者: Zhengda Bian, Yongbin Li **前置教程** -- [定义配置文件](../basics/define_your_config.md) -- [并行配置](../basics/configure_parallelization.md) - [1D 张量并行](./1D_tensor_parallel.md) - [2D 张量并行](./2D_tensor_parallel.md) **示例代码** -- [ColossalAI-Examples - 2.5D Tensor Parallelism](https://github.com/hpcaitech/ColossalAI-Examples/tree/main/features/tensor_parallel/tensor_parallel_2p5d.py) +- [ColossalAI-Examples - 2.5D Tensor Parallelism](https://github.com/hpcaitech/ColossalAI-Examples/blob/main/features/tensor_parallel/README.md) **相关论文** - [2.5-dimensional distributed model training](https://arxiv.org/pdf/2105.14500.pdf) @@ -22,29 +20,29 @@ 给定 $P=q \times q \times d$ 个处理器(必要条件), 如 $q=d=2$, 我们把输入 $X$ 划分为 $d\times q$ 行和 $q$ 列 $$ -\left[\begin{matrix} X_{30} & X_{31} \\ X_{20} & X_{21} \\ X_{10} & X_{11} \\ X_{00} & X_{01}\end{matrix} \right], +\left[\begin{matrix} X_{00} & X_{01} \\ X_{10} & X_{11} \\ X_{20} & X_{21} \\ X_{30} & X_{31}\end{matrix} \right], $$ 它可以被重塑为 $d$ 层 $$ -\left[\begin{matrix} X_{10} & X_{11} \\ X_{00} & X_{01} \end{matrix} \right] \text{~and~}\left[\begin{matrix} X_{30} & X_{31} \\ X_{20} & X_{21} \end{matrix} \right]. +\left[\begin{matrix} X_{00} & X_{01} \\ X_{10} & X_{11} \end{matrix} \right] \text{~and~}\left[\begin{matrix} X_{20} & X_{21} \\ X_{30} & X_{31} \end{matrix} \right]. $$ 另外,权重 $A$ 被分割为 $$ -\left[\begin{matrix} A_{10} & A_{11} \\ A_{00} & A_{01} \end{matrix} \right]. +\left[\begin{matrix} A_{00} & A_{01} \\ A_{10} & A_{11} \end{matrix} \right]. $$ 对于 $X$ 相关的每一层, 我们使用SUMMA算法将 $X$ 与 $A$ 相乘。 然后,我们得到输出 $$ -\left[\begin{matrix} Y_{10}=X_{10}A_{00}+X_{11}A_{10} & Y_{11}=X_{10}A_{01}+X_{11}A_{11} \\ Y_{00}=X_{00}A_{00}+X_{01}A_{10} & Y_{01}=X_{00}A_{01}+X_{01}A_{11} \end{matrix} \right] +\left[\begin{matrix} Y_{00}=X_{00}A_{00}+X_{01}A_{10} & Y_{01}=X_{00}A_{01}+X_{01}A_{11} \\ Y_{10}=X_{10}A_{00}+X_{11}A_{10} & Y_{11}=X_{10}A_{01}+X_{11}A_{11} \end{matrix} \right] \text{~and~} $$ $$ -\left[\begin{matrix} Y_{30}=X_{30}A_{00}+X_{31}A_{10} & Y_{31}=X_{30}A_{01}+X_{31}A_{11} \\ Y_{20}=X_{20}A_{00}+X_{21}A_{10} & Y_{21}=X_{20}A_{01}+X_{21}A_{11} \end{matrix} \right]. +\left[\begin{matrix} Y_{20}=X_{20}A_{00}+X_{21}A_{10} & Y_{21}=X_{20}A_{01}+X_{21}A_{11} \\ Y_{30}=X_{30}A_{00}+X_{31}A_{10} & Y_{31}=X_{30}A_{01}+X_{31}A_{11} \end{matrix} \right]. $$ ## 效率 @@ -57,89 +55,8 @@ $$ ## 使用 -为了使我们的模型能够实现2.5D张量并行,例如在8个 GPU 上,我们需要配置如下的并行设置。 - -```python -CONFIG = dict(parallel=dict( - data=1, - pipeline=1, - tensor=dict(size=8, mode='2.5d', depth=2), -)) - -``` - -然后 Colossal-AI 会自动对所有来自 `colossalai.nn` 的层应用2.5D张量并行。 - -让我们定义一个由两层多层感知器 (MLP) 组成的模型,如下所示。 - -```python -import colossalai -import colossalai.nn as col_nn -import torch -from colossalai.utils import print_rank_0 - -class MLP(torch.nn.Module): - def __init__(self, dim: int = 256): - super().__init__() - intermediate_dim = dim * 4 - self.dense_1 = col_nn.Linear(dim, intermediate_dim) - print_rank_0(f'Weight of the first linear layer: {self.dense_1.weight.shape}') - self.activation = torch.nn.GELU() - self.dense_2 = col_nn.Linear(intermediate_dim, dim) - print_rank_0(f'Weight of the second linear layer: {self.dense_2.weight.shape}') - self.dropout = col_nn.Dropout(0.1) - - def forward(self, x): - x = self.dense_1(x) - print_rank_0(f'Output of the first linear layer: {x.shape}') - x = self.activation(x) - x = self.dense_2(x) - print_rank_0(f'Output of the second linear layer: {x.shape}') - x = self.dropout(x) - return x -``` -在8个 GPU 上启动 Colossal-AI 并建立模型。 -```python -parser = colossalai.get_default_parser() -colossalai.launch(config=CONFIG, - rank=args.rank, - world_size=args.world_size, - local_rank=args.local_rank, - host=args.host, - port=args.port) - -m = MLP() -``` -我们将会看到 MLP 模型中被划分的参数(如权重)的形状。 -```shell -Weight of the first linear layer: torch.Size([128, 512]) -Weight of the second linear layer: torch.Size([512, 128]) -``` - -第一个线性层的完整权重形状应该为 `[256, 1024]`. 经过2.5D并行划分后,它在每个 GPU 上变成了 `[128, 512]` 。 -同样地,第二层将权重 `[1024, 256]` 划分为 `[512, 128]`. - -我们可以用一些随机输入来运行这个模型。 -```python -from colossalai.context import ParallelMode -from colossalai.core import global_context as gpc -from colossalai.utils import get_current_device - -x = torch.randn((16, 256), device=get_current_device()) -# partition input -torch.distributed.broadcast(x, src=0) -x = torch.chunk(x, 2, dim=0)[gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_DEP)] -x = torch.chunk(x, 2, dim=0)[gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL)] -x = torch.chunk(x, 2, dim=-1)[gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_ROW)] -print_rank_0(f'Input: {x.shape}') - -x = m(x) -``` -然后我们可以看到 activation 结果的形状。 -```shell -Input: torch.Size([4, 128]) -Output of the first linear layer: torch.Size([4, 512]) -Output of the second linear layer: torch.Size([4, 128]) -``` -2.5D并行中的 activation 张量都是同时在$d \times q$行和$q$列分割的。例如,第一个线性层的输出是 `[4, 512]`, 而第二层的输出为 `[4, 128]`。 -注意,2.5D并行使用与2D并行相同的划分方法来处理权重,区别在于对输入的划分。 +ColossalAI的最新版本还暂不支持2.5D张量并行,但2.5D张量并行的功能会在未来的版本被集成入`Shardformer`中。关于`Shardformer`的原理和用法细节请参考当前目录下的Shardformer文档。 + +对于老版本ColossalAI的用户,2.5D张量并行的用法请参考[ColossalAI-Examples - 2.5D Tensor Parallelism](https://github.com/hpcaitech/ColossalAI-Examples/blob/main/features/tensor_parallel/README.md)。 + + diff --git a/docs/source/zh-Hans/features/3D_tensor_parallel.md b/docs/source/zh-Hans/features/3D_tensor_parallel.md index 440121c942431b48869487f8199a88fb3fa9133a..bf403d2d9636155578ff45c4d3634f67fe582b9e 100644 --- a/docs/source/zh-Hans/features/3D_tensor_parallel.md +++ b/docs/source/zh-Hans/features/3D_tensor_parallel.md @@ -3,13 +3,11 @@ 作者: Zhengda Bian, Yongbin Li **前置教程** -- [定义配置文件](../basics/define_your_config.md) -- [并行配置](../basics/configure_parallelization.md) - [1D 张量并行](./1D_tensor_parallel.md) - [2D 张量并行](./2D_tensor_parallel.md) **示例代码** -- [ColossalAI-Examples - 3D Tensor Parallelism](https://github.com/hpcaitech/ColossalAI-Examples/tree/main/features/tensor_parallel/tensor_parallel_3d.py) +- [ColossalAI-Examples - 3D Tensor Parallelism](https://github.com/hpcaitech/ColossalAI-Examples/blob/main/features/tensor_parallel/README.md) **相关论文** - [Maximizing Parallelism in Distributed Training for Huge Neural Networks](https://arxiv.org/pdf/2105.14450.pdf) @@ -67,88 +65,8 @@ $$ ## 使用 -为了使我们的模型能够实现3D张量并行,例如在8个 GPU 上,我们需要配置如下的并行设置。 - -```python -CONFIG = dict(parallel=dict( - data=1, - pipeline=1, - tensor=dict(size=8, mode='3d'), -)) -``` -然后 Colossal-AI 会自动对所有来自 `colossalai.nn` 的层应用3D张量并行。 - -让我们定义一个由两层多层感知器 (MLP) 组成的模型,如下所示。 - -```python -import colossalai -import colossalai.nn as col_nn -import torch -from colossalai.utils import print_rank_0 - -class MLP(torch.nn.Module): - def __init__(self, dim: int = 256): - super().__init__() - intermediate_dim = dim * 4 - self.dense_1 = col_nn.Linear(dim, intermediate_dim) - print_rank_0(f'Weight of the first linear layer: {self.dense_1.weight.shape}') - self.activation = torch.nn.GELU() - self.dense_2 = col_nn.Linear(intermediate_dim, dim) - print_rank_0(f'Weight of the second linear layer: {self.dense_2.weight.shape}') - self.dropout = col_nn.Dropout(0.1) - - def forward(self, x): - x = self.dense_1(x) - print_rank_0(f'Output of the first linear layer: {x.shape}') - x = self.activation(x) - x = self.dense_2(x) - print_rank_0(f'Output of the second linear layer: {x.shape}') - x = self.dropout(x) - return x -``` -在8个 GPU 上启动 Colossal-AI 并建立模型。 -```python -parser = colossalai.get_default_parser() -colossalai.launch(config=CONFIG, - rank=args.rank, - world_size=args.world_size, - local_rank=args.local_rank, - host=args.host, - port=args.port) - -m = MLP() -``` -我们将会看到 MLP 模型中被划分的参数(如权重)的形状。 -```shell -Weight of the first linear layer: torch.Size([128, 256]) -Weight of the second linear layer: torch.Size([512, 64]) -``` - -第一个线性层的完整权重形状应该为 `[256, 1024]`. 经过3D并行划分后,它在每个 GPU 上变成了 `[128, 256]` 。 -同样地,第二层将权重 `[1024, 256]` 划分为 `[512, 64]`. - -我们可以用一些随机输入来运行这个模型。 - -```python -from colossalai.context import ParallelMode -from colossalai.core import global_context as gpc -from colossalai.utils import get_current_device - -x = torch.randn((16, 256), device=get_current_device()) -# partition input -torch.distributed.broadcast(x, src=0) -x = torch.chunk(x, 2, dim=0)[gpc.get_local_rank(ParallelMode.PARALLEL_3D_WEIGHT)] -x = torch.chunk(x, 2, dim=0)[gpc.get_local_rank(ParallelMode.PARALLEL_3D_INPUT)] -x = torch.chunk(x, 2, dim=-1)[gpc.get_local_rank(ParallelMode.PARALLEL_3D_OUTPUT)] -print_rank_0(f'Input: {x.shape}') - -x = m(x) -``` -然后我们可以看到 activation 结果的形状。 -```shell -Input: torch.Size([4, 128]) -Output of the first linear layer: torch.Size([4, 512]) -Output of the second linear layer: torch.Size([4, 128]) -``` -3D并行中的 activation 张量都是同时在$q^2$行和$q$列分割的。例如,第一个线性层的输出是 `[4, 512]`, 而第二层的输出为 `[4, 128]`。 -注意,虽然这里3D并行的结果与2.5D并行的结果形状相同,但每个划分的内容是不同的。 +ColossalAI的最新版本还暂不支持3D张量并行,但3D张量并行的功能会在未来的版本被集成入`Shardformer`中。关于`Shardformer`的原理和用法细节请参考当前目录下的Shardformer文档。 + +对于老版本ColossalAI的用户,3D张量并行的用法请参考[ColossalAI-Examples - 3D Tensor Parallelism](https://github.com/hpcaitech/ColossalAI-Examples/blob/main/features/tensor_parallel/README.md)。 + + diff --git a/docs/source/zh-Hans/features/cluster_utils.md b/docs/source/zh-Hans/features/cluster_utils.md new file mode 100644 index 0000000000000000000000000000000000000000..f54a72c63a66955e84aaccb9a3eceaddad87da3a --- /dev/null +++ b/docs/source/zh-Hans/features/cluster_utils.md @@ -0,0 +1,16 @@ +# 集群实用程序 + +作者: [Hongxin Liu](https://github.com/ver217) + +**前置教程:** +- [分布式训练](../concepts/distributed_training.md) + +## 引言 + +我们提供了一个实用程序类 `colossalai.cluster.DistCoordinator` 来协调分布式训练。它对于获取有关集群的各种信息很有用,例如节点数、每个节点的进程数等。 + +## API 参考 + +{{ autodoc:colossalai.cluster.DistCoordinator }} + + diff --git a/docs/source/zh-Hans/features/gradient_accumulation.md b/docs/source/zh-Hans/features/gradient_accumulation.md deleted file mode 100644 index e21e5fcd43d897761df6a5080313edb85d6a2e34..0000000000000000000000000000000000000000 --- a/docs/source/zh-Hans/features/gradient_accumulation.md +++ /dev/null @@ -1,40 +0,0 @@ -# 梯度累积 - -作者: Shenggui Li, Yongbin Li - -**前置教程** -- [定义配置文件](../basics/define_your_config.md) -- [在训练中使用Engine和Trainer](../basics/engine_trainer.md) - -**示例代码** -- [ColossalAI-Examples Gradient Accumulation](https://github.com/hpcaitech/ColossalAI-Examples/tree/main/features/gradient_accumulation) - -## 引言 - -梯度累积是一种常见的增大训练 batch size 的方式。 在训练大模型时,内存经常会成为瓶颈,并且 batch size 通常会很小(如2),这导致收敛性无法保证。梯度累积将多次迭代的梯度累加,并仅在达到预设迭代次数时更新参数。 - -## 使用 - -在 Colossal-AI 中使用梯度累积非常简单,仅需将下列配置添加进 config 文件。其中,整数值代表期望梯度累积的次数。 - -```python -gradient_accumulation = -``` - -## 实例 - -我们提供了一个 [运行实例](https://github.com/hpcaitech/ColossalAI-Examples/tree/main/features/gradient_accumulation) -来展现梯度累积。在这个例子中,梯度累积次数被设置为4,你可以通过一下命令启动脚本 - -```shell -python -m torch.distributed.launch --nproc_per_node 1 --master_addr localhost --master_port 29500 run_resnet_cifar10_with_engine.py -``` - -你将会看到类似下方的文本输出。这展现了梯度虽然在前3个迭代中被计算,但直到最后一次迭代,参数才被更新。 - -```text -iteration 0, first 10 elements of param: tensor([-0.0208, 0.0189, 0.0234, 0.0047, 0.0116, -0.0283, 0.0071, -0.0359, -0.0267, -0.0006], device='cuda:0', grad_fn=) -iteration 1, first 10 elements of param: tensor([-0.0208, 0.0189, 0.0234, 0.0047, 0.0116, -0.0283, 0.0071, -0.0359, -0.0267, -0.0006], device='cuda:0', grad_fn=) -iteration 2, first 10 elements of param: tensor([-0.0208, 0.0189, 0.0234, 0.0047, 0.0116, -0.0283, 0.0071, -0.0359, -0.0267, -0.0006], device='cuda:0', grad_fn=) -iteration 3, first 10 elements of param: tensor([-0.0141, 0.0464, 0.0507, 0.0321, 0.0356, -0.0150, 0.0172, -0.0118, 0.0222, 0.0473], device='cuda:0', grad_fn=) -``` diff --git a/docs/source/zh-Hans/features/gradient_accumulation_with_booster.md b/docs/source/zh-Hans/features/gradient_accumulation_with_booster.md new file mode 100644 index 0000000000000000000000000000000000000000..3ad9b2e07a955ec693098d4d304c4470f7e8c05f --- /dev/null +++ b/docs/source/zh-Hans/features/gradient_accumulation_with_booster.md @@ -0,0 +1,147 @@ +# 梯度累积 + +作者: [Mingyan Jiang](https://github.com/jiangmingyan) + +**前置教程** +- [训练中使用Booster](../basics/booster_api.md) + +## 引言 + +梯度累积是一种常见的增大训练 batch size 的方式。 在训练大模型时,内存经常会成为瓶颈,并且 batch size 通常会很小(如2),这导致收敛性无法保证。梯度累积将多次迭代的梯度累加,并仅在达到预设迭代次数时更新参数。 + +## 使用 + +在 Colossal-AI 中使用梯度累积非常简单,booster提供no_sync返回一个上下文管理器,在该上下文管理器下取消同步并且累积梯度。 + +## 实例 + +我们将介绍如何使用梯度累积。在这个例子中,梯度累积次数被设置为4。 + +### 步骤 1. 在 train.py 导入相关库 +创建train.py并导入必要依赖。 `torch` 的版本应不低于1.8.1。 + +```python +import os +from pathlib import Path + +import torch +from torchvision import transforms +from torchvision.datasets import CIFAR10 +from torchvision.models import resnet18 + +import colossalai +from colossalai.booster import Booster +from colossalai.booster.plugin import TorchDDPPlugin +from colossalai.logging import get_dist_logger +from colossalai.cluster.dist_coordinator import priority_execution +``` + +### 步骤 2. 初始化分布式环境 + +我们需要初始化分布式环境。为了快速演示,我们使用`launch_from_torch`。你可以参考 [Launch Colossal-AI](../basics/launch_colossalai.md)使用其他初始化方法。 + +```python +# initialize distributed setting +parser = colossalai.get_default_parser() +args = parser.parse_args() + +# launch from torch +colossalai.launch_from_torch(config=dict()) + +``` + +### 步骤 3. 创建训练组件 + +构建你的模型、优化器、损失函数、学习率调整器和数据加载器。注意数据集的路径从环境变量`DATA`获得。你可以通过 `export DATA=/path/to/data` 或 `Path(os.environ['DATA'])`,在你的机器上设置路径。数据将会被自动下载到该路径。 + +```python +# define the training hyperparameters +BATCH_SIZE = 128 +GRADIENT_ACCUMULATION = 4 + +# build resnet +model = resnet18(num_classes=10) + +# build dataloaders +with priority_execution(): + train_dataset = CIFAR10(root=Path(os.environ.get('DATA', './data')), + download=True, + transform=transforms.Compose([ + transforms.RandomCrop(size=32, padding=4), + transforms.RandomHorizontalFlip(), + transforms.ToTensor(), + transforms.Normalize(mean=[0.4914, 0.4822, 0.4465], std=[0.2023, 0.1994, 0.2010]), + ])) + +# build criterion +criterion = torch.nn.CrossEntropyLoss() + +# optimizer +optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9, weight_decay=5e-4) +``` + +### 步骤 4. 注入特性 +创建一个`TorchDDPPlugin`对象,并作为参实例化`Booster`, 调用`booster.boost`注入特性。 + +```python +plugin = TorchDDPPlugin() +booster = Booster(plugin=plugin) +train_dataloader = plugin.prepare_dataloader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, drop_last=True) +model, optimizer, criterion, train_dataloader, _ = booster.boost(model=model, + optimizer=optimizer, + criterion=criterion, + dataloader=train_dataloader) +``` + +### 步骤 5. 使用booster训练 +使用booster构建一个普通的训练循环,验证梯度累积。 `param_by_iter` 记录分布训练的信息。 +```python +optimizer.zero_grad() +for idx, (img, label) in enumerate(train_dataloader): + sync_context = booster.no_sync(model) + img = img.cuda() + label = label.cuda() + if idx % (GRADIENT_ACCUMULATION - 1) != 0: + with sync_context: + output = model(img) + train_loss = criterion(output, label) + train_loss = train_loss / GRADIENT_ACCUMULATION + booster.backward(train_loss, optimizer) + else: + output = model(img) + train_loss = criterion(output, label) + train_loss = train_loss / GRADIENT_ACCUMULATION + booster.backward(train_loss, optimizer) + optimizer.step() + optimizer.zero_grad() + + ele_1st = next(model.parameters()).flatten()[0] + param_by_iter.append(str(ele_1st.item())) + + if idx != 0 and idx % (GRADIENT_ACCUMULATION - 1) == 0: + break + + for iteration, val in enumerate(param_by_iter): + print(f'iteration {iteration} - value: {val}') + + if param_by_iter[-1] != param_by_iter[0]: + print('The parameter is only updated in the last iteration') + +``` + +### 步骤 6. 启动训练脚本 +为了验证梯度累积,我们可以只检查参数值的变化。当设置梯度累加时,仅在最后一步更新参数。您可以使用以下命令运行脚本: +```shell +colossalai run --nproc_per_node 1 train.py +``` + +你将会看到类似下方的文本输出。这展现了梯度虽然在前3个迭代中被计算,但直到最后一次迭代,参数才被更新。 + +```text +iteration 0, first 10 elements of param: tensor([-0.0208, 0.0189, 0.0234, 0.0047, 0.0116, -0.0283, 0.0071, -0.0359, -0.0267, -0.0006], device='cuda:0', grad_fn=) +iteration 1, first 10 elements of param: tensor([-0.0208, 0.0189, 0.0234, 0.0047, 0.0116, -0.0283, 0.0071, -0.0359, -0.0267, -0.0006], device='cuda:0', grad_fn=) +iteration 2, first 10 elements of param: tensor([-0.0208, 0.0189, 0.0234, 0.0047, 0.0116, -0.0283, 0.0071, -0.0359, -0.0267, -0.0006], device='cuda:0', grad_fn=) +iteration 3, first 10 elements of param: tensor([-0.0141, 0.0464, 0.0507, 0.0321, 0.0356, -0.0150, 0.0172, -0.0118, 0.0222, 0.0473], device='cuda:0', grad_fn=) +``` + + diff --git a/docs/source/zh-Hans/features/gradient_clipping.md b/docs/source/zh-Hans/features/gradient_clipping.md deleted file mode 100644 index 203f66a3fea247742823e5c24f5940d2bb5bf87a..0000000000000000000000000000000000000000 --- a/docs/source/zh-Hans/features/gradient_clipping.md +++ /dev/null @@ -1,51 +0,0 @@ -# 梯度裁剪 - -作者: Boxiang Wang, Haichen Huang, Yongbin Li - -**前置教程** -- [定义配置文件](../basics/define_your_config.md) -- [在训练中使用Engine和Trainer](../basics/engine_trainer.md) - -**示例代码** -- [ColossalAI-Examples Gradient Clipping](https://github.com/hpcaitech/ColossalAI-Examples/tree/main/features/gradient_clipping) - -**相关论文** -- [On the difficulty of training Recurrent Neural Networks](https://arxiv.org/abs/1211.5063) - -## 引言 - -为了加快训练过程和寻求全局最优以获得更好的性能,越来越多的学习率调度器被提出。人们通过控制学习率来调整训练中的下降速度。这使得梯度向量在每一步都能更好地统一。在这种情况下,下降速度可以按预期被控制。 -因此,梯度裁剪,一种可以将梯度向量归一化,以将其限制在统一长度的技术,对于那些希望模型性能更好的人来说是不可或缺的。 - -在使用 Colossal-AI 时,你不必担心实现梯度剪裁,我们以一种有效而方便的方式支持梯度剪裁。你所需要的只是在你的配置文件中增加一个命令。 - -## 为什么应该使用 Colossal-AI 中的梯度裁剪 - -我们不建议用户自己编写梯度剪裁,因为朴素的梯度剪裁在应用张量并行、流水线并行、MoE 等功能时可能会失败。 - -根据下图,每个 GPU 只拥有线性层中权重的一部分参数。为了得到线性层权重的梯度向量的正确范数,每个 GPU 中的每个梯度向量的范数应该相加。更复杂的是,偏置的分布不同于权重的分布。通信组在求和运算中有所不同。 - -(注: 这种情况是旧版本的 2D 并行,在代码中的实现是不一样的。但这是一个很好的例子,能够说明在梯度剪裁中统一所有通信的困难。) - -
              - -
              参数分布
              -
              - -不用担心它,因为 Colossal-AI 已经为你处理好。 - -### 使用 -要使用梯度裁剪,只需在配置文件中添加梯度裁剪范数即可。 - -```python -clip_grad_norm = 1.0 -``` - -### 实例 - -我们提供了一个展现梯度裁剪的[运行实例](https://github.com/hpcaitech/ColossalAI-Examples/tree/main/features/gradient_clipping) -。在本例中,我们将梯度裁剪范数设置为1.0,你可以使用以下命令运行脚本: - -```shell -python -m torch.distributed.launch --nproc_per_node 1 --master_addr localhost --master_port 29500 train_with_engine.py -``` diff --git a/docs/source/zh-Hans/features/gradient_clipping_with_booster.md b/docs/source/zh-Hans/features/gradient_clipping_with_booster.md new file mode 100644 index 0000000000000000000000000000000000000000..fdec09bf128a020071fe94bf90aa2f625662eca3 --- /dev/null +++ b/docs/source/zh-Hans/features/gradient_clipping_with_booster.md @@ -0,0 +1,139 @@ +# 梯度裁剪 + +作者: [Mingyan Jiang](https://github.com/jiangmingyan) + +**前置教程** +- [booster使用](../basics/booster_api.md) + +**相关论文** +- [On the difficulty of training Recurrent Neural Networks](https://arxiv.org/abs/1211.5063) + +## 引言 + +为了加快训练过程和寻求全局最优以获得更好的性能,越来越多的学习率调度器被提出。人们通过控制学习率来调整训练中的下降速度。这使得梯度向量在每一步都能更好地统一。在这种情况下,下降速度可以按预期被控制。 +因此,梯度裁剪,一种可以将梯度向量归一化,以将其限制在统一长度的技术,对于那些希望模型性能更好的人来说是不可或缺的。 + +在使用 Colossal-AI 时,你不必担心实现梯度剪裁,我们以一种有效而方便的方式支持梯度剪裁。你所需要的只是在你的配置文件中增加一个命令。 + +## 为什么应该使用 Colossal-AI 中的梯度裁剪 + +我们不建议用户自己编写梯度剪裁,因为朴素的梯度剪裁在应用张量并行、流水线并行、MoE 等功能时可能会失败。 + +根据下图,每个 GPU 只拥有线性层中权重的一部分参数。为了得到线性层权重的梯度向量的正确范数,每个 GPU 中的每个梯度向量的范数应该相加。更复杂的是,偏置的分布不同于权重的分布。通信组在求和运算中有所不同。 + +(注: 这种情况是旧版本的 2D 并行,在代码中的实现是不一样的。但这是一个很好的例子,能够说明在梯度剪裁中统一所有通信的困难。) + +
              + +
              参数分布
              +
              + +不用担心它,因为 Colossal-AI 已经为你处理好。 + +### 使用 +要使用梯度裁剪,只需在使用booster注入特性之后,调用optimizer的`clip_grad_by_norm`或者`clip_grad_by_value`函数即可进行梯度裁剪。 + +### 实例 + +下面我们将介绍如何使用梯度裁剪,在本例中,我们将梯度裁剪范数设置为1.0。 + +### 步骤 1. 在训练中导入相关库 +创建`train.py`并导入相关库。 + +```python +import os +from pathlib import Path + +import torch +from torchvision import transforms +from torchvision.datasets import CIFAR10 +from torchvision.models import resnet34 +from tqdm import tqdm + +import colossalai +from colossalai.booster import Booster +from colossalai.booster.plugin import TorchDDPPlugin +from colossalai.logging import get_dist_logger +from colossalai.nn.lr_scheduler import CosineAnnealingLR +``` + +### 步骤 2. 初始化分布式环境 +我们需要初始化分布式环境. 为了快速演示,我们使用`launch_from_torch`. 您可以参考 [Launch Colossal-AI](../basics/launch_colossalai.md) + +```python +colossalai.launch_from_torch(config=dict()) +logger = get_dist_logger() +``` + +### 步骤 3. 创建训练组件 + +构建你的模型、优化器、损失函数、学习率调整器和数据加载器。注意数据集的路径从环境变量`DATA`获得。你可以通过 `export DATA=/path/to/data` 或 `Path(os.environ['DATA'])`在你的机器上设置路径。数据将会被自动下载到该路径。 +```python +# define training hyperparameters +NUM_EPOCHS = 200 +BATCH_SIZE = 128 +GRADIENT_CLIPPING = 0.1 +# build resnet +model = resnet34(num_classes=10) +# build dataloaders +train_dataset = CIFAR10(root=Path(os.environ.get('DATA', './data')), + download=True, + transform=transforms.Compose([ + transforms.RandomCrop(size=32, padding=4), + transforms.RandomHorizontalFlip(), + transforms.ToTensor(), + transforms.Normalize(mean=[0.4914, 0.4822, 0.4465], std=[0.2023, 0.1994, 0.2010]), + ])) +# build criterion +criterion = torch.nn.CrossEntropyLoss() + +# optimizer +optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9, weight_decay=5e-4) + +# lr_scheduler +lr_scheduler = CosineAnnealingLR(optimizer, total_steps=NUM_EPOCHS) + +``` +### 步骤 4. 注入梯度裁剪特性 + +创建`TorchDDPPlugin`对象并初始化`Booster`, 使用booster注入相关特性。 +```python +plugin = TorchDDPPlugin() +booster = Booster(plugin=plugin) +train_dataloader = plugin.prepare_dataloader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, drop_last=True) +model, optimizer, criterion, train_dataloader, lr_scheduler = booster.boost(model,optimizer, criterion,train_dataloader, lr_scheduler) + +``` + +### 步骤 5. 使用booster训练 +使用booster进行训练。 +```python +# verify gradient clipping +model.train() +for idx, (img, label) in enumerate(train_dataloader): + img = img.cuda() + label = label.cuda() + + model.zero_grad() + output = model(img) + train_loss = criterion(output, label) + booster.backward(train_loss, optimizer) + optimizer.clip_grad_by_norm(max_norm=GRADIENT_CLIPPING) + optimizer.step() + lr_scheduler.step() + + ele_1st = next(model.parameters()).flatten()[0] + logger.info(f'iteration {idx}, loss: {train_loss}, 1st element of parameters: {ele_1st.item()}') + + # only run for 4 iterations + if idx == 3: + break +``` + +### 步骤 6. 启动训练脚本 +你可以使用以下命令运行脚本: + +```shell +colossalai run --nproc_per_node 1 train.py +``` + diff --git a/docs/source/zh-Hans/features/gradient_handler.md b/docs/source/zh-Hans/features/gradient_handler.md deleted file mode 100644 index 701c60fed57f01edfa9a49c9b68287ff9c62a7f9..0000000000000000000000000000000000000000 --- a/docs/source/zh-Hans/features/gradient_handler.md +++ /dev/null @@ -1,59 +0,0 @@ -# 梯度 Handler - -作者: Shenggui Li, Yongbin Li - -**前置教程** -- [定义配置文件](../basics/define_your_config.md) -- [在训练中使用Engine和Trainer](../basics/engine_trainer.md) - -**示例代码** -- [ColossalAI-Examples Gradient Handler](https://github.com/hpcaitech/ColossalAI-Examples/tree/main/features/gradient_handler) - -## 引言 - -在分布式训练中,每次迭代结束时都需要梯度同步。这很重要,因为我们需要确保在不同的机器中使用相同的梯度更新参数,以便生成的参数都一样。这通常在数据并行中看到,因为在数据并行中的模型是直接复制的。 - -在 Colossal-AI 中,我们为用户提供了一个接口来定制他们想要如何处理同步。这为实现新的并行方法等情况带来了灵活性。 - -当梯度 Handler 被使用时, PyTorch 的 `DistributedDataParallel` 将不再被使用,因为它会自动同步梯度. - -## 定制你的梯度 Handler - -要实现定制的梯度Handler,需要遵循以下步骤。 -1. 继承Colossal-AI中的 `BaseGradientHandler` -2. 将梯度Handler注册进 `GRADIENT_HANDLER` -3. 实现 `handle_gradient` - -```python -from colossalai.registry import GRADIENT_HANDLER -from colossalai.engine.gradient_handler import BaseGradientHandler - - -@GRADIENT_HANDLER.register_module -class MyGradientHandler(BaseGradientHandler): - - def handle_gradient(self): - do_something() - - -``` - - -## 使用 - -要使用梯度 Handler,需要在配置文件中指定梯度 Handler。梯度 Handler 将自动构建并连接到 Engine。 - -```python -gradient_handler = [dict(type='MyGradientHandler')] -``` - - -### 实例 - -我们提供了一个 [运行实例](https://github.com/hpcaitech/ColossalAI-Examples/tree/main/features/gradient_handler) -展现梯度 Handler 的使用. 在这个例子中,我们使用 `DataParallelGradientHandler` 而不是 PyTorch 的 -`DistributedDataParallel` 实现数据并行. - -```shell -python -m torch.distributed.launch --nproc_per_node 4 --master_addr localhost --master_port 29500 train_with_engine.py -``` diff --git a/docs/source/zh-Hans/features/lazy_init.md b/docs/source/zh-Hans/features/lazy_init.md new file mode 100644 index 0000000000000000000000000000000000000000..80742a56df2961517af85db5378723828a41ca87 --- /dev/null +++ b/docs/source/zh-Hans/features/lazy_init.md @@ -0,0 +1,76 @@ +# 懒惰初始化 + +作者: [Hongxiu Liu](https://github.com/ver217) + +**前置教程:** +- [Train with booster](../basics/booster_api.md) + +## 简介 + +懒惰初始化延迟了模型的初始化。它能够节省在大模型初始化时的内存占用。 + +如果你的模型有 `N` 十亿个参数并且你的内存(或显存)为 `M` GB, 我们推荐您在 `4N >= M` 时使用懒惰初始化。否则,懒惰初始化不是必须的。 + +## 使用 + +懒惰初始化必须与 booster 一起使用。 + +### API 参考 + +{{ autodoc:colossalai.lazy.LazyInitContext }} + +### 例子 + +```python +import colossalai +from colossalai.lazy import LazyInitContext +from colossalai.booster import Booster +from colossalai.booster.plugin import GeminiPlugin + +from transformers import LlamaForCausalLM, LlamaConfig, BertForPreTraining + +colossalai.launch({}) +plugin = GeminiPlugin() +booster = Booster(plugin) + +# 1. Initialize model from scratch +# Initialization on cuda will accelerate the initialization process but take more GPU memory. +with LazyInitContext(default_device="cuda"): + model = LlamaForCausalLM(LlamaConfig(hidden_size=64, intermediate_size=172, num_hidden_layers=4, num_attention_heads=4)) +model, *_ = booster.boost(model) + +# 2. Initialize model from pretrained +with LazyInitContext(): + model = BertForPreTraining.from_pretrained("prajjwal1/bert-tiny") +model, *_ = booster.boost(model) +``` + +> ⚠️ 使用懒惰初始化加载预训练模型在 colossalai>0.3.3 或主分支上支持。 + +## 限制 + +我们提到,懒惰初始化必须与 booster 一起使用。只有几个插件支持它。 + +| 插件 | 支持情况 | 备注 | +|-----------------|---------|--------| +| Gemini | 是 | | +| Hybrid Parallel | 是 | | +| Low Level Zero | 否 | 不需要 | +| Torch DDP | 否 | 不兼容 | +| Torch FSDP | 否 | 不兼容 | + +不是所有的模型都可以懒惰初始化。在某些情况下,一部分参数/缓冲区可能会被提前初始化。但是不用担心,这部分通常只占整个模型的一小部分。 + +并且一些模型完全不支持,会引发错误。我们测试了 torchvision, diffusers, timm, transformers, torchaudio 和 torchrec 中的模型。以下模型不受支持: + +| 模型 | 分类 | +|-------------------------------|--------------| +| wav2vec2_base | torchaudio | +| hubert_base | torchaudio | +| ViTModel | transformers | +| ViTForMaskedImageModeling | transformers | +| ViTForImageClassification | transformers | +| Blip2Model | transformers | +| Blip2ForConditionalGeneration | transformers | + + diff --git a/docs/source/zh-Hans/features/mixed_precision_training.md b/docs/source/zh-Hans/features/mixed_precision_training.md deleted file mode 100644 index c9db3a59c1c3912f256e90bbc57ce5322341e3a9..0000000000000000000000000000000000000000 --- a/docs/source/zh-Hans/features/mixed_precision_training.md +++ /dev/null @@ -1,344 +0,0 @@ -# 自动混合精度训练 (AMP) - -作者: Chuanrui Wang, Shenggui Li, Yongbin Li - -**前置教程** -- [定义配置文件](../basics/define_your_config.md) -- [在训练中使用Engine和Trainer](../basics/engine_trainer.md) - -**示例代码** -- [ColossalAI-Examples AMP](https://github.com/hpcaitech/ColossalAI-Examples/tree/main/features/amp) - -**相关论文** -- [Accelerating Scientific Computations with Mixed Precision Algorithms](https://arxiv.org/abs/0808.2794) - - -## 引言 - -AMP 代表自动混合精度训练。 -在 Colossal-AI 中, 我们结合了混合精度训练的不同实现: - -1. torch.cuda.amp -2. apex.amp -3. naive amp - - -| Colossal-AI | 支持张量并行 | 支持流水并行 | fp16范围 | -| ----------- | ----------------------- | ------------------------- | ----------- | -| AMP_TYPE.TORCH | ✅ | ❌ | 在前向和反向传播期间,模型参数、激活和梯度向下转换至fp16 | -| AMP_TYPE.APEX | ❌ | ❌ | 更细粒度,我们可以选择 opt_level O0, O1, O2, O3 | -| AMP_TYPE.NAIVE | ✅ | ✅ | 模型参数、前向和反向操作,全都向下转换至fp16 | - -前两个依赖于 PyTorch (1.6及以上) 和 NVIDIA Apex 的原始实现。最后一种方法类似 Apex O2。在这些方法中,Apex-AMP 与张量并行不兼容。这是因为张量是以张量并行的方式在设备之间拆分的,因此,需要在不同的进程之间进行通信,以检查整个模型权重中是否出现inf或nan。我们修改了torch amp实现,使其现在与张量并行兼容。 - -> ❌️ fp16与ZeRO配置不兼容 -> -> ⚠️ 流水并行目前仅支持naive amp - -我们建议使用 torch AMP,因为在不使用流水并行时,它通常比 NVIDIA AMP 提供更好的准确性。 - -## 目录 - -在本教程中,我们将介绍: - -1. AMP 介绍 -2. Colossal-AI 中的 AMP -3. 练习实例 - -## AMP 介绍 - -自动混合精度训练是混合 FP16 和 FP32 训练。 - -半精度浮点格式(FP16)具有较低的算法复杂度和较高的计算效率。此外,FP16 仅需要 FP32 所需的一半存储空间,并节省了内存和网络带宽,从而为大 batch size 和大模型提供了更多内存。 - -然而,还有其他操作,如缩减,需要 FP32 的动态范围,以避免数值溢出/下溢。因此,我们引入自动混合精度,尝试将每个操作与其相应的数据类型相匹配,这可以减少内存占用并提高训练效率。 - -
              - -
              AMP 示意图 (图片来自 PatrickStar 论文)
              -
              - -## Colossal-AI 中的 AMP - -我们支持三种 AMP 训练方法,并允许用户在没有改变代码的情况下使用 AMP 进行训练。只需在配置文件中添加'fp16'配置即可使用 AMP。 - -```python -from colossalai.amp import AMP_TYPE - -# 使用 Torch AMP -fp16=dict( - mode = AMP_TYPE.TORCH -) - -# 使用 naive AMP -fp16=dict( - mode = AMP_TYPE.NAIVE -) - -# 使用 Nvidia Apex AMP -fp16=dict( - mode = AMP_TYPE.APEX -) - -``` - -> 这些是最低配置,完整配置将在后面的部分中说明 - -### AMP 模块化 - -AMP 模块设计为完全模块化,可以独立使用。如果你想在你的代码库中只使用 AMP 而不使用`colossalai.initialize`,你可以导入`colossalai.amp.convert_to_amp`。 - -```python -from colossalai.amp import AMP_TYPE - -# 使用torch amp的例子 -model, optimizer, criterion = colossalai.amp.convert_to_amp(model, - optimizer, - criterion, - AMP_TYPE.TORCH) -``` - -### Torch AMP 配置 - -```python -from colossalai.amp import AMP_TYPE - -fp16=dict( - mode=AMP_TYPE.TORCH, - - # 下列是grad scaler的默认值 - init_scale=2.**16, - growth_factor=2.0, - backoff_factor=0.5, - growth_interval=2000, - enabled=True -) -``` - -可选参数: -- init_scale(float, optional, default=2.**16): 初始缩放因子; -- growth_factor(float, optional, default=2.0): 如果在``growth_interval``连续迭代过程中没有出现 inf/NaN 梯度,则在`update`中乘以比例系数; -- backoff_factor(float, optional, default=0.5): 如果在迭代中出现 inf/NaN 梯度,则在`update`中乘以比例系数; -- growth_interval(int, optional, default=2000): 在指定次数的连续迭代中,若没有出现 inf/NaN 梯度,则乘以``growth_factor``. -- enabled(bool, optional, default=True): ``False``则使梯度缩放无效,`step` 仅调用底层的 ``optimizer.step()``, 其他方法成为空操作。 - -### Apex AMP 配置 - -对于这种模式,我们依靠 Apex 实现混合精度训练。我们支持这个插件,因为它允许对混合精度的粒度进行更精细的控制。 -例如, O2 水平 (优化器水平2) 将保持 batch normalization 为 FP32。 - -如果你想了解更多细节,请参考 [Apex Documentation](https://nvidia.github.io/apex/)。 - -```python -from colossalai.amp import AMP_TYPE - -fp16 = dict( - mode=AMP_TYPE.APEX, - - # 下列是默认值 - enabled=True, - opt_level='O1', - cast_model_type=None, - patch_torch_functions=None, - keep_batchnorm_fp32=None, - master_weights=None, - loss_scale=None, - cast_model_outputs=None, - num_losses=1, - verbosity=1, - min_loss_scale=None, - max_loss_scale=16777216.0 -) -``` - -参数: -- enabled(bool, optional, default=True): False 会使所有 AMP 调用成为空操作, 程序将会像没有使用 AMP 一样运行。 - -- opt_level(str, optional, default="O1" ): 纯精度或混合精度优化水平。可选值 “O0”, “O1”, “O2”, and “O3”, 详细解释见上方 Apex AMP 文档。 - -- num_losses(int, optional, default=1): 选择提前告知 AMP 您计划使用多少次损失/反向计算。 -当`amp.scale_loss`与 loss_id 参数一起使用时,使 AMP 在每次损失/反向计算时使用不同的损失比例,这可以提高稳定性。如果 num_losses 被设置为1,AMP 仍支持多次损失/反向计算,但对他们都使用同一个全局损失比例。 - -- verbosity(int, default=1): 设置为0抑制 AMP 相关输出。 - -- min_loss_scale(float, default=None): 为可通过动态损耗比例选择的损耗比例值设置下限。 -默认值“None”意味着不设置任何下限。如果不使用动态损耗比例,则忽略 min_loss_scale 。 - -- max_loss_scale(float, default=2.**24 ): 为可通过动态损耗比例选择的损耗比例值设置上限。如果不使用动态损耗比例,则 max_loss_scale 被忽略. - -目前,管理纯精度或混合精度训练的幕后属性有以下几种: -cast_model_type, patch_torch_functions, keep_batchnorm_fp32, master_weights, loss_scale. -一旦 opt_level 被确定,它们是可选的可覆盖属性 - -- cast_model_type: 将模型的参数和缓冲区强制转换为所需的类型。 -- patch_torch_functions: 补全所有的 Torch 函数和张量方法,以便在FP16中执行张量核心友好的操作,如 GEMMs 和卷积,以及在 FP32 中执行任何受益于 FP32 精度的操作。 -- keep_batchnorm_fp32: 为了提高精度并启用 cudnn batchnorm (这会提高性能),在 FP32 中保留 batchnorm 权重通常是有益的,即使模型的其余部分是 FP16。 -- master_weights: 保持 FP32 主权重以配合任何 FP16 模型权重。 FP32 主权重由优化器分级,以提高精度和捕捉小梯度。 -- loss_scale: 如果 loss_scale 是一个浮点数,则使用这个值作为静态(固定)的损失比例。如果 loss_scale 是字符串 "dynamic",则随着时间的推移自适应地调整损失比例。动态损失比例调整由 AMP 自动执行。 - - -### Naive AMP 配置 - -在 Naive AMP 模式中, 我们实现了混合精度训练,同时保持了与复杂张量和流水并行的兼容性。该 AMP 模式将所有操作转为 FP16 。下列代码块展示了该模式的`config.py`。 - -```python -from colossalai.amp import AMP_TYPE - -fp16 = dict( - mode=AMP_TYPE.NAIVE, - - # below are the default values - log_num_zeros_in_grad=False, - initial_scale=2 ** 32, - min_scale=1, - growth_factor=2, - backoff_factor=0.5, - growth_interval=1000, - hysteresis=2 -) -``` - -Naive AMP 的默认参数: -- log_num_zeros_in_grad(bool): 返回0值梯度的个数. -- initial_scale(int): gradient scaler 的初始值 -- growth_factor(int): loss scale 的增长率 -- backoff_factor(float): loss scale 的下降率 -- hysterisis(int): 动态 loss scaling 的延迟偏移 -- max_scale(int): loss scale 的最大允许值 -- verbose(bool): 如果被设为`True`,将打印调试信息 - -当使用`colossalai.initialize`时, 首先需要实例化一个模型、一个优化器和一个标准。将输出模型转换为内存消耗较小的 AMP 模型。如果您的输入模型已经太大,无法放置在 GPU 中,请使用`dtype=torch.float16`实例化你的模型。或者请尝试更小的模型,或尝试更多的并行化训练技术! - -## 实例 - -我们提供了一个 [运行实例](https://github.com/hpcaitech/ColossalAI-Examples/tree/main/features/amp) -展现如何在 Colossal-AI 使用 AMP。在该例程中,我们使用 Torch AMP, 但提供的配置文件也适用于所有 AMP 模式. - -### 步骤 1. 创建配置文件 - -创建一个`config.py`文件并添加`fp16`配置. - -```python -# in config.py -from colossalai.amp import AMP_TYPE - -BATCH_SIZE = 128 -DROP_RATE = 0.1 -NUM_EPOCHS = 300 - -fp16 = dict( - mode=AMP_TYPE.TORCH, -) - -clip_grad_norm = 1.0 -``` - -### 步骤 2. 在 train_with_engine.py 导入相关库 - -创建`train_with_engine.py`并导入必要依赖. 请记得通过命令`pip install timm scipy`安装`scipy`和`timm`。 - -```python -import os -import colossalai -import torch -from pathlib import Path -from colossalai.core import global_context as gpc -from colossalai.logging import get_dist_logger -from colossalai.utils import get_dataloader -from colossalai.trainer import Trainer, hooks -from colossalai.nn.lr_scheduler import LinearWarmupLR -from timm.models import vit_base_patch16_224 -from torchvision import datasets, transforms - -``` - -### 步骤 3. 初始化分布式环境 - -我们需要初始化分布式环境。为了快速演示,我们使用`launch_from_torch`。你可以参考 [Launch Colossal-AI](../basics/launch_colossalai.md) -使用其他初始化方法。 - -```python -# 初始化分布式设置 -parser = colossalai.get_default_parser() -args = parser.parse_args() - -# launch from torch -colossalai.launch_from_torch(config=args.config) - -``` - -### 步骤 4. 创建训练组件 - -构建你的模型、优化器、损失函数、学习率调整器和数据加载器。注意数据集的路径从环境变量`DATA`获得。你可以通过 `export DATA=/path/to/data` 或 `Path(os.environ['DATA'])` -在你的机器上设置路径。数据将会被自动下载到该路径。 - -```python -# build model - model = vit_base_patch16_224(drop_rate=0.1) - - # build dataloader - train_dataset = datasets.Caltech101( - root=Path(os.environ['DATA']), - download=True, - transform=transforms.Compose([ - transforms.Resize(256), - transforms.RandomResizedCrop(224), - transforms.RandomHorizontalFlip(), - transforms.ToTensor(), - Gray2RGB(), - transforms.Normalize([0.5, 0.5, 0.5], - [0.5, 0.5, 0.5]) - ])) - - train_dataloader = get_dataloader(dataset=train_dataset, - shuffle=True, - batch_size=gpc.config.BATCH_SIZE, - num_workers=1, - pin_memory=True, - ) - - # build optimizer - optimizer = torch.optim.SGD(model.parameters(), lr=1e-2, weight_decay=0.1) - - # build loss - criterion = torch.nn.CrossEntropyLoss() - - # lr_scheduelr - lr_scheduler = LinearWarmupLR(optimizer, warmup_steps=50, total_steps=gpc.config.NUM_EPOCHS) -``` - -### 步骤 5. 插入 AMP - -调用 `colossalai.initialize` 将所有训练组件转为为FP16模式. - -```python -engine, train_dataloader, _, _ = colossalai.initialize( - model, optimizer, criterion, train_dataloader, - ) -``` - -### 步骤 6. 使用 Engine 训练 - -使用Engine构建一个普通的训练循环 - -```python -engine.train() -for epoch in range(gpc.config.NUM_EPOCHS): - for img, label in enumerate(train_dataloader): - img = img.cuda() - label = label.cuda() - engine.zero_grad() - output = engine(img) - loss = engine.criterion(output, label) - engine.backward(loss) - engine.step() - lr_scheduler.step() -``` - -### 步骤 7. 启动训练脚本 - -使用下列命令启动训练脚本,你可以改变 `--nproc_per_node` 以使用不同数量的 GPU。 - -```python -python -m torch.distributed.launch --nproc_per_node 4 --master_addr localhost --master_port 29500 train_with_engine.py --config config/config_AMP_torch.py -``` diff --git a/docs/source/zh-Hans/features/mixed_precision_training_with_booster.md b/docs/source/zh-Hans/features/mixed_precision_training_with_booster.md new file mode 100644 index 0000000000000000000000000000000000000000..8e9f614a25af032a379a888a21b4664eb4833162 --- /dev/null +++ b/docs/source/zh-Hans/features/mixed_precision_training_with_booster.md @@ -0,0 +1,245 @@ +# 自动混合精度训练 + +作者: [Mingyan Jiang](https://github.com/jiangmingyan) + +**前置教程** + +- [booster 使用](../basics/booster_api.md) + +**相关论文** + +- [Accelerating Scientific Computations with Mixed Precision Algorithms](https://arxiv.org/abs/0808.2794) + +## 引言 + +AMP 代表自动混合精度训练。 +在 Colossal-AI 中, 我们结合了混合精度训练的不同实现: + +1. torch.cuda.amp +2. apex.amp +3. naive amp + +| Colossal-AI | 支持张量并行 | 支持流水并行 | fp16 范围 | +| -------------- | ------------ | ------------ | --------------------------------------------------------- | +| AMP_TYPE.TORCH | ✅ | ❌ | 在前向和反向传播期间,模型参数、激活和梯度向下转换至 fp16 | +| AMP_TYPE.APEX | ❌ | ❌ | 更细粒度,我们可以选择 opt_level O0, O1, O2, O3 | +| AMP_TYPE.NAIVE | ✅ | ✅ | 模型参数、前向和反向操作,全都向下转换至 fp16 | + +前两个依赖于 PyTorch (1.6 及以上) 和 NVIDIA Apex 的原始实现。最后一种方法类似 Apex O2。在这些方法中,Apex-AMP 与张量并行不兼容。这是因为张量是以张量并行的方式在设备之间拆分的,因此,需要在不同的进程之间进行通信,以检查整个模型权重中是否出现 inf 或 nan。我们修改了 torch amp 实现,使其现在与张量并行兼容。 + +> ❌️ fp16 与 ZeRO 不兼容 +> +> ⚠️ 流水并行目前仅支持 naive amp + +我们建议使用 torch AMP,因为在不使用流水并行时,它通常比 NVIDIA AMP 提供更好的准确性。 + +## 目录 + +在本教程中,我们将介绍: + +1. [AMP 介绍](#amp-介绍) +2. [Colossal-AI 中的 AMP](#colossal-ai-中的-amp) +3. [练习实例](#实例) + +## AMP 介绍 + +自动混合精度训练是混合 FP16 和 FP32 训练。 + +半精度浮点格式(FP16)具有较低的算法复杂度和较高的计算效率。此外,FP16 仅需要 FP32 所需的一半存储空间,并节省了内存和网络带宽,从而为大 batch size 和大模型提供了更多内存。 + +然而,还有其他操作,如缩减,需要 FP32 的动态范围,以避免数值溢出/下溢。因此,我们引入自动混合精度,尝试将每个操作与其相应的数据类型相匹配,这可以减少内存占用并提高训练效率。 + +
              + +
              AMP 示意图 (图片来自 PatrickStar 论文)
              +
              + +## Colossal-AI 中的 AMP + +我们支持三种 AMP 训练方法,并允许用户在没有改变代码的情况下使用 AMP 进行训练。booster 支持 amp 特性注入,如果您要使用混合精度训练,则在创建 booster 实例时指定`mixed_precision`参数;后续将会拓展`bf16`,`pf8`的混合精度训练. + +#### booster 启动方式 + +您可以在创建 booster 实例时,指定`mixed_precision="fp16"`即使用 torch amp。 + + + +```python +""" + 初始化映射关系如下: + 'fp16': torch amp + 'fp16_apex': apex amp, + 'bf16': bf16, + 'fp8': fp8, + 'fp16_naive': naive amp +""" +from colossalai import Booster +booster = Booster(mixed_precision='fp16',...) +``` + + + +或者您可以自定义一个`FP16TorchMixedPrecision`对象,如 + + + +```python +from colossalai.mixed_precision import FP16TorchMixedPrecision +mixed_precision = FP16TorchMixedPrecision( + init_scale=2.**16, + growth_factor=2.0, + backoff_factor=0.5, + growth_interval=2000) +booster = Booster(mixed_precision=mixed_precision,...) +``` + + + +其他类型的 amp 使用方式也是一样的。 + +### Torch AMP 配置 + +{{ autodoc:colossalai.booster.mixed_precision.FP16TorchMixedPrecision }} + +### Apex AMP 配置 + +对于这种模式,我们依靠 Apex 实现混合精度训练。我们支持这个插件,因为它允许对混合精度的粒度进行更精细的控制。 +例如, O2 水平 (优化器水平 2) 将保持 batch normalization 为 FP32。 + +如果你想了解更多细节,请参考 [Apex Documentation](https://nvidia.github.io/apex/)。 + +{{ autodoc:colossalai.booster.mixed_precision.FP16ApexMixedPrecision }} + +### Naive AMP 配置 + +在 Naive AMP 模式中, 我们实现了混合精度训练,同时保持了与复杂张量和流水并行的兼容性。该 AMP 模式将所有操作转为 FP16 。下列代码块展示了该模式的 booster 启动方式。 + +{{ autodoc:colossalai.booster.mixed_precision.FP16NaiveMixedPrecision }} + +当使用`colossalai.booster`时, 首先需要实例化一个模型、一个优化器和一个标准。将输出模型转换为内存消耗较小的 AMP 模型。如果您的输入模型已经太大,无法放置在 GPU 中,请使用`dtype=torch.float16`实例化你的模型。或者请尝试更小的模型,或尝试更多的并行化训练技术! + +## 实例 + +下面我们将展现如何在 Colossal-AI 使用 AMP。在该例程中,我们使用 Torch AMP. + +### 步骤 1. 在 train.py 导入相关库 + +创建`train.py`并导入必要依赖. 请记得通过命令`pip install timm scipy`安装`scipy`和`timm`。 + +```python +import os +from pathlib import Path + +import torch +from timm.models import vit_base_patch16_224 +from titans.utils import barrier_context +from torchvision import datasets, transforms + +import colossalai +from colossalai.booster import Booster +from colossalai.booster.plugin import TorchDDPPlugin +from colossalai.logging import get_dist_logger +from colossalai.nn.lr_scheduler import LinearWarmupLR +``` + +### 步骤 2. 初始化分布式环境 + +我们需要初始化分布式环境。为了快速演示,我们使用`launch_from_torch`。你可以参考 [Launch Colossal-AI](../basics/launch_colossalai.md) +使用其他初始化方法。 + +```python +# 初始化分布式设置 +parser = colossalai.get_default_parser() +args = parser.parse_args() + +# launch from torch +colossalai.launch_from_torch(config=dict()) + +``` + +### 步骤 3. 创建训练组件 + +构建你的模型、优化器、损失函数、学习率调整器和数据加载器。注意数据集的路径从环境变量`DATA`获得。你可以通过 `export DATA=/path/to/data` 或 `Path(os.environ['DATA'])` +在你的机器上设置路径。数据将会被自动下载到该路径。 + +```python +# define the constants +NUM_EPOCHS = 2 +BATCH_SIZE = 128 +# build model +model = vit_base_patch16_224(drop_rate=0.1) + +# build dataloader +train_dataset = datasets.Caltech101( + root=Path(os.environ['DATA']), + download=True, + transform=transforms.Compose([ + transforms.Resize(256), + transforms.RandomResizedCrop(224), + transforms.RandomHorizontalFlip(), + transforms.ToTensor(), + Gray2RGB(), + transforms.Normalize([0.5, 0.5, 0.5], + [0.5, 0.5, 0.5]) + ])) + +# build optimizer +optimizer = torch.optim.SGD(model.parameters(), lr=1e-2, weight_decay=0.1) + +# build loss +criterion = torch.nn.CrossEntropyLoss() + +# lr_scheduler +lr_scheduler = LinearWarmupLR(optimizer, warmup_steps=50, total_steps=NUM_EPOCHS) +``` + +### 步骤 4. 插入 AMP + +创建一个 MixedPrecision 对象(如果需要)及 torchDDPPlugin 对象,调用 `colossalai.boost` 将所有训练组件转为为 FP16 模式. + +```python +plugin = TorchDDPPlugin() +train_dataloader = plugin.prepare_dataloader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, drop_last=True) +booster = Booster(mixed_precision='fp16', plugin=plugin) + +# if you need to customize the config, do like this +# >>> from colossalai.mixed_precision import FP16TorchMixedPrecision +# >>> mixed_precision = FP16TorchMixedPrecision( +# >>> init_scale=2.**16, +# >>> growth_factor=2.0, +# >>> backoff_factor=0.5, +# >>> growth_interval=2000) +# >>> plugin = TorchDDPPlugin() +# >>> booster = Booster(mixed_precision=mixed_precision, plugin=plugin) + +# boost model, optimizer, criterion, dataloader, lr_scheduler +model, optimizer, criterion, dataloader, lr_scheduler = booster.boost(model, optimizer, criterion, dataloader, lr_scheduler) +``` + +### 步骤 5. 使用 booster 训练 + +使用 booster 构建一个普通的训练循环。 + +```python +model.train() +for epoch in range(NUM_EPOCHS): + for img, label in enumerate(train_dataloader): + img = img.cuda() + label = label.cuda() + optimizer.zero_grad() + output = model(img) + loss = criterion(output, label) + booster.backward(loss, optimizer) + optimizer.step() + lr_scheduler.step() +``` + +### 步骤 6. 启动训练脚本 + +使用下列命令启动训练脚本,你可以改变 `--nproc_per_node` 以使用不同数量的 GPU。 + +```shell +colossalai run --nproc_per_node 1 train.py +``` + + diff --git a/docs/source/zh-Hans/features/nvme_offload.md b/docs/source/zh-Hans/features/nvme_offload.md index fd75ed1f5b3ecb10b51846c73f6afb2d5a1234c2..1feb9dde572593761fd7780b590d3fe181892964 100644 --- a/docs/source/zh-Hans/features/nvme_offload.md +++ b/docs/source/zh-Hans/features/nvme_offload.md @@ -53,9 +53,8 @@ optimizer = HybridAdam(model.parameters(), lr=1e-3, nvme_offload_fraction=1.0, n > ⚠ 它只会卸载在 CPU 上的优化器状态。这意味着它只会影响 CPU 训练或者使用卸载的 Zero/Gemini。 -## Exampls +## Examples -Let's start from two simple examples -- training GPT with different methods. These examples relies on `transformers`. 首先让我们从两个简单的例子开始 -- 用不同的方法训练 GPT。这些例子依赖`transformers`。 我们首先应该安装依赖: @@ -77,8 +76,9 @@ from transformers.models.gpt2.configuration_gpt2 import GPT2Config from transformers.models.gpt2.modeling_gpt2 import GPT2LMHeadModel import colossalai from colossalai.nn.optimizer import HybridAdam -from colossalai.zero import zero_model_wrapper, zero_optim_wrapper from colossalai.utils.model.colo_init_context import ColoInitContext +from colossalai.booster import Booster +from colossalai.booster.plugin import GeminiPlugin ``` 然后我们定义一个损失函数: @@ -182,16 +182,24 @@ def train_gemini_cpu(nvme_offload_fraction: float = 0.0): criterion = GPTLMLoss() optimizer = HybridAdam(model.parameters(), nvme_offload_fraction=nvme_offload_fraction) print(f'Model numel: {get_model_numel(model) / 1024**3:.3f} B') - gemini_config = dict(strict_ddp_mode=True, device=torch.cuda.current_device(), - placement_policy='cpu', pin_memory=True, hidden_dim=config.n_embd) - model = zero_model_wrapper(model, zero_stage=3, gemini_config=gemini_config) - optimizer = zero_optim_wrapper(model, optimizer, initial_scale=2**5) + + plugin = GeminiPlugin( + strict_ddp_mode=True, + device=torch.cuda.current_device(), + placement_policy='cpu', + pin_memory=True, + hidden_dim=config.n_embd, + initial_scale=2**5 + ) + booster = Booster(plugin) + model, optimizer, criterion, _* = booster.boost(model, optimizer, criterion) + start = time.time() for step in range(3): data = get_data(4, 128, config.vocab_size) outputs = model(**data) loss = criterion(outputs.logits, data['input_ids']) - optimizer.backward(loss) + booster.backward(loss, optimizer) optimizer.step() optimizer.zero_grad() print(f'[{step}] loss: {loss.item():.3f}') diff --git a/docs/source/zh-Hans/features/pipeline_parallel.md b/docs/source/zh-Hans/features/pipeline_parallel.md index 98096b1d7f9378bf178c6da9a5febfdfae67efb3..e688020556d8af1d3bebdba9eabdf8f7bb3e6c62 100644 --- a/docs/source/zh-Hans/features/pipeline_parallel.md +++ b/docs/source/zh-Hans/features/pipeline_parallel.md @@ -1,14 +1,15 @@ # 流水并行 -作者: Guangyang Lu, Hongxin Liu, Yongbin Li +作者: Guangyang Lu, Hongxin Liu, Yongbin Li, Mingyan Jiang **前置教程** -- [定义配置文件](../basics/define_your_config.md) -- [在训练中使用Engine和Trainer](../basics/engine_trainer.md) -- [并行配置](../basics/configure_parallelization.md) +- [并行技术](../concepts/paradigms_of_parallelism.md) +- [Booster API](../basics/booster_api.md) +- [Shardformer](../features/shardformer.md) +- [Booster 插件](../basics/booster_plugins.md) **示例代码** -- [ColossalAI-Examples ResNet with pipeline](https://github.com/hpcaitech/ColossalAI-Examples/tree/main/features/pipeline_parallel) +- [使用pipeline并行策略微调Bert](https://github.com/hpcaitech/ColossalAI/blob/main/examples/language/bert/finetune.py) **相关论文** - [Colossal-AI: A Unified Deep Learning System For Large-Scale Parallel Training](https://arxiv.org/abs/2110.14883) @@ -17,7 +18,7 @@ ## 快速预览 -在本教程中,你将学习如何使用流水并行。在 Colossal-AI 中, 我们使用 NVIDIA 推出的 1F1B 流水线。由于在本例中, 使用 ViT 和 ImageNet 太过庞大,因此我们使用 ResNet 和 CIFAR 为例. +在本教程中,你将学习如何使用流水并行。在 Colossal-AI 中, 我们使用 NVIDIA 推出的 1F1B 流水线。由于在本例中, 使用 ViT 和 ImageNet 太过庞大,因此我们使用 Bert 和 Glue数据集 为例. ## 目录 @@ -25,7 +26,7 @@ 1. 介绍 1F1B 流水线; 2. 使用非交错和交错 schedule; -3. 使用流水线训练 ResNet。 +3. 使用流水线微调 Bert ## 认识 1F1B 流水线 @@ -59,100 +60,154 @@ 这种模式既节省内存又节省时间。 -## 使用schedule +## Colossal-AI中的实现 -在 Colossal-AI 中, 我们提供非交错(`PipelineSchedule`) 和交错(`InterleavedPipelineSchedule`)schedule。 +在 Colossal-AI 中,流水线并行依赖于 `scheduler` 和 `Shardformer`。我们提供了非交错的(`OneForwardOneBackwardSchedule`)和交错的(`InterleavedSchedule`)两种调度方式。而 Shardformer 实现了对模型的层分割,并替换了模型的 `forward` 函数,使其与调度器兼容。 -你只需要在配置文件中,设置 `NUM_MICRO_BATCHES` 并在你想使用交错schedule的时候,设置 `NUM_CHUNKS`。 如果你确定性地知道每个管道阶段的输出张量的形状,而且形状都是一样的,你可以设置 `tensor_shape` 以进一步减少通信。否则,你可以忽略 `tensor_shape` , 形状将在管道阶段之间自动交换。 我们将会根据用户提供的配置文件,生成一个合适schedule来支持用户的流水并行训练。 +在 Colossal-AI 中,`HybridParallelPlugin` 封装了流水线执行策略。它管理流水线并行通信组和一个 `scheduler`。当使用此插件增强模型时,模型的层将通过调用 `shardformer.optimize` 函数进行分割,然后调用 `execute_pipeline` 使用 `scheduler` 来分别执行模型的各个部分。 `HybridParallelPlugin`暂时只支持`OneForwardOneBackwardSchedule`, `InterleavedSchedule`将会在不久后支持。 -## 使用流水线训练 ResNet +您可以通过设置 `HybridParallelPlugin` 的参数来自定义您的并行策略。更多使用细节请参考`HybridParallelPlugin`的[使用文档](../basics/booster_plugins.md)。 -我们首先用Colossal PipelinableContext方式建立 `ResNet` 模型: +## 使用流水线微调 Bert模型 + +首先我们定义好需要的训练组件,包括`model`, `dataloader`, `optimizer`, `lr_scheduler`, `criterion` 等: ```python -import os -from typing import Callable, List, Optional, Type, Union +import argparse +from typing import Callable, List, Union + import torch import torch.nn as nn +from data import GLUEDataBuilder +from torch.optim import Adam, Optimizer +from torch.optim.lr_scheduler import _LRScheduler as LRScheduler +from torch.utils.data import DataLoader +from tqdm import tqdm +from transformers import ( + AlbertForSequenceClassification, + AutoConfig, + BertForSequenceClassification, + get_linear_schedule_with_warmup, +) + import colossalai -import colossalai.nn as col_nn +from colossalai.booster import Booster +from colossalai.booster.plugin import HybridParallelPlugin +from colossalai.cluster import DistCoordinator +from colossalai.nn.optimizer import HybridAdam -from colossalai.core import global_context as gpc -from colossalai.logging import disable_existing_loggers, get_dist_logger -from colossalai.trainer import Trainer, hooks -from colossalai.utils import MultiTimer, get_dataloader -from colossalai.context import ParallelMode -from colossalai.pipeline.pipelinable import PipelinableContext +# Define some config +NUM_EPOCHS = 3 +BATCH_SIZE = 32 +LEARNING_RATE = 2.4e-5 +WEIGHT_DECAY = 0.01 +WARMUP_FRACTION = 0.1 + +coordinator = DistCoordinator() + +def move_to_cuda(batch): + return {k: v.cuda() for k, v in batch.items()} + +# Define 'criterion' function with two inputs, which will be passed to 'execute_pipeline'. +def _criterion(outputs, inputs): + return outputs.loss + +# Define optimizer +lr = LEARNING_RATE +no_decay = ["bias", "LayerNorm.weight"] +optimizer_grouped_parameters = [ + { + "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], + "weight_decay": WEIGHT_DECAY, + }, + { + "params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], + "weight_decay": 0.0, + }, +] -from titans.dataloader.cifar10 import build_cifar -from torchvision.models import resnet50 -from torchvision.models.resnet import BasicBlock, Bottleneck, conv1x1 +optimizer = HybridAdam(optimizer_grouped_parameters, lr=lr, eps=1e-8) -# Define some config -BATCH_SIZE = 64 -NUM_EPOCHS = 2 -NUM_CHUNKS = 1 -CONFIG = dict(NUM_MICRO_BATCHES=4, parallel=dict(pipeline=2)) - -# Train -disable_existing_loggers() -parser = colossalai.get_default_parser() -args = parser.parse_args() -colossalai.launch_from_torch(backend=args.backend, config=CONFIG) -logger = get_dist_logger() -pipelinable = PipelinableContext() - -# build model -with pipelinable: - model = resnet50() + +# Define lr_scheduler +total_steps = len(train_dataloader) * NUM_EPOCHS +num_warmup_steps = int(WARMUP_FRACTION * total_steps) +lr_scheduler = get_linear_schedule_with_warmup( + optimizer, + num_warmup_steps=num_warmup_steps, + num_training_steps=total_steps, +) + + +# Define Bert model +model = BertForSequenceClassification.from_pretrained("bert-base-uncased", config=cfg).cuda() + +# Define a dataloader +data_builder = GLUEDataBuilder(model_name, + plugin, + args.task, + train_batch_size=BATCH_SIZE, + eval_batch_size=BATCH_SIZE) +train_dataloader = data_builder.train_dataloader() ``` -给定切分顺序,module直接给出name,部分函数需要手动添加。 +使用`HybridParallelPlugin`初始化一个booster. ```python -exec_seq = [ - 'conv1', 'bn1', 'relu', 'maxpool', 'layer1', 'layer2', 'layer3', 'layer4', 'avgpool', - (lambda x: torch.flatten(x, 1), "behind"), 'fc' -] -pipelinable.to_layer_list(exec_seq) +plugin = HybridParallelPlugin(tp_size=1, + pp_size=2, + num_microbatches=None, + microbatch_size=1, + enable_all_optimization=True, + zero_stage=1, + precision='fp16', + initial_scale=1) +booster = Booster(plugin=plugin) ``` -将模型切分成流水线阶段。 +使用`booster`将优化特性注入到训练组件中。 ```python -model = pipelinable.partition(NUM_CHUNKS, gpc.pipeline_parallel_size, gpc.get_local_rank(ParallelMode.PIPELINE)) +model, optimizer, _criterion, _, lr_scheduler = booster.boost(model, + optimizer, + criterion=_criterion, + lr_scheduler=lr_scheduler) ``` -我们使用`Trainer`训练`ResNet`: +最后训练模型 ```python -# build criterion -criterion = nn.CrossEntropyLoss() - -# optimizer -optimizer = torch.optim.Adam(model.parameters(), lr=1e-3) - -# build dataloader -root = os.environ.get('DATA', './data') -train_dataloader, test_dataloader = build_cifar(BATCH_SIZE, root, padding=4, crop=32, resize=32) - -lr_scheduler = col_nn.lr_scheduler.LinearWarmupLR(optimizer, NUM_EPOCHS, warmup_steps=1) -engine, train_dataloader, test_dataloader, lr_scheduler = colossalai.initialize(model, optimizer, criterion, - train_dataloader, test_dataloader, - lr_scheduler) -timer = MultiTimer() - -trainer = Trainer(engine=engine, timer=timer, logger=logger) - -hook_list = [ - hooks.LossHook(), - hooks.AccuracyHook(col_nn.metric.Accuracy()), - hooks.LogMetricByEpochHook(logger), - hooks.LRSchedulerHook(lr_scheduler, by_epoch=True) -] - -trainer.fit(train_dataloader=train_dataloader, - epochs=NUM_EPOCHS, - test_dataloader=test_dataloader, - test_interval=1, - hooks=hook_list, - display_progress=True) +# Define a train function +def train_epoch(epoch: int, model: nn.Module, optimizer: Optimizer, _criterion: Callable, lr_scheduler: LRScheduler, + train_dataloader: DataLoader, booster: Booster, coordinator: DistCoordinator): + + is_pp_last_stage = booster.plugin.stage_manager.is_last_stage() + total_step = len(train_dataloader) + + model.train() + optimizer.zero_grad() + # convert train_dataloader to a iterator + train_dataloader_iter = iter(train_dataloader) + with tqdm(range(total_step), + desc=f'Epoch [{epoch + 1}/{NUM_EPOCHS}]', + disable=not (is_pp_last_stage)) as pbar: + # Forward pass + for _ in pbar: + outputs = booster.execute_pipeline(train_dataloader_iter, + model, + _criterion, + optimizer, + return_loss=True, + return_outputs=True) + # Backward and optimize + if is_pp_last_stage: + loss = outputs['loss'] + pbar.set_postfix({'loss': loss.item()}) + + optimizer.step() + optimizer.zero_grad() + lr_scheduler.step() + +# Train model +for epoch in range(NUM_EPOCHS): + train_epoch(epoch, model, optimizer, _criterion, lr_scheduler, train_dataloader, booster, coordinator) ``` -我们使用 `2` 个流水段,并且 batch 将被切分为 `4` 个 micro batches。 +我们使用 `2` 个流水段,并且 batch 将被切分为 `1` 个 micro batches。(这些参数都可根据实际情况设置为合适的值) + diff --git a/docs/source/zh-Hans/features/shardformer.md b/docs/source/zh-Hans/features/shardformer.md new file mode 100644 index 0000000000000000000000000000000000000000..99752a1ce4e08b633cba1b735e72bb0721601001 --- /dev/null +++ b/docs/source/zh-Hans/features/shardformer.md @@ -0,0 +1,333 @@ +# Shardformer + +Author: [Baizhou Zhang](https://github.com/Fridge003), [Bin Jia](https://github.com/FoolPlayer) + +**预备知识** +- [并行技术](../concepts/paradigms_of_parallelism.md) +- [Booster API](../basics/booster_api.md) +- [Booster 插件](../basics/booster_plugins.md) + +**示例代码** +- [使用Shardformer进行张量并行训练](https://github.com/hpcaitech/ColossalAI/tree/main/colossalai/shardformer/examples) +- [通过HybridParallelPlugin使用Shardformer](https://github.com/hpcaitech/ColossalAI/tree/main/examples/language/bert) + +**相关论文** +- [Efficient Large-Scale Language Model Training on GPU Clusters Using Megatron-LM](https://arxiv.org/abs/2104.04473) +- [GPipe: Efficient Training of Giant Neural Networks using Pipeline Parallelism](https://arxiv.org/abs/1811.06965) +- [FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning](https://arxiv.org/abs/2307.08691) +- [Sequence Parallelism: Long Sequence Training from System Perspective](https://arxiv.org/abs/2105.13120) +- [Reducing Activation Recomputation in Large Transformer Models](https://arxiv.org/abs/2205.05198) + + +## 简介 + +在训练LLaMa-2 70B或OPT 175B等大型Transformer模型时,为了满足GPU内存的限制,将大型模型划分为更小的分片的模型并行方法(包括张量并行以及流水线并行)是必不可少的。然而,对于不熟悉分布式训练的用户来说,手动剪切模型并重写其前向/反向逻辑可能很困难。与此同时,Huggingface transformers开源库正在逐渐成为用户模型来源的首选,大部分主流大型模型都已在Huggingface transformers模型库中开源。 + +出于这种动机,ColossalAI团队开发了**Shardformer**,该功能可以自动为HuggingFace中主流的Transformer模型进行封装,用于张量并行以及流水线并行的训练策略。如此一来,对系统了解不多的用户也可以轻松地在transformers模型上进行并行训练:只需几行代码,用户就可以将模型转变为并行训练的状态。此外,Shardformer也包括了多种优化工具,用于在前向/后向的传递过程中实现加速和节省内存。 + +## 支持信息 + +模型/功能 兼容性矩阵: + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
              Model/FeatureTensor
              Parallel
              Pipeline
              Parallel
              Lazy
              Initialization
              xFormersFlash
              Attention 2
              JIT Fused
              Operators
              Fused
              LayerNorm
              Sequence
              Parallel
              Sequence
              Overlap
              Llama V1/V2✔️✔️✔️✔️✔️✔️✔️
              OPT✔️✔️✔️✔️✔️✔️✔️
              BLOOM✔️✔️✔️✔️✔️✔️✔️✔️✔️
              ChatGLM 2✔️✔️✔️✔️✔️✔️✔️✔️✔️
              BERT✔️✔️✔️✔️✔️✔️✔️✔️✔️
              GPT 2✔️✔️✔️✔️✔️✔️✔️✔️✔️
              T5✔️✔️✔️✔️✔️✔️✔️
              ViT✔️✔️✔️✔️✔️✔️
              Whisper✔️✔️✔️✔️✔️✔️
              SAM✔️✔️✔️✔️✔️
              Blip2✔️✔️✔️✔️✔️
              + +我们计划在不久后为Shardformer支持的模型: +- RoBERTa +- ALBERT +- ERNIE +- GPT Neo +- GPT-J +- BEiT +- SwinTransformer V1/V2 +- qwen + +随着未来更多模型和优化工具的出现,我们支持的模型/优化工具将会变得越来越多。如果您对我们应该支持的模型/优化工具有任何建议,欢迎在项目的[Issues](https://github.com/hpcaitech/ColossalAI/issues)板块参与讨论。 + +## 用法 + +### Shardformer的参数配置 + +Shardformer的配置由类`ShardConfig`的参数控制: + +{{ autodoc:colossalai.shardformer.ShardConfig }} + +如果您想启用 Apex Fused Layernorm,请安装 `apex`。如果您想启用 flash attention,请安装 `flash_attn`。此外,xFormers 的 `cutlass_op` 可以作为Flash Attention的补充优化方式。 + +### 启动Shardformer + +#### 1. 通过Booster启动Shardformer (推荐) + +通过用`HybridParallelPlugin`初始化的`Booster`来启动`Shardformer`是最推荐的用法。其主要原因是如果不调用`Booster`的`execute_pipeline`方法,流水线并行就无法正常工作。此外,`HybridParallelPlugin`提供了将`Shardformer`的功能与其他功能(例如混合精度训练或Zero)相结合的能力。 + +[这里](https://github.com/hpcaitech/ColossalAI/tree/main/examples/language/bert)是一个通过`HybridParallelPlugin`启动`Shardformer`的示例。 +移动到示例的根目录下,执行命令: +```bash +torchrun --standalone --nproc_per_node 4 finetune.py --target_f1 0.86 --plugin "hybrid_parallel" --model_type "bert" +``` +你便可以微调一个被`Shardformer`封装过的Bert模型,而封装的操作是由`HybridParallelPlugin`完成的。 + +接下来一起深入挖掘一下`finetune.py`里的代码: + +在`main`函数中,混合并行的插件通过以下的代码创建 +```python +... +elif args.plugin == "hybrid_parallel": + # modify the param accordingly for finetuning test cases + plugin = HybridParallelPlugin( + tp_size=1, + pp_size=2, + num_microbatches=None, + microbatch_size=1, + enable_all_optimization=True, + zero_stage=1, + precision="fp16", + initial_scale=1, + ) +``` +在这里你可以通过设置不同的`tp_size`, `pp_size` 或 `zero_stage`来改变插件的配置。更多关于插件配置的信息可以在[Booster 插件文档](../basics/booster_plugins.md)中被找到。 + +当流水并行不被启用的时候,训练的流程和其他的插件是一样的 (先用Booster封装模型和优化器,再用正常的方式做前向和后向传递)。然而,当流水线并行被启用的时候,有几处不同于寻常情况的用法: + +1. 在进行前向和后向之前,criterion函数(loss函数)需要被处理以满足流水线并行的传参要求: + ```python + def _criterion(outputs, inputs): + outputs = output_transform_fn(outputs) + loss = criterion(outputs) + return loss + ``` + +2. 在 `train_epoch` 函数中, dataloader 在进行流水线的前向后向操作之前需要被转换为 `Iterator` 类: + ```python + train_dataloader_iter = iter(train_dataloader) + ``` + +3. 通过调用`Booster.execute_pipeline` 方法来执行前向和后向传递: + ```python + outputs = booster.execute_pipeline( + train_dataloader_iter, model, _criterion, optimizer, return_loss=True, return_outputs=True + ) + ``` + 该方法会自动执行后向传递,所以在执行该方法后不需要再调用 `loss.backward()`方法。 + 更多关于 `Booster.execute_pipeline` 的信息可以参考 [Booster API 文档](../basics/booster_api.md)。 + +#### 2. 通过Shardformer API启动Shardformer (不推荐) + +您还可以通过手动调用Shardformer API的方式启动Shardformer。然而我们并不推荐这种用法,因为流水线并行在没有`Booster`的情况下无法正常运行。 + +[这里](https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/shardformer/examples/convergence_benchmark.py) +是一个通过调用Shardformer的API启动`Shardformer`的示例。 +在示例代码的`train`函数中,模型被以下的几行代码进行封装: +```python +... +if dist.get_world_size() > 1: + tp_group = dist.new_group(backend="nccl") + + # First create configuration for Shardformer + shard_config = ShardConfig( + tensor_parallel_process_group=tp_group, + enable_tensor_parallelism=True, + enable_all_optimization=True + ) + + # Then create ShardFormer object with created config + shard_former = ShardFormer(shard_config=shard_config) + + # Finally shard the model using ShardFormer.optimize method + model, _ = shard_former.optimize(model) +... +``` + +### 注意事项 + +1. 当启用流水线并行时,请不要用常规方式(`model(input)`、`loss.backward()`)进行前向/后向传递,这样会导致未知的错误。这种情形下请通过调用`booster.execute_pipeline`方法来进行前向/后向传递。 + +2. 当使用Shardformer处理`GPT2ForSequenceClassification`、`ViTForImageClassification`等分类模型时,请确保labels的总数为张量并行度的整数倍,否则Shardformer无法正确地处理classifier层。一个简单的修复方法就是在transformers的config中添加虚拟的标签。这一bug将在 Shardformer的未来版本中修复。 + +3. 训练ChatGLM-2 6B的情况有点特殊:由于Huggingface Transformers 目前尚未正式支持ChatGLM。在使用Shardformer训练ChatGLM-2时,请通过以下方式导入config/model的类: + ```python + from colossalai.shardformer.modeling.chatglm2_6b.configuration_chatglm import ChatGLMConfig + from colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm import ChatGLMForConditionalGeneration, ChatGLMModel + ``` + 并且使用这些导入的类初始化模型。 + + +## Shardformer的工作原理 + +### 设计思想 + +通常来说,Shardformer通过以下四种“替换”进行工作: + +1. 用我们设计的分布式模块替换原始的PyTorch模块(例如`nn.Linear`、`nn.Embedding`)。 +分布式模块保持与原始模块相同的属性,但分布式模块会用新的参数替换原始模块的参数。新的前向函数将取代原来的前向函数,用于执行分布式计算,例如在张量并行下执行线性层的split/gather操作。每个分布式模块都应当实现其`from_native_module`静态方法,以将PyTorch模块转换为其相应的分布式模块。 + +2. 将原始Huggingface Transformers中间层的属性为适用于并行训练的属性。例如,当使用并行度为2的张量并行训练LlaMa-2时,`LlamaDecoderLayer` 的属性`num_heads`(每一层注意力头的数量)应替换为`model.config.num_attention_heads // 2`。 + +3. 将原来Huggingface transformers库实现的前向函数替换为我们定制的前向函数。前向函数的替换对于流水线并行性至关重要,因为流水线并行需要特殊的前向函数去在不同的流水线阶段之间传递中间的隐藏状态。此外,可以通过我们定制的前向函数将例如`flash attention`或序列并行的优化方法注入到前向的过程中。 + +4. 将完整的模型参数和优化器状态替换为只由当前设备控制的部分模型参数和优化器状态。通过执行`ModelSharder.shard`方法,当前设备仅会保留它应该处理的那部分模型参数。具体来说,这部分参数可以是使用张量并行时分配到当前机器的参数分片,或者使用流水线并行时当前流水线阶段的模型参数,或者兼而有之。除此之外的所有其他参数都被释放,用于节省内存的空间。 +如此一来,优化器只会计算保留的部分参数对应的状态,从而进一步节省内存的使用。 + +所有这些替换都是通过手动编写的策略和前向函数来实现的。如果您想更深入地研究Shardformer的设计方案,或者定制您自己的Shardformer策略,请参考[Shardformer 开发者文档](https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/shardformer/README.md)和[流水并行设计方案](https://github.com/hpcaitech/ColossalAI/discussions/4050)以获得更多细节。 + +### 序列并行 Sequence Parallelism + +序列并行是`Shardformer`支持的一种特殊的优化方法。在`Shardformer`中,序列并行与[此处](https://colossalai.org/docs/basics/configure_parallelization/#sequence-parallel)稍有不同,后者侧重于ring attention。在`Shardformer`中,序列并行仅与1D张量并行一起使用,以进一步减少计算中activation的内存占用。 + +1. 在普通的[1D张量并行](https://colossalai.org/docs/features/1D_tensor_parallel)中,有两个通信操作$g$和$\vec{g}$,$g$在反向传播中进行一次全局归约以获取来自所有设备的梯度,而$\vec{g}$在正向传播中进行一次All-Reduce以获取来自所有设备的输出。 + +2. 当使用序列并行时,$\vec{g}$需要在正向传播过程中进行All-Gather以获取序列维度上的输入,并在反向传播过程中进行Reduce-Scatter以分割梯度。$\vec{g}$需要进行Reduce-Scatter以将序列维度上的行线性层输出分割到所有设备上,并进行All-Gather以获取完整的梯度。 + +3. 使用NCCL的All-reduce实现采用了`Ring All-Reduce`方法,由一次Reduce-Scatter和一次All-Gather组成,两者的开销相等。因此,与序列并行和张量并行相比,它并不会引入额外的通信开销。 + +4. 需要注意的一点是,在张量并行的 `Column Linear` 层中进行序列并行时,梯度的反向计算过程中需要获取完整的输入。在前向传播过程中,仅保留沿序列维度分割的输入部分,张量的形状例如$(batch, sequence\_len/k, hidden\_states)$。因此,需要进行额外的全局收集操作以获取完整的输入进行梯度计算。但是,在实现中,可以将梯度计算与全局收集通信操作重叠,这不会引入额外的通信开销(对应`Shardformer`中的`enable_sequence_overlap`参数)。 + + + diff --git a/docs/source/zh-Hans/features/zero_with_chunk.md b/docs/source/zh-Hans/features/zero_with_chunk.md index 72403bf610a4f9523f88a5d0791417e0daed3bd8..61290628588bdae92c97f4a41a910a8dbe3e5509 100644 --- a/docs/source/zh-Hans/features/zero_with_chunk.md +++ b/docs/source/zh-Hans/features/zero_with_chunk.md @@ -4,7 +4,7 @@ **前置教程:** -- [定义配置文件](../basics/define_your_config.md) +- [booster使用](../basics/booster_api.md) **示例代码** @@ -53,32 +53,37 @@ 我们将运用`GeminiDDP`的方式来使用基于Chunk内存管理的ZeRO。这是我们新包装的torch.Module ,它使用 ZeRO-DP 和 Gemini,其中ZeRO 用于并行,Gemini 用于内存管理。 -同样需要确保你的模型是在 `ColoInitContext` 的上下文中初始化的。 +Gemini支持惰性初始化, 它可以节省多卡初始化大模型时的显存使用. +如果你的模型有 `N` billion 个参数,你的 GPU 内存为 `M` GB, 当 `4N >= M` 时,我们推荐使用 LazyInitContext。否则,LazyInitContext 是可选的。 + + ```python -with ColoInitContext(device='cpu', default_dist_spec=default_dist_spec, default_pg=default_pg): +with LazyInitContext(default_device=torch.device('cuda')): model = gpt2_medium(checkpoint=True) ``` + + +我们提供了 `Booster` API,它用户友好。我们推荐你使用 `Booster` API。如果您仍然想使用底层 API,您可以继续阅读本节其他内容。 -定义模型参数如下: +使用 `GeminiDDP` 包装模型。 + ```python -chunk_manager = init_chunk_manager(model=module, - init_device=device, - hidden_dim=hidden_dim, - search_range_mb=search_range_mb, - min_chunk_size_mb=min_chunk_size_mb) -gemini_manager = GeminiManager(placement_policy, chunk_manager) -model = ZeroDDP(model, gemini_manager) +model = GeminiDDP(model, hidden_dim=hidden_dim, min_chunk_size_m=min_chunk_size_m) ``` + -`hidden dim`是DNN的隐藏维度。用户可以提供这个参数来加快搜索速度。如果用户在训练前不知道这个参数也可以。 我们将使用默认值 1024。`min_chunk_size_mb`是以兆字节为单位的最小块大小。如果参数的总大小仍然小于最小块大小,则所有参数将被压缩为一个小块。 +`hidden dim`是DNN的隐藏维度。用户可以提供这个参数来加快搜索速度。如果用户在训练前不知道这个参数也可以。 我们将使用默认值 1024。`min_chunk_size_m`是以兆(2^20)为单位的最小块大小。如果参数的总大小仍然小于最小块大小,则所有参数将被压缩为一个小块。 初始化优化器。 + ```python optimizer = GeminiAdamOptimizer(model, lr=1e-3, initial_scale=2**5) ``` + + 训练 ```python optimizer.zero_grad() @@ -87,6 +92,7 @@ loss = criterion(outputs, input_ids) optimizer.backward(loss) optimizer.step() ``` + > ⚠️ 注意:请不要使用`loss.backward()`,规范写法是`optimizer.backward(loss)`。 ### 训练GPT @@ -97,6 +103,8 @@ optimizer.step() 首先我们只需要引入`Huggingface transformers` 的 `GPT2LMHeadModel`来定义我们的模型,不需要用户进行模型的定义与修改,方便用户使用。 +定义GPT模型: + ```python class GPTLMModel(nn.Module): @@ -141,75 +149,6 @@ class GPTLMLoss(nn.Module): return self.loss_fn(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) ``` -定义张量并行和参数分片策略: - -```python -def tensor_parallelize(model: torch.nn.Module, pg: ProcessGroup): - for mn, module in model.named_modules(): - for pn, param in module.named_parameters(recurse=False): - if hasattr(param, 'visited'): - continue - param.set_dist_spec(ReplicaSpec()) - if 'mlp.c_fc' in mn: - if 'weight' in pn or 'bias' in pn: - split_param_col_tp1d(param, pg) - param.compute_spec.set_output_replicate(False) - else: - param.set_dist_spec(ReplicaSpec()) - elif 'mlp.c_proj' in mn: - if 'weight' in pn: - split_param_row_tp1d(param, pg) - else: - param.set_dist_spec(ReplicaSpec()) - elif 'wte' in mn or 'wpe' in mn: - split_param_col_tp1d(param, pg) - elif 'c_attn' in mn or 'c_proj' in mn: - split_param_col_tp1d(param, pg) - else: - param.set_dist_spec(ReplicaSpec()) - - param.visited = True -def split_param_single_dim_tp1d(dim: int, param: ColoParameter, pg: ProcessGroup): - spec = (ShardSpec([dim], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D)) - param.set_tensor_spec(*spec) - - -def split_param_row_tp1d(param: ColoParameter, pg: ProcessGroup): - split_param_single_dim_tp1d(0, param, pg) - - -def split_param_col_tp1d(param: ColoParameter, pg: ProcessGroup): - split_param_single_dim_tp1d(-1, param, pg) -``` - -定义一个使用 Gemini + ZeRO DDP 的模型: - -```python -def gemini_zero_dpp(model: torch.nn.Module, pg: ProcessGroup, placememt_policy: str = "auto"): - cai_version = colossalai.__version__ - if version.parse(cai_version) > version.parse("0.1.10"): - from colossalai.nn.parallel import GeminiDDP - model = GeminiDDP(model, - device=get_current_device(), - placement_policy=placememt_policy, - pin_memory=True, - search_range_mb=32) - elif version.parse(cai_version) <= version.parse("0.1.10") and version.parse(cai_version) >= version.parse("0.1.9"): - from colossalai.gemini import ChunkManager, GeminiManager - chunk_size = ChunkManager.search_chunk_size(model, 64 * 1024**2, 32) - gemini_manager = GeminiManager(placememt_policy, chunk_manager) - chunk_manager = ChunkManager(chunk_size, - pg, - enable_distributed_storage=True, - init_device=GeminiManager.get_default_device(placememt_policy)) - model = ZeroDDP(model, gemini_manager) - else: - raise NotImplemented(f"CAI version {cai_version} is not supported") - return model -``` - -由于我们在这个例子中对GPT进行预训练,因此只使用了一个简单的语言模型损失函数。 - 写一个获得随机输入的函数: ```python @@ -219,9 +158,16 @@ def get_data(batch_size, seq_len, vocab_size): return input_ids, attention_mask ``` -最后,我们可以定义我们的训练循环: + +最后,使用booster注入 Gemini + ZeRO DDP 特性, 并定义训练循环。由于我们在这个例子中对GPT进行预训练,因此只使用了一个简单的语言模型损失函数: ```python +from colossalai.nn.optimizer import HybridAdam + +from colossalai.booster import Booster +from colossalai.lazy import LazyInitContext +from colossalai.booster.plugin import GeminiPlugin + def main(): args = parse_args() BATCH_SIZE = 8 @@ -232,22 +178,19 @@ def main(): # build criterion criterion = GPTLMLoss() + optimizer = HybridAdam(model.parameters(), lr=0.001) torch.manual_seed(123) - default_pg = ProcessGroup(tp_degree=args.tp_degree) - default_dist_spec = ShardSpec([-1], [args.tp_degree]) if args.shardinit else None # build GPT model - with ColoInitContext(device='cpu', default_dist_spec=default_dist_spec, default_pg=default_pg): + with ColoInitContext(default_device=torch.device('cuda')): model = gpt2_medium(checkpoint=True) - pg = default_pg - # Tensor Parallelism (TP) - tensor_parallelize(model, pg) - # Gemini + ZeRO DP, Note it must be used after TP - model = gemini_zero_dpp(model, pg, args.placement) - # build optimizer - optimizer = GeminiAdamOptimizer(model, lr=1e-3, initial_scale=2**5) - numel = sum([p.numel() for p in model.parameters()]) - get_tflops_func = partial(get_tflops, numel, BATCH_SIZE, SEQ_LEN) + + + # Gemini + ZeRO DP + plugin = GeminiPlugin(max_norm=1.0, initial_scale=2**5) + booster = Booster(plugin=plugin) + model, optimizer, criterion, _, _ = booster.boost(model, optimizer, criterion) + torch.cuda.synchronize() model.train() for n in range(NUM_STEPS): @@ -256,10 +199,12 @@ def main(): optimizer.zero_grad() outputs = model(input_ids, attn_mask) loss = criterion(outputs, input_ids) - optimizer.backward(loss) + booster.backward(loss, optimizer) optimizer.step() torch.cuda.synchronize() ``` -> ⚠️ 注意:如果你使用Gemini模块的话,请不要使用我们之前提到过的[梯度累加](../features/gradient_accumulation.md)。 +> ⚠️ 注意:如果你使用Gemini模块的话,请不要使用我们之前提到过的[梯度累加](../features/gradient_accumulation_with_booster.md)。 完整的例子代码可以在 [Train GPT with Colossal-AI](https://github.com/hpcaitech/ColossalAI/tree/main/examples/language/gpt). 获得。 + + diff --git a/docs/source/zh-Hans/get_started/installation.md b/docs/source/zh-Hans/get_started/installation.md index 72f85393814fb97eb5679d26c82d74f102e6cd49..a6c88672b90775b0eaafc44c5e828b968e86b2f1 100755 --- a/docs/source/zh-Hans/get_started/installation.md +++ b/docs/source/zh-Hans/get_started/installation.md @@ -28,7 +28,7 @@ CUDA_EXT=1 pip install colossalai ## 从源安装 -> 此文档将与版本库的主分支保持一致。如果您遇到任何问题,欢迎给我们提 issue :) +> 此文档将与版本库的主分支保持一致。如果您遇到任何问题,欢迎给我们提 issue。 ```shell git clone https://github.com/hpcaitech/ColossalAI.git @@ -38,13 +38,29 @@ cd ColossalAI pip install -r requirements/requirements.txt # install colossalai -pip install . +CUDA_EXT=1 pip install . ``` -如果您不想安装和启用 CUDA 内核融合(使用融合优化器时强制安装): +如果您不想安装和启用 CUDA 内核融合(使用融合优化器时强制安装),您可以不添加`CUDA_EXT=1`: ```shell -NO_CUDA_EXT=1 pip install . +pip install . +``` + +如果您在使用CUDA 10.2,您仍然可以从源码安装ColossalAI。但是您需要手动下载cub库并将其复制到相应的目录。 + +```bash +# clone the repository +git clone https://github.com/hpcaitech/ColossalAI.git +cd ColossalAI + +# download the cub library +wget https://github.com/NVIDIA/cub/archive/refs/tags/1.8.0.zip +unzip 1.8.0.zip +cp -r cub-1.8.0/cub/ colossalai/kernel/cuda_native/csrc/kernels/include/ + +# install +CUDA_EXT=1 pip install . ``` diff --git a/docs/source/zh-Hans/get_started/run_demo.md b/docs/source/zh-Hans/get_started/run_demo.md index edfc246c22d5672f9d9125042bb4ff4e82b90eba..70ed5ebe251be8cf73c20d31f62c20879e80f033 100755 --- a/docs/source/zh-Hans/get_started/run_demo.md +++ b/docs/source/zh-Hans/get_started/run_demo.md @@ -4,8 +4,8 @@ Colossal-AI 是一个集成的大规模深度学习系统,具有高效的并 ## 单 GPU -Colossal-AI 可以用在只有一个 GPU 的系统上训练深度学习模型,并达到 baseline 的性能。 我们提供了一个 [在CIFAR10数据集上训练ResNet](https://github.com/hpcaitech/ColossalAI-Examples/tree/main/image/resnet) 的例子,该例子只需要一个 GPU。 -您可以在 [ColossalAI-Examples](https://github.com/hpcaitech/ColossalAI-Examples) 中获取该例子。详细说明可以在其 `README.md` 中获取。 +Colossal-AI 可以用在只有一个 GPU 的系统上训练深度学习模型,并达到 baseline 的性能。 我们提供了一个 [在 CIFAR10 数据集上训练 ResNet](https://github.com/hpcaitech/ColossalAI/tree/main/examples/images/resnet) 的例子,该例子只需要一个 GPU。 +您可以在 [ColossalAI-Examples](https://github.com/hpcaitech/ColossalAI/tree/main/examples) 中获取该例子。详细说明可以在其 `README.md` 中获取。 ## 多 GPU @@ -13,16 +13,20 @@ Colossal-AI 可用于在具有多个 GPU 的分布式系统上训练深度学习 #### 1. 数据并行 -您可以使用与上述单 GPU 演示相同的 [ResNet例子](https://github.com/hpcaitech/ColossalAI-Examples/tree/main/image/resnet)。 通过设置 `--nproc_per_node` 为您机器上的 GPU 数量,您就能把数据并行应用在您的例子上了。 +您可以使用与上述单 GPU 演示相同的 [ResNet 例子](https://github.com/hpcaitech/ColossalAI/tree/main/examples/images/resnet)。 通过设置 `--nproc_per_node` 为您机器上的 GPU 数量,您就能把数据并行应用在您的例子上了。 #### 2. 混合并行 -混合并行包括数据、张量和流水线并行。在 Colossal-AI 中,我们支持不同类型的张量并行(即 1D、2D、2.5D 和 3D)。您可以通过简单地改变 `config.py` 中的配置在不同的张量并行之间切换。您可以参考 [GPT example](https://github.com/hpcaitech/ColossalAI-Examples/tree/main/language/gpt), 更多细节能在它的 `README.md` 中被找到。 +混合并行包括数据、张量和流水线并行。在 Colossal-AI 中,我们支持不同类型的张量并行(即 1D、2D、2.5D 和 3D)。您可以通过简单地改变 `config.py` 中的配置在不同的张量并行之间切换。您可以参考 [GPT example](https://github.com/hpcaitech/ColossalAI/tree/main/examples/language/gpt), 更多细节能在它的 `README.md` 中被找到。 -#### 3. MoE并行 +#### 3. MoE 并行 -我们提供了一个 [WideNet例子](https://github.com/hpcaitech/ColossalAI-Examples/tree/main/image/widenet) 来验证 MoE 的并行性。 WideNet 使用 Mixture of Experts(MoE)来实现更好的性能。更多的细节可以在我们的教程中获取:[教会您如何把Mixture of Experts整合到模型中](../advanced_tutorials/integrate_mixture_of_experts_into_your_model.md)。 + + +我们提供了一个 [ViT-MoE 例子](https://github.com/hpcaitech/ColossalAI-Examples/tree/main/image/moe) 来验证 MoE 的并行性。 WideNet 使用 Mixture of Experts(MoE)来实现更好的性能。更多的细节可以在我们的教程中获取:[教会您如何把 Mixture of Experts 整合到模型中](../advanced_tutorials/integrate_mixture_of_experts_into_your_model.md)。 #### 4. 序列并行 -序列并行是为了解决NLP任务中的内存效率和序列长度限制问题。 我们在 [ColossalAI-Examples](https://github.com/hpcaitech/ColossalAI-Examples) 中提供了一个 [BERT例子](https://github.com/hpcaitech/ColossalAI-Examples/tree/main/language/bert/sequene_parallel)。您可以按照 `README.md` 来执行代码。 +序列并行是为了解决 NLP 任务中的内存效率和序列长度限制问题。 我们在 [ColossalAI-Examples](https://github.com/hpcaitech/ColossalAI/tree/main/examples) 中提供了一个 [Sequence Parallelism 例子](https://github.com/hpcaitech/ColossalAI/tree/main/examples/tutorial/sequence_parallel)。您可以按照 `README.md` 来执行代码。 + + diff --git a/examples/README.md b/examples/README.md index 142a735c68192ed819b215e36caf2a64755eeb92..b822fb8ff92374ac45edc56bcb7bd03f079d61e6 100644 --- a/examples/README.md +++ b/examples/README.md @@ -36,7 +36,7 @@ You may contact us or participate in the following ways: 1. [Leaving a Star ⭐](https://github.com/hpcaitech/ColossalAI/stargazers) to show your like and support. Thanks! 2. Posting an [issue](https://github.com/hpcaitech/ColossalAI/issues/new/choose), or submitting a PR on GitHub follow the guideline in [Contributing](https://github.com/hpcaitech/ColossalAI/blob/main/CONTRIBUTING.md). 3. Join the Colossal-AI community on -[Slack](https://join.slack.com/t/colossalaiworkspace/shared_invite/zt-z7b26eeb-CBp7jouvu~r0~lcFzX832w), +[Slack](https://github.com/hpcaitech/public_assets/tree/main/colossalai/contact/slack), and [WeChat(微信)](https://raw.githubusercontent.com/hpcaitech/public_assets/main/colossalai/img/WeChat.png "qrcode") to share your ideas. 4. Send your official proposal to email contact@hpcaitech.com diff --git a/examples/community/fp8/mnist/main.py b/examples/community/fp8/mnist/main.py index a534663d380f4b4e00cd0fe80de1d3525bac3e69..2bb912dec247933f420195df8f7425e6f330dfa4 100644 --- a/examples/community/fp8/mnist/main.py +++ b/examples/community/fp8/mnist/main.py @@ -13,13 +13,13 @@ from torchvision import datasets, transforms try: from transformer_engine import pytorch as te + HAVE_TE = True except (ImportError, ModuleNotFoundError): HAVE_TE = False class Net(nn.Module): - def __init__(self, use_te=False): super(Net, self).__init__() self.conv1 = nn.Conv2d(1, 32, 3, 1) @@ -64,10 +64,12 @@ def train(args, model, device, train_loader, optimizer, epoch, use_fp8): loss.backward() optimizer.step() if batch_idx % args.log_interval == 0: - print(f"Train Epoch: {epoch} " - f"[{batch_idx * len(data)}/{len(train_loader.dataset)} " - f"({100. * batch_idx / len(train_loader):.0f}%)]\t" - f"Loss: {loss.item():.6f}") + print( + f"Train Epoch: {epoch} " + f"[{batch_idx * len(data)}/{len(train_loader.dataset)} " + f"({100. * batch_idx / len(train_loader):.0f}%)]\t" + f"Loss: {loss.item():.6f}" + ) if args.dry_run: break @@ -75,13 +77,11 @@ def train(args, model, device, train_loader, optimizer, epoch, use_fp8): def calibrate(model, device, test_loader): """Calibration function.""" model.eval() - test_loss = 0 - correct = 0 with torch.no_grad(): for data, target in test_loader: data, target = data.to(device), target.to(device) with te.fp8_autocast(enabled=False, calibrating=True): - output = model(data) + model(data) def test(model, device, test_loader, use_fp8): @@ -94,15 +94,17 @@ def test(model, device, test_loader, use_fp8): data, target = data.to(device), target.to(device) with te.fp8_autocast(enabled=use_fp8): output = model(data) - test_loss += F.nll_loss(output, target, reduction="sum").item() # sum up batch loss - pred = output.argmax(dim=1, keepdim=True) # get the index of the max log-probability + test_loss += F.nll_loss(output, target, reduction="sum").item() # sum up batch loss + pred = output.argmax(dim=1, keepdim=True) # get the index of the max log-probability correct += pred.eq(target.view_as(pred)).sum().item() test_loss /= len(test_loader.dataset) - print(f"\nTest set: Average loss: {test_loss:.4f}, " - f"Accuracy: {correct}/{len(test_loader.dataset)} " - f"({100. * correct / len(test_loader.dataset):.0f}%)\n") + print( + f"\nTest set: Average loss: {test_loss:.4f}, " + f"Accuracy: {correct}/{len(test_loader.dataset)} " + f"({100. * correct / len(test_loader.dataset):.0f}%)\n" + ) def main(): @@ -163,10 +165,9 @@ def main(): default=False, help="For Saving the current Model", ) - parser.add_argument("--use-fp8", - action="store_true", - default=False, - help="Use FP8 for inference and training without recalibration") + parser.add_argument( + "--use-fp8", action="store_true", default=False, help="Use FP8 for inference and training without recalibration" + ) parser.add_argument("--use-fp8-infer", action="store_true", default=False, help="Use FP8 inference only") parser.add_argument("--use-te", action="store_true", default=False, help="Use Transformer Engine") args = parser.parse_args() @@ -215,7 +216,7 @@ def main(): if args.save_model or args.use_fp8_infer: torch.save(model.state_dict(), "mnist_cnn.pt") - print('Eval with reloaded checkpoint : fp8=' + str(args.use_fp8_infer)) + print("Eval with reloaded checkpoint : fp8=" + str(args.use_fp8_infer)) weights = torch.load("mnist_cnn.pt") model.load_state_dict(weights) test(model, device, test_loader, args.use_fp8_infer) diff --git a/examples/community/roberta/README.md b/examples/community/roberta/README.md index 8aefa327a4b4bb7496f722b2713b986a5382b848..000fce63f35f093ca638e79224e1152ebba0b841 100644 --- a/examples/community/roberta/README.md +++ b/examples/community/roberta/README.md @@ -44,7 +44,7 @@ following the `README.md`, load the h5py generated by preprocess of step 1 to pr ## 3. Finetune -The checkpoint produced by this repo can replace `pytorch_model.bin` from [hfl/chinese-roberta-wwm-ext-large](https://huggingface.co/hfl/chinese-roberta-wwm-ext-large/tree/main) directly. Then use transfomers from Hugging Face to finetune downstream application. +The checkpoint produced by this repo can replace `pytorch_model.bin` from [hfl/chinese-roberta-wwm-ext-large](https://huggingface.co/hfl/chinese-roberta-wwm-ext-large/tree/main) directly. Then use transformers from Hugging Face to finetune downstream application. ## Contributors The example is contributed by AI team from [Moore Threads](https://www.mthreads.com/). If you find any problems for pretraining, please file an issue or send an email to yehua.zhang@mthreads.com. At last, welcome any form of contribution! diff --git a/examples/community/roberta/preprocessing/README.md b/examples/community/roberta/preprocessing/README.md index 17cc2f4dc22c36226560a1ad11c0f3f0493c84f1..2ed74754128089fa8f82b88368d4903d7c3a8f64 100644 --- a/examples/community/roberta/preprocessing/README.md +++ b/examples/community/roberta/preprocessing/README.md @@ -25,10 +25,10 @@ Firstly, each file has multiple documents, and each document contains multiple s In this example, split 200G Corpus into 100 shard, and each shard is about 2G. The size of the shard is memory-dependent, taking into account the number of servers, the memory used by the tokenizer, and the memory used by the multi-process training to read the shard (n data parallel requires n\*shard_size memory). **To sum up, data preprocessing and model pretraining requires fighting with hardware, not just GPU.** ```python -python sentence_split.py --input_path /orginal_corpus --output_path /shard --shard 100 +python sentence_split.py --input_path /original_corpus --output_path /shard --shard 100 # This step takes a short time ``` -* `--input_path`: all original corpus, e.g., /orginal_corpus/0.json /orginal_corpus/1.json ... +* `--input_path`: all original corpus, e.g., /original_corpus/0.json /original_corpus/1.json ... * `--output_path`: all shard with split sentences, e.g., /shard/0.txt, /shard/1.txt ... * `--shard`: Number of shard, e.g., 10, 50, or 100 @@ -76,7 +76,7 @@ make * `--input_path`: location of all shard with split sentences, e.g., /shard/0.txt, /shard/1.txt ... * `--output_path`: location of all h5 with token_id, input_mask, segment_ids and masked_lm_positions, e.g., /h5/0.h5, /h5/1.h5 ... -* `--tokenizer_path`: tokenizer path contains huggingface tokenizer.json. Download config.json, special_tokens_map.json, vocab.txt and tokenzier.json from [hfl/chinese-roberta-wwm-ext-large](https://huggingface.co/hfl/chinese-roberta-wwm-ext-large/tree/main) +* `--tokenizer_path`: tokenizer path contains huggingface tokenizer.json. Download config.json, special_tokens_map.json, vocab.txt and tokenizer.json from [hfl/chinese-roberta-wwm-ext-large](https://huggingface.co/hfl/chinese-roberta-wwm-ext-large/tree/main) * `--backend`: python or c++, **specifies c++ can obtain faster preprocess speed** * `--dupe_factor`: specifies how many times the preprocessor repeats to create the input from the same article/document * `--worker`: number of process diff --git a/examples/community/roberta/preprocessing/get_mask.py b/examples/community/roberta/preprocessing/get_mask.py index 74c97a63a9f3994bd4fd6e953222c83fedb229b3..f0ba8fe38501e516b30a9d0e10bb75ababef6657 100644 --- a/examples/community/roberta/preprocessing/get_mask.py +++ b/examples/community/roberta/preprocessing/get_mask.py @@ -1,13 +1,8 @@ import collections import logging -import os import random -import time -from enum import IntEnum -from random import choice import jieba -import torch jieba.setLogLevel(logging.CRITICAL) import re @@ -23,14 +18,15 @@ def map_to_numpy(data): return np.asarray(data) -class PreTrainingDataset(): - - def __init__(self, - tokenizer, - max_seq_length, - backend='python', - max_predictions_per_seq: int = 80, - do_whole_word_mask: bool = True): +class PreTrainingDataset: + def __init__( + self, + tokenizer, + max_seq_length, + backend="python", + max_predictions_per_seq: int = 80, + do_whole_word_mask: bool = True, + ): self.tokenizer = tokenizer self.max_seq_length = max_seq_length self.masked_lm_prob = 0.15 @@ -38,8 +34,8 @@ class PreTrainingDataset(): self.do_whole_word_mask = do_whole_word_mask self.max_predictions_per_seq = max_predictions_per_seq self.vocab_words = list(tokenizer.vocab.keys()) - self.rec = re.compile('[\u4E00-\u9FA5]') - self.whole_rec = re.compile('##[\u4E00-\u9FA5]') + self.rec = re.compile("[\u4E00-\u9FA5]") + self.whole_rec = re.compile("##[\u4E00-\u9FA5]") self.mlm_p = 0.15 self.mlm_mask_p = 0.8 @@ -64,7 +60,7 @@ class PreTrainingDataset(): original_tokens = [] segment_ids = [] tokens.append("[CLS]") - original_tokens.append('[CLS]') + original_tokens.append("[CLS]") segment_ids.append(0) for index, token in enumerate(tokens_a): tokens.append(token) @@ -72,7 +68,7 @@ class PreTrainingDataset(): segment_ids.append(0) tokens.append("[SEP]") - original_tokens.append('[SEP]') + original_tokens.append("[SEP]") segment_ids.append(0) # for token in tokens_b: @@ -83,11 +79,16 @@ class PreTrainingDataset(): # segment_ids.append(1) # Get Masked LM predictions - if self.backend == 'c++': + if self.backend == "c++": output_tokens, masked_lm_output = mask.create_whole_masked_lm_predictions( - tokens, original_tokens, self.vocab_words, self.tokenizer.vocab, self.max_predictions_per_seq, - self.masked_lm_prob) - elif self.backend == 'python': + tokens, + original_tokens, + self.vocab_words, + self.tokenizer.vocab, + self.max_predictions_per_seq, + self.masked_lm_prob, + ) + elif self.backend == "python": output_tokens, masked_lm_output = self.create_whole_masked_lm_predictions(tokens) # Convert to Ids @@ -99,20 +100,20 @@ class PreTrainingDataset(): segment_ids.append(PAD) input_mask.append(PAD) masked_lm_output.append(-1) - return ([ + return [ map_to_numpy(input_ids), map_to_numpy(input_mask), map_to_numpy(segment_ids), map_to_numpy(masked_lm_output), - map_to_numpy([is_next]) - ]) + map_to_numpy([is_next]), + ] def create_masked_lm_predictions(self, tokens): cand_indexes = [] for i, token in enumerate(tokens): if token == "[CLS]" or token == "[SEP]": continue - if (self.do_whole_word_mask and len(cand_indexes) >= 1 and token.startswith("##")): + if self.do_whole_word_mask and len(cand_indexes) >= 1 and token.startswith("##"): cand_indexes[-1].append(i) else: cand_indexes.append([i]) @@ -160,7 +161,7 @@ class PreTrainingDataset(): Input a sentence, return a processed sentence: In order to support the Chinese whole word mask, the words that are separated will be marked with a special mark ("#"), so that the subsequent processing module can know which words belong to the same word. :param segment: a sentence """ - seq_cws = jieba.lcut(''.join(segment)) + seq_cws = jieba.lcut("".join(segment)) seq_cws_dict = {x: 1 for x in seq_cws} new_segment = [] i = 0 @@ -174,10 +175,10 @@ class PreTrainingDataset(): for length in range(3, 0, -1): if i + length > len(segment): continue - if ''.join(segment[i:i + length]) in seq_cws_dict: + if "".join(segment[i : i + length]) in seq_cws_dict: new_segment.append(segment[i]) for l in range(1, length): - new_segment.append('##' + segment[i + l]) + new_segment.append("##" + segment[i + l]) i += length has_add = True break @@ -190,7 +191,7 @@ class PreTrainingDataset(): """Creates the predictions for the masked LM objective.""" cand_indexes = [] - for (i, token) in enumerate(tokens): + for i, token in enumerate(tokens): if token == "[CLS]" or token == "[SEP]": continue # Whole Word Masking means that if we mask all of the wordpieces @@ -202,14 +203,14 @@ class PreTrainingDataset(): # Note that Whole Word Masking does *not* change the training code # at all -- we still predict each WordPiece independently, softmaxed # over the entire vocabulary. - if (self.do_whole_word_mask and len(cand_indexes) >= 1 and token.startswith("##")): + if self.do_whole_word_mask and len(cand_indexes) >= 1 and token.startswith("##"): cand_indexes[-1].append(i) else: cand_indexes.append([i]) random.shuffle(cand_indexes) - output_tokens = [t[2:] if len(self.whole_rec.findall(t)) > 0 else t for t in tokens] # 去掉"##" + output_tokens = [t[2:] if len(self.whole_rec.findall(t)) > 0 else t for t in tokens] # 去掉"##" num_to_predict = min(self.max_predictions_per_seq, max(1, int(round(len(tokens) * self.masked_lm_prob)))) @@ -239,8 +240,9 @@ class PreTrainingDataset(): else: # 10% of the time, keep original if random.random() < 0.5: - masked_token = tokens[index][2:] if len(self.whole_rec.findall( - tokens[index])) > 0 else tokens[index] # 去掉"##" + masked_token = ( + tokens[index][2:] if len(self.whole_rec.findall(tokens[index])) > 0 else tokens[index] + ) # 去掉"##" # 10% of the time, replace with random word else: masked_token = self.vocab_words[random.randint(0, len(self.vocab_words) - 1)] @@ -250,7 +252,9 @@ class PreTrainingDataset(): masked_lms.append( MaskedLMInstance( index=index, - label=tokens[index][2:] if len(self.whole_rec.findall(tokens[index])) > 0 else tokens[index])) + label=tokens[index][2:] if len(self.whole_rec.findall(tokens[index])) > 0 else tokens[index], + ) + ) assert len(masked_lms) <= num_to_predict masked_lms = sorted(masked_lms, key=lambda x: x.index) masked_lm_output = [-1] * len(output_tokens) diff --git a/examples/community/roberta/preprocessing/sentence_split.py b/examples/community/roberta/preprocessing/sentence_split.py index 76e8bd428723d6de230a2a9df4a9f835c0a8ecf3..8c83ce09558274db2e5d342ac422972a30d368a8 100644 --- a/examples/community/roberta/preprocessing/sentence_split.py +++ b/examples/community/roberta/preprocessing/sentence_split.py @@ -14,17 +14,19 @@ def split_sentence(document: str, flag: str = "all", limit: int = 510) -> List[s sent_list = [] try: if flag == "zh": - document = re.sub('(?P([。?!…](?![”’"\'])))', r'\g\n', document) - document = re.sub('(?P([。?!]|…{1,2})[”’"\'])', r'\g\n', document) + document = re.sub("(?P([。?!…](?![”’\"'])))", r"\g\n", document) + document = re.sub("(?P([。?!]|…{1,2})[”’\"'])", r"\g\n", document) elif flag == "en": - document = re.sub('(?P([.?!](?![”’"\'])))', r'\g\n', document) - document = re.sub('(?P([?!.]["\']))', r'\g\n', - document) # Special quotation marks + document = re.sub("(?P([.?!](?![”’\"'])))", r"\g\n", document) + document = re.sub( + "(?P([?!.][\"']))", r"\g\n", document + ) # Special quotation marks else: - document = re.sub('(?P([。?!….?!](?![”’"\'])))', r'\g\n', document) + document = re.sub("(?P([。?!….?!](?![”’\"'])))", r"\g\n", document) - document = re.sub('(?P(([。?!.!?]|…{1,2})[”’"\']))', r'\g\n', - document) # Special quotation marks + document = re.sub( + "(?P(([。?!.!?]|…{1,2})[”’\"']))", r"\g\n", document + ) # Special quotation marks sent_list_ori = document.splitlines() for sent in sent_list_ori: @@ -46,36 +48,35 @@ def split_sentence(document: str, flag: str = "all", limit: int = 510) -> List[s def get_sent(output_path, input_path, fin_list=[], host=-1, seq_len=512) -> None: - workers = 32 - if input_path[-1] == '/': + if input_path[-1] == "/": input_path = input_path[:-1] - cur_path = os.path.join(output_path, str(host) + '.txt') + cur_path = os.path.join(output_path, str(host) + ".txt") new_split_sentence = functools.partial(split_sentence, limit=seq_len - 2) - with open(cur_path, 'w', encoding='utf-8') as f: + with open(cur_path, "w", encoding="utf-8") as f: for fi, fin_path in enumerate(fin_list): if not os.path.exists(os.path.join(input_path, fin_path[0])): continue - if '.json' not in fin_path[0]: + if ".json" not in fin_path[0]: continue print("Processing ", fin_path[0], " ", fi) - with open(os.path.join(input_path, fin_path[0]), 'r') as fin: - f_data = [l['content'] for l in json.load(fin)] + with open(os.path.join(input_path, fin_path[0]), "r") as fin: + f_data = [l["content"] for l in json.load(fin)] pool = multiprocessing.Pool(workers) all_sent = pool.imap_unordered(new_split_sentence, f_data, 32) pool.close() - print('finished..') + print("finished..") cnt = 0 for d in tqdm(all_sent): for i in d: - f.write(i.strip() + '\n') - f.write(']]' + '\n') + f.write(i.strip() + "\n") + f.write("]]" + "\n") cnt += 1 # if cnt >= 2: # exit() @@ -86,7 +87,7 @@ def getFileSize(filepath, shard): for i in os.listdir(filepath): all_data.append(os.path.join(filepath, i)) all_size = sum([os.path.getsize(os.path.join(filepath, f)) for f in all_data]) - ans = [[f.split('/')[-1], os.path.getsize(os.path.join(filepath, f))] for f in all_data] + ans = [[f.split("/")[-1], os.path.getsize(os.path.join(filepath, f))] for f in all_data] ans = sorted(ans, key=lambda x: x[1], reverse=True) per_size = all_size / shard real_shard = [] @@ -106,24 +107,24 @@ def getFileSize(filepath, shard): return real_shard -def get_start_end(real_shard, base=0, server_num=10, server_name='GPU'): +def get_start_end(real_shard, base=0, server_num=10, server_name="GPU"): import socket + host = int(socket.gethostname().split(server_name)[-1]) fin_list = real_shard[server_num * base + host - 1] print(fin_list) - print(f'I am server {host}, process {server_num * base + host - 1}, len {len(fin_list)}') + print(f"I am server {host}, process {server_num * base + host - 1}, len {len(fin_list)}") return fin_list, host -if __name__ == '__main__': - +if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument('--server_num', type=int, default=10, help='number of servers') - parser.add_argument('--seq_len', type=int, default=512, help='sequence length') - parser.add_argument('--shard', type=int, default=100, help='number of shards, e.g., 10, 50, or 100') - parser.add_argument('--input_path', type=str, required=True, help='input path of original corpus') - parser.add_argument('--output_path', type=str, required=True, help='output path of shard which has split sentence') + parser.add_argument("--server_num", type=int, default=10, help="number of servers") + parser.add_argument("--seq_len", type=int, default=512, help="sequence length") + parser.add_argument("--shard", type=int, default=100, help="number of shards, e.g., 10, 50, or 100") + parser.add_argument("--input_path", type=str, required=True, help="input path of original corpus") + parser.add_argument("--output_path", type=str, required=True, help="output path of shard which has split sentence") args = parser.parse_args() server_num = args.server_num @@ -137,7 +138,7 @@ if __name__ == '__main__': start = time.time() for index, shard in enumerate(real_shard): get_sent(output_path, input_path, fin_list=shard, host=index, seq_len=seq_len) - print(f'cost {str(time.time() - start)}') + print(f"cost {str(time.time() - start)}") # if you have multiple server, you can use code below or modify code to openmpi diff --git a/examples/community/roberta/preprocessing/tokenize_mask.py b/examples/community/roberta/preprocessing/tokenize_mask.py index f3d49c3d965fc680daf0532a5714122b2bf03678..19dbaf5384de2a3ad96d372ad0598d78ec1df95b 100644 --- a/examples/community/roberta/preprocessing/tokenize_mask.py +++ b/examples/community/roberta/preprocessing/tokenize_mask.py @@ -1,7 +1,6 @@ import argparse import multiprocessing import os -import socket import time from random import shuffle @@ -29,8 +28,7 @@ def get_raw_instance(document, max_sequence_length=512): curr_seq = [] sz_idx = 0 while sz_idx < len(sizes): - - if len(curr_seq) + sizes[sz_idx] <= max_sequence_length_allowed: # or len(curr_seq)==0: + if len(curr_seq) + sizes[sz_idx] <= max_sequence_length_allowed: # or len(curr_seq)==0: curr_seq += document[sz_idx] sz_idx += 1 elif sizes[sz_idx] >= max_sequence_length_allowed: @@ -43,7 +41,7 @@ def get_raw_instance(document, max_sequence_length=512): result_list.append(curr_seq) curr_seq = [] - if len(curr_seq) > max_sequence_length_allowed / 2: # /2 + if len(curr_seq) > max_sequence_length_allowed / 2: # /2 result_list.append(curr_seq) # num_instance=int(len(big_list)/max_sequence_length_allowed)+1 @@ -58,33 +56,30 @@ def get_raw_instance(document, max_sequence_length=512): def split_numpy_chunk(path, tokenizer, pretrain_data, host): - documents = [] instances = [] s = time.time() - with open(path, encoding='utf-8') as fd: + with open(path, encoding="utf-8") as fd: document = [] for i, line in enumerate(tqdm(fd)): line = line.strip() # document = line # if len(document.split("")) <= 3: # continue - if len(line) > 0 and line[:2] == "]]": # This is end of document + if len(line) > 0 and line[:2] == "]]": # This is end of document documents.append(document) document = [] elif len(line) >= 2: document.append(line) if len(document) > 0: documents.append(document) - print('read_file ', time.time() - s) + print("read_file ", time.time() - s) # documents = [x for x in documents if x] # print(len(documents)) # print(len(documents[0])) # print(documents[0][0:10]) - import multiprocessing - from typing import List ans = [] for docs in tqdm(documents): @@ -98,7 +93,7 @@ def split_numpy_chunk(path, tokenizer, pretrain_data, host): instances.extend(raw_ins) del ans - print('len instance', len(instances)) + print("len instance", len(instances)) sen_num = len(instances) seq_len = 512 @@ -114,7 +109,7 @@ def split_numpy_chunk(path, tokenizer, pretrain_data, host): segment_ids[index] = mask_dict[2] masked_lm_output[index] = mask_dict[3] - with h5py.File(f'/output/{host}.h5', 'w') as hf: + with h5py.File(f"/output/{host}.h5", "w") as hf: hf.create_dataset("input_ids", data=input_ids) hf.create_dataset("input_mask", data=input_ids) hf.create_dataset("segment_ids", data=segment_ids) @@ -124,45 +119,44 @@ def split_numpy_chunk(path, tokenizer, pretrain_data, host): def split_numpy_chunk_pool(input_path, output_path, pretrain_data, worker, dupe_factor, seq_len, file_name): - - if os.path.exists(os.path.join(output_path, f'{file_name}.h5')): - print(f'{file_name}.h5 exists') + if os.path.exists(os.path.join(output_path, f"{file_name}.h5")): + print(f"{file_name}.h5 exists") return documents = [] instances = [] s = time.time() - with open(input_path, 'r', encoding='utf-8') as fd: + with open(input_path, "r", encoding="utf-8") as fd: document = [] for i, line in enumerate(tqdm(fd)): line = line.strip() - if len(line) > 0 and line[:2] == "]]": # This is end of document + if len(line) > 0 and line[:2] == "]]": # This is end of document documents.append(document) document = [] elif len(line) >= 2: document.append(line) if len(document) > 0: documents.append(document) - print(f'read_file cost {time.time() - s}, length is {len(documents)}') + print(f"read_file cost {time.time() - s}, length is {len(documents)}") ans = [] s = time.time() pool = multiprocessing.Pool(worker) encoded_doc = pool.imap_unordered(pretrain_data.tokenize, documents, 100) - for index, res in tqdm(enumerate(encoded_doc, start=1), total=len(documents), colour='cyan'): + for index, res in tqdm(enumerate(encoded_doc, start=1), total=len(documents), colour="cyan"): ans.append(res) pool.close() print((time.time() - s) / 60) del documents instances = [] - for a in tqdm(ans, colour='MAGENTA'): + for a in tqdm(ans, colour="MAGENTA"): raw_ins = get_raw_instance(a, max_sequence_length=seq_len) instances.extend(raw_ins) del ans - print('len instance', len(instances)) + print("len instance", len(instances)) new_instances = [] for _ in range(dupe_factor): @@ -171,7 +165,7 @@ def split_numpy_chunk_pool(input_path, output_path, pretrain_data, worker, dupe_ shuffle(new_instances) instances = new_instances - print('after dupe_factor, len instance', len(instances)) + print("after dupe_factor, len instance", len(instances)) sentence_num = len(instances) input_ids = np.zeros([sentence_num, seq_len], dtype=np.int32) @@ -182,7 +176,7 @@ def split_numpy_chunk_pool(input_path, output_path, pretrain_data, worker, dupe_ s = time.time() pool = multiprocessing.Pool(worker) encoded_docs = pool.imap_unordered(pretrain_data.create_training_instance, instances, 32) - for index, mask_dict in tqdm(enumerate(encoded_docs), total=len(instances), colour='blue'): + for index, mask_dict in tqdm(enumerate(encoded_docs), total=len(instances), colour="blue"): input_ids[index] = mask_dict[0] input_mask[index] = mask_dict[1] segment_ids[index] = mask_dict[2] @@ -190,7 +184,7 @@ def split_numpy_chunk_pool(input_path, output_path, pretrain_data, worker, dupe_ pool.close() print((time.time() - s) / 60) - with h5py.File(os.path.join(output_path, f'{file_name}.h5'), 'w') as hf: + with h5py.File(os.path.join(output_path, f"{file_name}.h5"), "w") as hf: hf.create_dataset("input_ids", data=input_ids) hf.create_dataset("input_mask", data=input_mask) hf.create_dataset("segment_ids", data=segment_ids) @@ -199,50 +193,48 @@ def split_numpy_chunk_pool(input_path, output_path, pretrain_data, worker, dupe_ del instances -if __name__ == '__main__': - +if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument('--tokenizer_path', type=str, required=True, default=10, help='path of tokenizer') - parser.add_argument('--seq_len', type=int, default=512, help='sequence length') - parser.add_argument('--max_predictions_per_seq', - type=int, - default=80, - help='number of shards, e.g., 10, 50, or 100') - parser.add_argument('--input_path', type=str, required=True, help='input path of shard which has split sentence') - parser.add_argument('--output_path', type=str, required=True, help='output path of h5 contains token id') - parser.add_argument('--backend', - type=str, - default='python', - help='backend of mask token, python, c++, numpy respectively') + parser.add_argument("--tokenizer_path", type=str, required=True, default=10, help="path of tokenizer") + parser.add_argument("--seq_len", type=int, default=512, help="sequence length") + parser.add_argument( + "--max_predictions_per_seq", type=int, default=80, help="number of shards, e.g., 10, 50, or 100" + ) + parser.add_argument("--input_path", type=str, required=True, help="input path of shard which has split sentence") + parser.add_argument("--output_path", type=str, required=True, help="output path of h5 contains token id") + parser.add_argument( + "--backend", type=str, default="python", help="backend of mask token, python, c++, numpy respectively" + ) parser.add_argument( - '--dupe_factor', + "--dupe_factor", type=int, default=1, - help='specifies how many times the preprocessor repeats to create the input from the same article/document') - parser.add_argument('--worker', type=int, default=32, help='number of process') - parser.add_argument('--server_num', type=int, default=10, help='number of servers') + help="specifies how many times the preprocessor repeats to create the input from the same article/document", + ) + parser.add_argument("--worker", type=int, default=32, help="number of process") + parser.add_argument("--server_num", type=int, default=10, help="number of servers") args = parser.parse_args() tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_path) - pretrain_data = PreTrainingDataset(tokenizer, - args.seq_len, - args.backend, - max_predictions_per_seq=args.max_predictions_per_seq) + pretrain_data = PreTrainingDataset( + tokenizer, args.seq_len, args.backend, max_predictions_per_seq=args.max_predictions_per_seq + ) data_len = len(os.listdir(args.input_path)) for i in range(data_len): - input_path = os.path.join(args.input_path, f'{i}.txt') + input_path = os.path.join(args.input_path, f"{i}.txt") if os.path.exists(input_path): start = time.time() - print(f'process {input_path}') - split_numpy_chunk_pool(input_path, args.output_path, pretrain_data, args.worker, args.dupe_factor, - args.seq_len, i) + print(f"process {input_path}") + split_numpy_chunk_pool( + input_path, args.output_path, pretrain_data, args.worker, args.dupe_factor, args.seq_len, i + ) end_ = time.time() - print(u'memory:%.4f GB' % (psutil.Process(os.getpid()).memory_info().rss / 1024 / 1024 / 1024)) - print(f'has cost {(end_ - start) / 60}') - print('-' * 100) - print('') + print("memory:%.4f GB" % (psutil.Process(os.getpid()).memory_info().rss / 1024 / 1024 / 1024)) + print(f"has cost {(end_ - start) / 60}") + print("-" * 100) + print("") # if you have multiple server, you can use code below or modify code to openmpi diff --git a/examples/community/roberta/pretraining/README.md b/examples/community/roberta/pretraining/README.md index c248fc1f570831070b3b7d07e56ee64572405224..8abe48aa6c0ee72f57da74116d0209596b0458b8 100644 --- a/examples/community/roberta/pretraining/README.md +++ b/examples/community/roberta/pretraining/README.md @@ -13,7 +13,7 @@ bash run_pretrain.sh * `--bert_config`: config.json which represent model * `--mlm`: model type of backbone, bert or deberta_v2 -2. if resume training from earylier checkpoint, run the script below. +2. if resume training from earlier checkpoint, run the script below. ```shell bash run_pretrain_resume.sh diff --git a/examples/community/roberta/pretraining/arguments.py b/examples/community/roberta/pretraining/arguments.py index 40210c4b1be779e7cd029e7deb70b7ed7acad2d3..3428db4cb9c5df1371396deeb8c88320057403fd 100644 --- a/examples/community/roberta/pretraining/arguments.py +++ b/examples/community/roberta/pretraining/arguments.py @@ -1,17 +1,15 @@ -from numpy import require +import argparse -import colossalai - -__all__ = ['parse_args'] +__all__ = ["parse_args"] def parse_args(): - parser = colossalai.get_default_parser() + parser = argparse.ArgumentParser() parser.add_argument( "--distplan", type=str, - default='CAI_Gemini', + default="CAI_Gemini", help="The distributed plan [colossalai, zero1, zero2, torch_ddp, torch_zero].", ) parser.add_argument( @@ -23,65 +21,66 @@ def parse_args(): parser.add_argument( "--placement", type=str, - default='cpu', + default="cpu", help="Placement Policy for Gemini. Valid when using colossalai as dist plan.", ) parser.add_argument( "--shardinit", - action='store_true', - help= - "Shard the tensors when init the model to shrink peak memory size on the assigned device. Valid when using colossalai as dist plan.", + action="store_true", + help="Shard the tensors when init the model to shrink peak memory size on the assigned device. Valid when using colossalai as dist plan.", ) - parser.add_argument('--lr', type=float, required=True, help='initial learning rate') - parser.add_argument('--epoch', type=int, required=True, help='number of epoch') - parser.add_argument('--data_path_prefix', type=str, required=True, help="location of the train data corpus") - parser.add_argument('--eval_data_path_prefix', - type=str, - required=True, - help='location of the evaluation data corpus') - parser.add_argument('--tokenizer_path', type=str, required=True, help='location of the tokenizer') - parser.add_argument('--max_seq_length', type=int, default=512, help='sequence length') - parser.add_argument('--refresh_bucket_size', - type=int, - default=1, - help="This param makes sure that a certain task is repeated for this time steps to \ - optimise on the back propogation speed with APEX's DistributedDataParallel") - parser.add_argument("--max_predictions_per_seq", - "--max_pred", - default=80, - type=int, - help="The maximum number of masked tokens in a sequence to be predicted.") + parser.add_argument("--lr", type=float, required=True, help="initial learning rate") + parser.add_argument("--epoch", type=int, required=True, help="number of epoch") + parser.add_argument("--data_path_prefix", type=str, required=True, help="location of the train data corpus") + parser.add_argument( + "--eval_data_path_prefix", type=str, required=True, help="location of the evaluation data corpus" + ) + parser.add_argument("--tokenizer_path", type=str, required=True, help="location of the tokenizer") + parser.add_argument("--max_seq_length", type=int, default=512, help="sequence length") + parser.add_argument( + "--refresh_bucket_size", + type=int, + default=1, + help="This param makes sure that a certain task is repeated for this time steps to \ + optimize on the back propagation speed with APEX's DistributedDataParallel", + ) + parser.add_argument( + "--max_predictions_per_seq", + "--max_pred", + default=80, + type=int, + help="The maximum number of masked tokens in a sequence to be predicted.", + ) parser.add_argument("--gradient_accumulation_steps", default=1, type=int, help="accumulation_steps") parser.add_argument("--train_micro_batch_size_per_gpu", default=2, type=int, required=True, help="train batch size") parser.add_argument("--eval_micro_batch_size_per_gpu", default=2, type=int, required=True, help="eval batch size") parser.add_argument("--num_workers", default=8, type=int, help="") - parser.add_argument("--async_worker", action='store_true', help="") + parser.add_argument("--async_worker", action="store_true", help="") parser.add_argument("--bert_config", required=True, type=str, help="location of config.json") - parser.add_argument("--wandb", action='store_true', help="use wandb to watch model") - parser.add_argument("--wandb_project_name", default='roberta', help="wandb project name") + parser.add_argument("--wandb", action="store_true", help="use wandb to watch model") + parser.add_argument("--wandb_project_name", default="roberta", help="wandb project name") parser.add_argument("--log_interval", default=100, type=int, help="report interval") parser.add_argument("--log_path", type=str, required=True, help="log file which records train step") parser.add_argument("--tensorboard_path", type=str, required=True, help="location of tensorboard file") - parser.add_argument("--colossal_config", - type=str, - required=True, - help="colossal config, which contains zero config and so on") - parser.add_argument("--ckpt_path", - type=str, - required=True, - help="location of saving checkpoint, which contains model and optimizer") - parser.add_argument('--seed', type=int, default=42, help="random seed for initialization") - parser.add_argument('--vscode_debug', action='store_true', help="use vscode to debug") - parser.add_argument('--load_pretrain_model', default='', type=str, help="location of model's checkpoin") parser.add_argument( - '--load_optimizer_lr', - default='', + "--colossal_config", type=str, required=True, help="colossal config, which contains zero config and so on" + ) + parser.add_argument( + "--ckpt_path", type=str, required=True, help="location of saving checkpoint, which contains model and optimizer" + ) + parser.add_argument("--seed", type=int, default=42, help="random seed for initialization") + parser.add_argument("--vscode_debug", action="store_true", help="use vscode to debug") + parser.add_argument("--load_pretrain_model", default="", type=str, help="location of model's checkpoint") + parser.add_argument( + "--load_optimizer_lr", + default="", type=str, - help="location of checkpoint, which contains optimerzier, learning rate, epoch, shard and global_step") - parser.add_argument('--resume_train', action='store_true', help="whether resume training from a early checkpoint") - parser.add_argument('--mlm', default='bert', type=str, help="model type, bert or deberta") - parser.add_argument('--checkpoint_activations', action='store_true', help="whether to use gradient checkpointing") + help="location of checkpoint, which contains optimizer, learning rate, epoch, shard and global_step", + ) + parser.add_argument("--resume_train", action="store_true", help="whether resume training from a early checkpoint") + parser.add_argument("--mlm", default="bert", type=str, help="model type, bert or deberta") + parser.add_argument("--checkpoint_activations", action="store_true", help="whether to use gradient checkpointing") args = parser.parse_args() return args diff --git a/examples/community/roberta/pretraining/bert_dataset_provider.py b/examples/community/roberta/pretraining/bert_dataset_provider.py index eaf165ed18f4022a218671f4846d9e92c287f911..1d8cf2a910e9e41275edc480c25e86fa166fd0ee 100644 --- a/examples/community/roberta/pretraining/bert_dataset_provider.py +++ b/examples/community/roberta/pretraining/bert_dataset_provider.py @@ -1,5 +1,4 @@ class BertDatasetProviderInterface: - def get_shard(self, index, shuffle=True): raise NotImplementedError diff --git a/examples/community/roberta/pretraining/evaluation.py b/examples/community/roberta/pretraining/evaluation.py index 009242cd1cf5fcc48d0224394979de743e4f9f0b..e1bce48023c39a255556244c36f682930a2fc747 100644 --- a/examples/community/roberta/pretraining/evaluation.py +++ b/examples/community/roberta/pretraining/evaluation.py @@ -19,23 +19,27 @@ def evaluate(model, args, logger, global_step, criterion): world_size = torch.distributed.get_world_size() with torch.no_grad(): - for shard in range(start_shard, len(os.listdir(args.eval_data_path_prefix))): - - timers('eval_shard_time').start() + timers("eval_shard_time").start() dataset_iterator, total_length = evaluate_dataset_provider.get_shard(shard) # evaluate_dataset_provider.prefetch_shard(shard + 1) if torch.distributed.get_rank() == 0: - iterator_data = tqdm(enumerate(dataset_iterator), - total=(total_length // args.eval_micro_batch_size_per_gpu // world_size), - colour='MAGENTA', - smoothing=1) + iterator_data = tqdm( + enumerate(dataset_iterator), + total=(total_length // args.eval_micro_batch_size_per_gpu // world_size), + colour="MAGENTA", + smoothing=1, + ) else: iterator_data = enumerate(dataset_iterator) - for step, batch_data in iterator_data: #tqdm(enumerate(dataset_iterator), total=(total_length // args.train_micro_batch_size_per_gpu // world_size), colour='cyan', smoothing=1): - + for ( + step, + batch_data, + ) in ( + iterator_data + ): # tqdm(enumerate(dataset_iterator), total=(total_length // args.train_micro_batch_size_per_gpu // world_size), colour='cyan', smoothing=1): # batch_data = pretrain_dataset_provider.get_batch(batch_index) eval_step += 1 input_ids = batch_data[0].cuda() @@ -46,7 +50,7 @@ def evaluate(model, args, logger, global_step, criterion): output = model(input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask) - loss = criterion(output.logits, mlm_label) #prediction_scores + loss = criterion(output.logits, mlm_label) # prediction_scores evaluate_dataset_provider.prefetch_batch() eval_loss += loss.float().item() @@ -58,18 +62,18 @@ def evaluate(model, args, logger, global_step, criterion): if args.wandb and torch.distributed.get_rank() == 0: tensorboard_log = get_tensorboard_writer() - tensorboard_log.log_eval({ - 'loss': cur_loss, - 'ppl': ppl, - 'mins_batch': elapsed_time_per_iteration - }, global_step) + tensorboard_log.log_eval( + {"loss": cur_loss, "ppl": ppl, "mins_batch": elapsed_time_per_iteration}, global_step + ) - eval_log_str = f'evaluation shard: {shard} | step: {eval_step} | elapsed_time: {elapsed_time / 60 :.3f} minutes ' + \ - f'| mins/batch: {elapsed_time_per_iteration :.3f} seconds | loss: {cur_loss:.7f} | ppl: {ppl:.7f}' + eval_log_str = ( + f"evaluation shard: {shard} | step: {eval_step} | elapsed_time: {elapsed_time / 60 :.3f} minutes " + + f"| mins/batch: {elapsed_time_per_iteration :.3f} seconds | loss: {cur_loss:.7f} | ppl: {ppl:.7f}" + ) logger.info(eval_log_str) - logger.info('-' * 100) - logger.info('') + logger.info("-" * 100) + logger.info("") evaluate_dataset_provider.release_shard() model.train() diff --git a/examples/community/roberta/pretraining/loss.py b/examples/community/roberta/pretraining/loss.py index 989c2bd5c450462e130fe947a28472db32ea6f49..636246292809b9d574c1404b7b15404cc5a14b91 100644 --- a/examples/community/roberta/pretraining/loss.py +++ b/examples/community/roberta/pretraining/loss.py @@ -1,10 +1,9 @@ import torch -__all__ = ['LossForPretraining'] +__all__ = ["LossForPretraining"] class LossForPretraining(torch.nn.Module): - def __init__(self, vocab_size): super(LossForPretraining, self).__init__() self.loss_fn = torch.nn.CrossEntropyLoss(ignore_index=-1) @@ -13,5 +12,5 @@ class LossForPretraining(torch.nn.Module): def forward(self, prediction_scores, masked_lm_labels, next_sentence_labels=None): masked_lm_loss = self.loss_fn(prediction_scores.view(-1, self.vocab_size), masked_lm_labels.view(-1)) # next_sentence_loss = self.loss_fn(seq_relationship_score.view(-1, 2), next_sentence_labels.view(-1)) - total_loss = masked_lm_loss #+ next_sentence_loss + total_loss = masked_lm_loss # + next_sentence_loss return total_loss diff --git a/examples/community/roberta/pretraining/model/bert.py b/examples/community/roberta/pretraining/model/bert.py index a5da1bea6f655b3e2e36168a8339f848f48fc2ff..31e3d7075a0c2f0a1ee7debbdaccc87379d49935 100644 --- a/examples/community/roberta/pretraining/model/bert.py +++ b/examples/community/roberta/pretraining/model/bert.py @@ -59,7 +59,8 @@ _TOKENIZER_FOR_DOC = "BertTokenizer" # TokenClassification docstring _CHECKPOINT_FOR_TOKEN_CLASSIFICATION = "dbmdz/bert-large-cased-finetuned-conll03-english" _TOKEN_CLASS_EXPECTED_OUTPUT = ( - "['O', 'I-ORG', 'I-ORG', 'I-ORG', 'O', 'O', 'O', 'O', 'O', 'I-LOC', 'O', 'I-LOC', 'I-LOC'] ") + "['O', 'I-ORG', 'I-ORG', 'I-ORG', 'O', 'O', 'O', 'O', 'O', 'I-LOC', 'O', 'I-LOC', 'I-LOC'] " +) _TOKEN_CLASS_EXPECTED_LOSS = 0.01 # QuestionAnswering docstring @@ -109,8 +110,10 @@ def load_tf_weights_in_bert(model, config, tf_checkpoint_path): import numpy as np import tensorflow as tf except ImportError: - logger.error("Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see " - "https://www.tensorflow.org/install/ for installation instructions.") + logger.error( + "Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see " + "https://www.tensorflow.org/install/ for installation instructions." + ) raise tf_path = os.path.abspath(tf_checkpoint_path) logger.info(f"Converting TensorFlow checkpoint from {tf_path}") @@ -128,8 +131,10 @@ def load_tf_weights_in_bert(model, config, tf_checkpoint_path): name = name.split("/") # adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v # which are not required for using pretrained model - if any(n in ["adam_v", "adam_m", "AdamWeightDecayOptimizer", "AdamWeightDecayOptimizer_1", "global_step"] - for n in name): + if any( + n in ["adam_v", "adam_m", "AdamWeightDecayOptimizer", "AdamWeightDecayOptimizer_1", "global_step"] + for n in name + ): logger.info(f"Skipping {'/'.join(name)}") continue pointer = model @@ -209,7 +214,7 @@ class BertEmbeddings(nn.Module): seq_length = input_shape[1] if position_ids is None: - position_ids = self.position_ids[:, past_key_values_length:seq_length + past_key_values_length] + position_ids = self.position_ids[:, past_key_values_length : seq_length + past_key_values_length] # Setting the token_type_ids to the registered buffer in constructor where it is all zeros, which usually occurs # when its auto-generated, registered buffer helps users when tracing the model without passing token_type_ids, solves @@ -236,12 +241,13 @@ class BertEmbeddings(nn.Module): class BertSelfAttention(nn.Module): - def __init__(self, config, position_embedding_type=None): super().__init__() if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"): - raise ValueError(f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention " - f"heads ({config.num_attention_heads})") + raise ValueError( + f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention " + f"heads ({config.num_attention_heads})" + ) self.num_attention_heads = config.num_attention_heads self.attention_head_size = int(config.hidden_size / config.num_attention_heads) @@ -320,14 +326,14 @@ class BertSelfAttention(nn.Module): position_ids_r = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(1, -1) distance = position_ids_l - position_ids_r positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1) - positional_embedding = positional_embedding.to(dtype=query_layer.dtype) # fp16 compatibility + positional_embedding = positional_embedding.to(dtype=query_layer.dtype) # fp16 compatibility if self.position_embedding_type == "relative_key": relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding) attention_scores = attention_scores + relative_position_scores elif self.position_embedding_type == "relative_key_query": relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding) - relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding) + relative_position_scores_key = torch.einsum("bhld,lrd->bhlr", key_layer, positional_embedding) attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key attention_scores = attention_scores / math.sqrt(self.attention_head_size) @@ -360,7 +366,6 @@ class BertSelfAttention(nn.Module): class BertSelfOutput(nn.Module): - def __init__(self, config): super().__init__() self.dense = nn.Linear(config.hidden_size, config.hidden_size) @@ -375,7 +380,6 @@ class BertSelfOutput(nn.Module): class BertAttention(nn.Module): - def __init__(self, config, position_embedding_type=None): super().__init__() self.self = BertSelfAttention(config, position_embedding_type=position_embedding_type) @@ -385,8 +389,9 @@ class BertAttention(nn.Module): def prune_heads(self, heads): if len(heads) == 0: return - heads, index = find_pruneable_heads_and_indices(heads, self.self.num_attention_heads, - self.self.attention_head_size, self.pruned_heads) + heads, index = find_pruneable_heads_and_indices( + heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads + ) # Prune linear layers self.self.query = prune_linear_layer(self.self.query, index) @@ -419,12 +424,11 @@ class BertAttention(nn.Module): output_attentions, ) attention_output = self.output(self_outputs[0], hidden_states) - outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them + outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them return outputs class BertIntermediate(nn.Module): - def __init__(self, config): super().__init__() self.dense = nn.Linear(config.hidden_size, config.intermediate_size) @@ -440,7 +444,6 @@ class BertIntermediate(nn.Module): class BertOutput(nn.Module): - def __init__(self, config): super().__init__() self.dense = nn.Linear(config.intermediate_size, config.hidden_size) @@ -455,7 +458,6 @@ class BertOutput(nn.Module): class BertLayer(nn.Module): - def __init__(self, config): super().__init__() self.chunk_size_feed_forward = config.chunk_size_feed_forward @@ -496,14 +498,15 @@ class BertLayer(nn.Module): outputs = self_attention_outputs[1:-1] present_key_value = self_attention_outputs[-1] else: - outputs = self_attention_outputs[1:] # add self attentions if we output attention weights + outputs = self_attention_outputs[1:] # add self attentions if we output attention weights cross_attn_present_key_value = None if self.is_decoder and encoder_hidden_states is not None: if not hasattr(self, "crossattention"): raise ValueError( f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers" - " by setting `config.add_cross_attention=True`") + " by setting `config.add_cross_attention=True`" + ) # cross_attn cached key/values tuple is at positions 3,4 of past_key_value tuple cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None @@ -517,14 +520,15 @@ class BertLayer(nn.Module): output_attentions, ) attention_output = cross_attention_outputs[0] - outputs = outputs + cross_attention_outputs[1:-1] # add cross attentions if we output attention weights + outputs = outputs + cross_attention_outputs[1:-1] # add cross attentions if we output attention weights # add cross-attn cache to positions 3,4 of present_key_value tuple cross_attn_present_key_value = cross_attention_outputs[-1] present_key_value = present_key_value + cross_attn_present_key_value - layer_output = apply_chunking_to_forward(self.feed_forward_chunk, self.chunk_size_feed_forward, - self.seq_len_dim, attention_output) + layer_output = apply_chunking_to_forward( + self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output + ) outputs = (layer_output,) + outputs # if decoder, return the attn key/values as the last output @@ -540,7 +544,6 @@ class BertLayer(nn.Module): class BertEncoder(nn.Module): - def __init__(self, config): super().__init__() self.config = config @@ -573,14 +576,13 @@ class BertEncoder(nn.Module): past_key_value = past_key_values[i] if past_key_values is not None else None if self.gradient_checkpointing and self.training: - if use_cache: logger.warning( - "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...") + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) use_cache = False def create_custom_forward(module): - def custom_forward(*inputs): return module(*inputs, past_key_value, output_attentions) @@ -617,13 +619,17 @@ class BertEncoder(nn.Module): all_hidden_states = all_hidden_states + (hidden_states,) if not return_dict: - return tuple(v for v in [ - hidden_states, - next_decoder_cache, - all_hidden_states, - all_self_attentions, - all_cross_attentions, - ] if v is not None) + return tuple( + v + for v in [ + hidden_states, + next_decoder_cache, + all_hidden_states, + all_self_attentions, + all_cross_attentions, + ] + if v is not None + ) return BaseModelOutputWithPastAndCrossAttentions( last_hidden_state=hidden_states, past_key_values=next_decoder_cache, @@ -634,7 +640,6 @@ class BertEncoder(nn.Module): class BertPooler(nn.Module): - def __init__(self, config): super().__init__() self.dense = nn.Linear(config.hidden_size, config.hidden_size) @@ -650,7 +655,6 @@ class BertPooler(nn.Module): class BertPredictionHeadTransform(nn.Module): - def __init__(self, config): super().__init__() self.dense = nn.Linear(config.hidden_size, config.hidden_size) @@ -668,7 +672,6 @@ class BertPredictionHeadTransform(nn.Module): class BertLMPredictionHead(nn.Module): - def __init__(self, config): super().__init__() self.transform = BertPredictionHeadTransform(config) @@ -689,7 +692,6 @@ class BertLMPredictionHead(nn.Module): class BertOnlyMLMHead(nn.Module): - def __init__(self, config): super().__init__() self.predictions = BertLMPredictionHead(config) @@ -700,7 +702,6 @@ class BertOnlyMLMHead(nn.Module): class BertOnlyNSPHead(nn.Module): - def __init__(self, config): super().__init__() self.seq_relationship = nn.Linear(config.hidden_size, 2) @@ -711,7 +712,6 @@ class BertOnlyNSPHead(nn.Module): class BertPreTrainingHeads(nn.Module): - def __init__(self, config): super().__init__() self.predictions = BertLMPredictionHead(config) @@ -943,8 +943,9 @@ class BertModel(BertPreTrainedModel): `past_key_values`). """ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = (output_hidden_states - if output_hidden_states is not None else self.config.output_hidden_states) + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) return_dict = return_dict if return_dict is not None else self.config.use_return_dict if self.config.is_decoder: @@ -1043,7 +1044,6 @@ class BertModel(BertPreTrainedModel): BERT_START_DOCSTRING, ) class BertForPreTraining(BertPreTrainedModel): - def __init__(self, config): super().__init__(config) @@ -1144,10 +1144,10 @@ class BertForPreTraining(BertPreTrainedModel): ) -@add_start_docstrings("""Bert Model with a `language modeling` head on top for CLM fine-tuning.""", - BERT_START_DOCSTRING) +@add_start_docstrings( + """Bert Model with a `language modeling` head on top for CLM fine-tuning.""", BERT_START_DOCSTRING +) class BertLMHeadModel(BertPreTrainedModel): - _keys_to_ignore_on_load_unexpected = [r"pooler"] _keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"] @@ -1282,7 +1282,6 @@ class BertLMHeadModel(BertPreTrainedModel): @add_start_docstrings("""Bert Model with a `language modeling` head on top.""", BERT_START_DOCSTRING) class BertForMaskedLM(BertPreTrainedModel): - _keys_to_ignore_on_load_unexpected = [r"pooler"] _keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"] @@ -1290,8 +1289,10 @@ class BertForMaskedLM(BertPreTrainedModel): super().__init__(config) if config.is_decoder: - logger.warning("If you want to use `BertForMaskedLM` make sure `config.is_decoder=False` for " - "bi-directional self-attention.") + logger.warning( + "If you want to use `BertForMaskedLM` make sure `config.is_decoder=False` for " + "bi-directional self-attention." + ) self.bert = BertModel(config, add_pooling_layer=False) self.cls = BertOnlyMLMHead(config) @@ -1357,7 +1358,7 @@ class BertForMaskedLM(BertPreTrainedModel): masked_lm_loss = None if labels is not None: - loss_fct = CrossEntropyLoss() # -100 index = padding token + loss_fct = CrossEntropyLoss() # -100 index = padding token masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1)) if not return_dict: @@ -1380,10 +1381,9 @@ class BertForMaskedLM(BertPreTrainedModel): raise ValueError("The PAD token should be defined for generation") attention_mask = torch.cat([attention_mask, attention_mask.new_zeros((attention_mask.shape[0], 1))], dim=-1) - dummy_token = torch.full((effective_batch_size, 1), - self.config.pad_token_id, - dtype=torch.long, - device=input_ids.device) + dummy_token = torch.full( + (effective_batch_size, 1), self.config.pad_token_id, dtype=torch.long, device=input_ids.device + ) input_ids = torch.cat([input_ids, dummy_token], dim=1) return {"input_ids": input_ids, "attention_mask": attention_mask} @@ -1394,7 +1394,6 @@ class BertForMaskedLM(BertPreTrainedModel): BERT_START_DOCSTRING, ) class BertForNextSentencePrediction(BertPreTrainedModel): - def __init__(self, config): super().__init__(config) @@ -1500,15 +1499,15 @@ class BertForNextSentencePrediction(BertPreTrainedModel): BERT_START_DOCSTRING, ) class BertForSequenceClassification(BertPreTrainedModel): - def __init__(self, config): super().__init__(config) self.num_labels = config.num_labels self.config = config self.bert = BertModel(config) - classifier_dropout = (config.classifier_dropout - if config.classifier_dropout is not None else config.hidden_dropout_prob) + classifier_dropout = ( + config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob + ) self.dropout = nn.Dropout(classifier_dropout) self.classifier = nn.Linear(config.hidden_size, config.num_labels) @@ -1604,13 +1603,13 @@ class BertForSequenceClassification(BertPreTrainedModel): BERT_START_DOCSTRING, ) class BertForMultipleChoice(BertPreTrainedModel): - def __init__(self, config): super().__init__(config) self.bert = BertModel(config) - classifier_dropout = (config.classifier_dropout - if config.classifier_dropout is not None else config.hidden_dropout_prob) + classifier_dropout = ( + config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob + ) self.dropout = nn.Dropout(classifier_dropout) self.classifier = nn.Linear(config.hidden_size, 1) @@ -1650,8 +1649,11 @@ class BertForMultipleChoice(BertPreTrainedModel): attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None position_ids = position_ids.view(-1, position_ids.size(-1)) if position_ids is not None else None - inputs_embeds = (inputs_embeds.view(-1, inputs_embeds.size(-2), inputs_embeds.size(-1)) - if inputs_embeds is not None else None) + inputs_embeds = ( + inputs_embeds.view(-1, inputs_embeds.size(-2), inputs_embeds.size(-1)) + if inputs_embeds is not None + else None + ) outputs = self.bert( input_ids, @@ -1696,7 +1698,6 @@ class BertForMultipleChoice(BertPreTrainedModel): BERT_START_DOCSTRING, ) class BertForTokenClassification(BertPreTrainedModel): - _keys_to_ignore_on_load_unexpected = [r"pooler"] def __init__(self, config): @@ -1704,8 +1705,9 @@ class BertForTokenClassification(BertPreTrainedModel): self.num_labels = config.num_labels self.bert = BertModel(config, add_pooling_layer=False) - classifier_dropout = (config.classifier_dropout - if config.classifier_dropout is not None else config.hidden_dropout_prob) + classifier_dropout = ( + config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob + ) self.dropout = nn.Dropout(classifier_dropout) self.classifier = nn.Linear(config.hidden_size, config.num_labels) @@ -1782,7 +1784,6 @@ class BertForTokenClassification(BertPreTrainedModel): BERT_START_DOCSTRING, ) class BertForQuestionAnswering(BertPreTrainedModel): - _keys_to_ignore_on_load_unexpected = [r"pooler"] def __init__(self, config): diff --git a/examples/community/roberta/pretraining/model/deberta_v2.py b/examples/community/roberta/pretraining/model/deberta_v2.py index 5fc284911e38723ea9eb6ae2036521096a4323dc..c7457942e1641bb91ae9e8f2b3949cf986e38e30 100644 --- a/examples/community/roberta/pretraining/model/deberta_v2.py +++ b/examples/community/roberta/pretraining/model/deberta_v2.py @@ -23,7 +23,6 @@ import torch import torch.utils.checkpoint from torch import nn from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, LayerNorm, MSELoss -from transformers import FillMaskPipeline, T5ForConditionalGeneration, T5Tokenizer from transformers.activations import ACT2FN from transformers.modeling_outputs import ( BaseModelOutput, @@ -59,7 +58,6 @@ DEBERTA_V2_PRETRAINED_MODEL_ARCHIVE_LIST = [ # Copied from transformers.models.deberta.modeling_deberta.ContextPooler class ContextPooler(nn.Module): - def __init__(self, config): super().__init__() self.dense = nn.Linear(config.pooler_hidden_size, config.pooler_hidden_size) @@ -138,15 +136,15 @@ class XSoftmax(torch.autograd.Function): g.op("Sub", g.op("Constant", value_t=torch.tensor(1, dtype=torch.int64)), mask_cast_value), to_i=sym_help.cast_pytorch_to_onnx["Byte"], ) - output = masked_fill(g, self, r_mask, - g.op("Constant", value_t=torch.tensor(torch.finfo(self.type().dtype()).min))) + output = masked_fill( + g, self, r_mask, g.op("Constant", value_t=torch.tensor(torch.finfo(self.type().dtype()).min)) + ) output = softmax(g, output, dim) return masked_fill(g, output, r_mask, g.op("Constant", value_t=torch.tensor(0, dtype=torch.uint8))) # Copied from transformers.models.deberta.modeling_deberta.DropoutContext class DropoutContext(object): - def __init__(self): self.dropout = 0 self.mask = None @@ -249,7 +247,6 @@ class StableDropout(nn.Module): # Copied from transformers.models.deberta.modeling_deberta.DebertaSelfOutput with DebertaLayerNorm->LayerNorm class DebertaV2SelfOutput(nn.Module): - def __init__(self, config): super().__init__() self.dense = nn.Linear(config.hidden_size, config.hidden_size) @@ -265,7 +262,6 @@ class DebertaV2SelfOutput(nn.Module): # Copied from transformers.models.deberta.modeling_deberta.DebertaAttention with Deberta->DebertaV2 class DebertaV2Attention(nn.Module): - def __init__(self, config): super().__init__() self.self = DisentangledSelfAttention(config) @@ -303,7 +299,6 @@ class DebertaV2Attention(nn.Module): # Copied from transformers.models.bert.modeling_bert.BertIntermediate with Bert->DebertaV2 class DebertaV2Intermediate(nn.Module): - def __init__(self, config): super().__init__() self.dense = nn.Linear(config.hidden_size, config.intermediate_size) @@ -320,7 +315,6 @@ class DebertaV2Intermediate(nn.Module): # Copied from transformers.models.deberta.modeling_deberta.DebertaOutput with DebertaLayerNorm->LayerNorm class DebertaV2Output(nn.Module): - def __init__(self, config): super().__init__() self.dense = nn.Linear(config.intermediate_size, config.hidden_size) @@ -337,7 +331,6 @@ class DebertaV2Output(nn.Module): # Copied from transformers.models.deberta.modeling_deberta.DebertaLayer with Deberta->DebertaV2 class DebertaV2Layer(nn.Module): - def __init__(self, config): super().__init__() self.attention = DebertaV2Attention(config) @@ -372,17 +365,14 @@ class DebertaV2Layer(nn.Module): class ConvLayer(nn.Module): - def __init__(self, config): super().__init__() kernel_size = getattr(config, "conv_kernel_size", 3) groups = getattr(config, "conv_groups", 1) self.conv_act = getattr(config, "conv_act", "tanh") - self.conv = nn.Conv1d(config.hidden_size, - config.hidden_size, - kernel_size, - padding=(kernel_size - 1) // 2, - groups=groups) + self.conv = nn.Conv1d( + config.hidden_size, config.hidden_size, kernel_size, padding=(kernel_size - 1) // 2, groups=groups + ) self.LayerNorm = LayerNorm(config.hidden_size, config.layer_norm_eps) self.dropout = StableDropout(config.hidden_dropout_prob) self.config = config @@ -465,10 +455,9 @@ class DebertaV2Encoder(nn.Module): def get_rel_pos(self, hidden_states, query_states=None, relative_pos=None): if self.relative_attention and relative_pos is None: q = query_states.size(-2) if query_states is not None else hidden_states.size(-2) - relative_pos = build_relative_position(q, - hidden_states.size(-2), - bucket_size=self.position_buckets, - max_position=self.max_relative_positions) + relative_pos = build_relative_position( + q, hidden_states.size(-2), bucket_size=self.position_buckets, max_position=self.max_relative_positions + ) return relative_pos def forward( @@ -498,14 +487,12 @@ class DebertaV2Encoder(nn.Module): rel_embeddings = self.get_rel_embedding() output_states = next_kv for i, layer_module in enumerate(self.layer): - if output_hidden_states: all_hidden_states = all_hidden_states + (output_states,) if self.gradient_checkpointing and self.training: def create_custom_forward(module): - def custom_forward(*inputs): return module(*inputs, output_attentions) @@ -550,9 +537,9 @@ class DebertaV2Encoder(nn.Module): if not return_dict: return tuple(v for v in [output_states, all_hidden_states, all_attentions] if v is not None) - return BaseModelOutput(last_hidden_state=output_states, - hidden_states=all_hidden_states, - attentions=all_attentions) + return BaseModelOutput( + last_hidden_state=output_states, hidden_states=all_hidden_states, attentions=all_attentions + ) def make_log_bucket_position(relative_pos, bucket_size, max_position): @@ -625,8 +612,10 @@ class DisentangledSelfAttention(nn.Module): def __init__(self, config): super().__init__() if config.hidden_size % config.num_attention_heads != 0: - raise ValueError(f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention " - f"heads ({config.num_attention_heads})") + raise ValueError( + f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention " + f"heads ({config.num_attention_heads})" + ) self.num_attention_heads = config.num_attention_heads _attention_head_size = config.hidden_size // config.num_attention_heads self.attention_head_size = getattr(config, "attention_head_size", _attention_head_size) @@ -719,22 +708,28 @@ class DisentangledSelfAttention(nn.Module): attention_scores = torch.bmm(query_layer, key_layer.transpose(-1, -2)) / scale if self.relative_attention: rel_embeddings = self.pos_dropout(rel_embeddings) - rel_att = self.disentangled_attention_bias(query_layer, key_layer, relative_pos, rel_embeddings, - scale_factor) + rel_att = self.disentangled_attention_bias( + query_layer, key_layer, relative_pos, rel_embeddings, scale_factor + ) if rel_att is not None: attention_scores = attention_scores + rel_att attention_scores = attention_scores - attention_scores = attention_scores.view(-1, self.num_attention_heads, attention_scores.size(-2), - attention_scores.size(-1)) + attention_scores = attention_scores.view( + -1, self.num_attention_heads, attention_scores.size(-2), attention_scores.size(-1) + ) # bsz x height x length x dimension attention_probs = XSoftmax.apply(attention_scores, attention_mask, -1) attention_probs = self.dropout(attention_probs) - context_layer = torch.bmm(attention_probs.view(-1, attention_probs.size(-2), attention_probs.size(-1)), - value_layer) - context_layer = (context_layer.view(-1, self.num_attention_heads, context_layer.size(-2), - context_layer.size(-1)).permute(0, 2, 1, 3).contiguous()) + context_layer = torch.bmm( + attention_probs.view(-1, attention_probs.size(-2), attention_probs.size(-1)), value_layer + ) + context_layer = ( + context_layer.view(-1, self.num_attention_heads, context_layer.size(-2), context_layer.size(-1)) + .permute(0, 2, 1, 3) + .contiguous() + ) new_context_layer_shape = context_layer.size()[:-2] + (-1,) context_layer = context_layer.view(new_context_layer_shape) if output_attentions: @@ -745,10 +740,9 @@ class DisentangledSelfAttention(nn.Module): def disentangled_attention_bias(self, query_layer, key_layer, relative_pos, rel_embeddings, scale_factor): if relative_pos is None: q = query_layer.size(-2) - relative_pos = build_relative_position(q, - key_layer.size(-2), - bucket_size=self.position_buckets, - max_position=self.max_relative_positions) + relative_pos = build_relative_position( + q, key_layer.size(-2), bucket_size=self.position_buckets, max_position=self.max_relative_positions + ) if relative_pos.dim() == 2: relative_pos = relative_pos.unsqueeze(0).unsqueeze(0) elif relative_pos.dim() == 3: @@ -766,22 +760,25 @@ class DisentangledSelfAttention(nn.Module): # rel_embeddings = rel_embeddings.unsqueeze(0) # rel_embeddings = rel_embeddings[0 : att_span * 2, :].unsqueeze(0) if self.share_att_key: - pos_query_layer = self.transpose_for_scores(self.query_proj(rel_embeddings), - self.num_attention_heads).repeat( - query_layer.size(0) // self.num_attention_heads, 1, 1) + pos_query_layer = self.transpose_for_scores( + self.query_proj(rel_embeddings), self.num_attention_heads + ).repeat(query_layer.size(0) // self.num_attention_heads, 1, 1) pos_key_layer = self.transpose_for_scores(self.key_proj(rel_embeddings), self.num_attention_heads).repeat( - query_layer.size(0) // self.num_attention_heads, 1, 1) + query_layer.size(0) // self.num_attention_heads, 1, 1 + ) else: if "c2p" in self.pos_att_type: - pos_key_layer = self.transpose_for_scores(self.pos_key_proj(rel_embeddings), - self.num_attention_heads).repeat( - query_layer.size(0) // self.num_attention_heads, 1, - 1) # .split(self.all_head_size, dim=-1) + pos_key_layer = self.transpose_for_scores( + self.pos_key_proj(rel_embeddings), self.num_attention_heads + ).repeat( + query_layer.size(0) // self.num_attention_heads, 1, 1 + ) # .split(self.all_head_size, dim=-1) if "p2c" in self.pos_att_type: - pos_query_layer = self.transpose_for_scores(self.pos_query_proj(rel_embeddings), - self.num_attention_heads).repeat( - query_layer.size(0) // self.num_attention_heads, 1, - 1) # .split(self.all_head_size, dim=-1) + pos_query_layer = self.transpose_for_scores( + self.pos_query_proj(rel_embeddings), self.num_attention_heads + ).repeat( + query_layer.size(0) // self.num_attention_heads, 1, 1 + ) # .split(self.all_head_size, dim=-1) score = 0 # content->position @@ -792,9 +789,7 @@ class DisentangledSelfAttention(nn.Module): c2p_att = torch.gather( c2p_att, dim=-1, - index=c2p_pos.squeeze(0).expand([query_layer.size(0), - query_layer.size(1), - relative_pos.size(-1)]), + index=c2p_pos.squeeze(0).expand([query_layer.size(0), query_layer.size(1), relative_pos.size(-1)]), ) score += c2p_att / scale @@ -817,9 +812,7 @@ class DisentangledSelfAttention(nn.Module): p2c_att = torch.gather( p2c_att, dim=-1, - index=p2c_pos.squeeze(0).expand([query_layer.size(0), - key_layer.size(-2), - key_layer.size(-2)]), + index=p2c_pos.squeeze(0).expand([query_layer.size(0), key_layer.size(-2), key_layer.size(-2)]), ).transpose(-1, -2) score += p2c_att / scale @@ -999,7 +992,6 @@ DEBERTA_INPUTS_DOCSTRING = r""" ) # Copied from transformers.models.deberta.modeling_deberta.DebertaModel with Deberta->DebertaV2 class DebertaV2Model(DebertaV2PreTrainedModel): - def __init__(self, config): super().__init__(config) @@ -1042,8 +1034,9 @@ class DebertaV2Model(DebertaV2PreTrainedModel): return_dict: Optional[bool] = None, ) -> Union[Tuple, BaseModelOutput]: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = (output_hidden_states - if output_hidden_states is not None else self.config.output_hidden_states) + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) return_dict = return_dict if return_dict is not None else self.config.use_return_dict if input_ids is not None and inputs_embeds is not None: @@ -1100,7 +1093,7 @@ class DebertaV2Model(DebertaV2PreTrainedModel): sequence_output = encoded_layers[-1] if not return_dict: - return (sequence_output,) + encoder_outputs[(1 if output_hidden_states else 2):] + return (sequence_output,) + encoder_outputs[(1 if output_hidden_states else 2) :] return BaseModelOutput( last_hidden_state=sequence_output, @@ -1174,7 +1167,7 @@ class DebertaV2ForMaskedLM(DebertaV2PreTrainedModel): masked_lm_loss = None if labels is not None: - loss_fct = CrossEntropyLoss() # -100 index = padding token + loss_fct = CrossEntropyLoss() # -100 index = padding token masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1)) if not return_dict: @@ -1191,7 +1184,6 @@ class DebertaV2ForMaskedLM(DebertaV2PreTrainedModel): # copied from transformers.models.bert.BertPredictionHeadTransform with bert -> deberta class DebertaV2PredictionHeadTransform(nn.Module): - def __init__(self, config): super().__init__() self.dense = nn.Linear(config.hidden_size, config.hidden_size) @@ -1210,7 +1202,6 @@ class DebertaV2PredictionHeadTransform(nn.Module): # copied from transformers.models.bert.BertLMPredictionHead with bert -> deberta class DebertaV2LMPredictionHead(nn.Module): - def __init__(self, config): super().__init__() self.transform = DebertaV2PredictionHeadTransform(config) @@ -1232,7 +1223,6 @@ class DebertaV2LMPredictionHead(nn.Module): # copied from transformers.models.bert.BertOnlyMLMHead with bert -> deberta class DebertaV2OnlyMLMHead(nn.Module): - def __init__(self, config): super().__init__() self.predictions = DebertaV2LMPredictionHead(config) @@ -1251,7 +1241,6 @@ class DebertaV2OnlyMLMHead(nn.Module): ) # Copied from transformers.models.deberta.modeling_deberta.DebertaForSequenceClassification with Deberta->DebertaV2 class DebertaV2ForSequenceClassification(DebertaV2PreTrainedModel): - def __init__(self, config): super().__init__(config) @@ -1331,8 +1320,9 @@ class DebertaV2ForSequenceClassification(DebertaV2PreTrainedModel): label_index = (labels >= 0).nonzero() labels = labels.long() if label_index.size(0) > 0: - labeled_logits = torch.gather(logits, 0, label_index.expand(label_index.size(0), - logits.size(1))) + labeled_logits = torch.gather( + logits, 0, label_index.expand(label_index.size(0), logits.size(1)) + ) labels = torch.gather(labels, 0, label_index.view(-1)) loss_fct = CrossEntropyLoss() loss = loss_fct(labeled_logits.view(-1, self.num_labels).float(), labels.view(-1)) @@ -1357,10 +1347,9 @@ class DebertaV2ForSequenceClassification(DebertaV2PreTrainedModel): output = (logits,) + outputs[1:] return ((loss,) + output) if loss is not None else output - return SequenceClassifierOutput(loss=loss, - logits=logits, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions) + return SequenceClassifierOutput( + loss=loss, logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions + ) @add_start_docstrings( @@ -1435,10 +1424,9 @@ class DebertaV2ForTokenClassification(DebertaV2PreTrainedModel): output = (logits,) + outputs[1:] return ((loss,) + output) if loss is not None else output - return TokenClassifierOutput(loss=loss, - logits=logits, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions) + return TokenClassifierOutput( + loss=loss, logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions + ) @add_start_docstrings( @@ -1550,7 +1538,6 @@ class DebertaV2ForQuestionAnswering(DebertaV2PreTrainedModel): DEBERTA_START_DOCSTRING, ) class DebertaV2ForMultipleChoice(DebertaV2PreTrainedModel): - def __init__(self, config): super().__init__(config) @@ -1606,8 +1593,11 @@ class DebertaV2ForMultipleChoice(DebertaV2PreTrainedModel): flat_position_ids = position_ids.view(-1, position_ids.size(-1)) if position_ids is not None else None flat_token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None flat_attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None - flat_inputs_embeds = (inputs_embeds.view(-1, inputs_embeds.size(-2), inputs_embeds.size(-1)) - if inputs_embeds is not None else None) + flat_inputs_embeds = ( + inputs_embeds.view(-1, inputs_embeds.size(-2), inputs_embeds.size(-1)) + if inputs_embeds is not None + else None + ) outputs = self.deberta( flat_input_ids, diff --git a/examples/community/roberta/pretraining/nvidia_bert_dataset_provider.py b/examples/community/roberta/pretraining/nvidia_bert_dataset_provider.py index 72c7bd852a401daaf41a6ef9d5b123c7faf4085f..09677a6195cb88ee74b3b9d5e2312fa1df6db662 100644 --- a/examples/community/roberta/pretraining/nvidia_bert_dataset_provider.py +++ b/examples/community/roberta/pretraining/nvidia_bert_dataset_provider.py @@ -1,5 +1,3 @@ -import json -import logging import os import random import time @@ -12,14 +10,10 @@ import torch.distributed as dist from bert_dataset_provider import BertDatasetProviderInterface from torch.utils.data import DataLoader, Dataset from torch.utils.data.distributed import DistributedSampler -from torch.utils.data.sampler import RandomSampler - -import colossalai.utils as utils # Workaround because python functions are not picklable class WorkerInitObj(object): - def __init__(self, seed): self.seed = seed @@ -28,44 +22,46 @@ class WorkerInitObj(object): random.seed(self.seed + id) -def create_pretraining_dataset(input_file, max_predictions_per_seq, num_workers, train_batch_size, worker_init, - data_sampler): +def create_pretraining_dataset( + input_file, max_predictions_per_seq, num_workers, train_batch_size, worker_init, data_sampler +): train_data = pretraining_dataset(input_file=input_file, max_predictions_per_seq=max_predictions_per_seq) - train_dataloader = DataLoader(train_data, - sampler=data_sampler(train_data), - batch_size=train_batch_size, - num_workers=num_workers, - worker_init_fn=worker_init, - pin_memory=True) + train_dataloader = DataLoader( + train_data, + sampler=data_sampler(train_data), + batch_size=train_batch_size, + num_workers=num_workers, + worker_init_fn=worker_init, + pin_memory=True, + ) return train_dataloader, len(train_data) class pretraining_dataset(Dataset): - def __init__(self, input_file, max_predictions_per_seq): self.input_file = input_file self.max_predictions_per_seq = max_predictions_per_seq f = h5py.File(input_file, "r") - keys = ['input_ids', 'input_mask', 'segment_ids', 'masked_lm_positions'] + keys = ["input_ids", "input_mask", "segment_ids", "masked_lm_positions"] self.inputs = [np.asarray(f[key][:]) for key in keys] f.close() def __len__(self): - 'Denotes the total number of samples' + "Denotes the total number of samples" return len(self.inputs[0]) def __getitem__(self, index): - [input_ids, input_mask, segment_ids, masked_lm_labels] = [ - torch.from_numpy(input[index].astype(np.int64)) if indice < 5 else torch.from_numpy( - np.asarray(input[index].astype(np.int64))) for indice, input in enumerate(self.inputs) + torch.from_numpy(input[index].astype(np.int64)) + if indice < 5 + else torch.from_numpy(np.asarray(input[index].astype(np.int64))) + for indice, input in enumerate(self.inputs) ] return [input_ids, input_mask, segment_ids, masked_lm_labels] class NvidiaBertDatasetProvider(BertDatasetProviderInterface): - def __init__(self, args, evaluate=False): self.num_workers = args.num_workers self.max_seq_length = args.max_seq_length @@ -86,13 +82,13 @@ class NvidiaBertDatasetProvider(BertDatasetProviderInterface): self.dataset_files = [ os.path.join(args.data_path_prefix, f) for f in os.listdir(args.data_path_prefix) - if os.path.isfile(os.path.join(args.data_path_prefix, f)) and 'h5' in f + if os.path.isfile(os.path.join(args.data_path_prefix, f)) and "h5" in f ] else: self.dataset_files = [ os.path.join(args.eval_data_path_prefix, f) for f in os.listdir(args.eval_data_path_prefix) - if os.path.isfile(os.path.join(args.eval_data_path_prefix, f)) and 'h5' in f + if os.path.isfile(os.path.join(args.eval_data_path_prefix, f)) and "h5" in f ] self.dataset_files.sort() @@ -120,7 +116,8 @@ class NvidiaBertDatasetProvider(BertDatasetProviderInterface): num_workers=self.num_workers, train_batch_size=self.train_micro_batch_size_per_gpu, worker_init=self.worker_init, - data_sampler=self.data_sampler) + data_sampler=self.data_sampler, + ) else: self.train_dataloader, sample_count = self.dataset_future.result(timeout=None) @@ -136,9 +133,15 @@ class NvidiaBertDatasetProvider(BertDatasetProviderInterface): def prefetch_shard(self, index): self.data_file = self._get_shard_file(index) - self.dataset_future = self.pool.submit(create_pretraining_dataset, self.data_file, self.max_predictions_per_seq, - self.num_workers, self.train_micro_batch_size_per_gpu, self.worker_init, - self.data_sampler) + self.dataset_future = self.pool.submit( + create_pretraining_dataset, + self.data_file, + self.max_predictions_per_seq, + self.num_workers, + self.train_micro_batch_size_per_gpu, + self.worker_init, + self.data_sampler, + ) def get_batch(self, batch_iter): return batch_iter diff --git a/examples/community/roberta/pretraining/pretrain_utils.py b/examples/community/roberta/pretraining/pretrain_utils.py index cea6ac2c36e5f0225ce4b50a6af4a551986d0951..1370b413b7120f2734365bc04f5e246ccbaac46d 100644 --- a/examples/community/roberta/pretraining/pretrain_utils.py +++ b/examples/community/roberta/pretraining/pretrain_utils.py @@ -1,24 +1,12 @@ -import logging import os import sys import torch import transformers -from torch.optim import AdamW -from transformers import ( - AutoModelForMaskedLM, - AutoTokenizer, - BertForPreTraining, - GPT2Config, - GPT2LMHeadModel, - RobertaConfig, - RobertaForMaskedLM, - get_linear_schedule_with_warmup, -) - -from colossalai.core import global_context as gpc -from colossalai.nn.lr_scheduler import LinearWarmupLR -from colossalai.nn.optimizer import FusedAdam, HybridAdam +from transformers import get_linear_schedule_with_warmup + +from colossalai.legacy.core import global_context as gpc +from colossalai.nn.optimizer import HybridAdam sys.path.append(os.getcwd()) from collections import OrderedDict @@ -27,7 +15,7 @@ import torch.nn as nn from model.bert import BertForMaskedLM from model.deberta_v2 import DebertaV2ForMaskedLM -__all__ = ['get_model', 'get_optimizer', 'get_lr_scheduler', 'get_dataloader_for_pretraining'] +__all__ = ["get_model", "get_optimizer", "get_lr_scheduler", "get_dataloader_for_pretraining"] def get_new_state_dict(state_dict, start_index=13): @@ -39,7 +27,6 @@ def get_new_state_dict(state_dict, start_index=13): class LMModel(nn.Module): - def __init__(self, model, config, args): super().__init__() @@ -55,11 +42,10 @@ class LMModel(nn.Module): def get_model(args, logger): - - if args.mlm == 'bert': + if args.mlm == "bert": config = transformers.BertConfig.from_json_file(args.bert_config) model = BertForMaskedLM(config) - elif args.mlm == 'deberta_v2': + elif args.mlm == "deberta_v2": config = transformers.DebertaV2Config.from_json_file(args.bert_config) model = DebertaV2ForMaskedLM(config) else: @@ -68,11 +54,13 @@ def get_model(args, logger): if len(args.load_pretrain_model) > 0: assert os.path.exists(args.load_pretrain_model) # load_checkpoint(args.load_pretrain_model, model, strict=False) - m_state_dict = torch.load(args.load_pretrain_model, - map_location=torch.device(f"cuda:{torch.cuda.current_device()}")) + m_state_dict = torch.load( + args.load_pretrain_model, map_location=torch.device(f"cuda:{torch.cuda.current_device()}") + ) # new_state_dict = get_new_state_dict(m_state_dict) - model.load_state_dict(m_state_dict, - strict=True) # must insure that every process have identical parameters !!!!!!! + model.load_state_dict( + m_state_dict, strict=True + ) # must insure that every process have identical parameters !!!!!!! logger.info("load model success") numel = sum([p.numel() for p in model.parameters()]) @@ -85,40 +73,36 @@ def get_model(args, logger): def get_optimizer(model, lr): param_optimizer = list(model.named_parameters()) - no_decay = ['bias', 'gamma', 'beta', 'LayerNorm'] + no_decay = ["bias", "gamma", "beta", "LayerNorm"] # configure the weight decay for bert models - optimizer_grouped_parameters = [{ - 'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], - 'weight_decay': 0.1 - }, { - 'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], - 'weight_decay': 0.0 - }] + optimizer_grouped_parameters = [ + {"params": [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], "weight_decay": 0.1}, + {"params": [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], "weight_decay": 0.0}, + ] optimizer = HybridAdam(optimizer_grouped_parameters, lr=lr, betas=[0.9, 0.95]) return optimizer def get_lr_scheduler(optimizer, total_steps, warmup_steps=2000, last_epoch=-1): # warmup_steps = int(total_steps * warmup_ratio) - lr_scheduler = get_linear_schedule_with_warmup(optimizer, - num_warmup_steps=warmup_steps, - num_training_steps=total_steps, - last_epoch=last_epoch) + lr_scheduler = get_linear_schedule_with_warmup( + optimizer, num_warmup_steps=warmup_steps, num_training_steps=total_steps, last_epoch=last_epoch + ) # lr_scheduler = LinearWarmupLR(optimizer, total_steps=total_steps, warmup_steps=warmup_steps) return lr_scheduler def save_ckpt(model, optimizer, lr_scheduler, path, epoch, shard, global_step): - model_path = path + '_pytorch_model.bin' - optimizer_lr_path = path + '.op_lrs' + model_path = path + "_pytorch_model.bin" + optimizer_lr_path = path + ".op_lrs" checkpoint = {} - checkpoint['optimizer'] = optimizer.state_dict() - checkpoint['lr_scheduler'] = lr_scheduler.state_dict() - checkpoint['epoch'] = epoch - checkpoint['shard'] = shard - checkpoint['global_step'] = global_step - model_state = model.state_dict() #each process must run model.state_dict() + checkpoint["optimizer"] = optimizer.state_dict() + checkpoint["lr_scheduler"] = lr_scheduler.state_dict() + checkpoint["epoch"] = epoch + checkpoint["shard"] = shard + checkpoint["global_step"] = global_step + model_state = model.state_dict() # each process must run model.state_dict() if gpc.get_global_rank() == 0: torch.save(checkpoint, optimizer_lr_path) torch.save(model_state, model_path) diff --git a/examples/community/roberta/pretraining/run_pretraining.py b/examples/community/roberta/pretraining/run_pretraining.py index 9a6ffc1c566165ac63365864740e7dce5b2c09fa..5396de6935cbab6bf23ff9ff868d1724cb465cb2 100644 --- a/examples/community/roberta/pretraining/run_pretraining.py +++ b/examples/community/roberta/pretraining/run_pretraining.py @@ -17,16 +17,13 @@ from utils.logger import Logger import colossalai from colossalai.context import ParallelMode -from colossalai.core import global_context as gpc -from colossalai.nn.parallel import GeminiDDP, zero_model_wrapper, zero_optim_wrapper -from colossalai.tensor import ColoParameter, ComputePattern, ComputeSpec, ProcessGroup, ReplicaSpec, ShardSpec +from colossalai.nn.parallel import zero_model_wrapper, zero_optim_wrapper +from colossalai.tensor import ProcessGroup, ShardSpec from colossalai.utils import get_current_device from colossalai.utils.model.colo_init_context import ColoInitContext -from colossalai.zero import ZeroOptimizer def main(): - args = parse_args() launch_time = time.strftime("%Y-%m-%d-%H:%M:%S", time.localtime()) @@ -37,20 +34,17 @@ def main(): logger = Logger(os.path.join(args.log_path, launch_time), cuda=torch.cuda.is_available(), debug=args.vscode_debug) if args.vscode_debug: - colossalai.launch(config={}, - rank=args.rank, - world_size=args.world_size, - host=args.host, - port=args.port, - backend=args.backend) + colossalai.launch( + config={}, rank=args.rank, world_size=args.world_size, host=args.host, port=args.port, backend=args.backend + ) args.local_rank = -1 args.log_interval = 1 else: - colossalai.launch_from_torch(config={}) #args.colossal_config + colossalai.launch_from_torch(config={}) # args.colossal_config args.local_rank = int(os.environ["LOCAL_RANK"]) logger.info( - f'launch_from_torch, world size: {torch.distributed.get_world_size()} | ' + - f'ParallelMode.MODEL: {ParallelMode.MODEL} | ParallelMode.DATA: {ParallelMode.DATA} | ParallelMode.TENSOR: {ParallelMode.TENSOR}' + f"launch_from_torch, world size: {torch.distributed.get_world_size()} | " + + f"ParallelMode.MODEL: {ParallelMode.MODEL} | ParallelMode.DATA: {ParallelMode.DATA} | ParallelMode.TENSOR: {ParallelMode.TENSOR}" ) log_args(logger, args) @@ -59,7 +53,7 @@ def main(): set_global_variables(launch_time, args.tensorboard_path) world_size = torch.distributed.get_world_size() - init_dev = get_current_device() + get_current_device() # build model, optimizer and criterion if args.distplan.startswith("CAI"): @@ -72,24 +66,25 @@ def main(): raise RuntimeError("You can only use shardinit with CAI_Gemini") # build GPT model - with ColoInitContext(device=get_current_device(), - dtype=torch.half, - default_dist_spec=default_dist_spec, - default_pg=shard_pg): + with ColoInitContext( + device=get_current_device(), dtype=torch.half, default_dist_spec=default_dist_spec, default_pg=shard_pg + ): config, model, numel = get_model(args, logger) - # asign running configurations + # assign running configurations gemini_config = None if args.distplan.startswith("CAI_ZeRO"): optim_config = dict(reduce_bucket_size=12 * 1024 * 1024, overlap_communication=True, verbose=True) elif args.distplan == "CAI_Gemini": - gemini_config = dict(strict_ddp_mode=args.tp_degree == 1, - device=get_current_device(), - placement_policy=args.placement, - pin_memory=True, - hidden_dim=model.config.hidden_size, - search_range_mb=128) - optim_config = dict(gpu_margin_mem_ratio=0.) + gemini_config = dict( + strict_ddp_mode=args.tp_degree == 1, + device=get_current_device(), + placement_policy=args.placement, + pin_memory=True, + hidden_dim=model.config.hidden_size, + search_range_m=128, + ) + optim_config = dict(gpu_margin_mem_ratio=0.0) else: raise RuntimeError @@ -109,7 +104,7 @@ def main(): model = zero_model_wrapper(model, zero_stage, gemini_config) optimizer = zero_optim_wrapper(model, optimizer, optim_config=optim_config) - logger.info(get_mem_info(prefix='After init optim, ')) + logger.info(get_mem_info(prefix="After init optim, ")) else: config, model, numel = get_model(args, logger) @@ -118,12 +113,19 @@ def main(): if torch.distributed.get_rank() == 0: os.mkdir(os.path.join(args.ckpt_path, launch_time)) - logger.info(f'Model numel: {numel}') + logger.info(f"Model numel: {numel}") get_tflops_func = partial(get_tflops, numel, args.train_micro_batch_size_per_gpu, args.max_seq_length) # 144003367 is is the length of the entire dataset - steps_per_epoch = 144003367 // world_size // args.train_micro_batch_size_per_gpu // args.gradient_accumulation_steps // args.refresh_bucket_size #len(dataloader) + # len(dataloader) + steps_per_epoch = ( + 144003367 + // world_size + // args.train_micro_batch_size_per_gpu + // args.gradient_accumulation_steps + // args.refresh_bucket_size + ) total_steps = steps_per_epoch * args.epoch lr_scheduler = get_lr_scheduler(optimizer, total_steps=total_steps, last_epoch=-1) @@ -133,25 +135,25 @@ def main(): global_step = 0 if args.resume_train: assert os.path.exists(args.load_optimizer_lr) - o_l_state_dict = torch.load(args.load_optimizer_lr, map_location='cpu') - o_l_state_dict['lr_scheduler']['last_epoch'] = o_l_state_dict['lr_scheduler']['last_epoch'] - 1 - optimizer.load_state_dict(o_l_state_dict['optimizer']) + o_l_state_dict = torch.load(args.load_optimizer_lr, map_location="cpu") + o_l_state_dict["lr_scheduler"]["last_epoch"] = o_l_state_dict["lr_scheduler"]["last_epoch"] - 1 + optimizer.load_state_dict(o_l_state_dict["optimizer"]) # o_l_state_dict['lr_scheduler']['last_epoch'] - lr_scheduler = get_lr_scheduler(optimizer, - total_steps=total_steps, - last_epoch=o_l_state_dict['lr_scheduler']['last_epoch']) + lr_scheduler = get_lr_scheduler( + optimizer, total_steps=total_steps, last_epoch=o_l_state_dict["lr_scheduler"]["last_epoch"] + ) for state in optimizer.state.values(): for k, v in state.items(): if isinstance(v, torch.Tensor): state[k] = v.cuda(f"cuda:{torch.cuda.current_device()}") # if you want delete the above three code, must move the model to gpu. Because in optimizer.step() - lr_scheduler.load_state_dict(o_l_state_dict['lr_scheduler']) + lr_scheduler.load_state_dict(o_l_state_dict["lr_scheduler"]) - start_epoch = o_l_state_dict['epoch'] - start_shard = o_l_state_dict['shard'] + 1 + start_epoch = o_l_state_dict["epoch"] + start_shard = o_l_state_dict["shard"] + 1 # global_step = o_l_state_dict['global_step'] + 1 logger.info( - f'resume from epoch {start_epoch} shard {start_shard} step {lr_scheduler.last_epoch} lr {lr_scheduler.get_last_lr()[0]}' + f"resume from epoch {start_epoch} shard {start_shard} step {lr_scheduler.last_epoch} lr {lr_scheduler.get_last_lr()[0]}" ) criterion = LossForPretraining(config.vocab_size) @@ -159,34 +161,32 @@ def main(): # build dataloader pretrain_dataset_provider = NvidiaBertDatasetProvider(args) - logger.info(get_mem_info(prefix='After init model, ')) + logger.info(get_mem_info(prefix="After init model, ")) - best_loss = None eval_loss = 0 train_loss = 0 timers = get_timers() - timers('interval_time').start() - timers('epoch_time').start() - timers('shard_time').start() + timers("interval_time").start() + timers("epoch_time").start() + timers("shard_time").start() for epoch in range(start_epoch, args.epoch): - for shard in range(start_shard, len(os.listdir(args.data_path_prefix))): - dataset_iterator, total_length = pretrain_dataset_provider.get_shard(shard) # pretrain_dataset_provider.prefetch_shard(shard + 1) # may cause cpu memory overload if torch.distributed.get_rank() == 0: - iterator_data = tqdm(enumerate(dataset_iterator), - total=(total_length // args.train_micro_batch_size_per_gpu // world_size), - colour='cyan', - smoothing=1) + iterator_data = tqdm( + enumerate(dataset_iterator), + total=(total_length // args.train_micro_batch_size_per_gpu // world_size), + colour="cyan", + smoothing=1, + ) else: iterator_data = enumerate(dataset_iterator) model.train() for step, batch_data in iterator_data: - # batch_data = pretrain_dataset_provider.get_batch(batch_index) input_ids = batch_data[0].cuda(f"cuda:{torch.cuda.current_device()}") attention_mask = batch_data[1].cuda(f"cuda:{torch.cuda.current_device()}") @@ -208,56 +208,70 @@ def main(): global_step += 1 - if global_step % args.log_interval == 0 and global_step != 0 \ - and torch.distributed.get_rank() == 0: - elapsed_time = timers('interval_time').elapsed(reset=False) + if global_step % args.log_interval == 0 and global_step != 0 and torch.distributed.get_rank() == 0: + elapsed_time = timers("interval_time").elapsed(reset=False) elapsed_time_per_iteration = elapsed_time / global_step samples_per_sec, tflops, approx_parameters_in_billions = throughput_calculator( - numel, args, config, elapsed_time, global_step, world_size) + numel, args, config, elapsed_time, global_step, world_size + ) cur_loss = train_loss / args.log_interval current_lr = lr_scheduler.get_last_lr()[0] - log_str = f'| epoch: {epoch} | shard: {shard} | step: {global_step} | lr {current_lr:.7f} | elapsed_time: {elapsed_time / 60 :.3f} minutes ' + \ - f'| mins/batch: {elapsed_time_per_iteration :.3f} seconds | loss: {cur_loss:.7f} | ppl: {math.exp(cur_loss):.3f} | TFLOPS: {get_tflops_func(elapsed_time_per_iteration):.3f} or {tflops:.3f}' + log_str = ( + f"| epoch: {epoch} | shard: {shard} | step: {global_step} | lr {current_lr:.7f} | elapsed_time: {elapsed_time / 60 :.3f} minutes " + + f"| mins/batch: {elapsed_time_per_iteration :.3f} seconds | loss: {cur_loss:.7f} | ppl: {math.exp(cur_loss):.3f} | TFLOPS: {get_tflops_func(elapsed_time_per_iteration):.3f} or {tflops:.3f}" + ) logger.info(log_str, print_=False) if args.wandb: tensorboard_log = get_tensorboard_writer() tensorboard_log.log_train( { - 'lr': current_lr, - 'loss': cur_loss, - 'ppl': math.exp(cur_loss), - 'mins_batch': elapsed_time_per_iteration - }, global_step) + "lr": current_lr, + "loss": cur_loss, + "ppl": math.exp(cur_loss), + "mins_batch": elapsed_time_per_iteration, + }, + global_step, + ) train_loss = 0 logger.info(f'epoch {epoch} shard {shard} has cost {timers("shard_time").elapsed() / 60 :.3f} mins') - logger.info('*' * 100) + logger.info("*" * 100) eval_loss += evaluate(model, args, logger, global_step, criterion) - save_ckpt(model, optimizer, lr_scheduler, - os.path.join(args.ckpt_path, launch_time, f'epoch-{epoch}_shard-{shard}_' + launch_time), epoch, - shard, global_step) + save_ckpt( + model, + optimizer, + lr_scheduler, + os.path.join(args.ckpt_path, launch_time, f"epoch-{epoch}_shard-{shard}_" + launch_time), + epoch, + shard, + global_step, + ) eval_loss /= len(os.listdir(args.data_path_prefix)) logger.info( f'epoch {epoch} | shard_length {len(os.listdir(args.data_path_prefix))} | elapsed_time: {timers("epoch_time").elapsed() / 60 :.3f} mins' - + f'eval_loss: {eval_loss} | ppl: {math.exp(eval_loss)}') - logger.info('-' * 100) + + f"eval_loss: {eval_loss} | ppl: {math.exp(eval_loss)}" + ) + logger.info("-" * 100) if args.wandb and torch.distributed.get_rank() == 0: tensorboard_log = get_tensorboard_writer() - tensorboard_log.log_eval({ - 'all_eval_shard_loss': eval_loss, - }, epoch) + tensorboard_log.log_eval( + { + "all_eval_shard_loss": eval_loss, + }, + epoch, + ) start_shard = 0 eval_loss = 0 pretrain_dataset_provider.release_shard() - logger.info('Congratulation, training has finished!!!') + logger.info("Congratulation, training has finished!!!") -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/examples/community/roberta/pretraining/utils/WandbLog.py b/examples/community/roberta/pretraining/utils/WandbLog.py index b68ba8387dcdb8534d957268dd5606337ba1f74d..d73393c348d80987f82197a782ffd6e469ff2a6f 100644 --- a/examples/community/roberta/pretraining/utils/WandbLog.py +++ b/examples/community/roberta/pretraining/utils/WandbLog.py @@ -6,7 +6,6 @@ from torch.utils.tensorboard import SummaryWriter class WandbLog: - @classmethod def init_wandb(cls, project, notes=None, name=time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()), config=None): wandb.init(project=project, notes=notes, name=name, config=config) @@ -23,7 +22,6 @@ class WandbLog: class TensorboardLog: - def __init__(self, location, name=time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()), config=None): if not os.path.exists(location): os.mkdir(location) @@ -31,12 +29,12 @@ class TensorboardLog: def log_train(self, result, step): for k, v in result.items(): - self.writer.add_scalar(f'{k}/train', v, step) + self.writer.add_scalar(f"{k}/train", v, step) def log_eval(self, result, step): for k, v in result.items(): - self.writer.add_scalar(f'{k}/eval', v, step) + self.writer.add_scalar(f"{k}/eval", v, step) def log_zeroshot(self, result, step): for k, v in result.items(): - self.writer.add_scalar(f'{k}_acc/eval', v, step) + self.writer.add_scalar(f"{k}_acc/eval", v, step) diff --git a/examples/community/roberta/pretraining/utils/exp_util.py b/examples/community/roberta/pretraining/utils/exp_util.py index 0cdb56bad03117ddfd181a4bbb313b1d6783dada..e95b6efda4c8cb235b3ef64ade84c873e6cc00cb 100644 --- a/examples/community/roberta/pretraining/utils/exp_util.py +++ b/examples/community/roberta/pretraining/utils/exp_util.py @@ -5,15 +5,15 @@ import shutil import psutil import torch -from colossalai.core import global_context as gpc +from colossalai.legacy.core import global_context as gpc def logging(s, log_path, print_=True, log_=True): if print_: print(s) if log_: - with open(log_path, 'a+') as f_log: - f_log.write(s + '\n') + with open(log_path, "a+") as f_log: + f_log.write(s + "\n") def get_logger(log_path, **kwargs): @@ -22,22 +22,22 @@ def get_logger(log_path, **kwargs): def create_exp_dir(dir_path, scripts_to_save=None, debug=False): if debug: - print('Debug Mode : no experiment dir created') + print("Debug Mode : no experiment dir created") return functools.partial(logging, log_path=None, log_=False) if not os.path.exists(dir_path): os.makedirs(dir_path) - print('Experiment dir : {}'.format(dir_path)) + print("Experiment dir : {}".format(dir_path)) if scripts_to_save is not None: - script_path = os.path.join(dir_path, 'scripts') + script_path = os.path.join(dir_path, "scripts") if not os.path.exists(script_path): os.makedirs(script_path) for script in scripts_to_save: - dst_file = os.path.join(dir_path, 'scripts', os.path.basename(script)) + dst_file = os.path.join(dir_path, "scripts", os.path.basename(script)) shutil.copyfile(script, dst_file) - return get_logger(log_path=os.path.join(dir_path, 'log.txt')) + return get_logger(log_path=os.path.join(dir_path, "log.txt")) def get_cpu_mem(): @@ -48,8 +48,8 @@ def get_gpu_mem(): return torch.cuda.memory_allocated() / 1024**2 -def get_mem_info(prefix=''): - return f'{prefix}GPU memory usage: {get_gpu_mem():.2f} MB, CPU memory usage: {get_cpu_mem():.2f} MB' +def get_mem_info(prefix=""): + return f"{prefix}GPU memory usage: {get_gpu_mem():.2f} MB, CPU memory usage: {get_cpu_mem():.2f} MB" def get_tflops(model_numel, batch_size, seq_len, step_time): @@ -59,11 +59,12 @@ def get_tflops(model_numel, batch_size, seq_len, step_time): def get_parameters_in_billions(model, world_size=1): gpus_per_model = world_size - approx_parameters_in_billions = sum([ - sum([p.ds_numel if hasattr(p, 'ds_id') else p.nelement() - for p in model_module.parameters()]) - for model_module in model - ]) + approx_parameters_in_billions = sum( + [ + sum([p.ds_numel if hasattr(p, "ds_id") else p.nelement() for p in model_module.parameters()]) + for model_module in model + ] + ) return approx_parameters_in_billions * gpus_per_model / (1e9) @@ -71,13 +72,13 @@ def get_parameters_in_billions(model, world_size=1): def throughput_calculator(numel, args, config, iteration_time, total_iterations, world_size=1): gpus_per_model = 1 batch_size = args.train_micro_batch_size_per_gpu - samples_per_model = batch_size * args.max_seq_length - model_replica_count = world_size / gpus_per_model + batch_size * args.max_seq_length + world_size / gpus_per_model approx_parameters_in_billions = numel elapsed_time_per_iter = iteration_time / total_iterations samples_per_second = batch_size / elapsed_time_per_iter - #flops calculator + # flops calculator hidden_size = config.hidden_size num_layers = config.num_hidden_layers vocab_size = config.vocab_size @@ -87,9 +88,9 @@ def throughput_calculator(numel, args, config, iteration_time, total_iterations, # The factor of 4 is when used with activation check-pointing, # otherwise it will be 3. checkpoint_activations_factor = 4 if args.checkpoint_activations else 3 - flops_per_iteration = (24 * checkpoint_activations_factor * batch_size * args.max_seq_length * num_layers * - (hidden_size**2)) * (1. + (args.max_seq_length / (6. * hidden_size)) + - (vocab_size / (16. * num_layers * hidden_size))) + flops_per_iteration = ( + 24 * checkpoint_activations_factor * batch_size * args.max_seq_length * num_layers * (hidden_size**2) + ) * (1.0 + (args.max_seq_length / (6.0 * hidden_size)) + (vocab_size / (16.0 * num_layers * hidden_size))) tflops = flops_per_iteration / (elapsed_time_per_iter * (10**12)) return samples_per_second, tflops, approx_parameters_in_billions @@ -97,7 +98,7 @@ def throughput_calculator(numel, args, config, iteration_time, total_iterations, def synchronize(): if not torch.distributed.is_available(): return - if not torch.distributed.is_intialized(): + if not torch.distributed.is_initialized(): return world_size = torch.distributed.get_world_size() if world_size == 1: @@ -106,9 +107,9 @@ def synchronize(): def log_args(logger, args): - logger.info('--------args----------') - message = '\n'.join([f'{k:<30}: {v}' for k, v in vars(args).items()]) - message += '\n' - message += '\n'.join([f'{k:<30}: {v}' for k, v in gpc.config.items()]) + logger.info("--------args----------") + message = "\n".join([f"{k:<30}: {v}" for k, v in vars(args).items()]) + message += "\n" + message += "\n".join([f"{k:<30}: {v}" for k, v in gpc.config.items()]) logger.info(message) - logger.info('--------args----------\n') + logger.info("--------args----------\n") diff --git a/examples/community/roberta/pretraining/utils/global_vars.py b/examples/community/roberta/pretraining/utils/global_vars.py index 7b0c5a2be73d914bc6bccdc9d383432e2bfac1f9..176c0a5b34747554e9d16c658dd048e0055a721e 100644 --- a/examples/community/roberta/pretraining/utils/global_vars.py +++ b/examples/community/roberta/pretraining/utils/global_vars.py @@ -16,21 +16,21 @@ def set_global_variables(launch_time, tensorboard_path): def _set_timers(): """Initialize timers.""" global _GLOBAL_TIMERS - _ensure_var_is_not_initialized(_GLOBAL_TIMERS, 'timers') + _ensure_var_is_not_initialized(_GLOBAL_TIMERS, "timers") _GLOBAL_TIMERS = Timers() def _set_tensorboard_writer(launch_time, tensorboard_path): """Set tensorboard writer.""" global _GLOBAL_TENSORBOARD_WRITER - _ensure_var_is_not_initialized(_GLOBAL_TENSORBOARD_WRITER, 'tensorboard writer') + _ensure_var_is_not_initialized(_GLOBAL_TENSORBOARD_WRITER, "tensorboard writer") if torch.distributed.get_rank() == 0: - _GLOBAL_TENSORBOARD_WRITER = TensorboardLog(tensorboard_path + f'/{launch_time}', launch_time) + _GLOBAL_TENSORBOARD_WRITER = TensorboardLog(tensorboard_path + f"/{launch_time}", launch_time) def get_timers(): """Return timers.""" - _ensure_var_is_initialized(_GLOBAL_TIMERS, 'timers') + _ensure_var_is_initialized(_GLOBAL_TIMERS, "timers") return _GLOBAL_TIMERS @@ -42,12 +42,12 @@ def get_tensorboard_writer(): def _ensure_var_is_initialized(var, name): """Make sure the input variable is not None.""" - assert var is not None, '{} is not initialized.'.format(name) + assert var is not None, "{} is not initialized.".format(name) def _ensure_var_is_not_initialized(var, name): """Make sure the input variable is not None.""" - assert var is None, '{} is already initialized.'.format(name) + assert var is None, "{} is already initialized.".format(name) class _Timer: @@ -68,9 +68,9 @@ class _Timer: def stop(self): """Stop the timer.""" - assert self.started_, 'timer is not started' + assert self.started_, "timer is not started" torch.cuda.synchronize() - self.elapsed_ += (time.time() - self.start_time) + self.elapsed_ += time.time() - self.start_time self.started_ = False def reset(self): @@ -110,19 +110,19 @@ class Timers: """Write timers to a tensorboard writer""" # currently when using add_scalars, # torch.utils.add_scalars makes each timer its own run, which - # polutes the runs list, so we just add each as a scalar + # pollutes the runs list, so we just add each as a scalar assert normalizer > 0.0 for name in names: value = self.timers[name].elapsed(reset=reset) / normalizer - writer.add_scalar(name + '-time', value, iteration) + writer.add_scalar(name + "-time", value, iteration) def log(self, names, normalizer=1.0, reset=True): """Log a group of timers.""" assert normalizer > 0.0 - string = 'time (ms)' + string = "time (ms)" for name in names: elapsed_time = self.timers[name].elapsed(reset=reset) * 1000.0 / normalizer - string += ' | {}: {:.2f}'.format(name, elapsed_time) + string += " | {}: {:.2f}".format(name, elapsed_time) if torch.distributed.is_initialized(): if torch.distributed.get_rank() == (torch.distributed.get_world_size() - 1): print(string, flush=True) diff --git a/examples/community/roberta/pretraining/utils/logger.py b/examples/community/roberta/pretraining/utils/logger.py index 75c9bf4bef251f4b5914ed55f87d1ddc55cb2504..9913892b89e9ff21415f8e096339f5a5e29c8fb1 100644 --- a/examples/community/roberta/pretraining/utils/logger.py +++ b/examples/community/roberta/pretraining/utils/logger.py @@ -1,16 +1,14 @@ import logging -import os import torch.distributed as dist -logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s - %(message)s', - datefmt='%m/%d/%Y %H:%M:%S', - level=logging.INFO) +logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", datefmt="%m/%d/%Y %H:%M:%S", level=logging.INFO +) logger = logging.getLogger(__name__) -class Logger(): - +class Logger: def __init__(self, log_path, cuda=False, debug=False): self.logger = logging.getLogger(__name__) self.cuda = cuda @@ -23,8 +21,8 @@ class Logger(): self.logger.info(message, *args, **kwargs) if log_: - with open(self.log_path, 'a+') as f_log: - f_log.write(message + '\n') + with open(self.log_path, "a+") as f_log: + f_log.write(message + "\n") def error(self, message, *args, **kwargs): self.logger.error(message, *args, **kwargs) diff --git a/examples/community/roberta/test_ci.sh b/examples/community/roberta/test_ci.sh new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/examples/images/diffusion/README.md b/examples/images/diffusion/README.md index 0c7f42ded318774a05ee78345285fb81f8fba4a8..d6a1c47d6b870ea9bd9dd16ae42fd3738904ea8a 100644 --- a/examples/images/diffusion/README.md +++ b/examples/images/diffusion/README.md @@ -132,7 +132,7 @@ bash train_colossalai.sh ``` It is important for you to configure your volume mapping in order to get the best training experience. -1. **Mandatory**, mount your prepared data to `/data/scratch` via `-v :/data/scratch`, where you need to replace `` with the actual data path on your machine. Notice that within docker we need to transform the Windows path to a Linux one, e.g. `C:\User\Desktop` into `/mnt/c/User/Desktop`. +1. **Mandatory**, mount your prepared data to `/data/scratch` via `-v :/data/scratch`, where you need to replace `` with the actual data path on your machine. Notice that within docker we need to transform the Windows path to a Linux one, e.g. `C:\User\Desktop` into `/mnt/c/User/Desktop`. 2. **Recommended**, store the downloaded model weights to your host machine instead of the container directory via `-v :/root/.cache/huggingface`, where you need to replace the `` with the actual path. In this way, you don't have to repeatedly download the pretrained weights for every `docker run`. 3. **Optional**, if you encounter any problem stating that shared memory is insufficient inside container, please add `-v /dev/shm:/dev/shm` to your `docker run` command. @@ -254,7 +254,7 @@ You may contact us or participate in the following ways: 1. [Leaving a Star ⭐](https://github.com/hpcaitech/ColossalAI/stargazers) to show your like and support. Thanks! 2. Posting an [issue](https://github.com/hpcaitech/ColossalAI/issues/new/choose), or submitting a PR on GitHub follow the guideline in [Contributing](https://github.com/hpcaitech/ColossalAI/blob/main/CONTRIBUTING.md). 3. Join the Colossal-AI community on -[Slack](https://join.slack.com/t/colossalaiworkspace/shared_invite/zt-z7b26eeb-CBp7jouvu~r0~lcFzX832w), +[Slack](https://github.com/hpcaitech/public_assets/tree/main/colossalai/contact/slack), and [WeChat(微信)](https://raw.githubusercontent.com/hpcaitech/public_assets/main/colossalai/img/WeChat.png "qrcode") to share your ideas. 4. Send your official proposal to email contact@hpcaitech.com diff --git a/examples/images/diffusion/configs/train_ddp.yaml b/examples/images/diffusion/configs/train_ddp.yaml index f3ae3ddb5ff6d5847795f9d7a43afa243c0a85ed..72dc05b649a4eb129a298deee0dc61a4639501cf 100644 --- a/examples/images/diffusion/configs/train_ddp.yaml +++ b/examples/images/diffusion/configs/train_ddp.yaml @@ -80,7 +80,7 @@ data: lightning: trainer: - accelerator: 'gpu' + accelerator: 'gpu' devices: 8 log_gpu_memory: all max_epochs: 2 diff --git a/examples/images/diffusion/ldm/data/base.py b/examples/images/diffusion/ldm/data/base.py index a12492c95a162e53ba17903d9bfeb99ccd4a623f..11bd0c5954a28bfc3dd928b2364960f72d92d9a8 100644 --- a/examples/images/diffusion/ldm/data/base.py +++ b/examples/images/diffusion/ldm/data/base.py @@ -1,17 +1,15 @@ -import math import os -from abc import abstractmethod import cv2 import numpy as np import torch -from torch.utils.data import ChainDataset, ConcatDataset, Dataset, IterableDataset +from torch.utils.data import IterableDataset class Txt2ImgIterableBaseDataset(IterableDataset): - ''' + """ Define an interface to make the IterableDatasets for text2img data chainable - ''' + """ def __init__(self, file_path: str, rank, world_size): super().__init__() @@ -20,8 +18,8 @@ class Txt2ImgIterableBaseDataset(IterableDataset): self.file_list = [] self.txt_list = [] self.info = self._get_file_info(file_path) - self.start = self.info['start'] - self.end = self.info['end'] + self.start = self.info["start"] + self.end = self.info["end"] self.rank = rank self.world_size = world_size @@ -33,7 +31,7 @@ class Txt2ImgIterableBaseDataset(IterableDataset): self.num_records = self.end - self.start self.valid_ids = [i for i in range(self.end)] - print(f'{self.__class__.__name__} dataset contains {self.__len__()} examples.') + print(f"{self.__class__.__name__} dataset contains {self.__len__()} examples.") def __len__(self): # return self.iter_end - self.iter_start @@ -48,7 +46,7 @@ class Txt2ImgIterableBaseDataset(IterableDataset): for idx in range(start, end): file_name = self.file_list[idx] txt_name = self.txt_list[idx] - f_ = open(txt_name, 'r') + f_ = open(txt_name, "r") txt_ = f_.read() f_.close() image = cv2.imdecode(np.fromfile(file_name, dtype=np.uint8), 1) @@ -57,18 +55,17 @@ class Txt2ImgIterableBaseDataset(IterableDataset): yield {"txt": txt_, "image": image} def _get_file_info(self, file_path): - info = \ - { + info = { "start": 1, "end": 0, } - self.folder_list = [file_path + i for i in os.listdir(file_path) if '.' not in i] + self.folder_list = [file_path + i for i in os.listdir(file_path) if "." not in i] for folder in self.folder_list: - files = [folder + '/' + i for i in os.listdir(folder) if 'jpg' in i] - txts = [k.replace('jpg', 'txt') for k in files] + files = [folder + "/" + i for i in os.listdir(folder) if "jpg" in i] + txts = [k.replace("jpg", "txt") for k in files] self.file_list.extend(files) self.txt_list.extend(txts) - info['end'] = len(self.file_list) + info["end"] = len(self.file_list) # with open(file_path, 'r') as fin: # for _ in enumerate(fin): # info['end'] += 1 diff --git a/examples/images/diffusion/ldm/data/cifar10.py b/examples/images/diffusion/ldm/data/cifar10.py index 53cd61263b472d37c6b2b896cfd8ba2f89477b9a..85c6e1b5dd38d88c98fddd59a78b3b9f9158ce3e 100644 --- a/examples/images/diffusion/ldm/data/cifar10.py +++ b/examples/images/diffusion/ldm/data/cifar10.py @@ -1,15 +1,16 @@ +import json +from pathlib import Path from typing import Dict -import numpy as np -from omegaconf import DictConfig, ListConfig + import torch -from torch.utils.data import Dataset -from pathlib import Path -import json -from PIL import Image -from torchvision import transforms +from datasets import load_dataset from einops import rearrange from ldm.util import instantiate_from_config -from datasets import load_dataset +from omegaconf import DictConfig, ListConfig +from PIL import Image +from torch.utils.data import Dataset +from torchvision import transforms + def make_multi_folder_data(paths, caption_files=None, **kwargs): """Make a concat dataset from multiple folders @@ -19,10 +20,9 @@ def make_multi_folder_data(paths, caption_files=None, **kwargs): """ list_of_paths = [] if isinstance(paths, (Dict, DictConfig)): - assert caption_files is None, \ - "Caption files not yet supported for repeats" + assert caption_files is None, "Caption files not yet supported for repeats" for folder_path, repeats in paths.items(): - list_of_paths.extend([folder_path]*repeats) + list_of_paths.extend([folder_path] * repeats) paths = list_of_paths if caption_files is not None: @@ -31,8 +31,10 @@ def make_multi_folder_data(paths, caption_files=None, **kwargs): datasets = [FolderData(p, **kwargs) for p in paths] return torch.utils.data.ConcatDataset(datasets) + class FolderData(Dataset): - def __init__(self, + def __init__( + self, root_dir, caption_file=None, image_transforms=[], @@ -40,7 +42,7 @@ class FolderData(Dataset): default_caption="", postprocess=None, return_paths=False, - ) -> None: + ) -> None: """Create a dataset from a folder of images. If you pass in a root directory it will be searched for images ending in ext (ext can be a list) @@ -75,12 +77,12 @@ class FolderData(Dataset): self.paths.extend(list(self.root_dir.rglob(f"*.{e}"))) if isinstance(image_transforms, ListConfig): image_transforms = [instantiate_from_config(tt) for tt in image_transforms] - image_transforms.extend([transforms.ToTensor(), - transforms.Lambda(lambda x: rearrange(x * 2. - 1., 'c h w -> h w c'))]) + image_transforms.extend( + [transforms.ToTensor(), transforms.Lambda(lambda x: rearrange(x * 2.0 - 1.0, "c h w -> h w c"))] + ) image_transforms = transforms.Compose(image_transforms) self.tform = image_transforms - def __len__(self): if self.captions is not None: return len(self.captions.keys()) @@ -94,7 +96,7 @@ class FolderData(Dataset): caption = self.captions.get(chosen, None) if caption is None: caption = self.default_caption - filename = self.root_dir/chosen + filename = self.root_dir / chosen else: filename = self.paths[index] @@ -119,22 +121,23 @@ class FolderData(Dataset): im = im.convert("RGB") return self.tform(im) + def hf_dataset( name, image_transforms=[], image_column="img", label_column="label", text_column="txt", - split='train', - image_key='image', - caption_key='txt', - ): - """Make huggingface dataset with appropriate list of transforms applied - """ + split="train", + image_key="image", + caption_key="txt", +): + """Make huggingface dataset with appropriate list of transforms applied""" ds = load_dataset(name, split=split) image_transforms = [instantiate_from_config(tt) for tt in image_transforms] - image_transforms.extend([transforms.ToTensor(), - transforms.Lambda(lambda x: rearrange(x * 2. - 1., 'c h w -> h w c'))]) + image_transforms.extend( + [transforms.ToTensor(), transforms.Lambda(lambda x: rearrange(x * 2.0 - 1.0, "c h w -> h w c"))] + ) tform = transforms.Compose(image_transforms) assert image_column in ds.column_names, f"Didn't find column {image_column} in {ds.column_names}" @@ -144,7 +147,18 @@ def hf_dataset( processed = {} processed[image_key] = [tform(im) for im in examples[image_column]] - label_to_text_dict = {0: "airplane", 1: "automobile", 2: "bird", 3: "cat", 4: "deer", 5: "dog", 6: "frog", 7: "horse", 8: "ship", 9: "truck"} + label_to_text_dict = { + 0: "airplane", + 1: "automobile", + 2: "bird", + 3: "cat", + 4: "deer", + 5: "dog", + 6: "frog", + 7: "horse", + 8: "ship", + 9: "truck", + } processed[caption_key] = [label_to_text_dict[label] for label in examples[label_column]] @@ -153,6 +167,7 @@ def hf_dataset( ds.set_transform(pre_process) return ds + class TextOnly(Dataset): def __init__(self, captions, output_size, image_key="image", caption_key="txt", n_gpus=1): """Returns only captions with dummy images""" @@ -166,7 +181,7 @@ class TextOnly(Dataset): if n_gpus > 1: # hack to make sure that all the captions appear on each gpu - repeated = [n_gpus*[x] for x in self.captions] + repeated = [n_gpus * [x] for x in self.captions] self.captions = [] [self.captions.extend(x) for x in repeated] @@ -175,10 +190,10 @@ class TextOnly(Dataset): def __getitem__(self, index): dummy_im = torch.zeros(3, self.output_size, self.output_size) - dummy_im = rearrange(dummy_im * 2. - 1., 'c h w -> h w c') + dummy_im = rearrange(dummy_im * 2.0 - 1.0, "c h w -> h w c") return {self.image_key: dummy_im, self.caption_key: self.captions[index]} def _load_caption_file(self, filename): - with open(filename, 'rt') as f: + with open(filename, "rt") as f: captions = f.readlines() - return [x.strip('\n') for x in captions] \ No newline at end of file + return [x.strip("\n") for x in captions] diff --git a/examples/images/diffusion/ldm/data/imagenet.py b/examples/images/diffusion/ldm/data/imagenet.py index 1c473f9c6965b22315dbb289eff8247c71bdc790..8483e16ab23a44156148b278c36665265da843b1 100644 --- a/examples/images/diffusion/ldm/data/imagenet.py +++ b/examples/images/diffusion/ldm/data/imagenet.py @@ -1,32 +1,35 @@ -import os, yaml, pickle, shutil, tarfile, glob -import cv2 +import glob +import os +import pickle +import shutil +import tarfile +from functools import partial + import albumentations -import PIL +import cv2 import numpy as np +import PIL +import taming.data.utils as tdu import torchvision.transforms.functional as TF +import yaml +from ldm.modules.image_degradation import degradation_fn_bsr, degradation_fn_bsr_light from omegaconf import OmegaConf -from functools import partial from PIL import Image -from tqdm import tqdm +from taming.data.imagenet import ImagePaths, download, give_synsets_from_indices, retrieve, str_to_indices from torch.utils.data import Dataset, Subset - -import taming.data.utils as tdu -from taming.data.imagenet import str_to_indices, give_synsets_from_indices, download, retrieve -from taming.data.imagenet import ImagePaths - -from ldm.modules.image_degradation import degradation_fn_bsr, degradation_fn_bsr_light +from tqdm import tqdm def synset2idx(path_to_yaml="data/index_synset.yaml"): with open(path_to_yaml) as f: di2s = yaml.load(f) - return dict((v,k) for k,v in di2s.items()) + return dict((v, k) for k, v in di2s.items()) class ImageNetBase(Dataset): def __init__(self, config=None): self.config = config or OmegaConf.create() - if not type(self.config)==dict: + if not type(self.config) == dict: self.config = OmegaConf.to_container(self.config) self.keep_orig_class_label = self.config.get("keep_orig_class_label", False) self.process_images = True # if False we skip loading & processing images and self.data contains filepaths @@ -46,9 +49,11 @@ class ImageNetBase(Dataset): raise NotImplementedError() def _filter_relpaths(self, relpaths): - ignore = set([ - "n06596364_9591.JPEG", - ]) + ignore = set( + [ + "n06596364_9591.JPEG", + ] + ) relpaths = [rpath for rpath in relpaths if not rpath.split("/")[-1] in ignore] if "sub_indices" in self.config: indices = str_to_indices(self.config["sub_indices"]) @@ -67,20 +72,19 @@ class ImageNetBase(Dataset): SIZE = 2655750 URL = "https://heibox.uni-heidelberg.de/f/9f28e956cd304264bb82/?dl=1" self.human_dict = os.path.join(self.root, "synset_human.txt") - if (not os.path.exists(self.human_dict) or - not os.path.getsize(self.human_dict)==SIZE): + if not os.path.exists(self.human_dict) or not os.path.getsize(self.human_dict) == SIZE: download(URL, self.human_dict) def _prepare_idx_to_synset(self): URL = "https://heibox.uni-heidelberg.de/f/d835d5b6ceda4d3aa910/?dl=1" self.idx2syn = os.path.join(self.root, "index_synset.yaml") - if (not os.path.exists(self.idx2syn)): + if not os.path.exists(self.idx2syn): download(URL, self.idx2syn) def _prepare_human_to_integer_label(self): URL = "https://heibox.uni-heidelberg.de/f/2362b797d5be43b883f6/?dl=1" self.human2integer = os.path.join(self.root, "imagenet1000_clsidx_to_labels.txt") - if (not os.path.exists(self.human2integer)): + if not os.path.exists(self.human2integer): download(URL, self.human2integer) with open(self.human2integer, "r") as f: lines = f.read().splitlines() @@ -122,11 +126,12 @@ class ImageNetBase(Dataset): if self.process_images: self.size = retrieve(self.config, "size", default=256) - self.data = ImagePaths(self.abspaths, - labels=labels, - size=self.size, - random_crop=self.random_crop, - ) + self.data = ImagePaths( + self.abspaths, + labels=labels, + size=self.size, + random_crop=self.random_crop, + ) else: self.data = self.abspaths @@ -157,8 +162,7 @@ class ImageNetTrain(ImageNetBase): self.datadir = os.path.join(self.root, "data") self.txt_filelist = os.path.join(self.root, "filelist.txt") self.expected_length = 1281167 - self.random_crop = retrieve(self.config, "ImageNetTrain/random_crop", - default=True) + self.random_crop = retrieve(self.config, "ImageNetTrain/random_crop", default=True) if not tdu.is_prepared(self.root): # prep print("Preparing dataset {} in {}".format(self.NAME, self.root)) @@ -166,8 +170,9 @@ class ImageNetTrain(ImageNetBase): datadir = self.datadir if not os.path.exists(datadir): path = os.path.join(self.root, self.FILES[0]) - if not os.path.exists(path) or not os.path.getsize(path)==self.SIZES[0]: + if not os.path.exists(path) or not os.path.getsize(path) == self.SIZES[0]: import academictorrents as at + atpath = at.get(self.AT_HASH, datastore=self.root) assert atpath == path @@ -179,7 +184,7 @@ class ImageNetTrain(ImageNetBase): print("Extracting sub-tars.") subpaths = sorted(glob.glob(os.path.join(datadir, "*.tar"))) for subpath in tqdm(subpaths): - subdir = subpath[:-len(".tar")] + subdir = subpath[: -len(".tar")] os.makedirs(subdir, exist_ok=True) with tarfile.open(subpath, "r:") as tar: tar.extractall(path=subdir) @@ -187,7 +192,7 @@ class ImageNetTrain(ImageNetBase): filelist = glob.glob(os.path.join(datadir, "**", "*.JPEG")) filelist = [os.path.relpath(p, start=datadir) for p in filelist] filelist = sorted(filelist) - filelist = "\n".join(filelist)+"\n" + filelist = "\n".join(filelist) + "\n" with open(self.txt_filelist, "w") as f: f.write(filelist) @@ -222,8 +227,7 @@ class ImageNetValidation(ImageNetBase): self.datadir = os.path.join(self.root, "data") self.txt_filelist = os.path.join(self.root, "filelist.txt") self.expected_length = 50000 - self.random_crop = retrieve(self.config, "ImageNetValidation/random_crop", - default=False) + self.random_crop = retrieve(self.config, "ImageNetValidation/random_crop", default=False) if not tdu.is_prepared(self.root): # prep print("Preparing dataset {} in {}".format(self.NAME, self.root)) @@ -231,8 +235,9 @@ class ImageNetValidation(ImageNetBase): datadir = self.datadir if not os.path.exists(datadir): path = os.path.join(self.root, self.FILES[0]) - if not os.path.exists(path) or not os.path.getsize(path)==self.SIZES[0]: + if not os.path.exists(path) or not os.path.getsize(path) == self.SIZES[0]: import academictorrents as at + atpath = at.get(self.AT_HASH, datastore=self.root) assert atpath == path @@ -242,7 +247,7 @@ class ImageNetValidation(ImageNetBase): tar.extractall(path=datadir) vspath = os.path.join(self.root, self.FILES[1]) - if not os.path.exists(vspath) or not os.path.getsize(vspath)==self.SIZES[1]: + if not os.path.exists(vspath) or not os.path.getsize(vspath) == self.SIZES[1]: download(self.VS_URL, vspath) with open(vspath, "r") as f: @@ -261,18 +266,15 @@ class ImageNetValidation(ImageNetBase): filelist = glob.glob(os.path.join(datadir, "**", "*.JPEG")) filelist = [os.path.relpath(p, start=datadir) for p in filelist] filelist = sorted(filelist) - filelist = "\n".join(filelist)+"\n" + filelist = "\n".join(filelist) + "\n" with open(self.txt_filelist, "w") as f: f.write(filelist) tdu.mark_prepared(self.root) - class ImageNetSR(Dataset): - def __init__(self, size=None, - degradation=None, downscale_f=4, min_crop_f=0.5, max_crop_f=1., - random_crop=True): + def __init__(self, size=None, degradation=None, downscale_f=4, min_crop_f=0.5, max_crop_f=1.0, random_crop=True): """ Imagenet Superresolution Dataloader Performs following ops in order: @@ -296,12 +298,12 @@ class ImageNetSR(Dataset): self.LR_size = int(size / downscale_f) self.min_crop_f = min_crop_f self.max_crop_f = max_crop_f - assert(max_crop_f <= 1.) + assert max_crop_f <= 1.0 self.center_crop = not random_crop self.image_rescaler = albumentations.SmallestMaxSize(max_size=size, interpolation=cv2.INTER_AREA) - self.pil_interpolation = False # gets reset later if incase interp_op is from pillow + self.pil_interpolation = False # gets reset later if incase interp_op is from pillow if degradation == "bsrgan": self.degradation_process = partial(degradation_fn_bsr, sf=downscale_f) @@ -311,17 +313,17 @@ class ImageNetSR(Dataset): else: interpolation_fn = { - "cv_nearest": cv2.INTER_NEAREST, - "cv_bilinear": cv2.INTER_LINEAR, - "cv_bicubic": cv2.INTER_CUBIC, - "cv_area": cv2.INTER_AREA, - "cv_lanczos": cv2.INTER_LANCZOS4, - "pil_nearest": PIL.Image.NEAREST, - "pil_bilinear": PIL.Image.BILINEAR, - "pil_bicubic": PIL.Image.BICUBIC, - "pil_box": PIL.Image.BOX, - "pil_hamming": PIL.Image.HAMMING, - "pil_lanczos": PIL.Image.LANCZOS, + "cv_nearest": cv2.INTER_NEAREST, + "cv_bilinear": cv2.INTER_LINEAR, + "cv_bicubic": cv2.INTER_CUBIC, + "cv_area": cv2.INTER_AREA, + "cv_lanczos": cv2.INTER_LANCZOS4, + "pil_nearest": PIL.Image.NEAREST, + "pil_bilinear": PIL.Image.BILINEAR, + "pil_bicubic": PIL.Image.BICUBIC, + "pil_box": PIL.Image.BOX, + "pil_hamming": PIL.Image.HAMMING, + "pil_lanczos": PIL.Image.LANCZOS, }[degradation] self.pil_interpolation = degradation.startswith("pil_") @@ -330,8 +332,9 @@ class ImageNetSR(Dataset): self.degradation_process = partial(TF.resize, size=self.LR_size, interpolation=interpolation_fn) else: - self.degradation_process = albumentations.SmallestMaxSize(max_size=self.LR_size, - interpolation=interpolation_fn) + self.degradation_process = albumentations.SmallestMaxSize( + max_size=self.LR_size, interpolation=interpolation_fn + ) def __len__(self): return len(self.base) @@ -366,8 +369,8 @@ class ImageNetSR(Dataset): else: LR_image = self.degradation_process(image=image)["image"] - example["image"] = (image/127.5 - 1.0).astype(np.float32) - example["LR_image"] = (LR_image/127.5 - 1.0).astype(np.float32) + example["image"] = (image / 127.5 - 1.0).astype(np.float32) + example["LR_image"] = (LR_image / 127.5 - 1.0).astype(np.float32) return example @@ -379,7 +382,9 @@ class ImageNetSRTrain(ImageNetSR): def get_base(self): with open("data/imagenet_train_hr_indices.p", "rb") as f: indices = pickle.load(f) - dset = ImageNetTrain(process_images=False,) + dset = ImageNetTrain( + process_images=False, + ) return Subset(dset, indices) @@ -390,5 +395,7 @@ class ImageNetSRValidation(ImageNetSR): def get_base(self): with open("data/imagenet_val_hr_indices.p", "rb") as f: indices = pickle.load(f) - dset = ImageNetValidation(process_images=False,) + dset = ImageNetValidation( + process_images=False, + ) return Subset(dset, indices) diff --git a/examples/images/diffusion/ldm/data/lsun.py b/examples/images/diffusion/ldm/data/lsun.py index f5bf26c1425413f2d8ecb39913c2baccc42c5631..e5c374aa2d5124eeda9b9e31705172ef5f7fb372 100644 --- a/examples/images/diffusion/ldm/data/lsun.py +++ b/examples/images/diffusion/ldm/data/lsun.py @@ -1,47 +1,49 @@ import os + import numpy as np import PIL from PIL import Image from torch.utils.data import Dataset from torchvision import transforms + # This class is used to create a dataset of images from LSUN dataset for training class LSUNBase(Dataset): - def __init__(self, - txt_file, # path to the text file containing the list of image paths - data_root, # root directory of the LSUN dataset - size=None, # the size of images to resize to - interpolation="bicubic", # interpolation method to be used while resizing - flip_p=0.5 # probability of random horizontal flipping - ): - self.data_paths = txt_file # store path to text file containing list of images - self.data_root = data_root # store path to root directory of the dataset - with open(self.data_paths, "r") as f: # open and read the text file - self.image_paths = f.read().splitlines() # read the lines of the file and store as list - self._length = len(self.image_paths) # store the number of images - + def __init__( + self, + txt_file, # path to the text file containing the list of image paths + data_root, # root directory of the LSUN dataset + size=None, # the size of images to resize to + interpolation="bicubic", # interpolation method to be used while resizing + flip_p=0.5, # probability of random horizontal flipping + ): + self.data_paths = txt_file # store path to text file containing list of images + self.data_root = data_root # store path to root directory of the dataset + with open(self.data_paths, "r") as f: # open and read the text file + self.image_paths = f.read().splitlines() # read the lines of the file and store as list + self._length = len(self.image_paths) # store the number of images + # create dictionary to hold image path information self.labels = { "relative_file_path_": [l for l in self.image_paths], - "file_path_": [os.path.join(self.data_root, l) - for l in self.image_paths], + "file_path_": [os.path.join(self.data_root, l) for l in self.image_paths], } # set the image size to be resized - self.size = size + self.size = size # set the interpolation method for resizing the image - self.interpolation = {"linear": PIL.Image.LINEAR, - "bilinear": PIL.Image.BILINEAR, - "bicubic": PIL.Image.BICUBIC, - "lanczos": PIL.Image.LANCZOS, - }[interpolation] + self.interpolation = { + "linear": PIL.Image.LINEAR, + "bilinear": PIL.Image.BILINEAR, + "bicubic": PIL.Image.BICUBIC, + "lanczos": PIL.Image.LANCZOS, + }[interpolation] # randomly flip the image horizontally with a given probability self.flip = transforms.RandomHorizontalFlip(p=flip_p) def __len__(self): # return the length of dataset return self._length - def __getitem__(self, i): # get the image path for the given index @@ -52,59 +54,71 @@ class LSUNBase(Dataset): image = image.convert("RGB") # default to score-sde preprocessing - - img = np.array(image).astype(np.uint8) # convert image to numpy array - crop = min(img.shape[0], img.shape[1]) # crop the image to a square shape - h, w, = img.shape[0], img.shape[1] # get the height and width of image - img = img[(h - crop) // 2:(h + crop) // 2, - (w - crop) // 2:(w + crop) // 2] # crop the image to a square shape - - image = Image.fromarray(img) # create an image from numpy array - if self.size is not None: # if image size is provided, resize the image + + img = np.array(image).astype(np.uint8) # convert image to numpy array + crop = min(img.shape[0], img.shape[1]) # crop the image to a square shape + ( + h, + w, + ) = ( + img.shape[0], + img.shape[1], + ) # get the height and width of image + img = img[ + (h - crop) // 2 : (h + crop) // 2, (w - crop) // 2 : (w + crop) // 2 + ] # crop the image to a square shape + + image = Image.fromarray(img) # create an image from numpy array + if self.size is not None: # if image size is provided, resize the image image = image.resize((self.size, self.size), resample=self.interpolation) - image = self.flip(image) # flip the image horizontally with the given probability - image = np.array(image).astype(np.uint8) + image = self.flip(image) # flip the image horizontally with the given probability + image = np.array(image).astype(np.uint8) example["image"] = (image / 127.5 - 1.0).astype(np.float32) # normalize the image values and convert to float32 - return example # return the example dictionary containing the image and its file paths + return example # return the example dictionary containing the image and its file paths + -#A dataset class for LSUN Churches training set. -# It initializes by calling the constructor of LSUNBase class and passing the appropriate arguments. +# A dataset class for LSUN Churches training set. +# It initializes by calling the constructor of LSUNBase class and passing the appropriate arguments. # The text file containing the paths to the images and the root directory where the images are stored are passed as arguments. Any additional keyword arguments passed to this class will be forwarded to the constructor of the parent class. class LSUNChurchesTrain(LSUNBase): def __init__(self, **kwargs): super().__init__(txt_file="data/lsun/church_outdoor_train.txt", data_root="data/lsun/churches", **kwargs) -#A dataset class for LSUN Churches validation set. + +# A dataset class for LSUN Churches validation set. # It is similar to LSUNChurchesTrain except that it uses a different text file and sets the flip probability to zero by default. class LSUNChurchesValidation(LSUNBase): - def __init__(self, flip_p=0., **kwargs): - super().__init__(txt_file="data/lsun/church_outdoor_val.txt", data_root="data/lsun/churches", - flip_p=flip_p, **kwargs) + def __init__(self, flip_p=0.0, **kwargs): + super().__init__( + txt_file="data/lsun/church_outdoor_val.txt", data_root="data/lsun/churches", flip_p=flip_p, **kwargs + ) + -# A dataset class for LSUN Bedrooms training set. -# It initializes by calling the constructor of LSUNBase class and passing the appropriate arguments. +# A dataset class for LSUN Bedrooms training set. +# It initializes by calling the constructor of LSUNBase class and passing the appropriate arguments. class LSUNBedroomsTrain(LSUNBase): def __init__(self, **kwargs): super().__init__(txt_file="data/lsun/bedrooms_train.txt", data_root="data/lsun/bedrooms", **kwargs) -# A dataset class for LSUN Bedrooms validation set. + +# A dataset class for LSUN Bedrooms validation set. # It is similar to LSUNBedroomsTrain except that it uses a different text file and sets the flip probability to zero by default. class LSUNBedroomsValidation(LSUNBase): def __init__(self, flip_p=0.0, **kwargs): - super().__init__(txt_file="data/lsun/bedrooms_val.txt", data_root="data/lsun/bedrooms", - flip_p=flip_p, **kwargs) + super().__init__(txt_file="data/lsun/bedrooms_val.txt", data_root="data/lsun/bedrooms", flip_p=flip_p, **kwargs) -# A dataset class for LSUN Cats training set. -# It initializes by calling the constructor of LSUNBase class and passing the appropriate arguments. + +# A dataset class for LSUN Cats training set. +# It initializes by calling the constructor of LSUNBase class and passing the appropriate arguments. # The text file containing the paths to the images and the root directory where the images are stored are passed as arguments. class LSUNCatsTrain(LSUNBase): def __init__(self, **kwargs): super().__init__(txt_file="data/lsun/cat_train.txt", data_root="data/lsun/cats", **kwargs) -# A dataset class for LSUN Cats validation set. + +# A dataset class for LSUN Cats validation set. # It is similar to LSUNCatsTrain except that it uses a different text file and sets the flip probability to zero by default. class LSUNCatsValidation(LSUNBase): - def __init__(self, flip_p=0., **kwargs): - super().__init__(txt_file="data/lsun/cat_val.txt", data_root="data/lsun/cats", - flip_p=flip_p, **kwargs) + def __init__(self, flip_p=0.0, **kwargs): + super().__init__(txt_file="data/lsun/cat_val.txt", data_root="data/lsun/cats", flip_p=flip_p, **kwargs) diff --git a/examples/images/diffusion/ldm/data/teyvat.py b/examples/images/diffusion/ldm/data/teyvat.py index eb5d3ea469d4c68186f15008accbcc1a9ae8c8b7..4a50a78f2dbc5728f236f57d873ce7691d6445da 100644 --- a/examples/images/diffusion/ldm/data/teyvat.py +++ b/examples/images/diffusion/ldm/data/teyvat.py @@ -1,15 +1,16 @@ +import json +from pathlib import Path from typing import Dict -import numpy as np -from omegaconf import DictConfig, ListConfig + import torch -from torch.utils.data import Dataset -from pathlib import Path -import json -from PIL import Image -from torchvision import transforms +from datasets import load_dataset from einops import rearrange from ldm.util import instantiate_from_config -from datasets import load_dataset +from omegaconf import DictConfig, ListConfig +from PIL import Image +from torch.utils.data import Dataset +from torchvision import transforms + def make_multi_folder_data(paths, caption_files=None, **kwargs): """Make a concat dataset from multiple folders @@ -19,10 +20,9 @@ def make_multi_folder_data(paths, caption_files=None, **kwargs): """ list_of_paths = [] if isinstance(paths, (Dict, DictConfig)): - assert caption_files is None, \ - "Caption files not yet supported for repeats" + assert caption_files is None, "Caption files not yet supported for repeats" for folder_path, repeats in paths.items(): - list_of_paths.extend([folder_path]*repeats) + list_of_paths.extend([folder_path] * repeats) paths = list_of_paths if caption_files is not None: @@ -31,8 +31,10 @@ def make_multi_folder_data(paths, caption_files=None, **kwargs): datasets = [FolderData(p, **kwargs) for p in paths] return torch.utils.data.ConcatDataset(datasets) + class FolderData(Dataset): - def __init__(self, + def __init__( + self, root_dir, caption_file=None, image_transforms=[], @@ -40,7 +42,7 @@ class FolderData(Dataset): default_caption="", postprocess=None, return_paths=False, - ) -> None: + ) -> None: """Create a dataset from a folder of images. If you pass in a root directory it will be searched for images ending in ext (ext can be a list) @@ -75,12 +77,12 @@ class FolderData(Dataset): self.paths.extend(list(self.root_dir.rglob(f"*.{e}"))) if isinstance(image_transforms, ListConfig): image_transforms = [instantiate_from_config(tt) for tt in image_transforms] - image_transforms.extend([transforms.ToTensor(), - transforms.Lambda(lambda x: rearrange(x * 2. - 1., 'c h w -> h w c'))]) + image_transforms.extend( + [transforms.ToTensor(), transforms.Lambda(lambda x: rearrange(x * 2.0 - 1.0, "c h w -> h w c"))] + ) image_transforms = transforms.Compose(image_transforms) self.tform = image_transforms - def __len__(self): if self.captions is not None: return len(self.captions.keys()) @@ -94,7 +96,7 @@ class FolderData(Dataset): caption = self.captions.get(chosen, None) if caption is None: caption = self.default_caption - filename = self.root_dir/chosen + filename = self.root_dir / chosen else: filename = self.paths[index] @@ -119,23 +121,26 @@ class FolderData(Dataset): im = im.convert("RGB") return self.tform(im) + def hf_dataset( - path = "Fazzie/Teyvat", + path="Fazzie/Teyvat", image_transforms=[], image_column="image", text_column="text", - image_key='image', - caption_key='txt', - ): - """Make huggingface dataset with appropriate list of transforms applied - """ + image_key="image", + caption_key="txt", +): + """Make huggingface dataset with appropriate list of transforms applied""" ds = load_dataset(path, name="train") ds = ds["train"] image_transforms = [instantiate_from_config(tt) for tt in image_transforms] - image_transforms.extend([transforms.Resize((256, 256)), - transforms.ToTensor(), - transforms.Lambda(lambda x: rearrange(x * 2. - 1., 'c h w -> h w c'))] - ) + image_transforms.extend( + [ + transforms.Resize((256, 256)), + transforms.ToTensor(), + transforms.Lambda(lambda x: rearrange(x * 2.0 - 1.0, "c h w -> h w c")), + ] + ) tform = transforms.Compose(image_transforms) assert image_column in ds.column_names, f"Didn't find column {image_column} in {ds.column_names}" @@ -149,4 +154,4 @@ def hf_dataset( return processed ds.set_transform(pre_process) - return ds \ No newline at end of file + return ds diff --git a/examples/images/diffusion/ldm/lr_scheduler.py b/examples/images/diffusion/ldm/lr_scheduler.py index be39da9ca6dacc22bf3df9c7389bbb403a4a3ade..f4efb12f28b85d87c0f5622917a29807b2cba049 100644 --- a/examples/images/diffusion/ldm/lr_scheduler.py +++ b/examples/images/diffusion/ldm/lr_scheduler.py @@ -5,18 +5,20 @@ class LambdaWarmUpCosineScheduler: """ note: use with a base_lr of 1.0 """ + def __init__(self, warm_up_steps, lr_min, lr_max, lr_start, max_decay_steps, verbosity_interval=0): self.lr_warm_up_steps = warm_up_steps self.lr_start = lr_start self.lr_min = lr_min self.lr_max = lr_max self.lr_max_decay_steps = max_decay_steps - self.last_lr = 0. + self.last_lr = 0.0 self.verbosity_interval = verbosity_interval def schedule(self, n, **kwargs): if self.verbosity_interval > 0: - if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_lr}") + if n % self.verbosity_interval == 0: + print(f"current step: {n}, recent lr-multiplier: {self.last_lr}") if n < self.lr_warm_up_steps: lr = (self.lr_max - self.lr_start) / self.lr_warm_up_steps * n + self.lr_start self.last_lr = lr @@ -24,13 +26,12 @@ class LambdaWarmUpCosineScheduler: else: t = (n - self.lr_warm_up_steps) / (self.lr_max_decay_steps - self.lr_warm_up_steps) t = min(t, 1.0) - lr = self.lr_min + 0.5 * (self.lr_max - self.lr_min) * ( - 1 + np.cos(t * np.pi)) + lr = self.lr_min + 0.5 * (self.lr_max - self.lr_min) * (1 + np.cos(t * np.pi)) self.last_lr = lr return lr def __call__(self, n, **kwargs): - return self.schedule(n,**kwargs) + return self.schedule(n, **kwargs) class LambdaWarmUpCosineScheduler2: @@ -38,6 +39,7 @@ class LambdaWarmUpCosineScheduler2: supports repeated iterations, configurable via lists note: use with a base_lr of 1.0. """ + def __init__(self, warm_up_steps, f_min, f_max, f_start, cycle_lengths, verbosity_interval=0): assert len(warm_up_steps) == len(f_min) == len(f_max) == len(f_start) == len(cycle_lengths) self.lr_warm_up_steps = warm_up_steps @@ -46,7 +48,7 @@ class LambdaWarmUpCosineScheduler2: self.f_max = f_max self.cycle_lengths = cycle_lengths self.cum_cycles = np.cumsum([0] + list(self.cycle_lengths)) - self.last_f = 0. + self.last_f = 0.0 self.verbosity_interval = verbosity_interval def find_in_interval(self, n): @@ -60,8 +62,8 @@ class LambdaWarmUpCosineScheduler2: cycle = self.find_in_interval(n) n = n - self.cum_cycles[cycle] if self.verbosity_interval > 0: - if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_f}, " - f"current cycle {cycle}") + if n % self.verbosity_interval == 0: + print(f"current step: {n}, recent lr-multiplier: {self.last_f}, " f"current cycle {cycle}") if n < self.lr_warm_up_steps[cycle]: f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle] self.last_f = f @@ -69,8 +71,7 @@ class LambdaWarmUpCosineScheduler2: else: t = (n - self.lr_warm_up_steps[cycle]) / (self.cycle_lengths[cycle] - self.lr_warm_up_steps[cycle]) t = min(t, 1.0) - f = self.f_min[cycle] + 0.5 * (self.f_max[cycle] - self.f_min[cycle]) * ( - 1 + np.cos(t * np.pi)) + f = self.f_min[cycle] + 0.5 * (self.f_max[cycle] - self.f_min[cycle]) * (1 + np.cos(t * np.pi)) self.last_f = f return f @@ -79,20 +80,20 @@ class LambdaWarmUpCosineScheduler2: class LambdaLinearScheduler(LambdaWarmUpCosineScheduler2): - def schedule(self, n, **kwargs): cycle = self.find_in_interval(n) n = n - self.cum_cycles[cycle] if self.verbosity_interval > 0: - if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_f}, " - f"current cycle {cycle}") + if n % self.verbosity_interval == 0: + print(f"current step: {n}, recent lr-multiplier: {self.last_f}, " f"current cycle {cycle}") if n < self.lr_warm_up_steps[cycle]: f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle] self.last_f = f return f else: - f = self.f_min[cycle] + (self.f_max[cycle] - self.f_min[cycle]) * (self.cycle_lengths[cycle] - n) / (self.cycle_lengths[cycle]) + f = self.f_min[cycle] + (self.f_max[cycle] - self.f_min[cycle]) * (self.cycle_lengths[cycle] - n) / ( + self.cycle_lengths[cycle] + ) self.last_f = f return f - diff --git a/examples/images/diffusion/ldm/models/autoencoder.py b/examples/images/diffusion/ldm/models/autoencoder.py index f0a69fe63a8ce73f6b95540d0cdd9e0adfdbf170..1c54dfe74f74abf03b3d6fbde0b90efb71ab2add 100644 --- a/examples/images/diffusion/ldm/models/autoencoder.py +++ b/examples/images/diffusion/ldm/models/autoencoder.py @@ -1,29 +1,28 @@ -import torch -import lightning.pytorch as pl - -from torch import nn -from torch.nn import functional as F -from torch.nn import Identity from contextlib import contextmanager -from ldm.modules.diffusionmodules.model import Encoder, Decoder +import lightning.pytorch as pl +import torch +from ldm.modules.diffusionmodules.model import Decoder, Encoder from ldm.modules.distributions.distributions import DiagonalGaussianDistribution from ldm.modules.ema import LitEma +from torch.nn import Identity +from torch.nn import functional as F class AutoencoderKL(pl.LightningModule): - def __init__(self, - ddconfig, - lossconfig, - embed_dim, - ckpt_path=None, - ignore_keys=[], - image_key="image", - colorize_nlabels=None, - monitor=None, - ema_decay=None, - learn_logvar=False - ): + def __init__( + self, + ddconfig, + lossconfig, + embed_dim, + ckpt_path=None, + ignore_keys=[], + image_key="image", + colorize_nlabels=None, + monitor=None, + ema_decay=None, + learn_logvar=False, + ): super().__init__() self.learn_logvar = learn_logvar self.image_key = image_key @@ -31,11 +30,11 @@ class AutoencoderKL(pl.LightningModule): self.decoder = Decoder(**ddconfig) self.loss = Identity() assert ddconfig["double_z"] - self.quant_conv = torch.nn.Conv2d(2*ddconfig["z_channels"], 2*embed_dim, 1) + self.quant_conv = torch.nn.Conv2d(2 * ddconfig["z_channels"], 2 * embed_dim, 1) self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1) self.embed_dim = embed_dim if colorize_nlabels is not None: - assert type(colorize_nlabels)==int + assert type(colorize_nlabels) == int self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1)) if monitor is not None: self.monitor = monitor @@ -43,7 +42,7 @@ class AutoencoderKL(pl.LightningModule): self.use_ema = ema_decay is not None if self.use_ema: self.ema_decay = ema_decay - assert 0. < ema_decay < 1. + assert 0.0 < ema_decay < 1.0 self.model_ema = LitEma(self, decay=ema_decay) print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.") @@ -113,16 +112,30 @@ class AutoencoderKL(pl.LightningModule): if optimizer_idx == 0: # train encoder+decoder+logvar - aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step, - last_layer=self.get_last_layer(), split="train") + aeloss, log_dict_ae = self.loss( + inputs, + reconstructions, + posterior, + optimizer_idx, + self.global_step, + last_layer=self.get_last_layer(), + split="train", + ) self.log("aeloss", aeloss, prog_bar=True, logger=True, on_step=True, on_epoch=True) self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=False) return aeloss if optimizer_idx == 1: # train the discriminator - discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step, - last_layer=self.get_last_layer(), split="train") + discloss, log_dict_disc = self.loss( + inputs, + reconstructions, + posterior, + optimizer_idx, + self.global_step, + last_layer=self.get_last_layer(), + split="train", + ) self.log("discloss", discloss, prog_bar=True, logger=True, on_step=True, on_epoch=True) self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=False) @@ -137,11 +150,25 @@ class AutoencoderKL(pl.LightningModule): def _validation_step(self, batch, batch_idx, postfix=""): inputs = self.get_input(batch, self.image_key) reconstructions, posterior = self(inputs) - aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, 0, self.global_step, - last_layer=self.get_last_layer(), split="val"+postfix) - - discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, 1, self.global_step, - last_layer=self.get_last_layer(), split="val"+postfix) + aeloss, log_dict_ae = self.loss( + inputs, + reconstructions, + posterior, + 0, + self.global_step, + last_layer=self.get_last_layer(), + split="val" + postfix, + ) + + discloss, log_dict_disc = self.loss( + inputs, + reconstructions, + posterior, + 1, + self.global_step, + last_layer=self.get_last_layer(), + split="val" + postfix, + ) self.log(f"val{postfix}/rec_loss", log_dict_ae[f"val{postfix}/rec_loss"]) self.log_dict(log_dict_ae) @@ -150,15 +177,17 @@ class AutoencoderKL(pl.LightningModule): def configure_optimizers(self): lr = self.learning_rate - ae_params_list = list(self.encoder.parameters()) + list(self.decoder.parameters()) + list( - self.quant_conv.parameters()) + list(self.post_quant_conv.parameters()) + ae_params_list = ( + list(self.encoder.parameters()) + + list(self.decoder.parameters()) + + list(self.quant_conv.parameters()) + + list(self.post_quant_conv.parameters()) + ) if self.learn_logvar: print(f"{self.__class__.__name__}: Learning logvar") ae_params_list.append(self.loss.logvar) - opt_ae = torch.optim.Adam(ae_params_list, - lr=lr, betas=(0.5, 0.9)) - opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(), - lr=lr, betas=(0.5, 0.9)) + opt_ae = torch.optim.Adam(ae_params_list, lr=lr, betas=(0.5, 0.9)) + opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(), lr=lr, betas=(0.5, 0.9)) return [opt_ae, opt_disc], [] def get_last_layer(self): @@ -195,7 +224,7 @@ class AutoencoderKL(pl.LightningModule): if not hasattr(self, "colorize"): self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x)) x = F.conv2d(x, weight=self.colorize) - x = 2.*(x-x.min())/(x.max()-x.min()) - 1. + x = 2.0 * (x - x.min()) / (x.max() - x.min()) - 1.0 return x @@ -217,4 +246,3 @@ class IdentityFirstStage(torch.nn.Module): def forward(self, x, *args, **kwargs): return x - diff --git a/examples/images/diffusion/ldm/models/diffusion/classifier.py b/examples/images/diffusion/ldm/models/diffusion/classifier.py index 3cf12f093beaa34d52119f4634a1922a35ee00eb..73aba26c9d894a90549ff8fce3c6e3c9e46b1be6 100644 --- a/examples/images/diffusion/ldm/models/diffusion/classifier.py +++ b/examples/images/diffusion/ldm/models/diffusion/classifier.py @@ -1,23 +1,21 @@ import os -import torch +from copy import deepcopy +from glob import glob + import lightning.pytorch as pl +import torch +from einops import rearrange +from ldm.lr_scheduler import LambdaLinearScheduler +from ldm.models.diffusion.ddpm import LatentDiffusion +from ldm.modules.diffusionmodules.openaimodel import EncoderUNetModel, UNetModel +from ldm.util import default, ismap, log_txt_as_img +from natsort import natsorted from omegaconf import OmegaConf from torch.nn import functional as F from torch.optim import AdamW from torch.optim.lr_scheduler import LambdaLR -from copy import deepcopy -from einops import rearrange -from glob import glob -from natsort import natsorted -from ldm.models.diffusion.ddpm import LatentDiffusion -from ldm.lr_scheduler import LambdaLinearScheduler -from ldm.modules.diffusionmodules.openaimodel import EncoderUNetModel, UNetModel -from ldm.util import log_txt_as_img, default, ismap -__models__ = { - 'class_label': EncoderUNetModel, - 'segmentation': UNetModel -} +__models__ = {"class_label": EncoderUNetModel, "segmentation": UNetModel} def disabled_train(self, mode=True): @@ -27,24 +25,25 @@ def disabled_train(self, mode=True): class NoisyLatentImageClassifier(pl.LightningModule): - - def __init__(self, - diffusion_path, - num_classes, - ckpt_path=None, - pool='attention', - label_key=None, - diffusion_ckpt_path=None, - scheduler_config=None, - weight_decay=1.e-2, - log_steps=10, - monitor='val/loss', - *args, - **kwargs): + def __init__( + self, + diffusion_path, + num_classes, + ckpt_path=None, + pool="attention", + label_key=None, + diffusion_ckpt_path=None, + scheduler_config=None, + weight_decay=1.0e-2, + log_steps=10, + monitor="val/loss", + *args, + **kwargs, + ): super().__init__(*args, **kwargs) self.num_classes = num_classes # get latest config of diffusion model - diffusion_config = natsorted(glob(os.path.join(diffusion_path, 'configs', '*-project.yaml')))[-1] + diffusion_config = natsorted(glob(os.path.join(diffusion_path, "configs", "*-project.yaml")))[-1] self.diffusion_config = OmegaConf.load(diffusion_config).model self.diffusion_config.params.ckpt_path = diffusion_ckpt_path self.load_diffusion() @@ -54,10 +53,11 @@ class NoisyLatentImageClassifier(pl.LightningModule): self.log_time_interval = self.diffusion_model.num_timesteps // log_steps self.log_steps = log_steps - self.label_key = label_key if not hasattr(self.diffusion_model, 'cond_stage_key') \ - else self.diffusion_model.cond_stage_key + self.label_key = ( + label_key if not hasattr(self.diffusion_model, "cond_stage_key") else self.diffusion_model.cond_stage_key + ) - assert self.label_key is not None, 'label_key neither in diffusion model nor in model.params' + assert self.label_key is not None, "label_key neither in diffusion model nor in model.params" if self.label_key not in __models__: raise NotImplementedError() @@ -78,8 +78,9 @@ class NoisyLatentImageClassifier(pl.LightningModule): if k.startswith(ik): print("Deleting key {} from state_dict.".format(k)) del sd[k] - missing, unexpected = self.load_state_dict(sd, strict=False) if not only_model else self.model.load_state_dict( - sd, strict=False) + missing, unexpected = ( + self.load_state_dict(sd, strict=False) if not only_model else self.model.load_state_dict(sd, strict=False) + ) print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys") if len(missing) > 0: print(f"Missing Keys: {missing}") @@ -87,7 +88,7 @@ class NoisyLatentImageClassifier(pl.LightningModule): print(f"Unexpected Keys: {unexpected}") def load_diffusion(self): - model = LatentDiffusion(**self.diffusion_config.get('params',dict())) + model = LatentDiffusion(**self.diffusion_config.get("params", dict())) self.diffusion_model = model.eval() self.diffusion_model.train = disabled_train for param in self.diffusion_model.parameters(): @@ -97,14 +98,14 @@ class NoisyLatentImageClassifier(pl.LightningModule): model_config = deepcopy(self.diffusion_config.params.unet_config.params) model_config.in_channels = self.diffusion_config.params.unet_config.params.out_channels model_config.out_channels = self.num_classes - if self.label_key == 'class_label': + if self.label_key == "class_label": model_config.pool = pool self.model = __models__[self.label_key](**model_config) if ckpt_path is not None: - print('#####################################################################') + print("#####################################################################") print(f'load from ckpt "{ckpt_path}"') - print('#####################################################################') + print("#####################################################################") self.init_from_ckpt(ckpt_path) @torch.no_grad() @@ -115,8 +116,9 @@ class NoisyLatentImageClassifier(pl.LightningModule): continuous_sqrt_alpha_cumprod = self.diffusion_model.sample_continuous_noise_level(x.shape[0], t + 1) # todo: make sure t+1 is correct here - return self.diffusion_model.q_sample(x_start=x, t=t, noise=noise, - continuous_sqrt_alpha_cumprod=continuous_sqrt_alpha_cumprod) + return self.diffusion_model.q_sample( + x_start=x, t=t, noise=noise, continuous_sqrt_alpha_cumprod=continuous_sqrt_alpha_cumprod + ) def forward(self, x_noisy, t, *args, **kwargs): return self.model(x_noisy, t) @@ -126,7 +128,7 @@ class NoisyLatentImageClassifier(pl.LightningModule): x = batch[k] if len(x.shape) == 3: x = x[..., None] - x = rearrange(x, 'b h w c -> b c h w') + x = rearrange(x, "b h w c -> b c h w") x = x.to(memory_format=torch.contiguous_format).float() return x @@ -134,15 +136,15 @@ class NoisyLatentImageClassifier(pl.LightningModule): def get_conditioning(self, batch, k=None): if k is None: k = self.label_key - assert k is not None, 'Needs to provide label key' + assert k is not None, "Needs to provide label key" targets = batch[k].to(self.device) - if self.label_key == 'segmentation': - targets = rearrange(targets, 'b h w c -> b c h w') + if self.label_key == "segmentation": + targets = rearrange(targets, "b h w c -> b c h w") for down in range(self.numd): h, w = targets.shape[-2:] - targets = F.interpolate(targets, size=(h // 2, w // 2), mode='nearest') + targets = F.interpolate(targets, size=(h // 2, w // 2), mode="nearest") # targets = rearrange(targets,'b c h w -> b h w c') @@ -157,25 +159,21 @@ class NoisyLatentImageClassifier(pl.LightningModule): def on_train_epoch_start(self): # save some memory - self.diffusion_model.model.to('cpu') + self.diffusion_model.model.to("cpu") @torch.no_grad() def write_logs(self, loss, logits, targets): - log_prefix = 'train' if self.training else 'val' + log_prefix = "train" if self.training else "val" log = {} log[f"{log_prefix}/loss"] = loss.mean() - log[f"{log_prefix}/acc@1"] = self.compute_top_k( - logits, targets, k=1, reduction="mean" - ) - log[f"{log_prefix}/acc@5"] = self.compute_top_k( - logits, targets, k=5, reduction="mean" - ) + log[f"{log_prefix}/acc@1"] = self.compute_top_k(logits, targets, k=1, reduction="mean") + log[f"{log_prefix}/acc@5"] = self.compute_top_k(logits, targets, k=5, reduction="mean") self.log_dict(log, prog_bar=False, logger=True, on_step=self.training, on_epoch=True) - self.log('loss', log[f"{log_prefix}/loss"], prog_bar=True, logger=False) - self.log('global_step', self.global_step, logger=False, on_epoch=False, prog_bar=True) - lr = self.optimizers().param_groups[0]['lr'] - self.log('lr_abs', lr, on_step=True, logger=True, on_epoch=False, prog_bar=True) + self.log("loss", log[f"{log_prefix}/loss"], prog_bar=True, logger=False) + self.log("global_step", self.global_step, logger=False, on_epoch=False, prog_bar=True) + lr = self.optimizers().param_groups[0]["lr"] + self.log("lr_abs", lr, on_step=True, logger=True, on_epoch=False, prog_bar=True) def shared_step(self, batch, t=None): x, *_ = self.diffusion_model.get_input(batch, k=self.diffusion_model.first_stage_key) @@ -189,7 +187,7 @@ class NoisyLatentImageClassifier(pl.LightningModule): x_noisy = self.get_x_noisy(x, t) logits = self(x_noisy, t) - loss = F.cross_entropy(logits, targets, reduction='none') + loss = F.cross_entropy(logits, targets, reduction="none") self.write_logs(loss.detach(), logits.detach(), targets.detach()) @@ -201,8 +199,10 @@ class NoisyLatentImageClassifier(pl.LightningModule): return loss def reset_noise_accs(self): - self.noisy_acc = {t: {'acc@1': [], 'acc@5': []} for t in - range(0, self.diffusion_model.num_timesteps, self.diffusion_model.log_every_t)} + self.noisy_acc = { + t: {"acc@1": [], "acc@5": []} + for t in range(0, self.diffusion_model.num_timesteps, self.diffusion_model.log_every_t) + } def on_validation_start(self): self.reset_noise_accs() @@ -213,8 +213,8 @@ class NoisyLatentImageClassifier(pl.LightningModule): for t in self.noisy_acc: _, logits, _, targets = self.shared_step(batch, t) - self.noisy_acc[t]['acc@1'].append(self.compute_top_k(logits, targets, k=1, reduction='mean')) - self.noisy_acc[t]['acc@5'].append(self.compute_top_k(logits, targets, k=5, reduction='mean')) + self.noisy_acc[t]["acc@1"].append(self.compute_top_k(logits, targets, k=1, reduction="mean")) + self.noisy_acc[t]["acc@5"].append(self.compute_top_k(logits, targets, k=5, reduction="mean")) return loss @@ -222,15 +222,12 @@ class NoisyLatentImageClassifier(pl.LightningModule): optimizer = AdamW(self.model.parameters(), lr=self.learning_rate, weight_decay=self.weight_decay) if self.use_scheduler: - scheduler = LambdaLinearScheduler(**self.scheduler_config.get('params',dict())) + scheduler = LambdaLinearScheduler(**self.scheduler_config.get("params", dict())) print("Setting up LambdaLR scheduler...") scheduler = [ - { - 'scheduler': LambdaLR(optimizer, lr_lambda=scheduler.schedule), - 'interval': 'step', - 'frequency': 1 - }] + {"scheduler": LambdaLR(optimizer, lr_lambda=scheduler.schedule), "interval": "step", "frequency": 1} + ] return [optimizer], scheduler return optimizer @@ -239,28 +236,28 @@ class NoisyLatentImageClassifier(pl.LightningModule): def log_images(self, batch, N=8, *args, **kwargs): log = dict() x = self.get_input(batch, self.diffusion_model.first_stage_key) - log['inputs'] = x + log["inputs"] = x y = self.get_conditioning(batch) - if self.label_key == 'class_label': + if self.label_key == "class_label": y = log_txt_as_img((x.shape[2], x.shape[3]), batch["human_label"]) - log['labels'] = y + log["labels"] = y if ismap(y): - log['labels'] = self.diffusion_model.to_rgb(y) + log["labels"] = self.diffusion_model.to_rgb(y) for step in range(self.log_steps): current_time = step * self.log_time_interval _, logits, x_noisy, _ = self.shared_step(batch, t=current_time) - log[f'inputs@t{current_time}'] = x_noisy + log[f"inputs@t{current_time}"] = x_noisy pred = F.one_hot(logits.argmax(dim=1), num_classes=self.num_classes) - pred = rearrange(pred, 'b h w c -> b c h w') + pred = rearrange(pred, "b h w c -> b c h w") - log[f'pred@t{current_time}'] = self.diffusion_model.to_rgb(pred) + log[f"pred@t{current_time}"] = self.diffusion_model.to_rgb(pred) for key in log: log[key] = log[key][:N] diff --git a/examples/images/diffusion/ldm/models/diffusion/ddim.py b/examples/images/diffusion/ldm/models/diffusion/ddim.py index 27ead0ea914c64c747b64e690662899fb3801144..a9e28792f86494b7dfbde1bd737f32c0f583c453 100644 --- a/examples/images/diffusion/ldm/models/diffusion/ddim.py +++ b/examples/images/diffusion/ldm/models/diffusion/ddim.py @@ -1,11 +1,15 @@ """SAMPLING ONLY.""" -import torch import numpy as np +import torch +from ldm.modules.diffusionmodules.util import ( + extract_into_tensor, + make_ddim_sampling_parameters, + make_ddim_timesteps, + noise_like, +) from tqdm import tqdm -from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like, extract_into_tensor - class DDIMSampler(object): def __init__(self, model, schedule="linear", **kwargs): @@ -20,67 +24,75 @@ class DDIMSampler(object): attr = attr.to(torch.device("cuda")) setattr(self, name, attr) - def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True): - self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps, - num_ddpm_timesteps=self.ddpm_num_timesteps,verbose=verbose) + def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0.0, verbose=True): + self.ddim_timesteps = make_ddim_timesteps( + ddim_discr_method=ddim_discretize, + num_ddim_timesteps=ddim_num_steps, + num_ddpm_timesteps=self.ddpm_num_timesteps, + verbose=verbose, + ) alphas_cumprod = self.model.alphas_cumprod - assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep' + assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, "alphas have to be defined for each timestep" to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device) - self.register_buffer('betas', to_torch(self.model.betas)) - self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod)) - self.register_buffer('alphas_cumprod_prev', to_torch(self.model.alphas_cumprod_prev)) + self.register_buffer("betas", to_torch(self.model.betas)) + self.register_buffer("alphas_cumprod", to_torch(alphas_cumprod)) + self.register_buffer("alphas_cumprod_prev", to_torch(self.model.alphas_cumprod_prev)) # calculations for diffusion q(x_t | x_{t-1}) and others - self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu()))) - self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod.cpu()))) - self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod.cpu()))) - self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu()))) - self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu() - 1))) + self.register_buffer("sqrt_alphas_cumprod", to_torch(np.sqrt(alphas_cumprod.cpu()))) + self.register_buffer("sqrt_one_minus_alphas_cumprod", to_torch(np.sqrt(1.0 - alphas_cumprod.cpu()))) + self.register_buffer("log_one_minus_alphas_cumprod", to_torch(np.log(1.0 - alphas_cumprod.cpu()))) + self.register_buffer("sqrt_recip_alphas_cumprod", to_torch(np.sqrt(1.0 / alphas_cumprod.cpu()))) + self.register_buffer("sqrt_recipm1_alphas_cumprod", to_torch(np.sqrt(1.0 / alphas_cumprod.cpu() - 1))) # ddim sampling parameters - ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(alphacums=alphas_cumprod.cpu(), - ddim_timesteps=self.ddim_timesteps, - eta=ddim_eta,verbose=verbose) - self.register_buffer('ddim_sigmas', ddim_sigmas) - self.register_buffer('ddim_alphas', ddim_alphas) - self.register_buffer('ddim_alphas_prev', ddim_alphas_prev) - self.register_buffer('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas)) + ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters( + alphacums=alphas_cumprod.cpu(), ddim_timesteps=self.ddim_timesteps, eta=ddim_eta, verbose=verbose + ) + self.register_buffer("ddim_sigmas", ddim_sigmas) + self.register_buffer("ddim_alphas", ddim_alphas) + self.register_buffer("ddim_alphas_prev", ddim_alphas_prev) + self.register_buffer("ddim_sqrt_one_minus_alphas", np.sqrt(1.0 - ddim_alphas)) sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt( - (1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * ( - 1 - self.alphas_cumprod / self.alphas_cumprod_prev)) - self.register_buffer('ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps) + (1 - self.alphas_cumprod_prev) + / (1 - self.alphas_cumprod) + * (1 - self.alphas_cumprod / self.alphas_cumprod_prev) + ) + self.register_buffer("ddim_sigmas_for_original_num_steps", sigmas_for_original_sampling_steps) @torch.no_grad() - def sample(self, - S, - batch_size, - shape, - conditioning=None, - callback=None, - normals_sequence=None, - img_callback=None, - quantize_x0=False, - eta=0., - mask=None, - x0=None, - temperature=1., - noise_dropout=0., - score_corrector=None, - corrector_kwargs=None, - verbose=True, - x_T=None, - log_every_t=100, - unconditional_guidance_scale=1., - unconditional_conditioning=None, # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ... - dynamic_threshold=None, - ucg_schedule=None, - **kwargs - ): + def sample( + self, + S, + batch_size, + shape, + conditioning=None, + callback=None, + normals_sequence=None, + img_callback=None, + quantize_x0=False, + eta=0.0, + mask=None, + x0=None, + temperature=1.0, + noise_dropout=0.0, + score_corrector=None, + corrector_kwargs=None, + verbose=True, + x_T=None, + log_every_t=100, + unconditional_guidance_scale=1.0, + unconditional_conditioning=None, # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ... + dynamic_threshold=None, + ucg_schedule=None, + **kwargs, + ): if conditioning is not None: if isinstance(conditioning, dict): ctmp = conditioning[list(conditioning.keys())[0]] - while isinstance(ctmp, list): ctmp = ctmp[0] + while isinstance(ctmp, list): + ctmp = ctmp[0] cbs = ctmp.shape[0] if cbs != batch_size: print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}") @@ -98,35 +110,53 @@ class DDIMSampler(object): # sampling C, H, W = shape size = (batch_size, C, H, W) - print(f'Data shape for DDIM sampling is {size}, eta {eta}') - - samples, intermediates = self.ddim_sampling(conditioning, size, - callback=callback, - img_callback=img_callback, - quantize_denoised=quantize_x0, - mask=mask, x0=x0, - ddim_use_original_steps=False, - noise_dropout=noise_dropout, - temperature=temperature, - score_corrector=score_corrector, - corrector_kwargs=corrector_kwargs, - x_T=x_T, - log_every_t=log_every_t, - unconditional_guidance_scale=unconditional_guidance_scale, - unconditional_conditioning=unconditional_conditioning, - dynamic_threshold=dynamic_threshold, - ucg_schedule=ucg_schedule - ) + print(f"Data shape for DDIM sampling is {size}, eta {eta}") + + samples, intermediates = self.ddim_sampling( + conditioning, + size, + callback=callback, + img_callback=img_callback, + quantize_denoised=quantize_x0, + mask=mask, + x0=x0, + ddim_use_original_steps=False, + noise_dropout=noise_dropout, + temperature=temperature, + score_corrector=score_corrector, + corrector_kwargs=corrector_kwargs, + x_T=x_T, + log_every_t=log_every_t, + unconditional_guidance_scale=unconditional_guidance_scale, + unconditional_conditioning=unconditional_conditioning, + dynamic_threshold=dynamic_threshold, + ucg_schedule=ucg_schedule, + ) return samples, intermediates @torch.no_grad() - def ddim_sampling(self, cond, shape, - x_T=None, ddim_use_original_steps=False, - callback=None, timesteps=None, quantize_denoised=False, - mask=None, x0=None, img_callback=None, log_every_t=100, - temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None, - unconditional_guidance_scale=1., unconditional_conditioning=None, dynamic_threshold=None, - ucg_schedule=None): + def ddim_sampling( + self, + cond, + shape, + x_T=None, + ddim_use_original_steps=False, + callback=None, + timesteps=None, + quantize_denoised=False, + mask=None, + x0=None, + img_callback=None, + log_every_t=100, + temperature=1.0, + noise_dropout=0.0, + score_corrector=None, + corrector_kwargs=None, + unconditional_guidance_scale=1.0, + unconditional_conditioning=None, + dynamic_threshold=None, + ucg_schedule=None, + ): device = self.model.betas.device b = shape[0] if x_T is None: @@ -140,12 +170,12 @@ class DDIMSampler(object): subset_end = int(min(timesteps / self.ddim_timesteps.shape[0], 1) * self.ddim_timesteps.shape[0]) - 1 timesteps = self.ddim_timesteps[:subset_end] - intermediates = {'x_inter': [img], 'pred_x0': [img]} - time_range = reversed(range(0,timesteps)) if ddim_use_original_steps else np.flip(timesteps) + intermediates = {"x_inter": [img], "pred_x0": [img]} + time_range = reversed(range(0, timesteps)) if ddim_use_original_steps else np.flip(timesteps) total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0] print(f"Running DDIM Sampling with {total_steps} timesteps") - iterator = tqdm(time_range, desc='DDIM Sampler', total=total_steps) + iterator = tqdm(time_range, desc="DDIM Sampler", total=total_steps) for i, step in enumerate(iterator): index = total_steps - i - 1 @@ -154,37 +184,60 @@ class DDIMSampler(object): if mask is not None: assert x0 is not None img_orig = self.model.q_sample(x0, ts) # TODO: deterministic forward pass? - img = img_orig * mask + (1. - mask) * img + img = img_orig * mask + (1.0 - mask) * img if ucg_schedule is not None: assert len(ucg_schedule) == len(time_range) unconditional_guidance_scale = ucg_schedule[i] - outs = self.p_sample_ddim(img, cond, ts, index=index, use_original_steps=ddim_use_original_steps, - quantize_denoised=quantize_denoised, temperature=temperature, - noise_dropout=noise_dropout, score_corrector=score_corrector, - corrector_kwargs=corrector_kwargs, - unconditional_guidance_scale=unconditional_guidance_scale, - unconditional_conditioning=unconditional_conditioning, - dynamic_threshold=dynamic_threshold) + outs = self.p_sample_ddim( + img, + cond, + ts, + index=index, + use_original_steps=ddim_use_original_steps, + quantize_denoised=quantize_denoised, + temperature=temperature, + noise_dropout=noise_dropout, + score_corrector=score_corrector, + corrector_kwargs=corrector_kwargs, + unconditional_guidance_scale=unconditional_guidance_scale, + unconditional_conditioning=unconditional_conditioning, + dynamic_threshold=dynamic_threshold, + ) img, pred_x0 = outs - if callback: callback(i) - if img_callback: img_callback(pred_x0, i) + if callback: + callback(i) + if img_callback: + img_callback(pred_x0, i) if index % log_every_t == 0 or index == total_steps - 1: - intermediates['x_inter'].append(img) - intermediates['pred_x0'].append(pred_x0) + intermediates["x_inter"].append(img) + intermediates["pred_x0"].append(pred_x0) return img, intermediates @torch.no_grad() - def p_sample_ddim(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False, - temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None, - unconditional_guidance_scale=1., unconditional_conditioning=None, - dynamic_threshold=None): + def p_sample_ddim( + self, + x, + c, + t, + index, + repeat_noise=False, + use_original_steps=False, + quantize_denoised=False, + temperature=1.0, + noise_dropout=0.0, + score_corrector=None, + corrector_kwargs=None, + unconditional_guidance_scale=1.0, + unconditional_conditioning=None, + dynamic_threshold=None, + ): b, *_, device = *x.shape, x.device - if unconditional_conditioning is None or unconditional_guidance_scale == 1.: + if unconditional_conditioning is None or unconditional_guidance_scale == 1.0: model_output = self.model.apply_model(x, t, c) else: x_in = torch.cat([x] * 2) @@ -194,13 +247,9 @@ class DDIMSampler(object): c_in = dict() for k in c: if isinstance(c[k], list): - c_in[k] = [torch.cat([ - unconditional_conditioning[k][i], - c[k][i]]) for i in range(len(c[k]))] + c_in[k] = [torch.cat([unconditional_conditioning[k][i], c[k][i]]) for i in range(len(c[k]))] else: - c_in[k] = torch.cat([ - unconditional_conditioning[k], - c[k]]) + c_in[k] = torch.cat([unconditional_conditioning[k], c[k]]) elif isinstance(c, list): c_in = list() assert isinstance(unconditional_conditioning, list) @@ -217,18 +266,20 @@ class DDIMSampler(object): e_t = model_output if score_corrector is not None: - assert self.model.parameterization == "eps", 'not implemented' + assert self.model.parameterization == "eps", "not implemented" e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs) alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev - sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas + sqrt_one_minus_alphas = ( + self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas + ) sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas # select parameters corresponding to the currently considered timestep a_t = torch.full((b, 1, 1, 1), alphas[index], device=device) a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device) sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device) - sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device) + sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index], device=device) # current prediction for x_0 if self.model.parameterization != "v": @@ -243,16 +294,25 @@ class DDIMSampler(object): raise NotImplementedError() # direction pointing to x_t - dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t + dir_xt = (1.0 - a_prev - sigma_t**2).sqrt() * e_t noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature - if noise_dropout > 0.: + if noise_dropout > 0.0: noise = torch.nn.functional.dropout(noise, p=noise_dropout) x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise return x_prev, pred_x0 @torch.no_grad() - def encode(self, x0, c, t_enc, use_original_steps=False, return_intermediates=None, - unconditional_guidance_scale=1.0, unconditional_conditioning=None, callback=None): + def encode( + self, + x0, + c, + t_enc, + use_original_steps=False, + return_intermediates=None, + unconditional_guidance_scale=1.0, + unconditional_conditioning=None, + callback=None, + ): num_reference_steps = self.ddpm_num_timesteps if use_original_steps else self.ddim_timesteps.shape[0] assert t_enc <= num_reference_steps @@ -268,33 +328,37 @@ class DDIMSampler(object): x_next = x0 intermediates = [] inter_steps = [] - for i in tqdm(range(num_steps), desc='Encoding Image'): + for i in tqdm(range(num_steps), desc="Encoding Image"): t = torch.full((x0.shape[0],), i, device=self.model.device, dtype=torch.long) - if unconditional_guidance_scale == 1.: + if unconditional_guidance_scale == 1.0: noise_pred = self.model.apply_model(x_next, t, c) else: assert unconditional_conditioning is not None e_t_uncond, noise_pred = torch.chunk( - self.model.apply_model(torch.cat((x_next, x_next)), torch.cat((t, t)), - torch.cat((unconditional_conditioning, c))), 2) + self.model.apply_model( + torch.cat((x_next, x_next)), torch.cat((t, t)), torch.cat((unconditional_conditioning, c)) + ), + 2, + ) noise_pred = e_t_uncond + unconditional_guidance_scale * (noise_pred - e_t_uncond) xt_weighted = (alphas_next[i] / alphas[i]).sqrt() * x_next - weighted_noise_pred = alphas_next[i].sqrt() * ( - (1 / alphas_next[i] - 1).sqrt() - (1 / alphas[i] - 1).sqrt()) * noise_pred + weighted_noise_pred = ( + alphas_next[i].sqrt() * ((1 / alphas_next[i] - 1).sqrt() - (1 / alphas[i] - 1).sqrt()) * noise_pred + ) x_next = xt_weighted + weighted_noise_pred - if return_intermediates and i % ( - num_steps // return_intermediates) == 0 and i < num_steps - 1: + if return_intermediates and i % (num_steps // return_intermediates) == 0 and i < num_steps - 1: intermediates.append(x_next) inter_steps.append(i) elif return_intermediates and i >= num_steps - 2: intermediates.append(x_next) inter_steps.append(i) - if callback: callback(i) + if callback: + callback(i) - out = {'x_encoded': x_next, 'intermediate_steps': inter_steps} + out = {"x_encoded": x_next, "intermediate_steps": inter_steps} if return_intermediates: - out.update({'intermediates': intermediates}) + out.update({"intermediates": intermediates}) return x_next, out @torch.no_grad() @@ -310,13 +374,22 @@ class DDIMSampler(object): if noise is None: noise = torch.randn_like(x0) - return (extract_into_tensor(sqrt_alphas_cumprod, t, x0.shape) * x0 + - extract_into_tensor(sqrt_one_minus_alphas_cumprod, t, x0.shape) * noise) + return ( + extract_into_tensor(sqrt_alphas_cumprod, t, x0.shape) * x0 + + extract_into_tensor(sqrt_one_minus_alphas_cumprod, t, x0.shape) * noise + ) @torch.no_grad() - def decode(self, x_latent, cond, t_start, unconditional_guidance_scale=1.0, unconditional_conditioning=None, - use_original_steps=False, callback=None): - + def decode( + self, + x_latent, + cond, + t_start, + unconditional_guidance_scale=1.0, + unconditional_conditioning=None, + use_original_steps=False, + callback=None, + ): timesteps = np.arange(self.ddpm_num_timesteps) if use_original_steps else self.ddim_timesteps timesteps = timesteps[:t_start] @@ -324,13 +397,20 @@ class DDIMSampler(object): total_steps = timesteps.shape[0] print(f"Running DDIM Sampling with {total_steps} timesteps") - iterator = tqdm(time_range, desc='Decoding image', total=total_steps) + iterator = tqdm(time_range, desc="Decoding image", total=total_steps) x_dec = x_latent for i, step in enumerate(iterator): index = total_steps - i - 1 ts = torch.full((x_latent.shape[0],), step, device=x_latent.device, dtype=torch.long) - x_dec, _ = self.p_sample_ddim(x_dec, cond, ts, index=index, use_original_steps=use_original_steps, - unconditional_guidance_scale=unconditional_guidance_scale, - unconditional_conditioning=unconditional_conditioning) - if callback: callback(i) - return x_dec \ No newline at end of file + x_dec, _ = self.p_sample_ddim( + x_dec, + cond, + ts, + index=index, + use_original_steps=use_original_steps, + unconditional_guidance_scale=unconditional_guidance_scale, + unconditional_conditioning=unconditional_conditioning, + ) + if callback: + callback(i) + return x_dec diff --git a/examples/images/diffusion/ldm/models/diffusion/ddpm.py b/examples/images/diffusion/ldm/models/diffusion/ddpm.py index 842ec1371ea09fdb023f584e0525b815111ad2be..20e26256e18e950d01d241b86754b3def5cf92c4 100644 --- a/examples/images/diffusion/ldm/models/diffusion/ddpm.py +++ b/examples/images/diffusion/ldm/models/diffusion/ddpm.py @@ -27,23 +27,22 @@ from ldm.models.autoencoder import * from ldm.models.autoencoder import AutoencoderKL, IdentityFirstStage from ldm.models.diffusion.ddim import * from ldm.models.diffusion.ddim import DDIMSampler -from ldm.modules.midas.api import MiDaSInference from ldm.modules.diffusionmodules.model import * -from ldm.modules.diffusionmodules.model import Decoder, Encoder, Model from ldm.modules.diffusionmodules.openaimodel import * -from ldm.modules.diffusionmodules.openaimodel import AttentionPool2d, UNetModel +from ldm.modules.diffusionmodules.openaimodel import UNetModel +from ldm.modules.diffusionmodules.upscaling import ImageConcatWithNoiseAugmentation from ldm.modules.diffusionmodules.util import extract_into_tensor, make_beta_schedule, noise_like from ldm.modules.distributions.distributions import DiagonalGaussianDistribution, normal_kl -from ldm.modules.diffusionmodules.upscaling import ImageConcatWithNoiseAugmentation from ldm.modules.ema import LitEma from ldm.modules.encoders.modules import * +from ldm.modules.midas.api import MiDaSInference from ldm.util import count_params, default, exists, isimage, ismap, log_txt_as_img, mean_flat from omegaconf import ListConfig from torch.optim.lr_scheduler import LambdaLR from torchvision.utils import make_grid from tqdm import tqdm -__conditioning_keys__ = {'concat': 'c_concat', 'crossattn': 'c_crossattn', 'adm': 'y'} +__conditioning_keys__ = {"concat": "c_concat", "crossattn": "c_crossattn", "adm": "y"} def disabled_train(self, mode=True): @@ -78,15 +77,15 @@ class DDPM(pl.LightningModule): linear_end=2e-2, cosine_s=8e-3, given_betas=None, - original_elbo_weight=0., - v_posterior=0., # weight for choosing posterior variance as sigma = (1-v) * beta_tilde + v * beta - l_simple_weight=1., + original_elbo_weight=0.0, + v_posterior=0.0, # weight for choosing posterior variance as sigma = (1-v) * beta_tilde + v * beta + l_simple_weight=1.0, conditioning_key=None, - parameterization="eps", # all assuming fixed variance schedules + parameterization="eps", # all assuming fixed variance schedules scheduler_config=None, use_positional_encodings=False, learn_logvar=False, - logvar_init=0., + logvar_init=0.0, use_fp16=True, make_it_fit=False, ucg_training=None, @@ -133,9 +132,9 @@ class DDPM(pl.LightningModule): if reset_ema: assert exists(ckpt) - ''' + """ Uncomment if you Use DDP Strategy - ''' + """ # if ckpt is not None: # self.init_from_ckpt(ckpt, ignore_keys=ignore_keys, only_model=load_only_unet) # if reset_ema: @@ -155,12 +154,14 @@ class DDPM(pl.LightningModule): self.linear_end = linear_end self.cosine_s = cosine_s - self.register_schedule(given_betas=given_betas, - beta_schedule=beta_schedule, - timesteps=timesteps, - linear_start=linear_start, - linear_end=linear_end, - cosine_s=cosine_s) + self.register_schedule( + given_betas=given_betas, + beta_schedule=beta_schedule, + timesteps=timesteps, + linear_start=linear_start, + linear_end=linear_end, + cosine_s=cosine_s, + ) self.loss_type = loss_type @@ -174,67 +175,73 @@ class DDPM(pl.LightningModule): if self.ucg_training: self.ucg_prng = np.random.RandomState() - def register_schedule(self, - given_betas=None, - beta_schedule="linear", - timesteps=1000, - linear_start=1e-4, - linear_end=2e-2, - cosine_s=8e-3): + def register_schedule( + self, + given_betas=None, + beta_schedule="linear", + timesteps=1000, + linear_start=1e-4, + linear_end=2e-2, + cosine_s=8e-3, + ): if exists(given_betas): betas = given_betas else: - betas = make_beta_schedule(beta_schedule, - timesteps, - linear_start=linear_start, - linear_end=linear_end, - cosine_s=cosine_s) - alphas = 1. - betas + betas = make_beta_schedule( + beta_schedule, timesteps, linear_start=linear_start, linear_end=linear_end, cosine_s=cosine_s + ) + alphas = 1.0 - betas alphas_cumprod = np.cumprod(alphas, axis=0) - alphas_cumprod_prev = np.append(1., alphas_cumprod[:-1]) + alphas_cumprod_prev = np.append(1.0, alphas_cumprod[:-1]) - timesteps, = betas.shape + (timesteps,) = betas.shape self.num_timesteps = int(timesteps) self.linear_start = linear_start self.linear_end = linear_end - assert alphas_cumprod.shape[0] == self.num_timesteps, 'alphas have to be defined for each timestep' + assert alphas_cumprod.shape[0] == self.num_timesteps, "alphas have to be defined for each timestep" to_torch = partial(torch.tensor, dtype=torch.float32) - self.register_buffer('betas', to_torch(betas)) - self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod)) - self.register_buffer('alphas_cumprod_prev', to_torch(alphas_cumprod_prev)) + self.register_buffer("betas", to_torch(betas)) + self.register_buffer("alphas_cumprod", to_torch(alphas_cumprod)) + self.register_buffer("alphas_cumprod_prev", to_torch(alphas_cumprod_prev)) # calculations for diffusion q(x_t | x_{t-1}) and others - self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod))) - self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod))) - self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod))) - self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod))) - self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod - 1))) + self.register_buffer("sqrt_alphas_cumprod", to_torch(np.sqrt(alphas_cumprod))) + self.register_buffer("sqrt_one_minus_alphas_cumprod", to_torch(np.sqrt(1.0 - alphas_cumprod))) + self.register_buffer("log_one_minus_alphas_cumprod", to_torch(np.log(1.0 - alphas_cumprod))) + self.register_buffer("sqrt_recip_alphas_cumprod", to_torch(np.sqrt(1.0 / alphas_cumprod))) + self.register_buffer("sqrt_recipm1_alphas_cumprod", to_torch(np.sqrt(1.0 / alphas_cumprod - 1))) # calculations for posterior q(x_{t-1} | x_t, x_0) - posterior_variance = (1 - self.v_posterior) * betas * (1. - alphas_cumprod_prev) / ( - 1. - alphas_cumprod) + self.v_posterior * betas + posterior_variance = (1 - self.v_posterior) * betas * (1.0 - alphas_cumprod_prev) / ( + 1.0 - alphas_cumprod + ) + self.v_posterior * betas # above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t) - self.register_buffer('posterior_variance', to_torch(posterior_variance)) + self.register_buffer("posterior_variance", to_torch(posterior_variance)) # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain - self.register_buffer('posterior_log_variance_clipped', to_torch(np.log(np.maximum(posterior_variance, 1e-20)))) - self.register_buffer('posterior_mean_coef1', - to_torch(betas * np.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod))) - self.register_buffer('posterior_mean_coef2', - to_torch((1. - alphas_cumprod_prev) * np.sqrt(alphas) / (1. - alphas_cumprod))) + self.register_buffer("posterior_log_variance_clipped", to_torch(np.log(np.maximum(posterior_variance, 1e-20)))) + self.register_buffer( + "posterior_mean_coef1", to_torch(betas * np.sqrt(alphas_cumprod_prev) / (1.0 - alphas_cumprod)) + ) + self.register_buffer( + "posterior_mean_coef2", to_torch((1.0 - alphas_cumprod_prev) * np.sqrt(alphas) / (1.0 - alphas_cumprod)) + ) if self.parameterization == "eps": - lvlb_weights = self.betas**2 / (2 * self.posterior_variance * to_torch(alphas) * (1 - self.alphas_cumprod)) + lvlb_weights = self.betas**2 / ( + 2 * self.posterior_variance * to_torch(alphas) * (1 - self.alphas_cumprod) + ) elif self.parameterization == "x0": - lvlb_weights = 0.5 * np.sqrt(torch.Tensor(alphas_cumprod)) / (2. * 1 - torch.Tensor(alphas_cumprod)) + lvlb_weights = 0.5 * np.sqrt(torch.Tensor(alphas_cumprod)) / (2.0 * 1 - torch.Tensor(alphas_cumprod)) elif self.parameterization == "v": - lvlb_weights = torch.ones_like(self.betas**2 / (2 * self.posterior_variance * to_torch(alphas) * - (1 - self.alphas_cumprod))) + lvlb_weights = torch.ones_like( + self.betas**2 / (2 * self.posterior_variance * to_torch(alphas) * (1 - self.alphas_cumprod)) + ) else: raise NotImplementedError("mu not supported") lvlb_weights[0] = lvlb_weights[1] - self.register_buffer('lvlb_weights', lvlb_weights, persistent=False) + self.register_buffer("lvlb_weights", lvlb_weights, persistent=False) assert not torch.isnan(self.lvlb_weights).all() @contextmanager @@ -265,9 +272,11 @@ class DDPM(pl.LightningModule): del sd[k] if self.make_it_fit: n_params = len([name for name, _ in itertools.chain(self.named_parameters(), self.named_buffers())]) - for name, param in tqdm(itertools.chain(self.named_parameters(), self.named_buffers()), - desc="Fitting old weights to new weights", - total=n_params): + for name, param in tqdm( + itertools.chain(self.named_parameters(), self.named_buffers()), + desc="Fitting old weights to new weights", + total=n_params, + ): if not name in sd: continue old_shape = sd[name].shape @@ -302,8 +311,9 @@ class DDPM(pl.LightningModule): sd[name] = new_param - missing, unexpected = self.load_state_dict(sd, strict=False) if not only_model else self.model.load_state_dict( - sd, strict=False) + missing, unexpected = ( + self.load_state_dict(sd, strict=False) if not only_model else self.model.load_state_dict(sd, strict=False) + ) rank_zero_info(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys") if len(missing) > 0: rank_zero_info(f"Missing Keys:\n {missing}") @@ -317,28 +327,36 @@ class DDPM(pl.LightningModule): :param t: the number of diffusion steps (minus 1). Here, 0 means one step. :return: A tuple (mean, variance, log_variance), all of x_start's shape. """ - mean = (extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start) + mean = extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start variance = extract_into_tensor(1.0 - self.alphas_cumprod, t, x_start.shape) log_variance = extract_into_tensor(self.log_one_minus_alphas_cumprod, t, x_start.shape) return mean, variance, log_variance def predict_start_from_noise(self, x_t, t, noise): - return (extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - - extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * noise) + return ( + extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t + - extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * noise + ) def predict_start_from_z_and_v(self, x_t, t, v): # self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod))) # self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod))) - return (extract_into_tensor(self.sqrt_alphas_cumprod, t, x_t.shape) * x_t - - extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_t.shape) * v) + return ( + extract_into_tensor(self.sqrt_alphas_cumprod, t, x_t.shape) * x_t + - extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_t.shape) * v + ) def predict_eps_from_z_and_v(self, x_t, t, v): - return (extract_into_tensor(self.sqrt_alphas_cumprod, t, x_t.shape) * v + - extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_t.shape) * x_t) + return ( + extract_into_tensor(self.sqrt_alphas_cumprod, t, x_t.shape) * v + + extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_t.shape) * x_t + ) def q_posterior(self, x_start, x_t, t): - posterior_mean = (extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape) * x_start + - extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) * x_t) + posterior_mean = ( + extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape) * x_start + + extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) * x_t + ) posterior_variance = extract_into_tensor(self.posterior_variance, t, x_t.shape) posterior_log_variance_clipped = extract_into_tensor(self.posterior_log_variance_clipped, t, x_t.shape) return posterior_mean, posterior_variance, posterior_log_variance_clipped @@ -350,7 +368,7 @@ class DDPM(pl.LightningModule): elif self.parameterization == "x0": x_recon = model_out if clip_denoised: - x_recon.clamp_(-1., 1.) + x_recon.clamp_(-1.0, 1.0) model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t) return model_mean, posterior_variance, posterior_log_variance @@ -370,10 +388,10 @@ class DDPM(pl.LightningModule): b = shape[0] img = torch.randn(shape, device=device) intermediates = [img] - for i in tqdm(reversed(range(0, self.num_timesteps)), desc='Sampling t', total=self.num_timesteps): - img = self.p_sample(img, - torch.full((b,), i, device=device, dtype=torch.long), - clip_denoised=self.clip_denoised) + for i in tqdm(reversed(range(0, self.num_timesteps)), desc="Sampling t", total=self.num_timesteps): + img = self.p_sample( + img, torch.full((b,), i, device=device, dtype=torch.long), clip_denoised=self.clip_denoised + ) if i % self.log_every_t == 0 or i == self.num_timesteps - 1: intermediates.append(img) if return_intermediates: @@ -384,28 +402,33 @@ class DDPM(pl.LightningModule): def sample(self, batch_size=16, return_intermediates=False): image_size = self.image_size channels = self.channels - return self.p_sample_loop((batch_size, channels, image_size, image_size), - return_intermediates=return_intermediates) + return self.p_sample_loop( + (batch_size, channels, image_size, image_size), return_intermediates=return_intermediates + ) def q_sample(self, x_start, t, noise=None): noise = default(noise, lambda: torch.randn_like(x_start)) - return (extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start + - extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise) + return ( + extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start + + extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise + ) def get_v(self, x, noise, t): - return (extract_into_tensor(self.sqrt_alphas_cumprod, t, x.shape) * noise - - extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x.shape) * x) + return ( + extract_into_tensor(self.sqrt_alphas_cumprod, t, x.shape) * noise + - extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x.shape) * x + ) def get_loss(self, pred, target, mean=True): - if self.loss_type == 'l1': + if self.loss_type == "l1": loss = (target - pred).abs() if mean: loss = loss.mean() - elif self.loss_type == 'l2': + elif self.loss_type == "l2": if mean: loss = torch.nn.functional.mse_loss(target, pred) else: - loss = torch.nn.functional.mse_loss(target, pred, reduction='none') + loss = torch.nn.functional.mse_loss(target, pred, reduction="none") else: raise NotImplementedError("unknown loss type '{loss_type}'") @@ -428,17 +451,17 @@ class DDPM(pl.LightningModule): loss = self.get_loss(model_out, target, mean=False).mean(dim=[1, 2, 3]) - log_prefix = 'train' if self.training else 'val' + log_prefix = "train" if self.training else "val" - loss_dict.update({f'{log_prefix}/loss_simple': loss.mean()}) + loss_dict.update({f"{log_prefix}/loss_simple": loss.mean()}) loss_simple = loss.mean() * self.l_simple_weight loss_vlb = (self.lvlb_weights[t] * loss).mean() - loss_dict.update({f'{log_prefix}/loss_vlb': loss_vlb}) + loss_dict.update({f"{log_prefix}/loss_vlb": loss_vlb}) loss = loss_simple + self.original_elbo_weight * loss_vlb - loss_dict.update({f'{log_prefix}/loss': loss}) + loss_dict.update({f"{log_prefix}/loss": loss}) return loss, loss_dict @@ -452,7 +475,7 @@ class DDPM(pl.LightningModule): x = batch[k] if len(x.shape) == 3: x = x[..., None] - x = rearrange(x, 'b h w c -> b c h w') + x = rearrange(x, "b h w c -> b c h w") if self.use_fp16: x = x.to(memory_format=torch.contiguous_format).half() else: @@ -481,8 +504,8 @@ class DDPM(pl.LightningModule): self.log("global_step", self.global_step, prog_bar=True, logger=True, on_step=True, on_epoch=False) if self.use_scheduler: - lr = self.optimizers().param_groups[0]['lr'] - self.log('lr_abs', lr, prog_bar=True, logger=True, on_step=True, on_epoch=False) + lr = self.optimizers().param_groups[0]["lr"] + self.log("lr_abs", lr, prog_bar=True, logger=True, on_step=True, on_epoch=False) return loss @@ -491,7 +514,7 @@ class DDPM(pl.LightningModule): _, loss_dict_no_ema = self.shared_step(batch) with self.ema_scope(): _, loss_dict_ema = self.shared_step(batch) - loss_dict_ema = {key + '_ema': loss_dict_ema[key] for key in loss_dict_ema} + loss_dict_ema = {key + "_ema": loss_dict_ema[key] for key in loss_dict_ema} self.log_dict(loss_dict_no_ema, prog_bar=False, logger=True, on_step=False, on_epoch=True) self.log_dict(loss_dict_ema, prog_bar=False, logger=True, on_step=False, on_epoch=True) @@ -501,8 +524,8 @@ class DDPM(pl.LightningModule): def _get_rows_from_list(self, samples): n_imgs_per_row = len(samples) - denoise_grid = rearrange(samples, 'n b c h w -> b n c h w') - denoise_grid = rearrange(denoise_grid, 'b n c h w -> (b n) c h w') + denoise_grid = rearrange(samples, "n b c h w -> b n c h w") + denoise_grid = rearrange(denoise_grid, "b n c h w -> (b n) c h w") denoise_grid = make_grid(denoise_grid, nrow=n_imgs_per_row) return denoise_grid @@ -521,7 +544,7 @@ class DDPM(pl.LightningModule): for t in range(self.num_timesteps): if t % self.log_every_t == 0 or t == self.num_timesteps - 1: - t = repeat(torch.tensor([t]), '1 -> b', b=n_row) + t = repeat(torch.tensor([t]), "1 -> b", b=n_row) t = t.to(self.device).long() noise = torch.randn_like(x_start) x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise) @@ -556,29 +579,31 @@ class DDPM(pl.LightningModule): class LatentDiffusion(DDPM): """main class""" - def __init__(self, - first_stage_config, - cond_stage_config, - num_timesteps_cond=None, - cond_stage_key="image", - cond_stage_trainable=False, - concat_mode=True, - cond_stage_forward=None, - conditioning_key=None, - scale_factor=1.0, - scale_by_std=False, - use_fp16=True, - force_null_conditioning=False, - *args, - **kwargs): + def __init__( + self, + first_stage_config, + cond_stage_config, + num_timesteps_cond=None, + cond_stage_key="image", + cond_stage_trainable=False, + concat_mode=True, + cond_stage_forward=None, + conditioning_key=None, + scale_factor=1.0, + scale_by_std=False, + use_fp16=True, + force_null_conditioning=False, + *args, + **kwargs, + ): self.force_null_conditioning = force_null_conditioning self.num_timesteps_cond = default(num_timesteps_cond, 1) self.scale_by_std = scale_by_std - assert self.num_timesteps_cond <= kwargs['timesteps'] + assert self.num_timesteps_cond <= kwargs["timesteps"] # for backwards compatibility after implementation of DiffusionWrapper if conditioning_key is None: - conditioning_key = 'concat' if concat_mode else 'crossattn' - if cond_stage_config == '__is_unconditional__' and not self.force_null_conditioning: + conditioning_key = "concat" if concat_mode else "crossattn" + if cond_stage_config == "__is_unconditional__" and not self.force_null_conditioning: conditioning_key = None super().__init__(conditioning_key=conditioning_key, *args, **kwargs) @@ -593,7 +618,7 @@ class LatentDiffusion(DDPM): if not scale_by_std: self.scale_factor = scale_factor else: - self.register_buffer('scale_factor', torch.tensor(scale_factor)) + self.register_buffer("scale_factor", torch.tensor(scale_factor)) self.first_stage_config = first_stage_config self.cond_stage_config = cond_stage_config self.instantiate_first_stage(first_stage_config) @@ -601,9 +626,9 @@ class LatentDiffusion(DDPM): self.cond_stage_forward = cond_stage_forward self.clip_denoised = False self.bbox_tokenizer = None - ''' + """ Uncomment if you Use DDP Strategy - ''' + """ # self.restarted_from_ckpt = False # if self.ckpt is not None: # self.init_from_ckpt(self.ckpt, self.ignore_keys) @@ -630,15 +655,18 @@ class LatentDiffusion(DDPM): if self.reset_ema: assert self.use_ema rank_zero_info( - f"Resetting ema to pure model weights. This is useful when restoring from an ema-only checkpoint.") + f"Resetting ema to pure model weights. This is useful when restoring from an ema-only checkpoint." + ) self.model_ema = LitEma(self.model) - self.register_schedule(given_betas=self.given_betas, - beta_schedule=self.beta_schedule, - timesteps=self.timesteps, - linear_start=self.linear_start, - linear_end=self.linear_end, - cosine_s=self.cosine_s) + self.register_schedule( + given_betas=self.given_betas, + beta_schedule=self.beta_schedule, + timesteps=self.timesteps, + linear_start=self.linear_start, + linear_end=self.linear_end, + cosine_s=self.cosine_s, + ) self.logvar = torch.full(fill_value=self.logvar_init, size=(self.num_timesteps,)) if self.learn_logvar: @@ -654,20 +682,29 @@ class LatentDiffusion(DDPM): if self.reset_ema: assert self.use_ema rank_zero_info( - f"Resetting ema to pure model weights. This is useful when restoring from an ema-only checkpoint.") + f"Resetting ema to pure model weights. This is useful when restoring from an ema-only checkpoint." + ) self.model_ema = LitEma(self.model) - def make_cond_schedule(self,): + def make_cond_schedule( + self, + ): self.cond_ids = torch.full(size=(self.num_timesteps,), fill_value=self.num_timesteps - 1, dtype=torch.long) ids = torch.round(torch.linspace(0, self.num_timesteps - 1, self.num_timesteps_cond)).long() - self.cond_ids[:self.num_timesteps_cond] = ids + self.cond_ids[: self.num_timesteps_cond] = ids @rank_zero_only @torch.no_grad() def on_train_batch_start(self, batch, batch_idx): # only for very first batch - if self.scale_by_std and self.current_epoch == 0 and self.global_step == 0 and batch_idx == 0 and not self.restarted_from_ckpt: - assert self.scale_factor == 1., 'rather not use custom rescaling and std-rescaling simultaneously' + if ( + self.scale_by_std + and self.current_epoch == 0 + and self.global_step == 0 + and batch_idx == 0 + and not self.restarted_from_ckpt + ): + assert self.scale_factor == 1.0, "rather not use custom rescaling and std-rescaling simultaneously" # set rescale weight to 1./std of encodings rank_zero_info("### USING STD-RESCALING ###") x = super().get_input(batch, self.first_stage_key) @@ -675,17 +712,19 @@ class LatentDiffusion(DDPM): encoder_posterior = self.encode_first_stage(x) z = self.get_first_stage_encoding(encoder_posterior).detach() del self.scale_factor - self.register_buffer('scale_factor', 1. / z.flatten().std()) + self.register_buffer("scale_factor", 1.0 / z.flatten().std()) rank_zero_info(f"setting self.scale_factor to {self.scale_factor}") rank_zero_info("### USING STD-RESCALING ###") - def register_schedule(self, - given_betas=None, - beta_schedule="linear", - timesteps=1000, - linear_start=1e-4, - linear_end=2e-2, - cosine_s=8e-3): + def register_schedule( + self, + given_betas=None, + beta_schedule="linear", + timesteps=1000, + linear_start=1e-4, + linear_end=2e-2, + cosine_s=8e-3, + ): super().register_schedule(given_betas, beta_schedule, timesteps, linear_start, linear_end, cosine_s) self.shorten_cond_schedule = self.num_timesteps_cond > 1 @@ -718,15 +757,16 @@ class LatentDiffusion(DDPM): model = FrozenOpenCLIPEmbedder(**config) self.cond_stage_model = model - def _get_denoise_row_from_list(self, samples, desc='', force_no_decoder_quantization=False): + def _get_denoise_row_from_list(self, samples, desc="", force_no_decoder_quantization=False): denoise_row = [] for zd in tqdm(samples, desc=desc): denoise_row.append( - self.decode_first_stage(zd.to(self.device), force_not_quantize=force_no_decoder_quantization)) + self.decode_first_stage(zd.to(self.device), force_not_quantize=force_no_decoder_quantization) + ) n_imgs_per_row = len(denoise_row) - denoise_row = torch.stack(denoise_row) # n_log_step, n_row, C, H, W - denoise_grid = rearrange(denoise_row, 'n b c h w -> b n c h w') - denoise_grid = rearrange(denoise_grid, 'b n c h w -> (b n) c h w') + denoise_row = torch.stack(denoise_row) # n_log_step, n_row, C, H, W + denoise_grid = rearrange(denoise_row, "n b c h w -> b n c h w") + denoise_grid = rearrange(denoise_grid, "b n c h w -> (b n) c h w") denoise_grid = make_grid(denoise_grid, nrow=n_imgs_per_row) return denoise_grid @@ -741,7 +781,7 @@ class LatentDiffusion(DDPM): def get_learned_conditioning(self, c): if self.cond_stage_forward is None: - if hasattr(self.cond_stage_model, 'encode') and callable(self.cond_stage_model.encode): + if hasattr(self.cond_stage_model, "encode") and callable(self.cond_stage_model.encode): c = self.cond_stage_model.encode(c) if isinstance(c, DiagonalGaussianDistribution): c = c.mode() @@ -784,14 +824,17 @@ class LatentDiffusion(DDPM): if self.split_input_params["tie_braker"]: L_weighting = self.delta_border(Ly, Lx) - L_weighting = torch.clip(L_weighting, self.split_input_params["clip_min_tie_weight"], - self.split_input_params["clip_max_tie_weight"]) + L_weighting = torch.clip( + L_weighting, + self.split_input_params["clip_min_tie_weight"], + self.split_input_params["clip_max_tie_weight"], + ) L_weighting = L_weighting.view(1, 1, Ly * Lx).to(device) weighting = weighting * L_weighting return weighting - def get_fold_unfold(self, x, kernel_size, stride, uf=1, df=1): # todo load once not every time, shorten code + def get_fold_unfold(self, x, kernel_size, stride, uf=1, df=1): # todo load once not every time, shorten code """ :param x: img of size (bs, c, h, w) :return: n img crops of size (n, bs, c, kernel_size[0], kernel_size[1]) @@ -809,35 +852,39 @@ class LatentDiffusion(DDPM): fold = torch.nn.Fold(output_size=x.shape[2:], **fold_params) weighting = self.get_weighting(kernel_size[0], kernel_size[1], Ly, Lx, x.device).to(x.dtype) - normalization = fold(weighting).view(1, 1, h, w) # normalizes the overlap + normalization = fold(weighting).view(1, 1, h, w) # normalizes the overlap weighting = weighting.view((1, 1, kernel_size[0], kernel_size[1], Ly * Lx)) elif uf > 1 and df == 1: fold_params = dict(kernel_size=kernel_size, dilation=1, padding=0, stride=stride) unfold = torch.nn.Unfold(**fold_params) - fold_params2 = dict(kernel_size=(kernel_size[0] * uf, kernel_size[0] * uf), - dilation=1, - padding=0, - stride=(stride[0] * uf, stride[1] * uf)) + fold_params2 = dict( + kernel_size=(kernel_size[0] * uf, kernel_size[0] * uf), + dilation=1, + padding=0, + stride=(stride[0] * uf, stride[1] * uf), + ) fold = torch.nn.Fold(output_size=(x.shape[2] * uf, x.shape[3] * uf), **fold_params2) weighting = self.get_weighting(kernel_size[0] * uf, kernel_size[1] * uf, Ly, Lx, x.device).to(x.dtype) - normalization = fold(weighting).view(1, 1, h * uf, w * uf) # normalizes the overlap + normalization = fold(weighting).view(1, 1, h * uf, w * uf) # normalizes the overlap weighting = weighting.view((1, 1, kernel_size[0] * uf, kernel_size[1] * uf, Ly * Lx)) elif df > 1 and uf == 1: fold_params = dict(kernel_size=kernel_size, dilation=1, padding=0, stride=stride) unfold = torch.nn.Unfold(**fold_params) - fold_params2 = dict(kernel_size=(kernel_size[0] // df, kernel_size[0] // df), - dilation=1, - padding=0, - stride=(stride[0] // df, stride[1] // df)) + fold_params2 = dict( + kernel_size=(kernel_size[0] // df, kernel_size[0] // df), + dilation=1, + padding=0, + stride=(stride[0] // df, stride[1] // df), + ) fold = torch.nn.Fold(output_size=(x.shape[2] // df, x.shape[3] // df), **fold_params2) weighting = self.get_weighting(kernel_size[0] // df, kernel_size[1] // df, Ly, Lx, x.device).to(x.dtype) - normalization = fold(weighting).view(1, 1, h // df, w // df) # normalizes the overlap + normalization = fold(weighting).view(1, 1, h // df, w // df) # normalizes the overlap weighting = weighting.view((1, 1, kernel_size[0] // df, kernel_size[1] // df, Ly * Lx)) else: @@ -846,15 +893,17 @@ class LatentDiffusion(DDPM): return fold, unfold, normalization, weighting @torch.no_grad() - def get_input(self, - batch, - k, - return_first_stage_outputs=False, - force_c_encode=False, - cond_key=None, - return_original_cond=False, - bs=None, - return_x=False): + def get_input( + self, + batch, + k, + return_first_stage_outputs=False, + force_c_encode=False, + cond_key=None, + return_original_cond=False, + bs=None, + return_x=False, + ): x = super().get_input(batch, k) if bs is not None: x = x[:bs] @@ -866,9 +915,9 @@ class LatentDiffusion(DDPM): if cond_key is None: cond_key = self.cond_stage_key if cond_key != self.first_stage_key: - if cond_key in ['caption', 'coordinates_bbox', "txt"]: + if cond_key in ["caption", "coordinates_bbox", "txt"]: xc = batch[cond_key] - elif cond_key in ['class_label', 'cls']: + elif cond_key in ["class_label", "cls"]: xc = batch else: xc = super().get_input(batch, cond_key).to(self.device) @@ -887,14 +936,14 @@ class LatentDiffusion(DDPM): if self.use_positional_encodings: pos_x, pos_y = self.compute_latent_shifts(batch) ckey = __conditioning_keys__[self.model.conditioning_key] - c = {ckey: c, 'pos_x': pos_x, 'pos_y': pos_y} + c = {ckey: c, "pos_x": pos_x, "pos_y": pos_y} else: c = None xc = None if self.use_positional_encodings: pos_x, pos_y = self.compute_latent_shifts(batch) - c = {'pos_x': pos_x, 'pos_y': pos_y} + c = {"pos_x": pos_x, "pos_y": pos_y} out = [z, c] if return_first_stage_outputs: xrec = self.decode_first_stage(z) @@ -912,9 +961,9 @@ class LatentDiffusion(DDPM): if z.dim() == 4: z = torch.argmax(z.exp(), dim=1).long() z = self.first_stage_model.quantize.get_codebook_entry(z, shape=None) - z = rearrange(z, 'b h w c -> b c h w').contiguous() + z = rearrange(z, "b h w c -> b c h w").contiguous() - z = 1. / self.scale_factor * z + z = 1.0 / self.scale_factor * z return self.first_stage_model.decode(z) @torch.no_grad() @@ -932,7 +981,7 @@ class LatentDiffusion(DDPM): assert c is not None if self.cond_stage_trainable: c = self.get_learned_conditioning(c) - if self.shorten_cond_schedule: # TODO: drop this option + if self.shorten_cond_schedule: # TODO: drop this option tc = self.cond_ids[t].to(self.device) c = self.q_sample(x_start=c, t=tc, noise=torch.randn_like(c.float())) return self.p_losses(x, c, t, *args, **kwargs) @@ -944,7 +993,7 @@ class LatentDiffusion(DDPM): else: if not isinstance(cond, list): cond = [cond] - key = 'c_concat' if self.model.conditioning_key == 'concat' else 'c_crossattn' + key = "c_concat" if self.model.conditioning_key == "concat" else "c_crossattn" cond = {key: cond} x_recon = self.model(x_noisy, t, **cond) @@ -955,8 +1004,9 @@ class LatentDiffusion(DDPM): return x_recon def _predict_eps_from_xstart(self, x_t, t, pred_xstart): - return (extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - pred_xstart) / \ - extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) + return ( + extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - pred_xstart + ) / extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) def _prior_bpd(self, x_start): """ @@ -978,7 +1028,7 @@ class LatentDiffusion(DDPM): model_output = self.apply_model(x_noisy, t, cond) loss_dict = {} - prefix = 'train' if self.training else 'val' + prefix = "train" if self.training else "val" if self.parameterization == "x0": target = x_start @@ -990,36 +1040,38 @@ class LatentDiffusion(DDPM): raise NotImplementedError() loss_simple = self.get_loss(model_output, target, mean=False).mean([1, 2, 3]) - loss_dict.update({f'{prefix}/loss_simple': loss_simple.mean()}) + loss_dict.update({f"{prefix}/loss_simple": loss_simple.mean()}) logvar_t = self.logvar[t].to(self.device) loss = loss_simple / torch.exp(logvar_t) + logvar_t # loss = loss_simple / torch.exp(self.logvar) + self.logvar if self.learn_logvar: - loss_dict.update({f'{prefix}/loss_gamma': loss.mean()}) - loss_dict.update({'logvar': self.logvar.data.mean()}) + loss_dict.update({f"{prefix}/loss_gamma": loss.mean()}) + loss_dict.update({"logvar": self.logvar.data.mean()}) loss = self.l_simple_weight * loss.mean() loss_vlb = self.get_loss(model_output, target, mean=False).mean(dim=(1, 2, 3)) loss_vlb = (self.lvlb_weights[t] * loss_vlb).mean() - loss_dict.update({f'{prefix}/loss_vlb': loss_vlb}) - loss += (self.original_elbo_weight * loss_vlb) - loss_dict.update({f'{prefix}/loss': loss}) + loss_dict.update({f"{prefix}/loss_vlb": loss_vlb}) + loss += self.original_elbo_weight * loss_vlb + loss_dict.update({f"{prefix}/loss": loss}) return loss, loss_dict - def p_mean_variance(self, - x, - c, - t, - clip_denoised: bool, - return_codebook_ids=False, - quantize_denoised=False, - return_x0=False, - score_corrector=None, - corrector_kwargs=None): + def p_mean_variance( + self, + x, + c, + t, + clip_denoised: bool, + return_codebook_ids=False, + quantize_denoised=False, + return_x0=False, + score_corrector=None, + corrector_kwargs=None, + ): t_in = t model_out = self.apply_model(x, t_in, c, return_ids=return_codebook_ids) @@ -1038,7 +1090,7 @@ class LatentDiffusion(DDPM): raise NotImplementedError() if clip_denoised: - x_recon.clamp_(-1., 1.) + x_recon.clamp_(-1.0, 1.0) if quantize_denoised: x_recon, _, [_, _, indices] = self.first_stage_model.quantize(x_recon) model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t) @@ -1050,29 +1102,33 @@ class LatentDiffusion(DDPM): return model_mean, posterior_variance, posterior_log_variance @torch.no_grad() - def p_sample(self, - x, - c, - t, - clip_denoised=False, - repeat_noise=False, - return_codebook_ids=False, - quantize_denoised=False, - return_x0=False, - temperature=1., - noise_dropout=0., - score_corrector=None, - corrector_kwargs=None): + def p_sample( + self, + x, + c, + t, + clip_denoised=False, + repeat_noise=False, + return_codebook_ids=False, + quantize_denoised=False, + return_x0=False, + temperature=1.0, + noise_dropout=0.0, + score_corrector=None, + corrector_kwargs=None, + ): b, *_, device = *x.shape, x.device - outputs = self.p_mean_variance(x=x, - c=c, - t=t, - clip_denoised=clip_denoised, - return_codebook_ids=return_codebook_ids, - quantize_denoised=quantize_denoised, - return_x0=return_x0, - score_corrector=score_corrector, - corrector_kwargs=corrector_kwargs) + outputs = self.p_mean_variance( + x=x, + c=c, + t=t, + clip_denoised=clip_denoised, + return_codebook_ids=return_codebook_ids, + quantize_denoised=quantize_denoised, + return_x0=return_x0, + score_corrector=score_corrector, + corrector_kwargs=corrector_kwargs, + ) if return_codebook_ids: raise DeprecationWarning("Support dropped.") model_mean, _, model_log_variance, logits = outputs @@ -1082,7 +1138,7 @@ class LatentDiffusion(DDPM): model_mean, _, model_log_variance = outputs noise = noise_like(x.shape, device, repeat_noise) * temperature - if noise_dropout > 0.: + if noise_dropout > 0.0: noise = torch.nn.functional.dropout(noise, p=noise_dropout) # no noise when t == 0 nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1))) @@ -1095,23 +1151,25 @@ class LatentDiffusion(DDPM): return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise @torch.no_grad() - def progressive_denoising(self, - cond, - shape, - verbose=True, - callback=None, - quantize_denoised=False, - img_callback=None, - mask=None, - x0=None, - temperature=1., - noise_dropout=0., - score_corrector=None, - corrector_kwargs=None, - batch_size=None, - x_T=None, - start_T=None, - log_every_t=None): + def progressive_denoising( + self, + cond, + shape, + verbose=True, + callback=None, + quantize_denoised=False, + img_callback=None, + mask=None, + x0=None, + temperature=1.0, + noise_dropout=0.0, + score_corrector=None, + corrector_kwargs=None, + batch_size=None, + x_T=None, + start_T=None, + log_every_t=None, + ): if not log_every_t: log_every_t = self.log_every_t timesteps = self.num_timesteps @@ -1128,40 +1186,47 @@ class LatentDiffusion(DDPM): if cond is not None: if isinstance(cond, dict): cond = { - key: cond[key][:batch_size] if not isinstance(cond[key], list) else list( - map(lambda x: x[:batch_size], cond[key])) for key in cond + key: cond[key][:batch_size] + if not isinstance(cond[key], list) + else list(map(lambda x: x[:batch_size], cond[key])) + for key in cond } else: cond = [c[:batch_size] for c in cond] if isinstance(cond, list) else cond[:batch_size] if start_T is not None: timesteps = min(timesteps, start_T) - iterator = tqdm(reversed(range(0, timesteps)), desc='Progressive Generation', - total=timesteps) if verbose else reversed(range(0, timesteps)) + iterator = ( + tqdm(reversed(range(0, timesteps)), desc="Progressive Generation", total=timesteps) + if verbose + else reversed(range(0, timesteps)) + ) if type(temperature) == float: temperature = [temperature] * timesteps for i in iterator: ts = torch.full((b,), i, device=self.device, dtype=torch.long) if self.shorten_cond_schedule: - assert self.model.conditioning_key != 'hybrid' + assert self.model.conditioning_key != "hybrid" tc = self.cond_ids[ts].to(cond.device) cond = self.q_sample(x_start=cond, t=tc, noise=torch.randn_like(cond)) - img, x0_partial = self.p_sample(img, - cond, - ts, - clip_denoised=self.clip_denoised, - quantize_denoised=quantize_denoised, - return_x0=True, - temperature=temperature[i], - noise_dropout=noise_dropout, - score_corrector=score_corrector, - corrector_kwargs=corrector_kwargs) + img, x0_partial = self.p_sample( + img, + cond, + ts, + clip_denoised=self.clip_denoised, + quantize_denoised=quantize_denoised, + return_x0=True, + temperature=temperature[i], + noise_dropout=noise_dropout, + score_corrector=score_corrector, + corrector_kwargs=corrector_kwargs, + ) if mask is not None: assert x0 is not None img_orig = self.q_sample(x0, ts) - img = img_orig * mask + (1. - mask) * img + img = img_orig * mask + (1.0 - mask) * img if i % log_every_t == 0 or i == timesteps - 1: intermediates.append(x0_partial) @@ -1172,21 +1237,22 @@ class LatentDiffusion(DDPM): return img, intermediates @torch.no_grad() - def p_sample_loop(self, - cond, - shape, - return_intermediates=False, - x_T=None, - verbose=True, - callback=None, - timesteps=None, - quantize_denoised=False, - mask=None, - x0=None, - img_callback=None, - start_T=None, - log_every_t=None): - + def p_sample_loop( + self, + cond, + shape, + return_intermediates=False, + x_T=None, + verbose=True, + callback=None, + timesteps=None, + quantize_denoised=False, + mask=None, + x0=None, + img_callback=None, + start_T=None, + log_every_t=None, + ): if not log_every_t: log_every_t = self.log_every_t device = self.betas.device @@ -1202,24 +1268,27 @@ class LatentDiffusion(DDPM): if start_T is not None: timesteps = min(timesteps, start_T) - iterator = tqdm(reversed(range(0, timesteps)), desc='Sampling t', total=timesteps) if verbose else reversed( - range(0, timesteps)) + iterator = ( + tqdm(reversed(range(0, timesteps)), desc="Sampling t", total=timesteps) + if verbose + else reversed(range(0, timesteps)) + ) if mask is not None: assert x0 is not None - assert x0.shape[2:3] == mask.shape[2:3] # spatial size has to match + assert x0.shape[2:3] == mask.shape[2:3] # spatial size has to match for i in iterator: ts = torch.full((b,), i, device=device, dtype=torch.long) if self.shorten_cond_schedule: - assert self.model.conditioning_key != 'hybrid' + assert self.model.conditioning_key != "hybrid" tc = self.cond_ids[ts].to(cond.device) cond = self.q_sample(x_start=cond, t=tc, noise=torch.randn_like(cond)) img = self.p_sample(img, cond, ts, clip_denoised=self.clip_denoised, quantize_denoised=quantize_denoised) if mask is not None: img_orig = self.q_sample(x0, ts) - img = img_orig * mask + (1. - mask) * img + img = img_orig * mask + (1.0 - mask) * img if i % log_every_t == 0 or i == timesteps - 1: intermediates.append(img) @@ -1233,37 +1302,43 @@ class LatentDiffusion(DDPM): return img @torch.no_grad() - def sample(self, - cond, - batch_size=16, - return_intermediates=False, - x_T=None, - verbose=True, - timesteps=None, - quantize_denoised=False, - mask=None, - x0=None, - shape=None, - **kwargs): + def sample( + self, + cond, + batch_size=16, + return_intermediates=False, + x_T=None, + verbose=True, + timesteps=None, + quantize_denoised=False, + mask=None, + x0=None, + shape=None, + **kwargs, + ): if shape is None: shape = (batch_size, self.channels, self.image_size, self.image_size) if cond is not None: if isinstance(cond, dict): cond = { - key: cond[key][:batch_size] if not isinstance(cond[key], list) else list( - map(lambda x: x[:batch_size], cond[key])) for key in cond + key: cond[key][:batch_size] + if not isinstance(cond[key], list) + else list(map(lambda x: x[:batch_size], cond[key])) + for key in cond } else: cond = [c[:batch_size] for c in cond] if isinstance(cond, list) else cond[:batch_size] - return self.p_sample_loop(cond, - shape, - return_intermediates=return_intermediates, - x_T=x_T, - verbose=verbose, - timesteps=timesteps, - quantize_denoised=quantize_denoised, - mask=mask, - x0=x0) + return self.p_sample_loop( + cond, + shape, + return_intermediates=return_intermediates, + x_T=x_T, + verbose=verbose, + timesteps=timesteps, + quantize_denoised=quantize_denoised, + mask=mask, + x0=x0, + ) @torch.no_grad() def sample_log(self, cond, batch_size, ddim, ddim_steps, **kwargs): @@ -1295,41 +1370,45 @@ class LatentDiffusion(DDPM): return self.get_learned_conditioning(xc) else: raise NotImplementedError("todo") - if isinstance(c, list): # in case the encoder gives us a list + if isinstance(c, list): # in case the encoder gives us a list for i in range(len(c)): - c[i] = repeat(c[i], '1 ... -> b ...', b=batch_size).to(self.device) + c[i] = repeat(c[i], "1 ... -> b ...", b=batch_size).to(self.device) else: - c = repeat(c, '1 ... -> b ...', b=batch_size).to(self.device) + c = repeat(c, "1 ... -> b ...", b=batch_size).to(self.device) return c @torch.no_grad() - def log_images(self, - batch, - N=8, - n_row=4, - sample=True, - ddim_steps=50, - ddim_eta=0., - return_keys=None, - quantize_denoised=True, - inpaint=True, - plot_denoise_rows=False, - plot_progressive_rows=True, - plot_diffusion_rows=True, - unconditional_guidance_scale=1., - unconditional_guidance_label=None, - use_ema_scope=True, - **kwargs): + def log_images( + self, + batch, + N=8, + n_row=4, + sample=True, + ddim_steps=50, + ddim_eta=0.0, + return_keys=None, + quantize_denoised=True, + inpaint=True, + plot_denoise_rows=False, + plot_progressive_rows=True, + plot_diffusion_rows=True, + unconditional_guidance_scale=1.0, + unconditional_guidance_label=None, + use_ema_scope=True, + **kwargs, + ): ema_scope = self.ema_scope if use_ema_scope else nullcontext use_ddim = ddim_steps is not None log = dict() - z, c, x, xrec, xc = self.get_input(batch, - self.first_stage_key, - return_first_stage_outputs=True, - force_c_encode=True, - return_original_cond=True, - bs=N) + z, c, x, xrec, xc = self.get_input( + batch, + self.first_stage_key, + return_first_stage_outputs=True, + force_c_encode=True, + return_original_cond=True, + bs=N, + ) N = min(x.shape[0], N) n_row = min(x.shape[0], n_row) log["inputs"] = x @@ -1341,10 +1420,10 @@ class LatentDiffusion(DDPM): elif self.cond_stage_key in ["caption", "txt"]: xc = log_txt_as_img((x.shape[2], x.shape[3]), batch[self.cond_stage_key], size=x.shape[2] // 25) log["conditioning"] = xc - elif self.cond_stage_key in ['class_label', "cls"]: + elif self.cond_stage_key in ["class_label", "cls"]: try: xc = log_txt_as_img((x.shape[2], x.shape[3]), batch["human_label"], size=x.shape[2] // 25) - log['conditioning'] = xc + log["conditioning"] = xc except KeyError: # probably no "human_label" in batch pass @@ -1359,26 +1438,24 @@ class LatentDiffusion(DDPM): z_start = z[:n_row] for t in range(self.num_timesteps): if t % self.log_every_t == 0 or t == self.num_timesteps - 1: - t = repeat(torch.tensor([t]), '1 -> b', b=n_row) + t = repeat(torch.tensor([t]), "1 -> b", b=n_row) t = t.to(self.device).long() noise = torch.randn_like(z_start) z_noisy = self.q_sample(x_start=z_start, t=t, noise=noise) diffusion_row.append(self.decode_first_stage(z_noisy)) - diffusion_row = torch.stack(diffusion_row) # n_log_step, n_row, C, H, W - diffusion_grid = rearrange(diffusion_row, 'n b c h w -> b n c h w') - diffusion_grid = rearrange(diffusion_grid, 'b n c h w -> (b n) c h w') + diffusion_row = torch.stack(diffusion_row) # n_log_step, n_row, C, H, W + diffusion_grid = rearrange(diffusion_row, "n b c h w -> b n c h w") + diffusion_grid = rearrange(diffusion_grid, "b n c h w -> (b n) c h w") diffusion_grid = make_grid(diffusion_grid, nrow=diffusion_row.shape[0]) log["diffusion_row"] = diffusion_grid if sample: # get denoise row with ema_scope("Sampling"): - samples, z_denoise_row = self.sample_log(cond=c, - batch_size=N, - ddim=use_ddim, - ddim_steps=ddim_steps, - eta=ddim_eta) + samples, z_denoise_row = self.sample_log( + cond=c, batch_size=N, ddim=use_ddim, ddim_steps=ddim_steps, eta=ddim_eta + ) # samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True) x_samples = self.decode_first_stage(samples) log["samples"] = x_samples @@ -1386,16 +1463,16 @@ class LatentDiffusion(DDPM): denoise_grid = self._get_denoise_row_from_list(z_denoise_row) log["denoise_row"] = denoise_grid - if quantize_denoised and not isinstance(self.first_stage_model, AutoencoderKL) and not isinstance( - self.first_stage_model, IdentityFirstStage): + if ( + quantize_denoised + and not isinstance(self.first_stage_model, AutoencoderKL) + and not isinstance(self.first_stage_model, IdentityFirstStage) + ): # also display when quantizing x0 while sampling with ema_scope("Plotting Quantized Denoised"): - samples, z_denoise_row = self.sample_log(cond=c, - batch_size=N, - ddim=use_ddim, - ddim_steps=ddim_steps, - eta=ddim_eta, - quantize_denoised=True) + samples, z_denoise_row = self.sample_log( + cond=c, batch_size=N, ddim=use_ddim, ddim_steps=ddim_steps, eta=ddim_eta, quantize_denoised=True + ) # samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True, # quantize_denoised=True) x_samples = self.decode_first_stage(samples.to(self.device)) @@ -1423,38 +1500,30 @@ class LatentDiffusion(DDPM): b, h, w = z.shape[0], z.shape[2], z.shape[3] mask = torch.ones(N, h, w).to(self.device) # zeros will be filled in - mask[:, h // 4:3 * h // 4, w // 4:3 * w // 4] = 0. + mask[:, h // 4 : 3 * h // 4, w // 4 : 3 * w // 4] = 0.0 mask = mask[:, None, ...] with ema_scope("Plotting Inpaint"): - samples, _ = self.sample_log(cond=c, - batch_size=N, - ddim=use_ddim, - eta=ddim_eta, - ddim_steps=ddim_steps, - x0=z[:N], - mask=mask) + samples, _ = self.sample_log( + cond=c, batch_size=N, ddim=use_ddim, eta=ddim_eta, ddim_steps=ddim_steps, x0=z[:N], mask=mask + ) x_samples = self.decode_first_stage(samples.to(self.device)) log["samples_inpainting"] = x_samples log["mask"] = mask # outpaint - mask = 1. - mask + mask = 1.0 - mask with ema_scope("Plotting Outpaint"): - samples, _ = self.sample_log(cond=c, - batch_size=N, - ddim=use_ddim, - eta=ddim_eta, - ddim_steps=ddim_steps, - x0=z[:N], - mask=mask) + samples, _ = self.sample_log( + cond=c, batch_size=N, ddim=use_ddim, eta=ddim_eta, ddim_steps=ddim_steps, x0=z[:N], mask=mask + ) x_samples = self.decode_first_stage(samples.to(self.device)) log["samples_outpainting"] = x_samples if plot_progressive_rows: with ema_scope("Plotting Progressives"): - img, progressives = self.progressive_denoising(c, - shape=(self.channels, self.image_size, self.image_size), - batch_size=N) + img, progressives = self.progressive_denoising( + c, shape=(self.channels, self.image_size, self.image_size), batch_size=N + ) prog_row = self._get_denoise_row_from_list(progressives, desc="Progressive Generation") log["progressive_row"] = prog_row @@ -1472,10 +1541,11 @@ class LatentDiffusion(DDPM): rank_zero_info(f"{self.__class__.__name__}: Also optimizing conditioner params!") params = params + list(self.cond_stage_model.parameters()) if self.learn_logvar: - rank_zero_info('Diffusion model optimizing logvar') + rank_zero_info("Diffusion model optimizing logvar") params.append(self.logvar) from colossalai.nn.optimizer import HybridAdam + opt = HybridAdam(params, lr=lr) # opt = torch.optim.AdamW(params, lr=lr) @@ -1483,7 +1553,7 @@ class LatentDiffusion(DDPM): scheduler = LambdaLinearScheduler(**self.scheduler_config) rank_zero_info("Setting up LambdaLR scheduler...") - scheduler = [{'scheduler': LambdaLR(opt, lr_lambda=scheduler.schedule), 'interval': 'step', 'frequency': 1}] + scheduler = [{"scheduler": LambdaLR(opt, lr_lambda=scheduler.schedule), "interval": "step", "frequency": 1}] return [opt], scheduler return opt @@ -1493,45 +1563,44 @@ class LatentDiffusion(DDPM): if not hasattr(self, "colorize"): self.colorize = torch.randn(3, x.shape[1], 1, 1).to(x) x = nn.functional.conv2d(x, weight=self.colorize) - x = 2. * (x - x.min()) / (x.max() - x.min()) - 1. + x = 2.0 * (x - x.min()) / (x.max() - x.min()) - 1.0 return x class DiffusionWrapper(pl.LightningModule): - def __init__(self, diff_model_config, conditioning_key): super().__init__() self.sequential_cross_attn = diff_model_config.pop("sequential_crossattn", False) self.diffusion_model = UNetModel(**diff_model_config) self.conditioning_key = conditioning_key - assert self.conditioning_key in [None, 'concat', 'crossattn', 'hybrid', 'adm', 'hybrid-adm', 'crossattn-adm'] + assert self.conditioning_key in [None, "concat", "crossattn", "hybrid", "adm", "hybrid-adm", "crossattn-adm"] def forward(self, x, t, c_concat: list = None, c_crossattn: list = None, c_adm=None): if self.conditioning_key is None: out = self.diffusion_model(x, t) - elif self.conditioning_key == 'concat': + elif self.conditioning_key == "concat": xc = torch.cat([x] + c_concat, dim=1) out = self.diffusion_model(xc, t) - elif self.conditioning_key == 'crossattn': + elif self.conditioning_key == "crossattn": if not self.sequential_cross_attn: cc = torch.cat(c_crossattn, 1) else: cc = c_crossattn out = self.diffusion_model(x, t, context=cc) - elif self.conditioning_key == 'hybrid': + elif self.conditioning_key == "hybrid": xc = torch.cat([x] + c_concat, dim=1) cc = torch.cat(c_crossattn, 1) out = self.diffusion_model(xc, t, context=cc) - elif self.conditioning_key == 'hybrid-adm': + elif self.conditioning_key == "hybrid-adm": assert c_adm is not None xc = torch.cat([x] + c_concat, dim=1) cc = torch.cat(c_crossattn, 1) out = self.diffusion_model(xc, t, context=cc, y=c_adm) - elif self.conditioning_key == 'crossattn-adm': + elif self.conditioning_key == "crossattn-adm": assert c_adm is not None cc = torch.cat(c_crossattn, 1) out = self.diffusion_model(x, t, context=cc, y=c_adm) - elif self.conditioning_key == 'adm': + elif self.conditioning_key == "adm": cc = c_crossattn[0] out = self.diffusion_model(x, t, y=cc) else: @@ -1541,7 +1610,6 @@ class DiffusionWrapper(pl.LightningModule): class LatentUpscaleDiffusion(LatentDiffusion): - def __init__(self, *args, low_scale_config, low_scale_key="LR", noise_level_key=None, **kwargs): super().__init__(*args, **kwargs) # assumes that neither the cond_stage nor the low_scale_model contain trainable params @@ -1562,14 +1630,16 @@ class LatentUpscaleDiffusion(LatentDiffusion): if not log_mode: z, c = super().get_input(batch, k, force_c_encode=True, bs=bs) else: - z, c, x, xrec, xc = super().get_input(batch, - self.first_stage_key, - return_first_stage_outputs=True, - force_c_encode=True, - return_original_cond=True, - bs=bs) + z, c, x, xrec, xc = super().get_input( + batch, + self.first_stage_key, + return_first_stage_outputs=True, + force_c_encode=True, + return_original_cond=True, + bs=bs, + ) x_low = batch[self.low_scale_key][:bs] - x_low = rearrange(x_low, 'b h w c -> b c h w') + x_low = rearrange(x_low, "b h w c -> b c h w") if self.use_fp16: x_low = x_low.to(memory_format=torch.contiguous_format).half() else: @@ -1577,7 +1647,7 @@ class LatentUpscaleDiffusion(LatentDiffusion): zx, noise_level = self.low_scale_model(x_low) if self.noise_level_key is not None: # get noise level from batch instead, e.g. when extracting a custom noise level for bsr - raise NotImplementedError('TODO') + raise NotImplementedError("TODO") all_conds = {"c_concat": [zx], "c_crossattn": [c], "c_adm": noise_level} if log_mode: @@ -1587,29 +1657,30 @@ class LatentUpscaleDiffusion(LatentDiffusion): return z, all_conds @torch.no_grad() - def log_images(self, - batch, - N=8, - n_row=4, - sample=True, - ddim_steps=200, - ddim_eta=1., - return_keys=None, - plot_denoise_rows=False, - plot_progressive_rows=True, - plot_diffusion_rows=True, - unconditional_guidance_scale=1., - unconditional_guidance_label=None, - use_ema_scope=True, - **kwargs): + def log_images( + self, + batch, + N=8, + n_row=4, + sample=True, + ddim_steps=200, + ddim_eta=1.0, + return_keys=None, + plot_denoise_rows=False, + plot_progressive_rows=True, + plot_diffusion_rows=True, + unconditional_guidance_scale=1.0, + unconditional_guidance_label=None, + use_ema_scope=True, + **kwargs, + ): ema_scope = self.ema_scope if use_ema_scope else nullcontext use_ddim = ddim_steps is not None log = dict() - z, c, x, xrec, xc, x_low, x_low_rec, noise_level = self.get_input(batch, - self.first_stage_key, - bs=N, - log_mode=True) + z, c, x, xrec, xc, x_low, x_low_rec, noise_level = self.get_input( + batch, self.first_stage_key, bs=N, log_mode=True + ) N = min(x.shape[0], N) n_row = min(x.shape[0], n_row) log["inputs"] = x @@ -1623,9 +1694,9 @@ class LatentUpscaleDiffusion(LatentDiffusion): elif self.cond_stage_key in ["caption", "txt"]: xc = log_txt_as_img((x.shape[2], x.shape[3]), batch[self.cond_stage_key], size=x.shape[2] // 25) log["conditioning"] = xc - elif self.cond_stage_key in ['class_label', 'cls']: + elif self.cond_stage_key in ["class_label", "cls"]: xc = log_txt_as_img((x.shape[2], x.shape[3]), batch["human_label"], size=x.shape[2] // 25) - log['conditioning'] = xc + log["conditioning"] = xc elif isimage(xc): log["conditioning"] = xc if ismap(xc): @@ -1637,26 +1708,24 @@ class LatentUpscaleDiffusion(LatentDiffusion): z_start = z[:n_row] for t in range(self.num_timesteps): if t % self.log_every_t == 0 or t == self.num_timesteps - 1: - t = repeat(torch.tensor([t]), '1 -> b', b=n_row) + t = repeat(torch.tensor([t]), "1 -> b", b=n_row) t = t.to(self.device).long() noise = torch.randn_like(z_start) z_noisy = self.q_sample(x_start=z_start, t=t, noise=noise) diffusion_row.append(self.decode_first_stage(z_noisy)) - diffusion_row = torch.stack(diffusion_row) # n_log_step, n_row, C, H, W - diffusion_grid = rearrange(diffusion_row, 'n b c h w -> b n c h w') - diffusion_grid = rearrange(diffusion_grid, 'b n c h w -> (b n) c h w') + diffusion_row = torch.stack(diffusion_row) # n_log_step, n_row, C, H, W + diffusion_grid = rearrange(diffusion_row, "n b c h w -> b n c h w") + diffusion_grid = rearrange(diffusion_grid, "b n c h w -> (b n) c h w") diffusion_grid = make_grid(diffusion_grid, nrow=diffusion_row.shape[0]) log["diffusion_row"] = diffusion_grid if sample: # get denoise row with ema_scope("Sampling"): - samples, z_denoise_row = self.sample_log(cond=c, - batch_size=N, - ddim=use_ddim, - ddim_steps=ddim_steps, - eta=ddim_eta) + samples, z_denoise_row = self.sample_log( + cond=c, batch_size=N, ddim=use_ddim, ddim_steps=ddim_steps, eta=ddim_eta + ) # samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True) x_samples = self.decode_first_stage(samples) log["samples"] = x_samples @@ -1673,9 +1742,9 @@ class LatentUpscaleDiffusion(LatentDiffusion): if k == "c_crossattn": assert isinstance(c[k], list) and len(c[k]) == 1 uc[k] = [uc_tmp] - elif k == "c_adm": # todo: only run with text-based guidance? + elif k == "c_adm": # todo: only run with text-based guidance? assert isinstance(c[k], torch.Tensor) - #uc[k] = torch.ones_like(c[k]) * self.low_scale_model.max_noise_level + # uc[k] = torch.ones_like(c[k]) * self.low_scale_model.max_noise_level uc[k] = c[k] elif isinstance(c[k], list): uc[k] = [c[k][i] for i in range(len(c[k]))] @@ -1697,9 +1766,9 @@ class LatentUpscaleDiffusion(LatentDiffusion): if plot_progressive_rows: with ema_scope("Plotting Progressives"): - img, progressives = self.progressive_denoising(c, - shape=(self.channels, self.image_size, self.image_size), - batch_size=N) + img, progressives = self.progressive_denoising( + c, shape=(self.channels, self.image_size, self.image_size), batch_size=N + ) prog_row = self._get_denoise_row_from_list(progressives, desc="Progressive Generation") log["progressive_row"] = prog_row @@ -1708,21 +1777,24 @@ class LatentUpscaleDiffusion(LatentDiffusion): class LatentFinetuneDiffusion(LatentDiffusion): """ - Basis for different finetunas, such as inpainting or depth2image - To disable finetuning mode, set finetune_keys to None + Basis for different finetunas, such as inpainting or depth2image + To disable finetuning mode, set finetune_keys to None """ def __init__( - self, - concat_keys: tuple, - finetune_keys=("model.diffusion_model.input_blocks.0.0.weight", - "model_ema.diffusion_modelinput_blocks00weight"), - keep_finetune_dims=4, - # if model was trained without concat mode before and we would like to keep these channels - c_concat_log_start=None, # to log reconstruction of c_concat codes - c_concat_log_end=None, - *args, - **kwargs): + self, + concat_keys: tuple, + finetune_keys=( + "model.diffusion_model.input_blocks.0.0.weight", + "model_ema.diffusion_modelinput_blocks00weight", + ), + keep_finetune_dims=4, + # if model was trained without concat mode before and we would like to keep these channels + c_concat_log_start=None, # to log reconstruction of c_concat codes + c_concat_log_end=None, + *args, + **kwargs, + ): ckpt = kwargs.pop("ckpt", None) ignore_keys = kwargs.pop("ignore_keys", list()) super().__init__(*args, **kwargs) @@ -1732,7 +1804,7 @@ class LatentFinetuneDiffusion(LatentDiffusion): self.c_concat_log_start = c_concat_log_start self.c_concat_log_end = c_concat_log_end if exists(self.finetune_keys): - assert exists(ckpt), 'can only finetune from a given checkpoint' + assert exists(ckpt), "can only finetune from a given checkpoint" if exists(ckpt): self.init_from_ckpt(ckpt, ignore_keys) @@ -1755,13 +1827,14 @@ class LatentFinetuneDiffusion(LatentDiffusion): rank_zero_info( f"modifying key '{name}' and keeping its original {self.keep_dims} (channels) dimensions only" ) - new_entry = torch.zeros_like(param) # zero init - assert exists(new_entry), 'did not find matching parameter to modify' - new_entry[:, :self.keep_dims, ...] = sd[k] + new_entry = torch.zeros_like(param) # zero init + assert exists(new_entry), "did not find matching parameter to modify" + new_entry[:, : self.keep_dims, ...] = sd[k] sd[k] = new_entry - missing, unexpected = self.load_state_dict(sd, strict=False) if not only_model else self.model.load_state_dict( - sd, strict=False) + missing, unexpected = ( + self.load_state_dict(sd, strict=False) if not only_model else self.model.load_state_dict(sd, strict=False) + ) rank_zero_info(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys") if len(missing) > 0: rank_zero_info(f"Missing Keys: {missing}") @@ -1769,23 +1842,25 @@ class LatentFinetuneDiffusion(LatentDiffusion): rank_zero_info(f"Unexpected Keys: {unexpected}") @torch.no_grad() - def log_images(self, - batch, - N=8, - n_row=4, - sample=True, - ddim_steps=200, - ddim_eta=1., - return_keys=None, - quantize_denoised=True, - inpaint=True, - plot_denoise_rows=False, - plot_progressive_rows=True, - plot_diffusion_rows=True, - unconditional_guidance_scale=1., - unconditional_guidance_label=None, - use_ema_scope=True, - **kwargs): + def log_images( + self, + batch, + N=8, + n_row=4, + sample=True, + ddim_steps=200, + ddim_eta=1.0, + return_keys=None, + quantize_denoised=True, + inpaint=True, + plot_denoise_rows=False, + plot_progressive_rows=True, + plot_diffusion_rows=True, + unconditional_guidance_scale=1.0, + unconditional_guidance_label=None, + use_ema_scope=True, + **kwargs, + ): ema_scope = self.ema_scope if use_ema_scope else nullcontext use_ddim = ddim_steps is not None @@ -1803,16 +1878,16 @@ class LatentFinetuneDiffusion(LatentDiffusion): elif self.cond_stage_key in ["caption", "txt"]: xc = log_txt_as_img((x.shape[2], x.shape[3]), batch[self.cond_stage_key], size=x.shape[2] // 25) log["conditioning"] = xc - elif self.cond_stage_key in ['class_label', 'cls']: + elif self.cond_stage_key in ["class_label", "cls"]: xc = log_txt_as_img((x.shape[2], x.shape[3]), batch["human_label"], size=x.shape[2] // 25) - log['conditioning'] = xc + log["conditioning"] = xc elif isimage(xc): log["conditioning"] = xc if ismap(xc): log["original_conditioning"] = self.to_rgb(xc) if not (self.c_concat_log_start is None and self.c_concat_log_end is None): - log["c_concat_decoded"] = self.decode_first_stage(c_cat[:, self.c_concat_log_start:self.c_concat_log_end]) + log["c_concat_decoded"] = self.decode_first_stage(c_cat[:, self.c_concat_log_start : self.c_concat_log_end]) if plot_diffusion_rows: # get diffusion row @@ -1820,29 +1895,28 @@ class LatentFinetuneDiffusion(LatentDiffusion): z_start = z[:n_row] for t in range(self.num_timesteps): if t % self.log_every_t == 0 or t == self.num_timesteps - 1: - t = repeat(torch.tensor([t]), '1 -> b', b=n_row) + t = repeat(torch.tensor([t]), "1 -> b", b=n_row) t = t.to(self.device).long() noise = torch.randn_like(z_start) z_noisy = self.q_sample(x_start=z_start, t=t, noise=noise) diffusion_row.append(self.decode_first_stage(z_noisy)) - diffusion_row = torch.stack(diffusion_row) # n_log_step, n_row, C, H, W - diffusion_grid = rearrange(diffusion_row, 'n b c h w -> b n c h w') - diffusion_grid = rearrange(diffusion_grid, 'b n c h w -> (b n) c h w') + diffusion_row = torch.stack(diffusion_row) # n_log_step, n_row, C, H, W + diffusion_grid = rearrange(diffusion_row, "n b c h w -> b n c h w") + diffusion_grid = rearrange(diffusion_grid, "b n c h w -> (b n) c h w") diffusion_grid = make_grid(diffusion_grid, nrow=diffusion_row.shape[0]) log["diffusion_row"] = diffusion_grid if sample: # get denoise row with ema_scope("Sampling"): - samples, z_denoise_row = self.sample_log(cond={ - "c_concat": [c_cat], - "c_crossattn": [c] - }, - batch_size=N, - ddim=use_ddim, - ddim_steps=ddim_steps, - eta=ddim_eta) + samples, z_denoise_row = self.sample_log( + cond={"c_concat": [c_cat], "c_crossattn": [c]}, + batch_size=N, + ddim=use_ddim, + ddim_steps=ddim_steps, + eta=ddim_eta, + ) # samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True) x_samples = self.decode_first_stage(samples) log["samples"] = x_samples @@ -1856,10 +1930,7 @@ class LatentFinetuneDiffusion(LatentDiffusion): uc_full = {"c_concat": [uc_cat], "c_crossattn": [uc_cross]} with ema_scope("Sampling with classifier-free guidance"): samples_cfg, _ = self.sample_log( - cond={ - "c_concat": [c_cat], - "c_crossattn": [c] - }, + cond={"c_concat": [c_cat], "c_crossattn": [c]}, batch_size=N, ddim=use_ddim, ddim_steps=ddim_steps, @@ -1878,7 +1949,7 @@ class LatentInpaintDiffusion(LatentFinetuneDiffusion): can either run as pure inpainting model (only concat mode) or with mixed conditionings, e.g. mask as concat and text via cross-attn. To disable finetuning mode, set finetune_keys to None - """ + """ def __init__(self, concat_keys=("mask", "masked_image"), masked_image_key="masked_image", *args, **kwargs): super().__init__(concat_keys, *args, **kwargs) @@ -1888,21 +1959,23 @@ class LatentInpaintDiffusion(LatentFinetuneDiffusion): @torch.no_grad() def get_input(self, batch, k, cond_key=None, bs=None, return_first_stage_outputs=False): # note: restricted to non-trainable encoders currently - assert not self.cond_stage_trainable, 'trainable cond stages not yet supported for inpainting' - z, c, x, xrec, xc = super().get_input(batch, - self.first_stage_key, - return_first_stage_outputs=True, - force_c_encode=True, - return_original_cond=True, - bs=bs) + assert not self.cond_stage_trainable, "trainable cond stages not yet supported for inpainting" + z, c, x, xrec, xc = super().get_input( + batch, + self.first_stage_key, + return_first_stage_outputs=True, + force_c_encode=True, + return_original_cond=True, + bs=bs, + ) assert exists(self.concat_keys) c_cat = list() for ck in self.concat_keys: if self.use_fp16: - cc = rearrange(batch[ck], 'b h w c -> b c h w').to(memory_format=torch.contiguous_format).half() + cc = rearrange(batch[ck], "b h w c -> b c h w").to(memory_format=torch.contiguous_format).half() else: - cc = rearrange(batch[ck], 'b h w c -> b c h w').to(memory_format=torch.contiguous_format).float() + cc = rearrange(batch[ck], "b h w c -> b c h w").to(memory_format=torch.contiguous_format).float() if bs is not None: cc = cc[:bs] cc = cc.to(self.device) @@ -1921,8 +1994,9 @@ class LatentInpaintDiffusion(LatentFinetuneDiffusion): @torch.no_grad() def log_images(self, *args, **kwargs): log = super(LatentInpaintDiffusion, self).log_images(*args, **kwargs) - log["masked_image"] = rearrange(args[0]["masked_image"], - 'b h w c -> b c h w').to(memory_format=torch.contiguous_format).float() + log["masked_image"] = ( + rearrange(args[0]["masked_image"], "b h w c -> b c h w").to(memory_format=torch.contiguous_format).float() + ) return log @@ -1939,13 +2013,15 @@ class LatentDepth2ImageDiffusion(LatentFinetuneDiffusion): @torch.no_grad() def get_input(self, batch, k, cond_key=None, bs=None, return_first_stage_outputs=False): # note: restricted to non-trainable encoders currently - assert not self.cond_stage_trainable, 'trainable cond stages not yet supported for depth2img' - z, c, x, xrec, xc = super().get_input(batch, - self.first_stage_key, - return_first_stage_outputs=True, - force_c_encode=True, - return_original_cond=True, - bs=bs) + assert not self.cond_stage_trainable, "trainable cond stages not yet supported for depth2img" + z, c, x, xrec, xc = super().get_input( + batch, + self.first_stage_key, + return_first_stage_outputs=True, + force_c_encode=True, + return_original_cond=True, + bs=bs, + ) assert exists(self.concat_keys) assert len(self.concat_keys) == 1 @@ -1963,10 +2039,10 @@ class LatentDepth2ImageDiffusion(LatentFinetuneDiffusion): align_corners=False, ) - depth_min, depth_max = torch.amin(cc, dim=[1, 2, 3], keepdim=True), torch.amax(cc, - dim=[1, 2, 3], - keepdim=True) - cc = 2. * (cc - depth_min) / (depth_max - depth_min + 0.001) - 1. + depth_min, depth_max = torch.amin(cc, dim=[1, 2, 3], keepdim=True), torch.amax( + cc, dim=[1, 2, 3], keepdim=True + ) + cc = 2.0 * (cc - depth_min) / (depth_max - depth_min + 0.001) - 1.0 c_cat.append(cc) c_cat = torch.cat(c_cat, dim=1) all_conds = {"c_concat": [c_cat], "c_crossattn": [c]} @@ -1978,24 +2054,21 @@ class LatentDepth2ImageDiffusion(LatentFinetuneDiffusion): def log_images(self, *args, **kwargs): log = super().log_images(*args, **kwargs) depth = self.depth_model(args[0][self.depth_stage_key]) - depth_min, depth_max = torch.amin(depth, dim=[1, 2, 3], keepdim=True), \ - torch.amax(depth, dim=[1, 2, 3], keepdim=True) - log["depth"] = 2. * (depth - depth_min) / (depth_max - depth_min) - 1. + depth_min, depth_max = torch.amin(depth, dim=[1, 2, 3], keepdim=True), torch.amax( + depth, dim=[1, 2, 3], keepdim=True + ) + log["depth"] = 2.0 * (depth - depth_min) / (depth_max - depth_min) - 1.0 return log class LatentUpscaleFinetuneDiffusion(LatentFinetuneDiffusion): """ - condition on low-res image (and optionally on some spatial noise augmentation) + condition on low-res image (and optionally on some spatial noise augmentation) """ - def __init__(self, - concat_keys=("lr",), - reshuffle_patch_size=None, - low_scale_config=None, - low_scale_key=None, - *args, - **kwargs): + def __init__( + self, concat_keys=("lr",), reshuffle_patch_size=None, low_scale_config=None, low_scale_key=None, *args, **kwargs + ): super().__init__(concat_keys=concat_keys, *args, **kwargs) self.reshuffle_patch_size = reshuffle_patch_size self.low_scale_model = None @@ -2015,13 +2088,15 @@ class LatentUpscaleFinetuneDiffusion(LatentFinetuneDiffusion): @torch.no_grad() def get_input(self, batch, k, cond_key=None, bs=None, return_first_stage_outputs=False): # note: restricted to non-trainable encoders currently - assert not self.cond_stage_trainable, 'trainable cond stages not yet supported for upscaling-ft' - z, c, x, xrec, xc = super().get_input(batch, - self.first_stage_key, - return_first_stage_outputs=True, - force_c_encode=True, - return_original_cond=True, - bs=bs) + assert not self.cond_stage_trainable, "trainable cond stages not yet supported for upscaling-ft" + z, c, x, xrec, xc = super().get_input( + batch, + self.first_stage_key, + return_first_stage_outputs=True, + force_c_encode=True, + return_original_cond=True, + bs=bs, + ) assert exists(self.concat_keys) assert len(self.concat_keys) == 1 @@ -2030,13 +2105,15 @@ class LatentUpscaleFinetuneDiffusion(LatentFinetuneDiffusion): noise_level = None for ck in self.concat_keys: cc = batch[ck] - cc = rearrange(cc, 'b h w c -> b c h w') + cc = rearrange(cc, "b h w c -> b c h w") if exists(self.reshuffle_patch_size): assert isinstance(self.reshuffle_patch_size, int) - cc = rearrange(cc, - 'b c (p1 h) (p2 w) -> b (p1 p2 c) h w', - p1=self.reshuffle_patch_size, - p2=self.reshuffle_patch_size) + cc = rearrange( + cc, + "b c (p1 h) (p2 w) -> b (p1 p2 c) h w", + p1=self.reshuffle_patch_size, + p2=self.reshuffle_patch_size, + ) if bs is not None: cc = cc[:bs] cc = cc.to(self.device) @@ -2055,5 +2132,5 @@ class LatentUpscaleFinetuneDiffusion(LatentFinetuneDiffusion): @torch.no_grad() def log_images(self, *args, **kwargs): log = super().log_images(*args, **kwargs) - log["lr"] = rearrange(args[0]["lr"], 'b h w c -> b c h w') + log["lr"] = rearrange(args[0]["lr"], "b h w c -> b c h w") return log diff --git a/examples/images/diffusion/ldm/models/diffusion/dpm_solver/__init__.py b/examples/images/diffusion/ldm/models/diffusion/dpm_solver/__init__.py index 7427f38c07530afbab79154ea8aaf88c4bf70a08..f56611cb5fb3682486f83329da3583a95800ca20 100644 --- a/examples/images/diffusion/ldm/models/diffusion/dpm_solver/__init__.py +++ b/examples/images/diffusion/ldm/models/diffusion/dpm_solver/__init__.py @@ -1 +1 @@ -from .sampler import DPMSolverSampler \ No newline at end of file +from .sampler import DPMSolverSampler diff --git a/examples/images/diffusion/ldm/models/diffusion/dpm_solver/dpm_solver.py b/examples/images/diffusion/ldm/models/diffusion/dpm_solver/dpm_solver.py index 095e5ba3ce0b1aa7f4b3f1e2e5d8fff7cfe6dc8c..66063320ec7851cecccc2599dfe1702addc3db74 100644 --- a/examples/images/diffusion/ldm/models/diffusion/dpm_solver/dpm_solver.py +++ b/examples/images/diffusion/ldm/models/diffusion/dpm_solver/dpm_solver.py @@ -1,17 +1,17 @@ -import torch -import torch.nn.functional as F import math + +import torch from tqdm import tqdm class NoiseScheduleVP: def __init__( - self, - schedule='discrete', - betas=None, - alphas_cumprod=None, - continuous_beta_0=0.1, - continuous_beta_1=20., + self, + schedule="discrete", + betas=None, + alphas_cumprod=None, + continuous_beta_0=0.1, + continuous_beta_1=20.0, ): """Create a wrapper class for the forward SDE (VP type). *** @@ -70,50 +70,63 @@ class NoiseScheduleVP: >>> ns = NoiseScheduleVP('linear', continuous_beta_0=0.1, continuous_beta_1=20.) """ - if schedule not in ['discrete', 'linear', 'cosine']: + if schedule not in ["discrete", "linear", "cosine"]: raise ValueError( "Unsupported noise schedule {}. The schedule needs to be 'discrete' or 'linear' or 'cosine'".format( - schedule)) + schedule + ) + ) self.schedule = schedule - if schedule == 'discrete': + if schedule == "discrete": if betas is not None: log_alphas = 0.5 * torch.log(1 - betas).cumsum(dim=0) else: assert alphas_cumprod is not None log_alphas = 0.5 * torch.log(alphas_cumprod) self.total_N = len(log_alphas) - self.T = 1. - self.t_array = torch.linspace(0., 1., self.total_N + 1)[1:].reshape((1, -1)) - self.log_alpha_array = log_alphas.reshape((1, -1,)) + self.T = 1.0 + self.t_array = torch.linspace(0.0, 1.0, self.total_N + 1)[1:].reshape((1, -1)) + self.log_alpha_array = log_alphas.reshape( + ( + 1, + -1, + ) + ) else: self.total_N = 1000 self.beta_0 = continuous_beta_0 self.beta_1 = continuous_beta_1 self.cosine_s = 0.008 - self.cosine_beta_max = 999. - self.cosine_t_max = math.atan(self.cosine_beta_max * (1. + self.cosine_s) / math.pi) * 2. * ( - 1. + self.cosine_s) / math.pi - self.cosine_s - self.cosine_log_alpha_0 = math.log(math.cos(self.cosine_s / (1. + self.cosine_s) * math.pi / 2.)) + self.cosine_beta_max = 999.0 + self.cosine_t_max = ( + math.atan(self.cosine_beta_max * (1.0 + self.cosine_s) / math.pi) + * 2.0 + * (1.0 + self.cosine_s) + / math.pi + - self.cosine_s + ) + self.cosine_log_alpha_0 = math.log(math.cos(self.cosine_s / (1.0 + self.cosine_s) * math.pi / 2.0)) self.schedule = schedule - if schedule == 'cosine': + if schedule == "cosine": # For the cosine schedule, T = 1 will have numerical issues. So we manually set the ending time T. # Note that T = 0.9946 may be not the optimal setting. However, we find it works well. self.T = 0.9946 else: - self.T = 1. + self.T = 1.0 def marginal_log_mean_coeff(self, t): """ Compute log(alpha_t) of a given continuous-time label t in [0, T]. """ - if self.schedule == 'discrete': - return interpolate_fn(t.reshape((-1, 1)), self.t_array.to(t.device), - self.log_alpha_array.to(t.device)).reshape((-1)) - elif self.schedule == 'linear': - return -0.25 * t ** 2 * (self.beta_1 - self.beta_0) - 0.5 * t * self.beta_0 - elif self.schedule == 'cosine': - log_alpha_fn = lambda s: torch.log(torch.cos((s + self.cosine_s) / (1. + self.cosine_s) * math.pi / 2.)) + if self.schedule == "discrete": + return interpolate_fn( + t.reshape((-1, 1)), self.t_array.to(t.device), self.log_alpha_array.to(t.device) + ).reshape((-1)) + elif self.schedule == "linear": + return -0.25 * t**2 * (self.beta_1 - self.beta_0) - 0.5 * t * self.beta_0 + elif self.schedule == "cosine": + log_alpha_fn = lambda s: torch.log(torch.cos((s + self.cosine_s) / (1.0 + self.cosine_s) * math.pi / 2.0)) log_alpha_t = log_alpha_fn(t) - self.cosine_log_alpha_0 return log_alpha_t @@ -127,48 +140,56 @@ class NoiseScheduleVP: """ Compute sigma_t of a given continuous-time label t in [0, T]. """ - return torch.sqrt(1. - torch.exp(2. * self.marginal_log_mean_coeff(t))) + return torch.sqrt(1.0 - torch.exp(2.0 * self.marginal_log_mean_coeff(t))) def marginal_lambda(self, t): """ Compute lambda_t = log(alpha_t) - log(sigma_t) of a given continuous-time label t in [0, T]. """ log_mean_coeff = self.marginal_log_mean_coeff(t) - log_std = 0.5 * torch.log(1. - torch.exp(2. * log_mean_coeff)) + log_std = 0.5 * torch.log(1.0 - torch.exp(2.0 * log_mean_coeff)) return log_mean_coeff - log_std def inverse_lambda(self, lamb): """ Compute the continuous-time label t in [0, T] of a given half-logSNR lambda_t. """ - if self.schedule == 'linear': - tmp = 2. * (self.beta_1 - self.beta_0) * torch.logaddexp(-2. * lamb, torch.zeros((1,)).to(lamb)) - Delta = self.beta_0 ** 2 + tmp + if self.schedule == "linear": + tmp = 2.0 * (self.beta_1 - self.beta_0) * torch.logaddexp(-2.0 * lamb, torch.zeros((1,)).to(lamb)) + Delta = self.beta_0**2 + tmp return tmp / (torch.sqrt(Delta) + self.beta_0) / (self.beta_1 - self.beta_0) - elif self.schedule == 'discrete': - log_alpha = -0.5 * torch.logaddexp(torch.zeros((1,)).to(lamb.device), -2. * lamb) - t = interpolate_fn(log_alpha.reshape((-1, 1)), torch.flip(self.log_alpha_array.to(lamb.device), [1]), - torch.flip(self.t_array.to(lamb.device), [1])) + elif self.schedule == "discrete": + log_alpha = -0.5 * torch.logaddexp(torch.zeros((1,)).to(lamb.device), -2.0 * lamb) + t = interpolate_fn( + log_alpha.reshape((-1, 1)), + torch.flip(self.log_alpha_array.to(lamb.device), [1]), + torch.flip(self.t_array.to(lamb.device), [1]), + ) return t.reshape((-1,)) else: - log_alpha = -0.5 * torch.logaddexp(-2. * lamb, torch.zeros((1,)).to(lamb)) - t_fn = lambda log_alpha_t: torch.arccos(torch.exp(log_alpha_t + self.cosine_log_alpha_0)) * 2. * ( - 1. + self.cosine_s) / math.pi - self.cosine_s + log_alpha = -0.5 * torch.logaddexp(-2.0 * lamb, torch.zeros((1,)).to(lamb)) + t_fn = ( + lambda log_alpha_t: torch.arccos(torch.exp(log_alpha_t + self.cosine_log_alpha_0)) + * 2.0 + * (1.0 + self.cosine_s) + / math.pi + - self.cosine_s + ) t = t_fn(log_alpha) return t def model_wrapper( - model, - noise_schedule, - model_type="noise", - model_kwargs={}, - guidance_type="uncond", - condition=None, - unconditional_condition=None, - guidance_scale=1., - classifier_fn=None, - classifier_kwargs={}, + model, + noise_schedule, + model_type="noise", + model_kwargs={}, + guidance_type="uncond", + condition=None, + unconditional_condition=None, + guidance_scale=1.0, + classifier_fn=None, + classifier_kwargs={}, ): """Create a wrapper function for the noise prediction model. DPM-Solver needs to solve the continuous-time diffusion ODEs. For DPMs trained on discrete-time labels, we need to @@ -249,8 +270,8 @@ def model_wrapper( For discrete-time DPMs, we convert `t_continuous` in [1 / N, 1] to `t_input` in [0, 1000 * (N - 1) / N]. For continuous-time DPMs, we just use `t_continuous`. """ - if noise_schedule.schedule == 'discrete': - return (t_continuous - 1. / noise_schedule.total_N) * 1000. + if noise_schedule.schedule == "discrete": + return (t_continuous - 1.0 / noise_schedule.total_N) * 1000.0 else: return t_continuous @@ -302,7 +323,7 @@ def model_wrapper( noise = noise_pred_fn(x, t_continuous) return noise - guidance_scale * expand_dims(sigma_t, dims=cond_grad.dim()) * cond_grad elif guidance_type == "classifier-free": - if guidance_scale == 1. or unconditional_condition is None: + if guidance_scale == 1.0 or unconditional_condition is None: return noise_pred_fn(x, t_continuous, cond=condition) else: x_in = torch.cat([x] * 2) @@ -317,7 +338,7 @@ def model_wrapper( class DPM_Solver: - def __init__(self, model_fn, noise_schedule, predict_x0=False, thresholding=False, max_val=1.): + def __init__(self, model_fn, noise_schedule, predict_x0=False, thresholding=False, max_val=1.0): """Construct a DPM-Solver. We support both the noise prediction model ("predicting epsilon") and the data prediction model ("predicting x0"). If `predict_x0` is False, we use the solver for the noise prediction model (DPM-Solver). @@ -387,20 +408,21 @@ class DPM_Solver: Returns: A pytorch tensor of the time steps, with the shape (N + 1,). """ - if skip_type == 'logSNR': + if skip_type == "logSNR": lambda_T = self.noise_schedule.marginal_lambda(torch.tensor(t_T).to(device)) lambda_0 = self.noise_schedule.marginal_lambda(torch.tensor(t_0).to(device)) logSNR_steps = torch.linspace(lambda_T.cpu().item(), lambda_0.cpu().item(), N + 1).to(device) return self.noise_schedule.inverse_lambda(logSNR_steps) - elif skip_type == 'time_uniform': + elif skip_type == "time_uniform": return torch.linspace(t_T, t_0, N + 1).to(device) - elif skip_type == 'time_quadratic': + elif skip_type == "time_quadratic": t_order = 2 - t = torch.linspace(t_T ** (1. / t_order), t_0 ** (1. / t_order), N + 1).pow(t_order).to(device) + t = torch.linspace(t_T ** (1.0 / t_order), t_0 ** (1.0 / t_order), N + 1).pow(t_order).to(device) return t else: raise ValueError( - "Unsupported skip_type {}, need to be 'logSNR' or 'time_uniform' or 'time_quadratic'".format(skip_type)) + "Unsupported skip_type {}, need to be 'logSNR' or 'time_uniform' or 'time_quadratic'".format(skip_type) + ) def get_orders_and_timesteps_for_singlestep_solver(self, steps, order, skip_type, t_T, t_0, device): """ @@ -435,29 +457,57 @@ class DPM_Solver: if order == 3: K = steps // 3 + 1 if steps % 3 == 0: - orders = [3, ] * (K - 2) + [2, 1] + orders = [ + 3, + ] * ( + K - 2 + ) + [2, 1] elif steps % 3 == 1: - orders = [3, ] * (K - 1) + [1] + orders = [ + 3, + ] * ( + K - 1 + ) + [1] else: - orders = [3, ] * (K - 1) + [2] + orders = [ + 3, + ] * ( + K - 1 + ) + [2] elif order == 2: if steps % 2 == 0: K = steps // 2 - orders = [2, ] * K + orders = [ + 2, + ] * K else: K = steps // 2 + 1 - orders = [2, ] * (K - 1) + [1] + orders = [ + 2, + ] * ( + K - 1 + ) + [1] elif order == 1: K = 1 - orders = [1, ] * steps + orders = [ + 1, + ] * steps else: raise ValueError("'order' must be '1' or '2' or '3'.") - if skip_type == 'logSNR': + if skip_type == "logSNR": # To reproduce the results in DPM-Solver paper timesteps_outer = self.get_time_steps(skip_type, t_T, t_0, K, device) else: timesteps_outer = self.get_time_steps(skip_type, t_T, t_0, steps, device)[ - torch.cumsum(torch.tensor([0, ] + orders)).to(device)] + torch.cumsum( + torch.tensor( + [ + 0, + ] + + orders + ) + ).to(device) + ] return timesteps_outer, orders def denoise_to_zero_fn(self, x, s): @@ -491,12 +541,9 @@ class DPM_Solver: phi_1 = torch.expm1(-h) if model_s is None: model_s = self.model_fn(x, s) - x_t = ( - expand_dims(sigma_t / sigma_s, dims) * x - - expand_dims(alpha_t * phi_1, dims) * model_s - ) + x_t = expand_dims(sigma_t / sigma_s, dims) * x - expand_dims(alpha_t * phi_1, dims) * model_s if return_intermediate: - return x_t, {'model_s': model_s} + return x_t, {"model_s": model_s} else: return x_t else: @@ -504,16 +551,17 @@ class DPM_Solver: if model_s is None: model_s = self.model_fn(x, s) x_t = ( - expand_dims(torch.exp(log_alpha_t - log_alpha_s), dims) * x - - expand_dims(sigma_t * phi_1, dims) * model_s + expand_dims(torch.exp(log_alpha_t - log_alpha_s), dims) * x + - expand_dims(sigma_t * phi_1, dims) * model_s ) if return_intermediate: - return x_t, {'model_s': model_s} + return x_t, {"model_s": model_s} else: return x_t - def singlestep_dpm_solver_second_update(self, x, s, t, r1=0.5, model_s=None, return_intermediate=False, - solver_type='dpm_solver'): + def singlestep_dpm_solver_second_update( + self, x, s, t, r1=0.5, model_s=None, return_intermediate=False, solver_type="dpm_solver" + ): """ Singlestep solver DPM-Solver-2 from time `s` to time `t`. Args: @@ -529,7 +577,7 @@ class DPM_Solver: Returns: x_t: A pytorch tensor. The approximated solution at time `t`. """ - if solver_type not in ['dpm_solver', 'taylor']: + if solver_type not in ["dpm_solver", "taylor"]: raise ValueError("'solver_type' must be either 'dpm_solver' or 'taylor', got {}".format(solver_type)) if r1 is None: r1 = 0.5 @@ -539,8 +587,11 @@ class DPM_Solver: h = lambda_t - lambda_s lambda_s1 = lambda_s + r1 * h s1 = ns.inverse_lambda(lambda_s1) - log_alpha_s, log_alpha_s1, log_alpha_t = ns.marginal_log_mean_coeff(s), ns.marginal_log_mean_coeff( - s1), ns.marginal_log_mean_coeff(t) + log_alpha_s, log_alpha_s1, log_alpha_t = ( + ns.marginal_log_mean_coeff(s), + ns.marginal_log_mean_coeff(s1), + ns.marginal_log_mean_coeff(t), + ) sigma_s, sigma_s1, sigma_t = ns.marginal_std(s), ns.marginal_std(s1), ns.marginal_std(t) alpha_s1, alpha_t = torch.exp(log_alpha_s1), torch.exp(log_alpha_t) @@ -550,23 +601,19 @@ class DPM_Solver: if model_s is None: model_s = self.model_fn(x, s) - x_s1 = ( - expand_dims(sigma_s1 / sigma_s, dims) * x - - expand_dims(alpha_s1 * phi_11, dims) * model_s - ) + x_s1 = expand_dims(sigma_s1 / sigma_s, dims) * x - expand_dims(alpha_s1 * phi_11, dims) * model_s model_s1 = self.model_fn(x_s1, s1) - if solver_type == 'dpm_solver': + if solver_type == "dpm_solver": x_t = ( - expand_dims(sigma_t / sigma_s, dims) * x - - expand_dims(alpha_t * phi_1, dims) * model_s - - (0.5 / r1) * expand_dims(alpha_t * phi_1, dims) * (model_s1 - model_s) + expand_dims(sigma_t / sigma_s, dims) * x + - expand_dims(alpha_t * phi_1, dims) * model_s + - (0.5 / r1) * expand_dims(alpha_t * phi_1, dims) * (model_s1 - model_s) ) - elif solver_type == 'taylor': + elif solver_type == "taylor": x_t = ( - expand_dims(sigma_t / sigma_s, dims) * x - - expand_dims(alpha_t * phi_1, dims) * model_s - + (1. / r1) * expand_dims(alpha_t * ((torch.exp(-h) - 1.) / h + 1.), dims) * ( - model_s1 - model_s) + expand_dims(sigma_t / sigma_s, dims) * x + - expand_dims(alpha_t * phi_1, dims) * model_s + + (1.0 / r1) * expand_dims(alpha_t * ((torch.exp(-h) - 1.0) / h + 1.0), dims) * (model_s1 - model_s) ) else: phi_11 = torch.expm1(r1 * h) @@ -575,29 +622,39 @@ class DPM_Solver: if model_s is None: model_s = self.model_fn(x, s) x_s1 = ( - expand_dims(torch.exp(log_alpha_s1 - log_alpha_s), dims) * x - - expand_dims(sigma_s1 * phi_11, dims) * model_s + expand_dims(torch.exp(log_alpha_s1 - log_alpha_s), dims) * x + - expand_dims(sigma_s1 * phi_11, dims) * model_s ) model_s1 = self.model_fn(x_s1, s1) - if solver_type == 'dpm_solver': + if solver_type == "dpm_solver": x_t = ( - expand_dims(torch.exp(log_alpha_t - log_alpha_s), dims) * x - - expand_dims(sigma_t * phi_1, dims) * model_s - - (0.5 / r1) * expand_dims(sigma_t * phi_1, dims) * (model_s1 - model_s) + expand_dims(torch.exp(log_alpha_t - log_alpha_s), dims) * x + - expand_dims(sigma_t * phi_1, dims) * model_s + - (0.5 / r1) * expand_dims(sigma_t * phi_1, dims) * (model_s1 - model_s) ) - elif solver_type == 'taylor': + elif solver_type == "taylor": x_t = ( - expand_dims(torch.exp(log_alpha_t - log_alpha_s), dims) * x - - expand_dims(sigma_t * phi_1, dims) * model_s - - (1. / r1) * expand_dims(sigma_t * ((torch.exp(h) - 1.) / h - 1.), dims) * (model_s1 - model_s) + expand_dims(torch.exp(log_alpha_t - log_alpha_s), dims) * x + - expand_dims(sigma_t * phi_1, dims) * model_s + - (1.0 / r1) * expand_dims(sigma_t * ((torch.exp(h) - 1.0) / h - 1.0), dims) * (model_s1 - model_s) ) if return_intermediate: - return x_t, {'model_s': model_s, 'model_s1': model_s1} + return x_t, {"model_s": model_s, "model_s1": model_s1} else: return x_t - def singlestep_dpm_solver_third_update(self, x, s, t, r1=1. / 3., r2=2. / 3., model_s=None, model_s1=None, - return_intermediate=False, solver_type='dpm_solver'): + def singlestep_dpm_solver_third_update( + self, + x, + s, + t, + r1=1.0 / 3.0, + r2=2.0 / 3.0, + model_s=None, + model_s1=None, + return_intermediate=False, + solver_type="dpm_solver", + ): """ Singlestep solver DPM-Solver-3 from time `s` to time `t`. Args: @@ -616,12 +673,12 @@ class DPM_Solver: Returns: x_t: A pytorch tensor. The approximated solution at time `t`. """ - if solver_type not in ['dpm_solver', 'taylor']: + if solver_type not in ["dpm_solver", "taylor"]: raise ValueError("'solver_type' must be either 'dpm_solver' or 'taylor', got {}".format(solver_type)) if r1 is None: - r1 = 1. / 3. + r1 = 1.0 / 3.0 if r2 is None: - r2 = 2. / 3. + r2 = 2.0 / 3.0 ns = self.noise_schedule dims = x.dim() lambda_s, lambda_t = ns.marginal_lambda(s), ns.marginal_lambda(t) @@ -630,93 +687,98 @@ class DPM_Solver: lambda_s2 = lambda_s + r2 * h s1 = ns.inverse_lambda(lambda_s1) s2 = ns.inverse_lambda(lambda_s2) - log_alpha_s, log_alpha_s1, log_alpha_s2, log_alpha_t = ns.marginal_log_mean_coeff( - s), ns.marginal_log_mean_coeff(s1), ns.marginal_log_mean_coeff(s2), ns.marginal_log_mean_coeff(t) - sigma_s, sigma_s1, sigma_s2, sigma_t = ns.marginal_std(s), ns.marginal_std(s1), ns.marginal_std( - s2), ns.marginal_std(t) + log_alpha_s, log_alpha_s1, log_alpha_s2, log_alpha_t = ( + ns.marginal_log_mean_coeff(s), + ns.marginal_log_mean_coeff(s1), + ns.marginal_log_mean_coeff(s2), + ns.marginal_log_mean_coeff(t), + ) + sigma_s, sigma_s1, sigma_s2, sigma_t = ( + ns.marginal_std(s), + ns.marginal_std(s1), + ns.marginal_std(s2), + ns.marginal_std(t), + ) alpha_s1, alpha_s2, alpha_t = torch.exp(log_alpha_s1), torch.exp(log_alpha_s2), torch.exp(log_alpha_t) if self.predict_x0: phi_11 = torch.expm1(-r1 * h) phi_12 = torch.expm1(-r2 * h) phi_1 = torch.expm1(-h) - phi_22 = torch.expm1(-r2 * h) / (r2 * h) + 1. - phi_2 = phi_1 / h + 1. + phi_22 = torch.expm1(-r2 * h) / (r2 * h) + 1.0 + phi_2 = phi_1 / h + 1.0 phi_3 = phi_2 / h - 0.5 if model_s is None: model_s = self.model_fn(x, s) if model_s1 is None: - x_s1 = ( - expand_dims(sigma_s1 / sigma_s, dims) * x - - expand_dims(alpha_s1 * phi_11, dims) * model_s - ) + x_s1 = expand_dims(sigma_s1 / sigma_s, dims) * x - expand_dims(alpha_s1 * phi_11, dims) * model_s model_s1 = self.model_fn(x_s1, s1) x_s2 = ( - expand_dims(sigma_s2 / sigma_s, dims) * x - - expand_dims(alpha_s2 * phi_12, dims) * model_s - + r2 / r1 * expand_dims(alpha_s2 * phi_22, dims) * (model_s1 - model_s) + expand_dims(sigma_s2 / sigma_s, dims) * x + - expand_dims(alpha_s2 * phi_12, dims) * model_s + + r2 / r1 * expand_dims(alpha_s2 * phi_22, dims) * (model_s1 - model_s) ) model_s2 = self.model_fn(x_s2, s2) - if solver_type == 'dpm_solver': + if solver_type == "dpm_solver": x_t = ( - expand_dims(sigma_t / sigma_s, dims) * x - - expand_dims(alpha_t * phi_1, dims) * model_s - + (1. / r2) * expand_dims(alpha_t * phi_2, dims) * (model_s2 - model_s) + expand_dims(sigma_t / sigma_s, dims) * x + - expand_dims(alpha_t * phi_1, dims) * model_s + + (1.0 / r2) * expand_dims(alpha_t * phi_2, dims) * (model_s2 - model_s) ) - elif solver_type == 'taylor': - D1_0 = (1. / r1) * (model_s1 - model_s) - D1_1 = (1. / r2) * (model_s2 - model_s) + elif solver_type == "taylor": + D1_0 = (1.0 / r1) * (model_s1 - model_s) + D1_1 = (1.0 / r2) * (model_s2 - model_s) D1 = (r2 * D1_0 - r1 * D1_1) / (r2 - r1) - D2 = 2. * (D1_1 - D1_0) / (r2 - r1) + D2 = 2.0 * (D1_1 - D1_0) / (r2 - r1) x_t = ( - expand_dims(sigma_t / sigma_s, dims) * x - - expand_dims(alpha_t * phi_1, dims) * model_s - + expand_dims(alpha_t * phi_2, dims) * D1 - - expand_dims(alpha_t * phi_3, dims) * D2 + expand_dims(sigma_t / sigma_s, dims) * x + - expand_dims(alpha_t * phi_1, dims) * model_s + + expand_dims(alpha_t * phi_2, dims) * D1 + - expand_dims(alpha_t * phi_3, dims) * D2 ) else: phi_11 = torch.expm1(r1 * h) phi_12 = torch.expm1(r2 * h) phi_1 = torch.expm1(h) - phi_22 = torch.expm1(r2 * h) / (r2 * h) - 1. - phi_2 = phi_1 / h - 1. + phi_22 = torch.expm1(r2 * h) / (r2 * h) - 1.0 + phi_2 = phi_1 / h - 1.0 phi_3 = phi_2 / h - 0.5 if model_s is None: model_s = self.model_fn(x, s) if model_s1 is None: x_s1 = ( - expand_dims(torch.exp(log_alpha_s1 - log_alpha_s), dims) * x - - expand_dims(sigma_s1 * phi_11, dims) * model_s + expand_dims(torch.exp(log_alpha_s1 - log_alpha_s), dims) * x + - expand_dims(sigma_s1 * phi_11, dims) * model_s ) model_s1 = self.model_fn(x_s1, s1) x_s2 = ( - expand_dims(torch.exp(log_alpha_s2 - log_alpha_s), dims) * x - - expand_dims(sigma_s2 * phi_12, dims) * model_s - - r2 / r1 * expand_dims(sigma_s2 * phi_22, dims) * (model_s1 - model_s) + expand_dims(torch.exp(log_alpha_s2 - log_alpha_s), dims) * x + - expand_dims(sigma_s2 * phi_12, dims) * model_s + - r2 / r1 * expand_dims(sigma_s2 * phi_22, dims) * (model_s1 - model_s) ) model_s2 = self.model_fn(x_s2, s2) - if solver_type == 'dpm_solver': + if solver_type == "dpm_solver": x_t = ( - expand_dims(torch.exp(log_alpha_t - log_alpha_s), dims) * x - - expand_dims(sigma_t * phi_1, dims) * model_s - - (1. / r2) * expand_dims(sigma_t * phi_2, dims) * (model_s2 - model_s) + expand_dims(torch.exp(log_alpha_t - log_alpha_s), dims) * x + - expand_dims(sigma_t * phi_1, dims) * model_s + - (1.0 / r2) * expand_dims(sigma_t * phi_2, dims) * (model_s2 - model_s) ) - elif solver_type == 'taylor': - D1_0 = (1. / r1) * (model_s1 - model_s) - D1_1 = (1. / r2) * (model_s2 - model_s) + elif solver_type == "taylor": + D1_0 = (1.0 / r1) * (model_s1 - model_s) + D1_1 = (1.0 / r2) * (model_s2 - model_s) D1 = (r2 * D1_0 - r1 * D1_1) / (r2 - r1) - D2 = 2. * (D1_1 - D1_0) / (r2 - r1) + D2 = 2.0 * (D1_1 - D1_0) / (r2 - r1) x_t = ( - expand_dims(torch.exp(log_alpha_t - log_alpha_s), dims) * x - - expand_dims(sigma_t * phi_1, dims) * model_s - - expand_dims(sigma_t * phi_2, dims) * D1 - - expand_dims(sigma_t * phi_3, dims) * D2 + expand_dims(torch.exp(log_alpha_t - log_alpha_s), dims) * x + - expand_dims(sigma_t * phi_1, dims) * model_s + - expand_dims(sigma_t * phi_2, dims) * D1 + - expand_dims(sigma_t * phi_3, dims) * D2 ) if return_intermediate: - return x_t, {'model_s': model_s, 'model_s1': model_s1, 'model_s2': model_s2} + return x_t, {"model_s": model_s, "model_s1": model_s1, "model_s2": model_s2} else: return x_t @@ -733,14 +795,17 @@ class DPM_Solver: Returns: x_t: A pytorch tensor. The approximated solution at time `t`. """ - if solver_type not in ['dpm_solver', 'taylor']: + if solver_type not in ["dpm_solver", "taylor"]: raise ValueError("'solver_type' must be either 'dpm_solver' or 'taylor', got {}".format(solver_type)) ns = self.noise_schedule dims = x.dim() model_prev_1, model_prev_0 = model_prev_list t_prev_1, t_prev_0 = t_prev_list - lambda_prev_1, lambda_prev_0, lambda_t = ns.marginal_lambda(t_prev_1), ns.marginal_lambda( - t_prev_0), ns.marginal_lambda(t) + lambda_prev_1, lambda_prev_0, lambda_t = ( + ns.marginal_lambda(t_prev_1), + ns.marginal_lambda(t_prev_0), + ns.marginal_lambda(t), + ) log_alpha_prev_0, log_alpha_t = ns.marginal_log_mean_coeff(t_prev_0), ns.marginal_log_mean_coeff(t) sigma_prev_0, sigma_t = ns.marginal_std(t_prev_0), ns.marginal_std(t) alpha_t = torch.exp(log_alpha_t) @@ -748,36 +813,36 @@ class DPM_Solver: h_0 = lambda_prev_0 - lambda_prev_1 h = lambda_t - lambda_prev_0 r0 = h_0 / h - D1_0 = expand_dims(1. / r0, dims) * (model_prev_0 - model_prev_1) + D1_0 = expand_dims(1.0 / r0, dims) * (model_prev_0 - model_prev_1) if self.predict_x0: - if solver_type == 'dpm_solver': + if solver_type == "dpm_solver": x_t = ( - expand_dims(sigma_t / sigma_prev_0, dims) * x - - expand_dims(alpha_t * (torch.exp(-h) - 1.), dims) * model_prev_0 - - 0.5 * expand_dims(alpha_t * (torch.exp(-h) - 1.), dims) * D1_0 + expand_dims(sigma_t / sigma_prev_0, dims) * x + - expand_dims(alpha_t * (torch.exp(-h) - 1.0), dims) * model_prev_0 + - 0.5 * expand_dims(alpha_t * (torch.exp(-h) - 1.0), dims) * D1_0 ) - elif solver_type == 'taylor': + elif solver_type == "taylor": x_t = ( - expand_dims(sigma_t / sigma_prev_0, dims) * x - - expand_dims(alpha_t * (torch.exp(-h) - 1.), dims) * model_prev_0 - + expand_dims(alpha_t * ((torch.exp(-h) - 1.) / h + 1.), dims) * D1_0 + expand_dims(sigma_t / sigma_prev_0, dims) * x + - expand_dims(alpha_t * (torch.exp(-h) - 1.0), dims) * model_prev_0 + + expand_dims(alpha_t * ((torch.exp(-h) - 1.0) / h + 1.0), dims) * D1_0 ) else: - if solver_type == 'dpm_solver': + if solver_type == "dpm_solver": x_t = ( - expand_dims(torch.exp(log_alpha_t - log_alpha_prev_0), dims) * x - - expand_dims(sigma_t * (torch.exp(h) - 1.), dims) * model_prev_0 - - 0.5 * expand_dims(sigma_t * (torch.exp(h) - 1.), dims) * D1_0 + expand_dims(torch.exp(log_alpha_t - log_alpha_prev_0), dims) * x + - expand_dims(sigma_t * (torch.exp(h) - 1.0), dims) * model_prev_0 + - 0.5 * expand_dims(sigma_t * (torch.exp(h) - 1.0), dims) * D1_0 ) - elif solver_type == 'taylor': + elif solver_type == "taylor": x_t = ( - expand_dims(torch.exp(log_alpha_t - log_alpha_prev_0), dims) * x - - expand_dims(sigma_t * (torch.exp(h) - 1.), dims) * model_prev_0 - - expand_dims(sigma_t * ((torch.exp(h) - 1.) / h - 1.), dims) * D1_0 + expand_dims(torch.exp(log_alpha_t - log_alpha_prev_0), dims) * x + - expand_dims(sigma_t * (torch.exp(h) - 1.0), dims) * model_prev_0 + - expand_dims(sigma_t * ((torch.exp(h) - 1.0) / h - 1.0), dims) * D1_0 ) return x_t - def multistep_dpm_solver_third_update(self, x, model_prev_list, t_prev_list, t, solver_type='dpm_solver'): + def multistep_dpm_solver_third_update(self, x, model_prev_list, t_prev_list, t, solver_type="dpm_solver"): """ Multistep solver DPM-Solver-3 from time `t_prev_list[-1]` to time `t`. Args: @@ -794,8 +859,12 @@ class DPM_Solver: dims = x.dim() model_prev_2, model_prev_1, model_prev_0 = model_prev_list t_prev_2, t_prev_1, t_prev_0 = t_prev_list - lambda_prev_2, lambda_prev_1, lambda_prev_0, lambda_t = ns.marginal_lambda(t_prev_2), ns.marginal_lambda( - t_prev_1), ns.marginal_lambda(t_prev_0), ns.marginal_lambda(t) + lambda_prev_2, lambda_prev_1, lambda_prev_0, lambda_t = ( + ns.marginal_lambda(t_prev_2), + ns.marginal_lambda(t_prev_1), + ns.marginal_lambda(t_prev_0), + ns.marginal_lambda(t), + ) log_alpha_prev_0, log_alpha_t = ns.marginal_log_mean_coeff(t_prev_0), ns.marginal_log_mean_coeff(t) sigma_prev_0, sigma_t = ns.marginal_std(t_prev_0), ns.marginal_std(t) alpha_t = torch.exp(log_alpha_t) @@ -804,28 +873,29 @@ class DPM_Solver: h_0 = lambda_prev_0 - lambda_prev_1 h = lambda_t - lambda_prev_0 r0, r1 = h_0 / h, h_1 / h - D1_0 = expand_dims(1. / r0, dims) * (model_prev_0 - model_prev_1) - D1_1 = expand_dims(1. / r1, dims) * (model_prev_1 - model_prev_2) + D1_0 = expand_dims(1.0 / r0, dims) * (model_prev_0 - model_prev_1) + D1_1 = expand_dims(1.0 / r1, dims) * (model_prev_1 - model_prev_2) D1 = D1_0 + expand_dims(r0 / (r0 + r1), dims) * (D1_0 - D1_1) - D2 = expand_dims(1. / (r0 + r1), dims) * (D1_0 - D1_1) + D2 = expand_dims(1.0 / (r0 + r1), dims) * (D1_0 - D1_1) if self.predict_x0: x_t = ( - expand_dims(sigma_t / sigma_prev_0, dims) * x - - expand_dims(alpha_t * (torch.exp(-h) - 1.), dims) * model_prev_0 - + expand_dims(alpha_t * ((torch.exp(-h) - 1.) / h + 1.), dims) * D1 - - expand_dims(alpha_t * ((torch.exp(-h) - 1. + h) / h ** 2 - 0.5), dims) * D2 + expand_dims(sigma_t / sigma_prev_0, dims) * x + - expand_dims(alpha_t * (torch.exp(-h) - 1.0), dims) * model_prev_0 + + expand_dims(alpha_t * ((torch.exp(-h) - 1.0) / h + 1.0), dims) * D1 + - expand_dims(alpha_t * ((torch.exp(-h) - 1.0 + h) / h**2 - 0.5), dims) * D2 ) else: x_t = ( - expand_dims(torch.exp(log_alpha_t - log_alpha_prev_0), dims) * x - - expand_dims(sigma_t * (torch.exp(h) - 1.), dims) * model_prev_0 - - expand_dims(sigma_t * ((torch.exp(h) - 1.) / h - 1.), dims) * D1 - - expand_dims(sigma_t * ((torch.exp(h) - 1. - h) / h ** 2 - 0.5), dims) * D2 + expand_dims(torch.exp(log_alpha_t - log_alpha_prev_0), dims) * x + - expand_dims(sigma_t * (torch.exp(h) - 1.0), dims) * model_prev_0 + - expand_dims(sigma_t * ((torch.exp(h) - 1.0) / h - 1.0), dims) * D1 + - expand_dims(sigma_t * ((torch.exp(h) - 1.0 - h) / h**2 - 0.5), dims) * D2 ) return x_t - def singlestep_dpm_solver_update(self, x, s, t, order, return_intermediate=False, solver_type='dpm_solver', r1=None, - r2=None): + def singlestep_dpm_solver_update( + self, x, s, t, order, return_intermediate=False, solver_type="dpm_solver", r1=None, r2=None + ): """ Singlestep DPM-Solver with the order `order` from time `s` to time `t`. Args: @@ -844,15 +914,17 @@ class DPM_Solver: if order == 1: return self.dpm_solver_first_update(x, s, t, return_intermediate=return_intermediate) elif order == 2: - return self.singlestep_dpm_solver_second_update(x, s, t, return_intermediate=return_intermediate, - solver_type=solver_type, r1=r1) + return self.singlestep_dpm_solver_second_update( + x, s, t, return_intermediate=return_intermediate, solver_type=solver_type, r1=r1 + ) elif order == 3: - return self.singlestep_dpm_solver_third_update(x, s, t, return_intermediate=return_intermediate, - solver_type=solver_type, r1=r1, r2=r2) + return self.singlestep_dpm_solver_third_update( + x, s, t, return_intermediate=return_intermediate, solver_type=solver_type, r1=r1, r2=r2 + ) else: raise ValueError("Solver order must be 1 or 2 or 3, got {}".format(order)) - def multistep_dpm_solver_update(self, x, model_prev_list, t_prev_list, t, order, solver_type='dpm_solver'): + def multistep_dpm_solver_update(self, x, model_prev_list, t_prev_list, t, order, solver_type="dpm_solver"): """ Multistep DPM-Solver with the order `order` from time `t_prev_list[-1]` to time `t`. Args: @@ -875,8 +947,9 @@ class DPM_Solver: else: raise ValueError("Solver order must be 1 or 2 or 3, got {}".format(order)) - def dpm_solver_adaptive(self, x, order, t_T, t_0, h_init=0.05, atol=0.0078, rtol=0.05, theta=0.9, t_err=1e-5, - solver_type='dpm_solver'): + def dpm_solver_adaptive( + self, x, order, t_T, t_0, h_init=0.05, atol=0.0078, rtol=0.05, theta=0.9, t_err=1e-5, solver_type="dpm_solver" + ): """ The adaptive step size solver based on singlestep DPM-Solver. Args: @@ -906,17 +979,17 @@ class DPM_Solver: if order == 2: r1 = 0.5 lower_update = lambda x, s, t: self.dpm_solver_first_update(x, s, t, return_intermediate=True) - higher_update = lambda x, s, t, **kwargs: self.singlestep_dpm_solver_second_update(x, s, t, r1=r1, - solver_type=solver_type, - **kwargs) + higher_update = lambda x, s, t, **kwargs: self.singlestep_dpm_solver_second_update( + x, s, t, r1=r1, solver_type=solver_type, **kwargs + ) elif order == 3: - r1, r2 = 1. / 3., 2. / 3. - lower_update = lambda x, s, t: self.singlestep_dpm_solver_second_update(x, s, t, r1=r1, - return_intermediate=True, - solver_type=solver_type) - higher_update = lambda x, s, t, **kwargs: self.singlestep_dpm_solver_third_update(x, s, t, r1=r1, r2=r2, - solver_type=solver_type, - **kwargs) + r1, r2 = 1.0 / 3.0, 2.0 / 3.0 + lower_update = lambda x, s, t: self.singlestep_dpm_solver_second_update( + x, s, t, r1=r1, return_intermediate=True, solver_type=solver_type + ) + higher_update = lambda x, s, t, **kwargs: self.singlestep_dpm_solver_third_update( + x, s, t, r1=r1, r2=r2, solver_type=solver_type, **kwargs + ) else: raise ValueError("For adaptive step size solver, order must be 2 or 3, got {}".format(order)) while torch.abs((s - t_0)).mean() > t_err: @@ -926,20 +999,31 @@ class DPM_Solver: delta = torch.max(torch.ones_like(x).to(x) * atol, rtol * torch.max(torch.abs(x_lower), torch.abs(x_prev))) norm_fn = lambda v: torch.sqrt(torch.square(v.reshape((v.shape[0], -1))).mean(dim=-1, keepdim=True)) E = norm_fn((x_higher - x_lower) / delta).max() - if torch.all(E <= 1.): + if torch.all(E <= 1.0): x = x_higher s = t x_prev = x_lower lambda_s = ns.marginal_lambda(s) - h = torch.min(theta * h * torch.float_power(E, -1. / order).float(), lambda_0 - lambda_s) + h = torch.min(theta * h * torch.float_power(E, -1.0 / order).float(), lambda_0 - lambda_s) nfe += order - print('adaptive solver nfe', nfe) + print("adaptive solver nfe", nfe) return x - def sample(self, x, steps=20, t_start=None, t_end=None, order=3, skip_type='time_uniform', - method='singlestep', lower_order_final=True, denoise_to_zero=False, solver_type='dpm_solver', - atol=0.0078, rtol=0.05, - ): + def sample( + self, + x, + steps=20, + t_start=None, + t_end=None, + order=3, + skip_type="time_uniform", + method="singlestep", + lower_order_final=True, + denoise_to_zero=False, + solver_type="dpm_solver", + atol=0.0078, + rtol=0.05, + ): """ Compute the sample at time `t_end` by DPM-Solver, given the initial `x` at time `t_start`. ===================================================== @@ -1034,14 +1118,15 @@ class DPM_Solver: Returns: x_end: A pytorch tensor. The approximated solution at time `t_end`. """ - t_0 = 1. / self.noise_schedule.total_N if t_end is None else t_end + t_0 = 1.0 / self.noise_schedule.total_N if t_end is None else t_end t_T = self.noise_schedule.T if t_start is None else t_start device = x.device - if method == 'adaptive': + if method == "adaptive": with torch.no_grad(): - x = self.dpm_solver_adaptive(x, order=order, t_T=t_T, t_0=t_0, atol=atol, rtol=rtol, - solver_type=solver_type) - elif method == 'multistep': + x = self.dpm_solver_adaptive( + x, order=order, t_T=t_T, t_0=t_0, atol=atol, rtol=rtol, solver_type=solver_type + ) + elif method == "multistep": assert steps >= order timesteps = self.get_time_steps(skip_type=skip_type, t_T=t_T, t_0=t_0, N=steps, device=device) assert timesteps.shape[0] - 1 == steps @@ -1052,8 +1137,9 @@ class DPM_Solver: # Init the first `order` values by lower order multistep DPM-Solver. for init_order in tqdm(range(1, order), desc="DPM init order"): vec_t = timesteps[init_order].expand(x.shape[0]) - x = self.multistep_dpm_solver_update(x, model_prev_list, t_prev_list, vec_t, init_order, - solver_type=solver_type) + x = self.multistep_dpm_solver_update( + x, model_prev_list, t_prev_list, vec_t, init_order, solver_type=solver_type + ) model_prev_list.append(self.model_fn(x, vec_t)) t_prev_list.append(vec_t) # Compute the remaining values by `order`-th order multistep DPM-Solver. @@ -1063,8 +1149,9 @@ class DPM_Solver: step_order = min(order, steps + 1 - step) else: step_order = order - x = self.multistep_dpm_solver_update(x, model_prev_list, t_prev_list, vec_t, step_order, - solver_type=solver_type) + x = self.multistep_dpm_solver_update( + x, model_prev_list, t_prev_list, vec_t, step_order, solver_type=solver_type + ) for i in range(order - 1): t_prev_list[i] = t_prev_list[i + 1] model_prev_list[i] = model_prev_list[i + 1] @@ -1072,20 +1159,22 @@ class DPM_Solver: # We do not need to evaluate the final model value. if step < steps: model_prev_list[-1] = self.model_fn(x, vec_t) - elif method in ['singlestep', 'singlestep_fixed']: - if method == 'singlestep': - timesteps_outer, orders = self.get_orders_and_timesteps_for_singlestep_solver(steps=steps, order=order, - skip_type=skip_type, - t_T=t_T, t_0=t_0, - device=device) - elif method == 'singlestep_fixed': + elif method in ["singlestep", "singlestep_fixed"]: + if method == "singlestep": + timesteps_outer, orders = self.get_orders_and_timesteps_for_singlestep_solver( + steps=steps, order=order, skip_type=skip_type, t_T=t_T, t_0=t_0, device=device + ) + elif method == "singlestep_fixed": K = steps // order - orders = [order, ] * K + orders = [ + order, + ] * K timesteps_outer = self.get_time_steps(skip_type=skip_type, t_T=t_T, t_0=t_0, N=K, device=device) for i, order in enumerate(orders): t_T_inner, t_0_inner = timesteps_outer[i], timesteps_outer[i + 1] - timesteps_inner = self.get_time_steps(skip_type=skip_type, t_T=t_T_inner.item(), t_0=t_0_inner.item(), - N=order, device=device) + timesteps_inner = self.get_time_steps( + skip_type=skip_type, t_T=t_T_inner.item(), t_0=t_0_inner.item(), N=order, device=device + ) lambda_inner = self.noise_schedule.marginal_lambda(timesteps_inner) vec_s, vec_t = t_T_inner.tile(x.shape[0]), t_0_inner.tile(x.shape[0]) h = lambda_inner[-1] - lambda_inner[0] @@ -1101,6 +1190,7 @@ class DPM_Solver: # other utility functions ############################################################# + def interpolate_fn(x, xp, yp): """ A piecewise linear function y = f(x), using xp and yp as keypoints. @@ -1122,7 +1212,9 @@ def interpolate_fn(x, xp, yp): torch.eq(x_idx, 0), torch.tensor(1, device=x.device), torch.where( - torch.eq(x_idx, K), torch.tensor(K - 2, device=x.device), cand_start_idx, + torch.eq(x_idx, K), + torch.tensor(K - 2, device=x.device), + cand_start_idx, ), ) end_idx = torch.where(torch.eq(start_idx, cand_start_idx), start_idx + 2, start_idx + 1) @@ -1132,7 +1224,9 @@ def interpolate_fn(x, xp, yp): torch.eq(x_idx, 0), torch.tensor(0, device=x.device), torch.where( - torch.eq(x_idx, K), torch.tensor(K - 2, device=x.device), cand_start_idx, + torch.eq(x_idx, K), + torch.tensor(K - 2, device=x.device), + cand_start_idx, ), ) y_positions_expanded = yp.unsqueeze(0).expand(N, -1, -1) @@ -1151,4 +1245,4 @@ def expand_dims(v, dims): Returns: a PyTorch tensor with shape [N, 1, 1, ..., 1] and the total dimension is `dims`. """ - return v[(...,) + (None,) * (dims - 1)] \ No newline at end of file + return v[(...,) + (None,) * (dims - 1)] diff --git a/examples/images/diffusion/ldm/models/diffusion/dpm_solver/sampler.py b/examples/images/diffusion/ldm/models/diffusion/dpm_solver/sampler.py index 7d137b8cf36718c1c58faa09f9dd919e5fb2977b..55dac8555e5fcf060be1e57fde1cb1e9c634aa28 100644 --- a/examples/images/diffusion/ldm/models/diffusion/dpm_solver/sampler.py +++ b/examples/images/diffusion/ldm/models/diffusion/dpm_solver/sampler.py @@ -1,13 +1,9 @@ """SAMPLING ONLY.""" import torch -from .dpm_solver import NoiseScheduleVP, model_wrapper, DPM_Solver +from .dpm_solver import DPM_Solver, NoiseScheduleVP, model_wrapper - -MODEL_TYPES = { - "eps": "noise", - "v": "v" -} +MODEL_TYPES = {"eps": "noise", "v": "v"} class DPMSolverSampler(object): @@ -15,7 +11,7 @@ class DPMSolverSampler(object): super().__init__() self.model = model to_torch = lambda x: x.clone().detach().to(torch.float32).to(model.device) - self.register_buffer('alphas_cumprod', to_torch(model.alphas_cumprod)) + self.register_buffer("alphas_cumprod", to_torch(model.alphas_cumprod)) def register_buffer(self, name, attr): if type(attr) == torch.Tensor: @@ -24,30 +20,31 @@ class DPMSolverSampler(object): setattr(self, name, attr) @torch.no_grad() - def sample(self, - S, - batch_size, - shape, - conditioning=None, - callback=None, - normals_sequence=None, - img_callback=None, - quantize_x0=False, - eta=0., - mask=None, - x0=None, - temperature=1., - noise_dropout=0., - score_corrector=None, - corrector_kwargs=None, - verbose=True, - x_T=None, - log_every_t=100, - unconditional_guidance_scale=1., - unconditional_conditioning=None, - # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ... - **kwargs - ): + def sample( + self, + S, + batch_size, + shape, + conditioning=None, + callback=None, + normals_sequence=None, + img_callback=None, + quantize_x0=False, + eta=0.0, + mask=None, + x0=None, + temperature=1.0, + noise_dropout=0.0, + score_corrector=None, + corrector_kwargs=None, + verbose=True, + x_T=None, + log_every_t=100, + unconditional_guidance_scale=1.0, + unconditional_conditioning=None, + # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ... + **kwargs, + ): if conditioning is not None: if isinstance(conditioning, dict): cbs = conditioning[list(conditioning.keys())[0]].shape[0] @@ -61,7 +58,7 @@ class DPMSolverSampler(object): C, H, W = shape size = (batch_size, C, H, W) - print(f'Data shape for DPM-Solver sampling is {size}, sampling steps {S}') + print(f"Data shape for DPM-Solver sampling is {size}, sampling steps {S}") device = self.model.betas.device if x_T is None: @@ -69,7 +66,7 @@ class DPMSolverSampler(object): else: img = x_T - ns = NoiseScheduleVP('discrete', alphas_cumprod=self.alphas_cumprod) + ns = NoiseScheduleVP("discrete", alphas_cumprod=self.alphas_cumprod) model_fn = model_wrapper( lambda x, t, c: self.model.apply_model(x, t, c), @@ -82,6 +79,8 @@ class DPMSolverSampler(object): ) dpm_solver = DPM_Solver(model_fn, ns, predict_x0=True, thresholding=False) - x = dpm_solver.sample(img, steps=S, skip_type="time_uniform", method="multistep", order=2, lower_order_final=True) + x = dpm_solver.sample( + img, steps=S, skip_type="time_uniform", method="multistep", order=2, lower_order_final=True + ) - return x.to(device), None \ No newline at end of file + return x.to(device), None diff --git a/examples/images/diffusion/ldm/models/diffusion/plms.py b/examples/images/diffusion/ldm/models/diffusion/plms.py index 7002a365d27168ced0a04e9a4d83e088f8284eae..b2b3f032e4914dac002a3cca2df89e1efe8b8990 100644 --- a/examples/images/diffusion/ldm/models/diffusion/plms.py +++ b/examples/images/diffusion/ldm/models/diffusion/plms.py @@ -1,12 +1,10 @@ """SAMPLING ONLY.""" -import torch import numpy as np -from tqdm import tqdm -from functools import partial - -from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like +import torch from ldm.models.diffusion.sampling_util import norm_thresholding +from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like +from tqdm import tqdm class PLMSSampler(object): @@ -22,65 +20,72 @@ class PLMSSampler(object): attr = attr.to(torch.device("cuda")) setattr(self, name, attr) - def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True): + def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0.0, verbose=True): if ddim_eta != 0: - raise ValueError('ddim_eta must be 0 for PLMS') - self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps, - num_ddpm_timesteps=self.ddpm_num_timesteps,verbose=verbose) + raise ValueError("ddim_eta must be 0 for PLMS") + self.ddim_timesteps = make_ddim_timesteps( + ddim_discr_method=ddim_discretize, + num_ddim_timesteps=ddim_num_steps, + num_ddpm_timesteps=self.ddpm_num_timesteps, + verbose=verbose, + ) alphas_cumprod = self.model.alphas_cumprod - assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep' + assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, "alphas have to be defined for each timestep" to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device) - self.register_buffer('betas', to_torch(self.model.betas)) - self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod)) - self.register_buffer('alphas_cumprod_prev', to_torch(self.model.alphas_cumprod_prev)) + self.register_buffer("betas", to_torch(self.model.betas)) + self.register_buffer("alphas_cumprod", to_torch(alphas_cumprod)) + self.register_buffer("alphas_cumprod_prev", to_torch(self.model.alphas_cumprod_prev)) # calculations for diffusion q(x_t | x_{t-1}) and others - self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu()))) - self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod.cpu()))) - self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod.cpu()))) - self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu()))) - self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu() - 1))) + self.register_buffer("sqrt_alphas_cumprod", to_torch(np.sqrt(alphas_cumprod.cpu()))) + self.register_buffer("sqrt_one_minus_alphas_cumprod", to_torch(np.sqrt(1.0 - alphas_cumprod.cpu()))) + self.register_buffer("log_one_minus_alphas_cumprod", to_torch(np.log(1.0 - alphas_cumprod.cpu()))) + self.register_buffer("sqrt_recip_alphas_cumprod", to_torch(np.sqrt(1.0 / alphas_cumprod.cpu()))) + self.register_buffer("sqrt_recipm1_alphas_cumprod", to_torch(np.sqrt(1.0 / alphas_cumprod.cpu() - 1))) # ddim sampling parameters - ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(alphacums=alphas_cumprod.cpu(), - ddim_timesteps=self.ddim_timesteps, - eta=ddim_eta,verbose=verbose) - self.register_buffer('ddim_sigmas', ddim_sigmas) - self.register_buffer('ddim_alphas', ddim_alphas) - self.register_buffer('ddim_alphas_prev', ddim_alphas_prev) - self.register_buffer('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas)) + ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters( + alphacums=alphas_cumprod.cpu(), ddim_timesteps=self.ddim_timesteps, eta=ddim_eta, verbose=verbose + ) + self.register_buffer("ddim_sigmas", ddim_sigmas) + self.register_buffer("ddim_alphas", ddim_alphas) + self.register_buffer("ddim_alphas_prev", ddim_alphas_prev) + self.register_buffer("ddim_sqrt_one_minus_alphas", np.sqrt(1.0 - ddim_alphas)) sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt( - (1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * ( - 1 - self.alphas_cumprod / self.alphas_cumprod_prev)) - self.register_buffer('ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps) + (1 - self.alphas_cumprod_prev) + / (1 - self.alphas_cumprod) + * (1 - self.alphas_cumprod / self.alphas_cumprod_prev) + ) + self.register_buffer("ddim_sigmas_for_original_num_steps", sigmas_for_original_sampling_steps) @torch.no_grad() - def sample(self, - S, - batch_size, - shape, - conditioning=None, - callback=None, - normals_sequence=None, - img_callback=None, - quantize_x0=False, - eta=0., - mask=None, - x0=None, - temperature=1., - noise_dropout=0., - score_corrector=None, - corrector_kwargs=None, - verbose=True, - x_T=None, - log_every_t=100, - unconditional_guidance_scale=1., - unconditional_conditioning=None, - # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ... - dynamic_threshold=None, - **kwargs - ): + def sample( + self, + S, + batch_size, + shape, + conditioning=None, + callback=None, + normals_sequence=None, + img_callback=None, + quantize_x0=False, + eta=0.0, + mask=None, + x0=None, + temperature=1.0, + noise_dropout=0.0, + score_corrector=None, + corrector_kwargs=None, + verbose=True, + x_T=None, + log_every_t=100, + unconditional_guidance_scale=1.0, + unconditional_conditioning=None, + # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ... + dynamic_threshold=None, + **kwargs, + ): if conditioning is not None: if isinstance(conditioning, dict): cbs = conditioning[list(conditioning.keys())[0]].shape[0] @@ -94,34 +99,51 @@ class PLMSSampler(object): # sampling C, H, W = shape size = (batch_size, C, H, W) - print(f'Data shape for PLMS sampling is {size}') - - samples, intermediates = self.plms_sampling(conditioning, size, - callback=callback, - img_callback=img_callback, - quantize_denoised=quantize_x0, - mask=mask, x0=x0, - ddim_use_original_steps=False, - noise_dropout=noise_dropout, - temperature=temperature, - score_corrector=score_corrector, - corrector_kwargs=corrector_kwargs, - x_T=x_T, - log_every_t=log_every_t, - unconditional_guidance_scale=unconditional_guidance_scale, - unconditional_conditioning=unconditional_conditioning, - dynamic_threshold=dynamic_threshold, - ) + print(f"Data shape for PLMS sampling is {size}") + + samples, intermediates = self.plms_sampling( + conditioning, + size, + callback=callback, + img_callback=img_callback, + quantize_denoised=quantize_x0, + mask=mask, + x0=x0, + ddim_use_original_steps=False, + noise_dropout=noise_dropout, + temperature=temperature, + score_corrector=score_corrector, + corrector_kwargs=corrector_kwargs, + x_T=x_T, + log_every_t=log_every_t, + unconditional_guidance_scale=unconditional_guidance_scale, + unconditional_conditioning=unconditional_conditioning, + dynamic_threshold=dynamic_threshold, + ) return samples, intermediates @torch.no_grad() - def plms_sampling(self, cond, shape, - x_T=None, ddim_use_original_steps=False, - callback=None, timesteps=None, quantize_denoised=False, - mask=None, x0=None, img_callback=None, log_every_t=100, - temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None, - unconditional_guidance_scale=1., unconditional_conditioning=None, - dynamic_threshold=None): + def plms_sampling( + self, + cond, + shape, + x_T=None, + ddim_use_original_steps=False, + callback=None, + timesteps=None, + quantize_denoised=False, + mask=None, + x0=None, + img_callback=None, + log_every_t=100, + temperature=1.0, + noise_dropout=0.0, + score_corrector=None, + corrector_kwargs=None, + unconditional_guidance_scale=1.0, + unconditional_conditioning=None, + dynamic_threshold=None, + ): device = self.model.betas.device b = shape[0] if x_T is None: @@ -135,12 +157,12 @@ class PLMSSampler(object): subset_end = int(min(timesteps / self.ddim_timesteps.shape[0], 1) * self.ddim_timesteps.shape[0]) - 1 timesteps = self.ddim_timesteps[:subset_end] - intermediates = {'x_inter': [img], 'pred_x0': [img]} - time_range = list(reversed(range(0,timesteps))) if ddim_use_original_steps else np.flip(timesteps) + intermediates = {"x_inter": [img], "pred_x0": [img]} + time_range = list(reversed(range(0, timesteps))) if ddim_use_original_steps else np.flip(timesteps) total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0] print(f"Running PLMS Sampling with {total_steps} timesteps") - iterator = tqdm(time_range, desc='PLMS Sampler', total=total_steps) + iterator = tqdm(time_range, desc="PLMS Sampler", total=total_steps) old_eps = [] for i, step in enumerate(iterator): @@ -151,38 +173,64 @@ class PLMSSampler(object): if mask is not None: assert x0 is not None img_orig = self.model.q_sample(x0, ts) # TODO: deterministic forward pass? - img = img_orig * mask + (1. - mask) * img - - outs = self.p_sample_plms(img, cond, ts, index=index, use_original_steps=ddim_use_original_steps, - quantize_denoised=quantize_denoised, temperature=temperature, - noise_dropout=noise_dropout, score_corrector=score_corrector, - corrector_kwargs=corrector_kwargs, - unconditional_guidance_scale=unconditional_guidance_scale, - unconditional_conditioning=unconditional_conditioning, - old_eps=old_eps, t_next=ts_next, - dynamic_threshold=dynamic_threshold) + img = img_orig * mask + (1.0 - mask) * img + + outs = self.p_sample_plms( + img, + cond, + ts, + index=index, + use_original_steps=ddim_use_original_steps, + quantize_denoised=quantize_denoised, + temperature=temperature, + noise_dropout=noise_dropout, + score_corrector=score_corrector, + corrector_kwargs=corrector_kwargs, + unconditional_guidance_scale=unconditional_guidance_scale, + unconditional_conditioning=unconditional_conditioning, + old_eps=old_eps, + t_next=ts_next, + dynamic_threshold=dynamic_threshold, + ) img, pred_x0, e_t = outs old_eps.append(e_t) if len(old_eps) >= 4: old_eps.pop(0) - if callback: callback(i) - if img_callback: img_callback(pred_x0, i) + if callback: + callback(i) + if img_callback: + img_callback(pred_x0, i) if index % log_every_t == 0 or index == total_steps - 1: - intermediates['x_inter'].append(img) - intermediates['pred_x0'].append(pred_x0) + intermediates["x_inter"].append(img) + intermediates["pred_x0"].append(pred_x0) return img, intermediates @torch.no_grad() - def p_sample_plms(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False, - temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None, - unconditional_guidance_scale=1., unconditional_conditioning=None, old_eps=None, t_next=None, - dynamic_threshold=None): + def p_sample_plms( + self, + x, + c, + t, + index, + repeat_noise=False, + use_original_steps=False, + quantize_denoised=False, + temperature=1.0, + noise_dropout=0.0, + score_corrector=None, + corrector_kwargs=None, + unconditional_guidance_scale=1.0, + unconditional_conditioning=None, + old_eps=None, + t_next=None, + dynamic_threshold=None, + ): b, *_, device = *x.shape, x.device def get_model_output(x, t): - if unconditional_conditioning is None or unconditional_guidance_scale == 1.: + if unconditional_conditioning is None or unconditional_guidance_scale == 1.0: e_t = self.model.apply_model(x, t, c) else: x_in = torch.cat([x] * 2) @@ -199,7 +247,9 @@ class PLMSSampler(object): alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev - sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas + sqrt_one_minus_alphas = ( + self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas + ) sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas def get_x_prev_and_pred_x0(e_t, index): @@ -207,7 +257,7 @@ class PLMSSampler(object): a_t = torch.full((b, 1, 1, 1), alphas[index], device=device) a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device) sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device) - sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device) + sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index], device=device) # current prediction for x_0 pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt() @@ -216,9 +266,9 @@ class PLMSSampler(object): if dynamic_threshold is not None: pred_x0 = norm_thresholding(pred_x0, dynamic_threshold) # direction pointing to x_t - dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t + dir_xt = (1.0 - a_prev - sigma_t**2).sqrt() * e_t noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature - if noise_dropout > 0.: + if noise_dropout > 0.0: noise = torch.nn.functional.dropout(noise, p=noise_dropout) x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise return x_prev, pred_x0 diff --git a/examples/images/diffusion/ldm/models/diffusion/sampling_util.py b/examples/images/diffusion/ldm/models/diffusion/sampling_util.py index 7eff02be6d7c54d43ee6680636ac0698dd3b3f33..a4681368112bc6bb6963bc05c9c64011b9522ab1 100644 --- a/examples/images/diffusion/ldm/models/diffusion/sampling_util.py +++ b/examples/images/diffusion/ldm/models/diffusion/sampling_util.py @@ -1,13 +1,9 @@ -import torch -import numpy as np - - def append_dims(x, target_dims): """Appends dimensions to the end of a tensor until it has target_dims dimensions. From https://github.com/crowsonkb/k-diffusion/blob/master/k_diffusion/utils.py""" dims_to_append = target_dims - x.ndim if dims_to_append < 0: - raise ValueError(f'input has {x.ndim} dims but target_dims is {target_dims}, which is less') + raise ValueError(f"input has {x.ndim} dims but target_dims is {target_dims}, which is less") return x[(...,) + (None,) * dims_to_append] @@ -19,4 +15,4 @@ def norm_thresholding(x0, value): def spatial_norm_thresholding(x0, value): # b c h w s = x0.pow(2).mean(1, keepdim=True).sqrt().clamp(min=value) - return x0 * (value / s) \ No newline at end of file + return x0 * (value / s) diff --git a/examples/images/diffusion/ldm/modules/attention.py b/examples/images/diffusion/ldm/modules/attention.py index d504d939f6a02cf45f028799d7d73b84500cee06..f3c385e5138faad6feea562178589900cfdbd5f5 100644 --- a/examples/images/diffusion/ldm/modules/attention.py +++ b/examples/images/diffusion/ldm/modules/attention.py @@ -1,17 +1,17 @@ -from inspect import isfunction import math +from inspect import isfunction +from typing import Any, Optional + import torch import torch.nn.functional as F -from torch import nn, einsum from einops import rearrange, repeat -from typing import Optional, Any - from ldm.modules.diffusionmodules.util import checkpoint - +from torch import einsum, nn try: import xformers import xformers.ops + XFORMERS_IS_AVAILBLE = True except: XFORMERS_IS_AVAILBLE = False @@ -22,7 +22,7 @@ def exists(val): def uniq(arr): - return{el: True for el in arr}.keys() + return {el: True for el in arr}.keys() def default(val, d): @@ -54,20 +54,13 @@ class GEGLU(nn.Module): class FeedForward(nn.Module): - def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.): + def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.0): super().__init__() inner_dim = int(dim * mult) dim_out = default(dim_out, dim) - project_in = nn.Sequential( - nn.Linear(dim, inner_dim), - nn.GELU() - ) if not glu else GEGLU(dim, inner_dim) - - self.net = nn.Sequential( - project_in, - nn.Dropout(dropout), - nn.Linear(inner_dim, dim_out) - ) + project_in = nn.Sequential(nn.Linear(dim, inner_dim), nn.GELU()) if not glu else GEGLU(dim, inner_dim) + + self.net = nn.Sequential(project_in, nn.Dropout(dropout), nn.Linear(inner_dim, dim_out)) def forward(self, x): return self.net(x) @@ -92,26 +85,10 @@ class SpatialSelfAttention(nn.Module): self.in_channels = in_channels self.norm = Normalize(in_channels) - self.q = torch.nn.Conv2d(in_channels, - in_channels, - kernel_size=1, - stride=1, - padding=0) - self.k = torch.nn.Conv2d(in_channels, - in_channels, - kernel_size=1, - stride=1, - padding=0) - self.v = torch.nn.Conv2d(in_channels, - in_channels, - kernel_size=1, - stride=1, - padding=0) - self.proj_out = torch.nn.Conv2d(in_channels, - in_channels, - kernel_size=1, - stride=1, - padding=0) + self.q = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) + self.k = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) + self.v = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) + self.proj_out = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) def forward(self, x): h_ = x @@ -121,41 +98,38 @@ class SpatialSelfAttention(nn.Module): v = self.v(h_) # compute attention - b,c,h,w = q.shape - q = rearrange(q, 'b c h w -> b (h w) c') - k = rearrange(k, 'b c h w -> b c (h w)') - w_ = torch.einsum('bij,bjk->bik', q, k) + b, c, h, w = q.shape + q = rearrange(q, "b c h w -> b (h w) c") + k = rearrange(k, "b c h w -> b c (h w)") + w_ = torch.einsum("bij,bjk->bik", q, k) - w_ = w_ * (int(c)**(-0.5)) + w_ = w_ * (int(c) ** (-0.5)) w_ = torch.nn.functional.softmax(w_, dim=2) # attend to values - v = rearrange(v, 'b c h w -> b c (h w)') - w_ = rearrange(w_, 'b i j -> b j i') - h_ = torch.einsum('bij,bjk->bik', v, w_) - h_ = rearrange(h_, 'b c (h w) -> b c h w', h=h) + v = rearrange(v, "b c h w -> b c (h w)") + w_ = rearrange(w_, "b i j -> b j i") + h_ = torch.einsum("bij,bjk->bik", v, w_) + h_ = rearrange(h_, "b c (h w) -> b c h w", h=h) h_ = self.proj_out(h_) - return x+h_ + return x + h_ class CrossAttention(nn.Module): - def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.): + def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0): super().__init__() inner_dim = dim_head * heads context_dim = default(context_dim, query_dim) - self.scale = dim_head ** -0.5 + self.scale = dim_head**-0.5 self.heads = heads self.to_q = nn.Linear(query_dim, inner_dim, bias=False) self.to_k = nn.Linear(context_dim, inner_dim, bias=False) self.to_v = nn.Linear(context_dim, inner_dim, bias=False) - self.to_out = nn.Sequential( - nn.Linear(inner_dim, query_dim), - nn.Dropout(dropout) - ) + self.to_out = nn.Sequential(nn.Linear(inner_dim, query_dim), nn.Dropout(dropout)) def forward(self, x, context=None, mask=None): h = self.heads @@ -165,22 +139,22 @@ class CrossAttention(nn.Module): k = self.to_k(context) v = self.to_v(context) - q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v)) + q, k, v = map(lambda t: rearrange(t, "b n (h d) -> (b h) n d", h=h), (q, k, v)) - sim = einsum('b i d, b j d -> b i j', q, k) * self.scale + sim = einsum("b i d, b j d -> b i j", q, k) * self.scale del q, k if exists(mask): - mask = rearrange(mask, 'b ... -> b (...)') + mask = rearrange(mask, "b ... -> b (...)") max_neg_value = -torch.finfo(sim.dtype).max - mask = repeat(mask, 'b j -> (b h) () j', h=h) + mask = repeat(mask, "b j -> (b h) () j", h=h) sim.masked_fill_(~mask, max_neg_value) # attention, what we cannot get enough of sim = sim.softmax(dim=-1) - out = einsum('b i j, b j d -> b i d', sim, v) - out = rearrange(out, '(b h) n d -> b n (h d)', h=h) + out = einsum("b i j, b j d -> b i d", sim, v) + out = rearrange(out, "(b h) n d -> b n (h d)", h=h) return self.to_out(out) @@ -188,8 +162,10 @@ class MemoryEfficientCrossAttention(nn.Module): # https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223 def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0): super().__init__() - print(f"Setting up {self.__class__.__name__}. Query dim is {query_dim}, context_dim is {context_dim} and using " - f"{heads} heads.") + print( + f"Setting up {self.__class__.__name__}. Query dim is {query_dim}, context_dim is {context_dim} and using " + f"{heads} heads." + ) inner_dim = dim_head * heads context_dim = default(context_dim, query_dim) @@ -236,20 +212,36 @@ class MemoryEfficientCrossAttention(nn.Module): class BasicTransformerBlock(nn.Module): ATTENTION_MODES = { "softmax": CrossAttention, # vanilla attention - "softmax-xformers": MemoryEfficientCrossAttention + "softmax-xformers": MemoryEfficientCrossAttention, } - def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=True, - disable_self_attn=False): + + def __init__( + self, + dim, + n_heads, + d_head, + dropout=0.0, + context_dim=None, + gated_ff=True, + checkpoint=True, + disable_self_attn=False, + ): super().__init__() attn_mode = "softmax-xformers" if XFORMERS_IS_AVAILBLE else "softmax" assert attn_mode in self.ATTENTION_MODES attn_cls = self.ATTENTION_MODES[attn_mode] self.disable_self_attn = disable_self_attn - self.attn1 = attn_cls(query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout, - context_dim=context_dim if self.disable_self_attn else None) # is a self-attention if not self.disable_self_attn + self.attn1 = attn_cls( + query_dim=dim, + heads=n_heads, + dim_head=d_head, + dropout=dropout, + context_dim=context_dim if self.disable_self_attn else None, + ) # is a self-attention if not self.disable_self_attn self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff) - self.attn2 = attn_cls(query_dim=dim, context_dim=context_dim, - heads=n_heads, dim_head=d_head, dropout=dropout) # is self-attn if context is none + self.attn2 = attn_cls( + query_dim=dim, context_dim=context_dim, heads=n_heads, dim_head=d_head, dropout=dropout + ) # is self-attn if context is none self.norm1 = nn.LayerNorm(dim) self.norm2 = nn.LayerNorm(dim) self.norm3 = nn.LayerNorm(dim) @@ -274,10 +266,19 @@ class SpatialTransformer(nn.Module): Finally, reshape to image NEW: use_linear for more efficiency instead of the 1x1 convs """ - def __init__(self, in_channels, n_heads, d_head, - depth=1, dropout=0., context_dim=None, - disable_self_attn=False, use_linear=False, - use_checkpoint=True): + + def __init__( + self, + in_channels, + n_heads, + d_head, + depth=1, + dropout=0.0, + context_dim=None, + disable_self_attn=False, + use_linear=False, + use_checkpoint=True, + ): super().__init__() if exists(context_dim) and not isinstance(context_dim, list): context_dim = [context_dim] @@ -285,25 +286,26 @@ class SpatialTransformer(nn.Module): inner_dim = n_heads * d_head self.norm = Normalize(in_channels) if not use_linear: - self.proj_in = nn.Conv2d(in_channels, - inner_dim, - kernel_size=1, - stride=1, - padding=0) + self.proj_in = nn.Conv2d(in_channels, inner_dim, kernel_size=1, stride=1, padding=0) else: self.proj_in = nn.Linear(in_channels, inner_dim) self.transformer_blocks = nn.ModuleList( - [BasicTransformerBlock(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim[d], - disable_self_attn=disable_self_attn, checkpoint=use_checkpoint) - for d in range(depth)] + [ + BasicTransformerBlock( + inner_dim, + n_heads, + d_head, + dropout=dropout, + context_dim=context_dim[d], + disable_self_attn=disable_self_attn, + checkpoint=use_checkpoint, + ) + for d in range(depth) + ] ) if not use_linear: - self.proj_out = zero_module(nn.Conv2d(inner_dim, - in_channels, - kernel_size=1, - stride=1, - padding=0)) + self.proj_out = zero_module(nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)) else: self.proj_out = zero_module(nn.Linear(in_channels, inner_dim)) self.use_linear = use_linear @@ -317,15 +319,14 @@ class SpatialTransformer(nn.Module): x = self.norm(x) if not self.use_linear: x = self.proj_in(x) - x = rearrange(x, 'b c h w -> b (h w) c').contiguous() + x = rearrange(x, "b c h w -> b (h w) c").contiguous() if self.use_linear: x = self.proj_in(x) for i, block in enumerate(self.transformer_blocks): x = block(x, context=context[i]) if self.use_linear: x = self.proj_out(x) - x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w).contiguous() + x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w).contiguous() if not self.use_linear: x = self.proj_out(x) return x + x_in - diff --git a/examples/images/diffusion/ldm/modules/diffusionmodules/model.py b/examples/images/diffusion/ldm/modules/diffusionmodules/model.py index fb088db58919dd3ab79b2d6c7ab4d0e6a40f7454..7ed8d98a44ad5906f163b3763f10aed5bdb6e502 100644 --- a/examples/images/diffusion/ldm/modules/diffusionmodules/model.py +++ b/examples/images/diffusion/ldm/modules/diffusionmodules/model.py @@ -17,6 +17,7 @@ from ldm.modules.attention import MemoryEfficientCrossAttention try: import xformers import xformers.ops + XFORMERS_IS_AVAILBLE = True except: XFORMERS_IS_AVAILBLE = False @@ -39,7 +40,7 @@ def get_timestep_embedding(timesteps, embedding_dim): emb = emb.to(device=timesteps.device) emb = timesteps.float()[:, None] * emb[None, :] emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) - if embedding_dim % 2 == 1: # zero pad + if embedding_dim % 2 == 1: # zero pad emb = torch.nn.functional.pad(emb, (0, 1, 0, 0)) return emb @@ -54,7 +55,6 @@ def Normalize(in_channels, num_groups=32): class Upsample(nn.Module): - def __init__(self, in_channels, with_conv): super().__init__() self.with_conv = with_conv @@ -69,7 +69,6 @@ class Upsample(nn.Module): class Downsample(nn.Module): - def __init__(self, in_channels, with_conv): super().__init__() self.with_conv = with_conv @@ -88,7 +87,6 @@ class Downsample(nn.Module): class ResnetBlock(nn.Module): - def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False, dropout, temb_channels=512): super().__init__() self.in_channels = in_channels @@ -133,7 +131,6 @@ class ResnetBlock(nn.Module): class AttnBlock(nn.Module): - def __init__(self, in_channels): super().__init__() self.in_channels = in_channels @@ -154,16 +151,16 @@ class AttnBlock(nn.Module): # compute attention b, c, h, w = q.shape q = q.reshape(b, c, h * w) - q = q.permute(0, 2, 1) # b,hw,c - k = k.reshape(b, c, h * w) # b,c,hw - w_ = torch.bmm(q, k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j] - w_ = w_ * (int(c)**(-0.5)) + q = q.permute(0, 2, 1) # b,hw,c + k = k.reshape(b, c, h * w) # b,c,hw + w_ = torch.bmm(q, k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j] + w_ = w_ * (int(c) ** (-0.5)) w_ = torch.nn.functional.softmax(w_, dim=2) # attend to values v = v.reshape(b, c, h * w) - w_ = w_.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q) - h_ = torch.bmm(v, w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j] + w_ = w_.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q) + h_ = torch.bmm(v, w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j] h_ = h_.reshape(b, c, h, w) h_ = self.proj_out(h_) @@ -173,9 +170,9 @@ class AttnBlock(nn.Module): class MemoryEfficientAttnBlock(nn.Module): """ - Uses xformers efficient implementation, - see https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223 - Note: this is a single-head self-attention operation + Uses xformers efficient implementation, + see https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223 + Note: this is a single-head self-attention operation """ # @@ -199,34 +196,41 @@ class MemoryEfficientAttnBlock(nn.Module): # compute attention B, C, H, W = q.shape - q, k, v = map(lambda x: rearrange(x, 'b c h w -> b (h w) c'), (q, k, v)) + q, k, v = map(lambda x: rearrange(x, "b c h w -> b (h w) c"), (q, k, v)) q, k, v = map( - lambda t: t.unsqueeze(3).reshape(B, t.shape[1], 1, C).permute(0, 2, 1, 3).reshape(B * 1, t.shape[1], C). - contiguous(), + lambda t: t.unsqueeze(3) + .reshape(B, t.shape[1], 1, C) + .permute(0, 2, 1, 3) + .reshape(B * 1, t.shape[1], C) + .contiguous(), (q, k, v), ) out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None, op=self.attention_op) - out = (out.unsqueeze(0).reshape(B, 1, out.shape[1], C).permute(0, 2, 1, 3).reshape(B, out.shape[1], C)) - out = rearrange(out, 'b (h w) c -> b c h w', b=B, h=H, w=W, c=C) + out = out.unsqueeze(0).reshape(B, 1, out.shape[1], C).permute(0, 2, 1, 3).reshape(B, out.shape[1], C) + out = rearrange(out, "b (h w) c -> b c h w", b=B, h=H, w=W, c=C) out = self.proj_out(out) return x + out class MemoryEfficientCrossAttentionWrapper(MemoryEfficientCrossAttention): - def forward(self, x, context=None, mask=None): b, c, h, w = x.shape - x = rearrange(x, 'b c h w -> b (h w) c') + x = rearrange(x, "b c h w -> b (h w) c") out = super().forward(x, context=context, mask=mask) - out = rearrange(out, 'b (h w) c -> b c h w', h=h, w=w, c=c) + out = rearrange(out, "b (h w) c -> b c h w", h=h, w=w, c=c) return x + out def make_attn(in_channels, attn_type="vanilla", attn_kwargs=None): - assert attn_type in ["vanilla", "vanilla-xformers", "memory-efficient-cross-attn", "linear", - "none"], f'attn_type {attn_type} unknown' + assert attn_type in [ + "vanilla", + "vanilla-xformers", + "memory-efficient-cross-attn", + "linear", + "none", + ], f"attn_type {attn_type} unknown" if XFORMERS_IS_AVAILBLE and attn_type == "vanilla": attn_type = "vanilla-xformers" if attn_type == "vanilla": @@ -245,21 +249,22 @@ def make_attn(in_channels, attn_type="vanilla", attn_kwargs=None): class Model(nn.Module): - - def __init__(self, - *, - ch, - out_ch, - ch_mult=(1, 2, 4, 8), - num_res_blocks, - attn_resolutions, - dropout=0.0, - resamp_with_conv=True, - in_channels, - resolution, - use_timestep=True, - use_linear_attn=False, - attn_type="vanilla"): + def __init__( + self, + *, + ch, + out_ch, + ch_mult=(1, 2, 4, 8), + num_res_blocks, + attn_resolutions, + dropout=0.0, + resamp_with_conv=True, + in_channels, + resolution, + use_timestep=True, + use_linear_attn=False, + attn_type="vanilla", + ): super().__init__() if use_linear_attn: attn_type = "linear" @@ -274,10 +279,12 @@ class Model(nn.Module): if self.use_timestep: # timestep embedding self.temb = nn.Module() - self.temb.dense = nn.ModuleList([ - torch.nn.Linear(self.ch, self.temb_ch), - torch.nn.Linear(self.temb_ch, self.temb_ch), - ]) + self.temb.dense = nn.ModuleList( + [ + torch.nn.Linear(self.ch, self.temb_ch), + torch.nn.Linear(self.temb_ch, self.temb_ch), + ] + ) # downsampling self.conv_in = torch.nn.Conv2d(in_channels, self.ch, kernel_size=3, stride=1, padding=1) @@ -292,10 +299,10 @@ class Model(nn.Module): block_out = ch * ch_mult[i_level] for i_block in range(self.num_res_blocks): block.append( - ResnetBlock(in_channels=block_in, - out_channels=block_out, - temb_channels=self.temb_ch, - dropout=dropout)) + ResnetBlock( + in_channels=block_in, out_channels=block_out, temb_channels=self.temb_ch, dropout=dropout + ) + ) block_in = block_out if curr_res in attn_resolutions: attn.append(make_attn(block_in, attn_type=attn_type)) @@ -309,15 +316,13 @@ class Model(nn.Module): # middle self.mid = nn.Module() - self.mid.block_1 = ResnetBlock(in_channels=block_in, - out_channels=block_in, - temb_channels=self.temb_ch, - dropout=dropout) + self.mid.block_1 = ResnetBlock( + in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout + ) self.mid.attn_1 = make_attn(block_in, attn_type=attn_type) - self.mid.block_2 = ResnetBlock(in_channels=block_in, - out_channels=block_in, - temb_channels=self.temb_ch, - dropout=dropout) + self.mid.block_2 = ResnetBlock( + in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout + ) # upsampling self.up = nn.ModuleList() @@ -330,10 +335,13 @@ class Model(nn.Module): if i_block == self.num_res_blocks: skip_in = ch * in_ch_mult[i_level] block.append( - ResnetBlock(in_channels=block_in + skip_in, - out_channels=block_out, - temb_channels=self.temb_ch, - dropout=dropout)) + ResnetBlock( + in_channels=block_in + skip_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout, + ) + ) block_in = block_out if curr_res in attn_resolutions: attn.append(make_attn(block_in, attn_type=attn_type)) @@ -343,14 +351,14 @@ class Model(nn.Module): if i_level != 0: up.upsample = Upsample(block_in, resamp_with_conv) curr_res = curr_res * 2 - self.up.insert(0, up) # prepend to get consistent order + self.up.insert(0, up) # prepend to get consistent order # end self.norm_out = Normalize(block_in) self.conv_out = torch.nn.Conv2d(block_in, out_ch, kernel_size=3, stride=1, padding=1) def forward(self, x, t=None, context=None): - #assert x.shape[2] == x.shape[3] == self.resolution + # assert x.shape[2] == x.shape[3] == self.resolution if context is not None: # assume aligned context, cat along channel axis x = torch.cat((x, context), dim=1) @@ -401,23 +409,24 @@ class Model(nn.Module): class Encoder(nn.Module): - - def __init__(self, - *, - ch, - out_ch, - ch_mult=(1, 2, 4, 8), - num_res_blocks, - attn_resolutions, - dropout=0.0, - resamp_with_conv=True, - in_channels, - resolution, - z_channels, - double_z=True, - use_linear_attn=False, - attn_type="vanilla", - **ignore_kwargs): + def __init__( + self, + *, + ch, + out_ch, + ch_mult=(1, 2, 4, 8), + num_res_blocks, + attn_resolutions, + dropout=0.0, + resamp_with_conv=True, + in_channels, + resolution, + z_channels, + double_z=True, + use_linear_attn=False, + attn_type="vanilla", + **ignore_kwargs, + ): super().__init__() if use_linear_attn: attn_type = "linear" @@ -442,10 +451,10 @@ class Encoder(nn.Module): block_out = ch * ch_mult[i_level] for i_block in range(self.num_res_blocks): block.append( - ResnetBlock(in_channels=block_in, - out_channels=block_out, - temb_channels=self.temb_ch, - dropout=dropout)) + ResnetBlock( + in_channels=block_in, out_channels=block_out, temb_channels=self.temb_ch, dropout=dropout + ) + ) block_in = block_out if curr_res in attn_resolutions: attn.append(make_attn(block_in, attn_type=attn_type)) @@ -459,23 +468,19 @@ class Encoder(nn.Module): # middle self.mid = nn.Module() - self.mid.block_1 = ResnetBlock(in_channels=block_in, - out_channels=block_in, - temb_channels=self.temb_ch, - dropout=dropout) + self.mid.block_1 = ResnetBlock( + in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout + ) self.mid.attn_1 = make_attn(block_in, attn_type=attn_type) - self.mid.block_2 = ResnetBlock(in_channels=block_in, - out_channels=block_in, - temb_channels=self.temb_ch, - dropout=dropout) + self.mid.block_2 = ResnetBlock( + in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout + ) # end self.norm_out = Normalize(block_in) - self.conv_out = torch.nn.Conv2d(block_in, - 2 * z_channels if double_z else z_channels, - kernel_size=3, - stride=1, - padding=1) + self.conv_out = torch.nn.Conv2d( + block_in, 2 * z_channels if double_z else z_channels, kernel_size=3, stride=1, padding=1 + ) def forward(self, x): # timestep embedding @@ -506,24 +511,25 @@ class Encoder(nn.Module): class Decoder(nn.Module): - - def __init__(self, - *, - ch, - out_ch, - ch_mult=(1, 2, 4, 8), - num_res_blocks, - attn_resolutions, - dropout=0.0, - resamp_with_conv=True, - in_channels, - resolution, - z_channels, - give_pre_end=False, - tanh_out=False, - use_linear_attn=False, - attn_type="vanilla", - **ignorekwargs): + def __init__( + self, + *, + ch, + out_ch, + ch_mult=(1, 2, 4, 8), + num_res_blocks, + attn_resolutions, + dropout=0.0, + resamp_with_conv=True, + in_channels, + resolution, + z_channels, + give_pre_end=False, + tanh_out=False, + use_linear_attn=False, + attn_type="vanilla", + **ignorekwargs, + ): super().__init__() if use_linear_attn: attn_type = "linear" @@ -537,9 +543,9 @@ class Decoder(nn.Module): self.tanh_out = tanh_out # compute in_ch_mult, block_in and curr_res at lowest res - in_ch_mult = (1,) + tuple(ch_mult) + (1,) + tuple(ch_mult) block_in = ch * ch_mult[self.num_resolutions - 1] - curr_res = resolution // 2**(self.num_resolutions - 1) + curr_res = resolution // 2 ** (self.num_resolutions - 1) self.z_shape = (1, z_channels, curr_res, curr_res) rank_zero_info("Working with z of shape {} = {} dimensions.".format(self.z_shape, np.prod(self.z_shape))) @@ -548,15 +554,13 @@ class Decoder(nn.Module): # middle self.mid = nn.Module() - self.mid.block_1 = ResnetBlock(in_channels=block_in, - out_channels=block_in, - temb_channels=self.temb_ch, - dropout=dropout) + self.mid.block_1 = ResnetBlock( + in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout + ) self.mid.attn_1 = make_attn(block_in, attn_type=attn_type) - self.mid.block_2 = ResnetBlock(in_channels=block_in, - out_channels=block_in, - temb_channels=self.temb_ch, - dropout=dropout) + self.mid.block_2 = ResnetBlock( + in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout + ) # upsampling self.up = nn.ModuleList() @@ -566,10 +570,10 @@ class Decoder(nn.Module): block_out = ch * ch_mult[i_level] for i_block in range(self.num_res_blocks + 1): block.append( - ResnetBlock(in_channels=block_in, - out_channels=block_out, - temb_channels=self.temb_ch, - dropout=dropout)) + ResnetBlock( + in_channels=block_in, out_channels=block_out, temb_channels=self.temb_ch, dropout=dropout + ) + ) block_in = block_out if curr_res in attn_resolutions: attn.append(make_attn(block_in, attn_type=attn_type)) @@ -579,14 +583,14 @@ class Decoder(nn.Module): if i_level != 0: up.upsample = Upsample(block_in, resamp_with_conv) curr_res = curr_res * 2 - self.up.insert(0, up) # prepend to get consistent order + self.up.insert(0, up) # prepend to get consistent order # end self.norm_out = Normalize(block_in) self.conv_out = torch.nn.Conv2d(block_in, out_ch, kernel_size=3, stride=1, padding=1) def forward(self, z): - #assert z.shape[1:] == self.z_shape[1:] + # assert z.shape[1:] == self.z_shape[1:] self.last_z_shape = z.shape # timestep embedding @@ -622,17 +626,18 @@ class Decoder(nn.Module): class SimpleDecoder(nn.Module): - def __init__(self, in_channels, out_channels, *args, **kwargs): super().__init__() - self.model = nn.ModuleList([ - nn.Conv2d(in_channels, in_channels, 1), - ResnetBlock(in_channels=in_channels, out_channels=2 * in_channels, temb_channels=0, dropout=0.0), - ResnetBlock(in_channels=2 * in_channels, out_channels=4 * in_channels, temb_channels=0, dropout=0.0), - ResnetBlock(in_channels=4 * in_channels, out_channels=2 * in_channels, temb_channels=0, dropout=0.0), - nn.Conv2d(2 * in_channels, in_channels, 1), - Upsample(in_channels, with_conv=True) - ]) + self.model = nn.ModuleList( + [ + nn.Conv2d(in_channels, in_channels, 1), + ResnetBlock(in_channels=in_channels, out_channels=2 * in_channels, temb_channels=0, dropout=0.0), + ResnetBlock(in_channels=2 * in_channels, out_channels=4 * in_channels, temb_channels=0, dropout=0.0), + ResnetBlock(in_channels=4 * in_channels, out_channels=2 * in_channels, temb_channels=0, dropout=0.0), + nn.Conv2d(2 * in_channels, in_channels, 1), + Upsample(in_channels, with_conv=True), + ] + ) # end self.norm_out = Normalize(in_channels) self.conv_out = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) @@ -651,7 +656,6 @@ class SimpleDecoder(nn.Module): class UpsampleDecoder(nn.Module): - def __init__(self, in_channels, out_channels, ch, num_res_blocks, resolution, ch_mult=(2, 2), dropout=0.0): super().__init__() # upsampling @@ -659,7 +663,7 @@ class UpsampleDecoder(nn.Module): self.num_resolutions = len(ch_mult) self.num_res_blocks = num_res_blocks block_in = in_channels - curr_res = resolution // 2**(self.num_resolutions - 1) + curr_res = resolution // 2 ** (self.num_resolutions - 1) self.res_blocks = nn.ModuleList() self.upsample_blocks = nn.ModuleList() for i_level in range(self.num_resolutions): @@ -667,10 +671,10 @@ class UpsampleDecoder(nn.Module): block_out = ch * ch_mult[i_level] for i_block in range(self.num_res_blocks + 1): res_block.append( - ResnetBlock(in_channels=block_in, - out_channels=block_out, - temb_channels=self.temb_ch, - dropout=dropout)) + ResnetBlock( + in_channels=block_in, out_channels=block_out, temb_channels=self.temb_ch, dropout=dropout + ) + ) block_in = block_out self.res_blocks.append(nn.ModuleList(res_block)) if i_level != self.num_resolutions - 1: @@ -696,21 +700,24 @@ class UpsampleDecoder(nn.Module): class LatentRescaler(nn.Module): - def __init__(self, factor, in_channels, mid_channels, out_channels, depth=2): super().__init__() # residual block, interpolate, residual block self.factor = factor self.conv_in = nn.Conv2d(in_channels, mid_channels, kernel_size=3, stride=1, padding=1) - self.res_block1 = nn.ModuleList([ - ResnetBlock(in_channels=mid_channels, out_channels=mid_channels, temb_channels=0, dropout=0.0) - for _ in range(depth) - ]) + self.res_block1 = nn.ModuleList( + [ + ResnetBlock(in_channels=mid_channels, out_channels=mid_channels, temb_channels=0, dropout=0.0) + for _ in range(depth) + ] + ) self.attn = AttnBlock(mid_channels) - self.res_block2 = nn.ModuleList([ - ResnetBlock(in_channels=mid_channels, out_channels=mid_channels, temb_channels=0, dropout=0.0) - for _ in range(depth) - ]) + self.res_block2 = nn.ModuleList( + [ + ResnetBlock(in_channels=mid_channels, out_channels=mid_channels, temb_channels=0, dropout=0.0) + for _ in range(depth) + ] + ) self.conv_out = nn.Conv2d( mid_channels, @@ -722,9 +729,9 @@ class LatentRescaler(nn.Module): x = self.conv_in(x) for block in self.res_block1: x = block(x, None) - x = torch.nn.functional.interpolate(x, - size=(int(round(x.shape[2] * self.factor)), - int(round(x.shape[3] * self.factor)))) + x = torch.nn.functional.interpolate( + x, size=(int(round(x.shape[2] * self.factor)), int(round(x.shape[3] * self.factor))) + ) x = self.attn(x) for block in self.res_block2: x = block(x, None) @@ -733,37 +740,42 @@ class LatentRescaler(nn.Module): class MergedRescaleEncoder(nn.Module): - - def __init__(self, - in_channels, - ch, - resolution, - out_ch, - num_res_blocks, - attn_resolutions, - dropout=0.0, - resamp_with_conv=True, - ch_mult=(1, 2, 4, 8), - rescale_factor=1.0, - rescale_module_depth=1): + def __init__( + self, + in_channels, + ch, + resolution, + out_ch, + num_res_blocks, + attn_resolutions, + dropout=0.0, + resamp_with_conv=True, + ch_mult=(1, 2, 4, 8), + rescale_factor=1.0, + rescale_module_depth=1, + ): super().__init__() intermediate_chn = ch * ch_mult[-1] - self.encoder = Encoder(in_channels=in_channels, - num_res_blocks=num_res_blocks, - ch=ch, - ch_mult=ch_mult, - z_channels=intermediate_chn, - double_z=False, - resolution=resolution, - attn_resolutions=attn_resolutions, - dropout=dropout, - resamp_with_conv=resamp_with_conv, - out_ch=None) - self.rescaler = LatentRescaler(factor=rescale_factor, - in_channels=intermediate_chn, - mid_channels=intermediate_chn, - out_channels=out_ch, - depth=rescale_module_depth) + self.encoder = Encoder( + in_channels=in_channels, + num_res_blocks=num_res_blocks, + ch=ch, + ch_mult=ch_mult, + z_channels=intermediate_chn, + double_z=False, + resolution=resolution, + attn_resolutions=attn_resolutions, + dropout=dropout, + resamp_with_conv=resamp_with_conv, + out_ch=None, + ) + self.rescaler = LatentRescaler( + factor=rescale_factor, + in_channels=intermediate_chn, + mid_channels=intermediate_chn, + out_channels=out_ch, + depth=rescale_module_depth, + ) def forward(self, x): x = self.encoder(x) @@ -772,36 +784,41 @@ class MergedRescaleEncoder(nn.Module): class MergedRescaleDecoder(nn.Module): - - def __init__(self, - z_channels, - out_ch, - resolution, - num_res_blocks, - attn_resolutions, - ch, - ch_mult=(1, 2, 4, 8), - dropout=0.0, - resamp_with_conv=True, - rescale_factor=1.0, - rescale_module_depth=1): + def __init__( + self, + z_channels, + out_ch, + resolution, + num_res_blocks, + attn_resolutions, + ch, + ch_mult=(1, 2, 4, 8), + dropout=0.0, + resamp_with_conv=True, + rescale_factor=1.0, + rescale_module_depth=1, + ): super().__init__() tmp_chn = z_channels * ch_mult[-1] - self.decoder = Decoder(out_ch=out_ch, - z_channels=tmp_chn, - attn_resolutions=attn_resolutions, - dropout=dropout, - resamp_with_conv=resamp_with_conv, - in_channels=None, - num_res_blocks=num_res_blocks, - ch_mult=ch_mult, - resolution=resolution, - ch=ch) - self.rescaler = LatentRescaler(factor=rescale_factor, - in_channels=z_channels, - mid_channels=tmp_chn, - out_channels=tmp_chn, - depth=rescale_module_depth) + self.decoder = Decoder( + out_ch=out_ch, + z_channels=tmp_chn, + attn_resolutions=attn_resolutions, + dropout=dropout, + resamp_with_conv=resamp_with_conv, + in_channels=None, + num_res_blocks=num_res_blocks, + ch_mult=ch_mult, + resolution=resolution, + ch=ch, + ) + self.rescaler = LatentRescaler( + factor=rescale_factor, + in_channels=z_channels, + mid_channels=tmp_chn, + out_channels=tmp_chn, + depth=rescale_module_depth, + ) def forward(self, x): x = self.rescaler(x) @@ -810,27 +827,27 @@ class MergedRescaleDecoder(nn.Module): class Upsampler(nn.Module): - def __init__(self, in_size, out_size, in_channels, out_channels, ch_mult=2): super().__init__() assert out_size >= in_size num_blocks = int(np.log2(out_size // in_size)) + 1 - factor_up = 1. + (out_size % in_size) + factor_up = 1.0 + (out_size % in_size) rank_zero_info( f"Building {self.__class__.__name__} with in_size: {in_size} --> out_size {out_size} and factor {factor_up}" ) - self.rescaler = LatentRescaler(factor=factor_up, - in_channels=in_channels, - mid_channels=2 * in_channels, - out_channels=in_channels) - self.decoder = Decoder(out_ch=out_channels, - resolution=out_size, - z_channels=in_channels, - num_res_blocks=2, - attn_resolutions=[], - in_channels=None, - ch=in_channels, - ch_mult=[ch_mult for _ in range(num_blocks)]) + self.rescaler = LatentRescaler( + factor=factor_up, in_channels=in_channels, mid_channels=2 * in_channels, out_channels=in_channels + ) + self.decoder = Decoder( + out_ch=out_channels, + resolution=out_size, + z_channels=in_channels, + num_res_blocks=2, + attn_resolutions=[], + in_channels=None, + ch=in_channels, + ch_mult=[ch_mult for _ in range(num_blocks)], + ) def forward(self, x): x = self.rescaler(x) @@ -839,14 +856,14 @@ class Upsampler(nn.Module): class Resize(nn.Module): - def __init__(self, in_channels=None, learned=False, mode="bilinear"): super().__init__() self.with_conv = learned self.mode = mode if self.with_conv: rank_zero_info( - f"Note: {self.__class__.__name} uses learned downsampling and will ignore the fixed {mode} mode") + f"Note: {self.__class__.__name} uses learned downsampling and will ignore the fixed {mode} mode" + ) raise NotImplementedError() assert in_channels is not None # no asymmetric padding in torch conv, must do it ourselves diff --git a/examples/images/diffusion/ldm/modules/diffusionmodules/openaimodel.py b/examples/images/diffusion/ldm/modules/diffusionmodules/openaimodel.py index cd639d9360466c72b92db403e52514a149997ed8..614fe510f20e6b5cf99b8c1bb31a62cd8682452f 100644 --- a/examples/images/diffusion/ldm/modules/diffusionmodules/openaimodel.py +++ b/examples/images/diffusion/ldm/modules/diffusionmodules/openaimodel.py @@ -1,21 +1,20 @@ -from abc import abstractmethod import math +from abc import abstractmethod import numpy as np import torch as th import torch.nn as nn import torch.nn.functional as F - +from ldm.modules.attention import SpatialTransformer from ldm.modules.diffusionmodules.util import ( + avg_pool_nd, checkpoint, conv_nd, linear, - avg_pool_nd, - zero_module, normalization, timestep_embedding, + zero_module, ) -from ldm.modules.attention import SpatialTransformer from ldm.util import exists @@ -23,6 +22,7 @@ from ldm.util import exists def convert_module_to_f16(x): pass + def convert_module_to_f32(x): pass @@ -41,7 +41,7 @@ class AttentionPool2d(nn.Module): output_dim: int = None, ): super().__init__() - self.positional_embedding = nn.Parameter(th.randn(embed_dim, spacial_dim ** 2 + 1) / embed_dim ** 0.5) + self.positional_embedding = nn.Parameter(th.randn(embed_dim, spacial_dim**2 + 1) / embed_dim**0.5) self.qkv_proj = conv_nd(1, embed_dim, 3 * embed_dim, 1) self.c_proj = conv_nd(1, embed_dim, output_dim or embed_dim, 1) self.num_heads = embed_dim // num_heads_channels @@ -108,25 +108,25 @@ class Upsample(nn.Module): def forward(self, x): assert x.shape[1] == self.channels if self.dims == 3: - x = F.interpolate( - x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest" - ) + x = F.interpolate(x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest") else: x = F.interpolate(x, scale_factor=2, mode="nearest") if self.use_conv: x = self.conv(x) return x + class TransposedUpsample(nn.Module): - 'Learned 2x upsampling without padding' + "Learned 2x upsampling without padding" + def __init__(self, channels, out_channels=None, ks=5): super().__init__() self.channels = channels self.out_channels = out_channels or channels - self.up = nn.ConvTranspose2d(self.channels,self.out_channels,kernel_size=ks,stride=2) + self.up = nn.ConvTranspose2d(self.channels, self.out_channels, kernel_size=ks, stride=2) - def forward(self,x): + def forward(self, x): return self.up(x) @@ -139,7 +139,7 @@ class Downsample(nn.Module): downsampling occurs in the inner-two dimensions. """ - def __init__(self, channels, use_conv, dims=2, out_channels=None,padding=1): + def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1): super().__init__() self.channels = channels self.out_channels = out_channels or channels @@ -147,9 +147,7 @@ class Downsample(nn.Module): self.dims = dims stride = 2 if dims != 3 else (1, 2, 2) if use_conv: - self.op = conv_nd( - dims, self.channels, self.out_channels, 3, stride=stride, padding=padding - ) + self.op = conv_nd(dims, self.channels, self.out_channels, 3, stride=stride, padding=padding) else: assert self.channels == self.out_channels self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride) @@ -225,17 +223,13 @@ class ResBlock(TimestepBlock): normalization(self.out_channels), nn.SiLU(), nn.Dropout(p=dropout), - zero_module( - conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1) - ), + zero_module(conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1)), ) if self.out_channels == channels: self.skip_connection = nn.Identity() elif use_conv: - self.skip_connection = conv_nd( - dims, channels, self.out_channels, 3, padding=1 - ) + self.skip_connection = conv_nd(dims, channels, self.out_channels, 3, padding=1) else: self.skip_connection = conv_nd(dims, channels, self.out_channels, 1) @@ -246,10 +240,7 @@ class ResBlock(TimestepBlock): :param emb: an [N x emb_channels] Tensor of timestep embeddings. :return: an [N x C x ...] Tensor of outputs. """ - return checkpoint( - self._forward, (x, emb), self.parameters(), self.use_checkpoint - ) - + return checkpoint(self._forward, (x, emb), self.parameters(), self.use_checkpoint) def _forward(self, x, emb): if self.updown: @@ -311,8 +302,10 @@ class AttentionBlock(nn.Module): self.proj_out = zero_module(conv_nd(1, channels, channels, 1)) def forward(self, x): - return checkpoint(self._forward, (x,), self.parameters(), True) # TODO: check checkpoint usage, is True # TODO: fix the .half call!!! - #return pt_checkpoint(self._forward, x) # pytorch + return checkpoint( + self._forward, (x,), self.parameters(), True + ) # TODO: check checkpoint usage, is True # TODO: fix the .half call!!! + # return pt_checkpoint(self._forward, x) # pytorch def _forward(self, x): b, c, *spatial = x.shape @@ -339,7 +332,7 @@ def count_flops_attn(model, _x, y): # We perform two matmuls with the same number of ops. # The first computes the weight matrix, the second computes # the combination of the value vectors. - matmul_ops = 2 * b * (num_spatial ** 2) * c + matmul_ops = 2 * b * (num_spatial**2) * c model.total_ops += th.DoubleTensor([matmul_ops]) @@ -363,9 +356,7 @@ class QKVAttentionLegacy(nn.Module): ch = width // (3 * self.n_heads) q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch, dim=1) scale = 1 / math.sqrt(math.sqrt(ch)) - weight = th.einsum( - "bct,bcs->bts", q * scale, k * scale - ) # More stable with f16 than dividing afterwards + weight = th.einsum("bct,bcs->bts", q * scale, k * scale) # More stable with f16 than dividing afterwards weight = th.softmax(weight.float(), dim=-1).type(weight.dtype) a = th.einsum("bts,bcs->bct", weight, v) return a.reshape(bs, -1, length) @@ -460,10 +451,10 @@ class UNetModel(nn.Module): use_scale_shift_norm=False, resblock_updown=False, use_new_attention_order=False, - use_spatial_transformer=False, # custom transformer support - transformer_depth=1, # custom transformer support - context_dim=None, # custom transformer support - n_embed=None, # custom support for prediction of discrete ids into codebook of first stage vq model + use_spatial_transformer=False, # custom transformer support + transformer_depth=1, # custom transformer support + context_dim=None, # custom transformer support + n_embed=None, # custom support for prediction of discrete ids into codebook of first stage vq model legacy=True, disable_self_attentions=None, num_attention_blocks=None, @@ -472,11 +463,16 @@ class UNetModel(nn.Module): ): super().__init__() if use_spatial_transformer: - assert context_dim is not None, 'Fool!! You forgot to include the dimension of your cross-attention conditioning...' + assert ( + context_dim is not None + ), "Fool!! You forgot to include the dimension of your cross-attention conditioning..." if context_dim is not None: - assert use_spatial_transformer, 'Fool!! You forgot to use the spatial transformer for your cross-attention conditioning...' + assert ( + use_spatial_transformer + ), "Fool!! You forgot to use the spatial transformer for your cross-attention conditioning..." from omegaconf.listconfig import ListConfig + if type(context_dim) == ListConfig: context_dim = list(context_dim) @@ -484,10 +480,10 @@ class UNetModel(nn.Module): num_heads_upsample = num_heads if num_heads == -1: - assert num_head_channels != -1, 'Either num_heads or num_head_channels has to be set' + assert num_head_channels != -1, "Either num_heads or num_head_channels has to be set" if num_head_channels == -1: - assert num_heads != -1, 'Either num_heads or num_head_channels has to be set' + assert num_heads != -1, "Either num_heads or num_head_channels has to be set" self.image_size = image_size self.in_channels = in_channels @@ -497,19 +493,25 @@ class UNetModel(nn.Module): self.num_res_blocks = len(channel_mult) * [num_res_blocks] else: if len(num_res_blocks) != len(channel_mult): - raise ValueError("provide num_res_blocks either as an int (globally constant) or " - "as a list/tuple (per-level) with the same length as channel_mult") + raise ValueError( + "provide num_res_blocks either as an int (globally constant) or " + "as a list/tuple (per-level) with the same length as channel_mult" + ) self.num_res_blocks = num_res_blocks if disable_self_attentions is not None: # should be a list of booleans, indicating whether to disable self-attention in TransformerBlocks or not assert len(disable_self_attentions) == len(channel_mult) if num_attention_blocks is not None: assert len(num_attention_blocks) == len(self.num_res_blocks) - assert all(map(lambda i: self.num_res_blocks[i] >= num_attention_blocks[i], range(len(num_attention_blocks)))) - print(f"Constructor of UNetModel received num_attention_blocks={num_attention_blocks}. " - f"This option has LESS priority than attention_resolutions {attention_resolutions}, " - f"i.e., in cases where num_attention_blocks[i] > 0 but 2**i not in attention_resolutions, " - f"attention will still not be set.") + assert all( + map(lambda i: self.num_res_blocks[i] >= num_attention_blocks[i], range(len(num_attention_blocks))) + ) + print( + f"Constructor of UNetModel received num_attention_blocks={num_attention_blocks}. " + f"This option has LESS priority than attention_resolutions {attention_resolutions}, " + f"i.e., in cases where num_attention_blocks[i] > 0 but 2**i not in attention_resolutions, " + f"attention will still not be set." + ) self.attention_resolutions = attention_resolutions self.dropout = dropout @@ -540,11 +542,7 @@ class UNetModel(nn.Module): raise ValueError() self.input_blocks = nn.ModuleList( - [ - TimestepEmbedSequential( - conv_nd(dims, in_channels, model_channels, 3, padding=1) - ) - ] + [TimestepEmbedSequential(conv_nd(dims, in_channels, model_channels, 3, padding=1))] ) self._feature_size = model_channels input_block_chans = [model_channels] @@ -571,7 +569,7 @@ class UNetModel(nn.Module): num_heads = ch // num_head_channels dim_head = num_head_channels if legacy: - #num_heads = 1 + # num_heads = 1 dim_head = ch // num_heads if use_spatial_transformer else num_head_channels if exists(disable_self_attentions): disabled_sa = disable_self_attentions[level] @@ -586,10 +584,17 @@ class UNetModel(nn.Module): num_heads=num_heads, num_head_channels=dim_head, use_new_attention_order=use_new_attention_order, - ) if not use_spatial_transformer else SpatialTransformer( - ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim, - disable_self_attn=disabled_sa, use_linear=use_linear_in_transformer, - use_checkpoint=use_checkpoint + ) + if not use_spatial_transformer + else SpatialTransformer( + ch, + num_heads, + dim_head, + depth=transformer_depth, + context_dim=context_dim, + disable_self_attn=disabled_sa, + use_linear=use_linear_in_transformer, + use_checkpoint=use_checkpoint, ) ) self.input_blocks.append(TimestepEmbedSequential(*layers)) @@ -610,9 +615,7 @@ class UNetModel(nn.Module): down=True, ) if resblock_updown - else Downsample( - ch, conv_resample, dims=dims, out_channels=out_ch - ) + else Downsample(ch, conv_resample, dims=dims, out_channels=out_ch) ) ) ch = out_ch @@ -626,7 +629,7 @@ class UNetModel(nn.Module): num_heads = ch // num_head_channels dim_head = num_head_channels if legacy: - #num_heads = 1 + # num_heads = 1 dim_head = ch // num_heads if use_spatial_transformer else num_head_channels self.middle_block = TimestepEmbedSequential( ResBlock( @@ -643,11 +646,18 @@ class UNetModel(nn.Module): num_heads=num_heads, num_head_channels=dim_head, use_new_attention_order=use_new_attention_order, - ) if not use_spatial_transformer else SpatialTransformer( # always uses a self-attn - ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim, - disable_self_attn=disable_middle_self_attn, use_linear=use_linear_in_transformer, - use_checkpoint=use_checkpoint - ), + ) + if not use_spatial_transformer + else SpatialTransformer( # always uses a self-attn + ch, + num_heads, + dim_head, + depth=transformer_depth, + context_dim=context_dim, + disable_self_attn=disable_middle_self_attn, + use_linear=use_linear_in_transformer, + use_checkpoint=use_checkpoint, + ), ResBlock( ch, time_embed_dim, @@ -682,7 +692,7 @@ class UNetModel(nn.Module): num_heads = ch // num_head_channels dim_head = num_head_channels if legacy: - #num_heads = 1 + # num_heads = 1 dim_head = ch // num_heads if use_spatial_transformer else num_head_channels if exists(disable_self_attentions): disabled_sa = disable_self_attentions[level] @@ -697,10 +707,17 @@ class UNetModel(nn.Module): num_heads=num_heads_upsample, num_head_channels=dim_head, use_new_attention_order=use_new_attention_order, - ) if not use_spatial_transformer else SpatialTransformer( - ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim, - disable_self_attn=disabled_sa, use_linear=use_linear_in_transformer, - use_checkpoint=use_checkpoint + ) + if not use_spatial_transformer + else SpatialTransformer( + ch, + num_heads, + dim_head, + depth=transformer_depth, + context_dim=context_dim, + disable_self_attn=disabled_sa, + use_linear=use_linear_in_transformer, + use_checkpoint=use_checkpoint, ) ) if level and i == self.num_res_blocks[level]: @@ -730,10 +747,10 @@ class UNetModel(nn.Module): ) if self.predict_codebook_ids: self.id_predictor = nn.Sequential( - normalization(ch), - conv_nd(dims, model_channels, n_embed, 1), - #nn.LogSoftmax(dim=1) # change to cross_entropy and produce non-normalized logits - ) + normalization(ch), + conv_nd(dims, model_channels, n_embed, 1), + # nn.LogSoftmax(dim=1) # change to cross_entropy and produce non-normalized logits + ) def convert_to_fp16(self): """ @@ -751,7 +768,7 @@ class UNetModel(nn.Module): self.middle_block.apply(convert_module_to_f32) self.output_blocks.apply(convert_module_to_f32) - def forward(self, x, timesteps=None, context=None, y=None,**kwargs): + def forward(self, x, timesteps=None, context=None, y=None, **kwargs): """ Apply the model to an input batch. :param x: an [N x C x ...] Tensor of inputs. diff --git a/examples/images/diffusion/ldm/modules/diffusionmodules/upscaling.py b/examples/images/diffusion/ldm/modules/diffusionmodules/upscaling.py index 03816662098ce1ffac79bd939b892e867ab91988..82cc2157ca68de95e5f6b86ab88daed7fe430f9c 100644 --- a/examples/images/diffusion/ldm/modules/diffusionmodules/upscaling.py +++ b/examples/images/diffusion/ldm/modules/diffusionmodules/upscaling.py @@ -1,8 +1,8 @@ -import torch -import torch.nn as nn -import numpy as np from functools import partial +import numpy as np +import torch +import torch.nn as nn from ldm.modules.diffusionmodules.util import extract_into_tensor, make_beta_schedule from ldm.util import default @@ -14,37 +14,41 @@ class AbstractLowScaleModel(nn.Module): if noise_schedule_config is not None: self.register_schedule(**noise_schedule_config) - def register_schedule(self, beta_schedule="linear", timesteps=1000, - linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3): - betas = make_beta_schedule(beta_schedule, timesteps, linear_start=linear_start, linear_end=linear_end, - cosine_s=cosine_s) - alphas = 1. - betas + def register_schedule( + self, beta_schedule="linear", timesteps=1000, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3 + ): + betas = make_beta_schedule( + beta_schedule, timesteps, linear_start=linear_start, linear_end=linear_end, cosine_s=cosine_s + ) + alphas = 1.0 - betas alphas_cumprod = np.cumprod(alphas, axis=0) - alphas_cumprod_prev = np.append(1., alphas_cumprod[:-1]) + alphas_cumprod_prev = np.append(1.0, alphas_cumprod[:-1]) - timesteps, = betas.shape + (timesteps,) = betas.shape self.num_timesteps = int(timesteps) self.linear_start = linear_start self.linear_end = linear_end - assert alphas_cumprod.shape[0] == self.num_timesteps, 'alphas have to be defined for each timestep' + assert alphas_cumprod.shape[0] == self.num_timesteps, "alphas have to be defined for each timestep" to_torch = partial(torch.tensor, dtype=torch.float32) - self.register_buffer('betas', to_torch(betas)) - self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod)) - self.register_buffer('alphas_cumprod_prev', to_torch(alphas_cumprod_prev)) + self.register_buffer("betas", to_torch(betas)) + self.register_buffer("alphas_cumprod", to_torch(alphas_cumprod)) + self.register_buffer("alphas_cumprod_prev", to_torch(alphas_cumprod_prev)) # calculations for diffusion q(x_t | x_{t-1}) and others - self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod))) - self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod))) - self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod))) - self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod))) - self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod - 1))) + self.register_buffer("sqrt_alphas_cumprod", to_torch(np.sqrt(alphas_cumprod))) + self.register_buffer("sqrt_one_minus_alphas_cumprod", to_torch(np.sqrt(1.0 - alphas_cumprod))) + self.register_buffer("log_one_minus_alphas_cumprod", to_torch(np.log(1.0 - alphas_cumprod))) + self.register_buffer("sqrt_recip_alphas_cumprod", to_torch(np.sqrt(1.0 / alphas_cumprod))) + self.register_buffer("sqrt_recipm1_alphas_cumprod", to_torch(np.sqrt(1.0 / alphas_cumprod - 1))) def q_sample(self, x_start, t, noise=None): noise = default(noise, lambda: torch.randn_like(x_start)) - return (extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start + - extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise) + return ( + extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start + + extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise + ) def forward(self, x): return x, None @@ -76,6 +80,3 @@ class ImageConcatWithNoiseAugmentation(AbstractLowScaleModel): assert isinstance(noise_level, torch.Tensor) z = self.q_sample(x, noise_level) return z, noise_level - - - diff --git a/examples/images/diffusion/ldm/modules/diffusionmodules/util.py b/examples/images/diffusion/ldm/modules/diffusionmodules/util.py index 36b4a171b6c2382206c4754a26889fd685276db1..aed1b061323a4b75b505ffdb186199c6f51cf66c 100644 --- a/examples/images/diffusion/ldm/modules/diffusionmodules/util.py +++ b/examples/images/diffusion/ldm/modules/diffusionmodules/util.py @@ -8,7 +8,6 @@ # thanks! import math -import os import numpy as np import torch @@ -19,10 +18,10 @@ from ldm.util import instantiate_from_config def make_beta_schedule(schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3): if schedule == "linear": - betas = (torch.linspace(linear_start**0.5, linear_end**0.5, n_timestep, dtype=torch.float64)**2) + betas = torch.linspace(linear_start**0.5, linear_end**0.5, n_timestep, dtype=torch.float64) ** 2 elif schedule == "cosine": - timesteps = (torch.arange(n_timestep + 1, dtype=torch.float64) / n_timestep + cosine_s) + timesteps = torch.arange(n_timestep + 1, dtype=torch.float64) / n_timestep + cosine_s alphas = timesteps / (1 + cosine_s) * np.pi / 2 alphas = torch.cos(alphas).pow(2) alphas = alphas / alphas[0] @@ -32,18 +31,18 @@ def make_beta_schedule(schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, elif schedule == "sqrt_linear": betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) elif schedule == "sqrt": - betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64)**0.5 + betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) ** 0.5 else: raise ValueError(f"schedule '{schedule}' unknown.") return betas.numpy() def make_ddim_timesteps(ddim_discr_method, num_ddim_timesteps, num_ddpm_timesteps, verbose=True): - if ddim_discr_method == 'uniform': + if ddim_discr_method == "uniform": c = num_ddpm_timesteps // num_ddim_timesteps ddim_timesteps = np.asarray(list(range(0, num_ddpm_timesteps, c))) - elif ddim_discr_method == 'quad': - ddim_timesteps = ((np.linspace(0, np.sqrt(num_ddpm_timesteps * .8), num_ddim_timesteps))**2).astype(int) + elif ddim_discr_method == "quad": + ddim_timesteps = ((np.linspace(0, np.sqrt(num_ddpm_timesteps * 0.8), num_ddim_timesteps)) ** 2).astype(int) else: raise NotImplementedError(f'There is no ddim discretization method called "{ddim_discr_method}"') @@ -51,7 +50,7 @@ def make_ddim_timesteps(ddim_discr_method, num_ddim_timesteps, num_ddpm_timestep # add one to get the final alpha values right (the ones from first scale to data during sampling) steps_out = ddim_timesteps + 1 if verbose: - print(f'Selected timesteps for ddim sampler: {steps_out}') + print(f"Selected timesteps for ddim sampler: {steps_out}") return steps_out @@ -63,9 +62,11 @@ def make_ddim_sampling_parameters(alphacums, ddim_timesteps, eta, verbose=True): # according the the formula provided in https://arxiv.org/abs/2010.02502 sigmas = eta * np.sqrt((1 - alphas_prev) / (1 - alphas) * (1 - alphas / alphas_prev)) if verbose: - print(f'Selected alphas for ddim sampler: a_t: {alphas}; a_(t-1): {alphas_prev}') - print(f'For the chosen value of eta, which is {eta}, ' - f'this results in the following sigma_t schedule for ddim sampler {sigmas}') + print(f"Selected alphas for ddim sampler: a_t: {alphas}; a_(t-1): {alphas_prev}") + print( + f"For the chosen value of eta, which is {eta}, " + f"this results in the following sigma_t schedule for ddim sampler {sigmas}" + ) return sigmas, alphas, alphas_prev @@ -106,6 +107,7 @@ def checkpoint(func, inputs, params, flag): """ if flag: from torch.utils.checkpoint import checkpoint as torch_checkpoint + return torch_checkpoint(func, *inputs) # args = tuple(inputs) + tuple(params) # return CheckpointFunction.apply(func, len(inputs), *args) @@ -114,7 +116,6 @@ def checkpoint(func, inputs, params, flag): class CheckpointFunction(torch.autograd.Function): - @staticmethod def forward(ctx, run_function, length, *args): ctx.run_function = run_function @@ -123,7 +124,7 @@ class CheckpointFunction(torch.autograd.Function): ctx.gpu_autocast_kwargs = { "enabled": torch.is_autocast_enabled(), "dtype": torch.get_autocast_gpu_dtype(), - "cache_enabled": torch.is_autocast_cache_enabled() + "cache_enabled": torch.is_autocast_cache_enabled(), } with torch.no_grad(): output_tensors = ctx.run_function(*ctx.input_tensors) @@ -132,8 +133,7 @@ class CheckpointFunction(torch.autograd.Function): @staticmethod def backward(ctx, *output_grads): ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors] - with torch.enable_grad(), \ - torch.cuda.amp.autocast(**ctx.gpu_autocast_kwargs): + with torch.enable_grad(), torch.cuda.amp.autocast(**ctx.gpu_autocast_kwargs): # Fixes a bug where the first op in run_function modifies the # Tensor storage in place, which is not allowed for detach()'d # Tensors. @@ -162,14 +162,15 @@ def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False): """ if not repeat_only: half = dim // 2 - freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / - half).to(device=timesteps.device) + freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to( + device=timesteps.device + ) args = timesteps[:, None].float() * freqs[None] embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) if dim % 2: embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) else: - embedding = repeat(timesteps, 'b -> b d', d=dim) + embedding = repeat(timesteps, "b -> b d", d=dim) return embedding @@ -210,13 +211,11 @@ def normalization(channels): # PyTorch 1.7 has SiLU, but we support PyTorch 1.5. class SiLU(nn.Module): - def forward(self, x): return x * torch.sigmoid(x) class GroupNorm32(nn.GroupNorm): - def forward(self, x): return super().forward(x.float()).type(x.dtype) @@ -255,7 +254,6 @@ def avg_pool_nd(dims, *args, **kwargs): class HybridConditioner(nn.Module): - def __init__(self, c_concat_config, c_crossattn_config): super().__init__() self.concat_conditioner = instantiate_from_config(c_concat_config) @@ -264,7 +262,7 @@ class HybridConditioner(nn.Module): def forward(self, c_concat, c_crossattn): c_concat = self.concat_conditioner(c_concat) c_crossattn = self.crossattn_conditioner(c_crossattn) - return {'c_concat': [c_concat], 'c_crossattn': [c_crossattn]} + return {"c_concat": [c_concat], "c_crossattn": [c_crossattn]} def noise_like(shape, device, repeat=False): diff --git a/examples/images/diffusion/ldm/modules/distributions/distributions.py b/examples/images/diffusion/ldm/modules/distributions/distributions.py index f2b8ef901130efc171aa69742ca0244d94d3f2e9..b5f3b1ad48daff8c6e010c9d778ce5614bf0edef 100644 --- a/examples/images/diffusion/ldm/modules/distributions/distributions.py +++ b/examples/images/diffusion/ldm/modules/distributions/distributions.py @@ -1,5 +1,5 @@ -import torch import numpy as np +import torch class AbstractDistribution: @@ -38,25 +38,25 @@ class DiagonalGaussianDistribution(object): def kl(self, other=None): if self.deterministic: - return torch.Tensor([0.]) + return torch.Tensor([0.0]) else: if other is None: - return 0.5 * torch.sum(torch.pow(self.mean, 2) - + self.var - 1.0 - self.logvar, - dim=[1, 2, 3]) + return 0.5 * torch.sum(torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar, dim=[1, 2, 3]) else: return 0.5 * torch.sum( torch.pow(self.mean - other.mean, 2) / other.var - + self.var / other.var - 1.0 - self.logvar + other.logvar, - dim=[1, 2, 3]) - - def nll(self, sample, dims=[1,2,3]): + + self.var / other.var + - 1.0 + - self.logvar + + other.logvar, + dim=[1, 2, 3], + ) + + def nll(self, sample, dims=[1, 2, 3]): if self.deterministic: - return torch.Tensor([0.]) + return torch.Tensor([0.0]) logtwopi = np.log(2.0 * np.pi) - return 0.5 * torch.sum( - logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var, - dim=dims) + return 0.5 * torch.sum(logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var, dim=dims) def mode(self): return self.mean @@ -78,15 +78,8 @@ def normal_kl(mean1, logvar1, mean2, logvar2): # Force variances to be Tensors. Broadcasting helps convert scalars to # Tensors, but it does not work for torch.exp(). - logvar1, logvar2 = [ - x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor) - for x in (logvar1, logvar2) - ] + logvar1, logvar2 = [x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor) for x in (logvar1, logvar2)] return 0.5 * ( - -1.0 - + logvar2 - - logvar1 - + torch.exp(logvar1 - logvar2) - + ((mean1 - mean2) ** 2) * torch.exp(-logvar2) + -1.0 + logvar2 - logvar1 + torch.exp(logvar1 - logvar2) + ((mean1 - mean2) ** 2) * torch.exp(-logvar2) ) diff --git a/examples/images/diffusion/ldm/modules/ema.py b/examples/images/diffusion/ldm/modules/ema.py index bded25019b9bcbcd0260f0b8185f8c7859ca58c4..c3863269675e466f4571cf2d256c0f92fe1984cd 100644 --- a/examples/images/diffusion/ldm/modules/ema.py +++ b/examples/images/diffusion/ldm/modules/ema.py @@ -6,17 +6,18 @@ class LitEma(nn.Module): def __init__(self, model, decay=0.9999, use_num_upates=True): super().__init__() if decay < 0.0 or decay > 1.0: - raise ValueError('Decay must be between 0 and 1') + raise ValueError("Decay must be between 0 and 1") self.m_name2s_name = {} - self.register_buffer('decay', torch.tensor(decay, dtype=torch.float32)) - self.register_buffer('num_updates', torch.tensor(0, dtype=torch.int) if use_num_upates - else torch.tensor(-1, dtype=torch.int)) + self.register_buffer("decay", torch.tensor(decay, dtype=torch.float32)) + self.register_buffer( + "num_updates", torch.tensor(0, dtype=torch.int) if use_num_upates else torch.tensor(-1, dtype=torch.int) + ) for name, p in model.named_parameters(): if p.requires_grad: # remove as '.'-character is not allowed in buffers - s_name = name.replace('.', '') + s_name = name.replace(".", "") self.m_name2s_name.update({name: s_name}) self.register_buffer(s_name, p.clone().detach().data) @@ -24,7 +25,7 @@ class LitEma(nn.Module): def reset_num_updates(self): del self.num_updates - self.register_buffer('num_updates', torch.tensor(0, dtype=torch.int)) + self.register_buffer("num_updates", torch.tensor(0, dtype=torch.int)) def forward(self, model): decay = self.decay diff --git a/examples/images/diffusion/ldm/modules/encoders/modules.py b/examples/images/diffusion/ldm/modules/encoders/modules.py index 4edd5496b9e668ea72a5be39db9cca94b6a42f9b..58bff2382c4743717a4cbe4225a66110a9078363 100644 --- a/examples/images/diffusion/ldm/modules/encoders/modules.py +++ b/examples/images/diffusion/ldm/modules/encoders/modules.py @@ -1,11 +1,9 @@ +import open_clip import torch import torch.nn as nn +from ldm.util import count_params from torch.utils.checkpoint import checkpoint - -from transformers import T5Tokenizer, T5EncoderModel, CLIPTokenizer, CLIPTextModel - -import open_clip -from ldm.util import default, count_params +from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5Tokenizer class AbstractEncoder(nn.Module): @@ -17,13 +15,12 @@ class AbstractEncoder(nn.Module): class IdentityEncoder(AbstractEncoder): - def encode(self, x): return x class ClassEmbedder(nn.Module): - def __init__(self, embed_dim, n_classes=1000, key='class', ucg_rate=0.1): + def __init__(self, embed_dim, n_classes=1000, key="class", ucg_rate=0.1): super().__init__() self.key = key self.embedding = nn.Embedding(n_classes, embed_dim) @@ -35,9 +32,9 @@ class ClassEmbedder(nn.Module): key = self.key # this is for use in crossattn c = batch[key][:, None] - if self.ucg_rate > 0. and not disable_dropout: - mask = 1. - torch.bernoulli(torch.ones_like(c) * self.ucg_rate) - c = mask * c + (1-mask) * torch.ones_like(c)*(self.n_classes-1) + if self.ucg_rate > 0.0 and not disable_dropout: + mask = 1.0 - torch.bernoulli(torch.ones_like(c) * self.ucg_rate) + c = mask * c + (1 - mask) * torch.ones_like(c) * (self.n_classes - 1) c = c.long() c = self.embedding(c) return c @@ -57,24 +54,34 @@ def disabled_train(self, mode=True): class FrozenT5Embedder(AbstractEncoder): """Uses the T5 transformer encoder for text""" - def __init__(self, version="google/t5-v1_1-large", device="cuda", max_length=77, freeze=True): # others are google/t5-v1_1-xl and google/t5-v1_1-xxl + + def __init__( + self, version="google/t5-v1_1-large", device="cuda", max_length=77, freeze=True + ): # others are google/t5-v1_1-xl and google/t5-v1_1-xxl super().__init__() self.tokenizer = T5Tokenizer.from_pretrained(version) self.transformer = T5EncoderModel.from_pretrained(version) self.device = device - self.max_length = max_length # TODO: typical value? + self.max_length = max_length # TODO: typical value? if freeze: self.freeze() def freeze(self): self.transformer = self.transformer.eval() - #self.train = disabled_train + # self.train = disabled_train for param in self.parameters(): param.requires_grad = False def forward(self, text): - batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True, - return_overflowing_tokens=False, padding="max_length", return_tensors="pt") + batch_encoding = self.tokenizer( + text, + truncation=True, + max_length=self.max_length, + return_length=True, + return_overflowing_tokens=False, + padding="max_length", + return_tensors="pt", + ) tokens = batch_encoding["input_ids"].to(self.device) outputs = self.transformer(input_ids=tokens) @@ -87,13 +94,18 @@ class FrozenT5Embedder(AbstractEncoder): class FrozenCLIPEmbedder(AbstractEncoder): """Uses the CLIP transformer encoder for text (from huggingface)""" - LAYERS = [ - "last", - "pooled", - "hidden" - ] - def __init__(self, version="openai/clip-vit-large-patch14", device="cuda", max_length=77, - freeze=True, layer="last", layer_idx=None): # clip-vit-base-patch32 + + LAYERS = ["last", "pooled", "hidden"] + + def __init__( + self, + version="openai/clip-vit-large-patch14", + device="cuda", + max_length=77, + freeze=True, + layer="last", + layer_idx=None, + ): # clip-vit-base-patch32 super().__init__() assert layer in self.LAYERS self.tokenizer = CLIPTokenizer.from_pretrained(version) @@ -110,15 +122,22 @@ class FrozenCLIPEmbedder(AbstractEncoder): def freeze(self): self.transformer = self.transformer.eval() - #self.train = disabled_train + # self.train = disabled_train for param in self.parameters(): param.requires_grad = False def forward(self, text): - batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True, - return_overflowing_tokens=False, padding="max_length", return_tensors="pt") + batch_encoding = self.tokenizer( + text, + truncation=True, + max_length=self.max_length, + return_length=True, + return_overflowing_tokens=False, + padding="max_length", + return_tensors="pt", + ) tokens = batch_encoding["input_ids"].to(self.device) - outputs = self.transformer(input_ids=tokens, output_hidden_states=self.layer=="hidden") + outputs = self.transformer(input_ids=tokens, output_hidden_states=self.layer == "hidden") if self.layer == "last": z = outputs.last_hidden_state elif self.layer == "pooled": @@ -135,16 +154,19 @@ class FrozenOpenCLIPEmbedder(AbstractEncoder): """ Uses the OpenCLIP transformer encoder for text """ + LAYERS = [ - #"pooled", + # "pooled", "last", - "penultimate" + "penultimate", ] - def __init__(self, arch="ViT-H-14", version="laion2b_s32b_b79k", device="cuda", max_length=77, - freeze=True, layer="last"): + + def __init__( + self, arch="ViT-H-14", version="laion2b_s32b_b79k", device="cuda", max_length=77, freeze=True, layer="last" + ): super().__init__() assert layer in self.LAYERS - model, _, _ = open_clip.create_model_and_transforms(arch, device=torch.device('cpu'), pretrained=version) + model, _, _ = open_clip.create_model_and_transforms(arch, device=torch.device("cpu"), pretrained=version) del model.visual self.model = model @@ -179,7 +201,7 @@ class FrozenOpenCLIPEmbedder(AbstractEncoder): x = self.model.ln_final(x) return x - def text_transformer_forward(self, x: torch.Tensor, attn_mask = None): + def text_transformer_forward(self, x: torch.Tensor, attn_mask=None): for i, r in enumerate(self.model.transformer.resblocks): if i == len(self.model.transformer.resblocks) - self.layer_idx: break @@ -194,13 +216,21 @@ class FrozenOpenCLIPEmbedder(AbstractEncoder): class FrozenCLIPT5Encoder(AbstractEncoder): - def __init__(self, clip_version="openai/clip-vit-large-patch14", t5_version="google/t5-v1_1-xl", device="cuda", - clip_max_length=77, t5_max_length=77): + def __init__( + self, + clip_version="openai/clip-vit-large-patch14", + t5_version="google/t5-v1_1-xl", + device="cuda", + clip_max_length=77, + t5_max_length=77, + ): super().__init__() self.clip_encoder = FrozenCLIPEmbedder(clip_version, device, max_length=clip_max_length) self.t5_encoder = FrozenT5Embedder(t5_version, device, max_length=t5_max_length) - print(f"{self.clip_encoder.__class__.__name__} has {count_params(self.clip_encoder)*1.e-6:.2f} M parameters, " - f"{self.t5_encoder.__class__.__name__} comes with {count_params(self.t5_encoder)*1.e-6:.2f} M params.") + print( + f"{self.clip_encoder.__class__.__name__} has {count_params(self.clip_encoder)*1.e-6:.2f} M parameters, " + f"{self.t5_encoder.__class__.__name__} comes with {count_params(self.t5_encoder)*1.e-6:.2f} M params." + ) def encode(self, text): return self(text) @@ -209,5 +239,3 @@ class FrozenCLIPT5Encoder(AbstractEncoder): clip_z = self.clip_encoder.encode(text) t5_z = self.t5_encoder.encode(text) return [clip_z, t5_z] - - diff --git a/examples/images/diffusion/ldm/modules/image_degradation/bsrgan.py b/examples/images/diffusion/ldm/modules/image_degradation/bsrgan.py index 32ef56169978e550090261cddbcf5eb611a6173b..879b2aa099b6d55fc91bc8481f260a7ffb0378f5 100644 --- a/examples/images/diffusion/ldm/modules/image_degradation/bsrgan.py +++ b/examples/images/diffusion/ldm/modules/image_degradation/bsrgan.py @@ -10,33 +10,32 @@ # -------------------------------------------- """ -import numpy as np -import cv2 -import torch - -from functools import partial import random -from scipy import ndimage +from functools import partial + +import albumentations +import cv2 +import ldm.modules.image_degradation.utils_image as util +import numpy as np import scipy import scipy.stats as ss +import torch +from scipy import ndimage from scipy.interpolate import interp2d from scipy.linalg import orth -import albumentations - -import ldm.modules.image_degradation.utils_image as util def modcrop_np(img, sf): - ''' + """ Args: img: numpy image, WxH or WxHxC sf: scale factor Return: cropped image - ''' + """ w, h = img.shape[:2] im = np.copy(img) - return im[:w - w % sf, :h - h % sf, ...] + return im[: w - w % sf, : h - h % sf, ...] """ @@ -54,7 +53,7 @@ def analytic_kernel(k): # Loop over the small kernel to fill the big one for r in range(k_size): for c in range(k_size): - big_k[2 * r:2 * r + k_size, 2 * c:2 * c + k_size] += k[r, c] * k + big_k[2 * r : 2 * r + k_size, 2 * c : 2 * c + k_size] += k[r, c] * k # Crop the edges of the big kernel to ignore very small values and increase run time of SR crop = k_size // 2 cropped_big_k = big_k[crop:-crop, crop:-crop] @@ -63,7 +62,7 @@ def analytic_kernel(k): def anisotropic_Gaussian(ksize=15, theta=np.pi, l1=6, l2=6): - """ generate an anisotropic Gaussian kernel + """generate an anisotropic Gaussian kernel Args: ksize : e.g., 15, kernel size theta : [0, pi], rotation angle range @@ -74,7 +73,7 @@ def anisotropic_Gaussian(ksize=15, theta=np.pi, l1=6, l2=6): k : kernel """ - v = np.dot(np.array([[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]]), np.array([1., 0.])) + v = np.dot(np.array([[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]]), np.array([1.0, 0.0])) V = np.array([[v[0], v[1]], [v[1], -v[0]]]) D = np.array([[l1, 0], [0, l2]]) Sigma = np.dot(np.dot(V, D), np.linalg.inv(V)) @@ -126,13 +125,13 @@ def shift_pixel(x, sf, upper_left=True): def blur(x, k): - ''' + """ x: image, NxcxHxW k: kernel, Nx1xhxw - ''' + """ n, c = x.shape[:2] p1, p2 = (k.shape[-2] - 1) // 2, (k.shape[-1] - 1) // 2 - x = torch.nn.functional.pad(x, pad=(p1, p2, p1, p2), mode='replicate') + x = torch.nn.functional.pad(x, pad=(p1, p2, p1, p2), mode="replicate") k = k.repeat(1, c, 1, 1) k = k.view(-1, 1, k.shape[2], k.shape[3]) x = x.view(1, -1, x.shape[2], x.shape[3]) @@ -142,8 +141,8 @@ def blur(x, k): return x -def gen_kernel(k_size=np.array([15, 15]), scale_factor=np.array([4, 4]), min_var=0.6, max_var=10., noise_level=0): - """" +def gen_kernel(k_size=np.array([15, 15]), scale_factor=np.array([4, 4]), min_var=0.6, max_var=10.0, noise_level=0): + """ " # modified version of https://github.com/assafshocher/BlindSR_dataset_generator # Kai Zhang # min_var = 0.175 * sf # variance of the gaussian kernel will be sampled between min_var and max_var @@ -157,8 +156,7 @@ def gen_kernel(k_size=np.array([15, 15]), scale_factor=np.array([4, 4]), min_var # Set COV matrix using Lambdas and Theta LAMBDA = np.diag([lambda_1, lambda_2]) - Q = np.array([[np.cos(theta), -np.sin(theta)], - [np.sin(theta), np.cos(theta)]]) + Q = np.array([[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]]) SIGMA = Q @ LAMBDA @ Q.T INV_SIGMA = np.linalg.inv(SIGMA)[None, None, :, :] @@ -208,13 +206,13 @@ def fspecial_laplacian(alpha): def fspecial(filter_type, *args, **kwargs): - ''' + """ python code from: https://github.com/ronaldosena/imagens-medicas-2/blob/40171a6c259edec7827a6693a93955de2bd39e76/Aulas/aula_2_-_uniform_filter/matlab_fspecial.py - ''' - if filter_type == 'gaussian': + """ + if filter_type == "gaussian": return fspecial_gaussian(*args, **kwargs) - if filter_type == 'laplacian': + if filter_type == "laplacian": return fspecial_laplacian(*args, **kwargs) @@ -226,19 +224,19 @@ def fspecial(filter_type, *args, **kwargs): def bicubic_degradation(x, sf=3): - ''' + """ Args: x: HxWxC image, [0, 1] sf: down-scale factor Return: bicubicly downsampled LR image - ''' + """ x = util.imresize_np(x, scale=1 / sf) return x def srmd_degradation(x, k, sf=3): - ''' blur + bicubic downsampling + """blur + bicubic downsampling Args: x: HxWxC image, [0, 1] k: hxw, double @@ -253,14 +251,14 @@ def srmd_degradation(x, k, sf=3): pages={3262--3271}, year={2018} } - ''' - x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode='wrap') # 'nearest' | 'mirror' + """ + x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode="wrap") # 'nearest' | 'mirror' x = bicubic_degradation(x, sf=sf) return x def dpsr_degradation(x, k, sf=3): - ''' bicubic downsampling + blur + """bicubic downsampling + blur Args: x: HxWxC image, [0, 1] k: hxw, double @@ -275,22 +273,22 @@ def dpsr_degradation(x, k, sf=3): pages={1671--1681}, year={2019} } - ''' + """ x = bicubic_degradation(x, sf=sf) - x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode='wrap') + x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode="wrap") return x def classical_degradation(x, k, sf=3): - ''' blur + downsampling + """blur + downsampling Args: x: HxWxC image, [0, 1]/[0, 255] k: hxw, double sf: down-scale factor Return: downsampled LR image - ''' - x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode='wrap') + """ + x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode="wrap") # x = filters.correlate(x, np.expand_dims(np.flip(k), axis=2)) st = 0 return x[st::sf, st::sf, ...] @@ -314,7 +312,7 @@ def add_sharpening(img, weight=0.5, radius=50, threshold=10): blur = cv2.GaussianBlur(img, (radius, radius), 0) residual = img - blur mask = np.abs(residual) * 255 > threshold - mask = mask.astype('float32') + mask = mask.astype("float32") soft_mask = cv2.GaussianBlur(mask, (radius, radius), 0) K = img + weight * residual @@ -330,8 +328,8 @@ def add_blur(img, sf=4): l2 = wd2 * random.random() k = anisotropic_Gaussian(ksize=2 * random.randint(2, 11) + 3, theta=random.random() * np.pi, l1=l1, l2=l2) else: - k = fspecial('gaussian', 2 * random.randint(2, 11) + 3, wd * random.random()) - img = ndimage.filters.convolve(img, np.expand_dims(k, axis=2), mode='mirror') + k = fspecial("gaussian", 2 * random.randint(2, 11) + 3, wd * random.random()) + img = ndimage.filters.convolve(img, np.expand_dims(k, axis=2), mode="mirror") return img @@ -366,6 +364,7 @@ def add_resize(img, sf=4): # img = np.clip(img, 0.0, 1.0) # return img + def add_Gaussian_noise(img, noise_level1=2, noise_level2=25): noise_level = random.randint(noise_level1, noise_level2) rnum = np.random.rand() @@ -374,11 +373,11 @@ def add_Gaussian_noise(img, noise_level1=2, noise_level2=25): elif rnum < 0.4: # add grayscale Gaussian noise img = img + np.random.normal(0, noise_level / 255.0, (*img.shape[:2], 1)).astype(np.float32) else: # add noise - L = noise_level2 / 255. + L = noise_level2 / 255.0 D = np.diag(np.random.rand(3)) U = orth(np.random.rand(3, 3)) conv = np.dot(np.dot(np.transpose(U), D), U) - img = img + np.random.multivariate_normal([0, 0, 0], np.abs(L ** 2 * conv), img.shape[:2]).astype(np.float32) + img = img + np.random.multivariate_normal([0, 0, 0], np.abs(L**2 * conv), img.shape[:2]).astype(np.float32) img = np.clip(img, 0.0, 1.0) return img @@ -392,23 +391,23 @@ def add_speckle_noise(img, noise_level1=2, noise_level2=25): elif rnum < 0.4: img += img * np.random.normal(0, noise_level / 255.0, (*img.shape[:2], 1)).astype(np.float32) else: - L = noise_level2 / 255. + L = noise_level2 / 255.0 D = np.diag(np.random.rand(3)) U = orth(np.random.rand(3, 3)) conv = np.dot(np.dot(np.transpose(U), D), U) - img += img * np.random.multivariate_normal([0, 0, 0], np.abs(L ** 2 * conv), img.shape[:2]).astype(np.float32) + img += img * np.random.multivariate_normal([0, 0, 0], np.abs(L**2 * conv), img.shape[:2]).astype(np.float32) img = np.clip(img, 0.0, 1.0) return img def add_Poisson_noise(img): - img = np.clip((img * 255.0).round(), 0, 255) / 255. + img = np.clip((img * 255.0).round(), 0, 255) / 255.0 vals = 10 ** (2 * random.random() + 2.0) # [2, 4] if random.random() < 0.5: img = np.random.poisson(img * vals).astype(np.float32) / vals else: img_gray = np.dot(img[..., :3], [0.299, 0.587, 0.114]) - img_gray = np.clip((img_gray * 255.0).round(), 0, 255) / 255. + img_gray = np.clip((img_gray * 255.0).round(), 0, 255) / 255.0 noise_gray = np.random.poisson(img_gray * vals).astype(np.float32) / vals - img_gray img += noise_gray[:, :, np.newaxis] img = np.clip(img, 0.0, 1.0) @@ -418,7 +417,7 @@ def add_Poisson_noise(img): def add_JPEG_noise(img): quality_factor = random.randint(30, 95) img = cv2.cvtColor(util.single2uint(img), cv2.COLOR_RGB2BGR) - result, encimg = cv2.imencode('.jpg', img, [int(cv2.IMWRITE_JPEG_QUALITY), quality_factor]) + result, encimg = cv2.imencode(".jpg", img, [int(cv2.IMWRITE_JPEG_QUALITY), quality_factor]) img = cv2.imdecode(encimg, 1) img = cv2.cvtColor(util.uint2single(img), cv2.COLOR_BGR2RGB) return img @@ -428,10 +427,10 @@ def random_crop(lq, hq, sf=4, lq_patchsize=64): h, w = lq.shape[:2] rnd_h = random.randint(0, h - lq_patchsize) rnd_w = random.randint(0, w - lq_patchsize) - lq = lq[rnd_h:rnd_h + lq_patchsize, rnd_w:rnd_w + lq_patchsize, :] + lq = lq[rnd_h : rnd_h + lq_patchsize, rnd_w : rnd_w + lq_patchsize, :] rnd_h_H, rnd_w_H = int(rnd_h * sf), int(rnd_w * sf) - hq = hq[rnd_h_H:rnd_h_H + lq_patchsize * sf, rnd_w_H:rnd_w_H + lq_patchsize * sf, :] + hq = hq[rnd_h_H : rnd_h_H + lq_patchsize * sf, rnd_w_H : rnd_w_H + lq_patchsize * sf, :] return lq, hq @@ -452,18 +451,19 @@ def degradation_bsrgan(img, sf=4, lq_patchsize=72, isp_model=None): sf_ori = sf h1, w1 = img.shape[:2] - img = img.copy()[:w1 - w1 % sf, :h1 - h1 % sf, ...] # mod crop + img = img.copy()[: w1 - w1 % sf, : h1 - h1 % sf, ...] # mod crop h, w = img.shape[:2] if h < lq_patchsize * sf or w < lq_patchsize * sf: - raise ValueError(f'img size ({h1}X{w1}) is too small!') + raise ValueError(f"img size ({h1}X{w1}) is too small!") hq = img.copy() if sf == 4 and random.random() < scale2_prob: # downsample1 if np.random.rand() < 0.5: - img = cv2.resize(img, (int(1 / 2 * img.shape[1]), int(1 / 2 * img.shape[0])), - interpolation=random.choice([1, 2, 3])) + img = cv2.resize( + img, (int(1 / 2 * img.shape[1]), int(1 / 2 * img.shape[0])), interpolation=random.choice([1, 2, 3]) + ) else: img = util.imresize_np(img, 1 / 2, True) img = np.clip(img, 0.0, 1.0) @@ -475,7 +475,6 @@ def degradation_bsrgan(img, sf=4, lq_patchsize=72, isp_model=None): shuffle_order[idx1], shuffle_order[idx2] = shuffle_order[idx2], shuffle_order[idx1] for i in shuffle_order: - if i == 0: img = add_blur(img, sf=sf) @@ -487,13 +486,16 @@ def degradation_bsrgan(img, sf=4, lq_patchsize=72, isp_model=None): # downsample2 if random.random() < 0.75: sf1 = random.uniform(1, 2 * sf) - img = cv2.resize(img, (int(1 / sf1 * img.shape[1]), int(1 / sf1 * img.shape[0])), - interpolation=random.choice([1, 2, 3])) + img = cv2.resize( + img, + (int(1 / sf1 * img.shape[1]), int(1 / sf1 * img.shape[0])), + interpolation=random.choice([1, 2, 3]), + ) else: - k = fspecial('gaussian', 25, random.uniform(0.1, 0.6 * sf)) + k = fspecial("gaussian", 25, random.uniform(0.1, 0.6 * sf)) k_shifted = shift_pixel(k, sf) k_shifted = k_shifted / k_shifted.sum() # blur with shifted kernel - img = ndimage.filters.convolve(img, np.expand_dims(k_shifted, axis=2), mode='mirror') + img = ndimage.filters.convolve(img, np.expand_dims(k_shifted, axis=2), mode="mirror") img = img[0::sf, 0::sf, ...] # nearest downsampling img = np.clip(img, 0.0, 1.0) @@ -541,18 +543,20 @@ def degradation_bsrgan_variant(image, sf=4, isp_model=None): """ image = util.uint2single(image) isp_prob, jpeg_prob, scale2_prob = 0.25, 0.9, 0.25 - sf_ori = sf h1, w1 = image.shape[:2] - image = image.copy()[:w1 - w1 % sf, :h1 - h1 % sf, ...] # mod crop + image = image.copy()[: w1 - w1 % sf, : h1 - h1 % sf, ...] # mod crop h, w = image.shape[:2] - hq = image.copy() + image.copy() if sf == 4 and random.random() < scale2_prob: # downsample1 if np.random.rand() < 0.5: - image = cv2.resize(image, (int(1 / 2 * image.shape[1]), int(1 / 2 * image.shape[0])), - interpolation=random.choice([1, 2, 3])) + image = cv2.resize( + image, + (int(1 / 2 * image.shape[1]), int(1 / 2 * image.shape[0])), + interpolation=random.choice([1, 2, 3]), + ) else: image = util.imresize_np(image, 1 / 2, True) image = np.clip(image, 0.0, 1.0) @@ -564,7 +568,6 @@ def degradation_bsrgan_variant(image, sf=4, isp_model=None): shuffle_order[idx1], shuffle_order[idx2] = shuffle_order[idx2], shuffle_order[idx1] for i in shuffle_order: - if i == 0: image = add_blur(image, sf=sf) @@ -576,13 +579,16 @@ def degradation_bsrgan_variant(image, sf=4, isp_model=None): # downsample2 if random.random() < 0.75: sf1 = random.uniform(1, 2 * sf) - image = cv2.resize(image, (int(1 / sf1 * image.shape[1]), int(1 / sf1 * image.shape[0])), - interpolation=random.choice([1, 2, 3])) + image = cv2.resize( + image, + (int(1 / sf1 * image.shape[1]), int(1 / sf1 * image.shape[0])), + interpolation=random.choice([1, 2, 3]), + ) else: - k = fspecial('gaussian', 25, random.uniform(0.1, 0.6 * sf)) + k = fspecial("gaussian", 25, random.uniform(0.1, 0.6 * sf)) k_shifted = shift_pixel(k, sf) k_shifted = k_shifted / k_shifted.sum() # blur with shifted kernel - image = ndimage.filters.convolve(image, np.expand_dims(k_shifted, axis=2), mode='mirror') + image = ndimage.filters.convolve(image, np.expand_dims(k_shifted, axis=2), mode="mirror") image = image[0::sf, 0::sf, ...] # nearest downsampling image = np.clip(image, 0.0, 1.0) @@ -609,7 +615,7 @@ def degradation_bsrgan_variant(image, sf=4, isp_model=None): # add final JPEG compression noise image = add_JPEG_noise(image) image = util.single2uint(image) - example = {"image":image} + example = {"image": image} return example @@ -630,11 +636,11 @@ def degradation_bsrgan_plus(img, sf=4, shuffle_prob=0.5, use_sharp=True, lq_patc """ h1, w1 = img.shape[:2] - img = img.copy()[:w1 - w1 % sf, :h1 - h1 % sf, ...] # mod crop + img = img.copy()[: w1 - w1 % sf, : h1 - h1 % sf, ...] # mod crop h, w = img.shape[:2] if h < lq_patchsize * sf or w < lq_patchsize * sf: - raise ValueError(f'img size ({h1}X{w1}) is too small!') + raise ValueError(f"img size ({h1}X{w1}) is too small!") if use_sharp: img = add_sharpening(img) @@ -686,11 +692,12 @@ def degradation_bsrgan_plus(img, sf=4, shuffle_prob=0.5, use_sharp=True, lq_patc with torch.no_grad(): img, hq = isp_model.forward(img.copy(), hq) else: - print('check the shuffle!') + print("check the shuffle!") # resize to desired size - img = cv2.resize(img, (int(1 / sf * hq.shape[1]), int(1 / sf * hq.shape[0])), - interpolation=random.choice([1, 2, 3])) + img = cv2.resize( + img, (int(1 / sf * hq.shape[1]), int(1 / sf * hq.shape[0])), interpolation=random.choice([1, 2, 3]) + ) # add final JPEG compression noise img = add_JPEG_noise(img) @@ -701,30 +708,30 @@ def degradation_bsrgan_plus(img, sf=4, shuffle_prob=0.5, use_sharp=True, lq_patc return img, hq -if __name__ == '__main__': - print("hey") - img = util.imread_uint('utils/test.png', 3) - print(img) - img = util.uint2single(img) - print(img) - img = img[:448, :448] - h = img.shape[0] // 4 - print("resizing to", h) - sf = 4 - deg_fn = partial(degradation_bsrgan_variant, sf=sf) - for i in range(20): - print(i) - img_lq = deg_fn(img) - print(img_lq) - img_lq_bicubic = albumentations.SmallestMaxSize(max_size=h, interpolation=cv2.INTER_CUBIC)(image=img)["image"] - print(img_lq.shape) - print("bicubic", img_lq_bicubic.shape) - print(img_hq.shape) - lq_nearest = cv2.resize(util.single2uint(img_lq), (int(sf * img_lq.shape[1]), int(sf * img_lq.shape[0])), - interpolation=0) - lq_bicubic_nearest = cv2.resize(util.single2uint(img_lq_bicubic), (int(sf * img_lq.shape[1]), int(sf * img_lq.shape[0])), - interpolation=0) - img_concat = np.concatenate([lq_bicubic_nearest, lq_nearest, util.single2uint(img_hq)], axis=1) - util.imsave(img_concat, str(i) + '.png') - - +if __name__ == "__main__": + print("hey") + img = util.imread_uint("utils/test.png", 3) + print(img) + img = util.uint2single(img) + print(img) + img = img[:448, :448] + h = img.shape[0] // 4 + print("resizing to", h) + sf = 4 + deg_fn = partial(degradation_bsrgan_variant, sf=sf) + for i in range(20): + print(i) + img_lq = deg_fn(img) + print(img_lq) + img_lq_bicubic = albumentations.SmallestMaxSize(max_size=h, interpolation=cv2.INTER_CUBIC)(image=img)["image"] + print(img_lq.shape) + print("bicubic", img_lq_bicubic.shape) + print(img_hq.shape) + lq_nearest = cv2.resize( + util.single2uint(img_lq), (int(sf * img_lq.shape[1]), int(sf * img_lq.shape[0])), interpolation=0 + ) + lq_bicubic_nearest = cv2.resize( + util.single2uint(img_lq_bicubic), (int(sf * img_lq.shape[1]), int(sf * img_lq.shape[0])), interpolation=0 + ) + img_concat = np.concatenate([lq_bicubic_nearest, lq_nearest, util.single2uint(img_hq)], axis=1) + util.imsave(img_concat, str(i) + ".png") diff --git a/examples/images/diffusion/ldm/modules/image_degradation/bsrgan_light.py b/examples/images/diffusion/ldm/modules/image_degradation/bsrgan_light.py index 808c7f882cb75e2ba2340d5b55881d11927351f0..cf3f83f0c011d32187e2dc3a00ff8992201f320c 100644 --- a/examples/images/diffusion/ldm/modules/image_degradation/bsrgan_light.py +++ b/examples/images/diffusion/ldm/modules/image_degradation/bsrgan_light.py @@ -1,18 +1,17 @@ # -*- coding: utf-8 -*- -import numpy as np -import cv2 -import torch - -from functools import partial import random -from scipy import ndimage +from functools import partial + +import albumentations +import cv2 +import ldm.modules.image_degradation.utils_image as util +import numpy as np import scipy import scipy.stats as ss +import torch +from scipy import ndimage from scipy.interpolate import interp2d from scipy.linalg import orth -import albumentations - -import ldm.modules.image_degradation.utils_image as util """ # -------------------------------------------- @@ -25,17 +24,18 @@ import ldm.modules.image_degradation.utils_image as util # -------------------------------------------- """ + def modcrop_np(img, sf): - ''' + """ Args: img: numpy image, WxH or WxHxC sf: scale factor Return: cropped image - ''' + """ w, h = img.shape[:2] im = np.copy(img) - return im[:w - w % sf, :h - h % sf, ...] + return im[: w - w % sf, : h - h % sf, ...] """ @@ -53,7 +53,7 @@ def analytic_kernel(k): # Loop over the small kernel to fill the big one for r in range(k_size): for c in range(k_size): - big_k[2 * r:2 * r + k_size, 2 * c:2 * c + k_size] += k[r, c] * k + big_k[2 * r : 2 * r + k_size, 2 * c : 2 * c + k_size] += k[r, c] * k # Crop the edges of the big kernel to ignore very small values and increase run time of SR crop = k_size // 2 cropped_big_k = big_k[crop:-crop, crop:-crop] @@ -62,7 +62,7 @@ def analytic_kernel(k): def anisotropic_Gaussian(ksize=15, theta=np.pi, l1=6, l2=6): - """ generate an anisotropic Gaussian kernel + """generate an anisotropic Gaussian kernel Args: ksize : e.g., 15, kernel size theta : [0, pi], rotation angle range @@ -73,7 +73,7 @@ def anisotropic_Gaussian(ksize=15, theta=np.pi, l1=6, l2=6): k : kernel """ - v = np.dot(np.array([[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]]), np.array([1., 0.])) + v = np.dot(np.array([[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]]), np.array([1.0, 0.0])) V = np.array([[v[0], v[1]], [v[1], -v[0]]]) D = np.array([[l1, 0], [0, l2]]) Sigma = np.dot(np.dot(V, D), np.linalg.inv(V)) @@ -125,13 +125,13 @@ def shift_pixel(x, sf, upper_left=True): def blur(x, k): - ''' + """ x: image, NxcxHxW k: kernel, Nx1xhxw - ''' + """ n, c = x.shape[:2] p1, p2 = (k.shape[-2] - 1) // 2, (k.shape[-1] - 1) // 2 - x = torch.nn.functional.pad(x, pad=(p1, p2, p1, p2), mode='replicate') + x = torch.nn.functional.pad(x, pad=(p1, p2, p1, p2), mode="replicate") k = k.repeat(1, c, 1, 1) k = k.view(-1, 1, k.shape[2], k.shape[3]) x = x.view(1, -1, x.shape[2], x.shape[3]) @@ -141,8 +141,8 @@ def blur(x, k): return x -def gen_kernel(k_size=np.array([15, 15]), scale_factor=np.array([4, 4]), min_var=0.6, max_var=10., noise_level=0): - """" +def gen_kernel(k_size=np.array([15, 15]), scale_factor=np.array([4, 4]), min_var=0.6, max_var=10.0, noise_level=0): + """ " # modified version of https://github.com/assafshocher/BlindSR_dataset_generator # Kai Zhang # min_var = 0.175 * sf # variance of the gaussian kernel will be sampled between min_var and max_var @@ -156,8 +156,7 @@ def gen_kernel(k_size=np.array([15, 15]), scale_factor=np.array([4, 4]), min_var # Set COV matrix using Lambdas and Theta LAMBDA = np.diag([lambda_1, lambda_2]) - Q = np.array([[np.cos(theta), -np.sin(theta)], - [np.sin(theta), np.cos(theta)]]) + Q = np.array([[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]]) SIGMA = Q @ LAMBDA @ Q.T INV_SIGMA = np.linalg.inv(SIGMA)[None, None, :, :] @@ -207,13 +206,13 @@ def fspecial_laplacian(alpha): def fspecial(filter_type, *args, **kwargs): - ''' + """ python code from: https://github.com/ronaldosena/imagens-medicas-2/blob/40171a6c259edec7827a6693a93955de2bd39e76/Aulas/aula_2_-_uniform_filter/matlab_fspecial.py - ''' - if filter_type == 'gaussian': + """ + if filter_type == "gaussian": return fspecial_gaussian(*args, **kwargs) - if filter_type == 'laplacian': + if filter_type == "laplacian": return fspecial_laplacian(*args, **kwargs) @@ -225,19 +224,19 @@ def fspecial(filter_type, *args, **kwargs): def bicubic_degradation(x, sf=3): - ''' + """ Args: x: HxWxC image, [0, 1] sf: down-scale factor Return: bicubicly downsampled LR image - ''' + """ x = util.imresize_np(x, scale=1 / sf) return x def srmd_degradation(x, k, sf=3): - ''' blur + bicubic downsampling + """blur + bicubic downsampling Args: x: HxWxC image, [0, 1] k: hxw, double @@ -252,14 +251,14 @@ def srmd_degradation(x, k, sf=3): pages={3262--3271}, year={2018} } - ''' - x = ndimage.convolve(x, np.expand_dims(k, axis=2), mode='wrap') # 'nearest' | 'mirror' + """ + x = ndimage.convolve(x, np.expand_dims(k, axis=2), mode="wrap") # 'nearest' | 'mirror' x = bicubic_degradation(x, sf=sf) return x def dpsr_degradation(x, k, sf=3): - ''' bicubic downsampling + blur + """bicubic downsampling + blur Args: x: HxWxC image, [0, 1] k: hxw, double @@ -274,22 +273,22 @@ def dpsr_degradation(x, k, sf=3): pages={1671--1681}, year={2019} } - ''' + """ x = bicubic_degradation(x, sf=sf) - x = ndimage.convolve(x, np.expand_dims(k, axis=2), mode='wrap') + x = ndimage.convolve(x, np.expand_dims(k, axis=2), mode="wrap") return x def classical_degradation(x, k, sf=3): - ''' blur + downsampling + """blur + downsampling Args: x: HxWxC image, [0, 1]/[0, 255] k: hxw, double sf: down-scale factor Return: downsampled LR image - ''' - x = ndimage.convolve(x, np.expand_dims(k, axis=2), mode='wrap') + """ + x = ndimage.convolve(x, np.expand_dims(k, axis=2), mode="wrap") # x = filters.correlate(x, np.expand_dims(np.flip(k), axis=2)) st = 0 return x[st::sf, st::sf, ...] @@ -313,7 +312,7 @@ def add_sharpening(img, weight=0.5, radius=50, threshold=10): blur = cv2.GaussianBlur(img, (radius, radius), 0) residual = img - blur mask = np.abs(residual) * 255 > threshold - mask = mask.astype('float32') + mask = mask.astype("float32") soft_mask = cv2.GaussianBlur(mask, (radius, radius), 0) K = img + weight * residual @@ -325,16 +324,16 @@ def add_blur(img, sf=4): wd2 = 4.0 + sf wd = 2.0 + 0.2 * sf - wd2 = wd2/4 - wd = wd/4 + wd2 = wd2 / 4 + wd = wd / 4 if random.random() < 0.5: l1 = wd2 * random.random() l2 = wd2 * random.random() k = anisotropic_Gaussian(ksize=random.randint(2, 11) + 3, theta=random.random() * np.pi, l1=l1, l2=l2) else: - k = fspecial('gaussian', random.randint(2, 4) + 3, wd * random.random()) - img = ndimage.convolve(img, np.expand_dims(k, axis=2), mode='mirror') + k = fspecial("gaussian", random.randint(2, 4) + 3, wd * random.random()) + img = ndimage.convolve(img, np.expand_dims(k, axis=2), mode="mirror") return img @@ -369,6 +368,7 @@ def add_resize(img, sf=4): # img = np.clip(img, 0.0, 1.0) # return img + def add_Gaussian_noise(img, noise_level1=2, noise_level2=25): noise_level = random.randint(noise_level1, noise_level2) rnum = np.random.rand() @@ -377,11 +377,11 @@ def add_Gaussian_noise(img, noise_level1=2, noise_level2=25): elif rnum < 0.4: # add grayscale Gaussian noise img = img + np.random.normal(0, noise_level / 255.0, (*img.shape[:2], 1)).astype(np.float32) else: # add noise - L = noise_level2 / 255. + L = noise_level2 / 255.0 D = np.diag(np.random.rand(3)) U = orth(np.random.rand(3, 3)) conv = np.dot(np.dot(np.transpose(U), D), U) - img = img + np.random.multivariate_normal([0, 0, 0], np.abs(L ** 2 * conv), img.shape[:2]).astype(np.float32) + img = img + np.random.multivariate_normal([0, 0, 0], np.abs(L**2 * conv), img.shape[:2]).astype(np.float32) img = np.clip(img, 0.0, 1.0) return img @@ -395,23 +395,23 @@ def add_speckle_noise(img, noise_level1=2, noise_level2=25): elif rnum < 0.4: img += img * np.random.normal(0, noise_level / 255.0, (*img.shape[:2], 1)).astype(np.float32) else: - L = noise_level2 / 255. + L = noise_level2 / 255.0 D = np.diag(np.random.rand(3)) U = orth(np.random.rand(3, 3)) conv = np.dot(np.dot(np.transpose(U), D), U) - img += img * np.random.multivariate_normal([0, 0, 0], np.abs(L ** 2 * conv), img.shape[:2]).astype(np.float32) + img += img * np.random.multivariate_normal([0, 0, 0], np.abs(L**2 * conv), img.shape[:2]).astype(np.float32) img = np.clip(img, 0.0, 1.0) return img def add_Poisson_noise(img): - img = np.clip((img * 255.0).round(), 0, 255) / 255. + img = np.clip((img * 255.0).round(), 0, 255) / 255.0 vals = 10 ** (2 * random.random() + 2.0) # [2, 4] if random.random() < 0.5: img = np.random.poisson(img * vals).astype(np.float32) / vals else: img_gray = np.dot(img[..., :3], [0.299, 0.587, 0.114]) - img_gray = np.clip((img_gray * 255.0).round(), 0, 255) / 255. + img_gray = np.clip((img_gray * 255.0).round(), 0, 255) / 255.0 noise_gray = np.random.poisson(img_gray * vals).astype(np.float32) / vals - img_gray img += noise_gray[:, :, np.newaxis] img = np.clip(img, 0.0, 1.0) @@ -421,7 +421,7 @@ def add_Poisson_noise(img): def add_JPEG_noise(img): quality_factor = random.randint(80, 95) img = cv2.cvtColor(util.single2uint(img), cv2.COLOR_RGB2BGR) - result, encimg = cv2.imencode('.jpg', img, [int(cv2.IMWRITE_JPEG_QUALITY), quality_factor]) + result, encimg = cv2.imencode(".jpg", img, [int(cv2.IMWRITE_JPEG_QUALITY), quality_factor]) img = cv2.imdecode(encimg, 1) img = cv2.cvtColor(util.uint2single(img), cv2.COLOR_BGR2RGB) return img @@ -431,10 +431,10 @@ def random_crop(lq, hq, sf=4, lq_patchsize=64): h, w = lq.shape[:2] rnd_h = random.randint(0, h - lq_patchsize) rnd_w = random.randint(0, w - lq_patchsize) - lq = lq[rnd_h:rnd_h + lq_patchsize, rnd_w:rnd_w + lq_patchsize, :] + lq = lq[rnd_h : rnd_h + lq_patchsize, rnd_w : rnd_w + lq_patchsize, :] rnd_h_H, rnd_w_H = int(rnd_h * sf), int(rnd_w * sf) - hq = hq[rnd_h_H:rnd_h_H + lq_patchsize * sf, rnd_w_H:rnd_w_H + lq_patchsize * sf, :] + hq = hq[rnd_h_H : rnd_h_H + lq_patchsize * sf, rnd_w_H : rnd_w_H + lq_patchsize * sf, :] return lq, hq @@ -455,18 +455,19 @@ def degradation_bsrgan(img, sf=4, lq_patchsize=72, isp_model=None): sf_ori = sf h1, w1 = img.shape[:2] - img = img.copy()[:w1 - w1 % sf, :h1 - h1 % sf, ...] # mod crop + img = img.copy()[: w1 - w1 % sf, : h1 - h1 % sf, ...] # mod crop h, w = img.shape[:2] if h < lq_patchsize * sf or w < lq_patchsize * sf: - raise ValueError(f'img size ({h1}X{w1}) is too small!') + raise ValueError(f"img size ({h1}X{w1}) is too small!") hq = img.copy() if sf == 4 and random.random() < scale2_prob: # downsample1 if np.random.rand() < 0.5: - img = cv2.resize(img, (int(1 / 2 * img.shape[1]), int(1 / 2 * img.shape[0])), - interpolation=random.choice([1, 2, 3])) + img = cv2.resize( + img, (int(1 / 2 * img.shape[1]), int(1 / 2 * img.shape[0])), interpolation=random.choice([1, 2, 3]) + ) else: img = util.imresize_np(img, 1 / 2, True) img = np.clip(img, 0.0, 1.0) @@ -478,7 +479,6 @@ def degradation_bsrgan(img, sf=4, lq_patchsize=72, isp_model=None): shuffle_order[idx1], shuffle_order[idx2] = shuffle_order[idx2], shuffle_order[idx1] for i in shuffle_order: - if i == 0: img = add_blur(img, sf=sf) @@ -490,13 +490,16 @@ def degradation_bsrgan(img, sf=4, lq_patchsize=72, isp_model=None): # downsample2 if random.random() < 0.75: sf1 = random.uniform(1, 2 * sf) - img = cv2.resize(img, (int(1 / sf1 * img.shape[1]), int(1 / sf1 * img.shape[0])), - interpolation=random.choice([1, 2, 3])) + img = cv2.resize( + img, + (int(1 / sf1 * img.shape[1]), int(1 / sf1 * img.shape[0])), + interpolation=random.choice([1, 2, 3]), + ) else: - k = fspecial('gaussian', 25, random.uniform(0.1, 0.6 * sf)) + k = fspecial("gaussian", 25, random.uniform(0.1, 0.6 * sf)) k_shifted = shift_pixel(k, sf) k_shifted = k_shifted / k_shifted.sum() # blur with shifted kernel - img = ndimage.convolve(img, np.expand_dims(k_shifted, axis=2), mode='mirror') + img = ndimage.convolve(img, np.expand_dims(k_shifted, axis=2), mode="mirror") img = img[0::sf, 0::sf, ...] # nearest downsampling img = np.clip(img, 0.0, 1.0) @@ -544,18 +547,20 @@ def degradation_bsrgan_variant(image, sf=4, isp_model=None, up=False): """ image = util.uint2single(image) isp_prob, jpeg_prob, scale2_prob = 0.25, 0.9, 0.25 - sf_ori = sf h1, w1 = image.shape[:2] - image = image.copy()[:w1 - w1 % sf, :h1 - h1 % sf, ...] # mod crop + image = image.copy()[: w1 - w1 % sf, : h1 - h1 % sf, ...] # mod crop h, w = image.shape[:2] - hq = image.copy() + image.copy() if sf == 4 and random.random() < scale2_prob: # downsample1 if np.random.rand() < 0.5: - image = cv2.resize(image, (int(1 / 2 * image.shape[1]), int(1 / 2 * image.shape[0])), - interpolation=random.choice([1, 2, 3])) + image = cv2.resize( + image, + (int(1 / 2 * image.shape[1]), int(1 / 2 * image.shape[0])), + interpolation=random.choice([1, 2, 3]), + ) else: image = util.imresize_np(image, 1 / 2, True) image = np.clip(image, 0.0, 1.0) @@ -567,7 +572,6 @@ def degradation_bsrgan_variant(image, sf=4, isp_model=None, up=False): shuffle_order[idx1], shuffle_order[idx2] = shuffle_order[idx2], shuffle_order[idx1] for i in shuffle_order: - if i == 0: image = add_blur(image, sf=sf) @@ -582,13 +586,16 @@ def degradation_bsrgan_variant(image, sf=4, isp_model=None, up=False): # downsample2 if random.random() < 0.8: sf1 = random.uniform(1, 2 * sf) - image = cv2.resize(image, (int(1 / sf1 * image.shape[1]), int(1 / sf1 * image.shape[0])), - interpolation=random.choice([1, 2, 3])) + image = cv2.resize( + image, + (int(1 / sf1 * image.shape[1]), int(1 / sf1 * image.shape[0])), + interpolation=random.choice([1, 2, 3]), + ) else: - k = fspecial('gaussian', 25, random.uniform(0.1, 0.6 * sf)) + k = fspecial("gaussian", 25, random.uniform(0.1, 0.6 * sf)) k_shifted = shift_pixel(k, sf) k_shifted = k_shifted / k_shifted.sum() # blur with shifted kernel - image = ndimage.convolve(image, np.expand_dims(k_shifted, axis=2), mode='mirror') + image = ndimage.convolve(image, np.expand_dims(k_shifted, axis=2), mode="mirror") image = image[0::sf, 0::sf, ...] # nearest downsampling image = np.clip(image, 0.0, 1.0) @@ -617,16 +624,16 @@ def degradation_bsrgan_variant(image, sf=4, isp_model=None, up=False): image = add_JPEG_noise(image) image = util.single2uint(image) if up: - image = cv2.resize(image, (w1, h1), interpolation=cv2.INTER_CUBIC) # todo: random, as above? want to condition on it then + image = cv2.resize( + image, (w1, h1), interpolation=cv2.INTER_CUBIC + ) # todo: random, as above? want to condition on it then example = {"image": image} return example - - -if __name__ == '__main__': +if __name__ == "__main__": print("hey") - img = util.imread_uint('utils/test.png', 3) + img = util.imread_uint("utils/test.png", 3) img = img[:448, :448] h = img.shape[0] // 4 print("resizing to", h) @@ -638,14 +645,17 @@ if __name__ == '__main__': img_lq = deg_fn(img)["image"] img_hq, img_lq = util.uint2single(img_hq), util.uint2single(img_lq) print(img_lq) - img_lq_bicubic = albumentations.SmallestMaxSize(max_size=h, interpolation=cv2.INTER_CUBIC)(image=img_hq)["image"] + img_lq_bicubic = albumentations.SmallestMaxSize(max_size=h, interpolation=cv2.INTER_CUBIC)(image=img_hq)[ + "image" + ] print(img_lq.shape) print("bicubic", img_lq_bicubic.shape) print(img_hq.shape) - lq_nearest = cv2.resize(util.single2uint(img_lq), (int(sf * img_lq.shape[1]), int(sf * img_lq.shape[0])), - interpolation=0) - lq_bicubic_nearest = cv2.resize(util.single2uint(img_lq_bicubic), - (int(sf * img_lq.shape[1]), int(sf * img_lq.shape[0])), - interpolation=0) + lq_nearest = cv2.resize( + util.single2uint(img_lq), (int(sf * img_lq.shape[1]), int(sf * img_lq.shape[0])), interpolation=0 + ) + lq_bicubic_nearest = cv2.resize( + util.single2uint(img_lq_bicubic), (int(sf * img_lq.shape[1]), int(sf * img_lq.shape[0])), interpolation=0 + ) img_concat = np.concatenate([lq_bicubic_nearest, lq_nearest, util.single2uint(img_hq)], axis=1) - util.imsave(img_concat, str(i) + '.png') + util.imsave(img_concat, str(i) + ".png") diff --git a/examples/images/diffusion/ldm/modules/image_degradation/utils_image.py b/examples/images/diffusion/ldm/modules/image_degradation/utils_image.py index 0175f155ad900ae33c3c46ed87f49b352e3faf98..71fae1084b61446d3bf6144cb49e511c4d8f19fe 100644 --- a/examples/images/diffusion/ldm/modules/image_degradation/utils_image.py +++ b/examples/images/diffusion/ldm/modules/image_degradation/utils_image.py @@ -1,18 +1,20 @@ -import os import math +import os import random +from datetime import datetime + +import cv2 import numpy as np import torch -import cv2 from torchvision.utils import make_grid -from datetime import datetime -#import matplotlib.pyplot as plt # TODO: check with Dominik, also bsrgan.py vs bsrgan_light.py + +# import matplotlib.pyplot as plt # TODO: check with Dominik, also bsrgan.py vs bsrgan_light.py -os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE" +os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE" -''' +""" # -------------------------------------------- # Kai Zhang (github: https://github.com/cszn) # 03/Mar/2019 @@ -20,10 +22,10 @@ os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE" # https://github.com/twhui/SRGAN-pyTorch # https://github.com/xinntao/BasicSR # -------------------------------------------- -''' +""" -IMG_EXTENSIONS = ['.jpg', '.JPG', '.jpeg', '.JPEG', '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP', '.tif'] +IMG_EXTENSIONS = [".jpg", ".JPG", ".jpeg", ".JPEG", ".png", ".PNG", ".ppm", ".PPM", ".bmp", ".BMP", ".tif"] def is_image_file(filename): @@ -31,12 +33,12 @@ def is_image_file(filename): def get_timestamp(): - return datetime.now().strftime('%y%m%d-%H%M%S') + return datetime.now().strftime("%y%m%d-%H%M%S") def imshow(x, title=None, cbar=False, figsize=None): plt.figure(figsize=figsize) - plt.imshow(np.squeeze(x), interpolation='nearest', cmap='gray') + plt.imshow(np.squeeze(x), interpolation="nearest", cmap="gray") if title: plt.title(title) if cbar: @@ -44,24 +46,24 @@ def imshow(x, title=None, cbar=False, figsize=None): plt.show() -def surf(Z, cmap='rainbow', figsize=None): +def surf(Z, cmap="rainbow", figsize=None): plt.figure(figsize=figsize) - ax3 = plt.axes(projection='3d') + ax3 = plt.axes(projection="3d") w, h = Z.shape[:2] - xx = np.arange(0,w,1) - yy = np.arange(0,h,1) + xx = np.arange(0, w, 1) + yy = np.arange(0, h, 1) X, Y = np.meshgrid(xx, yy) - ax3.plot_surface(X,Y,Z,cmap=cmap) - #ax3.contour(X,Y,Z, zdim='z',offset=-2,cmap=cmap) + ax3.plot_surface(X, Y, Z, cmap=cmap) + # ax3.contour(X,Y,Z, zdim='z',offset=-2,cmap=cmap) plt.show() -''' +""" # -------------------------------------------- # get image pathes # -------------------------------------------- -''' +""" def get_image_paths(dataroot): @@ -72,37 +74,37 @@ def get_image_paths(dataroot): def _get_paths_from_images(path): - assert os.path.isdir(path), '{:s} is not a valid directory'.format(path) + assert os.path.isdir(path), "{:s} is not a valid directory".format(path) images = [] for dirpath, _, fnames in sorted(os.walk(path)): for fname in sorted(fnames): if is_image_file(fname): img_path = os.path.join(dirpath, fname) images.append(img_path) - assert images, '{:s} has no valid image file'.format(path) + assert images, "{:s} has no valid image file".format(path) return images -''' +""" # -------------------------------------------- -# split large images into small images +# split large images into small images # -------------------------------------------- -''' +""" def patches_from_image(img, p_size=512, p_overlap=64, p_max=800): w, h = img.shape[:2] patches = [] if w > p_max and h > p_max: - w1 = list(np.arange(0, w-p_size, p_size-p_overlap, dtype=np.int)) - h1 = list(np.arange(0, h-p_size, p_size-p_overlap, dtype=np.int)) - w1.append(w-p_size) - h1.append(h-p_size) -# print(w1) -# print(h1) + w1 = list(np.arange(0, w - p_size, p_size - p_overlap, dtype=np.int)) + h1 = list(np.arange(0, h - p_size, p_size - p_overlap, dtype=np.int)) + w1.append(w - p_size) + h1.append(h - p_size) + # print(w1) + # print(h1) for i in w1: for j in h1: - patches.append(img[i:i+p_size, j:j+p_size,:]) + patches.append(img[i : i + p_size, j : j + p_size, :]) else: patches.append(img) @@ -118,7 +120,7 @@ def imssave(imgs, img_path): for i, img in enumerate(imgs): if img.ndim == 3: img = img[:, :, [2, 1, 0]] - new_path = os.path.join(os.path.dirname(img_path), img_name+str('_s{:04d}'.format(i))+'.png') + new_path = os.path.join(os.path.dirname(img_path), img_name + str("_s{:04d}".format(i)) + ".png") cv2.imwrite(new_path, img) @@ -139,15 +141,16 @@ def split_imageset(original_dataroot, taget_dataroot, n_channels=3, p_size=800, # img_name, ext = os.path.splitext(os.path.basename(img_path)) img = imread_uint(img_path, n_channels=n_channels) patches = patches_from_image(img, p_size, p_overlap, p_max) - imssave(patches, os.path.join(taget_dataroot,os.path.basename(img_path))) - #if original_dataroot == taget_dataroot: - #del img_path + imssave(patches, os.path.join(taget_dataroot, os.path.basename(img_path))) + # if original_dataroot == taget_dataroot: + # del img_path + -''' +""" # -------------------------------------------- # makedir # -------------------------------------------- -''' +""" def mkdir(path): @@ -165,18 +168,18 @@ def mkdirs(paths): def mkdir_and_rename(path): if os.path.exists(path): - new_name = path + '_archived_' + get_timestamp() - print('Path already exists. Rename it to [{:s}]'.format(new_name)) + new_name = path + "_archived_" + get_timestamp() + print("Path already exists. Rename it to [{:s}]".format(new_name)) os.rename(path, new_name) os.makedirs(path) -''' +""" # -------------------------------------------- # read image from path # opencv is fast, but read BGR numpy image # -------------------------------------------- -''' +""" # -------------------------------------------- @@ -206,6 +209,7 @@ def imsave(img, img_path): img = img[:, :, [2, 1, 0]] cv2.imwrite(img_path, img) + def imwrite(img, img_path): img = np.squeeze(img) if img.ndim == 3: @@ -213,7 +217,6 @@ def imwrite(img, img_path): cv2.imwrite(img_path, img) - # -------------------------------------------- # get single image of size HxWxn_channles (BGR) # -------------------------------------------- @@ -221,7 +224,7 @@ def read_img(path): # read image by cv2 # return: Numpy float32, HWC, BGR, [0,1] img = cv2.imread(path, cv2.IMREAD_UNCHANGED) # cv2.IMREAD_GRAYSCALE - img = img.astype(np.float32) / 255. + img = img.astype(np.float32) / 255.0 if img.ndim == 2: img = np.expand_dims(img, axis=2) # some images have 4 channels @@ -230,7 +233,7 @@ def read_img(path): return img -''' +""" # -------------------------------------------- # image format conversion # -------------------------------------------- @@ -238,7 +241,7 @@ def read_img(path): # numpy(single) <---> tensor # numpy(unit) <---> tensor # -------------------------------------------- -''' +""" # -------------------------------------------- @@ -247,23 +250,19 @@ def read_img(path): def uint2single(img): - - return np.float32(img/255.) + return np.float32(img / 255.0) def single2uint(img): - - return np.uint8((img.clip(0, 1)*255.).round()) + return np.uint8((img.clip(0, 1) * 255.0).round()) def uint162single(img): - - return np.float32(img/65535.) + return np.float32(img / 65535.0) def single2uint16(img): - - return np.uint16((img.clip(0, 1)*65535.).round()) + return np.uint16((img.clip(0, 1) * 65535.0).round()) # -------------------------------------------- @@ -275,14 +274,14 @@ def single2uint16(img): def uint2tensor4(img): if img.ndim == 2: img = np.expand_dims(img, axis=2) - return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1).float().div(255.).unsqueeze(0) + return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1).float().div(255.0).unsqueeze(0) # convert uint to 3-dimensional torch tensor def uint2tensor3(img): if img.ndim == 2: img = np.expand_dims(img, axis=2) - return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1).float().div(255.) + return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1).float().div(255.0) # convert 2/3/4-dimensional torch tensor to uint @@ -290,7 +289,7 @@ def tensor2uint(img): img = img.data.squeeze().float().clamp_(0, 1).cpu().numpy() if img.ndim == 3: img = np.transpose(img, (1, 2, 0)) - return np.uint8((img*255.0).round()) + return np.uint8((img * 255.0).round()) # -------------------------------------------- @@ -316,6 +315,7 @@ def tensor2single(img): return img + # convert torch tensor to single def tensor2single3(img): img = img.data.squeeze().float().cpu().numpy() @@ -340,11 +340,11 @@ def single42tensor4(img): # from skimage.io import imread, imsave def tensor2img(tensor, out_type=np.uint8, min_max=(0, 1)): - ''' + """ Converts a torch Tensor into an image Numpy array of BGR channel order Input: 4D(B,(3/1),H,W), 3D(C,H,W), or 2D(H,W), any range, RGB channel order Output: 3D(H,W,C) or 2D(H,W), [0,255], np.uint8 (default) - ''' + """ tensor = tensor.squeeze().float().cpu().clamp_(*min_max) # squeeze first, then clamp tensor = (tensor - min_max[0]) / (min_max[1] - min_max[0]) # to range [0,1] n_dim = tensor.dim() @@ -358,15 +358,14 @@ def tensor2img(tensor, out_type=np.uint8, min_max=(0, 1)): elif n_dim == 2: img_np = tensor.numpy() else: - raise TypeError( - 'Only support 4D, 3D and 2D tensor. But received with dimension: {:d}'.format(n_dim)) + raise TypeError("Only support 4D, 3D and 2D tensor. But received with dimension: {:d}".format(n_dim)) if out_type == np.uint8: img_np = (img_np * 255.0).round() # Important. Unlike matlab, numpy.unit8() WILL NOT round by default. return img_np.astype(out_type) -''' +""" # -------------------------------------------- # Augmentation, flipe and/or rotate # -------------------------------------------- @@ -374,12 +373,11 @@ def tensor2img(tensor, out_type=np.uint8, min_max=(0, 1)): # (1) augmet_img: numpy image of WxHxC or WxH # (2) augment_img_tensor4: tensor image 1xCxWxH # -------------------------------------------- -''' +""" def augment_img(img, mode=0): - '''Kai Zhang (github: https://github.com/cszn) - ''' + """Kai Zhang (github: https://github.com/cszn)""" if mode == 0: return img elif mode == 1: @@ -399,8 +397,7 @@ def augment_img(img, mode=0): def augment_img_tensor4(img, mode=0): - '''Kai Zhang (github: https://github.com/cszn) - ''' + """Kai Zhang (github: https://github.com/cszn)""" if mode == 0: return img elif mode == 1: @@ -420,8 +417,7 @@ def augment_img_tensor4(img, mode=0): def augment_img_tensor(img, mode=0): - '''Kai Zhang (github: https://github.com/cszn) - ''' + """Kai Zhang (github: https://github.com/cszn)""" img_size = img.size() img_np = img.data.cpu().numpy() if len(img_size) == 3: @@ -484,11 +480,11 @@ def augment_imgs(img_list, hflip=True, rot=True): return [_augment(img) for img in img_list] -''' +""" # -------------------------------------------- # modcrop and shave # -------------------------------------------- -''' +""" def modcrop(img_in, scale): @@ -497,13 +493,13 @@ def modcrop(img_in, scale): if img.ndim == 2: H, W = img.shape H_r, W_r = H % scale, W % scale - img = img[:H - H_r, :W - W_r] + img = img[: H - H_r, : W - W_r] elif img.ndim == 3: H, W, C = img.shape H_r, W_r = H % scale, W % scale - img = img[:H - H_r, :W - W_r, :] + img = img[: H - H_r, : W - W_r, :] else: - raise ValueError('Wrong img ndim: [{:d}].'.format(img.ndim)) + raise ValueError("Wrong img ndim: [{:d}].".format(img.ndim)) return img @@ -511,11 +507,11 @@ def shave(img_in, border=0): # img_in: Numpy, HWC or HW img = np.copy(img_in) h, w = img.shape[:2] - img = img[border:h-border, border:w-border] + img = img[border : h - border, border : w - border] return img -''' +""" # -------------------------------------------- # image processing process on numpy image # channel_convert(in_c, tar_type, img_list): @@ -523,96 +519,99 @@ def shave(img_in, border=0): # bgr2ycbcr(img, only_y=True): # ycbcr2rgb(img): # -------------------------------------------- -''' +""" def rgb2ycbcr(img, only_y=True): - '''same as matlab rgb2ycbcr + """same as matlab rgb2ycbcr only_y: only return Y channel Input: uint8, [0, 255] float, [0, 1] - ''' + """ in_img_type = img.dtype img.astype(np.float32) if in_img_type != np.uint8: - img *= 255. + img *= 255.0 # convert if only_y: rlt = np.dot(img, [65.481, 128.553, 24.966]) / 255.0 + 16.0 else: - rlt = np.matmul(img, [[65.481, -37.797, 112.0], [128.553, -74.203, -93.786], - [24.966, 112.0, -18.214]]) / 255.0 + [16, 128, 128] + rlt = np.matmul( + img, [[65.481, -37.797, 112.0], [128.553, -74.203, -93.786], [24.966, 112.0, -18.214]] + ) / 255.0 + [16, 128, 128] if in_img_type == np.uint8: rlt = rlt.round() else: - rlt /= 255. + rlt /= 255.0 return rlt.astype(in_img_type) def ycbcr2rgb(img): - '''same as matlab ycbcr2rgb + """same as matlab ycbcr2rgb Input: uint8, [0, 255] float, [0, 1] - ''' + """ in_img_type = img.dtype img.astype(np.float32) if in_img_type != np.uint8: - img *= 255. + img *= 255.0 # convert - rlt = np.matmul(img, [[0.00456621, 0.00456621, 0.00456621], [0, -0.00153632, 0.00791071], - [0.00625893, -0.00318811, 0]]) * 255.0 + [-222.921, 135.576, -276.836] + rlt = np.matmul( + img, [[0.00456621, 0.00456621, 0.00456621], [0, -0.00153632, 0.00791071], [0.00625893, -0.00318811, 0]] + ) * 255.0 + [-222.921, 135.576, -276.836] if in_img_type == np.uint8: rlt = rlt.round() else: - rlt /= 255. + rlt /= 255.0 return rlt.astype(in_img_type) def bgr2ycbcr(img, only_y=True): - '''bgr version of rgb2ycbcr + """bgr version of rgb2ycbcr only_y: only return Y channel Input: uint8, [0, 255] float, [0, 1] - ''' + """ in_img_type = img.dtype img.astype(np.float32) if in_img_type != np.uint8: - img *= 255. + img *= 255.0 # convert if only_y: rlt = np.dot(img, [24.966, 128.553, 65.481]) / 255.0 + 16.0 else: - rlt = np.matmul(img, [[24.966, 112.0, -18.214], [128.553, -74.203, -93.786], - [65.481, -37.797, 112.0]]) / 255.0 + [16, 128, 128] + rlt = np.matmul( + img, [[24.966, 112.0, -18.214], [128.553, -74.203, -93.786], [65.481, -37.797, 112.0]] + ) / 255.0 + [16, 128, 128] if in_img_type == np.uint8: rlt = rlt.round() else: - rlt /= 255. + rlt /= 255.0 return rlt.astype(in_img_type) def channel_convert(in_c, tar_type, img_list): # conversion among BGR, gray and y - if in_c == 3 and tar_type == 'gray': # BGR to gray + if in_c == 3 and tar_type == "gray": # BGR to gray gray_list = [cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) for img in img_list] return [np.expand_dims(img, axis=2) for img in gray_list] - elif in_c == 3 and tar_type == 'y': # BGR to y + elif in_c == 3 and tar_type == "y": # BGR to y y_list = [bgr2ycbcr(img, only_y=True) for img in img_list] return [np.expand_dims(img, axis=2) for img in y_list] - elif in_c == 1 and tar_type == 'RGB': # gray/y to BGR + elif in_c == 1 and tar_type == "RGB": # gray/y to BGR return [cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) for img in img_list] else: return img_list -''' +""" # -------------------------------------------- # metric, PSNR and SSIM # -------------------------------------------- -''' +""" # -------------------------------------------- @@ -620,19 +619,19 @@ def channel_convert(in_c, tar_type, img_list): # -------------------------------------------- def calculate_psnr(img1, img2, border=0): # img1 and img2 have range [0, 255] - #img1 = img1.squeeze() - #img2 = img2.squeeze() + # img1 = img1.squeeze() + # img2 = img2.squeeze() if not img1.shape == img2.shape: - raise ValueError('Input images must have the same dimensions.') + raise ValueError("Input images must have the same dimensions.") h, w = img1.shape[:2] - img1 = img1[border:h-border, border:w-border] - img2 = img2[border:h-border, border:w-border] + img1 = img1[border : h - border, border : w - border] + img2 = img2[border : h - border, border : w - border] img1 = img1.astype(np.float64) img2 = img2.astype(np.float64) - mse = np.mean((img1 - img2)**2) + mse = np.mean((img1 - img2) ** 2) if mse == 0: - return float('inf') + return float("inf") return 20 * math.log10(255.0 / math.sqrt(mse)) @@ -640,17 +639,17 @@ def calculate_psnr(img1, img2, border=0): # SSIM # -------------------------------------------- def calculate_ssim(img1, img2, border=0): - '''calculate SSIM + """calculate SSIM the same outputs as MATLAB's img1, img2: [0, 255] - ''' - #img1 = img1.squeeze() - #img2 = img2.squeeze() + """ + # img1 = img1.squeeze() + # img2 = img2.squeeze() if not img1.shape == img2.shape: - raise ValueError('Input images must have the same dimensions.') + raise ValueError("Input images must have the same dimensions.") h, w = img1.shape[:2] - img1 = img1[border:h-border, border:w-border] - img2 = img2[border:h-border, border:w-border] + img1 = img1[border : h - border, border : w - border] + img2 = img2[border : h - border, border : w - border] if img1.ndim == 2: return ssim(img1, img2) @@ -658,17 +657,17 @@ def calculate_ssim(img1, img2, border=0): if img1.shape[2] == 3: ssims = [] for i in range(3): - ssims.append(ssim(img1[:,:,i], img2[:,:,i])) + ssims.append(ssim(img1[:, :, i], img2[:, :, i])) return np.array(ssims).mean() elif img1.shape[2] == 1: return ssim(np.squeeze(img1), np.squeeze(img2)) else: - raise ValueError('Wrong input image dimensions.') + raise ValueError("Wrong input image dimensions.") def ssim(img1, img2): - C1 = (0.01 * 255)**2 - C2 = (0.03 * 255)**2 + C1 = (0.01 * 255) ** 2 + C2 = (0.03 * 255) ** 2 img1 = img1.astype(np.float64) img2 = img2.astype(np.float64) @@ -684,16 +683,15 @@ def ssim(img1, img2): sigma2_sq = cv2.filter2D(img2**2, -1, window)[5:-5, 5:-5] - mu2_sq sigma12 = cv2.filter2D(img1 * img2, -1, window)[5:-5, 5:-5] - mu1_mu2 - ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * - (sigma1_sq + sigma2_sq + C2)) + ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2)) return ssim_map.mean() -''' +""" # -------------------------------------------- # matlab's bicubic imresize (numpy and torch) [0, 1] # -------------------------------------------- -''' +""" # matlab 'imresize' function, now only support 'bicubic' @@ -701,8 +699,9 @@ def cubic(x): absx = torch.abs(x) absx2 = absx**2 absx3 = absx**3 - return (1.5*absx3 - 2.5*absx2 + 1) * ((absx <= 1).type_as(absx)) + \ - (-0.5*absx3 + 2.5*absx2 - 4*absx + 2) * (((absx > 1)*(absx <= 2)).type_as(absx)) + return (1.5 * absx3 - 2.5 * absx2 + 1) * ((absx <= 1).type_as(absx)) + ( + -0.5 * absx3 + 2.5 * absx2 - 4 * absx + 2 + ) * (((absx > 1) * (absx <= 2)).type_as(absx)) def calculate_weights_indices(in_length, out_length, scale, kernel, kernel_width, antialiasing): @@ -729,8 +728,9 @@ def calculate_weights_indices(in_length, out_length, scale, kernel, kernel_width # The indices of the input pixels involved in computing the k-th output # pixel are in row k of the indices matrix. - indices = left.view(out_length, 1).expand(out_length, P) + torch.linspace(0, P - 1, P).view( - 1, P).expand(out_length, P) + indices = left.view(out_length, 1).expand(out_length, P) + torch.linspace(0, P - 1, P).view(1, P).expand( + out_length, P + ) # The weights used to compute the k-th output pixel are in row k of the # weights matrix. @@ -773,7 +773,7 @@ def imresize(img, scale, antialiasing=True): in_C, in_H, in_W = img.size() out_C, out_H, out_W = in_C, math.ceil(in_H * scale), math.ceil(in_W * scale) kernel_width = 4 - kernel = 'cubic' + kernel = "cubic" # Return the desired dimension order for performing the resize. The # strategy is to perform the resize first along the dimension with the @@ -782,9 +782,11 @@ def imresize(img, scale, antialiasing=True): # get weights and indices weights_H, indices_H, sym_len_Hs, sym_len_He = calculate_weights_indices( - in_H, out_H, scale, kernel, kernel_width, antialiasing) + in_H, out_H, scale, kernel, kernel_width, antialiasing + ) weights_W, indices_W, sym_len_Ws, sym_len_We = calculate_weights_indices( - in_W, out_W, scale, kernel, kernel_width, antialiasing) + in_W, out_W, scale, kernel, kernel_width, antialiasing + ) # process H dimension # symmetric copying img_aug = torch.FloatTensor(in_C, in_H + sym_len_Hs + sym_len_He, in_W) @@ -805,7 +807,7 @@ def imresize(img, scale, antialiasing=True): for i in range(out_H): idx = int(indices_H[i][0]) for j in range(out_C): - out_1[j, i, :] = img_aug[j, idx:idx + kernel_width, :].transpose(0, 1).mv(weights_H[i]) + out_1[j, i, :] = img_aug[j, idx : idx + kernel_width, :].transpose(0, 1).mv(weights_H[i]) # process W dimension # symmetric copying @@ -827,7 +829,7 @@ def imresize(img, scale, antialiasing=True): for i in range(out_W): idx = int(indices_W[i][0]) for j in range(out_C): - out_2[j, :, i] = out_1_aug[j, :, idx:idx + kernel_width].mv(weights_W[i]) + out_2[j, :, i] = out_1_aug[j, :, idx : idx + kernel_width].mv(weights_W[i]) if need_squeeze: out_2.squeeze_() return out_2 @@ -848,7 +850,7 @@ def imresize_np(img, scale, antialiasing=True): in_H, in_W, in_C = img.size() out_C, out_H, out_W = in_C, math.ceil(in_H * scale), math.ceil(in_W * scale) kernel_width = 4 - kernel = 'cubic' + kernel = "cubic" # Return the desired dimension order for performing the resize. The # strategy is to perform the resize first along the dimension with the @@ -857,9 +859,11 @@ def imresize_np(img, scale, antialiasing=True): # get weights and indices weights_H, indices_H, sym_len_Hs, sym_len_He = calculate_weights_indices( - in_H, out_H, scale, kernel, kernel_width, antialiasing) + in_H, out_H, scale, kernel, kernel_width, antialiasing + ) weights_W, indices_W, sym_len_Ws, sym_len_We = calculate_weights_indices( - in_W, out_W, scale, kernel, kernel_width, antialiasing) + in_W, out_W, scale, kernel, kernel_width, antialiasing + ) # process H dimension # symmetric copying img_aug = torch.FloatTensor(in_H + sym_len_Hs + sym_len_He, in_W, in_C) @@ -880,7 +884,7 @@ def imresize_np(img, scale, antialiasing=True): for i in range(out_H): idx = int(indices_H[i][0]) for j in range(out_C): - out_1[i, :, j] = img_aug[idx:idx + kernel_width, :, j].transpose(0, 1).mv(weights_H[i]) + out_1[i, :, j] = img_aug[idx : idx + kernel_width, :, j].transpose(0, 1).mv(weights_H[i]) # process W dimension # symmetric copying @@ -902,15 +906,15 @@ def imresize_np(img, scale, antialiasing=True): for i in range(out_W): idx = int(indices_W[i][0]) for j in range(out_C): - out_2[:, i, j] = out_1_aug[:, idx:idx + kernel_width, j].mv(weights_W[i]) + out_2[:, i, j] = out_1_aug[:, idx : idx + kernel_width, j].mv(weights_W[i]) if need_squeeze: out_2.squeeze_() return out_2.numpy() -if __name__ == '__main__': - print('---') +if __name__ == "__main__": + print("---") # img = imread_uint('test.bmp', 3) # img = uint2single(img) -# img_bicubic = imresize_np(img, 1/4) \ No newline at end of file +# img_bicubic = imresize_np(img, 1/4) diff --git a/examples/images/diffusion/ldm/modules/midas/api.py b/examples/images/diffusion/ldm/modules/midas/api.py index b58ebbffd942a2fc22264f0ab47e400c26b9f41c..6619f515fa0e26572e7513662b0127647974b2de 100644 --- a/examples/images/diffusion/ldm/modules/midas/api.py +++ b/examples/images/diffusion/ldm/modules/midas/api.py @@ -3,13 +3,11 @@ import cv2 import torch import torch.nn as nn -from torchvision.transforms import Compose - from ldm.modules.midas.midas.dpt_depth import DPTDepthModel from ldm.modules.midas.midas.midas_net import MidasNet from ldm.modules.midas.midas.midas_net_custom import MidasNet_small -from ldm.modules.midas.midas.transforms import Resize, NormalizeImage, PrepareForNet - +from ldm.modules.midas.midas.transforms import NormalizeImage, PrepareForNet, Resize +from torchvision.transforms import Compose ISL_PATHS = { "dpt_large": "midas_models/dpt_large-midas-2f21e586.pt", @@ -98,18 +96,20 @@ def load_model(model_type): model = MidasNet(model_path, non_negative=True) net_w, net_h = 384, 384 resize_mode = "upper_bound" - normalization = NormalizeImage( - mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] - ) + normalization = NormalizeImage(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) elif model_type == "midas_v21_small": - model = MidasNet_small(model_path, features=64, backbone="efficientnet_lite3", exportable=True, - non_negative=True, blocks={'expand': True}) + model = MidasNet_small( + model_path, + features=64, + backbone="efficientnet_lite3", + exportable=True, + non_negative=True, + blocks={"expand": True}, + ) net_w, net_h = 256, 256 resize_mode = "upper_bound" - normalization = NormalizeImage( - mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] - ) + normalization = NormalizeImage(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) else: print(f"model_type '{model_type}' not implemented, use: --model_type large") @@ -135,11 +135,7 @@ def load_model(model_type): class MiDaSInference(nn.Module): - MODEL_TYPES_TORCH_HUB = [ - "DPT_Large", - "DPT_Hybrid", - "MiDaS_small" - ] + MODEL_TYPES_TORCH_HUB = ["DPT_Large", "DPT_Hybrid", "MiDaS_small"] MODEL_TYPES_ISL = [ "dpt_large", "dpt_hybrid", @@ -149,7 +145,7 @@ class MiDaSInference(nn.Module): def __init__(self, model_type): super().__init__() - assert (model_type in self.MODEL_TYPES_ISL) + assert model_type in self.MODEL_TYPES_ISL model, _ = load_model(model_type) self.model = model self.model.train = disabled_train @@ -167,4 +163,3 @@ class MiDaSInference(nn.Module): ) assert prediction.shape == (x.shape[0], 1, x.shape[2], x.shape[3]) return prediction - diff --git a/examples/images/diffusion/ldm/modules/midas/midas/base_model.py b/examples/images/diffusion/ldm/modules/midas/midas/base_model.py index 5cf430239b47ec5ec07531263f26f5c24a2311cd..5c2e0e93b0495f48a3405546b6fe1969be3480a2 100644 --- a/examples/images/diffusion/ldm/modules/midas/midas/base_model.py +++ b/examples/images/diffusion/ldm/modules/midas/midas/base_model.py @@ -8,7 +8,7 @@ class BaseModel(torch.nn.Module): Args: path (str): file path """ - parameters = torch.load(path, map_location=torch.device('cpu')) + parameters = torch.load(path, map_location=torch.device("cpu")) if "optimizer" in parameters: parameters = parameters["model"] diff --git a/examples/images/diffusion/ldm/modules/midas/midas/blocks.py b/examples/images/diffusion/ldm/modules/midas/midas/blocks.py index 2145d18fa98060a618536d9a64fe6589e9be4f78..154de57cd2e81d5f54d378b088e2661f2b02bd6d 100644 --- a/examples/images/diffusion/ldm/modules/midas/midas/blocks.py +++ b/examples/images/diffusion/ldm/modules/midas/midas/blocks.py @@ -1,18 +1,22 @@ import torch import torch.nn as nn -from .vit import ( - _make_pretrained_vitb_rn50_384, - _make_pretrained_vitl16_384, - _make_pretrained_vitb16_384, - forward_vit, -) - -def _make_encoder(backbone, features, use_pretrained, groups=1, expand=False, exportable=True, hooks=None, use_vit_only=False, use_readout="ignore",): +from .vit import _make_pretrained_vitb16_384, _make_pretrained_vitb_rn50_384, _make_pretrained_vitl16_384 + + +def _make_encoder( + backbone, + features, + use_pretrained, + groups=1, + expand=False, + exportable=True, + hooks=None, + use_vit_only=False, + use_readout="ignore", +): if backbone == "vitl16_384": - pretrained = _make_pretrained_vitl16_384( - use_pretrained, hooks=hooks, use_readout=use_readout - ) + pretrained = _make_pretrained_vitl16_384(use_pretrained, hooks=hooks, use_readout=use_readout) scratch = _make_scratch( [256, 512, 1024, 1024], features, groups=groups, expand=expand ) # ViT-L/16 - 85.0% Top1 (backbone) @@ -27,22 +31,20 @@ def _make_encoder(backbone, features, use_pretrained, groups=1, expand=False, ex [256, 512, 768, 768], features, groups=groups, expand=expand ) # ViT-H/16 - 85.0% Top1 (backbone) elif backbone == "vitb16_384": - pretrained = _make_pretrained_vitb16_384( - use_pretrained, hooks=hooks, use_readout=use_readout - ) + pretrained = _make_pretrained_vitb16_384(use_pretrained, hooks=hooks, use_readout=use_readout) scratch = _make_scratch( [96, 192, 384, 768], features, groups=groups, expand=expand ) # ViT-B/16 - 84.6% Top1 (backbone) elif backbone == "resnext101_wsl": pretrained = _make_pretrained_resnext101_wsl(use_pretrained) - scratch = _make_scratch([256, 512, 1024, 2048], features, groups=groups, expand=expand) # efficientnet_lite3 + scratch = _make_scratch([256, 512, 1024, 2048], features, groups=groups, expand=expand) # efficientnet_lite3 elif backbone == "efficientnet_lite3": pretrained = _make_pretrained_efficientnet_lite3(use_pretrained, exportable=exportable) - scratch = _make_scratch([32, 48, 136, 384], features, groups=groups, expand=expand) # efficientnet_lite3 + scratch = _make_scratch([32, 48, 136, 384], features, groups=groups, expand=expand) # efficientnet_lite3 else: print(f"Backbone '{backbone}' not implemented") assert False - + return pretrained, scratch @@ -53,11 +55,11 @@ def _make_scratch(in_shape, out_shape, groups=1, expand=False): out_shape2 = out_shape out_shape3 = out_shape out_shape4 = out_shape - if expand==True: + if expand == True: out_shape1 = out_shape - out_shape2 = out_shape*2 - out_shape3 = out_shape*4 - out_shape4 = out_shape*8 + out_shape2 = out_shape * 2 + out_shape3 = out_shape * 4 + out_shape4 = out_shape * 8 scratch.layer1_rn = nn.Conv2d( in_shape[0], out_shape1, kernel_size=3, stride=1, padding=1, bias=False, groups=groups @@ -77,10 +79,7 @@ def _make_scratch(in_shape, out_shape, groups=1, expand=False): def _make_pretrained_efficientnet_lite3(use_pretrained, exportable=False): efficientnet = torch.hub.load( - "rwightman/gen-efficientnet-pytorch", - "tf_efficientnet_lite3", - pretrained=use_pretrained, - exportable=exportable + "rwightman/gen-efficientnet-pytorch", "tf_efficientnet_lite3", pretrained=use_pretrained, exportable=exportable ) return _make_efficientnet_backbone(efficientnet) @@ -88,21 +87,17 @@ def _make_pretrained_efficientnet_lite3(use_pretrained, exportable=False): def _make_efficientnet_backbone(effnet): pretrained = nn.Module() - pretrained.layer1 = nn.Sequential( - effnet.conv_stem, effnet.bn1, effnet.act1, *effnet.blocks[0:2] - ) + pretrained.layer1 = nn.Sequential(effnet.conv_stem, effnet.bn1, effnet.act1, *effnet.blocks[0:2]) pretrained.layer2 = nn.Sequential(*effnet.blocks[2:3]) pretrained.layer3 = nn.Sequential(*effnet.blocks[3:5]) pretrained.layer4 = nn.Sequential(*effnet.blocks[5:9]) return pretrained - + def _make_resnet_backbone(resnet): pretrained = nn.Module() - pretrained.layer1 = nn.Sequential( - resnet.conv1, resnet.bn1, resnet.relu, resnet.maxpool, resnet.layer1 - ) + pretrained.layer1 = nn.Sequential(resnet.conv1, resnet.bn1, resnet.relu, resnet.maxpool, resnet.layer1) pretrained.layer2 = resnet.layer2 pretrained.layer3 = resnet.layer3 @@ -116,10 +111,8 @@ def _make_pretrained_resnext101_wsl(use_pretrained): return _make_resnet_backbone(resnet) - class Interpolate(nn.Module): - """Interpolation module. - """ + """Interpolation module.""" def __init__(self, scale_factor, mode, align_corners=False): """Init. @@ -145,16 +138,13 @@ class Interpolate(nn.Module): tensor: interpolated data """ - x = self.interp( - x, scale_factor=self.scale_factor, mode=self.mode, align_corners=self.align_corners - ) + x = self.interp(x, scale_factor=self.scale_factor, mode=self.mode, align_corners=self.align_corners) return x class ResidualConvUnit(nn.Module): - """Residual convolution module. - """ + """Residual convolution module.""" def __init__(self, features): """Init. @@ -164,13 +154,9 @@ class ResidualConvUnit(nn.Module): """ super().__init__() - self.conv1 = nn.Conv2d( - features, features, kernel_size=3, stride=1, padding=1, bias=True - ) + self.conv1 = nn.Conv2d(features, features, kernel_size=3, stride=1, padding=1, bias=True) - self.conv2 = nn.Conv2d( - features, features, kernel_size=3, stride=1, padding=1, bias=True - ) + self.conv2 = nn.Conv2d(features, features, kernel_size=3, stride=1, padding=1, bias=True) self.relu = nn.ReLU(inplace=True) @@ -192,8 +178,7 @@ class ResidualConvUnit(nn.Module): class FeatureFusionBlock(nn.Module): - """Feature fusion block. - """ + """Feature fusion block.""" def __init__(self, features): """Init. @@ -219,18 +204,13 @@ class FeatureFusionBlock(nn.Module): output = self.resConfUnit2(output) - output = nn.functional.interpolate( - output, scale_factor=2, mode="bilinear", align_corners=True - ) + output = nn.functional.interpolate(output, scale_factor=2, mode="bilinear", align_corners=True) return output - - class ResidualConvUnit_custom(nn.Module): - """Residual convolution module. - """ + """Residual convolution module.""" def __init__(self, features, activation, bn): """Init. @@ -242,17 +222,13 @@ class ResidualConvUnit_custom(nn.Module): self.bn = bn - self.groups=1 + self.groups = 1 - self.conv1 = nn.Conv2d( - features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups - ) - - self.conv2 = nn.Conv2d( - features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups - ) + self.conv1 = nn.Conv2d(features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups) + + self.conv2 = nn.Conv2d(features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups) - if self.bn==True: + if self.bn == True: self.bn1 = nn.BatchNorm2d(features) self.bn2 = nn.BatchNorm2d(features) @@ -269,15 +245,15 @@ class ResidualConvUnit_custom(nn.Module): Returns: tensor: output """ - + out = self.activation(x) out = self.conv1(out) - if self.bn==True: + if self.bn == True: out = self.bn1(out) - + out = self.activation(out) out = self.conv2(out) - if self.bn==True: + if self.bn == True: out = self.bn2(out) if self.groups > 1: @@ -289,8 +265,7 @@ class ResidualConvUnit_custom(nn.Module): class FeatureFusionBlock_custom(nn.Module): - """Feature fusion block. - """ + """Feature fusion block.""" def __init__(self, features, activation, deconv=False, bn=False, expand=False, align_corners=True): """Init. @@ -303,18 +278,18 @@ class FeatureFusionBlock_custom(nn.Module): self.deconv = deconv self.align_corners = align_corners - self.groups=1 + self.groups = 1 self.expand = expand out_features = features - if self.expand==True: - out_features = features//2 - + if self.expand == True: + out_features = features // 2 + self.out_conv = nn.Conv2d(features, out_features, kernel_size=1, stride=1, padding=0, bias=True, groups=1) self.resConfUnit1 = ResidualConvUnit_custom(features, activation, bn) self.resConfUnit2 = ResidualConvUnit_custom(features, activation, bn) - + self.skip_add = nn.quantized.FloatFunctional() def forward(self, *xs): @@ -332,11 +307,8 @@ class FeatureFusionBlock_custom(nn.Module): output = self.resConfUnit2(output) - output = nn.functional.interpolate( - output, scale_factor=2, mode="bilinear", align_corners=self.align_corners - ) + output = nn.functional.interpolate(output, scale_factor=2, mode="bilinear", align_corners=self.align_corners) output = self.out_conv(output) return output - diff --git a/examples/images/diffusion/ldm/modules/midas/midas/dpt_depth.py b/examples/images/diffusion/ldm/modules/midas/midas/dpt_depth.py index 4e9aab5d2767dffea39da5b3f30e2798688216f1..74871e8b1fcef4173718905a5c2eabe28c00ad2c 100644 --- a/examples/images/diffusion/ldm/modules/midas/midas/dpt_depth.py +++ b/examples/images/diffusion/ldm/modules/midas/midas/dpt_depth.py @@ -1,15 +1,8 @@ import torch import torch.nn as nn -import torch.nn.functional as F from .base_model import BaseModel -from .blocks import ( - FeatureFusionBlock, - FeatureFusionBlock_custom, - Interpolate, - _make_encoder, - forward_vit, -) +from .blocks import FeatureFusionBlock_custom, Interpolate, _make_encoder, forward_vit def _make_fusion_block(features, use_bn): @@ -33,7 +26,6 @@ class DPT(BaseModel): channels_last=False, use_bn=False, ): - super(DPT, self).__init__() self.channels_last = channels_last @@ -48,7 +40,7 @@ class DPT(BaseModel): self.pretrained, self.scratch = _make_encoder( backbone, features, - False, # Set to true of you want to train from scratch, uses ImageNet weights + False, # Set to true of you want to train from scratch, uses ImageNet weights groups=1, expand=False, exportable=False, @@ -63,7 +55,6 @@ class DPT(BaseModel): self.scratch.output_conv = head - def forward(self, x): if self.channels_last == True: x.contiguous(memory_format=torch.channels_last) @@ -102,8 +93,7 @@ class DPTDepthModel(DPT): super().__init__(head, **kwargs) if path is not None: - self.load(path) + self.load(path) def forward(self, x): return super().forward(x).squeeze(dim=1) - diff --git a/examples/images/diffusion/ldm/modules/midas/midas/midas_net.py b/examples/images/diffusion/ldm/modules/midas/midas/midas_net.py index 8a954977800b0a0f48807e80fa63041910e33c1f..0dd87b59619cd74f93c232f79d09efd7779ff98b 100644 --- a/examples/images/diffusion/ldm/modules/midas/midas/midas_net.py +++ b/examples/images/diffusion/ldm/modules/midas/midas/midas_net.py @@ -10,8 +10,7 @@ from .blocks import FeatureFusionBlock, Interpolate, _make_encoder class MidasNet(BaseModel): - """Network for monocular depth estimation. - """ + """Network for monocular depth estimation.""" def __init__(self, path=None, features=256, non_negative=True): """Init. @@ -27,7 +26,9 @@ class MidasNet(BaseModel): use_pretrained = False if path is None else True - self.pretrained, self.scratch = _make_encoder(backbone="resnext101_wsl", features=features, use_pretrained=use_pretrained) + self.pretrained, self.scratch = _make_encoder( + backbone="resnext101_wsl", features=features, use_pretrained=use_pretrained + ) self.scratch.refinenet4 = FeatureFusionBlock(features) self.scratch.refinenet3 = FeatureFusionBlock(features) diff --git a/examples/images/diffusion/ldm/modules/midas/midas/midas_net_custom.py b/examples/images/diffusion/ldm/modules/midas/midas/midas_net_custom.py index 50e4acb5e53d5fabefe3dde16ab49c33c2b7797c..4d30744c46d36da8f3b3db79863202093baa7e03 100644 --- a/examples/images/diffusion/ldm/modules/midas/midas/midas_net_custom.py +++ b/examples/images/diffusion/ldm/modules/midas/midas/midas_net_custom.py @@ -6,15 +6,23 @@ import torch import torch.nn as nn from .base_model import BaseModel -from .blocks import FeatureFusionBlock, FeatureFusionBlock_custom, Interpolate, _make_encoder +from .blocks import FeatureFusionBlock_custom, Interpolate, _make_encoder class MidasNet_small(BaseModel): - """Network for monocular depth estimation. - """ - - def __init__(self, path=None, features=64, backbone="efficientnet_lite3", non_negative=True, exportable=True, channels_last=False, align_corners=True, - blocks={'expand': True}): + """Network for monocular depth estimation.""" + + def __init__( + self, + path=None, + features=64, + backbone="efficientnet_lite3", + non_negative=True, + exportable=True, + channels_last=False, + align_corners=True, + blocks={"expand": True}, + ): """Init. Args: @@ -27,49 +35,57 @@ class MidasNet_small(BaseModel): super(MidasNet_small, self).__init__() use_pretrained = False if path else True - + self.channels_last = channels_last self.blocks = blocks self.backbone = backbone self.groups = 1 - features1=features - features2=features - features3=features - features4=features + features1 = features + features2 = features + features3 = features + features4 = features self.expand = False - if "expand" in self.blocks and self.blocks['expand'] == True: + if "expand" in self.blocks and self.blocks["expand"] == True: self.expand = True - features1=features - features2=features*2 - features3=features*4 - features4=features*8 + features1 = features + features2 = features * 2 + features3 = features * 4 + features4 = features * 8 + + self.pretrained, self.scratch = _make_encoder( + self.backbone, features, use_pretrained, groups=self.groups, expand=self.expand, exportable=exportable + ) - self.pretrained, self.scratch = _make_encoder(self.backbone, features, use_pretrained, groups=self.groups, expand=self.expand, exportable=exportable) - - self.scratch.activation = nn.ReLU(False) + self.scratch.activation = nn.ReLU(False) - self.scratch.refinenet4 = FeatureFusionBlock_custom(features4, self.scratch.activation, deconv=False, bn=False, expand=self.expand, align_corners=align_corners) - self.scratch.refinenet3 = FeatureFusionBlock_custom(features3, self.scratch.activation, deconv=False, bn=False, expand=self.expand, align_corners=align_corners) - self.scratch.refinenet2 = FeatureFusionBlock_custom(features2, self.scratch.activation, deconv=False, bn=False, expand=self.expand, align_corners=align_corners) - self.scratch.refinenet1 = FeatureFusionBlock_custom(features1, self.scratch.activation, deconv=False, bn=False, align_corners=align_corners) + self.scratch.refinenet4 = FeatureFusionBlock_custom( + features4, self.scratch.activation, deconv=False, bn=False, expand=self.expand, align_corners=align_corners + ) + self.scratch.refinenet3 = FeatureFusionBlock_custom( + features3, self.scratch.activation, deconv=False, bn=False, expand=self.expand, align_corners=align_corners + ) + self.scratch.refinenet2 = FeatureFusionBlock_custom( + features2, self.scratch.activation, deconv=False, bn=False, expand=self.expand, align_corners=align_corners + ) + self.scratch.refinenet1 = FeatureFusionBlock_custom( + features1, self.scratch.activation, deconv=False, bn=False, align_corners=align_corners + ) - self.scratch.output_conv = nn.Sequential( - nn.Conv2d(features, features//2, kernel_size=3, stride=1, padding=1, groups=self.groups), + nn.Conv2d(features, features // 2, kernel_size=3, stride=1, padding=1, groups=self.groups), Interpolate(scale_factor=2, mode="bilinear"), - nn.Conv2d(features//2, 32, kernel_size=3, stride=1, padding=1), + nn.Conv2d(features // 2, 32, kernel_size=3, stride=1, padding=1), self.scratch.activation, nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0), nn.ReLU(True) if non_negative else nn.Identity(), nn.Identity(), ) - + if path: self.load(path) - def forward(self, x): """Forward pass. @@ -79,38 +95,35 @@ class MidasNet_small(BaseModel): Returns: tensor: depth """ - if self.channels_last==True: + if self.channels_last == True: print("self.channels_last = ", self.channels_last) x.contiguous(memory_format=torch.channels_last) - layer_1 = self.pretrained.layer1(x) layer_2 = self.pretrained.layer2(layer_1) layer_3 = self.pretrained.layer3(layer_2) layer_4 = self.pretrained.layer4(layer_3) - + layer_1_rn = self.scratch.layer1_rn(layer_1) layer_2_rn = self.scratch.layer2_rn(layer_2) layer_3_rn = self.scratch.layer3_rn(layer_3) layer_4_rn = self.scratch.layer4_rn(layer_4) - path_4 = self.scratch.refinenet4(layer_4_rn) path_3 = self.scratch.refinenet3(path_4, layer_3_rn) path_2 = self.scratch.refinenet2(path_3, layer_2_rn) path_1 = self.scratch.refinenet1(path_2, layer_1_rn) - + out = self.scratch.output_conv(path_1) return torch.squeeze(out, dim=1) - def fuse_model(m): prev_previous_type = nn.Identity() - prev_previous_name = '' + prev_previous_name = "" previous_type = nn.Identity() - previous_name = '' + previous_name = "" for name, module in m.named_modules(): if prev_previous_type == nn.Conv2d and previous_type == nn.BatchNorm2d and type(module) == nn.ReLU: # print("FUSED ", prev_previous_name, previous_name, name) @@ -125,4 +138,4 @@ def fuse_model(m): prev_previous_type = previous_type prev_previous_name = previous_name previous_type = type(module) - previous_name = name \ No newline at end of file + previous_name = name diff --git a/examples/images/diffusion/ldm/modules/midas/midas/transforms.py b/examples/images/diffusion/ldm/modules/midas/midas/transforms.py index 350cbc11662633ad7f8968eb10be2e7de6e384e9..aede0fa0c73fddc2cd6cf1ba18f0c5c6c6253962 100644 --- a/examples/images/diffusion/ldm/modules/midas/midas/transforms.py +++ b/examples/images/diffusion/ldm/modules/midas/midas/transforms.py @@ -1,7 +1,8 @@ -import numpy as np -import cv2 import math +import cv2 +import numpy as np + def apply_min_size(sample, size, image_interpolation_method=cv2.INTER_AREA): """Rezise the sample to ensure the given size. Keeps aspect ratio. @@ -28,13 +29,9 @@ def apply_min_size(sample, size, image_interpolation_method=cv2.INTER_AREA): shape[1] = math.ceil(scale * shape[1]) # resize - sample["image"] = cv2.resize( - sample["image"], tuple(shape[::-1]), interpolation=image_interpolation_method - ) + sample["image"] = cv2.resize(sample["image"], tuple(shape[::-1]), interpolation=image_interpolation_method) - sample["disparity"] = cv2.resize( - sample["disparity"], tuple(shape[::-1]), interpolation=cv2.INTER_NEAREST - ) + sample["disparity"] = cv2.resize(sample["disparity"], tuple(shape[::-1]), interpolation=cv2.INTER_NEAREST) sample["mask"] = cv2.resize( sample["mask"].astype(np.float32), tuple(shape[::-1]), @@ -46,8 +43,7 @@ def apply_min_size(sample, size, image_interpolation_method=cv2.INTER_AREA): class Resize(object): - """Resize sample to given size (width, height). - """ + """Resize sample to given size (width, height).""" def __init__( self, @@ -133,24 +129,14 @@ class Resize(object): # fit height scale_width = scale_height else: - raise ValueError( - f"resize_method {self.__resize_method} not implemented" - ) + raise ValueError(f"resize_method {self.__resize_method} not implemented") if self.__resize_method == "lower_bound": - new_height = self.constrain_to_multiple_of( - scale_height * height, min_val=self.__height - ) - new_width = self.constrain_to_multiple_of( - scale_width * width, min_val=self.__width - ) + new_height = self.constrain_to_multiple_of(scale_height * height, min_val=self.__height) + new_width = self.constrain_to_multiple_of(scale_width * width, min_val=self.__width) elif self.__resize_method == "upper_bound": - new_height = self.constrain_to_multiple_of( - scale_height * height, max_val=self.__height - ) - new_width = self.constrain_to_multiple_of( - scale_width * width, max_val=self.__width - ) + new_height = self.constrain_to_multiple_of(scale_height * height, max_val=self.__height) + new_width = self.constrain_to_multiple_of(scale_width * width, max_val=self.__width) elif self.__resize_method == "minimal": new_height = self.constrain_to_multiple_of(scale_height * height) new_width = self.constrain_to_multiple_of(scale_width * width) @@ -160,9 +146,7 @@ class Resize(object): return (new_width, new_height) def __call__(self, sample): - width, height = self.get_size( - sample["image"].shape[1], sample["image"].shape[0] - ) + width, height = self.get_size(sample["image"].shape[1], sample["image"].shape[0]) # resize sample sample["image"] = cv2.resize( @@ -180,9 +164,7 @@ class Resize(object): ) if "depth" in sample: - sample["depth"] = cv2.resize( - sample["depth"], (width, height), interpolation=cv2.INTER_NEAREST - ) + sample["depth"] = cv2.resize(sample["depth"], (width, height), interpolation=cv2.INTER_NEAREST) sample["mask"] = cv2.resize( sample["mask"].astype(np.float32), @@ -195,8 +177,7 @@ class Resize(object): class NormalizeImage(object): - """Normlize image by given mean and std. - """ + """Normlize image by given mean and std.""" def __init__(self, mean, std): self.__mean = mean @@ -209,8 +190,7 @@ class NormalizeImage(object): class PrepareForNet(object): - """Prepare sample for usage as network input. - """ + """Prepare sample for usage as network input.""" def __init__(self): pass diff --git a/examples/images/diffusion/ldm/modules/midas/midas/vit.py b/examples/images/diffusion/ldm/modules/midas/midas/vit.py index ea46b1be88b261b0dec04f3da0256f5f66f88a74..41bdb566fd4fbac60079b3a214b63c947b1762c5 100644 --- a/examples/images/diffusion/ldm/modules/midas/midas/vit.py +++ b/examples/images/diffusion/ldm/modules/midas/midas/vit.py @@ -1,8 +1,9 @@ +import math +import types + +import timm import torch import torch.nn as nn -import timm -import types -import math import torch.nn.functional as F @@ -56,7 +57,7 @@ class Transpose(nn.Module): def forward_vit(pretrained, x): b, c, h, w = x.shape - glob = pretrained.model.forward_flex(x) + pretrained.model.forward_flex(x) layer_1 = pretrained.activations["1"] layer_2 = pretrained.activations["2"] @@ -117,9 +118,7 @@ def _resize_pos_embed(self, posemb, gs_h, gs_w): def forward_flex(self, x): b, c, h, w = x.shape - pos_embed = self._resize_pos_embed( - self.pos_embed, h // self.patch_size[1], w // self.patch_size[0] - ) + pos_embed = self._resize_pos_embed(self.pos_embed, h // self.patch_size[1], w // self.patch_size[0]) B = x.shape[0] @@ -131,15 +130,11 @@ def forward_flex(self, x): x = self.patch_embed.proj(x).flatten(2).transpose(1, 2) if getattr(self, "dist_token", None) is not None: - cls_tokens = self.cls_token.expand( - B, -1, -1 - ) # stole cls_tokens impl from Phil Wang, thanks + cls_tokens = self.cls_token.expand(B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks dist_token = self.dist_token.expand(B, -1, -1) x = torch.cat((cls_tokens, dist_token, x), dim=1) else: - cls_tokens = self.cls_token.expand( - B, -1, -1 - ) # stole cls_tokens impl from Phil Wang, thanks + cls_tokens = self.cls_token.expand(B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks x = torch.cat((cls_tokens, x), dim=1) x = x + pos_embed @@ -169,13 +164,9 @@ def get_readout_oper(vit_features, features, use_readout, start_index=1): elif use_readout == "add": readout_oper = [AddReadout(start_index)] * len(features) elif use_readout == "project": - readout_oper = [ - ProjectReadout(vit_features, start_index) for out_feat in features - ] + readout_oper = [ProjectReadout(vit_features, start_index) for out_feat in features] else: - assert ( - False - ), "wrong operation for readout token, use_readout can be 'ignore', 'add', or 'project'" + assert False, "wrong operation for readout token, use_readout can be 'ignore', 'add', or 'project'" return readout_oper @@ -287,9 +278,7 @@ def _make_vit_b16_backbone( # We inject this function into the VisionTransformer instances so that # we can use it with interpolated position embeddings without modifying the library source. pretrained.model.forward_flex = types.MethodType(forward_flex, pretrained.model) - pretrained.model._resize_pos_embed = types.MethodType( - _resize_pos_embed, pretrained.model - ) + pretrained.model._resize_pos_embed = types.MethodType(_resize_pos_embed, pretrained.model) return pretrained @@ -311,24 +300,18 @@ def _make_pretrained_vitb16_384(pretrained, use_readout="ignore", hooks=None): model = timm.create_model("vit_base_patch16_384", pretrained=pretrained) hooks = [2, 5, 8, 11] if hooks == None else hooks - return _make_vit_b16_backbone( - model, features=[96, 192, 384, 768], hooks=hooks, use_readout=use_readout - ) + return _make_vit_b16_backbone(model, features=[96, 192, 384, 768], hooks=hooks, use_readout=use_readout) def _make_pretrained_deitb16_384(pretrained, use_readout="ignore", hooks=None): model = timm.create_model("vit_deit_base_patch16_384", pretrained=pretrained) hooks = [2, 5, 8, 11] if hooks == None else hooks - return _make_vit_b16_backbone( - model, features=[96, 192, 384, 768], hooks=hooks, use_readout=use_readout - ) + return _make_vit_b16_backbone(model, features=[96, 192, 384, 768], hooks=hooks, use_readout=use_readout) def _make_pretrained_deitb16_distil_384(pretrained, use_readout="ignore", hooks=None): - model = timm.create_model( - "vit_deit_base_distilled_patch16_384", pretrained=pretrained - ) + model = timm.create_model("vit_deit_base_distilled_patch16_384", pretrained=pretrained) hooks = [2, 5, 8, 11] if hooks == None else hooks return _make_vit_b16_backbone( @@ -358,12 +341,8 @@ def _make_vit_b_rn50_backbone( pretrained.model.blocks[hooks[0]].register_forward_hook(get_activation("1")) pretrained.model.blocks[hooks[1]].register_forward_hook(get_activation("2")) else: - pretrained.model.patch_embed.backbone.stages[0].register_forward_hook( - get_activation("1") - ) - pretrained.model.patch_embed.backbone.stages[1].register_forward_hook( - get_activation("2") - ) + pretrained.model.patch_embed.backbone.stages[0].register_forward_hook(get_activation("1")) + pretrained.model.patch_embed.backbone.stages[1].register_forward_hook(get_activation("2")) pretrained.model.blocks[hooks[2]].register_forward_hook(get_activation("3")) pretrained.model.blocks[hooks[3]].register_forward_hook(get_activation("4")) @@ -419,12 +398,8 @@ def _make_vit_b_rn50_backbone( ), ) else: - pretrained.act_postprocess1 = nn.Sequential( - nn.Identity(), nn.Identity(), nn.Identity() - ) - pretrained.act_postprocess2 = nn.Sequential( - nn.Identity(), nn.Identity(), nn.Identity() - ) + pretrained.act_postprocess1 = nn.Sequential(nn.Identity(), nn.Identity(), nn.Identity()) + pretrained.act_postprocess2 = nn.Sequential(nn.Identity(), nn.Identity(), nn.Identity()) pretrained.act_postprocess3 = nn.Sequential( readout_oper[2], @@ -468,16 +443,12 @@ def _make_vit_b_rn50_backbone( # We inject this function into the VisionTransformer instances so that # we can use it with interpolated position embeddings without modifying the library source. - pretrained.model._resize_pos_embed = types.MethodType( - _resize_pos_embed, pretrained.model - ) + pretrained.model._resize_pos_embed = types.MethodType(_resize_pos_embed, pretrained.model) return pretrained -def _make_pretrained_vitb_rn50_384( - pretrained, use_readout="ignore", hooks=None, use_vit_only=False -): +def _make_pretrained_vitb_rn50_384(pretrained, use_readout="ignore", hooks=None, use_vit_only=False): model = timm.create_model("vit_base_resnet50_384", pretrained=pretrained) hooks = [0, 1, 8, 11] if hooks == None else hooks diff --git a/examples/images/diffusion/ldm/modules/midas/utils.py b/examples/images/diffusion/ldm/modules/midas/utils.py index 9a9d3b5b66370fa98da9e067ba53ead848ea9a59..1428d42b2445e20629b48c969c2f84fb31fd0471 100644 --- a/examples/images/diffusion/ldm/modules/midas/utils.py +++ b/examples/images/diffusion/ldm/modules/midas/utils.py @@ -1,8 +1,9 @@ """Utils for monoDepth.""" -import sys import re -import numpy as np +import sys + import cv2 +import numpy as np import torch @@ -16,7 +17,6 @@ def read_pfm(path): tuple: (data, scale) """ with open(path, "rb") as file: - color = None width = None height = None @@ -74,9 +74,7 @@ def write_pfm(path, image, scale=1): if len(image.shape) == 3 and image.shape[2] == 3: # color image color = True - elif ( - len(image.shape) == 2 or len(image.shape) == 3 and image.shape[2] == 1 - ): # greyscale + elif len(image.shape) == 2 or len(image.shape) == 3 and image.shape[2] == 1: # greyscale color = False else: raise Exception("Image must have H x W x 3, H x W x 1 or H x W dimensions.") @@ -135,9 +133,7 @@ def resize_image(img): img_resized = cv2.resize(img, (width, height), interpolation=cv2.INTER_AREA) - img_resized = ( - torch.from_numpy(np.transpose(img_resized, (2, 0, 1))).contiguous().float() - ) + img_resized = torch.from_numpy(np.transpose(img_resized, (2, 0, 1))).contiguous().float() img_resized = img_resized.unsqueeze(0) return img_resized @@ -156,12 +152,11 @@ def resize_depth(depth, width, height): """ depth = torch.squeeze(depth[0, :, :, :]).to("cpu") - depth_resized = cv2.resize( - depth.numpy(), (width, height), interpolation=cv2.INTER_CUBIC - ) + depth_resized = cv2.resize(depth.numpy(), (width, height), interpolation=cv2.INTER_CUBIC) return depth_resized + def write_depth(path, depth, bits=1): """Write depth map to pfm and png file. @@ -174,7 +169,7 @@ def write_depth(path, depth, bits=1): depth_min = depth.min() depth_max = depth.max() - max_val = (2**(8*bits))-1 + max_val = (2 ** (8 * bits)) - 1 if depth_max - depth_min > np.finfo("float").eps: out = max_val * (depth - depth_min) / (depth_max - depth_min) diff --git a/examples/images/diffusion/ldm/util.py b/examples/images/diffusion/ldm/util.py index 8c09ca1c72f7ceb3f9d7f9546aae5561baf62b13..9b52b199aa2c687952601bc61618b22b7bec61bf 100644 --- a/examples/images/diffusion/ldm/util.py +++ b/examples/images/diffusion/ldm/util.py @@ -1,11 +1,10 @@ import importlib +from inspect import isfunction -import torch -from torch import optim import numpy as np - -from inspect import isfunction +import torch from PIL import Image, ImageDraw, ImageFont +from torch import optim def log_txt_as_img(wh, xc, size=10): @@ -16,9 +15,9 @@ def log_txt_as_img(wh, xc, size=10): for bi in range(b): txt = Image.new("RGB", wh, color="white") draw = ImageDraw.Draw(txt) - font = ImageFont.truetype('data/DejaVuSans.ttf', size=size) + font = ImageFont.truetype("data/DejaVuSans.ttf", size=size) nc = int(40 * (wh[0] / 256)) - lines = "\n".join(xc[bi][start:start + nc] for start in range(0, len(xc[bi]), nc)) + lines = "\n".join(xc[bi][start : start + nc] for start in range(0, len(xc[bi]), nc)) try: draw.text((0, 0), lines, fill="black", font=font) @@ -39,7 +38,7 @@ def ismap(x): def isimage(x): - if not isinstance(x,torch.Tensor): + if not isinstance(x, torch.Tensor): return False return (len(x.shape) == 4) and (x.shape[1] == 3 or x.shape[1] == 1) @@ -71,7 +70,7 @@ def count_params(model, verbose=False): def instantiate_from_config(config): if not "target" in config: - if config == '__is_first_stage__': + if config == "__is_first_stage__": return None elif config == "__is_unconditional__": return None @@ -89,9 +88,18 @@ def get_obj_from_str(string, reload=False): class AdamWwithEMAandWings(optim.Optimizer): # credit to https://gist.github.com/crowsonkb/65f7265353f403714fce3b2595e0b298 - def __init__(self, params, lr=1.e-3, betas=(0.9, 0.999), eps=1.e-8, # TODO: check hyperparameters before using - weight_decay=1.e-2, amsgrad=False, ema_decay=0.9999, # ema decay to match previous code - ema_power=1., param_names=()): + def __init__( + self, + params, + lr=1.0e-3, + betas=(0.9, 0.999), + eps=1.0e-8, # TODO: check hyperparameters before using + weight_decay=1.0e-2, + amsgrad=False, + ema_decay=0.9999, # ema decay to match previous code + ema_power=1.0, + param_names=(), + ): """AdamW that saves EMA versions of the parameters.""" if not 0.0 <= lr: raise ValueError("Invalid learning rate: {}".format(lr)) @@ -105,15 +113,22 @@ class AdamWwithEMAandWings(optim.Optimizer): raise ValueError("Invalid weight_decay value: {}".format(weight_decay)) if not 0.0 <= ema_decay <= 1.0: raise ValueError("Invalid ema_decay value: {}".format(ema_decay)) - defaults = dict(lr=lr, betas=betas, eps=eps, - weight_decay=weight_decay, amsgrad=amsgrad, ema_decay=ema_decay, - ema_power=ema_power, param_names=param_names) + defaults = dict( + lr=lr, + betas=betas, + eps=eps, + weight_decay=weight_decay, + amsgrad=amsgrad, + ema_decay=ema_decay, + ema_power=ema_power, + param_names=param_names, + ) super().__init__(params, defaults) def __setstate__(self, state): super().__setstate__(state) for group in self.param_groups: - group.setdefault('amsgrad', False) + group.setdefault("amsgrad", False) @torch.no_grad() def step(self, closure=None): @@ -133,65 +148,66 @@ class AdamWwithEMAandWings(optim.Optimizer): exp_avgs = [] exp_avg_sqs = [] ema_params_with_grad = [] - state_sums = [] max_exp_avg_sqs = [] state_steps = [] - amsgrad = group['amsgrad'] - beta1, beta2 = group['betas'] - ema_decay = group['ema_decay'] - ema_power = group['ema_power'] + amsgrad = group["amsgrad"] + beta1, beta2 = group["betas"] + ema_decay = group["ema_decay"] + ema_power = group["ema_power"] - for p in group['params']: + for p in group["params"]: if p.grad is None: continue params_with_grad.append(p) if p.grad.is_sparse: - raise RuntimeError('AdamW does not support sparse gradients') + raise RuntimeError("AdamW does not support sparse gradients") grads.append(p.grad) state = self.state[p] # State initialization if len(state) == 0: - state['step'] = 0 + state["step"] = 0 # Exponential moving average of gradient values - state['exp_avg'] = torch.zeros_like(p, memory_format=torch.preserve_format) + state["exp_avg"] = torch.zeros_like(p, memory_format=torch.preserve_format) # Exponential moving average of squared gradient values - state['exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format) + state["exp_avg_sq"] = torch.zeros_like(p, memory_format=torch.preserve_format) if amsgrad: # Maintains max of all exp. moving avg. of sq. grad. values - state['max_exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format) + state["max_exp_avg_sq"] = torch.zeros_like(p, memory_format=torch.preserve_format) # Exponential moving average of parameter values - state['param_exp_avg'] = p.detach().float().clone() + state["param_exp_avg"] = p.detach().float().clone() - exp_avgs.append(state['exp_avg']) - exp_avg_sqs.append(state['exp_avg_sq']) - ema_params_with_grad.append(state['param_exp_avg']) + exp_avgs.append(state["exp_avg"]) + exp_avg_sqs.append(state["exp_avg_sq"]) + ema_params_with_grad.append(state["param_exp_avg"]) if amsgrad: - max_exp_avg_sqs.append(state['max_exp_avg_sq']) + max_exp_avg_sqs.append(state["max_exp_avg_sq"]) # update the steps for each param group update - state['step'] += 1 + state["step"] += 1 # record the step after step update - state_steps.append(state['step']) - - optim._functional.adamw(params_with_grad, - grads, - exp_avgs, - exp_avg_sqs, - max_exp_avg_sqs, - state_steps, - amsgrad=amsgrad, - beta1=beta1, - beta2=beta2, - lr=group['lr'], - weight_decay=group['weight_decay'], - eps=group['eps'], - maximize=False) - - cur_ema_decay = min(ema_decay, 1 - state['step'] ** -ema_power) + state_steps.append(state["step"]) + + optim._functional.adamw( + params_with_grad, + grads, + exp_avgs, + exp_avg_sqs, + max_exp_avg_sqs, + state_steps, + amsgrad=amsgrad, + beta1=beta1, + beta2=beta2, + lr=group["lr"], + weight_decay=group["weight_decay"], + eps=group["eps"], + maximize=False, + ) + + cur_ema_decay = min(ema_decay, 1 - state["step"] ** -ema_power) for param, ema_param in zip(params_with_grad, ema_params_with_grad): ema_param.mul_(cur_ema_decay).add_(param.float(), alpha=1 - cur_ema_decay) - return loss \ No newline at end of file + return loss diff --git a/examples/images/diffusion/main.py b/examples/images/diffusion/main.py index 713029fc677d818b9d63e3d2c6f15a5592d2d3aa..6d44df667fcef3548e95d2ce1e72fc53a5520c81 100644 --- a/examples/images/diffusion/main.py +++ b/examples/images/diffusion/main.py @@ -1,33 +1,28 @@ import argparse -import csv import datetime import glob -import importlib import os import sys import time +from functools import partial +import lightning.pytorch as pl import numpy as np import torch import torchvision -import lightning.pytorch as pl - - -from functools import partial - -from omegaconf import OmegaConf -from packaging import version -from PIL import Image -from prefetch_generator import BackgroundGenerator -from torch.utils.data import DataLoader, Dataset, Subset, random_split from ldm.models.diffusion.ddpm import LatentDiffusion - from lightning.pytorch import seed_everything from lightning.pytorch.callbacks import Callback, LearningRateMonitor, ModelCheckpoint +from lightning.pytorch.loggers import TensorBoardLogger, WandbLogger +from lightning.pytorch.strategies import ColossalAIStrategy, DDPStrategy from lightning.pytorch.trainer import Trainer from lightning.pytorch.utilities import rank_zero_info, rank_zero_only -from lightning.pytorch.loggers import WandbLogger, TensorBoardLogger -from lightning.pytorch.strategies import ColossalAIStrategy,DDPStrategy +from omegaconf import OmegaConf +from packaging import version +from PIL import Image +from prefetch_generator import BackgroundGenerator +from torch.utils.data import DataLoader, Dataset + LIGHTNING_PACK_NAME = "lightning.pytorch." from ldm.data.base import Txt2ImgIterableBaseDataset @@ -37,15 +32,15 @@ from ldm.util import instantiate_from_config class DataLoaderX(DataLoader): -# A custom data loader class that inherits from DataLoader + # A custom data loader class that inherits from DataLoader def __iter__(self): # Overriding the __iter__ method of DataLoader to return a BackgroundGenerator - #This is to enable data loading in the background to improve training performance + # This is to enable data loading in the background to improve training performance return BackgroundGenerator(super().__iter__()) def get_parser(**parser_kwargs): - #A function to create an ArgumentParser object and add arguments to it + # A function to create an ArgumentParser object and add arguments to it def str2bool(v): # A helper function to parse boolean values from command line arguments @@ -57,6 +52,7 @@ def get_parser(**parser_kwargs): return False else: raise argparse.ArgumentTypeError("Boolean value expected.") + # Create an ArgumentParser object with specifies kwargs parser = argparse.ArgumentParser(**parser_kwargs) @@ -160,6 +156,7 @@ def get_parser(**parser_kwargs): return parser + # A function that returns the non-default arguments between two objects def nondefault_trainer_args(opt): # create an argument parser @@ -171,6 +168,7 @@ def nondefault_trainer_args(opt): # return all non-default arguments return sorted(k for k in vars(args) if getattr(opt, k) != getattr(args, k)) + # A dataset wrapper class to create a pytorch dataset from an arbitrary object class WrappedDataset(Dataset): """Wraps an arbitrary object with __len__ and __getitem__ into a pytorch dataset""" @@ -184,6 +182,7 @@ class WrappedDataset(Dataset): def __getitem__(self, idx): return self.data[idx] + # A function to initialize worker processes def worker_init_fn(_): worker_info = torch.utils.data.get_worker_info() @@ -192,31 +191,33 @@ def worker_init_fn(_): worker_id = worker_info.id if isinstance(dataset, Txt2ImgIterableBaseDataset): - #divide the dataset into equal parts for each worker + # divide the dataset into equal parts for each worker split_size = dataset.num_records // worker_info.num_workers - #set the sample IDs for the current worker + # set the sample IDs for the current worker # reset num_records to the true number to retain reliable length information - dataset.sample_ids = dataset.valid_ids[worker_id * split_size:(worker_id + 1) * split_size] + dataset.sample_ids = dataset.valid_ids[worker_id * split_size : (worker_id + 1) * split_size] # set the seed for the current worker current_id = np.random.choice(len(np.random.get_state()[1]), 1) return np.random.seed(np.random.get_state()[1][current_id] + worker_id) else: return np.random.seed(np.random.get_state()[1][0] + worker_id) -#Provide functionality for creating data loaders based on provided dataset configurations -class DataModuleFromConfig(pl.LightningDataModule): - def __init__(self, - batch_size, - train=None, - validation=None, - test=None, - predict=None, - wrap=False, - num_workers=None, - shuffle_test_loader=False, - use_worker_init_fn=False, - shuffle_val_dataloader=False): +# Provide functionality for creating data loaders based on provided dataset configurations +class DataModuleFromConfig(pl.LightningDataModule): + def __init__( + self, + batch_size, + train=None, + validation=None, + test=None, + predict=None, + wrap=False, + num_workers=None, + shuffle_test_loader=False, + use_worker_init_fn=False, + shuffle_val_dataloader=False, + ): super().__init__() # Set data module attributes self.batch_size = batch_size @@ -246,43 +247,47 @@ class DataModuleFromConfig(pl.LightningDataModule): def setup(self, stage=None): # Instantiate datasets from the dataset configs self.datasets = dict((k, instantiate_from_config(self.dataset_configs[k])) for k in self.dataset_configs) - + # If wrap is true, create a WrappedDataset for each dataset if self.wrap: for k in self.datasets: self.datasets[k] = WrappedDataset(self.datasets[k]) def _train_dataloader(self): - #Check if the train dataset is iterable - is_iterable_dataset = isinstance(self.datasets['train'], Txt2ImgIterableBaseDataset) - #Set the worker initialization function of the dataset is iterable or use_worker_init_fn is True + # Check if the train dataset is iterable + is_iterable_dataset = isinstance(self.datasets["train"], Txt2ImgIterableBaseDataset) + # Set the worker initialization function of the dataset is iterable or use_worker_init_fn is True if is_iterable_dataset or self.use_worker_init_fn: init_fn = worker_init_fn else: init_fn = None # Return a DataLoaderX object for the train dataset - return DataLoaderX(self.datasets["train"], - batch_size=self.batch_size, - num_workers=self.num_workers, - shuffle=False if is_iterable_dataset else True, - worker_init_fn=init_fn) + return DataLoaderX( + self.datasets["train"], + batch_size=self.batch_size, + num_workers=self.num_workers, + shuffle=False if is_iterable_dataset else True, + worker_init_fn=init_fn, + ) def _val_dataloader(self, shuffle=False): - #Check if the validation dataset is iterable - if isinstance(self.datasets['validation'], Txt2ImgIterableBaseDataset) or self.use_worker_init_fn: + # Check if the validation dataset is iterable + if isinstance(self.datasets["validation"], Txt2ImgIterableBaseDataset) or self.use_worker_init_fn: init_fn = worker_init_fn else: init_fn = None # Return a DataLoaderX object for the validation dataset - return DataLoaderX(self.datasets["validation"], - batch_size=self.batch_size, - num_workers=self.num_workers, - worker_init_fn=init_fn, - shuffle=shuffle) + return DataLoaderX( + self.datasets["validation"], + batch_size=self.batch_size, + num_workers=self.num_workers, + worker_init_fn=init_fn, + shuffle=shuffle, + ) def _test_dataloader(self, shuffle=False): # Check if the test dataset is iterable - is_iterable_dataset = isinstance(self.datasets['train'], Txt2ImgIterableBaseDataset) + is_iterable_dataset = isinstance(self.datasets["train"], Txt2ImgIterableBaseDataset) # Set the worker initialization function if the dataset is iterable or use_worker_init_fn is True if is_iterable_dataset or self.use_worker_init_fn: init_fn = worker_init_fn @@ -292,21 +297,22 @@ class DataModuleFromConfig(pl.LightningDataModule): # do not shuffle dataloader for iterable dataset shuffle = shuffle and (not is_iterable_dataset) - return DataLoaderX(self.datasets["test"], - batch_size=self.batch_size, - num_workers=self.num_workers, - worker_init_fn=init_fn, - shuffle=shuffle) + return DataLoaderX( + self.datasets["test"], + batch_size=self.batch_size, + num_workers=self.num_workers, + worker_init_fn=init_fn, + shuffle=shuffle, + ) def _predict_dataloader(self, shuffle=False): - if isinstance(self.datasets['predict'], Txt2ImgIterableBaseDataset) or self.use_worker_init_fn: + if isinstance(self.datasets["predict"], Txt2ImgIterableBaseDataset) or self.use_worker_init_fn: init_fn = worker_init_fn else: init_fn = None - return DataLoaderX(self.datasets["predict"], - batch_size=self.batch_size, - num_workers=self.num_workers, - worker_init_fn=init_fn) + return DataLoaderX( + self.datasets["predict"], batch_size=self.batch_size, num_workers=self.num_workers, worker_init_fn=init_fn + ) class SetupCallback(Callback): @@ -338,10 +344,10 @@ class SetupCallback(Callback): os.makedirs(self.ckptdir, exist_ok=True) os.makedirs(self.cfgdir, exist_ok=True) - #Create trainstep checkpoint directory if necessary + # Create trainstep checkpoint directory if necessary if "callbacks" in self.lightning_config: - if 'metrics_over_trainsteps_checkpoint' in self.lightning_config['callbacks']: - os.makedirs(os.path.join(self.ckptdir, 'trainstep_checkpoints'), exist_ok=True) + if "metrics_over_trainsteps_checkpoint" in self.lightning_config["callbacks"]: + os.makedirs(os.path.join(self.ckptdir, "trainstep_checkpoints"), exist_ok=True) print("Project config") print(OmegaConf.to_yaml(self.config)) OmegaConf.save(self.config, os.path.join(self.cfgdir, "{}-project.yaml".format(self.now))) @@ -349,8 +355,10 @@ class SetupCallback(Callback): # Save project config and lightning config as YAML files print("Lightning config") print(OmegaConf.to_yaml(self.lightning_config)) - OmegaConf.save(OmegaConf.create({"lightning": self.lightning_config}), - os.path.join(self.cfgdir, "{}-lightning.yaml".format(self.now))) + OmegaConf.save( + OmegaConf.create({"lightning": self.lightning_config}), + os.path.join(self.cfgdir, "{}-lightning.yaml".format(self.now)), + ) # Remove log directory if resuming training and directory already exists else: @@ -373,24 +381,25 @@ class SetupCallback(Callback): # PyTorch Lightning callback for logging images during training and validation of a deep learning model class ImageLogger(Callback): - - def __init__(self, - batch_frequency, # Frequency of batches on which to log images - max_images, # Maximum number of images to log - clamp=True, # Whether to clamp pixel values to [-1,1] - increase_log_steps=True, # Whether to increase frequency of log steps exponentially - rescale=True, # Whether to rescale pixel values to [0,1] - disabled=False, # Whether to disable logging - log_on_batch_idx=False, # Whether to log on batch index instead of global step - log_first_step=False, # Whether to log on the first step - log_images_kwargs=None): # Additional keyword arguments to pass to log_images method + def __init__( + self, + batch_frequency, # Frequency of batches on which to log images + max_images, # Maximum number of images to log + clamp=True, # Whether to clamp pixel values to [-1,1] + increase_log_steps=True, # Whether to increase frequency of log steps exponentially + rescale=True, # Whether to rescale pixel values to [0,1] + disabled=False, # Whether to disable logging + log_on_batch_idx=False, # Whether to log on batch index instead of global step + log_first_step=False, # Whether to log on the first step + log_images_kwargs=None, + ): # Additional keyword arguments to pass to log_images method super().__init__() self.rescale = rescale self.batch_freq = batch_frequency self.max_images = max_images self.logger_log_images = { # Dictionary of logger classes and their corresponding logging methods - pl.loggers.CSVLogger: self._testtube, + pl.loggers.CSVLogger: self._testtube, } # Create a list of exponentially increasing log steps, starting from 1 and ending at batch_frequency self.log_steps = [2**n for n in range(int(np.log2(self.batch_freq)) + 1)] @@ -402,37 +411,39 @@ class ImageLogger(Callback): self.log_images_kwargs = log_images_kwargs if log_images_kwargs else {} self.log_first_step = log_first_step - @rank_zero_only # Ensure that only the first process in distributed training executes this method - def _testtube(self, # The PyTorch Lightning module - pl_module, # A dictionary of images to log. - images, # - batch_idx, # The batch index. - split # The split (train/val) on which to log the images - ): - # Method for logging images using test-tube logger + @rank_zero_only # Ensure that only the first process in distributed training executes this method + def _testtube( + self, # The PyTorch Lightning module + pl_module, # A dictionary of images to log. + images, # + batch_idx, # The batch index. + split, # The split (train/val) on which to log the images + ): + # Method for logging images using test-tube logger for k in images: grid = torchvision.utils.make_grid(images[k]) - grid = (grid + 1.0) / 2.0 # -1,1 -> 0,1; c,h,w + grid = (grid + 1.0) / 2.0 # -1,1 -> 0,1; c,h,w tag = f"{split}/{k}" # Add image grid to logger's experiment pl_module.logger.experiment.add_image(tag, grid, global_step=pl_module.global_step) @rank_zero_only - def log_local(self, - save_dir, - split, # The split (train/val) on which to log the images - images, # A dictionary of images to log - global_step, # The global step - current_epoch, # The current epoch. - batch_idx - ): - # Method for saving image grids to local file system + def log_local( + self, + save_dir, + split, # The split (train/val) on which to log the images + images, # A dictionary of images to log + global_step, # The global step + current_epoch, # The current epoch. + batch_idx, + ): + # Method for saving image grids to local file system root = os.path.join(save_dir, "images", split) for k in images: grid = torchvision.utils.make_grid(images[k], nrow=4) if self.rescale: - grid = (grid + 1.0) / 2.0 # -1,1 -> 0,1; c,h,w + grid = (grid + 1.0) / 2.0 # -1,1 -> 0,1; c,h,w grid = grid.transpose(0, 1).transpose(1, 2).squeeze(-1) grid = grid.numpy() grid = (grid * 255).astype(np.uint8) @@ -443,11 +454,15 @@ class ImageLogger(Callback): Image.fromarray(grid).save(path) def log_img(self, pl_module, batch, batch_idx, split="train"): - #Function for logging images to both the logger and local file system. + # Function for logging images to both the logger and local file system. check_idx = batch_idx if self.log_on_batch_idx else pl_module.global_step # check if it's time to log an image batch - if (self.check_frequency(check_idx) and # batch_idx % self.batch_freq == 0 - hasattr(pl_module, "log_images") and callable(pl_module.log_images) and self.max_images > 0): + if ( + self.check_frequency(check_idx) + and hasattr(pl_module, "log_images") # batch_idx % self.batch_freq == 0 + and callable(pl_module.log_images) + and self.max_images > 0 + ): # Get logger type and check if training mode is on logger = type(pl_module.logger) @@ -466,11 +481,12 @@ class ImageLogger(Callback): if isinstance(images[k], torch.Tensor): images[k] = images[k].detach().cpu() if self.clamp: - images[k] = torch.clamp(images[k], -1., 1.) + images[k] = torch.clamp(images[k], -1.0, 1.0) # Log images locally to file system - self.log_local(pl_module.logger.save_dir, split, images, pl_module.global_step, pl_module.current_epoch, - batch_idx) + self.log_local( + pl_module.logger.save_dir, split, images, pl_module.global_step, pl_module.current_epoch, batch_idx + ) # log the images using the logger logger_log_images = self.logger_log_images.get(logger, lambda *args, **kwargs: None) @@ -482,13 +498,13 @@ class ImageLogger(Callback): # The function checks if it's time to log an image batch def check_frequency(self, check_idx): - if ((check_idx % self.batch_freq) == 0 or - (check_idx in self.log_steps)) and (check_idx > 0 or self.log_first_step): + if ((check_idx % self.batch_freq) == 0 or (check_idx in self.log_steps)) and ( + check_idx > 0 or self.log_first_step + ): try: self.log_steps.pop(0) except IndexError as e: print(e) - pass return True return False @@ -503,7 +519,7 @@ class ImageLogger(Callback): if not self.disabled and pl_module.global_step > 0: self.log_img(pl_module, batch, batch_idx, split="val") # log gradients during calibration if necessary - if hasattr(pl_module, 'calibrate_grad_norm'): + if hasattr(pl_module, "calibrate_grad_norm"): if (pl_module.calibrate_grad_norm and batch_idx % 25 == 0) and batch_idx > 0: self.log_gradients(trainer, pl_module, batch_idx=batch_idx) @@ -514,7 +530,7 @@ class CUDACallback(Callback): def on_train_start(self, trainer, pl_module): rank_zero_info("Training is starting") - #the method is called at the end of each training epoch + # the method is called at the end of each training epoch def on_train_end(self, trainer, pl_module): rank_zero_info("Training is ending") @@ -595,9 +611,11 @@ if __name__ == "__main__": opt, unknown = parser.parse_known_args() # Verify the arguments are both specified if opt.name and opt.resume: - raise ValueError("-n/--name and -r/--resume cannot be specified both." - "If you want to resume training in a new log folder, " - "use -n/--name in combination with --resume_from_checkpoint") + raise ValueError( + "-n/--name and -r/--resume cannot be specified both." + "If you want to resume training in a new log folder, " + "use -n/--name in combination with --resume_from_checkpoint" + ) # Check if the "resume" option is specified, resume training from the checkpoint if it is true ckpt = None @@ -646,7 +664,7 @@ if __name__ == "__main__": # Sets the seed for the random number generator to ensure reproducibility seed_everything(opt.seed) - # Initialize and save configuration using teh OmegaConf library. + # Initialize and save configuration using teh OmegaConf library. try: # init and save configs configs = [OmegaConf.load(cfg) for cfg in opt.base] @@ -676,7 +694,7 @@ if __name__ == "__main__": config.model["params"].update({"use_fp16": False}) if ckpt is not None: - #If a checkpoint path is specified in the ckpt variable, the code updates the "ckpt" key in the "params" dictionary of the config.model configuration with the value of ckpt + # If a checkpoint path is specified in the ckpt variable, the code updates the "ckpt" key in the "params" dictionary of the config.model configuration with the value of ckpt config.model["params"].update({"ckpt": ckpt}) rank_zero_info("Using ckpt_path = {}".format(config.model["params"]["ckpt"])) @@ -688,17 +706,12 @@ if __name__ == "__main__": # Default logger configs to log training metrics during the training process. default_logger_cfgs = { "wandb": { - "name": nowname, - "save_dir": logdir, - "offline": opt.debug, - "id": nowname, - } - , - "tensorboard": { - "save_dir": logdir, - "name": "diff_tb", - "log_graph": True - } + "name": nowname, + "save_dir": logdir, + "offline": opt.debug, + "id": nowname, + }, + "tensorboard": {"save_dir": logdir, "name": "diff_tb", "log_graph": True}, } # Set up the logger for TensorBoard @@ -722,11 +735,11 @@ if __name__ == "__main__": # modelcheckpoint - use TrainResult/EvalResult(checkpoint_on=metric) to # specify which metric is used to determine best models default_modelckpt_cfg = { - "dirpath": ckptdir, - "filename": "{epoch:06}", - "verbose": True, - "save_last": True, - } + "dirpath": ckptdir, + "filename": "{epoch:06}", + "verbose": True, + "save_last": True, + } if hasattr(model, "monitor"): default_modelckpt_cfg["monitor"] = model.monitor default_modelckpt_cfg["save_top_k"] = 3 @@ -736,48 +749,47 @@ if __name__ == "__main__": else: modelckpt_cfg = OmegaConf.create() modelckpt_cfg = OmegaConf.merge(default_modelckpt_cfg, modelckpt_cfg) - if version.parse(pl.__version__) < version.parse('1.4.0'): + if version.parse(pl.__version__) < version.parse("1.4.0"): trainer_kwargs["checkpoint_callback"] = ModelCheckpoint(**modelckpt_cfg) - #Create an empty OmegaConf configuration object + # Create an empty OmegaConf configuration object callbacks_cfg = OmegaConf.create() - - #Instantiate items according to the configs + + # Instantiate items according to the configs trainer_kwargs.setdefault("callbacks", []) setup_callback_config = { - "resume": opt.resume, # resume training if applicable - "now": now, - "logdir": logdir, # directory to save the log file - "ckptdir": ckptdir, # directory to save the checkpoint file - "cfgdir": cfgdir, # directory to save the configuration file - "config": config, # configuration dictionary - "lightning_config": lightning_config, # LightningModule configuration - } + "resume": opt.resume, # resume training if applicable + "now": now, + "logdir": logdir, # directory to save the log file + "ckptdir": ckptdir, # directory to save the checkpoint file + "cfgdir": cfgdir, # directory to save the configuration file + "config": config, # configuration dictionary + "lightning_config": lightning_config, # LightningModule configuration + } trainer_kwargs["callbacks"].append(SetupCallback(**setup_callback_config)) - + image_logger_config = { - - "batch_frequency": 750, # how frequently to log images - "max_images": 4, # maximum number of images to log - "clamp": True # whether to clamp pixel values to [0,1] - } + "batch_frequency": 750, # how frequently to log images + "max_images": 4, # maximum number of images to log + "clamp": True, # whether to clamp pixel values to [0,1] + } trainer_kwargs["callbacks"].append(ImageLogger(**image_logger_config)) - + learning_rate_logger_config = { - "logging_interval": "step", # logging frequency (either 'step' or 'epoch') - # "log_momentum": True # whether to log momentum (currently commented out) - } + "logging_interval": "step", # logging frequency (either 'step' or 'epoch') + # "log_momentum": True # whether to log momentum (currently commented out) + } trainer_kwargs["callbacks"].append(LearningRateMonitor(**learning_rate_logger_config)) - - metrics_over_trainsteps_checkpoint_config= { - "dirpath": os.path.join(ckptdir, 'trainstep_checkpoints'), + + metrics_over_trainsteps_checkpoint_config = { + "dirpath": os.path.join(ckptdir, "trainstep_checkpoints"), "filename": "{epoch:06}-{step:09}", "verbose": True, - 'save_top_k': -1, - 'every_n_train_steps': 10000, - 'save_weights_only': True - } + "save_top_k": -1, + "every_n_train_steps": 10000, + "save_weights_only": True, + } trainer_kwargs["callbacks"].append(ModelCheckpoint(**metrics_over_trainsteps_checkpoint_config)) trainer_kwargs["callbacks"].append(CUDACallback()) @@ -805,7 +817,7 @@ if __name__ == "__main__": ngpu = trainer_config["devices"] else: ngpu = 1 - if 'accumulate_grad_batches' in lightning_config.trainer: + if "accumulate_grad_batches" in lightning_config.trainer: accumulate_grad_batches = lightning_config.trainer.accumulate_grad_batches else: accumulate_grad_batches = 1 @@ -814,8 +826,10 @@ if __name__ == "__main__": if opt.scale_lr: model.learning_rate = accumulate_grad_batches * ngpu * bs * base_lr rank_zero_info( - "Setting learning rate to {:.2e} = {} (accumulate_grad_batches) * {} (num_gpus) * {} (batchsize) * {:.2e} (base_lr)" - .format(model.learning_rate, accumulate_grad_batches, ngpu, bs, base_lr)) + "Setting learning rate to {:.2e} = {} (accumulate_grad_batches) * {} (num_gpus) * {} (batchsize) * {:.2e} (base_lr)".format( + model.learning_rate, accumulate_grad_batches, ngpu, bs, base_lr + ) + ) else: model.learning_rate = base_lr rank_zero_info("++++ NOT USING LR SCALING ++++") @@ -832,9 +846,11 @@ if __name__ == "__main__": def divein(*args, **kwargs): if trainer.global_rank == 0: import pudb + pudb.set_trace() import signal + # Assign melk to SIGUSR1 signal and divein to SIGUSR2 signal signal.signal(signal.SIGUSR1, melk) signal.signal(signal.SIGUSR2, divein) diff --git a/examples/images/diffusion/requirements.txt b/examples/images/diffusion/requirements.txt index 59d027fcf60f74e6023531ec4a68d72602c91681..54c47cb5974c33848a6c615fab960db85d8114b2 100644 --- a/examples/images/diffusion/requirements.txt +++ b/examples/images/diffusion/requirements.txt @@ -7,12 +7,12 @@ imageio-ffmpeg==0.4.2 torchmetrics==0.7 omegaconf==2.1.1 test-tube>=0.7.5 -streamlit>=0.73.1 +streamlit>=1.11.1 einops==0.3.0 transformers webdataset==0.2.5 open-clip-torch==2.7.0 -gradio==3.11 +gradio==3.34.0 lightning==1.9.0 datasets colossalai diff --git a/examples/images/diffusion/scripts/download_first_stages.sh b/examples/images/diffusion/scripts/download_first_stages.sh index a8d79e99ccdff0a8d8762f23f3c0642401f32f6c..50dab5de5b90b1143bf928773d39e339851a1a54 100755 --- a/examples/images/diffusion/scripts/download_first_stages.sh +++ b/examples/images/diffusion/scripts/download_first_stages.sh @@ -38,4 +38,4 @@ unzip -o model.zip cd ../vq-f16 unzip -o model.zip -cd ../.. \ No newline at end of file +cd ../.. diff --git a/examples/images/diffusion/scripts/img2img.py b/examples/images/diffusion/scripts/img2img.py index 877538d4733dd06ab68d14b5205491fd516cfae2..4c386113dcc3deffa2ac4676256bdd7d5211b351 100644 --- a/examples/images/diffusion/scripts/img2img.py +++ b/examples/images/diffusion/scripts/img2img.py @@ -1,28 +1,30 @@ """make variations of input image""" -import argparse, os +import argparse +import os +from contextlib import nullcontext +from itertools import islice + +import numpy as np import PIL import torch -import numpy as np +from einops import rearrange, repeat from omegaconf import OmegaConf from PIL import Image -from tqdm import tqdm, trange -from itertools import islice -from einops import rearrange, repeat -from torchvision.utils import make_grid from torch import autocast -from contextlib import nullcontext +from torchvision.utils import make_grid +from tqdm import tqdm, trange + try: from lightning.pytorch import seed_everything except: from pytorch_lightning import seed_everything -from imwatermark import WatermarkEncoder - -from scripts.txt2img import put_watermark -from ldm.util import instantiate_from_config +from imwatermark import WatermarkEncoder from ldm.models.diffusion.ddim import DDIMSampler -from utils import replace_module, getModelSize +from ldm.util import instantiate_from_config +from scripts.txt2img import put_watermark +from utils import replace_module def chunk(it, size): @@ -58,7 +60,7 @@ def load_img(path): image = np.array(image).astype(np.float32) / 255.0 image = image[None].transpose(0, 3, 1, 2) image = torch.from_numpy(image) - return 2. * image - 1. + return 2.0 * image - 1.0 def main(): @@ -69,22 +71,13 @@ def main(): type=str, nargs="?", default="a painting of a virus monster playing guitar", - help="the prompt to render" + help="the prompt to render", ) - parser.add_argument( - "--init-img", - type=str, - nargs="?", - help="path to the input image" - ) + parser.add_argument("--init-img", type=str, nargs="?", help="path to the input image") parser.add_argument( - "--outdir", - type=str, - nargs="?", - help="dir to write results to", - default="outputs/img2img-samples" + "--outdir", type=str, nargs="?", help="dir to write results to", default="outputs/img2img-samples" ) parser.add_argument( @@ -96,7 +89,7 @@ def main(): parser.add_argument( "--fixed_code", - action='store_true', + action="store_true", help="if enabled, uses the same starting code across all samples ", ) @@ -177,11 +170,7 @@ def main(): help="the seed (for reproducible sampling)", ) parser.add_argument( - "--precision", - type=str, - help="evaluate at this precision", - choices=["full", "autocast"], - default="autocast" + "--precision", type=str, help="evaluate at this precision", choices=["full", "autocast"], default="autocast" ) parser.add_argument( "--use_int8", @@ -204,7 +193,7 @@ def main(): model = replace_module(model) # # to compute the model size # getModelSize(model) - + sampler = DDIMSampler(model) os.makedirs(opt.outdir, exist_ok=True) @@ -213,7 +202,7 @@ def main(): print("Creating invisible watermark encoder (see https://github.com/ShieldMnt/invisible-watermark)...") wm = "SDV2" wm_encoder = WatermarkEncoder() - wm_encoder.set_watermark('bytes', wm.encode('utf-8')) + wm_encoder.set_watermark("bytes", wm.encode("utf-8")) batch_size = opt.n_samples n_rows = opt.n_rows if opt.n_rows > 0 else batch_size @@ -235,12 +224,12 @@ def main(): assert os.path.isfile(opt.init_img) init_image = load_img(opt.init_img).to(device) - init_image = repeat(init_image, '1 ... -> b ...', b=batch_size) + init_image = repeat(init_image, "1 ... -> b ...", b=batch_size) init_latent = model.get_first_stage_encoding(model.encode_first_stage(init_image)) # move to latent space sampler.make_schedule(ddim_num_steps=opt.ddim_steps, ddim_eta=opt.ddim_eta, verbose=False) - assert 0. <= opt.strength <= 1., 'can only work with strength in [0.0, 1.0]' + assert 0.0 <= opt.strength <= 1.0, "can only work with strength in [0.0, 1.0]" t_enc = int(opt.strength * opt.ddim_steps) print(f"target t_enc is {t_enc} steps") @@ -261,14 +250,19 @@ def main(): # encode (scaled latent) z_enc = sampler.stochastic_encode(init_latent, torch.tensor([t_enc] * batch_size).to(device)) # decode it - samples = sampler.decode(z_enc, c, t_enc, unconditional_guidance_scale=opt.scale, - unconditional_conditioning=uc, ) + samples = sampler.decode( + z_enc, + c, + t_enc, + unconditional_guidance_scale=opt.scale, + unconditional_conditioning=uc, + ) x_samples = model.decode_first_stage(samples) x_samples = torch.clamp((x_samples + 1.0) / 2.0, min=0.0, max=1.0) for x_sample in x_samples: - x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c') + x_sample = 255.0 * rearrange(x_sample.cpu().numpy(), "c h w -> h w c") img = Image.fromarray(x_sample.astype(np.uint8)) img = put_watermark(img, wm_encoder) img.save(os.path.join(sample_path, f"{base_count:05}.png")) @@ -277,14 +271,14 @@ def main(): # additionally, save as grid grid = torch.stack(all_samples, 0) - grid = rearrange(grid, 'n b c h w -> (n b) c h w') + grid = rearrange(grid, "n b c h w -> (n b) c h w") grid = make_grid(grid, nrow=n_rows) # to image - grid = 255. * rearrange(grid, 'c h w -> h w c').cpu().numpy() + grid = 255.0 * rearrange(grid, "c h w -> h w c").cpu().numpy() grid = Image.fromarray(grid.astype(np.uint8)) grid = put_watermark(grid, wm_encoder) - grid.save(os.path.join(outpath, f'grid-{grid_count:04}.png')) + grid.save(os.path.join(outpath, f"grid-{grid_count:04}.png")) grid_count += 1 print(f"Your samples are ready and waiting for you here: \n{outpath} \nEnjoy.") diff --git a/examples/images/diffusion/scripts/inpaint.py b/examples/images/diffusion/scripts/inpaint.py index d6e6387a9a3b0afa73fae8af25f43a8ba856240e..afffcf1685e6bcb22d99024b8f89b4a222ccb8d0 100644 --- a/examples/images/diffusion/scripts/inpaint.py +++ b/examples/images/diffusion/scripts/inpaint.py @@ -1,32 +1,35 @@ -import argparse, os, sys, glob -from omegaconf import OmegaConf -from PIL import Image -from tqdm import tqdm +import argparse +import glob +import os + import numpy as np import torch -from main import instantiate_from_config from ldm.models.diffusion.ddim import DDIMSampler +from main import instantiate_from_config +from omegaconf import OmegaConf +from PIL import Image +from tqdm import tqdm def make_batch(image, mask, device): image = np.array(Image.open(image).convert("RGB")) - image = image.astype(np.float32)/255.0 - image = image[None].transpose(0,3,1,2) + image = image.astype(np.float32) / 255.0 + image = image[None].transpose(0, 3, 1, 2) image = torch.from_numpy(image) mask = np.array(Image.open(mask).convert("L")) - mask = mask.astype(np.float32)/255.0 - mask = mask[None,None] + mask = mask.astype(np.float32) / 255.0 + mask = mask[None, None] mask[mask < 0.5] = 0 mask[mask >= 0.5] = 1 mask = torch.from_numpy(mask) - masked_image = (1-mask)*image + masked_image = (1 - mask) * image batch = {"image": image, "mask": mask, "masked_image": masked_image} for k in batch: batch[k] = batch[k].to(device=device) - batch[k] = batch[k]*2.0-1.0 + batch[k] = batch[k] * 2.0 - 1.0 return batch @@ -58,8 +61,7 @@ if __name__ == "__main__": config = OmegaConf.load("models/ldm/inpainting_big/config.yaml") model = instantiate_from_config(config.model) - model.load_state_dict(torch.load("models/ldm/inpainting_big/last.ckpt")["state_dict"], - strict=False) + model.load_state_dict(torch.load("models/ldm/inpainting_big/last.ckpt")["state_dict"], strict=False) device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") model = model.to(device) @@ -74,25 +76,19 @@ if __name__ == "__main__": # encode masked image and concat downsampled mask c = model.cond_stage_model.encode(batch["masked_image"]) - cc = torch.nn.functional.interpolate(batch["mask"], - size=c.shape[-2:]) + cc = torch.nn.functional.interpolate(batch["mask"], size=c.shape[-2:]) c = torch.cat((c, cc), dim=1) - shape = (c.shape[1]-1,)+c.shape[2:] - samples_ddim, _ = sampler.sample(S=opt.steps, - conditioning=c, - batch_size=c.shape[0], - shape=shape, - verbose=False) + shape = (c.shape[1] - 1,) + c.shape[2:] + samples_ddim, _ = sampler.sample( + S=opt.steps, conditioning=c, batch_size=c.shape[0], shape=shape, verbose=False + ) x_samples_ddim = model.decode_first_stage(samples_ddim) - image = torch.clamp((batch["image"]+1.0)/2.0, - min=0.0, max=1.0) - mask = torch.clamp((batch["mask"]+1.0)/2.0, - min=0.0, max=1.0) - predicted_image = torch.clamp((x_samples_ddim+1.0)/2.0, - min=0.0, max=1.0) + image = torch.clamp((batch["image"] + 1.0) / 2.0, min=0.0, max=1.0) + mask = torch.clamp((batch["mask"] + 1.0) / 2.0, min=0.0, max=1.0) + predicted_image = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0) - inpainted = (1-mask)*image+mask*predicted_image - inpainted = inpainted.cpu().numpy().transpose(0,2,3,1)[0]*255 + inpainted = (1 - mask) * image + mask * predicted_image + inpainted = inpainted.cpu().numpy().transpose(0, 2, 3, 1)[0] * 255 Image.fromarray(inpainted.astype(np.uint8)).save(outpath) diff --git a/examples/images/diffusion/scripts/knn2img.py b/examples/images/diffusion/scripts/knn2img.py index e6eaaecab53eac9c97051c9a5cb457a240679725..763811665bbcbb708ba307e1b6fb6030461f718d 100644 --- a/examples/images/diffusion/scripts/knn2img.py +++ b/examples/images/diffusion/scripts/knn2img.py @@ -1,22 +1,22 @@ -import argparse, os, sys, glob -import clip -import torch -import torch.nn as nn -import numpy as np -from omegaconf import OmegaConf -from PIL import Image -from tqdm import tqdm, trange -from itertools import islice -from einops import rearrange, repeat -from torchvision.utils import make_grid -import scann +import argparse +import glob +import os import time +from itertools import islice from multiprocessing import cpu_count -from ldm.util import instantiate_from_config, parallel_data_prefetch +import numpy as np +import scann +import torch +from einops import rearrange from ldm.models.diffusion.ddim import DDIMSampler from ldm.models.diffusion.plms import PLMSSampler from ldm.modules.encoders.modules import FrozenClipImageEmbedder, FrozenCLIPTextEmbedder +from ldm.util import instantiate_from_config, parallel_data_prefetch +from omegaconf import OmegaConf +from PIL import Image +from torchvision.utils import make_grid +from tqdm import tqdm, trange DATABASES = [ "openimages", @@ -59,29 +59,24 @@ def load_model_from_config(config, ckpt, verbose=False): class Searcher(object): - def __init__(self, database, retriever_version='ViT-L/14'): + def __init__(self, database, retriever_version="ViT-L/14"): assert database in DATABASES # self.database = self.load_database(database) self.database_name = database - self.searcher_savedir = f'data/rdm/searchers/{self.database_name}' - self.database_path = f'data/rdm/retrieval_databases/{self.database_name}' + self.searcher_savedir = f"data/rdm/searchers/{self.database_name}" + self.database_path = f"data/rdm/retrieval_databases/{self.database_name}" self.retriever = self.load_retriever(version=retriever_version) - self.database = {'embedding': [], - 'img_id': [], - 'patch_coords': []} + self.database = {"embedding": [], "img_id": [], "patch_coords": []} self.load_database() self.load_searcher() - def train_searcher(self, k, - metric='dot_product', - searcher_savedir=None): - - print('Start training searcher') - searcher = scann.scann_ops_pybind.builder(self.database['embedding'] / - np.linalg.norm(self.database['embedding'], axis=1)[:, np.newaxis], - k, metric) + def train_searcher(self, k, metric="dot_product", searcher_savedir=None): + print("Start training searcher") + searcher = scann.scann_ops_pybind.builder( + self.database["embedding"] / np.linalg.norm(self.database["embedding"], axis=1)[:, np.newaxis], k, metric + ) self.searcher = searcher.score_brute_force().build() - print('Finish training searcher') + print("Finish training searcher") if searcher_savedir is not None: print(f'Save trained searcher under "{searcher_savedir}"') @@ -91,36 +86,40 @@ class Searcher(object): def load_single_file(self, saved_embeddings): compressed = np.load(saved_embeddings) self.database = {key: compressed[key] for key in compressed.files} - print('Finished loading of clip embeddings.') + print("Finished loading of clip embeddings.") def load_multi_files(self, data_archive): out_data = {key: [] for key in self.database} - for d in tqdm(data_archive, desc=f'Loading datapool from {len(data_archive)} individual files.'): + for d in tqdm(data_archive, desc=f"Loading datapool from {len(data_archive)} individual files."): for key in d.files: out_data[key].append(d[key]) return out_data def load_database(self): - print(f'Load saved patch embedding from "{self.database_path}"') - file_content = glob.glob(os.path.join(self.database_path, '*.npz')) + file_content = glob.glob(os.path.join(self.database_path, "*.npz")) if len(file_content) == 1: self.load_single_file(file_content[0]) elif len(file_content) > 1: data = [np.load(f) for f in file_content] - prefetched_data = parallel_data_prefetch(self.load_multi_files, data, - n_proc=min(len(data), cpu_count()), target_data_type='dict') + prefetched_data = parallel_data_prefetch( + self.load_multi_files, data, n_proc=min(len(data), cpu_count()), target_data_type="dict" + ) - self.database = {key: np.concatenate([od[key] for od in prefetched_data], axis=1)[0] for key in - self.database} + self.database = { + key: np.concatenate([od[key] for od in prefetched_data], axis=1)[0] for key in self.database + } else: raise ValueError(f'No npz-files in specified path "{self.database_path}" is this directory existing?') print(f'Finished loading of retrieval database of length {self.database["embedding"].shape[0]}.') - def load_retriever(self, version='ViT-L/14', ): + def load_retriever( + self, + version="ViT-L/14", + ): model = FrozenClipImageEmbedder(model=version) if torch.cuda.is_available(): model.cuda() @@ -128,14 +127,14 @@ class Searcher(object): return model def load_searcher(self): - print(f'load searcher for database {self.database_name} from {self.searcher_savedir}') + print(f"load searcher for database {self.database_name} from {self.searcher_savedir}") self.searcher = scann.scann_ops_pybind.load_searcher(self.searcher_savedir) - print('Finished loading searcher.') + print("Finished loading searcher.") def search(self, x, k): - if self.searcher is None and self.database['embedding'].shape[0] < 2e4: - self.train_searcher(k) # quickly fit searcher on the fly for small databases - assert self.searcher is not None, 'Cannot search with uninitialized searcher' + if self.searcher is None and self.database["embedding"].shape[0] < 2e4: + self.train_searcher(k) # quickly fit searcher on the fly for small databases + assert self.searcher is not None, "Cannot search with uninitialized searcher" if isinstance(x, torch.Tensor): x = x.detach().cpu().numpy() if len(x.shape) == 3: @@ -146,17 +145,19 @@ class Searcher(object): nns, distances = self.searcher.search_batched(query_embeddings, final_num_neighbors=k) end = time.time() - out_embeddings = self.database['embedding'][nns] - out_img_ids = self.database['img_id'][nns] - out_pc = self.database['patch_coords'][nns] + out_embeddings = self.database["embedding"][nns] + out_img_ids = self.database["img_id"][nns] + out_pc = self.database["patch_coords"][nns] - out = {'nn_embeddings': out_embeddings / np.linalg.norm(out_embeddings, axis=-1)[..., np.newaxis], - 'img_ids': out_img_ids, - 'patch_coords': out_pc, - 'queries': x, - 'exec_time': end - start, - 'nns': nns, - 'q_embeddings': query_embeddings} + out = { + "nn_embeddings": out_embeddings / np.linalg.norm(out_embeddings, axis=-1)[..., np.newaxis], + "img_ids": out_img_ids, + "patch_coords": out_pc, + "queries": x, + "exec_time": end - start, + "nns": nns, + "q_embeddings": query_embeddings, + } return out @@ -173,20 +174,16 @@ if __name__ == "__main__": type=str, nargs="?", default="a painting of a virus monster playing guitar", - help="the prompt to render" + help="the prompt to render", ) parser.add_argument( - "--outdir", - type=str, - nargs="?", - help="dir to write results to", - default="outputs/txt2img-samples" + "--outdir", type=str, nargs="?", help="dir to write results to", default="outputs/txt2img-samples" ) parser.add_argument( "--skip_grid", - action='store_true', + action="store_true", help="do not save a grid, only individual samples. Helpful when evaluating lots of samples", ) @@ -206,7 +203,7 @@ if __name__ == "__main__": parser.add_argument( "--plms", - action='store_true', + action="store_true", help="use plms sampling", ) @@ -287,14 +284,14 @@ if __name__ == "__main__": parser.add_argument( "--database", type=str, - default='artbench-surrealism', + default="artbench-surrealism", choices=DATABASES, help="The database used for the search, only applied when --use_neighbors=True", ) parser.add_argument( "--use_neighbors", default=False, - action='store_true', + action="store_true", help="Include neighbors in addition to text prompt for conditioning", ) parser.add_argument( @@ -358,41 +355,43 @@ if __name__ == "__main__": uc = None if searcher is not None: nn_dict = searcher(c, opt.knn) - c = torch.cat([c, torch.from_numpy(nn_dict['nn_embeddings']).cuda()], dim=1) + c = torch.cat([c, torch.from_numpy(nn_dict["nn_embeddings"]).cuda()], dim=1) if opt.scale != 1.0: uc = torch.zeros_like(c) if isinstance(prompts, tuple): prompts = list(prompts) shape = [16, opt.H // 16, opt.W // 16] # note: currently hardcoded for f16 model - samples_ddim, _ = sampler.sample(S=opt.ddim_steps, - conditioning=c, - batch_size=c.shape[0], - shape=shape, - verbose=False, - unconditional_guidance_scale=opt.scale, - unconditional_conditioning=uc, - eta=opt.ddim_eta, - ) + samples_ddim, _ = sampler.sample( + S=opt.ddim_steps, + conditioning=c, + batch_size=c.shape[0], + shape=shape, + verbose=False, + unconditional_guidance_scale=opt.scale, + unconditional_conditioning=uc, + eta=opt.ddim_eta, + ) x_samples_ddim = model.decode_first_stage(samples_ddim) x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0) for x_sample in x_samples_ddim: - x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c') + x_sample = 255.0 * rearrange(x_sample.cpu().numpy(), "c h w -> h w c") Image.fromarray(x_sample.astype(np.uint8)).save( - os.path.join(sample_path, f"{base_count:05}.png")) + os.path.join(sample_path, f"{base_count:05}.png") + ) base_count += 1 all_samples.append(x_samples_ddim) if not opt.skip_grid: # additionally, save as grid grid = torch.stack(all_samples, 0) - grid = rearrange(grid, 'n b c h w -> (n b) c h w') + grid = rearrange(grid, "n b c h w -> (n b) c h w") grid = make_grid(grid, nrow=n_rows) # to image - grid = 255. * rearrange(grid, 'c h w -> h w c').cpu().numpy() - Image.fromarray(grid.astype(np.uint8)).save(os.path.join(outpath, f'grid-{grid_count:04}.png')) + grid = 255.0 * rearrange(grid, "c h w -> h w c").cpu().numpy() + Image.fromarray(grid.astype(np.uint8)).save(os.path.join(outpath, f"grid-{grid_count:04}.png")) grid_count += 1 print(f"Your samples are ready and waiting for you here: \n{outpath} \nEnjoy.") diff --git a/examples/images/diffusion/scripts/sample_diffusion.py b/examples/images/diffusion/scripts/sample_diffusion.py index 876fe3c3642fcc8c7209e4f763c0134166615f78..740aae2435d26ad1ffc671609e3814601290f4cf 100644 --- a/examples/images/diffusion/scripts/sample_diffusion.py +++ b/examples/images/diffusion/scripts/sample_diffusion.py @@ -1,21 +1,26 @@ -import argparse, os, sys, glob, datetime, yaml -import torch +import argparse +import datetime +import glob +import os +import sys import time -import numpy as np -from tqdm import trange +import numpy as np +import torch +import yaml +from ldm.models.diffusion.ddim import DDIMSampler +from ldm.util import instantiate_from_config from omegaconf import OmegaConf from PIL import Image +from tqdm import trange -from ldm.models.diffusion.ddim import DDIMSampler -from ldm.util import instantiate_from_config +rescale = lambda x: (x + 1.0) / 2.0 -rescale = lambda x: (x + 1.) / 2. def custom_to_pil(x): x = x.detach().cpu() - x = torch.clamp(x, -1., 1.) - x = (x + 1.) / 2. + x = torch.clamp(x, -1.0, 1.0) + x = (x + 1.0) / 2.0 x = x.permute(1, 2, 0).numpy() x = (255 * x).astype(np.uint8) x = Image.fromarray(x) @@ -51,49 +56,51 @@ def logs2pil(logs, keys=["sample"]): @torch.no_grad() -def convsample(model, shape, return_intermediates=True, - verbose=True, - make_prog_row=False): - - +def convsample(model, shape, return_intermediates=True, verbose=True, make_prog_row=False): if not make_prog_row: - return model.p_sample_loop(None, shape, - return_intermediates=return_intermediates, verbose=verbose) + return model.p_sample_loop(None, shape, return_intermediates=return_intermediates, verbose=verbose) else: - return model.progressive_denoising( - None, shape, verbose=True - ) + return model.progressive_denoising(None, shape, verbose=True) @torch.no_grad() -def convsample_ddim(model, steps, shape, eta=1.0 - ): +def convsample_ddim(model, steps, shape, eta=1.0): ddim = DDIMSampler(model) bs = shape[0] shape = shape[1:] - samples, intermediates = ddim.sample(steps, batch_size=bs, shape=shape, eta=eta, verbose=False,) + samples, intermediates = ddim.sample( + steps, + batch_size=bs, + shape=shape, + eta=eta, + verbose=False, + ) return samples, intermediates @torch.no_grad() -def make_convolutional_sample(model, batch_size, vanilla=False, custom_steps=None, eta=1.0,): - - +def make_convolutional_sample( + model, + batch_size, + vanilla=False, + custom_steps=None, + eta=1.0, +): log = dict() - shape = [batch_size, - model.model.diffusion_model.in_channels, - model.model.diffusion_model.image_size, - model.model.diffusion_model.image_size] + shape = [ + batch_size, + model.model.diffusion_model.in_channels, + model.model.diffusion_model.image_size, + model.model.diffusion_model.image_size, + ] with model.ema_scope("Plotting"): t0 = time.time() if vanilla: - sample, progrow = convsample(model, shape, - make_prog_row=True) + sample, progrow = convsample(model, shape, make_prog_row=True) else: - sample, intermediates = convsample_ddim(model, steps=custom_steps, shape=shape, - eta=eta) + sample, intermediates = convsample_ddim(model, steps=custom_steps, shape=shape, eta=eta) t1 = time.time() @@ -101,32 +108,32 @@ def make_convolutional_sample(model, batch_size, vanilla=False, custom_steps=Non log["sample"] = x_sample log["time"] = t1 - t0 - log['throughput'] = sample.shape[0] / (t1 - t0) + log["throughput"] = sample.shape[0] / (t1 - t0) print(f'Throughput for this batch: {log["throughput"]}') return log + def run(model, logdir, batch_size=50, vanilla=False, custom_steps=None, eta=None, n_samples=50000, nplog=None): if vanilla: - print(f'Using Vanilla DDPM sampling with {model.num_timesteps} sampling steps.') + print(f"Using Vanilla DDPM sampling with {model.num_timesteps} sampling steps.") else: - print(f'Using DDIM sampling with {custom_steps} sampling steps and eta={eta}') - + print(f"Using DDIM sampling with {custom_steps} sampling steps and eta={eta}") tstart = time.time() - n_saved = len(glob.glob(os.path.join(logdir,'*.png')))-1 + n_saved = len(glob.glob(os.path.join(logdir, "*.png"))) - 1 # path = logdir if model.cond_stage_model is None: all_images = [] print(f"Running unconditional sampling for {n_samples} samples") for _ in trange(n_samples // batch_size, desc="Sampling Batches (unconditional)"): - logs = make_convolutional_sample(model, batch_size=batch_size, - vanilla=vanilla, custom_steps=custom_steps, - eta=eta) + logs = make_convolutional_sample( + model, batch_size=batch_size, vanilla=vanilla, custom_steps=custom_steps, eta=eta + ) n_saved = save_logs(logs, logdir, n_saved=n_saved, key="sample") all_images.extend([custom_to_np(logs["sample"])]) if n_saved >= n_samples: - print(f'Finish after generating {n_saved} samples') + print(f"Finish after generating {n_saved} samples") break all_img = np.concatenate(all_images, axis=0) all_img = all_img[:n_samples] @@ -135,7 +142,7 @@ def run(model, logdir, batch_size=50, vanilla=False, custom_steps=None, eta=None np.savez(nppath, all_img) else: - raise NotImplementedError('Currently only sampling for unconditional models supported.') + raise NotImplementedError("Currently only sampling for unconditional models supported.") print(f"sampling of {n_saved} images finished in {(time.time() - tstart) / 60.:.2f} minutes.") @@ -168,58 +175,33 @@ def get_parser(): nargs="?", help="load from logdir or checkpoint in logdir", ) - parser.add_argument( - "-n", - "--n_samples", - type=int, - nargs="?", - help="number of samples to draw", - default=50000 - ) + parser.add_argument("-n", "--n_samples", type=int, nargs="?", help="number of samples to draw", default=50000) parser.add_argument( "-e", "--eta", type=float, nargs="?", help="eta for ddim sampling (0.0 yields deterministic sampling)", - default=1.0 + default=1.0, ) parser.add_argument( "-v", "--vanilla_sample", default=False, - action='store_true', + action="store_true", help="vanilla sampling (default option is DDIM sampling)?", ) + parser.add_argument("-l", "--logdir", type=str, nargs="?", help="extra logdir", default="none") parser.add_argument( - "-l", - "--logdir", - type=str, - nargs="?", - help="extra logdir", - default="none" - ) - parser.add_argument( - "-c", - "--custom_steps", - type=int, - nargs="?", - help="number of steps for ddim and fastdpm sampling", - default=50 - ) - parser.add_argument( - "--batch_size", - type=int, - nargs="?", - help="the bs", - default=10 + "-c", "--custom_steps", type=int, nargs="?", help="number of steps for ddim and fastdpm sampling", default=50 ) + parser.add_argument("--batch_size", type=int, nargs="?", help="the bs", default=10) return parser def load_model_from_config(config, sd): model = instantiate_from_config(config) - model.load_state_dict(sd,strict=False) + model.load_state_dict(sd, strict=False) model.cuda() model.eval() return model @@ -233,8 +215,7 @@ def load_model(config, ckpt, gpu, eval_mode): else: pl_sd = {"state_dict": None} global_step = None - model = load_model_from_config(config.model, - pl_sd["state_dict"]) + model = load_model_from_config(config.model, pl_sd["state_dict"]) return model, global_step @@ -253,9 +234,9 @@ if __name__ == "__main__": if os.path.isfile(opt.resume): # paths = opt.resume.split("/") try: - logdir = '/'.join(opt.resume.split('/')[:-1]) + logdir = "/".join(opt.resume.split("/")[:-1]) # idx = len(paths)-paths[::-1].index("logs")+1 - print(f'Logdir is {logdir}') + print(f"Logdir is {logdir}") except ValueError: paths = opt.resume.split("/") idx = -2 # take a guess: path/to/logdir/checkpoints/model.ckpt @@ -278,7 +259,8 @@ if __name__ == "__main__": if opt.logdir != "none": locallog = logdir.split(os.sep)[-1] - if locallog == "": locallog = logdir.split(os.sep)[-2] + if locallog == "": + locallog = logdir.split(os.sep)[-2] print(f"Switching logdir from '{logdir}' to '{os.path.join(opt.logdir, locallog)}'") logdir = os.path.join(opt.logdir, locallog) @@ -301,13 +283,19 @@ if __name__ == "__main__": sampling_file = os.path.join(logdir, "sampling_config.yaml") sampling_conf = vars(opt) - with open(sampling_file, 'w') as f: + with open(sampling_file, "w") as f: yaml.dump(sampling_conf, f, default_flow_style=False) print(sampling_conf) - - run(model, imglogdir, eta=opt.eta, - vanilla=opt.vanilla_sample, n_samples=opt.n_samples, custom_steps=opt.custom_steps, - batch_size=opt.batch_size, nplog=numpylogdir) + run( + model, + imglogdir, + eta=opt.eta, + vanilla=opt.vanilla_sample, + n_samples=opt.n_samples, + custom_steps=opt.custom_steps, + batch_size=opt.batch_size, + nplog=numpylogdir, + ) print("done.") diff --git a/examples/images/diffusion/scripts/tests/test_checkpoint.py b/examples/images/diffusion/scripts/tests/test_checkpoint.py index 13622c4989fd48d44c49c875de8417b0ee7710cf..c0af17bdecaa262ef4777c5d0b09f25c40832615 100644 --- a/examples/images/diffusion/scripts/tests/test_checkpoint.py +++ b/examples/images/diffusion/scripts/tests/test_checkpoint.py @@ -1,28 +1,18 @@ -import os -import sys -from copy import deepcopy - +import torch import yaml -from datetime import datetime - from diffusers import StableDiffusionPipeline -import torch - -from main import get_parser from ldm.modules.diffusionmodules.openaimodel import UNetModel if __name__ == "__main__": with torch.no_grad(): yaml_path = "../../train_colossalai.yaml" - with open(yaml_path, 'r', encoding='utf-8') as f: + with open(yaml_path, "r", encoding="utf-8") as f: config = f.read() base_config = yaml.load(config, Loader=yaml.FullLoader) - unet_config = base_config['model']['params']['unet_config'] + unet_config = base_config["model"]["params"]["unet_config"] diffusion_model = UNetModel(**unet_config).to("cuda:0") - pipe = StableDiffusionPipeline.from_pretrained( - "/data/scratch/diffuser/stable-diffusion-v1-4" - ).to("cuda:0") + pipe = StableDiffusionPipeline.from_pretrained("/data/scratch/diffuser/stable-diffusion-v1-4").to("cuda:0") dif_model_2 = pipe.unet random_input_ = torch.rand((4, 4, 32, 32)).to("cuda:0") @@ -35,4 +25,4 @@ if __name__ == "__main__": out_1 = diffusion_model(random_input_, time_stamp, context_) out_2 = dif_model_2(random_input_2, time_stamp2, context_2) print(out_1.shape) - print(out_2['sample'].shape) \ No newline at end of file + print(out_2["sample"].shape) diff --git a/examples/images/diffusion/scripts/tests/test_watermark.py b/examples/images/diffusion/scripts/tests/test_watermark.py index f93f8a6e70763c0e284157bc8225827520b2f5ef..9bfc9fc7d9cbfe29a356bf24f29791568845c73c 100644 --- a/examples/images/diffusion/scripts/tests/test_watermark.py +++ b/examples/images/diffusion/scripts/tests/test_watermark.py @@ -5,14 +5,14 @@ from imwatermark import WatermarkDecoder def testit(img_path): bgr = cv2.imread(img_path) - decoder = WatermarkDecoder('bytes', 136) - watermark = decoder.decode(bgr, 'dwtDct') + decoder = WatermarkDecoder("bytes", 136) + watermark = decoder.decode(bgr, "dwtDct") try: - dec = watermark.decode('utf-8') + dec = watermark.decode("utf-8") except: dec = "null" print(dec) if __name__ == "__main__": - fire.Fire(testit) \ No newline at end of file + fire.Fire(testit) diff --git a/examples/images/diffusion/scripts/train_searcher.py b/examples/images/diffusion/scripts/train_searcher.py index 1e7904889c0145f9fb740fd4ae8e45c08728b255..1df0baa7e5cf89f46ba40b36f5937581c6a0df88 100644 --- a/examples/images/diffusion/scripts/train_searcher.py +++ b/examples/images/diffusion/scripts/train_searcher.py @@ -1,33 +1,39 @@ -import os, sys -import numpy as np -import scann import argparse import glob +import os +import sys from multiprocessing import cpu_count -from tqdm import tqdm +import numpy as np +import scann from ldm.util import parallel_data_prefetch +from tqdm import tqdm def search_bruteforce(searcher): return searcher.score_brute_force().build() -def search_partioned_ah(searcher, dims_per_block, aiq_threshold, reorder_k, - partioning_trainsize, num_leaves, num_leaves_to_search): - return searcher.tree(num_leaves=num_leaves, - num_leaves_to_search=num_leaves_to_search, - training_sample_size=partioning_trainsize). \ - score_ah(dims_per_block, anisotropic_quantization_threshold=aiq_threshold).reorder(reorder_k).build() +def search_partioned_ah( + searcher, dims_per_block, aiq_threshold, reorder_k, partioning_trainsize, num_leaves, num_leaves_to_search +): + return ( + searcher.tree( + num_leaves=num_leaves, num_leaves_to_search=num_leaves_to_search, training_sample_size=partioning_trainsize + ) + .score_ah(dims_per_block, anisotropic_quantization_threshold=aiq_threshold) + .reorder(reorder_k) + .build() + ) def search_ah(searcher, dims_per_block, aiq_threshold, reorder_k): - return searcher.score_ah(dims_per_block, anisotropic_quantization_threshold=aiq_threshold).reorder( - reorder_k).build() - -def load_datapool(dpath): + return ( + searcher.score_ah(dims_per_block, anisotropic_quantization_threshold=aiq_threshold).reorder(reorder_k).build() + ) +def load_datapool(dpath): def load_single_file(saved_embeddings): compressed = np.load(saved_embeddings) database = {key: compressed[key] for key in compressed.files} @@ -35,23 +41,26 @@ def load_datapool(dpath): def load_multi_files(data_archive): database = {key: [] for key in data_archive[0].files} - for d in tqdm(data_archive, desc=f'Loading datapool from {len(data_archive)} individual files.'): + for d in tqdm(data_archive, desc=f"Loading datapool from {len(data_archive)} individual files."): for key in d.files: database[key].append(d[key]) return database print(f'Load saved patch embedding from "{dpath}"') - file_content = glob.glob(os.path.join(dpath, '*.npz')) + file_content = glob.glob(os.path.join(dpath, "*.npz")) if len(file_content) == 1: data_pool = load_single_file(file_content[0]) elif len(file_content) > 1: data = [np.load(f) for f in file_content] - prefetched_data = parallel_data_prefetch(load_multi_files, data, - n_proc=min(len(data), cpu_count()), target_data_type='dict') + prefetched_data = parallel_data_prefetch( + load_multi_files, data, n_proc=min(len(data), cpu_count()), target_data_type="dict" + ) - data_pool = {key: np.concatenate([od[key] for od in prefetched_data], axis=1)[0] for key in prefetched_data[0].keys()} + data_pool = { + key: np.concatenate([od[key] for od in prefetched_data], axis=1)[0] for key in prefetched_data[0].keys() + } else: raise ValueError(f'No npz-files in specified path "{dpath}" is this directory existing?') @@ -59,16 +68,17 @@ def load_datapool(dpath): return data_pool -def train_searcher(opt, - metric='dot_product', - partioning_trainsize=None, - reorder_k=None, - # todo tune - aiq_thld=0.2, - dims_per_block=2, - num_leaves=None, - num_leaves_to_search=None,): - +def train_searcher( + opt, + metric="dot_product", + partioning_trainsize=None, + reorder_k=None, + # todo tune + aiq_thld=0.2, + dims_per_block=2, + num_leaves=None, + num_leaves_to_search=None, +): data_pool = load_datapool(opt.database) k = opt.knn @@ -77,71 +87,83 @@ def train_searcher(opt, # normalize # embeddings = - searcher = scann.scann_ops_pybind.builder(data_pool['embedding'] / np.linalg.norm(data_pool['embedding'], axis=1)[:, np.newaxis], k, metric) - pool_size = data_pool['embedding'].shape[0] - - print(*(['#'] * 100)) - print('Initializing scaNN searcher with the following values:') - print(f'k: {k}') - print(f'metric: {metric}') - print(f'reorder_k: {reorder_k}') - print(f'anisotropic_quantization_threshold: {aiq_thld}') - print(f'dims_per_block: {dims_per_block}') - print(*(['#'] * 100)) - print('Start training searcher....') - print(f'N samples in pool is {pool_size}') + searcher = scann.scann_ops_pybind.builder( + data_pool["embedding"] / np.linalg.norm(data_pool["embedding"], axis=1)[:, np.newaxis], k, metric + ) + pool_size = data_pool["embedding"].shape[0] + + print(*(["#"] * 100)) + print("Initializing scaNN searcher with the following values:") + print(f"k: {k}") + print(f"metric: {metric}") + print(f"reorder_k: {reorder_k}") + print(f"anisotropic_quantization_threshold: {aiq_thld}") + print(f"dims_per_block: {dims_per_block}") + print(*(["#"] * 100)) + print("Start training searcher....") + print(f"N samples in pool is {pool_size}") # this reflects the recommended design choices proposed at # https://github.com/google-research/google-research/blob/aca5f2e44e301af172590bb8e65711f0c9ee0cfd/scann/docs/algorithms.md if pool_size < 2e4: - print('Using brute force search.') + print("Using brute force search.") searcher = search_bruteforce(searcher) elif 2e4 <= pool_size and pool_size < 1e5: - print('Using asymmetric hashing search and reordering.') + print("Using asymmetric hashing search and reordering.") searcher = search_ah(searcher, dims_per_block, aiq_thld, reorder_k) else: - print('Using using partioning, asymmetric hashing search and reordering.') + print("Using using partioning, asymmetric hashing search and reordering.") if not partioning_trainsize: - partioning_trainsize = data_pool['embedding'].shape[0] // 10 + partioning_trainsize = data_pool["embedding"].shape[0] // 10 if not num_leaves: num_leaves = int(np.sqrt(pool_size)) if not num_leaves_to_search: num_leaves_to_search = max(num_leaves // 20, 1) - print('Partitioning params:') - print(f'num_leaves: {num_leaves}') - print(f'num_leaves_to_search: {num_leaves_to_search}') + print("Partitioning params:") + print(f"num_leaves: {num_leaves}") + print(f"num_leaves_to_search: {num_leaves_to_search}") # self.searcher = self.search_ah(searcher, dims_per_block, aiq_thld, reorder_k) - searcher = search_partioned_ah(searcher, dims_per_block, aiq_thld, reorder_k, - partioning_trainsize, num_leaves, num_leaves_to_search) + searcher = search_partioned_ah( + searcher, dims_per_block, aiq_thld, reorder_k, partioning_trainsize, num_leaves, num_leaves_to_search + ) - print('Finish training searcher') + print("Finish training searcher") searcher_savedir = opt.target_path os.makedirs(searcher_savedir, exist_ok=True) searcher.serialize(searcher_savedir) print(f'Saved trained searcher under "{searcher_savedir}"') -if __name__ == '__main__': + +if __name__ == "__main__": sys.path.append(os.getcwd()) parser = argparse.ArgumentParser() - parser.add_argument('--database', - '-d', - default='data/rdm/retrieval_databases/openimages', - type=str, - help='path to folder containing the clip feature of the database') - parser.add_argument('--target_path', - '-t', - default='data/rdm/searchers/openimages', - type=str, - help='path to the target folder where the searcher shall be stored.') - parser.add_argument('--knn', - '-k', - default=20, - type=int, - help='number of nearest neighbors, for which the searcher shall be optimized') - - opt, _ = parser.parse_known_args() - - train_searcher(opt,) \ No newline at end of file + parser.add_argument( + "--database", + "-d", + default="data/rdm/retrieval_databases/openimages", + type=str, + help="path to folder containing the clip feature of the database", + ) + parser.add_argument( + "--target_path", + "-t", + default="data/rdm/searchers/openimages", + type=str, + help="path to the target folder where the searcher shall be stored.", + ) + parser.add_argument( + "--knn", + "-k", + default=20, + type=int, + help="number of nearest neighbors, for which the searcher shall be optimized", + ) + + opt, _ = parser.parse_known_args() + + train_searcher( + opt, + ) diff --git a/examples/images/diffusion/scripts/txt2img.py b/examples/images/diffusion/scripts/txt2img.py index 364ebac6c67b62532b5bf0c26187dc366bec81af..feb17b9f77ae1220a08eb774ea86025b5392579b 100644 --- a/examples/images/diffusion/scripts/txt2img.py +++ b/examples/images/diffusion/scripts/txt2img.py @@ -1,29 +1,34 @@ -import argparse, os +import argparse +import os +from itertools import islice + import cv2 -import torch import numpy as np +import torch +from einops import rearrange from omegaconf import OmegaConf from PIL import Image -from tqdm import tqdm, trange -from itertools import islice -from einops import rearrange from torchvision.utils import make_grid +from tqdm import tqdm, trange + try: from lightning.pytorch import seed_everything except: from pytorch_lightning import seed_everything -from torch import autocast + from contextlib import nullcontext -from imwatermark import WatermarkEncoder -from ldm.util import instantiate_from_config +from imwatermark import WatermarkEncoder from ldm.models.diffusion.ddim import DDIMSampler -from ldm.models.diffusion.plms import PLMSSampler from ldm.models.diffusion.dpm_solver import DPMSolverSampler -from utils import replace_module, getModelSize +from ldm.models.diffusion.plms import PLMSSampler +from ldm.util import instantiate_from_config +from torch import autocast +from utils import replace_module torch.set_grad_enabled(False) + def chunk(it, size): it = iter(it) return iter(lambda: tuple(islice(it, size)), ()) @@ -55,14 +60,10 @@ def parse_args(): type=str, nargs="?", default="a professional photograph of an astronaut riding a triceratops", - help="the prompt to render" + help="the prompt to render", ) parser.add_argument( - "--outdir", - type=str, - nargs="?", - help="dir to write results to", - default="outputs/txt2img-samples" + "--outdir", type=str, nargs="?", help="dir to write results to", default="outputs/txt2img-samples" ) parser.add_argument( "--steps", @@ -72,17 +73,17 @@ def parse_args(): ) parser.add_argument( "--plms", - action='store_true', + action="store_true", help="use plms sampling", ) parser.add_argument( "--dpm", - action='store_true', + action="store_true", help="use DPM (2) sampler", ) parser.add_argument( "--fixed_code", - action='store_true', + action="store_true", help="if enabled, uses the same starting code across all samples ", ) parser.add_argument( @@ -162,11 +163,7 @@ def parse_args(): help="the seed (for reproducible sampling)", ) parser.add_argument( - "--precision", - type=str, - help="evaluate at this precision", - choices=["full", "autocast"], - default="autocast" + "--precision", type=str, help="evaluate at this precision", choices=["full", "autocast"], default="autocast" ) parser.add_argument( "--repeat", @@ -187,7 +184,7 @@ def parse_args(): def put_watermark(img, wm_encoder=None): if wm_encoder is not None: img = cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR) - img = wm_encoder.encode(img, 'dwtDct') + img = wm_encoder.encode(img, "dwtDct") img = Image.fromarray(img[:, :, ::-1]) return img @@ -197,17 +194,17 @@ def main(opt): config = OmegaConf.load(f"{opt.config}") model = load_model_from_config(config, f"{opt.ckpt}") - + device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") model = model.to(device) - + # quantize model if opt.use_int8: model = replace_module(model) # # to compute the model size # getModelSize(model) - + if opt.plms: sampler = PLMSSampler(model) elif opt.dpm: @@ -221,7 +218,7 @@ def main(opt): print("Creating invisible watermark encoder (see https://github.com/ShieldMnt/invisible-watermark)...") wm = "SDV2" wm_encoder = WatermarkEncoder() - wm_encoder.set_watermark('bytes', wm.encode('utf-8')) + wm_encoder.set_watermark("bytes", wm.encode("utf-8")) batch_size = opt.n_samples n_rows = opt.n_rows if opt.n_rows > 0 else batch_size @@ -248,56 +245,55 @@ def main(opt): start_code = torch.randn([opt.n_samples, opt.C, opt.H // opt.f, opt.W // opt.f], device=device) precision_scope = autocast if opt.precision == "autocast" else nullcontext - with torch.no_grad(), \ - precision_scope("cuda"), \ - model.ema_scope(): - all_samples = list() - for n in trange(opt.n_iter, desc="Sampling"): - for prompts in tqdm(data, desc="data"): - uc = None - if opt.scale != 1.0: - uc = model.get_learned_conditioning(batch_size * [""]) - if isinstance(prompts, tuple): - prompts = list(prompts) - c = model.get_learned_conditioning(prompts) - shape = [opt.C, opt.H // opt.f, opt.W // opt.f] - samples, _ = sampler.sample(S=opt.steps, - conditioning=c, - batch_size=opt.n_samples, - shape=shape, - verbose=False, - unconditional_guidance_scale=opt.scale, - unconditional_conditioning=uc, - eta=opt.ddim_eta, - x_T=start_code) - - x_samples = model.decode_first_stage(samples) - x_samples = torch.clamp((x_samples + 1.0) / 2.0, min=0.0, max=1.0) - - for x_sample in x_samples: - x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c') - img = Image.fromarray(x_sample.astype(np.uint8)) - img = put_watermark(img, wm_encoder) - img.save(os.path.join(sample_path, f"{base_count:05}.png")) - base_count += 1 - sample_count += 1 - - all_samples.append(x_samples) - - # additionally, save as grid - grid = torch.stack(all_samples, 0) - grid = rearrange(grid, 'n b c h w -> (n b) c h w') - grid = make_grid(grid, nrow=n_rows) - - # to image - grid = 255. * rearrange(grid, 'c h w -> h w c').cpu().numpy() - grid = Image.fromarray(grid.astype(np.uint8)) - grid = put_watermark(grid, wm_encoder) - grid.save(os.path.join(outpath, f'grid-{grid_count:04}.png')) - grid_count += 1 - - print(f"Your samples are ready and waiting for you here: \n{outpath} \n" - f" \nEnjoy.") + with torch.no_grad(), precision_scope("cuda"), model.ema_scope(): + all_samples = list() + for n in trange(opt.n_iter, desc="Sampling"): + for prompts in tqdm(data, desc="data"): + uc = None + if opt.scale != 1.0: + uc = model.get_learned_conditioning(batch_size * [""]) + if isinstance(prompts, tuple): + prompts = list(prompts) + c = model.get_learned_conditioning(prompts) + shape = [opt.C, opt.H // opt.f, opt.W // opt.f] + samples, _ = sampler.sample( + S=opt.steps, + conditioning=c, + batch_size=opt.n_samples, + shape=shape, + verbose=False, + unconditional_guidance_scale=opt.scale, + unconditional_conditioning=uc, + eta=opt.ddim_eta, + x_T=start_code, + ) + + x_samples = model.decode_first_stage(samples) + x_samples = torch.clamp((x_samples + 1.0) / 2.0, min=0.0, max=1.0) + + for x_sample in x_samples: + x_sample = 255.0 * rearrange(x_sample.cpu().numpy(), "c h w -> h w c") + img = Image.fromarray(x_sample.astype(np.uint8)) + img = put_watermark(img, wm_encoder) + img.save(os.path.join(sample_path, f"{base_count:05}.png")) + base_count += 1 + sample_count += 1 + + all_samples.append(x_samples) + + # additionally, save as grid + grid = torch.stack(all_samples, 0) + grid = rearrange(grid, "n b c h w -> (n b) c h w") + grid = make_grid(grid, nrow=n_rows) + + # to image + grid = 255.0 * rearrange(grid, "c h w -> h w c").cpu().numpy() + grid = Image.fromarray(grid.astype(np.uint8)) + grid = put_watermark(grid, wm_encoder) + grid.save(os.path.join(outpath, f"grid-{grid_count:04}.png")) + grid_count += 1 + + print(f"Your samples are ready and waiting for you here: \n{outpath} \n" f" \nEnjoy.") if __name__ == "__main__": diff --git a/examples/images/diffusion/scripts/utils.py b/examples/images/diffusion/scripts/utils.py index c954b22ca19045c985e2f0bc8da32fb79474996d..92ed0b4dfd0afa23b913c655ecdbc475cb084d96 100644 --- a/examples/images/diffusion/scripts/utils.py +++ b/examples/images/diffusion/scripts/utils.py @@ -1,6 +1,7 @@ import bitsandbytes as bnb -import torch.nn as nn import torch +import torch.nn as nn + class Linear8bit(nn.Linear): def __init__( @@ -12,11 +13,9 @@ class Linear8bit(nn.Linear): memory_efficient_backward=False, threshold=6.0, weight_data=None, - bias_data=None + bias_data=None, ): - super(Linear8bit, self).__init__( - input_features, output_features, bias - ) + super(Linear8bit, self).__init__(input_features, output_features, bias) self.state = bnb.MatmulLtState() self.bias = bias_data self.state.threshold = threshold @@ -24,13 +23,12 @@ class Linear8bit(nn.Linear): self.state.memory_efficient_backward = memory_efficient_backward if threshold > 0.0 and not has_fp16_weights: self.state.use_pool = True - + self.register_parameter("SCB", nn.Parameter(torch.empty(0), requires_grad=False)) self.weight = weight_data self.quant() - - def quant(self): + def quant(self): weight = self.weight.data.contiguous().half().cuda() CB, _, SCB, _, _ = bnb.functional.double_quant(weight) delattr(self, "weight") @@ -41,32 +39,34 @@ class Linear8bit(nn.Linear): def forward(self, x): self.state.is_training = self.training - + if self.bias is not None and self.bias.dtype != torch.float16: self.bias.data = self.bias.data.half() - + self.state.CB = self.weight.data self.state.SCB = self.SCB.data - + out = bnb.matmul(x, self.weight, bias=self.bias, state=self.state) del self.state.CxB return out + def replace_module(model): for name, module in model.named_children(): if len(list(module.children())) > 0: replace_module(module) - if isinstance(module, nn.Linear) and "out_proj" not in name: + if isinstance(module, nn.Linear) and "out_proj" not in name: model._modules[name] = Linear8bit( - input_features=module.in_features, - output_features=module.out_features, - threshold=6.0, - weight_data=module.weight, - bias_data=module.bias, - ) + input_features=module.in_features, + output_features=module.out_features, + threshold=6.0, + weight_data=module.weight, + bias_data=module.bias, + ) return model + def getModelSize(model): param_size = 0 param_sum = 0 @@ -79,5 +79,5 @@ def getModelSize(model): buffer_size += buffer.nelement() * buffer.element_size() buffer_sum += buffer.nelement() all_size = (param_size + buffer_size) / 1024 / 1024 - print('Model Size: {:.3f}MB'.format(all_size)) + print("Model Size: {:.3f}MB".format(all_size)) return (param_size, param_sum, buffer_size, buffer_sum, all_size) diff --git a/examples/images/diffusion/setup.py b/examples/images/diffusion/setup.py index a24d541676407eee1bea271179ffd1d80c6a8e79..13d9f892780146ec5c2cdfc8b331fa695f6f460a 100644 --- a/examples/images/diffusion/setup.py +++ b/examples/images/diffusion/setup.py @@ -1,13 +1,13 @@ -from setuptools import setup, find_packages +from setuptools import find_packages, setup setup( - name='latent-diffusion', - version='0.0.1', - description='', + name="latent-diffusion", + version="0.0.1", + description="", packages=find_packages(), install_requires=[ - 'torch', - 'numpy', - 'tqdm', + "torch", + "numpy", + "tqdm", ], -) \ No newline at end of file +) diff --git a/examples/images/diffusion/train_colossalai.sh b/examples/images/diffusion/train_colossalai.sh index 7f1a1bd14615a66e948879ffa49b4aae8a542d0a..c56ed7876e5a8c0e6d6df3a41bdfff0b09e8c9be 100755 --- a/examples/images/diffusion/train_colossalai.sh +++ b/examples/images/diffusion/train_colossalai.sh @@ -3,4 +3,3 @@ TRANSFORMERS_OFFLINE=1 DIFFUSERS_OFFLINE=1 python main.py --logdir /tmp --train --base configs/Teyvat/train_colossalai_teyvat.yaml --ckpt diffuser_root_dir/512-base-ema.ckpt - diff --git a/examples/images/diffusion/train_ddp.sh b/examples/images/diffusion/train_ddp.sh index 78fe765488c6451e8eeb23f217d49f750aabfd03..8304d6fa8b4fe8c576bb8e1d4a584df7b56fd36e 100644 --- a/examples/images/diffusion/train_ddp.sh +++ b/examples/images/diffusion/train_ddp.sh @@ -1,5 +1,5 @@ -HF_DATASETS_OFFLINE=1 -TRANSFORMERS_OFFLINE=1 -DIFFUSERS_OFFLINE=1 +HF_DATASETS_OFFLINE=1 +TRANSFORMERS_OFFLINE=1 +DIFFUSERS_OFFLINE=1 python main.py --logdir /tmp -t -b /configs/train_ddp.yaml diff --git a/examples/images/dreambooth/README.md b/examples/images/dreambooth/README.md index 7c117d841e24ccbb1e5d8cc7856df044986cc03e..6716052897a63883f3efa3039f70f5ba2fe19c5c 100644 --- a/examples/images/dreambooth/README.md +++ b/examples/images/dreambooth/README.md @@ -37,7 +37,7 @@ The `text` include the tag `Teyvat`, `Name`,`Element`, `Weapon`, `Region`, `Mode ## Training -We provide the script `colossalai.sh` to run the training task with colossalai. Meanwhile, we also provided traditional training process of dreambooth, `dreambooth.sh`, for possible comparation. For instance, the script of training process for [stable-diffusion-v1-4] model can be modified into: +We provide the script `colossalai.sh` to run the training task with colossalai. Meanwhile, we also provided traditional training process of dreambooth, `dreambooth.sh`, for possible comparison. For instance, the script of training process for [stable-diffusion-v1-4] model can be modified into: ```bash export MODEL_NAME="CompVis/stable-diffusion-v1-4" @@ -92,6 +92,29 @@ torchrun --nproc_per_node 2 train_dreambooth_colossalai.py \ --placement="cuda" ``` +## New API +We have modified our previous implementation of Dreambooth with our new Booster API, which offers a more flexible and efficient way to train your model. The new API is more user-friendly and easy to use. You can find the new API in `train_dreambooth_colossalai.py`. +We have also offer a shell script `test_ci.sh` for you to go through all our plugins for the booster. +For more information about the booster API you can refer to https://colossalai.org/docs/basics/booster_api/. + +## Performance + +| Strategy | #GPU | Batch Size | GPU RAM(GB) | speedup | +|:--------------:|:----:|:----------:|:-----------:|:-------:| +| Traditional | 1 | 16 | oom | \ | +| Traditional | 1 | 8 | 61.81 | 1 | +| torch_ddp | 4 | 16 | oom | \ | +| torch_ddp | 4 | 8 | 41.97 | 0.97 | +| gemini | 4 | 16 | 53.29 | \ | +| gemini | 4 | 8 | 29.36 | 2.00 | +| low_level_zero | 4 | 16 | 52.80 | \ | +| low_level_zero | 4 | 8 | 28.87 | 2.02 | + +The evaluation is performed on 4 Nvidia A100 GPUs with 80GB memory each, with GPU 0 & 1, 2 & 3 connected with NVLink. +We finetuned the [stable-diffusion-v1-4](https://huggingface.co/stabilityai/stable-diffusion-v1-4) model with 512x512 resolution on the [Teyvat](https://huggingface.co/datasets/Fazzie/Teyvat) dataset and compared +the memory cost and the throughput for the plugins. + + ## Inference Once you have trained a model using above command, the inference can be done simply using the `StableDiffusionPipeline`. Make sure to include the `identifier`(e.g. `--instance_prompt="a photo of sks dog" ` in the above example) in your prompt. @@ -116,7 +139,7 @@ You may contact us or participate in the following ways: 1. [Leaving a Star ⭐](https://github.com/hpcaitech/ColossalAI/stargazers) to show your like and support. Thanks! 2. Posting an [issue](https://github.com/hpcaitech/ColossalAI/issues/new/choose), or submitting a PR on GitHub follow the guideline in [Contributing](https://github.com/hpcaitech/ColossalAI/blob/main/CONTRIBUTING.md). 3. Join the Colossal-AI community on -[Slack](https://join.slack.com/t/colossalaiworkspace/shared_invite/zt-z7b26eeb-CBp7jouvu~r0~lcFzX832w), +[Slack](https://github.com/hpcaitech/public_assets/tree/main/colossalai/contact/slack), and [WeChat(微信)](https://raw.githubusercontent.com/hpcaitech/public_assets/main/colossalai/img/WeChat.png "qrcode") to share your ideas. 4. Send your official proposal to email contact@hpcaitech.com diff --git a/examples/images/dreambooth/colossalai.sh b/examples/images/dreambooth/colossalai.sh index 227d8b8bdb0410d29ae61b1f51ff4d6022bc7dd0..db4562dbc921a929840eda0cec92a51932a1ee2d 100755 --- a/examples/images/dreambooth/colossalai.sh +++ b/examples/images/dreambooth/colossalai.sh @@ -1,22 +1,18 @@ -export MODEL_NAME= -export INSTANCE_DIR= -export CLASS_DIR="path-to-class-images" -export OUTPUT_DIR="path-to-save-model" - -HF_DATASETS_OFFLINE=1 -TRANSFORMERS_OFFLINE=1 +HF_DATASETS_OFFLINE=1 +TRANSFORMERS_OFFLINE=1 DIFFUSERS_OFFLINE=1 -torchrun --nproc_per_node 2 --master_port=25641 train_dreambooth_colossalai.py \ - --pretrained_model_name_or_path=$MODEL_NAME \ - --instance_data_dir=$INSTANCE_DIR \ - --output_dir=$OUTPUT_DIR \ - --instance_prompt="a photo of a dog" \ +torchrun --nproc_per_node 4 --standalone train_dreambooth_colossalai.py \ + --pretrained_model_name_or_path="/data/dreambooth/diffuser/stable-diffusion-v1-4" \ + --instance_data_dir="/data/dreambooth/Teyvat/data" \ + --output_dir="./weight_output" \ + --instance_prompt="a picture of a dog" \ --resolution=512 \ + --plugin="gemini" \ --train_batch_size=1 \ - --gradient_accumulation_steps=1 \ --learning_rate=5e-6 \ --lr_scheduler="constant" \ --lr_warmup_steps=0 \ --num_class_images=200 \ - --placement="cuda" \ + --test_run=True \ + --placement="auto" \ diff --git a/examples/images/dreambooth/debug.py b/examples/images/dreambooth/debug.py index 33219b2caa298c626f9b3b50e8fc8645c922b9ca..8ce4dc3bbd80fd6bd704a76941edd03e086103b3 100644 --- a/examples/images/dreambooth/debug.py +++ b/examples/images/dreambooth/debug.py @@ -1,16 +1,16 @@ -''' +""" torchrun --standalone --nproc_per_node=1 debug.py -''' +""" from diffusers import AutoencoderKL import colossalai -from colossalai.zero import ColoInitContext, post_process_colo_init_ctx +from colossalai.zero import ColoInitContext path = "/data/scratch/diffuser/stable-diffusion-v1-4" colossalai.launch_from_torch(config={}) -with ColoInitContext(device='cpu'): +with ColoInitContext(device="cpu"): vae = AutoencoderKL.from_pretrained( path, subfolder="vae", diff --git a/examples/images/dreambooth/dreambooth.sh b/examples/images/dreambooth/dreambooth.sh index e063bc8279c53784a16df0a0f8c76bf785ecf46e..f6b8f5e1b87eb307f9bcf9a61b27b161915a580a 100644 --- a/examples/images/dreambooth/dreambooth.sh +++ b/examples/images/dreambooth/dreambooth.sh @@ -1,7 +1,7 @@ python train_dreambooth.py \ - --pretrained_model_name_or_path= ## Your Model Path \ - --instance_data_dir= ## Your Training Input Pics Path \ - --output_dir="path-to-save-model" \ + --pretrained_model_name_or_path="/data/dreambooth/diffuser/stable-diffusion-v1-4" \ + --instance_data_dir="/data/dreambooth/Teyvat/data" \ + --output_dir="./weight_output" \ --instance_prompt="a photo of a dog" \ --resolution=512 \ --train_batch_size=1 \ diff --git a/examples/images/dreambooth/inference.py b/examples/images/dreambooth/inference.py index c342821c783003b20b4ddb765c21a99f9397331f..ff317827aff7472e20950caef3c284fa200e5caf 100644 --- a/examples/images/dreambooth/inference.py +++ b/examples/images/dreambooth/inference.py @@ -1,7 +1,7 @@ -from diffusers import StableDiffusionPipeline, DiffusionPipeline import torch +from diffusers import DiffusionPipeline -model_id = +model_id = "" print(f"Loading model... from{model_id}") pipe = DiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16).to("cuda") diff --git a/examples/images/dreambooth/test_ci.sh b/examples/images/dreambooth/test_ci.sh index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..b0a96ec7007507d004a8c46b8d84256764f5c61b 100644 --- a/examples/images/dreambooth/test_ci.sh +++ b/examples/images/dreambooth/test_ci.sh @@ -0,0 +1,26 @@ +#!/bin/bash +set -xe +echo "this test is slow" + +# pip install -r requirements.txt + +# HF_DATASETS_OFFLINE=1 +# TRANSFORMERS_OFFLINE=1 +# DIFFUSERS_OFFLINE=1 + +# # "torch_ddp" "torch_ddp_fp16" "low_level_zero" +# for plugin in "gemini"; do +# torchrun --nproc_per_node 4 --standalone train_dreambooth_colossalai.py \ +# --pretrained_model_name_or_path="/data/dreambooth/diffuser/stable-diffusion-v1-4" \ +# --instance_data_dir="/data/dreambooth/Teyvat/data" \ +# --output_dir="./weight_output" \ +# --instance_prompt="a picture of a dog" \ +# --resolution=512 \ +# --plugin=$plugin \ +# --train_batch_size=1 \ +# --learning_rate=5e-6 \ +# --lr_scheduler="constant" \ +# --lr_warmup_steps=0 \ +# --test_run=True \ +# --num_class_images=200 +# don diff --git a/examples/images/dreambooth/train_dreambooth.py b/examples/images/dreambooth/train_dreambooth.py index b989955f7fb70c43a8daa5f2d87fbc0a3a3f0465..9b66089b2752fbef259efbc898b3729fab9cb7f1 100644 --- a/examples/images/dreambooth/train_dreambooth.py +++ b/examples/images/dreambooth/train_dreambooth.py @@ -104,8 +104,10 @@ def parse_args(input_args=None): "--num_class_images", type=int, default=100, - help=("Minimal class images for prior preservation loss. If there are not enough images already present in" - " class_data_dir, additional images will be sampled with class_prompt."), + help=( + "Minimal class images for prior preservation loss. If there are not enough images already present in" + " class_data_dir, additional images will be sampled with class_prompt." + ), ) parser.add_argument( "--output_dir", @@ -118,17 +120,18 @@ def parse_args(input_args=None): "--resolution", type=int, default=512, - help=("The resolution for input images, all the images in the train/validation dataset will be resized to this" - " resolution"), + help=( + "The resolution for input images, all the images in the train/validation dataset will be resized to this" + " resolution" + ), + ) + parser.add_argument( + "--center_crop", action="store_true", help="Whether to center crop images before resizing to resolution" ) - parser.add_argument("--center_crop", - action="store_true", - help="Whether to center crop images before resizing to resolution") parser.add_argument("--train_text_encoder", action="store_true", help="Whether to train the text encoder") - parser.add_argument("--train_batch_size", - type=int, - default=4, - help="Batch size (per device) for the training dataloader.") + parser.add_argument( + "--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader." + ) parser.add_argument("--sample_batch_size", type=int, default=4, help="Batch size (per device) for sampling images.") parser.add_argument("--num_train_epochs", type=int, default=1) parser.add_argument( @@ -165,16 +168,17 @@ def parse_args(input_args=None): "--lr_scheduler", type=str, default="constant", - help=('The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",' - ' "constant", "constant_with_warmup"]'), - ) - parser.add_argument("--lr_warmup_steps", - type=int, - default=500, - help="Number of steps for the warmup in the lr scheduler.") - parser.add_argument("--use_8bit_adam", - action="store_true", - help="Whether or not to use 8-bit Adam from bitsandbytes.") + help=( + 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",' + ' "constant", "constant_with_warmup"]' + ), + ) + parser.add_argument( + "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler." + ) + parser.add_argument( + "--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes." + ) parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.") parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.") parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.") @@ -192,8 +196,10 @@ def parse_args(input_args=None): "--logging_dir", type=str, default="logs", - help=("[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to" - " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."), + help=( + "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to" + " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***." + ), ) parser.add_argument( "--mixed_precision", @@ -203,7 +209,8 @@ def parse_args(input_args=None): help=( "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=" " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the" - " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."), + " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config." + ), ) parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank") @@ -269,12 +276,14 @@ class DreamBoothDataset(Dataset): else: self.class_data_root = None - self.image_transforms = transforms.Compose([ - transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR), - transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size), - transforms.ToTensor(), - transforms.Normalize([0.5], [0.5]), - ]) + self.image_transforms = transforms.Compose( + [ + transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR), + transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size), + transforms.ToTensor(), + transforms.Normalize([0.5], [0.5]), + ] + ) def __len__(self): return self._length @@ -350,7 +359,8 @@ def main(args): if args.train_text_encoder and args.gradient_accumulation_steps > 1 and accelerator.num_processes > 1: raise ValueError( "Gradient accumulation is not supported when training the text encoder in distributed training. " - "Please set gradient_accumulation_steps to 1. This feature will be supported in the future.") + "Please set gradient_accumulation_steps to 1. This feature will be supported in the future." + ) if args.seed is not None: set_seed(args.seed) @@ -380,9 +390,9 @@ def main(args): sample_dataloader = accelerator.prepare(sample_dataloader) pipeline.to(accelerator.device) - for example in tqdm(sample_dataloader, - desc="Generating class images", - disable=not accelerator.is_local_main_process): + for example in tqdm( + sample_dataloader, desc="Generating class images", disable=not accelerator.is_local_main_process + ): images = pipeline(example["prompt"]).images for i, image in enumerate(images): @@ -456,8 +466,9 @@ def main(args): text_encoder.gradient_checkpointing_enable() if args.scale_lr: - args.learning_rate = (args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * - accelerator.num_processes) + args.learning_rate = ( + args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes + ) # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs if args.use_8bit_adam: @@ -470,8 +481,9 @@ def main(args): else: optimizer_class = torch.optim.AdamW - params_to_optimize = (itertools.chain(unet.parameters(), text_encoder.parameters()) - if args.train_text_encoder else unet.parameters()) + params_to_optimize = ( + itertools.chain(unet.parameters(), text_encoder.parameters()) if args.train_text_encoder else unet.parameters() + ) optimizer = optimizer_class( params_to_optimize, lr=args.learning_rate, @@ -506,9 +518,7 @@ def main(args): pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float() input_ids = tokenizer.pad( - { - "input_ids": input_ids - }, + {"input_ids": input_ids}, padding="max_length", max_length=tokenizer.model_max_length, return_tensors="pt", @@ -520,11 +530,9 @@ def main(args): } return batch - train_dataloader = torch.utils.data.DataLoader(train_dataset, - batch_size=args.train_batch_size, - shuffle=True, - collate_fn=collate_fn, - num_workers=1) + train_dataloader = torch.utils.data.DataLoader( + train_dataset, batch_size=args.train_batch_size, shuffle=True, collate_fn=collate_fn, num_workers=1 + ) # Scheduler and math around the number of training steps. overrode_max_train_steps = False @@ -542,10 +550,12 @@ def main(args): if args.train_text_encoder: unet, text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( - unet, text_encoder, optimizer, train_dataloader, lr_scheduler) + unet, text_encoder, optimizer, train_dataloader, lr_scheduler + ) else: - unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(unet, optimizer, train_dataloader, - lr_scheduler) + unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + unet, optimizer, train_dataloader, lr_scheduler + ) weight_dtype = torch.float32 if accelerator.mixed_precision == "fp16": @@ -641,8 +651,11 @@ def main(args): accelerator.backward(loss) if accelerator.sync_gradients: - params_to_clip = (itertools.chain(unet.parameters(), text_encoder.parameters()) - if args.train_text_encoder else unet.parameters()) + params_to_clip = ( + itertools.chain(unet.parameters(), text_encoder.parameters()) + if args.train_text_encoder + else unet.parameters() + ) accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) optimizer.step() lr_scheduler.step() diff --git a/examples/images/dreambooth/train_dreambooth_colossalai.py b/examples/images/dreambooth/train_dreambooth_colossalai.py index e6159e1058b97f21135e60ffe88e535455e993d3..1a7f8da7f7d0a9ad950bd70a7fc874c231b28ad3 100644 --- a/examples/images/dreambooth/train_dreambooth_colossalai.py +++ b/examples/images/dreambooth/train_dreambooth_colossalai.py @@ -2,10 +2,12 @@ import argparse import hashlib import math import os +import shutil from pathlib import Path from typing import Optional import torch +import torch.distributed as dist import torch.nn.functional as F import torch.utils.checkpoint from diffusers import AutoencoderKL, DDPMScheduler, DiffusionPipeline, UNet2DConditionModel @@ -18,12 +20,11 @@ from tqdm.auto import tqdm from transformers import AutoTokenizer, PretrainedConfig import colossalai -from colossalai.context.parallel_mode import ParallelMode -from colossalai.core import global_context as gpc +from colossalai.booster import Booster +from colossalai.booster.plugin import GeminiPlugin, LowLevelZeroPlugin, TorchDDPPlugin from colossalai.logging import disable_existing_loggers, get_dist_logger +from colossalai.nn.optimizer import HybridAdam from colossalai.utils import get_current_device -from colossalai.zero import ColoInitContext, GeminiAdamOptimizer -from colossalai.zero.gemini import get_static_torch_model disable_existing_loggers() logger = get_dist_logger() @@ -58,6 +59,13 @@ def parse_args(input_args=None): required=True, help="Path to pretrained model or model identifier from huggingface.co/models.", ) + parser.add_argument( + "--externel_unet_path", + type=str, + default=None, + required=False, + help="Path to the externel unet model.", + ) parser.add_argument( "--revision", type=str, @@ -109,8 +117,10 @@ def parse_args(input_args=None): "--num_class_images", type=int, default=100, - help=("Minimal class images for prior preservation loss. If there are not enough images already present in" - " class_data_dir, additional images will be sampled with class_prompt."), + help=( + "Minimal class images for prior preservation loss. If there are not enough images already present in" + " class_data_dir, additional images will be sampled with class_prompt." + ), ) parser.add_argument( "--output_dir", @@ -123,26 +133,29 @@ def parse_args(input_args=None): "--resolution", type=int, default=512, - help=("The resolution for input images, all the images in the train/validation dataset will be resized to this" - " resolution"), + help=( + "The resolution for input images, all the images in the train/validation dataset will be resized to this" + " resolution" + ), ) parser.add_argument( - "--placement", - type=str, - default="cpu", - help="Placement Policy for Gemini. Valid when using colossalai as dist plan.", + "--offload_optim_frac", + type=float, + default=1.0, + help="Fraction of optimizer states to be offloaded. Valid when using colossalai as dist plan.", ) parser.add_argument( "--center_crop", default=False, action="store_true", - help=("Whether to center crop the input images to the resolution. If not set, the images will be randomly" - " cropped. The images will be resized to the resolution first before cropping."), + help=( + "Whether to center crop the input images to the resolution. If not set, the images will be randomly" + " cropped. The images will be resized to the resolution first before cropping." + ), + ) + parser.add_argument( + "--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader." ) - parser.add_argument("--train_batch_size", - type=int, - default=4, - help="Batch size (per device) for the training dataloader.") parser.add_argument("--sample_batch_size", type=int, default=4, help="Batch size (per device) for sampling images.") parser.add_argument("--num_train_epochs", type=int, default=1) parser.add_argument( @@ -173,32 +186,44 @@ def parse_args(input_args=None): "--lr_scheduler", type=str, default="constant", - help=('The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",' - ' "constant", "constant_with_warmup"]'), + help=( + 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",' + ' "constant", "constant_with_warmup"]' + ), + ) + parser.add_argument( + "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler." + ) + parser.add_argument( + "--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes." ) - parser.add_argument("--lr_warmup_steps", - type=int, - default=500, - help="Number of steps for the warmup in the lr scheduler.") - parser.add_argument("--use_8bit_adam", - action="store_true", - help="Whether or not to use 8-bit Adam from bitsandbytes.") parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.") parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.") parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.") + parser.add_argument("--test_run", default=False, help="Whether to use a smaller dataset for test run.") parser.add_argument( "--hub_model_id", type=str, default=None, help="The name of the repository to keep in sync with the local `output_dir`.", ) + parser.add_argument( + "-p", + "--plugin", + type=str, + default="torch_ddp", + choices=["torch_ddp", "torch_ddp_fp16", "gemini", "low_level_zero"], + help="plugin to use", + ) parser.add_argument( "--logging_dir", type=str, default="logs", - help=("[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to" - " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."), + help=( + "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to" + " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***." + ), ) parser.add_argument( "--mixed_precision", @@ -208,7 +233,8 @@ def parse_args(input_args=None): help=( "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=" " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the" - " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."), + " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config." + ), ) parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank") @@ -250,6 +276,7 @@ class DreamBoothDataset(Dataset): class_prompt=None, size=512, center_crop=False, + test=False, ): self.size = size self.center_crop = center_crop @@ -260,6 +287,8 @@ class DreamBoothDataset(Dataset): raise ValueError("Instance images root doesn't exists.") self.instance_images_path = list(Path(instance_data_root).iterdir()) + if test: + self.instance_images_path = self.instance_images_path[:10] self.num_instance_images = len(self.instance_images_path) self.instance_prompt = instance_prompt self._length = self.num_instance_images @@ -274,12 +303,14 @@ class DreamBoothDataset(Dataset): else: self.class_data_root = None - self.image_transforms = transforms.Compose([ - transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR), - transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size), - transforms.ToTensor(), - transforms.Normalize([0.5], [0.5]), - ]) + self.image_transforms = transforms.Compose( + [ + transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR), + transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size), + transforms.ToTensor(), + transforms.Normalize([0.5], [0.5]), + ] + ) def __len__(self): return self._length @@ -339,26 +370,14 @@ def get_full_repo_name(model_id: str, organization: Optional[str] = None, token: return f"{organization}/{model_id}" -# Gemini + ZeRO DDP -def gemini_zero_dpp(model: torch.nn.Module, placememt_policy: str = "auto"): - from colossalai.nn.parallel import GeminiDDP - - model = GeminiDDP(model, - device=get_current_device(), - placement_policy=placememt_policy, - pin_memory=True, - search_range_mb=64) - return model - - def main(args): if args.seed is None: colossalai.launch_from_torch(config={}) else: colossalai.launch_from_torch(config={}, seed=args.seed) - local_rank = gpc.get_local_rank(ParallelMode.DATA) - world_size = gpc.get_world_size(ParallelMode.DATA) + local_rank = dist.get_rank() + world_size = dist.get_world_size() if args.with_prior_preservation: class_images_dir = Path(args.class_data_dir) @@ -385,14 +404,14 @@ def main(args): pipeline.to(get_current_device()) for example in tqdm( - sample_dataloader, - desc="Generating class images", - disable=not local_rank == 0, + sample_dataloader, + desc="Generating class images", + disable=not local_rank == 0, ): images = pipeline(example["prompt"]).images for i, image in enumerate(images): - hash_image = hashlib.sha1(image.tobytes()).hexdigest() + hash_image = hashlib.sha256(image.tobytes()).hexdigest() image_filename = class_images_dir / f"{example['index'][i] + cur_class_images}-{hash_image}.jpg" image.save(image_filename) @@ -452,12 +471,16 @@ def main(args): revision=args.revision, ) - logger.info(f"Loading UNet2DConditionModel from {args.pretrained_model_name_or_path}", ranks=[0]) - with ColoInitContext(device=get_current_device()): - unet = UNet2DConditionModel.from_pretrained(args.pretrained_model_name_or_path, - subfolder="unet", - revision=args.revision, - low_cpu_mem_usage=False) + if args.externel_unet_path is None: + logger.info(f"Loading UNet2DConditionModel from {args.pretrained_model_name_or_path}", ranks=[0]) + unet = UNet2DConditionModel.from_pretrained( + args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision, low_cpu_mem_usage=False + ) + else: + logger.info(f"Loading UNet2DConditionModel from {args.externel_unet_path}", ranks=[0]) + unet = UNet2DConditionModel.from_pretrained( + args.externel_unet_path, revision=args.revision, low_cpu_mem_usage=False + ) vae.requires_grad_(False) text_encoder.requires_grad_(False) @@ -468,10 +491,24 @@ def main(args): if args.scale_lr: args.learning_rate = args.learning_rate * args.train_batch_size * world_size - unet = gemini_zero_dpp(unet, args.placement) + # Use Booster API to use Gemini/Zero with ColossalAI + + booster_kwargs = {} + if args.plugin == "torch_ddp_fp16": + booster_kwargs["mixed_precision"] = "fp16" + if args.plugin.startswith("torch_ddp"): + plugin = TorchDDPPlugin() + elif args.plugin == "gemini": + plugin = GeminiPlugin(offload_optim_frac=args.offload_optim_frac, strict_ddp_mode=True, initial_scale=2**5) + elif args.plugin == "low_level_zero": + plugin = LowLevelZeroPlugin(initial_scale=2**5) + + booster = Booster(plugin=plugin, **booster_kwargs) # config optimizer for colossalai zero - optimizer = GeminiAdamOptimizer(unet, lr=args.learning_rate, initial_scale=2**5, clipping_norm=args.max_grad_norm) + optimizer = HybridAdam( + unet.parameters(), lr=args.learning_rate, initial_scale=2**5, clipping_norm=args.max_grad_norm + ) # load noise_scheduler noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler") @@ -486,6 +523,7 @@ def main(args): tokenizer=tokenizer, size=args.resolution, center_crop=args.center_crop, + test=args.test_run, ) def collate_fn(examples): @@ -502,9 +540,7 @@ def main(args): pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float() input_ids = tokenizer.pad( - { - "input_ids": input_ids - }, + {"input_ids": input_ids}, padding="max_length", max_length=tokenizer.model_max_length, return_tensors="pt", @@ -516,11 +552,9 @@ def main(args): } return batch - train_dataloader = torch.utils.data.DataLoader(train_dataset, - batch_size=args.train_batch_size, - shuffle=True, - collate_fn=collate_fn, - num_workers=1) + train_dataloader = torch.utils.data.DataLoader( + train_dataset, batch_size=args.train_batch_size, shuffle=True, collate_fn=collate_fn, num_workers=1 + ) # Scheduler and math around the number of training steps. overrode_max_train_steps = False @@ -554,6 +588,8 @@ def main(args): # Afterwards we recalculate our number of training epochs args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) + unet, optimizer, _, _, lr_scheduler = booster.boost(unet, optimizer, lr_scheduler=lr_scheduler) + # Train! total_batch_size = args.train_batch_size * world_size @@ -637,37 +673,26 @@ def main(args): logs = { "loss": loss.detach().item(), "lr": optimizer.param_groups[0]["lr"], - } # lr_scheduler.get_last_lr()[0]} + } # lr_scheduler.get_last_lr()[0]} progress_bar.set_postfix(**logs) if global_step % args.save_steps == 0: torch.cuda.synchronize() - torch_unet = get_static_torch_model(unet) + save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}") + booster.save_model(unet, os.path.join(save_path, "diffusion_pytorch_model.bin")) if local_rank == 0: - pipeline = DiffusionPipeline.from_pretrained( - args.pretrained_model_name_or_path, - unet=torch_unet, - revision=args.revision, - ) - save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}") - pipeline.save_pretrained(save_path) + if not os.path.exists(os.path.join(save_path, "config.json")): + shutil.copy(os.path.join(args.pretrained_model_name_or_path, "unet/config.json"), save_path) logger.info(f"Saving model checkpoint to {save_path}", ranks=[0]) if global_step >= args.max_train_steps: break - torch.cuda.synchronize() - unet = get_static_torch_model(unet) + booster.save_model(unet, os.path.join(args.output_dir, "diffusion_pytorch_model.bin")) + logger.info(f"Saving model checkpoint to {args.output_dir} on rank {local_rank}") if local_rank == 0: - pipeline = DiffusionPipeline.from_pretrained( - args.pretrained_model_name_or_path, - unet=unet, - revision=args.revision, - ) - - pipeline.save_pretrained(args.output_dir) - logger.info(f"Saving model checkpoint to {args.output_dir}", ranks=[0]) - + if not os.path.exists(os.path.join(args.output_dir, "config.json")): + shutil.copy(os.path.join(args.pretrained_model_name_or_path, "unet/config.json"), args.output_dir) if args.push_to_hub: repo.push_to_hub(commit_message="End of training", blocking=False, auto_lfs_prune=True) diff --git a/examples/images/dreambooth/train_dreambooth_colossalai_lora.py b/examples/images/dreambooth/train_dreambooth_colossalai_lora.py index 1b2fc778d5ed05de78b59464c9bd231d0d411a21..ea6dde8bb5788d7dce32768e798285d7fdf2671a 100644 --- a/examples/images/dreambooth/train_dreambooth_colossalai_lora.py +++ b/examples/images/dreambooth/train_dreambooth_colossalai_lora.py @@ -2,6 +2,7 @@ import argparse import hashlib import math import os +import shutil from pathlib import Path from typing import Optional @@ -20,12 +21,13 @@ from tqdm.auto import tqdm from transformers import AutoTokenizer, PretrainedConfig import colossalai -from colossalai.context.parallel_mode import ParallelMode -from colossalai.core import global_context as gpc +from colossalai.booster import Booster +from colossalai.booster.plugin import GeminiPlugin, LowLevelZeroPlugin, TorchDDPPlugin +from colossalai.legacy.context.parallel_mode import ParallelMode +from colossalai.legacy.core import global_context as gpc from colossalai.logging import disable_existing_loggers, get_dist_logger +from colossalai.nn.optimizer import HybridAdam from colossalai.utils import get_current_device -from colossalai.zero import ColoInitContext, GeminiAdamOptimizer -from colossalai.zero.gemini import get_static_torch_model disable_existing_loggers() logger = get_dist_logger() @@ -60,6 +62,13 @@ def parse_args(input_args=None): required=True, help="Path to pretrained model or model identifier from huggingface.co/models.", ) + parser.add_argument( + "--externel_unet_path", + type=str, + default=None, + required=False, + help="Path to the externel unet model.", + ) parser.add_argument( "--revision", type=str, @@ -111,8 +120,10 @@ def parse_args(input_args=None): "--num_class_images", type=int, default=100, - help=("Minimal class images for prior preservation loss. If there are not enough images already present in" - " class_data_dir, additional images will be sampled with class_prompt."), + help=( + "Minimal class images for prior preservation loss. If there are not enough images already present in" + " class_data_dir, additional images will be sampled with class_prompt." + ), ) parser.add_argument( "--output_dir", @@ -125,8 +136,10 @@ def parse_args(input_args=None): "--resolution", type=int, default=512, - help=("The resolution for input images, all the images in the train/validation dataset will be resized to this" - " resolution"), + help=( + "The resolution for input images, all the images in the train/validation dataset will be resized to this" + " resolution" + ), ) parser.add_argument( "--placement", @@ -138,13 +151,14 @@ def parse_args(input_args=None): "--center_crop", default=False, action="store_true", - help=("Whether to center crop the input images to the resolution. If not set, the images will be randomly" - " cropped. The images will be resized to the resolution first before cropping."), + help=( + "Whether to center crop the input images to the resolution. If not set, the images will be randomly" + " cropped. The images will be resized to the resolution first before cropping." + ), + ) + parser.add_argument( + "--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader." ) - parser.add_argument("--train_batch_size", - type=int, - default=4, - help="Batch size (per device) for the training dataloader.") parser.add_argument("--sample_batch_size", type=int, default=4, help="Batch size (per device) for sampling images.") parser.add_argument("--num_train_epochs", type=int, default=1) parser.add_argument( @@ -175,16 +189,17 @@ def parse_args(input_args=None): "--lr_scheduler", type=str, default="constant", - help=('The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",' - ' "constant", "constant_with_warmup"]'), + help=( + 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",' + ' "constant", "constant_with_warmup"]' + ), + ) + parser.add_argument( + "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler." + ) + parser.add_argument( + "--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes." ) - parser.add_argument("--lr_warmup_steps", - type=int, - default=500, - help="Number of steps for the warmup in the lr scheduler.") - parser.add_argument("--use_8bit_adam", - action="store_true", - help="Whether or not to use 8-bit Adam from bitsandbytes.") parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.") parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.") @@ -195,12 +210,22 @@ def parse_args(input_args=None): default=None, help="The name of the repository to keep in sync with the local `output_dir`.", ) + parser.add_argument( + "-p", + "--plugin", + type=str, + default="torch_ddp", + choices=["torch_ddp", "torch_ddp_fp16", "gemini", "low_level_zero"], + help="plugin to use", + ) parser.add_argument( "--logging_dir", type=str, default="logs", - help=("[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to" - " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."), + help=( + "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to" + " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***." + ), ) parser.add_argument( "--mixed_precision", @@ -210,7 +235,8 @@ def parse_args(input_args=None): help=( "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=" " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the" - " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."), + " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config." + ), ) parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank") @@ -276,12 +302,14 @@ class DreamBoothDataset(Dataset): else: self.class_data_root = None - self.image_transforms = transforms.Compose([ - transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR), - transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size), - transforms.ToTensor(), - transforms.Normalize([0.5], [0.5]), - ]) + self.image_transforms = transforms.Compose( + [ + transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR), + transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size), + transforms.ToTensor(), + transforms.Normalize([0.5], [0.5]), + ] + ) def __len__(self): return self._length @@ -341,18 +369,6 @@ def get_full_repo_name(model_id: str, organization: Optional[str] = None, token: return f"{organization}/{model_id}" -# Gemini + ZeRO DDP -def gemini_zero_dpp(model: torch.nn.Module, placememt_policy: str = "auto"): - from colossalai.nn.parallel import GeminiDDP - - model = GeminiDDP(model, - device=get_current_device(), - placement_policy=placememt_policy, - pin_memory=True, - search_range_mb=64) - return model - - def main(args): if args.seed is None: colossalai.launch_from_torch(config={}) @@ -387,14 +403,14 @@ def main(args): pipeline.to(get_current_device()) for example in tqdm( - sample_dataloader, - desc="Generating class images", - disable=not local_rank == 0, + sample_dataloader, + desc="Generating class images", + disable=not local_rank == 0, ): images = pipeline(example["prompt"]).images for i, image in enumerate(images): - hash_image = hashlib.sha1(image.tobytes()).hexdigest() + hash_image = hashlib.sha256(image.tobytes()).hexdigest() image_filename = class_images_dir / f"{example['index'][i] + cur_class_images}-{hash_image}.jpg" image.save(image_filename) @@ -454,32 +470,38 @@ def main(args): revision=args.revision, ) - logger.info(f"Loading UNet2DConditionModel from {args.pretrained_model_name_or_path}", ranks=[0]) - with ColoInitContext(device=get_current_device()): - unet = UNet2DConditionModel.from_pretrained(args.pretrained_model_name_or_path, - subfolder="unet", - revision=args.revision, - low_cpu_mem_usage=False) - unet.requires_grad_(False) - - # Set correct lora layers - lora_attn_procs = {} - for name in unet.attn_processors.keys(): - cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim - if name.startswith("mid_block"): - hidden_size = unet.config.block_out_channels[-1] - elif name.startswith("up_blocks"): - block_id = int(name[len("up_blocks.")]) - hidden_size = list(reversed(unet.config.block_out_channels))[block_id] - elif name.startswith("down_blocks"): - block_id = int(name[len("down_blocks.")]) - hidden_size = unet.config.block_out_channels[block_id] - - lora_attn_procs[name] = LoRACrossAttnProcessor(hidden_size=hidden_size, - cross_attention_dim=cross_attention_dim) - - unet.set_attn_processor(lora_attn_procs) - lora_layers = AttnProcsLayers(unet.attn_processors) + if args.externel_unet_path is None: + logger.info(f"Loading UNet2DConditionModel from {args.pretrained_model_name_or_path}", ranks=[0]) + unet = UNet2DConditionModel.from_pretrained( + args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision, low_cpu_mem_usage=False + ) + else: + logger.info(f"Loading UNet2DConditionModel from {args.externel_unet_path}", ranks=[0]) + unet = UNet2DConditionModel.from_pretrained( + args.externel_unet_path, revision=args.revision, low_cpu_mem_usage=False + ) + unet = UNet2DConditionModel.from_pretrained( + args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision, low_cpu_mem_usage=False + ) + unet.requires_grad_(False) + + # Set correct lora layers + lora_attn_procs = {} + for name in unet.attn_processors.keys(): + cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim + if name.startswith("mid_block"): + hidden_size = unet.config.block_out_channels[-1] + elif name.startswith("up_blocks"): + block_id = int(name[len("up_blocks.")]) + hidden_size = list(reversed(unet.config.block_out_channels))[block_id] + elif name.startswith("down_blocks"): + block_id = int(name[len("down_blocks.")]) + hidden_size = unet.config.block_out_channels[block_id] + + lora_attn_procs[name] = LoRACrossAttnProcessor(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim) + + unet.set_attn_processor(lora_attn_procs) + AttnProcsLayers(unet.attn_processors) vae.requires_grad_(False) text_encoder.requires_grad_(False) @@ -490,10 +512,24 @@ def main(args): if args.scale_lr: args.learning_rate = args.learning_rate * args.train_batch_size * world_size - unet = gemini_zero_dpp(unet, args.placement) + # Use Booster API to use Gemini/Zero with ColossalAI + + booster_kwargs = {} + if args.plugin == "torch_ddp_fp16": + booster_kwargs["mixed_precision"] = "fp16" + if args.plugin.startswith("torch_ddp"): + plugin = TorchDDPPlugin() + elif args.plugin == "gemini": + plugin = GeminiPlugin(strict_ddp_mode=True, initial_scale=2**5) + elif args.plugin == "low_level_zero": + plugin = LowLevelZeroPlugin(initial_scale=2**5) + + booster = Booster(plugin=plugin, **booster_kwargs) # config optimizer for colossalai zero - optimizer = GeminiAdamOptimizer(unet, lr=args.learning_rate, initial_scale=2**5, clipping_norm=args.max_grad_norm) + optimizer = HybridAdam( + unet.parameters(), lr=args.learning_rate, initial_scale=2**5, clipping_norm=args.max_grad_norm + ) # load noise_scheduler noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler") @@ -524,9 +560,7 @@ def main(args): pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float() input_ids = tokenizer.pad( - { - "input_ids": input_ids - }, + {"input_ids": input_ids}, padding="max_length", max_length=tokenizer.model_max_length, return_tensors="pt", @@ -538,11 +572,9 @@ def main(args): } return batch - train_dataloader = torch.utils.data.DataLoader(train_dataset, - batch_size=args.train_batch_size, - shuffle=True, - collate_fn=collate_fn, - num_workers=1) + train_dataloader = torch.utils.data.DataLoader( + train_dataset, batch_size=args.train_batch_size, shuffle=True, collate_fn=collate_fn, num_workers=1 + ) # Scheduler and math around the number of training steps. overrode_max_train_steps = False @@ -576,6 +608,8 @@ def main(args): # Afterwards we recalculate our number of training epochs args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) + unet, optimizer, _, _, lr_scheduler = booster.boost(unet, optimizer, lr_scheduler=lr_scheduler) + # Train! total_batch_size = args.train_batch_size * world_size @@ -659,28 +693,26 @@ def main(args): logs = { "loss": loss.detach().item(), "lr": optimizer.param_groups[0]["lr"], - } # lr_scheduler.get_last_lr()[0]} + } # lr_scheduler.get_last_lr()[0]} progress_bar.set_postfix(**logs) if global_step % args.save_steps == 0: torch.cuda.synchronize() - torch_unet = get_static_torch_model(unet) + save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}") + booster.save_model(unet, os.path.join(save_path, "diffusion_pytorch_model.bin")) if local_rank == 0: - save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}") - torch_unet = torch_unet.to(torch.float32) - torch_unet.save_attn_procs(save_path) + if not os.path.exists(os.path.join(save_path, "config.json")): + shutil.copy(os.path.join(args.pretrained_model_name_or_path, "unet/config.json"), save_path) logger.info(f"Saving model checkpoint to {save_path}", ranks=[0]) if global_step >= args.max_train_steps: break - torch.cuda.synchronize() - torch_unet = get_static_torch_model(unet) + booster.save_model(unet, os.path.join(args.output_dir, "diffusion_pytorch_model.bin")) + logger.info(f"Saving model checkpoint to {args.output_dir} on rank {local_rank}") if local_rank == 0: - torch_unet = torch_unet.to(torch.float32) - torch_unet.save_attn_procs(save_path) - logger.info(f"Saving model checkpoint to {args.output_dir}", ranks=[0]) - + if not os.path.exists(os.path.join(args.output_dir, "config.json")): + shutil.copy(os.path.join(args.pretrained_model_name_or_path, "unet/config.json"), args.output_dir) if args.push_to_hub: repo.push_to_hub(commit_message="End of training", blocking=False, auto_lfs_prune=True) diff --git a/examples/images/dreambooth/train_dreambooth_inpaint.py b/examples/images/dreambooth/train_dreambooth_inpaint.py index 774cd4c458e9c02548a1165376857eb8398c4311..32f1b4959879b094123ced41621cdf214bacb104 100644 --- a/examples/images/dreambooth/train_dreambooth_inpaint.py +++ b/examples/images/dreambooth/train_dreambooth_inpaint.py @@ -126,8 +126,10 @@ def parse_args(): "--num_class_images", type=int, default=100, - help=("Minimal class images for prior preservation loss. If not have enough images, additional images will be" - " sampled with class_prompt."), + help=( + "Minimal class images for prior preservation loss. If not have enough images, additional images will be" + " sampled with class_prompt." + ), ) parser.add_argument( "--output_dir", @@ -140,17 +142,18 @@ def parse_args(): "--resolution", type=int, default=512, - help=("The resolution for input images, all the images in the train/validation dataset will be resized to this" - " resolution"), + help=( + "The resolution for input images, all the images in the train/validation dataset will be resized to this" + " resolution" + ), + ) + parser.add_argument( + "--center_crop", action="store_true", help="Whether to center crop images before resizing to resolution" ) - parser.add_argument("--center_crop", - action="store_true", - help="Whether to center crop images before resizing to resolution") parser.add_argument("--train_text_encoder", action="store_true", help="Whether to train the text encoder") - parser.add_argument("--train_batch_size", - type=int, - default=4, - help="Batch size (per device) for the training dataloader.") + parser.add_argument( + "--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader." + ) parser.add_argument("--sample_batch_size", type=int, default=4, help="Batch size (per device) for sampling images.") parser.add_argument("--num_train_epochs", type=int, default=1) parser.add_argument( @@ -186,16 +189,17 @@ def parse_args(): "--lr_scheduler", type=str, default="constant", - help=('The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",' - ' "constant", "constant_with_warmup"]'), + help=( + 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",' + ' "constant", "constant_with_warmup"]' + ), + ) + parser.add_argument( + "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler." + ) + parser.add_argument( + "--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes." ) - parser.add_argument("--lr_warmup_steps", - type=int, - default=500, - help="Number of steps for the warmup in the lr scheduler.") - parser.add_argument("--use_8bit_adam", - action="store_true", - help="Whether or not to use 8-bit Adam from bitsandbytes.") parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.") parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.") parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.") @@ -213,17 +217,21 @@ def parse_args(): "--logging_dir", type=str, default="logs", - help=("[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to" - " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."), + help=( + "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to" + " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***." + ), ) parser.add_argument( "--mixed_precision", type=str, default="no", choices=["no", "fp16", "bf16"], - help=("Whether to use mixed precision. Choose" - "between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >= 1.10." - "and an Nvidia Ampere GPU."), + help=( + "Whether to use mixed precision. Choose" + "between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >= 1.10." + "and an Nvidia Ampere GPU." + ), ) parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank") @@ -283,12 +291,14 @@ class DreamBoothDataset(Dataset): else: self.class_data_root = None - self.image_transforms = transforms.Compose([ - transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR), - transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size), - transforms.ToTensor(), - transforms.Normalize([0.5], [0.5]), - ]) + self.image_transforms = transforms.Compose( + [ + transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR), + transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size), + transforms.ToTensor(), + transforms.Normalize([0.5], [0.5]), + ] + ) def __len__(self): return self._length @@ -369,7 +379,8 @@ def main(): if args.train_text_encoder and args.gradient_accumulation_steps > 1 and accelerator.num_processes > 1: raise ValueError( "Gradient accumulation is not supported when training the text encoder in distributed training. " - "Please set gradient_accumulation_steps to 1. This feature will be supported in the future.") + "Please set gradient_accumulation_steps to 1. This feature will be supported in the future." + ) if args.seed is not None: set_seed(args.seed) @@ -382,25 +393,25 @@ def main(): if cur_class_images < args.num_class_images: torch_dtype = torch.float16 if accelerator.device.type == "cuda" else torch.float32 - pipeline = StableDiffusionInpaintPipeline.from_pretrained(args.pretrained_model_name_or_path, - torch_dtype=torch_dtype, - safety_checker=None) + pipeline = StableDiffusionInpaintPipeline.from_pretrained( + args.pretrained_model_name_or_path, torch_dtype=torch_dtype, safety_checker=None + ) pipeline.set_progress_bar_config(disable=True) num_new_images = args.num_class_images - cur_class_images logger.info(f"Number of class images to sample: {num_new_images}.") sample_dataset = PromptDataset(args.class_prompt, num_new_images) - sample_dataloader = torch.utils.data.DataLoader(sample_dataset, - batch_size=args.sample_batch_size, - num_workers=1) + sample_dataloader = torch.utils.data.DataLoader( + sample_dataset, batch_size=args.sample_batch_size, num_workers=1 + ) sample_dataloader = accelerator.prepare(sample_dataloader) pipeline.to(accelerator.device) transform_to_pil = transforms.ToPILImage() - for example in tqdm(sample_dataloader, - desc="Generating class images", - disable=not accelerator.is_local_main_process): + for example in tqdm( + sample_dataloader, desc="Generating class images", disable=not accelerator.is_local_main_process + ): bsz = len(example["prompt"]) fake_images = torch.rand((3, args.resolution, args.resolution)) transform_to_pil = transforms.ToPILImage() @@ -457,8 +468,9 @@ def main(): text_encoder.gradient_checkpointing_enable() if args.scale_lr: - args.learning_rate = (args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * - accelerator.num_processes) + args.learning_rate = ( + args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes + ) # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs if args.use_8bit_adam: @@ -471,8 +483,9 @@ def main(): else: optimizer_class = torch.optim.AdamW - params_to_optimize = (itertools.chain(unet.parameters(), text_encoder.parameters()) - if args.train_text_encoder else unet.parameters()) + params_to_optimize = ( + itertools.chain(unet.parameters(), text_encoder.parameters()) if args.train_text_encoder else unet.parameters() + ) optimizer = optimizer_class( params_to_optimize, lr=args.learning_rate, @@ -494,10 +507,12 @@ def main(): ) def collate_fn(examples): - image_transforms = transforms.Compose([ - transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR), - transforms.CenterCrop(args.resolution) if args.center_crop else transforms.RandomCrop(args.resolution), - ]) + image_transforms = transforms.Compose( + [ + transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR), + transforms.CenterCrop(args.resolution) if args.center_crop else transforms.RandomCrop(args.resolution), + ] + ) input_ids = [example["instance_prompt_ids"] for example in examples] pixel_values = [example["instance_images"] for example in examples] @@ -545,10 +560,9 @@ def main(): batch = {"input_ids": input_ids, "pixel_values": pixel_values, "masks": masks, "masked_images": masked_images} return batch - train_dataloader = torch.utils.data.DataLoader(train_dataset, - batch_size=args.train_batch_size, - shuffle=True, - collate_fn=collate_fn) + train_dataloader = torch.utils.data.DataLoader( + train_dataset, batch_size=args.train_batch_size, shuffle=True, collate_fn=collate_fn + ) # Scheduler and math around the number of training steps. overrode_max_train_steps = False @@ -566,10 +580,12 @@ def main(): if args.train_text_encoder: unet, text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( - unet, text_encoder, optimizer, train_dataloader, lr_scheduler) + unet, text_encoder, optimizer, train_dataloader, lr_scheduler + ) else: - unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(unet, optimizer, train_dataloader, - lr_scheduler) + unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + unet, optimizer, train_dataloader, lr_scheduler + ) weight_dtype = torch.float32 if args.mixed_precision == "fp16": @@ -622,16 +638,19 @@ def main(): latents = latents * 0.18215 # Convert masked images to latent space - masked_latents = vae.encode(batch["masked_images"].reshape( - batch["pixel_values"].shape).to(dtype=weight_dtype)).latent_dist.sample() + masked_latents = vae.encode( + batch["masked_images"].reshape(batch["pixel_values"].shape).to(dtype=weight_dtype) + ).latent_dist.sample() masked_latents = masked_latents * 0.18215 masks = batch["masks"] # resize the mask to latents shape as we concatenate the mask to the latents - mask = torch.stack([ - torch.nn.functional.interpolate(mask, size=(args.resolution // 8, args.resolution // 8)) - for mask in masks - ]) + mask = torch.stack( + [ + torch.nn.functional.interpolate(mask, size=(args.resolution // 8, args.resolution // 8)) + for mask in masks + ] + ) mask = mask.reshape(-1, 1, args.resolution // 8, args.resolution // 8) # Sample noise that we'll add to the latents @@ -680,8 +699,11 @@ def main(): accelerator.backward(loss) if accelerator.sync_gradients: - params_to_clip = (itertools.chain(unet.parameters(), text_encoder.parameters()) - if args.train_text_encoder else unet.parameters()) + params_to_clip = ( + itertools.chain(unet.parameters(), text_encoder.parameters()) + if args.train_text_encoder + else unet.parameters() + ) accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) optimizer.step() lr_scheduler.step() diff --git a/examples/tutorial/new_api/torch_ddp/.gitignore b/examples/images/resnet/.gitignore similarity index 100% rename from examples/tutorial/new_api/torch_ddp/.gitignore rename to examples/images/resnet/.gitignore diff --git a/examples/images/resnet/README.md b/examples/images/resnet/README.md new file mode 100644 index 0000000000000000000000000000000000000000..9a7493ea31a6761a9ad1528eec5c2b103e68edd6 --- /dev/null +++ b/examples/images/resnet/README.md @@ -0,0 +1,56 @@ +# Train ResNet on CIFAR-10 from scratch + +## 🚀 Quick Start + +This example provides a training script and an evaluation script. The training script provides an example of training ResNet on CIFAR10 dataset from scratch. + +- Training Arguments + - `-p`, `--plugin`: Plugin to use. Choices: `torch_ddp`, `torch_ddp_fp16`, `low_level_zero`. Defaults to `torch_ddp`. + - `-r`, `--resume`: Resume from checkpoint file path. Defaults to `-1`, which means not resuming. + - `-c`, `--checkpoint`: The folder to save checkpoints. Defaults to `./checkpoint`. + - `-i`, `--interval`: Epoch interval to save checkpoints. Defaults to `5`. If set to `0`, no checkpoint will be saved. + - `--target_acc`: Target accuracy. Raise exception if not reached. Defaults to `None`. + +- Eval Arguments + - `-e`, `--epoch`: select the epoch to evaluate + - `-c`, `--checkpoint`: the folder where checkpoints are found + +### Install requirements + +```bash +pip install -r requirements.txt +``` + +### Train +The folders will be created automatically. +```bash +# train with torch DDP with fp32 +colossalai run --nproc_per_node 2 train.py -c ./ckpt-fp32 + +# train with torch DDP with mixed precision training +colossalai run --nproc_per_node 2 train.py -c ./ckpt-fp16 -p torch_ddp_fp16 + +# train with low level zero +colossalai run --nproc_per_node 2 train.py -c ./ckpt-low_level_zero -p low_level_zero +``` + +### Eval + +```bash +# evaluate fp32 training +python eval.py -c ./ckpt-fp32 -e 80 + +# evaluate fp16 mixed precision training +python eval.py -c ./ckpt-fp16 -e 80 + +# evaluate low level zero training +python eval.py -c ./ckpt-low_level_zero -e 80 +``` + +Expected accuracy performance will be: + +| Model | Single-GPU Baseline FP32 | Booster DDP with FP32 | Booster DDP with FP16 | Booster Low Level Zero | Booster Gemini | +| --------- | ------------------------ | --------------------- | --------------------- | ---------------------- | -------------- | +| ResNet-18 | 85.85% | 84.91% | 85.46% | 84.50% | 84.60% | + +**Note: the baseline is adapted from the [script](https://pytorch-tutorial.readthedocs.io/en/latest/tutorial/chapter03_intermediate/3_2_2_cnn_resnet_cifar10/) to use `torchvision.models.resnet18`** diff --git a/examples/images/resnet/eval.py b/examples/images/resnet/eval.py new file mode 100644 index 0000000000000000000000000000000000000000..526e41a2850fe253991cb0636d98368a5970c408 --- /dev/null +++ b/examples/images/resnet/eval.py @@ -0,0 +1,47 @@ +import argparse + +import torch +import torchvision +import torchvision.transforms as transforms + +# ============================== +# Parse Arguments +# ============================== +parser = argparse.ArgumentParser() +parser.add_argument("-e", "--epoch", type=int, default=80, help="resume from the epoch's checkpoint") +parser.add_argument("-c", "--checkpoint", type=str, default="./checkpoint", help="checkpoint directory") +args = parser.parse_args() + +# ============================== +# Prepare Test Dataset +# ============================== +# CIFAR-10 dataset +test_dataset = torchvision.datasets.CIFAR10(root="./data/", train=False, transform=transforms.ToTensor()) + +# Data loader +test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=128, shuffle=False) + +# ============================== +# Load Model +# ============================== +model = torchvision.models.resnet18(num_classes=10).cuda() +state_dict = torch.load(f"{args.checkpoint}/model_{args.epoch}.pth") +model.load_state_dict(state_dict) + +# ============================== +# Run Evaluation +# ============================== +model.eval() + +with torch.no_grad(): + correct = 0 + total = 0 + for images, labels in test_loader: + images = images.cuda() + labels = labels.cuda() + outputs = model(images) + _, predicted = torch.max(outputs.data, 1) + total += labels.size(0) + correct += (predicted == labels).sum().item() + + print("Accuracy of the model on the test images: {} %".format(100 * correct / total)) diff --git a/examples/images/resnet/requirements.txt b/examples/images/resnet/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..46b7da7d487077cddc8a846e7a87eea06350ea98 --- /dev/null +++ b/examples/images/resnet/requirements.txt @@ -0,0 +1,5 @@ +colossalai +torch +torchvision +tqdm +pytest diff --git a/examples/images/resnet/test_ci.sh b/examples/images/resnet/test_ci.sh new file mode 100755 index 0000000000000000000000000000000000000000..b3fb67830dda424717eebbed893debffd0f41c97 --- /dev/null +++ b/examples/images/resnet/test_ci.sh @@ -0,0 +1,12 @@ +#!/bin/bash +set -xe + +export DATA=/data/scratch/cifar-10 + +pip install -r requirements.txt + +# TODO: skip ci test due to time limits, train.py needs to be rewritten. + +# for plugin in "torch_ddp" "torch_ddp_fp16" "low_level_zero"; do +# colossalai run --nproc_per_node 4 train.py --interval 0 --target_acc 0.84 --plugin $plugin +# done diff --git a/examples/images/resnet/train.py b/examples/images/resnet/train.py new file mode 100644 index 0000000000000000000000000000000000000000..13df516d4189864a03a2fef12980b1911e2a2c57 --- /dev/null +++ b/examples/images/resnet/train.py @@ -0,0 +1,207 @@ +import argparse +import os +from pathlib import Path + +import torch +import torch.distributed as dist +import torch.nn as nn +import torchvision +import torchvision.transforms as transforms +from torch.optim import Optimizer +from torch.optim.lr_scheduler import MultiStepLR +from torch.utils.data import DataLoader +from tqdm import tqdm + +import colossalai +from colossalai.booster import Booster +from colossalai.booster.plugin import GeminiPlugin, LowLevelZeroPlugin, TorchDDPPlugin +from colossalai.booster.plugin.dp_plugin_base import DPPluginBase +from colossalai.cluster import DistCoordinator +from colossalai.nn.optimizer import HybridAdam +from colossalai.utils import get_current_device + +# ============================== +# Prepare Hyperparameters +# ============================== +NUM_EPOCHS = 80 +LEARNING_RATE = 1e-3 + + +def build_dataloader(batch_size: int, coordinator: DistCoordinator, plugin: DPPluginBase): + # transform + transform_train = transforms.Compose( + [transforms.Pad(4), transforms.RandomHorizontalFlip(), transforms.RandomCrop(32), transforms.ToTensor()] + ) + transform_test = transforms.ToTensor() + + # CIFAR-10 dataset + data_path = os.environ.get("DATA", "./data") + with coordinator.priority_execution(): + train_dataset = torchvision.datasets.CIFAR10( + root=data_path, train=True, transform=transform_train, download=True + ) + test_dataset = torchvision.datasets.CIFAR10( + root=data_path, train=False, transform=transform_test, download=True + ) + + # Data loader + train_dataloader = plugin.prepare_dataloader(train_dataset, batch_size=batch_size, shuffle=True, drop_last=True) + test_dataloader = plugin.prepare_dataloader(test_dataset, batch_size=batch_size, shuffle=False, drop_last=False) + return train_dataloader, test_dataloader + + +@torch.no_grad() +def evaluate(model: nn.Module, test_dataloader: DataLoader, coordinator: DistCoordinator) -> float: + model.eval() + correct = torch.zeros(1, dtype=torch.int64, device=get_current_device()) + total = torch.zeros(1, dtype=torch.int64, device=get_current_device()) + for images, labels in test_dataloader: + images = images.cuda() + labels = labels.cuda() + outputs = model(images) + _, predicted = torch.max(outputs.data, 1) + total += labels.size(0) + correct += (predicted == labels).sum().item() + dist.all_reduce(correct) + dist.all_reduce(total) + accuracy = correct.item() / total.item() + if coordinator.is_master(): + print(f"Accuracy of the model on the test images: {accuracy * 100:.2f} %") + return accuracy + + +def train_epoch( + epoch: int, + model: nn.Module, + optimizer: Optimizer, + criterion: nn.Module, + train_dataloader: DataLoader, + booster: Booster, + coordinator: DistCoordinator, +): + model.train() + with tqdm(train_dataloader, desc=f"Epoch [{epoch + 1}/{NUM_EPOCHS}]", disable=not coordinator.is_master()) as pbar: + for images, labels in pbar: + images = images.cuda() + labels = labels.cuda() + # Forward pass + outputs = model(images) + loss = criterion(outputs, labels) + + # Backward and optimize + booster.backward(loss, optimizer) + optimizer.step() + optimizer.zero_grad() + + # Print log info + pbar.set_postfix({"loss": loss.item()}) + + +def main(): + # ============================== + # Parse Arguments + # ============================== + parser = argparse.ArgumentParser() + # FIXME(ver217): gemini is not supported resnet now + parser.add_argument( + "-p", + "--plugin", + type=str, + default="torch_ddp", + choices=["torch_ddp", "torch_ddp_fp16", "low_level_zero", "gemini"], + help="plugin to use", + ) + parser.add_argument("-r", "--resume", type=int, default=-1, help="resume from the epoch's checkpoint") + parser.add_argument("-c", "--checkpoint", type=str, default="./checkpoint", help="checkpoint directory") + parser.add_argument("-i", "--interval", type=int, default=5, help="interval of saving checkpoint") + parser.add_argument( + "--target_acc", type=float, default=None, help="target accuracy. Raise exception if not reached" + ) + args = parser.parse_args() + + # ============================== + # Prepare Checkpoint Directory + # ============================== + if args.interval > 0: + Path(args.checkpoint).mkdir(parents=True, exist_ok=True) + + # ============================== + # Launch Distributed Environment + # ============================== + colossalai.launch_from_torch(config={}) + coordinator = DistCoordinator() + + # update the learning rate with linear scaling + # old_gpu_num / old_lr = new_gpu_num / new_lr + global LEARNING_RATE + LEARNING_RATE *= coordinator.world_size + + # ============================== + # Instantiate Plugin and Booster + # ============================== + booster_kwargs = {} + if args.plugin == "torch_ddp_fp16": + booster_kwargs["mixed_precision"] = "fp16" + if args.plugin.startswith("torch_ddp"): + plugin = TorchDDPPlugin() + elif args.plugin == "gemini": + plugin = GeminiPlugin(initial_scale=2**5) + elif args.plugin == "low_level_zero": + plugin = LowLevelZeroPlugin(initial_scale=2**5) + + booster = Booster(plugin=plugin, **booster_kwargs) + + # ============================== + # Prepare Dataloader + # ============================== + train_dataloader, test_dataloader = build_dataloader(100, coordinator, plugin) + + # ==================================== + # Prepare model, optimizer, criterion + # ==================================== + # resent50 + model = torchvision.models.resnet18(num_classes=10) + + # Loss and optimizer + criterion = nn.CrossEntropyLoss() + optimizer = HybridAdam(model.parameters(), lr=LEARNING_RATE) + + # lr scheduler + lr_scheduler = MultiStepLR(optimizer, milestones=[20, 40, 60, 80], gamma=1 / 3) + + # ============================== + # Boost with ColossalAI + # ============================== + model, optimizer, criterion, _, lr_scheduler = booster.boost( + model, optimizer, criterion=criterion, lr_scheduler=lr_scheduler + ) + + # ============================== + # Resume from checkpoint + # ============================== + if args.resume >= 0: + booster.load_model(model, f"{args.checkpoint}/model_{args.resume}.pth") + booster.load_optimizer(optimizer, f"{args.checkpoint}/optimizer_{args.resume}.pth") + booster.load_lr_scheduler(lr_scheduler, f"{args.checkpoint}/lr_scheduler_{args.resume}.pth") + + # ============================== + # Train model + # ============================== + start_epoch = args.resume if args.resume >= 0 else 0 + for epoch in range(start_epoch, NUM_EPOCHS): + train_epoch(epoch, model, optimizer, criterion, train_dataloader, booster, coordinator) + lr_scheduler.step() + + # save checkpoint + if args.interval > 0 and (epoch + 1) % args.interval == 0: + booster.save_model(model, f"{args.checkpoint}/model_{epoch + 1}.pth") + booster.save_optimizer(optimizer, f"{args.checkpoint}/optimizer_{epoch + 1}.pth") + booster.save_lr_scheduler(lr_scheduler, f"{args.checkpoint}/lr_scheduler_{epoch + 1}.pth") + + accuracy = evaluate(model, test_dataloader, coordinator) + if args.target_acc is not None: + assert accuracy >= args.target_acc, f"Accuracy {accuracy} is lower than target accuracy {args.target_acc}" + + +if __name__ == "__main__": + main() diff --git a/examples/images/vit/README.md b/examples/images/vit/README.md index 4423d85d19e0549a04138097aa7526d8f13d71a4..33c6454ad92cbdb01b715bd11c0c4e3e8d0b7a2d 100644 --- a/examples/images/vit/README.md +++ b/examples/images/vit/README.md @@ -1,61 +1,28 @@ -# Vision Transformer with ColoTensor +## Overview -# Overview +Vision Transformer is a class of Transformer model tailored for computer vision tasks. It was first proposed in paper [An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale](https://arxiv.org/abs/2010.11929) and achieved SOTA results on various tasks at that time. -In this example, we will run Vision Transformer with ColoTensor. +In our example, we are using pretrained weights of ViT loaded from HuggingFace. +We adapt the ViT training code to ColossalAI by leveraging [Boosting API](https://colossalai.org/docs/basics/booster_api) loaded with a chosen plugin, where each plugin corresponds to a specific kind of training strategy. This example supports plugins including TorchDDPPlugin (DDP), LowLevelZeroPlugin (Zero1/Zero2), GeminiPlugin (Gemini) and HybridParallelPlugin (any combination of tensor/pipeline/data parallel). -We use model **ViTForImageClassification** from Hugging Face [Link](https://huggingface.co/docs/transformers/model_doc/vit) for unit test. -You can change world size or decide whether use DDP in our code. +## Run Demo -We use model **vision_transformer** from timm [Link](https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py) for training example. - -(2022/6/28) The default configuration now supports 2DP+2TP with gradient accumulation and checkpoint support. Zero is not supported at present. - -# Requirement - -Install colossalai version >= 0.1.11 - -## Unit test -To run unit test, you should install pytest, transformers with: -```shell -pip install pytest transformers +By running the following script: +```bash +bash run_demo.sh ``` +You will finetune a a [ViT-base](https://huggingface.co/google/vit-base-patch16-224) model on this [dataset](https://huggingface.co/datasets/beans), with more than 8000 images of bean leaves. This dataset is for image classification task and there are 3 labels: ['angular_leaf_spot', 'bean_rust', 'healthy']. -## Training example -To run training example with ViT-S, you should install **NVIDIA DALI** from [Link](https://docs.nvidia.com/deeplearning/dali/user-guide/docs/installation.html) for dataloader support. -You also need to install timm and titans for model/dataloader support with: -```shell -pip install timm titans -``` +The script can be modified if you want to try another set of hyperparameters or change to another ViT model with different size. -### Data preparation -You can download the ImageNet dataset from the [ImageNet official website](https://www.image-net.org/download.php). You should get the raw images after downloading the dataset. As we use **NVIDIA DALI** to read data, we use the TFRecords dataset instead of raw Imagenet dataset. This offers better speedup to IO. If you don't have TFRecords dataset, follow [imagenet-tools](https://github.com/ver217/imagenet-tools) to build one. +The demo code refers to this [blog](https://huggingface.co/blog/fine-tune-vit). -Before you start training, you need to set the environment variable `DATA` so that the script knows where to fetch the data for DALI dataloader. -```shell -export DATA=/path/to/ILSVRC2012 -``` -# How to run +## Run Benchmark -## Unit test -In your terminal -```shell -pytest test_vit.py +You can run benchmark for ViT model by running the following script: +```bash +bash run_benchmark.sh ``` - -This will evaluate models with different **world_size** and **use_ddp**. - -## Training example -Modify the settings in run.sh according to your environment. -For example, if you set `--nproc_per_node=8` in `run.sh` and `TP_WORLD_SIZE=2` in your config file, -data parallel size will be automatically calculated as 4. -Thus, the parallel strategy is set to 4DP+2TP. - -Then in your terminal -```shell -sh run.sh -``` - -This will start ViT-S training with ImageNet. +The script will test performance (throughput & peak memory usage) for each combination of hyperparameters. You can also play with this script to configure your own set of hyperparameters for testing. diff --git a/examples/images/vit/args.py b/examples/images/vit/args.py new file mode 100644 index 0000000000000000000000000000000000000000..9de4743ef94d87a53fe29f2bd19d31866f0d54cf --- /dev/null +++ b/examples/images/vit/args.py @@ -0,0 +1,86 @@ +import argparse + + +def parse_demo_args(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--model_name_or_path", + type=str, + default="google/vit-base-patch16-224", + help="Path to pretrained model or model identifier from huggingface.co/models.", + ) + parser.add_argument( + "--output_path", type=str, default="./output_model", help="The path of your saved model after finetuning." + ) + parser.add_argument( + "--plugin", + type=str, + default="gemini", + help="Plugin to use. Valid plugins include 'torch_ddp','torch_ddp_fp16','gemini','low_level_zero', 'hybrid_parallel'.", + ) + parser.add_argument("--num_epoch", type=int, default=3, help="Number of epochs.") + parser.add_argument( + "--batch_size", type=int, default=32, help="Batch size (per dp group) for the training dataloader." + ) + parser.add_argument( + "--tp_size", + type=int, + default=1, + help="The size along tensor parallel dimension, only be used when enabling hybrid parallel.", + ) + parser.add_argument( + "--pp_size", + type=int, + default=1, + help="The size along pipeline parallel dimension, only be used when enabling hybrid parallel.", + ) + parser.add_argument( + "--learning_rate", + type=float, + default=3e-4, + help="Initial learning rate (after the potential warmup period) to use.", + ) + parser.add_argument( + "--warmup_ratio", type=float, default=0.3, help="Ratio of warmup steps against total training steps." + ) + parser.add_argument("--weight_decay", type=float, default=0.1, help="Weight decay to use.") + parser.add_argument("--grad_checkpoint", type=bool, default=True, help="Whether to use gradient checkpointing.") + parser.add_argument("--seed", type=int, default=42, help="A seed for reproducible training.") + + args = parser.parse_args() + return args + + +def parse_benchmark_args(): + parser = argparse.ArgumentParser() + + parser.add_argument( + "--model_name_or_path", + type=str, + default="google/vit-base-patch16-224", + help="Path to a pretrained model or model identifier from huggingface.co/models.", + ) + parser.add_argument( + "--plugin", + type=str, + default="gemini", + help="Plugin to use. Valid plugins include 'torch_ddp','torch_ddp_fp16','gemini','low_level_zero', 'hybrid_parallel'.", + ) + parser.add_argument( + "--batch_size", type=int, default=8, help="Batch size (per dp group) for the training dataloader." + ) + parser.add_argument("--num_labels", type=int, default=10, help="Number of labels for classification.") + parser.add_argument( + "--learning_rate", + type=float, + default=5e-5, + help="Initial learning rate (after the potential warmup period) to use.", + ) + parser.add_argument("--weight_decay", type=float, default=0.0, help="Weight decay to use.") + parser.add_argument("--grad_checkpoint", type=bool, default=True, help="Whether to use gradient checkpointing.") + parser.add_argument("--max_train_steps", type=int, default=20, help="Total number of training steps to perform.") + parser.add_argument("--seed", type=int, default=42, help="A seed for reproducible training.") + parser.add_argument("--mem_cap", type=int, default=0, help="Limit on the usage of space for each GPU (in GB).") + args = parser.parse_args() + + return args diff --git a/examples/images/vit/configs/vit_1d_tp2.py b/examples/images/vit/configs/vit_1d_tp2.py deleted file mode 100644 index fbf399f2e50daaa70f52726ccfa6f4e035ce7380..0000000000000000000000000000000000000000 --- a/examples/images/vit/configs/vit_1d_tp2.py +++ /dev/null @@ -1,32 +0,0 @@ -from colossalai.amp import AMP_TYPE - -# hyperparameters -# BATCH_SIZE is as per GPU -# global batch size = BATCH_SIZE x data parallel size -BATCH_SIZE = 256 -LEARNING_RATE = 3e-3 -WEIGHT_DECAY = 0.3 -NUM_EPOCHS = 300 -WARMUP_EPOCHS = 32 - -# model config -IMG_SIZE = 224 -PATCH_SIZE = 16 -HIDDEN_SIZE = 384 -DEPTH = 12 -NUM_HEADS = 6 -MLP_RATIO = 4 -NUM_CLASSES = 1000 -CHECKPOINT = False -SEQ_LENGTH = (IMG_SIZE // PATCH_SIZE)**2 + 1 # add 1 for cls token - -USE_DDP = True -TP_WORLD_SIZE = 2 -TP_TYPE = 'row' -parallel = dict(tensor=dict(mode="1d", size=TP_WORLD_SIZE),) - -fp16 = dict(mode=AMP_TYPE.NAIVE) -clip_grad_norm = 1.0 -gradient_accumulation = 8 - -LOG_PATH = "./log" diff --git a/examples/images/vit/configs/vit_1d_tp2_ci.py b/examples/images/vit/configs/vit_1d_tp2_ci.py deleted file mode 100644 index e491e4ada45e25f57c8e6c93df41e36794f0b420..0000000000000000000000000000000000000000 --- a/examples/images/vit/configs/vit_1d_tp2_ci.py +++ /dev/null @@ -1,32 +0,0 @@ -from colossalai.amp import AMP_TYPE - -# hyperparameters -# BATCH_SIZE is as per GPU -# global batch size = BATCH_SIZE x data parallel size -BATCH_SIZE = 8 -LEARNING_RATE = 3e-3 -WEIGHT_DECAY = 0.3 -NUM_EPOCHS = 3 -WARMUP_EPOCHS = 1 - -# model config -IMG_SIZE = 224 -PATCH_SIZE = 16 -HIDDEN_SIZE = 32 -DEPTH = 2 -NUM_HEADS = 4 -MLP_RATIO = 4 -NUM_CLASSES = 10 -CHECKPOINT = False -SEQ_LENGTH = (IMG_SIZE // PATCH_SIZE)**2 + 1 # add 1 for cls token - -USE_DDP = True -TP_WORLD_SIZE = 2 -TP_TYPE = 'row' -parallel = dict(tensor=dict(mode="1d", size=TP_WORLD_SIZE),) - -fp16 = dict(mode=AMP_TYPE.NAIVE) -clip_grad_norm = 1.0 -gradient_accumulation = 2 - -LOG_PATH = "./log_ci" diff --git a/examples/images/vit/data.py b/examples/images/vit/data.py new file mode 100644 index 0000000000000000000000000000000000000000..5361fe9a3bad098ef2fc89ff0a6a9ae15d649089 --- /dev/null +++ b/examples/images/vit/data.py @@ -0,0 +1,36 @@ +import torch +from datasets import load_dataset +from torch.utils.data import Dataset + + +class BeansDataset(Dataset): + def __init__(self, image_processor, tp_size=1, split="train"): + super().__init__() + self.image_processor = image_processor + self.ds = load_dataset("beans")[split] + self.label_names = self.ds.features["labels"].names + while len(self.label_names) % tp_size != 0: + # ensure that the number of labels is multiple of tp_size + self.label_names.append(f"pad_label_{len(self.label_names)}") + self.num_labels = len(self.label_names) + self.inputs = [] + for example in self.ds: + self.inputs.append(self.process_example(example)) + + def __len__(self): + return len(self.inputs) + + def __getitem__(self, idx): + return self.inputs[idx] + + def process_example(self, example): + input = self.image_processor(example["image"], return_tensors="pt") + input["labels"] = example["labels"] + return input + + +def beans_collator(batch): + return { + "pixel_values": torch.cat([data["pixel_values"] for data in batch], dim=0), + "labels": torch.tensor([data["labels"] for data in batch], dtype=torch.int64), + } diff --git a/examples/images/vit/requirements.txt b/examples/images/vit/requirements.txt index 1f69794ebe700eda3b4f37e17d9253978a938125..69e41c61cd67fb85adb0ca0330c91d69035d7ed3 100644 --- a/examples/images/vit/requirements.txt +++ b/examples/images/vit/requirements.txt @@ -1,8 +1,6 @@ colossalai >= 0.1.12 torch >= 1.8.1 numpy>=1.24.1 -timm>=0.6.12 -titans>=0.0.7 tqdm>=4.61.2 -transformers>=4.25.1 -nvidia-dali-cuda110>=1.8.0 --extra-index-url https://developer.download.nvidia.com/compute/redist +transformers>=4.20.0 +datasets diff --git a/examples/images/vit/run.sh b/examples/images/vit/run.sh deleted file mode 100644 index 84fe58f11a6a7c25242eab1714fd94698509be9a..0000000000000000000000000000000000000000 --- a/examples/images/vit/run.sh +++ /dev/null @@ -1,15 +0,0 @@ -export DATA=/data/scratch/imagenet/tf_records -export OMP_NUM_THREADS=4 - -# resume -# CUDA_VISIBLE_DEVICES=4,5,6,7 colossalai run \ -# --nproc_per_node 4 train.py \ -# --config configs/vit_1d_tp2.py \ -# --resume_from checkpoint/epoch_10 \ -# --master_port 29598 | tee ./out 2>&1 - -# train -CUDA_VISIBLE_DEVICES=4,5,6,7 colossalai run \ ---nproc_per_node 4 train.py \ ---config configs/vit_1d_tp2.py \ ---master_port 29598 | tee ./out 2>&1 diff --git a/examples/images/vit/run_benchmark.sh b/examples/images/vit/run_benchmark.sh new file mode 100644 index 0000000000000000000000000000000000000000..ad41a283711cf10424f5944f6d2682f95444b012 --- /dev/null +++ b/examples/images/vit/run_benchmark.sh @@ -0,0 +1,24 @@ +set -xe +pip install -r requirements.txt + +export BS=8 +export MEMCAP=0 +export GPUNUM=1 + +for BS in 8 32 +do +for PLUGIN in "torch_ddp" "torch_ddp_fp16" "low_level_zero" "gemini" "hybrid_parallel" +do + +MODEL_PATH="google/vit-base-patch16-224" +colossalai run \ + --nproc_per_node ${GPUNUM} \ + --master_port 29505 \ + vit_benchmark.py \ + --model_name_or_path ${MODEL_PATH} \ + --mem_cap ${MEMCAP} \ + --plugin ${PLUGIN} \ + --batch_size ${BS} + +done +done diff --git a/examples/images/vit/run_demo.sh b/examples/images/vit/run_demo.sh new file mode 100644 index 0000000000000000000000000000000000000000..8eead0661454fbc6a0297231b39a009f7427c259 --- /dev/null +++ b/examples/images/vit/run_demo.sh @@ -0,0 +1,51 @@ +set -xe +pip install -r requirements.txt + +# model name or path +MODEL="google/vit-base-patch16-224" + +# path for saving model +OUTPUT_PATH="./output_model" + +# plugin(training strategy) +# can only be one of "torch_ddp"/"torch_ddp_fp16"/"low_level_zero"/"gemini"/"hybrid_parallel" +PLUGIN="gemini" +#PLUGIN="hybrid_parallel" + +# configuration of parallel group sizes, only used when setting PLUGIN to "hybrid_parallel" +TP_SIZE=2 +PP_SIZE=2 + +# number of gpus to use +GPUNUM=4 + +# batch size per data parallel group +BS=16 + +# learning rate +LR="2e-4" + +# number of epoch +EPOCH=3 + +# weight decay +WEIGHT_DECAY=0.05 + +# ratio of warmup steps +WARMUP_RATIO=0.3 + +# run the script for demo +colossalai run \ + --nproc_per_node ${GPUNUM} \ + --master_port 29505 \ + vit_train_demo.py \ + --model_name_or_path ${MODEL} \ + --output_path ${OUTPUT_PATH} \ + --plugin ${PLUGIN} \ + --batch_size ${BS} \ + --tp_size ${TP_SIZE} \ + --pp_size ${PP_SIZE} \ + --num_epoch ${EPOCH} \ + --learning_rate ${LR} \ + --weight_decay ${WEIGHT_DECAY} \ + --warmup_ratio ${WARMUP_RATIO} diff --git a/examples/images/vit/test_ci.sh b/examples/images/vit/test_ci.sh index 41d25ee23521d85040efa510a6b792344c0a62bb..fc1f2b7a2ee074607e8a82d778451407478eec3e 100644 --- a/examples/images/vit/test_ci.sh +++ b/examples/images/vit/test_ci.sh @@ -1,9 +1,16 @@ -export OMP_NUM_THREADS=4 - +set -xe pip install -r requirements.txt -# train +BS=8 +for PLUGIN in "torch_ddp" "torch_ddp_fp16" "low_level_zero" "gemini" "hybrid_parallel" +do + colossalai run \ ---nproc_per_node 4 train.py \ ---config configs/vit_1d_tp2_ci.py \ ---dummy_data + --nproc_per_node 4 \ + --master_port 29505 \ + vit_benchmark.py \ + --model_name_or_path "google/vit-base-patch16-224" \ + --plugin ${PLUGIN} \ + --batch_size ${BS} + +done diff --git a/examples/images/vit/test_vit.py b/examples/images/vit/test_vit.py deleted file mode 100644 index c0ae35bca87169c2813a3c796dd84064202742ac..0000000000000000000000000000000000000000 --- a/examples/images/vit/test_vit.py +++ /dev/null @@ -1,160 +0,0 @@ -import os -import random - -import numpy as np -import pytest -import torch -from torch.nn.parallel import DistributedDataParallel as DDP -from vit import get_training_components - -import colossalai -from colossalai.context import ParallelMode -from colossalai.context.parallel_mode import ParallelMode -from colossalai.core import global_context as gpc -from colossalai.nn.parallel.data_parallel import ColoDDP -from colossalai.tensor import ComputePattern, ComputeSpec, DistSpecManager, ProcessGroup, ShardSpec -from colossalai.testing import rerun_if_address_is_in_use, spawn -from colossalai.utils.cuda import get_current_device -from colossalai.zero import ColoInitContext - - -def set_seed(seed): - random.seed(seed) - os.environ['PYTHONHASHSEED'] = str(seed) - np.random.seed(seed) - torch.manual_seed(seed) - torch.cuda.manual_seed(seed) - torch.backends.cudnn.deterministic = True - - -def tensor_equal(A, B): - return torch.allclose(A, B, rtol=1e-3, atol=1e-1) - - -def tensor_shard_equal(tensor: torch.Tensor, shard: torch.Tensor): - assert tensor.ndim == shard.ndim - if tensor.shape == shard.shape: - return tensor_equal(tensor, shard) - else: - dims_not_eq = torch.nonzero(torch.tensor(tensor.shape) != torch.tensor(shard.shape)) - if dims_not_eq.numel() == 1: - # 1D shard - dim = dims_not_eq.item() - world_size = gpc.get_world_size(ParallelMode.PARALLEL_1D) - rank = gpc.get_local_rank(ParallelMode.PARALLEL_1D) - return tensor_equal(tensor.chunk(world_size, dim)[rank], shard) - else: - raise - - -# Only for all Linear, it's 1d_row split because Linear will be transposed when calculating. -# But for other layers, it's 1d_col split. -# Layernorm is not supported for now. -# patch_embeddings.projection has nn.Conv2d -# https://github.com/huggingface/transformers/blob/dcb08b99f44919425f8ba9be9ddcc041af8ec25e/src/transformers/models/vit/modeling_vit.py#L182 -def init_1d_row_for_linear_weight_spec(model, world_size: int): - pg = ProcessGroup(tp_degree=world_size) - spec = (ShardSpec([-1], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D)) - with DistSpecManager.no_grad(): - for n, p in model.named_parameters(): - if 'weight' in n and 'layernorm' not in n and 'embeddings.patch_embeddings.projection.weight' not in n: - p.set_process_group(pg) - p.set_tensor_spec(*spec) - - -# Similarly, it's col split for Linear but row split for others. -def init_1d_col_for_linear_weight_bias_spec(model, world_size: int): - pg = ProcessGroup(tp_degree=world_size) - spec = (ShardSpec([0], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D)) - with DistSpecManager.no_grad(): - for n, p in model.named_parameters(): - if ('weight' in n - or 'bias' in n) and 'layernorm' not in n and 'embeddings.patch_embeddings.projection' not in n: - p.set_process_group(pg) - p.set_tensor_spec(*spec) - - -def check_param_equal(model, torch_model): - for p, torch_p in zip(model.parameters(), torch_model.parameters()): - assert tensor_shard_equal(torch_p, p) - - -def check_grad_equal(model, torch_model): - for p, torch_p in zip(model.parameters(), torch_model.parameters()): - if (torch_p.grad.shape == p.grad.shape): - assert torch.allclose(torch_p.grad, p.grad, rtol=1e-3, atol=2.0) == True - else: - dims_not_eq = torch.nonzero(torch.tensor(torch_p.grad.shape) != torch.tensor(p.grad.shape)) - dim = dims_not_eq.item() - world_size = gpc.get_world_size(ParallelMode.PARALLEL_1D) - rank = gpc.get_local_rank(ParallelMode.PARALLEL_1D) - assert torch.allclose(torch_p.grad.chunk(world_size, dim)[rank], p.grad, rtol=1e-3, atol=2.0) == True - - -def run_vit(init_spec_func, use_ddp): - model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_training_components() - with ColoInitContext(device=get_current_device()): - model = model_builder() - model = model.cuda() - torch_model = model_builder().cuda() - if use_ddp: - model = ColoDDP(model) - torch_model = DDP(torch_model, - device_ids=[gpc.get_global_rank()], - process_group=gpc.get_group(ParallelMode.DATA)) - for torch_p, p in zip(torch_model.parameters(), model.parameters()): - torch_p.data.copy_(p) - - world_size = torch.distributed.get_world_size() - init_spec_func(model, world_size) - - check_param_equal(model, torch_model) - model.train() - torch_model.train() - set_seed(gpc.get_local_rank(ParallelMode.DATA)) - - optimizer = optimizer_class(model.parameters(), lr=0.001, betas=(0.9, 0.999), eps=1e-08, weight_decay=0) - torch_optimizer = optimizer_class(torch_model.parameters(), lr=0.001, betas=(0.9, 0.999), eps=1e-08, weight_decay=0) - - for i, image_dict in enumerate(train_dataloader): - if use_ddp: - model.zero_grad() - else: - optimizer.zero_grad() - logits = model(image_dict['pixel_values']) - torch_logits = torch_model(image_dict['pixel_values']) - assert tensor_equal(torch_logits.logits, logits.logits) - loss = criterion(logits.logits, image_dict['label']) - torch_loss = criterion(torch_logits.logits, image_dict['label']) - if use_ddp: - model.backward(loss) - else: - loss.backward() - torch_loss.backward() - check_grad_equal(model, torch_model) - optimizer.step() - torch_optimizer.step() - check_param_equal(model, torch_model) - break - - -def run_dist(rank, world_size, port, use_ddp): - if use_ddp and world_size == 1: - return - tp_world_size = world_size // 2 if use_ddp else world_size - config = dict(parallel=dict(tensor=dict(mode="1d", size=tp_world_size),)) - colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') - run_vit(init_1d_row_for_linear_weight_spec, use_ddp) - run_vit(init_1d_col_for_linear_weight_bias_spec, use_ddp) - - -@pytest.mark.dist -@pytest.mark.parametrize('world_size', [1, 4]) -@pytest.mark.parametrize('use_ddp', [False, True]) -@rerun_if_address_is_in_use() -def test_vit(world_size, use_ddp): - spawn(run_dist, world_size, use_ddp=use_ddp) - - -if __name__ == '__main__': - test_vit(1, False) diff --git a/examples/images/vit/train.py b/examples/images/vit/train.py deleted file mode 100644 index b42cf2bedc6bf738c17230c9fa40fcfc3418ba99..0000000000000000000000000000000000000000 --- a/examples/images/vit/train.py +++ /dev/null @@ -1,174 +0,0 @@ -import os - -import torch -import torch.distributed as dist -import torch.nn as nn -import torch.nn.functional as F -from timm.models.vision_transformer import _create_vision_transformer -from titans.dataloader.imagenet import build_dali_imagenet -from tqdm import tqdm -from vit import DummyDataLoader - -import colossalai -from colossalai.core import global_context as gpc -from colossalai.logging import disable_existing_loggers, get_dist_logger -from colossalai.nn import CrossEntropyLoss -from colossalai.nn._ops import * -from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR -from colossalai.nn.optimizer import HybridAdam -from colossalai.nn.parallel.data_parallel import ColoDDP -from colossalai.tensor import ComputePattern, ComputeSpec, DistSpecManager, ProcessGroup, ShardSpec -from colossalai.utils import get_current_device -from colossalai.zero import ColoInitContext - - -def init_1d_row_for_linear_weight_spec(model, world_size: int): - pg = ProcessGroup(tp_degree=world_size) - spec = (ShardSpec([-1], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D)) - with DistSpecManager.no_grad(): - for n, p in model.named_parameters(): - if 'weight' in n and 'norm' not in n and 'patch_embed.proj.weight' not in n: - p.set_process_group(pg) - p.set_tensor_spec(*spec) - - -# Similarly, it's col split for Linear but row split for others. -def init_1d_col_for_linear_weight_bias_spec(model, world_size: int): - pg = ProcessGroup(tp_degree=world_size) - spec = (ShardSpec([0], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D)) - with DistSpecManager.no_grad(): - for n, p in model.named_parameters(): - if ('weight' in n or 'bias' in n) and 'norm' not in n and ('patch_embed.proj.weight' not in n - and 'patch_embed.proj.bias' not in n): - p.set_process_group(pg) - p.set_tensor_spec(*spec) - - -def init_spec_func(model, tp_type): - world_size = torch.distributed.get_world_size() - if tp_type == 'row': - init_1d_row_for_linear_weight_spec(model, world_size) - elif tp_type == 'col': - init_1d_col_for_linear_weight_bias_spec(model, world_size) - else: - raise NotImplemented - - -def train_imagenet(): - - parser = colossalai.get_default_parser() - parser.add_argument('--resume_from', default=False, action='store_true') - parser.add_argument('--dummy_data', default=False, action='store_true') - - args = parser.parse_args() - colossalai.launch_from_torch(config=args.config) - use_ddp = gpc.config.USE_DDP - - disable_existing_loggers() - - logger = get_dist_logger() - if hasattr(gpc.config, 'LOG_PATH'): - if gpc.get_global_rank() == 0: - log_path = gpc.config.LOG_PATH - if not os.path.exists(log_path): - os.mkdir(log_path) - logger.log_to_file(log_path) - - logger.info('Build data loader', ranks=[0]) - if not args.dummy_data: - root = os.environ['DATA'] - train_dataloader, test_dataloader = build_dali_imagenet(root, - train_batch_size=gpc.config.BATCH_SIZE, - test_batch_size=gpc.config.BATCH_SIZE) - else: - train_dataloader = DummyDataLoader(length=10, - batch_size=gpc.config.BATCH_SIZE, - category=gpc.config.NUM_CLASSES, - image_size=gpc.config.IMG_SIZE, - return_dict=False) - test_dataloader = DummyDataLoader(length=5, - batch_size=gpc.config.BATCH_SIZE, - category=gpc.config.NUM_CLASSES, - image_size=gpc.config.IMG_SIZE, - return_dict=False) - - logger.info('Build model', ranks=[0]) - - model_kwargs = dict(img_size=gpc.config.IMG_SIZE, - patch_size=gpc.config.PATCH_SIZE, - embed_dim=gpc.config.HIDDEN_SIZE, - depth=gpc.config.DEPTH, - num_heads=gpc.config.NUM_HEADS, - mlp_ratio=gpc.config.MLP_RATIO, - num_classes=gpc.config.NUM_CLASSES, - drop_rate=0.1, - attn_drop_rate=0.1, - weight_init='jax') - - with ColoInitContext(device=get_current_device()): - model = _create_vision_transformer('vit_small_patch16_224', pretrained=False, **model_kwargs) - init_spec_func(model, gpc.config.TP_TYPE) - - world_size = torch.distributed.get_world_size() - model = ColoDDP(module=model, process_group=ProcessGroup(tp_degree=world_size)) - logger.info('Build criterion, optimizer, lr_scheduler', ranks=[0]) - optimizer = HybridAdam(model.parameters(), lr=gpc.config.LEARNING_RATE, weight_decay=gpc.config.WEIGHT_DECAY) - - criterion = CrossEntropyLoss() - lr_scheduler = CosineAnnealingWarmupLR(optimizer=optimizer, - total_steps=gpc.config.NUM_EPOCHS, - warmup_steps=gpc.config.WARMUP_EPOCHS) - - start_epoch = 0 - if args.resume_from: - load_model = torch.load(args.resume_from + '_model.pth') - start_epoch = load_model['epoch'] - model.load_state_dict(load_model['model']) - load_optim = torch.load(args.resume_from + '_optim_rank_{}.pth'.format(dist.get_rank())) - optimizer.load_state_dict(load_optim['optim']) - - for epoch in range(start_epoch, gpc.config.NUM_EPOCHS): - model.train() - for index, (x, y) in tqdm(enumerate(train_dataloader), total=len(train_dataloader), leave=False): - x, y = x.cuda(), y.cuda() - output = model(x) - loss = criterion(output, y) - loss = loss / gpc.config.gradient_accumulation - if use_ddp: - model.backward(loss) - else: - loss.backward() - if (index + 1) % gpc.config.gradient_accumulation == 0: - optimizer.step() - if use_ddp: - model.zero_grad() - else: - optimizer.zero_grad() - - logger.info( - f"Finish Train Epoch [{epoch+1}/{gpc.config.NUM_EPOCHS}] loss: {loss.item():.3f} lr: {optimizer.state_dict()['param_groups'][0]['lr']}", - ranks=[0]) - - model.eval() - test_loss = 0 - correct = 0 - test_sum = 0 - with torch.no_grad(): - for index, (x, y) in tqdm(enumerate(test_dataloader), total=len(test_dataloader), leave=False): - x, y = x.cuda(), y.cuda() - output = model(x) - test_loss += F.cross_entropy(output, y, reduction='sum').item() - pred = output.argmax(dim=1, keepdim=True) - correct += pred.eq(y.view_as(pred)).sum().item() - test_sum += y.size(0) - - test_loss /= test_sum - logger.info( - f"Finish Test Epoch [{epoch+1}/{gpc.config.NUM_EPOCHS}] loss: {test_loss:.3f} Accuracy: [{correct}/{test_sum}]({correct/test_sum:.3f})", - ranks=[0]) - - lr_scheduler.step() - - -if __name__ == '__main__': - train_imagenet() diff --git a/examples/images/vit/vit.py b/examples/images/vit/vit.py deleted file mode 100644 index f22e8ea90cecca131c4314d7eb7fb02c74cb371c..0000000000000000000000000000000000000000 --- a/examples/images/vit/vit.py +++ /dev/null @@ -1,95 +0,0 @@ -from abc import ABC, abstractmethod - -import torch -import torch.nn as nn -from transformers import ViTConfig, ViTForImageClassification - -from colossalai.utils.cuda import get_current_device - - -class DummyDataGenerator(ABC): - - def __init__(self, length=10): - self.length = length - - @abstractmethod - def generate(self): - pass - - def __iter__(self): - self.step = 0 - return self - - def __next__(self): - if self.step < self.length: - self.step += 1 - return self.generate() - else: - raise StopIteration - - def __len__(self): - return self.length - - -class DummyDataLoader(DummyDataGenerator): - - def __init__(self, length=10, batch_size=4, channel=3, category=8, image_size=224, return_dict=True): - super().__init__(length) - self.batch_size = batch_size - self.channel = channel - self.category = category - self.image_size = image_size - self.return_dict = return_dict - - def generate(self): - image_dict = {} - image_dict['pixel_values'] = torch.rand( - self.batch_size, self.channel, self.image_size, self.image_size, device=get_current_device()) * 2 - 1 - image_dict['label'] = torch.randint(self.category, (self.batch_size,), - dtype=torch.int64, - device=get_current_device()) - if not self.return_dict: - return image_dict['pixel_values'], image_dict['label'] - return image_dict - - -class ViTCVModel(nn.Module): - - def __init__(self, - hidden_size=768, - num_hidden_layers=12, - num_attention_heads=12, - image_size=224, - patch_size=16, - num_channels=3, - num_labels=8, - checkpoint=False): - super().__init__() - self.checkpoint = checkpoint - self.model = ViTForImageClassification( - ViTConfig(hidden_size=hidden_size, - num_hidden_layers=num_hidden_layers, - num_attention_heads=num_attention_heads, - image_size=image_size, - patch_size=patch_size, - num_channels=num_channels, - num_labels=num_labels)) - if checkpoint: - self.model.gradient_checkpointing_enable() - - def forward(self, pixel_values): - return self.model(pixel_values=pixel_values) - - -def vit_base_s(checkpoint=True): - return ViTCVModel(checkpoint=checkpoint) - - -def vit_base_micro(checkpoint=True): - return ViTCVModel(hidden_size=32, num_hidden_layers=2, num_attention_heads=4, checkpoint=checkpoint) - - -def get_training_components(): - trainloader = DummyDataLoader() - testloader = DummyDataLoader() - return vit_base_micro, trainloader, testloader, torch.optim.Adam, torch.nn.functional.cross_entropy diff --git a/examples/images/vit/vit_benchmark.py b/examples/images/vit/vit_benchmark.py new file mode 100644 index 0000000000000000000000000000000000000000..b770bc9cfb952e945176f027e32b3a19766972b5 --- /dev/null +++ b/examples/images/vit/vit_benchmark.py @@ -0,0 +1,152 @@ +import time + +import torch +import transformers +from args import parse_benchmark_args +from tqdm import tqdm +from transformers import ViTConfig, ViTForImageClassification + +import colossalai +from colossalai.booster import Booster +from colossalai.booster.plugin import GeminiPlugin, HybridParallelPlugin, LowLevelZeroPlugin, TorchDDPPlugin +from colossalai.cluster import DistCoordinator +from colossalai.logging import disable_existing_loggers, get_dist_logger +from colossalai.nn.optimizer import HybridAdam + + +def format_num(num: int, bytes=False): + """Scale bytes to its proper format, e.g. 1253656 => '1.20MB'""" + factor = 1024 if bytes else 1000 + suffix = "B" if bytes else "" + for unit in ["", " K", " M", " G", " T", " P"]: + if num < factor: + return f"{num:.2f}{unit}{suffix}" + num /= factor + + +def get_data_batch(batch_size, num_labels, num_channels=3, height=224, width=224): + pixel_values = torch.randn( + batch_size, num_channels, height, width, device=torch.cuda.current_device(), dtype=torch.float + ) + labels = torch.randint(0, num_labels, (batch_size,), device=torch.cuda.current_device(), dtype=torch.int64) + return dict(pixel_values=pixel_values, labels=labels) + + +def colo_memory_cap(size_in_GB): + from colossalai.utils import colo_device_memory_capacity, colo_set_process_memory_fraction, get_current_device + + cuda_capacity = colo_device_memory_capacity(get_current_device()) + if size_in_GB * (1024**3) < cuda_capacity: + colo_set_process_memory_fraction(size_in_GB * (1024**3) / cuda_capacity) + print(f"Limiting GPU memory usage to {size_in_GB} GB") + + +def main(): + args = parse_benchmark_args() + + # Launch ColossalAI + colossalai.launch_from_torch(config={}, seed=args.seed) + coordinator = DistCoordinator() + world_size = coordinator.world_size + + # Manage loggers + disable_existing_loggers() + logger = get_dist_logger() + if coordinator.is_master(): + transformers.utils.logging.set_verbosity_info() + else: + transformers.utils.logging.set_verbosity_error() + + # Whether to set limit on memory capacity + if args.mem_cap > 0: + colo_memory_cap(args.mem_cap) + + # Build ViT model + config = ViTConfig.from_pretrained(args.model_name_or_path) + model = ViTForImageClassification(config) + logger.info(f"Finish loading model from {args.model_name_or_path}", ranks=[0]) + + # Enable gradient checkpointing + if args.grad_checkpoint: + model.gradient_checkpointing_enable() + + # Set plugin + booster_kwargs = {} + if args.plugin == "torch_ddp_fp16": + booster_kwargs["mixed_precision"] = "fp16" + if args.plugin.startswith("torch_ddp"): + plugin = TorchDDPPlugin() + elif args.plugin == "gemini": + plugin = GeminiPlugin(offload_optim_frac=1.0, pin_memory=True, initial_scale=2**5) + elif args.plugin == "low_level_zero": + plugin = LowLevelZeroPlugin(initial_scale=2**5) + elif args.plugin == "hybrid_parallel": + plugin = HybridParallelPlugin( + tp_size=2, + pp_size=2, + num_microbatches=None, + microbatch_size=1, + enable_all_optimization=True, + precision="fp16", + initial_scale=1, + ) + logger.info(f"Set plugin as {args.plugin}", ranks=[0]) + + # Set optimizer + optimizer = HybridAdam(model.parameters(), lr=(args.learning_rate * world_size)) + + # Set criterion (loss function) + def criterion(outputs, inputs): + return outputs.loss + + # Set booster + booster = Booster(plugin=plugin, **booster_kwargs) + model, optimizer, criterion, _, _ = booster.boost(model, optimizer, criterion=criterion) + + # Start training. + logger.info(f"Start testing", ranks=[0]) + + torch.cuda.synchronize() + model.train() + start_time = time.time() + + with tqdm(range(args.max_train_steps), desc="Training Step", disable=not coordinator.is_master()) as pbar: + for _ in pbar: + optimizer.zero_grad() + batch = get_data_batch(args.batch_size, args.num_labels, 3, 224, 224) + + if hasattr(booster.plugin, "stage_manager") and booster.plugin.stage_manager is not None: + # run pipeline forward backward + batch = iter([batch]) + outputs = booster.execute_pipeline( + batch, model, criterion, optimizer, return_loss=True, return_outputs=True + ) + else: + outputs = model(**batch) + loss = criterion(outputs, None) + # Backward + booster.backward(loss, optimizer) + + optimizer.step() + + torch.cuda.synchronize() + + # Compute Statistics + end_time = time.time() + throughput = "{:.4f}".format((world_size * args.max_train_steps * args.batch_size) / (end_time - start_time)) + max_mem = format_num(torch.cuda.max_memory_allocated(device=torch.cuda.current_device()), bytes=True) + + logger.info( + f"Testing finished, " + f"batch size per gpu: {args.batch_size}, " + f"plugin: {args.plugin}, " + f"throughput: {throughput}, " + f"maximum memory usage per gpu: {max_mem}.", + ranks=[0], + ) + + torch.cuda.empty_cache() + + +if __name__ == "__main__": + main() diff --git a/examples/images/vit/vit_train_demo.py b/examples/images/vit/vit_train_demo.py new file mode 100644 index 0000000000000000000000000000000000000000..81009b3707b69eb2d9a99e0e9b1532d491091cb8 --- /dev/null +++ b/examples/images/vit/vit_train_demo.py @@ -0,0 +1,242 @@ +from typing import Any, Callable, Iterator + +import torch +import torch.distributed as dist +import torch.nn as nn +import transformers +from args import parse_demo_args +from data import BeansDataset, beans_collator +from torch.optim import Optimizer +from torch.optim.lr_scheduler import _LRScheduler as LRScheduler +from torch.utils.data import DataLoader +from tqdm import tqdm +from transformers import ViTConfig, ViTForImageClassification, ViTImageProcessor + +import colossalai +from colossalai.booster import Booster +from colossalai.booster.plugin import GeminiPlugin, HybridParallelPlugin, LowLevelZeroPlugin, TorchDDPPlugin +from colossalai.cluster import DistCoordinator +from colossalai.logging import disable_existing_loggers, get_dist_logger +from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR +from colossalai.nn.optimizer import HybridAdam + + +def move_to_cuda(batch, device): + return {k: v.to(device) for k, v in batch.items()} + + +def run_forward_backward( + model: nn.Module, + optimizer: Optimizer, + criterion: Callable[[Any, Any], torch.Tensor], + data_iter: Iterator, + booster: Booster, +): + if optimizer is not None: + optimizer.zero_grad() + if isinstance(booster.plugin, HybridParallelPlugin) and booster.plugin.pp_size > 1: + # run pipeline forward backward when enabling pp in hybrid parallel plugin + output_dict = booster.execute_pipeline( + data_iter, model, criterion, optimizer, return_loss=True, return_outputs=True + ) + loss, outputs = output_dict["loss"], output_dict["outputs"] + else: + batch = next(data_iter) + batch = move_to_cuda(batch, torch.cuda.current_device()) + outputs = model(**batch) + loss = criterion(outputs, None) + if optimizer is not None: + booster.backward(loss, optimizer) + + return loss, outputs + + +def train_epoch( + epoch: int, + model: nn.Module, + optimizer: Optimizer, + criterion: Callable[[Any, Any], torch.Tensor], + lr_scheduler: LRScheduler, + dataloader: DataLoader, + booster: Booster, + coordinator: DistCoordinator, +): + torch.cuda.synchronize() + + num_steps = len(dataloader) + data_iter = iter(dataloader) + enable_pbar = coordinator.is_master() + if isinstance(booster.plugin, HybridParallelPlugin) and booster.plugin.pp_size > 1: + # when using pp, only the last stage of master pipeline (dp_rank and tp_rank are both zero) shows pbar + tp_rank = dist.get_rank(booster.plugin.tp_group) + dp_rank = dist.get_rank(booster.plugin.dp_group) + enable_pbar = tp_rank == 0 and dp_rank == 0 and booster.plugin.stage_manager.is_last_stage() + + model.train() + + with tqdm(range(num_steps), desc=f"Epoch [{epoch + 1}]", disable=not enable_pbar) as pbar: + for _ in pbar: + loss, _ = run_forward_backward(model, optimizer, criterion, data_iter, booster) + optimizer.step() + lr_scheduler.step() + + # Print batch loss + if enable_pbar: + pbar.set_postfix({"loss": loss.item()}) + + +@torch.no_grad() +def evaluate_model( + epoch: int, + model: nn.Module, + criterion: Callable[[Any, Any], torch.Tensor], + eval_dataloader: DataLoader, + booster: Booster, + coordinator: DistCoordinator, +): + torch.cuda.synchronize() + model.eval() + accum_loss = torch.zeros(1, device=torch.cuda.current_device()) + total_num = torch.zeros(1, device=torch.cuda.current_device()) + accum_correct = torch.zeros(1, device=torch.cuda.current_device()) + + for batch in eval_dataloader: + batch = move_to_cuda(batch, torch.cuda.current_device()) + loss, outputs = run_forward_backward(model, None, criterion, iter([batch]), booster) + + to_accum = True + if isinstance(booster.plugin, HybridParallelPlugin): + # when using hybrid parallel, loss is only collected from last stage of pipeline with tp_rank == 0 + to_accum = to_accum and (dist.get_rank(booster.plugin.tp_group) == 0) + if booster.plugin.pp_size > 1: + to_accum = to_accum and booster.plugin.stage_manager.is_last_stage() + + if to_accum: + accum_loss += loss / len(eval_dataloader) + logits = outputs["logits"] + preds = torch.argmax(logits, dim=1) + + labels = batch["labels"] + total_num += batch["labels"].shape[0] + accum_correct += torch.sum(preds == labels) + + dist.all_reduce(accum_loss) + dist.all_reduce(total_num) + dist.all_reduce(accum_correct) + avg_loss = "{:.4f}".format(accum_loss.item()) + accuracy = "{:.4f}".format(accum_correct.item() / total_num.item()) + if coordinator.is_master(): + print( + f"Evaluation result for epoch {epoch + 1}: \ + average_loss={avg_loss}, \ + accuracy={accuracy}." + ) + + +def main(): + args = parse_demo_args() + + # Launch ColossalAI + colossalai.launch_from_torch(config={}, seed=args.seed) + coordinator = DistCoordinator() + world_size = coordinator.world_size + + # Manage loggers + disable_existing_loggers() + logger = get_dist_logger() + if coordinator.is_master(): + transformers.utils.logging.set_verbosity_info() + else: + transformers.utils.logging.set_verbosity_error() + + # Reset tp_size and pp_size to 1 if not using hybrid parallel. + if args.plugin != "hybrid_parallel": + args.tp_size = 1 + args.pp_size = 1 + + # Prepare Dataset + image_processor = ViTImageProcessor.from_pretrained(args.model_name_or_path) + train_dataset = BeansDataset(image_processor, args.tp_size, split="train") + eval_dataset = BeansDataset(image_processor, args.tp_size, split="validation") + num_labels = train_dataset.num_labels + + # Load pretrained ViT model + config = ViTConfig.from_pretrained(args.model_name_or_path) + config.num_labels = num_labels + config.id2label = {str(i): c for i, c in enumerate(train_dataset.label_names)} + config.label2id = {c: str(i) for i, c in enumerate(train_dataset.label_names)} + model = ViTForImageClassification.from_pretrained( + args.model_name_or_path, config=config, ignore_mismatched_sizes=True + ) + logger.info(f"Finish loading model from {args.model_name_or_path}", ranks=[0]) + + # Enable gradient checkpointing + if args.grad_checkpoint: + model.gradient_checkpointing_enable() + + # Set plugin + booster_kwargs = {} + if args.plugin == "torch_ddp_fp16": + booster_kwargs["mixed_precision"] = "fp16" + if args.plugin.startswith("torch_ddp"): + plugin = TorchDDPPlugin() + elif args.plugin == "gemini": + plugin = GeminiPlugin(offload_optim_frac=1.0, pin_memory=True, initial_scale=2**5) + elif args.plugin == "low_level_zero": + plugin = LowLevelZeroPlugin(initial_scale=2**5) + elif args.plugin == "hybrid_parallel": + plugin = HybridParallelPlugin( + tp_size=args.tp_size, + pp_size=args.pp_size, + num_microbatches=None, + microbatch_size=1, + enable_all_optimization=True, + precision="fp16", + initial_scale=1, + ) + else: + raise ValueError(f"Plugin with name {args.plugin} is not supported!") + logger.info(f"Set plugin as {args.plugin}", ranks=[0]) + + # Prepare dataloader + train_dataloader = plugin.prepare_dataloader( + train_dataset, batch_size=args.batch_size, shuffle=True, drop_last=True, collate_fn=beans_collator + ) + eval_dataloader = plugin.prepare_dataloader( + eval_dataset, batch_size=args.batch_size, shuffle=True, drop_last=True, collate_fn=beans_collator + ) + + # Set optimizer + optimizer = HybridAdam(model.parameters(), lr=(args.learning_rate * world_size), weight_decay=args.weight_decay) + + # Set criterion (loss function) + def criterion(outputs, inputs): + return outputs.loss + + # Set lr scheduler + total_steps = len(train_dataloader) * args.num_epoch + num_warmup_steps = int(args.warmup_ratio * total_steps) + lr_scheduler = CosineAnnealingWarmupLR( + optimizer=optimizer, total_steps=(len(train_dataloader) * args.num_epoch), warmup_steps=num_warmup_steps + ) + + # Set booster + booster = Booster(plugin=plugin, **booster_kwargs) + model, optimizer, _criterion, train_dataloader, lr_scheduler = booster.boost( + model=model, optimizer=optimizer, criterion=criterion, dataloader=train_dataloader, lr_scheduler=lr_scheduler + ) + + # Finetuning + logger.info(f"Start finetuning", ranks=[0]) + for epoch in range(args.num_epoch): + train_epoch(epoch, model, optimizer, criterion, lr_scheduler, train_dataloader, booster, coordinator) + evaluate_model(epoch, model, criterion, eval_dataloader, booster, coordinator) + logger.info(f"Finish finetuning", ranks=[0]) + + # Save the finetuned model + booster.save_model(model, args.output_path, shard=True) + logger.info(f"Saving model checkpoint to {args.output_path}", ranks=[0]) + + +if __name__ == "__main__": + main() diff --git a/examples/inference/bench_bloom.py b/examples/inference/bench_bloom.py new file mode 100644 index 0000000000000000000000000000000000000000..738f43dc06199e50352b031bb0f015ff2da9c63a --- /dev/null +++ b/examples/inference/bench_bloom.py @@ -0,0 +1,100 @@ +import argparse +import os +import time + +import torch +from transformers import BloomForCausalLM, BloomTokenizerFast + +import colossalai +from colossalai.inference.tensor_parallel.engine import TPInferEngine +from colossalai.logging import disable_existing_loggers +from colossalai.shardformer import ShardConfig +from colossalai.testing import clear_cache_before_run, rerun_if_address_is_in_use, spawn + +os.environ["TRANSFORMERS_NO_ADVISORY_WARNINGS"] = "true" + + +def print_perf_stats(latency_set, config, bs, warmup=3): + # trim warmup queries + latency_set = list(latency_set) + latency_set = latency_set[warmup:] + count = len(latency_set) + + if count > 0: + latency_set.sort() + avg = sum(latency_set) / count + num_layers = getattr(config, "num_layers", config.num_hidden_layers) + num_parameters = num_layers * config.hidden_size * config.hidden_size * 12 + num_bytes = 2 # float16 + + print("Avg Per Token Latency: {0:8.2f} ms".format(avg * 1000)) + print("Avg BW: {0:8.2f} GB/s".format(1 / avg * num_parameters * num_bytes / 1e9)) + print("Avg flops: {0:8.2f} TFlops/s".format(1 / avg * num_parameters * num_bytes * bs / 1e12)) + print("Avg Throughput: tokens/s: {}".format((1000 / (avg * 1000)) * bs)) + + +def bench_bloom(args): + model_path = args.path + max_batch_size = args.batch_size + max_input_len = args.input_len + max_output_len = args.output_len + + tokenizer = BloomTokenizerFast.from_pretrained(model_path) + tokenizer.pad_token = tokenizer.eos_token + model = BloomForCausalLM.from_pretrained(model_path, pad_token_id=tokenizer.eos_token_id) + model = model.half() + + # init TPInferEngine and shard the original model + # To benchmark torch original, comment out the line of optimizing model + shard_config = ShardConfig(enable_tensor_parallelism=True if args.tp_size > 1 else False, inference_only=True) + infer_engine = TPInferEngine(model, shard_config, max_batch_size, max_input_len, max_output_len) + + # prepare data for generation + generate_kwargs = dict(max_new_tokens=max_output_len, do_sample=False) + input_tokens = { + "input_ids": torch.randint(10, 1000, (max_batch_size, max_input_len)), + "attention_mask": torch.ones((max_batch_size, max_input_len)), + } + for t in input_tokens: + if torch.is_tensor(input_tokens[t]): + input_tokens[t] = input_tokens[t].to(torch.cuda.current_device()) + print(f" input_tokens[{t}].shape: {input_tokens[t].shape}") + + iters = 10 + times = [] + for i in range(iters): + torch.cuda.synchronize() + start = time.time() + outputs = infer_engine.generate(input_tokens, **generate_kwargs) + torch.cuda.synchronize() + end = time.time() + out_len = outputs.shape[1] + print(f" iter {i}: out len {str(out_len)}, generation time {str(end - start)} s") + times.append((end - start) / (out_len - max_input_len)) + + print_perf_stats(times, model.config, max_batch_size) + + +def check_bloom(rank, world_size, port, args): + disable_existing_loggers() + colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") + bench_bloom(args) + + +@rerun_if_address_is_in_use() +@clear_cache_before_run() +def test_bloom(args): + spawn(check_bloom, args.tp_size, args=args) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("-p", "--path", type=str, help="Model path", required=True) + parser.add_argument("-tp", "--tp_size", type=int, default=1, help="Tensor parallel size") + parser.add_argument("-b", "--batch_size", type=int, default=16, help="Maximum batch size") + parser.add_argument("--input_len", type=int, default=1024, help="Maximum input length") + parser.add_argument("--output_len", type=int, default=128, help="Maximum output length") + + args = parser.parse_args() + + test_bloom(args) diff --git a/examples/inference/bench_llama.py b/examples/inference/bench_llama.py new file mode 100644 index 0000000000000000000000000000000000000000..90d49f6a264a2deceade4e80268da861694617f6 --- /dev/null +++ b/examples/inference/bench_llama.py @@ -0,0 +1,103 @@ +import argparse +import os +import time + +import torch +from torch.profiler import ProfilerActivity, profile, record_function +from transformers import LlamaForCausalLM, LlamaTokenizer + +import colossalai +from colossalai.inference.tensor_parallel.engine import TPInferEngine +from colossalai.logging import disable_existing_loggers +from colossalai.shardformer import ShardConfig +from colossalai.testing import clear_cache_before_run, rerun_if_address_is_in_use, spawn + +os.environ["TRANSFORMERS_NO_ADVISORY_WARNINGS"] = "true" + + +def print_perf_stats(latency_set, config, bs, warmup=3): + # trim warmup queries + latency_set = list(latency_set) + latency_set = latency_set[warmup:] + count = len(latency_set) + + if count > 0: + latency_set.sort() + avg = sum(latency_set) / count + num_layers = getattr(config, "num_layers", config.num_hidden_layers) + num_parameters = num_layers * config.hidden_size * config.hidden_size * 12 + num_bytes = 2 + + print("Avg Per Token Latency: {0:8.2f} ms".format(avg * 1000)) + print("Avg BW: {0:8.2f} GB/s".format(1 / avg * num_parameters * num_bytes / 1e9)) + print("Avg flops: {0:8.2f} TFlops/s".format(1 / avg * num_parameters * num_bytes * bs / 1e12)) + + +def run_llama_test(args): + llama_model_path = args.path + max_batch_size = args.batch_size + max_input_len = args.input_len + max_output_len = args.output_len + + tokenizer = LlamaTokenizer.from_pretrained(llama_model_path) + tokenizer.pad_token_id = tokenizer.unk_token_id + model = LlamaForCausalLM.from_pretrained(llama_model_path, pad_token_id=tokenizer.eos_token_id) + model = model.half() + model_config = model.config + + shard_config = ShardConfig(enable_tensor_parallelism=True if args.tp_size > 1 else False, inference_only=True) + infer_engine = TPInferEngine(model, shard_config, max_batch_size, max_input_len, max_output_len) + + generate_kwargs = dict(max_new_tokens=max_output_len, do_sample=False) + input_tokens = { + "input_ids": torch.randint(1, 1000, (max_batch_size, max_input_len), device="cuda"), + "attention_mask": torch.ones((max_batch_size, max_input_len), device="cuda"), + } + + iters = 10 + times = [] + + for i in range(iters): + torch.cuda.synchronize() + start = time.time() + outputs = infer_engine.generate(input_tokens, **generate_kwargs) + torch.cuda.synchronize() + end = time.time() + out_len = outputs.shape[1] + print("generation time {} s".format(str(end - start))) + times.append((end - start) / (out_len - max_input_len)) + + print("outputs, ", len(outputs)) + print_perf_stats(times, model_config, max_batch_size) + + with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], record_shapes=True) as prof: + with record_function("model_inference"): + torch.cuda.synchronize() + outputs = infer_engine.generate(input_tokens, **generate_kwargs) + torch.cuda.synchronize() + print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10)) + + +def check_llama(rank, world_size, port, args): + disable_existing_loggers() + colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") + run_llama_test(args) + + +@rerun_if_address_is_in_use() +@clear_cache_before_run() +def test_llama(args): + spawn(check_llama, args.tp_size, args=args) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("-p", "--path", type=str, help="Model path", required=True) + parser.add_argument("-tp", "--tp_size", type=int, default=1, help="Tensor parallel size") + parser.add_argument("-b", "--batch_size", type=int, default=16, help="Maximum batch size") + parser.add_argument("--input_len", type=int, default=1024, help="Maximum input length") + parser.add_argument("--output_len", type=int, default=128, help="Maximum output length") + + args = parser.parse_args() + + test_llama(args) diff --git a/examples/inference/gptq_bloom.py b/examples/inference/gptq_bloom.py new file mode 100644 index 0000000000000000000000000000000000000000..43e118cc0aa55146a1e5d209668e0ed26bb37b74 --- /dev/null +++ b/examples/inference/gptq_bloom.py @@ -0,0 +1,123 @@ +import argparse +import logging +import os +import time + +import torch +from auto_gptq import AutoGPTQForCausalLM, BaseQuantizeConfig +from auto_gptq.nn_modules.qlinear import GeneralQuantLinear +from transformers import AutoTokenizer, BloomForCausalLM, BloomTokenizerFast, LlamaForCausalLM, LlamaTokenizer + +import colossalai +from colossalai.inference.tensor_parallel.engine import TPInferEngine +from colossalai.logging import disable_existing_loggers +from colossalai.shardformer import ShardConfig +from colossalai.testing import clear_cache_before_run, rerun_if_address_is_in_use, spawn + +os.environ['TRANSFORMERS_NO_ADVISORY_WARNINGS'] = 'true' + + +def print_perf_stats(latency_set, config, bs, warmup=3): + # trim warmup queries + latency_set = list(latency_set) + latency_set = latency_set[warmup:] + count = len(latency_set) + + if count > 0: + latency_set.sort() + avg = sum(latency_set) / count + num_layers = getattr(config, "num_layers", config.num_hidden_layers) + num_parameters = num_layers * config.hidden_size * config.hidden_size * 12 + num_bytes = 2 # float16 + + print("Avg Per Token Latency: {0:8.2f} ms".format(avg * 1000)) + print("Avg BW: {0:8.2f} GB/s".format(1 / avg * num_parameters * num_bytes / 1e9)) + print("Avg flops: {0:8.2f} TFlops/s".format(1 / avg * num_parameters * num_bytes * bs / 1e12)) + print("Avg Throughput: tokens/s: {}".format((1000 / (avg * 1000)) * bs)) + + +def bench_bloom(args): + + pretrained_model_dir = args.path + quantized_model_dir = args.quantized_path + max_batch_size = args.batch_size + max_input_len = args.input_len + max_output_len = args.output_len + + tokenizer = BloomTokenizerFast.from_pretrained(pretrained_model_dir) + tokenizer.pad_token = tokenizer.eos_token + + # load quantized model to the first GPU + model = AutoGPTQForCausalLM.from_quantized(quantized_model_dir, + device=torch.cuda.current_device(), + inject_fused_attention=False) + + model = model.half() + + model_config = model.config + shard_config = ShardConfig(enable_tensor_parallelism=True if args.tp_size > 1 else False, inference_only=True) + infer_engine = TPInferEngine(model, shard_config, max_batch_size, max_input_len, max_output_len) + generate_kwargs = dict(max_new_tokens=max_output_len, do_sample=False) + + input_tokens = { + "input_ids": torch.randint(1, 1000, (max_batch_size, max_input_len), device='cuda'), + "attention_mask": torch.ones((max_batch_size, max_input_len), device='cuda') + } + + # init TPInferEngine and shard the original model + # To benchmark torch original, comment out the line of optimizing model + shard_config = ShardConfig(enable_tensor_parallelism=True if args.tp_size > 1 else False, + inference_only=True, + inference_gptq=True) + infer_engine = TPInferEngine(model, shard_config, max_batch_size, max_input_len, max_output_len) + + # prepare data for generation + generate_kwargs = dict(max_new_tokens=max_output_len, do_sample=False) + input_tokens = { + "input_ids": torch.randint(10, 1000, (max_batch_size, max_input_len)), + "attention_mask": torch.ones((max_batch_size, max_input_len)) + } + for t in input_tokens: + if torch.is_tensor(input_tokens[t]): + input_tokens[t] = input_tokens[t].to(torch.cuda.current_device()) + # print(f" input_tokens[{t}].shape: {input_tokens[t].shape}") + + iters = 10 + times = [] + for i in range(iters): + torch.cuda.synchronize() + start = time.time() + outputs = infer_engine.generate(input_tokens, **generate_kwargs) + torch.cuda.synchronize() + end = time.time() + out_len = outputs.shape[1] + print(f" iter {i}: out len {str(out_len)}, generation time {str(end - start)} s") + times.append((end - start) / (out_len - max_input_len)) + + print_perf_stats(times, model_config, max_batch_size) + + +def check_bloom(rank, world_size, port, args): + disable_existing_loggers() + colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + bench_bloom(args) + + +@rerun_if_address_is_in_use() +@clear_cache_before_run() +def test_bloom(args): + spawn(check_bloom, args.tp_size, args=args) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument('-p', '--path', type=str, help='Model path', required=True) + parser.add_argument('-q', '--quantized_path', type=str, help='Model path', required=True) + parser.add_argument('-tp', '--tp_size', type=int, default=1, help='Tensor parallel size') + parser.add_argument('-b', '--batch_size', type=int, default=16, help='Maximum batch size') + parser.add_argument('--input_len', type=int, default=1024, help='Maximum input length') + parser.add_argument('--output_len', type=int, default=128, help='Maximum output length') + + args = parser.parse_args() + + test_bloom(args) diff --git a/examples/inference/gptq_llama.py b/examples/inference/gptq_llama.py new file mode 100644 index 0000000000000000000000000000000000000000..1bdee448c7428df9e3f7fb32e2df5db2f8384826 --- /dev/null +++ b/examples/inference/gptq_llama.py @@ -0,0 +1,107 @@ +import argparse +import os +import time + +import torch +from auto_gptq import AutoGPTQForCausalLM +from transformers import LlamaTokenizer + +import colossalai +from colossalai.inference.tensor_parallel.engine import TPInferEngine +from colossalai.inference.tensor_parallel.modeling._utils import init_to_get_rotary +from colossalai.logging import disable_existing_loggers +from colossalai.shardformer import ShardConfig +from colossalai.testing import clear_cache_before_run, rerun_if_address_is_in_use, spawn + +os.environ["TRANSFORMERS_NO_ADVISORY_WARNINGS"] = "true" + + +def print_perf_stats(latency_set, config, bs, warmup=3): + # trim warmup queries + latency_set = list(latency_set) + latency_set = latency_set[warmup:] + count = len(latency_set) + + if count > 0: + latency_set.sort() + avg = sum(latency_set) / count + num_layers = getattr(config, "num_layers", config.num_hidden_layers) + num_parameters = num_layers * config.hidden_size * config.hidden_size * 12 + num_bytes = 2 + + print("Avg Per Token Latency: {0:8.2f} ms".format(avg * 1000)) + print("Avg BW: {0:8.2f} GB/s".format(1 / avg * num_parameters * num_bytes / 1e9)) + print("Avg flops: {0:8.2f} TFlops/s".format(1 / avg * num_parameters * num_bytes * bs / 1e12)) + print("Avg Throughput: tokens/s: {}".format((1000 / (avg * 1000)) * bs)) + + +def run_llama_test(args): + pretrained_model_dir = args.path + quantized_model_dir = args.quantized_path + max_batch_size = args.batch_size + max_input_len = args.input_len + max_output_len = args.output_len + + tokenizer = LlamaTokenizer.from_pretrained(pretrained_model_dir, use_fast=True) + tokenizer.pad_token_id = tokenizer.eos_token_id + + # load quantized model to the first GPU + model = AutoGPTQForCausalLM.from_quantized( + quantized_model_dir, device=torch.cuda.current_device(), inject_fused_attention=False + ) + + init_to_get_rotary(model.model.model, base=10000) + + model_config = model.config + shard_config = ShardConfig( + enable_tensor_parallelism=True if args.tp_size > 1 else False, inference_only=True, inference_gptq=True + ) + infer_engine = TPInferEngine(model, shard_config, max_batch_size, max_input_len, max_output_len) + + generate_kwargs = dict(max_new_tokens=max_output_len, do_sample=False) + + input_tokens = { + "input_ids": torch.randint(1, 1000, (max_batch_size, max_input_len), device="cuda"), + "attention_mask": torch.ones((max_batch_size, max_input_len), device="cuda"), + } + + iters = 10 + times = [] + + for i in range(iters): + torch.cuda.synchronize() + start = time.time() + outputs = infer_engine.generate(input_tokens, **generate_kwargs) + torch.cuda.synchronize() + end = time.time() + out_len = outputs.shape[1] + print(f" iter {i}: out len {str(out_len)}, generation time {str(end - start)} s") + times.append((end - start) / (out_len - max_input_len)) + + print_perf_stats(times, model_config, max_batch_size) + + +def check_llama(rank, world_size, port, args): + disable_existing_loggers() + colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") + run_llama_test(args) + + +@rerun_if_address_is_in_use() +@clear_cache_before_run() +def test_llama(args): + spawn(check_llama, args.tp_size, args=args) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("-p", "--path", type=str, help="Model path", required=True) + parser.add_argument("-q", "--quantized_path", type=str, help="Model path", required=True) + parser.add_argument("-tp", "--tp_size", type=int, default=1, help="Tensor parallel size") + parser.add_argument("-b", "--batch_size", type=int, default=16, help="Maximum batch size") + parser.add_argument("--input_len", type=int, default=1024, help="Maximum input length") + parser.add_argument("--output_len", type=int, default=128, help="Maximum output length") + + args = parser.parse_args() + + test_llama(args) diff --git a/examples/inference/serving/ray_serve/Colossal_Inference_rayserve.py b/examples/inference/serving/ray_serve/Colossal_Inference_rayserve.py new file mode 100644 index 0000000000000000000000000000000000000000..51d520ebbcf65011f2dea49fa3050d0c61c1e181 --- /dev/null +++ b/examples/inference/serving/ray_serve/Colossal_Inference_rayserve.py @@ -0,0 +1,151 @@ +import logging +import os +from typing import Any, List, Union + +import ray +import ray.util.collective as collective +import starlette +import torch +from pydantic import BaseModel +from ray import serve +from ray.serve import Application +from transformers import AutoModelForCausalLM, AutoTokenizer + +import colossalai +from colossalai.inference.tensor_parallel.engine import TPInferEngine +from colossalai.shardformer import ShardConfig +from colossalai.testing import free_port + +ray_serve_logger = logging.getLogger("ray.serve") + + +class GenConfigArgs(BaseModel): + """Config for generation""" + + path: str + tp_size: int = 2 + max_batch_size: int = 4 + max_input_len: int = 128 + max_output_len: int = 32 + + +def log_cuda_info(scope_name: str): + ray_serve_logger.info(f" {scope_name}: ray.get_gpu_ids(): {ray.get_gpu_ids()}") + ray_serve_logger.info( + f" {scope_name}: CUDA_VISIBLE_DEVICES: {os.getenv('CUDA_VISIBLE_DEVICES', 'NO DEVICES FOUND!')}" + ) + if torch.cuda.is_available(): + ray_serve_logger.info( + f" {scope_name}: cuda current_device: {torch.cuda.current_device()}, cuda device count: {torch.cuda.device_count()}" + ) + else: + ray_serve_logger.info(f" {scope_name}: cuda is not available!") + + +@ray.remote(num_gpus=1) +class Worker: + def __init__(self, model_path: str, tp_size: int, max_batch_size: int, max_input_len: int, max_output_len: int): + log_cuda_info("Worker.init") + self.tp_size = tp_size + self.model_path = model_path + self.max_batch_size = max_batch_size + self.max_input_len = max_input_len + self.max_output_len = max_output_len + + def setup(self, world_size, rank, port): + # initialize a ray collective group, otherwise colossalai distributed env won't be built successfully + collective.init_collective_group(world_size, rank, "nccl", "default") + # initialize and set distributed environment + colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") + ray_serve_logger.info(f"Worker with rank {rank} (world size {world_size}) setting up..") + log_cuda_info("Worker.setup") + + # Load model + self.tokenizer = AutoTokenizer.from_pretrained(self.model_path) + if self.tokenizer.pad_token is None: + self.tokenizer.pad_token = self.tokenizer.eos_token + self.model = AutoModelForCausalLM.from_pretrained( + self.model_path, pad_token_id=self.tokenizer.pad_token_id, torch_dtype=torch.float16 + ) + + shard_config = ShardConfig(enable_tensor_parallelism=True if world_size > 1 else False, inference_only=True) + self.infer_engine = TPInferEngine( + self.model, shard_config, self.max_batch_size, self.max_input_len, self.max_output_len + ) + self.generate_kwargs = dict(max_new_tokens=self.max_output_len, do_sample=False) + + return True + + def generate(self, text: Union[str, List[str]]) -> str: + input_tokens = self.tokenizer.batch_encode_plus(text, return_tensors="pt", padding=True) + ray_serve_logger.info(f"text: {text},\ninput_tokens: {input_tokens}") + + model_output = self.infer_engine.generate(input_tokens, **self.generate_kwargs) + ray_serve_logger.info(f"model_output.shape: {model_output.shape}") + + text_output = [] + for i in range(len(model_output)): + text_output.append(self.tokenizer.decode(model_output[i])) + ray_serve_logger.info(f"output: {text_output}") + + return text_output + + +@serve.deployment( + ray_actor_options={"num_cpus": 1, "num_gpus": 0}, + max_concurrent_queries=5, + autoscaling_config={ + "target_num_ongoing_requests_per_replica": 1, + "min_replicas": 1, + "initial_replicas": 1, + "max_replicas": 1, + }, +) +class Driver: + def __init__(self, config: GenConfigArgs): + log_cuda_info("Driver:init") + model_path = config.path + tp_size = config.tp_size + + self.num_workers = tp_size + self.workers = [] + init_rets = [] + + # Just grab a free port on localhost + # NOTE workers in this communication group listen to the same port + available_port = free_port() + + for i in range(self.num_workers): + worker_name = "worker_idx_{}".format(i) + w = Worker.options(name=worker_name).remote( + model_path, self.num_workers, config.max_batch_size, config.max_input_len, config.max_output_len + ) + self.workers.append(w) + init_rets.append(w.setup.remote(self.num_workers, i, available_port)) + _options = { + "group_name": "default_driver", + "world_size": self.num_workers, + "ranks": [i for i in range(self.num_workers)], + "backend": "nccl", + } + collective.create_collective_group(self.workers, **_options) + _ = ray.get(init_rets) + + # set batch wait delay in seconds and maximum number of sequences in a batch + @serve.batch(batch_wait_timeout_s=0.8, max_batch_size=4) + async def batch_generate(self, requests: List[str]): + ray_serve_logger.info(f"Driver.batch_generate: requests length: {len(requests)}\n requests: {requests}") + results = ray.get([w.generate.remote(requests) for w in self.workers]) + text_res = results[0] # get any one of the copies + return text_res + + async def __call__(self, request: starlette.requests.Request) -> Any: + return await self.batch_generate(request.query_params["text"]) + + +def app(args: GenConfigArgs) -> Application: + print(args) + if args.path is None or not os.path.exists(args.path): + raise ValueError("Model path not provided or invalid path!") + + return Driver.options(name="Colossal-Inference-Driver").bind(config=args) diff --git a/examples/inference/serving/ray_serve/README.md b/examples/inference/serving/ray_serve/README.md new file mode 100644 index 0000000000000000000000000000000000000000..1d408238760b883f26a4f0aa700273750f79c6f7 --- /dev/null +++ b/examples/inference/serving/ray_serve/README.md @@ -0,0 +1,86 @@ +# Colossal-Inference with Ray Serve + +This example is used for demonstrating and testing the deployment of Colossal Inference from `colossalai.inference` with [Ray Serve](https://docs.ray.io/en/latest/serve/index.html). It imports inference modules from colossalai and is based on https://github.com/hpcaitech/ColossalAI/tree/a22706337a57dd1c98b95739dd09d98bd55947a0. + +Single-gpu inference as well as multiple-gpu inference (i.e. tensor parallel) serving are supported. + +## Installation + +### Conda Environment +```bash +# create a new conda env with python 3.8 +conda create -n ray_test python=3.8.18 + +# use torch1.13+cuda11.6 +pip install torch==1.13.1+cu116 torchvision==0.14.1+cu116 torchaudio==0.13.1 --extra-index-url https://download.pytorch.org/whl/cu116 + +# install ray from wheels +pip install -U "ray[default,serve]" + +# install cuda toolkit (e.g. nvcc, etc) +conda install -c "nvidia/label/cuda-11.6.2" cuda-toolkit + +# install cuDNN, cuTENSOR, and NCCL +conda install -c conda-forge cupy cudnn cutensor nccl cuda-version=11.6 + +# install colossalai with PyTorch extensions +cd +CUDA_EXT=1 pip install -e . + +# install other dependencies +pip install triton==2.0.0.dev20221202 +pip install transformers +``` + +## Launch Ray Serve and run the app +### Method #1. CLI command + +Under the current directory, we could launch the app by the following command: +```bash +RAY_DEDUP_LOGS=0 serve run Colossal_Inference_rayserve:app path="PATH_TO_YOUR_MODEL_DIR" +``` + +By default, Ray deduplicates logs across cluster. Here we set `RAY_DEDUP_LOGS=0` to disable log deduplication, enabling each actor to log information in CLI. `serve run` runs an application from the specified import path. The formats should be `:`. + +Then we could send requests by running python script in another window: +```bash +python send_request.py +``` + +### Method #2. Run inside script + +We could also launch ray serve and run the app inside a single script by making some modifications: +To avoid ray handler from raising error in serializing pydantic objects, we'll replace the config class from `class GenConfigArgs(BaseModel)` to +```python +from dataclasses import dataclass +@dataclass +class GenConfigArgs: + # attributes remain unchanged +``` +Comment out the app builder +```python +# def app(args: GenConfigArgs) -> Application: +# ... +# return Driver.options(name="Colossal-Inference-Driver").bind(config=args) +``` +And attach the following lines to the end of the file, +```python +from ray.serve.handle import DeploymentHandle, DeploymentResponse + +app = Driver.bind(config=GenConfigArgs(path="")) +handle: DeploymentHandle = serve.run(app).options(use_new_handle_api=True) +response: DeploymentResponse = handle.batch_generate.remote(requests="Introduce some landmarks in Beijing") +print(response.result()) +``` +Then we could run the script +```python +python Colossal_Inference_rayserve.py +``` + +### Terminate Ray Serve +Ray serve and the application would terminate automatically as you choose the second method to run any job in the script. If you choose the first method (serve run), you might want to apply `ctrl+c` to shut down the application, or use `serve shutdown` to shut down serve and deletes all applications on the ray cluster. + +To make sure all the active Ray processes are killed, run +```bash +ray stop +``` diff --git a/examples/inference/serving/ray_serve/send_request.py b/examples/inference/serving/ray_serve/send_request.py new file mode 100644 index 0000000000000000000000000000000000000000..3bab1764a1a59b191e3076afe3c464748049d69f --- /dev/null +++ b/examples/inference/serving/ray_serve/send_request.py @@ -0,0 +1,15 @@ +import ray +import requests + + +@ray.remote +def send_query(text): + resp = requests.get("http://localhost:8000/?text={}".format(text)) + return resp.text + + +test_sentence = "Introduce some landmarks in Beijing" + +result = ray.get(send_query.remote(test_sentence)) +print("Result returned:") +print(result) diff --git a/examples/inference/serving/ray_serve/send_requests.py b/examples/inference/serving/ray_serve/send_requests.py new file mode 100644 index 0000000000000000000000000000000000000000..bee3b6b68c85e1e7b20756031ad41f2f4c471f2f --- /dev/null +++ b/examples/inference/serving/ray_serve/send_requests.py @@ -0,0 +1,27 @@ +import ray +import requests + + +@ray.remote +def send_query(text): + resp = requests.get("http://localhost:8000/?text={}".format(text)) + return resp.text + + +test_sentences = [ + "Introduce some landmarks in Beijing", + "What is the weather today", + "Coding requires practice and patience", + "Rainy days inspire cozy reading", + "Laughter is contagious and heartwarming", + "Hiking mountains builds strength and resilience", + "Family bonds grow stronger with time", + "Science unlocks mysteries of the universe", + "Music soothes the soul and ignites passion", + "Artistic expression knows no boundaries", +] + +results = ray.get([send_query.remote(text) for text in test_sentences]) +print("Result returned:") +for res in results: + print(res) diff --git a/examples/inference/serving/test_ci.sh b/examples/inference/serving/test_ci.sh new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/examples/inference/serving/torch_serve/Colossal_Inference_Handler.py b/examples/inference/serving/torch_serve/Colossal_Inference_Handler.py new file mode 100644 index 0000000000000000000000000000000000000000..c0d30501efea5f4861b4da1bc3bb3cba8fd233ba --- /dev/null +++ b/examples/inference/serving/torch_serve/Colossal_Inference_Handler.py @@ -0,0 +1,193 @@ +import logging +import os +import zipfile +from abc import ABC + +import torch +import transformers +from transformers import AutoTokenizer, BloomForCausalLM, BloomTokenizerFast, LlamaForCausalLM +from ts.torch_handler.base_handler import BaseHandler + +import colossalai +from colossalai.inference.tensor_parallel.engine import TPInferEngine +from colossalai.shardformer import ShardConfig +from colossalai.testing import free_port + +logger = logging.getLogger(__name__) +logger.info("Transformers version %s", transformers.__version__) +logger.info("ColossalAI version %s", colossalai.__version__) + + +class ColossalInferenceHandler(BaseHandler, ABC): + """ + Transformers handler class for testing + """ + + def __init__(self): + super(ColossalInferenceHandler, self).__init__() + self.infer_engine = None + self.max_batch_size = None + self.max_input_len = None + self.max_output_len = None + self.tokenizer = None + self.initialized = False + + def initialize(self, ctx): + """Expected behaviour: the sharded Bloom/Llama model is loaded. + + Args: + ctx (context): It is a JSON Object containing information + pertaining to the model artefacts parameters. + """ + if ctx is not None or not hasattr(ctx, "model_yaml_config"): + logger.error("Context ctx and model-config are not appropriately passed in.") + + self.manifest = ctx.manifest + gpu_id = ctx.system_properties.get("gpu_id", -1) + model_dir = ctx.system_properties.get("model_dir") + + # Inference configs are collected together in model yaml config for handler use + inference_config = ctx.model_yaml_config["handler"] + self.inference_config = inference_config + logger.info(self.inference_config) + + self.tp_size = self.inference_config.get("tp_size", 1) + self.max_batch_size = self.inference_config.get("max_batch_size", 4) + self.max_input_len = self.inference_config.get("max_input_len", 1024) + self.max_output_len = self.inference_config.get("max_output_len", 128) + + self.device = torch.device("cuda:" + str(gpu_id) if torch.cuda.is_available() and gpu_id >= 0 else "cpu") + logger.info(f"Device set to {self.device}") + logger.info(f"torch.cuda.device_count() {torch.cuda.device_count()}") + + # Unpacking from model_dir + model_dir_path = os.path.join(model_dir, "model") + with zipfile.ZipFile(model_dir + "/model.zip", "r") as zip_ref: + zip_ref.extractall(model_dir_path) + logger.info(f"Loading {self.inference_config['model_type']} pretrain model and tokenizer") + if self.inference_config["model_type"] == "bloom": + self.model = BloomForCausalLM.from_pretrained( + model_dir_path, + ) + self.tokenizer = BloomTokenizerFast.from_pretrained(model_dir_path, return_tensors="pt") + elif self.inference_config["model_type"] == "llama": + self.model = LlamaForCausalLM.from_pretrained( + model_dir_path, + ) + self.tokenizer = AutoTokenizer.from_pretrained(model_dir_path, return_tensors="pt") + else: + logger.warning(f"Model type {self.inference_config['model_type']} not supported yet.") + + logger.info("Transformer model from path %s loaded successfully", model_dir) + + # NOTE world_size, rank, host, port here are used to launch colossalai dist environment + # This world_size is different from the world size of TorchServe + world_size = int(os.getenv("WORLD_SIZE", self.tp_size)) + assert world_size == 1, "Colossal-Inference with tensor parallel is not supported on TorchServe for now" + rank = int(os.getenv("RANK", gpu_id)) + local_rank = int(os.getenv("LOCAL_RANK", gpu_id)) + host = os.getenv("MASTER_ADDR", "localhost") + port = os.getenv("MASTER_PORT", free_port()) # use a random free port + + logger.info( + f" world_size {world_size}" f" local_rank {local_rank}" f" rank {rank}" f" host {host}" f" port {port}" + ) + + torch.cuda.set_device(self.device) + self.model.half() + self.model.cuda() + self.model.eval() + + colossalai.launch(config={}, rank=rank, world_size=world_size, host=host, port=port, backend="nccl") + logger.info("Initializing TPInferEngine ...") + shard_config = ShardConfig(enable_tensor_parallelism=True if self.tp_size > 1 else False, inference_only=True) + self.infer_engine = TPInferEngine( + self.model, shard_config, self.max_batch_size, self.max_input_len, self.max_output_len + ) + logger.info("TPInferEngine initialized successfully") + + self.model = self.infer_engine.model + self.initialized = True + + def preprocess(self, requests): + """Basic text preprocessing, based on the user's chocie of application mode. + Args: + requests: The Input data in the form of text is passed on to the preprocess + function. + Returns: + list : The preprocess function returns a list of Tensor for the size of the word tokens. + """ + logger.info("Pre-processing requests") + input_ids_batch = None + attention_mask_batch = None + for idx, data in enumerate(requests): + input_text = data.get("data") + if input_text is None: + input_text = data.get("body") + if isinstance(input_text, (bytes, bytearray)): + input_text = input_text.decode("utf-8") + + logger.info("Received text: '%s'", input_text) + + inputs = self.tokenizer.encode_plus( + input_text, + max_length=self.max_input_len, + padding=True, + add_special_tokens=True, + return_tensors="pt", + truncation=True, + ) + + input_ids = inputs["input_ids"].to(self.device) + attention_mask = inputs["attention_mask"].to(self.device) + # making a batch out of the recieved requests + # attention masks are passed for cases where input tokens are padded. + if input_ids.shape is not None: + if input_ids_batch is None: + input_ids_batch = input_ids + attention_mask_batch = attention_mask + else: + input_ids_batch = torch.cat((input_ids_batch, input_ids), 0) + attention_mask_batch = torch.cat((attention_mask_batch, attention_mask), 0) + return (input_ids_batch, attention_mask_batch) + + def inference(self, input_batch): + """Predict the class (or classes) of the received text using the + serialized transformers checkpoint. + Args: + input_batch (list): List of Text Tensors from the pre-process function is passed here + Returns: + list : It returns a list of the predicted value for the input text + """ + input_ids_batch, attention_mask_batch = input_batch + inferences = [] + + do_sample = self.inference_config.get("do_sample", True) + top_p = self.inference_config.get("top_p", 0.95 if do_sample else 1.0) + top_k = self.inference_config.get("top_k", 60 if do_sample else 50) + input_ids_batch = input_ids_batch.to(self.device) + outputs = self.infer_engine.generate( + dict(input_ids=input_ids_batch, attention_mask=attention_mask_batch), + do_sample=do_sample, + top_p=top_p, + top_k=top_k, + ) + + for i, _ in enumerate(outputs): + inferences.append(self.tokenizer.decode(outputs[i], skip_special_tokens=True)) + + # For testing only + logger.info( + f"Generated text: {inferences}", + ) + + return inferences + + def postprocess(self, inference_output): + """Post Process Function converts the predicted response into Torchserve readable format. + Args: + inference_output (list): It contains the predicted response of the input text. + Returns: + (list): Returns a list of the Predictions and Explanations. + """ + return inference_output diff --git a/examples/inference/serving/torch_serve/README.md b/examples/inference/serving/torch_serve/README.md new file mode 100644 index 0000000000000000000000000000000000000000..6bd145bc30ae7ea91a7ccea58ecaffbc1741d234 --- /dev/null +++ b/examples/inference/serving/torch_serve/README.md @@ -0,0 +1,109 @@ +# Colossal-Inference with TorchServe + +## Overview + +This demo is used for testing and demonstrating the usage of Colossal Inference from `colossalai.inference` with deployment with TorchServe. It imports inference modules from colossalai and is based on +https://github.com/hpcaitech/ColossalAI/tree/3e05c07bb8921f2a8f9736b6f6673d4e9f1697d0. For now, single-gpu inference serving is supported. + +## Environment for testing +### Option #1: Use Conda Env +Records to create a conda env to test locally as follows. We might want to use docker or configure env on cloud platform later. + +*NOTE*: It requires the installation of jdk and the set of `JAVA_HOME`. We recommend to install open-jdk-17 (Please refer to https://openjdk.org/projects/jdk/17/) + +```bash +# use python 3.8 or 3.9 +conda create -n infer python=3.9 + +# use torch 1.13+cuda11.6 for inference +pip install torch==1.13.1+cu116 torchvision==0.14.1+cu116 torchaudio==0.13.1 --extra-index-url https://download.pytorch.org/whl/cu116 + +# conda cuda toolkit (e.g. nvcc, etc) +conda install -c "nvidia/label/cuda-11.6.2" cuda-toolkit + +# install colossalai with PyTorch extensions +cd +pip install -r requirements/requirements.txt +pip install -r requirements/requirements-test.txt +CUDA_EXT=1 pip install -e . + +# install torchserve +cd +python ./ts_scripts/install_dependencies.py --cuda=cu116 +pip install torchserve torch-model-archiver torch-workflow-archiver +``` + +### Option #2: Use Docker +To use the stable diffusion Docker image, you can build using the provided the [Dockerfile](./docker/Dockerfile). + +```bash +# build from dockerfile +cd ColossalAI/examples/inference/serving/torch_serve/docker +docker build -t hpcaitech/colossal-infer-ts:0.2.0 . +``` + +Once you have the image ready, you can launch the image with the following command + +```bash +cd ColossalAI/examples/inference/serving/torch_serve + +# run the docker container +docker run --rm \ + -it --gpus all \ + --name \ + -v :/data/scratch \ + -w \ + hpcaitech/colossal-infer-ts:0.2.0 \ + /bin/bash +``` + +## Steps to deploy a model + +### 1.download/prepare a model +We will download a bloom model, and then zip the downloaded model. You could download the model from [HuggingFace](https://huggingface.co/models) manually, or you might want to refer to this script [download_model.py](https://github.com/pytorch/serve/blob/c3ca2599b4d36d2b61302064b02eab1b65e1908d/examples/large_models/utils/Download_model.py) provided by pytorch-serve team to help you download a snapshot of the model. + +```bash +# download snapshots +cd /examples/large_models/utils/ +huggingface-cli login +python download_model.py --model_name bigscience/bloom-560m -o + +# zip the model repo +cd /models--bigscience--bloom-560m/snapshots/ +zip -r //model.zip * +``` + +> **_NOTE:_** The torch archiver and server will use `/tmp/` folder. Depending on the limit of disk quota, using torch-model-archiver might cause OSError "Disk quota exceeded". To prevent the OSError, set tmp dir environment variable as follows: +`export TMPDIR=/tmp` and `export TEMP=/tmp`, +or use relatively small models (as we did) for local testing. + +### 2. Archive the model +With torch archiver, we will pack the model file (.zip) as well as handler file (.py) together into a .mar file. And then in serving process these files will be unpacked by TorchServe. Revelant model configs and inference configs can be set in `model-config.yaml`. +```bash +cd ./ColossalAI/examples/inference/serving/torch_serve +# create a folder under the current directory to store the packed model created by torch archiver +mkdir model_store +torch-model-archiver --model-name bloom --version 0.1 --handler Colossal_Inference_Handler.py --config-file model-config.yaml --extra-files /model.zip --export-path ./model_store/ +``` + +### 3. Launch serving + +Modify `load_models` in config.properties to select the model(s) stored in directory to be deployed. By default we use `load_models=all` to load and deploy all the models (.mar) we have. + +```bash +torchserve --start --ncs --ts-config config.properties +``` +We could set inference, management, and metrics addresses and other TorchServe settings in `config.properties`. + +TorchServe will create a folder `logs/` under the current directory to store ts, model, and metrics logs. + +### 4. Run inference + +```bash +# check inference status +curl http://0.0.0.0:8084/ping + +curl -X POST http://localhost:8084/predictions/bloom -T sample_text.txt +``` + +To stop TorchServe, run `torchserve --stop` diff --git a/examples/inference/serving/torch_serve/config.properties b/examples/inference/serving/torch_serve/config.properties new file mode 100644 index 0000000000000000000000000000000000000000..7f2b882a11a71923a75db894c46451dcc7c7172d --- /dev/null +++ b/examples/inference/serving/torch_serve/config.properties @@ -0,0 +1,10 @@ +inference_address=http://0.0.0.0:8084 +management_address=http://0.0.0.0:8085 +metrics_address=http://0.0.0.0:8086 +enable_envvars_config=true +install_py_dep_per_model=true +number_of_gpu=1 +load_models=all +max_response_size=655350000 +default_response_timeout=6000 +model_store=./model_store diff --git a/examples/inference/serving/torch_serve/docker/Dockerfile b/examples/inference/serving/torch_serve/docker/Dockerfile new file mode 100644 index 0000000000000000000000000000000000000000..6d780a84747f9f798c4fe695a98cd247b7006564 --- /dev/null +++ b/examples/inference/serving/torch_serve/docker/Dockerfile @@ -0,0 +1,57 @@ +FROM hpcaitech/pytorch-cuda:1.13.0-11.6.0 + +# enable passwordless ssh +RUN mkdir ~/.ssh && \ + printf "Host * \n ForwardAgent yes\nHost *\n StrictHostKeyChecking no" > ~/.ssh/config && \ + ssh-keygen -t rsa -N "" -f ~/.ssh/id_rsa && \ + cat ~/.ssh/id_rsa.pub >> ~/.ssh/authorized_keys + +# install curl +RUN apt-get update && \ + apt-get -y install curl && \ + apt-get clean && \ + rm -rf /var/lib/apt/lists/* + +# Download and extract OpenJDK 17 +ENV JAVA_HOME /opt/openjdk-17 +RUN apt-get update && \ + apt-get install -y wget && \ + wget -q https://download.java.net/openjdk/jdk17/ri/openjdk-17+35_linux-x64_bin.tar.gz -O /tmp/openjdk.tar.gz && \ + mkdir -p $JAVA_HOME && \ + tar xzf /tmp/openjdk.tar.gz -C $JAVA_HOME --strip-components=1 && \ + rm /tmp/openjdk.tar.gz && \ + apt-get purge -y --auto-remove wget && \ + rm -rf /var/lib/apt/lists/* + +ENV PATH $JAVA_HOME/bin:$PATH +RUN export JAVA_HOME +RUN java -version + +# install ninja +RUN apt-get update && \ + apt-get install -y --no-install-recommends ninja-build && \ + apt-get clean && \ + rm -rf /var/lib/apt/lists/* + +# install colossalai +ARG VERSION=main +RUN git clone -b ${VERSION} https://github.com/hpcaitech/ColossalAI.git && \ + cd ./ColossalAI && \ + git checkout 3e05c07bb8921f2a8f9736b6f6673d4e9f1697d0 && \ + CUDA_EXT=1 pip install -v --no-cache-dir . + +# install titans +RUN pip install --no-cache-dir titans + +# install transformers +RUN pip install --no-cache-dir transformers + +# install triton +RUN pip install --no-cache-dir triton==2.0.0.dev20221202 + +# install torchserve +ARG VERSION=master +RUN git clone -b ${VERSION} https://github.com/pytorch/serve.git && \ + cd ./serve && \ + python ./ts_scripts/install_dependencies.py --cuda=cu116 && \ + pip install torchserve torch-model-archiver torch-workflow-archiver diff --git a/examples/inference/serving/torch_serve/model-config.yaml b/examples/inference/serving/torch_serve/model-config.yaml new file mode 100644 index 0000000000000000000000000000000000000000..0f86d424beee7428bc19be9b328bdd109cf71dc0 --- /dev/null +++ b/examples/inference/serving/torch_serve/model-config.yaml @@ -0,0 +1,16 @@ +# TS frontend parameters settings +minWorkers: 1 # minimum number of workers of a model +maxWorkers: 1 # maximum number of workers of a model +batchSize: 8 # batch size of a model +maxBatchDelay: 100 # maximum delay of a batch (ms) +responseTimeout: 120 # timeout of a specific model's response (*in sec) +deviceType: "gpu" +# deviceIds: [0, 1] # seting CUDA_VISIBLE_DEVICES + +handler: + mode: "text_generation" + model_type: "bloom" + tp_size: 1 + max_batch_size: 8 + max_input_len: 1024 + max_output_len: 128 diff --git a/examples/inference/serving/torch_serve/sample_text.txt b/examples/inference/serving/torch_serve/sample_text.txt new file mode 100644 index 0000000000000000000000000000000000000000..18d8729f21b4d14dce06764c5b9751fb2c81a6c0 --- /dev/null +++ b/examples/inference/serving/torch_serve/sample_text.txt @@ -0,0 +1 @@ +Introduce some landmarks in Beijing diff --git a/examples/language/bert/README.md b/examples/language/bert/README.md new file mode 100644 index 0000000000000000000000000000000000000000..6601edb7960eeeb8702c6761a4c8428563313aa6 --- /dev/null +++ b/examples/language/bert/README.md @@ -0,0 +1,44 @@ +## Overview + +This directory includes two parts: Using the Booster API finetune Huggingface Bert and AlBert models and benchmarking Bert and AlBert models with different Booster Plugin. + +## Finetune +``` +bash test_ci.sh +``` + +### Bert-Finetune Results + +| Plugin | Accuracy | F1-score | GPU number | +| -------------- | -------- | -------- | -------- | +| torch_ddp | 84.4% | 88.6% | 2 | +| torch_ddp_fp16 | 84.7% | 88.8% | 2 | +| gemini | 84.0% | 88.4% | 2 | +| hybrid_parallel | 84.5% | 88.6% | 4 | + + +## Benchmark +``` +bash benchmark.sh +``` + +Now include these metrics in benchmark: CUDA mem occupy, throughput and the number of model parameters. If you have custom metrics, you can add them to benchmark_util. + +### Results + +#### Bert + +| | max cuda mem | throughput(sample/s) | params | +| :-----| -----------: | :--------: | :----: | +| ddp | 21.44 GB | 3.0 | 82M | +| ddp_fp16 | 16.26 GB | 11.3 | 82M | +| gemini | 11.0 GB | 12.9 | 82M | +| low_level_zero | 11.29 G | 14.7 | 82M | + +#### AlBert +| | max cuda mem | throughput(sample/s) | params | +| :-----| -----------: | :--------: | :----: | +| ddp | OOM | | | +| ddp_fp16 | OOM | | | +| gemini | 69.39 G | 1.3 | 208M | +| low_level_zero | 56.89 G | 1.4 | 208M | diff --git a/examples/language/bert/benchmark.py b/examples/language/bert/benchmark.py new file mode 100644 index 0000000000000000000000000000000000000000..10bd367fda5b629ac25ddd64e7c808b04eef040d --- /dev/null +++ b/examples/language/bert/benchmark.py @@ -0,0 +1,172 @@ +import argparse + +import torch +from benchmark_utils import benchmark +from torch.utils.data import DataLoader, Dataset +from transformers import ( + AlbertConfig, + AlbertForSequenceClassification, + BertConfig, + BertForSequenceClassification, + get_linear_schedule_with_warmup, +) + +import colossalai +from colossalai.booster import Booster +from colossalai.booster.plugin import GeminiPlugin, LowLevelZeroPlugin, TorchDDPPlugin +from colossalai.cluster import DistCoordinator +from colossalai.nn.optimizer import HybridAdam + +# ============================== +# Prepare Hyperparameters +# ============================== +NUM_EPOCHS = 3 +BATCH_SIZE = 32 +LEARNING_RATE = 2.4e-5 +WEIGHT_DECAY = 0.01 +WARMUP_FRACTION = 0.1 +SEQ_LEN = 512 +VOCAB_SIZE = 1000 +NUM_LABELS = 10 +DATASET_LEN = 1000 + + +class RandintDataset(Dataset): + def __init__(self, dataset_length: int, sequence_length: int, vocab_size: int, n_class: int): + self._sequence_length = sequence_length + self._vocab_size = vocab_size + self._n_class = n_class + self._dataset_length = dataset_length + self._datas = torch.randint( + low=0, + high=self._vocab_size, + size=( + self._dataset_length, + self._sequence_length, + ), + dtype=torch.long, + ) + self._labels = torch.randint(low=0, high=self._n_class, size=(self._dataset_length, 1), dtype=torch.long) + + def __len__(self): + return self._dataset_length + + def __getitem__(self, idx): + return self._datas[idx], self._labels[idx] + + +def main(): + # ============================== + # Parse Arguments + # ============================== + parser = argparse.ArgumentParser() + parser.add_argument("-t", "--task", default="mrpc", help="GLUE task to run") + parser.add_argument( + "-p", + "--plugin", + type=str, + default="torch_ddp", + choices=["torch_ddp", "torch_ddp_fp16", "gemini", "low_level_zero"], + help="plugin to use", + ) + parser.add_argument( + "--model_type", + type=str, + default="bert", + help="bert or albert", + ) + + args = parser.parse_args() + + # ============================== + # Launch Distributed Environment + # ============================== + colossalai.launch_from_torch(config={}, seed=42) + coordinator = DistCoordinator() + + # local_batch_size = BATCH_SIZE // coordinator.world_size + lr = LEARNING_RATE * coordinator.world_size + + # ============================== + # Instantiate Plugin and Booster + # ============================== + booster_kwargs = {} + if args.plugin == "torch_ddp_fp16": + booster_kwargs["mixed_precision"] = "fp16" + if args.plugin.startswith("torch_ddp"): + plugin = TorchDDPPlugin() + elif args.plugin == "gemini": + plugin = GeminiPlugin(placement_policy="cuda", strict_ddp_mode=True, initial_scale=2**5) + elif args.plugin == "low_level_zero": + plugin = LowLevelZeroPlugin(initial_scale=2**5) + + booster = Booster(plugin=plugin, **booster_kwargs) + + # ============================== + # Prepare Dataloader + # ============================== + + train_dataset = RandintDataset( + dataset_length=DATASET_LEN, sequence_length=SEQ_LEN, vocab_size=VOCAB_SIZE, n_class=NUM_LABELS + ) + train_dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE) + + # ==================================== + # Prepare model, optimizer + # ==================================== + # bert pretrained model + + if args.model_type == "bert": + cfg = BertConfig(vocab_size=VOCAB_SIZE, num_labels=NUM_LABELS) + model = BertForSequenceClassification(cfg) + elif args.model_type == "albert": + cfg = AlbertConfig(vocab_size=VOCAB_SIZE, num_labels=NUM_LABELS) + model = AlbertForSequenceClassification(cfg) + else: + raise RuntimeError + + # optimizer + no_decay = ["bias", "LayerNorm.weight"] + optimizer_grouped_parameters = [ + { + "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], + "weight_decay": WEIGHT_DECAY, + }, + { + "params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], + "weight_decay": 0.0, + }, + ] + + optimizer = HybridAdam(optimizer_grouped_parameters, lr=lr, eps=1e-8) + + # lr scheduler + total_steps = len(train_dataloader) * NUM_EPOCHS + num_warmup_steps = int(WARMUP_FRACTION * total_steps) + lr_scheduler = get_linear_schedule_with_warmup( + optimizer, + num_warmup_steps=num_warmup_steps, + num_training_steps=total_steps, + ) + + # criterion + criterion = lambda inputs: inputs[0] + + # ============================== + # Boost with ColossalAI + # ============================== + model, optimizer, _, _, lr_scheduler = booster.boost(model, optimizer, lr_scheduler=lr_scheduler) + + # ============================== + # Benchmark model + # ============================== + + results = benchmark( + model, booster, optimizer, lr_scheduler, train_dataloader, criterion=criterion, epoch_num=NUM_EPOCHS + ) + + coordinator.print_on_master(results) + + +if __name__ == "__main__": + main() diff --git a/examples/language/bert/benchmark.sh b/examples/language/bert/benchmark.sh new file mode 100755 index 0000000000000000000000000000000000000000..9453d1373f2f8e30adb967a9518739d747347c13 --- /dev/null +++ b/examples/language/bert/benchmark.sh @@ -0,0 +1,9 @@ +#!/bin/bash +set -xe + +pip install -r requirements.txt + +for plugin in "torch_ddp" "torch_ddp_fp16" "gemini" "low_level_zero"; do + torchrun --standalone --nproc_per_node 2 benchmark.py --plugin $plugin --model_type "bert" + torchrun --standalone --nproc_per_node 2 benchmark.py --plugin $plugin --model_type "albert" +done diff --git a/examples/language/bert/benchmark_utils.py b/examples/language/bert/benchmark_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..04d55cb2e7b6d868f0afb1c7211328c29ff3e3f3 --- /dev/null +++ b/examples/language/bert/benchmark_utils.py @@ -0,0 +1,149 @@ +import inspect +from logging import getLogger +from time import time +from typing import Callable + +import torch +import yaml +from torch.optim.lr_scheduler import _LRScheduler as LRScheduler +from torch.utils.data import DataLoader +from tqdm import tqdm + +from colossalai.booster import Booster +from colossalai.cluster import DistCoordinator + +logger = getLogger("colossalai-booster-benchmark") +_INVALID = float("nan") + + +def format_num(num: int, bytes=False): + """Scale bytes to its proper format, e.g. 1253656 => '1.20MB'""" + factor = 1024 if bytes else 1000 + suffix = "B" if bytes else "" + for unit in ["", " K", " M", " G", " T", " P"]: + if num < factor: + return f"{num:.2f}{unit}{suffix}" + num /= factor + + +def _is_valid(val): + return val == val + + +def get_call_arg_names(module_or_fn): + if isinstance(module_or_fn, torch.nn.Module): + return inspect.getfullargspec(module_or_fn.forward)[0][1:] + return inspect.getfullargspec(module_or_fn)[0] + + +def measure_params(model): + num_params = _INVALID + + try: + num_params = sum(p.numel() for p in model.parameters() if p.requires_grad) + except AttributeError as e: + logger.error(f"Unable to measure model params due to error: {e}") + + return num_params + + +def warm_up( + model, + booster, + dataloader, + criterion, + optimizer, + lr_scheduler, + num_runs=10, +): + for i, data in enumerate(dataloader): + if i > num_runs: + break + inputs, labels = data[0].cuda(), data[1].cuda() + outputs = model(inputs, labels=labels) + loss = criterion(outputs) + booster.backward(loss, optimizer) + optimizer.step() + lr_scheduler.step() + optimizer.zero_grad() + + +def fmt(d: dict): + return yaml.dump(d) + + +def benchmark( + model: torch.nn.Module, + booster: Booster, + optimizer: torch.optim.Optimizer, + lr_scheduler: LRScheduler, + dataloader: DataLoader, + criterion: Callable = None, + warm_up_fn=warm_up, + epoch_num: int = 3, + batch_size: int = 32, + warm_up_steps: int = 3, +): + results = {} + model_device = torch.cuda.current_device() + + # Warm up + warm_up_fn( + model, + booster, + dataloader, + criterion, + optimizer, + lr_scheduler, + num_runs=warm_up_steps, + ) + # Measure params + params = measure_params(model) + if _is_valid(params): + results["params"] = format_num(params) + logger.info(f"Model parameters: {params} ({format_num(params)})") + + # Measure Allocated Memory and Throughput + memory = {} + throughput = {} + torch.cuda.reset_peak_memory_stats(device=model_device) + pre_mem = torch.cuda.memory_allocated(device=model_device) + + start_time = time() + + for epoch in range(epoch_num): + with tqdm( + dataloader, desc=f"Epoch [{epoch + 1}/{epoch_num}]", disable=not DistCoordinator().is_master() + ) as pbar: + for data in pbar: + inputs, labels = data[0].cuda(), data[1].cuda() + outputs = model(inputs, labels=labels) + loss = criterion(outputs) + booster.backward(loss, optimizer) + optimizer.step() + lr_scheduler.step() + optimizer.zero_grad() + + end_time = time() + + all_sample = epoch_num * len(dataloader) + + post_mem = torch.cuda.memory_allocated(device=model_device) + max_mem = torch.cuda.max_memory_allocated(device=model_device) + + memory[f"batch_size_{batch_size}"] = { + "cuda_pre_training_bytes": format_num(pre_mem, bytes=True), + "cuda_max_training_bytes": format_num(max_mem, bytes=True), + "cuda_post_training_bytes": format_num(post_mem, bytes=True), + } + logger.info(fmt({f"Memory results (batch_size={batch_size})": memory[f"batch_size_{batch_size}"]})) + + throughput[f"batch_size_{batch_size}"] = { + "throughput:": "{:.1f}".format(all_sample * DistCoordinator().world_size / (end_time - start_time)) + } + logger.info(fmt({f"Throughput results (batch_size={batch_size})": throughput[f"batch_size_{batch_size}"]})) + + results["throughput"] = throughput + results["memory"] = memory + + return results diff --git a/examples/language/bert/data.py b/examples/language/bert/data.py new file mode 100644 index 0000000000000000000000000000000000000000..ef51f938dc4f8be5811916d692c8234fd811cd84 --- /dev/null +++ b/examples/language/bert/data.py @@ -0,0 +1,123 @@ +import datasets +from transformers import AutoTokenizer, PreTrainedTokenizer + +from colossalai.booster.plugin.dp_plugin_base import DPPluginBase + + +class GLUEDataBuilder: + task_text_field_map = { + "cola": ["sentence"], + "sst2": ["sentence"], + "mrpc": ["sentence1", "sentence2"], + "qqp": ["question1", "question2"], + "stsb": ["sentence1", "sentence2"], + "mnli": ["premise", "hypothesis"], + "qnli": ["question", "sentence"], + "rte": ["sentence1", "sentence2"], + "wnli": ["sentence1", "sentence2"], + "ax": ["premise", "hypothesis"], + } + + glue_task_num_labels = { + "cola": 2, + "sst2": 2, + "mrpc": 2, + "qqp": 2, + "stsb": 1, + "mnli": 3, + "qnli": 2, + "rte": 2, + "wnli": 2, + "ax": 3, + } + + loader_columns = [ + "datasets_idx", + "input_ids", + "token_type_ids", + "attention_mask", + "start_positions", + "end_positions", + "labels", + ] + + def __init__( + self, + model_name_or_path: str, + plugin: DPPluginBase, + task_name: str = "mrpc", + max_seq_length: int = 128, + train_batch_size: int = 32, + eval_batch_size: int = 32, + **kwargs, + ): + super().__init__() + self.model_name_or_path = model_name_or_path + self.task_name = task_name + self.max_seq_length = max_seq_length + self.train_batch_size = train_batch_size + self.eval_batch_size = eval_batch_size + self.plugin = plugin + + self.text_fields = self.task_text_field_map[task_name] + self.num_labels = self.glue_task_num_labels[task_name] + self.tokenizer: PreTrainedTokenizer = AutoTokenizer.from_pretrained(self.model_name_or_path, use_fast=True) + self.setup() + + def setup(self): + self.dataset = datasets.load_dataset("glue", self.task_name) + + for split in self.dataset.keys(): + self.dataset[split] = self.dataset[split].map( + self.convert_to_features, + batched=True, + remove_columns=["label"], + ) + self.columns = [c for c in self.dataset[split].column_names if c in self.loader_columns] + self.dataset[split].set_format(type="torch", columns=self.columns) + + self.eval_splits = [x for x in self.dataset.keys() if "validation" in x] + + def prepare_data(self): + datasets.load_dataset("glue", self.task_name) + AutoTokenizer.from_pretrained(self.model_name_or_path, use_fast=True) + + def train_dataloader(self): + return self.plugin.prepare_dataloader( + self.dataset["train"], batch_size=self.train_batch_size, shuffle=True, drop_last=True + ) + + def val_dataloader(self): + if len(self.eval_splits) == 1: + return self.plugin.prepare_dataloader(self.dataset["validation"], batch_size=self.eval_batch_size) + elif len(self.eval_splits) > 1: + return [ + self.plugin.prepare_dataloader(self.dataset[x], batch_size=self.eval_batch_size) + for x in self.eval_splits + ] + + def test_dataloader(self): + if len(self.eval_splits) == 1: + return self.plugin.prepare_dataloader(self.dataset["test"], batch_size=self.eval_batch_size) + elif len(self.eval_splits) > 1: + return [ + self.plugin.prepare_dataloader(self.dataset[x], batch_size=self.eval_batch_size) + for x in self.eval_splits + ] + + def convert_to_features(self, example_batch): + # Either encode single sentence or sentence pairs + if len(self.text_fields) > 1: + texts_or_text_pairs = list(zip(example_batch[self.text_fields[0]], example_batch[self.text_fields[1]])) + else: + texts_or_text_pairs = example_batch[self.text_fields[0]] + + # Tokenize the text/text pairs + features = self.tokenizer.batch_encode_plus( + texts_or_text_pairs, max_length=self.max_seq_length, padding="max_length", truncation=True + ) + + # Rename label to labels to make it easier to pass to model forward + features["labels"] = example_batch["label"] + + return features diff --git a/examples/language/bert/finetune.py b/examples/language/bert/finetune.py new file mode 100644 index 0000000000000000000000000000000000000000..563cfa58d5f6a67cf33a1f5fcc40090a89004a0a --- /dev/null +++ b/examples/language/bert/finetune.py @@ -0,0 +1,317 @@ +import argparse +from typing import Callable, List, Union + +import evaluate +import torch +import torch.distributed as dist +import torch.nn as nn +from data import GLUEDataBuilder +from torch.optim import Optimizer +from torch.optim.lr_scheduler import _LRScheduler as LRScheduler +from torch.utils.data import DataLoader +from tqdm import tqdm +from transformers import ( + AlbertForSequenceClassification, + AutoConfig, + BertForSequenceClassification, + get_linear_schedule_with_warmup, +) + +import colossalai +from colossalai.booster import Booster +from colossalai.booster.plugin import GeminiPlugin, HybridParallelPlugin, LowLevelZeroPlugin, TorchDDPPlugin +from colossalai.cluster import DistCoordinator +from colossalai.nn.optimizer import HybridAdam +from colossalai.utils import get_current_device + +# ============================== +# Prepare Hyperparameters +# ============================== +NUM_EPOCHS = 3 +BATCH_SIZE = 32 +LEARNING_RATE = 2.4e-5 +WEIGHT_DECAY = 0.01 +WARMUP_FRACTION = 0.1 + +output_transform_fn = lambda x: x +criterion = lambda x: x.loss + + +def move_to_cuda(batch): + return {k: v.cuda() for k, v in batch.items()} + + +@torch.no_grad() +def evaluate_model( + model: nn.Module, + criterion, + test_dataloader: Union[DataLoader, List[DataLoader]], + num_labels: int, + task_name: str, + eval_splits: List[str], + booster: Booster, + coordinator: DistCoordinator, +): + metric = evaluate.load("glue", task_name, process_id=coordinator.rank, num_process=coordinator.world_size) + model.eval() + + def evaluate_subset(dataloader: DataLoader): + use_pipeline = isinstance(booster.plugin, HybridParallelPlugin) and booster.plugin.pp_size > 1 + is_pp_last_stage = use_pipeline and booster.plugin.stage_manager.is_last_stage() + + accum_loss = torch.zeros(1, device=get_current_device()) + for batch in dataloader: + batch = move_to_cuda(batch) + labels = batch["labels"] + if use_pipeline: + pg_mesh = booster.plugin.pg_mesh + pp_group = booster.plugin.pp_group + current_pp_group_ranks = pg_mesh.get_ranks_in_group(pp_group) + current_rank = dist.get_rank() + batch = iter([batch]) + outputs = booster.execute_pipeline(batch, model, criterion, return_loss=True, return_outputs=True) + + if is_pp_last_stage: + logits = outputs["outputs"]["logits"] + val_loss = outputs["loss"] + accum_loss.add_(val_loss) + + if num_labels > 1: + preds = torch.argmax(logits, axis=1) + elif num_labels == 1: + preds = logits.squeeze() + + dist.broadcast_object_list([preds, val_loss], src=current_pp_group_ranks[-1], group=pp_group) + + metric.add_batch(predictions=preds, references=labels) + elif current_rank in current_pp_group_ranks: + object_list = [None, None] + dist.broadcast_object_list(object_list, src=current_pp_group_ranks[-1], group=pp_group) + + metric.add_batch(predictions=object_list[0].to(get_current_device()), references=labels) + accum_loss.add_(object_list[1].to(get_current_device())) + + else: + batch = move_to_cuda(batch) + outputs = model(**batch) + val_loss, logits = outputs[:2] + accum_loss.add_(val_loss) + + if num_labels > 1: + preds = torch.argmax(logits, axis=1) + elif num_labels == 1: + preds = logits.squeeze() + + metric.add_batch(predictions=preds, references=labels) + + results = metric.compute() + dist.all_reduce(accum_loss.div_(len(dataloader))) + if coordinator.is_master() and results is not None: + results["loss"] = accum_loss.item() / coordinator.world_size + + return results + + if isinstance(test_dataloader, DataLoader): + return evaluate_subset(test_dataloader) + else: + assert len(test_dataloader) == len(eval_splits) + final_results = {} + for split, sub_loader in zip(eval_splits, test_dataloader): + results = evaluate_subset(sub_loader) + final_results.update({f"{k}_{split}": v for k, v in results.items()}) + return final_results + + +def train_epoch( + epoch: int, + model: nn.Module, + optimizer: Optimizer, + _criterion: Callable, + lr_scheduler: LRScheduler, + train_dataloader: DataLoader, + booster: Booster, + coordinator: DistCoordinator, +): + use_pipeline = isinstance(booster.plugin, HybridParallelPlugin) and booster.plugin.pp_size > 1 + is_pp_last_stage = use_pipeline and booster.plugin.stage_manager.is_last_stage() + print_flag = (not use_pipeline and coordinator.is_master()) or (use_pipeline and is_pp_last_stage) + total_step = len(train_dataloader) + + model.train() + optimizer.zero_grad() + train_dataloader_iter = iter(train_dataloader) + with tqdm(range(total_step), desc=f"Epoch [{epoch + 1}/{NUM_EPOCHS}]", disable=not print_flag) as pbar: + # Forward pass + for _ in pbar: + if use_pipeline: + outputs = booster.execute_pipeline( + train_dataloader_iter, model, _criterion, optimizer, return_loss=True, return_outputs=True + ) + # Backward and optimize + if is_pp_last_stage: + loss = outputs["loss"] + pbar.set_postfix({"loss": loss.item()}) + else: + data = next(train_dataloader_iter) + data = move_to_cuda(data) + outputs = model(**data) + loss = _criterion(outputs, None) + # Backward + booster.backward(loss, optimizer) + pbar.set_postfix({"loss": loss.item()}) + + optimizer.step() + optimizer.zero_grad() + lr_scheduler.step() + + +def main(): + # ============================== + # Parse Arguments + # ============================== + parser = argparse.ArgumentParser() + parser.add_argument("-t", "--task", default="mrpc", help="GLUE task to run") + parser.add_argument( + "-p", + "--plugin", + type=str, + default="torch_ddp", + choices=["torch_ddp", "torch_ddp_fp16", "gemini", "low_level_zero", "hybrid_parallel"], + help="plugin to use", + ) + parser.add_argument( + "--model_type", + type=str, + default="bert", + help="bert or albert", + ) + parser.add_argument("--target_f1", type=float, default=None, help="target f1 score. Raise exception if not reached") + parser.add_argument("--use_lazy_init", type=bool, default=False, help="for initiating lazy init context") + args = parser.parse_args() + + if args.model_type == "bert": + model_name = "bert-base-uncased" + elif args.model_type == "albert": + model_name = "albert-xxlarge-v2" + else: + raise RuntimeError + + # ============================== + # Launch Distributed Environment + # ============================== + colossalai.launch_from_torch(config={}, seed=42) + coordinator = DistCoordinator() + + lr = LEARNING_RATE * coordinator.world_size + + # ============================== + # Instantiate Plugin and Booster + # ============================== + booster_kwargs = {} + if args.plugin == "torch_ddp_fp16": + booster_kwargs["mixed_precision"] = "fp16" + if args.plugin.startswith("torch_ddp"): + plugin = TorchDDPPlugin() + elif args.plugin == "gemini": + plugin = GeminiPlugin(initial_scale=2**5) + elif args.plugin == "low_level_zero": + plugin = LowLevelZeroPlugin(initial_scale=2**5) + elif args.plugin == "hybrid_parallel": + # modify the param accordingly for finetuning test cases + plugin = HybridParallelPlugin( + tp_size=1, + pp_size=2, + num_microbatches=None, + microbatch_size=1, + enable_all_optimization=True, + zero_stage=1, + precision="fp16", + initial_scale=1, + ) + + booster = Booster(plugin=plugin, **booster_kwargs) + + # ============================== + # Prepare Dataloader + # ============================== + data_builder = GLUEDataBuilder( + model_name, plugin, args.task, train_batch_size=BATCH_SIZE, eval_batch_size=BATCH_SIZE + ) + train_dataloader = data_builder.train_dataloader() + test_dataloader = data_builder.test_dataloader() + + # ==================================== + # Prepare model, optimizer + # ==================================== + # bert pretrained model + + cfg = AutoConfig.from_pretrained(model_name, num_labels=data_builder.num_labels) + + if model_name == "bert-base-uncased": + model = BertForSequenceClassification.from_pretrained(model_name, config=cfg).cuda() + elif model_name == "albert-xxlarge-v2": + model = AlbertForSequenceClassification.from_pretrained(model_name, config=cfg) + else: + raise RuntimeError + + # optimizer + no_decay = ["bias", "LayerNorm.weight"] + optimizer_grouped_parameters = [ + { + "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], + "weight_decay": WEIGHT_DECAY, + }, + { + "params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], + "weight_decay": 0.0, + }, + ] + + optimizer = HybridAdam(optimizer_grouped_parameters, lr=lr, eps=1e-8) + + # lr scheduler + total_steps = len(train_dataloader) * NUM_EPOCHS + num_warmup_steps = int(WARMUP_FRACTION * total_steps) + lr_scheduler = get_linear_schedule_with_warmup( + optimizer, + num_warmup_steps=num_warmup_steps, + num_training_steps=total_steps, + ) + + def _criterion(outputs, inputs): + outputs = output_transform_fn(outputs) + loss = criterion(outputs) + return loss + + # ============================== + # Boost with ColossalAI + # ============================== + model, optimizer, _criterion, _, lr_scheduler = booster.boost( + model, optimizer, criterion=_criterion, lr_scheduler=lr_scheduler + ) + + # ============================== + # Train model + # ============================== + for epoch in range(NUM_EPOCHS): + train_epoch(epoch, model, optimizer, _criterion, lr_scheduler, train_dataloader, booster, coordinator) + + results = evaluate_model( + model, + _criterion, + test_dataloader, + data_builder.num_labels, + args.task, + data_builder.eval_splits, + booster, + coordinator, + ) + + if coordinator.is_master(): + print(results) + if args.target_f1 is not None and "f1" in results: + assert results["f1"] >= args.target_f1, f'f1 score {results["f1"]} is lower than target {args.target_f1}' + + +if __name__ == "__main__": + main() diff --git a/examples/language/bert/requirements.txt b/examples/language/bert/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..377422c260ad24ace3c8701ccfdbc6a11d82af2f --- /dev/null +++ b/examples/language/bert/requirements.txt @@ -0,0 +1,9 @@ +colossalai +evaluate +datasets +torch +tqdm +transformers +scipy +scikit-learn +ptflops diff --git a/examples/language/bert/run_gemini.sh b/examples/language/bert/run_gemini.sh deleted file mode 100644 index d791334e8c97312f343a0cea6ac70d9ba4f7d3fe..0000000000000000000000000000000000000000 --- a/examples/language/bert/run_gemini.sh +++ /dev/null @@ -1,22 +0,0 @@ -set -x -# distplan in ["CAI_ZeRO1", "CAI_ZeRO2", "CAI_Gemini", "Pytorch_DDP", "Pytorch_ZeRO"] -export DISTPLAN=${DISTPLAN:-"CAI_Gemini"} - -# The following options only valid when DISTPLAN="colossalai" -export GPUNUM=${GPUNUM:-1} -export PLACEMENT=${PLACEMENT:-"cpu"} -export BATCH_SIZE=${BATCH_SIZE:-16} - -# bert | albert -export MODEL_TYPE=${MODEL_TYPE:-"bert"} -export TRAIN_STEP=${TRAIN_STEP:-10} - -mkdir -p gemini_logs - -env CUDA_LAUNCH_BLOCKING=1 torchrun --standalone --nproc_per_node=${GPUNUM} ./train_bert_demo.py \ ---model_type=${MODEL_TYPE} \ ---batch_size=${BATCH_SIZE} \ ---placement=${PLACEMENT} \ ---distplan=${DISTPLAN} \ ---train_step=${TRAIN_STEP} \ -2>&1 | tee ./gemini_logs/${MODEL_TYPE}_${DISTPLAN}_gpu_${GPUNUM}_bs_${BATCH_SIZE}_${PLACEMENT}.log diff --git a/examples/language/bert/test_ci.sh b/examples/language/bert/test_ci.sh old mode 100644 new mode 100755 index 42c63fec50c0e3729704a4944f22c19116e76494..394ff831b8550e1a3a74fced3f1b792b1481827b --- a/examples/language/bert/test_ci.sh +++ b/examples/language/bert/test_ci.sh @@ -1,2 +1,8 @@ -set -x -env GPUNUM=1 bash run_gemini.sh +#!/bin/bash +set -xe + +pip install -r requirements.txt + +for plugin in "torch_ddp" "torch_ddp_fp16" "gemini" "low_level_zero" "hybrid_parallel"; do + torchrun --standalone --nproc_per_node 4 finetune.py --target_f1 0.86 --plugin $plugin --model_type "bert" +done diff --git a/examples/language/bert/train_bert_demo.py b/examples/language/bert/train_bert_demo.py deleted file mode 100644 index 9a0278b2c711d441405f3f9324fad1bdde66846c..0000000000000000000000000000000000000000 --- a/examples/language/bert/train_bert_demo.py +++ /dev/null @@ -1,331 +0,0 @@ -import os -from functools import partial -from time import time - -import psutil -import torch -from packaging import version -from torch import nn -from torch.nn.parallel import DistributedDataParallel as DDP -from transformers import AlbertConfig, AlbertForSequenceClassification, BertConfig, BertForSequenceClassification - -import colossalai -from colossalai.logging import disable_existing_loggers, get_dist_logger -from colossalai.nn.optimizer import HybridAdam -from colossalai.tensor import ColoParameter, ComputePattern, ComputeSpec, ProcessGroup, ReplicaSpec, ShardSpec -from colossalai.utils import get_current_device -from colossalai.zero import ColoInitContext, zero_model_wrapper, zero_optim_wrapper - -CAI_VERSION = colossalai.__version__ - - -def get_tflops(model_numel, batch_size, seq_len, step_time): - return model_numel * batch_size * seq_len * 8 / 1e12 / (step_time + 1e-12) - - -def get_profile_context(enable_flag, warmup_steps, active_steps, save_dir): - from contextlib import nullcontext - - from torch.profiler import ProfilerActivity, profile, schedule, tensorboard_trace_handler - if enable_flag: - return profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], - schedule=schedule(wait=0, warmup=warmup_steps, active=active_steps), - on_trace_ready=tensorboard_trace_handler(save_dir), - record_shapes=True, - profile_memory=True) - else: - - class DummyProfiler: - - def __init__(self): - self.step_number = 0 - - def step(self): - self.step_number += 1 - - return nullcontext(DummyProfiler()) - - -def get_time_stamp(): - import time - cur_time = time.strftime("%d-%H:%M", time.localtime()) - return cur_time - - -def get_bert_data(batch_size: int, sequence_length: int, vacob_size: int, n_class: int, device: torch.device): - input = torch.randint( - low=0, - high=vacob_size, - size=(batch_size, sequence_length), - device=device, - dtype=torch.long, - ) - label = torch.randint(low=0, high=n_class, size=(batch_size,), device=device, dtype=torch.long) - return input, label - - -def parse_args(): - parser = colossalai.get_default_parser() - parser.add_argument( - "--distplan", - type=str, - default='CAI_Gemini', - help="The distributed plan [colossalai, zero1, zero2, torch_ddp, torch_zero].", - ) - parser.add_argument( - "--placement", - type=str, - default='cpu', - help="Placement Policy for Gemini. Valid when using colossalai as dist plan.", - ) - parser.add_argument( - "--batch_size", - type=int, - default=8, - help="batch size per DP group of training.", - ) - parser.add_argument( - "--model_type", - type=str, - default="bert", - help="bert or albert", - ) - parser.add_argument( - "--train_step", - type=int, - default=10, - help="training iterations for test", - ) - - args = parser.parse_args() - return args - - -SEQ_LEN = 512 -VOCAB_SIZE = 1000 -NUM_LABELS = 10 - - -# Parameter Sharding Strategies for Tensor Parallelism -def split_param_single_dim_tp1d(dim: int, param: ColoParameter, pg: ProcessGroup): - spec = (ShardSpec([dim], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D)) - param.set_tensor_spec(*spec) - - -def split_param_row_tp1d(param: ColoParameter, pg: ProcessGroup): - split_param_single_dim_tp1d(0, param, pg) - - -def split_param_col_tp1d(param: ColoParameter, pg: ProcessGroup): - split_param_single_dim_tp1d(-1, param, pg) - - -def get_cpu_mem(): - return psutil.Process().memory_info().rss / 1024**2 - - -def get_gpu_mem(): - return torch.cuda.memory_allocated() / 1024**2 - - -def get_mem_info(prefix=''): - return f'{prefix}GPU memory usage: {get_gpu_mem():.2f} MB, CPU memory usage: {get_cpu_mem():.2f} MB' - - -def get_model_size(model: nn.Module): - total_numel = 0 - for module in model.modules(): - for p in module.parameters(recurse=False): - total_numel += p.numel() - return total_numel - - -def model_builder(args): - if args.model_type == "bert": - cfg = BertConfig(vocab_size=VOCAB_SIZE, num_labels=NUM_LABELS) - return BertForSequenceClassification(cfg) - elif args.model_type == "albert": - cfg = AlbertConfig(vocab_size=VOCAB_SIZE, num_labels=NUM_LABELS) - return AlbertForSequenceClassification(cfg) - else: - raise RuntimeError - - -def model_size_formatter(numel: int) -> str: - GB_SIZE = 10**9 - MB_SIZE = 10**6 - KB_SIZE = 10**3 - if numel >= GB_SIZE: - return f'{numel / GB_SIZE:.1f}B' - elif numel >= MB_SIZE: - return f'{numel / MB_SIZE:.1f}M' - elif numel >= KB_SIZE: - return f'{numel / KB_SIZE:.1f}K' - else: - return str(numel) - - -def set_cpu_maximum_parallelism(): - conf_str = torch.__config__.parallel_info() - inter_str = conf_str.split("hardware_concurrency() : ")[1] - max_concurrency = inter_str.split('\n')[0] - os.environ["OMP_NUM_THREADS"] = max_concurrency - print(f"environmental variable OMP_NUM_THREADS is set to {max_concurrency}.") - - -def main(): - # version check - # this example is supposed to work for versions greater than 0.2.0 - assert version.parse(CAI_VERSION) >= version.parse("0.2.0") - - set_cpu_maximum_parallelism() - args = parse_args() - - # if args.distplan not in ["colossalai", "torch_ddp", "torch_zero", "zero1", "zero2"]: - if args.distplan not in ["CAI_ZeRO1", "CAI_ZeRO2", "CAI_Gemini", "Pytorch_DDP", "Pytorch_ZeRO"]: - raise TypeError(f"{args.distplan} is error") - - # batch size per DP degree - BATCH_SIZE = args.batch_size - - NUM_STEPS = args.train_step - - WARMUP_STEPS = 1 - assert WARMUP_STEPS < NUM_STEPS, "warmup steps should smaller than the total steps" - assert (NUM_STEPS - WARMUP_STEPS) % 2 == 1, "the number of valid steps should be odd to take the median" - PROF_FLAG = False # The flag of profiling, False by default - - disable_existing_loggers() - colossalai.launch_from_torch(config={}) - - logger = get_dist_logger() - logger.info(f" {args.distplan}, batch size {BATCH_SIZE}", ranks=[0]) - - torch.manual_seed(123) - if args.distplan.startswith("CAI"): - # all param must use the same process group. - world_size = torch.distributed.get_world_size() - - # build a base-bert model - with ColoInitContext(device=get_current_device(), dtype=torch.half): - model = model_builder(args) - # model = BertForSequenceClassification(BertConfig(vocal_size = VOCAB_SIZE)) - - # asign running configurations - gemini_config = None - if args.distplan.startswith("CAI_ZeRO"): - optim_config = dict(reduce_bucket_size=12 * 1024 * 1024, overlap_communication=True, verbose=True) - elif args.distplan == "CAI_Gemini": - gemini_config = dict(strict_ddp_mode=True, - device=get_current_device(), - placement_policy=args.placement, - pin_memory=True, - hidden_dim=model.config.hidden_size, - search_range_mb=128) - optim_config = dict(gpu_margin_mem_ratio=0.) - else: - raise RuntimeError - - # build a highly optimized gpu/cpu optimizer - optimizer = HybridAdam(model.parameters(), lr=1e-3) - - if args.distplan == "CAI_ZeRO1": - zero_stage = 1 - elif args.distplan == "CAI_ZeRO2": - zero_stage = 2 - elif args.distplan == "CAI_Gemini": - zero_stage = 3 - else: - raise RuntimeError - - # wrap your model and optimizer - model = zero_model_wrapper(model, zero_stage, gemini_config) - optimizer = zero_optim_wrapper(model, optimizer, optim_config=optim_config) - - logger.info(get_mem_info(prefix='After init optim, '), ranks=[0]) - elif args.distplan.startswith("Pytorch"): - model = model_builder(args).cuda() - model = DDP(model) - if args.distplan.endswith("DDP"): - optimizer = torch.optim.Adam(model.parameters(), lr=1e-3) - elif args.distplan.endswith("ZeRO"): - from torch.distributed.optim import ZeroRedundancyOptimizer - optimizer = ZeroRedundancyOptimizer(model.parameters(), optimizer_class=torch.optim.Adam, lr=1e-3) - else: - raise RuntimeError - - # model is shared after TP - numel = get_model_size(model) - logger.info(f"the size of testing model size is {model_size_formatter(numel)}.") - logger.info(get_mem_info(prefix='After init model, '), ranks=[0]) - - # Tflops_per_GPU = global_batch * global_numel * seq_len * 8 / #gpu - # = (batch_per_DP_group * dp_degree) * (numel * tp_degree) * seq_len * 8 / (tp_degree * dp_degree) - # = batch_per_DP_group * numel * seq_len * 8 - get_tflops_func = partial(get_tflops, numel, BATCH_SIZE, SEQ_LEN) - - torch.cuda.synchronize() - model.train() - tflops_list = [] - - def train_step(): - # we just use randomly generated data here - input_ids, labels = get_bert_data(BATCH_SIZE, - SEQ_LEN, - VOCAB_SIZE, - NUM_LABELS, - device=torch.cuda.current_device()) - optimizer.zero_grad() - - start = time() - outputs = model(input_ids, labels=labels) - loss, logits = outputs[:2] - torch.cuda.synchronize() - fwd_end = time() - fwd_time = fwd_end - start - logger.info(get_mem_info(prefix=f'[{n + 1}/{NUM_STEPS}] Forward '), ranks=[0]) - - if args.distplan.startswith("CAI"): - optimizer.backward(loss) - elif args.distplan.startswith("Pytorch"): - loss.backward() - else: - raise RuntimeError - - torch.cuda.synchronize() - bwd_end = time() - bwd_time = bwd_end - fwd_end - logger.info(get_mem_info(prefix=f'[{n + 1}/{NUM_STEPS}] Backward '), ranks=[0]) - - optimizer.step() - torch.cuda.synchronize() - optim_time = time() - bwd_end - step_time = time() - start - logger.info(get_mem_info(prefix=f'[{n + 1}/{NUM_STEPS}] Optimizer step '), ranks=[0]) - - step_tflops = get_tflops_func(step_time) - logger.info( - f"[{n + 1}/{NUM_STEPS}] Loss:{loss.item():.3f}, Step time: {step_time:.3f}s, TFLOPS: {get_tflops_func(step_time):.3f}, FWD time: {fwd_time:.3f}s, BWD time: {bwd_time:.3f}s, OPTIM time: {optim_time:.3f}s", - ranks=[0], - ) - if n >= WARMUP_STEPS: - tflops_list.append(step_tflops) - - demo_profiler = get_profile_context(PROF_FLAG, - WARMUP_STEPS, - NUM_STEPS - WARMUP_STEPS, - save_dir=f"profile/{get_time_stamp()}-demo") - - with demo_profiler as prof: - for n in range(NUM_STEPS): - train_step() - prof.step() - - tflops_list.sort() - median_index = ((NUM_STEPS - WARMUP_STEPS) >> 1) + WARMUP_STEPS - logger.info(f"Median TFLOPS is {tflops_list[median_index]:.3f}") - torch.cuda.synchronize() - - -if __name__ == '__main__': - main() diff --git a/examples/language/gpt/README.md b/examples/language/gpt/README.md index 47d24a4d69cb6c2edc4041e5f7501cb2e1e3362c..03679e66404a59983fc47ac541569d2b1247d3a7 100644 --- a/examples/language/gpt/README.md +++ b/examples/language/gpt/README.md @@ -65,6 +65,16 @@ Titans provides a customized GPT model, which uses distributed operators as buil In [./titans/README.md], we provide a hybrid parallelism of ZeRO, TP and PP. You can switch parallel strategies using a config file. +### Hybridparallelism + +Hybridparallelism provides a user friendly plugin to set multiple parallelism method for training and inference. In [./hybridparallelism], we provide a n example to finetune gpt2 using Hybridparallelism. + +Quick run +```bash +cd ./hybridparallelism +bash run.sh +``` + ## Performance Testbed: a cluster of 8xA100 (80GB) and 1xAMD EPYC 7543 32-Core Processor (512 GB). GPUs are connected via PCI-e. diff --git a/examples/language/gpt/experiments/auto_offload/model_zoo.py b/examples/language/gpt/experiments/auto_offload/model_zoo.py index 35e44608f8108e88ea3d33b8a1f35417cffea9ce..75968a0b1da9a52b14c321d35f8929f18534374c 100644 --- a/examples/language/gpt/experiments/auto_offload/model_zoo.py +++ b/examples/language/gpt/experiments/auto_offload/model_zoo.py @@ -2,22 +2,20 @@ import torch import torch.nn as nn from transformers import GPT2Config, GPT2LMHeadModel -class GPTLMModel(nn.Module): - def __init__(self, - hidden_size=768, - num_layers=12, - num_attention_heads=12, - max_seq_len=1024, - vocab_size=50257): +class GPTLMModel(nn.Module): + def __init__(self, hidden_size=768, num_layers=12, num_attention_heads=12, max_seq_len=1024, vocab_size=50257): super().__init__() self.model = GPT2LMHeadModel( - GPT2Config(n_embd=hidden_size, - n_layer=num_layers, - n_head=num_attention_heads, - n_positions=max_seq_len, - n_ctx=max_seq_len, - vocab_size=vocab_size)) + GPT2Config( + n_embd=hidden_size, + n_layer=num_layers, + n_head=num_attention_heads, + n_positions=max_seq_len, + n_ctx=max_seq_len, + vocab_size=vocab_size, + ) + ) def forward(self, input_ids, attention_mask): # Only return lm_logits @@ -25,7 +23,6 @@ class GPTLMModel(nn.Module): class GPTLMLoss(nn.Module): - def __init__(self): super().__init__() self.loss_fn = nn.CrossEntropyLoss() @@ -36,6 +33,7 @@ class GPTLMLoss(nn.Module): # Flatten the tokens return self.loss_fn(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) + def get_gpt2_components(model_type: str, batch_size: int): vocab_size = 1024 seq_len = 8 @@ -62,4 +60,4 @@ def get_gpt2_components(model_type: str, batch_size: int): kwargs = dict(input_ids=input_ids, attention_mask=attention_mask) return kwargs - return gpt2_model_builder, gpt2_data_gen \ No newline at end of file + return gpt2_model_builder, gpt2_data_gen diff --git a/examples/language/gpt/experiments/auto_offload/requirements.txt b/examples/language/gpt/experiments/auto_offload/requirements.txt index 3ebde8d460aad354648666ab18c8413a213047b3..137a69e80498223cd7581a62e2e27320b77682a0 100644 --- a/examples/language/gpt/experiments/auto_offload/requirements.txt +++ b/examples/language/gpt/experiments/auto_offload/requirements.txt @@ -1,2 +1,2 @@ colossalai >= 0.1.12 -torch >= 1.8.1 \ No newline at end of file +torch >= 1.8.1 diff --git a/examples/language/gpt/experiments/auto_offload/train_gpt_offload.py b/examples/language/gpt/experiments/auto_offload/train_gpt_offload.py index 89415c23f93c6dd96f5256235a1050aa7a8e5f13..e811e1acbf7e42a9504d560b28783591bc054153 100644 --- a/examples/language/gpt/experiments/auto_offload/train_gpt_offload.py +++ b/examples/language/gpt/experiments/auto_offload/train_gpt_offload.py @@ -18,14 +18,14 @@ from colossalai.utils import get_current_device def parse_args(): parser = argparse.ArgumentParser() - parser.add_argument('--model_type', type=str, default="gpt2_medium") - parser.add_argument('--batch_size', type=int, default=64) - parser.add_argument('--solver_type', type=str, default='asyn') - parser.add_argument('--memory_budget', type=float, default=16) + parser.add_argument("--model_type", type=str, default="gpt2_medium") + parser.add_argument("--batch_size", type=int, default=64) + parser.add_argument("--solver_type", type=str, default="asyn") + parser.add_argument("--memory_budget", type=float, default=16) return parser.parse_args() -@pytest.mark.skipif(NOT_NVML, reason='pynvml is not installed') +@pytest.mark.skipif(NOT_NVML, reason="pynvml is not installed") def train_gpt(args): memory_budget = args.memory_budget * 1024 * 1024 * 1024 solver_type = args.solver_type @@ -34,10 +34,15 @@ def train_gpt(args): # build model model_builder, data_gen = get_gpt2_components(model_type=model_type, batch_size=batch_size) - label = torch.randint(low=0, high=128, size=( - 64, - 8, - ), device=get_current_device()) + label = torch.randint( + low=0, + high=128, + size=( + 64, + 8, + ), + device=get_current_device(), + ) criterion = GPTLMLoss() start_time = time.time() @@ -80,18 +85,20 @@ def train_gpt(args): exec_time = sum(sorted(time_list)[:5]) / 5 runtime_peak_mem_alc = torch.cuda.max_memory_allocated() / 1024**2 runtime_peak_mem_res = torch.cuda.max_memory_reserved() / 1024**2 - print(f'solver_type: {solver_type} | model_type: {model_type}') - print(f'| exec_time={exec_time:.3f} s | param_size={param_size:.3f} MB ' - f'| runtime_peak_mem_alc={runtime_peak_mem_alc:.3f} MB| runtime_peak_mem_res={runtime_peak_mem_res:.3f} MB|') + print(f"solver_type: {solver_type} | model_type: {model_type}") + print( + f"| exec_time={exec_time:.3f} s | param_size={param_size:.3f} MB " + f"| runtime_peak_mem_alc={runtime_peak_mem_alc:.3f} MB| runtime_peak_mem_res={runtime_peak_mem_res:.3f} MB|" + ) print(time_list) def run(rank, world_size, port, args): config = {} - colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + colossalai.launch(config=config, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") train_gpt(args) -if __name__ == '__main__': +if __name__ == "__main__": args = parse_args() spawn(run, 1, args=args) diff --git a/examples/language/gpt/experiments/auto_parallel/README.md b/examples/language/gpt/experiments/auto_parallel/README.md index 1c8b1c35109fca737a6269c54d03194e24422524..32688873f8f15acb37dbecc0dead552e2e357f21 100644 --- a/examples/language/gpt/experiments/auto_parallel/README.md +++ b/examples/language/gpt/experiments/auto_parallel/README.md @@ -13,10 +13,10 @@ conda install pytorch==1.12.0 torchvision==0.13.0 torchaudio==0.12.0 cudatoolkit pip install torch==1.12.0+cu113 torchvision==0.13.0+cu113 torchaudio==0.12.0 --extra-index-url https://download.pytorch.org/whl/cu113 ``` -### Install [Colossal-AI v0.2.0](https://colossalai.org/download/) From Official Website +### Install Colossal-AI ```bash -pip install colossalai==0.2.0+torch1.12cu11.3 -f https://release.colossalai.org +pip install colossalai==0.2.0 ``` ### Install transformers diff --git a/examples/language/gpt/experiments/auto_parallel/auto_parallel_with_gpt.py b/examples/language/gpt/experiments/auto_parallel/auto_parallel_with_gpt.py index e331fc8fcf10639366fc809e0558958c419a487a..f3d35dd9042b06df3ffe8b493009b42bc17c3ecb 100644 --- a/examples/language/gpt/experiments/auto_parallel/auto_parallel_with_gpt.py +++ b/examples/language/gpt/experiments/auto_parallel/auto_parallel_with_gpt.py @@ -7,8 +7,8 @@ import transformers from gpt_modules import GPT2LMHeadModel, GPTLMLoss from colossalai.auto_parallel.tensor_shard.initialize import autoparallelize -from colossalai.core import global_context as gpc from colossalai.initialize import launch_from_torch +from colossalai.legacy.core import global_context as gpc from colossalai.logging import disable_existing_loggers, get_dist_logger BATCH_SIZE = 16 @@ -29,8 +29,8 @@ def get_gpu_mem(): return torch.cuda.memory_allocated() / 1024**2 -def get_mem_info(prefix=''): - return f'{prefix}GPU memory usage: {get_gpu_mem():.2f} MB, CPU memory usage: {get_cpu_mem():.2f} MB' +def get_mem_info(prefix=""): + return f"{prefix}GPU memory usage: {get_gpu_mem():.2f} MB, CPU memory usage: {get_cpu_mem():.2f} MB" def get_tflops(model_numel, batch_size, seq_len, step_time): @@ -51,14 +51,14 @@ def main(): logger = get_dist_logger() config = transformers.GPT2Config(n_position=SEQ_LENGTH, n_layer=NUM_LAYERS, n_head=NUM_HEADS, n_embd=HIDDEN_DIM) if FP16: - model = GPT2LMHeadModel(config=config).half().to('cuda') + model = GPT2LMHeadModel(config=config).half().to("cuda") else: - model = GPT2LMHeadModel(config=config).to('cuda') + model = GPT2LMHeadModel(config=config).to("cuda") global_numel = sum([p.numel() for p in model.parameters()]) meta_input_sample = { - 'input_ids': torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64).to('meta'), - 'attention_mask': torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64).to('meta'), + "input_ids": torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64).to("meta"), + "attention_mask": torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64).to("meta"), } gm, solution = autoparallelize(model, meta_input_sample, return_solution=True) @@ -72,7 +72,7 @@ def main(): criterion = GPTLMLoss() optimizer = torch.optim.Adam(gm.parameters(), lr=0.01) - logger.info(get_mem_info(prefix='After init model, '), ranks=[0]) + logger.info(get_mem_info(prefix="After init model, "), ranks=[0]) get_tflops_func = partial(get_tflops, global_numel, BATCH_SIZE, SEQ_LENGTH) torch.cuda.synchronize() model.train() @@ -89,10 +89,11 @@ def main(): torch.cuda.synchronize() step_time = time() - start logger.info( - f'[{n+1}/{NUM_STEPS}] Loss:{loss.item():.3f}, Step time: {step_time:.3f}s, TFLOPS: {get_tflops_func(step_time):.3f}', - ranks=[0]) + f"[{n+1}/{NUM_STEPS}] Loss:{loss.item():.3f}, Step time: {step_time:.3f}s, TFLOPS: {get_tflops_func(step_time):.3f}", + ranks=[0], + ) torch.cuda.synchronize() -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/examples/language/gpt/experiments/auto_parallel/gpt_modules.py b/examples/language/gpt/experiments/auto_parallel/gpt_modules.py index 95feaec38c26794760bf01c34d45fe06f61b401a..ad9a197772848233f82a873d43ba20072dfe031b 100644 --- a/examples/language/gpt/experiments/auto_parallel/gpt_modules.py +++ b/examples/language/gpt/experiments/auto_parallel/gpt_modules.py @@ -8,7 +8,6 @@ from transformers.pytorch_utils import Conv1D class GPT2MLP(nn.Module): - def __init__(self, intermediate_size, config): super().__init__() embed_dim = config.hidden_size @@ -30,15 +29,15 @@ class GPT2MLP(nn.Module): # 2. The order of split and view op has been changed in the customized GPT2Attention class, the new # order is same as megatron-lm gpt model. class GPT2Attention(nn.Module): - def __init__(self, config, layer_idx=None): super().__init__() max_positions = config.max_position_embeddings self.register_buffer( "bias", - torch.tril(torch.ones((max_positions, max_positions), - dtype=torch.uint8)).view(1, 1, max_positions, max_positions), + torch.tril(torch.ones((max_positions, max_positions), dtype=torch.uint8)).view( + 1, 1, max_positions, max_positions + ), ) self.register_buffer("masked_bias", torch.tensor(-1e4)) @@ -64,7 +63,7 @@ class GPT2Attention(nn.Module): attn_weights = torch.matmul(query, key.transpose(-1, -2)) if self.scale_attn_weights: - attn_weights = attn_weights / (value.size(-1)**0.5) + attn_weights = attn_weights / (value.size(-1) ** 0.5) # Layer-wise attention scaling if self.scale_attn_by_inverse_layer_idx: @@ -72,7 +71,7 @@ class GPT2Attention(nn.Module): # if only "normal" attention layer implements causal mask query_length, key_length = query.size(-2), key.size(-2) - causal_mask = self.bias[:, :, key_length - query_length:key_length, :key_length].to(torch.bool) + causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length].to(torch.bool) attn_weights = torch.where(causal_mask, attn_weights, self.masked_bias.to(attn_weights.dtype)) if attention_mask is not None: @@ -93,7 +92,7 @@ class GPT2Attention(nn.Module): def _split_heads(self, tensor, num_heads, attn_head_size): new_shape = tensor.size()[:-1] + (num_heads, attn_head_size) tensor = tensor.view(new_shape) - return tensor.permute(0, 2, 1, 3) # (batch, head, seq_length, head_features) + return tensor.permute(0, 2, 1, 3) # (batch, head, seq_length, head_features) def _merge_heads(self, tensor, num_heads, attn_head_size): tensor = tensor.permute(0, 2, 1, 3).contiguous() @@ -106,10 +105,9 @@ class GPT2Attention(nn.Module): attention_mask: Optional[torch.FloatTensor] = None, head_mask: Optional[torch.FloatTensor] = None, ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]], ...]: - qkv = self.c_attn(hidden_states) query, key, value = self._split_heads(qkv, self.num_heads, 3 * self.head_dim).split(self.head_dim, dim=3) - present = (key, value) + (key, value) attn_output, attn_weights = self._attn(query, key, value, attention_mask, head_mask) attn_output = self._merge_heads(attn_output, self.num_heads, self.head_dim) attn_output = self.c_proj(attn_output) @@ -117,7 +115,6 @@ class GPT2Attention(nn.Module): class GPT2Block(nn.Module): - def __init__(self, config, layer_idx=None): super().__init__() hidden_size = config.hidden_size @@ -152,7 +149,6 @@ class GPT2Block(nn.Module): class GPT2Model(GPT2PreTrainedModel): - def __init__(self, config): super().__init__(config) @@ -189,11 +185,9 @@ class GPT2Model(GPT2PreTrainedModel): # GPT2Attention mask. attention_mask = attention_mask.view(batch_size, -1) attention_mask = attention_mask[:, None, None, :] - attention_mask = attention_mask.to(dtype=self.dtype) # fp16 compatibility + attention_mask = attention_mask.to(dtype=self.dtype) # fp16 compatibility attention_mask = (1.0 - attention_mask) * -10000.0 - encoder_attention_mask = None - # Prepare head mask if needed # 1.0 in head_mask indicate we keep the head # attention_probs has shape bsz x n_heads x N x N @@ -217,7 +211,6 @@ class GPT2Model(GPT2PreTrainedModel): class GPT2LMHeadModel(GPT2PreTrainedModel): - def __init__(self, config): super().__init__(config) self.transformer = GPT2Model(config) @@ -241,7 +234,6 @@ class GPT2LMHeadModel(GPT2PreTrainedModel): class GPTLMLoss(nn.Module): - def __init__(self): super().__init__() self.loss_fn = nn.CrossEntropyLoss() diff --git a/examples/language/gpt/experiments/pipeline_parallel/model_zoo.py b/examples/language/gpt/experiments/pipeline_parallel/model_zoo.py index c31b3fa6d1035a0238c1d56fc678351c0fa2068d..47cc87980556f1aafccf3c345bb19f3724cbb3da 100644 --- a/examples/language/gpt/experiments/pipeline_parallel/model_zoo.py +++ b/examples/language/gpt/experiments/pipeline_parallel/model_zoo.py @@ -4,22 +4,25 @@ from transformers import GPT2Config, GPT2LMHeadModel ## Define the Model and Loss Based on Huggingface transformers GPT2LMHeadModel class GPTLMModel(nn.Module): - - def __init__(self, - hidden_size=768, - num_layers=12, - num_attention_heads=12, - max_seq_len=1024, - vocab_size=50257, - checkpoint=False): + def __init__( + self, + hidden_size=768, + num_layers=12, + num_attention_heads=12, + max_seq_len=1024, + vocab_size=50257, + checkpoint=False, + ): super().__init__() self.checkpoint = checkpoint - self.config = GPT2Config(n_embd=hidden_size, - n_layer=num_layers, - n_head=num_attention_heads, - n_positions=max_seq_len, - n_ctx=max_seq_len, - vocab_size=vocab_size) + self.config = GPT2Config( + n_embd=hidden_size, + n_layer=num_layers, + n_head=num_attention_heads, + n_positions=max_seq_len, + n_ctx=max_seq_len, + vocab_size=vocab_size, + ) self.model = GPT2LMHeadModel(self.config) if checkpoint: self.model.gradient_checkpointing_enable() @@ -70,4 +73,4 @@ def model_builder(model_size: str) -> callable: raise TypeError(f"model_builder {model_size}") -__all__ = ['model_builder'] +__all__ = ["model_builder"] diff --git a/examples/language/gpt/experiments/pipeline_parallel/train_gpt_pp.py b/examples/language/gpt/experiments/pipeline_parallel/train_gpt_pp.py index ad69888b8cc80ba7cd1ea52bd853ba61bc6c297b..09bbae9c5b748f203996f71505dd896b19b9e2dc 100644 --- a/examples/language/gpt/experiments/pipeline_parallel/train_gpt_pp.py +++ b/examples/language/gpt/experiments/pipeline_parallel/train_gpt_pp.py @@ -5,39 +5,32 @@ from functools import partial import torch from model_zoo import model_builder from torch import nn -from tqdm import tqdm from colossalai.fx import ColoTracer -from colossalai.fx.passes.adding_split_node_pass import ( - avgnode_split_pass, - gpipe_dp_split_pass, - split_with_split_nodes_pass, -) +from colossalai.fx.passes.adding_split_node_pass import gpipe_dp_split_pass, split_with_split_nodes_pass from colossalai.fx.passes.meta_info_prop import MetaInfoProp +from colossalai.legacy.pipeline.middleware.adaptor import get_fx_topology +from colossalai.legacy.pipeline.rpc._pipeline_schedule import FillDrainPipelineEngine +from colossalai.legacy.pipeline.rpc.utils import rpc_run from colossalai.logging import disable_existing_loggers, get_dist_logger -from colossalai.nn.optimizer import HybridAdam -from colossalai.pipeline.middleware.adaptor import get_fx_topology -from colossalai.pipeline.rpc._pipeline_schedule import FillDrainPipelineEngine, OneFOneBPipelineEngine -from colossalai.pipeline.rpc.utils import rpc_run def parse_args(): parser = argparse.ArgumentParser() - parser.add_argument('--model_type', type=str, default="gpt2_medium") - parser.add_argument('--world_size', type=int, default=2) - parser.add_argument('--batch_size', type=int, default=16) - parser.add_argument('--dp_degree', type=int, default=1) - parser.add_argument('--tp_degree', type=int, default=1) - parser.add_argument('--num_microbatches', type=int, default=2) - parser.add_argument('--device', type=str, choices=['cpu', 'cuda'], default='cuda') - parser.add_argument('--master_addr', type=str, default='localhost') - parser.add_argument('--master_port', type=str, default='29011') - parser.add_argument('--num_worker_threads', type=int, default=128) + parser.add_argument("--model_type", type=str, default="gpt2_medium") + parser.add_argument("--world_size", type=int, default=2) + parser.add_argument("--batch_size", type=int, default=16) + parser.add_argument("--dp_degree", type=int, default=1) + parser.add_argument("--tp_degree", type=int, default=1) + parser.add_argument("--num_microbatches", type=int, default=2) + parser.add_argument("--device", type=str, choices=["cpu", "cuda"], default="cuda") + parser.add_argument("--master_addr", type=str, default="localhost") + parser.add_argument("--master_port", type=str, default="29011") + parser.add_argument("--num_worker_threads", type=int, default=128) return parser.parse_args() class GPTLMLoss(nn.Module): - def __init__(self): super().__init__() self.loss_fn = nn.CrossEntropyLoss() @@ -63,16 +56,16 @@ def get_tflops(model_numel, batch_size, seq_len, step_time): # Create annotated model which is noted where to be splitted. def get_annotated_model(model, data_kwargs, num_stages, num_microbatches): tracer = ColoTracer() - meta_args = {k: v.to('meta') for k, v in data_kwargs.items()} + meta_args = {k: v.to("meta") for k, v in data_kwargs.items()} graph = tracer.trace(root=model, meta_args=meta_args) gm = torch.fx.GraphModule(model, graph, model.__class__.__name__) - interp_meta_args = tuple([v.to('meta') for k, v in data_kwargs.items()]) + interp_meta_args = tuple([v.to("meta") for k, v in data_kwargs.items()]) interp = MetaInfoProp(gm) interp.run(*interp_meta_args) - #annotated_model = avgnode_split_pass(gm, num_stages) - annotated_model = gpipe_dp_split_pass(gm, num_stages, num_microbatches, mode='block', block_limit=0.01) + # annotated_model = avgnode_split_pass(gm, num_stages) + annotated_model = gpipe_dp_split_pass(gm, num_stages, num_microbatches, mode="block", block_limit=0.01) return annotated_model @@ -83,7 +76,7 @@ def create_partition_module(pp_rank: int, num_stages: int, model, data_kwargs, n topo = get_fx_topology(top_module) for submodule in split_submodules: if isinstance(submodule, torch.fx.GraphModule): - setattr(submodule, '_topo', topo) + setattr(submodule, "_topo", topo) return split_submodules[pp_rank + 1] @@ -107,8 +100,10 @@ def run_master(args): disable_existing_loggers() logger = get_dist_logger() - logger.info(f"{args.model_type}, batch size {batch_size}, num stage {stage_num}, num microbatch {num_microbatches}", - ranks=[0]) + logger.info( + f"{args.model_type}, batch size {batch_size}, num stage {stage_num}, num microbatch {num_microbatches}", + ranks=[0], + ) torch.manual_seed(123) @@ -117,26 +112,28 @@ def run_master(args): # warm up pipeline fx partition input_ids, attn_mask = get_data(batch_size, SEQ_LEN, VOCAB_SIZE) - warmup_data_kwargs = {'input_ids': input_ids, 'attention_mask': attn_mask} + warmup_data_kwargs = {"input_ids": input_ids, "attention_mask": attn_mask} # create model - logger.info(f'start model_builder') + logger.info(f"start model_builder") model = model_builder(model_type)(checkpoint=False) - logger.info(f'end model_builder') + logger.info(f"end model_builder") # set 1f1b pipeline engine - pp_engine = FillDrainPipelineEngine(partition_fn=partial(partition, model, warmup_data_kwargs, num_microbatches), - stage_num=stage_num, - num_microbatches=num_microbatches, - device=device, - chunk=1, - criterion=criterion, - metric=None, - checkpoint=False) + pp_engine = FillDrainPipelineEngine( + partition_fn=partial(partition, model, warmup_data_kwargs, num_microbatches), + stage_num=stage_num, + num_microbatches=num_microbatches, + device=device, + chunk=1, + criterion=criterion, + metric=None, + checkpoint=False, + ) partition_numels = pp_engine.remote_numels() for rank, numel in partition_numels.items(): - logger.info(f'{rank=} numel in the partition:{numel}') + logger.info(f"{rank=} numel in the partition:{numel}") # build optim pp_engine.initialize_optimizer(torch.optim.Adam, lr=1e-3) @@ -145,7 +142,7 @@ def run_master(args): for n in range(NUM_STEPS): # we just use randomly generated data here input_ids, attn_mask = get_data(batch_size, SEQ_LEN, VOCAB_SIZE) - batch = {'input_ids': input_ids, 'attention_mask': attn_mask} + batch = {"input_ids": input_ids, "attention_mask": attn_mask} start = time.time() outputs = pp_engine.forward_backward(batch=batch, labels=input_ids, forward_only=False) @@ -175,6 +172,6 @@ def run_master(args): logger.info(f"Avg TFLOPS per GPU is {sum(gpu_tflops) / world_size:.3f}") -if __name__ == '__main__': +if __name__ == "__main__": args = parse_args() rpc_run(args, run_master) diff --git a/examples/language/gpt/gemini/commons/model_zoo.py b/examples/language/gpt/gemini/commons/model_zoo.py index 65124d9e488403e3473c23b4cd7a31f5e2173871..0f4517549db2d7df9f51013cc16a7a828e34c55b 100644 --- a/examples/language/gpt/gemini/commons/model_zoo.py +++ b/examples/language/gpt/gemini/commons/model_zoo.py @@ -4,22 +4,25 @@ from transformers import GPT2Config, GPT2LMHeadModel ## Define the Model and Loss Based on Huggingface transformers GPT2LMHeadModel class GPTLMModel(nn.Module): - - def __init__(self, - hidden_size=768, - num_layers=12, - num_attention_heads=12, - max_seq_len=1024, - vocab_size=50257, - checkpoint=False): + def __init__( + self, + hidden_size=768, + num_layers=12, + num_attention_heads=12, + max_seq_len=1024, + vocab_size=50257, + checkpoint=False, + ): super().__init__() self.checkpoint = checkpoint - self.config = GPT2Config(n_embd=hidden_size, - n_layer=num_layers, - n_head=num_attention_heads, - n_positions=max_seq_len, - n_ctx=max_seq_len, - vocab_size=vocab_size) + self.config = GPT2Config( + n_embd=hidden_size, + n_layer=num_layers, + n_head=num_attention_heads, + n_positions=max_seq_len, + n_ctx=max_seq_len, + vocab_size=vocab_size, + ) self.model = GPT2LMHeadModel(self.config) if checkpoint: self.model.gradient_checkpointing_enable() @@ -82,4 +85,4 @@ def model_builder(model_size: str) -> callable: raise TypeError(f"model_builder {model_size}") -__all__ = ['model_builder'] +__all__ = ["model_builder"] diff --git a/examples/language/gpt/gemini/commons/utils.py b/examples/language/gpt/gemini/commons/utils.py index 7bd098c1927c71f0bc77c26ec7e5ac37688b7c4f..7ed5fdb92b3559347348cb98fe9f534c13a68027 100644 --- a/examples/language/gpt/gemini/commons/utils.py +++ b/examples/language/gpt/gemini/commons/utils.py @@ -6,7 +6,6 @@ from torch.profiler import ProfilerActivity, profile, schedule, tensorboard_trac class DummyProfiler: - def __init__(self): self.step_number = 0 @@ -27,11 +26,13 @@ def get_tflops(model_numel, batch_size, seq_len, step_time): def get_profile_context(enable_flag, warmup_steps, active_steps, save_dir): if enable_flag: - return profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], - schedule=schedule(wait=0, warmup=warmup_steps, active=active_steps), - on_trace_ready=tensorboard_trace_handler(save_dir), - record_shapes=True, - profile_memory=True) + return profile( + activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], + schedule=schedule(wait=0, warmup=warmup_steps, active=active_steps), + on_trace_ready=tensorboard_trace_handler(save_dir), + record_shapes=True, + profile_memory=True, + ) else: return nullcontext(DummyProfiler()) diff --git a/examples/language/gpt/gemini/run_gemini.sh b/examples/language/gpt/gemini/run_gemini.sh index ad4e9419c1bdf01bb5e75350b9cfb715908c68ec..5eaa4af4df78bc52a5711b54e0ddfb3d645c7077 100644 --- a/examples/language/gpt/gemini/run_gemini.sh +++ b/examples/language/gpt/gemini/run_gemini.sh @@ -4,28 +4,17 @@ export DISTPLAN=${DISTPLAN:-"CAI_Gemini"} # The following options only valid when DISTPLAN="colossalai" export GPUNUM=${GPUNUM:-1} -export TPDEGREE=${TPDEGREE:-1} -export PLACEMENT=${PLACEMENT:-"cpu"} -export USE_SHARD_INIT=${USE_SHARD_INIT:-False} export BATCH_SIZE=${BATCH_SIZE:-16} export MODEL_TYPE=${MODEL_TYPE:-"gpt2_medium"} export TRAIN_STEP=${TRAIN_STEP:-10} # export PYTHONPATH=$PWD:$PYTHONPATH -if [ ${USE_SHARD_INIT} = "True" ]; then - USE_SHARD_INIT="--shardinit" -else - USE_SHARD_INIT="" -fi mkdir -p gemini_logs torchrun --standalone --nproc_per_node=${GPUNUM} ./train_gpt_demo.py \ ---tp_degree=${TPDEGREE} \ --model_type=${MODEL_TYPE} \ --batch_size=${BATCH_SIZE} \ ---placement=${PLACEMENT} \ -${USE_SHARD_INIT} \ --distplan=${DISTPLAN} \ --train_step=${TRAIN_STEP} \ -2>&1 | tee ./gemini_logs/${MODEL_TYPE}_${DISTPLAN}_gpu_${GPUNUM}_bs_${BATCH_SIZE}_tp_${TPDEGREE}_${PLACEMENT}.log +2>&1 | tee ./gemini_logs/${MODEL_TYPE}_${DISTPLAN}_gpu_${GPUNUM}_bs_${BATCH_SIZE}.log diff --git a/examples/language/gpt/gemini/test_ci.sh b/examples/language/gpt/gemini/test_ci.sh index 6079d5ed615bd2ea4d6d18cc168a8371219103c1..6fb08b975d7a9d79566253d118dcdf69d4343989 100644 --- a/examples/language/gpt/gemini/test_ci.sh +++ b/examples/language/gpt/gemini/test_ci.sh @@ -3,32 +3,20 @@ $(cd `dirname $0`;pwd) export TRAIN_STEP=4 for MODEL_TYPE in "gpt2_medium"; do - for DISTPLAN in "colossalai"; do + for DISTPLAN in "CAI_Gemini"; do for BATCH_SIZE in 2; do for GPUNUM in 1 4; do - for TPDEGREE in 1 2; do - if [ ${TPDEGREE} -gt ${GPUNUM} ]; then - continue - fi - for PLACEMENT in "cpu" "auto"; do - MODEL_TYPE=${MODEL_TYPE} DISTPLAN=${DISTPLAN} BATCH_SIZE=${BATCH_SIZE} GPUNUM=${GPUNUM} TPDEGREE=${TPDEGREE} PLACEMENT=${PLACEMENT} \ - bash ./run_gemini.sh - done - done + MODEL_TYPE=${MODEL_TYPE} DISTPLAN=${DISTPLAN} BATCH_SIZE=${BATCH_SIZE} GPUNUM=${GPUNUM} \ + bash ./run_gemini.sh done done done - for DISTPLAN in "zero1" "zero2"; do + for DISTPLAN in "CAI_ZeRO2" "CAI_ZeRO1"; do for BATCH_SIZE in 2; do for GPUNUM in 1 4; do - for TPDEGREE in 1; do - if [ ${TPDEGREE} -gt ${GPUNUM} ]; then - continue - fi - MODEL_TYPE=${MODEL_TYPE} DISTPLAN=${DISTPLAN} BATCH_SIZE=${BATCH_SIZE} GPUNUM=${GPUNUM} TPDEGREE=${TPDEGREE}\ - bash ./run_gemini.sh - done + MODEL_TYPE=${MODEL_TYPE} DISTPLAN=${DISTPLAN} BATCH_SIZE=${BATCH_SIZE} GPUNUM=${GPUNUM} \ + bash ./run_gemini.sh done done done diff --git a/examples/language/gpt/gemini/train_gpt_demo.py b/examples/language/gpt/gemini/train_gpt_demo.py index b2a7fa36d02140e6919c2f2c3fe8c1f5596d2943..88b76c654b1d2e5192fb96fbb4d5815665daee13 100644 --- a/examples/language/gpt/gemini/train_gpt_demo.py +++ b/examples/language/gpt/gemini/train_gpt_demo.py @@ -1,4 +1,6 @@ +import argparse import os +from contextlib import nullcontext from functools import partial from time import time @@ -8,44 +10,26 @@ import torch.nn as nn from commons.model_zoo import model_builder from commons.utils import get_data, get_profile_context, get_tflops, get_time_stamp from packaging import version -from torch.nn.parallel import DistributedDataParallel as DDP import colossalai +from colossalai.booster import Booster +from colossalai.booster.plugin import GeminiPlugin, LowLevelZeroPlugin, TorchDDPPlugin +from colossalai.lazy import LazyInitContext from colossalai.logging import disable_existing_loggers, get_dist_logger from colossalai.nn.optimizer import HybridAdam -from colossalai.tensor import ColoParameter, ComputePattern, ComputeSpec, ProcessGroup, ReplicaSpec, ShardSpec from colossalai.utils import get_current_device -from colossalai.zero import ColoInitContext, zero_model_wrapper, zero_optim_wrapper CAI_VERSION = colossalai.__version__ def parse_args(): - parser = colossalai.get_default_parser() + parser = argparse.ArgumentParser() parser.add_argument( "--distplan", type=str, - default='CAI_Gemini', + default="CAI_Gemini", help="The distributed plan [colossalai, zero1, zero2, torch_ddp, torch_zero].", ) - parser.add_argument( - "--tp_degree", - type=int, - default=1, - help="Tensor Parallelism Degree. Valid when using colossalai as dist plan.", - ) - parser.add_argument( - "--placement", - type=str, - default='cpu', - help="Placement Policy for Gemini. Valid when using colossalai as dist plan.", - ) - parser.add_argument( - "--shardinit", - action='store_true', - help= - "Shard the tensors when init the model to shrink peak memory size on the assigned device. Valid when using colossalai as dist plan.", - ) parser.add_argument( "--batch_size", type=int, @@ -69,22 +53,7 @@ def parse_args(): return args -# Parameter Sharding Strategies for Tensor Parallelism -def split_param_single_dim_tp1d(dim: int, param: ColoParameter, pg: ProcessGroup): - spec = (ShardSpec([dim], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D)) - param.set_tensor_spec(*spec) - - -def split_param_row_tp1d(param: ColoParameter, pg: ProcessGroup): - split_param_single_dim_tp1d(0, param, pg) - - -def split_param_col_tp1d(param: ColoParameter, pg: ProcessGroup): - split_param_single_dim_tp1d(-1, param, pg) - - class GPTLMLoss(nn.Module): - def __init__(self): super().__init__() self.loss_fn = nn.CrossEntropyLoss() @@ -104,8 +73,8 @@ def get_gpu_mem(): return torch.cuda.memory_allocated() / 1024**2 -def get_mem_info(prefix=''): - return f'{prefix}GPU memory usage: {get_gpu_mem():.2f} MB, CPU memory usage: {get_cpu_mem():.2f} MB' +def get_mem_info(prefix=""): + return f"{prefix}GPU memory usage: {get_gpu_mem():.2f} MB, CPU memory usage: {get_cpu_mem():.2f} MB" def get_model_size(model: nn.Module): @@ -121,11 +90,11 @@ def model_size_formatter(numel: int) -> str: MB_SIZE = 10**6 KB_SIZE = 10**3 if numel >= GB_SIZE: - return f'{numel / GB_SIZE:.1f}B' + return f"{numel / GB_SIZE:.1f}B" elif numel >= MB_SIZE: - return f'{numel / MB_SIZE:.1f}M' + return f"{numel / MB_SIZE:.1f}M" elif numel >= KB_SIZE: - return f'{numel / KB_SIZE:.1f}K' + return f"{numel / KB_SIZE:.1f}K" else: return str(numel) @@ -133,52 +102,11 @@ def model_size_formatter(numel: int) -> str: def set_cpu_maximum_parallelism(): conf_str = torch.__config__.parallel_info() inter_str = conf_str.split("hardware_concurrency() : ")[1] - max_concurrency = inter_str.split('\n')[0] + max_concurrency = inter_str.split("\n")[0] os.environ["OMP_NUM_THREADS"] = max_concurrency print(f"environmental variable OMP_NUM_THREADS is set to {max_concurrency}.") -# Tensor Parallel -def tensor_parallelize(model: torch.nn.Module, pg: ProcessGroup): - """tensor_parallelize - Sharding the Model Parameters. - - Args: - model (torch.nn.Module): a torch module to be sharded - """ - for mn, module in model.named_modules(): - for pn, param in module.named_parameters(recurse=False): - # NOTE() a param maybe shared by two modules - if hasattr(param, 'visited'): - continue - - # if shard init, then convert param to replica and use the dp-only ProcessGroup - param: ColoParameter = param - param.set_dist_spec(ReplicaSpec()) - param.set_process_group(pg) - - # shard it w.r.t tp pattern - if 'mlp.c_fc' in mn: - if 'weight' in pn or 'bias' in pn: - split_param_col_tp1d(param, pg) # colmn slice - # keep the shape of the output from c_fc - param.compute_spec.set_output_replicate(False) - else: - param.set_dist_spec(ReplicaSpec()) - elif 'mlp.c_proj' in mn: - if 'weight' in pn: - split_param_row_tp1d(param, pg) # row slice - else: - param.set_dist_spec(ReplicaSpec()) - elif 'wte' in mn or 'wpe' in mn: - split_param_col_tp1d(param, pg) # colmn slice - elif 'c_attn' in mn or 'c_proj' in mn: - split_param_col_tp1d(param, pg) # colmn slice - else: - param.set_dist_spec(ReplicaSpec()) - param.visited = True - - def main(): # version check # this example is supposed to work for versions greater than 0.2.0 @@ -201,7 +129,7 @@ def main(): WARMUP_STEPS = 1 assert WARMUP_STEPS < NUM_STEPS, "warmup steps should smaller than the total steps" assert (NUM_STEPS - WARMUP_STEPS) % 2 == 1, "the number of valid steps should be odd to take the median" - PROF_FLAG = False # The flag of profiling, False by default + PROF_FLAG = False # The flag of profiling, False by default disable_existing_loggers() colossalai.launch_from_torch(config={}) @@ -211,48 +139,14 @@ def main(): # build criterion criterion = GPTLMLoss() - torch.manual_seed(123) if args.distplan.startswith("CAI"): - # all param must use the same process group. - world_size = torch.distributed.get_world_size() - shard_pg = ProcessGroup(tp_degree=world_size) if args.shardinit else None - default_dist_spec = ShardSpec([-1], [world_size]) if args.shardinit else None - - if args.shardinit and args.distplan != "CAI_Gemini": - raise RuntimeError("You can only use shardinit with CAI_Gemini") - + ctx = LazyInitContext(default_device=get_current_device()) if args.distplan == "CAI_Gemini" else nullcontext() # build GPT model - with ColoInitContext(device=get_current_device(), - dtype=torch.half, - default_dist_spec=default_dist_spec, - default_pg=shard_pg): + with ctx: model = model_builder(args.model_type)(checkpoint=True) - tp_pg = ProcessGroup(tp_degree=args.tp_degree) - # Tensor Parallelism (TP) - # You should notice that v0.1.10 is not compatible with TP degree > 1 - if args.tp_degree > 1: - tensor_parallelize(model, tp_pg) - - # asign running configurations - gemini_config = None - if args.distplan.startswith("CAI_ZeRO"): - optim_config = dict(reduce_bucket_size=12 * 1024 * 1024, overlap_communication=True, verbose=True) - elif args.distplan == "CAI_Gemini": - gemini_config = dict(strict_ddp_mode=args.tp_degree == 1, - device=get_current_device(), - placement_policy=args.placement, - pin_memory=True, - hidden_dim=model.config.n_embd, - search_range_mb=128) - optim_config = dict(gpu_margin_mem_ratio=0.) - else: - raise RuntimeError - - # build a highly optimized gpu/cpu optimizer - optimizer = HybridAdam(model.parameters(), lr=1e-3) - + # assign running configurations if args.distplan == "CAI_ZeRO1": zero_stage = 1 elif args.distplan == "CAI_ZeRO2": @@ -262,27 +156,41 @@ def main(): else: raise RuntimeError - # wrap your model and optimizer - model = zero_model_wrapper(model, zero_stage, gemini_config) - optimizer = zero_optim_wrapper(model, optimizer, optim_config=optim_config) + plugin = None + if args.distplan.startswith("CAI_ZeRO"): + plugin = LowLevelZeroPlugin( + stage=zero_stage, reduce_bucket_size_in_m=12, overlap_communication=True, verbose=True + ) + elif args.distplan == "CAI_Gemini": + plugin = GeminiPlugin(search_range_m=128, hidden_dim=model.config.n_embd) + else: + raise RuntimeError - logger.info(get_mem_info(prefix='After init optim, '), ranks=[0]) + # build a highly optimized gpu/cpu optimizer + optimizer = HybridAdam(model.parameters(), lr=1e-3) + + logger.info(get_mem_info(prefix="After init optim, "), ranks=[0]) elif args.distplan.startswith("Pytorch"): assert args.tp_degree == 1, "The degree of TP should be 1 for DDP examples." model = model_builder(args.model_type)(checkpoint=True).cuda() - model = DDP(model) + plugin = TorchDDPPlugin() if args.distplan.endswith("DDP"): optimizer = torch.optim.Adam(model.parameters(), lr=1e-3) elif args.distplan.endswith("ZeRO"): from torch.distributed.optim import ZeroRedundancyOptimizer + optimizer = ZeroRedundancyOptimizer(model.parameters(), optimizer_class=torch.optim.Adam, lr=1e-3) + else: raise RuntimeError + # wrap your model and optimizer + booster = Booster(plugin=plugin) + model, optimizer, criterion, _, _ = booster.boost(model, optimizer, criterion) # model is shared after TP numel = get_model_size(model) logger.info(f"the size of testing model size is {model_size_formatter(numel)}.") - logger.info(get_mem_info(prefix='After init model, '), ranks=[0]) + logger.info(get_mem_info(prefix="After init model, "), ranks=[0]) # Tflops_per_GPU = global_batch * global_numel * seq_len * 8 / #gpu # = (batch_per_DP_group * dp_degree) * (numel * tp_degree) * seq_len * 8 / (tp_degree * dp_degree) @@ -304,25 +212,19 @@ def main(): torch.cuda.synchronize() fwd_end = time() fwd_time = fwd_end - start - logger.info(get_mem_info(prefix=f'[{n + 1}/{NUM_STEPS}] Forward '), ranks=[0]) - - if args.distplan.startswith("CAI"): - optimizer.backward(loss) - elif args.distplan.startswith("Pytorch"): - loss.backward() - else: - raise RuntimeError + logger.info(get_mem_info(prefix=f"[{n + 1}/{NUM_STEPS}] Forward "), ranks=[0]) + booster.backward(loss, optimizer) torch.cuda.synchronize() bwd_end = time() bwd_time = bwd_end - fwd_end - logger.info(get_mem_info(prefix=f'[{n + 1}/{NUM_STEPS}] Backward '), ranks=[0]) + logger.info(get_mem_info(prefix=f"[{n + 1}/{NUM_STEPS}] Backward "), ranks=[0]) optimizer.step() torch.cuda.synchronize() optim_time = time() - bwd_end step_time = time() - start - logger.info(get_mem_info(prefix=f'[{n + 1}/{NUM_STEPS}] Optimizer step '), ranks=[0]) + logger.info(get_mem_info(prefix=f"[{n + 1}/{NUM_STEPS}] Optimizer step "), ranks=[0]) step_tflops = get_tflops_func(step_time) logger.info( @@ -332,10 +234,9 @@ def main(): if n >= WARMUP_STEPS: tflops_list.append(step_tflops) - demo_profiler = get_profile_context(PROF_FLAG, - WARMUP_STEPS, - NUM_STEPS - WARMUP_STEPS, - save_dir=f"profile/{get_time_stamp()}-demo") + demo_profiler = get_profile_context( + PROF_FLAG, WARMUP_STEPS, NUM_STEPS - WARMUP_STEPS, save_dir=f"profile/{get_time_stamp()}-demo" + ) with demo_profiler as prof: for n in range(NUM_STEPS): @@ -348,5 +249,5 @@ def main(): torch.cuda.synchronize() -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/examples/language/gpt/hybridparallelism/data.py b/examples/language/gpt/hybridparallelism/data.py new file mode 100644 index 0000000000000000000000000000000000000000..ef51f938dc4f8be5811916d692c8234fd811cd84 --- /dev/null +++ b/examples/language/gpt/hybridparallelism/data.py @@ -0,0 +1,123 @@ +import datasets +from transformers import AutoTokenizer, PreTrainedTokenizer + +from colossalai.booster.plugin.dp_plugin_base import DPPluginBase + + +class GLUEDataBuilder: + task_text_field_map = { + "cola": ["sentence"], + "sst2": ["sentence"], + "mrpc": ["sentence1", "sentence2"], + "qqp": ["question1", "question2"], + "stsb": ["sentence1", "sentence2"], + "mnli": ["premise", "hypothesis"], + "qnli": ["question", "sentence"], + "rte": ["sentence1", "sentence2"], + "wnli": ["sentence1", "sentence2"], + "ax": ["premise", "hypothesis"], + } + + glue_task_num_labels = { + "cola": 2, + "sst2": 2, + "mrpc": 2, + "qqp": 2, + "stsb": 1, + "mnli": 3, + "qnli": 2, + "rte": 2, + "wnli": 2, + "ax": 3, + } + + loader_columns = [ + "datasets_idx", + "input_ids", + "token_type_ids", + "attention_mask", + "start_positions", + "end_positions", + "labels", + ] + + def __init__( + self, + model_name_or_path: str, + plugin: DPPluginBase, + task_name: str = "mrpc", + max_seq_length: int = 128, + train_batch_size: int = 32, + eval_batch_size: int = 32, + **kwargs, + ): + super().__init__() + self.model_name_or_path = model_name_or_path + self.task_name = task_name + self.max_seq_length = max_seq_length + self.train_batch_size = train_batch_size + self.eval_batch_size = eval_batch_size + self.plugin = plugin + + self.text_fields = self.task_text_field_map[task_name] + self.num_labels = self.glue_task_num_labels[task_name] + self.tokenizer: PreTrainedTokenizer = AutoTokenizer.from_pretrained(self.model_name_or_path, use_fast=True) + self.setup() + + def setup(self): + self.dataset = datasets.load_dataset("glue", self.task_name) + + for split in self.dataset.keys(): + self.dataset[split] = self.dataset[split].map( + self.convert_to_features, + batched=True, + remove_columns=["label"], + ) + self.columns = [c for c in self.dataset[split].column_names if c in self.loader_columns] + self.dataset[split].set_format(type="torch", columns=self.columns) + + self.eval_splits = [x for x in self.dataset.keys() if "validation" in x] + + def prepare_data(self): + datasets.load_dataset("glue", self.task_name) + AutoTokenizer.from_pretrained(self.model_name_or_path, use_fast=True) + + def train_dataloader(self): + return self.plugin.prepare_dataloader( + self.dataset["train"], batch_size=self.train_batch_size, shuffle=True, drop_last=True + ) + + def val_dataloader(self): + if len(self.eval_splits) == 1: + return self.plugin.prepare_dataloader(self.dataset["validation"], batch_size=self.eval_batch_size) + elif len(self.eval_splits) > 1: + return [ + self.plugin.prepare_dataloader(self.dataset[x], batch_size=self.eval_batch_size) + for x in self.eval_splits + ] + + def test_dataloader(self): + if len(self.eval_splits) == 1: + return self.plugin.prepare_dataloader(self.dataset["test"], batch_size=self.eval_batch_size) + elif len(self.eval_splits) > 1: + return [ + self.plugin.prepare_dataloader(self.dataset[x], batch_size=self.eval_batch_size) + for x in self.eval_splits + ] + + def convert_to_features(self, example_batch): + # Either encode single sentence or sentence pairs + if len(self.text_fields) > 1: + texts_or_text_pairs = list(zip(example_batch[self.text_fields[0]], example_batch[self.text_fields[1]])) + else: + texts_or_text_pairs = example_batch[self.text_fields[0]] + + # Tokenize the text/text pairs + features = self.tokenizer.batch_encode_plus( + texts_or_text_pairs, max_length=self.max_seq_length, padding="max_length", truncation=True + ) + + # Rename label to labels to make it easier to pass to model forward + features["labels"] = example_batch["label"] + + return features diff --git a/examples/language/gpt/hybridparallelism/finetune.py b/examples/language/gpt/hybridparallelism/finetune.py new file mode 100644 index 0000000000000000000000000000000000000000..62804eff8ea57923094edcde7efcd5a0cd38a275 --- /dev/null +++ b/examples/language/gpt/hybridparallelism/finetune.py @@ -0,0 +1,311 @@ +import argparse +from typing import Callable, List, Union + +import evaluate +import torch +import torch.distributed as dist +import torch.nn as nn +from data import GLUEDataBuilder +from torch.optim import Optimizer +from torch.optim.lr_scheduler import _LRScheduler as LRScheduler +from torch.utils.data import DataLoader +from tqdm import tqdm +from transformers import AutoConfig, GPT2ForSequenceClassification, get_linear_schedule_with_warmup + +import colossalai +from colossalai.booster import Booster +from colossalai.booster.plugin import GeminiPlugin, HybridParallelPlugin, LowLevelZeroPlugin, TorchDDPPlugin +from colossalai.cluster import DistCoordinator +from colossalai.nn.optimizer import HybridAdam +from colossalai.utils import get_current_device + +# ============================== +# Prepare Hyperparameters +# ============================== +NUM_EPOCHS = 3 +BATCH_SIZE = 32 +LEARNING_RATE = 2.4e-5 +WEIGHT_DECAY = 0.01 +WARMUP_FRACTION = 0.1 + +output_transform_fn = lambda x: x +criterion = lambda x: x.loss + + +def move_to_cuda(batch): + return {k: v.cuda() for k, v in batch.items()} + + +@torch.no_grad() +def evaluate_model( + model: nn.Module, + criterion, + test_dataloader: Union[DataLoader, List[DataLoader]], + num_labels: int, + task_name: str, + eval_splits: List[str], + booster: Booster, + coordinator: DistCoordinator, +): + metric = evaluate.load("glue", task_name, process_id=coordinator.rank, num_process=coordinator.world_size) + model.eval() + + def evaluate_subset(dataloader: DataLoader): + use_pipeline = isinstance(booster.plugin, HybridParallelPlugin) and booster.plugin.pp_size > 1 + is_pp_last_stage = use_pipeline and booster.plugin.stage_manager.is_last_stage() + + accum_loss = torch.zeros(1, device=get_current_device()) + for batch in dataloader: + batch = move_to_cuda(batch) + labels = batch["labels"] + if use_pipeline: + pg_mesh = booster.plugin.pg_mesh + pp_group = booster.plugin.pp_group + current_pp_group_ranks = pg_mesh.get_ranks_in_group(pp_group) + current_rank = dist.get_rank() + batch = iter([batch]) + outputs = booster.execute_pipeline(batch, model, criterion, return_loss=True, return_outputs=True) + + if is_pp_last_stage: + logits = outputs["outputs"]["logits"] + val_loss = outputs["loss"] + accum_loss.add_(val_loss) + + if num_labels > 1: + preds = torch.argmax(logits, axis=1) + elif num_labels == 1: + preds = logits.squeeze() + + dist.broadcast_object_list([preds, val_loss], src=current_pp_group_ranks[-1], group=pp_group) + + metric.add_batch(predictions=preds, references=labels) + elif current_rank in current_pp_group_ranks: + object_list = [None, None] + dist.broadcast_object_list(object_list, src=current_pp_group_ranks[-1], group=pp_group) + + metric.add_batch(predictions=object_list[0].to(get_current_device()), references=labels) + accum_loss.add_(object_list[1].to(get_current_device())) + + else: + batch = move_to_cuda(batch) + outputs = model(**batch) + val_loss, logits = outputs[:2] + accum_loss.add_(val_loss) + + if num_labels > 1: + preds = torch.argmax(logits, axis=1) + elif num_labels == 1: + preds = logits.squeeze() + + metric.add_batch(predictions=preds, references=labels) + + results = metric.compute() + dist.all_reduce(accum_loss.div_(len(dataloader))) + if coordinator.is_master() and results is not None: + results["loss"] = accum_loss.item() / coordinator.world_size + + return results + + if isinstance(test_dataloader, DataLoader): + return evaluate_subset(test_dataloader) + else: + assert len(test_dataloader) == len(eval_splits) + final_results = {} + for split, sub_loader in zip(eval_splits, test_dataloader): + results = evaluate_subset(sub_loader) + final_results.update({f"{k}_{split}": v for k, v in results.items()}) + return final_results + + +def train_epoch( + epoch: int, + model: nn.Module, + optimizer: Optimizer, + _criterion: Callable, + lr_scheduler: LRScheduler, + train_dataloader: DataLoader, + booster: Booster, + coordinator: DistCoordinator, +): + use_pipeline = isinstance(booster.plugin, HybridParallelPlugin) and booster.plugin.pp_size > 1 + is_pp_last_stage = use_pipeline and booster.plugin.stage_manager.is_last_stage() + total_step = len(train_dataloader) + + model.train() + optimizer.zero_grad() + train_dataloader_iter = iter(train_dataloader) + with tqdm( + range(total_step), + desc=f"Epoch [{epoch + 1}/{NUM_EPOCHS}]", + disable=not (coordinator.is_master() or is_pp_last_stage), + ) as pbar: + # Forward pass + for _ in pbar: + if use_pipeline: + outputs = booster.execute_pipeline( + train_dataloader_iter, model, _criterion, optimizer, return_loss=True, return_outputs=True + ) + # Backward and optimize + if is_pp_last_stage: + loss = outputs["loss"] + pbar.set_postfix({"loss": loss.item()}) + else: + data = next(train_dataloader_iter) + data = move_to_cuda(data) + outputs = model(**data) + loss = _criterion(outputs, None) + # Backward + booster.backward(loss, optimizer) + pbar.set_postfix({"loss": loss.item()}) + + optimizer.step() + optimizer.zero_grad() + lr_scheduler.step() + + +def main(): + # ============================== + # Parse Arguments + # ============================== + parser = argparse.ArgumentParser() + parser.add_argument("-t", "--task", default="mrpc", help="GLUE task to run") + parser.add_argument( + "-p", + "--plugin", + type=str, + default="torch_ddp", + choices=["torch_ddp", "torch_ddp_fp16", "gemini", "low_level_zero", "hybrid_parallel"], + help="plugin to use", + ) + parser.add_argument( + "--model_type", + type=str, + default="gpt2", + help="only gpt2 now", + ) + parser.add_argument("--target_f1", type=float, default=None, help="target f1 score. Raise exception if not reached") + parser.add_argument("--use_lazy_init", type=bool, default=False, help="for initiating lazy init context") + args = parser.parse_args() + + if args.model_type == "gpt2": + model_name = "gpt2" + else: + raise RuntimeError + # ============================== + # Launch Distributed Environment + # ============================== + colossalai.launch_from_torch(config={}, seed=42) + coordinator = DistCoordinator() + + # local_batch_size = BATCH_SIZE // coordinator.world_size + lr = LEARNING_RATE * coordinator.world_size + + # ============================== + # Instantiate Plugin and Booster + # ============================== + booster_kwargs = {} + if args.plugin == "torch_ddp_fp16": + booster_kwargs["mixed_precision"] = "fp16" + if args.plugin.startswith("torch_ddp"): + plugin = TorchDDPPlugin() + elif args.plugin == "gemini": + plugin = GeminiPlugin(initial_scale=2**5) + elif args.plugin == "low_level_zero": + plugin = LowLevelZeroPlugin(initial_scale=2**5) + elif args.plugin == "hybrid_parallel": + # modify the param accordingly for finetuning test cases + plugin = HybridParallelPlugin( + tp_size=1, + pp_size=2, + num_microbatches=None, + microbatch_size=1, + enable_all_optimization=True, + zero_stage=1, + precision="fp16", + initial_scale=1, + ) + + booster = Booster(plugin=plugin, **booster_kwargs) + + # ============================== + # Prepare Dataloader + # ============================== + data_builder = GLUEDataBuilder( + model_name, plugin, args.task, train_batch_size=BATCH_SIZE, eval_batch_size=BATCH_SIZE + ) + train_dataloader = data_builder.train_dataloader() + test_dataloader = data_builder.test_dataloader() + + # ==================================== + # Prepare model, optimizer + # ==================================== + # gpt2 pretrained model + + cfg = AutoConfig.from_pretrained(model_name, num_labels=data_builder.num_labels) + + if model_name == "gpt2": + model = GPT2ForSequenceClassification.from_pretrained(model_name, config=cfg).cuda() + else: + raise RuntimeError + + # optimizer + no_decay = ["bias", "LayerNorm.weight"] + optimizer_grouped_parameters = [ + { + "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], + "weight_decay": WEIGHT_DECAY, + }, + { + "params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], + "weight_decay": 0.0, + }, + ] + + optimizer = HybridAdam(optimizer_grouped_parameters, lr=lr, eps=1e-8) + + # lr scheduler + total_steps = len(train_dataloader) * NUM_EPOCHS + num_warmup_steps = int(WARMUP_FRACTION * total_steps) + lr_scheduler = get_linear_schedule_with_warmup( + optimizer, + num_warmup_steps=num_warmup_steps, + num_training_steps=total_steps, + ) + + def _criterion(outputs, inputs): + outputs = output_transform_fn(outputs) + loss = criterion(outputs) + return loss + + # ============================== + # Boost with ColossalAI + # ============================== + model, optimizer, _criterion, _, lr_scheduler = booster.boost( + model, optimizer, criterion=_criterion, lr_scheduler=lr_scheduler + ) + + # ============================== + # Train model + # ============================== + for epoch in range(NUM_EPOCHS): + train_epoch(epoch, model, optimizer, _criterion, lr_scheduler, train_dataloader, booster, coordinator) + + results = evaluate_model( + model, + _criterion, + test_dataloader, + data_builder.num_labels, + args.task, + data_builder.eval_splits, + booster, + coordinator, + ) + + if coordinator.is_master(): + print(results) + if args.target_f1 is not None and "f1" in results: + assert results["f1"] >= args.target_f1, f'f1 score {results["f1"]} is lower than target {args.target_f1}' + + +if __name__ == "__main__": + main() diff --git a/examples/language/gpt/hybridparallelism/run.sh b/examples/language/gpt/hybridparallelism/run.sh new file mode 100644 index 0000000000000000000000000000000000000000..679cbbf9b1e2eac23448ebc6f55fe61ef5f827f7 --- /dev/null +++ b/examples/language/gpt/hybridparallelism/run.sh @@ -0,0 +1,5 @@ +# load via internet +torchrun --standalone --nproc_per_node 4 --master_port 29800 finetune.py --target_f1 0.6 --plugin hybrid_parallel --model_type "gpt2" + +# load from local +# torchrun --standalone --nproc_per_node 4 --master_port 29800 finetune.py --target_f1 0.6 --plugin hybrid_parallel --model_type "gpt2" --pretrained_path "your/path/to/pretrained_model" diff --git a/examples/language/gpt/requirements.txt b/examples/language/gpt/requirements.txt index ef58bb76bfc86d9eda33a76ee7687b6204d65f9c..1a173f228aeed461ecfaf5f030fcb342c47716ef 100644 --- a/examples/language/gpt/requirements.txt +++ b/examples/language/gpt/requirements.txt @@ -1,2 +1,7 @@ transformers >= 4.23 colossalai +evaluate +tqdm +scipy +scikit-learn +numpy diff --git a/examples/language/gpt/test_ci.sh b/examples/language/gpt/test_ci.sh index d67c17229e711ba0cafb0837c260f813ce595537..db742220d97ec02f2ef80c06a640556a1f9f0d50 100644 --- a/examples/language/gpt/test_ci.sh +++ b/examples/language/gpt/test_ci.sh @@ -1,2 +1,5 @@ set -x +pip install -r requirements.txt + cd gemini && bash test_ci.sh +# cd ../hybridparallelism && bash run.sh diff --git a/examples/language/gpt/titans/configs/gpt2_small_zero3_pp1d.py b/examples/language/gpt/titans/configs/gpt2_small_zero3_pp1d.py index 7bf53303948a641a5e74504e9979a336699f5aaf..bc3dcb85cf1ab89286429379b0639b636bad1d8c 100644 --- a/examples/language/gpt/titans/configs/gpt2_small_zero3_pp1d.py +++ b/examples/language/gpt/titans/configs/gpt2_small_zero3_pp1d.py @@ -11,8 +11,10 @@ HIDDEN_SIZE = 768 TENSOR_SHAPE = (BATCH_SIZE // NUM_MICRO_BATCHES, SEQ_LEN, HIDDEN_SIZE) # if you do no want zero, just comment out this dictionary -zero = dict(model_config=dict(tensor_placement_policy='cuda', shard_strategy=TensorShardStrategy()), - optimizer_config=dict(initial_scale=2**5)) +zero = dict( + model_config=dict(tensor_placement_policy="cuda", shard_strategy=TensorShardStrategy()), + optimizer_config=dict(initial_scale=2**5), +) optimizer = dict( type=HybridAdam, @@ -27,5 +29,5 @@ model = dict(type=GPT2_small_pipeline_hybrid, checkpoint=True, num_chunks=1) # for the current model implementation, mode can only be 1D or None parallel = dict( pipeline=1, - tensor=dict(size=2, mode='1d'), + tensor=dict(size=2, mode="1d"), ) diff --git a/examples/language/gpt/titans/configs/gpt3_zero3_pp1d.py b/examples/language/gpt/titans/configs/gpt3_zero3_pp1d.py index 9f9816b3004f8108dec385da6eff94499ce730e4..7413764dad81032e7cd66d04298be58668928861 100644 --- a/examples/language/gpt/titans/configs/gpt3_zero3_pp1d.py +++ b/examples/language/gpt/titans/configs/gpt3_zero3_pp1d.py @@ -11,8 +11,10 @@ HIDDEN_SIZE = 12288 TENSOR_SHAPE = (BATCH_SIZE // NUM_MICRO_BATCHES, SEQ_LEN, HIDDEN_SIZE) # if you do no want zero, just comment out this dictionary -zero = dict(model_config=dict(tensor_placement_policy='cuda', shard_strategy=TensorShardStrategy()), - optimizer_config=dict(initial_scale=2**16)) +zero = dict( + model_config=dict(tensor_placement_policy="cuda", shard_strategy=TensorShardStrategy()), + optimizer_config=dict(initial_scale=2**16), +) optimizer = dict( type=HybridAdam, @@ -27,5 +29,5 @@ model = dict(type=GPT3_pipeline_hybrid, checkpoint=True, num_chunks=1) # for the current model implementation, mode can only be 1D or None parallel = dict( pipeline=1, - tensor=dict(size=2, mode='1d'), # for the current model implementation, mode can only be 1D or None + tensor=dict(size=2, mode="1d"), # for the current model implementation, mode can only be 1D or None ) diff --git a/examples/language/gpt/titans/dataset/webtext.py b/examples/language/gpt/titans/dataset/webtext.py index 64f5944a97f927a818b092613179fb064466a730..e61f73fd9ebae8b2517cc9f4360638cce3096f8f 100644 --- a/examples/language/gpt/titans/dataset/webtext.py +++ b/examples/language/gpt/titans/dataset/webtext.py @@ -6,17 +6,16 @@ import torch from torch.utils.data import Dataset from transformers import GPT2Tokenizer -from colossalai.registry import DATASETS +from colossalai.legacy.registry import DATASETS @DATASETS.register_module class WebtextDataset(Dataset): - def __init__(self, path: Optional[str] = None, seq_len=1024) -> None: super().__init__() if path is not None: root = os.path.dirname(path) - encoded_data_cache_path = os.path.join(root, f'gpt_webtext_{seq_len}.pt') + encoded_data_cache_path = os.path.join(root, f"gpt_webtext_{seq_len}.pt") if os.path.isfile(encoded_data_cache_path): seq_len_, data, attention_mask = torch.load(encoded_data_cache_path) if seq_len_ == seq_len: @@ -26,12 +25,12 @@ class WebtextDataset(Dataset): raw_data = [] with open(path) as f: for line in f.readlines(): - raw_data.append(json.loads(line)['text']) - tokenizer = GPT2Tokenizer.from_pretrained('gpt2') + raw_data.append(json.loads(line)["text"]) + tokenizer = GPT2Tokenizer.from_pretrained("gpt2") tokenizer.pad_token = tokenizer.unk_token - encoded_data = tokenizer(raw_data, padding=True, truncation=True, max_length=seq_len, return_tensors='pt') - self.data = encoded_data['input_ids'] - self.attention_mask = encoded_data['attention_mask'] + encoded_data = tokenizer(raw_data, padding=True, truncation=True, max_length=seq_len, return_tensors="pt") + self.data = encoded_data["input_ids"] + self.attention_mask = encoded_data["attention_mask"] else: self.data = torch.randint(0, 50257, (10240, seq_len)) self.attention_mask = torch.ones_like(self.data) @@ -40,4 +39,4 @@ class WebtextDataset(Dataset): return len(self.data) def __getitem__(self, index): - return {'input_ids': self.data[index], 'attention_mask': self.attention_mask[index]}, self.data[index] + return {"input_ids": self.data[index], "attention_mask": self.attention_mask[index]}, self.data[index] diff --git a/examples/language/gpt/titans/model/embed.py b/examples/language/gpt/titans/model/embed.py index 6369b9f8c5a136b534a316718cf195888e6d57cc..b2e3f71a53876524624335867ac7adbb097a29a3 100644 --- a/examples/language/gpt/titans/model/embed.py +++ b/examples/language/gpt/titans/model/embed.py @@ -1,18 +1,17 @@ import torch import torch.nn.init as init from torch import Tensor -from torch import distributed as dist from torch import nn as nn from torch.nn import functional as F from torch.nn.parameter import Parameter -from colossalai.context import ParallelMode, seed -from colossalai.core import global_context as gpc -from colossalai.nn.layer.base_layer import ParallelLayer -from colossalai.nn.layer.parallel_1d._utils import gather_forward_split_backward, reduce_grad, reduce_input -from colossalai.nn.layer.parallel_1d.layers import Linear1D_Row -from colossalai.nn.layer.utils import divide -from colossalai.registry import LAYERS, LOSSES, MODELS +from colossalai.legacy.context import ParallelMode, seed +from colossalai.legacy.core import global_context as gpc +from colossalai.legacy.nn.layer.base_layer import ParallelLayer +from colossalai.legacy.nn.layer.parallel_1d._utils import gather_forward_split_backward, reduce_grad, reduce_input +from colossalai.legacy.nn.layer.parallel_1d.layers import Linear1D_Row +from colossalai.legacy.nn.layer.utils import divide +from colossalai.legacy.registry import LAYERS, LOSSES from colossalai.utils import get_current_device @@ -30,13 +29,9 @@ class VocabParallelEmbedding(torch.nn.Module): will ignore this embedding """ - def __init__(self, - hidden_size, - vocab_size, - max_sequence_length, - embedding_dropout_prob, - num_tokentypes=0, - dtype=torch.float): + def __init__( + self, hidden_size, vocab_size, max_sequence_length, embedding_dropout_prob, num_tokentypes=0, dtype=torch.float + ): super(VocabParallelEmbedding, self).__init__() self.hidden_size = hidden_size @@ -44,11 +39,11 @@ class VocabParallelEmbedding(torch.nn.Module): # Word embeddings (parallel). self.word_embeddings = VocabParallelEmbedding1D(vocab_size, self.hidden_size, dtype=dtype) - self._word_embeddings_key = 'word_embeddings' + self._word_embeddings_key = "word_embeddings" # Position embedding (serial). self.position_embeddings = torch.nn.Embedding(max_sequence_length, self.hidden_size, dtype=dtype) - self._position_embeddings_key = 'position_embeddings' + self._position_embeddings_key = "position_embeddings" # Initialize the position embeddings. # self.init_method(self.position_embeddings.weight) @@ -56,7 +51,7 @@ class VocabParallelEmbedding(torch.nn.Module): # Add this as an optional field that can be added through # method call so we can load a pretrain model without # token types and add them as needed. - self._tokentype_embeddings_key = 'tokentype_embeddings' + self._tokentype_embeddings_key = "tokentype_embeddings" if self.num_tokentypes > 0: self.tokentype_embeddings = torch.nn.Embedding(self.num_tokentypes, self.hidden_size, dtype=dtype) # Initialize the token-type embeddings. @@ -83,9 +78,9 @@ class VocabParallelEmbedding(torch.nn.Module): This allows us to load the model normally and then add this embedding. """ if self.tokentype_embeddings is not None: - raise Exception('tokentype embeddings is already initialized') + raise Exception("tokentype embeddings is already initialized") if torch.distributed.get_rank() == 0: - print('adding embedding for {} tokentypes'.format(num_tokentypes), flush=True) + print("adding embedding for {} tokentypes".format(num_tokentypes), flush=True) self.num_tokentypes = num_tokentypes self.tokentype_embeddings = torch.nn.Embedding(num_tokentypes, self.hidden_size) # Initialize the token-type embeddings. @@ -112,19 +107,16 @@ class VocabParallelEmbedding(torch.nn.Module): embeddings = self.embedding_dropout(embeddings) return embeddings - def state_dict_for_save_checkpoint(self, destination=None, prefix='', keep_vars=False): + def state_dict_for_save_checkpoint(self, destination=None, prefix="", keep_vars=False): """For easy load.""" state_dict_ = {} - state_dict_[self._word_embeddings_key] \ - = self.word_embeddings.state_dict(destination, prefix, keep_vars) - state_dict_[self._position_embeddings_key] \ - = self.position_embeddings.state_dict( - destination, prefix, keep_vars) + state_dict_[self._word_embeddings_key] = self.word_embeddings.state_dict(destination, prefix, keep_vars) + state_dict_[self._position_embeddings_key] = self.position_embeddings.state_dict(destination, prefix, keep_vars) if self.num_tokentypes > 0: - state_dict_[self._tokentype_embeddings_key] \ - = self.tokentype_embeddings.state_dict( - destination, prefix, keep_vars) + state_dict_[self._tokentype_embeddings_key] = self.tokentype_embeddings.state_dict( + destination, prefix, keep_vars + ) return state_dict_ @@ -138,9 +130,8 @@ class VocabParallelEmbedding(torch.nn.Module): # for backward compatibility. state_dict_ = {} for key in state_dict.keys(): - if 'word_embeddings' in key: - state_dict_[key.split('word_embeddings.')[1]] \ - = state_dict[key] + if "word_embeddings" in key: + state_dict_[key.split("word_embeddings.")[1]] = state_dict[key] self.word_embeddings.load_state_dict(state_dict_, strict=strict) # Position embedding. @@ -150,9 +141,8 @@ class VocabParallelEmbedding(torch.nn.Module): # for backward compatibility. state_dict_ = {} for key in state_dict.keys(): - if 'position_embeddings' in key: - state_dict_[key.split('position_embeddings.')[1]] \ - = state_dict[key] + if "position_embeddings" in key: + state_dict_[key.split("position_embeddings.")[1]] = state_dict[key] self.position_embeddings.load_state_dict(state_dict_, strict=strict) # Tokentype embedding. @@ -163,15 +153,14 @@ class VocabParallelEmbedding(torch.nn.Module): else: # for backward compatibility. for key in state_dict.keys(): - if 'tokentype_embeddings' in key: - state_dict_[key.split('tokentype_embeddings.')[1]] \ - = state_dict[key] + if "tokentype_embeddings" in key: + state_dict_[key.split("tokentype_embeddings.")[1]] = state_dict[key] if len(state_dict_.keys()) > 0: self.tokentype_embeddings.load_state_dict(state_dict_, strict=strict) else: - print('***WARNING*** expected tokentype embeddings in the ' - 'checkpoint but could not find it', - flush=True) + print( + "***WARNING*** expected tokentype embeddings in the " "checkpoint but could not find it", flush=True + ) class VocabParallelEmbedding1D(torch.nn.Module): @@ -193,37 +182,41 @@ class VocabParallelEmbedding1D(torch.nn.Module): # Set the details for compatibility. self.padding_idx = None self.max_norm = None - self.norm_type = 2. + self.norm_type = 2.0 self.scale_grad_by_freq = False self.sparse = False self._weight = None self.tensor_model_parallel_size = gpc.tensor_parallel_size # Divide the weight matrix along the vocabulary dimension. - self.vocab_start_index, self.vocab_end_index = \ - VocabUtility.vocab_range_from_global_vocab_size( - self.num_embeddings, gpc.get_local_rank(ParallelMode.PARALLEL_1D), - self.tensor_model_parallel_size) - self.num_embeddings_per_partition = self.vocab_end_index - \ - self.vocab_start_index + self.vocab_start_index, self.vocab_end_index = VocabUtility.vocab_range_from_global_vocab_size( + self.num_embeddings, gpc.get_local_rank(ParallelMode.PARALLEL_1D), self.tensor_model_parallel_size + ) + self.num_embeddings_per_partition = self.vocab_end_index - self.vocab_start_index # Allocate weights and initialize. - factory_kwargs = {'device': get_current_device(), 'dtype': dtype} + factory_kwargs = {"device": get_current_device(), "dtype": dtype} self.weight = Parameter(torch.empty(self.num_embeddings_per_partition, self.embedding_dim, **factory_kwargs)) init.uniform_(self.weight, -1, 1) def forward(self, input_): if self.tensor_model_parallel_size > 1: # Build the mask. - input_mask = (input_ < self.vocab_start_index) | \ - (input_ >= self.vocab_end_index) + input_mask = (input_ < self.vocab_start_index) | (input_ >= self.vocab_end_index) # Mask the input. masked_input = input_.clone() - self.vocab_start_index masked_input[input_mask] = 0 else: masked_input = input_ # Get the embeddings. - output_parallel = F.embedding(masked_input, self.weight, self.padding_idx, self.max_norm, self.norm_type, - self.scale_grad_by_freq, self.sparse) + output_parallel = F.embedding( + masked_input, + self.weight, + self.padding_idx, + self.max_norm, + self.norm_type, + self.scale_grad_by_freq, + self.sparse, + ) # Mask the output embedding. if self.tensor_model_parallel_size > 1: output_parallel[input_mask, :] = 0.0 @@ -234,7 +227,6 @@ class VocabParallelEmbedding1D(torch.nn.Module): @LOSSES.register_module class vocab_parallel_cross_entropy(nn.Module): - def __init__(self): super().__init__() @@ -242,20 +234,19 @@ class vocab_parallel_cross_entropy(nn.Module): """Helper function for the cross entropy.""" vocab_parallel_logits = vocab_parallel_logits[..., :-1, :].contiguous() target = target[..., 1:].contiguous() - return _VocabParallelCrossEntropy.apply(vocab_parallel_logits.view(-1, vocab_parallel_logits.size(-1)), - target.view(-1)) + return _VocabParallelCrossEntropy.apply( + vocab_parallel_logits.view(-1, vocab_parallel_logits.size(-1)), target.view(-1) + ) class _VocabParallelCrossEntropy(torch.autograd.Function): - @staticmethod def forward(ctx, vocab_parallel_logits, target): - # Maximum value along vocab dimension across all GPUs. logits_max = torch.max(vocab_parallel_logits, dim=-1)[0] - torch.distributed.all_reduce(logits_max, - op=torch.distributed.ReduceOp.MAX, - group=gpc.get_group(ParallelMode.PARALLEL_1D)) + torch.distributed.all_reduce( + logits_max, op=torch.distributed.ReduceOp.MAX, group=gpc.get_group(ParallelMode.PARALLEL_1D) + ) # Subtract the maximum value. vocab_parallel_logits.sub_(logits_max.unsqueeze(dim=-1)) @@ -282,17 +273,17 @@ class _VocabParallelCrossEntropy(torch.autograd.Function): predicted_logits = predicted_logits_1d.view_as(target) predicted_logits[target_mask] = 0.0 # All reduce is needed to get the chunks from other GPUs. - torch.distributed.all_reduce(predicted_logits, - op=torch.distributed.ReduceOp.SUM, - group=gpc.get_group(ParallelMode.PARALLEL_1D)) + torch.distributed.all_reduce( + predicted_logits, op=torch.distributed.ReduceOp.SUM, group=gpc.get_group(ParallelMode.PARALLEL_1D) + ) # Sum of exponential of logits along vocab dimension across all GPUs. exp_logits = vocab_parallel_logits torch.exp(vocab_parallel_logits, out=exp_logits) sum_exp_logits = exp_logits.sum(dim=-1) - torch.distributed.all_reduce(sum_exp_logits, - op=torch.distributed.ReduceOp.SUM, - group=gpc.get_group(ParallelMode.PARALLEL_1D)) + torch.distributed.all_reduce( + sum_exp_logits, op=torch.distributed.ReduceOp.SUM, group=gpc.get_group(ParallelMode.PARALLEL_1D) + ) # Loss = log(sum(exp(logits))) - predicted-logit. loss = torch.log(sum_exp_logits) - predicted_logits @@ -304,8 +295,7 @@ class _VocabParallelCrossEntropy(torch.autograd.Function): @staticmethod def backward(ctx, grad_output): - - # Retreive tensors from the forward path. + # Retrieve tensors from the forward path. softmax, target_mask, masked_target_1d = ctx.saved_tensors # All the inputs have softmax as their gradient. @@ -316,7 +306,7 @@ class _VocabParallelCrossEntropy(torch.autograd.Function): # Add the gradient from matching classes. arange_1d = torch.arange(start=0, end=grad_2d.size()[0], device=grad_2d.device) - grad_2d[arange_1d, masked_target_1d] -= (1.0 - target_mask.view(-1).float()) + grad_2d[arange_1d, masked_target_1d] -= 1.0 - target_mask.view(-1).float() # Finally elementwise multiplication with the output gradients. grad_input.mul_(grad_output.unsqueeze(dim=-1)) @@ -326,8 +316,8 @@ class _VocabParallelCrossEntropy(torch.autograd.Function): class VocabUtility: """Split the vocabulary into `world_size` chunks amd return the - first and last index of the vocabulary belonging to the `rank` - partition: Note that indices in [fist, last)""" + first and last index of the vocabulary belonging to the `rank` + partition: Note that indices in [fist, last)""" @staticmethod def vocab_range_from_per_partition_vocab_size(per_partition_vocab_size, rank, world_size): @@ -393,11 +383,11 @@ class HiddenParallelEmbedding(torch.nn.Module): # Word embeddings (parallel). self.word_embeddings = HiddenParallelEmbedding1D(vocab_size, hidden_size, dtype, padding_idx) - self._word_embeddings_key = 'word_embeddings' + self._word_embeddings_key = "word_embeddings" # Position embedding (serial). self.position_embeddings = torch.nn.Embedding(max_sequence_length, self.hidden_size) - self._position_embeddings_key = 'position_embeddings' + self._position_embeddings_key = "position_embeddings" # Initialize the position embeddings. # self.init_method(self.position_embeddings.weight) @@ -405,7 +395,7 @@ class HiddenParallelEmbedding(torch.nn.Module): # Add this as an optional field that can be added through # method call so we can load a pretrain model without # token types and add them as needed. - self._tokentype_embeddings_key = 'tokentype_embeddings' + self._tokentype_embeddings_key = "tokentype_embeddings" if self.num_tokentypes > 0: self.tokentype_embeddings = torch.nn.Embedding(self.num_tokentypes, self.hidden_size) # Initialize the token-type embeddings. @@ -432,9 +422,9 @@ class HiddenParallelEmbedding(torch.nn.Module): This allows us to load the model normally and then add this embedding. """ if self.tokentype_embeddings is not None: - raise Exception('tokentype embeddings is already initialized') + raise Exception("tokentype embeddings is already initialized") if torch.distributed.get_rank() == 0: - print('adding embedding for {} tokentypes'.format(num_tokentypes), flush=True) + print("adding embedding for {} tokentypes".format(num_tokentypes), flush=True) self.num_tokentypes = num_tokentypes self.tokentype_embeddings = torch.nn.Embedding(num_tokentypes, self.hidden_size) # Initialize the token-type embeddings. @@ -460,19 +450,16 @@ class HiddenParallelEmbedding(torch.nn.Module): embeddings = self.embedding_dropout(embeddings) return embeddings - def state_dict_for_save_checkpoint(self, destination=None, prefix='', keep_vars=False): + def state_dict_for_save_checkpoint(self, destination=None, prefix="", keep_vars=False): """For easy load.""" state_dict_ = {} - state_dict_[self._word_embeddings_key] \ - = self.word_embeddings.state_dict(destination, prefix, keep_vars) - state_dict_[self._position_embeddings_key] \ - = self.position_embeddings.state_dict( - destination, prefix, keep_vars) + state_dict_[self._word_embeddings_key] = self.word_embeddings.state_dict(destination, prefix, keep_vars) + state_dict_[self._position_embeddings_key] = self.position_embeddings.state_dict(destination, prefix, keep_vars) if self.num_tokentypes > 0: - state_dict_[self._tokentype_embeddings_key] \ - = self.tokentype_embeddings.state_dict( - destination, prefix, keep_vars) + state_dict_[self._tokentype_embeddings_key] = self.tokentype_embeddings.state_dict( + destination, prefix, keep_vars + ) return state_dict_ @@ -486,9 +473,8 @@ class HiddenParallelEmbedding(torch.nn.Module): # for backward compatibility. state_dict_ = {} for key in state_dict.keys(): - if 'word_embeddings' in key: - state_dict_[key.split('word_embeddings.')[1]] \ - = state_dict[key] + if "word_embeddings" in key: + state_dict_[key.split("word_embeddings.")[1]] = state_dict[key] self.word_embeddings.load_state_dict(state_dict_, strict=strict) # Position embedding. @@ -498,9 +484,8 @@ class HiddenParallelEmbedding(torch.nn.Module): # for backward compatibility. state_dict_ = {} for key in state_dict.keys(): - if 'position_embeddings' in key: - state_dict_[key.split('position_embeddings.')[1]] \ - = state_dict[key] + if "position_embeddings" in key: + state_dict_[key.split("position_embeddings.")[1]] = state_dict[key] self.position_embeddings.load_state_dict(state_dict_, strict=strict) # Tokentype embedding. @@ -511,15 +496,14 @@ class HiddenParallelEmbedding(torch.nn.Module): else: # for backward compatibility. for key in state_dict.keys(): - if 'tokentype_embeddings' in key: - state_dict_[key.split('tokentype_embeddings.')[1]] \ - = state_dict[key] + if "tokentype_embeddings" in key: + state_dict_[key.split("tokentype_embeddings.")[1]] = state_dict[key] if len(state_dict_.keys()) > 0: self.tokentype_embeddings.load_state_dict(state_dict_, strict=strict) else: - print('***WARNING*** expected tokentype embeddings in the ' - 'checkpoint but could not find it', - flush=True) + print( + "***WARNING*** expected tokentype embeddings in the " "checkpoint but could not find it", flush=True + ) class HiddenParallelEmbedding1D(torch.nn.Module): @@ -542,21 +526,21 @@ class HiddenParallelEmbedding1D(torch.nn.Module): # Set the details for compatibility. self.padding_idx = padding_idx self.max_norm = None - self.norm_type = 2. + self.norm_type = 2.0 self.scale_grad_by_freq = False self.sparse = False self._weight = None # Allocate weights and initialize. - factory_kwargs = {'device': get_current_device(), 'dtype': dtype} + factory_kwargs = {"device": get_current_device(), "dtype": dtype} self.weight = Parameter(torch.empty(num_embeddings, embed_dim_per_partition, **factory_kwargs)) init.uniform_(self.weight, -1, 1) def forward(self, input_): - # Get the embeddings. - output_parallel = F.embedding(input_, self.weight, self.padding_idx, self.max_norm, self.norm_type, - self.scale_grad_by_freq, self.sparse) + output_parallel = F.embedding( + input_, self.weight, self.padding_idx, self.max_norm, self.norm_type, self.scale_grad_by_freq, self.sparse + ) # Reduce across all the model parallel GPUs. output = gather_forward_split_backward(output_parallel, ParallelMode.PARALLEL_1D, dim=-1) @@ -584,11 +568,9 @@ class HiddenParallelGPTLMHead1D(ParallelLayer): # self.embedding = HiddenParallelEmbedding1D(vocab_size, hidden_size, dtype, padding_idx) # (hidden_size/q, vocab_size) self.synced_embed = False - self.head = Linear1D_Row(in_features=embed_dim, - out_features=vocab_size, - bias=False, - dtype=dtype, - parallel_input=False) + self.head = Linear1D_Row( + in_features=embed_dim, out_features=vocab_size, bias=False, dtype=dtype, parallel_input=False + ) def forward(self, x: Tensor) -> Tensor: if self.synced_embed: diff --git a/examples/language/gpt/titans/model/gpt1d.py b/examples/language/gpt/titans/model/gpt1d.py index 2edd03606b7da3c7ba5418ec2ace2f1841576670..f8e2f42e11cb7a1211d56d0ac1d7e43e808349f6 100644 --- a/examples/language/gpt/titans/model/gpt1d.py +++ b/examples/language/gpt/titans/model/gpt1d.py @@ -9,27 +9,30 @@ from torch import nn as nn from colossalai import kernel from colossalai import nn as col_nn -from colossalai.core import global_context as gpc from colossalai.kernel.cuda_native.scaled_softmax import AttnMaskType -from colossalai.nn.layer import Linear1D_Col, Linear1D_Row -from colossalai.nn.layer.base_layer import ParallelLayer -from colossalai.nn.layer.utils import ACT2FN, divide +from colossalai.legacy.core import global_context as gpc +from colossalai.legacy.nn.layer import Linear1D_Col, Linear1D_Row +from colossalai.legacy.nn.layer.base_layer import ParallelLayer +from colossalai.legacy.nn.layer.utils import ACT2FN, divide +from colossalai.legacy.utils.activation_checkpoint import checkpoint from colossalai.utils import checkpoint -from colossalai.utils.activation_checkpoint import checkpoint __all__ = [ - 'GPTMLP1D', 'GPTSelfAttention1D', 'GPTTransformerLayer1D', 'FusedGPTSelfAttention1D', 'FusedGPTTransformerLayer1D' + "GPTMLP1D", + "GPTSelfAttention1D", + "GPTTransformerLayer1D", + "FusedGPTSelfAttention1D", + "FusedGPTTransformerLayer1D", ] class GPTMLP1D(ParallelLayer): - def __init__( self, in_features: int, mlp_ratio: int, - act_func: str = 'gelu', - dropout_prob: float = 0., + act_func: str = "gelu", + dropout_prob: float = 0.0, dtype=None, checkpoint: bool = False, skip_bias_add: bool = False, @@ -82,7 +85,6 @@ class GPTMLP1D(ParallelLayer): class GenericGPTSelfAttention1D(ParallelLayer): - def __init__( self, hidden_size: int, @@ -118,8 +120,10 @@ class GenericGPTSelfAttention1D(ParallelLayer): def _forward(self, hidden_states: Tensor, attention_mask=None) -> Tensor: query_key_value = self.query_key_value(hidden_states) - new_qkv_shape = query_key_value.shape[:-1] + \ - (self.num_attention_heads_per_partition, 3 * self.attention_head_size) + new_qkv_shape = query_key_value.shape[:-1] + ( + self.num_attention_heads_per_partition, + 3 * self.attention_head_size, + ) query_key_value = query_key_value.view(new_qkv_shape) query_key_value = query_key_value.permute((0, 2, 1, 3)) query_layer, key_layer, value_layer = torch.chunk(query_key_value, 3, dim=-1) @@ -152,28 +156,32 @@ class GenericGPTSelfAttention1D(ParallelLayer): class GPTSelfAttention1D(GenericGPTSelfAttention1D): - - def __init__(self, - hidden_size: int, - num_attention_heads: int, - attention_dropout_prob: float, - hidden_dropout_prob: float, - dtype=None, - checkpoint: bool = False, - max_position_embeddings=1024): - super().__init__(hidden_size, - num_attention_heads, - attention_dropout_prob, - hidden_dropout_prob, - dtype=dtype, - checkpoint=checkpoint, - max_position_embeddings=max_position_embeddings) + def __init__( + self, + hidden_size: int, + num_attention_heads: int, + attention_dropout_prob: float, + hidden_dropout_prob: float, + dtype=None, + checkpoint: bool = False, + max_position_embeddings=1024, + ): + super().__init__( + hidden_size, + num_attention_heads, + attention_dropout_prob, + hidden_dropout_prob, + dtype=dtype, + checkpoint=checkpoint, + max_position_embeddings=max_position_embeddings, + ) self.softmax = nn.Softmax(dim=-1) max_positions = max_position_embeddings self.register_buffer( "bias", - torch.tril(torch.ones((max_positions, max_positions), - dtype=torch.uint8)).view(1, 1, max_positions, max_positions), + torch.tril(torch.ones((max_positions, max_positions), dtype=torch.uint8)).view( + 1, 1, max_positions, max_positions + ), ) self.register_buffer("masked_bias", torch.tensor(-1e4)) @@ -181,7 +189,7 @@ class GPTSelfAttention1D(GenericGPTSelfAttention1D): attention_scores = attention_scores / math.sqrt(self.attention_head_size) # causal mask query_length, key_length = query_layer.size(-2), key_layer.size(-2) - causal_mask = self.bias[:, :, key_length - query_length:key_length, :key_length].bool() + causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length].bool() attention_scores = torch.where(causal_mask, attention_scores, self.masked_bias.to(attention_scores)) if attention_mask is not None: # Apply the attention mask @@ -191,50 +199,56 @@ class GPTSelfAttention1D(GenericGPTSelfAttention1D): class FusedGPTSelfAttention1D(GenericGPTSelfAttention1D): - - def __init__(self, - hidden_size: int, - num_attention_heads: int, - attention_dropout_prob: float, - hidden_dropout_prob: float, - dtype=None, - checkpoint: bool = False, - max_position_embeddings=1024): - super().__init__(hidden_size, - num_attention_heads, - attention_dropout_prob, - hidden_dropout_prob, - dtype=dtype, - checkpoint=checkpoint, - max_position_embeddings=max_position_embeddings) - self.softmax = kernel.FusedScaleMaskSoftmax(input_in_fp16=True, - input_in_bf16=False, - attn_mask_type=AttnMaskType.causal, - scaled_masked_softmax_fusion=True, - mask_func=None, - softmax_in_fp32=True, - scale=math.sqrt(self.attention_head_size)) + def __init__( + self, + hidden_size: int, + num_attention_heads: int, + attention_dropout_prob: float, + hidden_dropout_prob: float, + dtype=None, + checkpoint: bool = False, + max_position_embeddings=1024, + ): + super().__init__( + hidden_size, + num_attention_heads, + attention_dropout_prob, + hidden_dropout_prob, + dtype=dtype, + checkpoint=checkpoint, + max_position_embeddings=max_position_embeddings, + ) + self.softmax = kernel.FusedScaleMaskSoftmax( + input_in_fp16=True, + input_in_bf16=False, + attn_mask_type=AttnMaskType.causal, + scaled_masked_softmax_fusion=True, + mask_func=None, + softmax_in_fp32=True, + scale=math.sqrt(self.attention_head_size), + ) def softmax_forward(self, attention_scores, attention_mask, query_layer, key_layer): return self.softmax(attention_scores, attention_mask) class GenericGPTTransformerLayer1D(ParallelLayer): - - def __init__(self, - hidden_size: int, - num_attention_heads: int, - act_func: str = 'gelu', - mlp_ratio: float = 4.0, - attention_dropout_prob: float = 0., - hidden_dropout_prob: float = 0., - dtype=None, - checkpoint: bool = False, - max_position_embeddings: int = 1024, - layer_norm_epsilon: float = 1e-5, - apply_post_layer_norm: bool = False, - attention=None, - layer_norm=None): + def __init__( + self, + hidden_size: int, + num_attention_heads: int, + act_func: str = "gelu", + mlp_ratio: float = 4.0, + attention_dropout_prob: float = 0.0, + hidden_dropout_prob: float = 0.0, + dtype=None, + checkpoint: bool = False, + max_position_embeddings: int = 1024, + layer_norm_epsilon: float = 1e-5, + apply_post_layer_norm: bool = False, + attention=None, + layer_norm=None, + ): super().__init__() self.checkpoint = checkpoint self.dtype = dtype @@ -288,62 +302,68 @@ class GenericGPTTransformerLayer1D(ParallelLayer): class GPTTransformerLayer1D(GenericGPTTransformerLayer1D): - - def __init__(self, - hidden_size: int, - num_attention_heads: int, - act_func: str = 'gelu', - mlp_ratio: float = 4, - attention_dropout_prob: float = 0, - hidden_dropout_prob: float = 0, - dtype=None, - checkpoint: bool = False, - max_position_embeddings: int = 1024, - layer_norm_epsilon: float = 0.00001, - apply_post_layer_norm: bool = False): + def __init__( + self, + hidden_size: int, + num_attention_heads: int, + act_func: str = "gelu", + mlp_ratio: float = 4, + attention_dropout_prob: float = 0, + hidden_dropout_prob: float = 0, + dtype=None, + checkpoint: bool = False, + max_position_embeddings: int = 1024, + layer_norm_epsilon: float = 0.00001, + apply_post_layer_norm: bool = False, + ): attention = GPTSelfAttention1D layer_norm = nn.LayerNorm - super().__init__(hidden_size, - num_attention_heads, - act_func=act_func, - mlp_ratio=mlp_ratio, - attention_dropout_prob=attention_dropout_prob, - hidden_dropout_prob=hidden_dropout_prob, - dtype=dtype, - checkpoint=checkpoint, - max_position_embeddings=max_position_embeddings, - layer_norm_epsilon=layer_norm_epsilon, - apply_post_layer_norm=apply_post_layer_norm, - attention=attention, - layer_norm=layer_norm) + super().__init__( + hidden_size, + num_attention_heads, + act_func=act_func, + mlp_ratio=mlp_ratio, + attention_dropout_prob=attention_dropout_prob, + hidden_dropout_prob=hidden_dropout_prob, + dtype=dtype, + checkpoint=checkpoint, + max_position_embeddings=max_position_embeddings, + layer_norm_epsilon=layer_norm_epsilon, + apply_post_layer_norm=apply_post_layer_norm, + attention=attention, + layer_norm=layer_norm, + ) class FusedGPTTransformerLayer1D(GenericGPTTransformerLayer1D): - - def __init__(self, - hidden_size: int, - num_attention_heads: int, - act_func: str = 'gelu', - mlp_ratio: float = 4, - attention_dropout_prob: float = 0, - hidden_dropout_prob: float = 0, - dtype=None, - checkpoint: bool = False, - max_position_embeddings: int = 1024, - layer_norm_epsilon: float = 0.00001, - apply_post_layer_norm: bool = False): + def __init__( + self, + hidden_size: int, + num_attention_heads: int, + act_func: str = "gelu", + mlp_ratio: float = 4, + attention_dropout_prob: float = 0, + hidden_dropout_prob: float = 0, + dtype=None, + checkpoint: bool = False, + max_position_embeddings: int = 1024, + layer_norm_epsilon: float = 0.00001, + apply_post_layer_norm: bool = False, + ): attention = FusedGPTSelfAttention1D layer_norm = kernel.LayerNorm - super().__init__(hidden_size, - num_attention_heads, - act_func=act_func, - mlp_ratio=mlp_ratio, - attention_dropout_prob=attention_dropout_prob, - hidden_dropout_prob=hidden_dropout_prob, - dtype=dtype, - checkpoint=checkpoint, - max_position_embeddings=max_position_embeddings, - layer_norm_epsilon=layer_norm_epsilon, - apply_post_layer_norm=apply_post_layer_norm, - attention=attention, - layer_norm=layer_norm) + super().__init__( + hidden_size, + num_attention_heads, + act_func=act_func, + mlp_ratio=mlp_ratio, + attention_dropout_prob=attention_dropout_prob, + hidden_dropout_prob=hidden_dropout_prob, + dtype=dtype, + checkpoint=checkpoint, + max_position_embeddings=max_position_embeddings, + layer_norm_epsilon=layer_norm_epsilon, + apply_post_layer_norm=apply_post_layer_norm, + attention=attention, + layer_norm=layer_norm, + ) diff --git a/examples/language/gpt/titans/model/pipeline_gpt1d.py b/examples/language/gpt/titans/model/pipeline_gpt1d.py index 30180285bc70fefaec5610a4b2d284a946a25d2c..83158cb44e0c799ac0d6b1804c0d81ea25bde182 100644 --- a/examples/language/gpt/titans/model/pipeline_gpt1d.py +++ b/examples/language/gpt/titans/model/pipeline_gpt1d.py @@ -7,27 +7,26 @@ import torch.nn as nn from colossalai import kernel from colossalai import nn as col_nn -from colossalai.context.parallel_mode import ParallelMode -from colossalai.core import global_context as gpc +from colossalai.legacy.context.parallel_mode import ParallelMode +from colossalai.legacy.core import global_context as gpc +from colossalai.legacy.nn.layer.wrapper import PipelineSharedModuleWrapper +from colossalai.legacy.pipeline.utils import partition_uniform from colossalai.logging import get_dist_logger -from colossalai.nn.layer.wrapper import PipelineSharedModuleWrapper -from colossalai.pipeline.utils import partition_uniform from .embed import HiddenParallelEmbedding, HiddenParallelGPTLMHead1D, VocabParallelEmbedding, VocabParallelGPTLMHead1D from .gpt1d import FusedGPTTransformerLayer1D, GPTTransformerLayer1D __all__ = [ - 'GPT2_small_pipeline_1D', - 'GPT2_exlarge_pipeline_1D', - 'GPT3_pipeline_1D', - 'GPT2_exlarge_pipeline_hybrid', - 'GPT2_small_pipeline_hybrid', - 'GPT3_pipeline_hybrid', + "GPT2_small_pipeline_1D", + "GPT2_exlarge_pipeline_1D", + "GPT3_pipeline_1D", + "GPT2_exlarge_pipeline_hybrid", + "GPT2_small_pipeline_hybrid", + "GPT3_pipeline_hybrid", ] class GenericPipelineGPT(nn.Module): - def __init__(self, embedding=None, blocks=None, norm=None, head=None) -> None: super().__init__() self.embedding = embedding @@ -44,7 +43,7 @@ class GenericPipelineGPT(nn.Module): batch_size = hidden_states.shape[0] attention_mask = attention_mask.view(batch_size, -1) attention_mask = attention_mask[:, None, None, :] - attention_mask = attention_mask.to(dtype=hidden_states.dtype) # fp16 compatibility + attention_mask = attention_mask.to(dtype=hidden_states.dtype) # fp16 compatibility attention_mask = (1.0 - attention_mask) * -10000.0 for block in self.blocks: hidden_states, attention_mask = block(hidden_states, attention_mask) @@ -54,25 +53,26 @@ class GenericPipelineGPT(nn.Module): class PipelineGPT1D(GenericPipelineGPT): - - def __init__(self, - num_layers: int = 12, - hidden_size: int = 768, - num_attention_heads: int = 12, - vocab_size: int = 50304, - embed_drop_rate: float = 0., - act_func: str = 'gelu', - mlp_ratio: int = 4.0, - attn_drop_rate: float = 0., - drop_rate: float = 0., - dtype: torch.dtype = torch.float, - checkpoint: bool = False, - max_position_embeddings: int = 1024, - layer_norm_epsilon: float = 1e-5, - apply_post_layer_norm: bool = False, - first: bool = False, - last: bool = False, - embed_split_hidden=False): + def __init__( + self, + num_layers: int = 12, + hidden_size: int = 768, + num_attention_heads: int = 12, + vocab_size: int = 50304, + embed_drop_rate: float = 0.0, + act_func: str = "gelu", + mlp_ratio: int = 4.0, + attn_drop_rate: float = 0.0, + drop_rate: float = 0.0, + dtype: torch.dtype = torch.float, + checkpoint: bool = False, + max_position_embeddings: int = 1024, + layer_norm_epsilon: float = 1e-5, + apply_post_layer_norm: bool = False, + first: bool = False, + last: bool = False, + embed_split_hidden=False, + ): embedding = None norm = None head = None @@ -83,19 +83,24 @@ class PipelineGPT1D(GenericPipelineGPT): head_cls = HiddenParallelGPTLMHead1D if first: embedding = embed_cls(hidden_size, vocab_size, max_position_embeddings, embed_drop_rate, dtype=dtype) - blocks = nn.ModuleList([ - GPTTransformerLayer1D(hidden_size, - num_attention_heads, - act_func=act_func, - mlp_ratio=mlp_ratio, - attention_dropout_prob=attn_drop_rate, - hidden_dropout_prob=drop_rate, - dtype=dtype, - checkpoint=checkpoint, - max_position_embeddings=max_position_embeddings, - layer_norm_epsilon=layer_norm_epsilon, - apply_post_layer_norm=apply_post_layer_norm) for _ in range(num_layers) - ]) + blocks = nn.ModuleList( + [ + GPTTransformerLayer1D( + hidden_size, + num_attention_heads, + act_func=act_func, + mlp_ratio=mlp_ratio, + attention_dropout_prob=attn_drop_rate, + hidden_dropout_prob=drop_rate, + dtype=dtype, + checkpoint=checkpoint, + max_position_embeddings=max_position_embeddings, + layer_norm_epsilon=layer_norm_epsilon, + apply_post_layer_norm=apply_post_layer_norm, + ) + for _ in range(num_layers) + ] + ) if last: norm = nn.LayerNorm(hidden_size, eps=layer_norm_epsilon) head = head_cls(vocab_size=vocab_size, embed_dim=hidden_size, dtype=dtype) @@ -103,25 +108,26 @@ class PipelineGPT1D(GenericPipelineGPT): class FusedPipelineGPT1D(GenericPipelineGPT): - - def __init__(self, - num_layers: int = 12, - hidden_size: int = 768, - num_attention_heads: int = 12, - vocab_size: int = 50304, - embed_drop_rate: float = 0., - act_func: str = 'gelu', - mlp_ratio: int = 4.0, - attn_drop_rate: float = 0., - drop_rate: float = 0., - dtype: torch.dtype = torch.float, - checkpoint: bool = False, - max_position_embeddings: int = 1024, - layer_norm_epsilon: float = 1e-5, - apply_post_layer_norm: bool = False, - first: bool = False, - last: bool = False, - embed_split_hidden=False): + def __init__( + self, + num_layers: int = 12, + hidden_size: int = 768, + num_attention_heads: int = 12, + vocab_size: int = 50304, + embed_drop_rate: float = 0.0, + act_func: str = "gelu", + mlp_ratio: int = 4.0, + attn_drop_rate: float = 0.0, + drop_rate: float = 0.0, + dtype: torch.dtype = torch.float, + checkpoint: bool = False, + max_position_embeddings: int = 1024, + layer_norm_epsilon: float = 1e-5, + apply_post_layer_norm: bool = False, + first: bool = False, + last: bool = False, + embed_split_hidden=False, + ): embedding = None norm = None head = None @@ -132,19 +138,24 @@ class FusedPipelineGPT1D(GenericPipelineGPT): head_cls = HiddenParallelGPTLMHead1D if first: embedding = embed_cls(hidden_size, vocab_size, max_position_embeddings, embed_drop_rate, dtype=dtype) - blocks = nn.ModuleList([ - FusedGPTTransformerLayer1D(hidden_size, - num_attention_heads, - act_func=act_func, - mlp_ratio=mlp_ratio, - attention_dropout_prob=attn_drop_rate, - hidden_dropout_prob=drop_rate, - dtype=dtype, - checkpoint=checkpoint, - max_position_embeddings=max_position_embeddings, - layer_norm_epsilon=layer_norm_epsilon, - apply_post_layer_norm=apply_post_layer_norm) for _ in range(num_layers) - ]) + blocks = nn.ModuleList( + [ + FusedGPTTransformerLayer1D( + hidden_size, + num_attention_heads, + act_func=act_func, + mlp_ratio=mlp_ratio, + attention_dropout_prob=attn_drop_rate, + hidden_dropout_prob=drop_rate, + dtype=dtype, + checkpoint=checkpoint, + max_position_embeddings=max_position_embeddings, + layer_norm_epsilon=layer_norm_epsilon, + apply_post_layer_norm=apply_post_layer_norm, + ) + for _ in range(num_layers) + ] + ) if last: norm = kernel.LayerNorm(hidden_size, eps=layer_norm_epsilon) head = head_cls(vocab_size=vocab_size, embed_dim=hidden_size, dtype=dtype) @@ -153,7 +164,7 @@ class FusedPipelineGPT1D(GenericPipelineGPT): def forward(self, hidden_states=None, input_ids=None, attention_mask=None): if self.embedding is not None: hidden_states = self.embedding(input_ids=input_ids) - attention_mask = attention_mask.to(dtype=hidden_states.dtype) # fp16 compatibility + attention_mask = attention_mask.to(dtype=hidden_states.dtype) # fp16 compatibility for block in self.blocks: hidden_states, attention_mask = block(hidden_states, attention_mask) if self.norm is not None: @@ -162,44 +173,48 @@ class FusedPipelineGPT1D(GenericPipelineGPT): class PipelineGPTHybrid(GenericPipelineGPT): - - def __init__(self, - num_layers: int = 12, - hidden_size: int = 768, - num_attention_heads: int = 12, - vocab_size: int = 50304, - embed_drop_rate: float = 0., - act_func: str = 'gelu', - mlp_ratio: int = 4, - attn_drop_rate: float = 0., - drop_rate: float = 0., - dtype: torch.dtype = torch.float, - checkpoint: bool = False, - max_position_embeddings: int = 1024, - layer_norm_epsilon: float = 1e-5, - apply_post_layer_norm: bool = False, - first: bool = False, - last: bool = False, - embed_split_hidden=False): + def __init__( + self, + num_layers: int = 12, + hidden_size: int = 768, + num_attention_heads: int = 12, + vocab_size: int = 50304, + embed_drop_rate: float = 0.0, + act_func: str = "gelu", + mlp_ratio: int = 4, + attn_drop_rate: float = 0.0, + drop_rate: float = 0.0, + dtype: torch.dtype = torch.float, + checkpoint: bool = False, + max_position_embeddings: int = 1024, + layer_norm_epsilon: float = 1e-5, + apply_post_layer_norm: bool = False, + first: bool = False, + last: bool = False, + embed_split_hidden=False, + ): embedding = None norm = None head = None if first: - embedding = col_gpt.GPTEmbedding(hidden_size, - vocab_size, - max_position_embeddings, - dropout=embed_drop_rate, - dtype=dtype) - blocks = nn.ModuleList([ - col_gpt.GPTBlock(hidden_size, - num_attention_heads, - mlp_ratio=mlp_ratio, - attention_dropout=attn_drop_rate, - dropout=drop_rate, - dtype=dtype, - checkpoint=checkpoint, - activation=nn.functional.gelu) for _ in range(num_layers) - ]) + embedding = col_gpt.GPTEmbedding( + hidden_size, vocab_size, max_position_embeddings, dropout=embed_drop_rate, dtype=dtype + ) + blocks = nn.ModuleList( + [ + col_gpt.GPTBlock( + hidden_size, + num_attention_heads, + mlp_ratio=mlp_ratio, + attention_dropout=attn_drop_rate, + dropout=drop_rate, + dtype=dtype, + checkpoint=checkpoint, + activation=nn.functional.gelu, + ) + for _ in range(num_layers) + ] + ) if last: norm = col_nn.LayerNorm(hidden_size, eps=layer_norm_epsilon) # head = col_gpt.GPTLMHead(vocab_size=vocab_size, @@ -215,7 +230,7 @@ def _filter_kwargs(func, kwargs): return {k: v for k, v in kwargs.items() if k in sig.parameters} -def _build_generic_gpt_pipeline_1d(module_cls, num_layers, num_chunks, device=torch.device('cuda'), **kwargs): +def _build_generic_gpt_pipeline_1d(module_cls, num_layers, num_chunks, device=torch.device("cuda"), **kwargs): logger = get_dist_logger() if gpc.is_initialized(ParallelMode.PIPELINE): @@ -233,10 +248,10 @@ def _build_generic_gpt_pipeline_1d(module_cls, num_layers, num_chunks, device=to parts = partition_uniform(num_layers, pipeline_size, num_chunks)[pipeline_rank] models = [] for start, end in parts: - kwargs['num_layers'] = end - start - kwargs['first'] = start == 0 - kwargs['last'] = end == num_layers - logger.info(f'Rank{rank} build layer {start}-{end}, {end-start}/{num_layers} layers') + kwargs["num_layers"] = end - start + kwargs["first"] = start == 0 + kwargs["last"] = end == num_layers + logger.info(f"Rank{rank} build layer {start}-{end}, {end-start}/{num_layers} layers") chunk = module_cls(**_filter_kwargs(module_cls.__init__, kwargs)).to(device) if wrapper is not None: @@ -253,70 +268,82 @@ def _build_generic_gpt_pipeline_1d(module_cls, num_layers, num_chunks, device=to numel = 0 for _, param in model.named_parameters(recurse=True): numel += param.numel() - logger.info(f'Rank{rank}/{pipeline_rank} model size = {numel * 2 / 1e9} GB') + logger.info(f"Rank{rank}/{pipeline_rank} model size = {numel * 2 / 1e9} GB") return model -def _build_gpt_pipeline_1d(num_layers, num_chunks, device=torch.device('cuda'), fused=False, **kwargs): +def _build_gpt_pipeline_1d(num_layers, num_chunks, device=torch.device("cuda"), fused=False, **kwargs): model = FusedPipelineGPT1D if fused else PipelineGPT1D return _build_generic_gpt_pipeline_1d(model, num_layers, num_chunks, device, **kwargs) -def _build_gpt_pipeline_hybrid(num_layers, num_chunks, device=torch.device('cuda'), **kwargs): +def _build_gpt_pipeline_hybrid(num_layers, num_chunks, device=torch.device("cuda"), **kwargs): return _build_generic_gpt_pipeline_1d(PipelineGPTHybrid, num_layers, num_chunks, device, **kwargs) def GPT2_small_pipeline_1D(num_chunks=1, checkpoint=False, dtype=torch.float, embed_split_hidden=False, fused=False): - cfg = dict(hidden_size=768, - num_attention_heads=12, - checkpoint=checkpoint, - dtype=dtype, - embed_split_hidden=embed_split_hidden) + cfg = dict( + hidden_size=768, + num_attention_heads=12, + checkpoint=checkpoint, + dtype=dtype, + embed_split_hidden=embed_split_hidden, + ) return _build_gpt_pipeline_1d(12, num_chunks, fused=fused, **cfg) def GPT2_exlarge_pipeline_1D(num_chunks=1, checkpoint=False, dtype=torch.float, embed_split_hidden=False, fused=False): - cfg = dict(hidden_size=1600, - num_attention_heads=32, - checkpoint=checkpoint, - dtype=dtype, - embed_split_hidden=embed_split_hidden) + cfg = dict( + hidden_size=1600, + num_attention_heads=32, + checkpoint=checkpoint, + dtype=dtype, + embed_split_hidden=embed_split_hidden, + ) return _build_gpt_pipeline_1d(48, num_chunks, fused=fused, **cfg) def GPT3_pipeline_1D(num_chunks=1, checkpoint=False, dtype=torch.float, embed_split_hidden=False, fused=False): - cfg = dict(hidden_size=12288, - num_attention_heads=96, - checkpoint=checkpoint, - max_position_embeddings=2048, - dtype=dtype, - embed_split_hidden=embed_split_hidden) + cfg = dict( + hidden_size=12288, + num_attention_heads=96, + checkpoint=checkpoint, + max_position_embeddings=2048, + dtype=dtype, + embed_split_hidden=embed_split_hidden, + ) return _build_gpt_pipeline_1d(96, num_chunks, fused=fused, **cfg) def GPT2_exlarge_pipeline_hybrid(num_chunks=1, checkpoint=False, dtype=torch.float, embed_split_hidden=False): - cfg = dict(hidden_size=1600, - num_attention_heads=32, - checkpoint=checkpoint, - dtype=dtype, - embed_split_hidden=embed_split_hidden) + cfg = dict( + hidden_size=1600, + num_attention_heads=32, + checkpoint=checkpoint, + dtype=dtype, + embed_split_hidden=embed_split_hidden, + ) return _build_gpt_pipeline_hybrid(48, num_chunks, **cfg) def GPT2_small_pipeline_hybrid(num_chunks=1, checkpoint=False, dtype=torch.float, embed_split_hidden=False): - cfg = dict(hidden_size=768, - num_attention_heads=12, - checkpoint=checkpoint, - dtype=dtype, - embed_split_hidden=embed_split_hidden) + cfg = dict( + hidden_size=768, + num_attention_heads=12, + checkpoint=checkpoint, + dtype=dtype, + embed_split_hidden=embed_split_hidden, + ) return _build_gpt_pipeline_hybrid(12, num_chunks, **cfg) def GPT3_pipeline_hybrid(num_chunks=1, checkpoint=False, dtype=torch.float, embed_split_hidden=False): - cfg = dict(hidden_size=12288, - num_attention_heads=96, - checkpoint=checkpoint, - max_position_embeddings=2048, - dtype=dtype, - embed_split_hidden=embed_split_hidden) + cfg = dict( + hidden_size=12288, + num_attention_heads=96, + checkpoint=checkpoint, + max_position_embeddings=2048, + dtype=dtype, + embed_split_hidden=embed_split_hidden, + ) return _build_gpt_pipeline_hybrid(96, num_chunks, **cfg) diff --git a/examples/language/gpt/titans/train_gpt.py b/examples/language/gpt/titans/train_gpt.py index 66225d6c80447a147c70f2937ca2d1ad9e9aa5c4..565cf1e016ccf28a9948b2997726c2f6b2f32b57 100644 --- a/examples/language/gpt/titans/train_gpt.py +++ b/examples/language/gpt/titans/train_gpt.py @@ -1,3 +1,4 @@ +import argparse import contextlib import os @@ -8,14 +9,14 @@ from titans.model.gpt import GPTLMLoss import colossalai import colossalai.utils as utils -from colossalai.context.parallel_mode import ParallelMode -from colossalai.core import global_context as gpc +from colossalai.legacy.context.parallel_mode import ParallelMode +from colossalai.legacy.core import global_context as gpc +from colossalai.legacy.trainer import Trainer, hooks +from colossalai.legacy.zero.init_ctx import ZeroInitContext from colossalai.logging import disable_existing_loggers, get_dist_logger from colossalai.nn import LinearWarmupLR -from colossalai.trainer import Trainer, hooks -from colossalai.utils import colo_set_process_memory_fraction, is_using_pp +from colossalai.utils import is_using_pp from colossalai.utils.timer import MultiTimer -from colossalai.zero.init_ctx import ZeroInitContext def calc_local_model_size(model: torch.nn.Module): @@ -29,9 +30,9 @@ VOCAB_SIZE = 50257 def main(): - parser = colossalai.get_default_parser() - parser.add_argument('--from_torch', default=False, action='store_true') - parser.add_argument('--use_dummy_dataset', default=False, action='store_true') + parser = argparse.ArgumentParser() + parser.add_argument("--from_torch", default=False, action="store_true") + parser.add_argument("--use_dummy_dataset", default=False, action="store_true") args = parser.parse_args() disable_existing_loggers() if args.from_torch: @@ -40,28 +41,27 @@ def main(): colossalai.launch_from_slurm(config=args.config, host=args.host, port=29500, seed=42) logger = get_dist_logger() - data_path = None if args.use_dummy_dataset else os.environ['DATA'] - logger.info(f'Build data loader from path {data_path}', ranks=[0]) + data_path = None if args.use_dummy_dataset else os.environ["DATA"] + logger.info(f"Build data loader from path {data_path}", ranks=[0]) train_ds = WebtextDataset(path=data_path, seq_len=gpc.config.SEQ_LEN) - train_dataloader = utils.get_dataloader(train_ds, - seed=42, - batch_size=gpc.config.BATCH_SIZE, - pin_memory=True, - shuffle=True, - drop_last=True) - - logger.info('Build model', ranks=[0]) + train_dataloader = utils.get_dataloader( + train_ds, seed=42, batch_size=gpc.config.BATCH_SIZE, pin_memory=True, shuffle=True, drop_last=True + ) + + logger.info("Build model", ranks=[0]) use_pipeline = is_using_pp() - use_interleaved = hasattr(gpc.config.model, 'num_chunks') - use_zero3 = hasattr(gpc.config, 'zero') + use_interleaved = hasattr(gpc.config.model, "num_chunks") + use_zero3 = hasattr(gpc.config, "zero") ctx = contextlib.nullcontext() if use_zero3: - ctx = ZeroInitContext(target_device=torch.cuda.current_device(), - shard_strategy=gpc.config.zero.model_config.shard_strategy, - shard_param=True) + ctx = ZeroInitContext( + target_device=torch.cuda.current_device(), + shard_strategy=gpc.config.zero.model_config.shard_strategy, + shard_param=True, + ) with ctx: - model = gpc.config.model.pop('type')(**gpc.config.model) + model = gpc.config.model.pop("type")(**gpc.config.model) if use_pipeline and use_interleaved and not isinstance(model, nn.ModuleList): model = nn.ModuleList([model]) @@ -70,25 +70,31 @@ def main(): else: numel = calc_local_model_size(model) - tflop = numel * gpc.config.BATCH_SIZE * gpc.config.SEQ_LEN \ - * gpc.get_world_size(ParallelMode.MODEL) * gpc.get_world_size(ParallelMode.DATA) * 8 / (1024 ** 4) - - criterion = getattr(gpc.config, 'loss_fn', None) + tflop = ( + numel + * gpc.config.BATCH_SIZE + * gpc.config.SEQ_LEN + * gpc.get_world_size(ParallelMode.MODEL) + * gpc.get_world_size(ParallelMode.DATA) + * 8 + / (1024**4) + ) + + criterion = getattr(gpc.config, "loss_fn", None) if criterion is not None: criterion = criterion.type() else: criterion = GPTLMLoss() - logger.info('Build optimizer', ranks=[0]) - optimizer = gpc.config.optimizer.pop('type')(model.parameters(), **gpc.config.optimizer) + logger.info("Build optimizer", ranks=[0]) + optimizer = gpc.config.optimizer.pop("type")(model.parameters(), **gpc.config.optimizer) lr_scheduler = LinearWarmupLR(optimizer, total_steps=gpc.config.NUM_EPOCHS, warmup_steps=5) - engine, train_dataloader, _, lr_scheduler = colossalai.initialize(model, - optimizer, - criterion, - train_dataloader=train_dataloader, - lr_scheduler=lr_scheduler) - global_batch_size = gpc.config.BATCH_SIZE * \ - gpc.get_world_size(ParallelMode.DATA) * getattr(gpc.config, "gradient_accumulation", 1) - logger.info(f'Init done, global batch size = {global_batch_size}', ranks=[0]) + engine, train_dataloader, _, lr_scheduler = colossalai.initialize( + model, optimizer, criterion, train_dataloader=train_dataloader, lr_scheduler=lr_scheduler + ) + global_batch_size = ( + gpc.config.BATCH_SIZE * gpc.get_world_size(ParallelMode.DATA) * getattr(gpc.config, "gradient_accumulation", 1) + ) + logger.info(f"Init done, global batch size = {global_batch_size}", ranks=[0]) timier = MultiTimer() trainer = Trainer(engine=engine, logger=logger, timer=timier) hook_list = [ @@ -98,16 +104,18 @@ def main(): hooks.ThroughputHook(ignored_steps=10, tflop_per_step=tflop), hooks.LogMetricByStepHook(), hooks.LogMemoryByEpochHook(logger), - # hooks.LogMemoryByEpochHook(logger), - # hooks.LogTimingByEpochHook(timer, logger), + # hooks.LogMemoryByEpochHook(logger), + # hooks.LogTimingByEpochHook(timer, logger), ] - trainer.fit(train_dataloader=train_dataloader, - epochs=gpc.config.NUM_EPOCHS, - test_interval=1, - hooks=hook_list, - display_progress=True, - return_output_label=False) + trainer.fit( + train_dataloader=train_dataloader, + epochs=gpc.config.NUM_EPOCHS, + test_interval=1, + hooks=hook_list, + display_progress=True, + return_output_label=False, + ) -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/examples/language/llama2/README.md b/examples/language/llama2/README.md new file mode 100644 index 0000000000000000000000000000000000000000..83ef99b57d420e8edf425744c5a98bbc1a293d61 --- /dev/null +++ b/examples/language/llama2/README.md @@ -0,0 +1,234 @@ +# Pretraining LLaMA-1/2: best practices for building LLaMA-1/2-like base models + +### LLaMA2 +

              + +

              + +- 70 billion parameter LLaMA2 model training accelerated by 195% +[[code]](https://github.com/hpcaitech/ColossalAI/tree/main/examples/language/llama2) +[[blog]](https://www.hpc-ai.tech/blog/70b-llama2-training) + +### LLaMA1 +

              + +

              + +- 65-billion-parameter large model pretraining accelerated by 38% +[[code]](https://github.com/hpcaitech/ColossalAI/tree/example/llama/examples/language/llama) +[[blog]](https://www.hpc-ai.tech/blog/large-model-pretraining) + +## Dataset + +Different from the original LLaMA, we use [RedPajama](https://www.together.xyz/blog/redpajama) dataset, which is a reproduction of the LLaMA training dataset containing over 1.2 trillion tokens. The full dataset is ~5TB unzipped on disk and ~3TB to download compressed. + +A smaller, more consumable random sample can be downloaded through [Hugging Face](https://huggingface.co/datasets/togethercomputer/RedPajama-Data-1T). If you just want to try out the pretraining script, you can use a 1B-token sample subset of RedPajama, which is available at [Hugging Face](https://huggingface.co/datasets/togethercomputer/RedPajama-Data-1T-Sample). + +RedPajama-Data-1T consists of seven data slices: + +| | RedPajama | LLaMA | +|---------------|--------------|---------------| +| CommonCrawl | 878 billion | 852 billion | +| C4 | 175 billion | 190 billion | +| Github | 59 billion | 100 billion | +| Books | 26 billion | 25 billion | +| ArXiv | 28 billion | 33 billion | +| Wikipedia | 24 billion | 25 billion | +| StackExchange | 20 billion | 27 billion | +| Total | 1.2 trillion | 1.25 trillion | + +## Training + +We follow the hyperparameter settings from the original LLaMA paper. We use AdamW with $beta1=0.9$ and $beta2=0.95$. We use a cosine learning rate schedule, such that the final learning rate is equal to 10% of the maximal learning rate. We use a weight decay of 0.1 and gradient clipping of 1.0. We use 2,000 warmup steps. + +| params | learning rate | batch size | +|--------|---------------|------------| +| 6.7B | 3.0e-4 | 4M | +| 13.0B | 3.0e-4 | 4M | +| 32.5B | 1.5e-4 | 4M | +| 65.2B | 1.5e-4 | 4M | + +## Usage + +### 1. Installation + +Please install the latest ColossalAI from source. + +```bash +CUDA_EXT=1 pip install -U git+https://github.com/hpcaitech/ColossalAI +``` + +Then install other dependencies. + +```bash +pip install -r requirements.txt +``` + +Additionally, we recommend you to use torch 1.13.1. We've tested our code on torch 1.13.1 and found it's compatible with our code and flash attention. + +### 2. Download the dataset + +The dataset can be automatically downloaded by using `huggingface/datasets`. You can specify the dataset path by `-d` or `--dataset`. The default dataset is `togethercomputer/RedPajama-Data-1T-Sample`. + +### 3. Command line arguments + +Yon can use colossalai run to launch multi-nodes training: +```bash +colossalai run --nproc_per_node YOUR_GPU_PER_NODE --hostfile YOUR_HOST_FILE \ +pretrain.py --OTHER_CONFIGURATIONS +``` + +Here is a sample hostfile: + +```text +hostname1 +hostname2 +hostname3 +hostname4 +``` + +Make sure master node can access all nodes (including itself) by ssh without password. + +Here is details about CLI arguments: + +- Model configuration: `-c`, `--config`. `7b`, `13b`, `30b` and `65b` are supported for LLaMA-1, `7b`, `13b`, and `70b` are supported for LLaMA-2. +- Booster plugin: `-p`, `--plugin`. `gemini`, `gemini_auto`, `zero2`, `hybrid_parallel` and `zero2_cpu` are supported. For more details, please refer to [Booster plugins](https://colossalai.org/docs/basics/booster_plugins). +- Dataset path: `-d`, `--dataset`. The default dataset is `togethercomputer/RedPajama-Data-1T-Sample`. It support any dataset from `datasets` with the same data format as RedPajama. +- Number of epochs: `-e`, `--num_epochs`. The default value is 1. +- Local batch size: `-b`, `--batch_size`. Batch size per GPU. The default value is 2. +- Learning rate: `--lr`. The default value is 3e-4. +- Weight decay: `-w`, `--weight_decay`. The default value is 0.1. +- Warmup steps: `-s`, `--warmup_steps`. The default value is 2000. +- Gradient checkpointing: `-g`, `--gradient_checkpoint`. The default value is `False`. This saves memory at the cost of speed. You'd better enable this option when training with a large batch size. +- Max length: `-l`, `--max_length`. The default value is 4096. +- Mixed precision: `-x`, `--mixed_precision`. The default value is "fp16". "fp16" and "bf16" are supported. +- Save interval: `-i`, `--save_interval`. The interval (steps) of saving checkpoints. The default value is 1000. +- Checkpoint directory: `-o`, `--save_dir`. The directoty path to save checkpoints. The default value is `checkpoint`. +- Checkpoint to load: `-f`, `--load`. The checkpoint path to load. The default value is `None`. +- Gradient clipping: `--gradient_clipping`. The default value is 1.0. +- Tensorboard log directory: `-t`, `--tensorboard_dir`. The directory path to save tensorboard logs. The default value is `tb_logs`. +- Flash attention: `-a`, `--flash_attention`. If you want to use flash attention, you must install `flash-attn`. The default value is `False`. This is helpful to accelerate training while saving memory. We recommend you always use flash attention. + + +### 4. Shell Script Examples + +For your convenience, we provide some shell scripts to run benchmark with various configurations. + +You can find them in `scripts/benchmark_7B` and `scripts/benchmark_70B` directory. The main command should be in the format of: +```bash +colossalai run --nproc_per_node YOUR_GPU_PER_NODE --hostfile YOUR_HOST_FILE \ +benchmark.py --OTHER_CONFIGURATIONS +``` +Here we will show an example of how to run training +llama pretraining with `gemini, batch_size=16, sequence_length=4096, gradient_checkpoint=True, flash_attn=True`. + +#### a. Running environment +This experiment was performed on 4 computing nodes with 32 A800 GPUs in total for LLaMA-1 65B. The nodes are +connected with RDMA and GPUs within one node are fully connected with NVLink. + +#### b. Running command + +```bash +cd scripts/benchmark_7B +``` + +First, put your host file (`hosts.txt`) in this directory with your real host ip or host name. + +Here is a sample `hosts.txt`: +```text +hostname1 +hostname2 +hostname3 +hostname4 +``` + +Then add environment variables to script if needed. + +Finally, run the following command to start training: + +```bash +bash gemini.sh +``` + +If you encounter out-of-memory(OOM) error during training with script `gemini.sh`, changing to script `gemini_auto.sh` might be a solution, since gemini_auto will set a upper limit on GPU memory usage through offloading part of the model parameters and optimizer states back to CPU memory. But there's a trade-off: `gemini_auto.sh` will be a bit slower, since more data are transmitted between CPU and GPU. + +#### c. Results +If you run the above command successfully, you will get the following results: +`max memory usage: 55491.10 MB, throughput: 24.26 samples/s, TFLOPS/GPU: 167.43`. + + +## Reference +``` +@article{bian2021colossal, + title={Colossal-AI: A Unified Deep Learning System For Large-Scale Parallel Training}, + author={Bian, Zhengda and Liu, Hongxin and Wang, Boxiang and Huang, Haichen and Li, Yongbin and Wang, Chuanrui and Cui, Fan and You, Yang}, + journal={arXiv preprint arXiv:2110.14883}, + year={2021} +} +``` + +```bibtex +@software{openlm2023openllama, + author = {Geng, Xinyang and Liu, Hao}, + title = {OpenLLaMA: An Open Reproduction of LLaMA}, + month = May, + year = 2023, + url = {https://github.com/openlm-research/open_llama} +} +``` + +```bibtex +@software{together2023redpajama, + author = {Together Computer}, + title = {RedPajama-Data: An Open Source Recipe to Reproduce LLaMA training dataset}, + month = April, + year = 2023, + url = {https://github.com/togethercomputer/RedPajama-Data} +} +``` + +```bibtex +@article{touvron2023llama, + title={Llama: Open and efficient foundation language models}, + author={Touvron, Hugo and Lavril, Thibaut and Izacard, Gautier and Martinet, Xavier and Lachaux, Marie-Anne and Lacroix, Timoth{\'e}e and Rozi{\`e}re, Baptiste and Goyal, Naman and Hambro, Eric and Azhar, Faisal and others}, + journal={arXiv preprint arXiv:2302.13971}, + year={2023} +} +``` + + +# Fine-tune Llama2 + +We also provide a example to fine-tune llama2 in `finetune.py`, + +Make sure master node can access all nodes (including itself) by ssh without password. + +Here is details about CLI arguments: + +- Pretrained checkpoint path: `--model_path`, the path of your model checkpoint, it can be your local directory or a Hugging Face tag. +- Booster plugin: `-p`, `--plugin`. `gemini`, `gemini_auto`, `zero2`, `hybrid_parallel` and `zero2_cpu` are supported. For more details, please refer to [Booster plugins](https://colossalai.org/docs/basics/booster_plugins). +- Dataset path: `-d`, `--dataset`. The default dataset is `yizhongw/self_instruct`. It support any dataset from `datasets` with the same data format as `yizhongw/self_instruct`. +- task name: `--task_name`, the task to fine-tune, it's also related to the target of loading dataset, The default value is `super_natural_instructions`. +- Number of epochs: `-e`, `--num_epochs`. The default value is 1. +- Local batch size: `-b`, `--batch_size`. Batch size per GPU. The default value is 2. +- Learning rate: `--lr`. The default value is 3e-4. +- Weight decay: `-w`, `--weight_decay`. The default value is 0.1. +- Gradient checkpointing: `-g`, `--gradient_checkpoint`. The default value is `False`. This saves memory at the cost of speed. You'd better enable this option when training with a large batch size. +- Max length: `-l`, `--max_length`. The default value is 4096. +- Mixed precision: `-x`, `--mixed_precision`. The default value is "fp16". "fp16" and "bf16" are supported. +- Save interval: `-i`, `--save_interval`. The interval (steps) of saving checkpoints. The default value is 1000. +- Checkpoint directory: `-o`, `--save_dir`. The directoty path to save checkpoints. The default value is `checkpoint`. +- Checkpoint to load: `-f`, `--load`. The checkpoint path to load. The default value is `None`. +- Gradient clipping: `--gradient_clipping`. The default value is 1.0. +- Tensorboard log directory: `-t`, `--tensorboard_dir`. The directory path to save tensorboard logs. The default value is `tb_logs`. +- Flash attention: `-a`, `--flash_attention`. If you want to use flash attention, you must install `flash-attn`. The default value is `False`. This is helpful to accelerate training while saving memory. We recommend you always use flash attention. + + +```shell +torchrun --standalone --nproc_per_node 8 finetune.py \ + --plugin "hybrid_parallel" \ + --dataset "yizhongw/self_instruct" \ + --model_path "/path/llama" \ + --task_name "super_natural_instructions" \ + --save_dir "/path/output" +``` diff --git a/examples/language/llama2/attn.py b/examples/language/llama2/attn.py new file mode 100644 index 0000000000000000000000000000000000000000..2b2356b18b70d93d36f8c3de6343accccd5ba31d --- /dev/null +++ b/examples/language/llama2/attn.py @@ -0,0 +1,84 @@ +from types import MethodType +from typing import Optional, Tuple + +import torch +import torch.nn as nn +from transformers.models.llama.modeling_llama import LlamaAttention, apply_rotary_pos_emb, repeat_kv + +SUPPORT_XFORMERS = False +SUPPORT_FLASH2 = False +try: + import xformers.ops as xops + + SUPPORT_XFORMERS = True +except ImportError: + pass + +try: + from flash_attn import flash_attn_func + + SUPPORT_FLASH2 = True +except ImportError: + pass + +SUPPORT_FLASH = SUPPORT_XFORMERS or SUPPORT_FLASH2 + + +def llama_flash_attention( + self: LlamaAttention, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: bool = False, + use_cache: bool = False, +) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = self.k_proj(hidden_states).view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = self.v_proj(hidden_states).view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + kv_seq_len = key_states.shape[-2] + if past_key_value is not None: + kv_seq_len += past_key_value[0].shape[-2] + cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) + # [bsz, nh, t, hd] + + if past_key_value is not None: + # reuse k, v, self_attention + key_states = torch.cat([past_key_value[0], key_states], dim=2) + value_states = torch.cat([past_key_value[1], value_states], dim=2) + + past_key_value = (key_states, value_states) if use_cache else None + + # repeat k/v heads if n_kv_heads < n_heads + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + # q, k, v is [B, H, S, K] and xformers need [B, S, H, K]. returns [B, S, H, K] + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + if SUPPORT_FLASH2: + attn_output = flash_attn_func(query_states, key_states, value_states, causal=True) + else: + attn_output = xops.memory_efficient_attention( + query_states, key_states, value_states, attn_bias=xops.LowerTriangularMask() + ) + + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) + + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + +def replace_xformers(model: nn.Module): + for module in model.modules(): + if isinstance(module, LlamaAttention): + module.forward = MethodType(llama_flash_attention, module) diff --git a/examples/language/llama2/benchmark.py b/examples/language/llama2/benchmark.py new file mode 100644 index 0000000000000000000000000000000000000000..ce13ebbf617d4a9f90308b04ca910312dc63b1dd --- /dev/null +++ b/examples/language/llama2/benchmark.py @@ -0,0 +1,223 @@ +import argparse +import resource +from contextlib import nullcontext + +import torch +from attn import SUPPORT_FLASH, replace_xformers +from data_utils import RandomDataset +from model_utils import format_numel_str, get_model_numel +from performance_evaluator import PerformanceEvaluator +from torch.distributed.fsdp.fully_sharded_data_parallel import CPUOffload, MixedPrecision +from tqdm import tqdm +from transformers.models.llama.configuration_llama import LlamaConfig +from transformers.models.llama.modeling_llama import LlamaForCausalLM + +import colossalai +from colossalai.booster import Booster +from colossalai.booster.plugin import GeminiPlugin, HybridParallelPlugin, TorchFSDPPlugin +from colossalai.cluster import DistCoordinator +from colossalai.lazy import LazyInitContext +from colossalai.nn.optimizer import HybridAdam +from colossalai.utils import get_current_device + +# ============================== +# Constants +# ============================== + +MODEL_CONFIGS = { + "7b": LlamaConfig(max_position_embeddings=4096), + "13b": LlamaConfig( + hidden_size=5120, + intermediate_size=13824, + num_hidden_layers=40, + num_attention_heads=40, + max_position_embeddings=4096, + ), + "70b": LlamaConfig( + hidden_size=8192, + intermediate_size=28672, + num_hidden_layers=80, + num_attention_heads=64, + max_position_embeddings=4096, + num_key_value_heads=8, + ), +} + + +def main(): + # ============================== + # Parse Arguments + # ============================== + parser = argparse.ArgumentParser() + parser.add_argument("-c", "--config", type=str, default="7b", help="Model configuration") + parser.add_argument( + "-p", + "--plugin", + choices=["gemini", "gemini_auto", "fsdp", "fsdp_cpu", "3d", "3d_cpu"], + default="gemini", + help="Choose which plugin to use", + ) + parser.add_argument("-b", "--batch_size", type=int, default=2, help="Batch size") + parser.add_argument("-s", "--num_steps", type=int, default=5, help="Number of steps to run") + parser.add_argument("-i", "--ignore_steps", type=int, default=2, help="Number of steps to ignore") + parser.add_argument("-g", "--grad_checkpoint", action="store_true", help="Use gradient checkpointing") + parser.add_argument("-l", "--max_length", type=int, default=4096, help="Max sequence length") + parser.add_argument( + "-w", "--warmup_ratio", type=float, default=0.8, help="warm up ratio of non-model data. Only for gemini-auto" + ) + parser.add_argument("-m", "--memory_limit", type=int, help="Gemini memory limit in mb") + parser.add_argument("-x", "--xformers", action="store_true", help="Use xformers") + parser.add_argument("--shard_param_frac", type=float, default=1.0, help="Shard param fraction. Only for gemini") + parser.add_argument("--offload_optim_frac", type=float, default=0.0, help="Offload optim fraction. Only for gemini") + parser.add_argument("--offload_param_frac", type=float, default=0.0, help="Offload param fraction. Only for gemini") + parser.add_argument("--tp", type=int, default=1, help="Tensor parallel size") + parser.add_argument("--pp", type=int, default=1, help="Pipeline parallel size") + parser.add_argument("--mbs", type=int, default=1) + parser.add_argument("--zero", type=int, default=0) + args = parser.parse_args() + + colossalai.launch_from_torch({}) + coordinator = DistCoordinator() + + def empty_init(): + pass + + # ============================== + # Initialize Booster + # ============================== + use_empty_init = True + if args.plugin == "gemini": + plugin = GeminiPlugin( + precision="bf16", + shard_param_frac=args.shard_param_frac, + offload_optim_frac=args.offload_optim_frac, + offload_param_frac=args.offload_param_frac, + ) + elif args.plugin == "gemini_auto": + plugin = GeminiPlugin(placement_policy="auto", precision="bf16", warmup_non_model_data_ratio=args.warmup_ratio) + elif args.plugin == "fsdp": + if use_empty_init: + plugin = TorchFSDPPlugin( + mixed_precision=MixedPrecision( + param_dtype=torch.float16, reduce_dtype=torch.float16, buffer_dtype=torch.float16 + ), + param_init_fn=empty_init(), + ) + else: + plugin = TorchFSDPPlugin( + mixed_precision=MixedPrecision( + param_dtype=torch.float16, reduce_dtype=torch.float16, buffer_dtype=torch.float16 + ) + ) + elif args.plugin == "fsdp_cpu": + if use_empty_init: + plugin = TorchFSDPPlugin( + mixed_precision=MixedPrecision( + param_dtype=torch.float16, reduce_dtype=torch.float16, buffer_dtype=torch.float16 + ), + cpu_offload=CPUOffload(offload_params=True), + param_init_fn=empty_init(), + ) + else: + plugin = TorchFSDPPlugin( + mixed_precision=MixedPrecision( + param_dtype=torch.float16, reduce_dtype=torch.float16, buffer_dtype=torch.float16 + ), + cpu_offload=CPUOffload(offload_params=True), + ) + elif args.plugin == "3d": + plugin = HybridParallelPlugin( + tp_size=args.tp, + pp_size=args.pp, + zero_stage=args.zero, + enable_fused_normalization=True, + num_microbatches=args.mbs, + precision="bf16", + ) + elif args.plugin == "3d_cpu": + plugin = HybridParallelPlugin( + tp_size=args.tp, + pp_size=args.pp, + zero_stage=args.zero, + cpu_offload=True, + enable_fused_normalization=True, + num_microbatches=args.mbs, + initial_scale=2**8, + precision="bf16", + ) + else: + raise ValueError(f"Unknown plugin {args.plugin}") + + booster = Booster(plugin=plugin) + + # ============================== + # Initialize Dataset and Dataloader + # ============================== + dp_size = plugin.dp_size if isinstance(plugin, HybridParallelPlugin) else coordinator.world_size + + config = MODEL_CONFIGS[args.config] + dataset = RandomDataset( + num_samples=args.batch_size * args.num_steps * dp_size, max_length=args.max_length, vocab_size=config.vocab_size + ) + dataloader = plugin.prepare_dataloader(dataset, batch_size=args.batch_size, shuffle=True, drop_last=True) + + # ============================== + # Initialize Model and Optimizer + # ============================== + init_ctx = ( + LazyInitContext(default_device=get_current_device()) + if isinstance(plugin, (GeminiPlugin, HybridParallelPlugin)) + else nullcontext() + ) + + with init_ctx: + model = LlamaForCausalLM(config) + + if args.grad_checkpoint: + model.gradient_checkpointing_enable() + + if args.xformers: + assert SUPPORT_FLASH, "Use flash attention while xfomers is not installed" + replace_xformers(model) + + model_numel = get_model_numel(model) + coordinator.print_on_master(f"Model params: {format_numel_str(model_numel)}") + performance_evaluator = PerformanceEvaluator( + model_numel, args.grad_checkpoint, args.ignore_steps, dp_world_size=dp_size + ) + + optimizer = HybridAdam(model.parameters()) + torch.set_default_dtype(torch.bfloat16) + model, optimizer, _, dataloader, _ = booster.boost(model, optimizer, dataloader=dataloader) + torch.set_default_dtype(torch.float) + coordinator.print_on_master(f"Booster init max CUDA memory: {torch.cuda.max_memory_allocated()/1024**2:.2f} MB") + coordinator.print_on_master( + f"Booster init max CPU memory: {resource.getrusage(resource.RUSAGE_SELF).ru_maxrss/1024:.2f} MB" + ) + + if isinstance(plugin, HybridParallelPlugin) and args.pp > 1: + data_iter = iter(dataloader) + for step in tqdm(range(len(dataloader)), desc="Step", disable=not coordinator.is_master()): + performance_evaluator.on_step_start(step) + booster.execute_pipeline( + data_iter, model, criterion=lambda outputs, inputs: outputs[0], optimizer=optimizer, return_loss=False + ) + optimizer.step() + optimizer.zero_grad() + performance_evaluator.on_step_end(input_ids=torch.empty(args.batch_size, args.max_length)) + else: + for step, batch in enumerate(tqdm(dataloader, desc="Step", disable=not coordinator.is_master())): + performance_evaluator.on_step_start(step) + outputs = model(**batch) + loss = outputs[0] + booster.backward(loss, optimizer) + optimizer.step() + optimizer.zero_grad() + performance_evaluator.on_step_end(**batch) + + performance_evaluator.on_fit_end() + coordinator.print_on_master(f"Max CUDA memory usage: {torch.cuda.max_memory_allocated()/1024**2:.2f} MB") + + +if __name__ == "__main__": + main() diff --git a/examples/language/llama2/data_utils.py b/examples/language/llama2/data_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..a438833e1680ba660a506e0b5ab15ab22304fe76 --- /dev/null +++ b/examples/language/llama2/data_utils.py @@ -0,0 +1,122 @@ +import json +import random +from typing import Iterator, Optional + +import numpy as np +import torch +from torch.distributed import ProcessGroup +from torch.distributed.distributed_c10d import _get_default_group +from torch.utils.data import DataLoader, Dataset, DistributedSampler + +from colossalai.utils import get_current_device + + +class StatefulDistributedSampler(DistributedSampler): + def __init__( + self, + dataset: Dataset, + num_replicas: Optional[int] = None, + rank: Optional[int] = None, + shuffle: bool = True, + seed: int = 0, + drop_last: bool = False, + ) -> None: + super().__init__(dataset, num_replicas, rank, shuffle, seed, drop_last) + self.start_index: int = 0 + + def __iter__(self) -> Iterator: + iterator = super().__iter__() + indices = list(iterator) + indices = indices[self.start_index :] + return iter(indices) + + def __len__(self) -> int: + return self.num_samples - self.start_index + + def set_start_index(self, start_index: int) -> None: + self.start_index = start_index + + +def prepare_dataloader( + dataset, + batch_size, + shuffle=False, + seed=1024, + drop_last=False, + pin_memory=False, + num_workers=0, + process_group: Optional[ProcessGroup] = None, + **kwargs, +): + r""" + Prepare a dataloader for distributed training. The dataloader will be wrapped by + `torch.utils.data.DataLoader` and `StatefulDistributedSampler`. + + + Args: + dataset (`torch.utils.data.Dataset`): The dataset to be loaded. + shuffle (bool, optional): Whether to shuffle the dataset. Defaults to False. + seed (int, optional): Random worker seed for sampling, defaults to 1024. + add_sampler: Whether to add ``DistributedDataParallelSampler`` to the dataset. Defaults to True. + drop_last (bool, optional): Set to True to drop the last incomplete batch, if the dataset size + is not divisible by the batch size. If False and the size of dataset is not divisible by + the batch size, then the last batch will be smaller, defaults to False. + pin_memory (bool, optional): Whether to pin memory address in CPU memory. Defaults to False. + num_workers (int, optional): Number of worker threads for this dataloader. Defaults to 0. + kwargs (dict): optional parameters for ``torch.utils.data.DataLoader``, more details could be found in + `DataLoader `_. + + Returns: + :class:`torch.utils.data.DataLoader`: A DataLoader used for training or testing. + """ + _kwargs = kwargs.copy() + process_group = process_group or _get_default_group() + sampler = StatefulDistributedSampler( + dataset, num_replicas=process_group.size(), rank=process_group.rank(), shuffle=shuffle + ) + + # Deterministic dataloader + def seed_worker(worker_id): + worker_seed = seed + np.random.seed(worker_seed) + torch.manual_seed(worker_seed) + random.seed(worker_seed) + + return DataLoader( + dataset, + batch_size=batch_size, + sampler=sampler, + worker_init_fn=seed_worker, + drop_last=drop_last, + pin_memory=pin_memory, + num_workers=num_workers, + **_kwargs, + ) + + +def load_json(file_path: str): + with open(file_path, "r") as f: + return json.load(f) + + +def save_json(data, file_path: str): + with open(file_path, "w") as f: + json.dump(data, f, indent=4) + + +class RandomDataset(Dataset): + def __init__(self, num_samples: int = 1000, max_length: int = 2048, vocab_size: int = 32000): + self.num_samples = num_samples + self.max_length = max_length + self.input_ids = torch.randint(0, vocab_size, (num_samples, max_length), device=get_current_device()) + self.attention_mask = torch.ones_like(self.input_ids) + + def __len__(self): + return self.num_samples + + def __getitem__(self, idx): + return { + "input_ids": self.input_ids[idx], + "attention_mask": self.attention_mask[idx], + "labels": self.input_ids[idx], + } diff --git a/examples/language/llama2/finetune.py b/examples/language/llama2/finetune.py new file mode 100644 index 0000000000000000000000000000000000000000..33aa1d33e6ba30c32c1223ca50e38cdaf69b28e2 --- /dev/null +++ b/examples/language/llama2/finetune.py @@ -0,0 +1,313 @@ +import argparse +import math +import os +import resource +from contextlib import nullcontext +from functools import partial +from typing import Optional, Tuple + +import torch +import torch.distributed as dist +import torch.nn as nn +from attn import SUPPORT_XFORMERS, replace_xformers +from data_utils import load_json, prepare_dataloader, save_json +from datasets import load_dataset +from torch.optim import Optimizer +from torch.optim.lr_scheduler import _LRScheduler +from torch.utils.tensorboard import SummaryWriter +from tqdm import tqdm +from transformers.models.llama.configuration_llama import LlamaConfig +from transformers.models.llama.modeling_llama import LlamaForCausalLM +from transformers.models.llama.tokenization_llama import LlamaTokenizer + +import colossalai +from colossalai.booster import Booster +from colossalai.booster.plugin import GeminiPlugin, HybridParallelPlugin, LowLevelZeroPlugin +from colossalai.cluster import DistCoordinator +from colossalai.lazy import LazyInitContext +from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR +from colossalai.nn.optimizer import HybridAdam +from colossalai.utils import get_current_device + + +def get_model_numel(model: nn.Module) -> int: + return sum(p.numel() for p in model.parameters()) + + +def format_numel_str(numel: int) -> str: + B = 1024**3 + M = 1024**2 + K = 1024 + if numel >= B: + return f"{numel / B:.2f} B" + elif numel >= M: + return f"{numel / M:.2f} M" + elif numel >= K: + return f"{numel / K:.2f} K" + else: + return f"{numel}" + + +def tokenize_batch_for_finetune(batch, tokenizer: Optional[LlamaTokenizer] = None, max_length: int = 2048): + texts = [sample["prompt"] + sample["completion"] for sample in batch] + data = tokenizer(texts, return_tensors="pt", padding="max_length", truncation=True, max_length=max_length) + data = {k: v.cuda() for k, v in data.items()} + data["labels"] = data["input_ids"].clone() + return data + + +def all_reduce_mean(tensor: torch.Tensor) -> torch.Tensor: + dist.all_reduce(tensor, op=dist.ReduceOp.SUM) + tensor.div_(dist.get_world_size()) + return tensor + + +def save( + booster: Booster, + model: nn.Module, + optimizer: Optimizer, + lr_scheduler: _LRScheduler, + epoch: int, + step: int, + batch_size: int, + coordinator: DistCoordinator, + save_dir: str, +): + save_dir = os.path.join(save_dir, f"epoch{epoch}-step{step}") + os.makedirs(os.path.join(save_dir, "model"), exist_ok=True) + + booster.save_model(model, os.path.join(save_dir, "model"), shard=True) + booster.save_optimizer(optimizer, os.path.join(save_dir, "optimizer"), shard=True) + booster.save_lr_scheduler(lr_scheduler, os.path.join(save_dir, "lr_scheduler")) + running_states = { + "epoch": epoch, + "step": step, + "sample_start_index": step * batch_size, + } + if coordinator.is_master(): + save_json(running_states, os.path.join(save_dir, "running_states.json")) + + +def load( + booster: Booster, model: nn.Module, optimizer: Optimizer, lr_scheduler: _LRScheduler, load_dir: str +) -> Tuple[int, int, int]: + booster.load_model(model, os.path.join(load_dir, "model")) + booster.load_optimizer(optimizer, os.path.join(load_dir, "optimizer")) + booster.load_lr_scheduler(lr_scheduler, os.path.join(load_dir, "lr_scheduler")) + running_states = load_json(os.path.join(load_dir, "running_states.json")) + return running_states["epoch"], running_states["step"], running_states["sample_start_index"] + + +def _criterion(outputs, inputs): + return outputs.loss + + +def main(): + # ============================== + # Parse Arguments + # ============================== + parser = argparse.ArgumentParser() + parser.add_argument("--model_path", type=str, help="pretrained checkpoint path, used with mode==finetune") + parser.add_argument( + "-p", + "--plugin", + choices=["gemini", "gemini_auto", "zero2", "zero2_cpu", "hybrid_parallel"], + default="gemini", + help="Choose which plugin to use", + ) + parser.add_argument("-d", "--dataset", type=str, default="yizhongw/self_instruct", help="Data set path") + parser.add_argument("--task_name", type=str, default="super_natural_instructions", help="task to run") + parser.add_argument("-e", "--num_epochs", type=int, default=1, help="Number of epochs") + parser.add_argument("-b", "--batch_size", type=int, default=2, help="Local batch size") + parser.add_argument("--lr", type=float, default=3e-4, help="Learning rate") + parser.add_argument("-w", "--weigth_decay", type=float, default=0.1, help="Weight decay") + parser.add_argument("-g", "--grad_checkpoint", action="store_true", help="Use gradient checkpointing") + parser.add_argument("-l", "--max_length", type=int, default=4096, help="Max sequence length") + parser.add_argument("-x", "--mixed_precision", default="fp16", choices=["fp16", "bf16"], help="Mixed precision") + parser.add_argument("-i", "--save_interval", type=int, default=1000, help="Save interval") + parser.add_argument("-o", "--save_dir", type=str, default="checkpoint", help="Checkpoint directory") + parser.add_argument("-f", "--load", type=str, default=None, help="Load checkpoint") + parser.add_argument("--grad_clip", type=float, default=1.0, help="Gradient clipping") + parser.add_argument("-t", "--tensorboard_dir", type=str, default="tb_logs", help="Tensorboard directory") + parser.add_argument("-a", "--flash_attention", action="store_true", help="Use Flash Attention") + args = parser.parse_args() + + # ============================== + # Initialize Distributed Training + # ============================== + colossalai.launch_from_torch({}) + coordinator = DistCoordinator() + + # ============================== + # Initialize Booster + # ============================== + if args.plugin == "gemini": + plugin = GeminiPlugin(precision=args.mixed_precision, initial_scale=2**16, max_norm=args.grad_clip) + elif args.plugin == "gemini_auto": + plugin = GeminiPlugin( + precision=args.mixed_precision, placement_policy="auto", initial_scale=2**16, max_norm=args.grad_clip + ) + elif args.plugin == "zero2": + plugin = LowLevelZeroPlugin( + stage=2, precision=args.mixed_precision, initial_scale=2**16, max_norm=args.grad_clip + ) + elif args.plugin == "zero2_cpu": + plugin = LowLevelZeroPlugin( + stage=2, precision=args.mixed_precision, initial_scale=2**16, cpu_offload=True, max_norm=args.grad_clip + ) + elif args.plugin == "hybrid_parallel": + # modify the param accordingly, default configuration is for llama2-7b + plugin = HybridParallelPlugin( + tp_size=4, + pp_size=2, + num_microbatches=None, + microbatch_size=1, + enable_jit_fused=False, + zero_stage=0, + precision="fp32", + initial_scale=1, + ) + else: + raise ValueError(f"Unknown plugin {args.plugin}") + + booster = Booster(plugin=plugin) + + use_pipeline = isinstance(booster.plugin, HybridParallelPlugin) and booster.plugin.pp_size > 1 + is_pp_last_stage = use_pipeline and booster.plugin.stage_manager.is_last_stage() + print_flag = (not use_pipeline and coordinator.is_master()) or (use_pipeline and is_pp_last_stage) + + # ============================== + # Initialize Tensorboard + # ============================== + if print_flag: + os.makedirs(args.tensorboard_dir, exist_ok=True) + writer = SummaryWriter(args.tensorboard_dir) + + # ============================== + # Initialize Model, Optimizer and LR Scheduler + # ============================== + + config = LlamaConfig.from_pretrained(args.model_path) + # use lazy init when using GeminiPlugin + init_ctx = ( + LazyInitContext(default_device=get_current_device()) if isinstance(plugin, GeminiPlugin) else nullcontext() + ) + + with init_ctx: + model = LlamaForCausalLM(config) + + # ============================== + # Initialize Tokenizer, Dataset and Dataloader + # ============================== + tokenizer = LlamaTokenizer.from_pretrained("hf-internal-testing/llama-tokenizer") + # follows fast chat: https://github.com/lm-sys/FastChat/blob/main/fastchat/train/train.py#L257 + tokenizer.pad_token = tokenizer.unk_token + + dataset = load_dataset(args.dataset, args.task_name) + train_ds = dataset["train"] + dataloader = prepare_dataloader( + train_ds, + batch_size=args.batch_size, + shuffle=True, + drop_last=True, + collate_fn=partial(tokenize_batch_for_finetune, tokenizer=tokenizer, max_length=args.max_length), + ) + + if args.grad_checkpoint: + model.gradient_checkpointing_enable() + if args.flash_attention: + assert SUPPORT_XFORMERS, "Use flash attention while xfomers is not installed" + replace_xformers(model) + + model_numel = get_model_numel(model) + coordinator.print_on_master(f"Model params: {format_numel_str(model_numel)}") + + optimizer = HybridAdam(model.parameters(), lr=args.lr, betas=(0.9, 0.95), weight_decay=args.weigth_decay) + total_step = args.num_epochs * len(dataloader) + lr_scheduler = CosineAnnealingWarmupLR( + optimizer, total_steps=total_step, warmup_steps=math.ceil(total_step * 0.03), eta_min=0.1 * args.lr + ) + default_dtype = torch.float16 if args.mixed_precision == "fp16" else torch.bfloat16 + torch.set_default_dtype(default_dtype) + model, optimizer, _, dataloader, lr_scheduler = booster.boost( + model, optimizer, dataloader=dataloader, lr_scheduler=lr_scheduler + ) + torch.set_default_dtype(torch.float) + + booster.load_model(model, args.model_path) + + coordinator.print_on_master(f"Booster init max CUDA memory: {torch.cuda.max_memory_allocated()/1024**2:.2f} MB") + coordinator.print_on_master( + f"Booster init max CPU memory: {resource.getrusage(resource.RUSAGE_SELF).ru_maxrss/1024:.2f} MB" + ) + + # load checkpoint if specified + start_epoch = 0 + start_step = 0 + sampler_start_idx = 0 + if args.load is not None: + coordinator.print_on_master("Loading checkpoint") + start_epoch, start_step, sampler_start_idx = load(booster, model, optimizer, lr_scheduler, args.load) + coordinator.print_on_master(f"Loaded checkpoint {args.load} at epoch {start_epoch} step {start_step}") + + num_steps_per_epoch = len(dataloader) + + # if resume training, set the sampler start index to the correct value + dataloader.sampler.set_start_index(sampler_start_idx) + for epoch in range(start_epoch, args.num_epochs): + dataloader.sampler.set_epoch(epoch) + step_nums = num_steps_per_epoch - start_step + dataloader_iter = iter(dataloader) + + with tqdm( + range(step_nums), + desc=f"Epoch {epoch}", + disable=not print_flag, + total=num_steps_per_epoch, + initial=start_step, + ) as pbar: + for step in pbar: + if use_pipeline: + outputs = booster.execute_pipeline( + dataloader_iter, model, _criterion, optimizer, return_loss=True, return_outputs=True + ) + loss = outputs["loss"] + else: + batch = next(dataloader_iter) + outputs = model(**batch) + loss = outputs[0] + booster.backward(loss, optimizer) + + optimizer.step() + lr_scheduler.step() + optimizer.zero_grad() + + if not use_pipeline: + all_reduce_mean(loss) + if print_flag: + pbar.set_postfix({"loss": loss.item()}) + writer.add_scalar("loss", loss.item(), epoch * num_steps_per_epoch + step) + + if args.save_interval > 0 and (step + 1) % args.save_interval == 0: + coordinator.print_on_master(f"Saving checkpoint") + save( + booster, + model, + optimizer, + lr_scheduler, + epoch, + step + 1, + args.batch_size, + coordinator, + args.save_dir, + ) + coordinator.print_on_master(f"Saved checkpoint at epoch {epoch} step {step + 1}") + # the continue epochs are not resumed, so we need to reset the sampler start index and start step + dataloader.sampler.set_start_index(0) + start_step = 0 + + coordinator.print_on_master(f"Max CUDA memory usage: {torch.cuda.max_memory_allocated()/1024**2:.2f} MB") + + +if __name__ == "__main__": + main() diff --git a/examples/language/llama2/model_utils.py b/examples/language/llama2/model_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..63569bc61143b9abbba424ea312359c8ce85bbca --- /dev/null +++ b/examples/language/llama2/model_utils.py @@ -0,0 +1,32 @@ +from contextlib import contextmanager + +import torch +import torch.nn as nn + + +@contextmanager +def low_precision_init(target_dtype: torch.dtype = torch.float16): + dtype = torch.get_default_dtype() + try: + torch.set_default_dtype(target_dtype) + yield + finally: + torch.set_default_dtype(dtype) + + +def get_model_numel(model: nn.Module) -> int: + return sum(p.numel() for p in model.parameters()) + + +def format_numel_str(numel: int) -> str: + B = 1024**3 + M = 1024**2 + K = 1024 + if numel >= B: + return f"{numel / B:.2f} B" + elif numel >= M: + return f"{numel / M:.2f} M" + elif numel >= K: + return f"{numel / K:.2f} K" + else: + return f"{numel}" diff --git a/examples/language/llama2/performance_evaluator.py b/examples/language/llama2/performance_evaluator.py new file mode 100644 index 0000000000000000000000000000000000000000..a57c1e0e9ae3e663861ac9c11344763019ec90d5 --- /dev/null +++ b/examples/language/llama2/performance_evaluator.py @@ -0,0 +1,105 @@ +from time import time +from typing import Optional + +import torch +import torch.distributed as dist +from torch import Tensor + +from colossalai.cluster import DistCoordinator + + +def divide(x: float, y: float) -> float: + if y == 0: + return float("inf") + elif y == float("inf"): + return float("nan") + return x / y + + +@torch.no_grad() +def all_reduce_mean(x: float, world_size: int) -> float: + if world_size == 1: + return x + tensor = torch.tensor([x], device=torch.cuda.current_device()) + dist.all_reduce(tensor) + tensor = tensor / world_size + return tensor.item() + + +class Timer: + def __init__(self) -> None: + self.start_time: Optional[float] = None + self.duration: float = 0.0 + + def start(self) -> None: + self.start_time = time() + + def end(self) -> None: + assert self.start_time is not None + self.duration += time() - self.start_time + self.start_time = None + + def reset(self) -> None: + self.duration = 0.0 + + +class PerformanceEvaluator: + """ + Callback for valuate the performance of the model. + Args: + actor_num_params: The number of parameters of the actor model. + critic_num_params: The number of parameters of the critic model. + initial_model_num_params: The number of parameters of the initial model. + reward_model_num_params: The number of parameters of the reward model. + enable_grad_checkpoint: Whether to enable gradient checkpointing. + ignore_episodes: The number of episodes to ignore when calculating the performance. + """ + + def __init__( + self, + model_numel: int, + enable_grad_checkpoint: bool = False, + ignore_steps: int = 0, + dp_world_size: Optional[int] = None, + ) -> None: + self.model_numel = model_numel + self.enable_grad_checkpoint = enable_grad_checkpoint + self.ignore_steps = ignore_steps + + self.coordinator = DistCoordinator() + self.dp_world_size = dp_world_size or self.coordinator.world_size + self.disable: bool = False + self.timer = Timer() + self.num_samples: int = 0 + self.flop: int = 0 + + def on_step_start(self, step: int) -> None: + self.disable = self.ignore_steps > 0 and step < self.ignore_steps + if self.disable: + return + torch.cuda.synchronize() + self.timer.start() + + def on_step_end(self, input_ids: Tensor, **kwargs) -> None: + if self.disable: + return + torch.cuda.synchronize() + self.timer.end() + + batch_size, seq_len = input_ids.shape + + self.num_samples += batch_size + self.flop += batch_size * seq_len * self.model_numel * 2 * (3 + int(self.enable_grad_checkpoint)) + + def on_fit_end(self) -> None: + avg_duration = all_reduce_mean(self.timer.duration, self.coordinator.world_size) + avg_throughput = self.num_samples * self.dp_world_size / (avg_duration + 1e-12) + mp_world_size = self.coordinator.world_size // self.dp_world_size + avg_tflops_per_gpu = self.flop / 1e12 / (avg_duration + 1e-12) / mp_world_size + self.coordinator.print_on_master( + f"num_samples: {self.num_samples}, dp_world_size: {self.dp_world_size}, flop: {self.flop}, avg_duration: {avg_duration}, " + f"avg_throughput: {avg_throughput}" + ) + self.coordinator.print_on_master( + f"Throughput: {avg_throughput:.2f} samples/sec, TFLOPS per GPU: {avg_tflops_per_gpu:.2f}" + ) diff --git a/examples/language/llama2/pretrain.py b/examples/language/llama2/pretrain.py new file mode 100644 index 0000000000000000000000000000000000000000..6cc73b6265a4f56bfa659f94173b3960ad0beeb2 --- /dev/null +++ b/examples/language/llama2/pretrain.py @@ -0,0 +1,329 @@ +import argparse +import os +import resource +from contextlib import nullcontext +from functools import partial +from typing import Optional, Tuple + +import torch +import torch.distributed as dist +import torch.nn as nn +from attn import SUPPORT_XFORMERS, replace_xformers +from data_utils import load_json, prepare_dataloader, save_json +from datasets import load_dataset +from torch.optim import Optimizer +from torch.optim.lr_scheduler import _LRScheduler +from torch.utils.tensorboard import SummaryWriter +from tqdm import tqdm +from transformers.models.llama.configuration_llama import LlamaConfig +from transformers.models.llama.modeling_llama import LlamaForCausalLM +from transformers.models.llama.tokenization_llama import LlamaTokenizer + +import colossalai +from colossalai.booster import Booster +from colossalai.booster.plugin import GeminiPlugin, HybridParallelPlugin, LowLevelZeroPlugin +from colossalai.cluster import DistCoordinator +from colossalai.lazy import LazyInitContext +from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR +from colossalai.nn.optimizer import HybridAdam +from colossalai.utils import get_current_device + +MODEL_CONFIGS = { + "7b": LlamaConfig(max_position_embeddings=4096), + "13b": LlamaConfig( + hidden_size=5120, + intermediate_size=13824, + num_hidden_layers=40, + num_attention_heads=40, + max_position_embeddings=4096, + ), + "70b": LlamaConfig( + hidden_size=8192, + intermediate_size=28672, + num_hidden_layers=80, + num_attention_heads=64, + max_position_embeddings=4096, + num_key_value_heads=8, + ), +} + + +def get_model_numel(model: nn.Module) -> int: + return sum(p.numel() for p in model.parameters()) + + +def format_numel_str(numel: int) -> str: + B = 1024**3 + M = 1024**2 + K = 1024 + if numel >= B: + return f"{numel / B:.2f} B" + elif numel >= M: + return f"{numel / M:.2f} M" + elif numel >= K: + return f"{numel / K:.2f} K" + else: + return f"{numel}" + + +def tokenize_batch_for_pretrain(batch, tokenizer: Optional[LlamaTokenizer] = None, max_length: int = 2048): + texts = [sample["text"] for sample in batch] + data = tokenizer(texts, return_tensors="pt", padding="max_length", truncation=True, max_length=max_length) + data = {k: v.cuda() for k, v in data.items()} + data["labels"] = data["input_ids"].clone() + return data + + +def all_reduce_mean(tensor: torch.Tensor) -> torch.Tensor: + dist.all_reduce(tensor, op=dist.ReduceOp.SUM) + tensor.div_(dist.get_world_size()) + return tensor + + +def save( + booster: Booster, + model: nn.Module, + optimizer: Optimizer, + lr_scheduler: _LRScheduler, + epoch: int, + step: int, + batch_size: int, + coordinator: DistCoordinator, + save_dir: str, +): + save_dir = os.path.join(save_dir, f"epoch{epoch}-step{step}") + os.makedirs(os.path.join(save_dir, "model"), exist_ok=True) + + booster.save_model(model, os.path.join(save_dir, "model"), shard=True) + booster.save_optimizer(optimizer, os.path.join(save_dir, "optimizer"), shard=True) + booster.save_lr_scheduler(lr_scheduler, os.path.join(save_dir, "lr_scheduler")) + running_states = { + "epoch": epoch, + "step": step, + "sample_start_index": step * batch_size, + } + if coordinator.is_master(): + save_json(running_states, os.path.join(save_dir, "running_states.json")) + + +def load( + booster: Booster, model: nn.Module, optimizer: Optimizer, lr_scheduler: _LRScheduler, load_dir: str +) -> Tuple[int, int, int]: + booster.load_model(model, os.path.join(load_dir, "model")) + booster.load_optimizer(optimizer, os.path.join(load_dir, "optimizer")) + booster.load_lr_scheduler(lr_scheduler, os.path.join(load_dir, "lr_scheduler")) + running_states = load_json(os.path.join(load_dir, "running_states.json")) + return running_states["epoch"], running_states["step"], running_states["sample_start_index"] + + +def _criterion(outputs, inputs): + return outputs.loss + + +def main(): + # ============================== + # Parse Arguments + # ============================== + parser = argparse.ArgumentParser() + parser.add_argument("-c", "--config", type=str, default="7b", help="Model configuration") + parser.add_argument( + "-p", + "--plugin", + choices=["gemini", "gemini_auto", "zero2", "zero2_cpu", "hybrid_parallel"], + default="gemini", + help="Choose which plugin to use", + ) + parser.add_argument( + "-d", "--dataset", type=str, default="togethercomputer/RedPajama-Data-1T-Sample", help="Data set path" + ) + parser.add_argument("-e", "--num_epochs", type=int, default=1, help="Number of epochs") + parser.add_argument("-b", "--batch_size", type=int, default=2, help="Local batch size") + parser.add_argument("--lr", type=float, default=3e-4, help="Learning rate") + parser.add_argument("-w", "--weigth_decay", type=float, default=0.1, help="Weight decay") + parser.add_argument("-s", "--warmup_steps", type=int, default=2000, help="Warmup steps") + parser.add_argument("-g", "--grad_checkpoint", action="store_true", help="Use gradient checkpointing") + parser.add_argument("-l", "--max_length", type=int, default=4096, help="Max sequence length") + parser.add_argument("-x", "--mixed_precision", default="fp16", choices=["fp16", "bf16"], help="Mixed precision") + parser.add_argument("-i", "--save_interval", type=int, default=1000, help="Save interval") + parser.add_argument("-o", "--save_dir", type=str, default="checkpoint", help="Checkpoint directory") + parser.add_argument("-f", "--load", type=str, default=None, help="Load checkpoint") + parser.add_argument("--grad_clip", type=float, default=1.0, help="Gradient clipping") + parser.add_argument("-t", "--tensorboard_dir", type=str, default="tb_logs", help="Tensorboard directory") + parser.add_argument("-a", "--flash_attention", action="store_true", help="Use Flash Attention") + args = parser.parse_args() + + # ============================== + # Initialize Distributed Training + # ============================== + colossalai.launch_from_torch({}) + coordinator = DistCoordinator() + + # ============================== + # Initialize Booster + # ============================== + if args.plugin == "gemini": + plugin = GeminiPlugin(precision=args.mixed_precision, initial_scale=2**16, max_norm=args.grad_clip) + elif args.plugin == "gemini_auto": + plugin = GeminiPlugin( + precision=args.mixed_precision, placement_policy="auto", initial_scale=2**16, max_norm=args.grad_clip + ) + elif args.plugin == "zero2": + plugin = LowLevelZeroPlugin( + stage=2, precision=args.mixed_precision, initial_scale=2**16, max_norm=args.grad_clip + ) + elif args.plugin == "zero2_cpu": + plugin = LowLevelZeroPlugin( + stage=2, precision=args.mixed_precision, initial_scale=2**16, cpu_offload=True, max_norm=args.grad_clip + ) + elif args.plugin == "hybrid_parallel": + # modify the param accordingly, default configuration is for llama2-7b + plugin = HybridParallelPlugin( + tp_size=4, + pp_size=2, + num_microbatches=None, + microbatch_size=1, + enable_jit_fused=False, + zero_stage=0, + precision="fp32", + initial_scale=1, + ) + else: + raise ValueError(f"Unknown plugin {args.plugin}") + + booster = Booster(plugin=plugin) + + use_pipeline = isinstance(booster.plugin, HybridParallelPlugin) and booster.plugin.pp_size > 1 + is_pp_last_stage = use_pipeline and booster.plugin.stage_manager.is_last_stage() + print_flag = (not use_pipeline and coordinator.is_master()) or (use_pipeline and is_pp_last_stage) + + # ============================== + # Initialize Tensorboard + # ============================== + if print_flag: + os.makedirs(args.tensorboard_dir, exist_ok=True) + writer = SummaryWriter(args.tensorboard_dir) + + # ============================== + # Initialize Tokenizer, Dataset and Dataloader + # ============================== + tokenizer = LlamaTokenizer.from_pretrained("hf-internal-testing/llama-tokenizer") + # follows fast chat: https://github.com/lm-sys/FastChat/blob/main/fastchat/train/train.py#L257 + tokenizer.pad_token = tokenizer.unk_token + + dataset = load_dataset(args.dataset) + train_ds = dataset["train"] + dataloader = prepare_dataloader( + train_ds, + batch_size=args.batch_size, + shuffle=True, + drop_last=True, + collate_fn=partial(tokenize_batch_for_pretrain, tokenizer=tokenizer, max_length=args.max_length), + ) + + # ============================== + # Initialize Model, Optimizer and LR Scheduler + # ============================== + config = MODEL_CONFIGS[args.config] + # use lazy init when using GeminiPlugin + init_ctx = ( + LazyInitContext(default_device=get_current_device()) if isinstance(plugin, GeminiPlugin) else nullcontext() + ) + + with init_ctx: + model = LlamaForCausalLM(config) + + if args.grad_checkpoint: + model.gradient_checkpointing_enable() + if args.flash_attention: + assert SUPPORT_XFORMERS, "Use flash attention while xfomers is not installed" + replace_xformers(model) + + model_numel = get_model_numel(model) + coordinator.print_on_master(f"Model params: {format_numel_str(model_numel)}") + + optimizer = HybridAdam(model.parameters(), lr=args.lr, betas=(0.9, 0.95), weight_decay=args.weigth_decay) + lr_scheduler = CosineAnnealingWarmupLR( + optimizer, total_steps=args.num_epochs * len(dataloader), warmup_steps=args.warmup_steps, eta_min=0.1 * args.lr + ) + default_dtype = torch.float16 if args.mixed_precision == "fp16" else torch.bfloat16 + torch.set_default_dtype(default_dtype) + model, optimizer, _, dataloader, lr_scheduler = booster.boost( + model, optimizer, dataloader=dataloader, lr_scheduler=lr_scheduler + ) + torch.set_default_dtype(torch.float) + + coordinator.print_on_master(f"Booster init max CUDA memory: {torch.cuda.max_memory_allocated()/1024**2:.2f} MB") + coordinator.print_on_master( + f"Booster init max CPU memory: {resource.getrusage(resource.RUSAGE_SELF).ru_maxrss/1024:.2f} MB" + ) + + # load checkpoint if specified + start_epoch = 0 + start_step = 0 + sampler_start_idx = 0 + if args.load is not None: + coordinator.print_on_master("Loading checkpoint") + start_epoch, start_step, sampler_start_idx = load(booster, model, optimizer, lr_scheduler, args.load) + coordinator.print_on_master(f"Loaded checkpoint {args.load} at epoch {start_epoch} step {start_step}") + + num_steps_per_epoch = len(dataloader) + + # if resume training, set the sampler start index to the correct value + dataloader.sampler.set_start_index(sampler_start_idx) + for epoch in range(start_epoch, args.num_epochs): + dataloader.sampler.set_epoch(epoch) + step_nums = num_steps_per_epoch - start_step + dataloader_iter = iter(dataloader) + + with tqdm( + range(step_nums), + desc=f"Epoch {epoch}", + disable=not print_flag, + total=num_steps_per_epoch, + initial=start_step, + ) as pbar: + for step in pbar: + if use_pipeline: + outputs = booster.execute_pipeline( + dataloader_iter, model, _criterion, optimizer, return_loss=True, return_outputs=True + ) + loss = outputs["loss"] + else: + batch = next(dataloader_iter) + outputs = model(**batch) + loss = outputs[0] + booster.backward(loss, optimizer) + + optimizer.step() + lr_scheduler.step() + optimizer.zero_grad() + + if not use_pipeline: + all_reduce_mean(loss) + if print_flag: + pbar.set_postfix({"loss": loss.item()}) + writer.add_scalar("loss", loss.item(), epoch * num_steps_per_epoch + step) + + if args.save_interval > 0 and (step + 1) % args.save_interval == 0: + coordinator.print_on_master(f"Saving checkpoint") + save( + booster, + model, + optimizer, + lr_scheduler, + epoch, + step + 1, + args.batch_size, + coordinator, + args.save_dir, + ) + coordinator.print_on_master(f"Saved checkpoint at epoch {epoch} step {step + 1}") + # the continue epochs are not resumed, so we need to reset the sampler start index and start step + dataloader.sampler.set_start_index(0) + start_step = 0 + + coordinator.print_on_master(f"Max CUDA memory usage: {torch.cuda.max_memory_allocated()/1024**2:.2f} MB") + + +if __name__ == "__main__": + main() diff --git a/examples/language/llama2/requirements.txt b/examples/language/llama2/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..6b475682dad02b6868594ae17c1f2a4c0fd65350 --- /dev/null +++ b/examples/language/llama2/requirements.txt @@ -0,0 +1,9 @@ +colossalai>=0.3.2 +datasets +numpy +torch>=1.12.0,<=2.0.0 +tqdm +transformers +flash-attn>=2.0.0,<=2.0.5 +SentencePiece==0.1.99 +tensorboard==2.14.0 diff --git a/examples/language/llama2/scripts/benchmark_70B/3d.sh b/examples/language/llama2/scripts/benchmark_70B/3d.sh new file mode 100644 index 0000000000000000000000000000000000000000..d50c57042d1a931e1698a3db450156c06d410ef4 --- /dev/null +++ b/examples/language/llama2/scripts/benchmark_70B/3d.sh @@ -0,0 +1,17 @@ +#!/bin/bash + +# TODO: fix this +echo "3D parallel for LLaMA-2 is not ready yet" +exit 1 + +################ +#Load your environments and modules here +################ + +HOSTFILE=$(realpath hosts.txt) + +cd ../.. + +export OMP_NUM_THREADS=8 + +colossalai run --nproc_per_node 8 --hostfile $HOSTFILE benchmark.py -c 70b -p 3d -g -x -b 8 --tp 4 --pp 2 --mbs 4 diff --git a/examples/language/llama2/scripts/benchmark_70B/gemini.sh b/examples/language/llama2/scripts/benchmark_70B/gemini.sh new file mode 100644 index 0000000000000000000000000000000000000000..c80d4d9f25bf844bdbf3b0bbaff70c88c9f2fe4e --- /dev/null +++ b/examples/language/llama2/scripts/benchmark_70B/gemini.sh @@ -0,0 +1,13 @@ +#!/bin/bash + +################ +#Load your environments and modules here +################ + +HOSTFILE=$(realpath hosts.txt) + +cd ../.. + +export OMP_NUM_THREADS=8 + +colossalai run --nproc_per_node 8 --hostfile $HOSTFILE benchmark.py -c 70b -g -x -b 2 diff --git a/examples/language/llama2/scripts/benchmark_70B/gemini_auto.sh b/examples/language/llama2/scripts/benchmark_70B/gemini_auto.sh new file mode 100644 index 0000000000000000000000000000000000000000..ce3b2f2170cc2a00b9c7a70dd9c82c5158f1f03e --- /dev/null +++ b/examples/language/llama2/scripts/benchmark_70B/gemini_auto.sh @@ -0,0 +1,13 @@ +#!/bin/bash + +################ +#Load your environments and modules here +################ + +HOSTFILE=$(realpath hosts.txt) + +cd ../.. + +export OMP_NUM_THREADS=8 + +colossalai run --nproc_per_node 8 --hostfile $HOSTFILE benchmark.py -c 70b -p gemini_auto -g -x -b 2 diff --git a/examples/language/llama2/scripts/benchmark_7B/gemini.sh b/examples/language/llama2/scripts/benchmark_7B/gemini.sh new file mode 100644 index 0000000000000000000000000000000000000000..db4968a8df7f97a46b9e5bea67b1d868bf4eb212 --- /dev/null +++ b/examples/language/llama2/scripts/benchmark_7B/gemini.sh @@ -0,0 +1,13 @@ +#!/bin/bash + +################ +#Load your environments and modules here +################ + +HOSTFILE=$(realpath hosts.txt) + +cd ../.. + +export OMP_NUM_THREADS=8 + +colossalai run --nproc_per_node 8 --hostfile $HOSTFILE benchmark.py -g -x -b 16 diff --git a/examples/language/llama2/scripts/benchmark_7B/gemini_auto.sh b/examples/language/llama2/scripts/benchmark_7B/gemini_auto.sh new file mode 100644 index 0000000000000000000000000000000000000000..59ec1c1a75c2bc96cbf5855d3f89076ba34e798a --- /dev/null +++ b/examples/language/llama2/scripts/benchmark_7B/gemini_auto.sh @@ -0,0 +1,13 @@ +#!/bin/bash + +################ +#Load your environments and modules here +################ + +HOSTFILE=$(realpath hosts.txt) + +cd ../.. + +export OMP_NUM_THREADS=8 + +colossalai run --nproc_per_node 8 --hostfile $HOSTFILE benchmark.py -p gemini_auto -g -x -b 16 diff --git a/examples/language/llama2/test_ci.sh b/examples/language/llama2/test_ci.sh new file mode 100755 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/examples/language/opt/README.md b/examples/language/opt/README.md index c2fd254571c7fb7b9df3de8539fae32826a70e3f..af1e794374ed2164bf918ba796e98285d0edf692 100644 --- a/examples/language/opt/README.md +++ b/examples/language/opt/README.md @@ -19,15 +19,32 @@ Meta recently released [Open Pretrained Transformer (OPT)](https://github.com/fa The following example of [Colossal-AI](https://github.com/hpcaitech/ColossalAI) demonstrates fine-tuning Casual Language Modelling at low cost. -We are using the pre-training weights of the OPT model provided by Hugging Face Hub on the raw WikiText-2 (no tokens were replaced before -the tokenization). This training script is adapted from the [HuggingFace Language Modelling examples](https://github.com/huggingface/transformers/tree/main/examples/pytorch/language-modeling). ## Our Modifications -We adapt the OPT training code to ColossalAI by leveraging Gemini and ZeRO DDP. -## Quick Start -You can launch training by using the following bash script +We are using the pre-training weights of the OPT model provided by Hugging Face Hub on the raw WikiText-2 (no tokens were replaced before +the tokenization). + +We adapt the OPT training code to ColossalAI by leveraging [Boosting API](https://colossalai.org/docs/basics/booster_api) loaded with a chosen plugin, where each plugin corresponds to a specific kind of training strategy. This example supports plugins including TorchDDPPlugin, LowLevelZeroPlugin, HybridParallelPlugin and GeminiPlugin. + +## Run Demo + +By running the following script: +```bash +bash run_demo.sh +``` +You will finetune a [facebook/opt-350m](https://huggingface.co/facebook/opt-350m) model on this [dataset](https://huggingface.co/datasets/hugginglearners/netflix-shows), which contains more than 8000 comments on Netflix shows. + +The script can be modified if you want to try another set of hyperparameters or change to another OPT model with different size. + +The demo code is adapted from this [blog](https://medium.com/geekculture/fine-tune-eleutherai-gpt-neo-to-generate-netflix-movie-descriptions-in-only-47-lines-of-code-40c9b4c32475) and the [HuggingFace Language Modelling examples](https://github.com/huggingface/transformers/tree/main/examples/pytorch/language-modeling). + + + +## Run Benchmark +You can run benchmark for OPT model by running the following script: ```bash -bash ./run_gemini.sh +bash run_benchmark.sh ``` +The script will test performance (throughput & peak memory usage) for each combination of hyperparameters. You can also play with this script to configure your set of hyperparameters for testing. diff --git a/examples/language/opt/args.py b/examples/language/opt/args.py new file mode 100644 index 0000000000000000000000000000000000000000..fc3d42fae220900f6a95d89e160197cec8d896e3 --- /dev/null +++ b/examples/language/opt/args.py @@ -0,0 +1,70 @@ +import argparse + + +def parse_demo_args(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--model_name_or_path", + type=str, + default="facebook/opt-350m", + help="Path to pretrained model or model identifier from huggingface.co/models.", + ) + parser.add_argument( + "--output_path", type=str, default="./output_model.bin", help="The path of your saved model after finetuning." + ) + parser.add_argument( + "--plugin", + type=str, + default="gemini", + help="Plugin to use. Valid plugins include 'torch_ddp','torch_ddp_fp16','gemini','low_level_zero', 'hybrid_parallel'.", + ) + parser.add_argument("--num_epoch", type=int, default=10, help="Number of epochs.") + parser.add_argument( + "--batch_size", type=int, default=32, help="Batch size (per dp group) for the training dataloader." + ) + parser.add_argument( + "--learning_rate", + type=float, + default=5e-5, + help="Initial learning rate (after the potential warmup period) to use.", + ) + parser.add_argument( + "--warmup_ratio", type=float, default=0.1, help="Ratio of warmup steps against total training steps." + ) + parser.add_argument("--weight_decay", type=float, default=0.01, help="Weight decay to use.") + parser.add_argument("--seed", type=int, default=42, help="A seed for reproducible training.") + + args = parser.parse_args() + return args + + +def parse_benchmark_args(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--model_name_or_path", + type=str, + default="facebook/opt-125m", + help="Path to pretrained model or model identifier from huggingface.co/models.", + ) + parser.add_argument( + "--plugin", + type=str, + default="gemini", + help="Plugin to use. Valid plugins include 'torch_ddp','torch_ddp_fp16','gemini','low_level_zero'.", + ) + parser.add_argument( + "--batch_size", type=int, default=32, help="Batch size (per dp group) for the training dataloader." + ) + parser.add_argument( + "--learning_rate", + type=float, + default=5e-5, + help="Initial learning rate (after the potential warmup period) to use.", + ) + parser.add_argument("--weight_decay", type=float, default=0.0, help="Weight decay to use.") + parser.add_argument("--max_train_steps", type=int, default=20, help="Total number of training steps to perform.") + parser.add_argument("--seed", type=int, default=42, help="A seed for reproducible training.") + parser.add_argument("--mem_cap", type=int, default=0, help="Limit on the usage of space for each GPU (in GB).") + args = parser.parse_args() + + return args diff --git a/examples/language/opt/benchmark.sh b/examples/language/opt/benchmark.sh deleted file mode 100644 index 0d04b5e9b33cad5bef7037a4a9f700290c0c338a..0000000000000000000000000000000000000000 --- a/examples/language/opt/benchmark.sh +++ /dev/null @@ -1,21 +0,0 @@ -export BS=16 -export MEMCAP=0 -export MODEL="6.7b" -export GPUNUM=1 - -for MODEL in "6.7b" "13b" "1.3b" -do -for GPUNUM in 8 1 -do -for BS in 16 24 32 8 -do -for MEMCAP in 0 40 -do -pkill -9 torchrun -pkill -9 python - -env BS=$BS MEM_CAP=$MEMCAP MODEL=$MODEL GPUNUM=$GPUNUM bash ./run_gemini.sh -done -done -done -done diff --git a/examples/language/opt/data.py b/examples/language/opt/data.py new file mode 100644 index 0000000000000000000000000000000000000000..9b9cc59518ab43e8135be36aad2d117120fa528a --- /dev/null +++ b/examples/language/opt/data.py @@ -0,0 +1,38 @@ +import torch +from datasets import load_dataset +from torch.utils.data import Dataset + + +class NetflixDataset(Dataset): + def __init__(self, tokenizer): + super().__init__() + + self.tokenizer = tokenizer + self.input_ids = [] + self.attn_masks = [] + self.labels = [] + self.txt_list = netflix_descriptions = load_dataset("hugginglearners/netflix-shows", split="train")[ + "description" + ] + self.max_length = max([len(self.tokenizer.encode(description)) for description in netflix_descriptions]) + + for txt in self.txt_list: + encodings_dict = self.tokenizer( + "
              " + txt + "
              ", truncation=True, max_length=self.max_length, padding="max_length" + ) + self.input_ids.append(torch.tensor(encodings_dict["input_ids"])) + self.attn_masks.append(torch.tensor(encodings_dict["attention_mask"])) + + def __len__(self): + return len(self.input_ids) + + def __getitem__(self, idx): + return self.input_ids[idx], self.attn_masks[idx] + + +def netflix_collator(data): + return { + "input_ids": torch.stack([x[0] for x in data]), + "attention_mask": torch.stack([x[1] for x in data]), + "labels": torch.stack([x[0] for x in data]), + } diff --git a/examples/language/opt/opt_benchmark.py b/examples/language/opt/opt_benchmark.py new file mode 100755 index 0000000000000000000000000000000000000000..d16c9fdf99adcdd2422959234d32615607aed534 --- /dev/null +++ b/examples/language/opt/opt_benchmark.py @@ -0,0 +1,130 @@ +import time + +import torch +import tqdm +import transformers +from args import parse_benchmark_args +from transformers import AutoConfig, OPTForCausalLM +from transformers.utils.versions import require_version + +import colossalai +from colossalai.booster import Booster +from colossalai.booster.plugin import GeminiPlugin, LowLevelZeroPlugin, TorchDDPPlugin +from colossalai.cluster import DistCoordinator +from colossalai.logging import disable_existing_loggers, get_dist_logger +from colossalai.nn.optimizer import HybridAdam + +require_version("transformers>=4.20.0", "To fix: pip install -r requirements.txt") + + +def format_num(num: int, bytes=False): + """Scale bytes to its proper format, e.g. 1253656 => '1.20MB'""" + factor = 1024 if bytes else 1000 + suffix = "B" if bytes else "" + for unit in ["", " K", " M", " G", " T", " P"]: + if num < factor: + return f"{num:.2f}{unit}{suffix}" + num /= factor + + +def get_data(batch_size, seq_len, vocab_size): + input_ids = torch.randint(0, vocab_size, (batch_size, seq_len), device=torch.cuda.current_device()) + attention_mask = torch.ones_like(input_ids) + return input_ids, attention_mask + + +def colo_memory_cap(size_in_GB): + from colossalai.utils import colo_device_memory_capacity, colo_set_process_memory_fraction, get_current_device + + cuda_capacity = colo_device_memory_capacity(get_current_device()) + if size_in_GB * (1024**3) < cuda_capacity: + colo_set_process_memory_fraction(size_in_GB * (1024**3) / cuda_capacity) + print(f"Limiting GPU memory usage to {size_in_GB} GB") + + +def main(): + args = parse_benchmark_args() + + # Launch ColossalAI + colossalai.launch_from_torch(config={}, seed=args.seed) + coordinator = DistCoordinator() + world_size = coordinator.world_size + + # Manage loggers + disable_existing_loggers() + logger = get_dist_logger() + if coordinator.is_master(): + transformers.utils.logging.set_verbosity_info() + else: + transformers.utils.logging.set_verbosity_error() + + # Whether to set limit of memory capacity + if args.mem_cap > 0: + colo_memory_cap(args.mem_cap) + + # Build OPT model + config = AutoConfig.from_pretrained(args.model_name_or_path) + model = OPTForCausalLM(config=config) + logger.info(f"Finish loading model from {args.model_name_or_path}", ranks=[0]) + + # Enable gradient checkpointing + model.gradient_checkpointing_enable() + + # Set plugin + booster_kwargs = {} + if args.plugin == "torch_ddp_fp16": + booster_kwargs["mixed_precision"] = "fp16" + if args.plugin.startswith("torch_ddp"): + plugin = TorchDDPPlugin() + elif args.plugin == "gemini": + plugin = GeminiPlugin(offload_optim_frac=1.0, pin_memory=True, initial_scale=2**5) + elif args.plugin == "low_level_zero": + plugin = LowLevelZeroPlugin(initial_scale=2**5) + logger.info(f"Set plugin as {args.plugin}", ranks=[0]) + + # Set optimizer + optimizer = HybridAdam(model.parameters(), lr=args.learning_rate) + + # Set booster + booster = Booster(plugin=plugin, **booster_kwargs) + model, optimizer, _, _, _ = booster.boost(model, optimizer) + + SEQ_LEN = 1024 + VOCAB_SIZE = 50257 + + # Start training. + logger.info(f"Start testing", ranks=[0]) + progress_bar = tqdm.tqdm(total=args.max_train_steps, desc="Training Step", disable=not coordinator.is_master()) + + torch.cuda.synchronize() + model.train() + start_time = time.time() + + for _ in range(args.max_train_steps): + input_ids, attn_mask = get_data(args.batch_size, SEQ_LEN, VOCAB_SIZE) + optimizer.zero_grad() + outputs = model(input_ids=input_ids, attention_mask=attn_mask, labels=input_ids, use_cache=False) + loss = outputs["loss"] + booster.backward(loss, optimizer) + optimizer.step() + + torch.cuda.synchronize() + progress_bar.update(1) + + # Compute Statistics + end_time = time.time() + throughput = "{:.4f}".format((world_size * args.max_train_steps * args.batch_size) / (end_time - start_time)) + max_mem = format_num(torch.cuda.max_memory_allocated(device=torch.cuda.current_device()), bytes=True) + + logger.info( + f"Testing finished, " + f"batch size per gpu: {args.batch_size}, " + f"plugin: {args.plugin}, " + f"throughput: {throughput}, " + f"maximum memory usage per gpu: {max_mem}.", + ranks=[0], + ) + + +if __name__ == "__main__": + main() diff --git a/examples/language/opt/opt_train_demo.py b/examples/language/opt/opt_train_demo.py new file mode 100644 index 0000000000000000000000000000000000000000..fddbc1b408e71c8913fd686aea161a08fe1e7de0 --- /dev/null +++ b/examples/language/opt/opt_train_demo.py @@ -0,0 +1,156 @@ +import datasets +import torch +import transformers +from args import parse_demo_args +from data import NetflixDataset, netflix_collator +from tqdm import tqdm +from transformers import AutoConfig, AutoTokenizer, OPTForCausalLM, get_linear_schedule_with_warmup +from transformers.utils.versions import require_version + +import colossalai +from colossalai.booster import Booster +from colossalai.booster.plugin import GeminiPlugin, HybridParallelPlugin, LowLevelZeroPlugin, TorchDDPPlugin +from colossalai.cluster import DistCoordinator +from colossalai.logging import disable_existing_loggers, get_dist_logger +from colossalai.nn.optimizer import HybridAdam + +require_version("datasets>=1.8.0", "To fix: pip install -r requirements.txt") +require_version("transformers>=4.20.0", "To fix: pip install -r requirements.txt") + +output_transform_fn = lambda x: x +criterion = lambda x: x.loss + + +def move_to_cuda(batch, device): + return {k: v.to(device) for k, v in batch.items()} + + +def train_epoch(epoch, model, optimizer, _criterion, lr_scheduler, dataloader, booster, coordinator): + torch.cuda.synchronize() + + use_pipeline = isinstance(booster.plugin, HybridParallelPlugin) and booster.plugin.pp_size > 1 + is_pp_last_stage = use_pipeline and booster.plugin.stage_manager.is_last_stage() + total_step = len(dataloader) + + model.train() + optimizer.zero_grad() + dataloader = iter(dataloader) + with tqdm( + range(total_step), desc=f"Epoch [{epoch + 1}]", disable=not (coordinator.is_master() or is_pp_last_stage) + ) as pbar: + # Forward pass + for _ in pbar: + if use_pipeline: + outputs = booster.execute_pipeline( + dataloader, model, _criterion, optimizer, return_loss=True, return_outputs=True + ) + # Backward and optimize + if is_pp_last_stage: + loss = outputs["loss"] + pbar.set_postfix({"loss": loss.item()}) + else: + data = next(dataloader) + data = move_to_cuda(data) + outputs = model(**data) + loss = _criterion(outputs, None) + # Backward + booster.backward(loss, optimizer) + pbar.set_postfix({"loss": loss.item()}) + + optimizer.step() + optimizer.zero_grad() + lr_scheduler.step() + + +def main(): + args = parse_demo_args() + + # Launch ColossalAI + colossalai.launch_from_torch(config={}, seed=args.seed) + coordinator = DistCoordinator() + world_size = coordinator.world_size + + # Manage loggers + disable_existing_loggers() + logger = get_dist_logger() + if coordinator.is_master(): + datasets.utils.logging.set_verbosity_warning() + transformers.utils.logging.set_verbosity_info() + else: + datasets.utils.logging.set_verbosity_error() + transformers.utils.logging.set_verbosity_error() + + # Build OPT model + config = AutoConfig.from_pretrained(args.model_name_or_path) + model = OPTForCausalLM.from_pretrained(args.model_name_or_path, config=config) + logger.info(f"Finish loading model from {args.model_name_or_path}", ranks=[0]) + + # Enable gradient checkpointing + model.gradient_checkpointing_enable() + + # Set plugin + booster_kwargs = {} + if args.plugin == "torch_ddp_fp16": + booster_kwargs["mixed_precision"] = "fp16" + if args.plugin.startswith("torch_ddp"): + plugin = TorchDDPPlugin() + elif args.plugin == "gemini": + plugin = GeminiPlugin(offload_optim_frac=1.0, pin_memory=True, initial_scale=2**5) + elif args.plugin == "low_level_zero": + plugin = LowLevelZeroPlugin(initial_scale=2**5) + elif args.plugin == "hybrid_parallel": + # modify the param accordingly for finetuning test cases + plugin = HybridParallelPlugin( + tp_size=2, + pp_size=2, + num_microbatches=2, + enable_all_optimization=True, + zero_stage=0, + precision="fp16", + initial_scale=1, + ) + + logger.info(f"Set plugin as {args.plugin}", ranks=[0]) + + # Prepare tokenizer and dataloader + tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path) + dataset = NetflixDataset(tokenizer) + dataloader = plugin.prepare_dataloader( + dataset, batch_size=args.batch_size, shuffle=True, drop_last=True, collate_fn=netflix_collator + ) + + # Set optimizer + optimizer = HybridAdam(model.parameters(), lr=(args.learning_rate * world_size), weight_decay=args.weight_decay) + + # Set lr scheduler + total_steps = len(dataloader) * args.num_epoch + num_warmup_steps = int(args.warmup_ratio * total_steps) + lr_scheduler = get_linear_schedule_with_warmup( + optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=len(dataloader) * args.num_epoch + ) + + # Define criterion + def _criterion(outputs, inputs): + outputs = output_transform_fn(outputs) + loss = criterion(outputs) + return loss + + # Set booster + booster = Booster(plugin=plugin, **booster_kwargs) + model, optimizer, _criterion, dataloader, lr_scheduler = booster.boost( + model=model, optimizer=optimizer, dataloader=dataloader, criterion=_criterion, lr_scheduler=lr_scheduler + ) + + # Start finetuning + logger.info(f"Start finetuning", ranks=[0]) + for epoch in range(args.num_epoch): + train_epoch(epoch, model, optimizer, _criterion, lr_scheduler, dataloader, booster, coordinator) + + # Finish training and evaluate + logger.info(f"Finish finetuning", ranks=[0]) + booster.save_model(model, args.output_path, shard=True) + logger.info(f"Saving model checkpoint to {args.output_path}", ranks=[0]) + + +if __name__ == "__main__": + main() diff --git a/examples/language/opt/requirements.txt b/examples/language/opt/requirements.txt index 137a69e80498223cd7581a62e2e27320b77682a0..45bfbc37195f856777daa505e9a95e85914d4675 100644 --- a/examples/language/opt/requirements.txt +++ b/examples/language/opt/requirements.txt @@ -1,2 +1,4 @@ -colossalai >= 0.1.12 +colossalai >= 0.3.2 torch >= 1.8.1 +datasets >= 1.8.0 +transformers >= 4.30.2 diff --git a/examples/language/opt/run_benchmark.sh b/examples/language/opt/run_benchmark.sh new file mode 100644 index 0000000000000000000000000000000000000000..b79d6c13465ed948a9d0f231c0cf1d637b936c7c --- /dev/null +++ b/examples/language/opt/run_benchmark.sh @@ -0,0 +1,30 @@ +set -xe +pip install -r requirements.txt + +export BS=32 +export MEMCAP=0 +export GPUNUM=1 + +# acceptable values include `125m`, `350m`, `1.3b`, `2.7b`, `6.7b`, `13b`, `30b`, `66b` +export MODEL="125m" + +for BS in 8 32 128 +do +for PLUGIN in "torch_ddp" "torch_ddp_fp16" "low_level_zero" "gemini" +do +for GPUNUM in 1 4 +do + +MODLE_PATH="facebook/opt-${MODEL}" +colossalai run \ + --nproc_per_node ${GPUNUM} \ + --master_port 29505 \ + opt_benchmark.py \ + --model_name_or_path ${MODLE_PATH} \ + --mem_cap ${MEMCAP} \ + --plugin ${PLUGIN} \ + --batch_size ${BS} + +done +done +done diff --git a/examples/language/opt/run_demo.sh b/examples/language/opt/run_demo.sh new file mode 100644 index 0000000000000000000000000000000000000000..fe49d794f4b0c053148f146d317ab40e34b1628e --- /dev/null +++ b/examples/language/opt/run_demo.sh @@ -0,0 +1,44 @@ +set -xe +pip install -r requirements.txt + +# model name or path +MODEL="facebook/opt-350m" + +# path for saving model +OUTPUT_PATH="./output_model.bin" + +# plugin(training strategy) +# can only be one of "torch_ddp"/"torch_ddp_fp16"/"low_level_zero"/"gemini" +PLUGIN="hybrid_parallel" + +# number of gpus to use +GPUNUM=4 + +# batch size per gpu +BS=16 + +# learning rate +LR="5e-5" + +# number of epoch +EPOCH=10 + +# weight decay +WEIGHT_DECAY=0.01 + +# ratio of warmup steps +WARMUP_RATIO=0.1 + +# run the script for demo +colossalai run \ + --nproc_per_node ${GPUNUM} \ + --master_port 29505 \ + opt_train_demo.py \ + --model_name_or_path ${MODEL} \ + --output_path ${OUTPUT_PATH} \ + --plugin ${PLUGIN} \ + --batch_size ${BS} \ + --num_epoch ${EPOCH} \ + --learning_rate ${LR} \ + --weight_decay ${WEIGHT_DECAY} \ + --warmup_ratio ${WARMUP_RATIO} diff --git a/examples/language/opt/run_gemini.sh b/examples/language/opt/run_gemini.sh deleted file mode 100644 index 73f231292a132ec0a2837efe86dd6f5bc3eb81ba..0000000000000000000000000000000000000000 --- a/examples/language/opt/run_gemini.sh +++ /dev/null @@ -1,28 +0,0 @@ -set -x -export BS=${BS:-16} -export MEMCAP=${MEMCAP:-0} -# Acceptable values include `125m`, `350m`, `1.3b`, `2.7b`, `6.7b`, `13b`, `30b`, `66b`. For `175b` -export MODEL=${MODEL:-"125m"} -export GPUNUM=${GPUNUM:-1} -export USE_SHARD_INIT=${USE_SHARD_INIT:-"false"} - -# make directory for logs -mkdir -p ./logs - -if [ ${USE_SHARD_INIT} = "true" ]; then - USE_SHARD_INIT="--shardinit" -else - USE_SHARD_INIT="" -fi - -export MODLE_PATH="facebook/opt-${MODEL}" - -# HF_DATASETS_OFFLINE=1 TRANSFORMERS_OFFLINE=1 -torchrun \ - --nproc_per_node ${GPUNUM} \ - --master_port 19198 \ - train_gemini_opt.py \ - --mem_cap ${MEMCAP} \ - --model_name_or_path ${MODLE_PATH} \ - ${USE_SHARD_INIT} \ - --batch_size ${BS} 2>&1 | tee ./logs/colo_${MODEL}_bs_${BS}_cap_${MEMCAP}_gpu_${GPUNUM}.log diff --git a/examples/language/opt/test_ci.sh b/examples/language/opt/test_ci.sh index 317f602cda3c5a63b2d6130fff8a363cc8f613e2..2e3a645caf0673b2174bec107517abfc4be6ced2 100644 --- a/examples/language/opt/test_ci.sh +++ b/examples/language/opt/test_ci.sh @@ -1,4 +1,19 @@ -for GPUNUM in 2 1 +set -xe +pip install -r requirements.txt + +BS=4 +for PLUGIN in "torch_ddp" "torch_ddp_fp16" "low_level_zero" "gemini" do -env BS=2 MODEL="125m" GPUNUM=$GPUNUM bash ./run_gemini.sh +for GPUNUM in 1 4 +do + +colossalai run \ + --nproc_per_node ${GPUNUM} \ + --master_port 29505 \ + opt_benchmark.py \ + --model_name_or_path "facebook/opt-125m" \ + --plugin ${PLUGIN} \ + --batch_size ${BS} + +done done diff --git a/examples/language/opt/train_gemini_opt.py b/examples/language/opt/train_gemini_opt.py deleted file mode 100755 index 3614b689de26fbf8fb89a7c554bd9de4802da582..0000000000000000000000000000000000000000 --- a/examples/language/opt/train_gemini_opt.py +++ /dev/null @@ -1,233 +0,0 @@ -#!/usr/bin/env python -# coding=utf-8 -# Copyright 2021 The HuggingFace Inc. team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -Fine-tuning the library models for causal language modeling (GPT, GPT-2, CTRL, ...) -on a text file or a dataset without using HuggingFace Trainer. - -Here is the full list of checkpoints on the hub that can be fine-tuned by this script: -https://huggingface.co/models?filter=text-generation -""" -# You can also adapt this script on your own causal language modeling task. Pointers for this are left as comments. - -import time -from functools import partial - -import datasets -import torch -import torch.distributed as dist -import transformers -from transformers import CONFIG_MAPPING, MODEL_MAPPING, AutoConfig, OPTForCausalLM -from transformers.utils.versions import require_version - -import colossalai -from colossalai.logging import disable_existing_loggers, get_dist_logger -from colossalai.tensor import ProcessGroup, ShardSpec -from colossalai.utils import get_current_device -from colossalai.zero import ColoInitContext, GeminiAdamOptimizer, GeminiDDP - - -def get_data(batch_size, seq_len, vocab_size): - input_ids = torch.randint(0, vocab_size, (batch_size, seq_len), device=torch.cuda.current_device()) - attention_mask = torch.ones_like(input_ids) - return input_ids, attention_mask - - -require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/language-modeling/requirements.txt") - -MODEL_CONFIG_CLASSES = list(MODEL_MAPPING.keys()) -MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES) - - -def get_time_stamp(): - torch.cuda.synchronize() - return time.time() - - -def get_tflops(model_numel, batch_size, seq_len, step_time): - return model_numel * batch_size * seq_len * 8 / 1e12 / (step_time + 1e-12) - - -def parse_args(): - parser = colossalai.get_default_parser() - parser.add_argument( - "--model_name_or_path", - type=str, - help="Path to pretrained model or model identifier from huggingface.co/models.", - required=True, - ) - parser.add_argument( - "--config_name", - type=str, - default=None, - help="Pretrained config name or path if not the same as model_name", - ) - parser.add_argument( - "--batch_size", - type=int, - default=8, - help="Batch size (per dp group) for the training dataloader.", - ) - parser.add_argument( - "--learning_rate", - type=float, - default=5e-5, - help="Initial learning rate (after the potential warmup period) to use.", - ) - parser.add_argument("--weight_decay", type=float, default=0.0, help="Weight decay to use.") - parser.add_argument( - "--max_train_steps", - type=int, - default=20, - help="Total number of training steps to perform.", - ) - parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.") - parser.add_argument( - "--model_type", - type=str, - default=None, - help="Model type to use if training from scratch.", - choices=MODEL_TYPES, - ) - parser.add_argument( - "--shardinit", - action="store_true", - help="Initialize the model with tensor parallel", - ) - parser.add_argument("--mem_cap", type=int, default=0, help="use mem cap") - parser.add_argument("--init_in_cpu", action='store_true', default=False, help="init training model in cpu") - args = parser.parse_args() - - return args - - -def colo_memory_cap(size_in_GB): - from colossalai.utils import colo_device_memory_capacity, colo_set_process_memory_fraction, get_current_device - cuda_capacity = colo_device_memory_capacity(get_current_device()) - if size_in_GB * (1024**3) < cuda_capacity: - colo_set_process_memory_fraction(size_in_GB * (1024**3) / cuda_capacity) - print("Using {} GB of GPU memory".format(size_in_GB)) - - -def main(): - args = parse_args() - disable_existing_loggers() - colossalai.launch_from_torch({}) - logger = get_dist_logger() - is_main_process = dist.get_rank() == 0 - - if is_main_process: - datasets.utils.logging.set_verbosity_warning() - transformers.utils.logging.set_verbosity_info() - else: - datasets.utils.logging.set_verbosity_error() - transformers.utils.logging.set_verbosity_error() - - if args.mem_cap > 0: - colo_memory_cap(args.mem_cap) - - # If passed along, set the training seed now. - if args.seed is not None: - torch.mannul_seed(args.seed) - logger.info(f"Rank {dist.get_rank()}: random seed is set to {args.seed}") - - # See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at - # https://huggingface.co/docs/datasets/loading_datasets.html. - - # Load pretrained model - # In distributed training, the .from_pretrained methods guarantee that only one local process can concurrently - # download model & vocab. - if args.config_name: - config = AutoConfig.from_pretrained(args.config_name) - elif args.model_name_or_path: - config = AutoConfig.from_pretrained(args.model_name_or_path) - else: - config = CONFIG_MAPPING[args.model_type]() - logger.warning("You are instantiating a new config instance from scratch.") - logger.info("Model config has been created", ranks=[0]) - - if args.init_in_cpu: - init_dev = torch.device('cpu') - else: - init_dev = get_current_device() - - # shard init parameters - if args.shardinit: - logger.info("Sharding initialization !", ranks=[0]) - else: - logger.info("Skipping sharding initialization", ranks=[0]) - - world_size = torch.distributed.get_world_size() - shard_pg = ProcessGroup(tp_degree=world_size) if args.shardinit else None - default_dist_spec = ShardSpec([-1], [world_size]) if args.shardinit else None - - # build model - if args.model_name_or_path is None: - logger.info("Train a new model from scratch", ranks=[0]) - with ColoInitContext(device=init_dev, - dtype=torch.half, - default_dist_spec=default_dist_spec, - default_pg=shard_pg): - model = OPTForCausalLM(config) - else: - logger.info("Finetune a pre-trained model", ranks=[0]) - with ColoInitContext(device=init_dev, - dtype=torch.half, - default_dist_spec=default_dist_spec, - default_pg=shard_pg): - model = OPTForCausalLM.from_pretrained(args.model_name_or_path, - from_tf=bool(".ckpt" in args.model_name_or_path), - config=config, - local_files_only=False) - - # enable gradient checkpointing - model.gradient_checkpointing_enable() - - numel = sum([p.numel() for p in model.parameters()]) - PLACEMENT_POLICY = 'cpu' - model = GeminiDDP(model, - device=get_current_device(), - placement_policy=PLACEMENT_POLICY, - pin_memory=True, - strict_ddp_mode=args.shardinit) - optimizer = GeminiAdamOptimizer(model, lr=args.learning_rate, initial_scale=2**14, gpu_margin_mem_ratio=0.0) - - SEQ_LEN = 1024 - VOCAB_SIZE = 50257 - - get_tflops_func = partial(get_tflops, numel, args.batch_size, SEQ_LEN) - - model.train() - for step in range(args.max_train_steps): - st_time = time.time() - input_ids, attn_mask = get_data(args.batch_size, SEQ_LEN, VOCAB_SIZE) - - outputs = model(input_ids=input_ids, attention_mask=attn_mask, labels=input_ids, use_cache=False) - loss = outputs['loss'] - optimizer.backward(loss) - - optimizer.step() - optimizer.zero_grad() - torch.cuda.synchronize() - step_time = time.time() - st_time - step_tflops = get_tflops_func(step_time) - - logger.info("step {} finished, Tflops {}".format(step, step_tflops), ranks=[0]) - - logger.info("Training finished", ranks=[0]) - - -if __name__ == "__main__": - main() diff --git a/examples/language/palm/README.md b/examples/language/palm/README.md index 486bf240f89c1abbe442512660ee9f1aec6c1b9d..58ca902bf3eb74840f5e0a463b21f2093e69f0e9 100644 --- a/examples/language/palm/README.md +++ b/examples/language/palm/README.md @@ -43,6 +43,9 @@ palm = PaLM( ) ``` +## New API +We have modified our previous implementation of PaLM with our new Booster API, which offers a more flexible and efficient way to train your model. The new API is more user-friendly and easy to use. You can find the new API in train.py. We also offer a shell script test_ci.sh for you to go through all our plugins for the booster. For more information about the booster API you can refer to https://colossalai.org/docs/basics/booster_api/. + ## Test on Enwik8 ```bash diff --git a/examples/language/palm/palm_pytorch/autoregressive_wrapper.py b/examples/language/palm/palm_pytorch/autoregressive_wrapper.py index dc4f3d856fecaa09581cd6771bc5c9460c55865b..17251c2f4fb3df58da4b0155fdd8e81cf09dcaca 100644 --- a/examples/language/palm/palm_pytorch/autoregressive_wrapper.py +++ b/examples/language/palm/palm_pytorch/autoregressive_wrapper.py @@ -11,7 +11,6 @@ def exists(val): def eval_decorator(fn): - def inner(model, *args, **kwargs): was_training = model.training model.eval() @@ -34,7 +33,6 @@ def top_k(logits, thres=0.9): class AutoregressiveWrapper(nn.Module): - def __init__(self, net, max_seq_len=2048, pad_value=0): super().__init__() self.max_seq_len = max_seq_len diff --git a/examples/language/palm/palm_pytorch/palm_pytorch.py b/examples/language/palm/palm_pytorch/palm_pytorch.py index c37974711e11b7046e474f27fbc138b246e63d9f..6be966d67790575c8bf24199f00987056ef903fa 100644 --- a/examples/language/palm/palm_pytorch/palm_pytorch.py +++ b/examples/language/palm/palm_pytorch/palm_pytorch.py @@ -1,14 +1,13 @@ import torch import torch.nn.functional as F from einops import rearrange -from torch import einsum, matmul, nn +from torch import matmul, nn # normalization # they use layernorm without bias, something that pytorch does not offer class LayerNorm(nn.Module): - def __init__(self, dim, eps=1e-5): super().__init__() self.eps = eps @@ -24,7 +23,6 @@ class LayerNorm(nn.Module): class ParallelResidual(nn.Module): - def __init__(self, *fns): super().__init__() self.fns = nn.ModuleList(fns) @@ -38,16 +36,15 @@ class ParallelResidual(nn.Module): class RotaryEmbedding(nn.Module): - def __init__(self, dim): super().__init__() - inv_freq = 1.0 / (10000**(torch.arange(0, dim, 2).float() / dim)) + inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim)) self.register_buffer("inv_freq", inv_freq) def forward(self, max_seq_len, *, device): seq = torch.arange(max_seq_len, device=device) - #freqs = einsum("i , j -> i j", seq.type_as(self.inv_freq), self.inv_freq) - #freqs = torch.outer(seq.type_as(self.inv_freq), self.inv_freq) + # freqs = einsum("i , j -> i j", seq.type_as(self.inv_freq), self.inv_freq) + # freqs = torch.outer(seq.type_as(self.inv_freq), self.inv_freq) i, j = len(seq.type_as(self.inv_freq)), len(self.inv_freq) freqs = matmul(seq.type_as(self.inv_freq).reshape(i, 1), self.inv_freq.reshape(1, j)) return torch.cat((freqs, freqs), dim=-1) @@ -69,7 +66,6 @@ def apply_rotary_pos_emb(pos, t): class SwiGLU(nn.Module): - def forward(self, x): x, gate = x.chunk(2, dim=-1) return F.silu(gate) * x @@ -87,7 +83,6 @@ def FeedForward(dim, mult=4): # attention class Attention(nn.Module): - def __init__(self, dim, dim_head=64, heads=8): super().__init__() inner_dim = dim_head * heads @@ -160,7 +155,7 @@ class Attention(nn.Module): # similarity - #sim = einsum("b h i d, b j d -> b h i j", q, k) + # sim = einsum("b h i d, b j d -> b h i j", q, k) sim = matmul(q.reshape(b, h * i, d), k.transpose(1, 2)) sim = sim.reshape(b, h, i, j) @@ -178,7 +173,7 @@ class Attention(nn.Module): # aggregate values - #out = einsum("b h i j, b j d -> b h i d", attn, v) + # out = einsum("b h i j, b j d -> b h i d", attn, v) out = matmul(attn.reshape(b_, h_ * i_, j_), v) out = out.reshape(b_, h_, i_, d_) @@ -193,12 +188,17 @@ class Attention(nn.Module): def PaLM(*, dim, num_tokens, depth, dim_head=64, heads=8, ff_mult=4): net = nn.Sequential( - nn.Embedding(num_tokens, dim), *[ + nn.Embedding(num_tokens, dim), + *[ ParallelResidual( Attention(dim=dim, dim_head=dim_head, heads=heads), FeedForward(dim=dim, mult=ff_mult), - ) for _ in range(depth) - ], LayerNorm(dim), nn.Linear(dim, num_tokens, bias=False)) + ) + for _ in range(depth) + ], + LayerNorm(dim), + nn.Linear(dim, num_tokens, bias=False), + ) # they used embedding weight tied projection out to logits, not common, but works net[-1].weight = net[0].weight diff --git a/examples/language/palm/run.sh b/examples/language/palm/run.sh index 7a533509e009cfb7fe8d1c88d63b0457a2c21a65..0b9871c777233190f0ff05ba192464d6ba04b3b6 100644 --- a/examples/language/palm/run.sh +++ b/examples/language/palm/run.sh @@ -3,9 +3,11 @@ export DISTPAN="colossalai" # The following options only valid when DISTPAN="colossalai" export TPDEGREE=1 -export GPUNUM=1 +export GPUNUM=4 export PLACEMENT='cpu' export USE_SHARD_INIT=False -export BATCH_SIZE=4 +export BATCH_SIZE=1 -env OMP_NUM_THREADS=12 torchrun --standalone --nproc_per_node=${GPUNUM} --master_port 29501 train.py --tp_degree=${TPDEGREE} --batch_size=${BATCH_SIZE} --placement ${PLACEMENT} --shardinit ${USE_SHARD_INIT} --distplan ${DISTPAN} 2>&1 | tee run.log +env OMP_NUM_THREADS=12 colossalai run --nproc_per_node ${GPUNUM} --master_port 29505 train.py \ +--dummy_data=True --tp_degree=${TPDEGREE} --batch_size=${BATCH_SIZE} --plugin='gemini' \ +--placement ${PLACEMENT} --shardinit ${USE_SHARD_INIT} --distplan ${DISTPAN} 2>&1 | tee run.log diff --git a/examples/language/palm/test_ci.sh b/examples/language/palm/test_ci.sh index f21095578077eda89c08eeeebba0013327e26ea9..6bcd140fe7fd19dca975336d7b177dff40a72b53 100644 --- a/examples/language/palm/test_ci.sh +++ b/examples/language/palm/test_ci.sh @@ -4,6 +4,6 @@ for BATCH_SIZE in 2 do for GPUNUM in 1 4 do -env OMP_NUM_THREADS=12 torchrun --standalone --nproc_per_node=${GPUNUM} --master_port 29501 train.py --dummy_data=True --batch_size=${BATCH_SIZE} 2>&1 | tee run.log +env OMP_NUM_THREADS=12 colossalai run --nproc_per_node ${GPUNUM} --master_port 29505 train.py --dummy_data=True --batch_size=${BATCH_SIZE} --plugin='gemini' 2>&1 | tee run.log done done diff --git a/examples/language/palm/train.py b/examples/language/palm/train.py index 7923e4fc855d17ef8fbd7411c0d3c3c3bde51d67..7af02e24e6cf01b4f212e9971c69e66878ae2b2a 100644 --- a/examples/language/palm/train.py +++ b/examples/language/palm/train.py @@ -1,5 +1,6 @@ +import argparse import gzip -import random +from contextlib import nullcontext from functools import partial from time import time @@ -8,16 +9,17 @@ import torch import torch.nn as nn import torch.optim as optim import tqdm -from packaging import version from palm_pytorch import PaLM from palm_pytorch.autoregressive_wrapper import AutoregressiveWrapper from torch.utils.data import DataLoader, Dataset import colossalai +from colossalai.booster import Booster +from colossalai.booster.plugin import GeminiPlugin, LowLevelZeroPlugin, TorchDDPPlugin +from colossalai.lazy import LazyInitContext from colossalai.logging import disable_existing_loggers, get_dist_logger -from colossalai.tensor import ColoParameter, ComputePattern, ComputeSpec, ProcessGroup, ReplicaSpec, ShardSpec -from colossalai.utils import MultiTimer, get_current_device -from colossalai.zero import ColoInitContext, GeminiAdamOptimizer, ZeroDDP +from colossalai.nn import HybridAdam +from colossalai.utils import get_current_device # constants @@ -32,31 +34,26 @@ SEQ_LEN = 1024 def parse_args(): - parser = colossalai.get_default_parser() + parser = argparse.ArgumentParser() parser.add_argument( "--distplan", type=str, - default='colossalai', + default="colossalai", help="The distributed plan [colossalai, pytorch].", ) parser.add_argument( - "--tp_degree", - type=int, - default=1, - help="Tensor Parallelism Degree. Valid when using colossalai as dist plan.", + "--offload_optim_frac", + type=float, + default=1.0, + help="Fraction of optimizer states to be offloaded. This is only used for gemini.", ) parser.add_argument( - "--placement", + "-p", + "--plugin", type=str, - default='cpu', - help="Placement Policy for Gemini. Valid when using colossalai as dist plan.", - ) - parser.add_argument( - "--shardinit", - type=bool, - default=False, - help= - "Shard the tensors when init the model to shrink peak memory size on the assigned device. Valid when using colossalai as dist plan.", + default="torch_ddp", + choices=["torch_ddp", "torch_ddp_fp16", "gemini", "low_level_zero"], + help="plugin to use", ) parser.add_argument( "--batch_size", @@ -101,73 +98,6 @@ def get_model_size(model: nn.Module): return total_numel -# Gemini + ZeRO DDP -def gemini_zero_dpp(model: torch.nn.Module, pg: ProcessGroup, placememt_policy: str = "auto"): - cai_version = colossalai.__version__ - if version.parse(cai_version) > version.parse("0.1.10"): - from colossalai.nn.parallel import GeminiDDP - model = GeminiDDP(model, - device=get_current_device(), - placement_policy=placememt_policy, - pin_memory=True, - search_range_mb=32) - elif version.parse(cai_version) <= version.parse("0.1.10") and version.parse(cai_version) >= version.parse("0.1.9"): - from colossalai.gemini import ChunkManager, GeminiManager - chunk_size = ChunkManager.search_chunk_size(model, 64 * 1024**2, 32) - gemini_manager = GeminiManager(placememt_policy, chunk_manager) - chunk_manager = ChunkManager(chunk_size, - pg, - enable_distributed_storage=True, - init_device=GeminiManager.get_default_device(placememt_policy)) - model = ZeroDDP(model, gemini_manager) - else: - raise NotImplemented(f"CAI version {cai_version} is not supported") - return model - - -# Parameter Sharding Strategies for Tensor Parallelism -def split_param_single_dim_tp1d(dim: int, param: ColoParameter, pg: ProcessGroup): - spec = (ShardSpec([dim], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D)) - param.set_tensor_spec(*spec) - - -def split_param_row_tp1d(param: ColoParameter, pg: ProcessGroup): - split_param_single_dim_tp1d(0, param, pg) - - -def split_param_col_tp1d(param: ColoParameter, pg: ProcessGroup): - split_param_single_dim_tp1d(-1, param, pg) - - -# Tensor Parallel -def tensor_parallelize(model: torch.nn.Module, pg: ProcessGroup): - """tensor_parallelize - Sharding the Model Parameters. - Args: - model (torch.nn.Module): a torch module to be sharded - """ - for mn, module in model.named_modules(): - for pn, param in module.named_parameters(recurse=False): - if hasattr(param, 'visited'): - continue - param.set_dist_spec(ReplicaSpec()) - if 'net.0' in mn: - split_param_col_tp1d(param, pg) # colmn slice - elif 'to_q' in mn: - split_param_col_tp1d(param, pg) # colmn slice - elif 'to_kv' in mn: - split_param_row_tp1d(param, pg) # row slice - elif 'to_out' in mn: - split_param_row_tp1d(param, pg) # row slice - elif '1.1' in mn: - split_param_col_tp1d(param, pg) # colmn slice - elif '1.2' in mn: - split_param_row_tp1d(param, pg) # row slice - else: - param.set_dist_spec(ReplicaSpec()) - param.visited = True - - args = parse_args() if args.distplan not in ["colossalai", "pytorch"]: raise TypeError(f"{args.distplan} is error") @@ -195,7 +125,6 @@ print("generate dataset ready!") class TextSamplerDataset(Dataset): - def __init__(self, data, seq_len): super().__init__() self.data = data @@ -203,7 +132,7 @@ class TextSamplerDataset(Dataset): def __getitem__(self, index): rand_start = torch.randint(0, self.data.size(0) - self.seq_len, (1,)) - full_seq = self.data[rand_start:rand_start + self.seq_len + 1].long() + full_seq = self.data[rand_start : rand_start + self.seq_len + 1].long() return full_seq.cuda() def __len__(self): @@ -218,22 +147,29 @@ val_loader = cycle(DataLoader(val_dataset, batch_size=args.batch_size)) if args.distplan == "colossalai": # instantiate GPT-like decoder model - default_pg = ProcessGroup(tp_degree=args.tp_degree) - default_dist_spec = ShardSpec([-1], [args.tp_degree]) if args.shardinit else None - ctx = ColoInitContext(device='cpu', default_dist_spec=default_dist_spec, default_pg=default_pg) + booster_kwargs = {} + if args.plugin == "torch_ddp_fp16": + booster_kwargs["mixed_precision"] = "fp16" + if args.plugin.startswith("torch_ddp"): + plugin = TorchDDPPlugin() + elif args.plugin == "gemini": + plugin = GeminiPlugin(offload_optim_frac=args.offload_optim_frac, initial_scale=2**5) + elif args.plugin == "low_level_zero": + plugin = LowLevelZeroPlugin(initial_scale=2**5) + logger.info(f"plugin: {plugin}") + booster = Booster(plugin=plugin, **booster_kwargs) + + ctx = LazyInitContext(default_device=get_current_device()) if args.plugin == "gemini" else nullcontext() with ctx: model = PaLM(num_tokens=50304, dim=4096, depth=64) model = AutoregressiveWrapper(model, max_seq_len=SEQ_LEN) - pg = default_pg - tensor_parallelize(model, pg) - model = gemini_zero_dpp(model, pg, args.placement) - # optimizer - #optimizer = GeminiAdamOptimizer(model, lr=1e-7, initial_scale=2**5) - optimizer = GeminiAdamOptimizer(model, lr=LEARNING_RATE, initial_scale=2**5) + optimizer = HybridAdam(model.parameters(), lr=LEARNING_RATE, initial_scale=2**5) + model, optimizer, _, _, _ = booster.boost(model, optimizer) + else: model = PaLM(num_tokens=256, dim=512, depth=8) model = AutoregressiveWrapper(model, max_seq_len=2048) @@ -248,7 +184,6 @@ get_tflops_func = partial(get_tflops, numel, args.batch_size, SEQ_LEN) model.train() tflops_list = [] for i in tqdm.tqdm(range(NUM_BATCHES), mininterval=10.0, desc="training"): - if args.distplan == "colossalai": optimizer.zero_grad() start = time() @@ -297,12 +232,12 @@ logger.info(f"Median TFLOPS is {tflops_list[median_index]:.3f}") # loss = model(next(val_loader)) # print(f"validation loss: {loss.item()}") - # if i % GENERATE_EVERY == 0: - # model.eval() - # inp = random.choice(val_dataset)[:-1] - # prime = decode_tokens(inp) - # print(f"%s \n\n %s", (prime, "*" * 100)) +# if i % GENERATE_EVERY == 0: +# model.eval() +# inp = random.choice(val_dataset)[:-1] +# prime = decode_tokens(inp) +# print(f"%s \n\n %s", (prime, "*" * 100)) - # sample = model.generate(inp[None, ...], GENERATE_LENGTH) - # output_str = decode_tokens(sample[0]) - # print(output_str) +# sample = model.generate(inp[None, ...], GENERATE_LENGTH) +# output_str = decode_tokens(sample[0]) +# print(output_str) diff --git a/examples/tutorial/README.md b/examples/tutorial/README.md index f4843331fd54f795938c07677c2be60e55497950..a54c7b4da3bd8e6c1d2079ab1b65ec6ae8b126be 100644 --- a/examples/tutorial/README.md +++ b/examples/tutorial/README.md @@ -4,7 +4,8 @@ ## Introduction -Welcome to the [Colossal-AI](https://github.com/hpcaitech/ColossalAI) tutorial, which has been accepted as official tutorials by top conference [SC](https://sc22.supercomputing.org/), [AAAI](https://aaai.org/Conferences/AAAI-23/), [PPoPP](https://ppopp23.sigplan.org/), [CVPR](https://cvpr2023.thecvf.com/), [ISC](https://www.isc-hpc.com/), etc. +Welcome to the [Colossal-AI](https://github.com/hpcaitech/ColossalAI) tutorial, which has been accepted as official tutorials by top conference [NeurIPS](https://nips.cc/), [SC](https://sc22.supercomputing.org/), [AAAI](https://aaai.org/Conferences/AAAI-23/), +[PPoPP](https://ppopp23.sigplan.org/), [CVPR](https://cvpr2023.thecvf.com/), [ISC](https://www.isc-hpc.com/), [NVIDIA GTC](https://www.nvidia.com/en-us/on-demand/session/gtcspring23-S51482/) ,etc. [Colossal-AI](https://github.com/hpcaitech/ColossalAI), a unified deep learning system for the big model era, integrates @@ -29,7 +30,11 @@ quickly deploy large AI model training and inference, reducing large AI model tr - Fine-tuning and Inference for OPT [[code]](https://github.com/hpcaitech/ColossalAI/tree/main/examples/tutorial/opt) [[video]](https://www.youtube.com/watch?v=jbEFNVzl67Y) - Optimized AlphaFold [[code]](https://github.com/hpcaitech/ColossalAI/tree/main/examples/tutorial/fastfold) [[video]](https://www.youtube.com/watch?v=-zP13LfJP7w) - Optimized Stable Diffusion [[code]](https://github.com/hpcaitech/ColossalAI/tree/main/examples/images/diffusion) [[video]](https://www.youtube.com/watch?v=8KHeUjjc-XQ) - + - ColossalChat: Cloning ChatGPT with a Complete RLHF Pipeline +[[code]](https://github.com/hpcaitech/ColossalAI/tree/main/applications/Chat) +[[blog]](https://medium.com/@yangyou_berkeley/colossalchat-an-open-source-solution-for-cloning-chatgpt-with-a-complete-rlhf-pipeline-5edf08fb538b) +[[demo]](https://www.youtube.com/watch?v=HcTiHzApHm0) +[[video]](https://www.youtube.com/watch?v=-qFBZFmOJfg) ## Discussion diff --git a/examples/tutorial/auto_parallel/README.md b/examples/tutorial/auto_parallel/README.md index 6a12e0dd5a4839d617ad2e1aca8468ee3393436e..13561567636e6d65a091054147775b6fd0b56b6a 100644 --- a/examples/tutorial/auto_parallel/README.md +++ b/examples/tutorial/auto_parallel/README.md @@ -13,7 +13,7 @@ ## 📚 Overview -This tutorial folder contains a simple demo to run auto-parallelism with ResNet. Meanwhile, this diretory also contains demo scripts to run automatic activation checkpointing, but both features are still experimental for now and no guarantee that they will work for your version of Colossal-AI. +This tutorial folder contains a simple demo to run auto-parallelism with ResNet. Meanwhile, this directory also contains demo scripts to run automatic activation checkpointing, but both features are still experimental for now and no guarantee that they will work for your version of Colossal-AI. ## 🚀 Quick Start diff --git a/examples/tutorial/auto_parallel/auto_ckpt_batchsize_test.py b/examples/tutorial/auto_parallel/auto_ckpt_batchsize_test.py index 5a68aae18041946150150e9b63b04f0b9e387e40..29101ce08434615fdea25df49d75d61517d782ea 100644 --- a/examples/tutorial/auto_parallel/auto_ckpt_batchsize_test.py +++ b/examples/tutorial/auto_parallel/auto_ckpt_batchsize_test.py @@ -20,20 +20,22 @@ def _benchmark(rank, world_size, port): only result in minor performance drop. So at last we might be able to find better training batch size for our model (combine with large batch training optimizer such as LAMB). """ - colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") model = tm.resnet152() gm = symbolic_trace(model) raw_graph = deepcopy(gm.graph) peak_mems, through_puts, batch_sizes = [], [], [512, 1024, 2048] for batch_size in batch_sizes: batch_size = int(batch_size) - gm = metainfo_trace(gm, torch.empty(batch_size, 3, 224, 224, device='meta')) + gm = metainfo_trace(gm, torch.empty(batch_size, 3, 224, 224, device="meta")) solver = CheckpointSolverRotor(gm.graph, free_memory=torch.cuda.mem_get_info()[0] * 0.95) gm.graph = solver.solve() - peak_mem, step_time = bench(gm, - torch.nn.CrossEntropyLoss(), - partial(data_gen_resnet, batch_size=batch_size, shape=(3, 224, 224)), - num_steps=5) + peak_mem, step_time = bench( + gm, + torch.nn.CrossEntropyLoss(), + partial(data_gen_resnet, batch_size=batch_size, shape=(3, 224, 224)), + num_steps=5, + ) peak_mems.append(peak_mem) through_puts.append(batch_size / step_time * 1.0e3) gm.graph = deepcopy(raw_graph) @@ -41,7 +43,7 @@ def _benchmark(rank, world_size, port): # print results print("===============benchmark summary================") for batch_size, peak_mem, through_put in zip(batch_sizes, peak_mems, through_puts): - print(f'batch_size: {int(batch_size)}, peak memory: {peak_mem:.3f} MB, through put: {through_put:.3f} images/s') + print(f"batch_size: {int(batch_size)}, peak memory: {peak_mem:.3f} MB, through put: {through_put:.3f} images/s") def auto_activation_checkpoint_batchsize_benchmark(): diff --git a/examples/tutorial/auto_parallel/auto_ckpt_solver_test.py b/examples/tutorial/auto_parallel/auto_ckpt_solver_test.py index aa5c47294a8279a34e912e1cb1fb0aec4f7dcdff..cd03a917912e9726702aca99db20995a559f8a0e 100644 --- a/examples/tutorial/auto_parallel/auto_ckpt_solver_test.py +++ b/examples/tutorial/auto_parallel/auto_ckpt_solver_test.py @@ -1,4 +1,3 @@ -import time from argparse import ArgumentParser from functools import partial @@ -8,7 +7,6 @@ import torchvision.models as tm from bench_utils import GPTLMLoss, bench_rotor, data_gen_gpt2, data_gen_resnet, gpt2_medium import colossalai -from colossalai.auto_parallel.checkpoint import CheckpointSolverRotor from colossalai.fx import metainfo_trace, symbolic_trace from colossalai.testing import spawn @@ -19,37 +17,33 @@ def _benchmark(rank, world_size, port, args): The benchmark will sample in a range of memory budget for each model and output the benchmark summary and data visualization of peak memory vs. budget memory and relative step time vs. peak memory. """ - colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') - if args.model == 'resnet50': + colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") + if args.model == "resnet50": model = tm.resnet50() data_gen = partial(data_gen_resnet, batch_size=128, shape=(3, 224, 224)) gm = symbolic_trace(model) - gm = metainfo_trace(gm, torch.empty(128, 3, 224, 224, device='meta')) + gm = metainfo_trace(gm, torch.empty(128, 3, 224, 224, device="meta")) loss = torch.nn.CrossEntropyLoss() else: model = gpt2_medium() data_gen = partial(data_gen_gpt2, batch_size=8, seq_len=1024, vocab_size=50257) - data, mask = data_gen(device='meta')[0] - gm = symbolic_trace(model, meta_args={'input_ids': data, 'attention_mask': mask}) + data, mask = data_gen(device="meta")[0] + gm = symbolic_trace(model, meta_args={"input_ids": data, "attention_mask": mask}) gm = metainfo_trace(gm, data, mask) loss = GPTLMLoss() - free_memory = 11000 * 1024**2 if args.model == 'resnet50' else 56000 * 1024**2 - start_factor = 4 if args.model == 'resnet50' else 10 + free_memory = 11000 * 1024**2 if args.model == "resnet50" else 56000 * 1024**2 + start_factor = 4 if args.model == "resnet50" else 10 # trace and benchmark - budgets, peak_hist, step_hist = bench_rotor(gm, - loss, - data_gen, - num_steps=5, - sample_points=15, - free_memory=free_memory, - start_factor=start_factor) + budgets, peak_hist, step_hist = bench_rotor( + gm, loss, data_gen, num_steps=5, sample_points=15, free_memory=free_memory, start_factor=start_factor + ) # print summary print("==============benchmark summary==============") for budget, peak, step in zip(budgets, peak_hist, step_hist): - print(f'memory budget: {budget:.3f} MB, peak memory: {peak:.3f} MB, step time: {step:.3f} MS') + print(f"memory budget: {budget:.3f} MB, peak memory: {peak:.3f} MB, step time: {step:.3f} MS") # plot valid results fig, axs = plt.subplots(1, 2, figsize=(16, 8)) @@ -57,14 +51,14 @@ def _benchmark(rank, world_size, port, args): # plot peak memory vs. budget memory axs[0].plot(budgets[valid_idx:], peak_hist[valid_idx:]) - axs[0].plot([budgets[valid_idx], budgets[-1]], [budgets[valid_idx], budgets[-1]], linestyle='--') + axs[0].plot([budgets[valid_idx], budgets[-1]], [budgets[valid_idx], budgets[-1]], linestyle="--") axs[0].set_xlabel("Budget Memory (MB)") axs[0].set_ylabel("Peak Memory (MB)") axs[0].set_title("Peak Memory vs. Budget Memory") # plot relative step time vs. budget memory axs[1].plot(peak_hist[valid_idx:], [step_time / step_hist[-1] for step_time in step_hist[valid_idx:]]) - axs[1].plot([peak_hist[valid_idx], peak_hist[-1]], [1.0, 1.0], linestyle='--') + axs[1].plot([peak_hist[valid_idx], peak_hist[-1]], [1.0, 1.0], linestyle="--") axs[1].set_xlabel("Peak Memory (MB)") axs[1].set_ylabel("Relative Step Time") axs[1].set_title("Step Time vs. Peak Memory") @@ -81,7 +75,7 @@ def auto_activation_checkpoint_benchmark(args): if __name__ == "__main__": parser = ArgumentParser("Auto Activation Checkpoint Solver Benchmark") - parser.add_argument("--model", type=str, default='gpt2', choices=['gpt2', 'resnet50']) + parser.add_argument("--model", type=str, default="gpt2", choices=["gpt2", "resnet50"]) args = parser.parse_args() auto_activation_checkpoint_benchmark(args) diff --git a/examples/tutorial/auto_parallel/auto_parallel_with_resnet.py b/examples/tutorial/auto_parallel/auto_parallel_with_resnet.py index a6a9ad0a312cba1d8771b4e868fcb0d0e92507ee..3c5b786b561a47200bd0b445c3026303fba47dbc 100644 --- a/examples/tutorial/auto_parallel/auto_parallel_with_resnet.py +++ b/examples/tutorial/auto_parallel/auto_parallel_with_resnet.py @@ -4,8 +4,8 @@ from tqdm import tqdm import colossalai from colossalai.auto_parallel.tensor_shard.initialize import initialize_model -from colossalai.core import global_context as gpc from colossalai.device.device_mesh import DeviceMesh +from colossalai.legacy.core import global_context as gpc from colossalai.logging import get_dist_logger from colossalai.nn.lr_scheduler import CosineAnnealingLR @@ -17,14 +17,14 @@ def synthesize_data(): def main(): - colossalai.launch_from_torch(config='./config.py') + colossalai.launch_from_torch(config="./config.py") logger = get_dist_logger() # trace the model with meta data model = resnet50(num_classes=10).cuda() - input_sample = {'x': torch.rand([gpc.config.BATCH_SIZE * torch.distributed.get_world_size(), 3, 32, 32]).to('meta')} + input_sample = {"x": torch.rand([gpc.config.BATCH_SIZE * torch.distributed.get_world_size(), 3, 32, 32]).to("meta")} device_mesh = DeviceMesh(physical_mesh_id=torch.tensor([0, 1, 2, 3]), mesh_shape=[2, 2], init_process_group=True) model, solution = initialize_model(model, input_sample, device_mesh=device_mesh, return_solution=True) @@ -88,8 +88,9 @@ def main(): logger.info( f"Epoch {epoch} - train loss: {train_loss:.5}, test loss: {test_loss:.5}, acc: {correct / total:.5}, lr: {lr_scheduler.get_last_lr()[0]:.5g}", - ranks=[0]) + ranks=[0], + ) -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/examples/tutorial/auto_parallel/bench_utils.py b/examples/tutorial/auto_parallel/bench_utils.py index 69859f885ae601efcec049b4771df831ed0da009..96cfd49c6787de7cecff6b1a223a8acd2858a5df 100644 --- a/examples/tutorial/auto_parallel/bench_utils.py +++ b/examples/tutorial/auto_parallel/bench_utils.py @@ -1,22 +1,19 @@ import time from copy import deepcopy -from functools import partial from typing import Callable, Tuple import numpy as np import torch import torch.nn as nn -import torchvision.models as tm from transformers import GPT2Config, GPT2LMHeadModel from colossalai.auto_parallel.checkpoint import CheckpointSolverRotor from colossalai.fx import metainfo_trace -def bench(gm: torch.fx.GraphModule, - criterion: torch.nn.Module, - data_gen: Callable, - num_steps: int = 5) -> Tuple[int, int]: +def bench( + gm: torch.fx.GraphModule, criterion: torch.nn.Module, data_gen: Callable, num_steps: int = 5 +) -> Tuple[int, int]: """Benchmarking a given graph module Args: gm (torch.fx.GraphModule): The graph module to benchmark. @@ -28,7 +25,7 @@ def bench(gm: torch.fx.GraphModule, """ gm.train() gm.cuda() - step_time = float('inf') + step_time = float("inf") torch.cuda.synchronize() torch.cuda.empty_cache() torch.cuda.reset_peak_memory_stats() @@ -58,13 +55,15 @@ def bench(gm: torch.fx.GraphModule, return peak_mem, step_time * 1.0e3 -def bench_rotor(gm: torch.fx.GraphModule, - criterion: torch.nn.Module, - data_gen: Callable, - num_steps: int = 5, - sample_points: int = 20, - free_memory: int = torch.cuda.mem_get_info()[0], - start_factor: int = 4) -> Tuple[np.array, list, list]: +def bench_rotor( + gm: torch.fx.GraphModule, + criterion: torch.nn.Module, + data_gen: Callable, + num_steps: int = 5, + sample_points: int = 20, + free_memory: int = torch.cuda.mem_get_info()[0], + start_factor: int = 4, +) -> Tuple[np.array, list, list]: """Auto Checkpoint Rotor Algorithm benchmarking Benchmarks the Auto Checkpoint Rotor Algorithm for a given graph module and data. Args: @@ -88,7 +87,7 @@ def bench_rotor(gm: torch.fx.GraphModule, gm.graph = solver.solve(verbose=False) peak_memory, step_time = bench(gm, criterion, data_gen, num_steps=num_steps) except: - peak_memory, step_time = budget / 1024**2, float('inf') + peak_memory, step_time = budget / 1024**2, float("inf") peak_hist.append(peak_memory) step_hist.append(step_time) gm.graph = deepcopy(raw_graph) @@ -100,22 +99,27 @@ class GPTLMModel(nn.Module): GPT Model """ - def __init__(self, - hidden_size=768, - num_layers=12, - num_attention_heads=12, - max_seq_len=1024, - vocab_size=50257, - checkpoint=False): + def __init__( + self, + hidden_size=768, + num_layers=12, + num_attention_heads=12, + max_seq_len=1024, + vocab_size=50257, + checkpoint=False, + ): super().__init__() self.checkpoint = checkpoint self.model = GPT2LMHeadModel( - GPT2Config(n_embd=hidden_size, - n_layer=num_layers, - n_head=num_attention_heads, - n_positions=max_seq_len, - n_ctx=max_seq_len, - vocab_size=vocab_size)) + GPT2Config( + n_embd=hidden_size, + n_layer=num_layers, + n_head=num_attention_heads, + n_positions=max_seq_len, + n_ctx=max_seq_len, + vocab_size=vocab_size, + ) + ) if checkpoint: self.model.gradient_checkpointing_enable() @@ -152,7 +156,7 @@ def gpt2_6b(checkpoint=False): return GPTLMModel(hidden_size=4096, num_layers=30, num_attention_heads=16, checkpoint=checkpoint) -def data_gen_gpt2(batch_size, seq_len, vocab_size, device='cuda:0'): +def data_gen_gpt2(batch_size, seq_len, vocab_size, device="cuda:0"): """ Generate random data for gpt2 benchmarking """ @@ -161,7 +165,7 @@ def data_gen_gpt2(batch_size, seq_len, vocab_size, device='cuda:0'): return (input_ids, attention_mask), attention_mask -def data_gen_resnet(batch_size, shape, device='cuda:0'): +def data_gen_resnet(batch_size, shape, device="cuda:0"): """ Generate random data for resnet benchmarking """ diff --git a/examples/tutorial/auto_parallel/setup.py b/examples/tutorial/auto_parallel/setup.py index 6e6cff32ed23ece8a78a066a66433e9ffdd04210..94d5ec0c0e9ebed4fdefb0572d330b376c23b862 100644 --- a/examples/tutorial/auto_parallel/setup.py +++ b/examples/tutorial/auto_parallel/setup.py @@ -1,13 +1,13 @@ from setuptools import find_packages, setup setup( - name='auto_parallel', - version='0.0.1', - description='', + name="auto_parallel", + version="0.0.1", + description="", packages=find_packages(), install_requires=[ - 'torch', - 'numpy', - 'tqdm', + "torch", + "numpy", + "tqdm", ], ) diff --git a/examples/tutorial/auto_parallel/test_ci.sh b/examples/tutorial/auto_parallel/test_ci.sh index bf6275b673ff7f559708bae2b5bc85dbac23c3ae..b27e36217117606dd701b60e57ff50ff2440c29e 100644 --- a/examples/tutorial/auto_parallel/test_ci.sh +++ b/examples/tutorial/auto_parallel/test_ci.sh @@ -1,6 +1,8 @@ #!/bin/bash set -euxo pipefail -pip install -r requirements.txt -conda install -c conda-forge coin-or-cbc -colossalai run --nproc_per_node 4 auto_parallel_with_resnet.py +echo "this test is outdated" + +# pip install -r requirements.txt +# conda install -c conda-forge coin-or-cbc +# colossalai run --nproc_per_node 4 auto_parallel_with_resnet.py diff --git a/examples/tutorial/download_cifar10.py b/examples/tutorial/download_cifar10.py index 5c6b6988ade531f9a6e77955803b7c2dbd88ca9a..78ea3d1e062e8cf2fb6e6f41c14720dcadf4fe52 100644 --- a/examples/tutorial/download_cifar10.py +++ b/examples/tutorial/download_cifar10.py @@ -5,9 +5,9 @@ from torchvision.datasets import CIFAR10 def main(): dir_path = os.path.dirname(os.path.realpath(__file__)) - data_root = os.path.join(dir_path, 'data') + data_root = os.path.join(dir_path, "data") dataset = CIFAR10(root=data_root, download=True) -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/examples/tutorial/fastfold/FastFold b/examples/tutorial/fastfold/FastFold index 05681304651b1b29d7d887db169045ea3dd28fce..eba496808a91bbcd9661cf832349a418b197015f 160000 --- a/examples/tutorial/fastfold/FastFold +++ b/examples/tutorial/fastfold/FastFold @@ -1 +1 @@ -Subproject commit 05681304651b1b29d7d887db169045ea3dd28fce +Subproject commit eba496808a91bbcd9661cf832349a418b197015f diff --git a/examples/tutorial/hybrid_parallel/config.py b/examples/tutorial/hybrid_parallel/config.py index fe9abf2f1955fc0c9e15bf7d8669c5d05c36ce76..15f9d0bc75ee1e9387b74ed79d5ebd7eed85d8b2 100644 --- a/examples/tutorial/hybrid_parallel/config.py +++ b/examples/tutorial/hybrid_parallel/config.py @@ -1,4 +1,4 @@ -from colossalai.amp import AMP_TYPE +from colossalai.legacy.amp import AMP_TYPE # hyperparameters # BATCH_SIZE is as per GPU @@ -18,11 +18,11 @@ NUM_HEADS = 4 MLP_RATIO = 2 NUM_CLASSES = 10 CHECKPOINT = False -SEQ_LENGTH = (IMG_SIZE // PATCH_SIZE)**2 + 1 # add 1 for cls token +SEQ_LENGTH = (IMG_SIZE // PATCH_SIZE) ** 2 + 1 # add 1 for cls token # parallel setting TENSOR_PARALLEL_SIZE = 2 -TENSOR_PARALLEL_MODE = '1d' +TENSOR_PARALLEL_MODE = "1d" parallel = dict( pipeline=2, @@ -33,4 +33,4 @@ fp16 = dict(mode=AMP_TYPE.NAIVE) clip_grad_norm = 1.0 # pipeline config -NUM_MICRO_BATCHES = parallel['pipeline'] +NUM_MICRO_BATCHES = parallel["pipeline"] diff --git a/examples/tutorial/hybrid_parallel/test_ci.sh b/examples/tutorial/hybrid_parallel/test_ci.sh index e0dbef354e2d85721a3deda62b2c392dad09bb1d..24cee1da3de44fd4492a37f1036eb5e79d1a0866 100644 --- a/examples/tutorial/hybrid_parallel/test_ci.sh +++ b/examples/tutorial/hybrid_parallel/test_ci.sh @@ -1,5 +1,7 @@ #!/bin/bash set -euxo pipefail -pip install -r requirements.txt -colossalai run --nproc_per_node 4 train.py --config config.py +echo "legacy example" + +# pip install -r requirements.txt +# colossalai run --nproc_per_node 4 train.py --config config.py diff --git a/examples/tutorial/hybrid_parallel/train.py b/examples/tutorial/hybrid_parallel/train.py index 4953d5350f31ac222330fbcbffc156e3187fd80b..95f1bf8ee17c8965c6f3554573fe792732913960 100644 --- a/examples/tutorial/hybrid_parallel/train.py +++ b/examples/tutorial/hybrid_parallel/train.py @@ -5,17 +5,16 @@ from titans.model.vit.vit import _create_vit_model from tqdm import tqdm import colossalai -from colossalai.context import ParallelMode -from colossalai.core import global_context as gpc +from colossalai.legacy.context import ParallelMode +from colossalai.legacy.core import global_context as gpc +from colossalai.legacy.nn import CrossEntropyLoss +from colossalai.legacy.pipeline.pipelinable import PipelinableContext from colossalai.logging import get_dist_logger -from colossalai.nn import CrossEntropyLoss from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR -from colossalai.pipeline.pipelinable import PipelinableContext from colossalai.utils import is_using_pp -class DummyDataloader(): - +class DummyDataloader: def __init__(self, length, batch_size): self.length = length self.batch_size = batch_size @@ -50,7 +49,7 @@ def main(): logger = get_dist_logger() logger.info("initialized distributed environment", ranks=[0]) - if hasattr(gpc.config, 'LOG_PATH'): + if hasattr(gpc.config, "LOG_PATH"): if gpc.get_global_rank() == 0: log_path = gpc.config.LOG_PATH if not os.path.exists(log_path): @@ -60,15 +59,17 @@ def main(): use_pipeline = is_using_pp() # create model - model_kwargs = dict(img_size=gpc.config.IMG_SIZE, - patch_size=gpc.config.PATCH_SIZE, - hidden_size=gpc.config.HIDDEN_SIZE, - depth=gpc.config.DEPTH, - num_heads=gpc.config.NUM_HEADS, - mlp_ratio=gpc.config.MLP_RATIO, - num_classes=10, - init_method='jax', - checkpoint=gpc.config.CHECKPOINT) + model_kwargs = dict( + img_size=gpc.config.IMG_SIZE, + patch_size=gpc.config.PATCH_SIZE, + hidden_size=gpc.config.HIDDEN_SIZE, + depth=gpc.config.DEPTH, + num_heads=gpc.config.NUM_HEADS, + mlp_ratio=gpc.config.MLP_RATIO, + num_classes=10, + init_method="jax", + checkpoint=gpc.config.CHECKPOINT, + ) if use_pipeline: pipelinable = PipelinableContext() @@ -102,16 +103,18 @@ def main(): optimizer = torch.optim.AdamW(model.parameters(), lr=gpc.config.LEARNING_RATE, weight_decay=gpc.config.WEIGHT_DECAY) # create lr scheduler - lr_scheduler = CosineAnnealingWarmupLR(optimizer=optimizer, - total_steps=gpc.config.NUM_EPOCHS, - warmup_steps=gpc.config.WARMUP_EPOCHS) + lr_scheduler = CosineAnnealingWarmupLR( + optimizer=optimizer, total_steps=gpc.config.NUM_EPOCHS, warmup_steps=gpc.config.WARMUP_EPOCHS + ) # initialize - engine, train_dataloader, test_dataloader, _ = colossalai.initialize(model=model, - optimizer=optimizer, - criterion=criterion, - train_dataloader=train_dataloader, - test_dataloader=test_dataloader) + engine, train_dataloader, test_dataloader, _ = colossalai.initialize( + model=model, + optimizer=optimizer, + criterion=criterion, + train_dataloader=train_dataloader, + test_dataloader=test_dataloader, + ) logger.info("Engine is built", ranks=[0]) @@ -121,7 +124,7 @@ def main(): data_iter = iter(train_dataloader) if gpc.get_global_rank() == 0: - description = 'Epoch {} / {}'.format(epoch, gpc.config.NUM_EPOCHS) + description = "Epoch {} / {}".format(epoch, gpc.config.NUM_EPOCHS) progress = tqdm(range(len(train_dataloader)), desc=description) else: progress = range(len(train_dataloader)) @@ -133,5 +136,5 @@ def main(): gpc.destroy() -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/examples/tutorial/large_batch_optimizer/config.py b/examples/tutorial/large_batch_optimizer/config.py index 2efa0ffd0556c6d245dea8e86ffc19cf5ab68cc3..c6d9f94505f1abf208ca88b69d818fb80b63df4e 100644 --- a/examples/tutorial/large_batch_optimizer/config.py +++ b/examples/tutorial/large_batch_optimizer/config.py @@ -1,4 +1,4 @@ -from colossalai.amp import AMP_TYPE +from colossalai.legacy.amp import AMP_TYPE # hyperparameters # BATCH_SIZE is as per GPU diff --git a/examples/tutorial/large_batch_optimizer/test_ci.sh b/examples/tutorial/large_batch_optimizer/test_ci.sh index 89f426c542b18f61225ed86eefe15038cba41cfe..f4393938220dc5293fda295b7a0d4d5e67eb3562 100644 --- a/examples/tutorial/large_batch_optimizer/test_ci.sh +++ b/examples/tutorial/large_batch_optimizer/test_ci.sh @@ -1,8 +1,9 @@ #!/bin/bash set -euxo pipefail +echo "this test is outdated" -pip install -r requirements.txt +# pip install -r requirements.txt # run test -colossalai run --nproc_per_node 4 --master_port 29500 train.py --config config.py --optimizer lars -colossalai run --nproc_per_node 4 --master_port 29501 train.py --config config.py --optimizer lamb +# colossalai run --nproc_per_node 4 --master_port 29500 train.py --config config.py --optimizer lars +# colossalai run --nproc_per_node 4 --master_port 29501 train.py --config config.py --optimizer lamb diff --git a/examples/tutorial/large_batch_optimizer/train.py b/examples/tutorial/large_batch_optimizer/train.py index 35e54582f49443ce994e77dbd3e1722ae9a7bd01..dd114b5af86d639a3cdee3280fa89d1ffb8bd7c6 100644 --- a/examples/tutorial/large_batch_optimizer/train.py +++ b/examples/tutorial/large_batch_optimizer/train.py @@ -4,14 +4,13 @@ from torchvision.models import resnet18 from tqdm import tqdm import colossalai -from colossalai.core import global_context as gpc +from colossalai.legacy.core import global_context as gpc from colossalai.logging import get_dist_logger from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR from colossalai.nn.optimizer import Lamb, Lars -class DummyDataloader(): - +class DummyDataloader: def __init__(self, length, batch_size): self.length = length self.batch_size = batch_size @@ -39,10 +38,9 @@ class DummyDataloader(): def main(): # initialize distributed setting parser = colossalai.get_default_parser() - parser.add_argument('--optimizer', - choices=['lars', 'lamb'], - help="Choose your large-batch optimizer", - required=True) + parser.add_argument( + "--optimizer", choices=["lars", "lamb"], help="Choose your large-batch optimizer", required=True + ) args = parser.parse_args() # launch from torch @@ -70,16 +68,18 @@ def main(): optimizer = optim_cls(model.parameters(), lr=gpc.config.LEARNING_RATE, weight_decay=gpc.config.WEIGHT_DECAY) # create lr scheduler - lr_scheduler = CosineAnnealingWarmupLR(optimizer=optimizer, - total_steps=gpc.config.NUM_EPOCHS, - warmup_steps=gpc.config.WARMUP_EPOCHS) + lr_scheduler = CosineAnnealingWarmupLR( + optimizer=optimizer, total_steps=gpc.config.NUM_EPOCHS, warmup_steps=gpc.config.WARMUP_EPOCHS + ) # initialize - engine, train_dataloader, test_dataloader, _ = colossalai.initialize(model=model, - optimizer=optimizer, - criterion=criterion, - train_dataloader=train_dataloader, - test_dataloader=test_dataloader) + engine, train_dataloader, test_dataloader, _ = colossalai.initialize( + model=model, + optimizer=optimizer, + criterion=criterion, + train_dataloader=train_dataloader, + test_dataloader=test_dataloader, + ) logger.info("Engine is built", ranks=[0]) @@ -89,7 +89,7 @@ def main(): data_iter = iter(train_dataloader) if gpc.get_global_rank() == 0: - description = 'Epoch {} / {}'.format(epoch, gpc.config.NUM_EPOCHS) + description = "Epoch {} / {}".format(epoch, gpc.config.NUM_EPOCHS) progress = tqdm(range(len(train_dataloader)), desc=description) else: progress = range(len(train_dataloader)) @@ -100,5 +100,5 @@ def main(): lr_scheduler.step() -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/examples/tutorial/new_api/cifar_resnet/.gitignore b/examples/tutorial/new_api/cifar_resnet/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..a79cf5236c088af99b3891cb3dc536aaaded808c --- /dev/null +++ b/examples/tutorial/new_api/cifar_resnet/.gitignore @@ -0,0 +1,4 @@ +data +checkpoint +ckpt-fp16 +ckpt-fp32 diff --git a/examples/tutorial/new_api/cifar_resnet/README.md b/examples/tutorial/new_api/cifar_resnet/README.md new file mode 100644 index 0000000000000000000000000000000000000000..4ed86aa7a0ad8595fba636c6be6d072db47b2ac3 --- /dev/null +++ b/examples/tutorial/new_api/cifar_resnet/README.md @@ -0,0 +1,56 @@ +# Train ResNet on CIFAR-10 from scratch + +## 🚀 Quick Start + +This example provides a training script and an evaluation script. The training script provides an example of training ResNet on CIFAR10 dataset from scratch. + +- Training Arguments + - `-p`, `--plugin`: Plugin to use. Choices: `torch_ddp`, `torch_ddp_fp16`, `low_level_zero`. Defaults to `torch_ddp`. + - `-r`, `--resume`: Resume from checkpoint file path. Defaults to `-1`, which means not resuming. + - `-c`, `--checkpoint`: The folder to save checkpoints. Defaults to `./checkpoint`. + - `-i`, `--interval`: Epoch interval to save checkpoints. Defaults to `5`. If set to `0`, no checkpoint will be saved. + - `--target_acc`: Target accuracy. Raise exception if not reached. Defaults to `None`. + +- Eval Arguments + - `-e`, `--epoch`: select the epoch to evaluate + - `-c`, `--checkpoint`: the folder where checkpoints are found + +### Install requirements + +```bash +pip install -r requirements.txt +``` + +### Train + +```bash +# train with torch DDP with fp32 +colossalai run --nproc_per_node 2 train.py -c ./ckpt-fp32 + +# train with torch DDP with mixed precision training +colossalai run --nproc_per_node 2 train.py -c ./ckpt-fp16 -p torch_ddp_fp16 + +# train with low level zero +colossalai run --nproc_per_node 2 train.py -c ./ckpt-low_level_zero -p low_level_zero +``` + +### Eval + +```bash +# evaluate fp32 training +python eval.py -c ./ckpt-fp32 -e 80 + +# evaluate fp16 mixed precision training +python eval.py -c ./ckpt-fp16 -e 80 + +# evaluate low level zero training +python eval.py -c ./ckpt-low_level_zero -e 80 +``` + +Expected accuracy performance will be: + +| Model | Single-GPU Baseline FP32 | Booster DDP with FP32 | Booster DDP with FP16 | Booster Low Level Zero | +| --------- | ------------------------ | --------------------- | --------------------- | ---------------------- | +| ResNet-18 | 85.85% | 84.91% | 85.46% | 84.50% | + +**Note: the baseline is adapted from the [script](https://pytorch-tutorial.readthedocs.io/en/latest/tutorial/chapter03_intermediate/3_2_2_cnn_resnet_cifar10/) to use `torchvision.models.resnet18`** diff --git a/examples/tutorial/new_api/cifar_resnet/eval.py b/examples/tutorial/new_api/cifar_resnet/eval.py new file mode 100644 index 0000000000000000000000000000000000000000..526e41a2850fe253991cb0636d98368a5970c408 --- /dev/null +++ b/examples/tutorial/new_api/cifar_resnet/eval.py @@ -0,0 +1,47 @@ +import argparse + +import torch +import torchvision +import torchvision.transforms as transforms + +# ============================== +# Parse Arguments +# ============================== +parser = argparse.ArgumentParser() +parser.add_argument("-e", "--epoch", type=int, default=80, help="resume from the epoch's checkpoint") +parser.add_argument("-c", "--checkpoint", type=str, default="./checkpoint", help="checkpoint directory") +args = parser.parse_args() + +# ============================== +# Prepare Test Dataset +# ============================== +# CIFAR-10 dataset +test_dataset = torchvision.datasets.CIFAR10(root="./data/", train=False, transform=transforms.ToTensor()) + +# Data loader +test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=128, shuffle=False) + +# ============================== +# Load Model +# ============================== +model = torchvision.models.resnet18(num_classes=10).cuda() +state_dict = torch.load(f"{args.checkpoint}/model_{args.epoch}.pth") +model.load_state_dict(state_dict) + +# ============================== +# Run Evaluation +# ============================== +model.eval() + +with torch.no_grad(): + correct = 0 + total = 0 + for images, labels in test_loader: + images = images.cuda() + labels = labels.cuda() + outputs = model(images) + _, predicted = torch.max(outputs.data, 1) + total += labels.size(0) + correct += (predicted == labels).sum().item() + + print("Accuracy of the model on the test images: {} %".format(100 * correct / total)) diff --git a/examples/tutorial/new_api/cifar_resnet/requirements.txt b/examples/tutorial/new_api/cifar_resnet/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..85522f4129c44912990ddf0d052a6362c4888b31 --- /dev/null +++ b/examples/tutorial/new_api/cifar_resnet/requirements.txt @@ -0,0 +1,4 @@ +colossalai +torch +torchvision +tqdm diff --git a/examples/tutorial/new_api/cifar_resnet/test_ci.sh b/examples/tutorial/new_api/cifar_resnet/test_ci.sh new file mode 100755 index 0000000000000000000000000000000000000000..3954b84ff1baa52477dd0aca33ef8fcdceb77c45 --- /dev/null +++ b/examples/tutorial/new_api/cifar_resnet/test_ci.sh @@ -0,0 +1,10 @@ +#!/bin/bash +set -xe + +export DATA=/data/scratch/cifar-10 + +pip install -r requirements.txt + +for plugin in "torch_ddp" "torch_ddp_fp16" "low_level_zero"; do + colossalai run --nproc_per_node 4 train.py --interval 0 --target_acc 0.84 --plugin $plugin +done diff --git a/examples/tutorial/new_api/cifar_resnet/train.py b/examples/tutorial/new_api/cifar_resnet/train.py new file mode 100644 index 0000000000000000000000000000000000000000..4407a51c315384e1ff715e13ecef1ff21d453e96 --- /dev/null +++ b/examples/tutorial/new_api/cifar_resnet/train.py @@ -0,0 +1,207 @@ +import argparse +import os +from pathlib import Path + +import torch +import torch.distributed as dist +import torch.nn as nn +import torchvision +import torchvision.transforms as transforms +from torch.optim import Optimizer +from torch.optim.lr_scheduler import MultiStepLR +from torch.utils.data import DataLoader +from tqdm import tqdm + +import colossalai +from colossalai.booster import Booster +from colossalai.booster.plugin import GeminiPlugin, LowLevelZeroPlugin, TorchDDPPlugin +from colossalai.booster.plugin.dp_plugin_base import DPPluginBase +from colossalai.cluster import DistCoordinator +from colossalai.nn.optimizer import HybridAdam +from colossalai.utils import get_current_device + +# ============================== +# Prepare Hyperparameters +# ============================== +NUM_EPOCHS = 80 +LEARNING_RATE = 1e-3 + + +def build_dataloader(batch_size: int, coordinator: DistCoordinator, plugin: DPPluginBase): + # transform + transform_train = transforms.Compose( + [transforms.Pad(4), transforms.RandomHorizontalFlip(), transforms.RandomCrop(32), transforms.ToTensor()] + ) + transform_test = transforms.ToTensor() + + # CIFAR-10 dataset + data_path = os.environ.get("DATA", "./data") + with coordinator.priority_execution(): + train_dataset = torchvision.datasets.CIFAR10( + root=data_path, train=True, transform=transform_train, download=True + ) + test_dataset = torchvision.datasets.CIFAR10( + root=data_path, train=False, transform=transform_test, download=True + ) + + # Data loader + train_dataloader = plugin.prepare_dataloader(train_dataset, batch_size=batch_size, shuffle=True, drop_last=True) + test_dataloader = plugin.prepare_dataloader(test_dataset, batch_size=batch_size, shuffle=False, drop_last=False) + return train_dataloader, test_dataloader + + +@torch.no_grad() +def evaluate(model: nn.Module, test_dataloader: DataLoader, coordinator: DistCoordinator) -> float: + model.eval() + correct = torch.zeros(1, dtype=torch.int64, device=get_current_device()) + total = torch.zeros(1, dtype=torch.int64, device=get_current_device()) + for images, labels in test_dataloader: + images = images.cuda() + labels = labels.cuda() + outputs = model(images) + _, predicted = torch.max(outputs.data, 1) + total += labels.size(0) + correct += (predicted == labels).sum().item() + dist.all_reduce(correct) + dist.all_reduce(total) + accuracy = correct.item() / total.item() + if coordinator.is_master(): + print(f"Accuracy of the model on the test images: {accuracy * 100:.2f} %") + return accuracy + + +def train_epoch( + epoch: int, + model: nn.Module, + optimizer: Optimizer, + criterion: nn.Module, + train_dataloader: DataLoader, + booster: Booster, + coordinator: DistCoordinator, +): + model.train() + with tqdm(train_dataloader, desc=f"Epoch [{epoch + 1}/{NUM_EPOCHS}]", disable=not coordinator.is_master()) as pbar: + for images, labels in pbar: + images = images.cuda() + labels = labels.cuda() + # Forward pass + outputs = model(images) + loss = criterion(outputs, labels) + + # Backward and optimize + booster.backward(loss, optimizer) + optimizer.step() + optimizer.zero_grad() + + # Print log info + pbar.set_postfix({"loss": loss.item()}) + + +def main(): + # ============================== + # Parse Arguments + # ============================== + parser = argparse.ArgumentParser() + # FIXME(ver217): gemini is not supported resnet now + parser.add_argument( + "-p", + "--plugin", + type=str, + default="torch_ddp", + choices=["torch_ddp", "torch_ddp_fp16", "low_level_zero"], + help="plugin to use", + ) + parser.add_argument("-r", "--resume", type=int, default=-1, help="resume from the epoch's checkpoint") + parser.add_argument("-c", "--checkpoint", type=str, default="./checkpoint", help="checkpoint directory") + parser.add_argument("-i", "--interval", type=int, default=5, help="interval of saving checkpoint") + parser.add_argument( + "--target_acc", type=float, default=None, help="target accuracy. Raise exception if not reached" + ) + args = parser.parse_args() + + # ============================== + # Prepare Checkpoint Directory + # ============================== + if args.interval > 0: + Path(args.checkpoint).mkdir(parents=True, exist_ok=True) + + # ============================== + # Launch Distributed Environment + # ============================== + colossalai.launch_from_torch(config={}) + coordinator = DistCoordinator() + + # update the learning rate with linear scaling + # old_gpu_num / old_lr = new_gpu_num / new_lr + global LEARNING_RATE + LEARNING_RATE *= coordinator.world_size + + # ============================== + # Instantiate Plugin and Booster + # ============================== + booster_kwargs = {} + if args.plugin == "torch_ddp_fp16": + booster_kwargs["mixed_precision"] = "fp16" + if args.plugin.startswith("torch_ddp"): + plugin = TorchDDPPlugin() + elif args.plugin == "gemini": + plugin = GeminiPlugin(placement_policy="static", strict_ddp_mode=True, initial_scale=2**5) + elif args.plugin == "low_level_zero": + plugin = LowLevelZeroPlugin(initial_scale=2**5) + + booster = Booster(plugin=plugin, **booster_kwargs) + + # ============================== + # Prepare Dataloader + # ============================== + train_dataloader, test_dataloader = build_dataloader(100, coordinator, plugin) + + # ==================================== + # Prepare model, optimizer, criterion + # ==================================== + # resent50 + model = torchvision.models.resnet18(num_classes=10) + + # Loss and optimizer + criterion = nn.CrossEntropyLoss() + optimizer = HybridAdam(model.parameters(), lr=LEARNING_RATE) + + # lr scheduler + lr_scheduler = MultiStepLR(optimizer, milestones=[20, 40, 60, 80], gamma=1 / 3) + + # ============================== + # Boost with ColossalAI + # ============================== + model, optimizer, criterion, _, lr_scheduler = booster.boost( + model, optimizer, criterion=criterion, lr_scheduler=lr_scheduler + ) + + # ============================== + # Resume from checkpoint + # ============================== + if args.resume >= 0: + booster.load_model(model, f"{args.checkpoint}/model_{args.resume}.pth") + booster.load_optimizer(optimizer, f"{args.checkpoint}/optimizer_{args.resume}.pth") + booster.load_lr_scheduler(lr_scheduler, f"{args.checkpoint}/lr_scheduler_{args.resume}.pth") + + # ============================== + # Train model + # ============================== + start_epoch = args.resume if args.resume >= 0 else 0 + for epoch in range(start_epoch, NUM_EPOCHS): + train_epoch(epoch, model, optimizer, criterion, train_dataloader, booster, coordinator) + lr_scheduler.step() + + # save checkpoint + if args.interval > 0 and (epoch + 1) % args.interval == 0: + booster.save_model(model, f"{args.checkpoint}/model_{epoch + 1}.pth") + booster.save_optimizer(optimizer, f"{args.checkpoint}/optimizer_{epoch + 1}.pth") + booster.save_lr_scheduler(lr_scheduler, f"{args.checkpoint}/lr_scheduler_{epoch + 1}.pth") + + accuracy = evaluate(model, test_dataloader, coordinator) + if args.target_acc is not None: + assert accuracy >= args.target_acc, f"Accuracy {accuracy} is lower than target accuracy {args.target_acc}" + + +if __name__ == "__main__": + main() diff --git a/examples/tutorial/new_api/cifar_vit/README.md b/examples/tutorial/new_api/cifar_vit/README.md new file mode 100644 index 0000000000000000000000000000000000000000..fa76447c508f08bd7c5e5c3e920099eb409a6dbf --- /dev/null +++ b/examples/tutorial/new_api/cifar_vit/README.md @@ -0,0 +1,37 @@ +# Train ViT on CIFAR-10 from scratch + +## 🚀 Quick Start + +This example provides a training script, which provides an example of training ViT on CIFAR10 dataset from scratch. + +- Training Arguments + - `-p`, `--plugin`: Plugin to use. Choices: `torch_ddp`, `torch_ddp_fp16`, `low_level_zero`. Defaults to `torch_ddp`. + - `-r`, `--resume`: Resume from checkpoint file path. Defaults to `-1`, which means not resuming. + - `-c`, `--checkpoint`: The folder to save checkpoints. Defaults to `./checkpoint`. + - `-i`, `--interval`: Epoch interval to save checkpoints. Defaults to `5`. If set to `0`, no checkpoint will be saved. + - `--target_acc`: Target accuracy. Raise exception if not reached. Defaults to `None`. + +### Install requirements + +```bash +pip install -r requirements.txt +``` + +### Train + +```bash +# train with torch DDP with fp32 +colossalai run --nproc_per_node 4 train.py -c ./ckpt-fp32 + +# train with torch DDP with mixed precision training +colossalai run --nproc_per_node 4 train.py -c ./ckpt-fp16 -p torch_ddp_fp16 + +# train with low level zero +colossalai run --nproc_per_node 4 train.py -c ./ckpt-low_level_zero -p low_level_zero +``` + +Expected accuracy performance will be: + +| Model | Single-GPU Baseline FP32 | Booster DDP with FP32 | Booster DDP with FP16 | Booster Low Level Zero | +| --------- | ------------------------ | --------------------- | --------------------- | ---------------------- | +| ViT | 83.00% | 84.03% | 84.00% | 84.43% | diff --git a/examples/tutorial/new_api/cifar_vit/requirements.txt b/examples/tutorial/new_api/cifar_vit/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..6d53ce7b5a7da3efc1acbe8422793154beb6da76 --- /dev/null +++ b/examples/tutorial/new_api/cifar_vit/requirements.txt @@ -0,0 +1,5 @@ +colossalai +timm +torch +torchvision +tqdm diff --git a/examples/tutorial/new_api/cifar_vit/test_ci.sh b/examples/tutorial/new_api/cifar_vit/test_ci.sh new file mode 100755 index 0000000000000000000000000000000000000000..43239d4005863ff012bc4b0d5cdb414d3cedee43 --- /dev/null +++ b/examples/tutorial/new_api/cifar_vit/test_ci.sh @@ -0,0 +1,10 @@ +#!/bin/bash +set -xe + +export DATA=/data/scratch/cifar-10 + +pip install -r requirements.txt + +for plugin in "torch_ddp" "torch_ddp_fp16" "low_level_zero"; do + colossalai run --nproc_per_node 4 train.py --interval 0 --target_acc 0.83 --plugin $plugin +done diff --git a/examples/tutorial/new_api/cifar_vit/train.py b/examples/tutorial/new_api/cifar_vit/train.py new file mode 100644 index 0000000000000000000000000000000000000000..700e4d2e0cd96fe11d8f9bdc0bcc8d2e9276fe9e --- /dev/null +++ b/examples/tutorial/new_api/cifar_vit/train.py @@ -0,0 +1,227 @@ +import argparse +import os +from pathlib import Path + +import torch +import torch.distributed as dist +import torch.nn as nn +import torchvision +import torchvision.transforms as transforms +from timm.models.vision_transformer import _cfg, _create_vision_transformer +from torch.optim import Optimizer +from torch.utils.data import DataLoader +from tqdm import tqdm + +import colossalai +from colossalai.booster import Booster +from colossalai.booster.plugin import GeminiPlugin, LowLevelZeroPlugin, TorchDDPPlugin +from colossalai.booster.plugin.dp_plugin_base import DPPluginBase +from colossalai.cluster import DistCoordinator +from colossalai.nn.lr_scheduler import LinearWarmupLR +from colossalai.nn.optimizer import HybridAdam +from colossalai.utils import get_current_device + +# ============================== +# Prepare Hyperparameters +# ============================== +NUM_EPOCHS = 60 +WARMUP_EPOCHS = 5 +LEARNING_RATE = 1e-3 + + +def vit_cifar(**kwargs): + pretrained_cfg = _cfg(num_classes=10, input_size=(3, 32, 32), crop_pct=1.0) + model_kwargs = dict(patch_size=4, embed_dim=512, depth=6, num_heads=8, drop_rate=0.1, mlp_ratio=1.0, **kwargs) + model = _create_vision_transformer("vit_cifar", pretrained_cfg=pretrained_cfg, **model_kwargs) + return model + + +def build_dataloader(batch_size: int, coordinator: DistCoordinator, plugin: DPPluginBase): + # transform + transform_train = transforms.Compose( + [ + transforms.RandomCrop(32, padding=4), + transforms.RandomHorizontalFlip(), + transforms.ToTensor(), + transforms.Normalize((0.49139968, 0.48215827, 0.44653124), (0.24703233, 0.24348505, 0.26158768)), + ] + ) + transform_test = transforms.Compose( + [ + transforms.Resize(32), + transforms.ToTensor(), + transforms.Normalize((0.49139968, 0.48215827, 0.44653124), (0.24703233, 0.24348505, 0.26158768)), + ] + ) + + # CIFAR-10 dataset + data_path = os.environ.get("DATA", "./data") + with coordinator.priority_execution(): + train_dataset = torchvision.datasets.CIFAR10( + root=data_path, train=True, transform=transform_train, download=True + ) + test_dataset = torchvision.datasets.CIFAR10( + root=data_path, train=False, transform=transform_test, download=True + ) + + # Data loader + train_dataloader = plugin.prepare_dataloader(train_dataset, batch_size=batch_size, shuffle=True, drop_last=True) + test_dataloader = plugin.prepare_dataloader(test_dataset, batch_size=batch_size, shuffle=False, drop_last=False) + return train_dataloader, test_dataloader + + +@torch.no_grad() +def evaluate(model: nn.Module, test_dataloader: DataLoader, coordinator: DistCoordinator) -> float: + model.eval() + correct = torch.zeros(1, dtype=torch.int64, device=get_current_device()) + total = torch.zeros(1, dtype=torch.int64, device=get_current_device()) + for images, labels in test_dataloader: + images = images.cuda() + labels = labels.cuda() + outputs = model(images) + _, predicted = torch.max(outputs.data, 1) + total += labels.size(0) + correct += (predicted == labels).sum().item() + dist.all_reduce(correct) + dist.all_reduce(total) + accuracy = correct.item() / total.item() + if coordinator.is_master(): + print(f"Accuracy of the model on the test images: {accuracy * 100:.2f} %") + return accuracy + + +def train_epoch( + epoch: int, + model: nn.Module, + optimizer: Optimizer, + criterion: nn.Module, + train_dataloader: DataLoader, + booster: Booster, + coordinator: DistCoordinator, +): + model.train() + with tqdm(train_dataloader, desc=f"Epoch [{epoch + 1}/{NUM_EPOCHS}]", disable=not coordinator.is_master()) as pbar: + for images, labels in pbar: + images = images.cuda() + labels = labels.cuda() + # Forward pass + outputs = model(images) + loss = criterion(outputs, labels) + + # Backward and optimize + booster.backward(loss, optimizer) + optimizer.step() + optimizer.zero_grad() + + # Print log info + pbar.set_postfix({"loss": loss.item()}) + + +def main(): + # ============================== + # Parse Arguments + # ============================== + parser = argparse.ArgumentParser() + # FIXME(ver217): gemini is not supported resnet now + parser.add_argument( + "-p", + "--plugin", + type=str, + default="torch_ddp", + choices=["torch_ddp", "torch_ddp_fp16", "low_level_zero"], + help="plugin to use", + ) + parser.add_argument("-r", "--resume", type=int, default=-1, help="resume from the epoch's checkpoint") + parser.add_argument("-c", "--checkpoint", type=str, default="./checkpoint", help="checkpoint directory") + parser.add_argument("-i", "--interval", type=int, default=5, help="interval of saving checkpoint") + parser.add_argument( + "--target_acc", type=float, default=None, help="target accuracy. Raise exception if not reached" + ) + args = parser.parse_args() + + # ============================== + # Prepare Checkpoint Directory + # ============================== + if args.interval > 0: + Path(args.checkpoint).mkdir(parents=True, exist_ok=True) + + # ============================== + # Launch Distributed Environment + # ============================== + colossalai.launch_from_torch(config={}) + coordinator = DistCoordinator() + + # update the learning rate with linear scaling + # old_gpu_num / old_lr = new_gpu_num / new_lr + global LEARNING_RATE + LEARNING_RATE *= coordinator.world_size + + # ============================== + # Instantiate Plugin and Booster + # ============================== + booster_kwargs = {} + if args.plugin == "torch_ddp_fp16": + booster_kwargs["mixed_precision"] = "fp16" + if args.plugin.startswith("torch_ddp"): + plugin = TorchDDPPlugin() + elif args.plugin == "gemini": + plugin = GeminiPlugin(placement_policy="static", strict_ddp_mode=True, initial_scale=2**5) + elif args.plugin == "low_level_zero": + plugin = LowLevelZeroPlugin(initial_scale=2**5) + + booster = Booster(plugin=plugin, **booster_kwargs) + + # ============================== + # Prepare Dataloader + # ============================== + train_dataloader, test_dataloader = build_dataloader(512, coordinator, plugin) + + # ==================================== + # Prepare model, optimizer, criterion + # ==================================== + # resent50 + model = torchvision.models.resnet18(num_classes=10) + + # Loss and optimizer + criterion = nn.CrossEntropyLoss() + optimizer = HybridAdam(model.parameters(), lr=LEARNING_RATE) + + # lr scheduler + lr_scheduler = LinearWarmupLR(optimizer, NUM_EPOCHS, WARMUP_EPOCHS) + + # ============================== + # Boost with ColossalAI + # ============================== + model, optimizer, criterion, train_dataloader, lr_scheduler = booster.boost( + model, optimizer, criterion=criterion, dataloader=train_dataloader, lr_scheduler=lr_scheduler + ) + + # ============================== + # Resume from checkpoint + # ============================== + if args.resume >= 0: + booster.load_model(model, f"{args.checkpoint}/model_{args.resume}.pth") + booster.load_optimizer(optimizer, f"{args.checkpoint}/optimizer_{args.resume}.pth") + booster.load_lr_scheduler(lr_scheduler, f"{args.checkpoint}/lr_scheduler_{args.resume}.pth") + + # ============================== + # Train model + # ============================== + start_epoch = args.resume if args.resume >= 0 else 0 + for epoch in range(start_epoch, NUM_EPOCHS): + train_epoch(epoch, model, optimizer, criterion, train_dataloader, booster, coordinator) + lr_scheduler.step() + + # save checkpoint + if args.interval > 0 and (epoch + 1) % args.interval == 0: + booster.save_model(model, f"{args.checkpoint}/model_{epoch + 1}.pth") + booster.save_optimizer(optimizer, f"{args.checkpoint}/optimizer_{epoch + 1}.pth") + booster.save_lr_scheduler(lr_scheduler, f"{args.checkpoint}/lr_scheduler_{epoch + 1}.pth") + + accuracy = evaluate(model, test_dataloader, coordinator) + if args.target_acc is not None: + assert accuracy >= args.target_acc, f"Accuracy {accuracy} is lower than target accuracy {args.target_acc}" + + +if __name__ == "__main__": + main() diff --git a/examples/tutorial/new_api/glue_bert/README.md b/examples/tutorial/new_api/glue_bert/README.md new file mode 100644 index 0000000000000000000000000000000000000000..0030eead9f5be2d4d25c3cc1d177b105b5c01d5b --- /dev/null +++ b/examples/tutorial/new_api/glue_bert/README.md @@ -0,0 +1,39 @@ +# Finetune BERT on GLUE + +## 🚀 Quick Start + +This example provides a training script, which provides an example of finetuning BERT on GLUE dataset. + +- Training Arguments + - `-t`, `--task`: GLUE task to run. Defaults to `mrpc`. + - `-p`, `--plugin`: Plugin to use. Choices: `torch_ddp`, `torch_ddp_fp16`, `gemini`, `low_level_zero`. Defaults to `torch_ddp`. + - `--target_f1`: Target f1 score. Raise exception if not reached. Defaults to `None`. + + +### Install requirements + +```bash +pip install -r requirements.txt +``` + +### Train + +```bash +# train with torch DDP with fp32 +colossalai run --nproc_per_node 4 finetune.py + +# train with torch DDP with mixed precision training +colossalai run --nproc_per_node 4 finetune.py -p torch_ddp_fp16 + +# train with gemini +colossalai run --nproc_per_node 4 finetune.py -p gemini + +# train with low level zero +colossalai run --nproc_per_node 4 finetune.py -p low_level_zero +``` + +Expected F1-score will be: + +| Model | Single-GPU Baseline FP32 | Booster DDP with FP32 | Booster DDP with FP16 | Booster Gemini | Booster Low Level Zero | +| ----------------- | ------------------------ | --------------------- | --------------------- |--------------- | ---------------------- | +| bert-base-uncased | 0.86 | 0.88 | 0.87 | 0.88 | 0.89 | diff --git a/examples/tutorial/new_api/glue_bert/data.py b/examples/tutorial/new_api/glue_bert/data.py new file mode 100644 index 0000000000000000000000000000000000000000..ef51f938dc4f8be5811916d692c8234fd811cd84 --- /dev/null +++ b/examples/tutorial/new_api/glue_bert/data.py @@ -0,0 +1,123 @@ +import datasets +from transformers import AutoTokenizer, PreTrainedTokenizer + +from colossalai.booster.plugin.dp_plugin_base import DPPluginBase + + +class GLUEDataBuilder: + task_text_field_map = { + "cola": ["sentence"], + "sst2": ["sentence"], + "mrpc": ["sentence1", "sentence2"], + "qqp": ["question1", "question2"], + "stsb": ["sentence1", "sentence2"], + "mnli": ["premise", "hypothesis"], + "qnli": ["question", "sentence"], + "rte": ["sentence1", "sentence2"], + "wnli": ["sentence1", "sentence2"], + "ax": ["premise", "hypothesis"], + } + + glue_task_num_labels = { + "cola": 2, + "sst2": 2, + "mrpc": 2, + "qqp": 2, + "stsb": 1, + "mnli": 3, + "qnli": 2, + "rte": 2, + "wnli": 2, + "ax": 3, + } + + loader_columns = [ + "datasets_idx", + "input_ids", + "token_type_ids", + "attention_mask", + "start_positions", + "end_positions", + "labels", + ] + + def __init__( + self, + model_name_or_path: str, + plugin: DPPluginBase, + task_name: str = "mrpc", + max_seq_length: int = 128, + train_batch_size: int = 32, + eval_batch_size: int = 32, + **kwargs, + ): + super().__init__() + self.model_name_or_path = model_name_or_path + self.task_name = task_name + self.max_seq_length = max_seq_length + self.train_batch_size = train_batch_size + self.eval_batch_size = eval_batch_size + self.plugin = plugin + + self.text_fields = self.task_text_field_map[task_name] + self.num_labels = self.glue_task_num_labels[task_name] + self.tokenizer: PreTrainedTokenizer = AutoTokenizer.from_pretrained(self.model_name_or_path, use_fast=True) + self.setup() + + def setup(self): + self.dataset = datasets.load_dataset("glue", self.task_name) + + for split in self.dataset.keys(): + self.dataset[split] = self.dataset[split].map( + self.convert_to_features, + batched=True, + remove_columns=["label"], + ) + self.columns = [c for c in self.dataset[split].column_names if c in self.loader_columns] + self.dataset[split].set_format(type="torch", columns=self.columns) + + self.eval_splits = [x for x in self.dataset.keys() if "validation" in x] + + def prepare_data(self): + datasets.load_dataset("glue", self.task_name) + AutoTokenizer.from_pretrained(self.model_name_or_path, use_fast=True) + + def train_dataloader(self): + return self.plugin.prepare_dataloader( + self.dataset["train"], batch_size=self.train_batch_size, shuffle=True, drop_last=True + ) + + def val_dataloader(self): + if len(self.eval_splits) == 1: + return self.plugin.prepare_dataloader(self.dataset["validation"], batch_size=self.eval_batch_size) + elif len(self.eval_splits) > 1: + return [ + self.plugin.prepare_dataloader(self.dataset[x], batch_size=self.eval_batch_size) + for x in self.eval_splits + ] + + def test_dataloader(self): + if len(self.eval_splits) == 1: + return self.plugin.prepare_dataloader(self.dataset["test"], batch_size=self.eval_batch_size) + elif len(self.eval_splits) > 1: + return [ + self.plugin.prepare_dataloader(self.dataset[x], batch_size=self.eval_batch_size) + for x in self.eval_splits + ] + + def convert_to_features(self, example_batch): + # Either encode single sentence or sentence pairs + if len(self.text_fields) > 1: + texts_or_text_pairs = list(zip(example_batch[self.text_fields[0]], example_batch[self.text_fields[1]])) + else: + texts_or_text_pairs = example_batch[self.text_fields[0]] + + # Tokenize the text/text pairs + features = self.tokenizer.batch_encode_plus( + texts_or_text_pairs, max_length=self.max_seq_length, padding="max_length", truncation=True + ) + + # Rename label to labels to make it easier to pass to model forward + features["labels"] = example_batch["label"] + + return features diff --git a/examples/tutorial/new_api/glue_bert/finetune.py b/examples/tutorial/new_api/glue_bert/finetune.py new file mode 100644 index 0000000000000000000000000000000000000000..990822c9febaf94548546490812b1ad72ca33027 --- /dev/null +++ b/examples/tutorial/new_api/glue_bert/finetune.py @@ -0,0 +1,212 @@ +import argparse +from typing import List, Union + +import datasets +import torch +import torch.distributed as dist +import torch.nn as nn +from data import GLUEDataBuilder +from torch.optim import Optimizer +from torch.utils.data import DataLoader +from tqdm import tqdm +from transformers import AutoConfig, BertForSequenceClassification, get_linear_schedule_with_warmup + +import colossalai +from colossalai.booster import Booster +from colossalai.booster.plugin import GeminiPlugin, LowLevelZeroPlugin, TorchDDPPlugin +from colossalai.cluster import DistCoordinator +from colossalai.nn.optimizer import HybridAdam +from colossalai.utils import get_current_device + +# ============================== +# Prepare Hyperparameters +# ============================== +NUM_EPOCHS = 1 +BATCH_SIZE = 32 +LEARNING_RATE = 2.4e-5 +WEIGHT_DECAY = 0.01 +WARMUP_FRACTION = 0.1 + + +def move_to_cuda(batch): + return {k: v.cuda() for k, v in batch.items()} + + +@torch.no_grad() +def evaluate( + model: nn.Module, + test_dataloader: Union[DataLoader, List[DataLoader]], + num_labels: int, + task_name: str, + eval_splits: List[str], + coordinator: DistCoordinator, +): + metric = datasets.load_metric("glue", task_name, process_id=coordinator.rank, num_process=coordinator.world_size) + model.eval() + + def evaluate_subset(dataloader: DataLoader): + accum_loss = torch.zeros(1, device=get_current_device()) + for batch in dataloader: + batch = move_to_cuda(batch) + outputs = model(**batch) + val_loss, logits = outputs[:2] + accum_loss.add_(val_loss) + + if num_labels > 1: + preds = torch.argmax(logits, axis=1) + elif num_labels == 1: + preds = logits.squeeze() + + labels = batch["labels"] + + metric.add_batch(predictions=preds, references=labels) + + results = metric.compute() + dist.all_reduce(accum_loss.div_(len(dataloader))) + if coordinator.is_master(): + results["loss"] = accum_loss.item() / coordinator.world_size + return results + + if isinstance(test_dataloader, DataLoader): + return evaluate_subset(test_dataloader) + else: + assert len(test_dataloader) == len(eval_splits) + final_results = {} + for split, sub_loader in zip(eval_splits, test_dataloader): + results = evaluate_subset(sub_loader) + final_results.update({f"{k}_{split}": v for k, v in results.items()}) + return final_results + + +def train_epoch( + epoch: int, + model: nn.Module, + optimizer: Optimizer, + lr_scheduler, + train_dataloader: DataLoader, + booster: Booster, + coordinator: DistCoordinator, +): + model.train() + with tqdm(train_dataloader, desc=f"Epoch [{epoch + 1}/{NUM_EPOCHS}]", disable=not coordinator.is_master()) as pbar: + for batch in pbar: + # Forward pass + batch = move_to_cuda(batch) + outputs = model(**batch) + loss = outputs[0] + + # Backward and optimize + booster.backward(loss, optimizer) + optimizer.step() + optimizer.zero_grad() + lr_scheduler.step() + + # Print log info + pbar.set_postfix({"loss": loss.item()}) + + +def main(): + # ============================== + # Parse Arguments + # ============================== + parser = argparse.ArgumentParser() + parser.add_argument("-t", "--task", default="mrpc", help="GLUE task to run") + parser.add_argument( + "-p", + "--plugin", + type=str, + default="torch_ddp", + choices=["torch_ddp", "torch_ddp_fp16", "gemini", "low_level_zero"], + help="plugin to use", + ) + parser.add_argument("--target_f1", type=float, default=None, help="target f1 score. Raise exception if not reached") + args = parser.parse_args() + + # ============================== + # Launch Distributed Environment + # ============================== + colossalai.launch_from_torch(config={}, seed=42) + coordinator = DistCoordinator() + + # local_batch_size = BATCH_SIZE // coordinator.world_size + lr = LEARNING_RATE * coordinator.world_size + model_name = "bert-base-uncased" + + # ============================== + # Instantiate Plugin and Booster + # ============================== + booster_kwargs = {} + if args.plugin == "torch_ddp_fp16": + booster_kwargs["mixed_precision"] = "fp16" + if args.plugin.startswith("torch_ddp"): + plugin = TorchDDPPlugin() + elif args.plugin == "gemini": + plugin = GeminiPlugin(placement_policy="static", strict_ddp_mode=True, initial_scale=2**5) + elif args.plugin == "low_level_zero": + plugin = LowLevelZeroPlugin(initial_scale=2**5) + + booster = Booster(plugin=plugin, **booster_kwargs) + + # ============================== + # Prepare Dataloader + # ============================== + data_builder = GLUEDataBuilder( + model_name, plugin, args.task, train_batch_size=BATCH_SIZE, eval_batch_size=BATCH_SIZE + ) + train_dataloader = data_builder.train_dataloader() + test_dataloader = data_builder.test_dataloader() + + # ==================================== + # Prepare model, optimizer + # ==================================== + # bert pretrained model + config = AutoConfig.from_pretrained(model_name, num_labels=data_builder.num_labels) + model = BertForSequenceClassification.from_pretrained(model_name, config=config) + + # optimizer + no_decay = ["bias", "LayerNorm.weight"] + optimizer_grouped_parameters = [ + { + "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], + "weight_decay": WEIGHT_DECAY, + }, + { + "params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], + "weight_decay": 0.0, + }, + ] + + optimizer = HybridAdam(optimizer_grouped_parameters, lr=lr, eps=1e-8) + + # lr scheduler + total_steps = len(train_dataloader) * NUM_EPOCHS + num_warmup_steps = int(WARMUP_FRACTION * total_steps) + lr_scheduler = get_linear_schedule_with_warmup( + optimizer, + num_warmup_steps=num_warmup_steps, + num_training_steps=total_steps, + ) + + # ============================== + # Boost with ColossalAI + # ============================== + model, optimizer, _, _, lr_scheduler = booster.boost(model, optimizer, lr_scheduler=lr_scheduler) + + # ============================== + # Train model + # ============================== + for epoch in range(NUM_EPOCHS): + train_epoch(epoch, model, optimizer, lr_scheduler, train_dataloader, booster, coordinator) + + results = evaluate( + model, test_dataloader, data_builder.num_labels, args.task, data_builder.eval_splits, coordinator + ) + + if coordinator.is_master(): + print(results) + if args.target_f1 is not None and "f1" in results: + assert results["f1"] >= args.target_f1, f'f1 score {results["f1"]} is lower than target {args.target_f1}' + + +if __name__ == "__main__": + main() diff --git a/examples/tutorial/new_api/glue_bert/requirements.txt b/examples/tutorial/new_api/glue_bert/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..950c2d378f0898e7efc0c67dff11a748ff6bba75 --- /dev/null +++ b/examples/tutorial/new_api/glue_bert/requirements.txt @@ -0,0 +1,7 @@ +colossalai +datasets +torch +tqdm +transformers +scipy +scikit-learn diff --git a/examples/tutorial/new_api/glue_bert/test_ci.sh b/examples/tutorial/new_api/glue_bert/test_ci.sh new file mode 100755 index 0000000000000000000000000000000000000000..56dd431f1e603fb9209d9ed4f95a3065bcd2ecef --- /dev/null +++ b/examples/tutorial/new_api/glue_bert/test_ci.sh @@ -0,0 +1,8 @@ +#!/bin/bash +set -xe + +pip install -r requirements.txt + +for plugin in "torch_ddp" "torch_ddp_fp16" "gemini" "low_level_zero"; do + torchrun --standalone --nproc_per_node 4 finetune.py --target_f1 0.80 --plugin $plugin +done diff --git a/examples/tutorial/new_api/test_ci.sh b/examples/tutorial/new_api/test_ci.sh index 8b4475e9f1473acd9456e050cb338395f2fbb60e..a08844dbe5fa6b8474426b489ca8ca5d4d8f15df 100644 --- a/examples/tutorial/new_api/test_ci.sh +++ b/examples/tutorial/new_api/test_ci.sh @@ -1,2 +1,6 @@ -#!/usr/bin/env -echo "The CI integration will be completed when the API is stable" +#!/bin/bash +set -xe + +# FIXME(ver217): only run bert finetune to save time + +cd glue_bert && bash ./test_ci.sh && cd .. diff --git a/examples/tutorial/new_api/torch_ddp/README.md b/examples/tutorial/new_api/torch_ddp/README.md deleted file mode 100644 index e120bacb0c84269565c6185efb3666a87680d35e..0000000000000000000000000000000000000000 --- a/examples/tutorial/new_api/torch_ddp/README.md +++ /dev/null @@ -1,44 +0,0 @@ -# Distributed Data Parallel - -## 🚀 Quick Start - -This example provides a training script and an evaluation script. The training script provides an example of training ResNet on CIFAR10 dataset from scratch. - -- Training Arguments - - `-r`, `--resume`: resume from checkpoint file path - - `-c`, `--checkpoint`: the folder to save checkpoints - - `-i`, `--interval`: epoch interval to save checkpoints - - `-f`, `--fp16`: use fp16 - -- Eval Arguments - - `-e`, `--epoch`: select the epoch to evaluate - - `-c`, `--checkpoint`: the folder where checkpoints are found - - -### Train - -```bash -# train with torch DDP with fp32 -colossalai run --nproc_per_node 2 train.py -c ./ckpt-fp32 - -# train with torch DDP with mixed precision training -colossalai run --nproc_per_node 2 train.py -c ./ckpt-fp16 --fp16 -``` - -### Eval - -```bash -# evaluate fp32 training -python eval.py -c ./ckpt-fp32 -e 80 - -# evaluate fp16 mixed precision training -python eval.py -c ./ckpt-fp16 -e 80 -``` - -Expected accuracy performance will be: - -| Model | Single-GPU Baseline FP32 | Booster DDP with FP32 | Booster DDP with FP16 | -| --------- | ------------------------ | --------------------- | --------------------- | -| ResNet-18 | 85.85% | 85.03% | 85.12% | - -**Note: the baseline is adapted from the [script](https://pytorch-tutorial.readthedocs.io/en/latest/tutorial/chapter03_intermediate/3_2_2_cnn_resnet_cifar10/) to use `torchvision.models.resnet18`** diff --git a/examples/tutorial/new_api/torch_ddp/eval.py b/examples/tutorial/new_api/torch_ddp/eval.py deleted file mode 100644 index 657708ec3ff26a699a7abe12a3c3330a96194a23..0000000000000000000000000000000000000000 --- a/examples/tutorial/new_api/torch_ddp/eval.py +++ /dev/null @@ -1,48 +0,0 @@ -import argparse - -import torch -import torch.nn as nn -import torchvision -import torchvision.transforms as transforms - -# ============================== -# Parse Arguments -# ============================== -parser = argparse.ArgumentParser() -parser.add_argument('-e', '--epoch', type=int, default=80, help="resume from the epoch's checkpoint") -parser.add_argument('-c', '--checkpoint', type=str, default='./checkpoint', help="checkpoint directory") -args = parser.parse_args() - -# ============================== -# Prepare Test Dataset -# ============================== -# CIFAR-10 dataset -test_dataset = torchvision.datasets.CIFAR10(root='./data/', train=False, transform=transforms.ToTensor()) - -# Data loader -test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=128, shuffle=False) - -# ============================== -# Load Model -# ============================== -model = torchvision.models.resnet18(num_classes=10).cuda() -state_dict = torch.load(f'{args.checkpoint}/model_{args.epoch}.pth') -model.load_state_dict(state_dict) - -# ============================== -# Run Evaluation -# ============================== -model.eval() - -with torch.no_grad(): - correct = 0 - total = 0 - for images, labels in test_loader: - images = images.cuda() - labels = labels.cuda() - outputs = model(images) - _, predicted = torch.max(outputs.data, 1) - total += labels.size(0) - correct += (predicted == labels).sum().item() - - print('Accuracy of the model on the test images: {} %'.format(100 * correct / total)) diff --git a/examples/tutorial/new_api/torch_ddp/train.py b/examples/tutorial/new_api/torch_ddp/train.py deleted file mode 100644 index 4741c3151cbbbbb210cd2d5a7b4417625925ae28..0000000000000000000000000000000000000000 --- a/examples/tutorial/new_api/torch_ddp/train.py +++ /dev/null @@ -1,128 +0,0 @@ -import argparse -from pathlib import Path - -import torch -import torch.nn as nn -import torchvision -import torchvision.transforms as transforms -from torch.optim.lr_scheduler import MultiStepLR - -import colossalai -from colossalai.booster import Booster -from colossalai.booster.plugin import TorchDDPPlugin -from colossalai.cluster import DistCoordinator - -# ============================== -# Parse Arguments -# ============================== -parser = argparse.ArgumentParser() -parser.add_argument('-r', '--resume', type=int, default=-1, help="resume from the epoch's checkpoint") -parser.add_argument('-c', '--checkpoint', type=str, default='./checkpoint', help="checkpoint directory") -parser.add_argument('-i', '--interval', type=int, default=5, help="interval of saving checkpoint") -parser.add_argument('-f', '--fp16', action='store_true', help="use fp16") -args = parser.parse_args() - -# ============================== -# Prepare Checkpoint Directory -# ============================== -Path(args.checkpoint).mkdir(parents=True, exist_ok=True) - -# ============================== -# Prepare Hyperparameters -# ============================== -NUM_EPOCHS = 80 -LEARNING_RATE = 1e-3 -START_EPOCH = args.resume if args.resume >= 0 else 0 - -# ============================== -# Launch Distributed Environment -# ============================== -colossalai.launch_from_torch(config={}) -coordinator = DistCoordinator() - -# update the learning rate with linear scaling -# old_gpu_num / old_lr = new_gpu_num / new_lr -LEARNING_RATE *= coordinator.world_size - -# ============================== -# Prepare Booster -# ============================== -plugin = TorchDDPPlugin() -if args.fp16: - booster = Booster(mixed_precision='fp16', plugin=plugin) -else: - booster = Booster(plugin=plugin) - -# ============================== -# Prepare Train Dataset -# ============================== -transform = transforms.Compose( - [transforms.Pad(4), - transforms.RandomHorizontalFlip(), - transforms.RandomCrop(32), - transforms.ToTensor()]) - -# CIFAR-10 dataset -with coordinator.priority_execution(): - train_dataset = torchvision.datasets.CIFAR10(root='./data/', train=True, transform=transform, download=True) - -# ==================================== -# Prepare model, optimizer, criterion -# ==================================== -# resent50 -model = torchvision.models.resnet18(num_classes=10).cuda() - -# Loss and optimizer -criterion = nn.CrossEntropyLoss() -optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE) - -# lr scheduler -lr_scheduler = MultiStepLR(optimizer, milestones=[20, 40, 60, 80], gamma=1 / 3) - -# prepare dataloader with torch ddp plugin -train_dataloader = plugin.prepare_train_dataloader(train_dataset, batch_size=100, shuffle=True) - -# ============================== -# Resume from checkpoint -# ============================== -if args.resume >= 0: - booster.load_model(model, f'{args.checkpoint}/model_{args.resume}.pth') - booster.load_optimizer(optimizer, f'{args.checkpoint}/optimizer_{args.resume}.pth') - booster.load_lr_scheduler(lr_scheduler, f'{args.checkpoint}/lr_scheduler_{args.resume}.pth') - -# ============================== -# Boost with ColossalAI -# ============================== -model, optimizer, criterion, train_dataloader, lr_scheduler = booster.boost(model, optimizer, criterion, - train_dataloader, lr_scheduler) - -# ============================== -# Train model -# ============================== -total_step = len(train_dataloader) - -for epoch in range(START_EPOCH, NUM_EPOCHS): - for i, (images, labels) in enumerate(train_dataloader): - images = images.cuda() - labels = labels.cuda() - - # Forward pass - outputs = model(images) - loss = criterion(outputs, labels) - - # Backward and optimize - optimizer.zero_grad() - booster.backward(loss, optimizer) - optimizer.step() - - if (i + 1) % 100 == 0: - print("Epoch [{}/{}], Step [{}/{}] Loss: {:.4f}".format(epoch + 1, NUM_EPOCHS, i + 1, total_step, - loss.item())) - - lr_scheduler.step() - - # save checkpoint every 5 epoch - if (epoch + 1) % args.interval == 0: - booster.save_model(model, f'{args.checkpoint}/model_{epoch + 1}.pth') - booster.save_optimizer(optimizer, f'{args.checkpoint}/optimizer_{epoch + 1}.pth') - booster.save_lr_scheduler(lr_scheduler, f'{args.checkpoint}/lr_scheduler_{epoch + 1}.pth') diff --git a/examples/tutorial/opt/inference/batch.py b/examples/tutorial/opt/inference/batch.py index 1a0876ca833890fd2bfac3b4d9f342a05c67928f..e4e857b264a0f0d5e3bb20e4af0dbebd7178c146 100644 --- a/examples/tutorial/opt/inference/batch.py +++ b/examples/tutorial/opt/inference/batch.py @@ -1,5 +1,6 @@ +from typing import Any, Deque, Hashable, List, Tuple + import torch -from typing import List, Deque, Tuple, Hashable, Any from energonai import BatchManager, SubmitEntry, TaskEntry @@ -10,15 +11,15 @@ class BatchManagerForGeneration(BatchManager): self.pad_token_id = pad_token_id def _left_padding(self, batch_inputs): - max_len = max(len(inputs['input_ids']) for inputs in batch_inputs) - outputs = {'input_ids': [], 'attention_mask': []} + max_len = max(len(inputs["input_ids"]) for inputs in batch_inputs) + outputs = {"input_ids": [], "attention_mask": []} for inputs in batch_inputs: - input_ids, attention_mask = inputs['input_ids'], inputs['attention_mask'] + input_ids, attention_mask = inputs["input_ids"], inputs["attention_mask"] padding_len = max_len - len(input_ids) input_ids = [self.pad_token_id] * padding_len + input_ids attention_mask = [0] * padding_len + attention_mask - outputs['input_ids'].append(input_ids) - outputs['attention_mask'].append(attention_mask) + outputs["input_ids"].append(input_ids) + outputs["attention_mask"].append(attention_mask) for k in outputs: outputs[k] = torch.tensor(outputs[k]) return outputs, max_len @@ -26,7 +27,7 @@ class BatchManagerForGeneration(BatchManager): @staticmethod def _make_batch_key(entry: SubmitEntry) -> tuple: data = entry.data - return (data['top_k'], data['top_p'], data['temperature']) + return (data["top_k"], data["top_p"], data["temperature"]) def make_batch(self, q: Deque[SubmitEntry]) -> Tuple[TaskEntry, dict]: entry = q.popleft() @@ -37,7 +38,7 @@ class BatchManagerForGeneration(BatchManager): break if self._make_batch_key(entry) != self._make_batch_key(q[0]): break - if q[0].data['max_tokens'] > entry.data['max_tokens']: + if q[0].data["max_tokens"] > entry.data["max_tokens"]: break e = q.popleft() batch.append(e.data) @@ -45,12 +46,12 @@ class BatchManagerForGeneration(BatchManager): inputs, max_len = self._left_padding(batch) trunc_lens = [] for data in batch: - trunc_lens.append(max_len + data['max_tokens']) - inputs['top_k'] = entry.data['top_k'] - inputs['top_p'] = entry.data['top_p'] - inputs['temperature'] = entry.data['temperature'] - inputs['max_tokens'] = max_len + entry.data['max_tokens'] - return TaskEntry(tuple(uids), inputs), {'trunc_lens': trunc_lens} + trunc_lens.append(max_len + data["max_tokens"]) + inputs["top_k"] = entry.data["top_k"] + inputs["top_p"] = entry.data["top_p"] + inputs["temperature"] = entry.data["temperature"] + inputs["max_tokens"] = max_len + entry.data["max_tokens"] + return TaskEntry(tuple(uids), inputs), {"trunc_lens": trunc_lens} def split_batch(self, task_entry: TaskEntry, trunc_lens: List[int] = []) -> List[Tuple[Hashable, Any]]: retval = [] diff --git a/examples/tutorial/opt/inference/benchmark/locustfile.py b/examples/tutorial/opt/inference/benchmark/locustfile.py index 4d829e5d83bf73c45b20d05bfa549b5ee3b869ba..76ef9d8cb3d685754fce6f081aa4cd225fe86613 100644 --- a/examples/tutorial/opt/inference/benchmark/locustfile.py +++ b/examples/tutorial/opt/inference/benchmark/locustfile.py @@ -1,15 +1,14 @@ from locust import HttpUser, task -from json import JSONDecodeError class GenerationUser(HttpUser): @task def generate(self): - prompt = 'Question: What is the longest river on the earth? Answer:' + prompt = "Question: What is the longest river on the earth? Answer:" for i in range(4, 9): - data = {'max_tokens': 2**i, 'prompt': prompt} - with self.client.post('/generation', json=data, catch_response=True) as response: + data = {"max_tokens": 2**i, "prompt": prompt} + with self.client.post("/generation", json=data, catch_response=True) as response: if response.status_code in (200, 406): response.success() else: - response.failure('Response wrong') + response.failure("Response wrong") diff --git a/examples/tutorial/opt/inference/cache.py b/examples/tutorial/opt/inference/cache.py index 30febc44fbb3d7afb703c641e582562f90c6b8d0..1eb7dac2ea04b780971e8ad286cea7f4e76bd448 100644 --- a/examples/tutorial/opt/inference/cache.py +++ b/examples/tutorial/opt/inference/cache.py @@ -1,7 +1,7 @@ from collections import OrderedDict -from threading import Lock from contextlib import contextmanager -from typing import List, Any, Hashable, Dict +from threading import Lock +from typing import Any, Dict, Hashable, List class MissCacheError(Exception): diff --git a/examples/tutorial/opt/inference/opt_fastapi.py b/examples/tutorial/opt/inference/opt_fastapi.py index cbfc2a22e7c0c98070d940f3a42cfd66ce839028..6475284e535b88b9ebf6a4241ef1b1e332300261 100644 --- a/examples/tutorial/opt/inference/opt_fastapi.py +++ b/examples/tutorial/opt/inference/opt_fastapi.py @@ -4,20 +4,21 @@ import random from typing import Optional import uvicorn +from batch import BatchManagerForGeneration +from cache import ListCache, MissCacheError from energonai import QueueFullError, launch_engine from energonai.model import opt_6B, opt_30B, opt_125M, opt_175B from fastapi import FastAPI, HTTPException, Request from pydantic import BaseModel, Field from transformers import GPT2Tokenizer -from batch import BatchManagerForGeneration -from cache import ListCache, MissCacheError - class GenerationTaskReq(BaseModel): max_tokens: int = Field(gt=0, le=256, example=64) prompt: str = Field( - min_length=1, example='Question: Where were the 2004 Olympics held?\nAnswer: Athens, Greece\n\nQuestion: What is the longest river on the earth?\nAnswer:') + min_length=1, + example="Question: Where were the 2004 Olympics held?\nAnswer: Athens, Greece\n\nQuestion: What is the longest river on the earth?\nAnswer:", + ) top_k: Optional[int] = Field(default=None, gt=0, example=50) top_p: Optional[float] = Field(default=None, gt=0.0, lt=1.0, example=0.5) temperature: Optional[float] = Field(default=None, gt=0.0, lt=1.0, example=0.7) @@ -26,7 +27,7 @@ class GenerationTaskReq(BaseModel): app = FastAPI() -@app.post('/generation') +@app.post("/generation") async def generate(data: GenerationTaskReq, request: Request): logger.info(f'{request.client.host}:{request.client.port} - "{request.method} {request.url.path}" - {data}') key = (data.prompt, data.max_tokens) @@ -35,13 +36,13 @@ async def generate(data: GenerationTaskReq, request: Request): raise MissCacheError() outputs = cache.get(key) output = random.choice(outputs) - logger.info('Cache hit') + logger.info("Cache hit") except MissCacheError: inputs = tokenizer(data.prompt, truncation=True, max_length=512) - inputs['max_tokens'] = data.max_tokens - inputs['top_k'] = data.top_k - inputs['top_p'] = data.top_p - inputs['temperature'] = data.temperature + inputs["max_tokens"] = data.max_tokens + inputs["top_k"] = data.top_k + inputs["top_p"] = data.top_p + inputs["temperature"] = data.temperature try: uid = id(data) engine.submit(uid, inputs) @@ -52,7 +53,7 @@ async def generate(data: GenerationTaskReq, request: Request): except QueueFullError as e: raise HTTPException(status_code=406, detail=e.args[0]) - return {'text': output} + return {"text": output} @app.on_event("shutdown") @@ -64,60 +65,72 @@ async def shutdown(*_): def get_model_fn(model_name: str): - model_map = { - 'opt-125m': opt_125M, - 'opt-6.7b': opt_6B, - 'opt-30b': opt_30B, - 'opt-175b': opt_175B - } + model_map = {"opt-125m": opt_125M, "opt-6.7b": opt_6B, "opt-30b": opt_30B, "opt-175b": opt_175B} return model_map[model_name] def print_args(args: argparse.Namespace): - print('\n==> Args:') + print("\n==> Args:") for k, v in args.__dict__.items(): - print(f'{k} = {v}') + print(f"{k} = {v}") FIXED_CACHE_KEYS = [ - ('Question: What is the name of the largest continent on earth?\nAnswer: Asia\n\nQuestion: What is at the center of the solar system?\nAnswer:', 64), - ('A chat between a salesman and a student.\n\nSalesman: Hi boy, are you looking for a new phone?\nStudent: Yes, my phone is not functioning well.\nSalesman: What is your budget? \nStudent: I have received my scholarship so I am fine with any phone.\nSalesman: Great, then perhaps this latest flagship phone is just right for you.', 64), - ("English: I am happy today.\nChinese: 我今天很开心。\n\nEnglish: I am going to play basketball.\nChinese: 我一会去打篮球。\n\nEnglish: Let's celebrate our anniversary.\nChinese:", 64) + ( + "Question: What is the name of the largest continent on earth?\nAnswer: Asia\n\nQuestion: What is at the center of the solar system?\nAnswer:", + 64, + ), + ( + "A chat between a salesman and a student.\n\nSalesman: Hi boy, are you looking for a new phone?\nStudent: Yes, my phone is not functioning well.\nSalesman: What is your budget? \nStudent: I have received my scholarship so I am fine with any phone.\nSalesman: Great, then perhaps this latest flagship phone is just right for you.", + 64, + ), + ( + "English: I am happy today.\nChinese: 我今天很开心。\n\nEnglish: I am going to play basketball.\nChinese: 我一会去打篮球。\n\nEnglish: Let's celebrate our anniversary.\nChinese:", + 64, + ), ] -if __name__ == '__main__': +if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument('model', choices=['opt-125m', 'opt-6.7b', 'opt-30b', 'opt-175b']) - parser.add_argument('--tp', type=int, default=1) - parser.add_argument('--master_host', default='localhost') - parser.add_argument('--master_port', type=int, default=19990) - parser.add_argument('--rpc_port', type=int, default=19980) - parser.add_argument('--max_batch_size', type=int, default=8) - parser.add_argument('--pipe_size', type=int, default=1) - parser.add_argument('--queue_size', type=int, default=0) - parser.add_argument('--http_host', default='0.0.0.0') - parser.add_argument('--http_port', type=int, default=7070) - parser.add_argument('--checkpoint', default=None) - parser.add_argument('--cache_size', type=int, default=0) - parser.add_argument('--cache_list_size', type=int, default=1) + parser.add_argument("model", choices=["opt-125m", "opt-6.7b", "opt-30b", "opt-175b"]) + parser.add_argument("--tp", type=int, default=1) + parser.add_argument("--master_host", default="localhost") + parser.add_argument("--master_port", type=int, default=19990) + parser.add_argument("--rpc_port", type=int, default=19980) + parser.add_argument("--max_batch_size", type=int, default=8) + parser.add_argument("--pipe_size", type=int, default=1) + parser.add_argument("--queue_size", type=int, default=0) + parser.add_argument("--http_host", default="0.0.0.0") + parser.add_argument("--http_port", type=int, default=7070) + parser.add_argument("--checkpoint", default=None) + parser.add_argument("--cache_size", type=int, default=0) + parser.add_argument("--cache_list_size", type=int, default=1) args = parser.parse_args() print_args(args) model_kwargs = {} if args.checkpoint is not None: - model_kwargs['checkpoint'] = args.checkpoint + model_kwargs["checkpoint"] = args.checkpoint logger = logging.getLogger(__name__) - tokenizer = GPT2Tokenizer.from_pretrained('facebook/opt-30b') + tokenizer = GPT2Tokenizer.from_pretrained("facebook/opt-30b") if args.cache_size > 0: cache = ListCache(args.cache_size, args.cache_list_size, fixed_keys=FIXED_CACHE_KEYS) else: cache = None - engine = launch_engine(args.tp, 1, args.master_host, args.master_port, args.rpc_port, get_model_fn(args.model), - batch_manager=BatchManagerForGeneration(max_batch_size=args.max_batch_size, - pad_token_id=tokenizer.pad_token_id), - pipe_size=args.pipe_size, - queue_size=args.queue_size, - **model_kwargs) + engine = launch_engine( + args.tp, + 1, + args.master_host, + args.master_port, + args.rpc_port, + get_model_fn(args.model), + batch_manager=BatchManagerForGeneration( + max_batch_size=args.max_batch_size, pad_token_id=tokenizer.pad_token_id + ), + pipe_size=args.pipe_size, + queue_size=args.queue_size, + **model_kwargs, + ) config = uvicorn.Config(app, host=args.http_host, port=args.http_port) server = uvicorn.Server(config=config) server.run() diff --git a/examples/tutorial/opt/inference/opt_server.py b/examples/tutorial/opt/inference/opt_server.py index 8dab82622c59c313aee836a4509a27fd48ccad1c..7f591b9be1112d8bd8663bdd2fc2a2843d4a5c54 100644 --- a/examples/tutorial/opt/inference/opt_server.py +++ b/examples/tutorial/opt/inference/opt_server.py @@ -1,33 +1,36 @@ -import logging import argparse +import logging import random -from torch import Tensor -from pydantic import BaseModel, Field from typing import Optional -from energonai.model import opt_125M, opt_30B, opt_175B, opt_6B -from transformers import GPT2Tokenizer -from energonai import launch_engine, QueueFullError + +from batch import BatchManagerForGeneration +from cache import ListCache, MissCacheError +from energonai import QueueFullError, launch_engine +from energonai.model import opt_6B, opt_30B, opt_125M, opt_175B +from pydantic import BaseModel, Field from sanic import Sanic from sanic.request import Request from sanic.response import json -from sanic_ext import validate, openapi -from batch import BatchManagerForGeneration -from cache import ListCache, MissCacheError +from sanic_ext import openapi, validate +from torch import Tensor +from transformers import GPT2Tokenizer class GenerationTaskReq(BaseModel): max_tokens: int = Field(gt=0, le=256, example=64) prompt: str = Field( - min_length=1, example='Question: Where were the 2004 Olympics held?\nAnswer: Athens, Greece\n\nQuestion: What is the longest river on the earth?\nAnswer:') + min_length=1, + example="Question: Where were the 2004 Olympics held?\nAnswer: Athens, Greece\n\nQuestion: What is the longest river on the earth?\nAnswer:", + ) top_k: Optional[int] = Field(default=None, gt=0, example=50) top_p: Optional[float] = Field(default=None, gt=0.0, lt=1.0, example=0.5) temperature: Optional[float] = Field(default=None, gt=0.0, lt=1.0, example=0.7) -app = Sanic('opt') +app = Sanic("opt") -@app.post('/generation') +@app.post("/generation") @openapi.body(GenerationTaskReq) @validate(json=GenerationTaskReq) async def generate(request: Request, body: GenerationTaskReq): @@ -38,13 +41,13 @@ async def generate(request: Request, body: GenerationTaskReq): raise MissCacheError() outputs = cache.get(key) output = random.choice(outputs) - logger.info('Cache hit') + logger.info("Cache hit") except MissCacheError: inputs = tokenizer(body.prompt, truncation=True, max_length=512) - inputs['max_tokens'] = body.max_tokens - inputs['top_k'] = body.top_k - inputs['top_p'] = body.top_p - inputs['temperature'] = body.temperature + inputs["max_tokens"] = body.max_tokens + inputs["top_k"] = body.top_k + inputs["top_p"] = body.top_p + inputs["temperature"] = body.temperature try: uid = id(body) engine.submit(uid, inputs) @@ -54,9 +57,9 @@ async def generate(request: Request, body: GenerationTaskReq): if cache is not None: cache.add(key, output) except QueueFullError as e: - return json({'detail': e.args[0]}, status=406) + return json({"detail": e.args[0]}, status=406) - return json({'text': output}) + return json({"text": output}) @app.after_server_stop @@ -65,58 +68,70 @@ def shutdown(*_): def get_model_fn(model_name: str): - model_map = { - 'opt-125m': opt_125M, - 'opt-6.7b': opt_6B, - 'opt-30b': opt_30B, - 'opt-175b': opt_175B - } + model_map = {"opt-125m": opt_125M, "opt-6.7b": opt_6B, "opt-30b": opt_30B, "opt-175b": opt_175B} return model_map[model_name] def print_args(args: argparse.Namespace): - print('\n==> Args:') + print("\n==> Args:") for k, v in args.__dict__.items(): - print(f'{k} = {v}') + print(f"{k} = {v}") FIXED_CACHE_KEYS = [ - ('Question: What is the name of the largest continent on earth?\nAnswer: Asia\n\nQuestion: What is at the center of the solar system?\nAnswer:', 64), - ('A chat between a salesman and a student.\n\nSalesman: Hi boy, are you looking for a new phone?\nStudent: Yes, my phone is not functioning well.\nSalesman: What is your budget? \nStudent: I have received my scholarship so I am fine with any phone.\nSalesman: Great, then perhaps this latest flagship phone is just right for you.', 64), - ("English: I am happy today.\nChinese: 我今天很开心。\n\nEnglish: I am going to play basketball.\nChinese: 我一会去打篮球。\n\nEnglish: Let's celebrate our anniversary.\nChinese:", 64) + ( + "Question: What is the name of the largest continent on earth?\nAnswer: Asia\n\nQuestion: What is at the center of the solar system?\nAnswer:", + 64, + ), + ( + "A chat between a salesman and a student.\n\nSalesman: Hi boy, are you looking for a new phone?\nStudent: Yes, my phone is not functioning well.\nSalesman: What is your budget? \nStudent: I have received my scholarship so I am fine with any phone.\nSalesman: Great, then perhaps this latest flagship phone is just right for you.", + 64, + ), + ( + "English: I am happy today.\nChinese: 我今天很开心。\n\nEnglish: I am going to play basketball.\nChinese: 我一会去打篮球。\n\nEnglish: Let's celebrate our anniversary.\nChinese:", + 64, + ), ] -if __name__ == '__main__': +if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument('model', choices=['opt-125m', 'opt-6.7b', 'opt-30b', 'opt-175b']) - parser.add_argument('--tp', type=int, default=1) - parser.add_argument('--master_host', default='localhost') - parser.add_argument('--master_port', type=int, default=19990) - parser.add_argument('--rpc_port', type=int, default=19980) - parser.add_argument('--max_batch_size', type=int, default=8) - parser.add_argument('--pipe_size', type=int, default=1) - parser.add_argument('--queue_size', type=int, default=0) - parser.add_argument('--http_host', default='0.0.0.0') - parser.add_argument('--http_port', type=int, default=7070) - parser.add_argument('--checkpoint', default=None) - parser.add_argument('--cache_size', type=int, default=0) - parser.add_argument('--cache_list_size', type=int, default=1) + parser.add_argument("model", choices=["opt-125m", "opt-6.7b", "opt-30b", "opt-175b"]) + parser.add_argument("--tp", type=int, default=1) + parser.add_argument("--master_host", default="localhost") + parser.add_argument("--master_port", type=int, default=19990) + parser.add_argument("--rpc_port", type=int, default=19980) + parser.add_argument("--max_batch_size", type=int, default=8) + parser.add_argument("--pipe_size", type=int, default=1) + parser.add_argument("--queue_size", type=int, default=0) + parser.add_argument("--http_host", default="0.0.0.0") + parser.add_argument("--http_port", type=int, default=7070) + parser.add_argument("--checkpoint", default=None) + parser.add_argument("--cache_size", type=int, default=0) + parser.add_argument("--cache_list_size", type=int, default=1) args = parser.parse_args() print_args(args) model_kwargs = {} if args.checkpoint is not None: - model_kwargs['checkpoint'] = args.checkpoint + model_kwargs["checkpoint"] = args.checkpoint logger = logging.getLogger(__name__) - tokenizer = GPT2Tokenizer.from_pretrained('facebook/opt-30b') + tokenizer = GPT2Tokenizer.from_pretrained("facebook/opt-30b") if args.cache_size > 0: cache = ListCache(args.cache_size, args.cache_list_size, fixed_keys=FIXED_CACHE_KEYS) else: cache = None - engine = launch_engine(args.tp, 1, args.master_host, args.master_port, args.rpc_port, get_model_fn(args.model), - batch_manager=BatchManagerForGeneration(max_batch_size=args.max_batch_size, - pad_token_id=tokenizer.pad_token_id), - pipe_size=args.pipe_size, - queue_size=args.queue_size, - **model_kwargs) + engine = launch_engine( + args.tp, + 1, + args.master_host, + args.master_port, + args.rpc_port, + get_model_fn(args.model), + batch_manager=BatchManagerForGeneration( + max_batch_size=args.max_batch_size, pad_token_id=tokenizer.pad_token_id + ), + pipe_size=args.pipe_size, + queue_size=args.queue_size, + **model_kwargs, + ) app.run(args.http_host, args.http_port) diff --git a/examples/tutorial/opt/inference/script/process-opt-175b/README.md b/examples/tutorial/opt/inference/script/process-opt-175b/README.md index bc3cba72df33c3242ed35625e3467aaefb4849e1..665c459fec693b38f0f33850e70b245884fe4ffd 100644 --- a/examples/tutorial/opt/inference/script/process-opt-175b/README.md +++ b/examples/tutorial/opt/inference/script/process-opt-175b/README.md @@ -43,4 +43,3 @@ Finally, you will get 8 files in `` with following checksums: 5d63b8750d827a1aa7c8ae5b02a3a2ca reshard-model_part-6.pt f888bd41e009096804fe9a4b48c7ffe8 reshard-model_part-7.pt ``` - diff --git a/examples/tutorial/opt/inference/script/process-opt-175b/convert_ckpt.py b/examples/tutorial/opt/inference/script/process-opt-175b/convert_ckpt.py index a17ddd4fa1735fcd1a114f1fd3c87870c257d6bd..36c9001fe3f18bcbee7d022547c3915f897c4222 100644 --- a/examples/tutorial/opt/inference/script/process-opt-175b/convert_ckpt.py +++ b/examples/tutorial/opt/inference/script/process-opt-175b/convert_ckpt.py @@ -14,42 +14,45 @@ def load_json(path: str): def parse_shape_info(flat_dir: str): - data = load_json(os.path.join(flat_dir, 'shape.json')) + data = load_json(os.path.join(flat_dir, "shape.json")) flat_info = defaultdict(lambda: defaultdict(list)) for k, shape in data.items(): - matched = re.match(r'decoder.layers.\d+', k) + matched = re.match(r"decoder.layers.\d+", k) if matched is None: - flat_key = 'flat_param_0' + flat_key = "flat_param_0" else: - flat_key = f'{matched[0]}.flat_param_0' - flat_info[flat_key]['names'].append(k) - flat_info[flat_key]['shapes'].append(shape) - flat_info[flat_key]['numels'].append(int(np.prod(shape))) + flat_key = f"{matched[0]}.flat_param_0" + flat_info[flat_key]["names"].append(k) + flat_info[flat_key]["shapes"].append(shape) + flat_info[flat_key]["numels"].append(int(np.prod(shape))) return flat_info def convert(flat_dir: str, output_dir: str, part: int): - flat_path = os.path.join(flat_dir, f'reshard-model_part-{part}-shard0.pt') - output_path = os.path.join(output_dir, f'reshard-model_part-{part}.pt') - flat_meta = load_json(os.path.join(flat_dir, 'flat-meta.json')) + flat_path = os.path.join(flat_dir, f"reshard-model_part-{part}-shard0.pt") + output_path = os.path.join(output_dir, f"reshard-model_part-{part}.pt") + flat_meta = load_json(os.path.join(flat_dir, "flat-meta.json")) flat_sd = torch.load(flat_path) - print(f'Loaded flat state dict from {flat_path}') + print(f"Loaded flat state dict from {flat_path}") output_sd = {} for flat_key, param_meta in flat_meta.items(): - flat_param = flat_sd['model'][flat_key] - assert sum(param_meta['numels']) == flat_param.numel( + flat_param = flat_sd["model"][flat_key] + assert ( + sum(param_meta["numels"]) == flat_param.numel() ), f'flat {flat_key} {flat_param.numel()} vs {sum(param_meta["numels"])}' - for name, shape, param in zip(param_meta['names'], param_meta['shapes'], flat_param.split(param_meta['numels'])): + for name, shape, param in zip( + param_meta["names"], param_meta["shapes"], flat_param.split(param_meta["numels"]) + ): output_sd[name] = param.view(shape) torch.save(output_sd, output_path) - print(f'Saved unflat state dict to {output_path}') + print(f"Saved unflat state dict to {output_path}") if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument('flat_dir') - parser.add_argument('output_dir') - parser.add_argument('part', type=int) + parser.add_argument("flat_dir") + parser.add_argument("output_dir") + parser.add_argument("part", type=int) args = parser.parse_args() convert(args.flat_dir, args.output_dir, args.part) diff --git a/examples/tutorial/opt/inference/script/process-opt-175b/flat-meta.json b/examples/tutorial/opt/inference/script/process-opt-175b/flat-meta.json index 59d285565cfdb8d68ce0e5dc91aab2046c136e9c..ce70451cc4e5c662100075db1f4015b2c1928d05 100644 --- a/examples/tutorial/opt/inference/script/process-opt-175b/flat-meta.json +++ b/examples/tutorial/opt/inference/script/process-opt-175b/flat-meta.json @@ -1 +1,6944 @@ -{"flat_param_0": {"names": ["decoder.embed_tokens.weight", "decoder.embed_positions.weight", "decoder.layer_norm.weight", "decoder.layer_norm.bias"], "shapes": [[6284, 12288], [2050, 12288], [12288], [12288]], "numels": [77217792, 25190400, 12288, 12288]}, "decoder.layers.0.flat_param_0": {"names": ["decoder.layers.0.self_attn.qkv_proj.weight", "decoder.layers.0.self_attn.qkv_proj.bias", "decoder.layers.0.self_attn.out_proj.weight", "decoder.layers.0.self_attn.out_proj.bias", "decoder.layers.0.self_attn_layer_norm.weight", "decoder.layers.0.self_attn_layer_norm.bias", "decoder.layers.0.fc1.weight", "decoder.layers.0.fc1.bias", "decoder.layers.0.fc2.weight", "decoder.layers.0.fc2.bias", "decoder.layers.0.final_layer_norm.weight", "decoder.layers.0.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.1.flat_param_0": {"names": ["decoder.layers.1.self_attn.qkv_proj.weight", "decoder.layers.1.self_attn.qkv_proj.bias", "decoder.layers.1.self_attn.out_proj.weight", "decoder.layers.1.self_attn.out_proj.bias", "decoder.layers.1.self_attn_layer_norm.weight", "decoder.layers.1.self_attn_layer_norm.bias", "decoder.layers.1.fc1.weight", "decoder.layers.1.fc1.bias", "decoder.layers.1.fc2.weight", "decoder.layers.1.fc2.bias", "decoder.layers.1.final_layer_norm.weight", "decoder.layers.1.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.2.flat_param_0": {"names": ["decoder.layers.2.self_attn.qkv_proj.weight", "decoder.layers.2.self_attn.qkv_proj.bias", "decoder.layers.2.self_attn.out_proj.weight", "decoder.layers.2.self_attn.out_proj.bias", "decoder.layers.2.self_attn_layer_norm.weight", "decoder.layers.2.self_attn_layer_norm.bias", "decoder.layers.2.fc1.weight", "decoder.layers.2.fc1.bias", "decoder.layers.2.fc2.weight", "decoder.layers.2.fc2.bias", "decoder.layers.2.final_layer_norm.weight", "decoder.layers.2.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.3.flat_param_0": {"names": ["decoder.layers.3.self_attn.qkv_proj.weight", "decoder.layers.3.self_attn.qkv_proj.bias", "decoder.layers.3.self_attn.out_proj.weight", "decoder.layers.3.self_attn.out_proj.bias", "decoder.layers.3.self_attn_layer_norm.weight", "decoder.layers.3.self_attn_layer_norm.bias", "decoder.layers.3.fc1.weight", "decoder.layers.3.fc1.bias", "decoder.layers.3.fc2.weight", "decoder.layers.3.fc2.bias", "decoder.layers.3.final_layer_norm.weight", "decoder.layers.3.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.4.flat_param_0": {"names": ["decoder.layers.4.self_attn.qkv_proj.weight", "decoder.layers.4.self_attn.qkv_proj.bias", "decoder.layers.4.self_attn.out_proj.weight", "decoder.layers.4.self_attn.out_proj.bias", "decoder.layers.4.self_attn_layer_norm.weight", "decoder.layers.4.self_attn_layer_norm.bias", "decoder.layers.4.fc1.weight", "decoder.layers.4.fc1.bias", "decoder.layers.4.fc2.weight", "decoder.layers.4.fc2.bias", "decoder.layers.4.final_layer_norm.weight", "decoder.layers.4.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.5.flat_param_0": {"names": ["decoder.layers.5.self_attn.qkv_proj.weight", "decoder.layers.5.self_attn.qkv_proj.bias", "decoder.layers.5.self_attn.out_proj.weight", "decoder.layers.5.self_attn.out_proj.bias", "decoder.layers.5.self_attn_layer_norm.weight", "decoder.layers.5.self_attn_layer_norm.bias", "decoder.layers.5.fc1.weight", "decoder.layers.5.fc1.bias", "decoder.layers.5.fc2.weight", "decoder.layers.5.fc2.bias", "decoder.layers.5.final_layer_norm.weight", "decoder.layers.5.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.6.flat_param_0": {"names": ["decoder.layers.6.self_attn.qkv_proj.weight", "decoder.layers.6.self_attn.qkv_proj.bias", "decoder.layers.6.self_attn.out_proj.weight", "decoder.layers.6.self_attn.out_proj.bias", "decoder.layers.6.self_attn_layer_norm.weight", "decoder.layers.6.self_attn_layer_norm.bias", "decoder.layers.6.fc1.weight", "decoder.layers.6.fc1.bias", "decoder.layers.6.fc2.weight", "decoder.layers.6.fc2.bias", "decoder.layers.6.final_layer_norm.weight", "decoder.layers.6.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.7.flat_param_0": {"names": ["decoder.layers.7.self_attn.qkv_proj.weight", "decoder.layers.7.self_attn.qkv_proj.bias", "decoder.layers.7.self_attn.out_proj.weight", "decoder.layers.7.self_attn.out_proj.bias", "decoder.layers.7.self_attn_layer_norm.weight", "decoder.layers.7.self_attn_layer_norm.bias", "decoder.layers.7.fc1.weight", "decoder.layers.7.fc1.bias", "decoder.layers.7.fc2.weight", "decoder.layers.7.fc2.bias", "decoder.layers.7.final_layer_norm.weight", "decoder.layers.7.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.8.flat_param_0": {"names": ["decoder.layers.8.self_attn.qkv_proj.weight", "decoder.layers.8.self_attn.qkv_proj.bias", "decoder.layers.8.self_attn.out_proj.weight", "decoder.layers.8.self_attn.out_proj.bias", "decoder.layers.8.self_attn_layer_norm.weight", "decoder.layers.8.self_attn_layer_norm.bias", "decoder.layers.8.fc1.weight", "decoder.layers.8.fc1.bias", "decoder.layers.8.fc2.weight", "decoder.layers.8.fc2.bias", "decoder.layers.8.final_layer_norm.weight", "decoder.layers.8.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.9.flat_param_0": {"names": ["decoder.layers.9.self_attn.qkv_proj.weight", "decoder.layers.9.self_attn.qkv_proj.bias", "decoder.layers.9.self_attn.out_proj.weight", "decoder.layers.9.self_attn.out_proj.bias", "decoder.layers.9.self_attn_layer_norm.weight", "decoder.layers.9.self_attn_layer_norm.bias", "decoder.layers.9.fc1.weight", "decoder.layers.9.fc1.bias", "decoder.layers.9.fc2.weight", "decoder.layers.9.fc2.bias", "decoder.layers.9.final_layer_norm.weight", "decoder.layers.9.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.10.flat_param_0": {"names": ["decoder.layers.10.self_attn.qkv_proj.weight", "decoder.layers.10.self_attn.qkv_proj.bias", "decoder.layers.10.self_attn.out_proj.weight", "decoder.layers.10.self_attn.out_proj.bias", "decoder.layers.10.self_attn_layer_norm.weight", "decoder.layers.10.self_attn_layer_norm.bias", "decoder.layers.10.fc1.weight", "decoder.layers.10.fc1.bias", "decoder.layers.10.fc2.weight", "decoder.layers.10.fc2.bias", "decoder.layers.10.final_layer_norm.weight", "decoder.layers.10.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.11.flat_param_0": {"names": ["decoder.layers.11.self_attn.qkv_proj.weight", "decoder.layers.11.self_attn.qkv_proj.bias", "decoder.layers.11.self_attn.out_proj.weight", "decoder.layers.11.self_attn.out_proj.bias", "decoder.layers.11.self_attn_layer_norm.weight", "decoder.layers.11.self_attn_layer_norm.bias", "decoder.layers.11.fc1.weight", "decoder.layers.11.fc1.bias", "decoder.layers.11.fc2.weight", "decoder.layers.11.fc2.bias", "decoder.layers.11.final_layer_norm.weight", "decoder.layers.11.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.12.flat_param_0": {"names": ["decoder.layers.12.self_attn.qkv_proj.weight", "decoder.layers.12.self_attn.qkv_proj.bias", "decoder.layers.12.self_attn.out_proj.weight", "decoder.layers.12.self_attn.out_proj.bias", "decoder.layers.12.self_attn_layer_norm.weight", "decoder.layers.12.self_attn_layer_norm.bias", "decoder.layers.12.fc1.weight", "decoder.layers.12.fc1.bias", "decoder.layers.12.fc2.weight", "decoder.layers.12.fc2.bias", "decoder.layers.12.final_layer_norm.weight", "decoder.layers.12.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.13.flat_param_0": {"names": ["decoder.layers.13.self_attn.qkv_proj.weight", "decoder.layers.13.self_attn.qkv_proj.bias", "decoder.layers.13.self_attn.out_proj.weight", "decoder.layers.13.self_attn.out_proj.bias", "decoder.layers.13.self_attn_layer_norm.weight", "decoder.layers.13.self_attn_layer_norm.bias", "decoder.layers.13.fc1.weight", "decoder.layers.13.fc1.bias", "decoder.layers.13.fc2.weight", "decoder.layers.13.fc2.bias", "decoder.layers.13.final_layer_norm.weight", "decoder.layers.13.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.14.flat_param_0": {"names": ["decoder.layers.14.self_attn.qkv_proj.weight", "decoder.layers.14.self_attn.qkv_proj.bias", "decoder.layers.14.self_attn.out_proj.weight", "decoder.layers.14.self_attn.out_proj.bias", "decoder.layers.14.self_attn_layer_norm.weight", "decoder.layers.14.self_attn_layer_norm.bias", "decoder.layers.14.fc1.weight", "decoder.layers.14.fc1.bias", "decoder.layers.14.fc2.weight", "decoder.layers.14.fc2.bias", "decoder.layers.14.final_layer_norm.weight", "decoder.layers.14.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.15.flat_param_0": {"names": ["decoder.layers.15.self_attn.qkv_proj.weight", "decoder.layers.15.self_attn.qkv_proj.bias", "decoder.layers.15.self_attn.out_proj.weight", "decoder.layers.15.self_attn.out_proj.bias", "decoder.layers.15.self_attn_layer_norm.weight", "decoder.layers.15.self_attn_layer_norm.bias", "decoder.layers.15.fc1.weight", "decoder.layers.15.fc1.bias", "decoder.layers.15.fc2.weight", "decoder.layers.15.fc2.bias", "decoder.layers.15.final_layer_norm.weight", "decoder.layers.15.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.16.flat_param_0": {"names": ["decoder.layers.16.self_attn.qkv_proj.weight", "decoder.layers.16.self_attn.qkv_proj.bias", "decoder.layers.16.self_attn.out_proj.weight", "decoder.layers.16.self_attn.out_proj.bias", "decoder.layers.16.self_attn_layer_norm.weight", "decoder.layers.16.self_attn_layer_norm.bias", "decoder.layers.16.fc1.weight", "decoder.layers.16.fc1.bias", "decoder.layers.16.fc2.weight", "decoder.layers.16.fc2.bias", "decoder.layers.16.final_layer_norm.weight", "decoder.layers.16.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.17.flat_param_0": {"names": ["decoder.layers.17.self_attn.qkv_proj.weight", "decoder.layers.17.self_attn.qkv_proj.bias", "decoder.layers.17.self_attn.out_proj.weight", "decoder.layers.17.self_attn.out_proj.bias", "decoder.layers.17.self_attn_layer_norm.weight", "decoder.layers.17.self_attn_layer_norm.bias", "decoder.layers.17.fc1.weight", "decoder.layers.17.fc1.bias", "decoder.layers.17.fc2.weight", "decoder.layers.17.fc2.bias", "decoder.layers.17.final_layer_norm.weight", "decoder.layers.17.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.18.flat_param_0": {"names": ["decoder.layers.18.self_attn.qkv_proj.weight", "decoder.layers.18.self_attn.qkv_proj.bias", "decoder.layers.18.self_attn.out_proj.weight", "decoder.layers.18.self_attn.out_proj.bias", "decoder.layers.18.self_attn_layer_norm.weight", "decoder.layers.18.self_attn_layer_norm.bias", "decoder.layers.18.fc1.weight", "decoder.layers.18.fc1.bias", "decoder.layers.18.fc2.weight", "decoder.layers.18.fc2.bias", "decoder.layers.18.final_layer_norm.weight", "decoder.layers.18.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.19.flat_param_0": {"names": ["decoder.layers.19.self_attn.qkv_proj.weight", "decoder.layers.19.self_attn.qkv_proj.bias", "decoder.layers.19.self_attn.out_proj.weight", "decoder.layers.19.self_attn.out_proj.bias", "decoder.layers.19.self_attn_layer_norm.weight", "decoder.layers.19.self_attn_layer_norm.bias", "decoder.layers.19.fc1.weight", "decoder.layers.19.fc1.bias", "decoder.layers.19.fc2.weight", "decoder.layers.19.fc2.bias", "decoder.layers.19.final_layer_norm.weight", "decoder.layers.19.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.20.flat_param_0": {"names": ["decoder.layers.20.self_attn.qkv_proj.weight", "decoder.layers.20.self_attn.qkv_proj.bias", "decoder.layers.20.self_attn.out_proj.weight", "decoder.layers.20.self_attn.out_proj.bias", "decoder.layers.20.self_attn_layer_norm.weight", "decoder.layers.20.self_attn_layer_norm.bias", "decoder.layers.20.fc1.weight", "decoder.layers.20.fc1.bias", "decoder.layers.20.fc2.weight", "decoder.layers.20.fc2.bias", "decoder.layers.20.final_layer_norm.weight", "decoder.layers.20.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.21.flat_param_0": {"names": ["decoder.layers.21.self_attn.qkv_proj.weight", "decoder.layers.21.self_attn.qkv_proj.bias", "decoder.layers.21.self_attn.out_proj.weight", "decoder.layers.21.self_attn.out_proj.bias", "decoder.layers.21.self_attn_layer_norm.weight", "decoder.layers.21.self_attn_layer_norm.bias", "decoder.layers.21.fc1.weight", "decoder.layers.21.fc1.bias", "decoder.layers.21.fc2.weight", "decoder.layers.21.fc2.bias", "decoder.layers.21.final_layer_norm.weight", "decoder.layers.21.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.22.flat_param_0": {"names": ["decoder.layers.22.self_attn.qkv_proj.weight", "decoder.layers.22.self_attn.qkv_proj.bias", "decoder.layers.22.self_attn.out_proj.weight", "decoder.layers.22.self_attn.out_proj.bias", "decoder.layers.22.self_attn_layer_norm.weight", "decoder.layers.22.self_attn_layer_norm.bias", "decoder.layers.22.fc1.weight", "decoder.layers.22.fc1.bias", "decoder.layers.22.fc2.weight", "decoder.layers.22.fc2.bias", "decoder.layers.22.final_layer_norm.weight", "decoder.layers.22.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.23.flat_param_0": {"names": ["decoder.layers.23.self_attn.qkv_proj.weight", "decoder.layers.23.self_attn.qkv_proj.bias", "decoder.layers.23.self_attn.out_proj.weight", "decoder.layers.23.self_attn.out_proj.bias", "decoder.layers.23.self_attn_layer_norm.weight", "decoder.layers.23.self_attn_layer_norm.bias", "decoder.layers.23.fc1.weight", "decoder.layers.23.fc1.bias", "decoder.layers.23.fc2.weight", "decoder.layers.23.fc2.bias", "decoder.layers.23.final_layer_norm.weight", "decoder.layers.23.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.24.flat_param_0": {"names": ["decoder.layers.24.self_attn.qkv_proj.weight", "decoder.layers.24.self_attn.qkv_proj.bias", "decoder.layers.24.self_attn.out_proj.weight", "decoder.layers.24.self_attn.out_proj.bias", "decoder.layers.24.self_attn_layer_norm.weight", "decoder.layers.24.self_attn_layer_norm.bias", "decoder.layers.24.fc1.weight", "decoder.layers.24.fc1.bias", "decoder.layers.24.fc2.weight", "decoder.layers.24.fc2.bias", "decoder.layers.24.final_layer_norm.weight", "decoder.layers.24.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.25.flat_param_0": {"names": ["decoder.layers.25.self_attn.qkv_proj.weight", "decoder.layers.25.self_attn.qkv_proj.bias", "decoder.layers.25.self_attn.out_proj.weight", "decoder.layers.25.self_attn.out_proj.bias", "decoder.layers.25.self_attn_layer_norm.weight", "decoder.layers.25.self_attn_layer_norm.bias", "decoder.layers.25.fc1.weight", "decoder.layers.25.fc1.bias", "decoder.layers.25.fc2.weight", "decoder.layers.25.fc2.bias", "decoder.layers.25.final_layer_norm.weight", "decoder.layers.25.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.26.flat_param_0": {"names": ["decoder.layers.26.self_attn.qkv_proj.weight", "decoder.layers.26.self_attn.qkv_proj.bias", "decoder.layers.26.self_attn.out_proj.weight", "decoder.layers.26.self_attn.out_proj.bias", "decoder.layers.26.self_attn_layer_norm.weight", "decoder.layers.26.self_attn_layer_norm.bias", "decoder.layers.26.fc1.weight", "decoder.layers.26.fc1.bias", "decoder.layers.26.fc2.weight", "decoder.layers.26.fc2.bias", "decoder.layers.26.final_layer_norm.weight", "decoder.layers.26.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.27.flat_param_0": {"names": ["decoder.layers.27.self_attn.qkv_proj.weight", "decoder.layers.27.self_attn.qkv_proj.bias", "decoder.layers.27.self_attn.out_proj.weight", "decoder.layers.27.self_attn.out_proj.bias", "decoder.layers.27.self_attn_layer_norm.weight", "decoder.layers.27.self_attn_layer_norm.bias", "decoder.layers.27.fc1.weight", "decoder.layers.27.fc1.bias", "decoder.layers.27.fc2.weight", "decoder.layers.27.fc2.bias", "decoder.layers.27.final_layer_norm.weight", "decoder.layers.27.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.28.flat_param_0": {"names": ["decoder.layers.28.self_attn.qkv_proj.weight", "decoder.layers.28.self_attn.qkv_proj.bias", "decoder.layers.28.self_attn.out_proj.weight", "decoder.layers.28.self_attn.out_proj.bias", "decoder.layers.28.self_attn_layer_norm.weight", "decoder.layers.28.self_attn_layer_norm.bias", "decoder.layers.28.fc1.weight", "decoder.layers.28.fc1.bias", "decoder.layers.28.fc2.weight", "decoder.layers.28.fc2.bias", "decoder.layers.28.final_layer_norm.weight", "decoder.layers.28.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.29.flat_param_0": {"names": ["decoder.layers.29.self_attn.qkv_proj.weight", "decoder.layers.29.self_attn.qkv_proj.bias", "decoder.layers.29.self_attn.out_proj.weight", "decoder.layers.29.self_attn.out_proj.bias", "decoder.layers.29.self_attn_layer_norm.weight", "decoder.layers.29.self_attn_layer_norm.bias", "decoder.layers.29.fc1.weight", "decoder.layers.29.fc1.bias", "decoder.layers.29.fc2.weight", "decoder.layers.29.fc2.bias", "decoder.layers.29.final_layer_norm.weight", "decoder.layers.29.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.30.flat_param_0": {"names": ["decoder.layers.30.self_attn.qkv_proj.weight", "decoder.layers.30.self_attn.qkv_proj.bias", "decoder.layers.30.self_attn.out_proj.weight", "decoder.layers.30.self_attn.out_proj.bias", "decoder.layers.30.self_attn_layer_norm.weight", "decoder.layers.30.self_attn_layer_norm.bias", "decoder.layers.30.fc1.weight", "decoder.layers.30.fc1.bias", "decoder.layers.30.fc2.weight", "decoder.layers.30.fc2.bias", "decoder.layers.30.final_layer_norm.weight", "decoder.layers.30.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.31.flat_param_0": {"names": ["decoder.layers.31.self_attn.qkv_proj.weight", "decoder.layers.31.self_attn.qkv_proj.bias", "decoder.layers.31.self_attn.out_proj.weight", "decoder.layers.31.self_attn.out_proj.bias", "decoder.layers.31.self_attn_layer_norm.weight", "decoder.layers.31.self_attn_layer_norm.bias", "decoder.layers.31.fc1.weight", "decoder.layers.31.fc1.bias", "decoder.layers.31.fc2.weight", "decoder.layers.31.fc2.bias", "decoder.layers.31.final_layer_norm.weight", "decoder.layers.31.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.32.flat_param_0": {"names": ["decoder.layers.32.self_attn.qkv_proj.weight", "decoder.layers.32.self_attn.qkv_proj.bias", "decoder.layers.32.self_attn.out_proj.weight", "decoder.layers.32.self_attn.out_proj.bias", "decoder.layers.32.self_attn_layer_norm.weight", "decoder.layers.32.self_attn_layer_norm.bias", "decoder.layers.32.fc1.weight", "decoder.layers.32.fc1.bias", "decoder.layers.32.fc2.weight", "decoder.layers.32.fc2.bias", "decoder.layers.32.final_layer_norm.weight", "decoder.layers.32.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.33.flat_param_0": {"names": ["decoder.layers.33.self_attn.qkv_proj.weight", "decoder.layers.33.self_attn.qkv_proj.bias", "decoder.layers.33.self_attn.out_proj.weight", "decoder.layers.33.self_attn.out_proj.bias", "decoder.layers.33.self_attn_layer_norm.weight", "decoder.layers.33.self_attn_layer_norm.bias", "decoder.layers.33.fc1.weight", "decoder.layers.33.fc1.bias", "decoder.layers.33.fc2.weight", "decoder.layers.33.fc2.bias", "decoder.layers.33.final_layer_norm.weight", "decoder.layers.33.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.34.flat_param_0": {"names": ["decoder.layers.34.self_attn.qkv_proj.weight", "decoder.layers.34.self_attn.qkv_proj.bias", "decoder.layers.34.self_attn.out_proj.weight", "decoder.layers.34.self_attn.out_proj.bias", "decoder.layers.34.self_attn_layer_norm.weight", "decoder.layers.34.self_attn_layer_norm.bias", "decoder.layers.34.fc1.weight", "decoder.layers.34.fc1.bias", "decoder.layers.34.fc2.weight", "decoder.layers.34.fc2.bias", "decoder.layers.34.final_layer_norm.weight", "decoder.layers.34.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.35.flat_param_0": {"names": ["decoder.layers.35.self_attn.qkv_proj.weight", "decoder.layers.35.self_attn.qkv_proj.bias", "decoder.layers.35.self_attn.out_proj.weight", "decoder.layers.35.self_attn.out_proj.bias", "decoder.layers.35.self_attn_layer_norm.weight", "decoder.layers.35.self_attn_layer_norm.bias", "decoder.layers.35.fc1.weight", "decoder.layers.35.fc1.bias", "decoder.layers.35.fc2.weight", "decoder.layers.35.fc2.bias", "decoder.layers.35.final_layer_norm.weight", "decoder.layers.35.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.36.flat_param_0": {"names": ["decoder.layers.36.self_attn.qkv_proj.weight", "decoder.layers.36.self_attn.qkv_proj.bias", "decoder.layers.36.self_attn.out_proj.weight", "decoder.layers.36.self_attn.out_proj.bias", "decoder.layers.36.self_attn_layer_norm.weight", "decoder.layers.36.self_attn_layer_norm.bias", "decoder.layers.36.fc1.weight", "decoder.layers.36.fc1.bias", "decoder.layers.36.fc2.weight", "decoder.layers.36.fc2.bias", "decoder.layers.36.final_layer_norm.weight", "decoder.layers.36.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.37.flat_param_0": {"names": ["decoder.layers.37.self_attn.qkv_proj.weight", "decoder.layers.37.self_attn.qkv_proj.bias", "decoder.layers.37.self_attn.out_proj.weight", "decoder.layers.37.self_attn.out_proj.bias", "decoder.layers.37.self_attn_layer_norm.weight", "decoder.layers.37.self_attn_layer_norm.bias", "decoder.layers.37.fc1.weight", "decoder.layers.37.fc1.bias", "decoder.layers.37.fc2.weight", "decoder.layers.37.fc2.bias", "decoder.layers.37.final_layer_norm.weight", "decoder.layers.37.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.38.flat_param_0": {"names": ["decoder.layers.38.self_attn.qkv_proj.weight", "decoder.layers.38.self_attn.qkv_proj.bias", "decoder.layers.38.self_attn.out_proj.weight", "decoder.layers.38.self_attn.out_proj.bias", "decoder.layers.38.self_attn_layer_norm.weight", "decoder.layers.38.self_attn_layer_norm.bias", "decoder.layers.38.fc1.weight", "decoder.layers.38.fc1.bias", "decoder.layers.38.fc2.weight", "decoder.layers.38.fc2.bias", "decoder.layers.38.final_layer_norm.weight", "decoder.layers.38.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.39.flat_param_0": {"names": ["decoder.layers.39.self_attn.qkv_proj.weight", "decoder.layers.39.self_attn.qkv_proj.bias", "decoder.layers.39.self_attn.out_proj.weight", "decoder.layers.39.self_attn.out_proj.bias", "decoder.layers.39.self_attn_layer_norm.weight", "decoder.layers.39.self_attn_layer_norm.bias", "decoder.layers.39.fc1.weight", "decoder.layers.39.fc1.bias", "decoder.layers.39.fc2.weight", "decoder.layers.39.fc2.bias", "decoder.layers.39.final_layer_norm.weight", "decoder.layers.39.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.40.flat_param_0": {"names": ["decoder.layers.40.self_attn.qkv_proj.weight", "decoder.layers.40.self_attn.qkv_proj.bias", "decoder.layers.40.self_attn.out_proj.weight", "decoder.layers.40.self_attn.out_proj.bias", "decoder.layers.40.self_attn_layer_norm.weight", "decoder.layers.40.self_attn_layer_norm.bias", "decoder.layers.40.fc1.weight", "decoder.layers.40.fc1.bias", "decoder.layers.40.fc2.weight", "decoder.layers.40.fc2.bias", "decoder.layers.40.final_layer_norm.weight", "decoder.layers.40.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.41.flat_param_0": {"names": ["decoder.layers.41.self_attn.qkv_proj.weight", "decoder.layers.41.self_attn.qkv_proj.bias", "decoder.layers.41.self_attn.out_proj.weight", "decoder.layers.41.self_attn.out_proj.bias", "decoder.layers.41.self_attn_layer_norm.weight", "decoder.layers.41.self_attn_layer_norm.bias", "decoder.layers.41.fc1.weight", "decoder.layers.41.fc1.bias", "decoder.layers.41.fc2.weight", "decoder.layers.41.fc2.bias", "decoder.layers.41.final_layer_norm.weight", "decoder.layers.41.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.42.flat_param_0": {"names": ["decoder.layers.42.self_attn.qkv_proj.weight", "decoder.layers.42.self_attn.qkv_proj.bias", "decoder.layers.42.self_attn.out_proj.weight", "decoder.layers.42.self_attn.out_proj.bias", "decoder.layers.42.self_attn_layer_norm.weight", "decoder.layers.42.self_attn_layer_norm.bias", "decoder.layers.42.fc1.weight", "decoder.layers.42.fc1.bias", "decoder.layers.42.fc2.weight", "decoder.layers.42.fc2.bias", "decoder.layers.42.final_layer_norm.weight", "decoder.layers.42.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.43.flat_param_0": {"names": ["decoder.layers.43.self_attn.qkv_proj.weight", "decoder.layers.43.self_attn.qkv_proj.bias", "decoder.layers.43.self_attn.out_proj.weight", "decoder.layers.43.self_attn.out_proj.bias", "decoder.layers.43.self_attn_layer_norm.weight", "decoder.layers.43.self_attn_layer_norm.bias", "decoder.layers.43.fc1.weight", "decoder.layers.43.fc1.bias", "decoder.layers.43.fc2.weight", "decoder.layers.43.fc2.bias", "decoder.layers.43.final_layer_norm.weight", "decoder.layers.43.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.44.flat_param_0": {"names": ["decoder.layers.44.self_attn.qkv_proj.weight", "decoder.layers.44.self_attn.qkv_proj.bias", "decoder.layers.44.self_attn.out_proj.weight", "decoder.layers.44.self_attn.out_proj.bias", "decoder.layers.44.self_attn_layer_norm.weight", "decoder.layers.44.self_attn_layer_norm.bias", "decoder.layers.44.fc1.weight", "decoder.layers.44.fc1.bias", "decoder.layers.44.fc2.weight", "decoder.layers.44.fc2.bias", "decoder.layers.44.final_layer_norm.weight", "decoder.layers.44.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.45.flat_param_0": {"names": ["decoder.layers.45.self_attn.qkv_proj.weight", "decoder.layers.45.self_attn.qkv_proj.bias", "decoder.layers.45.self_attn.out_proj.weight", "decoder.layers.45.self_attn.out_proj.bias", "decoder.layers.45.self_attn_layer_norm.weight", "decoder.layers.45.self_attn_layer_norm.bias", "decoder.layers.45.fc1.weight", "decoder.layers.45.fc1.bias", "decoder.layers.45.fc2.weight", "decoder.layers.45.fc2.bias", "decoder.layers.45.final_layer_norm.weight", "decoder.layers.45.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.46.flat_param_0": {"names": ["decoder.layers.46.self_attn.qkv_proj.weight", "decoder.layers.46.self_attn.qkv_proj.bias", "decoder.layers.46.self_attn.out_proj.weight", "decoder.layers.46.self_attn.out_proj.bias", "decoder.layers.46.self_attn_layer_norm.weight", "decoder.layers.46.self_attn_layer_norm.bias", "decoder.layers.46.fc1.weight", "decoder.layers.46.fc1.bias", "decoder.layers.46.fc2.weight", "decoder.layers.46.fc2.bias", "decoder.layers.46.final_layer_norm.weight", "decoder.layers.46.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.47.flat_param_0": {"names": ["decoder.layers.47.self_attn.qkv_proj.weight", "decoder.layers.47.self_attn.qkv_proj.bias", "decoder.layers.47.self_attn.out_proj.weight", "decoder.layers.47.self_attn.out_proj.bias", "decoder.layers.47.self_attn_layer_norm.weight", "decoder.layers.47.self_attn_layer_norm.bias", "decoder.layers.47.fc1.weight", "decoder.layers.47.fc1.bias", "decoder.layers.47.fc2.weight", "decoder.layers.47.fc2.bias", "decoder.layers.47.final_layer_norm.weight", "decoder.layers.47.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.48.flat_param_0": {"names": ["decoder.layers.48.self_attn.qkv_proj.weight", "decoder.layers.48.self_attn.qkv_proj.bias", "decoder.layers.48.self_attn.out_proj.weight", "decoder.layers.48.self_attn.out_proj.bias", "decoder.layers.48.self_attn_layer_norm.weight", "decoder.layers.48.self_attn_layer_norm.bias", "decoder.layers.48.fc1.weight", "decoder.layers.48.fc1.bias", "decoder.layers.48.fc2.weight", "decoder.layers.48.fc2.bias", "decoder.layers.48.final_layer_norm.weight", "decoder.layers.48.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.49.flat_param_0": {"names": ["decoder.layers.49.self_attn.qkv_proj.weight", "decoder.layers.49.self_attn.qkv_proj.bias", "decoder.layers.49.self_attn.out_proj.weight", "decoder.layers.49.self_attn.out_proj.bias", "decoder.layers.49.self_attn_layer_norm.weight", "decoder.layers.49.self_attn_layer_norm.bias", "decoder.layers.49.fc1.weight", "decoder.layers.49.fc1.bias", "decoder.layers.49.fc2.weight", "decoder.layers.49.fc2.bias", "decoder.layers.49.final_layer_norm.weight", "decoder.layers.49.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.50.flat_param_0": {"names": ["decoder.layers.50.self_attn.qkv_proj.weight", "decoder.layers.50.self_attn.qkv_proj.bias", "decoder.layers.50.self_attn.out_proj.weight", "decoder.layers.50.self_attn.out_proj.bias", "decoder.layers.50.self_attn_layer_norm.weight", "decoder.layers.50.self_attn_layer_norm.bias", "decoder.layers.50.fc1.weight", "decoder.layers.50.fc1.bias", "decoder.layers.50.fc2.weight", "decoder.layers.50.fc2.bias", "decoder.layers.50.final_layer_norm.weight", "decoder.layers.50.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.51.flat_param_0": {"names": ["decoder.layers.51.self_attn.qkv_proj.weight", "decoder.layers.51.self_attn.qkv_proj.bias", "decoder.layers.51.self_attn.out_proj.weight", "decoder.layers.51.self_attn.out_proj.bias", "decoder.layers.51.self_attn_layer_norm.weight", "decoder.layers.51.self_attn_layer_norm.bias", "decoder.layers.51.fc1.weight", "decoder.layers.51.fc1.bias", "decoder.layers.51.fc2.weight", "decoder.layers.51.fc2.bias", "decoder.layers.51.final_layer_norm.weight", "decoder.layers.51.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.52.flat_param_0": {"names": ["decoder.layers.52.self_attn.qkv_proj.weight", "decoder.layers.52.self_attn.qkv_proj.bias", "decoder.layers.52.self_attn.out_proj.weight", "decoder.layers.52.self_attn.out_proj.bias", "decoder.layers.52.self_attn_layer_norm.weight", "decoder.layers.52.self_attn_layer_norm.bias", "decoder.layers.52.fc1.weight", "decoder.layers.52.fc1.bias", "decoder.layers.52.fc2.weight", "decoder.layers.52.fc2.bias", "decoder.layers.52.final_layer_norm.weight", "decoder.layers.52.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.53.flat_param_0": {"names": ["decoder.layers.53.self_attn.qkv_proj.weight", "decoder.layers.53.self_attn.qkv_proj.bias", "decoder.layers.53.self_attn.out_proj.weight", "decoder.layers.53.self_attn.out_proj.bias", "decoder.layers.53.self_attn_layer_norm.weight", "decoder.layers.53.self_attn_layer_norm.bias", "decoder.layers.53.fc1.weight", "decoder.layers.53.fc1.bias", "decoder.layers.53.fc2.weight", "decoder.layers.53.fc2.bias", "decoder.layers.53.final_layer_norm.weight", "decoder.layers.53.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.54.flat_param_0": {"names": ["decoder.layers.54.self_attn.qkv_proj.weight", "decoder.layers.54.self_attn.qkv_proj.bias", "decoder.layers.54.self_attn.out_proj.weight", "decoder.layers.54.self_attn.out_proj.bias", "decoder.layers.54.self_attn_layer_norm.weight", "decoder.layers.54.self_attn_layer_norm.bias", "decoder.layers.54.fc1.weight", "decoder.layers.54.fc1.bias", "decoder.layers.54.fc2.weight", "decoder.layers.54.fc2.bias", "decoder.layers.54.final_layer_norm.weight", "decoder.layers.54.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.55.flat_param_0": {"names": ["decoder.layers.55.self_attn.qkv_proj.weight", "decoder.layers.55.self_attn.qkv_proj.bias", "decoder.layers.55.self_attn.out_proj.weight", "decoder.layers.55.self_attn.out_proj.bias", "decoder.layers.55.self_attn_layer_norm.weight", "decoder.layers.55.self_attn_layer_norm.bias", "decoder.layers.55.fc1.weight", "decoder.layers.55.fc1.bias", "decoder.layers.55.fc2.weight", "decoder.layers.55.fc2.bias", "decoder.layers.55.final_layer_norm.weight", "decoder.layers.55.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.56.flat_param_0": {"names": ["decoder.layers.56.self_attn.qkv_proj.weight", "decoder.layers.56.self_attn.qkv_proj.bias", "decoder.layers.56.self_attn.out_proj.weight", "decoder.layers.56.self_attn.out_proj.bias", "decoder.layers.56.self_attn_layer_norm.weight", "decoder.layers.56.self_attn_layer_norm.bias", "decoder.layers.56.fc1.weight", "decoder.layers.56.fc1.bias", "decoder.layers.56.fc2.weight", "decoder.layers.56.fc2.bias", "decoder.layers.56.final_layer_norm.weight", "decoder.layers.56.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.57.flat_param_0": {"names": ["decoder.layers.57.self_attn.qkv_proj.weight", "decoder.layers.57.self_attn.qkv_proj.bias", "decoder.layers.57.self_attn.out_proj.weight", "decoder.layers.57.self_attn.out_proj.bias", "decoder.layers.57.self_attn_layer_norm.weight", "decoder.layers.57.self_attn_layer_norm.bias", "decoder.layers.57.fc1.weight", "decoder.layers.57.fc1.bias", "decoder.layers.57.fc2.weight", "decoder.layers.57.fc2.bias", "decoder.layers.57.final_layer_norm.weight", "decoder.layers.57.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.58.flat_param_0": {"names": ["decoder.layers.58.self_attn.qkv_proj.weight", "decoder.layers.58.self_attn.qkv_proj.bias", "decoder.layers.58.self_attn.out_proj.weight", "decoder.layers.58.self_attn.out_proj.bias", "decoder.layers.58.self_attn_layer_norm.weight", "decoder.layers.58.self_attn_layer_norm.bias", "decoder.layers.58.fc1.weight", "decoder.layers.58.fc1.bias", "decoder.layers.58.fc2.weight", "decoder.layers.58.fc2.bias", "decoder.layers.58.final_layer_norm.weight", "decoder.layers.58.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.59.flat_param_0": {"names": ["decoder.layers.59.self_attn.qkv_proj.weight", "decoder.layers.59.self_attn.qkv_proj.bias", "decoder.layers.59.self_attn.out_proj.weight", "decoder.layers.59.self_attn.out_proj.bias", "decoder.layers.59.self_attn_layer_norm.weight", "decoder.layers.59.self_attn_layer_norm.bias", "decoder.layers.59.fc1.weight", "decoder.layers.59.fc1.bias", "decoder.layers.59.fc2.weight", "decoder.layers.59.fc2.bias", "decoder.layers.59.final_layer_norm.weight", "decoder.layers.59.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.60.flat_param_0": {"names": ["decoder.layers.60.self_attn.qkv_proj.weight", "decoder.layers.60.self_attn.qkv_proj.bias", "decoder.layers.60.self_attn.out_proj.weight", "decoder.layers.60.self_attn.out_proj.bias", "decoder.layers.60.self_attn_layer_norm.weight", "decoder.layers.60.self_attn_layer_norm.bias", "decoder.layers.60.fc1.weight", "decoder.layers.60.fc1.bias", "decoder.layers.60.fc2.weight", "decoder.layers.60.fc2.bias", "decoder.layers.60.final_layer_norm.weight", "decoder.layers.60.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.61.flat_param_0": {"names": ["decoder.layers.61.self_attn.qkv_proj.weight", "decoder.layers.61.self_attn.qkv_proj.bias", "decoder.layers.61.self_attn.out_proj.weight", "decoder.layers.61.self_attn.out_proj.bias", "decoder.layers.61.self_attn_layer_norm.weight", "decoder.layers.61.self_attn_layer_norm.bias", "decoder.layers.61.fc1.weight", "decoder.layers.61.fc1.bias", "decoder.layers.61.fc2.weight", "decoder.layers.61.fc2.bias", "decoder.layers.61.final_layer_norm.weight", "decoder.layers.61.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.62.flat_param_0": {"names": ["decoder.layers.62.self_attn.qkv_proj.weight", "decoder.layers.62.self_attn.qkv_proj.bias", "decoder.layers.62.self_attn.out_proj.weight", "decoder.layers.62.self_attn.out_proj.bias", "decoder.layers.62.self_attn_layer_norm.weight", "decoder.layers.62.self_attn_layer_norm.bias", "decoder.layers.62.fc1.weight", "decoder.layers.62.fc1.bias", "decoder.layers.62.fc2.weight", "decoder.layers.62.fc2.bias", "decoder.layers.62.final_layer_norm.weight", "decoder.layers.62.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.63.flat_param_0": {"names": ["decoder.layers.63.self_attn.qkv_proj.weight", "decoder.layers.63.self_attn.qkv_proj.bias", "decoder.layers.63.self_attn.out_proj.weight", "decoder.layers.63.self_attn.out_proj.bias", "decoder.layers.63.self_attn_layer_norm.weight", "decoder.layers.63.self_attn_layer_norm.bias", "decoder.layers.63.fc1.weight", "decoder.layers.63.fc1.bias", "decoder.layers.63.fc2.weight", "decoder.layers.63.fc2.bias", "decoder.layers.63.final_layer_norm.weight", "decoder.layers.63.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.64.flat_param_0": {"names": ["decoder.layers.64.self_attn.qkv_proj.weight", "decoder.layers.64.self_attn.qkv_proj.bias", "decoder.layers.64.self_attn.out_proj.weight", "decoder.layers.64.self_attn.out_proj.bias", "decoder.layers.64.self_attn_layer_norm.weight", "decoder.layers.64.self_attn_layer_norm.bias", "decoder.layers.64.fc1.weight", "decoder.layers.64.fc1.bias", "decoder.layers.64.fc2.weight", "decoder.layers.64.fc2.bias", "decoder.layers.64.final_layer_norm.weight", "decoder.layers.64.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.65.flat_param_0": {"names": ["decoder.layers.65.self_attn.qkv_proj.weight", "decoder.layers.65.self_attn.qkv_proj.bias", "decoder.layers.65.self_attn.out_proj.weight", "decoder.layers.65.self_attn.out_proj.bias", "decoder.layers.65.self_attn_layer_norm.weight", "decoder.layers.65.self_attn_layer_norm.bias", "decoder.layers.65.fc1.weight", "decoder.layers.65.fc1.bias", "decoder.layers.65.fc2.weight", "decoder.layers.65.fc2.bias", "decoder.layers.65.final_layer_norm.weight", "decoder.layers.65.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.66.flat_param_0": {"names": ["decoder.layers.66.self_attn.qkv_proj.weight", "decoder.layers.66.self_attn.qkv_proj.bias", "decoder.layers.66.self_attn.out_proj.weight", "decoder.layers.66.self_attn.out_proj.bias", "decoder.layers.66.self_attn_layer_norm.weight", "decoder.layers.66.self_attn_layer_norm.bias", "decoder.layers.66.fc1.weight", "decoder.layers.66.fc1.bias", "decoder.layers.66.fc2.weight", "decoder.layers.66.fc2.bias", "decoder.layers.66.final_layer_norm.weight", "decoder.layers.66.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.67.flat_param_0": {"names": ["decoder.layers.67.self_attn.qkv_proj.weight", "decoder.layers.67.self_attn.qkv_proj.bias", "decoder.layers.67.self_attn.out_proj.weight", "decoder.layers.67.self_attn.out_proj.bias", "decoder.layers.67.self_attn_layer_norm.weight", "decoder.layers.67.self_attn_layer_norm.bias", "decoder.layers.67.fc1.weight", "decoder.layers.67.fc1.bias", "decoder.layers.67.fc2.weight", "decoder.layers.67.fc2.bias", "decoder.layers.67.final_layer_norm.weight", "decoder.layers.67.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.68.flat_param_0": {"names": ["decoder.layers.68.self_attn.qkv_proj.weight", "decoder.layers.68.self_attn.qkv_proj.bias", "decoder.layers.68.self_attn.out_proj.weight", "decoder.layers.68.self_attn.out_proj.bias", "decoder.layers.68.self_attn_layer_norm.weight", "decoder.layers.68.self_attn_layer_norm.bias", "decoder.layers.68.fc1.weight", "decoder.layers.68.fc1.bias", "decoder.layers.68.fc2.weight", "decoder.layers.68.fc2.bias", "decoder.layers.68.final_layer_norm.weight", "decoder.layers.68.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.69.flat_param_0": {"names": ["decoder.layers.69.self_attn.qkv_proj.weight", "decoder.layers.69.self_attn.qkv_proj.bias", "decoder.layers.69.self_attn.out_proj.weight", "decoder.layers.69.self_attn.out_proj.bias", "decoder.layers.69.self_attn_layer_norm.weight", "decoder.layers.69.self_attn_layer_norm.bias", "decoder.layers.69.fc1.weight", "decoder.layers.69.fc1.bias", "decoder.layers.69.fc2.weight", "decoder.layers.69.fc2.bias", "decoder.layers.69.final_layer_norm.weight", "decoder.layers.69.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.70.flat_param_0": {"names": ["decoder.layers.70.self_attn.qkv_proj.weight", "decoder.layers.70.self_attn.qkv_proj.bias", "decoder.layers.70.self_attn.out_proj.weight", "decoder.layers.70.self_attn.out_proj.bias", "decoder.layers.70.self_attn_layer_norm.weight", "decoder.layers.70.self_attn_layer_norm.bias", "decoder.layers.70.fc1.weight", "decoder.layers.70.fc1.bias", "decoder.layers.70.fc2.weight", "decoder.layers.70.fc2.bias", "decoder.layers.70.final_layer_norm.weight", "decoder.layers.70.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.71.flat_param_0": {"names": ["decoder.layers.71.self_attn.qkv_proj.weight", "decoder.layers.71.self_attn.qkv_proj.bias", "decoder.layers.71.self_attn.out_proj.weight", "decoder.layers.71.self_attn.out_proj.bias", "decoder.layers.71.self_attn_layer_norm.weight", "decoder.layers.71.self_attn_layer_norm.bias", "decoder.layers.71.fc1.weight", "decoder.layers.71.fc1.bias", "decoder.layers.71.fc2.weight", "decoder.layers.71.fc2.bias", "decoder.layers.71.final_layer_norm.weight", "decoder.layers.71.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.72.flat_param_0": {"names": ["decoder.layers.72.self_attn.qkv_proj.weight", "decoder.layers.72.self_attn.qkv_proj.bias", "decoder.layers.72.self_attn.out_proj.weight", "decoder.layers.72.self_attn.out_proj.bias", "decoder.layers.72.self_attn_layer_norm.weight", "decoder.layers.72.self_attn_layer_norm.bias", "decoder.layers.72.fc1.weight", "decoder.layers.72.fc1.bias", "decoder.layers.72.fc2.weight", "decoder.layers.72.fc2.bias", "decoder.layers.72.final_layer_norm.weight", "decoder.layers.72.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.73.flat_param_0": {"names": ["decoder.layers.73.self_attn.qkv_proj.weight", "decoder.layers.73.self_attn.qkv_proj.bias", "decoder.layers.73.self_attn.out_proj.weight", "decoder.layers.73.self_attn.out_proj.bias", "decoder.layers.73.self_attn_layer_norm.weight", "decoder.layers.73.self_attn_layer_norm.bias", "decoder.layers.73.fc1.weight", "decoder.layers.73.fc1.bias", "decoder.layers.73.fc2.weight", "decoder.layers.73.fc2.bias", "decoder.layers.73.final_layer_norm.weight", "decoder.layers.73.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.74.flat_param_0": {"names": ["decoder.layers.74.self_attn.qkv_proj.weight", "decoder.layers.74.self_attn.qkv_proj.bias", "decoder.layers.74.self_attn.out_proj.weight", "decoder.layers.74.self_attn.out_proj.bias", "decoder.layers.74.self_attn_layer_norm.weight", "decoder.layers.74.self_attn_layer_norm.bias", "decoder.layers.74.fc1.weight", "decoder.layers.74.fc1.bias", "decoder.layers.74.fc2.weight", "decoder.layers.74.fc2.bias", "decoder.layers.74.final_layer_norm.weight", "decoder.layers.74.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.75.flat_param_0": {"names": ["decoder.layers.75.self_attn.qkv_proj.weight", "decoder.layers.75.self_attn.qkv_proj.bias", "decoder.layers.75.self_attn.out_proj.weight", "decoder.layers.75.self_attn.out_proj.bias", "decoder.layers.75.self_attn_layer_norm.weight", "decoder.layers.75.self_attn_layer_norm.bias", "decoder.layers.75.fc1.weight", "decoder.layers.75.fc1.bias", "decoder.layers.75.fc2.weight", "decoder.layers.75.fc2.bias", "decoder.layers.75.final_layer_norm.weight", "decoder.layers.75.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.76.flat_param_0": {"names": ["decoder.layers.76.self_attn.qkv_proj.weight", "decoder.layers.76.self_attn.qkv_proj.bias", "decoder.layers.76.self_attn.out_proj.weight", "decoder.layers.76.self_attn.out_proj.bias", "decoder.layers.76.self_attn_layer_norm.weight", "decoder.layers.76.self_attn_layer_norm.bias", "decoder.layers.76.fc1.weight", "decoder.layers.76.fc1.bias", "decoder.layers.76.fc2.weight", "decoder.layers.76.fc2.bias", "decoder.layers.76.final_layer_norm.weight", "decoder.layers.76.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.77.flat_param_0": {"names": ["decoder.layers.77.self_attn.qkv_proj.weight", "decoder.layers.77.self_attn.qkv_proj.bias", "decoder.layers.77.self_attn.out_proj.weight", "decoder.layers.77.self_attn.out_proj.bias", "decoder.layers.77.self_attn_layer_norm.weight", "decoder.layers.77.self_attn_layer_norm.bias", "decoder.layers.77.fc1.weight", "decoder.layers.77.fc1.bias", "decoder.layers.77.fc2.weight", "decoder.layers.77.fc2.bias", "decoder.layers.77.final_layer_norm.weight", "decoder.layers.77.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.78.flat_param_0": {"names": ["decoder.layers.78.self_attn.qkv_proj.weight", "decoder.layers.78.self_attn.qkv_proj.bias", "decoder.layers.78.self_attn.out_proj.weight", "decoder.layers.78.self_attn.out_proj.bias", "decoder.layers.78.self_attn_layer_norm.weight", "decoder.layers.78.self_attn_layer_norm.bias", "decoder.layers.78.fc1.weight", "decoder.layers.78.fc1.bias", "decoder.layers.78.fc2.weight", "decoder.layers.78.fc2.bias", "decoder.layers.78.final_layer_norm.weight", "decoder.layers.78.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.79.flat_param_0": {"names": ["decoder.layers.79.self_attn.qkv_proj.weight", "decoder.layers.79.self_attn.qkv_proj.bias", "decoder.layers.79.self_attn.out_proj.weight", "decoder.layers.79.self_attn.out_proj.bias", "decoder.layers.79.self_attn_layer_norm.weight", "decoder.layers.79.self_attn_layer_norm.bias", "decoder.layers.79.fc1.weight", "decoder.layers.79.fc1.bias", "decoder.layers.79.fc2.weight", "decoder.layers.79.fc2.bias", "decoder.layers.79.final_layer_norm.weight", "decoder.layers.79.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.80.flat_param_0": {"names": ["decoder.layers.80.self_attn.qkv_proj.weight", "decoder.layers.80.self_attn.qkv_proj.bias", "decoder.layers.80.self_attn.out_proj.weight", "decoder.layers.80.self_attn.out_proj.bias", "decoder.layers.80.self_attn_layer_norm.weight", "decoder.layers.80.self_attn_layer_norm.bias", "decoder.layers.80.fc1.weight", "decoder.layers.80.fc1.bias", "decoder.layers.80.fc2.weight", "decoder.layers.80.fc2.bias", "decoder.layers.80.final_layer_norm.weight", "decoder.layers.80.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.81.flat_param_0": {"names": ["decoder.layers.81.self_attn.qkv_proj.weight", "decoder.layers.81.self_attn.qkv_proj.bias", "decoder.layers.81.self_attn.out_proj.weight", "decoder.layers.81.self_attn.out_proj.bias", "decoder.layers.81.self_attn_layer_norm.weight", "decoder.layers.81.self_attn_layer_norm.bias", "decoder.layers.81.fc1.weight", "decoder.layers.81.fc1.bias", "decoder.layers.81.fc2.weight", "decoder.layers.81.fc2.bias", "decoder.layers.81.final_layer_norm.weight", "decoder.layers.81.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.82.flat_param_0": {"names": ["decoder.layers.82.self_attn.qkv_proj.weight", "decoder.layers.82.self_attn.qkv_proj.bias", "decoder.layers.82.self_attn.out_proj.weight", "decoder.layers.82.self_attn.out_proj.bias", "decoder.layers.82.self_attn_layer_norm.weight", "decoder.layers.82.self_attn_layer_norm.bias", "decoder.layers.82.fc1.weight", "decoder.layers.82.fc1.bias", "decoder.layers.82.fc2.weight", "decoder.layers.82.fc2.bias", "decoder.layers.82.final_layer_norm.weight", "decoder.layers.82.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.83.flat_param_0": {"names": ["decoder.layers.83.self_attn.qkv_proj.weight", "decoder.layers.83.self_attn.qkv_proj.bias", "decoder.layers.83.self_attn.out_proj.weight", "decoder.layers.83.self_attn.out_proj.bias", "decoder.layers.83.self_attn_layer_norm.weight", "decoder.layers.83.self_attn_layer_norm.bias", "decoder.layers.83.fc1.weight", "decoder.layers.83.fc1.bias", "decoder.layers.83.fc2.weight", "decoder.layers.83.fc2.bias", "decoder.layers.83.final_layer_norm.weight", "decoder.layers.83.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.84.flat_param_0": {"names": ["decoder.layers.84.self_attn.qkv_proj.weight", "decoder.layers.84.self_attn.qkv_proj.bias", "decoder.layers.84.self_attn.out_proj.weight", "decoder.layers.84.self_attn.out_proj.bias", "decoder.layers.84.self_attn_layer_norm.weight", "decoder.layers.84.self_attn_layer_norm.bias", "decoder.layers.84.fc1.weight", "decoder.layers.84.fc1.bias", "decoder.layers.84.fc2.weight", "decoder.layers.84.fc2.bias", "decoder.layers.84.final_layer_norm.weight", "decoder.layers.84.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.85.flat_param_0": {"names": ["decoder.layers.85.self_attn.qkv_proj.weight", "decoder.layers.85.self_attn.qkv_proj.bias", "decoder.layers.85.self_attn.out_proj.weight", "decoder.layers.85.self_attn.out_proj.bias", "decoder.layers.85.self_attn_layer_norm.weight", "decoder.layers.85.self_attn_layer_norm.bias", "decoder.layers.85.fc1.weight", "decoder.layers.85.fc1.bias", "decoder.layers.85.fc2.weight", "decoder.layers.85.fc2.bias", "decoder.layers.85.final_layer_norm.weight", "decoder.layers.85.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.86.flat_param_0": {"names": ["decoder.layers.86.self_attn.qkv_proj.weight", "decoder.layers.86.self_attn.qkv_proj.bias", "decoder.layers.86.self_attn.out_proj.weight", "decoder.layers.86.self_attn.out_proj.bias", "decoder.layers.86.self_attn_layer_norm.weight", "decoder.layers.86.self_attn_layer_norm.bias", "decoder.layers.86.fc1.weight", "decoder.layers.86.fc1.bias", "decoder.layers.86.fc2.weight", "decoder.layers.86.fc2.bias", "decoder.layers.86.final_layer_norm.weight", "decoder.layers.86.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.87.flat_param_0": {"names": ["decoder.layers.87.self_attn.qkv_proj.weight", "decoder.layers.87.self_attn.qkv_proj.bias", "decoder.layers.87.self_attn.out_proj.weight", "decoder.layers.87.self_attn.out_proj.bias", "decoder.layers.87.self_attn_layer_norm.weight", "decoder.layers.87.self_attn_layer_norm.bias", "decoder.layers.87.fc1.weight", "decoder.layers.87.fc1.bias", "decoder.layers.87.fc2.weight", "decoder.layers.87.fc2.bias", "decoder.layers.87.final_layer_norm.weight", "decoder.layers.87.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.88.flat_param_0": {"names": ["decoder.layers.88.self_attn.qkv_proj.weight", "decoder.layers.88.self_attn.qkv_proj.bias", "decoder.layers.88.self_attn.out_proj.weight", "decoder.layers.88.self_attn.out_proj.bias", "decoder.layers.88.self_attn_layer_norm.weight", "decoder.layers.88.self_attn_layer_norm.bias", "decoder.layers.88.fc1.weight", "decoder.layers.88.fc1.bias", "decoder.layers.88.fc2.weight", "decoder.layers.88.fc2.bias", "decoder.layers.88.final_layer_norm.weight", "decoder.layers.88.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.89.flat_param_0": {"names": ["decoder.layers.89.self_attn.qkv_proj.weight", "decoder.layers.89.self_attn.qkv_proj.bias", "decoder.layers.89.self_attn.out_proj.weight", "decoder.layers.89.self_attn.out_proj.bias", "decoder.layers.89.self_attn_layer_norm.weight", "decoder.layers.89.self_attn_layer_norm.bias", "decoder.layers.89.fc1.weight", "decoder.layers.89.fc1.bias", "decoder.layers.89.fc2.weight", "decoder.layers.89.fc2.bias", "decoder.layers.89.final_layer_norm.weight", "decoder.layers.89.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.90.flat_param_0": {"names": ["decoder.layers.90.self_attn.qkv_proj.weight", "decoder.layers.90.self_attn.qkv_proj.bias", "decoder.layers.90.self_attn.out_proj.weight", "decoder.layers.90.self_attn.out_proj.bias", "decoder.layers.90.self_attn_layer_norm.weight", "decoder.layers.90.self_attn_layer_norm.bias", "decoder.layers.90.fc1.weight", "decoder.layers.90.fc1.bias", "decoder.layers.90.fc2.weight", "decoder.layers.90.fc2.bias", "decoder.layers.90.final_layer_norm.weight", "decoder.layers.90.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.91.flat_param_0": {"names": ["decoder.layers.91.self_attn.qkv_proj.weight", "decoder.layers.91.self_attn.qkv_proj.bias", "decoder.layers.91.self_attn.out_proj.weight", "decoder.layers.91.self_attn.out_proj.bias", "decoder.layers.91.self_attn_layer_norm.weight", "decoder.layers.91.self_attn_layer_norm.bias", "decoder.layers.91.fc1.weight", "decoder.layers.91.fc1.bias", "decoder.layers.91.fc2.weight", "decoder.layers.91.fc2.bias", "decoder.layers.91.final_layer_norm.weight", "decoder.layers.91.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.92.flat_param_0": {"names": ["decoder.layers.92.self_attn.qkv_proj.weight", "decoder.layers.92.self_attn.qkv_proj.bias", "decoder.layers.92.self_attn.out_proj.weight", "decoder.layers.92.self_attn.out_proj.bias", "decoder.layers.92.self_attn_layer_norm.weight", "decoder.layers.92.self_attn_layer_norm.bias", "decoder.layers.92.fc1.weight", "decoder.layers.92.fc1.bias", "decoder.layers.92.fc2.weight", "decoder.layers.92.fc2.bias", "decoder.layers.92.final_layer_norm.weight", "decoder.layers.92.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.93.flat_param_0": {"names": ["decoder.layers.93.self_attn.qkv_proj.weight", "decoder.layers.93.self_attn.qkv_proj.bias", "decoder.layers.93.self_attn.out_proj.weight", "decoder.layers.93.self_attn.out_proj.bias", "decoder.layers.93.self_attn_layer_norm.weight", "decoder.layers.93.self_attn_layer_norm.bias", "decoder.layers.93.fc1.weight", "decoder.layers.93.fc1.bias", "decoder.layers.93.fc2.weight", "decoder.layers.93.fc2.bias", "decoder.layers.93.final_layer_norm.weight", "decoder.layers.93.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.94.flat_param_0": {"names": ["decoder.layers.94.self_attn.qkv_proj.weight", "decoder.layers.94.self_attn.qkv_proj.bias", "decoder.layers.94.self_attn.out_proj.weight", "decoder.layers.94.self_attn.out_proj.bias", "decoder.layers.94.self_attn_layer_norm.weight", "decoder.layers.94.self_attn_layer_norm.bias", "decoder.layers.94.fc1.weight", "decoder.layers.94.fc1.bias", "decoder.layers.94.fc2.weight", "decoder.layers.94.fc2.bias", "decoder.layers.94.final_layer_norm.weight", "decoder.layers.94.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.95.flat_param_0": {"names": ["decoder.layers.95.self_attn.qkv_proj.weight", "decoder.layers.95.self_attn.qkv_proj.bias", "decoder.layers.95.self_attn.out_proj.weight", "decoder.layers.95.self_attn.out_proj.bias", "decoder.layers.95.self_attn_layer_norm.weight", "decoder.layers.95.self_attn_layer_norm.bias", "decoder.layers.95.fc1.weight", "decoder.layers.95.fc1.bias", "decoder.layers.95.fc2.weight", "decoder.layers.95.fc2.bias", "decoder.layers.95.final_layer_norm.weight", "decoder.layers.95.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}} \ No newline at end of file +{ + "flat_param_0": { + "names": [ + "decoder.embed_tokens.weight", + "decoder.embed_positions.weight", + "decoder.layer_norm.weight", + "decoder.layer_norm.bias" + ], + "shapes": [ + [ + 6284, + 12288 + ], + [ + 2050, + 12288 + ], + [ + 12288 + ], + [ + 12288 + ] + ], + "numels": [ + 77217792, + 25190400, + 12288, + 12288 + ] + }, + "decoder.layers.0.flat_param_0": { + "names": [ + "decoder.layers.0.self_attn.qkv_proj.weight", + "decoder.layers.0.self_attn.qkv_proj.bias", + "decoder.layers.0.self_attn.out_proj.weight", + "decoder.layers.0.self_attn.out_proj.bias", + "decoder.layers.0.self_attn_layer_norm.weight", + "decoder.layers.0.self_attn_layer_norm.bias", + "decoder.layers.0.fc1.weight", + "decoder.layers.0.fc1.bias", + "decoder.layers.0.fc2.weight", + "decoder.layers.0.fc2.bias", + "decoder.layers.0.final_layer_norm.weight", + "decoder.layers.0.final_layer_norm.bias" + ], + "shapes": [ + [ + 4608, + 12288 + ], + [ + 4608 + ], + [ + 12288, + 1536 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 6144, + 12288 + ], + [ + 6144 + ], + [ + 12288, + 6144 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ] + ], + "numels": [ + 56623104, + 4608, + 18874368, + 12288, + 12288, + 12288, + 75497472, + 6144, + 75497472, + 12288, + 12288, + 12288 + ] + }, + "decoder.layers.1.flat_param_0": { + "names": [ + "decoder.layers.1.self_attn.qkv_proj.weight", + "decoder.layers.1.self_attn.qkv_proj.bias", + "decoder.layers.1.self_attn.out_proj.weight", + "decoder.layers.1.self_attn.out_proj.bias", + "decoder.layers.1.self_attn_layer_norm.weight", + "decoder.layers.1.self_attn_layer_norm.bias", + "decoder.layers.1.fc1.weight", + "decoder.layers.1.fc1.bias", + "decoder.layers.1.fc2.weight", + "decoder.layers.1.fc2.bias", + "decoder.layers.1.final_layer_norm.weight", + "decoder.layers.1.final_layer_norm.bias" + ], + "shapes": [ + [ + 4608, + 12288 + ], + [ + 4608 + ], + [ + 12288, + 1536 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 6144, + 12288 + ], + [ + 6144 + ], + [ + 12288, + 6144 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ] + ], + "numels": [ + 56623104, + 4608, + 18874368, + 12288, + 12288, + 12288, + 75497472, + 6144, + 75497472, + 12288, + 12288, + 12288 + ] + }, + "decoder.layers.2.flat_param_0": { + "names": [ + "decoder.layers.2.self_attn.qkv_proj.weight", + "decoder.layers.2.self_attn.qkv_proj.bias", + "decoder.layers.2.self_attn.out_proj.weight", + "decoder.layers.2.self_attn.out_proj.bias", + "decoder.layers.2.self_attn_layer_norm.weight", + "decoder.layers.2.self_attn_layer_norm.bias", + "decoder.layers.2.fc1.weight", + "decoder.layers.2.fc1.bias", + "decoder.layers.2.fc2.weight", + "decoder.layers.2.fc2.bias", + "decoder.layers.2.final_layer_norm.weight", + "decoder.layers.2.final_layer_norm.bias" + ], + "shapes": [ + [ + 4608, + 12288 + ], + [ + 4608 + ], + [ + 12288, + 1536 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 6144, + 12288 + ], + [ + 6144 + ], + [ + 12288, + 6144 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ] + ], + "numels": [ + 56623104, + 4608, + 18874368, + 12288, + 12288, + 12288, + 75497472, + 6144, + 75497472, + 12288, + 12288, + 12288 + ] + }, + "decoder.layers.3.flat_param_0": { + "names": [ + "decoder.layers.3.self_attn.qkv_proj.weight", + "decoder.layers.3.self_attn.qkv_proj.bias", + "decoder.layers.3.self_attn.out_proj.weight", + "decoder.layers.3.self_attn.out_proj.bias", + "decoder.layers.3.self_attn_layer_norm.weight", + "decoder.layers.3.self_attn_layer_norm.bias", + "decoder.layers.3.fc1.weight", + "decoder.layers.3.fc1.bias", + "decoder.layers.3.fc2.weight", + "decoder.layers.3.fc2.bias", + "decoder.layers.3.final_layer_norm.weight", + "decoder.layers.3.final_layer_norm.bias" + ], + "shapes": [ + [ + 4608, + 12288 + ], + [ + 4608 + ], + [ + 12288, + 1536 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 6144, + 12288 + ], + [ + 6144 + ], + [ + 12288, + 6144 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ] + ], + "numels": [ + 56623104, + 4608, + 18874368, + 12288, + 12288, + 12288, + 75497472, + 6144, + 75497472, + 12288, + 12288, + 12288 + ] + }, + "decoder.layers.4.flat_param_0": { + "names": [ + "decoder.layers.4.self_attn.qkv_proj.weight", + "decoder.layers.4.self_attn.qkv_proj.bias", + "decoder.layers.4.self_attn.out_proj.weight", + "decoder.layers.4.self_attn.out_proj.bias", + "decoder.layers.4.self_attn_layer_norm.weight", + "decoder.layers.4.self_attn_layer_norm.bias", + "decoder.layers.4.fc1.weight", + "decoder.layers.4.fc1.bias", + "decoder.layers.4.fc2.weight", + "decoder.layers.4.fc2.bias", + "decoder.layers.4.final_layer_norm.weight", + "decoder.layers.4.final_layer_norm.bias" + ], + "shapes": [ + [ + 4608, + 12288 + ], + [ + 4608 + ], + [ + 12288, + 1536 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 6144, + 12288 + ], + [ + 6144 + ], + [ + 12288, + 6144 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ] + ], + "numels": [ + 56623104, + 4608, + 18874368, + 12288, + 12288, + 12288, + 75497472, + 6144, + 75497472, + 12288, + 12288, + 12288 + ] + }, + "decoder.layers.5.flat_param_0": { + "names": [ + "decoder.layers.5.self_attn.qkv_proj.weight", + "decoder.layers.5.self_attn.qkv_proj.bias", + "decoder.layers.5.self_attn.out_proj.weight", + "decoder.layers.5.self_attn.out_proj.bias", + "decoder.layers.5.self_attn_layer_norm.weight", + "decoder.layers.5.self_attn_layer_norm.bias", + "decoder.layers.5.fc1.weight", + "decoder.layers.5.fc1.bias", + "decoder.layers.5.fc2.weight", + "decoder.layers.5.fc2.bias", + "decoder.layers.5.final_layer_norm.weight", + "decoder.layers.5.final_layer_norm.bias" + ], + "shapes": [ + [ + 4608, + 12288 + ], + [ + 4608 + ], + [ + 12288, + 1536 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 6144, + 12288 + ], + [ + 6144 + ], + [ + 12288, + 6144 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ] + ], + "numels": [ + 56623104, + 4608, + 18874368, + 12288, + 12288, + 12288, + 75497472, + 6144, + 75497472, + 12288, + 12288, + 12288 + ] + }, + "decoder.layers.6.flat_param_0": { + "names": [ + "decoder.layers.6.self_attn.qkv_proj.weight", + "decoder.layers.6.self_attn.qkv_proj.bias", + "decoder.layers.6.self_attn.out_proj.weight", + "decoder.layers.6.self_attn.out_proj.bias", + "decoder.layers.6.self_attn_layer_norm.weight", + "decoder.layers.6.self_attn_layer_norm.bias", + "decoder.layers.6.fc1.weight", + "decoder.layers.6.fc1.bias", + "decoder.layers.6.fc2.weight", + "decoder.layers.6.fc2.bias", + "decoder.layers.6.final_layer_norm.weight", + "decoder.layers.6.final_layer_norm.bias" + ], + "shapes": [ + [ + 4608, + 12288 + ], + [ + 4608 + ], + [ + 12288, + 1536 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 6144, + 12288 + ], + [ + 6144 + ], + [ + 12288, + 6144 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ] + ], + "numels": [ + 56623104, + 4608, + 18874368, + 12288, + 12288, + 12288, + 75497472, + 6144, + 75497472, + 12288, + 12288, + 12288 + ] + }, + "decoder.layers.7.flat_param_0": { + "names": [ + "decoder.layers.7.self_attn.qkv_proj.weight", + "decoder.layers.7.self_attn.qkv_proj.bias", + "decoder.layers.7.self_attn.out_proj.weight", + "decoder.layers.7.self_attn.out_proj.bias", + "decoder.layers.7.self_attn_layer_norm.weight", + "decoder.layers.7.self_attn_layer_norm.bias", + "decoder.layers.7.fc1.weight", + "decoder.layers.7.fc1.bias", + "decoder.layers.7.fc2.weight", + "decoder.layers.7.fc2.bias", + "decoder.layers.7.final_layer_norm.weight", + "decoder.layers.7.final_layer_norm.bias" + ], + "shapes": [ + [ + 4608, + 12288 + ], + [ + 4608 + ], + [ + 12288, + 1536 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 6144, + 12288 + ], + [ + 6144 + ], + [ + 12288, + 6144 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ] + ], + "numels": [ + 56623104, + 4608, + 18874368, + 12288, + 12288, + 12288, + 75497472, + 6144, + 75497472, + 12288, + 12288, + 12288 + ] + }, + "decoder.layers.8.flat_param_0": { + "names": [ + "decoder.layers.8.self_attn.qkv_proj.weight", + "decoder.layers.8.self_attn.qkv_proj.bias", + "decoder.layers.8.self_attn.out_proj.weight", + "decoder.layers.8.self_attn.out_proj.bias", + "decoder.layers.8.self_attn_layer_norm.weight", + "decoder.layers.8.self_attn_layer_norm.bias", + "decoder.layers.8.fc1.weight", + "decoder.layers.8.fc1.bias", + "decoder.layers.8.fc2.weight", + "decoder.layers.8.fc2.bias", + "decoder.layers.8.final_layer_norm.weight", + "decoder.layers.8.final_layer_norm.bias" + ], + "shapes": [ + [ + 4608, + 12288 + ], + [ + 4608 + ], + [ + 12288, + 1536 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 6144, + 12288 + ], + [ + 6144 + ], + [ + 12288, + 6144 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ] + ], + "numels": [ + 56623104, + 4608, + 18874368, + 12288, + 12288, + 12288, + 75497472, + 6144, + 75497472, + 12288, + 12288, + 12288 + ] + }, + "decoder.layers.9.flat_param_0": { + "names": [ + "decoder.layers.9.self_attn.qkv_proj.weight", + "decoder.layers.9.self_attn.qkv_proj.bias", + "decoder.layers.9.self_attn.out_proj.weight", + "decoder.layers.9.self_attn.out_proj.bias", + "decoder.layers.9.self_attn_layer_norm.weight", + "decoder.layers.9.self_attn_layer_norm.bias", + "decoder.layers.9.fc1.weight", + "decoder.layers.9.fc1.bias", + "decoder.layers.9.fc2.weight", + "decoder.layers.9.fc2.bias", + "decoder.layers.9.final_layer_norm.weight", + "decoder.layers.9.final_layer_norm.bias" + ], + "shapes": [ + [ + 4608, + 12288 + ], + [ + 4608 + ], + [ + 12288, + 1536 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 6144, + 12288 + ], + [ + 6144 + ], + [ + 12288, + 6144 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ] + ], + "numels": [ + 56623104, + 4608, + 18874368, + 12288, + 12288, + 12288, + 75497472, + 6144, + 75497472, + 12288, + 12288, + 12288 + ] + }, + "decoder.layers.10.flat_param_0": { + "names": [ + "decoder.layers.10.self_attn.qkv_proj.weight", + "decoder.layers.10.self_attn.qkv_proj.bias", + "decoder.layers.10.self_attn.out_proj.weight", + "decoder.layers.10.self_attn.out_proj.bias", + "decoder.layers.10.self_attn_layer_norm.weight", + "decoder.layers.10.self_attn_layer_norm.bias", + "decoder.layers.10.fc1.weight", + "decoder.layers.10.fc1.bias", + "decoder.layers.10.fc2.weight", + "decoder.layers.10.fc2.bias", + "decoder.layers.10.final_layer_norm.weight", + "decoder.layers.10.final_layer_norm.bias" + ], + "shapes": [ + [ + 4608, + 12288 + ], + [ + 4608 + ], + [ + 12288, + 1536 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 6144, + 12288 + ], + [ + 6144 + ], + [ + 12288, + 6144 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ] + ], + "numels": [ + 56623104, + 4608, + 18874368, + 12288, + 12288, + 12288, + 75497472, + 6144, + 75497472, + 12288, + 12288, + 12288 + ] + }, + "decoder.layers.11.flat_param_0": { + "names": [ + "decoder.layers.11.self_attn.qkv_proj.weight", + "decoder.layers.11.self_attn.qkv_proj.bias", + "decoder.layers.11.self_attn.out_proj.weight", + "decoder.layers.11.self_attn.out_proj.bias", + "decoder.layers.11.self_attn_layer_norm.weight", + "decoder.layers.11.self_attn_layer_norm.bias", + "decoder.layers.11.fc1.weight", + "decoder.layers.11.fc1.bias", + "decoder.layers.11.fc2.weight", + "decoder.layers.11.fc2.bias", + "decoder.layers.11.final_layer_norm.weight", + "decoder.layers.11.final_layer_norm.bias" + ], + "shapes": [ + [ + 4608, + 12288 + ], + [ + 4608 + ], + [ + 12288, + 1536 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 6144, + 12288 + ], + [ + 6144 + ], + [ + 12288, + 6144 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ] + ], + "numels": [ + 56623104, + 4608, + 18874368, + 12288, + 12288, + 12288, + 75497472, + 6144, + 75497472, + 12288, + 12288, + 12288 + ] + }, + "decoder.layers.12.flat_param_0": { + "names": [ + "decoder.layers.12.self_attn.qkv_proj.weight", + "decoder.layers.12.self_attn.qkv_proj.bias", + "decoder.layers.12.self_attn.out_proj.weight", + "decoder.layers.12.self_attn.out_proj.bias", + "decoder.layers.12.self_attn_layer_norm.weight", + "decoder.layers.12.self_attn_layer_norm.bias", + "decoder.layers.12.fc1.weight", + "decoder.layers.12.fc1.bias", + "decoder.layers.12.fc2.weight", + "decoder.layers.12.fc2.bias", + "decoder.layers.12.final_layer_norm.weight", + "decoder.layers.12.final_layer_norm.bias" + ], + "shapes": [ + [ + 4608, + 12288 + ], + [ + 4608 + ], + [ + 12288, + 1536 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 6144, + 12288 + ], + [ + 6144 + ], + [ + 12288, + 6144 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ] + ], + "numels": [ + 56623104, + 4608, + 18874368, + 12288, + 12288, + 12288, + 75497472, + 6144, + 75497472, + 12288, + 12288, + 12288 + ] + }, + "decoder.layers.13.flat_param_0": { + "names": [ + "decoder.layers.13.self_attn.qkv_proj.weight", + "decoder.layers.13.self_attn.qkv_proj.bias", + "decoder.layers.13.self_attn.out_proj.weight", + "decoder.layers.13.self_attn.out_proj.bias", + "decoder.layers.13.self_attn_layer_norm.weight", + "decoder.layers.13.self_attn_layer_norm.bias", + "decoder.layers.13.fc1.weight", + "decoder.layers.13.fc1.bias", + "decoder.layers.13.fc2.weight", + "decoder.layers.13.fc2.bias", + "decoder.layers.13.final_layer_norm.weight", + "decoder.layers.13.final_layer_norm.bias" + ], + "shapes": [ + [ + 4608, + 12288 + ], + [ + 4608 + ], + [ + 12288, + 1536 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 6144, + 12288 + ], + [ + 6144 + ], + [ + 12288, + 6144 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ] + ], + "numels": [ + 56623104, + 4608, + 18874368, + 12288, + 12288, + 12288, + 75497472, + 6144, + 75497472, + 12288, + 12288, + 12288 + ] + }, + "decoder.layers.14.flat_param_0": { + "names": [ + "decoder.layers.14.self_attn.qkv_proj.weight", + "decoder.layers.14.self_attn.qkv_proj.bias", + "decoder.layers.14.self_attn.out_proj.weight", + "decoder.layers.14.self_attn.out_proj.bias", + "decoder.layers.14.self_attn_layer_norm.weight", + "decoder.layers.14.self_attn_layer_norm.bias", + "decoder.layers.14.fc1.weight", + "decoder.layers.14.fc1.bias", + "decoder.layers.14.fc2.weight", + "decoder.layers.14.fc2.bias", + "decoder.layers.14.final_layer_norm.weight", + "decoder.layers.14.final_layer_norm.bias" + ], + "shapes": [ + [ + 4608, + 12288 + ], + [ + 4608 + ], + [ + 12288, + 1536 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 6144, + 12288 + ], + [ + 6144 + ], + [ + 12288, + 6144 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ] + ], + "numels": [ + 56623104, + 4608, + 18874368, + 12288, + 12288, + 12288, + 75497472, + 6144, + 75497472, + 12288, + 12288, + 12288 + ] + }, + "decoder.layers.15.flat_param_0": { + "names": [ + "decoder.layers.15.self_attn.qkv_proj.weight", + "decoder.layers.15.self_attn.qkv_proj.bias", + "decoder.layers.15.self_attn.out_proj.weight", + "decoder.layers.15.self_attn.out_proj.bias", + "decoder.layers.15.self_attn_layer_norm.weight", + "decoder.layers.15.self_attn_layer_norm.bias", + "decoder.layers.15.fc1.weight", + "decoder.layers.15.fc1.bias", + "decoder.layers.15.fc2.weight", + "decoder.layers.15.fc2.bias", + "decoder.layers.15.final_layer_norm.weight", + "decoder.layers.15.final_layer_norm.bias" + ], + "shapes": [ + [ + 4608, + 12288 + ], + [ + 4608 + ], + [ + 12288, + 1536 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 6144, + 12288 + ], + [ + 6144 + ], + [ + 12288, + 6144 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ] + ], + "numels": [ + 56623104, + 4608, + 18874368, + 12288, + 12288, + 12288, + 75497472, + 6144, + 75497472, + 12288, + 12288, + 12288 + ] + }, + "decoder.layers.16.flat_param_0": { + "names": [ + "decoder.layers.16.self_attn.qkv_proj.weight", + "decoder.layers.16.self_attn.qkv_proj.bias", + "decoder.layers.16.self_attn.out_proj.weight", + "decoder.layers.16.self_attn.out_proj.bias", + "decoder.layers.16.self_attn_layer_norm.weight", + "decoder.layers.16.self_attn_layer_norm.bias", + "decoder.layers.16.fc1.weight", + "decoder.layers.16.fc1.bias", + "decoder.layers.16.fc2.weight", + "decoder.layers.16.fc2.bias", + "decoder.layers.16.final_layer_norm.weight", + "decoder.layers.16.final_layer_norm.bias" + ], + "shapes": [ + [ + 4608, + 12288 + ], + [ + 4608 + ], + [ + 12288, + 1536 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 6144, + 12288 + ], + [ + 6144 + ], + [ + 12288, + 6144 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ] + ], + "numels": [ + 56623104, + 4608, + 18874368, + 12288, + 12288, + 12288, + 75497472, + 6144, + 75497472, + 12288, + 12288, + 12288 + ] + }, + "decoder.layers.17.flat_param_0": { + "names": [ + "decoder.layers.17.self_attn.qkv_proj.weight", + "decoder.layers.17.self_attn.qkv_proj.bias", + "decoder.layers.17.self_attn.out_proj.weight", + "decoder.layers.17.self_attn.out_proj.bias", + "decoder.layers.17.self_attn_layer_norm.weight", + "decoder.layers.17.self_attn_layer_norm.bias", + "decoder.layers.17.fc1.weight", + "decoder.layers.17.fc1.bias", + "decoder.layers.17.fc2.weight", + "decoder.layers.17.fc2.bias", + "decoder.layers.17.final_layer_norm.weight", + "decoder.layers.17.final_layer_norm.bias" + ], + "shapes": [ + [ + 4608, + 12288 + ], + [ + 4608 + ], + [ + 12288, + 1536 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 6144, + 12288 + ], + [ + 6144 + ], + [ + 12288, + 6144 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ] + ], + "numels": [ + 56623104, + 4608, + 18874368, + 12288, + 12288, + 12288, + 75497472, + 6144, + 75497472, + 12288, + 12288, + 12288 + ] + }, + "decoder.layers.18.flat_param_0": { + "names": [ + "decoder.layers.18.self_attn.qkv_proj.weight", + "decoder.layers.18.self_attn.qkv_proj.bias", + "decoder.layers.18.self_attn.out_proj.weight", + "decoder.layers.18.self_attn.out_proj.bias", + "decoder.layers.18.self_attn_layer_norm.weight", + "decoder.layers.18.self_attn_layer_norm.bias", + "decoder.layers.18.fc1.weight", + "decoder.layers.18.fc1.bias", + "decoder.layers.18.fc2.weight", + "decoder.layers.18.fc2.bias", + "decoder.layers.18.final_layer_norm.weight", + "decoder.layers.18.final_layer_norm.bias" + ], + "shapes": [ + [ + 4608, + 12288 + ], + [ + 4608 + ], + [ + 12288, + 1536 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 6144, + 12288 + ], + [ + 6144 + ], + [ + 12288, + 6144 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ] + ], + "numels": [ + 56623104, + 4608, + 18874368, + 12288, + 12288, + 12288, + 75497472, + 6144, + 75497472, + 12288, + 12288, + 12288 + ] + }, + "decoder.layers.19.flat_param_0": { + "names": [ + "decoder.layers.19.self_attn.qkv_proj.weight", + "decoder.layers.19.self_attn.qkv_proj.bias", + "decoder.layers.19.self_attn.out_proj.weight", + "decoder.layers.19.self_attn.out_proj.bias", + "decoder.layers.19.self_attn_layer_norm.weight", + "decoder.layers.19.self_attn_layer_norm.bias", + "decoder.layers.19.fc1.weight", + "decoder.layers.19.fc1.bias", + "decoder.layers.19.fc2.weight", + "decoder.layers.19.fc2.bias", + "decoder.layers.19.final_layer_norm.weight", + "decoder.layers.19.final_layer_norm.bias" + ], + "shapes": [ + [ + 4608, + 12288 + ], + [ + 4608 + ], + [ + 12288, + 1536 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 6144, + 12288 + ], + [ + 6144 + ], + [ + 12288, + 6144 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ] + ], + "numels": [ + 56623104, + 4608, + 18874368, + 12288, + 12288, + 12288, + 75497472, + 6144, + 75497472, + 12288, + 12288, + 12288 + ] + }, + "decoder.layers.20.flat_param_0": { + "names": [ + "decoder.layers.20.self_attn.qkv_proj.weight", + "decoder.layers.20.self_attn.qkv_proj.bias", + "decoder.layers.20.self_attn.out_proj.weight", + "decoder.layers.20.self_attn.out_proj.bias", + "decoder.layers.20.self_attn_layer_norm.weight", + "decoder.layers.20.self_attn_layer_norm.bias", + "decoder.layers.20.fc1.weight", + "decoder.layers.20.fc1.bias", + "decoder.layers.20.fc2.weight", + "decoder.layers.20.fc2.bias", + "decoder.layers.20.final_layer_norm.weight", + "decoder.layers.20.final_layer_norm.bias" + ], + "shapes": [ + [ + 4608, + 12288 + ], + [ + 4608 + ], + [ + 12288, + 1536 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 6144, + 12288 + ], + [ + 6144 + ], + [ + 12288, + 6144 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ] + ], + "numels": [ + 56623104, + 4608, + 18874368, + 12288, + 12288, + 12288, + 75497472, + 6144, + 75497472, + 12288, + 12288, + 12288 + ] + }, + "decoder.layers.21.flat_param_0": { + "names": [ + "decoder.layers.21.self_attn.qkv_proj.weight", + "decoder.layers.21.self_attn.qkv_proj.bias", + "decoder.layers.21.self_attn.out_proj.weight", + "decoder.layers.21.self_attn.out_proj.bias", + "decoder.layers.21.self_attn_layer_norm.weight", + "decoder.layers.21.self_attn_layer_norm.bias", + "decoder.layers.21.fc1.weight", + "decoder.layers.21.fc1.bias", + "decoder.layers.21.fc2.weight", + "decoder.layers.21.fc2.bias", + "decoder.layers.21.final_layer_norm.weight", + "decoder.layers.21.final_layer_norm.bias" + ], + "shapes": [ + [ + 4608, + 12288 + ], + [ + 4608 + ], + [ + 12288, + 1536 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 6144, + 12288 + ], + [ + 6144 + ], + [ + 12288, + 6144 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ] + ], + "numels": [ + 56623104, + 4608, + 18874368, + 12288, + 12288, + 12288, + 75497472, + 6144, + 75497472, + 12288, + 12288, + 12288 + ] + }, + "decoder.layers.22.flat_param_0": { + "names": [ + "decoder.layers.22.self_attn.qkv_proj.weight", + "decoder.layers.22.self_attn.qkv_proj.bias", + "decoder.layers.22.self_attn.out_proj.weight", + "decoder.layers.22.self_attn.out_proj.bias", + "decoder.layers.22.self_attn_layer_norm.weight", + "decoder.layers.22.self_attn_layer_norm.bias", + "decoder.layers.22.fc1.weight", + "decoder.layers.22.fc1.bias", + "decoder.layers.22.fc2.weight", + "decoder.layers.22.fc2.bias", + "decoder.layers.22.final_layer_norm.weight", + "decoder.layers.22.final_layer_norm.bias" + ], + "shapes": [ + [ + 4608, + 12288 + ], + [ + 4608 + ], + [ + 12288, + 1536 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 6144, + 12288 + ], + [ + 6144 + ], + [ + 12288, + 6144 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ] + ], + "numels": [ + 56623104, + 4608, + 18874368, + 12288, + 12288, + 12288, + 75497472, + 6144, + 75497472, + 12288, + 12288, + 12288 + ] + }, + "decoder.layers.23.flat_param_0": { + "names": [ + "decoder.layers.23.self_attn.qkv_proj.weight", + "decoder.layers.23.self_attn.qkv_proj.bias", + "decoder.layers.23.self_attn.out_proj.weight", + "decoder.layers.23.self_attn.out_proj.bias", + "decoder.layers.23.self_attn_layer_norm.weight", + "decoder.layers.23.self_attn_layer_norm.bias", + "decoder.layers.23.fc1.weight", + "decoder.layers.23.fc1.bias", + "decoder.layers.23.fc2.weight", + "decoder.layers.23.fc2.bias", + "decoder.layers.23.final_layer_norm.weight", + "decoder.layers.23.final_layer_norm.bias" + ], + "shapes": [ + [ + 4608, + 12288 + ], + [ + 4608 + ], + [ + 12288, + 1536 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 6144, + 12288 + ], + [ + 6144 + ], + [ + 12288, + 6144 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ] + ], + "numels": [ + 56623104, + 4608, + 18874368, + 12288, + 12288, + 12288, + 75497472, + 6144, + 75497472, + 12288, + 12288, + 12288 + ] + }, + "decoder.layers.24.flat_param_0": { + "names": [ + "decoder.layers.24.self_attn.qkv_proj.weight", + "decoder.layers.24.self_attn.qkv_proj.bias", + "decoder.layers.24.self_attn.out_proj.weight", + "decoder.layers.24.self_attn.out_proj.bias", + "decoder.layers.24.self_attn_layer_norm.weight", + "decoder.layers.24.self_attn_layer_norm.bias", + "decoder.layers.24.fc1.weight", + "decoder.layers.24.fc1.bias", + "decoder.layers.24.fc2.weight", + "decoder.layers.24.fc2.bias", + "decoder.layers.24.final_layer_norm.weight", + "decoder.layers.24.final_layer_norm.bias" + ], + "shapes": [ + [ + 4608, + 12288 + ], + [ + 4608 + ], + [ + 12288, + 1536 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 6144, + 12288 + ], + [ + 6144 + ], + [ + 12288, + 6144 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ] + ], + "numels": [ + 56623104, + 4608, + 18874368, + 12288, + 12288, + 12288, + 75497472, + 6144, + 75497472, + 12288, + 12288, + 12288 + ] + }, + "decoder.layers.25.flat_param_0": { + "names": [ + "decoder.layers.25.self_attn.qkv_proj.weight", + "decoder.layers.25.self_attn.qkv_proj.bias", + "decoder.layers.25.self_attn.out_proj.weight", + "decoder.layers.25.self_attn.out_proj.bias", + "decoder.layers.25.self_attn_layer_norm.weight", + "decoder.layers.25.self_attn_layer_norm.bias", + "decoder.layers.25.fc1.weight", + "decoder.layers.25.fc1.bias", + "decoder.layers.25.fc2.weight", + "decoder.layers.25.fc2.bias", + "decoder.layers.25.final_layer_norm.weight", + "decoder.layers.25.final_layer_norm.bias" + ], + "shapes": [ + [ + 4608, + 12288 + ], + [ + 4608 + ], + [ + 12288, + 1536 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 6144, + 12288 + ], + [ + 6144 + ], + [ + 12288, + 6144 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ] + ], + "numels": [ + 56623104, + 4608, + 18874368, + 12288, + 12288, + 12288, + 75497472, + 6144, + 75497472, + 12288, + 12288, + 12288 + ] + }, + "decoder.layers.26.flat_param_0": { + "names": [ + "decoder.layers.26.self_attn.qkv_proj.weight", + "decoder.layers.26.self_attn.qkv_proj.bias", + "decoder.layers.26.self_attn.out_proj.weight", + "decoder.layers.26.self_attn.out_proj.bias", + "decoder.layers.26.self_attn_layer_norm.weight", + "decoder.layers.26.self_attn_layer_norm.bias", + "decoder.layers.26.fc1.weight", + "decoder.layers.26.fc1.bias", + "decoder.layers.26.fc2.weight", + "decoder.layers.26.fc2.bias", + "decoder.layers.26.final_layer_norm.weight", + "decoder.layers.26.final_layer_norm.bias" + ], + "shapes": [ + [ + 4608, + 12288 + ], + [ + 4608 + ], + [ + 12288, + 1536 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 6144, + 12288 + ], + [ + 6144 + ], + [ + 12288, + 6144 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ] + ], + "numels": [ + 56623104, + 4608, + 18874368, + 12288, + 12288, + 12288, + 75497472, + 6144, + 75497472, + 12288, + 12288, + 12288 + ] + }, + "decoder.layers.27.flat_param_0": { + "names": [ + "decoder.layers.27.self_attn.qkv_proj.weight", + "decoder.layers.27.self_attn.qkv_proj.bias", + "decoder.layers.27.self_attn.out_proj.weight", + "decoder.layers.27.self_attn.out_proj.bias", + "decoder.layers.27.self_attn_layer_norm.weight", + "decoder.layers.27.self_attn_layer_norm.bias", + "decoder.layers.27.fc1.weight", + "decoder.layers.27.fc1.bias", + "decoder.layers.27.fc2.weight", + "decoder.layers.27.fc2.bias", + "decoder.layers.27.final_layer_norm.weight", + "decoder.layers.27.final_layer_norm.bias" + ], + "shapes": [ + [ + 4608, + 12288 + ], + [ + 4608 + ], + [ + 12288, + 1536 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 6144, + 12288 + ], + [ + 6144 + ], + [ + 12288, + 6144 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ] + ], + "numels": [ + 56623104, + 4608, + 18874368, + 12288, + 12288, + 12288, + 75497472, + 6144, + 75497472, + 12288, + 12288, + 12288 + ] + }, + "decoder.layers.28.flat_param_0": { + "names": [ + "decoder.layers.28.self_attn.qkv_proj.weight", + "decoder.layers.28.self_attn.qkv_proj.bias", + "decoder.layers.28.self_attn.out_proj.weight", + "decoder.layers.28.self_attn.out_proj.bias", + "decoder.layers.28.self_attn_layer_norm.weight", + "decoder.layers.28.self_attn_layer_norm.bias", + "decoder.layers.28.fc1.weight", + "decoder.layers.28.fc1.bias", + "decoder.layers.28.fc2.weight", + "decoder.layers.28.fc2.bias", + "decoder.layers.28.final_layer_norm.weight", + "decoder.layers.28.final_layer_norm.bias" + ], + "shapes": [ + [ + 4608, + 12288 + ], + [ + 4608 + ], + [ + 12288, + 1536 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 6144, + 12288 + ], + [ + 6144 + ], + [ + 12288, + 6144 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ] + ], + "numels": [ + 56623104, + 4608, + 18874368, + 12288, + 12288, + 12288, + 75497472, + 6144, + 75497472, + 12288, + 12288, + 12288 + ] + }, + "decoder.layers.29.flat_param_0": { + "names": [ + "decoder.layers.29.self_attn.qkv_proj.weight", + "decoder.layers.29.self_attn.qkv_proj.bias", + "decoder.layers.29.self_attn.out_proj.weight", + "decoder.layers.29.self_attn.out_proj.bias", + "decoder.layers.29.self_attn_layer_norm.weight", + "decoder.layers.29.self_attn_layer_norm.bias", + "decoder.layers.29.fc1.weight", + "decoder.layers.29.fc1.bias", + "decoder.layers.29.fc2.weight", + "decoder.layers.29.fc2.bias", + "decoder.layers.29.final_layer_norm.weight", + "decoder.layers.29.final_layer_norm.bias" + ], + "shapes": [ + [ + 4608, + 12288 + ], + [ + 4608 + ], + [ + 12288, + 1536 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 6144, + 12288 + ], + [ + 6144 + ], + [ + 12288, + 6144 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ] + ], + "numels": [ + 56623104, + 4608, + 18874368, + 12288, + 12288, + 12288, + 75497472, + 6144, + 75497472, + 12288, + 12288, + 12288 + ] + }, + "decoder.layers.30.flat_param_0": { + "names": [ + "decoder.layers.30.self_attn.qkv_proj.weight", + "decoder.layers.30.self_attn.qkv_proj.bias", + "decoder.layers.30.self_attn.out_proj.weight", + "decoder.layers.30.self_attn.out_proj.bias", + "decoder.layers.30.self_attn_layer_norm.weight", + "decoder.layers.30.self_attn_layer_norm.bias", + "decoder.layers.30.fc1.weight", + "decoder.layers.30.fc1.bias", + "decoder.layers.30.fc2.weight", + "decoder.layers.30.fc2.bias", + "decoder.layers.30.final_layer_norm.weight", + "decoder.layers.30.final_layer_norm.bias" + ], + "shapes": [ + [ + 4608, + 12288 + ], + [ + 4608 + ], + [ + 12288, + 1536 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 6144, + 12288 + ], + [ + 6144 + ], + [ + 12288, + 6144 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ] + ], + "numels": [ + 56623104, + 4608, + 18874368, + 12288, + 12288, + 12288, + 75497472, + 6144, + 75497472, + 12288, + 12288, + 12288 + ] + }, + "decoder.layers.31.flat_param_0": { + "names": [ + "decoder.layers.31.self_attn.qkv_proj.weight", + "decoder.layers.31.self_attn.qkv_proj.bias", + "decoder.layers.31.self_attn.out_proj.weight", + "decoder.layers.31.self_attn.out_proj.bias", + "decoder.layers.31.self_attn_layer_norm.weight", + "decoder.layers.31.self_attn_layer_norm.bias", + "decoder.layers.31.fc1.weight", + "decoder.layers.31.fc1.bias", + "decoder.layers.31.fc2.weight", + "decoder.layers.31.fc2.bias", + "decoder.layers.31.final_layer_norm.weight", + "decoder.layers.31.final_layer_norm.bias" + ], + "shapes": [ + [ + 4608, + 12288 + ], + [ + 4608 + ], + [ + 12288, + 1536 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 6144, + 12288 + ], + [ + 6144 + ], + [ + 12288, + 6144 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ] + ], + "numels": [ + 56623104, + 4608, + 18874368, + 12288, + 12288, + 12288, + 75497472, + 6144, + 75497472, + 12288, + 12288, + 12288 + ] + }, + "decoder.layers.32.flat_param_0": { + "names": [ + "decoder.layers.32.self_attn.qkv_proj.weight", + "decoder.layers.32.self_attn.qkv_proj.bias", + "decoder.layers.32.self_attn.out_proj.weight", + "decoder.layers.32.self_attn.out_proj.bias", + "decoder.layers.32.self_attn_layer_norm.weight", + "decoder.layers.32.self_attn_layer_norm.bias", + "decoder.layers.32.fc1.weight", + "decoder.layers.32.fc1.bias", + "decoder.layers.32.fc2.weight", + "decoder.layers.32.fc2.bias", + "decoder.layers.32.final_layer_norm.weight", + "decoder.layers.32.final_layer_norm.bias" + ], + "shapes": [ + [ + 4608, + 12288 + ], + [ + 4608 + ], + [ + 12288, + 1536 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 6144, + 12288 + ], + [ + 6144 + ], + [ + 12288, + 6144 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ] + ], + "numels": [ + 56623104, + 4608, + 18874368, + 12288, + 12288, + 12288, + 75497472, + 6144, + 75497472, + 12288, + 12288, + 12288 + ] + }, + "decoder.layers.33.flat_param_0": { + "names": [ + "decoder.layers.33.self_attn.qkv_proj.weight", + "decoder.layers.33.self_attn.qkv_proj.bias", + "decoder.layers.33.self_attn.out_proj.weight", + "decoder.layers.33.self_attn.out_proj.bias", + "decoder.layers.33.self_attn_layer_norm.weight", + "decoder.layers.33.self_attn_layer_norm.bias", + "decoder.layers.33.fc1.weight", + "decoder.layers.33.fc1.bias", + "decoder.layers.33.fc2.weight", + "decoder.layers.33.fc2.bias", + "decoder.layers.33.final_layer_norm.weight", + "decoder.layers.33.final_layer_norm.bias" + ], + "shapes": [ + [ + 4608, + 12288 + ], + [ + 4608 + ], + [ + 12288, + 1536 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 6144, + 12288 + ], + [ + 6144 + ], + [ + 12288, + 6144 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ] + ], + "numels": [ + 56623104, + 4608, + 18874368, + 12288, + 12288, + 12288, + 75497472, + 6144, + 75497472, + 12288, + 12288, + 12288 + ] + }, + "decoder.layers.34.flat_param_0": { + "names": [ + "decoder.layers.34.self_attn.qkv_proj.weight", + "decoder.layers.34.self_attn.qkv_proj.bias", + "decoder.layers.34.self_attn.out_proj.weight", + "decoder.layers.34.self_attn.out_proj.bias", + "decoder.layers.34.self_attn_layer_norm.weight", + "decoder.layers.34.self_attn_layer_norm.bias", + "decoder.layers.34.fc1.weight", + "decoder.layers.34.fc1.bias", + "decoder.layers.34.fc2.weight", + "decoder.layers.34.fc2.bias", + "decoder.layers.34.final_layer_norm.weight", + "decoder.layers.34.final_layer_norm.bias" + ], + "shapes": [ + [ + 4608, + 12288 + ], + [ + 4608 + ], + [ + 12288, + 1536 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 6144, + 12288 + ], + [ + 6144 + ], + [ + 12288, + 6144 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ] + ], + "numels": [ + 56623104, + 4608, + 18874368, + 12288, + 12288, + 12288, + 75497472, + 6144, + 75497472, + 12288, + 12288, + 12288 + ] + }, + "decoder.layers.35.flat_param_0": { + "names": [ + "decoder.layers.35.self_attn.qkv_proj.weight", + "decoder.layers.35.self_attn.qkv_proj.bias", + "decoder.layers.35.self_attn.out_proj.weight", + "decoder.layers.35.self_attn.out_proj.bias", + "decoder.layers.35.self_attn_layer_norm.weight", + "decoder.layers.35.self_attn_layer_norm.bias", + "decoder.layers.35.fc1.weight", + "decoder.layers.35.fc1.bias", + "decoder.layers.35.fc2.weight", + "decoder.layers.35.fc2.bias", + "decoder.layers.35.final_layer_norm.weight", + "decoder.layers.35.final_layer_norm.bias" + ], + "shapes": [ + [ + 4608, + 12288 + ], + [ + 4608 + ], + [ + 12288, + 1536 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 6144, + 12288 + ], + [ + 6144 + ], + [ + 12288, + 6144 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ] + ], + "numels": [ + 56623104, + 4608, + 18874368, + 12288, + 12288, + 12288, + 75497472, + 6144, + 75497472, + 12288, + 12288, + 12288 + ] + }, + "decoder.layers.36.flat_param_0": { + "names": [ + "decoder.layers.36.self_attn.qkv_proj.weight", + "decoder.layers.36.self_attn.qkv_proj.bias", + "decoder.layers.36.self_attn.out_proj.weight", + "decoder.layers.36.self_attn.out_proj.bias", + "decoder.layers.36.self_attn_layer_norm.weight", + "decoder.layers.36.self_attn_layer_norm.bias", + "decoder.layers.36.fc1.weight", + "decoder.layers.36.fc1.bias", + "decoder.layers.36.fc2.weight", + "decoder.layers.36.fc2.bias", + "decoder.layers.36.final_layer_norm.weight", + "decoder.layers.36.final_layer_norm.bias" + ], + "shapes": [ + [ + 4608, + 12288 + ], + [ + 4608 + ], + [ + 12288, + 1536 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 6144, + 12288 + ], + [ + 6144 + ], + [ + 12288, + 6144 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ] + ], + "numels": [ + 56623104, + 4608, + 18874368, + 12288, + 12288, + 12288, + 75497472, + 6144, + 75497472, + 12288, + 12288, + 12288 + ] + }, + "decoder.layers.37.flat_param_0": { + "names": [ + "decoder.layers.37.self_attn.qkv_proj.weight", + "decoder.layers.37.self_attn.qkv_proj.bias", + "decoder.layers.37.self_attn.out_proj.weight", + "decoder.layers.37.self_attn.out_proj.bias", + "decoder.layers.37.self_attn_layer_norm.weight", + "decoder.layers.37.self_attn_layer_norm.bias", + "decoder.layers.37.fc1.weight", + "decoder.layers.37.fc1.bias", + "decoder.layers.37.fc2.weight", + "decoder.layers.37.fc2.bias", + "decoder.layers.37.final_layer_norm.weight", + "decoder.layers.37.final_layer_norm.bias" + ], + "shapes": [ + [ + 4608, + 12288 + ], + [ + 4608 + ], + [ + 12288, + 1536 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 6144, + 12288 + ], + [ + 6144 + ], + [ + 12288, + 6144 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ] + ], + "numels": [ + 56623104, + 4608, + 18874368, + 12288, + 12288, + 12288, + 75497472, + 6144, + 75497472, + 12288, + 12288, + 12288 + ] + }, + "decoder.layers.38.flat_param_0": { + "names": [ + "decoder.layers.38.self_attn.qkv_proj.weight", + "decoder.layers.38.self_attn.qkv_proj.bias", + "decoder.layers.38.self_attn.out_proj.weight", + "decoder.layers.38.self_attn.out_proj.bias", + "decoder.layers.38.self_attn_layer_norm.weight", + "decoder.layers.38.self_attn_layer_norm.bias", + "decoder.layers.38.fc1.weight", + "decoder.layers.38.fc1.bias", + "decoder.layers.38.fc2.weight", + "decoder.layers.38.fc2.bias", + "decoder.layers.38.final_layer_norm.weight", + "decoder.layers.38.final_layer_norm.bias" + ], + "shapes": [ + [ + 4608, + 12288 + ], + [ + 4608 + ], + [ + 12288, + 1536 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 6144, + 12288 + ], + [ + 6144 + ], + [ + 12288, + 6144 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ] + ], + "numels": [ + 56623104, + 4608, + 18874368, + 12288, + 12288, + 12288, + 75497472, + 6144, + 75497472, + 12288, + 12288, + 12288 + ] + }, + "decoder.layers.39.flat_param_0": { + "names": [ + "decoder.layers.39.self_attn.qkv_proj.weight", + "decoder.layers.39.self_attn.qkv_proj.bias", + "decoder.layers.39.self_attn.out_proj.weight", + "decoder.layers.39.self_attn.out_proj.bias", + "decoder.layers.39.self_attn_layer_norm.weight", + "decoder.layers.39.self_attn_layer_norm.bias", + "decoder.layers.39.fc1.weight", + "decoder.layers.39.fc1.bias", + "decoder.layers.39.fc2.weight", + "decoder.layers.39.fc2.bias", + "decoder.layers.39.final_layer_norm.weight", + "decoder.layers.39.final_layer_norm.bias" + ], + "shapes": [ + [ + 4608, + 12288 + ], + [ + 4608 + ], + [ + 12288, + 1536 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 6144, + 12288 + ], + [ + 6144 + ], + [ + 12288, + 6144 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ] + ], + "numels": [ + 56623104, + 4608, + 18874368, + 12288, + 12288, + 12288, + 75497472, + 6144, + 75497472, + 12288, + 12288, + 12288 + ] + }, + "decoder.layers.40.flat_param_0": { + "names": [ + "decoder.layers.40.self_attn.qkv_proj.weight", + "decoder.layers.40.self_attn.qkv_proj.bias", + "decoder.layers.40.self_attn.out_proj.weight", + "decoder.layers.40.self_attn.out_proj.bias", + "decoder.layers.40.self_attn_layer_norm.weight", + "decoder.layers.40.self_attn_layer_norm.bias", + "decoder.layers.40.fc1.weight", + "decoder.layers.40.fc1.bias", + "decoder.layers.40.fc2.weight", + "decoder.layers.40.fc2.bias", + "decoder.layers.40.final_layer_norm.weight", + "decoder.layers.40.final_layer_norm.bias" + ], + "shapes": [ + [ + 4608, + 12288 + ], + [ + 4608 + ], + [ + 12288, + 1536 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 6144, + 12288 + ], + [ + 6144 + ], + [ + 12288, + 6144 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ] + ], + "numels": [ + 56623104, + 4608, + 18874368, + 12288, + 12288, + 12288, + 75497472, + 6144, + 75497472, + 12288, + 12288, + 12288 + ] + }, + "decoder.layers.41.flat_param_0": { + "names": [ + "decoder.layers.41.self_attn.qkv_proj.weight", + "decoder.layers.41.self_attn.qkv_proj.bias", + "decoder.layers.41.self_attn.out_proj.weight", + "decoder.layers.41.self_attn.out_proj.bias", + "decoder.layers.41.self_attn_layer_norm.weight", + "decoder.layers.41.self_attn_layer_norm.bias", + "decoder.layers.41.fc1.weight", + "decoder.layers.41.fc1.bias", + "decoder.layers.41.fc2.weight", + "decoder.layers.41.fc2.bias", + "decoder.layers.41.final_layer_norm.weight", + "decoder.layers.41.final_layer_norm.bias" + ], + "shapes": [ + [ + 4608, + 12288 + ], + [ + 4608 + ], + [ + 12288, + 1536 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 6144, + 12288 + ], + [ + 6144 + ], + [ + 12288, + 6144 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ] + ], + "numels": [ + 56623104, + 4608, + 18874368, + 12288, + 12288, + 12288, + 75497472, + 6144, + 75497472, + 12288, + 12288, + 12288 + ] + }, + "decoder.layers.42.flat_param_0": { + "names": [ + "decoder.layers.42.self_attn.qkv_proj.weight", + "decoder.layers.42.self_attn.qkv_proj.bias", + "decoder.layers.42.self_attn.out_proj.weight", + "decoder.layers.42.self_attn.out_proj.bias", + "decoder.layers.42.self_attn_layer_norm.weight", + "decoder.layers.42.self_attn_layer_norm.bias", + "decoder.layers.42.fc1.weight", + "decoder.layers.42.fc1.bias", + "decoder.layers.42.fc2.weight", + "decoder.layers.42.fc2.bias", + "decoder.layers.42.final_layer_norm.weight", + "decoder.layers.42.final_layer_norm.bias" + ], + "shapes": [ + [ + 4608, + 12288 + ], + [ + 4608 + ], + [ + 12288, + 1536 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 6144, + 12288 + ], + [ + 6144 + ], + [ + 12288, + 6144 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ] + ], + "numels": [ + 56623104, + 4608, + 18874368, + 12288, + 12288, + 12288, + 75497472, + 6144, + 75497472, + 12288, + 12288, + 12288 + ] + }, + "decoder.layers.43.flat_param_0": { + "names": [ + "decoder.layers.43.self_attn.qkv_proj.weight", + "decoder.layers.43.self_attn.qkv_proj.bias", + "decoder.layers.43.self_attn.out_proj.weight", + "decoder.layers.43.self_attn.out_proj.bias", + "decoder.layers.43.self_attn_layer_norm.weight", + "decoder.layers.43.self_attn_layer_norm.bias", + "decoder.layers.43.fc1.weight", + "decoder.layers.43.fc1.bias", + "decoder.layers.43.fc2.weight", + "decoder.layers.43.fc2.bias", + "decoder.layers.43.final_layer_norm.weight", + "decoder.layers.43.final_layer_norm.bias" + ], + "shapes": [ + [ + 4608, + 12288 + ], + [ + 4608 + ], + [ + 12288, + 1536 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 6144, + 12288 + ], + [ + 6144 + ], + [ + 12288, + 6144 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ] + ], + "numels": [ + 56623104, + 4608, + 18874368, + 12288, + 12288, + 12288, + 75497472, + 6144, + 75497472, + 12288, + 12288, + 12288 + ] + }, + "decoder.layers.44.flat_param_0": { + "names": [ + "decoder.layers.44.self_attn.qkv_proj.weight", + "decoder.layers.44.self_attn.qkv_proj.bias", + "decoder.layers.44.self_attn.out_proj.weight", + "decoder.layers.44.self_attn.out_proj.bias", + "decoder.layers.44.self_attn_layer_norm.weight", + "decoder.layers.44.self_attn_layer_norm.bias", + "decoder.layers.44.fc1.weight", + "decoder.layers.44.fc1.bias", + "decoder.layers.44.fc2.weight", + "decoder.layers.44.fc2.bias", + "decoder.layers.44.final_layer_norm.weight", + "decoder.layers.44.final_layer_norm.bias" + ], + "shapes": [ + [ + 4608, + 12288 + ], + [ + 4608 + ], + [ + 12288, + 1536 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 6144, + 12288 + ], + [ + 6144 + ], + [ + 12288, + 6144 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ] + ], + "numels": [ + 56623104, + 4608, + 18874368, + 12288, + 12288, + 12288, + 75497472, + 6144, + 75497472, + 12288, + 12288, + 12288 + ] + }, + "decoder.layers.45.flat_param_0": { + "names": [ + "decoder.layers.45.self_attn.qkv_proj.weight", + "decoder.layers.45.self_attn.qkv_proj.bias", + "decoder.layers.45.self_attn.out_proj.weight", + "decoder.layers.45.self_attn.out_proj.bias", + "decoder.layers.45.self_attn_layer_norm.weight", + "decoder.layers.45.self_attn_layer_norm.bias", + "decoder.layers.45.fc1.weight", + "decoder.layers.45.fc1.bias", + "decoder.layers.45.fc2.weight", + "decoder.layers.45.fc2.bias", + "decoder.layers.45.final_layer_norm.weight", + "decoder.layers.45.final_layer_norm.bias" + ], + "shapes": [ + [ + 4608, + 12288 + ], + [ + 4608 + ], + [ + 12288, + 1536 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 6144, + 12288 + ], + [ + 6144 + ], + [ + 12288, + 6144 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ] + ], + "numels": [ + 56623104, + 4608, + 18874368, + 12288, + 12288, + 12288, + 75497472, + 6144, + 75497472, + 12288, + 12288, + 12288 + ] + }, + "decoder.layers.46.flat_param_0": { + "names": [ + "decoder.layers.46.self_attn.qkv_proj.weight", + "decoder.layers.46.self_attn.qkv_proj.bias", + "decoder.layers.46.self_attn.out_proj.weight", + "decoder.layers.46.self_attn.out_proj.bias", + "decoder.layers.46.self_attn_layer_norm.weight", + "decoder.layers.46.self_attn_layer_norm.bias", + "decoder.layers.46.fc1.weight", + "decoder.layers.46.fc1.bias", + "decoder.layers.46.fc2.weight", + "decoder.layers.46.fc2.bias", + "decoder.layers.46.final_layer_norm.weight", + "decoder.layers.46.final_layer_norm.bias" + ], + "shapes": [ + [ + 4608, + 12288 + ], + [ + 4608 + ], + [ + 12288, + 1536 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 6144, + 12288 + ], + [ + 6144 + ], + [ + 12288, + 6144 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ] + ], + "numels": [ + 56623104, + 4608, + 18874368, + 12288, + 12288, + 12288, + 75497472, + 6144, + 75497472, + 12288, + 12288, + 12288 + ] + }, + "decoder.layers.47.flat_param_0": { + "names": [ + "decoder.layers.47.self_attn.qkv_proj.weight", + "decoder.layers.47.self_attn.qkv_proj.bias", + "decoder.layers.47.self_attn.out_proj.weight", + "decoder.layers.47.self_attn.out_proj.bias", + "decoder.layers.47.self_attn_layer_norm.weight", + "decoder.layers.47.self_attn_layer_norm.bias", + "decoder.layers.47.fc1.weight", + "decoder.layers.47.fc1.bias", + "decoder.layers.47.fc2.weight", + "decoder.layers.47.fc2.bias", + "decoder.layers.47.final_layer_norm.weight", + "decoder.layers.47.final_layer_norm.bias" + ], + "shapes": [ + [ + 4608, + 12288 + ], + [ + 4608 + ], + [ + 12288, + 1536 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 6144, + 12288 + ], + [ + 6144 + ], + [ + 12288, + 6144 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ] + ], + "numels": [ + 56623104, + 4608, + 18874368, + 12288, + 12288, + 12288, + 75497472, + 6144, + 75497472, + 12288, + 12288, + 12288 + ] + }, + "decoder.layers.48.flat_param_0": { + "names": [ + "decoder.layers.48.self_attn.qkv_proj.weight", + "decoder.layers.48.self_attn.qkv_proj.bias", + "decoder.layers.48.self_attn.out_proj.weight", + "decoder.layers.48.self_attn.out_proj.bias", + "decoder.layers.48.self_attn_layer_norm.weight", + "decoder.layers.48.self_attn_layer_norm.bias", + "decoder.layers.48.fc1.weight", + "decoder.layers.48.fc1.bias", + "decoder.layers.48.fc2.weight", + "decoder.layers.48.fc2.bias", + "decoder.layers.48.final_layer_norm.weight", + "decoder.layers.48.final_layer_norm.bias" + ], + "shapes": [ + [ + 4608, + 12288 + ], + [ + 4608 + ], + [ + 12288, + 1536 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 6144, + 12288 + ], + [ + 6144 + ], + [ + 12288, + 6144 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ] + ], + "numels": [ + 56623104, + 4608, + 18874368, + 12288, + 12288, + 12288, + 75497472, + 6144, + 75497472, + 12288, + 12288, + 12288 + ] + }, + "decoder.layers.49.flat_param_0": { + "names": [ + "decoder.layers.49.self_attn.qkv_proj.weight", + "decoder.layers.49.self_attn.qkv_proj.bias", + "decoder.layers.49.self_attn.out_proj.weight", + "decoder.layers.49.self_attn.out_proj.bias", + "decoder.layers.49.self_attn_layer_norm.weight", + "decoder.layers.49.self_attn_layer_norm.bias", + "decoder.layers.49.fc1.weight", + "decoder.layers.49.fc1.bias", + "decoder.layers.49.fc2.weight", + "decoder.layers.49.fc2.bias", + "decoder.layers.49.final_layer_norm.weight", + "decoder.layers.49.final_layer_norm.bias" + ], + "shapes": [ + [ + 4608, + 12288 + ], + [ + 4608 + ], + [ + 12288, + 1536 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 6144, + 12288 + ], + [ + 6144 + ], + [ + 12288, + 6144 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ] + ], + "numels": [ + 56623104, + 4608, + 18874368, + 12288, + 12288, + 12288, + 75497472, + 6144, + 75497472, + 12288, + 12288, + 12288 + ] + }, + "decoder.layers.50.flat_param_0": { + "names": [ + "decoder.layers.50.self_attn.qkv_proj.weight", + "decoder.layers.50.self_attn.qkv_proj.bias", + "decoder.layers.50.self_attn.out_proj.weight", + "decoder.layers.50.self_attn.out_proj.bias", + "decoder.layers.50.self_attn_layer_norm.weight", + "decoder.layers.50.self_attn_layer_norm.bias", + "decoder.layers.50.fc1.weight", + "decoder.layers.50.fc1.bias", + "decoder.layers.50.fc2.weight", + "decoder.layers.50.fc2.bias", + "decoder.layers.50.final_layer_norm.weight", + "decoder.layers.50.final_layer_norm.bias" + ], + "shapes": [ + [ + 4608, + 12288 + ], + [ + 4608 + ], + [ + 12288, + 1536 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 6144, + 12288 + ], + [ + 6144 + ], + [ + 12288, + 6144 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ] + ], + "numels": [ + 56623104, + 4608, + 18874368, + 12288, + 12288, + 12288, + 75497472, + 6144, + 75497472, + 12288, + 12288, + 12288 + ] + }, + "decoder.layers.51.flat_param_0": { + "names": [ + "decoder.layers.51.self_attn.qkv_proj.weight", + "decoder.layers.51.self_attn.qkv_proj.bias", + "decoder.layers.51.self_attn.out_proj.weight", + "decoder.layers.51.self_attn.out_proj.bias", + "decoder.layers.51.self_attn_layer_norm.weight", + "decoder.layers.51.self_attn_layer_norm.bias", + "decoder.layers.51.fc1.weight", + "decoder.layers.51.fc1.bias", + "decoder.layers.51.fc2.weight", + "decoder.layers.51.fc2.bias", + "decoder.layers.51.final_layer_norm.weight", + "decoder.layers.51.final_layer_norm.bias" + ], + "shapes": [ + [ + 4608, + 12288 + ], + [ + 4608 + ], + [ + 12288, + 1536 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 6144, + 12288 + ], + [ + 6144 + ], + [ + 12288, + 6144 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ] + ], + "numels": [ + 56623104, + 4608, + 18874368, + 12288, + 12288, + 12288, + 75497472, + 6144, + 75497472, + 12288, + 12288, + 12288 + ] + }, + "decoder.layers.52.flat_param_0": { + "names": [ + "decoder.layers.52.self_attn.qkv_proj.weight", + "decoder.layers.52.self_attn.qkv_proj.bias", + "decoder.layers.52.self_attn.out_proj.weight", + "decoder.layers.52.self_attn.out_proj.bias", + "decoder.layers.52.self_attn_layer_norm.weight", + "decoder.layers.52.self_attn_layer_norm.bias", + "decoder.layers.52.fc1.weight", + "decoder.layers.52.fc1.bias", + "decoder.layers.52.fc2.weight", + "decoder.layers.52.fc2.bias", + "decoder.layers.52.final_layer_norm.weight", + "decoder.layers.52.final_layer_norm.bias" + ], + "shapes": [ + [ + 4608, + 12288 + ], + [ + 4608 + ], + [ + 12288, + 1536 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 6144, + 12288 + ], + [ + 6144 + ], + [ + 12288, + 6144 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ] + ], + "numels": [ + 56623104, + 4608, + 18874368, + 12288, + 12288, + 12288, + 75497472, + 6144, + 75497472, + 12288, + 12288, + 12288 + ] + }, + "decoder.layers.53.flat_param_0": { + "names": [ + "decoder.layers.53.self_attn.qkv_proj.weight", + "decoder.layers.53.self_attn.qkv_proj.bias", + "decoder.layers.53.self_attn.out_proj.weight", + "decoder.layers.53.self_attn.out_proj.bias", + "decoder.layers.53.self_attn_layer_norm.weight", + "decoder.layers.53.self_attn_layer_norm.bias", + "decoder.layers.53.fc1.weight", + "decoder.layers.53.fc1.bias", + "decoder.layers.53.fc2.weight", + "decoder.layers.53.fc2.bias", + "decoder.layers.53.final_layer_norm.weight", + "decoder.layers.53.final_layer_norm.bias" + ], + "shapes": [ + [ + 4608, + 12288 + ], + [ + 4608 + ], + [ + 12288, + 1536 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 6144, + 12288 + ], + [ + 6144 + ], + [ + 12288, + 6144 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ] + ], + "numels": [ + 56623104, + 4608, + 18874368, + 12288, + 12288, + 12288, + 75497472, + 6144, + 75497472, + 12288, + 12288, + 12288 + ] + }, + "decoder.layers.54.flat_param_0": { + "names": [ + "decoder.layers.54.self_attn.qkv_proj.weight", + "decoder.layers.54.self_attn.qkv_proj.bias", + "decoder.layers.54.self_attn.out_proj.weight", + "decoder.layers.54.self_attn.out_proj.bias", + "decoder.layers.54.self_attn_layer_norm.weight", + "decoder.layers.54.self_attn_layer_norm.bias", + "decoder.layers.54.fc1.weight", + "decoder.layers.54.fc1.bias", + "decoder.layers.54.fc2.weight", + "decoder.layers.54.fc2.bias", + "decoder.layers.54.final_layer_norm.weight", + "decoder.layers.54.final_layer_norm.bias" + ], + "shapes": [ + [ + 4608, + 12288 + ], + [ + 4608 + ], + [ + 12288, + 1536 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 6144, + 12288 + ], + [ + 6144 + ], + [ + 12288, + 6144 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ] + ], + "numels": [ + 56623104, + 4608, + 18874368, + 12288, + 12288, + 12288, + 75497472, + 6144, + 75497472, + 12288, + 12288, + 12288 + ] + }, + "decoder.layers.55.flat_param_0": { + "names": [ + "decoder.layers.55.self_attn.qkv_proj.weight", + "decoder.layers.55.self_attn.qkv_proj.bias", + "decoder.layers.55.self_attn.out_proj.weight", + "decoder.layers.55.self_attn.out_proj.bias", + "decoder.layers.55.self_attn_layer_norm.weight", + "decoder.layers.55.self_attn_layer_norm.bias", + "decoder.layers.55.fc1.weight", + "decoder.layers.55.fc1.bias", + "decoder.layers.55.fc2.weight", + "decoder.layers.55.fc2.bias", + "decoder.layers.55.final_layer_norm.weight", + "decoder.layers.55.final_layer_norm.bias" + ], + "shapes": [ + [ + 4608, + 12288 + ], + [ + 4608 + ], + [ + 12288, + 1536 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 6144, + 12288 + ], + [ + 6144 + ], + [ + 12288, + 6144 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ] + ], + "numels": [ + 56623104, + 4608, + 18874368, + 12288, + 12288, + 12288, + 75497472, + 6144, + 75497472, + 12288, + 12288, + 12288 + ] + }, + "decoder.layers.56.flat_param_0": { + "names": [ + "decoder.layers.56.self_attn.qkv_proj.weight", + "decoder.layers.56.self_attn.qkv_proj.bias", + "decoder.layers.56.self_attn.out_proj.weight", + "decoder.layers.56.self_attn.out_proj.bias", + "decoder.layers.56.self_attn_layer_norm.weight", + "decoder.layers.56.self_attn_layer_norm.bias", + "decoder.layers.56.fc1.weight", + "decoder.layers.56.fc1.bias", + "decoder.layers.56.fc2.weight", + "decoder.layers.56.fc2.bias", + "decoder.layers.56.final_layer_norm.weight", + "decoder.layers.56.final_layer_norm.bias" + ], + "shapes": [ + [ + 4608, + 12288 + ], + [ + 4608 + ], + [ + 12288, + 1536 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 6144, + 12288 + ], + [ + 6144 + ], + [ + 12288, + 6144 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ] + ], + "numels": [ + 56623104, + 4608, + 18874368, + 12288, + 12288, + 12288, + 75497472, + 6144, + 75497472, + 12288, + 12288, + 12288 + ] + }, + "decoder.layers.57.flat_param_0": { + "names": [ + "decoder.layers.57.self_attn.qkv_proj.weight", + "decoder.layers.57.self_attn.qkv_proj.bias", + "decoder.layers.57.self_attn.out_proj.weight", + "decoder.layers.57.self_attn.out_proj.bias", + "decoder.layers.57.self_attn_layer_norm.weight", + "decoder.layers.57.self_attn_layer_norm.bias", + "decoder.layers.57.fc1.weight", + "decoder.layers.57.fc1.bias", + "decoder.layers.57.fc2.weight", + "decoder.layers.57.fc2.bias", + "decoder.layers.57.final_layer_norm.weight", + "decoder.layers.57.final_layer_norm.bias" + ], + "shapes": [ + [ + 4608, + 12288 + ], + [ + 4608 + ], + [ + 12288, + 1536 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 6144, + 12288 + ], + [ + 6144 + ], + [ + 12288, + 6144 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ] + ], + "numels": [ + 56623104, + 4608, + 18874368, + 12288, + 12288, + 12288, + 75497472, + 6144, + 75497472, + 12288, + 12288, + 12288 + ] + }, + "decoder.layers.58.flat_param_0": { + "names": [ + "decoder.layers.58.self_attn.qkv_proj.weight", + "decoder.layers.58.self_attn.qkv_proj.bias", + "decoder.layers.58.self_attn.out_proj.weight", + "decoder.layers.58.self_attn.out_proj.bias", + "decoder.layers.58.self_attn_layer_norm.weight", + "decoder.layers.58.self_attn_layer_norm.bias", + "decoder.layers.58.fc1.weight", + "decoder.layers.58.fc1.bias", + "decoder.layers.58.fc2.weight", + "decoder.layers.58.fc2.bias", + "decoder.layers.58.final_layer_norm.weight", + "decoder.layers.58.final_layer_norm.bias" + ], + "shapes": [ + [ + 4608, + 12288 + ], + [ + 4608 + ], + [ + 12288, + 1536 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 6144, + 12288 + ], + [ + 6144 + ], + [ + 12288, + 6144 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ] + ], + "numels": [ + 56623104, + 4608, + 18874368, + 12288, + 12288, + 12288, + 75497472, + 6144, + 75497472, + 12288, + 12288, + 12288 + ] + }, + "decoder.layers.59.flat_param_0": { + "names": [ + "decoder.layers.59.self_attn.qkv_proj.weight", + "decoder.layers.59.self_attn.qkv_proj.bias", + "decoder.layers.59.self_attn.out_proj.weight", + "decoder.layers.59.self_attn.out_proj.bias", + "decoder.layers.59.self_attn_layer_norm.weight", + "decoder.layers.59.self_attn_layer_norm.bias", + "decoder.layers.59.fc1.weight", + "decoder.layers.59.fc1.bias", + "decoder.layers.59.fc2.weight", + "decoder.layers.59.fc2.bias", + "decoder.layers.59.final_layer_norm.weight", + "decoder.layers.59.final_layer_norm.bias" + ], + "shapes": [ + [ + 4608, + 12288 + ], + [ + 4608 + ], + [ + 12288, + 1536 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 6144, + 12288 + ], + [ + 6144 + ], + [ + 12288, + 6144 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ] + ], + "numels": [ + 56623104, + 4608, + 18874368, + 12288, + 12288, + 12288, + 75497472, + 6144, + 75497472, + 12288, + 12288, + 12288 + ] + }, + "decoder.layers.60.flat_param_0": { + "names": [ + "decoder.layers.60.self_attn.qkv_proj.weight", + "decoder.layers.60.self_attn.qkv_proj.bias", + "decoder.layers.60.self_attn.out_proj.weight", + "decoder.layers.60.self_attn.out_proj.bias", + "decoder.layers.60.self_attn_layer_norm.weight", + "decoder.layers.60.self_attn_layer_norm.bias", + "decoder.layers.60.fc1.weight", + "decoder.layers.60.fc1.bias", + "decoder.layers.60.fc2.weight", + "decoder.layers.60.fc2.bias", + "decoder.layers.60.final_layer_norm.weight", + "decoder.layers.60.final_layer_norm.bias" + ], + "shapes": [ + [ + 4608, + 12288 + ], + [ + 4608 + ], + [ + 12288, + 1536 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 6144, + 12288 + ], + [ + 6144 + ], + [ + 12288, + 6144 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ] + ], + "numels": [ + 56623104, + 4608, + 18874368, + 12288, + 12288, + 12288, + 75497472, + 6144, + 75497472, + 12288, + 12288, + 12288 + ] + }, + "decoder.layers.61.flat_param_0": { + "names": [ + "decoder.layers.61.self_attn.qkv_proj.weight", + "decoder.layers.61.self_attn.qkv_proj.bias", + "decoder.layers.61.self_attn.out_proj.weight", + "decoder.layers.61.self_attn.out_proj.bias", + "decoder.layers.61.self_attn_layer_norm.weight", + "decoder.layers.61.self_attn_layer_norm.bias", + "decoder.layers.61.fc1.weight", + "decoder.layers.61.fc1.bias", + "decoder.layers.61.fc2.weight", + "decoder.layers.61.fc2.bias", + "decoder.layers.61.final_layer_norm.weight", + "decoder.layers.61.final_layer_norm.bias" + ], + "shapes": [ + [ + 4608, + 12288 + ], + [ + 4608 + ], + [ + 12288, + 1536 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 6144, + 12288 + ], + [ + 6144 + ], + [ + 12288, + 6144 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ] + ], + "numels": [ + 56623104, + 4608, + 18874368, + 12288, + 12288, + 12288, + 75497472, + 6144, + 75497472, + 12288, + 12288, + 12288 + ] + }, + "decoder.layers.62.flat_param_0": { + "names": [ + "decoder.layers.62.self_attn.qkv_proj.weight", + "decoder.layers.62.self_attn.qkv_proj.bias", + "decoder.layers.62.self_attn.out_proj.weight", + "decoder.layers.62.self_attn.out_proj.bias", + "decoder.layers.62.self_attn_layer_norm.weight", + "decoder.layers.62.self_attn_layer_norm.bias", + "decoder.layers.62.fc1.weight", + "decoder.layers.62.fc1.bias", + "decoder.layers.62.fc2.weight", + "decoder.layers.62.fc2.bias", + "decoder.layers.62.final_layer_norm.weight", + "decoder.layers.62.final_layer_norm.bias" + ], + "shapes": [ + [ + 4608, + 12288 + ], + [ + 4608 + ], + [ + 12288, + 1536 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 6144, + 12288 + ], + [ + 6144 + ], + [ + 12288, + 6144 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ] + ], + "numels": [ + 56623104, + 4608, + 18874368, + 12288, + 12288, + 12288, + 75497472, + 6144, + 75497472, + 12288, + 12288, + 12288 + ] + }, + "decoder.layers.63.flat_param_0": { + "names": [ + "decoder.layers.63.self_attn.qkv_proj.weight", + "decoder.layers.63.self_attn.qkv_proj.bias", + "decoder.layers.63.self_attn.out_proj.weight", + "decoder.layers.63.self_attn.out_proj.bias", + "decoder.layers.63.self_attn_layer_norm.weight", + "decoder.layers.63.self_attn_layer_norm.bias", + "decoder.layers.63.fc1.weight", + "decoder.layers.63.fc1.bias", + "decoder.layers.63.fc2.weight", + "decoder.layers.63.fc2.bias", + "decoder.layers.63.final_layer_norm.weight", + "decoder.layers.63.final_layer_norm.bias" + ], + "shapes": [ + [ + 4608, + 12288 + ], + [ + 4608 + ], + [ + 12288, + 1536 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 6144, + 12288 + ], + [ + 6144 + ], + [ + 12288, + 6144 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ] + ], + "numels": [ + 56623104, + 4608, + 18874368, + 12288, + 12288, + 12288, + 75497472, + 6144, + 75497472, + 12288, + 12288, + 12288 + ] + }, + "decoder.layers.64.flat_param_0": { + "names": [ + "decoder.layers.64.self_attn.qkv_proj.weight", + "decoder.layers.64.self_attn.qkv_proj.bias", + "decoder.layers.64.self_attn.out_proj.weight", + "decoder.layers.64.self_attn.out_proj.bias", + "decoder.layers.64.self_attn_layer_norm.weight", + "decoder.layers.64.self_attn_layer_norm.bias", + "decoder.layers.64.fc1.weight", + "decoder.layers.64.fc1.bias", + "decoder.layers.64.fc2.weight", + "decoder.layers.64.fc2.bias", + "decoder.layers.64.final_layer_norm.weight", + "decoder.layers.64.final_layer_norm.bias" + ], + "shapes": [ + [ + 4608, + 12288 + ], + [ + 4608 + ], + [ + 12288, + 1536 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 6144, + 12288 + ], + [ + 6144 + ], + [ + 12288, + 6144 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ] + ], + "numels": [ + 56623104, + 4608, + 18874368, + 12288, + 12288, + 12288, + 75497472, + 6144, + 75497472, + 12288, + 12288, + 12288 + ] + }, + "decoder.layers.65.flat_param_0": { + "names": [ + "decoder.layers.65.self_attn.qkv_proj.weight", + "decoder.layers.65.self_attn.qkv_proj.bias", + "decoder.layers.65.self_attn.out_proj.weight", + "decoder.layers.65.self_attn.out_proj.bias", + "decoder.layers.65.self_attn_layer_norm.weight", + "decoder.layers.65.self_attn_layer_norm.bias", + "decoder.layers.65.fc1.weight", + "decoder.layers.65.fc1.bias", + "decoder.layers.65.fc2.weight", + "decoder.layers.65.fc2.bias", + "decoder.layers.65.final_layer_norm.weight", + "decoder.layers.65.final_layer_norm.bias" + ], + "shapes": [ + [ + 4608, + 12288 + ], + [ + 4608 + ], + [ + 12288, + 1536 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 6144, + 12288 + ], + [ + 6144 + ], + [ + 12288, + 6144 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ] + ], + "numels": [ + 56623104, + 4608, + 18874368, + 12288, + 12288, + 12288, + 75497472, + 6144, + 75497472, + 12288, + 12288, + 12288 + ] + }, + "decoder.layers.66.flat_param_0": { + "names": [ + "decoder.layers.66.self_attn.qkv_proj.weight", + "decoder.layers.66.self_attn.qkv_proj.bias", + "decoder.layers.66.self_attn.out_proj.weight", + "decoder.layers.66.self_attn.out_proj.bias", + "decoder.layers.66.self_attn_layer_norm.weight", + "decoder.layers.66.self_attn_layer_norm.bias", + "decoder.layers.66.fc1.weight", + "decoder.layers.66.fc1.bias", + "decoder.layers.66.fc2.weight", + "decoder.layers.66.fc2.bias", + "decoder.layers.66.final_layer_norm.weight", + "decoder.layers.66.final_layer_norm.bias" + ], + "shapes": [ + [ + 4608, + 12288 + ], + [ + 4608 + ], + [ + 12288, + 1536 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 6144, + 12288 + ], + [ + 6144 + ], + [ + 12288, + 6144 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ] + ], + "numels": [ + 56623104, + 4608, + 18874368, + 12288, + 12288, + 12288, + 75497472, + 6144, + 75497472, + 12288, + 12288, + 12288 + ] + }, + "decoder.layers.67.flat_param_0": { + "names": [ + "decoder.layers.67.self_attn.qkv_proj.weight", + "decoder.layers.67.self_attn.qkv_proj.bias", + "decoder.layers.67.self_attn.out_proj.weight", + "decoder.layers.67.self_attn.out_proj.bias", + "decoder.layers.67.self_attn_layer_norm.weight", + "decoder.layers.67.self_attn_layer_norm.bias", + "decoder.layers.67.fc1.weight", + "decoder.layers.67.fc1.bias", + "decoder.layers.67.fc2.weight", + "decoder.layers.67.fc2.bias", + "decoder.layers.67.final_layer_norm.weight", + "decoder.layers.67.final_layer_norm.bias" + ], + "shapes": [ + [ + 4608, + 12288 + ], + [ + 4608 + ], + [ + 12288, + 1536 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 6144, + 12288 + ], + [ + 6144 + ], + [ + 12288, + 6144 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ] + ], + "numels": [ + 56623104, + 4608, + 18874368, + 12288, + 12288, + 12288, + 75497472, + 6144, + 75497472, + 12288, + 12288, + 12288 + ] + }, + "decoder.layers.68.flat_param_0": { + "names": [ + "decoder.layers.68.self_attn.qkv_proj.weight", + "decoder.layers.68.self_attn.qkv_proj.bias", + "decoder.layers.68.self_attn.out_proj.weight", + "decoder.layers.68.self_attn.out_proj.bias", + "decoder.layers.68.self_attn_layer_norm.weight", + "decoder.layers.68.self_attn_layer_norm.bias", + "decoder.layers.68.fc1.weight", + "decoder.layers.68.fc1.bias", + "decoder.layers.68.fc2.weight", + "decoder.layers.68.fc2.bias", + "decoder.layers.68.final_layer_norm.weight", + "decoder.layers.68.final_layer_norm.bias" + ], + "shapes": [ + [ + 4608, + 12288 + ], + [ + 4608 + ], + [ + 12288, + 1536 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 6144, + 12288 + ], + [ + 6144 + ], + [ + 12288, + 6144 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ] + ], + "numels": [ + 56623104, + 4608, + 18874368, + 12288, + 12288, + 12288, + 75497472, + 6144, + 75497472, + 12288, + 12288, + 12288 + ] + }, + "decoder.layers.69.flat_param_0": { + "names": [ + "decoder.layers.69.self_attn.qkv_proj.weight", + "decoder.layers.69.self_attn.qkv_proj.bias", + "decoder.layers.69.self_attn.out_proj.weight", + "decoder.layers.69.self_attn.out_proj.bias", + "decoder.layers.69.self_attn_layer_norm.weight", + "decoder.layers.69.self_attn_layer_norm.bias", + "decoder.layers.69.fc1.weight", + "decoder.layers.69.fc1.bias", + "decoder.layers.69.fc2.weight", + "decoder.layers.69.fc2.bias", + "decoder.layers.69.final_layer_norm.weight", + "decoder.layers.69.final_layer_norm.bias" + ], + "shapes": [ + [ + 4608, + 12288 + ], + [ + 4608 + ], + [ + 12288, + 1536 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 6144, + 12288 + ], + [ + 6144 + ], + [ + 12288, + 6144 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ] + ], + "numels": [ + 56623104, + 4608, + 18874368, + 12288, + 12288, + 12288, + 75497472, + 6144, + 75497472, + 12288, + 12288, + 12288 + ] + }, + "decoder.layers.70.flat_param_0": { + "names": [ + "decoder.layers.70.self_attn.qkv_proj.weight", + "decoder.layers.70.self_attn.qkv_proj.bias", + "decoder.layers.70.self_attn.out_proj.weight", + "decoder.layers.70.self_attn.out_proj.bias", + "decoder.layers.70.self_attn_layer_norm.weight", + "decoder.layers.70.self_attn_layer_norm.bias", + "decoder.layers.70.fc1.weight", + "decoder.layers.70.fc1.bias", + "decoder.layers.70.fc2.weight", + "decoder.layers.70.fc2.bias", + "decoder.layers.70.final_layer_norm.weight", + "decoder.layers.70.final_layer_norm.bias" + ], + "shapes": [ + [ + 4608, + 12288 + ], + [ + 4608 + ], + [ + 12288, + 1536 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 6144, + 12288 + ], + [ + 6144 + ], + [ + 12288, + 6144 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ] + ], + "numels": [ + 56623104, + 4608, + 18874368, + 12288, + 12288, + 12288, + 75497472, + 6144, + 75497472, + 12288, + 12288, + 12288 + ] + }, + "decoder.layers.71.flat_param_0": { + "names": [ + "decoder.layers.71.self_attn.qkv_proj.weight", + "decoder.layers.71.self_attn.qkv_proj.bias", + "decoder.layers.71.self_attn.out_proj.weight", + "decoder.layers.71.self_attn.out_proj.bias", + "decoder.layers.71.self_attn_layer_norm.weight", + "decoder.layers.71.self_attn_layer_norm.bias", + "decoder.layers.71.fc1.weight", + "decoder.layers.71.fc1.bias", + "decoder.layers.71.fc2.weight", + "decoder.layers.71.fc2.bias", + "decoder.layers.71.final_layer_norm.weight", + "decoder.layers.71.final_layer_norm.bias" + ], + "shapes": [ + [ + 4608, + 12288 + ], + [ + 4608 + ], + [ + 12288, + 1536 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 6144, + 12288 + ], + [ + 6144 + ], + [ + 12288, + 6144 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ] + ], + "numels": [ + 56623104, + 4608, + 18874368, + 12288, + 12288, + 12288, + 75497472, + 6144, + 75497472, + 12288, + 12288, + 12288 + ] + }, + "decoder.layers.72.flat_param_0": { + "names": [ + "decoder.layers.72.self_attn.qkv_proj.weight", + "decoder.layers.72.self_attn.qkv_proj.bias", + "decoder.layers.72.self_attn.out_proj.weight", + "decoder.layers.72.self_attn.out_proj.bias", + "decoder.layers.72.self_attn_layer_norm.weight", + "decoder.layers.72.self_attn_layer_norm.bias", + "decoder.layers.72.fc1.weight", + "decoder.layers.72.fc1.bias", + "decoder.layers.72.fc2.weight", + "decoder.layers.72.fc2.bias", + "decoder.layers.72.final_layer_norm.weight", + "decoder.layers.72.final_layer_norm.bias" + ], + "shapes": [ + [ + 4608, + 12288 + ], + [ + 4608 + ], + [ + 12288, + 1536 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 6144, + 12288 + ], + [ + 6144 + ], + [ + 12288, + 6144 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ] + ], + "numels": [ + 56623104, + 4608, + 18874368, + 12288, + 12288, + 12288, + 75497472, + 6144, + 75497472, + 12288, + 12288, + 12288 + ] + }, + "decoder.layers.73.flat_param_0": { + "names": [ + "decoder.layers.73.self_attn.qkv_proj.weight", + "decoder.layers.73.self_attn.qkv_proj.bias", + "decoder.layers.73.self_attn.out_proj.weight", + "decoder.layers.73.self_attn.out_proj.bias", + "decoder.layers.73.self_attn_layer_norm.weight", + "decoder.layers.73.self_attn_layer_norm.bias", + "decoder.layers.73.fc1.weight", + "decoder.layers.73.fc1.bias", + "decoder.layers.73.fc2.weight", + "decoder.layers.73.fc2.bias", + "decoder.layers.73.final_layer_norm.weight", + "decoder.layers.73.final_layer_norm.bias" + ], + "shapes": [ + [ + 4608, + 12288 + ], + [ + 4608 + ], + [ + 12288, + 1536 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 6144, + 12288 + ], + [ + 6144 + ], + [ + 12288, + 6144 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ] + ], + "numels": [ + 56623104, + 4608, + 18874368, + 12288, + 12288, + 12288, + 75497472, + 6144, + 75497472, + 12288, + 12288, + 12288 + ] + }, + "decoder.layers.74.flat_param_0": { + "names": [ + "decoder.layers.74.self_attn.qkv_proj.weight", + "decoder.layers.74.self_attn.qkv_proj.bias", + "decoder.layers.74.self_attn.out_proj.weight", + "decoder.layers.74.self_attn.out_proj.bias", + "decoder.layers.74.self_attn_layer_norm.weight", + "decoder.layers.74.self_attn_layer_norm.bias", + "decoder.layers.74.fc1.weight", + "decoder.layers.74.fc1.bias", + "decoder.layers.74.fc2.weight", + "decoder.layers.74.fc2.bias", + "decoder.layers.74.final_layer_norm.weight", + "decoder.layers.74.final_layer_norm.bias" + ], + "shapes": [ + [ + 4608, + 12288 + ], + [ + 4608 + ], + [ + 12288, + 1536 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 6144, + 12288 + ], + [ + 6144 + ], + [ + 12288, + 6144 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ] + ], + "numels": [ + 56623104, + 4608, + 18874368, + 12288, + 12288, + 12288, + 75497472, + 6144, + 75497472, + 12288, + 12288, + 12288 + ] + }, + "decoder.layers.75.flat_param_0": { + "names": [ + "decoder.layers.75.self_attn.qkv_proj.weight", + "decoder.layers.75.self_attn.qkv_proj.bias", + "decoder.layers.75.self_attn.out_proj.weight", + "decoder.layers.75.self_attn.out_proj.bias", + "decoder.layers.75.self_attn_layer_norm.weight", + "decoder.layers.75.self_attn_layer_norm.bias", + "decoder.layers.75.fc1.weight", + "decoder.layers.75.fc1.bias", + "decoder.layers.75.fc2.weight", + "decoder.layers.75.fc2.bias", + "decoder.layers.75.final_layer_norm.weight", + "decoder.layers.75.final_layer_norm.bias" + ], + "shapes": [ + [ + 4608, + 12288 + ], + [ + 4608 + ], + [ + 12288, + 1536 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 6144, + 12288 + ], + [ + 6144 + ], + [ + 12288, + 6144 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ] + ], + "numels": [ + 56623104, + 4608, + 18874368, + 12288, + 12288, + 12288, + 75497472, + 6144, + 75497472, + 12288, + 12288, + 12288 + ] + }, + "decoder.layers.76.flat_param_0": { + "names": [ + "decoder.layers.76.self_attn.qkv_proj.weight", + "decoder.layers.76.self_attn.qkv_proj.bias", + "decoder.layers.76.self_attn.out_proj.weight", + "decoder.layers.76.self_attn.out_proj.bias", + "decoder.layers.76.self_attn_layer_norm.weight", + "decoder.layers.76.self_attn_layer_norm.bias", + "decoder.layers.76.fc1.weight", + "decoder.layers.76.fc1.bias", + "decoder.layers.76.fc2.weight", + "decoder.layers.76.fc2.bias", + "decoder.layers.76.final_layer_norm.weight", + "decoder.layers.76.final_layer_norm.bias" + ], + "shapes": [ + [ + 4608, + 12288 + ], + [ + 4608 + ], + [ + 12288, + 1536 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 6144, + 12288 + ], + [ + 6144 + ], + [ + 12288, + 6144 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ] + ], + "numels": [ + 56623104, + 4608, + 18874368, + 12288, + 12288, + 12288, + 75497472, + 6144, + 75497472, + 12288, + 12288, + 12288 + ] + }, + "decoder.layers.77.flat_param_0": { + "names": [ + "decoder.layers.77.self_attn.qkv_proj.weight", + "decoder.layers.77.self_attn.qkv_proj.bias", + "decoder.layers.77.self_attn.out_proj.weight", + "decoder.layers.77.self_attn.out_proj.bias", + "decoder.layers.77.self_attn_layer_norm.weight", + "decoder.layers.77.self_attn_layer_norm.bias", + "decoder.layers.77.fc1.weight", + "decoder.layers.77.fc1.bias", + "decoder.layers.77.fc2.weight", + "decoder.layers.77.fc2.bias", + "decoder.layers.77.final_layer_norm.weight", + "decoder.layers.77.final_layer_norm.bias" + ], + "shapes": [ + [ + 4608, + 12288 + ], + [ + 4608 + ], + [ + 12288, + 1536 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 6144, + 12288 + ], + [ + 6144 + ], + [ + 12288, + 6144 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ] + ], + "numels": [ + 56623104, + 4608, + 18874368, + 12288, + 12288, + 12288, + 75497472, + 6144, + 75497472, + 12288, + 12288, + 12288 + ] + }, + "decoder.layers.78.flat_param_0": { + "names": [ + "decoder.layers.78.self_attn.qkv_proj.weight", + "decoder.layers.78.self_attn.qkv_proj.bias", + "decoder.layers.78.self_attn.out_proj.weight", + "decoder.layers.78.self_attn.out_proj.bias", + "decoder.layers.78.self_attn_layer_norm.weight", + "decoder.layers.78.self_attn_layer_norm.bias", + "decoder.layers.78.fc1.weight", + "decoder.layers.78.fc1.bias", + "decoder.layers.78.fc2.weight", + "decoder.layers.78.fc2.bias", + "decoder.layers.78.final_layer_norm.weight", + "decoder.layers.78.final_layer_norm.bias" + ], + "shapes": [ + [ + 4608, + 12288 + ], + [ + 4608 + ], + [ + 12288, + 1536 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 6144, + 12288 + ], + [ + 6144 + ], + [ + 12288, + 6144 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ] + ], + "numels": [ + 56623104, + 4608, + 18874368, + 12288, + 12288, + 12288, + 75497472, + 6144, + 75497472, + 12288, + 12288, + 12288 + ] + }, + "decoder.layers.79.flat_param_0": { + "names": [ + "decoder.layers.79.self_attn.qkv_proj.weight", + "decoder.layers.79.self_attn.qkv_proj.bias", + "decoder.layers.79.self_attn.out_proj.weight", + "decoder.layers.79.self_attn.out_proj.bias", + "decoder.layers.79.self_attn_layer_norm.weight", + "decoder.layers.79.self_attn_layer_norm.bias", + "decoder.layers.79.fc1.weight", + "decoder.layers.79.fc1.bias", + "decoder.layers.79.fc2.weight", + "decoder.layers.79.fc2.bias", + "decoder.layers.79.final_layer_norm.weight", + "decoder.layers.79.final_layer_norm.bias" + ], + "shapes": [ + [ + 4608, + 12288 + ], + [ + 4608 + ], + [ + 12288, + 1536 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 6144, + 12288 + ], + [ + 6144 + ], + [ + 12288, + 6144 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ] + ], + "numels": [ + 56623104, + 4608, + 18874368, + 12288, + 12288, + 12288, + 75497472, + 6144, + 75497472, + 12288, + 12288, + 12288 + ] + }, + "decoder.layers.80.flat_param_0": { + "names": [ + "decoder.layers.80.self_attn.qkv_proj.weight", + "decoder.layers.80.self_attn.qkv_proj.bias", + "decoder.layers.80.self_attn.out_proj.weight", + "decoder.layers.80.self_attn.out_proj.bias", + "decoder.layers.80.self_attn_layer_norm.weight", + "decoder.layers.80.self_attn_layer_norm.bias", + "decoder.layers.80.fc1.weight", + "decoder.layers.80.fc1.bias", + "decoder.layers.80.fc2.weight", + "decoder.layers.80.fc2.bias", + "decoder.layers.80.final_layer_norm.weight", + "decoder.layers.80.final_layer_norm.bias" + ], + "shapes": [ + [ + 4608, + 12288 + ], + [ + 4608 + ], + [ + 12288, + 1536 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 6144, + 12288 + ], + [ + 6144 + ], + [ + 12288, + 6144 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ] + ], + "numels": [ + 56623104, + 4608, + 18874368, + 12288, + 12288, + 12288, + 75497472, + 6144, + 75497472, + 12288, + 12288, + 12288 + ] + }, + "decoder.layers.81.flat_param_0": { + "names": [ + "decoder.layers.81.self_attn.qkv_proj.weight", + "decoder.layers.81.self_attn.qkv_proj.bias", + "decoder.layers.81.self_attn.out_proj.weight", + "decoder.layers.81.self_attn.out_proj.bias", + "decoder.layers.81.self_attn_layer_norm.weight", + "decoder.layers.81.self_attn_layer_norm.bias", + "decoder.layers.81.fc1.weight", + "decoder.layers.81.fc1.bias", + "decoder.layers.81.fc2.weight", + "decoder.layers.81.fc2.bias", + "decoder.layers.81.final_layer_norm.weight", + "decoder.layers.81.final_layer_norm.bias" + ], + "shapes": [ + [ + 4608, + 12288 + ], + [ + 4608 + ], + [ + 12288, + 1536 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 6144, + 12288 + ], + [ + 6144 + ], + [ + 12288, + 6144 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ] + ], + "numels": [ + 56623104, + 4608, + 18874368, + 12288, + 12288, + 12288, + 75497472, + 6144, + 75497472, + 12288, + 12288, + 12288 + ] + }, + "decoder.layers.82.flat_param_0": { + "names": [ + "decoder.layers.82.self_attn.qkv_proj.weight", + "decoder.layers.82.self_attn.qkv_proj.bias", + "decoder.layers.82.self_attn.out_proj.weight", + "decoder.layers.82.self_attn.out_proj.bias", + "decoder.layers.82.self_attn_layer_norm.weight", + "decoder.layers.82.self_attn_layer_norm.bias", + "decoder.layers.82.fc1.weight", + "decoder.layers.82.fc1.bias", + "decoder.layers.82.fc2.weight", + "decoder.layers.82.fc2.bias", + "decoder.layers.82.final_layer_norm.weight", + "decoder.layers.82.final_layer_norm.bias" + ], + "shapes": [ + [ + 4608, + 12288 + ], + [ + 4608 + ], + [ + 12288, + 1536 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 6144, + 12288 + ], + [ + 6144 + ], + [ + 12288, + 6144 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ] + ], + "numels": [ + 56623104, + 4608, + 18874368, + 12288, + 12288, + 12288, + 75497472, + 6144, + 75497472, + 12288, + 12288, + 12288 + ] + }, + "decoder.layers.83.flat_param_0": { + "names": [ + "decoder.layers.83.self_attn.qkv_proj.weight", + "decoder.layers.83.self_attn.qkv_proj.bias", + "decoder.layers.83.self_attn.out_proj.weight", + "decoder.layers.83.self_attn.out_proj.bias", + "decoder.layers.83.self_attn_layer_norm.weight", + "decoder.layers.83.self_attn_layer_norm.bias", + "decoder.layers.83.fc1.weight", + "decoder.layers.83.fc1.bias", + "decoder.layers.83.fc2.weight", + "decoder.layers.83.fc2.bias", + "decoder.layers.83.final_layer_norm.weight", + "decoder.layers.83.final_layer_norm.bias" + ], + "shapes": [ + [ + 4608, + 12288 + ], + [ + 4608 + ], + [ + 12288, + 1536 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 6144, + 12288 + ], + [ + 6144 + ], + [ + 12288, + 6144 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ] + ], + "numels": [ + 56623104, + 4608, + 18874368, + 12288, + 12288, + 12288, + 75497472, + 6144, + 75497472, + 12288, + 12288, + 12288 + ] + }, + "decoder.layers.84.flat_param_0": { + "names": [ + "decoder.layers.84.self_attn.qkv_proj.weight", + "decoder.layers.84.self_attn.qkv_proj.bias", + "decoder.layers.84.self_attn.out_proj.weight", + "decoder.layers.84.self_attn.out_proj.bias", + "decoder.layers.84.self_attn_layer_norm.weight", + "decoder.layers.84.self_attn_layer_norm.bias", + "decoder.layers.84.fc1.weight", + "decoder.layers.84.fc1.bias", + "decoder.layers.84.fc2.weight", + "decoder.layers.84.fc2.bias", + "decoder.layers.84.final_layer_norm.weight", + "decoder.layers.84.final_layer_norm.bias" + ], + "shapes": [ + [ + 4608, + 12288 + ], + [ + 4608 + ], + [ + 12288, + 1536 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 6144, + 12288 + ], + [ + 6144 + ], + [ + 12288, + 6144 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ] + ], + "numels": [ + 56623104, + 4608, + 18874368, + 12288, + 12288, + 12288, + 75497472, + 6144, + 75497472, + 12288, + 12288, + 12288 + ] + }, + "decoder.layers.85.flat_param_0": { + "names": [ + "decoder.layers.85.self_attn.qkv_proj.weight", + "decoder.layers.85.self_attn.qkv_proj.bias", + "decoder.layers.85.self_attn.out_proj.weight", + "decoder.layers.85.self_attn.out_proj.bias", + "decoder.layers.85.self_attn_layer_norm.weight", + "decoder.layers.85.self_attn_layer_norm.bias", + "decoder.layers.85.fc1.weight", + "decoder.layers.85.fc1.bias", + "decoder.layers.85.fc2.weight", + "decoder.layers.85.fc2.bias", + "decoder.layers.85.final_layer_norm.weight", + "decoder.layers.85.final_layer_norm.bias" + ], + "shapes": [ + [ + 4608, + 12288 + ], + [ + 4608 + ], + [ + 12288, + 1536 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 6144, + 12288 + ], + [ + 6144 + ], + [ + 12288, + 6144 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ] + ], + "numels": [ + 56623104, + 4608, + 18874368, + 12288, + 12288, + 12288, + 75497472, + 6144, + 75497472, + 12288, + 12288, + 12288 + ] + }, + "decoder.layers.86.flat_param_0": { + "names": [ + "decoder.layers.86.self_attn.qkv_proj.weight", + "decoder.layers.86.self_attn.qkv_proj.bias", + "decoder.layers.86.self_attn.out_proj.weight", + "decoder.layers.86.self_attn.out_proj.bias", + "decoder.layers.86.self_attn_layer_norm.weight", + "decoder.layers.86.self_attn_layer_norm.bias", + "decoder.layers.86.fc1.weight", + "decoder.layers.86.fc1.bias", + "decoder.layers.86.fc2.weight", + "decoder.layers.86.fc2.bias", + "decoder.layers.86.final_layer_norm.weight", + "decoder.layers.86.final_layer_norm.bias" + ], + "shapes": [ + [ + 4608, + 12288 + ], + [ + 4608 + ], + [ + 12288, + 1536 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 6144, + 12288 + ], + [ + 6144 + ], + [ + 12288, + 6144 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ] + ], + "numels": [ + 56623104, + 4608, + 18874368, + 12288, + 12288, + 12288, + 75497472, + 6144, + 75497472, + 12288, + 12288, + 12288 + ] + }, + "decoder.layers.87.flat_param_0": { + "names": [ + "decoder.layers.87.self_attn.qkv_proj.weight", + "decoder.layers.87.self_attn.qkv_proj.bias", + "decoder.layers.87.self_attn.out_proj.weight", + "decoder.layers.87.self_attn.out_proj.bias", + "decoder.layers.87.self_attn_layer_norm.weight", + "decoder.layers.87.self_attn_layer_norm.bias", + "decoder.layers.87.fc1.weight", + "decoder.layers.87.fc1.bias", + "decoder.layers.87.fc2.weight", + "decoder.layers.87.fc2.bias", + "decoder.layers.87.final_layer_norm.weight", + "decoder.layers.87.final_layer_norm.bias" + ], + "shapes": [ + [ + 4608, + 12288 + ], + [ + 4608 + ], + [ + 12288, + 1536 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 6144, + 12288 + ], + [ + 6144 + ], + [ + 12288, + 6144 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ] + ], + "numels": [ + 56623104, + 4608, + 18874368, + 12288, + 12288, + 12288, + 75497472, + 6144, + 75497472, + 12288, + 12288, + 12288 + ] + }, + "decoder.layers.88.flat_param_0": { + "names": [ + "decoder.layers.88.self_attn.qkv_proj.weight", + "decoder.layers.88.self_attn.qkv_proj.bias", + "decoder.layers.88.self_attn.out_proj.weight", + "decoder.layers.88.self_attn.out_proj.bias", + "decoder.layers.88.self_attn_layer_norm.weight", + "decoder.layers.88.self_attn_layer_norm.bias", + "decoder.layers.88.fc1.weight", + "decoder.layers.88.fc1.bias", + "decoder.layers.88.fc2.weight", + "decoder.layers.88.fc2.bias", + "decoder.layers.88.final_layer_norm.weight", + "decoder.layers.88.final_layer_norm.bias" + ], + "shapes": [ + [ + 4608, + 12288 + ], + [ + 4608 + ], + [ + 12288, + 1536 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 6144, + 12288 + ], + [ + 6144 + ], + [ + 12288, + 6144 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ] + ], + "numels": [ + 56623104, + 4608, + 18874368, + 12288, + 12288, + 12288, + 75497472, + 6144, + 75497472, + 12288, + 12288, + 12288 + ] + }, + "decoder.layers.89.flat_param_0": { + "names": [ + "decoder.layers.89.self_attn.qkv_proj.weight", + "decoder.layers.89.self_attn.qkv_proj.bias", + "decoder.layers.89.self_attn.out_proj.weight", + "decoder.layers.89.self_attn.out_proj.bias", + "decoder.layers.89.self_attn_layer_norm.weight", + "decoder.layers.89.self_attn_layer_norm.bias", + "decoder.layers.89.fc1.weight", + "decoder.layers.89.fc1.bias", + "decoder.layers.89.fc2.weight", + "decoder.layers.89.fc2.bias", + "decoder.layers.89.final_layer_norm.weight", + "decoder.layers.89.final_layer_norm.bias" + ], + "shapes": [ + [ + 4608, + 12288 + ], + [ + 4608 + ], + [ + 12288, + 1536 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 6144, + 12288 + ], + [ + 6144 + ], + [ + 12288, + 6144 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ] + ], + "numels": [ + 56623104, + 4608, + 18874368, + 12288, + 12288, + 12288, + 75497472, + 6144, + 75497472, + 12288, + 12288, + 12288 + ] + }, + "decoder.layers.90.flat_param_0": { + "names": [ + "decoder.layers.90.self_attn.qkv_proj.weight", + "decoder.layers.90.self_attn.qkv_proj.bias", + "decoder.layers.90.self_attn.out_proj.weight", + "decoder.layers.90.self_attn.out_proj.bias", + "decoder.layers.90.self_attn_layer_norm.weight", + "decoder.layers.90.self_attn_layer_norm.bias", + "decoder.layers.90.fc1.weight", + "decoder.layers.90.fc1.bias", + "decoder.layers.90.fc2.weight", + "decoder.layers.90.fc2.bias", + "decoder.layers.90.final_layer_norm.weight", + "decoder.layers.90.final_layer_norm.bias" + ], + "shapes": [ + [ + 4608, + 12288 + ], + [ + 4608 + ], + [ + 12288, + 1536 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 6144, + 12288 + ], + [ + 6144 + ], + [ + 12288, + 6144 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ] + ], + "numels": [ + 56623104, + 4608, + 18874368, + 12288, + 12288, + 12288, + 75497472, + 6144, + 75497472, + 12288, + 12288, + 12288 + ] + }, + "decoder.layers.91.flat_param_0": { + "names": [ + "decoder.layers.91.self_attn.qkv_proj.weight", + "decoder.layers.91.self_attn.qkv_proj.bias", + "decoder.layers.91.self_attn.out_proj.weight", + "decoder.layers.91.self_attn.out_proj.bias", + "decoder.layers.91.self_attn_layer_norm.weight", + "decoder.layers.91.self_attn_layer_norm.bias", + "decoder.layers.91.fc1.weight", + "decoder.layers.91.fc1.bias", + "decoder.layers.91.fc2.weight", + "decoder.layers.91.fc2.bias", + "decoder.layers.91.final_layer_norm.weight", + "decoder.layers.91.final_layer_norm.bias" + ], + "shapes": [ + [ + 4608, + 12288 + ], + [ + 4608 + ], + [ + 12288, + 1536 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 6144, + 12288 + ], + [ + 6144 + ], + [ + 12288, + 6144 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ] + ], + "numels": [ + 56623104, + 4608, + 18874368, + 12288, + 12288, + 12288, + 75497472, + 6144, + 75497472, + 12288, + 12288, + 12288 + ] + }, + "decoder.layers.92.flat_param_0": { + "names": [ + "decoder.layers.92.self_attn.qkv_proj.weight", + "decoder.layers.92.self_attn.qkv_proj.bias", + "decoder.layers.92.self_attn.out_proj.weight", + "decoder.layers.92.self_attn.out_proj.bias", + "decoder.layers.92.self_attn_layer_norm.weight", + "decoder.layers.92.self_attn_layer_norm.bias", + "decoder.layers.92.fc1.weight", + "decoder.layers.92.fc1.bias", + "decoder.layers.92.fc2.weight", + "decoder.layers.92.fc2.bias", + "decoder.layers.92.final_layer_norm.weight", + "decoder.layers.92.final_layer_norm.bias" + ], + "shapes": [ + [ + 4608, + 12288 + ], + [ + 4608 + ], + [ + 12288, + 1536 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 6144, + 12288 + ], + [ + 6144 + ], + [ + 12288, + 6144 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ] + ], + "numels": [ + 56623104, + 4608, + 18874368, + 12288, + 12288, + 12288, + 75497472, + 6144, + 75497472, + 12288, + 12288, + 12288 + ] + }, + "decoder.layers.93.flat_param_0": { + "names": [ + "decoder.layers.93.self_attn.qkv_proj.weight", + "decoder.layers.93.self_attn.qkv_proj.bias", + "decoder.layers.93.self_attn.out_proj.weight", + "decoder.layers.93.self_attn.out_proj.bias", + "decoder.layers.93.self_attn_layer_norm.weight", + "decoder.layers.93.self_attn_layer_norm.bias", + "decoder.layers.93.fc1.weight", + "decoder.layers.93.fc1.bias", + "decoder.layers.93.fc2.weight", + "decoder.layers.93.fc2.bias", + "decoder.layers.93.final_layer_norm.weight", + "decoder.layers.93.final_layer_norm.bias" + ], + "shapes": [ + [ + 4608, + 12288 + ], + [ + 4608 + ], + [ + 12288, + 1536 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 6144, + 12288 + ], + [ + 6144 + ], + [ + 12288, + 6144 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ] + ], + "numels": [ + 56623104, + 4608, + 18874368, + 12288, + 12288, + 12288, + 75497472, + 6144, + 75497472, + 12288, + 12288, + 12288 + ] + }, + "decoder.layers.94.flat_param_0": { + "names": [ + "decoder.layers.94.self_attn.qkv_proj.weight", + "decoder.layers.94.self_attn.qkv_proj.bias", + "decoder.layers.94.self_attn.out_proj.weight", + "decoder.layers.94.self_attn.out_proj.bias", + "decoder.layers.94.self_attn_layer_norm.weight", + "decoder.layers.94.self_attn_layer_norm.bias", + "decoder.layers.94.fc1.weight", + "decoder.layers.94.fc1.bias", + "decoder.layers.94.fc2.weight", + "decoder.layers.94.fc2.bias", + "decoder.layers.94.final_layer_norm.weight", + "decoder.layers.94.final_layer_norm.bias" + ], + "shapes": [ + [ + 4608, + 12288 + ], + [ + 4608 + ], + [ + 12288, + 1536 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 6144, + 12288 + ], + [ + 6144 + ], + [ + 12288, + 6144 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ] + ], + "numels": [ + 56623104, + 4608, + 18874368, + 12288, + 12288, + 12288, + 75497472, + 6144, + 75497472, + 12288, + 12288, + 12288 + ] + }, + "decoder.layers.95.flat_param_0": { + "names": [ + "decoder.layers.95.self_attn.qkv_proj.weight", + "decoder.layers.95.self_attn.qkv_proj.bias", + "decoder.layers.95.self_attn.out_proj.weight", + "decoder.layers.95.self_attn.out_proj.bias", + "decoder.layers.95.self_attn_layer_norm.weight", + "decoder.layers.95.self_attn_layer_norm.bias", + "decoder.layers.95.fc1.weight", + "decoder.layers.95.fc1.bias", + "decoder.layers.95.fc2.weight", + "decoder.layers.95.fc2.bias", + "decoder.layers.95.final_layer_norm.weight", + "decoder.layers.95.final_layer_norm.bias" + ], + "shapes": [ + [ + 4608, + 12288 + ], + [ + 4608 + ], + [ + 12288, + 1536 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 6144, + 12288 + ], + [ + 6144 + ], + [ + 12288, + 6144 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ] + ], + "numels": [ + 56623104, + 4608, + 18874368, + 12288, + 12288, + 12288, + 75497472, + 6144, + 75497472, + 12288, + 12288, + 12288 + ] + } +} diff --git a/examples/tutorial/opt/inference/script/processing_ckpt_66b.py b/examples/tutorial/opt/inference/script/processing_ckpt_66b.py index 0494647d7bcce31a6ba1e3decf0943043730c2fa..576daacdb47170d90b026963221fd2a6d7ee43ae 100644 --- a/examples/tutorial/opt/inference/script/processing_ckpt_66b.py +++ b/examples/tutorial/opt/inference/script/processing_ckpt_66b.py @@ -1,7 +1,8 @@ import os -import torch from multiprocessing import Pool +import torch + # download pytorch model ckpt in https://huggingface.co/facebook/opt-66b/tree/main # you can use whether wget or git lfs @@ -20,14 +21,14 @@ with Pool(14) as pool: restored = {} for ckpt in ckpts: - for k,v in ckpt.items(): - if(k[0] == 'm'): - k = k[6:] - if(k == "lm_head.weight"): + for k, v in ckpt.items(): + if k[0] == "m": + k = k[6:] + if k == "lm_head.weight": k = "head.dense.weight" - if(k == "decoder.final_layer_norm.weight"): + if k == "decoder.final_layer_norm.weight": k = "decoder.layer_norm.weight" - if(k == "decoder.final_layer_norm.bias"): + if k == "decoder.final_layer_norm.bias": k = "decoder.layer_norm.bias" restored[k] = v restored["decoder.version"] = "0.0" @@ -37,11 +38,11 @@ split_num = len(restored.keys()) // 60 count = 0 file_count = 1 tmp = {} -for k,v in restored.items(): +for k, v in restored.items(): print(k) tmp[k] = v - count = count + 1 - if(count == split_num): + count = count + 1 + if count == split_num: filename = str(file_count) + "-restored.pt" torch.save(tmp, os.path.join(new_path, filename)) file_count = file_count + 1 @@ -50,6 +51,3 @@ for k,v in restored.items(): filename = str(file_count) + "-restored.pt" torch.save(tmp, os.path.join(new_path, filename)) - - - diff --git a/examples/tutorial/opt/opt/colossalai_zero.py b/examples/tutorial/opt/opt/colossalai_zero.py index 7c2c152450c5563a19a80c3e14b8463fa8423c9a..75516bba560f7370f88000846738d7e4af2aae2a 100644 --- a/examples/tutorial/opt/opt/colossalai_zero.py +++ b/examples/tutorial/opt/opt/colossalai_zero.py @@ -2,9 +2,9 @@ try: from colossalai.zero.shard_utils import TensorShardStrategy except ImportError: # colossalai > 0.2.8 - from colossalai.zero.legacy import TensorShardStrategy + from colossalai.legacy.zero import TensorShardStrategy -zero = dict(model_config=dict(shard_strategy=TensorShardStrategy(), - tensor_placement_policy="auto", - reuse_fp16_shard=True), - optimizer_config=dict(gpu_margin_mem_ratio=0.8, initial_scale=16384)) +zero = dict( + model_config=dict(shard_strategy=TensorShardStrategy(), tensor_placement_policy="auto", reuse_fp16_shard=True), + optimizer_config=dict(gpu_margin_mem_ratio=0.8, initial_scale=16384), +) diff --git a/examples/tutorial/opt/opt/context.py b/examples/tutorial/opt/opt/context.py index 95f0abf1d8c92ed5766e5f0fa2c70618be7827c5..7172408f3cbc4ce0a9136d7aff26cf10e54c99fd 100644 --- a/examples/tutorial/opt/opt/context.py +++ b/examples/tutorial/opt/opt/context.py @@ -1,10 +1,10 @@ import torch.distributed as dist -from colossalai.context import ParallelMode -from colossalai.core import global_context as gpc +from colossalai.legacy.context import ParallelMode +from colossalai.legacy.core import global_context as gpc -class barrier_context(): +class barrier_context: """ This context manager is used to allow one process to execute while blocking all other processes in the same process group. This is often useful when downloading is required diff --git a/examples/tutorial/opt/opt/requirements.txt b/examples/tutorial/opt/opt/requirements.txt index d0ed2c717aee74916ee54a29ebcf69a6d41ca2f5..f2df112fa6ba3b6c8fedfab0b0129a15fb89b5b2 100644 --- a/examples/tutorial/opt/opt/requirements.txt +++ b/examples/tutorial/opt/opt/requirements.txt @@ -3,5 +3,5 @@ torch >= 1.8.1 datasets >= 1.8.0 sentencepiece != 0.1.92 protobuf -accelerate == 0.13.2 +accelerate >= 0.20.3 transformers diff --git a/examples/tutorial/opt/opt/run_clm.py b/examples/tutorial/opt/opt/run_clm.py index fdc86adab66578156809231ede5c16064e57ad71..9bd23ffc8aba5ec9871c422bc89457f72cf76ca6 100755 --- a/examples/tutorial/opt/opt/run_clm.py +++ b/examples/tutorial/opt/opt/run_clm.py @@ -30,7 +30,7 @@ from itertools import chain import datasets import torch import torch.distributed as dist -import transformers +import transformers.utils.logging as logging from accelerate.utils import set_seed from context import barrier_context from datasets import load_dataset @@ -51,13 +51,14 @@ from transformers import ( from transformers.utils.versions import require_version import colossalai -from colossalai.context import ParallelMode -from colossalai.core import global_context as gpc +from colossalai.legacy.context import ParallelMode +from colossalai.legacy.core import global_context as gpc +from colossalai.legacy.tensor import ProcessGroup +from colossalai.legacy.utils import get_dataloader from colossalai.logging import disable_existing_loggers, get_dist_logger from colossalai.nn.optimizer import HybridAdam -from colossalai.tensor import ProcessGroup -from colossalai.utils import get_current_device, get_dataloader -from colossalai.zero import ColoInitContext, ZeroDDP, ZeroOptimizer +from colossalai.utils import get_current_device +from colossalai.zero import GeminiOptimizer require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/language-modeling/requirements.txt") @@ -85,14 +86,12 @@ def parse_args(): default=None, help="The configuration name of the dataset to use (via the datasets library).", ) - parser.add_argument("--train_file", - type=str, - default=None, - help="A csv or a json file containing the training data.") - parser.add_argument("--validation_file", - type=str, - default=None, - help="A csv or a json file containing the validation data.") + parser.add_argument( + "--train_file", type=str, default=None, help="A csv or a json file containing the training data." + ) + parser.add_argument( + "--validation_file", type=str, default=None, help="A csv or a json file containing the validation data." + ) parser.add_argument( "--validation_split_percentage", default=5, @@ -160,10 +159,9 @@ def parse_args(): help="The scheduler type to use.", choices=["linear", "cosine", "cosine_with_restarts", "polynomial", "constant", "constant_with_warmup"], ) - parser.add_argument("--num_warmup_steps", - type=int, - default=0, - help="Number of steps for the warmup in the lr scheduler.") + parser.add_argument( + "--num_warmup_steps", type=int, default=0, help="Number of steps for the warmup in the lr scheduler." + ) parser.add_argument("--output_dir", type=str, default=None, help="Where to store the final model.") parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.") parser.add_argument( @@ -177,9 +175,11 @@ def parse_args(): "--block_size", type=int, default=None, - help=("Optional input sequence length after tokenization. The training dataset will be truncated in block of" - " this size for training. Default to the model max input length for single sentence inputs (take into" - " account special tokens)."), + help=( + "Optional input sequence length after tokenization. The training dataset will be truncated in block of" + " this size for training. Default to the model max input length for single sentence inputs (take into" + " account special tokens)." + ), ) parser.add_argument( "--preprocessing_num_workers", @@ -187,17 +187,16 @@ def parse_args(): default=None, help="The number of processes to use for the preprocessing.", ) - parser.add_argument("--overwrite_cache", - type=bool, - default=False, - help="Overwrite the cached training and evaluation sets") - parser.add_argument("--no_keep_linebreaks", - action="store_true", - help="Do not keep line breaks when using TXT files.") + parser.add_argument( + "--overwrite_cache", type=bool, default=False, help="Overwrite the cached training and evaluation sets" + ) + parser.add_argument( + "--no_keep_linebreaks", action="store_true", help="Do not keep line breaks when using TXT files." + ) parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.") - parser.add_argument("--hub_model_id", - type=str, - help="The name of the repository to keep in sync with the local `output_dir`.") + parser.add_argument( + "--hub_model_id", type=str, help="The name of the repository to keep in sync with the local `output_dir`." + ) parser.add_argument("--hub_token", type=str, help="The token to use to push to the Model Hub.") parser.add_argument( "--checkpointing_steps", @@ -220,13 +219,15 @@ def parse_args(): "--report_to", type=str, default="all", - help=('The integration to report the results and logs to. Supported platforms are `"tensorboard"`,' - ' `"wandb"` and `"comet_ml"`. Use `"all"` (default) to report to all integrations.' - "Only applicable when `--with_tracking` is passed."), + help=( + 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`,' + ' `"wandb"` and `"comet_ml"`. Use `"all"` (default) to report to all integrations.' + "Only applicable when `--with_tracking` is passed." + ), ) parser.add_argument("--mem_cap", type=int, default=0, help="use mem cap") - parser.add_argument("--init_in_cpu", action='store_true', default=False, help="init training model in cpu") + parser.add_argument("--init_in_cpu", action="store_true", default=False, help="init training model in cpu") args = parser.parse_args() # Sanity checks @@ -249,6 +250,7 @@ def parse_args(): def colo_memory_cap(size_in_GB): from colossalai.utils import colo_device_memory_capacity, colo_set_process_memory_fraction, get_current_device + cuda_capacity = colo_device_memory_capacity(get_current_device()) if size_in_GB * (1024**3) < cuda_capacity: colo_set_process_memory_fraction(size_in_GB * (1024**3) / cuda_capacity) @@ -256,7 +258,6 @@ def colo_memory_cap(size_in_GB): class DummyDataloader: - def __init__(self, length, batch_size, seq_len, vocab_size): self.length = length self.batch_size = batch_size @@ -292,10 +293,10 @@ def main(): if is_main_process: datasets.utils.logging.set_verbosity_warning() - transformers.utils.logging.set_verbosity_info() + logging.set_verbosity_info() else: datasets.utils.logging.set_verbosity_error() - transformers.utils.logging.set_verbosity_error() + logging.set_verbosity_error() if args.mem_cap > 0: colo_memory_cap(args.mem_cap) @@ -379,40 +380,60 @@ def main(): logger.warning("You are instantiating a new config instance from scratch.") logger.info("Model config has been created", ranks=[0]) - if args.model_name_or_path == 'facebook/opt-13b': + if args.model_name_or_path == "facebook/opt-13b": tokenizer = GPT2Tokenizer.from_pretrained(args.model_name_or_path) else: - print(f'load model from {args.model_name_or_path}') + print(f"load model from {args.model_name_or_path}") tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path, use_fast=not args.use_slow_tokenizer) logger.info(f"{tokenizer.__class__.__name__} has been created", ranks=[0]) if args.init_in_cpu: - init_dev = torch.device('cpu') + init_dev = torch.device("cpu") else: init_dev = get_current_device() + cai_version = colossalai.__version__ + logger.info(f"using Colossal-AI version {cai_version}") # build model - if args.model_name_or_path is None or args.model_name_or_path == 'facebook/opt-13b': + if version.parse(cai_version) >= version.parse("0.3.1"): + from contextlib import nullcontext + + from colossalai.lazy import LazyInitContext + + ctx = ( + LazyInitContext(default_device=init_dev) + if args.model_name_or_path is None or args.model_name_or_path == "facebook/opt-13b" + else nullcontext() + ) + else: + from colossalai.zero import ColoInitContext + + ctx = ColoInitContext(device=init_dev) + if args.model_name_or_path is None or args.model_name_or_path == "facebook/opt-13b": # currently, there has a bug in pretrained opt-13b # we can not import it until huggingface fix it logger.info("Train a new model from scratch", ranks=[0]) - with ColoInitContext(device=init_dev): + with ctx: model = OPTForCausalLM(config) else: logger.info("Finetune a pre-trained model", ranks=[0]) - with ColoInitContext(device=init_dev): - model = OPTForCausalLM.from_pretrained(args.model_name_or_path, - from_tf=bool(".ckpt" in args.model_name_or_path), - config=config, - local_files_only=False) + with ctx: + model = OPTForCausalLM.from_pretrained( + args.model_name_or_path, + from_tf=bool(".ckpt" in args.model_name_or_path), + config=config, + local_files_only=False, + ) # enable graident checkpointing model.gradient_checkpointing_enable() - PLACEMENT_POLICY = 'auto' - cai_version = colossalai.__version__ - logger.info(f'using Colossal-AI version {cai_version}') - if version.parse(cai_version) > version.parse("0.1.10"): + PLACEMENT_POLICY = "auto" + if version.parse(cai_version) >= version.parse("0.3.1"): + from colossalai.zero import GeminiDDP + + model = GeminiDDP(model, offload_optim_frac=1.0, pin_memory=True) + elif version.parse(cai_version) > version.parse("0.1.10"): try: from colossalai.nn.parallel import GeminiDDP except ImportError: @@ -421,16 +442,19 @@ def main(): model = GeminiDDP(model, device=get_current_device(), placement_policy=PLACEMENT_POLICY, pin_memory=True) elif version.parse(cai_version) <= version.parse("0.1.10") and version.parse(cai_version) >= version.parse("0.1.9"): from colossalai.gemini import ChunkManager, GeminiManager + pg = ProcessGroup() chunk_size = ChunkManager.search_chunk_size(model, 64 * 1024**2, 32) - chunk_manager = ChunkManager(chunk_size, - pg, - enable_distributed_storage=True, - init_device=GeminiManager.get_default_device(PLACEMENT_POLICY)) + chunk_manager = ChunkManager( + chunk_size, + pg, + enable_distributed_storage=True, + init_device=GeminiManager.get_default_device(PLACEMENT_POLICY), + ) gemini_manager = GeminiManager(PLACEMENT_POLICY, chunk_manager) model = ZeroDDP(model, gemini_manager) - logger.info(f'{model.__class__.__name__} has been created', ranks=[0]) + logger.info(f"{model.__class__.__name__} has been created", ranks=[0]) if not args.synthetic: # Preprocessing the datasets. @@ -456,12 +480,15 @@ def main(): if block_size > 1024: logger.warning( f"The tokenizer picked seems to have a very large `model_max_length` ({tokenizer.model_max_length}). " - "Picking 1024 instead. You can change that default value by passing --block_size xxx.") + "Picking 1024 instead. You can change that default value by passing --block_size xxx." + ) block_size = 1024 else: if args.block_size > tokenizer.model_max_length: - logger.warning(f"The block_size passed ({args.block_size}) is larger than the maximum length for the model" - f"({tokenizer.model_max_length}). Using block_size={tokenizer.model_max_length}.") + logger.warning( + f"The block_size passed ({args.block_size}) is larger than the maximum length for the model" + f"({tokenizer.model_max_length}). Using block_size={tokenizer.model_max_length}." + ) block_size = min(args.block_size, tokenizer.model_max_length) # Main data processing function that will concatenate all texts from our dataset and generate chunks of block_size. @@ -475,8 +502,8 @@ def main(): total_length = (total_length // block_size) * block_size # Split by chunks of max_len. result = { - k: [t[i:i + block_size] for i in range(0, total_length, block_size) - ] for k, t in concatenated_examples.items() + k: [t[i : i + block_size] for i in range(0, total_length, block_size)] + for k, t in concatenated_examples.items() } result["labels"] = result["input_ids"].copy() return result @@ -506,19 +533,23 @@ def main(): # logger.info(f"Sample {index} of the training set: {train_dataset[index]}.") # DataLoaders creation: - train_dataloader = get_dataloader(train_dataset, - shuffle=True, - add_sampler=True, - collate_fn=default_data_collator, - batch_size=args.per_device_train_batch_size) - eval_dataloader = DataLoader(eval_dataset, - collate_fn=default_data_collator, - batch_size=args.per_device_eval_batch_size) + train_dataloader = get_dataloader( + train_dataset, + shuffle=True, + add_sampler=True, + collate_fn=default_data_collator, + batch_size=args.per_device_train_batch_size, + ) + eval_dataloader = DataLoader( + eval_dataset, collate_fn=default_data_collator, batch_size=args.per_device_eval_batch_size + ) else: - train_dataloader = DummyDataloader(30, args.per_device_train_batch_size, config.max_position_embeddings, - config.vocab_size) - eval_dataloader = DummyDataloader(10, args.per_device_train_batch_size, config.max_position_embeddings, - config.vocab_size) + train_dataloader = DummyDataloader( + 30, args.per_device_train_batch_size, config.max_position_embeddings, config.vocab_size + ) + eval_dataloader = DummyDataloader( + 10, args.per_device_train_batch_size, config.max_position_embeddings, config.vocab_size + ) logger.info("Dataloaders have been created", ranks=[0]) # Optimizer @@ -536,7 +567,6 @@ def main(): ] optimizer = HybridAdam(optimizer_grouped_parameters, lr=args.learning_rate) - optimizer = ZeroOptimizer(optimizer, model, initial_scale=2**14) # Scheduler and math around the number of training steps. overrode_max_train_steps = False @@ -551,6 +581,7 @@ def main(): num_warmup_steps=args.num_warmup_steps, num_training_steps=args.max_train_steps, ) + optimizer = GeminiOptimizer(optimizer, model, initial_scale=2**14) # We need to recalculate our total training steps as the size of the training dataloader may have changed. num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) @@ -579,7 +610,6 @@ def main(): global_step = 0 for epoch in range(starting_epoch, args.num_train_epochs): - if completed_steps >= args.max_train_steps: break @@ -587,7 +617,7 @@ def main(): for step, batch in enumerate(train_dataloader): batch = {k: v.cuda() for k, v in batch.items()} outputs = model(use_cache=False, **batch) - loss = outputs['loss'] + loss = outputs["loss"] optimizer.backward(loss) if step % args.gradient_accumulation_steps == 0 or step == len(train_dataloader) - 1: @@ -610,7 +640,7 @@ def main(): batch = {k: v.cuda() for k, v in batch.items()} outputs = model(**batch) - loss = outputs['loss'].unsqueeze(0) + loss = outputs["loss"].unsqueeze(0) losses.append(loss) losses = torch.cat(losses) @@ -626,7 +656,7 @@ def main(): if args.output_dir is not None: model_state = model.state_dict() if is_main_process: - torch.save(model_state, args.output_dir + '/epoch_{}_model.pth'.format(completed_steps)) + torch.save(model_state, args.output_dir + "/epoch_{}_model.pth".format(completed_steps)) dist.barrier() # load_state = torch.load(args.output_dir + '/epoch_{}_model.pth'.format(completed_steps)) # model.load_state_dict(load_state, strict=False) diff --git a/examples/tutorial/opt/opt/test_ci.sh b/examples/tutorial/opt/opt/test_ci.sh index e505da1364de04c76f60177824466539af1af416..9cbc49c7b0017b9940db912e5b5180449da3896c 100755 --- a/examples/tutorial/opt/opt/test_ci.sh +++ b/examples/tutorial/opt/opt/test_ci.sh @@ -1,21 +1,21 @@ #!/bin/bash set -xue +echo "this test is outdated" +# pip install -r requirements.txt -pip install -r requirements.txt +# BS=4 +# MEMCAP=0 +# GPUNUM=4 +# MODLE="facebook/opt-125m" -BS=8 -MEMCAP=0 -GPUNUM=2 -MODLE="facebook/opt-125m" - -torchrun \ - --nproc_per_node ${GPUNUM} \ - --master_port 19198 \ - run_clm.py \ - -s \ - --output_dir $PWD \ - --mem_cap ${MEMCAP} \ - --model_name_or_path ${MODLE} \ - --per_device_train_batch_size ${BS} \ - --num_train_epochs 1 +# torchrun \ +# --nproc_per_node ${GPUNUM} \ +# --master_port 19198 \ +# run_clm.py \ +# -s \ +# --output_dir $PWD \ +# --mem_cap ${MEMCAP} \ +# --model_name_or_path ${MODLE} \ +# --per_device_train_batch_size ${BS} \ +# --num_train_epochs 1 diff --git a/examples/tutorial/sequence_parallel/config.py b/examples/tutorial/sequence_parallel/config.py index 6edf9cc2c7e5a22d836839bdcc40959ac75b0296..859f6e25e845d7f29131c92cd650c9ee8211212a 100644 --- a/examples/tutorial/sequence_parallel/config.py +++ b/examples/tutorial/sequence_parallel/config.py @@ -1,10 +1,10 @@ -from colossalai.amp import AMP_TYPE +from colossalai.legacy.amp import AMP_TYPE # hyper-parameters TRAIN_ITERS = 10 DECAY_ITERS = 4 WARMUP_FRACTION = 0.01 -GLOBAL_BATCH_SIZE = 32 # dp world size * sentences per GPU +GLOBAL_BATCH_SIZE = 32 # dp world size * sentences per GPU EVAL_ITERS = 10 EVAL_INTERVAL = 10 LR = 0.0001 @@ -28,8 +28,8 @@ SEED = 1234 NUM_MICRO_BATCHES = 4 # colossalai config -parallel = dict(pipeline=1, tensor=dict(size=2, mode='sequence')) +parallel = dict(pipeline=1, tensor=dict(size=2, mode="sequence")) fp16 = dict(mode=AMP_TYPE.NAIVE, verbose=True) -gradient_handler = [dict(type='SequenceParallelGradientHandler')] +gradient_handler = [dict(type="SequenceParallelGradientHandler")] diff --git a/examples/tutorial/sequence_parallel/data/__init__.py b/examples/tutorial/sequence_parallel/data/__init__.py index 1ef2d999389fe001b01342e66942c69455327efb..137f3cf0267b7813a44b83b7d6731cc058a91ebd 100644 --- a/examples/tutorial/sequence_parallel/data/__init__.py +++ b/examples/tutorial/sequence_parallel/data/__init__.py @@ -1,10 +1,12 @@ -from colossalai.context.parallel_context import ParallelContext -from colossalai.core import global_context as gpc +import torch + +from colossalai.legacy.context import ParallelMode +from colossalai.legacy.context.parallel_context import ParallelContext +from colossalai.legacy.core import global_context as gpc from colossalai.logging import get_dist_logger -from colossalai.context import ParallelMode -from .datasets.data_samplers import build_pretraining_data_loader + from .datasets.builder import build_train_valid_test_datasets -import torch +from .datasets.data_samplers import build_pretraining_data_loader def cyclic_iter(iter): @@ -13,17 +15,13 @@ def cyclic_iter(iter): yield x -def build_train_valid_test_data_iterators(train_iters, - global_batch_size, - eval_interval, - eval_iters, - dataloader_type='single', - **kwargs - ): +def build_train_valid_test_data_iterators( + train_iters, global_batch_size, eval_interval, eval_iters, dataloader_type="single", **kwargs +): (train_dataloader, valid_dataloader, test_dataloader) = (None, None, None) logger = get_dist_logger() - logger.info('> building train, validation, and test datasets ...', ranks=[0]) + logger.info("> building train, validation, and test datasets ...", ranks=[0]) # Backward compatibility, assume fixed batch size. # if iteration > 0 and consumed_train_samples == 0: @@ -37,65 +35,61 @@ def build_train_valid_test_data_iterators(train_iters, # Data loader only on rank 0 of each model parallel group. if not gpc.is_initialized(ParallelMode.TENSOR) or gpc.get_local_rank(ParallelMode.TENSOR) == 0: - # Number of train/valid/test samples. train_samples = train_iters * global_batch_size eval_iters_ = (train_iters // eval_interval + 1) * eval_iters test_iters = eval_iters - train_val_test_num_samples = [train_samples, - eval_iters_ * global_batch_size, - test_iters * global_batch_size] - logger.info(' > datasets target sizes (minimum size):') - logger.info(' train: {}'.format(train_val_test_num_samples[0]), ranks=[0]) - logger.info(' validation: {}'.format(train_val_test_num_samples[1]), ranks=[0]) - logger.info(' test: {}'.format(train_val_test_num_samples[2]), ranks=[0]) + train_val_test_num_samples = [train_samples, eval_iters_ * global_batch_size, test_iters * global_batch_size] + logger.info(" > datasets target sizes (minimum size):") + logger.info(" train: {}".format(train_val_test_num_samples[0]), ranks=[0]) + logger.info(" validation: {}".format(train_val_test_num_samples[1]), ranks=[0]) + logger.info(" test: {}".format(train_val_test_num_samples[2]), ranks=[0]) # Build the datasets. train_ds, valid_ds, test_ds = build_train_valid_test_datasets( - train_valid_test_num_samples=train_val_test_num_samples, **kwargs) + train_valid_test_num_samples=train_val_test_num_samples, **kwargs + ) # Build dataloaders. dp_size = gpc.get_world_size(ParallelMode.DATA) train_dataloader = build_pretraining_data_loader( - train_ds, consumed_samples=0, micro_batch_size=global_batch_size//dp_size) + train_ds, consumed_samples=0, micro_batch_size=global_batch_size // dp_size + ) valid_dataloader = build_pretraining_data_loader( - valid_ds, consumed_samples=0, micro_batch_size=global_batch_size//dp_size) - test_dataloader = build_pretraining_data_loader(test_ds, 0, micro_batch_size=global_batch_size//dp_size) + valid_ds, consumed_samples=0, micro_batch_size=global_batch_size // dp_size + ) + test_dataloader = build_pretraining_data_loader(test_ds, 0, micro_batch_size=global_batch_size // dp_size) # Flags to know if we need to do training/validation/testing. do_train = train_dataloader is not None and train_iters > 0 do_valid = valid_dataloader is not None and eval_iters > 0 do_test = test_dataloader is not None and eval_iters > 0 # Need to broadcast num_tokens and num_type_tokens. - flags = torch.cuda.LongTensor( - [int(do_train), int(do_valid), int(do_test)]) + flags = torch.cuda.LongTensor([int(do_train), int(do_valid), int(do_test)]) else: flags = torch.cuda.LongTensor([0, 0, 0]) # Broadcast num tokens. - torch.distributed.broadcast(flags, - gpc.get_ranks_in_group(ParallelMode.TENSOR)[0], - group=gpc.get_group(ParallelMode.TENSOR)) + torch.distributed.broadcast( + flags, gpc.get_ranks_in_group(ParallelMode.TENSOR)[0], group=gpc.get_group(ParallelMode.TENSOR) + ) # Build iterators. dl_type = dataloader_type - assert dl_type in ['single', 'cyclic'] + assert dl_type in ["single", "cyclic"] if train_dataloader is not None: - train_data_iterator = iter(train_dataloader) if dl_type == 'single' \ - else iter(cyclic_iter(train_dataloader)) + train_data_iterator = iter(train_dataloader) if dl_type == "single" else iter(cyclic_iter(train_dataloader)) else: train_data_iterator = None if valid_dataloader is not None: - valid_data_iterator = iter(valid_dataloader) if dl_type == 'single' \ - else iter(cyclic_iter(valid_dataloader)) + valid_data_iterator = iter(valid_dataloader) if dl_type == "single" else iter(cyclic_iter(valid_dataloader)) else: valid_data_iterator = None if test_dataloader is not None: - test_data_iterator = iter(test_dataloader) if dl_type == 'single' \ - else iter(cyclic_iter(test_dataloader)) + test_data_iterator = iter(test_dataloader) if dl_type == "single" else iter(cyclic_iter(test_dataloader)) else: test_data_iterator = None diff --git a/examples/tutorial/sequence_parallel/data/bert_helper.py b/examples/tutorial/sequence_parallel/data/bert_helper.py index d092db3e7dd8d545253e3a36c6203ace3d0eec9d..471be19bb12357d0d7087d426a525b234cd6a42a 100644 --- a/examples/tutorial/sequence_parallel/data/bert_helper.py +++ b/examples/tutorial/sequence_parallel/data/bert_helper.py @@ -1,7 +1,8 @@ -from colossalai.core import global_context as gpc -from colossalai.context import ParallelMode import torch +from colossalai.legacy.context import ParallelMode +from colossalai.legacy.core import global_context as gpc + _MAX_DATA_DIM = 5 @@ -14,7 +15,7 @@ def _build_key_size_numel_dictionaries(keys, data): if not gpc.is_initialized(ParallelMode.TENSOR) or gpc.get_local_rank(ParallelMode.TENSOR) == 0: offset = 0 for key in keys: - assert data[key].dim() < max_dim, 'you should increase MAX_DATA_DIM' + assert data[key].dim() < max_dim, "you should increase MAX_DATA_DIM" size = data[key].size() for i, s in enumerate(size): sizes[i + offset] = s @@ -22,8 +23,9 @@ def _build_key_size_numel_dictionaries(keys, data): # Move to GPU and broadcast. sizes_cuda = torch.cuda.LongTensor(sizes) - torch.distributed.broadcast(sizes_cuda, gpc.get_ranks_in_group(ParallelMode.TENSOR)[0], - group=gpc.get_group(ParallelMode.TENSOR)) + torch.distributed.broadcast( + sizes_cuda, gpc.get_ranks_in_group(ParallelMode.TENSOR)[0], group=gpc.get_group(ParallelMode.TENSOR) + ) # Move back to cpu and unpack. sizes_cpu = sizes_cuda.cpu() @@ -60,24 +62,20 @@ def broadcast_data(keys, data, datatype): """ # Build (key, size) and (key, number of elements) dictionaries along # with the total number of elements on all ranks. - key_size, key_numel, total_numel = _build_key_size_numel_dictionaries(keys, - data) + key_size, key_numel, total_numel = _build_key_size_numel_dictionaries(keys, data) # Pack on rank zero. if not gpc.is_initialized(ParallelMode.TENSOR) or gpc.get_local_rank(ParallelMode.TENSOR) == 0: # Check that all keys have the same data type. # Flatten the data associated with the keys - flatten_data = torch.cat( - [data[key].contiguous().view(-1) for key in keys], dim=0).cuda() + flatten_data = torch.cat([data[key].contiguous().view(-1) for key in keys], dim=0).cuda() else: - flatten_data = torch.empty(total_numel, - device=torch.cuda.current_device(), - dtype=datatype) + flatten_data = torch.empty(total_numel, device=torch.cuda.current_device(), dtype=datatype) # Broadcast - torch.distributed.broadcast(flatten_data, - gpc.get_ranks_in_group(ParallelMode.TENSOR)[0], - group=gpc.get_group(ParallelMode.TENSOR)) + torch.distributed.broadcast( + flatten_data, gpc.get_ranks_in_group(ParallelMode.TENSOR)[0], group=gpc.get_group(ParallelMode.TENSOR) + ) # Unpack output = {} @@ -95,7 +93,7 @@ def get_batch(data_iterator): """Build the batch.""" # Items and their type. - keys = ['text', 'types', 'labels', 'is_random', 'loss_mask', 'padding_mask'] + keys = ["text", "types", "labels", "is_random", "loss_mask", "padding_mask"] datatype = torch.int64 # Broadcast data. @@ -106,12 +104,12 @@ def get_batch(data_iterator): data_b = broadcast_data(keys, data, datatype) # Unpack. - tokens = data_b['text'].long() - types = data_b['types'].long() - sentence_order = data_b['is_random'].long() - loss_mask = data_b['loss_mask'].float() - lm_labels = data_b['labels'].long() - padding_mask = data_b['padding_mask'].long() + tokens = data_b["text"].long() + types = data_b["types"].long() + sentence_order = data_b["is_random"].long() + loss_mask = data_b["loss_mask"].float() + lm_labels = data_b["labels"].long() + padding_mask = data_b["padding_mask"].long() return tokens, types, sentence_order, loss_mask, lm_labels, padding_mask @@ -120,7 +118,7 @@ def get_batch_for_sequence_parallel(data_iterator): """Build the batch.""" # Items and their type. - keys = ['text', 'types', 'labels', 'is_random', 'loss_mask', 'padding_mask'] + keys = ["text", "types", "labels", "is_random", "loss_mask", "padding_mask"] datatype = torch.int64 # Broadcast data. @@ -136,30 +134,28 @@ def get_batch_for_sequence_parallel(data_iterator): global_rank = torch.distributed.get_rank() local_world_size = 1 if not gpc.is_initialized(ParallelMode.TENSOR) else gpc.get_world_size(ParallelMode.TENSOR) local_rank = global_rank % local_world_size - seq_length = data_b['text'].size(1) + seq_length = data_b["text"].size(1) sub_seq_length = seq_length // local_world_size sub_seq_start = local_rank * sub_seq_length - sub_seq_end = (local_rank+1) * sub_seq_length + sub_seq_end = (local_rank + 1) * sub_seq_length # # # Unpack. - tokens = data_b['text'][:, sub_seq_start:sub_seq_end].long() - types = data_b['types'][:, sub_seq_start:sub_seq_end].long() - sentence_order = data_b['is_random'].long() - loss_mask = data_b['loss_mask'][:, sub_seq_start:sub_seq_end].float() - lm_labels = data_b['labels'][:, sub_seq_start:sub_seq_end].long() - padding_mask = data_b['padding_mask'].long() + tokens = data_b["text"][:, sub_seq_start:sub_seq_end].long() + types = data_b["types"][:, sub_seq_start:sub_seq_end].long() + sentence_order = data_b["is_random"].long() + loss_mask = data_b["loss_mask"][:, sub_seq_start:sub_seq_end].float() + lm_labels = data_b["labels"][:, sub_seq_start:sub_seq_end].long() + padding_mask = data_b["padding_mask"].long() return tokens, types, sentence_order, loss_mask, lm_labels, padding_mask class SequenceParallelDataIterator: - def __init__(self, data_iter): self.data_iter = data_iter - def __iter__(self): return self.data_iter def __next__(self): - return get_batch_for_sequence_parallel(self.data_iter) \ No newline at end of file + return get_batch_for_sequence_parallel(self.data_iter) diff --git a/examples/tutorial/sequence_parallel/data/datasets/bert_dataset.py b/examples/tutorial/sequence_parallel/data/datasets/bert_dataset.py index d6388bd9f8e427575f345c20f38aa276a97be049..afab202e0927112890e23965815d23932d95c793 100644 --- a/examples/tutorial/sequence_parallel/data/datasets/bert_dataset.py +++ b/examples/tutorial/sequence_parallel/data/datasets/bert_dataset.py @@ -21,8 +21,8 @@ import numpy as np import torch from torch.utils.data import Dataset -from colossalai.context import ParallelMode -from colossalai.core import global_context as gpc +from colossalai.legacy.context import ParallelMode +from colossalai.legacy.core import global_context as gpc from colossalai.logging import get_dist_logger from ..tokenizer import get_tokenizer @@ -41,10 +41,19 @@ except: class BertDataset(Dataset): - - def __init__(self, name, indexed_dataset, data_prefix, num_epochs, max_num_samples, masked_lm_prob, max_seq_length, - short_seq_prob, seed, binary_head): - + def __init__( + self, + name, + indexed_dataset, + data_prefix, + num_epochs, + max_num_samples, + masked_lm_prob, + max_seq_length, + short_seq_prob, + seed, + binary_head, + ): # Params to store. self.name = name self.seed = seed @@ -61,11 +70,12 @@ class BertDataset(Dataset): data_prefix, num_epochs, max_num_samples, - self.max_seq_length - 3, # account for added tokens, + self.max_seq_length - 3, # account for added tokens, short_seq_prob, self.seed, self.name, - self.binary_head) + self.binary_head, + ) # Vocab stuff. tokenizer = get_tokenizer() @@ -89,7 +99,7 @@ class BertDataset(Dataset): return build_training_sample( sample, seq_length, - self.max_seq_length, # needed for padding + self.max_seq_length, # needed for padding self.vocab_id_list, self.vocab_id_to_token_dict, self.cls_id, @@ -98,37 +108,39 @@ class BertDataset(Dataset): self.pad_id, self.masked_lm_prob, np_rng, - self.binary_head) + self.binary_head, + ) -def get_samples_mapping_(indexed_dataset, data_prefix, num_epochs, max_num_samples, max_seq_length, short_seq_prob, - seed, name, binary_head): +def get_samples_mapping_( + indexed_dataset, data_prefix, num_epochs, max_num_samples, max_seq_length, short_seq_prob, seed, name, binary_head +): logger = get_dist_logger() if not num_epochs: if not max_num_samples: - raise ValueError("Need to specify either max_num_samples " - "or num_epochs") + raise ValueError("Need to specify either max_num_samples " "or num_epochs") num_epochs = np.iinfo(np.int32).max - 1 if not max_num_samples: max_num_samples = np.iinfo(np.int64).max - 1 # Filename of the index mapping indexmap_filename = data_prefix - indexmap_filename += '_{}_indexmap'.format(name) + indexmap_filename += "_{}_indexmap".format(name) if num_epochs != (np.iinfo(np.int32).max - 1): - indexmap_filename += '_{}ep'.format(num_epochs) + indexmap_filename += "_{}ep".format(num_epochs) if max_num_samples != (np.iinfo(np.int64).max - 1): - indexmap_filename += '_{}mns'.format(max_num_samples) - indexmap_filename += '_{}msl'.format(max_seq_length) - indexmap_filename += '_{:0.2f}ssp'.format(short_seq_prob) - indexmap_filename += '_{}s'.format(seed) - indexmap_filename += '.npy' + indexmap_filename += "_{}mns".format(max_num_samples) + indexmap_filename += "_{}msl".format(max_seq_length) + indexmap_filename += "_{:0.2f}ssp".format(short_seq_prob) + indexmap_filename += "_{}s".format(seed) + indexmap_filename += ".npy" # Build the indexed mapping if not exist. - if torch.distributed.get_rank() == 0 and \ - not os.path.isfile(indexmap_filename): - print(' > WARNING: could not find index map file {}, building ' - 'the indices on rank 0 ...'.format(indexmap_filename)) + if torch.distributed.get_rank() == 0 and not os.path.isfile(indexmap_filename): + print( + " > WARNING: could not find index map file {}, building " + "the indices on rank 0 ...".format(indexmap_filename) + ) # Make sure the types match the helpers input types. assert indexed_dataset.doc_idx.dtype == np.int64 @@ -137,18 +149,27 @@ def get_samples_mapping_(indexed_dataset, data_prefix, num_epochs, max_num_sampl # Build samples mapping verbose = torch.distributed.get_rank() == 0 start_time = time.time() - logger.info('\n > building samples index mapping for {} ...'.format(name), ranks=[0]) + logger.info("\n > building samples index mapping for {} ...".format(name), ranks=[0]) # First compile and then import. - samples_mapping = helpers.build_mapping(indexed_dataset.doc_idx, indexed_dataset.sizes, num_epochs, - max_num_samples, max_seq_length, short_seq_prob, seed, verbose, - 2 if binary_head else 1) - logger.info('\n > done building samples index maping', ranks=[0]) + samples_mapping = helpers.build_mapping( + indexed_dataset.doc_idx, + indexed_dataset.sizes, + num_epochs, + max_num_samples, + max_seq_length, + short_seq_prob, + seed, + verbose, + 2 if binary_head else 1, + ) + logger.info("\n > done building samples index maping", ranks=[0]) np.save(indexmap_filename, samples_mapping, allow_pickle=True) - logger.info('\n > saved the index mapping in {}'.format(indexmap_filename), ranks=[0]) + logger.info("\n > saved the index mapping in {}".format(indexmap_filename), ranks=[0]) # Make sure all the ranks have built the mapping - logger.info('\n > elapsed time to build and save samples mapping ' - '(seconds): {:4f}'.format(time.time() - start_time), - ranks=[0]) + logger.info( + "\n > elapsed time to build and save samples mapping " "(seconds): {:4f}".format(time.time() - start_time), + ranks=[0], + ) # This should be a barrier but nccl barrier assumes # device_index=rank which is not the case for model # parallel case @@ -156,22 +177,38 @@ def get_samples_mapping_(indexed_dataset, data_prefix, num_epochs, max_num_sampl torch.distributed.all_reduce(counts, group=gpc.get_group(ParallelMode.DATA)) if gpc.is_initialized(ParallelMode.PIPELINE): torch.distributed.all_reduce(counts, group=gpc.get_group(ParallelMode.PIPELINE)) - assert counts[0].item() == (torch.distributed.get_world_size() // - torch.distributed.get_world_size(group=gpc.get_group(ParallelMode.SEQUENCE))) + assert counts[0].item() == ( + torch.distributed.get_world_size() + // torch.distributed.get_world_size(group=gpc.get_group(ParallelMode.SEQUENCE)) + ) # Load indexed dataset. start_time = time.time() - samples_mapping = np.load(indexmap_filename, allow_pickle=True, mmap_mode='r') - logger.info('\n > loading indexed mapping from {}'.format(indexmap_filename) + - '\n loaded indexed file in {:3.3f} seconds'.format(time.time() - start_time) + - '\n total number of samples: {}'.format(samples_mapping.shape[0]), - ranks=[0]) + samples_mapping = np.load(indexmap_filename, allow_pickle=True, mmap_mode="r") + logger.info( + "\n > loading indexed mapping from {}".format(indexmap_filename) + + "\n loaded indexed file in {:3.3f} seconds".format(time.time() - start_time) + + "\n total number of samples: {}".format(samples_mapping.shape[0]), + ranks=[0], + ) return samples_mapping -def build_training_sample(sample, target_seq_length, max_seq_length, vocab_id_list, vocab_id_to_token_dict, cls_id, - sep_id, mask_id, pad_id, masked_lm_prob, np_rng, binary_head): +def build_training_sample( + sample, + target_seq_length, + max_seq_length, + vocab_id_list, + vocab_id_to_token_dict, + cls_id, + sep_id, + mask_id, + pad_id, + masked_lm_prob, + np_rng, + binary_head, +): """Build training sample. Arguments: @@ -215,22 +252,30 @@ def build_training_sample(sample, target_seq_length, max_seq_length, vocab_id_li # Masking. max_predictions_per_seq = masked_lm_prob * max_num_tokens - (tokens, masked_positions, masked_labels, - _) = create_masked_lm_predictions(tokens, vocab_id_list, vocab_id_to_token_dict, masked_lm_prob, cls_id, sep_id, - mask_id, max_predictions_per_seq, np_rng) + (tokens, masked_positions, masked_labels, _) = create_masked_lm_predictions( + tokens, + vocab_id_list, + vocab_id_to_token_dict, + masked_lm_prob, + cls_id, + sep_id, + mask_id, + max_predictions_per_seq, + np_rng, + ) # Padding. - tokens_np, tokentypes_np, labels_np, padding_mask_np, loss_mask_np \ - = pad_and_convert_to_numpy(tokens, tokentypes, masked_positions, - masked_labels, pad_id, max_seq_length) + tokens_np, tokentypes_np, labels_np, padding_mask_np, loss_mask_np = pad_and_convert_to_numpy( + tokens, tokentypes, masked_positions, masked_labels, pad_id, max_seq_length + ) train_sample = { - 'text': tokens_np, - 'types': tokentypes_np, - 'labels': labels_np, - 'is_random': int(is_next_random), - 'loss_mask': loss_mask_np, - 'padding_mask': padding_mask_np, - 'truncated': int(truncated) + "text": tokens_np, + "types": tokentypes_np, + "labels": labels_np, + "is_random": int(is_next_random), + "loss_mask": loss_mask_np, + "padding_mask": padding_mask_np, + "truncated": int(truncated), } return train_sample diff --git a/examples/tutorial/sequence_parallel/data/datasets/blendable_dataset.py b/examples/tutorial/sequence_parallel/data/datasets/blendable_dataset.py index 6a06c869d8c808af92ed3a6993a57cca9ca78a8b..1fa9c85fce0a76a3a14990f7f34251d36c6b2b17 100644 --- a/examples/tutorial/sequence_parallel/data/datasets/blendable_dataset.py +++ b/examples/tutorial/sequence_parallel/data/datasets/blendable_dataset.py @@ -22,9 +22,7 @@ import torch class BlendableDataset(torch.utils.data.Dataset): - def __init__(self, datasets, weights): - self.datasets = datasets num_datasets = len(datasets) assert num_datasets == len(weights) @@ -46,12 +44,16 @@ class BlendableDataset(torch.utils.data.Dataset): self.dataset_sample_index = np.zeros(self.size, dtype=np.int64) from . import helpers - helpers.build_blending_indices(self.dataset_index, - self.dataset_sample_index, - weights, num_datasets, self.size, - torch.distributed.get_rank() == 0) - print('> elapsed time for building blendable dataset indices: ' - '{:.2f} (sec)'.format(time.time() - start_time)) + + helpers.build_blending_indices( + self.dataset_index, + self.dataset_sample_index, + weights, + num_datasets, + self.size, + torch.distributed.get_rank() == 0, + ) + print("> elapsed time for building blendable dataset indices: " "{:.2f} (sec)".format(time.time() - start_time)) def __len__(self): return self.size diff --git a/examples/tutorial/sequence_parallel/data/datasets/builder.py b/examples/tutorial/sequence_parallel/data/datasets/builder.py index 6106f833b4628a0763cae82d0a7b6073f5c0548d..edf4c3d70cbf0de72e2b6e5943f8e15ab9f2989d 100644 --- a/examples/tutorial/sequence_parallel/data/datasets/builder.py +++ b/examples/tutorial/sequence_parallel/data/datasets/builder.py @@ -1,29 +1,34 @@ +from colossalai.logging import get_dist_logger + +from .bert_dataset import BertDataset from .blendable_dataset import BlendableDataset from .dataset_utils import get_datasets_weights_and_num_samples, get_indexed_dataset_, get_train_valid_test_split_ -from .bert_dataset import BertDataset -from colossalai.logging import get_dist_logger -DSET_TYPE_BERT = 'standard_bert' -DSET_TYPE_ICT = 'ict' -DSET_TYPE_T5 = 't5' +DSET_TYPE_BERT = "standard_bert" +DSET_TYPE_ICT = "ict" +DSET_TYPE_T5 = "t5" DSET_TYPES = [DSET_TYPE_BERT, DSET_TYPE_ICT, DSET_TYPE_T5] -def _build_train_valid_test_datasets(data_prefix, data_impl, splits_string, - train_valid_test_num_samples, - max_seq_length, masked_lm_prob, - short_seq_prob, seed, skip_warmup, - binary_head, - dataset_type='standard_bert'): - +def _build_train_valid_test_datasets( + data_prefix, + data_impl, + splits_string, + train_valid_test_num_samples, + max_seq_length, + masked_lm_prob, + short_seq_prob, + seed, + skip_warmup, + binary_head, + dataset_type="standard_bert", +): if dataset_type not in DSET_TYPES: raise ValueError("Invalid dataset_type: ", dataset_type) # Indexed dataset. - indexed_dataset = get_indexed_dataset_(data_prefix, - data_impl, - skip_warmup) + indexed_dataset = get_indexed_dataset_(data_prefix, data_impl, skip_warmup) # Get start and end indices of train/valid/train into doc-idx # Note that doc-idx is designed to be num-docs + 1 so we can @@ -34,22 +39,25 @@ def _build_train_valid_test_datasets(data_prefix, data_impl, splits_string, logger = get_dist_logger() # Print stats about the splits. - logger.info('\n > dataset split:', ranks=[0]) + logger.info("\n > dataset split:", ranks=[0]) def print_split_stats(name, index): start_index = indexed_dataset.doc_idx[splits[index]] end_index = indexed_dataset.doc_idx[splits[index + 1]] - logger.info('\n {}:'.format(name) + - '\n document indices in [{}, {}) total of {} documents'.format( - splits[index], splits[index + 1], - splits[index + 1] - splits[index]) + - '\n sentence indices in [{}, {}) total of {} sentences'.format( - start_index, end_index, - end_index - start_index), - ranks=[0]) - print_split_stats('train', 0) - print_split_stats('validation', 1) - print_split_stats('test', 2) + logger.info( + "\n {}:".format(name) + + "\n document indices in [{}, {}) total of {} documents".format( + splits[index], splits[index + 1], splits[index + 1] - splits[index] + ) + + "\n sentence indices in [{}, {}) total of {} sentences".format( + start_index, end_index, end_index - start_index + ), + ranks=[0], + ) + + print_split_stats("train", 0) + print_split_stats("validation", 1) + print_split_stats("test", 2) def build_dataset(index, name): dataset = None @@ -80,44 +88,53 @@ def _build_train_valid_test_datasets(data_prefix, data_impl, splits_string, masked_lm_prob=masked_lm_prob, short_seq_prob=short_seq_prob, binary_head=binary_head, - **kwargs + **kwargs, ) # Set the original pointer so dataset remains the main dataset. indexed_dataset.set_doc_idx(doc_idx_ptr) # Checks. assert indexed_dataset.doc_idx[0] == 0 - assert indexed_dataset.doc_idx.shape[0] == \ - (total_num_of_documents + 1) + assert indexed_dataset.doc_idx.shape[0] == (total_num_of_documents + 1) return dataset - train_dataset = build_dataset(0, 'train') - valid_dataset = build_dataset(1, 'valid') - test_dataset = build_dataset(2, 'test') + train_dataset = build_dataset(0, "train") + valid_dataset = build_dataset(1, "valid") + test_dataset = build_dataset(2, "test") return (train_dataset, valid_dataset, test_dataset) -def build_train_valid_test_datasets(data_prefix, data_impl, splits_string, - train_valid_test_num_samples, - max_seq_length, masked_lm_prob, - short_seq_prob, seed, skip_warmup, - binary_head, - dataset_type='standard_bert'): - +def build_train_valid_test_datasets( + data_prefix, + data_impl, + splits_string, + train_valid_test_num_samples, + max_seq_length, + masked_lm_prob, + short_seq_prob, + seed, + skip_warmup, + binary_head, + dataset_type="standard_bert", +): if len(data_prefix) == 1: - return _build_train_valid_test_datasets(data_prefix[0], - data_impl, splits_string, - train_valid_test_num_samples, - max_seq_length, masked_lm_prob, - short_seq_prob, seed, - skip_warmup, - binary_head, - dataset_type=dataset_type) + return _build_train_valid_test_datasets( + data_prefix[0], + data_impl, + splits_string, + train_valid_test_num_samples, + max_seq_length, + masked_lm_prob, + short_seq_prob, + seed, + skip_warmup, + binary_head, + dataset_type=dataset_type, + ) # Blending dataset. # Parse the values. - output = get_datasets_weights_and_num_samples(data_prefix, - train_valid_test_num_samples) + output = get_datasets_weights_and_num_samples(data_prefix, train_valid_test_num_samples) prefixes, weights, datasets_train_valid_test_num_samples = output # Build individual datasets. @@ -126,10 +143,18 @@ def build_train_valid_test_datasets(data_prefix, data_impl, splits_string, test_datasets = [] for i in range(len(prefixes)): train_ds, valid_ds, test_ds = _build_train_valid_test_datasets( - prefixes[i], data_impl, splits_string, + prefixes[i], + data_impl, + splits_string, datasets_train_valid_test_num_samples[i], - max_seq_length, masked_lm_prob, short_seq_prob, - seed, skip_warmup, binary_head, dataset_type=dataset_type) + max_seq_length, + masked_lm_prob, + short_seq_prob, + seed, + skip_warmup, + binary_head, + dataset_type=dataset_type, + ) if train_ds: train_datasets.append(train_ds) if valid_ds: @@ -148,5 +173,4 @@ def build_train_valid_test_datasets(data_prefix, data_impl, splits_string, if test_datasets: blending_test_dataset = BlendableDataset(test_datasets, weights) - return (blending_train_dataset, blending_valid_dataset, - blending_test_dataset) + return (blending_train_dataset, blending_valid_dataset, blending_test_dataset) diff --git a/examples/tutorial/sequence_parallel/data/datasets/data_samplers.py b/examples/tutorial/sequence_parallel/data/datasets/data_samplers.py index cf547ad9755815dfc0e9de449d147fe928a82948..8ba598529ebce93689c1a308ae3735d9638588b3 100644 --- a/examples/tutorial/sequence_parallel/data/datasets/data_samplers.py +++ b/examples/tutorial/sequence_parallel/data/datasets/data_samplers.py @@ -14,67 +14,67 @@ # limitations under the License. """Dataloaders.""" + import torch -import random -from colossalai.core import global_context as gpc -from colossalai.context import ParallelMode + +from colossalai.legacy.context import ParallelMode +from colossalai.legacy.core import global_context as gpc -def build_pretraining_data_loader(dataset, consumed_samples, micro_batch_size, dataloader_type='single', num_workers=0): +def build_pretraining_data_loader(dataset, consumed_samples, micro_batch_size, dataloader_type="single", num_workers=0): """Build dataloader given an input dataset.""" if dataset is None: return None # Megatron sampler - if dataloader_type == 'single': - batch_sampler = MegatronPretrainingSampler(total_samples=len(dataset), - consumed_samples=consumed_samples, - micro_batch_size=micro_batch_size, - data_parallel_rank=gpc.get_local_rank(ParallelMode.DATA), - data_parallel_size=gpc.get_world_size(ParallelMode.DATA)) - elif dataloader_type == 'cyclic': - batch_sampler = MegatronPretrainingRandomSampler(total_samples=len(dataset), - consumed_samples=consumed_samples, - micro_batch_size=micro_batch_size, - data_parallel_rank=gpc.get_local_rank(ParallelMode.DATA), - data_parallel_size=gpc.get_world_size(ParallelMode.DATA)) + if dataloader_type == "single": + batch_sampler = MegatronPretrainingSampler( + total_samples=len(dataset), + consumed_samples=consumed_samples, + micro_batch_size=micro_batch_size, + data_parallel_rank=gpc.get_local_rank(ParallelMode.DATA), + data_parallel_size=gpc.get_world_size(ParallelMode.DATA), + ) + elif dataloader_type == "cyclic": + batch_sampler = MegatronPretrainingRandomSampler( + total_samples=len(dataset), + consumed_samples=consumed_samples, + micro_batch_size=micro_batch_size, + data_parallel_rank=gpc.get_local_rank(ParallelMode.DATA), + data_parallel_size=gpc.get_world_size(ParallelMode.DATA), + ) else: - raise Exception('{} dataloader type is not supported.'.format(dataloader_type)) + raise Exception("{} dataloader type is not supported.".format(dataloader_type)) # Torch dataloader. return torch.utils.data.DataLoader(dataset, batch_sampler=batch_sampler, num_workers=num_workers, pin_memory=True) class MegatronPretrainingSampler: - - def __init__(self, - total_samples, - consumed_samples, - micro_batch_size, - data_parallel_rank, - data_parallel_size, - drop_last=True): + def __init__( + self, total_samples, consumed_samples, micro_batch_size, data_parallel_rank, data_parallel_size, drop_last=True + ): # Keep a copy of input params for later use. self.total_samples = total_samples self.consumed_samples = consumed_samples self.micro_batch_size = micro_batch_size self.data_parallel_rank = data_parallel_rank - self.micro_batch_times_data_parallel_size = \ - self.micro_batch_size * data_parallel_size + self.micro_batch_times_data_parallel_size = self.micro_batch_size * data_parallel_size self.drop_last = drop_last # Sanity checks. - assert self.total_samples > 0, \ - 'no sample to consume: {}'.format(self.total_samples) - assert self.consumed_samples < self.total_samples, \ - 'no samples left to consume: {}, {}'.format(self.consumed_samples, - self.total_samples) + assert self.total_samples > 0, "no sample to consume: {}".format(self.total_samples) + assert self.consumed_samples < self.total_samples, "no samples left to consume: {}, {}".format( + self.consumed_samples, self.total_samples + ) assert self.micro_batch_size > 0 assert data_parallel_size > 0 - assert self.data_parallel_rank < data_parallel_size, \ - 'data_parallel_rank should be smaller than data size: {}, ' \ - '{}'.format(self.data_parallel_rank, data_parallel_size) + assert ( + self.data_parallel_rank < data_parallel_size + ), "data_parallel_rank should be smaller than data size: {}, " "{}".format( + self.data_parallel_rank, data_parallel_size + ) def __len__(self): return self.total_samples @@ -101,7 +101,6 @@ class MegatronPretrainingSampler: class MegatronPretrainingRandomSampler: - def __init__(self, total_samples, consumed_samples, micro_batch_size, data_parallel_rank, data_parallel_size): # Keep a copy of input params for later use. self.total_samples = total_samples @@ -109,19 +108,18 @@ class MegatronPretrainingRandomSampler: self.micro_batch_size = micro_batch_size self.data_parallel_rank = data_parallel_rank self.data_parallel_size = data_parallel_size - self.micro_batch_times_data_parallel_size = \ - self.micro_batch_size * data_parallel_size - self.last_batch_size = \ - self.total_samples % self.micro_batch_times_data_parallel_size + self.micro_batch_times_data_parallel_size = self.micro_batch_size * data_parallel_size + self.last_batch_size = self.total_samples % self.micro_batch_times_data_parallel_size # Sanity checks. - assert self.total_samples > 0, \ - 'no sample to consume: {}'.format(self.total_samples) + assert self.total_samples > 0, "no sample to consume: {}".format(self.total_samples) assert self.micro_batch_size > 0 assert data_parallel_size > 0 - assert self.data_parallel_rank < data_parallel_size, \ - 'data_parallel_rank should be smaller than data size: {}, ' \ - '{}'.format(self.data_parallel_rank, data_parallel_size) + assert ( + self.data_parallel_rank < data_parallel_size + ), "data_parallel_rank should be smaller than data size: {}, " "{}".format( + self.data_parallel_rank, data_parallel_size + ) def __len__(self): return self.total_samples @@ -133,8 +131,7 @@ class MegatronPretrainingRandomSampler: assert current_epoch_samples % self.micro_batch_times_data_parallel_size == 0 # data sharding and random sampling - bucket_size = (self.total_samples // self.micro_batch_times_data_parallel_size) \ - * self.micro_batch_size + bucket_size = (self.total_samples // self.micro_batch_times_data_parallel_size) * self.micro_batch_size bucket_offset = current_epoch_samples // self.data_parallel_size start_idx = self.data_parallel_rank * bucket_size diff --git a/examples/tutorial/sequence_parallel/data/datasets/dataset_utils.py b/examples/tutorial/sequence_parallel/data/datasets/dataset_utils.py index cf4e4763fc107ca9da3f063037bf08f4efc67cdd..3e197ff96c0c0f718cc67effb9f27366b9b67c65 100644 --- a/examples/tutorial/sequence_parallel/data/datasets/dataset_utils.py +++ b/examples/tutorial/sequence_parallel/data/datasets/dataset_utils.py @@ -18,32 +18,33 @@ # https://github.com/google-research/albert/blob/master/create_pretraining_data.py # with some modifications. +import collections import math import time -import collections -from colossalai.logging import get_dist_logger + import numpy as np + +from colossalai.logging import get_dist_logger + from .blendable_dataset import BlendableDataset from .indexed_dataset import make_dataset as make_indexed_dataset -DSET_TYPE_STD = 'standard_bert' -DSET_TYPE_ICT = 'ict' +DSET_TYPE_STD = "standard_bert" +DSET_TYPE_ICT = "ict" DSET_TYPES = [DSET_TYPE_ICT, DSET_TYPE_STD] -def get_datasets_weights_and_num_samples(data_prefix, - train_valid_test_num_samples): - +def get_datasets_weights_and_num_samples(data_prefix, train_valid_test_num_samples): # The data prefix should be in the format of: # weight-1, data-prefix-1, weight-2, data-prefix-2, .. assert len(data_prefix) % 2 == 0 num_datasets = len(data_prefix) // 2 - weights = [0]*num_datasets - prefixes = [0]*num_datasets + weights = [0] * num_datasets + prefixes = [0] * num_datasets for i in range(num_datasets): - weights[i] = float(data_prefix[2*i]) - prefixes[i] = (data_prefix[2*i+1]).strip() + weights[i] = float(data_prefix[2 * i]) + prefixes[i] = (data_prefix[2 * i + 1]).strip() # Normalize weights weight_sum = 0.0 for weight in weights: @@ -57,8 +58,8 @@ def get_datasets_weights_and_num_samples(data_prefix, datasets_train_valid_test_num_samples = [] for weight in weights: datasets_train_valid_test_num_samples.append( - [int(math.ceil(val * weight * 1.005)) - for val in train_valid_test_num_samples]) + [int(math.ceil(val * weight * 1.005)) for val in train_valid_test_num_samples] + ) return prefixes, weights, datasets_train_valid_test_num_samples @@ -68,11 +69,13 @@ def compile_helper(): is invoked on a single process.""" import os import subprocess + path = os.path.abspath(os.path.dirname(__file__)) - ret = subprocess.run(['make', '-C', path]) + ret = subprocess.run(["make", "-C", path]) if ret.returncode != 0: print("Making C++ dataset helpers module failed, exiting.") import sys + sys.exit(1) @@ -82,7 +85,7 @@ def get_a_and_b_segments(sample, np_rng): # Number of sentences in the sample. n_sentences = len(sample) # Make sure we always have two sentences. - assert n_sentences > 1, 'make sure each sample has at least two sentences.' + assert n_sentences > 1, "make sure each sample has at least two sentences." # First part: # `a_end` is how many sentences go into the `A`. @@ -110,7 +113,7 @@ def get_a_and_b_segments(sample, np_rng): def truncate_segments(tokens_a, tokens_b, len_a, len_b, max_num_tokens, np_rng): """Truncates a pair of sequences to a maximum sequence length.""" - #print(len_a, len_b, max_num_tokens) + # print(len_a, len_b, max_num_tokens) assert len_a > 0 if len_a + len_b <= max_num_tokens: return False @@ -155,8 +158,7 @@ def create_tokens_and_tokentypes(tokens_a, tokens_b, cls_id, sep_id): return tokens, tokentypes -MaskedLmInstance = collections.namedtuple("MaskedLmInstance", - ["index", "label"]) +MaskedLmInstance = collections.namedtuple("MaskedLmInstance", ["index", "label"]) def is_start_piece(piece): @@ -168,16 +170,21 @@ def is_start_piece(piece): return not piece.startswith("##") -def create_masked_lm_predictions(tokens, - vocab_id_list, vocab_id_to_token_dict, - masked_lm_prob, - cls_id, sep_id, mask_id, - max_predictions_per_seq, - np_rng, - max_ngrams=3, - do_whole_word_mask=True, - favor_longer_ngram=False, - do_permutation=False): +def create_masked_lm_predictions( + tokens, + vocab_id_list, + vocab_id_to_token_dict, + masked_lm_prob, + cls_id, + sep_id, + mask_id, + max_predictions_per_seq, + np_rng, + max_ngrams=3, + do_whole_word_mask=True, + favor_longer_ngram=False, + do_permutation=False, +): """Creates the predictions for the masked LM objective. Note: Tokens here are vocab ids and not text tokens.""" @@ -187,7 +194,7 @@ def create_masked_lm_predictions(tokens, # on-the-fly whole word masking is possible. token_boundary = [0] * len(tokens) - for (i, token) in enumerate(tokens): + for i, token in enumerate(tokens): if token == cls_id or token == sep_id: token_boundary[i] = 1 continue @@ -197,8 +204,7 @@ def create_masked_lm_predictions(tokens, # Note that Whole Word Masking does *not* change the training code # at all -- we still predict each WordPiece independently, softmaxed # over the entire vocabulary. - if (do_whole_word_mask and len(cand_indexes) >= 1 and - not is_start_piece(vocab_id_to_token_dict[token])): + if do_whole_word_mask and len(cand_indexes) >= 1 and not is_start_piece(vocab_id_to_token_dict[token]): cand_indexes[-1].append(i) else: cand_indexes.append([i]) @@ -211,16 +217,14 @@ def create_masked_lm_predictions(tokens, masked_lm_labels = [] if masked_lm_prob == 0: - return (output_tokens, masked_lm_positions, - masked_lm_labels, token_boundary) + return (output_tokens, masked_lm_positions, masked_lm_labels, token_boundary) - num_to_predict = min(max_predictions_per_seq, - max(1, int(round(len(tokens) * masked_lm_prob)))) + num_to_predict = min(max_predictions_per_seq, max(1, int(round(len(tokens) * masked_lm_prob)))) # Note(mingdachen): # By default, we set the probabilities to favor shorter ngram sequences. ngrams = np.arange(1, max_ngrams + 1, dtype=np.int64) - pvals = 1. / np.arange(1, max_ngrams + 1) + pvals = 1.0 / np.arange(1, max_ngrams + 1) pvals /= pvals.sum(keepdims=True) if favor_longer_ngram: @@ -230,7 +234,7 @@ def create_masked_lm_predictions(tokens, for idx in range(len(cand_indexes)): ngram_index = [] for n in ngrams: - ngram_index.append(cand_indexes[idx:idx + n]) + ngram_index.append(cand_indexes[idx : idx + n]) ngram_indexes.append(ngram_index) np_rng.shuffle(ngram_indexes) @@ -249,9 +253,10 @@ def create_masked_lm_predictions(tokens, if index in covered_indexes: continue - n = np_rng.choice(ngrams[:len(cand_index_set)], - p=pvals[:len(cand_index_set)] / - pvals[:len(cand_index_set)].sum(keepdims=True)) + n = np_rng.choice( + ngrams[: len(cand_index_set)], + p=pvals[: len(cand_index_set)] / pvals[: len(cand_index_set)].sum(keepdims=True), + ) index_set = sum(cand_index_set[n - 1], []) n -= 1 # Note(mingdachen): @@ -309,9 +314,10 @@ def create_masked_lm_predictions(tokens, if index in covered_indexes or index in select_indexes: continue - n = np.random.choice(ngrams[:len(cand_index_set)], - p=pvals[:len(cand_index_set)] / - pvals[:len(cand_index_set)].sum(keepdims=True)) + n = np.random.choice( + ngrams[: len(cand_index_set)], + p=pvals[: len(cand_index_set)] / pvals[: len(cand_index_set)].sum(keepdims=True), + ) index_set = sum(cand_index_set[n - 1], []) n -= 1 @@ -353,8 +359,7 @@ def create_masked_lm_predictions(tokens, return (output_tokens, masked_lm_positions, masked_lm_labels, token_boundary) -def pad_and_convert_to_numpy(tokens, tokentypes, masked_positions, - masked_labels, pad_id, max_seq_length): +def pad_and_convert_to_numpy(tokens, tokentypes, masked_positions, masked_labels, pad_id, max_seq_length): """Pad sequences and convert them to numpy.""" # Some checks. @@ -370,8 +375,7 @@ def pad_and_convert_to_numpy(tokens, tokentypes, masked_positions, tokentypes_np = np.array(tokentypes + filler, dtype=np.int64) # Padding mask. - padding_mask_np = np.array([1] * num_tokens + [0] * padding_length, - dtype=np.int64) + padding_mask_np = np.array([1] * num_tokens + [0] * padding_length, dtype=np.int64) # Lables and loss mask. labels = [-1] * max_seq_length @@ -386,26 +390,36 @@ def pad_and_convert_to_numpy(tokens, tokentypes, masked_positions, return tokens_np, tokentypes_np, labels_np, padding_mask_np, loss_mask_np -def build_train_valid_test_datasets(data_prefix, data_impl, splits_string, - train_valid_test_num_samples, - max_seq_length, masked_lm_prob, - short_seq_prob, seed, skip_warmup, - binary_head, - dataset_type='standard_bert'): - +def build_train_valid_test_datasets( + data_prefix, + data_impl, + splits_string, + train_valid_test_num_samples, + max_seq_length, + masked_lm_prob, + short_seq_prob, + seed, + skip_warmup, + binary_head, + dataset_type="standard_bert", +): if len(data_prefix) == 1: - return _build_train_valid_test_datasets(data_prefix[0], - data_impl, splits_string, - train_valid_test_num_samples, - max_seq_length, masked_lm_prob, - short_seq_prob, seed, - skip_warmup, - binary_head, - dataset_type=dataset_type) + return _build_train_valid_test_datasets( + data_prefix[0], + data_impl, + splits_string, + train_valid_test_num_samples, + max_seq_length, + masked_lm_prob, + short_seq_prob, + seed, + skip_warmup, + binary_head, + dataset_type=dataset_type, + ) # Blending dataset. # Parse the values. - output = get_datasets_weights_and_num_samples(data_prefix, - train_valid_test_num_samples) + output = get_datasets_weights_and_num_samples(data_prefix, train_valid_test_num_samples) prefixes, weights, datasets_train_valid_test_num_samples = output # Build individual datasets. @@ -414,10 +428,18 @@ def build_train_valid_test_datasets(data_prefix, data_impl, splits_string, test_datasets = [] for i in range(len(prefixes)): train_ds, valid_ds, test_ds = _build_train_valid_test_datasets( - prefixes[i], data_impl, splits_string, + prefixes[i], + data_impl, + splits_string, datasets_train_valid_test_num_samples[i], - max_seq_length, masked_lm_prob, short_seq_prob, - seed, skip_warmup, binary_head, dataset_type=dataset_type) + max_seq_length, + masked_lm_prob, + short_seq_prob, + seed, + skip_warmup, + binary_head, + dataset_type=dataset_type, + ) if train_ds: train_datasets.append(train_ds) if valid_ds: @@ -436,31 +458,33 @@ def build_train_valid_test_datasets(data_prefix, data_impl, splits_string, if test_datasets: blending_test_dataset = BlendableDataset(test_datasets, weights) - return (blending_train_dataset, blending_valid_dataset, - blending_test_dataset) - - -def _build_train_valid_test_datasets(data_prefix, data_impl, splits_string, - train_valid_test_num_samples, - max_seq_length, masked_lm_prob, - short_seq_prob, seed, skip_warmup, - binary_head, - dataset_type='standard_bert'): + return (blending_train_dataset, blending_valid_dataset, blending_test_dataset) + + +def _build_train_valid_test_datasets( + data_prefix, + data_impl, + splits_string, + train_valid_test_num_samples, + max_seq_length, + masked_lm_prob, + short_seq_prob, + seed, + skip_warmup, + binary_head, + dataset_type="standard_bert", +): logger = get_dist_logger() if dataset_type not in DSET_TYPES: raise ValueError("Invalid dataset_type: ", dataset_type) # Indexed dataset. - indexed_dataset = get_indexed_dataset_(data_prefix, - data_impl, - skip_warmup) + indexed_dataset = get_indexed_dataset_(data_prefix, data_impl, skip_warmup) if dataset_type == DSET_TYPE_ICT: args = get_args() - title_dataset = get_indexed_dataset_(args.titles_data_path, - data_impl, - skip_warmup) + title_dataset = get_indexed_dataset_(args.titles_data_path, data_impl, skip_warmup) # Get start and end indices of train/valid/train into doc-idx # Note that doc-idx is designed to be num-docs + 1 so we can @@ -469,27 +493,29 @@ def _build_train_valid_test_datasets(data_prefix, data_impl, splits_string, splits = get_train_valid_test_split_(splits_string, total_num_of_documents) # Print stats about the splits. - logger.info('\n > dataset split:') + logger.info("\n > dataset split:") def print_split_stats(name, index): start_index = indexed_dataset.doc_idx[splits[index]] end_index = indexed_dataset.doc_idx[splits[index + 1]] - logger.info('\n {}:'.format(name) + - '\n document indices in [{}, {}) total of {} documents'.format( - splits[index], - splits[index + 1], - splits[index + 1] - splits[index]) + - '\n sentence indices in [{}, {}) total of {} sentences'.format( - start_index, - end_index, - end_index - start_index), - ranks=[0]) - print_split_stats('train', 0) - print_split_stats('validation', 1) - print_split_stats('test', 2) + logger.info( + "\n {}:".format(name) + + "\n document indices in [{}, {}) total of {} documents".format( + splits[index], splits[index + 1], splits[index + 1] - splits[index] + ) + + "\n sentence indices in [{}, {}) total of {} sentences".format( + start_index, end_index, end_index - start_index + ), + ranks=[0], + ) + + print_split_stats("train", 0) + print_split_stats("validation", 1) + print_split_stats("test", 2) def build_dataset(index, name): from .bert_dataset import BertDataset + dataset = None if splits[index + 1] > splits[index]: # Get the pointer to the original doc-idx so we can set it later. @@ -508,7 +534,7 @@ def _build_train_valid_test_datasets(data_prefix, data_impl, splits_string, max_num_samples=train_valid_test_num_samples[index], max_seq_length=max_seq_length, seed=seed, - binary_head=binary_head + binary_head=binary_head, ) if dataset_type == DSET_TYPE_ICT: @@ -518,27 +544,26 @@ def _build_train_valid_test_datasets(data_prefix, data_impl, splits_string, title_dataset=title_dataset, query_in_block_prob=args.query_in_block_prob, use_one_sent_docs=args.use_one_sent_docs, - **kwargs + **kwargs, ) else: dataset = BertDataset( indexed_dataset=indexed_dataset, masked_lm_prob=masked_lm_prob, short_seq_prob=short_seq_prob, - **kwargs + **kwargs, ) # Set the original pointer so dataset remains the main dataset. indexed_dataset.set_doc_idx(doc_idx_ptr) # Checks. assert indexed_dataset.doc_idx[0] == 0 - assert indexed_dataset.doc_idx.shape[0] == \ - (total_num_of_documents + 1) + assert indexed_dataset.doc_idx.shape[0] == (total_num_of_documents + 1) return dataset - train_dataset = build_dataset(0, 'train') - valid_dataset = build_dataset(1, 'valid') - test_dataset = build_dataset(2, 'test') + train_dataset = build_dataset(0, "train") + valid_dataset = build_dataset(1, "valid") + test_dataset = build_dataset(2, "test") return (train_dataset, valid_dataset, test_dataset) @@ -546,44 +571,41 @@ def _build_train_valid_test_datasets(data_prefix, data_impl, splits_string, def get_indexed_dataset_(data_prefix, data_impl, skip_warmup): logger = get_dist_logger() start_time = time.time() - indexed_dataset = make_indexed_dataset(data_prefix, - data_impl, - skip_warmup) + indexed_dataset = make_indexed_dataset(data_prefix, data_impl, skip_warmup) assert indexed_dataset.sizes.shape[0] == indexed_dataset.doc_idx[-1] - logger.info('\n > building dataset index ...', ranks=[0]) - logger.info('\n > finished creating indexed dataset in {:4f} ' - 'seconds'.format(time.time() - start_time), ranks=[0]) - logger.info('\n > indexed dataset stats:' + - '\n number of documents: {}'.format( - indexed_dataset.doc_idx.shape[0] - 1) + - '\n number of sentences: {}'.format( - indexed_dataset.sizes.shape[0]), - ranks=[0] - ) + logger.info("\n > building dataset index ...", ranks=[0]) + logger.info( + "\n > finished creating indexed dataset in {:4f} " "seconds".format(time.time() - start_time), ranks=[0] + ) + logger.info( + "\n > indexed dataset stats:" + + "\n number of documents: {}".format(indexed_dataset.doc_idx.shape[0] - 1) + + "\n number of sentences: {}".format(indexed_dataset.sizes.shape[0]), + ranks=[0], + ) return indexed_dataset def get_train_valid_test_split_(splits_string, size): - """ Get dataset splits from comma or '/' separated string list.""" + """Get dataset splits from comma or '/' separated string list.""" splits = [] - if splits_string.find(',') != -1: - splits = [float(s) for s in splits_string.split(',')] - elif splits_string.find('/') != -1: - splits = [float(s) for s in splits_string.split('/')] + if splits_string.find(",") != -1: + splits = [float(s) for s in splits_string.split(",")] + elif splits_string.find("/") != -1: + splits = [float(s) for s in splits_string.split("/")] else: splits = [float(splits_string)] while len(splits) < 3: - splits.append(0.) + splits.append(0.0) splits = splits[:3] splits_sum = sum(splits) assert splits_sum > 0.0 splits = [split / splits_sum for split in splits] splits_index = [0] for index, split in enumerate(splits): - splits_index.append(splits_index[index] + - int(round(split * float(size)))) + splits_index.append(splits_index[index] + int(round(split * float(size)))) diff = splits_index[-1] - size for index in range(1, len(splits_index)): splits_index[index] -= diff diff --git a/examples/tutorial/sequence_parallel/data/datasets/helpers.cpp b/examples/tutorial/sequence_parallel/data/datasets/helpers.cpp index e45926a976961eb5094658ba478cb697a88c8000..52977e63181f20c87353e77b37349da76ce4504a 100644 --- a/examples/tutorial/sequence_parallel/data/datasets/helpers.cpp +++ b/examples/tutorial/sequence_parallel/data/datasets/helpers.cpp @@ -15,29 +15,28 @@ limitations under the License. */ - /* Helper methods for fast index mapping builds */ +#include +#include +#include + #include #include #include -#include -#include -#include -#include #include +#include namespace py = pybind11; using namespace std; const int32_t LONG_SENTENCE_LEN = 512; - void build_blending_indices(py::array_t& dataset_index, - py::array_t& dataset_sample_index, - const py::array_t& weights, - const int32_t num_datasets, - const int64_t size, const bool verbose) { + py::array_t& dataset_sample_index, + const py::array_t& weights, + const int32_t num_datasets, const int64_t size, + const bool verbose) { /* Given multiple datasets and a weighting array, build samples such that it follows those wieghts.*/ @@ -52,24 +51,23 @@ void build_blending_indices(py::array_t& dataset_index, // Initialize buffer for number of samples used for each dataset. int64_t current_samples[num_datasets]; - for(int64_t i = 0; i < num_datasets; ++i) { + for (int64_t i = 0; i < num_datasets; ++i) { current_samples[i] = 0; } // For each sample: - for(int64_t sample_idx = 0; sample_idx < size; ++sample_idx) { - + for (int64_t sample_idx = 0; sample_idx < size; ++sample_idx) { // Determine where the max error in sampling is happening. auto sample_idx_double = std::max(static_cast(sample_idx), 1.0); int64_t max_error_index = 0; double max_error = weights_ptr[0] * sample_idx_double - - static_cast(current_samples[0]); + static_cast(current_samples[0]); for (int64_t dataset_idx = 1; dataset_idx < num_datasets; ++dataset_idx) { double error = weights_ptr[dataset_idx] * sample_idx_double - - static_cast(current_samples[dataset_idx]); + static_cast(current_samples[dataset_idx]); if (error > max_error) { - max_error = error; - max_error_index = dataset_idx; + max_error = error; + max_error_index = dataset_idx; } } @@ -79,7 +77,6 @@ void build_blending_indices(py::array_t& dataset_index, // Update the total samples. current_samples[max_error_index] += 1; - } // print info @@ -87,631 +84,607 @@ void build_blending_indices(py::array_t& dataset_index, std::cout << " > sample ratios:" << std::endl; for (int64_t dataset_idx = 0; dataset_idx < num_datasets; ++dataset_idx) { auto ratio = static_cast(current_samples[dataset_idx]) / - static_cast(size); - std::cout << " dataset " << dataset_idx << ", input: " << - weights_ptr[dataset_idx] << ", achieved: " << ratio << std::endl; + static_cast(size); + std::cout << " dataset " << dataset_idx + << ", input: " << weights_ptr[dataset_idx] + << ", achieved: " << ratio << std::endl; } } - } - py::array build_sample_idx(const py::array_t& sizes_, - const py::array_t& doc_idx_, - const int32_t seq_length, - const int32_t num_epochs, - const int64_t tokens_per_epoch) { - /* Sample index (sample_idx) is used for gpt2 like dataset for which - the documents are flattened and the samples are built based on this - 1-D flatten array. It is a 2D array with sizes [number-of-samples + 1, 2] - where [..., 0] contains the index into `doc_idx` and [..., 1] is the - starting offset in that document.*/ - - // Consistency checks. - assert(seq_length > 1); - assert(num_epochs > 0); - assert(tokens_per_epoch > 1); - - // Remove bound checks. - auto sizes = sizes_.unchecked<1>(); - auto doc_idx = doc_idx_.unchecked<1>(); - - // Mapping and it's length (1D). - int64_t num_samples = (num_epochs * tokens_per_epoch - 1) / seq_length; - int32_t* sample_idx = new int32_t[2*(num_samples+1)]; - - cout << " using:" << endl << std::flush; - cout << " number of documents: " << - doc_idx_.shape(0) / num_epochs << endl << std::flush; - cout << " number of epochs: " << num_epochs << - endl << std::flush; - cout << " sequence length: " << seq_length << - endl << std::flush; - cout << " total number of samples: " << num_samples << - endl << std::flush; - - // Index into sample_idx. - int64_t sample_index = 0; - // Index into doc_idx. - int64_t doc_idx_index = 0; - // Begining offset for each document. - int32_t doc_offset = 0; - // Start with first document and no offset. + const py::array_t& doc_idx_, + const int32_t seq_length, const int32_t num_epochs, + const int64_t tokens_per_epoch) { + /* Sample index (sample_idx) is used for gpt2 like dataset for which + the documents are flattened and the samples are built based on this + 1-D flatten array. It is a 2D array with sizes [number-of-samples + 1, 2] + where [..., 0] contains the index into `doc_idx` and [..., 1] is the + starting offset in that document.*/ + + // Consistency checks. + assert(seq_length > 1); + assert(num_epochs > 0); + assert(tokens_per_epoch > 1); + + // Remove bound checks. + auto sizes = sizes_.unchecked<1>(); + auto doc_idx = doc_idx_.unchecked<1>(); + + // Mapping and it's length (1D). + int64_t num_samples = (num_epochs * tokens_per_epoch - 1) / seq_length; + int32_t* sample_idx = new int32_t[2 * (num_samples + 1)]; + + cout << " using:" << endl << std::flush; + cout << " number of documents: " << doc_idx_.shape(0) / num_epochs + << endl + << std::flush; + cout << " number of epochs: " << num_epochs << endl + << std::flush; + cout << " sequence length: " << seq_length << endl + << std::flush; + cout << " total number of samples: " << num_samples << endl + << std::flush; + + // Index into sample_idx. + int64_t sample_index = 0; + // Index into doc_idx. + int64_t doc_idx_index = 0; + // Begining offset for each document. + int32_t doc_offset = 0; + // Start with first document and no offset. + sample_idx[2 * sample_index] = doc_idx_index; + sample_idx[2 * sample_index + 1] = doc_offset; + ++sample_index; + + while (sample_index <= num_samples) { + // Start with a fresh sequence. + int32_t remaining_seq_length = seq_length + 1; + while (remaining_seq_length != 0) { + // Get the document length. + auto doc_id = doc_idx[doc_idx_index]; + auto doc_length = sizes[doc_id] - doc_offset; + // And add it to the current sequence. + remaining_seq_length -= doc_length; + // If we have more than a full sequence, adjust offset and set + // remaining length to zero so we return from the while loop. + // Note that -1 here is for the same reason we have -1 in + // `_num_epochs` calculations. + if (remaining_seq_length <= 0) { + doc_offset += (remaining_seq_length + doc_length - 1); + remaining_seq_length = 0; + } else { + // Otherwise, start from the begining of the next document. + ++doc_idx_index; + doc_offset = 0; + } + } + // Record the sequence. sample_idx[2 * sample_index] = doc_idx_index; sample_idx[2 * sample_index + 1] = doc_offset; ++sample_index; + } - while (sample_index <= num_samples) { - // Start with a fresh sequence. - int32_t remaining_seq_length = seq_length + 1; - while (remaining_seq_length != 0) { - // Get the document length. - auto doc_id = doc_idx[doc_idx_index]; - auto doc_length = sizes[doc_id] - doc_offset; - // And add it to the current sequence. - remaining_seq_length -= doc_length; - // If we have more than a full sequence, adjust offset and set - // remaining length to zero so we return from the while loop. - // Note that -1 here is for the same reason we have -1 in - // `_num_epochs` calculations. - if (remaining_seq_length <= 0) { - doc_offset += (remaining_seq_length + doc_length - 1); - remaining_seq_length = 0; - } else { - // Otherwise, start from the begining of the next document. - ++doc_idx_index; - doc_offset = 0; - } - } - // Record the sequence. - sample_idx[2 * sample_index] = doc_idx_index; - sample_idx[2 * sample_index + 1] = doc_offset; - ++sample_index; - } - - // Method to deallocate memory. - py::capsule free_when_done(sample_idx, [](void *mem_) { - int32_t *mem = reinterpret_cast(mem_); - delete[] mem; - }); - - // Return the numpy array. - const auto byte_size = sizeof(int32_t); - return py::array(std::vector{num_samples+1, 2}, // shape - {2*byte_size, byte_size}, // C-style contiguous strides - sample_idx, // the data pointer - free_when_done); // numpy array references - + // Method to deallocate memory. + py::capsule free_when_done(sample_idx, [](void* mem_) { + int32_t* mem = reinterpret_cast(mem_); + delete[] mem; + }); + + // Return the numpy array. + const auto byte_size = sizeof(int32_t); + return py::array(std::vector{num_samples + 1, 2}, // shape + {2 * byte_size, byte_size}, // C-style contiguous strides + sample_idx, // the data pointer + free_when_done); // numpy array references } - inline int32_t get_target_sample_len(const int32_t short_seq_ratio, - const int32_t max_length, - std::mt19937& rand32_gen) { - /* Training sample length. */ - if (short_seq_ratio == 0) { - return max_length; - } - const auto random_number = rand32_gen(); - if ((random_number % short_seq_ratio) == 0) { - return 2 + random_number % (max_length - 1); - } + const int32_t max_length, + std::mt19937& rand32_gen) { + /* Training sample length. */ + if (short_seq_ratio == 0) { return max_length; + } + const auto random_number = rand32_gen(); + if ((random_number % short_seq_ratio) == 0) { + return 2 + random_number % (max_length - 1); + } + return max_length; } - -template +template py::array build_mapping_impl(const py::array_t& docs_, const py::array_t& sizes_, const int32_t num_epochs, const uint64_t max_num_samples, const int32_t max_seq_length, - const double short_seq_prob, - const int32_t seed, - const bool verbose, - const int32_t min_num_sent) { - /* Build a mapping of (start-index, end-index, sequence-length) where - start and end index are the indices of the sentences in the sample - and sequence-length is the target sequence length. - */ - - // Consistency checks. - assert(num_epochs > 0); - assert(max_seq_length > 1); - assert(short_seq_prob >= 0.0); - assert(short_seq_prob <= 1.0); - assert(seed > 0); - - // Remove bound checks. - auto docs = docs_.unchecked<1>(); - auto sizes = sizes_.unchecked<1>(); - - // For efficiency, convert probability to ratio. Note: rand() generates int. - int32_t short_seq_ratio = 0; - if (short_seq_prob > 0) { - short_seq_ratio = static_cast(round(1.0 / short_seq_prob)); - } + const double short_seq_prob, const int32_t seed, + const bool verbose, const int32_t min_num_sent) { + /* Build a mapping of (start-index, end-index, sequence-length) where + start and end index are the indices of the sentences in the sample + and sequence-length is the target sequence length. + */ + + // Consistency checks. + assert(num_epochs > 0); + assert(max_seq_length > 1); + assert(short_seq_prob >= 0.0); + assert(short_seq_prob <= 1.0); + assert(seed > 0); + + // Remove bound checks. + auto docs = docs_.unchecked<1>(); + auto sizes = sizes_.unchecked<1>(); + + // For efficiency, convert probability to ratio. Note: rand() generates int. + int32_t short_seq_ratio = 0; + if (short_seq_prob > 0) { + short_seq_ratio = static_cast(round(1.0 / short_seq_prob)); + } - if (verbose) { - const auto sent_start_index = docs[0]; - const auto sent_end_index = docs[docs_.shape(0) - 1]; - const auto num_sentences = sent_end_index - sent_start_index; - cout << " using:" << endl << std::flush; - cout << " number of documents: " << docs_.shape(0) - 1 << - endl << std::flush; - cout << " sentences range: [" << sent_start_index << - ", " << sent_end_index << ")" << endl << std::flush; - cout << " total number of sentences: " << num_sentences << - endl << std::flush; - cout << " number of epochs: " << num_epochs << - endl << std::flush; - cout << " maximum number of samples: " << max_num_samples << - endl << std::flush; - cout << " maximum sequence length: " << max_seq_length << - endl << std::flush; - cout << " short sequence probability: " << short_seq_prob << - endl << std::flush; - cout << " short sequence ration (1/prob): " << short_seq_ratio << - endl << std::flush; - cout << " seed: " << seed << endl << - std::flush; - } + if (verbose) { + const auto sent_start_index = docs[0]; + const auto sent_end_index = docs[docs_.shape(0) - 1]; + const auto num_sentences = sent_end_index - sent_start_index; + cout << " using:" << endl << std::flush; + cout << " number of documents: " << docs_.shape(0) - 1 + << endl + << std::flush; + cout << " sentences range: [" << sent_start_index << ", " + << sent_end_index << ")" << endl + << std::flush; + cout << " total number of sentences: " << num_sentences << endl + << std::flush; + cout << " number of epochs: " << num_epochs << endl + << std::flush; + cout << " maximum number of samples: " << max_num_samples << endl + << std::flush; + cout << " maximum sequence length: " << max_seq_length << endl + << std::flush; + cout << " short sequence probability: " << short_seq_prob << endl + << std::flush; + cout << " short sequence ration (1/prob): " << short_seq_ratio << endl + << std::flush; + cout << " seed: " << seed << endl + << std::flush; + } - // Mapping and it's length (1D). - int64_t num_samples = -1; - DocIdx* maps = NULL; - - // Perform two iterations, in the first iteration get the size - // and allocate memory and in the second iteration populate the map. - bool second = false; - for (int32_t iteration=0; iteration<2; ++iteration) { - - // Set the seed so both iterations produce the same results. - std::mt19937 rand32_gen(seed); - - // Set the flag on second iteration. - second = (iteration == 1); - - // Counters: - uint64_t empty_docs = 0; - uint64_t one_sent_docs = 0; - uint64_t long_sent_docs = 0; - - // Current map index. - uint64_t map_index = 0; - - // For each epoch: - for (int32_t epoch=0; epoch= max_num_samples) { - if (verbose && (!second)) { - cout << " reached " << max_num_samples << " samples after " - << epoch << " epochs ..." << endl << std::flush; - } - break; + // Mapping and it's length (1D). + int64_t num_samples = -1; + DocIdx* maps = NULL; + + // Perform two iterations, in the first iteration get the size + // and allocate memory and in the second iteration populate the map. + bool second = false; + for (int32_t iteration = 0; iteration < 2; ++iteration) { + // Set the seed so both iterations produce the same results. + std::mt19937 rand32_gen(seed); + + // Set the flag on second iteration. + second = (iteration == 1); + + // Counters: + uint64_t empty_docs = 0; + uint64_t one_sent_docs = 0; + uint64_t long_sent_docs = 0; + + // Current map index. + uint64_t map_index = 0; + + // For each epoch: + for (int32_t epoch = 0; epoch < num_epochs; ++epoch) { + if (map_index >= max_num_samples) { + if (verbose && (!second)) { + cout << " reached " << max_num_samples << " samples after " + << epoch << " epochs ..." << endl + << std::flush; + } + break; + } + // For each document: + for (int32_t doc = 0; doc < (docs.shape(0) - 1); ++doc) { + // Document sentences are in [sent_index_first, sent_index_last) + const auto sent_index_first = docs[doc]; + const auto sent_index_last = docs[doc + 1]; + + // At the begining of the document previous index is the + // start index. + auto prev_start_index = sent_index_first; + + // Remaining documents. + auto num_remain_sent = sent_index_last - sent_index_first; + + // Some bookkeeping + if ((epoch == 0) && (!second)) { + if (num_remain_sent == 0) { + ++empty_docs; + } + if (num_remain_sent == 1) { + ++one_sent_docs; + } + } + + // Detect documents with long sentences. + bool contains_long_sentence = false; + if (num_remain_sent > 1) { + for (auto sent_index = sent_index_first; sent_index < sent_index_last; + ++sent_index) { + if (sizes[sent_index] > LONG_SENTENCE_LEN) { + if ((epoch == 0) && (!second)) { + ++long_sent_docs; + } + contains_long_sentence = true; + break; } - // For each document: - for (int32_t doc=0; doc<(docs.shape(0) - 1); ++doc) { - - // Document sentences are in [sent_index_first, sent_index_last) - const auto sent_index_first = docs[doc]; - const auto sent_index_last = docs[doc + 1]; - - // At the begining of the document previous index is the - // start index. - auto prev_start_index = sent_index_first; - - // Remaining documents. - auto num_remain_sent = sent_index_last - sent_index_first; - - // Some bookkeeping - if ((epoch == 0) && (!second)) { - if (num_remain_sent == 0) { - ++empty_docs; - } - if (num_remain_sent == 1) { - ++one_sent_docs; - } - } - - // Detect documents with long sentences. - bool contains_long_sentence = false; - if (num_remain_sent > 1) { - for (auto sent_index=sent_index_first; - sent_index < sent_index_last; ++sent_index) { - if (sizes[sent_index] > LONG_SENTENCE_LEN){ - if ((epoch == 0) && (!second)) { - ++long_sent_docs; - } - contains_long_sentence = true; - break; - } - } - } - - // If we have more than two sentences. - if ((num_remain_sent >= min_num_sent) && (!contains_long_sentence)) { - - // Set values. - auto seq_len = int32_t{0}; - auto num_sent = int32_t{0}; - auto target_seq_len = get_target_sample_len(short_seq_ratio, - max_seq_length, - rand32_gen); - - // Loop through sentences. - for (auto sent_index=sent_index_first; - sent_index < sent_index_last; ++sent_index) { - - // Add the size and number of sentences. - seq_len += sizes[sent_index]; - ++num_sent; - --num_remain_sent; - - // If we have reached the target length. - // and if not only one sentence is left in the document. - // and if we have at least two sentneces. - // and if we have reached end of the document. - if (((seq_len >= target_seq_len) && - (num_remain_sent > 1) && - (num_sent >= min_num_sent) ) || (num_remain_sent == 0)) { - - // Check for overflow. - if ((3 * map_index + 2) > - std::numeric_limits::max()) { - cout << "number of samples exceeded maximum " - << "allowed by type int64: " - << std::numeric_limits::max() - << endl; - throw std::overflow_error("Number of samples"); - } - - // Populate the map. - if (second) { - const auto map_index_0 = 3 * map_index; - maps[map_index_0] = static_cast(prev_start_index); - maps[map_index_0 + 1] = static_cast(sent_index + 1); - maps[map_index_0 + 2] = static_cast(target_seq_len); - } - - // Update indices / counters. - ++map_index; - prev_start_index = sent_index + 1; - target_seq_len = get_target_sample_len(short_seq_ratio, - max_seq_length, - rand32_gen); - seq_len = 0; - num_sent = 0; - } - - } // for (auto sent_index=sent_index_first; ... - } // if (num_remain_sent > 1) { - } // for (int doc=0; doc < num_docs; ++doc) { - } // for (int epoch=0; epoch < num_epochs; ++epoch) { - - if (!second) { - if (verbose) { - cout << " number of empty documents: " << empty_docs << - endl << std::flush; - cout << " number of documents with one sentence: " << - one_sent_docs << endl << std::flush; - cout << " number of documents with long sentences: " << - long_sent_docs << endl << std::flush; - cout << " will create mapping for " << map_index << - " samples" << endl << std::flush; - } - assert(maps == NULL); - assert(num_samples < 0); - maps = new DocIdx[3*map_index]; - num_samples = static_cast(map_index); + } } - } // for (int iteration=0; iteration < 2; ++iteration) { - - // Shuffle. - // We need a 64 bit random number generator as we might have more - // than 2 billion samples. - std::mt19937_64 rand64_gen(seed + 1); - for (auto i=(num_samples - 1); i > 0; --i) { - const auto j = static_cast(rand64_gen() % (i + 1)); - const auto i0 = 3 * i; - const auto j0 = 3 * j; - // Swap values. - swap(maps[i0], maps[j0]); - swap(maps[i0 + 1], maps[j0 + 1]); - swap(maps[i0 + 2], maps[j0 + 2]); - } + // If we have more than two sentences. + if ((num_remain_sent >= min_num_sent) && (!contains_long_sentence)) { + // Set values. + auto seq_len = int32_t{0}; + auto num_sent = int32_t{0}; + auto target_seq_len = get_target_sample_len( + short_seq_ratio, max_seq_length, rand32_gen); + + // Loop through sentences. + for (auto sent_index = sent_index_first; sent_index < sent_index_last; + ++sent_index) { + // Add the size and number of sentences. + seq_len += sizes[sent_index]; + ++num_sent; + --num_remain_sent; + + // If we have reached the target length. + // and if not only one sentence is left in the document. + // and if we have at least two sentneces. + // and if we have reached end of the document. + if (((seq_len >= target_seq_len) && (num_remain_sent > 1) && + (num_sent >= min_num_sent)) || + (num_remain_sent == 0)) { + // Check for overflow. + if ((3 * map_index + 2) > std::numeric_limits::max()) { + cout << "number of samples exceeded maximum " + << "allowed by type int64: " + << std::numeric_limits::max() << endl; + throw std::overflow_error("Number of samples"); + } + + // Populate the map. + if (second) { + const auto map_index_0 = 3 * map_index; + maps[map_index_0] = static_cast(prev_start_index); + maps[map_index_0 + 1] = static_cast(sent_index + 1); + maps[map_index_0 + 2] = static_cast(target_seq_len); + } + + // Update indices / counters. + ++map_index; + prev_start_index = sent_index + 1; + target_seq_len = get_target_sample_len( + short_seq_ratio, max_seq_length, rand32_gen); + seq_len = 0; + num_sent = 0; + } - // Method to deallocate memory. - py::capsule free_when_done(maps, [](void *mem_) { - DocIdx *mem = reinterpret_cast(mem_); - delete[] mem; - }); + } // for (auto sent_index=sent_index_first; ... + } // if (num_remain_sent > 1) { + } // for (int doc=0; doc < num_docs; ++doc) { + } // for (int epoch=0; epoch < num_epochs; ++epoch) { + + if (!second) { + if (verbose) { + cout << " number of empty documents: " << empty_docs << endl + << std::flush; + cout << " number of documents with one sentence: " << one_sent_docs + << endl + << std::flush; + cout << " number of documents with long sentences: " << long_sent_docs + << endl + << std::flush; + cout << " will create mapping for " << map_index << " samples" << endl + << std::flush; + } + assert(maps == NULL); + assert(num_samples < 0); + maps = new DocIdx[3 * map_index]; + num_samples = static_cast(map_index); + } - // Return the numpy array. - const auto byte_size = sizeof(DocIdx); - return py::array(std::vector{num_samples, 3}, // shape - {3*byte_size, byte_size}, // C-style contiguous strides - maps, // the data pointer - free_when_done); // numpy array references + } // for (int iteration=0; iteration < 2; ++iteration) { + + // Shuffle. + // We need a 64 bit random number generator as we might have more + // than 2 billion samples. + std::mt19937_64 rand64_gen(seed + 1); + for (auto i = (num_samples - 1); i > 0; --i) { + const auto j = static_cast(rand64_gen() % (i + 1)); + const auto i0 = 3 * i; + const auto j0 = 3 * j; + // Swap values. + swap(maps[i0], maps[j0]); + swap(maps[i0 + 1], maps[j0 + 1]); + swap(maps[i0 + 2], maps[j0 + 2]); + } + // Method to deallocate memory. + py::capsule free_when_done(maps, [](void* mem_) { + DocIdx* mem = reinterpret_cast(mem_); + delete[] mem; + }); + + // Return the numpy array. + const auto byte_size = sizeof(DocIdx); + return py::array(std::vector{num_samples, 3}, // shape + {3 * byte_size, byte_size}, // C-style contiguous strides + maps, // the data pointer + free_when_done); // numpy array references } - py::array build_mapping(const py::array_t& docs_, - const py::array_t& sizes_, - const int num_epochs, + const py::array_t& sizes_, const int num_epochs, const uint64_t max_num_samples, - const int max_seq_length, - const double short_seq_prob, - const int seed, - const bool verbose, - const int32_t min_num_sent) { - - if (sizes_.size() > std::numeric_limits::max()) { - if (verbose) { - cout << " using uint64 for data mapping..." << endl << std::flush; - } - return build_mapping_impl(docs_, sizes_, num_epochs, - max_num_samples, max_seq_length, - short_seq_prob, seed, verbose, - min_num_sent); - } else { - if (verbose) { - cout << " using uint32 for data mapping..." << endl << std::flush; - } - return build_mapping_impl(docs_, sizes_, num_epochs, - max_num_samples, max_seq_length, - short_seq_prob, seed, verbose, - min_num_sent); + const int max_seq_length, const double short_seq_prob, + const int seed, const bool verbose, + const int32_t min_num_sent) { + if (sizes_.size() > std::numeric_limits::max()) { + if (verbose) { + cout << " using uint64 for data mapping..." << endl << std::flush; + } + return build_mapping_impl( + docs_, sizes_, num_epochs, max_num_samples, max_seq_length, + short_seq_prob, seed, verbose, min_num_sent); + } else { + if (verbose) { + cout << " using uint32 for data mapping..." << endl << std::flush; } + return build_mapping_impl( + docs_, sizes_, num_epochs, max_num_samples, max_seq_length, + short_seq_prob, seed, verbose, min_num_sent); + } } -template -py::array build_blocks_mapping_impl(const py::array_t& docs_, - const py::array_t& sizes_, - const py::array_t& titles_sizes_, - const int32_t num_epochs, - const uint64_t max_num_samples, - const int32_t max_seq_length, - const int32_t seed, - const bool verbose, - const bool use_one_sent_blocks) { - /* Build a mapping of (start-index, end-index, sequence-length) where - start and end index are the indices of the sentences in the sample - and sequence-length is the target sequence length. - */ - - // Consistency checks. - assert(num_epochs > 0); - assert(max_seq_length > 1); - assert(seed > 0); - - // Remove bound checks. - auto docs = docs_.unchecked<1>(); - auto sizes = sizes_.unchecked<1>(); - auto titles_sizes = titles_sizes_.unchecked<1>(); +template +py::array build_blocks_mapping_impl( + const py::array_t& docs_, const py::array_t& sizes_, + const py::array_t& titles_sizes_, const int32_t num_epochs, + const uint64_t max_num_samples, const int32_t max_seq_length, + const int32_t seed, const bool verbose, const bool use_one_sent_blocks) { + /* Build a mapping of (start-index, end-index, sequence-length) where + start and end index are the indices of the sentences in the sample + and sequence-length is the target sequence length. + */ + + // Consistency checks. + assert(num_epochs > 0); + assert(max_seq_length > 1); + assert(seed > 0); + + // Remove bound checks. + auto docs = docs_.unchecked<1>(); + auto sizes = sizes_.unchecked<1>(); + auto titles_sizes = titles_sizes_.unchecked<1>(); - if (verbose) { - const auto sent_start_index = docs[0]; - const auto sent_end_index = docs[docs_.shape(0) - 1]; - const auto num_sentences = sent_end_index - sent_start_index; - cout << " using:" << endl << std::flush; - cout << " number of documents: " << docs_.shape(0) - 1 << - endl << std::flush; - cout << " sentences range: [" << sent_start_index << - ", " << sent_end_index << ")" << endl << std::flush; - cout << " total number of sentences: " << num_sentences << - endl << std::flush; - cout << " number of epochs: " << num_epochs << - endl << std::flush; - cout << " maximum number of samples: " << max_num_samples << - endl << std::flush; - cout << " maximum sequence length: " << max_seq_length << - endl << std::flush; - cout << " seed: " << seed << endl << - std::flush; - } + if (verbose) { + const auto sent_start_index = docs[0]; + const auto sent_end_index = docs[docs_.shape(0) - 1]; + const auto num_sentences = sent_end_index - sent_start_index; + cout << " using:" << endl << std::flush; + cout << " number of documents: " << docs_.shape(0) - 1 + << endl + << std::flush; + cout << " sentences range: [" << sent_start_index << ", " + << sent_end_index << ")" << endl + << std::flush; + cout << " total number of sentences: " << num_sentences << endl + << std::flush; + cout << " number of epochs: " << num_epochs << endl + << std::flush; + cout << " maximum number of samples: " << max_num_samples << endl + << std::flush; + cout << " maximum sequence length: " << max_seq_length << endl + << std::flush; + cout << " seed: " << seed << endl + << std::flush; + } - // Mapping and its length (1D). - int64_t num_samples = -1; - DocIdx* maps = NULL; + // Mapping and its length (1D). + int64_t num_samples = -1; + DocIdx* maps = NULL; - // Acceptable number of sentences per block. - int min_num_sent = 2; - if (use_one_sent_blocks) { - min_num_sent = 1; - } + // Acceptable number of sentences per block. + int min_num_sent = 2; + if (use_one_sent_blocks) { + min_num_sent = 1; + } - // Perform two iterations, in the first iteration get the size - // and allocate memory and in the second iteration populate the map. - bool second = false; - for (int32_t iteration=0; iteration<2; ++iteration) { - - // Set the flag on second iteration. - second = (iteration == 1); - - // Current map index. - uint64_t map_index = 0; - - uint64_t empty_docs = 0; - uint64_t one_sent_docs = 0; - uint64_t long_sent_docs = 0; - // For each epoch: - for (int32_t epoch=0; epoch= max_num_samples) { - if (verbose && (!second)) { - cout << " reached " << max_num_samples << " samples after " - << epoch << " epochs ..." << endl << std::flush; - } - break; - } - // For each document: - for (int32_t doc=0; doc<(docs.shape(0) - 1); ++doc) { - - // Document sentences are in [sent_index_first, sent_index_last) - const auto sent_index_first = docs[doc]; - const auto sent_index_last = docs[doc + 1]; - const auto target_seq_len = max_seq_length - titles_sizes[doc]; - - // At the begining of the document previous index is the - // start index. - auto prev_start_index = sent_index_first; - - // Remaining documents. - auto num_remain_sent = sent_index_last - sent_index_first; - - // Some bookkeeping - if ((epoch == 0) && (!second)) { - if (num_remain_sent == 0) { - ++empty_docs; - } - if (num_remain_sent == 1) { - ++one_sent_docs; - } - } - // Detect documents with long sentences. - bool contains_long_sentence = false; - if (num_remain_sent >= min_num_sent) { - for (auto sent_index=sent_index_first; - sent_index < sent_index_last; ++sent_index) { - if (sizes[sent_index] > LONG_SENTENCE_LEN){ - if ((epoch == 0) && (!second)) { - ++long_sent_docs; - } - contains_long_sentence = true; - break; - } - } - } - // If we have enough sentences and no long sentences. - if ((num_remain_sent >= min_num_sent) && (!contains_long_sentence)) { - - // Set values. - auto seq_len = int32_t{0}; - auto num_sent = int32_t{0}; - - // Loop through sentences. - for (auto sent_index=sent_index_first; - sent_index < sent_index_last; ++sent_index) { - - // Add the size and number of sentences. - seq_len += sizes[sent_index]; - ++num_sent; - --num_remain_sent; - - // If we have reached the target length. - // and there are an acceptable number of sentences left - // and if we have at least the minimum number of sentences. - // or if we have reached end of the document. - if (((seq_len >= target_seq_len) && - (num_remain_sent >= min_num_sent) && - (num_sent >= min_num_sent) ) || (num_remain_sent == 0)) { - - // Populate the map. - if (second) { - const auto map_index_0 = 4 * map_index; - // Each sample has 4 items: the starting sentence index, ending sentence index, - // the index of the document from which the block comes (used for fetching titles) - // and the unique id of the block (used for creating block indexes) - - maps[map_index_0] = static_cast(prev_start_index); - maps[map_index_0 + 1] = static_cast(sent_index + 1); - maps[map_index_0 + 2] = static_cast(doc); - maps[map_index_0 + 3] = static_cast(block_id); - } - - // Update indices / counters. - ++map_index; - ++block_id; - prev_start_index = sent_index + 1; - seq_len = 0; - num_sent = 0; - } - } // for (auto sent_index=sent_index_first; ... - } // if (num_remain_sent > 1) { - } // for (int doc=0; doc < num_docs; ++doc) { - } // for (int epoch=0; epoch < num_epochs; ++epoch) { - - if (!second) { - if (verbose) { - cout << " number of empty documents: " << empty_docs << - endl << std::flush; - cout << " number of documents with one sentence: " << - one_sent_docs << endl << std::flush; - cout << " number of documents with long sentences: " << - long_sent_docs << endl << std::flush; - cout << " will create mapping for " << map_index << - " samples" << endl << std::flush; + // Perform two iterations, in the first iteration get the size + // and allocate memory and in the second iteration populate the map. + bool second = false; + for (int32_t iteration = 0; iteration < 2; ++iteration) { + // Set the flag on second iteration. + second = (iteration == 1); + + // Current map index. + uint64_t map_index = 0; + + uint64_t empty_docs = 0; + uint64_t one_sent_docs = 0; + uint64_t long_sent_docs = 0; + // For each epoch: + for (int32_t epoch = 0; epoch < num_epochs; ++epoch) { + // assign every block a unique id + int32_t block_id = 0; + + if (map_index >= max_num_samples) { + if (verbose && (!second)) { + cout << " reached " << max_num_samples << " samples after " + << epoch << " epochs ..." << endl + << std::flush; + } + break; + } + // For each document: + for (int32_t doc = 0; doc < (docs.shape(0) - 1); ++doc) { + // Document sentences are in [sent_index_first, sent_index_last) + const auto sent_index_first = docs[doc]; + const auto sent_index_last = docs[doc + 1]; + const auto target_seq_len = max_seq_length - titles_sizes[doc]; + + // At the begining of the document previous index is the + // start index. + auto prev_start_index = sent_index_first; + + // Remaining documents. + auto num_remain_sent = sent_index_last - sent_index_first; + + // Some bookkeeping + if ((epoch == 0) && (!second)) { + if (num_remain_sent == 0) { + ++empty_docs; + } + if (num_remain_sent == 1) { + ++one_sent_docs; + } + } + // Detect documents with long sentences. + bool contains_long_sentence = false; + if (num_remain_sent >= min_num_sent) { + for (auto sent_index = sent_index_first; sent_index < sent_index_last; + ++sent_index) { + if (sizes[sent_index] > LONG_SENTENCE_LEN) { + if ((epoch == 0) && (!second)) { + ++long_sent_docs; + } + contains_long_sentence = true; + break; } - assert(maps == NULL); - assert(num_samples < 0); - maps = new DocIdx[4*map_index]; - num_samples = static_cast(map_index); + } } - - } // for (int iteration=0; iteration < 2; ++iteration) { - - // Shuffle. - // We need a 64 bit random number generator as we might have more - // than 2 billion samples. - std::mt19937_64 rand64_gen(seed + 1); - for (auto i=(num_samples - 1); i > 0; --i) { - const auto j = static_cast(rand64_gen() % (i + 1)); - const auto i0 = 4 * i; - const auto j0 = 4 * j; - // Swap values. - swap(maps[i0], maps[j0]); - swap(maps[i0 + 1], maps[j0 + 1]); - swap(maps[i0 + 2], maps[j0 + 2]); - swap(maps[i0 + 3], maps[j0 + 3]); + // If we have enough sentences and no long sentences. + if ((num_remain_sent >= min_num_sent) && (!contains_long_sentence)) { + // Set values. + auto seq_len = int32_t{0}; + auto num_sent = int32_t{0}; + + // Loop through sentences. + for (auto sent_index = sent_index_first; sent_index < sent_index_last; + ++sent_index) { + // Add the size and number of sentences. + seq_len += sizes[sent_index]; + ++num_sent; + --num_remain_sent; + + // If we have reached the target length. + // and there are an acceptable number of sentences left + // and if we have at least the minimum number of sentences. + // or if we have reached end of the document. + if (((seq_len >= target_seq_len) && + (num_remain_sent >= min_num_sent) && + (num_sent >= min_num_sent)) || + (num_remain_sent == 0)) { + // Populate the map. + if (second) { + const auto map_index_0 = 4 * map_index; + // Each sample has 4 items: the starting sentence index, ending + // sentence index, the index of the document from which the + // block comes (used for fetching titles) and the unique id of + // the block (used for creating block indexes) + + maps[map_index_0] = static_cast(prev_start_index); + maps[map_index_0 + 1] = static_cast(sent_index + 1); + maps[map_index_0 + 2] = static_cast(doc); + maps[map_index_0 + 3] = static_cast(block_id); + } + + // Update indices / counters. + ++map_index; + ++block_id; + prev_start_index = sent_index + 1; + seq_len = 0; + num_sent = 0; + } + } // for (auto sent_index=sent_index_first; ... + } // if (num_remain_sent > 1) { + } // for (int doc=0; doc < num_docs; ++doc) { + } // for (int epoch=0; epoch < num_epochs; ++epoch) { + + if (!second) { + if (verbose) { + cout << " number of empty documents: " << empty_docs << endl + << std::flush; + cout << " number of documents with one sentence: " << one_sent_docs + << endl + << std::flush; + cout << " number of documents with long sentences: " << long_sent_docs + << endl + << std::flush; + cout << " will create mapping for " << map_index << " samples" << endl + << std::flush; + } + assert(maps == NULL); + assert(num_samples < 0); + maps = new DocIdx[4 * map_index]; + num_samples = static_cast(map_index); } - // Method to deallocate memory. - py::capsule free_when_done(maps, [](void *mem_) { - DocIdx *mem = reinterpret_cast(mem_); - delete[] mem; - }); - - // Return the numpy array. - const auto byte_size = sizeof(DocIdx); - return py::array(std::vector{num_samples, 4}, // shape - {4*byte_size, byte_size}, // C-style contiguous strides - maps, // the data pointer - free_when_done); // numpy array references + } // for (int iteration=0; iteration < 2; ++iteration) { + + // Shuffle. + // We need a 64 bit random number generator as we might have more + // than 2 billion samples. + std::mt19937_64 rand64_gen(seed + 1); + for (auto i = (num_samples - 1); i > 0; --i) { + const auto j = static_cast(rand64_gen() % (i + 1)); + const auto i0 = 4 * i; + const auto j0 = 4 * j; + // Swap values. + swap(maps[i0], maps[j0]); + swap(maps[i0 + 1], maps[j0 + 1]); + swap(maps[i0 + 2], maps[j0 + 2]); + swap(maps[i0 + 3], maps[j0 + 3]); + } + // Method to deallocate memory. + py::capsule free_when_done(maps, [](void* mem_) { + DocIdx* mem = reinterpret_cast(mem_); + delete[] mem; + }); + + // Return the numpy array. + const auto byte_size = sizeof(DocIdx); + return py::array(std::vector{num_samples, 4}, // shape + {4 * byte_size, byte_size}, // C-style contiguous strides + maps, // the data pointer + free_when_done); // numpy array references } -py::array build_blocks_mapping(const py::array_t& docs_, - const py::array_t& sizes_, - const py::array_t& titles_sizes_, - const int num_epochs, - const uint64_t max_num_samples, - const int max_seq_length, - const int seed, - const bool verbose, - const bool use_one_sent_blocks) { - - if (sizes_.size() > std::numeric_limits::max()) { - if (verbose) { - cout << " using uint64 for data mapping..." << endl << std::flush; - } - return build_blocks_mapping_impl(docs_, sizes_, titles_sizes_, - num_epochs, max_num_samples, max_seq_length, seed, verbose, use_one_sent_blocks); - } else { - if (verbose) { - cout << " using uint32 for data mapping..." << endl << std::flush; - } - return build_blocks_mapping_impl(docs_, sizes_, titles_sizes_, - num_epochs, max_num_samples, max_seq_length, seed, verbose, use_one_sent_blocks); +py::array build_blocks_mapping( + const py::array_t& docs_, const py::array_t& sizes_, + const py::array_t& titles_sizes_, const int num_epochs, + const uint64_t max_num_samples, const int max_seq_length, const int seed, + const bool verbose, const bool use_one_sent_blocks) { + if (sizes_.size() > std::numeric_limits::max()) { + if (verbose) { + cout << " using uint64 for data mapping..." << endl << std::flush; + } + return build_blocks_mapping_impl( + docs_, sizes_, titles_sizes_, num_epochs, max_num_samples, + max_seq_length, seed, verbose, use_one_sent_blocks); + } else { + if (verbose) { + cout << " using uint32 for data mapping..." << endl << std::flush; } + return build_blocks_mapping_impl( + docs_, sizes_, titles_sizes_, num_epochs, max_num_samples, + max_seq_length, seed, verbose, use_one_sent_blocks); + } } PYBIND11_MODULE(helpers, m) { - m.def("build_mapping", &build_mapping); - m.def("build_blocks_mapping", &build_blocks_mapping); - m.def("build_sample_idx", &build_sample_idx); - m.def("build_blending_indices", &build_blending_indices); + m.def("build_mapping", &build_mapping); + m.def("build_blocks_mapping", &build_blocks_mapping); + m.def("build_sample_idx", &build_sample_idx); + m.def("build_blending_indices", &build_blending_indices); } diff --git a/examples/tutorial/sequence_parallel/data/datasets/ict_dataset.py b/examples/tutorial/sequence_parallel/data/datasets/ict_dataset.py index 6dac35ff9d413898146df5c1cc8553719e142105..220099f9ba32d06e90022a6eed569b0f1ab9a705 100644 --- a/examples/tutorial/sequence_parallel/data/datasets/ict_dataset.py +++ b/examples/tutorial/sequence_parallel/data/datasets/ict_dataset.py @@ -2,12 +2,11 @@ import itertools import random import numpy as np -from torch.utils.data import Dataset - -from megatron import get_tokenizer -from megatron import get_args +from megatron import get_args, get_tokenizer from megatron.data.dataset_utils import get_indexed_dataset_ from megatron.data.realm_dataset_utils import get_block_samples_mapping +from torch.utils.data import Dataset + def make_attention_mask(source_block, target_block): """ @@ -20,16 +19,17 @@ def make_attention_mask(source_block, target_block): # (source_length, target_length) return mask + def get_ict_dataset(use_titles=True, query_in_block_prob=1): """Get a dataset which uses block samples mappings to get ICT/block indexing data (via get_block()) rather than for training, since it is only built with a single epoch sample mapping. """ args = get_args() - block_dataset = get_indexed_dataset_(args.data_path, 'mmap', True) - titles_dataset = get_indexed_dataset_(args.titles_data_path, 'mmap', True) + block_dataset = get_indexed_dataset_(args.data_path, "mmap", True) + titles_dataset = get_indexed_dataset_(args.titles_data_path, "mmap", True) kwargs = dict( - name='full', + name="full", block_dataset=block_dataset, title_dataset=titles_dataset, data_prefix=args.data_path, @@ -39,7 +39,7 @@ def get_ict_dataset(use_titles=True, query_in_block_prob=1): seed=1, query_in_block_prob=query_in_block_prob, use_titles=use_titles, - use_one_sent_docs=args.use_one_sent_docs + use_one_sent_docs=args.use_one_sent_docs, ) dataset = ICTDataset(**kwargs) return dataset @@ -47,9 +47,22 @@ def get_ict_dataset(use_titles=True, query_in_block_prob=1): class ICTDataset(Dataset): """Dataset containing sentences and their blocks for an inverse cloze task.""" - def __init__(self, name, block_dataset, title_dataset, data_prefix, - num_epochs, max_num_samples, max_seq_length, query_in_block_prob, - seed, use_titles=True, use_one_sent_docs=False, binary_head=False): + + def __init__( + self, + name, + block_dataset, + title_dataset, + data_prefix, + num_epochs, + max_num_samples, + max_seq_length, + query_in_block_prob, + seed, + use_titles=True, + use_one_sent_docs=False, + binary_head=False, + ): self.name = name self.seed = seed self.max_seq_length = max_seq_length @@ -61,8 +74,16 @@ class ICTDataset(Dataset): self.use_one_sent_docs = use_one_sent_docs self.samples_mapping = get_block_samples_mapping( - block_dataset, title_dataset, data_prefix, num_epochs, - max_num_samples, max_seq_length, seed, name, use_one_sent_docs) + block_dataset, + title_dataset, + data_prefix, + num_epochs, + max_num_samples, + max_seq_length, + seed, + name, + use_one_sent_docs, + ) self.tokenizer = get_tokenizer() self.vocab_id_list = list(self.tokenizer.inv_vocab.keys()) self.vocab_id_to_token_list = self.tokenizer.inv_vocab @@ -99,8 +120,8 @@ class ICTDataset(Dataset): # still need to truncate because blocks are concluded when # the sentence lengths have exceeded max_seq_length. - query = query[:self.max_seq_length - 2] - block = list(itertools.chain(*block))[:self.max_seq_length - title_pad_offset] + query = query[: self.max_seq_length - 2] + block = list(itertools.chain(*block))[: self.max_seq_length - title_pad_offset] query_tokens, query_pad_mask = self.concat_and_pad_tokens(query) context_tokens, context_pad_mask = self.concat_and_pad_tokens(block, title) @@ -111,13 +132,13 @@ class ICTDataset(Dataset): block_data = sample_data.as_array() sample = { - 'query_tokens': query_tokens, - 'query_mask': query_mask, - 'query_pad_mask': query_pad_mask, - 'context_tokens': context_tokens, - 'context_mask': context_mask, - 'context_pad_mask': context_pad_mask, - 'block_data': block_data, + "query_tokens": query_tokens, + "query_mask": query_mask, + "query_pad_mask": query_pad_mask, + "context_tokens": context_tokens, + "context_mask": context_mask, + "context_pad_mask": context_pad_mask, + "block_data": block_data, } return sample @@ -127,7 +148,7 @@ class ICTDataset(Dataset): block = [self.block_dataset[i] for i in range(start_idx, end_idx)] title = self.title_dataset[int(doc_idx)] - block = list(itertools.chain(*block))[:self.max_seq_length - (3 + len(title))] + block = list(itertools.chain(*block))[: self.max_seq_length - (3 + len(title))] block_tokens, block_pad_mask = self.concat_and_pad_tokens(block, title) return block_tokens, block_pad_mask diff --git a/examples/tutorial/sequence_parallel/data/datasets/indexed_dataset.py b/examples/tutorial/sequence_parallel/data/datasets/indexed_dataset.py index b4febcd822e1d8640a69e2f31287caad0bba39f7..961a1650bd744a61c12dbac8a60018737364c33e 100644 --- a/examples/tutorial/sequence_parallel/data/datasets/indexed_dataset.py +++ b/examples/tutorial/sequence_parallel/data/datasets/indexed_dataset.py @@ -3,17 +3,16 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. - # copied from fairseq/fairseq/data/indexed_dataset.py # Removed IndexedRawTextDataset since it relied on Fairseq dictionary # other slight modifications to remove fairseq dependencies # Added document index to index file and made it accessible. # An empty sentence no longer separates documents. -from functools import lru_cache import os import shutil import struct +from functools import lru_cache from itertools import accumulate import numpy as np @@ -28,17 +27,17 @@ def __best_fitting_dtype(vocab_size=None): def get_available_dataset_impl(): - return ['lazy', 'cached', 'mmap'] + return ["lazy", "cached", "mmap"] def infer_dataset_impl(path): if IndexedDataset.exists(path): - with open(index_file_path(path), 'rb') as f: + with open(index_file_path(path), "rb") as f: magic = f.read(8) if magic == IndexedDataset._HDR_MAGIC: - return 'cached' + return "cached" elif magic == MMapIndexedDataset.Index._HDR_MAGIC[:8]: - return 'mmap' + return "mmap" else: return None else: @@ -48,7 +47,7 @@ def infer_dataset_impl(path): def make_builder(out_file, impl, vocab_size=None): - if impl == 'mmap': + if impl == "mmap": return MMapIndexedDatasetBuilder(out_file, dtype=__best_fitting_dtype(vocab_size)) else: return IndexedDatasetBuilder(out_file) @@ -59,20 +58,20 @@ def make_dataset(path, impl, skip_warmup=False): print(f"Dataset does not exist: {path}") print("Path should be a basename that both .idx and .bin can be appended to get full filenames.") return None - if impl == 'infer': + if impl == "infer": impl = infer_dataset_impl(path) - if impl == 'lazy' and IndexedDataset.exists(path): + if impl == "lazy" and IndexedDataset.exists(path): return IndexedDataset(path) - elif impl == 'cached' and IndexedDataset.exists(path): + elif impl == "cached" and IndexedDataset.exists(path): return IndexedCachedDataset(path) - elif impl == 'mmap' and MMapIndexedDataset.exists(path): + elif impl == "mmap" and MMapIndexedDataset.exists(path): return MMapIndexedDataset(path, skip_warmup) print(f"Unknown dataset implementation: {impl}") return None def dataset_exists(path, impl): - if impl == 'mmap': + if impl == "mmap": return MMapIndexedDataset.exists(path) else: return IndexedDataset.exists(path) @@ -88,16 +87,7 @@ def write_longs(f, a): f.write(np.array(a, dtype=np.int64)) -dtypes = { - 1: np.uint8, - 2: np.int8, - 3: np.int16, - 4: np.int32, - 5: np.int64, - 6: np.float, - 7: np.double, - 8: np.uint16 -} +dtypes = {1: np.uint8, 2: np.int8, 3: np.int16, 4: np.int32, 5: np.int64, 6: float, 7: np.double, 8: np.uint16} def code(dtype): @@ -108,11 +98,11 @@ def code(dtype): def index_file_path(prefix_path): - return prefix_path + '.idx' + return prefix_path + ".idx" def data_file_path(prefix_path): - return prefix_path + '.bin' + return prefix_path + ".bin" def create_doc_idx(sizes): @@ -125,7 +115,8 @@ def create_doc_idx(sizes): class IndexedDataset(torch.utils.data.Dataset): """Loader for IndexedDataset""" - _HDR_MAGIC = b'TNTIDX\x00\x00' + + _HDR_MAGIC = b"TNTIDX\x00\x00" def __init__(self, path): super().__init__() @@ -134,29 +125,28 @@ class IndexedDataset(torch.utils.data.Dataset): self.read_index(path) def read_index(self, path): - with open(index_file_path(path), 'rb') as f: + with open(index_file_path(path), "rb") as f: magic = f.read(8) assert magic == self._HDR_MAGIC, ( - 'Index file doesn\'t match expected format. ' - 'Make sure that --dataset-impl is configured properly.' + "Index file doesn't match expected format. " "Make sure that --dataset-impl is configured properly." ) version = f.read(8) - assert struct.unpack('= self._len: - raise IndexError('index out of range') + raise IndexError("index out of range") def __del__(self): if self.data_file: @@ -169,7 +159,7 @@ class IndexedDataset(torch.utils.data.Dataset): if isinstance(idx, int): i = idx self.check_index(i) - tensor_size = self.sizes[self.dim_offsets[i]:self.dim_offsets[i + 1]] + tensor_size = self.sizes[self.dim_offsets[i] : self.dim_offsets[i + 1]] a = np.empty(tensor_size, dtype=self.dtype) self.data_file.seek(self.data_offsets[i] * self.element_size) self.data_file.readinto(a) @@ -178,7 +168,7 @@ class IndexedDataset(torch.utils.data.Dataset): start, stop, step = idx.indices(len(self)) if step != 1: raise ValueError("Slices into indexed_dataset must be contiguous") - sizes = self.sizes[self.dim_offsets[start]:self.dim_offsets[stop]] + sizes = self.sizes[self.dim_offsets[start] : self.dim_offsets[stop]] size = sum(sizes) a = np.empty(size, dtype=self.dtype) self.data_file.seek(self.data_offsets[start] * self.element_size) @@ -198,9 +188,7 @@ class IndexedDataset(torch.utils.data.Dataset): @staticmethod def exists(path): - return ( - os.path.exists(index_file_path(path)) and os.path.exists(data_file_path(path)) - ) + return os.path.exists(index_file_path(path)) and os.path.exists(data_file_path(path)) @property def supports_prefetch(self): @@ -208,7 +196,6 @@ class IndexedDataset(torch.utils.data.Dataset): class IndexedCachedDataset(IndexedDataset): - def __init__(self, path): super().__init__(path) self.cache = None @@ -233,7 +220,7 @@ class IndexedCachedDataset(IndexedDataset): for i in indices: self.cache_index[i] = ptx size = self.data_offsets[i + 1] - self.data_offsets[i] - a = self.cache[ptx: ptx + size] + a = self.cache[ptx : ptx + size] self.data_file.seek(self.data_offsets[i] * self.element_size) self.data_file.readinto(a) ptx += size @@ -247,10 +234,10 @@ class IndexedCachedDataset(IndexedDataset): if isinstance(idx, int): i = idx self.check_index(i) - tensor_size = self.sizes[self.dim_offsets[i]:self.dim_offsets[i + 1]] + tensor_size = self.sizes[self.dim_offsets[i] : self.dim_offsets[i + 1]] a = np.empty(tensor_size, dtype=self.dtype) ptx = self.cache_index[i] - np.copyto(a, self.cache[ptx: ptx + a.size]) + np.copyto(a, self.cache[ptx : ptx + a.size]) return a elif isinstance(idx, slice): # Hack just to make this work, can optimizer later if necessary @@ -261,18 +248,10 @@ class IndexedCachedDataset(IndexedDataset): class IndexedDatasetBuilder(object): - element_sizes = { - np.uint8: 1, - np.int8: 1, - np.int16: 2, - np.int32: 4, - np.int64: 8, - np.float: 4, - np.double: 8 - } + element_sizes = {np.uint8: 1, np.int8: 1, np.int16: 2, np.int32: 4, np.int64: 8, float: 4, np.double: 8} def __init__(self, out_file, dtype=np.int32): - self.out_file = open(out_file, 'wb') + self.out_file = open(out_file, "wb") self.dtype = dtype self.data_offsets = [0] self.dim_offsets = [0] @@ -302,7 +281,7 @@ class IndexedDatasetBuilder(object): for dim_offset in index.dim_offsets[1:]: self.dim_offsets.append(begin + dim_offset) - with open(data_file_path(another_file), 'rb') as f: + with open(data_file_path(another_file), "rb") as f: while True: data = f.read(1024) if data: @@ -312,12 +291,12 @@ class IndexedDatasetBuilder(object): def finalize(self, index_file): self.out_file.close() - index = open(index_file, 'wb') - index.write(b'TNTIDX\x00\x00') - index.write(struct.pack('= 0x4E00 and cp <= 0x9FFF) or # - (cp >= 0x3400 and cp <= 0x4DBF) or # - (cp >= 0x20000 and cp <= 0x2A6DF) or # - (cp >= 0x2A700 and cp <= 0x2B73F) or # - (cp >= 0x2B740 and cp <= 0x2B81F) or # - (cp >= 0x2B820 and cp <= 0x2CEAF) or - (cp >= 0xF900 and cp <= 0xFAFF) or # - (cp >= 0x2F800 and cp <= 0x2FA1F)): # + if ( + (cp >= 0x4E00 and cp <= 0x9FFF) + or (cp >= 0x3400 and cp <= 0x4DBF) # + or (cp >= 0x20000 and cp <= 0x2A6DF) # + or (cp >= 0x2A700 and cp <= 0x2B73F) # + or (cp >= 0x2B740 and cp <= 0x2B81F) # + or (cp >= 0x2B820 and cp <= 0x2CEAF) # + or (cp >= 0xF900 and cp <= 0xFAFF) + or (cp >= 0x2F800 and cp <= 0x2FA1F) # + ): # return True return False @@ -320,7 +320,7 @@ class BasicTokenizer(object): output = [] for char in text: cp = ord(char) - if cp == 0 or cp == 0xfffd or _is_control(char): + if cp == 0 or cp == 0xFFFD or _is_control(char): continue if _is_whitespace(char): output.append(" ") @@ -422,8 +422,7 @@ def _is_punctuation(char): # Characters such as "^", "$", and "`" are not in the Unicode # Punctuation class but we treat them as punctuation anyways, for # consistency. - if ((cp >= 33 and cp <= 47) or (cp >= 58 and cp <= 64) or - (cp >= 91 and cp <= 96) or (cp >= 123 and cp <= 126)): + if (cp >= 33 and cp <= 47) or (cp >= 58 and cp <= 64) or (cp >= 91 and cp <= 96) or (cp >= 123 and cp <= 126): return True cat = unicodedata.category(char) if cat.startswith("P"): diff --git a/examples/tutorial/sequence_parallel/data/tokenizer/tokenizer.py b/examples/tutorial/sequence_parallel/data/tokenizer/tokenizer.py index ee3c923e8e76517b8a5244608d6c18d080082952..6c7bfd69f3f934a530badea44d8eafad31940f76 100644 --- a/examples/tutorial/sequence_parallel/data/tokenizer/tokenizer.py +++ b/examples/tutorial/sequence_parallel/data/tokenizer/tokenizer.py @@ -12,13 +12,12 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - """Megatron tokenizers.""" -from abc import ABC -from abc import abstractmethod -from colossalai.core import global_context as gpc -from colossalai.context import ParallelMode +from abc import ABC, abstractmethod + +from colossalai.legacy.context import ParallelMode +from colossalai.legacy.core import global_context as gpc from .bert_tokenization import FullTokenizer as FullBertTokenizer @@ -26,21 +25,15 @@ from .bert_tokenization import FullTokenizer as FullBertTokenizer def build_tokenizer(vocab_file, tokenizer_type, vocab_extra_ids=0): """Initialize tokenizer.""" if not gpc.is_initialized(ParallelMode.GLOBAL) or gpc.get_global_rank() == 0: - print('> building {} tokenizer ...'.format(tokenizer_type), - flush=True) + print("> building {} tokenizer ...".format(tokenizer_type), flush=True) # Select and instantiate the tokenizer. - if tokenizer_type == 'BertWordPieceLowerCase': - tokenizer = _BertWordPieceTokenizer(vocab_file=vocab_file, - lower_case=True, - vocab_extra_ids=vocab_extra_ids) - elif tokenizer_type == 'BertWordPieceCase': - tokenizer = _BertWordPieceTokenizer(vocab_file=vocab_file, - lower_case=False, - vocab_extra_ids=vocab_extra_ids) + if tokenizer_type == "BertWordPieceLowerCase": + tokenizer = _BertWordPieceTokenizer(vocab_file=vocab_file, lower_case=True, vocab_extra_ids=vocab_extra_ids) + elif tokenizer_type == "BertWordPieceCase": + tokenizer = _BertWordPieceTokenizer(vocab_file=vocab_file, lower_case=False, vocab_extra_ids=vocab_extra_ids) else: - raise NotImplementedError('{} tokenizer is not ' - 'implemented.'.format(tokenizer_type)) + raise NotImplementedError("{} tokenizer is not " "implemented.".format(tokenizer_type)) # Add vocab size. padded_vocab_size = _vocab_size_with_padding(tokenizer.vocab_size) @@ -61,9 +54,11 @@ def _vocab_size_with_padding(orig_vocab_size, make_vocab_size_divisible_by=128): while (after % multiple) != 0: after += 1 if not gpc.is_initialized(ParallelMode.GLOBAL) or gpc.get_global_rank() == 0: - print(' > padded vocab (size: {}) with {} dummy tokens ' - '(new size: {})'.format( - orig_vocab_size, after - orig_vocab_size, after), flush=True) + print( + " > padded vocab (size: {}) with {} dummy tokens " + "(new size: {})".format(orig_vocab_size, after - orig_vocab_size, after), + flush=True, + ) return after @@ -83,46 +78,38 @@ class AbstractTokenizer(ABC): @abstractmethod def vocab(self): """Dictionary from vocab text token to id token.""" - pass @property @abstractmethod def inv_vocab(self): """Dictionary from vocab id token to text token.""" - pass @abstractmethod def tokenize(self, text): pass def detokenize(self, token_ids): - raise NotImplementedError('detokenizer is not implemented for {} ' - 'tokenizer'.format(self.name)) + raise NotImplementedError("detokenizer is not implemented for {} " "tokenizer".format(self.name)) @property def cls(self): - raise NotImplementedError('CLS is not provided for {} ' - 'tokenizer'.format(self.name)) + raise NotImplementedError("CLS is not provided for {} " "tokenizer".format(self.name)) @property def sep(self): - raise NotImplementedError('SEP is not provided for {} ' - 'tokenizer'.format(self.name)) + raise NotImplementedError("SEP is not provided for {} " "tokenizer".format(self.name)) @property def pad(self): - raise NotImplementedError('PAD is not provided for {} ' - 'tokenizer'.format(self.name)) + raise NotImplementedError("PAD is not provided for {} " "tokenizer".format(self.name)) @property def eod(self): - raise NotImplementedError('EOD is not provided for {} ' - 'tokenizer'.format(self.name)) + raise NotImplementedError("EOD is not provided for {} " "tokenizer".format(self.name)) @property def mask(self): - raise NotImplementedError('MASK is not provided for {} ' - 'tokenizer'.format(self.name)) + raise NotImplementedError("MASK is not provided for {} " "tokenizer".format(self.name)) class _BertWordPieceTokenizer(AbstractTokenizer): @@ -130,33 +117,31 @@ class _BertWordPieceTokenizer(AbstractTokenizer): def __init__(self, vocab_file, lower_case=True, vocab_extra_ids=0): if lower_case: - name = 'BERT Lower Case' + name = "BERT Lower Case" else: - name = 'BERT Upper Case' + name = "BERT Upper Case" super().__init__(name) self.tokenizer = FullBertTokenizer(vocab_file, do_lower_case=lower_case) - self.cls_id = self.tokenizer.vocab['[CLS]'] - self.sep_id = self.tokenizer.vocab['[SEP]'] - self.pad_id = self.tokenizer.vocab['[PAD]'] - self.mask_id = self.tokenizer.vocab['[MASK]'] + self.cls_id = self.tokenizer.vocab["[CLS]"] + self.sep_id = self.tokenizer.vocab["[SEP]"] + self.pad_id = self.tokenizer.vocab["[PAD]"] + self.mask_id = self.tokenizer.vocab["[MASK]"] self._additional_special_tokens = [] # (dsachan) Add BOS and EOS tokens - SPECIAL_TOKENS = {'eos_token': '[EOS]', - 'bos_token': '[BOS]'} - self._bos_token = '[BOS]' + SPECIAL_TOKENS = {"eos_token": "[EOS]", "bos_token": "[BOS]"} + self._bos_token = "[BOS]" self.add_token(self._bos_token) self._bos_token_id = self.vocab.get(self._bos_token) - self._eos_token = '[EOS]' + self._eos_token = "[EOS]" self.add_token(self._eos_token) self._eos_token_id = self.vocab.get(self._eos_token) # (dsachan) Add additional special tokens # These can be used as sentinel tokens in T5 model inputs additional_special_tokens = [] - additional_special_tokens.extend( - ["".format(i) for i in range(vocab_extra_ids)]) + additional_special_tokens.extend(["".format(i) for i in range(vocab_extra_ids)]) self.add_additional_special_tokens(additional_special_tokens) def add_token(self, token): @@ -193,7 +178,7 @@ class _BertWordPieceTokenizer(AbstractTokenizer): def decode_token_ids(self, token_ids): tokens = self.tokenizer.convert_ids_to_tokens(token_ids) - exclude_list = ['[PAD]', '[CLS]'] + exclude_list = ["[PAD]", "[CLS]"] non_pads = [t for t in tokens if t not in exclude_list] result = "" @@ -223,32 +208,32 @@ class _BertWordPieceTokenizer(AbstractTokenizer): @property def bos_token(self): - """ Beginning of sentence token id """ + """Beginning of sentence token id""" return self._bos_token @property def eos_token(self): - """ End of sentence token id """ + """End of sentence token id""" return self._eos_token @property def additional_special_tokens(self): - """ All the additional special tokens you may want to use (list of strings).""" + """All the additional special tokens you may want to use (list of strings).""" return self._additional_special_tokens @property def bos_token_id(self): - """ Id of the beginning of sentence token in the vocabulary.""" + """Id of the beginning of sentence token in the vocabulary.""" return self._bos_token_id @property def eos_token_id(self): - """ Id of the end of sentence token in the vocabulary.""" + """Id of the end of sentence token in the vocabulary.""" return self._eos_token_id @property def additional_special_tokens_ids(self): - """ Ids of all the additional special tokens in the vocabulary (list of integers).""" + """Ids of all the additional special tokens in the vocabulary (list of integers).""" return [self.vocab.get(token) for token in self._additional_special_tokens] @additional_special_tokens.setter diff --git a/examples/tutorial/sequence_parallel/loss_func/bert_loss.py b/examples/tutorial/sequence_parallel/loss_func/bert_loss.py index e87a778cf5d5959f68c134957f101787ca87cffe..869ff720f4b0b546095c5d66cdcddfaebbdca941 100644 --- a/examples/tutorial/sequence_parallel/loss_func/bert_loss.py +++ b/examples/tutorial/sequence_parallel/loss_func/bert_loss.py @@ -1,37 +1,24 @@ import torch import torch.nn as nn -from colossalai.core import global_context as gpc -from colossalai.context import ParallelMode -from colossalai.logging import get_dist_logger import torch.nn.functional as F -import torch.distributed as dist -from .cross_entropy import vocab_cross_entropy +from colossalai.legacy.context import ParallelMode +from colossalai.legacy.core import global_context as gpc -class BertLoss(nn.Module): - def forward(self, - lm_loss, - sop_logits, - loss_mask, - sentence_order): +class BertLoss(nn.Module): + def forward(self, lm_loss, sop_logits, loss_mask, sentence_order): lm_loss_ = lm_loss.float() loss_mask = loss_mask.float() loss_mask_sum = loss_mask.sum() - lm_loss = torch.sum( - lm_loss_.view(-1) * loss_mask.reshape(-1)) + lm_loss = torch.sum(lm_loss_.view(-1) * loss_mask.reshape(-1)) lm_loss /= loss_mask_sum - torch.distributed.all_reduce( - lm_loss, - group=gpc.get_group(ParallelMode.SEQUENCE) - ) + torch.distributed.all_reduce(lm_loss, group=gpc.get_group(ParallelMode.SEQUENCE)) if sop_logits is not None: - sop_loss = F.cross_entropy(sop_logits.view(-1, 2).float(), - sentence_order.view(-1), - ignore_index=-1) + sop_loss = F.cross_entropy(sop_logits.view(-1, 2).float(), sentence_order.view(-1), ignore_index=-1) sop_loss = sop_loss.float() loss = lm_loss + sop_loss * gpc.get_world_size(ParallelMode.SEQUENCE) else: diff --git a/examples/tutorial/sequence_parallel/loss_func/cross_entropy.py b/examples/tutorial/sequence_parallel/loss_func/cross_entropy.py index 54553c29a61f91b34d9b192ab734ff201a92892e..b5d9ea919261af97776e0be8776203098188f612 100644 --- a/examples/tutorial/sequence_parallel/loss_func/cross_entropy.py +++ b/examples/tutorial/sequence_parallel/loss_func/cross_entropy.py @@ -1,10 +1,8 @@ -from colossalai.context.parallel_mode import ParallelMode import torch from torch.cuda.amp import custom_bwd, custom_fwd class _VocabCrossEntropy(torch.autograd.Function): - @staticmethod @custom_fwd def forward(ctx, vocab_parallel_logits, target): @@ -24,8 +22,7 @@ class _VocabCrossEntropy(torch.autograd.Function): # [*, partition-vocab-size] and target to a 1-D tensor of size [*]. logits_2d = vocab_parallel_logits.view(-1, vocab_parallel_logits.size(-1)) masked_target_1d = masked_target.view(-1) - arange_1d = torch.arange(start=0, end=logits_2d.size()[0], - device=logits_2d.device) + arange_1d = torch.arange(start=0, end=logits_2d.size()[0], device=logits_2d.device) predicted_logits_1d = logits_2d[arange_1d, masked_target_1d] predicted_logits_1d = predicted_logits_1d.clone().contiguous() predicted_logits = predicted_logits_1d.view_as(target) @@ -58,10 +55,8 @@ class _VocabCrossEntropy(torch.autograd.Function): grad_2d = grad_input.view(-1, partition_vocab_size) # Add the gradient from matching classes. - arange_1d = torch.arange(start=0, end=grad_2d.size()[0], - device=grad_2d.device) - grad_2d[arange_1d, masked_target_1d] -= ( - 1.0 - target_mask.view(-1).float()) + arange_1d = torch.arange(start=0, end=grad_2d.size()[0], device=grad_2d.device) + grad_2d[arange_1d, masked_target_1d] -= 1.0 - target_mask.view(-1).float() # Finally elementwise multiplication with the output gradients. grad_input.mul_(grad_output.unsqueeze(dim=-1)) diff --git a/examples/tutorial/sequence_parallel/loss_func/utils.py b/examples/tutorial/sequence_parallel/loss_func/utils.py index a3d92f294326312a6f3cb9be04ed38124b104485..35fa73896c469a1d35ea8e0ee726d8cc561629e5 100644 --- a/examples/tutorial/sequence_parallel/loss_func/utils.py +++ b/examples/tutorial/sequence_parallel/loss_func/utils.py @@ -1,11 +1,9 @@ - import torch def ensure_divisibility(numerator, denominator): """Ensure that numerator is divisible by the denominator.""" - assert numerator % denominator == 0, '{} is not divisible by {}'.format( - numerator, denominator) + assert numerator % denominator == 0, "{} is not divisible by {}".format(numerator, denominator) def divide(numerator, denominator): @@ -15,8 +13,7 @@ def divide(numerator, denominator): return numerator // denominator -def split_tensor_along_last_dim(tensor, num_partitions, - contiguous_split_chunks=False): +def split_tensor_along_last_dim(tensor, num_partitions, contiguous_split_chunks=False): """Split a tensor along its last dimension. Arguments: tensor: input tensor. @@ -38,12 +35,11 @@ def split_tensor_along_last_dim(tensor, num_partitions, class VocabUtility: """Split the vocabulary into `world_size` chunks amd return the - first and last index of the vocabulary belonging to the `rank` - partition: Note that indices in [fist, last)""" + first and last index of the vocabulary belonging to the `rank` + partition: Note that indices in [fist, last)""" @staticmethod - def vocab_range_from_per_partition_vocab_size(per_partition_vocab_size, - rank, world_size): + def vocab_range_from_per_partition_vocab_size(per_partition_vocab_size, rank, world_size): index_f = rank * per_partition_vocab_size index_l = index_f + per_partition_vocab_size return index_f, index_l @@ -51,5 +47,4 @@ class VocabUtility: @staticmethod def vocab_range_from_global_vocab_size(global_vocab_size, rank, world_size): per_partition_vocab_size = divide(global_vocab_size, world_size) - return VocabUtility.vocab_range_from_per_partition_vocab_size( - per_partition_vocab_size, rank, world_size) + return VocabUtility.vocab_range_from_per_partition_vocab_size(per_partition_vocab_size, rank, world_size) diff --git a/examples/tutorial/sequence_parallel/lr_scheduler/annealing_lr.py b/examples/tutorial/sequence_parallel/lr_scheduler/annealing_lr.py index 8d95679ff76dcdab3020c02b30e9e8616d086c73..866d0d54583b7ba08118e99b2f59872796236df5 100644 --- a/examples/tutorial/sequence_parallel/lr_scheduler/annealing_lr.py +++ b/examples/tutorial/sequence_parallel/lr_scheduler/annealing_lr.py @@ -21,16 +21,17 @@ import math class AnnealingLR(object): """Anneals the learning rate.""" - def __init__(self, - optimizer, - max_lr, - min_lr, - warmup_steps, - decay_steps, - decay_style, - use_checkpoint_lr_scheduler=True, - override_lr_scheduler=False): - + def __init__( + self, + optimizer, + max_lr, + min_lr, + warmup_steps, + decay_steps, + decay_style, + use_checkpoint_lr_scheduler=True, + override_lr_scheduler=False, + ): # Class values. self.optimizer = optimizer @@ -50,23 +51,21 @@ class AnnealingLR(object): self.override_lr_scheduler = override_lr_scheduler self.use_checkpoint_lr_scheduler = use_checkpoint_lr_scheduler if self.override_lr_scheduler: - assert not self.use_checkpoint_lr_scheduler, 'both override and '\ - 'use-checkpoint are set.' + assert not self.use_checkpoint_lr_scheduler, "both override and " "use-checkpoint are set." # Set the learning rate self.step(0) def get_lr(self): """Learning rate decay functions from: - https://openreview.net/pdf?id=BJYwwY9ll pg. 4""" + https://openreview.net/pdf?id=BJYwwY9ll pg. 4""" # Use linear warmup for the initial part. if self.warmup_steps > 0 and self.num_steps <= self.warmup_steps: - return self.max_lr * float(self.num_steps) / \ - float(self.warmup_steps) + return self.max_lr * float(self.num_steps) / float(self.warmup_steps) # If the learning rate is constant, just return the initial value. - if self.decay_style == 'constant': + if self.decay_style == "constant": return self.max_lr # For any steps larger than `self.decay_steps`, use `self.min_lr`. @@ -81,13 +80,12 @@ class AnnealingLR(object): assert decay_ratio <= 1.0 delta_lr = self.max_lr - self.min_lr - if self.decay_style == 'linear': - coeff = (1.0 - decay_ratio) - elif self.decay_style == 'cosine': + if self.decay_style == "linear": + coeff = 1.0 - decay_ratio + elif self.decay_style == "cosine": coeff = 0.5 * (math.cos(math.pi * decay_ratio) + 1.0) else: - raise Exception('{} decay style is not supported.'.format( - self.decay_style)) + raise Exception("{} decay style is not supported.".format(self.decay_style)) return self.min_lr + coeff * delta_lr @@ -96,16 +94,16 @@ class AnnealingLR(object): self.num_steps += increment new_lr = self.get_lr() for group in self.optimizer.param_groups: - group['lr'] = new_lr + group["lr"] = new_lr def state_dict(self): state_dict = { - 'max_lr': self.max_lr, - 'warmup_steps': self.warmup_steps, - 'num_steps': self.num_steps, - 'decay_style': self.decay_style, - 'decay_steps': self.decay_steps, - 'min_lr': self.min_lr + "max_lr": self.max_lr, + "warmup_steps": self.warmup_steps, + "num_steps": self.num_steps, + "decay_style": self.decay_style, + "decay_steps": self.decay_steps, + "min_lr": self.min_lr, } return state_dict @@ -116,43 +114,35 @@ class AnnealingLR(object): return cls_value if not self.use_checkpoint_lr_scheduler: - assert cls_value == sd_value, \ - f'AnnealingLR: class input value {cls_value} and checkpoint' \ - f'value {sd_value} for {name} do not match' + assert cls_value == sd_value, ( + f"AnnealingLR: class input value {cls_value} and checkpoint" f"value {sd_value} for {name} do not match" + ) return sd_value def load_state_dict(self, sd): - - if 'start_lr' in sd: - max_lr_ = sd['start_lr'] + if "start_lr" in sd: + max_lr_ = sd["start_lr"] else: - max_lr_ = sd['max_lr'] - self.max_lr = self._check_and_set(self.max_lr, max_lr_, - 'learning rate') + max_lr_ = sd["max_lr"] + self.max_lr = self._check_and_set(self.max_lr, max_lr_, "learning rate") - self.min_lr = self._check_and_set(self.min_lr, sd['min_lr'], - 'minimum learning rate') + self.min_lr = self._check_and_set(self.min_lr, sd["min_lr"], "minimum learning rate") - if 'warmup_iter' in sd: - warmup_steps_ = sd['warmup_iter'] + if "warmup_iter" in sd: + warmup_steps_ = sd["warmup_iter"] else: - warmup_steps_ = sd['warmup_steps'] - self.warmup_steps = self._check_and_set(self.warmup_steps, - warmup_steps_, - 'warmup iterations') + warmup_steps_ = sd["warmup_steps"] + self.warmup_steps = self._check_and_set(self.warmup_steps, warmup_steps_, "warmup iterations") - if 'end_iter' in sd: - decay_steps_ = sd['end_iter'] + if "end_iter" in sd: + decay_steps_ = sd["end_iter"] else: - decay_steps_ = sd['decay_steps'] - self.decay_steps = self._check_and_set(self.decay_steps, decay_steps_, - 'total number of iterations') - self.decay_style = self._check_and_set(self.decay_style, - sd['decay_style'], - 'decay style') - - if 'num_iters' in sd: - num_steps = sd['num_iters'] + decay_steps_ = sd["decay_steps"] + self.decay_steps = self._check_and_set(self.decay_steps, decay_steps_, "total number of iterations") + self.decay_style = self._check_and_set(self.decay_style, sd["decay_style"], "decay style") + + if "num_iters" in sd: + num_steps = sd["num_iters"] else: - num_steps = sd['num_steps'] + num_steps = sd["num_steps"] self.step(increment=num_steps) diff --git a/examples/tutorial/sequence_parallel/model/__init__.py b/examples/tutorial/sequence_parallel/model/__init__.py index 139597f9cb07c5d48bed18984ec4747f4b4f3438..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 100644 --- a/examples/tutorial/sequence_parallel/model/__init__.py +++ b/examples/tutorial/sequence_parallel/model/__init__.py @@ -1,2 +0,0 @@ - - diff --git a/examples/tutorial/sequence_parallel/model/bert.py b/examples/tutorial/sequence_parallel/model/bert.py index 049579c5a639c2bab29348fe0290b75497dbdc31..7b0e93d958ca30ba9f103f0c21730cfdd034a7b9 100644 --- a/examples/tutorial/sequence_parallel/model/bert.py +++ b/examples/tutorial/sequence_parallel/model/bert.py @@ -1,36 +1,41 @@ -from colossalai.context.parallel_mode import ParallelMode +import inspect + import torch import torch.nn as nn -import inspect -from .layers import Embedding, BertLayer, BertDualHead, PreProcessor, VocabEmbedding -from .layers.init_method import init_normal, output_init_normal -from colossalai.core import global_context as gpc -from colossalai.context import ParallelMode + from colossalai.kernel import LayerNorm -from colossalai.nn.layer.wrapper import PipelineSharedModuleWrapper +from colossalai.legacy.context import ParallelMode +from colossalai.legacy.context.parallel_mode import ParallelMode +from colossalai.legacy.core import global_context as gpc +from colossalai.legacy.nn.layer.wrapper import PipelineSharedModuleWrapper +from colossalai.legacy.pipeline.utils import partition_uniform from colossalai.logging import get_dist_logger -from colossalai.pipeline.utils import partition_uniform +from .layers import BertDualHead, BertLayer, Embedding, PreProcessor, VocabEmbedding +from .layers.init_method import init_normal, output_init_normal -class BertForPretrain(nn.Module): - def __init__(self, - vocab_size, - hidden_size, - max_sequence_length, - num_attention_heads, - num_layers, - add_binary_head, - is_naive_fp16, - num_tokentypes=2, - dropout_prob=0.1, - mlp_ratio=4, - init_std=0.02, - convert_fp16_to_fp32_in_softmax=False, - ): +class BertForPretrain(nn.Module): + def __init__( + self, + vocab_size, + hidden_size, + max_sequence_length, + num_attention_heads, + num_layers, + add_binary_head, + is_naive_fp16, + num_tokentypes=2, + dropout_prob=0.1, + mlp_ratio=4, + init_std=0.02, + convert_fp16_to_fp32_in_softmax=False, + ): super().__init__() self.seq_parallel_size = gpc.get_world_size(ParallelMode.SEQUENCE) - assert max_sequence_length % self.seq_parallel_size == 0, 'sequence length is not divisible by the sequence parallel size' + assert ( + max_sequence_length % self.seq_parallel_size == 0 + ), "sequence length is not divisible by the sequence parallel size" self.sub_seq_length = max_sequence_length // self.seq_parallel_size self.init_std = init_std self.num_layers = num_layers @@ -39,28 +44,32 @@ class BertForPretrain(nn.Module): num_tokentypes = 0 self.preprocessor = PreProcessor(self.sub_seq_length) - self.embedding = Embedding(hidden_size=hidden_size, - vocab_size=vocab_size, - max_sequence_length=max_sequence_length, - embedding_dropout_prob=dropout_prob, - num_tokentypes=num_tokentypes) + self.embedding = Embedding( + hidden_size=hidden_size, + vocab_size=vocab_size, + max_sequence_length=max_sequence_length, + embedding_dropout_prob=dropout_prob, + num_tokentypes=num_tokentypes, + ) self.bert_layers = nn.ModuleList() for i in range(num_layers): - bert_layer = BertLayer(layer_number=i+1, - hidden_size=hidden_size, - num_attention_heads=num_attention_heads, - attention_dropout=dropout_prob, - mlp_ratio=mlp_ratio, - hidden_dropout=dropout_prob, - convert_fp16_to_fp32_in_softmax=convert_fp16_to_fp32_in_softmax, - is_naive_fp16=is_naive_fp16 - ) + bert_layer = BertLayer( + layer_number=i + 1, + hidden_size=hidden_size, + num_attention_heads=num_attention_heads, + attention_dropout=dropout_prob, + mlp_ratio=mlp_ratio, + hidden_dropout=dropout_prob, + convert_fp16_to_fp32_in_softmax=convert_fp16_to_fp32_in_softmax, + is_naive_fp16=is_naive_fp16, + ) self.bert_layers.append(bert_layer) self.layer_norm = LayerNorm(hidden_size) - self.head = BertDualHead(hidden_size, self.embedding.word_embedding_weight.size(0), - add_binary_head=add_binary_head) + self.head = BertDualHead( + hidden_size, self.embedding.word_embedding_weight.size(0), add_binary_head=add_binary_head + ) self.reset_parameters() def _init_normal(self, tensor): @@ -118,27 +127,30 @@ class BertForPretrain(nn.Module): class PipelineBertForPretrain(nn.Module): - - def __init__(self, - vocab_size, - hidden_size, - max_sequence_length, - num_attention_heads, - num_layers, - add_binary_head, - is_naive_fp16, - num_tokentypes=2, - dropout_prob=0.1, - mlp_ratio=4, - init_std=0.02, - convert_fp16_to_fp32_in_softmax=False, - first_stage=True, - last_stage=True, - start_idx=None, - end_idx=None): + def __init__( + self, + vocab_size, + hidden_size, + max_sequence_length, + num_attention_heads, + num_layers, + add_binary_head, + is_naive_fp16, + num_tokentypes=2, + dropout_prob=0.1, + mlp_ratio=4, + init_std=0.02, + convert_fp16_to_fp32_in_softmax=False, + first_stage=True, + last_stage=True, + start_idx=None, + end_idx=None, + ): super().__init__() self.seq_parallel_size = gpc.get_world_size(ParallelMode.SEQUENCE) - assert max_sequence_length % self.seq_parallel_size == 0, 'sequence length is not divisible by the sequence parallel size' + assert ( + max_sequence_length % self.seq_parallel_size == 0 + ), "sequence length is not divisible by the sequence parallel size" self.sub_seq_length = max_sequence_length // self.seq_parallel_size self.init_std = init_std self.num_layers = num_layers @@ -152,11 +164,13 @@ class PipelineBertForPretrain(nn.Module): self.preprocessor = PreProcessor(self.sub_seq_length) if self.first_stage: - self.embedding = Embedding(hidden_size=hidden_size, - vocab_size=vocab_size, - max_sequence_length=max_sequence_length, - embedding_dropout_prob=dropout_prob, - num_tokentypes=num_tokentypes) + self.embedding = Embedding( + hidden_size=hidden_size, + vocab_size=vocab_size, + max_sequence_length=max_sequence_length, + embedding_dropout_prob=dropout_prob, + num_tokentypes=num_tokentypes, + ) # transformer layers self.bert_layers = nn.ModuleList() @@ -166,22 +180,22 @@ class PipelineBertForPretrain(nn.Module): end_idx = num_layers for i in range(start_idx, end_idx): - bert_layer = BertLayer(layer_number=i+1, - hidden_size=hidden_size, - num_attention_heads=num_attention_heads, - attention_dropout=dropout_prob, - mlp_ratio=mlp_ratio, - hidden_dropout=dropout_prob, - convert_fp16_to_fp32_in_softmax=convert_fp16_to_fp32_in_softmax, - is_naive_fp16=is_naive_fp16 - ) + bert_layer = BertLayer( + layer_number=i + 1, + hidden_size=hidden_size, + num_attention_heads=num_attention_heads, + attention_dropout=dropout_prob, + mlp_ratio=mlp_ratio, + hidden_dropout=dropout_prob, + convert_fp16_to_fp32_in_softmax=convert_fp16_to_fp32_in_softmax, + is_naive_fp16=is_naive_fp16, + ) self.bert_layers.append(bert_layer) if self.last_stage: self.word_embeddings = VocabEmbedding(vocab_size, hidden_size) self.layer_norm = LayerNorm(hidden_size) - self.head = BertDualHead(hidden_size, vocab_size, - add_binary_head=add_binary_head) + self.head = BertDualHead(hidden_size, vocab_size, add_binary_head=add_binary_head) self.reset_parameters() def _init_normal(self, tensor): @@ -254,7 +268,7 @@ def _filter_kwargs(func, kwargs): return {k: v for k, v in kwargs.items() if k in sig.parameters} -def build_pipeline_bert(num_layers, num_chunks, device=torch.device('cuda'), **kwargs): +def build_pipeline_bert(num_layers, num_chunks, device=torch.device("cuda"), **kwargs): logger = get_dist_logger() pipeline_size = gpc.get_world_size(ParallelMode.PIPELINE) pipeline_rank = gpc.get_local_rank(ParallelMode.PIPELINE) @@ -263,12 +277,12 @@ def build_pipeline_bert(num_layers, num_chunks, device=torch.device('cuda'), **k parts = partition_uniform(num_layers, pipeline_size, num_chunks)[pipeline_rank] models = [] for start, end in parts: - kwargs['num_layers'] = num_layers - kwargs['start_idx'] = start - kwargs['end_idx'] = end - kwargs['first_stage'] = start == 0 - kwargs['last_stage'] = end == num_layers - logger.info(f'Rank{rank} build layer {start}-{end}, {end-start}/{num_layers} layers') + kwargs["num_layers"] = num_layers + kwargs["start_idx"] = start + kwargs["end_idx"] = end + kwargs["first_stage"] = start == 0 + kwargs["last_stage"] = end == num_layers + logger.info(f"Rank{rank} build layer {start}-{end}, {end-start}/{num_layers} layers") chunk = PipelineBertForPretrain(**_filter_kwargs(PipelineBertForPretrain.__init__, kwargs)).to(device) if start == 0: wrapper.register_module(chunk.embedding.word_embeddings) diff --git a/examples/tutorial/sequence_parallel/model/layers/__init__.py b/examples/tutorial/sequence_parallel/model/layers/__init__.py index 3a8823caa81b4e1f3cdbf9ea4f79e012a1b99505..58495c516239eefe1736e5d16969ecdcc2766740 100644 --- a/examples/tutorial/sequence_parallel/model/layers/__init__.py +++ b/examples/tutorial/sequence_parallel/model/layers/__init__.py @@ -1,4 +1,4 @@ -from .embedding import VocabEmbedding, Embedding from .bert_layer import BertLayer +from .embedding import Embedding, VocabEmbedding from .head import BertDualHead from .preprocess import PreProcessor diff --git a/examples/tutorial/sequence_parallel/model/layers/bert_layer.py b/examples/tutorial/sequence_parallel/model/layers/bert_layer.py index 4ede21516f65a8bc008a51dea505811fb0ceda48..1ef16ee6ad795cd8b62f8c54917458d8fc18f61e 100644 --- a/examples/tutorial/sequence_parallel/model/layers/bert_layer.py +++ b/examples/tutorial/sequence_parallel/model/layers/bert_layer.py @@ -1,10 +1,12 @@ import torch import torch.nn as nn -from colossalai.nn.layer.parallel_sequence import TransformerSelfAttentionRing -from colossalai.kernel.jit import bias_dropout_add_fused_train, bias_dropout_add_fused_inference + from colossalai.kernel.cuda_native import LayerNorm -from .mlp import TransformerMLP +from colossalai.kernel.jit import bias_dropout_add_fused_inference, bias_dropout_add_fused_train +from colossalai.legacy.nn.layer.parallel_sequence import TransformerSelfAttentionRing + from .dropout import get_bias_dropout_add +from .mlp import TransformerMLP def attention_mask_func(attention_scores, attention_mask): @@ -18,18 +20,20 @@ class BertLayer(nn.Module): output of the same size. """ - def __init__(self, - layer_number, - hidden_size, - num_attention_heads, - attention_dropout, - mlp_ratio, - hidden_dropout, - is_naive_fp16, - apply_residual_connection_post_layernorm=False, - fp32_residual_connection=False, - bias_dropout_fusion: bool = True, - convert_fp16_to_fp32_in_softmax: bool = False): + def __init__( + self, + layer_number, + hidden_size, + num_attention_heads, + attention_dropout, + mlp_ratio, + hidden_dropout, + is_naive_fp16, + apply_residual_connection_post_layernorm=False, + fp32_residual_connection=False, + bias_dropout_fusion: bool = True, + convert_fp16_to_fp32_in_softmax: bool = False, + ): super().__init__() self.layer_number = layer_number @@ -48,7 +52,7 @@ class BertLayer(nn.Module): layer_number=layer_number, apply_query_key_layer_scaling=True, convert_fp16_to_fp32_in_softmax=convert_fp16_to_fp32_in_softmax, - fp16=is_naive_fp16 + fp16=is_naive_fp16, ) self.hidden_dropout = hidden_dropout @@ -90,10 +94,8 @@ class BertLayer(nn.Module): # re-enable torch grad to enable fused optimization. with torch.enable_grad(): layernorm_input = bias_dropout_add_func( - attention_output, - attention_bias.expand_as(residual), - residual, - self.hidden_dropout) + attention_output, attention_bias.expand_as(residual), residual, self.hidden_dropout + ) # Layer norm post the self attention. layernorm_output = self.post_attention_layernorm(layernorm_input) @@ -109,10 +111,6 @@ class BertLayer(nn.Module): # re-enable torch grad to enable fused optimization. with torch.enable_grad(): - output = bias_dropout_add_func( - mlp_output, - mlp_bias.expand_as(residual), - residual, - self.hidden_dropout) + output = bias_dropout_add_func(mlp_output, mlp_bias.expand_as(residual), residual, self.hidden_dropout) return output diff --git a/examples/tutorial/sequence_parallel/model/layers/dropout.py b/examples/tutorial/sequence_parallel/model/layers/dropout.py index 0e99105b8f7e9d1ce2f884eba12678d5b18f4468..18eae0d63cd11d7a7398e8429b2e0ac1f9823bb2 100644 --- a/examples/tutorial/sequence_parallel/model/layers/dropout.py +++ b/examples/tutorial/sequence_parallel/model/layers/dropout.py @@ -1,5 +1,6 @@ import torch + def bias_dropout_add(x, bias, residual, prob, training): # type: (Tensor, Tensor, Tensor, float, bool) -> Tensor out = torch.nn.functional.dropout(x + bias, p=prob, training=training) @@ -10,4 +11,5 @@ def bias_dropout_add(x, bias, residual, prob, training): def get_bias_dropout_add(training): def _bias_dropout_add(x, bias, residual, prob): return bias_dropout_add(x, bias, residual, prob, training) - return _bias_dropout_add \ No newline at end of file + + return _bias_dropout_add diff --git a/examples/tutorial/sequence_parallel/model/layers/embedding.py b/examples/tutorial/sequence_parallel/model/layers/embedding.py index 0700d960d845ff1ab1f0258d6e174b88ad3ac902..03183c55243f38d9aab1415e58fc475f58822d4b 100644 --- a/examples/tutorial/sequence_parallel/model/layers/embedding.py +++ b/examples/tutorial/sequence_parallel/model/layers/embedding.py @@ -5,7 +5,6 @@ import torch.nn.init as init class VocabEmbedding(torch.nn.Module): - def __init__(self, num_embeddings, embedding_dim): super(VocabEmbedding, self).__init__() # Keep the input dimensions. @@ -13,26 +12,29 @@ class VocabEmbedding(torch.nn.Module): self.embedding_dim = embedding_dim self.padding_idx = None self.max_norm = None - self.norm_type = 2. + self.norm_type = 2.0 self.scale_grad_by_freq = False self.sparse = False self._weight = None # Allocate weights and initialize. - self.weight = nn.Parameter(torch.empty( - self.num_embeddings, self.embedding_dim)) + self.weight = nn.Parameter(torch.empty(self.num_embeddings, self.embedding_dim)) init.xavier_uniform_(self.weight) def forward(self, hidden_state): - output = F.embedding(hidden_state, self.weight, - self.padding_idx, self.max_norm, - self.norm_type, self.scale_grad_by_freq, - self.sparse) + output = F.embedding( + hidden_state, + self.weight, + self.padding_idx, + self.max_norm, + self.norm_type, + self.scale_grad_by_freq, + self.sparse, + ) return output def __repr__(self): - return f'VocabEmbedding(num_embeddings={self.num_embeddings}, ' \ - f'embedding_dim={self.embedding_dim})' + return f"VocabEmbedding(num_embeddings={self.num_embeddings}, " f"embedding_dim={self.embedding_dim})" class Embedding(nn.Module): @@ -48,12 +50,7 @@ class Embedding(nn.Module): will ignore this embedding """ - def __init__(self, - hidden_size, - vocab_size, - max_sequence_length, - embedding_dropout_prob, - num_tokentypes): + def __init__(self, hidden_size, vocab_size, max_sequence_length, embedding_dropout_prob, num_tokentypes): super(Embedding, self).__init__() self.hidden_size = hidden_size @@ -62,16 +59,14 @@ class Embedding(nn.Module): self.word_embeddings = VocabEmbedding(vocab_size, self.hidden_size) # Position embedding (serial). - self.position_embeddings = torch.nn.Embedding( - max_sequence_length, self.hidden_size) + self.position_embeddings = torch.nn.Embedding(max_sequence_length, self.hidden_size) # Token type embedding. # Add this as an optional field that can be added through # method call so we can load a pretrain model without # token types and add them as needed. if self.num_tokentypes > 0: - self.tokentype_embeddings = torch.nn.Embedding(self.num_tokentypes, - self.hidden_size) + self.tokentype_embeddings = torch.nn.Embedding(self.num_tokentypes, self.hidden_size) else: self.tokentype_embeddings = None diff --git a/examples/tutorial/sequence_parallel/model/layers/head.py b/examples/tutorial/sequence_parallel/model/layers/head.py index ea336b9d131e3405dc22d03f3e7b51add4e54466..75afeee60ad40965a204dcd7646ebebb742e1509 100644 --- a/examples/tutorial/sequence_parallel/model/layers/head.py +++ b/examples/tutorial/sequence_parallel/model/layers/head.py @@ -1,15 +1,15 @@ -import colossalai import torch import torch.nn as nn import torch.nn.functional as F -from .pooler import Pooler -from .linear import Linear -from .embedding import VocabEmbedding -from colossalai.core import global_context as gpc -from colossalai.context import ParallelMode -from colossalai.kernel import LayerNorm from loss_func.cross_entropy import vocab_cross_entropy +from colossalai.kernel import LayerNorm +from colossalai.legacy.context import ParallelMode +from colossalai.legacy.core import global_context as gpc + +from .linear import Linear +from .pooler import Pooler + class BertLMHead(nn.Module): """Masked LM head for Bert @@ -19,11 +19,11 @@ class BertLMHead(nn.Module): layernorm_epsilon: tolerance for layer norm divisions """ - def __init__(self, - vocab_size, - hidden_size, - ): - + def __init__( + self, + vocab_size, + hidden_size, + ): super(BertLMHead, self).__init__() self.bias = torch.nn.Parameter(torch.zeros(vocab_size)) @@ -43,7 +43,6 @@ class BertLMHead(nn.Module): class BertBinaryHead(nn.Module): - def __init__(self, hidden_size): super().__init__() self.pooler = Pooler(hidden_size) @@ -59,7 +58,6 @@ class BertBinaryHead(nn.Module): class BertDualHead(nn.Module): - def __init__(self, hidden_size, vocab_size, add_binary_head): super().__init__() self.lm_head = BertLMHead(vocab_size, hidden_size) diff --git a/examples/tutorial/sequence_parallel/model/layers/init_method.py b/examples/tutorial/sequence_parallel/model/layers/init_method.py index 1b409dfe40541524891f70fc7c7d8297afa86999..22d12a504fab77eab739c33687bdd2df50ed3946 100644 --- a/examples/tutorial/sequence_parallel/model/layers/init_method.py +++ b/examples/tutorial/sequence_parallel/model/layers/init_method.py @@ -1,6 +1,8 @@ -import torch import math +import torch + + def init_normal(tensor, sigma): """Init method based on N(0, sigma).""" torch.nn.init.normal_(tensor, mean=0.0, std=sigma) diff --git a/examples/tutorial/sequence_parallel/model/layers/linear.py b/examples/tutorial/sequence_parallel/model/layers/linear.py index 5ae7d671e2bf2312da315d35629dcdc12ca075de..5592f6e8c209448930381f1d80885125f19f7afe 100644 --- a/examples/tutorial/sequence_parallel/model/layers/linear.py +++ b/examples/tutorial/sequence_parallel/model/layers/linear.py @@ -1,8 +1,8 @@ import torch import torch.nn as nn -from torch.nn import Parameter import torch.nn.functional as F import torch.nn.init as init +from torch.nn import Parameter class Linear(nn.Module): @@ -24,11 +24,7 @@ class Linear(nn.Module): adding bias but instead return it. """ - def __init__(self, - input_size, - output_size, - bias=True, - skip_bias_add=False): + def __init__(self, input_size, output_size, bias=True, skip_bias_add=False): super(Linear, self).__init__() # Keep input parameters @@ -36,9 +32,12 @@ class Linear(nn.Module): self.output_size = output_size self.skip_bias_add = skip_bias_add - self.weight = Parameter(torch.empty(self.output_size, - self.input_size, - )) + self.weight = Parameter( + torch.empty( + self.output_size, + self.input_size, + ) + ) init.normal_(self.weight) if bias: self.bias = Parameter(torch.empty(self.output_size)) @@ -46,7 +45,7 @@ class Linear(nn.Module): with torch.no_grad(): self.bias.zero_() else: - self.register_parameter('bias', None) + self.register_parameter("bias", None) def forward(self, input_): # Matrix multiply. @@ -59,5 +58,7 @@ class Linear(nn.Module): return output def __repr__(self): - return f'Linear(in_features={self.input_size}, out_features={self.output_size}, ' + \ - f'bias={self.bias is not None}, skip_bias_add={self.skip_bias_add})' + return ( + f"Linear(in_features={self.input_size}, out_features={self.output_size}, " + + f"bias={self.bias is not None}, skip_bias_add={self.skip_bias_add})" + ) diff --git a/examples/tutorial/sequence_parallel/model/layers/mlp.py b/examples/tutorial/sequence_parallel/model/layers/mlp.py index a255de813d135e5c79a281c86462340337c5c036..54a695fda4023ff6b22fe3c8254a2b5288b31a5e 100644 --- a/examples/tutorial/sequence_parallel/model/layers/mlp.py +++ b/examples/tutorial/sequence_parallel/model/layers/mlp.py @@ -1,10 +1,10 @@ -import torch import torch.nn as nn import torch.nn.functional as F -from .linear import Linear from colossalai.kernel.jit import bias_gelu_impl +from .linear import Linear + class TransformerMLP(nn.Module): """MLP. @@ -18,19 +18,13 @@ class TransformerMLP(nn.Module): super(TransformerMLP, self).__init__() # Project to 4h. - self.dense_h_to_4h = Linear( - hidden_size, - int(hidden_size*mlp_ratio), - skip_bias_add=True) + self.dense_h_to_4h = Linear(hidden_size, int(hidden_size * mlp_ratio), skip_bias_add=True) self.bias_gelu_fusion = fuse_gelu self.activation_func = F.gelu # Project back to h. - self.dense_4h_to_h = Linear( - int(hidden_size*mlp_ratio), - hidden_size, - skip_bias_add=True) + self.dense_4h_to_h = Linear(int(hidden_size * mlp_ratio), hidden_size, skip_bias_add=True) def forward(self, hidden_states): # hidden states should be in the shape of [s, b, h] @@ -39,11 +33,9 @@ class TransformerMLP(nn.Module): intermediate_parallel, bias_parallel = self.dense_h_to_4h(hidden_states) if self.bias_gelu_fusion: - intermediate_parallel = \ - bias_gelu_impl(intermediate_parallel, bias_parallel) + intermediate_parallel = bias_gelu_impl(intermediate_parallel, bias_parallel) else: - intermediate_parallel = \ - self.activation_func(intermediate_parallel + bias_parallel) + intermediate_parallel = self.activation_func(intermediate_parallel + bias_parallel) # [s, b, h] output, output_bias = self.dense_4h_to_h(intermediate_parallel) diff --git a/examples/tutorial/sequence_parallel/model/layers/pooler.py b/examples/tutorial/sequence_parallel/model/layers/pooler.py index 282ed114790b32618c1a92924bce403167d1b89e..c3397787aecf035b7ed1e6ec74b3346538de77f1 100644 --- a/examples/tutorial/sequence_parallel/model/layers/pooler.py +++ b/examples/tutorial/sequence_parallel/model/layers/pooler.py @@ -1,5 +1,6 @@ import torch import torch.nn as nn + from .linear import Linear diff --git a/examples/tutorial/sequence_parallel/model/layers/preprocess.py b/examples/tutorial/sequence_parallel/model/layers/preprocess.py index 53a326ddacf14f7babaa18390a04b962bd84c809..55dd20e1e948f26c1d874a68a3f32de626aeb788 100644 --- a/examples/tutorial/sequence_parallel/model/layers/preprocess.py +++ b/examples/tutorial/sequence_parallel/model/layers/preprocess.py @@ -1,11 +1,11 @@ -from colossalai.context.parallel_mode import ParallelMode import torch import torch.nn as nn -from colossalai.core import global_context as gpc +from colossalai.legacy.context.parallel_mode import ParallelMode +from colossalai.legacy.core import global_context as gpc -class PreProcessor(nn.Module): +class PreProcessor(nn.Module): def __init__(self, sub_seq_length): super().__init__() self.sub_seq_length = sub_seq_length @@ -14,10 +14,9 @@ class PreProcessor(nn.Module): # Create position ids seq_length = token_ids.size(1) local_rank = gpc.get_local_rank(ParallelMode.SEQUENCE) - position_ids = torch.arange(seq_length*local_rank, - seq_length * (local_rank+1), - dtype=torch.long, - device=token_ids.device) + position_ids = torch.arange( + seq_length * local_rank, seq_length * (local_rank + 1), dtype=torch.long, device=token_ids.device + ) position_ids = position_ids.unsqueeze(0).expand_as(token_ids) return position_ids @@ -41,7 +40,7 @@ class PreProcessor(nn.Module): extended_attention_mask = attention_mask_bss.unsqueeze(1) # Convert attention mask to binary: - extended_attention_mask = (extended_attention_mask < 0.5) + extended_attention_mask = extended_attention_mask < 0.5 return extended_attention_mask diff --git a/examples/tutorial/sequence_parallel/requirements.txt b/examples/tutorial/sequence_parallel/requirements.txt index b49a94554afb699639f9da7d04808be760c051c0..4fc576453de8806c13c6b7e28f41ec48323cf6a7 100644 --- a/examples/tutorial/sequence_parallel/requirements.txt +++ b/examples/tutorial/sequence_parallel/requirements.txt @@ -1,2 +1,3 @@ colossalai torch +six diff --git a/examples/tutorial/sequence_parallel/test_ci.sh b/examples/tutorial/sequence_parallel/test_ci.sh index 7bc20de3b6e414a8bc74bca5fcb2d17ec18b1106..1cd646526d994f50209eb62fc412e08c9081e6f2 100644 --- a/examples/tutorial/sequence_parallel/test_ci.sh +++ b/examples/tutorial/sequence_parallel/test_ci.sh @@ -1,7 +1,8 @@ #!/bin/bash set -euxo pipefail -pip install -r requirements.txt +echo "this test is outdated" +# pip install -r requirements.txt # run test -colossalai run --nproc_per_node 4 train.py +# colossalai run --nproc_per_node 4 train.py diff --git a/examples/tutorial/sequence_parallel/train.py b/examples/tutorial/sequence_parallel/train.py index a89747b5845eb1a4f9e663373a015daf39d169a2..e9ceb8d70cb81d781cebdc497d2a1b21af7de592 100644 --- a/examples/tutorial/sequence_parallel/train.py +++ b/examples/tutorial/sequence_parallel/train.py @@ -8,14 +8,14 @@ from lr_scheduler import AnnealingLR from model.bert import BertForPretrain, build_pipeline_bert import colossalai -from colossalai.amp import AMP_TYPE -from colossalai.context.parallel_mode import ParallelMode -from colossalai.core import global_context as gpc -from colossalai.engine.schedule import PipelineSchedule from colossalai.kernel import LayerNorm +from colossalai.legacy.amp import AMP_TYPE +from colossalai.legacy.context.parallel_mode import ParallelMode +from colossalai.legacy.core import global_context as gpc +from colossalai.legacy.utils import is_using_pp from colossalai.logging import get_dist_logger from colossalai.nn.optimizer import FusedAdam -from colossalai.utils import MultiTimer, is_using_pp +from colossalai.utils import MultiTimer def process_batch_data(batch_data): @@ -30,7 +30,7 @@ def process_batch_data(batch_data): def parse_args(): parser = argparse.ArgumentParser() - parser.add_argument('-s', '--synthetic', action="store_true", help="whether use synthetic data") + parser.add_argument("-s", "--synthetic", action="store_true", help="whether use synthetic data") return parser.parse_args() @@ -47,37 +47,39 @@ def pipeline_data_process_func(stage_output, micro_batch_data): def main(): # initialize - args = parse_args() - colossalai.launch_from_torch(config='./config.py', seed=1234, backend='nccl') + parse_args() + colossalai.launch_from_torch(config="./config.py", seed=1234, backend="nccl") logger = get_dist_logger() # build synthetic dataloader BATCH_SIZE_PER_GPUS = gpc.config.GLOBAL_BATCH_SIZE // gpc.get_world_size(ParallelMode.DATA) VOCAB_SIZE = 30528 - trainloader = DummyDataloader(batch_size=BATCH_SIZE_PER_GPUS, - vocab_size=VOCAB_SIZE, - seq_length=gpc.config.SEQ_LENGTH) - validloader = DummyDataloader(batch_size=BATCH_SIZE_PER_GPUS, - vocab_size=VOCAB_SIZE, - seq_length=gpc.config.SEQ_LENGTH) + trainloader = DummyDataloader( + batch_size=BATCH_SIZE_PER_GPUS, vocab_size=VOCAB_SIZE, seq_length=gpc.config.SEQ_LENGTH + ) + validloader = DummyDataloader( + batch_size=BATCH_SIZE_PER_GPUS, vocab_size=VOCAB_SIZE, seq_length=gpc.config.SEQ_LENGTH + ) logger.info("Dataloaders are built", ranks=[0]) # build model - if hasattr(gpc.config, 'fp16') and gpc.config.fp16.get('mode') == AMP_TYPE.NAIVE: + if hasattr(gpc.config, "fp16") and gpc.config.fp16.get("mode") == AMP_TYPE.NAIVE: is_naive_fp16 = True else: is_naive_fp16 = False use_pipeline = is_using_pp() - kwargs = dict(vocab_size=VOCAB_SIZE, - hidden_size=gpc.config.HIDDEN_SIZE, - max_sequence_length=gpc.config.SEQ_LENGTH, - num_attention_heads=gpc.config.NUM_ATTENTION_HEADS, - convert_fp16_to_fp32_in_softmax=True, - is_naive_fp16=is_naive_fp16, - add_binary_head=gpc.config.ADD_BINARY_HEAD) + kwargs = dict( + vocab_size=VOCAB_SIZE, + hidden_size=gpc.config.HIDDEN_SIZE, + max_sequence_length=gpc.config.SEQ_LENGTH, + num_attention_heads=gpc.config.NUM_ATTENTION_HEADS, + convert_fp16_to_fp32_in_softmax=True, + is_naive_fp16=is_naive_fp16, + add_binary_head=gpc.config.ADD_BINARY_HEAD, + ) if use_pipeline: model = build_pipeline_bert(num_layers=gpc.config.DEPTH, num_chunks=1, **kwargs) @@ -98,35 +100,39 @@ def main(): logger.info("Criterion is built", ranks=[0]) # layernorm and bias has no weight decay - weight_decay_params = {'params': []} - no_weight_decay_params = {'params': [], 'weight_decay': 0.0} + weight_decay_params = {"params": []} + no_weight_decay_params = {"params": [], "weight_decay": 0.0} for module_ in model.modules(): if isinstance(module_, LayerNorm): - no_weight_decay_params['params'].extend([p for p in list(module_._parameters.values()) if p is not None]) + no_weight_decay_params["params"].extend([p for p in list(module_._parameters.values()) if p is not None]) else: - weight_decay_params['params'].extend( - [p for n, p in list(module_._parameters.items()) if p is not None and n != 'bias']) - no_weight_decay_params['params'].extend( - [p for n, p in list(module_._parameters.items()) if p is not None and n == 'bias']) + weight_decay_params["params"].extend( + [p for n, p in list(module_._parameters.items()) if p is not None and n != "bias"] + ) + no_weight_decay_params["params"].extend( + [p for n, p in list(module_._parameters.items()) if p is not None and n == "bias"] + ) logger.info( f"without weight decay param: {len(no_weight_decay_params['params'])}, with weight decay param: {len(weight_decay_params['params'])}" ) # optimizer - optimizer = FusedAdam((weight_decay_params, no_weight_decay_params), - lr=gpc.config.LR, - weight_decay=gpc.config.WEIGHT_DECAY) + optimizer = FusedAdam( + (weight_decay_params, no_weight_decay_params), lr=gpc.config.LR, weight_decay=gpc.config.WEIGHT_DECAY + ) logger.info("Optimizer is built", ranks=[0]) # lr scheduler # follow Megatron-LM setting warmup_steps = int(gpc.config.DECAY_ITERS * gpc.config.WARMUP_FRACTION) - lr_scheduler = AnnealingLR(optimizer=optimizer, - max_lr=gpc.config.LR, - min_lr=gpc.config.MIN_LR, - warmup_steps=warmup_steps, - decay_steps=gpc.config.DECAY_ITERS, - decay_style='linear') + lr_scheduler = AnnealingLR( + optimizer=optimizer, + max_lr=gpc.config.LR, + min_lr=gpc.config.MIN_LR, + warmup_steps=warmup_steps, + decay_steps=gpc.config.DECAY_ITERS, + decay_style="linear", + ) logger.info(f"LR Scheduler is built with {warmup_steps} warmup steps and {gpc.config.DECAY_ITERS} decay steps") # # init @@ -134,7 +140,6 @@ def main(): # build timer timer = MultiTimer() - skip_iters = 0 # build loss tracker accumulated_train_loss = torch.zeros(1, dtype=torch.float32).cuda() @@ -149,7 +154,7 @@ def main(): logger.info("start training") for step in range(1, gpc.config.TRAIN_ITERS + 1): - timer.start('train-iterations') + timer.start("train-iterations") engine.train() if use_pipeline: engine.zero_grad() @@ -157,13 +162,14 @@ def main(): engine.step() else: tokens, types, sentence_order, loss_mask, lm_labels, padding_mask = get_batch_for_sequence_parallel( - trainloader) + trainloader + ) engine.zero_grad() lm_loss, sop_output = engine(tokens, padding_mask, types, lm_labels) train_loss = engine.criterion(lm_loss, sop_output, loss_mask, sentence_order) engine.backward(train_loss) engine.step() - timer.stop('train-iterations', keep_in_history=True) + timer.stop("train-iterations", keep_in_history=True) if not gpc.is_initialized(ParallelMode.PIPELINE) or gpc.is_last_rank(ParallelMode.PIPELINE): accumulated_train_loss += train_loss @@ -176,12 +182,18 @@ def main(): for j in range(gpc.config.EVAL_ITERS): with torch.no_grad(): if use_pipeline: - _, _, eval_loss = engine.execute_schedule(valid_data_iter, - forward_only=True, - return_output_label=False) + _, _, eval_loss = engine.execute_schedule( + valid_data_iter, forward_only=True, return_output_label=False + ) else: - tokens, types, sentence_order, loss_mask, lm_labels, padding_mask = get_batch_for_sequence_parallel( - validloader) + ( + tokens, + types, + sentence_order, + loss_mask, + lm_labels, + padding_mask, + ) = get_batch_for_sequence_parallel(validloader) lm_loss, sop_output = engine(tokens, padding_mask, types, lm_labels) eval_loss = engine.criterion(lm_loss, sop_output, loss_mask, sentence_order) @@ -195,18 +207,22 @@ def main(): timer_string = [] for n, t in timer: timer_string.append(f"{n}: {t.get_history_mean()*1000:.5f}") - timer_string = ' | '.join(timer_string) - lr = list(engine.optimizer.param_groups)[0]['lr'] + timer_string = " | ".join(timer_string) + lr = list(engine.optimizer.param_groups)[0]["lr"] loss_scale = engine.optimizer.optim.loss_scale.item() if gpc.is_initialized(ParallelMode.PIPELINE): ranks = [gpc.get_ranks_in_group(ParallelMode.PIPELINE)[-1]] else: ranks = [0] - logger.info(f'Step {step} / {gpc.config.TRAIN_ITERS} | Train Loss: {accumulated_train_loss.item():.5g} ' + - f'| Eval Loss: {accumulated_eval_loss.item():.5g} ' + f'| Loss Scale: {loss_scale}' + - f"| Learning rate: {lr} | " + timer_string, - ranks=ranks) + logger.info( + f"Step {step} / {gpc.config.TRAIN_ITERS} | Train Loss: {accumulated_train_loss.item():.5g} " + + f"| Eval Loss: {accumulated_eval_loss.item():.5g} " + + f"| Loss Scale: {loss_scale}" + + f"| Learning rate: {lr} | " + + timer_string, + ranks=ranks, + ) for n, t in timer: t.reset() @@ -214,5 +230,5 @@ def main(): accumulated_train_loss.zero_() -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/op_builder/__init__.py b/op_builder/__init__.py index 5ae7223b8c692d5570d74a1ef72c0fe1ad61c28a..808559ec9c2d1c7014b51b3ee0982c0b46046565 100644 --- a/op_builder/__init__.py +++ b/op_builder/__init__.py @@ -7,17 +7,26 @@ from .scaled_masked_softmax import ScaledMaskedSoftmaxBuilder from .scaled_upper_triangle_masked_softmax import ScaledUpperTrainglemaskedSoftmaxBuilder ALL_OPS = { - 'cpu_adam': CPUAdamBuilder, - 'fused_optim': FusedOptimBuilder, - 'moe': MOEBuilder, - 'multi_head_attn': MultiHeadAttnBuilder, - 'scaled_masked_softmax': ScaledMaskedSoftmaxBuilder, - 'scaled_upper_triangle_masked_softmax': ScaledUpperTrainglemaskedSoftmaxBuilder, - 'layernorm': LayerNormBuilder, + "cpu_adam": CPUAdamBuilder, + "fused_optim": FusedOptimBuilder, + "moe": MOEBuilder, + "multi_head_attn": MultiHeadAttnBuilder, + "scaled_masked_softmax": ScaledMaskedSoftmaxBuilder, + "scaled_upper_triangle_masked_softmax": ScaledUpperTrainglemaskedSoftmaxBuilder, + "layernorm": LayerNormBuilder, } __all__ = [ - 'ALL_OPS', 'CPUAdamBuilder', 'FusedOptimBuilder', 'MultiHeadAttnBuilder', 'ScaledMaskedSoftmaxBuilder', - 'ScaledUpperTrainglemaskedSoftmaxBuilder', 'MOEBuilder', 'MultiTensorSGDBuilder', 'MultiTensorAdamBuilder', - 'MultiTensorLambBuilder', 'MultiTensorScaleBuilder', 'MultiTensorL2NormBuilder' + "ALL_OPS", + "CPUAdamBuilder", + "FusedOptimBuilder", + "MultiHeadAttnBuilder", + "ScaledMaskedSoftmaxBuilder", + "ScaledUpperTrainglemaskedSoftmaxBuilder", + "MOEBuilder", + "MultiTensorSGDBuilder", + "MultiTensorAdamBuilder", + "MultiTensorLambBuilder", + "MultiTensorScaleBuilder", + "MultiTensorL2NormBuilder", ] diff --git a/op_builder/builder.py b/op_builder/builder.py index 8396235e5cfe89c6a4edccc4c7eb33ceefbbe11b..75823ef105c7b87f786e2f90d7b7cf0d60942d2d 100644 --- a/op_builder/builder.py +++ b/op_builder/builder.py @@ -24,13 +24,14 @@ class Builder(ABC): def __init__(self, name: str, prebuilt_import_path: str): self.name = name self.prebuilt_import_path = prebuilt_import_path - self.version_dependent_macros = ['-DVERSION_GE_1_1', '-DVERSION_GE_1_3', '-DVERSION_GE_1_5'] + self.version_dependent_macros = ["-DVERSION_GE_1_1", "-DVERSION_GE_1_3", "-DVERSION_GE_1_5"] # we store the op as an attribute to avoid repeated building and loading self.cached_op_module = None - assert prebuilt_import_path.startswith('colossalai._C'), \ - f'The prebuilt_import_path should start with colossalai._C, but got {self.prebuilt_import_path}' + assert prebuilt_import_path.startswith( + "colossalai._C" + ), f"The prebuilt_import_path should start with colossalai._C, but got {self.prebuilt_import_path}" def relative_to_abs_path(self, code_path: str) -> str: """ @@ -46,10 +47,10 @@ class Builder(ABC): # this symlink will be replaced with actual files if we install via pypi # thus we cannot tell the colossalai root directory by checking whether the op_builder # is a symlink, we can only tell whether it is inside or outside colossalai - if str(op_builder_module_path).endswith('colossalai/kernel/op_builder'): + if str(op_builder_module_path).endswith("colossalai/kernel/op_builder"): root_path = op_builder_module_path.parent.parent else: - root_path = op_builder_module_path.parent.joinpath('colossalai') + root_path = op_builder_module_path.parent.joinpath("colossalai") code_abs_path = root_path.joinpath(code_path) return str(code_abs_path) @@ -59,13 +60,14 @@ class Builder(ABC): return include path inside the cuda home. """ from torch.utils.cpp_extension import CUDA_HOME + if CUDA_HOME is None: raise RuntimeError("CUDA_HOME is None, please set CUDA_HOME to compile C++/CUDA kernels in ColossalAI.") cuda_include = os.path.join(CUDA_HOME, "include") return cuda_include def csrc_abs_path(self, path): - return os.path.join(self.relative_to_abs_path('kernel/cuda_native/csrc'), path) + return os.path.join(self.relative_to_abs_path("kernel/cuda_native/csrc"), path) # functions must be overrided begin @abstractmethod @@ -80,27 +82,24 @@ class Builder(ABC): """ This function should return a list of include files for extensions. """ - pass @abstractmethod def cxx_flags(self) -> List[str]: """ This function should return a list of cxx compilation flags for extensions. """ - pass @abstractmethod def nvcc_flags(self) -> List[str]: """ This function should return a list of nvcc compilation flags for extensions. """ - pass # functions must be overrided over def strip_empty_entries(self, args): - ''' + """ Drop any empty strings from the list of compile and link flags - ''' + """ return [x for x in args if len(x) > 0] def import_op(self): @@ -114,8 +113,8 @@ class Builder(ABC): Check whether the system environment is ready for extension compilation. """ try: - import torch from torch.utils.cpp_extension import CUDA_HOME + TORCH_AVAILABLE = True except ImportError: TORCH_AVAILABLE = False @@ -123,7 +122,8 @@ class Builder(ABC): if not TORCH_AVAILABLE: raise ModuleNotFoundError( - "PyTorch is not found. You need to install PyTorch first in order to build CUDA extensions") + "PyTorch is not found. You need to install PyTorch first in order to build CUDA extensions" + ) if CUDA_HOME is None: raise RuntimeError( @@ -150,7 +150,7 @@ class Builder(ABC): verbose (bool, optional): show detailed info. Defaults to True. """ if verbose is None: - verbose = os.environ.get('CAI_KERNEL_VERBOSE', '0') == '1' + verbose = os.environ.get("CAI_KERNEL_VERBOSE", "0") == "1" # if the kernel has be compiled and cached, we directly use it if self.cached_op_module is not None: return self.cached_op_module @@ -161,7 +161,8 @@ class Builder(ABC): op_module = self.import_op() if verbose: print_rank_0( - f"[extension] OP {self.prebuilt_import_path} has been compiled ahead of time, skip building.") + f"[extension] OP {self.prebuilt_import_path} has been compiled ahead of time, skip building." + ) except ImportError: # check environment self.check_runtime_build_environment() @@ -172,10 +173,11 @@ class Builder(ABC): # construct the build directory import torch from torch.utils.cpp_extension import load - torch_version_major = torch.__version__.split('.')[0] - torch_version_minor = torch.__version__.split('.')[1] + + torch_version_major = torch.__version__.split(".")[0] + torch_version_minor = torch.__version__.split(".")[1] torch_cuda_version = torch.version.cuda - home_directory = os.path.expanduser('~') + home_directory = os.path.expanduser("~") extension_directory = f".cache/colossalai/torch_extensions/torch{torch_version_major}.{torch_version_minor}_cu{torch_cuda_version}" build_directory = os.path.join(home_directory, extension_directory) Path(build_directory).mkdir(parents=True, exist_ok=True) @@ -184,14 +186,16 @@ class Builder(ABC): print_rank_0(f"[extension] Compiling or loading the JIT-built {self.name} kernel during runtime now") # load the kernel - op_module = load(name=self.name, - sources=self.strip_empty_entries(self.sources_files()), - extra_include_paths=self.strip_empty_entries(self.include_dirs()), - extra_cflags=self.cxx_flags(), - extra_cuda_cflags=self.nvcc_flags(), - extra_ldflags=[], - build_directory=build_directory, - verbose=verbose) + op_module = load( + name=self.name, + sources=self.strip_empty_entries(self.sources_files()), + extra_include_paths=self.strip_empty_entries(self.include_dirs()), + extra_cflags=self.cxx_flags(), + extra_cuda_cflags=self.nvcc_flags(), + extra_ldflags=[], + build_directory=build_directory, + verbose=verbose, + ) build_duration = time.time() - start_build @@ -204,16 +208,18 @@ class Builder(ABC): return op_module - def builder(self) -> 'CUDAExtension': + def builder(self) -> "CUDAExtension": """ get a CUDAExtension instance used for setup.py """ from torch.utils.cpp_extension import CUDAExtension - return CUDAExtension(name=self.prebuilt_import_path, - sources=self.strip_empty_entries(self.sources_files()), - include_dirs=self.strip_empty_entries(self.include_dirs()), - extra_compile_args={ - 'cxx': self.strip_empty_entries(self.cxx_flags()), - 'nvcc': self.strip_empty_entries(self.nvcc_flags()) - }) + return CUDAExtension( + name=self.prebuilt_import_path, + sources=self.strip_empty_entries(self.sources_files()), + include_dirs=self.strip_empty_entries(self.include_dirs()), + extra_compile_args={ + "cxx": self.strip_empty_entries(self.cxx_flags()), + "nvcc": self.strip_empty_entries(self.nvcc_flags()), + }, + ) diff --git a/op_builder/cpu_adam.py b/op_builder/cpu_adam.py index 500e2cc0eddc55c0ecb5b40be5e979d912adb0c2..5a2a2e3e6a566a552c6bbe9ba054db4b3fdb93d7 100644 --- a/op_builder/cpu_adam.py +++ b/op_builder/cpu_adam.py @@ -1,5 +1,3 @@ -import os - from .builder import Builder from .utils import append_nvcc_threads @@ -10,29 +8,29 @@ class CPUAdamBuilder(Builder): def __init__(self): super().__init__(name=CPUAdamBuilder.NAME, prebuilt_import_path=CPUAdamBuilder.PREBUILT_IMPORT_PATH) - self.version_dependent_macros = ['-DVERSION_GE_1_1', '-DVERSION_GE_1_3', '-DVERSION_GE_1_5'] + self.version_dependent_macros = ["-DVERSION_GE_1_1", "-DVERSION_GE_1_3", "-DVERSION_GE_1_5"] # necessary 4 functions def sources_files(self): ret = [ - self.csrc_abs_path('cpu_adam.cpp'), + self.csrc_abs_path("cpu_adam.cpp"), ] return ret def include_dirs(self): - return [ - self.csrc_abs_path("includes"), - self.get_cuda_home_include() - ] + return [self.csrc_abs_path("includes"), self.get_cuda_home_include()] def cxx_flags(self): - extra_cxx_flags = ['-std=c++14', '-lcudart', '-lcublas', '-g', '-Wno-reorder', '-fopenmp', '-march=native'] - return ['-O3'] + self.version_dependent_macros + extra_cxx_flags + extra_cxx_flags = ["-std=c++14", "-lcudart", "-lcublas", "-g", "-Wno-reorder", "-fopenmp", "-march=native"] + return ["-O3"] + self.version_dependent_macros + extra_cxx_flags def nvcc_flags(self): extra_cuda_flags = [ - '-std=c++14', '-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__', - '-U__CUDA_NO_HALF2_OPERATORS__', '-DTHRUST_IGNORE_CUB_VERSION_CHECK' + "-std=c++14", + "-U__CUDA_NO_HALF_OPERATORS__", + "-U__CUDA_NO_HALF_CONVERSIONS__", + "-U__CUDA_NO_HALF2_OPERATORS__", + "-DTHRUST_IGNORE_CUB_VERSION_CHECK", ] - ret = ['-O3', '--use_fast_math'] + self.version_dependent_macros + extra_cuda_flags + ret = ["-O3", "--use_fast_math"] + self.version_dependent_macros + extra_cuda_flags return append_nvcc_threads(ret) diff --git a/op_builder/fused_optim.py b/op_builder/fused_optim.py index 31ddfced1db24f18fb37cb549dc27060ae2fcff6..3baa0880d801f69abe5c0b2cf829eb7687e8bbca 100644 --- a/op_builder/fused_optim.py +++ b/op_builder/fused_optim.py @@ -1,5 +1,3 @@ -import os - from .builder import Builder from .utils import get_cuda_cc_flag @@ -10,25 +8,30 @@ class FusedOptimBuilder(Builder): def __init__(self): super().__init__(name=FusedOptimBuilder.NAME, prebuilt_import_path=FusedOptimBuilder.PREBUILT_IMPORT_PATH) - + def sources_files(self): ret = [ - self.csrc_abs_path(fname) for fname in [ - 'colossal_C_frontend.cpp', 'multi_tensor_sgd_kernel.cu', 'multi_tensor_scale_kernel.cu', - 'multi_tensor_adam.cu', 'multi_tensor_l2norm_kernel.cu', 'multi_tensor_lamb.cu' + self.csrc_abs_path(fname) + for fname in [ + "colossal_C_frontend.cpp", + "multi_tensor_sgd_kernel.cu", + "multi_tensor_scale_kernel.cu", + "multi_tensor_adam.cu", + "multi_tensor_l2norm_kernel.cu", + "multi_tensor_lamb.cu", ] ] return ret def include_dirs(self): - ret = [self.csrc_abs_path('kernels/include'), self.get_cuda_home_include()] + ret = [self.csrc_abs_path("kernels/include"), self.get_cuda_home_include()] return ret def cxx_flags(self): - version_dependent_macros = ['-DVERSION_GE_1_1', '-DVERSION_GE_1_3', '-DVERSION_GE_1_5'] - return ['-O3'] + version_dependent_macros + version_dependent_macros = ["-DVERSION_GE_1_1", "-DVERSION_GE_1_3", "-DVERSION_GE_1_5"] + return ["-O3"] + version_dependent_macros def nvcc_flags(self): - extra_cuda_flags = ['-lineinfo'] + extra_cuda_flags = ["-lineinfo"] extra_cuda_flags.extend(get_cuda_cc_flag()) - return ['-O3', '--use_fast_math'] + extra_cuda_flags + return ["-O3", "--use_fast_math"] + extra_cuda_flags diff --git a/op_builder/gptq.py b/op_builder/gptq.py new file mode 100644 index 0000000000000000000000000000000000000000..bc4f445de0674e08532d4c67b7645930007cfffc --- /dev/null +++ b/op_builder/gptq.py @@ -0,0 +1,56 @@ +import re + +import torch + +from .builder import Builder +from .utils import append_nvcc_threads + + +class GPTQBuilder(Builder): + NAME = "cu_gptq" + PREBUILT_IMPORT_PATH = "colossalai._C.cu_gptq" + + def __init__(self): + super().__init__(name=GPTQBuilder.NAME, prebuilt_import_path=GPTQBuilder.PREBUILT_IMPORT_PATH) + + def include_dirs(self): + ret = [self.csrc_abs_path("gptq"), self.get_cuda_home_include()] + return ret + + def sources_files(self): + ret = [ + self.csrc_abs_path(fname) + for fname in [ + "gptq/linear_gptq.cpp", + "gptq/column_remap.cu", + "gptq/cuda_buffers.cu", + "gptq/q4_matmul.cu", + "gptq/q4_matrix.cu", + ] + ] + return ret + + def cxx_flags(self): + return ["-O3"] + self.version_dependent_macros + + def nvcc_flags(self): + extra_cuda_flags = [ + "-v", + "-std=c++14", + "-U__CUDA_NO_HALF_OPERATORS__", + "-U__CUDA_NO_HALF_CONVERSIONS__", + "-U__CUDA_NO_HALF2_OPERATORS__", + "-DTHRUST_IGNORE_CUB_VERSION_CHECK", + "-lcublas", + "-std=c++17", + ] + + for arch in torch.cuda.get_arch_list(): + res = re.search(r"sm_(\d+)", arch) + if res: + arch_cap = res[1] + if int(arch_cap) >= 80: + extra_cuda_flags.extend(["-gencode", f"arch=compute_{arch_cap},code={arch}"]) + + ret = ["-O3", "--use_fast_math"] + self.version_dependent_macros + extra_cuda_flags + return append_nvcc_threads(ret) diff --git a/op_builder/layernorm.py b/op_builder/layernorm.py index 61d9417419293c4cc3d4835406f7b7756ffda845..2684c6ddb7f7e0424cab54666c8647693b85cb58 100644 --- a/op_builder/layernorm.py +++ b/op_builder/layernorm.py @@ -1,5 +1,3 @@ -import os - from .builder import Builder from .utils import append_nvcc_threads, get_cuda_cc_flag @@ -12,18 +10,18 @@ class LayerNormBuilder(Builder): super().__init__(name=LayerNormBuilder.NAME, prebuilt_import_path=LayerNormBuilder.PREBUILT_IMPORT_PATH) def sources_files(self): - ret = [self.csrc_abs_path(fname) for fname in ['layer_norm_cuda.cpp', 'layer_norm_cuda_kernel.cu']] + ret = [self.csrc_abs_path(fname) for fname in ["layer_norm_cuda.cpp", "layer_norm_cuda_kernel.cu"]] return ret def include_dirs(self): - ret = [self.csrc_abs_path('kernels/include'), self.get_cuda_home_include()] + ret = [self.csrc_abs_path("kernels/include"), self.get_cuda_home_include()] return ret def cxx_flags(self): - return ['-O3'] + self.version_dependent_macros + return ["-O3"] + self.version_dependent_macros def nvcc_flags(self): - extra_cuda_flags = ['-maxrregcount=50'] + extra_cuda_flags = ["-maxrregcount=50"] extra_cuda_flags.extend(get_cuda_cc_flag()) - ret = ['-O3', '--use_fast_math'] + extra_cuda_flags + self.version_dependent_macros + ret = ["-O3", "--use_fast_math"] + extra_cuda_flags + self.version_dependent_macros return append_nvcc_threads(ret) diff --git a/op_builder/moe.py b/op_builder/moe.py index eeb7d8e3980c095d297fb985fbd961a3f852602a..6f8028b1720cc00bc4716ed8eed678703b4effb4 100644 --- a/op_builder/moe.py +++ b/op_builder/moe.py @@ -1,11 +1,8 @@ -import os - from .builder import Builder from .utils import append_nvcc_threads, get_cuda_cc_flag class MOEBuilder(Builder): - NAME = "moe" PREBUILT_IMPORT_PATH = "colossalai._C.moe" @@ -13,24 +10,23 @@ class MOEBuilder(Builder): super().__init__(name=MOEBuilder.NAME, prebuilt_import_path=MOEBuilder.PREBUILT_IMPORT_PATH) def include_dirs(self): - ret = [ - self.csrc_abs_path("kernels/include"), - self.get_cuda_home_include() - ] + ret = [self.csrc_abs_path("kernels/include"), self.get_cuda_home_include()] return ret def sources_files(self): - ret = [self.csrc_abs_path(fname) for fname in ['moe_cuda.cpp', 'moe_cuda_kernel.cu']] + ret = [self.csrc_abs_path(fname) for fname in ["moe_cuda.cpp", "moe_cuda_kernel.cu"]] return ret def cxx_flags(self): - return ['-O3'] + self.version_dependent_macros + return ["-O3"] + self.version_dependent_macros def nvcc_flags(self): extra_cuda_flags = [ - '-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__', '--expt-relaxed-constexpr', - '--expt-extended-lambda' + "-U__CUDA_NO_HALF_OPERATORS__", + "-U__CUDA_NO_HALF_CONVERSIONS__", + "--expt-relaxed-constexpr", + "--expt-extended-lambda", ] extra_cuda_flags.extend(get_cuda_cc_flag()) - ret = ['-O3', '--use_fast_math'] + extra_cuda_flags + ret = ["-O3", "--use_fast_math"] + extra_cuda_flags return append_nvcc_threads(ret) diff --git a/op_builder/multi_head_attn.py b/op_builder/multi_head_attn.py index f9103fe947297441d708dbbb42a5ce748ac5c3d7..b70f041db7d6a51e932cac9403db1f8c2f847c26 100644 --- a/op_builder/multi_head_attn.py +++ b/op_builder/multi_head_attn.py @@ -1,18 +1,13 @@ -import os - from .builder import Builder from .utils import append_nvcc_threads, get_cuda_cc_flag class MultiHeadAttnBuilder(Builder): - NAME = "multihead_attention" PREBUILT_IMPORT_PATH = "colossalai._C.multihead_attention" def __init__(self): - super().__init__(name=MultiHeadAttnBuilder.NAME, - prebuilt_import_path=MultiHeadAttnBuilder.PREBUILT_IMPORT_PATH) - + super().__init__(name=MultiHeadAttnBuilder.NAME, prebuilt_import_path=MultiHeadAttnBuilder.PREBUILT_IMPORT_PATH) def include_dirs(self): ret = [self.csrc_abs_path("kernels/include"), self.get_cuda_home_include()] @@ -20,22 +15,31 @@ class MultiHeadAttnBuilder(Builder): def sources_files(self): ret = [ - self.csrc_abs_path(fname) for fname in [ - 'multihead_attention_1d.cpp', 'kernels/cublas_wrappers.cu', 'kernels/transform_kernels.cu', - 'kernels/dropout_kernels.cu', 'kernels/normalize_kernels.cu', 'kernels/softmax_kernels.cu', - 'kernels/general_kernels.cu', 'kernels/cuda_util.cu' + self.csrc_abs_path(fname) + for fname in [ + "multihead_attention_1d.cpp", + "kernels/cublas_wrappers.cu", + "kernels/transform_kernels.cu", + "kernels/dropout_kernels.cu", + "kernels/normalize_kernels.cu", + "kernels/softmax_kernels.cu", + "kernels/general_kernels.cu", + "kernels/cuda_util.cu", ] ] return ret def cxx_flags(self): - return ['-O3'] + self.version_dependent_macros + return ["-O3"] + self.version_dependent_macros def nvcc_flags(self): extra_cuda_flags = [ - '-std=c++14', '-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__', - '-U__CUDA_NO_HALF2_OPERATORS__', '-DTHRUST_IGNORE_CUB_VERSION_CHECK' + "-std=c++14", + "-U__CUDA_NO_HALF_OPERATORS__", + "-U__CUDA_NO_HALF_CONVERSIONS__", + "-U__CUDA_NO_HALF2_OPERATORS__", + "-DTHRUST_IGNORE_CUB_VERSION_CHECK", ] extra_cuda_flags.extend(get_cuda_cc_flag()) - ret = ['-O3', '--use_fast_math'] + self.version_dependent_macros + extra_cuda_flags + ret = ["-O3", "--use_fast_math"] + self.version_dependent_macros + extra_cuda_flags return append_nvcc_threads(ret) diff --git a/op_builder/scaled_masked_softmax.py b/op_builder/scaled_masked_softmax.py index 11cfda39a85c799578653c909cf79539fc24cb48..b2f1de7792c8e44ab3fb74a67bc8c38b4feabbab 100644 --- a/op_builder/scaled_masked_softmax.py +++ b/op_builder/scaled_masked_softmax.py @@ -1,5 +1,3 @@ -import os - from .builder import Builder from .utils import append_nvcc_threads @@ -9,29 +7,28 @@ class ScaledMaskedSoftmaxBuilder(Builder): PREBUILT_IMPORT_PATH = "colossalai._C.scaled_masked_softmax" def __init__(self): - super().__init__(name=ScaledMaskedSoftmaxBuilder.NAME, prebuilt_import_path=ScaledMaskedSoftmaxBuilder.PREBUILT_IMPORT_PATH) + super().__init__( + name=ScaledMaskedSoftmaxBuilder.NAME, prebuilt_import_path=ScaledMaskedSoftmaxBuilder.PREBUILT_IMPORT_PATH + ) # necessary 4 functions def sources_files(self): - ret = [ - self.csrc_abs_path(fname) for fname in - ['scaled_masked_softmax.cpp', 'scaled_masked_softmax_cuda.cu'] - ] + ret = [self.csrc_abs_path(fname) for fname in ["scaled_masked_softmax.cpp", "scaled_masked_softmax_cuda.cu"]] return ret def include_dirs(self): - return [ - self.csrc_abs_path("kernels/include"), - self.get_cuda_home_include() - ] + return [self.csrc_abs_path("kernels/include"), self.get_cuda_home_include()] def cxx_flags(self): - return ['-O3'] + self.version_dependent_macros + return ["-O3"] + self.version_dependent_macros def nvcc_flags(self): extra_cuda_flags = [ - '-std=c++14', '-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__', - '-U__CUDA_NO_HALF2_OPERATORS__', '-DTHRUST_IGNORE_CUB_VERSION_CHECK' + "-std=c++14", + "-U__CUDA_NO_HALF_OPERATORS__", + "-U__CUDA_NO_HALF_CONVERSIONS__", + "-U__CUDA_NO_HALF2_OPERATORS__", + "-DTHRUST_IGNORE_CUB_VERSION_CHECK", ] - ret = ['-O3', '--use_fast_math'] + self.version_dependent_macros + extra_cuda_flags + ret = ["-O3", "--use_fast_math"] + self.version_dependent_macros + extra_cuda_flags return append_nvcc_threads(ret) diff --git a/op_builder/scaled_upper_triangle_masked_softmax.py b/op_builder/scaled_upper_triangle_masked_softmax.py index d0d2433aa64527abb662381b08b42aabb9d61807..1445230acbc128424f0fa90e25aa6e08ad5eb8fc 100644 --- a/op_builder/scaled_upper_triangle_masked_softmax.py +++ b/op_builder/scaled_upper_triangle_masked_softmax.py @@ -1,5 +1,3 @@ -import os - from .builder import Builder from .utils import append_nvcc_threads, get_cuda_cc_flag @@ -9,29 +7,31 @@ class ScaledUpperTrainglemaskedSoftmaxBuilder(Builder): PREBUILT_IMPORT_PATH = "colossalai._C.scaled_upper_triangle_masked_softmax" def __init__(self): - super().__init__(name=ScaledUpperTrainglemaskedSoftmaxBuilder.NAME, prebuilt_import_path=ScaledUpperTrainglemaskedSoftmaxBuilder.PREBUILT_IMPORT_PATH) + super().__init__( + name=ScaledUpperTrainglemaskedSoftmaxBuilder.NAME, + prebuilt_import_path=ScaledUpperTrainglemaskedSoftmaxBuilder.PREBUILT_IMPORT_PATH, + ) def include_dirs(self): - return [ - self.csrc_abs_path("kernels/include"), - self.get_cuda_home_include() - ] + return [self.csrc_abs_path("kernels/include"), self.get_cuda_home_include()] def sources_files(self): ret = [ self.csrc_abs_path(fname) - for fname in ['scaled_upper_triang_masked_softmax.cpp', 'scaled_upper_triang_masked_softmax_cuda.cu'] + for fname in ["scaled_upper_triang_masked_softmax.cpp", "scaled_upper_triang_masked_softmax_cuda.cu"] ] return ret def cxx_flags(self): - return ['-O3'] + self.version_dependent_macros + return ["-O3"] + self.version_dependent_macros def nvcc_flags(self): extra_cuda_flags = [ - '-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__', '--expt-relaxed-constexpr', - '--expt-extended-lambda' + "-U__CUDA_NO_HALF_OPERATORS__", + "-U__CUDA_NO_HALF_CONVERSIONS__", + "--expt-relaxed-constexpr", + "--expt-extended-lambda", ] extra_cuda_flags.extend(get_cuda_cc_flag()) - ret = ['-O3', '--use_fast_math'] + extra_cuda_flags + ret = ["-O3", "--use_fast_math"] + extra_cuda_flags return append_nvcc_threads(ret) diff --git a/op_builder/utils.py b/op_builder/utils.py index 1b1bd5f499707b685903771d14d466023f0d66b7..3f75f952d57bd94e68dbd8e3e0b4c6bd3c2edc58 100644 --- a/op_builder/utils.py +++ b/op_builder/utils.py @@ -11,6 +11,7 @@ def print_rank_0(message: str) -> None: """ try: import torch.distributed as dist + if not dist.is_initialized(): is_main_rank = True else: @@ -36,7 +37,8 @@ def get_cuda_version_in_pytorch() -> List[int]: torch_cuda_minor = torch.version.cuda.split(".")[1] except: raise ValueError( - "[extension] Cannot retrive the CUDA version in the PyTorch binary given by torch.version.cuda") + "[extension] Cannot retrieve the CUDA version in the PyTorch binary given by torch.version.cuda" + ) return torch_cuda_major, torch_cuda_minor @@ -50,7 +52,7 @@ def get_cuda_bare_metal_version(cuda_dir) -> List[int]: Returns: The CUDA version required by PyTorch, in the form of tuple (major, minor). """ - nvcc_path = os.path.join(cuda_dir, 'bin/nvcc') + nvcc_path = os.path.join(cuda_dir, "bin/nvcc") if cuda_dir is None: raise ValueError( @@ -85,9 +87,9 @@ def check_system_pytorch_cuda_match(cuda_dir): if bare_metal_major != torch_cuda_major: raise Exception( - f'[extension] Failed to build PyTorch extension because the detected CUDA version ({bare_metal_major}.{bare_metal_minor}) ' - f'mismatches the version that was used to compile PyTorch ({torch_cuda_major}.{torch_cuda_minor}).' - 'Please make sure you have set the CUDA_HOME correctly and installed the correct PyTorch in https://pytorch.org/get-started/locally/ .' + f"[extension] Failed to build PyTorch extension because the detected CUDA version ({bare_metal_major}.{bare_metal_minor}) " + f"mismatches the version that was used to compile PyTorch ({torch_cuda_major}.{torch_cuda_minor})." + "Please make sure you have set the CUDA_HOME correctly and installed the correct PyTorch in https://pytorch.org/get-started/locally/ ." ) if bare_metal_minor != torch_cuda_minor: @@ -107,10 +109,11 @@ def get_pytorch_version() -> List[int]: A tuple of integers in the form of (major, minor, patch). """ import torch - torch_version = torch.__version__.split('+')[0] - TORCH_MAJOR = int(torch_version.split('.')[0]) - TORCH_MINOR = int(torch_version.split('.')[1]) - TORCH_PATCH = int(torch_version.split('.')[2]) + + torch_version = torch.__version__.split("+")[0] + TORCH_MAJOR = int(torch_version.split(".")[0]) + TORCH_MINOR = int(torch_version.split(".")[1]) + TORCH_PATCH = int(torch_version.split(".")[2], 16) return TORCH_MAJOR, TORCH_MINOR, TORCH_PATCH @@ -132,7 +135,8 @@ def check_pytorch_version(min_major_version, min_minor_version) -> bool: if torch_major < min_major_version or (torch_major == min_major_version and torch_minor < min_minor_version): raise RuntimeError( f"[extension] Colossal-AI requires Pytorch {min_major_version}.{min_minor_version} or newer.\n" - "The latest stable release can be obtained from https://pytorch.org/get-started/locally/") + "The latest stable release can be obtained from https://pytorch.org/get-started/locally/" + ) def check_cuda_availability(): @@ -143,6 +147,7 @@ def check_cuda_availability(): A boolean value. True if CUDA is available and False otherwise. """ import torch + return torch.cuda.is_available() @@ -155,29 +160,31 @@ def set_cuda_arch_list(cuda_dir): # we only need to set this when CUDA is not available for cross-compilation if not cuda_available: - warnings.warn('\n[extension] PyTorch did not find available GPUs on this system.\n' - 'If your intention is to cross-compile, this is not an error.\n' - 'By default, Colossal-AI will cross-compile for \n' - '1. Pascal (compute capabilities 6.0, 6.1, 6.2),\n' - '2. Volta (compute capability 7.0)\n' - '3. Turing (compute capability 7.5),\n' - '4. Ampere (compute capability 8.0, 8.6)if the CUDA version is >= 11.0\n' - '\nIf you wish to cross-compile for a single specific architecture,\n' - 'export TORCH_CUDA_ARCH_LIST="compute capability" before running setup.py.\n') + warnings.warn( + "\n[extension] PyTorch did not find available GPUs on this system.\n" + "If your intention is to cross-compile, this is not an error.\n" + "By default, Colossal-AI will cross-compile for \n" + "1. Pascal (compute capabilities 6.0, 6.1, 6.2),\n" + "2. Volta (compute capability 7.0)\n" + "3. Turing (compute capability 7.5),\n" + "4. Ampere (compute capability 8.0, 8.6)if the CUDA version is >= 11.0\n" + "\nIf you wish to cross-compile for a single specific architecture,\n" + 'export TORCH_CUDA_ARCH_LIST="compute capability" before running setup.py.\n' + ) if os.environ.get("TORCH_CUDA_ARCH_LIST", None) is None: bare_metal_major, bare_metal_minor = get_cuda_bare_metal_version(cuda_dir) - arch_list = ['6.0', '6.1', '6.2', '7.0', '7.5'] + arch_list = ["6.0", "6.1", "6.2", "7.0", "7.5"] if int(bare_metal_major) == 11: if int(bare_metal_minor) == 0: - arch_list.append('8.0') + arch_list.append("8.0") else: - arch_list.append('8.0') - arch_list.append('8.6') + arch_list.append("8.0") + arch_list.append("8.6") - arch_list_str = ';'.join(arch_list) + arch_list_str = ";".join(arch_list) os.environ["TORCH_CUDA_ARCH_LIST"] = arch_list_str return False return True @@ -197,12 +204,13 @@ def get_cuda_cc_flag() -> List[str]: import torch cc_flag = [] + max_arch = "".join(str(i) for i in torch.cuda.get_device_capability()) for arch in torch.cuda.get_arch_list(): - res = re.search(r'sm_(\d+)', arch) + res = re.search(r"sm_(\d+)", arch) if res: arch_cap = res[1] - if int(arch_cap) >= 60: - cc_flag.extend(['-gencode', f'arch=compute_{arch_cap},code={arch}']) + if int(arch_cap) >= 60 and int(arch_cap) <= int(max_arch): + cc_flag.extend(["-gencode", f"arch=compute_{arch_cap},code={arch}"]) return cc_flag diff --git a/pytest.ini b/pytest.ini index ac31ace4bfae025025b1098719aba873db615d1c..38ad7d76de506c569a13b7889af8d4e131f58193 100644 --- a/pytest.ini +++ b/pytest.ini @@ -1,6 +1,5 @@ [pytest] markers = - cpu: tests which can run on CPU - gpu: tests which requires a single GPU - dist: tests which are run in a multi-GPU or multi-machine environment - experiment: tests for experimental features \ No newline at end of file + dist: tests which are run in a multi-GPU or multi-machine environment (at least 4 GPUs) + largedist: tests which are run in a multi-GPU or multi-machine environment (at least 8 GPUs) +addopts = --ignore=tests/test_analyzer --ignore=tests/test_auto_parallel --ignore=tests/test_autochunk --ignore=tests/test_moe --ignore=tests/test_fx --ignore=tests/test_legacy diff --git a/requirements/requirements-test.txt b/requirements/requirements-test.txt index 82b6173b351791654e71034c4b445b04babf480b..467f83610eb0daf9cf5922ce168c23e0262238d3 100644 --- a/requirements/requirements-test.txt +++ b/requirements/requirements-test.txt @@ -1,15 +1,21 @@ diffusers fbgemm-gpu==0.2.0 pytest -pytest-cov +coverage==7.2.3 +git+https://github.com/hpcaitech/pytest-testmon torchvision -transformers +transformers==4.33.0 timm titans torchaudio +torchx-nightly==2022.6.29 # torchrec 0.2.0 requires torchx-nightly. This package is updated every day. We fix the version to a specific date to avoid breaking changes. torchrec==0.2.0 contexttimer einops triton==2.0.0.dev20221202 -git+https://github.com/HazyResearch/flash-attention.git@c422fee3776eb3ea24e011ef641fd5fbeb212623#egg=flash_attn requests==2.27.1 # downgrade to avoid huggingface error https://github.com/huggingface/transformers/issues/17611 +SentencePiece +ninja +flash_attn==2.0.5 +datasets +#auto-gptq now not support torch1.12 diff --git a/requirements/requirements.txt b/requirements/requirements.txt index b34dc2e223ae7c8144c8e0b95eadf6646aaca4e9..9aa5f2822e40197012eff1aa9965e29c77689947 100644 --- a/requirements/requirements.txt +++ b/requirements/requirements.txt @@ -8,5 +8,6 @@ click fabric contexttimer ninja -torch>=1.11 +torch>=1.12 safetensors +einops diff --git a/setup.py b/setup.py index 5d8f831218d95a3d1cad775083f2652da8a19c15..cda1ba7ee7a63f672265538bfba6d6ee6476b5bd 100644 --- a/setup.py +++ b/setup.py @@ -15,8 +15,8 @@ from op_builder.utils import ( ) try: - import torch from torch.utils.cpp_extension import CUDA_HOME, BuildExtension + TORCH_AVAILABLE = True except ImportError: TORCH_AVAILABLE = False @@ -26,14 +26,14 @@ except ImportError: MIN_PYTORCH_VERSION_MAJOR = 1 MIN_PYTORCH_VERSION_MINOR = 10 THIS_DIR = os.path.dirname(os.path.abspath(__file__)) -BUILD_CUDA_EXT = int(os.environ.get('CUDA_EXT', '0')) == 1 -IS_NIGHTLY = int(os.environ.get('NIGHTLY', '0')) == 1 +BUILD_CUDA_EXT = int(os.environ.get("CUDA_EXT", "0")) == 1 +IS_NIGHTLY = int(os.environ.get("NIGHTLY", "0")) == 1 # a variable to store the op builder ext_modules = [] # we do not support windows currently -if sys.platform == 'win32': +if sys.platform == "win32": raise RuntimeError("Windows is not supported yet. Please try again within the Windows Subsystem for Linux (WSL).") @@ -64,7 +64,7 @@ def fetch_requirements(path) -> List[str]: Returns: The lines in the requirements file. """ - with open(path, 'r') as fd: + with open(path, "r") as fd: return [r.strip() for r in fd.readlines()] @@ -75,7 +75,7 @@ def fetch_readme() -> str: Returns: The lines in the README file. """ - with open('README.md', encoding='utf-8') as f: + with open("README.md", encoding="utf-8") as f: return f.read() @@ -89,21 +89,21 @@ def get_version() -> str: setup_file_path = os.path.abspath(__file__) project_path = os.path.dirname(setup_file_path) - version_txt_path = os.path.join(project_path, 'version.txt') - version_py_path = os.path.join(project_path, 'colossalai/version.py') + version_txt_path = os.path.join(project_path, "version.txt") + version_py_path = os.path.join(project_path, "colossalai/version.py") with open(version_txt_path) as f: version = f.read().strip() # write version into version.py - with open(version_py_path, 'w') as f: + with open(version_py_path, "w") as f: f.write(f"__version__ = '{version}'\n") # look for pytorch and cuda version if BUILD_CUDA_EXT: torch_major, torch_minor, _ = get_pytorch_version() - torch_version = f'{torch_major}.{torch_minor}' - cuda_version = '.'.join(get_cuda_bare_metal_version(CUDA_HOME)) + torch_version = f"{torch_major}.{torch_minor}" + cuda_version = ".".join(get_cuda_bare_metal_version(CUDA_HOME)) else: torch_version = None cuda_version = None @@ -112,12 +112,12 @@ def get_version() -> str: if torch_version: f.write(f'torch = "{torch_version}"\n') else: - f.write('torch = None\n') + f.write("torch = None\n") if cuda_version: f.write(f'cuda = "{cuda_version}"\n') else: - f.write('cuda = None\n') + f.write("cuda = None\n") return version @@ -127,6 +127,7 @@ if BUILD_CUDA_EXT: set_cuda_arch_list(CUDA_HOME) from op_builder import ALL_OPS + op_names = [] # load all builders @@ -135,7 +136,7 @@ if BUILD_CUDA_EXT: ext_modules.append(builder_cls().builder()) # show log - op_name_list = ', '.join(op_names) + op_name_list = ", ".join(op_names) print(f"[extension] loaded builders for {op_name_list}") # always put not nightly branch as the if branch @@ -143,56 +144,62 @@ if BUILD_CUDA_EXT: # and it will mess up with the dependency graph insights if not IS_NIGHTLY: version = get_version() - package_name = 'colossalai' + package_name = "colossalai" else: # use date as the nightly version - version = datetime.today().strftime('%Y.%m.%d') - package_name = 'colossalai-nightly' - -setup(name=package_name, - version=version, - packages=find_packages(exclude=( - 'op_builder', - 'benchmark', - 'docker', - 'tests', - 'docs', - 'examples', - 'tests', - 'scripts', - 'requirements', - '*.egg-info', - )), - description='An integrated large-scale model training system with efficient parallelization techniques', - long_description=fetch_readme(), - long_description_content_type='text/markdown', - license='Apache Software License 2.0', - url='https://www.colossalai.org', - project_urls={ - 'Forum': 'https://github.com/hpcaitech/ColossalAI/discussions', - 'Bug Tracker': 'https://github.com/hpcaitech/ColossalAI/issues', - 'Examples': 'https://github.com/hpcaitech/ColossalAI-Examples', - 'Documentation': 'http://colossalai.readthedocs.io', - 'Github': 'https://github.com/hpcaitech/ColossalAI', - }, - ext_modules=ext_modules, - cmdclass={'build_ext': BuildExtension} if ext_modules else {}, - install_requires=fetch_requirements('requirements/requirements.txt'), - entry_points=''' + version = datetime.today().strftime("%Y.%m.%d") + package_name = "colossalai-nightly" + +setup( + name=package_name, + version=version, + packages=find_packages( + exclude=( + "op_builder", + "benchmark", + "docker", + "tests", + "docs", + "examples", + "tests", + "scripts", + "requirements", + "*.egg-info", + ) + ), + description="An integrated large-scale model training system with efficient parallelization techniques", + long_description=fetch_readme(), + long_description_content_type="text/markdown", + license="Apache Software License 2.0", + url="https://www.colossalai.org", + project_urls={ + "Forum": "https://github.com/hpcaitech/ColossalAI/discussions", + "Bug Tracker": "https://github.com/hpcaitech/ColossalAI/issues", + "Examples": "https://github.com/hpcaitech/ColossalAI-Examples", + "Documentation": "http://colossalai.readthedocs.io", + "Github": "https://github.com/hpcaitech/ColossalAI", + }, + ext_modules=ext_modules, + cmdclass={"build_ext": BuildExtension} if ext_modules else {}, + install_requires=fetch_requirements("requirements/requirements.txt"), + entry_points=""" [console_scripts] colossalai=colossalai.cli:cli - ''', - python_requires='>=3.6', - classifiers=[ - 'Programming Language :: Python :: 3', - 'License :: OSI Approved :: Apache Software License', - 'Environment :: GPU :: NVIDIA CUDA', - 'Topic :: Scientific/Engineering :: Artificial Intelligence', - 'Topic :: System :: Distributed Computing', - ], - package_data={ - 'colossalai': [ - '_C/*.pyi', 'kernel/cuda_native/csrc/*', 'kernel/cuda_native/csrc/kernel/*', - 'kernel/cuda_native/csrc/kernels/include/*' - ] - }) + """, + python_requires=">=3.6", + classifiers=[ + "Programming Language :: Python :: 3", + "License :: OSI Approved :: Apache Software License", + "Environment :: GPU :: NVIDIA CUDA", + "Topic :: Scientific/Engineering :: Artificial Intelligence", + "Topic :: System :: Distributed Computing", + ], + package_data={ + "colossalai": [ + "_C/*.pyi", + "kernel/cuda_native/csrc/*", + "kernel/cuda_native/csrc/kernel/*", + "kernel/cuda_native/csrc/kernels/include/*", + ] + }, +) diff --git a/tests/components_to_test/__init__.py b/tests/components_to_test/__init__.py index f29efefce4a461bb71ea2f65d10a57a8b1439320..65eaa72d6e849cd2154de77e4ba1e9558e76d098 100644 --- a/tests/components_to_test/__init__.py +++ b/tests/components_to_test/__init__.py @@ -11,9 +11,19 @@ from . import ( ) from .utils import run_fwd, run_fwd_bwd -from . import albert # isort:skip +from . import albert # isort:skip __all__ = [ - 'bert', 'gpt2', 'hanging_param_model', 'inline_op_model', 'nested_model', 'repeated_computed_layers', 'resnet', - 'simple_net', 'run_fwd_bwd', 'albert', 'beit', 'run_fwd' + "bert", + "gpt2", + "hanging_param_model", + "inline_op_model", + "nested_model", + "repeated_computed_layers", + "resnet", + "simple_net", + "run_fwd_bwd", + "albert", + "beit", + "run_fwd", ] diff --git a/tests/components_to_test/albert.py b/tests/components_to_test/albert.py index d5b6bc89a83e0ae49a4045800b6e9d57ff848604..0ba4d19655cd7013e363ce868ae6a36f59615759 100644 --- a/tests/components_to_test/albert.py +++ b/tests/components_to_test/albert.py @@ -1,13 +1,11 @@ import torch -import transformers -from packaging import version from transformers import AlbertConfig, AlbertForSequenceClassification from .bert import get_bert_data_loader from .registry import non_distributed_component_funcs -@non_distributed_component_funcs.register(name='albert') +@non_distributed_component_funcs.register(name="albert") def get_training_components(): hidden_dim = 8 num_head = 4 @@ -16,20 +14,21 @@ def get_training_components(): vocab_size = 32 def bert_model_builder(checkpoint: bool = False): - config = AlbertConfig(vocab_size=vocab_size, - gradient_checkpointing=checkpoint, - hidden_size=hidden_dim, - intermediate_size=hidden_dim * 4, - num_attention_heads=num_head, - max_position_embeddings=sequence_length, - num_hidden_layers=num_layer, - hidden_dropout_prob=0., - attention_probs_dropout_prob=0.) - print('building AlbertForSequenceClassification model') - - # adapting huggingface BertForSequenceClassification for single unitest calling interface - class ModelAaptor(AlbertForSequenceClassification): - + config = AlbertConfig( + vocab_size=vocab_size, + gradient_checkpointing=checkpoint, + hidden_size=hidden_dim, + intermediate_size=hidden_dim * 4, + num_attention_heads=num_head, + max_position_embeddings=sequence_length, + num_hidden_layers=num_layer, + hidden_dropout_prob=0.0, + attention_probs_dropout_prob=0.0, + ) + print("building AlbertForSequenceClassification model") + + # adapting huggingface BertForSequenceClassification for single unittest calling interface + class ModelAdaptor(AlbertForSequenceClassification): def forward(self, input_ids, labels): """ inputs: data, label @@ -37,23 +36,27 @@ def get_training_components(): """ return super().forward(input_ids=input_ids, labels=labels)[0] - model = ModelAaptor(config) + model = ModelAdaptor(config) # if checkpoint and version.parse(transformers.__version__) >= version.parse("4.11.0"): # model.gradient_checkpointing_enable() return model - is_distrbuted = torch.distributed.is_initialized() - trainloader = get_bert_data_loader(n_class=vocab_size, - batch_size=2, - total_samples=10000, - sequence_length=sequence_length, - is_distrbuted=is_distrbuted) - testloader = get_bert_data_loader(n_class=vocab_size, - batch_size=2, - total_samples=10000, - sequence_length=sequence_length, - is_distrbuted=is_distrbuted) + is_distributed = torch.distributed.is_initialized() + trainloader = get_bert_data_loader( + n_class=vocab_size, + batch_size=2, + total_samples=10000, + sequence_length=sequence_length, + is_distributed=is_distributed, + ) + testloader = get_bert_data_loader( + n_class=vocab_size, + batch_size=2, + total_samples=10000, + sequence_length=sequence_length, + is_distributed=is_distributed, + ) criterion = None return bert_model_builder, trainloader, testloader, torch.optim.Adam, criterion diff --git a/tests/components_to_test/beit.py b/tests/components_to_test/beit.py index 1252071f40759149606267b3d18b495d1a3f490d..d33474ea9a6bfdf6a843a1c3ea69a6d0196dcff4 100644 --- a/tests/components_to_test/beit.py +++ b/tests/components_to_test/beit.py @@ -14,29 +14,31 @@ class DummyDataLoader(DummyDataGenerator): batch_size = 4 def generate(self): - data = torch.randn((DummyDataLoader.batch_size, DummyDataLoader.num_channel, DummyDataLoader.img_size, - DummyDataLoader.img_size), - device=get_current_device()) - label = torch.randint(low=0, - high=DummyDataLoader.num_class, - size=(DummyDataLoader.batch_size,), - device=get_current_device()) + data = torch.randn( + ( + DummyDataLoader.batch_size, + DummyDataLoader.num_channel, + DummyDataLoader.img_size, + DummyDataLoader.img_size, + ), + device=get_current_device(), + ) + label = torch.randint( + low=0, high=DummyDataLoader.num_class, size=(DummyDataLoader.batch_size,), device=get_current_device() + ) return data, label -@non_distributed_component_funcs.register(name='beit') +@non_distributed_component_funcs.register(name="beit") def get_training_components(): - - def model_buider(checkpoint=False): - model = Beit(img_size=DummyDataLoader.img_size, - num_classes=DummyDataLoader.num_class, - embed_dim=32, - depth=2, - num_heads=4) + def model_builder(checkpoint=False): + model = Beit( + img_size=DummyDataLoader.img_size, num_classes=DummyDataLoader.num_class, embed_dim=32, depth=2, num_heads=4 + ) return model trainloader = DummyDataLoader() testloader = DummyDataLoader() criterion = torch.nn.CrossEntropyLoss() - return model_buider, trainloader, testloader, torch.optim.Adam, criterion + return model_builder, trainloader, testloader, torch.optim.Adam, criterion diff --git a/tests/components_to_test/bert.py b/tests/components_to_test/bert.py index c1faa6f9d892650c4717d7baa359165822a8165b..f0061ad18c843b72e8faf5a424a9b27b66357bf7 100644 --- a/tests/components_to_test/bert.py +++ b/tests/components_to_test/bert.py @@ -8,12 +8,12 @@ from .registry import non_distributed_component_funcs def get_bert_data_loader( - n_class, - batch_size, - total_samples, - sequence_length, - device=torch.device('cpu:0'), - is_distrbuted=False, + n_class, + batch_size, + total_samples, + sequence_length, + device=torch.device("cpu:0"), + is_distributed=False, ): train_data = torch.randint( low=0, @@ -24,7 +24,7 @@ def get_bert_data_loader( ) train_label = torch.randint(low=0, high=2, size=(total_samples,), device=device, dtype=torch.long) train_dataset = torch.utils.data.TensorDataset(train_data, train_label) - if is_distrbuted: + if is_distributed: sampler = torch.utils.data.distributed.DistributedSampler(train_dataset) else: sampler = SequentialSampler(train_dataset) @@ -32,7 +32,7 @@ def get_bert_data_loader( return train_loader -@non_distributed_component_funcs.register(name='bert') +@non_distributed_component_funcs.register(name="bert") def get_training_components(): hidden_dim = 8 num_head = 4 @@ -41,20 +41,21 @@ def get_training_components(): vocab_size = 32 def bert_model_builder(checkpoint: bool = False): - config = BertConfig(vocab_size=vocab_size, - gradient_checkpointing=checkpoint, - hidden_size=hidden_dim, - intermediate_size=hidden_dim * 4, - num_attention_heads=num_head, - max_position_embeddings=sequence_length, - num_hidden_layers=num_layer, - hidden_dropout_prob=0., - attention_probs_dropout_prob=0.) - print('building BertForSequenceClassification model') - - # adapting huggingface BertForSequenceClassification for single unitest calling interface - class ModelAaptor(BertForSequenceClassification): + config = BertConfig( + vocab_size=vocab_size, + gradient_checkpointing=checkpoint, + hidden_size=hidden_dim, + intermediate_size=hidden_dim * 4, + num_attention_heads=num_head, + max_position_embeddings=sequence_length, + num_hidden_layers=num_layer, + hidden_dropout_prob=0.0, + attention_probs_dropout_prob=0.0, + ) + print("building BertForSequenceClassification model") + # adapting huggingface BertForSequenceClassification for single unittest calling interface + class ModelAdaptor(BertForSequenceClassification): def forward(self, input_ids, labels): """ inputs: data, label @@ -62,23 +63,27 @@ def get_training_components(): """ return super().forward(input_ids=input_ids, labels=labels)[0] - model = ModelAaptor(config) + model = ModelAdaptor(config) if checkpoint and version.parse(transformers.__version__) >= version.parse("4.11.0"): model.gradient_checkpointing_enable() return model - is_distrbuted = torch.distributed.is_initialized() - trainloader = get_bert_data_loader(n_class=vocab_size, - batch_size=2, - total_samples=10000, - sequence_length=sequence_length, - is_distrbuted=is_distrbuted) - testloader = get_bert_data_loader(n_class=vocab_size, - batch_size=2, - total_samples=10000, - sequence_length=sequence_length, - is_distrbuted=is_distrbuted) + is_distributed = torch.distributed.is_initialized() + trainloader = get_bert_data_loader( + n_class=vocab_size, + batch_size=2, + total_samples=10000, + sequence_length=sequence_length, + is_distributed=is_distributed, + ) + testloader = get_bert_data_loader( + n_class=vocab_size, + batch_size=2, + total_samples=10000, + sequence_length=sequence_length, + is_distributed=is_distributed, + ) criterion = None return bert_model_builder, trainloader, testloader, torch.optim.Adam, criterion diff --git a/tests/components_to_test/gpt2.py b/tests/components_to_test/gpt2.py index fe25b4923fa27493be7f006956ce642be6047503..7f826497d2abe9a5ffb93fbdc38021c0c0d58590 100644 --- a/tests/components_to_test/gpt2.py +++ b/tests/components_to_test/gpt2.py @@ -14,33 +14,40 @@ class DummyDataLoader(DummyDataGenerator): seq_len = 64 def generate(self): - input_ids = torch.randint(0, - DummyDataLoader.vocab_size, (DummyDataLoader.batch_size, DummyDataLoader.seq_len), - device=get_current_device()) + input_ids = torch.randint( + 0, + DummyDataLoader.vocab_size, + (DummyDataLoader.batch_size, DummyDataLoader.seq_len), + device=get_current_device(), + ) return input_ids, input_ids class GPTLMModel(nn.Module): - - def __init__(self, - hidden_size=768, - num_layers=12, - num_attention_heads=12, - max_seq_len=1024, - vocab_size=50304, - checkpoint=False): + def __init__( + self, + hidden_size=768, + num_layers=12, + num_attention_heads=12, + max_seq_len=1024, + vocab_size=50304, + checkpoint=False, + ): super().__init__() self.checkpoint = checkpoint self.model = GPT2LMHeadModel( - GPT2Config(n_embd=hidden_size, - n_layer=num_layers, - n_head=num_attention_heads, - n_positions=max_seq_len, - n_ctx=max_seq_len, - vocab_size=vocab_size, - resid_pdrop=0.0, - embd_pdrop=0.0, - attn_pdrop=0.0)) + GPT2Config( + n_embd=hidden_size, + n_layer=num_layers, + n_head=num_attention_heads, + n_positions=max_seq_len, + n_ctx=max_seq_len, + vocab_size=vocab_size, + resid_pdrop=0.0, + embd_pdrop=0.0, + attn_pdrop=0.0, + ) + ) if checkpoint: self.model.gradient_checkpointing_enable() @@ -51,12 +58,9 @@ class GPTLMModel(nn.Module): def gpt2_micro(checkpoint=True): - return GPTLMModel(checkpoint=checkpoint, - hidden_size=32, - num_layers=2, - num_attention_heads=4, - max_seq_len=64, - vocab_size=128) + return GPTLMModel( + checkpoint=checkpoint, hidden_size=32, num_layers=2, num_attention_heads=4, max_seq_len=64, vocab_size=128 + ) def gpt2_s(checkpoint=True): @@ -68,7 +72,6 @@ def gpt2_m(checkpoint=True): class GPTLMLoss(nn.Module): - def __init__(self): super().__init__() self.loss_fn = nn.CrossEntropyLoss() @@ -80,9 +83,8 @@ class GPTLMLoss(nn.Module): return self.loss_fn(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) -@non_distributed_component_funcs.register(name='gpt2') +@non_distributed_component_funcs.register(name="gpt2") def get_training_components(): - trainloader = DummyDataLoader() testloader = DummyDataLoader() diff --git a/tests/components_to_test/hanging_param_model.py b/tests/components_to_test/hanging_param_model.py index 329a08ea28f0224226c6061c2894982f3cd1d397..5531c8d081a0ecbc2b31f501ab274c96d763c360 100644 --- a/tests/components_to_test/hanging_param_model.py +++ b/tests/components_to_test/hanging_param_model.py @@ -2,7 +2,7 @@ import torch import torch.nn as nn import torch.nn.functional as F -from colossalai.nn import CheckpointModule +from colossalai.legacy.nn import CheckpointModule from .registry import non_distributed_component_funcs from .utils.dummy_data_generator import DummyDataGenerator @@ -28,16 +28,14 @@ class HangingParamModule(CheckpointModule): class DummyDataLoader(DummyDataGenerator): - def generate(self): data = torch.rand(16, 4) label = torch.randint(low=0, high=2, size=(16,)) return data, label -@non_distributed_component_funcs.register(name='hanging_param_model') +@non_distributed_component_funcs.register(name="hanging_param_model") def get_training_components(): - def model_builder(checkpoint=False): return HangingParamModule(checkpoint) @@ -46,4 +44,5 @@ def get_training_components(): criterion = torch.nn.CrossEntropyLoss() from colossalai.nn.optimizer import HybridAdam + return model_builder, trainloader, testloader, HybridAdam, criterion diff --git a/tests/components_to_test/inline_op_model.py b/tests/components_to_test/inline_op_model.py index f061d48f92c6eb7a2095fed039f2d2e49c1e07d3..8bfa9cf343539de613b2d69f8063d027db66c604 100644 --- a/tests/components_to_test/inline_op_model.py +++ b/tests/components_to_test/inline_op_model.py @@ -1,8 +1,7 @@ import torch import torch.nn as nn -import torch.nn.functional as F -from colossalai.nn import CheckpointModule +from colossalai.legacy.nn import CheckpointModule from .registry import non_distributed_component_funcs from .utils.dummy_data_generator import DummyDataGenerator @@ -19,7 +18,6 @@ class InlineOpModule(CheckpointModule): self.proj2 = nn.Linear(8, 8) def forward(self, x): - x = self.proj1(x) # inline add_ x.add_(10) @@ -31,16 +29,14 @@ class InlineOpModule(CheckpointModule): class DummyDataLoader(DummyDataGenerator): - def generate(self): data = torch.rand(16, 4) label = torch.randint(low=0, high=2, size=(16,)) return data, label -@non_distributed_component_funcs.register(name='inline_op_model') +@non_distributed_component_funcs.register(name="inline_op_model") def get_training_components(): - def model_builder(checkpoint=False): return InlineOpModule(checkpoint) @@ -49,4 +45,5 @@ def get_training_components(): criterion = torch.nn.CrossEntropyLoss() from colossalai.nn.optimizer import HybridAdam + return model_builder, trainloader, testloader, HybridAdam, criterion diff --git a/tests/components_to_test/nested_model.py b/tests/components_to_test/nested_model.py index 339084639244ef4f434218ff77673b46b9deb5c1..44577456dec570f7b6714b7eedfbbe6c7438cf80 100644 --- a/tests/components_to_test/nested_model.py +++ b/tests/components_to_test/nested_model.py @@ -2,14 +2,13 @@ import torch import torch.nn as nn import torch.nn.functional as F -from colossalai.nn import CheckpointModule +from colossalai.legacy.nn import CheckpointModule from .registry import non_distributed_component_funcs from .utils import DummyDataGenerator class SubNet(nn.Module): - def __init__(self, out_features) -> None: super().__init__() self.bias = nn.Parameter(torch.zeros(out_features)) @@ -19,7 +18,6 @@ class SubNet(nn.Module): class NestedNet(CheckpointModule): - def __init__(self, checkpoint=False) -> None: super().__init__(checkpoint) self.fc1 = nn.Linear(5, 5) @@ -35,16 +33,14 @@ class NestedNet(CheckpointModule): class DummyDataLoader(DummyDataGenerator): - def generate(self): data = torch.rand(16, 5) label = torch.randint(low=0, high=2, size=(16,)) return data, label -@non_distributed_component_funcs.register(name='nested_model') +@non_distributed_component_funcs.register(name="nested_model") def get_training_components(): - def model_builder(checkpoint=False): return NestedNet(checkpoint) diff --git a/tests/components_to_test/registry.py b/tests/components_to_test/registry.py index 728ed9eba6ea176b70f152751f877781a5beb214..ec561b7831ad9bf84bbcdd2ce04e214856903bb4 100644 --- a/tests/components_to_test/registry.py +++ b/tests/components_to_test/registry.py @@ -2,17 +2,16 @@ class Registry: - def __init__(self): self._registry = dict() def register(self, name): assert name not in self._registry - def _regsiter(callable_): + def _register(callable_): self._registry[name] = callable_ - return _regsiter + return _register def get_callable(self, name: str): return self._registry[name] @@ -34,6 +33,6 @@ class Registry: non_distributed_component_funcs = Registry() -model_paralle_component_funcs = Registry() +model_parallel_component_funcs = Registry() -__all__ = ['non_distributed_component_funcs', 'model_paralle_component_funcs'] +__all__ = ["non_distributed_component_funcs", "model_parallel_component_funcs"] diff --git a/tests/components_to_test/repeated_computed_layers.py b/tests/components_to_test/repeated_computed_layers.py index b3f84bd0e203eb3fd2d23ae627048f422cfc8a51..3da64de3fb64cb0dcfd7e5d6278fa008b0fca17e 100644 --- a/tests/components_to_test/repeated_computed_layers.py +++ b/tests/components_to_test/repeated_computed_layers.py @@ -3,7 +3,7 @@ import torch import torch.nn as nn -from colossalai.nn import CheckpointModule +from colossalai.legacy.nn import CheckpointModule from .registry import non_distributed_component_funcs from .utils.dummy_data_generator import DummyDataGenerator @@ -29,16 +29,14 @@ class NetWithRepeatedlyComputedLayers(CheckpointModule): class DummyDataLoader(DummyDataGenerator): - def generate(self): data = torch.rand(16, 5) label = torch.randint(low=0, high=2, size=(16,)) return data, label -@non_distributed_component_funcs.register(name='repeated_computed_layers') +@non_distributed_component_funcs.register(name="repeated_computed_layers") def get_training_components(): - def model_builder(checkpoint=False): return NetWithRepeatedlyComputedLayers(checkpoint) diff --git a/tests/components_to_test/resnet.py b/tests/components_to_test/resnet.py index 193832ebc12da0c3f6d18a8bb21fa9a67a2ac02c..a43becc162337294874c125920d1e9ca8640a26e 100644 --- a/tests/components_to_test/resnet.py +++ b/tests/components_to_test/resnet.py @@ -1,28 +1,32 @@ -from torchvision.models import resnet18 -from .registry import non_distributed_component_funcs -from pathlib import Path import os +from pathlib import Path + import torch -from torchvision.transforms import transforms from torchvision.datasets import CIFAR10 -from colossalai.utils import get_dataloader +from torchvision.models import resnet18 +from torchvision.transforms import transforms + +from colossalai.legacy.utils import get_dataloader + +from .registry import non_distributed_component_funcs def get_cifar10_dataloader(train): # build dataloaders - dataset = CIFAR10(root=Path(os.environ['DATA']), - download=True, - train=train, - transform=transforms.Compose( - [transforms.ToTensor(), - transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))])) + dataset = CIFAR10( + root=Path(os.environ["DATA"]), + download=True, + train=train, + transform=transforms.Compose( + [transforms.ToTensor(), transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))] + ), + ) dataloader = get_dataloader(dataset=dataset, shuffle=True, batch_size=16, drop_last=True) return dataloader -@non_distributed_component_funcs.register(name='resnet18') +@non_distributed_component_funcs.register(name="resnet18") def get_resnet_training_components(): - def model_builder(checkpoint=False): return resnet18(num_classes=10) diff --git a/tests/components_to_test/simple_net.py b/tests/components_to_test/simple_net.py index cd9d7ebc0b1a1a699e427905269e8e8a9b0936bf..0f0ac5cff49a201f58e10084b098f376d49014e5 100644 --- a/tests/components_to_test/simple_net.py +++ b/tests/components_to_test/simple_net.py @@ -1,7 +1,7 @@ import torch import torch.nn as nn -from colossalai.nn import CheckpointModule +from colossalai.legacy.nn import CheckpointModule from colossalai.utils.cuda import get_current_device from .registry import non_distributed_component_funcs @@ -33,16 +33,14 @@ class SimpleNet(CheckpointModule): class DummyDataLoader(DummyDataGenerator): - def generate(self): data = torch.randint(low=0, high=20, size=(16,), device=get_current_device()) label = torch.randint(low=0, high=2, size=(16,), device=get_current_device()) return data, label -@non_distributed_component_funcs.register(name='simple_net') +@non_distributed_component_funcs.register(name="simple_net") def get_training_components(): - def model_builder(checkpoint=False): return SimpleNet(checkpoint) @@ -51,4 +49,5 @@ def get_training_components(): criterion = torch.nn.CrossEntropyLoss() from colossalai.nn.optimizer import HybridAdam + return model_builder, trainloader, testloader, HybridAdam, criterion diff --git a/tests/components_to_test/utils/dummy_data_generator.py b/tests/components_to_test/utils/dummy_data_generator.py index 5ab33e86de230da778d0e2bc1b8d1e8e581c1f79..7b3af46c8f3530d0e0395eddd6a0c6312ce0523b 100644 --- a/tests/components_to_test/utils/dummy_data_generator.py +++ b/tests/components_to_test/utils/dummy_data_generator.py @@ -2,7 +2,6 @@ from abc import ABC, abstractmethod class DummyDataGenerator(ABC): - def __init__(self, length=10): self.length = length diff --git a/tests/kit/model_zoo/__init__.py b/tests/kit/model_zoo/__init__.py index 466a2a55882955cf855ab8c9a07991f9a6e833e4..c08fd365d871e659f0b22c2a9b547ffb51764772 100644 --- a/tests/kit/model_zoo/__init__.py +++ b/tests/kit/model_zoo/__init__.py @@ -1,4 +1,4 @@ from . import diffusers, timm, torchaudio, torchrec, torchvision, transformers from .registry import model_zoo -__all__ = ['model_zoo'] +__all__ = ["model_zoo"] diff --git a/tests/kit/model_zoo/diffusers/diffusers.py b/tests/kit/model_zoo/diffusers/diffusers.py index 204c1d7773ca9e87bdbbb49d5cf0bb72b572315f..895ee7967f6b57ee99cea553c47e2401c1280ba1 100644 --- a/tests/kit/model_zoo/diffusers/diffusers.py +++ b/tests/kit/model_zoo/diffusers/diffusers.py @@ -4,7 +4,7 @@ import diffusers import torch import transformers -from ..registry import ModelAttribute, model_zoo +from ..registry import model_zoo BATCH_SIZE = 2 SEQ_LENGTH = 5 @@ -26,10 +26,9 @@ def data_clip_model(): attention_mask = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64) position_ids = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64) pixel_values = torch.zeros((BATCH_SIZE, IN_CHANNELS, HEIGHT, WIDTH), dtype=torch.float32) - return dict(input_ids=input_ids, - pixel_values=pixel_values, - attention_mask=attention_mask, - position_ids=position_ids) + return dict( + input_ids=input_ids, pixel_values=pixel_values, attention_mask=attention_mask, position_ids=position_ids + ) def data_clip_text(): @@ -43,32 +42,41 @@ def data_clip_vision(): return dict(pixel_values=pixel_values) -model_zoo.register(name='diffusers_auto_encoder_kl', - model_fn=diffusers.AutoencoderKL, - data_gen_fn=data_vae_fn, - output_transform_fn=identity_output) - -model_zoo.register(name='diffusers_vq_model', - model_fn=diffusers.VQModel, - data_gen_fn=data_vae_fn, - output_transform_fn=identity_output) - -model_zoo.register(name='diffusers_clip_model', - model_fn=partial(transformers.CLIPModel, config=transformers.CLIPConfig()), - data_gen_fn=data_clip_model, - output_transform_fn=identity_output) - -model_zoo.register(name='diffusers_clip_text_model', - model_fn=partial(transformers.CLIPTextModel, config=transformers.CLIPTextConfig()), - data_gen_fn=data_clip_text, - output_transform_fn=identity_output) - -model_zoo.register(name='diffusers_clip_vision_model', - model_fn=partial(transformers.CLIPVisionModel, config=transformers.CLIPVisionConfig()), - data_gen_fn=data_clip_vision, - output_transform_fn=clip_vision_model_output) - -model_zoo.register(name='diffusers_unet2d_model', - model_fn=diffusers.UNet2DModel, - data_gen_fn=data_unet_fn, - output_transform_fn=identity_output) +model_zoo.register( + name="diffusers_auto_encoder_kl", + model_fn=diffusers.AutoencoderKL, + data_gen_fn=data_vae_fn, + output_transform_fn=identity_output, +) + +model_zoo.register( + name="diffusers_vq_model", model_fn=diffusers.VQModel, data_gen_fn=data_vae_fn, output_transform_fn=identity_output +) + +model_zoo.register( + name="diffusers_clip_model", + model_fn=partial(transformers.CLIPModel, config=transformers.CLIPConfig()), + data_gen_fn=data_clip_model, + output_transform_fn=identity_output, +) + +model_zoo.register( + name="diffusers_clip_text_model", + model_fn=partial(transformers.CLIPTextModel, config=transformers.CLIPTextConfig()), + data_gen_fn=data_clip_text, + output_transform_fn=identity_output, +) + +model_zoo.register( + name="diffusers_clip_vision_model", + model_fn=partial(transformers.CLIPVisionModel, config=transformers.CLIPVisionConfig()), + data_gen_fn=data_clip_vision, + output_transform_fn=clip_vision_model_output, +) + +model_zoo.register( + name="diffusers_unet2d_model", + model_fn=diffusers.UNet2DModel, + data_gen_fn=data_unet_fn, + output_transform_fn=identity_output, +) diff --git a/tests/kit/model_zoo/registry.py b/tests/kit/model_zoo/registry.py index 7470327a65b62e4d126e5724a43ad2748a02962d..b909722918708e3fdc74f2cf108987b85f82b269 100644 --- a/tests/kit/model_zoo/registry.py +++ b/tests/kit/model_zoo/registry.py @@ -2,7 +2,7 @@ from dataclasses import dataclass from typing import Callable -__all__ = ['ModelZooRegistry', 'ModelAttributem', 'model_zoo'] +__all__ = ["ModelZooRegistry", "ModelAttribute", "model_zoo"] @dataclass @@ -14,6 +14,7 @@ class ModelAttribute: has_control_flow (bool): Whether the model contains branching in its forward method. has_stochastic_depth_prob (bool): Whether the model contains stochastic depth probability. Often seen in the torchvision models. """ + has_control_flow: bool = False has_stochastic_depth_prob: bool = False @@ -23,32 +24,42 @@ class ModelZooRegistry(dict): A registry to map model names to model and data generation functions. """ - def register(self, - name: str, - model_fn: Callable, - data_gen_fn: Callable, - output_transform_fn: Callable, - model_attribute: ModelAttribute = None): + def register( + self, + name: str, + model_fn: Callable, + data_gen_fn: Callable, + output_transform_fn: Callable, + loss_fn: Callable = None, + model_attribute: ModelAttribute = None, + ): """ Register a model and data generation function. Examples: - >>> # Register - >>> model_zoo = ModelZooRegistry() - >>> model_zoo.register('resnet18', resnet18, resnet18_data_gen) - >>> # Run the model - >>> data = resnresnet18_data_gen() # do not input any argument - >>> model = resnet18() # do not input any argument - >>> out = model(**data) + + ```python + # normal forward workflow + model = resnet18() + data = resnet18_data_gen() + output = model(**data) + transformed_output = output_transform_fn(output) + loss = loss_fn(transformed_output) + + # Register + model_zoo = ModelZooRegistry() + model_zoo.register('resnet18', resnet18, resnet18_data_gen, output_transform_fn, loss_fn) + ``` Args: name (str): Name of the model. - model_fn (callable): A function that returns a model. **It must not contain any arguments.** - output_transform_fn (callable): A function that transforms the output of the model into Dict. - data_gen_fn (callable): A function that returns a data sample in the form of Dict. **It must not contain any arguments.** + model_fn (Callable): A function that returns a model. **It must not contain any arguments.** + data_gen_fn (Callable): A function that returns a data sample in the form of Dict. **It must not contain any arguments.** + output_transform_fn (Callable): A function that transforms the output of the model into Dict. + loss_fn (Callable): a function to compute the loss from the given output. Defaults to None model_attribute (ModelAttribute): Attributes of the model. Defaults to None. """ - self[name] = (model_fn, data_gen_fn, output_transform_fn, model_attribute) + self[name] = (model_fn, data_gen_fn, output_transform_fn, loss_fn, model_attribute) def get_sub_registry(self, keyword: str): """ @@ -62,6 +73,8 @@ class ModelZooRegistry(dict): for k, v in self.items(): if keyword in k: new_dict[k] = v + + assert len(new_dict) > 0, f"No model found with keyword {keyword}" return new_dict diff --git a/tests/kit/model_zoo/timm/timm.py b/tests/kit/model_zoo/timm/timm.py index b29ac12a6b534f554f6917a78dc38a0b4f533abc..eb6d2f6bc7579b42703b01c5febf2bd8b047380d 100644 --- a/tests/kit/model_zoo/timm/timm.py +++ b/tests/kit/model_zoo/timm/timm.py @@ -9,151 +9,183 @@ from ..registry import ModelAttribute, model_zoo data_gen_fn = lambda: dict(x=torch.rand(2, 3, 224, 224)) output_transform_fn = lambda x: dict(output=x) -model_zoo.register(name='timm_resnet', - model_fn=tm.resnest.resnest50d, - data_gen_fn=data_gen_fn, - output_transform_fn=output_transform_fn) -model_zoo.register(name='timm_beit', - model_fn=tm.beit.beit_base_patch16_224, - data_gen_fn=data_gen_fn, - output_transform_fn=output_transform_fn) -model_zoo.register(name='timm_cait', - model_fn=tm.cait.cait_s24_224, - data_gen_fn=data_gen_fn, - output_transform_fn=output_transform_fn) -model_zoo.register(name='timm_convmixer', - model_fn=tm.convmixer.convmixer_768_32, - data_gen_fn=data_gen_fn, - output_transform_fn=output_transform_fn) -model_zoo.register(name='timm_efficientnetv2', - model_fn=tm.efficientnet.efficientnetv2_m, - data_gen_fn=data_gen_fn, - output_transform_fn=output_transform_fn) -model_zoo.register(name='timm_resmlp', - model_fn=tm.resmlp_12_224, - data_gen_fn=data_gen_fn, - output_transform_fn=output_transform_fn) -model_zoo.register(name='timm_vision_transformer', - model_fn=tm.vision_transformer.vit_base_patch16_224, - data_gen_fn=data_gen_fn, - output_transform_fn=output_transform_fn) -model_zoo.register(name='timm_deit', - model_fn=tm.deit_base_distilled_patch16_224, - data_gen_fn=data_gen_fn, - output_transform_fn=output_transform_fn) -model_zoo.register(name='timm_beitv2', - model_fn=tm.beitv2_base_patch16_224, - data_gen_fn=data_gen_fn, - output_transform_fn=output_transform_fn) -model_zoo.register(name='timm_coat', - model_fn=tm.coat.coat_lite_mini, - data_gen_fn=data_gen_fn, - output_transform_fn=output_transform_fn) +model_zoo.register( + name="timm_resnet", model_fn=tm.resnest.resnest50d, data_gen_fn=data_gen_fn, output_transform_fn=output_transform_fn +) +model_zoo.register( + name="timm_beit", + model_fn=tm.beit.beit_base_patch16_224, + data_gen_fn=data_gen_fn, + output_transform_fn=output_transform_fn, +) +model_zoo.register( + name="timm_cait", model_fn=tm.cait.cait_s24_224, data_gen_fn=data_gen_fn, output_transform_fn=output_transform_fn +) +model_zoo.register( + name="timm_convmixer", + model_fn=tm.convmixer.convmixer_768_32, + data_gen_fn=data_gen_fn, + output_transform_fn=output_transform_fn, +) +model_zoo.register( + name="timm_efficientnetv2", + model_fn=tm.efficientnet.efficientnetv2_m, + data_gen_fn=data_gen_fn, + output_transform_fn=output_transform_fn, +) +model_zoo.register( + name="timm_resmlp", model_fn=tm.resmlp_12_224, data_gen_fn=data_gen_fn, output_transform_fn=output_transform_fn +) +model_zoo.register( + name="timm_vision_transformer", + model_fn=tm.vision_transformer.vit_base_patch16_224, + data_gen_fn=data_gen_fn, + output_transform_fn=output_transform_fn, +) +model_zoo.register( + name="timm_deit", + model_fn=tm.deit_base_distilled_patch16_224, + data_gen_fn=data_gen_fn, + output_transform_fn=output_transform_fn, +) +model_zoo.register( + name="timm_beitv2", + model_fn=tm.beitv2_base_patch16_224, + data_gen_fn=data_gen_fn, + output_transform_fn=output_transform_fn, +) +model_zoo.register( + name="timm_coat", model_fn=tm.coat.coat_lite_mini, data_gen_fn=data_gen_fn, output_transform_fn=output_transform_fn +) -model_zoo.register(name='timm_deit3', - model_fn=tm.deit3_base_patch16_224, - data_gen_fn=data_gen_fn, - output_transform_fn=output_transform_fn) +model_zoo.register( + name="timm_deit3", + model_fn=tm.deit3_base_patch16_224, + data_gen_fn=data_gen_fn, + output_transform_fn=output_transform_fn, +) -model_zoo.register(name='timm_eca_nfnet', - model_fn=tm.eca_nfnet_l0, - data_gen_fn=data_gen_fn, - output_transform_fn=output_transform_fn) -model_zoo.register(name='timm_efficientformer', - model_fn=tm.efficientformer_l1, - data_gen_fn=data_gen_fn, - output_transform_fn=output_transform_fn) -model_zoo.register(name='timm_ese_vovnet19b_dw', - model_fn=tm.ese_vovnet19b_dw, - data_gen_fn=data_gen_fn, - output_transform_fn=output_transform_fn) -model_zoo.register(name='timm_gmixer_12_224', - model_fn=tm.gmixer_12_224, - data_gen_fn=data_gen_fn, - output_transform_fn=output_transform_fn) -model_zoo.register(name='timm_gmlp_b16_224', - model_fn=tm.gmlp_b16_224, - data_gen_fn=data_gen_fn, - output_transform_fn=output_transform_fn) -model_zoo.register(name='timm_hardcorenas_a', - model_fn=tm.hardcorenas_a, - data_gen_fn=data_gen_fn, - output_transform_fn=output_transform_fn) -model_zoo.register(name='timm_hrnet_w18_small', - model_fn=tm.hrnet_w18_small, - data_gen_fn=data_gen_fn, - output_transform_fn=output_transform_fn) -model_zoo.register(name='timm_inception_v3', - model_fn=tm.inception_v3, - data_gen_fn=data_gen_fn, - output_transform_fn=output_transform_fn) -model_zoo.register(name='timm_mixer_b16_224', - model_fn=tm.mixer_b16_224, - data_gen_fn=data_gen_fn, - output_transform_fn=output_transform_fn) -model_zoo.register(name='timm_nf_ecaresnet101', - model_fn=tm.nf_ecaresnet101, - data_gen_fn=data_gen_fn, - output_transform_fn=output_transform_fn) -model_zoo.register(name='timm_nf_regnet_b0', - model_fn=tm.nf_regnet_b0, - data_gen_fn=data_gen_fn, - output_transform_fn=output_transform_fn) -model_zoo.register(name='timm_regnetv_040', - model_fn=tm.regnetv_040, - data_gen_fn=data_gen_fn, - output_transform_fn=output_transform_fn) -model_zoo.register(name='timm_skresnet18', - model_fn=tm.skresnet18, - data_gen_fn=data_gen_fn, - output_transform_fn=output_transform_fn) -model_zoo.register(name='timm_tnt_b_patch16_224', - model_fn=tm.tnt_b_patch16_224, - data_gen_fn=data_gen_fn, - output_transform_fn=output_transform_fn) -model_zoo.register(name='timm_wide_resnet50_2', - model_fn=tm.wide_resnet50_2, - data_gen_fn=data_gen_fn, - output_transform_fn=output_transform_fn) -model_zoo.register(name='timm_convit', - model_fn=tm.convit_base, - data_gen_fn=data_gen_fn, - output_transform_fn=output_transform_fn) -model_zoo.register(name='timm_dm_nfnet', - model_fn=tm.dm_nfnet_f0, - data_gen_fn=data_gen_fn, - output_transform_fn=output_transform_fn) +model_zoo.register( + name="timm_eca_nfnet", model_fn=tm.eca_nfnet_l0, data_gen_fn=data_gen_fn, output_transform_fn=output_transform_fn +) +model_zoo.register( + name="timm_efficientformer", + model_fn=tm.efficientformer_l1, + data_gen_fn=data_gen_fn, + output_transform_fn=output_transform_fn, +) +model_zoo.register( + name="timm_ese_vovnet19b_dw", + model_fn=tm.ese_vovnet19b_dw, + data_gen_fn=data_gen_fn, + output_transform_fn=output_transform_fn, +) +model_zoo.register( + name="timm_gmixer_12_224", + model_fn=tm.gmixer_12_224, + data_gen_fn=data_gen_fn, + output_transform_fn=output_transform_fn, +) +model_zoo.register( + name="timm_gmlp_b16_224", model_fn=tm.gmlp_b16_224, data_gen_fn=data_gen_fn, output_transform_fn=output_transform_fn +) +model_zoo.register( + name="timm_hardcorenas_a", + model_fn=tm.hardcorenas_a, + data_gen_fn=data_gen_fn, + output_transform_fn=output_transform_fn, +) +model_zoo.register( + name="timm_hrnet_w18_small", + model_fn=tm.hrnet_w18_small, + data_gen_fn=data_gen_fn, + output_transform_fn=output_transform_fn, +) +model_zoo.register( + name="timm_inception_v3", model_fn=tm.inception_v3, data_gen_fn=data_gen_fn, output_transform_fn=output_transform_fn +) +model_zoo.register( + name="timm_mixer_b16_224", + model_fn=tm.mixer_b16_224, + data_gen_fn=data_gen_fn, + output_transform_fn=output_transform_fn, +) +model_zoo.register( + name="timm_nf_ecaresnet101", + model_fn=tm.nf_ecaresnet101, + data_gen_fn=data_gen_fn, + output_transform_fn=output_transform_fn, +) +model_zoo.register( + name="timm_nf_regnet_b0", model_fn=tm.nf_regnet_b0, data_gen_fn=data_gen_fn, output_transform_fn=output_transform_fn +) +model_zoo.register( + name="timm_regnetv_040", model_fn=tm.regnetv_040, data_gen_fn=data_gen_fn, output_transform_fn=output_transform_fn +) +model_zoo.register( + name="timm_skresnet18", model_fn=tm.skresnet18, data_gen_fn=data_gen_fn, output_transform_fn=output_transform_fn +) +model_zoo.register( + name="timm_tnt_b_patch16_224", + model_fn=tm.tnt_b_patch16_224, + data_gen_fn=data_gen_fn, + output_transform_fn=output_transform_fn, +) +model_zoo.register( + name="timm_wide_resnet50_2", + model_fn=tm.wide_resnet50_2, + data_gen_fn=data_gen_fn, + output_transform_fn=output_transform_fn, +) +model_zoo.register( + name="timm_convit", model_fn=tm.convit_base, data_gen_fn=data_gen_fn, output_transform_fn=output_transform_fn +) +model_zoo.register( + name="timm_dm_nfnet", model_fn=tm.dm_nfnet_f0, data_gen_fn=data_gen_fn, output_transform_fn=output_transform_fn +) # ============== # Register models with control flow # ============== -model_zoo.register(name='timm_convnext', - model_fn=tm.convnext.convnext_base, - data_gen_fn=data_gen_fn, - output_transform_fn=output_transform_fn, - model_attribute=ModelAttribute(has_control_flow=True)) -model_zoo.register(name='timm_vgg', - model_fn=tm.vgg.vgg11, - data_gen_fn=data_gen_fn, - output_transform_fn=output_transform_fn, - model_attribute=ModelAttribute(has_control_flow=True)) -model_zoo.register(name='timm_dpn', - model_fn=tm.dpn.dpn68, - data_gen_fn=data_gen_fn, - output_transform_fn=output_transform_fn, - model_attribute=ModelAttribute(has_control_flow=True)) -model_zoo.register(name='timm_densenet', - model_fn=tm.densenet.densenet121, - data_gen_fn=data_gen_fn, - output_transform_fn=output_transform_fn, - model_attribute=ModelAttribute(has_control_flow=True)) -model_zoo.register(name='timm_rexnet', - model_fn=tm.rexnet.rexnet_100, - data_gen_fn=data_gen_fn, - output_transform_fn=output_transform_fn, - model_attribute=ModelAttribute(has_control_flow=True)) -model_zoo.register(name='timm_swin_transformer', - model_fn=tm.swin_transformer.swin_base_patch4_window7_224, - data_gen_fn=data_gen_fn, - output_transform_fn=output_transform_fn, - model_attribute=ModelAttribute(has_control_flow=True)) +model_zoo.register( + name="timm_convnext", + model_fn=tm.convnext.convnext_base, + data_gen_fn=data_gen_fn, + output_transform_fn=output_transform_fn, + model_attribute=ModelAttribute(has_control_flow=True), +) +model_zoo.register( + name="timm_vgg", + model_fn=tm.vgg.vgg11, + data_gen_fn=data_gen_fn, + output_transform_fn=output_transform_fn, + model_attribute=ModelAttribute(has_control_flow=True), +) +model_zoo.register( + name="timm_dpn", + model_fn=tm.dpn.dpn68, + data_gen_fn=data_gen_fn, + output_transform_fn=output_transform_fn, + model_attribute=ModelAttribute(has_control_flow=True), +) +model_zoo.register( + name="timm_densenet", + model_fn=tm.densenet.densenet121, + data_gen_fn=data_gen_fn, + output_transform_fn=output_transform_fn, + model_attribute=ModelAttribute(has_control_flow=True), +) +model_zoo.register( + name="timm_rexnet", + model_fn=tm.rexnet.rexnet_100, + data_gen_fn=data_gen_fn, + output_transform_fn=output_transform_fn, + model_attribute=ModelAttribute(has_control_flow=True), +) +model_zoo.register( + name="timm_swin_transformer", + model_fn=tm.swin_transformer.swin_base_patch4_window7_224, + data_gen_fn=data_gen_fn, + output_transform_fn=output_transform_fn, + model_attribute=ModelAttribute(has_control_flow=True), +) diff --git a/tests/kit/model_zoo/torchaudio/torchaudio.py b/tests/kit/model_zoo/torchaudio/torchaudio.py index 9a244ac312c0bfa0479844c887dfdee6bf4f942a..03f565c045539d2263f143d94cd32a0bad418992 100644 --- a/tests/kit/model_zoo/torchaudio/torchaudio.py +++ b/tests/kit/model_zoo/torchaudio/torchaudio.py @@ -23,24 +23,31 @@ def conformer_data_gen_fn(): transformer_output_transform_fn = lambda outputs: dict(frames=outputs[0], lengths=outputs[1]) -model_zoo.register(name='torchaudio_conformer', - model_fn=lambda: tm.Conformer( - input_dim=INPUT_DIM, num_heads=4, ffn_dim=128, num_layers=4, depthwise_conv_kernel_size=31), - data_gen_fn=conformer_data_gen_fn, - output_transform_fn=transformer_output_transform_fn) +model_zoo.register( + name="torchaudio_conformer", + model_fn=lambda: tm.Conformer( + input_dim=INPUT_DIM, num_heads=4, ffn_dim=128, num_layers=4, depthwise_conv_kernel_size=31 + ), + data_gen_fn=conformer_data_gen_fn, + output_transform_fn=transformer_output_transform_fn, +) single_output_transform_fn = lambda output: dict(output=output) -model_zoo.register(name='torchaudio_convtasnet', - model_fn=tm.ConvTasNet, - data_gen_fn=lambda: dict(input=torch.rand(4, 1, 8)), - output_transform_fn=single_output_transform_fn, - model_attribute=ModelAttribute(has_control_flow=True)) +model_zoo.register( + name="torchaudio_convtasnet", + model_fn=tm.ConvTasNet, + data_gen_fn=lambda: dict(input=torch.rand(4, 1, 8)), + output_transform_fn=single_output_transform_fn, + model_attribute=ModelAttribute(has_control_flow=True), +) -model_zoo.register(name='torchaudio_deepspeech', - model_fn=lambda: tm.DeepSpeech(IN_FEATURES, n_hidden=128, n_class=4), - data_gen_fn=lambda: dict(x=torch.rand(4, 1, 10, IN_FEATURES)), - output_transform_fn=single_output_transform_fn) +model_zoo.register( + name="torchaudio_deepspeech", + model_fn=lambda: tm.DeepSpeech(IN_FEATURES, n_hidden=128, n_class=4), + data_gen_fn=lambda: dict(x=torch.rand(4, 1, 10, IN_FEATURES)), + output_transform_fn=single_output_transform_fn, +) def emformer_data_gen_fn(): @@ -50,21 +57,26 @@ def emformer_data_gen_fn(): model_zoo.register( - name='torchaudio_emformer', + name="torchaudio_emformer", model_fn=lambda: tm.Emformer(input_dim=IN_FEATURES, num_heads=4, ffn_dim=128, num_layers=4, segment_length=4), data_gen_fn=emformer_data_gen_fn, output_transform_fn=transformer_output_transform_fn, - model_attribute=ModelAttribute(has_control_flow=True)) + model_attribute=ModelAttribute(has_control_flow=True), +) -model_zoo.register(name='torchaudio_wav2letter_waveform', - model_fn=lambda: tm.Wav2Letter(input_type='waveform', num_features=40), - data_gen_fn=lambda: dict(x=torch.rand(4, 40, 400)), - output_transform_fn=single_output_transform_fn) +model_zoo.register( + name="torchaudio_wav2letter_waveform", + model_fn=lambda: tm.Wav2Letter(input_type="waveform", num_features=40), + data_gen_fn=lambda: dict(x=torch.rand(4, 40, 400)), + output_transform_fn=single_output_transform_fn, +) -model_zoo.register(name='torchaudio_wav2letter_mfcc', - model_fn=lambda: tm.Wav2Letter(input_type='mfcc', num_features=40), - data_gen_fn=lambda: dict(x=torch.rand(4, 40, 400)), - output_transform_fn=single_output_transform_fn) +model_zoo.register( + name="torchaudio_wav2letter_mfcc", + model_fn=lambda: tm.Wav2Letter(input_type="mfcc", num_features=40), + data_gen_fn=lambda: dict(x=torch.rand(4, 40, 400)), + output_transform_fn=single_output_transform_fn, +) def wavernn_data_gen_fn(): @@ -73,20 +85,24 @@ def wavernn_data_gen_fn(): return dict(waveform=waveform, specgram=specgram) -model_zoo.register(name='torchaudio_wavernn', - model_fn=lambda: tm.WaveRNN(upsample_scales=[2, 2, 5], - n_classes=N_CLASSES, - hop_length=HOP_LENGTH, - kernel_size=KERNEL_SIZE, - n_freq=N_FREQ, - n_res_block=2, - n_rnn=64, - n_fc=64, - n_hidden=16, - n_output=16), - data_gen_fn=wavernn_data_gen_fn, - output_transform_fn=single_output_transform_fn, - model_attribute=ModelAttribute(has_control_flow=True)) +model_zoo.register( + name="torchaudio_wavernn", + model_fn=lambda: tm.WaveRNN( + upsample_scales=[2, 2, 5], + n_classes=N_CLASSES, + hop_length=HOP_LENGTH, + kernel_size=KERNEL_SIZE, + n_freq=N_FREQ, + n_res_block=2, + n_rnn=64, + n_fc=64, + n_hidden=16, + n_output=16, + ), + data_gen_fn=wavernn_data_gen_fn, + output_transform_fn=single_output_transform_fn, + model_attribute=ModelAttribute(has_control_flow=True), +) def tacotron_data_gen_fn(): @@ -97,17 +113,18 @@ def tacotron_data_gen_fn(): token_lengths = max_text_length * torch.ones((n_batch,)) mel_specgram = torch.rand(n_batch, N_MELS, max_mel_specgram_length) mel_specgram_lengths = max_mel_specgram_length * torch.ones((n_batch,)) - return dict(tokens=tokens, - token_lengths=token_lengths, - mel_specgram=mel_specgram, - mel_specgram_lengths=mel_specgram_lengths) + return dict( + tokens=tokens, token_lengths=token_lengths, mel_specgram=mel_specgram, mel_specgram_lengths=mel_specgram_lengths + ) -model_zoo.register(name='torchaudio_tacotron', - model_fn=lambda: tm.Tacotron2(n_mels=N_MELS), - data_gen_fn=tacotron_data_gen_fn, - output_transform_fn=lambda outputs: dict(summed_output=sum(x.sum() for x in outputs)), - model_attribute=ModelAttribute(has_control_flow=True)) +model_zoo.register( + name="torchaudio_tacotron", + model_fn=lambda: tm.Tacotron2(n_mels=N_MELS), + data_gen_fn=tacotron_data_gen_fn, + output_transform_fn=lambda outputs: dict(summed_output=sum(x.sum() for x in outputs)), + model_attribute=ModelAttribute(has_control_flow=True), +) def wav2vec_data_gen_fn(): @@ -117,14 +134,18 @@ def wav2vec_data_gen_fn(): return dict(waveforms=waveforms, lengths=lengths) -model_zoo.register(name='torchaudio_wav2vec2_base', - model_fn=partial(tm.wav2vec2_base, encoder_layer_drop=0.0), - data_gen_fn=wav2vec_data_gen_fn, - output_transform_fn=transformer_output_transform_fn, - model_attribute=ModelAttribute(has_control_flow=True)) +model_zoo.register( + name="torchaudio_wav2vec2_base", + model_fn=partial(tm.wav2vec2_base, encoder_layer_drop=0.0), + data_gen_fn=wav2vec_data_gen_fn, + output_transform_fn=transformer_output_transform_fn, + model_attribute=ModelAttribute(has_control_flow=True), +) -model_zoo.register(name='torchaudio_hubert_base', - model_fn=tm.hubert_base, - data_gen_fn=wav2vec_data_gen_fn, - output_transform_fn=transformer_output_transform_fn, - model_attribute=ModelAttribute(has_control_flow=True)) +model_zoo.register( + name="torchaudio_hubert_base", + model_fn=tm.hubert_base, + data_gen_fn=wav2vec_data_gen_fn, + output_transform_fn=transformer_output_transform_fn, + model_attribute=ModelAttribute(has_control_flow=True), +) diff --git a/tests/kit/model_zoo/torchrec/torchrec.py b/tests/kit/model_zoo/torchrec/torchrec.py index dda563155fcac3c996867b244fd5215a2cd65e22..dce66c3d3509dbb8f3671c28b7b8fc9112fbf780 100644 --- a/tests/kit/model_zoo/torchrec/torchrec.py +++ b/tests/kit/model_zoo/torchrec/torchrec.py @@ -1,4 +1,3 @@ -from collections import namedtuple from functools import partial import torch @@ -7,7 +6,7 @@ from torchrec.modules.embedding_configs import EmbeddingBagConfig from torchrec.modules.embedding_modules import EmbeddingBagCollection from torchrec.sparse.jagged_tensor import KeyedJaggedTensor, KeyedTensor -from ..registry import ModelAttribute, model_zoo +from ..registry import model_zoo BATCH = 2 SHAPE = 10 @@ -20,9 +19,9 @@ def gen_kt(): # KeyedJaggedTensor def gen_kjt(): - KJT = KeyedJaggedTensor.from_offsets_sync(keys=["f1", "f2"], - values=torch.tensor([1, 2, 3, 4, 5, 6, 7, 8]), - offsets=torch.tensor([0, 2, 4, 6, 8])) + KJT = KeyedJaggedTensor.from_offsets_sync( + keys=["f1", "f2"], values=torch.tensor([1, 2, 3, 4, 5, 6, 7, 8]), offsets=torch.tensor([0, 2, 4, 6, 8]) + ) return KJT @@ -54,21 +53,11 @@ def output_transform_fn(x): return dict(output=x) -def output_transform_fn(x): - if isinstance(x, KeyedTensor): - output = dict() - for key in x.keys(): - output[key] = x[key] - return output - else: - return dict(output=x) - - def get_ebc(): # EmbeddingBagCollection eb1_config = EmbeddingBagConfig(name="t1", embedding_dim=SHAPE, num_embeddings=SHAPE, feature_names=["f1"]) eb2_config = EmbeddingBagConfig(name="t2", embedding_dim=SHAPE, num_embeddings=SHAPE, feature_names=["f2"]) - return EmbeddingBagCollection(tables=[eb1_config, eb2_config], device=torch.device('cpu')) + return EmbeddingBagCollection(tables=[eb1_config, eb2_config], device=torch.device("cpu")) def sparse_arch_model_fn(): @@ -91,52 +80,69 @@ def dlrm_sparsearch_model_fn(): return dlrm.SparseArch(ebc) -model_zoo.register(name='deepfm_densearch', - model_fn=partial(deepfm.DenseArch, SHAPE, SHAPE, SHAPE), - data_gen_fn=data_gen_fn, - output_transform_fn=output_transform_fn) - -model_zoo.register(name='deepfm_interactionarch', - model_fn=partial(deepfm.FMInteractionArch, SHAPE * 3, ["f1", "f2"], SHAPE), - data_gen_fn=interaction_arch_data_gen_fn, - output_transform_fn=output_transform_fn) - -model_zoo.register(name='deepfm_overarch', - model_fn=partial(deepfm.OverArch, SHAPE), - data_gen_fn=data_gen_fn, - output_transform_fn=output_transform_fn) - -model_zoo.register(name='deepfm_simpledeepfmnn', - model_fn=simple_deep_fmnn_model_fn, - data_gen_fn=simple_dfm_data_gen_fn, - output_transform_fn=output_transform_fn) - -model_zoo.register(name='deepfm_sparsearch', - model_fn=sparse_arch_model_fn, - data_gen_fn=sparse_arch_data_gen_fn, - output_transform_fn=output_transform_fn) - -model_zoo.register(name='dlrm', - model_fn=dlrm_model_fn, - data_gen_fn=simple_dfm_data_gen_fn, - output_transform_fn=output_transform_fn) - -model_zoo.register(name='dlrm_densearch', - model_fn=partial(dlrm.DenseArch, SHAPE, [SHAPE, SHAPE]), - data_gen_fn=data_gen_fn, - output_transform_fn=output_transform_fn) - -model_zoo.register(name='dlrm_interactionarch', - model_fn=partial(dlrm.InteractionArch, 2), - data_gen_fn=interaction_arch_data_gen_fn, - output_transform_fn=output_transform_fn) - -model_zoo.register(name='dlrm_overarch', - model_fn=partial(dlrm.OverArch, SHAPE, [5, 1]), - data_gen_fn=data_gen_fn, - output_transform_fn=output_transform_fn) - -model_zoo.register(name='dlrm_sparsearch', - model_fn=dlrm_sparsearch_model_fn, - data_gen_fn=sparse_arch_data_gen_fn, - output_transform_fn=output_transform_fn) +model_zoo.register( + name="deepfm_densearch", + model_fn=partial(deepfm.DenseArch, SHAPE, SHAPE, SHAPE), + data_gen_fn=data_gen_fn, + output_transform_fn=output_transform_fn, +) + +model_zoo.register( + name="deepfm_interactionarch", + model_fn=partial(deepfm.FMInteractionArch, SHAPE * 3, ["f1", "f2"], SHAPE), + data_gen_fn=interaction_arch_data_gen_fn, + output_transform_fn=output_transform_fn, +) + +model_zoo.register( + name="deepfm_overarch", + model_fn=partial(deepfm.OverArch, SHAPE), + data_gen_fn=data_gen_fn, + output_transform_fn=output_transform_fn, +) + +model_zoo.register( + name="deepfm_simpledeepfmnn", + model_fn=simple_deep_fmnn_model_fn, + data_gen_fn=simple_dfm_data_gen_fn, + output_transform_fn=output_transform_fn, +) + +model_zoo.register( + name="deepfm_sparsearch", + model_fn=sparse_arch_model_fn, + data_gen_fn=sparse_arch_data_gen_fn, + output_transform_fn=output_transform_fn, +) + +model_zoo.register( + name="dlrm", model_fn=dlrm_model_fn, data_gen_fn=simple_dfm_data_gen_fn, output_transform_fn=output_transform_fn +) + +model_zoo.register( + name="dlrm_densearch", + model_fn=partial(dlrm.DenseArch, SHAPE, [SHAPE, SHAPE]), + data_gen_fn=data_gen_fn, + output_transform_fn=output_transform_fn, +) + +model_zoo.register( + name="dlrm_interactionarch", + model_fn=partial(dlrm.InteractionArch, 2), + data_gen_fn=interaction_arch_data_gen_fn, + output_transform_fn=output_transform_fn, +) + +model_zoo.register( + name="dlrm_overarch", + model_fn=partial(dlrm.OverArch, SHAPE, [5, 1]), + data_gen_fn=data_gen_fn, + output_transform_fn=output_transform_fn, +) + +model_zoo.register( + name="dlrm_sparsearch", + model_fn=dlrm_sparsearch_model_fn, + data_gen_fn=sparse_arch_data_gen_fn, + output_transform_fn=output_transform_fn, +) diff --git a/tests/kit/model_zoo/torchvision/torchvision.py b/tests/kit/model_zoo/torchvision/torchvision.py index ddc3ec24b2ff5d486df1054ca58fcc9c6095097a..57b633e9d6765f2ed50035b30b6ab535056323ab 100644 --- a/tests/kit/model_zoo/torchvision/torchvision.py +++ b/tests/kit/model_zoo/torchvision/torchvision.py @@ -1,5 +1,3 @@ -from collections import namedtuple - import torch import torchvision import torchvision.models as tm @@ -29,103 +27,133 @@ def swin_s(): depths=[2, 2, 6, 2], num_heads=[3, 6, 12, 24], window_size=[7, 7], - stochastic_depth_prob=0, # it is originally 0.2, but we set it to 0 to make it deterministic + stochastic_depth_prob=0, # it is originally 0.2, but we set it to 0 to make it deterministic weights=weights, progress=progress, ) # special output transform fn -google_net_output_transform_fn = lambda x: dict(output=sum(x)) if isinstance(x, torchvision.models.GoogLeNetOutputs - ) else dict(output=x) -swin_s_output_output_transform_fn = lambda x: {f'output{idx}': val - for idx, val in enumerate(x)} if isinstance(x, tuple) else dict(output=x) -inception_v3_output_transform_fn = lambda x: dict(output=sum(x)) if isinstance(x, torchvision.models.InceptionOutputs - ) else dict(output=x) +google_net_output_transform_fn = ( + lambda x: dict(output=sum(x)) if isinstance(x, torchvision.models.GoogLeNetOutputs) else dict(output=x) +) +swin_s_output_output_transform_fn = ( + lambda x: {f"output{idx}": val for idx, val in enumerate(x)} if isinstance(x, tuple) else dict(output=x) +) +inception_v3_output_transform_fn = ( + lambda x: dict(output=sum(x)) if isinstance(x, torchvision.models.InceptionOutputs) else dict(output=x) +) -model_zoo.register(name='torchvision_alexnet', - model_fn=tm.alexnet, - data_gen_fn=data_gen_fn, - output_transform_fn=output_transform_fn) -model_zoo.register(name='torchvision_densenet121', - model_fn=tm.densenet121, - data_gen_fn=data_gen_fn, - output_transform_fn=output_transform_fn) -model_zoo.register(name='torchvision_efficientnet_b0', - model_fn=tm.efficientnet_b0, - data_gen_fn=data_gen_fn, - output_transform_fn=output_transform_fn, - model_attribute=ModelAttribute(has_stochastic_depth_prob=True)) -model_zoo.register(name='torchvision_googlenet', - model_fn=tm.googlenet, - data_gen_fn=data_gen_fn, - output_transform_fn=google_net_output_transform_fn) -model_zoo.register(name='torchvision_inception_v3', - model_fn=tm.inception_v3, - data_gen_fn=inception_v3_data_gen_fn, - output_transform_fn=inception_v3_output_transform_fn) -model_zoo.register(name='torchvision_mobilenet_v2', - model_fn=tm.mobilenet_v2, - data_gen_fn=data_gen_fn, - output_transform_fn=output_transform_fn) -model_zoo.register(name='torchvision_mobilenet_v3_small', - model_fn=tm.mobilenet_v3_small, - data_gen_fn=data_gen_fn, - output_transform_fn=output_transform_fn) -model_zoo.register(name='torchvision_mnasnet0_5', - model_fn=tm.mnasnet0_5, - data_gen_fn=data_gen_fn, - output_transform_fn=output_transform_fn) -model_zoo.register(name='torchvision_resnet18', - model_fn=tm.resnet18, - data_gen_fn=data_gen_fn, - output_transform_fn=output_transform_fn) -model_zoo.register(name='torchvision_regnet_x_16gf', - model_fn=tm.regnet_x_16gf, - data_gen_fn=data_gen_fn, - output_transform_fn=output_transform_fn) -model_zoo.register(name='torchvision_resnext50_32x4d', - model_fn=tm.resnext50_32x4d, - data_gen_fn=data_gen_fn, - output_transform_fn=output_transform_fn) -model_zoo.register(name='torchvision_shufflenet_v2_x0_5', - model_fn=tm.shufflenet_v2_x0_5, - data_gen_fn=data_gen_fn, - output_transform_fn=output_transform_fn) -model_zoo.register(name='torchvision_squeezenet1_0', - model_fn=tm.squeezenet1_0, - data_gen_fn=data_gen_fn, - output_transform_fn=output_transform_fn) +model_zoo.register( + name="torchvision_alexnet", model_fn=tm.alexnet, data_gen_fn=data_gen_fn, output_transform_fn=output_transform_fn +) +model_zoo.register( + name="torchvision_densenet121", + model_fn=tm.densenet121, + data_gen_fn=data_gen_fn, + output_transform_fn=output_transform_fn, +) +model_zoo.register( + name="torchvision_efficientnet_b0", + model_fn=tm.efficientnet_b0, + data_gen_fn=data_gen_fn, + output_transform_fn=output_transform_fn, + model_attribute=ModelAttribute(has_stochastic_depth_prob=True), +) +model_zoo.register( + name="torchvision_googlenet", + model_fn=tm.googlenet, + data_gen_fn=data_gen_fn, + output_transform_fn=google_net_output_transform_fn, +) +model_zoo.register( + name="torchvision_inception_v3", + model_fn=tm.inception_v3, + data_gen_fn=inception_v3_data_gen_fn, + output_transform_fn=inception_v3_output_transform_fn, +) +model_zoo.register( + name="torchvision_mobilenet_v2", + model_fn=tm.mobilenet_v2, + data_gen_fn=data_gen_fn, + output_transform_fn=output_transform_fn, +) +model_zoo.register( + name="torchvision_mobilenet_v3_small", + model_fn=tm.mobilenet_v3_small, + data_gen_fn=data_gen_fn, + output_transform_fn=output_transform_fn, +) +model_zoo.register( + name="torchvision_mnasnet0_5", + model_fn=tm.mnasnet0_5, + data_gen_fn=data_gen_fn, + output_transform_fn=output_transform_fn, +) +model_zoo.register( + name="torchvision_resnet18", model_fn=tm.resnet18, data_gen_fn=data_gen_fn, output_transform_fn=output_transform_fn +) +model_zoo.register( + name="torchvision_regnet_x_16gf", + model_fn=tm.regnet_x_16gf, + data_gen_fn=data_gen_fn, + output_transform_fn=output_transform_fn, +) +model_zoo.register( + name="torchvision_resnext50_32x4d", + model_fn=tm.resnext50_32x4d, + data_gen_fn=data_gen_fn, + output_transform_fn=output_transform_fn, +) +model_zoo.register( + name="torchvision_shufflenet_v2_x0_5", + model_fn=tm.shufflenet_v2_x0_5, + data_gen_fn=data_gen_fn, + output_transform_fn=output_transform_fn, +) +model_zoo.register( + name="torchvision_squeezenet1_0", + model_fn=tm.squeezenet1_0, + data_gen_fn=data_gen_fn, + output_transform_fn=output_transform_fn, +) -model_zoo.register(name='torchvision_vgg11', - model_fn=tm.vgg11, - data_gen_fn=data_gen_fn, - output_transform_fn=output_transform_fn) -model_zoo.register(name='torchvision_wide_resnet50_2', - model_fn=tm.wide_resnet50_2, - data_gen_fn=data_gen_fn, - output_transform_fn=output_transform_fn) +model_zoo.register( + name="torchvision_vgg11", model_fn=tm.vgg11, data_gen_fn=data_gen_fn, output_transform_fn=output_transform_fn +) +model_zoo.register( + name="torchvision_wide_resnet50_2", + model_fn=tm.wide_resnet50_2, + data_gen_fn=data_gen_fn, + output_transform_fn=output_transform_fn, +) -if version.parse(torchvision.__version__) >= version.parse('0.12.0'): - model_zoo.register(name='torchvision_vit_b_16', - model_fn=tm.vit_b_16, - data_gen_fn=data_gen_fn, - output_transform_fn=output_transform_fn) - model_zoo.register(name='torchvision_convnext_base', - model_fn=tm.convnext_base, - data_gen_fn=data_gen_fn, - output_transform_fn=output_transform_fn, - model_attribute=ModelAttribute(has_stochastic_depth_prob=True)) +if version.parse(torchvision.__version__) >= version.parse("0.12.0"): + model_zoo.register( + name="torchvision_vit_b_16", + model_fn=tm.vit_b_16, + data_gen_fn=data_gen_fn, + output_transform_fn=output_transform_fn, + ) + model_zoo.register( + name="torchvision_convnext_base", + model_fn=tm.convnext_base, + data_gen_fn=data_gen_fn, + output_transform_fn=output_transform_fn, + model_attribute=ModelAttribute(has_stochastic_depth_prob=True), + ) -if version.parse(torchvision.__version__) >= version.parse('0.13.0'): +if version.parse(torchvision.__version__) >= version.parse("0.13.0"): model_zoo.register( - name='torchvision_swin_s', + name="torchvision_swin_s", model_fn=swin_s, data_gen_fn=data_gen_fn, output_transform_fn=swin_s_output_output_transform_fn, ) - model_zoo.register(name='torchvision_efficientnet_v2_s', - model_fn=tm.efficientnet_v2_s, - data_gen_fn=data_gen_fn, - output_transform_fn=output_transform_fn, - model_attribute=ModelAttribute(has_stochastic_depth_prob=True)) + model_zoo.register( + name="torchvision_efficientnet_v2_s", + model_fn=tm.efficientnet_v2_s, + data_gen_fn=data_gen_fn, + output_transform_fn=output_transform_fn, + model_attribute=ModelAttribute(has_stochastic_depth_prob=True), + ) diff --git a/tests/kit/model_zoo/transformers/__init__.py b/tests/kit/model_zoo/transformers/__init__.py index f56ff7ad84eb1b1360afc0fd47d817756c6df35b..2a492361b13b1beac548d41e10b9afaf04e3b108 100644 --- a/tests/kit/model_zoo/transformers/__init__.py +++ b/tests/kit/model_zoo/transformers/__init__.py @@ -1,5 +1,12 @@ from .albert import * from .bert import * +from .blip2 import * +from .bloom import * +from .chatglm2 import * from .gpt import * +from .llama import * from .opt import * +from .sam import * from .t5 import * +from .vit import * +from .whisper import * diff --git a/tests/kit/model_zoo/transformers/albert.py b/tests/kit/model_zoo/transformers/albert.py index e85f564e376a53669374dc21eeceb88a69ee691c..d1c23703b3e495f0f1108faa35ddeb3ac5d41c2f 100644 --- a/tests/kit/model_zoo/transformers/albert.py +++ b/tests/kit/model_zoo/transformers/albert.py @@ -17,39 +17,54 @@ def data_gen_fn(): return dict(input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask) +def data_gen_for_pretrain(): + inputs = data_gen_fn() + inputs["labels"] = inputs["input_ids"].clone() + inputs["sentence_order_label"] = torch.zeros(BATCH_SIZE, dtype=torch.int64) + return inputs + + output_transform_fn = lambda x: x -config = transformers.AlbertConfig(embedding_size=128, - hidden_size=128, - num_hidden_layers=2, - num_attention_heads=4, - intermediate_size=256) - -model_zoo.register(name='transformers_albert', - model_fn=lambda: transformers.AlbertModel(config), - data_gen_fn=data_gen_fn, - output_transform_fn=output_transform_fn, - model_attribute=ModelAttribute(has_control_flow=True)) -model_zoo.register(name='transformers_albert_for_pretraining', - model_fn=lambda: transformers.AlbertForPreTraining(config), - data_gen_fn=data_gen_fn, - output_transform_fn=output_transform_fn, - model_attribute=ModelAttribute(has_control_flow=True)) -model_zoo.register(name='transformers_albert_for_masked_lm', - model_fn=lambda: transformers.AlbertForMaskedLM(config), - data_gen_fn=data_gen_fn, - output_transform_fn=output_transform_fn, - model_attribute=ModelAttribute(has_control_flow=True)) -model_zoo.register(name='transformers_albert_for_sequence_classification', - model_fn=lambda: transformers.AlbertForSequenceClassification(config), - data_gen_fn=data_gen_fn, - output_transform_fn=output_transform_fn, - model_attribute=ModelAttribute(has_control_flow=True)) -model_zoo.register(name='transformers_albert_for_token_classification', - model_fn=lambda: transformers.AlbertForTokenClassification(config), - data_gen_fn=data_gen_fn, - output_transform_fn=output_transform_fn, - model_attribute=ModelAttribute(has_control_flow=True)) +config = transformers.AlbertConfig( + embedding_size=128, hidden_size=128, num_hidden_layers=2, num_attention_heads=4, intermediate_size=256 +) + +model_zoo.register( + name="transformers_albert", + model_fn=lambda: transformers.AlbertModel(config, add_pooling_layer=False), + data_gen_fn=data_gen_fn, + output_transform_fn=output_transform_fn, + model_attribute=ModelAttribute(has_control_flow=True), +) +model_zoo.register( + name="transformers_albert_for_pretraining", + model_fn=lambda: transformers.AlbertForPreTraining(config), + data_gen_fn=data_gen_for_pretrain, + output_transform_fn=lambda x: dict(loss=x.loss), + model_attribute=ModelAttribute(has_control_flow=True), +) +model_zoo.register( + name="transformers_albert_for_masked_lm", + model_fn=lambda: transformers.AlbertForMaskedLM(config), + data_gen_fn=data_gen_fn, + output_transform_fn=output_transform_fn, + model_attribute=ModelAttribute(has_control_flow=True), +) +model_zoo.register( + name="transformers_albert_for_sequence_classification", + model_fn=lambda: transformers.AlbertForSequenceClassification(config), + data_gen_fn=data_gen_fn, + output_transform_fn=output_transform_fn, + model_attribute=ModelAttribute(has_control_flow=True), +) +model_zoo.register( + name="transformers_albert_for_token_classification", + model_fn=lambda: transformers.AlbertForTokenClassification(config), + data_gen_fn=data_gen_fn, + output_transform_fn=output_transform_fn, + model_attribute=ModelAttribute(has_control_flow=True), +) # =============================== # Register multi-sentence ALBERT @@ -73,13 +88,17 @@ def data_gen_for_mcq(): return encoding -model_zoo.register(name='transformers_albert_for_question_answering', - model_fn=lambda: transformers.AlbertForQuestionAnswering(config), - data_gen_fn=data_gen_for_qa, - output_transform_fn=output_transform_fn, - model_attribute=ModelAttribute(has_control_flow=True)) -model_zoo.register(name='transformers_albert_for_multiple_choice', - model_fn=lambda: transformers.AlbertForMultipleChoice(config), - data_gen_fn=data_gen_for_mcq, - output_transform_fn=output_transform_fn, - model_attribute=ModelAttribute(has_control_flow=True)) +model_zoo.register( + name="transformers_albert_for_question_answering", + model_fn=lambda: transformers.AlbertForQuestionAnswering(config), + data_gen_fn=data_gen_for_qa, + output_transform_fn=output_transform_fn, + model_attribute=ModelAttribute(has_control_flow=True), +) +model_zoo.register( + name="transformers_albert_for_multiple_choice", + model_fn=lambda: transformers.AlbertForMultipleChoice(config), + data_gen_fn=data_gen_for_mcq, + output_transform_fn=output_transform_fn, + model_attribute=ModelAttribute(has_control_flow=True), +) diff --git a/tests/kit/model_zoo/transformers/bert.py b/tests/kit/model_zoo/transformers/bert.py index 99135704da70f96237bdcb6553ebd847ea8fcba1..8b90a3c7372c33855b29b8a2b19ca6ab9e1b143c 100644 --- a/tests/kit/model_zoo/transformers/bert.py +++ b/tests/kit/model_zoo/transformers/bert.py @@ -6,83 +6,442 @@ from ..registry import ModelAttribute, model_zoo # =============================== # Register single-sentence BERT # =============================== -BATCH_SIZE = 2 -SEQ_LENGTH = 16 -def data_gen_fn(): - input_ids = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64) - token_type_ids = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64) - attention_mask = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64) +# define data gen function +def data_gen(): + # Generated from following code snippet + # + # from transformers import BertTokenizer + # input = 'Hello, my dog is cute' + # tokenized_input = tokenizer(input, return_tensors='pt') + # input_ids = tokenized_input['input_ids'] + # attention_mask = tokenized_input['attention_mask'] + # token_type_ids = tokenized_input['token_type_ids'] + input_ids = torch.tensor([[101, 7592, 1010, 2026, 3899, 2003, 10140, 102]], dtype=torch.int64) + token_type_ids = torch.tensor([[0, 0, 0, 0, 0, 0, 0, 0]], dtype=torch.int64) + attention_mask = torch.tensor([[1, 1, 1, 1, 1, 1, 1, 0]], dtype=torch.int64) return dict(input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask) -output_transform_fn = lambda x: x +def data_gen_for_lm(): + # LM data gen + # the `labels` of LM is the token of the output, cause no padding, use `input_ids` as `labels` + data = data_gen() + data["labels"] = data["input_ids"].clone() + return data -config = transformers.BertConfig(hidden_size=128, num_hidden_layers=2, num_attention_heads=4, intermediate_size=256) -# register the BERT variants -model_zoo.register(name='transformers_bert', - model_fn=lambda: transformers.BertModel(config), - data_gen_fn=data_gen_fn, - output_transform_fn=output_transform_fn, - model_attribute=ModelAttribute(has_control_flow=True)) -model_zoo.register(name='transformers_bert_for_pretraining', - model_fn=lambda: transformers.BertForPreTraining(config), - data_gen_fn=data_gen_fn, - output_transform_fn=output_transform_fn, - model_attribute=ModelAttribute(has_control_flow=True)) -model_zoo.register(name='transformers_bert_lm_head_model', - model_fn=lambda: transformers.BertLMHeadModel(config), - data_gen_fn=data_gen_fn, - output_transform_fn=output_transform_fn, - model_attribute=ModelAttribute(has_control_flow=True)) -model_zoo.register(name='transformers_bert_for_masked_lm', - model_fn=lambda: transformers.BertForMaskedLM(config), - data_gen_fn=data_gen_fn, - output_transform_fn=output_transform_fn, - model_attribute=ModelAttribute(has_control_flow=True)) -model_zoo.register(name='transformers_bert_for_sequence_classification', - model_fn=lambda: transformers.BertForSequenceClassification(config), - data_gen_fn=data_gen_fn, - output_transform_fn=output_transform_fn, - model_attribute=ModelAttribute(has_control_flow=True)) -model_zoo.register(name='transformers_bert_for_token_classification', - model_fn=lambda: transformers.BertForTokenClassification(config), - data_gen_fn=data_gen_fn, - output_transform_fn=output_transform_fn, - model_attribute=ModelAttribute(has_control_flow=True)) +def data_gen_for_pretraining(): + # pretraining data gen + # `next_sentence_label` is the label for next sentence prediction, 0 or 1 + data = data_gen_for_lm() + data["next_sentence_label"] = torch.tensor([1], dtype=torch.int64) + return data -# =============================== -# Register multi-sentence BERT -# =============================== -def data_gen_for_next_sentence(): - tokenizer = transformers.BertTokenizer.from_pretrained("bert-base-uncased") - prompt = "In Italy, pizza served in formal settings, such as at a restaurant, is presented unsliced." - next_sentence = "The sky is blue due to the shorter wavelength of blue light." - encoding = tokenizer(prompt, next_sentence, return_tensors="pt") - return encoding +def data_gen_for_sequence_classification(): + # sequence classification data gen + # `labels` is the label for sequence classification, 0 or 1 + data = data_gen() + data["labels"] = torch.tensor([1], dtype=torch.int64) + return data + + +def data_gen_for_token_classification(): + # token classification data gen + # `labels` is the type not the token id for token classification, 0 or 1 + data = data_gen() + data["labels"] = torch.tensor([[1, 0, 0, 0, 0, 0, 0, 0]], dtype=torch.int64) + return data def data_gen_for_mcq(): - tokenizer = transformers.BertTokenizer.from_pretrained("bert-base-uncased") - prompt = "In Italy, pizza served in formal settings, such as at a restaurant, is presented unsliced." - choice0 = "It is eaten with a fork and a knife." - choice1 = "It is eaten while held in the hand." - encoding = tokenizer([prompt, prompt], [choice0, choice1], return_tensors="pt", padding=True) - encoding = {k: v.unsqueeze(0) for k, v in encoding.items()} - return encoding - - -# register the following models -model_zoo.register(name='transformers_bert_for_next_sentence', - model_fn=lambda: transformers.BertForNextSentencePrediction(config), - data_gen_fn=data_gen_for_next_sentence, - output_transform_fn=output_transform_fn, - model_attribute=ModelAttribute(has_control_flow=True)) -model_zoo.register(name='transformers_bert_for_mcq', - model_fn=lambda: transformers.BertForMultipleChoice(config), - data_gen_fn=data_gen_for_mcq, - output_transform_fn=output_transform_fn, - model_attribute=ModelAttribute(has_control_flow=True)) + # multiple choice question data gen + # Generated from following code snippet + # + # tokenizer = transformers.BertTokenizer.from_pretrained("bert-base-uncased") + # prompt = "In Italy, pizza served in formal settings, such as at a restaurant, is presented unsliced." + # choice0 = "It is eaten with a fork and a knife." + # choice1 = "It is eaten while held in the hand." + # data = tokenizer([prompt, prompt], [choice0, choice1], return_tensors="pt", padding=True) + # data = {k: v.unsqueeze(0) for k, v in encoding.items()} + # data['labels'] = torch.tensor([0], dtype=torch.int64) + input_ids = torch.tensor( + [ + [ + [ + 101, + 1999, + 3304, + 1010, + 10733, + 2366, + 1999, + 5337, + 10906, + 1010, + 2107, + 2004, + 2012, + 1037, + 4825, + 1010, + 2003, + 3591, + 4895, + 14540, + 6610, + 2094, + 1012, + 102, + 2009, + 2003, + 8828, + 2007, + 1037, + 9292, + 1998, + 1037, + 5442, + 1012, + 102, + 102, + 5442, + 1012, + 102, + 102, + ], + [ + 101, + 1999, + 3304, + 1010, + 10733, + 2366, + 1999, + 5337, + 10906, + 1010, + 2107, + 2004, + 2012, + 1037, + 4825, + 1010, + 2003, + 3591, + 4895, + 14540, + 6610, + 2094, + 1012, + 102, + 2009, + 2003, + 8828, + 2096, + 2218, + 1999, + 1996, + 2192, + 1012, + 102, + 0, + 0, + 1012, + 102, + 0, + 0, + ], + ] + ] + ) + token_type_ids = torch.tensor( + [ + [ + [ + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + ], + [ + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 0, + 0, + 1, + 1, + 0, + 0, + ], + ] + ] + ) + attention_mask = torch.tensor( + [ + [ + [ + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + ], + [ + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 0, + 0, + 1, + 1, + 0, + 0, + ], + ] + ] + ) + labels = torch.tensor([0], dtype=torch.int64) + + return dict(input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask, labels=labels) + + +def data_gen_for_qa(): + # generating data for question answering + # no need for labels and use start and end position instead + data = data_gen() + start_positions = torch.tensor([0], dtype=torch.int64) + data["start_positions"] = start_positions + end_positions = torch.tensor([1], dtype=torch.int64) + data["end_positions"] = end_positions + return data + + +# define output transform function +output_transform_fn = lambda x: x + +# define loss funciton + +loss_fn_for_bert_model = lambda x: torch.nn.functional.mse_loss( + x.last_hidden_state, torch.ones_like(x.last_hidden_state) +) +loss_fn = lambda x: x.loss + +config = transformers.BertConfig( + hidden_size=128, + num_hidden_layers=2, + num_attention_heads=4, + intermediate_size=256, + hidden_dropout_prob=0, + attention_probs_dropout_prob=0, +) + +# register the BERT variants +model_zoo.register( + name="transformers_bert", + model_fn=lambda: transformers.BertModel(config, add_pooling_layer=False), + data_gen_fn=data_gen, + output_transform_fn=output_transform_fn, + loss_fn=loss_fn_for_bert_model, + model_attribute=ModelAttribute(has_control_flow=True), +) +model_zoo.register( + name="transformers_bert_for_pretraining", + model_fn=lambda: transformers.BertForPreTraining(config), + data_gen_fn=data_gen_for_pretraining, + output_transform_fn=output_transform_fn, + loss_fn=loss_fn, + model_attribute=ModelAttribute(has_control_flow=True), +) +model_zoo.register( + name="transformers_bert_lm_head_model", + model_fn=lambda: transformers.BertLMHeadModel(config), + data_gen_fn=data_gen_for_lm, + output_transform_fn=output_transform_fn, + loss_fn=loss_fn, + model_attribute=ModelAttribute(has_control_flow=True), +) +model_zoo.register( + name="transformers_bert_for_masked_lm", + model_fn=lambda: transformers.BertForMaskedLM(config), + data_gen_fn=data_gen_for_lm, + output_transform_fn=output_transform_fn, + loss_fn=loss_fn, + model_attribute=ModelAttribute(has_control_flow=True), +) +model_zoo.register( + name="transformers_bert_for_sequence_classification", + model_fn=lambda: transformers.BertForSequenceClassification(config), + data_gen_fn=data_gen_for_sequence_classification, + output_transform_fn=output_transform_fn, + loss_fn=loss_fn, + model_attribute=ModelAttribute(has_control_flow=True), +) +model_zoo.register( + name="transformers_bert_for_token_classification", + model_fn=lambda: transformers.BertForTokenClassification(config), + data_gen_fn=data_gen_for_token_classification, + output_transform_fn=output_transform_fn, + loss_fn=loss_fn, + model_attribute=ModelAttribute(has_control_flow=True), +) +model_zoo.register( + name="transformers_bert_for_next_sentence", + model_fn=lambda: transformers.BertForNextSentencePrediction(config), + data_gen_fn=data_gen_for_sequence_classification, + output_transform_fn=output_transform_fn, + loss_fn=loss_fn, + model_attribute=ModelAttribute(has_control_flow=True), +) +model_zoo.register( + name="transformers_bert_for_mcq", + model_fn=lambda: transformers.BertForMultipleChoice(config), + data_gen_fn=data_gen_for_mcq, + output_transform_fn=output_transform_fn, + loss_fn=loss_fn, + model_attribute=ModelAttribute(has_control_flow=True), +) +model_zoo.register( + name="transformers_bert_for_question_answering", + model_fn=lambda: transformers.BertForQuestionAnswering(config), + data_gen_fn=data_gen_for_qa, + output_transform_fn=output_transform_fn, + loss_fn=loss_fn, + model_attribute=ModelAttribute(has_control_flow=True), +) diff --git a/tests/kit/model_zoo/transformers/blip2.py b/tests/kit/model_zoo/transformers/blip2.py new file mode 100644 index 0000000000000000000000000000000000000000..887b11c7f54e1d51914380a23d249de3fe4440b4 --- /dev/null +++ b/tests/kit/model_zoo/transformers/blip2.py @@ -0,0 +1,66 @@ +import torch +import transformers + +from ..registry import ModelAttribute, model_zoo + +# =============================== +# Register single-image SAM +# =============================== + + +# define data gen function +def data_gen(): + # Generated from following code snippet + # + # from PIL import Image + # import requests + # from transformers import Blip2Processor, Blip2Model + # import torch + + # processor = Blip2Processor.from_pretrained("Salesforce/blip2-opt-2.7b") + # url = "http://images.cocodataset.org/val2017/000000039769.jpg" + # image = Image.open(requests.get(url, stream=True).raw) + + # prompt = "Question: how many cats are there? Answer:" + # inputs = processor(images=image, text=prompt, return_tensors="pt").to(device, torch.float16) + + pixel_values = torch.rand(1, 3, 224, 224, dtype=torch.float32) + input_ids = torch.tensor([[2, 45641, 35, 141, 171, 10017, 32, 89, 116, 31652, 35]], dtype=torch.int64) + attention_mask = torch.tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]], dtype=torch.int64) + labels = torch.tensor([[34, 56]], dtype=torch.int64) + return dict(pixel_values=pixel_values, input_ids=input_ids, attention_mask=attention_mask, labels=labels) + + +# define output transform function +output_transform_fn = lambda x: x + +# define loss funciton +loss_fn_blip2_model = lambda x: x.loss + +config = transformers.Blip2Config() +config.vision_config.patch_size = 14 +config.text_config.num_hidden_layers = 1 +config.qformer_config.num_hidden_layers = 1 +config.vision_config.num_hidden_layers = 1 +config.qformer_config.attention_probs_dropout_prob = 0 +config.qformer_config.hidden_dropout_prob = 0 +config.text_config.dropout = 0 + +# register the blip2 variants +model_zoo.register( + name="transformers_blip2", + model_fn=lambda: transformers.Blip2Model(config), + data_gen_fn=data_gen, + output_transform_fn=output_transform_fn, + loss_fn=loss_fn_blip2_model, + model_attribute=ModelAttribute(has_control_flow=True), +) + +model_zoo.register( + name="transformers_blip2_conditional_gerneration", + model_fn=lambda: transformers.Blip2ForConditionalGeneration(config), + data_gen_fn=data_gen, + output_transform_fn=output_transform_fn, + loss_fn=loss_fn_blip2_model, + model_attribute=ModelAttribute(has_control_flow=True), +) diff --git a/tests/kit/model_zoo/transformers/bloom.py b/tests/kit/model_zoo/transformers/bloom.py new file mode 100644 index 0000000000000000000000000000000000000000..12dcd71d5d1b24da62e731e36fa0fd4b4e28882f --- /dev/null +++ b/tests/kit/model_zoo/transformers/bloom.py @@ -0,0 +1,122 @@ +import torch +import transformers + +from ..registry import ModelAttribute, model_zoo + +# =============================== +# Register Bloom +# =============================== + + +def data_gen(): + # Generated from following code snippet + # + # from transformers import BloomTokenizer + # input = 'Hello, my dog is cute' + # tokenized_input = tokenizer(input, return_tensors='pt') + # input_ids = tokenized_input['input_ids'] + # attention_mask = tokenized_input['attention_mask'] + input_ids = torch.tensor([[59414, 15, 2670, 35433, 632, 207595, 632, 207595]], dtype=torch.int64) + attention_mask = torch.tensor([[1, 1, 1, 1, 1, 1, 1, 1]], dtype=torch.int64) + return dict(input_ids=input_ids, attention_mask=attention_mask) + + +def data_gen_for_lm(): + # LM data gen + # the `labels` of LM is the token of the output, cause no padding, use `input_ids` as `labels` + data = data_gen() + data["labels"] = data["input_ids"].clone() + return data + + +def data_gen_for_token_classification(): + # token classification data gen + # `labels` is the type not the token id for token classification, 0 or 1 + data = data_gen() + data["labels"] = torch.tensor([[0, 0, 0, 0, 0, 0, 0, 0]], dtype=torch.int64) + return data + + +def data_gen_for_sequence_classification(): + # sequence classification data gen + data = data_gen() + data["labels"] = torch.tensor([0], dtype=torch.int64) + return data + + +def data_gen_for_question_answering(): + # obtained with the following code + # + # from transformers import AutoTokenizer + # tokenizer = AutoTokenizer.from_pretrained("bigscience/bloom-560m") + # question, text = "Who was Jim Henson?", "Jim Henson was a nice puppet" + # inputs = tokenizer(question, text, return_tensors="pt") + + input_ids = torch.tensor( + [[57647, 1620, 23967, 620, 107373, 34, 91514, 620, 107373, 1620, 267, 35378, 48946, 18161, 48946, 18161]], + dtype=torch.int64, + ) + attention_mask = torch.tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]], dtype=torch.int64) + start_positions = torch.tensor([1], dtype=torch.int64) + end_positions = torch.tensor([10], dtype=torch.int64) + return dict( + input_ids=input_ids, attention_mask=attention_mask, start_positions=start_positions, end_positions=end_positions + ) + + +# define output transform function +output_transform_fn = lambda x: x + +# define loss function +loss_fn_for_bloom_model = lambda x: torch.nn.functional.mse_loss( + x.last_hidden_state, torch.ones_like(x.last_hidden_state) +) +loss_fn_for_causal_lm = lambda x: x.loss +loss_fn_for_classification = lambda x: x.loss +loss_fn_for_question_answering = lambda x: x.loss + +config = transformers.BloomConfig( + n_layer=2, n_head=4, vocab_size=250880, hidden_dropout=0, attention_dropout=0, hidden_size=64, pad_token_id=50256 +) + +# register the following models +model_zoo.register( + name="transformers_bloom", + model_fn=lambda: transformers.BloomModel(config), + data_gen_fn=data_gen, + output_transform_fn=output_transform_fn, + loss_fn=loss_fn_for_bloom_model, + model_attribute=ModelAttribute(has_control_flow=True), +) +model_zoo.register( + name="transformers_bloom_for_causal_lm", + model_fn=lambda: transformers.BloomForCausalLM(config), + data_gen_fn=data_gen_for_lm, + output_transform_fn=output_transform_fn, + loss_fn=loss_fn_for_causal_lm, + model_attribute=ModelAttribute(has_control_flow=True), +) +model_zoo.register( + name="transformers_bloom_for_sequence_classification", + model_fn=lambda: transformers.BloomForSequenceClassification(config), + data_gen_fn=data_gen_for_sequence_classification, + output_transform_fn=output_transform_fn, + loss_fn=loss_fn_for_classification, + model_attribute=ModelAttribute(has_control_flow=True), +) +model_zoo.register( + name="transformers_bloom_for_token_classification", + model_fn=lambda: transformers.BloomForTokenClassification(config), + data_gen_fn=data_gen_for_token_classification, + output_transform_fn=output_transform_fn, + loss_fn=loss_fn_for_classification, + model_attribute=ModelAttribute(has_control_flow=True), +) +model_zoo.register( + name="transformers_bloom_for_question_answering", + model_fn=lambda: transformers.BloomForQuestionAnswering(config), + data_gen_fn=data_gen_for_question_answering, + output_transform_fn=output_transform_fn, + loss_fn=loss_fn_for_question_answering, + model_attribute=ModelAttribute(has_control_flow=True), +) diff --git a/tests/kit/model_zoo/transformers/chatglm2.py b/tests/kit/model_zoo/transformers/chatglm2.py new file mode 100644 index 0000000000000000000000000000000000000000..f4369cb7d1715913dfcfb91190c9c7693496f22a --- /dev/null +++ b/tests/kit/model_zoo/transformers/chatglm2.py @@ -0,0 +1,79 @@ +import torch + +from colossalai.shardformer.modeling.chatglm2_6b.configuration_chatglm import ChatGLMConfig +from colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm import ChatGLMForConditionalGeneration, ChatGLMModel + +from ..registry import ModelAttribute, model_zoo + +# ================================ +# Register single-sentence ChatGLM +# ================================ + + +def data_gen(): + input_ids = torch.tensor([[5941, 15, 2670, 3543, 632, 2075, 632, 2075]], dtype=torch.int64) + attention_mask = torch.tensor([[1, 1, 1, 1, 1, 1, 1, 1]]) + return dict(input_ids=input_ids, attention_mask=attention_mask) + + +def data_gen_for_conditional_generation(): + # token classification data gen + # `labels` is the type not the token id for token classification, 0 or 1 + data = data_gen() + labels = data["input_ids"].clone() + data["labels"] = labels + return data + + +# define output transform function +output_transform_fn = lambda x: x + +# define loss function +loss_fn_for_chatglm_model = lambda x: torch.nn.functional.mse_loss( + x.last_hidden_state, torch.ones_like(x.last_hidden_state) +) +loss_fn = lambda x: x.loss + +config = ChatGLMConfig( + num_layers=2, + padded_vocab_size=65024, + hidden_size=64, + num_attention_heads=8, + kv_channels=16, + rmsnorm=True, + original_rope=True, + use_cache=True, + torch_dtype=torch.float32, +) + +infer_config = ChatGLMConfig( + num_layers=2, + padded_vocab_size=65024, + hidden_size=128, + num_attention_heads=8, + multi_query_attention=True, + multi_query_group_num=2, + kv_channels=16, + rmsnorm=True, + original_rope=True, + use_cache=True, + torch_dtype=torch.float32, +) + +model_zoo.register( + name="transformers_chatglm", + model_fn=lambda: ChatGLMModel(config, empty_init=False), + data_gen_fn=data_gen, + output_transform_fn=output_transform_fn, + loss_fn=loss_fn_for_chatglm_model, + model_attribute=ModelAttribute(has_control_flow=True), +) + +model_zoo.register( + name="transformers_chatglm_for_conditional_generation", + model_fn=lambda: ChatGLMForConditionalGeneration(config, empty_init=False), + data_gen_fn=data_gen_for_conditional_generation, + output_transform_fn=output_transform_fn, + loss_fn=loss_fn, + model_attribute=ModelAttribute(has_control_flow=True), +) diff --git a/tests/kit/model_zoo/transformers/gpt.py b/tests/kit/model_zoo/transformers/gpt.py index 5ed4fbe70dc9df559ad8285ddae49616c43c2f9e..2af6176fbe4a558097a335086fedead34f0a1618 100644 --- a/tests/kit/model_zoo/transformers/gpt.py +++ b/tests/kit/model_zoo/transformers/gpt.py @@ -1,3 +1,5 @@ +import copy + import torch import transformers @@ -6,52 +8,151 @@ from ..registry import ModelAttribute, model_zoo # =============================== # Register single-sentence GPT # =============================== -BATCH_SIZE = 1 # it can only be 1 as GPT cannot handle batch sizes > 1 if no padding token is defined. -SEQ_LENGTH = 16 def data_gen(): - input_ids = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64) - token_type_ids = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64) - attention_mask = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64) - return dict(input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask) + # Generated from following code snippet + # + # from transformers import GPT2Tokenizer + # input = 'Hello, my dog is cute' + # tokenized_input = tokenizer(input, return_tensors='pt') + # input_ids = tokenized_input['input_ids'] + # attention_mask = tokenized_input['attention_mask'] + input_ids = torch.tensor([[15496, 11, 616, 3290, 318, 13779, 318, 13779]], dtype=torch.int64) + attention_mask = torch.tensor([[1, 1, 1, 1, 1, 1, 1, 1]], dtype=torch.int64) + return dict(input_ids=input_ids, attention_mask=attention_mask) + + +def data_gen_for_lm(): + # LM data gen + # the `labels` of LM is the token of the output, cause no padding, use `input_ids` as `labels` + data = data_gen() + data["labels"] = data["input_ids"].clone() + return data + + +def data_gen_for_question_answering(): + # question answering data gen + # `labels` is the type not the token id for token classification, 0 or 1 + data = data_gen() + start_positions = torch.tensor([0], dtype=torch.int64) + data["start_positions"] = start_positions + end_positions = torch.tensor([1], dtype=torch.int64) + data["end_positions"] = end_positions + return data + + +def data_gen_for_token_classification(): + # token classification data gen + # `labels` is the type not the token id for token classification, 0 or 1 + data = data_gen() + data["labels"] = torch.tensor([[0, 0, 0, 0, 0, 0, 0, 1]], dtype=torch.int64) + return data -def seq_classification_data_gen(): - # batch sizes should be 1 if no padding token is defined. - input_ids = torch.zeros((1, SEQ_LENGTH), dtype=torch.int64) - token_type_ids = torch.zeros((1, SEQ_LENGTH), dtype=torch.int64) - attention_mask = torch.zeros((1, SEQ_LENGTH), dtype=torch.int64) - return dict(input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask) +def data_gen_for_sequence_classification(): + # sequence classification data gen + data = data_gen() + data["labels"] = torch.tensor([1], dtype=torch.int64) + return data +def date_gen_for_double_heads(): + num_choices = 2 + batch_size = 2 + input_ids = torch.tensor( + [[15496, 11, 616, 3290, 318, 13779, 318, 13779], [15496, 11, 616, 3290, 318, 13779, 318, 13779]], + dtype=torch.int64, + ) + attention_mask = torch.tensor([[1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1]], dtype=torch.int64) + mc_labels = torch.zeros(input_ids.shape[0], dtype=torch.int64) + + mc_token_ids = torch.arange(0, num_choices, dtype=torch.int64) + mc_token_ids = mc_token_ids.expand((batch_size, num_choices)) + multiple_choice_inputs_ids = input_ids.unsqueeze(1).expand(-1, num_choices, -1).contiguous() + multiple_choice_input_mask = attention_mask.unsqueeze(1).expand(-1, num_choices, -1).contiguous() + + inputs = { + "input_ids": multiple_choice_inputs_ids, + "mc_token_ids": mc_token_ids, + "attention_mask": multiple_choice_input_mask, + "labels": multiple_choice_inputs_ids, + "mc_labels": mc_labels, + } + return inputs + + +# define output transform function output_transform_fn = lambda x: x -config = transformers.GPT2Config(n_position=64, n_layer=2, n_head=4) +# define loss function +loss_fn_for_gpt2_model = lambda x: torch.nn.functional.mse_loss( + x.last_hidden_state, torch.ones_like(x.last_hidden_state) +) +loss_fn = lambda x: x.loss + +config = transformers.GPT2Config( + n_layer=2, + n_head=4, + vocab_size=50258, + attn_pdrop=0, + embd_pdrop=0, + resid_pdrop=0, + summary_first_dropout=0, + hidden_dropout=0, + problem_type="single_label_classification", + pad_token_id=50256, +) + +config_for_token_classification = copy.deepcopy(config) +config_for_token_classification.num_labels = 2 # register the following models -model_zoo.register(name='transformers_gpt', - model_fn=lambda: transformers.GPT2Model(config), - data_gen_fn=data_gen, - output_transform_fn=output_transform_fn, - model_attribute=ModelAttribute(has_control_flow=True)) -model_zoo.register(name='transformers_gpt_lm', - model_fn=lambda: transformers.GPT2LMHeadModel(config), - data_gen_fn=data_gen, - output_transform_fn=output_transform_fn, - model_attribute=ModelAttribute(has_control_flow=True)) -model_zoo.register(name='transformers_gpt_double_heads', - model_fn=lambda: transformers.GPT2DoubleHeadsModel(config), - data_gen_fn=data_gen, - output_transform_fn=output_transform_fn, - model_attribute=ModelAttribute(has_control_flow=True)) -model_zoo.register(name='transformers_gpt_for_token_classification', - model_fn=lambda: transformers.GPT2ForTokenClassification(config), - data_gen_fn=data_gen, - output_transform_fn=output_transform_fn, - model_attribute=ModelAttribute(has_control_flow=True)) -model_zoo.register(name='transformers_gpt_for_sequence_classification', - model_fn=lambda: transformers.GPT2ForSequenceClassification(config), - data_gen_fn=seq_classification_data_gen, - output_transform_fn=output_transform_fn, - model_attribute=ModelAttribute(has_control_flow=True)) +model_zoo.register( + name="transformers_gpt", + model_fn=lambda: transformers.GPT2Model(config), + data_gen_fn=data_gen, + output_transform_fn=output_transform_fn, + loss_fn=loss_fn_for_gpt2_model, + model_attribute=ModelAttribute(has_control_flow=True), +) +model_zoo.register( + name="transformers_gpt_lm", + model_fn=lambda: transformers.GPT2LMHeadModel(config), + data_gen_fn=data_gen_for_lm, + output_transform_fn=output_transform_fn, + loss_fn=loss_fn, + model_attribute=ModelAttribute(has_control_flow=True), +) +model_zoo.register( + name="transformers_gpt_double_heads", + model_fn=lambda: transformers.GPT2DoubleHeadsModel(config), + data_gen_fn=date_gen_for_double_heads, + output_transform_fn=output_transform_fn, + loss_fn=lambda x: x.loss + x.mc_loss, + model_attribute=ModelAttribute(has_control_flow=True), +) +model_zoo.register( + name="transformers_gpt_for_question_answering", + model_fn=lambda: transformers.GPT2ForQuestionAnswering(config), + data_gen_fn=data_gen_for_question_answering, + output_transform_fn=output_transform_fn, + loss_fn=loss_fn, + model_attribute=ModelAttribute(has_control_flow=True), +) +model_zoo.register( + name="transformers_gpt_for_token_classification", + model_fn=lambda: transformers.GPT2ForTokenClassification(config_for_token_classification), + data_gen_fn=data_gen_for_token_classification, + output_transform_fn=output_transform_fn, + loss_fn=loss_fn, + model_attribute=ModelAttribute(has_control_flow=True), +) +model_zoo.register( + name="transformers_gpt_for_sequence_classification", + model_fn=lambda: transformers.GPT2ForSequenceClassification(config_for_token_classification), + data_gen_fn=data_gen_for_sequence_classification, + output_transform_fn=output_transform_fn, + loss_fn=loss_fn, + model_attribute=ModelAttribute(has_control_flow=True), +) diff --git a/tests/kit/model_zoo/transformers/llama.py b/tests/kit/model_zoo/transformers/llama.py new file mode 100644 index 0000000000000000000000000000000000000000..bc229b17e08c110e7f164b104999661ac710a661 --- /dev/null +++ b/tests/kit/model_zoo/transformers/llama.py @@ -0,0 +1,88 @@ +import torch +import transformers + +from ..registry import ModelAttribute, model_zoo + +try: + from transformers import LlamaConfig + + HAS_LLAMA = True +except ImportError: + HAS_LLAMA = False + +if HAS_LLAMA: + # =============================== + # Register LLaMA + # =============================== + + def data_gen(): + # the input ids are corresponding to the sentence + # 'Hello, my dog is cute' + # + # the code is give below: + # ----------------------------------- + # from transformers import LlamaTokenizerFast + # tokenizer = LlamaTokenizerFast.from_pretrained("hf-internal-testing/llama-tokenizer") + # input = 'Hello, my dog is cute' + # tokenized_input = tokenizer(input, return_tensors='pt').to('cuda') + # ----------------------------------- + + input_ids = torch.Tensor([[1, 15043, 29892, 590, 11203, 338, 274, 1082]]).long() + attention_mask = torch.Tensor([[1, 1, 1, 1, 1, 1, 1, 1]]).long() + return dict(input_ids=input_ids, attention_mask=attention_mask) + + # label is needed for casual lm + def data_gen_for_casual_lm(): + data = data_gen() + labels = data["input_ids"].clone() + data["labels"] = labels + return data + + # transform the output to a dict + output_transform_fn = lambda x: x + + # function to get the loss + loss_fn = lambda output: output.last_hidden_state.mean() + loss_fn_for_casual_lm = lambda output: output.loss + loss_fn_for_seq_classification = lambda output: output.logits.mean() + + config = LlamaConfig( + num_hidden_layers=4, + hidden_size=128, + intermediate_size=256, + num_attention_heads=4, + max_position_embeddings=128, + num_labels=16, + ) + + if hasattr(config, "pad_token_id"): + config.pad_token_id = config.eos_token_id + + # register the following models + # transformers.LlamaModel, + # transformers.LlamaForCausalLM, + # transformers.LlamaForSequenceClassification, + model_zoo.register( + name="transformers_llama", + model_fn=lambda: transformers.LlamaModel(config), + data_gen_fn=data_gen, + output_transform_fn=output_transform_fn, + loss_fn=loss_fn, + model_attribute=ModelAttribute(has_control_flow=True), + ) + model_zoo.register( + name="transformers_llama_for_casual_lm", + model_fn=lambda: transformers.LlamaForCausalLM(config), + data_gen_fn=data_gen_for_casual_lm, + output_transform_fn=output_transform_fn, + loss_fn=loss_fn_for_casual_lm, + model_attribute=ModelAttribute(has_control_flow=True), + ) + model_zoo.register( + name="transformers_llama_for_sequence_classification", + model_fn=lambda: transformers.LlamaForSequenceClassification(config), + data_gen_fn=data_gen, + output_transform_fn=output_transform_fn, + loss_fn=loss_fn_for_seq_classification, + model_attribute=ModelAttribute(has_control_flow=True), + ) diff --git a/tests/kit/model_zoo/transformers/opt.py b/tests/kit/model_zoo/transformers/opt.py index d9c4a0b3c23c52ac54fcdff786cae7eacdb2137c..07ca41ef21ae5e18986eebe775f57b4cac1e8611 100644 --- a/tests/kit/model_zoo/transformers/opt.py +++ b/tests/kit/model_zoo/transformers/opt.py @@ -11,25 +11,82 @@ SEQ_LENGTH = 16 def data_gen(): - input_ids = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64) - attention_mask = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64) + input_ids = torch.Tensor([[1, 15043, 29892, 590, 11203, 338, 274, 1082]]).long() + attention_mask = torch.Tensor([[1, 1, 1, 1, 1, 1, 1, 1]]).long() return dict(input_ids=input_ids, attention_mask=attention_mask) -output_transform_fn = lambda x: x +def data_gen_for_causal_lm(): + # LM data gen + # the `labels` of LM is the token of the output, cause no padding, use `input_ids` as `labels` + data = data_gen() + labels = data["input_ids"].clone() + data["labels"] = labels + return data + + +def data_gen_for_sequence_classification(): + # LM data gen + # the `labels` of LM is the token of the output, cause no padding, use `input_ids` as `labels` + data = data_gen() + data["input_ids"].clone() + data["labels"] = torch.tensor([1]) + return data + -config = transformers.OPTConfig(hidden_size=128, num_hidden_layers=2, num_attention_heads=4) +def data_gen_for_question_answering(): + # LM data gen + # the `labels` of LM is the token of the output, cause no padding, use `input_ids` as `labels` + data = data_gen() + data["start_positions"] = torch.tensor([0]) + data["end_positions"] = torch.tensor([1]) + return data + + +output_transform_fn = lambda x: x +loss_fn_for_opt_model = lambda x: torch.nn.functional.mse_loss( + x.last_hidden_state, torch.ones_like(x.last_hidden_state) +) +loss_fn_for_lm = lambda x: x.loss +config = transformers.OPTConfig( + hidden_size=128, + num_hidden_layers=2, + num_attention_heads=4, + dropout=0, +) # register the following models # transformers.OPTModel, # transformers.OPTForCausalLM, -model_zoo.register(name='transformers_opt', - model_fn=lambda: transformers.OPTModel(config), - data_gen_fn=data_gen, - output_transform_fn=output_transform_fn, - model_attribute=ModelAttribute(has_control_flow=True)) -model_zoo.register(name='transformers_opt_for_causal_lm', - model_fn=lambda: transformers.OPTForCausalLM(config), - data_gen_fn=data_gen, - output_transform_fn=output_transform_fn, - model_attribute=ModelAttribute(has_control_flow=True)) +model_zoo.register( + name="transformers_opt", + model_fn=lambda: transformers.OPTModel(config), + data_gen_fn=data_gen, + output_transform_fn=output_transform_fn, + loss_fn=loss_fn_for_opt_model, + model_attribute=ModelAttribute(has_control_flow=True), +) +model_zoo.register( + name="transformers_opt_for_causal_lm", + model_fn=lambda: transformers.OPTForCausalLM(config), + data_gen_fn=data_gen_for_causal_lm, + output_transform_fn=output_transform_fn, + loss_fn=loss_fn_for_lm, + model_attribute=ModelAttribute(has_control_flow=True), +) +model_zoo.register( + name="transformers_opt_for_question_answering", + model_fn=lambda: transformers.OPTForQuestionAnswering(config), + data_gen_fn=data_gen_for_question_answering, + output_transform_fn=output_transform_fn, + loss_fn=loss_fn_for_lm, + model_attribute=ModelAttribute(has_control_flow=True), +) + +# TODO The loss and gradient check in the test are failing, to be fixed. +# model_zoo.register(name='transformers_opt_for_sequence_classification', +# model_fn=lambda: transformers.OPTForSequenceClassification(config), +# data_gen_fn=data_gen_for_sequence_classification, +# output_transform_fn=output_transform_fn, +# loss_fn=loss_fn_for_lm, +# model_attribute=ModelAttribute(has_control_flow=True)) diff --git a/tests/kit/model_zoo/transformers/sam.py b/tests/kit/model_zoo/transformers/sam.py new file mode 100644 index 0000000000000000000000000000000000000000..b928a8f14e75c8ed212412145d4b74fa394d9a12 --- /dev/null +++ b/tests/kit/model_zoo/transformers/sam.py @@ -0,0 +1,56 @@ +import torch +import transformers + +from ..registry import ModelAttribute, model_zoo + +# =============================== +# Register single-image SAM +# =============================== + + +# define data gen function +def data_gen(): + # Generated from following code snippet + # + # from PIL import Image + # import requests + # from transformers import SamModel, SamProcessor + # + # model = SamModel.from_pretrained("facebook/sam-vit-base") + # processor = SamProcessor.from_pretrained("facebook/sam-vit-base") + # + # img_url = "https://huggingface.co/ybelkada/segment-anything/resolve/main/assets/car.png" + # raw_image = Image.open(requests.get(img_url, stream=True).raw).convert("RGB") + # input_points = [[[450, 600]]] # 2D localization of a window + # inputs = processor(raw_image, input_points=input_points, return_tensors="pt") + + pixel_values = torch.rand(1, 3, 1024, 1024, dtype=torch.float32) + original_sizes = torch.tensor([[1764, 2646]], dtype=torch.int64) + reshaped_input_sizes = torch.tensor([[683, 1024]], dtype=torch.int64) + input_points = torch.tensor([[[[174.1497, 232.3129]]]], dtype=torch.float64) + return dict( + pixel_values=pixel_values, + original_sizes=original_sizes, + reshaped_input_sizes=reshaped_input_sizes, + input_points=input_points, + ) + + +# define output transform function +output_transform_fn = lambda x: x + +# define loss funciton +loss_fn = lambda x: x.iou_scores.mean() + +config = transformers.SamConfig() +config.vision_config.num_hidden_layers = 2 + +# register the BERT variants +model_zoo.register( + name="transformers_sam", + model_fn=lambda: transformers.SamModel(config), + data_gen_fn=data_gen, + output_transform_fn=output_transform_fn, + loss_fn=loss_fn, + model_attribute=ModelAttribute(has_control_flow=True), +) diff --git a/tests/kit/model_zoo/transformers/t5.py b/tests/kit/model_zoo/transformers/t5.py index b81bcad90db87f1b3edebe48316d9d87bbd3b795..1b63cccc42ee43c980fb43f8fde11b17452f95aa 100644 --- a/tests/kit/model_zoo/transformers/t5.py +++ b/tests/kit/model_zoo/transformers/t5.py @@ -6,41 +6,76 @@ from ..registry import ModelAttribute, model_zoo # =============================== # Register single-sentence T5 # =============================== -BATCH_SIZE = 2 -SEQ_LENGTH = 16 -def data_gen(): - input_ids = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64) - decoder_input_ids = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64) - return dict(input_ids=input_ids, decoder_input_ids=decoder_input_ids) +# define data gen function +def data_gen_for_encoder_only(): + # Generated from following code snippet + # + # from transformers import T5Config, T5Tokenizer + # config = T5Config(decoder_start_token_id=0) + # tokenizer = T5Tokenizer.from_pretrained("t5-small") + # input_ids = tokenizer("translate English to German: The house is wonderful.", return_tensors="pt").input_ids + input_ids = torch.Tensor([[13959, 1566, 12, 2968, 10, 37, 629, 19, 1627, 5, 1, 12, 1627, 5, 1, 12]]).long() + attention_mask = torch.Tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0]]).long() + return dict(input_ids=input_ids, attention_mask=attention_mask) -def data_gen_for_encoder_only(): - input_ids = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64) - return dict(input_ids=input_ids) +def data_gen_for_conditional_generation(): + # labels is generated with the following code + # + # labels = tokenizer("Das Haus ist wunderbar.", return_tensors="pt").input_ids + data = data_gen_for_encoder_only() + labels = torch.Tensor([[644, 4598, 229, 19250, 5, 1, 644, 4598, 229, 19250, 5, 1, 229, 19250, 5, 1]]).long() + data["labels"] = labels + return data + +def data_gen_for_t5_model(): + # decoder_inputs_ids is obtained with the following code + # decoder_input_ids = model._shift_right(input_ids) + data = data_gen_for_encoder_only() + decoder_input_ids = torch.Tensor([[0, 13959, 1566, 12, 2968, 10, 37, 629, 19, 1627, 5, 5, 19, 1627, 5, 5]]).long() + data["decoder_input_ids"] = decoder_input_ids + return data + +# output transform function output_transform_fn = lambda x: x -config = transformers.T5Config(d_model=128, num_layers=2) +# define loss function +loss_fn_for_t5_model = lambda x: x.last_hidden_state.mean() +loss_fn_for_encoder_only = lambda x: x.last_hidden_state.mean() +loss_fn_for_conditional_generation = lambda x: x.loss + +# define model config +config = transformers.T5Config(d_model=128, num_layers=2, dropout_rate=0, decoder_start_token_id=0) # register the following models # transformers.T5Model, # transformers.T5ForConditionalGeneration, # transformers.T5EncoderModel, -model_zoo.register(name='transformers_t5', - model_fn=lambda: transformers.T5Model(config), - data_gen_fn=data_gen, - output_transform_fn=output_transform_fn, - model_attribute=ModelAttribute(has_control_flow=True)) -model_zoo.register(name='transformers_t5_for_conditional_generation', - model_fn=lambda: transformers.T5ForConditionalGeneration(config), - data_gen_fn=data_gen, - output_transform_fn=output_transform_fn, - model_attribute=ModelAttribute(has_control_flow=True)) -model_zoo.register(name='transformers_t5_encoder_model', - model_fn=lambda: transformers.T5EncoderModel(config), - data_gen_fn=data_gen_for_encoder_only, - output_transform_fn=output_transform_fn, - model_attribute=ModelAttribute(has_control_flow=True)) +model_zoo.register( + name="transformers_t5", + model_fn=lambda: transformers.T5Model(config), + data_gen_fn=data_gen_for_t5_model, + output_transform_fn=output_transform_fn, + loss_fn=loss_fn_for_t5_model, + model_attribute=ModelAttribute(has_control_flow=True), +) +model_zoo.register( + name="transformers_t5_for_conditional_generation", + model_fn=lambda: transformers.T5ForConditionalGeneration(config), + data_gen_fn=data_gen_for_conditional_generation, + output_transform_fn=output_transform_fn, + loss_fn=loss_fn_for_conditional_generation, + model_attribute=ModelAttribute(has_control_flow=True), +) +model_zoo.register( + name="transformers_t5_encoder_model", + model_fn=lambda: transformers.T5EncoderModel(config), + data_gen_fn=data_gen_for_encoder_only, + output_transform_fn=output_transform_fn, + loss_fn=loss_fn_for_encoder_only, + model_attribute=ModelAttribute(has_control_flow=True), +) diff --git a/tests/kit/model_zoo/transformers/vit.py b/tests/kit/model_zoo/transformers/vit.py new file mode 100644 index 0000000000000000000000000000000000000000..f1990751b0168ef485219029464f9ac92192324c --- /dev/null +++ b/tests/kit/model_zoo/transformers/vit.py @@ -0,0 +1,70 @@ +import torch +import transformers + +from ..registry import ModelAttribute, model_zoo + +# =============================== +# Register single-sentence VIT +# =============================== + +config = transformers.ViTConfig(num_hidden_layers=4, hidden_size=128, intermediate_size=256, num_attention_heads=4) + + +# define data gen function +def data_gen(): + pixel_values = torch.randn(1, 3, 224, 224) + return dict(pixel_values=pixel_values) + + +def data_gen_for_image_classification(): + data = data_gen() + data["labels"] = torch.tensor([0]) + return data + + +def data_gen_for_masked_image_modeling(): + data = data_gen() + num_patches = (config.image_size // config.patch_size) ** 2 + bool_masked_pos = torch.randint(low=0, high=2, size=(1, num_patches)).bool() + data["bool_masked_pos"] = bool_masked_pos + return data + + +# define output transform function +output_transform_fn = lambda x: x + +# function to get the loss +loss_fn_for_vit_model = lambda x: x.pooler_output.mean() +loss_fn_for_image_classification = lambda x: x.logits.mean() +loss_fn_for_masked_image_modeling = lambda x: x.loss + +# register the following models +# transformers.ViTModel, +# transformers.ViTForMaskedImageModeling, +# transformers.ViTForImageClassification, +model_zoo.register( + name="transformers_vit", + model_fn=lambda: transformers.ViTModel(config), + data_gen_fn=data_gen, + output_transform_fn=output_transform_fn, + loss_fn=loss_fn_for_vit_model, + model_attribute=ModelAttribute(has_control_flow=True), +) + +model_zoo.register( + name="transformers_vit_for_masked_image_modeling", + model_fn=lambda: transformers.ViTForMaskedImageModeling(config), + data_gen_fn=data_gen_for_masked_image_modeling, + output_transform_fn=output_transform_fn, + loss_fn=loss_fn_for_masked_image_modeling, + model_attribute=ModelAttribute(has_control_flow=True), +) + +model_zoo.register( + name="transformers_vit_for_image_classification", + model_fn=lambda: transformers.ViTForImageClassification(config), + data_gen_fn=data_gen_for_image_classification, + output_transform_fn=output_transform_fn, + loss_fn=loss_fn_for_image_classification, + model_attribute=ModelAttribute(has_control_flow=True), +) diff --git a/tests/kit/model_zoo/transformers/whisper.py b/tests/kit/model_zoo/transformers/whisper.py new file mode 100644 index 0000000000000000000000000000000000000000..928be4468c01befa9f3098c4aaefe94bcb6c8158 --- /dev/null +++ b/tests/kit/model_zoo/transformers/whisper.py @@ -0,0 +1,97 @@ +import torch +import transformers + +from ..registry import ModelAttribute, model_zoo + +# =============================== +# Register single-sentence Whisper +# =============================== + + +# define data gen function +def data_gen(): + # Generated from following code snippet + # + # from transformers import AutoFeatureExtractor, WhisperModel + # from datasets import load_dataset + + # model = WhisperModel.from_pretrained("openai/whisper-base") + # feature_extractor = AutoFeatureExtractor.from_pretrained("openai/whisper-base") + # ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation") + # inputs = feature_extractor(ds[0]["audio"]["array"], return_tensors="pt") + # input_features = inputs.input_features + # decoder_input_ids = torch.tensor([[1, 1]]) * model.config.decoder_start_token_id + + input_features = torch.rand(1, 80, 3000) + decoder_input_ids = torch.tensor([[1, 1]]) * 50258 + return dict(input_features=input_features, decoder_input_ids=decoder_input_ids) + + +def data_gen_for_conditional_generation(): + # labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + # Labels for computing the language modeling loss. Indices should either be in `[0, ..., config.vocab_size]` + # or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored (masked), the loss is + # only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + data = data_gen() + data["labels"] = torch.tensor([[0, 1]], dtype=torch.int64) + return data + + +def data_gen_for_audio_classification(): + # labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + # Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + # config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + # `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + # `WhisperForAudioClassification` does not need `decoder_input_ids` + data = data_gen() + data.pop("decoder_input_ids") + data["labels"] = torch.tensor([1], dtype=torch.int64) + return data + + +# define output transform function +output_transform_fn = lambda x: x + +# define loss funciton +loss_fn = lambda x: torch.nn.functional.mse_loss(x.last_hidden_state, torch.ones_like(x.last_hidden_state)) +loss_fn_attr = lambda x: x.loss + +config = transformers.WhisperConfig( + classifier_proj_size=256, + d_model=256, + decoder_attention_heads=4, + decoder_ffn_dim=1536, + decoder_layers=2, + encoder_attention_heads=4, + encoder_ffn_dim=1536, + encoder_layers=2, + vocab_size=51866, +) + +# register the Whisper variants +model_zoo.register( + name="transformers_whisper", + model_fn=lambda: transformers.WhisperModel(config), + data_gen_fn=data_gen, + output_transform_fn=output_transform_fn, + loss_fn=loss_fn, + model_attribute=ModelAttribute(has_control_flow=True), +) + +model_zoo.register( + name="transformers_whisper_for_conditional_generation", + model_fn=lambda: transformers.WhisperForConditionalGeneration(config), + data_gen_fn=data_gen_for_conditional_generation, + output_transform_fn=output_transform_fn, + loss_fn=loss_fn_attr, + model_attribute=ModelAttribute(has_control_flow=True), +) + +model_zoo.register( + name="transformers_whisper_for_audio_classification", + model_fn=lambda: transformers.WhisperForAudioClassification(config), + data_gen_fn=data_gen_for_audio_classification, + output_transform_fn=output_transform_fn, + loss_fn=loss_fn_attr, + model_attribute=ModelAttribute(has_control_flow=True), +) diff --git a/tests/test_analyzer/test_fx/test_bias_addition.py b/tests/test_analyzer/test_fx/test_bias_addition.py index f7b5eb140f2437c03bbf3f6f399960f5c5d6ed98..f72c1cb3f533f3cab0cd0b6b6db92e7cafea6e77 100644 --- a/tests/test_analyzer/test_fx/test_bias_addition.py +++ b/tests/test_analyzer/test_fx/test_bias_addition.py @@ -12,7 +12,6 @@ except: class LinearModel(torch.nn.Module): - def __init__(self, in_features, out_features, bias): super().__init__() self.linear = torch.nn.Linear(in_features, out_features, bias=bias) @@ -23,25 +22,14 @@ class LinearModel(torch.nn.Module): class ConvModel(torch.nn.Module): - def __init__(self, in_channel, out_channels, kernel_size, bias) -> None: super().__init__() - self.conv = torch.nn.Conv2d(in_channel, - out_channels, - kernel_size, - bias=bias, - padding=1, - stride=2, - dilation=2, - groups=3) - self.conv_transpose = torch.nn.ConvTranspose2d(in_channel, - out_channels, - kernel_size, - bias=bias, - padding=1, - stride=2, - dilation=2, - groups=3) + self.conv = torch.nn.Conv2d( + in_channel, out_channels, kernel_size, bias=bias, padding=1, stride=2, dilation=2, groups=3 + ) + self.conv_transpose = torch.nn.ConvTranspose2d( + in_channel, out_channels, kernel_size, bias=bias, padding=1, stride=2, dilation=2, groups=3 + ) def forward(self, x, select=0): if select == 0: @@ -52,7 +40,6 @@ class ConvModel(torch.nn.Module): class SiuModel(torch.nn.Module): - def __init__(self, bias) -> None: super().__init__() self.linear = LinearModel(3, 3, bias) @@ -69,7 +56,6 @@ class SiuModel(torch.nn.Module): class AddmmModel(torch.nn.Module): - def __init__(self, alpha, beta) -> None: super().__init__() self.alpha = alpha @@ -80,7 +66,7 @@ class AddmmModel(torch.nn.Module): return x -@pytest.mark.skipif(version.parse(torch.__version__) < version.parse('1.12.0'), reason='torch version < 12') +@pytest.mark.skipif(version.parse(torch.__version__) < version.parse("1.12.0"), reason="torch version < 12") @clear_cache_before_run() @parameterize("bias", [True, False]) @parameterize("bias_addition_split", [True, False]) @@ -89,19 +75,21 @@ class AddmmModel(torch.nn.Module): def test_siu_model(bias, bias_addition_split, shape, select): model = SiuModel(bias=bias) x = torch.rand(shape) - gm = symbolic_trace(model, - meta_args={'x': x}, - concrete_args={'select': select}, - trace_act_ckpt=True, - bias_addition_split=bias_addition_split) - assert torch.allclose(model(x, select), gm(x)), 'original model and traced model should be the same!' + gm = symbolic_trace( + model, + meta_args={"x": x}, + concrete_args={"select": select}, + trace_act_ckpt=True, + bias_addition_split=bias_addition_split, + ) + assert torch.allclose(model(x, select), gm(x)), "original model and traced model should be the same!" if bias and bias_addition_split: - assert '+' in gm.code, 'bias addition should be split!' + assert "+" in gm.code, "bias addition should be split!" else: - assert '+' not in gm.code, 'bias addition should not be split!' + assert "+" not in gm.code, "bias addition should not be split!" -@pytest.mark.skipif(version.parse(torch.__version__) < version.parse('1.12.0'), reason='torch version < 12') +@pytest.mark.skipif(version.parse(torch.__version__) < version.parse("1.12.0"), reason="torch version < 12") @parameterize("alpha", [1, 2]) @parameterize("beta", [1, 2]) @parameterize("bias_addition_split", [True, False]) @@ -109,14 +97,14 @@ def test_siu_model(bias, bias_addition_split, shape, select): def test_addmm_model(alpha, beta, bias_addition_split, shape): model = AddmmModel(alpha=alpha, beta=beta) x = torch.rand(shape) - gm = symbolic_trace(model, meta_args={'x': x}, trace_act_ckpt=True, bias_addition_split=bias_addition_split) - assert torch.allclose(model(x), gm(x)), 'original model and traced model should be the same!' + gm = symbolic_trace(model, meta_args={"x": x}, trace_act_ckpt=True, bias_addition_split=bias_addition_split) + assert torch.allclose(model(x), gm(x)), "original model and traced model should be the same!" if (alpha == 1 and beta == 1) or not bias_addition_split: - assert '*' not in gm.code, 'bias addition should not be split!' + assert "*" not in gm.code, "bias addition should not be split!" elif bias_addition_split: - assert '+' in gm.code, 'bias addition should be split!' + assert "+" in gm.code, "bias addition should be split!" -if __name__ == '__main__': +if __name__ == "__main__": test_siu_model() test_addmm_model() diff --git a/tests/test_analyzer/test_fx/test_mod_dir.py b/tests/test_analyzer/test_fx/test_mod_dir.py index f62147b297a2ccd6ede7cbbf56804125172fa847..be151b1edd8040e44a4700457a1e108391cf8664 100644 --- a/tests/test_analyzer/test_fx/test_mod_dir.py +++ b/tests/test_analyzer/test_fx/test_mod_dir.py @@ -10,7 +10,6 @@ except: class LinearModel(torch.nn.Module): - def __init__(self, in_features, out_features, bias): super().__init__() self.linear = torch.nn.Linear(in_features, out_features, bias=bias) @@ -21,25 +20,14 @@ class LinearModel(torch.nn.Module): class ConvModel(torch.nn.Module): - def __init__(self, in_channel, out_channels, kernel_size, bias) -> None: super().__init__() - self.conv = torch.nn.Conv2d(in_channel, - out_channels, - kernel_size, - bias=bias, - padding=1, - stride=2, - dilation=2, - groups=3) - self.conv_transpose = torch.nn.ConvTranspose2d(out_channels, - out_channels, - kernel_size, - bias=bias, - padding=1, - stride=2, - dilation=2, - groups=3) + self.conv = torch.nn.Conv2d( + in_channel, out_channels, kernel_size, bias=bias, padding=1, stride=2, dilation=2, groups=3 + ) + self.conv_transpose = torch.nn.ConvTranspose2d( + out_channels, out_channels, kernel_size, bias=bias, padding=1, stride=2, dilation=2, groups=3 + ) def forward(self, x): x = self.conv(x) @@ -48,7 +36,6 @@ class ConvModel(torch.nn.Module): class AModel(torch.nn.Module): - def __init__(self, bias) -> None: super().__init__() self.linear_1 = LinearModel(3, 3, bias) @@ -63,7 +50,7 @@ class AModel(torch.nn.Module): return x -@pytest.mark.skipif(torch.__version__ < '1.12.0', reason='torch version < 12') +@pytest.mark.skipif(torch.__version__ < "1.12.0", reason="torch version < 12") @clear_cache_before_run() @parameterize("bias", [True, False]) @parameterize("bias_addition_split", [True, False]) @@ -71,11 +58,11 @@ class AModel(torch.nn.Module): def test_mod_dir(bias, bias_addition_split, shape): model = AModel(bias=bias) x = torch.rand(shape) - gm = symbolic_trace(model, meta_args={'x': x}, bias_addition_split=bias_addition_split) + gm = symbolic_trace(model, meta_args={"x": x}, bias_addition_split=bias_addition_split) for node in gm.graph.nodes: - assert len(node.meta['info'].mod_dir), f"{node} should have non-trivial ``mod_dir``." - print(node, node.meta['info'].mod_dir) + assert len(node.meta["info"].mod_dir), f"{node} should have non-trivial ``mod_dir``." + print(node, node.meta["info"].mod_dir) -if __name__ == '__main__': +if __name__ == "__main__": test_mod_dir(bias=True, bias_addition_split=True, shape=(3, 3, 3)) diff --git a/tests/test_analyzer/test_fx/test_nested_ckpt.py b/tests/test_analyzer/test_fx/test_nested_ckpt.py index bd16f5a4f95dd2ad4b60221a3b0281ce6b5a766d..d7b96fb9f04363bb590e11c906f9f6333fc9c9dd 100644 --- a/tests/test_analyzer/test_fx/test_nested_ckpt.py +++ b/tests/test_analyzer/test_fx/test_nested_ckpt.py @@ -12,7 +12,6 @@ except: class MyModule(nn.Module): - def __init__(self): super().__init__() self.a = nn.Linear(10, 10) @@ -43,14 +42,14 @@ class MyModule(nn.Module): return checkpoint(self.checkpoint_0, x) -@pytest.mark.skipif(torch.__version__ < '1.12.0', reason='torch version < 12') +@pytest.mark.skipif(torch.__version__ < "1.12.0", reason="torch version < 12") @clear_cache_before_run() def test_nested_ckpt(): model = MyModule() x = torch.rand(10, 10) - gm = symbolic_trace(model, meta_args={'x': x}, trace_act_ckpt=True) + gm = symbolic_trace(model, meta_args={"x": x}, trace_act_ckpt=True) assert torch.allclose(gm(x), model(x)), "The traced model should generate the same output as the original model." - for ckpt_def in filter(lambda s: s.startswith('checkpoint'), dir(model)): + for ckpt_def in filter(lambda s: s.startswith("checkpoint"), dir(model)): assert ckpt_def in gm.code, f"Checkpoint {ckpt_def} should be in the traced code.\n Traced code = {gm.code}" diff --git a/tests/test_analyzer/test_fx/test_shape_prop.py b/tests/test_analyzer/test_fx/test_shape_prop.py index a849feb795e5d4a935892d88ddfc099ec2a6c1b6..609fc9c7b02211d30bbf9df25f4786d3f61d0450 100644 --- a/tests/test_analyzer/test_fx/test_shape_prop.py +++ b/tests/test_analyzer/test_fx/test_shape_prop.py @@ -1,6 +1,5 @@ import pytest import torch -import torchvision.models as tm from packaging import version from colossalai.testing.utils import clear_cache_before_run, parameterize @@ -16,24 +15,25 @@ try: def linear_impl(*args, **kwargs): assert True return torch.nn.functional.linear(*args, **kwargs) + except: pass def _check_gm_validity(gm: torch.fx.GraphModule): for node in gm.graph.nodes: - assert node.meta['info'].outputs, f'In {gm.__class__.__name__}, {node} has no output shape.' + assert node.meta["info"].outputs, f"In {gm.__class__.__name__}, {node} has no output shape." if node.op in [ - 'call_module', # can apply to params - 'call_function', # can apply to params - 'call_method', # can apply to params + "call_module", # can apply to params + "call_function", # can apply to params + "call_method", # can apply to params ]: - assert hasattr(node.meta['info'], 'inputs'), f'In {gm.__class__.__name__}, {node} has no input shape.' + assert hasattr(node.meta["info"], "inputs"), f"In {gm.__class__.__name__}, {node} has no input shape." -@pytest.mark.skipif(version.parse(torch.__version__) < version.parse('1.12.0'), reason='torch version < 12') +@pytest.mark.skipif(version.parse(torch.__version__) < version.parse("1.12.0"), reason="torch version < 12") @clear_cache_before_run() -@parameterize('m', tm_models) +@parameterize("m", tm_models) def test_torchvision_shape_prop(m): with MetaTensorMode(): model = m() @@ -46,9 +46,9 @@ def test_torchvision_shape_prop(m): _check_gm_validity(gm) -@pytest.mark.skipif(version.parse(torch.__version__) < version.parse('1.12.0'), reason='torch version < 12') +@pytest.mark.skipif(version.parse(torch.__version__) < version.parse("1.12.0"), reason="torch version < 12") @clear_cache_before_run() -@parameterize('m', tmm_models) +@parameterize("m", tmm_models) def test_timm_shape_prop(m): with MetaTensorMode(): model = m() diff --git a/tests/test_analyzer/test_fx/test_symbolic_profile.py b/tests/test_analyzer/test_fx/test_symbolic_profile.py index 17deee7a71188ba60def6e79e9995727c3589001..8d8ee2445d5891380bb8e53d5bbc73c72f1ce7e0 100644 --- a/tests/test_analyzer/test_fx/test_symbolic_profile.py +++ b/tests/test_analyzer/test_fx/test_symbolic_profile.py @@ -1,6 +1,5 @@ import pytest import torch -import torchvision.models as tm from packaging import version from colossalai.testing.utils import clear_cache_before_run, parameterize @@ -15,12 +14,12 @@ except: def _check_gm_validity(gm: torch.fx.GraphModule): for node in gm.graph.nodes: - assert len(node.meta['info'].global_ctx), f'In {gm.__class__.__name__}, {node} has empty global context.' + assert len(node.meta["info"].global_ctx), f"In {gm.__class__.__name__}, {node} has empty global context." -@pytest.mark.skipif(version.parse(torch.__version__) < version.parse('1.12.0'), reason='torch version < 12') +@pytest.mark.skipif(version.parse(torch.__version__) < version.parse("1.12.0"), reason="torch version < 12") @clear_cache_before_run() -@parameterize('m', tm_models) +@parameterize("m", tm_models) def test_torchvision_profile(m, verbose=False, bias_addition_split=False): with MetaTensorMode(): model = m() @@ -33,9 +32,9 @@ def test_torchvision_profile(m, verbose=False, bias_addition_split=False): _check_gm_validity(gm) -@pytest.mark.skipif(version.parse(torch.__version__) < version.parse('1.12.0'), reason='torch version < 12') +@pytest.mark.skipif(version.parse(torch.__version__) < version.parse("1.12.0"), reason="torch version < 12") @clear_cache_before_run() -@parameterize('m', tmm_models) +@parameterize("m", tmm_models) def test_timm_profile(m, verbose=False, bias_addition_split=False): with MetaTensorMode(): model = m() diff --git a/tests/test_analyzer/test_subclasses/test_aten.py b/tests/test_analyzer/test_subclasses/test_aten.py index b7858110ac0939c4d6d09825143167e18b32e58f..61c1d25f7b3d1d692c6dfe5b78a556b5ff45bc9f 100644 --- a/tests/test_analyzer/test_subclasses/test_aten.py +++ b/tests/test_analyzer/test_subclasses/test_aten.py @@ -14,35 +14,41 @@ except: aten = torch.ops.aten registered_meta = { - ('aten.convolution.default', True): [ # (aten ops, requires_backward) + ("aten.convolution.default", True): [ # (aten ops, requires_backward) (nn.Conv1d(in_channels=3, out_channels=4, kernel_size=2, padding=1, dilation=2), torch.rand(2, 3, 4)), (nn.Conv2d(in_channels=3, out_channels=4, kernel_size=2, padding=1, dilation=2), torch.rand(2, 3, 4, 4)), (nn.Conv3d(in_channels=3, out_channels=4, kernel_size=2, padding=1, dilation=2), torch.rand(2, 3, 4, 4, 4)), (nn.ConvTranspose1d(in_channels=3, out_channels=4, kernel_size=2, padding=1, dilation=2), torch.rand(2, 3, 4)), - (nn.ConvTranspose2d(in_channels=3, out_channels=4, kernel_size=2, padding=1, - dilation=2), torch.rand(2, 3, 4, 4)), - (nn.ConvTranspose3d(in_channels=3, out_channels=4, kernel_size=2, padding=1, - dilation=2), torch.rand(2, 3, 4, 4, 4)), + ( + nn.ConvTranspose2d(in_channels=3, out_channels=4, kernel_size=2, padding=1, dilation=2), + torch.rand(2, 3, 4, 4), + ), + ( + nn.ConvTranspose3d(in_channels=3, out_channels=4, kernel_size=2, padding=1, dilation=2), + torch.rand(2, 3, 4, 4, 4), + ), ], - ('aten.native_batch_norm.default', True): [ + ("aten.native_batch_norm.default", True): [ (nn.BatchNorm1d(4), torch.rand(2, 4)), (nn.BatchNorm2d(4), torch.rand(1, 4, 4, 4)), (nn.BatchNorm3d(4), torch.rand(1, 4, 4, 4, 4)), ], - ('aten.native_layer_norm.default', True): [(nn.LayerNorm(4), torch.rand(1, 2, 3, 4)),], - ('aten.avg_pool1d.default', True): [ + ("aten.native_layer_norm.default", True): [ + (nn.LayerNorm(4), torch.rand(1, 2, 3, 4)), + ], + ("aten.avg_pool1d.default", True): [ (nn.MaxPool1d(3, stride=2), torch.rand(4, 5, 5)), (nn.AvgPool1d(3, stride=2), torch.rand(4, 5, 5)), (nn.AdaptiveMaxPool1d(3), torch.rand(4, 5, 5)), (nn.AdaptiveAvgPool1d(3), torch.rand(4, 5, 5)), ], - ('aten.avg_pool2d.default', True): [ + ("aten.avg_pool2d.default", True): [ (nn.MaxPool2d((3, 2), stride=(2, 1)), torch.rand(2, 4, 5, 5)), (nn.AvgPool2d((3, 2), stride=(2, 1)), torch.rand(2, 4, 5, 5)), (nn.AdaptiveMaxPool2d((3, 2)), torch.rand(2, 4, 5, 5)), (nn.AdaptiveAvgPool2d((3, 2)), torch.rand(2, 4, 5, 5)), ], - ('aten.relu.default', True): [ + ("aten.relu.default", True): [ (nn.ReLU(), torch.rand(4, 3, 1, 2)), (nn.LeakyReLU(), torch.rand(4, 3, 1, 2)), (nn.SiLU(), torch.rand(4, 3, 1, 2)), @@ -51,15 +57,20 @@ registered_meta = { (nn.Sigmoid(), torch.rand(4, 3, 1, 2)), (nn.Tanh(), torch.rand(4, 3, 1, 2)), (nn.Hardswish(), torch.rand(4, 3, 1, 2)), - ] + ], } def compare_all(tensor: torch.Tensor, meta_tensor: torch.Tensor) -> Any: - assert tensor.shape == meta_tensor.shape, f'the shape of tensor ({tensor.shape}) and meta tensor ({meta_tensor.shape}) does not match.' - assert tensor.dtype == meta_tensor.dtype, f'the dtype of tensor ({tensor.dtype}) and meta tensor ({meta_tensor.dtype}) does not match.' - assert tensor.stride() == meta_tensor.stride( - ), f'the stride of tensor ({tensor.stride()}) and meta tensor ({meta_tensor.stride()}) does not match.' + assert ( + tensor.shape == meta_tensor.shape + ), f"the shape of tensor ({tensor.shape}) and meta tensor ({meta_tensor.shape}) does not match." + assert ( + tensor.dtype == meta_tensor.dtype + ), f"the dtype of tensor ({tensor.dtype}) and meta tensor ({meta_tensor.dtype}) does not match." + assert ( + tensor.stride() == meta_tensor.stride() + ), f"the stride of tensor ({tensor.stride()}) and meta tensor ({meta_tensor.stride()}) does not match." def run_and_compare(f: Union[nn.Module, Callable], x: torch.Tensor, requires_backward=False) -> Any: @@ -73,7 +84,7 @@ def run_and_compare(f: Union[nn.Module, Callable], x: torch.Tensor, requires_bac compare_all(x.grad, meta_x.grad) -@pytest.mark.skipif(torch.__version__ < '1.12.0', reason='torch version < 12') +@pytest.mark.skipif(torch.__version__ < "1.12.0", reason="torch version < 12") @clear_cache_before_run() def test_meta_aten(): for (aten_op, requires_backward), v in registered_meta.items(): @@ -81,5 +92,5 @@ def test_meta_aten(): run_and_compare(f, x, requires_backward) -if __name__ == '__main__': +if __name__ == "__main__": test_meta_aten() diff --git a/tests/test_analyzer/test_subclasses/test_flop_tensor.py b/tests/test_analyzer/test_subclasses/test_flop_tensor.py index da3829e401468be885bc226cb50c67d9ebc8288e..b1b9a89fad970c9142776bf2268a4207a86db06a 100644 --- a/tests/test_analyzer/test_subclasses/test_flop_tensor.py +++ b/tests/test_analyzer/test_subclasses/test_flop_tensor.py @@ -4,7 +4,6 @@ import torch.nn.functional as F import torchvision.models as tm from packaging import version -from colossalai.testing import clear_cache_before_run, parameterize from tests.test_analyzer.test_fx.zoo import tm_models, tmm_models try: @@ -13,41 +12,44 @@ except: pass -@pytest.mark.skipif(version.parse(torch.__version__) < version.parse('1.12.0'), reason='torch version < 12') -@pytest.mark.parametrize('m', tm_models + tmm_models) +@pytest.mark.skipif(version.parse(torch.__version__) < version.parse("1.12.0"), reason="torch version < 12") +@pytest.mark.parametrize("m", tm_models + tmm_models) def test_flop_count_module(m): x = torch.rand(2, 3, 224, 224) - with MetaTensorMode(): # save time for testing + with MetaTensorMode(): # save time for testing module = m() rs_fwd, rs_bwd = flop_count(module, x, verbose=True) - assert rs_fwd > 0, f'fwd flop count of {m.__name__} is {rs_fwd}' - assert rs_bwd > 0, f'bwd flop count of {m.__name__} is {rs_bwd}' + assert rs_fwd > 0, f"fwd flop count of {m.__name__} is {rs_fwd}" + assert rs_bwd > 0, f"bwd flop count of {m.__name__} is {rs_bwd}" odd_cases = [ - (F.relu, (torch.rand(2, 3, 224, 224, requires_grad=True),), { - 'inplace': True - }), - (F.max_pool2d, (torch.rand(2, 3, 224, 224, requires_grad=True),), { - 'kernel_size': 3, - 'stride': 2, - 'padding': 1, - 'dilation': 2 - }), - (torch.where, (torch.rand(2, 3, 224, 224) > 0.5, torch.rand(2, 3, 224, 224, requires_grad=True), - torch.rand(2, 3, 224, 224, requires_grad=True)), {}), + (F.relu, (torch.rand(2, 3, 224, 224, requires_grad=True),), {"inplace": True}), + ( + F.max_pool2d, + (torch.rand(2, 3, 224, 224, requires_grad=True),), + {"kernel_size": 3, "stride": 2, "padding": 1, "dilation": 2}, + ), + ( + torch.where, + ( + torch.rand(2, 3, 224, 224) > 0.5, + torch.rand(2, 3, 224, 224, requires_grad=True), + torch.rand(2, 3, 224, 224, requires_grad=True), + ), + {}, + ), ] -@pytest.mark.skipif(version.parse(torch.__version__) < version.parse('1.12.0'), reason='torch version < 12') -@clear_cache_before_run() -@parameterize('func, args, kwargs', odd_cases) +@pytest.mark.skipif(version.parse(torch.__version__) < version.parse("1.12.0"), reason="torch version < 12") +@pytest.mark.parametrize("func, args, kwargs", odd_cases) def test_flop_count_function(func, args, kwargs): rs_fwd, rs_bwd = flop_count(func, *args, **kwargs, verbose=True) - assert rs_fwd > 0, f'fwd flop count of {func.__name__} is {rs_fwd}' - assert rs_bwd > 0, f'bwd flop count of {func.__name__} is {rs_bwd}' + assert rs_fwd > 0, f"fwd flop count of {func.__name__} is {rs_fwd}" + assert rs_bwd > 0, f"bwd flop count of {func.__name__} is {rs_bwd}" -if __name__ == '__main__': +if __name__ == "__main__": test_flop_count_module(tm.resnet18) - test_flop_count_function(F.relu, (torch.rand(2, 3, 224, 224, requires_grad=True),), {'inplace': True}) + test_flop_count_function(F.relu, (torch.rand(2, 3, 224, 224, requires_grad=True),), {"inplace": True}) diff --git a/tests/test_analyzer/test_subclasses/test_meta_mode.py b/tests/test_analyzer/test_subclasses/test_meta_mode.py index d2a0a1b9cfb590b4a32b2fafe17ede5a9c4de0ef..c55c4ec427033c6025047f28e5914f8ed777004f 100644 --- a/tests/test_analyzer/test_subclasses/test_meta_mode.py +++ b/tests/test_analyzer/test_subclasses/test_meta_mode.py @@ -6,17 +6,22 @@ from packaging import version from colossalai.testing import clear_cache_before_run, parameterize try: - from colossalai._analyzer._subclasses import MetaTensor, MetaTensorMode + from colossalai._analyzer._subclasses import MetaTensorMode except: pass from tests.test_analyzer.test_fx.zoo import tm_models, tmm_models def compare_all(tensor: torch.Tensor, meta_tensor: torch.Tensor): - assert tensor.shape == meta_tensor.shape, f'the shape of tensor ({tensor.shape}) and meta tensor ({meta_tensor.shape}) does not match.' - assert tensor.dtype == meta_tensor.dtype, f'the dtype of tensor ({tensor.dtype}) and meta tensor ({meta_tensor.dtype}) does not match.' - assert tensor.stride() == meta_tensor.stride( - ), f'the stride of tensor ({tensor.stride()}) and meta tensor ({meta_tensor.stride()}) does not match.' + assert ( + tensor.shape == meta_tensor.shape + ), f"the shape of tensor ({tensor.shape}) and meta tensor ({meta_tensor.shape}) does not match." + assert ( + tensor.dtype == meta_tensor.dtype + ), f"the dtype of tensor ({tensor.dtype}) and meta tensor ({meta_tensor.dtype}) does not match." + assert ( + tensor.stride() == meta_tensor.stride() + ), f"the stride of tensor ({tensor.stride()}) and meta tensor ({meta_tensor.stride()}) does not match." def run_and_compare(model): @@ -31,12 +36,12 @@ def run_and_compare(model): compare_all(x.grad, meta_x.grad) -@pytest.mark.skipif(version.parse(torch.__version__) < version.parse('1.12.0'), reason='torch version < 12') +@pytest.mark.skipif(version.parse(torch.__version__) < version.parse("1.12.0"), reason="torch version < 12") @clear_cache_before_run() -@parameterize('m', tm_models + tmm_models) +@parameterize("m", tm_models + tmm_models) def test_meta_mode_shape(m): run_and_compare(m()) -if __name__ == '__main__': +if __name__ == "__main__": test_meta_mode_shape(tm.resnet18) diff --git a/tests/test_auto_parallel/test_ckpt_solvers/test_C_solver_consistency.py b/tests/test_auto_parallel/test_ckpt_solvers/test_C_solver_consistency.py index f184f64b35d020196ded866936a142c6858533c4..03bba8e647721a496ac5d2af941d634d6c3de154 100644 --- a/tests/test_auto_parallel/test_ckpt_solvers/test_C_solver_consistency.py +++ b/tests/test_auto_parallel/test_ckpt_solvers/test_C_solver_consistency.py @@ -6,12 +6,13 @@ import torch.fx import torchvision.models as tm import colossalai -from colossalai.core import global_context as gpc from colossalai.fx import ColoGraphModule, ColoTracer from colossalai.fx._compatibility import is_compatible_with_meta + # from colossalai.fx.passes.algorithms import solver_rotor # from colossalai.fx.passes.algorithms.operation import Sequence from colossalai.fx.passes.meta_info_prop import MetaInfoProp +from colossalai.legacy.core import global_context as gpc from colossalai.testing import rerun_if_address_is_in_use, spawn if is_compatible_with_meta(): @@ -19,18 +20,18 @@ if is_compatible_with_meta(): try: from colossalai.fx.codegen import ActivationCheckpointCodeGen + withcodegen = True except: - from colossalai.fx.codegen import python_code_with_activation_checkpoint withcodegen = False def _run_C_solver_consistency_test(rank, world_size, port): - colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") for M, mem_budget in [(tm.resnet50, 4000), (tm.densenet121, 8080)]: model = M() - data = torch.rand(128, 3, 224, 224, device='meta') + data = torch.rand(128, 3, 224, 224, device="meta") tracer = ColoTracer() graph = tracer.trace(model, meta_args={"x": data}) @@ -54,15 +55,17 @@ def _run_C_solver_consistency_test(rank, world_size, port): for m in range(len(opt_python)): for d in range(1, len(opt_python[0])): for i in range(len(opt_python[0]) - d): - assert opt_python[m][i][i + d] == opt_C[m][i][i + d], \ - f"item ({m}, {i}, {i + d}) is not consistent with python version!\npython version: {opt_python[m][i][i + d]}\nC version: {opt_C[m][i][i + d]}" + assert ( + opt_python[m][i][i + d] == opt_C[m][i][i + d] + ), f"item ({m}, {i}, {i + d}) is not consistent with python version!\npython version: {opt_python[m][i][i + d]}\nC version: {opt_C[m][i][i + d]}" sequence_python = sequence_python.list_operations() sequence_C = sequence_C.list_operations() # make sure the sequences are the same - assert len(sequence_python) == len(sequence_C) and \ - all(python_op.__repr__() == C_op.__repr__() for (python_op, C_op) in zip(sequence_python, sequence_C)) + assert len(sequence_python) == len(sequence_C) and all( + python_op.__repr__() == C_op.__repr__() for (python_op, C_op) in zip(sequence_python, sequence_C) + ) gpc.destroy() @@ -74,5 +77,5 @@ def test_C_solver_consistency(): spawn(_run_C_solver_consistency_test, 1) -if __name__ == '__main__': +if __name__ == "__main__": _run_C_solver_consistency_test(rank=0) diff --git a/tests/test_auto_parallel/test_ckpt_solvers/test_ckpt_torchvision.py b/tests/test_auto_parallel/test_ckpt_solvers/test_ckpt_torchvision.py index db268b91d0a0ed1394b139ff503747f90ee81156..c46f57f75303d03be7bfe257170501a9b161e0d9 100644 --- a/tests/test_auto_parallel/test_ckpt_solvers/test_ckpt_torchvision.py +++ b/tests/test_auto_parallel/test_ckpt_solvers/test_ckpt_torchvision.py @@ -8,12 +8,13 @@ import torchvision.models as tm from torch.fx import GraphModule import colossalai -from colossalai.core import global_context as gpc from colossalai.fx import ColoTracer from colossalai.fx._compatibility import is_compatible_with_meta from colossalai.fx.graph_module import ColoGraphModule + # from colossalai.fx.passes.algorithms import chen_greedy, solver_rotor from colossalai.fx.passes.meta_info_prop import MetaInfoProp +from colossalai.legacy.core import global_context as gpc from colossalai.testing import rerun_if_address_is_in_use, spawn if is_compatible_with_meta(): @@ -21,10 +22,12 @@ if is_compatible_with_meta(): try: from colossalai.fx.codegen import ActivationCheckpointCodeGen + with_codegen = True except: # fall back to older pytorch version from colossalai.fx.codegen import python_code_with_activation_checkpoint + with_codegen = False # SOLVERS = [chen_greedy, solver_rotor] @@ -33,7 +36,7 @@ SOLVERS = [] def _is_activation_checkpoint_available(gm: GraphModule): for n in gm.graph.nodes: - if hasattr(n, 'activation_checkpoint') and getattr(n, 'activation_checkpoint') is not None: + if hasattr(n, "activation_checkpoint") and getattr(n, "activation_checkpoint") is not None: return True @@ -47,15 +50,19 @@ def _is_all_gradient_close(m: torch.nn.Module, gm: GraphModule): def _is_graph_linearized(gm: GraphModule): code = gm.code # find patterns like r' return output_1, output_2', which is not expected on a linearized graph - pattern = re.compile(r' return [a-zA-Z0-9_]+(, [a-zA-Z0-9_]+)+') + pattern = re.compile(r" return [a-zA-Z0-9_]+(, [a-zA-Z0-9_]+)+") if pattern.findall(code): return False else: return True -def check_backward_consistency(m: torch.nn.Module, gm: GraphModule, solver: Callable[[GraphModule], GraphModule], - model_cls: Callable[[], torch.nn.Module]): +def check_backward_consistency( + m: torch.nn.Module, + gm: GraphModule, + solver: Callable[[GraphModule], GraphModule], + model_cls: Callable[[], torch.nn.Module], +): criterion = torch.nn.MSELoss() m.cuda() data = torch.rand(2, 3, 32, 32).cuda() @@ -64,18 +71,18 @@ def check_backward_consistency(m: torch.nn.Module, gm: GraphModule, solver: Call loss.backward() loss = criterion(gm(data), label) loss.backward() - assert _is_all_gradient_close(m, gm), f'Solver {solver} did not work correctly in backward pass on {model_cls}' + assert _is_all_gradient_close(m, gm), f"Solver {solver} did not work correctly in backward pass on {model_cls}" def _run_ckpt_solver(rank, world_size, port): - colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") MODEL_LIST = [tm.densenet121] torch.backends.cudnn.deterministic = True tracer = ColoTracer(trace_act_ckpt=False) - data = torch.rand(8, 3, 224, 224, device='meta') + data = torch.rand(8, 3, 224, 224, device="meta") for solver in SOLVERS: for model_cls in MODEL_LIST: m = model_cls(num_classes=5) @@ -90,27 +97,28 @@ def _run_ckpt_solver(rank, world_size, port): gm = solver(gm) assert _is_graph_linearized(gm), f"Solver {solver} did not solve {model_cls} in a linearized manner." assert _is_activation_checkpoint_available( - gm), f"Solver {solver} did not annotate {model_cls} with any activation checkpoints" + gm + ), f"Solver {solver} did not annotate {model_cls} with any activation checkpoints" check_backward_consistency(m, gm, solver, model_cls) gpc.destroy() @pytest.mark.skip("TODO(super-dainiu): refactor all tests.") -@pytest.mark.skipif(not with_codegen, reason='torch version is lower than 1.12.0') +@pytest.mark.skipif(not with_codegen, reason="torch version is lower than 1.12.0") @rerun_if_address_is_in_use() def test_ckpt_solver(): spawn(_run_ckpt_solver, 1) def _run_ckpt_solver_torch11(rank, world_size, port): - colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") MODEL_LIST = [tm.densenet121] torch.backends.cudnn.deterministic = True tracer = ColoTracer(trace_act_ckpt=False) - data = torch.rand(8, 3, 32, 32, device='meta') + data = torch.rand(8, 3, 32, 32, device="meta") for solver in SOLVERS: for model_cls in MODEL_LIST: m = model_cls(num_classes=5) @@ -124,19 +132,20 @@ def _run_ckpt_solver_torch11(rank, world_size, port): gm = solver(gm) assert _is_graph_linearized(gm), f"Solver {solver} did not solve {model_cls} in a linearized manner." assert _is_activation_checkpoint_available( - gm), f"Solver {solver} did not annotate {model_cls} with any activation checkpoints" + gm + ), f"Solver {solver} did not annotate {model_cls} with any activation checkpoints" check_backward_consistency(m, gm, solver, model_cls) gpc.destroy() -@pytest.mark.skipif(with_codegen, reason='torch version is equal to or higher than 1.12.0') +@pytest.mark.skipif(with_codegen, reason="torch version is equal to or higher than 1.12.0") @pytest.mark.skip(reason="currently torch11 ColoGraphModule is not done") @rerun_if_address_is_in_use() def test_ckpt_solver_torch11(): spawn(_run_ckpt_solver_torch11, 1) -if __name__ == '__main__': +if __name__ == "__main__": _run_ckpt_solver(rank=0) test_ckpt_solver() test_ckpt_solver_torch11() diff --git a/tests/test_auto_parallel/test_ckpt_solvers/test_linearize.py b/tests/test_auto_parallel/test_ckpt_solvers/test_linearize.py index 59880815dc5ebbd97dc7ee7fca52ca502be4207c..bb3be934456600f96680a95ef145cd203e62ad9d 100644 --- a/tests/test_auto_parallel/test_ckpt_solvers/test_linearize.py +++ b/tests/test_auto_parallel/test_ckpt_solvers/test_linearize.py @@ -5,6 +5,7 @@ import torchvision.models as tm from colossalai.fx import ColoTracer from colossalai.fx._compatibility import is_compatible_with_meta from colossalai.fx.graph_module import ColoGraphModule + # from colossalai.fx.passes.algorithms import linearize, solver_rotor # from colossalai.fx.passes.algorithms.operation import (ForwardCheck, ForwardEnable, ForwardNograd, Loss) from colossalai.fx.passes.meta_info_prop import MetaInfoProp @@ -15,14 +16,16 @@ if is_compatible_with_meta(): try: from colossalai.fx.codegen import ActivationCheckpointCodeGen + with_codegen = True except: # fall back to older pytorch version from colossalai.fx.codegen import python_code_with_activation_checkpoint + with_codegen = False -@pytest.mark.skip(reason='TODO: modify the logger') +@pytest.mark.skip(reason="TODO: modify the logger") @pytest.mark.skip("TODO(lyl): refactor all tests.") @pytest.mark.skipif(not with_codegen, reason="torch version is lower than 1.12.0") @clear_cache_before_run() @@ -35,12 +38,12 @@ def test_linearize(): graph = tracer.trace(model) graph.set_codegen(ActivationCheckpointCodeGen()) gm = ColoGraphModule(model, graph, model.__class__.__name__) - MetaInfoProp(gm).run(MetaTensor(torch.rand(128, 3, 224, 224, device="meta"), fake_device='cpu')) + MetaInfoProp(gm).run(MetaTensor(torch.rand(128, 3, 224, 224, device="meta"), fake_device="cpu")) node_list = linearize(gm) gm = solver_rotor(gm, data=torch.rand(128, 3, 224, 224, device="meta"), mem_limit=budget * 1024**2) op_list = gm.__sequence__.list_operations() loss_op = next(op for op in op_list if isinstance(op, Loss)) - op_list = op_list[:op_list.index(loss_op)] + op_list = op_list[: op_list.index(loss_op)] in_ckpt = False ckpt_idx = 0 for idx, op in enumerate(op_list): @@ -48,8 +51,9 @@ def test_linearize(): if isinstance(op, ForwardNograd): for n in node_list[idx]: assert hasattr(n, "activation_checkpoint"), f"{n} is not annotated!" - assert n.activation_checkpoint[ - 0] == ckpt_idx, f"{n} ckpt_idx {n.activation_checkpoint[0]} wrong, should be {ckpt_idx}!" + assert ( + n.activation_checkpoint[0] == ckpt_idx + ), f"{n} ckpt_idx {n.activation_checkpoint[0]} wrong, should be {ckpt_idx}!" continue @@ -65,8 +69,9 @@ def test_linearize(): ckpt_idx += 1 for n in node_list[idx]: assert hasattr(n, "activation_checkpoint"), f"{n} is not annotated!" - assert n.activation_checkpoint[ - 0] == ckpt_idx, f"{n} ckpt_idx {n.activation_checkpoint[0]} wrong, should be {ckpt_idx}!" + assert ( + n.activation_checkpoint[0] == ckpt_idx + ), f"{n} ckpt_idx {n.activation_checkpoint[0]} wrong, should be {ckpt_idx}!" continue @@ -75,8 +80,9 @@ def test_linearize(): in_ckpt = True for n in node_list[idx]: assert hasattr(n, "activation_checkpoint"), f"{n} is not annotated!" - assert n.activation_checkpoint[ - 0] == ckpt_idx, f"{n} ckpt_idx {n.activation_checkpoint[0]} wrong, should be {ckpt_idx}!" + assert ( + n.activation_checkpoint[0] == ckpt_idx + ), f"{n} ckpt_idx {n.activation_checkpoint[0]} wrong, should be {ckpt_idx}!" del model del gm @@ -100,7 +106,7 @@ def test_linearize_torch11(): gm = solver_rotor(gm, data=torch.rand(128, 3, 224, 224, device="meta"), mem_limit=budget * 1024**2) op_list = gm.__sequence__.list_operations() loss_op = next(op for op in op_list if isinstance(op, Loss)) - op_list = op_list[:op_list.index(loss_op)] + op_list = op_list[: op_list.index(loss_op)] in_ckpt = False ckpt_idx = 0 for idx, op in enumerate(op_list): diff --git a/tests/test_auto_parallel/test_offload/model_utils.py b/tests/test_auto_parallel/test_offload/model_utils.py index c22b17ae42ba6b5f201698048dbfd7c0bb628341..0efe84655aac7b654bd7d070c08c3aaf6e18c89a 100644 --- a/tests/test_auto_parallel/test_offload/model_utils.py +++ b/tests/test_auto_parallel/test_offload/model_utils.py @@ -1,25 +1,23 @@ import torch import torch.nn as nn -from transformers import GPT2Config, GPT2LMHeadModel -from transformers import BertConfig, BertLMHeadModel +from transformers import BertConfig, BertLMHeadModel, GPT2Config, GPT2LMHeadModel + from tests.components_to_test.registry import non_distributed_component_funcs -class GPTLMModel(nn.Module): - def __init__(self, - hidden_size=768, - num_layers=12, - num_attention_heads=12, - max_seq_len=1024, - vocab_size=50257): +class GPTLMModel(nn.Module): + def __init__(self, hidden_size=768, num_layers=12, num_attention_heads=12, max_seq_len=1024, vocab_size=50257): super().__init__() self.model = GPT2LMHeadModel( - GPT2Config(n_embd=hidden_size, - n_layer=num_layers, - n_head=num_attention_heads, - n_positions=max_seq_len, - n_ctx=max_seq_len, - vocab_size=vocab_size)) + GPT2Config( + n_embd=hidden_size, + n_layer=num_layers, + n_head=num_attention_heads, + n_positions=max_seq_len, + n_ctx=max_seq_len, + vocab_size=vocab_size, + ) + ) def forward(self, input_ids, attention_mask): # Only return lm_logits @@ -27,7 +25,6 @@ class GPTLMModel(nn.Module): class LMLoss(nn.Module): - def __init__(self): super().__init__() self.loss_fn = nn.CrossEntropyLoss() @@ -38,18 +35,27 @@ class LMLoss(nn.Module): # Flatten the tokens return self.loss_fn(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) + class BertLMModel(nn.Module): def __init__(self, hidden_size=768, num_layers=12, num_attention_heads=32, vocab_size=30522): super().__init__() - self.model = BertLMHeadModel(BertConfig(n_embd=hidden_size, num_hidden_layers=num_layers, hidden_size=hidden_size, - num_attention_heads=num_attention_heads, max_position_embeddings=hidden_size, - vocab_size=vocab_size)) + self.model = BertLMHeadModel( + BertConfig( + n_embd=hidden_size, + num_hidden_layers=num_layers, + hidden_size=hidden_size, + num_attention_heads=num_attention_heads, + max_position_embeddings=hidden_size, + vocab_size=vocab_size, + ) + ) def forward(self, input_ids, attention_mask): # Only return lm_logits return self.model(input_ids=input_ids, attention_mask=attention_mask, use_cache=True)[0] -@non_distributed_component_funcs.register(name='bert_') + +@non_distributed_component_funcs.register(name="bert_") def get_bert_components(): vocab_size = 1024 seq_len = 64 @@ -67,7 +73,8 @@ def get_bert_components(): return bert_model_builder, bert_data_gen -@non_distributed_component_funcs.register(name='gpt2_') + +@non_distributed_component_funcs.register(name="gpt2_") def get_gpt2_components(): vocab_size = 1024 seq_len = 8 @@ -83,4 +90,4 @@ def get_gpt2_components(): kwargs = dict(input_ids=input_ids, attention_mask=attention_mask) return kwargs - return gpt2_model_builder, gpt2_data_gen \ No newline at end of file + return gpt2_model_builder, gpt2_data_gen diff --git a/tests/test_auto_parallel/test_offload/test_perf.py b/tests/test_auto_parallel/test_offload/test_perf.py index 80f134fd85d0007226ce0883b3ee302dbf87ee52..2c8b260e649890f6ddca27409a00f43c7ad871fe 100644 --- a/tests/test_auto_parallel/test_offload/test_perf.py +++ b/tests/test_auto_parallel/test_offload/test_perf.py @@ -17,18 +17,22 @@ from tests.test_auto_parallel.test_offload.model_utils import * from tests.test_tensor.common_utils import set_seed -@parameterize('model_name', ['gpt2_']) -@parameterize('memory_budget', [5000]) -@parameterize('solver_name', ['asyn']) +@parameterize("model_name", ["gpt2_"]) +@parameterize("memory_budget", [5000]) +@parameterize("solver_name", ["asyn"]) def exam_fwd_bwd(model_name: str, memory_budget: float, solver_name: str): - # build model get_components_func = non_distributed_component_funcs.get_callable(model_name) model_builder, data_gen = get_components_func() - label = torch.randint(low=0, high=128, size=( - 64, - 8, - ), device=get_current_device()) + label = torch.randint( + low=0, + high=128, + size=( + 64, + 8, + ), + device=get_current_device(), + ) criterion = LMLoss() set_seed(42) @@ -50,17 +54,19 @@ def exam_fwd_bwd(model_name: str, memory_budget: float, solver_name: str): hybrid_optimizer = HybridAdam(model.model.parameters(), lr=1e-3) optim = AMPOptimizer(hybrid_optimizer, model) - with ColoInitContext(device=torch.device('cpu')): + with ColoInitContext(device=torch.device("cpu")): gemini_model = model_builder() gemini_model.train() hybrid_optimizer = HybridAdam(gemini_model.parameters(), lr=1e-3) - gemini_config = dict(strict_ddp_mode=False, - device=torch.device('cpu'), - placement_policy='cpu', - pin_memory=True, - hidden_dim=8192, - search_range_mb=128) + gemini_config = dict( + strict_ddp_mode=False, + device=torch.device("cpu"), + placement_policy="cpu", + pin_memory=True, + hidden_dim=8192, + search_range_m=128, + ) gemini_model = zero_model_wrapper(gemini_model, 3, gemini_config) optim_config = dict(reduce_bucket_size=12 * 1024 * 1024, overlap_communication=True, verbose=True) gemini_optim = zero_optim_wrapper(gemini_model, hybrid_optimizer, optim_config=optim_config) @@ -89,9 +95,11 @@ def exam_fwd_bwd(model_name: str, memory_budget: float, solver_name: str): exec_time = sum(sorted(time_list)[:5]) / 5 runtime_peak_mem_alc = torch.cuda.max_memory_allocated() / 1024**2 runtime_peak_mem_res = torch.cuda.max_memory_reserved() / 1024**2 - print(f'gemini | model_name: {model_name}') - print(f'| exec_time={exec_time:.3f} s | param_size={param_size:.3f} MB ' - f'| runtime_peak_mem_alc={runtime_peak_mem_alc:.3f} MB| runtime_peak_mem_res={runtime_peak_mem_res:.3f} MB|') + print(f"gemini | model_name: {model_name}") + print( + f"| exec_time={exec_time:.3f} s | param_size={param_size:.3f} MB " + f"| runtime_peak_mem_alc={runtime_peak_mem_alc:.3f} MB| runtime_peak_mem_res={runtime_peak_mem_res:.3f} MB|" + ) print(time_list) del data_args @@ -124,24 +132,26 @@ def exam_fwd_bwd(model_name: str, memory_budget: float, solver_name: str): exec_time = sum(sorted(time_list)[:5]) / 5 runtime_peak_mem_alc = torch.cuda.max_memory_allocated() / 1024**2 runtime_peak_mem_res = torch.cuda.max_memory_reserved() / 1024**2 - print(f'solver_name: {solver_name} | model_name: {model_name}') - print(f'| exec_time={exec_time:.3f} s | param_size={param_size:.3f} MB ' - f'| runtime_peak_mem_alc={runtime_peak_mem_alc:.3f} MB| runtime_peak_mem_res={runtime_peak_mem_res:.3f} MB|') + print(f"solver_name: {solver_name} | model_name: {model_name}") + print( + f"| exec_time={exec_time:.3f} s | param_size={param_size:.3f} MB " + f"| runtime_peak_mem_alc={runtime_peak_mem_alc:.3f} MB| runtime_peak_mem_res={runtime_peak_mem_res:.3f} MB|" + ) print(time_list) def run_dist(rank, world_size, port): config = {} - colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + colossalai.launch(config=config, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") exam_fwd_bwd() @pytest.mark.skip("this test failed") -@pytest.mark.skipif(NOT_NVML, reason='pynvml is not installed') +@pytest.mark.skipif(NOT_NVML, reason="pynvml is not installed") @rerun_if_address_is_in_use() def test_perf(): spawn(run_dist, 1) -if __name__ == '__main__': +if __name__ == "__main__": test_perf() diff --git a/tests/test_auto_parallel/test_offload/test_solver.py b/tests/test_auto_parallel/test_offload/test_solver.py index aa2c9a36849fa9027caa8c66f7d8cf9d95590d0c..6bb53aa6749513b1bd3d2b1a6edaa8b4dcedd62c 100644 --- a/tests/test_auto_parallel/test_offload/test_solver.py +++ b/tests/test_auto_parallel/test_offload/test_solver.py @@ -11,13 +11,12 @@ from colossalai.testing import clear_cache_before_run, parameterize from tests.test_auto_parallel.test_offload.model_utils import * -@pytest.mark.skipif(NOT_NVML, reason='pynvml is not installed') +@pytest.mark.skipif(NOT_NVML, reason="pynvml is not installed") @clear_cache_before_run() -@parameterize('model_name', ['gpt2_', 'bert_']) -@parameterize('memory_budget', [4000]) -@parameterize('solver_name', ['syn', 'asyn']) +@parameterize("model_name", ["gpt2_", "bert_"]) +@parameterize("memory_budget", [4000]) +@parameterize("solver_name", ["syn", "asyn"]) def solver_test(model_name: str, memory_budget: float, solver_name: str): - get_components_func = non_distributed_component_funcs.get_callable(model_name) model_builder, data_gen = get_components_func() data_args = data_gen(device="cpu") @@ -53,15 +52,15 @@ def solver_test(model_name: str, memory_budget: float, solver_name: str): need_offload = region.need_offload to_prefetch = region.fwd_prefetch_region.r_id if region.fwd_prefetch_region is not None else None print( - f'| {model_name} forward | region id: {region.r_id} | need_offload: {need_offload} | to_prefetch: {to_prefetch}' + f"| {model_name} forward | region id: {region.r_id} | need_offload: {need_offload} | to_prefetch: {to_prefetch}" ) for region in region_list.__reversed__(): need_offload = region.need_offload to_prefetch = region.bwd_prefetch_region.r_id if region.bwd_prefetch_region is not None else None print( - f'| {model_name} backward | region id: {region.r_id} | need_offload: {need_offload} | to_prefetch: {to_prefetch}' + f"| {model_name} backward | region id: {region.r_id} | need_offload: {need_offload} | to_prefetch: {to_prefetch}" ) -if __name__ == '__main__': +if __name__ == "__main__": solver_test() diff --git a/tests/test_auto_parallel/test_pass/test_node_converting_pass.py b/tests/test_auto_parallel/test_pass/test_node_converting_pass.py index 429e89aae5d38c5ac76fddb66a67116c0398f6a3..2b89a73656b13a1e77598ca37bf926371b93a96c 100644 --- a/tests/test_auto_parallel/test_pass/test_node_converting_pass.py +++ b/tests/test_auto_parallel/test_pass/test_node_converting_pass.py @@ -1,5 +1,4 @@ import torch -import torch.nn.functional as F from colossalai.auto_parallel.passes.runtime_preparation_pass import node_args_converting_pass from colossalai.device.device_mesh import DeviceMesh @@ -10,7 +9,6 @@ from colossalai.testing import clear_cache_before_run class TestModule(torch.nn.Module): - def forward(self, x): x = x.view(4, 4, 2) return x @@ -19,7 +17,7 @@ class TestModule(torch.nn.Module): def insert_narrow(gm, x_node): graph = gm.graph with graph.inserting_after(x_node): - shard_node = graph.create_node('call_method', 'narrow', args=(x_node, 0, 0, 2), kwargs={}) + shard_node = graph.create_node("call_method", "narrow", args=(x_node, 0, 0, 2), kwargs={}) view_node = list(x_node.users.keys())[0] new_args = list(view_node.args) new_args[0] = shard_node @@ -33,7 +31,7 @@ def test_node_args_converting_pass(): physical_mesh_id = torch.arange(0, 4) mesh_shape = (2, 2) device_mesh = DeviceMesh(physical_mesh_id, mesh_shape) - meta_args = {'x': torch.rand(4, 8).to('meta')} + meta_args = {"x": torch.rand(4, 8).to("meta")} input = torch.rand(4, 8) tracer = ColoTracer() graph = tracer.trace(root=model, meta_args=meta_args) @@ -41,8 +39,8 @@ def test_node_args_converting_pass(): x_node = list(graph.nodes)[0] view_node = list(graph.nodes)[1] sharding_spec = ShardingSpec(device_mesh, entire_shape=(4, 8), dim_partition_dict={0: [0]}) - setattr(x_node, 'sharding_spec', sharding_spec) - setattr(view_node, 'sharding_spec', sharding_spec) + setattr(x_node, "sharding_spec", sharding_spec) + setattr(view_node, "sharding_spec", sharding_spec) gm = ColoGraphModule(model, graph) gm = node_args_converting_pass(gm, device_mesh) @@ -52,5 +50,5 @@ def test_node_args_converting_pass(): assert output.shape == torch.Size([2, 4, 2]) -if __name__ == '__main__': +if __name__ == "__main__": test_node_args_converting_pass() diff --git a/tests/test_auto_parallel/test_pass/test_size_value_converting_pass.py b/tests/test_auto_parallel/test_pass/test_size_value_converting_pass.py index bca81201c6ef03f3263369f2af7a13de67d2262d..b6cc6c9b44fdd19cf63b4760e1f52b3733fc829b 100644 --- a/tests/test_auto_parallel/test_pass/test_size_value_converting_pass.py +++ b/tests/test_auto_parallel/test_pass/test_size_value_converting_pass.py @@ -1,6 +1,5 @@ import pytest import torch -import torch.nn.functional as F from colossalai._analyzer.fx.graph_module import ColoGraphModule from colossalai._analyzer.fx.passes import shape_prop_pass @@ -12,7 +11,6 @@ from colossalai.testing import clear_cache_before_run class TestModule(torch.nn.Module): - def forward(self, x): size = x.size() return size @@ -21,7 +19,7 @@ class TestModule(torch.nn.Module): def insert_narrow(gm, x_node): graph = gm.graph with graph.inserting_after(x_node): - shard_node = graph.create_node('call_method', 'narrow', args=(x_node, 0, 0, 2), kwargs={}) + shard_node = graph.create_node("call_method", "narrow", args=(x_node, 0, 0, 2), kwargs={}) size_node = list(x_node.users.keys())[0] size_node.args = (shard_node,) return gm @@ -36,20 +34,20 @@ def recover_narrow(gm, narrow_node): return gm -@pytest.mark.skip('ShapeProp is not compatible with PyTorch 1.11.0') +@pytest.mark.skip("ShapeProp is not compatible with PyTorch 1.11.0") @clear_cache_before_run() def test_size_value_converting_pass(): model = TestModule() physical_mesh_id = torch.arange(0, 4) mesh_shape = (2, 2) device_mesh = DeviceMesh(physical_mesh_id, mesh_shape) - meta_args = {'x': torch.rand(4, 8).to('meta')} + meta_args = {"x": torch.rand(4, 8).to("meta")} input = torch.rand(4, 8) tracer = ColoTracer(bias_addition_split=True) graph = tracer.trace(root=model, meta_args=meta_args) x_node = list(graph.nodes)[0] x_sharding_spec = ShardingSpec(device_mesh, entire_shape=(4, 8), dim_partition_dict={0: [0]}) - setattr(x_node, 'sharding_spec', x_sharding_spec) + setattr(x_node, "sharding_spec", x_sharding_spec) gm = ColoGraphModule(model, graph) gm = insert_narrow(gm, x_node) shape_prop_pass(gm, *meta_args.values()) @@ -66,5 +64,5 @@ def test_size_value_converting_pass(): assert size == torch.Size([4, 8]) -if __name__ == '__main__': +if __name__ == "__main__": test_size_value_converting_pass() diff --git a/tests/test_auto_parallel/test_tensor_shard/test_bias_addition_forward.py b/tests/test_auto_parallel/test_tensor_shard/test_bias_addition_forward.py index 9fbe674ef4f4a9609e6ff1a8a5b507b64d8be7f1..c41c66745012cc2fdb09e541703bbae2226fbe75 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_bias_addition_forward.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_bias_addition_forward.py @@ -1,10 +1,9 @@ -from functools import partial - import pytest import torch try: from colossalai.auto_parallel.tensor_shard.initialize import initialize_model + NO_CODEGEN = False except: NO_CODEGEN = True @@ -16,7 +15,6 @@ from colossalai.testing import assert_close, rerun_if_address_is_in_use, run_on_ class LinearModel(torch.nn.Module): - def __init__(self, in_features, out_features): super().__init__() self.linear = torch.nn.Linear(in_features, out_features) @@ -29,13 +27,11 @@ class LinearModel(torch.nn.Module): class ConvModel(torch.nn.Module): - def __init__(self, in_channels, out_channels, kernel_size, bias=True): super().__init__() - self.conv = torch.nn.Conv2d(in_channels=in_channels, - out_channels=out_channels, - kernel_size=kernel_size, - bias=bias) + self.conv = torch.nn.Conv2d( + in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, bias=bias + ) def forward(self, x): x = self.conv(x) @@ -46,7 +42,7 @@ class ConvModel(torch.nn.Module): def check_linear_module(rank, world_size, port): disable_existing_loggers() - launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") model = LinearModel(4, 8).cuda() input = torch.rand(4, 4).cuda() output_compare = model(input) @@ -55,7 +51,7 @@ def check_linear_module(rank, world_size, port): # [[0, 1] # [2, 3]] device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True) - meta_args = {'x': torch.rand(4, 4).to('meta')} + meta_args = {"x": torch.rand(4, 4).to("meta")} gm = initialize_model(model, meta_args=meta_args, device_mesh=device_mesh) output = gm(input) assert_close(output, output_compare) @@ -63,7 +59,7 @@ def check_linear_module(rank, world_size, port): def check_conv_module(rank, world_size, port): disable_existing_loggers() - launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") model = ConvModel(3, 6, 2).cuda() input = torch.rand(4, 3, 64, 64).cuda() output_compare = model(input) @@ -72,14 +68,14 @@ def check_conv_module(rank, world_size, port): # [[0, 1] # [2, 3]] device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True) - meta_args = {'x': torch.rand(4, 3, 64, 64).to('meta')} + meta_args = {"x": torch.rand(4, 3, 64, 64).to("meta")} gm = initialize_model(model, meta_args=meta_args, device_mesh=device_mesh) output = gm(input) assert_close(output, output_compare) -@run_on_environment_flag(name='AUTO_PARALLEL') -@pytest.mark.skipif(NO_CODEGEN, reason='No codegen found') +@run_on_environment_flag(name="AUTO_PARALLEL") +@pytest.mark.skipif(NO_CODEGEN, reason="No codegen found") @pytest.mark.dist @rerun_if_address_is_in_use() def test_bias_addition_module(): @@ -87,5 +83,5 @@ def test_bias_addition_module(): spawn(check_conv_module, 4) -if __name__ == '__main__': +if __name__ == "__main__": test_bias_addition_module() diff --git a/tests/test_auto_parallel/test_tensor_shard/test_broadcast.py b/tests/test_auto_parallel/test_tensor_shard/test_broadcast.py index 5607587496f354a9b5a9623a3ea54c539dfd7ade..5cc1820837bb7338ac7dcf02f67901814f5c33e2 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_broadcast.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_broadcast.py @@ -48,17 +48,15 @@ def test_recover_sharding_spec_for_broadcast_shape(): device_mesh = DeviceMesh(physical_mesh_id, mesh_shape) broadcast_shape = get_broadcast_shape(x1.shape, x2.shape) - logical_sharding_spec_for_x1 = ShardingSpec(device_mesh=device_mesh, - dim_partition_dict={ - 0: [0], - 1: [1] - }, - entire_shape=broadcast_shape) + logical_sharding_spec_for_x1 = ShardingSpec( + device_mesh=device_mesh, dim_partition_dict={0: [0], 1: [1]}, entire_shape=broadcast_shape + ) physical_sharding_spec_for_x1, removed_dims = recover_sharding_spec_for_broadcast_shape( - logical_sharding_spec_for_x1, broadcast_shape, x1.shape) + logical_sharding_spec_for_x1, broadcast_shape, x1.shape + ) print(physical_sharding_spec_for_x1) assert physical_sharding_spec_for_x1.entire_shape == x1.shape # dim 1 for the physical tensor is of broadcast type MULTIPLE, so should ignore assert physical_sharding_spec_for_x1.dim_partition_dict == {0: [0]} - assert physical_sharding_spec_for_x1.sharding_sequence == ['S0', 'R', 'R'] + assert physical_sharding_spec_for_x1.sharding_sequence == ["S0", "R", "R"] diff --git a/tests/test_auto_parallel/test_tensor_shard/test_checkpoint.py b/tests/test_auto_parallel/test_tensor_shard/test_checkpoint.py index 398458306e3d34c9dceba3dfc3030d98a2a3e5ad..c800f54da66cb3d7194ee3927d2a820be115ce0f 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_checkpoint.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_checkpoint.py @@ -8,6 +8,7 @@ from transformers.pytorch_utils import Conv1D try: from colossalai.auto_parallel.tensor_shard.initialize import initialize_model + NO_CODEGEN = False except: NO_CODEGEN = True @@ -21,7 +22,6 @@ HIDDEN_SIZE = 16 class GPT2MLPWithCkpt(nn.Module): - def __init__(self, intermediate_size, hidden_size): super().__init__() embed_dim = hidden_size @@ -39,11 +39,11 @@ class GPT2MLPWithCkpt(nn.Module): def check_act_ckpt(rank, world_size, port): disable_existing_loggers() - launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") model = GPT2MLPWithCkpt(intermediate_size=4 * HIDDEN_SIZE, hidden_size=HIDDEN_SIZE) - input = torch.rand(1, 64, HIDDEN_SIZE) + torch.rand(1, 64, HIDDEN_SIZE) input_sample = { - 'hidden_states': torch.rand(1, 64, HIDDEN_SIZE).to('meta'), + "hidden_states": torch.rand(1, 64, HIDDEN_SIZE).to("meta"), } physical_mesh_id = torch.arange(0, 4) mesh_shape = (2, 2) @@ -51,18 +51,24 @@ def check_act_ckpt(rank, world_size, port): # [2, 3]] device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True) gm = initialize_model(model, input_sample, device_mesh) - code = gm.module.graph.python_code('self').src - assert "runtime_comm_spec_apply_1 = colossalai_auto_parallel_passes_runtime_apply_pass_runtime_comm_spec_apply(linear_1, comm_actions_dict, 12, 'linear_1')" in code - assert "view_3 = torch.utils.checkpoint.checkpoint(self.checkpoint_0, view_1, comm_actions_dict, use_reentrant=False)" in code + code = gm.module.graph.python_code("self").src + assert ( + "runtime_comm_spec_apply_1 = colossalai_auto_parallel_passes_runtime_apply_pass_runtime_comm_spec_apply(linear_1, comm_actions_dict, 12, 'linear_1')" + in code + ) + assert ( + "view_3 = torch.utils.checkpoint.checkpoint(self.checkpoint_0, view_1, comm_actions_dict, use_reentrant=False)" + in code + ) -@run_on_environment_flag(name='AUTO_PARALLEL') -@pytest.mark.skipif(NO_CODEGEN, reason='No codegen found') +@run_on_environment_flag(name="AUTO_PARALLEL") +@pytest.mark.skipif(NO_CODEGEN, reason="No codegen found") @pytest.mark.dist @rerun_if_address_is_in_use() def test_mlp_layer(): spawn(check_act_ckpt, 4) -if __name__ == '__main__': +if __name__ == "__main__": test_mlp_layer() diff --git a/tests/test_auto_parallel/test_tensor_shard/test_compatibility_with_ddp.py b/tests/test_auto_parallel/test_tensor_shard/test_compatibility_with_ddp.py index 6908a17818691ecae401f5ada98f94d6c9c84621..e8f175326bb149552ec596be52dec5ab890da595 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_compatibility_with_ddp.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_compatibility_with_ddp.py @@ -6,6 +6,7 @@ from torch.nn.parallel import DistributedDataParallel as DDP try: from colossalai.auto_parallel.tensor_shard.initialize import initialize_model + NO_CODEGEN = False except: NO_CODEGEN = True @@ -17,7 +18,6 @@ from colossalai.testing import assert_close, rerun_if_address_is_in_use, run_on_ class MLP(torch.nn.Module): - def __init__(self, in_features): super().__init__() self.linear_1 = torch.nn.Linear(in_features, 4 * in_features, bias=False) @@ -32,7 +32,7 @@ class MLP(torch.nn.Module): def check_compatibility_with_ddp(rank, world_size, port): disable_existing_loggers() - launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") model = MLP(4).cuda() if rank in [0, 1]: input = torch.arange(0, 16, dtype=torch.float).reshape(4, 4).cuda() @@ -49,26 +49,28 @@ def check_compatibility_with_ddp(rank, world_size, port): # [[0, 1] # [2, 3]] device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True) - meta_args = {'x': torch.rand(4, 4).to('meta')} - gm, solution = initialize_model(model, - meta_args=meta_args, - device_mesh=device_mesh, - return_solution=True, - solver_preference='tp', - shard_option='shard_last_axis') - - msg = '| TP strategy combination chosen by auto-parallel solver |' + meta_args = {"x": torch.rand(4, 4).to("meta")} + gm, solution = initialize_model( + model, + meta_args=meta_args, + device_mesh=device_mesh, + return_solution=True, + solver_preference="tp", + shard_option="shard_last_axis", + ) + + msg = "| TP strategy combination chosen by auto-parallel solver |" msg_length = len(msg) if rank == 0: - print('=' * msg_length) + print("=" * msg_length) print(msg) - print('=' * msg_length) + print("=" * msg_length) for strategy in solution: print(strategy) - print('=' * msg_length) + print("=" * msg_length) dp_process_group = None - for (ranks, process_group_handle) in device_mesh.process_groups_dict[0]: + for ranks, process_group_handle in device_mesh.process_groups_dict[0]: if rank in ranks: dp_process_group = process_group_handle assert dp_process_group is not None @@ -79,7 +81,7 @@ def check_compatibility_with_ddp(rank, world_size, port): assert_close(output, output_compare.narrow(0, 0, 4)) else: assert_close(output, output_compare.narrow(0, 4, 4)) - print(f'output on rank{rank} is correct') + print(f"output on rank{rank} is correct") loss = output.sum() loss.backward() @@ -90,16 +92,16 @@ def check_compatibility_with_ddp(rank, world_size, port): if rank in (1, 3): assert_close(gm.module.module.linear_1.weight.grad, grad_compare.narrow(0, 8, 8)) - print(f'gradient on rank{rank} is correct') + print(f"gradient on rank{rank} is correct") -@run_on_environment_flag(name='AUTO_PARALLEL') -@pytest.mark.skipif(NO_CODEGEN, reason='No codegen found') +@run_on_environment_flag(name="AUTO_PARALLEL") +@pytest.mark.skipif(NO_CODEGEN, reason="No codegen found") @pytest.mark.dist @rerun_if_address_is_in_use() def test_compatibility_with_ddp(): spawn(check_compatibility_with_ddp, 4) -if __name__ == '__main__': +if __name__ == "__main__": test_compatibility_with_ddp() diff --git a/tests/test_auto_parallel/test_tensor_shard/test_compatibility_with_gemini.py b/tests/test_auto_parallel/test_tensor_shard/test_compatibility_with_gemini.py index 05704acbf7fdb9e801ab47d138026f9885e8c92f..aba746f1992dcd451d3b0df39d27f03759412181 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_compatibility_with_gemini.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_compatibility_with_gemini.py @@ -5,6 +5,7 @@ import torch try: from colossalai.auto_parallel.tensor_shard.initialize import initialize_model + NO_CODEGEN = False except: NO_CODEGEN = True @@ -13,14 +14,12 @@ from colossalai.device.device_mesh import DeviceMesh from colossalai.initialize import launch from colossalai.logging import disable_existing_loggers from colossalai.nn.optimizer import HybridAdam -from colossalai.tensor.process_group import ProcessGroup from colossalai.testing import assert_close, rerun_if_address_is_in_use, run_on_environment_flag, spawn from colossalai.utils import get_current_device -from colossalai.zero import post_process_colo_init_ctx, zero_model_wrapper, zero_optim_wrapper +from colossalai.zero import zero_model_wrapper, zero_optim_wrapper class MLP(torch.nn.Module): - def __init__(self, in_features): super().__init__() self.linear_1 = torch.nn.Linear(in_features, 4 * in_features, bias=False) @@ -35,7 +34,7 @@ class MLP(torch.nn.Module): def check_auto_parallel_with_gemini(rank, world_size, port): disable_existing_loggers() - launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") model = MLP(4).half().cuda() if rank in [0, 1]: input = torch.arange(0, 16).reshape(4, 4).half().cuda() @@ -52,32 +51,30 @@ def check_auto_parallel_with_gemini(rank, world_size, port): # [[0, 1] # [2, 3]] device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True) - meta_args = {'x': torch.rand(4, 4).half().to('meta')} - gm, solution = initialize_model(model, - meta_args=meta_args, - device_mesh=device_mesh, - return_solution=True, - solver_preference='tp', - shard_option='shard_last_axis') + meta_args = {"x": torch.rand(4, 4).half().to("meta")} + gm, solution = initialize_model( + model, + meta_args=meta_args, + device_mesh=device_mesh, + return_solution=True, + solver_preference="tp", + shard_option="shard_last_axis", + ) if rank == 0: - msg = '| TP strategy combination chosen by auto-parallel solver |' + msg = "| TP strategy combination chosen by auto-parallel solver |" msg_length = len(msg) - print('=' * msg_length) + print("=" * msg_length) print(msg) - print('=' * msg_length) + print("=" * msg_length) for strategy in solution: print(strategy) - print('=' * msg_length) + print("=" * msg_length) - dp_process_group = ProcessGroup(rank=rank, ranks=[0, 1, 2, 3], tp_degree=2, dp_degree=2) - gemini_config = dict(strict_ddp_mode=False, - device=get_current_device(), - placement_policy='cpu', - pin_memory=True, - search_range_mb=128) + gemini_config = dict( + strict_ddp_mode=False, device=get_current_device(), placement_policy="cpu", pin_memory=True, search_range_m=128 + ) - post_process_colo_init_ctx(gm, device=get_current_device(), default_pg=dp_process_group) gm = zero_model_wrapper(gm, zero_stage=3, gemini_config=gemini_config) optimizer = HybridAdam(gm.parameters(), betas=(0, 0)) optimizer = zero_optim_wrapper(gm, optimizer, initial_scale=1) @@ -86,28 +83,28 @@ def check_auto_parallel_with_gemini(rank, world_size, port): assert_close(output, output_compare.narrow(0, 0, 4)) else: assert_close(output, output_compare.narrow(0, 4, 4)) - print(f'output on rank{rank} is correct') + print(f"output on rank{rank} is correct") loss = output.sum() optimizer.zero_grad() optimizer.backward(loss) optimizer.step() if rank in (0, 2): - assert_close(list(optimizer.optim.state.values())[0]['exp_avg'].half(), grad_compare.narrow(0, 0, 8).flatten()) + assert_close(list(optimizer.optim.state.values())[0]["exp_avg"].half(), grad_compare.narrow(0, 0, 8).flatten()) if rank in (1, 3): - assert_close(list(optimizer.optim.state.values())[0]['exp_avg'].half(), grad_compare.narrow(0, 8, 8).flatten()) + assert_close(list(optimizer.optim.state.values())[0]["exp_avg"].half(), grad_compare.narrow(0, 8, 8).flatten()) - print(f'gradient on rank{rank} is correct') + print(f"gradient on rank{rank} is correct") -@run_on_environment_flag(name='AUTO_PARALLEL') -@pytest.mark.skipif(NO_CODEGEN, reason='No codegen found') +@run_on_environment_flag(name="AUTO_PARALLEL") +@pytest.mark.skipif(NO_CODEGEN, reason="No codegen found") @pytest.mark.dist @rerun_if_address_is_in_use() def test_auto_parallel_with_gemini(): spawn(check_auto_parallel_with_gemini, 4) -if __name__ == '__main__': +if __name__ == "__main__": test_auto_parallel_with_gemini() diff --git a/tests/test_auto_parallel/test_tensor_shard/test_find_repeat_block.py b/tests/test_auto_parallel/test_tensor_shard/test_find_repeat_block.py index a0b407b240e1fc4432002ed50a1e0132348e3ded..a0276acc42935afd99dff0516cbccf87feabd72b 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_find_repeat_block.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_find_repeat_block.py @@ -5,8 +5,8 @@ import torch.nn as nn from torch.fx import GraphModule from transformers.pytorch_utils import Conv1D -from colossalai._analyzer.fx.graph_module import ColoGraphModule from colossalai._analyzer.fx.passes import shape_prop_pass + # from colossalai.fx.tracer.tracer import ColoTracer from colossalai._analyzer.fx.tracer.tracer import ColoTracer from colossalai.auto_parallel.tensor_shard.utils.factory import find_repeat_blocks @@ -19,7 +19,6 @@ HIDDEN_DIM = 384 class RepeatBlock(nn.Module): - def __init__(self, intermediate_size, hidden_size): super().__init__() self.c_fc = Conv1D(intermediate_size, hidden_size) @@ -35,13 +34,11 @@ class RepeatBlock(nn.Module): class RepeatModel(nn.Module): - def __init__(self, intermediate_size, hidden_size, num_layers): super().__init__() self.blocks = nn.ModuleList([RepeatBlock(intermediate_size, hidden_size) for i in range(num_layers)]) def forward(self, x): - for block in self.blocks: x = block(x) @@ -49,10 +46,9 @@ class RepeatModel(nn.Module): class NonRepeatBlock(nn.Module): - def __init__(self, intermediate_size, hidden_size, layer_index): super().__init__() - intermediate_size //= (layer_index + 1) + intermediate_size //= layer_index + 1 self.c_fc = Conv1D(intermediate_size, hidden_size) self.c_proj = Conv1D(hidden_size, intermediate_size) self.act = torch.nn.ReLU() @@ -66,28 +62,25 @@ class NonRepeatBlock(nn.Module): class NonRepeatModel(nn.Module): - def __init__(self, intermediate_size, hidden_size, num_layers): super().__init__() self.blocks = nn.ModuleList([NonRepeatBlock(intermediate_size, hidden_size, i) for i in range(num_layers)]) def forward(self, x): - for block in self.blocks: x = block(x) return x -@run_on_environment_flag(name='AUTO_PARALLEL') +@run_on_environment_flag(name="AUTO_PARALLEL") @clear_cache_before_run() -@parameterize('model_cls', [RepeatModel, NonRepeatModel]) +@parameterize("model_cls", [RepeatModel, NonRepeatModel]) def test_repeat_blocks(model_cls): - model = model_cls(4 * HIDDEN_DIM, HIDDEN_DIM, NUM_REPEAT_BLOCKS) tracer = ColoTracer(bias_addition_split=True) - input_sample = {'x': torch.rand(BATCH_SIZE, SEQ_LENGTH, HIDDEN_DIM).to('meta')} + input_sample = {"x": torch.rand(BATCH_SIZE, SEQ_LENGTH, HIDDEN_DIM).to("meta")} graph = tracer.trace(root=model, meta_args=input_sample) gm = GraphModule(model, graph, model.__class__.__name__) @@ -110,5 +103,5 @@ def test_repeat_blocks(model_cls): assert len(common_blocks) == 0 -if __name__ == '__main__': +if __name__ == "__main__": test_repeat_blocks() diff --git a/tests/test_auto_parallel/test_tensor_shard/test_gpt/gpt_modules.py b/tests/test_auto_parallel/test_tensor_shard/test_gpt/gpt_modules.py index 22a2371311f9125cbc6854f2dbb1141d906c1ebb..3bb7cc409938a9a7b296a67c03c056a839aae5e4 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_gpt/gpt_modules.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_gpt/gpt_modules.py @@ -8,7 +8,6 @@ from transformers.pytorch_utils import Conv1D class GPT2MLP(nn.Module): - def __init__(self, intermediate_size, config): super().__init__() embed_dim = config.hidden_size @@ -34,15 +33,15 @@ class GPT2MLP(nn.Module): # 2. The order of split and view op has been changed in the customized GPT2Attention class, the new # order is same as megatron-lm gpt model. class GPT2Attention(nn.Module): - def __init__(self, config, layer_idx=None): super().__init__() max_positions = config.max_position_embeddings self.register_buffer( "bias", - torch.tril(torch.ones((max_positions, max_positions), - dtype=torch.uint8)).view(1, 1, max_positions, max_positions), + torch.tril(torch.ones((max_positions, max_positions), dtype=torch.uint8)).view( + 1, 1, max_positions, max_positions + ), ) self.register_buffer("masked_bias", torch.tensor(-1e4)) @@ -68,7 +67,7 @@ class GPT2Attention(nn.Module): attn_weights = torch.matmul(query, key.transpose(-1, -2)) if self.scale_attn_weights: - attn_weights = attn_weights / (value.size(-1)**0.5) + attn_weights = attn_weights / (value.size(-1) ** 0.5) # Layer-wise attention scaling if self.scale_attn_by_inverse_layer_idx: @@ -76,7 +75,7 @@ class GPT2Attention(nn.Module): # if only "normal" attention layer implements causal mask query_length, key_length = query.size(-2), key.size(-2) - causal_mask = self.bias[:, :, key_length - query_length:key_length, :key_length].to(torch.bool) + causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length].to(torch.bool) attn_weights = torch.where(causal_mask, attn_weights, self.masked_bias.to(attn_weights.dtype)) if attention_mask is not None: @@ -100,7 +99,7 @@ class GPT2Attention(nn.Module): def _split_heads(self, tensor, num_heads, attn_head_size): new_shape = tensor.size()[:-1] + (num_heads, attn_head_size) tensor = tensor.view(new_shape) - return tensor.permute(0, 2, 1, 3) # (batch, head, seq_length, head_features) + return tensor.permute(0, 2, 1, 3) # (batch, head, seq_length, head_features) def _merge_heads(self, tensor, num_heads, attn_head_size): tensor = tensor.permute(0, 2, 1, 3).contiguous() @@ -113,7 +112,6 @@ class GPT2Attention(nn.Module): attention_mask: Optional[torch.FloatTensor] = None, head_mask: Optional[torch.FloatTensor] = None, ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]], ...]: - # query, key, value = self.c_attn(hidden_states).split(self.split_size, dim=2) qkv = self.c_attn(hidden_states) @@ -121,7 +119,7 @@ class GPT2Attention(nn.Module): # key = self._split_heads(key, self.num_heads, self.head_dim) # value = self._split_heads(value, self.num_heads, self.head_dim) query, key, value = self._split_heads(qkv, self.num_heads, 3 * self.head_dim).split(self.head_dim, dim=3) - present = (key, value) + (key, value) attn_output, attn_weights = self._attn(query, key, value, attention_mask, head_mask) attn_output = self._merge_heads(attn_output, self.num_heads, self.head_dim) @@ -131,7 +129,6 @@ class GPT2Attention(nn.Module): class GPT2Block(nn.Module): - def __init__(self, config, layer_idx=None): super().__init__() hidden_size = config.hidden_size @@ -205,11 +202,9 @@ class GPT2Model(GPT2PreTrainedModel): # GPT2Attention mask. attention_mask = attention_mask.view(batch_size, -1) attention_mask = attention_mask[:, None, None, :] - attention_mask = attention_mask.to(dtype=self.dtype) # fp16 compatibility + attention_mask = attention_mask.to(dtype=self.dtype) # fp16 compatibility attention_mask = (1.0 - attention_mask) * -10000.0 - encoder_attention_mask = None - # Prepare head mask if needed # 1.0 in head_mask indicate we keep the head # attention_probs has shape bsz x n_heads x N x N @@ -267,7 +262,6 @@ class GPT2LMHeadModel(GPT2PreTrainedModel): class GPTLMLoss(nn.Module): - def __init__(self): super().__init__() self.loss_fn = nn.CrossEntropyLoss() diff --git a/tests/test_auto_parallel/test_tensor_shard/test_gpt/test_runtime_with_gpt_modules.py b/tests/test_auto_parallel/test_tensor_shard/test_gpt/test_runtime_with_gpt_modules.py index 48d2672c65714797c53e5e94442e6161bf5550e9..24968e670e3fabf293b753bd6a4b97a13d17fb68 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_gpt/test_runtime_with_gpt_modules.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_gpt/test_runtime_with_gpt_modules.py @@ -9,6 +9,7 @@ import transformers from torch.fx import GraphModule from colossalai._analyzer.fx.passes.shape_prop import shape_prop_pass + # from colossalai.fx.tracer.tracer import ColoTracer from colossalai._analyzer.fx.tracer.tracer import ColoTracer @@ -19,6 +20,7 @@ try: solve_solution, transform_to_sharded_model, ) + NO_CODEGEN = False except: NO_CODEGEN = True @@ -45,14 +47,17 @@ torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False -def _check_module_grad(module: torch.nn.Module, origin_param_dict: Dict[str, torch.Tensor], - best_sharding_spec_dict: Dict[str, ShardingSpec]): +def _check_module_grad( + module: torch.nn.Module, + origin_param_dict: Dict[str, torch.Tensor], + best_sharding_spec_dict: Dict[str, ShardingSpec], +): for name, param in module.named_parameters(): param_grad = param.grad - name = name.replace('module.', '') + name = name.replace("module.", "") origin_param_grad = origin_param_dict[name].grad - atoms = name.split('.') - new_name = '_'.join(atoms) + atoms = name.split(".") + new_name = "_".join(atoms) if new_name in best_sharding_spec_dict: param_sharding_spec = best_sharding_spec_dict[new_name] grad_to_compare = copy.deepcopy(param_grad) @@ -63,19 +68,19 @@ def _check_module_grad(module: torch.nn.Module, origin_param_dict: Dict[str, tor difference = param_grad_global - origin_param_grad avg_diff = difference.abs().sum() / difference.numel() assert avg_diff < 0.001 - print(f'{name} param has {avg_diff} average difference') + print(f"{name} param has {avg_diff} average difference") def check_attention_layer(rank, model_cls, world_size, port): disable_existing_loggers() - launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") config = transformers.GPT2Config(n_position=64, n_layer=2, n_head=16, n_embd=HIDDEN_DIM) if model_cls == GPT2MLP: - model = model_cls(intermediate_size=4 * config.hidden_size, config=config).to('cuda') + model = model_cls(intermediate_size=4 * config.hidden_size, config=config).to("cuda") else: - model = model_cls(config=config).to('cuda') + model = model_cls(config=config).to("cuda") test_model = copy.deepcopy(model) input_ids = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64) @@ -84,30 +89,30 @@ def check_attention_layer(rank, model_cls, world_size, port): hidden_states = torch.rand((BATCH_SIZE, SEQ_LENGTH, HIDDEN_DIM), dtype=torch.float32) if model_cls == GPT2MLP: - input_sample = (hidden_states.to('cuda'),) + input_sample = (hidden_states.to("cuda"),) test_input_sample = copy.deepcopy(input_sample) meta_input_sample = { - 'hidden_states': hidden_states.to('meta'), + "hidden_states": hidden_states.to("meta"), } elif model_cls in (GPT2Attention, GPT2Block): input_sample = ( - hidden_states.to('cuda'), - attention_mask.to('cuda'), + hidden_states.to("cuda"), + attention_mask.to("cuda"), ) test_input_sample = copy.deepcopy(input_sample) meta_input_sample = { - 'hidden_states': hidden_states.to('meta'), - 'attention_mask': attention_mask.to('meta'), + "hidden_states": hidden_states.to("meta"), + "attention_mask": attention_mask.to("meta"), } else: input_sample = ( - input_ids.to('cuda'), - attention_mask.to('cuda'), + input_ids.to("cuda"), + attention_mask.to("cuda"), ) test_input_sample = copy.deepcopy(input_sample) meta_input_sample = { - 'input_ids': input_ids.to('meta'), - 'attention_mask': attention_mask.to('meta'), + "input_ids": input_ids.to("meta"), + "attention_mask": attention_mask.to("meta"), } physical_mesh_id = torch.arange(0, 4) @@ -122,10 +127,11 @@ def check_attention_layer(rank, model_cls, world_size, port): shape_prop_pass(gm, *meta_input_sample.values()) gm.recompile() - strategies_constructor = build_strategy_constructor(graph, device_mesh, 'standard', 'replicated', 'standard') + strategies_constructor = build_strategy_constructor(graph, device_mesh, "standard", "replicated", "standard") solution = solve_solution(gm, strategies_constructor, memory_budget=-1) - gm, sharding_spec_dicts = transform_to_sharded_model(gm, meta_input_sample, solution, device_mesh, - strategies_constructor) + gm, sharding_spec_dicts = transform_to_sharded_model( + gm, meta_input_sample, solution, device_mesh, strategies_constructor + ) gm = ModuleWrapper(gm, *sharding_spec_dicts) nodes = [strategies_vector.node for strategies_vector in strategies_constructor.leaf_strategies] @@ -141,7 +147,7 @@ def check_attention_layer(rank, model_cls, world_size, port): output = gm(*input_sample) assert_close(output, origin_output, rtol=1e-03, atol=1e-03) - #*******************backward starting******************* + # *******************backward starting******************* cuda_rng_state = torch.cuda.get_rng_state() cpu_rng_state = torch.get_rng_state() output.sum().backward() @@ -158,9 +164,9 @@ def check_attention_layer(rank, model_cls, world_size, port): if rank == 0: print("*******************backward finished*******************") - #*******************backward finished******************* + # *******************backward finished******************* - #*******************strategy selected******************* + # *******************strategy selected******************* if rank == 0: print("*******************strategy selected*******************") nodes = [strategies_vector.node for strategies_vector in strategies_constructor.leaf_strategies] @@ -176,19 +182,19 @@ def check_attention_layer(rank, model_cls, world_size, port): node_memory_cost = node_memory_cost[0] memory_cost += node_memory_cost.activation + node_memory_cost.parameter - print(f'computation cost is {computation_cost}') - print(f'communication cost is {communication_cost}') - print(f'memory cost is {memory_cost}') + print(f"computation cost is {computation_cost}") + print(f"communication cost is {communication_cost}") + print(f"memory cost is {memory_cost}") -@run_on_environment_flag(name='AUTO_PARALLEL') +@run_on_environment_flag(name="AUTO_PARALLEL") @pytest.mark.skipif(NO_CODEGEN, reason="no codegen module") @pytest.mark.dist -@parameterize('model_cls', [GPT2MLP, GPT2Block, GPT2Attention, GPT2Model]) +@parameterize("model_cls", [GPT2MLP, GPT2Block, GPT2Attention, GPT2Model]) @rerun_if_address_is_in_use() def test_mlp_layer(model_cls): spawn(check_attention_layer, 4, model_cls=model_cls) -if __name__ == '__main__': +if __name__ == "__main__": test_mlp_layer() diff --git a/tests/test_auto_parallel/test_tensor_shard/test_gpt/test_solver_with_gpt_module.py b/tests/test_auto_parallel/test_tensor_shard/test_gpt/test_solver_with_gpt_module.py index 5a8c3c4bf5a08c0d2c07acce6ff70f17b76212e9..b61cbe1708207a9dd24c5e5eced8285becb707be 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_gpt/test_solver_with_gpt_module.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_gpt/test_solver_with_gpt_module.py @@ -4,7 +4,6 @@ from torch.fx import GraphModule from colossalai._analyzer.fx.passes.shape_prop import shape_prop_pass from colossalai._analyzer.fx.tracer.tracer import ColoTracer -from colossalai.auto_parallel.tensor_shard.constants import BATCHNORM_MODULE_OP from colossalai.auto_parallel.tensor_shard.options import SolverOptions from colossalai.auto_parallel.tensor_shard.solver import CostGraph, Solver, StrategiesConstructor from colossalai.device.device_mesh import DeviceMesh @@ -18,9 +17,9 @@ SEQ_LENGTH = 32 HIDDEN_DIM = 384 -@run_on_environment_flag(name='AUTO_PARALLEL') +@run_on_environment_flag(name="AUTO_PARALLEL") @clear_cache_before_run() -@parameterize('model_cls', [GPT2Block, GPT2Attention, GPT2MLP, GPT2Model]) +@parameterize("model_cls", [GPT2Block, GPT2Attention, GPT2MLP, GPT2Model]) def test_self_attention_block(model_cls): config = transformers.GPT2Config(n_position=64, n_layer=2, n_head=16, n_embd=HIDDEN_DIM) if model_cls == GPT2MLP: @@ -32,23 +31,23 @@ def test_self_attention_block(model_cls): # [[0, 1] # [2, 3]] device_mesh = DeviceMesh(physical_mesh_id, mesh_shape) - shape_consistency_manager = ShapeConsistencyManager() + ShapeConsistencyManager() tracer = ColoTracer(bias_addition_split=True) if model_cls == GPT2MLP: input_sample = { - 'hidden_states': torch.rand(BATCH_SIZE, SEQ_LENGTH, HIDDEN_DIM).to('meta'), + "hidden_states": torch.rand(BATCH_SIZE, SEQ_LENGTH, HIDDEN_DIM).to("meta"), } elif model_cls in (GPT2Attention, GPT2Block): input_sample = { - 'hidden_states': torch.rand(BATCH_SIZE, SEQ_LENGTH, HIDDEN_DIM).to('meta'), - 'attention_mask': torch.rand(1, SEQ_LENGTH).to('meta'), + "hidden_states": torch.rand(BATCH_SIZE, SEQ_LENGTH, HIDDEN_DIM).to("meta"), + "attention_mask": torch.rand(1, SEQ_LENGTH).to("meta"), } else: input_ids = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64) attention_mask = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64) kwargs = dict(input_ids=input_ids, attention_mask=attention_mask) - input_sample = {k: v.to('meta') for k, v in kwargs.items()} + input_sample = {k: v.to("meta") for k, v in kwargs.items()} graph = tracer.trace(root=model, meta_args=input_sample) @@ -63,7 +62,7 @@ def test_self_attention_block(model_cls): cost_graph = CostGraph(strategies_constructor.leaf_strategies) cost_graph.simplify_graph() solver = Solver(gm.graph, strategies_constructor, cost_graph, memory_budget=-1) - ret = solver.call_solver_serialized_args() + solver.call_solver_serialized_args() strategies_list = solver.last_s_val nodes = [strategies_vector.node for strategies_vector in strategies_constructor.leaf_strategies] @@ -79,10 +78,10 @@ def test_self_attention_block(model_cls): node_memory_cost = node_memory_cost[0] memory_cost += node_memory_cost.activation + node_memory_cost.parameter - print(f'computation cost is {computation_cost}') - print(f'communication cost is {communication_cost}') - print(f'memory cost is {memory_cost}') + print(f"computation cost is {computation_cost}") + print(f"communication cost is {communication_cost}") + print(f"memory cost is {memory_cost}") -if __name__ == '__main__': +if __name__ == "__main__": test_self_attention_block() diff --git a/tests/test_auto_parallel/test_tensor_shard/test_liveness_analysis.py b/tests/test_auto_parallel/test_tensor_shard/test_liveness_analysis.py index d10b222c060d7d96e98c05d4b3107db8405c4785..4dd04c69c8a53422490d46d98eb97e5e8b0c9968 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_liveness_analysis.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_liveness_analysis.py @@ -11,7 +11,6 @@ from colossalai.testing import clear_cache_before_run class LinearModel(nn.Module): - def __init__(self): super().__init__() self.linear1 = nn.Linear(4, 4) @@ -27,12 +26,12 @@ class LinearModel(nn.Module): return out -@pytest.mark.skip('meta tensor has some bugs in 1.11') +@pytest.mark.skip("meta tensor has some bugs in 1.11") @clear_cache_before_run() def test_liveness_analysis(): model = LinearModel() tracer = ColoTracer(bias_addition_split=True) - meta_args = {'x1': torch.rand(4, 4, device='meta'), 'x2': torch.rand(4, 4, device='meta')} + meta_args = {"x1": torch.rand(4, 4, device="meta"), "x2": torch.rand(4, 4, device="meta")} graph = tracer.trace(model, meta_args=meta_args) gm = ColoGraphModule(root=model, graph=graph, class_name=model.__class__.__name__) shape_prop_pass(gm, *meta_args.values()) @@ -46,8 +45,8 @@ def test_liveness_analysis(): # a variable named `relu` must exist # and this live var must have inplace = True - assert liveness_list[0].all_live_vars.exists('relu') - relu_var = liveness_list[0].all_live_vars.get('relu') + assert liveness_list[0].all_live_vars.exists("relu") + relu_var = liveness_list[0].all_live_vars.get("relu") assert relu_var.is_inplace # the unique vars must be fewer than the all vars since in-place ops exist @@ -56,5 +55,5 @@ def test_liveness_analysis(): assert len(unique_live_vars) + 1 == len(all_live_vars) -if __name__ == '__main__': +if __name__ == "__main__": test_liveness_analysis() diff --git a/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_activation_metainfo.py b/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_activation_metainfo.py index e0a2133e654eafd3bb3719155a101e0a937f38b8..8831a208cb2f6d172c7a324bf4b837ed65f9c4eb 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_activation_metainfo.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_activation_metainfo.py @@ -7,14 +7,17 @@ from colossalai.testing.utils import clear_cache_before_run, parameterize from tests.test_auto_parallel.test_tensor_shard.test_metainfo.utils import print_results -@pytest.mark.skipif(torch.__version__ < '1.12.0', reason="need pytorch 1.12.0 or higher for aten level operations") +@pytest.mark.skipif(torch.__version__ < "1.12.0", reason="need pytorch 1.12.0 or higher for aten level operations") @clear_cache_before_run() -@parameterize('func', [ - torch.nn.functional.softmax, - torch.nn.functional.relu, - torch.tanh, - torch.nn.functional.dropout, -]) +@parameterize( + "func", + [ + torch.nn.functional.softmax, + torch.nn.functional.relu, + torch.tanh, + torch.nn.functional.dropout, + ], +) def test_activation_meta_info(func): meta_func = meta_register.get(func) # construct meta tensors @@ -23,13 +26,13 @@ def test_activation_meta_info(func): softmax_dim = 0 # construct operation data - input_data = OperationData(name='input', type=OperationDataType.ARG, data=input_tensor) - output_data = OperationData(name='output', type=OperationDataType.OUTPUT, data=output_tensor) - softmax_dim_data = OperationData(name='softmax_dim', type=OperationDataType.ARG, data=softmax_dim) + input_data = OperationData(name="input", type=OperationDataType.ARG, data=input_tensor) + output_data = OperationData(name="output", type=OperationDataType.OUTPUT, data=output_tensor) + softmax_dim_data = OperationData(name="softmax_dim", type=OperationDataType.ARG, data=softmax_dim) # construct args and kwargs args = [input_data, softmax_dim_data, output_data] - kwargs = {'inplace': False} + kwargs = {"inplace": False} # estimated results compute_cost, memory_cost, fwd_in, fwd_buffer, fwd_out = meta_func(*args, **kwargs) @@ -54,9 +57,17 @@ def test_activation_meta_info(func): bwd_allocated = torch.cuda.memory_allocated() - mem_stamp0 bwd_peak = torch.cuda.max_memory_allocated() - mem_stamp0 - print_results([input_real_tensor], [output_real_tensor], compute_cost, memory_cost, fwd_allocated, fwd_peak, - bwd_allocated, bwd_peak) + print_results( + [input_real_tensor], + [output_real_tensor], + compute_cost, + memory_cost, + fwd_allocated, + fwd_peak, + bwd_allocated, + bwd_peak, + ) -if __name__ == '__main__': +if __name__ == "__main__": test_activation_meta_info() diff --git a/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_binary_elementwise_metainfo.py b/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_binary_elementwise_metainfo.py index 68ccc7835bc354ca9979964b144e6a42f0edf560..ba9e282144b7e93942816ed1086d11c0bfaec9d1 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_binary_elementwise_metainfo.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_binary_elementwise_metainfo.py @@ -3,7 +3,6 @@ import torch import torch.nn as nn from colossalai.device.device_mesh import DeviceMesh -from colossalai.fx import ColoGraphModule, ColoTracer from colossalai.initialize import launch from colossalai.logging import disable_existing_loggers from colossalai.testing.pytest_wrapper import run_on_environment_flag @@ -12,7 +11,6 @@ from tests.test_auto_parallel.test_tensor_shard.test_metainfo.utils import mem_t class BinaryElementwiseOpModule(nn.Module): - def __init__(self, token=torch.add, shape=64) -> None: super().__init__() self.token = token @@ -33,7 +31,7 @@ def _binary_elementwise_mem_test(rank, world_size, port): port: port for initializing process group """ disable_existing_loggers() - launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") model = BinaryElementwiseOpModule(token=torch.add, shape=1024).cuda() input = torch.rand(32, 1024).cuda() input.requires_grad = True @@ -45,21 +43,23 @@ def _binary_elementwise_mem_test(rank, world_size, port): node_index = 2 # total number of target node strategies strategy_number = 9 - mem_test_for_node_strategy(rank=rank, - model=model, - device_mesh=device_mesh, - node_index=node_index, - strategy_number=strategy_number, - input_args=[input], - meta_arg_names=['input']) + mem_test_for_node_strategy( + rank=rank, + model=model, + device_mesh=device_mesh, + node_index=node_index, + strategy_number=strategy_number, + input_args=[input], + meta_arg_names=["input"], + ) -@run_on_environment_flag(name='AUTO_PARALLEL') +@run_on_environment_flag(name="AUTO_PARALLEL") @pytest.mark.dist @rerun_if_address_is_in_use() def test_binary_elementwise_meta_concrete_info_match(): spawn(_binary_elementwise_mem_test, 4) -if __name__ == '__main__': +if __name__ == "__main__": test_binary_elementwise_meta_concrete_info_match() diff --git a/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_conv_metainfo.py b/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_conv_metainfo.py index c6f7b88f44a50042cd96ce0548ca47906bf167fd..45558154547fe8d7933a762e2c55aa8df8d7d10e 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_conv_metainfo.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_conv_metainfo.py @@ -11,7 +11,6 @@ from tests.test_auto_parallel.test_tensor_shard.test_metainfo.utils import mem_t class ConvFunctionModule(nn.Module): - def __init__(self, in_channels=4, out_channels=64, kernel_size=3): super().__init__() self.conv_weight = nn.Parameter(torch.randn(out_channels, in_channels, kernel_size, kernel_size)) @@ -32,7 +31,7 @@ def _conv_module_mem_test(rank, world_size, port, bias): port: port for initializing process group """ disable_existing_loggers() - launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") model = nn.Sequential(nn.Conv2d(4, 64, 3, padding=1, bias=bias)).cuda() input = torch.rand(4, 4, 64, 64).cuda() input.requires_grad = True @@ -44,16 +43,18 @@ def _conv_module_mem_test(rank, world_size, port, bias): node_index = 1 # total number of target node strategies strategy_number = 16 - mem_test_for_node_strategy(rank=rank, - model=model, - device_mesh=device_mesh, - node_index=node_index, - strategy_number=strategy_number, - input_args=[input], - meta_arg_names=['input']) - - -@run_on_environment_flag(name='AUTO_PARALLEL') + mem_test_for_node_strategy( + rank=rank, + model=model, + device_mesh=device_mesh, + node_index=node_index, + strategy_number=strategy_number, + input_args=[input], + meta_arg_names=["input"], + ) + + +@run_on_environment_flag(name="AUTO_PARALLEL") @pytest.mark.dist @rerun_if_address_is_in_use() def test_conv_meta_concrete_info_match(bias=False): @@ -71,7 +72,7 @@ def _conv_function_mem_test(rank, world_size, port): port: port for initializing process group """ disable_existing_loggers() - launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") model = ConvFunctionModule().cuda() input = torch.rand(4, 4, 64, 64).cuda() input.requires_grad = True @@ -83,22 +84,24 @@ def _conv_function_mem_test(rank, world_size, port): node_index = 2 # total number of target node strategies strategy_number = 16 - mem_test_for_node_strategy(rank=rank, - model=model, - device_mesh=device_mesh, - node_index=node_index, - strategy_number=strategy_number, - input_args=[input], - meta_arg_names=['input']) - - -@run_on_environment_flag(name='AUTO_PARALLEL') + mem_test_for_node_strategy( + rank=rank, + model=model, + device_mesh=device_mesh, + node_index=node_index, + strategy_number=strategy_number, + input_args=[input], + meta_arg_names=["input"], + ) + + +@run_on_environment_flag(name="AUTO_PARALLEL") @pytest.mark.dist @rerun_if_address_is_in_use() def test_conv_function_concrete_info_match(): spawn(_conv_function_mem_test, 4) -if __name__ == '__main__': +if __name__ == "__main__": # test_conv_meta_concrete_info_match() test_conv_function_concrete_info_match() diff --git a/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_embedding_metainfo.py b/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_embedding_metainfo.py index e3f76a95c4a5f4d7832b5916e4172276d2c43a41..5d830d769c2dfa1f4ae5584c5dfb9f3b712f67ab 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_embedding_metainfo.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_embedding_metainfo.py @@ -5,11 +5,11 @@ from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationDat from colossalai.testing.utils import clear_cache_before_run from tests.test_auto_parallel.test_tensor_shard.test_metainfo.utils import print_results -if torch.__version__ >= '1.12.0': +if torch.__version__ >= "1.12.0": from colossalai.auto_parallel.meta_profiler import meta_register -@pytest.mark.skipif(torch.__version__ < '1.12.0', reason="need pytorch 1.12.0 or higher for aten level operations") +@pytest.mark.skipif(torch.__version__ < "1.12.0", reason="need pytorch 1.12.0 or higher for aten level operations") @clear_cache_before_run() def test_embedding_meta_info(): meta_func = meta_register.get(torch.nn.Embedding) @@ -28,7 +28,7 @@ def test_embedding_meta_info(): # construct args and kwargs args = [input_data, weight_data, output_data] - kwargs = {'inplace': False} + kwargs = {"inplace": False} # estimated results compute_cost, memory_cost, fwd_in, fwd_buffer, fwd_out = meta_func(*args, **kwargs) @@ -52,9 +52,17 @@ def test_embedding_meta_info(): bwd_allocated = torch.cuda.memory_allocated() - mem_stamp0 bwd_peak = torch.cuda.max_memory_allocated() - mem_stamp0 - print_results([input_real_tensor], [output_real_tensor], compute_cost, memory_cost, fwd_allocated, fwd_peak, - bwd_allocated, bwd_peak) + print_results( + [input_real_tensor], + [output_real_tensor], + compute_cost, + memory_cost, + fwd_allocated, + fwd_peak, + bwd_allocated, + bwd_peak, + ) -if __name__ == '__main__': +if __name__ == "__main__": test_embedding_meta_info() diff --git a/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_linear_metainfo.py b/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_linear_metainfo.py index fb3ded339ddf17d6fa2c9f0927850bf16631035d..639870c89a8229720577e13f69a54f38822d1f6e 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_linear_metainfo.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_linear_metainfo.py @@ -11,7 +11,6 @@ from tests.test_auto_parallel.test_tensor_shard.test_metainfo.utils import mem_t class MyModule(nn.Module): - def __init__(self, in_features=64, out_features=128): super().__init__() self.fc_weight = nn.Parameter(torch.randn(out_features, in_features)) @@ -31,7 +30,7 @@ def _linear_module_mem_test(rank, world_size, port): port: port for initializing process group """ disable_existing_loggers() - launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") model = nn.Sequential(nn.Linear(64, 128, bias=False)).cuda() input = torch.rand(8, 8, 16, 64).cuda() input.requires_grad = True @@ -40,16 +39,18 @@ def _linear_module_mem_test(rank, world_size, port): device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True) # memory test - mem_test_for_node_strategy(rank=rank, - model=model, - device_mesh=device_mesh, - node_index=1, - strategy_number=13, - input_args=[input], - meta_arg_names=["input"]) - - -@run_on_environment_flag(name='AUTO_PARALLEL') + mem_test_for_node_strategy( + rank=rank, + model=model, + device_mesh=device_mesh, + node_index=1, + strategy_number=13, + input_args=[input], + meta_arg_names=["input"], + ) + + +@run_on_environment_flag(name="AUTO_PARALLEL") @pytest.mark.dist @rerun_if_address_is_in_use() def test_linear_module_meta_concrete_info_match(): @@ -67,7 +68,7 @@ def _linear_function_mem_test(rank, world_size, port): port: port for initializing process group """ disable_existing_loggers() - launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") model = MyModule().cuda() input = torch.rand(8, 8, 16, 64).cuda() input.requires_grad = True @@ -76,22 +77,24 @@ def _linear_function_mem_test(rank, world_size, port): device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True) # memory test - mem_test_for_node_strategy(rank=rank, - model=model, - device_mesh=device_mesh, - node_index=2, - strategy_number=24, - input_args=[input], - meta_arg_names=["input"]) - - -@run_on_environment_flag(name='AUTO_PARALLEL') + mem_test_for_node_strategy( + rank=rank, + model=model, + device_mesh=device_mesh, + node_index=2, + strategy_number=24, + input_args=[input], + meta_arg_names=["input"], + ) + + +@run_on_environment_flag(name="AUTO_PARALLEL") @pytest.mark.dist @rerun_if_address_is_in_use() def test_linear_function_meta_concrete_info_match(): spawn(_linear_function_mem_test, 4) -if __name__ == '__main__': +if __name__ == "__main__": # test_linear_module_meta_concrete_info_match() test_linear_function_meta_concrete_info_match() diff --git a/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_matmul_metainfo.py b/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_matmul_metainfo.py index 2d2d77f0c637092664174ab8600f55bdfc4c1a46..b182dd02ca762fd3ee513bb900506f8827e043e9 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_matmul_metainfo.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_matmul_metainfo.py @@ -5,26 +5,27 @@ from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationDat from colossalai.testing.utils import clear_cache_before_run, parameterize from tests.test_auto_parallel.test_tensor_shard.test_metainfo.utils import print_results -if torch.__version__ >= '1.12.0': - from colossalai.auto_parallel.meta_profiler import ShardMetaInfo, meta_register +if torch.__version__ >= "1.12.0": + from colossalai.auto_parallel.meta_profiler import meta_register -@pytest.mark.skipif(torch.__version__ < '1.12.0', reason="need pytorch 1.12.0 or higher for aten level operations") +@pytest.mark.skipif(torch.__version__ < "1.12.0", reason="need pytorch 1.12.0 or higher for aten level operations") @clear_cache_before_run() @parameterize( - 'tensor_shapes', + "tensor_shapes", [ - [[128], [128]], # dot product - [[64, 128], [128]], # mat-vec - [[128], [128, 64]], # vec-mat - [[64, 64, 128], [128]], # batched mat-vec - [[128], [64, 128, 64]], # vec-batched mat - [[64, 128], [128, 192]], # mat-mat - [[64, 64, 128], [128, 192]], # batched mat-mat - [[64, 128], [64, 128, 192]], # mat-batched mat - [[64, 64, 128], [64, 128, 192]], # batched mat-batched mat (matched batch dims) - [[64, 1, 64, 128], [64, 128, 192]], # batched mat-batched mat (unmatched batch dims) - ]) + [[128], [128]], # dot product + [[64, 128], [128]], # mat-vec + [[128], [128, 64]], # vec-mat + [[64, 64, 128], [128]], # batched mat-vec + [[128], [64, 128, 64]], # vec-batched mat + [[64, 128], [128, 192]], # mat-mat + [[64, 64, 128], [128, 192]], # batched mat-mat + [[64, 128], [64, 128, 192]], # mat-batched mat + [[64, 64, 128], [64, 128, 192]], # batched mat-batched mat (matched batch dims) + [[64, 1, 64, 128], [64, 128, 192]], # batched mat-batched mat (unmatched batch dims) + ], +) def test_matmul_function_meta_info(tensor_shapes): meta_func = meta_register.get(torch.matmul) @@ -55,7 +56,7 @@ def test_matmul_function_meta_info(tensor_shapes): # construct args and kwargs args = [input_data, other_data, output_data] - kwargs = {'inplace': False} + kwargs = {"inplace": False} # estimated results compute_cost, memory_cost, fwd_in, fwd_buffer, fwd_out = meta_func(*args, **kwargs) @@ -85,9 +86,17 @@ def test_matmul_function_meta_info(tensor_shapes): compute_cost: TrainCycleItem memory_cost: TrainCycleItem - print_results([input_real_tensor, other_real_tensor], [output_real_tensor], compute_cost, memory_cost, - fwd_allocated, fwd_peak, bwd_allocated, bwd_peak) + print_results( + [input_real_tensor, other_real_tensor], + [output_real_tensor], + compute_cost, + memory_cost, + fwd_allocated, + fwd_peak, + bwd_allocated, + bwd_peak, + ) -if __name__ == '__main__': +if __name__ == "__main__": test_matmul_function_meta_info() diff --git a/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_norm_metainfo.py b/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_norm_metainfo.py index 808172977b6046dd55a6e1ee8c8a251774cb5a1f..ed809a758dfdcb738e05c48acdea77584fb70396 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_norm_metainfo.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_norm_metainfo.py @@ -10,7 +10,7 @@ from colossalai.testing.pytest_wrapper import run_on_environment_flag from colossalai.testing.utils import parameterize, rerun_if_address_is_in_use, spawn from tests.test_auto_parallel.test_tensor_shard.test_metainfo.utils import mem_test_for_node_strategy, print_results -if torch.__version__ >= '1.12.0': +if torch.__version__ >= "1.12.0": from colossalai.auto_parallel.meta_profiler import meta_register @@ -25,7 +25,7 @@ def _batchnorm_module_mem_test(rank, world_size, port): port: port for initializing process group """ disable_existing_loggers() - launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") model = nn.Sequential(nn.BatchNorm2d(128)).cuda() input = torch.rand(4, 128, 64, 64).cuda() input.requires_grad = True @@ -37,27 +37,32 @@ def _batchnorm_module_mem_test(rank, world_size, port): node_index = 1 # total number of target node strategies strategy_number = 9 - mem_test_for_node_strategy(rank=rank, - model=model, - device_mesh=device_mesh, - node_index=node_index, - strategy_number=strategy_number, - input_args=[input], - meta_arg_names=['input']) - - -@run_on_environment_flag(name='AUTO_PARALLEL') + mem_test_for_node_strategy( + rank=rank, + model=model, + device_mesh=device_mesh, + node_index=node_index, + strategy_number=strategy_number, + input_args=[input], + meta_arg_names=["input"], + ) + + +@run_on_environment_flag(name="AUTO_PARALLEL") @pytest.mark.dist @rerun_if_address_is_in_use() def test_batchnorm_meta_concrete_info_match(): spawn(_batchnorm_module_mem_test, 4) -@pytest.mark.skipif(torch.__version__ < '1.12.0', reason='need pytorch 1.12.0 or higher for aten level operations') -@parameterize('tensor_shape', [ - [256, 1024], - [1024, 256], -]) +@pytest.mark.skipif(torch.__version__ < "1.12.0", reason="need pytorch 1.12.0 or higher for aten level operations") +@parameterize( + "tensor_shape", + [ + [256, 1024], + [1024, 256], + ], +) def test_layernorm_meta_info(tensor_shape): meta_func = meta_register.get(torch.nn.LayerNorm) @@ -78,7 +83,7 @@ def test_layernorm_meta_info(tensor_shape): # construct args and kwargs args = [input_data, output_data, weight_data, bias_data] - kwargs = {'inplace': False} + kwargs = {"inplace": False} # estimated results compute_cost, memory_cost, fwd_in, fwd_buffer, fwd_out = meta_func(*args, **kwargs) @@ -108,10 +113,18 @@ def test_layernorm_meta_info(tensor_shape): compute_cost: TrainCycleItem memory_cost: TrainCycleItem - print_results([input_real_tensor], [output_real_tensor], compute_cost, memory_cost, fwd_allocated, fwd_peak, - bwd_allocated, bwd_peak) + print_results( + [input_real_tensor], + [output_real_tensor], + compute_cost, + memory_cost, + fwd_allocated, + fwd_peak, + bwd_allocated, + bwd_peak, + ) -if __name__ == '__main__': +if __name__ == "__main__": test_batchnorm_meta_concrete_info_match() test_layernorm_meta_info() diff --git a/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_pooling_metainfo.py b/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_pooling_metainfo.py index 4cddf4e19fcabd1174854beebd89be6bafe64b74..bd1deb40ca7ba923ddf914ca18c4b4befb2f995b 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_pooling_metainfo.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_pooling_metainfo.py @@ -21,7 +21,7 @@ def _adaptiveavgpool_module_mem_test(rank, world_size, port): port: port for initializing process group """ disable_existing_loggers() - launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") model = nn.Sequential(nn.AdaptiveAvgPool2d((16, 16))).cuda() input = torch.rand(4, 128, 64, 64).cuda() input.requires_grad = True @@ -33,16 +33,18 @@ def _adaptiveavgpool_module_mem_test(rank, world_size, port): node_index = 1 # total number of target strategies strategy_number = 1 - mem_test_for_node_strategy(rank=rank, - model=model, - device_mesh=device_mesh, - node_index=node_index, - strategy_number=strategy_number, - input_args=[input], - meta_arg_names=['input']) + mem_test_for_node_strategy( + rank=rank, + model=model, + device_mesh=device_mesh, + node_index=node_index, + strategy_number=strategy_number, + input_args=[input], + meta_arg_names=["input"], + ) -@run_on_environment_flag(name='AUTO_PARALLEL') +@run_on_environment_flag(name="AUTO_PARALLEL") @pytest.mark.dist @rerun_if_address_is_in_use() def test_adaptiveavgpool_meta_concrete_info_match(): @@ -60,7 +62,7 @@ def _maxpool_module_mem_test(rank, world_size, port): port: port for initializing process group """ disable_existing_loggers() - launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") model = nn.Sequential(nn.MaxPool2d((16, 16))).cuda() input = torch.rand(4, 128, 64, 64).cuda() input.requires_grad = True @@ -72,22 +74,24 @@ def _maxpool_module_mem_test(rank, world_size, port): node_index = 1 # total number of target node strategies strategy_number = 9 - mem_test_for_node_strategy(rank=rank, - model=model, - device_mesh=device_mesh, - node_index=node_index, - strategy_number=strategy_number, - input_args=[input], - meta_arg_names=['input']) + mem_test_for_node_strategy( + rank=rank, + model=model, + device_mesh=device_mesh, + node_index=node_index, + strategy_number=strategy_number, + input_args=[input], + meta_arg_names=["input"], + ) -@run_on_environment_flag(name='AUTO_PARALLEL') +@run_on_environment_flag(name="AUTO_PARALLEL") @pytest.mark.dist @rerun_if_address_is_in_use() def test_maxpool_meta_concrete_info_match(): spawn(_maxpool_module_mem_test, 4) -if __name__ == '__main__': +if __name__ == "__main__": test_adaptiveavgpool_meta_concrete_info_match() test_maxpool_meta_concrete_info_match() diff --git a/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_tensor_metainfo.py b/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_tensor_metainfo.py index 6e8145885d67dc9596c71015ae8c0e134fe29ad6..a29291e9b4d95231b4f5a2a1e25bca30914fb7c5 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_tensor_metainfo.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_tensor_metainfo.py @@ -6,12 +6,11 @@ from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationDat from colossalai.testing.utils import clear_cache_before_run from tests.test_auto_parallel.test_tensor_shard.test_metainfo.utils import print_results -if torch.__version__ >= '1.12.0': - from colossalai.auto_parallel.meta_profiler import ShardMetaInfo, meta_register +if torch.__version__ >= "1.12.0": + from colossalai.auto_parallel.meta_profiler import meta_register class SplitModule(nn.Module): - def __init__(self) -> None: super().__init__() @@ -19,7 +18,7 @@ class SplitModule(nn.Module): return x.split(512, dim=0) -@pytest.mark.skipif(torch.__version__ < '1.12.0', reason="need pytorch 1.12.0 or higher for aten level operations") +@pytest.mark.skipif(torch.__version__ < "1.12.0", reason="need pytorch 1.12.0 or higher for aten level operations") @clear_cache_before_run() def test_tensor_meta_info(): """test tensor related meta information @@ -45,7 +44,7 @@ def test_tensor_meta_info(): logical_shape=input_tensor.shape, ) split_info_data = OperationData( - name='split_info', + name="split_info", type=OperationDataType.ARG, data=0, logical_shape=None, @@ -53,7 +52,7 @@ def test_tensor_meta_info(): # construct args args = [input_data, output_data, split_info_data] - kwargs = {'inplace': False} + kwargs = {"inplace": False} # estimated results compute_cost, memory_cost, fwd_in, fwd_buffer, fwd_out = meta_func(*args, **kwargs) @@ -79,8 +78,16 @@ def test_tensor_meta_info(): bwd_allocated = torch.cuda.memory_allocated() - mem_stamp0 bwd_peak = torch.cuda.max_memory_allocated() - mem_stamp0 - print_results([input_real_tensor], output_real_tensor, compute_cost, memory_cost, fwd_allocated, fwd_peak, - bwd_allocated, bwd_peak) + print_results( + [input_real_tensor], + output_real_tensor, + compute_cost, + memory_cost, + fwd_allocated, + fwd_peak, + bwd_allocated, + bwd_peak, + ) if __name__ == "__main__": diff --git a/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_where_metainfo.py b/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_where_metainfo.py index b4564312eeb4698e045eab5263d7f68147484b56..64d9ccd3def2dc3b3c2d0eda8424c2b0288e917f 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_where_metainfo.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_where_metainfo.py @@ -5,11 +5,11 @@ from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationDat from colossalai.testing.utils import clear_cache_before_run from tests.test_auto_parallel.test_tensor_shard.test_metainfo.utils import print_results -if torch.__version__ >= '1.12.0': - from colossalai.auto_parallel.meta_profiler import ShardMetaInfo, meta_register +if torch.__version__ >= "1.12.0": + from colossalai.auto_parallel.meta_profiler import meta_register -@pytest.mark.skipif(torch.__version__ < '1.12.0', reason="need pytorch 1.12.0 or higher for aten level operations") +@pytest.mark.skipif(torch.__version__ < "1.12.0", reason="need pytorch 1.12.0 or higher for aten level operations") @clear_cache_before_run() def test_where_meta_info(): meta_func = meta_register.get(torch.where) @@ -49,7 +49,7 @@ def test_where_meta_info(): # construct args and kwargs args = [condition_data, x_data, y_data, output_data] - kwargs = {'inplace': False} + kwargs = {"inplace": False} # estimated results compute_cost, memory_cost, fwd_in, fwd_buffer, fwd_out = meta_func(*args, **kwargs) @@ -81,9 +81,17 @@ def test_where_meta_info(): compute_cost: TrainCycleItem memory_cost: TrainCycleItem - print_results([condition_real_tensor, x_real_tensor, y_real_tensor], [output_real_tensor], compute_cost, - memory_cost, fwd_allocated, fwd_peak, bwd_allocated, bwd_peak) + print_results( + [condition_real_tensor, x_real_tensor, y_real_tensor], + [output_real_tensor], + compute_cost, + memory_cost, + fwd_allocated, + fwd_peak, + bwd_allocated, + bwd_peak, + ) -if __name__ == '__main__': +if __name__ == "__main__": test_where_meta_info() diff --git a/tests/test_auto_parallel/test_tensor_shard/test_metainfo/utils.py b/tests/test_auto_parallel/test_tensor_shard/test_metainfo/utils.py index 4ca85d34da3097a54d1c568c09224e7087bc45cc..e58d15cec50b44d56fef938dfeb74f3406c67ab0 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_metainfo/utils.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_metainfo/utils.py @@ -7,6 +7,7 @@ from torch.fx import GraphModule from colossalai._analyzer.fx.graph_module import ColoGraphModule from colossalai._analyzer.fx.passes import shape_prop_pass + # from colossalai.fx.tracer.tracer import ColoTracer from colossalai._analyzer.fx.tracer.tracer import ColoTracer from colossalai.auto_parallel.passes.runtime_apply_pass import runtime_apply_pass @@ -16,29 +17,34 @@ from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationDat from colossalai.auto_parallel.tensor_shard.solver import StrategiesConstructor from colossalai.device.device_mesh import DeviceMesh -if torch.__version__ >= '1.12.0': +if torch.__version__ >= "1.12.0": from colossalai.auto_parallel.meta_profiler import ShardMetaInfo -def mem_test_for_node_strategy(rank: int, - model: torch.nn.Module, - device_mesh: DeviceMesh, - node_index: int, - strategy_number: int, - input_args: List[torch.Tensor], - meta_arg_names: List[str], - input_kwargs: Dict[str, torch.Tensor] = {}): +def mem_test_for_node_strategy( + rank: int, + model: torch.nn.Module, + device_mesh: DeviceMesh, + node_index: int, + strategy_number: int, + input_args: List[torch.Tensor], + meta_arg_names: List[str], + input_kwargs: Dict[str, torch.Tensor] = {}, +): for strategy_index in range(strategy_number): # We need to copy the model to avoid do backward more than once in same graph - model_to_shard, args_to_shard, kwargs_to_shard = copy.deepcopy(model), copy.deepcopy(input_args), copy.deepcopy( - input_kwargs) + model_to_shard, args_to_shard, kwargs_to_shard = ( + copy.deepcopy(model), + copy.deepcopy(input_args), + copy.deepcopy(input_kwargs), + ) tracer = ColoTracer(bias_addition_split=True) input_sample = {} for input_arg, meta_arg_name in zip(input_args, meta_arg_names): - input_sample[meta_arg_name] = torch.rand(input_arg.shape).to('meta') + input_sample[meta_arg_name] = torch.rand(input_arg.shape).to("meta") for meta_kwarg_name, input_kwarg in input_kwargs.items(): - input_sample[meta_kwarg_name] = torch.rand(input_kwarg.shape).to('meta') + input_sample[meta_kwarg_name] = torch.rand(input_kwarg.shape).to("meta") graph = tracer.trace(root=model_to_shard, meta_args=input_sample) gm = ColoGraphModule(model_to_shard, graph, model_to_shard.__class__.__name__) shape_prop_pass(gm, *input_sample.values()) @@ -57,13 +63,18 @@ def mem_test_for_node_strategy(rank: int, # construct the strategy for the output node placeholder_strategy = list(graph.nodes)[-1].strategies_vector[0] - output_key = next(key for key in target_node.strategies_vector[strategy_index].sharding_specs.keys() - if key.type == OperationDataType.OUTPUT) + output_key = next( + key + for key in target_node.strategies_vector[strategy_index].sharding_specs.keys() + if key.type == OperationDataType.OUTPUT + ) placeholder_strategy.sharding_specs[output_key] = target_node.strategies_vector[strategy_index].sharding_specs[ - output_key] + output_key + ] gm, sharding_spec_dict, origin_spec_dict, comm_actions_dict = runtime_preparation_pass( - gm, solution, device_mesh, strategies_constructor) + gm, solution, device_mesh, strategies_constructor + ) gm = runtime_apply_pass(gm) gm.recompile() gm: GraphModule @@ -76,22 +87,26 @@ def mem_test_for_node_strategy(rank: int, # warmup with torch.no_grad(): - output = gm(*args_to_shard, - sharding_spec_convert_dict=sharding_spec_dict, - origin_node_sharding_spec_dict=origin_spec_dict, - comm_actions_dict=comm_actions_dict, - **kwargs_to_shard) + output = gm( + *args_to_shard, + sharding_spec_convert_dict=sharding_spec_dict, + origin_node_sharding_spec_dict=origin_spec_dict, + comm_actions_dict=comm_actions_dict, + **kwargs_to_shard, + ) del output # forward memory compare if rank == 0: torch.cuda.reset_peak_memory_stats() mem_stamp0 = torch.cuda.memory_allocated() - output = gm(*args_to_shard, - sharding_spec_convert_dict=sharding_spec_dict, - origin_node_sharding_spec_dict=origin_spec_dict, - comm_actions_dict=comm_actions_dict, - **kwargs_to_shard) + output = gm( + *args_to_shard, + sharding_spec_convert_dict=sharding_spec_dict, + origin_node_sharding_spec_dict=origin_spec_dict, + comm_actions_dict=comm_actions_dict, + **kwargs_to_shard, + ) if rank == 0: # print forward memory allocated and peak memory stats in kb @@ -113,8 +128,10 @@ def mem_test_for_node_strategy(rank: int, # estimated memory if target_node.op == "call_module": - metainfo = ShardMetaInfo(target_node.strategies_vector[strategy_index], - target_node.graph.owning_module.get_submodule(target_node.target)) + metainfo = ShardMetaInfo( + target_node.strategies_vector[strategy_index], + target_node.graph.owning_module.get_submodule(target_node.target), + ) else: metainfo = ShardMetaInfo(target_node.strategies_vector[strategy_index], target_node.target) @@ -134,8 +151,16 @@ def mem_test_for_node_strategy(rank: int, print("=======================") -def print_results(input: List[torch.Tensor], output: List[torch.Tensor], compute_cost: TrainCycleItem, - memory_cost: TrainCycleItem, fwd_allocated, fwd_peak, bwd_allocated, bwd_peak): +def print_results( + input: List[torch.Tensor], + output: List[torch.Tensor], + compute_cost: TrainCycleItem, + memory_cost: TrainCycleItem, + fwd_allocated, + fwd_peak, + bwd_allocated, + bwd_peak, +): """Print the results of the meta information test. Args: diff --git a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_addbmm_handler.py b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_addbmm_handler.py index 80e6a6c1460c756c7212ad5dac7fcf24a78ab489..73a15f3ba4de9eb9e4a1318e8246a63afb67caac 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_addbmm_handler.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_addbmm_handler.py @@ -13,7 +13,6 @@ from tests.test_auto_parallel.test_tensor_shard.test_node_handler.utils import n class AddBMMTensorMethodModule(nn.Module): - def __init__(self, using_kwargs): super().__init__() self.using_kwargs = using_kwargs @@ -27,7 +26,6 @@ class AddBMMTensorMethodModule(nn.Module): class AddBMMTorchFunctionModule(nn.Module): - def __init__(self, using_kwargs): super().__init__() self.using_kwargs = using_kwargs @@ -42,7 +40,7 @@ class AddBMMTorchFunctionModule(nn.Module): def check_2d_device_mesh(rank, world_size, port, module, bias_shape, using_kwargs): disable_existing_loggers() - launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") model = module(using_kwargs).cuda() physical_mesh_id = torch.arange(0, 4) mesh_shape = (2, 2) @@ -57,13 +55,15 @@ def check_2d_device_mesh(rank, world_size, port, module, bias_shape, using_kwarg # construct input args input_args = [bias, x1, x2] # construct meta arg names - meta_arg_names = ['bias', 'x1', 'x2'] - numerical_test_for_node_strategy(model=model, - device_mesh=device_mesh, - node_index=node_index, - strategy_number=strategy_number, - input_args=input_args, - meta_arg_names=meta_arg_names) + meta_arg_names = ["bias", "x1", "x2"] + numerical_test_for_node_strategy( + model=model, + device_mesh=device_mesh, + node_index=node_index, + strategy_number=strategy_number, + input_args=input_args, + meta_arg_names=meta_arg_names, + ) tracer = ColoTracer() # graph(): # %bias : torch.Tensor [#users=1] = placeholder[target=bias] @@ -73,13 +73,15 @@ def check_2d_device_mesh(rank, world_size, port, module, bias_shape, using_kwarg # %sum_1 : [#users=1] = call_function[target=torch.sum](args = (%bmm, 0), kwargs = {}) # %add : [#users=1] = call_function[target=operator.add](args = (%sum_1, %bias), kwargs = {}) # return add - graph = tracer.trace(model, - meta_args={ - 'bias': torch.rand(*bias_shape).to('meta'), - "x1": torch.rand(4, 8, 16).to('meta'), - 'x2': torch.rand(4, 16, 8).to('meta') - }) - gm = ColoGraphModule(model, graph) + graph = tracer.trace( + model, + meta_args={ + "bias": torch.rand(*bias_shape).to("meta"), + "x1": torch.rand(4, 8, 16).to("meta"), + "x2": torch.rand(4, 16, 8).to("meta"), + }, + ) + ColoGraphModule(model, graph) bmm_mod_node = list(graph.nodes)[3] strategies_vector = StrategiesVector(bmm_mod_node) @@ -96,49 +98,49 @@ def check_2d_device_mesh(rank, world_size, port, module, bias_shape, using_kwarg assert op_data.logical_shape is not None assert op_data.data is not None - assert mapping['input'].name == "x1" - assert mapping['input'].data.is_meta - assert mapping['input'].data.shape == torch.Size([4, 8, 16]) - assert mapping['input'].type == OperationDataType.ARG - assert mapping['input'].logical_shape == torch.Size([4, 8, 16]) + assert mapping["input"].name == "x1" + assert mapping["input"].data.is_meta + assert mapping["input"].data.shape == torch.Size([4, 8, 16]) + assert mapping["input"].type == OperationDataType.ARG + assert mapping["input"].logical_shape == torch.Size([4, 8, 16]) - assert mapping['other'].name == "x2" - assert mapping['other'].data.is_meta - assert mapping['other'].data.shape == torch.Size([4, 16, 8]) - assert mapping['other'].type == OperationDataType.ARG - assert mapping['other'].logical_shape == torch.Size([4, 16, 8]) + assert mapping["other"].name == "x2" + assert mapping["other"].data.is_meta + assert mapping["other"].data.shape == torch.Size([4, 16, 8]) + assert mapping["other"].type == OperationDataType.ARG + assert mapping["other"].logical_shape == torch.Size([4, 16, 8]) - assert mapping['output'].name == "bmm" - assert mapping['output'].data.is_meta - assert mapping['output'].data.shape == torch.Size([4, 8, 8]) - assert mapping['output'].type == OperationDataType.OUTPUT + assert mapping["output"].name == "bmm" + assert mapping["output"].data.is_meta + assert mapping["output"].data.shape == torch.Size([4, 8, 8]) + assert mapping["output"].type == OperationDataType.OUTPUT strategies_vector = handler.register_strategy(compute_resharding_cost=False) strategy_name_list = [val.name for val in strategies_vector] for name in strategy_name_list: print(name) # one batch dim - assert 'Sb0 = Sb0 x Sb0' not in strategy_name_list + assert "Sb0 = Sb0 x Sb0" not in strategy_name_list # two batch dim - assert 'Sb01 = Sb01 x Sb01' in strategy_name_list + assert "Sb01 = Sb01 x Sb01" in strategy_name_list # SbSi = SbSi x Sb - assert 'Sb0Si1 = Sb0Si1 x Sb0' in strategy_name_list - assert 'Sb1Si0 = Sb1Si0 x Sb1' in strategy_name_list + assert "Sb0Si1 = Sb0Si1 x Sb0" in strategy_name_list + assert "Sb1Si0 = Sb1Si0 x Sb1" in strategy_name_list # SbSj = SbR x SbSj - assert 'Sb0Sj1 = Sb0R x Sb0Sj1' in strategy_name_list - assert 'Sb1Sj0 = Sb1R x Sb1Sj0' in strategy_name_list + assert "Sb0Sj1 = Sb0R x Sb0Sj1" in strategy_name_list + assert "Sb1Sj0 = Sb1R x Sb1Sj0" in strategy_name_list # SbR = SbSk x SbSk - assert 'Sb0R = Sb0Sk1 x Sb0Sk1' in strategy_name_list - assert 'Sb1R = Sb1Sk0 x Sb1Sk0' in strategy_name_list + assert "Sb0R = Sb0Sk1 x Sb0Sk1" in strategy_name_list + assert "Sb1R = Sb1Sk0 x Sb1Sk0" in strategy_name_list for strategy in strategies_vector: - input_sharding_spec = strategy.get_sharding_spec_by_name('x1') - other_sharding_spec = strategy.get_sharding_spec_by_name('x2') - output_sharding_spec = strategy.get_sharding_spec_by_name('bmm') + input_sharding_spec = strategy.get_sharding_spec_by_name("x1") + other_sharding_spec = strategy.get_sharding_spec_by_name("x2") + output_sharding_spec = strategy.get_sharding_spec_by_name("bmm") # make sure the sharding matches across different operation data assert input_sharding_spec.sharding_sequence[0] == output_sharding_spec.sharding_sequence[0] @@ -148,7 +150,7 @@ def check_2d_device_mesh(rank, world_size, port, module, bias_shape, using_kwarg def check_1d_device_mesh(rank, module, bias_shape, using_kwargs, world_size, port): disable_existing_loggers() - launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") physical_mesh_id = torch.arange(0, 4) mesh_shape = (1, 4) device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True) @@ -163,13 +165,15 @@ def check_1d_device_mesh(rank, module, bias_shape, using_kwargs, world_size, por # construct input args input_args = [bias, x1, x2] # construct meta arg names - meta_arg_names = ['bias', 'x1', 'x2'] - numerical_test_for_node_strategy(model=model, - device_mesh=device_mesh, - node_index=node_index, - strategy_number=strategy_number, - input_args=input_args, - meta_arg_names=meta_arg_names) + meta_arg_names = ["bias", "x1", "x2"] + numerical_test_for_node_strategy( + model=model, + device_mesh=device_mesh, + node_index=node_index, + strategy_number=strategy_number, + input_args=input_args, + meta_arg_names=meta_arg_names, + ) tracer = ColoTracer() # graph(): @@ -180,13 +184,15 @@ def check_1d_device_mesh(rank, module, bias_shape, using_kwargs, world_size, por # %sum_1 : [#users=1] = call_function[target=torch.sum](args = (%bmm, 0), kwargs = {}) # %add : [#users=1] = call_function[target=operator.add](args = (%sum_1, %bias), kwargs = {}) # return add - graph = tracer.trace(model, - meta_args={ - 'bias': torch.rand(*bias_shape).to('meta'), - "x1": torch.rand(4, 8, 16).to('meta'), - 'x2': torch.rand(4, 16, 8).to('meta') - }) - gm = ColoGraphModule(model, graph) + graph = tracer.trace( + model, + meta_args={ + "bias": torch.rand(*bias_shape).to("meta"), + "x1": torch.rand(4, 8, 16).to("meta"), + "x2": torch.rand(4, 16, 8).to("meta"), + }, + ) + ColoGraphModule(model, graph) bmm_mod_node = list(graph.nodes)[3] strategies_vector = StrategiesVector(bmm_mod_node) @@ -202,33 +208,33 @@ def check_1d_device_mesh(rank, module, bias_shape, using_kwargs, world_size, por assert op_data.logical_shape is not None assert op_data.data is not None - assert mapping['input'].name == "x1" - assert mapping['input'].data.is_meta - assert mapping['input'].data.shape == torch.Size([4, 8, 16]) - assert mapping['input'].type == OperationDataType.ARG - assert mapping['input'].logical_shape == torch.Size([4, 8, 16]) + assert mapping["input"].name == "x1" + assert mapping["input"].data.is_meta + assert mapping["input"].data.shape == torch.Size([4, 8, 16]) + assert mapping["input"].type == OperationDataType.ARG + assert mapping["input"].logical_shape == torch.Size([4, 8, 16]) - assert mapping['other'].name == "x2" - assert mapping['other'].data.is_meta - assert mapping['other'].data.shape == torch.Size([4, 16, 8]) - assert mapping['other'].type == OperationDataType.ARG - assert mapping['other'].logical_shape == torch.Size([4, 16, 8]) + assert mapping["other"].name == "x2" + assert mapping["other"].data.is_meta + assert mapping["other"].data.shape == torch.Size([4, 16, 8]) + assert mapping["other"].type == OperationDataType.ARG + assert mapping["other"].logical_shape == torch.Size([4, 16, 8]) - assert mapping['output'].name == "bmm" - assert mapping['output'].data.is_meta - assert mapping['output'].data.shape == torch.Size([4, 8, 8]) - assert mapping['output'].type == OperationDataType.OUTPUT + assert mapping["output"].name == "bmm" + assert mapping["output"].data.is_meta + assert mapping["output"].data.shape == torch.Size([4, 8, 8]) + assert mapping["output"].type == OperationDataType.OUTPUT strategies_vector = handler.register_strategy(compute_resharding_cost=False) strategy_name_list = [val.name for val in strategies_vector] assert len(strategy_name_list) == 1 # one batch dim - assert 'Sb0 = Sb0 x Sb0' in strategy_name_list + assert "Sb0 = Sb0 x Sb0" in strategy_name_list for strategy in strategies_vector: - input_sharding_spec = strategy.get_sharding_spec_by_name('x1') - other_sharding_spec = strategy.get_sharding_spec_by_name('x2') - output_sharding_spec = strategy.get_sharding_spec_by_name('bmm') + input_sharding_spec = strategy.get_sharding_spec_by_name("x1") + other_sharding_spec = strategy.get_sharding_spec_by_name("x2") + output_sharding_spec = strategy.get_sharding_spec_by_name("bmm") # make sure the sharding matches across different operation data assert input_sharding_spec.sharding_sequence[0] == output_sharding_spec.sharding_sequence[0] @@ -237,11 +243,11 @@ def check_1d_device_mesh(rank, module, bias_shape, using_kwargs, world_size, por @pytest.mark.skip("skip due to bias cases not ready") -@run_on_environment_flag(name='AUTO_PARALLEL') +@run_on_environment_flag(name="AUTO_PARALLEL") @pytest.mark.dist -@parameterize('module', [AddBMMTorchFunctionModule, AddBMMTensorMethodModule]) -@parameterize('bias_shape', [[8], [1, 8], [8, 8]]) -@parameterize('using_kwargs', [True, False]) +@parameterize("module", [AddBMMTorchFunctionModule, AddBMMTensorMethodModule]) +@parameterize("bias_shape", [[8], [1, 8], [8, 8]]) +@parameterize("using_kwargs", [True, False]) @rerun_if_address_is_in_use() def test_2d_device_mesh(module, bias_shape, using_kwargs): spawn( @@ -254,11 +260,11 @@ def test_2d_device_mesh(module, bias_shape, using_kwargs): @pytest.mark.skip("skip due to bias cases not ready") -@run_on_environment_flag(name='AUTO_PARALLEL') +@run_on_environment_flag(name="AUTO_PARALLEL") @pytest.mark.dist -@parameterize('module', [AddBMMTorchFunctionModule, AddBMMTensorMethodModule]) -@parameterize('bias_shape', [[8], [1, 8], [8, 8]]) -@parameterize('using_kwargs', [True, False]) +@parameterize("module", [AddBMMTorchFunctionModule, AddBMMTensorMethodModule]) +@parameterize("bias_shape", [[8], [1, 8], [8, 8]]) +@parameterize("using_kwargs", [True, False]) @rerun_if_address_is_in_use() def test_1d_device_mesh(module, bias_shape, using_kwargs): spawn( @@ -270,6 +276,6 @@ def test_1d_device_mesh(module, bias_shape, using_kwargs): ) -if __name__ == '__main__': +if __name__ == "__main__": test_1d_device_mesh() test_2d_device_mesh() diff --git a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_addmm_handler.py b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_addmm_handler.py index fe6554cd81eed38976fab9a73b68b8b8bd16b026..26f9c4ab1e3cd65a0f790abb1f7769661763224d 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_addmm_handler.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_addmm_handler.py @@ -19,7 +19,6 @@ from tests.test_auto_parallel.test_tensor_shard.test_node_handler.utils import n class AddmmModel(nn.Module): - def __init__(self): super().__init__() @@ -29,7 +28,6 @@ class AddmmModel(nn.Module): class AddmmModel_with_param(nn.Module): - def __init__(self, weight_shape, bias_shape): super().__init__() self.weight = torch.nn.Parameter(torch.rand(weight_shape)) @@ -42,7 +40,7 @@ class AddmmModel_with_param(nn.Module): def check_addmm_function_handler(rank, world_size, port, input_shape, model_cls): disable_existing_loggers() - launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") if model_cls == AddmmModel: model = AddmmModel().cuda() else: @@ -58,10 +56,10 @@ def check_addmm_function_handler(rank, world_size, port, input_shape, model_cls) # construct input args input_args = [input, m1, m2] # construct meta arg names - meta_arg_names = ['input', 'm1', 'm2'] + meta_arg_names = ["input", "m1", "m2"] meta_args_for_tracer = {} for meta_arg, input_arg in zip(meta_arg_names, input_args): - meta_args_for_tracer[meta_arg] = input_arg.to('meta') + meta_args_for_tracer[meta_arg] = input_arg.to("meta") # the index of addmm node in computation graph node_index = 4 @@ -72,22 +70,24 @@ def check_addmm_function_handler(rank, world_size, port, input_shape, model_cls) # construct input args input_args = [m1] # construct meta arg names - meta_arg_names = ['m1'] + meta_arg_names = ["m1"] # the index of addmm node in computation graph meta_args_for_tracer = {} for meta_arg, input_arg in zip(meta_arg_names, input_args): - meta_args_for_tracer[meta_arg] = input_arg.to('meta') + meta_args_for_tracer[meta_arg] = input_arg.to("meta") node_index = 4 # strategy number of linear node strategy_number = 14 - numerical_test_for_node_strategy(model=model, - device_mesh=device_mesh, - node_index=node_index, - strategy_number=strategy_number, - input_args=input_args, - meta_arg_names=meta_arg_names, - node_type='bias_module') + numerical_test_for_node_strategy( + model=model, + device_mesh=device_mesh, + node_index=node_index, + strategy_number=strategy_number, + input_args=input_args, + meta_arg_names=meta_arg_names, + node_type="bias_module", + ) tracer = ColoTracer(bias_addition_split=True) # graph(): @@ -117,60 +117,60 @@ def check_addmm_function_handler(rank, world_size, port, input_shape, model_cls) # check operation data mapping mapping = handler.get_operation_data_mapping() - assert mapping['input'].name == "m1" - assert mapping['input'].data.shape == torch.Size([4, 8]) - assert mapping['input'].type == OperationDataType.ARG - assert mapping['input'].logical_shape == torch.Size([4, 8]) + assert mapping["input"].name == "m1" + assert mapping["input"].data.shape == torch.Size([4, 8]) + assert mapping["input"].type == OperationDataType.ARG + assert mapping["input"].logical_shape == torch.Size([4, 8]) - assert mapping['other'].name == "transpose" - assert mapping['other'].data.shape == torch.Size([16, 8]) + assert mapping["other"].name == "transpose" + assert mapping["other"].data.shape == torch.Size([16, 8]) if model_cls == AddmmModel: - assert mapping['other'].type == OperationDataType.ARG + assert mapping["other"].type == OperationDataType.ARG else: - assert mapping['other'].type == OperationDataType.PARAM - assert mapping['other'].logical_shape == torch.Size([8, 16]) + assert mapping["other"].type == OperationDataType.PARAM + assert mapping["other"].logical_shape == torch.Size([8, 16]) - assert mapping['output'].name == "linear" - assert mapping['output'].data.shape == torch.Size([4, 16]) - assert mapping['output'].type == OperationDataType.OUTPUT + assert mapping["output"].name == "linear" + assert mapping["output"].data.shape == torch.Size([4, 16]) + assert mapping["output"].type == OperationDataType.OUTPUT # SS = SR x RS - assert 'S0S1 = S0R x RS1_0' in strategy_name_list - assert 'S1S0 = S1R x RS0_0' in strategy_name_list + assert "S0S1 = S0R x RS1_0" in strategy_name_list + assert "S1S0 = S1R x RS0_0" in strategy_name_list # SR = SS x SR - assert 'S0R = S0S1 x S1R_0' in strategy_name_list - assert 'S1R = S1S0 x S0R_0' in strategy_name_list + assert "S0R = S0S1 x S1R_0" in strategy_name_list + assert "S1R = S1S0 x S0R_0" in strategy_name_list # RS = RS x SS - assert 'RS0 = RS1 x S1S0' in strategy_name_list - assert 'RS1 = RS0 x S0S1' in strategy_name_list + assert "RS0 = RS1 x S1S0" in strategy_name_list + assert "RS1 = RS0 x S0S1" in strategy_name_list # RR = RS x SR - assert 'RR = RS0 x S0R' in strategy_name_list - assert 'RR = RS1 x S1R' in strategy_name_list + assert "RR = RS0 x S0R" in strategy_name_list + assert "RR = RS1 x S1R" in strategy_name_list # RS= RR x RS - assert 'RS0 = RR x RS0' in strategy_name_list - assert 'RS1 = RR x RS1' in strategy_name_list + assert "RS0 = RR x RS0" in strategy_name_list + assert "RS1 = RR x RS1" in strategy_name_list # S01R = S01R x RR - assert 'S01R = S01R x RR_0' in strategy_name_list + assert "S01R = S01R x RR_0" in strategy_name_list # RR = RS01 x S01R - assert 'RR = RS01 x S01R' in strategy_name_list + assert "RR = RS01 x S01R" in strategy_name_list # RS01 = RR x RS01 - assert 'RS01 = RR x RS01' in strategy_name_list + assert "RS01 = RR x RS01" in strategy_name_list # RR = RR x RR - assert 'RR = RR x RR' in strategy_name_list + assert "RR = RR x RR" in strategy_name_list for strategy in strategies_vector: strategy: ShardingStrategy - input_sharding_spec = strategy.get_sharding_spec_by_name('m1') - weight_sharding_spec = strategy.get_sharding_spec_by_name('transpose') - output_sharding_spec = strategy.get_sharding_spec_by_name('linear') + input_sharding_spec = strategy.get_sharding_spec_by_name("m1") + weight_sharding_spec = strategy.get_sharding_spec_by_name("transpose") + output_sharding_spec = strategy.get_sharding_spec_by_name("linear") # make sure the sharding matches across different operation data assert input_sharding_spec.sharding_sequence[:-1] == output_sharding_spec.sharding_sequence[:-1] @@ -178,14 +178,14 @@ def check_addmm_function_handler(rank, world_size, port, input_shape, model_cls) assert weight_sharding_spec.sharding_sequence[0] == output_sharding_spec.sharding_sequence[1] -@run_on_environment_flag(name='AUTO_PARALLEL') +@run_on_environment_flag(name="AUTO_PARALLEL") @pytest.mark.dist -@parameterize('input_shape', [(16,), (4, 16)]) -@parameterize('model_cls', [AddmmModel, AddmmModel_with_param]) +@parameterize("input_shape", [(16,), (4, 16)]) +@parameterize("model_cls", [AddmmModel, AddmmModel_with_param]) @rerun_if_address_is_in_use() def test_addmm_handler(input_shape, model_cls): spawn(check_addmm_function_handler, 4, input_shape=input_shape, model_cls=model_cls) -if __name__ == '__main__': +if __name__ == "__main__": test_addmm_handler() diff --git a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_batch_norm_handler.py b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_batch_norm_handler.py index b47b3508ad1b9047e492ba8b10ba87c1d96b311a..86df7237a21991a54c2a6f35d666befb510252e3 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_batch_norm_handler.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_batch_norm_handler.py @@ -16,7 +16,7 @@ from tests.test_auto_parallel.test_tensor_shard.test_node_handler.utils import n def check_bn_module_handler(rank, world_size, port): disable_existing_loggers() - launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") model = nn.Sequential(nn.BatchNorm2d(16)).cuda() physical_mesh_id = torch.arange(0, 4) @@ -27,20 +27,22 @@ def check_bn_module_handler(rank, world_size, port): # the index of bn node in computation graph node_index = 1 # the total number of bn strategies without sync bn mode - # TODO: add sync bn stategies after related passes ready + # TODO: add sync bn strategies after related passes ready strategy_number = 4 - numerical_test_for_node_strategy(model=model, - device_mesh=device_mesh, - node_index=node_index, - strategy_number=strategy_number, - input_args=[input], - meta_arg_names=['input']) + numerical_test_for_node_strategy( + model=model, + device_mesh=device_mesh, + node_index=node_index, + strategy_number=strategy_number, + input_args=[input], + meta_arg_names=["input"], + ) tracer = ColoTracer(bias_addition_split=True) # graph(): # %input_1 : torch.Tensor [#users=1] = placeholder[target=input] # %_0 : [#users=1] = call_module[target=0](args = (%input_1,), kwargs = {}) # return _0 - meta_args = {"input": torch.rand(4, 16, 64, 64).to('meta')} + meta_args = {"input": torch.rand(4, 16, 64, 64).to("meta")} graph = tracer.trace(model, meta_args=meta_args) gm = ColoGraphModule(model, graph) shape_prop_pass(gm, *meta_args.values()) @@ -59,37 +61,37 @@ def check_bn_module_handler(rank, world_size, port): assert op_data.logical_shape is not None assert op_data.data is not None - assert mapping['input'].name == "input_1" - assert mapping['input'].data.shape == torch.Size([4, 16, 64, 64]) - assert mapping['input'].type == OperationDataType.ARG - assert mapping['input'].logical_shape == torch.Size([4, 16, 64, 64]) + assert mapping["input"].name == "input_1" + assert mapping["input"].data.shape == torch.Size([4, 16, 64, 64]) + assert mapping["input"].type == OperationDataType.ARG + assert mapping["input"].logical_shape == torch.Size([4, 16, 64, 64]) - assert mapping['other'].name == "weight" - assert mapping['other'].data.shape == torch.Size([16]) - assert mapping['other'].type == OperationDataType.PARAM - assert mapping['other'].logical_shape == torch.Size([16]) + assert mapping["other"].name == "weight" + assert mapping["other"].data.shape == torch.Size([16]) + assert mapping["other"].type == OperationDataType.PARAM + assert mapping["other"].logical_shape == torch.Size([16]) - assert mapping['bias'].name == "bias" - assert mapping['bias'].data.shape == torch.Size([16]) - assert mapping['bias'].type == OperationDataType.PARAM - assert mapping['bias'].logical_shape == torch.Size([16]) + assert mapping["bias"].name == "bias" + assert mapping["bias"].data.shape == torch.Size([16]) + assert mapping["bias"].type == OperationDataType.PARAM + assert mapping["bias"].logical_shape == torch.Size([16]) - assert mapping['output'].name == "_0" - assert mapping['output'].data.shape == torch.Size([4, 16, 64, 64]) - assert mapping['output'].type == OperationDataType.OUTPUT + assert mapping["output"].name == "_0" + assert mapping["output"].data.shape == torch.Size([4, 16, 64, 64]) + assert mapping["output"].type == OperationDataType.OUTPUT strategies_vector = handler.register_strategy(compute_resharding_cost=False) strategy_name_list = [val.name for val in strategies_vector] # RS = RS x S - assert 'RS0 = RS0 x S0' in strategy_name_list - assert 'RS1 = RS1 x S1' in strategy_name_list + assert "RS0 = RS0 x S0" in strategy_name_list + assert "RS1 = RS1 x S1" in strategy_name_list # RR = RR x R - assert 'RR = RR x R' in strategy_name_list + assert "RR = RR x R" in strategy_name_list # RS01 = RS01 x S01 - assert 'RS01 = RS01 x S01' in strategy_name_list + assert "RS01 = RS01 x S01" in strategy_name_list # temporarily skip the sync bn test # TODO: test sync bn after the implicit runtime pass completed @@ -105,12 +107,12 @@ def check_bn_module_handler(rank, world_size, port): # assert 'S01R = S01R x R WITH SYNC_BN' in strategy_name_list -@run_on_environment_flag(name='AUTO_PARALLEL') +@run_on_environment_flag(name="AUTO_PARALLEL") @pytest.mark.dist @rerun_if_address_is_in_use() def test_bn_module_handler(): spawn(check_bn_module_handler, 4) -if __name__ == '__main__': +if __name__ == "__main__": test_bn_module_handler() diff --git a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_bias_linear_function_node.py b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_bias_linear_function_node.py index 800bc11a50e443eeb59598ce1ef0317b96589899..e06625e1c42c587392710fd06068c4e554d23e61 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_bias_linear_function_node.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_bias_linear_function_node.py @@ -5,7 +5,7 @@ import torch.nn.functional as F from colossalai._analyzer.fx.graph_module import ColoGraphModule from colossalai._analyzer.fx.passes.shape_prop import shape_prop_pass from colossalai._analyzer.fx.tracer.tracer import ColoTracer -from colossalai.auto_parallel.tensor_shard.node_handler import LinearFunctionHandler, LinearModuleHandler +from colossalai.auto_parallel.tensor_shard.node_handler import LinearFunctionHandler from colossalai.auto_parallel.tensor_shard.sharding_strategy import ( OperationData, OperationDataType, @@ -22,7 +22,6 @@ WEIGHT_SHAPE = (32, 16) class LinearModule(torch.nn.Module): - def __init__(self, weight_shape): super().__init__() self.weight = torch.nn.Parameter(torch.rand(*weight_shape)) @@ -35,7 +34,7 @@ class LinearModule(torch.nn.Module): def check_linear_module_handler(rank, world_size, port): disable_existing_loggers() - launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") model = LinearModule(weight_shape=WEIGHT_SHAPE).cuda() physical_mesh_id = torch.arange(0, 4) @@ -49,14 +48,16 @@ def check_linear_module_handler(rank, world_size, port): # construct input args input_args = [input] # construct meta arg names - meta_arg_names = ['x'] - numerical_test_for_node_strategy(model=model, - device_mesh=device_mesh, - node_index=node_index, - strategy_number=strategy_number, - input_args=input_args, - meta_arg_names=meta_arg_names, - node_type='bias_module') + meta_arg_names = ["x"] + numerical_test_for_node_strategy( + model=model, + device_mesh=device_mesh, + node_index=node_index, + strategy_number=strategy_number, + input_args=input_args, + meta_arg_names=meta_arg_names, + node_type="bias_module", + ) tracer = ColoTracer(bias_addition_split=True) # graph(): @@ -66,7 +67,7 @@ def check_linear_module_handler(rank, world_size, port): # %linear : [#users=1] = call_function[target=torch._C._nn.linear](args = (%x, %weight), kwargs = {}) # %add : [#users=1] = call_function[target=operator.add](args = (%linear, %bias), kwargs = {}) # return add - meta_args = {"x": torch.rand(4, 4, 4, 16).to('meta')} + meta_args = {"x": torch.rand(4, 4, 4, 16).to("meta")} graph = tracer.trace(model, meta_args=meta_args) gm = ColoGraphModule(model, graph) shape_prop_pass(gm, *meta_args.values()) @@ -85,72 +86,72 @@ def check_linear_module_handler(rank, world_size, port): assert op_data.logical_shape is not None assert op_data.data is not None - assert mapping['input'].name == "x" - assert mapping['input'].data.shape == torch.Size([4, 4, 4, 16]) - assert mapping['input'].type == OperationDataType.ARG - assert mapping['input'].logical_shape == torch.Size([64, 16]) + assert mapping["input"].name == "x" + assert mapping["input"].data.shape == torch.Size([4, 4, 4, 16]) + assert mapping["input"].type == OperationDataType.ARG + assert mapping["input"].logical_shape == torch.Size([64, 16]) - assert mapping['other'].name == "weight" - assert mapping['other'].data.shape == torch.Size([32, 16]) - assert mapping['other'].type == OperationDataType.PARAM - assert mapping['other'].logical_shape == torch.Size([16, 32]) + assert mapping["other"].name == "weight" + assert mapping["other"].data.shape == torch.Size([32, 16]) + assert mapping["other"].type == OperationDataType.PARAM + assert mapping["other"].logical_shape == torch.Size([16, 32]) - assert 'bias' not in mapping + assert "bias" not in mapping - assert mapping['output'].name == "linear" - assert mapping['output'].data.shape == torch.Size([4, 4, 4, 32]) - assert mapping['output'].type == OperationDataType.OUTPUT + assert mapping["output"].name == "linear" + assert mapping["output"].data.shape == torch.Size([4, 4, 4, 32]) + assert mapping["output"].type == OperationDataType.OUTPUT strategies_vector = handler.register_strategy(compute_resharding_cost=False) strategy_name_list = [val.name for val in strategies_vector] # SS = SR x RS - assert 'S0S1 = S0R x RS1_0' in strategy_name_list - assert 'S0S1 = S0R x RS1_1' in strategy_name_list - assert 'S0S1 = S0R x RS1_2' in strategy_name_list - assert 'S1S0 = S1R x RS0_0' in strategy_name_list - assert 'S1S0 = S1R x RS0_1' in strategy_name_list - assert 'S1S0 = S1R x RS0_2' in strategy_name_list + assert "S0S1 = S0R x RS1_0" in strategy_name_list + assert "S0S1 = S0R x RS1_1" in strategy_name_list + assert "S0S1 = S0R x RS1_2" in strategy_name_list + assert "S1S0 = S1R x RS0_0" in strategy_name_list + assert "S1S0 = S1R x RS0_1" in strategy_name_list + assert "S1S0 = S1R x RS0_2" in strategy_name_list # SR = SS x SR - assert 'S0R = S0S1 x S1R_0' in strategy_name_list - assert 'S0R = S0S1 x S1R_1' in strategy_name_list - assert 'S0R = S0S1 x S1R_2' in strategy_name_list - assert 'S1R = S1S0 x S0R_0' in strategy_name_list - assert 'S1R = S1S0 x S0R_1' in strategy_name_list - assert 'S1R = S1S0 x S0R_2' in strategy_name_list + assert "S0R = S0S1 x S1R_0" in strategy_name_list + assert "S0R = S0S1 x S1R_1" in strategy_name_list + assert "S0R = S0S1 x S1R_2" in strategy_name_list + assert "S1R = S1S0 x S0R_0" in strategy_name_list + assert "S1R = S1S0 x S0R_1" in strategy_name_list + assert "S1R = S1S0 x S0R_2" in strategy_name_list # RS = RS x SS - assert 'RS0 = RS1 x S1S0' in strategy_name_list - assert 'RS1 = RS0 x S0S1' in strategy_name_list + assert "RS0 = RS1 x S1S0" in strategy_name_list + assert "RS1 = RS0 x S0S1" in strategy_name_list # RR = RS x SR - assert 'RR = RS0 x S0R' in strategy_name_list - assert 'RR = RS1 x S1R' in strategy_name_list + assert "RR = RS0 x S0R" in strategy_name_list + assert "RR = RS1 x S1R" in strategy_name_list # RS= RR x RS - assert 'RS0 = RR x RS0' in strategy_name_list - assert 'RS1 = RR x RS1' in strategy_name_list + assert "RS0 = RR x RS0" in strategy_name_list + assert "RS1 = RR x RS1" in strategy_name_list # S01R = S01R x RR - assert 'S01R = S01R x RR_0' in strategy_name_list - assert 'S01R = S01R x RR_1' in strategy_name_list - assert 'S01R = S01R x RR_2' in strategy_name_list + assert "S01R = S01R x RR_0" in strategy_name_list + assert "S01R = S01R x RR_1" in strategy_name_list + assert "S01R = S01R x RR_2" in strategy_name_list # RR = RS01 x S01R - assert 'RR = RS01 x S01R' in strategy_name_list + assert "RR = RS01 x S01R" in strategy_name_list # RS01 = RR x RS01 - assert 'RS01 = RR x RS01' in strategy_name_list + assert "RS01 = RR x RS01" in strategy_name_list # RR = RR x RR - assert 'RR = RR x RR' in strategy_name_list + assert "RR = RR x RR" in strategy_name_list for strategy in strategies_vector: strategy: ShardingStrategy - input_sharding_spec = strategy.get_sharding_spec_by_name('x') - weight_sharding_spec = strategy.get_sharding_spec_by_name('weight') - output_sharding_spec = strategy.get_sharding_spec_by_name('linear') + input_sharding_spec = strategy.get_sharding_spec_by_name("x") + weight_sharding_spec = strategy.get_sharding_spec_by_name("weight") + output_sharding_spec = strategy.get_sharding_spec_by_name("linear") # make sure the sharding matches across different operation data assert input_sharding_spec.sharding_sequence[:-1] == output_sharding_spec.sharding_sequence[:-1] @@ -158,12 +159,12 @@ def check_linear_module_handler(rank, world_size, port): assert weight_sharding_spec.sharding_sequence[0] == output_sharding_spec.sharding_sequence[-1] -@run_on_environment_flag(name='AUTO_PARALLEL') +@run_on_environment_flag(name="AUTO_PARALLEL") @pytest.mark.dist @rerun_if_address_is_in_use() def test_linear_handler(): spawn(check_linear_module_handler) -if __name__ == '__main__': +if __name__ == "__main__": test_linear_handler() diff --git a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_bias_linear_module_node.py b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_bias_linear_module_node.py index c29a065d10baf4faee039bd1965650abdbfc6062..690f0c12387c542cff424b3fe5ff885ed1d10cec 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_bias_linear_module_node.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_bias_linear_module_node.py @@ -19,7 +19,6 @@ from tests.test_auto_parallel.test_tensor_shard.test_node_handler.utils import n class LinearModule(torch.nn.Module): - def __init__(self, in_features, out_features, bias): super().__init__() self.linear = torch.nn.Linear(in_features, out_features, bias=bias) @@ -31,7 +30,7 @@ class LinearModule(torch.nn.Module): def check_linear_module_handler(rank, world_size, port, bias): disable_existing_loggers() - launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") model = LinearModule(16, 32, bias=bias).cuda() physical_mesh_id = torch.arange(0, 4) @@ -45,17 +44,19 @@ def check_linear_module_handler(rank, world_size, port, bias): # construct input args input_args = [input] # construct meta arg names - meta_arg_names = ['x'] - numerical_test_for_node_strategy(model=model, - device_mesh=device_mesh, - node_index=node_index, - strategy_number=strategy_number, - input_args=input_args, - meta_arg_names=meta_arg_names, - node_type='bias_module') + meta_arg_names = ["x"] + numerical_test_for_node_strategy( + model=model, + device_mesh=device_mesh, + node_index=node_index, + strategy_number=strategy_number, + input_args=input_args, + meta_arg_names=meta_arg_names, + node_type="bias_module", + ) tracer = ColoTracer(bias_addition_split=True) - meta_args = {"x": torch.rand(4, 4, 4, 16).to('meta')} + meta_args = {"x": torch.rand(4, 4, 4, 16).to("meta")} graph = tracer.trace(model, meta_args=meta_args) gm = ColoGraphModule(model, graph) shape_prop_pass(gm, *meta_args.values()) @@ -74,72 +75,72 @@ def check_linear_module_handler(rank, world_size, port, bias): assert op_data.logical_shape is not None assert op_data.data is not None - assert mapping['input'].name == "x" - assert mapping['input'].data.shape == torch.Size([4, 4, 4, 16]) - assert mapping['input'].type == OperationDataType.ARG - assert mapping['input'].logical_shape == torch.Size([64, 16]) + assert mapping["input"].name == "x" + assert mapping["input"].data.shape == torch.Size([4, 4, 4, 16]) + assert mapping["input"].type == OperationDataType.ARG + assert mapping["input"].logical_shape == torch.Size([64, 16]) - assert mapping['other'].name == "linear_weight" - assert mapping['other'].data.shape == torch.Size([32, 16]) - assert mapping['other'].type == OperationDataType.PARAM - assert mapping['other'].logical_shape == torch.Size([16, 32]) + assert mapping["other"].name == "linear_weight" + assert mapping["other"].data.shape == torch.Size([32, 16]) + assert mapping["other"].type == OperationDataType.PARAM + assert mapping["other"].logical_shape == torch.Size([16, 32]) - assert 'bias' not in mapping + assert "bias" not in mapping - assert mapping['output'].name == "linear" - assert mapping['output'].data.shape == torch.Size([4, 4, 4, 32]) - assert mapping['output'].type == OperationDataType.OUTPUT + assert mapping["output"].name == "linear" + assert mapping["output"].data.shape == torch.Size([4, 4, 4, 32]) + assert mapping["output"].type == OperationDataType.OUTPUT strategies_vector = handler.register_strategy(compute_resharding_cost=False) strategy_name_list = [val.name for val in strategies_vector] # SS = SR x RS - assert 'S0S1 = S0R x RS1_0' in strategy_name_list - assert 'S0S1 = S0R x RS1_1' in strategy_name_list - assert 'S0S1 = S0R x RS1_2' in strategy_name_list - assert 'S1S0 = S1R x RS0_0' in strategy_name_list - assert 'S1S0 = S1R x RS0_1' in strategy_name_list - assert 'S1S0 = S1R x RS0_2' in strategy_name_list + assert "S0S1 = S0R x RS1_0" in strategy_name_list + assert "S0S1 = S0R x RS1_1" in strategy_name_list + assert "S0S1 = S0R x RS1_2" in strategy_name_list + assert "S1S0 = S1R x RS0_0" in strategy_name_list + assert "S1S0 = S1R x RS0_1" in strategy_name_list + assert "S1S0 = S1R x RS0_2" in strategy_name_list # SR = SS x SR - assert 'S0R = S0S1 x S1R_0' in strategy_name_list - assert 'S0R = S0S1 x S1R_1' in strategy_name_list - assert 'S0R = S0S1 x S1R_2' in strategy_name_list - assert 'S1R = S1S0 x S0R_0' in strategy_name_list - assert 'S1R = S1S0 x S0R_1' in strategy_name_list - assert 'S1R = S1S0 x S0R_2' in strategy_name_list + assert "S0R = S0S1 x S1R_0" in strategy_name_list + assert "S0R = S0S1 x S1R_1" in strategy_name_list + assert "S0R = S0S1 x S1R_2" in strategy_name_list + assert "S1R = S1S0 x S0R_0" in strategy_name_list + assert "S1R = S1S0 x S0R_1" in strategy_name_list + assert "S1R = S1S0 x S0R_2" in strategy_name_list # RS = RS x SS - assert 'RS0 = RS1 x S1S0' in strategy_name_list - assert 'RS1 = RS0 x S0S1' in strategy_name_list + assert "RS0 = RS1 x S1S0" in strategy_name_list + assert "RS1 = RS0 x S0S1" in strategy_name_list # RR = RS x SR - assert 'RR = RS0 x S0R' in strategy_name_list - assert 'RR = RS1 x S1R' in strategy_name_list + assert "RR = RS0 x S0R" in strategy_name_list + assert "RR = RS1 x S1R" in strategy_name_list # RS= RR x RS - assert 'RS0 = RR x RS0' in strategy_name_list - assert 'RS1 = RR x RS1' in strategy_name_list + assert "RS0 = RR x RS0" in strategy_name_list + assert "RS1 = RR x RS1" in strategy_name_list # S01R = S01R x RR - assert 'S01R = S01R x RR_0' in strategy_name_list - assert 'S01R = S01R x RR_1' in strategy_name_list - assert 'S01R = S01R x RR_2' in strategy_name_list + assert "S01R = S01R x RR_0" in strategy_name_list + assert "S01R = S01R x RR_1" in strategy_name_list + assert "S01R = S01R x RR_2" in strategy_name_list # RR = RS01 x S01R - assert 'RR = RS01 x S01R' in strategy_name_list + assert "RR = RS01 x S01R" in strategy_name_list # RS01 = RR x RS01 - assert 'RS01 = RR x RS01' in strategy_name_list + assert "RS01 = RR x RS01" in strategy_name_list # RR = RR x RR - assert 'RR = RR x RR' in strategy_name_list + assert "RR = RR x RR" in strategy_name_list for strategy in strategies_vector: strategy: ShardingStrategy - input_sharding_spec = strategy.get_sharding_spec_by_name('x') - weight_sharding_spec = strategy.get_sharding_spec_by_name('linear_weight') - output_sharding_spec = strategy.get_sharding_spec_by_name('linear') + input_sharding_spec = strategy.get_sharding_spec_by_name("x") + weight_sharding_spec = strategy.get_sharding_spec_by_name("linear_weight") + output_sharding_spec = strategy.get_sharding_spec_by_name("linear") # make sure the sharding matches across different operation data assert input_sharding_spec.sharding_sequence[:-1] == output_sharding_spec.sharding_sequence[:-1] @@ -147,12 +148,12 @@ def check_linear_module_handler(rank, world_size, port, bias): assert weight_sharding_spec.sharding_sequence[0] == output_sharding_spec.sharding_sequence[-1] -@run_on_environment_flag(name='AUTO_PARALLEL') +@run_on_environment_flag(name="AUTO_PARALLEL") @pytest.mark.dist @rerun_if_address_is_in_use() def test_linear_handler(bias=True): spawn(check_linear_module_handler, bias=bias) -if __name__ == '__main__': +if __name__ == "__main__": test_linear_handler() diff --git a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_binary_elementwise_handler.py b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_binary_elementwise_handler.py index 83f3aafe220eaadb31fbada624cf6bb65c899df2..5b2e2ab49f6d5421fbbd694a70f3247c979642ff 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_binary_elementwise_handler.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_binary_elementwise_handler.py @@ -16,10 +16,9 @@ from tests.test_auto_parallel.test_tensor_shard.test_node_handler.utils import n def check_binary_elementwise_handler_with_tensor(rank, world_size, port, op, other_dim): disable_existing_loggers() - launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") class BinaryElementwiseOpModel(nn.Module): - def __init__(self, op): super().__init__() self.op = op @@ -41,16 +40,18 @@ def check_binary_elementwise_handler_with_tensor(rank, world_size, port, op, oth # construct input args input_args = [x1, x2] # construct meta arg names - meta_arg_names = ['x1', 'x2'] - numerical_test_for_node_strategy(model=model, - device_mesh=device_mesh, - node_index=node_index, - strategy_number=strategy_number, - input_args=input_args, - meta_arg_names=meta_arg_names) + meta_arg_names = ["x1", "x2"] + numerical_test_for_node_strategy( + model=model, + device_mesh=device_mesh, + node_index=node_index, + strategy_number=strategy_number, + input_args=input_args, + meta_arg_names=meta_arg_names, + ) tracer = ColoTracer(bias_addition_split=True) - meta_args = {'x1': torch.rand(4, 4).to('meta'), 'x2': torch.rand([4] * other_dim).to('meta')} + meta_args = {"x1": torch.rand(4, 4).to("meta"), "x2": torch.rand([4] * other_dim).to("meta")} graph = tracer.trace(model, meta_args=meta_args) gm = ColoGraphModule(model, graph) shape_prop_pass(gm, *meta_args.values()) @@ -70,23 +71,23 @@ def check_binary_elementwise_handler_with_tensor(rank, world_size, port, op, oth assert op_data.logical_shape is not None assert op_data.data is not None - assert mapping['input'].name == "x1" - assert mapping['input'].data.is_meta - assert mapping['input'].data.shape == torch.Size([4, 4]) - assert mapping['input'].type == OperationDataType.ARG - assert mapping['input'].logical_shape == torch.Size([4, 4]) + assert mapping["input"].name == "x1" + assert mapping["input"].data.is_meta + assert mapping["input"].data.shape == torch.Size([4, 4]) + assert mapping["input"].type == OperationDataType.ARG + assert mapping["input"].logical_shape == torch.Size([4, 4]) - assert mapping['other'].name == "x2" - assert mapping['other'].data.is_meta - assert mapping['other'].data.shape == torch.Size([4] * other_dim) - assert mapping['other'].type == OperationDataType.ARG - assert mapping['other'].logical_shape == torch.Size([4, 4]) + assert mapping["other"].name == "x2" + assert mapping["other"].data.is_meta + assert mapping["other"].data.shape == torch.Size([4] * other_dim) + assert mapping["other"].type == OperationDataType.ARG + assert mapping["other"].logical_shape == torch.Size([4, 4]) - assert mapping['output'].name == str(op_node) - assert mapping['output'].data.is_meta - assert mapping['output'].data.shape == torch.Size([4, 4]) - assert mapping['output'].type == OperationDataType.OUTPUT - assert mapping['output'].logical_shape == torch.Size([4, 4]) + assert mapping["output"].name == str(op_node) + assert mapping["output"].data.is_meta + assert mapping["output"].data.shape == torch.Size([4, 4]) + assert mapping["output"].type == OperationDataType.OUTPUT + assert mapping["output"].logical_shape == torch.Size([4, 4]) strategies_vector = handler.register_strategy(compute_resharding_cost=False) strategy_name_list = [val.name for val in strategies_vector] @@ -95,19 +96,19 @@ def check_binary_elementwise_handler_with_tensor(rank, world_size, port, op, oth assert len(strategy_name_list) == 9 # check if the sharding strategy is correct - assert '[S0, S1] = [S0, S1] [S0, S1]' in strategy_name_list - assert '[S1, S0] = [S1, S0] [S1, S0]' in strategy_name_list - assert '[S01, R] = [S01, R] [S01, R]' in strategy_name_list - assert '[R, S01] = [R, S01] [R, S01]' in strategy_name_list - assert '[S0, R] = [S0, R] [S0, R]' in strategy_name_list - assert '[R, S0] = [R, S0] [R, S0]' in strategy_name_list - assert '[S1, R] = [S1, R] [S1, R]' in strategy_name_list - assert '[R, S1] = [R, S1] [R, S1]' in strategy_name_list - assert '[R, R] = [R, R] [R, R]' in strategy_name_list + assert "[S0, S1] = [S0, S1] [S0, S1]" in strategy_name_list + assert "[S1, S0] = [S1, S0] [S1, S0]" in strategy_name_list + assert "[S01, R] = [S01, R] [S01, R]" in strategy_name_list + assert "[R, S01] = [R, S01] [R, S01]" in strategy_name_list + assert "[S0, R] = [S0, R] [S0, R]" in strategy_name_list + assert "[R, S0] = [R, S0] [R, S0]" in strategy_name_list + assert "[S1, R] = [S1, R] [S1, R]" in strategy_name_list + assert "[R, S1] = [R, S1] [R, S1]" in strategy_name_list + assert "[R, R] = [R, R] [R, R]" in strategy_name_list for strategy in strategies_vector: - input_sharding_spec = strategy.get_sharding_spec_by_name('x1') - other_sharding_spec = strategy.get_sharding_spec_by_name('x2') + input_sharding_spec = strategy.get_sharding_spec_by_name("x1") + other_sharding_spec = strategy.get_sharding_spec_by_name("x2") output_sharding_spec = strategy.get_sharding_spec_by_name(str(op_node)) # make sure the sharding spec is the same for input and output @@ -121,7 +122,6 @@ def check_binary_elementwise_handler_with_tensor(rank, world_size, port, op, oth class BEOpModelWithNodeConst(nn.Module): - def __init__(self, op): super().__init__() self.op = op @@ -133,7 +133,6 @@ class BEOpModelWithNodeConst(nn.Module): class BEOpModelWithIntConst(nn.Module): - def __init__(self, op, const): super().__init__() self.op = op @@ -146,7 +145,7 @@ class BEOpModelWithIntConst(nn.Module): def check_binary_elementwise_handler_with_int(rank, world_size, port, op, other_dim, model_cls): disable_existing_loggers() - launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") physical_mesh_id = torch.arange(0, 4) mesh_shape = (2, 2) @@ -163,15 +162,17 @@ def check_binary_elementwise_handler_with_int(rank, world_size, port, op, other_ # construct input args input_args = [x1] # construct meta arg names - meta_arg_names = ['x1'] - numerical_test_for_node_strategy(model=model, - device_mesh=device_mesh, - node_index=node_index, - strategy_number=strategy_number, - input_args=input_args, - meta_arg_names=meta_arg_names) + meta_arg_names = ["x1"] + numerical_test_for_node_strategy( + model=model, + device_mesh=device_mesh, + node_index=node_index, + strategy_number=strategy_number, + input_args=input_args, + meta_arg_names=meta_arg_names, + ) tracer = ColoTracer(bias_addition_split=True) - meta_args = {'x1': torch.rand(4, 4).to('meta')} + meta_args = {"x1": torch.rand(4, 4).to("meta")} graph = tracer.trace(model, meta_args=meta_args) gm = ColoGraphModule(model, graph) shape_prop_pass(gm, *meta_args.values()) @@ -188,17 +189,17 @@ def check_binary_elementwise_handler_with_int(rank, world_size, port, op, other_ # check operation data mapping mapping = handler.get_operation_data_mapping() - assert mapping['input'].name == "x1" - assert mapping['input'].data.is_meta - assert mapping['input'].data.shape == torch.Size([4, 4]) - assert mapping['input'].type == OperationDataType.ARG - assert mapping['input'].logical_shape == torch.Size([4, 4]) + assert mapping["input"].name == "x1" + assert mapping["input"].data.is_meta + assert mapping["input"].data.shape == torch.Size([4, 4]) + assert mapping["input"].type == OperationDataType.ARG + assert mapping["input"].logical_shape == torch.Size([4, 4]) - assert mapping['output'].name == str(op_node) - assert mapping['output'].data.is_meta - assert mapping['output'].data.shape == torch.Size([4, 4]) - assert mapping['output'].type == OperationDataType.OUTPUT - assert mapping['output'].logical_shape == torch.Size([4, 4]) + assert mapping["output"].name == str(op_node) + assert mapping["output"].data.is_meta + assert mapping["output"].data.shape == torch.Size([4, 4]) + assert mapping["output"].type == OperationDataType.OUTPUT + assert mapping["output"].logical_shape == torch.Size([4, 4]) strategies_vector = handler.register_strategy(compute_resharding_cost=False) strategy_name_list = [val.name for val in strategies_vector] @@ -207,27 +208,27 @@ def check_binary_elementwise_handler_with_int(rank, world_size, port, op, other_ assert len(strategy_name_list) == 9 # check if the sharding strategy is correct - assert '[S0, S1] = [S0, S1] [S0, S1]' in strategy_name_list - assert '[S1, S0] = [S1, S0] [S1, S0]' in strategy_name_list - assert '[S01, R] = [S01, R] [S01, R]' in strategy_name_list - assert '[R, S01] = [R, S01] [R, S01]' in strategy_name_list - assert '[S0, R] = [S0, R] [S0, R]' in strategy_name_list - assert '[R, S0] = [R, S0] [R, S0]' in strategy_name_list - assert '[S1, R] = [S1, R] [S1, R]' in strategy_name_list - assert '[R, S1] = [R, S1] [R, S1]' in strategy_name_list - assert '[R, R] = [R, R] [R, R]' in strategy_name_list + assert "[S0, S1] = [S0, S1] [S0, S1]" in strategy_name_list + assert "[S1, S0] = [S1, S0] [S1, S0]" in strategy_name_list + assert "[S01, R] = [S01, R] [S01, R]" in strategy_name_list + assert "[R, S01] = [R, S01] [R, S01]" in strategy_name_list + assert "[S0, R] = [S0, R] [S0, R]" in strategy_name_list + assert "[R, S0] = [R, S0] [R, S0]" in strategy_name_list + assert "[S1, R] = [S1, R] [S1, R]" in strategy_name_list + assert "[R, S1] = [R, S1] [R, S1]" in strategy_name_list + assert "[R, R] = [R, R] [R, R]" in strategy_name_list for strategy in strategies_vector: - input_sharding_spec = strategy.get_sharding_spec_by_name('x1') + input_sharding_spec = strategy.get_sharding_spec_by_name("x1") output_sharding_spec = strategy.get_sharding_spec_by_name(str(op_node)) # make sure the sharding spec is the same for input and output assert input_sharding_spec.sharding_sequence == output_sharding_spec.sharding_sequence -@run_on_environment_flag(name='AUTO_PARALLEL') -@parameterize('op', [torch.add]) -@parameterize('other_dim', [1, 2]) +@run_on_environment_flag(name="AUTO_PARALLEL") +@parameterize("op", [torch.add]) +@parameterize("other_dim", [1, 2]) @pytest.mark.dist @rerun_if_address_is_in_use() def test_binary_elementwise_handler_with_tensor(op, other_dim): @@ -239,10 +240,10 @@ def test_binary_elementwise_handler_with_tensor(op, other_dim): ) -@run_on_environment_flag(name='AUTO_PARALLEL') -@parameterize('op', [torch.add]) -@parameterize('other_dim', [1, 2]) -@parameterize('model_cls', [BEOpModelWithNodeConst, BEOpModelWithIntConst]) +@run_on_environment_flag(name="AUTO_PARALLEL") +@parameterize("op", [torch.add]) +@parameterize("other_dim", [1, 2]) +@parameterize("model_cls", [BEOpModelWithNodeConst, BEOpModelWithIntConst]) @pytest.mark.dist @rerun_if_address_is_in_use() def test_binary_elementwise_handler_with_int(op, model_cls, other_dim): @@ -255,6 +256,6 @@ def test_binary_elementwise_handler_with_int(op, model_cls, other_dim): ) -if __name__ == '__main__': +if __name__ == "__main__": test_binary_elementwise_handler_with_tensor() test_binary_elementwise_handler_with_int() diff --git a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_bmm_handler.py b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_bmm_handler.py index f4fdc458f80ea874bc6a3916e2974558e0663c72..29df128322414f9da9ea0aee4067869c9d3ea9a6 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_bmm_handler.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_bmm_handler.py @@ -15,20 +15,18 @@ from tests.test_auto_parallel.test_tensor_shard.test_node_handler.utils import n class BMMTensorMethodModule(nn.Module): - def forward(self, x1, x2): return x1.bmm(x2) class BMMTorchFunctionModule(nn.Module): - def forward(self, x1, x2): return torch.bmm(x1, x2) def check_2d_device_mesh(rank, module, world_size, port): disable_existing_loggers() - launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") model = module().cuda() physical_mesh_id = torch.arange(0, 4) mesh_shape = (2, 2) @@ -42,15 +40,17 @@ def check_2d_device_mesh(rank, module, world_size, port): # construct input args input_args = [x1, x2] # construct meta arg names - meta_arg_names = ['x1', 'x2'] - numerical_test_for_node_strategy(model=model, - device_mesh=device_mesh, - node_index=node_index, - strategy_number=strategy_number, - input_args=input_args, - meta_arg_names=meta_arg_names) + meta_arg_names = ["x1", "x2"] + numerical_test_for_node_strategy( + model=model, + device_mesh=device_mesh, + node_index=node_index, + strategy_number=strategy_number, + input_args=input_args, + meta_arg_names=meta_arg_names, + ) tracer = ColoTracer(bias_addition_split=True) - meta_args = {'x1': torch.rand(4, 8, 16).to('meta'), 'x2': torch.rand(4, 16, 8).to('meta')} + meta_args = {"x1": torch.rand(4, 8, 16).to("meta"), "x2": torch.rand(4, 16, 8).to("meta")} graph = tracer.trace(model, meta_args=meta_args) gm = ColoGraphModule(model, graph) shape_prop_pass(gm, *meta_args.values()) @@ -70,48 +70,48 @@ def check_2d_device_mesh(rank, module, world_size, port): assert op_data.logical_shape is not None assert op_data.data is not None - assert mapping['input'].name == "x1" - assert mapping['input'].data.is_meta - assert mapping['input'].data.shape == torch.Size([4, 8, 16]) - assert mapping['input'].type == OperationDataType.ARG - assert mapping['input'].logical_shape == torch.Size([4, 8, 16]) + assert mapping["input"].name == "x1" + assert mapping["input"].data.is_meta + assert mapping["input"].data.shape == torch.Size([4, 8, 16]) + assert mapping["input"].type == OperationDataType.ARG + assert mapping["input"].logical_shape == torch.Size([4, 8, 16]) - assert mapping['other'].name == "x2" - assert mapping['other'].data.is_meta - assert mapping['other'].data.shape == torch.Size([4, 16, 8]) - assert mapping['other'].type == OperationDataType.ARG - assert mapping['other'].logical_shape == torch.Size([4, 16, 8]) + assert mapping["other"].name == "x2" + assert mapping["other"].data.is_meta + assert mapping["other"].data.shape == torch.Size([4, 16, 8]) + assert mapping["other"].type == OperationDataType.ARG + assert mapping["other"].logical_shape == torch.Size([4, 16, 8]) - assert mapping['output'].name == "bmm" - assert mapping['output'].data.is_meta - assert mapping['output'].data.shape == torch.Size([4, 8, 8]) - assert mapping['output'].type == OperationDataType.OUTPUT + assert mapping["output"].name == "bmm" + assert mapping["output"].data.is_meta + assert mapping["output"].data.shape == torch.Size([4, 8, 8]) + assert mapping["output"].type == OperationDataType.OUTPUT strategies_vector = handler.register_strategy(compute_resharding_cost=False) strategy_name_list = [val.name for val in strategies_vector] # one batch dim - assert 'Sb0 = Sb0 x Sb0' not in strategy_name_list + assert "Sb0 = Sb0 x Sb0" not in strategy_name_list # two batch dim - assert 'Sb01 = Sb01 x Sb01' in strategy_name_list + assert "Sb01 = Sb01 x Sb01" in strategy_name_list # SbSi = SbSi x Sb - assert 'Sb0Si1 = Sb0Si1 x Sb0' in strategy_name_list - assert 'Sb1Si0 = Sb1Si0 x Sb1' in strategy_name_list + assert "Sb0Si1 = Sb0Si1 x Sb0" in strategy_name_list + assert "Sb1Si0 = Sb1Si0 x Sb1" in strategy_name_list # SbSj = SbR x SbSj - assert 'Sb0Sj1 = Sb0R x Sb0Sj1' in strategy_name_list - assert 'Sb1Sj0 = Sb1R x Sb1Sj0' in strategy_name_list + assert "Sb0Sj1 = Sb0R x Sb0Sj1" in strategy_name_list + assert "Sb1Sj0 = Sb1R x Sb1Sj0" in strategy_name_list # SbR = SbSk x SbSk - assert 'Sb0R = Sb0Sk1 x Sb0Sk1' in strategy_name_list - assert 'Sb1R = Sb1Sk0 x Sb1Sk0' in strategy_name_list + assert "Sb0R = Sb0Sk1 x Sb0Sk1" in strategy_name_list + assert "Sb1R = Sb1Sk0 x Sb1Sk0" in strategy_name_list for strategy in strategies_vector: - input_sharding_spec = strategy.get_sharding_spec_by_name('x1') - other_sharding_spec = strategy.get_sharding_spec_by_name('x2') - output_sharding_spec = strategy.get_sharding_spec_by_name('bmm') + input_sharding_spec = strategy.get_sharding_spec_by_name("x1") + other_sharding_spec = strategy.get_sharding_spec_by_name("x2") + output_sharding_spec = strategy.get_sharding_spec_by_name("bmm") # make sure the sharding matches across different operation data assert input_sharding_spec.sharding_sequence[:-1] == output_sharding_spec.sharding_sequence[:-1] @@ -121,7 +121,7 @@ def check_2d_device_mesh(rank, module, world_size, port): def check_1d_device_mesh(rank, module, world_size, port): disable_existing_loggers() - launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") model = module().cuda() physical_mesh_id = torch.arange(0, 4) mesh_shape = (1, 4) @@ -135,15 +135,17 @@ def check_1d_device_mesh(rank, module, world_size, port): # construct input args input_args = [x1, x2] # construct meta arg names - meta_arg_names = ['x1', 'x2'] - numerical_test_for_node_strategy(model=model, - device_mesh=device_mesh, - node_index=node_index, - strategy_number=strategy_number, - input_args=input_args, - meta_arg_names=meta_arg_names) + meta_arg_names = ["x1", "x2"] + numerical_test_for_node_strategy( + model=model, + device_mesh=device_mesh, + node_index=node_index, + strategy_number=strategy_number, + input_args=input_args, + meta_arg_names=meta_arg_names, + ) tracer = ColoTracer(bias_addition_split=True) - meta_args = {'x1': torch.rand(4, 8, 16).to('meta'), 'x2': torch.rand(4, 16, 8).to('meta')} + meta_args = {"x1": torch.rand(4, 8, 16).to("meta"), "x2": torch.rand(4, 16, 8).to("meta")} graph = tracer.trace(model, meta_args=meta_args) gm = ColoGraphModule(model, graph) shape_prop_pass(gm, *meta_args.values()) @@ -162,33 +164,33 @@ def check_1d_device_mesh(rank, module, world_size, port): assert op_data.logical_shape is not None assert op_data.data is not None - assert mapping['input'].name == "x1" - assert mapping['input'].data.is_meta - assert mapping['input'].data.shape == torch.Size([4, 8, 16]) - assert mapping['input'].type == OperationDataType.ARG - assert mapping['input'].logical_shape == torch.Size([4, 8, 16]) + assert mapping["input"].name == "x1" + assert mapping["input"].data.is_meta + assert mapping["input"].data.shape == torch.Size([4, 8, 16]) + assert mapping["input"].type == OperationDataType.ARG + assert mapping["input"].logical_shape == torch.Size([4, 8, 16]) - assert mapping['other'].name == "x2" - assert mapping['other'].data.is_meta - assert mapping['other'].data.shape == torch.Size([4, 16, 8]) - assert mapping['other'].type == OperationDataType.ARG - assert mapping['other'].logical_shape == torch.Size([4, 16, 8]) + assert mapping["other"].name == "x2" + assert mapping["other"].data.is_meta + assert mapping["other"].data.shape == torch.Size([4, 16, 8]) + assert mapping["other"].type == OperationDataType.ARG + assert mapping["other"].logical_shape == torch.Size([4, 16, 8]) - assert mapping['output'].name == "bmm" - assert mapping['output'].data.is_meta - assert mapping['output'].data.shape == torch.Size([4, 8, 8]) - assert mapping['output'].type == OperationDataType.OUTPUT + assert mapping["output"].name == "bmm" + assert mapping["output"].data.is_meta + assert mapping["output"].data.shape == torch.Size([4, 8, 8]) + assert mapping["output"].type == OperationDataType.OUTPUT strategies_vector = handler.register_strategy(compute_resharding_cost=False) strategy_name_list = [val.name for val in strategies_vector] assert len(strategy_name_list) == 1 # one batch dim - assert 'Sb0 = Sb0 x Sb0' in strategy_name_list + assert "Sb0 = Sb0 x Sb0" in strategy_name_list for strategy in strategies_vector: - input_sharding_spec = strategy.get_sharding_spec_by_name('x1') - other_sharding_spec = strategy.get_sharding_spec_by_name('x2') - output_sharding_spec = strategy.get_sharding_spec_by_name('bmm') + input_sharding_spec = strategy.get_sharding_spec_by_name("x1") + other_sharding_spec = strategy.get_sharding_spec_by_name("x2") + output_sharding_spec = strategy.get_sharding_spec_by_name("bmm") # make sure the sharding matches across different operation data assert input_sharding_spec.sharding_sequence[:-1] == output_sharding_spec.sharding_sequence[:-1] @@ -196,9 +198,9 @@ def check_1d_device_mesh(rank, module, world_size, port): assert other_sharding_spec.sharding_sequence[-1] == output_sharding_spec.sharding_sequence[-1] -@run_on_environment_flag(name='AUTO_PARALLEL') -@parameterize('module', [BMMTensorMethodModule, BMMTorchFunctionModule]) -@parameterize('module', [BMMTensorMethodModule, BMMTorchFunctionModule]) +@run_on_environment_flag(name="AUTO_PARALLEL") +@parameterize("module", [BMMTensorMethodModule, BMMTorchFunctionModule]) +@parameterize("module", [BMMTensorMethodModule, BMMTorchFunctionModule]) @pytest.mark.dist @rerun_if_address_is_in_use() def test_bmm_handler(module): @@ -206,5 +208,5 @@ def test_bmm_handler(module): spawn(check_1d_device_mesh, 4, module=module) -if __name__ == '__main__': +if __name__ == "__main__": test_bmm_handler() diff --git a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_conv_handler.py b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_conv_handler.py index f9632b1cd8f9bb02299718d0917379cbcdbde78b..8a37dd9256ddd0913e0efbc89cd28852f8fd0211 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_conv_handler.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_conv_handler.py @@ -16,7 +16,7 @@ from tests.test_auto_parallel.test_tensor_shard.test_node_handler.utils import n def check_conv_module_handler(rank, world_size, port, bias): disable_existing_loggers() - launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") model = nn.Sequential(nn.Conv2d(4, 16, 3, padding=1, bias=bias)).cuda() # graph(): # %input_1 : torch.Tensor [#users=1] = placeholder[target=input] @@ -32,14 +32,16 @@ def check_conv_module_handler(rank, world_size, port, bias): node_index = 1 # total number of conv strategies strategy_number = 16 - numerical_test_for_node_strategy(model=model, - device_mesh=device_mesh, - node_index=node_index, - strategy_number=strategy_number, - input_args=[input], - meta_arg_names=['input']) + numerical_test_for_node_strategy( + model=model, + device_mesh=device_mesh, + node_index=node_index, + strategy_number=strategy_number, + input_args=[input], + meta_arg_names=["input"], + ) tracer = ColoTracer(bias_addition_split=True) - meta_args = {'input': torch.rand(4, 4, 64, 64).to('meta')} + meta_args = {"input": torch.rand(4, 4, 64, 64).to("meta")} graph = tracer.trace(model, meta_args=meta_args) gm = ColoGraphModule(model, graph) shape_prop_pass(gm, *meta_args.values()) @@ -58,76 +60,76 @@ def check_conv_module_handler(rank, world_size, port, bias): assert op_data.logical_shape is not None assert op_data.data is not None - assert mapping['input'].name == "input_1" + assert mapping["input"].name == "input_1" # assert mapping['input'].data.is_meta - assert mapping['input'].data.shape == torch.Size([4, 4, 64, 64]) - assert mapping['input'].type == OperationDataType.ARG - assert mapping['input'].logical_shape == torch.Size([4, 4, 64, 64]) + assert mapping["input"].data.shape == torch.Size([4, 4, 64, 64]) + assert mapping["input"].type == OperationDataType.ARG + assert mapping["input"].logical_shape == torch.Size([4, 4, 64, 64]) - assert mapping['other'].name == "weight" + assert mapping["other"].name == "weight" # assert mapping['other'].data.is_meta - assert mapping['other'].data.shape == torch.Size([16, 4, 3, 3]) - assert mapping['other'].type == OperationDataType.PARAM - assert mapping['other'].logical_shape == torch.Size([4, 16, 3, 3]) + assert mapping["other"].data.shape == torch.Size([16, 4, 3, 3]) + assert mapping["other"].type == OperationDataType.PARAM + assert mapping["other"].logical_shape == torch.Size([4, 16, 3, 3]) if bias: - assert mapping['bias'].name == "bias" + assert mapping["bias"].name == "bias" # assert mapping['bias'].data.is_meta - assert mapping['bias'].data.shape == torch.Size([16]) - assert mapping['bias'].type == OperationDataType.PARAM - assert mapping['bias'].logical_shape == torch.Size([16]) + assert mapping["bias"].data.shape == torch.Size([16]) + assert mapping["bias"].type == OperationDataType.PARAM + assert mapping["bias"].logical_shape == torch.Size([16]) - assert mapping['output'].name == "_0" + assert mapping["output"].name == "_0" # assert mapping['output'].data.is_meta - assert mapping['output'].data.shape == torch.Size([4, 16, 64, 64]) - assert mapping['output'].type == OperationDataType.OUTPUT + assert mapping["output"].data.shape == torch.Size([4, 16, 64, 64]) + assert mapping["output"].type == OperationDataType.OUTPUT strategies_vector = handler.register_strategy(compute_resharding_cost=False) strategy_name_list = [val.name for val in strategies_vector] # SS = SR x RS - assert 'S0S1 = S0R x RS1' in strategy_name_list - assert 'S1S0 = S1R x RS0' in strategy_name_list + assert "S0S1 = S0R x RS1" in strategy_name_list + assert "S1S0 = S1R x RS0" in strategy_name_list # SR = SR x RR - assert 'S0R = S0R x RR' in strategy_name_list - assert 'S1R = S1R x RR' in strategy_name_list + assert "S0R = S0R x RR" in strategy_name_list + assert "S1R = S1R x RR" in strategy_name_list # SR = SS x SR - assert 'S0R = S0S1 x S1R' in strategy_name_list - assert 'S1R = S1S0 x S0R' in strategy_name_list + assert "S0R = S0S1 x S1R" in strategy_name_list + assert "S1R = S1S0 x S0R" in strategy_name_list # RS = RS x SS - assert 'RS0 = RS1 x S1S0' in strategy_name_list - assert 'RS1 = RS0 x S0S1' in strategy_name_list + assert "RS0 = RS1 x S1S0" in strategy_name_list + assert "RS1 = RS0 x S0S1" in strategy_name_list # RR = RS x SR - assert 'RR = RS0 x S0R' in strategy_name_list - assert 'RR = RS1 x S1R' in strategy_name_list + assert "RR = RS0 x S0R" in strategy_name_list + assert "RR = RS1 x S1R" in strategy_name_list # RS= RR x RS - assert 'RS0 = RR x RS0' in strategy_name_list - assert 'RS1 = RR x RS1' in strategy_name_list + assert "RS0 = RR x RS0" in strategy_name_list + assert "RS1 = RR x RS1" in strategy_name_list # RR = RR x RR - assert 'RR = RR x RR' in strategy_name_list + assert "RR = RR x RR" in strategy_name_list # S01R = S01R x RR - assert 'S01R = S01R x RR' in strategy_name_list + assert "S01R = S01R x RR" in strategy_name_list # RR = RS01 x S01R - assert 'RR = RS01 x S01R' in strategy_name_list + assert "RR = RS01 x S01R" in strategy_name_list # RS01 = RR x RS01 - assert 'RS01 = RR x RS01' in strategy_name_list + assert "RS01 = RR x RS01" in strategy_name_list for strategy in strategies_vector: - input_sharding_spec = strategy.get_sharding_spec_by_name('input_1') - weight_sharding_spec = strategy.get_sharding_spec_by_name('weight') - output_sharding_spec = strategy.get_sharding_spec_by_name('_0') + input_sharding_spec = strategy.get_sharding_spec_by_name("input_1") + weight_sharding_spec = strategy.get_sharding_spec_by_name("weight") + output_sharding_spec = strategy.get_sharding_spec_by_name("_0") if bias: - bias_sharding_spec = strategy.get_sharding_spec_by_name('bias') + bias_sharding_spec = strategy.get_sharding_spec_by_name("bias") # make sure the sharding matches across different operation data assert output_sharding_spec.sharding_sequence[1] == weight_sharding_spec.sharding_sequence[0] @@ -141,7 +143,6 @@ def check_conv_module_handler(rank, world_size, port, bias): class ConvModel(nn.Module): - def __init__(self): super().__init__() @@ -152,7 +153,7 @@ class ConvModel(nn.Module): def check_conv_function_handler(rank, world_size, port, bias): disable_existing_loggers() - launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") model = ConvModel().cuda() physical_mesh_id = torch.arange(0, 4) mesh_shape = (2, 2) @@ -160,22 +161,24 @@ def check_conv_function_handler(rank, world_size, port, bias): input = torch.rand(4, 4, 64, 64).cuda() others = torch.rand(16, 4, 3, 3).cuda() input_args = [input, others] - meta_arg_names = ['input', 'others'] + meta_arg_names = ["input", "others"] input_kwargs = {} # total number of conv strategies strategy_number = 16 node_index = 2 if bias: bias_tensor = torch.rand(16).cuda() - input_kwargs['bias'] = bias_tensor + input_kwargs["bias"] = bias_tensor node_index += 1 - numerical_test_for_node_strategy(model=model, - device_mesh=device_mesh, - node_index=node_index, - strategy_number=strategy_number, - input_args=input_args, - meta_arg_names=meta_arg_names, - input_kwargs=input_kwargs) + numerical_test_for_node_strategy( + model=model, + device_mesh=device_mesh, + node_index=node_index, + strategy_number=strategy_number, + input_args=input_args, + meta_arg_names=meta_arg_names, + input_kwargs=input_kwargs, + ) tracer = ColoTracer(bias_addition_split=True) # graph(): @@ -183,9 +186,9 @@ def check_conv_function_handler(rank, world_size, port, bias): # %others : torch.Tensor [#users=1] = placeholder[target=others] # %conv2d : [#users=1] = call_function[target=torch.conv2d](args = (%input_1, %others), kwargs = {}) # return conv2d - meta_args = {"input": torch.rand(4, 4, 64, 64).to('meta'), "others": torch.rand(16, 4, 3, 3).to('meta')} + meta_args = {"input": torch.rand(4, 4, 64, 64).to("meta"), "others": torch.rand(16, 4, 3, 3).to("meta")} if bias: - meta_args['bias'] = torch.rand(16).to('meta') + meta_args["bias"] = torch.rand(16).to("meta") graph = tracer.trace(model, meta_args=meta_args) gm = ColoGraphModule(model, graph) shape_prop_pass(gm, *meta_args.values()) @@ -208,76 +211,76 @@ def check_conv_function_handler(rank, world_size, port, bias): assert op_data.logical_shape is not None assert op_data.data is not None - assert mapping['input'].name == "input_1" - assert mapping['input'].data.is_meta - assert mapping['input'].data.shape == torch.Size([4, 4, 64, 64]) - assert mapping['input'].type == OperationDataType.ARG - assert mapping['input'].logical_shape == torch.Size([4, 4, 64, 64]) + assert mapping["input"].name == "input_1" + assert mapping["input"].data.is_meta + assert mapping["input"].data.shape == torch.Size([4, 4, 64, 64]) + assert mapping["input"].type == OperationDataType.ARG + assert mapping["input"].logical_shape == torch.Size([4, 4, 64, 64]) - assert mapping['other'].name == "others" - assert mapping['other'].data.is_meta - assert mapping['other'].data.shape == torch.Size([16, 4, 3, 3]) - assert mapping['other'].type == OperationDataType.ARG - assert mapping['other'].logical_shape == torch.Size([4, 16, 3, 3]) + assert mapping["other"].name == "others" + assert mapping["other"].data.is_meta + assert mapping["other"].data.shape == torch.Size([16, 4, 3, 3]) + assert mapping["other"].type == OperationDataType.ARG + assert mapping["other"].logical_shape == torch.Size([4, 16, 3, 3]) if bias: - assert mapping['bias'].name == "bias" - assert mapping['bias'].data.is_meta - assert mapping['bias'].data.shape == torch.Size([16]) - assert mapping['bias'].type == OperationDataType.ARG - assert mapping['bias'].logical_shape == torch.Size([16]) + assert mapping["bias"].name == "bias" + assert mapping["bias"].data.is_meta + assert mapping["bias"].data.shape == torch.Size([16]) + assert mapping["bias"].type == OperationDataType.ARG + assert mapping["bias"].logical_shape == torch.Size([16]) - assert mapping['output'].name == "conv2d" - assert mapping['output'].data.is_meta - assert mapping['output'].data.shape == torch.Size([4, 16, 64, 64]) - assert mapping['output'].type == OperationDataType.OUTPUT + assert mapping["output"].name == "conv2d" + assert mapping["output"].data.is_meta + assert mapping["output"].data.shape == torch.Size([4, 16, 64, 64]) + assert mapping["output"].type == OperationDataType.OUTPUT handler.register_strategy(compute_resharding_cost=False) strategy_name_list = [val.name for val in strategies_vector] # SS = SR x RS - assert 'S0S1 = S0R x RS1' in strategy_name_list - assert 'S1S0 = S1R x RS0' in strategy_name_list + assert "S0S1 = S0R x RS1" in strategy_name_list + assert "S1S0 = S1R x RS0" in strategy_name_list # SR = SR x RR - assert 'S0R = S0R x RR' in strategy_name_list - assert 'S1R = S1R x RR' in strategy_name_list + assert "S0R = S0R x RR" in strategy_name_list + assert "S1R = S1R x RR" in strategy_name_list # SR = SS x SR - assert 'S0R = S0S1 x S1R' in strategy_name_list - assert 'S1R = S1S0 x S0R' in strategy_name_list + assert "S0R = S0S1 x S1R" in strategy_name_list + assert "S1R = S1S0 x S0R" in strategy_name_list # RS = RS x SS - assert 'RS0 = RS1 x S1S0' in strategy_name_list - assert 'RS1 = RS0 x S0S1' in strategy_name_list + assert "RS0 = RS1 x S1S0" in strategy_name_list + assert "RS1 = RS0 x S0S1" in strategy_name_list # RR = RS x SR - assert 'RR = RS0 x S0R' in strategy_name_list - assert 'RR = RS1 x S1R' in strategy_name_list + assert "RR = RS0 x S0R" in strategy_name_list + assert "RR = RS1 x S1R" in strategy_name_list # RS= RR x RS - assert 'RS0 = RR x RS0' in strategy_name_list - assert 'RS1 = RR x RS1' in strategy_name_list + assert "RS0 = RR x RS0" in strategy_name_list + assert "RS1 = RR x RS1" in strategy_name_list # RR = RR x RR - assert 'RR = RR x RR' in strategy_name_list + assert "RR = RR x RR" in strategy_name_list # S01R = S01R x RR - assert 'S01R = S01R x RR' in strategy_name_list + assert "S01R = S01R x RR" in strategy_name_list # RR = RS01 x S01R - assert 'RR = RS01 x S01R' in strategy_name_list + assert "RR = RS01 x S01R" in strategy_name_list # RS01 = RR x RS01 - assert 'RS01 = RR x RS01' in strategy_name_list + assert "RS01 = RR x RS01" in strategy_name_list for strategy in strategies_vector: - input_sharding_spec = strategy.get_sharding_spec_by_name('input_1') - weight_sharding_spec = strategy.get_sharding_spec_by_name('others') - output_sharding_spec = strategy.get_sharding_spec_by_name('conv2d') + input_sharding_spec = strategy.get_sharding_spec_by_name("input_1") + weight_sharding_spec = strategy.get_sharding_spec_by_name("others") + output_sharding_spec = strategy.get_sharding_spec_by_name("conv2d") if bias: - bias_sharding_spec = strategy.get_sharding_spec_by_name('bias') + bias_sharding_spec = strategy.get_sharding_spec_by_name("bias") # make sure the sharding matches across different operation data assert output_sharding_spec.sharding_sequence[1] == weight_sharding_spec.sharding_sequence[0] @@ -290,7 +293,7 @@ def check_conv_function_handler(rank, world_size, port, bias): assert bias_sharding_spec.sharding_sequence[-1] == output_sharding_spec.sharding_sequence[1] -@run_on_environment_flag(name='AUTO_PARALLEL') +@run_on_environment_flag(name="AUTO_PARALLEL") @pytest.mark.dist # We temporarily ban the bias option before doing bias add # before all reduce communication may encounter correctness issue. @@ -300,7 +303,7 @@ def test_conv_module_handler(bias=False): spawn(check_conv_module_handler, 4, bias=bias) -@run_on_environment_flag(name='AUTO_PARALLEL') +@run_on_environment_flag(name="AUTO_PARALLEL") @pytest.mark.dist # We temporarily ban the bias option before doing bias add # before all reduce communication may encounter correctness issue. @@ -310,6 +313,6 @@ def test_conv_function_handler(bias=False): spawn(check_conv_function_handler, 4, bias=bias) -if __name__ == '__main__': +if __name__ == "__main__": test_conv_module_handler() test_conv_function_handler() diff --git a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_default_reshape_handler.py b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_default_reshape_handler.py index 64f56ba98e2b9a2725f2a1bd1a8b3a9a18c4d0c2..ce2ae4248fce6f8f7ec364dda1e47c74d8704b6c 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_default_reshape_handler.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_default_reshape_handler.py @@ -12,7 +12,6 @@ from colossalai.testing import clear_cache_before_run, run_on_environment_flag class ReshapeModel(nn.Module): - def __init__(self): super().__init__() @@ -22,7 +21,7 @@ class ReshapeModel(nn.Module): return reshape_node -@run_on_environment_flag(name='AUTO_PARALLEL') +@run_on_environment_flag(name="AUTO_PARALLEL") @clear_cache_before_run() def test_reshape_handler(): model = ReshapeModel() @@ -34,8 +33,8 @@ def test_reshape_handler(): # %view : [#users=1] = call_method[target=view](args = (%conv2d, 2, -1), kwargs = {}) # return view meta_args = { - "input": torch.rand(4, 4, 64, 64).to('meta'), - "other": torch.rand(16, 4, 3, 3).to('meta'), + "input": torch.rand(4, 4, 64, 64).to("meta"), + "other": torch.rand(16, 4, 3, 3).to("meta"), } graph = tracer.trace(model, meta_args=meta_args) gm = ColoGraphModule(model, graph) @@ -50,14 +49,14 @@ def test_reshape_handler(): conv_strategies_vector = StrategiesVector(conv_mod_node) # build handler - conv_handler = ConvFunctionHandler(node=conv_mod_node, - device_mesh=device_mesh, - strategies_vector=conv_strategies_vector) + conv_handler = ConvFunctionHandler( + node=conv_mod_node, device_mesh=device_mesh, strategies_vector=conv_strategies_vector + ) conv_handler.register_strategy(compute_resharding_cost=False) - setattr(conv_mod_node, 'strategies_vector', conv_strategies_vector) - reshape_handler = DefaultReshapeHandler(node=reshape_node, - device_mesh=device_mesh, - strategies_vector=reshape_strategies_vector) + setattr(conv_mod_node, "strategies_vector", conv_strategies_vector) + reshape_handler = DefaultReshapeHandler( + node=reshape_node, device_mesh=device_mesh, strategies_vector=reshape_strategies_vector + ) reshape_handler.register_strategy(compute_resharding_cost=False) @@ -69,20 +68,20 @@ def test_reshape_handler(): # make sure they have valid values assert op_data.data is not None - assert mapping['input'].name == "conv2d" - assert mapping['input'].data.is_meta - assert mapping['input'].data.shape == torch.Size([4, 16, 62, 62]) - assert mapping['input'].type == OperationDataType.ARG - assert mapping['input'].logical_shape == torch.Size([4, 16, 62, 62]) + assert mapping["input"].name == "conv2d" + assert mapping["input"].data.is_meta + assert mapping["input"].data.shape == torch.Size([4, 16, 62, 62]) + assert mapping["input"].type == OperationDataType.ARG + assert mapping["input"].logical_shape == torch.Size([4, 16, 62, 62]) - assert mapping['output'].name == "view" - assert mapping['output'].data.is_meta - assert mapping['output'].data.shape == torch.Size([2, 123008]) - assert mapping['output'].type == OperationDataType.OUTPUT + assert mapping["output"].name == "view" + assert mapping["output"].data.is_meta + assert mapping["output"].data.shape == torch.Size([2, 123008]) + assert mapping["output"].type == OperationDataType.OUTPUT # reshape handler is a following strategy handler, so the number of strategies is equal to the predecessor node. assert len(reshape_strategies_vector) == len(conv_strategies_vector) -if __name__ == '__main__': +if __name__ == "__main__": test_reshape_handler() diff --git a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_embedding_handler.py b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_embedding_handler.py index 4fa0313b1cb5da80c53ace4be0cae5012ea68f0e..9ac6ba95da480c2423bcbf5ec50fa82ca46d4e5c 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_embedding_handler.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_embedding_handler.py @@ -22,7 +22,6 @@ EMBEDDING_DIMS = 32 class EmbeddingModule(nn.Module): - def __init__(self, num_embeddings, embedding_dims): super().__init__() self.embedding = nn.Embedding(num_embeddings, embedding_dims) @@ -34,7 +33,7 @@ class EmbeddingModule(nn.Module): def check_embedding_module_handler(rank, world_size, port): disable_existing_loggers() - launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") model = EmbeddingModule(num_embeddings=NUM_EMBEDDINGS, embedding_dims=EMBEDDING_DIMS).cuda() # graph(): # %input_1 : torch.Tensor [#users=1] = placeholder[target=input] @@ -51,15 +50,17 @@ def check_embedding_module_handler(rank, world_size, port): node_index = 1 # total number of embedding strategies strategy_number = 19 - numerical_test_for_node_strategy(model=model, - device_mesh=device_mesh, - node_index=node_index, - strategy_number=strategy_number, - input_args=[input], - meta_arg_names=['input']) + numerical_test_for_node_strategy( + model=model, + device_mesh=device_mesh, + node_index=node_index, + strategy_number=strategy_number, + input_args=[input], + meta_arg_names=["input"], + ) tracer = ColoTracer(bias_addition_split=True) - meta_args = {"input": torch.randint(NUM_EMBEDDINGS, (4, 16, 16)).to('meta')} + meta_args = {"input": torch.randint(NUM_EMBEDDINGS, (4, 16, 16)).to("meta")} graph = tracer.trace(model, meta_args=meta_args) gm = ColoGraphModule(model, graph) shape_prop_pass(gm, *meta_args.values()) @@ -78,60 +79,60 @@ def check_embedding_module_handler(rank, world_size, port): assert op_data.logical_shape is not None assert op_data.data is not None - assert mapping['input'].name == "input_1" + assert mapping["input"].name == "input_1" # assert mapping['input'].data.is_meta - assert mapping['input'].data.shape == torch.Size([4, 16, 16]) - assert mapping['input'].type == OperationDataType.ARG - assert mapping['input'].logical_shape == torch.Size([1024]) + assert mapping["input"].data.shape == torch.Size([4, 16, 16]) + assert mapping["input"].type == OperationDataType.ARG + assert mapping["input"].logical_shape == torch.Size([1024]) - assert mapping['other'].name == "weight" - assert mapping['other'].data.shape == torch.Size([NUM_EMBEDDINGS, EMBEDDING_DIMS]) - assert mapping['other'].type == OperationDataType.PARAM - assert mapping['other'].logical_shape == torch.Size([NUM_EMBEDDINGS, EMBEDDING_DIMS]) + assert mapping["other"].name == "weight" + assert mapping["other"].data.shape == torch.Size([NUM_EMBEDDINGS, EMBEDDING_DIMS]) + assert mapping["other"].type == OperationDataType.PARAM + assert mapping["other"].logical_shape == torch.Size([NUM_EMBEDDINGS, EMBEDDING_DIMS]) - assert mapping['output'].name == "embedding" - assert mapping['output'].data.shape == torch.Size([4, 16, 16, EMBEDDING_DIMS]) - assert mapping['output'].type == OperationDataType.OUTPUT - assert mapping['output'].logical_shape == torch.Size([1024, EMBEDDING_DIMS]) + assert mapping["output"].name == "embedding" + assert mapping["output"].data.shape == torch.Size([4, 16, 16, EMBEDDING_DIMS]) + assert mapping["output"].type == OperationDataType.OUTPUT + assert mapping["output"].logical_shape == torch.Size([1024, EMBEDDING_DIMS]) strategies_vector = handler.register_strategy(compute_resharding_cost=False) strategy_name_list = [val.name for val in strategies_vector] # RR = RR x RR - assert 'RR = R x RR' in strategy_name_list + assert "RR = R x RR" in strategy_name_list # SR = SR x RR - assert 'S0R = S0 x RR_0' in strategy_name_list - assert 'S0R = S0 x RR_1' in strategy_name_list - assert 'S0R = S0 x RR_2' in strategy_name_list - assert 'S1R = S1 x RR_0' in strategy_name_list - assert 'S1R = S1 x RR_1' in strategy_name_list - assert 'S1R = S1 x RR_2' in strategy_name_list + assert "S0R = S0 x RR_0" in strategy_name_list + assert "S0R = S0 x RR_1" in strategy_name_list + assert "S0R = S0 x RR_2" in strategy_name_list + assert "S1R = S1 x RR_0" in strategy_name_list + assert "S1R = S1 x RR_1" in strategy_name_list + assert "S1R = S1 x RR_2" in strategy_name_list # SS = SR x RS - assert 'S0S1 = S0 x RS1_0' in strategy_name_list - assert 'S0S1 = S0 x RS1_1' in strategy_name_list - assert 'S0S1 = S0 x RS1_2' in strategy_name_list - assert 'S1S0 = S1 x RS0_0' in strategy_name_list - assert 'S1S0 = S1 x RS0_1' in strategy_name_list - assert 'S1S0 = S1 x RS0_2' in strategy_name_list + assert "S0S1 = S0 x RS1_0" in strategy_name_list + assert "S0S1 = S0 x RS1_1" in strategy_name_list + assert "S0S1 = S0 x RS1_2" in strategy_name_list + assert "S1S0 = S1 x RS0_0" in strategy_name_list + assert "S1S0 = S1 x RS0_1" in strategy_name_list + assert "S1S0 = S1 x RS0_2" in strategy_name_list # RS= RR x RS - assert 'RS0 = R x RS0' in strategy_name_list - assert 'RS1 = R x RS1' in strategy_name_list + assert "RS0 = R x RS0" in strategy_name_list + assert "RS1 = R x RS1" in strategy_name_list # S01R = S01R x RR - assert 'S01R = S01 x RR_0' in strategy_name_list - assert 'S01R = S01 x RR_1' in strategy_name_list - assert 'S01R = S01 x RR_2' in strategy_name_list + assert "S01R = S01 x RR_0" in strategy_name_list + assert "S01R = S01 x RR_1" in strategy_name_list + assert "S01R = S01 x RR_2" in strategy_name_list # RS01 = RR x RS01 - assert 'RS01 = R x RS01' in strategy_name_list + assert "RS01 = R x RS01" in strategy_name_list for strategy in strategies_vector: - input_sharding_spec = strategy.get_sharding_spec_by_name('input_1') - weight_sharding_spec = strategy.get_sharding_spec_by_name('weight') - output_sharding_spec = strategy.get_sharding_spec_by_name('embedding') + input_sharding_spec = strategy.get_sharding_spec_by_name("input_1") + weight_sharding_spec = strategy.get_sharding_spec_by_name("weight") + output_sharding_spec = strategy.get_sharding_spec_by_name("embedding") # make sure the sharding matches across different operation data assert output_sharding_spec.sharding_sequence[-1] == weight_sharding_spec.sharding_sequence[-1] @@ -139,7 +140,6 @@ def check_embedding_module_handler(rank, world_size, port): class EmbeddingFunction(nn.Module): - def __init__(self): super().__init__() @@ -150,7 +150,7 @@ class EmbeddingFunction(nn.Module): def check_embedding_function_handler(rank, world_size, port): disable_existing_loggers() - launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") model = EmbeddingFunction().cuda() physical_mesh_id = torch.arange(0, 4) mesh_shape = (2, 2) @@ -159,18 +159,20 @@ def check_embedding_function_handler(rank, world_size, port): input = input.to(torch.int64).cuda() others = torch.rand(NUM_EMBEDDINGS, EMBEDDING_DIMS).cuda() input_args = [input, others] - meta_arg_names = ['input', 'others'] + meta_arg_names = ["input", "others"] input_kwargs = {} # total number of embedding strategies strategy_number = 19 node_index = 2 - numerical_test_for_node_strategy(model=model, - device_mesh=device_mesh, - node_index=node_index, - strategy_number=strategy_number, - input_args=input_args, - meta_arg_names=meta_arg_names, - input_kwargs=input_kwargs) + numerical_test_for_node_strategy( + model=model, + device_mesh=device_mesh, + node_index=node_index, + strategy_number=strategy_number, + input_args=input_args, + meta_arg_names=meta_arg_names, + input_kwargs=input_kwargs, + ) tracer = ColoTracer(bias_addition_split=True) # graph(): # %input_1 : torch.Tensor [#users=1] = placeholder[target=input] @@ -178,8 +180,8 @@ def check_embedding_function_handler(rank, world_size, port): # %embedding : [#users=1] = call_function[target=torch.nn.functional.embedding](args = (%input_1, %others), kwargs = {padding_idx: None, max_norm: None, norm_type: 2.0, scale_grad_by_freq: False, sparse: False}) # return embedding meta_args = { - "input": torch.randint(NUM_EMBEDDINGS, (4, 16, 16)).to('meta'), - "others": torch.rand(NUM_EMBEDDINGS, EMBEDDING_DIMS).to('meta') + "input": torch.randint(NUM_EMBEDDINGS, (4, 16, 16)).to("meta"), + "others": torch.rand(NUM_EMBEDDINGS, EMBEDDING_DIMS).to("meta"), } graph = tracer.trace(model, meta_args=meta_args) gm = ColoGraphModule(model, graph) @@ -189,9 +191,9 @@ def check_embedding_function_handler(rank, world_size, port): strategies_vector = StrategiesVector(embedding_node) # build handler - handler = EmbeddingFunctionHandler(node=embedding_node, - device_mesh=device_mesh, - strategies_vector=strategies_vector) + handler = EmbeddingFunctionHandler( + node=embedding_node, device_mesh=device_mesh, strategies_vector=strategies_vector + ) # check operation data mapping mapping = handler.get_operation_data_mapping() @@ -202,82 +204,82 @@ def check_embedding_function_handler(rank, world_size, port): assert op_data.logical_shape is not None assert op_data.data is not None - assert mapping['input'].name == "input_1" - assert mapping['input'].data.is_meta - assert mapping['input'].data.shape == torch.Size([4, 16, 16]) - assert mapping['input'].type == OperationDataType.ARG - assert mapping['input'].logical_shape == torch.Size([1024]) + assert mapping["input"].name == "input_1" + assert mapping["input"].data.is_meta + assert mapping["input"].data.shape == torch.Size([4, 16, 16]) + assert mapping["input"].type == OperationDataType.ARG + assert mapping["input"].logical_shape == torch.Size([1024]) - assert mapping['other'].name == "others" - assert mapping['other'].data.is_meta - assert mapping['other'].data.shape == torch.Size([NUM_EMBEDDINGS, EMBEDDING_DIMS]) - assert mapping['other'].type == OperationDataType.ARG - assert mapping['other'].logical_shape == torch.Size([NUM_EMBEDDINGS, EMBEDDING_DIMS]) + assert mapping["other"].name == "others" + assert mapping["other"].data.is_meta + assert mapping["other"].data.shape == torch.Size([NUM_EMBEDDINGS, EMBEDDING_DIMS]) + assert mapping["other"].type == OperationDataType.ARG + assert mapping["other"].logical_shape == torch.Size([NUM_EMBEDDINGS, EMBEDDING_DIMS]) - assert mapping['output'].name == "embedding" - assert mapping['output'].data.is_meta - assert mapping['output'].data.shape == torch.Size([4, 16, 16, EMBEDDING_DIMS]) - assert mapping['output'].type == OperationDataType.OUTPUT - assert mapping['output'].logical_shape == torch.Size([1024, EMBEDDING_DIMS]) + assert mapping["output"].name == "embedding" + assert mapping["output"].data.is_meta + assert mapping["output"].data.shape == torch.Size([4, 16, 16, EMBEDDING_DIMS]) + assert mapping["output"].type == OperationDataType.OUTPUT + assert mapping["output"].logical_shape == torch.Size([1024, EMBEDDING_DIMS]) handler.register_strategy(compute_resharding_cost=False) strategy_name_list = [val.name for val in strategies_vector] # RR = RR x RR - assert 'RR = R x RR' in strategy_name_list + assert "RR = R x RR" in strategy_name_list # SR = SR x RR - assert 'S0R = S0 x RR_0' in strategy_name_list - assert 'S0R = S0 x RR_1' in strategy_name_list - assert 'S0R = S0 x RR_2' in strategy_name_list - assert 'S1R = S1 x RR_0' in strategy_name_list - assert 'S1R = S1 x RR_1' in strategy_name_list - assert 'S1R = S1 x RR_2' in strategy_name_list + assert "S0R = S0 x RR_0" in strategy_name_list + assert "S0R = S0 x RR_1" in strategy_name_list + assert "S0R = S0 x RR_2" in strategy_name_list + assert "S1R = S1 x RR_0" in strategy_name_list + assert "S1R = S1 x RR_1" in strategy_name_list + assert "S1R = S1 x RR_2" in strategy_name_list # SS = SR x RS - assert 'S0S1 = S0 x RS1_0' in strategy_name_list - assert 'S0S1 = S0 x RS1_1' in strategy_name_list - assert 'S0S1 = S0 x RS1_2' in strategy_name_list - assert 'S1S0 = S1 x RS0_0' in strategy_name_list - assert 'S1S0 = S1 x RS0_1' in strategy_name_list - assert 'S1S0 = S1 x RS0_2' in strategy_name_list + assert "S0S1 = S0 x RS1_0" in strategy_name_list + assert "S0S1 = S0 x RS1_1" in strategy_name_list + assert "S0S1 = S0 x RS1_2" in strategy_name_list + assert "S1S0 = S1 x RS0_0" in strategy_name_list + assert "S1S0 = S1 x RS0_1" in strategy_name_list + assert "S1S0 = S1 x RS0_2" in strategy_name_list # RS= RR x RS - assert 'RS0 = R x RS0' in strategy_name_list - assert 'RS1 = R x RS1' in strategy_name_list + assert "RS0 = R x RS0" in strategy_name_list + assert "RS1 = R x RS1" in strategy_name_list # S01R = S01R x RR - assert 'S01R = S01 x RR_0' in strategy_name_list - assert 'S01R = S01 x RR_1' in strategy_name_list - assert 'S01R = S01 x RR_2' in strategy_name_list + assert "S01R = S01 x RR_0" in strategy_name_list + assert "S01R = S01 x RR_1" in strategy_name_list + assert "S01R = S01 x RR_2" in strategy_name_list # RS01 = RR x RS01 - assert 'RS01 = R x RS01' in strategy_name_list + assert "RS01 = R x RS01" in strategy_name_list for strategy in strategies_vector: - input_sharding_spec = strategy.get_sharding_spec_by_name('input_1') - weight_sharding_spec = strategy.get_sharding_spec_by_name('others') - output_sharding_spec = strategy.get_sharding_spec_by_name('embedding') + input_sharding_spec = strategy.get_sharding_spec_by_name("input_1") + weight_sharding_spec = strategy.get_sharding_spec_by_name("others") + output_sharding_spec = strategy.get_sharding_spec_by_name("embedding") # make sure the sharding matches across different operation data assert output_sharding_spec.sharding_sequence[-1] == weight_sharding_spec.sharding_sequence[-1] assert input_sharding_spec.sharding_sequence == output_sharding_spec.sharding_sequence[:-1] -@run_on_environment_flag(name='AUTO_PARALLEL') +@run_on_environment_flag(name="AUTO_PARALLEL") @pytest.mark.dist @rerun_if_address_is_in_use() def test_embedding_module_handler(): spawn(check_embedding_module_handler, 4) -@run_on_environment_flag(name='AUTO_PARALLEL') +@run_on_environment_flag(name="AUTO_PARALLEL") @pytest.mark.dist @rerun_if_address_is_in_use() def test_embedding_function_handler(): spawn(check_embedding_function_handler, 4) -if __name__ == '__main__': +if __name__ == "__main__": test_embedding_module_handler() test_embedding_function_handler() diff --git a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_getattr_handler.py b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_getattr_handler.py index a089df743ec0156f5905988f36f37b5afdb387a5..2c464f64d8cab165ba6329e71ee9a02a292dd3c6 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_getattr_handler.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_getattr_handler.py @@ -12,7 +12,6 @@ from colossalai.testing import clear_cache_before_run class GetattrModel(nn.Module): - def __init__(self): super().__init__() self.conv = nn.Conv2d(4, 16, 3, padding=1, bias=False) @@ -22,7 +21,7 @@ class GetattrModel(nn.Module): return weight -@pytest.mark.skip('ShapeProp is not compatible with PyTorch 1.11.0') +@pytest.mark.skip("ShapeProp is not compatible with PyTorch 1.11.0") @clear_cache_before_run() def test_getattr_handler(): model = GetattrModel() @@ -31,7 +30,7 @@ def test_getattr_handler(): # %input_1 : torch.Tensor [#users=0] = placeholder[target=input] # %conv_weight : [#users=1] = get_attr[target=conv.weight] # return conv_weight - meta_args = {'input': torch.rand(4, 4, 64, 64).to('meta')} + meta_args = {"input": torch.rand(4, 4, 64, 64).to("meta")} graph = tracer.trace(model, meta_args=meta_args) gm = ColoGraphModule(model, graph) shape_prop_pass(gm, *meta_args.values()) @@ -42,9 +41,9 @@ def test_getattr_handler(): getattr_strategies_vector = StrategiesVector(getattr_node) # build handler - getattr_handler = GetattrHandler(node=getattr_node, - device_mesh=device_mesh, - strategies_vector=getattr_strategies_vector) + getattr_handler = GetattrHandler( + node=getattr_node, device_mesh=device_mesh, strategies_vector=getattr_strategies_vector + ) getattr_handler.register_strategy(compute_resharding_cost=False) @@ -56,20 +55,20 @@ def test_getattr_handler(): # make sure they have valid values assert op_data.data is not None - assert mapping['output'].name == "conv_weight" - assert mapping['output'].data.shape == torch.Size((16, 4, 3, 3)) - assert mapping['output'].type == OperationDataType.OUTPUT + assert mapping["output"].name == "conv_weight" + assert mapping["output"].data.shape == torch.Size((16, 4, 3, 3)) + assert mapping["output"].type == OperationDataType.OUTPUT strategy_name_list = [val.name for val in getattr_handler.strategies_vector] - assert 'get_attr [S0, S1, R, R]' in strategy_name_list - assert 'get_attr [S1, S0, R, R]' in strategy_name_list - assert 'get_attr [S01, R, R, R]' in strategy_name_list - assert 'get_attr [R, S01, R, R]' in strategy_name_list - assert 'get_attr [S0, R, R, R]' in strategy_name_list - assert 'get_attr [R, S0, R, R]' in strategy_name_list - assert 'get_attr [S1, R, R, R]' in strategy_name_list - assert 'get_attr [R, S1, R, R]' in strategy_name_list - assert 'get_attr [R, R, R, R]' in strategy_name_list + assert "get_attr [S0, S1, R, R]" in strategy_name_list + assert "get_attr [S1, S0, R, R]" in strategy_name_list + assert "get_attr [S01, R, R, R]" in strategy_name_list + assert "get_attr [R, S01, R, R]" in strategy_name_list + assert "get_attr [S0, R, R, R]" in strategy_name_list + assert "get_attr [R, S0, R, R]" in strategy_name_list + assert "get_attr [S1, R, R, R]" in strategy_name_list + assert "get_attr [R, S1, R, R]" in strategy_name_list + assert "get_attr [R, R, R, R]" in strategy_name_list -if __name__ == '__main__': +if __name__ == "__main__": test_getattr_handler() diff --git a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_getitem_handler.py b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_getitem_handler.py index a2e0968b18bb2dc4f2d3cbe27bc5e88bae20f69e..cf802a22803415352c2ab7339b046812306ceee5 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_getitem_handler.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_getitem_handler.py @@ -1,5 +1,3 @@ -from functools import partial - import pytest import torch import torch.nn as nn @@ -21,7 +19,6 @@ from tests.test_auto_parallel.test_tensor_shard.test_node_handler.utils import n class GetItemFromTensorModel(nn.Module): - def __init__(self, getitem_index): super().__init__() self.getitem_index = getitem_index @@ -34,12 +31,12 @@ class GetItemFromTensorModel(nn.Module): def check_getitem_from_tensor_handler(rank, getitem_index, world_size, port): disable_existing_loggers() - launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") model = GetItemFromTensorModel(getitem_index=getitem_index) - input = torch.rand(8, 16, 64, 32).to('cuda') - other = torch.rand(64, 32).to('cuda') + input = torch.rand(8, 16, 64, 32).to("cuda") + other = torch.rand(64, 32).to("cuda") # index of linear node in computation graph node_index = 2 # total number of linear strategies @@ -49,18 +46,20 @@ def check_getitem_from_tensor_handler(rank, getitem_index, world_size, port): mesh_shape = (2, 2) device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True) - numerical_test_for_node_strategy(model=model, - device_mesh=device_mesh, - node_index=node_index, - strategy_number=strategy_number, - input_args=[input, other], - meta_arg_names=['input', 'other'], - node_type='following') + numerical_test_for_node_strategy( + model=model, + device_mesh=device_mesh, + node_index=node_index, + strategy_number=strategy_number, + input_args=[input, other], + meta_arg_names=["input", "other"], + node_type="following", + ) tracer = ColoTracer(bias_addition_split=True) meta_args = { - "input": torch.rand(8, 16, 64, 32).to('meta'), - "other": torch.rand(64, 32).to('meta'), + "input": torch.rand(8, 16, 64, 32).to("meta"), + "other": torch.rand(64, 32).to("meta"), } graph = tracer.trace(model, meta_args=meta_args) @@ -72,14 +71,14 @@ def check_getitem_from_tensor_handler(rank, getitem_index, world_size, port): linear_strategies_vector = StrategiesVector(linear_mod_node) # build handler - linear_handler = LinearFunctionHandler(node=linear_mod_node, - device_mesh=device_mesh, - strategies_vector=linear_strategies_vector) + linear_handler = LinearFunctionHandler( + node=linear_mod_node, device_mesh=device_mesh, strategies_vector=linear_strategies_vector + ) linear_handler.register_strategy(compute_resharding_cost=False) - setattr(linear_mod_node, 'strategies_vector', linear_strategies_vector) - getitem_handler = GetItemHandler(node=getitem_mod_node, - device_mesh=device_mesh, - strategies_vector=getitem_strategies_vector) + setattr(linear_mod_node, "strategies_vector", linear_strategies_vector) + getitem_handler = GetItemHandler( + node=getitem_mod_node, device_mesh=device_mesh, strategies_vector=getitem_strategies_vector + ) getitem_handler.register_strategy(compute_resharding_cost=False) # check operation data mapping @@ -94,17 +93,16 @@ def check_getitem_from_tensor_handler(rank, getitem_index, world_size, port): assert len(getitem_strategies_vector) == len(linear_strategies_vector) -@run_on_environment_flag(name='AUTO_PARALLEL') +@run_on_environment_flag(name="AUTO_PARALLEL") @pytest.mark.dist @rerun_if_address_is_in_use() # @parameterize('getitem_index', [slice(0, 2), (slice(None), slice(None))]) -@parameterize('getitem_index', [1, (1, 4), slice(0, 2), (slice(None), slice(None))]) +@parameterize("getitem_index", [1, (1, 4), slice(0, 2), (slice(None), slice(None))]) def test_getitem_from_tensor_handler(getitem_index): spawn(check_getitem_from_tensor_handler, 4) class GetItemFromTupleModel(nn.Module): - def __init__(self): super().__init__() @@ -114,7 +112,7 @@ class GetItemFromTupleModel(nn.Module): return x -@run_on_environment_flag(name='AUTO_PARALLEL') +@run_on_environment_flag(name="AUTO_PARALLEL") @clear_cache_before_run() def test_getitem_from_tuple_handler(): model = GetItemFromTupleModel() @@ -125,7 +123,7 @@ def test_getitem_from_tuple_handler(): # %getitem : [#users=1] = call_function[target=operator.getitem](args = (%split, 1), kwargs = {}) # return getitem meta_args = { - "input": torch.rand(4, 4, 64, 64).to('meta'), + "input": torch.rand(4, 4, 64, 64).to("meta"), } graph = tracer.trace(model, meta_args=meta_args) gm = ColoGraphModule(model, graph) @@ -146,20 +144,20 @@ def test_getitem_from_tuple_handler(): node=input_node, device_mesh=device_mesh, strategies_vector=input_strategies_vector, - placeholder_option='replicated', + placeholder_option="replicated", ) input_handler.register_strategy(compute_resharding_cost=False) - setattr(input_node, 'strategies_vector', input_strategies_vector) - split_handler = DefaultReshapeHandler(node=split_node, - device_mesh=device_mesh, - strategies_vector=split_strategies_vector) + setattr(input_node, "strategies_vector", input_strategies_vector) + split_handler = DefaultReshapeHandler( + node=split_node, device_mesh=device_mesh, strategies_vector=split_strategies_vector + ) split_handler.register_strategy(compute_resharding_cost=False) - setattr(split_node, 'strategies_vector', split_strategies_vector) - getitem_handler = GetItemHandler(node=getitem_node, - device_mesh=device_mesh, - strategies_vector=getitem_strategies_vector) + setattr(split_node, "strategies_vector", split_strategies_vector) + getitem_handler = GetItemHandler( + node=getitem_node, device_mesh=device_mesh, strategies_vector=getitem_strategies_vector + ) getitem_handler.register_strategy(compute_resharding_cost=False) - setattr(getitem_node, 'strategies_vector', getitem_strategies_vector) + setattr(getitem_node, "strategies_vector", getitem_strategies_vector) # check operation data mapping mapping = getitem_handler.get_operation_data_mapping() @@ -169,23 +167,23 @@ def test_getitem_from_tuple_handler(): # make sure they have valid values assert op_data.data is not None - assert mapping['input'].name == "split" - assert mapping['input'].type == OperationDataType.ARG - assert mapping['input'].logical_shape == (torch.Size([2, 4, 64, 64]), torch.Size([2, 4, 64, 64])) + assert mapping["input"].name == "split" + assert mapping["input"].type == OperationDataType.ARG + assert mapping["input"].logical_shape == (torch.Size([2, 4, 64, 64]), torch.Size([2, 4, 64, 64])) - assert mapping['index'].name == "index" - assert isinstance(mapping['index'].data, int) - assert mapping['index'].type == OperationDataType.ARG + assert mapping["index"].name == "index" + assert isinstance(mapping["index"].data, int) + assert mapping["index"].type == OperationDataType.ARG - assert mapping['output'].name == "getitem" - assert mapping['output'].data.is_meta - assert mapping['output'].data.shape == torch.Size([2, 4, 64, 64]) - assert mapping['output'].type == OperationDataType.OUTPUT + assert mapping["output"].name == "getitem" + assert mapping["output"].data.is_meta + assert mapping["output"].data.shape == torch.Size([2, 4, 64, 64]) + assert mapping["output"].type == OperationDataType.OUTPUT # getitem is a following strategy handler, so the number of strategies is equal to the predecessor node. assert len(getitem_strategies_vector) == len(split_strategies_vector) -if __name__ == '__main__': +if __name__ == "__main__": test_getitem_from_tensor_handler() test_getitem_from_tuple_handler() diff --git a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_layer_norm_handler.py b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_layer_norm_handler.py index ad72c2026b9aea03b83bfdf0db41c915a433618a..59a66bc6a5d6d990c39546196a6aa73e976ef885 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_layer_norm_handler.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_layer_norm_handler.py @@ -17,7 +17,7 @@ from tests.test_auto_parallel.test_tensor_shard.test_node_handler.utils import n def check_ln_module_handler(rank, world_size, port): disable_existing_loggers() - launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") model = nn.Sequential(nn.LayerNorm(16)).cuda() physical_mesh_id = torch.arange(0, 4) mesh_shape = (2, 2) @@ -30,19 +30,21 @@ def check_ln_module_handler(rank, world_size, port): # construct input args input_args = [input] # construct meta arg names - meta_arg_names = ['input'] - numerical_test_for_node_strategy(model=model, - device_mesh=device_mesh, - node_index=node_index, - strategy_number=strategy_number, - input_args=input_args, - meta_arg_names=meta_arg_names) + meta_arg_names = ["input"] + numerical_test_for_node_strategy( + model=model, + device_mesh=device_mesh, + node_index=node_index, + strategy_number=strategy_number, + input_args=input_args, + meta_arg_names=meta_arg_names, + ) tracer = ColoTracer(bias_addition_split=True) # graph(): # %input_1 : torch.Tensor [#users=1] = placeholder[target=input] # %_0 : [#users=1] = call_module[target=0](args = (%input_1,), kwargs = {}) # return _0 - meta_args = {"input": torch.rand(4, 16).to('meta')} + meta_args = {"input": torch.rand(4, 16).to("meta")} graph = tracer.trace(model, meta_args=meta_args) gm = ColoGraphModule(model, graph) shape_prop_pass(gm, *meta_args.values()) @@ -62,45 +64,45 @@ def check_ln_module_handler(rank, world_size, port): assert op_data.logical_shape is not None assert op_data.data is not None - assert mapping['input'].name == "input_1" - assert mapping['input'].data.shape == torch.Size([4, 16]) - assert mapping['input'].type == OperationDataType.ARG - assert mapping['input'].logical_shape == torch.Size([4, 16]) + assert mapping["input"].name == "input_1" + assert mapping["input"].data.shape == torch.Size([4, 16]) + assert mapping["input"].type == OperationDataType.ARG + assert mapping["input"].logical_shape == torch.Size([4, 16]) - assert mapping['other'].name == "weight" - assert mapping['other'].data.shape == torch.Size([16]) - assert mapping['other'].type == OperationDataType.PARAM - assert mapping['other'].logical_shape == torch.Size([16]) + assert mapping["other"].name == "weight" + assert mapping["other"].data.shape == torch.Size([16]) + assert mapping["other"].type == OperationDataType.PARAM + assert mapping["other"].logical_shape == torch.Size([16]) - assert mapping['bias'].name == "bias" - assert mapping['bias'].data.shape == torch.Size([16]) - assert mapping['bias'].type == OperationDataType.PARAM - assert mapping['bias'].logical_shape == torch.Size([16]) + assert mapping["bias"].name == "bias" + assert mapping["bias"].data.shape == torch.Size([16]) + assert mapping["bias"].type == OperationDataType.PARAM + assert mapping["bias"].logical_shape == torch.Size([16]) - assert mapping['output'].name == "_0" - assert mapping['output'].data.shape == torch.Size([4, 16]) - assert mapping['output'].type == OperationDataType.OUTPUT + assert mapping["output"].name == "_0" + assert mapping["output"].data.shape == torch.Size([4, 16]) + assert mapping["output"].type == OperationDataType.OUTPUT strategies_vector = handler.register_strategy(compute_resharding_cost=False) strategy_name_list = [val.name for val in strategies_vector] # SR = SR x R - assert '[S0, R] = [S0, R] x [R]' in strategy_name_list - assert '[S1, R] = [S1, R] x [R]' in strategy_name_list + assert "[S0, R] = [S0, R] x [R]" in strategy_name_list + assert "[S1, R] = [S1, R] x [R]" in strategy_name_list # RR = RR x R - assert 'RR = RR x R' in strategy_name_list + assert "RR = RR x R" in strategy_name_list # S01R = S01R x R - assert '[S01, R] = [S01, R] x [R]' in strategy_name_list + assert "[S01, R] = [S01, R] x [R]" in strategy_name_list -@run_on_environment_flag(name='AUTO_PARALLEL') +@run_on_environment_flag(name="AUTO_PARALLEL") @pytest.mark.dist @rerun_if_address_is_in_use() def test_ln_module_handler(): spawn(check_ln_module_handler, 4) -if __name__ == '__main__': +if __name__ == "__main__": test_ln_module_handler() diff --git a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_linear_handler.py b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_linear_handler.py index ec695cd8f7b9e26704f28da68862844b6c75d4a4..da88b735f7c14c65f008eaa43f5bd119eeee3473 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_linear_handler.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_linear_handler.py @@ -23,7 +23,7 @@ from tests.test_auto_parallel.test_tensor_shard.test_node_handler.utils import n def check_linear_module_handler(rank, world_size, port, bias, input_shape): disable_existing_loggers() - launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") model = nn.Sequential(nn.Linear(16, 32, bias=bias)).cuda() physical_mesh_id = torch.arange(0, 4) mesh_shape = (2, 2) @@ -39,13 +39,15 @@ def check_linear_module_handler(rank, world_size, port, bias, input_shape): # construct input args input_args = [input] # construct meta arg names - meta_arg_names = ['input'] - numerical_test_for_node_strategy(model=model, - device_mesh=device_mesh, - node_index=node_index, - strategy_number=strategy_number, - input_args=input_args, - meta_arg_names=meta_arg_names) + meta_arg_names = ["input"] + numerical_test_for_node_strategy( + model=model, + device_mesh=device_mesh, + node_index=node_index, + strategy_number=strategy_number, + input_args=input_args, + meta_arg_names=meta_arg_names, + ) tracer = ColoTracer(bias_addition_split=True) meta_args = {"input": torch.rand(input_shape).cuda()} @@ -68,86 +70,86 @@ def check_linear_module_handler(rank, world_size, port, bias, input_shape): assert op_data.logical_shape is not None assert op_data.data is not None - assert mapping['input'].name == "input_1" - assert mapping['input'].data.shape == torch.Size(input_shape) - assert mapping['input'].type == OperationDataType.ARG - input_logical_shape = mapping['input'].data.view(-1, 16).shape - assert mapping['input'].logical_shape == input_logical_shape + assert mapping["input"].name == "input_1" + assert mapping["input"].data.shape == torch.Size(input_shape) + assert mapping["input"].type == OperationDataType.ARG + input_logical_shape = mapping["input"].data.view(-1, 16).shape + assert mapping["input"].logical_shape == input_logical_shape - assert mapping['other'].name == "weight" - assert mapping['other'].data.shape == torch.Size([32, 16]) - assert mapping['other'].type == OperationDataType.PARAM - assert mapping['other'].logical_shape == torch.Size([16, 32]) + assert mapping["other"].name == "weight" + assert mapping["other"].data.shape == torch.Size([32, 16]) + assert mapping["other"].type == OperationDataType.PARAM + assert mapping["other"].logical_shape == torch.Size([16, 32]) if bias: - assert mapping['bias'].name == "bias" - assert mapping['bias'].data.shape == torch.Size([32]) - assert mapping['bias'].type == OperationDataType.PARAM - assert mapping['bias'].logical_shape == torch.Size([32]) + assert mapping["bias"].name == "bias" + assert mapping["bias"].data.shape == torch.Size([32]) + assert mapping["bias"].type == OperationDataType.PARAM + assert mapping["bias"].logical_shape == torch.Size([32]) - assert mapping['output'].name == "_0" + assert mapping["output"].name == "_0" output_shape = input_shape[:-1] + (32,) - assert mapping['output'].data.shape == torch.Size(output_shape) - assert mapping['output'].type == OperationDataType.OUTPUT - output_logical_shape = mapping['output'].data.view(-1, 32).shape - assert mapping['output'].logical_shape == torch.Size(output_logical_shape) + assert mapping["output"].data.shape == torch.Size(output_shape) + assert mapping["output"].type == OperationDataType.OUTPUT + output_logical_shape = mapping["output"].data.view(-1, 32).shape + assert mapping["output"].logical_shape == torch.Size(output_logical_shape) strategies_vector = handler.register_strategy(compute_resharding_cost=False) strategy_name_list = [val.name for val in strategies_vector] # First dimension cannot be shard if input shape is (1, 4, 4, 16) if input_shape != (1, 4, 4, 16): - assert 'S1S0 = S1R x RS0_0' in strategy_name_list - assert 'S0S1 = S0R x RS1_0' in strategy_name_list - assert 'S1R = S1S0 x S0R_0' in strategy_name_list - assert 'S0R = S0S1 x S1R_0' in strategy_name_list - assert 'S01R = S01R x RR_0' in strategy_name_list + assert "S1S0 = S1R x RS0_0" in strategy_name_list + assert "S0S1 = S0R x RS1_0" in strategy_name_list + assert "S1R = S1S0 x S0R_0" in strategy_name_list + assert "S0R = S0S1 x S1R_0" in strategy_name_list + assert "S01R = S01R x RR_0" in strategy_name_list # SS = SR x RS - assert 'S0S1 = S0R x RS1_1' in strategy_name_list - assert 'S0S1 = S0R x RS1_2' in strategy_name_list - assert 'S1S0 = S1R x RS0_1' in strategy_name_list - assert 'S1S0 = S1R x RS0_2' in strategy_name_list + assert "S0S1 = S0R x RS1_1" in strategy_name_list + assert "S0S1 = S0R x RS1_2" in strategy_name_list + assert "S1S0 = S1R x RS0_1" in strategy_name_list + assert "S1S0 = S1R x RS0_2" in strategy_name_list # SR = SS x SR - assert 'S0R = S0S1 x S1R_1' in strategy_name_list - assert 'S0R = S0S1 x S1R_2' in strategy_name_list - assert 'S1R = S1S0 x S0R_1' in strategy_name_list - assert 'S1R = S1S0 x S0R_2' in strategy_name_list + assert "S0R = S0S1 x S1R_1" in strategy_name_list + assert "S0R = S0S1 x S1R_2" in strategy_name_list + assert "S1R = S1S0 x S0R_1" in strategy_name_list + assert "S1R = S1S0 x S0R_2" in strategy_name_list # RS = RS x SS - assert 'RS0 = RS1 x S1S0' in strategy_name_list - assert 'RS1 = RS0 x S0S1' in strategy_name_list + assert "RS0 = RS1 x S1S0" in strategy_name_list + assert "RS1 = RS0 x S0S1" in strategy_name_list # RR = RS x SR - assert 'RR = RS0 x S0R' in strategy_name_list - assert 'RR = RS1 x S1R' in strategy_name_list + assert "RR = RS0 x S0R" in strategy_name_list + assert "RR = RS1 x S1R" in strategy_name_list # RS= RR x RS - assert 'RS0 = RR x RS0' in strategy_name_list - assert 'RS1 = RR x RS1' in strategy_name_list + assert "RS0 = RR x RS0" in strategy_name_list + assert "RS1 = RR x RS1" in strategy_name_list # S01R = S01R x RR - assert 'S01R = S01R x RR_1' in strategy_name_list - assert 'S01R = S01R x RR_2' in strategy_name_list + assert "S01R = S01R x RR_1" in strategy_name_list + assert "S01R = S01R x RR_2" in strategy_name_list # RR = RS01 x S01R - assert 'RR = RS01 x S01R' in strategy_name_list + assert "RR = RS01 x S01R" in strategy_name_list # RS01 = RR x RS01 - assert 'RS01 = RR x RS01' in strategy_name_list + assert "RS01 = RR x RS01" in strategy_name_list # RR = RR x RR - assert 'RR = RR x RR' in strategy_name_list + assert "RR = RR x RR" in strategy_name_list for strategy in strategies_vector: strategy: ShardingStrategy - input_sharding_spec = strategy.get_sharding_spec_by_name('input_1') - weight_sharding_spec = strategy.get_sharding_spec_by_name('weight') - output_sharding_spec = strategy.get_sharding_spec_by_name('_0') + input_sharding_spec = strategy.get_sharding_spec_by_name("input_1") + weight_sharding_spec = strategy.get_sharding_spec_by_name("weight") + output_sharding_spec = strategy.get_sharding_spec_by_name("_0") if bias: - bias_sharding_spec = strategy.get_sharding_spec_by_name('bias') + bias_sharding_spec = strategy.get_sharding_spec_by_name("bias") # make sure the sharding matches across different operation data assert input_sharding_spec.sharding_sequence[:-1] == output_sharding_spec.sharding_sequence[:-1] @@ -159,7 +161,6 @@ def check_linear_module_handler(rank, world_size, port, bias, input_shape): class LinearModel(nn.Module): - def __init__(self): super().__init__() @@ -170,7 +171,7 @@ class LinearModel(nn.Module): def check_linear_function_handler(rank, world_size, port, bias, input_shape): disable_existing_loggers() - launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") model = LinearModel().cuda() physical_mesh_id = torch.arange(0, 4) mesh_shape = (2, 2) @@ -188,16 +189,18 @@ def check_linear_function_handler(rank, world_size, port, bias, input_shape): # construct input args input_args = [input, other] # construct meta arg names - meta_arg_names = ['input', 'others'] - numerical_test_for_node_strategy(model=model, - device_mesh=device_mesh, - node_index=node_index, - strategy_number=strategy_number, - input_args=input_args, - meta_arg_names=meta_arg_names) + meta_arg_names = ["input", "others"] + numerical_test_for_node_strategy( + model=model, + device_mesh=device_mesh, + node_index=node_index, + strategy_number=strategy_number, + input_args=input_args, + meta_arg_names=meta_arg_names, + ) tracer = ColoTracer(bias_addition_split=True) - meta_args = {'input': torch.rand(input_shape).to('meta'), 'others': torch.rand(32, 16).to('meta')} + meta_args = {"input": torch.rand(input_shape).to("meta"), "others": torch.rand(32, 16).to("meta")} graph = tracer.trace(model, meta_args=meta_args) gm = ColoGraphModule(model, graph) shape_prop_pass(gm, *meta_args.values()) @@ -214,86 +217,86 @@ def check_linear_function_handler(rank, world_size, port, bias, input_shape): # # check operation data mapping mapping = handler.get_operation_data_mapping() - assert mapping['input'].name == "input_1" - assert mapping['input'].data.shape == torch.Size(input_shape) - assert mapping['input'].type == OperationDataType.ARG - input_logical_shape = mapping['input'].data.view(-1, 16).shape - assert mapping['input'].logical_shape == torch.Size(input_logical_shape) + assert mapping["input"].name == "input_1" + assert mapping["input"].data.shape == torch.Size(input_shape) + assert mapping["input"].type == OperationDataType.ARG + input_logical_shape = mapping["input"].data.view(-1, 16).shape + assert mapping["input"].logical_shape == torch.Size(input_logical_shape) - assert mapping['other'].name == "others" - assert mapping['other'].data.shape == torch.Size([32, 16]) - assert mapping['other'].type == OperationDataType.ARG - assert mapping['other'].logical_shape == torch.Size([16, 32]) + assert mapping["other"].name == "others" + assert mapping["other"].data.shape == torch.Size([32, 16]) + assert mapping["other"].type == OperationDataType.ARG + assert mapping["other"].logical_shape == torch.Size([16, 32]) if bias: - assert mapping['bias'].name == "bias" - assert mapping['bias'].data.shape == torch.Size([32]) - assert mapping['bias'].type == OperationDataType.ARG - assert mapping['other'].logical_shape == torch.Size([16, 32]) + assert mapping["bias"].name == "bias" + assert mapping["bias"].data.shape == torch.Size([32]) + assert mapping["bias"].type == OperationDataType.ARG + assert mapping["other"].logical_shape == torch.Size([16, 32]) - assert mapping['output'].name == "linear" + assert mapping["output"].name == "linear" output_shape = input_shape[:-1] + (32,) - assert mapping['output'].data.shape == torch.Size(output_shape) - assert mapping['output'].type == OperationDataType.OUTPUT - output_logical_shape = mapping['output'].data.view(-1, 32).shape - assert mapping['output'].logical_shape == torch.Size(output_logical_shape) + assert mapping["output"].data.shape == torch.Size(output_shape) + assert mapping["output"].type == OperationDataType.OUTPUT + output_logical_shape = mapping["output"].data.view(-1, 32).shape + assert mapping["output"].logical_shape == torch.Size(output_logical_shape) strategies_vector = handler.register_strategy(compute_resharding_cost=False) strategy_name_list = [val.name for val in strategies_vector] # First dimension cannot be shard if input shape is (1, 4, 4, 16) if input_shape != (1, 4, 4, 16): - assert 'S1S0 = S1R x RS0_0' in strategy_name_list - assert 'S0S1 = S0R x RS1_0' in strategy_name_list - assert 'S1R = S1S0 x S0R_0' in strategy_name_list - assert 'S0R = S0S1 x S1R_0' in strategy_name_list - assert 'S01R = S01R x RR_0' in strategy_name_list + assert "S1S0 = S1R x RS0_0" in strategy_name_list + assert "S0S1 = S0R x RS1_0" in strategy_name_list + assert "S1R = S1S0 x S0R_0" in strategy_name_list + assert "S0R = S0S1 x S1R_0" in strategy_name_list + assert "S01R = S01R x RR_0" in strategy_name_list # SS = SR x RS - assert 'S0S1 = S0R x RS1_1' in strategy_name_list - assert 'S0S1 = S0R x RS1_2' in strategy_name_list - assert 'S1S0 = S1R x RS0_1' in strategy_name_list - assert 'S1S0 = S1R x RS0_2' in strategy_name_list + assert "S0S1 = S0R x RS1_1" in strategy_name_list + assert "S0S1 = S0R x RS1_2" in strategy_name_list + assert "S1S0 = S1R x RS0_1" in strategy_name_list + assert "S1S0 = S1R x RS0_2" in strategy_name_list # SR = SS x SR - assert 'S0R = S0S1 x S1R_1' in strategy_name_list - assert 'S0R = S0S1 x S1R_2' in strategy_name_list - assert 'S1R = S1S0 x S0R_1' in strategy_name_list - assert 'S1R = S1S0 x S0R_2' in strategy_name_list + assert "S0R = S0S1 x S1R_1" in strategy_name_list + assert "S0R = S0S1 x S1R_2" in strategy_name_list + assert "S1R = S1S0 x S0R_1" in strategy_name_list + assert "S1R = S1S0 x S0R_2" in strategy_name_list # RS = RS x SS - assert 'RS0 = RS1 x S1S0' in strategy_name_list - assert 'RS1 = RS0 x S0S1' in strategy_name_list + assert "RS0 = RS1 x S1S0" in strategy_name_list + assert "RS1 = RS0 x S0S1" in strategy_name_list # RR = RS x SR - assert 'RR = RS0 x S0R' in strategy_name_list - assert 'RR = RS1 x S1R' in strategy_name_list + assert "RR = RS0 x S0R" in strategy_name_list + assert "RR = RS1 x S1R" in strategy_name_list # RS= RR x RS - assert 'RS0 = RR x RS0' in strategy_name_list - assert 'RS1 = RR x RS1' in strategy_name_list + assert "RS0 = RR x RS0" in strategy_name_list + assert "RS1 = RR x RS1" in strategy_name_list # S01R = S01R x RR - assert 'S01R = S01R x RR_1' in strategy_name_list - assert 'S01R = S01R x RR_2' in strategy_name_list + assert "S01R = S01R x RR_1" in strategy_name_list + assert "S01R = S01R x RR_2" in strategy_name_list # RR = RS01 x S01R - assert 'RR = RS01 x S01R' in strategy_name_list + assert "RR = RS01 x S01R" in strategy_name_list # RS01 = RR x RS01 - assert 'RS01 = RR x RS01' in strategy_name_list + assert "RS01 = RR x RS01" in strategy_name_list # RR = RR x RR - assert 'RR = RR x RR' in strategy_name_list + assert "RR = RR x RR" in strategy_name_list for strategy in strategies_vector: strategy: ShardingStrategy - input_sharding_spec = strategy.get_sharding_spec_by_name('input_1') - weight_sharding_spec = strategy.get_sharding_spec_by_name('others') - output_sharding_spec = strategy.get_sharding_spec_by_name('linear') + input_sharding_spec = strategy.get_sharding_spec_by_name("input_1") + weight_sharding_spec = strategy.get_sharding_spec_by_name("others") + output_sharding_spec = strategy.get_sharding_spec_by_name("linear") if bias: - bias_sharding_spec = strategy.get_sharding_spec_by_name('bias') + bias_sharding_spec = strategy.get_sharding_spec_by_name("bias") # make sure the sharding matches across different operation data assert input_sharding_spec.sharding_sequence[:-1] == output_sharding_spec.sharding_sequence[:-1] @@ -304,8 +307,8 @@ def check_linear_function_handler(rank, world_size, port, bias, input_shape): assert bias_sharding_spec.sharding_sequence[-1] == output_sharding_spec.sharding_sequence[-1] -@run_on_environment_flag(name='AUTO_PARALLEL') -@parameterize('input_shape', [(1, 4, 4, 16), (4, 4, 4, 16)]) +@run_on_environment_flag(name="AUTO_PARALLEL") +@parameterize("input_shape", [(1, 4, 4, 16), (4, 4, 4, 16)]) @pytest.mark.dist @rerun_if_address_is_in_use() def test_linear_handler(input_shape, bias=False): @@ -323,5 +326,5 @@ def test_linear_handler(input_shape, bias=False): ) -if __name__ == '__main__': +if __name__ == "__main__": test_linear_handler() diff --git a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_matmul_handler.py b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_matmul_handler.py index 938acd3d1eeacef457cc887ef92cb8624884a7e3..5fb4985e2f3cd91c6d98eb0aba6e06b43cb79cff 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_matmul_handler.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_matmul_handler.py @@ -22,31 +22,31 @@ from colossalai.testing.utils import clear_cache_before_run, parameterize class MatMulModule(nn.Module): - def forward(self, x1, x2): return torch.matmul(x1, x2) -@pytest.mark.skipif(torch.__version__ < '1.12.0', reason="need pytorch 1.12.0 or higher for aten level operations") +@pytest.mark.skipif(torch.__version__ < "1.12.0", reason="need pytorch 1.12.0 or higher for aten level operations") @clear_cache_before_run() @parameterize( - 'tensor_shapes', + "tensor_shapes", [ - [[8], [8]], # dot product - [[4, 8], [8]], # mat-vec product - [[4, 8], [8, 16]], # mat-mat product - [[8], [8, 16]], # mat-mat product - [[8], [4, 8, 16]], # batched mat-mat product with padding + broadcasting - [[4, 8, 16], [16]], # batched mat-mat product with padding + broadcasting - [[4, 8, 16], [16, 32]], # batched mat-mat product with broadcasting - [[4, 8, 16], [1, 16, 32]], # batched mat-mat product with broadcasting - [[8, 16], [2, 4, 16, 32]], # batched mat-mat product with broadcasting - [[4, 8, 16], [2, 4, 16, 32]], # batched mat-mat product with broadcasting - [[1, 8, 16], [2, 4, 16, 32]], # batched mat-mat product with broadcasting - [[1, 4, 8, 16], [2, 4, 16, 32]], # batched mat-mat product with broadcasting - [[2, 1, 8, 16], [2, 4, 16, 32]], # batched mat-mat product with broadcasting - [[2, 4, 8, 16], [2, 4, 16, 32]], # batched mat-mat product without broadcasting - ]) + [[8], [8]], # dot product + [[4, 8], [8]], # mat-vec product + [[4, 8], [8, 16]], # mat-mat product + [[8], [8, 16]], # mat-mat product + [[8], [4, 8, 16]], # batched mat-mat product with padding + broadcasting + [[4, 8, 16], [16]], # batched mat-mat product with padding + broadcasting + [[4, 8, 16], [16, 32]], # batched mat-mat product with broadcasting + [[4, 8, 16], [1, 16, 32]], # batched mat-mat product with broadcasting + [[8, 16], [2, 4, 16, 32]], # batched mat-mat product with broadcasting + [[4, 8, 16], [2, 4, 16, 32]], # batched mat-mat product with broadcasting + [[1, 8, 16], [2, 4, 16, 32]], # batched mat-mat product with broadcasting + [[1, 4, 8, 16], [2, 4, 16, 32]], # batched mat-mat product with broadcasting + [[2, 1, 8, 16], [2, 4, 16, 32]], # batched mat-mat product with broadcasting + [[2, 4, 8, 16], [2, 4, 16, 32]], # batched mat-mat product without broadcasting + ], +) def test_matmul_node_handler(tensor_shapes): input_shape, other_shape = tensor_shapes @@ -61,7 +61,7 @@ def test_matmul_node_handler(tensor_shapes): model = MatMulModule() tracer = ColoTracer(bias_addition_split=True) - meta_args = {"x1": x1.to('meta'), 'x2': x2.to('meta')} + meta_args = {"x1": x1.to("meta"), "x2": x2.to("meta")} graph = tracer.trace(model, meta_args=meta_args) gm = ColoGraphModule(model, graph) shape_prop_pass(gm, *meta_args.values()) @@ -92,30 +92,31 @@ def test_matmul_node_handler(tensor_shapes): logical_input_shape = [1] + input_shape elif matmul_type == MatMulType.BMM: logical_input_shape, logical_other_shape, logical_output_shape = _get_bmm_logical_shape( - input_shape, other_shape, handler.transforms) + input_shape, other_shape, handler.transforms + ) else: logical_input_shape = input_shape # check input operation data - assert mapping['input'].name == "x1" - assert mapping['input'].data.is_meta - assert mapping['input'].data.shape == torch.Size(input_shape) - assert mapping['input'].type == OperationDataType.ARG - assert mapping['input'].logical_shape == torch.Size(logical_input_shape) + assert mapping["input"].name == "x1" + assert mapping["input"].data.is_meta + assert mapping["input"].data.shape == torch.Size(input_shape) + assert mapping["input"].type == OperationDataType.ARG + assert mapping["input"].logical_shape == torch.Size(logical_input_shape) # check other operation data - assert mapping['other'].name == "x2" - assert mapping['other'].data.is_meta - assert mapping['other'].data.shape == torch.Size(other_shape) - assert mapping['other'].type == OperationDataType.ARG - assert mapping['other'].logical_shape == torch.Size(logical_other_shape) + assert mapping["other"].name == "x2" + assert mapping["other"].data.is_meta + assert mapping["other"].data.shape == torch.Size(other_shape) + assert mapping["other"].type == OperationDataType.ARG + assert mapping["other"].logical_shape == torch.Size(logical_other_shape) # check output - assert mapping['output'].name == "matmul" - assert mapping['output'].data.is_meta - assert mapping['output'].data.shape == torch.Size(output_shape) - assert mapping['output'].type == OperationDataType.OUTPUT - assert mapping['output'].logical_shape == torch.Size(logical_output_shape) + assert mapping["output"].name == "matmul" + assert mapping["output"].data.is_meta + assert mapping["output"].data.shape == torch.Size(output_shape) + assert mapping["output"].type == OperationDataType.OUTPUT + assert mapping["output"].logical_shape == torch.Size(logical_output_shape) strategies_vector = handler.register_strategy(compute_resharding_cost=False) strategy_name_list = [val.name for val in strategies_vector] @@ -126,9 +127,9 @@ def test_matmul_node_handler(tensor_shapes): for strategy in strategies_vector: strategy: ShardingStrategy - input_sharding_spec = strategy.get_sharding_spec_by_name('x1') - other_sharding_spec = strategy.get_sharding_spec_by_name('x2') - output_sharding_spec = strategy.get_sharding_spec_by_name('matmul') + input_sharding_spec = strategy.get_sharding_spec_by_name("x1") + other_sharding_spec = strategy.get_sharding_spec_by_name("x2") + output_sharding_spec = strategy.get_sharding_spec_by_name("matmul") if matmul_type == MatMulType.DOT: # dot product will produce a scaler # results should fulfill: @@ -171,5 +172,5 @@ def test_matmul_node_handler(tensor_shapes): assert other_sharding_spec.sharding_sequence[-2] == input_sharding_spec.sharding_sequence[-1] -if __name__ == '__main__': +if __name__ == "__main__": test_matmul_node_handler() diff --git a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_norm_pooling_handler.py b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_norm_pooling_handler.py index 6bff9f9648e2aca20c5c480726bfeb7038041712..6b7ac766ff184c1af9362baf9c56413c6cb7ed8c 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_norm_pooling_handler.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_norm_pooling_handler.py @@ -10,16 +10,16 @@ from colossalai.device.device_mesh import DeviceMesh from colossalai.testing import clear_cache_before_run, run_on_environment_flag -@run_on_environment_flag(name='AUTO_PARALLEL') +@run_on_environment_flag(name="AUTO_PARALLEL") @clear_cache_before_run() def test_norm_pool_handler(): - model = nn.Sequential(nn.MaxPool2d(4, padding=1).to('meta')) + model = nn.Sequential(nn.MaxPool2d(4, padding=1).to("meta")) tracer = ColoTracer(bias_addition_split=True) # graph(): # %input_1 : torch.Tensor [#users=1] = placeholder[target=input] # %_0 : [#users=1] = call_module[target=0](args = (%input_1,), kwargs = {}) # return _0 - meta_args = {"input": torch.rand(4, 4, 64, 64).to('meta')} + meta_args = {"input": torch.rand(4, 4, 64, 64).to("meta")} graph = tracer.trace(model, meta_args=meta_args) gm = ColoGraphModule(model, graph) @@ -41,21 +41,21 @@ def test_norm_pool_handler(): # make sure they have valid values assert op_data.data is not None - assert mapping['input'].name == "input_1" - assert mapping['input'].data.is_meta - assert mapping['input'].data.shape == torch.Size([4, 4, 64, 64]) - assert mapping['input'].type == OperationDataType.ARG - assert mapping['input'].logical_shape == torch.Size([4, 4, 64, 64]) + assert mapping["input"].name == "input_1" + assert mapping["input"].data.is_meta + assert mapping["input"].data.shape == torch.Size([4, 4, 64, 64]) + assert mapping["input"].type == OperationDataType.ARG + assert mapping["input"].logical_shape == torch.Size([4, 4, 64, 64]) - assert mapping['output'].name == "_0" - assert mapping['output'].data.is_meta - assert mapping['output'].data.shape == torch.Size([4, 4, 16, 16]) - assert mapping['output'].type == OperationDataType.OUTPUT + assert mapping["output"].name == "_0" + assert mapping["output"].data.is_meta + assert mapping["output"].data.shape == torch.Size([4, 4, 16, 16]) + assert mapping["output"].type == OperationDataType.OUTPUT strategies_vector = handler.register_strategy(compute_resharding_cost=False) strategy_name_list = [val.name for val in strategies_vector] assert len(strategy_name_list) == 9 -if __name__ == '__main__': +if __name__ == "__main__": test_norm_pool_handler() diff --git a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_output_handler.py b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_output_handler.py index 5259455d2179cae56865bba73718d8e818a39c41..4da986181f8969c94ede10bb08e02a8bdcbc56f4 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_output_handler.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_output_handler.py @@ -12,7 +12,6 @@ from colossalai.testing import clear_cache_before_run, parameterize class OutputModel(nn.Module): - def __init__(self): super().__init__() @@ -21,8 +20,8 @@ class OutputModel(nn.Module): return x, y -@pytest.mark.skip('ShapeProp is not compatible with PyTorch 1.11.0') -@parameterize('output_option', ['distributed', 'replicated']) +@pytest.mark.skip("ShapeProp is not compatible with PyTorch 1.11.0") +@parameterize("output_option", ["distributed", "replicated"]) @clear_cache_before_run() def test_output_handler(output_option): model = OutputModel() @@ -31,7 +30,7 @@ def test_output_handler(output_option): # %x : torch.Tensor [#users=2] = placeholder[target=x] # %mul : [#users=1] = call_function[target=operator.mul](args = (%x, 2), kwargs = {}) # return (x, mul) - meta_args = {'x': torch.rand(4, 4, 64, 64).to('meta')} + meta_args = {"x": torch.rand(4, 4, 64, 64).to("meta")} graph = tracer.trace(model, meta_args=meta_args) gm = ColoGraphModule(model, graph) shape_prop_pass(gm, *meta_args.values()) @@ -43,28 +42,30 @@ def test_output_handler(output_option): output_strategies_vector = StrategiesVector(output_node) # build handler - otuput_handler = OutputHandler(node=output_node, - device_mesh=device_mesh, - strategies_vector=output_strategies_vector, - output_option=output_option) + output_handler = OutputHandler( + node=output_node, + device_mesh=device_mesh, + strategies_vector=output_strategies_vector, + output_option=output_option, + ) - otuput_handler.register_strategy(compute_resharding_cost=False) + output_handler.register_strategy(compute_resharding_cost=False) # check operation data mapping - mapping = otuput_handler.get_operation_data_mapping() + mapping = output_handler.get_operation_data_mapping() for name, op_data in mapping.items(): op_data: OperationData # make sure they have valid values assert op_data.data is not None - assert mapping['output'].name == "output" - assert mapping['output'].type == OperationDataType.OUTPUT - strategy_name_list = [val.name for val in otuput_handler.strategies_vector] - if output_option == 'distributed': + assert mapping["output"].name == "output" + assert mapping["output"].type == OperationDataType.OUTPUT + strategy_name_list = [val.name for val in output_handler.strategies_vector] + if output_option == "distributed": assert "Distributed Output" in strategy_name_list else: assert "Replica Output" in strategy_name_list -if __name__ == '__main__': +if __name__ == "__main__": test_output_handler() diff --git a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_permute_and_transpose_handler.py b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_permute_and_transpose_handler.py index f071cd120fb719e832b691a7de37d811a4235708..958dc288fa16efbfc20dd91d63ce24ebcf1ed1c9 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_permute_and_transpose_handler.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_permute_and_transpose_handler.py @@ -1,5 +1,3 @@ -from functools import partial - import pytest import torch import torch.nn as nn @@ -20,7 +18,6 @@ from tests.test_auto_parallel.test_tensor_shard.test_node_handler.utils import n class ConvReshapeModel(nn.Module): - def __init__(self, reshape_dims, call_function): super().__init__() self.reshape_dims = reshape_dims @@ -37,7 +34,6 @@ class ConvReshapeModel(nn.Module): class LinearReshapeModel(nn.Module): - def __init__(self, reshape_dims, call_function): super().__init__() self.reshape_dims = reshape_dims @@ -55,23 +51,23 @@ class LinearReshapeModel(nn.Module): def check_view_handler(rank, world_size, port, call_function, reshape_dims, model_cls): disable_existing_loggers() - launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") if call_function == torch.permute: reshape_dims = reshape_dims[0] elif call_function == torch.transpose: reshape_dims = reshape_dims[1] model = model_cls(reshape_dims, call_function).cuda() - if model_cls.__name__ == 'ConvReshapeModel': - input = torch.rand(8, 8, 66, 66).to('cuda') - other = torch.rand(16, 8, 3, 3).to('cuda') + if model_cls.__name__ == "ConvReshapeModel": + input = torch.rand(8, 8, 66, 66).to("cuda") + other = torch.rand(16, 8, 3, 3).to("cuda") # index of conv node in computation graph node_index = 2 # total number of conv strategies strategy_number = 16 - if model_cls.__name__ == 'LinearReshapeModel': - input = torch.rand(8, 16, 64, 32).to('cuda') - other = torch.rand(64, 32).to('cuda') + if model_cls.__name__ == "LinearReshapeModel": + input = torch.rand(8, 16, 64, 32).to("cuda") + other = torch.rand(64, 32).to("cuda") # index of linear node in computation graph node_index = 2 # total number of linear strategies @@ -81,15 +77,17 @@ def check_view_handler(rank, world_size, port, call_function, reshape_dims, mode mesh_shape = (2, 2) device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True) - numerical_test_for_node_strategy(model=model, - device_mesh=device_mesh, - node_index=node_index, - strategy_number=strategy_number, - input_args=[input, other], - meta_arg_names=['input', 'other'], - node_type='following') + numerical_test_for_node_strategy( + model=model, + device_mesh=device_mesh, + node_index=node_index, + strategy_number=strategy_number, + input_args=[input, other], + meta_arg_names=["input", "other"], + node_type="following", + ) tracer = ColoTracer(bias_addition_split=True) - if model_cls.__name__ == 'ConvReshapeModel': + if model_cls.__name__ == "ConvReshapeModel": # graph(): # %input_1 : torch.Tensor [#users=1] = placeholder[target=input] # %other : torch.Tensor [#users=1] = placeholder[target=other] @@ -97,12 +95,12 @@ def check_view_handler(rank, world_size, port, call_function, reshape_dims, mode # %permute : [#users=1] = call_function[target=torch.permute](args = (%conv2d, (0, 2, 1, 3)), kwargs = {}) # return permute meta_args = { - 'input': torch.rand(8, 8, 66, 66).to('meta'), - 'other': torch.rand(16, 8, 3, 3).to('meta'), + "input": torch.rand(8, 8, 66, 66).to("meta"), + "other": torch.rand(16, 8, 3, 3).to("meta"), } graph = tracer.trace(model, meta_args=meta_args) - if model_cls.__name__ == 'LinearReshapeModel': + if model_cls.__name__ == "LinearReshapeModel": # graph(): # %input_1 : torch.Tensor [#users=1] = placeholder[target=input] # %other : torch.Tensor [#users=1] = placeholder[target=other] @@ -110,8 +108,8 @@ def check_view_handler(rank, world_size, port, call_function, reshape_dims, mode # %permute : [#users=1] = call_method[target=view](args = (%linear, 32, 4, 32, 32, 4), kwargs = {}) # return permute meta_args = { - 'input': torch.rand(8, 16, 64, 32).to('meta'), - 'other': torch.rand(64, 32).to('meta'), + "input": torch.rand(8, 16, 64, 32).to("meta"), + "other": torch.rand(64, 32).to("meta"), } graph = tracer.trace(model, meta_args=meta_args) @@ -124,30 +122,29 @@ def check_view_handler(rank, world_size, port, call_function, reshape_dims, mode previous_strategies_vector = StrategiesVector(previous_mod_node) # build handler - if model_cls.__name__ == 'ConvReshapeModel': - - conv_handler = ConvFunctionHandler(node=previous_mod_node, - device_mesh=device_mesh, - strategies_vector=previous_strategies_vector) + if model_cls.__name__ == "ConvReshapeModel": + conv_handler = ConvFunctionHandler( + node=previous_mod_node, device_mesh=device_mesh, strategies_vector=previous_strategies_vector + ) conv_handler.register_strategy(compute_resharding_cost=False) - setattr(previous_mod_node, 'strategies_vector', previous_strategies_vector) + setattr(previous_mod_node, "strategies_vector", previous_strategies_vector) - if model_cls.__name__ == 'LinearReshapeModel': + if model_cls.__name__ == "LinearReshapeModel": assert len(previous_strategies_vector) == 0 - linear_handler = LinearFunctionHandler(node=previous_mod_node, - device_mesh=device_mesh, - strategies_vector=previous_strategies_vector) + linear_handler = LinearFunctionHandler( + node=previous_mod_node, device_mesh=device_mesh, strategies_vector=previous_strategies_vector + ) linear_handler.register_strategy(compute_resharding_cost=False) - setattr(previous_mod_node, 'strategies_vector', previous_strategies_vector) + setattr(previous_mod_node, "strategies_vector", previous_strategies_vector) if call_function == torch.permute: - reshape_handler = PermuteHandler(node=reshape_node, - device_mesh=device_mesh, - strategies_vector=view_strategies_vector) + reshape_handler = PermuteHandler( + node=reshape_node, device_mesh=device_mesh, strategies_vector=view_strategies_vector + ) else: - reshape_handler = TransposeHandler(node=reshape_node, - device_mesh=device_mesh, - strategies_vector=view_strategies_vector) + reshape_handler = TransposeHandler( + node=reshape_node, device_mesh=device_mesh, strategies_vector=view_strategies_vector + ) reshape_handler.register_strategy(compute_resharding_cost=False) @@ -159,25 +156,25 @@ def check_view_handler(rank, world_size, port, call_function, reshape_dims, mode # make sure they have valid values assert op_data.data is not None - if model_cls.__name__ == 'ConvReshapeModel': - assert mapping['input'].name == "conv2d" + if model_cls.__name__ == "ConvReshapeModel": + assert mapping["input"].name == "conv2d" else: - assert mapping['input'].name == "linear" - assert mapping['input'].data.is_meta - assert mapping['input'].data.shape == torch.Size([8, 16, 64, 64]) - assert mapping['input'].type == OperationDataType.ARG - assert mapping['input'].logical_shape == torch.Size([8, 16, 64, 64]) + assert mapping["input"].name == "linear" + assert mapping["input"].data.is_meta + assert mapping["input"].data.shape == torch.Size([8, 16, 64, 64]) + assert mapping["input"].type == OperationDataType.ARG + assert mapping["input"].logical_shape == torch.Size([8, 16, 64, 64]) if call_function == torch.permute: - assert mapping['output'].name == "permute" - assert mapping['output'].data.is_meta - assert mapping['output'].data.shape == torch.permute(torch.rand(8, 16, 64, 64), reshape_dims).shape - assert mapping['output'].type == OperationDataType.OUTPUT + assert mapping["output"].name == "permute" + assert mapping["output"].data.is_meta + assert mapping["output"].data.shape == torch.permute(torch.rand(8, 16, 64, 64), reshape_dims).shape + assert mapping["output"].type == OperationDataType.OUTPUT else: - assert mapping['output'].name == "transpose" - assert mapping['output'].data.is_meta - assert mapping['output'].data.shape == torch.transpose(torch.rand(8, 16, 64, 64), *reshape_dims).shape - assert mapping['output'].type == OperationDataType.OUTPUT + assert mapping["output"].name == "transpose" + assert mapping["output"].data.is_meta + assert mapping["output"].data.shape == torch.transpose(torch.rand(8, 16, 64, 64), *reshape_dims).shape + assert mapping["output"].type == OperationDataType.OUTPUT # reshape handler is a following strategy handler, so the number of strategies is equal to the predecessor node. assert len(view_strategies_vector) == len(previous_strategies_vector) @@ -185,146 +182,144 @@ def check_view_handler(rank, world_size, port, call_function, reshape_dims, mode if rank == 0: for name in strategy_name_list: print(name) - if model_cls.__name__ == 'ConvReshapeModel': - + if model_cls.__name__ == "ConvReshapeModel": if reshape_dims in ((0, 2, 1, 3), (1, 2)): - assert '[S0, S1, R, R] -> [S0, R, S1, R]_0' in strategy_name_list - assert '[S1, S0, R, R] -> [S1, R, S0, R]_1' in strategy_name_list - assert '[S0, R, R, R] -> [S0, R, R, R]_2' in strategy_name_list - assert '[S1, R, R, R] -> [S1, R, R, R]_3' in strategy_name_list - assert '[S0, R, R, R] -> [S0, R, R, R]_4' in strategy_name_list - assert '[S1, R, R, R] -> [S1, R, R, R]_5' in strategy_name_list - assert '[R, S1, R, R] -> [R, R, S1, R]_6' in strategy_name_list - assert '[R, S0, R, R] -> [R, R, S0, R]_7' in strategy_name_list - assert '[R, R, R, R] -> [R, R, R, R]_8' in strategy_name_list - assert '[R, R, R, R] -> [R, R, R, R]_9' in strategy_name_list - assert '[R, S0, R, R] -> [R, R, S0, R]_10' in strategy_name_list - assert '[R, S1, R, R] -> [R, R, S1, R]_11' in strategy_name_list - assert '[R, R, R, R] -> [R, R, R, R]_12' in strategy_name_list - assert '[S01, R, R, R] -> [S01, R, R, R]_13' in strategy_name_list - assert '[R, R, R, R] -> [R, R, R, R]_14' in strategy_name_list - assert '[R, S01, R, R] -> [R, R, S01, R]_15' in strategy_name_list + assert "[S0, S1, R, R] -> [S0, R, S1, R]_0" in strategy_name_list + assert "[S1, S0, R, R] -> [S1, R, S0, R]_1" in strategy_name_list + assert "[S0, R, R, R] -> [S0, R, R, R]_2" in strategy_name_list + assert "[S1, R, R, R] -> [S1, R, R, R]_3" in strategy_name_list + assert "[S0, R, R, R] -> [S0, R, R, R]_4" in strategy_name_list + assert "[S1, R, R, R] -> [S1, R, R, R]_5" in strategy_name_list + assert "[R, S1, R, R] -> [R, R, S1, R]_6" in strategy_name_list + assert "[R, S0, R, R] -> [R, R, S0, R]_7" in strategy_name_list + assert "[R, R, R, R] -> [R, R, R, R]_8" in strategy_name_list + assert "[R, R, R, R] -> [R, R, R, R]_9" in strategy_name_list + assert "[R, S0, R, R] -> [R, R, S0, R]_10" in strategy_name_list + assert "[R, S1, R, R] -> [R, R, S1, R]_11" in strategy_name_list + assert "[R, R, R, R] -> [R, R, R, R]_12" in strategy_name_list + assert "[S01, R, R, R] -> [S01, R, R, R]_13" in strategy_name_list + assert "[R, R, R, R] -> [R, R, R, R]_14" in strategy_name_list + assert "[R, S01, R, R] -> [R, R, S01, R]_15" in strategy_name_list if reshape_dims == (2, 0, 1, 3): - assert '[S0, S1, R, R] -> [R, S0, S1, R]_0' in strategy_name_list - assert '[S1, S0, R, R] -> [R, S1, S0, R]_1' in strategy_name_list - assert '[S0, R, R, R] -> [R, S0, R, R]_2' in strategy_name_list - assert '[S1, R, R, R] -> [R, S1, R, R]_3' in strategy_name_list - assert '[S0, R, R, R] -> [R, S0, R, R]_4' in strategy_name_list - assert '[S1, R, R, R] -> [R, S1, R, R]_5' in strategy_name_list - assert '[R, S1, R, R] -> [R, R, S1, R]_6' in strategy_name_list - assert '[R, S0, R, R] -> [R, R, S0, R]_7' in strategy_name_list - assert '[R, R, R, R] -> [R, R, R, R]_8' in strategy_name_list - assert '[R, R, R, R] -> [R, R, R, R]_9' in strategy_name_list - assert '[R, S0, R, R] -> [R, R, S0, R]_10' in strategy_name_list - assert '[R, S1, R, R] -> [R, R, S1, R]_11' in strategy_name_list - assert '[R, R, R, R] -> [R, R, R, R]_12' in strategy_name_list - assert '[S01, R, R, R] -> [R, S01, R, R]_13' in strategy_name_list - assert '[R, R, R, R] -> [R, R, R, R]_14' in strategy_name_list - assert '[R, S01, R, R] -> [R, R, S01, R]_15' in strategy_name_list + assert "[S0, S1, R, R] -> [R, S0, S1, R]_0" in strategy_name_list + assert "[S1, S0, R, R] -> [R, S1, S0, R]_1" in strategy_name_list + assert "[S0, R, R, R] -> [R, S0, R, R]_2" in strategy_name_list + assert "[S1, R, R, R] -> [R, S1, R, R]_3" in strategy_name_list + assert "[S0, R, R, R] -> [R, S0, R, R]_4" in strategy_name_list + assert "[S1, R, R, R] -> [R, S1, R, R]_5" in strategy_name_list + assert "[R, S1, R, R] -> [R, R, S1, R]_6" in strategy_name_list + assert "[R, S0, R, R] -> [R, R, S0, R]_7" in strategy_name_list + assert "[R, R, R, R] -> [R, R, R, R]_8" in strategy_name_list + assert "[R, R, R, R] -> [R, R, R, R]_9" in strategy_name_list + assert "[R, S0, R, R] -> [R, R, S0, R]_10" in strategy_name_list + assert "[R, S1, R, R] -> [R, R, S1, R]_11" in strategy_name_list + assert "[R, R, R, R] -> [R, R, R, R]_12" in strategy_name_list + assert "[S01, R, R, R] -> [R, S01, R, R]_13" in strategy_name_list + assert "[R, R, R, R] -> [R, R, R, R]_14" in strategy_name_list + assert "[R, S01, R, R] -> [R, R, S01, R]_15" in strategy_name_list if reshape_dims == (1, 3): - assert '[S0, S1, R, R] -> [S0, R, R, S1]_0' in strategy_name_list - assert '[S1, S0, R, R] -> [S1, R, R, S0]_1' in strategy_name_list - assert '[S0, R, R, R] -> [S0, R, R, R]_2' in strategy_name_list - assert '[S1, R, R, R] -> [S1, R, R, R]_3' in strategy_name_list - assert '[S0, R, R, R] -> [S0, R, R, R]_4' in strategy_name_list - assert '[S1, R, R, R] -> [S1, R, R, R]_5' in strategy_name_list - assert '[R, S1, R, R] -> [R, R, R, S1]_6' in strategy_name_list - assert '[R, S0, R, R] -> [R, R, R, S0]_7' in strategy_name_list - assert '[R, R, R, R] -> [R, R, R, R]_8' in strategy_name_list - assert '[R, R, R, R] -> [R, R, R, R]_9' in strategy_name_list - assert '[R, S0, R, R] -> [R, R, R, S0]_10' in strategy_name_list - assert '[R, S1, R, R] -> [R, R, R, S1]_11' in strategy_name_list - assert '[R, R, R, R] -> [R, R, R, R]_12' in strategy_name_list - assert '[S01, R, R, R] -> [S01, R, R, R]_13' in strategy_name_list - assert '[R, R, R, R] -> [R, R, R, R]_14' in strategy_name_list - assert '[R, S01, R, R] -> [R, R, R, S01]_15' in strategy_name_list - - if model_cls.__name__ == 'LinearReshapeModel': - + assert "[S0, S1, R, R] -> [S0, R, R, S1]_0" in strategy_name_list + assert "[S1, S0, R, R] -> [S1, R, R, S0]_1" in strategy_name_list + assert "[S0, R, R, R] -> [S0, R, R, R]_2" in strategy_name_list + assert "[S1, R, R, R] -> [S1, R, R, R]_3" in strategy_name_list + assert "[S0, R, R, R] -> [S0, R, R, R]_4" in strategy_name_list + assert "[S1, R, R, R] -> [S1, R, R, R]_5" in strategy_name_list + assert "[R, S1, R, R] -> [R, R, R, S1]_6" in strategy_name_list + assert "[R, S0, R, R] -> [R, R, R, S0]_7" in strategy_name_list + assert "[R, R, R, R] -> [R, R, R, R]_8" in strategy_name_list + assert "[R, R, R, R] -> [R, R, R, R]_9" in strategy_name_list + assert "[R, S0, R, R] -> [R, R, R, S0]_10" in strategy_name_list + assert "[R, S1, R, R] -> [R, R, R, S1]_11" in strategy_name_list + assert "[R, R, R, R] -> [R, R, R, R]_12" in strategy_name_list + assert "[S01, R, R, R] -> [S01, R, R, R]_13" in strategy_name_list + assert "[R, R, R, R] -> [R, R, R, R]_14" in strategy_name_list + assert "[R, S01, R, R] -> [R, R, R, S01]_15" in strategy_name_list + + if model_cls.__name__ == "LinearReshapeModel": if reshape_dims == ((0, 2, 1, 3), (1, 2)): - assert '[S0, R, R, S1] -> [S0, R, R, S1]_11' in strategy_name_list - assert '[R, S0, R, S1] -> [R, R, S0, S1]_12' in strategy_name_list - assert '[R, R, S0, S1] -> [R, S0, R, S1]_13' in strategy_name_list - assert '[S1, R, R, S0] -> [S1, R, R, S0]_14' in strategy_name_list - assert '[R, S1, R, S0] -> [R, R, S1, S0]_15' in strategy_name_list - assert '[R, R, S1, S0] -> [R, S1, R, S0]_16' in strategy_name_list - assert '[S0, R, R, R] -> [S0, R, R, R]_17' in strategy_name_list - assert '[R, S0, R, R] -> [R, R, S0, R]_18' in strategy_name_list - assert '[R, R, S0, R] -> [R, S0, R, R]_19' in strategy_name_list - assert '[S1, R, R, R] -> [S1, R, R, R]_20' in strategy_name_list - assert '[R, S1, R, R] -> [R, R, S1, R]_21' in strategy_name_list - assert '[R, R, S1, R] -> [R, S1, R, R]_22' in strategy_name_list - assert '[R, R, R, S1] -> [R, R, R, S1]_10' in strategy_name_list - assert '[R, R, R, S0] -> [R, R, R, S0]_9' in strategy_name_list - assert '[R, R, R, R] -> [R, R, R, R]_8' in strategy_name_list - assert '[R, R, R, R] -> [R, R, R, R]_7' in strategy_name_list - assert '[R, R, R, S0] -> [R, R, R, S0]_6' in strategy_name_list - assert '[R, R, R, S1] -> [R, R, R, S1]_5' in strategy_name_list - assert '[S01, R, R, R] -> [S01, R, R, R]_0' in strategy_name_list - assert '[R, S01, R, R] -> [R, R, S01, R]_1' in strategy_name_list - assert '[R, R, S01, R] -> [R, S01, R, R]_2' in strategy_name_list - assert '[R, R, R, R] -> [R, R, R, R]_3' in strategy_name_list - assert '[R, R, R, S01] -> [R, R, R, S01]_4' in strategy_name_list + assert "[S0, R, R, S1] -> [S0, R, R, S1]_11" in strategy_name_list + assert "[R, S0, R, S1] -> [R, R, S0, S1]_12" in strategy_name_list + assert "[R, R, S0, S1] -> [R, S0, R, S1]_13" in strategy_name_list + assert "[S1, R, R, S0] -> [S1, R, R, S0]_14" in strategy_name_list + assert "[R, S1, R, S0] -> [R, R, S1, S0]_15" in strategy_name_list + assert "[R, R, S1, S0] -> [R, S1, R, S0]_16" in strategy_name_list + assert "[S0, R, R, R] -> [S0, R, R, R]_17" in strategy_name_list + assert "[R, S0, R, R] -> [R, R, S0, R]_18" in strategy_name_list + assert "[R, R, S0, R] -> [R, S0, R, R]_19" in strategy_name_list + assert "[S1, R, R, R] -> [S1, R, R, R]_20" in strategy_name_list + assert "[R, S1, R, R] -> [R, R, S1, R]_21" in strategy_name_list + assert "[R, R, S1, R] -> [R, S1, R, R]_22" in strategy_name_list + assert "[R, R, R, S1] -> [R, R, R, S1]_10" in strategy_name_list + assert "[R, R, R, S0] -> [R, R, R, S0]_9" in strategy_name_list + assert "[R, R, R, R] -> [R, R, R, R]_8" in strategy_name_list + assert "[R, R, R, R] -> [R, R, R, R]_7" in strategy_name_list + assert "[R, R, R, S0] -> [R, R, R, S0]_6" in strategy_name_list + assert "[R, R, R, S1] -> [R, R, R, S1]_5" in strategy_name_list + assert "[S01, R, R, R] -> [S01, R, R, R]_0" in strategy_name_list + assert "[R, S01, R, R] -> [R, R, S01, R]_1" in strategy_name_list + assert "[R, R, S01, R] -> [R, S01, R, R]_2" in strategy_name_list + assert "[R, R, R, R] -> [R, R, R, R]_3" in strategy_name_list + assert "[R, R, R, S01] -> [R, R, R, S01]_4" in strategy_name_list if reshape_dims == (2, 0, 1, 3): - assert '[S0, R, R, S1] -> [R, S0, R, S1]_11' in strategy_name_list - assert '[R, S0, R, S1] -> [R, R, S0, S1]_12' in strategy_name_list - assert '[R, R, S0, S1] -> [S0, R, R, S1]_13' in strategy_name_list - assert '[S1, R, R, S0] -> [R, S1, R, S0]_14' in strategy_name_list - assert '[R, S1, R, S0] -> [R, R, S1, S0]_15' in strategy_name_list - assert '[R, R, S1, S0] -> [S1, R, R, S0]_16' in strategy_name_list - assert '[S0, R, R, R] -> [R, S0, R, R]_17' in strategy_name_list - assert '[R, S0, R, R] -> [R, R, S0, R]_18' in strategy_name_list - assert '[R, R, S0, R] -> [S0, R, R, R]_19' in strategy_name_list - assert '[S1, R, R, R] -> [R, S1, R, R]_20' in strategy_name_list - assert '[R, S1, R, R] -> [R, R, S1, R]_21' in strategy_name_list - assert '[R, R, S1, R] -> [S1, R, R, R]_22' in strategy_name_list - assert '[R, R, R, S1] -> [R, R, R, S1]_10' in strategy_name_list - assert '[R, R, R, S0] -> [R, R, R, S0]_9' in strategy_name_list - assert '[R, R, R, R] -> [R, R, R, R]_8' in strategy_name_list - assert '[R, R, R, R] -> [R, R, R, R]_7' in strategy_name_list - assert '[R, R, R, S0] -> [R, R, R, S0]_6' in strategy_name_list - assert '[R, R, R, S1] -> [R, R, R, S1]_5' in strategy_name_list - assert '[S01, R, R, R] -> [R, S01, R, R]_0' in strategy_name_list - assert '[R, S01, R, R] -> [R, R, S01, R]_1' in strategy_name_list - assert '[R, R, S01, R] -> [S01, R, R, R]_2' in strategy_name_list - assert '[R, R, R, R] -> [R, R, R, R]_3' in strategy_name_list - assert '[R, R, R, S01] -> [R, R, R, S01]_4' in strategy_name_list + assert "[S0, R, R, S1] -> [R, S0, R, S1]_11" in strategy_name_list + assert "[R, S0, R, S1] -> [R, R, S0, S1]_12" in strategy_name_list + assert "[R, R, S0, S1] -> [S0, R, R, S1]_13" in strategy_name_list + assert "[S1, R, R, S0] -> [R, S1, R, S0]_14" in strategy_name_list + assert "[R, S1, R, S0] -> [R, R, S1, S0]_15" in strategy_name_list + assert "[R, R, S1, S0] -> [S1, R, R, S0]_16" in strategy_name_list + assert "[S0, R, R, R] -> [R, S0, R, R]_17" in strategy_name_list + assert "[R, S0, R, R] -> [R, R, S0, R]_18" in strategy_name_list + assert "[R, R, S0, R] -> [S0, R, R, R]_19" in strategy_name_list + assert "[S1, R, R, R] -> [R, S1, R, R]_20" in strategy_name_list + assert "[R, S1, R, R] -> [R, R, S1, R]_21" in strategy_name_list + assert "[R, R, S1, R] -> [S1, R, R, R]_22" in strategy_name_list + assert "[R, R, R, S1] -> [R, R, R, S1]_10" in strategy_name_list + assert "[R, R, R, S0] -> [R, R, R, S0]_9" in strategy_name_list + assert "[R, R, R, R] -> [R, R, R, R]_8" in strategy_name_list + assert "[R, R, R, R] -> [R, R, R, R]_7" in strategy_name_list + assert "[R, R, R, S0] -> [R, R, R, S0]_6" in strategy_name_list + assert "[R, R, R, S1] -> [R, R, R, S1]_5" in strategy_name_list + assert "[S01, R, R, R] -> [R, S01, R, R]_0" in strategy_name_list + assert "[R, S01, R, R] -> [R, R, S01, R]_1" in strategy_name_list + assert "[R, R, S01, R] -> [S01, R, R, R]_2" in strategy_name_list + assert "[R, R, R, R] -> [R, R, R, R]_3" in strategy_name_list + assert "[R, R, R, S01] -> [R, R, R, S01]_4" in strategy_name_list if reshape_dims == (1, 3): - assert '[S0, R, R, S1] -> [S0, S1, R, R]_11' in strategy_name_list - assert '[R, S0, R, S1] -> [R, S1, R, S0]_12' in strategy_name_list - assert '[R, R, S0, S1] -> [R, S1, S0, R]_13' in strategy_name_list - assert '[S1, R, R, S0] -> [S1, S0, R, R]_14' in strategy_name_list - assert '[R, S1, R, S0] -> [R, S0, R, S1]_15' in strategy_name_list - assert '[R, R, S1, S0] -> [R, S0, S1, R]_16' in strategy_name_list - assert '[S0, R, R, R] -> [S0, R, R, R]_17' in strategy_name_list - assert '[R, S0, R, R] -> [R, R, R, S0]_18' in strategy_name_list - assert '[R, R, S0, R] -> [R, R, S0, R]_19' in strategy_name_list - assert '[S1, R, R, R] -> [S1, R, R, R]_20' in strategy_name_list - assert '[R, S1, R, R] -> [R, R, R, S1]_21' in strategy_name_list - assert '[R, R, S1, R] -> [R, R, S1, R]_22' in strategy_name_list - assert '[R, R, R, S1] -> [R, S1, R, R]_10' in strategy_name_list - assert '[R, R, R, S0] -> [R, S0, R, R]_9' in strategy_name_list - assert '[R, R, R, R] -> [R, R, R, R]_8' in strategy_name_list - assert '[R, R, R, R] -> [R, R, R, R]_7' in strategy_name_list - assert '[R, R, R, S0] -> [R, S0, R, R]_6' in strategy_name_list - assert '[R, R, R, S1] -> [R, S1, R, R]_5' in strategy_name_list - assert '[S01, R, R, R] -> [S01, R, R, R]_0' in strategy_name_list - assert '[R, S01, R, R] -> [R, R, R, S01]_1' in strategy_name_list - assert '[R, R, S01, R] -> [R, R, S01, R]_2' in strategy_name_list - assert '[R, R, R, R] -> [R, R, R, R]_3' in strategy_name_list - assert '[R, R, R, S01] -> [R, S01, R, R]_4' in strategy_name_list - - -@run_on_environment_flag(name='AUTO_PARALLEL') + assert "[S0, R, R, S1] -> [S0, S1, R, R]_11" in strategy_name_list + assert "[R, S0, R, S1] -> [R, S1, R, S0]_12" in strategy_name_list + assert "[R, R, S0, S1] -> [R, S1, S0, R]_13" in strategy_name_list + assert "[S1, R, R, S0] -> [S1, S0, R, R]_14" in strategy_name_list + assert "[R, S1, R, S0] -> [R, S0, R, S1]_15" in strategy_name_list + assert "[R, R, S1, S0] -> [R, S0, S1, R]_16" in strategy_name_list + assert "[S0, R, R, R] -> [S0, R, R, R]_17" in strategy_name_list + assert "[R, S0, R, R] -> [R, R, R, S0]_18" in strategy_name_list + assert "[R, R, S0, R] -> [R, R, S0, R]_19" in strategy_name_list + assert "[S1, R, R, R] -> [S1, R, R, R]_20" in strategy_name_list + assert "[R, S1, R, R] -> [R, R, R, S1]_21" in strategy_name_list + assert "[R, R, S1, R] -> [R, R, S1, R]_22" in strategy_name_list + assert "[R, R, R, S1] -> [R, S1, R, R]_10" in strategy_name_list + assert "[R, R, R, S0] -> [R, S0, R, R]_9" in strategy_name_list + assert "[R, R, R, R] -> [R, R, R, R]_8" in strategy_name_list + assert "[R, R, R, R] -> [R, R, R, R]_7" in strategy_name_list + assert "[R, R, R, S0] -> [R, S0, R, R]_6" in strategy_name_list + assert "[R, R, R, S1] -> [R, S1, R, R]_5" in strategy_name_list + assert "[S01, R, R, R] -> [S01, R, R, R]_0" in strategy_name_list + assert "[R, S01, R, R] -> [R, R, R, S01]_1" in strategy_name_list + assert "[R, R, S01, R] -> [R, R, S01, R]_2" in strategy_name_list + assert "[R, R, R, R] -> [R, R, R, R]_3" in strategy_name_list + assert "[R, R, R, S01] -> [R, S01, R, R]_4" in strategy_name_list + + +@run_on_environment_flag(name="AUTO_PARALLEL") @pytest.mark.dist @rerun_if_address_is_in_use() -@parameterize('call_function', [torch.permute, torch.transpose]) -@parameterize('reshape_dims', [((0, 2, 1, 3), (1, 2)), ((2, 0, 1, 3), (1, 3))]) -@parameterize('model_cls', [ConvReshapeModel, LinearReshapeModel]) +@parameterize("call_function", [torch.permute, torch.transpose]) +@parameterize("reshape_dims", [((0, 2, 1, 3), (1, 2)), ((2, 0, 1, 3), (1, 3))]) +@parameterize("model_cls", [ConvReshapeModel, LinearReshapeModel]) def test_view_handler(call_function, reshape_dims, model_cls): spawn( check_view_handler, @@ -335,5 +330,5 @@ def test_view_handler(call_function, reshape_dims, model_cls): ) -if __name__ == '__main__': +if __name__ == "__main__": test_view_handler() diff --git a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_placeholder_handler.py b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_placeholder_handler.py index 6d02b0e0ba7407147fc603cbec9910ff79e58743..60c090429c6c0fa483047b7baa6bd802dd31a255 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_placeholder_handler.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_placeholder_handler.py @@ -12,7 +12,6 @@ from colossalai.testing import clear_cache_before_run, parameterize class PlaceholderModel(nn.Module): - def __init__(self): super().__init__() @@ -20,8 +19,8 @@ class PlaceholderModel(nn.Module): return input -@pytest.mark.skip('ShapeProp is not compatible with PyTorch 1.11.0') -@parameterize('placeholder_option', ['distributed', 'replicated']) +@pytest.mark.skip("ShapeProp is not compatible with PyTorch 1.11.0") +@parameterize("placeholder_option", ["distributed", "replicated"]) @clear_cache_before_run() def test_placeholder_handler(placeholder_option): model = PlaceholderModel() @@ -30,7 +29,7 @@ def test_placeholder_handler(placeholder_option): # %input_1 : torch.Tensor [#users=1] = placeholder[target=input] # return input_1 meta_args = { - "input": torch.rand(4, 4, 64, 64).to('meta'), + "input": torch.rand(4, 4, 64, 64).to("meta"), } graph = tracer.trace(model, meta_args=meta_args) gm = ColoGraphModule(model, graph) @@ -42,10 +41,12 @@ def test_placeholder_handler(placeholder_option): placeholder_node = list(graph.nodes)[0] placeholder_strategies_vector = StrategiesVector(placeholder_node) # build handler - placeholder_handler = PlaceholderHandler(node=placeholder_node, - device_mesh=device_mesh, - strategies_vector=placeholder_strategies_vector, - placeholder_option=placeholder_option) + placeholder_handler = PlaceholderHandler( + node=placeholder_node, + device_mesh=device_mesh, + strategies_vector=placeholder_strategies_vector, + placeholder_option=placeholder_option, + ) placeholder_handler.register_strategy(compute_resharding_cost=False) @@ -53,28 +54,28 @@ def test_placeholder_handler(placeholder_option): mapping = placeholder_handler.get_operation_data_mapping() strategy = placeholder_strategies_vector[0] - strategy_sharding_spec = strategy.get_sharding_spec_by_name(mapping['output'].name) + strategy_sharding_spec = strategy.get_sharding_spec_by_name(mapping["output"].name) - if placeholder_option == 'distributed': - assert str(strategy_sharding_spec.sharding_sequence) == '[S01, R, R, R]' + if placeholder_option == "distributed": + assert str(strategy_sharding_spec.sharding_sequence) == "[S01, R, R, R]" else: - assert str(strategy_sharding_spec.sharding_sequence) == '[R, R, R, R]' + assert str(strategy_sharding_spec.sharding_sequence) == "[R, R, R, R]" for name, op_data in mapping.items(): op_data: OperationData # make sure they have valid values assert op_data.data is not None - assert mapping['output'].name == "input_1" - assert mapping['output'].data.is_meta - assert mapping['output'].data.shape == torch.Size((4, 4, 64, 64)) - assert mapping['output'].type == OperationDataType.OUTPUT + assert mapping["output"].name == "input_1" + assert mapping["output"].data.is_meta + assert mapping["output"].data.shape == torch.Size((4, 4, 64, 64)) + assert mapping["output"].type == OperationDataType.OUTPUT strategy_name_list = [val.name for val in placeholder_handler.strategies_vector] - if placeholder_option == 'replicated': + if placeholder_option == "replicated": assert "Replica Placeholder" in strategy_name_list else: assert "Distributed Placeholder" in strategy_name_list -if __name__ == '__main__': +if __name__ == "__main__": test_placeholder_handler() diff --git a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_shard_option.py b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_shard_option.py index 14c364c45fc437c08a69948eea63802184fc04e9..6836a882242f6689174f74ac02769b91a23a013d 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_shard_option.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_shard_option.py @@ -12,7 +12,6 @@ from colossalai.testing import clear_cache_before_run, run_on_environment_flag class LinearModel(nn.Module): - def __init__(self): super().__init__() @@ -28,7 +27,7 @@ def check_shard_option(shard_option): device_mesh = DeviceMesh(physical_mesh_id, mesh_shape) tracer = ColoTracer(bias_addition_split=True) - meta_args = {'input': torch.rand(4, 4, 4, 16).to('meta'), 'others': torch.rand(32, 16).to('meta')} + meta_args = {"input": torch.rand(4, 4, 4, 16).to("meta"), "others": torch.rand(32, 16).to("meta")} graph = tracer.trace(model, meta_args=meta_args) gm = ColoGraphModule(model, graph) shape_prop_pass(gm, *meta_args.values()) @@ -36,77 +35,76 @@ def check_shard_option(shard_option): strategies_vector = StrategiesVector(linear_func_node) # build handler - handler = LinearFunctionHandler(node=linear_func_node, - device_mesh=device_mesh, - strategies_vector=strategies_vector, - shard_option=shard_option) + handler = LinearFunctionHandler( + node=linear_func_node, device_mesh=device_mesh, strategies_vector=strategies_vector, shard_option=shard_option + ) strategies_vector = handler.register_strategy(compute_resharding_cost=False) strategy_name_list = [val.name for val in strategies_vector] if shard_option == ShardOption.SHARD_LAST_AXIS: # RR = RS x SR - assert 'RR = RS1 x S1R' in strategy_name_list + assert "RR = RS1 x S1R" in strategy_name_list # RS= RR x RS - assert 'RS1 = RR x RS1' in strategy_name_list + assert "RS1 = RR x RS1" in strategy_name_list return # SS = SR x RS - assert 'S1S0 = S1R x RS0_0' in strategy_name_list - assert 'S0S1 = S0R x RS1_1' in strategy_name_list - assert 'S0S1 = S0R x RS1_2' in strategy_name_list - assert 'S0S1 = S0R x RS1_0' in strategy_name_list - assert 'S1S0 = S1R x RS0_1' in strategy_name_list - assert 'S1S0 = S1R x RS0_2' in strategy_name_list + assert "S1S0 = S1R x RS0_0" in strategy_name_list + assert "S0S1 = S0R x RS1_1" in strategy_name_list + assert "S0S1 = S0R x RS1_2" in strategy_name_list + assert "S0S1 = S0R x RS1_0" in strategy_name_list + assert "S1S0 = S1R x RS0_1" in strategy_name_list + assert "S1S0 = S1R x RS0_2" in strategy_name_list # SR = SS x SR - assert 'S0R = S0S1 x S1R_1' in strategy_name_list - assert 'S0R = S0S1 x S1R_2' in strategy_name_list - assert 'S1R = S1S0 x S0R_0' in strategy_name_list - assert 'S0R = S0S1 x S1R_0' in strategy_name_list - assert 'S1R = S1S0 x S0R_1' in strategy_name_list - assert 'S1R = S1S0 x S0R_2' in strategy_name_list + assert "S0R = S0S1 x S1R_1" in strategy_name_list + assert "S0R = S0S1 x S1R_2" in strategy_name_list + assert "S1R = S1S0 x S0R_0" in strategy_name_list + assert "S0R = S0S1 x S1R_0" in strategy_name_list + assert "S1R = S1S0 x S0R_1" in strategy_name_list + assert "S1R = S1S0 x S0R_2" in strategy_name_list # RS = RS x SS - assert 'RS0 = RS1 x S1S0' in strategy_name_list - assert 'RS1 = RS0 x S0S1' in strategy_name_list + assert "RS0 = RS1 x S1S0" in strategy_name_list + assert "RS1 = RS0 x S0S1" in strategy_name_list # S01R = S01R x RR - assert 'S01R = S01R x RR_0' in strategy_name_list - assert 'S01R = S01R x RR_1' in strategy_name_list - assert 'S01R = S01R x RR_2' in strategy_name_list + assert "S01R = S01R x RR_0" in strategy_name_list + assert "S01R = S01R x RR_1" in strategy_name_list + assert "S01R = S01R x RR_2" in strategy_name_list # RR = RS01 x S01R - assert 'RR = RS01 x S01R' in strategy_name_list + assert "RR = RS01 x S01R" in strategy_name_list # RS01 = RR x RS01 - assert 'RS01 = RR x RS01' in strategy_name_list + assert "RS01 = RR x RS01" in strategy_name_list if shard_option == ShardOption.SHARD: # RR = RS x SR - assert 'RR = RS0 x S0R' in strategy_name_list - assert 'RR = RS1 x S1R' in strategy_name_list + assert "RR = RS0 x S0R" in strategy_name_list + assert "RR = RS1 x S1R" in strategy_name_list # RS= RR x RS - assert 'RS0 = RR x RS0' in strategy_name_list - assert 'RS1 = RR x RS1' in strategy_name_list + assert "RS0 = RR x RS0" in strategy_name_list + assert "RS1 = RR x RS1" in strategy_name_list if shard_option == ShardOption.STANDARD: # RR = RS x SR - assert 'RR = RS0 x S0R' in strategy_name_list - assert 'RR = RS1 x S1R' in strategy_name_list + assert "RR = RS0 x S0R" in strategy_name_list + assert "RR = RS1 x S1R" in strategy_name_list # RS= RR x RS - assert 'RS0 = RR x RS0' in strategy_name_list - assert 'RS1 = RR x RS1' in strategy_name_list + assert "RS0 = RR x RS0" in strategy_name_list + assert "RS1 = RR x RS1" in strategy_name_list # RR = RR x RR - assert 'RR = RR x RR' in strategy_name_list + assert "RR = RR x RR" in strategy_name_list -@run_on_environment_flag(name='AUTO_PARALLEL') +@run_on_environment_flag(name="AUTO_PARALLEL") @clear_cache_before_run() def test_shard_option(): # for shard_option in [ShardOption.STANDARD, ShardOption.SHARD, ShardOption.FULL_SHARD, ShardOption.SHARD_LAST_AXIS]: @@ -114,5 +112,5 @@ def test_shard_option(): check_shard_option(shard_option) -if __name__ == '__main__': +if __name__ == "__main__": test_shard_option() diff --git a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_softmax_handler.py b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_softmax_handler.py index 75ae0416ef9876795a7966c8f4f1d685583058ce..1a99c32ebcb9ed86c40aa72a8c45ee9759ac61f0 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_softmax_handler.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_softmax_handler.py @@ -17,7 +17,6 @@ from tests.test_auto_parallel.test_tensor_shard.test_node_handler.utils import n class LinearSplitModel(nn.Module): - def __init__(self, softmax_dim): super().__init__() self.softmax_dim = softmax_dim @@ -30,11 +29,11 @@ class LinearSplitModel(nn.Module): def check_split_handler(rank, world_size, port, softmax_dim, model_cls): disable_existing_loggers() - launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") model = model_cls(softmax_dim=softmax_dim).cuda() - input = torch.rand(8, 16, 64, 32).to('cuda') - other = torch.rand(64, 32).to('cuda') + input = torch.rand(8, 16, 64, 32).to("cuda") + other = torch.rand(64, 32).to("cuda") # index of linear node in computation graph node_index = 2 # total number of linear strategies @@ -44,13 +43,15 @@ def check_split_handler(rank, world_size, port, softmax_dim, model_cls): mesh_shape = (2, 2) device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True) - numerical_test_for_node_strategy(model=model, - device_mesh=device_mesh, - node_index=node_index, - strategy_number=strategy_number, - input_args=[input, other], - meta_arg_names=['input', 'other'], - node_type='following') + numerical_test_for_node_strategy( + model=model, + device_mesh=device_mesh, + node_index=node_index, + strategy_number=strategy_number, + input_args=[input, other], + meta_arg_names=["input", "other"], + node_type="following", + ) tracer = ColoTracer(bias_addition_split=True) # graph(): @@ -60,8 +61,8 @@ def check_split_handler(rank, world_size, port, softmax_dim, model_cls): # %softmax : [#users=1] = call_method[target=split](args = (%linear,), kwargs = {}) # return split meta_args = { - 'input': torch.rand(8, 16, 64, 32).to('meta'), - 'other': torch.rand(64, 32).to('meta'), + "input": torch.rand(8, 16, 64, 32).to("meta"), + "other": torch.rand(64, 32).to("meta"), } graph = tracer.trace(model, meta_args=meta_args) @@ -75,15 +76,15 @@ def check_split_handler(rank, world_size, port, softmax_dim, model_cls): # build handler assert len(previous_strategies_vector) == 0 - linear_handler = LinearFunctionHandler(node=previous_mod_node, - device_mesh=device_mesh, - strategies_vector=previous_strategies_vector) + linear_handler = LinearFunctionHandler( + node=previous_mod_node, device_mesh=device_mesh, strategies_vector=previous_strategies_vector + ) linear_handler.register_strategy(compute_resharding_cost=False) - setattr(previous_mod_node, 'strategies_vector', previous_strategies_vector) + setattr(previous_mod_node, "strategies_vector", previous_strategies_vector) - softmax_handler = SoftmaxHandler(node=split_node, - device_mesh=device_mesh, - strategies_vector=split_strategies_vector) + softmax_handler = SoftmaxHandler( + node=split_node, device_mesh=device_mesh, strategies_vector=split_strategies_vector + ) softmax_handler.register_strategy(compute_resharding_cost=False) @@ -95,84 +96,84 @@ def check_split_handler(rank, world_size, port, softmax_dim, model_cls): # make sure they have valid values assert op_data.data is not None - assert mapping['input'].name == "linear" - assert mapping['input'].data.is_meta - assert mapping['input'].data.shape == torch.Size([8, 16, 64, 64]) - assert mapping['input'].type == OperationDataType.ARG - assert mapping['input'].logical_shape == torch.Size([8, 16, 64, 64]) + assert mapping["input"].name == "linear" + assert mapping["input"].data.is_meta + assert mapping["input"].data.shape == torch.Size([8, 16, 64, 64]) + assert mapping["input"].type == OperationDataType.ARG + assert mapping["input"].logical_shape == torch.Size([8, 16, 64, 64]) - assert mapping['softmax_dim'].name == "softmax_dim" - assert mapping['softmax_dim'].data == softmax_dim - assert mapping['softmax_dim'].type == OperationDataType.ARG + assert mapping["softmax_dim"].name == "softmax_dim" + assert mapping["softmax_dim"].data == softmax_dim + assert mapping["softmax_dim"].type == OperationDataType.ARG - assert mapping['output'].name == "softmax" - assert mapping['output'].data.shape == torch.Size([8, 16, 64, 64]) - assert mapping['output'].logical_shape == torch.Size([8, 16, 64, 64]) - assert mapping['output'].type == OperationDataType.OUTPUT + assert mapping["output"].name == "softmax" + assert mapping["output"].data.shape == torch.Size([8, 16, 64, 64]) + assert mapping["output"].logical_shape == torch.Size([8, 16, 64, 64]) + assert mapping["output"].type == OperationDataType.OUTPUT # reshape handler is a following strategy handler, so the number of strategies is equal to the predecessor node. assert len(split_strategies_vector) == len(previous_strategies_vector) strategy_name_list = [strategy.name for strategy in split_strategies_vector] if softmax_dim == 0: - assert '[R, R, R, S1] -> [R, R, R, S1]_11' in strategy_name_list - assert '[R, S0, R, S1] -> [R, S0, R, S1]_12' in strategy_name_list - assert '[R, R, S0, S1] -> [R, R, S0, S1]_13' in strategy_name_list - assert '[R, R, R, S0] -> [R, R, R, S0]_14' in strategy_name_list - assert '[R, S1, R, S0] -> [R, S1, R, S0]_15' in strategy_name_list - assert '[R, R, S1, S0] -> [R, R, S1, S0]_16' in strategy_name_list - assert '[R, R, R, R] -> [R, R, R, R]_17' in strategy_name_list - assert '[R, S0, R, R] -> [R, S0, R, R]_18' in strategy_name_list - assert '[R, R, S0, R] -> [R, R, S0, R]_19' in strategy_name_list - assert '[R, R, R, R] -> [R, R, R, R]_20' in strategy_name_list - assert '[R, S1, R, R] -> [R, S1, R, R]_21' in strategy_name_list - assert '[R, R, S1, R] -> [R, R, S1, R]_22' in strategy_name_list - assert '[R, R, R, S1] -> [R, R, R, S1]_10' in strategy_name_list - assert '[R, R, R, S0] -> [R, R, R, S0]_9' in strategy_name_list - assert '[R, R, R, R] -> [R, R, R, R]_8' in strategy_name_list - assert '[R, R, R, R] -> [R, R, R, R]_7' in strategy_name_list - assert '[R, R, R, S0] -> [R, R, R, S0]_6' in strategy_name_list - assert '[R, R, R, S1] -> [R, R, R, S1]_5' in strategy_name_list - assert '[R, R, R, R] -> [R, R, R, R]_0' in strategy_name_list - assert '[R, S01, R, R] -> [R, S01, R, R]_1' in strategy_name_list - assert '[R, R, S01, R] -> [R, R, S01, R]_2' in strategy_name_list - assert '[R, R, R, R] -> [R, R, R, R]_3' in strategy_name_list - assert '[R, R, R, S01] -> [R, R, R, S01]_4' in strategy_name_list + assert "[R, R, R, S1] -> [R, R, R, S1]_11" in strategy_name_list + assert "[R, S0, R, S1] -> [R, S0, R, S1]_12" in strategy_name_list + assert "[R, R, S0, S1] -> [R, R, S0, S1]_13" in strategy_name_list + assert "[R, R, R, S0] -> [R, R, R, S0]_14" in strategy_name_list + assert "[R, S1, R, S0] -> [R, S1, R, S0]_15" in strategy_name_list + assert "[R, R, S1, S0] -> [R, R, S1, S0]_16" in strategy_name_list + assert "[R, R, R, R] -> [R, R, R, R]_17" in strategy_name_list + assert "[R, S0, R, R] -> [R, S0, R, R]_18" in strategy_name_list + assert "[R, R, S0, R] -> [R, R, S0, R]_19" in strategy_name_list + assert "[R, R, R, R] -> [R, R, R, R]_20" in strategy_name_list + assert "[R, S1, R, R] -> [R, S1, R, R]_21" in strategy_name_list + assert "[R, R, S1, R] -> [R, R, S1, R]_22" in strategy_name_list + assert "[R, R, R, S1] -> [R, R, R, S1]_10" in strategy_name_list + assert "[R, R, R, S0] -> [R, R, R, S0]_9" in strategy_name_list + assert "[R, R, R, R] -> [R, R, R, R]_8" in strategy_name_list + assert "[R, R, R, R] -> [R, R, R, R]_7" in strategy_name_list + assert "[R, R, R, S0] -> [R, R, R, S0]_6" in strategy_name_list + assert "[R, R, R, S1] -> [R, R, R, S1]_5" in strategy_name_list + assert "[R, R, R, R] -> [R, R, R, R]_0" in strategy_name_list + assert "[R, S01, R, R] -> [R, S01, R, R]_1" in strategy_name_list + assert "[R, R, S01, R] -> [R, R, S01, R]_2" in strategy_name_list + assert "[R, R, R, R] -> [R, R, R, R]_3" in strategy_name_list + assert "[R, R, R, S01] -> [R, R, R, S01]_4" in strategy_name_list if softmax_dim == 1: - assert '[S0, R, R, S1] -> [S0, R, R, S1]_11' in strategy_name_list - assert '[R, R, R, S1] -> [R, R, R, S1]_12' in strategy_name_list - assert '[R, R, S0, S1] -> [R, R, S0, S1]_13' in strategy_name_list - assert '[S1, R, R, S0] -> [S1, R, R, S0]_14' in strategy_name_list - assert '[R, R, R, S0] -> [R, R, R, S0]_15' in strategy_name_list - assert '[R, R, S1, S0] -> [R, R, S1, S0]_16' in strategy_name_list - assert '[S0, R, R, R] -> [S0, R, R, R]_17' in strategy_name_list - assert '[R, R, R, R] -> [R, R, R, R]_18' in strategy_name_list - assert '[R, R, S0, R] -> [R, R, S0, R]_19' in strategy_name_list - assert '[S1, R, R, R] -> [S1, R, R, R]_20' in strategy_name_list - assert '[R, R, R, R] -> [R, R, R, R]_21' in strategy_name_list - assert '[R, R, S1, R] -> [R, R, S1, R]_22' in strategy_name_list - assert '[R, R, R, S1] -> [R, R, R, S1]_10' in strategy_name_list - assert '[R, R, R, S0] -> [R, R, R, S0]_9' in strategy_name_list - assert '[R, R, R, R] -> [R, R, R, R]_8' in strategy_name_list - assert '[R, R, R, R] -> [R, R, R, R]_7' in strategy_name_list - assert '[R, R, R, S0] -> [R, R, R, S0]_6' in strategy_name_list - assert '[R, R, R, S1] -> [R, R, R, S1]_5' in strategy_name_list - assert '[S01, R, R, R] -> [S01, R, R, R]_0' in strategy_name_list - assert '[R, R, R, R] -> [R, R, R, R]_1' in strategy_name_list - assert '[R, R, S01, R] -> [R, R, S01, R]_2' in strategy_name_list - assert '[R, R, R, R] -> [R, R, R, R]_3' in strategy_name_list - assert '[R, R, R, S01] -> [R, R, R, S01]_4' in strategy_name_list - - -@run_on_environment_flag(name='AUTO_PARALLEL') + assert "[S0, R, R, S1] -> [S0, R, R, S1]_11" in strategy_name_list + assert "[R, R, R, S1] -> [R, R, R, S1]_12" in strategy_name_list + assert "[R, R, S0, S1] -> [R, R, S0, S1]_13" in strategy_name_list + assert "[S1, R, R, S0] -> [S1, R, R, S0]_14" in strategy_name_list + assert "[R, R, R, S0] -> [R, R, R, S0]_15" in strategy_name_list + assert "[R, R, S1, S0] -> [R, R, S1, S0]_16" in strategy_name_list + assert "[S0, R, R, R] -> [S0, R, R, R]_17" in strategy_name_list + assert "[R, R, R, R] -> [R, R, R, R]_18" in strategy_name_list + assert "[R, R, S0, R] -> [R, R, S0, R]_19" in strategy_name_list + assert "[S1, R, R, R] -> [S1, R, R, R]_20" in strategy_name_list + assert "[R, R, R, R] -> [R, R, R, R]_21" in strategy_name_list + assert "[R, R, S1, R] -> [R, R, S1, R]_22" in strategy_name_list + assert "[R, R, R, S1] -> [R, R, R, S1]_10" in strategy_name_list + assert "[R, R, R, S0] -> [R, R, R, S0]_9" in strategy_name_list + assert "[R, R, R, R] -> [R, R, R, R]_8" in strategy_name_list + assert "[R, R, R, R] -> [R, R, R, R]_7" in strategy_name_list + assert "[R, R, R, S0] -> [R, R, R, S0]_6" in strategy_name_list + assert "[R, R, R, S1] -> [R, R, R, S1]_5" in strategy_name_list + assert "[S01, R, R, R] -> [S01, R, R, R]_0" in strategy_name_list + assert "[R, R, R, R] -> [R, R, R, R]_1" in strategy_name_list + assert "[R, R, S01, R] -> [R, R, S01, R]_2" in strategy_name_list + assert "[R, R, R, R] -> [R, R, R, R]_3" in strategy_name_list + assert "[R, R, R, S01] -> [R, R, R, S01]_4" in strategy_name_list + + +@run_on_environment_flag(name="AUTO_PARALLEL") @pytest.mark.dist @rerun_if_address_is_in_use() -@parameterize('softmax_dim', [0, 1, 2, 3]) -@parameterize('model_cls', [LinearSplitModel]) +@parameterize("softmax_dim", [0, 1, 2, 3]) +@parameterize("model_cls", [LinearSplitModel]) def test_split_handler(softmax_dim, model_cls): spawn(check_split_handler, 4, softmax_dim=softmax_dim, model_cls=model_cls) -if __name__ == '__main__': +if __name__ == "__main__": test_split_handler() diff --git a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_split_handler.py b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_split_handler.py index f860c629b0a0ca8fb4211e1f7b4efd4d5ccec71c..0318023c858ddb646a397ce29de5c9fb2df4ad06 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_split_handler.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_split_handler.py @@ -17,7 +17,6 @@ from tests.test_auto_parallel.test_tensor_shard.test_node_handler.utils import n class ConvSplitModel(nn.Module): - def __init__(self, split_size, split_dim): super().__init__() self.split_size = split_size @@ -30,7 +29,6 @@ class ConvSplitModel(nn.Module): class LinearSplitModel(nn.Module): - def __init__(self, split_size, split_dim): super().__init__() self.split_size = split_size @@ -44,19 +42,19 @@ class LinearSplitModel(nn.Module): def check_split_handler(rank, world_size, port, split_size, split_dim, model_cls): disable_existing_loggers() - launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") model = model_cls(split_size=split_size, split_dim=split_dim).cuda() - if model_cls.__name__ == 'ConvSplitModel': - input = torch.rand(8, 8, 66, 66).to('cuda') - other = torch.rand(16, 8, 3, 3).to('cuda') + if model_cls.__name__ == "ConvSplitModel": + input = torch.rand(8, 8, 66, 66).to("cuda") + other = torch.rand(16, 8, 3, 3).to("cuda") # index of conv node in computation graph node_index = 2 # total number of conv strategies strategy_number = 16 - if model_cls.__name__ == 'LinearSplitModel': - input = torch.rand(8, 16, 64, 32).to('cuda') - other = torch.rand(64, 32).to('cuda') + if model_cls.__name__ == "LinearSplitModel": + input = torch.rand(8, 16, 64, 32).to("cuda") + other = torch.rand(64, 32).to("cuda") # index of linear node in computation graph node_index = 2 # total number of linear strategies @@ -66,15 +64,17 @@ def check_split_handler(rank, world_size, port, split_size, split_dim, model_cls mesh_shape = (2, 2) device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True) - numerical_test_for_node_strategy(model=model, - device_mesh=device_mesh, - node_index=node_index, - strategy_number=strategy_number, - input_args=[input, other], - meta_arg_names=['input', 'other'], - node_type='following') + numerical_test_for_node_strategy( + model=model, + device_mesh=device_mesh, + node_index=node_index, + strategy_number=strategy_number, + input_args=[input, other], + meta_arg_names=["input", "other"], + node_type="following", + ) tracer = ColoTracer(bias_addition_split=True) - if model_cls.__name__ == 'ConvSplitModel': + if model_cls.__name__ == "ConvSplitModel": # graph(): # %input_1 : torch.Tensor [#users=1] = placeholder[target=input] # %other : torch.Tensor [#users=1] = placeholder[target=other] @@ -82,12 +82,12 @@ def check_split_handler(rank, world_size, port, split_size, split_dim, model_cls # %split : [#users=1] = call_method[target=split](args = (%conv2d,), kwargs = {}) # return split meta_args = { - 'input': torch.rand(8, 8, 66, 66).to('meta'), - 'other': torch.rand(16, 8, 3, 3).to('meta'), + "input": torch.rand(8, 8, 66, 66).to("meta"), + "other": torch.rand(16, 8, 3, 3).to("meta"), } graph = tracer.trace(model, meta_args=meta_args) - if model_cls.__name__ == 'LinearSplitModel': + if model_cls.__name__ == "LinearSplitModel": # graph(): # %input_1 : torch.Tensor [#users=1] = placeholder[target=input] # %other : torch.Tensor [#users=1] = placeholder[target=other] @@ -95,8 +95,8 @@ def check_split_handler(rank, world_size, port, split_size, split_dim, model_cls # %split : [#users=1] = call_method[target=split](args = (%linear,), kwargs = {}) # return split meta_args = { - 'input': torch.rand(8, 16, 64, 32).to('meta'), - 'other': torch.rand(64, 32).to('meta'), + "input": torch.rand(8, 16, 64, 32).to("meta"), + "other": torch.rand(64, 32).to("meta"), } graph = tracer.trace(model, meta_args=meta_args) @@ -109,21 +109,20 @@ def check_split_handler(rank, world_size, port, split_size, split_dim, model_cls previous_strategies_vector = StrategiesVector(previous_mod_node) # build handler - if model_cls.__name__ == 'ConvSplitModel': - - conv_handler = ConvFunctionHandler(node=previous_mod_node, - device_mesh=device_mesh, - strategies_vector=previous_strategies_vector) + if model_cls.__name__ == "ConvSplitModel": + conv_handler = ConvFunctionHandler( + node=previous_mod_node, device_mesh=device_mesh, strategies_vector=previous_strategies_vector + ) conv_handler.register_strategy(compute_resharding_cost=False) - setattr(previous_mod_node, 'strategies_vector', previous_strategies_vector) + setattr(previous_mod_node, "strategies_vector", previous_strategies_vector) - if model_cls.__name__ == 'LinearSplitModel': + if model_cls.__name__ == "LinearSplitModel": assert len(previous_strategies_vector) == 0 - linear_handler = LinearFunctionHandler(node=previous_mod_node, - device_mesh=device_mesh, - strategies_vector=previous_strategies_vector) + linear_handler = LinearFunctionHandler( + node=previous_mod_node, device_mesh=device_mesh, strategies_vector=previous_strategies_vector + ) linear_handler.register_strategy(compute_resharding_cost=False) - setattr(previous_mod_node, 'strategies_vector', previous_strategies_vector) + setattr(previous_mod_node, "strategies_vector", previous_strategies_vector) split_handler = SplitHandler(node=split_node, device_mesh=device_mesh, strategies_vector=split_strategies_vector) @@ -137,124 +136,122 @@ def check_split_handler(rank, world_size, port, split_size, split_dim, model_cls # make sure they have valid values assert op_data.data is not None - if model_cls.__name__ == 'ConvSplitModel': - assert mapping['input'].name == "conv2d" + if model_cls.__name__ == "ConvSplitModel": + assert mapping["input"].name == "conv2d" else: - assert mapping['input'].name == "linear" - assert mapping['input'].data.is_meta - assert mapping['input'].data.shape == torch.Size([8, 16, 64, 64]) - assert mapping['input'].type == OperationDataType.ARG - assert mapping['input'].logical_shape == torch.Size([8, 16, 64, 64]) + assert mapping["input"].name == "linear" + assert mapping["input"].data.is_meta + assert mapping["input"].data.shape == torch.Size([8, 16, 64, 64]) + assert mapping["input"].type == OperationDataType.ARG + assert mapping["input"].logical_shape == torch.Size([8, 16, 64, 64]) - assert mapping['output'].name == "split" + assert mapping["output"].name == "split" split_items = torch.empty([8, 16, 64, 64]).split(split_size, split_dim) - assert mapping['output'].logical_shape == tuple([item.shape for item in split_items]) - assert mapping['output'].type == OperationDataType.OUTPUT + assert mapping["output"].logical_shape == tuple([item.shape for item in split_items]) + assert mapping["output"].type == OperationDataType.OUTPUT # reshape handler is a following strategy handler, so the number of strategies is equal to the predecessor node. assert len(split_strategies_vector) == len(previous_strategies_vector) strategy_name_list = [strategy.name for strategy in split_strategies_vector] - if model_cls.__name__ == 'ConvSplitModel': - + if model_cls.__name__ == "ConvSplitModel": if split_dim == 0: - assert '[R, S1, R, R]_0' in strategy_name_list - assert '[R, S0, R, R]_1' in strategy_name_list - assert '[R, R, R, R]_2' in strategy_name_list - assert '[R, R, R, R]_3' in strategy_name_list - assert '[R, R, R, R]_4' in strategy_name_list - assert '[R, R, R, R]_5' in strategy_name_list - assert '[R, S1, R, R]_6' in strategy_name_list - assert '[R, S0, R, R]_7' in strategy_name_list - assert '[R, R, R, R]_8' in strategy_name_list - assert '[R, R, R, R]_9' in strategy_name_list - assert '[R, S0, R, R]_10' in strategy_name_list - assert '[R, S1, R, R]_11' in strategy_name_list - assert '[R, R, R, R]_12' in strategy_name_list - assert '[R, R, R, R]_13' in strategy_name_list - assert '[R, R, R, R]_14' in strategy_name_list - assert '[R, S01, R, R]_15' in strategy_name_list + assert "[R, S1, R, R]_0" in strategy_name_list + assert "[R, S0, R, R]_1" in strategy_name_list + assert "[R, R, R, R]_2" in strategy_name_list + assert "[R, R, R, R]_3" in strategy_name_list + assert "[R, R, R, R]_4" in strategy_name_list + assert "[R, R, R, R]_5" in strategy_name_list + assert "[R, S1, R, R]_6" in strategy_name_list + assert "[R, S0, R, R]_7" in strategy_name_list + assert "[R, R, R, R]_8" in strategy_name_list + assert "[R, R, R, R]_9" in strategy_name_list + assert "[R, S0, R, R]_10" in strategy_name_list + assert "[R, S1, R, R]_11" in strategy_name_list + assert "[R, R, R, R]_12" in strategy_name_list + assert "[R, R, R, R]_13" in strategy_name_list + assert "[R, R, R, R]_14" in strategy_name_list + assert "[R, S01, R, R]_15" in strategy_name_list if split_dim == 1: - assert '[S0, R, R, R]_0' in strategy_name_list - assert '[S1, R, R, R]_1' in strategy_name_list - assert '[S0, R, R, R]_2' in strategy_name_list - assert '[S1, R, R, R]_3' in strategy_name_list - assert '[S0, R, R, R]_4' in strategy_name_list - assert '[S1, R, R, R]_5' in strategy_name_list - assert '[R, R, R, R]_6' in strategy_name_list - assert '[R, R, R, R]_7' in strategy_name_list - assert '[R, R, R, R]_8' in strategy_name_list - assert '[R, R, R, R]_9' in strategy_name_list - assert '[R, R, R, R]_10' in strategy_name_list - assert '[R, R, R, R]_11' in strategy_name_list - assert '[R, R, R, R]_12' in strategy_name_list - assert '[S01, R, R, R]_13' in strategy_name_list - assert '[R, R, R, R]_14' in strategy_name_list - assert '[R, R, R, R]_15' in strategy_name_list - - if model_cls.__name__ == 'LinearSplitModel': - + assert "[S0, R, R, R]_0" in strategy_name_list + assert "[S1, R, R, R]_1" in strategy_name_list + assert "[S0, R, R, R]_2" in strategy_name_list + assert "[S1, R, R, R]_3" in strategy_name_list + assert "[S0, R, R, R]_4" in strategy_name_list + assert "[S1, R, R, R]_5" in strategy_name_list + assert "[R, R, R, R]_6" in strategy_name_list + assert "[R, R, R, R]_7" in strategy_name_list + assert "[R, R, R, R]_8" in strategy_name_list + assert "[R, R, R, R]_9" in strategy_name_list + assert "[R, R, R, R]_10" in strategy_name_list + assert "[R, R, R, R]_11" in strategy_name_list + assert "[R, R, R, R]_12" in strategy_name_list + assert "[S01, R, R, R]_13" in strategy_name_list + assert "[R, R, R, R]_14" in strategy_name_list + assert "[R, R, R, R]_15" in strategy_name_list + + if model_cls.__name__ == "LinearSplitModel": if split_dim == 0: - assert '[R, R, R, S1]_11' in strategy_name_list - assert '[R, S0, R, S1]_12' in strategy_name_list - assert '[R, R, S0, S1]_13' in strategy_name_list - assert '[R, R, R, S0]_14' in strategy_name_list - assert '[R, S1, R, S0]_15' in strategy_name_list - assert '[R, R, S1, S0]_16' in strategy_name_list - assert '[R, R, R, R]_17' in strategy_name_list - assert '[R, S0, R, R]_18' in strategy_name_list - assert '[R, R, S0, R]_19' in strategy_name_list - assert '[R, R, R, R]_20' in strategy_name_list - assert '[R, S1, R, R]_21' in strategy_name_list - assert '[R, R, S1, R]_22' in strategy_name_list - assert '[R, R, R, S1]_10' in strategy_name_list - assert '[R, R, R, S0]_9' in strategy_name_list - assert '[R, R, R, R]_8' in strategy_name_list - assert '[R, R, R, R]_7' in strategy_name_list - assert '[R, R, R, S0]_6' in strategy_name_list - assert '[R, R, R, S1]_5' in strategy_name_list - assert '[R, R, R, R]_0' in strategy_name_list - assert '[R, S01, R, R]_1' in strategy_name_list - assert '[R, R, S01, R]_2' in strategy_name_list - assert '[R, R, R, R]_3' in strategy_name_list - assert '[R, R, R, S01]_4' in strategy_name_list + assert "[R, R, R, S1]_11" in strategy_name_list + assert "[R, S0, R, S1]_12" in strategy_name_list + assert "[R, R, S0, S1]_13" in strategy_name_list + assert "[R, R, R, S0]_14" in strategy_name_list + assert "[R, S1, R, S0]_15" in strategy_name_list + assert "[R, R, S1, S0]_16" in strategy_name_list + assert "[R, R, R, R]_17" in strategy_name_list + assert "[R, S0, R, R]_18" in strategy_name_list + assert "[R, R, S0, R]_19" in strategy_name_list + assert "[R, R, R, R]_20" in strategy_name_list + assert "[R, S1, R, R]_21" in strategy_name_list + assert "[R, R, S1, R]_22" in strategy_name_list + assert "[R, R, R, S1]_10" in strategy_name_list + assert "[R, R, R, S0]_9" in strategy_name_list + assert "[R, R, R, R]_8" in strategy_name_list + assert "[R, R, R, R]_7" in strategy_name_list + assert "[R, R, R, S0]_6" in strategy_name_list + assert "[R, R, R, S1]_5" in strategy_name_list + assert "[R, R, R, R]_0" in strategy_name_list + assert "[R, S01, R, R]_1" in strategy_name_list + assert "[R, R, S01, R]_2" in strategy_name_list + assert "[R, R, R, R]_3" in strategy_name_list + assert "[R, R, R, S01]_4" in strategy_name_list if split_dim == 1: - assert '[S0, R, R, S1]_11' in strategy_name_list - assert '[R, R, R, S1]_12' in strategy_name_list - assert '[R, R, S0, S1]_13' in strategy_name_list - assert '[S1, R, R, S0]_14' in strategy_name_list - assert '[R, R, R, S0]_15' in strategy_name_list - assert '[R, R, S1, S0]_16' in strategy_name_list - assert '[S0, R, R, R]_17' in strategy_name_list - assert '[R, R, R, R]_18' in strategy_name_list - assert '[R, R, S0, R]_19' in strategy_name_list - assert '[S1, R, R, R]_20' in strategy_name_list - assert '[R, R, R, R]_21' in strategy_name_list - assert '[R, R, S1, R]_22' in strategy_name_list - assert '[R, R, R, S1]_10' in strategy_name_list - assert '[R, R, R, S0]_9' in strategy_name_list - assert '[R, R, R, R]_8' in strategy_name_list - assert '[R, R, R, R]_7' in strategy_name_list - assert '[R, R, R, S0]_6' in strategy_name_list - assert '[R, R, R, S1]_5' in strategy_name_list - assert '[S01, R, R, R]_0' in strategy_name_list - assert '[R, R, R, R]_1' in strategy_name_list - assert '[R, R, S01, R]_2' in strategy_name_list - assert '[R, R, R, R]_3' in strategy_name_list - assert '[R, R, R, S01]_4' in strategy_name_list - - -@run_on_environment_flag(name='AUTO_PARALLEL') + assert "[S0, R, R, S1]_11" in strategy_name_list + assert "[R, R, R, S1]_12" in strategy_name_list + assert "[R, R, S0, S1]_13" in strategy_name_list + assert "[S1, R, R, S0]_14" in strategy_name_list + assert "[R, R, R, S0]_15" in strategy_name_list + assert "[R, R, S1, S0]_16" in strategy_name_list + assert "[S0, R, R, R]_17" in strategy_name_list + assert "[R, R, R, R]_18" in strategy_name_list + assert "[R, R, S0, R]_19" in strategy_name_list + assert "[S1, R, R, R]_20" in strategy_name_list + assert "[R, R, R, R]_21" in strategy_name_list + assert "[R, R, S1, R]_22" in strategy_name_list + assert "[R, R, R, S1]_10" in strategy_name_list + assert "[R, R, R, S0]_9" in strategy_name_list + assert "[R, R, R, R]_8" in strategy_name_list + assert "[R, R, R, R]_7" in strategy_name_list + assert "[R, R, R, S0]_6" in strategy_name_list + assert "[R, R, R, S1]_5" in strategy_name_list + assert "[S01, R, R, R]_0" in strategy_name_list + assert "[R, R, R, R]_1" in strategy_name_list + assert "[R, R, S01, R]_2" in strategy_name_list + assert "[R, R, R, R]_3" in strategy_name_list + assert "[R, R, R, S01]_4" in strategy_name_list + + +@run_on_environment_flag(name="AUTO_PARALLEL") @pytest.mark.dist @rerun_if_address_is_in_use() -@parameterize('split_size', [2]) -@parameterize('split_dim', [0, 1, 2]) -@parameterize('model_cls', [ConvSplitModel, LinearSplitModel]) +@parameterize("split_size", [2]) +@parameterize("split_dim", [0, 1, 2]) +@parameterize("model_cls", [ConvSplitModel, LinearSplitModel]) def test_split_handler(split_size, split_dim, model_cls): spawn(check_split_handler, 4, split_size=split_size, split_dim=split_dim, model_cls=model_cls) -if __name__ == '__main__': +if __name__ == "__main__": test_split_handler() diff --git a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_sum_handler.py b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_sum_handler.py index c11291ecac969f4f2a93b12646baa546ad231f13..cbd3e47044b39e91f5bac9926d7d2ce9a9a405f8 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_sum_handler.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_sum_handler.py @@ -16,7 +16,6 @@ from tests.test_auto_parallel.test_tensor_shard.test_node_handler.utils import n class LinearSumModel(nn.Module): - def __init__(self, sum_dims, keepdim): super().__init__() self.sum_dims = sum_dims @@ -33,26 +32,28 @@ class LinearSumModel(nn.Module): def check_sum_handler(rank, world_size, port, sum_dims, keepdim): disable_existing_loggers() - launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") model = LinearSumModel(sum_dims=sum_dims, keepdim=keepdim).cuda() physical_mesh_id = torch.arange(0, 4) mesh_shape = (2, 2) device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True) - input = torch.rand(8, 16, 64, 32).to('cuda') - other = torch.rand(64, 32).to('cuda') + input = torch.rand(8, 16, 64, 32).to("cuda") + other = torch.rand(64, 32).to("cuda") # index of linear node in computation graph node_index = 2 # total number of linear strategies strategy_number = 24 - numerical_test_for_node_strategy(model=model, - device_mesh=device_mesh, - node_index=node_index, - strategy_number=strategy_number, - input_args=[input, other], - meta_arg_names=['input', 'other'], - node_type='following') + numerical_test_for_node_strategy( + model=model, + device_mesh=device_mesh, + node_index=node_index, + strategy_number=strategy_number, + input_args=[input, other], + meta_arg_names=["input", "other"], + node_type="following", + ) tracer = ColoTracer(bias_addition_split=True) @@ -63,8 +64,8 @@ def check_sum_handler(rank, world_size, port, sum_dims, keepdim): # %sum_1 : [#users=1] = call_function[target=torch.sum](args = (%linear,), kwargs = {}) # return sum_1 meta_args = { - "input": torch.rand(8, 16, 64, 32).to('meta'), - "other": torch.rand(64, 32).to('meta'), + "input": torch.rand(8, 16, 64, 32).to("meta"), + "other": torch.rand(64, 32).to("meta"), } graph = tracer.trace(model, meta_args=meta_args) gm = ColoGraphModule(model, graph) @@ -78,11 +79,11 @@ def check_sum_handler(rank, world_size, port, sum_dims, keepdim): # build handler assert len(previous_strategies_vector) == 0 - linear_handler = LinearFunctionHandler(node=previous_mod_node, - device_mesh=device_mesh, - strategies_vector=previous_strategies_vector) + linear_handler = LinearFunctionHandler( + node=previous_mod_node, device_mesh=device_mesh, strategies_vector=previous_strategies_vector + ) linear_handler.register_strategy(compute_resharding_cost=False) - setattr(previous_mod_node, 'strategies_vector', previous_strategies_vector) + setattr(previous_mod_node, "strategies_vector", previous_strategies_vector) sum_handler = SumHandler(node=sum_node, device_mesh=device_mesh, strategies_vector=sum_strategies_vector) @@ -100,131 +101,131 @@ def check_sum_handler(rank, world_size, port, sum_dims, keepdim): # make sure they have valid values assert op_data.data is not None - assert mapping['input'].name == "linear" - assert mapping['input'].data.is_meta - assert mapping['input'].data.shape == torch.Size([8, 16, 64, 64]) - assert mapping['input'].type == OperationDataType.ARG - assert mapping['input'].logical_shape == torch.Size([8, 16, 64, 64]) + assert mapping["input"].name == "linear" + assert mapping["input"].data.is_meta + assert mapping["input"].data.shape == torch.Size([8, 16, 64, 64]) + assert mapping["input"].type == OperationDataType.ARG + assert mapping["input"].logical_shape == torch.Size([8, 16, 64, 64]) - assert mapping['output'].name == "sum_1" + assert mapping["output"].name == "sum_1" sum_node_shape = torch.empty([8, 16, 64, 64]).sum(sum_dims, keepdim=keepdim).shape - assert mapping['output'].logical_shape == sum_node_shape - assert mapping['output'].type == OperationDataType.OUTPUT + assert mapping["output"].logical_shape == sum_node_shape + assert mapping["output"].type == OperationDataType.OUTPUT # check strategy name if sum_dims == (0, 2) and keepdim == False: - assert '[R, R, R, R] -> [R, R]_0' in strategy_name_list - assert '[R, S01, R, R] -> [S01, R]_1' in strategy_name_list - assert '[R, R, R, R] -> [R, R]_2' in strategy_name_list - assert '[R, R, R, R] -> [R, R]_3' in strategy_name_list - assert '[R, R, R, S01] -> [R, S01]_4' in strategy_name_list - assert '[R, R, R, S1] -> [R, S1]_5' in strategy_name_list - assert '[R, R, R, S0] -> [R, S0]_6' in strategy_name_list - assert '[R, R, R, R] -> [R, R]_7' in strategy_name_list - assert '[R, R, R, R] -> [R, R]_8' in strategy_name_list - assert '[R, R, R, S0] -> [R, S0]_9' in strategy_name_list - assert '[R, R, R, S1] -> [R, S1]_10' in strategy_name_list - assert '[R, R, R, S1] -> [R, S1]_11' in strategy_name_list - assert '[R, S0, R, S1] -> [S0, S1]_12' in strategy_name_list - assert '[R, R, R, S1] -> [R, S1]_13' in strategy_name_list - assert '[R, R, R, S0] -> [R, S0]_14' in strategy_name_list - assert '[R, S1, R, S0] -> [S1, S0]_15' in strategy_name_list - assert '[R, R, R, S0] -> [R, S0]_16' in strategy_name_list - assert '[R, R, R, R] -> [R, R]_17' in strategy_name_list - assert '[R, S0, R, R] -> [S0, R]_18' in strategy_name_list - assert '[R, R, R, R] -> [R, R]_19' in strategy_name_list - assert '[R, R, R, R] -> [R, R]_20' in strategy_name_list - assert '[R, S1, R, R] -> [S1, R]_21' in strategy_name_list - assert '[R, R, R, R] -> [R, R]_22' in strategy_name_list - assert '[R, R, R, R] -> [R, R]_23' in strategy_name_list + assert "[R, R, R, R] -> [R, R]_0" in strategy_name_list + assert "[R, S01, R, R] -> [S01, R]_1" in strategy_name_list + assert "[R, R, R, R] -> [R, R]_2" in strategy_name_list + assert "[R, R, R, R] -> [R, R]_3" in strategy_name_list + assert "[R, R, R, S01] -> [R, S01]_4" in strategy_name_list + assert "[R, R, R, S1] -> [R, S1]_5" in strategy_name_list + assert "[R, R, R, S0] -> [R, S0]_6" in strategy_name_list + assert "[R, R, R, R] -> [R, R]_7" in strategy_name_list + assert "[R, R, R, R] -> [R, R]_8" in strategy_name_list + assert "[R, R, R, S0] -> [R, S0]_9" in strategy_name_list + assert "[R, R, R, S1] -> [R, S1]_10" in strategy_name_list + assert "[R, R, R, S1] -> [R, S1]_11" in strategy_name_list + assert "[R, S0, R, S1] -> [S0, S1]_12" in strategy_name_list + assert "[R, R, R, S1] -> [R, S1]_13" in strategy_name_list + assert "[R, R, R, S0] -> [R, S0]_14" in strategy_name_list + assert "[R, S1, R, S0] -> [S1, S0]_15" in strategy_name_list + assert "[R, R, R, S0] -> [R, S0]_16" in strategy_name_list + assert "[R, R, R, R] -> [R, R]_17" in strategy_name_list + assert "[R, S0, R, R] -> [S0, R]_18" in strategy_name_list + assert "[R, R, R, R] -> [R, R]_19" in strategy_name_list + assert "[R, R, R, R] -> [R, R]_20" in strategy_name_list + assert "[R, S1, R, R] -> [S1, R]_21" in strategy_name_list + assert "[R, R, R, R] -> [R, R]_22" in strategy_name_list + assert "[R, R, R, R] -> [R, R]_23" in strategy_name_list if sum_dims == (0, 2) and keepdim == True: - assert '[R, R, R, R] -> [R, R, R, R]_0' in strategy_name_list - assert '[R, S01, R, R] -> [R, S01, R, R]_1' in strategy_name_list - assert '[R, R, R, R] -> [R, R, R, R]_2' in strategy_name_list - assert '[R, R, R, R] -> [R, R, R, R]_3' in strategy_name_list - assert '[R, R, R, S01] -> [R, R, R, S01]_4' in strategy_name_list - assert '[R, R, R, S1] -> [R, R, R, S1]_5' in strategy_name_list - assert '[R, R, R, S0] -> [R, R, R, S0]_6' in strategy_name_list - assert '[R, R, R, R] -> [R, R, R, R]_7' in strategy_name_list - assert '[R, R, R, R] -> [R, R, R, R]_8' in strategy_name_list - assert '[R, R, R, S0] -> [R, R, R, S0]_9' in strategy_name_list - assert '[R, R, R, S1] -> [R, R, R, S1]_10' in strategy_name_list - assert '[R, R, R, S1] -> [R, R, R, S1]_11' in strategy_name_list - assert '[R, S0, R, S1] -> [R, S0, R, S1]_12' in strategy_name_list - assert '[R, R, R, S1] -> [R, R, R, S1]_13' in strategy_name_list - assert '[R, R, R, S0] -> [R, R, R, S0]_14' in strategy_name_list - assert '[R, S1, R, S0] -> [R, S1, R, S0]_15' in strategy_name_list - assert '[R, R, R, S0] -> [R, R, R, S0]_16' in strategy_name_list - assert '[R, R, R, R] -> [R, R, R, R]_17' in strategy_name_list - assert '[R, S0, R, R] -> [R, S0, R, R]_18' in strategy_name_list - assert '[R, R, R, R] -> [R, R, R, R]_19' in strategy_name_list - assert '[R, R, R, R] -> [R, R, R, R]_20' in strategy_name_list - assert '[R, S1, R, R] -> [R, S1, R, R]_21' in strategy_name_list - assert '[R, R, R, R] -> [R, R, R, R]_22' in strategy_name_list - assert '[R, R, R, R] -> [R, R, R, R]_23' in strategy_name_list + assert "[R, R, R, R] -> [R, R, R, R]_0" in strategy_name_list + assert "[R, S01, R, R] -> [R, S01, R, R]_1" in strategy_name_list + assert "[R, R, R, R] -> [R, R, R, R]_2" in strategy_name_list + assert "[R, R, R, R] -> [R, R, R, R]_3" in strategy_name_list + assert "[R, R, R, S01] -> [R, R, R, S01]_4" in strategy_name_list + assert "[R, R, R, S1] -> [R, R, R, S1]_5" in strategy_name_list + assert "[R, R, R, S0] -> [R, R, R, S0]_6" in strategy_name_list + assert "[R, R, R, R] -> [R, R, R, R]_7" in strategy_name_list + assert "[R, R, R, R] -> [R, R, R, R]_8" in strategy_name_list + assert "[R, R, R, S0] -> [R, R, R, S0]_9" in strategy_name_list + assert "[R, R, R, S1] -> [R, R, R, S1]_10" in strategy_name_list + assert "[R, R, R, S1] -> [R, R, R, S1]_11" in strategy_name_list + assert "[R, S0, R, S1] -> [R, S0, R, S1]_12" in strategy_name_list + assert "[R, R, R, S1] -> [R, R, R, S1]_13" in strategy_name_list + assert "[R, R, R, S0] -> [R, R, R, S0]_14" in strategy_name_list + assert "[R, S1, R, S0] -> [R, S1, R, S0]_15" in strategy_name_list + assert "[R, R, R, S0] -> [R, R, R, S0]_16" in strategy_name_list + assert "[R, R, R, R] -> [R, R, R, R]_17" in strategy_name_list + assert "[R, S0, R, R] -> [R, S0, R, R]_18" in strategy_name_list + assert "[R, R, R, R] -> [R, R, R, R]_19" in strategy_name_list + assert "[R, R, R, R] -> [R, R, R, R]_20" in strategy_name_list + assert "[R, S1, R, R] -> [R, S1, R, R]_21" in strategy_name_list + assert "[R, R, R, R] -> [R, R, R, R]_22" in strategy_name_list + assert "[R, R, R, R] -> [R, R, R, R]_23" in strategy_name_list if sum_dims == 1 and keepdim == False: - assert '[S01, R, R, R] -> [S01, R, R]_0' in strategy_name_list - assert '[R, R, R, R] -> [R, R, R]_1' in strategy_name_list - assert '[R, R, S01, R] -> [R, S01, R]_2' in strategy_name_list - assert '[R, R, R, R] -> [R, R, R]_3' in strategy_name_list - assert '[R, R, R, S01] -> [R, R, S01]_4' in strategy_name_list - assert '[R, R, R, S1] -> [R, R, S1]_5' in strategy_name_list - assert '[R, R, R, S0] -> [R, R, S0]_6' in strategy_name_list - assert '[R, R, R, R] -> [R, R, R]_7' in strategy_name_list - assert '[R, R, R, R] -> [R, R, R]_8' in strategy_name_list - assert '[R, R, R, S0] -> [R, R, S0]_9' in strategy_name_list - assert '[R, R, R, S1] -> [R, R, S1]_10' in strategy_name_list - assert '[S0, R, R, S1] -> [S0, R, S1]_11' in strategy_name_list - assert '[R, R, R, S1] -> [R, R, S1]_12' in strategy_name_list - assert '[R, R, S0, S1] -> [R, S0, S1]_13' in strategy_name_list - assert '[S1, R, R, S0] -> [S1, R, S0]_14' in strategy_name_list - assert '[R, R, R, S0] -> [R, R, S0]_15' in strategy_name_list - assert '[R, R, S1, S0] -> [R, S1, S0]_16' in strategy_name_list - assert '[S0, R, R, R] -> [S0, R, R]_17' in strategy_name_list - assert '[R, R, R, R] -> [R, R, R]_18' in strategy_name_list - assert '[R, R, S0, R] -> [R, S0, R]_19' in strategy_name_list - assert '[S1, R, R, R] -> [S1, R, R]_20' in strategy_name_list - assert '[R, R, R, R] -> [R, R, R]_21' in strategy_name_list - assert '[R, R, S1, R] -> [R, S1, R]_22' in strategy_name_list - assert '[R, R, R, R] -> [R, R, R]_23' in strategy_name_list + assert "[S01, R, R, R] -> [S01, R, R]_0" in strategy_name_list + assert "[R, R, R, R] -> [R, R, R]_1" in strategy_name_list + assert "[R, R, S01, R] -> [R, S01, R]_2" in strategy_name_list + assert "[R, R, R, R] -> [R, R, R]_3" in strategy_name_list + assert "[R, R, R, S01] -> [R, R, S01]_4" in strategy_name_list + assert "[R, R, R, S1] -> [R, R, S1]_5" in strategy_name_list + assert "[R, R, R, S0] -> [R, R, S0]_6" in strategy_name_list + assert "[R, R, R, R] -> [R, R, R]_7" in strategy_name_list + assert "[R, R, R, R] -> [R, R, R]_8" in strategy_name_list + assert "[R, R, R, S0] -> [R, R, S0]_9" in strategy_name_list + assert "[R, R, R, S1] -> [R, R, S1]_10" in strategy_name_list + assert "[S0, R, R, S1] -> [S0, R, S1]_11" in strategy_name_list + assert "[R, R, R, S1] -> [R, R, S1]_12" in strategy_name_list + assert "[R, R, S0, S1] -> [R, S0, S1]_13" in strategy_name_list + assert "[S1, R, R, S0] -> [S1, R, S0]_14" in strategy_name_list + assert "[R, R, R, S0] -> [R, R, S0]_15" in strategy_name_list + assert "[R, R, S1, S0] -> [R, S1, S0]_16" in strategy_name_list + assert "[S0, R, R, R] -> [S0, R, R]_17" in strategy_name_list + assert "[R, R, R, R] -> [R, R, R]_18" in strategy_name_list + assert "[R, R, S0, R] -> [R, S0, R]_19" in strategy_name_list + assert "[S1, R, R, R] -> [S1, R, R]_20" in strategy_name_list + assert "[R, R, R, R] -> [R, R, R]_21" in strategy_name_list + assert "[R, R, S1, R] -> [R, S1, R]_22" in strategy_name_list + assert "[R, R, R, R] -> [R, R, R]_23" in strategy_name_list if sum_dims == 1 and keepdim == True: - assert '[S01, R, R, R] -> [S01, R, R, R]_0' in strategy_name_list - assert '[R, R, R, R] -> [R, R, R, R]_1' in strategy_name_list - assert '[R, R, S01, R] -> [R, R, S01, R]_2' in strategy_name_list - assert '[R, R, R, R] -> [R, R, R, R]_3' in strategy_name_list - assert '[R, R, R, S01] -> [R, R, R, S01]_4' in strategy_name_list - assert '[R, R, R, S1] -> [R, R, R, S1]_5' in strategy_name_list - assert '[R, R, R, S0] -> [R, R, R, S0]_6' in strategy_name_list - assert '[R, R, R, R] -> [R, R, R, R]_7' in strategy_name_list - assert '[R, R, R, R] -> [R, R, R, R]_8' in strategy_name_list - assert '[R, R, R, S0] -> [R, R, R, S0]_9' in strategy_name_list - assert '[R, R, R, S1] -> [R, R, R, S1]_10' in strategy_name_list - assert '[S0, R, R, S1] -> [S0, R, R, S1]_11' in strategy_name_list - assert '[R, R, R, S1] -> [R, R, R, S1]_12' in strategy_name_list - assert '[R, R, S0, S1] -> [R, R, S0, S1]_13' in strategy_name_list - assert '[S1, R, R, S0] -> [S1, R, R, S0]_14' in strategy_name_list - assert '[R, R, R, S0] -> [R, R, R, S0]_15' in strategy_name_list - assert '[R, R, S1, S0] -> [R, R, S1, S0]_16' in strategy_name_list - assert '[S0, R, R, R] -> [S0, R, R, R]_17' in strategy_name_list - assert '[R, R, R, R] -> [R, R, R, R]_18' in strategy_name_list - assert '[R, R, S0, R] -> [R, R, S0, R]_19' in strategy_name_list - assert '[S1, R, R, R] -> [S1, R, R, R]_20' in strategy_name_list - assert '[R, R, R, R] -> [R, R, R, R]_21' in strategy_name_list - assert '[R, R, S1, R] -> [R, R, S1, R]_22' in strategy_name_list - assert '[R, R, R, R] -> [R, R, R, R]_23' in strategy_name_list - - -@run_on_environment_flag(name='AUTO_PARALLEL') + assert "[S01, R, R, R] -> [S01, R, R, R]_0" in strategy_name_list + assert "[R, R, R, R] -> [R, R, R, R]_1" in strategy_name_list + assert "[R, R, S01, R] -> [R, R, S01, R]_2" in strategy_name_list + assert "[R, R, R, R] -> [R, R, R, R]_3" in strategy_name_list + assert "[R, R, R, S01] -> [R, R, R, S01]_4" in strategy_name_list + assert "[R, R, R, S1] -> [R, R, R, S1]_5" in strategy_name_list + assert "[R, R, R, S0] -> [R, R, R, S0]_6" in strategy_name_list + assert "[R, R, R, R] -> [R, R, R, R]_7" in strategy_name_list + assert "[R, R, R, R] -> [R, R, R, R]_8" in strategy_name_list + assert "[R, R, R, S0] -> [R, R, R, S0]_9" in strategy_name_list + assert "[R, R, R, S1] -> [R, R, R, S1]_10" in strategy_name_list + assert "[S0, R, R, S1] -> [S0, R, R, S1]_11" in strategy_name_list + assert "[R, R, R, S1] -> [R, R, R, S1]_12" in strategy_name_list + assert "[R, R, S0, S1] -> [R, R, S0, S1]_13" in strategy_name_list + assert "[S1, R, R, S0] -> [S1, R, R, S0]_14" in strategy_name_list + assert "[R, R, R, S0] -> [R, R, R, S0]_15" in strategy_name_list + assert "[R, R, S1, S0] -> [R, R, S1, S0]_16" in strategy_name_list + assert "[S0, R, R, R] -> [S0, R, R, R]_17" in strategy_name_list + assert "[R, R, R, R] -> [R, R, R, R]_18" in strategy_name_list + assert "[R, R, S0, R] -> [R, R, S0, R]_19" in strategy_name_list + assert "[S1, R, R, R] -> [S1, R, R, R]_20" in strategy_name_list + assert "[R, R, R, R] -> [R, R, R, R]_21" in strategy_name_list + assert "[R, R, S1, R] -> [R, R, S1, R]_22" in strategy_name_list + assert "[R, R, R, R] -> [R, R, R, R]_23" in strategy_name_list + + +@run_on_environment_flag(name="AUTO_PARALLEL") @pytest.mark.dist @rerun_if_address_is_in_use() -@parameterize('sum_dims', [(0, 2), 1]) -@parameterize('keepdim', [False, True]) +@parameterize("sum_dims", [(0, 2), 1]) +@parameterize("keepdim", [False, True]) def test_sum_handler(sum_dims, keepdim): spawn(check_sum_handler, 4, sum_dims=sum_dims, keepdim=keepdim) -if __name__ == '__main__': +if __name__ == "__main__": test_sum_handler() diff --git a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_tensor_constructor.py b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_tensor_constructor.py index 5b6ac051a8ef7944f35bb3e94a1170e3bfd8efc2..29089183165d177d8177db51c47bc882abe387f4 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_tensor_constructor.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_tensor_constructor.py @@ -11,7 +11,6 @@ from colossalai.testing import clear_cache_before_run, run_on_environment_flag class TensorConstructorModel(nn.Module): - def __init__(self): super().__init__() @@ -21,7 +20,7 @@ class TensorConstructorModel(nn.Module): return x -@run_on_environment_flag(name='AUTO_PARALLEL') +@run_on_environment_flag(name="AUTO_PARALLEL") @clear_cache_before_run() def test_where_handler(): model = TensorConstructorModel() @@ -33,7 +32,7 @@ def test_where_handler(): # %arange : [#users=1] = call_function[target=torch.arange](args = (%getitem,), kwargs = {}) # %add : [#users=1] = call_function[target=operator.add](args = (%x, %arange), kwargs = {}) # return add - meta_args = {'x': torch.rand(10).to('meta')} + meta_args = {"x": torch.rand(10).to("meta")} graph = tracer.trace(model, meta_args=meta_args) gm = ColoGraphModule(model, graph) shape_prop_pass(gm, *meta_args.values()) @@ -56,16 +55,16 @@ def test_where_handler(): assert op_data.logical_shape is not None assert op_data.data is not None - assert mapping['output'].name == "arange" - assert mapping['output'].data.is_meta - assert mapping['output'].data.shape == torch.Size([10]) - assert mapping['output'].type == OperationDataType.OUTPUT + assert mapping["output"].name == "arange" + assert mapping["output"].data.is_meta + assert mapping["output"].data.shape == torch.Size([10]) + assert mapping["output"].type == OperationDataType.OUTPUT handler.register_strategy(compute_resharding_cost=False) strategy_name_list = [val.name for val in strategies_vector] - assert 'Replica Tensor Constructor' in strategy_name_list + assert "Replica Tensor Constructor" in strategy_name_list -if __name__ == '__main__': +if __name__ == "__main__": test_where_handler() diff --git a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_unary_element_wise_handler.py b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_unary_element_wise_handler.py index f4e6dafdfd692e5c1ef5a5862c70f61ac31d6994..271d55ae917a7dba977edd128424be9385839804 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_unary_element_wise_handler.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_unary_element_wise_handler.py @@ -12,7 +12,6 @@ from colossalai.testing import clear_cache_before_run, run_on_environment_flag class ReLuModel(nn.Module): - def __init__(self): super().__init__() self.act = torch.nn.ReLU() @@ -23,7 +22,7 @@ class ReLuModel(nn.Module): return relu_node -@run_on_environment_flag(name='AUTO_PARALLEL') +@run_on_environment_flag(name="AUTO_PARALLEL") @clear_cache_before_run() def test_elementwise_handler(): model = ReLuModel() @@ -35,8 +34,8 @@ def test_elementwise_handler(): # %act : [#users=1] = call_module[target=act](args = (%conv2d,), kwargs = {}) # return act meta_args = { - 'input': torch.rand(4, 4, 64, 64).to('meta'), - 'other': torch.rand(16, 4, 3, 3).to('meta'), + "input": torch.rand(4, 4, 64, 64).to("meta"), + "other": torch.rand(16, 4, 3, 3).to("meta"), } graph = tracer.trace(model, meta_args=meta_args) gm = ColoGraphModule(model, graph) @@ -51,14 +50,14 @@ def test_elementwise_handler(): conv_strategies_vector = StrategiesVector(conv_mod_node) # build handler - conv_handler = ConvFunctionHandler(node=conv_mod_node, - device_mesh=device_mesh, - strategies_vector=conv_strategies_vector) + conv_handler = ConvFunctionHandler( + node=conv_mod_node, device_mesh=device_mesh, strategies_vector=conv_strategies_vector + ) conv_handler.register_strategy(compute_resharding_cost=False) - setattr(conv_mod_node, 'strategies_vector', conv_strategies_vector) - relu_handler = UnaryElementwiseHandler(node=relu_mod_node, - device_mesh=device_mesh, - strategies_vector=relu_strategies_vector) + setattr(conv_mod_node, "strategies_vector", conv_strategies_vector) + relu_handler = UnaryElementwiseHandler( + node=relu_mod_node, device_mesh=device_mesh, strategies_vector=relu_strategies_vector + ) relu_handler.register_strategy(compute_resharding_cost=False) @@ -70,20 +69,20 @@ def test_elementwise_handler(): # make sure they have valid values assert op_data.data is not None - assert mapping['input'].name == "conv2d" - assert mapping['input'].data.is_meta - assert mapping['input'].data.shape == torch.Size([4, 16, 62, 62]) - assert mapping['input'].type == OperationDataType.ARG - assert mapping['input'].logical_shape == torch.Size([4, 16, 62, 62]) + assert mapping["input"].name == "conv2d" + assert mapping["input"].data.is_meta + assert mapping["input"].data.shape == torch.Size([4, 16, 62, 62]) + assert mapping["input"].type == OperationDataType.ARG + assert mapping["input"].logical_shape == torch.Size([4, 16, 62, 62]) - assert mapping['output'].name == "act" - assert mapping['output'].data.is_meta - assert mapping['output'].data.shape == torch.Size([4, 16, 62, 62]) - assert mapping['output'].type == OperationDataType.OUTPUT + assert mapping["output"].name == "act" + assert mapping["output"].data.is_meta + assert mapping["output"].data.shape == torch.Size([4, 16, 62, 62]) + assert mapping["output"].type == OperationDataType.OUTPUT # getitem is a following strategy handler, so the number of strategies is equal to the predecessor node. assert len(relu_strategies_vector) == len(conv_strategies_vector) -if __name__ == '__main__': +if __name__ == "__main__": test_elementwise_handler() diff --git a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_view_handler.py b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_view_handler.py index fbb194d8e0b8d119cec8eb27f09657ee95dbeaaf..466168c79a0ba8655b043f16f59157ca267e89de 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_view_handler.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_view_handler.py @@ -18,7 +18,6 @@ from tests.test_auto_parallel.test_tensor_shard.test_node_handler.utils import n class ConvViewModel(nn.Module): - def __init__(self, tgt_shape): super().__init__() self.tgt_shape = tgt_shape @@ -30,7 +29,6 @@ class ConvViewModel(nn.Module): class LinearViewModel(nn.Module): - def __init__(self, tgt_shape): super().__init__() self.tgt_shape = tgt_shape @@ -43,19 +41,19 @@ class LinearViewModel(nn.Module): def check_view_handler(rank, tgt_shape, model_cls, world_size, port): disable_existing_loggers() - launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") model = model_cls(tgt_shape).cuda() - if model_cls.__name__ == 'ConvViewModel': - input = torch.rand(8, 8, 66, 66).to('cuda') - other = torch.rand(16, 8, 3, 3).to('cuda') + if model_cls.__name__ == "ConvViewModel": + input = torch.rand(8, 8, 66, 66).to("cuda") + other = torch.rand(16, 8, 3, 3).to("cuda") # index of conv node in computation graph node_index = 2 # total number of conv strategies strategy_number = 16 - if model_cls.__name__ == 'LinearViewModel': - input = torch.rand(8, 16, 64, 32).to('cuda') - other = torch.rand(64, 32).to('cuda') + if model_cls.__name__ == "LinearViewModel": + input = torch.rand(8, 16, 64, 32).to("cuda") + other = torch.rand(64, 32).to("cuda") # index of linear node in computation graph node_index = 2 # total number of linear strategies @@ -65,25 +63,27 @@ def check_view_handler(rank, tgt_shape, model_cls, world_size, port): mesh_shape = (2, 2) device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True) - numerical_test_for_node_strategy(model=model, - device_mesh=device_mesh, - node_index=node_index, - strategy_number=strategy_number, - input_args=[input, other], - meta_arg_names=['input', 'other'], - node_type='following') + numerical_test_for_node_strategy( + model=model, + device_mesh=device_mesh, + node_index=node_index, + strategy_number=strategy_number, + input_args=[input, other], + meta_arg_names=["input", "other"], + node_type="following", + ) tracer = ColoTracer(bias_addition_split=True) - if model_cls.__name__ == 'ConvViewModel': + if model_cls.__name__ == "ConvViewModel": # graph(): # %input_1 : torch.Tensor [#users=1] = placeholder[target=input] # %other : torch.Tensor [#users=1] = placeholder[target=other] # %conv2d : [#users=1] = call_function[target=torch.conv2d](args = (%input_1, %other), kwargs = {}) # %view : [#users=1] = call_method[target=view](args = (%conv2d, 2, -1), kwargs = {}) # return view - meta_args = {'input': torch.rand(8, 8, 66, 66).to('meta'), 'other': torch.rand(16, 8, 3, 3).to('meta')} + meta_args = {"input": torch.rand(8, 8, 66, 66).to("meta"), "other": torch.rand(16, 8, 3, 3).to("meta")} graph = tracer.trace(model, meta_args=meta_args) - if model_cls.__name__ == 'LinearViewModel': + if model_cls.__name__ == "LinearViewModel": # graph(): # %input_1 : torch.Tensor [#users=1] = placeholder[target=input] # %other : torch.Tensor [#users=1] = placeholder[target=other] @@ -91,8 +91,8 @@ def check_view_handler(rank, tgt_shape, model_cls, world_size, port): # %view : [#users=1] = call_method[target=view](args = (%linear, 32, 4, 32, 32, 4), kwargs = {}) # return view meta_args = { - 'input': torch.rand(8, 16, 64, 32).to('meta'), - 'other': torch.rand(64, 32).to('meta'), + "input": torch.rand(8, 16, 64, 32).to("meta"), + "other": torch.rand(64, 32).to("meta"), } graph = tracer.trace(model, meta_args=meta_args) @@ -105,21 +105,20 @@ def check_view_handler(rank, tgt_shape, model_cls, world_size, port): previous_strategies_vector = StrategiesVector(previous_mod_node) # build handler - if model_cls.__name__ == 'ConvViewModel': - - conv_handler = ConvFunctionHandler(node=previous_mod_node, - device_mesh=device_mesh, - strategies_vector=previous_strategies_vector) + if model_cls.__name__ == "ConvViewModel": + conv_handler = ConvFunctionHandler( + node=previous_mod_node, device_mesh=device_mesh, strategies_vector=previous_strategies_vector + ) conv_handler.register_strategy(compute_resharding_cost=False) - setattr(previous_mod_node, 'strategies_vector', previous_strategies_vector) + setattr(previous_mod_node, "strategies_vector", previous_strategies_vector) - if model_cls.__name__ == 'LinearViewModel': + if model_cls.__name__ == "LinearViewModel": assert len(previous_strategies_vector) == 0 - linear_handler = LinearFunctionHandler(node=previous_mod_node, - device_mesh=device_mesh, - strategies_vector=previous_strategies_vector) + linear_handler = LinearFunctionHandler( + node=previous_mod_node, device_mesh=device_mesh, strategies_vector=previous_strategies_vector + ) linear_handler.register_strategy(compute_resharding_cost=False) - setattr(previous_mod_node, 'strategies_vector', previous_strategies_vector) + setattr(previous_mod_node, "strategies_vector", previous_strategies_vector) view_handler = ViewHandler(node=view_node, device_mesh=device_mesh, strategies_vector=view_strategies_vector) @@ -133,126 +132,124 @@ def check_view_handler(rank, tgt_shape, model_cls, world_size, port): # make sure they have valid values assert op_data.data is not None - if model_cls.__name__ == 'ConvViewModel': - assert mapping['input'].name == "conv2d" + if model_cls.__name__ == "ConvViewModel": + assert mapping["input"].name == "conv2d" else: - assert mapping['input'].name == "linear" - assert mapping['input'].data.is_meta - assert mapping['input'].data.shape == torch.Size([8, 16, 64, 64]) - assert mapping['input'].type == OperationDataType.ARG - assert mapping['input'].logical_shape == torch.Size([8, 16, 64, 64]) + assert mapping["input"].name == "linear" + assert mapping["input"].data.is_meta + assert mapping["input"].data.shape == torch.Size([8, 16, 64, 64]) + assert mapping["input"].type == OperationDataType.ARG + assert mapping["input"].logical_shape == torch.Size([8, 16, 64, 64]) - assert mapping['output'].name == "view" - assert mapping['output'].data.is_meta - assert mapping['output'].data.shape == torch.Size(tgt_shape) - assert mapping['output'].type == OperationDataType.OUTPUT + assert mapping["output"].name == "view" + assert mapping["output"].data.is_meta + assert mapping["output"].data.shape == torch.Size(tgt_shape) + assert mapping["output"].type == OperationDataType.OUTPUT # reshape handler is a following strategy handler, so the number of strategies is equal to the predecessor node. assert len(view_strategies_vector) == len(previous_strategies_vector) strategy_name_list = [strategy.name for strategy in view_strategies_vector] - if model_cls.__name__ == 'ConvViewModel': - + if model_cls.__name__ == "ConvViewModel": if tgt_shape == (32, 4, 64, 16, 4): - assert '[S0, S1, R, R] -> FULLY REPLICATED_0' in strategy_name_list - assert '[S1, S0, R, R] -> FULLY REPLICATED_1' in strategy_name_list - assert '[S0, R, R, R] -> [S0, R, R, R, R]_2' in strategy_name_list - assert '[S1, R, R, R] -> [S1, R, R, R, R]_3' in strategy_name_list - assert '[S0, R, R, R] -> [S0, R, R, R, R]_4' in strategy_name_list - assert '[S1, R, R, R] -> [S1, R, R, R, R]_5' in strategy_name_list - assert '[R, S1, R, R] -> FULLY REPLICATED_6' in strategy_name_list - assert '[R, S0, R, R] -> FULLY REPLICATED_7' in strategy_name_list - assert '[R, R, R, R] -> [R, R, R, R, R]_8' in strategy_name_list - assert '[R, R, R, R] -> [R, R, R, R, R]_9' in strategy_name_list - assert '[R, S0, R, R] -> FULLY REPLICATED_10' in strategy_name_list - assert '[R, S1, R, R] -> FULLY REPLICATED_11' in strategy_name_list - assert '[R, R, R, R] -> [R, R, R, R, R]_12' in strategy_name_list - assert '[S01, R, R, R] -> [S01, R, R, R, R]_13' in strategy_name_list - assert '[R, R, R, R] -> [R, R, R, R, R]_14' in strategy_name_list - assert '[R, S01, R, R] -> FULLY REPLICATED_15' in strategy_name_list + assert "[S0, S1, R, R] -> FULLY REPLICATED_0" in strategy_name_list + assert "[S1, S0, R, R] -> FULLY REPLICATED_1" in strategy_name_list + assert "[S0, R, R, R] -> [S0, R, R, R, R]_2" in strategy_name_list + assert "[S1, R, R, R] -> [S1, R, R, R, R]_3" in strategy_name_list + assert "[S0, R, R, R] -> [S0, R, R, R, R]_4" in strategy_name_list + assert "[S1, R, R, R] -> [S1, R, R, R, R]_5" in strategy_name_list + assert "[R, S1, R, R] -> FULLY REPLICATED_6" in strategy_name_list + assert "[R, S0, R, R] -> FULLY REPLICATED_7" in strategy_name_list + assert "[R, R, R, R] -> [R, R, R, R, R]_8" in strategy_name_list + assert "[R, R, R, R] -> [R, R, R, R, R]_9" in strategy_name_list + assert "[R, S0, R, R] -> FULLY REPLICATED_10" in strategy_name_list + assert "[R, S1, R, R] -> FULLY REPLICATED_11" in strategy_name_list + assert "[R, R, R, R] -> [R, R, R, R, R]_12" in strategy_name_list + assert "[S01, R, R, R] -> [S01, R, R, R, R]_13" in strategy_name_list + assert "[R, R, R, R] -> [R, R, R, R, R]_14" in strategy_name_list + assert "[R, S01, R, R] -> FULLY REPLICATED_15" in strategy_name_list if tgt_shape == (8, 4, 4, 64, 16, 4): - assert '[S0, S1, R, R] -> [S0, S1, R, R, R, R]_0' in strategy_name_list - assert '[S1, S0, R, R] -> [S1, S0, R, R, R, R]_1' in strategy_name_list - assert '[S0, R, R, R] -> [S0, R, R, R, R, R]_2' in strategy_name_list - assert '[S1, R, R, R] -> [S1, R, R, R, R, R]_3' in strategy_name_list - assert '[S0, R, R, R] -> [S0, R, R, R, R, R]_4' in strategy_name_list - assert '[S1, R, R, R] -> [S1, R, R, R, R, R]_5' in strategy_name_list - assert '[R, S1, R, R] -> [R, S1, R, R, R, R]_6' in strategy_name_list - assert '[R, S0, R, R] -> [R, S0, R, R, R, R]_7' in strategy_name_list - assert '[R, R, R, R] -> [R, R, R, R, R, R]_8' in strategy_name_list - assert '[R, R, R, R] -> [R, R, R, R, R, R]_9' in strategy_name_list - assert '[R, S0, R, R] -> [R, S0, R, R, R, R]_10' in strategy_name_list - assert '[R, S1, R, R] -> [R, S1, R, R, R, R]_11' in strategy_name_list - assert '[R, R, R, R] -> [R, R, R, R, R, R]_12' in strategy_name_list - assert '[S01, R, R, R] -> [S01, R, R, R, R, R]_13' in strategy_name_list - assert '[R, R, R, R] -> [R, R, R, R, R, R]_14' in strategy_name_list - assert '[R, S01, R, R] -> [R, S01, R, R, R, R]_15' in strategy_name_list - - if model_cls.__name__ == 'LinearViewModel': - + assert "[S0, S1, R, R] -> [S0, S1, R, R, R, R]_0" in strategy_name_list + assert "[S1, S0, R, R] -> [S1, S0, R, R, R, R]_1" in strategy_name_list + assert "[S0, R, R, R] -> [S0, R, R, R, R, R]_2" in strategy_name_list + assert "[S1, R, R, R] -> [S1, R, R, R, R, R]_3" in strategy_name_list + assert "[S0, R, R, R] -> [S0, R, R, R, R, R]_4" in strategy_name_list + assert "[S1, R, R, R] -> [S1, R, R, R, R, R]_5" in strategy_name_list + assert "[R, S1, R, R] -> [R, S1, R, R, R, R]_6" in strategy_name_list + assert "[R, S0, R, R] -> [R, S0, R, R, R, R]_7" in strategy_name_list + assert "[R, R, R, R] -> [R, R, R, R, R, R]_8" in strategy_name_list + assert "[R, R, R, R] -> [R, R, R, R, R, R]_9" in strategy_name_list + assert "[R, S0, R, R] -> [R, S0, R, R, R, R]_10" in strategy_name_list + assert "[R, S1, R, R] -> [R, S1, R, R, R, R]_11" in strategy_name_list + assert "[R, R, R, R] -> [R, R, R, R, R, R]_12" in strategy_name_list + assert "[S01, R, R, R] -> [S01, R, R, R, R, R]_13" in strategy_name_list + assert "[R, R, R, R] -> [R, R, R, R, R, R]_14" in strategy_name_list + assert "[R, S01, R, R] -> [R, S01, R, R, R, R]_15" in strategy_name_list + + if model_cls.__name__ == "LinearViewModel": if tgt_shape == (32, 4, 64, 16, 4): for strategy in strategy_name_list: print(strategy) # print(strategy_name_list) - assert '[S0, R, R, S1] -> [S0, R, R, S1, R]_11' in strategy_name_list - assert '[R, S0, R, S1] -> FULLY REPLICATED_12' in strategy_name_list - assert '[R, R, S0, S1] -> [R, R, S0, S1, R]_13' in strategy_name_list - assert '[S1, R, R, S0] -> [S1, R, R, S0, R]_14' in strategy_name_list - assert '[R, S1, R, S0] -> FULLY REPLICATED_15' in strategy_name_list - assert '[R, R, S1, S0] -> [R, R, S1, S0, R]_16' in strategy_name_list - assert '[S0, R, R, R] -> [S0, R, R, R, R]_17' in strategy_name_list - assert '[R, S0, R, R] -> FULLY REPLICATED_18' in strategy_name_list - assert '[R, R, S0, R] -> [R, R, S0, R, R]_19' in strategy_name_list - assert '[S1, R, R, R] -> [S1, R, R, R, R]_20' in strategy_name_list - assert '[R, S1, R, R] -> FULLY REPLICATED_21' in strategy_name_list - assert '[R, R, S1, R] -> [R, R, S1, R, R]_22' in strategy_name_list - assert '[R, R, R, S1] -> [R, R, R, S1, R]_10' in strategy_name_list - assert '[R, R, R, S0] -> [R, R, R, S0, R]_9' in strategy_name_list - assert '[R, R, R, R] -> [R, R, R, R, R]_8' in strategy_name_list - assert '[R, R, R, R] -> [R, R, R, R, R]_7' in strategy_name_list - assert '[R, R, R, S0] -> [R, R, R, S0, R]_6' in strategy_name_list - assert '[R, R, R, S1] -> [R, R, R, S1, R]_5' in strategy_name_list - assert '[S01, R, R, R] -> [S01, R, R, R, R]_0' in strategy_name_list - assert '[R, S01, R, R] -> FULLY REPLICATED_1' in strategy_name_list - assert '[R, R, S01, R] -> [R, R, S01, R, R]_2' in strategy_name_list - assert '[R, R, R, R] -> [R, R, R, R, R]_3' in strategy_name_list - assert '[R, R, R, S01] -> [R, R, R, S01, R]_4' in strategy_name_list + assert "[S0, R, R, S1] -> [S0, R, R, S1, R]_11" in strategy_name_list + assert "[R, S0, R, S1] -> FULLY REPLICATED_12" in strategy_name_list + assert "[R, R, S0, S1] -> [R, R, S0, S1, R]_13" in strategy_name_list + assert "[S1, R, R, S0] -> [S1, R, R, S0, R]_14" in strategy_name_list + assert "[R, S1, R, S0] -> FULLY REPLICATED_15" in strategy_name_list + assert "[R, R, S1, S0] -> [R, R, S1, S0, R]_16" in strategy_name_list + assert "[S0, R, R, R] -> [S0, R, R, R, R]_17" in strategy_name_list + assert "[R, S0, R, R] -> FULLY REPLICATED_18" in strategy_name_list + assert "[R, R, S0, R] -> [R, R, S0, R, R]_19" in strategy_name_list + assert "[S1, R, R, R] -> [S1, R, R, R, R]_20" in strategy_name_list + assert "[R, S1, R, R] -> FULLY REPLICATED_21" in strategy_name_list + assert "[R, R, S1, R] -> [R, R, S1, R, R]_22" in strategy_name_list + assert "[R, R, R, S1] -> [R, R, R, S1, R]_10" in strategy_name_list + assert "[R, R, R, S0] -> [R, R, R, S0, R]_9" in strategy_name_list + assert "[R, R, R, R] -> [R, R, R, R, R]_8" in strategy_name_list + assert "[R, R, R, R] -> [R, R, R, R, R]_7" in strategy_name_list + assert "[R, R, R, S0] -> [R, R, R, S0, R]_6" in strategy_name_list + assert "[R, R, R, S1] -> [R, R, R, S1, R]_5" in strategy_name_list + assert "[S01, R, R, R] -> [S01, R, R, R, R]_0" in strategy_name_list + assert "[R, S01, R, R] -> FULLY REPLICATED_1" in strategy_name_list + assert "[R, R, S01, R] -> [R, R, S01, R, R]_2" in strategy_name_list + assert "[R, R, R, R] -> [R, R, R, R, R]_3" in strategy_name_list + assert "[R, R, R, S01] -> [R, R, R, S01, R]_4" in strategy_name_list if tgt_shape == (8, 4, 4, 64, 16, 4): - assert '[S0, R, R, S1] -> [S0, R, R, R, S1, R]_11' in strategy_name_list - assert '[R, S0, R, S1] -> [R, S0, R, R, S1, R]_12' in strategy_name_list - assert '[R, R, S0, S1] -> [R, R, R, S0, S1, R]_13' in strategy_name_list - assert '[S1, R, R, S0] -> [S1, R, R, R, S0, R]_14' in strategy_name_list - assert '[R, S1, R, S0] -> [R, S1, R, R, S0, R]_15' in strategy_name_list - assert '[R, R, S1, S0] -> [R, R, R, S1, S0, R]_16' in strategy_name_list - assert '[S0, R, R, R] -> [S0, R, R, R, R, R]_17' in strategy_name_list - assert '[R, S0, R, R] -> [R, S0, R, R, R, R]_18' in strategy_name_list - assert '[R, R, S0, R] -> [R, R, R, S0, R, R]_19' in strategy_name_list - assert '[S1, R, R, R] -> [S1, R, R, R, R, R]_20' in strategy_name_list - assert '[R, S1, R, R] -> [R, S1, R, R, R, R]_21' in strategy_name_list - assert '[R, R, S1, R] -> [R, R, R, S1, R, R]_22' in strategy_name_list - assert '[R, R, R, S1] -> [R, R, R, R, S1, R]_10' in strategy_name_list - assert '[R, R, R, S0] -> [R, R, R, R, S0, R]_9' in strategy_name_list - assert '[R, R, R, R] -> [R, R, R, R, R, R]_8' in strategy_name_list - assert '[R, R, R, R] -> [R, R, R, R, R, R]_7' in strategy_name_list - assert '[R, R, R, S0] -> [R, R, R, R, S0, R]_6' in strategy_name_list - assert '[R, R, R, S1] -> [R, R, R, R, S1, R]_5' in strategy_name_list - assert '[S01, R, R, R] -> [S01, R, R, R, R, R]_0' in strategy_name_list - assert '[R, S01, R, R] -> [R, S01, R, R, R, R]_1' in strategy_name_list - assert '[R, R, S01, R] -> [R, R, R, S01, R, R]_2' in strategy_name_list - assert '[R, R, R, R] -> [R, R, R, R, R, R]_3' in strategy_name_list - assert '[R, R, R, S01] -> [R, R, R, R, S01, R]_4' in strategy_name_list - - -@run_on_environment_flag(name='AUTO_PARALLEL') + assert "[S0, R, R, S1] -> [S0, R, R, R, S1, R]_11" in strategy_name_list + assert "[R, S0, R, S1] -> [R, S0, R, R, S1, R]_12" in strategy_name_list + assert "[R, R, S0, S1] -> [R, R, R, S0, S1, R]_13" in strategy_name_list + assert "[S1, R, R, S0] -> [S1, R, R, R, S0, R]_14" in strategy_name_list + assert "[R, S1, R, S0] -> [R, S1, R, R, S0, R]_15" in strategy_name_list + assert "[R, R, S1, S0] -> [R, R, R, S1, S0, R]_16" in strategy_name_list + assert "[S0, R, R, R] -> [S0, R, R, R, R, R]_17" in strategy_name_list + assert "[R, S0, R, R] -> [R, S0, R, R, R, R]_18" in strategy_name_list + assert "[R, R, S0, R] -> [R, R, R, S0, R, R]_19" in strategy_name_list + assert "[S1, R, R, R] -> [S1, R, R, R, R, R]_20" in strategy_name_list + assert "[R, S1, R, R] -> [R, S1, R, R, R, R]_21" in strategy_name_list + assert "[R, R, S1, R] -> [R, R, R, S1, R, R]_22" in strategy_name_list + assert "[R, R, R, S1] -> [R, R, R, R, S1, R]_10" in strategy_name_list + assert "[R, R, R, S0] -> [R, R, R, R, S0, R]_9" in strategy_name_list + assert "[R, R, R, R] -> [R, R, R, R, R, R]_8" in strategy_name_list + assert "[R, R, R, R] -> [R, R, R, R, R, R]_7" in strategy_name_list + assert "[R, R, R, S0] -> [R, R, R, R, S0, R]_6" in strategy_name_list + assert "[R, R, R, S1] -> [R, R, R, R, S1, R]_5" in strategy_name_list + assert "[S01, R, R, R] -> [S01, R, R, R, R, R]_0" in strategy_name_list + assert "[R, S01, R, R] -> [R, S01, R, R, R, R]_1" in strategy_name_list + assert "[R, R, S01, R] -> [R, R, R, S01, R, R]_2" in strategy_name_list + assert "[R, R, R, R] -> [R, R, R, R, R, R]_3" in strategy_name_list + assert "[R, R, R, S01] -> [R, R, R, R, S01, R]_4" in strategy_name_list + + +@run_on_environment_flag(name="AUTO_PARALLEL") @pytest.mark.dist @rerun_if_address_is_in_use() -@parameterize('tgt_shape', [(32, 4, 64, 16, 4), (8, 4, 4, 64, 16, 4)]) -@parameterize('model_cls', [ConvViewModel, LinearViewModel]) +@parameterize("tgt_shape", [(32, 4, 64, 16, 4), (8, 4, 4, 64, 16, 4)]) +@parameterize("model_cls", [ConvViewModel, LinearViewModel]) def test_view_handler(tgt_shape, model_cls): spawn(check_view_handler, 4, tgt_shape=tgt_shape, model_cls=model_cls) -if __name__ == '__main__': +if __name__ == "__main__": test_view_handler() diff --git a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_where_handler.py b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_where_handler.py index bd7635ac1737424d1f78c2155f2d3308d7d343b0..10ca644cddc24b0986a5be370b4f67a05ed5db65 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_where_handler.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_where_handler.py @@ -12,7 +12,6 @@ from colossalai.testing import clear_cache_before_run class ConvModel(nn.Module): - def __init__(self): super().__init__() @@ -21,7 +20,7 @@ class ConvModel(nn.Module): return output -@pytest.mark.skip('ShapeProp is not compatible with PyTorch 1.11.0') +@pytest.mark.skip("ShapeProp is not compatible with PyTorch 1.11.0") @clear_cache_before_run() def test_where_handler(): model = ConvModel() @@ -33,9 +32,9 @@ def test_where_handler(): # %where : [#users=1] = call_function[target=torch.where](args = (%condition, %x, %y), kwargs = {}) # return where meta_args = { - 'condition': torch.rand(4, 4, 64, 64).to('meta'), - 'x': torch.rand(4, 1, 64, 64).to('meta'), - 'y': torch.rand(1, 4, 64, 64).to('meta') + "condition": torch.rand(4, 4, 64, 64).to("meta"), + "x": torch.rand(4, 1, 64, 64).to("meta"), + "y": torch.rand(1, 4, 64, 64).to("meta"), } graph = tracer.trace(model, meta_args=meta_args) gm = ColoGraphModule(model, graph) @@ -59,28 +58,28 @@ def test_where_handler(): assert op_data.logical_shape is not None assert op_data.data is not None - assert mapping['condition'].name == "condition" - assert mapping['condition'].data.is_meta - assert mapping['condition'].data.shape == torch.Size([4, 4, 64, 64]) - assert mapping['condition'].type == OperationDataType.ARG - assert mapping['condition'].logical_shape == torch.Size([4, 4, 64, 64]) - - assert mapping['x'].name == "x" - assert mapping['x'].data.is_meta - assert mapping['x'].data.shape == torch.Size([4, 1, 64, 64]) - assert mapping['x'].type == OperationDataType.ARG - assert mapping['x'].logical_shape == torch.Size([4, 4, 64, 64]) - - assert mapping['y'].name == "y" - assert mapping['y'].data.is_meta - assert mapping['y'].data.shape == torch.Size([1, 4, 64, 64]) - assert mapping['y'].type == OperationDataType.ARG - assert mapping['y'].logical_shape == torch.Size([4, 4, 64, 64]) - - assert mapping['output'].name == "where" - assert mapping['output'].data.is_meta - assert mapping['output'].data.shape == torch.Size([4, 4, 64, 64]) - assert mapping['output'].type == OperationDataType.OUTPUT + assert mapping["condition"].name == "condition" + assert mapping["condition"].data.is_meta + assert mapping["condition"].data.shape == torch.Size([4, 4, 64, 64]) + assert mapping["condition"].type == OperationDataType.ARG + assert mapping["condition"].logical_shape == torch.Size([4, 4, 64, 64]) + + assert mapping["x"].name == "x" + assert mapping["x"].data.is_meta + assert mapping["x"].data.shape == torch.Size([4, 1, 64, 64]) + assert mapping["x"].type == OperationDataType.ARG + assert mapping["x"].logical_shape == torch.Size([4, 4, 64, 64]) + + assert mapping["y"].name == "y" + assert mapping["y"].data.is_meta + assert mapping["y"].data.shape == torch.Size([1, 4, 64, 64]) + assert mapping["y"].type == OperationDataType.ARG + assert mapping["y"].logical_shape == torch.Size([4, 4, 64, 64]) + + assert mapping["output"].name == "where" + assert mapping["output"].data.is_meta + assert mapping["output"].data.shape == torch.Size([4, 4, 64, 64]) + assert mapping["output"].type == OperationDataType.OUTPUT handler.register_strategy(compute_resharding_cost=False) strategy_name_list = [val.name for val in strategies_vector] @@ -88,5 +87,5 @@ def test_where_handler(): assert len(strategy_name_list) == 25 -if __name__ == '__main__': +if __name__ == "__main__": test_where_handler() diff --git a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/utils.py b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/utils.py index 28a8bbd9a4c11df2a2fbf324f6873c6d2c75ae4f..3591c663897c391c072cc4f16b26a96c27c1802b 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/utils.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/utils.py @@ -2,7 +2,6 @@ import copy from typing import Dict, List import torch -from torch.fx import GraphModule from colossalai._analyzer.fx.graph_module import ColoGraphModule from colossalai._analyzer.fx.passes.shape_prop import shape_prop_pass @@ -18,16 +17,18 @@ from colossalai.tensor.shape_consistency import to_global from colossalai.testing.comparison import assert_close -def _build_model_to_compare(model: torch.nn.Module, input_args: List[torch.Tensor], - input_kwargs: Dict[str, torch.Tensor], grad_dict: Dict[any, torch.Tensor]): - +def _build_model_to_compare( + model: torch.nn.Module, + input_args: List[torch.Tensor], + input_kwargs: Dict[str, torch.Tensor], + grad_dict: Dict[any, torch.Tensor], +): model_to_compare = copy.deepcopy(model) args_to_compare = [] kwargs_to_compare = {} for arg_index, input_tensor in enumerate(input_args): def wrapper(param, index): - def hook_fn(grad): grad_dict[index] = grad @@ -45,7 +46,6 @@ def _build_model_to_compare(model: torch.nn.Module, input_args: List[torch.Tenso for name, input_kwarg in input_kwargs.items(): def wrapper(param, name): - def hook_fn(grad): grad_dict[name] = grad @@ -63,30 +63,34 @@ def _build_model_to_compare(model: torch.nn.Module, input_args: List[torch.Tenso return model_to_compare, args_to_compare, kwargs_to_compare -def numerical_test_for_node_strategy(model: torch.nn.Module, - device_mesh: DeviceMesh, - node_index: int, - strategy_number: int, - input_args: List[torch.Tensor], - meta_arg_names: List[str], - input_kwargs: Dict[str, torch.Tensor] = {}, - node_type: str = 'normal'): +def numerical_test_for_node_strategy( + model: torch.nn.Module, + device_mesh: DeviceMesh, + node_index: int, + strategy_number: int, + input_args: List[torch.Tensor], + meta_arg_names: List[str], + input_kwargs: Dict[str, torch.Tensor] = {}, + node_type: str = "normal", +): for strategy_index in range(strategy_number): - print(f'#strategy_index: {strategy_index}') + print(f"#strategy_index: {strategy_index}") # We need to copy the model to avoid do backward more than once in same graph grad_to_compare_dict = {} grad_to_shard_dict = {} model_to_compare, args_to_compare, kwargs_to_compare = _build_model_to_compare( - model, input_args, input_kwargs, grad_to_compare_dict) - model_to_shard, args_to_shard, kwargs_to_shard = _build_model_to_compare(model, input_args, input_kwargs, - grad_to_shard_dict) + model, input_args, input_kwargs, grad_to_compare_dict + ) + model_to_shard, args_to_shard, kwargs_to_shard = _build_model_to_compare( + model, input_args, input_kwargs, grad_to_shard_dict + ) tracer = ColoTracer(bias_addition_split=True) input_sample = {} for input_arg, meta_arg_name in zip(input_args, meta_arg_names): - input_sample[meta_arg_name] = torch.empty(input_arg.shape, dtype=input_arg.dtype).to('meta') + input_sample[meta_arg_name] = torch.empty(input_arg.shape, dtype=input_arg.dtype).to("meta") for meta_kwarg_name, input_kwarg in input_kwargs.items(): - input_sample[meta_kwarg_name] = torch.empty(input_kwarg.shape, dtype=input_kwarg.dtype).to('meta') + input_sample[meta_kwarg_name] = torch.empty(input_kwarg.shape, dtype=input_kwarg.dtype).to("meta") graph = tracer.trace(root=model_to_shard, meta_args=input_sample) gm = ColoGraphModule(model_to_shard, graph, model_to_shard.__class__.__name__) shape_prop_pass(gm, *input_sample.values()) @@ -94,13 +98,14 @@ def numerical_test_for_node_strategy(model: torch.nn.Module, solver_options = SolverOptions() strategies_constructor = StrategiesConstructor(graph, device_mesh, solver_options) strategies_constructor.build_strategies_and_cost() - target_node = [strategies_vector.node for strategies_vector in strategies_constructor.leaf_strategies - ][node_index] - if node_type == 'normal': + target_node = [strategies_vector.node for strategies_vector in strategies_constructor.leaf_strategies][ + node_index + ] + if node_type == "normal": solution_len = len(strategies_constructor.leaf_strategies) solution = [0] * solution_len solution[node_index] = strategy_index - elif node_type == 'following': + elif node_type == "following": solution_len = len(strategies_constructor.leaf_strategies) solution = [0] * solution_len solution[node_index] = strategy_index @@ -116,18 +121,21 @@ def numerical_test_for_node_strategy(model: torch.nn.Module, ret = solver.call_solver_serialized_args() solution = list(ret[0]) gm, sharding_spec_dict, origin_spec_dict, comm_actions_dict = runtime_preparation_pass( - gm, solution, device_mesh, strategies_constructor) + gm, solution, device_mesh, strategies_constructor + ) gm = runtime_apply_pass(gm) gm.recompile() # forward result compare - output = gm(*args_to_shard, - sharding_spec_convert_dict=sharding_spec_dict, - origin_node_sharding_spec_dict=origin_spec_dict, - comm_actions_dict=comm_actions_dict, - **kwargs_to_shard) + output = gm( + *args_to_shard, + sharding_spec_convert_dict=sharding_spec_dict, + origin_node_sharding_spec_dict=origin_spec_dict, + comm_actions_dict=comm_actions_dict, + **kwargs_to_shard, + ) output_to_compare = model_to_compare(*args_to_compare, **kwargs_to_compare) - assert_close_helper(output, output_to_compare, strategy_index=strategy_index, type='forward output') + assert_close_helper(output, output_to_compare, strategy_index=strategy_index, type="forward output") # backward result compare if isinstance(output, (tuple, list)): @@ -142,43 +150,45 @@ def numerical_test_for_node_strategy(model: torch.nn.Module, for key in grad_to_shard_dict.keys(): grad_to_shard = grad_to_shard_dict[key] grad_to_compare = grad_to_compare_dict[key] - assert_close_helper(grad_to_shard, grad_to_compare, strategy_index=strategy_index, type='input grad') + assert_close_helper(grad_to_shard, grad_to_compare, strategy_index=strategy_index, type="input grad") # extract the strategy used in this iter strategy_in_use = target_node.strategies_vector[strategy_index] param_to_shard_dict = dict(gm.named_parameters()) param_to_compare_dict = dict(model_to_compare.named_parameters()) for name in param_to_shard_dict.keys(): - param_name = name.split('.')[-1] - if node_type == 'normal': + param_name = name.split(".")[-1] + if node_type == "normal": param_sharding_spec = strategy_in_use.get_sharding_spec_by_name(param_name) else: - if 'weight' in name: + if "weight" in name: param_sharding_spec = None for node in list(graph.nodes): - if 'weight' in node.name: + if "weight" in node.name: param_sharding_spec = node.sharding_spec - elif 'bias' in name: + elif "bias" in name: param_sharding_spec = None for node in list(graph.nodes): - if 'bias' in node.name: + if "bias" in node.name: param_sharding_spec = node.sharding_spec assert param_sharding_spec is not None grad_sharded = param_to_shard_dict[name].grad grad_to_compare = param_to_compare_dict[name].grad global_grad = to_global(grad_sharded, param_sharding_spec) - assert_close_helper(global_grad, grad_to_compare, strategy_index=strategy_index, type='param grad') + assert_close_helper(global_grad, grad_to_compare, strategy_index=strategy_index, type="param grad") -def assert_close_helper(first: torch.Tensor, - second: torch.Tensor, - rtol: float = 1e-2, - atol: float = 1e-2, - strategy_index: int = -1, - type: str = 'not defined'): +def assert_close_helper( + first: torch.Tensor, + second: torch.Tensor, + rtol: float = 1e-2, + atol: float = 1e-2, + strategy_index: int = -1, + type: str = "not defined", +): """ This method is used to check whether the average difference between two tensors is as close as expected. """ @@ -189,4 +199,4 @@ def assert_close_helper(first: torch.Tensor, else: assert_close(first, second, rtol=rtol, atol=atol) except: - print(f'strategy index {strategy_index} encounter assert_close error on {type}') + print(f"strategy index {strategy_index} encounter assert_close error on {type}") diff --git a/tests/test_auto_parallel/test_tensor_shard/test_solver_with_resnet_v2.py b/tests/test_auto_parallel/test_tensor_shard/test_solver_with_resnet_v2.py index 0d93e4e4052792158729b382f4f679996bcc48e1..e7b8c696e62e48f1bf7ad32bf6ac6185412ca5e1 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_solver_with_resnet_v2.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_solver_with_resnet_v2.py @@ -3,17 +3,18 @@ from torch.fx import GraphModule from torchvision.models import resnet50 from colossalai._analyzer.fx.passes import shape_prop_pass + # from colossalai.fx.tracer.tracer import ColoTracer from colossalai._analyzer.fx.tracer.tracer import ColoTracer from colossalai.auto_parallel.tensor_shard.constants import BATCHNORM_MODULE_OP from colossalai.auto_parallel.tensor_shard.options import SolverOptions -from colossalai.auto_parallel.tensor_shard.solver import CostGraph, GraphAnalyser, Solver, StrategiesConstructor +from colossalai.auto_parallel.tensor_shard.solver import CostGraph, Solver, StrategiesConstructor from colossalai.device.device_mesh import DeviceMesh from colossalai.tensor.shape_consistency import ShapeConsistencyManager from colossalai.testing import clear_cache_before_run, run_on_environment_flag -@run_on_environment_flag(name='AUTO_PARALLEL') +@run_on_environment_flag(name="AUTO_PARALLEL") @clear_cache_before_run() def test_cost_graph(): physical_mesh_id = torch.arange(0, 8) @@ -21,11 +22,11 @@ def test_cost_graph(): # [[0, 1] # [2, 3]] device_mesh = DeviceMesh(physical_mesh_id, mesh_shape) - shape_consistency_manager = ShapeConsistencyManager() + ShapeConsistencyManager() tracer = ColoTracer(bias_addition_split=True) model = resnet50(num_classes=100000) - input_sample = {'x': torch.rand(128, 3, 224, 224).to('meta')} + input_sample = {"x": torch.rand(128, 3, 224, 224).to("meta")} graph = tracer.trace(root=model, meta_args=input_sample) # graph(): @@ -74,7 +75,7 @@ def test_cost_graph(): communication_cost_bn = 0 memory_cost = 0 for index, node in enumerate(graph.nodes): - if node.op == 'call_module': + if node.op == "call_module": submod = node.graph.owning_module.get_submodule(node.target) if type(submod) in BATCHNORM_MODULE_OP: communication_cost_bn += node.strategies_vector[strategies_list[index]].communication_cost.total @@ -86,11 +87,11 @@ def test_cost_graph(): node_memory_cost = node_memory_cost[0] memory_cost += node_memory_cost.activation + node_memory_cost.parameter - print(f'computation cost is {computation_cost}') - print(f'communication cost is {communication_cost}') - print(f'memory cost is {memory_cost}') - print(f'bn communication cost is {communication_cost_bn}') + print(f"computation cost is {computation_cost}") + print(f"communication cost is {communication_cost}") + print(f"memory cost is {memory_cost}") + print(f"bn communication cost is {communication_cost_bn}") -if __name__ == '__main__': +if __name__ == "__main__": test_cost_graph() diff --git a/tests/test_autochunk/test_autochunk_alphafold/benchmark_autochunk_alphafold.py b/tests/test_autochunk/test_autochunk_alphafold/benchmark_autochunk_alphafold.py index d07145e48e1f5e86a9c1dcf9e1ed5644bee513aa..07fd0ad582e99c833ffdf940cc378e759ff0fdf7 100644 --- a/tests/test_autochunk/test_autochunk_alphafold/benchmark_autochunk_alphafold.py +++ b/tests/test_autochunk/test_autochunk_alphafold/benchmark_autochunk_alphafold.py @@ -1,5 +1,5 @@ import time -from typing import Any, Dict, List +from typing import Any import torch import torch.fx @@ -111,13 +111,14 @@ def _benchmark_speed(model, inputs, loop=5): def benchmark_evoformer_stack(data_args): from test_autochunk_evoformer_stack import get_data, get_model + print("\nmsa len: %d, pair len: %d" % (data_args[0], data_args[1])) max_mem = _benchmark_evoformer_stack_origin(data_args, get_model, get_data) for ratio in [0.5, 0.4, 0.3, 0.2, 0.1]: try: _benchmark_evoformer_stack_gm(data_args, max_mem * ratio, get_model, get_data) except RuntimeError as e: - if e.args[0] == 'Search failed. Try a larger memory threshold.': + if e.args[0] == "Search failed. Try a larger memory threshold.": break except Exception as e: raise e diff --git a/tests/test_autochunk/test_autochunk_alphafold/test_autochunk_alphafold_utils.py b/tests/test_autochunk/test_autochunk_alphafold/test_autochunk_alphafold_utils.py index 15610e2b50dcbdb24431b7add4cd9fae6a98619e..3d3f212a68d0f97a13125010bdbc2d7e32b19897 100644 --- a/tests/test_autochunk/test_autochunk_alphafold/test_autochunk_alphafold_utils.py +++ b/tests/test_autochunk/test_autochunk_alphafold/test_autochunk_alphafold_utils.py @@ -6,7 +6,6 @@ import torch.fx import colossalai from colossalai.autochunk.autochunk_codegen import AUTOCHUNK_AVAILABLE from colossalai.autochunk.utils import flat_list -from colossalai.core import global_context as gpc from colossalai.fx.graph_module import ColoGraphModule from colossalai.fx.passes.meta_info_prop import MetaInfoProp from colossalai.testing import free_port @@ -80,9 +79,9 @@ def assert_codegen_run( out_gm = flat_list(out_gm) out_model = flat_list(out_model) for out_gm_i, out_model_i in zip(out_gm, out_model): - assert torch.allclose(out_gm_i, out_model_i, - atol=1e-4), "fx_out doesn't comply with original output, diff is %.2e" % torch.mean( - torch.abs(out_gm_i - out_model_i)) + assert torch.allclose( + out_gm_i, out_model_i, atol=1e-4 + ), "fx_out doesn't comply with original output, diff is %.2e" % torch.mean(torch.abs(out_gm_i - out_model_i)) return chunks diff --git a/tests/test_autochunk/test_autochunk_alphafold/test_autochunk_evoformer_block.py b/tests/test_autochunk/test_autochunk_alphafold/test_autochunk_evoformer_block.py index 9e4cb7ee9f95b212821809e3ab25906ab82540aa..1a4ababda30d3a2d8fa30281a227c94381ddfbe5 100644 --- a/tests/test_autochunk/test_autochunk_alphafold/test_autochunk_evoformer_block.py +++ b/tests/test_autochunk/test_autochunk_alphafold/test_autochunk_evoformer_block.py @@ -6,6 +6,7 @@ import torch.fx try: from fastfold.model.nn.evoformer import EvoformerBlock + HAS_REPO = True except: HAS_REPO = False @@ -17,22 +18,26 @@ from colossalai.testing import clear_cache_before_run, parameterize, spawn def get_model(): - model = EvoformerBlock( - c_m=256, - c_z=128, - c_hidden_msa_att=32, - c_hidden_opm=32, - c_hidden_mul=128, - c_hidden_pair_att=32, - no_heads_msa=8, - no_heads_pair=4, - transition_n=4, - msa_dropout=0.15, - pair_dropout=0.15, - inf=1e4, - eps=1e-4, - is_multimer=False, - ).eval().cuda() + model = ( + EvoformerBlock( + c_m=256, + c_z=128, + c_hidden_msa_att=32, + c_hidden_opm=32, + c_hidden_mul=128, + c_hidden_pair_att=32, + no_heads_msa=8, + no_heads_pair=4, + transition_n=4, + msa_dropout=0.15, + pair_dropout=0.15, + inf=1e4, + eps=1e-4, + is_multimer=False, + ) + .eval() + .cuda() + ) return model @@ -54,8 +59,20 @@ def get_data(msa_len: int, pair_len: int) -> Tuple[List, List]: def get_chunk_target() -> Dict: return { - None: [(120, 126), (225, 244), (270, 289), (306, 311), (70, 106), (23, 46), (146, 152), (187, 193), (181, 184), - (140, 145), (162, 163), (203, 204)], + None: [ + (120, 126), + (225, 244), + (270, 289), + (306, 311), + (70, 106), + (23, 46), + (146, 152), + (187, 193), + (181, 184), + (140, 145), + (162, 163), + (203, 204), + ], 20: [(120, 123), (232, 237), (277, 282), (305, 306)], 24: [(122, 123)], } diff --git a/tests/test_autochunk/test_autochunk_alphafold/test_autochunk_evoformer_stack.py b/tests/test_autochunk/test_autochunk_alphafold/test_autochunk_evoformer_stack.py index 6b47033e199f0ac47ef6df90bb4d28730eae1bd9..0b04ba5257b6689dd132771511ca6beb248e3440 100644 --- a/tests/test_autochunk/test_autochunk_alphafold/test_autochunk_evoformer_stack.py +++ b/tests/test_autochunk/test_autochunk_alphafold/test_autochunk_evoformer_stack.py @@ -6,6 +6,7 @@ import torch.fx try: from fastfold.model.nn.evoformer import EvoformerStack + HAS_REPO = True except: HAS_REPO = False @@ -17,26 +18,30 @@ from colossalai.testing import clear_cache_before_run, parameterize, spawn def get_model(): - model = EvoformerStack( - c_m=256, - c_z=128, - c_hidden_msa_att=32, - c_hidden_opm=32, - c_hidden_mul=128, - c_hidden_pair_att=32, - c_s=384, - no_heads_msa=8, - no_heads_pair=4, - no_blocks=2, # 48 - transition_n=4, - msa_dropout=0.15, - pair_dropout=0.25, - blocks_per_ckpt=None, - inf=1000000000.0, - eps=1e-08, - clear_cache_between_blocks=False, - is_multimer=False, - ).eval().cuda() + model = ( + EvoformerStack( + c_m=256, + c_z=128, + c_hidden_msa_att=32, + c_hidden_opm=32, + c_hidden_mul=128, + c_hidden_pair_att=32, + c_s=384, + no_heads_msa=8, + no_heads_pair=4, + no_blocks=2, # 48 + transition_n=4, + msa_dropout=0.15, + pair_dropout=0.25, + blocks_per_ckpt=None, + inf=1000000000.0, + eps=1e-08, + clear_cache_between_blocks=False, + is_multimer=False, + ) + .eval() + .cuda() + ) return model @@ -62,7 +67,7 @@ def get_data(msa_len: int, pair_len: int) -> Tuple[List, List]: ) @clear_cache_before_run() @parameterize("max_memory", [None, 20, 24]) -@parameterize("data_args", [(32, 64)]) # (msa_len, pair_len) +@parameterize("data_args", [(32, 64)]) # (msa_len, pair_len) def test_evoformer_stack(data_args, max_memory): spawn( run_test, diff --git a/tests/test_autochunk/test_autochunk_alphafold/test_autochunk_extramsa_block.py b/tests/test_autochunk/test_autochunk_alphafold/test_autochunk_extramsa_block.py index b4c577c18ee602fcf49724b413c2e323931530f5..585a9e3381c447bf159660d3cffdd804d4e639b0 100644 --- a/tests/test_autochunk/test_autochunk_alphafold/test_autochunk_extramsa_block.py +++ b/tests/test_autochunk/test_autochunk_alphafold/test_autochunk_extramsa_block.py @@ -1,4 +1,4 @@ -from typing import Dict, List, Tuple +from typing import List, Tuple import pytest import torch @@ -6,6 +6,7 @@ import torch.fx try: from fastfold.model.nn.evoformer import ExtraMSABlock + HAS_REPO = True except: HAS_REPO = False @@ -16,23 +17,27 @@ from colossalai.testing import clear_cache_before_run, parameterize, spawn def get_model(): - model = ExtraMSABlock( - c_m=256, - c_z=128, - c_hidden_msa_att=32, - c_hidden_opm=32, - c_hidden_mul=128, - c_hidden_pair_att=32, - no_heads_msa=8, - no_heads_pair=4, - transition_n=4, - msa_dropout=0.15, - pair_dropout=0.15, - inf=1e4, - eps=1e-4, - ckpt=False, - is_multimer=False, - ).eval().cuda() + model = ( + ExtraMSABlock( + c_m=256, + c_z=128, + c_hidden_msa_att=32, + c_hidden_opm=32, + c_hidden_mul=128, + c_hidden_pair_att=32, + no_heads_msa=8, + no_heads_pair=4, + transition_n=4, + msa_dropout=0.15, + pair_dropout=0.15, + inf=1e4, + eps=1e-4, + ckpt=False, + is_multimer=False, + ) + .eval() + .cuda() + ) return model @@ -58,7 +63,7 @@ def get_data(msa_len: int, pair_len: int) -> Tuple[List, List]: ) @clear_cache_before_run() @parameterize("max_memory", [None, 20, 24]) -@parameterize("data_args", [(32, 64)]) # (msa_len, pair_len) +@parameterize("data_args", [(32, 64)]) # (msa_len, pair_len) def test_extramsa_block(data_args, max_memory): spawn( run_test, diff --git a/tests/test_autochunk/test_autochunk_diffuser/benchmark_autochunk_diffuser.py b/tests/test_autochunk/test_autochunk_diffuser/benchmark_autochunk_diffuser.py index 6fb7efa7a8fc986b6811fe9958d2365b3fa7a1f4..b75cbe67590cbccf48577a54e01b05359f9657a9 100644 --- a/tests/test_autochunk/test_autochunk_diffuser/benchmark_autochunk_diffuser.py +++ b/tests/test_autochunk/test_autochunk_diffuser/benchmark_autochunk_diffuser.py @@ -1,5 +1,5 @@ import time -from typing import Any, Dict, List +from typing import Any import torch import torch.fx @@ -64,8 +64,10 @@ def _benchmark_autochunk_unet_gm( para_mem = float(parameter_size(model)) / 1024**2 act_mem = _benchmark_memory(gm, inputs) speed = _benchmark_speed(gm, inputs) - print("unet autochunk, time: %.4fs, act mem: %.2fMB, para mem: %.2fMB, all mem: %.2fMB" % - (speed, act_mem, para_mem, act_mem + para_mem)) + print( + "unet autochunk, time: %.4fs, act mem: %.2fMB, para mem: %.2fMB, all mem: %.2fMB" + % (speed, act_mem, para_mem, act_mem + para_mem) + ) def _benchmark_autochunk_unet_origin( @@ -86,8 +88,10 @@ def _benchmark_autochunk_unet_origin( para_mem = float(parameter_size(model)) / 1024**2 act_mem = _benchmark_memory(model, inputs) speed = _benchmark_speed(model, inputs) - print("unet origin, time: %.4fs, act mem: %.2fMB, para mem: %.2fMB, all mem: %.2fMB" % - (speed, act_mem, para_mem, act_mem + para_mem)) + print( + "unet origin, time: %.4fs, act mem: %.2fMB, para mem: %.2fMB, all mem: %.2fMB" + % (speed, act_mem, para_mem, act_mem + para_mem) + ) return act_mem @@ -115,6 +119,7 @@ def _benchmark_speed(model, inputs, loop=5): def benchmark_autochunk_unet(batch=1, height=448, width=448): from test_autochunk_unet import UNet2DModel, get_data + model = UNet2DModel() latent_shape = (batch, 3, height // 7, width // 7) @@ -124,7 +129,7 @@ def benchmark_autochunk_unet(batch=1, height=448, width=448): try: _benchmark_autochunk_unet_gm(model, get_data(latent_shape), max_mem * ratio) except RuntimeError as e: - if e.args[0] == 'Search failed. Try a larger memory threshold.': + if e.args[0] == "Search failed. Try a larger memory threshold.": break except Exception as e: raise e diff --git a/tests/test_autochunk/test_autochunk_diffuser/test_autochunk_diffuser_utils.py b/tests/test_autochunk/test_autochunk_diffuser/test_autochunk_diffuser_utils.py index e245f10d4576bed99bc322ba602dfb8e04c315b3..32034992090fa0d0d43aefb9ccd9473d76ec936c 100644 --- a/tests/test_autochunk/test_autochunk_diffuser/test_autochunk_diffuser_utils.py +++ b/tests/test_autochunk/test_autochunk_diffuser/test_autochunk_diffuser_utils.py @@ -5,10 +5,9 @@ import torch.fx import colossalai from colossalai.autochunk.autochunk_codegen import AUTOCHUNK_AVAILABLE -from colossalai.core import global_context as gpc from colossalai.fx.graph_module import ColoGraphModule from colossalai.fx.passes.meta_info_prop import MetaInfoProp -from colossalai.testing import free_port +from colossalai.legacy.core import global_context as gpc if AUTOCHUNK_AVAILABLE: from colossalai.autochunk.autochunk_codegen import AutoChunkCodeGen @@ -84,15 +83,19 @@ def assert_codegen_run( max_mem_ori = torch.cuda.max_memory_allocated() / 1024**2 print("origin mem: %.2fMB, autochunk mem: %.2fMB" % (max_mem_ori - now_mem_ori, max_mem_gm - now_mem_gm)) - assert torch.allclose(out_gm["sample"], out_model["sample"], - atol=1e-3), "fx_out doesn't comply with original output, diff is %.2e" % torch.mean( - torch.abs(out_gm["sample"] - out_model["sample"])) + assert torch.allclose( + out_gm["sample"], out_model["sample"], atol=1e-3 + ), "fx_out doesn't comply with original output, diff is %.2e" % torch.mean( + torch.abs(out_gm["sample"] - out_model["sample"]) + ) return chunks def run_test( rank: int, + world_size: int, + port: int, model: Any, data: tuple, max_memory: int, @@ -106,9 +109,9 @@ def run_test( colossalai.launch( config={}, rank=rank, - world_size=1, + world_size=world_size, host="localhost", - port=free_port(), + port=port, backend="nccl", ) @@ -128,7 +131,7 @@ def run_test( if get_chunk_target is not None: chunk_found = [i["region"] for i in chunks] chunk_target = get_chunk_target()[max_memory] - assert (chunk_found == chunk_target), "found regions %s doesn't equal target regions %s" % ( + assert chunk_found == chunk_target, "found regions %s doesn't equal target regions %s" % ( str(chunk_found), str(chunk_target), ) diff --git a/tests/test_autochunk/test_autochunk_diffuser/test_autochunk_unet.py b/tests/test_autochunk/test_autochunk_diffuser/test_autochunk_unet.py index ff0d4a1b53f58a22adc1bf820e25bdcf920728bb..ad50874c92a339ac095ca26a95fe1a4ef1c7df2a 100644 --- a/tests/test_autochunk/test_autochunk_diffuser/test_autochunk_unet.py +++ b/tests/test_autochunk/test_autochunk_diffuser/test_autochunk_unet.py @@ -4,12 +4,17 @@ import pytest import torch try: - from diffusers import UNet2DModel - MODELS = [UNet2DModel] + import diffusers + + MODELS = [diffusers.UNet2DModel] HAS_REPO = True + from packaging import version + + SKIP_UNET_TEST = version.parse(diffusers.__version__) > version.parse("0.10.2") except: MODELS = [] HAS_REPO = False + SKIP_UNET_TEST = False from test_autochunk_diffuser_utils import run_test @@ -32,6 +37,10 @@ def get_data(shape: tuple) -> Tuple[List, List]: return meta_args, concrete_args +@pytest.mark.skipif( + SKIP_UNET_TEST, + reason="diffusers version > 0.10.2", +) @pytest.mark.skipif( not (AUTOCHUNK_AVAILABLE and HAS_REPO), reason="torch version is lower than 1.12.0", @@ -51,13 +60,4 @@ def test_evoformer_block(model, shape, max_memory): if __name__ == "__main__": - run_test( - rank=0, - data=get_data(LATENTS_SHAPE), - max_memory=None, - model=UNet2DModel, - print_code=False, - print_mem=True, - print_est_mem=False, - print_progress=False, - ) + test_evoformer_block() diff --git a/tests/test_autochunk/test_autochunk_transformer/benchmark_autochunk_transformer.py b/tests/test_autochunk/test_autochunk_transformer/benchmark_autochunk_transformer.py index 63490aaee7ff0803b7c0aded3fca7d79084f9b27..e70e501750325d00ba478319110c84d264e9d3f9 100644 --- a/tests/test_autochunk/test_autochunk_transformer/benchmark_autochunk_transformer.py +++ b/tests/test_autochunk/test_autochunk_transformer/benchmark_autochunk_transformer.py @@ -1,5 +1,5 @@ import time -from typing import Any, Dict, List +from typing import Any import torch import torch.fx @@ -64,8 +64,10 @@ def _benchmark_autochunk_gpt_gm( para_mem = float(parameter_size(model)) / 1024**2 * 6 act_mem = _benchmark_memory(gm, inputs) speed = _benchmark_speed(gm, inputs) - print("gpt autochunk, time: %.4fs, act mem: %.2fMB, para mem: %.2fMB, all mem: %.2fMB" % - (speed, act_mem, para_mem, act_mem + para_mem)) + print( + "gpt autochunk, time: %.4fs, act mem: %.2fMB, para mem: %.2fMB, all mem: %.2fMB" + % (speed, act_mem, para_mem, act_mem + para_mem) + ) def _benchmark_autochunk_gpt_origin( @@ -86,8 +88,10 @@ def _benchmark_autochunk_gpt_origin( para_mem = float(parameter_size(model)) / 1024**2 * 6 act_mem = _benchmark_memory(model, inputs) speed = _benchmark_speed(model, inputs) - print("gpt origin, time: %.4fs, act mem: %.2fMB, para mem: %.2fMB, all mem: %.2fMB" % - (speed, act_mem, para_mem, act_mem + para_mem)) + print( + "gpt origin, time: %.4fs, act mem: %.2fMB, para mem: %.2fMB, all mem: %.2fMB" + % (speed, act_mem, para_mem, act_mem + para_mem) + ) return act_mem @@ -115,6 +119,7 @@ def _benchmark_speed(model, inputs, loop=5): def benchmark_autochunk_gpt(batch=1, seq=512, n_embd=768, n_head=12): from test_autochunk_gpt import GPT2Config, GPT2Model, get_data + model = GPT2Model config = GPT2Config(n_embd=n_embd, n_positions=seq, n_layer=2, n_head=n_head) model = model(config=config) @@ -125,7 +130,7 @@ def benchmark_autochunk_gpt(batch=1, seq=512, n_embd=768, n_head=12): try: _benchmark_autochunk_gpt_gm(model, get_data(shape), max_mem * ratio) except RuntimeError as e: - if e.args[0] == 'Search failed. Try a larger memory threshold.': + if e.args[0] == "Search failed. Try a larger memory threshold.": break except Exception as e: raise e diff --git a/tests/test_autochunk/test_autochunk_transformer/test_autochunk_gpt.py b/tests/test_autochunk/test_autochunk_transformer/test_autochunk_gpt.py index 384706639e108c067839b878cca3bb5ba142b604..b2d842ee6a7b1c63f77e47524a6c221527a955f0 100644 --- a/tests/test_autochunk/test_autochunk_transformer/test_autochunk_gpt.py +++ b/tests/test_autochunk/test_autochunk_transformer/test_autochunk_gpt.py @@ -5,6 +5,7 @@ import torch try: from transformers import GPT2Config, GPT2Model + MODELS = [GPT2Model] HAS_REPO = True except: @@ -30,6 +31,8 @@ def get_data(shape: tuple) -> Tuple[List, List]: return meta_args, concrete_args, sequence +@pytest.mark.skip("full op is not implemented now") +# FIXME(ver217, oahzxl): implement full op @pytest.mark.skipif( not (AUTOCHUNK_AVAILABLE and HAS_REPO), reason="torch version is lower than 1.12.0", @@ -50,13 +53,15 @@ def test_autochunk_gpt(model, shape, max_memory): if __name__ == "__main__": - run_test(rank=0, - data=get_data((BATCH_SIZE, SEQ_LENGTH)), - max_memory=None, - model=GPT2Model, - config=GPT2Config(n_embd=96, n_position=SEQ_LENGTH, n_layer=2, n_head=4), - print_code=False, - print_est_mem=False, - print_mem=False, - print_progress=False, - eval_mem=False) + run_test( + rank=0, + data=get_data((BATCH_SIZE, SEQ_LENGTH)), + max_memory=None, + model=GPT2Model, + config=GPT2Config(n_embd=96, n_position=SEQ_LENGTH, n_layer=2, n_head=4), + print_code=False, + print_est_mem=False, + print_mem=False, + print_progress=False, + eval_mem=False, + ) diff --git a/tests/test_autochunk/test_autochunk_transformer/test_autochunk_transformer_utils.py b/tests/test_autochunk/test_autochunk_transformer/test_autochunk_transformer_utils.py index faba138cd42cae467532587338dda5bf9a6b33fb..77c11db71a5c48d1d7de276f510d460c77df8008 100644 --- a/tests/test_autochunk/test_autochunk_transformer/test_autochunk_transformer_utils.py +++ b/tests/test_autochunk/test_autochunk_transformer/test_autochunk_transformer_utils.py @@ -5,10 +5,8 @@ import torch.fx import colossalai from colossalai.autochunk.autochunk_codegen import AUTOCHUNK_AVAILABLE -from colossalai.core import global_context as gpc from colossalai.fx.graph_module import ColoGraphModule from colossalai.fx.passes.meta_info_prop import MetaInfoProp -from colossalai.testing import free_port if AUTOCHUNK_AVAILABLE: from colossalai.autochunk.autochunk_codegen import AutoChunkCodeGen @@ -40,11 +38,9 @@ def assert_codegen_run( meta_tensors = [meta_args[i] if i in meta_args else concrete_args[i] for i in sequence] meta_tensors = [MetaTensor(i, fake_device="cuda:0") if isinstance(i, torch.Tensor) else i for i in meta_tensors] interp.propagate(*meta_tensors) - codegen = AutoChunkCodeGen(meta_graph, - max_memory=max_memory, - print_mem=print_est_mem, - print_progress=print_progress, - eval_mem=eval_mem) + codegen = AutoChunkCodeGen( + meta_graph, max_memory=max_memory, print_mem=print_est_mem, print_progress=print_progress, eval_mem=eval_mem + ) chunks = codegen.chunk_infos # trace and recompile @@ -87,9 +83,9 @@ def assert_allclose(out_model: Any, out_gm: Any) -> None: assert allclose for out """ if isinstance(out_model, torch.Tensor): - assert torch.allclose(out_model, out_gm, - atol=1e-4), "fx_out doesn't comply with original output, diff is %.2e" % torch.mean( - torch.abs(out_model - out_gm)) + assert torch.allclose( + out_model, out_gm, atol=1e-4 + ), "fx_out doesn't comply with original output, diff is %.2e" % torch.mean(torch.abs(out_model - out_gm)) elif isinstance(out_model, dict): for k in out_model.keys(): assert_allclose(out_model[k], out_gm[k]) @@ -100,6 +96,8 @@ def assert_allclose(out_model: Any, out_gm: Any) -> None: def run_test( rank: int, + world_size: int, + port: int, model: Any, config: Any, data: tuple, @@ -116,26 +114,28 @@ def run_test( colossalai.launch( config={}, rank=rank, - world_size=1, + world_size=world_size, host="localhost", - port=free_port(), + port=port, backend="nccl", ) # build model and input - chunks = assert_codegen_run(model, - data=data, - max_memory=max_memory, - print_code=print_code, - print_est_mem=print_est_mem, - print_mem=print_mem, - print_progress=print_progress, - eval_mem=eval_mem) + chunks = assert_codegen_run( + model, + data=data, + max_memory=max_memory, + print_code=print_code, + print_est_mem=print_est_mem, + print_mem=print_mem, + print_progress=print_progress, + eval_mem=eval_mem, + ) if get_chunk_target is not None: chunk_found = [i["region"] for i in chunks] chunk_target = get_chunk_target()[max_memory] - assert (chunk_found == chunk_target), "found regions %s doesn't equal target regions %s" % ( + assert chunk_found == chunk_target, "found regions %s doesn't equal target regions %s" % ( str(chunk_found), str(chunk_target), ) diff --git a/tests/test_autochunk/test_autochunk_vit/test_autochunk_vit.py b/tests/test_autochunk/test_autochunk_vit/test_autochunk_vit.py index a98aa0e03954e49c9dfd361a1ed885e86a9c0771..aa868d683f069f8e2b25841ca53cc7cd04affc44 100644 --- a/tests/test_autochunk/test_autochunk_vit/test_autochunk_vit.py +++ b/tests/test_autochunk/test_autochunk_vit/test_autochunk_vit.py @@ -5,6 +5,7 @@ import torch try: from timm.models.vision_transformer import vit_large_patch16_384 as vit + MODELS = [vit] HAS_REPO = True except: @@ -19,7 +20,7 @@ from colossalai.testing import clear_cache_before_run, parameterize, spawn def get_data() -> Tuple[List, List]: data = torch.rand(1, 3, 384, 384) - meta_args = {'x': data} + meta_args = {"x": data} return data, meta_args diff --git a/tests/test_autochunk/test_autochunk_vit/test_autochunk_vit_utils.py b/tests/test_autochunk/test_autochunk_vit/test_autochunk_vit_utils.py index 317606fc478184c5f659797042dec5702c9e488e..ca919fb7e4feace1c5c4b9fd468cea872b81e662 100644 --- a/tests/test_autochunk/test_autochunk_vit/test_autochunk_vit_utils.py +++ b/tests/test_autochunk/test_autochunk_vit/test_autochunk_vit_utils.py @@ -5,10 +5,9 @@ import torch.fx import colossalai from colossalai.autochunk.autochunk_codegen import AUTOCHUNK_AVAILABLE -from colossalai.core import global_context as gpc from colossalai.fx.graph_module import ColoGraphModule from colossalai.fx.passes.meta_info_prop import MetaInfoProp -from colossalai.testing import free_port +from colossalai.legacy.core import global_context as gpc if AUTOCHUNK_AVAILABLE: from colossalai.autochunk.autochunk_codegen import AutoChunkCodeGen @@ -76,15 +75,17 @@ def assert_codegen_run( max_mem_ori = torch.cuda.max_memory_allocated() / 1024**2 print("origin mem: %.2fMB, autochunk mem: %.2fMB" % (max_mem_ori - now_mem_ori, max_mem_gm - now_mem_gm)) - assert torch.allclose(out_gm, out_model, - atol=1e-3), "fx_out doesn't comply with original output, diff is %.2e" % torch.mean( - torch.abs(out_gm - out_model)) + assert torch.allclose( + out_gm, out_model, atol=1e-3 + ), "fx_out doesn't comply with original output, diff is %.2e" % torch.mean(torch.abs(out_gm - out_model)) return chunks def run_test( rank: int, + world_size: int, + port: int, model: Any, data: tuple, max_memory: int, @@ -98,9 +99,9 @@ def run_test( colossalai.launch( config={}, rank=rank, - world_size=1, + world_size=world_size, host="localhost", - port=free_port(), + port=port, backend="nccl", ) @@ -120,7 +121,7 @@ def run_test( if get_chunk_target is not None: chunk_found = [i["region"] for i in chunks] chunk_target = get_chunk_target()[max_memory] - assert (chunk_found == chunk_target), "found regions %s doesn't equal target regions %s" % ( + assert chunk_found == chunk_target, "found regions %s doesn't equal target regions %s" % ( str(chunk_found), str(chunk_target), ) diff --git a/tests/test_booster/test_accelerator.py b/tests/test_booster/test_accelerator.py index 895c494d0c17f9ad4822154376dcaf045a310ac8..777589299d137dcace20d25211ded69c048c3ce3 100644 --- a/tests/test_booster/test_accelerator.py +++ b/tests/test_booster/test_accelerator.py @@ -5,10 +5,10 @@ from colossalai.testing import clear_cache_before_run, parameterize @clear_cache_before_run() -@parameterize('device', ['cpu', 'cuda']) +@parameterize("device", ["cpu", "cuda"]) def test_accelerator(device): - acceleartor = Accelerator(device) + accelerator = Accelerator(device) model = nn.Linear(8, 8) - model = acceleartor.configure_model(model) + model = accelerator.configure_model(model) assert next(model.parameters()).device.type == device - del model, acceleartor + del model, accelerator diff --git a/tests/test_booster/test_mixed_precision/test_fp16_torch.py b/tests/test_booster/test_mixed_precision/test_fp16_torch.py index 963387da262bffea1ee555f8614ed2f60157b3ea..3aefb37974f05a86755a40a16b40050db239619c 100644 --- a/tests/test_booster/test_mixed_precision/test_fp16_torch.py +++ b/tests/test_booster/test_mixed_precision/test_fp16_torch.py @@ -9,11 +9,11 @@ from tests.kit.model_zoo import model_zoo def run_torch_amp(rank, world_size, port): # init dist env - colossalai.launch(config=dict(), rank=rank, world_size=world_size, port=port, host='localhost') - sub_model_zoo = model_zoo.get_sub_registry('timm') - for name, (model_fn, data_gen_fn, output_transform_fn, _) in sub_model_zoo.items(): + colossalai.launch(config=dict(), rank=rank, world_size=world_size, port=port, host="localhost") + sub_model_zoo = model_zoo.get_sub_registry("timm") + for name, (model_fn, data_gen_fn, output_transform_fn, _, _) in sub_model_zoo.items(): # dlrm_interactionarch has not parameters, so skip - if name == 'dlrm_interactionarch': + if name == "dlrm_interactionarch": continue model = model_fn().cuda() @@ -21,7 +21,7 @@ def run_torch_amp(rank, world_size, port): criterion = lambda x: x.mean() data = data_gen_fn() data = { - k: v.to('cuda') if torch.is_tensor(v) or 'Tensor' in v.__class__.__name__ else v for k, v in data.items() + k: v.to("cuda") if torch.is_tensor(v) or "Tensor" in v.__class__.__name__ else v for k, v in data.items() } mixed_precision = FP16TorchMixedPrecision() model, optimizer, criterion = mixed_precision.configure(model, optimizer, criterion) diff --git a/tests/test_booster/test_plugin/test_3d_plugin.py b/tests/test_booster/test_plugin/test_3d_plugin.py new file mode 100644 index 0000000000000000000000000000000000000000..ad878fb0c86a42a0d88e4d0f3f4eb95e8b311c55 --- /dev/null +++ b/tests/test_booster/test_plugin/test_3d_plugin.py @@ -0,0 +1,100 @@ +from contextlib import nullcontext +from typing import Optional + +import torch +import torch.distributed as dist + +import colossalai +from colossalai.booster import Booster +from colossalai.booster.plugin import HybridParallelPlugin +from colossalai.fx import is_compatible_with_meta +from colossalai.lazy.lazy_init import LazyInitContext +from colossalai.nn.optimizer import HybridAdam +from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn +from tests.kit.model_zoo import model_zoo + + +def run_fn(init_method, model_fn, data_gen_fn, output_transform_fn) -> Optional[str]: + try: + if init_method == "lazy": + ctx = LazyInitContext() + else: + ctx = nullcontext() + plugin = HybridParallelPlugin(tp_size=2, pp_size=2, num_microbatches=4, precision="bf16") + booster = Booster(plugin=plugin) + with ctx: + model = model_fn() + optimizer = HybridAdam(model.parameters(), lr=1e-3) + criterion = lambda x: x.mean() + data = data_gen_fn() + + data = { + k: v.to("cuda").repeat(4, 1) if torch.is_tensor(v) or "Tensor" in v.__class__.__name__ else v + for k, v in data.items() + } + + model, optimizer, criterion, _, _ = booster.boost(model, optimizer, criterion) + + data_iter = iter([data]) + + def _criterion(outputs, inputs): + outputs = output_transform_fn(outputs) + output_key = list(outputs.keys())[0] + loss = criterion(outputs[output_key]) + return loss + + booster.execute_pipeline(data_iter, model, _criterion, optimizer, return_loss=True, return_outputs=False) + optimizer.step() + + except Exception as e: + return repr(e) + + +@parameterize("init_method", ["none", "lazy"]) +def check_3d_plugin(init_method: str = "none", early_stop: bool = True): + """check gemini plugin over model zoo + + Args: + early_stop (bool, optional): Whether to stop when getting the first error. Defaults to True. + """ + is_support_meta = is_compatible_with_meta() + if not is_support_meta and init_method == "lazy": + return + + passed_models = [] + failed_info = {} # (model_name, error) pair + + # TODO(ver217): add more models + for name, (model_fn, data_gen_fn, output_transform_fn, _, _) in model_zoo.get_sub_registry( + "transformers_llama_for_casual_lm" + ).items(): + err = run_fn(init_method, model_fn, data_gen_fn, output_transform_fn) + torch.cuda.empty_cache() + + if err is None: + passed_models.append(name) + else: + failed_info[name] = err + if early_stop: + break + + if dist.get_rank() == 0: + print(f"Init method: {init_method}") + print(f"Passed models({len(passed_models)}): {passed_models}\n\n") + print(f"Failed models({len(failed_info)}): {list(failed_info.keys())}\n\n") + assert len(failed_info) == 0, "\n".join([f"{k}: {v}" for k, v in failed_info.items()]) + + +def run_dist(rank, world_size, port, early_stop: bool = True): + # init dist env + colossalai.launch(config=dict(), rank=rank, world_size=world_size, port=port, host="localhost") + check_3d_plugin(early_stop=early_stop) + + +@rerun_if_address_is_in_use() +def test_gemini_plugin(early_stop: bool = True): + spawn(run_dist, 4, early_stop=early_stop) + + +if __name__ == "__main__": + test_gemini_plugin(early_stop=False) diff --git a/tests/test_booster/test_plugin/test_dp_plugin_base.py b/tests/test_booster/test_plugin/test_dp_plugin_base.py new file mode 100644 index 0000000000000000000000000000000000000000..0ac9d0f6d40904e3bce2e888800ddd959f0f02bd --- /dev/null +++ b/tests/test_booster/test_plugin/test_dp_plugin_base.py @@ -0,0 +1,88 @@ +from typing import Callable, Iterator, List, Tuple, Union + +import torch +import torch.distributed as dist +import torch.nn as nn +from torch.optim import Optimizer +from torch.optim.lr_scheduler import _LRScheduler as LRScheduler +from torch.utils.data import DataLoader, TensorDataset + +import colossalai +from colossalai.booster.plugin.dp_plugin_base import DPPluginBase +from colossalai.checkpoint_io import CheckpointIO +from colossalai.interface import OptimizerWrapper +from colossalai.testing import rerun_if_address_is_in_use, spawn + + +class DPPluginWrapper(DPPluginBase): + """This is a wrapper class for testing DP plugin initialization and dataloader creation.""" + + def configure( + self, + model: nn.Module, + optimizer: Optimizer, + criterion: Callable = None, + dataloader: DataLoader = None, + lr_scheduler: LRScheduler = None, + ) -> Tuple[Union[nn.Module, OptimizerWrapper, LRScheduler, DataLoader]]: + pass + + def control_checkpoint_io(self) -> bool: + pass + + def control_device(self) -> bool: + pass + + def control_precision(self) -> bool: + pass + + def get_checkpoint_io(self) -> CheckpointIO: + pass + + def support_no_sync(self) -> bool: + pass + + def supported_devices(self) -> List[str]: + pass + + def supported_precisions(self) -> List[str]: + pass + + def no_sync(self, model: nn.Module) -> Iterator[None]: + pass + + +def check_dataloader_sharding(): + plugin = DPPluginWrapper() + + # create a custom dataset with 0 to 10 + dataset = TensorDataset(torch.arange(0, 10)) + train_dataloader = plugin.prepare_dataloader(dataset, batch_size=2) + + # get the first batch of data + batch = next(iter(train_dataloader))[0].cuda() + is_rank_0 = dist.get_rank() == 0 + + if is_rank_0: + batch_to_compare = batch.clone() + else: + batch_to_compare = batch + # pass to the rank 1 value to rank 0 + dist.broadcast(batch_to_compare, src=1) + + # compare on rank 0 + if is_rank_0: + assert not torch.equal( + batch, batch_to_compare + ), "Same number was found across ranks but expected it to be different" + + +def run_dist(rank, world_size, port): + # init dist env + colossalai.launch(config=dict(), rank=rank, world_size=world_size, port=port, host="localhost") + check_dataloader_sharding() + + +@rerun_if_address_is_in_use() +def test_dp_plugin_dataloader(): + spawn(run_dist, 2) diff --git a/tests/test_booster/test_plugin/test_gemini_plugin.py b/tests/test_booster/test_plugin/test_gemini_plugin.py index 985d7989fc9dfa7df2523899b7912e6e776d35dc..00ff6cb37d2a365f35e5dcc0769a2ea029af19cc 100644 --- a/tests/test_booster/test_plugin/test_gemini_plugin.py +++ b/tests/test_booster/test_plugin/test_gemini_plugin.py @@ -8,23 +8,20 @@ import colossalai from colossalai.booster import Booster from colossalai.booster.plugin import GeminiPlugin from colossalai.fx import is_compatible_with_meta +from colossalai.lazy.lazy_init import LazyInitContext from colossalai.nn.optimizer import HybridAdam from colossalai.tensor.colo_parameter import ColoParameter from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn -from colossalai.utils.model.experimental import LazyInitContext -from colossalai.zero import ColoInitContext from tests.kit.model_zoo import model_zoo def run_fn(init_method, model_fn, data_gen_fn, output_transform_fn) -> Optional[str]: try: - if init_method == 'colo': - ctx = ColoInitContext() - elif init_method == 'lazy': + if init_method == "lazy": ctx = LazyInitContext() else: ctx = nullcontext() - plugin = GeminiPlugin(placement_policy='cuda', strict_ddp_mode=True, max_norm=1.0, initial_scale=2**5) + plugin = GeminiPlugin(max_norm=1.0, initial_scale=2**5) booster = Booster(plugin=plugin) with ctx: model = model_fn() @@ -33,13 +30,13 @@ def run_fn(init_method, model_fn, data_gen_fn, output_transform_fn) -> Optional[ data = data_gen_fn() data = { - k: v.to('cuda') if torch.is_tensor(v) or 'Tensor' in v.__class__.__name__ else v for k, v in data.items() + k: v.to("cuda") if torch.is_tensor(v) or "Tensor" in v.__class__.__name__ else v for k, v in data.items() } model, optimizer, criterion, _, _ = booster.boost(model, optimizer, criterion) for n, p in model.named_parameters(): - assert isinstance(p, ColoParameter), f'{n} is not a ColoParameter' + assert isinstance(p, ColoParameter), f"{n} is not a ColoParameter" output = model(**data) output = output_transform_fn(output) @@ -50,6 +47,7 @@ def run_fn(init_method, model_fn, data_gen_fn, output_transform_fn) -> Optional[ optimizer.step() except Exception as e: + # raise e return repr(e) @@ -57,52 +55,69 @@ def run_fn(init_method, model_fn, data_gen_fn, output_transform_fn) -> Optional[ # @parameterize('init_method', ['lazy', 'none', 'colo']) -@parameterize('init_method', ['none']) -def check_gemini_plugin(init_method: str = 'none', early_stop: bool = True): +@parameterize("subset", ["torchvision", "transformers", "diffusers"]) +@parameterize("init_method", ["none"]) +def check_gemini_plugin(subset: str, init_method: str = "none", early_stop: bool = True): """check gemini plugin over model zoo Args: early_stop (bool, optional): Whether to stop when getting the first error. Defaults to True. """ is_support_meta = is_compatible_with_meta() - if not is_support_meta and init_method == 'lazy': + if not is_support_meta and init_method == "lazy": return passed_models = [] - failed_info = {} # (model_name, error) pair + failed_info = {} # (model_name, error) pair - for name, (model_fn, data_gen_fn, output_transform_fn, _) in model_zoo.items(): + for name, (model_fn, data_gen_fn, output_transform_fn, _, _) in model_zoo.get_sub_registry(subset).items(): # These models lead to CUDA error - if name in ('diffusers_auto_encoder_kl', 'diffusers_vq_model', 'diffusers_unet2d_model', 'timm_resmlp', - 'timm_gmixer_12_224', 'timm_gmlp_b16_224', 'timm_mixer_b16_224', 'timm_convnext'): + if name in ( + "diffusers_auto_encoder_kl", + "diffusers_vq_model", + "diffusers_unet2d_model", + "timm_resmlp", + "timm_gmixer_12_224", + "timm_gmlp_b16_224", + "timm_mixer_b16_224", + "timm_convnext", + "torchvision_convnext_base", + ): continue # These models are not compatible with gemini if name in [ - 'diffusers_clip_vision_model', 'timm_resnet', 'timm_beit', 'timm_beitv2', 'timm_eca_nfnet', - 'timm_efficientformer', 'timm_hrnet_w18_small', 'timm_nf_ecaresnet101', 'timm_nf_regnet_b0', - 'timm_skresnet18', 'timm_wide_resnet50_2', 'timm_convit', 'timm_dm_nfnet', 'timm_swin_transformer', - 'torchaudio_conformer', 'torchaudio_deepspeech', 'torchaudio_wavernn', 'torchaudio_tacotron', - 'deepfm_interactionarch', 'deepfm_simpledeepfmnn', 'dlrm', 'dlrm_interactionarch', - 'torchvision_googlenet', 'torchvision_inception_v3', 'torchvision_mobilenet_v3_small', - 'torchvision_resnet18', 'torchvision_resnext50_32x4d', 'torchvision_wide_resnet50_2', - 'torchvision_vit_b_16', 'torchvision_convnext_base', 'torchvision_swin_s', 'transformers_albert', - 'transformers_albert_for_pretraining', 'transformers_bert', 'transformers_bert_for_pretraining', - 'transformers_gpt_double_heads', 'torchaudio_hubert_base', 'torchaudio_wav2vec2_base', - 'transformers_t5_for_conditional_generation', 'transformers_t5', 'transformers_t5_encoder_model' + "timm_convit", + "timm_dm_nfnet", + "torchvision_vit_b_16", + "transformers_t5", + "transformers_t5_for_conditional_generation", + "transformers_t5_encoder_model", # does not support apex rmsnorm + "transformers_chatglm", + "transformers_sam", + "transformers_vit", + "transformers_gpt_double_heads", # TODO check why does the model fail to run using Gemini ]: continue - if init_method == 'lazy' and name in [ - 'timm_convmixer', 'timm_vision_transformer', 'timm_deit', 'timm_deit3', 'timm_inception_v3', - 'timm_tnt_b_patch16_224', 'timm_rexnet', 'torchvision_densenet121', 'torchvision_efficientnet_b0', - 'torchvision_mobilenet_v2', 'torchvision_mnasnet0_5', 'torchvision_regnet_x_16gf', - 'torchvision_shufflenet_v2_x0_5', 'torchvision_efficientnet_v2_s' + if init_method == "lazy" and name in [ + "timm_convmixer", + "timm_vision_transformer", + "timm_deit", + "timm_deit3", + "timm_inception_v3", + "timm_tnt_b_patch16_224", + "timm_rexnet", + "torchvision_densenet121", + "torchvision_efficientnet_b0", + "torchvision_mobilenet_v2", + "torchvision_mnasnet0_5", + "torchvision_regnet_x_16gf", + "torchvision_shufflenet_v2_x0_5", + "torchvision_efficientnet_v2_s", ]: continue - err = run_fn(init_method, model_fn, data_gen_fn, output_transform_fn) torch.cuda.empty_cache() - if err is None: passed_models.append(name) else: @@ -111,40 +126,15 @@ def check_gemini_plugin(init_method: str = 'none', early_stop: bool = True): break if dist.get_rank() == 0: - print(f'Init method: {init_method}') - print(f'Passed models({len(passed_models)}): {passed_models}\n\n') - print(f'Failed models({len(failed_info)}): {list(failed_info.keys())}\n\n') - assert len(failed_info) == 0, '\n'.join([f'{k}: {v}' for k, v in failed_info.items()]) - - -def check_dataloader_sharding(): - plugin = GeminiPlugin() - - # create a custom dasetset with 0 to 10 - dataset = torch.utils.data.TensorDataset(torch.arange(0, 10)) - train_dataloader = plugin.prepare_train_dataloader(dataset, batch_size=2) - - # get the first batch of data - batch = next(iter(train_dataloader))[0].cuda() - is_rank_0 = dist.get_rank() == 0 - - if is_rank_0: - batch_to_compare = batch.clone() - else: - batch_to_compare = batch - # pass to the rank 1 value to rank 0 - dist.broadcast(batch_to_compare, src=1) - - # compare on rank 0 - if is_rank_0: - assert not torch.equal(batch, - batch_to_compare), 'Same number was found across ranks but expected it to be different' + print(f"Init method: {init_method}") + print(f"Passed models({len(passed_models)}): {passed_models}\n\n") + print(f"Failed models({len(failed_info)}): {list(failed_info.keys())}\n\n") + assert len(failed_info) == 0, "\n".join([f"{k}: {v}" for k, v in failed_info.items()]) def run_dist(rank, world_size, port, early_stop: bool = True): # init dist env - colossalai.launch(config=dict(), rank=rank, world_size=world_size, port=port, host='localhost') - check_dataloader_sharding() + colossalai.launch(config=dict(), rank=rank, world_size=world_size, port=port, host="localhost") check_gemini_plugin(early_stop=early_stop) @@ -153,5 +143,5 @@ def test_gemini_plugin(early_stop: bool = True): spawn(run_dist, 4, early_stop=early_stop) -if __name__ == '__main__': +if __name__ == "__main__": test_gemini_plugin(early_stop=False) diff --git a/tests/test_booster/test_plugin/test_low_level_zero_plugin.py b/tests/test_booster/test_plugin/test_low_level_zero_plugin.py index e24196a149172c16300845f156d429338fa7a1b8..9cc12f96bd4d9deafd170aa53d9a27865d5ff12a 100644 --- a/tests/test_booster/test_plugin/test_low_level_zero_plugin.py +++ b/tests/test_booster/test_plugin/test_low_level_zero_plugin.py @@ -11,14 +11,9 @@ from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn from tests.kit.model_zoo import model_zoo # These models are not compatible with AMP -_AMP_ERR_MODELS = ['timm_convit', 'dlrm', 'deepfm_interactionarch', 'deepfm_simpledeepfmnn`'] +_AMP_ERR_MODELS = ["timm_convit", "deepfm_interactionarch"] # These models have no parameters -_LOW_LEVEL_ZERO_ERR_MODELS = ['dlrm_interactionarch'] -# These models will get stuck -_STUCK_MODELS = [ - 'diffusers_vq_model', 'transformers_albert', 'transformers_albert_for_pretraining', 'transformers_bert', - 'transformers_bert_for_pretraining', 'transformers_gpt_double_heads' -] +_LOW_LEVEL_ZERO_ERR_MODELS = ["dlrm_interactionarch"] def run_fn(stage, model_fn, data_gen_fn, output_transform_fn) -> Optional[str]: @@ -31,7 +26,7 @@ def run_fn(stage, model_fn, data_gen_fn, output_transform_fn) -> Optional[str]: data = data_gen_fn() data = { - k: v.to('cuda') if torch.is_tensor(v) or 'Tensor' in v.__class__.__name__ else v for k, v in data.items() + k: v.to("cuda") if torch.is_tensor(v) or "Tensor" in v.__class__.__name__ else v for k, v in data.items() } model, optimizer, criterion, _, _ = booster.boost(model, optimizer, criterion) @@ -48,7 +43,7 @@ def run_fn(stage, model_fn, data_gen_fn, output_transform_fn) -> Optional[str]: return repr(e) -@parameterize('stage', [2]) +@parameterize("stage", [2]) def check_low_level_zero_plugin(stage: int, early_stop: bool = True): """check low level zero plugin over model zoo @@ -57,16 +52,17 @@ def check_low_level_zero_plugin(stage: int, early_stop: bool = True): early_stop (bool, optional): Whether to stop when getting the first error. Defaults to True. """ passed_models = [] - failed_info = {} # (model_name, error) pair - ignore_models = _AMP_ERR_MODELS + _LOW_LEVEL_ZERO_ERR_MODELS + _STUCK_MODELS + failed_info = {} # (model_name, error) pair + ignore_models = _AMP_ERR_MODELS + _LOW_LEVEL_ZERO_ERR_MODELS skipped_models = [] - for name, (model_fn, data_gen_fn, output_transform_fn, _) in model_zoo.items(): + for name, (model_fn, data_gen_fn, output_transform_fn, _, _) in model_zoo.items(): # FIXME(ver217): fix these models if name in ignore_models: skipped_models.append(name) continue err = run_fn(stage, model_fn, data_gen_fn, output_transform_fn) + torch.cuda.empty_cache() if err is None: @@ -77,46 +73,22 @@ def check_low_level_zero_plugin(stage: int, early_stop: bool = True): break if dist.get_rank() == 0: - print(f'Passed models({len(passed_models)}): {passed_models}\n\n') - print(f'Failed models({len(failed_info)}): {list(failed_info.keys())}\n\n') - print(f'Skipped models({len(skipped_models)}): {skipped_models}\n\n') - assert len(failed_info) == 0, '\n'.join([f'{k}: {v}' for k, v in failed_info.items()]) - - -def check_dataloader_sharding(): - plugin = LowLevelZeroPlugin() - - # create a custom dasetset with 0 to 10 - dataset = torch.utils.data.TensorDataset(torch.arange(0, 10)) - train_dataloader = plugin.prepare_train_dataloader(dataset, batch_size=2) - - # get the first batch of data - batch = next(iter(train_dataloader))[0].cuda() - is_rank_0 = dist.get_rank() == 0 - - if is_rank_0: - batch_to_compare = batch.clone() - else: - batch_to_compare = batch - # pass to the rank 1 value to rank 0 - dist.broadcast(batch_to_compare, src=1) - - # compare on rank 0 - if is_rank_0: - assert not torch.equal(batch, - batch_to_compare), 'Same number was found across ranks but expected it to be different' + print(f"Passed models({len(passed_models)}): {passed_models}\n\n") + print(f"Failed models({len(failed_info)}): {list(failed_info.keys())}\n\n") + print(f"Skipped models({len(skipped_models)}): {skipped_models}\n\n") + assert len(failed_info) == 0, "\n".join([f"{k}: {v}" for k, v in failed_info.items()]) def run_dist(rank, world_size, port, early_stop: bool = True): # init dist env - colossalai.launch(config=dict(), rank=rank, world_size=world_size, port=port, host='localhost') + colossalai.launch(config=dict(), rank=rank, world_size=world_size, port=port, host="localhost") check_low_level_zero_plugin(early_stop=early_stop) @rerun_if_address_is_in_use() def test_low_level_zero_plugin(early_stop: bool = True): - spawn(run_dist, 2, early_stop=early_stop) + spawn(run_dist, 4, early_stop=early_stop) -if __name__ == '__main__': +if __name__ == "__main__": test_low_level_zero_plugin(early_stop=False) diff --git a/tests/test_booster/test_plugin/test_torch_ddp_plugin.py b/tests/test_booster/test_plugin/test_torch_ddp_plugin.py index 5354eae01d40fbf4a0cd740c6568fb296b08f6ef..1a7ca6f2a30cdeb866c92c2cac4dbeadf4480cd5 100644 --- a/tests/test_booster/test_plugin/test_torch_ddp_plugin.py +++ b/tests/test_booster/test_plugin/test_torch_ddp_plugin.py @@ -1,5 +1,8 @@ +from contextlib import nullcontext + import torch import torch.distributed as dist +import torch.nn as nn from torch.nn.parallel import DistributedDataParallel as DDP from torch.optim import SGD @@ -19,7 +22,7 @@ def run_fn(model_fn, data_gen_fn, output_transform_fn): criterion = lambda x: x.mean() data = data_gen_fn() - data = {k: v.to('cuda') if torch.is_tensor(v) or 'Tensor' in v.__class__.__name__ else v for k, v in data.items()} + data = {k: v.to("cuda") if torch.is_tensor(v) or "Tensor" in v.__class__.__name__ else v for k, v in data.items()} model, optimizer, criterion, _, _ = booster.boost(model, optimizer, criterion) @@ -37,65 +40,72 @@ def run_fn(model_fn, data_gen_fn, output_transform_fn): def check_torch_ddp_plugin(): - for name, (model_fn, data_gen_fn, output_transform_fn, _) in model_zoo.items(): - if name == 'dlrm_interactionarch': + for name, (model_fn, data_gen_fn, output_transform_fn, _, _) in model_zoo.items(): + if name == "dlrm_interactionarch": continue run_fn(model_fn, data_gen_fn, output_transform_fn) torch.cuda.empty_cache() -def check_dataloader_sharding(): - plugin = TorchDDPPlugin() - - # create a custom dasetset with 0 to 10 - dataset = torch.utils.data.TensorDataset(torch.arange(0, 10)) - train_dataloader = plugin.prepare_train_dataloader(dataset, batch_size=2) - - # get the first batch of data - batch = next(iter(train_dataloader))[0].cuda() - is_rank_0 = dist.get_rank() == 0 - - if is_rank_0: - batch_to_compare = batch.clone() - else: - batch_to_compare = batch - # pass to the rank 1 value to rank 0 - dist.broadcast(batch_to_compare, src=1) +class DummyModel(nn.Module): + def __init__(self): + super().__init__() + self.weight = nn.Parameter(torch.rand(1)) - # compare on rank 0 - if is_rank_0: - assert not torch.equal(batch, - batch_to_compare), 'Same number was found across ranks but expected it to be different' + def forward(self, x): + return self.weight * x -def check_checkpoint_save_and_load(): - model_fn, data_gen_fn, output_transform_fn, _ = model_zoo['timm_resnet'] - +def check_torch_ddp_no_sync(): plugin = TorchDDPPlugin() booster = Booster(plugin=plugin) - model = model_fn() - optimizer = SGD(model.parameters(), lr=1e-3) + model = DummyModel() criterion = lambda x: x.mean() - data = data_gen_fn() - - data = {k: v.to('cuda') if torch.is_tensor(v) or 'Tensor' in v.__class__.__name__ else v for k, v in data.items()} - - model, optimizer, criterion, _, _ = booster.boost(model, optimizer, criterion) - - output = model(**data) - output = output_transform_fn(output) - output_key = list(output.keys())[0] - loss = criterion(output[output_key]) - - booster.backward(loss, optimizer) + optimizer = SGD(model.parameters(), lr=1e-3) + # create a custom dataset with 0 to 10 + dataset = torch.arange(0, 10) + train_dataloader = plugin.prepare_dataloader(dataset, batch_size=2) + model, optimizer, criterion, train_dataloader, _ = booster.boost( + model, optimizer, criterion, dataloader=train_dataloader + ) + + def fwd_bwd(): + output = model(batch.cuda()) + loss = criterion(output) + booster.backward(loss, optimizer) + + def get_grad_set_over_all_ranks(): + for p in model.parameters(): + # grad shape is (1, ) + assert p.grad.shape == (1,) + grad_list = [torch.empty_like(p.grad) for _ in range(dist.get_world_size())] + dist.all_gather(grad_list, p.grad) + # get grad set of all ranks + grad_set = set([grad.item() for grad in grad_list]) + # as the model only has one parameter, we can return here + return grad_set + + for i, batch in enumerate(train_dataloader): + if i > 1: + # only check the first two batches + break + # no_sync for the first batch, sync for the second batch + ctx = booster.no_sync(model) if i == 0 else nullcontext() + with ctx: + fwd_bwd() + grad_set = get_grad_set_over_all_ranks() + # for the first batch, all ranks should have different grads + # for the second batch, as grad is synchronized,all ranks should have the same grads + target_num_different_grad = dist.get_world_size() if i == 0 else 1 + assert len(grad_set) == target_num_different_grad def run_dist(rank, world_size, port): # init dist env - colossalai.launch(config=dict(), rank=rank, world_size=world_size, port=port, host='localhost') - check_dataloader_sharding() + colossalai.launch(config=dict(), rank=rank, world_size=world_size, port=port, host="localhost") check_torch_ddp_plugin() + check_torch_ddp_no_sync() @rerun_if_address_is_in_use() diff --git a/tests/test_booster/test_plugin/test_torch_fsdp_plugin.py b/tests/test_booster/test_plugin/test_torch_fsdp_plugin.py new file mode 100644 index 0000000000000000000000000000000000000000..8bcbffdd06fe1658c2e6a485e4111cdfb26186c5 --- /dev/null +++ b/tests/test_booster/test_plugin/test_torch_fsdp_plugin.py @@ -0,0 +1,70 @@ +import pytest +import torch +from packaging import version +from torch.optim import SGD + +import colossalai +from colossalai.booster import Booster + +if version.parse(torch.__version__) >= version.parse("1.12.0"): + from torch.distributed.fsdp import FullyShardedDataParallel as FSDP + from colossalai.booster.plugin import TorchFSDPPlugin + +from colossalai.interface import OptimizerWrapper +from colossalai.testing import rerun_if_address_is_in_use, spawn +from tests.kit.model_zoo import model_zoo + + +# test basic fsdp function +def run_fn(model_fn, data_gen_fn, output_transform_fn): + plugin = TorchFSDPPlugin() + booster = Booster(plugin=plugin) + model = model_fn() + optimizer = SGD(model.parameters(), lr=1e-3) + criterion = lambda x: x.mean() + data = data_gen_fn() + + data = {k: v.to("cuda") if torch.is_tensor(v) or "Tensor" in v.__class__.__name__ else v for k, v in data.items()} + + model, optimizer, criterion, _, _ = booster.boost(model, optimizer, criterion) + + assert isinstance(model.module, FSDP) + assert isinstance(optimizer, OptimizerWrapper) + + output = model(**data) + output = output_transform_fn(output) + output_key = list(output.keys())[0] + loss = criterion(output[output_key]) + + booster.backward(loss, optimizer) + optimizer.clip_grad_by_norm(1.0) + optimizer.step() + + +def check_torch_fsdp_plugin(): + for name, (model_fn, data_gen_fn, output_transform_fn, _, _) in model_zoo.items(): + if any( + element in name + for element in [ + "diffusers", + "deepfm_sparsearch", + "dlrm_interactionarch", + "torchvision_googlenet", + "torchvision_inception_v3", + ] + ): + continue + run_fn(model_fn, data_gen_fn, output_transform_fn) + torch.cuda.empty_cache() + + +def run_dist(rank, world_size, port): + # init dist env + colossalai.launch(config=dict(), rank=rank, world_size=world_size, port=port, host="localhost") + check_torch_fsdp_plugin() + + +@pytest.mark.skipif(version.parse(torch.__version__) < version.parse("1.12.0"), reason="requires torch1.12 or higher") +@rerun_if_address_is_in_use() +def test_torch_fsdp_plugin(): + spawn(run_dist, 2) diff --git a/tests/test_checkpoint_io/test_gemini_checkpoint_io.py b/tests/test_checkpoint_io/test_gemini_checkpoint_io.py new file mode 100644 index 0000000000000000000000000000000000000000..634e81bb225db1eab7792192ff042dcc9d6664ac --- /dev/null +++ b/tests/test_checkpoint_io/test_gemini_checkpoint_io.py @@ -0,0 +1,154 @@ +import os + +import pytest +import torch +import torch.distributed as dist +from transformers import LlamaForCausalLM +from utils import shared_tempdir + +import colossalai +from colossalai.booster import Booster +from colossalai.booster.plugin import GeminiPlugin +from colossalai.lazy import LazyInitContext +from colossalai.nn.optimizer import HybridAdam +from colossalai.testing import ( + check_state_dict_equal, + clear_cache_before_run, + parameterize, + rerun_if_address_is_in_use, + spawn, +) +from tests.kit.model_zoo import model_zoo + +MODEL_PLACEMENT_CONFIGS = [ + {"placement_policy": "static", "shard_param_frac": 0.0}, # zero2 + {"placement_policy": "static", "shard_param_frac": 1.0}, # zero3 + {"placement_policy": "static", "shard_param_frac": 0.5}, # zero3-half +] + +OPTIM_PLACEMENT_CONFIGS = [ + {"placement_policy": "static", "shard_param_frac": 0.0, "offload_optim_frac": 0.0}, # zero2 + {"placement_policy": "static", "shard_param_frac": 0.0, "offload_optim_frac": 1.0}, # zero2-offload + {"placement_policy": "static", "shard_param_frac": 0.0, "offload_optim_frac": 0.5}, # zero2-offload-half +] + + +@clear_cache_before_run() +@parameterize("placement_config", MODEL_PLACEMENT_CONFIGS) +@parameterize("model_name", ["transformers_bert_for_sequence_classification"]) +@parameterize("use_safetensors", [False, True]) +def exam_state_dict_with_origin(placement_config, model_name, use_safetensors: bool): + from transformers import BertForSequenceClassification + + (model_fn, data_gen_fn, output_transform_fn, _, _) = next(iter(model_zoo.get_sub_registry(model_name).values())) + bert_model = model_fn() + + with shared_tempdir() as tempdir: + pretrained_path = os.path.join(tempdir, "pretrained") + bert_model.config.save_pretrained(save_directory=pretrained_path) + + plugin = GeminiPlugin(**placement_config) + booster = Booster(plugin=plugin) + bert_model, _, _, _, _ = booster.boost(bert_model) + model_size = sum(p.numel() * p.element_size() for p in bert_model.parameters()) / 1024**2 + + booster.save_model( + bert_model, pretrained_path, True, True, "", (model_size / 3), use_safetensors=use_safetensors + ) + dist.barrier() + + new_bert_model = BertForSequenceClassification.from_pretrained(pretrained_path) + check_state_dict_equal( + bert_model.state_dict(only_rank_0=False, dtype=torch.float32), new_bert_model.state_dict(), False + ) + + +@clear_cache_before_run() +@parameterize("placement_config", OPTIM_PLACEMENT_CONFIGS) +@parameterize("shard", [False, True]) +@parameterize("model_name", ["transformers_gpt"]) +@parameterize("size_per_shard", [32]) +def exam_state_dict(placement_config, shard: bool, model_name: str, size_per_shard: int): + (model_fn, data_gen_fn, output_transform_fn, _, _) = next(iter(model_zoo.get_sub_registry(model_name).values())) + criterion = lambda x: x.mean() + plugin = GeminiPlugin(**placement_config, precision="fp16", initial_scale=(2**14)) + booster = Booster(plugin=plugin) + + model = model_fn() + new_model = model_fn() + optimizer = HybridAdam(model.parameters(), lr=0.001) + model, optimizer, criterion, _, _ = booster.boost(model, optimizer, criterion) + new_optimizer = HybridAdam(new_model.parameters(), lr=0.001) + new_model, new_optimizer, criterion, _, _ = booster.boost(new_model, new_optimizer, criterion) + + data = data_gen_fn() + data = {k: v.to("cuda") if torch.is_tensor(v) or "Tensor" in v.__class__.__name__ else v for k, v in data.items()} + output = model(**data) + output = output_transform_fn(output) + output_key = list(output.keys())[0] + loss = criterion(output[output_key]) + + booster.backward(loss, optimizer) + optimizer.step() + + with shared_tempdir() as tempdir: + model_ckpt_path = f"{tempdir}/model" + optimizer_ckpt_path = f"{tempdir}/optimizer" + booster.save_model(model, model_ckpt_path, shard=shard, size_per_shard=size_per_shard) + + booster.save_optimizer(optimizer, optimizer_ckpt_path, shard=shard, size_per_shard=size_per_shard) + dist.barrier() + + booster.load_model(new_model, model_ckpt_path) + check_state_dict_equal(model.state_dict(only_rank_0=False), new_model.state_dict(only_rank_0=False), False) + + booster.load_optimizer(new_optimizer, optimizer_ckpt_path) + check_state_dict_equal( + optimizer.state_dict(only_rank_0=False), new_optimizer.state_dict(only_rank_0=False), False + ) + + # Check the new model/optimizer can successfully run. + data = data_gen_fn() + data = { + k: v.to("cuda") if torch.is_tensor(v) or "Tensor" in v.__class__.__name__ else v for k, v in data.items() + } + output = new_model(**data) + output = output_transform_fn(output) + output_key = list(output.keys())[0] + loss = criterion(output[output_key]) + booster.backward(loss, new_optimizer) + new_optimizer.step() + booster.save_model(new_model, model_ckpt_path, shard=shard) + booster.save_optimizer(new_optimizer, optimizer_ckpt_path, shard=shard) + + +def exam_lazy_from_pretrained(): + llama_path = os.environ["LLAMA_PATH"] + plugin = GeminiPlugin() + booster = Booster(plugin=plugin) + orig_model = LlamaForCausalLM.from_pretrained(llama_path) + orig_state_dict = {k: v.half() for k, v in orig_model.state_dict().items()} + with LazyInitContext(): + model = LlamaForCausalLM.from_pretrained(llama_path) + model, *_ = booster.boost(model) + with shared_tempdir() as tempdir: + save_path = os.path.join(tempdir, "model.pt") + booster.save_model(model, save_path, shard=False) + dist.barrier() + state_dict = torch.load(save_path, map_location="cpu") + check_state_dict_equal(state_dict, orig_state_dict, False) + + +def run_dist(rank, world_size, port): + config = {} + colossalai.launch(config=config, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") + exam_state_dict() + exam_state_dict_with_origin() + exam_lazy_from_pretrained() + + +@pytest.mark.dist +@pytest.mark.parametrize("world_size", [2]) +@rerun_if_address_is_in_use() +def test_gemini_ckpIO(world_size): + spawn(run_dist, world_size) diff --git a/tests/test_checkpoint_io/test_gemini_torch_compability.py b/tests/test_checkpoint_io/test_gemini_torch_compability.py new file mode 100644 index 0000000000000000000000000000000000000000..d46e5380d94427ed6f860e66fb9f4729be3f5a24 --- /dev/null +++ b/tests/test_checkpoint_io/test_gemini_torch_compability.py @@ -0,0 +1,174 @@ +import pytest +import torch +import torch.distributed as dist +from torch.optim import Adam +from utils import shared_tempdir + +import colossalai +from colossalai.booster import Booster +from colossalai.booster.plugin import GeminiPlugin, TorchDDPPlugin +from colossalai.nn.optimizer import HybridAdam +from colossalai.testing import ( + check_state_dict_equal, + clear_cache_before_run, + parameterize, + rerun_if_address_is_in_use, + spawn, +) +from tests.kit.model_zoo import model_zoo + + +@clear_cache_before_run() +@parameterize("shard", [False, True]) +@parameterize("model_name", ["transformers_gpt"]) +def exam_torch_load_from_gemini(shard: bool, model_name: str): + (model_fn, data_gen_fn, output_transform_fn, _, _) = next(iter(model_zoo.get_sub_registry(model_name).values())) + criterion = lambda x: x.mean() + plugin = GeminiPlugin(precision="fp16", initial_scale=(2**14)) + booster = Booster(plugin=plugin) + + model = model_fn() + optimizer = HybridAdam(model.parameters(), lr=0.001) + model, optimizer, criterion, _, _ = booster.boost(model, optimizer, criterion) + + data = data_gen_fn() + data = {k: v.to("cuda") if torch.is_tensor(v) or "Tensor" in v.__class__.__name__ else v for k, v in data.items()} + output = model(**data) + output = output_transform_fn(output) + output_key = list(output.keys())[0] + loss = criterion(output[output_key]) + + booster.backward(loss, optimizer) + optimizer.step() + + with shared_tempdir() as tempdir: + model_ckpt_path = f"{tempdir}/model" + optimizer_ckpt_path = f"{tempdir}/optimizer" + + booster.save_model(model, model_ckpt_path, shard=shard) + booster.save_optimizer(optimizer, optimizer_ckpt_path, shard=shard) + dist.barrier() + + new_model = model_fn() + new_optimizer = Adam(new_model.parameters(), lr=0.001) + new_plugin = TorchDDPPlugin() + new_booster = Booster(plugin=new_plugin) + new_model, new_optimizer, criterion, _, _ = new_booster.boost(new_model, new_optimizer, criterion) + + # Loading HybridAdam states to torch.Adam + new_booster.load_model(new_model, model_ckpt_path, strict=True) + + # Add prefix to get aligned with pytorch parameter names. + check_state_dict_equal( + model.state_dict(only_rank_0=False, prefix="module.module.", dtype=torch.float32), + new_model.state_dict(), + False, + ) + + new_booster.load_optimizer(new_optimizer, optimizer_ckpt_path) + check_state_dict_equal(optimizer.state_dict(only_rank_0=False), new_optimizer.state_dict(), False) + + # Check the new model/optimizer can successfully run. + data = data_gen_fn() + data = { + k: v.to("cuda") if torch.is_tensor(v) or "Tensor" in v.__class__.__name__ else v for k, v in data.items() + } + output = new_model(**data) + output = output_transform_fn(output) + output_key = list(output.keys())[0] + loss = criterion(output[output_key]) + new_booster.backward(loss, new_optimizer) + new_optimizer.step() + new_booster.save_model(new_model, model_ckpt_path, shard=shard) + new_booster.save_optimizer(new_optimizer, optimizer_ckpt_path, shard=shard) + + +@clear_cache_before_run() +@parameterize("shard", [False, True]) +@parameterize("model_name", ["transformers_gpt"]) +def exam_gemini_load_from_torch(shard: bool, model_name: str): + (model_fn, data_gen_fn, output_transform_fn, _, _) = next(iter(model_zoo.get_sub_registry(model_name).values())) + criterion = lambda x: x.mean() + plugin = TorchDDPPlugin() + booster = Booster(plugin=plugin) + + model = model_fn() + optimizer = Adam(model.parameters(), lr=0.001) + model, optimizer, criterion, _, _ = booster.boost(model, optimizer, criterion) + + data = data_gen_fn() + data = {k: v.to("cuda") if torch.is_tensor(v) or "Tensor" in v.__class__.__name__ else v for k, v in data.items()} + output = model(**data) + output = output_transform_fn(output) + output_key = list(output.keys())[0] + loss = criterion(output[output_key]) + + booster.backward(loss, optimizer) + optimizer.step() + + with shared_tempdir() as tempdir: + model_ckpt_path = f"{tempdir}/model" + optimizer_ckpt_path = f"{tempdir}/optimizer" + + booster.save_model(model, model_ckpt_path, shard=shard) + booster.save_optimizer(optimizer, optimizer_ckpt_path, shard=shard) + dist.barrier() + + new_model = model_fn() + new_optimizer = HybridAdam(new_model.parameters(), lr=0.001) + new_plugin = GeminiPlugin() + new_booster = Booster(plugin=new_plugin) + new_model, new_optimizer, criterion, _, _ = new_booster.boost(new_model, new_optimizer, criterion) + + # Loading torch.Adam states to HybridAdam + new_booster.load_model(new_model, model_ckpt_path, strict=True) + + # Add prefix to get aligned with pytorch parameter names. + check_state_dict_equal( + new_model.state_dict(only_rank_0=False, prefix="module.module.", dtype=torch.float32), + model.state_dict(), + False, + ) + + new_booster.load_optimizer(new_optimizer, optimizer_ckpt_path) + old_state_dict = optimizer.state_dict() + new_state_dict = new_optimizer.state_dict(only_rank_0=False) + + # Comparison of param_groups needs special care here, + # since not all hyperparameters in Adam are used by HybridAdam + hyperparameters_to_examine = ["params", "lr", "betas", "eps", "weight_decay"] + for old_group, new_group in zip(old_state_dict["param_groups"], new_state_dict["param_groups"]): + for k in hyperparameters_to_examine: + assert ( + k in old_group and k in new_group + ), f"Old group's keys: {list(old_group.keys())}, New group's keys: {list(new_group.keys())}" + assert old_group[k] == new_group[k] + check_state_dict_equal(old_state_dict["state"], new_state_dict["state"], False) + + # Check the new model/optimizer can successfully run. + data = data_gen_fn() + data = { + k: v.to("cuda") if torch.is_tensor(v) or "Tensor" in v.__class__.__name__ else v for k, v in data.items() + } + output = new_model(**data) + output = output_transform_fn(output) + output_key = list(output.keys())[0] + loss = criterion(output[output_key]) + new_booster.backward(loss, new_optimizer) + new_optimizer.step() + new_booster.save_model(new_model, model_ckpt_path, shard=shard) + new_booster.save_optimizer(new_optimizer, optimizer_ckpt_path, shard=shard) + + +def run_dist(rank, world_size, port): + config = {} + colossalai.launch(config=config, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") + exam_torch_load_from_gemini() + exam_gemini_load_from_torch() + + +@pytest.mark.dist +@pytest.mark.parametrize("world_size", [2]) +@rerun_if_address_is_in_use() +def test_gemini_ckpIO(world_size): + spawn(run_dist, world_size) diff --git a/tests/test_checkpoint_io/test_general_checkpoint_io.py b/tests/test_checkpoint_io/test_general_checkpoint_io.py index ca5ce10054f7c85d86a9ab82a8141c0edfb7389c..2a046a298dd7cbc411e3a0920cfea57881d00d0e 100644 --- a/tests/test_checkpoint_io/test_general_checkpoint_io.py +++ b/tests/test_checkpoint_io/test_general_checkpoint_io.py @@ -1,15 +1,12 @@ import tempfile + import pytest import torch -import logging from torch.optim import Adam from torchvision.models import resnet18 -from pathlib import Path -import os -import subprocess from colossalai.checkpoint_io import GeneralCheckpointIO -from colossalai.testing import clear_cache_before_run, parameterize +from colossalai.testing import check_state_dict_equal, clear_cache_before_run, parameterize # ======== # Note: @@ -20,7 +17,7 @@ from colossalai.testing import clear_cache_before_run, parameterize @clear_cache_before_run() -@parameterize('use_safetensors', [True, False]) +@parameterize("use_safetensors", [True, False]) def test_unsharded_checkpoint(use_safetensors: bool): # create a model and optimizer model = resnet18() @@ -56,13 +53,13 @@ def test_unsharded_checkpoint(use_safetensors: bool): ckpt_io.load_model(new_model, model_ckpt_tempfile.name) ckpt_io.load_optimizer(new_optimizer, optimizer_ckpt_tempfile.name) - # check for model and optimizer state dict recursively - recursive_check(model.state_dict(), new_model.state_dict()) - recursive_check(optimizer.state_dict(), new_optimizer.state_dict()) + check_state_dict_equal(model.state_dict(), new_model.state_dict()) + check_state_dict_equal(optimizer.state_dict(), new_optimizer.state_dict()) + -@pytest.mark.parametrize('use_safetensors', [True, False]) -def test_sharded_checkpoint(use_safetensors: bool): +@pytest.mark.parametrize("use_safetensors", [True, False]) +def test_sharded_model_checkpoint(use_safetensors: bool): # create a model and optimizer model = resnet18() optimizer = Adam(model.parameters(), lr=0.001) @@ -77,13 +74,10 @@ def test_sharded_checkpoint(use_safetensors: bool): # create a temp file for checkpoint if use_safetensors: - suffix = ".safetensors" - SAFE_WEIGHTS_INDEX_NAME = "model.safetensors.index.json" + pass else: - suffix = ".bin" - WEIGHTS_INDEX_NAME = "model.bin.index.json" - - # model_ckpt_dir = tempfile.TemporaryDirectory(suffix=suffix) + pass + model_ckpt_dir = tempfile.TemporaryDirectory() optimizer_ckpt_tempfile = tempfile.NamedTemporaryFile() @@ -92,7 +86,7 @@ def test_sharded_checkpoint(use_safetensors: bool): ckpt_io.save_model(model, model_ckpt_dir.name, True, True, "", 10, use_safetensors=use_safetensors) ckpt_io.save_optimizer(optimizer, optimizer_ckpt_tempfile.name, shard=False) - + # create new model new_model = resnet18() new_optimizer = Adam(new_model.parameters(), lr=0.001) @@ -101,26 +95,103 @@ def test_sharded_checkpoint(use_safetensors: bool): ckpt_io.load_optimizer(new_optimizer, optimizer_ckpt_tempfile.name) # check for model and optimizer state dict recursively - recursive_check(model.state_dict(), new_model.state_dict()) - recursive_check(optimizer.state_dict(), new_optimizer.state_dict()) - - -# do recursive check for the optimizer state dict -# if the value is a dict, compare its values -# if the value is a list, comapre all elements one-by-one -# if the value is a torch.Tensor, use torch.equal -# otherwise use assertEqual -def recursive_check(d1, d2): - for k, v in d1.items(): - if isinstance(v, dict): - recursive_check(v, d2[k]) - elif isinstance(v, list): - for i in range(len(v)): - if isinstance(v[i], torch.Tensor): - assert torch.equal(v[i], d2[k][i]) - else: - assert v[i] == d2[k][i] - elif isinstance(v, torch.Tensor): - assert torch.equal(v, d2[k]) - else: - assert v == d2[k] + check_state_dict_equal(model.state_dict(), new_model.state_dict()) + check_state_dict_equal(optimizer.state_dict(), new_optimizer.state_dict()) + + +def test_sharded_optimizer_checkpoint(): + # create a model and optimizer + model = resnet18() + optimizer = Adam(model.parameters(), lr=0.001) + + # create test data sample + x = torch.randn(1, 3, 224, 224) + + # run fwd and bwd + y = model(x) + loss = y.sum() + loss.backward() + optimizer.step() + + # create temp directories for checkpoint + model_ckpt_dir = tempfile.TemporaryDirectory() + optimizer_ckpt_dir = tempfile.TemporaryDirectory() + + # save the model and optimizer + ckpt_io = GeneralCheckpointIO() + + ckpt_io.save_model(model, model_ckpt_dir.name, True, True, "", 10, use_safetensors=False) + ckpt_io.save_optimizer(optimizer, optimizer_ckpt_dir.name, shard=True, size_per_shard=10) + + # create new model + new_model = resnet18() + new_optimizer = Adam(new_model.parameters(), lr=0.001) + + ckpt_io.load_model(new_model, str(model_ckpt_dir.name), strict=True) + ckpt_io.load_optimizer(new_optimizer, str(optimizer_ckpt_dir.name)) + + # check for model and optimizer state dict recursively + check_state_dict_equal(model.state_dict(), new_model.state_dict()) + check_state_dict_equal(optimizer.state_dict(), new_optimizer.state_dict()) + + # continue running fwd and bwd + for _ in range(5): + y = new_model(x) + loss = y.sum() + loss.backward() + new_optimizer.step() + + # save the newly got optimizer + ckpt_io.save_model(new_model, model_ckpt_dir.name, True, True, "", 10, use_safetensors=False) + ckpt_io.save_optimizer(new_optimizer, optimizer_ckpt_dir.name, shard=True, size_per_shard=10) + + # create another new model + new_new_model = resnet18() + new_new_optimizer = Adam(new_new_model.parameters(), lr=0.001) + + ckpt_io.load_model(new_new_model, str(model_ckpt_dir.name), strict=True) + ckpt_io.load_optimizer(new_new_optimizer, str(optimizer_ckpt_dir.name)) + + # check for model and optimizer state dict recursively + check_state_dict_equal(new_model.state_dict(), new_new_model.state_dict()) + check_state_dict_equal(new_optimizer.state_dict(), new_new_optimizer.state_dict()) + + +def test_sharded_optimizer_multiple_param_groups(): + # create a model and optimizer + model = resnet18() + optimizer = Adam( + [{"params": model.layer1.parameters()}, {"params": model.layer2.parameters(), "lr": 0.002}], lr=0.001 + ) + + # create test data sample + x = torch.randn(1, 3, 224, 224) + + # run fwd and bwd + y = model(x) + loss = y.sum() + loss.backward() + optimizer.step() + + # create temp directories for checkpoint + model_ckpt_dir = tempfile.TemporaryDirectory() + optimizer_ckpt_dir = tempfile.TemporaryDirectory() + + # save the model and optimizer + ckpt_io = GeneralCheckpointIO() + + ckpt_io.save_model(model, model_ckpt_dir.name, True, True, "", 10, use_safetensors=False) + ckpt_io.save_optimizer(optimizer, optimizer_ckpt_dir.name, shard=True, size_per_shard=10) + + # create new model + new_model = resnet18() + new_optimizer = Adam( + [{"params": new_model.layer1.parameters()}, {"params": new_model.layer2.parameters(), "lr": 0.002}], lr=0.001 + ) + + ckpt_io.load_model(new_model, str(model_ckpt_dir.name), strict=True) + ckpt_io.load_optimizer(new_optimizer, str(optimizer_ckpt_dir.name)) + + # check for model and optimizer state dict recursively + check_state_dict_equal(model.state_dict(), new_model.state_dict()) + check_state_dict_equal(optimizer.state_dict(), new_optimizer.state_dict()) diff --git a/tests/test_checkpoint_io/test_hybrid_parallel_plugin_checkpoint_io.py b/tests/test_checkpoint_io/test_hybrid_parallel_plugin_checkpoint_io.py new file mode 100644 index 0000000000000000000000000000000000000000..c0bc2d2f5d0a916e49b03c22915c8097b3cddb85 --- /dev/null +++ b/tests/test_checkpoint_io/test_hybrid_parallel_plugin_checkpoint_io.py @@ -0,0 +1,147 @@ +import pytest +import torch +import torch.distributed as dist +from packaging.version import Version +from torch.optim import Adam +from utils import shared_tempdir + +import colossalai +from colossalai.booster import Booster +from colossalai.booster.plugin import HybridParallelPlugin +from colossalai.shardformer.layer.utils import Randomizer +from colossalai.tensor.d_tensor.api import clear_layout_converter +from colossalai.testing import ( + assert_close_loose, + check_state_dict_equal, + clear_cache_before_run, + parameterize, + rerun_if_address_is_in_use, + spawn, +) +from tests.kit.model_zoo import model_zoo + +if Version(torch.__version__) < Version("2.0.0"): + TEST_CONFIGS = [ + { + "tp_size": 4, + "pp_size": 1, + "precision": "fp32", + }, + {"tp_size": 2, "pp_size": 2, "num_microbatches": 4, "precision": "fp16", "initial_scale": 1}, + {"tp_size": 2, "pp_size": 1, "zero_stage": 2, "precision": "fp16", "initial_scale": 1}, + {"tp_size": 1, "pp_size": 2, "num_microbatches": 4, "zero_stage": 1, "precision": "fp16", "initial_scale": 1}, + ] +else: + TEST_CONFIGS = [ + # TODO(ver217): other configs lead to hang + {"tp_size": 1, "pp_size": 2, "num_microbatches": 4, "zero_stage": 1, "precision": "fp16", "initial_scale": 1}, + ] + + +@clear_cache_before_run() +@parameterize("shard", [True, False]) +@parameterize("model_name", ["transformers_gpt"]) +@parameterize("size_per_shard", [32]) +@parameterize("test_config", TEST_CONFIGS) +def exam_state_dict(shard: bool, model_name: str, size_per_shard: int, test_config: dict): + (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) = next( + iter(model_zoo.get_sub_registry(model_name).values()) + ) + criterion = loss_fn + plugin = HybridParallelPlugin(**test_config) + booster = Booster(plugin=plugin) + + def _criterion(outputs, inputs): + outputs = output_transform_fn(outputs) + loss = criterion(outputs) + return loss + + def _preprocess_data(data): + if booster.plugin.stage_manager is not None: + for k, v in data.items(): + if torch.is_tensor(v) or "Tensor" in v.__class__.__name__: + new_shape = [1] * v.dim() + new_shape[0] = 4 + data[k] = v.to("cuda").repeat(*new_shape) + return iter([data]) + else: + return {k: v.cuda() for k, v in data.items()} + + model = model_fn().cuda() + optimizer = Adam(model.parameters(), lr=1e-3) + model, optimizer, criterion, _, _ = booster.boost(model, optimizer, criterion) + + data = data_gen_fn() + model.train() + if booster.plugin.stage_manager is not None: + booster.execute_pipeline( + _preprocess_data(data), model, _criterion, optimizer, return_loss=True, return_outputs=False + ) + else: + output = model(**_preprocess_data(data)) + loss = criterion(output) + optimizer.backward(loss) + + optimizer.step() + + with shared_tempdir() as tempdir: + model_ckpt_path = f"{tempdir}/model" + optimizer_ckpt_path = f"{tempdir}/optimizer" + booster.save_model(model, model_ckpt_path, shard=shard, size_per_shard=size_per_shard) + booster.save_optimizer(optimizer, optimizer_ckpt_path, shard=shard, size_per_shard=size_per_shard) + dist.barrier() + + new_model = model_fn().cuda() + new_optimizer = Adam(new_model.parameters(), lr=1e-3) + new_model, new_optimizer, criterion, _, _ = booster.boost(new_model, new_optimizer, criterion) + + booster.load_model(new_model, model_ckpt_path) + check_state_dict_equal(model.unwrap().state_dict(), new_model.unwrap().state_dict(), False) + booster.load_optimizer(new_optimizer, optimizer_ckpt_path) + check_state_dict_equal(optimizer.unwrap().state_dict(), new_optimizer.unwrap().state_dict(), False) + dist.barrier() + + # Check whether the loaded model & optimizer works smoothly. + model.train() + new_model.train() + if booster.plugin.stage_manager is not None: + booster.execute_pipeline( + _preprocess_data(data), model, _criterion, optimizer, return_loss=True, return_outputs=False + ) + booster.execute_pipeline( + _preprocess_data(data), new_model, _criterion, new_optimizer, return_loss=True, return_outputs=False + ) + else: + old_model_loss = criterion(model(**_preprocess_data(data))) + optimizer.backward(old_model_loss) + new_model_loss = criterion(new_model(**_preprocess_data(data))) + new_optimizer.backward(new_model_loss) + + optimizer.step() + new_optimizer.step() + + # Check updated weights. + stage_manager = booster.plugin.stage_manager + + if stage_manager is None or stage_manager.is_first_stage(): + assert_close_loose(model.unwrap().wte.weight.data, new_model.unwrap().wte.weight.data, atol=5e-3, rtol=5e-3) + assert_close_loose( + model.unwrap().h[0].mlp.c_fc.weight.data, new_model.unwrap().h[0].mlp.c_fc.weight.data, atol=5e-3, rtol=5e-3 + ) + + dist.barrier() + Randomizer.reset_index() + clear_layout_converter() + + +def run_dist(rank, world_size, port): + config = {} + colossalai.launch(config=config, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") + exam_state_dict() + + +@pytest.mark.dist +@pytest.mark.parametrize("world_size", [4]) +@rerun_if_address_is_in_use() +def test_hybrid_ckpIO(world_size): + spawn(run_dist, world_size) diff --git a/tests/test_checkpoint_io/test_low_level_zero_checkpoint_io.py b/tests/test_checkpoint_io/test_low_level_zero_checkpoint_io.py new file mode 100644 index 0000000000000000000000000000000000000000..8a4724c8a82c45137510cc791e17be1e5ec02eae --- /dev/null +++ b/tests/test_checkpoint_io/test_low_level_zero_checkpoint_io.py @@ -0,0 +1,84 @@ +import torch +import torch.distributed as dist +from torchvision.models import resnet18 +from utils import shared_tempdir + +import colossalai +from colossalai.booster import Booster +from colossalai.booster.plugin import LowLevelZeroPlugin +from colossalai.nn.optimizer import HybridAdam +from colossalai.testing import ( + check_state_dict_equal, + clear_cache_before_run, + parameterize, + rerun_if_address_is_in_use, + spawn, +) +from colossalai.zero import LowLevelZeroOptimizer + + +# stage 1 and 2 process the optimizer/mode the same way +# only test 2 is fine +@clear_cache_before_run() +@parameterize("stage", [2]) +@parameterize("shard", [True, False]) +@parameterize("offload", [False, True]) +def check_low_level_zero_checkpointIO(stage: int, shard: bool, offload: bool): + plugin = LowLevelZeroPlugin(stage=stage, max_norm=1.0, initial_scale=32, cpu_offload=offload) + booster = Booster(plugin=plugin) + model = resnet18() + criterion = lambda x: x.mean() + optimizer = HybridAdam((model.parameters()), lr=0.001) + model, optimizer, criterion, _, _ = booster.boost(model, optimizer, criterion) + + x = torch.randn(1, 3, 224, 224, device="cuda") + output = model(x) + loss = criterion(output) + booster.backward(loss, optimizer) + optimizer.step() + with shared_tempdir() as tempdir: + model_ckpt_path = f"{tempdir}/model" + optimizer_ckpt_path = f"{tempdir}/optimizer" + # lr scheduler is tested in test_torch_ddp_checkpoint_io.py and low level zero does not change it, we can skip it here + booster.save_model(model, model_ckpt_path, shard=shard) + booster.save_optimizer(optimizer, optimizer_ckpt_path, shard=shard) + + dist.barrier() + + new_model = resnet18() + new_optimizer = HybridAdam((new_model.parameters()), lr=0.001) + new_model, new_optimizer, _, _, _ = booster.boost(new_model, new_optimizer) + + booster.load_model(new_model, model_ckpt_path) + check_state_dict_equal(model.state_dict(), new_model.state_dict(), False) + # check master weight + assert isinstance(new_optimizer, LowLevelZeroOptimizer) + working_param_id_set = set(id(p) for p in new_model.parameters()) + for p_id, master_param in new_optimizer._param_store.working_to_master_param.items(): + assert p_id in working_param_id_set + working_param = new_optimizer._param_store.master_to_working_param[id(master_param)] + padding = new_optimizer._param_store.get_param_padding_size(working_param) + padded_param = torch.nn.functional.pad(working_param.data.view(-1), (0, padding)) + working_shard = padded_param.chunk(dist.get_world_size())[dist.get_rank()] + assert torch.equal( + working_shard, master_param.data.view(-1).to(dtype=padded_param.dtype, device=padded_param.device) + ) + + booster.load_optimizer(new_optimizer, optimizer_ckpt_path) + check_state_dict_equal(optimizer.optim.state_dict(), new_optimizer.optim.state_dict(), False) + + +def run_dist(rank, world_size, port): + colossalai.launch(config=(dict()), rank=rank, world_size=world_size, port=port, host="localhost") + check_low_level_zero_checkpointIO() + torch.cuda.empty_cache() + + +@rerun_if_address_is_in_use() +@clear_cache_before_run() +def test_low_level_zero_checkpointIO(): + spawn(run_dist, 2) + + +if __name__ == "__main__": + test_low_level_zero_checkpointIO() diff --git a/tests/test_checkpoint_io/test_plugins_huggingface_compatibility.py b/tests/test_checkpoint_io/test_plugins_huggingface_compatibility.py new file mode 100644 index 0000000000000000000000000000000000000000..a6f67e0d7729706c5cf719e0530a0a287fcafd47 --- /dev/null +++ b/tests/test_checkpoint_io/test_plugins_huggingface_compatibility.py @@ -0,0 +1,80 @@ +import pytest +import torch +import torch.distributed as dist +from utils import shared_tempdir + +import colossalai +from colossalai.booster import Booster +from colossalai.booster.plugin import GeminiPlugin, LowLevelZeroPlugin, TorchDDPPlugin +from colossalai.nn.optimizer import HybridAdam +from colossalai.testing import ( + check_state_dict_equal, + clear_cache_before_run, + parameterize, + rerun_if_address_is_in_use, + spawn, +) +from tests.kit.model_zoo import model_zoo + + +@clear_cache_before_run() +@parameterize("model_name", ["transformers_gpt"]) +@parameterize("plugin_type", ["ddp", "zero", "gemini"]) +def exam_from_pretrained(plugin_type: str, model_name: str, shard=True, size_per_shard=32): + (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) = next( + iter(model_zoo.get_sub_registry(model_name).values()) + ) + criterion = loss_fn + + if plugin_type == "ddp": + plugin = TorchDDPPlugin() + elif plugin_type == "zero": + plugin = LowLevelZeroPlugin(stage=2, max_norm=1.0, initial_scale=32) + elif plugin_type == "gemini": + plugin = GeminiPlugin(precision="fp16", initial_scale=32) + else: + raise ValueError(f"Plugin with type {plugin_type} is invalid, please check your argument.") + + booster = Booster(plugin=plugin) + + model = model_fn().cuda() + model_huggingface_cls = model.__class__ + optimizer = HybridAdam(model.parameters(), lr=0.001) + model, optimizer, criterion, _, _ = booster.boost(model, optimizer, criterion) + + data = data_gen_fn() + data = {k: v.to("cuda") if torch.is_tensor(v) or "Tensor" in v.__class__.__name__ else v for k, v in data.items()} + output = model(**data) + loss = criterion(output) + + booster.backward(loss, optimizer) + optimizer.step() + + with shared_tempdir() as tempdir: + model_ckpt_path = f"{tempdir}/model" + booster.save_model(model, model_ckpt_path, shard=shard, size_per_shard=size_per_shard) + dist.barrier() + + new_model = model_huggingface_cls.from_pretrained(model_ckpt_path) + new_model = new_model.cuda() + new_optimizer = HybridAdam(new_model.parameters(), lr=0.001) + new_model, new_optimizer, criterion, _, _ = booster.boost(new_model, new_optimizer, criterion) + + if plugin_type == "gemini": + check_state_dict_equal(model.state_dict(only_rank_0=False), new_model.state_dict(only_rank_0=False), False) + else: + check_state_dict_equal(model.unwrap().state_dict(), new_model.unwrap().state_dict(), False) + dist.barrier() + + +def run_dist(rank, world_size, port): + config = {} + colossalai.launch(config=config, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") + exam_from_pretrained() + + +@pytest.mark.dist +@pytest.mark.parametrize("world_size", [2]) +@rerun_if_address_is_in_use() +def test_huggingface_compatibility(world_size): + spawn(run_dist, world_size) diff --git a/tests/test_checkpoint_io/test_torch_ddp_checkpoint_io.py b/tests/test_checkpoint_io/test_torch_ddp_checkpoint_io.py new file mode 100644 index 0000000000000000000000000000000000000000..eeb04df0f42d896cb8c1b239e67d32f3d142b2b6 --- /dev/null +++ b/tests/test_checkpoint_io/test_torch_ddp_checkpoint_io.py @@ -0,0 +1,70 @@ +import torch +import torch.distributed as dist +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.optim import SGD +from torchvision.models import resnet18 +from utils import shared_tempdir + +import colossalai +from colossalai.booster import Booster +from colossalai.booster.plugin import TorchDDPPlugin +from colossalai.interface import OptimizerWrapper +from colossalai.testing import check_state_dict_equal, parameterize, rerun_if_address_is_in_use, spawn + + +@parameterize("shard", [True, False]) +@parameterize("size_per_shard", [16, 128]) +def check_torch_ddp_checkpointIO(shard: bool, size_per_shard: int): + plugin = TorchDDPPlugin() + booster = Booster(plugin=plugin) + model = resnet18() + criterion = lambda x: x.mean() + optimizer = SGD((model.parameters()), lr=0.001) + scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.1) + model, optimizer, criterion, _, _ = booster.boost(model, optimizer, criterion, lr_scheduler=scheduler) + + assert isinstance(model.module, DDP) + assert isinstance(optimizer, OptimizerWrapper) + + x = torch.randn(4, 3, 224, 224) + x = x.to("cuda") + output = model(x) + loss = criterion(output) + booster.backward(loss, optimizer) + optimizer.clip_grad_by_norm(1.0) + optimizer.step() + scheduler.step() + + with shared_tempdir() as tempdir: + model_ckpt_path = f"{tempdir}/model" + optimizer_ckpt_path = f"{tempdir}/optimizer" + lr_scheduler_ckpt_path = f"{tempdir}/lr_scheduler" + booster.save_model(model, model_ckpt_path, shard=shard, size_per_shard=size_per_shard) + booster.save_optimizer(optimizer, optimizer_ckpt_path, shard=shard, size_per_shard=size_per_shard) + booster.save_lr_scheduler(scheduler, lr_scheduler_ckpt_path) + dist.barrier() + + new_model = resnet18() + new_optimizer = SGD((new_model.parameters()), lr=0.001) + new_scheduler = torch.optim.lr_scheduler.StepLR(new_optimizer, step_size=1, gamma=0.1) + new_model, new_optimizer, _, _, new_scheduler = booster.boost( + new_model, new_optimizer, lr_scheduler=new_scheduler + ) + + booster.load_model(new_model, model_ckpt_path) + check_state_dict_equal(model.state_dict(), new_model.state_dict(), False) + + booster.load_optimizer(new_optimizer, optimizer_ckpt_path) + check_state_dict_equal(optimizer.state_dict(), new_optimizer.state_dict(), False) + booster.load_lr_scheduler(new_scheduler, lr_scheduler_ckpt_path) + check_state_dict_equal(scheduler.state_dict(), new_scheduler.state_dict(), False) + + +def run_dist(rank, world_size, port): + colossalai.launch(config=(dict()), rank=rank, world_size=world_size, port=port, host="localhost") + check_torch_ddp_checkpointIO() + + +@rerun_if_address_is_in_use() +def test_torch_ddp_checkpointIO(): + spawn(run_dist, 2) diff --git a/tests/test_checkpoint_io/test_torch_fsdp_checkpoint_io.py b/tests/test_checkpoint_io/test_torch_fsdp_checkpoint_io.py new file mode 100644 index 0000000000000000000000000000000000000000..dd41f8185c2b84d15b844a20d115287426bf919c --- /dev/null +++ b/tests/test_checkpoint_io/test_torch_fsdp_checkpoint_io.py @@ -0,0 +1,112 @@ +import pytest +import torch +from packaging import version +from torch.optim import SGD +from torchvision.models import resnet18 +from utils import shared_tempdir + +import colossalai +from colossalai.booster import Booster + +if version.parse(torch.__version__) >= version.parse("1.12.0"): + from colossalai.booster.plugin import TorchFSDPPlugin + +from colossalai.testing import rerun_if_address_is_in_use, spawn + + +def compare_nested_dict(dict1, dict2): + for key in dict1: + if key in dict2: + if type(dict1[key]) is dict: + assert type(dict2[key]) is dict + diff = compare_nested_dict(dict1[key], dict2[key]) + if not diff: + return diff + elif type(dict1[key]) is list: + assert type(dict2[key]) is list + for i, val in enumerate(dict1[key]): + if isinstance(val, torch.Tensor): + if not torch.equal(dict1[key][i], dict2[key][i]): + return False + elif val != dict2[key][i]: + return False + elif type(dict1[key]) is torch.Tensor: + assert type(dict2[key]) is torch.Tensor + if not torch.equal(dict1[key], dict2[key]): + return False + else: + if dict1[key] != dict2[key]: + return False + else: + return False + return True + + +def check_torch_fsdp_ckpt(): + model = resnet18() + plugin = TorchFSDPPlugin() + booster = Booster(plugin=plugin) + optimizer = SGD(model.parameters(), lr=0.001, momentum=0.9) + criterion = lambda x: x.mean() + fsdp_model, optimizer, criterion, _, _ = booster.boost(model, optimizer, criterion) + + inputs = torch.randn(4, 3, 224, 224) + outputs = None + + def run_model(): + nonlocal outputs + outputs = fsdp_model(inputs) + optimizer.zero_grad() + criterion(outputs).backward() + optimizer.step() + + with shared_tempdir() as tempdir: + model_ckpt_path = f"{tempdir}/model" + optim_ckpt_path = f"{tempdir}/optimizer" + + run_model() + + booster.save_model(fsdp_model, model_ckpt_path, shard=False) + booster.save_optimizer(optimizer, optim_ckpt_path, shard=False) + + full_msd = fsdp_model.state_dict() + # full_osd = FSDP.full_optim_state_dict(fsdp_model, optimizer) + sharded_osd = optimizer.state_dict() + import copy + + sharded_osd = copy.deepcopy(sharded_osd) + + run_model() + + full_msd_updated = fsdp_model.state_dict() + # full_osd_updated = FSDP.full_optim_state_dict(fsdp_model, optimizer, rank0_only=True) + sharded_osd_updated = optimizer.state_dict() + + assert not compare_nested_dict(sharded_osd, sharded_osd_updated) + assert not compare_nested_dict(full_msd_updated, full_msd) + outputs_first = fsdp_model(inputs) + assert criterion(outputs_first) != criterion(outputs) + + booster.load_model(fsdp_model, model_ckpt_path) + booster.load_optimizer(optimizer, optim_ckpt_path) + + full_msd_restore = fsdp_model.state_dict() + # full_osd_restore = FSDP.full_optim_state_dict(fsdp_model, optimizer, rank0_only=True) + sharded_osd_restore = optimizer.state_dict() + + assert compare_nested_dict(sharded_osd, sharded_osd_restore) + assert compare_nested_dict(full_msd_restore, full_msd) + outputs_sec = fsdp_model(inputs) + assert criterion(outputs_sec) == criterion(outputs) + + +def run_dist(rank, world_size, port): + # init dist env + colossalai.launch(config=dict(), rank=rank, world_size=world_size, port=port, host="localhost") + check_torch_fsdp_ckpt() + + +@pytest.mark.skipif(version.parse(torch.__version__) < version.parse("1.12.0"), reason="requires torch1.12 or higher") +@rerun_if_address_is_in_use() +def test_torch_fsdp_ckpt(): + spawn(run_dist, 2) diff --git a/tests/test_checkpoint_io/utils.py b/tests/test_checkpoint_io/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..d14fc944267c0e90fa7bae353a54fe9e4140d79c --- /dev/null +++ b/tests/test_checkpoint_io/utils.py @@ -0,0 +1,21 @@ +import tempfile +from contextlib import contextmanager, nullcontext +from typing import Iterator + +import torch.distributed as dist + + +@contextmanager +def shared_tempdir() -> Iterator[str]: + """ + A temporary directory that is shared across all processes. + """ + ctx_fn = tempfile.TemporaryDirectory if dist.get_rank() == 0 else nullcontext + with ctx_fn() as tempdir: + try: + obj = [tempdir] + dist.broadcast_object_list(obj, src=0) + tempdir = obj[0] # use the same directory on all ranks + yield tempdir + finally: + dist.barrier() diff --git a/tests/test_cluster/test_device_mesh_manager.py b/tests/test_cluster/test_device_mesh_manager.py index b42ef1fe0062dc2dec944c321a63bf9812d86a6c..ab61cdae5bb05f9aba4dd7aad12b67e9c5de40ca 100644 --- a/tests/test_cluster/test_device_mesh_manager.py +++ b/tests/test_cluster/test_device_mesh_manager.py @@ -1,5 +1,3 @@ -import torch - from colossalai.cluster.device_mesh_manager import DeviceMeshInfo, DeviceMeshManager from colossalai.initialize import launch from colossalai.logging import disable_existing_loggers @@ -8,18 +6,19 @@ from colossalai.testing import spawn def check_device_mesh_manager(rank, world_size, port): disable_existing_loggers() - launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") device_mesh_manager = DeviceMeshManager() - device_mesh_info_auto = DeviceMeshInfo(physical_ids=[0, 1, 2, 3],) - device_mesh_auto = device_mesh_manager.create_device_mesh('0', device_mesh_info_auto) - assert device_mesh_auto.shape == (2, 2) - assert device_mesh_auto._logical_mesh_id.tolist() == [[0, 1], [2, 3]] + # TODO(ver217): this test is strictly relies on hardware, temporary skip it + # device_mesh_info_auto = DeviceMeshInfo(physical_ids=[0, 1, 2, 3],) + # device_mesh_auto = device_mesh_manager.create_device_mesh('0', device_mesh_info_auto) + # assert device_mesh_auto.shape == (2, 2) + # assert device_mesh_auto._logical_mesh_id.tolist() == [[0, 1], [2, 3]] device_mesh_info_with_shape = DeviceMeshInfo( physical_ids=[0, 1, 2, 3], mesh_shape=(2, 2), ) - device_mesh_with_shape = device_mesh_manager.create_device_mesh('1', device_mesh_info_with_shape) + device_mesh_with_shape = device_mesh_manager.create_device_mesh("1", device_mesh_info_with_shape) assert device_mesh_with_shape.shape == (2, 2) assert device_mesh_with_shape._logical_mesh_id.tolist() == [[0, 1], [2, 3]] @@ -29,5 +28,5 @@ def test_device_mesh_manager(): spawn(check_device_mesh_manager, 4) -if __name__ == '__main__': +if __name__ == "__main__": test_device_mesh_manager() diff --git a/tests/test_cluster/test_process_group_mesh.py b/tests/test_cluster/test_process_group_mesh.py new file mode 100644 index 0000000000000000000000000000000000000000..08542d1f64fac359ff752e5d6e4a0a7bb5decb01 --- /dev/null +++ b/tests/test_cluster/test_process_group_mesh.py @@ -0,0 +1,167 @@ +import pytest +import torch.distributed as dist + +import colossalai +from colossalai.cluster import ProcessGroupMesh +from colossalai.testing import spawn + + +def check_process_group_mesh_with_gpc(): + from colossalai.legacy.context import ParallelMode + from colossalai.legacy.core import global_context as gpc + + DP_DIM, PP_DIM, TP_DIM = 0, 1, 2 + pg_mesh = ProcessGroupMesh(1, 2, 2) + + # check world size + assert gpc.get_world_size(ParallelMode.TENSOR) == pg_mesh.size( + TP_DIM + ), f"{gpc.get_world_size(ParallelMode.TENSOR)} != {pg_mesh.size(TP_DIM)}" + assert gpc.get_world_size(ParallelMode.PIPELINE) == pg_mesh.size(PP_DIM) + assert gpc.get_world_size(ParallelMode.DATA) == pg_mesh.size(DP_DIM) + + # check locak rank (coordinate) + assert gpc.get_local_rank(ParallelMode.TENSOR) == pg_mesh.coordinate( + TP_DIM + ), f"{gpc.get_local_rank(ParallelMode.TENSOR)} != {pg_mesh.coordinate(TP_DIM)}" + assert gpc.get_local_rank(ParallelMode.PIPELINE) == pg_mesh.coordinate(PP_DIM) + assert gpc.get_local_rank(ParallelMode.DATA) == pg_mesh.coordinate(DP_DIM) + + # check ranks in group + tp_group = pg_mesh.get_group_along_axis(TP_DIM) + assert gpc.get_ranks_in_group(ParallelMode.TENSOR) == pg_mesh.get_ranks_in_group(tp_group) + pp_group = pg_mesh.get_group_along_axis(PP_DIM) + assert gpc.get_ranks_in_group(ParallelMode.PIPELINE) == pg_mesh.get_ranks_in_group(pp_group) + dp_group = pg_mesh.get_group_along_axis(DP_DIM) + assert gpc.get_ranks_in_group(ParallelMode.DATA) == pg_mesh.get_ranks_in_group(dp_group) + + # check prev rank + coord = pg_mesh.coordinate() + if not gpc.is_first_rank(ParallelMode.TENSOR): + assert coord[TP_DIM] != 0 + prev_coord = coord[:TP_DIM] + (coord[TP_DIM] - 1,) + coord[TP_DIM + 1 :] + assert gpc.get_prev_global_rank(ParallelMode.TENSOR) == pg_mesh.ravel(prev_coord, pg_mesh.shape) + if not gpc.is_first_rank(ParallelMode.PIPELINE): + assert coord[PP_DIM] != 0 + prev_coord = coord[:PP_DIM] + (coord[PP_DIM] - 1,) + coord[PP_DIM + 1 :] + assert gpc.get_prev_global_rank(ParallelMode.PIPELINE) == pg_mesh.ravel(prev_coord, pg_mesh.shape) + + # check next rank + if not gpc.is_last_rank(ParallelMode.TENSOR): + assert coord[TP_DIM] != pg_mesh.size(TP_DIM) - 1 + next_coord = coord[:TP_DIM] + (coord[TP_DIM] + 1,) + coord[TP_DIM + 1 :] + assert gpc.get_next_global_rank(ParallelMode.TENSOR) == pg_mesh.ravel(next_coord, pg_mesh.shape) + if not gpc.is_last_rank(ParallelMode.PIPELINE): + assert coord[PP_DIM] != pg_mesh.size(PP_DIM) - 1 + next_coord = coord[:PP_DIM] + (coord[PP_DIM] + 1,) + coord[PP_DIM + 1 :] + assert gpc.get_next_global_rank(ParallelMode.PIPELINE) == pg_mesh.ravel(next_coord, pg_mesh.shape) + + +def check_process_group_mesh_with_cases(): + DP_DIM, PP_DIM, TP_DIM = 0, 1, 2 + DP_SIZE, PP_SIZE, TP_SIZE = 1, 2, 2 + RANK_TO_COORDINATE = { + 0: (0, 0, 0), + 1: (0, 0, 1), + 2: (0, 1, 0), + 3: (0, 1, 1), + } + TP_RANKS_IN_GROUP = { + 0: [0, 1], + 1: [0, 1], + 2: [2, 3], + 3: [2, 3], + } + PP_RANKS_IN_GROUP = { + 0: [0, 2], + 1: [1, 3], + 2: [0, 2], + 3: [1, 3], + } + DP_RANKS_IN_GROUP = { + 0: [0], + 1: [1], + 2: [2], + 3: [3], + } + + pg_mesh = ProcessGroupMesh(DP_SIZE, PP_SIZE, TP_SIZE) + + rank = dist.get_rank() + assert rank == pg_mesh.rank + + # check world size + assert pg_mesh.size(TP_DIM) == 2 + assert pg_mesh.size(PP_DIM) == 2 + assert pg_mesh.size(DP_DIM) == 1 + + # check coordinate + assert pg_mesh.coordinate(TP_DIM) == RANK_TO_COORDINATE[rank][TP_DIM] + assert pg_mesh.coordinate(PP_DIM) == RANK_TO_COORDINATE[rank][PP_DIM] + assert pg_mesh.coordinate(DP_DIM) == RANK_TO_COORDINATE[rank][DP_DIM] + + # check ranks in group + tp_group = pg_mesh.get_group_along_axis(TP_DIM) + assert pg_mesh.get_ranks_in_group(tp_group) == TP_RANKS_IN_GROUP[rank] + pp_group = pg_mesh.get_group_along_axis(PP_DIM) + assert pg_mesh.get_ranks_in_group(pp_group) == PP_RANKS_IN_GROUP[rank] + dp_group = pg_mesh.get_group_along_axis(DP_DIM) + assert pg_mesh.get_ranks_in_group(dp_group) == DP_RANKS_IN_GROUP[rank] + + # check prev rank + if RANK_TO_COORDINATE[rank][TP_DIM] != 0: + prev_coord = ( + RANK_TO_COORDINATE[rank][:TP_DIM] + + (RANK_TO_COORDINATE[rank][TP_DIM] - 1,) + + RANK_TO_COORDINATE[rank][TP_DIM + 1 :] + ) + prev_rank = TP_RANKS_IN_GROUP[rank][TP_RANKS_IN_GROUP[rank].index(rank) - 1] + assert pg_mesh.ravel(prev_coord, pg_mesh.shape) == prev_rank + if RANK_TO_COORDINATE[rank][PP_DIM] != 0: + prev_coord = ( + RANK_TO_COORDINATE[rank][:PP_DIM] + + (RANK_TO_COORDINATE[rank][PP_DIM] - 1,) + + RANK_TO_COORDINATE[rank][PP_DIM + 1 :] + ) + prev_rank = PP_RANKS_IN_GROUP[rank][PP_RANKS_IN_GROUP[rank].index(rank) - 1] + assert pg_mesh.ravel(prev_coord, pg_mesh.shape) == prev_rank + + # check next rank + if RANK_TO_COORDINATE[rank][TP_DIM] != TP_SIZE - 1: + next_coord = ( + RANK_TO_COORDINATE[rank][:TP_DIM] + + (RANK_TO_COORDINATE[rank][TP_DIM] + 1,) + + RANK_TO_COORDINATE[rank][TP_DIM + 1 :] + ) + next_rank = TP_RANKS_IN_GROUP[rank][TP_RANKS_IN_GROUP[rank].index(rank) + 1] + assert pg_mesh.ravel(next_coord, pg_mesh.shape) == next_rank + if RANK_TO_COORDINATE[rank][PP_DIM] != PP_SIZE - 1: + next_coord = ( + RANK_TO_COORDINATE[rank][:PP_DIM] + + (RANK_TO_COORDINATE[rank][PP_DIM] + 1,) + + RANK_TO_COORDINATE[rank][PP_DIM + 1 :] + ) + next_rank = PP_RANKS_IN_GROUP[rank][PP_RANKS_IN_GROUP[rank].index(rank) + 1] + assert pg_mesh.ravel(next_coord, pg_mesh.shape) == next_rank + + +def run_dist(rank, world_size, port): + colossalai.launch( + config=dict(parallel=dict(data=1, pipeline=2, tensor=dict(mode="1d", size=2))), + rank=rank, + world_size=world_size, + port=port, + host="localhost", + ) + # TODO(ver217): this function should be removed when gpc is removed + # check_process_group_mesh_with_gpc() + check_process_group_mesh_with_cases() + + +@pytest.mark.dist +def test_process_group_mesh(): + spawn(run_dist, 4) + + +if __name__ == "__main__": + test_process_group_mesh() diff --git a/tests/test_comm/test_boardcast_send_recv_v2.py b/tests/test_comm/test_boardcast_send_recv_v2.py deleted file mode 100644 index 253f6f21cd804c7d4d1e3d3e62b5ce2253ef271f..0000000000000000000000000000000000000000 --- a/tests/test_comm/test_boardcast_send_recv_v2.py +++ /dev/null @@ -1,47 +0,0 @@ -import pytest -import torch - -from colossalai.communication.p2p_v2 import _recv_object, _send_object -from colossalai.context import ParallelMode -from colossalai.core import global_context as gpc -from colossalai.initialize import launch -from colossalai.logging import disable_existing_loggers -from colossalai.testing import rerun_if_address_is_in_use, spawn - -disable_existing_loggers() -world_size = 4 -CONFIG = dict(parallel=dict(pipeline=world_size)) -torch.manual_seed(123) - - -def check_layer(rank, world_size, port): - disable_existing_loggers() - launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl', verbose=False) - rank = gpc.get_local_rank(ParallelMode.PIPELINE) - - if rank == 0: - obj = [torch.randn(3,)] - _send_object(obj, 1) - - if rank == 1: - _recv_object(0) - - if rank == 2: - _recv_object(3) - - if rank == 3: - obj = [torch.randn(3,)] - _send_object(obj, 2) - - gpc.destroy() - torch.cuda.empty_cache() - - -@pytest.mark.dist -@rerun_if_address_is_in_use() -def test_object_list_p2p(): - spawn(check_layer, world_size) - - -if __name__ == '__main__': - test_object_list_p2p() diff --git a/tests/test_comm/test_comm.py b/tests/test_comm/test_comm.py deleted file mode 100644 index 747596bd2dedff9a20222669bb611cf96d020724..0000000000000000000000000000000000000000 --- a/tests/test_comm/test_comm.py +++ /dev/null @@ -1,71 +0,0 @@ -import pytest -import torch -import torch.distributed as dist - -from colossalai.communication import all_gather, all_reduce, reduce_scatter -from colossalai.context import ParallelMode -from colossalai.core import global_context as gpc -from colossalai.initialize import launch -from colossalai.testing import rerun_if_address_is_in_use, spawn -from colossalai.utils import get_current_device - -CONFIG = dict(parallel=dict(data=8, pipeline=1, tensor=dict(mode=None, size=1))) - -SIZE = 8 - - -def check_all_gather(): - tensor = torch.tensor([dist.get_rank() * SIZE + j for j in range(SIZE)]) - tensor = tensor.to(get_current_device()) - print('Before: Rank {0} - {1}'.format(dist.get_rank(), tensor)) - tensor, op = all_gather(tensor, 0, ParallelMode.GLOBAL, async_op=True) - print('After: Rank {0} - {1}'.format(dist.get_rank(), tensor)) - op.wait() - print('Complete: Rank {0} - {1}'.format(dist.get_rank(), tensor)) - torch.cuda.synchronize() - - -def check_reduce_scatter(): - tensor = torch.tensor([dist.get_rank() * SIZE + j for j in range(SIZE)]) - tensor = tensor.to(get_current_device()) - print('Before: Rank {0} - {1}'.format(dist.get_rank(), tensor)) - tensor, op = reduce_scatter(tensor, 0, ParallelMode.GLOBAL, async_op=True) - print('After: Rank {0} - {1}'.format(dist.get_rank(), tensor)) - op.wait() - print('Complete: Rank {0} - {1}'.format(dist.get_rank(), tensor)) - torch.cuda.synchronize() - - -def check_all_reduce(): - tensor = torch.tensor([dist.get_rank() * SIZE + j for j in range(SIZE)]) - tensor = tensor.to(get_current_device()) - print('Before: Rank {0} - {1}'.format(dist.get_rank(), tensor)) - tensor, op = all_reduce(tensor, ParallelMode.GLOBAL, async_op=True) - print('After: Rank {0} - {1}'.format(dist.get_rank(), tensor)) - op.wait() - print('Complete: Rank {0} - {1}'.format(dist.get_rank(), tensor)) - torch.cuda.synchronize() - - -def check_layer(rank, world_size, port): - launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') - - assert dist.get_rank() == gpc.get_global_rank() - print('Rank {} / {}'.format(dist.get_rank(), dist.get_world_size())) - - check_all_gather() - check_reduce_scatter() - check_all_reduce() - - gpc.destroy() - torch.cuda.empty_cache() - - -@pytest.mark.dist -@rerun_if_address_is_in_use() -def test_comm(): - spawn(check_layer, 4) - - -if __name__ == '__main__': - test_comm() diff --git a/tests/test_config/sample_config.py b/tests/test_config/sample_config.py index 08ca108281b9c0700fef4ecb2c14416ccbabfd9f..b9af7ab41a55b38e5dbddecfa6db760447cc5280 100644 --- a/tests/test_config/sample_config.py +++ b/tests/test_config/sample_config.py @@ -3,23 +3,23 @@ train_data = dict( dataset=dict( - type='CIFAR10Dataset', - root='/path/to/data', + type="CIFAR10Dataset", + root="/path/to/data", download=True, transform_pipeline=[ - dict(type='RandomResizedCrop', size=224), - dict(type='RandomHorizontalFlip'), - dict(type='ToTensor'), - dict(type='Normalize', mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)) - ] + dict(type="RandomResizedCrop", size=224), + dict(type="RandomHorizontalFlip"), + dict(type="ToTensor"), + dict(type="Normalize", mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)), + ], ), dataloader=dict( batch_size=64, pin_memory=True, num_workers=4, sampler=dict( - type='DataParallelSampler', + type="DataParallelSampler", shuffle=True, - ) - ) + ), + ), ) diff --git a/tests/test_config/test_load_config.py b/tests/test_config/test_load_config.py index 550af2a4ae81656b4da0c159ce5a04bfdbb891cc..66e473459445eaed88bd92c33dbd374c212b905a 100644 --- a/tests/test_config/test_load_config.py +++ b/tests/test_config/test_load_config.py @@ -3,17 +3,15 @@ from pathlib import Path -import pytest - from colossalai.context.config import Config -@pytest.mark.cpu def test_load_config(): - filename = Path(__file__).parent.joinpath('sample_config.py') + filename = Path(__file__).parent.joinpath("sample_config.py") config = Config.from_file(filename) - assert config.train_data, 'cannot access train data as attribute' - assert config.train_data.dataset, 'cannot access grandchild attribute' - assert isinstance(config.train_data.dataset.transform_pipeline[0], dict), \ - f'expected attribute transform_pipeline elements to be a dict, but found {type(config.train_data.dataset.transform_pipeline)}' + assert config.train_data, "cannot access train data as attribute" + assert config.train_data.dataset, "cannot access grandchild attribute" + assert isinstance( + config.train_data.dataset.transform_pipeline[0], dict + ), f"expected attribute transform_pipeline elements to be a dict, but found {type(config.train_data.dataset.transform_pipeline)}" diff --git a/tests/test_context/configs/parallel_2d_init.py b/tests/test_context/configs/parallel_2d_init.py deleted file mode 100644 index 6af884450ad0fee42d86fd1ad7ee950d576dd7da..0000000000000000000000000000000000000000 --- a/tests/test_context/configs/parallel_2d_init.py +++ /dev/null @@ -1,10 +0,0 @@ -#!/usr/bin/env python -# -*- encoding: utf-8 -*- - -parallel = dict( - pipeline=dict(size=2), - tensor=dict( - size=4, - mode='2d' - ) -) diff --git a/tests/test_context/configs/parallel_2p5d_init.py b/tests/test_context/configs/parallel_2p5d_init.py deleted file mode 100644 index c2d896d383e26d1530bd05d4127dfdafec57d826..0000000000000000000000000000000000000000 --- a/tests/test_context/configs/parallel_2p5d_init.py +++ /dev/null @@ -1,11 +0,0 @@ -#!/usr/bin/env python -# -*- encoding: utf-8 -*- - -parallel = dict( - pipeline=dict(size=2), - tensor=dict( - size=8, - depth=2, - mode='2.5d' - ) -) diff --git a/tests/test_context/configs/parallel_3d_init.py b/tests/test_context/configs/parallel_3d_init.py deleted file mode 100644 index 0ec724f8bb4f2513457568eaeb221727e4da2ff1..0000000000000000000000000000000000000000 --- a/tests/test_context/configs/parallel_3d_init.py +++ /dev/null @@ -1,10 +0,0 @@ -#!/usr/bin/env python -# -*- encoding: utf-8 -*- - -parallel = dict( - pipeline=dict(size=2), - tensor=dict( - size=8, - mode='3d' - ) -) diff --git a/tests/test_data/test_data_parallel_sampler.py b/tests/test_data/test_data_parallel_sampler.py deleted file mode 100644 index 2ad3fd696c39179d9eaec2a331925ea4f5ab3bf1..0000000000000000000000000000000000000000 --- a/tests/test_data/test_data_parallel_sampler.py +++ /dev/null @@ -1,63 +0,0 @@ -#!/usr/bin/env python -# -*- encoding: utf-8 -*- - -import os -from pathlib import Path - -import pytest -import torch -import torch.distributed as dist -from torchvision import datasets, transforms - -import colossalai -from colossalai.context import Config, ParallelMode -from colossalai.core import global_context as gpc -from colossalai.testing import rerun_if_address_is_in_use, spawn -from colossalai.utils import get_dataloader - -CONFIG = Config(dict( - parallel=dict( - pipeline=dict(size=1), - tensor=dict(size=1, mode=None), - ), - seed=1024, -)) - - -def run_data_sampler(rank, world_size, port): - dist_args = dict(config=CONFIG, rank=rank, world_size=world_size, backend='gloo', port=port, host='localhost') - colossalai.launch(**dist_args) - print('finished initialization') - - # build dataset - transform_pipeline = [transforms.ToTensor()] - transform_pipeline = transforms.Compose(transform_pipeline) - dataset = datasets.CIFAR10(root=Path(os.environ['DATA']), train=True, download=True, transform=transform_pipeline) - - # build dataloader - dataloader = get_dataloader(dataset, batch_size=8, add_sampler=True) - - data_iter = iter(dataloader) - img, label = data_iter.next() - img = img[0] - - if gpc.get_local_rank(ParallelMode.DATA) != 0: - img_to_compare = img.clone() - else: - img_to_compare = img - dist.broadcast(img_to_compare, src=0, group=gpc.get_group(ParallelMode.DATA)) - - if gpc.get_local_rank(ParallelMode.DATA) != 0: - assert not torch.equal( - img, img_to_compare), 'Same image was distributed across ranks but expected it to be different' - torch.cuda.empty_cache() - - -@pytest.mark.cpu -@rerun_if_address_is_in_use() -def test_data_sampler(): - spawn(run_data_sampler, 4) - - -if __name__ == '__main__': - test_data_sampler() diff --git a/tests/test_data/test_deterministic_dataloader.py b/tests/test_data/test_deterministic_dataloader.py deleted file mode 100644 index 239e79dff7d85f81f71802cbbdedbae6789b869d..0000000000000000000000000000000000000000 --- a/tests/test_data/test_deterministic_dataloader.py +++ /dev/null @@ -1,74 +0,0 @@ -#!/usr/bin/env python -# -*- encoding: utf-8 -*- - -import os -from pathlib import Path - -import pytest -import torch -import torch.distributed as dist -from torchvision import datasets, transforms - -import colossalai -from colossalai.context import Config, ParallelMode -from colossalai.core import global_context as gpc -from colossalai.testing import rerun_if_address_is_in_use, spawn -from colossalai.utils import get_dataloader - -CONFIG = Config( - dict( - train_data=dict( - dataset=dict( - type='CIFAR10', - root=Path(os.environ['DATA']), - train=True, - download=True, - ), - dataloader=dict(num_workers=2, batch_size=2, shuffle=True), - ), - parallel=dict( - pipeline=dict(size=1), - tensor=dict(size=1, mode=None), - ), - seed=1024, - )) - - -def run_data_sampler(rank, world_size, port): - dist_args = dict(config=CONFIG, rank=rank, world_size=world_size, backend='gloo', port=port, host='localhost') - colossalai.launch(**dist_args) - - # build dataset - transform_pipeline = [transforms.ToTensor(), transforms.RandomCrop(size=32, padding=4)] - transform_pipeline = transforms.Compose(transform_pipeline) - dataset = datasets.CIFAR10(root=Path(os.environ['DATA']), train=True, download=True, transform=transform_pipeline) - - # build dataloader - dataloader = get_dataloader(dataset, batch_size=8, add_sampler=False) - - data_iter = iter(dataloader) - img, label = data_iter.next() - img = img[0] - - if gpc.get_local_rank(ParallelMode.DATA) != 0: - img_to_compare = img.clone() - else: - img_to_compare = img - dist.broadcast(img_to_compare, src=0, group=gpc.get_group(ParallelMode.DATA)) - - if gpc.get_local_rank(ParallelMode.DATA) != 0: - # this is without sampler - # this should be false if data parallel sampler to given to the dataloader - assert torch.equal(img, - img_to_compare), 'Same image was distributed across ranks and expected it to be the same' - torch.cuda.empty_cache() - - -@pytest.mark.cpu -@rerun_if_address_is_in_use() -def test_data_sampler(): - spawn(run_data_sampler, 4) - - -if __name__ == '__main__': - test_data_sampler() diff --git a/tests/test_data_pipeline_tensor_parallel/test_cifar_with_data_pipeline_tensor.py b/tests/test_data_pipeline_tensor_parallel/test_cifar_with_data_pipeline_tensor.py deleted file mode 100644 index 4d63592f12b0321ee3f48c0f289871cb213d6a17..0000000000000000000000000000000000000000 --- a/tests/test_data_pipeline_tensor_parallel/test_cifar_with_data_pipeline_tensor.py +++ /dev/null @@ -1,100 +0,0 @@ -import os -from pathlib import Path - -import pytest -import torch -from torchvision import transforms -from torchvision.datasets import CIFAR10 - -import colossalai -from colossalai.amp import AMP_TYPE -from colossalai.context import ParallelMode -from colossalai.core import global_context as gpc -from colossalai.logging import get_dist_logger -from colossalai.nn import CrossEntropyLoss -from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR -from colossalai.pipeline.pipelinable import PipelinableContext -from colossalai.testing import rerun_if_address_is_in_use, skip_if_not_enough_gpus, spawn -from colossalai.trainer import Trainer, hooks -from colossalai.utils import get_dataloader - -BATCH_SIZE = 4 -NUM_EPOCHS = 60 -WARMUP_EPOCHS = 5 -CONFIG = dict(NUM_MICRO_BATCHES=2, - parallel=dict(pipeline=2, tensor=dict(size=2, mode='1d')), - fp16=dict(mode=AMP_TYPE.NAIVE), - gradient_accumulation=2) - - -def run_trainer(rank, world_size, port): - colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') - - logger = get_dist_logger() - - # get logger - logger = get_dist_logger() - - pipelinable = PipelinableContext() - try: - from titans.model.vit import vit_tiny_patch4_32 - except ImportError: - logger.warning('skip the test_cifar_with_data_pipeline_tensor test because titan is not installed') - logger.warning('please install titan from https://github.com/hpcaitech/Titans') - return - with pipelinable: - model = vit_tiny_patch4_32() - pipelinable.to_layer_list() - pipelinable.policy = "uniform" - model = pipelinable.partition(1, gpc.pipeline_parallel_size, gpc.get_local_rank(ParallelMode.PIPELINE)) - - # craete dataloaders - root = Path(os.environ['DATA']) - transform_train = transforms.Compose([ - transforms.RandomCrop(32, padding=4, pad_if_needed=True), - transforms.AutoAugment(policy=transforms.AutoAugmentPolicy.CIFAR10), - transforms.ToTensor(), - transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), - ]) - train_dataset = CIFAR10(root=root, train=True, download=True, transform=transform_train) - train_dataloader = get_dataloader(dataset=train_dataset, shuffle=True, batch_size=BATCH_SIZE, pin_memory=True) - - # create loss function - criterion = CrossEntropyLoss(label_smoothing=0.1) - - # create optimizer - optimizer = torch.optim.AdamW(model.parameters(), lr=0.001, weight_decay=0) - - # create lr scheduler - lr_scheduler = CosineAnnealingWarmupLR(optimizer=optimizer, total_steps=NUM_EPOCHS, warmup_steps=WARMUP_EPOCHS) - - # intiailize - engine, train_dataloader, *_ = colossalai.initialize(model=model, - optimizer=optimizer, - criterion=criterion, - train_dataloader=train_dataloader) - - logger = get_dist_logger() - - trainer = Trainer(engine=engine, logger=logger) - - hook_list = [ - hooks.LRSchedulerHook(lr_scheduler=lr_scheduler, by_epoch=False), - ] - - trainer.fit(train_dataloader=train_dataloader, - epochs=NUM_EPOCHS, - max_steps=2, - hooks=hook_list, - display_progress=True) - - -@pytest.mark.dist -@skip_if_not_enough_gpus(min_gpus=8) -@rerun_if_address_is_in_use() -def test_hybrid_parallel(): - spawn(run_trainer, 8) - - -if __name__ == '__main__': - test_hybrid_parallel() diff --git a/tests/test_data_pipeline_tensor_parallel/test_cifar_with_data_pipeline_tensor_v2.py b/tests/test_data_pipeline_tensor_parallel/test_cifar_with_data_pipeline_tensor_v2.py deleted file mode 100644 index 67d2ba5f5d987606586d5af0d17e058f9fb7c8b5..0000000000000000000000000000000000000000 --- a/tests/test_data_pipeline_tensor_parallel/test_cifar_with_data_pipeline_tensor_v2.py +++ /dev/null @@ -1,104 +0,0 @@ -import os -from pathlib import Path - -import pytest -import torch -from torchvision import transforms -from torchvision.datasets import CIFAR10 - -import colossalai -from colossalai.amp import AMP_TYPE -from colossalai.context import ParallelMode -from colossalai.core import global_context as gpc -from colossalai.engine.schedule._pipeline_schedule_v2 import PipelineScheduleV2 -from colossalai.logging import disable_existing_loggers, get_dist_logger -from colossalai.nn import CrossEntropyLoss -from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR -from colossalai.pipeline.pipelinable import PipelinableContext -from colossalai.testing import rerun_if_address_is_in_use, spawn -from colossalai.trainer import Trainer, hooks -from colossalai.utils import get_dataloader - -disable_existing_loggers() -BATCH_SIZE = 4 -NUM_EPOCHS = 10 -WARMUP_EPOCHS = 5 -CONFIG = dict(NUM_MICRO_BATCHES=2, - parallel=dict(pipeline=2, tensor=dict(size=1, mode='1d')), - fp16=dict(mode=AMP_TYPE.NAIVE), - gradient_accumulation=2) - - -def run_trainer(rank, world_size, port): - disable_existing_loggers() - colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') - - disable_existing_loggers() - # get logger - logger = get_dist_logger() - - pipelinable = PipelinableContext() - try: - from titans.model.vit import vit_tiny_patch4_32 - except ImportError: - logger.warning('skip the test_cifar_with_data_pipeline_tensor test because titan is not installed') - logger.warning('please install titan from https://github.com/hpcaitech/Titans') - return - with pipelinable: - model = vit_tiny_patch4_32() - pipelinable.to_layer_list() - pipelinable.policy = "uniform" - model = pipelinable.partition(1, gpc.pipeline_parallel_size, gpc.get_local_rank(ParallelMode.PIPELINE)) - - # craete dataloaders - root = Path(os.environ['DATA']) - transform_train = transforms.Compose([ - transforms.RandomCrop(32, padding=4, pad_if_needed=True), - transforms.AutoAugment(policy=transforms.AutoAugmentPolicy.CIFAR10), - transforms.ToTensor(), - transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), - ]) - train_dataset = CIFAR10(root=root, train=True, download=True, transform=transform_train) - train_dataloader = get_dataloader(dataset=train_dataset, shuffle=True, batch_size=BATCH_SIZE, pin_memory=True) - - # create loss function - criterion = CrossEntropyLoss(label_smoothing=0.1) - - # create optimizer - optimizer = torch.optim.AdamW(model.parameters(), lr=0.001, weight_decay=0) - - # create lr scheduler - lr_scheduler = CosineAnnealingWarmupLR(optimizer=optimizer, total_steps=NUM_EPOCHS, warmup_steps=WARMUP_EPOCHS) - - # intiailize - engine, train_dataloader, *_ = colossalai.initialize(model=model, - optimizer=optimizer, - criterion=criterion, - train_dataloader=train_dataloader) - - engine._schedule = PipelineScheduleV2(num_microbatches=gpc.config.NUM_MICRO_BATCHES) - - logger = get_dist_logger() - - trainer = Trainer(engine=engine, logger=logger) - - hook_list = [ - hooks.LRSchedulerHook(lr_scheduler=lr_scheduler, by_epoch=False), - ] - - trainer.fit(train_dataloader=train_dataloader, - max_steps=2, - epochs=NUM_EPOCHS, - hooks=hook_list, - display_progress=True) - - -@pytest.mark.dist -@rerun_if_address_is_in_use() -def test_hybrid_parallel(): - spawn(run_trainer, 2) - disable_existing_loggers() - - -if __name__ == '__main__': - test_hybrid_parallel() diff --git a/tests/test_ddp/test_ddp_ignore_params.py b/tests/test_ddp/test_ddp_ignore_params.py deleted file mode 100644 index 39efcd41a1d46e13238fad7e27a5743665bd10f9..0000000000000000000000000000000000000000 --- a/tests/test_ddp/test_ddp_ignore_params.py +++ /dev/null @@ -1,92 +0,0 @@ -import os -import random -from typing import Callable, Type - -import numpy as np -import pytest -import torch -import torch.distributed as dist - -import colossalai -from colossalai.nn.parallel import ColoDDP -from colossalai.tensor import ProcessGroup -from colossalai.testing import rerun_if_address_is_in_use, spawn -from colossalai.utils.cuda import get_current_device -from colossalai.zero import ColoInitContext, ZeroDDP -from colossalai.zero.gemini.chunk import ChunkManager, search_chunk_configuration -from colossalai.zero.gemini.gemini_mgr import GeminiManager - - -def set_seed(seed): - random.seed(seed) - os.environ['PYTHONHASHSEED'] = str(seed) - np.random.seed(seed) - torch.manual_seed(seed) - torch.cuda.manual_seed(seed) - torch.backends.cudnn.deterministic = True - - -def init_ddp(module: torch.nn.Module) -> ColoDDP: - pg = ProcessGroup() - return ColoDDP(module, process_group=pg) - - -def init_ddpv2(module: torch.nn.Module) -> ZeroDDP: - chunk_config, *_ = search_chunk_configuration(module, 4, 1024) - chunk_manager = ChunkManager(chunk_config) - gemini_manager = GeminiManager('cuda', chunk_manager) - return ZeroDDP(module, gemini_manager) - - -class Net(torch.nn.Module): - - def __init__(self) -> None: - super().__init__() - self.fc1 = torch.nn.Linear(3, 3, bias=False) - self.fc2 = torch.nn.Linear(3, 1, bias=False) - - def forward(self, x): - return self.fc2(self.fc1(x)) - - -def run_fwd_bwd(ddp_cls: Type[ColoDDP], init_ddp_func: Callable[[torch.nn.Module], ColoDDP]): - with ColoInitContext(device=get_current_device()): - model = Net().cuda() - w1 = model.fc1.weight - w2 = model.fc2.weight - ddp_cls.set_params_to_ignore([w2]) - model = init_ddp_func(model) - x = torch.rand(2, 3, device=get_current_device()) - logits = model(x) - loss = torch.sum(logits) - model.backward(loss) - - if ddp_cls is ZeroDDP: - w1s_grad = w1 - else: - w1s_grad = w1.grad - - w1_grads = [torch.empty_like(w1) for _ in range(dist.get_world_size())] - dist.all_gather(w1_grads, w1s_grad) - assert torch.equal(w1_grads[0], w1_grads[1]) - w2_grads = [torch.empty_like(w2) for _ in range(dist.get_world_size())] - dist.all_gather(w2_grads, w2.grad) - assert not torch.equal(w2_grads[0], w2_grads[1]) - - -def run_dist(rank, world_size, port): - colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') - set_seed(dist.get_rank()) - run_fwd_bwd(ColoDDP, init_ddp) - run_fwd_bwd(ZeroDDP, init_ddpv2) - - -@pytest.mark.dist -@pytest.mark.parametrize('world_size', [2]) -@rerun_if_address_is_in_use() -def test_ddp_ignore_params(world_size): - spawn(run_dist, world_size) - - -if __name__ == '__main__': - test_ddp_ignore_params(2) diff --git a/tests/test_ddp/test_ddp_state_dict.py b/tests/test_ddp/test_ddp_state_dict.py deleted file mode 100644 index 54f89f972765ecab174a8370ef7024b714fae8b5..0000000000000000000000000000000000000000 --- a/tests/test_ddp/test_ddp_state_dict.py +++ /dev/null @@ -1,67 +0,0 @@ -from collections import OrderedDict - -import pytest -import torch - -import colossalai -from colossalai.nn.parallel import ColoDDP -from colossalai.tensor import ColoParameter, ProcessGroup -from colossalai.testing import rerun_if_address_is_in_use, spawn -from colossalai.utils.cuda import get_current_device -from colossalai.zero import ColoInitContext -from tests.components_to_test.registry import non_distributed_component_funcs - - -def check_state_dict_equal(state_dict: OrderedDict, other_state_dict: OrderedDict): - for (k1, t1), (k2, t2) in zip(state_dict.items(), other_state_dict.items()): - assert k1 == k2 - - if t1.device != t2.device: - temp_t2 = t2.to(t1.device) - else: - temp_t2 = t2 - - assert torch.equal(t1, temp_t2), "\t{}\n\t{}".format(t1, temp_t2) - - -def init_ddp(module: torch.nn.Module) -> ColoDDP: - pg = ProcessGroup() - return ColoDDP(module, process_group=pg) - - -def run_ddp_state_dict(): - get_components_func = non_distributed_component_funcs.get_callable('gpt2') - model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func() - torch_model = model_builder().cuda() - with ColoInitContext(device=get_current_device()): - model = model_builder() - model = init_ddp(model) - torch_state_dict = torch_model.state_dict() - - for param in model.parameters(): - if isinstance(param, ColoParameter): - assert param.get_process_group() is not None - model.load_state_dict(torch_state_dict) - - for param in model.parameters(): - if isinstance(param, ColoParameter): - assert param.get_process_group() is not None - - state_dict = model.state_dict() - check_state_dict_equal(torch_state_dict, state_dict) - - -def run_dist(rank, world_size, port): - colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') - run_ddp_state_dict() - - -@pytest.mark.dist -@pytest.mark.parametrize('world_size', [1, 2]) -@rerun_if_address_is_in_use() -def test_state_dict(world_size): - spawn(run_dist, world_size) - - -if __name__ == '__main__': - test_state_dict(2) diff --git a/tests/test_ddp/test_reducer.py b/tests/test_ddp/test_reducer.py deleted file mode 100644 index e8d3a112c938d7c883439314cf3e735dec2e8af3..0000000000000000000000000000000000000000 --- a/tests/test_ddp/test_reducer.py +++ /dev/null @@ -1,47 +0,0 @@ -from functools import partial - -import pytest -import torch -import torch.distributed as dist -from torch.distributed.distributed_c10d import _get_default_group - -import colossalai -from colossalai.nn.parallel.reducer import Reducer -from colossalai.testing import rerun_if_address_is_in_use, spawn -from colossalai.utils.cuda import get_current_device - -REDUCE_CNT = 0 - - -def check_eq(grad, grad_clone): - global REDUCE_CNT - print(f'Rank{dist.get_rank()} check {REDUCE_CNT}') - REDUCE_CNT += 1 - assert torch.allclose(grad, grad_clone) - - -def run_reducer(): - grads = [torch.rand(64, i + 1, device=get_current_device()) for i in range(10)] - grads_clone = [g.clone().detach() for g in grads] - for g in grads: - dist.all_reduce(g) - reducer = Reducer(bucket_size_mb=1) - for g, g_clone in zip(grads, grads_clone): - reducer.all_reduce_async(g_clone, _get_default_group(), partial(check_eq, g)) - reducer.flush() - - -def run_dist(rank, world_size, port): - colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') - run_reducer() - - -@pytest.mark.dist -@pytest.mark.parametrize('world_size', [1, 2]) -@rerun_if_address_is_in_use() -def test_reducer(world_size): - spawn(run_dist, world_size) - - -if __name__ == '__main__': - test_reducer(2) diff --git a/tests/test_device/test_alpha_beta.py b/tests/test_device/test_alpha_beta.py index ab933ed57d0d2630f5ebd7dad6b86afc1677f6a5..f4a88f79c37b2a5cf5bcf657584e442dff8ace72 100644 --- a/tests/test_device/test_alpha_beta.py +++ b/tests/test_device/test_alpha_beta.py @@ -8,7 +8,7 @@ from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn def check_alpha_beta(rank, world_size, port, physical_devices): disable_existing_loggers() - launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") profiler = AlphaBetaProfiler(physical_devices) ab_dict = profiler.profile_ab() for _, (alpha, beta) in ab_dict.items(): @@ -17,11 +17,11 @@ def check_alpha_beta(rank, world_size, port, physical_devices): @pytest.mark.skip(reason="Skip because assertion fails for CI devices") @pytest.mark.dist -@parameterize('physical_devices', [[0, 1, 2, 3], [0, 3]]) +@parameterize("physical_devices", [[0, 1, 2, 3], [0, 3]]) @rerun_if_address_is_in_use() def test_profile_alpha_beta(physical_devices): spawn(check_alpha_beta, 4, physical_devices=physical_devices) -if __name__ == '__main__': +if __name__ == "__main__": test_profile_alpha_beta() diff --git a/tests/test_device/test_device_mesh.py b/tests/test_device/test_device_mesh.py index 3be057b3a98bba70c3e162d7148fb9c66ee84792..af44af5d90974dd1af1090682081ae83df68ee69 100644 --- a/tests/test_device/test_device_mesh.py +++ b/tests/test_device/test_device_mesh.py @@ -1,21 +1,89 @@ -from colossalai.device.device_mesh import DeviceMesh +import pytest import torch +import torch.distributed as dist + +import colossalai +from colossalai.device.device_mesh import DeviceMesh +from colossalai.testing import rerun_if_address_is_in_use, spawn def test_device_mesh(): - physical_mesh_id = torch.arange(0, 16).reshape(2, 8) + physical_mesh_id = torch.arange(0, 16) mesh_shape = (4, 4) # [[0, 1, 2, 3], # [4, 5, 6, 7], # [8, 9, 10,11], # [12,13,14,15]] device_mesh = DeviceMesh(physical_mesh_id, mesh_shape) - assert device_mesh.convert_map[5] == [1, 1] - assert device_mesh.convert_map[11] == [2, 3] - assert device_mesh.global_rank_to_process_groups_with_logical_rank(0)[0] == [[0, 0], [1, 0], [2, 0], [3, 0]] - assert device_mesh.global_rank_to_process_groups_with_logical_rank(2)[1] == [[0, 0], [0, 1], [0, 2], [0, 3]] - assert device_mesh.global_rank_to_process_groups_with_global_rank(2)[1] == [0, 1, 2, 3] + assert device_mesh.global_rank_to_local_rank(5) == [1, 1] + assert device_mesh.global_rank_to_local_rank(11) == [2, 3] + assert device_mesh.get_ranks_in_process_group(axis=1, global_rank=2) == [0, 1, 2, 3] + + +def check_1d_device_mesh(): + # check for 1D device mesh + process_group = dist.GroupMember.WORLD + device_mesh = DeviceMesh.from_process_group(process_group) + + # checks + assert device_mesh.shape == [4] + assert len(device_mesh.get_process_group_for_all_axes().keys()) == 1, "Expected 1 axis for the process group dict" + assert device_mesh.get_process_group(axis=0) == process_group, "Expected world process group" + assert device_mesh.is_initialized + assert device_mesh.num_devices == 4 + assert device_mesh.is_initialized + assert device_mesh.logical_mesh_id is None + assert device_mesh._is_init_from_process_group + + +def check_2d_device_mesh(): + # create process group for 2D device mesh + first_row_ranks = [0, 1] + second_row_ranks = [2, 3] + first_col_ranks = [0, 2] + second_col_ranks = [1, 3] + + first_row_pg = dist.new_group(first_row_ranks, backend="nccl") + second_row_pg = dist.new_group(second_row_ranks, backend="nccl") + first_col_pg = dist.new_group(first_col_ranks, backend="nccl") + second_col_pg = dist.new_group(second_col_ranks, backend="nccl") + + # check for + current_rank = dist.get_rank() + + if current_rank in first_row_ranks: + row_pg = first_row_pg + else: + row_pg = second_row_pg + + if current_rank in first_col_ranks: + col_pg = first_col_pg + else: + col_pg = second_col_pg + + device_mesh = DeviceMesh.from_process_group([col_pg, row_pg]) + + # checks + assert device_mesh.shape == [2, 2] + assert len(device_mesh.get_process_group_for_all_axes().keys()) == 2, "Expected 2 axes for the process group dict" + assert device_mesh.get_process_group(axis=0) == col_pg, "Expected column process group" + assert device_mesh.get_process_group(axis=1) == row_pg, "Expected row process group" + assert device_mesh.num_devices == 4 + assert device_mesh.is_initialized + assert device_mesh.logical_mesh_id is None + assert device_mesh._is_init_from_process_group + + +def check_init_from_process_group(rank, world_size, port): + colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") + + +@pytest.mark.dist +@rerun_if_address_is_in_use() +def test_device_mesh_from_process_group(): + spawn(check_init_from_process_group, 4) -if __name__ == '__main__': +if __name__ == "__main__": test_device_mesh() + test_device_mesh_from_process_group() diff --git a/tests/test_device/test_extract_alpha_beta.py b/tests/test_device/test_extract_alpha_beta.py index 52604b9c6a4909d437ef55f446619d49681ccebe..34f2aacc18b26fee73fdcfe0160819a8125d56a9 100644 --- a/tests/test_device/test_extract_alpha_beta.py +++ b/tests/test_device/test_extract_alpha_beta.py @@ -8,7 +8,7 @@ from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn def check_extract_alpha_beta(rank, world_size, port, physical_devices): disable_existing_loggers() - launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") profiler = AlphaBetaProfiler(physical_devices) mesh_alpha, mesh_beta = profiler.extract_alpha_beta_for_device_mesh() @@ -20,11 +20,11 @@ def check_extract_alpha_beta(rank, world_size, port, physical_devices): @pytest.mark.skip(reason="Skip because assertion may fail for CI devices") @pytest.mark.dist -@parameterize('physical_devices', [[0, 1, 2, 3], [0, 3]]) +@parameterize("physical_devices", [[0, 1, 2, 3], [0, 3]]) @rerun_if_address_is_in_use() def test_profile_alpha_beta(physical_devices): spawn(check_extract_alpha_beta, 4, physical_devices=physical_devices) -if __name__ == '__main__': +if __name__ == "__main__": test_profile_alpha_beta() diff --git a/tests/test_device/test_init_logical_pg.py b/tests/test_device/test_init_logical_pg.py index 2b7060c4846aef61d7680e4884fe0f87526f8795..3b398a9171824907ef0a8293788d84e472b2ed13 100644 --- a/tests/test_device/test_init_logical_pg.py +++ b/tests/test_device/test_init_logical_pg.py @@ -3,35 +3,28 @@ import torch import torch.distributed as dist from torch.distributed import ReduceOp -from colossalai.core import global_context as gpc from colossalai.device.device_mesh import DeviceMesh from colossalai.initialize import launch from colossalai.testing import rerun_if_address_is_in_use, spawn def check_layer(rank, world_size, port): - launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") physical_mesh_id = torch.arange(0, 4) - assert rank == gpc.get_global_rank() + assert rank == dist.get_rank() tensor_to_check = torch.tensor([2, 2, 2, 2]).cuda() mesh_shape = (2, 2) # [[0, 1, # [2, 3]] device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True) - logical_pg_dict = {0: [[0, 2], [1, 3]], 1: [[0, 1], [2, 3]]} - logical_process_groups = device_mesh.process_groups_dict - for mesh_dim, pgs in logical_pg_dict.items(): - for index, pg in enumerate(pgs): - if rank in pg: - tensor = torch.ones(4).cuda() - group = logical_process_groups[mesh_dim][index][1] - dist.all_reduce(tensor, op=ReduceOp.SUM, group=group) - assert tensor.equal(tensor_to_check) - - gpc.destroy() + for axis in range(len(mesh_shape)): + tensor = torch.ones(4).cuda() + pg = device_mesh.get_process_group(axis=axis) + dist.all_reduce(tensor, op=ReduceOp.SUM, group=pg) + assert tensor.equal(tensor_to_check) @pytest.mark.dist @@ -40,5 +33,5 @@ def test_logical_pg(): spawn(check_layer, 4) -if __name__ == '__main__': +if __name__ == "__main__": test_logical_pg() diff --git a/tests/test_device/test_search_logical_device_mesh.py b/tests/test_device/test_search_logical_device_mesh.py index b22a76eabc2fec34ec7d9922a9cd988c58c576aa..d9d4e79c1f57d32fb18d24ce91bb045536083434 100644 --- a/tests/test_device/test_search_logical_device_mesh.py +++ b/tests/test_device/test_search_logical_device_mesh.py @@ -8,7 +8,7 @@ from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn def check_alpha_beta(rank, world_size, port, physical_devices): disable_existing_loggers() - launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") profiler = AlphaBetaProfiler(physical_devices) best_logical_mesh = profiler.search_best_logical_mesh() @@ -20,11 +20,11 @@ def check_alpha_beta(rank, world_size, port, physical_devices): @pytest.mark.skip(reason="Skip because assertion may fail for CI devices") @pytest.mark.dist -@parameterize('physical_devices', [[0, 1, 2, 3], [0, 3]]) +@parameterize("physical_devices", [[0, 1, 2, 3], [0, 3]]) @rerun_if_address_is_in_use() def test_profile_alpha_beta(physical_devices): spawn(check_alpha_beta, 4, physical_devices=physical_devices) -if __name__ == '__main__': +if __name__ == "__main__": test_profile_alpha_beta() diff --git a/tests/test_engine/test_engine.py b/tests/test_engine/test_engine.py deleted file mode 100644 index 62493cf3712dd46eba2bde0dc3d34531725230e3..0000000000000000000000000000000000000000 --- a/tests/test_engine/test_engine.py +++ /dev/null @@ -1,62 +0,0 @@ -import pytest - -import colossalai -from colossalai.amp import AMP_TYPE -from colossalai.core import global_context as gpc -from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn -from tests.components_to_test.registry import non_distributed_component_funcs - -CONFIG = dict(parallel=dict(pipeline=dict(size=1), tensor=dict(size=1, mode=None)), - fp16=dict(mode=None), - clip_grad_norm=1.0) - - -@parameterize('model_name', ['repeated_computed_layers', 'resnet18', 'repeated_computed_layers']) -@parameterize('amp_mode', [AMP_TYPE.APEX, AMP_TYPE.TORCH, AMP_TYPE.NAIVE, None]) -def run_train(model_name, amp_mode): - # FIXME: test bert - get_components_func = non_distributed_component_funcs.get_callable(model_name) - gpc.config.fp16['mode'] = amp_mode - model_builder, train_dataloader, _, optimizer_class, criterion = get_components_func() - - model = model_builder(checkpoint=False) - engine, train_dataloader, *args = colossalai.initialize(model=model, - optimizer=optimizer_class(model.parameters(), lr=1e-3), - criterion=criterion, - train_dataloader=train_dataloader) - - try: - engine.train() - for data, label in train_dataloader: - engine.zero_grad() - data = data.cuda() - label = label.cuda() - if criterion: - output = engine(data) - loss = engine.criterion(output, label) - else: - loss = engine(data, label) - engine.backward(loss) - engine.step() - break - except IndexError: - # if using apex amp, NetWithRepeatedlyComputedLayers will raise an index out of range issue - # the following check fails in apex - # if cached_x.grad_fn.next_functions[1][0].variable is not x: - pass - - -def run_engine(rank, world_size, port): - # init dist env - colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') - run_train() - - -@pytest.mark.dist -@rerun_if_address_is_in_use() -def test_engine(): - spawn(run_engine, 2) - - -if __name__ == '__main__': - test_engine() diff --git a/tests/test_engine/test_gradient_accumluation.py b/tests/test_engine/test_gradient_accumluation.py deleted file mode 100644 index 7783827c7c44c929a4a66667ffdf24ffc2d1aa8f..0000000000000000000000000000000000000000 --- a/tests/test_engine/test_gradient_accumluation.py +++ /dev/null @@ -1,95 +0,0 @@ -import os -from pathlib import Path - -import pytest -import torch -import torch.nn as nn -from torch.optim import Adam -from torchvision import transforms -from torchvision.datasets import CIFAR10 -from torchvision.models import resnet18 - -import colossalai -from colossalai.core import global_context as gpc -from colossalai.logging import get_dist_logger -from colossalai.testing import rerun_if_address_is_in_use, spawn -from colossalai.utils import get_dataloader - -# Config -BATCH_SIZE = 2 -NUM_CLASSES = 10 - -CONFIG = dict(parallel=dict(pipeline=dict(size=1), tensor=dict(size=1, mode=None)), - clip_grad_norm=1.0, - gradient_accumulation=4) - - -def run_no_pipeline(rank, world_size, port): - - # init dist env - colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') - - # build model - model = resnet18(num_classes=10) - - # build dataloaders - train_dataset = CIFAR10(root=Path(os.environ['DATA']), - download=True, - transform=transforms.Compose([ - transforms.ToTensor(), - transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)) - ])) - train_dataloader = get_dataloader(dataset=train_dataset, - shuffle=True, - batch_size=BATCH_SIZE, - pin_memory=True, - drop_last=True) - - # build optimizer - optimizer = Adam(model.parameters(), lr=0.001) - criterion = nn.CrossEntropyLoss() - - engine, train_dataloader, *args = colossalai.initialize(model=model, - optimizer=optimizer, - criterion=criterion, - train_dataloader=train_dataloader) - logger = get_dist_logger() - rank = torch.distributed.get_rank() - param_track = [] - grad_track = [] - next(model.parameters()).retain_grad() - - engine.train() - step = 0 - for img, label in train_dataloader: - engine.zero_grad() - img = img.cuda() - label = label.cuda() - output = engine(img) - loss = engine.criterion(output, label) - engine.backward(loss) - engine.step() - - # check - param_track.append(next(model.parameters())[0].clone()) - grad_track.append(next(model.parameters()).grad[0].clone()) - step += 1 - if step == CONFIG['gradient_accumulation']: - break - - assert not torch.all(grad_track[0] == grad_track[-1]), 'grad should be different in different iterations' - assert torch.all(param_track[0] == param_track[1]) and not torch.all(param_track[0] == param_track[-1]), \ - 'param should be the same in the first few iterations and only changed in the last iteration' - - gpc.destroy() - torch.cuda.empty_cache() - - -@pytest.mark.dist -@rerun_if_address_is_in_use() -def test_engine(): - spawn(run_no_pipeline, 4) - - -if __name__ == '__main__': - test_engine() diff --git a/tests/test_fx/test_codegen/test_activation_checkpoint_codegen.py b/tests/test_fx/test_codegen/test_activation_checkpoint_codegen.py index ab483f7e47a3e0821b36f06e6ecfad6f56e90a83..10fe9815541ce6ab409556da6bad188d65870038 100644 --- a/tests/test_fx/test_codegen/test_activation_checkpoint_codegen.py +++ b/tests/test_fx/test_codegen/test_activation_checkpoint_codegen.py @@ -4,22 +4,23 @@ import torch.nn.functional as F from torch.utils.checkpoint import checkpoint import colossalai -from colossalai.core import global_context as gpc from colossalai.fx import ColoTracer from colossalai.fx.graph_module import ColoGraphModule +from colossalai.legacy.core import global_context as gpc from colossalai.testing import rerun_if_address_is_in_use, spawn try: from colossalai.fx.codegen import ActivationCheckpointCodeGen + with_codegen = True except: # fall back to older pytorch version from colossalai.fx.codegen import python_code_with_activation_checkpoint + with_codegen = False class MLP(torch.nn.Module): - def __init__(self): super().__init__() self.linear1 = torch.nn.Linear(4, 4) @@ -30,7 +31,6 @@ class MLP(torch.nn.Module): class relu(torch.nn.Module): - def __init__(self) -> None: super().__init__() self.relu = torch.nn.ReLU(inplace=True) @@ -40,7 +40,6 @@ class relu(torch.nn.Module): class MyModule(torch.nn.Module): - def __init__(self): super().__init__() self.mlp1 = MLP() @@ -64,8 +63,8 @@ class MyModule(torch.nn.Module): def _run_act_ckpt_codegen(rank, world_size, port): - # launch colossalai to make sure we could execute colossalai.utils.checkpoint currectly - colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + # launch colossalai to make sure we could execute colossalai.utils.checkpoint currently + colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") # build model and run forward model = MyModule() @@ -87,26 +86,31 @@ def _run_act_ckpt_codegen(rank, world_size, port): # check ops are annotated with ckpt # also annotate the selected node for offloading - ckpt_nodes = ['mlp1_linear1', 'mlp1_linear2', 'relu_relu', 'relu'] - offload_starts = ['mlp1_linear1'] + ckpt_nodes = ["mlp1_linear1", "mlp1_linear2", "relu_relu", "relu"] + offload_starts = ["mlp1_linear1"] for node in graph.nodes: if node.name in ckpt_nodes: - assert 'activation_checkpoint' in node.meta + assert "activation_checkpoint" in node.meta # annotate the selected node for offload if node.name in offload_starts: - node.meta['activation_offload'] = True + node.meta["activation_offload"] = True gm = ColoGraphModule(model, graph) gm.recompile() # assert checkpoint function will be generated and # the offload option is correct - code = graph.python_code('self').src - assert 'colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_0, True, x, use_reentrant=False)' in code and \ - 'colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_1, False, x, use_reentrant=False)' in code and \ - 'colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_2, False, y, use_reentrant=False)' in code and \ - 'colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_3, False, y, relu, use_reentrant=True)' in code + code = graph.python_code("self").src + assert ( + "colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_0, True, x, use_reentrant=False)" in code + and "colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_1, False, x, use_reentrant=False)" + in code + and "colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_2, False, y, use_reentrant=False)" + in code + and "colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_3, False, y, relu, use_reentrant=True)" + in code + ) # recompile and verify the outputs are consistent fx_out = gm(data1, data2) @@ -115,15 +119,15 @@ def _run_act_ckpt_codegen(rank, world_size, port): gpc.destroy() -@pytest.mark.skipif(not with_codegen, reason='torch version is lower than 1.12.0') +@pytest.mark.skipif(not with_codegen, reason="torch version is lower than 1.12.0") @rerun_if_address_is_in_use() def test_act_ckpt_codegen(): spawn(_run_act_ckpt_codegen, 1) def _run_act_ckpt_python_code_torch11(rank, world_size, port): - # launch colossalai to make sure we could execute colossalai.utils.checkpoint currectly - colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + # launch colossalai to make sure we could execute colossalai.utils.checkpoint currently + colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") # build model and run forward model = MyModule() @@ -144,25 +148,30 @@ def _run_act_ckpt_python_code_torch11(rank, world_size, port): graph._python_code = python_code_with_activation_checkpoint.__get__(graph) # check ops are annotated with ckpt - ckpt_nodes = ['mlp1_linear1', 'mlp1_linear2', 'relu_relu', 'relu'] - offload_starts = ['mlp1_linear1'] + ckpt_nodes = ["mlp1_linear1", "mlp1_linear2", "relu_relu", "relu"] + offload_starts = ["mlp1_linear1"] for node in graph.nodes: if node.name in ckpt_nodes: - assert 'activation_checkpoint' in node.meta + assert "activation_checkpoint" in node.meta # annotate the selected node for offload if node.name in offload_starts: - node.meta['activation_offload'] = True + node.meta["activation_offload"] = True gm = ColoGraphModule(model, graph) gm.recompile() # assert checkpoint function will be generated and # the offload option is correct - code = graph.python_code('self').src - assert 'colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_0, True, x, use_reentrant=False)' in code and \ - 'colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_1, False, x, use_reentrant=False)' in code and \ - 'colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_2, False, y, use_reentrant=False)' in code and \ - 'colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_3, False, y, relu, use_reentrant=True)' in code + code = graph.python_code("self").src + assert ( + "colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_0, True, x, use_reentrant=False)" in code + and "colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_1, False, x, use_reentrant=False)" + in code + and "colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_2, False, y, use_reentrant=False)" + in code + and "colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_3, False, y, relu, use_reentrant=True)" + in code + ) # recompile and verify the outputs are consistent fx_out = gm(data1, data2) @@ -171,12 +180,12 @@ def _run_act_ckpt_python_code_torch11(rank, world_size, port): gpc.destroy() -@pytest.mark.skipif(with_codegen, reason='torch version is equal to or higher than 1.12.0') +@pytest.mark.skipif(with_codegen, reason="torch version is equal to or higher than 1.12.0") @pytest.mark.skip(reason="currently torch11 ColoGraphModule is not done") @rerun_if_address_is_in_use() def test_act_ckpt_python_code_torch11(): spawn(_run_act_ckpt_python_code_torch11, 1) -if __name__ == '__main__': +if __name__ == "__main__": _run_act_ckpt_codegen(rank=0) diff --git a/tests/test_fx/test_codegen/test_nested_activation_checkpoint_codegen.py b/tests/test_fx/test_codegen/test_nested_activation_checkpoint_codegen.py index 9064023d4f68299e0520ba1f0321a0d91b239527..f1e87e5ed1403fb949f498a5b5a6ef22ed6adb4c 100644 --- a/tests/test_fx/test_codegen/test_nested_activation_checkpoint_codegen.py +++ b/tests/test_fx/test_codegen/test_nested_activation_checkpoint_codegen.py @@ -2,22 +2,21 @@ import pytest import torch import colossalai -from colossalai.core import global_context as gpc from colossalai.fx import ColoTracer from colossalai.fx.graph_module import ColoGraphModule +from colossalai.legacy.core import global_context as gpc from colossalai.testing import rerun_if_address_is_in_use, spawn try: from colossalai.fx.codegen import ActivationCheckpointCodeGen + with_codegen = True except: # fall back to older pytorch version - from colossalai.fx.codegen import python_code_with_activation_checkpoint with_codegen = False class MyModule(torch.nn.Module): - def __init__(self): super().__init__() self.linear1 = torch.nn.Linear(4, 4) @@ -32,8 +31,8 @@ class MyModule(torch.nn.Module): def _run_act_ckpt_codegen(rank, world_size, port): - # launch colossalai to make sure we could execute colossalai.utils.checkpoint currectly - colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + # launch colossalai to make sure we could execute colossalai.utils.checkpoint currently + colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") # build model and run forward model = MyModule() @@ -54,27 +53,34 @@ def _run_act_ckpt_codegen(rank, world_size, port): # annotate nested checkpoint for node in graph.nodes: if node.name == "linear1": - node.meta['activation_checkpoint'] = [0, 0, 0] + node.meta["activation_checkpoint"] = [0, 0, 0] continue if node.name == "linear2": - node.meta['activation_checkpoint'] = [0, 0, None] + node.meta["activation_checkpoint"] = [0, 0, None] if node.name == "linear3": - node.meta['activation_checkpoint'] = [0, 0, 1] + node.meta["activation_checkpoint"] = [0, 0, 1] if node.name == "linear4": - node.meta['activation_checkpoint'] = [0, 1, None] + node.meta["activation_checkpoint"] = [0, 1, None] if node.name == "linear5": - node.meta['activation_checkpoint'] = 1 + node.meta["activation_checkpoint"] = 1 gm = ColoGraphModule(model, graph) gm.recompile() # assert checkpoint function will be generated and - code = graph.python_code('self').src - assert 'colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_0_0, False, x, use_reentrant=False)' in code and \ - 'colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_0_1, False, linear3, use_reentrant=False)' in code and \ - 'colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_0_0_0, False, x, use_reentrant=False)' in code and \ - 'colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_0_0_1, False, linear2, use_reentrant=False)' in code and \ - 'colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_0, False, x, use_reentrant=False)' in code and \ - 'colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_1, False, linear4, use_reentrant=False)' in code + code = graph.python_code("self").src + assert ( + "colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_0_0, False, x, use_reentrant=False)" in code + and "colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_0_1, False, linear3, use_reentrant=False)" + in code + and "colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_0_0_0, False, x, use_reentrant=False)" + in code + and "colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_0_0_1, False, linear2, use_reentrant=False)" + in code + and "colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_0, False, x, use_reentrant=False)" + in code + and "colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_1, False, linear4, use_reentrant=False)" + in code + ) # recompile and verify the outputs are consistent fx_out = gm(data1) @@ -83,14 +89,14 @@ def _run_act_ckpt_codegen(rank, world_size, port): gpc.destroy() -@pytest.mark.skipif(not with_codegen, reason='torch version is lower than 1.12.0') +@pytest.mark.skipif(not with_codegen, reason="torch version is lower than 1.12.0") def test_act_ckpt_codegen(): spawn(_run_act_ckpt_codegen, 1) def _run_act_ckpt_python_code_torch11(rank, world_size, port): - # launch colossalai to make sure we could execute colossalai.utils.checkpoint currectly - colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + # launch colossalai to make sure we could execute colossalai.utils.checkpoint currently + colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") # build model and run forward model = MyModule() @@ -111,27 +117,34 @@ def _run_act_ckpt_python_code_torch11(rank, world_size, port): # annotate nested checkpoint for node in graph.nodes: if node.name == "linear1": - node.meta['activation_checkpoint'] = [0, 0, 0] + node.meta["activation_checkpoint"] = [0, 0, 0] continue if node.name == "linear2": - node.meta['activation_checkpoint'] = [0, 0, None] + node.meta["activation_checkpoint"] = [0, 0, None] if node.name == "linear3": - node.meta['activation_checkpoint'] = [0, 0, 1] + node.meta["activation_checkpoint"] = [0, 0, 1] if node.name == "linear4": - node.meta['activation_checkpoint'] = [0, 1, None] + node.meta["activation_checkpoint"] = [0, 1, None] if node.name == "linear5": - node.meta['activation_checkpoint'] = 1 + node.meta["activation_checkpoint"] = 1 gm = ColoGraphModule(model, graph) gm.recompile() # assert checkpoint function will be generated and - code = graph.python_code('self').src - assert 'colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_0_0, False, x, use_reentrant=False)' in code and \ - 'colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_0_1, False, linear3, use_reentrant=False)' in code and \ - 'colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_0_0_0, False, x, use_reentrant=False)' in code and \ - 'colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_0_0_1, False, linear2, use_reentrant=False)' in code and \ - 'colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_0, False, x, use_reentrant=False)' in code and \ - 'colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_1, False, linear4, use_reentrant=False)' in code + code = graph.python_code("self").src + assert ( + "colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_0_0, False, x, use_reentrant=False)" in code + and "colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_0_1, False, linear3, use_reentrant=False)" + in code + and "colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_0_0_0, False, x, use_reentrant=False)" + in code + and "colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_0_0_1, False, linear2, use_reentrant=False)" + in code + and "colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_0, False, x, use_reentrant=False)" + in code + and "colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_1, False, linear4, use_reentrant=False)" + in code + ) # recompile and verify the outputs are consistent fx_out = gm(data1) @@ -140,12 +153,12 @@ def _run_act_ckpt_python_code_torch11(rank, world_size, port): gpc.destroy() -@pytest.mark.skipif(with_codegen, reason='torch version is equal to or higher than 1.12.0') +@pytest.mark.skipif(with_codegen, reason="torch version is equal to or higher than 1.12.0") @pytest.mark.skip(reason="currently torch11 ColoGraphModule is not done") @rerun_if_address_is_in_use() def test_act_ckpt_python_code_torch11(): spawn(_run_act_ckpt_python_code_torch11, 1) -if __name__ == '__main__': +if __name__ == "__main__": _run_act_ckpt_codegen(rank=0) diff --git a/tests/test_fx/test_codegen/test_offload_codegen.py b/tests/test_fx/test_codegen/test_offload_codegen.py index 96e88eb92b3319bf8cd54c9a0175d6e9bf5f5617..da1e73ec3dfe1a462eeb91147b8ab06e7190e40a 100644 --- a/tests/test_fx/test_codegen/test_offload_codegen.py +++ b/tests/test_fx/test_codegen/test_offload_codegen.py @@ -5,22 +5,23 @@ import torch from torch.fx import GraphModule import colossalai -from colossalai.core import global_context as gpc from colossalai.fx import ColoTracer from colossalai.fx.graph_module import ColoGraphModule +from colossalai.legacy.core import global_context as gpc from colossalai.testing import rerun_if_address_is_in_use, spawn try: from colossalai.fx.codegen import ActivationCheckpointCodeGen + with_codegen = True except: # fall back to older pytorch version from colossalai.fx.codegen import python_code_with_activation_checkpoint + with_codegen = False class MyNet(torch.nn.Module): - def __init__(self) -> None: super().__init__() self.linear0 = torch.nn.Linear(4, 4) @@ -50,13 +51,12 @@ def _is_all_gradient_close(m: torch.nn.Module, gm: GraphModule) -> bool: def _test_fwd_and_bwd(model: torch.nn.Module, gm: ColoGraphModule, data: torch.Tensor): - # test forward non_fx_out = model(data) fx_out = gm(data) assert torch.equal(non_fx_out, fx_out), "fx_out doesn't comply with original output" - # test barckward + # test backward loss0 = non_fx_out.sum() loss0.backward() loss1 = fx_out.sum() @@ -65,8 +65,8 @@ def _test_fwd_and_bwd(model: torch.nn.Module, gm: ColoGraphModule, data: torch.T def _run_offload_codegen(rank, world_size, port): - # launch colossalai to make sure we could execute colossalai.utils.checkpoint currectly - colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + # launch colossalai to make sure we could execute colossalai.utils.checkpoint currently + colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") # build model and input model = MyNet().cuda() @@ -83,45 +83,48 @@ def _run_offload_codegen(rank, world_size, port): # of input offload for node in graph.nodes: if node.name == "linear0": - node.meta['activation_offload'] = [0, True, False] + node.meta["activation_offload"] = [0, True, False] if node.name == "linear1": - node.meta['activation_offload'] = [0, True, False] + node.meta["activation_offload"] = [0, True, False] if node.name == "linear2": - node.meta['activation_offload'] = [1, True, True] + node.meta["activation_offload"] = [1, True, True] if node.name == "linear4": - node.meta['activation_offload'] = [2, False, True] + node.meta["activation_offload"] = [2, False, True] if node.name == "linear5": - node.meta['activation_checkpoint'] = [0] - node.meta['activation_offload'] = True + node.meta["activation_checkpoint"] = [0] + node.meta["activation_offload"] = True gm = ColoGraphModule(copy.deepcopy(model), graph) gm.recompile() # assert we have all the components code = graph.python_code("self").src - assert "def pack_hook_input(self, x):" in code and \ - "def unpack_hook(self, packed):" in code and \ - "def pack_hook_no_input(self, x):" in code and \ - "setattr(x, 'offload', True)" in code and \ - "setattr(linear3, 'offload', False)" in code and \ - "with torch.autograd.graph.saved_tensors_hooks(self.pack_hook_input, self.unpack_hook):" in code and \ - "with torch.autograd.graph.save_on_cpu(pin_memory=True):" in code and \ - "with torch.autograd.graph.saved_tensors_hooks(self.pack_hook_no_input, self.unpack_hook):" in code and \ - "colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_0, True, linear4, use_reentrant=False)" in code + assert ( + "def pack_hook_input(self, x):" in code + and "def unpack_hook(self, packed):" in code + and "def pack_hook_no_input(self, x):" in code + and "setattr(x, 'offload', True)" in code + and "setattr(linear3, 'offload', False)" in code + and "with torch.autograd.graph.saved_tensors_hooks(self.pack_hook_input, self.unpack_hook):" in code + and "with torch.autograd.graph.save_on_cpu(pin_memory=True):" in code + and "with torch.autograd.graph.saved_tensors_hooks(self.pack_hook_no_input, self.unpack_hook):" in code + and "colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_0, True, linear4, use_reentrant=False)" + in code + ) _test_fwd_and_bwd(model, gm, data) gpc.destroy() -@pytest.mark.skipif(not with_codegen, reason='torch version is lower than 1.12.0') +@pytest.mark.skipif(not with_codegen, reason="torch version is lower than 1.12.0") @rerun_if_address_is_in_use() def test_act_ckpt_codegen(): spawn(_run_offload_codegen, 1) def _run_offload_codegen_torch11(rank, world_size, port): - # launch colossalai to make sure we could execute colossalai.utils.checkpoint currectly - colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + # launch colossalai to make sure we could execute colossalai.utils.checkpoint currently + colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") # build model and input model = MyNet().cuda() @@ -139,31 +142,34 @@ def _run_offload_codegen_torch11(rank, world_size, port): # of input offload for node in graph.nodes: if node.name == "linear0": - node.meta['activation_offload'] = [0, True, False] + node.meta["activation_offload"] = [0, True, False] if node.name == "linear1": - node.meta['activation_offload'] = [0, True, False] + node.meta["activation_offload"] = [0, True, False] if node.name == "linear2": - node.meta['activation_offload'] = [1, True, True] + node.meta["activation_offload"] = [1, True, True] if node.name == "linear4": - node.meta['activation_offload'] = [2, False, True] + node.meta["activation_offload"] = [2, False, True] if node.name == "linear5": - node.meta['activation_checkpoint'] = [0] - node.meta['activation_offload'] = True + node.meta["activation_checkpoint"] = [0] + node.meta["activation_offload"] = True gm = ColoGraphModule(copy.deepcopy(model), graph) gm.recompile() # assert we have all the components code = graph.python_code("self").src - assert "def pack_hook_input(self, x):" in code and \ - "def unpack_hook(self, packed):" in code and \ - "def pack_hook_no_input(self, x):" in code and \ - "setattr(x, 'offload', True)" in code and \ - "setattr(linear3, 'offload', False)" in code and \ - "with torch.autograd.graph.saved_tensors_hooks(self.pack_hook_input, self.unpack_hook):" in code and \ - "with torch.autograd.graph.save_on_cpu(pin_memory=True):" in code and \ - "with torch.autograd.graph.saved_tensors_hooks(self.pack_hook_no_input, self.unpack_hook):" in code and \ - "colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_0, True, linear4, use_reentrant=False)" in code + assert ( + "def pack_hook_input(self, x):" in code + and "def unpack_hook(self, packed):" in code + and "def pack_hook_no_input(self, x):" in code + and "setattr(x, 'offload', True)" in code + and "setattr(linear3, 'offload', False)" in code + and "with torch.autograd.graph.saved_tensors_hooks(self.pack_hook_input, self.unpack_hook):" in code + and "with torch.autograd.graph.save_on_cpu(pin_memory=True):" in code + and "with torch.autograd.graph.saved_tensors_hooks(self.pack_hook_no_input, self.unpack_hook):" in code + and "colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_0, True, linear4, use_reentrant=False)" + in code + ) _test_fwd_and_bwd(model, gm, data) gpc.destroy() diff --git a/tests/test_fx/test_coloproxy.py b/tests/test_fx/test_coloproxy.py index 96cf5198da10a5eb3d0d595e73d9b796547afe86..efef368bdd45b275b04efb2269ed6884740f87ad 100644 --- a/tests/test_fx/test_coloproxy.py +++ b/tests/test_fx/test_coloproxy.py @@ -1,4 +1,3 @@ -import pytest import torch import torch.nn as nn from torch.fx import GraphModule @@ -9,7 +8,6 @@ from colossalai.testing import clear_cache_before_run class Conv1D(nn.Module): - def __init__(self, nf, nx): super().__init__() self.nf = nf @@ -27,10 +25,9 @@ class Conv1D(nn.Module): @clear_cache_before_run() def test_coloproxy(): - tracer = ColoTracer() model = Conv1D(3, 3) - input_sample = {'x': torch.rand(3, 3).to('meta')} + input_sample = {"x": torch.rand(3, 3).to("meta")} graph = tracer.trace(root=model, meta_args=input_sample) gm = GraphModule(model, graph, model.__class__.__name__) @@ -38,7 +35,7 @@ def test_coloproxy(): node = list(gm.graph.nodes)[0] proxy = ColoProxy(node=node, tracer=tracer) - proxy.meta_data = torch.empty(4, 2, device='meta') + proxy.meta_data = torch.empty(4, 2, device="meta") assert len(proxy) == 4 assert proxy.shape[0] == 4 and proxy.shape[1] == 2 @@ -47,5 +44,5 @@ def test_coloproxy(): assert proxy.size(0) == 4 -if __name__ == '__main__': +if __name__ == "__main__": test_coloproxy() diff --git a/tests/test_fx/test_comm_size_compute.py b/tests/test_fx/test_comm_size_compute.py index d3daadd714064b38112d5af976e9a32725faddf7..00721ca86ade2194269554c70b335092f6583adb 100644 --- a/tests/test_fx/test_comm_size_compute.py +++ b/tests/test_fx/test_comm_size_compute.py @@ -17,7 +17,6 @@ PIPELINE_SIZE = 2 class MLP(torch.nn.Module): - def __init__(self, dim: int): super().__init__() self.linear1 = torch.nn.Linear(dim, dim) @@ -36,7 +35,7 @@ class MLP(torch.nn.Module): @clear_cache_before_run() def test_comm_size_compute(): model = MLP(MODEL_DIM) - input_sample = torch.rand(BATCH_SIZE, MODEL_DIM, device='meta') + input_sample = torch.rand(BATCH_SIZE, MODEL_DIM, device="meta") gm = symbolic_trace(model) if is_compatible: input_sample = MetaTensor(input_sample, fake_device=next(gm.parameters()).device) @@ -49,5 +48,5 @@ def test_comm_size_compute(): assert comm_size == 128 -if __name__ == '__main__': +if __name__ == "__main__": test_comm_size_compute() diff --git a/tests/test_fx/test_graph_manipulation.py b/tests/test_fx/test_graph_manipulation.py index 175b69dd96fe53204ed1263daab40b923ed54206..eece451a706f7c940de82b5108a90aba44a72c1b 100644 --- a/tests/test_fx/test_graph_manipulation.py +++ b/tests/test_fx/test_graph_manipulation.py @@ -1,15 +1,11 @@ import torch -from torch.fx import GraphModule -import colossalai from colossalai.fx import ColoTracer -from colossalai.fx.passes.meta_info_prop import MetaInfoProp, TensorMetadata from colossalai.fx.passes.utils import assign_bfs_level_to_nodes, get_leaf, get_top from colossalai.testing import clear_cache_before_run class MLP(torch.nn.Module): - def __init__(self, dim: int): super().__init__() self.linear1 = torch.nn.Linear(dim, dim) @@ -43,11 +39,11 @@ def test_graph_manipulation(): assert leaf_nodes == set([l4, l5]) assert top_nodes == set([l1, l2]) for node in graph.nodes: - if node.op in ('placeholder', 'output'): - assert not hasattr(node, 'bfs_level') + if node.op in ("placeholder", "output"): + assert not hasattr(node, "bfs_level") else: assert node.bfs_level == compare_dict[node] -if __name__ == '__main__': +if __name__ == "__main__": test_graph_manipulation() diff --git a/tests/test_fx/test_meta/test_aten.py b/tests/test_fx/test_meta/test_aten.py index e490522dbf15677f5af4fa5b8c04520338101585..7fc7eb4df64bc6e424b8978d7455f3b2eabb93d3 100644 --- a/tests/test_fx/test_meta/test_aten.py +++ b/tests/test_fx/test_meta/test_aten.py @@ -13,35 +13,41 @@ if is_compatible_with_meta(): aten = torch.ops.aten registered_meta = { - ('aten.convolution.default', True): [ # (aten ops, requires_backward) + ("aten.convolution.default", True): [ # (aten ops, requires_backward) (nn.Conv1d(in_channels=3, out_channels=4, kernel_size=2, padding=1, dilation=2), torch.rand(2, 3, 4)), (nn.Conv2d(in_channels=3, out_channels=4, kernel_size=2, padding=1, dilation=2), torch.rand(2, 3, 4, 4)), (nn.Conv3d(in_channels=3, out_channels=4, kernel_size=2, padding=1, dilation=2), torch.rand(2, 3, 4, 4, 4)), (nn.ConvTranspose1d(in_channels=3, out_channels=4, kernel_size=2, padding=1, dilation=2), torch.rand(2, 3, 4)), - (nn.ConvTranspose2d(in_channels=3, out_channels=4, kernel_size=2, padding=1, - dilation=2), torch.rand(2, 3, 4, 4)), - (nn.ConvTranspose3d(in_channels=3, out_channels=4, kernel_size=2, padding=1, - dilation=2), torch.rand(2, 3, 4, 4, 4)), + ( + nn.ConvTranspose2d(in_channels=3, out_channels=4, kernel_size=2, padding=1, dilation=2), + torch.rand(2, 3, 4, 4), + ), + ( + nn.ConvTranspose3d(in_channels=3, out_channels=4, kernel_size=2, padding=1, dilation=2), + torch.rand(2, 3, 4, 4, 4), + ), ], - ('aten.native_batch_norm.default', True): [ + ("aten.native_batch_norm.default", True): [ (nn.BatchNorm1d(4), torch.rand(2, 4)), (nn.BatchNorm2d(4), torch.rand(1, 4, 4, 4)), (nn.BatchNorm3d(4), torch.rand(1, 4, 4, 4, 4)), ], - ('aten.native_layer_norm.default', True): [(nn.LayerNorm(4), torch.rand(1, 2, 3, 4)),], - ('aten.avg_pool1d.default', True): [ + ("aten.native_layer_norm.default", True): [ + (nn.LayerNorm(4), torch.rand(1, 2, 3, 4)), + ], + ("aten.avg_pool1d.default", True): [ (nn.MaxPool1d(3, stride=2), torch.rand(4, 5, 5)), (nn.AvgPool1d(3, stride=2), torch.rand(4, 5, 5)), (nn.AdaptiveMaxPool1d(3), torch.rand(4, 5, 5)), (nn.AdaptiveAvgPool1d(3), torch.rand(4, 5, 5)), ], - ('aten.avg_pool2d.default', True): [ + ("aten.avg_pool2d.default", True): [ (nn.MaxPool2d((3, 2), stride=(2, 1)), torch.rand(2, 4, 5, 5)), (nn.AvgPool2d((3, 2), stride=(2, 1)), torch.rand(2, 4, 5, 5)), (nn.AdaptiveMaxPool2d((3, 2)), torch.rand(2, 4, 5, 5)), (nn.AdaptiveAvgPool2d((3, 2)), torch.rand(2, 4, 5, 5)), ], - ('aten.relu.default', True): [ + ("aten.relu.default", True): [ (nn.ReLU(), torch.rand(4, 3, 1, 2)), (nn.LeakyReLU(), torch.rand(4, 3, 1, 2)), (nn.SiLU(), torch.rand(4, 3, 1, 2)), @@ -50,15 +56,20 @@ registered_meta = { (nn.Sigmoid(), torch.rand(4, 3, 1, 2)), (nn.Tanh(), torch.rand(4, 3, 1, 2)), (nn.Hardswish(), torch.rand(4, 3, 1, 2)), - ] + ], } def compare_all(tensor: torch.Tensor, meta_tensor: torch.Tensor) -> Any: - assert tensor.shape == meta_tensor.shape, f'the shape of tensor ({tensor.shape}) and meta tensor ({meta_tensor.shape}) does not match.' - assert tensor.dtype == meta_tensor.dtype, f'the dtype of tensor ({tensor.dtype}) and meta tensor ({meta_tensor.dtype}) does not match.' - assert tensor.stride() == meta_tensor.stride( - ), f'the stride of tensor ({tensor.stride()}) and meta tensor ({meta_tensor.stride()}) does not match.' + assert ( + tensor.shape == meta_tensor.shape + ), f"the shape of tensor ({tensor.shape}) and meta tensor ({meta_tensor.shape}) does not match." + assert ( + tensor.dtype == meta_tensor.dtype + ), f"the dtype of tensor ({tensor.dtype}) and meta tensor ({meta_tensor.dtype}) does not match." + assert ( + tensor.stride() == meta_tensor.stride() + ), f"the stride of tensor ({tensor.stride()}) and meta tensor ({meta_tensor.stride()}) does not match." def run_and_compare(f: Union[nn.Module, Callable], x: torch.Tensor, requires_backward=False) -> Any: @@ -72,7 +83,7 @@ def run_and_compare(f: Union[nn.Module, Callable], x: torch.Tensor, requires_bac compare_all(x.grad, meta_x.grad) -@pytest.mark.skipif(not is_compatible_with_meta(), reason='torch version is lower than 1.12.0') +@pytest.mark.skipif(not is_compatible_with_meta(), reason="torch version is lower than 1.12.0") @clear_cache_before_run() def test_meta_aten(): for (aten_op, requires_backward), v in registered_meta.items(): @@ -80,5 +91,5 @@ def test_meta_aten(): run_and_compare(f, x, requires_backward) -if __name__ == '__main__': +if __name__ == "__main__": test_meta_aten() diff --git a/tests/test_fx/test_meta/test_backward.py b/tests/test_fx/test_meta/test_backward.py index 7aed6fd4597b2286bbcda9627c7dd1ba0b482593..6091c4b6be2f85ca0aefeed6c6814a7b9da8f7c2 100644 --- a/tests/test_fx/test_meta/test_backward.py +++ b/tests/test_fx/test_meta/test_backward.py @@ -23,31 +23,40 @@ tm_models = [ ] tmm_models = [ - tmm.resnest.resnest50d, tmm.beit.beit_base_patch16_224, tmm.cait.cait_s24_224, tmm.efficientnet.efficientnetv2_m, - tmm.resmlp_12_224, tmm.vision_transformer.vit_base_patch16_224, tmm.deit_base_distilled_patch16_224, - tmm.convnext.convnext_base, tmm.vgg.vgg11, tmm.dpn.dpn68, tmm.densenet.densenet121, tmm.rexnet.rexnet_100, - tmm.swin_transformer.swin_base_patch4_window7_224 + tmm.resnest.resnest50d, + tmm.beit.beit_base_patch16_224, + tmm.cait.cait_s24_224, + tmm.efficientnet.efficientnetv2_m, + tmm.resmlp_12_224, + tmm.vision_transformer.vit_base_patch16_224, + tmm.deit_base_distilled_patch16_224, + tmm.convnext.convnext_base, + tmm.vgg.vgg11, + tmm.dpn.dpn68, + tmm.densenet.densenet121, + tmm.rexnet.rexnet_100, + tmm.swin_transformer.swin_base_patch4_window7_224, ] -@pytest.mark.skipif(not is_compatible_with_meta(), reason='torch version is lower than 1.12.0') +@pytest.mark.skipif(not is_compatible_with_meta(), reason="torch version is lower than 1.12.0") @clear_cache_before_run() def test_torchvision_models(): for m in tm_models: model = m() - data = torch.rand(100000, 3, 224, 224, device='meta') - model(MetaTensor(data, fake_device=torch.device('cpu'))).sum().backward() + data = torch.rand(100000, 3, 224, 224, device="meta") + model(MetaTensor(data, fake_device=torch.device("cpu"))).sum().backward() -@pytest.mark.skipif(not is_compatible_with_meta(), reason='torch version is lower than 1.12.0') +@pytest.mark.skipif(not is_compatible_with_meta(), reason="torch version is lower than 1.12.0") @clear_cache_before_run() def test_timm_models(): for m in tmm_models: model = m() - data = torch.rand(100000, 3, 224, 224, device='meta') - model(MetaTensor(data, fake_device=torch.device('cpu'))).sum().backward() + data = torch.rand(100000, 3, 224, 224, device="meta") + model(MetaTensor(data, fake_device=torch.device("cpu"))).sum().backward() -if __name__ == '__main__': +if __name__ == "__main__": test_torchvision_models() test_timm_models() diff --git a/tests/test_fx/test_meta/test_meta_trace.py b/tests/test_fx/test_meta/test_meta_trace.py index 61614f8a66234260f54f8286bcfe628d55452eab..ba9617a383802feeaa167b048b2ad7b3ebf89a51 100644 --- a/tests/test_fx/test_meta/test_meta_trace.py +++ b/tests/test_fx/test_meta/test_meta_trace.py @@ -23,31 +23,40 @@ tm_models = [ ] tmm_models = [ - tmm.resnest.resnest50d, tmm.beit.beit_base_patch16_224, tmm.cait.cait_s24_224, tmm.efficientnet.efficientnetv2_m, - tmm.resmlp_12_224, tmm.vision_transformer.vit_base_patch16_224, tmm.deit_base_distilled_patch16_224, - tmm.convnext.convnext_base, tmm.vgg.vgg11, tmm.dpn.dpn68, tmm.densenet.densenet121, tmm.rexnet.rexnet_100, - tmm.swin_transformer.swin_base_patch4_window7_224 + tmm.resnest.resnest50d, + tmm.beit.beit_base_patch16_224, + tmm.cait.cait_s24_224, + tmm.efficientnet.efficientnetv2_m, + tmm.resmlp_12_224, + tmm.vision_transformer.vit_base_patch16_224, + tmm.deit_base_distilled_patch16_224, + tmm.convnext.convnext_base, + tmm.vgg.vgg11, + tmm.dpn.dpn68, + tmm.densenet.densenet121, + tmm.rexnet.rexnet_100, + tmm.swin_transformer.swin_base_patch4_window7_224, ] -@pytest.mark.skipif(not is_compatible_with_meta(), reason='torch version is lower than 1.12.0') +@pytest.mark.skipif(not is_compatible_with_meta(), reason="torch version is lower than 1.12.0") @clear_cache_before_run() def test_torchvision_models_trace(): for m in tm_models: model = m() - data = torch.rand(1000, 3, 224, 224, device='meta') - graph = meta_trace(model, torch.device('cpu'), data) + data = torch.rand(1000, 3, 224, 224, device="meta") + meta_trace(model, torch.device("cpu"), data) -@pytest.mark.skipif(not is_compatible_with_meta(), reason='torch version is lower than 1.12.0') +@pytest.mark.skipif(not is_compatible_with_meta(), reason="torch version is lower than 1.12.0") @clear_cache_before_run() def test_timm_models_trace(): for m in tmm_models: model = m() - data = torch.rand(1000, 3, 224, 224, device='meta') - graph = meta_trace(model, torch.device('cpu'), data) + data = torch.rand(1000, 3, 224, 224, device="meta") + meta_trace(model, torch.device("cpu"), data) -if __name__ == '__main__': +if __name__ == "__main__": test_torchvision_models_trace() test_timm_models_trace() diff --git a/tests/test_fx/test_meta_info_prop.py b/tests/test_fx/test_meta_info_prop.py index a12512696a730459b80a9debc4cfa43695be8515..659949e87002ff075ff7c4695d06ac169cb3682e 100644 --- a/tests/test_fx/test_meta_info_prop.py +++ b/tests/test_fx/test_meta_info_prop.py @@ -23,18 +23,18 @@ def meta_check(meta_info_spec: TensorMetadata, orig_tensor: torch.Tensor): @clear_cache_before_run() def test_meta_info_prop(): model = torch.nn.Linear(DIM_IN, DIM_OUT) - input_sample = torch.rand(BATCH_SIZE, DIM_IN, device='meta') + input_sample = torch.rand(BATCH_SIZE, DIM_IN, device="meta") if is_compatible_with_meta(): - input_sample = MetaTensor(input_sample, fake_device='cpu') + input_sample = MetaTensor(input_sample, fake_device="cpu") orig_output = model(input_sample) gm = symbolic_trace(model) MetaInfoProp(gm).run(input_sample) for node in gm.graph.nodes: - if node.op == 'placeholder': - meta_check(node.meta['tensor_meta'], input_sample) - if node.op == 'output': - meta_check(node.meta['tensor_meta'], orig_output) + if node.op == "placeholder": + meta_check(node.meta["tensor_meta"], input_sample) + if node.op == "output": + meta_check(node.meta["tensor_meta"], orig_output) -if __name__ == '__main__': +if __name__ == "__main__": test_meta_info_prop() diff --git a/tests/test_fx/test_parallel_1d.py b/tests/test_fx/test_parallel_1d.py index 1044be7db1f4cbab519d8a90a70f8f3deb3016a2..6d890f59d5c537ab1e31b6c0f22990a2081070a7 100644 --- a/tests/test_fx/test_parallel_1d.py +++ b/tests/test_fx/test_parallel_1d.py @@ -5,15 +5,14 @@ import pytest import torch from torch.fx import symbolic_trace -from colossalai.core import global_context as gpc from colossalai.fx.passes import column_shard_linear_pass from colossalai.initialize import launch +from colossalai.legacy.core import global_context as gpc from colossalai.logging import disable_existing_loggers from colossalai.testing import clear_cache_before_run, rerun_if_address_is_in_use, spawn class MLP(torch.nn.Module): - def __init__(self, dim: int): super().__init__() self.linear1 = torch.nn.Linear(dim, dim) @@ -29,12 +28,12 @@ class MLP(torch.nn.Module): return x -CONFIG = dict(parallel=dict(tensor=dict(mode='1d', size=2))) +CONFIG = dict(parallel=dict(tensor=dict(mode="1d", size=2))) def check_layer(rank, world_size, port): disable_existing_loggers() - launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + launch(config=CONFIG, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") input_tensor = torch.rand(2, 16).cuda() model = MLP(16).cuda() symbolic_traced = symbolic_trace(model) @@ -55,5 +54,5 @@ def test_1d(): spawn(check_layer, 2) -if __name__ == '__main__': +if __name__ == "__main__": test_1d() diff --git a/tests/test_fx/test_pipeline/test_hf_model/hf_utils.py b/tests/test_fx/test_pipeline/test_hf_model/hf_utils.py index 3afc6c97e2bb69778b5fce5668861cba3d3582eb..b86c71db85c267da0dd4fe2003281eb09acf01b6 100644 --- a/tests/test_fx/test_pipeline/test_hf_model/hf_utils.py +++ b/tests/test_fx/test_pipeline/test_hf_model/hf_utils.py @@ -1,11 +1,12 @@ -import torch -from torch.fx import symbolic_trace -from torch.fx import GraphModule -from colossalai.fx.passes.adding_split_node_pass import split_with_split_nodes_pass, balanced_split_pass -from colossalai.fx import ColoTracer import inspect import random + import numpy as np +import torch +from torch.fx import GraphModule + +from colossalai.fx import ColoTracer +from colossalai.fx.passes.adding_split_node_pass import balanced_split_pass, split_with_split_nodes_pass MANUAL_SEED = 0 random.seed(MANUAL_SEED) @@ -26,7 +27,7 @@ def split_model_and_compare_output(model, data_gen): # tracing model tracer = ColoTracer() try: - meta_args = {k: v.to('meta') for k, v in kwargs.items()} + meta_args = {k: v.to("meta") for k, v in kwargs.items()} graph = tracer.trace(root=model, meta_args=meta_args) except Exception as e: raise RuntimeError(f"Failed to trace {model.__class__.__name__}, error: {e}") @@ -49,16 +50,16 @@ def split_model_and_compare_output(model, data_gen): output_part1 = model_part1(output_part0) else: if len(output_part0) > len(sig.parameters): - output_part0 = output_part0[:len(sig.parameters)] + output_part0 = output_part0[: len(sig.parameters)] output_part1 = model_part1(*output_part0) # get output tensor from HFOutput datastructure - if 'logits' in output: - output_to_compare = output['logits'] - elif 'prediction_logits' in output: - output_to_compare = output['prediction_logits'] + if "logits" in output: + output_to_compare = output["logits"] + elif "prediction_logits" in output: + output_to_compare = output["prediction_logits"] else: - output_to_compare = output['last_hidden_state'] + output_to_compare = output["last_hidden_state"] # compare output if isinstance(output_part1, torch.Tensor): diff --git a/tests/test_fx/test_pipeline/test_hf_model/test_albert.py b/tests/test_fx/test_pipeline/test_hf_model/test_albert.py index 6ef861bdefbe105299e31e8f06f8d1eb464fb3c2..d15081b0b3ad301b6286ecb58c1412b1fbf9a383 100644 --- a/tests/test_fx/test_pipeline/test_hf_model/test_albert.py +++ b/tests/test_fx/test_pipeline/test_hf_model/test_albert.py @@ -7,7 +7,7 @@ BATCH_SIZE = 2 SEQ_LENGHT = 16 -@pytest.mark.skip('balance split v2 is not ready') +@pytest.mark.skip("balance split v2 is not ready") def test_single_sentence_albert(): MODEL_LIST = [ transformers.AlbertModel, @@ -17,12 +17,14 @@ def test_single_sentence_albert(): transformers.AlbertForTokenClassification, ] - config = transformers.AlbertConfig(vocab_size=100, - embedding_size=128, - hidden_size=128, - num_hidden_layers=2, - num_attention_heads=4, - intermediate_size=256) + config = transformers.AlbertConfig( + vocab_size=100, + embedding_size=128, + hidden_size=128, + num_hidden_layers=2, + num_attention_heads=4, + intermediate_size=256, + ) def data_gen(): input_ids = torch.zeros((BATCH_SIZE, SEQ_LENGHT), dtype=torch.int64) @@ -36,5 +38,5 @@ def test_single_sentence_albert(): split_model_and_compare_output(model, data_gen) -if __name__ == '__main__': +if __name__ == "__main__": test_single_sentence_albert() diff --git a/tests/test_fx/test_pipeline/test_hf_model/test_bert.py b/tests/test_fx/test_pipeline/test_hf_model/test_bert.py index a7550413fac8f6514e1428b8d7b3999608d84bbc..3588033d1ecdd68a3e5aed53c37a760f50feab01 100644 --- a/tests/test_fx/test_pipeline/test_hf_model/test_bert.py +++ b/tests/test_fx/test_pipeline/test_hf_model/test_bert.py @@ -7,7 +7,7 @@ BATCH_SIZE = 2 SEQ_LENGHT = 16 -@pytest.mark.skip('balance split v2 is not ready') +@pytest.mark.skip("balance split v2 is not ready") def test_single_sentence_bert(): MODEL_LIST = [ transformers.BertModel, @@ -18,11 +18,9 @@ def test_single_sentence_bert(): transformers.BertForTokenClassification, ] - config = transformers.BertConfig(vocab_size=100, - hidden_size=128, - num_hidden_layers=4, - num_attention_heads=4, - intermediate_size=256) + config = transformers.BertConfig( + vocab_size=100, hidden_size=128, num_hidden_layers=4, num_attention_heads=4, intermediate_size=256 + ) def data_gen(): input_ids = torch.zeros((BATCH_SIZE, SEQ_LENGHT), dtype=torch.int64) @@ -36,5 +34,5 @@ def test_single_sentence_bert(): split_model_and_compare_output(model, data_gen) -if __name__ == '__main__': +if __name__ == "__main__": test_single_sentence_bert() diff --git a/tests/test_fx/test_pipeline/test_hf_model/test_gpt.py b/tests/test_fx/test_pipeline/test_hf_model/test_gpt.py index 6181c5c0706a97d895fb4d51d012809a2a5336a4..d2533aea4003f5890142ec6451ab547f4ce50818 100644 --- a/tests/test_fx/test_pipeline/test_hf_model/test_gpt.py +++ b/tests/test_fx/test_pipeline/test_hf_model/test_gpt.py @@ -9,14 +9,14 @@ NUM_EPOCHS = 2 NUM_CHUNKS = 1 -@pytest.mark.skip('balance split v2 is not ready') +@pytest.mark.skip("balance split v2 is not ready") def test_gpt(): MODEL_LIST = [ transformers.GPT2Model, transformers.GPT2LMHeadModel, transformers.GPT2DoubleHeadsModel, transformers.GPT2ForTokenClassification, - # transformers.GPT2ForSequenceClassification, # not supported yet + # transformers.GPT2ForSequenceClassification, # not supported yet ] config = transformers.GPT2Config(n_position=64, n_layer=4, n_head=8) @@ -32,5 +32,5 @@ def test_gpt(): split_model_and_compare_output(model, data_gen) -if __name__ == '__main__': +if __name__ == "__main__": test_gpt() diff --git a/tests/test_fx/test_pipeline/test_hf_model/test_opt.py b/tests/test_fx/test_pipeline/test_hf_model/test_opt.py index 1a9b36be82bd9e9fb29580b6fe6807dbb1eb032a..e67628d103649968792ad0fec2243d35185cf5b7 100644 --- a/tests/test_fx/test_pipeline/test_hf_model/test_opt.py +++ b/tests/test_fx/test_pipeline/test_hf_model/test_opt.py @@ -7,7 +7,7 @@ BATCH_SIZE = 1 SEQ_LENGHT = 16 -@pytest.mark.skip('balance split v2 is not ready') +@pytest.mark.skip("balance split v2 is not ready") def test_opt(): MODEL_LIST = [ transformers.OPTModel, @@ -27,5 +27,5 @@ def test_opt(): split_model_and_compare_output(model, data_gen) -if __name__ == '__main__': +if __name__ == "__main__": test_opt() diff --git a/tests/test_fx/test_pipeline/test_hf_model/test_t5.py b/tests/test_fx/test_pipeline/test_hf_model/test_t5.py index 16d0163746b318dec6be47529b9f9c55d23d8a7e..dc36fdb131523790ddcb51f5a3d7162fb06a4718 100644 --- a/tests/test_fx/test_pipeline/test_hf_model/test_t5.py +++ b/tests/test_fx/test_pipeline/test_hf_model/test_t5.py @@ -7,7 +7,7 @@ BATCH_SIZE = 1 SEQ_LENGHT = 16 -@pytest.mark.skip('balance split v2 is not ready') +@pytest.mark.skip("balance split v2 is not ready") def test_t5(): MODEL_LIST = [ transformers.T5Model, @@ -39,5 +39,5 @@ def test_t5(): split_model_and_compare_output(model, data_gen_func) -if __name__ == '__main__': +if __name__ == "__main__": test_t5() diff --git a/tests/test_fx/test_pipeline/test_timm_model/test_timm.py b/tests/test_fx/test_pipeline/test_timm_model/test_timm.py index 6fb1f6f4bb237db355e54d7b57730b3e9a6990a5..c4fe5547ed8db6b82ef95768db474cdc9a92aaee 100644 --- a/tests/test_fx/test_pipeline/test_timm_model/test_timm.py +++ b/tests/test_fx/test_pipeline/test_timm_model/test_timm.py @@ -4,9 +4,8 @@ import torch from timm_utils import split_model_and_compare_output -@pytest.mark.skip('balance split v2 is not ready') +@pytest.mark.skip("balance split v2 is not ready") def test_timm_models_without_control_flow(): - MODEL_LIST = [ tm.resnest.resnest50d, tm.beit.beit_base_patch16_224, @@ -25,24 +24,28 @@ def test_timm_models_without_control_flow(): split_model_and_compare_output(model, data) -@pytest.mark.skip('balance split v2 is not ready') +@pytest.mark.skip("balance split v2 is not ready") def test_timm_models_with_control_flow(): torch.backends.cudnn.deterministic = True MODEL_LIST_WITH_CONTROL_FLOW = [ - tm.convnext.convnext_base, tm.vgg.vgg11, tm.dpn.dpn68, tm.densenet.densenet121, tm.rexnet.rexnet_100, - tm.swin_transformer.swin_base_patch4_window7_224 + tm.convnext.convnext_base, + tm.vgg.vgg11, + tm.dpn.dpn68, + tm.densenet.densenet121, + tm.rexnet.rexnet_100, + tm.swin_transformer.swin_base_patch4_window7_224, ] data = torch.rand(2, 3, 224, 224) - meta_args = {'x': data.to('meta')} + meta_args = {"x": data.to("meta")} for model_cls in MODEL_LIST_WITH_CONTROL_FLOW: model = model_cls() split_model_and_compare_output(model, data, meta_args) -if __name__ == '__main__': +if __name__ == "__main__": test_timm_models_without_control_flow() test_timm_models_with_control_flow() diff --git a/tests/test_fx/test_pipeline/test_timm_model/timm_utils.py b/tests/test_fx/test_pipeline/test_timm_model/timm_utils.py index aa870e5c7a659c5a9c09eae5a9a5d5ec3b77220f..e1182c8d4978876ee756af13c5db4fa6f50f9351 100644 --- a/tests/test_fx/test_pipeline/test_timm_model/timm_utils.py +++ b/tests/test_fx/test_pipeline/test_timm_model/timm_utils.py @@ -1,11 +1,12 @@ -import torch -from torch.fx import symbolic_trace -from torch.fx import GraphModule -from colossalai.fx.passes.adding_split_node_pass import split_with_split_nodes_pass, balanced_split_pass -from colossalai.fx import ColoTracer import inspect import random + import numpy as np +import torch +from torch.fx import GraphModule + +from colossalai.fx import ColoTracer +from colossalai.fx.passes.adding_split_node_pass import balanced_split_pass, split_with_split_nodes_pass MANUAL_SEED = 0 random.seed(MANUAL_SEED) @@ -46,6 +47,6 @@ def split_model_and_compare_output(model, data, meta_args=None): output_part1 = model_part1(output_part0) else: if len(output_part0) > len(sig.parameters): - output_part0 = output_part0[:len(sig.parameters)] + output_part0 = output_part0[: len(sig.parameters)] output_part1 = model_part1(*output_part0) assert output.equal(output_part1) diff --git a/tests/test_fx/test_pipeline/test_topo/test_topo.py b/tests/test_fx/test_pipeline/test_topo/test_topo.py index 16da56250dc3a43fffec4534119f6a1de586a0e7..7c420ef2385ae3b9b9c466a68277ef8121cd89ff 100644 --- a/tests/test_fx/test_pipeline/test_topo/test_topo.py +++ b/tests/test_fx/test_pipeline/test_topo/test_topo.py @@ -7,7 +7,7 @@ BATCH_SIZE = 1 SEQ_LENGHT = 16 -@pytest.mark.skip('ShapeProp is not compatible with PyTorch 1.11.0') +@pytest.mark.skip("ShapeProp is not compatible with PyTorch 1.11.0") def test_opt(): MODEL_LIST = [ MLP, @@ -15,10 +15,7 @@ def test_opt(): ] CONFIGS = [ - { - 'dim': 10, - 'layers': 12 - }, + {"dim": 10, "layers": 12}, transformers.OPTConfig(vocab_size=100, hidden_size=128, num_hidden_layers=4, num_attention_heads=4), ] @@ -45,5 +42,5 @@ def test_opt(): check_topo(top_mod, topo) -if __name__ == '__main__': +if __name__ == "__main__": test_opt() diff --git a/tests/test_fx/test_pipeline/test_topo/topo_utils.py b/tests/test_fx/test_pipeline/test_topo/topo_utils.py index 55dd65201acd9ddd06d3becc84326ba97b1a7943..6a69181a6d26c35961bf4e8221dc765789ce1213 100644 --- a/tests/test_fx/test_pipeline/test_topo/topo_utils.py +++ b/tests/test_fx/test_pipeline/test_topo/topo_utils.py @@ -1,22 +1,25 @@ +import random + +import numpy as np import torch from torch.fx import GraphModule -from colossalai.fx.passes.adding_split_node_pass import split_with_split_nodes_pass, balanced_split_pass + from colossalai.fx import ColoTracer -from colossalai.pipeline.middleware import Partition, PartitionInputVal, PartitionOutputVal, Topo -from colossalai.pipeline.middleware.adaptor import get_fx_topology -import random -import numpy as np +from colossalai.fx.passes.adding_split_node_pass import balanced_split_pass, split_with_split_nodes_pass +from colossalai.legacy.pipeline.middleware import Partition, Topo +from colossalai.legacy.pipeline.middleware.adaptor import get_fx_topology MANUAL_SEED = 0 random.seed(MANUAL_SEED) np.random.seed(MANUAL_SEED) torch.manual_seed(MANUAL_SEED) + class MLP(torch.nn.Module): def __init__(self, config={}): super().__init__() - dim = config['dim'] - layers = config['layers'] + dim = config["dim"] + layers = config["layers"] self.layers = torch.nn.ModuleList() for _ in range(layers): @@ -27,6 +30,7 @@ class MLP(torch.nn.Module): x = layer(x) return x + def split_model_and_get_DAG(model, data_gen): model.eval() @@ -36,7 +40,7 @@ def split_model_and_get_DAG(model, data_gen): # tracing model tracer = ColoTracer() try: - meta_args = {k: v.to('meta') for k, v in kwargs.items()} + meta_args = {k: v.to("meta") for k, v in kwargs.items()} graph = tracer.trace(root=model, meta_args=meta_args) except Exception as e: raise RuntimeError(f"Failed to trace {model.__class__.__name__}, error: {e}") @@ -46,47 +50,49 @@ def split_model_and_get_DAG(model, data_gen): # apply transform passes annotated_model = balanced_split_pass(gm, 2) top_module, split_submodules = split_with_split_nodes_pass(annotated_model) - + topo = get_fx_topology(top_module) for submodule in split_submodules: if isinstance(submodule, torch.fx.GraphModule): - setattr(submodule, '_topo', topo) + setattr(submodule, "_topo", topo) return top_module, split_submodules[0]._topo + def check_input(top_module, input_partition: Partition): partition_output = input_partition.get_output_vals() arg_pos = 0 for node in top_module.graph.nodes: - if node.op == 'placeholder': + if node.op == "placeholder": cur_checkee = partition_output[arg_pos] to_partition_and_offset = cur_checkee.get() assert len(to_partition_and_offset) == len(node.users.keys()) arg_pos += 1 - + assert arg_pos == len(partition_output) - + + def check_submod(top_module, part_id, mid_partition: Partition): partition_input = mid_partition.get_input_vals() partition_output = mid_partition.get_output_vals() - + cnt = 1 cur_node = None for node in top_module.graph.nodes: - if node.name.startswith('submod'): + if node.name.startswith("submod"): cnt += 1 if cnt == part_id: cur_node = node break - + assert len(partition_input) == len(cur_node.args) assert len(partition_output) == len(cur_node.users) -def check_topo(top_module, topo: Topo): + +def check_topo(top_module, topo: Topo): input_partition = topo.get_input_partition() mid_partitions = topo.get_mid_partitions() - + check_input(top_module, input_partition) for part_id, submod in mid_partitions.items(): check_submod(top_module, part_id, submod) - \ No newline at end of file diff --git a/tests/test_fx/test_pipeline/test_torchvision/test_torchvision.py b/tests/test_fx/test_pipeline/test_torchvision/test_torchvision.py index 5d47be2c7bea01323cf1938cfa98ffe1175271de..063e51309503aa131ce18e7d346445b5bf2b21ec 100644 --- a/tests/test_fx/test_pipeline/test_torchvision/test_torchvision.py +++ b/tests/test_fx/test_pipeline/test_torchvision/test_torchvision.py @@ -19,14 +19,21 @@ torch.manual_seed(MANUAL_SEED) torch.backends.cudnn.deterministic = True -@pytest.mark.skip('balance split v2 is not ready') +@pytest.mark.skip("balance split v2 is not ready") def test_torchvision_models(): MODEL_LIST = [ - tm.vgg11, tm.resnet18, tm.densenet121, tm.mobilenet_v3_small, tm.resnext50_32x4d, tm.wide_resnet50_2, - tm.regnet_x_16gf, tm.efficientnet_b0, tm.mnasnet0_5 + tm.vgg11, + tm.resnet18, + tm.densenet121, + tm.mobilenet_v3_small, + tm.resnext50_32x4d, + tm.wide_resnet50_2, + tm.regnet_x_16gf, + tm.efficientnet_b0, + tm.mnasnet0_5, ] - if version.parse(torchvision.__version__) >= version.parse('0.12.0'): + if version.parse(torchvision.__version__) >= version.parse("0.12.0"): MODEL_LIST.extend([tm.vit_b_16, tm.convnext_small]) tracer = ColoTracer() @@ -57,10 +64,10 @@ def test_torchvision_models(): output_part1 = model_part1(output_part0) else: if len(output_part0) > len(sig.parameters): - output_part0 = output_part0[:len(sig.parameters)] + output_part0 = output_part0[: len(sig.parameters)] output_part1 = model_part1(*output_part0) assert output.equal(output_part1) -if __name__ == '__main__': +if __name__ == "__main__": test_torchvision_models() diff --git a/tests/test_fx/test_pipeline_passes.py b/tests/test_fx/test_pipeline_passes.py index 1078dac9db7cb4c2e00191a6fc33121f397f89e6..7a5a397500bb0e090858cfee20a074d6260cadf5 100644 --- a/tests/test_fx/test_pipeline_passes.py +++ b/tests/test_fx/test_pipeline_passes.py @@ -1,10 +1,6 @@ -import pytest import torch -import torch.nn as nn from torch.fx import symbolic_trace -import colossalai -import colossalai.nn as col_nn from colossalai.fx.passes.adding_split_node_pass import ( balanced_split_pass, balanced_split_pass_v2, @@ -19,7 +15,6 @@ PIPELINE_SIZE = 2 class MLP(torch.nn.Module): - def __init__(self, dim: int): super().__init__() self.linear1 = torch.nn.Linear(dim, dim) @@ -53,5 +48,5 @@ def test_pipeline_passes(): pipeline_pass_test_helper(model, data, uniform_split_pass) -if __name__ == '__main__': +if __name__ == "__main__": test_pipeline_passes() diff --git a/tests/test_fx/test_profiler/gpt_utils.py b/tests/test_fx/test_profiler/gpt_utils.py index aec32268484fbcc3b362207a02b8a022bf660642..9e4214876ba7985b35096708f6e7d0c5328c1cf3 100644 --- a/tests/test_fx/test_profiler/gpt_utils.py +++ b/tests/test_fx/test_profiler/gpt_utils.py @@ -1,26 +1,29 @@ -import torch import torch.nn as nn from transformers import GPT2Config, GPT2LMHeadModel class GPTLMModel(nn.Module): - - def __init__(self, - hidden_size=768, - num_layers=12, - num_attention_heads=12, - max_seq_len=1024, - vocab_size=50257, - checkpoint=False): + def __init__( + self, + hidden_size=768, + num_layers=12, + num_attention_heads=12, + max_seq_len=1024, + vocab_size=50257, + checkpoint=False, + ): super().__init__() self.checkpoint = checkpoint self.model = GPT2LMHeadModel( - GPT2Config(n_embd=hidden_size, - n_layer=num_layers, - n_head=num_attention_heads, - n_positions=max_seq_len, - n_ctx=max_seq_len, - vocab_size=vocab_size)) + GPT2Config( + n_embd=hidden_size, + n_layer=num_layers, + n_head=num_attention_heads, + n_positions=max_seq_len, + n_ctx=max_seq_len, + vocab_size=vocab_size, + ) + ) if checkpoint: self.model.gradient_checkpointing_enable() @@ -30,7 +33,6 @@ class GPTLMModel(nn.Module): class GPTLMLoss(nn.Module): - def __init__(self): super().__init__() self.loss_fn = nn.CrossEntropyLoss() diff --git a/tests/test_fx/test_profiler/test_profiler_meta_info_prop.py b/tests/test_fx/test_profiler/test_profiler_meta_info_prop.py index b5a6bbe8bf181c92f5fb8bf18779925889bc35fb..28409696ca553d8d51732de1ec6775fe75511f1c 100644 --- a/tests/test_fx/test_profiler/test_profiler_meta_info_prop.py +++ b/tests/test_fx/test_profiler/test_profiler_meta_info_prop.py @@ -1,9 +1,9 @@ -from typing import Optional, Tuple, Union +from typing import Tuple import torch import torch.fx import torchvision.models as tm -from gpt_utils import gpt2_medium, gpt2_xl +from gpt_utils import gpt2_medium from torch.fx import symbolic_trace from colossalai.fx.passes.meta_info_prop import MetaInfoProp @@ -33,18 +33,18 @@ def extract_forward_flops(gm: torch.fx.GraphModule): fwd_flop = 0 bwd_flop = 0 for node in gm.graph.nodes: - fwd_flop += node.meta.get('fwd_flop', 0) - bwd_flop += node.meta.get('bwd_flop', 0) + fwd_flop += node.meta.get("fwd_flop", 0) + bwd_flop += node.meta.get("bwd_flop", 0) return fwd_flop, bwd_flop -def gen_tm_data(batch_size: int, shape: Tuple[int, int, int], device='cuda'): +def gen_tm_data(batch_size: int, shape: Tuple[int, int, int], device="cuda"): data = torch.rand(batch_size, *shape, device=device) label = torch.empty(batch_size, dtype=torch.long, device=device).random_(1000) return data, label -def gen_gpt_data(batch_size, seq_len, vocab_size, device='cpu'): +def gen_gpt_data(batch_size, seq_len, vocab_size, device="cpu"): input_ids = torch.randint(0, vocab_size, (batch_size, seq_len), device=device) attention_mask = torch.ones_like(input_ids, device=device) return input_ids, attention_mask @@ -96,7 +96,7 @@ def run_gpt_forward(gm: torch.fx.GraphModule): param_mem += torch.cuda.memory_allocated(device="cuda:0") / 1024**2 for n in range(NUM_STEPS): torch.cuda.reset_peak_memory_stats() - data, mask = gen_gpt_data(GPT_BATCH_SIZE, 1024, 50257, device='cuda:0') + data, mask = gen_gpt_data(GPT_BATCH_SIZE, 1024, 50257, device="cuda:0") # If we need to dive deep into the memory usage by # inspecting `saved_tensor_hooks` @@ -125,21 +125,56 @@ def run_gpt_forward(gm: torch.fx.GraphModule): return forward_mem, param_mem -@run_on_environment_flag(name='FX_PROFILER') +@run_on_environment_flag(name="FX_PROFILER") @clear_cache_before_run() def test_meta_info_prop(): for m in [ - tm.alexnet, tm.resnet18, tm.resnet34, tm.resnet50, tm.resnet101, tm.resnet152, tm.densenet121, - tm.densenet161, tm.densenet169, tm.densenet201, tm.convnext_tiny, tm.convnext_small, tm.convnext_base, - tm.convnext_large, tm.wide_resnet50_2, tm.wide_resnet101_2, tm.regnet_x_16gf, tm.mnasnet0_5, - tm.efficientnet_b0, tm.shufflenet_v2_x0_5, tm.shufflenet_v2_x1_0, tm.shufflenet_v2_x1_5, - tm.shufflenet_v2_x2_0, tm.mobilenet_v2, tm.mobilenet_v3_small, tm.mobilenet_v3_large, tm.resnext50_32x4d, - tm.resnext101_32x8d, tm.resnext101_64x4d, tm.vit_b_16, tm.vit_b_32, tm.vit_h_14, tm.vit_l_16, tm.vit_l_32, - tm.vgg11, tm.vgg11_bn, tm.vgg13, tm.vgg13_bn, tm.vgg16, tm.vgg16_bn, tm.vgg19, tm.vgg19_bn + tm.alexnet, + tm.resnet18, + tm.resnet34, + tm.resnet50, + tm.resnet101, + tm.resnet152, + tm.densenet121, + tm.densenet161, + tm.densenet169, + tm.densenet201, + tm.convnext_tiny, + tm.convnext_small, + tm.convnext_base, + tm.convnext_large, + tm.wide_resnet50_2, + tm.wide_resnet101_2, + tm.regnet_x_16gf, + tm.mnasnet0_5, + tm.efficientnet_b0, + tm.shufflenet_v2_x0_5, + tm.shufflenet_v2_x1_0, + tm.shufflenet_v2_x1_5, + tm.shufflenet_v2_x2_0, + tm.mobilenet_v2, + tm.mobilenet_v3_small, + tm.mobilenet_v3_large, + tm.resnext50_32x4d, + tm.resnext101_32x8d, + tm.resnext101_64x4d, + tm.vit_b_16, + tm.vit_b_32, + tm.vit_h_14, + tm.vit_l_16, + tm.vit_l_32, + tm.vgg11, + tm.vgg11_bn, + tm.vgg13, + tm.vgg13_bn, + tm.vgg16, + tm.vgg16_bn, + tm.vgg19, + tm.vgg19_bn, ]: model = m().cuda() model.train() - data = MetaTensor(torch.rand(int(TM_BATCH_SIZE), 3, 224, 224, device='meta'), fake_device='cuda:0') + data = MetaTensor(torch.rand(int(TM_BATCH_SIZE), 3, 224, 224, device="meta"), fake_device="cuda:0") gm = symbolic_trace(model) interp = MetaInfoProp(gm) interp.propagate(data) @@ -150,22 +185,22 @@ def test_meta_info_prop(): concrete_forward_mem, concrete_param_mem = run_tm_forward(gm) print( - f'|{m.__name__}|{meta_forward_mem:.3f} MB|{meta_param_mem:.3f} MB|{concrete_forward_mem:.3f} MB|{concrete_param_mem:.3f} MB|fwd_flop={fwd_flop / 1e9:.3f}GFLOPs|bwd_flop={bwd_flop / 1e9:.3f}GFLOPs|' + f"|{m.__name__}|{meta_forward_mem:.3f} MB|{meta_param_mem:.3f} MB|{concrete_forward_mem:.3f} MB|{concrete_param_mem:.3f} MB|fwd_flop={fwd_flop / 1e9:.3f}GFLOPs|bwd_flop={bwd_flop / 1e9:.3f}GFLOPs|" ) del model, gm -@run_on_environment_flag(name='FX_PROFILER') +@run_on_environment_flag(name="FX_PROFILER") @clear_cache_before_run() def test_gpt_meta_info_prop(): for m in [gpt2_medium]: model = m().cuda() model.train() - data, mask = gen_gpt_data(GPT_BATCH_SIZE, 1024, 50257, device='meta') - graph = ColoTracer().trace(model, meta_args={'input_ids': data, 'attention_mask': mask}) + data, mask = gen_gpt_data(GPT_BATCH_SIZE, 1024, 50257, device="meta") + graph = ColoTracer().trace(model, meta_args={"input_ids": data, "attention_mask": mask}) gm = torch.fx.GraphModule(model, graph) interp = MetaInfoProp(gm) - interp.propagate(MetaTensor(data, fake_device='cuda:0'), MetaTensor(mask, fake_device='cuda:0')) + interp.propagate(MetaTensor(data, fake_device="cuda:0"), MetaTensor(mask, fake_device="cuda:0")) model.cpu() fwd_flop, bwd_flop = extract_forward_flops(gm) @@ -174,11 +209,11 @@ def test_gpt_meta_info_prop(): meta_forward_mem, meta_param_mem = extract_forward_mem(gm) print( - f'|{m.__name__}|{meta_forward_mem:.3f} MB|{meta_param_mem:.3f} MB|{concrete_forward_mem:.3f} MB|{concrete_param_mem:.3f} MB|fwd_flop={fwd_flop / 1e9:.3f}GFLOPs|bwd_flop={bwd_flop / 1e9:.3f}GFLOPs|' + f"|{m.__name__}|{meta_forward_mem:.3f} MB|{meta_param_mem:.3f} MB|{concrete_forward_mem:.3f} MB|{concrete_param_mem:.3f} MB|fwd_flop={fwd_flop / 1e9:.3f}GFLOPs|bwd_flop={bwd_flop / 1e9:.3f}GFLOPs|" ) del model, gm -if __name__ == '__main__': +if __name__ == "__main__": test_meta_info_prop() test_gpt_meta_info_prop() diff --git a/tests/test_fx/test_tracer/test_activation_checkpoint_annotation.py b/tests/test_fx/test_tracer/test_activation_checkpoint_annotation.py index 632ab8c097505d6e805219527c7d51ecc6ba69d6..e7dcf07aafb4b7bd3640361d03a56b61464d9687 100644 --- a/tests/test_fx/test_tracer/test_activation_checkpoint_annotation.py +++ b/tests/test_fx/test_tracer/test_activation_checkpoint_annotation.py @@ -1,5 +1,4 @@ import torch -import torch.nn as nn from torch.fx import GraphModule from torch.utils.checkpoint import checkpoint @@ -8,7 +7,6 @@ from colossalai.testing import clear_cache_before_run class MLP(torch.nn.Module): - def __init__(self): super().__init__() self.linear1 = torch.nn.Linear(4, 4) @@ -22,7 +20,6 @@ class MLP(torch.nn.Module): # Simple module for demonstration class MyModule(torch.nn.Module): - def __init__(self): super().__init__() self.mlp_1 = MLP() @@ -46,20 +43,20 @@ def test_activation_checkpoint_annotation(): gm = GraphModule(module, graph) for node in gm.graph.nodes: - if node.name in ['mlp_1_linear1', 'mlp_1_linear2']: - assert node.meta.get('activation_checkpoint', -1) == 0 + if node.name in ["mlp_1_linear1", "mlp_1_linear2"]: + assert node.meta.get("activation_checkpoint", -1) == 0 for node in gm.graph.nodes: - if node.name in ['mlp_2_linear1', 'mlp_2_linear2']: - assert node.meta.get('activation_checkpoint', -1) == 1 + if node.name in ["mlp_2_linear1", "mlp_2_linear2"]: + assert node.meta.get("activation_checkpoint", -1) == 1 tracer = ColoTracer(trace_act_ckpt=False) graph = tracer.trace(module) gm = GraphModule(module, graph) for node in gm.graph.nodes: - assert not hasattr(node, 'activation_checkpoint') + assert not hasattr(node, "activation_checkpoint") -if __name__ == '__main__': +if __name__ == "__main__": test_activation_checkpoint_annotation() diff --git a/tests/test_fx/test_tracer/test_bias_addition_module.py b/tests/test_fx/test_tracer/test_bias_addition_module.py index 2f88d8c784e8144e557747bc35e86b74c96e7478..e53894bdfd71fc086923f7d9e0bef51ea3e9625b 100644 --- a/tests/test_fx/test_tracer/test_bias_addition_module.py +++ b/tests/test_fx/test_tracer/test_bias_addition_module.py @@ -5,7 +5,6 @@ from colossalai.testing import clear_cache_before_run class LinearModel(torch.nn.Module): - def __init__(self, in_features, out_features): super().__init__() self.linear = torch.nn.Linear(in_features, out_features) @@ -18,13 +17,11 @@ class LinearModel(torch.nn.Module): class ConvModel(torch.nn.Module): - def __init__(self, in_channels, out_channels, kernel_size, bias=True): super().__init__() - self.conv = torch.nn.Conv2d(in_channels=in_channels, - out_channels=out_channels, - kernel_size=kernel_size, - bias=bias) + self.conv = torch.nn.Conv2d( + in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, bias=bias + ) def forward(self, x): x = self.conv(x) @@ -45,7 +42,7 @@ def test_linear_module(): # %add : [#users=1] = call_function[target=operator.add](args = (%linear, %linear_bias), kwargs = {}) # %mul : [#users=1] = call_function[target=operator.mul](args = (%add, 2), kwargs = {}) # return mul - graph = tracer.trace(root=model, meta_args={'x': torch.rand(3, 3).to('meta')}) + graph = tracer.trace(root=model, meta_args={"x": torch.rand(3, 3).to("meta")}) # def forward(self, x : torch.Tensor): # linear_weight = self.linear.weight # linear_bias = self.linear.bias @@ -57,9 +54,9 @@ def test_linear_module(): gm.recompile() node_list = list(graph.nodes) for node in node_list: - if node.op == 'output': + if node.op == "output": continue - assert hasattr(node, '_meta_data') + assert hasattr(node, "_meta_data") weight_node = node_list[1] bias_node = node_list[2] linear_node = node_list[3] @@ -83,7 +80,7 @@ def test_conv_module(): # %add : [#users=1] = call_function[target=operator.add](args = (%conv2d, %view), kwargs = {}) # %mul : [#users=1] = call_function[target=operator.mul](args = (%add, 2), kwargs = {}) # return mul - graph = tracer.trace(root=model, meta_args={'x': torch.rand(4, 3, 64, 64).to('meta')}) + graph = tracer.trace(root=model, meta_args={"x": torch.rand(4, 3, 64, 64).to("meta")}) # def forward(self, x : torch.Tensor): # conv_weight = self.conv.weight # conv_bias = self.conv.bias @@ -97,9 +94,9 @@ def test_conv_module(): gm.recompile() node_list = list(graph.nodes) for node in node_list: - if node.op == 'output': + if node.op == "output": continue - assert hasattr(node, '_meta_data') + assert hasattr(node, "_meta_data") weight_node = node_list[1] bias_node = node_list[2] conv_node = node_list[3] @@ -112,6 +109,6 @@ def test_conv_module(): assert add_node._meta_data.shape == (4, 6, 63, 63) -if __name__ == '__main__': +if __name__ == "__main__": test_linear_module() test_conv_module() diff --git a/tests/test_fx/test_tracer/test_control_flow.py b/tests/test_fx/test_tracer/test_control_flow.py index 820729dadb3ed21aad6f0230e3094f839a14d50f..f0c261c39db582c0f0b5ff4aaf7e81c7b2d73c74 100644 --- a/tests/test_fx/test_tracer/test_control_flow.py +++ b/tests/test_fx/test_tracer/test_control_flow.py @@ -7,7 +7,6 @@ from colossalai.testing import clear_cache_before_run class ControlFlowModel(nn.Module): - def __init__(self): super().__init__() self.linear1 = nn.Linear(10, 10) @@ -27,16 +26,12 @@ class ControlFlowModel(nn.Module): def test_control_flow(): model = ControlFlowModel() tracer = Tracer() - graph_branch_true = tracer.trace(model, - meta_args={ - 'x': torch.rand(4, 10, device='meta'), - 'y': torch.rand(4, 10, device='meta') - }) - graph_branch_false = tracer.trace(model, - meta_args={ - 'x': torch.rand(10, device='meta'), - 'y': torch.rand(4, 10, device='meta') - }) + graph_branch_true = tracer.trace( + model, meta_args={"x": torch.rand(4, 10, device="meta"), "y": torch.rand(4, 10, device="meta")} + ) + graph_branch_false = tracer.trace( + model, meta_args={"x": torch.rand(10, device="meta"), "y": torch.rand(4, 10, device="meta")} + ) gm_branch_true = GraphModule(model, graph_branch_true, model.__class__.__name__) gm_branch_false = GraphModule(model, graph_branch_false, model.__class__.__name__) @@ -56,5 +51,5 @@ def test_control_flow(): assert torch.all(gm_branch_false(x, y) != gm_branch_true(x, y)) -if __name__ == '__main__': +if __name__ == "__main__": test_control_flow() diff --git a/tests/test_fx/test_tracer/test_functional_conv.py b/tests/test_fx/test_tracer/test_functional_conv.py index a552e905223d5ce2b1c9e7208242edb565b7e882..63f9721e2a65d71af2241106a550cea46e9d6342 100644 --- a/tests/test_fx/test_tracer/test_functional_conv.py +++ b/tests/test_fx/test_tracer/test_functional_conv.py @@ -47,5 +47,5 @@ def test_conv(): assert out_transpose_3d.shape == patched_out_transpose_3d.shape -if __name__ == '__main__': +if __name__ == "__main__": test_conv() diff --git a/tests/test_fx/test_tracer/test_hf_model/hf_tracer_utils.py b/tests/test_fx/test_tracer/test_hf_model/hf_tracer_utils.py index 7a4bf131ae36875aa9ce69ba00a08e75741bf1d2..4828bb0302c892e11a88451f2d42a4ddc293697a 100644 --- a/tests/test_fx/test_tracer/test_hf_model/hf_tracer_utils.py +++ b/tests/test_fx/test_tracer/test_hf_model/hf_tracer_utils.py @@ -1,25 +1,29 @@ +from typing import List + import torch -from numpy import isin -from torch.fx import GraphModule -from torch.utils._pytree import tree_flatten # from colossalai.fx import symbolic_trace from colossalai._analyzer.fx import symbolic_trace -def trace_model_and_compare_output(model, data_gen): +def trace_model_and_compare_output(model, data_gen, ignore_data: List[str] = None): # must turn on eval mode to ensure the output is consistent model.eval() + inputs = data_gen() + + if ignore_data is not None: + # drop the ignore_data key + inputs = {k: v for k, v in inputs.items() if k not in ignore_data} + try: - kwargs = data_gen() - meta_args = {k: v.to('meta') for k, v in kwargs.items()} + meta_args = {k: v.to("meta") for k, v in inputs.items()} gm = symbolic_trace(model, meta_args=meta_args) + except Exception as e: raise RuntimeError(f"Failed to trace {model.__class__.__name__}, error: {e}") # run forward - inputs = data_gen() non_fx_out = model(**inputs) fx_out = gm(**inputs) @@ -28,4 +32,4 @@ def trace_model_and_compare_output(model, data_gen): if torch.is_tensor(fx_out[k]): assert torch.equal( fx_out[k], non_fx_out[k] - ), f'{model.__class__.__name__} has incorrect output {k}, expect {non_fx_out[k]}, but got {fx_out[k]}' + ), f"{model.__class__.__name__} has incorrect output {k}, expect {non_fx_out[k]}, but got {fx_out[k]}" diff --git a/tests/test_fx/test_tracer/test_hf_model/test_hf_albert.py b/tests/test_fx/test_tracer/test_hf_model/test_hf_albert.py index f4d681221191ca2084f184ec0bacea0749356c74..fb093821e48808ec91c0faabf9d7a3ea1615133c 100644 --- a/tests/test_fx/test_tracer/test_hf_model/test_hf_albert.py +++ b/tests/test_fx/test_tracer/test_hf_model/test_hf_albert.py @@ -10,15 +10,15 @@ BATCH_SIZE = 2 SEQ_LENGTH = 16 -@pytest.mark.skipif(version.parse(torch.__version__) < version.parse('1.12.0'), reason='torch version < 12') +@pytest.mark.skipif(version.parse(torch.__version__) < version.parse("1.12.0"), reason="torch version < 12") @clear_cache_before_run() def test_albert(): - sub_registry = model_zoo.get_sub_registry('transformers_albert') + sub_registry = model_zoo.get_sub_registry("transformers_albert") - for name, (model_fn, data_gen_fn, _, _) in sub_registry.items(): + for name, (model_fn, data_gen_fn, _, _, _) in sub_registry.items(): model = model_fn() trace_model_and_compare_output(model, data_gen_fn) -if __name__ == '__main__': +if __name__ == "__main__": test_albert() diff --git a/tests/test_fx/test_tracer/test_hf_model/test_hf_bert.py b/tests/test_fx/test_tracer/test_hf_model/test_hf_bert.py index a833bb30c056b772dce2b4cbb158be37d254ed0e..91f7b9764e6eddcfe8a534e0a41c848d1942a21f 100644 --- a/tests/test_fx/test_tracer/test_hf_model/test_hf_bert.py +++ b/tests/test_fx/test_tracer/test_hf_model/test_hf_bert.py @@ -7,15 +7,17 @@ from colossalai.testing import clear_cache_before_run from tests.kit.model_zoo import model_zoo -@pytest.mark.skipif(version.parse(torch.__version__) < version.parse('1.12.0'), reason='torch version < 12') +@pytest.mark.skipif(version.parse(torch.__version__) < version.parse("1.12.0"), reason="torch version < 12") @clear_cache_before_run() def test_bert(): - sub_registry = model_zoo.get_sub_registry('transformers_bert') + sub_registry = model_zoo.get_sub_registry("transformers_bert") - for name, (model_fn, data_gen_fn, _, _) in sub_registry.items(): + for name, (model_fn, data_gen_fn, _, _, _) in sub_registry.items(): model = model_fn() - trace_model_and_compare_output(model, data_gen_fn) + if model.__class__.__name__ == "BertForQuestionAnswering": + continue + trace_model_and_compare_output(model, data_gen_fn, ignore_data=["labels", "next_sentence_label"]) -if __name__ == '__main__': +if __name__ == "__main__": test_bert() diff --git a/tests/test_fx/test_tracer/test_hf_model/test_hf_diffuser.py b/tests/test_fx/test_tracer/test_hf_model/test_hf_diffuser.py index 0cbea82e083a80159bd4603b6995e98ca86dc20a..95a464fa0534208973be9206c5b201aa9f6b17e5 100644 --- a/tests/test_fx/test_tracer/test_hf_model/test_hf_diffuser.py +++ b/tests/test_fx/test_tracer/test_hf_model/test_hf_diffuser.py @@ -22,7 +22,7 @@ def trace_and_compare(model_cls, data, output_fn): model.eval() concrete_args = {k: v for k, v in data.items() if not torch.is_tensor(v)} - meta_args = {k: v.to('meta') for k, v in data.items() if torch.is_tensor(v)} + meta_args = {k: v.to("meta") for k, v in data.items() if torch.is_tensor(v)} gm = symbolic_trace(model, concrete_args=concrete_args, meta_args=meta_args) # run forward @@ -40,14 +40,14 @@ def trace_and_compare(model_cls, data, output_fn): assert_dict(transformed_fx_out, transformed_non_fx_out, assert_fn) -@pytest.mark.skip(reason='cannot pass this test yet') +@pytest.mark.skip(reason="cannot pass this test yet") @clear_cache_before_run() def test_diffusers(): seed_all(9091, cuda_deterministic=True) - sub_model_zoo = model_zoo.get_sub_registry('diffusers') + sub_model_zoo = model_zoo.get_sub_registry("diffusers") - for name, (model_fn, data_gen_fn, output_transform_fn, attribute) in sub_model_zoo.items(): + for name, (model_fn, data_gen_fn, output_transform_fn, _, _, attribute) in sub_model_zoo.items(): data = data_gen_fn() trace_and_compare(model_fn, data, output_transform_fn) torch.cuda.synchronize() @@ -58,12 +58,12 @@ def test_diffusers(): def test_torch_diffusers(): seed_all(65535, cuda_deterministic=True) - sub_model_zoo = model_zoo.get_sub_registry('diffusers') + sub_model_zoo = model_zoo.get_sub_registry("diffusers") - for name, (model_fn, data_gen_fn, output_transform_fn, attribute) in sub_model_zoo.items(): + for name, (model_fn, data_gen_fn, output_transform_fn, _, attribute) in sub_model_zoo.items(): data = data_gen_fn() model = model_fn() - output = model(**data) + model(**data) torch.cuda.synchronize() print(f"{name:40s} √") diff --git a/tests/test_fx/test_tracer/test_hf_model/test_hf_gpt.py b/tests/test_fx/test_tracer/test_hf_model/test_hf_gpt.py index 67107469d8bb898edfb487cc63bdac0cbfe9b024..7bd8a726f1acfee22b2ff5eec09beb274074cb6e 100644 --- a/tests/test_fx/test_tracer/test_hf_model/test_hf_gpt.py +++ b/tests/test_fx/test_tracer/test_hf_model/test_hf_gpt.py @@ -7,22 +7,22 @@ from colossalai.testing import clear_cache_before_run from tests.kit.model_zoo import model_zoo -@pytest.mark.skipif(version.parse(torch.__version__) < version.parse('1.12.0'), reason='torch version < 12') +@pytest.mark.skipif(version.parse(torch.__version__) < version.parse("1.12.0"), reason="torch version < 12") @clear_cache_before_run() def test_gpt(): - sub_registry = model_zoo.get_sub_registry('transformers_gpt') + sub_registry = model_zoo.get_sub_registry("transformers_gpt") - for name, (model_fn, data_gen_fn, _, _) in sub_registry.items(): + for name, (model_fn, data_gen_fn, _, _, _) in sub_registry.items(): model = model_fn() - # TODO: support the following models + # TODO(ver217): support the following models # 1. GPT2DoubleHeadsModel # as they are not supported, let's skip them - if model.__class__.__name__ in ['GPT2DoubleHeadsModel']: + if model.__class__.__name__ in ["GPT2DoubleHeadsModel", "GPT2ForQuestionAnswering"]: continue - trace_model_and_compare_output(model, data_gen_fn) + trace_model_and_compare_output(model, data_gen_fn, ignore_data=["labels"]) -if __name__ == '__main__': +if __name__ == "__main__": test_gpt() diff --git a/tests/test_fx/test_tracer/test_hf_model/test_hf_opt.py b/tests/test_fx/test_tracer/test_hf_model/test_hf_opt.py index 369545b03de1b4bf0746d46b9928ed68be9688d4..5f7525d5707bdcd9097bd6a6f52166c2295acb4c 100644 --- a/tests/test_fx/test_tracer/test_hf_model/test_hf_opt.py +++ b/tests/test_fx/test_tracer/test_hf_model/test_hf_opt.py @@ -7,15 +7,14 @@ from colossalai.testing import clear_cache_before_run from tests.kit.model_zoo import model_zoo -@pytest.mark.skipif(version.parse(torch.__version__) < version.parse('1.12.0'), reason='torch version < 12') +@pytest.mark.skipif(version.parse(torch.__version__) < version.parse("1.12.0"), reason="torch version < 12") @clear_cache_before_run() def test_opt(): - sub_registry = model_zoo.get_sub_registry('transformers_opt') - - for name, (model_fn, data_gen_fn, _, _) in sub_registry.items(): + sub_registry = model_zoo.get_sub_registry("transformers_opt") + for name, (model_fn, data_gen_fn, _, _, _) in sub_registry.items(): model = model_fn() - trace_model_and_compare_output(model, data_gen_fn) + trace_model_and_compare_output(model, data_gen_fn, ignore_data=["labels", "start_positions", "end_positions"]) -if __name__ == '__main__': +if __name__ == "__main__": test_opt() diff --git a/tests/test_fx/test_tracer/test_hf_model/test_hf_t5.py b/tests/test_fx/test_tracer/test_hf_model/test_hf_t5.py index 811cf3b214303ae62569158d1321b1a99ff6d76e..6ccbb14e3d96acb540cc46a7fcbab399f9d6b1e1 100644 --- a/tests/test_fx/test_tracer/test_hf_model/test_hf_t5.py +++ b/tests/test_fx/test_tracer/test_hf_model/test_hf_t5.py @@ -7,15 +7,20 @@ from colossalai.testing import clear_cache_before_run from tests.kit.model_zoo import model_zoo -@pytest.mark.skipif(version.parse(torch.__version__) < version.parse('1.12.0'), reason='torch version < 12') +@pytest.mark.skipif(version.parse(torch.__version__) < version.parse("1.12.0"), reason="torch version < 12") @clear_cache_before_run() def test_t5(): - sub_registry = model_zoo.get_sub_registry('transformers_t5') + sub_registry = model_zoo.get_sub_registry("transformers_t5") + + for name, (model_fn, data_gen_fn, _, _, _) in sub_registry.items(): + if name == "transformers_t5_for_conditional_generation": + # cannot trace for loss function yet + # so we use a data gen which does not produce labels + data_gen_fn = sub_registry.get("transformers_t5")[1] - for name, (model_fn, data_gen_fn, _, _) in sub_registry.items(): model = model_fn() - trace_model_and_compare_output(model, data_gen_fn) + trace_model_and_compare_output(model, data_gen_fn, ignore_data=["labels"]) -if __name__ == '__main__': +if __name__ == "__main__": test_t5() diff --git a/tests/test_fx/test_tracer/test_patched_module.py b/tests/test_fx/test_tracer/test_patched_module.py index ef778e21801a84bd4e55ca08937db67bffaca5c1..fe66cbd0ffcc6495490fb8528ddac6f4dffd4eda 100644 --- a/tests/test_fx/test_tracer/test_patched_module.py +++ b/tests/test_fx/test_tracer/test_patched_module.py @@ -36,12 +36,12 @@ def _assert_output_shape(data, module, patch_fn, expect_exception, output_shape) @clear_cache_before_run() def test_linear(): # test linear patch can produce the meta output with correct shape - data = torch.rand(2, 4, device='meta') + data = torch.rand(2, 4, device="meta") module = torch.nn.Linear(4, 2) _assert_output_shape(data, module, patched_module.torch_nn_linear, False, torch.Size([2, 2])) # test if the linear patch can catch exception when dimension does not match - data = torch.rand(2, 2, device='meta') + data = torch.rand(2, 2, device="meta") _assert_output_shape(data, module, patched_module.torch_nn_linear, True, None) @@ -51,20 +51,20 @@ def test_rnn(): data = (torch.randn(5, 3, 10), torch.randn(2, 3, 20)) module = torch.nn.RNN(10, 20, 2) output, hn = module(*data) - meta_data = (torch.randn(5, 3, 10).to('meta'), torch.randn(2, 3, 20).to('meta')) + meta_data = (torch.randn(5, 3, 10).to("meta"), torch.randn(2, 3, 20).to("meta")) _assert_output_shape(meta_data, module, patched_module.torch_nn_rnn, False, (output.shape, hn.shape)) # test if the rnn patch can catch exception when dimension does not match data = (torch.randn(5, 3, 10), torch.randn(2, 3, 20)) module = torch.nn.RNN(10, 20, 2) output, hn = module(*data) - meta_data = (torch.randn(5, 3, 1).to('meta'), torch.randn(2, 3, 20).to('meta')) + meta_data = (torch.randn(5, 3, 1).to("meta"), torch.randn(2, 3, 20).to("meta")) _assert_output_shape(meta_data, module, patched_module.torch_nn_rnn, True, None) @clear_cache_before_run() def test_embedding(): - data = torch.rand(2, 4, device='meta') + data = torch.rand(2, 4, device="meta") # test layernorm ln = torch.nn.LayerNorm(4) @@ -76,67 +76,71 @@ def test_embedding(): # test batch norm 1d bn1d = torch.nn.BatchNorm1d(4) - data = torch.rand(2, 4, device='meta') - _assert_output_shape(data=data, - module=bn1d, - patch_fn=patched_module.torch_nn_normalize, - expect_exception=False, - output_shape=data.shape) - - data = torch.rand(2, 4, device='meta') - _assert_output_shape(data=data, - module=bn1d, - patch_fn=patched_module.torch_nn_normalize, - expect_exception=False, - output_shape=data.shape) - - data = torch.rand(2, 3, 4, device='meta') - _assert_output_shape(data=data, - module=bn1d, - patch_fn=patched_module.torch_nn_normalize, - expect_exception=False, - output_shape=data.shape) - - data = torch.rand(1, 2, 3, 4, device='meta') - _assert_output_shape(data=data, - module=bn1d, - patch_fn=patched_module.torch_nn_normalize, - expect_exception=True, - output_shape=None) + data = torch.rand(2, 4, device="meta") + _assert_output_shape( + data=data, + module=bn1d, + patch_fn=patched_module.torch_nn_normalize, + expect_exception=False, + output_shape=data.shape, + ) + + data = torch.rand(2, 4, device="meta") + _assert_output_shape( + data=data, + module=bn1d, + patch_fn=patched_module.torch_nn_normalize, + expect_exception=False, + output_shape=data.shape, + ) + + data = torch.rand(2, 3, 4, device="meta") + _assert_output_shape( + data=data, + module=bn1d, + patch_fn=patched_module.torch_nn_normalize, + expect_exception=False, + output_shape=data.shape, + ) + + data = torch.rand(1, 2, 3, 4, device="meta") + _assert_output_shape( + data=data, module=bn1d, patch_fn=patched_module.torch_nn_normalize, expect_exception=True, output_shape=None + ) # test batch norm 2d bn2d = torch.nn.BatchNorm2d(4) - data = torch.rand(1, 2, 3, 4, device='meta') - _assert_output_shape(data=data, - module=bn2d, - patch_fn=patched_module.torch_nn_normalize, - expect_exception=False, - output_shape=data.shape) + data = torch.rand(1, 2, 3, 4, device="meta") + _assert_output_shape( + data=data, + module=bn2d, + patch_fn=patched_module.torch_nn_normalize, + expect_exception=False, + output_shape=data.shape, + ) - data = torch.rand(2, 3, 4, device='meta') - _assert_output_shape(data=data, - module=bn2d, - patch_fn=patched_module.torch_nn_normalize, - expect_exception=True, - output_shape=None) + data = torch.rand(2, 3, 4, device="meta") + _assert_output_shape( + data=data, module=bn2d, patch_fn=patched_module.torch_nn_normalize, expect_exception=True, output_shape=None + ) # # test batch size 3d bn3d = torch.nn.BatchNorm3d(4) - data = torch.rand(1, 1, 2, 3, 4, device='meta') - _assert_output_shape(data=data, - module=bn3d, - patch_fn=patched_module.torch_nn_normalize, - expect_exception=False, - output_shape=data.shape) + data = torch.rand(1, 1, 2, 3, 4, device="meta") + _assert_output_shape( + data=data, + module=bn3d, + patch_fn=patched_module.torch_nn_normalize, + expect_exception=False, + output_shape=data.shape, + ) - data = torch.rand(1, 2, 3, 4, device='meta') - _assert_output_shape(data=data, - module=bn3d, - patch_fn=patched_module.torch_nn_normalize, - expect_exception=True, - output_shape=None) + data = torch.rand(1, 2, 3, 4, device="meta") + _assert_output_shape( + data=data, module=bn3d, patch_fn=patched_module.torch_nn_normalize, expect_exception=True, output_shape=None + ) @clear_cache_before_run() @@ -146,35 +150,38 @@ def test_conv1d(): conv1d = torch.nn.Conv1d(in_channels=3, out_channels=4, kernel_size=2) materialized_output = conv1d(data) - meta_data = data.to('meta') - _assert_output_shape(data=meta_data, - module=conv1d, - patch_fn=patched_module.torch_nn_conv1d, - expect_exception=False, - output_shape=materialized_output.shape) + meta_data = data.to("meta") + _assert_output_shape( + data=meta_data, + module=conv1d, + patch_fn=patched_module.torch_nn_conv1d, + expect_exception=False, + output_shape=materialized_output.shape, + ) conv1d = torch.nn.Conv1d(in_channels=3, out_channels=4, kernel_size=2, padding=1) materialized_output = conv1d(data) - meta_data = data.to('meta') - _assert_output_shape(data=meta_data, - module=conv1d, - patch_fn=patched_module.torch_nn_conv1d, - expect_exception=False, - output_shape=materialized_output.shape) - - conv1d = torch.nn.Conv1d(in_channels=3, - out_channels=4, - kernel_size=2, - padding=1, - dilation=2, - padding_mode='reflect') + meta_data = data.to("meta") + _assert_output_shape( + data=meta_data, + module=conv1d, + patch_fn=patched_module.torch_nn_conv1d, + expect_exception=False, + output_shape=materialized_output.shape, + ) + + conv1d = torch.nn.Conv1d( + in_channels=3, out_channels=4, kernel_size=2, padding=1, dilation=2, padding_mode="reflect" + ) materialized_output = conv1d(data) - meta_data = data.to('meta') - _assert_output_shape(data=meta_data, - module=conv1d, - patch_fn=patched_module.torch_nn_conv1d, - expect_exception=False, - output_shape=materialized_output.shape) + meta_data = data.to("meta") + _assert_output_shape( + data=meta_data, + module=conv1d, + patch_fn=patched_module.torch_nn_conv1d, + expect_exception=False, + output_shape=materialized_output.shape, + ) def test_conv2d(): @@ -182,40 +189,45 @@ def test_conv2d(): data = torch.rand(2, 3, 4, 4) conv2d = torch.nn.Conv2d(in_channels=3, out_channels=4, kernel_size=2) materialized_output = conv2d(data) - _assert_output_shape(data=data, - module=conv2d, - patch_fn=patched_module.torch_nn_conv2d, - expect_exception=False, - output_shape=materialized_output.shape) + _assert_output_shape( + data=data, + module=conv2d, + patch_fn=patched_module.torch_nn_conv2d, + expect_exception=False, + output_shape=materialized_output.shape, + ) conv2d = torch.nn.Conv2d(in_channels=3, out_channels=4, kernel_size=2, padding=1) materialized_output = conv2d(data) - _assert_output_shape(data=data, - module=conv2d, - patch_fn=patched_module.torch_nn_conv2d, - expect_exception=False, - output_shape=materialized_output.shape) + _assert_output_shape( + data=data, + module=conv2d, + patch_fn=patched_module.torch_nn_conv2d, + expect_exception=False, + output_shape=materialized_output.shape, + ) conv2d = torch.nn.Conv2d(in_channels=3, out_channels=4, kernel_size=2, padding=1, dilation=2) materialized_output = conv2d(data) - _assert_output_shape(data=data, - module=conv2d, - patch_fn=patched_module.torch_nn_conv2d, - expect_exception=False, - output_shape=materialized_output.shape) - - conv2d = torch.nn.Conv2d(in_channels=3, - out_channels=4, - kernel_size=2, - padding=1, - dilation=2, - padding_mode='reflect') + _assert_output_shape( + data=data, + module=conv2d, + patch_fn=patched_module.torch_nn_conv2d, + expect_exception=False, + output_shape=materialized_output.shape, + ) + + conv2d = torch.nn.Conv2d( + in_channels=3, out_channels=4, kernel_size=2, padding=1, dilation=2, padding_mode="reflect" + ) materialized_output = conv2d(data) - _assert_output_shape(data=data, - module=conv2d, - patch_fn=patched_module.torch_nn_conv2d, - expect_exception=False, - output_shape=materialized_output.shape) + _assert_output_shape( + data=data, + module=conv2d, + patch_fn=patched_module.torch_nn_conv2d, + expect_exception=False, + output_shape=materialized_output.shape, + ) @clear_cache_before_run() @@ -224,40 +236,45 @@ def test_conv3d(): data = torch.rand(2, 3, 4, 4, 4) conv3d = torch.nn.Conv3d(in_channels=3, out_channels=4, kernel_size=2) materialized_output = conv3d(data) - _assert_output_shape(data=data, - module=conv3d, - patch_fn=patched_module.torch_nn_conv3d, - expect_exception=False, - output_shape=materialized_output.shape) + _assert_output_shape( + data=data, + module=conv3d, + patch_fn=patched_module.torch_nn_conv3d, + expect_exception=False, + output_shape=materialized_output.shape, + ) conv3d = torch.nn.Conv3d(in_channels=3, out_channels=4, kernel_size=2, padding=1) materialized_output = conv3d(data) - _assert_output_shape(data=data, - module=conv3d, - patch_fn=patched_module.torch_nn_conv3d, - expect_exception=False, - output_shape=materialized_output.shape) + _assert_output_shape( + data=data, + module=conv3d, + patch_fn=patched_module.torch_nn_conv3d, + expect_exception=False, + output_shape=materialized_output.shape, + ) conv3d = torch.nn.Conv3d(in_channels=3, out_channels=4, kernel_size=2, padding=1, dilation=2) materialized_output = conv3d(data) - _assert_output_shape(data=data, - module=conv3d, - patch_fn=patched_module.torch_nn_conv3d, - expect_exception=False, - output_shape=materialized_output.shape) - - conv3d = torch.nn.Conv3d(in_channels=3, - out_channels=4, - kernel_size=2, - padding=1, - dilation=2, - padding_mode='reflect') + _assert_output_shape( + data=data, + module=conv3d, + patch_fn=patched_module.torch_nn_conv3d, + expect_exception=False, + output_shape=materialized_output.shape, + ) + + conv3d = torch.nn.Conv3d( + in_channels=3, out_channels=4, kernel_size=2, padding=1, dilation=2, padding_mode="reflect" + ) materialized_output = conv3d(data) - _assert_output_shape(data=data, - module=conv3d, - patch_fn=patched_module.torch_nn_conv3d, - expect_exception=False, - output_shape=materialized_output.shape) + _assert_output_shape( + data=data, + module=conv3d, + patch_fn=patched_module.torch_nn_conv3d, + expect_exception=False, + output_shape=materialized_output.shape, + ) @clear_cache_before_run() @@ -267,21 +284,25 @@ def test_conv_transpose1d(): convtrans1d = torch.nn.ConvTranspose1d(in_channels=3, out_channels=4, kernel_size=2) materialized_output = convtrans1d(data) - meta_data = data.to('meta') - _assert_output_shape(data=meta_data, - module=convtrans1d, - patch_fn=patched_module.torch_nn_convtranspose1d, - expect_exception=False, - output_shape=materialized_output.shape) + meta_data = data.to("meta") + _assert_output_shape( + data=meta_data, + module=convtrans1d, + patch_fn=patched_module.torch_nn_convtranspose1d, + expect_exception=False, + output_shape=materialized_output.shape, + ) convtrans1d = torch.nn.ConvTranspose1d(in_channels=3, out_channels=4, kernel_size=2, padding=1) materialized_output = convtrans1d(data) - meta_data = data.to('meta') - _assert_output_shape(data=meta_data, - module=convtrans1d, - patch_fn=patched_module.torch_nn_convtranspose1d, - expect_exception=False, - output_shape=materialized_output.shape) + meta_data = data.to("meta") + _assert_output_shape( + data=meta_data, + module=convtrans1d, + patch_fn=patched_module.torch_nn_convtranspose1d, + expect_exception=False, + output_shape=materialized_output.shape, + ) @clear_cache_before_run() @@ -291,21 +312,25 @@ def test_conv_transpose2d(): convtrans2d = torch.nn.ConvTranspose2d(in_channels=3, out_channels=4, kernel_size=2) materialized_output = convtrans2d(data) - meta_data = data.to('meta') - _assert_output_shape(data=meta_data, - module=convtrans2d, - patch_fn=patched_module.torch_nn_convtranspose2d, - expect_exception=False, - output_shape=materialized_output.shape) + meta_data = data.to("meta") + _assert_output_shape( + data=meta_data, + module=convtrans2d, + patch_fn=patched_module.torch_nn_convtranspose2d, + expect_exception=False, + output_shape=materialized_output.shape, + ) convtrans2d = torch.nn.ConvTranspose2d(in_channels=3, out_channels=4, kernel_size=2, padding=1) materialized_output = convtrans2d(data) - meta_data = data.to('meta') - _assert_output_shape(data=meta_data, - module=convtrans2d, - patch_fn=patched_module.torch_nn_convtranspose2d, - expect_exception=False, - output_shape=materialized_output.shape) + meta_data = data.to("meta") + _assert_output_shape( + data=meta_data, + module=convtrans2d, + patch_fn=patched_module.torch_nn_convtranspose2d, + expect_exception=False, + output_shape=materialized_output.shape, + ) @clear_cache_before_run() @@ -315,46 +340,56 @@ def test_conv_transpose3d(): convtrans3d = torch.nn.ConvTranspose3d(in_channels=3, out_channels=4, kernel_size=2) materialized_output = convtrans3d(data) - meta_data = data.to('meta') - _assert_output_shape(data=meta_data, - module=convtrans3d, - patch_fn=patched_module.torch_nn_convtranspose3d, - expect_exception=False, - output_shape=materialized_output.shape) + meta_data = data.to("meta") + _assert_output_shape( + data=meta_data, + module=convtrans3d, + patch_fn=patched_module.torch_nn_convtranspose3d, + expect_exception=False, + output_shape=materialized_output.shape, + ) convtrans3d = torch.nn.ConvTranspose3d(in_channels=3, out_channels=4, kernel_size=2, padding=1) materialized_output = convtrans3d(data) - meta_data = data.to('meta') - _assert_output_shape(data=meta_data, - module=convtrans3d, - patch_fn=patched_module.torch_nn_convtranspose3d, - expect_exception=False, - output_shape=materialized_output.shape) + meta_data = data.to("meta") + _assert_output_shape( + data=meta_data, + module=convtrans3d, + patch_fn=patched_module.torch_nn_convtranspose3d, + expect_exception=False, + output_shape=materialized_output.shape, + ) @clear_cache_before_run() def test_pool1d(): - combinations = [[torch.nn.MaxPool1d, patched_module.torch_nn_maxpool1d], - [torch.nn.AvgPool1d, patched_module.torch_nn_avgpool1d]] + combinations = [ + [torch.nn.MaxPool1d, patched_module.torch_nn_maxpool1d], + [torch.nn.AvgPool1d, patched_module.torch_nn_avgpool1d], + ] - for (layer_cls, patch_func) in combinations: + for layer_cls, patch_func in combinations: pooler = layer_cls(kernel_size=3) data = torch.rand(2, 3, 4) materialized_output = pooler(data) - _assert_output_shape(data=data, - module=pooler, - patch_fn=patch_func, - expect_exception=False, - output_shape=materialized_output.shape) + _assert_output_shape( + data=data, + module=pooler, + patch_fn=patch_func, + expect_exception=False, + output_shape=materialized_output.shape, + ) data = torch.rand(2, 4) materialized_output = pooler(data) - _assert_output_shape(data=data, - module=pooler, - patch_fn=patch_func, - expect_exception=False, - output_shape=materialized_output.shape) + _assert_output_shape( + data=data, + module=pooler, + patch_fn=patch_func, + expect_exception=False, + output_shape=materialized_output.shape, + ) data = torch.rand(2, 3, 4, 4) _assert_output_shape(data=data, module=pooler, patch_fn=patch_func, expect_exception=True, output_shape=None) @@ -362,29 +397,35 @@ def test_pool1d(): @clear_cache_before_run() def test_pool2d(): - combinations = [[torch.nn.MaxPool2d, patched_module.torch_nn_maxpool2d], - [torch.nn.AvgPool2d, patched_module.torch_nn_avgpool2d]] + combinations = [ + [torch.nn.MaxPool2d, patched_module.torch_nn_maxpool2d], + [torch.nn.AvgPool2d, patched_module.torch_nn_avgpool2d], + ] - for (layer_cls, patch_func) in combinations: + for layer_cls, patch_func in combinations: pooler = layer_cls(kernel_size=3) # test max pool 3d data = torch.rand(2, 3, 4, 4) materialized_output = pooler(data) - _assert_output_shape(data=data, - module=pooler, - patch_fn=patch_func, - expect_exception=False, - output_shape=materialized_output.shape) + _assert_output_shape( + data=data, + module=pooler, + patch_fn=patch_func, + expect_exception=False, + output_shape=materialized_output.shape, + ) # test max pool 3d data = torch.rand(2, 4, 4) materialized_output = pooler(data) - _assert_output_shape(data=data, - module=pooler, - patch_fn=patch_func, - expect_exception=False, - output_shape=materialized_output.shape) + _assert_output_shape( + data=data, + module=pooler, + patch_fn=patch_func, + expect_exception=False, + output_shape=materialized_output.shape, + ) # test max pool 3d data = torch.rand(2, 3, 4, 4, 4) @@ -393,29 +434,35 @@ def test_pool2d(): @clear_cache_before_run() def test_pool3d(): - combinations = [[torch.nn.MaxPool3d, patched_module.torch_nn_maxpool3d], - [torch.nn.AvgPool3d, patched_module.torch_nn_avgpool3d]] + combinations = [ + [torch.nn.MaxPool3d, patched_module.torch_nn_maxpool3d], + [torch.nn.AvgPool3d, patched_module.torch_nn_avgpool3d], + ] - for (layer_cls, patch_func) in combinations: + for layer_cls, patch_func in combinations: pooler = layer_cls(kernel_size=3) # test max pool 3d data = torch.rand(2, 3, 4, 4, 4) materialized_output = pooler(data) - _assert_output_shape(data=data, - module=pooler, - patch_fn=patch_func, - expect_exception=False, - output_shape=materialized_output.shape) + _assert_output_shape( + data=data, + module=pooler, + patch_fn=patch_func, + expect_exception=False, + output_shape=materialized_output.shape, + ) # test max pool 3d data = torch.rand(2, 4, 4, 4) materialized_output = pooler(data) - _assert_output_shape(data=data, - module=pooler, - patch_fn=patch_func, - expect_exception=False, - output_shape=materialized_output.shape) + _assert_output_shape( + data=data, + module=pooler, + patch_fn=patch_func, + expect_exception=False, + output_shape=materialized_output.shape, + ) # test max pool 3d data = torch.rand(2, 3, 4) @@ -430,19 +477,15 @@ def test_adaptive_pooling_1d(): data = torch.rand(3, 4) output = pooler(data) - _assert_output_shape(data=data, - module=pooler, - patch_fn=patch_func, - expect_exception=False, - output_shape=output.shape) + _assert_output_shape( + data=data, module=pooler, patch_fn=patch_func, expect_exception=False, output_shape=output.shape + ) data = torch.rand(2, 3, 4) output = pooler(data) - _assert_output_shape(data=data, - module=pooler, - patch_fn=patch_func, - expect_exception=False, - output_shape=output.shape) + _assert_output_shape( + data=data, module=pooler, patch_fn=patch_func, expect_exception=False, output_shape=output.shape + ) data = torch.rand(2, 3, 4, 5) _assert_output_shape(data=data, module=pooler, patch_fn=patch_func, expect_exception=True, output_shape=None) @@ -458,19 +501,15 @@ def test_adaptive_pooling_2d(): data = torch.rand(2, 3, 4) output = pooler(data) - _assert_output_shape(data=data, - module=pooler, - patch_fn=patch_func, - expect_exception=False, - output_shape=output.shape) + _assert_output_shape( + data=data, module=pooler, patch_fn=patch_func, expect_exception=False, output_shape=output.shape + ) data = torch.rand(2, 3, 4, 5) output = pooler(data) - _assert_output_shape(data=data, - module=pooler, - patch_fn=patch_func, - expect_exception=False, - output_shape=output.shape) + _assert_output_shape( + data=data, module=pooler, patch_fn=patch_func, expect_exception=False, output_shape=output.shape + ) @clear_cache_before_run() @@ -483,16 +522,12 @@ def test_adaptive_pooling_3d(): data = torch.rand(2, 3, 4, 5) output = pooler(data) - _assert_output_shape(data=data, - module=pooler, - patch_fn=patch_func, - expect_exception=False, - output_shape=output.shape) + _assert_output_shape( + data=data, module=pooler, patch_fn=patch_func, expect_exception=False, output_shape=output.shape + ) data = torch.rand(2, 3, 4, 5, 6) output = pooler(data) - _assert_output_shape(data=data, - module=pooler, - patch_fn=patch_func, - expect_exception=False, - output_shape=output.shape) + _assert_output_shape( + data=data, module=pooler, patch_fn=patch_func, expect_exception=False, output_shape=output.shape + ) diff --git a/tests/test_fx/test_tracer/test_patched_op.py b/tests/test_fx/test_tracer/test_patched_op.py index e0c5f560c49e0d4d0807fb0b42e9ede960704b7d..37c2333c09825376c4d516ec2e6c8ab6d22c6e81 100644 --- a/tests/test_fx/test_tracer/test_patched_op.py +++ b/tests/test_fx/test_tracer/test_patched_op.py @@ -33,38 +33,34 @@ def test_repeat_interleave(): data = torch.tensor([1, 2, 3]) materialized_output = torch.repeat_interleave(data, repeats=2) repeat_interleave = partial(patch_fn, repeats=2) - meta_data = data.to('meta') - _assert_output_shape(data=meta_data, - patch_fn=repeat_interleave, - expect_exception=False, - output_shape=materialized_output.shape) + meta_data = data.to("meta") + _assert_output_shape( + data=meta_data, patch_fn=repeat_interleave, expect_exception=False, output_shape=materialized_output.shape + ) data = torch.tensor([[1, 2], [3, 4]]) materialized_output = torch.repeat_interleave(data, repeats=3, dim=1) repeat_interleave = partial(patch_fn, repeats=3, dim=1) - meta_data = data.to('meta') - _assert_output_shape(data=meta_data, - patch_fn=repeat_interleave, - expect_exception=False, - output_shape=materialized_output.shape) + meta_data = data.to("meta") + _assert_output_shape( + data=meta_data, patch_fn=repeat_interleave, expect_exception=False, output_shape=materialized_output.shape + ) data = torch.tensor([[1, 2], [3, 4]]) materialized_output = torch.repeat_interleave(data, repeats=torch.tensor([1, 2]), dim=-1) repeat_interleave = partial(patch_fn, repeats=torch.tensor([1, 2]), dim=-1) - meta_data = data.to('meta') - _assert_output_shape(data=meta_data, - patch_fn=repeat_interleave, - expect_exception=False, - output_shape=materialized_output.shape) + meta_data = data.to("meta") + _assert_output_shape( + data=meta_data, patch_fn=repeat_interleave, expect_exception=False, output_shape=materialized_output.shape + ) data = torch.tensor([[1, 2], [3, 4]]) materialized_output = torch.repeat_interleave(data, repeats=torch.tensor([1, 2]), dim=0) repeat_interleave = partial(patch_fn, repeats=[1, 2], dim=0) - meta_data = data.to('meta') - _assert_output_shape(data=meta_data, - patch_fn=repeat_interleave, - expect_exception=True, - output_shape=materialized_output.shape) + meta_data = data.to("meta") + _assert_output_shape( + data=meta_data, patch_fn=repeat_interleave, expect_exception=True, output_shape=materialized_output.shape + ) @clear_cache_before_run() diff --git a/tests/test_fx/test_tracer/test_timm_model/test_timm_model.py b/tests/test_fx/test_tracer/test_timm_model/test_timm_model.py index aa14f514c7d6d86bb8c384200777d685fa0df93a..2b3f3e039baf333ea5a2c520879f26cc54d96ecb 100644 --- a/tests/test_fx/test_tracer/test_timm_model/test_timm_model.py +++ b/tests/test_fx/test_tracer/test_timm_model/test_timm_model.py @@ -20,7 +20,7 @@ def trace_and_compare(model_cls, data, output_transform_fn, meta_args=None): # 1. ConViT # 2. NormFreeNet # as they are not supported, let's skip them - if model.__class__.__name__ in ['ConViT', 'NormFreeNet']: + if model.__class__.__name__ in ["ConViT", "NormFreeNet"]: return gm = symbolic_trace(model, meta_args=meta_args) @@ -39,26 +39,33 @@ def trace_and_compare(model_cls, data, output_transform_fn, meta_args=None): for key in transformed_fx_out.keys(): fx_output_val = transformed_fx_out[key] non_fx_output_val = transformed_non_fx_out[key] - assert torch.allclose(fx_output_val, non_fx_output_val, atol=1e-5), \ - f'{model.__class__.__name__} has inconsistent outputs, {fx_output_val} vs {non_fx_output_val}' - - -@pytest.mark.skipif(version.parse(torch.__version__) < version.parse('1.12.0'), reason='torch version < 12') + assert torch.allclose( + fx_output_val, non_fx_output_val, atol=1e-5 + ), f"{model.__class__.__name__} has inconsistent outputs, {fx_output_val} vs {non_fx_output_val}" + + +# FIXME(ver217): timm/models/convit.py:71: in forward +# if self.rel_indices is None or self.rel_indices.shape[1] != N: +# torch/fx/proxy.py:284: in __bool__ +# return self.tracer.to_bool(self) +# torch.fx.proxy.TraceError: symbolically traced variables cannot be used as inputs to control flow +@pytest.mark.skip("convit is not supported yet") +@pytest.mark.skipif(version.parse(torch.__version__) < version.parse("1.12.0"), reason="torch version < 12") @clear_cache_before_run() def test_timm_models(): torch.backends.cudnn.deterministic = True - sub_model_zoo = model_zoo.get_sub_registry('timm') + sub_model_zoo = model_zoo.get_sub_registry("timm") - for name, (model_fn, data_gen_fn, output_transform_fn, attribute) in sub_model_zoo.items(): + for name, (model_fn, data_gen_fn, output_transform_fn, _, _, attribute) in sub_model_zoo.items(): data = data_gen_fn() if attribute is not None and attribute.has_control_flow: - meta_args = {k: v.to('meta') for k, v in data.items()} + meta_args = {k: v.to("meta") for k, v in data.items()} else: meta_args = None trace_and_compare(model_fn, data, output_transform_fn, meta_args) -if __name__ == '__main__': +if __name__ == "__main__": test_timm_models() diff --git a/tests/test_fx/test_tracer/test_torchaudio_model/test_torchaudio_model.py b/tests/test_fx/test_tracer/test_torchaudio_model/test_torchaudio_model.py index eafcaca10b1d72c37750a8b2f87a9c0ec4487a7f..dd94a254695516640bf1a938640c563b4320604b 100644 --- a/tests/test_fx/test_tracer/test_torchaudio_model/test_torchaudio_model.py +++ b/tests/test_fx/test_tracer/test_torchaudio_model/test_torchaudio_model.py @@ -1,6 +1,5 @@ import pytest import torch -from packaging import version from torchaudio_utils import trace_and_compare from colossalai.testing import clear_cache_before_run @@ -14,11 +13,10 @@ from tests.kit.model_zoo import model_zoo def test_torchaudio_models(): torch.backends.cudnn.deterministic = True - sub_model_zoo = model_zoo.get_sub_registry('torchaudio') + sub_model_zoo = model_zoo.get_sub_registry("torchaudio") - for name, (model_fn, data_gen_fn, output_transform_fn, attribute) in sub_model_zoo.items(): + for name, (model_fn, data_gen_fn, output_transform_fn, _, _, attribute) in sub_model_zoo.items(): model = model_fn() - trace_and_compare(model, - data_gen_fn, - output_transform_fn, - need_meta=(attribute is not None and attribute.has_control_flow)) + trace_and_compare( + model, data_gen_fn, output_transform_fn, need_meta=(attribute is not None and attribute.has_control_flow) + ) diff --git a/tests/test_fx/test_tracer/test_torchaudio_model/torchaudio_utils.py b/tests/test_fx/test_tracer/test_torchaudio_model/torchaudio_utils.py index 239f38680cec5f0659985416260992935645cb10..2379372bc3f9d8f5d8e498881337fee3b3a19d55 100644 --- a/tests/test_fx/test_tracer/test_torchaudio_model/torchaudio_utils.py +++ b/tests/test_fx/test_tracer/test_torchaudio_model/torchaudio_utils.py @@ -6,7 +6,7 @@ from colossalai._analyzer.fx import symbolic_trace def trace_and_compare(model, data_gen, output_transform_fn, need_meta=False, need_concrete=False): data = data_gen() concrete_args = data if need_concrete else {} - meta_args = {k: v.to('meta') for k, v in data.items()} if need_meta else {} + meta_args = {k: v.to("meta") for k, v in data.items()} if need_meta else {} model.eval() @@ -24,5 +24,6 @@ def trace_and_compare(model, data_gen, output_transform_fn, need_meta=False, nee for key, fx_output_val in transformed_fx_out.items(): non_fx_output_val = transformed_non_fx_out[key] - assert torch.allclose(fx_output_val, non_fx_output_val, atol=1e-5), \ - f'{model.__class__.__name__} has inconsistent outputs, {fx_output_val} vs {non_fx_output_val}' + assert torch.allclose( + fx_output_val, non_fx_output_val, atol=1e-5 + ), f"{model.__class__.__name__} has inconsistent outputs, {fx_output_val} vs {non_fx_output_val}" diff --git a/tests/test_fx/test_tracer/test_torchrec_model/test_deepfm_model.py b/tests/test_fx/test_tracer/test_torchrec_model/test_deepfm_model.py index df02568c00496a0b76c686ac7e3a551960914584..30c1910855e60b9cfd2f82b7ae95b6d1c5f51564 100644 --- a/tests/test_fx/test_tracer/test_torchrec_model/test_deepfm_model.py +++ b/tests/test_fx/test_tracer/test_torchrec_model/test_deepfm_model.py @@ -1,4 +1,3 @@ -import pytest import torch from colossalai._analyzer.fx import symbolic_trace @@ -32,31 +31,34 @@ def trace_and_compare(model_cls, data, output_transform_fn, meta_args=None): assert len(transformed_fx_out) == len(transformed_non_fx_out) if torch.is_tensor(fx_out): assert torch.allclose( - fx_out, non_fx_out), f'{model.__class__.__name__} has inconsistent outputs, {fx_out} vs {non_fx_out}' + fx_out, non_fx_out + ), f"{model.__class__.__name__} has inconsistent outputs, {fx_out} vs {non_fx_out}" else: assert torch.allclose( - fx_out.values(), - non_fx_out.values()), f'{model.__class__.__name__} has inconsistent outputs, {fx_out} vs {non_fx_out}' + fx_out.values(), non_fx_out.values() + ), f"{model.__class__.__name__} has inconsistent outputs, {fx_out} vs {non_fx_out}" for key in transformed_fx_out.keys(): fx_output_val = transformed_fx_out[key] non_fx_output_val = transformed_non_fx_out[key] if torch.is_tensor(fx_output_val): - assert torch.allclose(fx_output_val, non_fx_output_val, atol=1e-5), \ - f'{model.__class__.__name__} has inconsistent outputs, {fx_output_val} vs {non_fx_output_val}' + assert torch.allclose( + fx_output_val, non_fx_output_val, atol=1e-5 + ), f"{model.__class__.__name__} has inconsistent outputs, {fx_output_val} vs {non_fx_output_val}" else: - assert torch.allclose(fx_output_val.values(), non_fx_output_val.values() - ), f'{model.__class__.__name__} has inconsistent outputs, {fx_out} vs {non_fx_out}' + assert torch.allclose( + fx_output_val.values(), non_fx_output_val.values() + ), f"{model.__class__.__name__} has inconsistent outputs, {fx_out} vs {non_fx_out}" @clear_cache_before_run() def test_torchrec_deepfm_models(): - deepfm_models = model_zoo.get_sub_registry('deepfm') + deepfm_models = model_zoo.get_sub_registry("deepfm") torch.backends.cudnn.deterministic = True - for name, (model_fn, data_gen_fn, output_transform_fn, attribute) in deepfm_models.items(): + for name, (model_fn, data_gen_fn, output_transform_fn, _, attribute) in deepfm_models.items(): data = data_gen_fn() if attribute is not None and attribute.has_control_flow: - meta_args = {k: v.to('meta') for k, v in data.items()} + meta_args = {k: v.to("meta") for k, v in data.items()} else: meta_args = None diff --git a/tests/test_fx/test_tracer/test_torchrec_model/test_dlrm_model.py b/tests/test_fx/test_tracer/test_torchrec_model/test_dlrm_model.py index 9776452be9c8ea220a3fc2651f5aa4a5a3f8b2ad..71b73236474f38c10409a55d28a4d9911bd88425 100644 --- a/tests/test_fx/test_tracer/test_torchrec_model/test_dlrm_model.py +++ b/tests/test_fx/test_tracer/test_torchrec_model/test_dlrm_model.py @@ -1,4 +1,3 @@ -import pytest import torch from colossalai._analyzer.fx import symbolic_trace @@ -32,37 +31,40 @@ def trace_and_compare(model_cls, data, output_transform_fn, meta_args=None): assert len(transformed_fx_out) == len(transformed_non_fx_out) if torch.is_tensor(fx_out): assert torch.allclose( - fx_out, non_fx_out), f'{model.__class__.__name__} has inconsistent outputs, {fx_out} vs {non_fx_out}' + fx_out, non_fx_out + ), f"{model.__class__.__name__} has inconsistent outputs, {fx_out} vs {non_fx_out}" else: assert torch.allclose( - fx_out.values(), - non_fx_out.values()), f'{model.__class__.__name__} has inconsistent outputs, {fx_out} vs {non_fx_out}' + fx_out.values(), non_fx_out.values() + ), f"{model.__class__.__name__} has inconsistent outputs, {fx_out} vs {non_fx_out}" for key in transformed_fx_out.keys(): fx_output_val = transformed_fx_out[key] non_fx_output_val = transformed_non_fx_out[key] if torch.is_tensor(fx_output_val): - assert torch.allclose(fx_output_val, non_fx_output_val, atol=1e-5), \ - f'{model.__class__.__name__} has inconsistent outputs, {fx_output_val} vs {non_fx_output_val}' + assert torch.allclose( + fx_output_val, non_fx_output_val, atol=1e-5 + ), f"{model.__class__.__name__} has inconsistent outputs, {fx_output_val} vs {non_fx_output_val}" else: - assert torch.allclose(fx_output_val.values(), non_fx_output_val.values() - ), f'{model.__class__.__name__} has inconsistent outputs, {fx_out} vs {non_fx_out}' + assert torch.allclose( + fx_output_val.values(), non_fx_output_val.values() + ), f"{model.__class__.__name__} has inconsistent outputs, {fx_out} vs {non_fx_out}" @clear_cache_before_run() def test_torchrec_dlrm_models(): torch.backends.cudnn.deterministic = True - dlrm_models = model_zoo.get_sub_registry('dlrm') + dlrm_models = model_zoo.get_sub_registry("dlrm") - for name, (model_fn, data_gen_fn, output_transform_fn, attribute) in dlrm_models.items(): + for name, (model_fn, data_gen_fn, output_transform_fn, _, attribute) in dlrm_models.items(): data = data_gen_fn() # dlrm_interactionarch is not supported # TODO(FrankLeeeee): support this model - if name == 'dlrm_interactionarch': + if name == "dlrm_interactionarch": continue if attribute is not None and attribute.has_control_flow: - meta_args = {k: v.to('meta') for k, v in data.items()} + meta_args = {k: v.to("meta") for k, v in data.items()} else: meta_args = None diff --git a/tests/test_fx/test_tracer/test_torchvision_model/test_torchvision_model.py b/tests/test_fx/test_tracer/test_torchvision_model/test_torchvision_model.py index bd259475ae5a51311932f6bb61e89d1c55f538cf..47c6b1186c8e98705760c3c671cd3c2d0760fd97 100644 --- a/tests/test_fx/test_tracer/test_torchvision_model/test_torchvision_model.py +++ b/tests/test_fx/test_tracer/test_torchvision_model/test_torchvision_model.py @@ -8,9 +8,9 @@ from tests.kit.model_zoo import model_zoo @clear_cache_before_run() def test_torchvision_models(): torch.backends.cudnn.deterministic = True - tv_sub_registry = model_zoo.get_sub_registry('torchvision') + tv_sub_registry = model_zoo.get_sub_registry("torchvision") - for name, (model_fn, data_gen_fn, output_transform_fn, model_attribute) in tv_sub_registry.items(): + for name, (model_fn, data_gen_fn, output_transform_fn, _, model_attribute) in tv_sub_registry.items(): data = data_gen_fn() if model_attribute is not None and model_attribute.has_stochastic_depth_prob: @@ -36,11 +36,11 @@ def test_torchvision_models(): fx_val = transformed_out[key] non_fx_val = transformed_non_fx_out[key] assert torch.allclose( - fx_val, - non_fx_val), f'{model.__class__.__name__} has inconsistent outputs, {fx_val} vs {non_fx_val}' + fx_val, non_fx_val + ), f"{model.__class__.__name__} has inconsistent outputs, {fx_val} vs {non_fx_val}" except Exception as e: print(name, e) -if __name__ == '__main__': +if __name__ == "__main__": test_torchvision_models() diff --git a/tests/test_gptq/test_gptq_linear.py b/tests/test_gptq/test_gptq_linear.py new file mode 100644 index 0000000000000000000000000000000000000000..9b650aa781123890effeef7dcb72396689abcbf6 --- /dev/null +++ b/tests/test_gptq/test_gptq_linear.py @@ -0,0 +1,150 @@ +import math +import time + +import numpy as np +import pytest +import torch +import torch.nn as nn +import transformers +from packaging import version + +try: + import triton + import triton.language as tl + HAS_TRITON = True +except ImportError: + HAS_TRITON = False + print("please install triton from https://github.com/openai/triton") + +try: + from auto_gptq.modeling._utils import autogptq_post_init + from auto_gptq.utils.import_utils import dynamically_import_QuantLinear + from exllama_kernels import prepare_buffers, set_tuning_params + + from colossalai.inference.quant.gptq import CaiQuantLinear + HAS_AUTO_GPTQ = True +except: + HAS_AUTO_GPTQ = False + print("please install AutoGPTQ from https://github.com/PanQiWei/AutoGPTQ") + +import warnings + +HAS_GPTQ_CUDA = False +try: + from colossalai.kernel.op_builder.gptq import GPTQBuilder + gptq_cuda = GPTQBuilder().load() + HAS_GPTQ_CUDA = True +except ImportError: + warnings.warn('CUDA gptq is not installed') + HAS_GPTQ_CUDA = False + +TRITON_CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse('11.4') + +max_inner_outer_dim = 1 +max_input_len = 1 +max_dq_buffer_size = 1 +gptq_temp_dq_buffer = None +gptq_temp_state_buffer = None + + +def init_buffer(cai_linear, use_act_order=False): + global max_dq_buffer_size + global max_input_len + global max_dq_buffer_size + global max_inner_outer_dim + global gptq_temp_dq_buffer + global gptq_temp_state_buffer + + max_dq_buffer_size = max(max_dq_buffer_size, cai_linear.qweight.numel() * 8) + + if use_act_order: + max_inner_outer_dim = max(max_inner_outer_dim, cai_linear.infeatures, cai_linear.outfeatures) + + if use_act_order: + max_input_len = 4096 + # The temp_state buffer is required to reorder X in the act-order case. + # The temp_dq buffer is required to dequantize weights when using cuBLAS, typically for the prefill. + gptq_temp_state_buffer = torch.zeros((max_input_len, max_inner_outer_dim), + dtype=torch.float16, + device=torch.cuda.current_device()) + gptq_temp_dq_buffer = torch.zeros((1, max_dq_buffer_size), dtype=torch.float16, device=torch.cuda.current_device()) + + gptq_cuda.prepare_buffers(torch.device(torch.cuda.current_device()), gptq_temp_state_buffer, gptq_temp_dq_buffer) + # Using the default from exllama repo here. + matmul_recons_thd = 8 + matmul_fused_remap = False + matmul_no_half2 = False + gptq_cuda.set_tuning_params(matmul_recons_thd, matmul_fused_remap, matmul_no_half2) + + +@pytest.mark.skipif(not TRITON_CUDA_SUPPORT or not HAS_TRITON or not HAS_AUTO_GPTQ, + reason="triton requires cuda version to be higher than 11.4 or not install auto-gptq") +def test_gptq_linear(): + + infeature = 1024 + outfeature = 1024 + group_size = 128 + wbits = 4 + + inps = torch.ones(1, 1, infeature).to(torch.float16).to(torch.cuda.current_device()) + batch_inps = torch.randn(1, 16, infeature).to(torch.float16).to(torch.cuda.current_device()) + + device = torch.device("cuda:0") + + linear_class = dynamically_import_QuantLinear(use_triton=False, desc_act=False, group_size=group_size, bits=wbits) + + linear = linear_class( + bits=4, + group_size=group_size, + infeatures=infeature, + outfeatures=outfeature, + bias=False, + ) + + torch.manual_seed(42) + + linear.qweight = torch.randint(-100, 100, size=linear.qweight.shape, dtype=torch.int32) + linear.scales = linear.scales + 0.002 + + linear = linear.to(device) + + cai_linear = CaiQuantLinear(wbits, group_size, infeature, outfeature, True) + cai_linear.qweight.data.copy_(linear.qweight) + cai_linear.scales = cai_linear.scales + 0.002 + cai_linear = cai_linear.to(device) + + linear = autogptq_post_init(linear, use_act_order=False) + + max_inner_outer_dim = max(infeature, outfeature) + max_dq_buffer_size = linear.infeatures * linear.outfeatures + max_input_len = 2048 + buffers = { + "temp_state": torch.zeros((max_input_len, max_inner_outer_dim), dtype=torch.float16, device=device), + "temp_dq": torch.zeros((1, max_dq_buffer_size), dtype=torch.float16, device=device) + } + + prepare_buffers(device, buffers["temp_state"], buffers["temp_dq"]) + + # Using the default from exllama repo here. + matmul_recons_thd = 8 + matmul_fused_remap = False + matmul_no_half2 = False + set_tuning_params(matmul_recons_thd, matmul_fused_remap, matmul_no_half2) + + with torch.no_grad(): + gptq_out = linear(inps) + batch_gptq_out = linear(batch_inps) + torch.cuda.synchronize() + cai_out = cai_linear(inps) + torch.cuda.synchronize() + + batch_cai_out = cai_linear(batch_inps) + torch.cuda.synchronize() + + assert torch.allclose(cai_out, gptq_out, rtol=1e-01, atol=1e-01) + assert torch.allclose(batch_cai_out, batch_gptq_out, rtol=1e-01, atol=1e-01) + + +if __name__ == "__main__": + + test_gptq_linear() diff --git a/tests/test_infer/_utils.py b/tests/test_infer/_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..2ddc8b6e68e437591b5500d3879fb7345e06d3c2 --- /dev/null +++ b/tests/test_infer/_utils.py @@ -0,0 +1,41 @@ +import copy + +from colossalai.shardformer import ShardConfig, ShardFormer + + +def build_model( + model_fn, + enable_fused_normalization=False, + enable_tensor_parallelism=False, + enable_flash_attention=False, + enable_jit_fused=False, +): + # create new model + org_model = model_fn() + + # shard model + shard_config = ShardConfig( + enable_fused_normalization=enable_fused_normalization, + enable_tensor_parallelism=enable_tensor_parallelism, + enable_flash_attention=enable_flash_attention, + enable_jit_fused=enable_jit_fused, + inference_only=True, + ) + model_copy = copy.deepcopy(org_model) + shard_former = ShardFormer(shard_config=shard_config) + sharded_model, shared_params = shard_former.optimize(model_copy) + return org_model.cuda(), sharded_model.cuda() + + +def run_infer(original_model, sharded_model, data_gen_fn, output_transform_fn): + # prepare input + data = data_gen_fn() + data = {k: v.cuda() for k, v in data.items()} + # run forward + org_output = original_model(**data) + org_output = output_transform_fn(org_output) + + shard_output = sharded_model(**data) + shard_output = output_transform_fn(shard_output) + + return org_output, shard_output diff --git a/tests/test_infer/test_bloom_infer.py b/tests/test_infer/test_bloom_infer.py new file mode 100644 index 0000000000000000000000000000000000000000..ba978ad9bf0d0b06ed012cc97f0f0e00a7ab0d29 --- /dev/null +++ b/tests/test_infer/test_bloom_infer.py @@ -0,0 +1,64 @@ +import pytest +import torch +from packaging import version +from transformers import BloomForCausalLM +from transformers.models.bloom.configuration_bloom import BloomConfig + +import colossalai +from colossalai.inference.tensor_parallel import TPInferEngine +from colossalai.logging import disable_existing_loggers +from colossalai.shardformer import ShardConfig +from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn + +TP_SIZE = 2 +MAX_BATCH_SIZE = 4 +MAX_INPUT_LEN = 16 +MAX_OUTPUT_LEN = 32 + +CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse("11.5") + + +@parameterize( + "test_config", + [ + { + "tp_size": TP_SIZE, + } + ], +) +def run(test_config): + bloom_config = BloomConfig(num_hidden_layers=2, bos_token_id=0, eos_token_id=1, vocab_size=1200, hidden_size=1024) + model = BloomForCausalLM(bloom_config) + model = model.half() + + shard_config = ShardConfig( + enable_tensor_parallelism=True if test_config["tp_size"] > 1 else False, inference_only=True + ) + infer_engine = TPInferEngine(model, shard_config, MAX_BATCH_SIZE, MAX_INPUT_LEN, MAX_OUTPUT_LEN) + generate_kwargs = dict(max_new_tokens=MAX_OUTPUT_LEN, do_sample=False) + + input_tokens = { + "input_ids": torch.randint(1, 1000, (MAX_BATCH_SIZE, MAX_INPUT_LEN), device="cuda"), + "attention_mask": torch.ones((MAX_BATCH_SIZE, MAX_INPUT_LEN), device="cuda"), + } + outputs = infer_engine.generate(input_tokens, **generate_kwargs) + + assert outputs is not None + + +def check_bloom(rank, world_size, port): + disable_existing_loggers() + colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") + run() + + +@pytest.mark.skipif(not CUDA_SUPPORT, reason="kv-cache manager engine requires cuda version to be higher than 11.5") +@pytest.mark.dist +@rerun_if_address_is_in_use() +@clear_cache_before_run() +def test_bloom_infer(): + spawn(check_bloom, TP_SIZE) + + +if __name__ == "__main__": + test_bloom_infer() diff --git a/tests/test_infer/test_chatglm2_infer.py b/tests/test_infer/test_chatglm2_infer.py new file mode 100644 index 0000000000000000000000000000000000000000..399b70e1460e47ad75065b543b88dd1469140873 --- /dev/null +++ b/tests/test_infer/test_chatglm2_infer.py @@ -0,0 +1,74 @@ +import os + +import pytest +import torch +from packaging import version + +import colossalai +from colossalai.inference.tensor_parallel.engine import TPInferEngine +from colossalai.logging import disable_existing_loggers +from colossalai.shardformer import ShardConfig +from colossalai.shardformer.modeling.chatglm2_6b.configuration_chatglm import ChatGLMConfig +from colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm import ChatGLMForConditionalGeneration +from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn + +os.environ["TRANSFORMERS_NO_ADVISORY_WARNINGS"] = "true" +TPSIZE = 1 +BATCH_SIZE = 8 +MAX_INPUT_LEN = 12 +MAX_OUTPUT_LEN = 100 +CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse("11.5") + + +@parameterize( + "test_config", + [ + { + "tp_size": TPSIZE, + } + ], +) +def run_chatglm2_test(test_config): + chatglm_config = ChatGLMConfig( + num_layers=2, + vocab_size=1200, + use_cache=True, + multi_query_attention=True, + multi_query_group_num=2, + num_attention_heads=8, + hidden_size=1024, + ) + model = ChatGLMForConditionalGeneration(chatglm_config) + model = model.half() + + shard_config = ShardConfig( + enable_tensor_parallelism=True if test_config["tp_size"] > 1 else False, inference_only=True + ) + infer_engine = TPInferEngine(model, shard_config, BATCH_SIZE, MAX_INPUT_LEN, MAX_OUTPUT_LEN) + generate_kwargs = dict(max_new_tokens=MAX_OUTPUT_LEN, do_sample=False) + + input_tokens = { + "input_ids": torch.randint(1, 1000, (BATCH_SIZE, MAX_INPUT_LEN), device="cuda"), + "attention_mask": torch.ones((BATCH_SIZE, MAX_INPUT_LEN), device="cuda"), + } + outputs = infer_engine.generate(input_tokens, **generate_kwargs) + + assert outputs is not None + + +def check_chatglm2(rank, world_size, port): + disable_existing_loggers() + colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") + run_chatglm2_test() + + +@pytest.mark.skipif(not CUDA_SUPPORT, reason="kv-cache manager engine requires cuda version to be higher than 11.5") +@pytest.mark.dist +@rerun_if_address_is_in_use() +@clear_cache_before_run() +def test_chatglm2(): + spawn(check_chatglm2, TPSIZE) + + +if __name__ == "__main__": + test_chatglm2() diff --git a/tests/test_infer/test_infer_engine.py b/tests/test_infer/test_infer_engine.py new file mode 100644 index 0000000000000000000000000000000000000000..f24160820e71f1bbcd954c49556f927a87fba655 --- /dev/null +++ b/tests/test_infer/test_infer_engine.py @@ -0,0 +1,102 @@ +from itertools import accumulate + +import pytest +import torch +from packaging import version +from transformers import BloomConfig, BloomForCausalLM +from transformers.tokenization_utils_base import BatchEncoding + +import colossalai +from colossalai.inference.tensor_parallel import TPInferEngine +from colossalai.logging import disable_existing_loggers +from colossalai.shardformer import ShardConfig +from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn + +TP_SIZE = 2 +MAX_BATCH_SIZE = 4 +MAX_INPUT_LEN = 16 +MAX_OUTPUT_LEN = 8 + +CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse("11.5") + + +@parameterize( + "test_config", + [ + { + "tp_size": TP_SIZE, + } + ], +) +def run(test_config): + model_config = BloomConfig(num_hidden_layers=4, hidden_size=128, intermediate_size=256, num_attention_heads=4) + model = BloomForCausalLM(model_config) + model = model.half() + model.to(torch.cuda.current_device()) + + # 1. check TPInferEngine init and model optimization + shard_config = ShardConfig( + enable_tensor_parallelism=True if test_config["tp_size"] > 1 else False, inference_only=True + ) + infer_engine = TPInferEngine(model, shard_config, MAX_BATCH_SIZE, MAX_INPUT_LEN, MAX_OUTPUT_LEN) + + assert infer_engine.cache_manager is not None + assert infer_engine.tp_size == TP_SIZE + assert infer_engine.head_num == model_config.num_attention_heads // TP_SIZE + + # 2. check data preparation + input_ids_list = [ + [80540, 15473, 3331, 11970, 90472, 361, 61335], + [80540, 15473, 3331, 11970], + [80540, 15473, 3331, 11970], + [80540, 15473], + ] + batch_size = len(input_ids_list) + max_seq_len = max(len(li) for li in input_ids_list) + attention_mask = [[0] * max_seq_len for _ in range(batch_size)] + for i, li in enumerate(input_ids_list): + attention_mask[i][max_seq_len - len(li) :] = [1 for _ in range(len(li))] + data = dict(input_ids=input_ids_list, attention_mask=attention_mask) + inputs_batch_encoding = BatchEncoding(data=data) + seq_lengths = [len(li) for li in input_ids_list] + start_loc = list(accumulate([0] + seq_lengths[:-1])) + seq_lengths = torch.tensor(seq_lengths, dtype=torch.int32) + start_loc = torch.tensor(start_loc, dtype=torch.int32) + # input token id list as inputs + batch_state_out1 = infer_engine.prepare_batch_state(inputs_batch_encoding) + # BatchEncoding as inputs + batch_state_out2 = infer_engine.prepare_batch_state(input_ids_list) + + assert batch_state_out1.batch_size == batch_state_out2.batch_size == batch_size + assert torch.equal(batch_state_out1.seq_len, batch_state_out2.seq_len) + + # The following tests are discarded for now, and will be reused after all features are added + # assert torch.equal(batch_state_out1.seq_len.to(seq_lengths.device), seq_lengths) + # assert torch.equal(batch_state_out2.seq_len.to(seq_lengths.device), seq_lengths) + # assert torch.equal(batch_state_out1.start_loc.to(start_loc.device), start_loc) + # assert torch.equal(batch_state_out2.start_loc.to(start_loc.device), start_loc) + + # 3. check optimized model generate + input_ids = torch.randint(low=10, high=1000, size=(MAX_BATCH_SIZE, MAX_INPUT_LEN)) + generate_kwargs = dict(do_sample=False) + infer_engine.generate(input_ids, **generate_kwargs) + + torch.cuda.empty_cache() + + +def check_engine(rank, world_size, port): + disable_existing_loggers() + colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") + run() + + +@pytest.mark.skipif(not CUDA_SUPPORT, reason="kv-cache manager engine requires cuda version to be higher than 11.5") +@pytest.mark.dist +@rerun_if_address_is_in_use() +@clear_cache_before_run() +def test_engine(): + spawn(check_engine, TP_SIZE) + + +if __name__ == "__main__": + test_engine() diff --git a/tests/test_infer/test_kvcache_manager.py b/tests/test_infer/test_kvcache_manager.py new file mode 100644 index 0000000000000000000000000000000000000000..f3e2cdf1e18fa55e86e8b77b486fed8858c9b91f --- /dev/null +++ b/tests/test_infer/test_kvcache_manager.py @@ -0,0 +1,66 @@ +import os + +import pytest +import torch +from packaging import version + +from colossalai.inference.tensor_parallel import MemoryManager +from colossalai.logging import disable_existing_loggers +from colossalai.testing import rerun_if_address_is_in_use, spawn + +BATCH_SIZE = 4 +INPUT_LEN = 16 +OUTPUT_LEN = 8 +LAYER_NUM = 4 +HEAD_NUM = 32 +HEAD_DIM = 128 + +CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse("11.5") + + +def create_cache_manager(rank, world_size, port, batch_size, input_len, output_len, layer_num, head_num, head_dim): + os.environ["RANK"] = str(rank) + os.environ["LOCAL_RANK"] = str(rank) + os.environ["WORLD_SIZE"] = str(world_size) + os.environ["MASTER_ADDR"] = "localhost" + os.environ["MASTER_PORT"] = str(port) + disable_existing_loggers() + + size = batch_size * (input_len + output_len) + kvcache_manager = MemoryManager(size, torch.float16, head_num // world_size, head_dim, layer_num, rank) + key_buffers = kvcache_manager.key_buffer + value_buffers = kvcache_manager.value_buffer + assert len(key_buffers) == len(value_buffers) == layer_num + assert key_buffers[0].shape == value_buffers[0].shape + # required size exceeds the maximum allocated size + invalid_locs = kvcache_manager.alloc_contiguous(size + 1) + assert invalid_locs is None + # for prefill stage, allocation via alloc and alloc_contiguous should be the same + total_token_prefill = batch_size * input_len + prefill_locs = kvcache_manager.alloc(total_token_prefill) + kvcache_manager.free_all() + prefill_locs_contiguous = kvcache_manager.alloc_contiguous(total_token_prefill)[0] + assert torch.equal(prefill_locs, prefill_locs_contiguous) + assert torch.sum(kvcache_manager.mem_state).item() == size - total_token_prefill + kvcache_manager.alloc_contiguous(batch_size) + assert torch.all(kvcache_manager.mem_state[: total_token_prefill + batch_size] == False) + + +@pytest.mark.skipif(not CUDA_SUPPORT, reason="kv-cache manager engine requires cuda version to be higher than 11.5") +@pytest.mark.dist +@rerun_if_address_is_in_use() +def test_cache_manager_dist(): + spawn( + create_cache_manager, + 4, + batch_size=BATCH_SIZE, + input_len=INPUT_LEN, + output_len=OUTPUT_LEN, + layer_num=LAYER_NUM, + head_num=HEAD_NUM, + head_dim=HEAD_DIM, + ) + + +if __name__ == "__main__": + test_cache_manager_dist() diff --git a/tests/test_infer/test_llama_infer.py b/tests/test_infer/test_llama_infer.py new file mode 100644 index 0000000000000000000000000000000000000000..13bdf03996b97ebfe1370024875dd16a3ffa2133 --- /dev/null +++ b/tests/test_infer/test_llama_infer.py @@ -0,0 +1,68 @@ +import os + +import pytest +import torch +from packaging import version +from transformers import LlamaForCausalLM +from transformers.models.llama.configuration_llama import LlamaConfig + +import colossalai +from colossalai.inference.tensor_parallel.engine import TPInferEngine +from colossalai.logging import disable_existing_loggers +from colossalai.shardformer import ShardConfig +from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn + +os.environ["TRANSFORMERS_NO_ADVISORY_WARNINGS"] = "true" +TPSIZE = 2 +BATCH_SIZE = 8 +MAX_INPUT_LEN = 12 +MAX_OUTPUT_LEN = 100 + +CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse("11.5") + + +@parameterize( + "test_config", + [ + { + "tp_size": TPSIZE, + } + ], +) +def run_llama_test(test_config): + llama_config = LlamaConfig(num_hidden_layers=2, bos_token_id=0, eos_token_id=1, vocab_size=1200, hidden_size=1024) + model = LlamaForCausalLM(llama_config) + model = model.half() + + shard_config = ShardConfig( + enable_tensor_parallelism=True if test_config["tp_size"] > 1 else False, inference_only=True + ) + infer_engine = TPInferEngine(model, shard_config, BATCH_SIZE, MAX_INPUT_LEN, MAX_OUTPUT_LEN) + init_to_get_rotary(model.model, base=10000) + generate_kwargs = dict(max_new_tokens=MAX_OUTPUT_LEN, do_sample=False) + + input_tokens = { + "input_ids": torch.randint(1, 1000, (BATCH_SIZE, MAX_INPUT_LEN), device="cuda"), + "attention_mask": torch.ones((BATCH_SIZE, MAX_INPUT_LEN), device="cuda"), + } + outputs = infer_engine.generate(input_tokens, **generate_kwargs) + + assert outputs is not None + + +def check_llama(rank, world_size, port): + disable_existing_loggers() + colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") + run_llama_test() + + +@pytest.mark.skipif(not CUDA_SUPPORT, reason="kv-cache manager engine requires cuda version to be higher than 11.5") +@pytest.mark.dist +@rerun_if_address_is_in_use() +@clear_cache_before_run() +def test_llama(): + spawn(check_llama, TPSIZE) + + +if __name__ == "__main__": + test_llama() diff --git a/tests/test_infer_ops/cuda/test_vllm_rmsnorm.py b/tests/test_infer_ops/cuda/test_vllm_rmsnorm.py new file mode 100644 index 0000000000000000000000000000000000000000..a4d893f8e830548f85e5f562d7fcbd15859d4a8c --- /dev/null +++ b/tests/test_infer_ops/cuda/test_vllm_rmsnorm.py @@ -0,0 +1,60 @@ +#!/usr/bin/env python +# -*- encoding: utf-8 -*- +import pytest +import torch +from torch import nn + +try: + from vllm import layernorm_ops + + rms_norm = layernorm_ops.rms_norm + HAS_VLLM_KERNERL = True +except: + print("please install vllm kernels to install rmsnorm") + print("install vllm from https://github.com/vllm-project/vllm to accelerate your inference") + HAS_VLLM_KERNERL = False + + +class LlamaRMSNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6): + """ + LlamaRMSNorm is equivalent to T5LayerNorm + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + + +def cuda_rmsnorm_forward(hidden_states, weight, variance_epsilon): + x = hidden_states + out = torch.empty_like(x) + rms_norm( + out, + x, + weight, + variance_epsilon, + ) + return out + + +@pytest.mark.skipif(not HAS_VLLM_KERNERL, reason="You need to install llama supported cuda kernels to run this test") +def test_rmsnorm(): + data = torch.randn((1024, 64), dtype=torch.float16, device="cuda") + hg_rms = LlamaRMSNorm(64) + hg_rms = hg_rms.half().cuda() + out_torch = hg_rms(data) + out_cuda = cuda_rmsnorm_forward(data, hg_rms.weight.data, hg_rms.variance_epsilon) + + check = torch.allclose(out_torch.cpu(), out_cuda.cpu(), rtol=1e-3, atol=1e-5) + assert check is True, "cuda rmsnorm forward is not matched with torch rmsnorm forward" + + +if __name__ == "__main__": + test_rmsnorm() diff --git a/tests/test_infer_ops/cuda/test_vllm_rotary_embedding.py b/tests/test_infer_ops/cuda/test_vllm_rotary_embedding.py new file mode 100644 index 0000000000000000000000000000000000000000..40451ef6636dac18c72fafb1de6c36acd0252a0e --- /dev/null +++ b/tests/test_infer_ops/cuda/test_vllm_rotary_embedding.py @@ -0,0 +1,153 @@ +#!/usr/bin/env python +# -*- encoding: utf-8 -*- +from typing import Tuple + +import pytest +import torch +import torch.nn as nn +import torch.nn.functional as F +from transformers.models.llama.modeling_llama import apply_rotary_pos_emb, rotate_half + +try: + from vllm import pos_encoding_ops + + rotary_embedding_neox = pos_encoding_ops.rotary_embedding_neox + HAS_VLLM_KERNERL = True +except: + print("fall back to original rotary_embedding_neox of huggingface") + print("install vllm from https://github.com/vllm-project/vllm to accelerate your inference") + HAS_VLLM_KERNERL = False + + +def rotate_half(x: torch.Tensor) -> torch.Tensor: + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +def apply_rotary_pos_emb( + q: torch.Tensor, + k: torch.Tensor, + cos: torch.Tensor, + sin: torch.Tensor, +) -> Tuple[torch.Tensor, torch.Tensor]: + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + +class RefRotaryEmbeddingNeox(nn.Module): + """Reference implementation of the GPT-NeoX style rotary embedding.""" + + def __init__( + self, + dim: int, + max_position_embeddings: int = 2048, + base: int = 10000, + ) -> None: + super().__init__() + self.rotary_dim = dim + self.max_position_embeddings = max_position_embeddings + + # Create cos and sin embeddings. + inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2) / dim)) + t = torch.arange(max_position_embeddings).float() + freqs = torch.einsum("i,j->ij", t, inv_freq.float()) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos().to(dtype=inv_freq.dtype) + sin = emb.sin().to(dtype=inv_freq.dtype) + self.register_buffer("cos_cached", cos, persistent=False) + self.register_buffer("sin_cached", sin, persistent=False) + + def forward( + self, + positions: torch.Tensor, # [num_tokens] + query: torch.Tensor, # [num_tokens, num_heads, head_size] + key: torch.Tensor, # [num_tokens, num_heads, head_size] + ) -> Tuple[torch.Tensor, torch.Tensor]: + query_rot = query[..., : self.rotary_dim] + query_pass = query[..., self.rotary_dim :] + key_rot = key[..., : self.rotary_dim] + key_pass = key[..., self.rotary_dim :] + + query_rot = query_rot.transpose(0, 1) + key_rot = key_rot.transpose(0, 1) + cos = F.embedding(positions, self.cos_cached) + sin = F.embedding(positions, self.sin_cached) + query_rot, key_rot = apply_rotary_pos_emb(query_rot, key_rot, cos, sin) + query_rot = query_rot.transpose(0, 1).contiguous() + key_rot = key_rot.transpose(0, 1).contiguous() + + query = torch.cat((query_rot, query_pass), dim=-1) + key = torch.cat((key_rot, key_pass), dim=-1) + + # Output query/key shape: [num_tokens, num_tokens, head_size] + return query, key + + +def run_rotary_embedding_neox( + num_tokens: int, + num_heads: int, + head_size: int, + max_position: int, + rotary_dim: int, + dtype: torch.dtype, + base: int = 10000, +) -> None: + positions = torch.randint(0, max_position, (num_tokens,), device="cuda") + query = torch.randn(num_tokens, num_heads * head_size, dtype=dtype, device="cuda") + key = torch.randn(num_tokens, num_heads * head_size, dtype=dtype, device="cuda") + + # Create the rotary embedding. + inv_freq = 1.0 / (base ** (torch.arange(0, rotary_dim, 2) / rotary_dim)) + t = torch.arange(max_position).float() + freqs = torch.einsum("i,j -> ij", t, inv_freq.float()) + cos = freqs.cos() + sin = freqs.sin() + cos_sin_cache = torch.cat((cos, sin), dim=-1) + cos_sin_cache = cos_sin_cache.to(dtype=dtype, device="cuda") + + # Run the kernel. The kernel is in-place, so we need to clone the inputs. + out_query = query.clone() + out_key = key.clone() + rotary_embedding_neox( + positions, + out_query, + out_key, + head_size, + cos_sin_cache, + ) + + # Run the reference implementation. + ref_rotary_embedding = RefRotaryEmbeddingNeox( + dim=rotary_dim, + max_position_embeddings=max_position, + base=base, + ).to(dtype=dtype, device="cuda") + ref_query, ref_key = ref_rotary_embedding( + positions, + query.view(num_tokens, num_heads, head_size), + key.view(num_tokens, num_heads, head_size), + ) + ref_query = ref_query.view(num_tokens, num_heads * head_size) + ref_key = ref_key.view(num_tokens, num_heads * head_size) + + # Compare the results. + assert torch.allclose(out_query, ref_query, atol=1e-3, rtol=1e-5) + assert torch.allclose(out_key, ref_key, atol=1e-3, rtol=1e-5) + + +@pytest.mark.skipif(not HAS_VLLM_KERNERL, reason="You need to install llama supported cuda kernels to run this test") +def test_rotary_embedding(): + run_rotary_embedding_neox( + num_tokens=1024, + num_heads=8, + head_size=64, + max_position=8192, + rotary_dim=64, + dtype=torch.float16, + ) + + +if __name__ == "__main__": + test_rotary_embedding() diff --git a/tests/test_infer_ops/triton/kernel_utils.py b/tests/test_infer_ops/triton/kernel_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..0732ace1e04b7cbca71cebe678a184238fc5db0b --- /dev/null +++ b/tests/test_infer_ops/triton/kernel_utils.py @@ -0,0 +1,27 @@ +import math + +import torch +from torch.nn import functional as F + + +def torch_context_attention(xq, xk, xv, bs, seqlen, num_head, head_dim): + """ + adepted from https://github.com/ModelTC/lightllm/blob/main/lightllm/models/bloom/triton_kernel/context_flashattention_nopad.py#L253 + """ + xq = xq.view(bs, seqlen, num_head, head_dim) + xk = xk.view(bs, seqlen, num_head, head_dim) + xv = xv.view(bs, seqlen, num_head, head_dim) + mask = torch.tril(torch.ones(seqlen, seqlen), diagonal=0).unsqueeze(0).unsqueeze(0).cuda() + mask[mask == 0.0] = -100000000.0 + mask = mask.repeat(bs, num_head, 1, 1) + keys = xk + values = xv + xq = xq.transpose(1, 2) + keys = keys.transpose(1, 2) + values = values.transpose(1, 2) + sm_scale = 1 / math.sqrt(head_dim) + scores = torch.matmul(xq, keys.transpose(2, 3)) * sm_scale + scores = F.softmax(scores.float() + mask, dim=-1).to(dtype=torch.float16) + + output = torch.matmul(scores, values).transpose(1, 2).contiguous().reshape(-1, num_head, head_dim) + return output diff --git a/tests/test_infer_ops/triton/test_bloom_context_attention.py b/tests/test_infer_ops/triton/test_bloom_context_attention.py new file mode 100644 index 0000000000000000000000000000000000000000..7a6c218a66910725d2fb77998473b6754379f5ce --- /dev/null +++ b/tests/test_infer_ops/triton/test_bloom_context_attention.py @@ -0,0 +1,52 @@ +import pytest +import torch +from packaging import version + +try: + pass + + from colossalai.kernel.triton import bloom_context_attn_fwd + from tests.test_infer_ops.triton.kernel_utils import torch_context_attention + + HAS_TRITON = True +except ImportError: + HAS_TRITON = False + print("please install triton from https://github.com/openai/triton") + +TRITON_CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse("11.4") + + +@pytest.mark.skipif( + not TRITON_CUDA_SUPPORT or not HAS_TRITON, reason="triton requires cuda version to be higher than 11.4" +) +def test_bloom_context_attention(): + bs = 4 + head_num = 8 + seq_len = 1024 + head_dim = 64 + + query = torch.randn((bs * seq_len, head_num, head_dim), dtype=torch.float16, device="cuda") + k = torch.randn((bs * seq_len, head_num, head_dim), dtype=torch.float16, device="cuda") + v = torch.randn((bs * seq_len, head_num, head_dim), dtype=torch.float16, device="cuda") + + max_input_len = seq_len + b_start = torch.zeros((bs,), device="cuda", dtype=torch.int32) + b_len = torch.zeros((bs,), device="cuda", dtype=torch.int32) + + for i in range(bs): + b_start[i] = i * seq_len + b_len[i] = seq_len + + o = torch.randn((bs * seq_len, head_num, head_dim), dtype=torch.float16, device="cuda") + alibi = torch.zeros((head_num,), dtype=torch.float32, device="cuda") + bloom_context_attn_fwd(query.clone(), k.clone(), v.clone(), o, b_start, b_len, max_input_len, alibi) + + torch_out = torch_context_attention(query.clone(), k.clone(), v.clone(), bs, seq_len, head_num, head_dim) + + assert torch.allclose( + torch_out.cpu(), o.cpu(), rtol=1e-3, atol=1e-2 + ), "outputs from triton and torch are not matched" + + +if __name__ == "__main__": + test_bloom_context_attention() diff --git a/tests/test_infer_ops/triton/test_copy_kv_dest.py b/tests/test_infer_ops/triton/test_copy_kv_dest.py new file mode 100644 index 0000000000000000000000000000000000000000..34e453f7840e0d2e96d6bab535a97613932894fa --- /dev/null +++ b/tests/test_infer_ops/triton/test_copy_kv_dest.py @@ -0,0 +1,39 @@ +import pytest +import torch +from packaging import version + +try: + pass + + from colossalai.kernel.triton.copy_kv_cache_dest import copy_kv_cache_to_dest + + HAS_TRITON = True +except ImportError: + HAS_TRITON = False + print("please install triton from https://github.com/openai/triton") + +TRITON_CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse("11.4") + + +@pytest.mark.skipif( + not TRITON_CUDA_SUPPORT or not HAS_TRITON, reason="triton requires cuda version to be higher than 11.4" +) +def test_kv_cache_copy_op(): + B_NTX = 32 * 2048 + head_num = 8 + head_dim = 64 + + cache = torch.randn((B_NTX, head_num, head_dim), device="cuda", dtype=torch.float16) + dest_index = torch.arange(0, B_NTX, device="cuda", dtype=torch.int32) + + dest_data = torch.ones((B_NTX, head_num, head_dim), device="cuda", dtype=torch.float16) + + copy_kv_cache_to_dest(cache, dest_index, dest_data) + + assert torch.allclose( + cache.cpu(), dest_data.cpu(), rtol=1e-3, atol=1e-3 + ), "copy_kv_cache_to_dest outputs from triton and torch are not matched" + + +if __name__ == "__main__": + test_kv_cache_copy_op() diff --git a/tests/test_infer_ops/triton/test_layernorm_triton.py b/tests/test_infer_ops/triton/test_layernorm_triton.py new file mode 100644 index 0000000000000000000000000000000000000000..7f814e8c9a9fd6c521e67954b3df8d2163403741 --- /dev/null +++ b/tests/test_infer_ops/triton/test_layernorm_triton.py @@ -0,0 +1,43 @@ +import pytest +import torch +from packaging import version + +from colossalai.kernel.triton import layer_norm +from colossalai.testing.utils import parameterize + +try: + pass + + HAS_TRITON = True +except ImportError: + HAS_TRITON = False + print("please install triton from https://github.com/openai/triton") + +TRITON_CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse("11.4") + + +@pytest.mark.skipif( + not TRITON_CUDA_SUPPORT or not HAS_TRITON, reason="triton requires cuda version to be higher than 11.4" +) +@parameterize("M", [2, 4, 8, 16]) +@parameterize("N", [64, 128]) +def test_layer_norm(M, N): + dtype = torch.float16 + eps = 1e-5 + x_shape = (M, N) + w_shape = (x_shape[-1],) + weight = torch.rand(w_shape, dtype=dtype, device="cuda") + bias = torch.rand(w_shape, dtype=dtype, device="cuda") + x = -2.3 + 0.5 * torch.randn(x_shape, dtype=dtype, device="cuda") + + y_triton = layer_norm(x, weight, bias, eps) + y_torch = torch.nn.functional.layer_norm(x, w_shape, weight, bias, eps).to(dtype) + + assert y_triton.shape == y_torch.shape + assert y_triton.dtype == y_torch.dtype + print("max delta: ", torch.max(torch.abs(y_triton - y_torch))) + assert torch.allclose(y_triton, y_torch, atol=1e-2, rtol=0) + + +if __name__ == "__main__": + test_layer_norm() diff --git a/tests/test_infer_ops/triton/test_llama2_token_attn.py b/tests/test_infer_ops/triton/test_llama2_token_attn.py new file mode 100644 index 0000000000000000000000000000000000000000..0537a3d76129114ac6fcac279a70a086aa0a72e1 --- /dev/null +++ b/tests/test_infer_ops/triton/test_llama2_token_attn.py @@ -0,0 +1,63 @@ +import pytest +import torch +from packaging import version + +try: + pass + + from colossalai.kernel.triton.token_attention_kernel import Llama2TokenAttentionForwards + + HAS_TRITON = True +except ImportError: + HAS_TRITON = False + print("please install triton from https://github.com/openai/triton") + +TRITON_CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse("11.4") + + +def torch_att(xq, xk, xv, bs, seqlen, num_head, head_dim): + xq = xq.view(bs, 1, num_head, head_dim) + xk = xk.view(bs, seqlen, num_head, head_dim) + xv = xv.view(bs, seqlen, num_head, head_dim) + + logics = torch.sum(xq * xk, dim=3, keepdim=False) * 1 / (head_dim**0.5) + prob = torch.softmax(logics, dim=1) + prob = prob.view(bs, seqlen, num_head, 1) + + return torch.sum(prob * xv, dim=1, keepdim=False) + + +@pytest.mark.skipif( + not TRITON_CUDA_SUPPORT or not HAS_TRITON, reason="triton requires cuda version to be higher than 11.4" +) +def test(): + Z, head_num, seq_len, head_dim = 2, 32, 2048, 128 + dtype = torch.float16 + + # attn out: 2,4096 + q = torch.empty((Z, head_num, head_dim), dtype=dtype, device="cuda").normal_(mean=0.1, std=0.2) + k = torch.empty((Z * seq_len, head_num, head_dim), dtype=dtype, device="cuda").normal_(mean=0.4, std=0.2) + v = torch.empty((Z * seq_len, head_num, head_dim), dtype=dtype, device="cuda").normal_(mean=0.3, std=0.2) + o = torch.empty((Z, head_num, head_dim), dtype=dtype, device="cuda") + max_kv_cache_len = seq_len + kv_cache_start_loc = torch.zeros((Z,), dtype=torch.int32, device="cuda") + kv_cache_loc = torch.zeros((Z, seq_len), dtype=torch.int32, device="cuda") + kv_cache_seq_len = torch.ones((Z,), dtype=torch.int32, device="cuda") + other_kv_index = 2048 + + kv_cache_seq_len[:] = seq_len + kv_cache_start_loc[0] = 0 + kv_cache_start_loc[1] = seq_len + + for i in range(Z): + kv_cache_loc[i, :] = torch.arange(i * seq_len, (i + 1) * seq_len, dtype=torch.int32, device="cuda") + + Llama2TokenAttentionForwards.token_attn( + q, k, v, o, kv_cache_loc, kv_cache_start_loc, kv_cache_seq_len, max_kv_cache_len, other_kv_index + ) + torch_out = torch_att(q, k, v, Z, seq_len, head_num, head_dim) + assert torch.allclose(torch_out, o, atol=1e-3, rtol=0) + + +if __name__ == "__main__": + test() diff --git a/tests/test_infer_ops/triton/test_llama_context_attention.py b/tests/test_infer_ops/triton/test_llama_context_attention.py new file mode 100644 index 0000000000000000000000000000000000000000..be6de6db2471c3d7d30d287a49b3c03fd0381061 --- /dev/null +++ b/tests/test_infer_ops/triton/test_llama_context_attention.py @@ -0,0 +1,51 @@ +import pytest +import torch +from packaging import version + +try: + pass + + from colossalai.kernel.triton import llama_context_attn_fwd + from tests.test_infer_ops.triton.kernel_utils import torch_context_attention + + HAS_TRITON = True +except ImportError: + HAS_TRITON = False + print("please install triton from https://github.com/openai/triton") + +TRITON_CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse("11.4") + + +@pytest.mark.skipif( + not TRITON_CUDA_SUPPORT or not HAS_TRITON, reason="triton requires cuda version to be higher than 11.4" +) +def test_llama_context_attention(): + bs = 4 + head_num = 8 + seq_len = 1024 + head_dim = 64 + + query = torch.randn((bs * seq_len, head_num, head_dim), dtype=torch.float16, device="cuda") + k = torch.randn((bs * seq_len, head_num, head_dim), dtype=torch.float16, device="cuda") + v = torch.randn((bs * seq_len, head_num, head_dim), dtype=torch.float16, device="cuda") + + max_input_len = seq_len + b_start = torch.zeros((bs,), device="cuda", dtype=torch.int32) + b_len = torch.zeros((bs,), device="cuda", dtype=torch.int32) + + for i in range(bs): + b_start[i] = i * seq_len + b_len[i] = seq_len + + o = torch.randn((bs * seq_len, head_num, head_dim), dtype=torch.float16, device="cuda") + llama_context_attn_fwd(query.clone(), k.clone(), v.clone(), o, b_start, b_len, max_input_len) + + torch_out = torch_context_attention(query.clone(), k.clone(), v.clone(), bs, seq_len, head_num, head_dim) + + assert torch.allclose( + torch_out.cpu(), o.cpu(), rtol=1e-3, atol=1e-3 + ), "outputs from triton and torch are not matched" + + +if __name__ == "__main__": + test_llama_context_attention() diff --git a/tests/test_infer_ops/triton/test_rotary_embedding.py b/tests/test_infer_ops/triton/test_rotary_embedding.py new file mode 100644 index 0000000000000000000000000000000000000000..7e05ccafbfc437d70d7404a99aed0e76ec79416a --- /dev/null +++ b/tests/test_infer_ops/triton/test_rotary_embedding.py @@ -0,0 +1,55 @@ +# Adapted from ModelTC https://github.com/ModelTC/lightllm + + +import pytest +import torch +from packaging import version + +try: + pass + + from colossalai.kernel.triton.rotary_embedding_kernel import rotary_embedding_fwd + + HAS_TRITON = True +except ImportError: + HAS_TRITON = False + print("please install triton from https://github.com/openai/triton") + +TRITON_CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse("11.4") + + +def torch_rotary_emb(x, cos, sin): + seq_len, h, dim = x.shape + x0 = x[:, :, 0 : dim // 2] + x1 = x[:, :, dim // 2 : dim] + cos = cos.view((seq_len, 1, dim // 2)) + sin = sin.view((seq_len, 1, dim // 2)) + o0 = x0 * cos - x1 * sin + o1 = x0 * sin + x1 * cos + return torch.cat((o0, o1), dim=-1) + + +@pytest.mark.skipif( + not TRITON_CUDA_SUPPORT or not HAS_TRITON, reason="triton requires cuda version to be higher than 11.4" +) +def test_rotary_emb(): + SEQ_LEN = 1 + HEAD_NUM = 32 + HEAD_DIM = 128 + dtype = torch.half + # create data + x_shape = (SEQ_LEN, HEAD_NUM, HEAD_DIM) + x = -2.3 + 0.5 * torch.randn(x_shape, dtype=dtype, device="cuda") + cos_shape = (SEQ_LEN, HEAD_DIM // 2) + cos = -1.2 + 0.5 * torch.randn(cos_shape, dtype=dtype, device="cuda") + sin = -2.0 + 0.5 * torch.randn(cos_shape, dtype=dtype, device="cuda") + # forward pass + y_torch = torch_rotary_emb(x, cos, sin) + rotary_embedding_fwd(x, cos, sin) + y_triton = x + # compare + assert torch.allclose(y_torch, y_triton, atol=1e-2, rtol=0) + + +if __name__ == "__main__": + test_rotary_emb() diff --git a/tests/test_infer_ops/triton/test_self_attention_nonfusion.py b/tests/test_infer_ops/triton/test_self_attention_nonfusion.py new file mode 100644 index 0000000000000000000000000000000000000000..9bdec86645b20e18417f92a883ef78e535004293 --- /dev/null +++ b/tests/test_infer_ops/triton/test_self_attention_nonfusion.py @@ -0,0 +1,143 @@ +import pytest +import torch +import torch.nn.functional as F +from packaging import version + +try: + import triton + + from colossalai.kernel.triton.qkv_matmul_kernel import qkv_gemm_4d_kernel + from colossalai.kernel.triton.self_attention_nofusion import self_attention_compute_using_triton + + HAS_TRITON = True +except ImportError: + HAS_TRITON = False + print("please install triton from https://github.com/openai/triton") + +TRITON_CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse("11.4") + + +@pytest.mark.skipif( + not TRITON_CUDA_SUPPORT or not HAS_TRITON, reason="triton requires cuda version to be higher than 11.4" +) +def test_qkv_matmul(): + qkv = torch.randn((4, 24, 64 * 3), device="cuda", dtype=torch.float16) + scale = 1.2 + head_size = 32 + batches = qkv.shape[0] + d_model = qkv.shape[-1] // 3 + num_of_heads = d_model // head_size + + q = qkv[:, :, :d_model] + k = qkv[:, :, d_model : d_model * 2] + + q = q.view(batches, -1, num_of_heads, head_size) + k = k.view(batches, -1, num_of_heads, head_size) + q_copy = q.clone() + k_copy = k.clone() + q = torch.transpose(q, 1, 2).contiguous() + k = torch.transpose(k, 1, 2).contiguous() + k = torch.transpose(k, 2, 3).contiguous() + + torch_ouput = torch.einsum("bnij,bnjk->bnik", q, k) + torch_ouput *= 1.2 + + q, k = q_copy, k_copy + batches, M, H, K = q.shape + N = k.shape[1] + score_output = torch.empty((batches, H, M, N), device=q.device, dtype=q.dtype) + + grid = lambda meta: ( + batches, + H, + triton.cdiv(M, meta["BLOCK_SIZE_M"]) * triton.cdiv(N, meta["BLOCK_SIZE_N"]), + ) + + K = q.shape[3] + qkv_gemm_4d_kernel[grid]( + q, + k, + score_output, + M, + N, + K, + q.stride(0), + q.stride(2), + q.stride(1), + q.stride(3), + k.stride(0), + k.stride(2), + k.stride(3), + k.stride(1), + score_output.stride(0), + score_output.stride(1), + score_output.stride(2), + score_output.stride(3), + scale=scale, + # currently manually setting, later on we can use auto-tune config to match best setting + BLOCK_SIZE_M=64, + BLOCK_SIZE_N=32, + BLOCK_SIZE_K=32, + GROUP_SIZE_M=8, + ) + + check = torch.allclose(torch_ouput.cpu(), score_output.cpu(), rtol=1e-3, atol=1e-5) + assert check is True, "the outputs of triton and torch are not matched" + + +def self_attention_compute_using_torch(qkv, input_mask, scale, head_size): + batches = qkv.shape[0] + d_model = qkv.shape[-1] // 3 + num_of_heads = d_model // head_size + + q = qkv[:, :, :d_model] + k = qkv[:, :, d_model : d_model * 2] + v = qkv[:, :, d_model * 2 :] + q = q.view(batches, -1, num_of_heads, head_size) + k = k.view(batches, -1, num_of_heads, head_size) + v = v.view(batches, -1, num_of_heads, head_size) + + q = torch.transpose(q, 1, 2).contiguous() + k = torch.transpose(k, 1, 2).contiguous() + v = torch.transpose(v, 1, 2).contiguous() + + k = torch.transpose(k, -1, -2).contiguous() + + score_output = torch.einsum("bnij,bnjk->bnik", q, k) + score_output *= scale + + softmax_output = F.softmax(score_output, dim=-1) + res = torch.einsum("bnij,bnjk->bnik", softmax_output, v) + res = torch.transpose(res, 1, 2) + res = res.contiguous() + + return res.view(batches, -1, d_model), score_output, softmax_output + + +@pytest.mark.skipif( + not TRITON_CUDA_SUPPORT or not HAS_TRITON, reason="triton requires cuda version to be higher than 11.4" +) +def test_self_atttention_test(): + qkv = torch.randn((4, 24, 64 * 3), device="cuda", dtype=torch.float16) + data_output_torch, score_output_torch, softmax_output_torch = self_attention_compute_using_torch( + qkv.clone(), input_mask=None, scale=1.2, head_size=32 + ) + + data_output_triton = self_attention_compute_using_triton( + qkv.clone(), + alibi=None, + head_size=32, + scale=1.2, + input_mask=None, + layer_past=None, + use_flash=False, + triangular=True, + ) + + check = torch.allclose(data_output_triton.cpu(), data_output_torch.cpu(), rtol=1e-4, atol=1e-2) + assert check is True, "the triton output is not matched with torch output" + + +if __name__ == "__main__": + test_qkv_matmul() + test_self_atttention_test() diff --git a/tests/test_infer_ops/triton/test_softmax.py b/tests/test_infer_ops/triton/test_softmax.py new file mode 100644 index 0000000000000000000000000000000000000000..43b9c0929c4acce5d9de0a6a2bfd8929c6bb15c2 --- /dev/null +++ b/tests/test_infer_ops/triton/test_softmax.py @@ -0,0 +1,36 @@ +import pytest +import torch +from packaging import version +from torch import nn + +try: + from colossalai.kernel.triton.softmax import softmax + + HAS_TRITON = True +except ImportError: + HAS_TRITON = False + print("please install triton from https://github.com/openai/triton") + +TRITON_CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse("11.4") + + +@pytest.mark.skipif( + not TRITON_CUDA_SUPPORT or not HAS_TRITON, reason="triton requires cuda version to be higher than 11.4" +) +def test_softmax_op(): + data_samples = [ + torch.randn((3, 4, 5, 32), device="cuda", dtype=torch.float32), + torch.randn((320, 320, 78), device="cuda", dtype=torch.float32), + torch.randn((2345, 4, 5, 64), device="cuda", dtype=torch.float16), + ] + + for data in data_samples: + module = nn.Softmax(dim=-1) + data_torch_out = module(data) + data_triton_out = softmax(data) + check = torch.allclose(data_torch_out.cpu(), data_triton_out.cpu(), rtol=1e-3, atol=1e-3) + assert check is True, "softmax outputs from triton and torch are not matched" + + +if __name__ == "__main__": + test_softmax_op() diff --git a/tests/test_infer_ops/triton/test_token_attn_1.py b/tests/test_infer_ops/triton/test_token_attn_1.py new file mode 100644 index 0000000000000000000000000000000000000000..fc5f8cd6c9dcf087bda743ca78625366b9bac0b4 --- /dev/null +++ b/tests/test_infer_ops/triton/test_token_attn_1.py @@ -0,0 +1,74 @@ +import math + +import pytest +import torch +from packaging import version + +try: + pass + + from colossalai.kernel.triton.token_attention_kernel import token_attn_fwd_1 + + HAS_TRITON = True +except ImportError: + HAS_TRITON = False + print("please install triton from https://github.com/openai/triton") + +TRITON_CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse("11.4") + + +def torch_attn(xq, xk, bs, seqlen, num_head, head_dim): + xq = xq.view(bs, 1, num_head, head_dim) + xk = xk.view(bs, seqlen, num_head, head_dim) + keys = xk + xq = xq.transpose(1, 2) + keys = keys.transpose(1, 2) + scores = ( + (torch.matmul(xq, keys.transpose(2, 3)) / math.sqrt(head_dim)).squeeze().transpose(0, 1).reshape(num_head, -1) + ) + return scores + + +def torch_attn_1(xq, xk, seqlen, num_head, head_dim): + xq = xq.view(1, num_head, head_dim) + xk = xk.view(seqlen, num_head, head_dim) + logics = torch.sum(xq * xk, dim=-1, keepdim=False) + + logics = logics.transpose(0, 1) / math.sqrt(head_dim) + return logics + + +@pytest.mark.skipif( + not TRITON_CUDA_SUPPORT or not HAS_TRITON, reason="triton requires cuda version to be higher than 11.4" +) +def test_attn_1(): + pass + + batch_size, seq_len, head_num, head_dim = 17, 1025, 12, 128 + + dtype = torch.float16 + + q = torch.empty((batch_size, head_num, head_dim), dtype=dtype, device="cuda").normal_(mean=0.1, std=0.2) + k = torch.empty((batch_size * seq_len, head_num, head_dim), dtype=dtype, device="cuda").normal_(mean=0.1, std=0.2) + attn_out = torch.empty((head_num, batch_size * seq_len), dtype=dtype, device="cuda") + + b_loc = torch.zeros((batch_size, seq_len), dtype=torch.int32, device="cuda") + kv_cache_start_loc = torch.zeros((batch_size,), dtype=torch.int32, device="cuda") + kv_cache_seq_len = torch.zeros((batch_size,), dtype=torch.int32, device="cuda") + + for i in range(batch_size): + kv_cache_start_loc[i] = i * seq_len + kv_cache_seq_len[i] = seq_len + b_loc[i] = i * seq_len + torch.arange(0, seq_len, dtype=torch.int32, device="cuda") + + token_attn_fwd_1(q, k, attn_out, b_loc, kv_cache_start_loc, kv_cache_seq_len, seq_len) + + torch_out = torch_attn(q, k, batch_size, seq_len, head_num, head_dim).squeeze() + o = attn_out.squeeze() + print("max ", torch.max(torch.abs(torch_out - o))) + print("mean ", torch.mean(torch.abs(torch_out - o))) + assert torch.allclose(torch_out, o, atol=1e-2, rtol=0) + + +if __name__ == "__main__": + test_attn_1() diff --git a/tests/test_infer_ops/triton/test_token_attn_2.py b/tests/test_infer_ops/triton/test_token_attn_2.py new file mode 100644 index 0000000000000000000000000000000000000000..2dd756f2ba914b52a9896edbba37e2e4cc78d625 --- /dev/null +++ b/tests/test_infer_ops/triton/test_token_attn_2.py @@ -0,0 +1,63 @@ +import pytest +import torch +from packaging import version + +try: + pass + + from colossalai.kernel.triton.token_attention_kernel import token_attn_fwd_2 + + HAS_TRITON = True +except ImportError: + HAS_TRITON = False + print("please install triton from https://github.com/openai/triton") + +TRITON_CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse("11.4") + + +def torch_attn(V, P, bs, seqlen, num_head, head_dim): + V = V.view(bs, seqlen, num_head, head_dim).transpose(1, 2) + P = P.reshape(num_head, bs, 1, seqlen).transpose(0, 1) + attn_out = torch.matmul(P, V) + + return attn_out + + +@pytest.mark.skipif( + not TRITON_CUDA_SUPPORT or not HAS_TRITON, reason="triton requires cuda version to be higher than 11.4" +) +def test_token_attn_2(): + pass + + batch_size, seq_len, head_num, head_dim = 17, 1025, 12, 128 + dtype = torch.float16 + + V = torch.empty((batch_size * seq_len, head_num, head_dim), dtype=dtype, device="cuda").normal_(mean=0.1, std=10) + Prob = ( + torch.empty((head_num, batch_size * seq_len), dtype=dtype, device="cuda") + .normal_(mean=0.4, std=0.2) + .reshape(head_num, batch_size, seq_len) + .softmax(-1) + .reshape(head_num, batch_size * seq_len) + ) + attn_out = torch.empty((batch_size, head_num, head_dim), dtype=dtype, device="cuda") + + kv_cache_start_loc = torch.zeros((batch_size,), dtype=torch.int32, device="cuda") + kv_cache_seq_len = torch.zeros((batch_size,), dtype=torch.int32, device="cuda") + kv_cache_loc = torch.zeros((batch_size, seq_len), dtype=torch.int32, device="cuda") + for i in range(batch_size): + kv_cache_start_loc[i] = i * seq_len + kv_cache_seq_len[i] = seq_len + kv_cache_loc[i] = i * seq_len + torch.arange(0, seq_len, dtype=torch.int32, device="cuda") + + token_attn_fwd_2(Prob, V, attn_out, kv_cache_loc, kv_cache_start_loc, kv_cache_seq_len, seq_len) + + torch_out = torch_attn(V, Prob, batch_size, seq_len, head_num, head_dim).squeeze() + o = attn_out + print("max ", torch.max(torch.abs(torch_out - o))) + print("mean ", torch.mean(torch.abs(torch_out - o))) + assert torch.allclose(torch_out, o, atol=1e-2, rtol=0) + + +if __name__ == "__main__": + test_token_attn_2() diff --git a/tests/test_infer_ops/triton/test_token_attn_fwd.py b/tests/test_infer_ops/triton/test_token_attn_fwd.py new file mode 100644 index 0000000000000000000000000000000000000000..9c7a53798317f9440ca5d427c40f83ff8b429a2f --- /dev/null +++ b/tests/test_infer_ops/triton/test_token_attn_fwd.py @@ -0,0 +1,65 @@ +import pytest +import torch +from packaging import version + +try: + pass + + from colossalai.kernel.triton.token_attention_kernel import token_attention_fwd + + HAS_TRITON = True +except ImportError: + HAS_TRITON = False + print("please install triton from https://github.com/openai/triton") + +TRITON_CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse("11.4") + + +def torch_att(xq, xk, xv, bs, seqlen, num_head, head_dim): + xq = xq.view(bs, 1, num_head, head_dim) + xk = xk.view(bs, seqlen, num_head, head_dim) + xv = xv.view(bs, seqlen, num_head, head_dim) + + logics = torch.sum(xq * xk, dim=3, keepdim=False) * 1 / (head_dim**0.5) + prob = torch.softmax(logics, dim=1) + prob = prob.view(bs, seqlen, num_head, 1) + + return torch.sum(prob * xv, dim=1, keepdim=False) + + +@pytest.mark.skipif( + not TRITON_CUDA_SUPPORT or not HAS_TRITON, reason="triton requires cuda version to be higher than 11.4" +) +def test(): + Z, head_num, seq_len, head_dim = 22, 112 // 8, 2048, 128 + dtype = torch.float16 + q = torch.empty((Z, head_num, head_dim), dtype=dtype, device="cuda").normal_(mean=0.1, std=0.2) + k = torch.empty((Z * seq_len, head_num, head_dim), dtype=dtype, device="cuda").normal_(mean=0.4, std=0.2) + v = torch.empty((Z * seq_len, head_num, head_dim), dtype=dtype, device="cuda").normal_(mean=0.3, std=0.2) + o = torch.empty((Z, head_num, head_dim), dtype=dtype, device="cuda").normal_(mean=0.3, std=0.2) + alibi = torch.zeros((head_num,), dtype=torch.float32, device="cuda") + + max_kv_cache_len = seq_len + kv_cache_start_loc = torch.zeros((Z,), dtype=torch.int32, device="cuda") + kv_cache_loc = torch.zeros((Z, seq_len), dtype=torch.int32, device="cuda") + kv_cache_seq_len = torch.ones((Z,), dtype=torch.int32, device="cuda") + + kv_cache_seq_len[:] = seq_len + kv_cache_start_loc[0] = 0 + kv_cache_start_loc[1] = seq_len + kv_cache_start_loc[2] = 2 * seq_len + kv_cache_start_loc[3] = 3 * seq_len + + for i in range(Z): + kv_cache_loc[i, :] = torch.arange(i * seq_len, (i + 1) * seq_len, dtype=torch.int32, device="cuda") + + token_attention_fwd(q, k, v, o, kv_cache_loc, kv_cache_start_loc, kv_cache_seq_len, max_kv_cache_len, alibi=alibi) + torch_out = torch_att(q, k, v, Z, seq_len, head_num, head_dim) + + print("max ", torch.max(torch.abs(torch_out - o))) + print("mean ", torch.mean(torch.abs(torch_out - o))) + assert torch.allclose(torch_out, o, atol=1e-2, rtol=0) + + +if __name__ == "__main__": + test() diff --git a/tests/test_infer_ops/triton/test_token_softmax.py b/tests/test_infer_ops/triton/test_token_softmax.py new file mode 100644 index 0000000000000000000000000000000000000000..1f97f16748189686c403a02b741f3a4ab3682a74 --- /dev/null +++ b/tests/test_infer_ops/triton/test_token_softmax.py @@ -0,0 +1,48 @@ +import pytest +import torch +from packaging import version + +try: + pass + + from colossalai.kernel.triton.token_attention_kernel import token_attn_softmax_fwd + + HAS_TRITON = True +except ImportError: + HAS_TRITON = False + print("please install triton from https://github.com/openai/triton") + +TRITON_CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse("11.4") + + +@pytest.mark.skipif( + not TRITON_CUDA_SUPPORT or not HAS_TRITON, reason="triton requires cuda version to be higher than 11.4" +) +def test_softmax(): + import torch + + batch_size, seq_len, head_num, head_dim = 4, 1025, 12, 128 + + dtype = torch.float16 + + Logics = torch.empty((head_num, batch_size * seq_len), dtype=dtype, device="cuda").normal_(mean=0.1, std=10) + ProbOut = torch.empty((head_num, batch_size * seq_len), dtype=dtype, device="cuda").normal_(mean=0.4, std=0.2) + + kv_cache_start_loc = torch.zeros((batch_size,), dtype=torch.int32, device="cuda") + kv_cache_seq_len = torch.zeros((batch_size,), dtype=torch.int32, device="cuda") + + for i in range(batch_size): + kv_cache_start_loc[i] = i * seq_len + kv_cache_seq_len[i] = seq_len + + token_attn_softmax_fwd(Logics, kv_cache_start_loc, kv_cache_seq_len, ProbOut, seq_len) + + torch_out = Logics.reshape(head_num * batch_size, -1).softmax(-1).reshape(head_num, batch_size * seq_len) + o = ProbOut + print("max ", torch.max(torch.abs(torch_out - o))) + print("mean ", torch.mean(torch.abs(torch_out - o))) + assert torch.allclose(torch_out, o, atol=1e-2, rtol=0) + + +if __name__ == "__main__": + test_softmax() diff --git a/tests/test_layers/test_1d/checks_1d/common.py b/tests/test_layers/test_1d/checks_1d/common.py deleted file mode 100644 index 8b7b28613d223bfb0ab249dee01869446a08e406..0000000000000000000000000000000000000000 --- a/tests/test_layers/test_1d/checks_1d/common.py +++ /dev/null @@ -1,15 +0,0 @@ -#!/usr/bin/env python -# -*- encoding: utf-8 -*- - -import torch - -DEPTH = 4 -BATCH_SIZE = 8 -SEQ_LENGTH = 8 -IMG_SIZE = 16 -HIDDEN_SIZE = 8 -NUM_CLASSES = 8 -VOCAB_SIZE = 16 - -def check_equal(A, B): - assert torch.allclose(A, B, rtol=1e-3, atol=1e-1) == True diff --git a/tests/test_layers/test_2d/checks_2d/check_operation_2d.py b/tests/test_layers/test_2d/checks_2d/check_operation_2d.py deleted file mode 100644 index a5e37b1ec3097b29b55a4744111131ef3bfdea44..0000000000000000000000000000000000000000 --- a/tests/test_layers/test_2d/checks_2d/check_operation_2d.py +++ /dev/null @@ -1,213 +0,0 @@ -#!/usr/bin/env python -# -*- encoding: utf-8 -*- - -import torch - -from colossalai.context.parallel_mode import ParallelMode -from colossalai.core import global_context as gpc -from colossalai.nn.layer.parallel_2d._operation import Matmul_AB_2D, Matmul_ABT_2D, Matmul_ATB_2D -from colossalai.utils import get_current_device -from colossalai.utils import print_rank_0 -from .common import check_equal, BATCH_SIZE, SEQ_LENGTH, HIDDEN_SIZE, DEPTH - - -def check_AB(): - data_parallel_rank = 0 if not gpc.is_initialized(ParallelMode.DATA) else gpc.get_local_rank(ParallelMode.DATA) - pipeline_parallel_rank = 0 if not gpc.is_initialized(ParallelMode.PIPELINE) else gpc.get_local_rank( - ParallelMode.PIPELINE) - pipeline_parallel_size = 1 if not gpc.is_initialized(ParallelMode.PIPELINE) else gpc.get_world_size( - ParallelMode.PIPELINE) - tensor_parallel_size = gpc.get_world_size(ParallelMode.TENSOR) - - dtype = torch.float - j = gpc.get_local_rank(ParallelMode.PARALLEL_2D_ROW) - i = gpc.get_local_rank(ParallelMode.PARALLEL_2D_COL) - - A_shape = (BATCH_SIZE, SEQ_LENGTH, HIDDEN_SIZE) - A_master = torch.randn(A_shape, dtype=dtype, device=get_current_device()) - torch.distributed.broadcast(A_master, src=0) - A = torch.chunk(A_master, DEPTH, dim=0)[i] - A = torch.chunk(A, DEPTH, dim=-1)[j] - A = A.clone() - A.requires_grad = True - - B_shape = (HIDDEN_SIZE, 4 * HIDDEN_SIZE) - B_master = torch.randn(B_shape, dtype=dtype, device=get_current_device()) - torch.distributed.broadcast(B_master, src=0) - B = torch.chunk(B_master, DEPTH, dim=0)[i] - B = torch.chunk(B, DEPTH, dim=-1)[j] - B = B.clone() - B.requires_grad = True - - out_shape = (BATCH_SIZE // DEPTH, SEQ_LENGTH, 4 * HIDDEN_SIZE // DEPTH) - - out = Matmul_AB_2D.apply(A, B, DEPTH, out_shape, i, j, ParallelMode.PARALLEL_2D_ROW, ParallelMode.PARALLEL_2D_COL, - data_parallel_rank, pipeline_parallel_rank, pipeline_parallel_size, tensor_parallel_size) - - C_shape = (BATCH_SIZE, SEQ_LENGTH, 4 * HIDDEN_SIZE) - A_master = A_master.clone() - A_master.requires_grad = True - B_master = B_master.clone() - B_master.requires_grad = True - C_master = torch.matmul(A_master, B_master) - C = torch.chunk(C_master, DEPTH, dim=0)[i] - C = torch.chunk(C, DEPTH, dim=-1)[j] - # check forward correctness - check_equal(out, C) - print_rank_0('AB forward: pass') - - grad_shape = C_master.shape - grad_master = torch.randn(grad_shape, dtype=dtype, device=get_current_device()) - torch.distributed.broadcast(grad_master, src=0) - grad = torch.chunk(grad_master, DEPTH, dim=0)[i] - grad = torch.chunk(grad, DEPTH, dim=-1)[j] - - out.backward(grad) - - C_master.backward(grad_master) - A_grad = A_master.grad - A_grad = torch.chunk(A_grad, DEPTH, dim=0)[i] - A_grad = torch.chunk(A_grad, DEPTH, dim=-1)[j] - # check backward correctness - check_equal(A_grad, A.grad) - - B_grad = B_master.grad - B_grad = torch.chunk(B_grad, DEPTH, dim=0)[i] - B_grad = torch.chunk(B_grad, DEPTH, dim=-1)[j] - # check backward correctness - check_equal(B_grad, B.grad) - print_rank_0('AB backward: pass') - - -def check_ABT(): - data_parallel_rank = 0 if not gpc.is_initialized(ParallelMode.DATA) else gpc.get_local_rank(ParallelMode.DATA) - pipeline_parallel_rank = 0 if not gpc.is_initialized(ParallelMode.PIPELINE) else gpc.get_local_rank( - ParallelMode.PIPELINE) - pipeline_parallel_size = 1 if not gpc.is_initialized(ParallelMode.PIPELINE) else gpc.get_world_size( - ParallelMode.PIPELINE) - tensor_parallel_size = gpc.get_world_size(ParallelMode.TENSOR) - - dtype = torch.float - device = get_current_device() - - j = gpc.get_local_rank(ParallelMode.PARALLEL_2D_ROW) - i = gpc.get_local_rank(ParallelMode.PARALLEL_2D_COL) - - C_shape = (BATCH_SIZE, SEQ_LENGTH, 4 * HIDDEN_SIZE) - C_master = torch.randn(C_shape, dtype=dtype, device=device) - torch.distributed.broadcast(C_master, src=0) - C = torch.chunk(C_master, DEPTH, dim=0)[i] - C = torch.chunk(C, DEPTH, dim=-1)[j] - C = C.clone() - C.requires_grad = True - - B_shape = (HIDDEN_SIZE, 4 * HIDDEN_SIZE) - B_master = torch.randn(B_shape, dtype=dtype, device=device) - torch.distributed.broadcast(B_master, src=0) - B = torch.chunk(B_master, DEPTH, dim=0)[i] - B = torch.chunk(B, DEPTH, dim=-1)[j] - B = B.clone() - B.requires_grad = True - - out = Matmul_ABT_2D.apply(C, B, DEPTH, (BATCH_SIZE // DEPTH, SEQ_LENGTH, HIDDEN_SIZE // DEPTH), i, j, - ParallelMode.PARALLEL_2D_ROW, ParallelMode.PARALLEL_2D_COL, data_parallel_rank, - pipeline_parallel_rank, pipeline_parallel_size, tensor_parallel_size) - - A_shape = (BATCH_SIZE, SEQ_LENGTH, HIDDEN_SIZE) - C_master = C_master.clone() - C_master.requires_grad = True - B_master = B_master.clone() - B_master.requires_grad = True - A_master = torch.matmul(C_master, B_master.transpose(0, 1)) - A = torch.chunk(A_master, DEPTH, dim=0)[i] - A = torch.chunk(A, DEPTH, dim=-1)[j] - check_equal(out, A) - print_rank_0('ABT forward: pass') - - grad_shape = A_master.shape - grad_master = torch.randn(grad_shape, dtype=dtype, device=device) - torch.distributed.broadcast(grad_master, src=0) - grad = torch.chunk(grad_master, DEPTH, dim=0)[i] - grad = torch.chunk(grad, DEPTH, dim=-1)[j] - - # backward - out.backward(grad) - - A_master.backward(grad_master) - C_grad = C_master.grad - C_grad = torch.chunk(C_grad, DEPTH, dim=0)[i] - C_grad = torch.chunk(C_grad, DEPTH, dim=-1)[j] - check_equal(C_grad, C.grad) - - B_grad = B_master.grad - B_grad = torch.chunk(B_grad, DEPTH, dim=0)[i] - B_grad = torch.chunk(B_grad, DEPTH, dim=-1)[j] - check_equal(B_grad, B.grad) - print_rank_0('ABT backward: pass') - - -def check_ATB(): - data_parallel_rank = 0 if not gpc.is_initialized(ParallelMode.DATA) else gpc.get_local_rank(ParallelMode.DATA) - pipeline_parallel_rank = 0 if not gpc.is_initialized(ParallelMode.PIPELINE) else gpc.get_local_rank( - ParallelMode.PIPELINE) - pipeline_parallel_size = 1 if not gpc.is_initialized(ParallelMode.PIPELINE) else gpc.get_world_size( - ParallelMode.PIPELINE) - tensor_parallel_size = gpc.get_world_size(ParallelMode.TENSOR) - - device = get_current_device() - dtype = torch.float - - j = gpc.get_local_rank(ParallelMode.PARALLEL_2D_ROW) - i = gpc.get_local_rank(ParallelMode.PARALLEL_2D_COL) - - A_shape = (BATCH_SIZE, SEQ_LENGTH, HIDDEN_SIZE) - A_master = torch.randn(A_shape, dtype=dtype, device=device) - torch.distributed.broadcast(A_master, src=0) - A = torch.chunk(A_master, DEPTH, dim=0)[i] - A = torch.chunk(A, DEPTH, dim=-1)[j] - A = A.clone() - A.requires_grad = True - - C_shape = (BATCH_SIZE, SEQ_LENGTH, 4 * HIDDEN_SIZE) - C_master = torch.randn(C_shape, dtype=dtype, device=device) - torch.distributed.broadcast(C_master, src=0) - C = torch.chunk(C_master, DEPTH, dim=0)[i] - C = torch.chunk(C, DEPTH, dim=-1)[j] - C = C.clone() - C.requires_grad = True - - out = Matmul_ATB_2D.apply(A, C, DEPTH, (HIDDEN_SIZE // DEPTH, 4 * HIDDEN_SIZE // DEPTH), i, j, - ParallelMode.PARALLEL_2D_ROW, ParallelMode.PARALLEL_2D_COL, data_parallel_rank, - pipeline_parallel_rank, pipeline_parallel_size, tensor_parallel_size) - - B_shape = (HIDDEN_SIZE, 4 * HIDDEN_SIZE) - A_master = A_master.clone() - A_master.requires_grad = True - C_master = C_master.clone() - C_master.requires_grad = True - B_master = torch.matmul( - A_master.view(-1, A_master.shape[-1]).transpose(0, 1), C_master.view(-1, C_master.shape[-1])) - B = torch.chunk(B_master, DEPTH, dim=0)[i] - B = torch.chunk(B, DEPTH, dim=-1)[j] - check_equal(out, B) - print_rank_0('ATB forward: pass') - - grad_shape = B_master.shape - grad_master = torch.randn(grad_shape, dtype=dtype, device=device) - torch.distributed.broadcast(grad_master, src=0) - grad = torch.chunk(grad_master, DEPTH, dim=0)[i] - grad = torch.chunk(grad, DEPTH, dim=-1)[j] - - out.backward(grad) - - B_master.backward(grad_master) - A_grad = A_master.grad - A_grad = torch.chunk(A_grad, DEPTH, dim=0)[i] - A_grad = torch.chunk(A_grad, DEPTH, dim=-1)[j] - check_equal(A_grad, A.grad) - - C_grad = C_master.grad - C_grad = torch.chunk(C_grad, DEPTH, dim=0)[i] - C_grad = torch.chunk(C_grad, DEPTH, dim=-1)[j] - check_equal(C_grad, C.grad) - print_rank_0('ATB backward: pass') diff --git a/tests/test_layers/test_2p5d/checks_2p5d/check_operation_2p5d.py b/tests/test_layers/test_2p5d/checks_2p5d/check_operation_2p5d.py deleted file mode 100644 index d0c3b02fccba589507bf8e2af25846767636c734..0000000000000000000000000000000000000000 --- a/tests/test_layers/test_2p5d/checks_2p5d/check_operation_2p5d.py +++ /dev/null @@ -1,216 +0,0 @@ -import torch - -from colossalai.context import ParallelMode -from colossalai.core import global_context as gpc -from colossalai.nn.layer.parallel_2p5d._operation import Matmul_AB_2p5D, Matmul_ABT_2p5D, \ - Matmul_ATB_2p5D -from colossalai.utils import get_current_device -from colossalai.utils import print_rank_0 -from .common import * - - -def check_AB(): - data_parallel_rank = 0 if not gpc.is_initialized(ParallelMode.DATA) else gpc.get_local_rank(ParallelMode.DATA) - pipeline_parallel_rank = 0 if not gpc.is_initialized(ParallelMode.PIPELINE) else gpc.get_local_rank( - ParallelMode.PIPELINE) - pipeline_parallel_size = 1 if not gpc.is_initialized(ParallelMode.PIPELINE) else gpc.get_world_size( - ParallelMode.PIPELINE) - tensor_parallel_size = gpc.get_world_size(ParallelMode.TENSOR) - - dtype = torch.float - i = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL) - j = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_ROW) - k = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_DEP) - - A_shape = (BATCH_SIZE, SEQ_LENGTH, HIDDEN_SIZE) - A_master = torch.randn(A_shape, dtype=dtype, device=get_current_device()) - torch.distributed.broadcast(A_master, src=0) - A = torch.chunk(A_master, TESSERACT_DIM, dim=0)[i] - A = torch.chunk(A, TESSERACT_DIM, dim=-1)[j] - A = A.clone() - A.requires_grad = True - - B_shape = (HIDDEN_SIZE, 4 * HIDDEN_SIZE) - B_master = torch.randn(B_shape, dtype=dtype, device=get_current_device()) - torch.distributed.broadcast(B_master, src=0) - B = torch.chunk(B_master, TESSERACT_DIM, dim=0)[i] - B = torch.chunk(B, TESSERACT_DIM, dim=-1)[j] - B = B.clone() - B.requires_grad = True - - out_shape = (BATCH_SIZE // TESSERACT_DIM, SEQ_LENGTH, 4 * HIDDEN_SIZE // TESSERACT_DIM) - out = Matmul_AB_2p5D.apply(A, B, TESSERACT_DIM, out_shape, i, j, k, ParallelMode.PARALLEL_2P5D_ROW, - ParallelMode.PARALLEL_2P5D_COL, data_parallel_rank, pipeline_parallel_rank, - pipeline_parallel_size, tensor_parallel_size) - - C_shape = (BATCH_SIZE, SEQ_LENGTH, 4 * HIDDEN_SIZE) - A_master = A_master.clone() - A_master.requires_grad = True - B_master = B_master.clone() - B_master.requires_grad = True - C_master = torch.matmul(A_master, B_master) - C = torch.chunk(C_master, TESSERACT_DIM, dim=0)[i] - C = torch.chunk(C, TESSERACT_DIM, dim=-1)[j] - # check forward correctness - check_equal(out, C) - print_rank_0('AB forward: pass') - - grad_shape = C_master.shape - grad_master = torch.randn(grad_shape, dtype=dtype, device=get_current_device()) - torch.distributed.broadcast(grad_master, src=0) - grad = torch.chunk(grad_master, TESSERACT_DIM, dim=0)[i] - grad = torch.chunk(grad, TESSERACT_DIM, dim=-1)[j] - - out.backward(grad) - - C_master.backward(grad_master) - A_grad = A_master.grad - A_grad = torch.chunk(A_grad, TESSERACT_DIM, dim=0)[i] - A_grad = torch.chunk(A_grad, TESSERACT_DIM, dim=-1)[j] - # check backward correctness - check_equal(A_grad, A.grad) - - B_grad = B_master.grad - B_grad = torch.chunk(B_grad, TESSERACT_DIM, dim=0)[i] - B_grad = torch.chunk(B_grad, TESSERACT_DIM, dim=-1)[j] - # check backward correctness - check_equal(B_grad, B.grad) - print_rank_0('AB backward: pass') - - -def check_ABT(): - data_parallel_rank = 0 if not gpc.is_initialized(ParallelMode.DATA) else gpc.get_local_rank(ParallelMode.DATA) - pipeline_parallel_rank = 0 if not gpc.is_initialized(ParallelMode.PIPELINE) else gpc.get_local_rank( - ParallelMode.PIPELINE) - pipeline_parallel_size = 1 if not gpc.is_initialized(ParallelMode.PIPELINE) else gpc.get_world_size( - ParallelMode.PIPELINE) - tensor_parallel_size = gpc.get_world_size(ParallelMode.TENSOR) - - dtype = torch.float - device = get_current_device() - - i = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL) - j = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_ROW) - k = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_DEP) - - C_shape = (BATCH_SIZE, SEQ_LENGTH, 4 * HIDDEN_SIZE) - C_master = torch.randn(C_shape, dtype=dtype, device=device) - torch.distributed.broadcast(C_master, src=0) - C = torch.chunk(C_master, TESSERACT_DIM, dim=0)[i] - C = torch.chunk(C, TESSERACT_DIM, dim=-1)[j] - C = C.clone() - C.requires_grad = True - - B_shape = (HIDDEN_SIZE, 4 * HIDDEN_SIZE) - B_master = torch.randn(B_shape, dtype=dtype, device=device) - torch.distributed.broadcast(B_master, src=0) - B = torch.chunk(B_master, TESSERACT_DIM, dim=0)[i] - B = torch.chunk(B, TESSERACT_DIM, dim=-1)[j] - B = B.clone() - B.requires_grad = True - - out = Matmul_ABT_2p5D.apply(C, B, TESSERACT_DIM, - (BATCH_SIZE // TESSERACT_DIM, SEQ_LENGTH, HIDDEN_SIZE // TESSERACT_DIM), i, j, k, - ParallelMode.PARALLEL_2P5D_ROW, ParallelMode.PARALLEL_2P5D_COL, data_parallel_rank, - pipeline_parallel_rank, pipeline_parallel_size, tensor_parallel_size) - - A_shape = (BATCH_SIZE, SEQ_LENGTH, HIDDEN_SIZE) - C_master = C_master.clone() - C_master.requires_grad = True - B_master = B_master.clone() - B_master.requires_grad = True - A_master = torch.matmul(C_master, B_master.transpose(0, 1)) - A = torch.chunk(A_master, TESSERACT_DIM, dim=0)[i] - A = torch.chunk(A, TESSERACT_DIM, dim=-1)[j] - check_equal(out, A) - print_rank_0('ABT forward: pass') - - grad_shape = A_master.shape - grad_master = torch.randn(grad_shape, dtype=dtype, device=device) - torch.distributed.broadcast(grad_master, src=0) - grad = torch.chunk(grad_master, TESSERACT_DIM, dim=0)[i] - grad = torch.chunk(grad, TESSERACT_DIM, dim=-1)[j] - - # backward - out.backward(grad) - - A_master.backward(grad_master) - C_grad = C_master.grad - C_grad = torch.chunk(C_grad, TESSERACT_DIM, dim=0)[i] - C_grad = torch.chunk(C_grad, TESSERACT_DIM, dim=-1)[j] - check_equal(C_grad, C.grad) - - B_grad = B_master.grad - B_grad = torch.chunk(B_grad, TESSERACT_DIM, dim=0)[i] - B_grad = torch.chunk(B_grad, TESSERACT_DIM, dim=-1)[j] - check_equal(B_grad, B.grad) - print_rank_0('ABT backward: pass') - - -def check_ATB(): - data_parallel_rank = 0 if not gpc.is_initialized(ParallelMode.DATA) else gpc.get_local_rank(ParallelMode.DATA) - pipeline_parallel_rank = 0 if not gpc.is_initialized(ParallelMode.PIPELINE) else gpc.get_local_rank( - ParallelMode.PIPELINE) - pipeline_parallel_size = 1 if not gpc.is_initialized(ParallelMode.PIPELINE) else gpc.get_world_size( - ParallelMode.PIPELINE) - tensor_parallel_size = gpc.get_world_size(ParallelMode.TENSOR) - - device = get_current_device() - dtype = torch.float - - i = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL) - j = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_ROW) - k = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_DEP) - - A_shape = (BATCH_SIZE, SEQ_LENGTH, HIDDEN_SIZE) - A_master = torch.randn(A_shape, dtype=dtype, device=device) - torch.distributed.broadcast(A_master, src=0) - A = torch.chunk(A_master, TESSERACT_DIM, dim=0)[i] - A = torch.chunk(A, TESSERACT_DIM, dim=-1)[j] - A = A.clone() - A.requires_grad = True - - C_shape = (BATCH_SIZE, SEQ_LENGTH, 4 * HIDDEN_SIZE) - C_master = torch.randn(C_shape, dtype=dtype, device=device) - torch.distributed.broadcast(C_master, src=0) - C = torch.chunk(C_master, TESSERACT_DIM, dim=0)[i] - C = torch.chunk(C, TESSERACT_DIM, dim=-1)[j] - C = C.clone() - C.requires_grad = True - - out = Matmul_ATB_2p5D.apply(A, C, TESSERACT_DIM, (HIDDEN_SIZE // TESSERACT_DIM, 4 * HIDDEN_SIZE // TESSERACT_DIM), - i, j, k, ParallelMode.PARALLEL_2P5D_ROW, ParallelMode.PARALLEL_2P5D_COL, - data_parallel_rank, pipeline_parallel_rank, pipeline_parallel_size, - tensor_parallel_size) - - B_shape = (HIDDEN_SIZE, 4 * HIDDEN_SIZE) - A_master = A_master.clone() - A_master.requires_grad = True - C_master = C_master.clone() - C_master.requires_grad = True - B_master = torch.matmul( - A_master.view(-1, A_master.shape[-1]).transpose(0, 1), C_master.view(-1, C_master.shape[-1])) - B = torch.chunk(B_master, TESSERACT_DIM, dim=0)[i] - B = torch.chunk(B, TESSERACT_DIM, dim=-1)[j] - check_equal(out, B) - print_rank_0('ATB forward: pass') - - grad_shape = B_master.shape - grad_master = torch.randn(grad_shape, dtype=dtype, device=device) - torch.distributed.broadcast(grad_master, src=0) - grad = torch.chunk(grad_master, TESSERACT_DIM, dim=0)[i] - grad = torch.chunk(grad, TESSERACT_DIM, dim=-1)[j] - - out.backward(grad) - - B_master.backward(grad_master) - A_grad = A_master.grad - A_grad = torch.chunk(A_grad, TESSERACT_DIM, dim=0)[i] - A_grad = torch.chunk(A_grad, TESSERACT_DIM, dim=-1)[j] - check_equal(A_grad, A.grad) - - C_grad = C_master.grad - C_grad = torch.chunk(C_grad, TESSERACT_DIM, dim=0)[i] - C_grad = torch.chunk(C_grad, TESSERACT_DIM, dim=-1)[j] - check_equal(C_grad, C.grad) - print_rank_0('ATB backward: pass') diff --git a/tests/test_layers/test_2p5d/checks_2p5d/common.py b/tests/test_layers/test_2p5d/checks_2p5d/common.py deleted file mode 100644 index aff85f109666d7cdf9e65173eda851368c39694c..0000000000000000000000000000000000000000 --- a/tests/test_layers/test_2p5d/checks_2p5d/common.py +++ /dev/null @@ -1,14 +0,0 @@ -import torch - -TESSERACT_DIM = 2 -TESSERACT_DEP = 2 -BATCH_SIZE = 8 -SEQ_LENGTH = 8 -HIDDEN_SIZE = 8 -NUM_CLASSES = 8 -VOCAB_SIZE = 16 -IMG_SIZE = 16 - - -def check_equal(A, B): - assert torch.allclose(A, B, rtol=1e-5, atol=1e-2) \ No newline at end of file diff --git a/tests/test_layers/test_3d/checks_3d/common.py b/tests/test_layers/test_3d/checks_3d/common.py deleted file mode 100644 index afb19c4745cc72cd66a4d5a11239ed6f10f68d14..0000000000000000000000000000000000000000 --- a/tests/test_layers/test_3d/checks_3d/common.py +++ /dev/null @@ -1,19 +0,0 @@ -#!/usr/bin/env python -# -*- encoding: utf-8 -*- - -import torch - -DEPTH = 2 -BATCH_SIZE = 8 -SEQ_LENGTH = 8 -HIDDEN_SIZE = 8 -NUM_CLASSES = 8 -NUM_BLOCKS = 2 -IMG_SIZE = 16 -VOCAB_SIZE = 16 - - -def check_equal(A, B): - eq = torch.allclose(A, B, rtol=1e-3, atol=1e-2) - assert eq, f"\nA = {A}\nB = {B}" - return eq \ No newline at end of file diff --git a/tests/test_layers/test_cache_embedding.py b/tests/test_layers/test_cache_embedding.py deleted file mode 100644 index 22d4f02a48d726d8f1892f5cb99f57eca863490d..0000000000000000000000000000000000000000 --- a/tests/test_layers/test_cache_embedding.py +++ /dev/null @@ -1,377 +0,0 @@ -import random -from typing import List - -import numpy as np -import pytest -import torch - -import colossalai -from colossalai.nn.parallel.layers import ( - CachedEmbeddingBag, - CachedParamMgr, - EvictionStrategy, - ParallelCachedEmbeddingBag, - ParallelCachedEmbeddingBagTablewise, - TablewiseEmbeddingBagConfig, -) -from colossalai.tensor import ColoTensor, ComputePattern, ComputeSpec, ProcessGroup, ShardSpec -from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn - -NUM_EMBED, EMBED_DIM = 10, 8 -BATCH_SIZE = 8 - - -def set_seed(seed): - """ - To achieve reproducible results, it's necessary to fix random seeds - """ - random.seed(seed) - np.random.seed(seed) - torch.manual_seed(seed) - - -def synthesize_1d_sparse_feature( - batch_size, - num_embed, - device, -): - indices_in_batch = batch_size * 2 - indices = torch.randint(low=0, high=num_embed, size=(indices_in_batch,), device=device, dtype=torch.long) - offsets = torch.from_numpy( - np.array([ - 0, *np.sort(np.random.randint(low=0, high=indices_in_batch, size=(indices_in_batch - 1,))), indices_in_batch - ])).to(device).long() - return indices, offsets - - -@pytest.mark.skip -@clear_cache_before_run() -def test_cachemgr(): - model = torch.nn.EmbeddingBag(10000, 128) - # 10 chunks, 5 in cuda - mgr = CachedParamMgr(model.weight.detach(), 5) - assert mgr.cuda_row_num == 5 - - mgr._admit(1) - assert not mgr._chunk_in_cuda(2) - assert mgr._chunk_in_cuda(1) - - # print(mgr.cached_chunk_table) - mgr._admit(8) - - # now 3 chunk is available - assert mgr.cuda_available_chunk_num == 3 - - mgr._evict() - assert mgr.cuda_available_chunk_num == 4 - - mgr._prepare_rows_on_cuda(torch.tensor([9, 6, 5], dtype=torch.long, device=0)) - mgr._prepare_rows_on_cuda(torch.tensor([3, 4, 5], dtype=torch.long, device=0)) - # print(mgr.cached_chunk_table) - # mgr.print_comm_stats() - - mgr.flush() - assert mgr.cuda_available_chunk_num == 5 - - -@clear_cache_before_run() -def test_reorder_with_freq(): - num_embed = 100 - chunk_size = 1 - num_chunk = 5 - - idx_map = torch.randint(10000, size=(num_embed,)) - sorted_idx = torch.argsort(idx_map, descending=True).tolist() - chunkid, offset_in_chunk = [], [] - for i in range(num_embed): - idx = sorted_idx.index(i) - chunkid.append(idx // chunk_size) - offset_in_chunk.append(idx % chunk_size) - - dev = torch.device('cuda') - chunkid = torch.tensor(chunkid, dtype=torch.long, device=dev) - offset_in_chunk = torch.tensor(offset_in_chunk, dtype=torch.long, device=dev) - - weight = torch.rand(num_embed, 2) - mgr = CachedParamMgr(weight, num_chunk) - - mgr.reorder(idx_map) - - indices = mgr.idx_map.index_select(0, torch.arange(num_embed, dtype=torch.long, device=dev)) - mgr_chunk_id = torch.div(indices, chunk_size, rounding_mode='floor') - mgr_offsets = torch.remainder(indices, chunk_size) - assert torch.allclose(chunkid, mgr_chunk_id), f"chunk id: {chunkid}, mgr: {mgr_chunk_id}" - assert torch.allclose(offset_in_chunk, mgr_offsets), \ - f"offset in chunk: {offset_in_chunk}, mgr: {mgr_offsets}" - - -@clear_cache_before_run() -@parameterize('use_LFU', [True, False]) -def test_freq_aware_embed(use_LFU: bool): - device = torch.device('cuda', 0) - evict_strategy = EvictionStrategy.LFU if use_LFU else EvictionStrategy.DATASET - model = CachedEmbeddingBag(NUM_EMBED, - EMBED_DIM, - mode='mean', - include_last_offset=True, - cache_ratio=min(BATCH_SIZE * 2 / NUM_EMBED, 1.0), - ids_freq_mapping=None, - evict_strategy=evict_strategy).to(device) - - assert model.weight.shape[0] == NUM_EMBED - ref_model = torch.nn.EmbeddingBag.from_pretrained(model.weight.detach().to(device), - mode='mean', - include_last_offset=True, - freeze=False) - - assert torch.allclose(ref_model.weight.detach(), model.weight.detach().to(device)) - - optimizer = torch.optim.SGD(model.parameters(), lr=1e-3) - ref_optimizer = torch.optim.SGD(ref_model.parameters(), lr=1e-3) - - for i in range(5): - indices, offsets = synthesize_1d_sparse_feature(BATCH_SIZE, NUM_EMBED, device) - res = model(indices, offsets) - ref_res = ref_model(indices, offsets) - assert torch.allclose(res, ref_res), f"model result: {res}, reference: {ref_res}" - - grad = torch.rand_like(res) - # comparing gradient here is nontrivial - res.backward(grad) - ref_res.backward(grad) - optimizer.step() - optimizer.zero_grad() - - ref_optimizer.step() - ref_optimizer.zero_grad() - - model.cache_weight_mgr.flush() - model_weight = model.weight.detach().to(device) - ref_weight = ref_model.weight.detach() - assert torch.allclose(model_weight, ref_weight), \ - f"model weight: {model_weight[10:18, :8]}, reference: {ref_weight[10:18, :8]}" - - -@clear_cache_before_run() -@parameterize('init_freq', [True, False]) -def test_lfu_strategy(init_freq: bool): - # minimal test to check behavior - Bag = CachedEmbeddingBag(5, - 5, - cache_ratio=3 / 5, - buffer_size=0, - pin_weight=True, - ids_freq_mapping=[4, 2, 1, 3, 1] if init_freq else None, - warmup_ratio=1.0, - evict_strategy=EvictionStrategy.LFU) - - # print('cached_idx_map: ', Bag.cache_weight_mgr.cached_idx_map) - offsets = torch.tensor([0], device="cuda:0") - - # prepare frequency learning info: - Bag.forward(torch.tensor([2], device="cuda:0"), offsets) - Bag.forward(torch.tensor([1, 2], device="cuda:0"), offsets) - Bag.forward(torch.tensor([0, 2], device="cuda:0"), offsets) - Bag.forward(torch.tensor([0, 1, 2], device="cuda:0"), offsets) - Bag.forward(torch.tensor([0, 1, 2], device="cuda:0"), offsets) - Bag.forward(torch.tensor([0, 1, 2], device="cuda:0"), offsets) - Bag.forward(torch.tensor([0, 1, 2], device="cuda:0"), offsets) - Bag.forward(torch.tensor([0, 2], device="cuda:0"), offsets) - Bag.forward(torch.tensor([0, 2], device="cuda:0"), offsets) - Bag.forward(torch.tensor([0, 2], device="cuda:0"), offsets) - Bag.forward(torch.tensor([0, 2], device="cuda:0"), offsets) - Bag.forward(torch.tensor([0], device="cuda:0"), offsets) - Bag.forward(torch.tensor([0], device="cuda:0"), offsets) - Bag.forward(torch.tensor([0], device="cuda:0"), offsets) - Bag.forward(torch.tensor([0], device="cuda:0"), offsets) - - # check strategy - Bag.forward(torch.tensor([0, 1, 2], device="cuda:0"), offsets) - Bag.forward(torch.tensor([0, 1, 2], device="cuda:0"), offsets) - Bag.forward(torch.tensor([3], device="cuda:0"), offsets) # miss, evict 1 - Bag.forward(torch.tensor([2], device="cuda:0"), offsets) # hit - Bag.forward(torch.tensor([4], device="cuda:0"), offsets) # miss, evict 3 - Bag.forward(torch.tensor([2], device="cuda:0"), offsets) # hit - Bag.forward(torch.tensor([0], device="cuda:0"), offsets) # hit - - assert torch.allclose(torch.Tensor(Bag.cache_weight_mgr.num_hits_history[-6:]), torch.Tensor([3, 0, 1, 0, 1, 1])), \ - "LFU strategy behavior failed" - - -def gather_tensor(tensor, rank, world_size): - gather_list = [] - if rank == 0: - gather_list = [torch.empty_like(tensor) for _ in range(world_size)] - - torch.distributed.gather(tensor, gather_list, dst=0) - return gather_list - - -def run_parallel_freq_aware_embed_tablewise(rank, world_size): - if world_size != 2: - return - device = torch.device('cuda', torch.cuda.current_device()) - - # initialize weight - # 3 feature tables. idx: 0~5, 6~10, 11~17 - weight_tables = torch.rand(18, 5) - weight_table1 = weight_tables[0:6] - weight_table2 = weight_tables[6:11] - weight_table3 = weight_tables[11:18] - embedding_bag_config_list: List[TablewiseEmbeddingBagConfig] = [] - embedding_bag_config_list.append( - TablewiseEmbeddingBagConfig(num_embeddings=6, - cuda_row_num=4, - assigned_rank=0, - initial_weight=weight_table1.clone().detach().cpu())) - embedding_bag_config_list.append( - TablewiseEmbeddingBagConfig(num_embeddings=5, - cuda_row_num=4, - assigned_rank=0, - initial_weight=weight_table2.clone().detach().cpu())) - embedding_bag_config_list.append( - TablewiseEmbeddingBagConfig(num_embeddings=7, - cuda_row_num=4, - assigned_rank=1, - initial_weight=weight_table3.clone().detach().cpu())) - if rank == 0: - _weight = torch.cat([weight_table1, weight_table2], 0) - else: - _weight = weight_table3 - model = ParallelCachedEmbeddingBagTablewise( - embedding_bag_config_list, - embedding_dim=5, - _weight=_weight, - include_last_offset=True, - cache_ratio=0.5, - buffer_size=0, - evict_strategy=EvictionStrategy.LFU, - ) - # explain - ''' - batch feature 1 feature 2 feature 3 - input0 [1,2,3] [6,7] [] - input1 [] [9] [13,15] - input2 [1,5] [6,8] [11] - ↑ ↑ ↑ - rank 0 rank 0 rank 1 - in KJT format - ''' - res = model(torch.tensor([1, 2, 3, 1, 5, 6, 7, 9, 6, 8, 13, 15, 11], device=device), - torch.tensor([0, 3, 3, 5, 7, 8, 10, 10, 12, 13], device=device), - already_split_along_rank=False) - optimizer = torch.optim.SGD(model.parameters(), lr=1e-2) - rand_grad = torch.rand(3, 5 * 3, dtype=res.dtype, device=res.device) - if rank == 0: - fake_grad = rand_grad[0:2] - else: - fake_grad = rand_grad[2:] - res.backward(fake_grad) - optimizer.step() - optimizer.zero_grad() - - # check correctness - if rank == 0: - ref_model = torch.nn.EmbeddingBag.from_pretrained(weight_tables.detach().clone(), - include_last_offset=True, - freeze=False).to(device) - ref_optimizer = torch.optim.SGD(ref_model.parameters(), lr=1e-2) - ref_fake_grad = torch.cat(rand_grad.split(5, 1), 0) - ref_res = ref_model(torch.tensor([1, 2, 3, 1, 5, 6, 7, 9, 6, 8, 13, 15, 11], device=device), - torch.tensor([0, 3, 3, 5, 7, 8, 10, 10, 12, 13], device=device)) - ref_res.backward(ref_fake_grad) - ref_optimizer.step() - ref_optimizer.zero_grad() - - model.cache_weight_mgr.flush() - recover_weight = model.cache_weight_mgr.weight.to(device) - ref_weight = ref_model.weight.detach()[:11] - assert torch.allclose(recover_weight, ref_weight), f"{recover_weight - ref_weight}" - - -def run_parallel_freq_aware_embed_columnwise(rank, world_size): - device = torch.device('cuda', torch.cuda.current_device()) - - num_embed = 100 - embed_dim = 16 - batch_size = 4 - - set_seed(4321) - weight = torch.rand(num_embed, embed_dim) - coloweight = ColoTensor(weight.clone().detach().cpu(), spec=None) - - # initialize the tensor spec for the embedding weight parameter, - # which is an ColoParameter. - coloweight.set_process_group(ProcessGroup(tp_degree=world_size)) - coloweight.set_tensor_spec(ShardSpec(dims=[-1], num_partitions=[world_size]), ComputeSpec(ComputePattern.TP1D)) - - model = ParallelCachedEmbeddingBag.from_pretrained( - coloweight, - include_last_offset=True, - freeze=False, - cache_ratio=batch_size * 2 / num_embed, - ) - - assert model.cache_weight_mgr.weight.device.type == 'cpu' - assert model.cache_weight_mgr.cuda_cached_weight.requires_grad - weight_in_rank = torch.tensor_split(weight, world_size, -1)[rank] - print(f"model weight: {model.cache_weight_mgr.weight.shape}, ref weight: {weight_in_rank.shape}") - assert torch.allclose(weight_in_rank, - model.cache_weight_mgr.weight.detach()), f"{weight_in_rank - model.cache_weight_mgr.weight}" - - optimizer = torch.optim.SGD(model.parameters(), lr=1e-3) - - if rank == 0: - ref_model = torch.nn.EmbeddingBag.from_pretrained(weight.detach().clone(), - include_last_offset=True, - freeze=False).to(device) - ref_optimizer = torch.optim.SGD(ref_model.parameters(), lr=1e-3) - - set_seed(4321) - for i in range(5): - indices, offsets = synthesize_1d_sparse_feature(batch_size, num_embed, device) - res = model(indices, offsets) - - grad = torch.rand(batch_size * 2, embed_dim, dtype=res.dtype, device=res.device) - grad_in_rank = torch.tensor_split(grad, world_size, 0)[rank] - res.backward(grad_in_rank) - - optimizer.step() - optimizer.zero_grad() - - res_list = gather_tensor(res.detach(), rank, world_size) - - if rank == 0: - ref_res = ref_model(indices, offsets) - recover_res = torch.cat(res_list, dim=0) - - assert torch.allclose(ref_res, recover_res) - - ref_res.backward(grad) - ref_optimizer.step() - ref_optimizer.zero_grad() - - model.cache_weight_mgr.flush() - weight_list = gather_tensor(model.cache_weight_mgr.weight.detach().cuda(), rank, world_size) - if rank == 0: - recover_weight = torch.cat(weight_list, dim=1) - assert torch.allclose(recover_weight, ref_model.weight.detach()), f"{recover_weight - ref_model.weight}" - - -def run_dist(rank, world_size, port): - colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') - # run_parallel_freq_aware_embed_columnwise(rank, world_size) - run_parallel_freq_aware_embed_tablewise(rank, world_size) - - -@pytest.mark.dist -@pytest.mark.parametrize('world_size', [1, 4]) -@rerun_if_address_is_in_use() -def test_parallel_freq_aware_embed(world_size): - spawn(run_dist, world_size) - - -if __name__ == '__main__': - # test_freq_aware_embed(True) - test_parallel_freq_aware_embed(2) - # test_lfu_strategy(False) diff --git a/tests/test_layers/test_sequence/checks_seq/check_layer_seq.py b/tests/test_layers/test_sequence/checks_seq/check_layer_seq.py deleted file mode 100644 index 2b7b999d43731ae5b5cd3f7cb87eecc6c1585fc4..0000000000000000000000000000000000000000 --- a/tests/test_layers/test_sequence/checks_seq/check_layer_seq.py +++ /dev/null @@ -1,21 +0,0 @@ -import torch - -from colossalai.context import ParallelMode -from colossalai.core import global_context as gpc -from colossalai.nn import TransformerSelfAttentionRing -from colossalai.utils import get_current_device - - -def check_selfattention(): - WORLD_SIZE = gpc.get_world_size(ParallelMode.SEQUENCE) - SUB_SEQ_LENGTH = 8 - BATCH = 4 - HIDDEN_SIZE = 16 - - layer = TransformerSelfAttentionRing(16, 8, 8, 0.1) - layer = layer.to(get_current_device()) - - hidden_states = torch.rand(SUB_SEQ_LENGTH, BATCH, HIDDEN_SIZE).to(get_current_device()) - attention_mask = torch.randint(low=0, high=2, - size=(BATCH, 1, 1, 1, SUB_SEQ_LENGTH * WORLD_SIZE)).to(get_current_device()) - out = layer(hidden_states, attention_mask) diff --git a/tests/test_layers/test_sequence/test_sequence.py b/tests/test_layers/test_sequence/test_sequence.py deleted file mode 100644 index aac192d7eff0b2df668a937c4eafdf9e11404f51..0000000000000000000000000000000000000000 --- a/tests/test_layers/test_sequence/test_sequence.py +++ /dev/null @@ -1,138 +0,0 @@ -import pytest -import torch -import torch.distributed as dist - -import colossalai -from colossalai.context import ParallelMode -from colossalai.core import global_context as gpc -from colossalai.testing import rerun_if_address_is_in_use, spawn - -CONFIG = dict(parallel=dict(tensor=dict(size=4, mode='sequence'))) - - -def check_ring_qk(rank, world_size): - # params - batch_size = 4 - num_heads = 4 - seq_length = 32 - attention_head_size = 32 - sub_seq_length = seq_length // world_size - - # create master tensors - q = torch.rand(batch_size * num_heads, seq_length, attention_head_size).cuda() - k = torch.rand(batch_size * num_heads, seq_length, attention_head_size).cuda() - dist.broadcast(q, src=0, group=gpc.get_group(ParallelMode.SEQUENCE)) - dist.broadcast(k, src=0, group=gpc.get_group(ParallelMode.SEQUENCE)) - - # create distributed tensors - sub_q = q.clone()[:, rank * sub_seq_length:(rank + 1) * sub_seq_length].contiguous() - sub_k = k.clone()[:, rank * sub_seq_length:(rank + 1) * sub_seq_length].contiguous() - - # set autograd attributes - q.requires_grad = True - k.requires_grad = True - q.retain_grad() - k.retain_grad() - sub_q.requires_grad = True - sub_k.requires_grad = True - sub_q.retain_grad() - sub_k.retain_grad() - - # compute master attention scores - a = torch.matmul(q, k.transpose(2, 1)) - - # compute distributed attention scores - ring_qk = colossalai.nn.layer.parallel_sequence.RingQK.apply - sub_a = ring_qk(sub_q, sub_k, batch_size, num_heads, sub_seq_length) - - # check master and distributed attetion scores - sub_master_a = a[:, rank * sub_seq_length:(rank + 1) * sub_seq_length] - assert torch.allclose(sub_a, sub_master_a, rtol=1e-5, atol=1e-2) - - # run master backward - a.retain_grad() - a.mean().backward() - - # run distributed backward - partial_master_a_grad = a.grad[:, rank * sub_seq_length:(rank + 1) * sub_seq_length] - torch.autograd.backward(sub_a, partial_master_a_grad) - - # check master and distributed grads - partial_master_q_grad = q.grad[:, rank * sub_seq_length:(rank + 1) * sub_seq_length] - assert torch.allclose(sub_q.grad, partial_master_q_grad, rtol=1e-5, atol=1e-2), \ - 'attention score cannot match' - - -def check_ring_av(rank, world_size): - # params - batch_size = 4 - num_heads = 4 - seq_length = 16 - attention_head_size = 32 - sub_seq_length = seq_length // world_size - - # create master tensors - a = torch.rand(batch_size * num_heads, seq_length, seq_length).cuda() - v = torch.rand(batch_size * num_heads, seq_length, attention_head_size).cuda() - dist.broadcast(a, src=0, group=gpc.get_group(ParallelMode.SEQUENCE)) - dist.broadcast(v, src=0, group=gpc.get_group(ParallelMode.SEQUENCE)) - - # create distributed tensors - sub_a = a.clone()[:, rank * sub_seq_length:(rank + 1) * sub_seq_length].contiguous() - sub_v = v.clone()[:, rank * sub_seq_length:(rank + 1) * sub_seq_length].contiguous() - - # set autograd attributes - a.requires_grad = True - v.requires_grad = True - a.retain_grad() - v.retain_grad() - sub_a.requires_grad = True - sub_v.requires_grad = True - sub_a.retain_grad() - sub_v.retain_grad() - - # compute master attention scores - out = torch.matmul(a, v) - - # compute distributed attention scores - ring_av = colossalai.nn.layer.parallel_sequence.RingAV.apply - sub_out = ring_av(sub_a, sub_v, batch_size, num_heads, attention_head_size, sub_seq_length) - - # print(f'master output shape: {out.shape}, partial output shape: {sub_out.shape}') - - # check master and distributed output - sub_master_out = out[:, rank * sub_seq_length:(rank + 1) * sub_seq_length] - assert torch.allclose(sub_out, sub_master_out, rtol=1e-5, atol=1e-2) - - # # run master backward - out.retain_grad() - out.mean().backward() - - # # run distributed backward - partial_master_out_grad = out.grad[:, rank * sub_seq_length:(rank + 1) * sub_seq_length] - torch.autograd.backward(sub_out, partial_master_out_grad) - - # # check master and distributed grads - partial_master_a_grad = a.grad[:, rank * sub_seq_length:(rank + 1) * sub_seq_length] - assert torch.allclose(sub_a.grad, partial_master_a_grad, rtol=1e-5, atol=1e-2), \ - 'attention output cannot match' - - -def run_test(rank, world_size, port): - colossalai.launch(rank=rank, world_size=world_size, config=CONFIG, host='localhost', port=port) - - # check_ring_qk(rank, world_size) - check_ring_av(rank, world_size) - - gpc.destroy() - torch.cuda.empty_cache() - - -@pytest.mark.dist -@rerun_if_address_is_in_use() -def test_sequence(): - spawn(run_test, 4) - - -if __name__ == '__main__': - test_sequence() diff --git a/tests/test_lazy/lazy_init_utils.py b/tests/test_lazy/lazy_init_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..ea6b16b947856f2de5ead0a31e22d8e9d70e2c24 --- /dev/null +++ b/tests/test_lazy/lazy_init_utils.py @@ -0,0 +1,109 @@ +import random +from copy import deepcopy +from typing import Any, Callable, Optional, Tuple + +import numpy as np +import torch +from packaging import version + +from colossalai.device.device_mesh import DeviceMesh +from colossalai.lazy.lazy_init import LazyInitContext, LazyTensor, _MyTensor +from colossalai.tensor.d_tensor import to_global +from colossalai.tensor.d_tensor.layout import Layout +from tests.kit.model_zoo.registry import ModelAttribute + +SUPPORT_LAZY = version.parse(torch.__version__) >= version.parse("1.12.0") + +# model_fn, data_gen_fn, output_transform_fn, model_attr +TestingEntry = Tuple[Callable[[], torch.nn.Module], Callable[[], dict], Callable[[], dict], Optional[ModelAttribute]] + + +def set_seed(seed: int) -> None: + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + + +def assert_model_equal(m1: torch.nn.Module, m2: torch.nn.Module) -> None: + s1 = m1.state_dict() + s2 = m2.state_dict() + + assert len(s1) == len(s2), f"len {len(s1)} vs {len(s2)}" + + for (n1, t1), (n2, t2) in zip(s1.items(), s2.items()): + assert n1 == n2 + assert torch.equal(t1, t2), f"{n1} {t1} vs {t2}" + + for p1, p2 in zip(m1.parameters(), m2.parameters()): + assert p1.requires_grad == p2.requires_grad + + +def assert_forward_equal( + m1: torch.nn.Module, + m2: torch.nn.Module, + data_gen_fn: Callable[[], dict], + output_transform_fn: Callable[[Any], dict], +) -> None: + data = data_gen_fn() + + m1.eval() + m2.eval() + # run forward + with torch.no_grad(): + outputs1 = m1(**data) + outputs2 = m2(**data) + + # compare output + transformed_out1 = output_transform_fn(outputs1) + transformed_out2 = output_transform_fn(outputs2) + + assert len(transformed_out1) == len(transformed_out2) + + for key, out1 in transformed_out1.items(): + out2 = transformed_out2[key] + assert torch.allclose( + out1, out2, atol=1e-5 + ), f"{m1.__class__.__name__} has inconsistent outputs, {out1} vs {out2}" + + +def check_lazy_init( + entry: TestingEntry, seed: int = 42, verbose: bool = False, check_forward: bool = False, default_device: str = "cpu" +) -> None: + model_fn, data_gen_fn, output_transform_fn, _, model_attr = entry + _MyTensor._pre_op_fn = lambda *args: set_seed(seed) + LazyTensor._pre_op_fn = lambda *args: set_seed(seed) + ctx = LazyInitContext(tensor_cls=_MyTensor, default_device=default_device) + with ctx: + model = model_fn() + ctx = LazyInitContext(default_device=default_device) + with ctx: + deferred_model = model_fn() + copied_deferred_model = deepcopy(deferred_model) + deferred_model = ctx.materialize(deferred_model, verbose=verbose) + copied_deferred_model = ctx.materialize(copied_deferred_model, verbose=verbose) + assert_model_equal(model, deferred_model) + assert_model_equal(deferred_model, copied_deferred_model) + if check_forward: + assert_forward_equal(model, deferred_model, data_gen_fn, output_transform_fn) + assert_forward_equal(deferred_model, copied_deferred_model, data_gen_fn, output_transform_fn) + if verbose: + print(f"{model.__class__.__name__} pass") + + +def assert_dist_model_equal( + model: torch.nn.Module, distributed_model: torch.nn.Module, device_mesh: DeviceMesh, sharding_spec_dict: dict +) -> None: + state = model.state_dict() + distributed_state = distributed_model.state_dict() + + assert len(state) == len(distributed_state), f"len {len(state)} vs {len(distributed_state)}" + + for (n1, t1), (n2, t2) in zip(state.items(), distributed_state.items()): + assert n1 == n2 + t1 = t1.cuda() + t2 = t2.cuda() + if n2 in sharding_spec_dict: + layout = Layout(device_mesh=device_mesh, sharding_spec=sharding_spec_dict[n2], global_shape=t1.shape) + t2.dist_layout = layout + t2 = to_global(t2) + assert torch.equal(t1, t2), f"{n1} {t1} vs {t2}" diff --git a/tests/test_lazy/test_from_pretrained.py b/tests/test_lazy/test_from_pretrained.py new file mode 100644 index 0000000000000000000000000000000000000000..623dd82c5ad9621cb12999d74a152c887f7fc1d0 --- /dev/null +++ b/tests/test_lazy/test_from_pretrained.py @@ -0,0 +1,31 @@ +import os + +from transformers import BertForPreTraining, LlamaForCausalLM + +import colossalai.interface.pretrained as pretrained_utils +from colossalai.lazy import LazyInitContext + + +def test_lazy_from_pretrained(): + # test from cached file, unsharded + model = BertForPreTraining.from_pretrained("prajjwal1/bert-tiny") + with LazyInitContext(): + deffered_model = BertForPreTraining.from_pretrained("prajjwal1/bert-tiny") + pretrained_path = pretrained_utils.get_pretrained_path(deffered_model) + assert os.path.isfile(pretrained_path) + for p, lazy_p in zip(model.parameters(), deffered_model.parameters()): + assert p.shape == lazy_p.shape + + # test from local file, sharded + llama_path = os.environ["LLAMA_PATH"] + model = LlamaForCausalLM.from_pretrained(llama_path) + with LazyInitContext(): + deffered_model = LlamaForCausalLM.from_pretrained(llama_path) + pretrained_path = pretrained_utils.get_pretrained_path(deffered_model) + assert os.path.isfile(pretrained_path) + for p, lazy_p in zip(model.parameters(), deffered_model.parameters()): + assert p.shape == lazy_p.shape + + +if __name__ == "__main__": + test_lazy_from_pretrained() diff --git a/tests/test_lazy/test_models.py b/tests/test_lazy/test_models.py new file mode 100644 index 0000000000000000000000000000000000000000..a1b5763d4cd817de7c505b74a26f0497444b2638 --- /dev/null +++ b/tests/test_lazy/test_models.py @@ -0,0 +1,22 @@ +import pytest +from lazy_init_utils import SUPPORT_LAZY, check_lazy_init + +from tests.kit.model_zoo import model_zoo + + +@pytest.mark.skipif(not SUPPORT_LAZY, reason="requires torch >= 1.12.0") +@pytest.mark.parametrize("subset", ["torchvision", "diffusers", "timm", "transformers", "torchaudio", "deepfm", "dlrm"]) +@pytest.mark.parametrize("default_device", ["cpu", "cuda"]) +def test_torchvision_models_lazy_init(subset, default_device): + sub_model_zoo = model_zoo.get_sub_registry(subset) + for name, entry in sub_model_zoo.items(): + # TODO(ver217): lazy init does not support weight norm, skip these models + if name in ("torchaudio_wav2vec2_base", "torchaudio_hubert_base") or name.startswith( + ("transformers_vit", "transformers_blip2") + ): + continue + check_lazy_init(entry, verbose=True, default_device=default_device) + + +if __name__ == "__main__": + test_torchvision_models_lazy_init("transformers", "cpu") diff --git a/tests/test_lazy/test_ops.py b/tests/test_lazy/test_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..e6b936198547df3aa303627982cd8a2852ff4abb --- /dev/null +++ b/tests/test_lazy/test_ops.py @@ -0,0 +1,64 @@ +import copy + +import pytest +import torch +import torch.nn as nn +from lazy_init_utils import SUPPORT_LAZY +from torch.nn import Parameter + +from colossalai.lazy import LazyInitContext + + +@pytest.mark.skipif(not SUPPORT_LAZY, reason="requires torch >= 1.12.0") +def test_lazy_ops(): + with LazyInitContext(): + x = torch.rand(2, 3) + assert tuple(x.shape) == (2, 3) + assert x.device.type == "cpu" + x.requires_grad is False + y = x.cuda() + assert tuple(y.shape) == (2, 3) + assert y.device.type == "cuda" + assert y.requires_grad is False + assert x.cpu() is x + p = Parameter(torch.empty(2, 3)) + assert tuple(p.shape) == (2, 3) + assert p.device.type == "cpu" + assert p.requires_grad is True + assert isinstance(p, Parameter) + x.materialize() + assert tuple(x.shape) == (2, 3) + assert x.device.type == "cpu" + assert x.requires_grad is False + y.materialize() + assert tuple(y.shape) == (2, 3) + assert y.device.type == "cuda" + assert y.requires_grad is False + p.materialize() + assert tuple(p.shape) == (2, 3) + assert p.device.type == "cpu" + assert p.requires_grad is True + assert isinstance(p, Parameter) + + with LazyInitContext(): + x = torch.empty(2, 3) + x.uniform_() + x.materialize() + assert tuple(x.shape) == (2, 3) + + with LazyInitContext(): + model = nn.Linear(3, 4) + model = model.cuda() + model_copied = copy.deepcopy(model) + LazyInitContext.materialize(model) + assert model.weight.device.type == "cuda" + assert model.bias.device.type == "cuda" + LazyInitContext.materialize(model_copied) + assert model_copied.weight.device.type == "cuda" + assert model_copied.bias.device.type == "cuda" + assert torch.equal(model.weight, model_copied.weight) + assert torch.equal(model.bias, model_copied.bias) + + +if __name__ == "__main__": + test_lazy_ops() diff --git a/tests/test_amp/test_naive_fp16.py b/tests/test_legacy/test_amp/test_naive_fp16.py similarity index 85% rename from tests/test_amp/test_naive_fp16.py rename to tests/test_legacy/test_amp/test_naive_fp16.py index 6ce4c7f497254cc45a62c640fd2463159b97018f..76f9ff07407f92cd7acf37f4ea9a4739e65f2f91 100644 --- a/tests/test_amp/test_naive_fp16.py +++ b/tests/test_legacy/test_amp/test_naive_fp16.py @@ -4,7 +4,7 @@ import pytest import torch import colossalai -from colossalai.amp import convert_to_apex_amp, convert_to_naive_amp +from colossalai.legacy.amp import convert_to_apex_amp, convert_to_naive_amp from colossalai.testing import assert_close_loose, clear_cache_before_run, rerun_if_address_is_in_use, spawn from tests.components_to_test.registry import non_distributed_component_funcs @@ -13,7 +13,7 @@ def check_equal(a, b): """ This function checks if two tensors are equal within tolerance """ - assert torch.allclose(a.float(), b.float(), rtol=1e-4, atol=1e-3), f'a = {a}, b = {b}' + assert torch.allclose(a.float(), b.float(), rtol=1e-4, atol=1e-3), f"a = {a}, b = {b}" def run_naive_amp(): @@ -25,7 +25,7 @@ def run_naive_amp(): torch.backends.cudnn.deterministic = True # create layer - test_models = ['repeated_computed_layers', 'nested_model', 'resnet18'] + test_models = ["repeated_computed_layers", "nested_model", "resnet18"] for test_name in test_models: get_component_func = non_distributed_component_funcs.get_callable(test_name) model_builder, train_dataloader, _, optim_class, _ = get_component_func() @@ -41,9 +41,10 @@ def run_naive_amp(): # inject naive and apex amp naive_amp_config = dict(initial_scale=128, clip_grad_norm=1.0) - naive_amp_model, naive_amp_optimizer = convert_to_naive_amp(naive_amp_model, naive_amp_optimizer, - naive_amp_config) - apex_amp_config = dict(opt_level='O2', loss_scale=128, keep_batchnorm_fp32=False) + naive_amp_model, naive_amp_optimizer = convert_to_naive_amp( + naive_amp_model, naive_amp_optimizer, naive_amp_config + ) + apex_amp_config = dict(opt_level="O2", loss_scale=128, keep_batchnorm_fp32=False) apex_amp_model, apex_amp_optimizer = convert_to_apex_amp(apex_amp_model, apex_amp_optimizer, apex_amp_config) # create data @@ -78,7 +79,7 @@ def run_naive_amp(): def run_dist(rank, world_size, port): - colossalai.launch(config=dict(), rank=rank, world_size=world_size, port=port, host='localhost') + colossalai.legacy.launch(config=dict(), rank=rank, world_size=world_size, port=port, host="localhost") run_naive_amp() @@ -89,5 +90,5 @@ def test_naive_amp(): spawn(run_dist, 1) -if __name__ == '__main__': +if __name__ == "__main__": test_naive_amp() diff --git a/tests/test_amp/test_torch_fp16.py b/tests/test_legacy/test_amp/test_torch_fp16.py similarity index 84% rename from tests/test_amp/test_torch_fp16.py rename to tests/test_legacy/test_amp/test_torch_fp16.py index 6451aa6264a37c363e5458d1a0b859d67bddc122..47b303745e4eb345a631bcfde957371238df3a4d 100644 --- a/tests/test_amp/test_torch_fp16.py +++ b/tests/test_legacy/test_amp/test_torch_fp16.py @@ -4,7 +4,7 @@ import pytest import torch import colossalai -from colossalai.amp import convert_to_apex_amp, convert_to_torch_amp +from colossalai.legacy.amp import convert_to_apex_amp, convert_to_torch_amp from colossalai.testing import assert_close_loose, clear_cache_before_run, rerun_if_address_is_in_use, spawn from tests.components_to_test.registry import non_distributed_component_funcs @@ -18,7 +18,7 @@ def run_torch_amp(): torch.backends.cudnn.deterministic = True # create layer - test_models = ['resnet18', 'simple_net'] + test_models = ["resnet18", "simple_net"] for test_name in test_models: get_component_func = non_distributed_component_funcs.get_callable(test_name) model_builder, train_dataloader, _, optim_class, _ = get_component_func() @@ -34,10 +34,10 @@ def run_torch_amp(): # inject torch and apex amp torch_amp_config = dict(init_scale=128, enabled=True) - torch_amp_model, torch_amp_optimizer, _ = convert_to_torch_amp(torch_amp_model, - torch_amp_optimizer, - amp_config=torch_amp_config) - apex_amp_config = dict(opt_level='O1', loss_scale=128) + torch_amp_model, torch_amp_optimizer, _ = convert_to_torch_amp( + torch_amp_model, torch_amp_optimizer, amp_config=torch_amp_config + ) + apex_amp_config = dict(opt_level="O1", loss_scale=128) apex_amp_model, apex_amp_optimizer = convert_to_apex_amp(apex_amp_model, apex_amp_optimizer, apex_amp_config) # create data @@ -61,7 +61,7 @@ def run_torch_amp(): # check grad # In apex amp, grad is not scaled before backward, but torch amp does for torch_amp_param, apex_amp_param in zip(torch_amp_model.parameters(), apex_amp_model.parameters()): - assert_close_loose(torch_amp_param.grad, apex_amp_param.grad * apex_amp_config['loss_scale']) + assert_close_loose(torch_amp_param.grad, apex_amp_param.grad * apex_amp_config["loss_scale"]) # clip gradient apex_amp_optimizer.clip_grad_norm(model=apex_amp_model, max_norm=1.0) @@ -78,7 +78,7 @@ def run_torch_amp(): def run_dist(rank, world_size, port): - colossalai.launch(config=dict(), rank=rank, world_size=world_size, port=port, host='localhost') + colossalai.legacy.launch(config=dict(), rank=rank, world_size=world_size, port=port, host="localhost") run_torch_amp() @@ -89,5 +89,5 @@ def test_torch_amp(): spawn(run_dist, 1) -if __name__ == '__main__': +if __name__ == "__main__": test_torch_amp() diff --git a/tests/test_legacy/test_comm/test_boardcast_send_recv_v2.py b/tests/test_legacy/test_comm/test_boardcast_send_recv_v2.py new file mode 100644 index 0000000000000000000000000000000000000000..bc243631a6c564f7f4e01bf56cce160016444546 --- /dev/null +++ b/tests/test_legacy/test_comm/test_boardcast_send_recv_v2.py @@ -0,0 +1,55 @@ +import pytest +import torch + +from colossalai.legacy.communication.p2p_v2 import _recv_object, _send_object +from colossalai.legacy.context import ParallelMode +from colossalai.legacy.core import global_context as gpc +from colossalai.legacy.initialize import launch +from colossalai.logging import disable_existing_loggers +from colossalai.testing import rerun_if_address_is_in_use, spawn + +disable_existing_loggers() +world_size = 4 +CONFIG = dict(parallel=dict(pipeline=world_size)) +torch.manual_seed(123) + + +def check_layer(rank, world_size, port): + disable_existing_loggers() + launch(config=CONFIG, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl", verbose=False) + rank = gpc.get_local_rank(ParallelMode.PIPELINE) + + if rank == 0: + obj = [ + torch.randn( + 3, + ) + ] + _send_object(obj, 1) + + if rank == 1: + _recv_object(0) + + if rank == 2: + _recv_object(3) + + if rank == 3: + obj = [ + torch.randn( + 3, + ) + ] + _send_object(obj, 2) + + gpc.destroy() + torch.cuda.empty_cache() + + +@pytest.mark.dist +@rerun_if_address_is_in_use() +def test_object_list_p2p(): + spawn(check_layer, world_size) + + +if __name__ == "__main__": + test_object_list_p2p() diff --git a/tests/test_legacy/test_comm/test_comm.py b/tests/test_legacy/test_comm/test_comm.py new file mode 100644 index 0000000000000000000000000000000000000000..7d2c81972e5ab630cbd4b6682450b1959e2960ce --- /dev/null +++ b/tests/test_legacy/test_comm/test_comm.py @@ -0,0 +1,71 @@ +import pytest +import torch +import torch.distributed as dist + +from colossalai.legacy.communication import all_gather, all_reduce, reduce_scatter +from colossalai.legacy.context import ParallelMode +from colossalai.legacy.core import global_context as gpc +from colossalai.legacy.initialize import launch +from colossalai.testing import rerun_if_address_is_in_use, spawn +from colossalai.utils import get_current_device + +CONFIG = dict(parallel=dict(data=8, pipeline=1, tensor=dict(mode=None, size=1))) + +SIZE = 8 + + +def check_all_gather(): + tensor = torch.tensor([dist.get_rank() * SIZE + j for j in range(SIZE)]) + tensor = tensor.to(get_current_device()) + print("Before: Rank {0} - {1}".format(dist.get_rank(), tensor)) + tensor, op = all_gather(tensor, 0, ParallelMode.GLOBAL, async_op=True) + print("After: Rank {0} - {1}".format(dist.get_rank(), tensor)) + op.wait() + print("Complete: Rank {0} - {1}".format(dist.get_rank(), tensor)) + torch.cuda.synchronize() + + +def check_reduce_scatter(): + tensor = torch.tensor([dist.get_rank() * SIZE + j for j in range(SIZE)]) + tensor = tensor.to(get_current_device()) + print("Before: Rank {0} - {1}".format(dist.get_rank(), tensor)) + tensor, op = reduce_scatter(tensor, 0, ParallelMode.GLOBAL, async_op=True) + print("After: Rank {0} - {1}".format(dist.get_rank(), tensor)) + op.wait() + print("Complete: Rank {0} - {1}".format(dist.get_rank(), tensor)) + torch.cuda.synchronize() + + +def check_all_reduce(): + tensor = torch.tensor([dist.get_rank() * SIZE + j for j in range(SIZE)]) + tensor = tensor.to(get_current_device()) + print("Before: Rank {0} - {1}".format(dist.get_rank(), tensor)) + tensor, op = all_reduce(tensor, ParallelMode.GLOBAL, async_op=True) + print("After: Rank {0} - {1}".format(dist.get_rank(), tensor)) + op.wait() + print("Complete: Rank {0} - {1}".format(dist.get_rank(), tensor)) + torch.cuda.synchronize() + + +def check_layer(rank, world_size, port): + launch(config=CONFIG, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") + + assert dist.get_rank() == gpc.get_global_rank() + print("Rank {} / {}".format(dist.get_rank(), dist.get_world_size())) + + check_all_gather() + check_reduce_scatter() + check_all_reduce() + + gpc.destroy() + torch.cuda.empty_cache() + + +@pytest.mark.dist +@rerun_if_address_is_in_use() +def test_comm(): + spawn(check_layer, 4) + + +if __name__ == "__main__": + test_comm() diff --git a/tests/test_comm/test_object_list_p2p.py b/tests/test_legacy/test_comm/test_object_list_p2p.py similarity index 86% rename from tests/test_comm/test_object_list_p2p.py rename to tests/test_legacy/test_comm/test_object_list_p2p.py index e9d7630c154307bb680a4fd85babc012bbaade47..69c68c7159e485cf0051e80b405f2ee878edfaff 100644 --- a/tests/test_comm/test_object_list_p2p.py +++ b/tests/test_legacy/test_comm/test_object_list_p2p.py @@ -1,7 +1,7 @@ import pytest import torch -from colossalai.communication.p2p import ( +from colossalai.legacy.communication.p2p import ( recv_backward, recv_forward, send_backward, @@ -9,9 +9,9 @@ from colossalai.communication.p2p import ( send_forward, send_forward_recv_backward, ) -from colossalai.context import ParallelMode -from colossalai.core import global_context as gpc -from colossalai.initialize import launch +from colossalai.legacy.context import ParallelMode +from colossalai.legacy.core import global_context as gpc +from colossalai.legacy.initialize import launch from colossalai.testing import rerun_if_address_is_in_use, spawn CONFIG = dict(parallel=dict(pipeline=2)) @@ -27,7 +27,7 @@ grad_list = [torch.rand(3, 3) for i in range(LIST_LENGTH)] def check_send_recv_forward(): if gpc.get_local_rank(ParallelMode.PIPELINE) == 0: - device = torch.device('cuda:0') + device = torch.device("cuda:0") data_to_send = data.to(device) data_list_to_send = [] for data_in_list in data_list: @@ -35,7 +35,7 @@ def check_send_recv_forward(): send_forward(data_to_send) send_forward(data_list_to_send) else: - device = torch.device('cuda:1') + device = torch.device("cuda:1") data_recv = recv_forward(TENSOR_SIZE) data_list_recv = recv_forward(TENSOR_SIZE_LIST) data_to_check = data.to(device) @@ -47,7 +47,7 @@ def check_send_recv_forward(): def check_send_recv_backward(): if gpc.get_local_rank(ParallelMode.PIPELINE) == 0: - device = torch.device('cuda:0') + device = torch.device("cuda:0") grad_recv = recv_backward(TENSOR_SIZE) grad_list_recv = recv_backward(TENSOR_SIZE_LIST) grad_to_check = grad.to(device) @@ -56,7 +56,7 @@ def check_send_recv_backward(): grad_to_check = grad_send.to(device) assert grad_recv.equal(grad_to_check) else: - device = torch.device('cuda:1') + device = torch.device("cuda:1") grad_to_send = grad.to(device) grad_list_to_send = [] for grad_in_list in grad_list: @@ -67,7 +67,7 @@ def check_send_recv_backward(): def check_send_recv_forward_backward(): if gpc.get_local_rank(ParallelMode.PIPELINE) == 0: - device = torch.device('cuda:0') + device = torch.device("cuda:0") data_list_to_send = [] for data_in_list in data_list: data_list_to_send.append(data_in_list.to(device)) @@ -77,7 +77,7 @@ def check_send_recv_forward_backward(): grad_to_check = grad_send.to(device) assert grad_recv.equal(grad_to_check) else: - device = torch.device('cuda:1') + device = torch.device("cuda:1") grad_list_to_send = [] for grad_in_list in grad_list: grad_list_to_send.append(grad_in_list.to(device)) @@ -88,7 +88,7 @@ def check_send_recv_forward_backward(): def check_layer(rank, world_size, port): - launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + launch(config=CONFIG, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") check_send_recv_forward() check_send_recv_backward() check_send_recv_forward_backward() @@ -102,5 +102,5 @@ def test_object_list_p2p(): spawn(check_layer, 2) -if __name__ == '__main__': +if __name__ == "__main__": test_object_list_p2p() diff --git a/tests/test_comm/test_object_list_p2p_v2.py b/tests/test_legacy/test_comm/test_object_list_p2p_v2.py similarity index 87% rename from tests/test_comm/test_object_list_p2p_v2.py rename to tests/test_legacy/test_comm/test_object_list_p2p_v2.py index cae38385b6e17ddd8ebbd7e8ad7fbb9d3041ef24..eb05ea4839c6fa85d778dfca49458d8f7eaecc80 100644 --- a/tests/test_comm/test_object_list_p2p_v2.py +++ b/tests/test_legacy/test_comm/test_object_list_p2p_v2.py @@ -1,10 +1,10 @@ import pytest import torch -from colossalai.communication.p2p_v2 import recv_backward, recv_forward, send_backward, send_forward -from colossalai.context import ParallelMode -from colossalai.core import global_context as gpc -from colossalai.initialize import launch +from colossalai.legacy.communication.p2p_v2 import recv_backward, recv_forward, send_backward, send_forward +from colossalai.legacy.context import ParallelMode +from colossalai.legacy.core import global_context as gpc +from colossalai.legacy.initialize import launch from colossalai.logging import disable_existing_loggers from colossalai.testing import rerun_if_address_is_in_use, spawn @@ -32,7 +32,7 @@ def check_send_recv_forward(): local_rank = gpc.get_local_rank(ParallelMode.PIPELINE) if local_rank == 0: - device = torch.device('cuda:0') + device = torch.device("cuda:0") data_to_send = data.to(device) data_list_to_send = [] for data_in_list in data_list: @@ -42,7 +42,7 @@ def check_send_recv_forward(): send_forward(data_list_to_send, scatter_gather_tensors=use_scatter_gather_tensors) elif local_rank == 1: - device = torch.device('cuda:1') + device = torch.device("cuda:1") data_recv = recv_forward(TENSOR_SIZE, scatter_gather_tensors=use_scatter_gather_tensors) data_list_recv = recv_forward(TENSOR_SIZE_LIST, scatter_gather_tensors=use_scatter_gather_tensors) @@ -60,7 +60,7 @@ def check_send_recv_forward(): def check_send_recv_backward(): disable_existing_loggers() if gpc.get_local_rank(ParallelMode.PIPELINE) == 0: - device = torch.device('cuda:0') + device = torch.device("cuda:0") grad_recv = recv_backward(TENSOR_SIZE) grad_list_recv = recv_backward(TENSOR_SIZE_LIST) @@ -73,7 +73,7 @@ def check_send_recv_backward(): grad_to_check = grad_send.to(device) assert grad_recv.equal(grad_to_check) else: - device = torch.device('cuda:1') + device = torch.device("cuda:1") grad_to_send = grad.to(device) grad_list_to_send = [] for grad_in_list in grad_list: @@ -104,7 +104,7 @@ def check_small_pipeline(): def check_layer(rank, world_size, port): disable_existing_loggers() - launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + launch(config=CONFIG, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") disable_existing_loggers() # check_send_recv_forward() @@ -120,6 +120,6 @@ def test_object_list_p2p(): spawn(check_layer, world_size) -if __name__ == '__main__': +if __name__ == "__main__": disable_existing_loggers() test_object_list_p2p() diff --git a/tests/test_legacy/test_context/configs/parallel_2d_init.py b/tests/test_legacy/test_context/configs/parallel_2d_init.py new file mode 100644 index 0000000000000000000000000000000000000000..d1203fcdc4367cbc79e09eaa4c6c6bc68288fa23 --- /dev/null +++ b/tests/test_legacy/test_context/configs/parallel_2d_init.py @@ -0,0 +1,4 @@ +#!/usr/bin/env python +# -*- encoding: utf-8 -*- + +parallel = dict(pipeline=dict(size=2), tensor=dict(size=4, mode="2d")) diff --git a/tests/test_legacy/test_context/configs/parallel_2p5d_init.py b/tests/test_legacy/test_context/configs/parallel_2p5d_init.py new file mode 100644 index 0000000000000000000000000000000000000000..89e8cd6039f7700d4813a54222a1655ba2305b91 --- /dev/null +++ b/tests/test_legacy/test_context/configs/parallel_2p5d_init.py @@ -0,0 +1,4 @@ +#!/usr/bin/env python +# -*- encoding: utf-8 -*- + +parallel = dict(pipeline=dict(size=2), tensor=dict(size=8, depth=2, mode="2.5d")) diff --git a/tests/test_legacy/test_context/configs/parallel_3d_init.py b/tests/test_legacy/test_context/configs/parallel_3d_init.py new file mode 100644 index 0000000000000000000000000000000000000000..f9aa52fa419942302912e38c644113a8621ef34a --- /dev/null +++ b/tests/test_legacy/test_context/configs/parallel_3d_init.py @@ -0,0 +1,4 @@ +#!/usr/bin/env python +# -*- encoding: utf-8 -*- + +parallel = dict(pipeline=dict(size=2), tensor=dict(size=8, mode="3d")) diff --git a/tests/test_context/test_hybrid_parallel.py b/tests/test_legacy/test_context/test_hybrid_parallel.py similarity index 82% rename from tests/test_context/test_hybrid_parallel.py rename to tests/test_legacy/test_context/test_hybrid_parallel.py index 9f26a5af53ce6d13fc0d00e4c4d8b949b1172b17..b9e44bb34362e204514d9d734c67e21dcce22b5b 100644 --- a/tests/test_context/test_hybrid_parallel.py +++ b/tests/test_legacy/test_context/test_hybrid_parallel.py @@ -3,17 +3,16 @@ from pathlib import Path -import pytest import torch -from colossalai import launch -from colossalai.context import reset_seeds -from colossalai.context.parallel_mode import ParallelMode -from colossalai.core import global_context as gpc -from colossalai.global_variables import tensor_parallel_env as tp_env +from colossalai.legacy import launch +from colossalai.legacy.context import reset_seeds +from colossalai.legacy.context.parallel_mode import ParallelMode +from colossalai.legacy.core import global_context as gpc +from colossalai.legacy.global_variables import tensor_parallel_env as tp_env from colossalai.testing import free_port, rerun_if_address_is_in_use, spawn -CONFIG_PATH_LIST = list(Path(__file__).parent.glob('configs/*.py')) +CONFIG_PATH_LIST = list(Path(__file__).parent.glob("configs/*.py")) def check_data_parallel_rank(rank): @@ -50,11 +49,11 @@ def check_model_parallel_rank(rank): def check_tensor_parallel_rank(rank): - if tp_env.mode == '2d': + if tp_env.mode == "2d": check_2d_tensor_parallel_rank(rank) - elif tp_env == '2.5d': + elif tp_env == "2.5d": check_2p5d_tensor_parallel_rank(rank) - elif tp_env == '3d': + elif tp_env == "3d": check_3d_tensor_parallel_rank(rank) @@ -115,13 +114,9 @@ def check_3d_tensor_parallel_rank(rank): def init_context(config_path, rank, world_size, backend, port, host): - dist_args = dict(config=config_path, - rank=rank, - world_size=world_size, - backend=backend, - port=port, - host=host, - verbose=True) + dist_args = dict( + config=config_path, rank=rank, world_size=world_size, backend=backend, port=port, host=host, verbose=True + ) launch(**dist_args) check_tensor_parallel_rank(rank) @@ -134,16 +129,12 @@ def init_context(config_path, rank, world_size, backend, port, host): def run_dist(rank, world_size, port, backend, port_list, host): for config_path, current_port in zip(CONFIG_PATH_LIST, port_list): - init_context(config_path=config_path, - rank=rank, - world_size=world_size, - backend=backend, - port=current_port, - host=host) + init_context( + config_path=config_path, rank=rank, world_size=world_size, backend=backend, port=current_port, host=host + ) reset_seeds() -@pytest.mark.cpu @rerun_if_address_is_in_use() def test_context(): """ @@ -159,8 +150,8 @@ def test_context(): port_list.append(port) break - spawn(run_dist, world_size, backend='gloo', port_list=port_list, host='localhost') + spawn(run_dist, world_size, backend="gloo", port_list=port_list, host="localhost") -if __name__ == '__main__': +if __name__ == "__main__": test_context() diff --git a/tests/test_data/test_cifar10_dataset.py b/tests/test_legacy/test_data/test_cifar10_dataset.py similarity index 77% rename from tests/test_data/test_cifar10_dataset.py rename to tests/test_legacy/test_data/test_cifar10_dataset.py index 4b9ca61d9f1796353f48fd96ca2d8b1cbacfa17d..4851f1b85817e23a679b8469395f499ac1b442bf 100644 --- a/tests/test_data/test_cifar10_dataset.py +++ b/tests/test_legacy/test_data/test_cifar10_dataset.py @@ -4,19 +4,17 @@ import os from pathlib import Path -import pytest -from torchvision import transforms, datasets from torch.utils.data import DataLoader +from torchvision import datasets, transforms -@pytest.mark.cpu def test_cifar10_dataset(): # build transform transform_pipeline = [transforms.ToTensor()] transform_pipeline = transforms.Compose(transform_pipeline) # build dataset - dataset = datasets.CIFAR10(root=Path(os.environ['DATA']), train=True, download=True, transform=transform_pipeline) + dataset = datasets.CIFAR10(root=Path(os.environ["DATA"]), train=True, download=True, transform=transform_pipeline) # build dataloader dataloader = DataLoader(dataset=dataset, batch_size=4, shuffle=True, num_workers=2) @@ -24,5 +22,5 @@ def test_cifar10_dataset(): img, label = data_iter.next() -if __name__ == '__main__': +if __name__ == "__main__": test_cifar10_dataset() diff --git a/tests/test_legacy/test_data/test_data_parallel_sampler.py b/tests/test_legacy/test_data/test_data_parallel_sampler.py new file mode 100644 index 0000000000000000000000000000000000000000..1786b4a77a8b7bcc1959aed670a09783658cdd6f --- /dev/null +++ b/tests/test_legacy/test_data/test_data_parallel_sampler.py @@ -0,0 +1,65 @@ +#!/usr/bin/env python +# -*- encoding: utf-8 -*- + +import os +from pathlib import Path + +import torch +import torch.distributed as dist +from torchvision import datasets, transforms + +import colossalai +from colossalai.context import Config +from colossalai.legacy.context import ParallelMode +from colossalai.legacy.core import global_context as gpc +from colossalai.legacy.utils import get_dataloader +from colossalai.testing import rerun_if_address_is_in_use, spawn + +CONFIG = Config( + dict( + parallel=dict( + pipeline=dict(size=1), + tensor=dict(size=1, mode=None), + ), + seed=1024, + ) +) + + +def run_data_sampler(rank, world_size, port): + dist_args = dict(config=CONFIG, rank=rank, world_size=world_size, backend="gloo", port=port, host="localhost") + colossalai.legacy.launch(**dist_args) + print("finished initialization") + + # build dataset + transform_pipeline = [transforms.ToTensor()] + transform_pipeline = transforms.Compose(transform_pipeline) + dataset = datasets.CIFAR10(root=Path(os.environ["DATA"]), train=True, download=True, transform=transform_pipeline) + + # build dataloader + dataloader = get_dataloader(dataset, batch_size=8, add_sampler=True) + + data_iter = iter(dataloader) + img, label = data_iter.next() + img = img[0] + + if gpc.get_local_rank(ParallelMode.DATA) != 0: + img_to_compare = img.clone() + else: + img_to_compare = img + dist.broadcast(img_to_compare, src=0, group=gpc.get_group(ParallelMode.DATA)) + + if gpc.get_local_rank(ParallelMode.DATA) != 0: + assert not torch.equal( + img, img_to_compare + ), "Same image was distributed across ranks but expected it to be different" + torch.cuda.empty_cache() + + +@rerun_if_address_is_in_use() +def test_data_sampler(): + spawn(run_data_sampler, 4) + + +if __name__ == "__main__": + test_data_sampler() diff --git a/tests/test_legacy/test_data/test_deterministic_dataloader.py b/tests/test_legacy/test_data/test_deterministic_dataloader.py new file mode 100644 index 0000000000000000000000000000000000000000..abb442f4820397274174a11c73e5197133a10f4d --- /dev/null +++ b/tests/test_legacy/test_data/test_deterministic_dataloader.py @@ -0,0 +1,75 @@ +#!/usr/bin/env python +# -*- encoding: utf-8 -*- + +import os +from pathlib import Path + +import torch +import torch.distributed as dist +from torchvision import datasets, transforms + +import colossalai +from colossalai.context import Config +from colossalai.legacy.context import ParallelMode +from colossalai.legacy.core import global_context as gpc +from colossalai.legacy.utils import get_dataloader +from colossalai.testing import rerun_if_address_is_in_use, spawn + +CONFIG = Config( + dict( + train_data=dict( + dataset=dict( + type="CIFAR10", + root=Path(os.environ["DATA"]), + train=True, + download=True, + ), + dataloader=dict(num_workers=2, batch_size=2, shuffle=True), + ), + parallel=dict( + pipeline=dict(size=1), + tensor=dict(size=1, mode=None), + ), + seed=1024, + ) +) + + +def run_data_sampler(rank, world_size, port): + dist_args = dict(config=CONFIG, rank=rank, world_size=world_size, backend="gloo", port=port, host="localhost") + colossalai.legacy.launch(**dist_args) + + # build dataset + transform_pipeline = [transforms.ToTensor(), transforms.RandomCrop(size=32, padding=4)] + transform_pipeline = transforms.Compose(transform_pipeline) + dataset = datasets.CIFAR10(root=Path(os.environ["DATA"]), train=True, download=True, transform=transform_pipeline) + + # build dataloader + dataloader = get_dataloader(dataset, batch_size=8, add_sampler=False) + + data_iter = iter(dataloader) + img, label = data_iter.next() + img = img[0] + + if gpc.get_local_rank(ParallelMode.DATA) != 0: + img_to_compare = img.clone() + else: + img_to_compare = img + dist.broadcast(img_to_compare, src=0, group=gpc.get_group(ParallelMode.DATA)) + + if gpc.get_local_rank(ParallelMode.DATA) != 0: + # this is without sampler + # this should be false if data parallel sampler to given to the dataloader + assert torch.equal( + img, img_to_compare + ), "Same image was distributed across ranks and expected it to be the same" + torch.cuda.empty_cache() + + +@rerun_if_address_is_in_use() +def test_data_sampler(): + spawn(run_data_sampler, 4) + + +if __name__ == "__main__": + test_data_sampler() diff --git a/tests/test_legacy/test_engine/test_engine.py b/tests/test_legacy/test_engine/test_engine.py new file mode 100644 index 0000000000000000000000000000000000000000..b07fe8abe86e221cc3cee29e5d55455587d50029 --- /dev/null +++ b/tests/test_legacy/test_engine/test_engine.py @@ -0,0 +1,66 @@ +import pytest + +import colossalai +from colossalai.legacy.amp import AMP_TYPE +from colossalai.legacy.core import global_context as gpc +from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn +from tests.components_to_test.registry import non_distributed_component_funcs + +CONFIG = dict( + parallel=dict(pipeline=dict(size=1), tensor=dict(size=1, mode=None)), fp16=dict(mode=None), clip_grad_norm=1.0 +) + + +@parameterize("model_name", ["repeated_computed_layers", "resnet18", "repeated_computed_layers"]) +@parameterize("amp_mode", [AMP_TYPE.APEX, AMP_TYPE.TORCH, AMP_TYPE.NAIVE, None]) +def run_train(model_name, amp_mode): + # FIXME: test bert + get_components_func = non_distributed_component_funcs.get_callable(model_name) + gpc.config.fp16["mode"] = amp_mode + model_builder, train_dataloader, _, optimizer_class, criterion = get_components_func() + + model = model_builder(checkpoint=False) + engine, train_dataloader, *args = colossalai.legacy.initialize( + model=model, + optimizer=optimizer_class(model.parameters(), lr=1e-3), + criterion=criterion, + train_dataloader=train_dataloader, + ) + + try: + engine.train() + for data, label in train_dataloader: + engine.zero_grad() + data = data.cuda() + label = label.cuda() + if criterion: + output = engine(data) + loss = engine.criterion(output, label) + else: + loss = engine(data, label) + engine.backward(loss) + engine.step() + break + except IndexError: + # if using apex amp, NetWithRepeatedlyComputedLayers will raise an index out of range issue + # the following check fails in apex + # if cached_x.grad_fn.next_functions[1][0].variable is not x: + pass + + +def run_engine(rank, world_size, port): + # init dist env + colossalai.legacy.launch( + config=CONFIG, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl" + ) + run_train() + + +@pytest.mark.dist +@rerun_if_address_is_in_use() +def test_engine(): + spawn(run_engine, 2) + + +if __name__ == "__main__": + test_engine() diff --git a/tests/test_legacy/test_engine/test_gradient_accumluation.py b/tests/test_legacy/test_engine/test_gradient_accumluation.py new file mode 100644 index 0000000000000000000000000000000000000000..262876e0ba4223973942edae3b5d468912c41d75 --- /dev/null +++ b/tests/test_legacy/test_engine/test_gradient_accumluation.py @@ -0,0 +1,95 @@ +import os +from pathlib import Path + +import pytest +import torch +import torch.nn as nn +from torch.optim import Adam +from torchvision import transforms +from torchvision.datasets import CIFAR10 +from torchvision.models import resnet18 + +import colossalai +from colossalai.legacy.core import global_context as gpc +from colossalai.legacy.utils import get_dataloader +from colossalai.logging import get_dist_logger +from colossalai.testing import rerun_if_address_is_in_use, spawn + +# Config +BATCH_SIZE = 2 +NUM_CLASSES = 10 + +CONFIG = dict( + parallel=dict(pipeline=dict(size=1), tensor=dict(size=1, mode=None)), clip_grad_norm=1.0, gradient_accumulation=4 +) + + +def run_no_pipeline(rank, world_size, port): + # init dist env + colossalai.legacy.launch( + config=CONFIG, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl" + ) + + # build model + model = resnet18(num_classes=10) + + # build dataloaders + train_dataset = CIFAR10( + root=Path(os.environ["DATA"]), + download=True, + transform=transforms.Compose( + [transforms.ToTensor(), transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))] + ), + ) + train_dataloader = get_dataloader( + dataset=train_dataset, shuffle=True, batch_size=BATCH_SIZE, pin_memory=True, drop_last=True + ) + + # build optimizer + optimizer = Adam(model.parameters(), lr=0.001) + criterion = nn.CrossEntropyLoss() + + engine, train_dataloader, *args = colossalai.legacy.initialize( + model=model, optimizer=optimizer, criterion=criterion, train_dataloader=train_dataloader + ) + get_dist_logger() + rank = torch.distributed.get_rank() + param_track = [] + grad_track = [] + next(model.parameters()).retain_grad() + + engine.train() + step = 0 + for img, label in train_dataloader: + engine.zero_grad() + img = img.cuda() + label = label.cuda() + output = engine(img) + loss = engine.criterion(output, label) + engine.backward(loss) + engine.step() + + # check + param_track.append(next(model.parameters())[0].clone()) + grad_track.append(next(model.parameters()).grad[0].clone()) + step += 1 + if step == CONFIG["gradient_accumulation"]: + break + + assert not torch.all(grad_track[0] == grad_track[-1]), "grad should be different in different iterations" + assert torch.all(param_track[0] == param_track[1]) and not torch.all( + param_track[0] == param_track[-1] + ), "param should be the same in the first few iterations and only changed in the last iteration" + + gpc.destroy() + torch.cuda.empty_cache() + + +@pytest.mark.dist +@rerun_if_address_is_in_use() +def test_engine(): + spawn(run_no_pipeline, 4) + + +if __name__ == "__main__": + test_engine() diff --git a/tests/test_legacy/test_layers/test_1d/checks_1d/__init__.py b/tests/test_legacy/test_layers/test_1d/checks_1d/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/tests/test_layers/test_1d/checks_1d/check_layer_1d.py b/tests/test_legacy/test_layers/test_1d/checks_1d/check_layer_1d.py similarity index 91% rename from tests/test_layers/test_1d/checks_1d/check_layer_1d.py rename to tests/test_legacy/test_layers/test_1d/checks_1d/check_layer_1d.py index 668b8a334800753d9848347a8d88cba66b605de6..8a9a73d65f389b905a40893d3cf8fdc9d7d6f19e 100644 --- a/tests/test_layers/test_1d/checks_1d/check_layer_1d.py +++ b/tests/test_legacy/test_layers/test_1d/checks_1d/check_layer_1d.py @@ -2,10 +2,10 @@ import torch import torch.distributed as dist from torch.nn import Parameter -from colossalai.context.parallel_mode import ParallelMode -from colossalai.core import global_context as gpc -from colossalai.global_variables import tensor_parallel_env as env -from colossalai.nn import ( +from colossalai.legacy.context.parallel_mode import ParallelMode +from colossalai.legacy.core import global_context as gpc +from colossalai.legacy.global_variables import tensor_parallel_env as env +from colossalai.legacy.nn import ( Classifier1D, Embedding1D, Linear1D_Col, @@ -15,7 +15,8 @@ from colossalai.nn import ( VocabParallelCrossEntropyLoss1D, VocabParallelEmbedding1D, ) -from colossalai.utils import get_current_device, print_rank_0 +from colossalai.legacy.utils import print_rank_0 +from colossalai.utils import get_current_device from .common import BATCH_SIZE, DEPTH, HIDDEN_SIZE, NUM_CLASSES, SEQ_LENGTH, VOCAB_SIZE, check_equal @@ -43,7 +44,7 @@ def check_linear_col(): W = W.clone() W.requires_grad = True - B_shape = (OUTPUT_SIZE) + B_shape = OUTPUT_SIZE B_master = torch.randn(B_shape, dtype=dtype, device=device) dist.broadcast(B_master, src=0) B = torch.chunk(B_master, DEPTH, dim=0)[i] @@ -64,7 +65,7 @@ def check_linear_col(): C = torch.chunk(C_master, DEPTH, dim=-1)[i] check_equal(out, C) - print_rank_0('linear_col forward: pass') + print_rank_0("linear_col forward: pass") grad_shape = C_master.shape grad_master = torch.randn(grad_shape, dtype=dtype, device=get_current_device()) @@ -86,7 +87,7 @@ def check_linear_col(): B_grad = torch.chunk(B_grad, DEPTH, dim=0)[i] check_equal(B_grad, layer.bias.grad) - print_rank_0('linear_col backward: pass') + print_rank_0("linear_col backward: pass") def check_linear_row(): @@ -113,7 +114,7 @@ def check_linear_row(): W = W.clone() W.requires_grad = True - B_shape = (INPUT_SIZE) + B_shape = INPUT_SIZE B_master = torch.randn(B_shape, dtype=dtype, device=device) dist.broadcast(B_master, src=0) B = B_master.clone() @@ -133,7 +134,7 @@ def check_linear_row(): C = C_master.clone() check_equal(out, C) - print_rank_0('linear_row forward: pass') + print_rank_0("linear_row forward: pass") grad_shape = C_master.shape grad_master = torch.randn(grad_shape, dtype=dtype, device=get_current_device()) @@ -154,7 +155,7 @@ def check_linear_row(): B_grad = B_master.grad check_equal(B_grad, layer.bias.grad) - print_rank_0('linear_row backward: pass') + print_rank_0("linear_row backward: pass") def check_embed(): @@ -183,7 +184,7 @@ def check_embed(): C_master = embed_master(A_master) C = C_master.clone() check_equal(out, C) - print_rank_0('embed forward: pass') + print_rank_0("embed forward: pass") grad_shape = C_master.shape grad_master = torch.randn(grad_shape, dtype=dtype, device=device) @@ -196,7 +197,7 @@ def check_embed(): B_grad = embed_master.weight.grad B_grad = torch.chunk(B_grad, DEPTH, dim=-1)[i] check_equal(B_grad, embed.weight.grad) - print_rank_0('embed backward: pass') + print_rank_0("embed backward: pass") def check_vocab_parallel_embed(): @@ -225,7 +226,7 @@ def check_vocab_parallel_embed(): C_master = embed_master(A_master) C = C_master.clone() check_equal(out, C) - print_rank_0('vocab parallel embed forward: pass') + print_rank_0("vocab parallel embed forward: pass") grad_shape = C_master.shape grad_master = torch.randn(grad_shape, dtype=dtype, device=device) @@ -238,7 +239,7 @@ def check_vocab_parallel_embed(): B_grad = embed_master.weight.grad B_grad = torch.chunk(B_grad, DEPTH, dim=0)[i] check_equal(B_grad, embed.weight.grad) - print_rank_0('vocab parallel embed backward: pass') + print_rank_0("vocab parallel embed backward: pass") def check_classifier_no_given_weight(): @@ -282,7 +283,7 @@ def check_classifier_no_given_weight(): C = C_master.clone() check_equal(out, C) - print_rank_0('classifier (no given weight) forward: pass') + print_rank_0("classifier (no given weight) forward: pass") grad_shape = C_master.shape grad_master = torch.randn(grad_shape, dtype=dtype, device=device) @@ -304,7 +305,7 @@ def check_classifier_no_given_weight(): B_grad = layer_master.bias.grad check_equal(B_grad, layer.bias.grad) - print_rank_0('classifier (no given weight) backward: pass') + print_rank_0("classifier (no given weight) backward: pass") def check_vocab_parallel_classifier_no_given_weight(): @@ -342,7 +343,7 @@ def check_vocab_parallel_classifier_no_given_weight(): C = torch.chunk(C_master, DEPTH, dim=-1)[i] check_equal(out, C) - print_rank_0('vocab parallel classifier (no given weight) forward: pass') + print_rank_0("vocab parallel classifier (no given weight) forward: pass") grad_shape = C_master.shape grad_master = torch.randn(grad_shape, dtype=dtype, device=device) @@ -364,7 +365,7 @@ def check_vocab_parallel_classifier_no_given_weight(): B_grad = torch.chunk(B_grad, DEPTH, dim=0)[i] check_equal(B_grad, layer.bias.grad) - print_rank_0('vocab parallel classifier (no given weight) backward: pass') + print_rank_0("vocab parallel classifier (no given weight) backward: pass") def check_classifier_given_embed_weight(): @@ -400,7 +401,7 @@ def check_classifier_given_embed_weight(): C_master = layer_master(embed_master(A_master)) C = C_master.clone() check_equal(out, C) - print_rank_0('classifier (given embed weight) forward: pass') + print_rank_0("classifier (given embed weight) forward: pass") grad_shape = C_master.shape grad_master = torch.randn(grad_shape, dtype=dtype, device=device) @@ -415,7 +416,7 @@ def check_classifier_given_embed_weight(): W_grad = torch.chunk(W_grad, DEPTH, dim=-1)[i] check_equal(W_grad, embed.weight.grad) - print_rank_0('classifier (given embed weight) backward: pass') + print_rank_0("classifier (given embed weight) backward: pass") def check_vocab_parallel_classifier_given_embed_weight(): @@ -451,7 +452,7 @@ def check_vocab_parallel_classifier_given_embed_weight(): C_master = layer_master(embed_master(A_master)) C = torch.chunk(C_master, DEPTH, dim=-1)[i] check_equal(out, C) - print_rank_0('vocab parallel classifier (given embed weight) forward: pass') + print_rank_0("vocab parallel classifier (given embed weight) forward: pass") grad_shape = C_master.shape grad_master = torch.randn(grad_shape, dtype=dtype, device=device) @@ -467,7 +468,7 @@ def check_vocab_parallel_classifier_given_embed_weight(): W_grad = torch.chunk(W_grad, DEPTH, dim=0)[i] check_equal(W_grad, embed.weight.grad) - print_rank_0('vocab parallel classifier (given embed weight) backward: pass') + print_rank_0("vocab parallel classifier (given embed weight) backward: pass") def check_vocab_parallel_loss(): @@ -494,7 +495,7 @@ def check_vocab_parallel_loss(): out_master.requires_grad = True loss_master = criterion_master(out_master, target_master) check_equal(loss, loss_master) - print_rank_0('vocab parallel loss forward: pass') + print_rank_0("vocab parallel loss forward: pass") loss.backward() loss_master.backward() @@ -502,7 +503,7 @@ def check_vocab_parallel_loss(): out_grad = out_master.grad out_grad = torch.chunk(out_grad, DEPTH, dim=-1)[i] check_equal(out_grad, out.grad) - print_rank_0('vocab parallel loss backward: pass') + print_rank_0("vocab parallel loss backward: pass") @torch.no_grad() @@ -530,7 +531,7 @@ def check_linear_row_stream_inference(): W = torch.chunk(W_master, DEPTH, dim=-1)[i] W = W.clone() - B_shape = (INPUT_SIZE) + B_shape = INPUT_SIZE B_master = torch.randn(B_shape, dtype=dtype, device=device) dist.broadcast(B_master, src=0) B = B_master.clone() @@ -549,4 +550,4 @@ def check_linear_row_stream_inference(): C = C_master.clone() check_equal(out, C) - print_rank_0('linear_row forward: pass') + print_rank_0("linear_row forward: pass") diff --git a/tests/test_legacy/test_layers/test_1d/checks_1d/common.py b/tests/test_legacy/test_layers/test_1d/checks_1d/common.py new file mode 100644 index 0000000000000000000000000000000000000000..29a9a3d203307e0dc47cbe7eeb95a0dbd4b95af5 --- /dev/null +++ b/tests/test_legacy/test_layers/test_1d/checks_1d/common.py @@ -0,0 +1,16 @@ +#!/usr/bin/env python +# -*- encoding: utf-8 -*- + +import torch + +DEPTH = 4 +BATCH_SIZE = 8 +SEQ_LENGTH = 8 +IMG_SIZE = 16 +HIDDEN_SIZE = 8 +NUM_CLASSES = 8 +VOCAB_SIZE = 16 + + +def check_equal(A, B): + assert torch.allclose(A, B, rtol=1e-3, atol=1e-1) == True diff --git a/tests/test_layers/test_1d/test_1d.py b/tests/test_legacy/test_layers/test_1d/test_1d.py similarity index 77% rename from tests/test_layers/test_1d/test_1d.py rename to tests/test_legacy/test_layers/test_1d/test_1d.py index 89151254247522e4b34dfae1155fca9dc73d0572..cebbedd303eea1b703817ad498d7b3343074ec62 100644 --- a/tests/test_layers/test_1d/test_1d.py +++ b/tests/test_legacy/test_layers/test_1d/test_1d.py @@ -5,17 +5,19 @@ import pytest import torch from checks_1d.check_layer_1d import * -from colossalai.core import global_context as gpc -from colossalai.initialize import launch +from colossalai.legacy.core import global_context as gpc +from colossalai.legacy.initialize import launch from colossalai.logging import disable_existing_loggers from colossalai.testing import rerun_if_address_is_in_use, spawn -CONFIG = dict(parallel=dict(pipeline=dict(size=1), tensor=dict(size=4, mode='1d')),) +CONFIG = dict( + parallel=dict(pipeline=dict(size=1), tensor=dict(size=4, mode="1d")), +) def check_layer(rank, world_size, port): disable_existing_loggers() - launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + launch(config=CONFIG, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") check_linear_col() check_linear_row() @@ -39,5 +41,5 @@ def test_1d(): spawn(check_layer, 4) -if __name__ == '__main__': +if __name__ == "__main__": test_1d() diff --git a/tests/test_legacy/test_layers/test_2d/checks_2d/__init__.py b/tests/test_legacy/test_layers/test_2d/checks_2d/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/tests/test_layers/test_2d/checks_2d/check_layer_2d.py b/tests/test_legacy/test_layers/test_2d/checks_2d/check_layer_2d.py similarity index 91% rename from tests/test_layers/test_2d/checks_2d/check_layer_2d.py rename to tests/test_legacy/test_layers/test_2d/checks_2d/check_layer_2d.py index e030e473a36311250bd0310b94261a46b77d561a..0bbc72eca8091b211668a56717740b5e31f25e95 100644 --- a/tests/test_layers/test_2d/checks_2d/check_layer_2d.py +++ b/tests/test_legacy/test_layers/test_2d/checks_2d/check_layer_2d.py @@ -1,12 +1,24 @@ import torch -from colossalai.context.parallel_mode import ParallelMode -from colossalai.core import global_context as gpc -from colossalai.nn import (Classifier2D, CrossEntropyLoss2D, Embedding2D, LayerNorm2D, Linear2D, PatchEmbedding2D, - VanillaClassifier, VanillaPatchEmbedding, VocabParallelClassifier2D, - VocabParallelCrossEntropyLoss2D, VocabParallelEmbedding2D) -from colossalai.utils import get_current_device, print_rank_0 -from .common import (BATCH_SIZE, DEPTH, HIDDEN_SIZE, IMG_SIZE, NUM_CLASSES, SEQ_LENGTH, VOCAB_SIZE, check_equal) +from colossalai.legacy.context.parallel_mode import ParallelMode +from colossalai.legacy.core import global_context as gpc +from colossalai.legacy.nn import ( + Classifier2D, + CrossEntropyLoss2D, + Embedding2D, + LayerNorm2D, + Linear2D, + PatchEmbedding2D, + VanillaClassifier, + VanillaPatchEmbedding, + VocabParallelClassifier2D, + VocabParallelCrossEntropyLoss2D, + VocabParallelEmbedding2D, +) +from colossalai.legacy.utils import print_rank_0 +from colossalai.utils import get_current_device + +from .common import BATCH_SIZE, DEPTH, HIDDEN_SIZE, IMG_SIZE, NUM_CLASSES, SEQ_LENGTH, VOCAB_SIZE, check_equal def check_linear(): @@ -36,7 +48,7 @@ def check_linear(): W = W.clone() W.requires_grad = True - B_shape = (OUTPUT_SIZE) + B_shape = OUTPUT_SIZE B_master = torch.randn(B_shape, dtype=dtype, device=device) torch.distributed.broadcast(B_master, src=0) B = torch.chunk(B_master, DEPTH, dim=-1)[j] @@ -59,7 +71,7 @@ def check_linear(): C = torch.chunk(C, DEPTH, dim=-1)[j] check_equal(out, C) - print_rank_0('linear forward: pass') + print_rank_0("linear forward: pass") grad_shape = C_master.shape grad_master = torch.randn(grad_shape, dtype=dtype, device=get_current_device()) @@ -87,7 +99,7 @@ def check_linear(): # if i == 0: check_equal(B_grad, layer.bias.grad) - print_rank_0('linear backward: pass') + print_rank_0("linear backward: pass") def check_layernorm(): @@ -124,7 +136,7 @@ def check_layernorm(): C = torch.chunk(C, DEPTH, dim=-1)[j] check_equal(out, C) - print_rank_0('layer norm forward: pass') + print_rank_0("layer norm forward: pass") grad_shape = C_master.shape grad_master = torch.randn(grad_shape, dtype=dtype, device=get_current_device()) @@ -138,7 +150,7 @@ def check_layernorm(): A_grad = torch.chunk(A_grad, DEPTH, dim=0)[i] A_grad = torch.chunk(A_grad, DEPTH, dim=-1)[j] check_equal(A_grad, A.grad) - print_rank_0('layer norm backward: pass') + print_rank_0("layer norm backward: pass") def check_embed(): @@ -169,7 +181,7 @@ def check_embed(): C = torch.chunk(C_master, DEPTH, dim=0)[i] C = torch.chunk(C, DEPTH, dim=-1)[j] check_equal(out, C) - print_rank_0('embed forward: pass') + print_rank_0("embed forward: pass") grad_shape = C_master.shape grad_master = torch.randn(grad_shape, dtype=dtype, device=device) @@ -185,7 +197,7 @@ def check_embed(): B_grad = torch.chunk(B_grad, DEPTH, dim=-1)[j] B_grad = torch.chunk(B_grad, DEPTH, dim=-1)[i] check_equal(B_grad, embed.weight.grad) - print_rank_0('embed backward: pass') + print_rank_0("embed backward: pass") def check_patch_embed(): @@ -226,7 +238,7 @@ def check_patch_embed(): C = torch.chunk(C_master, DEPTH, dim=0)[i] C = torch.chunk(C, DEPTH, dim=-1)[j] check_equal(out, C) - print_rank_0('patch embed forward: pass') + print_rank_0("patch embed forward: pass") grad_shape = C_master.shape grad_master = torch.randn(grad_shape, dtype=dtype, device=device) @@ -258,7 +270,7 @@ def check_patch_embed(): bias_grad = torch.chunk(bias_grad, DEPTH)[j] bias_grad = torch.chunk(bias_grad, DEPTH)[i] check_equal(bias_grad, layer.bias.grad) - print_rank_0('patch embed backward: pass') + print_rank_0("patch embed backward: pass") def check_vocab_parallel_embed(): @@ -289,7 +301,7 @@ def check_vocab_parallel_embed(): C = torch.chunk(C_master, DEPTH, dim=0)[i] C = torch.chunk(C, DEPTH, dim=-1)[j] check_equal(out, C) - print_rank_0('vocab parallel embed forward: pass') + print_rank_0("vocab parallel embed forward: pass") grad_shape = C_master.shape grad_master = torch.randn(grad_shape, dtype=dtype, device=device) @@ -305,7 +317,7 @@ def check_vocab_parallel_embed(): B_grad = torch.chunk(B_grad, DEPTH, dim=-1)[j] B_grad = torch.chunk(B_grad, DEPTH, dim=0)[i] check_equal(B_grad, embed.weight.grad) - print_rank_0('vocab parallel embed backward: pass') + print_rank_0("vocab parallel embed backward: pass") def check_classifier_no_given_weight(): @@ -336,7 +348,7 @@ def check_classifier_no_given_weight(): layer.weight.data.copy_(W) # W.requires_grad = True - B_shape = (OUTPUT_SIZE, ) + B_shape = (OUTPUT_SIZE,) B_master = torch.randint(5, B_shape, dtype=dtype, device=device) torch.distributed.broadcast(B_master, src=0) # B = torch.chunk(B_master, DEPTH, dim=0)[j] @@ -356,7 +368,7 @@ def check_classifier_no_given_weight(): # C = torch.chunk(C, DEPTH, dim=-1)[j] check_equal(out, C) - print_rank_0('classifier (no given weight) forward: pass') + print_rank_0("classifier (no given weight) forward: pass") grad_shape = C_master.shape grad_master = torch.randn(grad_shape, dtype=dtype, device=get_current_device()) @@ -383,7 +395,7 @@ def check_classifier_no_given_weight(): # if i == 0: check_equal(B_grad, layer.bias.grad) - print_rank_0('classifier (no given weight) backward: pass') + print_rank_0("classifier (no given weight) backward: pass") def check_vocab_parallel_classifier_no_given_weight(): @@ -425,7 +437,7 @@ def check_vocab_parallel_classifier_no_given_weight(): C = torch.chunk(C_master, DEPTH, dim=0)[i] C = torch.chunk(C, DEPTH, dim=-1)[j] check_equal(out, C) - print_rank_0('vocab parallel classifier (no given weight) forward: pass') + print_rank_0("vocab parallel classifier (no given weight) forward: pass") grad_shape = C_master.shape grad_master = torch.randn(grad_shape, dtype=dtype, device=device) @@ -451,7 +463,7 @@ def check_vocab_parallel_classifier_no_given_weight(): B_grad = torch.chunk(B_grad, DEPTH)[j] B_grad = torch.chunk(B_grad, DEPTH)[i] check_equal(B_grad, layer.bias.grad) - print_rank_0('vocab parallel classifier (no given weight) backward: pass') + print_rank_0("vocab parallel classifier (no given weight) backward: pass") def check_classifier_given_embed_weight(): @@ -487,7 +499,7 @@ def check_classifier_given_embed_weight(): C_master = layer_master(embed_master(A_master)) C = torch.chunk(C_master, DEPTH, dim=0)[i] check_equal(out, C) - print_rank_0('classifier (given embed weight) forward: pass') + print_rank_0("classifier (given embed weight) forward: pass") grad_shape = C_master.shape grad_master = torch.randn(grad_shape, dtype=dtype, device=device) @@ -503,7 +515,7 @@ def check_classifier_given_embed_weight(): W_grad = torch.chunk(W_grad, DEPTH, dim=-1)[j] W_grad = torch.chunk(W_grad, DEPTH, dim=-1)[i] check_equal(W_grad, embed.weight.grad) - print_rank_0('classifier (given embed weight) backward: pass') + print_rank_0("classifier (given embed weight) backward: pass") def check_vocab_parallel_classifier_given_embed_weight(): @@ -540,7 +552,7 @@ def check_vocab_parallel_classifier_given_embed_weight(): C = torch.chunk(C_master, DEPTH, dim=0)[i] C = torch.chunk(C, DEPTH, dim=-1)[j] check_equal(out, C) - print_rank_0('vocab parallel classifier (given embed weight) forward: pass') + print_rank_0("vocab parallel classifier (given embed weight) forward: pass") grad_shape = C_master.shape grad_master = torch.randn(grad_shape, dtype=dtype, device=device) @@ -557,14 +569,14 @@ def check_vocab_parallel_classifier_given_embed_weight(): W_grad = torch.chunk(W_grad, DEPTH, dim=-1)[j] W_grad = torch.chunk(W_grad, DEPTH, dim=0)[i] check_equal(W_grad, embed.weight.grad) - print_rank_0('vocab parallel classifier (given embed weight) backward: pass') + print_rank_0("vocab parallel classifier (given embed weight) backward: pass") def check_loss(): device = get_current_device() dtype = torch.float32 - j = gpc.get_local_rank(ParallelMode.PARALLEL_2D_ROW) + gpc.get_local_rank(ParallelMode.PARALLEL_2D_ROW) i = gpc.get_local_rank(ParallelMode.PARALLEL_2D_COL) criterion = CrossEntropyLoss2D() @@ -572,7 +584,7 @@ def check_loss(): out_shape = (BATCH_SIZE, NUM_CLASSES) out_master = torch.randn(out_shape, dtype=dtype, device=device) - target_master = torch.randint(NUM_CLASSES, (BATCH_SIZE, ), dtype=torch.long, device=device) + target_master = torch.randint(NUM_CLASSES, (BATCH_SIZE,), dtype=torch.long, device=device) torch.distributed.broadcast(out_master, src=0) torch.distributed.broadcast(target_master, src=0) out = torch.chunk(out_master, DEPTH, dim=0)[i] @@ -584,7 +596,7 @@ def check_loss(): out_master.requires_grad = True loss_master = criterion_master(out_master, target_master) check_equal(loss, loss_master) - print_rank_0('cross entropy loss forward: pass') + print_rank_0("cross entropy loss forward: pass") loss.backward() loss_master.backward() @@ -592,7 +604,7 @@ def check_loss(): out_grad = out_master.grad out_grad = torch.chunk(out_grad, DEPTH, dim=0)[i] check_equal(out_grad, out.grad) - print_rank_0('cross entropy loss backward: pass') + print_rank_0("cross entropy loss backward: pass") def check_vocab_parallel_loss(): @@ -607,7 +619,7 @@ def check_vocab_parallel_loss(): out_shape = (BATCH_SIZE, NUM_CLASSES) out_master = torch.randn(out_shape, dtype=dtype, device=device) - target_master = torch.randint(NUM_CLASSES, (BATCH_SIZE, ), dtype=torch.long, device=device) + target_master = torch.randint(NUM_CLASSES, (BATCH_SIZE,), dtype=torch.long, device=device) torch.distributed.broadcast(out_master, src=0) torch.distributed.broadcast(target_master, src=0) out = torch.chunk(out_master, DEPTH, dim=0)[i] @@ -620,7 +632,7 @@ def check_vocab_parallel_loss(): out_master.requires_grad = True loss_master = criterion_master(out_master, target_master) check_equal(loss, loss_master) - print_rank_0('vocab parallel cross entropy loss forward: pass') + print_rank_0("vocab parallel cross entropy loss forward: pass") loss.backward() loss_master.backward() @@ -629,7 +641,7 @@ def check_vocab_parallel_loss(): out_grad = torch.chunk(out_grad, DEPTH, dim=0)[i] out_grad = torch.chunk(out_grad, DEPTH, dim=-1)[j] check_equal(out_grad, out.grad) - print_rank_0('vocab parallel cross entropy loss backward: pass') + print_rank_0("vocab parallel cross entropy loss backward: pass") # def check_attention(): diff --git a/tests/test_legacy/test_layers/test_2d/checks_2d/check_operation_2d.py b/tests/test_legacy/test_layers/test_2d/checks_2d/check_operation_2d.py new file mode 100644 index 0000000000000000000000000000000000000000..9c126cefeba80ae62adf73ef1e4f92daba3d7ffb --- /dev/null +++ b/tests/test_legacy/test_layers/test_2d/checks_2d/check_operation_2d.py @@ -0,0 +1,255 @@ +#!/usr/bin/env python +# -*- encoding: utf-8 -*- + +import torch + +from colossalai.legacy.context.parallel_mode import ParallelMode +from colossalai.legacy.core import global_context as gpc +from colossalai.legacy.nn.layer.parallel_2d._operation import Matmul_AB_2D, Matmul_ABT_2D, Matmul_ATB_2D +from colossalai.legacy.utils import print_rank_0 +from colossalai.utils import get_current_device + +from .common import BATCH_SIZE, DEPTH, HIDDEN_SIZE, SEQ_LENGTH, check_equal + + +def check_AB(): + data_parallel_rank = 0 if not gpc.is_initialized(ParallelMode.DATA) else gpc.get_local_rank(ParallelMode.DATA) + pipeline_parallel_rank = ( + 0 if not gpc.is_initialized(ParallelMode.PIPELINE) else gpc.get_local_rank(ParallelMode.PIPELINE) + ) + pipeline_parallel_size = ( + 1 if not gpc.is_initialized(ParallelMode.PIPELINE) else gpc.get_world_size(ParallelMode.PIPELINE) + ) + tensor_parallel_size = gpc.get_world_size(ParallelMode.TENSOR) + + dtype = torch.float + j = gpc.get_local_rank(ParallelMode.PARALLEL_2D_ROW) + i = gpc.get_local_rank(ParallelMode.PARALLEL_2D_COL) + + A_shape = (BATCH_SIZE, SEQ_LENGTH, HIDDEN_SIZE) + A_master = torch.randn(A_shape, dtype=dtype, device=get_current_device()) + torch.distributed.broadcast(A_master, src=0) + A = torch.chunk(A_master, DEPTH, dim=0)[i] + A = torch.chunk(A, DEPTH, dim=-1)[j] + A = A.clone() + A.requires_grad = True + + B_shape = (HIDDEN_SIZE, 4 * HIDDEN_SIZE) + B_master = torch.randn(B_shape, dtype=dtype, device=get_current_device()) + torch.distributed.broadcast(B_master, src=0) + B = torch.chunk(B_master, DEPTH, dim=0)[i] + B = torch.chunk(B, DEPTH, dim=-1)[j] + B = B.clone() + B.requires_grad = True + + out_shape = (BATCH_SIZE // DEPTH, SEQ_LENGTH, 4 * HIDDEN_SIZE // DEPTH) + + out = Matmul_AB_2D.apply( + A, + B, + DEPTH, + out_shape, + i, + j, + ParallelMode.PARALLEL_2D_ROW, + ParallelMode.PARALLEL_2D_COL, + data_parallel_rank, + pipeline_parallel_rank, + pipeline_parallel_size, + tensor_parallel_size, + ) + + (BATCH_SIZE, SEQ_LENGTH, 4 * HIDDEN_SIZE) + A_master = A_master.clone() + A_master.requires_grad = True + B_master = B_master.clone() + B_master.requires_grad = True + C_master = torch.matmul(A_master, B_master) + C = torch.chunk(C_master, DEPTH, dim=0)[i] + C = torch.chunk(C, DEPTH, dim=-1)[j] + # check forward correctness + check_equal(out, C) + print_rank_0("AB forward: pass") + + grad_shape = C_master.shape + grad_master = torch.randn(grad_shape, dtype=dtype, device=get_current_device()) + torch.distributed.broadcast(grad_master, src=0) + grad = torch.chunk(grad_master, DEPTH, dim=0)[i] + grad = torch.chunk(grad, DEPTH, dim=-1)[j] + + out.backward(grad) + + C_master.backward(grad_master) + A_grad = A_master.grad + A_grad = torch.chunk(A_grad, DEPTH, dim=0)[i] + A_grad = torch.chunk(A_grad, DEPTH, dim=-1)[j] + # check backward correctness + check_equal(A_grad, A.grad) + + B_grad = B_master.grad + B_grad = torch.chunk(B_grad, DEPTH, dim=0)[i] + B_grad = torch.chunk(B_grad, DEPTH, dim=-1)[j] + # check backward correctness + check_equal(B_grad, B.grad) + print_rank_0("AB backward: pass") + + +def check_ABT(): + data_parallel_rank = 0 if not gpc.is_initialized(ParallelMode.DATA) else gpc.get_local_rank(ParallelMode.DATA) + pipeline_parallel_rank = ( + 0 if not gpc.is_initialized(ParallelMode.PIPELINE) else gpc.get_local_rank(ParallelMode.PIPELINE) + ) + pipeline_parallel_size = ( + 1 if not gpc.is_initialized(ParallelMode.PIPELINE) else gpc.get_world_size(ParallelMode.PIPELINE) + ) + tensor_parallel_size = gpc.get_world_size(ParallelMode.TENSOR) + + dtype = torch.float + device = get_current_device() + + j = gpc.get_local_rank(ParallelMode.PARALLEL_2D_ROW) + i = gpc.get_local_rank(ParallelMode.PARALLEL_2D_COL) + + C_shape = (BATCH_SIZE, SEQ_LENGTH, 4 * HIDDEN_SIZE) + C_master = torch.randn(C_shape, dtype=dtype, device=device) + torch.distributed.broadcast(C_master, src=0) + C = torch.chunk(C_master, DEPTH, dim=0)[i] + C = torch.chunk(C, DEPTH, dim=-1)[j] + C = C.clone() + C.requires_grad = True + + B_shape = (HIDDEN_SIZE, 4 * HIDDEN_SIZE) + B_master = torch.randn(B_shape, dtype=dtype, device=device) + torch.distributed.broadcast(B_master, src=0) + B = torch.chunk(B_master, DEPTH, dim=0)[i] + B = torch.chunk(B, DEPTH, dim=-1)[j] + B = B.clone() + B.requires_grad = True + + out = Matmul_ABT_2D.apply( + C, + B, + DEPTH, + (BATCH_SIZE // DEPTH, SEQ_LENGTH, HIDDEN_SIZE // DEPTH), + i, + j, + ParallelMode.PARALLEL_2D_ROW, + ParallelMode.PARALLEL_2D_COL, + data_parallel_rank, + pipeline_parallel_rank, + pipeline_parallel_size, + tensor_parallel_size, + ) + + (BATCH_SIZE, SEQ_LENGTH, HIDDEN_SIZE) + C_master = C_master.clone() + C_master.requires_grad = True + B_master = B_master.clone() + B_master.requires_grad = True + A_master = torch.matmul(C_master, B_master.transpose(0, 1)) + A = torch.chunk(A_master, DEPTH, dim=0)[i] + A = torch.chunk(A, DEPTH, dim=-1)[j] + check_equal(out, A) + print_rank_0("ABT forward: pass") + + grad_shape = A_master.shape + grad_master = torch.randn(grad_shape, dtype=dtype, device=device) + torch.distributed.broadcast(grad_master, src=0) + grad = torch.chunk(grad_master, DEPTH, dim=0)[i] + grad = torch.chunk(grad, DEPTH, dim=-1)[j] + + # backward + out.backward(grad) + + A_master.backward(grad_master) + C_grad = C_master.grad + C_grad = torch.chunk(C_grad, DEPTH, dim=0)[i] + C_grad = torch.chunk(C_grad, DEPTH, dim=-1)[j] + check_equal(C_grad, C.grad) + + B_grad = B_master.grad + B_grad = torch.chunk(B_grad, DEPTH, dim=0)[i] + B_grad = torch.chunk(B_grad, DEPTH, dim=-1)[j] + check_equal(B_grad, B.grad) + print_rank_0("ABT backward: pass") + + +def check_ATB(): + data_parallel_rank = 0 if not gpc.is_initialized(ParallelMode.DATA) else gpc.get_local_rank(ParallelMode.DATA) + pipeline_parallel_rank = ( + 0 if not gpc.is_initialized(ParallelMode.PIPELINE) else gpc.get_local_rank(ParallelMode.PIPELINE) + ) + pipeline_parallel_size = ( + 1 if not gpc.is_initialized(ParallelMode.PIPELINE) else gpc.get_world_size(ParallelMode.PIPELINE) + ) + tensor_parallel_size = gpc.get_world_size(ParallelMode.TENSOR) + + device = get_current_device() + dtype = torch.float + + j = gpc.get_local_rank(ParallelMode.PARALLEL_2D_ROW) + i = gpc.get_local_rank(ParallelMode.PARALLEL_2D_COL) + + A_shape = (BATCH_SIZE, SEQ_LENGTH, HIDDEN_SIZE) + A_master = torch.randn(A_shape, dtype=dtype, device=device) + torch.distributed.broadcast(A_master, src=0) + A = torch.chunk(A_master, DEPTH, dim=0)[i] + A = torch.chunk(A, DEPTH, dim=-1)[j] + A = A.clone() + A.requires_grad = True + + C_shape = (BATCH_SIZE, SEQ_LENGTH, 4 * HIDDEN_SIZE) + C_master = torch.randn(C_shape, dtype=dtype, device=device) + torch.distributed.broadcast(C_master, src=0) + C = torch.chunk(C_master, DEPTH, dim=0)[i] + C = torch.chunk(C, DEPTH, dim=-1)[j] + C = C.clone() + C.requires_grad = True + + out = Matmul_ATB_2D.apply( + A, + C, + DEPTH, + (HIDDEN_SIZE // DEPTH, 4 * HIDDEN_SIZE // DEPTH), + i, + j, + ParallelMode.PARALLEL_2D_ROW, + ParallelMode.PARALLEL_2D_COL, + data_parallel_rank, + pipeline_parallel_rank, + pipeline_parallel_size, + tensor_parallel_size, + ) + + (HIDDEN_SIZE, 4 * HIDDEN_SIZE) + A_master = A_master.clone() + A_master.requires_grad = True + C_master = C_master.clone() + C_master.requires_grad = True + B_master = torch.matmul( + A_master.view(-1, A_master.shape[-1]).transpose(0, 1), C_master.view(-1, C_master.shape[-1]) + ) + B = torch.chunk(B_master, DEPTH, dim=0)[i] + B = torch.chunk(B, DEPTH, dim=-1)[j] + check_equal(out, B) + print_rank_0("ATB forward: pass") + + grad_shape = B_master.shape + grad_master = torch.randn(grad_shape, dtype=dtype, device=device) + torch.distributed.broadcast(grad_master, src=0) + grad = torch.chunk(grad_master, DEPTH, dim=0)[i] + grad = torch.chunk(grad, DEPTH, dim=-1)[j] + + out.backward(grad) + + B_master.backward(grad_master) + A_grad = A_master.grad + A_grad = torch.chunk(A_grad, DEPTH, dim=0)[i] + A_grad = torch.chunk(A_grad, DEPTH, dim=-1)[j] + check_equal(A_grad, A.grad) + + C_grad = C_master.grad + C_grad = torch.chunk(C_grad, DEPTH, dim=0)[i] + C_grad = torch.chunk(C_grad, DEPTH, dim=-1)[j] + check_equal(C_grad, C.grad) + print_rank_0("ATB backward: pass") diff --git a/tests/test_layers/test_2d/checks_2d/common.py b/tests/test_legacy/test_layers/test_2d/checks_2d/common.py similarity index 100% rename from tests/test_layers/test_2d/checks_2d/common.py rename to tests/test_legacy/test_layers/test_2d/checks_2d/common.py diff --git a/tests/test_layers/test_2d/test_2d.py b/tests/test_legacy/test_layers/test_2d/test_2d.py similarity index 85% rename from tests/test_layers/test_2d/test_2d.py rename to tests/test_legacy/test_layers/test_2d/test_2d.py index bcea5ce7b25dbad1fa9b84c38651feb347ead0cf..77a4b281a746efbc8568f36c10ca973c99ea53d9 100644 --- a/tests/test_layers/test_2d/test_2d.py +++ b/tests/test_legacy/test_layers/test_2d/test_2d.py @@ -18,12 +18,14 @@ from checks_2d.check_layer_2d import ( ) from checks_2d.check_operation_2d import check_AB, check_ABT, check_ATB -from colossalai.core import global_context as gpc -from colossalai.initialize import launch +from colossalai.legacy.core import global_context as gpc +from colossalai.legacy.initialize import launch from colossalai.logging import disable_existing_loggers from colossalai.testing import rerun_if_address_is_in_use, spawn -CONFIG = dict(parallel=dict(pipeline=dict(size=1), tensor=dict(size=4, mode='2d')),) +CONFIG = dict( + parallel=dict(pipeline=dict(size=1), tensor=dict(size=4, mode="2d")), +) def check_operations(): @@ -48,7 +50,7 @@ def check_layer(): def check_layer_and_operation(rank, world_size, port): disable_existing_loggers() - launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + launch(config=CONFIG, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") torch.backends.cuda.matmul.allow_tf32 = False torch.backends.cudnn.allow_tf32 = False @@ -65,5 +67,5 @@ def test_2d(): spawn(check_layer_and_operation, 4) -if __name__ == '__main__': +if __name__ == "__main__": test_2d() diff --git a/tests/test_legacy/test_layers/test_2p5d/checks_2p5d/__init__.py b/tests/test_legacy/test_layers/test_2p5d/checks_2p5d/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/tests/test_layers/test_2p5d/checks_2p5d/check_layer_2p5d.py b/tests/test_legacy/test_layers/test_2p5d/checks_2p5d/check_layer_2p5d.py similarity index 90% rename from tests/test_layers/test_2p5d/checks_2p5d/check_layer_2p5d.py rename to tests/test_legacy/test_layers/test_2p5d/checks_2p5d/check_layer_2p5d.py index a8f551093b1ef782f8bef64d4241c1400aa6bdde..283e7f68374fbd45f4b7ed75131983069bf23570 100644 --- a/tests/test_layers/test_2p5d/checks_2p5d/check_layer_2p5d.py +++ b/tests/test_legacy/test_layers/test_2p5d/checks_2p5d/check_layer_2p5d.py @@ -1,12 +1,24 @@ import torch -from colossalai.context.parallel_mode import ParallelMode -from colossalai.core import global_context as gpc -from colossalai.nn import (Classifier2p5D, CrossEntropyLoss2p5D, Embedding2p5D, LayerNorm2p5D, Linear2p5D, - PatchEmbedding2p5D, VanillaClassifier, VanillaPatchEmbedding, VocabParallelClassifier2p5D, - VocabParallelCrossEntropyLoss2p5D, VocabParallelEmbedding2p5D) -from colossalai.utils import get_current_device, print_rank_0 from torch.nn import Parameter +from colossalai.legacy.context.parallel_mode import ParallelMode +from colossalai.legacy.core import global_context as gpc +from colossalai.legacy.nn import ( + Classifier2p5D, + CrossEntropyLoss2p5D, + Embedding2p5D, + LayerNorm2p5D, + Linear2p5D, + PatchEmbedding2p5D, + VanillaClassifier, + VanillaPatchEmbedding, + VocabParallelClassifier2p5D, + VocabParallelCrossEntropyLoss2p5D, + VocabParallelEmbedding2p5D, +) +from colossalai.legacy.utils import print_rank_0 +from colossalai.utils import get_current_device + from .common import * @@ -18,7 +30,7 @@ def check_linear(): i = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL) j = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_ROW) - k = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_DEP) + gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_DEP) layer = Linear2p5D(INPUT_SIZE, OUTPUT_SIZE, dtype=dtype, skip_bias_add=False) @@ -38,7 +50,7 @@ def check_linear(): W = W.clone() W.requires_grad = True - B_shape = (OUTPUT_SIZE) + B_shape = OUTPUT_SIZE B_master = torch.randn(B_shape, dtype=dtype, device=device) torch.distributed.broadcast(B_master, src=0) B = torch.chunk(B_master, TESSERACT_DIM, dim=0)[j] @@ -48,7 +60,7 @@ def check_linear(): layer.weight = Parameter(W) layer.bias = Parameter(B) out = layer(A) - bias = layer.bias + layer.bias A_master = A_master.clone() A_master.requires_grad = True @@ -61,7 +73,7 @@ def check_linear(): C = torch.chunk(C, TESSERACT_DIM, dim=-1)[j] check_equal(out, C) - print_rank_0('linear forward: pass') + print_rank_0("linear forward: pass") grad_shape = C_master.shape grad_master = torch.randn(grad_shape, dtype=dtype, device=get_current_device()) @@ -88,7 +100,7 @@ def check_linear(): if i == 0: check_equal(B_grad, layer.bias.grad) - print_rank_0('linear backward: pass') + print_rank_0("linear backward: pass") def check_layernorm(): @@ -99,7 +111,7 @@ def check_layernorm(): i = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL) j = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_ROW) - k = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_DEP) + gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_DEP) layernorm = LayerNorm2p5D(INPUT_SIZE, dtype=dtype) @@ -126,7 +138,7 @@ def check_layernorm(): C = torch.chunk(C, TESSERACT_DIM, dim=-1)[j] check_equal(out, C) - print_rank_0('layer norm forward: pass') + print_rank_0("layer norm forward: pass") grad_shape = C_master.shape grad_master = torch.randn(grad_shape, dtype=dtype, device=get_current_device()) @@ -140,7 +152,7 @@ def check_layernorm(): A_grad = torch.chunk(A_grad, TESSERACT_DIM, dim=0)[i] A_grad = torch.chunk(A_grad, TESSERACT_DIM, dim=-1)[j] check_equal(A_grad, A.grad) - print_rank_0('layer norm backward: pass') + print_rank_0("layer norm backward: pass") def check_embed(): @@ -148,7 +160,7 @@ def check_embed(): dtype = torch.float32 i = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL) j = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_ROW) - k = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_DEP) + gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_DEP) embed = Embedding2p5D(VOCAB_SIZE, HIDDEN_SIZE) embed = embed.to(dtype).to(device) @@ -172,7 +184,7 @@ def check_embed(): C = torch.chunk(C_master, TESSERACT_DIM, dim=0)[i] C = torch.chunk(C, TESSERACT_DIM, dim=-1)[j] check_equal(out, C) - print_rank_0('embed forward: pass') + print_rank_0("embed forward: pass") grad_shape = C_master.shape grad_master = torch.randn(grad_shape, dtype=dtype, device=device) @@ -188,7 +200,7 @@ def check_embed(): B_grad = torch.chunk(B_grad, TESSERACT_DIM, dim=-1)[j] B_grad = torch.chunk(B_grad, TESSERACT_DIM, dim=-1)[i] check_equal(B_grad, embed.weight.grad) - print_rank_0('embed backward: pass') + print_rank_0("embed backward: pass") def check_patch_embed(): @@ -196,7 +208,7 @@ def check_patch_embed(): dtype = torch.float32 i = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL) j = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_ROW) - k = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_DEP) + gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_DEP) layer = PatchEmbedding2p5D(IMG_SIZE, 4, 3, HIDDEN_SIZE, dtype=dtype) torch.nn.init.ones_(layer.cls_token) @@ -230,7 +242,7 @@ def check_patch_embed(): C = torch.chunk(C_master, TESSERACT_DIM, dim=0)[i] C = torch.chunk(C, TESSERACT_DIM, dim=-1)[j] check_equal(out, C) - print_rank_0('patch embed forward: pass') + print_rank_0("patch embed forward: pass") grad_shape = C_master.shape grad_master = torch.randn(grad_shape, dtype=dtype, device=device) @@ -262,7 +274,7 @@ def check_patch_embed(): bias_grad = torch.chunk(bias_grad, TESSERACT_DIM)[j] bias_grad = torch.chunk(bias_grad, TESSERACT_DIM)[i] check_equal(bias_grad, layer.bias.grad) - print_rank_0('patch embed backward: pass') + print_rank_0("patch embed backward: pass") def check_vocab_parallel_embed(): @@ -270,7 +282,7 @@ def check_vocab_parallel_embed(): dtype = torch.float32 i = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL) j = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_ROW) - k = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_DEP) + gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_DEP) embed = VocabParallelEmbedding2p5D(VOCAB_SIZE, HIDDEN_SIZE) embed = embed.to(dtype).to(device) @@ -294,7 +306,7 @@ def check_vocab_parallel_embed(): C = torch.chunk(C_master, TESSERACT_DIM, dim=0)[i] C = torch.chunk(C, TESSERACT_DIM, dim=-1)[j] check_equal(out, C) - print_rank_0('vocab parallel embed forward: pass') + print_rank_0("vocab parallel embed forward: pass") grad_shape = C_master.shape grad_master = torch.randn(grad_shape, dtype=dtype, device=device) @@ -310,7 +322,7 @@ def check_vocab_parallel_embed(): B_grad = torch.chunk(B_grad, TESSERACT_DIM, dim=-1)[j] B_grad = torch.chunk(B_grad, TESSERACT_DIM, dim=0)[i] check_equal(B_grad, embed.weight.grad) - print_rank_0('vocab parallel embed backward: pass') + print_rank_0("vocab parallel embed backward: pass") def check_classifier_no_given_weight(): @@ -342,7 +354,7 @@ def check_classifier_no_given_weight(): layer.weight.data.copy_(W) # W.requires_grad = True - B_shape = (OUTPUT_SIZE, ) + B_shape = (OUTPUT_SIZE,) B_master = torch.randint(5, B_shape, dtype=dtype, device=device) torch.distributed.broadcast(B_master, src=0) # B = torch.chunk(B_master, TESSERACT_DIM, dim=0)[j] @@ -362,7 +374,7 @@ def check_classifier_no_given_weight(): # C = torch.chunk(C, TESSERACT_DIM, dim=-1)[j] check_equal(out, C) - print_rank_0('classifier (no given weight) forward: pass') + print_rank_0("classifier (no given weight) forward: pass") grad_shape = C_master.shape grad_master = torch.randn(grad_shape, dtype=dtype, device=get_current_device()) @@ -389,7 +401,7 @@ def check_classifier_no_given_weight(): # if i == 0: check_equal(B_grad, layer.bias.grad) - print_rank_0('classifier (no given weight) backward: pass') + print_rank_0("classifier (no given weight) backward: pass") def check_vocab_parallel_classifier_no_given_weight(): @@ -397,7 +409,7 @@ def check_vocab_parallel_classifier_no_given_weight(): dtype = torch.float32 i = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL) j = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_ROW) - k = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_DEP) + gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_DEP) layer = VocabParallelClassifier2p5D(HIDDEN_SIZE, VOCAB_SIZE, bias=True) layer = layer.to(dtype).to(device) @@ -430,7 +442,7 @@ def check_vocab_parallel_classifier_no_given_weight(): C = torch.chunk(C_master, TESSERACT_DIM, dim=0)[i] C = torch.chunk(C, TESSERACT_DIM, dim=-1)[j] check_equal(out, C) - print_rank_0('vocab parallel classifier (no given weight) forward: pass') + print_rank_0("vocab parallel classifier (no given weight) forward: pass") grad_shape = C_master.shape grad_master = torch.randn(grad_shape, dtype=dtype, device=device) @@ -456,7 +468,7 @@ def check_vocab_parallel_classifier_no_given_weight(): B_grad = torch.chunk(B_grad, TESSERACT_DIM)[j] if i == 0: check_equal(B_grad, layer.bias.grad) - print_rank_0('vocab parallel classifier (no given weight) backward: pass') + print_rank_0("vocab parallel classifier (no given weight) backward: pass") def check_classifier_given_embed_weight(): @@ -464,7 +476,7 @@ def check_classifier_given_embed_weight(): dtype = torch.float32 i = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL) j = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_ROW) - k = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_DEP) + gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_DEP) embed = Embedding2p5D(VOCAB_SIZE, HIDDEN_SIZE) embed = embed.to(dtype).to(device) @@ -492,7 +504,7 @@ def check_classifier_given_embed_weight(): C_master = layer_master(embed_master(A_master)) C = torch.chunk(C_master, TESSERACT_DIM, dim=0)[i] check_equal(out, C) - print_rank_0('classifier (given embed weight) forward: pass') + print_rank_0("classifier (given embed weight) forward: pass") grad_shape = C_master.shape grad_master = torch.randn(grad_shape, dtype=dtype, device=device) @@ -508,7 +520,7 @@ def check_classifier_given_embed_weight(): W_grad = torch.chunk(W_grad, TESSERACT_DIM, dim=-1)[j] W_grad = torch.chunk(W_grad, TESSERACT_DIM, dim=-1)[i] check_equal(W_grad, embed.weight.grad) - print_rank_0('classifier (given embed weight) backward: pass') + print_rank_0("classifier (given embed weight) backward: pass") def check_vocab_parallel_classifier_given_embed_weight(): @@ -516,7 +528,7 @@ def check_vocab_parallel_classifier_given_embed_weight(): dtype = torch.float32 i = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL) j = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_ROW) - k = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_DEP) + gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_DEP) embed = VocabParallelEmbedding2p5D(VOCAB_SIZE, HIDDEN_SIZE) embed = embed.to(dtype).to(device) @@ -545,7 +557,7 @@ def check_vocab_parallel_classifier_given_embed_weight(): C = torch.chunk(C_master, TESSERACT_DIM, dim=0)[i] C = torch.chunk(C, TESSERACT_DIM, dim=-1)[j] check_equal(out, C) - print_rank_0('vocab parallel classifier (given embed weight) forward: pass') + print_rank_0("vocab parallel classifier (given embed weight) forward: pass") grad_shape = C_master.shape grad_master = torch.randn(grad_shape, dtype=dtype, device=device) @@ -562,22 +574,22 @@ def check_vocab_parallel_classifier_given_embed_weight(): W_grad = torch.chunk(W_grad, TESSERACT_DIM, dim=-1)[j] W_grad = torch.chunk(W_grad, TESSERACT_DIM, dim=0)[i] check_equal(W_grad, embed.weight.grad) - print_rank_0('vocab parallel classifier (given embed weight) backward: pass') + print_rank_0("vocab parallel classifier (given embed weight) backward: pass") def check_loss(): device = get_current_device() dtype = torch.float32 i = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL) - j = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_ROW) - k = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_DEP) + gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_ROW) + gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_DEP) criterion = CrossEntropyLoss2p5D() criterion_master = torch.nn.CrossEntropyLoss() out_shape = (BATCH_SIZE, NUM_CLASSES) out_master = torch.randn(out_shape, dtype=dtype, device=device) - target_master = torch.randint(NUM_CLASSES, (BATCH_SIZE, ), dtype=torch.long, device=device) + target_master = torch.randint(NUM_CLASSES, (BATCH_SIZE,), dtype=torch.long, device=device) torch.distributed.broadcast(out_master, src=0) torch.distributed.broadcast(target_master, src=0) out = torch.chunk(out_master, TESSERACT_DIM, dim=0)[i] @@ -589,7 +601,7 @@ def check_loss(): out_master.requires_grad = True loss_master = criterion_master(out_master, target_master) check_equal(loss, loss_master) - print_rank_0('cross entropy loss forward: pass') + print_rank_0("cross entropy loss forward: pass") loss.backward() loss_master.backward() @@ -597,7 +609,7 @@ def check_loss(): out_grad = out_master.grad out_grad = torch.chunk(out_grad, TESSERACT_DIM, dim=0)[i] check_equal(out_grad, out.grad) - print_rank_0('cross entropy loss backward: pass') + print_rank_0("cross entropy loss backward: pass") def check_vocab_parallel_loss(): @@ -605,14 +617,14 @@ def check_vocab_parallel_loss(): dtype = torch.float32 i = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL) j = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_ROW) - k = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_DEP) + gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_DEP) criterion = VocabParallelCrossEntropyLoss2p5D() criterion_master = torch.nn.CrossEntropyLoss() out_shape = (BATCH_SIZE, NUM_CLASSES) out_master = torch.randn(out_shape, dtype=dtype, device=device) - target_master = torch.randint(NUM_CLASSES, (BATCH_SIZE, ), dtype=torch.long, device=device) + target_master = torch.randint(NUM_CLASSES, (BATCH_SIZE,), dtype=torch.long, device=device) torch.distributed.broadcast(out_master, src=0) torch.distributed.broadcast(target_master, src=0) out = torch.chunk(out_master, TESSERACT_DIM, dim=0)[i] @@ -625,7 +637,7 @@ def check_vocab_parallel_loss(): out_master.requires_grad = True loss_master = criterion_master(out_master, target_master) check_equal(loss, loss_master) - print_rank_0('vocab parallel cross entropy loss forward: pass') + print_rank_0("vocab parallel cross entropy loss forward: pass") loss.backward() loss_master.backward() @@ -634,7 +646,7 @@ def check_vocab_parallel_loss(): out_grad = torch.chunk(out_grad, TESSERACT_DIM, dim=0)[i] out_grad = torch.chunk(out_grad, TESSERACT_DIM, dim=-1)[j] check_equal(out_grad, out.grad) - print_rank_0('vocab parallel cross entropy loss backward: pass') + print_rank_0("vocab parallel cross entropy loss backward: pass") # def check_attention(): diff --git a/tests/test_legacy/test_layers/test_2p5d/checks_2p5d/check_operation_2p5d.py b/tests/test_legacy/test_layers/test_2p5d/checks_2p5d/check_operation_2p5d.py new file mode 100644 index 0000000000000000000000000000000000000000..992bd6107f0868adebb777341bd6fcb5eac48ecf --- /dev/null +++ b/tests/test_legacy/test_layers/test_2p5d/checks_2p5d/check_operation_2p5d.py @@ -0,0 +1,257 @@ +import torch + +from colossalai.legacy.context import ParallelMode +from colossalai.legacy.core import global_context as gpc +from colossalai.legacy.nn.layer.parallel_2p5d._operation import Matmul_AB_2p5D, Matmul_ABT_2p5D, Matmul_ATB_2p5D +from colossalai.legacy.utils import print_rank_0 +from colossalai.utils import get_current_device + +from .common import * + + +def check_AB(): + data_parallel_rank = 0 if not gpc.is_initialized(ParallelMode.DATA) else gpc.get_local_rank(ParallelMode.DATA) + pipeline_parallel_rank = ( + 0 if not gpc.is_initialized(ParallelMode.PIPELINE) else gpc.get_local_rank(ParallelMode.PIPELINE) + ) + pipeline_parallel_size = ( + 1 if not gpc.is_initialized(ParallelMode.PIPELINE) else gpc.get_world_size(ParallelMode.PIPELINE) + ) + tensor_parallel_size = gpc.get_world_size(ParallelMode.TENSOR) + + dtype = torch.float + i = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL) + j = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_ROW) + k = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_DEP) + + A_shape = (BATCH_SIZE, SEQ_LENGTH, HIDDEN_SIZE) + A_master = torch.randn(A_shape, dtype=dtype, device=get_current_device()) + torch.distributed.broadcast(A_master, src=0) + A = torch.chunk(A_master, TESSERACT_DIM, dim=0)[i] + A = torch.chunk(A, TESSERACT_DIM, dim=-1)[j] + A = A.clone() + A.requires_grad = True + + B_shape = (HIDDEN_SIZE, 4 * HIDDEN_SIZE) + B_master = torch.randn(B_shape, dtype=dtype, device=get_current_device()) + torch.distributed.broadcast(B_master, src=0) + B = torch.chunk(B_master, TESSERACT_DIM, dim=0)[i] + B = torch.chunk(B, TESSERACT_DIM, dim=-1)[j] + B = B.clone() + B.requires_grad = True + + out_shape = (BATCH_SIZE // TESSERACT_DIM, SEQ_LENGTH, 4 * HIDDEN_SIZE // TESSERACT_DIM) + out = Matmul_AB_2p5D.apply( + A, + B, + TESSERACT_DIM, + out_shape, + i, + j, + k, + ParallelMode.PARALLEL_2P5D_ROW, + ParallelMode.PARALLEL_2P5D_COL, + data_parallel_rank, + pipeline_parallel_rank, + pipeline_parallel_size, + tensor_parallel_size, + ) + + (BATCH_SIZE, SEQ_LENGTH, 4 * HIDDEN_SIZE) + A_master = A_master.clone() + A_master.requires_grad = True + B_master = B_master.clone() + B_master.requires_grad = True + C_master = torch.matmul(A_master, B_master) + C = torch.chunk(C_master, TESSERACT_DIM, dim=0)[i] + C = torch.chunk(C, TESSERACT_DIM, dim=-1)[j] + # check forward correctness + check_equal(out, C) + print_rank_0("AB forward: pass") + + grad_shape = C_master.shape + grad_master = torch.randn(grad_shape, dtype=dtype, device=get_current_device()) + torch.distributed.broadcast(grad_master, src=0) + grad = torch.chunk(grad_master, TESSERACT_DIM, dim=0)[i] + grad = torch.chunk(grad, TESSERACT_DIM, dim=-1)[j] + + out.backward(grad) + + C_master.backward(grad_master) + A_grad = A_master.grad + A_grad = torch.chunk(A_grad, TESSERACT_DIM, dim=0)[i] + A_grad = torch.chunk(A_grad, TESSERACT_DIM, dim=-1)[j] + # check backward correctness + check_equal(A_grad, A.grad) + + B_grad = B_master.grad + B_grad = torch.chunk(B_grad, TESSERACT_DIM, dim=0)[i] + B_grad = torch.chunk(B_grad, TESSERACT_DIM, dim=-1)[j] + # check backward correctness + check_equal(B_grad, B.grad) + print_rank_0("AB backward: pass") + + +def check_ABT(): + data_parallel_rank = 0 if not gpc.is_initialized(ParallelMode.DATA) else gpc.get_local_rank(ParallelMode.DATA) + pipeline_parallel_rank = ( + 0 if not gpc.is_initialized(ParallelMode.PIPELINE) else gpc.get_local_rank(ParallelMode.PIPELINE) + ) + pipeline_parallel_size = ( + 1 if not gpc.is_initialized(ParallelMode.PIPELINE) else gpc.get_world_size(ParallelMode.PIPELINE) + ) + tensor_parallel_size = gpc.get_world_size(ParallelMode.TENSOR) + + dtype = torch.float + device = get_current_device() + + i = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL) + j = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_ROW) + k = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_DEP) + + C_shape = (BATCH_SIZE, SEQ_LENGTH, 4 * HIDDEN_SIZE) + C_master = torch.randn(C_shape, dtype=dtype, device=device) + torch.distributed.broadcast(C_master, src=0) + C = torch.chunk(C_master, TESSERACT_DIM, dim=0)[i] + C = torch.chunk(C, TESSERACT_DIM, dim=-1)[j] + C = C.clone() + C.requires_grad = True + + B_shape = (HIDDEN_SIZE, 4 * HIDDEN_SIZE) + B_master = torch.randn(B_shape, dtype=dtype, device=device) + torch.distributed.broadcast(B_master, src=0) + B = torch.chunk(B_master, TESSERACT_DIM, dim=0)[i] + B = torch.chunk(B, TESSERACT_DIM, dim=-1)[j] + B = B.clone() + B.requires_grad = True + + out = Matmul_ABT_2p5D.apply( + C, + B, + TESSERACT_DIM, + (BATCH_SIZE // TESSERACT_DIM, SEQ_LENGTH, HIDDEN_SIZE // TESSERACT_DIM), + i, + j, + k, + ParallelMode.PARALLEL_2P5D_ROW, + ParallelMode.PARALLEL_2P5D_COL, + data_parallel_rank, + pipeline_parallel_rank, + pipeline_parallel_size, + tensor_parallel_size, + ) + + (BATCH_SIZE, SEQ_LENGTH, HIDDEN_SIZE) + C_master = C_master.clone() + C_master.requires_grad = True + B_master = B_master.clone() + B_master.requires_grad = True + A_master = torch.matmul(C_master, B_master.transpose(0, 1)) + A = torch.chunk(A_master, TESSERACT_DIM, dim=0)[i] + A = torch.chunk(A, TESSERACT_DIM, dim=-1)[j] + check_equal(out, A) + print_rank_0("ABT forward: pass") + + grad_shape = A_master.shape + grad_master = torch.randn(grad_shape, dtype=dtype, device=device) + torch.distributed.broadcast(grad_master, src=0) + grad = torch.chunk(grad_master, TESSERACT_DIM, dim=0)[i] + grad = torch.chunk(grad, TESSERACT_DIM, dim=-1)[j] + + # backward + out.backward(grad) + + A_master.backward(grad_master) + C_grad = C_master.grad + C_grad = torch.chunk(C_grad, TESSERACT_DIM, dim=0)[i] + C_grad = torch.chunk(C_grad, TESSERACT_DIM, dim=-1)[j] + check_equal(C_grad, C.grad) + + B_grad = B_master.grad + B_grad = torch.chunk(B_grad, TESSERACT_DIM, dim=0)[i] + B_grad = torch.chunk(B_grad, TESSERACT_DIM, dim=-1)[j] + check_equal(B_grad, B.grad) + print_rank_0("ABT backward: pass") + + +def check_ATB(): + data_parallel_rank = 0 if not gpc.is_initialized(ParallelMode.DATA) else gpc.get_local_rank(ParallelMode.DATA) + pipeline_parallel_rank = ( + 0 if not gpc.is_initialized(ParallelMode.PIPELINE) else gpc.get_local_rank(ParallelMode.PIPELINE) + ) + pipeline_parallel_size = ( + 1 if not gpc.is_initialized(ParallelMode.PIPELINE) else gpc.get_world_size(ParallelMode.PIPELINE) + ) + tensor_parallel_size = gpc.get_world_size(ParallelMode.TENSOR) + + device = get_current_device() + dtype = torch.float + + i = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL) + j = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_ROW) + k = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_DEP) + + A_shape = (BATCH_SIZE, SEQ_LENGTH, HIDDEN_SIZE) + A_master = torch.randn(A_shape, dtype=dtype, device=device) + torch.distributed.broadcast(A_master, src=0) + A = torch.chunk(A_master, TESSERACT_DIM, dim=0)[i] + A = torch.chunk(A, TESSERACT_DIM, dim=-1)[j] + A = A.clone() + A.requires_grad = True + + C_shape = (BATCH_SIZE, SEQ_LENGTH, 4 * HIDDEN_SIZE) + C_master = torch.randn(C_shape, dtype=dtype, device=device) + torch.distributed.broadcast(C_master, src=0) + C = torch.chunk(C_master, TESSERACT_DIM, dim=0)[i] + C = torch.chunk(C, TESSERACT_DIM, dim=-1)[j] + C = C.clone() + C.requires_grad = True + + out = Matmul_ATB_2p5D.apply( + A, + C, + TESSERACT_DIM, + (HIDDEN_SIZE // TESSERACT_DIM, 4 * HIDDEN_SIZE // TESSERACT_DIM), + i, + j, + k, + ParallelMode.PARALLEL_2P5D_ROW, + ParallelMode.PARALLEL_2P5D_COL, + data_parallel_rank, + pipeline_parallel_rank, + pipeline_parallel_size, + tensor_parallel_size, + ) + + (HIDDEN_SIZE, 4 * HIDDEN_SIZE) + A_master = A_master.clone() + A_master.requires_grad = True + C_master = C_master.clone() + C_master.requires_grad = True + B_master = torch.matmul( + A_master.view(-1, A_master.shape[-1]).transpose(0, 1), C_master.view(-1, C_master.shape[-1]) + ) + B = torch.chunk(B_master, TESSERACT_DIM, dim=0)[i] + B = torch.chunk(B, TESSERACT_DIM, dim=-1)[j] + check_equal(out, B) + print_rank_0("ATB forward: pass") + + grad_shape = B_master.shape + grad_master = torch.randn(grad_shape, dtype=dtype, device=device) + torch.distributed.broadcast(grad_master, src=0) + grad = torch.chunk(grad_master, TESSERACT_DIM, dim=0)[i] + grad = torch.chunk(grad, TESSERACT_DIM, dim=-1)[j] + + out.backward(grad) + + B_master.backward(grad_master) + A_grad = A_master.grad + A_grad = torch.chunk(A_grad, TESSERACT_DIM, dim=0)[i] + A_grad = torch.chunk(A_grad, TESSERACT_DIM, dim=-1)[j] + check_equal(A_grad, A.grad) + + C_grad = C_master.grad + C_grad = torch.chunk(C_grad, TESSERACT_DIM, dim=0)[i] + C_grad = torch.chunk(C_grad, TESSERACT_DIM, dim=-1)[j] + check_equal(C_grad, C.grad) + print_rank_0("ATB backward: pass") diff --git a/tests/test_legacy/test_layers/test_2p5d/checks_2p5d/common.py b/tests/test_legacy/test_layers/test_2p5d/checks_2p5d/common.py new file mode 100644 index 0000000000000000000000000000000000000000..c90d8fc086bd5dc269ed293dbb3ddfe4d005ef2f --- /dev/null +++ b/tests/test_legacy/test_layers/test_2p5d/checks_2p5d/common.py @@ -0,0 +1,14 @@ +import torch + +TESSERACT_DIM = 2 +TESSERACT_DEP = 2 +BATCH_SIZE = 8 +SEQ_LENGTH = 8 +HIDDEN_SIZE = 8 +NUM_CLASSES = 8 +VOCAB_SIZE = 16 +IMG_SIZE = 16 + + +def check_equal(A, B): + assert torch.allclose(A, B, rtol=1e-5, atol=1e-2) diff --git a/tests/test_layers/test_2p5d/test_2p5d.py b/tests/test_legacy/test_layers/test_2p5d/test_2p5d.py similarity index 80% rename from tests/test_layers/test_2p5d/test_2p5d.py rename to tests/test_legacy/test_layers/test_2p5d/test_2p5d.py index 373d834d003286a0f40cd8e4063fa1e4782cdb9c..437a8f8a7265c509babb3570b29e5295b92b3e79 100644 --- a/tests/test_layers/test_2p5d/test_2p5d.py +++ b/tests/test_legacy/test_layers/test_2p5d/test_2p5d.py @@ -3,15 +3,17 @@ import torch from checks_2p5d.check_layer_2p5d import * from checks_2p5d.check_operation_2p5d import check_AB, check_ABT, check_ATB -from colossalai.core import global_context as gpc -from colossalai.initialize import launch +from colossalai.legacy.core import global_context as gpc +from colossalai.legacy.initialize import launch from colossalai.logging import disable_existing_loggers from colossalai.testing import rerun_if_address_is_in_use, spawn -CONFIG = dict(parallel=dict( - pipeline=dict(size=1), - tensor=dict(size=4, mode='2.5d', depth=1), -),) +CONFIG = dict( + parallel=dict( + pipeline=dict(size=1), + tensor=dict(size=4, mode="2.5d", depth=1), + ), +) def check_operations(): @@ -36,7 +38,7 @@ def check_layer(): def check_layer_and_operation(rank, world_size, port): disable_existing_loggers() - launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + launch(config=CONFIG, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") torch.backends.cuda.matmul.allow_tf32 = False torch.backends.cudnn.allow_tf32 = False @@ -53,5 +55,5 @@ def test_2p5d(): spawn(check_layer_and_operation, 4) -if __name__ == '__main__': +if __name__ == "__main__": test_2p5d() diff --git a/tests/test_legacy/test_layers/test_3d/checks_3d/__init__.py b/tests/test_legacy/test_layers/test_3d/checks_3d/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/tests/test_layers/test_3d/checks_3d/check_layer_3d.py b/tests/test_legacy/test_layers/test_3d/checks_3d/check_layer_3d.py similarity index 78% rename from tests/test_layers/test_3d/checks_3d/check_layer_3d.py rename to tests/test_legacy/test_layers/test_3d/checks_3d/check_layer_3d.py index e946a1f5912d69cebcbe64bc3c317fc2c9774207..a4a4ae9a5ba480c942fe86cb3a3ace0e649344e9 100644 --- a/tests/test_layers/test_3d/checks_3d/check_layer_3d.py +++ b/tests/test_legacy/test_layers/test_3d/checks_3d/check_layer_3d.py @@ -5,10 +5,9 @@ import time import torch -from colossalai.constants import INPUT_GROUP_3D, OUTPUT_GROUP_3D, WEIGHT_GROUP_3D -from colossalai.core import global_context -from colossalai.logging import get_dist_logger -from colossalai.nn import ( +from colossalai.legacy.constants import INPUT_GROUP_3D, OUTPUT_GROUP_3D, WEIGHT_GROUP_3D +from colossalai.legacy.core import global_context +from colossalai.legacy.nn import ( Classifier3D, CrossEntropyLoss3D, Embedding3D, @@ -21,8 +20,10 @@ from colossalai.nn import ( VocabParallelCrossEntropyLoss3D, VocabParallelEmbedding3D, ) -from colossalai.nn.layer.parallel_3d._utils import get_parallel_mode_from_env -from colossalai.utils import get_current_device, print_rank_0 +from colossalai.legacy.nn.layer.parallel_3d._utils import get_parallel_mode_from_env +from colossalai.legacy.utils import print_rank_0 +from colossalai.logging import get_dist_logger +from colossalai.utils import get_current_device from .common import BATCH_SIZE, DEPTH, HIDDEN_SIZE, IMG_SIZE, NUM_CLASSES, SEQ_LENGTH, VOCAB_SIZE, check_equal @@ -72,14 +73,15 @@ def check_linear(): torch.cuda.synchronize() fwd_end = time.time() print_rank_0( - 'linear forward: {0} --> {1} | {2:.3f} s'.format(tuple(A.shape), tuple(out.shape), fwd_end - fwd_start), logger) + "linear forward: {0} --> {1} | {2:.3f} s".format(tuple(A.shape), tuple(out.shape), fwd_end - fwd_start), logger + ) A_master = A_master.clone() A_master.requires_grad = True C_master = layer_master(A_master) C = torch.chunk(C_master, DEPTH, dim=0)[i] C = torch.chunk(C, DEPTH, dim=-1)[j] C = torch.chunk(C, DEPTH, dim=0)[k] - logger.info('Rank {} linear forward: {}'.format(rank, check_equal(out, C))) + logger.info("Rank {} linear forward: {}".format(rank, check_equal(out, C))) grad_shape = C_master.shape grad_master = torch.randn(grad_shape, device=get_current_device()) @@ -92,24 +94,24 @@ def check_linear(): out.backward(grad) torch.cuda.synchronize() bwd_end = time.time() - print_rank_0('linear backward: {:.3f} s'.format(bwd_end - bwd_start), logger) + print_rank_0("linear backward: {:.3f} s".format(bwd_end - bwd_start), logger) C_master.backward(grad_master) A_grad = A_master.grad A_grad = torch.chunk(A_grad, DEPTH, dim=0)[i] A_grad = torch.chunk(A_grad, DEPTH, dim=-1)[k] A_grad = torch.chunk(A_grad, DEPTH, dim=0)[j] - logger.info('Rank {} linear backward (input_grad): {}'.format(rank, check_equal(A_grad, A.grad))) + logger.info("Rank {} linear backward (input_grad): {}".format(rank, check_equal(A_grad, A.grad))) B_grad = layer_master.weight.grad.transpose(0, 1) B_grad = torch.chunk(B_grad, DEPTH, dim=0)[k] B_grad = torch.chunk(B_grad, DEPTH, dim=-1)[j] B_grad = torch.chunk(B_grad, DEPTH, dim=0)[i] - logger.info('Rank {} linear backward (weight_grad): {}'.format(rank, check_equal(B_grad, layer.weight.grad))) + logger.info("Rank {} linear backward (weight_grad): {}".format(rank, check_equal(B_grad, layer.weight.grad))) bias_grad = layer_master.bias.grad bias_grad = torch.chunk(bias_grad, DEPTH)[j] - logger.info('Rank {} linear backward (bias_grad): {}'.format(rank, check_equal(bias_grad, layer.bias.grad))) + logger.info("Rank {} linear backward (bias_grad): {}".format(rank, check_equal(bias_grad, layer.bias.grad))) return fwd_end - fwd_start, bwd_end - bwd_start @@ -156,8 +158,11 @@ def check_layernorm(): torch.cuda.synchronize() fwd_end = time.time() print_rank_0( - 'layer norm forward: pass | {0} --> {1} | {2:.3f} s'.format(tuple(A.shape), tuple(out.shape), - fwd_end - fwd_start), logger) + "layer norm forward: pass | {0} --> {1} | {2:.3f} s".format( + tuple(A.shape), tuple(out.shape), fwd_end - fwd_start + ), + logger, + ) A_master = A_master.clone() A_master.requires_grad = True @@ -165,7 +170,7 @@ def check_layernorm(): C = torch.chunk(C_master, DEPTH, dim=0)[i] C = torch.chunk(C, DEPTH, dim=-1)[k] C = torch.chunk(C, DEPTH, dim=0)[j] - logger.info('Rank {} layernorm forward: {}'.format(rank, check_equal(out, C))) + logger.info("Rank {} layernorm forward: {}".format(rank, check_equal(out, C))) grad_shape = C_master.shape grad_master = torch.randn(grad_shape, device=device) @@ -178,22 +183,22 @@ def check_layernorm(): out.backward(grad) torch.cuda.synchronize() bwd_end = time.time() - print_rank_0('layer norm backward: pass | {:.3f} s'.format(bwd_end - bwd_start), logger) + print_rank_0("layer norm backward: pass | {:.3f} s".format(bwd_end - bwd_start), logger) C_master.backward(grad_master) A_grad = A_master.grad A_grad = torch.chunk(A_grad, DEPTH, dim=0)[i] A_grad = torch.chunk(A_grad, DEPTH, dim=-1)[k] A_grad = torch.chunk(A_grad, DEPTH, dim=0)[j] - logger.info('Rank {} layernorm backward (input_grad): {}'.format(rank, check_equal(A_grad, A.grad))) + logger.info("Rank {} layernorm backward (input_grad): {}".format(rank, check_equal(A_grad, A.grad))) bias_grad = norm_master.weight.grad bias_grad = torch.chunk(bias_grad, DEPTH)[k] - logger.info('Rank {} layernorm backward (weight_grad): {}'.format(rank, check_equal(bias_grad, norm.weight.grad))) + logger.info("Rank {} layernorm backward (weight_grad): {}".format(rank, check_equal(bias_grad, norm.weight.grad))) bias_grad = norm_master.bias.grad bias_grad = torch.chunk(bias_grad, DEPTH)[k] - logger.info('Rank {} layernorm backward (bias_grad): {}'.format(rank, check_equal(bias_grad, norm.bias.grad))) + logger.info("Rank {} layernorm backward (bias_grad): {}".format(rank, check_equal(bias_grad, norm.bias.grad))) return fwd_end - fwd_start, bwd_end - bwd_start @@ -240,14 +245,17 @@ def check_classifier_no_given_weight(): torch.cuda.synchronize() fwd_end = time.time() print_rank_0( - 'classifier (no given weight) forward: pass | {0} --> {1} | {2:.3f} s'.format( - tuple(A.shape), tuple(out.shape), fwd_end - fwd_start), logger) + "classifier (no given weight) forward: pass | {0} --> {1} | {2:.3f} s".format( + tuple(A.shape), tuple(out.shape), fwd_end - fwd_start + ), + logger, + ) A_master = A_master.clone() A_master.requires_grad = True C_master = layer_master(A_master) C = torch.chunk(C_master, DEPTH, dim=0)[i] C = torch.chunk(C, DEPTH, dim=0)[j] - logger.info('Rank {} classifier (no given weight) forward: {}'.format(rank, check_equal(out, C))) + logger.info("Rank {} classifier (no given weight) forward: {}".format(rank, check_equal(out, C))) grad_shape = C_master.shape grad_master = torch.randn(grad_shape, device=get_current_device()) @@ -260,7 +268,7 @@ def check_classifier_no_given_weight(): out.backward(grad) torch.cuda.synchronize() bwd_end = time.time() - print_rank_0('classifier (no given weight) backward: pass | {:.3f} s'.format(bwd_end - bwd_start), logger) + print_rank_0("classifier (no given weight) backward: pass | {:.3f} s".format(bwd_end - bwd_start), logger) grad_master = grad_master.clone() C_master.backward(grad_master) @@ -268,21 +276,29 @@ def check_classifier_no_given_weight(): A_grad = torch.chunk(A_grad, DEPTH, dim=0)[i] A_grad = torch.chunk(A_grad, DEPTH, dim=-1)[k] A_grad = torch.chunk(A_grad, DEPTH, dim=0)[j] - logger.info('Rank {} classifier (no given weight) backward (input_grad): {}'.format( - rank, check_equal(A_grad, A.grad))) + logger.info( + "Rank {} classifier (no given weight) backward (input_grad): {}".format(rank, check_equal(A_grad, A.grad)) + ) B_grad = layer_master.weight.grad B_grad = torch.chunk(B_grad, DEPTH, dim=-1)[k] if j == k: - logger.info('Rank {} classifier (no given weight) backward (weight_grad): {}'.format( - rank, check_equal(B_grad, layer.weight.grad))) + logger.info( + "Rank {} classifier (no given weight) backward (weight_grad): {}".format( + rank, check_equal(B_grad, layer.weight.grad) + ) + ) else: - logger.info('Rank {} classifier (no given weight) backward (weight_grad): {}'.format( - rank, layer.weight.grad is None)) + logger.info( + "Rank {} classifier (no given weight) backward (weight_grad): {}".format(rank, layer.weight.grad is None) + ) bias_grad = layer_master.bias.grad - logger.info('Rank {} classifier (no given weight) backward (bias_grad): {}'.format( - rank, check_equal(bias_grad, layer.bias.grad))) + logger.info( + "Rank {} classifier (no given weight) backward (bias_grad): {}".format( + rank, check_equal(bias_grad, layer.bias.grad) + ) + ) return fwd_end - fwd_start, bwd_end - bwd_start @@ -332,15 +348,18 @@ def check_vocab_parallel_classifier_no_given_weight(): torch.cuda.synchronize() fwd_end = time.time() print_rank_0( - 'vocab parallel classifier (no given weight) forward: pass | {0} --> {1} | {2:.3f} s'.format( - tuple(A.shape), tuple(out.shape), fwd_end - fwd_start), logger) + "vocab parallel classifier (no given weight) forward: pass | {0} --> {1} | {2:.3f} s".format( + tuple(A.shape), tuple(out.shape), fwd_end - fwd_start + ), + logger, + ) A_master = A_master.clone() A_master.requires_grad = True C_master = layer_master(A_master) C = torch.chunk(C_master, DEPTH, dim=0)[i] C = torch.chunk(C, DEPTH, dim=-1)[j] C = torch.chunk(C, DEPTH, dim=0)[k] - logger.info('Rank {} vocab parallel classifier (no given weight) forward: {}'.format(rank, check_equal(out, C))) + logger.info("Rank {} vocab parallel classifier (no given weight) forward: {}".format(rank, check_equal(out, C))) grad_shape = C_master.shape grad_master = torch.randn(grad_shape, device=device) @@ -354,8 +373,9 @@ def check_vocab_parallel_classifier_no_given_weight(): out.backward(grad) torch.cuda.synchronize() bwd_end = time.time() - print_rank_0('vocab parallel classifier (no given weight) backward: pass | {:.3f} s'.format(bwd_end - bwd_start), - logger) + print_rank_0( + "vocab parallel classifier (no given weight) backward: pass | {:.3f} s".format(bwd_end - bwd_start), logger + ) grad_master = grad_master.clone() C_master.backward(grad_master) @@ -363,20 +383,29 @@ def check_vocab_parallel_classifier_no_given_weight(): A_grad = torch.chunk(A_grad, DEPTH, dim=0)[i] A_grad = torch.chunk(A_grad, DEPTH, dim=-1)[k] A_grad = torch.chunk(A_grad, DEPTH, dim=0)[j] - logger.info('Rank {} vocab parallel classifier (no given weight) backward (input_grad): {}'.format( - rank, check_equal(A_grad, A.grad))) + logger.info( + "Rank {} vocab parallel classifier (no given weight) backward (input_grad): {}".format( + rank, check_equal(A_grad, A.grad) + ) + ) B_grad = layer_master.weight.grad B_grad = torch.chunk(B_grad, DEPTH, dim=0)[j] B_grad = torch.chunk(B_grad, DEPTH, dim=0)[i] B_grad = torch.chunk(B_grad, DEPTH, dim=-1)[k] - logger.info('Rank {} vocab parallel classifier (no given weight) backward (weight_grad): {}'.format( - rank, check_equal(B_grad, layer.weight.grad))) + logger.info( + "Rank {} vocab parallel classifier (no given weight) backward (weight_grad): {}".format( + rank, check_equal(B_grad, layer.weight.grad) + ) + ) bias_grad = layer_master.bias.grad bias_grad = torch.chunk(bias_grad, DEPTH)[j] - logger.info('Rank {} vocab parallel classifier (no given weight) backward (bias_grad): {}'.format( - rank, check_equal(bias_grad, layer.bias.grad))) + logger.info( + "Rank {} vocab parallel classifier (no given weight) backward (bias_grad): {}".format( + rank, check_equal(bias_grad, layer.bias.grad) + ) + ) return fwd_end - fwd_start, bwd_end - bwd_start @@ -422,13 +451,16 @@ def check_classifier_given_embed_weight(): torch.cuda.synchronize() fwd_end = time.time() print_rank_0( - 'classifier (given embed weight) forward: pass | {0} --> {1} | {2:.3f} s'.format( - tuple(A.shape), tuple(out.shape), fwd_end - fwd_start), logger) + "classifier (given embed weight) forward: pass | {0} --> {1} | {2:.3f} s".format( + tuple(A.shape), tuple(out.shape), fwd_end - fwd_start + ), + logger, + ) A_master = A_master.clone() C_master = layer_master(embed_master(A_master)) C = torch.chunk(C_master, DEPTH, dim=0)[i] C = torch.chunk(C, DEPTH, dim=0)[j] - logger.info('Rank {} classifier (given embed weight) forward: {}'.format(rank, check_equal(out, C))) + logger.info("Rank {} classifier (given embed weight) forward: {}".format(rank, check_equal(out, C))) grad_shape = C_master.shape grad_master = torch.randn(grad_shape, dtype=dtype, device=get_current_device()) @@ -441,7 +473,7 @@ def check_classifier_given_embed_weight(): out.backward(grad) torch.cuda.synchronize() bwd_end = time.time() - print_rank_0('classifier (given embed weight) backward: pass | {:.3f} s'.format(bwd_end - bwd_start), logger) + print_rank_0("classifier (given embed weight) backward: pass | {:.3f} s".format(bwd_end - bwd_start), logger) grad_master = grad_master.clone() C_master.backward(grad_master) @@ -449,11 +481,15 @@ def check_classifier_given_embed_weight(): B_grad = embed_master.weight.grad B_grad = torch.chunk(B_grad, DEPTH, dim=-1)[k] if j == k: - logger.info('Rank {} classifier (given embed weight) backward (weight_grad): {}'.format( - rank, check_equal(B_grad, embed.weight.grad))) + logger.info( + "Rank {} classifier (given embed weight) backward (weight_grad): {}".format( + rank, check_equal(B_grad, embed.weight.grad) + ) + ) else: - logger.info('Rank {} classifier (given embed weight) backward (weight_grad): {}'.format( - rank, embed.weight.grad is None)) + logger.info( + "Rank {} classifier (given embed weight) backward (weight_grad): {}".format(rank, embed.weight.grad is None) + ) return fwd_end - fwd_start, bwd_end - bwd_start @@ -500,14 +536,17 @@ def check_vocab_parallel_classifier_given_embed_weight(): torch.cuda.synchronize() fwd_end = time.time() print_rank_0( - 'vocab parallel classifier (given embed weight) forward: pass | {0} --> {1} | {2:.3f} s'.format( - tuple(A.shape), tuple(out.shape), fwd_end - fwd_start), logger) + "vocab parallel classifier (given embed weight) forward: pass | {0} --> {1} | {2:.3f} s".format( + tuple(A.shape), tuple(out.shape), fwd_end - fwd_start + ), + logger, + ) A_master = A_master.clone() C_master = layer_master(embed_master(A_master)) C = torch.chunk(C_master, DEPTH, dim=0)[i] C = torch.chunk(C, DEPTH, dim=-1)[j] C = torch.chunk(C, DEPTH, dim=0)[k] - logger.info('Rank {} vocab parallel classifier (given embed weight) forward: {}'.format(rank, check_equal(out, C))) + logger.info("Rank {} vocab parallel classifier (given embed weight) forward: {}".format(rank, check_equal(out, C))) grad_shape = C_master.shape grad_master = torch.randn(grad_shape, device=device) @@ -521,8 +560,9 @@ def check_vocab_parallel_classifier_given_embed_weight(): out.backward(grad) torch.cuda.synchronize() bwd_end = time.time() - print_rank_0('vocab parallel classifier (given embed weight) backward: pass | {:.3f} s'.format(bwd_end - bwd_start), - logger) + print_rank_0( + "vocab parallel classifier (given embed weight) backward: pass | {:.3f} s".format(bwd_end - bwd_start), logger + ) grad_master = grad_master.clone() C_master.backward(grad_master) @@ -531,9 +571,9 @@ def check_vocab_parallel_classifier_given_embed_weight(): B_grad = torch.chunk(B_grad, DEPTH, dim=0)[j] B_grad = torch.chunk(B_grad, DEPTH, dim=0)[i] B_grad = torch.chunk(B_grad, DEPTH, dim=-1)[k] - logger.info('Rank {} vocab parallel embed backward (weight_grad): {}'.format(rank, - check_equal(B_grad, - embed.weight.grad))) + logger.info( + "Rank {} vocab parallel embed backward (weight_grad): {}".format(rank, check_equal(B_grad, embed.weight.grad)) + ) return fwd_end - fwd_start, bwd_end - bwd_start @@ -542,7 +582,7 @@ def check_patch_embed(): rank = torch.distributed.get_rank() device = get_current_device() logger = get_dist_logger() - dtype = torch.float32 + torch.float32 input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D) weight_parallel_mode = get_parallel_mode_from_env(WEIGHT_GROUP_3D) @@ -581,15 +621,18 @@ def check_patch_embed(): torch.cuda.synchronize() fwd_end = time.time() print_rank_0( - 'patch embed forward: pass | {0} --> {1} | {2:.3f} s'.format(tuple(A.shape), tuple(out.shape), - fwd_end - fwd_start), logger) + "patch embed forward: pass | {0} --> {1} | {2:.3f} s".format( + tuple(A.shape), tuple(out.shape), fwd_end - fwd_start + ), + logger, + ) A_master = A_master.clone() C_master = layer_master(A_master) C = torch.chunk(C_master, DEPTH, dim=0)[i] C = torch.chunk(C, DEPTH, dim=-1)[k] C = torch.chunk(C, DEPTH, dim=0)[j] - logger.info('Rank {} patch embed forward: {}'.format(rank, check_equal(out, C))) + logger.info("Rank {} patch embed forward: {}".format(rank, check_equal(out, C))) grad_shape = C_master.shape grad_master = torch.randn(grad_shape, device=device) @@ -603,29 +646,32 @@ def check_patch_embed(): out.backward(grad) torch.cuda.synchronize() bwd_end = time.time() - print_rank_0('patch embed backward: pass | {:.3f} s'.format(bwd_end - bwd_start), logger) + print_rank_0("patch embed backward: pass | {:.3f} s".format(bwd_end - bwd_start), logger) grad_master = grad_master.clone() C_master.backward(grad_master) cls_grad_master = layer_master.cls_token.grad cls_grad = torch.chunk(cls_grad_master, DEPTH, dim=-1)[k] - logger.info('Rank {} patch embed backward (cls_grad): {}'.format(rank, check_equal(cls_grad, layer.cls_token.grad))) + logger.info("Rank {} patch embed backward (cls_grad): {}".format(rank, check_equal(cls_grad, layer.cls_token.grad))) pos_grad_master = layer_master.pos_embed.grad pos_grad = torch.chunk(pos_grad_master, DEPTH, dim=-1)[k] - logger.info('Rank {} patch embed backward (pos_embed_grad): {}'.format(rank, - check_equal(pos_grad, layer.pos_embed.grad))) + logger.info( + "Rank {} patch embed backward (pos_embed_grad): {}".format(rank, check_equal(pos_grad, layer.pos_embed.grad)) + ) B_grad = layer_master.weight.grad B_grad = torch.chunk(B_grad, DEPTH, dim=0)[k] - logger.info('Rank {} patch embed backward (proj_weight_grad): {}'.format(rank, - check_equal(B_grad, layer.weight.grad))) + logger.info( + "Rank {} patch embed backward (proj_weight_grad): {}".format(rank, check_equal(B_grad, layer.weight.grad)) + ) bias_grad = layer_master.bias.grad bias_grad = torch.chunk(bias_grad, DEPTH)[k] - logger.info('Rank {} patch embed backward (proj_bias_grad): {}'.format(rank, - check_equal(bias_grad, layer.bias.grad))) + logger.info( + "Rank {} patch embed backward (proj_bias_grad): {}".format(rank, check_equal(bias_grad, layer.bias.grad)) + ) return fwd_end - fwd_start, bwd_end - bwd_start @@ -634,7 +680,7 @@ def check_embed(): rank = torch.distributed.get_rank() device = get_current_device() logger = get_dist_logger() - dtype = torch.float32 + torch.float32 input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D) weight_parallel_mode = get_parallel_mode_from_env(WEIGHT_GROUP_3D) @@ -663,16 +709,17 @@ def check_embed(): out = layer(A) torch.cuda.synchronize() fwd_end = time.time() - logger.info('embed forward: pass | {0} --> {1} | {2:.3f} s'.format(tuple(A.shape), tuple(out.shape), - fwd_end - fwd_start), - ranks=[0]) + logger.info( + "embed forward: pass | {0} --> {1} | {2:.3f} s".format(tuple(A.shape), tuple(out.shape), fwd_end - fwd_start), + ranks=[0], + ) A_master = A_master.clone() C_master = layer_master(A_master) C = torch.chunk(C_master, DEPTH, dim=0)[i] C = torch.chunk(C, DEPTH, dim=-1)[k] C = torch.chunk(C, DEPTH, dim=0)[j] - logger.info('Rank {} embed forward: {}'.format(rank, check_equal(out, C))) + logger.info("Rank {} embed forward: {}".format(rank, check_equal(out, C))) grad_shape = C_master.shape grad_master = torch.randn(grad_shape, device=device) @@ -685,14 +732,14 @@ def check_embed(): out.backward(grad) torch.cuda.synchronize() bwd_end = time.time() - logger.info('embed backward: pass | {:.3f} s'.format(bwd_end - bwd_start), ranks=[0]) + logger.info("embed backward: pass | {:.3f} s".format(bwd_end - bwd_start), ranks=[0]) grad_master = grad_master.clone() C_master.backward(grad_master) B_grad = layer_master.weight.grad B_grad = torch.chunk(B_grad, DEPTH, dim=-1)[k] - logger.info('Rank {} embed backward (weight_grad): {}'.format(rank, check_equal(B_grad, layer.weight.grad))) + logger.info("Rank {} embed backward (weight_grad): {}".format(rank, check_equal(B_grad, layer.weight.grad))) return fwd_end - fwd_start, bwd_end - bwd_start @@ -701,7 +748,7 @@ def check_vocab_parallel_embed(): rank = torch.distributed.get_rank() device = get_current_device() logger = get_dist_logger() - dtype = torch.float32 + torch.float32 input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D) weight_parallel_mode = get_parallel_mode_from_env(WEIGHT_GROUP_3D) @@ -732,16 +779,19 @@ def check_vocab_parallel_embed(): out = layer(A) torch.cuda.synchronize() fwd_end = time.time() - logger.info('vocab parallel embed forward: pass | {0} --> {1} | {2:.3f} s'.format( - tuple(A.shape), tuple(out.shape), fwd_end - fwd_start), - ranks=[0]) + logger.info( + "vocab parallel embed forward: pass | {0} --> {1} | {2:.3f} s".format( + tuple(A.shape), tuple(out.shape), fwd_end - fwd_start + ), + ranks=[0], + ) A_master = A_master.clone() C_master = layer_master(A_master) C = torch.chunk(C_master, DEPTH, dim=0)[i] C = torch.chunk(C, DEPTH, dim=-1)[k] C = torch.chunk(C, DEPTH, dim=0)[j] - logger.info('Rank {} vocab parallel embed forward: {}'.format(rank, check_equal(out, C))) + logger.info("Rank {} vocab parallel embed forward: {}".format(rank, check_equal(out, C))) grad_shape = C_master.shape grad_master = torch.randn(grad_shape, device=device) @@ -754,7 +804,7 @@ def check_vocab_parallel_embed(): out.backward(grad) torch.cuda.synchronize() bwd_end = time.time() - logger.info('vocab parallel embed backward: pass | {:.3f} s'.format(bwd_end - bwd_start), ranks=[0]) + logger.info("vocab parallel embed backward: pass | {:.3f} s".format(bwd_end - bwd_start), ranks=[0]) grad_master = grad_master.clone() C_master.backward(grad_master) @@ -763,9 +813,9 @@ def check_vocab_parallel_embed(): B_grad = torch.chunk(B_grad, DEPTH, dim=0)[j] B_grad = torch.chunk(B_grad, DEPTH, dim=0)[i] B_grad = torch.chunk(B_grad, DEPTH, dim=-1)[k] - logger.info('Rank {} vocab parallel embed backward (weight_grad): {}'.format(rank, - check_equal(B_grad, - layer.weight.grad))) + logger.info( + "Rank {} vocab parallel embed backward (weight_grad): {}".format(rank, check_equal(B_grad, layer.weight.grad)) + ) return fwd_end - fwd_start, bwd_end - bwd_start @@ -797,25 +847,28 @@ def check_loss(): fwd_start = time.time() loss = criterion(out, target_master) fwd_end = time.time() - logger.info('cross entropy loss forward: pass | {0} --> {1} | {2:.3f} s'.format(tuple(out.shape), tuple(loss.shape), - fwd_end - fwd_start), - ranks=[0]) + logger.info( + "cross entropy loss forward: pass | {0} --> {1} | {2:.3f} s".format( + tuple(out.shape), tuple(loss.shape), fwd_end - fwd_start + ), + ranks=[0], + ) out_master = out_master.clone() out_master.requires_grad = True loss_master = criterion_master(out_master, target_master) - logger.info('Rank {} cross entropy loss forward: {}'.format(rank, check_equal(loss, loss_master))) + logger.info("Rank {} cross entropy loss forward: {}".format(rank, check_equal(loss, loss_master))) bwd_start = time.time() loss.backward() bwd_end = time.time() - logger.info('cross entropy loss backward: pass | {:.3f} s'.format(bwd_end - bwd_start), ranks=[0]) + logger.info("cross entropy loss backward: pass | {:.3f} s".format(bwd_end - bwd_start), ranks=[0]) loss_master.backward() out_grad = out_master.grad out_grad = torch.chunk(out_grad, DEPTH, dim=0)[i] out_grad = torch.chunk(out_grad, DEPTH, dim=0)[j] - logger.info('Rank {} cross entropy loss backward: {}'.format(rank, check_equal(out_grad, out.grad))) + logger.info("Rank {} cross entropy loss backward: {}".format(rank, check_equal(out_grad, out.grad))) return fwd_end - fwd_start, bwd_end - bwd_start @@ -824,7 +877,7 @@ def check_vocab_parallel_loss(): rank = torch.distributed.get_rank() logger = get_dist_logger() device = get_current_device() - dtype = torch.float32 + torch.float32 input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D) weight_parallel_mode = get_parallel_mode_from_env(WEIGHT_GROUP_3D) @@ -851,25 +904,28 @@ def check_vocab_parallel_loss(): fwd_start = time.time() loss = criterion(out, target_master) fwd_end = time.time() - logger.info('vocab parallel cross entropy loss forward: pass | {0} --> {1} | {2:.3f} s'.format( - tuple(out.shape), tuple(loss.shape), fwd_end - fwd_start), - ranks=[0]) + logger.info( + "vocab parallel cross entropy loss forward: pass | {0} --> {1} | {2:.3f} s".format( + tuple(out.shape), tuple(loss.shape), fwd_end - fwd_start + ), + ranks=[0], + ) out_master = out_master.clone() out_master.requires_grad = True loss_master = criterion_master(out_master, target_master) - logger.info('Rank {} vocab parallel cross entropy loss forward: {}'.format(rank, check_equal(loss, loss_master))) + logger.info("Rank {} vocab parallel cross entropy loss forward: {}".format(rank, check_equal(loss, loss_master))) bwd_start = time.time() loss.backward() bwd_end = time.time() - logger.info('vocab parallel cross entropy loss backward: pass | {:.3f} s'.format(bwd_end - bwd_start), ranks=[0]) + logger.info("vocab parallel cross entropy loss backward: pass | {:.3f} s".format(bwd_end - bwd_start), ranks=[0]) loss_master.backward() out_grad = out_master.grad out_grad = torch.chunk(out_grad, DEPTH, dim=0)[i] out_grad = torch.chunk(out_grad, DEPTH, dim=-1)[k] out_grad = torch.chunk(out_grad, DEPTH, dim=0)[j] - logger.info('Rank {} vocab parallel cross entropy loss backward: {}'.format(rank, check_equal(out_grad, out.grad))) + logger.info("Rank {} vocab parallel cross entropy loss backward: {}".format(rank, check_equal(out_grad, out.grad))) return fwd_end - fwd_start, bwd_end - bwd_start diff --git a/tests/test_legacy/test_layers/test_3d/checks_3d/common.py b/tests/test_legacy/test_layers/test_3d/checks_3d/common.py new file mode 100644 index 0000000000000000000000000000000000000000..509fc2cecf598461516d33e857c960ccb8c7a1e0 --- /dev/null +++ b/tests/test_legacy/test_layers/test_3d/checks_3d/common.py @@ -0,0 +1,19 @@ +#!/usr/bin/env python +# -*- encoding: utf-8 -*- + +import torch + +DEPTH = 2 +BATCH_SIZE = 8 +SEQ_LENGTH = 8 +HIDDEN_SIZE = 8 +NUM_CLASSES = 8 +NUM_BLOCKS = 2 +IMG_SIZE = 16 +VOCAB_SIZE = 16 + + +def check_equal(A, B): + eq = torch.allclose(A, B, rtol=1e-3, atol=1e-2) + assert eq, f"\nA = {A}\nB = {B}" + return eq diff --git a/tests/test_layers/test_3d/test_3d.py b/tests/test_legacy/test_layers/test_3d/test_3d.py similarity index 87% rename from tests/test_layers/test_3d/test_3d.py rename to tests/test_legacy/test_layers/test_3d/test_3d.py index fde71a4a0d26c172b45c6d0ff48497d55174b893..7057e2308b39f0899a23ed27087d94a3cd1d5f35 100644 --- a/tests/test_layers/test_3d/test_3d.py +++ b/tests/test_legacy/test_layers/test_3d/test_3d.py @@ -15,15 +15,15 @@ from checks_3d.check_layer_3d import ( check_vocab_parallel_loss, ) -from colossalai.core import global_context as gpc -from colossalai.initialize import launch +from colossalai.legacy.core import global_context as gpc +from colossalai.legacy.initialize import launch from colossalai.logging import disable_existing_loggers from colossalai.testing import rerun_if_address_is_in_use, skip_if_not_enough_gpus, spawn CONFIG = dict( parallel=dict( pipeline=1, - tensor=dict(mode='3d', size=8), + tensor=dict(mode="3d", size=8), ), seed=42, ) @@ -44,7 +44,7 @@ def check_layer(): def check_layer_and_operation(rank, world_size, port): disable_existing_loggers() - launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + launch(config=CONFIG, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") torch.backends.cuda.matmul.allow_tf32 = False torch.backends.cudnn.allow_tf32 = False torch.backends.cudnn.deterministic = True @@ -60,5 +60,5 @@ def test_3d(): spawn(check_layer_and_operation, 8) -if __name__ == '__main__': +if __name__ == "__main__": test_3d() diff --git a/tests/test_legacy/test_layers/test_cache_embedding.py b/tests/test_legacy/test_layers/test_cache_embedding.py new file mode 100644 index 0000000000000000000000000000000000000000..d64ff56b8a65b71e9b1fa3ff9872f767714c766d --- /dev/null +++ b/tests/test_legacy/test_layers/test_cache_embedding.py @@ -0,0 +1,396 @@ +import random +from typing import List + +import numpy as np +import pytest +import torch + +import colossalai +from colossalai.legacy.nn.parallel.layers import ( + CachedEmbeddingBag, + CachedParamMgr, + EvictionStrategy, + ParallelCachedEmbeddingBag, + ParallelCachedEmbeddingBagTablewise, + TablewiseEmbeddingBagConfig, +) +from colossalai.legacy.tensor import ComputePattern, ComputeSpec, ProcessGroup, ShardSpec +from colossalai.tensor import ColoTensor +from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn + +NUM_EMBED, EMBED_DIM = 10, 8 +BATCH_SIZE = 8 + + +def set_seed(seed): + """ + To achieve reproducible results, it's necessary to fix random seeds + """ + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + + +def synthesize_1d_sparse_feature( + batch_size, + num_embed, + device, +): + indices_in_batch = batch_size * 2 + indices = torch.randint(low=0, high=num_embed, size=(indices_in_batch,), device=device, dtype=torch.long) + offsets = ( + torch.from_numpy( + np.array( + [ + 0, + *np.sort(np.random.randint(low=0, high=indices_in_batch, size=(indices_in_batch - 1,))), + indices_in_batch, + ] + ) + ) + .to(device) + .long() + ) + return indices, offsets + + +@pytest.mark.skip +@clear_cache_before_run() +def test_cachemgr(): + model = torch.nn.EmbeddingBag(10000, 128) + # 10 chunks, 5 in cuda + mgr = CachedParamMgr(model.weight.detach(), 5) + assert mgr.cuda_row_num == 5 + + mgr._admit(1) + assert not mgr._chunk_in_cuda(2) + assert mgr._chunk_in_cuda(1) + + # print(mgr.cached_chunk_table) + mgr._admit(8) + + # now 3 chunk is available + assert mgr.cuda_available_chunk_num == 3 + + mgr._evict() + assert mgr.cuda_available_chunk_num == 4 + + mgr._prepare_rows_on_cuda(torch.tensor([9, 6, 5], dtype=torch.long, device=0)) + mgr._prepare_rows_on_cuda(torch.tensor([3, 4, 5], dtype=torch.long, device=0)) + # print(mgr.cached_chunk_table) + # mgr.print_comm_stats() + + mgr.flush() + assert mgr.cuda_available_chunk_num == 5 + + +@clear_cache_before_run() +def test_reorder_with_freq(): + num_embed = 100 + chunk_size = 1 + num_chunk = 5 + + idx_map = torch.randint(10000, size=(num_embed,)) + sorted_idx = torch.argsort(idx_map, descending=True).tolist() + chunkid, offset_in_chunk = [], [] + for i in range(num_embed): + idx = sorted_idx.index(i) + chunkid.append(idx // chunk_size) + offset_in_chunk.append(idx % chunk_size) + + dev = torch.device("cuda") + chunkid = torch.tensor(chunkid, dtype=torch.long, device=dev) + offset_in_chunk = torch.tensor(offset_in_chunk, dtype=torch.long, device=dev) + + weight = torch.rand(num_embed, 2) + mgr = CachedParamMgr(weight, num_chunk) + + mgr.reorder(idx_map) + + indices = mgr.idx_map.index_select(0, torch.arange(num_embed, dtype=torch.long, device=dev)) + mgr_chunk_id = torch.div(indices, chunk_size, rounding_mode="floor") + mgr_offsets = torch.remainder(indices, chunk_size) + assert torch.allclose(chunkid, mgr_chunk_id), f"chunk id: {chunkid}, mgr: {mgr_chunk_id}" + assert torch.allclose(offset_in_chunk, mgr_offsets), f"offset in chunk: {offset_in_chunk}, mgr: {mgr_offsets}" + + +@clear_cache_before_run() +@parameterize("use_LFU", [True, False]) +def test_freq_aware_embed(use_LFU: bool): + device = torch.device("cuda", 0) + evict_strategy = EvictionStrategy.LFU if use_LFU else EvictionStrategy.DATASET + model = CachedEmbeddingBag( + NUM_EMBED, + EMBED_DIM, + mode="mean", + include_last_offset=True, + cache_ratio=min(BATCH_SIZE * 2 / NUM_EMBED, 1.0), + ids_freq_mapping=None, + evict_strategy=evict_strategy, + ).to(device) + + assert model.weight.shape[0] == NUM_EMBED + ref_model = torch.nn.EmbeddingBag.from_pretrained( + model.weight.detach().to(device), mode="mean", include_last_offset=True, freeze=False + ) + + assert torch.allclose(ref_model.weight.detach(), model.weight.detach().to(device)) + + optimizer = torch.optim.SGD(model.parameters(), lr=1e-3) + ref_optimizer = torch.optim.SGD(ref_model.parameters(), lr=1e-3) + + for i in range(5): + indices, offsets = synthesize_1d_sparse_feature(BATCH_SIZE, NUM_EMBED, device) + res = model(indices, offsets) + ref_res = ref_model(indices, offsets) + assert torch.allclose(res, ref_res), f"model result: {res}, reference: {ref_res}" + + grad = torch.rand_like(res) + # comparing gradient here is nontrivial + res.backward(grad) + ref_res.backward(grad) + optimizer.step() + optimizer.zero_grad() + + ref_optimizer.step() + ref_optimizer.zero_grad() + + model.cache_weight_mgr.flush() + model_weight = model.weight.detach().to(device) + ref_weight = ref_model.weight.detach() + assert torch.allclose( + model_weight, ref_weight + ), f"model weight: {model_weight[10:18, :8]}, reference: {ref_weight[10:18, :8]}" + + +@clear_cache_before_run() +@parameterize("init_freq", [True, False]) +def test_lfu_strategy(init_freq: bool): + # minimal test to check behavior + Bag = CachedEmbeddingBag( + 5, + 5, + cache_ratio=3 / 5, + buffer_size=0, + pin_weight=True, + ids_freq_mapping=[4, 2, 1, 3, 1] if init_freq else None, + warmup_ratio=1.0, + evict_strategy=EvictionStrategy.LFU, + ) + + # print('cached_idx_map: ', Bag.cache_weight_mgr.cached_idx_map) + offsets = torch.tensor([0], device="cuda:0") + + # prepare frequency learning info: + Bag.forward(torch.tensor([2], device="cuda:0"), offsets) + Bag.forward(torch.tensor([1, 2], device="cuda:0"), offsets) + Bag.forward(torch.tensor([0, 2], device="cuda:0"), offsets) + Bag.forward(torch.tensor([0, 1, 2], device="cuda:0"), offsets) + Bag.forward(torch.tensor([0, 1, 2], device="cuda:0"), offsets) + Bag.forward(torch.tensor([0, 1, 2], device="cuda:0"), offsets) + Bag.forward(torch.tensor([0, 1, 2], device="cuda:0"), offsets) + Bag.forward(torch.tensor([0, 2], device="cuda:0"), offsets) + Bag.forward(torch.tensor([0, 2], device="cuda:0"), offsets) + Bag.forward(torch.tensor([0, 2], device="cuda:0"), offsets) + Bag.forward(torch.tensor([0, 2], device="cuda:0"), offsets) + Bag.forward(torch.tensor([0], device="cuda:0"), offsets) + Bag.forward(torch.tensor([0], device="cuda:0"), offsets) + Bag.forward(torch.tensor([0], device="cuda:0"), offsets) + Bag.forward(torch.tensor([0], device="cuda:0"), offsets) + + # check strategy + Bag.forward(torch.tensor([0, 1, 2], device="cuda:0"), offsets) + Bag.forward(torch.tensor([0, 1, 2], device="cuda:0"), offsets) + Bag.forward(torch.tensor([3], device="cuda:0"), offsets) # miss, evict 1 + Bag.forward(torch.tensor([2], device="cuda:0"), offsets) # hit + Bag.forward(torch.tensor([4], device="cuda:0"), offsets) # miss, evict 3 + Bag.forward(torch.tensor([2], device="cuda:0"), offsets) # hit + Bag.forward(torch.tensor([0], device="cuda:0"), offsets) # hit + + assert torch.allclose( + torch.Tensor(Bag.cache_weight_mgr.num_hits_history[-6:]), torch.Tensor([3, 0, 1, 0, 1, 1]) + ), "LFU strategy behavior failed" + + +def gather_tensor(tensor, rank, world_size): + gather_list = [] + if rank == 0: + gather_list = [torch.empty_like(tensor) for _ in range(world_size)] + + torch.distributed.gather(tensor, gather_list, dst=0) + return gather_list + + +def run_parallel_freq_aware_embed_tablewise(rank, world_size): + if world_size != 2: + return + device = torch.device("cuda", torch.cuda.current_device()) + + # initialize weight + # 3 feature tables. idx: 0~5, 6~10, 11~17 + weight_tables = torch.rand(18, 5) + weight_table1 = weight_tables[0:6] + weight_table2 = weight_tables[6:11] + weight_table3 = weight_tables[11:18] + embedding_bag_config_list: List[TablewiseEmbeddingBagConfig] = [] + embedding_bag_config_list.append( + TablewiseEmbeddingBagConfig( + num_embeddings=6, cuda_row_num=4, assigned_rank=0, initial_weight=weight_table1.clone().detach().cpu() + ) + ) + embedding_bag_config_list.append( + TablewiseEmbeddingBagConfig( + num_embeddings=5, cuda_row_num=4, assigned_rank=0, initial_weight=weight_table2.clone().detach().cpu() + ) + ) + embedding_bag_config_list.append( + TablewiseEmbeddingBagConfig( + num_embeddings=7, cuda_row_num=4, assigned_rank=1, initial_weight=weight_table3.clone().detach().cpu() + ) + ) + if rank == 0: + _weight = torch.cat([weight_table1, weight_table2], 0) + else: + _weight = weight_table3 + model = ParallelCachedEmbeddingBagTablewise( + embedding_bag_config_list, + embedding_dim=5, + _weight=_weight, + include_last_offset=True, + cache_ratio=0.5, + buffer_size=0, + evict_strategy=EvictionStrategy.LFU, + ) + # explain + """ + batch feature 1 feature 2 feature 3 + input0 [1,2,3] [6,7] [] + input1 [] [9] [13,15] + input2 [1,5] [6,8] [11] + ↑ ↑ ↑ + rank 0 rank 0 rank 1 + in KJT format + """ + res = model( + torch.tensor([1, 2, 3, 1, 5, 6, 7, 9, 6, 8, 13, 15, 11], device=device), + torch.tensor([0, 3, 3, 5, 7, 8, 10, 10, 12, 13], device=device), + already_split_along_rank=False, + ) + optimizer = torch.optim.SGD(model.parameters(), lr=1e-2) + rand_grad = torch.rand(3, 5 * 3, dtype=res.dtype, device=res.device) + if rank == 0: + fake_grad = rand_grad[0:2] + else: + fake_grad = rand_grad[2:] + res.backward(fake_grad) + optimizer.step() + optimizer.zero_grad() + + # check correctness + if rank == 0: + ref_model = torch.nn.EmbeddingBag.from_pretrained( + weight_tables.detach().clone(), include_last_offset=True, freeze=False + ).to(device) + ref_optimizer = torch.optim.SGD(ref_model.parameters(), lr=1e-2) + ref_fake_grad = torch.cat(rand_grad.split(5, 1), 0) + ref_res = ref_model( + torch.tensor([1, 2, 3, 1, 5, 6, 7, 9, 6, 8, 13, 15, 11], device=device), + torch.tensor([0, 3, 3, 5, 7, 8, 10, 10, 12, 13], device=device), + ) + ref_res.backward(ref_fake_grad) + ref_optimizer.step() + ref_optimizer.zero_grad() + + model.cache_weight_mgr.flush() + recover_weight = model.cache_weight_mgr.weight.to(device) + ref_weight = ref_model.weight.detach()[:11] + assert torch.allclose(recover_weight, ref_weight), f"{recover_weight - ref_weight}" + + +def run_parallel_freq_aware_embed_columnwise(rank, world_size): + device = torch.device("cuda", torch.cuda.current_device()) + + num_embed = 100 + embed_dim = 16 + batch_size = 4 + + set_seed(4321) + weight = torch.rand(num_embed, embed_dim) + coloweight = ColoTensor(weight.clone().detach().cpu(), spec=None) + + # initialize the tensor spec for the embedding weight parameter, + # which is an ColoParameter. + coloweight.set_process_group(ProcessGroup(tp_degree=world_size)) + coloweight.set_tensor_spec(ShardSpec(dims=[-1], num_partitions=[world_size]), ComputeSpec(ComputePattern.TP1D)) + + model = ParallelCachedEmbeddingBag.from_pretrained( + coloweight, + include_last_offset=True, + freeze=False, + cache_ratio=batch_size * 2 / num_embed, + ) + + assert model.cache_weight_mgr.weight.device.type == "cpu" + assert model.cache_weight_mgr.cuda_cached_weight.requires_grad + weight_in_rank = torch.tensor_split(weight, world_size, -1)[rank] + print(f"model weight: {model.cache_weight_mgr.weight.shape}, ref weight: {weight_in_rank.shape}") + assert torch.allclose( + weight_in_rank, model.cache_weight_mgr.weight.detach() + ), f"{weight_in_rank - model.cache_weight_mgr.weight}" + + optimizer = torch.optim.SGD(model.parameters(), lr=1e-3) + + if rank == 0: + ref_model = torch.nn.EmbeddingBag.from_pretrained( + weight.detach().clone(), include_last_offset=True, freeze=False + ).to(device) + ref_optimizer = torch.optim.SGD(ref_model.parameters(), lr=1e-3) + + set_seed(4321) + for i in range(5): + indices, offsets = synthesize_1d_sparse_feature(batch_size, num_embed, device) + res = model(indices, offsets) + + grad = torch.rand(batch_size * 2, embed_dim, dtype=res.dtype, device=res.device) + grad_in_rank = torch.tensor_split(grad, world_size, 0)[rank] + res.backward(grad_in_rank) + + optimizer.step() + optimizer.zero_grad() + + res_list = gather_tensor(res.detach(), rank, world_size) + + if rank == 0: + ref_res = ref_model(indices, offsets) + recover_res = torch.cat(res_list, dim=0) + + assert torch.allclose(ref_res, recover_res) + + ref_res.backward(grad) + ref_optimizer.step() + ref_optimizer.zero_grad() + + model.cache_weight_mgr.flush() + weight_list = gather_tensor(model.cache_weight_mgr.weight.detach().cuda(), rank, world_size) + if rank == 0: + recover_weight = torch.cat(weight_list, dim=1) + assert torch.allclose(recover_weight, ref_model.weight.detach()), f"{recover_weight - ref_model.weight}" + + +def run_dist(rank, world_size, port): + colossalai.legacy.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") + # run_parallel_freq_aware_embed_columnwise(rank, world_size) + run_parallel_freq_aware_embed_tablewise(rank, world_size) + + +@pytest.mark.dist +@pytest.mark.parametrize("world_size", [1, 4]) +@rerun_if_address_is_in_use() +def test_parallel_freq_aware_embed(world_size): + spawn(run_dist, world_size) + + +if __name__ == "__main__": + # test_freq_aware_embed(True) + test_parallel_freq_aware_embed(2) + # test_lfu_strategy(False) diff --git a/tests/test_legacy/test_layers/test_sequence/checks_seq/__init__.py b/tests/test_legacy/test_layers/test_sequence/checks_seq/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/tests/test_legacy/test_layers/test_sequence/checks_seq/check_layer_seq.py b/tests/test_legacy/test_layers/test_sequence/checks_seq/check_layer_seq.py new file mode 100644 index 0000000000000000000000000000000000000000..aa4d5d6ceeb37bbb622acc193a9f3340a8c4da8d --- /dev/null +++ b/tests/test_legacy/test_layers/test_sequence/checks_seq/check_layer_seq.py @@ -0,0 +1,22 @@ +import torch + +from colossalai.legacy.context import ParallelMode +from colossalai.legacy.core import global_context as gpc +from colossalai.legacy.nn import TransformerSelfAttentionRing +from colossalai.utils import get_current_device + + +def check_selfattention(): + WORLD_SIZE = gpc.get_world_size(ParallelMode.SEQUENCE) + SUB_SEQ_LENGTH = 8 + BATCH = 4 + HIDDEN_SIZE = 16 + + layer = TransformerSelfAttentionRing(16, 8, 8, 0.1) + layer = layer.to(get_current_device()) + + hidden_states = torch.rand(SUB_SEQ_LENGTH, BATCH, HIDDEN_SIZE).to(get_current_device()) + attention_mask = torch.randint(low=0, high=2, size=(BATCH, 1, 1, 1, SUB_SEQ_LENGTH * WORLD_SIZE)).to( + get_current_device() + ) + layer(hidden_states, attention_mask) diff --git a/tests/test_legacy/test_layers/test_sequence/test_sequence.py b/tests/test_legacy/test_layers/test_sequence/test_sequence.py new file mode 100644 index 0000000000000000000000000000000000000000..bdd3e04c6479555e026f2a9fb84a2ffbead478b6 --- /dev/null +++ b/tests/test_legacy/test_layers/test_sequence/test_sequence.py @@ -0,0 +1,137 @@ +import pytest +import torch +import torch.distributed as dist + +import colossalai +from colossalai.legacy.context import ParallelMode +from colossalai.legacy.core import global_context as gpc +from colossalai.legacy.nn.layer.parallel_sequence import RingAV, RingQK +from colossalai.testing import rerun_if_address_is_in_use, spawn + +CONFIG = dict(parallel=dict(tensor=dict(size=4, mode="sequence"))) + + +def check_ring_qk(rank, world_size): + # params + batch_size = 4 + num_heads = 4 + seq_length = 32 + attention_head_size = 32 + sub_seq_length = seq_length // world_size + + # create master tensors + q = torch.rand(batch_size * num_heads, seq_length, attention_head_size).cuda() + k = torch.rand(batch_size * num_heads, seq_length, attention_head_size).cuda() + dist.broadcast(q, src=0, group=gpc.get_group(ParallelMode.SEQUENCE)) + dist.broadcast(k, src=0, group=gpc.get_group(ParallelMode.SEQUENCE)) + + # create distributed tensors + sub_q = q.clone()[:, rank * sub_seq_length : (rank + 1) * sub_seq_length].contiguous() + sub_k = k.clone()[:, rank * sub_seq_length : (rank + 1) * sub_seq_length].contiguous() + + # set autograd attributes + q.requires_grad = True + k.requires_grad = True + q.retain_grad() + k.retain_grad() + sub_q.requires_grad = True + sub_k.requires_grad = True + sub_q.retain_grad() + sub_k.retain_grad() + + # compute master attention scores + a = torch.matmul(q, k.transpose(2, 1)) + + # compute distributed attention scores + ring_qk = RingQK.apply + sub_a = ring_qk(sub_q, sub_k, batch_size, num_heads, sub_seq_length) + + # check master and distributed attention scores + sub_master_a = a[:, rank * sub_seq_length : (rank + 1) * sub_seq_length] + assert torch.allclose(sub_a, sub_master_a, rtol=1e-5, atol=1e-2) + + # run master backward + a.retain_grad() + a.mean().backward() + + # run distributed backward + partial_master_a_grad = a.grad[:, rank * sub_seq_length : (rank + 1) * sub_seq_length] + torch.autograd.backward(sub_a, partial_master_a_grad) + + # check master and distributed grads + partial_master_q_grad = q.grad[:, rank * sub_seq_length : (rank + 1) * sub_seq_length] + assert torch.allclose(sub_q.grad, partial_master_q_grad, rtol=1e-5, atol=1e-2), "attention score cannot match" + + +def check_ring_av(rank, world_size): + # params + batch_size = 4 + num_heads = 4 + seq_length = 16 + attention_head_size = 32 + sub_seq_length = seq_length // world_size + + # create master tensors + a = torch.rand(batch_size * num_heads, seq_length, seq_length).cuda() + v = torch.rand(batch_size * num_heads, seq_length, attention_head_size).cuda() + dist.broadcast(a, src=0, group=gpc.get_group(ParallelMode.SEQUENCE)) + dist.broadcast(v, src=0, group=gpc.get_group(ParallelMode.SEQUENCE)) + + # create distributed tensors + sub_a = a.clone()[:, rank * sub_seq_length : (rank + 1) * sub_seq_length].contiguous() + sub_v = v.clone()[:, rank * sub_seq_length : (rank + 1) * sub_seq_length].contiguous() + + # set autograd attributes + a.requires_grad = True + v.requires_grad = True + a.retain_grad() + v.retain_grad() + sub_a.requires_grad = True + sub_v.requires_grad = True + sub_a.retain_grad() + sub_v.retain_grad() + + # compute master attention scores + out = torch.matmul(a, v) + + # compute distributed attention scores + ring_av = RingAV.apply + sub_out = ring_av(sub_a, sub_v, batch_size, num_heads, attention_head_size, sub_seq_length) + + # print(f'master output shape: {out.shape}, partial output shape: {sub_out.shape}') + + # check master and distributed output + sub_master_out = out[:, rank * sub_seq_length : (rank + 1) * sub_seq_length] + assert torch.allclose(sub_out, sub_master_out, rtol=1e-5, atol=1e-2) + + # # run master backward + out.retain_grad() + out.mean().backward() + + # # run distributed backward + partial_master_out_grad = out.grad[:, rank * sub_seq_length : (rank + 1) * sub_seq_length] + torch.autograd.backward(sub_out, partial_master_out_grad) + + # # check master and distributed grads + partial_master_a_grad = a.grad[:, rank * sub_seq_length : (rank + 1) * sub_seq_length] + assert torch.allclose(sub_a.grad, partial_master_a_grad, rtol=1e-5, atol=1e-2), "attention output cannot match" + + +def run_test(rank, world_size, port): + colossalai.legacy.launch(rank=rank, world_size=world_size, config=CONFIG, host="localhost", port=port) + + # check_ring_qk(rank, world_size) + check_ring_av(rank, world_size) + + gpc.destroy() + torch.cuda.empty_cache() + + +@pytest.mark.dist +@rerun_if_address_is_in_use() +def test_sequence(): + spawn(run_test, 4) + + +if __name__ == "__main__": + test_sequence() diff --git a/tests/test_legacy/test_pipeline/rpc_test_utils.py b/tests/test_legacy/test_pipeline/rpc_test_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..e59f22062cfcde36566433dc03a1b42d1cec8215 --- /dev/null +++ b/tests/test_legacy/test_pipeline/rpc_test_utils.py @@ -0,0 +1,147 @@ +import argparse +import os +import warnings + +import torch +import torch.distributed.rpc as rpc +import torch.multiprocessing as mp +from torch import nn +from torch._C._distributed_rpc import _is_current_rpc_agent_set + +from colossalai.legacy import launch +from colossalai.legacy.pipeline.pipeline_process_group import ppg +from colossalai.logging import disable_existing_loggers + +rpc_is_initialized = _is_current_rpc_agent_set + + +def color_debug(text, prefix=" ", color="blue"): + color = color.upper() + print(getattr(Back, color), prefix, Style.RESET_ALL, text) + + +class MLP(nn.Module): + def __init__(self, dim: int, layers: int): + super().__init__() + self.layers = torch.nn.ModuleList() + + for _ in range(layers): + self.layers.append(nn.Linear(dim, dim, bias=False)) + + def forward(self, x): + for layer in self.layers: + x = layer(x) + return x.sum() + + +class DAG_MLP(nn.Module): + def __init__(self, dim: int, layers: int): + super().__init__() + self.layers = torch.nn.ModuleList() + self.dag_layer = nn.Linear(dim, dim, bias=False) + + for _ in range(layers): + self.layers.append(nn.Linear(dim, dim, bias=False)) + + def forward(self, x, y): + for layer in self.layers: + x = layer(x) + y = self.dag_layer(y) + return x.sum(), y.sum() + + +class RpcTestModel(nn.Module): + def __init__(self, stage_id, actual_stage_num, feat_num, h) -> None: + super().__init__() + self.rank = stage_id + self.is_last_rank = stage_id == actual_stage_num - 1 + self.linear_name = f"linear_{stage_id}" + + if stage_id == 0: + linear = nn.Linear(feat_num, h) + elif stage_id == actual_stage_num - 1: + linear = nn.Linear(h, 1) + else: + linear = nn.Linear(h, h) + + setattr(self, self.linear_name, linear) + + def forward(self, x) -> torch.Tensor: + linear: nn.Module = getattr(self, self.linear_name) + out: torch.Tensor = linear(x) + + if self.is_last_rank: + out = out.sum() + return out + + +def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument("--epoch", type=int, default=1) + parser.add_argument("--world_size", type=int, default=2) + parser.add_argument("--batch_size", type=int, default=16) + parser.add_argument("--dp_degree", type=int, default=1) + parser.add_argument("--tp_degree", type=int, default=1) + parser.add_argument("--num_microbatches", type=int, default=2) + parser.add_argument("--chunk", type=int, default=1) + parser.add_argument("--use_checkpoint", action="store_true") + parser.add_argument("--optimizer", type=str, choices=["SGD", "Adam", "RMSprop"], default="SGD") + parser.add_argument("--device", type=str, choices=["cpu", "cuda"], default="cuda") + parser.add_argument("--master_addr", type=str, default="localhost") + parser.add_argument("--master_port", type=str, default="29020") + parser.add_argument("--num_worker_threads", type=str, default=128) + return parser.parse_args() + + +def pg_parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument("--world_size", type=int, default=4) + parser.add_argument("--dp_degree", type=int, default=2) + parser.add_argument("--tp_degree", type=int, default=1) + parser.add_argument("--chunk", type=int, default=1) + parser.add_argument("--num_worker_threads", type=str, default=128) + parser.add_argument("--device", type=str, choices=["cpu", "cuda"], default="cuda") + parser.add_argument("--master_addr", type=str, default="localhost") + parser.add_argument("--master_port", type=str, default="29020") + return parser.parse_args() + + +def run_worker(rank, args, master_func): + os.environ["MASTER_ADDR"] = args.master_addr + os.environ["MASTER_PORT"] = args.master_port + + device = args.device + world_size = args.world_size + dp_degree = args.dp_degree + tp_degree = args.tp_degree + num_worker_threads = args.num_worker_threads + host = args.master_addr + port = args.master_port + backend = "nccl" if device == "cuda" else "gloo" + + disable_existing_loggers() + + launch(dict(), rank, world_size, host, int(port), backend, verbose=False) + ppg.set_global_info( + rank=rank, + world_size=world_size, + dp_degree=dp_degree, + tp_degree=tp_degree, + num_worker_threads=num_worker_threads, + device=device, + ) + + # in rpc mode, only rank 0 is needed to be coded + if rank == 0: + master_func(args) + # barrier here + if rpc_is_initialized(): + rpc.shutdown() + else: + warnings.warn("RPC has not been initialized") + + +def rpc_run(args, master_func): + world_size = args.world_size + assert args.num_microbatches >= args.world_size, "num_microbatches cannot be fewer than world_size!" + mp.spawn(run_worker, args=(args, master_func), nprocs=world_size) diff --git a/tests/test_pipeline/test_cuda_rpc_chimera.py b/tests/test_legacy/test_pipeline/test_cuda_rpc_chimera.py similarity index 78% rename from tests/test_pipeline/test_cuda_rpc_chimera.py rename to tests/test_legacy/test_pipeline/test_cuda_rpc_chimera.py index 45ad8f828e61506649295b88ecd7e2fbfd5dcd3c..f6c077136607736c69939a0f27134b2b321ac1ab 100644 --- a/tests/test_pipeline/test_cuda_rpc_chimera.py +++ b/tests/test_legacy/test_pipeline/test_cuda_rpc_chimera.py @@ -1,10 +1,9 @@ import torch -from torch import nn import torch.autograd as autograd +from rpc_test_utils import RpcTestModel, parse_args, rpc_run +from torch import nn -from colossalai.pipeline.rpc import ChimeraPipelineEngine -from colossalai.testing import assert_close -from rpc_test_utils import rpc_run, parse_args, RpcTestModel +from colossalai.legacy.pipeline.rpc import ChimeraPipelineEngine # global variable for model created feat_num = 100 @@ -20,7 +19,7 @@ def partition(pp_rank: int, chunk: int, stage_num: int): def run_master(args): torch.manual_seed(100) - epoch = args.epoch + args.epoch device = args.device stage_num = args.world_size chunk = 1 @@ -32,11 +31,13 @@ def run_master(args): assert sample_num % batch_size == 0 - engine = ChimeraPipelineEngine(partition_fn=partition, - stage_num=stage_num, - num_microbatches=num_microbatches, - device=device, - checkpoint=use_checkpoint) + engine = ChimeraPipelineEngine( + partition_fn=partition, + stage_num=stage_num, + num_microbatches=num_microbatches, + device=device, + checkpoint=use_checkpoint, + ) engine.initialize_optimizer(torch.optim.Adam, lr=1e-3) input_sample = torch.randn((sample_num, feat_num), device=device) @@ -56,7 +57,8 @@ def run_master(args): # compute forward result and backward grad of parameters just in rank_0 test_model = nn.Sequential( - *[partition(pp_rank, chunk, actual_stage_num) for pp_rank in range(actual_stage_num)]).to(device) + *[partition(pp_rank, chunk, actual_stage_num) for pp_rank in range(actual_stage_num)] + ).to(device) # input_sample = input_sample[len(input_sample) // 2:] input_sample = input_sample.requires_grad_() out_val = test_model(input_sample).sum() diff --git a/tests/test_legacy/test_pipeline/test_cuda_rpc_optimizer.py b/tests/test_legacy/test_pipeline/test_cuda_rpc_optimizer.py new file mode 100644 index 0000000000000000000000000000000000000000..806f24a64511c9fbb91570292366c4746bfd6dbe --- /dev/null +++ b/tests/test_legacy/test_pipeline/test_cuda_rpc_optimizer.py @@ -0,0 +1,83 @@ +import torch +from rpc_test_utils import RpcTestModel, parse_args, rpc_run +from torch import autograd, nn +from torch.optim import Optimizer + +from colossalai.legacy.pipeline.rpc._pipeline_schedule import OneFOneBPipelineEngine +from colossalai.testing import assert_close + +# global variable for model created +feat_num = 100 +h = 100 + + +def partition(pp_rank: int, chunk: int, stage_num: int): + torch.manual_seed(1024) + partition = RpcTestModel(pp_rank, stage_num, feat_num, h) + return partition + + +def run_master(args): + torch.manual_seed(100) + + device = args.device + stage_num = args.world_size + chunk = args.chunk + actual_stage_num = stage_num * chunk + use_checkpoint = args.use_checkpoint + num_microbatches = args.num_microbatches + optimizer_class = globals()[args.optimizer] + + lr = 1e-3 + sample_num = 1024 + batch_size = 1024 + + assert sample_num % batch_size == 0 + + input_sample = torch.randn((sample_num, feat_num), device=device) + + engine = OneFOneBPipelineEngine( + partition_fn=partition, + stage_num=stage_num, + num_microbatches=num_microbatches, + device=device, + chunk=chunk, + checkpoint=use_checkpoint, + ) + + engine.initialize_optimizer(optimizer_class, lr=lr) + + _ = engine.forward_backward(input_sample) + + cuda_rpc_result = [] + single_result = [] + actual_stage_num = engine._get_actual_stage_num() + + # compute parameters after updating in cuda rpc + parameters = engine.remote_parameters() + for stage_id in range(actual_stage_num): + for p in parameters[stage_id]: + cuda_rpc_result.append(p) + + # compute forward result and backward grad of parameters just in rank_0 + test_model = nn.Sequential( + *[partition(pp_rank, chunk, actual_stage_num) for pp_rank in range(actual_stage_num)] + ).to(device) + optimizer: Optimizer = optimizer_class(test_model.parameters(), lr=lr) + input_sample = input_sample.requires_grad_() + out_val = test_model(input_sample).sum() + autograd.backward(out_val) + optimizer.step() + optimizer.zero_grad() + + for p in test_model.parameters(): + single_result.append(p) + + assert len(cuda_rpc_result) == len(single_result) + for r_c, r_s in zip(cuda_rpc_result, single_result): + assert_close(r_c, r_s, 0.001, 0.001) + + +if __name__ == "__main__": + args = parse_args() + rpc_run(args, run_master) diff --git a/tests/test_legacy/test_pipeline/test_cuda_rpc_pipeline.py b/tests/test_legacy/test_pipeline/test_cuda_rpc_pipeline.py new file mode 100644 index 0000000000000000000000000000000000000000..a5e8fc6e6b51417752863bf7602c9c7cf323e254 --- /dev/null +++ b/tests/test_legacy/test_pipeline/test_cuda_rpc_pipeline.py @@ -0,0 +1,49 @@ +import torch +from rpc_test_utils import RpcTestModel, parse_args, rpc_run + +from colossalai.legacy.pipeline.rpc._pipeline_schedule import OneFOneBPipelineEngine + +# global variable for model created +feat_num = 100 +h = 100 + + +def partition(pp_rank: int, chunk: int, stage_num: int): + torch.manual_seed(1024) + partition = RpcTestModel(pp_rank, stage_num, feat_num, h) + return partition + + +def run_master(args): + torch.manual_seed(100) + + epoch = args.epoch + device = args.device + stage_num = args.world_size + chunk = args.chunk + num_microbatches = args.num_microbatches + use_checkpoint = args.use_checkpoint + + sample_num = 1024 + batch_size = 1024 + + assert sample_num % batch_size == 0 + + input_sample = torch.randn((sample_num, feat_num), device=device) + + engine = OneFOneBPipelineEngine( + partition_fn=partition, + stage_num=stage_num, + num_microbatches=num_microbatches, + device=device, + chunk=chunk, + checkpoint=use_checkpoint, + ) + + for _ in range(epoch): + _ = engine.forward_backward(input_sample, forward_only=False) + + +if __name__ == "__main__": + args = parse_args() + rpc_run(args, run_master) diff --git a/tests/test_legacy/test_pipeline/test_cuda_rpc_value_correctness.py b/tests/test_legacy/test_pipeline/test_cuda_rpc_value_correctness.py new file mode 100644 index 0000000000000000000000000000000000000000..09c9b84a99075341a039de101edb298d34e128b7 --- /dev/null +++ b/tests/test_legacy/test_pipeline/test_cuda_rpc_value_correctness.py @@ -0,0 +1,75 @@ +import torch +from rpc_test_utils import RpcTestModel, parse_args, rpc_run +from torch import autograd, nn + +from colossalai.legacy.pipeline.rpc._pipeline_schedule import OneFOneBPipelineEngine +from colossalai.testing import assert_close + +feat_num = 100 +h = 100 + + +def partition(pp_rank: int, chunk: int, stage_num: int): + torch.manual_seed(1024) + partition = RpcTestModel(pp_rank, stage_num, feat_num, h) + return partition + + +def run_master(args): + torch.manual_seed(100) + + device = args.device + stage_num = args.world_size + chunk = args.chunk + actual_stage_num = stage_num * chunk + use_checkpoint = args.use_checkpoint + num_microbatches = args.num_microbatches + + sample_num = 1024 + batch_size = 1024 + + assert sample_num % batch_size == 0 + + input_sample = torch.randn((sample_num, feat_num), device=device) + + engine = OneFOneBPipelineEngine( + partition_fn=partition, + stage_num=stage_num, + num_microbatches=num_microbatches, + device=device, + chunk=chunk, + checkpoint=use_checkpoint, + ) + + forward_result = engine.forward_backward(input_sample) + + cuda_rpc_result = [] + single_result = [] + actual_stage_num = engine._get_actual_stage_num() + + # compute forward result and backward grad of parameters in cuda rpc + cuda_rpc_result.append(sum(forward_result[0])) + grad = engine.remote_grad() + for stage_id in range(actual_stage_num): + for p in grad[stage_id]: + cuda_rpc_result.append(p) + + # compute forward result and backward grad of parameters just in rank_0 + test_model = nn.Sequential( + *[partition(pp_rank, chunk, actual_stage_num) for pp_rank in range(actual_stage_num)] + ).to(device) + input_sample = input_sample.requires_grad_() + out_val = test_model(input_sample).sum() + autograd.backward(out_val) + single_result.append(out_val) + for p in test_model.parameters(): + single_result.append(p.grad) + + assert len(cuda_rpc_result) == len(single_result) + for r_c, r_s in zip(cuda_rpc_result, single_result): + assert_close(r_c, r_s, 0.001, 0.001) + + +if __name__ == "__main__": + args = parse_args() + rpc_run(args, run_master) diff --git a/tests/test_pipeline/test_middleware_1f1b.py b/tests/test_legacy/test_pipeline/test_middleware_1f1b.py similarity index 76% rename from tests/test_pipeline/test_middleware_1f1b.py rename to tests/test_legacy/test_pipeline/test_middleware_1f1b.py index 5b3aad70327598341d7196ae97b0d2e791f57e39..dff04c3ebba14ea5a5773afa077e21743f4016f5 100644 --- a/tests/test_pipeline/test_middleware_1f1b.py +++ b/tests/test_legacy/test_pipeline/test_middleware_1f1b.py @@ -7,13 +7,13 @@ import torch.distributed.rpc as rpc from rpc_test_utils import DAG_MLP, MLP from torch._C._distributed_rpc import _is_current_rpc_agent_set -from colossalai import launch from colossalai.fx import ColoTracer from colossalai.fx.passes.adding_split_node_pass import balanced_split_pass, split_with_split_nodes_pass +from colossalai.legacy import launch +from colossalai.legacy.pipeline.middleware.adaptor import get_fx_topology +from colossalai.legacy.pipeline.pipeline_process_group import ppg +from colossalai.legacy.pipeline.rpc._pipeline_schedule import OneFOneBPipelineEngine from colossalai.logging import disable_existing_loggers -from colossalai.pipeline.middleware.adaptor import get_fx_topology -from colossalai.pipeline.pipeline_process_group import ppg -from colossalai.pipeline.rpc._pipeline_schedule import OneFOneBPipelineEngine from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn # global variable for model created @@ -25,7 +25,7 @@ rpc_is_initialized = _is_current_rpc_agent_set def create_partition_module(pp_rank: int, stage_num: int, model, data_kwargs): model.eval() tracer = ColoTracer() - meta_args = {k: v.to('meta') for k, v in data_kwargs.items()} + meta_args = {k: v.to("meta") for k, v in data_kwargs.items()} graph = tracer.trace(root=model, meta_args=meta_args) gm = torch.fx.GraphModule(model, graph, model.__class__.__name__) annotated_model = balanced_split_pass(gm, stage_num) @@ -33,7 +33,7 @@ def create_partition_module(pp_rank: int, stage_num: int, model, data_kwargs): topo = get_fx_topology(top_module) for submodule in split_submodules: if isinstance(submodule, torch.fx.GraphModule): - setattr(submodule, '_topo', topo) + setattr(submodule, "_topo", topo) return split_submodules[pp_rank + 1] @@ -47,11 +47,11 @@ def run_master(model_cls, world_size, forward_only): torch.manual_seed(100) epoch = 3 - device = 'cuda' + device = "cuda" stage_num = world_size chunk = 1 num_microbatches = 8 - use_checkpoint = 'store_true' + use_checkpoint = "store_true" if model_cls == MLP: @@ -92,29 +92,26 @@ def run_master(model_cls, world_size, forward_only): checkpoint=use_checkpoint, ) if not forward_only: - engine.initialize_optimizer(getattr(torch.optim, 'SGD'), lr=1e-3) + engine.initialize_optimizer(getattr(torch.optim, "SGD"), lr=1e-3) for _ in range(epoch): input_x = torch.randn((batch_size, dim), device=device) input_y = torch.randn((batch_size, dim), device=device) - logits = engine.forward_backward({'x': input_x, 'y': input_y}, labels=labels, forward_only=forward_only) + logits = engine.forward_backward({"x": input_x, "y": input_y}, labels=labels, forward_only=forward_only) def run_worker(rank, world_size, port, model_cls, forward_only, master_func): - master_addr = 'localhost' + master_addr = "localhost" master_port = 29020 - os.environ['MASTER_ADDR'] = master_addr - os.environ['MASTER_PORT'] = str(master_port) + os.environ["MASTER_ADDR"] = master_addr + os.environ["MASTER_PORT"] = str(master_port) disable_existing_loggers() - launch(dict(), rank, world_size, master_addr, master_port, 'nccl', verbose=False) - ppg.set_global_info(rank=rank, - world_size=world_size, - dp_degree=1, - tp_degree=1, - num_worker_threads=128, - device='cuda') + launch(dict(), rank, world_size, master_addr, master_port, "nccl", verbose=False) + ppg.set_global_info( + rank=rank, world_size=world_size, dp_degree=1, tp_degree=1, num_worker_threads=128, device="cuda" + ) # in rpc mode, only rank 0 is needed to be coded if rank == 0: @@ -125,8 +122,8 @@ def run_worker(rank, world_size, port, model_cls, forward_only, master_func): @pytest.mark.skip("skip due to CI torch version 1.11") -@parameterize('model_cls', [MLP, DAG_MLP]) -@parameterize('forward_only', [True, False]) +@parameterize("model_cls", [MLP, DAG_MLP]) +@parameterize("forward_only", [True, False]) @pytest.mark.dist @rerun_if_address_is_in_use() def test_pp_middleware_fwd(model_cls, forward_only): diff --git a/tests/test_pipeline/test_pipelinable.py b/tests/test_legacy/test_pipeline/test_pipelinable.py similarity index 87% rename from tests/test_pipeline/test_pipelinable.py rename to tests/test_legacy/test_pipeline/test_pipelinable.py index 627cb5ac6f51968ae29c2743d5ec5a730af70814..950cc68036ae4b1df2cecad44d0a3a9990b811a6 100644 --- a/tests/test_pipeline/test_pipelinable.py +++ b/tests/test_legacy/test_pipeline/test_pipelinable.py @@ -1,14 +1,14 @@ +import pytest import torch -from colossalai.pipeline.pipelinable import PipelinableContext -from colossalai.testing import rerun_if_address_is_in_use, rerun_on_exception, spawn +from colossalai.legacy.pipeline.pipelinable import PipelinableContext +from colossalai.testing import rerun_if_address_is_in_use, spawn NUM_CHUNKS = 1 PIPELINE_SIZE = 2 class MLP(torch.nn.Module): - def __init__(self, dim: int = 256): super().__init__() intermediate_dim = dim * 4 @@ -48,10 +48,11 @@ def run_pipelinable(rank, world_size, port): assert layers_count_in_part_0 + layers_count_in_part_1 == pipelinable.layers_count +@pytest.mark.skip(reason="this is useless") @rerun_if_address_is_in_use() def test_pipelinable(): spawn(run_pipelinable, 1) -if __name__ == '__main__': +if __name__ == "__main__": test_pipelinable() diff --git a/tests/test_legacy/test_pipeline/test_pipeline_process_group.py b/tests/test_legacy/test_pipeline/test_pipeline_process_group.py new file mode 100644 index 0000000000000000000000000000000000000000..627aafb18e6155de9abed63b929ae5a0837f508f --- /dev/null +++ b/tests/test_legacy/test_pipeline/test_pipeline_process_group.py @@ -0,0 +1,44 @@ +import os + +import torch.distributed.rpc as rpc +from rpc_test_utils import pg_parse_args, rpc_is_initialized + +from colossalai.legacy.initialize import launch +from colossalai.legacy.pipeline.pipeline_process_group import ppg +from colossalai.logging import disable_existing_loggers +from colossalai.testing import spawn + + +def run_worker(rank, args): + os.environ["MASTER_ADDR"] = args.master_addr + os.environ["MASTER_PORT"] = args.master_port + + device = args.device + world_size = args.world_size + dp_degree = args.dp_degree + tp_degree = args.tp_degree + num_worker_threads = args.num_worker_threads + host = args.master_addr + port = args.master_port + backend = "nccl" if device == "cuda" else "gloo" + + disable_existing_loggers() + launch(dict(), rank, world_size, host, int(port), backend, verbose=False) + + ppg.set_global_info( + rank=rank, + world_size=world_size, + dp_degree=dp_degree, + tp_degree=tp_degree, + num_worker_threads=num_worker_threads, + device=device, + ) + + if rpc_is_initialized(): + rpc.shutdown() + + +if __name__ == "__main__": + args = pg_parse_args() + world_size = args.world_size + spawn(run_worker, world_size, args=args) diff --git a/tests/test_legacy/test_tensor/common_utils/__init__.py b/tests/test_legacy/test_tensor/common_utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..9a35d02ce5edd0e2f4a16831621041009928f129 --- /dev/null +++ b/tests/test_legacy/test_tensor/common_utils/__init__.py @@ -0,0 +1 @@ +from ._utils import * diff --git a/tests/test_legacy/test_tensor/common_utils/_utils.py b/tests/test_legacy/test_tensor/common_utils/_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..78bea66583647f769854ea24cccdd5526a9ede4e --- /dev/null +++ b/tests/test_legacy/test_tensor/common_utils/_utils.py @@ -0,0 +1,88 @@ +import os +import random + +import numpy as np +import torch +import torch.distributed as dist +from torch.testing import assert_close + +from colossalai.legacy.context import ParallelMode +from colossalai.legacy.core import global_context as gpc +from colossalai.legacy.tensor import ComputePattern, ComputeSpec, ShardSpec + + +def set_seed(seed): + random.seed(seed) + os.environ["PYTHONHASHSEED"] = str(seed) + np.random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + torch.backends.cudnn.deterministic = True + torch.backends.cudnn.benchmark = False + + +def check_equal(A, B): + assert torch.allclose(A, B, rtol=1e-3, atol=1e-1) == True + + +def replace_parameter_add_grad(layer, weight=None, bias=None): + if weight is not None: + delattr(layer, "weight") + setattr(layer, "weight", weight) + layer.weight.requires_grad = True + if bias is not None: + delattr(layer, "bias") + setattr(layer, "bias", bias) + layer.bias.requires_grad = True + + +def broadcast_tensor_chunk(tensor, chunk_size=1, local_rank=0): + dist.broadcast(tensor, src=0) + tensor_chunk = torch.chunk(tensor, chunk_size, dim=-1)[local_rank] + return tensor_chunk.clone() + + +def tensor_equal(t_a: torch.Tensor, t_b: torch.Tensor, rtol: float = 1e-3, atol: float = 1e-1): + assert_close(t_a, t_b, rtol=rtol, atol=atol) + return True + + +def tensor_shard_equal( + tensor: torch.Tensor, shard: torch.Tensor, rank: int, world_size: int, rtol: float = 1e-3, atol: float = 1e-1 +): + assert tensor.ndim == shard.ndim + if tensor.shape == shard.shape: + return tensor_equal(tensor, shard, rtol, atol) + else: + dims_not_eq = torch.nonzero(torch.tensor(tensor.shape) != torch.tensor(shard.shape)) + if dims_not_eq.numel() == 1: + # 1D shard + dim = dims_not_eq.item() + if world_size is None: + world_size = gpc.get_world_size(ParallelMode.PARALLEL_1D) + if rank is None: + rank = gpc.get_local_rank(ParallelMode.PARALLEL_1D) + return tensor_equal(tensor.chunk(world_size, dim)[rank], shard, rtol, atol) + else: + raise NotImplementedError + + +def split_param_single_dim_tp1d(dim, param, pg): + spec = (ShardSpec([dim], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D)) + if param.process_group.tp_world_size() == 1: + param.set_process_group(pg) + param.set_tensor_spec(*spec) + + +def split_param_row_tp1d(param, pg): + split_param_single_dim_tp1d(0, param, pg) + + +def split_param_col_tp1d(param, pg): + split_param_single_dim_tp1d(-1, param, pg) + + +def debug_print(ranks, *args): + if dist.get_rank() in ranks: + print(*args) + dist.barrier() diff --git a/tests/test_tensor/core/test_dist_spec_mgr.py b/tests/test_legacy/test_tensor/core/test_dist_spec_mgr.py similarity index 88% rename from tests/test_tensor/core/test_dist_spec_mgr.py rename to tests/test_legacy/test_tensor/core/test_dist_spec_mgr.py index 89476a35b63a0575c1f2cb30c0c7a74e0f4c0596..506244447054756520c75e33399b2117d632e955 100644 --- a/tests/test_tensor/core/test_dist_spec_mgr.py +++ b/tests/test_legacy/test_tensor/core/test_dist_spec_mgr.py @@ -5,7 +5,7 @@ import torch import torch.distributed as dist import colossalai -from colossalai.tensor import DistSpecManager, ProcessGroup, ReplicaSpec, ShardSpec +from colossalai.legacy.tensor import DistSpecManager, ProcessGroup, ReplicaSpec, ShardSpec from colossalai.testing import rerun_if_address_is_in_use, spawn @@ -48,17 +48,17 @@ def check_mem(): def run_dist(rank, world_size, port): - colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + colossalai.legacy.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") check_mem() run() @pytest.mark.dist -@pytest.mark.parametrize('world_size', [1, 4]) +@pytest.mark.parametrize("world_size", [1, 4]) @rerun_if_address_is_in_use() def test_dist_spec_mgr(world_size): spawn(run_dist, world_size) -if __name__ == '__main__': +if __name__ == "__main__": test_dist_spec_mgr(4) diff --git a/tests/test_legacy/test_tensor/test_parameter.py b/tests/test_legacy/test_tensor/test_parameter.py new file mode 100644 index 0000000000000000000000000000000000000000..5217e22cc4225e64df6ab2b9abe34b04aaebdb35 --- /dev/null +++ b/tests/test_legacy/test_tensor/test_parameter.py @@ -0,0 +1,35 @@ +import pytest +import torch +from common_utils import tensor_equal + +import colossalai +from colossalai.tensor import ColoParameter, ColoTensor +from colossalai.testing import free_port + + +@pytest.mark.skip +def test_multiinheritance(): + colossalai.legacy.launch(config={}, rank=0, world_size=1, host="localhost", port=free_port(), backend="nccl") + colo_param = ColoParameter(None, requires_grad=True) + assert colo_param.dist_spec.placement.value == "r" + assert isinstance(colo_param, ColoTensor) + assert isinstance(colo_param, torch.nn.Parameter) + + # __deepcopy__ overload + import copy + + colo_param2 = copy.deepcopy(colo_param) + assert isinstance(colo_param2, ColoParameter) + assert tensor_equal(colo_param.data, colo_param2.data) + assert colo_param.requires_grad == colo_param2.requires_grad + + # __repr__ overload + assert "ColoParameter" in str(colo_param) + + # __torch_function__ + clone_param = torch.clone(colo_param) + assert isinstance(clone_param, ColoTensor) + + +if __name__ == "__main__": + test_multiinheritance() diff --git a/tests/test_legacy/test_trainer/test_pipeline/test_p2p.py b/tests/test_legacy/test_trainer/test_pipeline/test_p2p.py new file mode 100644 index 0000000000000000000000000000000000000000..a5a2d38577dc5880789b9cea585d41fa8ef9ccb3 --- /dev/null +++ b/tests/test_legacy/test_trainer/test_pipeline/test_p2p.py @@ -0,0 +1,110 @@ +#!/usr/bin/env python +# -*- encoding: utf-8 -*- + +import pytest +import torch +import torch.distributed as dist + +from colossalai.legacy.communication import ( + recv_backward, + recv_forward, + send_backward, + send_backward_recv_forward, + send_forward, + send_forward_recv_backward, +) +from colossalai.legacy.context.parallel_mode import ParallelMode +from colossalai.legacy.core import global_context as gpc +from colossalai.legacy.initialize import launch +from colossalai.logging import get_dist_logger +from colossalai.testing import rerun_if_address_is_in_use, spawn +from colossalai.utils import get_current_device + +BATCH_SIZE = 4 +SEQ_LENGTH = 2 +HIDDEN_SIZE = 16 + +CONFIG = dict(parallel=dict(pipeline=dict(size=4), tensor=dict(size=1, mode=None)), seed=1024) + + +def check_equal(A, B): + return torch.allclose(A, B, rtol=1e-5, atol=1e-3) + + +def check_forward(output_tensor, rank, logger): + dist.barrier() + if gpc.is_first_rank(ParallelMode.PIPELINE): + tensor = output_tensor.clone() + else: + tensor = recv_forward(output_tensor.shape) + logger.info("Rank {} received forward. Correct tensor: {}".format(rank, check_equal(tensor, output_tensor))) + if not gpc.is_last_rank(ParallelMode.PIPELINE): + send_forward(tensor) + logger.info("Rank {} sent forward.".format(rank)) + + +def check_backward(output_grad, rank, logger): + dist.barrier() + if gpc.is_last_rank(ParallelMode.PIPELINE): + grad = output_grad.clone() + else: + grad = recv_backward(output_grad.shape) + logger.info("Rank {} received backward. Correct grad: {}".format(rank, check_equal(grad, output_grad))) + if not gpc.is_first_rank(ParallelMode.PIPELINE): + send_backward(grad) + logger.info("Rank {} sent backward.".format(rank)) + + +def check_forward_backward(output_tensor, output_grad, rank, logger): + dist.barrier() + if not gpc.is_first_rank(ParallelMode.PIPELINE): + tensor = send_backward_recv_forward(output_grad, output_tensor.shape) + logger.info( + "Rank {} sent backward received forward. Correct tensor: {}".format( + rank, check_equal(tensor, output_tensor) + ) + ) + if not gpc.is_last_rank(ParallelMode.PIPELINE): + grad = send_forward_recv_backward(output_tensor, output_grad.shape) + logger.info( + "Rank {} sent forward received backward. Correct grad: {}".format(rank, check_equal(grad, output_grad)) + ) + + +def check_comm(size, rank, prev_rank, next_rank, logger): + dtype = torch.float32 + device = get_current_device() + tensor_shape = (BATCH_SIZE, SEQ_LENGTH, HIDDEN_SIZE) + grad_shape = (BATCH_SIZE, SEQ_LENGTH, HIDDEN_SIZE) + tensor = torch.randn(tensor_shape, dtype=dtype, device=device) + dist.all_reduce(tensor) + grad = torch.randn(grad_shape, dtype=dtype, device=device) + dist.all_reduce(grad) + check_forward(tensor, rank, logger) + check_backward(grad, rank, logger) + check_forward_backward(tensor, grad, rank, logger) + + +def run_check(rank, world_size, port): + launch(config=CONFIG, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") + logger = get_dist_logger() + rank = gpc.get_global_rank() + prev_rank = gpc.get_prev_global_rank(ParallelMode.PIPELINE) + next_rank = gpc.get_next_global_rank(ParallelMode.PIPELINE) + logger.info("Rank {0}: prev rank {1}, next rank {2}".format(rank, prev_rank, next_rank)) + logger.info("Distributed environment is initialized.") + + check_comm(world_size, rank, prev_rank, next_rank, logger) + gpc.destroy() + torch.cuda.empty_cache() + + +@pytest.mark.dist +@rerun_if_address_is_in_use() +def test_p2p(): + world_size = 4 + spawn(run_check, world_size) + + +if __name__ == "__main__": + test_p2p() diff --git a/tests/test_legacy/test_trainer/test_pipeline/test_pipeline_schedule.py b/tests/test_legacy/test_trainer/test_pipeline/test_pipeline_schedule.py new file mode 100644 index 0000000000000000000000000000000000000000..cd7fcfe5635dea1180e5db5dc1c3dc9f189d9203 --- /dev/null +++ b/tests/test_legacy/test_trainer/test_pipeline/test_pipeline_schedule.py @@ -0,0 +1,90 @@ +# referenced from Megatron and used to testify communication + +import os +from pathlib import Path + +import pytest +import torch +import torch.nn as nn +from torchvision import transforms +from torchvision.datasets import CIFAR10 +from torchvision.models import resnet18 + +import colossalai +from colossalai.legacy.context import ParallelMode +from colossalai.legacy.core import global_context as gpc +from colossalai.legacy.initialize import launch +from colossalai.legacy.utils import get_dataloader, print_rank_0 +from colossalai.testing import rerun_if_address_is_in_use, spawn + +BATCH_SIZE = 8 + +CONFIG = dict(NUM_MICRO_BATCHES=2, parallel=dict(pipeline=dict(size=2), tensor=dict(size=1, mode=None))) + + +def run_schedule(rank, world_size, port): + launch(config=CONFIG, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") + + # build model + model = resnet18(num_classes=10) + + if gpc.get_local_rank(ParallelMode.PIPELINE) == 0: + model = nn.Sequential(model.conv1, model.bn1, model.relu, model.maxpool, model.layer1, model.layer2) + elif gpc.get_local_rank(ParallelMode.PIPELINE) == 1: + + class Flatten(nn.Module): + def forward(self, x): + return torch.flatten(x, 1) + + model = nn.Sequential(model.layer3, model.layer4, model.avgpool, Flatten(), model.fc) + + print_rank_0("model is created") + + train_dataset = CIFAR10( + root=Path(os.environ["DATA"]), + download=True, + transform=transforms.Compose( + [ + transforms.ToTensor(), + transforms.Normalize(mean=[0.4914, 0.4822, 0.4465], std=[0.2023, 0.1994, 0.2010]), + ] + ), + ) + + train_dataloader = get_dataloader( + dataset=train_dataset, + shuffle=True, + add_sampler=True, + batch_size=BATCH_SIZE, + pin_memory=True, + ) + + # build criterion + criterion = torch.nn.CrossEntropyLoss() + + # optimizer + optimizer = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=0) + + # initialize + engine, train_dataloader, _, _ = colossalai.legacy.initialize(model, optimizer, criterion, train_dataloader) + + # build pipeline schedule + schedule = engine.schedule + + # run schedule + data_iter = iter(train_dataloader) + schedule.forward_backward_step(engine, data_iter) + + gpc.destroy() + torch.cuda.empty_cache() + + +@pytest.mark.dist +@rerun_if_address_is_in_use() +def test_pipeline_schedule(): + world_size = 2 + spawn(run_schedule, world_size) + + +if __name__ == "__main__": + test_pipeline_schedule() diff --git a/tests/test_legacy/test_trainer/test_trainer_with_non_pipe_schedule.py b/tests/test_legacy/test_trainer/test_trainer_with_non_pipe_schedule.py new file mode 100644 index 0000000000000000000000000000000000000000..d19b12a5b044e8a73f2a302a2cde2290c4fe79a2 --- /dev/null +++ b/tests/test_legacy/test_trainer/test_trainer_with_non_pipe_schedule.py @@ -0,0 +1,62 @@ +import pytest +import torch + +import colossalai +from colossalai.legacy.amp.amp_type import AMP_TYPE +from colossalai.legacy.trainer import Trainer +from colossalai.logging import get_dist_logger +from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn +from colossalai.utils import MultiTimer +from tests.components_to_test.registry import non_distributed_component_funcs + +BATCH_SIZE = 4 +IMG_SIZE = 32 +NUM_EPOCHS = 200 + +CONFIG = dict(fp16=dict(mode=AMP_TYPE.TORCH)) + + +@parameterize("model_name", ["repeated_computed_layers", "resnet18", "nested_model"]) +def run_trainer(model_name): + get_components_func = non_distributed_component_funcs.get_callable(model_name) + model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func() + model = model_builder() + optimizer = optimizer_class(model.parameters(), lr=1e-3) + engine, train_dataloader, *_ = colossalai.legacy.initialize( + model=model, optimizer=optimizer, criterion=criterion, train_dataloader=train_dataloader + ) + + logger = get_dist_logger() + logger.info("engine is built", ranks=[0]) + + timer = MultiTimer() + trainer = Trainer(engine=engine, logger=logger, timer=timer) + logger.info("trainer is built", ranks=[0]) + + logger.info("start training", ranks=[0]) + trainer.fit( + train_dataloader=train_dataloader, + test_dataloader=test_dataloader, + epochs=NUM_EPOCHS, + max_steps=3, + display_progress=True, + test_interval=5, + ) + torch.cuda.empty_cache() + + +def run_dist(rank, world_size, port): + colossalai.legacy.launch( + config=CONFIG, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl" + ) + + +@pytest.mark.dist +@rerun_if_address_is_in_use() +def test_trainer_no_pipeline(): + world_size = 4 + spawn(run_dist, world_size) + + +if __name__ == "__main__": + test_trainer_no_pipeline() diff --git a/tests/test_legacy/test_trainer/test_trainer_with_pipe_schedule.py b/tests/test_legacy/test_trainer/test_trainer_with_pipe_schedule.py new file mode 100644 index 0000000000000000000000000000000000000000..0b34a79f96dd4fb328b154bafc6d2b97ded2bf0d --- /dev/null +++ b/tests/test_legacy/test_trainer/test_trainer_with_pipe_schedule.py @@ -0,0 +1,97 @@ +import os +from pathlib import Path + +import pytest +import torch +import torch.nn as nn +from torch.optim import Adam +from torchvision import transforms +from torchvision.datasets import CIFAR10 +from torchvision.models import resnet18 + +import colossalai +from colossalai.legacy.context.parallel_mode import ParallelMode +from colossalai.legacy.core import global_context as gpc +from colossalai.legacy.trainer import Trainer +from colossalai.legacy.utils import get_dataloader +from colossalai.logging import get_dist_logger +from colossalai.testing import rerun_if_address_is_in_use, spawn +from colossalai.utils import MultiTimer + +BATCH_SIZE = 4 +IMG_SIZE = 32 +NUM_EPOCHS = 200 + +CONFIG = dict( + NUM_MICRO_BATCHES=2, + parallel=dict(pipeline=2), +) + + +def run_trainer_with_pipeline(rank, world_size, port): + colossalai.legacy.launch( + config=CONFIG, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl" + ) + + # build model + model = resnet18(num_classes=10) + + if gpc.get_local_rank(ParallelMode.PIPELINE) == 0: + model = nn.Sequential(model.conv1, model.bn1, model.relu, model.maxpool, model.layer1, model.layer2) + elif gpc.get_local_rank(ParallelMode.PIPELINE) == 1: + + class Flatten(nn.Module): + def forward(self, x): + return torch.flatten(x, 1) + + model = nn.Sequential(model.layer3, model.layer4, model.avgpool, Flatten(), model.fc) + + # build dataloaders + train_dataset = CIFAR10( + root=Path(os.environ["DATA"]), + download=True, + transform=transforms.Compose( + [ + transforms.Resize(size=(IMG_SIZE, IMG_SIZE)), + transforms.ToTensor(), + transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)), + ] + ), + ) + + train_dataloader = get_dataloader( + dataset=train_dataset, shuffle=True, batch_size=BATCH_SIZE, pin_memory=True, drop_last=True + ) + + # build optimizer + optimizer = Adam(model.parameters(), lr=0.001) + criterion = nn.CrossEntropyLoss() + + engine, train_dataloader, *args = colossalai.legacy.initialize( + model=model, optimizer=optimizer, criterion=criterion, train_dataloader=train_dataloader + ) + + logger = get_dist_logger() + logger.info("engine is built", ranks=[0]) + timer = MultiTimer() + trainer = Trainer(engine=engine, logger=logger, timer=timer) + logger.info("trainer is built", ranks=[0]) + + logger.info("start training", ranks=[0]) + + trainer.fit( + train_dataloader=train_dataloader, epochs=NUM_EPOCHS, max_steps=3, display_progress=True, test_interval=5 + ) + gpc.destroy() + torch.cuda.empty_cache() + + +@pytest.mark.dist +@rerun_if_address_is_in_use() +def test_trainer_with_pipeline(): + world_size = 4 + spawn(run_trainer_with_pipeline, world_size) + + +if __name__ == "__main__": + test_trainer_with_pipeline() diff --git a/tests/test_utils/test_activation_checkpointing.py b/tests/test_legacy/test_utils/test_activation_checkpointing.py similarity index 81% rename from tests/test_utils/test_activation_checkpointing.py rename to tests/test_legacy/test_utils/test_activation_checkpointing.py index 59a8acd4b21022fae583e6737e75f47fb27c651e..3303f610ee8221d08d0dd940bb3f53b1e5f49750 100644 --- a/tests/test_utils/test_activation_checkpointing.py +++ b/tests/test_legacy/test_utils/test_activation_checkpointing.py @@ -1,14 +1,13 @@ #!/usr/bin/env python # -*- encoding: utf-8 -*- -import pytest import torch import torch.nn.functional as F -from colossalai.context.parallel_mode import ParallelMode -from colossalai.context.random import add_seed, reset_seeds, seed, set_mode +from colossalai.legacy.context.parallel_mode import ParallelMode +from colossalai.legacy.context.random import add_seed, reset_seeds, seed, set_mode +from colossalai.legacy.utils.activation_checkpoint import checkpoint from colossalai.testing import clear_cache_before_run, parameterize -from colossalai.utils.activation_checkpoint import checkpoint def forward(x, weight): @@ -40,25 +39,23 @@ def forward_inplace(x, weight): return out -@pytest.mark.gpu @clear_cache_before_run() @parameterize("use_reentrant", [True, False]) @parameterize("cpu_offload", [True, False]) def test_activation_checkpointing(cpu_offload, use_reentrant): - # as seed manager is singleton # if we don't reset seeds here, # other tests might affect this test reset_seeds() - # We put initilization here to avoid change cuda rng state below - inputs = torch.rand(2, 2, requires_grad=True, device='cuda') - weight = torch.rand(2, 4, requires_grad=True, device='cuda') + # We put initialization here to avoid change cuda rng state below + inputs = torch.rand(2, 2, requires_grad=True, device="cuda") + weight = torch.rand(2, 4, requires_grad=True, device="cuda") # Get a copy of input tensors - inputs_ = torch.empty(2, 2, requires_grad=True, device='cuda') + inputs_ = torch.empty(2, 2, requires_grad=True, device="cuda") inputs_.data.copy_(inputs.data) - weight_ = torch.empty(2, 4, requires_grad=True, device='cuda') + weight_ = torch.empty(2, 4, requires_grad=True, device="cuda") weight_.data.copy_(weight.data) add_seed(ParallelMode.GLOBAL, 1024) @@ -84,7 +81,7 @@ def test_activation_checkpointing(cpu_offload, use_reentrant): loss = out.sum() loss.backward() - assert torch.all(inputs.grad == inputs_.grad), 'Gradient of the input does not match' + assert torch.all(inputs.grad == inputs_.grad), "Gradient of the input does not match" torch.cuda.empty_cache() # Extra test for use_reentrant=False @@ -111,7 +108,7 @@ def test_activation_checkpointing(cpu_offload, use_reentrant): loss = out.sum() loss.backward() - assert torch.all(inputs.grad == inputs_.grad), 'Gradient of the input does not match' + assert torch.all(inputs.grad == inputs_.grad), "Gradient of the input does not match" torch.cuda.empty_cache() # as seed manager is singleton diff --git a/tests/test_utils/test_checkpoint/test_checkpoint_1d.py b/tests/test_legacy/test_utils/test_checkpoint/test_checkpoint_1d.py similarity index 78% rename from tests/test_utils/test_checkpoint/test_checkpoint_1d.py rename to tests/test_legacy/test_utils/test_checkpoint/test_checkpoint_1d.py index 335be61359ed2f802d6f1544e232dd9fd3c36dd7..c07ff132b79eb7d87ea68980d571e06e0036bb93 100644 --- a/tests/test_utils/test_checkpoint/test_checkpoint_1d.py +++ b/tests/test_legacy/test_utils/test_checkpoint/test_checkpoint_1d.py @@ -7,18 +7,18 @@ import pytest import torch import torch.nn as nn -import colossalai.nn as col_nn -from colossalai.context.parallel_mode import ParallelMode -from colossalai.core import global_context as gpc -from colossalai.initialize import launch +import colossalai.legacy.nn as col_nn +from colossalai.legacy.context.parallel_mode import ParallelMode +from colossalai.legacy.core import global_context as gpc +from colossalai.legacy.initialize import launch +from colossalai.legacy.utils import is_using_pp +from colossalai.legacy.utils.checkpointing import gather_pipeline_parallel_state_dict, load_checkpoint, save_checkpoint from colossalai.logging import disable_existing_loggers from colossalai.testing import rerun_if_address_is_in_use, skip_if_not_enough_gpus, spawn -from colossalai.utils import is_using_pp -from colossalai.utils.checkpointing import gather_pipeline_parallel_state_dict, load_checkpoint, save_checkpoint def build_pipeline(model): - from colossalai.pipeline.utils import partition_uniform + from colossalai.legacy.pipeline.utils import partition_uniform pipeline_size = gpc.get_world_size(ParallelMode.PIPELINE) pipeline_rank = gpc.get_local_rank(ParallelMode.PIPELINE) @@ -38,7 +38,9 @@ def check_equal(A, B): def check_checkpoint_1d(rank, world_size, port): - config = dict(parallel=dict(pipeline=dict(size=2), tensor=dict(size=4, mode="1d")),) + config = dict( + parallel=dict(pipeline=dict(size=2), tensor=dict(size=4, mode="1d")), + ) disable_existing_loggers() launch(config=config, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") diff --git a/tests/test_utils/test_checkpoint/test_checkpoint_2d.py b/tests/test_legacy/test_utils/test_checkpoint/test_checkpoint_2d.py similarity index 78% rename from tests/test_utils/test_checkpoint/test_checkpoint_2d.py rename to tests/test_legacy/test_utils/test_checkpoint/test_checkpoint_2d.py index 175d9ef6ceb9170e6bf0687e0bb8e88ffdfcb5f8..2ec1facf21b107b134a7b53cebbc9f4f95a2d9ed 100644 --- a/tests/test_utils/test_checkpoint/test_checkpoint_2d.py +++ b/tests/test_legacy/test_utils/test_checkpoint/test_checkpoint_2d.py @@ -7,18 +7,18 @@ import pytest import torch import torch.nn as nn -import colossalai.nn as col_nn -from colossalai.context.parallel_mode import ParallelMode -from colossalai.core import global_context as gpc -from colossalai.initialize import launch +import colossalai.legacy.nn as col_nn +from colossalai.legacy.context.parallel_mode import ParallelMode +from colossalai.legacy.core import global_context as gpc +from colossalai.legacy.initialize import launch +from colossalai.legacy.utils import is_using_pp +from colossalai.legacy.utils.checkpointing import gather_pipeline_parallel_state_dict, load_checkpoint, save_checkpoint from colossalai.logging import disable_existing_loggers from colossalai.testing import rerun_if_address_is_in_use, skip_if_not_enough_gpus, spawn -from colossalai.utils import is_using_pp -from colossalai.utils.checkpointing import gather_pipeline_parallel_state_dict, load_checkpoint, save_checkpoint def build_pipeline(model): - from colossalai.pipeline.utils import partition_uniform + from colossalai.legacy.pipeline.utils import partition_uniform pipeline_size = gpc.get_world_size(ParallelMode.PIPELINE) pipeline_rank = gpc.get_local_rank(ParallelMode.PIPELINE) @@ -38,7 +38,9 @@ def check_equal(A, B): def check_checkpoint_2d(rank, world_size, port): - config = dict(parallel=dict(pipeline=dict(size=2), tensor=dict(size=4, mode="2d")),) + config = dict( + parallel=dict(pipeline=dict(size=2), tensor=dict(size=4, mode="2d")), + ) disable_existing_loggers() launch(config=config, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") diff --git a/tests/test_utils/test_checkpoint/test_checkpoint_2p5d.py b/tests/test_legacy/test_utils/test_checkpoint/test_checkpoint_2p5d.py similarity index 78% rename from tests/test_utils/test_checkpoint/test_checkpoint_2p5d.py rename to tests/test_legacy/test_utils/test_checkpoint/test_checkpoint_2p5d.py index 33cb3a65d184f5ae14627010bc91de9d967c12df..a6bf702a84827024a61a336472ec2f1d29af7fb1 100644 --- a/tests/test_utils/test_checkpoint/test_checkpoint_2p5d.py +++ b/tests/test_legacy/test_utils/test_checkpoint/test_checkpoint_2p5d.py @@ -7,18 +7,18 @@ import pytest import torch import torch.nn as nn -import colossalai.nn as col_nn -from colossalai.context.parallel_mode import ParallelMode -from colossalai.core import global_context as gpc -from colossalai.initialize import launch +import colossalai.legacy.nn as col_nn +from colossalai.legacy.context.parallel_mode import ParallelMode +from colossalai.legacy.core import global_context as gpc +from colossalai.legacy.initialize import launch +from colossalai.legacy.utils import is_using_pp +from colossalai.legacy.utils.checkpointing import gather_pipeline_parallel_state_dict, load_checkpoint, save_checkpoint from colossalai.logging import disable_existing_loggers from colossalai.testing import rerun_if_address_is_in_use, skip_if_not_enough_gpus, spawn -from colossalai.utils import is_using_pp -from colossalai.utils.checkpointing import gather_pipeline_parallel_state_dict, load_checkpoint, save_checkpoint def build_pipeline(model): - from colossalai.pipeline.utils import partition_uniform + from colossalai.legacy.pipeline.utils import partition_uniform pipeline_size = gpc.get_world_size(ParallelMode.PIPELINE) pipeline_rank = gpc.get_local_rank(ParallelMode.PIPELINE) @@ -38,7 +38,9 @@ def check_equal(A, B): def check_checkpoint_2p5d(rank, world_size, port): - config = dict(parallel=dict(pipeline=dict(size=2), tensor=dict(size=4, depth=1, mode="2.5d")),) + config = dict( + parallel=dict(pipeline=dict(size=2), tensor=dict(size=4, depth=1, mode="2.5d")), + ) disable_existing_loggers() launch(config=config, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") diff --git a/tests/test_utils/test_checkpoint/test_checkpoint_3d.py b/tests/test_legacy/test_utils/test_checkpoint/test_checkpoint_3d.py similarity index 78% rename from tests/test_utils/test_checkpoint/test_checkpoint_3d.py rename to tests/test_legacy/test_utils/test_checkpoint/test_checkpoint_3d.py index 73ac2dd5fe1837ae9ee4f1a7b95e43fc6c51b482..12d9283129690910b984cb93b8a583ed51df9888 100644 --- a/tests/test_utils/test_checkpoint/test_checkpoint_3d.py +++ b/tests/test_legacy/test_utils/test_checkpoint/test_checkpoint_3d.py @@ -7,18 +7,18 @@ import pytest import torch import torch.nn as nn -import colossalai.nn as col_nn -from colossalai.context.parallel_mode import ParallelMode -from colossalai.core import global_context as gpc -from colossalai.initialize import launch +import colossalai.legacy.nn as col_nn +from colossalai.legacy.context.parallel_mode import ParallelMode +from colossalai.legacy.core import global_context as gpc +from colossalai.legacy.initialize import launch +from colossalai.legacy.utils import is_using_pp +from colossalai.legacy.utils.checkpointing import gather_pipeline_parallel_state_dict, load_checkpoint, save_checkpoint from colossalai.logging import disable_existing_loggers from colossalai.testing import rerun_if_address_is_in_use, skip_if_not_enough_gpus, spawn -from colossalai.utils import is_using_pp -from colossalai.utils.checkpointing import gather_pipeline_parallel_state_dict, load_checkpoint, save_checkpoint def build_pipeline(model): - from colossalai.pipeline.utils import partition_uniform + from colossalai.legacy.pipeline.utils import partition_uniform pipeline_size = gpc.get_world_size(ParallelMode.PIPELINE) pipeline_rank = gpc.get_local_rank(ParallelMode.PIPELINE) @@ -38,7 +38,9 @@ def check_equal(A, B): def check_checkpoint_3d(rank, world_size, port): - config = dict(parallel=dict(pipeline=dict(size=1), tensor=dict(size=8, mode="3d")),) + config = dict( + parallel=dict(pipeline=dict(size=1), tensor=dict(size=8, mode="3d")), + ) disable_existing_loggers() launch(config=config, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") diff --git a/tests/test_legacy/test_utils/test_memory.py b/tests/test_legacy/test_utils/test_memory.py new file mode 100644 index 0000000000000000000000000000000000000000..9416ac86e3253359dd3d6f3cd5bbdc09e2a79df8 --- /dev/null +++ b/tests/test_legacy/test_utils/test_memory.py @@ -0,0 +1,28 @@ +import pytest + +import colossalai +from colossalai.legacy.utils.memory import colo_device_memory_capacity, colo_set_process_memory_fraction +from colossalai.testing import spawn +from colossalai.utils.cuda import get_current_device + + +def _run_colo_set_process_memory_fraction_and_colo_device_memory_capacity(): + frac1 = colo_device_memory_capacity(get_current_device()) + colo_set_process_memory_fraction(0.5) + frac2 = colo_device_memory_capacity(get_current_device()) + assert frac2 * 2 == frac1 + + +def run_dist(rank, world_size, port): + colossalai.legacy.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") + _run_colo_set_process_memory_fraction_and_colo_device_memory_capacity() + + +@pytest.mark.dist +@pytest.mark.parametrize("world_size", [3, 4]) +def test_memory_utils(world_size): + spawn(run_dist, world_size) + + +if __name__ == "__main__": + test_memory_utils(world_size=2) diff --git a/tests/test_legacy/test_utils/test_norm_gradient_clipping.py b/tests/test_legacy/test_utils/test_norm_gradient_clipping.py new file mode 100644 index 0000000000000000000000000000000000000000..b5f2be705890df3c03d03ee790ca6bf8a1ed6a35 --- /dev/null +++ b/tests/test_legacy/test_utils/test_norm_gradient_clipping.py @@ -0,0 +1,78 @@ +import pytest +import torch +from torch.nn.parameter import Parameter +from torch.nn.utils import clip_grad_norm_ + +import colossalai +from colossalai.legacy.tensor import ColoTensorSpec, ProcessGroup, distspec +from colossalai.legacy.utils.common import clip_grad_norm +from colossalai.logging import disable_existing_loggers +from colossalai.tensor.colo_parameter import ColoParameter +from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn +from colossalai.utils import get_current_device + + +def close(num: float, other: float, rtol: float = 1e-5, atol: float = 1e-8): + return abs(num - other) <= atol + rtol * other + + +def shard_param(p: ColoParameter) -> None: + pg = p.get_process_group() + p._redistribute(distspec.ShardSpec([0], [pg.tp_world_size()])) + p.grad = p.grad.chunk(pg.tp_world_size(), 0)[pg.tp_local_rank()].clone().detach() + + +def check_grad_equal(p: Parameter, colo_p: ColoParameter) -> None: + pg = colo_p.get_process_group() + if p.shape != colo_p.shape: + grad = p.grad.chunk(pg.tp_world_size(), 0)[pg.tp_local_rank()] + else: + grad = p.grad + assert torch.allclose(grad, colo_p.grad), f"diff: {torch.abs(grad - colo_p.grad)}" + + +@parameterize("dtype", [torch.float]) +@parameterize("device", ["mixed", "cuda", "cpu"]) +@parameterize("norm_type", [2.0, 3.0, float("inf")]) +def run_grad_clip_norm(world_size: int, dtype: torch.dtype, device: str, norm_type: float): + print(f"{world_size}, {dtype}, {device}, {norm_type}") + cuda_device = get_current_device() + devices = [cuda_device] * 4 + if device == "cpu": + devices = [torch.device("cpu")] * 4 + elif device == "mixed": + devices = [cuda_device] * 2 + [torch.device("cpu")] * 2 + pg = ProcessGroup(tp_degree=world_size) + params = [Parameter(torch.empty(4, 4, dtype=dtype, device=devices[i])) for i in range(4)] + colo_params = [ + ColoParameter(torch.empty(4, 4, dtype=dtype, device=devices[i]), spec=ColoTensorSpec(pg)) for i in range(4) + ] + for p, colo_p in zip(params, colo_params): + grad = torch.rand_like(p) + p.grad = grad + colo_p.grad = grad.clone().detach() + shard_param(colo_params[0]) + shard_param(colo_params[2]) + torch_norm = clip_grad_norm_(params, 1.0, norm_type=norm_type) + colo_norm = clip_grad_norm(colo_params, 1.0, norm_type=norm_type) + assert close(torch_norm, colo_norm), f"diff: {abs(torch_norm-colo_norm)}" + for p, colo_p in zip(params, colo_params): + check_grad_equal(p, colo_p) + + +def run_dist(rank, world_size, port): + disable_existing_loggers() + colossalai.legacy.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") + run_grad_clip_norm(world_size=world_size) + + +@pytest.mark.skip("this need to be updated") +@pytest.mark.dist +@pytest.mark.parametrize("world_size", [1, 2]) +@rerun_if_address_is_in_use() +def test_zero_clip_grad(world_size: int): + spawn(run_dist, world_size) + + +if __name__ == "__main__": + test_zero_clip_grad(2) diff --git a/tests/test_legacy/test_zero/test_commons.py b/tests/test_legacy/test_zero/test_commons.py new file mode 100644 index 0000000000000000000000000000000000000000..741f519e13764b3bada596ff3c494ab8a970fc07 --- /dev/null +++ b/tests/test_legacy/test_zero/test_commons.py @@ -0,0 +1,41 @@ +import torch + +import colossalai +from colossalai.legacy.zero.gemini.tensor_utils import colo_model_data_tensor_move, colo_model_data_tensor_move_inline +from colossalai.legacy.zero.sharded_param import ShardedTensor +from colossalai.testing import rerun_if_address_is_in_use, spawn + + +def run_tensor_move(rank, world_size, port): + colossalai.legacy.launch(config={}, rank=0, world_size=world_size, host="localhost", port=port, backend="nccl") + + src_t = torch.ones(2, 3).cuda() + tgt_t = torch.zeros(2, 3) + + colo_model_data_tensor_move(src_t, tgt_t) + assert torch.sum(tgt_t) == 6.0, f"{torch.sum(tgt_t.payload)} vs. 6.0" + + src_t = torch.ones(2, 3) + tgt_t = torch.zeros(2, 3).cuda().half() + colo_model_data_tensor_move(src_t, tgt_t) + # the src_t has been removed + assert src_t.numel() == 0 + assert torch.sum(tgt_t) == 6.0, f"{torch.sum(tgt_t.payload)} vs. 6.0" + + src_t = ShardedTensor(torch.ones(2, 3)) + tgt_t = ShardedTensor(torch.zeros(2, 3).cuda().half()) + colo_model_data_tensor_move(src_t, tgt_t) + assert torch.sum(tgt_t.payload) == 6.0, f"{torch.sum(tgt_t.payload)} vs. 6.0" + + assert tgt_t.device.type == "cuda" + colo_model_data_tensor_move_inline(tgt_t, torch.device("cpu")) + assert tgt_t.device.type == "cpu" + + +@rerun_if_address_is_in_use() +def test_tensor_move(): + spawn(run_tensor_move, 1) + + +if __name__ == "__main__": + test_tensor_move() diff --git a/tests/test_moe/test_grad_handler.py b/tests/test_moe/test_grad_handler.py index e7002a75f3f7b1d7dd5eff4fac0eba0d5de9f95d..8742e5f411366e68bc4a03846d5f84d5c138b465 100644 --- a/tests/test_moe/test_grad_handler.py +++ b/tests/test_moe/test_grad_handler.py @@ -5,7 +5,7 @@ import torch.nn as nn import colossalai from colossalai.context.moe_context import MOE_CONTEXT -from colossalai.engine.gradient_handler import MoeGradientHandler +from colossalai.legacy.engine.gradient_handler import MoeGradientHandler from colossalai.nn.layer.moe import Experts, MoeLayer, Top1Router, UniformNoiseGenerator from colossalai.testing import assert_equal_in_group, rerun_if_address_is_in_use, spawn from colossalai.utils import get_current_device @@ -17,11 +17,11 @@ CONFIG = dict() def run_test(rank, world_size, port): - colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") expert_module = nn.Linear expert_factor = dict(in_features=DIM, out_features=DIM, device=get_current_device()) - MOE_CONTEXT.setup(42) # MOE initialization + MOE_CONTEXT.setup(42) # MOE initialization noisy_func = UniformNoiseGenerator() router = Top1Router(noisy_func=noisy_func) num_experts_list = [1, 2, 4] @@ -67,5 +67,5 @@ def test_grad_handler(): spawn(run_test, 4) -if __name__ == '__main__': +if __name__ == "__main__": test_grad_handler() diff --git a/tests/test_moe/test_kernel.py b/tests/test_moe/test_kernel.py index ad9a172b72aa479dd85b50483b79c4b7b22fa023..7a9c551d679d56bb3fae6f04bf93f5877ec4990b 100644 --- a/tests/test_moe/test_kernel.py +++ b/tests/test_moe/test_kernel.py @@ -3,9 +3,9 @@ import torch import torch.nn as nn import colossalai -from colossalai.context import ParallelMode from colossalai.context.moe_context import MOE_CONTEXT -from colossalai.core import global_context as gpc +from colossalai.legacy.context import ParallelMode +from colossalai.legacy.core import global_context as gpc from colossalai.nn.layer.moe import Experts, MoeLayer, Top1Router, Top2Router from colossalai.testing import rerun_if_address_is_in_use, spawn from colossalai.utils import get_current_device @@ -23,12 +23,12 @@ def run_routing(rank, world_size, port, rs=2, hidden_size=128, data_type=torch.f # Here we do not need TF32, since it brings absolute error on results torch.backends.cuda.matmul.allow_tf32 = False - colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") local_rank = gpc.get_local_rank(ParallelMode.GLOBAL) - MOE_CONTEXT.setup(42) # MOE environment initialization + MOE_CONTEXT.setup(42) # MOE environment initialization MOE_CONTEXT.reset_loss() - torch.manual_seed(rs + local_rank) # set each process has different random seed + torch.manual_seed(rs + local_rank) # set each process has different random seed # get randomized data tokens = torch.randn(BATCH_SIZE, hidden_size, dtype=data_type, device=get_current_device(), requires_grad=True) @@ -41,12 +41,12 @@ def run_routing(rank, world_size, port, rs=2, hidden_size=128, data_type=torch.f if data_type == torch.float16: layer = layer.half() - # use matrix multiplication instead of COL_MOE_KERNL in MOE dispatch and combine + # use matrix multiplication instead of COL_MOE_KERNEL in MOE dispatch and combine layer.use_kernel = False old_out, _ = layer(tokens) ech = old_out.shape grad = torch.randn(ech, device=get_current_device()) - old_out.backward(grad) # get gradient + old_out.backward(grad) # get gradient # save all results o_tk_grad = tokens.grad.data.clone() @@ -57,7 +57,7 @@ def run_routing(rank, world_size, port, rs=2, hidden_size=128, data_type=torch.f layer.gate_weight.grad.zero_() layer.use_kernel = True - new_out, _ = layer(tokens) # get ouputs through colossal kernel + new_out, _ = layer(tokens) # get outputs through colossal kernel if data_type == torch.float32: check_equal(old_out, new_out) @@ -65,7 +65,7 @@ def run_routing(rank, world_size, port, rs=2, hidden_size=128, data_type=torch.f check_equal(old_out, new_out, 1e-2) # forward function passed - new_out.backward(grad) # get new type gradient + new_out.backward(grad) # get new type gradient n_tk_grad = tokens.grad.data.clone() n_gt_grad = layer.gate_weight.grad.data.clone() @@ -92,5 +92,5 @@ def test_moe_kernel(rs, hidden_size, data_type, router): spawn(run_routing, 4, rs=rs, hidden_size=hidden_size, data_type=data_type, router=router) -if __name__ == '__main__': +if __name__ == "__main__": test_moe_kernel(2, 256, torch.float16, Top2Router) diff --git a/tests/test_moe/test_moe_checkpoint.py b/tests/test_moe/test_moe_checkpoint.py index 8a0283ba71fc9c6798c5800e4cc529acee47de99..b7024f32b1cfebbaa05cbb740c351102d14da7c0 100644 --- a/tests/test_moe/test_moe_checkpoint.py +++ b/tests/test_moe/test_moe_checkpoint.py @@ -17,11 +17,11 @@ from tests.test_zero.test_legacy.common import CONFIG def exam_moe_checkpoint(): with ColoInitContext(device=get_current_device()): model = MoeModel(checkpoint=True) - save_moe_model(model, 'temp_path.pth') + save_moe_model(model, "temp_path.pth") with ColoInitContext(device=get_current_device()): other_model = MoeModel(checkpoint=True) - load_moe_model(other_model, 'temp_path.pth') + load_moe_model(other_model, "temp_path.pth") state_0 = model.state_dict() state_1 = other_model.state_dict() @@ -30,11 +30,11 @@ def exam_moe_checkpoint(): assert torch.equal(u.data, v.data) if dist.get_rank() == 0: - os.remove('temp_path.pth') + os.remove("temp_path.pth") def _run_dist(rank, world_size, port): - colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") MOE_CONTEXT.setup(seed=42) exam_moe_checkpoint() @@ -46,5 +46,5 @@ def test_moe_checkpoint(world_size): spawn(_run_dist) -if __name__ == '__main__': +if __name__ == "__main__": test_moe_checkpoint(world_size=4) diff --git a/tests/test_moe/test_moe_colo_init.py b/tests/test_moe/test_moe_colo_init.py index 555338fcf9fcde69bec2c19efdd26d97977ada66..488573b733b1e773bd489613759cd38a4eb56f4e 100644 --- a/tests/test_moe/test_moe_colo_init.py +++ b/tests/test_moe/test_moe_colo_init.py @@ -9,17 +9,16 @@ from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn from colossalai.utils import get_current_device from colossalai.zero import ColoInitContext from tests.test_moe.test_moe_zero_init import MoeModel -from tests.test_tensor.common_utils import debug_print from tests.test_zero.test_legacy.common import CONFIG -@parameterize("init_device_type", ['cpu', 'cuda']) +@parameterize("init_device_type", ["cpu", "cuda"]) def exam_moe_colo_init(init_device_type): world_size = dist.get_world_size() - if init_device_type == 'cuda': + if init_device_type == "cuda": init_device = get_current_device() - elif init_device_type == 'cpu': + elif init_device_type == "cpu": init_device = torch.device("cpu") else: raise NotImplementedError("Unknown device found.") @@ -40,7 +39,7 @@ def exam_moe_colo_init(init_device_type): def _run_dist(rank, world_size, port): - colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") MOE_CONTEXT.setup(seed=42) exam_moe_colo_init() @@ -52,5 +51,5 @@ def test_moe_colo_init(world_size): spawn(_run_dist, world_size) -if __name__ == '__main__': +if __name__ == "__main__": test_moe_colo_init(world_size=4) diff --git a/tests/test_moe/test_moe_group.py b/tests/test_moe/test_moe_group.py index 6dc3f5f18b6df04764800814c8d02e70205347af..300fb6c99b7b0cbfe57b32fc31c4a91f629f27f0 100644 --- a/tests/test_moe/test_moe_group.py +++ b/tests/test_moe/test_moe_group.py @@ -16,11 +16,11 @@ CONFIG = dict() def run_test(rank, world_size, port): world_size = 4 - colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") expert_module = nn.Linear expert_factor = dict(in_features=D_MODEL, out_features=D_FF, device=get_current_device()) - MOE_CONTEXT.setup(42) # MOE environment initialization + MOE_CONTEXT.setup(42) # MOE environment initialization exp0 = Experts(expert_module, 1, **expert_factor) exp1 = Experts(expert_module, 2, **expert_factor) exp2 = Experts(expert_module, 4, **expert_factor) @@ -64,5 +64,5 @@ def test_moe_initialization(): spawn(run_test, 4) -if __name__ == '__main__': +if __name__ == "__main__": test_moe_initialization() diff --git a/tests/test_moe/test_moe_zero_init.py b/tests/test_moe/test_moe_zero_init.py index 79722f9f40560bfa78ee2d871da94fdab2a185a3..c48f9a3557ce36c90527a629818a26e3f84bf7b2 100644 --- a/tests/test_moe/test_moe_zero_init.py +++ b/tests/test_moe/test_moe_zero_init.py @@ -15,20 +15,15 @@ from tests.test_zero.test_legacy.common import CONFIG class MoeModel(nn.Module): - def __init__(self, checkpoint: bool = False): - class TestSubModule(CheckpointModule): - def __init__(self): super().__init__(checkpoint) expert_cls = nn.Linear expert_args_dict = dict(in_features=16, out_features=16) - self.moe = MoeModule(dim_model=16, - num_experts=8, - use_residual=True, - expert_cls=expert_cls, - **expert_args_dict) + self.moe = MoeModule( + dim_model=16, num_experts=8, use_residual=True, expert_cls=expert_cls, **expert_args_dict + ) self.proj = nn.Linear(16, 4) def _forward(self, x): @@ -50,49 +45,52 @@ class MoeModel(nn.Module): return x -@parameterize("init_device_type", ['cpu', 'cuda']) +@parameterize("init_device_type", ["cpu", "cuda"]) @parameterize("shard_strategy_class", [TensorShardStrategy, BucketTensorShardStrategy]) def run_moe_zero_init(init_device_type, shard_strategy_class): - logger = get_dist_logger("test_moe_zero_init") + get_dist_logger("test_moe_zero_init") - if init_device_type == 'cuda': + if init_device_type == "cuda": init_device = get_current_device() - elif init_device_type == 'cpu': + elif init_device_type == "cpu": init_device = torch.device("cpu") else: raise NotImplementedError("Unknown device found.") model_numel_tensor = torch.zeros(1, dtype=torch.int) - with ZeroInitContext(target_device=init_device, - shard_strategy=shard_strategy_class(), - shard_param=True, - model_numel_tensor=model_numel_tensor): + with ZeroInitContext( + target_device=init_device, + shard_strategy=shard_strategy_class(), + shard_param=True, + model_numel_tensor=model_numel_tensor, + ): model = MoeModel(checkpoint=True) for name, param in model.named_parameters(): - assert hasattr(param, 'colo_attr') + assert hasattr(param, "colo_attr") # the parameters in moe experts and its gate should not be sharded - if ('experts' in name) or ('gate' in name) or ('residual_combine' in name): + if ("experts" in name) or ("gate" in name) or ("residual_combine" in name): assert not param.colo_attr.sharded_data_tensor.is_sharded, "`{}` parameter has problem".format(name) else: assert param.colo_attr.sharded_data_tensor.is_sharded # the parameters in moe experts is not replicated - if 'experts' in name: + if "experts" in name: assert not param.colo_attr.is_replicated else: assert param.colo_attr.is_replicated if param.colo_attr.param_is_sharded: - assert param.colo_attr.data_payload.device.type == init_device.type, \ - f'{param.colo_attr.data_payload.device.type} vs. {init_device.type}' + assert ( + param.colo_attr.data_payload.device.type == init_device.type + ), f"{param.colo_attr.data_payload.device.type} vs. {init_device.type}" else: - assert param.colo_attr.data_payload.device.type == 'cuda' + assert param.colo_attr.data_payload.device.type == "cuda" def _run_dist(rank, world_size, port): - colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") MOE_CONTEXT.setup(seed=42) run_moe_zero_init() @@ -104,5 +102,5 @@ def test_moe_zero_init(world_size): spawn(_run_dist, world_size) -if __name__ == '__main__': +if __name__ == "__main__": test_moe_zero_init(world_size=2) diff --git a/tests/test_moe/test_moe_zero_model.py b/tests/test_moe/test_moe_zero_model.py index ec37967f18c5a0b2fc735e38de062c4a1967e4f9..724d70d77bc64068a4106368c09f3a3132745a07 100644 --- a/tests/test_moe/test_moe_zero_model.py +++ b/tests/test_moe/test_moe_zero_model.py @@ -3,7 +3,7 @@ import torch import colossalai from colossalai.context import MOE_CONTEXT -from colossalai.engine.gradient_handler import MoeGradientHandler +from colossalai.legacy.engine.gradient_handler import MoeGradientHandler from colossalai.nn import MoeLoss from colossalai.testing import assert_equal_in_group, parameterize, rerun_if_address_is_in_use, spawn from colossalai.zero.legacy.init_ctx import ZeroInitContext @@ -21,13 +21,13 @@ from tests.test_zero.test_legacy.common import CONFIG, check_grads_padding, run_ def run_model_test(enable_autocast, shard_strategy_class): shard_strategy = shard_strategy_class() - get_components_func = non_distributed_component_funcs.get_callable('hanging_param_model') + get_components_func = non_distributed_component_funcs.get_callable("hanging_param_model") _, train_dataloader, _, optimizer_class, _ = get_components_func() criterion = MoeLoss(aux_weight=0.01, loss_fn=torch.nn.CrossEntropyLoss) - with ZeroInitContext(target_device=torch.device('cuda', torch.cuda.current_device()), - shard_strategy=shard_strategy, - shard_param=True): + with ZeroInitContext( + target_device=torch.device("cuda", torch.cuda.current_device()), shard_strategy=shard_strategy, shard_param=True + ): zero_model = MoeModel(checkpoint=True) zero_model = ShardedModelV2(zero_model, shard_strategy) @@ -54,7 +54,7 @@ def run_model_test(enable_autocast, shard_strategy_class): def run_dist(rank, world_size, port): - colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") MOE_CONTEXT.setup(seed=42) run_model_test() @@ -66,5 +66,5 @@ def test_moe_zero_model(world_size): spawn(run_dist, world_size) -if __name__ == '__main__': +if __name__ == "__main__": test_moe_zero_model(world_size=2) diff --git a/tests/test_moe/test_moe_zero_optim.py b/tests/test_moe/test_moe_zero_optim.py index efc6e9ddae27311ab5e16771b366b2baf87d2e00..bb9822daee05d97f7ecc56e481536d87ebbd4fdb 100644 --- a/tests/test_moe/test_moe_zero_optim.py +++ b/tests/test_moe/test_moe_zero_optim.py @@ -2,9 +2,9 @@ import pytest import torch import colossalai -from colossalai.amp import convert_to_apex_amp from colossalai.context import MOE_CONTEXT -from colossalai.engine.gradient_handler import MoeGradientHandler +from colossalai.legacy.amp import convert_to_apex_amp +from colossalai.legacy.engine.gradient_handler import MoeGradientHandler from colossalai.nn import MoeLoss from colossalai.nn.optimizer import CPUAdam from colossalai.testing import assert_equal_in_group, parameterize, rerun_if_address_is_in_use, spawn @@ -43,31 +43,33 @@ def _run_step(model, optimizer, data, label, criterion, grad_handler): @parameterize("cpu_offload", [True]) -@parameterize("use_cpuadam", [True]) # We do not use Hybrid Adam right now, since it has a little bug +@parameterize("use_cpuadam", [True]) # We do not use Hybrid Adam right now, since it has a little bug @parameterize("reuse_fp16_shard", [True, False]) @parameterize("shard_strategy_class", [TensorShardStrategy, BucketTensorShardStrategy]) -def _run_test_sharded_optim_v2(cpu_offload, - shard_strategy_class, - use_cpuadam, - reuse_fp16_shard, - gpu_margin_mem_ratio=0.0): +def _run_test_sharded_optim_v2( + cpu_offload, shard_strategy_class, use_cpuadam, reuse_fp16_shard, gpu_margin_mem_ratio=0.0 +): shard_strategy = shard_strategy_class() if use_cpuadam and cpu_offload is False: return MOE_CONTEXT.reset_loss() - get_components_func = non_distributed_component_funcs.get_callable('hanging_param_model') + get_components_func = non_distributed_component_funcs.get_callable("hanging_param_model") _, train_dataloader, _, optimizer_class, _ = get_components_func() criterion = MoeLoss(aux_weight=0.01, loss_fn=torch.nn.CrossEntropyLoss) - with ZeroInitContext(target_device=torch.device('cpu') if cpu_offload else get_current_device(), - shard_strategy=shard_strategy, - shard_param=True): + with ZeroInitContext( + target_device=torch.device("cpu") if cpu_offload else get_current_device(), + shard_strategy=shard_strategy, + shard_param=True, + ): zero_model = MoeModel(checkpoint=True) - zero_model = ShardedModelV2(zero_model, - shard_strategy, - tensor_placement_policy='cpu' if cpu_offload else 'cuda', - reuse_fp16_shard=reuse_fp16_shard) + zero_model = ShardedModelV2( + zero_model, + shard_strategy, + tensor_placement_policy="cpu" if cpu_offload else "cuda", + reuse_fp16_shard=reuse_fp16_shard, + ) # check whether parameters are identical in ddp for name, p in zero_model.named_parameters(): @@ -82,12 +84,11 @@ def _run_test_sharded_optim_v2(cpu_offload, optimizer_class = CPUAdam optim = optimizer_class(model.parameters(), lr=1e-3) sharded_optim = optimizer_class(zero_model.parameters(), lr=1e-3) - sharded_optim = ShardedOptimizerV2(zero_model, - sharded_optim, - initial_scale=2**5, - gpu_margin_mem_ratio=gpu_margin_mem_ratio) + sharded_optim = ShardedOptimizerV2( + zero_model, sharded_optim, initial_scale=2**5, gpu_margin_mem_ratio=gpu_margin_mem_ratio + ) - amp_config = dict(opt_level='O2', keep_batchnorm_fp32=False) + amp_config = dict(opt_level="O2", keep_batchnorm_fp32=False) apex_model, apex_optimizer = convert_to_apex_amp(model, optim, amp_config) apex_grad_handler = MoeGradientHandler(model) @@ -103,7 +104,7 @@ def _run_test_sharded_optim_v2(cpu_offload, def _run_dist(rank, world_size, port): - colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") MOE_CONTEXT.setup(seed=42) _run_test_sharded_optim_v2() @@ -116,5 +117,5 @@ def test_moe_zero_optim(world_size): spawn(_run_dist, world_size) -if __name__ == '__main__': +if __name__ == "__main__": test_moe_zero_optim(world_size=4) diff --git a/tests/test_ops/test_addmm_tp.py b/tests/test_ops/test_addmm_tp.py deleted file mode 100644 index ecd3721b902e252de6c765a6f8b1c96215d984a9..0000000000000000000000000000000000000000 --- a/tests/test_ops/test_addmm_tp.py +++ /dev/null @@ -1,73 +0,0 @@ -import pytest -import torch -import torch.nn as nn - -import colossalai -from colossalai.tensor import ColoTensor, ColoTensorSpec, ProcessGroup -from colossalai.testing import rerun_if_address_is_in_use, spawn -from tests.test_tensor.common_utils import split_param_col_tp1d, split_param_row_tp1d, tensor_equal, tensor_shard_equal - - -class Conv1D(nn.Module): - """ - 1D-convolutional layer as defined by Radford et al. for OpenAI GPT (and also used in GPT-2). - Basically works like a linear layer but the weights are transposed. - Args: - nf (`int`): The number of output features. - nx (`int`): The number of input features. - """ - - def __init__(self, nf, nx): - super().__init__() - self.nf = nf - w = torch.empty(nx, nf) - nn.init.normal_(w, std=0.02) - self.weight = nn.Parameter(w) - self.bias = nn.Parameter(torch.ones(nf)) - - def forward(self, x): - size_out = x.size()[:-1] + (self.nf,) - x = torch.addmm(self.bias, x.view(-1, x.size(-1)), self.weight) - x = x.view(size_out) - return x - - -def run_with_spec(spec_init_func, split_bias): - model = Conv1D(4, 16).cuda() - world_size = torch.distributed.get_world_size() - pg = ProcessGroup(tp_degree=world_size) - - weight = ColoTensor(torch.nn.Parameter(model.weight.detach()), ColoTensorSpec(pg)) - bias = ColoTensor(torch.nn.Parameter(model.bias.detach()), ColoTensorSpec(pg)) - - spec_init_func(weight, pg) - if split_bias: - spec_init_func(bias, pg) - - x = torch.rand(2, 16).cuda() - out = model(x) - colo_out = torch.addmm(bias, x, weight) - colo_out = colo_out.to_replicate() - assert tensor_equal(out, colo_out) - grad = torch.rand_like(out) - out.backward(grad) - colo_out.backward(grad) - tensor_shard_equal(model.weight.grad, weight.grad, pg.tp_local_rank(), pg.tp_world_size()) - tensor_shard_equal(model.bias.grad, bias.grad, pg.tp_local_rank(), pg.tp_world_size()) - - -def run_dist(rank, world_size, port): - colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') - run_with_spec(spec_init_func=split_param_row_tp1d, split_bias=False) - run_with_spec(spec_init_func=split_param_col_tp1d, split_bias=True) - - -@pytest.mark.dist -@pytest.mark.parametrize('world_size', [1, 4]) -@rerun_if_address_is_in_use() -def test_addmm_1d(world_size): - spawn(run_dist, world_size) - - -if __name__ == '__main__': - test_addmm_1d(4) diff --git a/tests/test_ops/test_embedding_bag_tp.py b/tests/test_ops/test_embedding_bag_tp.py deleted file mode 100644 index d3d3dcf7e2c9a2737c26e97ad4e7ae487c201917..0000000000000000000000000000000000000000 --- a/tests/test_ops/test_embedding_bag_tp.py +++ /dev/null @@ -1,43 +0,0 @@ -import pytest -import torch -from torch.nn import functional as F - -import colossalai -from colossalai.tensor import ColoParameter, ColoTensorSpec, ProcessGroup -from colossalai.testing import rerun_if_address_is_in_use, spawn -from tests.test_tensor.common_utils import split_param_col_tp1d, tensor_equal, tensor_shard_equal - - -def run_with_spec(spec_init_func): - pg = ProcessGroup(tp_degree=torch.distributed.get_world_size()) - model = torch.nn.EmbeddingBag(10, 4).cuda() - weight = ColoParameter(model.weight.clone(), True, ColoTensorSpec(pg)) - - spec_init_func(weight, pg) - - inputs = torch.tensor([1, 2, 4, 5, 4, 3, 2, 9]).cuda() - offsets = torch.tensor([0, 4]).cuda() - out = model(inputs, offsets=offsets) - colo_out = F.embedding_bag(inputs, weight, offsets=offsets) - assert tensor_equal(out, colo_out) - grad = torch.rand_like(out) - out.backward(grad) - colo_out.backward(grad) - assert tensor_shard_equal(model.weight.grad, weight.grad, pg.tp_local_rank(), pg.tp_world_size()) - - -def run_dist(rank, world_size, port): - config = dict(parallel=dict(tensor=dict(mode="1d", size=world_size),)) - colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') - run_with_spec(split_param_col_tp1d) - - -@pytest.mark.dist -@pytest.mark.parametrize('world_size', [1, 4]) -@rerun_if_address_is_in_use() -def test_embedding_bag_1d(world_size): - spawn(run_dist, world_size) - - -if __name__ == '__main__': - test_embedding_bag_1d(4) diff --git a/tests/test_ops/test_embedding_tp.py b/tests/test_ops/test_embedding_tp.py deleted file mode 100644 index c0b376e2c92a298bf2e8c42257ba8b39a14e9a35..0000000000000000000000000000000000000000 --- a/tests/test_ops/test_embedding_tp.py +++ /dev/null @@ -1,44 +0,0 @@ -import pytest -import torch -from torch.nn import functional as F - -import colossalai -from colossalai.tensor import ColoTensor, ColoTensorSpec, ProcessGroup -from colossalai.testing import rerun_if_address_is_in_use, spawn -from tests.test_tensor.common_utils import split_param_col_tp1d, split_param_row_tp1d, tensor_equal, tensor_shard_equal - - -def run_with_spec(spec_init_func, pg: ProcessGroup): - model = torch.nn.Embedding(12, 32).cuda() - weight = ColoTensor(torch.nn.Parameter(model.weight.detach()), ColoTensorSpec(pg)) - - spec_init_func(weight, pg) - - x = torch.tensor((0, 3, 6, 9)).cuda() - out = model(x) - colo_out = F.embedding(x, weight) - assert tensor_equal(out, colo_out) - grad = torch.rand_like(out) - out.backward(grad) - colo_out.backward(grad) - # compare grad inside a TP group - assert tensor_shard_equal(model.weight.grad, weight.grad, pg.tp_local_rank(), pg.tp_world_size()) - - -def run_dist(rank, world_size, port): - # config = dict(parallel=dict(tensor=dict(mode="1d", size=world_size),)) - colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') - pg = ProcessGroup(tp_degree=world_size) - run_with_spec(split_param_row_tp1d, pg) - run_with_spec(split_param_col_tp1d, pg) - - -@pytest.mark.dist -@pytest.mark.parametrize('world_size', [1, 4]) -@rerun_if_address_is_in_use() -def test_embedding_1d(world_size): - spawn(run_dist, world_size) - - -if __name__ == '__main__': - test_embedding_1d(4) diff --git a/tests/test_ops/test_linear_tp.py b/tests/test_ops/test_linear_tp.py deleted file mode 100644 index c88adfdd9a7757cd7eabc8f27d40adae6374c61c..0000000000000000000000000000000000000000 --- a/tests/test_ops/test_linear_tp.py +++ /dev/null @@ -1,48 +0,0 @@ -import pytest -import torch -import torch.nn.functional as F - -import colossalai -from colossalai.tensor import ColoTensor, ColoTensorSpec, ProcessGroup -from colossalai.testing import rerun_if_address_is_in_use, spawn -from tests.test_tensor.common_utils import split_param_col_tp1d, split_param_row_tp1d, tensor_equal, tensor_shard_equal - - -def run_with_spec(spec_init_func, split_bias): - pg = ProcessGroup(tp_degree=torch.distributed.get_world_size()) - model = torch.nn.Linear(4, 8).cuda() - weight = ColoTensor(torch.nn.Parameter(model.weight.detach()), ColoTensorSpec(pg)) - bias = ColoTensor(torch.nn.Parameter(model.bias.detach()), ColoTensorSpec(pg)) - - spec_init_func(weight, pg) - if split_bias: - spec_init_func(bias, pg) - - x = torch.rand(2, 4).cuda() - out = model(x) - colo_out = F.linear(x, weight, bias) - colo_out = colo_out.to_replicate() - assert tensor_equal(out, colo_out) - grad = torch.rand_like(out) - out.backward(grad) - colo_out.backward(grad) - assert tensor_shard_equal(model.weight.grad, weight.grad, pg.tp_local_rank(), pg.tp_world_size()) - assert tensor_shard_equal(model.bias.grad, bias.grad, pg.tp_local_rank(), pg.tp_world_size()) - - -def run_dist(rank, world_size, port): - config = dict(parallel=dict(tensor=dict(mode="1d", size=world_size),)) - colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') - run_with_spec(spec_init_func=split_param_col_tp1d, split_bias=False) - run_with_spec(spec_init_func=split_param_row_tp1d, split_bias=True) - - -@pytest.mark.dist -@pytest.mark.parametrize('world_size', [1, 4]) -@rerun_if_address_is_in_use() -def test_linear_1d(world_size): - spawn(run_dist, world_size) - - -if __name__ == '__main__': - test_linear_1d(4) diff --git a/tests/test_ops/test_loss_func.py b/tests/test_ops/test_loss_func.py deleted file mode 100644 index fc55c7f7725412ec75411e60b51791c654848ca3..0000000000000000000000000000000000000000 --- a/tests/test_ops/test_loss_func.py +++ /dev/null @@ -1,48 +0,0 @@ -import pytest -import torch -import torch.nn.functional as F - -import colossalai -from colossalai.tensor import ColoTensor, ColoTensorSpec, ComputePattern, ComputeSpec, ProcessGroup, ShardSpec -from colossalai.testing import rerun_if_address_is_in_use, spawn -from colossalai.utils import get_current_device - - -def check_cross_entropy(): - input_t = torch.randn(4, 4, device=get_current_device(), requires_grad=True) - input_ct = torch.randn(4, 4, device=get_current_device(), requires_grad=True) - with torch.no_grad(): - input_ct.copy_(input_t) - - target = torch.randint(4, (4,), dtype=torch.int64, device=get_current_device()) - - world_size = torch.distributed.get_world_size() - pg = ProcessGroup(tp_degree=world_size) - input_t_colo = ColoTensor.from_torch_tensor(tensor=input_ct, spec=ColoTensorSpec(pg)) - input_shard = input_t_colo.redistribute(ShardSpec([-1], [pg.tp_world_size()])) - input_shard.set_tensor_spec(dist_spec=None, compute_spec=ComputeSpec(ComputePattern.TP1D)) - - output = F.cross_entropy(input_t, target) - output_colo = F.cross_entropy(input_shard, target) - assert torch.allclose(output_colo, output) - - output.backward() - output_colo.backward() - - assert torch.allclose(input_t.grad, input_ct.grad) - - -def run_dist(rank, world_size, port): - colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') - check_cross_entropy() - - -@pytest.mark.dist -@pytest.mark.parametrize('world_size', [1, 2]) -@rerun_if_address_is_in_use() -def test_loss_func(world_size): - spawn(run_dist, world_size) - - -if __name__ == '__main__': - test_loss_func(1) diff --git a/tests/test_ops/test_op.py b/tests/test_ops/test_op.py deleted file mode 100644 index 4176d3b64d90e4c560ba48a6de42a9c9e26b73d1..0000000000000000000000000000000000000000 --- a/tests/test_ops/test_op.py +++ /dev/null @@ -1,87 +0,0 @@ -import pytest -import torch -import torch.nn.functional as F -from torch.nn import Parameter - -import colossalai -from colossalai.tensor import ColoTensor, ColoTensorSpec, ProcessGroup, ShardSpec -from colossalai.testing import rerun_if_address_is_in_use, spawn -from colossalai.utils import get_current_device - - -def _run_layer_norm(): - ln_op = torch.nn.LayerNorm(2, 3, device=get_current_device()) - - input_t = torch.randn(3, 2, device=get_current_device()) - - pg = ProcessGroup(tp_degree=torch.distributed.get_world_size()) - input_t_colo = ColoTensor.from_torch_tensor(input_t.clone().detach(), ColoTensorSpec(pg)) - - # prepare colossalai LN - weight = ColoTensor(Parameter(ln_op.weight.detach()), ColoTensorSpec(pg)) - bias = ColoTensor(Parameter(ln_op.bias.detach()), ColoTensorSpec(pg)) - - output = ln_op(input_t) - output_colo = F.layer_norm(input_t_colo, ln_op.normalized_shape, weight, bias, ln_op.eps) - - assert torch.allclose(output_colo, output) - - torch.mean(output).backward() - torch.mean(output_colo).backward() - - assert torch.allclose(ln_op.weight.grad, weight.grad) - - -def check_spec_eq(tensor, other): - assert isinstance(tensor, ColoTensor) and isinstance(other, ColoTensor) - for k in dir(tensor.dist_spec): - if not k.startswith('__'): - assert hasattr(other.dist_spec, k), f"{k}" - assert getattr(tensor.dist_spec, k) == getattr(other.dist_spec, k) - - -def check_element_wise_ops(): - world_size = torch.distributed.get_world_size() - pg = ProcessGroup(tp_degree=world_size) - t = torch.rand(2, 2) - x = ColoTensor(t, spec=ColoTensorSpec(pg, ShardSpec([0], [pg.tp_world_size()]))) - - check_spec_eq(x, x.cuda()) - assert torch.equal(x.cuda(), t.cuda()) - check_spec_eq(x, torch.abs(x)) - assert torch.equal(torch.abs(x), torch.abs(t)) - check_spec_eq(x, F.sigmoid(x)) - assert torch.equal(F.sigmoid(x), F.sigmoid(t)) - - -def run_dist(rank, world_size, port): - colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') - check_element_wise_ops() - _run_layer_norm() - - -@pytest.mark.dist -@pytest.mark.parametrize('world_size', [2]) -@rerun_if_address_is_in_use() -def test_element_wise_ops(world_size): - spawn(run_dist, world_size) - - -def run_dist2(rank, world_size, port): - colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') - _run_layer_norm() - - -@pytest.mark.dist -@pytest.mark.parametrize('world_size', [1]) -@rerun_if_address_is_in_use() -def test_ln(world_size): - spawn(run_dist2, world_size) - - -def check_all(): - test_element_wise_ops(2) - - -if __name__ == '__main__': - check_all() diff --git a/tests/test_ops/test_view.py b/tests/test_ops/test_view.py deleted file mode 100644 index a9f2033201c7ac6a1797b92dd6e834d97a90b5d1..0000000000000000000000000000000000000000 --- a/tests/test_ops/test_view.py +++ /dev/null @@ -1,97 +0,0 @@ -import pytest -import torch -import torch.distributed as dist - -import colossalai -from colossalai.tensor import ColoTensor, ColoTensorSpec, ProcessGroup, ShardSpec -from colossalai.tensor.distspec import DistPlacementPattern -from colossalai.testing import rerun_if_address_is_in_use, spawn -from colossalai.utils import get_current_device -from tests.test_tensor.common_utils import debug_print, split_param_col_tp1d, split_param_row_tp1d - - -def exam_view_core(pg): - # the case of replicated ColoTensors - x = torch.randn(4, 4).cuda() - x_colo = ColoTensor(x, ColoTensorSpec(pg)) - - y = x.view(2, -1, 2) - y_colo = x_colo.view(2, -1, 2) - - assert torch.all(y == y_colo) - assert y_colo.dist_spec.placement == DistPlacementPattern.REPLICATE - # the perfect case of col-sliced ColoTensors - split_param_col_tp1d(x_colo, pg) - - z = x.view(torch.Size((2, 1, 2, -1))) - z_colo = x_colo.view(torch.Size((2, 1, 2, -1))) - if dist.get_rank() == 0: - z = z[:, :, :, 0:2] - else: - z = z[:, :, :, 2:] - assert torch.all(z == z_colo) - assert z_colo.dist_spec == x_colo.dist_spec - # the perfect case of row-sliced ColoTensors - split_param_row_tp1d(x_colo, pg) - - z = x.view(torch.Size((-1, 2, 2))) - z_colo = x_colo.view(torch.Size((-1, 2, 2))) - if dist.get_rank() == 0: - z = z[0:2, :, :] - else: - z = z[2:, :, :] - assert torch.all(z == z_colo) - assert z_colo.dist_spec == x_colo.dist_spec - # the normal case of row-sliced ColoTensors - z = x.view(-1, 2, 2, 2) - z_colo = x_colo.view(-1, 2, 2, 2) - assert torch.all(z == z_colo) - assert y_colo.dist_spec.placement == DistPlacementPattern.REPLICATE - - -def exam_view_autograd(pg): - x = torch.randn(8, 2, device=get_current_device(), requires_grad=True) - y = torch.randn(8, 2, device=get_current_device(), requires_grad=True) - with torch.no_grad(): - y.copy_(x) - y = ColoTensor(y, ColoTensorSpec(pg)) - y_slice = y.redistribute(ShardSpec([-1], [pg.tp_world_size()])) - - xx = x.view(2, 2, -1) - yy_slice = y_slice.view(2, 2, -1) - yy = yy_slice.to_replicate() - grad = torch.randn(2, 2, 4, device=get_current_device()) - - xx.backward(grad) - yy.backward(grad) - assert torch.all(x.grad == y.grad) - - -def exam_view_errors(pg): - x = torch.randn(8, 2, device=get_current_device()) - x = ColoTensor(x, ColoTensorSpec(pg)) - split_param_row_tp1d(x, pg) - - x.view('a', 'b', 'c') - x.view(8, -1) - x.view([-2, -2, -2]) - x.view((-1, -1, -1)) - - -def run_dist(rank, world_size, port): - colossalai.launch(config=dict(), rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') - pg = ProcessGroup(tp_degree=torch.distributed.get_world_size()) - exam_view_core(pg) - exam_view_autograd(pg) - # exam_view_errors(pg) - - -@pytest.mark.dist -@pytest.mark.parametrize('world_size', [2]) -@rerun_if_address_is_in_use() -def test_view(world_size): - spawn(run_dist, world_size) - - -if __name__ == '__main__': - test_view(2) diff --git a/tests/test_optimizer/test_adam_kernel.py b/tests/test_optimizer/test_adam_kernel.py new file mode 100644 index 0000000000000000000000000000000000000000..8131ea3234d8da1da8315e88e4b36e53fa69192e --- /dev/null +++ b/tests/test_optimizer/test_adam_kernel.py @@ -0,0 +1,171 @@ +# This test checks adam kernels +# Baseline is pure fp32 torch adam optimizer +import math +from abc import abstractmethod +from typing import Type + +import pytest +import torch +from torch import Tensor + +from colossalai.utils import get_current_device, multi_tensor_applier + +_FUSED_ALLOWED_P_G_TYPES = [ + (torch.float, torch.half), + (torch.float, torch.float), + (torch.half, torch.float), + (torch.half, torch.half), + (torch.bfloat16, torch.float), + (torch.float, torch.bfloat16), + (torch.bfloat16, torch.bfloat16), +] + +_CPU_ALLOWED_P_G_TYPES = [ + (torch.float, torch.half), + (torch.float, torch.float), + (torch.half, torch.float), + (torch.half, torch.half), +] + + +class AdamKernel: + def __init__(self, lr: float, beta1: float, beta2: float, eps: float, weight_decay: float, use_adamw: bool) -> None: + self.lr = lr + self.beta1 = beta1 + self.beta2 = beta2 + self.eps = eps + self.weight_decay = weight_decay + self.use_adamw = use_adamw + + @abstractmethod + def update(self, step: int, param: Tensor, grad: Tensor, exp_avg: Tensor, exp_avg_sq: Tensor): + pass + + +class TorchAdamKernel(AdamKernel): + def update(self, step: int, param: Tensor, grad: Tensor, exp_avg: Tensor, exp_avg_sq: Tensor): + bias_correction1 = 1 - self.beta1**step + bias_correction2 = 1 - self.beta2**step + + if self.weight_decay != 0: + if self.use_adamw: + # Perform stepweight decay + param.mul_(1 - self.lr * self.weight_decay) + else: + grad = grad.add(param, alpha=self.weight_decay) + + # Decay the first and second moment running average coefficient + exp_avg.mul_(self.beta1).add_(grad, alpha=1 - self.beta1) + exp_avg_sq.mul_(self.beta2).addcmul_(grad, grad, value=1 - self.beta2) + denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(self.eps) + + step_size = self.lr / bias_correction1 + + param.addcdiv_(exp_avg, denom, value=-step_size) + + +class FusedAdamKernel(AdamKernel): + def __init__(self, lr: float, beta1: float, beta2: float, eps: float, weight_decay: float, use_adamw: bool) -> None: + super().__init__(lr, beta1, beta2, eps, weight_decay, use_adamw) + from colossalai.kernel.op_builder import FusedOptimBuilder + + fused_optim = FusedOptimBuilder().load() + self.fused_adam = fused_optim.multi_tensor_adam + self.dummy_overflow_buf = torch.cuda.IntTensor([0]) + + def update(self, step: int, param: Tensor, grad: Tensor, exp_avg: Tensor, exp_avg_sq: Tensor): + multi_tensor_applier( + self.fused_adam, + self.dummy_overflow_buf, + [[grad], [param], [exp_avg], [exp_avg_sq]], + self.lr, + self.beta1, + self.beta2, + self.eps, + step, + self.use_adamw, + True, + self.weight_decay, + -1, + ) + + +class CPUAdamKernel(AdamKernel): + def __init__(self, lr: float, beta1: float, beta2: float, eps: float, weight_decay: float, use_adamw: bool) -> None: + super().__init__(lr, beta1, beta2, eps, weight_decay, use_adamw) + from colossalai.kernel.op_builder import CPUAdamBuilder + + cpu_optim = CPUAdamBuilder().load() + + self.cpu_adam_op = cpu_optim.CPUAdamOptimizer(lr, beta1, beta2, eps, weight_decay, use_adamw) + + def update(self, step: int, param: Tensor, grad: Tensor, exp_avg: Tensor, exp_avg_sq: Tensor): + self.cpu_adam_op.step( + step, + self.lr, + self.beta1, + self.beta2, + self.eps, + self.weight_decay, + True, + param.view(-1), + grad.view(-1), + exp_avg.view(-1), + exp_avg_sq.view(-1), + -1, + ) + + +def check_adam_kernel( + kernel: Type[AdamKernel], + adamw: bool, + weight_decay: float, + p_dtype: torch.dtype, + g_dtype: torch.dtype, + device: torch.device, + n_steps: int, + rtol: float, + atol: float, +): + lr = 1e-3 + beta1, beta2 = 0.9, 0.999 + eps = 1e-8 + torch_adam = TorchAdamKernel(lr, beta1, beta2, eps, weight_decay, adamw) + adam_kernel = kernel(lr, beta1, beta2, eps, weight_decay, adamw) + master_p = torch.rand(64, device=device) + master_g = torch.rand_like(master_p) + master_exp_avg = torch.zeros_like(master_p) + master_exp_avg_sq = torch.zeros_like(master_p) + p = master_p.clone().to(p_dtype) + g = master_g.clone().to(g_dtype) + exp_avg = master_exp_avg.clone() + exp_avg_sq = master_exp_avg_sq.clone() + + for step in range(1, 1 + n_steps): + torch_adam.update(step, master_p, master_g, master_exp_avg, master_exp_avg_sq) + adam_kernel.update(step, p, g, exp_avg, exp_avg_sq) + # if overflow, the weight won't be updated. so there will be no nan in p + assert not torch.isnan(p).any() + assert torch.allclose(master_p, p.float(), rtol=rtol, atol=atol) + + +@pytest.mark.parametrize("adamw", [False, True]) +@pytest.mark.parametrize("weight_decay", [0.0, 0.1]) +@pytest.mark.parametrize("p_dtype, g_dtype", _FUSED_ALLOWED_P_G_TYPES) +def test_fused_adam_kernel(adamw, weight_decay, p_dtype, g_dtype): + rtol, atol = 1e-5, 1e-8 + if p_dtype is torch.float16 or g_dtype is torch.float16: + rtol, atol = 1e-3, 1e-3 + if p_dtype is torch.bfloat16 or g_dtype is torch.bfloat16: + rtol, atol = 4e-3, 4e-3 + check_adam_kernel(FusedAdamKernel, adamw, weight_decay, p_dtype, g_dtype, get_current_device(), 3, rtol, atol) + + +@pytest.mark.parametrize("adamw", [False, True]) +@pytest.mark.parametrize("weight_decay", [0.0, 0.1]) +@pytest.mark.parametrize("p_dtype, g_dtype", _CPU_ALLOWED_P_G_TYPES) +def test_cpu_adam_kernel(adamw, weight_decay, p_dtype, g_dtype): + rtol, atol = 1e-5, 1e-8 + if p_dtype is torch.float16 or g_dtype is torch.float16: + rtol, atol = 1e-3, 1e-3 + check_adam_kernel(CPUAdamKernel, adamw, weight_decay, p_dtype, g_dtype, torch.device("cpu"), 3, rtol, atol) diff --git a/tests/test_optimizer/test_adam_optim.py b/tests/test_optimizer/test_adam_optim.py new file mode 100644 index 0000000000000000000000000000000000000000..59b40a0afa3c9a9af72fc9e50e3a3ef2272f4f05 --- /dev/null +++ b/tests/test_optimizer/test_adam_optim.py @@ -0,0 +1,91 @@ +from copy import deepcopy +from typing import Type, Union + +import pytest +import torch +import torch.nn as nn +from torch.optim import Adam, AdamW + +from colossalai.nn.optimizer import CPUAdam, FusedAdam, HybridAdam +from tests.kit.model_zoo import model_zoo + +_ALLOWED_OPTIM_DEVICES = [ + (FusedAdam, torch.device("cuda:0")), + (CPUAdam, torch.device("cpu")), + (CPUAdam, torch.device("cuda:0")), + (HybridAdam, torch.device("cpu")), + (HybridAdam, torch.device("cuda:0")), +] + +_ALLOWED_P_G_TYPES = [ + (torch.float, torch.float), # pure fp32 + (torch.float, torch.half), # fp16 amp + (torch.float, torch.bfloat16), # bfloat16 amp + # (torch.half, torch.half), # FIXME(ver217): cpu adam kernel does not support pure fp16 + # (torch.bfloat16, torch.bfloat16), # FIXME(ver217): cpu adam kernel does not support pure bfloat16 +] + +N_STEPS = 3 + + +def setup_param_groups(bert_model: nn.Module) -> list: + no_decay = ["bias", "LayerNorm.weight"] + optimizer_grouped_parameters = [ + { + "params": [p for n, p in bert_model.named_parameters() if not any(nd in n for nd in no_decay)], + "weight_decay": 0.1, + }, + { + "params": [p for n, p in bert_model.named_parameters() if any(nd in n for nd in no_decay)], + "weight_decay": 0.0, + }, + ] + return optimizer_grouped_parameters + + +def set_grad(model: nn.Module, torch_model: nn.Module, g_dtype: torch.dtype) -> None: + for p, torch_p in zip(model.parameters(), torch_model.parameters()): + torch_p.grad = torch.rand_like(torch_p) + # avoid inconsistent grad and param dtype error + orig_p = p.data + p.data = torch_p.grad.clone().to(g_dtype) + p.grad = p.data + p.data = orig_p + + +@pytest.mark.parametrize("optim_cls, device", _ALLOWED_OPTIM_DEVICES) +@pytest.mark.parametrize("adamw", [False, True]) +@pytest.mark.parametrize("p_dtype, g_dtype", _ALLOWED_P_G_TYPES) +def test_adam_optim_on_bert( + optim_cls: Union[Type[FusedAdam], Type[CPUAdam], Type[HybridAdam]], + device: torch.device, + adamw: bool, + p_dtype: torch.dtype, + g_dtype: torch.dtype, +) -> None: + model_fn, *_ = next(iter(model_zoo.get_sub_registry("transformers_bert_for_sequence_classification").values())) + torch_model = model_fn().to(device) + model = deepcopy(torch_model).to(p_dtype) + lr = 1e-3 + beta1, beta2 = 0.9, 0.999 + eps = 1e-8 + torch_optim_cls = AdamW if adamw else Adam + torch_optim = torch_optim_cls(setup_param_groups(torch_model), lr=lr, betas=(beta1, beta2), eps=eps) + optim = optim_cls(setup_param_groups(model), lr=lr, betas=(beta1, beta2), eps=eps, adamw_mode=adamw) + + rtol, atol = 1e-5, 1e-5 + if p_dtype is torch.float16 or g_dtype is torch.float16: + rtol, atol = 2e-3, 2e-3 + if p_dtype is torch.bfloat16 or g_dtype is torch.bfloat16: + rtol, atol = 4e-3, 4e-3 + + for _ in range(N_STEPS): + set_grad(model, torch_model, g_dtype) + torch_optim.step() + optim.step() + torch_optim.zero_grad() + optim.zero_grad() + for p, torch_p in zip(model.parameters(), torch_model.parameters()): + # if overflow, the weight won't be updated. so there will be no nan in p + assert not torch.isnan(p).any() + assert torch.allclose(p.float(), torch_p, rtol=rtol, atol=atol) diff --git a/tests/test_optimizer/test_cpu_adam.py b/tests/test_optimizer/test_cpu_adam.py deleted file mode 100644 index 8b3ecf8517f7cd13cc67555bb81cf388efdd0654..0000000000000000000000000000000000000000 --- a/tests/test_optimizer/test_cpu_adam.py +++ /dev/null @@ -1,121 +0,0 @@ -import math - -import torch - -from colossalai.testing import clear_cache_before_run, parameterize - - -def torch_adam_update( - step, - lr, - beta1, - beta2, - eps, - weight_decay, - param, - grad, - exp_avg, - exp_avg_sq, - use_adamw, -): - bias_correction1 = 1 - beta1**step - bias_correction2 = 1 - beta2**step - - if weight_decay != 0: - if use_adamw: - # Perform stepweight decay - param.mul_(1 - lr * weight_decay) - else: - grad = grad.add(param, alpha=weight_decay) - - # Decay the first and second moment running average coefficient - exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) - exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) - denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(eps) - - step_size = lr / bias_correction1 - - param.addcdiv_(exp_avg, denom, value=-step_size) - - -def assertLess(data_diff, threshold, msg): - assert data_diff < threshold, msg - - -def assertTrue(condition, msg): - assert condition, msg - - -@clear_cache_before_run() -@parameterize('adamw', [True, False]) -@parameterize('step', [1, 2]) -@parameterize('p_dtype', [torch.float, torch.half]) -@parameterize('g_dtype', [torch.float, torch.half]) -def test_cpu_adam(adamw, step, p_dtype, g_dtype): - lr = 1e-3 - beta1, beta2 = 0.9, 0.999 - eps = 1e-8 - weight_decay = 0 - - for i in range(3): - p_data = torch.rand(64, dtype=p_dtype) - p_data_copy = p_data.clone().float() - p_grad = torch.rand(64, dtype=g_dtype) - p_grad_copy = p_grad.clone().float() - exp_avg = torch.rand(p_data.shape) - exp_avg_copy = exp_avg.clone() - exp_avg_sq = torch.rand(p_data.shape) - exp_avg_sq_copy = exp_avg_sq.clone() - - from colossalai.kernel.op_builder import CPUAdamBuilder - cpu_optim = CPUAdamBuilder().load() - - cpu_adam_op = cpu_optim.CPUAdamOptimizer(lr, beta1, beta2, eps, weight_decay, adamw) - - cpu_adam_op.step( - step, - lr, - beta1, - beta2, - eps, - weight_decay, - True, - p_data.view(-1), # fp32 data - p_grad.view(-1), # fp32 grad - exp_avg.view(-1), - exp_avg_sq.view(-1), - -1, - ) - - torch_adam_update( - step, - lr, - beta1, - beta2, - eps, - weight_decay, - p_data_copy, # fp32 data - p_grad_copy, # fp32 grad - exp_avg_copy, - exp_avg_sq_copy, - adamw, - ) - var = p_data_copy - p_data - data_diff = torch.max(torch.abs(var)) - threshold = 1e-3 - assertLess( - data_diff, - threshold, - f"p_data diff {data_diff}. failed check, step {step}, lr {lr}, eps " - f"{eps} beta1 {beta1} beta2 {beta2} weight_decay {weight_decay} p_dtype {p_dtype}, g_dtype {g_dtype}", - ) - max_grad_diff = torch.max(torch.abs(p_grad_copy - p_grad)) - assertTrue(max_grad_diff < threshold, f"diff {max_grad_diff}") - max_exp_avg_diff = torch.max(torch.abs(exp_avg_copy - exp_avg)) - assertTrue(max_exp_avg_diff < threshold, f"max_exp_avg_diff {max_exp_avg_diff}") - max_exp_avg_sq_diff = torch.max(torch.abs(exp_avg_sq_copy - exp_avg_sq)) - assertTrue(max_exp_avg_sq_diff < threshold, f"max_exp_avg_sq_diff {max_exp_avg_sq_diff}") - - -if __name__ == '__main__': - test_cpu_adam() diff --git a/tests/test_optimizer/test_fused_adam.py b/tests/test_optimizer/test_fused_adam.py deleted file mode 100644 index 114d5293dad96063d9cdd7103be01a73fe60524a..0000000000000000000000000000000000000000 --- a/tests/test_optimizer/test_fused_adam.py +++ /dev/null @@ -1,64 +0,0 @@ -import torch -import torch.nn as nn -from torch.optim import AdamW -from torch.optim.adam import Adam - -from colossalai.nn.optimizer.fused_adam import FusedAdam -from colossalai.testing import clear_cache_before_run, parameterize - - -class FC(nn.Module): - - def __init__(self) -> None: - super().__init__() - self.fc = nn.Sequential(nn.Linear(64, 64)) - - def forward(self, x): - return self.fc(x) - - -@clear_cache_before_run() -@parameterize('adamw', [False, True]) -@parameterize('p_dtype', [torch.float, torch.half]) -@parameterize('g_dtype', [torch.float, torch.half]) -def test_adam(adamw, p_dtype, g_dtype): - model = FC().cuda().to(p_dtype) - state = model.state_dict() - model_copy = FC().cuda().to(p_dtype) - model_copy.load_state_dict(state.copy()) - - if adamw: - optim = FusedAdam(model.parameters(), lr=1e-3, adamw_mode=True) - torch_optim = AdamW(model_copy.parameters(), lr=1e-3) - else: - optim = FusedAdam(model.parameters(), lr=1e-3) - torch_optim = Adam(model_copy.parameters(), lr=1e-3) - - data = torch.rand(1024, 64).cuda().to(p_dtype) - data_copy = data.clone() - label = torch.rand(1024, 64).cuda().to(p_dtype) - - for d, l in zip(data, label): - y = model(d) - loss = ((l - y)**2).sum() - optim.zero_grad() - loss.backward() - if p_dtype != g_dtype: - for i in range(len(optim.param_groups[0]['params'])): - optim.param_groups[0]['params'][i].grad.data = optim.param_groups[0]['params'][i].grad.data.to(g_dtype) - optim.step() - - for d, l in zip(data_copy, label): - y = model_copy(d) - loss = ((l - y)**2).sum() - torch_optim.zero_grad() - loss.backward() - torch_optim.step() - - assert len(optim.param_groups[0]['params']) == len(torch_optim.param_groups[0]['params']) - - for i in range(len(optim.param_groups[0]['params'])): - if torch.isnan(optim.param_groups[0]['params'][i]).any() \ - or torch.isnan(torch_optim.param_groups[0]['params'][i]).any(): - continue - assert torch.allclose(optim.param_groups[0]['params'][i], torch_optim.param_groups[0]['params'][i], 2e-3, 2e-3) diff --git a/tests/test_optimizer/test_fused_adam_kernel.py b/tests/test_optimizer/test_fused_adam_kernel.py deleted file mode 100644 index 4afa13349c1be4a1424855886d95e17a14047369..0000000000000000000000000000000000000000 --- a/tests/test_optimizer/test_fused_adam_kernel.py +++ /dev/null @@ -1,95 +0,0 @@ -import math - -import torch -import torch.nn as nn -from numpy import dtype - -from colossalai.testing import clear_cache_before_run, parameterize -from colossalai.utils import multi_tensor_applier - - -def torch_adam_update( - step, - lr, - beta1, - beta2, - eps, - weight_decay, - param, - grad, - exp_avg, - exp_avg_sq, - use_adamw, -): - bias_correction1 = 1 - beta1**step - bias_correction2 = 1 - beta2**step - - if weight_decay != 0: - if use_adamw: - # Perform stepweight decay - param.mul_(1 - lr * weight_decay) - else: - grad = grad.add(param, alpha=weight_decay) - - # Decay the first and second moment running average coefficient - exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) - exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) - denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(eps) - - step_size = lr / bias_correction1 - - param.addcdiv_(exp_avg, denom, value=-step_size) - - -@clear_cache_before_run() -@parameterize('adamw', [False, True]) -@parameterize('step', [1, 2]) -@parameterize('p_dtype', [torch.float, torch.half]) -@parameterize('g_dtype', [torch.float, torch.half]) -def test_adam(adamw, step, p_dtype, g_dtype): - from colossalai.kernel.op_builder import FusedOptimBuilder - fused_optim = FusedOptimBuilder().load() - fused_adam = fused_optim.multi_tensor_adam - - dummy_overflow_buf = torch.cuda.IntTensor([0]) - - count = 0 - - for i in range(3): - p = torch.rand(64, dtype=p_dtype).cuda() - p_copy = p.clone().float() - g = torch.rand(p.shape, dtype=g_dtype).cuda() - g_copy = g.clone().float() - m = torch.rand(p.shape).cuda() - m_copy = m.clone() - v = torch.rand(p.shape).cuda() - v_copy = v.clone() - - lr = 1e-3 - beta1, beta2 = 0.9, 0.999 - eps = 1e-8 - weight_decay = 0 - - multi_tensor_applier(fused_adam, dummy_overflow_buf, [[g], [p], [m], [v]], lr, beta1, beta2, eps, step, adamw, - True, weight_decay, -1) - - torch_adam_update( - step, - lr, - beta1, - beta2, - eps, - weight_decay, - p_copy, # fp32 data - g_copy, # fp32 grad - m_copy, - v_copy, - adamw, - ) - - if torch.isnan(p).any() or torch.isnan(p_copy).any(): - count += 1 - continue - assert count < 200, "too many nans" - assert torch.allclose(p.to(torch.float), p_copy.to(torch.float), 1e-5, - 1e-5), f"failed check, adamw {adamw}, p_dtype {p_dtype}, g_dtype {g_dtype}" diff --git a/tests/test_optimizer/test_hybrid_adam.py b/tests/test_optimizer/test_hybrid_adam.py deleted file mode 100644 index d075149dfcb1e6a7c3d6357f1987a68750004810..0000000000000000000000000000000000000000 --- a/tests/test_optimizer/test_hybrid_adam.py +++ /dev/null @@ -1,42 +0,0 @@ -import torch -import torch.nn as nn -from torch.optim import AdamW -from torch.optim.adam import Adam - -from colossalai.nn.optimizer.hybrid_adam import HybridAdam -from colossalai.testing import clear_cache_before_run, parameterize - -RE = 3 - - -@clear_cache_before_run() -@parameterize('adamw', [False, True]) -@parameterize('device', ['cpu', 'cuda:0']) -@parameterize('p_dtype', [torch.float]) -@parameterize('g_dtype', [torch.float, torch.half]) -def test_adam(adamw, device, p_dtype, g_dtype): - rng_state = torch.get_rng_state() - p = nn.Parameter(torch.rand(64).to(device, p_dtype)) - torch.set_rng_state(rng_state) - p_copy = nn.Parameter(torch.rand(64).to(device).float()) - - if adamw: - optim = HybridAdam([p], lr=1e-3, adamw_mode=True) - torch_optim = AdamW([p_copy], lr=1e-3) - else: - optim = HybridAdam([p], lr=1e-3) - torch_optim = Adam([p_copy], lr=1e-3) - - print(f"adaw mode {adamw}, device {device}, p_dtype {p_dtype}, g_dtype {g_dtype}") - for i in range(RE): - p.grad = torch.rand(64).to(device, p_dtype) - p_copy.grad = p.grad.clone().float() - p.grad.data = p.grad.data.to(g_dtype) - - optim.step() - torch_optim.step() - - if torch.isnan(p.data).any() or torch.isnan(p_copy.data).any(): - continue - assert torch.allclose(p.data, p_copy.data, 1e-4, 1e-2), \ - f"adaw mode {adamw}, device {device}, p_dtype {p_dtype}, g_dtype {g_dtype}" diff --git a/tests/test_optimizer/test_nvme.py b/tests/test_optimizer/test_nvme.py index 5d794ac2dd1a92123fd8fde23c40fcc529570aa7..a68a9c51855f7de4f49f3901a319d85972b94d74 100644 --- a/tests/test_optimizer/test_nvme.py +++ b/tests/test_optimizer/test_nvme.py @@ -1,4 +1,3 @@ -import pytest import torch from colossalai.nn.optimizer import CPUAdam, HybridAdam @@ -15,23 +14,22 @@ def move_some_params_to_cuda(model, torch_model): def check_params_equal(model, torch_model): for p, torch_p in zip(model.parameters(), torch_model.parameters()): - assert torch.allclose(p, torch_p, atol=1e-3), f'diff: {torch.abs(p - torch_p)}' + assert torch.allclose(p, torch_p, atol=1e-3), f"diff: {torch.abs(p - torch_p)}" @clear_cache_before_run() -@parameterize('nvme_offload_fraction', [0.0, 0.5, 1.0]) -@parameterize('nvme_offload_dir', ['./offload', None]) -@parameterize('adam_cls', [CPUAdam, HybridAdam]) +@parameterize("nvme_offload_fraction", [0.0, 0.5, 1.0]) +@parameterize("nvme_offload_dir", ["./offload", None]) +@parameterize("adam_cls", [CPUAdam, HybridAdam]) def test_nvme_adam(nvme_offload_fraction, nvme_offload_dir, adam_cls): - get_components_func = non_distributed_component_funcs.get_callable('simple_net') + get_components_func = non_distributed_component_funcs.get_callable("simple_net") model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func() model = model_builder() torch_model = model_builder() move_some_params_to_cuda(model, torch_model) - optimizer = adam_cls(model.parameters(), - lr=0.1, - nvme_offload_fraction=nvme_offload_fraction, - nvme_offload_dir=nvme_offload_dir) + optimizer = adam_cls( + model.parameters(), lr=0.1, nvme_offload_fraction=nvme_offload_fraction, nvme_offload_dir=nvme_offload_dir + ) torch_optimizer = torch.optim.Adam(torch_model.parameters(), lr=0.1) with torch.no_grad(): for p, torch_p in zip(model.parameters(), torch_model.parameters()): @@ -45,5 +43,5 @@ def test_nvme_adam(nvme_offload_fraction, nvme_offload_dir, adam_cls): check_params_equal(model, torch_model) -if __name__ == '__main__': - test_nvme_adam(0.5, './offload', CPUAdam) +if __name__ == "__main__": + test_nvme_adam(0.5, "./offload", CPUAdam) diff --git a/tests/test_pipeline/rpc_test_utils.py b/tests/test_pipeline/rpc_test_utils.py deleted file mode 100644 index dab474a4ee21156d4ebdcd6e19570878776b1efe..0000000000000000000000000000000000000000 --- a/tests/test_pipeline/rpc_test_utils.py +++ /dev/null @@ -1,150 +0,0 @@ -import argparse -import os -import warnings - -import torch -import torch.distributed as dist -import torch.distributed.rpc as rpc -import torch.multiprocessing as mp -from torch import nn -from torch._C._distributed_rpc import _is_current_rpc_agent_set -from torch.optim import SGD, Adam, Optimizer, RMSprop - -from colossalai import launch -from colossalai.logging import disable_existing_loggers -from colossalai.pipeline.pipeline_process_group import ppg - -rpc_is_initialized = _is_current_rpc_agent_set - - -def color_debug(text, prefix=' ', color='blue'): - color = color.upper() - print(getattr(Back, color), prefix, Style.RESET_ALL, text) - - -class MLP(nn.Module): - - def __init__(self, dim: int, layers: int): - super().__init__() - self.layers = torch.nn.ModuleList() - - for _ in range(layers): - self.layers.append(nn.Linear(dim, dim, bias=False)) - - def forward(self, x): - for layer in self.layers: - x = layer(x) - return x.sum() - - -class DAG_MLP(nn.Module): - - def __init__(self, dim: int, layers: int): - super().__init__() - self.layers = torch.nn.ModuleList() - self.dag_layer = nn.Linear(dim, dim, bias=False) - - for _ in range(layers): - self.layers.append(nn.Linear(dim, dim, bias=False)) - - def forward(self, x, y): - for layer in self.layers: - x = layer(x) - y = self.dag_layer(y) - return x.sum(), y.sum() - - -class RpcTestModel(nn.Module): - - def __init__(self, stage_id, actual_stage_num, feat_num, h) -> None: - super().__init__() - self.rank = stage_id - self.is_last_rank = stage_id == actual_stage_num - 1 - self.linear_name = f'linear_{stage_id}' - - if stage_id == 0: - linear = nn.Linear(feat_num, h) - elif stage_id == actual_stage_num - 1: - linear = nn.Linear(h, 1) - else: - linear = nn.Linear(h, h) - - setattr(self, self.linear_name, linear) - - def forward(self, x) -> torch.Tensor: - linear: nn.Module = getattr(self, self.linear_name) - out: torch.Tensor = linear(x) - - if self.is_last_rank: - out = out.sum() - return out - - -def parse_args(): - parser = argparse.ArgumentParser() - parser.add_argument('--epoch', type=int, default=1) - parser.add_argument('--world_size', type=int, default=2) - parser.add_argument('--batch_size', type=int, default=16) - parser.add_argument('--dp_degree', type=int, default=1) - parser.add_argument('--tp_degree', type=int, default=1) - parser.add_argument('--num_microbatches', type=int, default=2) - parser.add_argument('--chunk', type=int, default=1) - parser.add_argument('--use_checkpoint', action='store_true') - parser.add_argument('--optimizer', type=str, choices=['SGD', 'Adam', 'RMSprop'], default='SGD') - parser.add_argument('--device', type=str, choices=['cpu', 'cuda'], default='cuda') - parser.add_argument('--master_addr', type=str, default='localhost') - parser.add_argument('--master_port', type=str, default='29020') - parser.add_argument('--num_worker_threads', type=str, default=128) - return parser.parse_args() - - -def pg_parse_args(): - parser = argparse.ArgumentParser() - parser.add_argument('--world_size', type=int, default=4) - parser.add_argument('--dp_degree', type=int, default=2) - parser.add_argument('--tp_degree', type=int, default=1) - parser.add_argument('--chunk', type=int, default=1) - parser.add_argument('--num_worker_threads', type=str, default=128) - parser.add_argument('--device', type=str, choices=['cpu', 'cuda'], default='cuda') - parser.add_argument('--master_addr', type=str, default='localhost') - parser.add_argument('--master_port', type=str, default='29020') - return parser.parse_args() - - -def run_worker(rank, args, master_func): - os.environ['MASTER_ADDR'] = args.master_addr - os.environ['MASTER_PORT'] = args.master_port - - device = args.device - world_size = args.world_size - dp_degree = args.dp_degree - tp_degree = args.tp_degree - num_worker_threads = args.num_worker_threads - host = args.master_addr - port = args.master_port - backend = 'nccl' if device == 'cuda' else 'gloo' - - disable_existing_loggers() - - launch(dict(), rank, world_size, host, int(port), backend, verbose=False) - ppg.set_global_info(rank=rank, - world_size=world_size, - dp_degree=dp_degree, - tp_degree=tp_degree, - num_worker_threads=num_worker_threads, - device=device) - - # in rpc mode, only rank 0 is needed to be coded - if rank == 0: - master_func(args) - # barrier here - if rpc_is_initialized(): - rpc.shutdown() - else: - warnings.warn("RPC has not been initialized") - - -def rpc_run(args, master_func): - world_size = args.world_size - assert args.num_microbatches >= args.world_size, "num_microbatches cannot be fewer than world_size!" - mp.spawn(run_worker, args=(args, master_func), nprocs=world_size) diff --git a/tests/test_pipeline/test_cuda_rpc_optimizer.py b/tests/test_pipeline/test_cuda_rpc_optimizer.py deleted file mode 100644 index 842566730caf2454bdeb6e27b864e58beb3ba7a1..0000000000000000000000000000000000000000 --- a/tests/test_pipeline/test_cuda_rpc_optimizer.py +++ /dev/null @@ -1,81 +0,0 @@ -import torch -from torch import nn -from torch import autograd -from torch.optim import SGD, Adam, RMSprop, Optimizer - -from colossalai.pipeline.rpc._pipeline_schedule import FillDrainPipelineEngine, OneFOneBPipelineEngine -from colossalai.testing import assert_close -from rpc_test_utils import rpc_run, parse_args, RpcTestModel - -# global variable for model created -feat_num = 100 -h = 100 - - -def partition(pp_rank: int, chunk: int, stage_num: int): - torch.manual_seed(1024) - partition = RpcTestModel(pp_rank, stage_num, feat_num, h) - return partition - - -def run_master(args): - torch.manual_seed(100) - - device = args.device - stage_num = args.world_size - chunk = args.chunk - actual_stage_num = stage_num * chunk - use_checkpoint = args.use_checkpoint - num_microbatches = args.num_microbatches - optimizer_class = globals()[args.optimizer] - - lr = 1e-3 - sample_num = 1024 - batch_size = 1024 - - assert sample_num % batch_size == 0 - - input_sample = torch.randn((sample_num, feat_num), device=device) - - engine = OneFOneBPipelineEngine(partition_fn=partition, - stage_num=stage_num, - num_microbatches=num_microbatches, - device=device, - chunk=chunk, - checkpoint=use_checkpoint) - - engine.initialize_optimizer(optimizer_class, lr=lr) - - _ = engine.forward_backward(input_sample) - - cuda_rpc_result = [] - single_result = [] - actual_stage_num = engine._get_actual_stage_num() - - # compute parameters after updating in cuda rpc - parameters = engine.remote_parameters() - for stage_id in range(actual_stage_num): - for p in parameters[stage_id]: - cuda_rpc_result.append(p) - - # compute forward result and backward grad of parameters just in rank_0 - test_model = nn.Sequential( - *[partition(pp_rank, chunk, actual_stage_num) for pp_rank in range(actual_stage_num)]).to(device) - optimizer: Optimizer = optimizer_class(test_model.parameters(), lr=lr) - input_sample = input_sample.requires_grad_() - out_val = test_model(input_sample).sum() - autograd.backward(out_val) - optimizer.step() - optimizer.zero_grad() - - for p in test_model.parameters(): - single_result.append(p) - - assert len(cuda_rpc_result) == len(single_result) - for r_c, r_s in zip(cuda_rpc_result, single_result): - assert_close(r_c, r_s, 0.001, 0.001) - - -if __name__ == "__main__": - args = parse_args() - rpc_run(args, run_master) diff --git a/tests/test_pipeline/test_cuda_rpc_performance.py b/tests/test_pipeline/test_cuda_rpc_performance.py deleted file mode 100644 index 6a0509555862ea0b1be5ba10c961075af854720a..0000000000000000000000000000000000000000 --- a/tests/test_pipeline/test_cuda_rpc_performance.py +++ /dev/null @@ -1,90 +0,0 @@ -import os -from typing import Callable, List, Optional, Type, Union -import time - -import pytest -import torch -import torch.nn as nn -from titans.dataloader.cifar10 import build_cifar -from torchvision.models import resnet50 -from torchvision.models.resnet import BasicBlock, Bottleneck, conv1x1 -from tqdm import tqdm - -from rpc_test_utils import rpc_run, parse_args -import colossalai -import colossalai.nn as col_nn -from colossalai.logging import disable_existing_loggers, get_dist_logger -from colossalai.trainer import Trainer, hooks -from colossalai.utils import MultiTimer, get_dataloader -from colossalai.context import ParallelMode -from colossalai.pipeline.pipelinable import PipelinableContext, PipelinableModel -from colossalai.pipeline.rpc import OneFOneBPipelineEngine, ChimeraPipelineEngine -from colossalai.pipeline.pipeline_process_group import ppg - - -def flatten(x): - return torch.flatten(x, 1) - - -def partition(pp_rank: int, chunk: int, stage_num: int): - pipelinable = PipelinableContext() - - # build model partitions - with pipelinable: - # input : [B, 3, 32, 32] - _ = resnet50() - - pipelinable.policy = "customized" - - exec_seq = [ - 'conv1', 'bn1', 'relu', 'maxpool', 'layer1', 'layer2', 'layer3', 'layer4', 'avgpool', (flatten, "behind"), 'fc' - ] - pipelinable.to_layer_list(exec_seq) - partition = pipelinable.partition(chunk, stage_num, pp_rank) - return partition - - -def run_master(args): - batch_size = args.batch_size - chunk = args.chunk - device = args.device - world_size = args.world_size - stage_num = world_size - num_microbatches = args.num_microbatches - - # build dataloader - root = os.environ.get('DATA', './data') - train_dataloader, test_dataloader = build_cifar(batch_size, root, padding=4, crop=32, resize=32) - criterion = nn.CrossEntropyLoss() - - pp_engine = OneFOneBPipelineEngine(partition_fn=partition, - stage_num=stage_num, - num_microbatches=num_microbatches, - device=device, - chunk=chunk, - criterion=criterion, - checkpoint=False) - - pp_engine.initialize_optimizer(torch.optim.Adam, lr=1e-3) - s = time.time() - - for bx, by in tqdm(train_dataloader): - pp_engine.forward_backward(bx, labels=by, forward_only=False) - - cost_time = time.time() - s - - print("total cost time :", cost_time) - print("cost time per batch:", cost_time / len(train_dataloader)) - - -@pytest.mark.skip("Test for performance, no need for CI") -def main(): - args = parse_args() - # this is due to limitation of partition function - args.world_size = 2 - args.chunk = 1 - rpc_run(args, run_master) - - -if __name__ == '__main__': - main() diff --git a/tests/test_pipeline/test_cuda_rpc_pipeline.py b/tests/test_pipeline/test_cuda_rpc_pipeline.py deleted file mode 100644 index 8d03e79813e89976153def8ecc7873af2c913701..0000000000000000000000000000000000000000 --- a/tests/test_pipeline/test_cuda_rpc_pipeline.py +++ /dev/null @@ -1,48 +0,0 @@ -import torch -from torch import nn - -from colossalai.pipeline.rpc._pipeline_schedule import FillDrainPipelineEngine, OneFOneBPipelineEngine -from rpc_test_utils import rpc_run, parse_args, RpcTestModel - -# global variable for model created -feat_num = 100 -h = 100 - - -def partition(pp_rank: int, chunk: int, stage_num: int): - torch.manual_seed(1024) - partition = RpcTestModel(pp_rank, stage_num, feat_num, h) - return partition - - -def run_master(args): - torch.manual_seed(100) - - epoch = args.epoch - device = args.device - stage_num = args.world_size - chunk = args.chunk - num_microbatches = args.num_microbatches - use_checkpoint = args.use_checkpoint - - sample_num = 1024 - batch_size = 1024 - - assert sample_num % batch_size == 0 - - input_sample = torch.randn((sample_num, feat_num), device=device) - - engine = OneFOneBPipelineEngine(partition_fn=partition, - stage_num=stage_num, - num_microbatches=num_microbatches, - device=device, - chunk=chunk, - checkpoint=use_checkpoint) - - for _ in range(epoch): - _ = engine.forward_backward(input_sample, forward_only=False) - - -if __name__ == "__main__": - args = parse_args() - rpc_run(args, run_master) diff --git a/tests/test_pipeline/test_cuda_rpc_value_correctness.py b/tests/test_pipeline/test_cuda_rpc_value_correctness.py deleted file mode 100644 index e6713478baecae9115ac3142d5c601d850bec070..0000000000000000000000000000000000000000 --- a/tests/test_pipeline/test_cuda_rpc_value_correctness.py +++ /dev/null @@ -1,73 +0,0 @@ -import torch -from torch import nn -from torch import autograd - -from colossalai.pipeline.rpc._pipeline_schedule import FillDrainPipelineEngine, OneFOneBPipelineEngine -from colossalai.testing import assert_close -from rpc_test_utils import rpc_run, parse_args, RpcTestModel - -feat_num = 100 -h = 100 - - -def partition(pp_rank: int, chunk: int, stage_num: int): - torch.manual_seed(1024) - partition = RpcTestModel(pp_rank, stage_num, feat_num, h) - return partition - - -def run_master(args): - torch.manual_seed(100) - - device = args.device - stage_num = args.world_size - chunk = args.chunk - actual_stage_num = stage_num * chunk - use_checkpoint = args.use_checkpoint - num_microbatches = args.num_microbatches - - sample_num = 1024 - batch_size = 1024 - - assert sample_num % batch_size == 0 - - input_sample = torch.randn((sample_num, feat_num), device=device) - - engine = OneFOneBPipelineEngine(partition_fn=partition, - stage_num=stage_num, - num_microbatches=num_microbatches, - device=device, - chunk=chunk, - checkpoint=use_checkpoint) - - forward_result = engine.forward_backward(input_sample) - - cuda_rpc_result = [] - single_result = [] - actual_stage_num = engine._get_actual_stage_num() - - # compute forward result and backward grad of parameters in cuda rpc - cuda_rpc_result.append(sum(forward_result[0])) - grad = engine.remote_grad() - for stage_id in range(actual_stage_num): - for p in grad[stage_id]: - cuda_rpc_result.append(p) - - # compute forward result and backward grad of parameters just in rank_0 - test_model = nn.Sequential( - *[partition(pp_rank, chunk, actual_stage_num) for pp_rank in range(actual_stage_num)]).to(device) - input_sample = input_sample.requires_grad_() - out_val = test_model(input_sample).sum() - autograd.backward(out_val) - single_result.append(out_val) - for p in test_model.parameters(): - single_result.append(p.grad) - - assert len(cuda_rpc_result) == len(single_result) - for r_c, r_s in zip(cuda_rpc_result, single_result): - assert_close(r_c, r_s, 0.001, 0.001) - - -if __name__ == "__main__": - args = parse_args() - rpc_run(args, run_master) diff --git a/tests/test_pipeline/test_p2p_communication.py b/tests/test_pipeline/test_p2p_communication.py new file mode 100644 index 0000000000000000000000000000000000000000..1665711ceeefb2433da33f4637d48facea39bb8d --- /dev/null +++ b/tests/test_pipeline/test_p2p_communication.py @@ -0,0 +1,59 @@ +import pytest +import torch +import torch.distributed as dist + +import colossalai +from colossalai.cluster import ProcessGroupMesh +from colossalai.pipeline.p2p import PipelineP2PCommunication +from colossalai.pipeline.stage_manager import PipelineStageManager +from colossalai.testing import rerun_if_address_is_in_use, spawn +from colossalai.utils import get_current_device + + +def check_p2p_communication(): + pg_mesh = ProcessGroupMesh(2) + stage_manager = PipelineStageManager(pg_mesh, 0) + p2p = PipelineP2PCommunication(stage_manager) + + rank = dist.get_rank() + + tensor = torch.ones(1, device=get_current_device()) + + if rank == 0: + p2p.send_forward(tensor) + p2p.send_forward([tensor]) + p2p.send_forward({"tensor": tensor}) + else: + obj = p2p.recv_forward() + assert torch.equal(obj, tensor) + obj = p2p.recv_forward() + assert type(obj) == list and len(obj) == 1 and torch.equal(obj[0], tensor) + obj = p2p.recv_forward() + assert type(obj) == dict and "tensor" in obj and torch.equal(obj["tensor"], tensor) + + if rank == 1: + p2p.send_backward(tensor) + p2p.send_backward([tensor]) + p2p.send_backward({"tensor": tensor}) + else: + obj = p2p.recv_backward() + assert torch.equal(obj, tensor) + obj = p2p.recv_backward() + assert type(obj) == list and len(obj) == 1 and torch.equal(obj[0], tensor) + obj = p2p.recv_backward() + assert type(obj) == dict and "tensor" in obj and torch.equal(obj["tensor"], tensor) + + +def run_dist(rank, world_size, port): + colossalai.launch(config={}, rank=rank, world_size=world_size, port=port, host="localhost") + check_p2p_communication() + + +@pytest.mark.dist +@rerun_if_address_is_in_use() +def test_pipeline_p2p(): + spawn(run_dist, 2) + + +if __name__ == "__main__": + test_pipeline_p2p() diff --git a/tests/test_pipeline/test_pipeline_process_group.py b/tests/test_pipeline/test_pipeline_process_group.py deleted file mode 100644 index 2a00e3ac55b195b8b3cde6ec9c17b949dc7511af..0000000000000000000000000000000000000000 --- a/tests/test_pipeline/test_pipeline_process_group.py +++ /dev/null @@ -1,42 +0,0 @@ -import os - -import torch.distributed.rpc as rpc -from rpc_test_utils import pg_parse_args, rpc_is_initialized - -from colossalai.initialize import launch -from colossalai.logging import disable_existing_loggers -from colossalai.pipeline.pipeline_process_group import ppg -from colossalai.testing import spawn - - -def run_worker(rank, args): - os.environ['MASTER_ADDR'] = args.master_addr - os.environ['MASTER_PORT'] = args.master_port - - device = args.device - world_size = args.world_size - dp_degree = args.dp_degree - tp_degree = args.tp_degree - num_worker_threads = args.num_worker_threads - host = args.master_addr - port = args.master_port - backend = 'nccl' if device == 'cuda' else 'gloo' - - disable_existing_loggers() - launch(dict(), rank, world_size, host, int(port), backend, verbose=False) - - ppg.set_global_info(rank=rank, - world_size=world_size, - dp_degree=dp_degree, - tp_degree=tp_degree, - num_worker_threads=num_worker_threads, - device=device) - - if rpc_is_initialized(): - rpc.shutdown() - - -if __name__ == "__main__": - args = pg_parse_args() - world_size = args.world_size - spawn(run_worker, world_size, args=args) diff --git a/tests/test_pipeline/test_pipeline_utils/test_t5_pipeline_utils.py b/tests/test_pipeline/test_pipeline_utils/test_t5_pipeline_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..3723c9c1014a01e4880ded0b8fa5e9a07e30a085 --- /dev/null +++ b/tests/test_pipeline/test_pipeline_utils/test_t5_pipeline_utils.py @@ -0,0 +1,45 @@ +from colossalai.shardformer.policies.t5 import T5BasePolicy + + +def test_t5_pipeline_distribution(): + num_test_cases = 8 + test_dict = { + "num_encoder_layers": [2, 1, 3, 2, 3, 2, 10, 5], + "num_decoder_layers": [2, 8, 0, 2, 1, 5, 6, 22], + "num_stages": [2, 2, 2, 4, 4, 4, 8, 8], + "decoder_starting_stage": [1, 1, 2, 2, 3, 1, 5, 2], + } + + for i in range(num_test_cases): + _, decoder_starting_stage = T5BasePolicy.distribute_t5_layers( + test_dict["num_encoder_layers"][i], test_dict["num_decoder_layers"][i], test_dict["num_stages"][i] + ) + assert test_dict["decoder_starting_stage"][i] == decoder_starting_stage + + +def test_t5_pipeline_layers(): + num_test_cases = 4 + test_dict = { + "num_encoder_layers": [2, 3, 2, 4], + "num_decoder_layers": [2, 0, 2, 8], + "num_stages": [2, 2, 4, 4], + "layers_per_stage": [ + [[0, 2], [0, 2]], + [[0, 1], [1, 3]], + [[0, 1], [1, 2], [0, 1], [1, 2]], + [[0, 4], [0, 3], [3, 6], [6, 8]], + ], + } + + for i in range(num_test_cases): + layers_per_stage, decoder_starting_stage = T5BasePolicy.distribute_t5_layers( + test_dict["num_encoder_layers"][i], test_dict["num_decoder_layers"][i], test_dict["num_stages"][i] + ) + + for stage in range(test_dict["num_stages"][i]): + start_idx, end_idx = test_dict["layers_per_stage"][i][stage] + predicted_start, predicted_end = T5BasePolicy.get_t5_stage_index( + layers_per_stage, stage, decoder_starting_stage + ) + assert start_idx == predicted_start + assert end_idx == predicted_end diff --git a/tests/test_pipeline/test_pipeline_utils/test_whisper_pipeline_utils.py b/tests/test_pipeline/test_pipeline_utils/test_whisper_pipeline_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..f6be8f6feac2303f2df5ade3b2285154cef295aa --- /dev/null +++ b/tests/test_pipeline/test_pipeline_utils/test_whisper_pipeline_utils.py @@ -0,0 +1,50 @@ +from colossalai.shardformer.policies.whisper import WhisperPolicy + + +def test_whisper_pipeline_distribution(): + num_test_cases = 8 + test_dict = { + "num_encoder_layers": [2, 1, 3, 2, 3, 2, 10, 5], + "num_decoder_layers": [2, 8, 0, 2, 1, 5, 6, 22], + "num_stages": [2, 2, 2, 4, 4, 4, 8, 8], + "decoder_starting_stage": [1, 1, 2, 2, 3, 1, 5, 2], + } + + for i in range(num_test_cases): + _, decoder_starting_stage = WhisperPolicy.distribute_whisper_layers( + test_dict["num_encoder_layers"][i], test_dict["num_decoder_layers"][i], test_dict["num_stages"][i] + ) + assert test_dict["decoder_starting_stage"][i] == decoder_starting_stage + + +def test_whisper_pipeline_layers(): + num_test_cases = 4 + test_dict = { + "num_encoder_layers": [2, 3, 2, 4], + "num_decoder_layers": [2, 0, 2, 8], + "num_stages": [2, 2, 4, 4], + "layers_per_stage": [ + [[0, 2], [0, 2]], + [[0, 1], [1, 3]], + [[0, 1], [1, 2], [0, 1], [1, 2]], + [[0, 4], [0, 3], [3, 6], [6, 8]], + ], + } + + for i in range(num_test_cases): + layers_per_stage, decoder_starting_stage = WhisperPolicy.distribute_whisper_layers( + test_dict["num_encoder_layers"][i], test_dict["num_decoder_layers"][i], test_dict["num_stages"][i] + ) + + for stage in range(test_dict["num_stages"][i]): + start_idx, end_idx = test_dict["layers_per_stage"][i][stage] + predicted_start, predicted_end = WhisperPolicy.get_whisper_stage_index( + layers_per_stage, stage, decoder_starting_stage + ) + assert start_idx == predicted_start + assert end_idx == predicted_end + + +if __name__ == "__main__": + test_whisper_pipeline_distribution() + test_whisper_pipeline_layers() diff --git a/tests/test_pipeline/test_schedule/test_interleaved.py b/tests/test_pipeline/test_schedule/test_interleaved.py new file mode 100644 index 0000000000000000000000000000000000000000..f181453eaed521ead71d3b5264ac402f08cede25 --- /dev/null +++ b/tests/test_pipeline/test_schedule/test_interleaved.py @@ -0,0 +1,159 @@ +import copy +from functools import partial +from types import MethodType + +import pytest +import torch +import torch.nn as nn + +import colossalai +from colossalai.cluster import ProcessGroupMesh +from colossalai.interface import OptimizerWrapper +from colossalai.pipeline.schedule.interleaved_pp import InterleavedSchedule +from colossalai.pipeline.stage_manager import PipelineStageManager +from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn +from colossalai.testing.random import seed_all + + +class MlpModel(nn.Module): + def __init__(self): + super(MlpModel, self).__init__() + self.linear1 = nn.Linear(4, 8) + self.linear2 = nn.Linear(8, 8) + self.linear3 = nn.Linear(8, 8) + self.linear4 = nn.Linear(8, 8) + self.linear5 = nn.Linear(8, 8) + self.linear6 = nn.Linear(8, 8) + self.linear7 = nn.Linear(8, 8) + self.linear8 = nn.Linear(8, 4) + + def forward(self, x): + x = self.linear1(x) + x = self.linear2(x) + x = self.linear3(x) + x = self.linear4(x) + x = self.linear5(x) + x = self.linear6(x) + x = self.linear7(x) + x = self.linear8(x) + return x + + +def pp_linear_fwd( + forward, + data: torch.Tensor = None, + input_obj: torch.Tensor = None, + stage_mgr: PipelineStageManager = None, + num_chunks: int = None, + model_chunk_id: int = None, +): + if stage_mgr.is_first_stage() and model_chunk_id == 0: + return {"input_obj": forward(data)} + elif stage_mgr.is_last_stage() and model_chunk_id == num_chunks - 1: + return forward(input_obj) + else: + return {"input_obj": forward(input_obj)} + + +@parameterize("num_micro_batches", [4, 8, 12]) +def examine_pp(num_micro_batches): + """ + This test is to examine the correctness of interleaved 1F1B, compared with torch. + Be aware it contains some hardcodes. + """ + world_size = torch.distributed.get_world_size() + local_rank = torch.distributed.get_rank() + seed_all(1453) + + NUM_MICRO_BATCHS = num_micro_batches + BATCH_SIZE = num_micro_batches + NUM_CHUNKS = 2 + + # create model + torch_model = MlpModel().cuda() + + pp_model = copy.deepcopy(torch_model).cuda() + + DP_DIM, PP_DIM, TP_DIM = 0, 1, 2 + pg_mesh = ProcessGroupMesh(1, world_size, 1) + stage_manager = PipelineStageManager(pg_mesh, PP_DIM, is_virtual=True) + schedule = InterleavedSchedule(NUM_MICRO_BATCHS, NUM_CHUNKS, stage_manager) + + sharded_model = torch.nn.ModuleList() + for idx, (_, sub_model) in enumerate(pp_model.named_children()): + if idx % (world_size) == local_rank: + sub_model._forward = sub_model.forward + sub_model.forward = MethodType( + partial( + pp_linear_fwd, stage_mgr=stage_manager, num_chunks=NUM_CHUNKS, model_chunk_id=len(sharded_model) + ), + sub_model._forward, + ) + sharded_model.append(sub_model.cuda()) + + # create optimizer + torch_optimizer = torch.optim.SGD(torch_model.parameters(), lr=1) + pp_optimizer = OptimizerWrapper(torch.optim.SGD(sharded_model.parameters(), lr=1)) + + # create + seed_all(1453) + if local_rank == 0: + input_list = [torch.rand(BATCH_SIZE, 4).cuda()] + else: + input_list = [torch.zeros(BATCH_SIZE, 4).cuda()] + torch.distributed.all_reduce(input_list[0]) + + criterion = lambda x, y: torch.mean(x) + + # forward and backward + torch_output = torch_model(input_list[0]) + torch_loss = criterion(torch_output, _) + torch_loss.backward() + + pp_ret = schedule.forward_backward_step( + sharded_model, iter(input_list), criterion, pp_optimizer, return_loss=True, return_outputs=True + ) + + # check loss + if stage_manager.is_last_stage(): + assert torch.allclose(torch_loss, pp_ret["loss"]) + + # check gradients + torch_grad = [] + for torch_p in torch_model.parameters(): + torch_grad.append(torch_p.grad.data) + + for idx, pp_p in enumerate(sharded_model.parameters()): + if idx < 2: + assert torch.allclose(torch_grad[idx + local_rank * 2], pp_p.grad.data) + else: + assert torch.allclose(torch_grad[idx + local_rank * 2 + 6], pp_p.grad.data) + + # step + torch_optimizer.step() + pp_optimizer.step() + + # check updated param + torch_param = [] + for torch_p in torch_model.parameters(): + torch_param.append(torch_p.data) + for idx, pp_p in enumerate(sharded_model.parameters()): + if idx < 2: + assert torch.allclose(torch_param[idx + local_rank * 2], pp_p.data) + else: + assert torch.allclose(torch_param[idx + local_rank * 2 + 6], pp_p.data) + + +def run_dist(rank, world_size, port): + colossalai.launch(config=dict(), rank=rank, world_size=world_size, port=port, host="localhost") + examine_pp() + + +@pytest.mark.dist +@rerun_if_address_is_in_use() +def test_pp(): + spawn(run_dist, 4) + + +if __name__ == "__main__": + test_pp() diff --git a/tests/test_pipeline/test_schedule/test_oneF_oneB.py b/tests/test_pipeline/test_schedule/test_oneF_oneB.py new file mode 100644 index 0000000000000000000000000000000000000000..1d77edc2db114fc63b1120330ba5fe4c43f69941 --- /dev/null +++ b/tests/test_pipeline/test_schedule/test_oneF_oneB.py @@ -0,0 +1,128 @@ +import copy +from functools import partial +from types import MethodType + +import pytest +import torch +import torch.nn as nn + +import colossalai +from colossalai.cluster import ProcessGroupMesh +from colossalai.interface import OptimizerWrapper +from colossalai.pipeline.schedule.one_f_one_b import OneForwardOneBackwardSchedule +from colossalai.pipeline.stage_manager import PipelineStageManager +from colossalai.testing import rerun_if_address_is_in_use, spawn +from colossalai.testing.random import seed_all + + +class MlpModel(nn.Module): + def __init__(self): + super(MlpModel, self).__init__() + self.linear1 = nn.Linear(4, 8) + self.linear2 = nn.Linear(8, 4) + + def forward(self, x): + x = self.linear1(x) + x = self.linear2(x) + return x + + +def pp_linear_fwd( + forward, data: torch.Tensor = None, input_obj: torch.Tensor = None, stage_mgr: PipelineStageManager = None +): + if stage_mgr.is_first_stage(): + return {"input_obj": forward(data)} + elif stage_mgr.is_last_stage(): + return forward(input_obj) + else: + return {"input_obj": forward(input_obj)} + + +def examine_pp(): + """ + This test is to examine the correctness of 1F1B, compared with torch. + Be aware it contains some hardcodes. + """ + world_size = torch.distributed.get_world_size() + local_rank = torch.distributed.get_rank() + seed_all(1453) + + NUM_MICRO_BATCHS = 4 + BATCH_SIZE = 4 + + # create models + torch_model = MlpModel().cuda() + + pp_model = copy.deepcopy(torch_model).cuda() + + DP_DIM, PP_DIM, TP_DIM = 0, 1, 2 + pg_mesh = ProcessGroupMesh(1, world_size, 1) + stage_manager = PipelineStageManager(pg_mesh, PP_DIM) + schedule = OneForwardOneBackwardSchedule(stage_manager, num_microbatches=NUM_MICRO_BATCHS) + + for idx, (_, sub_model) in enumerate(pp_model.named_children()): + if idx % (world_size) == local_rank: + sharded_model = sub_model.cuda() + + sharded_model._forward = sharded_model.forward + sharded_model.forward = MethodType(partial(pp_linear_fwd, stage_mgr=stage_manager), sharded_model._forward) + + # create optimizer + torch_optimizer = torch.optim.SGD(torch_model.parameters(), lr=1) + pp_optimizer = OptimizerWrapper(torch.optim.SGD(sharded_model.parameters(), lr=1)) + + # create + seed_all(1453) + if stage_manager.is_first_stage(): + input_list = [torch.rand(BATCH_SIZE, 4).cuda()] + else: + input_list = [torch.zeros(BATCH_SIZE, 4).cuda()] + torch.distributed.all_reduce(input_list[0]) + + criterion = lambda x, y: torch.mean(x) + + # forward and backward + torch_output = torch_model(input_list[0]) + torch_loss = criterion(torch_output, _) + torch_loss.backward() + + pp_ret = schedule.forward_backward_step( + sharded_model, iter(input_list), criterion, pp_optimizer, return_loss=True, return_outputs=True + ) + + # check loss + if stage_manager.is_last_stage(): + assert torch.allclose(torch_loss, pp_ret["loss"]) + + # check gradients + torch_grad = [] + for torch_p in torch_model.parameters(): + torch_grad.append(torch_p.grad.data) + for idx, pp_p in enumerate(sharded_model.parameters()): + assert torch.allclose(torch_grad[idx + local_rank * 2], pp_p.grad.data) + + # step + torch_optimizer.step() + pp_optimizer.step() + + # check updated param + torch_param = [] + for torch_p in torch_model.parameters(): + torch_param.append(torch_p.data) + for idx, pp_p in enumerate(sharded_model.parameters()): + assert torch.allclose(torch_param[idx + local_rank * 2], pp_p.data) + + +def run_dist(rank, world_size, port): + colossalai.launch(config=dict(), rank=rank, world_size=world_size, port=port, host="localhost") + examine_pp() + + +@pytest.mark.dist +@rerun_if_address_is_in_use() +def test_pp(): + spawn(run_dist, 2) + + +if __name__ == "__main__": + test_pp() diff --git a/tests/test_pipeline/test_schedule/test_pipeline_schedule_utils.py b/tests/test_pipeline/test_schedule/test_pipeline_schedule_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..462355ee470b70bbfb26746dcf2a1155afe6032a --- /dev/null +++ b/tests/test_pipeline/test_schedule/test_pipeline_schedule_utils.py @@ -0,0 +1,47 @@ +import torch + +from colossalai.pipeline.schedule._utils import get_batch_size, get_micro_batch, merge_batch + + +def test_get_batch_size(): + tensor = torch.rand(2, 3) + assert get_batch_size(tensor) == 2 + assert get_batch_size([tensor]) == 2 + assert get_batch_size((1, tensor)) == 2 + assert get_batch_size({"tensor": tensor}) == 2 + assert get_batch_size({"dummy": [1], "tensor": tensor}) == 2 + assert get_batch_size({"tensor": [tensor]}) == 2 + + +def test_get_micro_batch(): + x = torch.rand(2, 1) + y = torch.rand(2, 3) + micro_batch = get_micro_batch(x, 0, 1) + assert torch.equal(micro_batch, x[0:1]) + micro_batch = get_micro_batch(x, 1, 1) + assert torch.equal(micro_batch, x[1:2]) + micro_batch = get_micro_batch([x, y], 0, 1) + assert torch.equal(micro_batch[0], x[0:1]) + assert torch.equal(micro_batch[1], y[0:1]) + micro_batch = get_micro_batch([x, y], 1, 1) + assert torch.equal(micro_batch[0], x[1:2]) + assert torch.equal(micro_batch[1], y[1:2]) + micro_batch = get_micro_batch({"x": x, "y": y}, 0, 1) + assert torch.equal(micro_batch["x"], x[0:1]) + assert torch.equal(micro_batch["y"], y[0:1]) + micro_batch = get_micro_batch({"x": x, "y": y}, 1, 1) + assert torch.equal(micro_batch["x"], x[1:2]) + assert torch.equal(micro_batch["y"], y[1:2]) + + +def test_merge_batch(): + x = torch.rand(2, 1) + y = torch.rand(2, 3) + merged = merge_batch([x[0:1], x[1:2]]) + assert torch.equal(merged, x) + merged = merge_batch([[x[0:1], y[0:1]], [x[1:2], y[1:2]]]) + assert torch.equal(merged[0], x) + assert torch.equal(merged[1], y) + merged = merge_batch([{"x": x[0:1], "y": y[0:1]}, {"x": x[1:2], "y": y[1:2]}]) + assert torch.equal(merged["x"], x) + assert torch.equal(merged["y"], y) diff --git a/tests/test_pipeline/test_stage_manager.py b/tests/test_pipeline/test_stage_manager.py new file mode 100644 index 0000000000000000000000000000000000000000..ed8284b3e64c48dd7d0515930d4a5b3d054f8185 --- /dev/null +++ b/tests/test_pipeline/test_stage_manager.py @@ -0,0 +1,78 @@ +import pytest +import torch.distributed as dist + +import colossalai +from colossalai.cluster import ProcessGroupMesh +from colossalai.pipeline.stage_manager import PipelineStageManager +from colossalai.testing import rerun_if_address_is_in_use, spawn + + +def check_stage_manager(): + DP_DIM, PP_DIM = 0, 1 + DP_SIZE, PP_SIZE = 2, 2 + RANK_TO_COORDINATE = { + 0: (0, 0), + 1: (0, 1), + 2: (1, 0), + 3: (1, 1), + } + PP_RANKS_IN_GROUP = { + 0: [0, 1], + 1: [0, 1], + 2: [2, 3], + 3: [2, 3], + } + pg_mesh = ProcessGroupMesh(DP_SIZE, PP_SIZE) + stage_manager = PipelineStageManager(pg_mesh, PP_DIM) + rank = dist.get_rank() + + # check stage info + assert stage_manager.num_stages == PP_SIZE + assert stage_manager.stage == RANK_TO_COORDINATE[rank][PP_DIM] + + # check is_first_stage + ranks_in_group = PP_RANKS_IN_GROUP[rank] + is_first_stage = ranks_in_group.index(rank) == 0 + assert stage_manager.is_first_stage() == is_first_stage + + # check is_last_stage + is_last_stage = ranks_in_group.index(rank) == len(ranks_in_group) - 1 + assert stage_manager.is_last_stage() == is_last_stage + + # check prev rank + if not is_first_stage: + prev_rank = ranks_in_group[ranks_in_group.index(rank) - 1] + assert stage_manager.get_prev_rank() == prev_rank + + # check next rank + if not is_last_stage: + next_rank = ranks_in_group[ranks_in_group.index(rank) + 1] + assert stage_manager.get_next_rank() == next_rank + + # check p2p groups + for prev, cur in zip(ranks_in_group[:-1], ranks_in_group[1:]): + if rank in [prev, cur]: + group = stage_manager.get_p2p_process_group(prev, cur) + dist.barrier(group=group) + + # check stage groups + pg_mesh = ProcessGroupMesh(4) + stage_manager = PipelineStageManager(pg_mesh, 0) + group = stage_manager.init_process_group_by_stages([0, 2]) + if rank in [0, 2]: + dist.barrier(group=group) + + +def run_dist(rank, world_size, port): + colossalai.launch(config={}, rank=rank, world_size=world_size, port=port, host="localhost") + check_stage_manager() + + +@pytest.mark.dist +@rerun_if_address_is_in_use() +def test_pipeline_stage_manager(): + spawn(run_dist, 4) + + +if __name__ == "__main__": + test_pipeline_stage_manager() diff --git a/tests/test_shardformer/__init__.py b/tests/test_shardformer/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/tests/test_shardformer/test_layer/test_dist_crossentropy.py b/tests/test_shardformer/test_layer/test_dist_crossentropy.py new file mode 100644 index 0000000000000000000000000000000000000000..277a5b2bb4beb6ab6a5730cfa901f65ef370f3d1 --- /dev/null +++ b/tests/test_shardformer/test_layer/test_dist_crossentropy.py @@ -0,0 +1,45 @@ +import pytest +import torch +import torch.nn.functional as F + +import colossalai +from colossalai.logging import disable_existing_loggers +from colossalai.shardformer.layer import cross_entropy_1d +from colossalai.testing import rerun_if_address_is_in_use, spawn + +CONFIG = dict( + parallel=dict(data=1, pipeline=1, tensor=dict(size=2, mode="1d")), +) + + +def check_dist_crossentropy(rank, world_size, port, ignore_index): + disable_existing_loggers() + colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, port=port, host="localhost", backend="nccl") + + # prepare data + pred = torch.randn(2, 4, 8, requires_grad=True) + labels = torch.randint(8, (2, 4)) + # set some label to -100 to test the ignore index + labels[0, -1] = ignore_index + + org_pred = pred.view(-1, 8) + org_labels = labels.view(-1) + org_loss = F.cross_entropy(org_pred, org_labels) + + dist_pred = pred.chunk(world_size, -1)[rank] + dist_loss = cross_entropy_1d(dist_pred.to("cuda"), labels.to("cuda"), ignore_index=ignore_index) + + assert torch.allclose( + org_loss, dist_loss, atol=1e-5 + ), f"dist cross entropy loss is not equal to orgin loss\n{org_loss}\n{dist_loss}" + + +@pytest.mark.dist +@rerun_if_address_is_in_use() +def test_dist_crossentropy(): + ignore_index = -100 + spawn(check_dist_crossentropy, 2, ignore_index=ignore_index) + + +if __name__ == "__main__": + test_dist_crossentropy() diff --git a/tests/test_shardformer/test_layer/test_dropout.py b/tests/test_shardformer/test_layer/test_dropout.py new file mode 100644 index 0000000000000000000000000000000000000000..576620e6c7f38c059017f00001ca4b4eed7499ef --- /dev/null +++ b/tests/test_shardformer/test_layer/test_dropout.py @@ -0,0 +1,70 @@ +import torch +import torch.distributed as dist +import torch.nn as nn + +import colossalai +from colossalai.shardformer.layer import DropoutForParallelInput, DropoutForReplicatedInput +from colossalai.testing import assert_equal, assert_not_equal, rerun_if_address_is_in_use, spawn + + +def check_dropout_parallel_input(): + dropout = nn.Dropout().cuda() + dropout_1d = DropoutForParallelInput.from_native_module(dropout, process_group=None) + + # check computation correctness + x = torch.rand(4, 128).cuda() + + # we set seed so that dropout will generate the same mask + torch.cuda.manual_seed(1024) + out = dropout(x) + + # we set seed to simulate the same scenario + # but expect the dropout mask to be different + # due to the internal randomness control + torch.cuda.manual_seed(1024) + out_1d = dropout_1d(x) + + # ensure out is the same across all ranks + world_size = dist.get_world_size() + out_all = [torch.empty_like(out) for _ in range(world_size)] + dist.all_gather(out_all, out) + + for i in range(world_size): + assert_equal(out_all[i], out_all[0]) + + # ensure out_1d is different across ranks + out_1d_all = [torch.zeros_like(out_1d) for _ in range(world_size)] + dist.all_gather(out_1d_all, out_1d) + for i in range(1, world_size): + assert_not_equal(out_1d_all[i], out_1d_all[0]) + + +def check_dropout_replicated_input(): + dropout = nn.Dropout().cuda() + dropout_replica = DropoutForReplicatedInput.from_native_module(dropout, process_group=None) + + # check computation correctness + x = torch.rand(4, 128).cuda() + out_1d = dropout_replica(x) + + # ensure out_1d is different across ranks + world_size = dist.get_world_size() + out_1d_all = [torch.zeros_like(out_1d) for _ in range(world_size)] + dist.all_gather(out_1d_all, out_1d) + for i in range(1, world_size): + assert_equal(out_1d_all[i], out_1d_all[0]) + + +def run_dist(rank, world_size, port): + colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") + check_dropout_parallel_input() + check_dropout_replicated_input() + + +@rerun_if_address_is_in_use() +def test_dropout(): + spawn(run_dist, nprocs=2) + + +if __name__ == "__main__": + test_dropout() diff --git a/tests/test_shardformer/test_layer/test_embedding.py b/tests/test_shardformer/test_layer/test_embedding.py new file mode 100644 index 0000000000000000000000000000000000000000..3dbbcd766bf4f1257a133ffe50a8727de04cd8e2 --- /dev/null +++ b/tests/test_shardformer/test_layer/test_embedding.py @@ -0,0 +1,56 @@ +from contextlib import nullcontext + +import torch +import torch.distributed as dist +import torch.nn as nn +from torch.testing import assert_close + +import colossalai +from colossalai.lazy import LazyInitContext +from colossalai.shardformer.layer import Embedding1D +from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn + + +@parameterize("lazy_init", [False, True]) +def check_embedding_1d(lazy_init: bool): + ctx = LazyInitContext() if lazy_init else nullcontext() + + embedding = nn.Embedding(32, 128).cuda() + with ctx: + embedding_copy = nn.Embedding(32, 128).cuda() + embedding_1d = Embedding1D.from_native_module(embedding_copy, process_group=None) + + assert embedding_1d.weight.shape == torch.Size([32, 64]) + assert embedding_1d.weight is embedding_copy.weight + + # ensure state dict is reversibly loadable + embedding.load_state_dict(embedding_1d.state_dict()) + embedding_1d.load_state_dict(embedding.state_dict()) + + # check computation correctness + x = torch.randint(low=0, high=32, size=(4, 32)).cuda() + out = embedding(x) + gather_out = embedding_1d(x) + assert_close(out, gather_out) + + # check backward correctness + out.sum().backward() + gather_out.sum().backward() + + rank = dist.get_rank() + target_grad = torch.chunk(embedding.weight.grad, 2, dim=1)[rank] + assert_close(target_grad, embedding_1d.weight.grad) + + +def run_dist(rank, world_size, port): + colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") + check_embedding_1d() + + +@rerun_if_address_is_in_use() +def test_embedding_1d(): + spawn(run_dist, nprocs=2) + + +if __name__ == "__main__": + test_embedding_1d() diff --git a/tests/test_shardformer/test_layer/test_gpt2_qkv_fused_linear_1d.py b/tests/test_shardformer/test_layer/test_gpt2_qkv_fused_linear_1d.py new file mode 100644 index 0000000000000000000000000000000000000000..10ffdcd7138c4c11dc40c5d5e38cf279b36780b3 --- /dev/null +++ b/tests/test_shardformer/test_layer/test_gpt2_qkv_fused_linear_1d.py @@ -0,0 +1,148 @@ +from contextlib import nullcontext + +import torch +import torch.distributed as dist +import torch.nn as nn +from torch.testing import assert_close + +import colossalai +from colossalai.lazy import LazyInitContext +from colossalai.shardformer.layer import GPT2FusedLinearConv1D_Col, GPT2FusedLinearConv1D_Row +from colossalai.shardformer.layer.qkv_fused_linear import split_fused_qkv_in_gpt2_style +from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn + + +# This code is copied from https://github.com/huggingface/transformers +class Conv1D(nn.Module): + """ + 1D-convolutional layer as defined by Radford et al. for OpenAI GPT (and also used in GPT-2). + + Basically works like a linear layer but the weights are transposed. + + Args: + nf (`int`): The number of output features. + nx (`int`): The number of input features. + """ + + def __init__(self, nf, nx): + super().__init__() + self.nf = nf + self.weight = nn.Parameter(torch.empty(nx, nf)) + self.bias = nn.Parameter(torch.zeros(nf)) + nn.init.normal_(self.weight, std=0.02) + + def forward(self, x): + size_out = x.size()[:-1] + (self.nf,) + x = torch.addmm(self.bias, x.view(-1, x.size(-1)), self.weight) + x = x.view(size_out) + return x + + +def rearrange(tensor: torch.Tensor, dim: int): + tensor = tensor.clone() + world_size = 2 + order = torch.arange(world_size * 3) + new_order = [] + for i in range(world_size): + new_order.append(order[i::world_size]) + new_order = torch.cat(new_order) + + tensor_chunks = torch.chunk(tensor, world_size * 3, dim=dim) + rearanged_tensor_chunks = [tensor_chunks[i] for i in new_order] + rearanged_tensor = torch.cat(rearanged_tensor_chunks, dim=dim) + return rearanged_tensor + + +def check_linear_conv_1d_col(lazy_init: bool, seq_parallel: bool, overlap: bool): + ctx = LazyInitContext() if lazy_init else nullcontext() + linear = Conv1D(192, 48).cuda() + with ctx: + linear_copy = Conv1D(192, 48).cuda() + linear_conv_col = GPT2FusedLinearConv1D_Col.from_native_module( + linear_copy, process_group=None, gather_output=True, seq_parallel=seq_parallel, n_fused=3, overlap=overlap + ) + + assert linear.weight.shape == torch.Size([48, 192]) + assert linear.bias.shape == torch.Size([192]) + assert linear_conv_col.weight.shape == torch.Size([48, 96]) + assert linear_conv_col.bias.shape == torch.Size([96]) + assert linear_copy.weight is linear_conv_col.weight + assert linear_copy.bias is linear_conv_col.bias + + # ensure weights are reversibly loadable + linear_conv_col.load_state_dict(linear.state_dict()) + linear.load_state_dict(linear_conv_col.state_dict()) + + # check computation correctness + x = torch.rand(1, 4, 48).cuda() + out = linear(x) + x_for_shard = x.expand_as(x.clone()) if seq_parallel is False else torch.chunk(x.clone(), 2, dim=1)[dist.get_rank()] + gather_out = linear_conv_col(x_for_shard) + assert_close(rearrange(out, -1), gather_out) + + # check backward correctness + out.sum().backward() + gather_out.sum().backward() + + target_grad = split_fused_qkv_in_gpt2_style(linear.weight.grad, 3, None, True) + assert_close(target_grad, linear_conv_col.weight.grad) + + +def check_linear_conv_1d_row(lazy_init: bool, seq_parallel: bool): + ctx = LazyInitContext() if lazy_init else nullcontext() + + linear = Conv1D(192, 48).cuda() + with ctx: + linear_copy = Conv1D(192, 48).cuda() + linear_row = GPT2FusedLinearConv1D_Row.from_native_module( + linear_copy, process_group=None, parallel_input=False, seq_parallel=seq_parallel + ) + + assert linear.weight.shape == torch.Size([48, 192]) + assert linear_row.weight.shape == torch.Size([24, 192]) + assert linear_row.bias.shape == torch.Size([192]) + assert linear_copy.weight is linear_row.weight + assert linear_copy.bias is linear_row.bias + + # ensure weights are reversibly loadable + linear_row.load_state_dict(linear.state_dict()) + linear.load_state_dict(linear_row.state_dict()) + + # check computation correctness + x = torch.rand(1, 4, 48).cuda() + out = linear(x) + gather_out = linear_row(x) + target_out = out if seq_parallel is False else torch.chunk(out.clone(), 2, dim=1)[dist.get_rank()] + assert_close(target_out, gather_out) + + # check backward correctness + out.sum().backward() + gather_out.sum().backward() + + rank = dist.get_rank() + target_grad = torch.chunk(linear.weight.grad, 2, dim=0)[rank] + assert_close(target_grad, linear_row.weight.grad) + + +@parameterize("lazy_init", [False, True]) +@parameterize("seq_parallel", [False, True]) +@parameterize("overlap", [True]) +def check_gpt2_qkv_fused_linear_1d(lazy_init: bool, seq_parallel: bool, overlap: bool): + check_linear_conv_1d_col(lazy_init, seq_parallel, overlap) + check_linear_conv_1d_row(lazy_init, seq_parallel) + + +def run_dist(rank, world_size, port): + colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") + + # test for linear conv + check_gpt2_qkv_fused_linear_1d() + + +@rerun_if_address_is_in_use() +def test_linearconv(): + spawn(run_dist, nprocs=2) + + +if __name__ == "__main__": + test_linearconv() diff --git a/tests/test_shardformer/test_layer/test_layernorm.py b/tests/test_shardformer/test_layer/test_layernorm.py new file mode 100644 index 0000000000000000000000000000000000000000..3eb3bb2e5b8d3905a33ee22b32c37c30b6998962 --- /dev/null +++ b/tests/test_shardformer/test_layer/test_layernorm.py @@ -0,0 +1,54 @@ +from contextlib import nullcontext + +import torch +import torch.nn as nn +from torch.testing import assert_close + +import colossalai +from colossalai.lazy import LazyInitContext +from colossalai.shardformer.layer import FusedLayerNorm +from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn + + +@parameterize("lazy_init", [False, True]) +def check_layernorm(lazy_init: bool): + ctx = LazyInitContext() if lazy_init else nullcontext() + + norm = nn.LayerNorm(128, 0.00001).cuda() + with ctx: + norm_copy = nn.LayerNorm(128, 0.00001).cuda() + norm1d = FusedLayerNorm.from_native_module(norm_copy, process_group=None) + + assert norm1d.weight.shape == torch.Size([128]) + assert norm_copy.weight is norm1d.weight + assert norm_copy.bias is norm1d.bias + + # ensure state dict is reversibly loadable + norm.load_state_dict(norm1d.state_dict()) + norm1d.load_state_dict(norm.state_dict()) + + # check computation correctness + x = torch.rand(4, 128).cuda() + out = norm(x) + gather_out = norm1d(x) + assert_close(out, gather_out) + + # check backward correctness + out.sum().backward() + gather_out.sum().backward() + + assert_close(norm.weight.grad, norm1d.weight.grad) + + +def run_dist(rank, world_size, port): + colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") + check_layernorm() + + +@rerun_if_address_is_in_use() +def test_layernorm(): + spawn(run_dist, nprocs=2) + + +if __name__ == "__main__": + test_layernorm() diff --git a/tests/test_shardformer/test_layer/test_linear_1d.py b/tests/test_shardformer/test_layer/test_linear_1d.py new file mode 100644 index 0000000000000000000000000000000000000000..5bacf1865c486d3fde0e57b62c31c31b421298a8 --- /dev/null +++ b/tests/test_shardformer/test_layer/test_linear_1d.py @@ -0,0 +1,189 @@ +from contextlib import nullcontext + +import torch +import torch.distributed as dist +import torch.nn as nn +from torch.testing import assert_close + +import colossalai +from colossalai.lazy import LazyInitContext +from colossalai.shardformer.layer import Linear1D_Col, Linear1D_Row +from colossalai.tensor.d_tensor import is_distributed_tensor +from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn + + +def check_linear_1d_col(lazy_init: bool, seq_parallel: bool, overlap: bool): + ctx = LazyInitContext() if lazy_init else nullcontext() + linear = nn.Linear(32, 128).cuda() + with ctx: + linear_copy = nn.Linear(32, 128).cuda() + linear_col = Linear1D_Col.from_native_module( + linear_copy, process_group=None, gather_output=True, seq_parallel=seq_parallel, overlap=overlap + ) + + # ensure that the parameters are distributed + assert is_distributed_tensor(linear_col.weight) + assert is_distributed_tensor(linear_col.bias) + assert linear_copy.weight is linear_col.weight + assert linear_copy.bias is linear_col.bias + + # ensure the shape is correct + assert linear_col.weight.shape == torch.Size([64, 32]) + assert linear_col.bias.shape == torch.Size([64]) + + # ensure state dict is reversibly loadable + linear.load_state_dict(linear_col.state_dict()) + linear_col.load_state_dict(linear.state_dict()) + + # check computation correctness + # [batch_size, seq_len, hidden_size] + x = torch.rand(2, 4, 32).cuda() + x_for_unshard = x.expand_as(x.clone()) + x_for_unshard.requires_grad_(True) + x_for_shard = x.expand_as(x.clone()) if seq_parallel is False else torch.chunk(x.clone(), 2, dim=1)[dist.get_rank()] + x_for_shard.requires_grad_(True) + + out = linear(x_for_unshard) + gather_out = linear_col(x_for_shard) + assert_close(out, gather_out) + + # check backward correctness + out.sum().backward() + gather_out.sum().backward() + + rank = dist.get_rank() + target_grad = torch.chunk(linear.weight.grad, 2, dim=0)[rank] + assert_close(target_grad, linear_col.weight.grad) + + # check the input gradients + assert x_for_shard.grad is not None + assert x_for_unshard.grad is not None + target_unshard_gard = ( + x_for_unshard.grad + if seq_parallel is False + else torch.chunk(x_for_unshard.grad.clone(), 2, dim=1)[dist.get_rank()] + ) + assert_close(target_unshard_gard, x_for_shard.grad) + + +def check_linear_1d_row(lazy_init: bool, seq_parallel: bool): + ctx = LazyInitContext() if lazy_init else nullcontext() + + linear = nn.Linear(32, 128).cuda() + with ctx: + linear_copy = nn.Linear(32, 128).cuda() + linear_row = Linear1D_Row.from_native_module( + linear_copy, process_group=None, parallel_input=False, seq_parallel=seq_parallel + ) + + assert linear_row.weight.shape == torch.Size([128, 16]) + assert linear_row.bias.shape == torch.Size([128]) + assert linear_copy.weight is linear_row.weight + assert linear_copy.bias is linear_row.bias + + linear.load_state_dict(linear_row.state_dict()) + linear_row.load_state_dict(linear.state_dict()) + + # check computation correctness + # [batch_size, seq_len, hidden_size] + x = torch.rand(2, 4, 32).cuda() + x_for_unshard = x.expand_as(x.clone()) + x_for_unshard.requires_grad_(True) + x_for_shard = x.expand_as(x.clone()) + x_for_shard.requires_grad_(True) + + # run forward + out = linear(x_for_unshard) + gather_out = linear_row(x_for_shard) + target_out = out if seq_parallel is False else torch.chunk(out.clone(), 2, dim=1)[dist.get_rank()] + assert_close(target_out, gather_out) + + # check backward correctness + out.sum().backward() + gather_out.sum().backward() + + rank = dist.get_rank() + target_grad = torch.chunk(linear.weight.grad, 2, dim=1)[rank] + assert_close(target_grad, linear_row.weight.grad) + + # check the input gradients + assert x_for_shard.grad is not None + assert x_for_unshard.grad is not None + assert_close(x_for_unshard.grad, x_for_shard.grad) + + +def check_linear_col_plus_row(lazy_init: bool, seq_parallel: bool, overlap: bool): + ctx = LazyInitContext() if lazy_init else nullcontext() + + linear_1 = nn.Linear(32, 128).cuda() + linear_2 = nn.Linear(128, 32).cuda() + + with ctx: + linear_1_copy = nn.Linear(32, 128).cuda() + linear_2_copy = nn.Linear(128, 32).cuda() + linear_col = Linear1D_Col.from_native_module( + linear_1_copy, process_group=None, gather_output=False, seq_parallel=seq_parallel, overlap=overlap + ) + linear_row = Linear1D_Row.from_native_module( + linear_2_copy, process_group=None, parallel_input=True, seq_parallel=seq_parallel + ) + + linear_1.load_state_dict(linear_col.state_dict()) + linear_col.load_state_dict(linear_1.state_dict()) + linear_2.load_state_dict(linear_row.state_dict()) + linear_row.load_state_dict(linear_2.state_dict()) + + # check computation correctness + # [batch_size, seq_len, hidden_size] + x = torch.rand(2, 4, 32).cuda() + x_for_unshard = x.expand_as(x.clone()) + x_for_unshard.requires_grad_(True) + x_for_shard = x.expand_as(x.clone()) if seq_parallel is False else torch.chunk(x.clone(), 2, dim=1)[dist.get_rank()] + x_for_shard.requires_grad_(True) + + # run forward + unshard_out = linear_2(linear_1(x_for_unshard)) + shard_out = linear_row(linear_col(x_for_shard)) + target_out = unshard_out if seq_parallel is False else torch.chunk(unshard_out.clone(), 2, dim=1)[dist.get_rank()] + assert_close(target_out, shard_out) + + # check backward correctness + unshard_out.sum().backward() + shard_out.sum().backward() + + rank = dist.get_rank() + target_1_grad = torch.chunk(linear_1.weight.grad, 2, dim=0)[rank] + assert_close(target_1_grad, linear_col.weight.grad) + + # check the input gradients + assert x_for_shard.grad is not None + assert x_for_unshard.grad is not None + target_unshard_gard = ( + x_for_unshard.grad + if seq_parallel is False + else torch.chunk(x_for_unshard.grad.clone(), 2, dim=1)[dist.get_rank()] + ) + assert_close(target_unshard_gard, x_for_shard.grad) + + +@parameterize("lazy_init", [False, True]) +@parameterize("seq_parallel", [False, True]) +@parameterize("overlap", [True]) +def run_dist_linear_test(lazy_init, seq_parallel, overlap): + check_linear_1d_col(lazy_init, seq_parallel, overlap) + check_linear_1d_row(lazy_init, seq_parallel) + check_linear_col_plus_row(lazy_init, seq_parallel, overlap) + + +def check_dist_linear(rank, world_size, port): + colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") + run_dist_linear_test() + + +@rerun_if_address_is_in_use() +def test_linear(): + spawn(check_dist_linear, nprocs=2) + + +if __name__ == "__main__": + test_linear() diff --git a/tests/test_shardformer/test_layer/test_qkv_fused_linear_1d.py b/tests/test_shardformer/test_layer/test_qkv_fused_linear_1d.py new file mode 100644 index 0000000000000000000000000000000000000000..b02d581810cddf46c80fa6500934f9c3cbd06c37 --- /dev/null +++ b/tests/test_shardformer/test_layer/test_qkv_fused_linear_1d.py @@ -0,0 +1,139 @@ +from contextlib import nullcontext + +import torch +import torch.distributed as dist +import torch.nn as nn +from torch.testing import assert_close + +import colossalai +from colossalai.lazy import LazyInitContext +from colossalai.shardformer.layer import GPT2FusedLinearConv1D_Col, GPT2FusedLinearConv1D_Row +from colossalai.shardformer.layer.qkv_fused_linear import split_fused_qkv_in_gpt2_style +from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn + + +# This code is copied from https://github.com/huggingface/transformers +class Conv1D(nn.Module): + """ + 1D-convolutional layer as defined by Radford et al. for OpenAI GPT (and also used in GPT-2). + + Basically works like a linear layer but the weights are transposed. + + Args: + nf (`int`): The number of output features. + nx (`int`): The number of input features. + """ + + def __init__(self, nf, nx): + super().__init__() + self.nf = nf + self.weight = nn.Parameter(torch.empty(nx, nf)) + self.bias = nn.Parameter(torch.zeros(nf)) + nn.init.normal_(self.weight, std=0.02) + + def forward(self, x): + size_out = x.size()[:-1] + (self.nf,) + x = torch.addmm(self.bias, x.view(-1, x.size(-1)), self.weight) + x = x.view(size_out) + return x + + +def rearrange(tensor: torch.Tensor, dim: int): + tensor = tensor.clone() + world_size = 2 + order = torch.arange(world_size * 3) + new_order = [] + for i in range(world_size): + new_order.append(order[i::world_size]) + new_order = torch.cat(new_order) + + tensor_chunks = torch.chunk(tensor, world_size * 3, dim=dim) + rearanged_tensor_chunks = [tensor_chunks[i] for i in new_order] + rearanged_tensor = torch.cat(rearanged_tensor_chunks, dim=dim) + return rearanged_tensor + + +@parameterize("lazy_init", [False, True]) +def check_linear_conv_1d_col(lazy_init: bool): + ctx = LazyInitContext() if lazy_init else nullcontext() + linear = Conv1D(192, 48).cuda() + with ctx: + linear_copy = Conv1D(192, 48).cuda() + linear_conv_col = GPT2FusedLinearConv1D_Col.from_native_module( + linear_copy, process_group=None, gather_output=True, n_fused=3 + ) + + assert linear.weight.shape == torch.Size([48, 192]) + assert linear.bias.shape == torch.Size([192]) + assert linear_conv_col.weight.shape == torch.Size([48, 96]) + assert linear_conv_col.bias.shape == torch.Size([96]) + assert linear_copy.weight is linear_conv_col.weight + assert linear_copy.bias is linear_conv_col.bias + + # ensure weights are reversibly loadable + linear_conv_col.load_state_dict(linear.state_dict()) + linear.load_state_dict(linear_conv_col.state_dict()) + + # check computation correctness + x = torch.rand(4, 48).cuda() + out = linear(x) + gather_out = linear_conv_col(x) + assert_close(rearrange(out, 1), gather_out) + + # check backward correctness + out.sum().backward() + gather_out.sum().backward() + + target_grad = split_fused_qkv_in_gpt2_style(linear.weight.grad, 3, None, True) + assert_close(target_grad, linear_conv_col.weight.grad) + + +@parameterize("lazy_init", [False, True]) +def check_linear_conv_1d_row(lazy_init: bool): + ctx = LazyInitContext() if lazy_init else nullcontext() + + linear = Conv1D(192, 48).cuda() + with ctx: + linear_copy = Conv1D(192, 48).cuda() + linear_row = GPT2FusedLinearConv1D_Row.from_native_module(linear_copy, process_group=None, parallel_input=False) + + assert linear.weight.shape == torch.Size([48, 192]) + assert linear_row.weight.shape == torch.Size([24, 192]) + assert linear_row.bias.shape == torch.Size([192]) + assert linear_copy.weight is linear_row.weight + assert linear_copy.bias is linear_row.bias + + # ensure weights are reversibly loadable + linear_row.load_state_dict(linear.state_dict()) + linear.load_state_dict(linear_row.state_dict()) + + # check computation correctness + x = torch.rand(4, 48).cuda() + out = linear(x) + gather_out = linear_row(x) + assert_close(out, gather_out) + + # check backward correctness + out.sum().backward() + gather_out.sum().backward() + + rank = dist.get_rank() + target_grad = torch.chunk(linear.weight.grad, 2, dim=0)[rank] + assert_close(target_grad, linear_row.weight.grad) + + +def run_dist(rank, world_size, port): + colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") + + # test for linear conv + check_linear_conv_1d_col() + check_linear_conv_1d_row() + + +@rerun_if_address_is_in_use() +def test_linearconv(): + spawn(run_dist, nprocs=2) + + +if __name__ == "__main__": + test_linearconv() diff --git a/tests/test_shardformer/test_layer/test_vocab_parallel_embedding_1d.py b/tests/test_shardformer/test_layer/test_vocab_parallel_embedding_1d.py new file mode 100644 index 0000000000000000000000000000000000000000..b23a44f2dffa62d315db678c15f81bc27fb6766f --- /dev/null +++ b/tests/test_shardformer/test_layer/test_vocab_parallel_embedding_1d.py @@ -0,0 +1,58 @@ +from contextlib import nullcontext + +import torch +import torch.distributed as dist +import torch.nn as nn +from torch.testing import assert_close + +import colossalai +from colossalai.lazy import LazyInitContext +from colossalai.shardformer.layer import VocabParallelEmbedding1D +from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn + + +@parameterize("lazy_init", [False, True]) +def check_vocab_embedding_1d(lazy_init: bool): + ctx = LazyInitContext() if lazy_init else nullcontext() + + embedding = nn.Embedding(128, 32).to("cuda") + with ctx: + embedding_copy = nn.Embedding(128, 32).to("cuda") + dist_embedding_1d = VocabParallelEmbedding1D.from_native_module(embedding_copy, process_group=None) + + assert dist_embedding_1d.weight.shape == torch.Size([64, 32]) + assert dist_embedding_1d.num_embeddings == 64 + assert dist_embedding_1d.embedding_dim == 32 + assert embedding_copy.weight is dist_embedding_1d.weight + + # ensure state dict is reversibly loadable + embedding.load_state_dict(dist_embedding_1d.state_dict()) + dist_embedding_1d.load_state_dict(embedding.state_dict()) + + # check embedding correctness + x = torch.randint(0, 128, (4, 32)).to("cuda") + org_out = embedding(x) + dist_out = dist_embedding_1d(x) + assert_close(org_out, dist_out) + + # check backward correctness + org_out.sum().backward() + dist_out.sum().backward() + + rank = dist.get_rank() + target_grad = torch.chunk(embedding.weight.grad, 2, dim=0)[rank] + assert_close(target_grad, dist_embedding_1d.weight.grad) + + +def run_dist(rank, world_size, port): + colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") + check_vocab_embedding_1d() + + +@rerun_if_address_is_in_use() +def test_vocab_embedding(): + spawn(run_dist, nprocs=2) + + +if __name__ == "__main__": + test_vocab_embedding() diff --git a/tests/test_shardformer/test_model/__init__.py b/tests/test_shardformer/test_model/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/tests/test_shardformer/test_model/_utils.py b/tests/test_shardformer/test_model/_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..0a2b151d42741ecd313819e4a4d7d9ab036b31b4 --- /dev/null +++ b/tests/test_shardformer/test_model/_utils.py @@ -0,0 +1,342 @@ +import copy +import math +from contextlib import nullcontext +from typing import Any, Callable, Dict, List, Optional + +import torch +import torch.distributed as dist +from torch import Tensor +from torch import distributed as dist +from torch.distributed import ProcessGroup +from torch.nn import Module +from torch.optim import Adam, Optimizer + +from colossalai.booster import Booster +from colossalai.booster.plugin import HybridParallelPlugin +from colossalai.booster.plugin.hybrid_parallel_plugin import HybridParallelModule +from colossalai.lazy import LazyInitContext +from colossalai.pipeline.stage_manager import PipelineStageManager +from colossalai.shardformer import ShardConfig, ShardFormer +from colossalai.shardformer._utils import getattr_ +from colossalai.shardformer.policies.auto_policy import Policy +from colossalai.tensor.d_tensor.api import is_customized_distributed_tensor, is_distributed_tensor + + +def build_model( + model_fn, + enable_fused_normalization=True, + enable_tensor_parallelism=True, + enable_flash_attention=False, + enable_jit_fused=False, + enable_sequence_parallelism=False, + use_lazy_init: bool = False, +): + # create new model + ctx = LazyInitContext() if use_lazy_init else nullcontext() + with ctx: + # create new model + org_model = model_fn() + model_copy = copy.deepcopy(org_model) + if use_lazy_init: + ctx.materialize(org_model) + # shard model + shard_config = ShardConfig( + enable_fused_normalization=enable_fused_normalization, + enable_tensor_parallelism=enable_tensor_parallelism, + enable_flash_attention=enable_flash_attention, + enable_jit_fused=enable_jit_fused, + enable_sequence_parallelism=enable_sequence_parallelism, + ) + model_copy = copy.deepcopy(org_model) + shard_former = ShardFormer(shard_config=shard_config) + sharded_model, shared_params = shard_former.optimize(model_copy) + return org_model.cuda(), sharded_model.cuda() + + +def build_pipeline_model( + model_fn, + stage_manager=None, + enable_fused_normalization=False, + enable_tensor_parallelism=False, + use_lazy_init: bool = False, + policy: Optional[Policy] = None, +): + ctx = LazyInitContext() if use_lazy_init else nullcontext() + with ctx: + # create new model + org_model = model_fn() + model_copy = copy.deepcopy(org_model) + if use_lazy_init: + ctx.materialize(org_model) + + # shard model + shard_config = ShardConfig( + enable_fused_normalization=enable_fused_normalization, + enable_tensor_parallelism=enable_tensor_parallelism, + pipeline_stage_manager=stage_manager, + ) + + shard_former = ShardFormer(shard_config=shard_config) + sharded_model, shared_params = shard_former.optimize(model_copy, policy=policy) + return org_model.cuda(), sharded_model.cuda() + + +def run_forward(original_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn): + # prepare input + data = data_gen_fn() + data = {k: v.cuda() for k, v in data.items()} + # switch to train mode + original_model.train() + sharded_model.train() + # run forward + org_output = original_model(**data) + org_output = output_transform_fn(org_output) + org_loss = loss_fn(org_output) + + shard_output = sharded_model(**data) + shard_output = output_transform_fn(shard_output) + shard_loss = loss_fn(shard_output) + return org_output, org_loss, shard_output, shard_loss + + +def check_state_dict(org_model: Module, sharded_model: Module, name: str = ""): + org_sd = org_model.state_dict() + shard_sd = sharded_model.state_dict() + for k, v in org_sd.items(): + assert k in shard_sd, f"{name} {k} not in sharded model" + shard_v = shard_sd[k] + assert v.shape == shard_v.shape, f"{name} {k} shape mismatch, {v.shape} vs {shard_v.shape}" + assert v.dtype == shard_v.dtype, f"{name} {k} dtype mismatch, {v.dtype} vs {shard_v.dtype}" + assert torch.equal(v, shard_v), f"{name} {k} value mismatch" + + +def build_model_from_hybrid_plugin(model_fn: Callable, loss_fn: Callable, test_config: Dict[str, Any]): + use_lazy_init = False + if "use_lazy_init" in test_config: + use_lazy_init = test_config.pop("use_lazy_init") + + ctx = LazyInitContext() if use_lazy_init else nullcontext() + with ctx: + org_model = model_fn() + sharded_model = copy.deepcopy(org_model) + if use_lazy_init: + ctx.materialize(org_model) + + org_model = org_model.cuda() + org_optimizer = Adam(org_model.parameters(), lr=1e-3) + sharded_optimizer = Adam(sharded_model.parameters(), lr=1e-3) + criterion = loss_fn + + plugin = HybridParallelPlugin(**test_config) + booster = Booster(plugin=plugin) + + sharded_model, sharded_optimizer, criterion, _, _ = booster.boost(sharded_model, sharded_optimizer, criterion) + return org_model, org_optimizer, sharded_model, sharded_optimizer, criterion, booster + + +def run_forward_backward_with_hybrid_plugin( + org_model: Module, + sharded_model: Module, + sharded_optimizer: Optimizer, + data_gen_fn: Callable, + output_transform_fn: Callable, + criterion: Callable, + booster: Booster, +): + org_model.cuda() + sharded_model.cuda() + + def _criterion(outputs, inputs): + outputs = output_transform_fn(outputs) + loss = criterion(outputs) + return loss + + data = data_gen_fn() + + if booster.plugin.enable_sequence_parallelism and booster.plugin.tp_size != 0: + seq_len = data["input_ids"].shape[-1] + lcm = booster.plugin.tp_size * seq_len // math.gcd(booster.plugin.tp_size, seq_len) + times = lcm // seq_len + input_shape = data["input_ids"].shape + for k, v in data.items(): + if v.shape == input_shape: + data[k] = v.repeat(input_shape[:-1] + (input_shape[-1] * times,)) + + sharded_model.train() + if booster.plugin.stage_manager is not None: + for k, v in data.items(): + if torch.is_tensor(v) or "Tensor" in v.__class__.__name__: + new_shape = [1] * v.dim() + new_shape[0] = 4 + data[k] = v.to("cuda").repeat(*new_shape) + + data_iter = iter([data]) + sharded_output = booster.execute_pipeline( + data_iter, sharded_model, _criterion, sharded_optimizer, return_loss=True, return_outputs=True + ) + sharded_loss = sharded_output["loss"] + else: + data = {k: v.cuda() for k, v in data.items()} + sharded_output = sharded_model(**data) + + sharded_loss = criterion(sharded_output) + sharded_optimizer.backward(sharded_loss) + + org_model.train() + data = {k: v.cuda() for k, v in data.items()} + org_output = org_model(**data) + + org_loss = criterion(org_output) + org_loss.backward() + + return org_loss, org_output, sharded_loss, sharded_output + + +def check_output_hidden_state( + org_output: Tensor, + sharded_output: Tensor, + stage_manager: Optional[PipelineStageManager] = None, + atol: float = 1e-5, + rtol: float = 1e-3, + dim: int = 0, +): + org_hidden_state = org_output.last_hidden_state + + if stage_manager and stage_manager.is_last_stage(): + sharded_hidden_state = sharded_output["outputs"]["last_hidden_state"] + else: + sharded_hidden_state = sharded_output.last_hidden_state + + assert torch.allclose( + org_hidden_state.float(), sharded_hidden_state.float(), atol=atol, rtol=rtol + ), f"shard model's output hidden state is not equal to origin model's last hidden state\n{org_hidden_state}\n{sharded_hidden_state}" + + +def check_loss(org_loss: Tensor, sharded_loss: Tensor, atol: float = 1e-5, rtol: float = 1e-3): + assert torch.allclose( + org_loss.float(), sharded_loss.float(), atol=atol, rtol=rtol + ), f"shard model loss is not equal to origin model loss\n{org_loss}\n{sharded_loss}" + + +def check_weight( + org_model: Module, + sharded_model: Module, + layer_suffix: List[str], + tp_group: Optional[ProcessGroup] = None, + dim: int = 0, + atol: float = 1e-5, + rtol: float = 1e-3, + verbose: bool = False, +): + for suffix in layer_suffix: + org_weight = getattr_(org_model, suffix).weight + sharded_weight = getattr_(sharded_model, suffix).weight + + if is_distributed_tensor(sharded_weight) or is_customized_distributed_tensor(sharded_weight): + sharded_weight_list = [ + torch.zeros_like(sharded_weight).to("cuda") for _ in range(dist.get_world_size(tp_group)) + ] + dist.all_gather(sharded_weight_list, sharded_weight, tp_group) + sharded_weight = torch.cat(sharded_weight_list, dim=dim) + + if verbose and dist.get_rank() == 0: + print(f"'{suffix}' weight: {org_weight}, {sharded_weight}") + + assert torch.allclose( + org_weight.float(), sharded_weight.float(), atol=atol, rtol=rtol + ), f"shard model weight {suffix} is not equal to origin model weight\n{org_weight}\n{sharded_weight}" + + +def get_grad_tensors_for_check( + org_model: Module, + sharded_model: Module, + layer_suffix: List[str], + tp_group: ProcessGroup = None, + dim: int = 0, + atol: float = 1e-5, + rtol: float = 1e-3, + verbose: bool = False, + name: str = None, +): + grad_to_check = {} + for suffix in layer_suffix: + org_grad = getattr_(org_model, suffix).weight.grad + shard_grad = getattr_(sharded_model, suffix).weight.grad + shard_weight = getattr_(sharded_model, suffix).weight + if is_distributed_tensor(shard_weight) or is_customized_distributed_tensor(shard_weight): + shard_grad_list = [torch.zeros_like(shard_grad).to("cuda") for _ in range(dist.get_world_size(tp_group))] + dist.all_gather(shard_grad_list, shard_grad, tp_group) + shard_grad = torch.cat(shard_grad_list, dim=dim) + + # embedding may be resized when using tensor parallel + if shard_grad.shape[0] > org_grad.shape[0]: + shard_grad = shard_grad[: org_grad.shape[0], :] + if verbose and dist.get_rank() == 0: + print(f"'{suffix}' grad: {org_grad}, {shard_grad}") + + grad_to_check[suffix] = { + "org_grad": org_grad.float(), + "shard_grad": shard_grad.float(), + "rtol": rtol, + "atol": atol, + } + + return grad_to_check + + +# used by sam/blip2 +def check_grad( + org_model: Module, + sharded_model: Module, + layer_suffix: List[str], + tp_group: ProcessGroup = None, + dim: int = 0, + atol: float = 1e-5, + rtol: float = 1e-3, + verbose: bool = False, +): + for suffix in layer_suffix: + org_grad = getattr_(org_model, suffix).weight.grad + shard_grad = getattr_(sharded_model, suffix).weight.grad + shard_weight = getattr_(sharded_model, suffix).weight + if is_distributed_tensor(shard_weight) or is_customized_distributed_tensor(shard_weight): + shard_grad_list = [torch.zeros_like(shard_grad).to("cuda") for _ in range(dist.get_world_size(tp_group))] + dist.all_gather(shard_grad_list, shard_grad, tp_group) + shard_grad = torch.cat(shard_grad_list, dim=dim) + + # embedding may be resized when using tensor parallel + if shard_grad.shape[0] > org_grad.shape[0]: + shard_grad = shard_grad[: org_grad.shape[0], :] + if verbose and dist.get_rank() == 0: + print(f"'{suffix}' grad: {org_grad}, {shard_grad}") + + assert torch.allclose( + org_grad.float(), shard_grad.float(), rtol=rtol, atol=atol + ), f"error attribute '{suffix}', orgin model grad is not equal to shard model grad\n{org_grad}\n{shard_grad}" + + +def unwrap_model( + module: Module, base_model_class_name: Optional[str] = None, base_model_attribute_name: Optional[str] = None +): + if isinstance(module, HybridParallelModule): + module = module.unwrap() + if base_model_class_name is None: + return module + if module.__class__.__name__ == base_model_class_name: + return module + return getattr(module, base_model_attribute_name, None) + + +def check_all_grad_tensors(check_tensors): + """ + "org_grad": tensor to be compared from the original model + "shard_grad": tensor to be compared from the sharded model + """ + for suffix, check_info in check_tensors.items(): + org_grad = check_info["org_grad"] + shard_grad = check_info["shard_grad"] + rtol = check_info["rtol"] + atol = check_info["atol"] + assert torch.allclose( + org_grad, shard_grad, atol=atol, rtol=rtol + ), f"error attribute '{suffix}', orgin model grad is not equal to shard model grad\n{org_grad}\n{shard_grad}" diff --git a/tests/test_shardformer/test_model/test_shard_bert.py b/tests/test_shardformer/test_model/test_shard_bert.py new file mode 100644 index 0000000000000000000000000000000000000000..31fd58d06f778b4535dbdfd57ef68e6d84b6c0e5 --- /dev/null +++ b/tests/test_shardformer/test_model/test_shard_bert.py @@ -0,0 +1,206 @@ +import pytest +import torch + +import colossalai +from colossalai.logging import disable_existing_loggers +from colossalai.shardformer.layer.utils import Randomizer +from colossalai.tensor.d_tensor.api import clear_layout_converter +from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn +from tests.kit.model_zoo import model_zoo +from tests.test_shardformer.test_model._utils import ( + build_model_from_hybrid_plugin, + check_all_grad_tensors, + check_loss, + check_output_hidden_state, + check_weight, + get_grad_tensors_for_check, + run_forward_backward_with_hybrid_plugin, + unwrap_model, +) + + +def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config): + org_model, org_optimizer, sharded_model, sharded_optimizer, criterion, booster = build_model_from_hybrid_plugin( + model_fn, loss_fn, test_config + ) + + org_loss, org_output, sharded_loss, sharded_output = run_forward_backward_with_hybrid_plugin( + org_model, sharded_model, sharded_optimizer, data_gen_fn, output_transform_fn, criterion, booster + ) + + stage_manager = booster.plugin.stage_manager + tp_group = booster.plugin.tp_group + + bert = unwrap_model(org_model, "BertModel", "bert") + sharded_bert = unwrap_model(sharded_model, "BertModel", "bert") + + col_layer_for_check = ["encoder.layer[0].output.dense"] + row_layer_for_check = ["embeddings.word_embeddings", "encoder.layer[0].intermediate.dense"] + + # Save gradient tensors for comparison between the original model and the sharded model before optimizer step. + grads_to_check = {} + if test_config["precision"] == "fp32": + atol, rtol = 1e-4, 1e-3 + else: + atol, rtol = 5e-3, 5e-3 + if (stage_manager is None or stage_manager.is_first_stage()) and booster.plugin.zero_stage == 0: + col_layer_grads = get_grad_tensors_for_check( + bert, sharded_bert, col_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=1, verbose=False + ) + row_layer_grads = get_grad_tensors_for_check( + bert, sharded_bert, row_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=0, verbose=False + ) + grads_to_check.update(col_layer_grads) + grads_to_check.update(row_layer_grads) + + # optimizer executes step + org_optimizer.step() + sharded_optimizer.step() + + # check last hidden state & loss + if stage_manager is None or stage_manager.is_last_stage(): + if test_config["precision"] == "fp32": + atol, rtol = 1e-5, 1e-3 + else: + atol, rtol = 5e-3, 5e-3 + if org_model.__class__.__name__ == "BertModel": + check_output_hidden_state(org_output, sharded_output, stage_manager, atol=atol, rtol=rtol) + + check_loss(org_loss, sharded_loss, atol=atol, rtol=rtol) + + # check weights + if test_config["precision"] == "fp32": + atol, rtol = 5e-3, 1e-3 + else: + atol, rtol = 5e-3, 5e-3 + if stage_manager is None or stage_manager.is_first_stage(): + check_weight(bert, sharded_bert, col_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=1, verbose=False) + + # check grads + check_all_grad_tensors(grads_to_check) + + torch.cuda.empty_cache() + + +@parameterize( + "test_config", + [ + { + "tp_size": 1, + "pp_size": 2, + "num_microbatches": 4, + "use_lazy_init": True, + "precision": "fp32", + }, + { + "tp_size": 2, + "pp_size": 2, + "num_microbatches": 2, + "enable_all_optimization": True, + "use_lazy_init": True, + "precision": "fp16", + "initial_scale": 1, + }, + { + "tp_size": 4, + "pp_size": 1, + "enable_all_optimization": True, + "use_lazy_init": False, + "precision": "fp32", + }, + {"tp_size": 2, "pp_size": 1, "enable_all_optimization": True, "use_lazy_init": False, "precision": "fp32"}, + { + "tp_size": 2, + "pp_size": 1, + "enable_all_optimization": True, + "use_lazy_init": True, + "zero_stage": 2, + "precision": "fp16", + "initial_scale": 1, + }, + { + "tp_size": 1, + "pp_size": 2, + "num_microbatches": 2, + "enable_all_optimization": True, + "use_lazy_init": True, + "zero_stage": 1, + "precision": "fp16", + "initial_scale": 1, + }, + ], +) +def run_bert_test(test_config): + sub_model_zoo = model_zoo.get_sub_registry("transformers_bert") + + for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): + check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config) + + clear_layout_converter() + Randomizer.reset_index() + torch.cuda.empty_cache() + + +@parameterize( + "test_config", + [ + { + "tp_size": 2, + "pp_size": 2, + "num_microbatches": 4, + "enable_all_optimization": False, + "use_lazy_init": False, + "precision": "fp32", + }, + { + "tp_size": 2, + "pp_size": 2, + "num_microbatches": 4, + "enable_all_optimization": False, + "use_lazy_init": False, + "precision": "fp16", + "zero_stage": 1, + "initial_scale": 1, + }, + ], +) +def run_bert_3d_test(test_config): + sub_model_zoo = model_zoo.get_sub_registry("transformers_bert") + + for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): + check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config) + + clear_layout_converter() + Randomizer.reset_index() + torch.cuda.empty_cache() + + +def check_bert(rank, world_size, port): + disable_existing_loggers() + colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") + run_bert_test() + + +def check_bert_3d(rank, world_size, port): + disable_existing_loggers() + colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") + run_bert_3d_test() + + +@pytest.mark.dist +@rerun_if_address_is_in_use() +@clear_cache_before_run() +def test_bert(): + spawn(check_bert, 4) + + +@pytest.mark.largedist +@rerun_if_address_is_in_use() +@clear_cache_before_run() +def test_bert_3d(): + spawn(check_bert_3d, 8) + + +if __name__ == "__main__": + test_bert() + test_bert_3d() diff --git a/tests/test_shardformer/test_model/test_shard_blip2.py b/tests/test_shardformer/test_model/test_shard_blip2.py new file mode 100644 index 0000000000000000000000000000000000000000..02c15460ecb372a8313ac6aec19d151c4b6241f6 --- /dev/null +++ b/tests/test_shardformer/test_model/test_shard_blip2.py @@ -0,0 +1,81 @@ +import pytest +import torch + +import colossalai +from colossalai.logging import disable_existing_loggers +from colossalai.testing import ( + assert_hf_output_close, + clear_cache_before_run, + parameterize, + rerun_if_address_is_in_use, + spawn, +) +from tests.kit.model_zoo import model_zoo +from tests.test_shardformer.test_model._utils import build_model, check_grad, run_forward + + +def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn): + # check forward + org_output, org_loss, shard_output, shard_loss = run_forward( + org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn + ) + assert_hf_output_close(org_output, shard_output, ignore_keys=["past_key_values"]) + + # do backward + org_loss.backward() + shard_loss.backward() + + assert torch.allclose( + org_loss, shard_loss, atol=1e-5 + ), f"shard model loss is not equal to orgin model loss\n{org_loss}\n{shard_loss}" + + # check grad + + blip2 = org_model + sharded_blip2 = sharded_model + + # check grad + col_layer_for_check = [ + "vision_model.encoder.layers[0].self_attn.qkv", + "qformer.encoder.layer[0].attention.attention.query", + "language_model.model.decoder.layers[0].self_attn.k_proj", + ] + row_layer_for_check = [ + "vision_model.encoder.layers[0].self_attn.projection", + "qformer.encoder.layer[0].attention.output.dense", + "language_model.model.decoder.layers[0].self_attn.out_proj", + ] + check_grad(blip2, sharded_blip2, col_layer_for_check, atol=1e-6, rtol=1e-5, dim=0, verbose=False) + check_grad(blip2, sharded_blip2, row_layer_for_check, atol=1e-6, rtol=1e-5, dim=1, verbose=False) + + +@parameterize("enable_fused_normalization", [True, False]) +@parameterize("enable_tensor_parallelism", [True, False]) +@parameterize("enable_flash_attention", [True, False]) +@parameterize("enable_jit_fused", [True, False]) +def run_blip2_test(enable_fused_normalization, enable_tensor_parallelism, enable_flash_attention, enable_jit_fused): + sub_model_zoo = model_zoo.get_sub_registry("transformers_blip2") + for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): + org_model, sharded_model = build_model( + model_fn, enable_fused_normalization, enable_tensor_parallelism, enable_flash_attention, enable_jit_fused + ) + check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn) + + torch.cuda.empty_cache() + + +def check_blip2(rank, world_size, port): + disable_existing_loggers() + colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") + run_blip2_test() + + +@pytest.mark.dist +@rerun_if_address_is_in_use() +@clear_cache_before_run() +def test_blip2(): + spawn(check_blip2, 2) + + +if __name__ == "__main__": + test_blip2() diff --git a/tests/test_shardformer/test_model/test_shard_bloom.py b/tests/test_shardformer/test_model/test_shard_bloom.py new file mode 100644 index 0000000000000000000000000000000000000000..7fe791db6d5e31100fbe6f0c323be30c60d46322 --- /dev/null +++ b/tests/test_shardformer/test_model/test_shard_bloom.py @@ -0,0 +1,202 @@ +import pytest +import torch + +import colossalai +from colossalai.logging import disable_existing_loggers +from colossalai.shardformer.layer.utils import Randomizer +from colossalai.tensor.d_tensor.api import clear_layout_converter +from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn +from tests.kit.model_zoo import model_zoo +from tests.test_shardformer.test_model._utils import ( + build_model_from_hybrid_plugin, + check_all_grad_tensors, + check_loss, + check_output_hidden_state, + check_weight, + get_grad_tensors_for_check, + run_forward_backward_with_hybrid_plugin, + unwrap_model, +) + + +def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config): + org_model, org_optimizer, sharded_model, sharded_optimizer, criterion, booster = build_model_from_hybrid_plugin( + model_fn, loss_fn, test_config + ) + + org_loss, org_output, sharded_loss, sharded_output = run_forward_backward_with_hybrid_plugin( + org_model, sharded_model, sharded_optimizer, data_gen_fn, output_transform_fn, criterion, booster + ) + + stage_manager = booster.plugin.stage_manager + tp_group = booster.plugin.tp_group + + # unwrap model + bloom = unwrap_model(org_model, "BloomModel", "transformer") + sharded_bloom = unwrap_model(sharded_model, "BloomModel", "transformer") + + row_layer_for_check = ["h[0].self_attention.query_key_value", "word_embeddings"] + col_layer_for_check = ["h[0].self_attention.dense"] + + # Save gradient tensors for comparison between the original model and the sharded model. + grads_to_check = {} + if (stage_manager is None or stage_manager.is_first_stage()) and booster.plugin.zero_stage == 0: + if test_config["precision"] == "fp32": + atol, rtol = 1e-6, 1e-5 + else: + atol, rtol = 5e-3, 5e-3 + row_layer_grads = get_grad_tensors_for_check( + bloom, sharded_bloom, row_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=0, verbose=False + ) + col_layer_grads = get_grad_tensors_for_check( + bloom, sharded_bloom, col_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=1, verbose=False + ) + grads_to_check.update(col_layer_grads) + grads_to_check.update(row_layer_grads) + + # optimizer executes step + org_optimizer.step() + sharded_optimizer.step() + + # check last hidden state & loss + if stage_manager is None or stage_manager.is_last_stage(): + if test_config["precision"] == "fp32": + atol, rtol = 1e-5, 1e-3 + else: + atol, rtol = 5e-3, 5e-3 + if org_model.__class__.__name__ == "BloomModel": + check_output_hidden_state(org_output, sharded_output, stage_manager, atol=atol, rtol=rtol) + + check_loss(org_loss, sharded_loss, atol=atol, rtol=rtol) + + if stage_manager is None or stage_manager.is_first_stage(): + if test_config["precision"] == "fp32": + atol, rtol = 1e-4, 1e-3 + else: + atol, rtol = 5e-3, 5e-3 + check_weight(bloom, sharded_bloom, col_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=1, verbose=False) + + # check grads + check_all_grad_tensors(grads_to_check) + + torch.cuda.empty_cache() + + +@parameterize( + "test_config", + [ + { + "tp_size": 2, + "pp_size": 2, + "num_microbatches": 4, + "enable_all_optimization": True, + "use_lazy_init": True, + "precision": "fp16", + "initial_scale": 1, + }, + { + "tp_size": 1, + "pp_size": 2, + "num_microbatches": 4, + "enable_all_optimization": False, + "use_lazy_init": False, + "precision": "fp32", + }, + {"tp_size": 4, "pp_size": 1, "enable_all_optimization": True, "use_lazy_init": False, "precision": "fp32"}, + {"tp_size": 2, "pp_size": 1, "enable_all_optimization": True, "use_lazy_init": False, "precision": "fp32"}, + { + "tp_size": 2, + "pp_size": 1, + "enable_all_optimization": True, + "use_lazy_init": True, + "zero_stage": 2, + "precision": "fp16", + "initial_scale": 1, + }, + { + "tp_size": 1, + "pp_size": 2, + "num_microbatches": 2, + "enable_all_optimization": True, + "use_lazy_init": True, + "zero_stage": 1, + "precision": "fp16", + "initial_scale": 1, + }, + ], +) +def run_bloom_test(test_config): + sub_model_zoo = model_zoo.get_sub_registry("transformers_bloom") + + for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): + check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config) + + clear_layout_converter() + Randomizer.reset_index() + torch.cuda.empty_cache() + + +@parameterize( + "test_config", + [ + { + "tp_size": 2, + "pp_size": 2, + "num_microbatches": 4, + "enable_all_optimization": False, + "use_lazy_init": False, + "precision": "fp32", + "initial_scale": 1, + }, + { + "tp_size": 2, + "pp_size": 2, + "num_microbatches": 4, + "enable_all_optimization": False, + "use_lazy_init": False, + "precision": "fp16", + "zero_stage": 1, + "initial_scale": 1, + }, + ], +) +def run_bloom_3d_test(test_config): + sub_model_zoo = model_zoo.get_sub_registry("transformers_bloom") + + for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): + check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config) + + clear_layout_converter() + Randomizer.reset_index() + torch.cuda.empty_cache() + + +def check_bloom(rank, world_size, port): + disable_existing_loggers() + colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") + run_bloom_test() + + +def check_bloom_3d(rank, world_size, port): + disable_existing_loggers() + colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") + run_bloom_3d_test() + + +@pytest.mark.dist +@rerun_if_address_is_in_use() +@clear_cache_before_run() +def test_bloom(): + spawn(check_bloom, 4) + + +@pytest.mark.largedist +@rerun_if_address_is_in_use() +@clear_cache_before_run() +def test_bloom_3d(): + spawn(check_bloom_3d, 8) + + +if __name__ == "__main__": + test_bloom() + test_bloom_3d() diff --git a/tests/test_shardformer/test_model/test_shard_chatglm2.py b/tests/test_shardformer/test_model/test_shard_chatglm2.py new file mode 100644 index 0000000000000000000000000000000000000000..bdf5b79fc4985359e112cecec17de663be956fe6 --- /dev/null +++ b/tests/test_shardformer/test_model/test_shard_chatglm2.py @@ -0,0 +1,217 @@ +import pytest +import torch + +import colossalai +from colossalai.logging import disable_existing_loggers +from colossalai.shardformer.layer.utils import Randomizer +from colossalai.tensor.d_tensor.api import clear_layout_converter +from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn +from tests.kit.model_zoo import model_zoo +from tests.test_shardformer.test_model._utils import ( + build_model_from_hybrid_plugin, + check_all_grad_tensors, + check_loss, + check_output_hidden_state, + check_weight, + get_grad_tensors_for_check, + run_forward_backward_with_hybrid_plugin, + unwrap_model, +) + + +def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config): + org_model, org_optimizer, sharded_model, sharded_optimizer, criterion, booster = build_model_from_hybrid_plugin( + model_fn, loss_fn, test_config + ) + + org_loss, org_output, sharded_loss, sharded_output = run_forward_backward_with_hybrid_plugin( + org_model, sharded_model, sharded_optimizer, data_gen_fn, output_transform_fn, criterion, booster + ) + + stage_manager = booster.plugin.stage_manager + tp_group = booster.plugin.tp_group + + # unwrap model + chatglm_model = unwrap_model(org_model, "ChatGLMModel", "transformer") + shard_chatglm_model = unwrap_model(sharded_model, "ChatGLMModel", "transformer") + + row_layer_for_check = ["encoder.layers[0].self_attention.query_key_value", "embedding.word_embeddings"] + col_layer_for_check = ["encoder.layers[0].self_attention.dense"] + + # Save gradient tensors for comparison between the original model and the sharded model. + grads_to_check = {} + if (stage_manager is None or stage_manager.is_first_stage()) and booster.plugin.zero_stage == 0: + if test_config["precision"] == "fp32": + atol, rtol = 1e-6, 1e-3 + else: + atol, rtol = 5e-3, 5e-3 + row_layer_grads = get_grad_tensors_for_check( + chatglm_model, + shard_chatglm_model, + row_layer_for_check, + tp_group, + atol=atol, + rtol=rtol, + dim=0, + verbose=False, + ) + + col_layer_grads = get_grad_tensors_for_check( + chatglm_model, + shard_chatglm_model, + col_layer_for_check, + tp_group, + atol=atol, + rtol=rtol, + dim=1, + verbose=False, + ) + grads_to_check.update(col_layer_grads) + grads_to_check.update(row_layer_grads) + + # optimizer executes step + org_optimizer.step() + sharded_optimizer.step() + + # check last hidden state & loss + if stage_manager is None or stage_manager.is_last_stage(): + if test_config["precision"] == "fp32": + atol, rtol = 1e-5, 1e-3 + else: + atol, rtol = 5e-3, 5e-3 + + if org_model.__class__.__name__ == "ChatGLMModel": + check_output_hidden_state(org_output, sharded_output, stage_manager, atol=atol, rtol=rtol, dim=1) + + check_loss(org_loss, sharded_loss, atol=atol, rtol=rtol) + + # check weights + if stage_manager is None or stage_manager.is_first_stage(): + if test_config["precision"] == "fp32": + atol, rtol = 1e-4, 1e-3 + else: + atol, rtol = 5e-3, 5e-3 + check_weight( + chatglm_model, + shard_chatglm_model, + col_layer_for_check, + tp_group, + atol=atol, + rtol=rtol, + dim=1, + verbose=False, + ) + + # check grads + check_all_grad_tensors(grads_to_check) + + Randomizer.reset_index() + torch.cuda.empty_cache() + + +@parameterize( + "test_config", + [ + { + "tp_size": 2, + "pp_size": 2, + "num_microbatches": 4, + "enable_all_optimization": True, + "use_lazy_init": True, + "precision": "fp16", + "initial_scale": 1, + }, + { + "tp_size": 1, + "pp_size": 2, + "num_microbatches": 4, + "enable_all_optimization": False, + "use_lazy_init": False, + "precision": "fp32", + }, + {"tp_size": 4, "pp_size": 1, "enable_all_optimization": True, "use_lazy_init": False, "precision": "fp32"}, + {"tp_size": 2, "pp_size": 1, "enable_all_optimization": True, "use_lazy_init": False, "precision": "fp32"}, + { + "tp_size": 2, + "pp_size": 1, + "enable_all_optimization": True, + "use_lazy_init": True, + "zero_stage": 2, + "precision": "fp16", + "initial_scale": 1, + }, + ], +) +def run_chatglm_test(test_config): + sub_model_zoo = model_zoo.get_sub_registry("transformers_chatglm") + + for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): + check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config) + + clear_layout_converter() + torch.cuda.empty_cache() + + +@parameterize( + "test_config", + [ + { + "tp_size": 2, + "pp_size": 2, + "num_microbatches": 4, + "enable_all_optimization": False, + "use_lazy_init": False, + "precision": "fp32", + "initial_scale": 1, + }, + { + "tp_size": 2, + "pp_size": 2, + "num_microbatches": 4, + "enable_all_optimization": False, + "use_lazy_init": False, + "precision": "fp16", + "zero_stage": 1, + "initial_scale": 1, + }, + ], +) +def run_chatglm_3d_test(test_config): + sub_model_zoo = model_zoo.get_sub_registry("transformers_chatglm") + + for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): + check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config) + + clear_layout_converter() + torch.cuda.empty_cache() + + +def check_chatglm(rank, world_size, port): + disable_existing_loggers() + colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") + run_chatglm_test() + + +def check_chatglm_3d(rank, world_size, port): + disable_existing_loggers() + colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") + run_chatglm_3d_test() + + +@pytest.mark.dist +@rerun_if_address_is_in_use() +@clear_cache_before_run() +def test_chatglm(): + spawn(check_chatglm, 4) + + +@pytest.mark.largedist +@rerun_if_address_is_in_use() +@clear_cache_before_run() +def test_chatglm_3d(): + spawn(check_chatglm_3d, 8) + + +if __name__ == "__main__": + test_chatglm() + test_chatglm_3d() diff --git a/tests/test_shardformer/test_model/test_shard_gpt2.py b/tests/test_shardformer/test_model/test_shard_gpt2.py new file mode 100644 index 0000000000000000000000000000000000000000..69a15166a54ca05b908a7b5d4cb4cad88bea91e7 --- /dev/null +++ b/tests/test_shardformer/test_model/test_shard_gpt2.py @@ -0,0 +1,226 @@ +import pytest +import torch + +import colossalai +from colossalai.logging import disable_existing_loggers +from colossalai.shardformer.layer.utils import Randomizer +from colossalai.tensor.d_tensor.api import clear_layout_converter +from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn +from tests.kit.model_zoo import model_zoo +from tests.test_shardformer.test_model._utils import ( + build_model_from_hybrid_plugin, + check_all_grad_tensors, + check_loss, + check_output_hidden_state, + check_weight, + get_grad_tensors_for_check, + run_forward_backward_with_hybrid_plugin, + unwrap_model, +) + + +def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config): + org_model, org_optimizer, sharded_model, sharded_optimizer, criterion, booster = build_model_from_hybrid_plugin( + model_fn, loss_fn, test_config + ) + + org_loss, org_output, sharded_loss, sharded_output = run_forward_backward_with_hybrid_plugin( + org_model, sharded_model, sharded_optimizer, data_gen_fn, output_transform_fn, criterion, booster + ) + + stage_manager = booster.plugin.stage_manager + tp_group = booster.plugin.tp_group + + # unwrap model + gpt2 = unwrap_model(org_model, "GPT2Model", "transformer") + sharded_gpt2 = unwrap_model(sharded_model, "GPT2Model", "transformer") + + col_layer_for_check = ["h[0].mlp.c_fc"] + row_layer_for_check = ["wte", "h[0].mlp.c_proj"] + + # Save gradient tensors for comparison between the original model and the sharded model. + grads_to_check = {} + if (stage_manager is None or stage_manager.is_first_stage()) and booster.plugin.zero_stage == 0: + if test_config["precision"] == "fp32": + atol, rtol = 1e-4, 1e-3 + else: + atol, rtol = 5e-3, 5e-3 + col_layer_grads = get_grad_tensors_for_check( + gpt2, sharded_gpt2, col_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=1, verbose=False + ) + row_layer_grads = get_grad_tensors_for_check( + gpt2, sharded_gpt2, row_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=0, verbose=False + ) + grads_to_check.update(col_layer_grads) + grads_to_check.update(row_layer_grads) + + # optimizer executes step + org_optimizer.step() + sharded_optimizer.step() + + # check last hidden state & loss + if stage_manager is None or stage_manager.is_last_stage(): + if test_config["precision"] == "fp32": + atol, rtol = 1e-5, 1e-3 + else: + atol, rtol = 5e-3, 5e-3 + + if org_model.__class__.__name__ == "GPT2Model": + check_output_hidden_state(org_output, sharded_output, stage_manager, atol=atol, rtol=rtol) + + check_loss(org_loss, sharded_loss, atol=atol, rtol=rtol) + + # check weights + if stage_manager is None or stage_manager.is_first_stage(): + if test_config["precision"] == "fp32": + atol, rtol = 5e-3, 1e-3 + else: + atol, rtol = 5e-3, 5e-3 + check_weight(gpt2, sharded_gpt2, col_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=1, verbose=False) + + # check grads + check_all_grad_tensors(grads_to_check) + + Randomizer.reset_index() + torch.cuda.empty_cache() + + +@parameterize( + "test_config", + [ + { + "tp_size": 2, + "pp_size": 2, + "num_microbatches": 4, + "enable_all_optimization": True, + "use_lazy_init": True, + "precision": "fp16", + "initial_scale": 1, + }, + { + "tp_size": 1, + "pp_size": 2, + "num_microbatches": 4, + "enable_all_optimization": True, + "use_lazy_init": True, + "precision": "fp16", + "initial_scale": 1, + }, + { + "tp_size": 4, + "pp_size": 1, + "enable_all_optimization": True, + "use_lazy_init": False, + "precision": "fp32", + }, + { + "tp_size": 2, + "pp_size": 1, + "enable_all_optimization": True, + "use_lazy_init": False, + "precision": "fp32", + }, + { + "tp_size": 2, + "pp_size": 2, + "num_microbatches": 4, + "enable_all_optimization": True, + "use_lazy_init": True, + "precision": "fp32", + }, + { + "tp_size": 2, + "pp_size": 1, + "enable_all_optimization": True, + "use_lazy_init": True, + "zero_stage": 2, + "precision": "fp16", + "initial_scale": 1, + }, + { + "tp_size": 1, + "pp_size": 2, + "num_microbatches": 2, + "enable_all_optimization": True, + "use_lazy_init": True, + "zero_stage": 1, + "precision": "fp16", + "initial_scale": 1, + }, + ], +) +@clear_cache_before_run() +def run_gpt2_test(test_config): + sub_model_zoo = model_zoo.get_sub_registry("transformers_gpt") + + for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): + check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config) + + clear_layout_converter() + torch.cuda.empty_cache() + + +@parameterize( + "test_config", + [ + { + "tp_size": 2, + "pp_size": 2, + "num_microbatches": 4, + "enable_all_optimization": False, + "use_lazy_init": False, + "precision": "fp32", + "initial_scale": 1, + }, + { + "tp_size": 2, + "pp_size": 2, + "num_microbatches": 4, + "enable_all_optimization": False, + "use_lazy_init": False, + "precision": "fp16", + "zero_stage": 1, + "initial_scale": 1, + }, + ], +) +@clear_cache_before_run() +def run_gpt2_3d_test(test_config): + sub_model_zoo = model_zoo.get_sub_registry("transformers_gpt") + + for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): + check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config) + + clear_layout_converter() + torch.cuda.empty_cache() + + +def check_gpt2(rank, world_size, port): + disable_existing_loggers() + colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") + run_gpt2_test() + + +def check_gpt2_3d(rank, world_size, port): + disable_existing_loggers() + colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") + run_gpt2_3d_test() + + +@pytest.mark.dist +@rerun_if_address_is_in_use() +@clear_cache_before_run() +def test_gpt2(): + spawn(check_gpt2, 4) + + +@pytest.mark.largedist +@rerun_if_address_is_in_use() +@clear_cache_before_run() +def test_gpt2_3d(): + spawn(check_gpt2_3d, 8) + + +if __name__ == "__main__": + test_gpt2() + test_gpt2_3d() diff --git a/tests/test_shardformer/test_model/test_shard_llama.py b/tests/test_shardformer/test_model/test_shard_llama.py new file mode 100644 index 0000000000000000000000000000000000000000..f8f08e1d00757c817a085ac287b3f23135e2c9b3 --- /dev/null +++ b/tests/test_shardformer/test_model/test_shard_llama.py @@ -0,0 +1,223 @@ +import os + +import pytest +import torch + +import colossalai +from colossalai.logging import disable_existing_loggers +from colossalai.shardformer.layer.utils import Randomizer +from colossalai.tensor.d_tensor.api import clear_layout_converter +from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn +from tests.kit.model_zoo import model_zoo +from tests.test_shardformer.test_model._utils import ( + build_model_from_hybrid_plugin, + check_all_grad_tensors, + check_loss, + check_output_hidden_state, + check_weight, + get_grad_tensors_for_check, + run_forward_backward_with_hybrid_plugin, + unwrap_model, +) + +os.environ["TRANSFORMERS_NO_ADVISORY_WARNINGS"] = "true" + + +def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config): + org_model, org_optimizer, sharded_model, sharded_optimizer, criterion, booster = build_model_from_hybrid_plugin( + model_fn, loss_fn, test_config + ) + + org_loss, org_output, sharded_loss, sharded_output = run_forward_backward_with_hybrid_plugin( + org_model, sharded_model, sharded_optimizer, data_gen_fn, output_transform_fn, criterion, booster + ) + + stage_manager = booster.plugin.stage_manager + tp_group = booster.plugin.tp_group + + # unwrap model + llama_model = unwrap_model(org_model, "LlamaModel", "model") + shard_llama_model = unwrap_model(sharded_model, "LlamaModel", "model") + + row_layer_for_check = ["layers[0].self_attn.q_proj", "embed_tokens"] + col_layer_for_check = ["layers[0].self_attn.o_proj"] + + # Save gradient tensors for comparison between the original model and the sharded model before optimizer step. + grads_to_check = {} + if (stage_manager is None or stage_manager.is_first_stage()) and booster.plugin.zero_stage == 0: + if test_config["precision"] == "fp32": + atol, rtol = 1e-6, 1e-4 + else: + atol, rtol = 5e-3, 5e-3 + row_layer_grads = get_grad_tensors_for_check( + llama_model, shard_llama_model, row_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=0, verbose=False + ) + col_layer_grads = get_grad_tensors_for_check( + llama_model, shard_llama_model, col_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=1, verbose=False + ) + grads_to_check.update(col_layer_grads) + grads_to_check.update(row_layer_grads) + + # optimizer executes step + org_optimizer.step() + sharded_optimizer.step() + + # check last hidden state & loss + if stage_manager is None or stage_manager.is_last_stage(): + if test_config["precision"] == "fp32": + atol, rtol = 1e-5, 1e-3 + else: + atol, rtol = 5e-3, 5e-3 + + if org_model.__class__.__name__ == "LlamaModel": + check_output_hidden_state(org_output, sharded_output, stage_manager, atol=atol, rtol=rtol) + + check_loss(org_loss, sharded_loss, atol=atol, rtol=rtol) + + # check weights + if stage_manager is None or stage_manager.is_first_stage(): + if test_config["precision"] == "fp32": + atol, rtol = 1e-4, 1e-3 + else: + atol, rtol = 5e-3, 5e-3 + check_weight( + llama_model, shard_llama_model, col_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=1, verbose=False + ) + + # check grads + check_all_grad_tensors(grads_to_check) + + torch.cuda.empty_cache() + + +@parameterize( + "test_config", + [ + { + "tp_size": 2, + "pp_size": 2, + "num_microbatches": 2, + "enable_all_optimization": True, + "use_lazy_init": True, + "precision": "fp16", + "initial_scale": 1, + }, + { + "tp_size": 1, + "pp_size": 2, + "num_microbatches": 4, + "use_lazy_init": False, + "precision": "fp32", + }, + { + "tp_size": 4, + "pp_size": 1, + "enable_all_optimization": True, + "use_lazy_init": False, + "precision": "fp32", + }, + { + "tp_size": 1, + "pp_size": 4, + "num_microbatches": 4, + "enable_all_optimization": False, + "use_lazy_init": False, + "precision": "fp32", + }, + {"tp_size": 2, "pp_size": 1, "enable_all_optimization": True, "use_lazy_init": False, "precision": "fp32"}, + { + "tp_size": 2, + "pp_size": 1, + "enable_all_optimization": True, + "use_lazy_init": True, + "zero_stage": 2, + "precision": "fp16", + "initial_scale": 1, + }, + { + "tp_size": 1, + "pp_size": 2, + "num_microbatches": 2, + "enable_all_optimization": True, + "use_lazy_init": True, + "zero_stage": 1, + "precision": "fp16", + "initial_scale": 1, + }, + ], +) +def run_llama_test(test_config): + sub_model_zoo = model_zoo.get_sub_registry("transformers_llama") + + for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): + check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config) + + clear_layout_converter() + Randomizer.reset_index() + torch.cuda.empty_cache() + + +@parameterize( + "test_config", + [ + { + "tp_size": 2, + "pp_size": 2, + "num_microbatches": 4, + "enable_all_optimization": False, + "use_lazy_init": False, + "precision": "fp32", + "initial_scale": 1, + }, + { + "tp_size": 2, + "pp_size": 2, + "num_microbatches": 4, + "enable_all_optimization": False, + "use_lazy_init": False, + "precision": "fp16", + "zero_stage": 1, + "initial_scale": 1, + }, + ], +) +def run_llama_3d_test(test_config): + sub_model_zoo = model_zoo.get_sub_registry("transformers_llama") + + for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): + check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config) + + clear_layout_converter() + Randomizer.reset_index() + torch.cuda.empty_cache() + + +def check_llama(rank, world_size, port): + disable_existing_loggers() + colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") + run_llama_test() + + +def check_llama_3d(rank, world_size, port): + disable_existing_loggers() + colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") + run_llama_3d_test() + + +@pytest.mark.dist +@rerun_if_address_is_in_use() +@clear_cache_before_run() +def test_llama(): + spawn(check_llama, 4) + + +@pytest.mark.largedist +@rerun_if_address_is_in_use() +@clear_cache_before_run() +def test_llama_3d(): + spawn(check_llama_3d, 8) + + +if __name__ == "__main__": + test_llama() + test_llama_3d() diff --git a/tests/test_shardformer/test_model/test_shard_opt.py b/tests/test_shardformer/test_model/test_shard_opt.py new file mode 100644 index 0000000000000000000000000000000000000000..d21ab264d8ab442ae5da00c036f41c513e1b8df6 --- /dev/null +++ b/tests/test_shardformer/test_model/test_shard_opt.py @@ -0,0 +1,207 @@ +import os + +import pytest +import torch + +import colossalai +from colossalai.logging import disable_existing_loggers +from colossalai.shardformer.layer.utils import Randomizer +from colossalai.tensor.d_tensor.api import clear_layout_converter +from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn +from tests.kit.model_zoo import model_zoo +from tests.test_shardformer.test_model._utils import ( + build_model_from_hybrid_plugin, + check_all_grad_tensors, + check_loss, + check_output_hidden_state, + check_weight, + get_grad_tensors_for_check, + run_forward_backward_with_hybrid_plugin, + unwrap_model, +) + +os.environ["TRANSFORMERS_NO_ADVISORY_WARNINGS"] = "true" + + +def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config): + org_model, org_optimizer, sharded_model, sharded_optimizer, criterion, booster = build_model_from_hybrid_plugin( + model_fn, loss_fn, test_config + ) + + org_loss, org_output, sharded_loss, sharded_output = run_forward_backward_with_hybrid_plugin( + org_model, sharded_model, sharded_optimizer, data_gen_fn, output_transform_fn, criterion, booster + ) + + stage_manager = booster.plugin.stage_manager + tp_group = booster.plugin.tp_group + + # unwrap model + opt_model = unwrap_model(org_model, "OPTModel", "model") + shard_opt_model = unwrap_model(sharded_model, "OPTModel", "model") + + row_layer_for_check = ["decoder.layers[0].self_attn.q_proj", "decoder.embed_tokens"] # 'decoder.embed_tokens' + col_layer_for_check = ["decoder.layers[0].self_attn.out_proj"] + + # Save gradient tensors for comparison between the original model and the sharded model. + grads_to_check = {} + if (stage_manager is None or stage_manager.is_first_stage()) and booster.plugin.zero_stage == 0: + if test_config["precision"] == "fp32": + atol, rtol = 1e-6, 1e-3 + else: + atol, rtol = 4e-2, 4e-2 + row_layer_grads = get_grad_tensors_for_check( + opt_model, shard_opt_model, row_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=0, verbose=False + ) + col_layer_grads = get_grad_tensors_for_check( + opt_model, shard_opt_model, col_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=1, verbose=False + ) + grads_to_check.update(col_layer_grads) + grads_to_check.update(row_layer_grads) + + # optimizer executes step + org_optimizer.step() + sharded_optimizer.step() + + # check last hidden state & loss + if stage_manager is None or stage_manager.is_last_stage(): + if test_config["precision"] == "fp32": + atol, rtol = 1e-5, 1e-3 + else: + atol, rtol = 5e-3, 5e-3 + if org_model.__class__.__name__ == "OPTModel": + check_output_hidden_state(org_output, sharded_output, stage_manager, atol=atol, rtol=rtol) + + check_loss(org_loss, sharded_loss, atol=atol, rtol=rtol) + + # check weights + if stage_manager is None or stage_manager.is_first_stage(): + if test_config["precision"] == "fp32": + atol, rtol = 1e-3, 1e-3 + else: + atol, rtol = 5e-3, 5e-3 + check_weight( + opt_model, shard_opt_model, col_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=1, verbose=False + ) + + # check grads + check_all_grad_tensors(grads_to_check) + + Randomizer.reset_index() + torch.cuda.empty_cache() + + +@parameterize( + "test_config", + [ + { + "tp_size": 2, + "pp_size": 2, + "num_microbatches": 4, + "enable_all_optimization": True, + "use_lazy_init": True, + "precision": "fp16", + "initial_scale": 1, + }, + { + "tp_size": 1, + "pp_size": 2, + "num_microbatches": 4, + "enable_all_optimization": False, + "use_lazy_init": False, + "precision": "fp32", + }, + {"tp_size": 4, "pp_size": 1, "enable_all_optimization": True, "use_lazy_init": False, "precision": "fp32"}, + {"tp_size": 2, "pp_size": 1, "enable_all_optimization": True, "use_lazy_init": False, "precision": "fp32"}, + { + "tp_size": 2, + "pp_size": 1, + "enable_all_optimization": True, + "use_lazy_init": True, + "zero_stage": 2, + "precision": "fp16", + "initial_scale": 1, + }, + { + "tp_size": 1, + "pp_size": 2, + "num_microbatches": 2, + "enable_all_optimization": True, + "use_lazy_init": True, + "zero_stage": 1, + "precision": "fp16", + "initial_scale": 1, + }, + ], +) +def run_opt_test(test_config): + sub_model_zoo = model_zoo.get_sub_registry("transformers_opt") + for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): + check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config) + + clear_layout_converter() + torch.cuda.empty_cache() + + +@parameterize( + "test_config", + [ + { + "tp_size": 2, + "pp_size": 2, + "num_microbatches": 4, + "enable_all_optimization": False, + "use_lazy_init": False, + "precision": "fp32", + "initial_scale": 1, + }, + { + "tp_size": 2, + "pp_size": 2, + "num_microbatches": 4, + "enable_all_optimization": False, + "use_lazy_init": False, + "precision": "fp16", + "zero_stage": 1, + "initial_scale": 1, + }, + ], +) +def run_opt_3d_test(test_config): + sub_model_zoo = model_zoo.get_sub_registry("transformers_opt") + + for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): + check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config) + + clear_layout_converter() + torch.cuda.empty_cache() + + +def check_OPTModel(rank, world_size, port): + disable_existing_loggers() + colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") + run_opt_test() + + +def check_opt_3d(rank, world_size, port): + disable_existing_loggers() + colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") + run_opt_3d_test() + + +@pytest.mark.dist +@rerun_if_address_is_in_use() +@clear_cache_before_run() +def test_OPTModel(): + spawn(check_OPTModel, 4) + + +@pytest.mark.largedist +@rerun_if_address_is_in_use() +@clear_cache_before_run() +def test_opt_3d(): + spawn(check_opt_3d, 8) + + +if __name__ == "__main__": + test_OPTModel() + test_opt_3d() diff --git a/tests/test_shardformer/test_model/test_shard_sam.py b/tests/test_shardformer/test_model/test_shard_sam.py new file mode 100644 index 0000000000000000000000000000000000000000..a8d4cb63522179e4cc75c3b04e23ff302e37d0d6 --- /dev/null +++ b/tests/test_shardformer/test_model/test_shard_sam.py @@ -0,0 +1,72 @@ +import pytest +import torch + +import colossalai +from colossalai.logging import disable_existing_loggers +from colossalai.testing import ( + assert_hf_output_close, + clear_cache_before_run, + parameterize, + rerun_if_address_is_in_use, + spawn, +) +from tests.kit.model_zoo import model_zoo +from tests.test_shardformer.test_model._utils import build_model, check_grad, run_forward + + +def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn): + # check forward + org_output, org_loss, shard_output, shard_loss = run_forward( + org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn + ) + assert_hf_output_close(org_output, shard_output, ignore_keys=["pred_masks"]) + + # do backward + org_loss.backward() + shard_loss.backward() + + assert torch.allclose( + org_loss, shard_loss, atol=1e-5 + ), f"shard model loss is not equal to orgin model loss\n{org_loss}\n{shard_loss}" + + # check grad + + sam = org_model + sharded_sam = sharded_model + + # check grad + col_layer_for_check = ["mask_decoder.transformer.layers[0].self_attn.q_proj", "vision_encoder.layers[0].mlp.lin1"] + row_layer_for_check = ["mask_decoder.transformer.layers[0].self_attn.out_proj", "vision_encoder.layers[0].mlp.lin2"] + check_grad(sam, sharded_sam, col_layer_for_check, atol=1e-5, rtol=1e-3, dim=0, verbose=False) + check_grad(sam, sharded_sam, row_layer_for_check, atol=1e-3, rtol=1e-3, dim=1, verbose=False) + + +@parameterize("enable_fused_normalization", [True, False]) +@parameterize("enable_tensor_parallelism", [True, False]) +@parameterize("enable_flash_attention", [True, False]) +def run_sam_test(enable_fused_normalization, enable_tensor_parallelism, enable_flash_attention): + sub_model_zoo = model_zoo.get_sub_registry("transformers_sam") + for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): + org_model, sharded_model = build_model( + model_fn, enable_fused_normalization, enable_tensor_parallelism, enable_flash_attention + ) + check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn) + + torch.cuda.empty_cache() + + +def check_sam(rank, world_size, port): + disable_existing_loggers() + colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") + run_sam_test() + + +@pytest.mark.dist +@rerun_if_address_is_in_use() +@clear_cache_before_run() +def test_sam(): + spawn(check_sam, 2) + + +if __name__ == "__main__": + test_sam() diff --git a/tests/test_shardformer/test_model/test_shard_t5.py b/tests/test_shardformer/test_model/test_shard_t5.py new file mode 100644 index 0000000000000000000000000000000000000000..73f203d1f0235b16b85d86f1b9153ed57a3a2e3d --- /dev/null +++ b/tests/test_shardformer/test_model/test_shard_t5.py @@ -0,0 +1,217 @@ +import pytest +import torch + +import colossalai +from colossalai.logging import disable_existing_loggers +from colossalai.shardformer.layer.utils import Randomizer +from colossalai.tensor.d_tensor.api import clear_layout_converter +from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn +from tests.kit.model_zoo import model_zoo +from tests.test_shardformer.test_model._utils import ( + build_model_from_hybrid_plugin, + check_all_grad_tensors, + check_loss, + check_output_hidden_state, + check_weight, + get_grad_tensors_for_check, + run_forward_backward_with_hybrid_plugin, + unwrap_model, +) + + +def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config): + org_model, org_optimizer, sharded_model, sharded_optimizer, criterion, booster = build_model_from_hybrid_plugin( + model_fn, loss_fn, test_config + ) + + org_loss, org_output, sharded_loss, sharded_output = run_forward_backward_with_hybrid_plugin( + org_model, sharded_model, sharded_optimizer, data_gen_fn, output_transform_fn, criterion, booster + ) + + stage_manager = booster.plugin.stage_manager + tp_group = booster.plugin.tp_group + + # unwrap model + t5 = unwrap_model(org_model) + sharded_t5 = unwrap_model(sharded_model) + + row_layer_for_check = ["shared", "encoder.block[0].layer[0].SelfAttention.q"] + + # Save gradient tensors for comparison between the original model and the sharded model before optimizer step. + grads_to_check = {} + if test_config["precision"] == "fp32": + atol, rtol = 1e-5, 1e-3 + else: + atol, rtol = 5e-3, 5e-3 + if (stage_manager is None or stage_manager.is_first_stage()) and booster.plugin.zero_stage == 0: + row_layer_grads = get_grad_tensors_for_check( + t5, sharded_t5, row_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=0 + ) + grads_to_check.update(row_layer_grads) + + # optimizer executes step + org_optimizer.step() + sharded_optimizer.step() + + # check last hidden state & loss + if stage_manager is None or stage_manager.is_last_stage(): + if test_config["precision"] == "fp32": + atol, rtol = 1e-5, 1e-3 + else: + atol, rtol = 5e-3, 5e-3 + + if org_model.__class__.__name__ != "T5ForConditionalGeneration": + check_output_hidden_state(org_output, sharded_output, stage_manager, atol=atol, rtol=rtol) + + check_loss(org_loss, sharded_loss, atol=atol, rtol=rtol) + + # check weights + if test_config["precision"] == "fp32": + atol, rtol = 5e-4, 1e-3 + else: + atol, rtol = 5e-3, 5e-3 + if stage_manager is None or stage_manager.is_first_stage(): + check_weight(t5, sharded_t5, row_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=0, verbose=False) + + # check grads + check_all_grad_tensors(grads_to_check) + + torch.cuda.empty_cache() + + +@parameterize( + "test_config", + [ + { + "tp_size": 2, + "pp_size": 2, + "num_microbatches": 2, + "enable_all_optimization": True, + "use_lazy_init": True, + "precision": "fp16", + "initial_scale": 1, + }, + { + "tp_size": 1, + "pp_size": 2, + "num_microbatches": 4, + "use_lazy_init": False, + "precision": "fp16", + "initial_scale": 1, + }, + { + "tp_size": 4, + "pp_size": 1, + "enable_all_optimization": True, + "use_lazy_init": False, + "precision": "fp32", + }, + { + "tp_size": 1, + "pp_size": 4, + "num_microbatches": 4, + "enable_all_optimization": False, + "use_lazy_init": False, + "precision": "fp32", + }, + {"tp_size": 2, "pp_size": 1, "enable_all_optimization": True, "use_lazy_init": False, "precision": "fp32"}, + { + "tp_size": 2, + "pp_size": 1, + "enable_all_optimization": True, + "use_lazy_init": True, + "zero_stage": 2, + "precision": "fp16", + "initial_scale": 1, + }, + { + "tp_size": 1, + "pp_size": 2, + "num_microbatches": 2, + "enable_all_optimization": True, + "use_lazy_init": True, + "zero_stage": 1, + "precision": "fp16", + "initial_scale": 1, + }, + ], +) +@clear_cache_before_run() +def run_t5_test(test_config): + sub_model_zoo = model_zoo.get_sub_registry("transformers_t5") + + for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): + # skip 4-stage pp test for t5_encoder + if test_config["pp_size"] > 2 and name == "transformers_t5_encoder_model": + continue + + check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config) + + clear_layout_converter() + Randomizer.reset_index() + torch.cuda.empty_cache() + + +@parameterize( + "test_config", + [ + { + "tp_size": 2, + "pp_size": 2, + "num_microbatches": 4, + "enable_all_optimization": False, + "use_lazy_init": False, + "precision": "fp32", + "initial_scale": 1, + }, + { + "tp_size": 2, + "pp_size": 2, + "num_microbatches": 4, + "enable_all_optimization": False, + "use_lazy_init": False, + "precision": "fp16", + "zero_stage": 1, + "initial_scale": 1, + }, + ], +) +def run_t5_3d_test(test_config): + sub_model_zoo = model_zoo.get_sub_registry("transformers_t5") + + for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): + check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config) + + clear_layout_converter() + torch.cuda.empty_cache() + + +def check_t5(rank, world_size, port): + disable_existing_loggers() + colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") + run_t5_test() + + +def check_t5_3d(rank, world_size, port): + disable_existing_loggers() + colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") + run_t5_3d_test() + + +@pytest.mark.dist +@rerun_if_address_is_in_use() +@clear_cache_before_run() +def test_t5(): + spawn(check_t5, 4) + + +@pytest.mark.largedist +@rerun_if_address_is_in_use() +@clear_cache_before_run() +def test_t5_3d(): + spawn(check_t5_3d, 8) + + +if __name__ == "__main__": + test_t5() + test_t5_3d() diff --git a/tests/test_shardformer/test_model/test_shard_vit.py b/tests/test_shardformer/test_model/test_shard_vit.py new file mode 100644 index 0000000000000000000000000000000000000000..1c934bd22340661383f9f745d1f49c8c850f16db --- /dev/null +++ b/tests/test_shardformer/test_model/test_shard_vit.py @@ -0,0 +1,206 @@ +import pytest +import torch + +import colossalai +from colossalai.logging import disable_existing_loggers +from colossalai.shardformer.layer.utils import Randomizer +from colossalai.tensor.d_tensor.api import clear_layout_converter +from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn +from tests.kit.model_zoo import model_zoo +from tests.test_shardformer.test_model._utils import ( + build_model_from_hybrid_plugin, + check_all_grad_tensors, + check_loss, + check_output_hidden_state, + check_weight, + get_grad_tensors_for_check, + run_forward_backward_with_hybrid_plugin, + unwrap_model, +) + + +def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config): + org_model, org_optimizer, sharded_model, sharded_optimizer, criterion, booster = build_model_from_hybrid_plugin( + model_fn, loss_fn, test_config + ) + + org_loss, org_output, sharded_loss, sharded_output = run_forward_backward_with_hybrid_plugin( + org_model, sharded_model, sharded_optimizer, data_gen_fn, output_transform_fn, criterion, booster + ) + + stage_manager = booster.plugin.stage_manager + tp_group = booster.plugin.tp_group + + # unwrap model + vit_model = unwrap_model(org_model, "ViTModel", "vit") + shard_vit_model = unwrap_model(sharded_model, "ViTModel", "vit") + + # check grad + row_layer_for_check = ["encoder.layer[0].attention.attention.query", "embeddings.patch_embeddings.projection"] + col_layer_for_check = ["encoder.layer[0].attention.output.dense"] + + # Save gradient tensors for comparison between the original model and the sharded model before optimizer step. + grads_to_check = {} + if (stage_manager is None or stage_manager.is_first_stage()) and booster.plugin.zero_stage == 0: + if test_config["precision"] == "fp32": + atol, rtol = 1e-5, 1e-3 + else: + atol, rtol = 5e-3, 5e-3 + row_layer_grads = get_grad_tensors_for_check( + vit_model, shard_vit_model, row_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=0, verbose=False + ) + col_layer_grads = get_grad_tensors_for_check( + vit_model, shard_vit_model, col_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=1, verbose=False + ) + grads_to_check.update(col_layer_grads) + grads_to_check.update(row_layer_grads) + + # optimizer executes step + org_optimizer.step() + sharded_optimizer.step() + + # check last hidden state & loss + if stage_manager is None or stage_manager.is_last_stage(): + if test_config["precision"] == "fp32": + atol, rtol = 1e-5, 1e-3 + else: + atol, rtol = 5e-3, 5e-3 + + if org_model.__class__.__name__ == "ViTModel": + check_output_hidden_state(org_output, sharded_output, stage_manager, atol=atol, rtol=rtol) + check_loss(org_loss, sharded_loss, atol=atol, rtol=rtol) + + # check weights + if stage_manager is None or stage_manager.is_first_stage(): + if test_config["precision"] == "fp32": + atol, rtol = 5e-3, 1e-3 + else: + atol, rtol = 5e-3, 5e-3 + check_weight( + vit_model, shard_vit_model, col_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=1, verbose=False + ) + + # check grads + check_all_grad_tensors(grads_to_check) + + torch.cuda.empty_cache() + + +# TODO: num_microbatch size = 2 inf loss +@parameterize( + "test_config", + [ + { + "tp_size": 2, + "pp_size": 2, + "num_microbatches": 4, + "enable_all_optimization": True, + "use_lazy_init": False, + "precision": "fp16", + "initial_scale": 1, + }, + { + "tp_size": 1, + "pp_size": 2, + "num_microbatches": 4, + "enable_all_optimization": False, + "use_lazy_init": False, + "precision": "fp32", + }, + {"tp_size": 4, "pp_size": 1, "enable_all_optimization": True, "use_lazy_init": False, "precision": "fp32"}, + {"tp_size": 2, "pp_size": 1, "enable_all_optimization": True, "use_lazy_init": False, "precision": "fp32"}, + { + "tp_size": 2, + "pp_size": 1, + "enable_all_optimization": True, + "use_lazy_init": False, + "zero_stage": 2, + "precision": "fp16", + "initial_scale": 1, + }, + { + "tp_size": 1, + "pp_size": 2, + "num_microbatches": 4, + "enable_all_optimization": True, + "use_lazy_init": False, + "zero_stage": 1, + "precision": "fp16", + "initial_scale": 1, + }, + ], +) +def run_vit_test(test_config): + # TODO: fix bug when settign lazy_init for Conv2D Layers in ViT models + + sub_model_zoo = model_zoo.get_sub_registry("transformers_vit") + for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): + check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config) + + clear_layout_converter() + Randomizer.reset_index() + torch.cuda.empty_cache() + + +@parameterize( + "test_config", + [ + { + "tp_size": 2, + "pp_size": 2, + "num_microbatches": 4, + "enable_all_optimization": False, + "use_lazy_init": False, + "precision": "fp32", + "initial_scale": 1, + }, + { + "tp_size": 2, + "pp_size": 2, + "num_microbatches": 2, + "enable_all_optimization": False, + "use_lazy_init": False, + "precision": "fp32", + "initial_scale": 1, + }, + ], +) +def run_vit_3d_test(test_config): + sub_model_zoo = model_zoo.get_sub_registry("transformers_vit") + + for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): + check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config) + + clear_layout_converter() + torch.cuda.empty_cache() + + +def check_vit(rank, world_size, port): + disable_existing_loggers() + colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") + run_vit_test() + + +def check_vit_3d(rank, world_size, port): + disable_existing_loggers() + colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") + run_vit_3d_test() + + +@pytest.mark.dist +@rerun_if_address_is_in_use() +@clear_cache_before_run() +def test_vit(): + spawn(check_vit, 4) + + +@pytest.mark.largedist +@rerun_if_address_is_in_use() +@clear_cache_before_run() +def test_vit_3d(): + spawn(check_vit_3d, 8) + + +if __name__ == "__main__": + test_vit() + test_vit_3d() diff --git a/tests/test_shardformer/test_model/test_shard_whisper.py b/tests/test_shardformer/test_model/test_shard_whisper.py new file mode 100644 index 0000000000000000000000000000000000000000..f839bd84ab699415d91fc4cc39c2f0cd919adcc1 --- /dev/null +++ b/tests/test_shardformer/test_model/test_shard_whisper.py @@ -0,0 +1,220 @@ +import pytest +import torch + +import colossalai +from colossalai.logging import disable_existing_loggers +from colossalai.shardformer.layer.utils import Randomizer +from colossalai.tensor.d_tensor.api import clear_layout_converter +from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn +from tests.kit.model_zoo import model_zoo +from tests.test_shardformer.test_model._utils import ( + build_model_from_hybrid_plugin, + check_all_grad_tensors, + check_loss, + check_output_hidden_state, + check_weight, + get_grad_tensors_for_check, + run_forward_backward_with_hybrid_plugin, +) + + +def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config): + # check forward + org_model, org_optimizer, sharded_model, sharded_optimizer, criterion, booster = build_model_from_hybrid_plugin( + model_fn, loss_fn, test_config + ) + + org_loss, org_output, sharded_loss, sharded_output = run_forward_backward_with_hybrid_plugin( + org_model, sharded_model, sharded_optimizer, data_gen_fn, output_transform_fn, criterion, booster + ) + + stage_manager = booster.plugin.stage_manager + tp_group = booster.plugin.tp_group + + # unwarp the model + if org_model.__class__.__name__ == "WhisperForConditionalGeneration": + whisper = org_model.model + sharded_whisper = sharded_model.unwrap().model + else: + whisper = org_model + sharded_whisper = sharded_model.unwrap() + + # check grad + if org_model.__class__.__name__ == "WhisperForAudioClassification": + col_layer_for_check = ["encoder.layers[0].self_attn.q_proj"] + row_layer_for_check = ["encoder.layers[0].self_attn.out_proj"] + else: + col_layer_for_check = [ + "encoder.layers[0].self_attn.q_proj", + # 'decoder.layers[0].self_attn.q_proj' + ] + row_layer_for_check = [ + "encoder.layers[0].self_attn.out_proj", + #'decoder.layers[0].self_attn.out_proj' + ] + + # Save gradient tensors for comparison between the original model and the sharded model before optimizer step. + grads_to_check = {} + if test_config["precision"] == "fp32": + atol, rtol = 2e-4, 2e-4 + else: + atol, rtol = 5e-3, 5e-3 + + if stage_manager is None or stage_manager.is_first_stage(): + row_layer_grads = get_grad_tensors_for_check( + whisper, sharded_whisper, row_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=1 + ) + col_layer_grads = get_grad_tensors_for_check( + whisper, sharded_whisper, col_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=0 + ) + grads_to_check.update(col_layer_grads) + grads_to_check.update(row_layer_grads) + + # optimizer executes step + org_optimizer.step() + sharded_optimizer.step() + + # check last hidden state & loss + if stage_manager is None or stage_manager.is_last_stage(): + if test_config["precision"] == "fp32": + atol, rtol = 2e-4, 2e-4 + else: + atol, rtol = 5e-3, 5e-3 + + if org_model.__class__.__name__ == "WhisperModel": + check_output_hidden_state(org_output, sharded_output, stage_manager, atol=atol, rtol=rtol) + + check_loss(org_loss, sharded_loss, atol=atol, rtol=rtol) + + # check weights + if test_config["precision"] == "fp32": + atol, rtol = 1e-3, 1e-3 + else: + atol, rtol = 5e-3, 5e-3 + if stage_manager is None or stage_manager.is_first_stage(): + check_weight( + whisper, sharded_whisper, row_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=1, verbose=False + ) + check_weight( + whisper, sharded_whisper, col_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=0, verbose=False + ) + + # check grads + check_all_grad_tensors(grads_to_check) + + torch.cuda.empty_cache() + + +# TODO fix WhisperForConditionalGeneration enable jit fused operato +# TODO(jianghai) fix fp16 +@parameterize( + "test_config", + [ + { + "tp_size": 2, + "pp_size": 2, + "num_microbatches": 2, + "enable_all_optimization": True, + "use_lazy_init": True, + "precision": "fp32", + "initial_scale": 1, + }, + { + "tp_size": 1, + "pp_size": 2, + "num_microbatches": 4, + "use_lazy_init": False, + "precision": "fp32", + "initial_scale": 1, + }, + { + "tp_size": 4, + "pp_size": 1, + "enable_all_optimization": True, + "use_lazy_init": False, + "precision": "fp32", + }, + { + "tp_size": 1, + "pp_size": 4, + "num_microbatches": 4, + "use_lazy_init": False, + "precision": "fp32", + }, + # whisper is not supported fp16 for now. + ], +) +def run_whisper_test(test_config): + sub_model_zoo = model_zoo.get_sub_registry("transformers_whisper") + for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): + if test_config["pp_size"] > 2 and name == "transformers_whisper_for_audio_classification": + continue + check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config) + + clear_layout_converter() + Randomizer.reset_index() + torch.cuda.empty_cache() + + +@parameterize( + "test_config", + [ + { + "tp_size": 2, + "pp_size": 2, + "num_microbatches": 4, + "enable_all_optimization": False, + "use_lazy_init": False, + "precision": "fp32", + "initial_scale": 1, + }, + { + "tp_size": 2, + "pp_size": 2, + "num_microbatches": 2, + "enable_all_optimization": False, + "use_lazy_init": False, + "precision": "fp32", + "initial_scale": 1, + }, + ], +) +def run_whisper_3d_test(test_config): + sub_model_zoo = model_zoo.get_sub_registry("transformers_whisper") + + for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): + check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config) + + clear_layout_converter() + torch.cuda.empty_cache() + + +def check_whisper(rank, world_size, port): + disable_existing_loggers() + colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") + run_whisper_test() + + +def check_whisper_3d(rank, world_size, port): + disable_existing_loggers() + colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") + run_whisper_3d_test() + + +@pytest.mark.dist +@rerun_if_address_is_in_use() +@clear_cache_before_run() +def test_whisper(): + spawn(check_whisper, 4) + + +@pytest.mark.largedist +@rerun_if_address_is_in_use() +@clear_cache_before_run() +def test_whisper_3d(): + spawn(check_whisper_3d, 8) + + +if __name__ == "__main__": + test_whisper() + test_whisper_3d() diff --git a/tests/test_shardformer/test_shard_utils.py b/tests/test_shardformer/test_shard_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..9739fad86d39d6bd2304ad3605f075f4f5fe7874 --- /dev/null +++ b/tests/test_shardformer/test_shard_utils.py @@ -0,0 +1,26 @@ +import torch +import torch.nn as nn + +from colossalai.shardformer.shard.utils import set_tensors_to_none + + +class Net(nn.Module): + def __init__(self) -> None: + super().__init__() + self.layers = nn.Sequential(nn.Linear(1, 2), nn.Linear(2, 3)) + self.out = nn.Linear(3, 1) + + +def test_release_layer(): + orig_cuda_allocated = torch.cuda.memory_allocated() + model = Net().cuda() + set_tensors_to_none(model, exclude={model.layers[0]}) + assert model.layers[1].weight is None + assert model.layers[1].bias is None + assert model.out.weight is None + assert model.out.bias is None + set_tensors_to_none(model) + assert model.layers[0].weight is None + assert model.layers[0].bias is None + assert len(list(model.parameters())) == 0 + assert torch.cuda.memory_allocated() == orig_cuda_allocated diff --git a/tests/test_shardformer/test_with_torch_ddp.py b/tests/test_shardformer/test_with_torch_ddp.py new file mode 100644 index 0000000000000000000000000000000000000000..f642a9dcada45f60dec383f528dba2567dc07d84 --- /dev/null +++ b/tests/test_shardformer/test_with_torch_ddp.py @@ -0,0 +1,86 @@ +from contextlib import nullcontext + +import pytest +import torch +import torch.distributed as dist +from torch.nn.parallel import DistributedDataParallel as DDP + +import colossalai +from colossalai.cluster import DistCoordinator +from colossalai.lazy import LazyInitContext +from colossalai.logging import disable_existing_loggers +from colossalai.shardformer import ShardConfig, ShardFormer +from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn +from tests.kit.model_zoo import model_zoo + + +@parameterize("lazy_init", [True, False]) +def check_shardformer_with_ddp(lazy_init: bool): + sub_model_zoo = model_zoo.get_sub_registry("transformers_gpt") + + # create shardformer + # ranks: [0, 1, 2, 3] + # tp ranks = [0, 1], [2, 3] + # dp ranks = [0, 2], [1, 3] + dp_process_group_1 = dist.new_group([0, 2]) + dp_process_group_2 = dist.new_group([1, 3]) + tp_process_group_1 = dist.new_group([0, 1]) + tp_process_group_2 = dist.new_group([2, 3]) + + coordinator = DistCoordinator() + + if coordinator.rank in [0, 1]: + tp_process_group = tp_process_group_1 + else: + tp_process_group = tp_process_group_2 + + if coordinator.rank in [0, 2]: + dp_process_group = dp_process_group_1 + else: + dp_process_group = dp_process_group_2 + + shard_config = ShardConfig(tensor_parallel_process_group=tp_process_group, enable_fused_normalization=True) + shardformer = ShardFormer(shard_config=shard_config) + + ctx = LazyInitContext() if lazy_init else nullcontext() + + for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): + # create and shard model + with ctx: + model = model_fn().cuda() + sharded_model, _ = shardformer.optimize(model) + + # add ddp + sharded_ddp_model = DDP(sharded_model, process_group=dp_process_group) + + # prepare input + data = data_gen_fn() + data = {k: v.cuda() for k, v in data.items()} + + # switch to train mode + sharded_ddp_model.train() + + # run forward + output = sharded_ddp_model(**data) + loss = loss_fn(output) + + # backward + loss.backward() + torch.cuda.empty_cache() + + +def run_dist(rank, world_size, port): + disable_existing_loggers() + colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") + check_shardformer_with_ddp() + + +@pytest.mark.dist +@rerun_if_address_is_in_use() +@clear_cache_before_run() +def test_gpt2(): + spawn(run_dist, 4) + + +if __name__ == "__main__": + test_gpt2() diff --git a/tests/test_tensor/common_utils/__init__.py b/tests/test_tensor/common_utils/__init__.py deleted file mode 100644 index 5387db70445ffec90f2ffaafbc7bf41368ce6dc3..0000000000000000000000000000000000000000 --- a/tests/test_tensor/common_utils/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from ._utils import * diff --git a/tests/test_tensor/common_utils/_utils.py b/tests/test_tensor/common_utils/_utils.py deleted file mode 100644 index b405f8cd2108368b336f432a9ddc41b061c9f0c5..0000000000000000000000000000000000000000 --- a/tests/test_tensor/common_utils/_utils.py +++ /dev/null @@ -1,91 +0,0 @@ -import os -import random - -import numpy as np -import torch -import torch.distributed as dist -from torch.testing import assert_close - -from colossalai.context import ParallelMode -from colossalai.core import global_context as gpc -from colossalai.tensor import ComputePattern, ComputeSpec, ShardSpec - - -def set_seed(seed): - random.seed(seed) - os.environ['PYTHONHASHSEED'] = str(seed) - np.random.seed(seed) - torch.manual_seed(seed) - torch.cuda.manual_seed(seed) - torch.backends.cudnn.deterministic = True - torch.backends.cudnn.benchmark = False - - -def check_equal(A, B): - assert torch.allclose(A, B, rtol=1e-3, atol=1e-1) == True - - -def replace_parameter_add_grad(layer, weight=None, bias=None): - if weight is not None: - delattr(layer, 'weight') - setattr(layer, 'weight', weight) - layer.weight.requires_grad = True - if bias is not None: - delattr(layer, 'bias') - setattr(layer, 'bias', bias) - layer.bias.requires_grad = True - - -def broadcast_tensor_chunk(tensor, chunk_size=1, local_rank=0): - dist.broadcast(tensor, src=0) - tensor_chunk = torch.chunk(tensor, chunk_size, dim=-1)[local_rank] - return tensor_chunk.clone() - - -def tensor_equal(t_a: torch.Tensor, t_b: torch.Tensor, rtol: float = 1e-3, atol: float = 1e-1): - assert_close(t_a, t_b, rtol=rtol, atol=atol) - return True - - -def tensor_shard_equal(tensor: torch.Tensor, - shard: torch.Tensor, - rank: int, - world_size: int, - rtol: float = 1e-3, - atol: float = 1e-1): - assert tensor.ndim == shard.ndim - if tensor.shape == shard.shape: - return tensor_equal(tensor, shard, rtol, atol) - else: - dims_not_eq = torch.nonzero(torch.tensor(tensor.shape) != torch.tensor(shard.shape)) - if dims_not_eq.numel() == 1: - # 1D shard - dim = dims_not_eq.item() - if world_size is None: - world_size = gpc.get_world_size(ParallelMode.PARALLEL_1D) - if rank is None: - rank = gpc.get_local_rank(ParallelMode.PARALLEL_1D) - return tensor_equal(tensor.chunk(world_size, dim)[rank], shard, rtol, atol) - else: - raise NotImplementedError - - -def split_param_single_dim_tp1d(dim, param, pg): - spec = (ShardSpec([dim], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D)) - if param.process_group.tp_world_size() == 1: - param.set_process_group(pg) - param.set_tensor_spec(*spec) - - -def split_param_row_tp1d(param, pg): - split_param_single_dim_tp1d(0, param, pg) - - -def split_param_col_tp1d(param, pg): - split_param_single_dim_tp1d(-1, param, pg) - - -def debug_print(ranks, *args): - if dist.get_rank() in ranks: - print(*args) - dist.barrier() diff --git a/tests/test_tensor/core/test_tensor.py b/tests/test_tensor/core/test_tensor.py deleted file mode 100644 index 64d198b350a82fd3d5722cdeed81fc9d283e8867..0000000000000000000000000000000000000000 --- a/tests/test_tensor/core/test_tensor.py +++ /dev/null @@ -1,153 +0,0 @@ -import pytest -import torch -from numpy import allclose - -import colossalai -from colossalai.core import global_context as gpc -from colossalai.tensor import ColoTensor, ColoTensorSpec, ProcessGroup, ReplicaSpec, ShardSpec, distspec -from colossalai.testing import rerun_if_address_is_in_use, spawn - - -def _run_tensor_indexing(): - pg = ProcessGroup() - torch_t = torch.randn(2, 3) - colo_t = ColoTensor(torch_t, ColoTensorSpec(pg)) - assert allclose(torch_t[:, 1], colo_t[:, 1]) - - -def _run_wrapped_tensor_func(): - pg = ProcessGroup() - t_ref = torch.randn(4, 5) - t = ColoTensor.from_torch_tensor(t_ref.clone(), ColoTensorSpec(pg)) - - # non-func attr - assert t.is_cuda == t_ref.is_cuda - - # return 1 torch.Tensor - t_abs = t.abs() - assert isinstance(t_abs, ColoTensor) and torch.equal(t_abs, t_ref.abs()) - - # return 1 non-torch.Tensor - assert t.dim() == t_ref.dim() - - # return >1 torch.Tensor - assert isinstance(t, ColoTensor) - t_split1, t_split2 = t.split(2) - assert isinstance(t_split1, ColoTensor) and isinstance(t_split2, ColoTensor), f"{type(t_split1)} {type(t_split2)}" - - -def _run_operand(world_size): - pg = ProcessGroup() - t_ref = torch.randn(4, 5) - t = ColoTensor.from_torch_tensor(t_ref.clone(), ColoTensorSpec(pg)) - - t_ref_res = t_ref + t_ref - t_res = t + t - - assert isinstance(t_res, ColoTensor) - assert torch.allclose(t_ref_res, t_res) - - pg = ProcessGroup(tp_degree=world_size) - t = ColoTensor.from_torch_tensor(t_ref.clone(), ColoTensorSpec(pg)) - t.set_dist_spec(ShardSpec([0], [world_size])) - t_new = torch.zeros_like(t) - assert isinstance(t_new, ColoTensor) - assert t_new.is_sharded() - - -#### Test Distributed init a Colotensor - - -def _run_view(world_size): - t_ref = torch.randn(4, 5) - rank = gpc.get_global_rank() - pg = ProcessGroup(rank, list(range(world_size)), tp_degree=world_size) - t = ColoTensor.from_torch_tensor( - t_ref, ColoTensorSpec(pg, dist_attr=ShardSpec(dims=[0], num_partitions=[pg.tp_world_size()]))) - - assert t.size_global()[0] == 4 * world_size - assert t.size_global(1) == 5 - assert t.size_global() == torch.Size([4 * world_size, 5]) - - t = t.view(4 * 5 * world_size) - assert t.shape == torch.Size([4 * 5 * world_size]) - - -def _run_tensor_shard_init(world_size): - t_ref = torch.randn(4, 5) - pg = ProcessGroup(tp_degree=world_size) - shard_attr = ShardSpec(dims=[0], num_partitions=[pg.tp_world_size()]) - tensor_spec = ColoTensorSpec(pg, dist_attr=shard_attr) - t = ColoTensor.from_torch_tensor(t_ref.clone(), tensor_spec) - t.set_dist_spec(ReplicaSpec()) - - assert t.shape == torch.Size((4 * world_size, 5)), f"{t.shape} vs ({4 * world_size, 5})" - - -def _run_tensor_replicated_init(world_size): - t_ref = torch.randn(4 * world_size, 5) - pg = ProcessGroup() - spec = ColoTensorSpec(pg) - t = ColoTensor.from_torch_tensor(t_ref.clone(), spec) - - assert t.shape == torch.Size((4 * world_size, 5)), f"{t.shape}" - - -def _run_process_group(world_size): - pg1 = ProcessGroup() - pg2 = ProcessGroup() - assert pg1 == pg2 - - -def _run_redistributed(world_size): - if world_size != 4: - return - pg1 = ProcessGroup(tp_degree=2, dp_degree=2) - pg2 = ProcessGroup(tp_degree=4, dp_degree=1) - - spec1 = ColoTensorSpec(pg1) - t1 = ColoTensor.from_torch_tensor(torch.randn(2, 3, 4), spec1) - t1 = t1.redistribute(ShardSpec([0], [pg1.tp_world_size()])) - assert t1.is_sharded() - t1 = t1.redistribute(ShardSpec([-1], [pg2.tp_world_size()]), pg2) - assert t1.is_sharded() - pg3 = ProcessGroup(tp_degree=1, dp_degree=4) - t1 = t1.redistribute(ReplicaSpec(), pg3) - assert t1.is_replicate() - - -def _run_set_tensor_spec(world_size): - if world_size != 4: - return - pg = ProcessGroup(tp_degree=2, dp_degree=2) - spec1 = ColoTensorSpec(pg) - t1 = ColoTensor.from_torch_tensor(torch.randn(2, 3, 4), spec1) - - dist_spec2 = ShardSpec([-1], [pg.tp_world_size()]) - assert t1.is_replicate() - t1.set_dist_spec(dist_spec2) - assert t1.is_shard_1dcol() - - -def run_dist_tests(rank, world_size, port): - colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') - _run_tensor_shard_init(world_size) - _run_tensor_replicated_init(world_size) - _run_view(world_size) - _run_process_group(world_size) - _run_tensor_indexing() - _run_operand(world_size) - _run_wrapped_tensor_func() - _run_redistributed(world_size) - _run_set_tensor_spec(world_size) - - -@pytest.mark.dist -@pytest.mark.parametrize('world_size', [1, 2]) -@rerun_if_address_is_in_use() -def test_dist_cases(world_size): - spawn(run_dist_tests, world_size) - - -if __name__ == '__main__': - test_dist_cases(4) diff --git a/tests/test_tensor/model/test_gpt2.py b/tests/test_tensor/model/test_gpt2.py deleted file mode 100644 index 337bfa840d5da1866a3f822803ace857bdefc4b9..0000000000000000000000000000000000000000 --- a/tests/test_tensor/model/test_gpt2.py +++ /dev/null @@ -1,148 +0,0 @@ -import pytest -import torch -from torch.nn.parallel import DistributedDataParallel as DDP - -import colossalai -from colossalai.nn.parallel.data_parallel import ColoDDP -from colossalai.tensor import ColoTensor, ColoTensorSpec, ComputePattern, ComputeSpec, ProcessGroup, ShardSpec -from colossalai.testing import rerun_if_address_is_in_use, spawn -from colossalai.utils.cuda import get_current_device -from colossalai.zero import ColoInitContext -from tests.components_to_test.registry import non_distributed_component_funcs -from tests.test_tensor.common_utils import ( - debug_print, - set_seed, - split_param_col_tp1d, - split_param_row_tp1d, - tensor_equal, - tensor_shard_equal, -) - - -def init_1d_row_spec(model, pg: ProcessGroup): - tensor_spec = (ShardSpec([0], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D)) - for n, p in model.named_parameters(): - p.set_process_group(pg) - if 'weight' in n and 'ln' not in n: - p.set_tensor_spec(*tensor_spec) - - -def init_1d_col_spec(model, pg: ProcessGroup): - spec = (ShardSpec([-1], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D)) - - for n, p in model.named_parameters(): - p.set_process_group(pg) - if 'ln' not in n and ('weight' in n or 'bias' in n): - p.set_tensor_spec(*spec) - - -def init_megatron_spec(model, pg: ProcessGroup): - for mn, module in model.named_modules(): - # debug_print([0], mn) - for pn, param in module.named_parameters(recurse=False): - # debug_print([0], '\t', pn, param.compute_spec, param.shape) - param.set_process_group(pg) - - if 'mlp.c_fc' in mn: - if 'weight' in pn or 'bias' in pn: - split_param_col_tp1d(param, pg) - param.compute_spec.set_output_replicate(False) - else: - raise RuntimeError - elif 'mlp.c_proj' in mn: - if 'weight' in pn: - split_param_row_tp1d(param, pg) - else: - assert 'bias' in pn - elif 'wte' in mn or 'wpe' in mn: - assert 'weight' in pn - split_param_col_tp1d(param, pg) - elif 'c_attn' in mn or 'c_proj' in mn: - split_param_col_tp1d(param, pg) - # debug_print([0], '\t', param.compute_spec, param.shape) - - -def check_param_equal(model, torch_model, pg: ProcessGroup): - for p, torch_p in zip(model.parameters(), torch_model.parameters()): - assert pg.tp_local_rank() is not None, f"{pg.rank()} {pg.tp_world_size()} {pg._tp_degree} {pg.tp_local_rank()}1" - assert pg.tp_world_size() is not None - assert tensor_shard_equal(torch_p, p, pg.tp_local_rank(), pg.tp_world_size()) - - -def check_grad_equal(model, torch_model, pg: ProcessGroup): - for p, torch_p in zip(model.parameters(), torch_model.parameters()): - assert tensor_shard_equal(torch_p.grad, p.grad, pg.tp_local_rank(), pg.tp_world_size()) - - -def run_gpt(init_spec_func, use_ddp): - world_size = torch.distributed.get_world_size() - - # build a PG with TP and DP hybrid - pg = ProcessGroup(dp_degree=(2 if (use_ddp and world_size >= 2) else 1)) - - # set seed make processes of the same tp group use the same seed - # set_seed(pg.tp_local_rank()) - - get_components_func = non_distributed_component_funcs.get_callable('gpt2') - model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func() - - # make sure torch_model and model has the same parameter values - with ColoInitContext(device=get_current_device()): - model = model_builder() - model = model.cuda() - torch_model = model_builder().cuda() - - if use_ddp: - torch_model = DDP(torch_model, device_ids=[pg.rank()], process_group=pg.dp_process_group()) - model = ColoDDP(model, process_group=pg) - - for torch_p, p in zip(torch_model.parameters(), model.parameters()): - torch_p.data.copy_(p) - - init_spec_func(model, pg) - - check_param_equal(model, torch_model, pg) - - # close the dropout in eval mode - model.eval() - torch_model.eval() - set_seed(pg.dp_local_rank()) - torch.distributed.barrier() - for i, (input_ids, label) in enumerate(train_dataloader): - colo_input = ColoTensor.from_torch_tensor(input_ids, ColoTensorSpec(pg)) - logits = model(colo_input) - torch_logits = torch_model(input_ids) - assert tensor_equal(torch_logits, logits), f"{torch_logits - logits}" - loss = criterion(logits, input_ids) - torch_loss = criterion(torch_logits, input_ids) - if use_ddp: - model.backward(loss) - else: - loss.backward() - torch_loss.backward() - check_grad_equal(model, torch_model, pg) - if i > 0: - break - set_seed(313) - - -def run_dist(rank, world_size, port, use_ddp): - if use_ddp and world_size == 1: - return - colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') - # Comments below tests for speed concern - # run_gpt(init_1d_row_spec, use_ddp) - # run_gpt(init_1d_col_spec, use_ddp) - run_gpt(init_megatron_spec, use_ddp) - - -@pytest.mark.dist -@pytest.mark.parametrize('world_size', [1, 4]) -@pytest.mark.parametrize('use_ddp', [False, True]) -@rerun_if_address_is_in_use() -def test_gpt(world_size, use_ddp): - spawn(run_dist, world_size, use_ddp=use_ddp) - - -if __name__ == '__main__': - test_gpt(4, use_ddp=False) diff --git a/tests/test_tensor/model/test_model.py b/tests/test_tensor/model/test_model.py deleted file mode 100644 index 79d70e53c5cb133e31c1741799b2eb3b0dc424b8..0000000000000000000000000000000000000000 --- a/tests/test_tensor/model/test_model.py +++ /dev/null @@ -1,334 +0,0 @@ -import pytest -import torch - -import colossalai -from colossalai.nn.optimizer import ColossalaiOptimizer -from colossalai.tensor import ColoTensor, ProcessGroup -from colossalai.tensor.colo_parameter import ColoParameter -from colossalai.testing import free_port, rerun_if_address_is_in_use, spawn -from colossalai.utils.cuda import get_current_device -from colossalai.zero import ColoInitContext -from tests.components_to_test.registry import non_distributed_component_funcs -from tests.test_tensor.common_utils import ( - check_equal, - set_seed, - split_param_col_tp1d, - split_param_row_tp1d, - tensor_shard_equal, -) - - -def run_1d_hybrid_tp(model_name): - # A simple net with two stacked nn.Linear - get_components_func = non_distributed_component_funcs.get_callable(model_name) - model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func() - - rank = torch.distributed.get_rank() - world_size = torch.distributed.get_world_size() - - set_seed(1) - with ColoInitContext(device=get_current_device()): - model = model_builder(checkpoint=True) - - if rank == 0: - model_torch = model_builder(checkpoint=True) - model_torch = model_torch.cuda() - - optimizer_torch = ColossalaiOptimizer(torch.optim.SGD(model_torch.parameters(), lr=0.1)) - - # Make two models have the same init params - for p1, p2 in zip(model.parameters(), model_torch.parameters()): - p2.data.copy_(p1.data) - else: - model_torch = None - optimizer_torch = None - - pg = ProcessGroup(tp_degree=world_size) - if 'bert' == model_name: - for name, p in model.named_parameters(): - if not isinstance(p, ColoTensor): - continue - - # num_class = type_vocab_size = 2 | (8, 2) - if 'classifier' in name and 'weight' in name: - split_param_col_tp1d(p, pg) - # num_class = vocab_size = 30524 | (30524, 8) - elif 'word_embeddings' in name and 'weight' in name: - split_param_row_tp1d(p, pg) - # num_class = seq_len = 512 | (512, 8) - elif 'position_embeddings' in name and 'weight' in name: - split_param_row_tp1d(p, pg) - # num_class = type_vocab_size = 2 | (2, 8) - elif 'token_type_embeddings' in name and 'weight' in name: - split_param_col_tp1d(p, pg) - - elif "simple_net" == model_name: - # A naive way to set spec for all weights in Linear - for name, p in model.named_parameters(): - if not isinstance(p, ColoTensor): - continue - if 'embed' in name and 'weight' in name: - split_param_col_tp1d(p, pg) - if 'proj1' in name and ('weight' in name or 'bias' in name): - split_param_row_tp1d(p, pg) - if 'proj2' in name and 'weight' in name: - split_param_col_tp1d(p, pg) - if 'classifier' in name and ('weight' in name or 'bias' in name): - split_param_row_tp1d(p, pg) - - model = model.cuda() - model.eval() - if rank == 0: - model_torch.eval() - - colo_optimizer = ColossalaiOptimizer(torch.optim.SGD(model.parameters(), lr=0.1)) - - for i, (data, label) in enumerate(train_dataloader): - - # Zero grad - colo_optimizer.zero_grad() - if rank == 0: - optimizer_torch.zero_grad() - torch.distributed.barrier() - - data = data.to(get_current_device()) - label = label.to(get_current_device()) - - torch.distributed.broadcast(data, 0, group=pg.tp_process_group()) - torch.distributed.broadcast(label, 0, group=pg.tp_process_group()) - - # Bcast rank0 data to all processes - if criterion: - output = model(data) - loss = criterion(output, label) - else: - output = model(data, label) - loss = output - - # Test output - if rank == 0: - if criterion: - output_torch = model_torch(data) - loss_torch = criterion(output_torch, label) - else: - output_torch = model_torch(data, label) - loss_torch = output_torch - assert torch.allclose(loss, loss_torch, rtol=1e-2), f"model_name {model_name} failed" - torch.distributed.barrier() - - loss.backward() - colo_optimizer.step() - - if rank == 0: - loss_torch.backward() - optimizer_torch.step() - - with torch.no_grad(): - # check param - for p, torch_p in zip(model.parameters(), model_torch.parameters()): - assert tensor_shard_equal(torch_p, p, pg.tp_local_rank(), pg.tp_world_size()) - torch.distributed.barrier() - if i > 5: - break - - -# Test the overrided parameters() and named_parameters() member functions -def test_model_parameters(): - colossalai.launch(config={}, rank=0, world_size=1, host='localhost', port=free_port(), backend='nccl') - - # build a module with 2 Linear, 4 parameters in total. - class Net(torch.nn.Module): - - def __init__(self): - super().__init__() - self.fcs = torch.nn.Sequential(torch.nn.Linear(2, 3), torch.nn.Linear(3, 2)) - self.extra_param = torch.nn.Parameter(torch.randn(2)) - - with ColoInitContext(device=get_current_device()): - model = Net() - - param_cnt = 0 - for name, p in model.named_parameters(): - param_cnt += 1 - assert param_cnt == 5 - - for name, colo_p in model.named_parameters(): - assert colo_p.is_model_data() - - param_cnt = 0 - for name, p in model.named_parameters(recurse=False): - param_cnt += 1 - assert param_cnt == 1 - - param_cnt = 0 - for p in model.fcs[0].parameters(recurse=False): - param_cnt += 1 - assert param_cnt == 2 - - -def test_colo_optimizer(): - get_components_func = non_distributed_component_funcs.get_callable('simple_net') - model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func() - set_seed(1) - with ColoInitContext(device=get_current_device()): - model = model_builder(checkpoint=True) - - colo_optimizer = ColossalaiOptimizer(torch.optim.SGD(model.parameters(), lr=0.1)) - for i, (data, label) in enumerate(train_dataloader): - colo_optimizer.zero_grad() - data = data.to(get_current_device()) - label = label.to(get_current_device()) - - # Bcast rank0 data to all processes - if criterion: - output = model(data) - loss = criterion(output, label) - else: - output = model(data, label) - loss = output - - loss.backward() - colo_optimizer.step() - - if i > 5: - break - - -def run_1d_row_tp(model_name: str): - # A simple net with two stacked nn.Linear - get_components_func = non_distributed_component_funcs.get_callable(model_name) - model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func() - rank = torch.distributed.get_rank() - - set_seed(1) - with ColoInitContext(device=get_current_device()): - model = model_builder(checkpoint=True) - - world_size = torch.distributed.get_world_size() - pg = ProcessGroup(tp_degree=world_size) - - set_seed(1) - if rank == 0: - model_torch = model_builder(checkpoint=True) - model_torch = model_torch.cuda() - - # A naive way to set spec for all weights in Linear - for mo_name, module in model.named_modules(): - # print(mo_name) - for pa_name, param in module.named_parameters(recurse=False): - # print('\t', pa_name, param.shape) - if not isinstance(param, ColoTensor): - continue - if 'weight' in pa_name: - if 'embed' in mo_name and 'token' not in mo_name and 'LayerNorm' not in mo_name: - split_param_row_tp1d(param, pg) - elif 'LayerNorm' not in mo_name and 'ln' not in mo_name: - split_param_col_tp1d(param, pg) - - model = model.cuda() - - for i, (data, label) in enumerate(train_dataloader): - data = data.to(get_current_device()) - label = label.to(get_current_device()) - - torch.distributed.broadcast(data, 0, group=pg.tp_process_group()) - torch.distributed.broadcast(label, 0, group=pg.tp_process_group()) - - # Bcast rank0 data to all processes - if criterion: - output = model(data) - loss = criterion(output, label) - else: - output = model(data, label) - loss = output - - # For reference - if rank == 0: - if criterion: - output_torch = model_torch(data) - loss_torch = criterion(output_torch, label) - else: - output_torch = model_torch(data, label) - loss_torch = output_torch - assert torch.allclose(loss, loss_torch, rtol=1e-2) - torch.distributed.barrier() - - loss.backward() - - if rank == 0: - loss_torch.backward() - torch.distributed.barrier() - - if i > 5: - break - - -def _run_pretrain_load(): - from transformers import BertForMaskedLM - set_seed(1) - model_pretrained = BertForMaskedLM.from_pretrained('bert-base-uncased') - with ColoInitContext(device=get_current_device()): - model = BertForMaskedLM.from_pretrained('bert-base-uncased') - - model_pretrained = model_pretrained.cuda() - model = model.cuda() - - dict_pretrained = {} - dict_col = {} - c_ref = 0 - for name, param in model_pretrained.named_parameters(): - dict_pretrained[name] = param - c_ref += 1 - c1 = 0 - c2 = 0 - for name, param in model.named_parameters(): - if isinstance(param, ColoParameter): - c1 += 1 - else: - c2 += 1 - dict_col[name] = param - assert c_ref == c1 - assert c2 == 0 - if model_pretrained.cls.predictions.decoder.bias is model_pretrained.cls.predictions.bias: - assert model.cls.predictions.decoder.bias is model.cls.predictions.bias - - for name, param in dict_pretrained.items(): - check_equal(param, dict_col[name]) - - -def run_model_dist(rank, world_size, port): - colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') - # Comment below test for speed consideration - # for name in ['bert', 'simple_net']: - # run_1d_row_tp(name) - for name in ['bert', 'simple_net']: - run_1d_hybrid_tp(name) - - -@pytest.mark.dist -@pytest.mark.parametrize('world_size', [1, 4]) -@rerun_if_address_is_in_use() -def test_model(world_size): - spawn(run_model_dist, world_size) - - -def run_pretrain_load_dist(rank, world_size, port): - colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') - _run_pretrain_load() - - -# The test case has to download huggingface pretrained models from the internet -# So we manually trigger the test. -@pytest.mark.skip -@pytest.mark.dist -@pytest.mark.parametrize('world_size', [1, 4]) -@rerun_if_address_is_in_use() -def test_pretrain_load(world_size): - spawn(run_pretrain_load_dist, world_size) - - -if __name__ == '__main__': - # test_model_parameters() - # test_colo_optgimizer() - test_model(4) - # test_pretrain_load(4) diff --git a/tests/test_tensor/model/test_module_spec.py b/tests/test_tensor/model/test_module_spec.py deleted file mode 100644 index b50851e5eaf2f04cd5c7480a0fa2dc4ff02c4876..0000000000000000000000000000000000000000 --- a/tests/test_tensor/model/test_module_spec.py +++ /dev/null @@ -1,227 +0,0 @@ -from copy import deepcopy - -import pytest -import torch - -import colossalai -from colossalai.nn.parallel.layers import check_colo_module, init_colo_module -from colossalai.tensor import ( - ColoTensor, - ColoTensorSpec, - ComputePattern, - ComputeSpec, - ProcessGroup, - ReplicaSpec, - ShardSpec, - distspec, -) -from colossalai.testing import rerun_if_address_is_in_use, spawn -from colossalai.utils.cuda import get_current_device -from colossalai.zero import ColoInitContext -from tests.components_to_test.registry import non_distributed_component_funcs -from tests.test_tensor.common_utils import set_seed, tensor_equal, tensor_shard_equal - - -def run_model_with_spec(mode, model_name): - get_components_func = non_distributed_component_funcs.get_callable(model_name) - model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func() - world_size = torch.distributed.get_world_size() - pg = ProcessGroup(tp_degree=world_size) - rank = pg.rank() - - set_seed(1) - with ColoInitContext(device=get_current_device()): - model = model_builder(checkpoint=False) - - if rank == 0: - model_seq = model_builder(checkpoint=False) - model_seq = model_seq.cuda() - - # Make two models have the same init params - for p1, p2 in zip(model.parameters(), model_seq.parameters()): - p2.data.copy_(p1.data) - - compute_spec = ComputeSpec(ComputePattern.TP1D) - # Not all layers in Bert can be mod by 4. - # e.g. row shard for all layers is invalid because the first dim of some layer is the classification type size 2. - if 'bert' == model_name: - if 'col' == mode: - init_colo_module(model.bert.embeddings, compute_spec, pg=pg, recursive=True, mode=mode) - init_colo_module(model.bert.encoder, compute_spec, pg=pg, recursive=True, mode=mode) - init_colo_module(model.classifier, compute_spec, pg=pg, recursive=True, mode='row') - elif 'row' == mode: - init_colo_module(model.bert.embeddings, compute_spec, pg=pg, recursive=True, mode='col') - init_colo_module(model.bert.encoder, compute_spec, pg=pg, recursive=True, mode=mode) - init_colo_module(model.classifier, compute_spec, pg=pg, recursive=True, mode=mode) - elif 'simple_net' == model_name: - init_colo_module(model, compute_spec, pg=pg, recursive=True, mode=mode) - - model = model.cuda() - for i, (data, label) in enumerate(train_dataloader): - data = data.to(get_current_device()) - label = label.to(get_current_device()) - - torch.distributed.broadcast(data, 0, group=pg.tp_process_group()) - torch.distributed.broadcast(label, 0, group=pg.tp_process_group()) - - if criterion: - output = model(data) - loss = criterion(output, label) - else: - output = model(data, label) - loss = output - - # For reference - if rank == 0: - if criterion: - output_seq = model_seq(data) - loss_seq = criterion(output_seq, label) - else: - output_seq = model_seq(data, label) - loss_seq = output_seq - - if rank == 0: - with torch.no_grad(): - assert torch.allclose(loss, loss_seq, rtol=1e-2) - - loss.backward() - - if rank == 0: - loss_seq.backward() - - with torch.no_grad(): - # check param - for p1, p2 in zip(model.parameters(), model_seq.parameters()): - if p1.size() == p2.size(): - assert torch.allclose(p1, p2) - else: - if p1.size(-1) < p2.size(-1): # col - world_size = p2.size(-1) // p1.size(-1) - split_p2 = torch.chunk(p2, world_size, dim=-1)[0] - - elif p1.size(0) < p2.size(0): # row - world_size = p2.size(0) // p1.size(0) - split_p2 = torch.chunk(p2, world_size, dim=0)[0] - - assert torch.allclose(p1, split_p2) - - if i > 3: - break - - -def run_linear_with_spec(mode): - with ColoInitContext(device=get_current_device()): - model = torch.nn.Linear(4, 8) - - model_handy = deepcopy(model) - world_size = torch.distributed.get_world_size() - pg = ProcessGroup(tp_degree=world_size) - compute_spec = ComputeSpec(ComputePattern.TP1D) - init_colo_module(model, compute_spec, pg=pg, recursive=True, mode=mode) - - x = torch.rand(2, 4).cuda() - colo_x = ColoTensor.from_torch_tensor(x, ColoTensorSpec(pg)) - - out = model(x) - colo_out = model_handy(colo_x) - assert tensor_equal(out, colo_out) - - grad = torch.rand_like(out) - out.backward(grad) - colo_out.backward(grad) - - assert tensor_shard_equal(model_handy.weight.grad, model.weight.grad, pg.tp_local_rank(), pg.tp_world_size()) - assert tensor_shard_equal(model_handy.bias.grad, model.bias.grad, pg.tp_local_rank(), pg.tp_world_size()) - - -def run_check_shared_param(): - from transformers import BertConfig, BertForMaskedLM - hidden_dim = 8 - num_head = 4 - sequence_length = 12 - num_layer = 2 - vocab_size = 24 - - world_size = torch.distributed.get_world_size() - pg = ProcessGroup(tp_degree=world_size) - rank = pg.rank() - - config = BertConfig(vocab_size=vocab_size, - hidden_size=hidden_dim, - intermediate_size=hidden_dim * 4, - num_attention_heads=num_head, - max_position_embeddings=sequence_length, - num_hidden_layers=num_layer, - hidden_dropout_prob=0., - attention_probs_dropout_prob=0.) - with ColoInitContext(device=get_current_device()): - model = BertForMaskedLM(config) - - model = model.cuda() - compute_spec = ComputeSpec(ComputePattern.TP1D) - # model.cls.predictions.decoder and model.cls.predictions share the bias, so they should have the same spec - assert len(model.cls.predictions.decoder.bias.shared_param_modules) == 2 - # They are all Linear, so both row is allowed. This should pass check. - init_colo_module(model, compute_spec, pg=pg, recursive=True, mode='row') - # This should be detected by check because you can not set weight as row while set bias as col. - col_spec = (ShardSpec([0], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D)) - - # TODO(jiaruifang) optimize this line - if not model.cls.predictions.bias.has_initialized: - model.cls.predictions.bias.pg = pg - model.cls.predictions.bias.dist_spec = ReplicaSpec() - model.cls.predictions.bias.has_initialized = True - model.cls.predictions.bias.set_tensor_spec(*col_spec) - try: - check_colo_module(model.cls.predictions.decoder, pg=pg, recursive=False) - except Exception as e: - assert 'incorrectly sharded' in str(e) - - -def run_dist(rank, world_size, port): - config = dict(parallel=dict(tensor=dict(mode="1d", size=world_size),)) - colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') - run_linear_with_spec('col') - run_linear_with_spec('row') - - -def run_dist_model(rank, world_size, port): - config = dict(parallel=dict(tensor=dict(mode="1d", size=world_size),)) - colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') - for model_name in ['simple_net', 'bert']: - run_model_with_spec('col', model_name) - run_model_with_spec('row', model_name) - - -def run_dist_check(rank, world_size, port): - config = dict(parallel=dict(tensor=dict(mode="1d", size=world_size),)) - colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') - run_check_shared_param() - - -@pytest.mark.dist -@pytest.mark.parametrize('world_size', [1, 4]) -@pytest.mark.skip("for higher testing speed") -@rerun_if_address_is_in_use() -def test_module_linear_1d(world_size): - spawn(run_dist, world_size) - - -@pytest.mark.dist -@pytest.mark.parametrize('world_size', [1, 4]) -@pytest.mark.skip("for higher testing speed") -@rerun_if_address_is_in_use() -def test_module_model(world_size): - spawn(run_dist_model, world_size) - - -@pytest.mark.dist -@pytest.mark.parametrize('world_size', [1, 2]) -@pytest.mark.skip("for higher testing speed") -@rerun_if_address_is_in_use() -def test_module_check(world_size): - spawn(run_dist_check, world_size) - - -if __name__ == '__main__': - test_module_linear_1d(4) diff --git a/tests/test_tensor/test_colo_checkpoint_tools.py b/tests/test_tensor/test_colo_checkpoint_tools.py deleted file mode 100644 index a53a3f37a664134e9b0f379ff3ad7f8b5abd5b16..0000000000000000000000000000000000000000 --- a/tests/test_tensor/test_colo_checkpoint_tools.py +++ /dev/null @@ -1,41 +0,0 @@ -import pytest -import torch -import torch.distributed as dist - -import colossalai -from colossalai.tensor import ColoTensor, ColoTensorSpec, ComputePattern, ComputeSpec, ProcessGroup, ShardSpec -from colossalai.testing import rerun_if_address_is_in_use, spawn -from colossalai.utils.checkpoint.utils import gather_tensor, scatter_tensor -from tests.test_tensor.common_utils import tensor_shard_equal - - -def run_dist(rank, world_size, port, dp_degree, tp_degree): - colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') - pg = ProcessGroup(dp_degree=dp_degree, tp_degree=tp_degree) - x = torch.randn(4, 4) - param = ColoTensor(torch.nn.Parameter(x), spec=ColoTensorSpec(pg)) - spec = ShardSpec([-1], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D) - param.set_tensor_spec(*spec) - - gather_tensor(param) - if dist.get_rank() == 0: - assert torch.all(x == param) - else: - assert tensor_shard_equal(x, param.data, pg.tp_local_rank(), pg.tp_world_size()) - dist.barrier() - - scatter_tensor(param, spec[0]) - assert tensor_shard_equal(x, param.data, pg.tp_local_rank(), pg.tp_world_size()) - assert param.requires_grad is True - dist.barrier() - - -@pytest.mark.dist -@pytest.mark.parametrize('world_size', [4]) -@rerun_if_address_is_in_use() -def test_checkpoint(world_size): - spawn(run_dist, world_size, dp_degree=2, tp_degree=world_size // 2) - - -if __name__ == '__main__': - test_checkpoint(world_size=4) diff --git a/tests/test_tensor/test_comm_spec_apply.py b/tests/test_tensor/test_comm_spec_apply.py index 2c68633aabc819706e35f69dc4ebfbb521dc1972..5e969b1aaf9802bf8276835bd8d142ff0cbdb7b4 100644 --- a/tests/test_tensor/test_comm_spec_apply.py +++ b/tests/test_tensor/test_comm_spec_apply.py @@ -1,7 +1,7 @@ import pytest import torch +import torch.distributed as dist -from colossalai.core import global_context as gpc from colossalai.device.device_mesh import DeviceMesh from colossalai.initialize import launch from colossalai.logging import disable_existing_loggers @@ -29,10 +29,9 @@ def check_all_gather(device_mesh, rank): sharding_spec = ShardingSpec(device_mesh, tensor_to_check.shape, dim_partition_dict=dim_partition_dict) # CommSpec:(comm_pattern:allgather, gather_dim:1, logical_process_axis:1) - comm_spec = CommSpec(CollectiveCommPattern.GATHER_FWD_SPLIT_BWD, - sharding_spec, - gather_dim=1, - logical_process_axis=1) + comm_spec = CommSpec( + CollectiveCommPattern.GATHER_FWD_SPLIT_BWD, sharding_spec, gather_dim=1, logical_process_axis=1 + ) sharded_tensor_to_comm = sharded_tensor_to_comm = comm_spec.covert_spec_to_action(sharded_tensor_to_comm) assert sharded_tensor_to_comm.equal(tensor_to_check) @@ -101,11 +100,9 @@ def check_all_to_all(device_mesh, rank): sharding_spec = ShardingSpec(device_mesh, torch.Size((4, 2)), dim_partition_dict=dim_partition_dict) # CommSpec:(comm_pattern:shard, shard_dim:1, logical_process_axis:1) - comm_spec = CommSpec(CollectiveCommPattern.ALL2ALL_FWD_ALL2ALL_BWD, - sharding_spec, - gather_dim=0, - shard_dim=1, - logical_process_axis=0) + comm_spec = CommSpec( + CollectiveCommPattern.ALL2ALL_FWD_ALL2ALL_BWD, sharding_spec, gather_dim=0, shard_dim=1, logical_process_axis=0 + ) tensor_to_comm = comm_spec.covert_spec_to_action(tensor_to_comm) assert tensor_to_comm.equal(tensor_to_check) @@ -181,10 +178,10 @@ def check_all_reduce_in_flatten_device_mesh(device_mesh, rank): def check_comm(rank, world_size, port): disable_existing_loggers() - launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") physical_mesh_id = torch.arange(0, 4) - assert rank == gpc.get_global_rank() + assert rank == dist.get_rank() mesh_shape = (2, 2) # [[0, 1, @@ -205,7 +202,6 @@ def check_comm(rank, world_size, port): # test all reduce in 1D flatten device mesh check_all_reduce_in_flatten_device_mesh(device_mesh, rank) - gpc.destroy() @pytest.mark.dist @@ -215,5 +211,5 @@ def test_comm_spec(): spawn(check_comm, world_size) -if __name__ == '__main__': +if __name__ == "__main__": test_comm_spec() diff --git a/tests/test_tensor/test_context.py b/tests/test_tensor/test_context.py deleted file mode 100644 index 45def034ba8e6f592dfe294d16561558d2139c3b..0000000000000000000000000000000000000000 --- a/tests/test_tensor/test_context.py +++ /dev/null @@ -1,64 +0,0 @@ -import pytest -import torch - -import colossalai -from colossalai.tensor import ( - ColoParameter, - ColoTensorSpec, - ComputePattern, - ComputeSpec, - ProcessGroup, - ReplicaSpec, - ShardSpec, -) -from colossalai.testing import rerun_if_address_is_in_use, spawn -from colossalai.utils.cuda import get_current_device -from colossalai.zero import ColoInitContext -from tests.components_to_test.registry import non_distributed_component_funcs -from tests.test_tensor.common_utils import set_seed - - -def run_colo_init_context(rank: int, world_size: int, port: int): - colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') - - # make sure seed of each process is the same, so the params are consistent among processes and the params are exactly replicated. - set_seed(42) - get_components_func = non_distributed_component_funcs.get_callable('gpt2') - model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func() - - # keep parameters replicated during init - with ColoInitContext(device=get_current_device()): - model1 = model_builder() - - # shard the parameters during init - set_seed(42) - shard_spec = ReplicaSpec() - - # If using ShardSpec, the assertations will failed. - # But it is not a bug, the initialized values are not consist with the original one. - # shard_spec = ShardSpec(dims=[0], num_partitions=[world_size]) - default_pg = ProcessGroup(tp_degree=world_size) - with ColoInitContext(device=get_current_device(), default_pg=default_pg, default_dist_spec=shard_spec): - model2 = model_builder() - - # reshard both models - new_shard = ShardSpec(dims=[-1], num_partitions=[world_size]) - for p1, p2 in zip(model1.parameters(), model2.parameters()): - p1: ColoParameter = p1 - p1.set_process_group(ProcessGroup(tp_degree=world_size)) - p1.set_dist_spec(new_shard) - p2.set_dist_spec(new_shard) - - for p1, p2 in zip(model1.parameters(), model2.parameters()): - assert (torch.allclose(p1, p2)) - - -@pytest.mark.dist -@pytest.mark.parametrize('world_size', [1, 4]) -@rerun_if_address_is_in_use() -def test_colo_init_context(world_size): - spawn(run_colo_init_context, world_size) - - -if __name__ == '__main__': - test_colo_init_context(2) diff --git a/tests/test_tensor/test_dtensor/test_comm_spec.py b/tests/test_tensor/test_dtensor/test_comm_spec.py index d1f5b9299397f0b63e6c66e072878bdcf0059afd..6d1640b4f3dce060efbc9172a736da74664d80b8 100644 --- a/tests/test_tensor/test_dtensor/test_comm_spec.py +++ b/tests/test_tensor/test_dtensor/test_comm_spec.py @@ -1,14 +1,11 @@ import pytest import torch import torch.distributed as dist -from torch.distributed import ReduceOp -from colossalai.core import global_context as gpc from colossalai.device.device_mesh import DeviceMesh from colossalai.initialize import launch from colossalai.logging import disable_existing_loggers from colossalai.tensor.d_tensor.comm_spec import CollectiveCommPattern, CommSpec -from colossalai.tensor.d_tensor.sharding_spec import ShardingSpec from colossalai.testing import rerun_if_address_is_in_use, spawn @@ -23,10 +20,9 @@ def check_all_gather(process_groups_dict, rank): tensor_to_check = torch.cat((torch.ones(2, 2), torch.zeros(2, 2)), 1).cuda() # CommSpec:(comm_pattern:allgather, gather_dim:1, logical_process_axis:1) - comm_spec = CommSpec(CollectiveCommPattern.GATHER_FWD_SPLIT_BWD, - process_groups_dict, - gather_dim=1, - logical_process_axis=1) + comm_spec = CommSpec( + CollectiveCommPattern.GATHER_FWD_SPLIT_BWD, process_groups_dict, gather_dim=1, logical_process_axis=1 + ) sharded_tensor_to_comm = sharded_tensor_to_comm = comm_spec.covert_spec_to_action(sharded_tensor_to_comm) assert sharded_tensor_to_comm.equal(tensor_to_check) @@ -41,10 +37,9 @@ def check_shard(process_groups_dict, rank): tensor_to_shard = torch.cat((sharded_tensor_to_comm_0, sharded_tensor_to_comm_1), 1) # CommSpec:(comm_pattern:shard, shard_dim:1, logical_process_axis:1) - comm_spec = CommSpec(CollectiveCommPattern.SPLIT_FWD_GATHER_BWD, - process_groups_dict, - shard_dim=1, - logical_process_axis=1) + comm_spec = CommSpec( + CollectiveCommPattern.SPLIT_FWD_GATHER_BWD, process_groups_dict, shard_dim=1, logical_process_axis=1 + ) tensor_to_shard = comm_spec.covert_spec_to_action(tensor_to_shard) if rank in (0, 2): @@ -82,11 +77,13 @@ def check_all_to_all(process_groups_dict, rank): tensor_to_check = torch.tensor([[1], [1], [3], [3]], dtype=tensor_to_comm.dtype).cuda() # CommSpec:(comm_pattern:shard, shard_dim:1, logical_process_axis:1) - comm_spec = CommSpec(CollectiveCommPattern.ALL2ALL_FWD_ALL2ALL_BWD, - process_groups_dict, - gather_dim=0, - shard_dim=1, - logical_process_axis=0) + comm_spec = CommSpec( + CollectiveCommPattern.ALL2ALL_FWD_ALL2ALL_BWD, + process_groups_dict, + gather_dim=0, + shard_dim=1, + logical_process_axis=0, + ) tensor_to_comm = comm_spec.covert_spec_to_action(tensor_to_comm) assert tensor_to_comm.equal(tensor_to_check) @@ -125,53 +122,32 @@ def check_all_reduce_bwd(process_groups_dict, rank): assert tensor_to_comm.equal(tensor_to_check) -def check_all_reduce_in_flatten_device_mesh(process_groups_dict, rank): - # tensor to comm - tensor_to_comm = torch.ones(2, 2).cuda() * rank - - # reduce through logical process axis 0 at flatten device mesh - # tensor to check - # tensor([[6., 6.], - # [6., 6.]]) - tensor_to_check = torch.tensor([[6, 6], [6, 6]], dtype=tensor_to_comm.dtype).cuda() - - # CommSpec:(comm_pattern:all_reduce, logical_process_axis:[0, 1]) - comm_spec = CommSpec(CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD, process_groups_dict, logical_process_axis=0) - tensor_to_comm = comm_spec.covert_spec_to_action(tensor_to_comm) - - assert tensor_to_comm.equal(tensor_to_check) - - def check_comm(rank, world_size, port): disable_existing_loggers() - launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") physical_mesh_id = torch.arange(0, 4) - assert rank == gpc.get_global_rank() + assert rank == dist.get_rank() mesh_shape = (2, 2) # [[0, 1, # [2, 3]] device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True) - process_groups_dict = device_mesh.process_groups_dict + + process_group_dict = device_mesh._process_group_dict[rank] # test all gather - check_all_gather(process_groups_dict, rank) + check_all_gather(process_group_dict, rank) # test shard - check_shard(process_groups_dict, rank) + check_shard(process_group_dict, rank) # test all to all - check_all_to_all(process_groups_dict, rank) + check_all_to_all(process_group_dict, rank) # test all reduce - check_all_reduce_fwd(process_groups_dict, rank) - check_all_reduce_bwd(process_groups_dict, rank) - - flatten_process_groups_dict = device_mesh.flatten_device_mesh.process_groups_dict - # test all reduce in 1D flatten device mesh - check_all_reduce_in_flatten_device_mesh(flatten_process_groups_dict, rank) - gpc.destroy() + check_all_reduce_fwd(process_group_dict, rank) + check_all_reduce_bwd(process_group_dict, rank) @pytest.mark.dist @@ -181,5 +157,5 @@ def test_comm_spec(): spawn(check_comm, world_size) -if __name__ == '__main__': +if __name__ == "__main__": test_comm_spec() diff --git a/tests/test_tensor/test_dtensor/test_dtensor.py b/tests/test_tensor/test_dtensor/test_dtensor.py index 3ca369acbf87c6cc9785c9304fa3ed06bdae9661..33ae59d015507f53db6de1d6d50bb0d7abd3c07f 100644 --- a/tests/test_tensor/test_dtensor/test_dtensor.py +++ b/tests/test_tensor/test_dtensor/test_dtensor.py @@ -3,14 +3,11 @@ import torch from colossalai.device.device_mesh import DeviceMesh from colossalai.initialize import launch from colossalai.logging import disable_existing_loggers -from colossalai.tensor.d_tensor.d_tensor import DTensor, distribute_tensor -from colossalai.tensor.d_tensor.layout import Layout -from colossalai.tensor.d_tensor.sharding_spec import ShardingSpec +from colossalai.tensor.d_tensor import ShardingSpec, distribute_tensor, get_global_shape, redistribute, to_global from colossalai.testing import rerun_if_address_is_in_use, spawn class TestModel(torch.nn.Module): - def __init__(self, in_features, out_features): super().__init__() self.linear_1 = torch.nn.Linear(in_features, out_features) @@ -24,29 +21,25 @@ class TestModel(torch.nn.Module): def check_dtensor(rank, world_size, port): disable_existing_loggers() - launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') - test_model = TestModel(8, 8).to('cuda') - original_tensor = torch.rand(4, 8).to('cuda') + launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") + test_model = TestModel(8, 8).to("cuda") + original_tensor = torch.rand(4, 8).to("cuda") compare_output = test_model(original_tensor) device_mesh = DeviceMesh(torch.Tensor([0, 1, 2, 3]), (2, 2), init_process_group=True) target_sharding_spec = ShardingSpec(dim_size=original_tensor.dim(), dim_partition_dict={0: [0]}) - layout = Layout(device_mesh=device_mesh, - device_type=torch.device('cuda'), - sharding_spec=target_sharding_spec, - entire_shape=original_tensor.shape) - d_tensor = DTensor(original_tensor, layout) + d_tensor = distribute_tensor(original_tensor, device_mesh, target_sharding_spec) - assert d_tensor.entire_shape == original_tensor.shape - assert d_tensor.data_type == original_tensor.dtype + assert get_global_shape(d_tensor) == original_tensor.shape + assert d_tensor.dtype == original_tensor.dtype if rank in (0, 1): - assert d_tensor.to_local().equal(original_tensor.narrow(0, 0, 2)) + assert d_tensor.equal(original_tensor.narrow(0, 0, 2)) elif rank in (2, 3): - assert d_tensor.to_local().equal(original_tensor.narrow(0, 2, 2)) + assert d_tensor.equal(original_tensor.narrow(0, 2, 2)) else: - raise ValueError(f'rank {rank} is not in the device mesh') - assert d_tensor.to_global().equal(original_tensor) + raise ValueError(f"rank {rank} is not in the device mesh") + assert to_global(d_tensor).equal(original_tensor) output = test_model(d_tensor) if rank in (0, 1): @@ -54,39 +47,34 @@ def check_dtensor(rank, world_size, port): elif rank in (2, 3): assert output.equal(compare_output.narrow(0, 2, 2)) else: - raise ValueError(f'rank {rank} is not in the device mesh') + raise ValueError(f"rank {rank} is not in the device mesh") new_sharding_spec = ShardingSpec(dim_size=original_tensor.dim(), dim_partition_dict={0: [0, 1]}) - new_layout = Layout(device_mesh=device_mesh, - device_type=torch.device('cuda'), - sharding_spec=new_sharding_spec, - entire_shape=original_tensor.shape) - - d_tensor.layout_convert(new_layout) + d_tensor = redistribute(d_tensor, device_mesh, new_sharding_spec) if rank == 0: - assert d_tensor.local_tensor.equal(original_tensor.narrow(0, 0, 1)) + assert d_tensor.equal(original_tensor.narrow(0, 0, 1)) elif rank == 1: - assert d_tensor.local_tensor.equal(original_tensor.narrow(0, 1, 1)) + assert d_tensor.equal(original_tensor.narrow(0, 1, 1)) elif rank == 2: - assert d_tensor.local_tensor.equal(original_tensor.narrow(0, 2, 1)) + assert d_tensor.equal(original_tensor.narrow(0, 2, 1)) elif rank == 3: - assert d_tensor.local_tensor.equal(original_tensor.narrow(0, 3, 1)) + assert d_tensor.equal(original_tensor.narrow(0, 3, 1)) else: - raise ValueError(f'rank {rank} is not in the device mesh') + raise ValueError(f"rank {rank} is not in the device mesh") - dtensor_from_local = distribute_tensor(original_tensor, new_layout) + dtensor_from_local = distribute_tensor(original_tensor, device_mesh, new_sharding_spec) if rank == 0: - assert dtensor_from_local.local_tensor.equal(original_tensor.narrow(0, 0, 1)) + assert dtensor_from_local.equal(original_tensor.narrow(0, 0, 1)) elif rank == 1: - assert dtensor_from_local.local_tensor.equal(original_tensor.narrow(0, 1, 1)) + assert dtensor_from_local.equal(original_tensor.narrow(0, 1, 1)) elif rank == 2: - assert dtensor_from_local.local_tensor.equal(original_tensor.narrow(0, 2, 1)) + assert dtensor_from_local.equal(original_tensor.narrow(0, 2, 1)) elif rank == 3: - assert dtensor_from_local.local_tensor.equal(original_tensor.narrow(0, 3, 1)) + assert dtensor_from_local.equal(original_tensor.narrow(0, 3, 1)) else: - raise ValueError(f'rank {rank} is not in the device mesh') + raise ValueError(f"rank {rank} is not in the device mesh") @rerun_if_address_is_in_use() @@ -95,5 +83,5 @@ def test_dtensor(): spawn(check_dtensor, world_size) -if __name__ == '__main__': +if __name__ == "__main__": test_dtensor() diff --git a/tests/test_tensor/test_dtensor/test_dtensor_sharding_spec.py b/tests/test_tensor/test_dtensor/test_dtensor_sharding_spec.py index 7fd1c3d90fc4c1c4810b8b177c92ec1d514e0d29..654a4438479a9d6bc131f0eac51864537ac9cb3b 100644 --- a/tests/test_tensor/test_dtensor/test_dtensor_sharding_spec.py +++ b/tests/test_tensor/test_dtensor/test_dtensor_sharding_spec.py @@ -26,9 +26,10 @@ def test_dtensor_sharding_spec(): assert dim_spec_list_0[2].dim_diff(dim_spec_list_1[2]) == 0 assert dim_spec_list_0[3].dim_diff(dim_spec_list_1[3]) == 0 - assert sharding_spec_0.spec_diff(sharding_spec_1) == \ - reduce(operator.add, [dim_spec_list_0[i].dim_diff(dim_spec_list_1[i]) for i in range(dims)], 0) + assert sharding_spec_0.spec_diff(sharding_spec_1) == reduce( + operator.add, [dim_spec_list_0[i].dim_diff(dim_spec_list_1[i]) for i in range(dims)], 0 + ) -if __name__ == '__main__': +if __name__ == "__main__": test_dtensor_sharding_spec() diff --git a/tests/test_tensor/test_dtensor/test_layout_converter.py b/tests/test_tensor/test_dtensor/test_layout_converter.py index 5f56decb5e5dd03a132c7b1ff2df9a64d63d7da9..4e65401bf7b416e9699791091e93df000524de3d 100644 --- a/tests/test_tensor/test_dtensor/test_layout_converter.py +++ b/tests/test_tensor/test_dtensor/test_layout_converter.py @@ -9,18 +9,18 @@ from colossalai.logging import disable_existing_loggers from colossalai.tensor.d_tensor.comm_spec import CollectiveCommPattern from colossalai.tensor.d_tensor.layout import Layout from colossalai.tensor.d_tensor.layout_converter import LayoutConverter -from colossalai.tensor.d_tensor.sharding_spec import DimSpec, ShardingSpec +from colossalai.tensor.d_tensor.sharding_spec import ShardingSpec from colossalai.testing import rerun_if_address_is_in_use, spawn -entire_shape = torch.Size((64, 32, 16)) +global_shape = torch.Size((64, 32, 16)) layout_converter = LayoutConverter() -physical_mesh_id = torch.arange(0, 4).reshape(2, 2) +physical_mesh_id = torch.arange(0, 4) mesh_shape = (2, 2) def check_one_step_transform(rank, world_size, port): disable_existing_loggers() - launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") # [[0, 1], # [2, 3]] device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True) @@ -30,17 +30,14 @@ def check_one_step_transform(rank, world_size, port): # shard_sequence: S0,S1,R # device_mesh_shape: (2, 2) sharding_spec = ShardingSpec(dim_size=3, dim_partition_dict=dim_partition_dict) - layout = Layout(device_mesh=device_mesh, - device_type=torch.device('cuda'), - sharding_spec=sharding_spec, - entire_shape=entire_shape) + layout = Layout(device_mesh=device_mesh, sharding_spec=sharding_spec, global_shape=global_shape) rst_dict = layout_converter.all_gather_transform_layouts(layout) - assert '[R, S1, R]' in [ + assert "[R, S1, R]" in [ str(all_gather_layout.sharding_spec.sharding_sequence) for all_gather_layout in rst_dict.keys() ] - assert '[S0, R, R]' in [ + assert "[S0, R, R]" in [ str(all_gather_layout.sharding_spec.sharding_sequence) for all_gather_layout in rst_dict.keys() ] @@ -49,20 +46,17 @@ def check_one_step_transform(rank, world_size, port): # shard_sequence: S0,S1,R # device_mesh_shape: (4, 4) sharding_spec_all2all = ShardingSpec(dim_size=3, dim_partition_dict=dim_partition_dict_all2all) - layout_all2all = Layout(device_mesh=device_mesh, - device_type=torch.device('cuda'), - sharding_spec=sharding_spec_all2all, - entire_shape=entire_shape) + layout_all2all = Layout(device_mesh=device_mesh, sharding_spec=sharding_spec_all2all, global_shape=global_shape) rst_dict_all2all = layout_converter.all_to_all_transform_layout(layout_all2all) - assert '[S01, R, R]' in [ + assert "[S01, R, R]" in [ str(all2all_layout.sharding_spec.sharding_sequence) for all2all_layout in rst_dict_all2all.keys() ] - assert '[R, S1, S0]' in [ + assert "[R, S1, S0]" in [ str(all2all_layout.sharding_spec.sharding_sequence) for all2all_layout in rst_dict_all2all.keys() ] - assert '[S0, R, S1]' in [ + assert "[S0, R, S1]" in [ str(all2all_layout.sharding_spec.sharding_sequence) for all2all_layout in rst_dict_all2all.keys() ] @@ -71,27 +65,24 @@ def check_one_step_transform(rank, world_size, port): # shard_sequence: S0,R,R # device_mesh_shape: (4, 4) sharding_spec_shard = ShardingSpec(dim_size=3, dim_partition_dict=dim_partition_shard) - shard_layout = Layout(device_mesh=device_mesh, - device_type=torch.device('cuda'), - sharding_spec=sharding_spec_shard, - entire_shape=entire_shape) + shard_layout = Layout(device_mesh=device_mesh, sharding_spec=sharding_spec_shard, global_shape=global_shape) rst_dict_shard = layout_converter.shard_transform_layout(shard_layout) - assert '[S01, R, R]' in [ + assert "[S01, R, R]" in [ str(shard_layout.sharding_spec.sharding_sequence) for shard_layout in rst_dict_shard.keys() ] - assert '[S0, S1, R]' in [ + assert "[S0, S1, R]" in [ str(shard_layout.sharding_spec.sharding_sequence) for shard_layout in rst_dict_shard.keys() ] - assert '[S0, R, S1]' in [ + assert "[S0, R, S1]" in [ str(shard_layout.sharding_spec.sharding_sequence) for shard_layout in rst_dict_shard.keys() ] def check_layout_converting(rank, world_size, port): disable_existing_loggers() - launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") dim_partition_source = {1: [0, 1]} dim_partition_target = {0: [0, 1]} device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True) @@ -100,25 +91,19 @@ def check_layout_converting(rank, world_size, port): # shard_sequence: R,S01,R # device_mesh_shape: (4, 4) sharding_spec_source = ShardingSpec(dim_size=3, dim_partition_dict=dim_partition_source) - source_layout = Layout(device_mesh=device_mesh, - device_type=torch.device('cuda'), - sharding_spec=sharding_spec_source, - entire_shape=entire_shape) + source_layout = Layout(device_mesh=device_mesh, sharding_spec=sharding_spec_source, global_shape=global_shape) # DistSpec: # shard_sequence: S01,R,R # device_mesh_shape: (4, 4) sharding_spec_target = ShardingSpec(dim_size=3, dim_partition_dict=dim_partition_target) - target_layout = Layout(device_mesh=device_mesh, - device_type=torch.device('cuda'), - sharding_spec=sharding_spec_target, - entire_shape=entire_shape) + target_layout = Layout(device_mesh=device_mesh, sharding_spec=sharding_spec_target, global_shape=global_shape) transform_path, comm_action_sequence = layout_converter.layout_converting(source_layout, target_layout) # check transform path - transform_path_str = '->'.join([str(layout.sharding_spec.sharding_sequence) for layout in transform_path]) - assert transform_path_str == '[R, S01, R]->[R, S0, R]->[S0, R, R]->[S01, R, R]' + transform_path_str = "->".join([str(layout.sharding_spec.sharding_sequence) for layout in transform_path]) + assert transform_path_str == "[R, S01, R]->[R, S0, R]->[S0, R, R]->[S01, R, R]" # check comm action sequence # all-gather(S01) -> S0 @@ -138,18 +123,18 @@ def check_layout_converting(rank, world_size, port): assert comm_action_sequence[2].logical_process_axis == 1 # checkout chached_spec_pairs_transform_path - assert layout_converter.cached_solution[('[R, S01, R]', '[S01, R, R]')][0] == transform_path - assert layout_converter.cached_solution[('[R, S01, R]', '[S01, R, R]')][1] == comm_action_sequence + assert layout_converter.cached_solution[("[R, S01, R]", "[S01, R, R]")][0] == transform_path + assert layout_converter.cached_solution[("[R, S01, R]", "[S01, R, R]")][1] == comm_action_sequence comm_cost = layout_converter.get_total_comm_cost(source_layout, target_layout) - assert comm_cost['forward'] == comm_cost['backward'] - assert math.floor(comm_cost['total']) == math.floor(comm_cost['forward'] + comm_cost['backward']) + assert comm_cost["forward"] == comm_cost["backward"] + assert math.floor(comm_cost["total"]) == math.floor(comm_cost["forward"] + comm_cost["backward"]) def check_layout_converting_apply(rank, world_size, port): disable_existing_loggers() - launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") dim_partition_source = {1: [0, 1]} dim_partition_target = {0: [0, 1]} @@ -159,21 +144,15 @@ def check_layout_converting_apply(rank, world_size, port): # shard_sequence: R,S01,R # device_mesh_shape: (4, 4) sharding_spec_source = ShardingSpec(dim_size=3, dim_partition_dict=dim_partition_source) - source_layout = Layout(device_mesh=device_mesh, - device_type=torch.device('cuda'), - sharding_spec=sharding_spec_source, - entire_shape=entire_shape) + source_layout = Layout(device_mesh=device_mesh, sharding_spec=sharding_spec_source, global_shape=global_shape) # DistSpec: # shard_sequence: S01,R,R # device_mesh_shape: (4, 4) sharding_spec_target = ShardingSpec(dim_size=3, dim_partition_dict=dim_partition_target) - target_layout = Layout(device_mesh=device_mesh, - device_type=torch.device('cuda'), - sharding_spec=sharding_spec_target, - entire_shape=entire_shape) + target_layout = Layout(device_mesh=device_mesh, sharding_spec=sharding_spec_target, global_shape=global_shape) - original_tensor = torch.rand(entire_shape).cuda() + original_tensor = torch.rand(global_shape).cuda() # tensor_to_apply: [R, S01, R] tensor_to_apply = original_tensor.narrow(1, rank * 8, 8) @@ -194,5 +173,5 @@ def test_layout_converter(): spawn(check_layout_converting_apply, world_size) -if __name__ == '__main__': +if __name__ == "__main__": test_layout_converter() diff --git a/tests/test_tensor/test_mix_gather.py b/tests/test_tensor/test_mix_gather.py index 9122808eb5a3ef841ab7af7344822abf90e7447f..7d6f8979dd0bb0e15b19d1e2472b24d3e6862fd1 100644 --- a/tests/test_tensor/test_mix_gather.py +++ b/tests/test_tensor/test_mix_gather.py @@ -1,7 +1,7 @@ import pytest import torch +import torch.distributed as dist -from colossalai.core import global_context as gpc from colossalai.device.device_mesh import DeviceMesh from colossalai.initialize import launch from colossalai.logging import disable_existing_loggers @@ -17,12 +17,13 @@ def check_mix_gather_S0S1(device_mesh, rank): f_target_pair = (f, [0]) b_target_pair = (b, [1]) gather_dim, logical_process_axes = mix_gather_simulator(f_target_pair, b_target_pair) - tensor_slice = [4, 2] # (4, 2) + tensor_slice = [4, 2] # (4, 2) rank_slice = 4 f_start = (rank // rank_slice) * tensor_slice[0] b_start = (rank % rank_slice) * tensor_slice[1] - tensor_to_comm = tensor_to_check[f_start:f_start + tensor_slice[0], - b_start:b_start + tensor_slice[1]].contiguous().cuda() + tensor_to_comm = ( + tensor_to_check[f_start : f_start + tensor_slice[0], b_start : b_start + tensor_slice[1]].contiguous().cuda() + ) dim_partition_dict = {0: [0], 1: [1]} @@ -31,12 +32,14 @@ def check_mix_gather_S0S1(device_mesh, rank): # device_mesh_shape: (2, 4) source_spec = ShardingSpec(device_mesh, tensor_to_check.shape, dim_partition_dict=dim_partition_dict) - comm_spec = CommSpec(CollectiveCommPattern.MIXGATHER_FWD_SPLIT_BWD, - sharding_spec=source_spec, - gather_dim=gather_dim, - logical_process_axis=logical_process_axes, - forward_only=True, - mix_gather=True) + comm_spec = CommSpec( + CollectiveCommPattern.MIXGATHER_FWD_SPLIT_BWD, + sharding_spec=source_spec, + gather_dim=gather_dim, + logical_process_axis=logical_process_axes, + forward_only=True, + mix_gather=True, + ) tensor_to_comm = comm_spec.covert_spec_to_action(tensor_to_comm) assert tensor_to_comm.equal(tensor_to_check) @@ -48,12 +51,13 @@ def check_two_all_gather_S0S1(device_mesh, rank): dim_partition_dict = {0: [0], 1: [1]} - tensor_slice = [tensor_width // 2, tensor_width // 4] # (4, 2) + tensor_slice = [tensor_width // 2, tensor_width // 4] # (4, 2) rank_slice = 4 f_start = (rank // rank_slice) * tensor_slice[0] b_start = (rank % rank_slice) * tensor_slice[1] - tensor_to_comm = tensor_to_check[f_start:f_start + tensor_slice[0], - b_start:b_start + tensor_slice[1]].contiguous().cuda() + tensor_to_comm = ( + tensor_to_check[f_start : f_start + tensor_slice[0], b_start : b_start + tensor_slice[1]].contiguous().cuda() + ) # DistSpec: # shard_sequence: S0,S1 @@ -61,10 +65,9 @@ def check_two_all_gather_S0S1(device_mesh, rank): sharding_spec = ShardingSpec(device_mesh, tensor_to_check.shape, dim_partition_dict=dim_partition_dict) # CommSpec:(comm_pattern:allgather, gather_dim:0, logical_process_axis:0) - comm_spec = CommSpec(CollectiveCommPattern.GATHER_FWD_SPLIT_BWD, - sharding_spec, - gather_dim=0, - logical_process_axis=0) + comm_spec = CommSpec( + CollectiveCommPattern.GATHER_FWD_SPLIT_BWD, sharding_spec, gather_dim=0, logical_process_axis=0 + ) tensor_to_comm = comm_spec.covert_spec_to_action(tensor_to_comm) @@ -75,10 +78,9 @@ def check_two_all_gather_S0S1(device_mesh, rank): sharding_spec = ShardingSpec(device_mesh, tensor_to_check.shape, dim_partition_dict=dim_partition_dict) # CommSpec:(comm_pattern:allgather, gather_dim:1, logical_process_axis:1) - comm_spec = CommSpec(CollectiveCommPattern.GATHER_FWD_SPLIT_BWD, - sharding_spec, - gather_dim=1, - logical_process_axis=1) + comm_spec = CommSpec( + CollectiveCommPattern.GATHER_FWD_SPLIT_BWD, sharding_spec, gather_dim=1, logical_process_axis=1 + ) tensor_to_comm = comm_spec.covert_spec_to_action(tensor_to_comm) @@ -95,8 +97,9 @@ def check_mix_gather_S1S0(device_mesh, rank): rank_slice = 4 f_start = (rank % rank_slice) * tensor_slice[0] b_start = (rank // rank_slice) * tensor_slice[1] - tensor_to_comm = tensor_to_check[f_start:f_start + tensor_slice[0], - b_start:b_start + tensor_slice[1]].contiguous().cuda() + tensor_to_comm = ( + tensor_to_check[f_start : f_start + tensor_slice[0], b_start : b_start + tensor_slice[1]].contiguous().cuda() + ) dim_partition_dict = {0: [1], 1: [0]} @@ -105,12 +108,14 @@ def check_mix_gather_S1S0(device_mesh, rank): # device_mesh_shape: (2, 4) source_spec = ShardingSpec(device_mesh, tensor_to_check.shape, dim_partition_dict=dim_partition_dict) - comm_spec = CommSpec(CollectiveCommPattern.MIXGATHER_FWD_SPLIT_BWD, - sharding_spec=source_spec, - gather_dim=gather_dim, - logical_process_axis=logical_process_axes, - forward_only=True, - mix_gather=True) + comm_spec = CommSpec( + CollectiveCommPattern.MIXGATHER_FWD_SPLIT_BWD, + sharding_spec=source_spec, + gather_dim=gather_dim, + logical_process_axis=logical_process_axes, + forward_only=True, + mix_gather=True, + ) tensor_to_comm = comm_spec.covert_spec_to_action(tensor_to_comm) assert tensor_to_comm.equal(tensor_to_check) @@ -120,12 +125,13 @@ def check_two_all_gather_S1S0(device_mesh, rank): tensor_width = 8 tensor_to_check = torch.arange(int(tensor_width * tensor_width)).reshape((tensor_width, tensor_width)).cuda() - tensor_slice = [tensor_width // 4, tensor_width // 2] # (4, 2) + tensor_slice = [tensor_width // 4, tensor_width // 2] # (4, 2) rank_slice = 4 f_start = (rank % rank_slice) * tensor_slice[0] b_start = (rank // rank_slice) * tensor_slice[1] - tensor_to_comm = tensor_to_check[f_start:f_start + tensor_slice[0], - b_start:b_start + tensor_slice[1]].contiguous().cuda() + tensor_to_comm = ( + tensor_to_check[f_start : f_start + tensor_slice[0], b_start : b_start + tensor_slice[1]].contiguous().cuda() + ) dim_partition_dict = {0: [1], 1: [0]} @@ -135,10 +141,9 @@ def check_two_all_gather_S1S0(device_mesh, rank): sharding_spec = ShardingSpec(device_mesh, tensor_to_check.shape, dim_partition_dict=dim_partition_dict) # CommSpec:(comm_pattern:allgather, gather_dim:0, logical_process_axis:1) - comm_spec = CommSpec(CollectiveCommPattern.GATHER_FWD_SPLIT_BWD, - sharding_spec, - gather_dim=0, - logical_process_axis=1) + comm_spec = CommSpec( + CollectiveCommPattern.GATHER_FWD_SPLIT_BWD, sharding_spec, gather_dim=0, logical_process_axis=1 + ) tensor_to_comm = comm_spec.covert_spec_to_action(tensor_to_comm) @@ -149,10 +154,9 @@ def check_two_all_gather_S1S0(device_mesh, rank): sharding_spec = ShardingSpec(device_mesh, tensor_to_check.shape, dim_partition_dict=dim_partition_dict) # CommSpec:(comm_pattern:allgather, gather_dim:1, logical_process_axis:0) - comm_spec = CommSpec(CollectiveCommPattern.GATHER_FWD_SPLIT_BWD, - sharding_spec, - gather_dim=1, - logical_process_axis=0) + comm_spec = CommSpec( + CollectiveCommPattern.GATHER_FWD_SPLIT_BWD, sharding_spec, gather_dim=1, logical_process_axis=0 + ) tensor_to_comm = comm_spec.covert_spec_to_action(tensor_to_comm) @@ -165,7 +169,7 @@ def check_mix_gather_S01R(device_mesh, rank): f_target_pair = (f, [0, 1]) b_target_pair = (b, []) gather_dim, logical_process_axes = mix_gather_simulator(f_target_pair, b_target_pair) - tensor_to_comm = tensor_to_check[rank:rank + 1, :].contiguous().cuda() + tensor_to_comm = tensor_to_check[rank : rank + 1, :].contiguous().cuda() dim_partition_dict = {0: [0, 1]} # DistSpec: @@ -173,12 +177,14 @@ def check_mix_gather_S01R(device_mesh, rank): # device_mesh_shape: (2, 4) source_spec = ShardingSpec(device_mesh, tensor_to_check.shape, dim_partition_dict=dim_partition_dict) - comm_spec = CommSpec(CollectiveCommPattern.MIXGATHER_FWD_SPLIT_BWD, - sharding_spec=source_spec, - gather_dim=gather_dim, - logical_process_axis=logical_process_axes, - forward_only=True, - mix_gather=True) + comm_spec = CommSpec( + CollectiveCommPattern.MIXGATHER_FWD_SPLIT_BWD, + sharding_spec=source_spec, + gather_dim=gather_dim, + logical_process_axis=logical_process_axes, + forward_only=True, + mix_gather=True, + ) tensor_to_comm = comm_spec.covert_spec_to_action(tensor_to_comm) assert tensor_to_comm.equal(tensor_to_check) @@ -189,7 +195,7 @@ def check_two_all_gather_S01R(device_mesh, rank): tensor_to_check = torch.arange(int(tensor_width * tensor_width)).reshape((tensor_width, tensor_width)).cuda() rank_stride = tensor_width // 8 - tensor_to_comm = tensor_to_check[rank:rank + rank_stride, :].contiguous().cuda() + tensor_to_comm = tensor_to_check[rank : rank + rank_stride, :].contiguous().cuda() dim_partition_dict = {0: [0, 1]} @@ -199,10 +205,9 @@ def check_two_all_gather_S01R(device_mesh, rank): sharding_spec = ShardingSpec(device_mesh, tensor_to_check.shape, dim_partition_dict=dim_partition_dict) # CommSpec:(comm_pattern:allgather, gather_dim:0, logical_process_axis:0) - comm_spec = CommSpec(CollectiveCommPattern.GATHER_FWD_SPLIT_BWD, - sharding_spec, - gather_dim=0, - logical_process_axis=1) + comm_spec = CommSpec( + CollectiveCommPattern.GATHER_FWD_SPLIT_BWD, sharding_spec, gather_dim=0, logical_process_axis=1 + ) tensor_to_comm = comm_spec.covert_spec_to_action(tensor_to_comm) @@ -214,10 +219,9 @@ def check_two_all_gather_S01R(device_mesh, rank): sharding_spec = ShardingSpec(device_mesh, tensor_to_check.shape, dim_partition_dict=dim_partition_dict) # CommSpec:(comm_pattern:allgather, gather_dim:0, logical_process_axis:1) - comm_spec = CommSpec(CollectiveCommPattern.GATHER_FWD_SPLIT_BWD, - sharding_spec, - gather_dim=0, - logical_process_axis=0) + comm_spec = CommSpec( + CollectiveCommPattern.GATHER_FWD_SPLIT_BWD, sharding_spec, gather_dim=0, logical_process_axis=0 + ) tensor_to_comm = comm_spec.covert_spec_to_action(tensor_to_comm) @@ -231,7 +235,7 @@ def check_mix_gather_RS01(device_mesh, rank): f_target_pair = (f, []) b_target_pair = (b, [0, 1]) gather_dim, logical_process_axes = mix_gather_simulator(f_target_pair, b_target_pair) - tensor_to_comm = tensor_to_check[:, rank:rank + 1].contiguous().cuda() + tensor_to_comm = tensor_to_check[:, rank : rank + 1].contiguous().cuda() dim_partition_dict = {1: [0, 1]} # DistSpec: @@ -239,12 +243,14 @@ def check_mix_gather_RS01(device_mesh, rank): # device_mesh_shape: (2, 4) source_spec = ShardingSpec(device_mesh, tensor_to_check.shape, dim_partition_dict=dim_partition_dict) - comm_spec = CommSpec(CollectiveCommPattern.MIXGATHER_FWD_SPLIT_BWD, - sharding_spec=source_spec, - gather_dim=gather_dim, - logical_process_axis=logical_process_axes, - forward_only=True, - mix_gather=True) + comm_spec = CommSpec( + CollectiveCommPattern.MIXGATHER_FWD_SPLIT_BWD, + sharding_spec=source_spec, + gather_dim=gather_dim, + logical_process_axis=logical_process_axes, + forward_only=True, + mix_gather=True, + ) tensor_to_comm = comm_spec.covert_spec_to_action(tensor_to_comm) assert tensor_to_comm.equal(tensor_to_check) @@ -255,7 +261,7 @@ def check_two_all_gather_RS01(device_mesh, rank): tensor_to_check = torch.arange(int(tensor_width * tensor_width)).reshape((tensor_width, tensor_width)).cuda() rank_stride = tensor_width // 8 - tensor_to_comm = tensor_to_check[:, rank:rank + rank_stride].contiguous().cuda() + tensor_to_comm = tensor_to_check[:, rank : rank + rank_stride].contiguous().cuda() dim_partition_dict = {1: [0, 1]} @@ -265,10 +271,9 @@ def check_two_all_gather_RS01(device_mesh, rank): sharding_spec = ShardingSpec(device_mesh, tensor_to_check.shape, dim_partition_dict=dim_partition_dict) # CommSpec:(comm_pattern:allgather, gather_dim:1, logical_process_axis:0) - comm_spec = CommSpec(CollectiveCommPattern.GATHER_FWD_SPLIT_BWD, - sharding_spec, - gather_dim=1, - logical_process_axis=1) + comm_spec = CommSpec( + CollectiveCommPattern.GATHER_FWD_SPLIT_BWD, sharding_spec, gather_dim=1, logical_process_axis=1 + ) tensor_to_comm = comm_spec.covert_spec_to_action(tensor_to_comm) @@ -280,10 +285,9 @@ def check_two_all_gather_RS01(device_mesh, rank): sharding_spec = ShardingSpec(device_mesh, tensor_to_check.shape, dim_partition_dict=dim_partition_dict) # CommSpec:(comm_pattern:allgather, gather_dim:1, logical_process_axis:1) - comm_spec = CommSpec(CollectiveCommPattern.GATHER_FWD_SPLIT_BWD, - sharding_spec, - gather_dim=1, - logical_process_axis=0) + comm_spec = CommSpec( + CollectiveCommPattern.GATHER_FWD_SPLIT_BWD, sharding_spec, gather_dim=1, logical_process_axis=0 + ) tensor_to_comm = comm_spec.covert_spec_to_action(tensor_to_comm) @@ -292,10 +296,10 @@ def check_two_all_gather_RS01(device_mesh, rank): def check_comm(rank, world_size, port): disable_existing_loggers() - launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") physical_mesh_id = torch.arange(0, 8) - assert rank == gpc.get_global_rank() + assert rank == dist.get_rank() mesh_shape = (2, 4) # [[0, 1, 2, 3], @@ -326,5 +330,5 @@ def test_mix_gather(): spawn(check_comm, world_size) -if __name__ == '__main__': +if __name__ == "__main__": test_mix_gather() diff --git a/tests/test_tensor/test_parameter.py b/tests/test_tensor/test_parameter.py deleted file mode 100644 index 9c3f05da1ffa8c1fc62851cadf7b572ef0a6d63d..0000000000000000000000000000000000000000 --- a/tests/test_tensor/test_parameter.py +++ /dev/null @@ -1,34 +0,0 @@ -import pytest -import torch -from common_utils import tensor_equal - -import colossalai -from colossalai.tensor import ColoParameter, ColoTensor, ColoTensorSpec, ProcessGroup -from colossalai.testing import free_port - - -@pytest.mark.skip -def test_multiinheritance(): - colossalai.launch(config={}, rank=0, world_size=1, host='localhost', port=free_port(), backend='nccl') - colo_param = ColoParameter(None, requires_grad=True) - assert colo_param.dist_spec.placement.value == 'r' - assert isinstance(colo_param, ColoTensor) - assert isinstance(colo_param, torch.nn.Parameter) - - # __deepcopy__ overload - import copy - colo_param2 = copy.deepcopy(colo_param) - assert isinstance(colo_param2, ColoParameter) - assert tensor_equal(colo_param.data, colo_param2.data) - assert colo_param.requires_grad == colo_param2.requires_grad - - # __repr__ overload - assert 'ColoParameter' in str(colo_param) - - # __torch_function__ - clone_param = torch.clone(colo_param) - assert isinstance(clone_param, ColoTensor) - - -if __name__ == '__main__': - test_multiinheritance() diff --git a/tests/test_tensor/test_shape_consistency.py b/tests/test_tensor/test_shape_consistency.py index 6fe9ee292cd014ae37d41ec5df4f7fcd6ccb2d99..c51797912e6f1264bea3f00f72cc11e96b49425f 100644 --- a/tests/test_tensor/test_shape_consistency.py +++ b/tests/test_tensor/test_shape_consistency.py @@ -1,9 +1,10 @@ -from colossalai.tensor.shape_consistency import ShapeConsistencyManager, CollectiveCommPattern import torch -from colossalai.tensor.sharding_spec import _DimSpec, ShardingSpec + from colossalai.device.device_mesh import DeviceMesh +from colossalai.tensor.shape_consistency import CollectiveCommPattern, ShapeConsistencyManager +from colossalai.tensor.sharding_spec import ShardingSpec -physical_mesh_id = torch.arange(0, 16).reshape(2, 8) +physical_mesh_id = torch.arange(0, 16) mesh_shape = (4, 4) # [[0, 1, 2, 3], # [4, 5, 6, 7], @@ -15,7 +16,6 @@ shape_consistency_manager = ShapeConsistencyManager() def test_one_step_transform(): - dim_partition_dict = {0: [0], 1: [1]} # DistSpec: # shard_sequence: S0,S1,R @@ -27,16 +27,14 @@ def test_one_step_transform(): # device_mesh_shape: (4, 4): (CommSpec:(comm_pattern:allgather, gather_dim:0, logical_process_axis:0), 0), DistSpec: # shard_sequence: S0,R,R # device_mesh_shape: (4, 4): (CommSpec:(comm_pattern:allgather, gather_dim:1, logical_process_axis:1), 0)} - rst_dict = shape_consistency_manager.get_all_all_gather_spec(sharding_spec, { - "forward": 0, - "backward": 0, - "total": 0 - }) + rst_dict = shape_consistency_manager.get_all_all_gather_spec( + sharding_spec, {"forward": 0, "backward": 0, "total": 0} + ) - assert '[R, S1, R]' in [ + assert "[R, S1, R]" in [ str(all_gather_sharding_spec.sharding_sequence) for all_gather_sharding_spec in rst_dict.keys() ] - assert '[S0, R, R]' in [ + assert "[S0, R, R]" in [ str(all_gather_sharding_spec.sharding_sequence) for all_gather_sharding_spec in rst_dict.keys() ] @@ -52,19 +50,17 @@ def test_one_step_transform(): # device_mesh_shape: (4, 4): (CommSpec:(comm_pattern:all2all, gather_dim:0, shard_dim:2, logical_process_axis: 0), 0), DistSpec: # shard_sequence: S0,R,S1 # device_mesh_shape: (4, 4): (CommSpec:(comm_pattern:all2all, gather_dim:1, shard_dim:2, logical_process_axis: 1), 0)} - rst_dict_all2all = shape_consistency_manager.get_all_all_to_all_spec(sharding_spec_all2all, { - "forward": 0, - "backward": 0, - "total": 0 - }) + rst_dict_all2all = shape_consistency_manager.get_all_all_to_all_spec( + sharding_spec_all2all, {"forward": 0, "backward": 0, "total": 0} + ) - assert '[S01, R, R]' in [ + assert "[S01, R, R]" in [ str(all2all_sharding_spec.sharding_sequence) for all2all_sharding_spec in rst_dict_all2all.keys() ] - assert '[R, S1, S0]' in [ + assert "[R, S1, S0]" in [ str(all2all_sharding_spec.sharding_sequence) for all2all_sharding_spec in rst_dict_all2all.keys() ] - assert '[S0, R, S1]' in [ + assert "[S0, R, S1]" in [ str(all2all_sharding_spec.sharding_sequence) for all2all_sharding_spec in rst_dict_all2all.keys() ] @@ -80,19 +76,17 @@ def test_one_step_transform(): # device_mesh_shape: (4, 4): (CommSpec:(comm_pattern:shard, shard_dim:1, logical_process_axis:1), 0), DistSpec: # shard_sequence: S0,R,S1 # device_mesh_shape: (4, 4): (CommSpec:(comm_pattern:shard, shard_dim:2, logical_process_axis:1), 0)} - rst_dict_shard = shape_consistency_manager.get_all_shard_spec(sharding_spec_shard, { - "forward": 0, - "backward": 0, - "total": 0 - }) + rst_dict_shard = shape_consistency_manager.get_all_shard_spec( + sharding_spec_shard, {"forward": 0, "backward": 0, "total": 0} + ) - assert '[S01, R, R]' in [ + assert "[S01, R, R]" in [ str(shard_sharding_spec.sharding_sequence) for shard_sharding_spec in rst_dict_shard.keys() ] - assert '[S0, S1, R]' in [ + assert "[S0, S1, R]" in [ str(shard_sharding_spec.sharding_sequence) for shard_sharding_spec in rst_dict_shard.keys() ] - assert '[S0, R, S1]' in [ + assert "[S0, R, S1]" in [ str(shard_sharding_spec.sharding_sequence) for shard_sharding_spec in rst_dict_shard.keys() ] @@ -112,10 +106,11 @@ def test_shape_consistency(): sharding_spec_target = ShardingSpec(device_mesh, entire_shape, dim_partition_target) transform_path, comm_action_sequence, total_cost = shape_consistency_manager.shape_consistency( - sharding_spec_source, sharding_spec_target) + sharding_spec_source, sharding_spec_target + ) - transform_path_str = '->'.join([str(sharding_spec.sharding_sequence) for sharding_spec in transform_path]) - assert transform_path_str == '[R, S01, R]->[R, S0, R]->[S0, R, R]->[S01, R, R]' + transform_path_str = "->".join([str(sharding_spec.sharding_sequence) for sharding_spec in transform_path]) + assert transform_path_str == "[R, S01, R]->[R, S0, R]->[S0, R, R]->[S01, R, R]" # all-gather(S01) -> S0 assert comm_action_sequence[0].comm_pattern == CollectiveCommPattern.GATHER_FWD_SPLIT_BWD @@ -133,12 +128,15 @@ def test_shape_consistency(): assert comm_action_sequence[2].shard_dim == 0 assert comm_action_sequence[2].logical_process_axis == 1 - assert shape_consistency_manager.cached_spec_pairs_transform_path[('[R, S01, R]', - '[S01, R, R]')][0] == transform_path - assert shape_consistency_manager.cached_spec_pairs_transform_path[('[R, S01, R]', - '[S01, R, R]')][1] == comm_action_sequence + assert ( + shape_consistency_manager.cached_spec_pairs_transform_path[("[R, S01, R]", "[S01, R, R]")][0] == transform_path + ) + assert ( + shape_consistency_manager.cached_spec_pairs_transform_path[("[R, S01, R]", "[S01, R, R]")][1] + == comm_action_sequence + ) -if __name__ == '__main__': +if __name__ == "__main__": test_one_step_transform() test_shape_consistency() diff --git a/tests/test_tensor/test_shape_consistency_apply.py b/tests/test_tensor/test_shape_consistency_apply.py index b57952df401fdcbe99299077e7c5618c137045e6..b2bc84edd87fb85d9e067c6396ff72edbe33a570 100644 --- a/tests/test_tensor/test_shape_consistency_apply.py +++ b/tests/test_tensor/test_shape_consistency_apply.py @@ -4,14 +4,14 @@ import torch from colossalai.device.device_mesh import DeviceMesh from colossalai.initialize import launch from colossalai.logging import disable_existing_loggers -from colossalai.tensor.shape_consistency import CollectiveCommPattern, ShapeConsistencyManager +from colossalai.tensor.shape_consistency import ShapeConsistencyManager from colossalai.tensor.sharding_spec import ShardingSpec from colossalai.testing import rerun_if_address_is_in_use, spawn def check_apply(rank, world_size, port): disable_existing_loggers() - launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") physical_mesh_id = torch.arange(0, 4) mesh_shape = (2, 2) @@ -72,5 +72,5 @@ def test_apply(): spawn(check_apply, world_size) -if __name__ == '__main__': +if __name__ == "__main__": test_apply() diff --git a/tests/test_tensor/test_sharded_linear.py b/tests/test_tensor/test_sharded_linear.py deleted file mode 100644 index d66d4fec14d11142a851535114e277c32c97c3c4..0000000000000000000000000000000000000000 --- a/tests/test_tensor/test_sharded_linear.py +++ /dev/null @@ -1,232 +0,0 @@ -import pytest -import torch -import torch.nn.functional as F - -import colossalai -from colossalai.device.device_mesh import DeviceMesh -from colossalai.nn._ops._utils import gather_forward_split_backward -from colossalai.tensor import ColoParameter, ColoTensor, ProcessGroup -from colossalai.tensor.sharding_spec import ShardingSpec -from colossalai.testing import rerun_if_address_is_in_use, spawn - - -def run_dist(rank, world_size, port): - config = {} - colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') - - # create mlp vars - x = ColoTensor.from_torch_tensor(torch.rand(4, 4, 8, requires_grad=True)).cuda() - w = ColoParameter.from_torch_tensor(torch.rand(16, 8, requires_grad=True)).cuda() - b = ColoParameter.from_torch_tensor(torch.rand(16, requires_grad=True)).cuda() - - # run normal forward - out = F.linear(x, w, b) - - # create mesh meta - # the mesh is in the following topo - # [[0, 1], - # [2, 3]] - physical_mesh_id = torch.arange(0, 4).reshape(2, 2) - mesh_shape = (2, 2) - device_mesh = DeviceMesh(physical_mesh_id, mesh_shape) - row_id = rank // 2 - column_id = rank % 2 - - # create pg - row_process_group = None - col_process_group = None - row_to_ranks = {0: [0, 1], 1: [2, 3]} - col_to_ranks = {0: [0, 2], 1: [1, 3]} - - for idx in range(2): - # row ranks - row_ranks = row_to_ranks[idx] - row_pg = ProcessGroup(ranks=row_ranks, tp_degree=2) - - # col ranks - col_ranks = col_to_ranks[idx] - col_pg = ProcessGroup(ranks=col_ranks, tp_degree=2) - - if rank in row_ranks: - row_process_group = row_pg - - if rank in col_ranks: - col_process_group = col_pg - - ######################## - # RRR x RS0 -> RRS0 # - ######################## - # w will be transposed in F.linear - x_replica = x.detach().clone() - w_shard = torch.chunk(w.detach().clone(), chunks=2, dim=0)[row_id] - b_shard = torch.chunk(b.detach().clone(), chunks=2, dim=0)[row_id] - - # adding sharding spec - x_replica.sharding_spec = ShardingSpec(device_mesh, x.shape, dim_partition_dict={}) - w_shard.sharding_spec = ShardingSpec(device_mesh, w.shape, dim_partition_dict={0: [0]}) - b_shard.sharding_spec = ShardingSpec(device_mesh, b.shape, dim_partition_dict={0: [0]}) - - # check sharding spec - assert str(x_replica.sharding_spec.sharding_sequence) == "[R, R, R]" - assert str(w_shard.sharding_spec.sharding_sequence) == "[S0, R]" - assert str(b_shard.sharding_spec.sharding_sequence) == "[S0]" - - w_shard.pg_axis0 = col_process_group - w_shard.pg_axis1 = row_process_group - - out_shard = F.linear(x_replica, w_shard, b_shard) - assert str(out_shard.sharding_spec.sharding_sequence) == "[R, R, S0]" - - # each row only has a mini-batch - expected_out_shard = torch.chunk(out, chunks=2, dim=2)[row_id] - assert torch.allclose(out_shard, expected_out_shard) - - ######################## - # S0RR x RS1 -> S0RS1 # - ######################## - # w will be transposed in F.linear - x_shard = torch.chunk(x.detach().clone(), chunks=2, dim=0)[row_id] - w_shard = torch.chunk(w.detach().clone(), chunks=2, dim=0)[column_id] - b_shard = torch.chunk(b.detach().clone(), chunks=2, dim=0)[column_id] - - # adding sharding spec - x_shard.sharding_spec = ShardingSpec(device_mesh, x.shape, dim_partition_dict={0: [0]}) - w_shard.sharding_spec = ShardingSpec(device_mesh, w.shape, dim_partition_dict={0: [1]}) - b_shard.sharding_spec = ShardingSpec(device_mesh, b.shape, dim_partition_dict={0: [1]}) - - # check sharding spec - assert str(x_shard.sharding_spec.sharding_sequence) == "[S0, R, R]" - assert str(w_shard.sharding_spec.sharding_sequence) == "[S1, R]" - assert str(b_shard.sharding_spec.sharding_sequence) == "[S1]" - - w_shard.pg_axis0 = col_process_group - w_shard.pg_axis1 = row_process_group - - out_shard = F.linear(x_shard, w_shard, b_shard) - - # each row only has a mini-batch - expected_out_shard = torch.chunk(out, chunks=2, dim=0)[row_id] - expected_out_shard = torch.chunk(expected_out_shard, chunks=2, dim=2)[column_id] - assert torch.allclose(out_shard, expected_out_shard) - - ######################## - # S0RS1 x S1R -> S0RR # - ######################## - # w will be transposed in F.linear - x_shard = torch.chunk(x.clone(), chunks=2, dim=0)[row_id] - x_shard = torch.chunk(x_shard, chunks=2, dim=2)[column_id] - w_shard = torch.chunk(w.clone(), chunks=2, dim=1)[column_id] - b_replica = b.clone() - - # adding sharding spec - x_shard.sharding_spec = ShardingSpec(device_mesh, x.shape, dim_partition_dict={0: [0], 2: [1]}) - w_shard.sharding_spec = ShardingSpec(device_mesh, w.shape, dim_partition_dict={1: [1]}) - b_replica.sharding_spec = ShardingSpec(device_mesh, b.shape, dim_partition_dict={}) - - # check sharding spec - assert str(x_shard.sharding_spec.sharding_sequence) == "[S0, R, S1]" - assert str(w_shard.sharding_spec.sharding_sequence) == "[R, S1]" - assert str(b_replica.sharding_spec.sharding_sequence) == "[R]" - - w_shard.pg_axis0 = col_process_group - w_shard.pg_axis1 = row_process_group - - out_shard = F.linear(x_shard, w_shard, b_replica) - - # each row only has a mini-batch - expected_out_shard = torch.chunk(out, chunks=2, dim=0)[row_id] - assert torch.allclose(out_shard, expected_out_shard) - - ######################## - # RRS0 x S0R -> RRR # - ######################## - # w will be transposed in F.linear - x_shard = torch.chunk(x.clone(), chunks=2, dim=2)[row_id] - w_shard = torch.chunk(w.clone(), chunks=2, dim=1)[row_id] - b_replica = b.clone() - - # adding sharding spec - x_shard.sharding_spec = ShardingSpec(device_mesh, x.shape, dim_partition_dict={2: [0]}) - w_shard.sharding_spec = ShardingSpec(device_mesh, w.shape, dim_partition_dict={1: [0]}) - b_replica.sharding_spec = ShardingSpec(device_mesh, b.shape, dim_partition_dict={}) - - # check sharding spec - assert str(x_shard.sharding_spec.sharding_sequence) == "[R, R, S0]" - assert str(w_shard.sharding_spec.sharding_sequence) == "[R, S0]" - assert str(b_replica.sharding_spec.sharding_sequence) == "[R]" - - w_shard.pg_axis0 = col_process_group - w_shard.pg_axis1 = row_process_group - - out_shard = F.linear(x_shard, w_shard, b_replica) - - # each row only has a mini-batch - expected_out_shard = out - assert torch.allclose(out_shard, expected_out_shard) - - ######################## - # RS0S1 x S1R -> RS0R # - ######################## - # w will be transposed in F.linear - x_shard = torch.chunk(x.clone(), chunks=2, dim=1)[row_id] - x_shard = torch.chunk(x_shard, chunks=2, dim=2)[column_id] - w_shard = torch.chunk(w.clone(), chunks=2, dim=1)[column_id] - b_replica = b.clone() - - # adding sharding spec - x_shard.sharding_spec = ShardingSpec(device_mesh, x.shape, dim_partition_dict={1: [0], 2: [1]}) - w_shard.sharding_spec = ShardingSpec(device_mesh, w.shape, dim_partition_dict={1: [1]}) - b_replica.sharding_spec = ShardingSpec(device_mesh, b.shape, dim_partition_dict={}) - - # check sharding spec - assert str(x_shard.sharding_spec.sharding_sequence) == "[R, S0, S1]" - assert str(w_shard.sharding_spec.sharding_sequence) == "[R, S1]" - assert str(b_replica.sharding_spec.sharding_sequence) == "[R]" - - w_shard.pg_axis0 = col_process_group - w_shard.pg_axis1 = row_process_group - - out_shard = F.linear(x_shard, w_shard, b_replica) - - # each row only has a mini-batch - expected_out_shard = torch.chunk(out, chunks=2, dim=1)[row_id] - assert torch.allclose(out_shard, expected_out_shard) - - ######################## - # RRS0 x S0S1 -> RRS1 # - ######################## - # w will be transposed in F.linear - x_shard = torch.chunk(x.clone(), chunks=2, dim=2)[row_id] - w_shard = torch.chunk(w.clone(), chunks=2, dim=1)[row_id] - w_shard = torch.chunk(w_shard, chunks=2, dim=0)[column_id] - b_shard = torch.chunk(b.clone(), chunks=2, dim=0)[column_id] - - # adding sharding spec - x_shard.sharding_spec = ShardingSpec(device_mesh, x.shape, dim_partition_dict={2: [0]}) - w_shard.sharding_spec = ShardingSpec(device_mesh, w.shape, dim_partition_dict={0: [1], 1: [0]}) - b_shard.sharding_spec = ShardingSpec(device_mesh, b.shape, dim_partition_dict={0: [1]}) - - # check sharding spec - assert str(x_shard.sharding_spec.sharding_sequence) == "[R, R, S0]" - assert str(w_shard.sharding_spec.sharding_sequence) == "[S1, S0]" - assert str(b_shard.sharding_spec.sharding_sequence) == "[S1]" - - w_shard.pg_axis0 = col_process_group - w_shard.pg_axis1 = row_process_group - - out_shard = F.linear(x_shard, w_shard, b_shard) - - # each row only has a mini-batch - expected_out_shard = torch.chunk(out, chunks=2, dim=2)[column_id] - assert torch.allclose(out_shard, expected_out_shard) - - -@pytest.mark.dist -@pytest.mark.parametrize('world_size', [4]) -@rerun_if_address_is_in_use() -def test_sharded_mlp(world_size): - spawn(run_dist, world_size) - - -if __name__ == '__main__': - test_sharded_mlp(4) diff --git a/tests/test_tensor/test_sharding_spec.py b/tests/test_tensor/test_sharding_spec.py index 909c84ef0f0ebeadcaab09c8fa23dd0a90199843..7730683bf52539fd1689b43adae3b08104c9b9a2 100644 --- a/tests/test_tensor/test_sharding_spec.py +++ b/tests/test_tensor/test_sharding_spec.py @@ -1,11 +1,11 @@ import torch from colossalai.device.device_mesh import DeviceMesh -from colossalai.tensor.sharding_spec import ShardingSpec, _DimSpec +from colossalai.tensor.sharding_spec import ShardingSpec def test_sharding_spec(): - physical_mesh_id = torch.arange(0, 16).reshape(2, 8) + physical_mesh_id = torch.arange(0, 16) mesh_shape = (4, 4) # [[0, 1, 2, 3], # [4, 5, 6, 7], @@ -21,5 +21,5 @@ def test_sharding_spec(): assert str(sharding_spec.sharding_sequence) == "[S01, R, R]" -if __name__ == '__main__': +if __name__ == "__main__": test_sharding_spec() diff --git a/tests/test_tensor/test_tp_with_zero.py b/tests/test_tensor/test_tp_with_zero.py deleted file mode 100644 index c636d9442902ebd2a85931fb3eaa8b3cc71231f5..0000000000000000000000000000000000000000 --- a/tests/test_tensor/test_tp_with_zero.py +++ /dev/null @@ -1,143 +0,0 @@ -import pytest -import torch -from torch.nn.parallel import DistributedDataParallel as DDP - -import colossalai -from colossalai.amp import convert_to_apex_amp -from colossalai.tensor import ColoTensor, ColoTensorSpec, ComputePattern, ComputeSpec, ProcessGroup, ShardSpec -from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn -from colossalai.utils.cuda import get_current_device -from colossalai.zero import ColoInitContext, GeminiAdamOptimizer, GeminiDDP, ZeroDDP -from colossalai.zero.gemini import search_chunk_configuration -from tests.components_to_test.registry import non_distributed_component_funcs -from tests.test_tensor.common_utils import set_seed, tensor_shard_equal -from tests.test_tensor.model.test_gpt2 import init_megatron_spec - - -def check_param(model: ZeroDDP, torch_model: torch.nn.Module, pg: ProcessGroup): - zero_dict = model.state_dict(only_rank_0=False) - torch_dict = torch_model.state_dict() - - for key, value in torch_dict.items(): - # key is 'module.model.PARAMETER', so we truncate it - key = key[7:] - assert key in zero_dict, "{} not in ZeRO dictionary.".format(key) - temp_zero_value = zero_dict[key].to(device=value.device, dtype=value.dtype) - # debug_print([0], "max range: ", key, torch.max(torch.abs(value - temp_zero_value))) - assert tensor_shard_equal(value, temp_zero_value, pg.tp_local_rank(), pg.tp_world_size()), \ - "parameter '{}' has problem.".format(key) - - -def run_fwd_bwd(model, criterion, optimizer, input_ids): - optimizer.zero_grad() - logits = model(input_ids) - logits = logits.float() - loss = criterion(logits, input_ids) - optimizer.backward(loss) - return logits - - -def init_1d_row_spec(model, pg: ProcessGroup): - spec = (ShardSpec([0], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D)) - for n, p in model.named_parameters(): - p.set_process_group(pg) - if 'weight' in n and 'ln' not in n: - p.set_tensor_spec(*spec) - - -def init_1d_col_spec(model, pg: ProcessGroup): - spec = (ShardSpec([-1], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D)) - for n, p in model.named_parameters(): - p.set_process_group(pg) - if 'ln' not in n and ('weight' in n or 'bias' in n): - p.set_tensor_spec(*spec) - - -@parameterize('placement_policy', ['cuda', 'cpu']) -def run_gpt(placement_policy, tp_init_spec_func=None): - set_seed(42) - get_components_func = non_distributed_component_funcs.get_callable('gpt2') - model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func() - - with ColoInitContext(device=get_current_device()): - model = model_builder() - model = model.cuda() - torch_model = model_builder().cuda() - - for torch_p, p in zip(torch_model.parameters(), model.parameters()): - torch_p.data.copy_(p.data) - - world_size = torch.distributed.get_world_size() - - # world size, dp = 2, tp =2, construct a hybrid parallelism. - if world_size == 4: - pg = ProcessGroup(tp_degree=2) - else: - pg = ProcessGroup(tp_degree=world_size) - - if tp_init_spec_func: - tp_init_spec_func(model, pg) - - dp_world_size = pg.dp_world_size() - config_dict, *_ = search_chunk_configuration(model, search_range_mb=1, search_interval_byte=100) - config_dict[dp_world_size]['chunk_size'] = 5000 - config_dict[dp_world_size]['keep_gathered'] = False - if placement_policy != 'cuda': - init_device = torch.device('cpu') - else: - init_device = None - - model = GeminiDDP(model, init_device, placement_policy, True, False) - # The same as the following 3 lines - # chunk_manager = ChunkManager(config_dict, init_device=init_device) - # gemini_manager = GeminiManager(placement_policy, chunk_manager) - # model = ZeroDDP(model, gemini_manager, pin_memory=True) - - zero_optim = GeminiAdamOptimizer(model, lr=1e-3, initial_scale=1) - # The same as the following 2 lines - # optimizer = HybridAdam(model.parameters(), lr=1e-3) - # zero_optim = ZeroOptimizer(optimizer, model, initial_scale=1) - - amp_config = dict(opt_level='O2', keep_batchnorm_fp32=False, loss_scale=1) - torch_optim = torch.optim.Adam(torch_model.parameters(), lr=1e-3) - torch_model, torch_optim = convert_to_apex_amp(torch_model, torch_optim, amp_config) - torch_model = DDP(torch_model, device_ids=[pg.rank()], process_group=pg.dp_process_group()) - - check_param(model, torch_model, pg) - - model.eval() - torch_model.eval() - - set_seed(pg.dp_local_rank()) - for i, (input_ids, label) in enumerate(train_dataloader): - if i > 2: - break - input_ids_colo = ColoTensor.from_torch_tensor(input_ids, ColoTensorSpec(pg)) - zero_logits = run_fwd_bwd(model, criterion, zero_optim, input_ids_colo) - torch_logits = run_fwd_bwd(torch_model, criterion, torch_optim, input_ids) - assert torch.allclose(zero_logits, torch_logits, rtol=1e-3, atol=1e-2) - - zero_optim.step() - torch_optim.step() - check_param(model, torch_model, pg) - - -def run_dist(rank, world_size, port): - config = {} - colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') - if world_size == 4: - run_gpt(tp_init_spec_func=init_megatron_spec) - else: - run_gpt(tp_init_spec_func=init_1d_col_spec) - run_gpt(tp_init_spec_func=init_1d_row_spec) - - -@pytest.mark.dist -@pytest.mark.parametrize('world_size', [1, 4]) -@rerun_if_address_is_in_use() -def test_gpt(world_size): - spawn(run_dist, world_size) - - -if __name__ == '__main__': - test_gpt(4) diff --git a/tests/test_trainer/test_pipeline/test_p2p.py b/tests/test_trainer/test_pipeline/test_p2p.py deleted file mode 100644 index cb7a193d2bfa3f0ebcb1c8f58ff097161503993f..0000000000000000000000000000000000000000 --- a/tests/test_trainer/test_pipeline/test_p2p.py +++ /dev/null @@ -1,108 +0,0 @@ -#!/usr/bin/env python -# -*- encoding: utf-8 -*- - -import pytest -import torch -import torch.distributed as dist - -from colossalai.communication import ( - recv_backward, - recv_forward, - recv_obj_meta, - send_backward, - send_backward_recv_forward, - send_forward, - send_forward_recv_backward, - send_obj_meta, -) -from colossalai.context.parallel_mode import ParallelMode -from colossalai.core import global_context as gpc -from colossalai.initialize import launch -from colossalai.logging import get_dist_logger -from colossalai.testing import rerun_if_address_is_in_use, spawn -from colossalai.utils import get_current_device - -BATCH_SIZE = 4 -SEQ_LENGTH = 2 -HIDDEN_SIZE = 16 - -CONFIG = dict(parallel=dict(pipeline=dict(size=4), tensor=dict(size=1, mode=None)), seed=1024) - - -def check_equal(A, B): - return torch.allclose(A, B, rtol=1e-5, atol=1e-3) - - -def check_forward(output_tensor, rank, logger): - dist.barrier() - if gpc.is_first_rank(ParallelMode.PIPELINE): - tensor = output_tensor.clone() - else: - tensor = recv_forward(output_tensor.shape) - logger.info('Rank {} received forward. Correct tensor: {}'.format(rank, check_equal(tensor, output_tensor))) - if not gpc.is_last_rank(ParallelMode.PIPELINE): - send_forward(tensor) - logger.info('Rank {} sent forward.'.format(rank)) - - -def check_backward(output_grad, rank, logger): - dist.barrier() - if gpc.is_last_rank(ParallelMode.PIPELINE): - grad = output_grad.clone() - else: - grad = recv_backward(output_grad.shape) - logger.info('Rank {} received backward. Correct grad: {}'.format(rank, check_equal(grad, output_grad))) - if not gpc.is_first_rank(ParallelMode.PIPELINE): - send_backward(grad) - logger.info('Rank {} sent backward.'.format(rank)) - - -def check_forward_backward(output_tensor, output_grad, rank, logger): - dist.barrier() - if not gpc.is_first_rank(ParallelMode.PIPELINE): - tensor = send_backward_recv_forward(output_grad, output_tensor.shape) - logger.info('Rank {} sent backward received forward. Correct tensor: {}'.format( - rank, check_equal(tensor, output_tensor))) - if not gpc.is_last_rank(ParallelMode.PIPELINE): - grad = send_forward_recv_backward(output_tensor, output_grad.shape) - logger.info('Rank {} sent forward received backward. Correct grad: {}'.format( - rank, check_equal(grad, output_grad))) - - -def check_comm(size, rank, prev_rank, next_rank, logger): - dtype = torch.float32 - device = get_current_device() - tensor_shape = (BATCH_SIZE, SEQ_LENGTH, HIDDEN_SIZE) - grad_shape = (BATCH_SIZE, SEQ_LENGTH, HIDDEN_SIZE) - tensor = torch.randn(tensor_shape, dtype=dtype, device=device) - dist.all_reduce(tensor) - grad = torch.randn(grad_shape, dtype=dtype, device=device) - dist.all_reduce(grad) - check_forward(tensor, rank, logger) - check_backward(grad, rank, logger) - check_forward_backward(tensor, grad, rank, logger) - - -def run_check(rank, world_size, port): - launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') - logger = get_dist_logger() - rank = gpc.get_global_rank() - prev_rank = gpc.get_prev_global_rank(ParallelMode.PIPELINE) - next_rank = gpc.get_next_global_rank(ParallelMode.PIPELINE) - logger.info('Rank {0}: prev rank {1}, next rank {2}'.format(rank, prev_rank, next_rank)) - logger.info('Distributed environment is initialzied.') - - check_comm(world_size, rank, prev_rank, next_rank, logger) - gpc.destroy() - torch.cuda.empty_cache() - - -@pytest.mark.dist -@rerun_if_address_is_in_use() -def test_p2p(): - world_size = 4 - spawn(run_check, world_size) - - -if __name__ == '__main__': - test_p2p() diff --git a/tests/test_trainer/test_pipeline/test_pipeline_schedule.py b/tests/test_trainer/test_pipeline/test_pipeline_schedule.py deleted file mode 100644 index 6d7bf6b3d89f54b6704096cb0394c56db51aee5f..0000000000000000000000000000000000000000 --- a/tests/test_trainer/test_pipeline/test_pipeline_schedule.py +++ /dev/null @@ -1,87 +0,0 @@ -# referenced from Megatron and used to testify communication - -import os -from pathlib import Path - -import pytest -import torch -import torch.nn as nn -from torchvision import transforms -from torchvision.datasets import CIFAR10 -from torchvision.models import resnet18 - -import colossalai -from colossalai.context import ParallelMode -from colossalai.core import global_context as gpc -from colossalai.initialize import launch -from colossalai.testing import rerun_if_address_is_in_use, spawn -from colossalai.utils import get_dataloader, print_rank_0 - -BATCH_SIZE = 8 - -CONFIG = dict(NUM_MICRO_BATCHES=2, parallel=dict(pipeline=dict(size=2), tensor=dict(size=1, mode=None))) - - -def run_schedule(rank, world_size, port): - launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') - - # build model - model = resnet18(num_classes=10) - - if gpc.get_local_rank(ParallelMode.PIPELINE) == 0: - model = nn.Sequential(model.conv1, model.bn1, model.relu, model.maxpool, model.layer1, model.layer2) - elif gpc.get_local_rank(ParallelMode.PIPELINE) == 1: - - class Flatten(nn.Module): - - def forward(self, x): - return torch.flatten(x, 1) - - model = nn.Sequential(model.layer3, model.layer4, model.avgpool, Flatten(), model.fc) - - print_rank_0('model is created') - - train_dataset = CIFAR10(root=Path(os.environ['DATA']), - download=True, - transform=transforms.Compose([ - transforms.ToTensor(), - transforms.Normalize(mean=[0.4914, 0.4822, 0.4465], std=[0.2023, 0.1994, 0.2010]), - ])) - - train_dataloader = get_dataloader( - dataset=train_dataset, - shuffle=True, - add_sampler=True, - batch_size=BATCH_SIZE, - pin_memory=True, - ) - - # build criterion - criterion = torch.nn.CrossEntropyLoss() - - # optimizer - optimizer = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=0) - - # initialize - engine, train_dataloader, _, _ = colossalai.initialize(model, optimizer, criterion, train_dataloader) - - # build pipeline schedule - schedule = engine.schedule - - # run schedule - data_iter = iter(train_dataloader) - schedule.forward_backward_step(engine, data_iter) - - gpc.destroy() - torch.cuda.empty_cache() - - -@pytest.mark.dist -@rerun_if_address_is_in_use() -def test_pipeline_schedule(): - world_size = 2 - spawn(run_schedule, world_size) - - -if __name__ == '__main__': - test_pipeline_schedule() diff --git a/tests/test_trainer/test_trainer_with_non_pipe_schedule.py b/tests/test_trainer/test_trainer_with_non_pipe_schedule.py deleted file mode 100644 index 753f82222f9d81679f4a720d258abc1659bc8fc3..0000000000000000000000000000000000000000 --- a/tests/test_trainer/test_trainer_with_non_pipe_schedule.py +++ /dev/null @@ -1,59 +0,0 @@ -import pytest -import torch - -import colossalai -from colossalai.amp.amp_type import AMP_TYPE -from colossalai.logging import get_dist_logger -from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn -from colossalai.trainer import Trainer -from colossalai.utils import MultiTimer -from tests.components_to_test.registry import non_distributed_component_funcs - -BATCH_SIZE = 4 -IMG_SIZE = 32 -NUM_EPOCHS = 200 - -CONFIG = dict(fp16=dict(mode=AMP_TYPE.TORCH)) - - -@parameterize('model_name', ['repeated_computed_layers', 'resnet18', 'nested_model']) -def run_trainer(model_name): - get_components_func = non_distributed_component_funcs.get_callable(model_name) - model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func() - model = model_builder() - optimizer = optimizer_class(model.parameters(), lr=1e-3) - engine, train_dataloader, *_ = colossalai.initialize(model=model, - optimizer=optimizer, - criterion=criterion, - train_dataloader=train_dataloader) - - logger = get_dist_logger() - logger.info("engine is built", ranks=[0]) - - timer = MultiTimer() - trainer = Trainer(engine=engine, logger=logger, timer=timer) - logger.info("trainer is built", ranks=[0]) - - logger.info("start training", ranks=[0]) - trainer.fit(train_dataloader=train_dataloader, - test_dataloader=test_dataloader, - epochs=NUM_EPOCHS, - max_steps=3, - display_progress=True, - test_interval=5) - torch.cuda.empty_cache() - - -def run_dist(rank, world_size, port): - colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') - - -@pytest.mark.dist -@rerun_if_address_is_in_use() -def test_trainer_no_pipeline(): - world_size = 4 - spawn(run_dist, world_size) - - -if __name__ == '__main__': - test_trainer_no_pipeline() diff --git a/tests/test_trainer/test_trainer_with_pipe_schedule.py b/tests/test_trainer/test_trainer_with_pipe_schedule.py deleted file mode 100644 index bb63d51a0b656183c20526bf0389727219baf2b1..0000000000000000000000000000000000000000 --- a/tests/test_trainer/test_trainer_with_pipe_schedule.py +++ /dev/null @@ -1,96 +0,0 @@ -import os -from pathlib import Path - -import pytest -import torch -import torch.nn as nn -from torch.optim import Adam -from torchvision import transforms -from torchvision.datasets import CIFAR10 -from torchvision.models import resnet18 - -import colossalai -from colossalai.context.parallel_mode import ParallelMode -from colossalai.core import global_context as gpc -from colossalai.logging import get_dist_logger -from colossalai.testing import rerun_if_address_is_in_use, spawn -from colossalai.trainer import Trainer -from colossalai.utils import MultiTimer, get_dataloader - -BATCH_SIZE = 4 -IMG_SIZE = 32 -NUM_EPOCHS = 200 - -CONFIG = dict( - NUM_MICRO_BATCHES=2, - parallel=dict(pipeline=2), -) - - -def run_trainer_with_pipeline(rank, world_size, port): - colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') - - # build model - model = resnet18(num_classes=10) - - if gpc.get_local_rank(ParallelMode.PIPELINE) == 0: - model = nn.Sequential(model.conv1, model.bn1, model.relu, model.maxpool, model.layer1, model.layer2) - elif gpc.get_local_rank(ParallelMode.PIPELINE) == 1: - - class Flatten(nn.Module): - - def forward(self, x): - return torch.flatten(x, 1) - - model = nn.Sequential(model.layer3, model.layer4, model.avgpool, Flatten(), model.fc) - - # build dataloaders - train_dataset = CIFAR10(root=Path(os.environ['DATA']), - download=True, - transform=transforms.Compose([ - transforms.Resize(size=(IMG_SIZE, IMG_SIZE)), - transforms.ToTensor(), - transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)) - ])) - - train_dataloader = get_dataloader(dataset=train_dataset, - shuffle=True, - batch_size=BATCH_SIZE, - pin_memory=True, - drop_last=True) - - # build optimizer - optimizer = Adam(model.parameters(), lr=0.001) - criterion = nn.CrossEntropyLoss() - - engine, train_dataloader, *args = colossalai.initialize(model=model, - optimizer=optimizer, - criterion=criterion, - train_dataloader=train_dataloader) - - logger = get_dist_logger() - logger.info("engine is built", ranks=[0]) - timer = MultiTimer() - trainer = Trainer(engine=engine, logger=logger, timer=timer) - logger.info("trainer is built", ranks=[0]) - - logger.info("start training", ranks=[0]) - - trainer.fit(train_dataloader=train_dataloader, - epochs=NUM_EPOCHS, - max_steps=3, - display_progress=True, - test_interval=5) - gpc.destroy() - torch.cuda.empty_cache() - - -@pytest.mark.dist -@rerun_if_address_is_in_use() -def test_trainer_with_pipeline(): - world_size = 4 - spawn(run_trainer_with_pipeline, world_size) - - -if __name__ == '__main__': - test_trainer_with_pipeline() diff --git a/tests/test_utils/test_checkpoint_io/test_build_checkpoints.py b/tests/test_utils/test_checkpoint_io/test_build_checkpoints.py deleted file mode 100644 index 6d89fb90c574e9b06de571760b59441e75bf2b33..0000000000000000000000000000000000000000 --- a/tests/test_utils/test_checkpoint_io/test_build_checkpoints.py +++ /dev/null @@ -1,120 +0,0 @@ -import torch -import torch.nn as nn -from colossalai.utils.checkpoint_io.meta import ParamDistMeta -from colossalai.utils.checkpoint_io.utils import build_checkpoints -from torch.optim import Adam - - -class DummyModel(nn.Module): - - def __init__(self) -> None: - super().__init__() - self.fc = nn.Linear(20, 1) - - -def test_global_model(): - model = DummyModel() - model_checkpoints, optimizer_checkpoints, meta = build_checkpoints(0, model) - assert len(model_checkpoints) == 1 - assert len(optimizer_checkpoints) == 0 - assert meta['dist_meta'] is None - orig_state_dict = model.state_dict() - global_state_dict = model_checkpoints[0] - assert set(orig_state_dict.keys()) == set(global_state_dict.keys()) - for k, v in orig_state_dict.items(): - assert torch.equal(v, global_state_dict[k]) - - -def test_global_model_shard(): - model = DummyModel() - model_checkpoints, optimizer_checkpoints, meta = build_checkpoints(80, model) - assert len(model_checkpoints) == 2 - assert len(optimizer_checkpoints) == 0 - assert meta['dist_meta'] is None - orig_state_dict = model.state_dict() - assert set(orig_state_dict.keys()) == set(model_checkpoints[0].keys()) | set(model_checkpoints[1].keys()) - assert len(set(model_checkpoints[0].keys()) & set(model_checkpoints[1].keys())) == 0 - for k, v in orig_state_dict.items(): - for state_dict in model_checkpoints: - if k in state_dict: - assert torch.equal(v, state_dict[k]) - - -def test_global_optimizer(): - model = DummyModel() - for p in model.parameters(): - p.grad = torch.rand_like(p) - optimizer = Adam(model.parameters(), lr=1e-3) - optimizer.step() - model_checkpoints, optimizer_checkpoints, meta = build_checkpoints(0, model, optimizer) - assert len(optimizer_checkpoints) == 1 - assert meta['param_to_os'] == {'fc.weight': 0, 'fc.bias': 1} - for state in meta['paired_os'].values(): - for k, is_paired in state.items(): - if k == 'step': - assert not is_paired - else: - assert is_paired - orig_state_dict = optimizer.state_dict() - state_dict = optimizer_checkpoints[0] - for k, orig_state in orig_state_dict['state'].items(): - state = state_dict['state'][k] - for v1, v2 in zip(orig_state.values(), state.values()): - if isinstance(v2, torch.Tensor): - assert torch.equal(v1, v2) - else: - assert v2 == v2 - assert orig_state_dict['param_groups'] == state_dict['param_groups'] - - -def test_global_optimizer_shard(): - model = DummyModel() - for p in model.parameters(): - p.grad = torch.rand_like(p) - optimizer = Adam(model.parameters(), lr=1e-3) - optimizer.step() - model_checkpoints, optimizer_checkpoints, meta = build_checkpoints(80, model, optimizer) - assert len(optimizer_checkpoints) == 2 - assert 'param_groups' in optimizer_checkpoints[0] and 'param_groups' not in optimizer_checkpoints[1] - orig_state_dict = optimizer.state_dict() - assert set(orig_state_dict['state'].keys()) == set(optimizer_checkpoints[0]['state'].keys()) | set( - optimizer_checkpoints[1]['state'].keys()) - assert len(set(optimizer_checkpoints[0]['state'].keys()) & set(optimizer_checkpoints[1]['state'].keys())) == 0 - for k, orig_state in orig_state_dict['state'].items(): - state = optimizer_checkpoints[0]['state'][k] if k in optimizer_checkpoints[0][ - 'state'] else optimizer_checkpoints[1]['state'][k] - for v1, v2 in zip(orig_state.values(), state.values()): - if isinstance(v2, torch.Tensor): - assert torch.equal(v1, v2) - else: - assert v1 == v2 - - assert orig_state_dict['param_groups'] == optimizer_checkpoints[0]['param_groups'] - - -def test_dist_model_optimizer(): - model = DummyModel() - for p in model.parameters(): - p.grad = torch.rand_like(p) - optimizer = Adam(model.parameters(), lr=1e-3) - optimizer.step() - dist_meta = {'fc.weight': ParamDistMeta(0, 2, 0, 1), 'fc.bias': ParamDistMeta(1, 2, 0, 1)} - model_checkpoints, optimizer_checkpoints, meta = build_checkpoints(0, model, optimizer, dist_meta=dist_meta) - assert dist_meta == meta['dist_meta'] - assert len(model_checkpoints) == 1 - assert len(optimizer_checkpoints) == 1 - assert 'fc.weight' in model_checkpoints[0] and 'fc.bias' in model_checkpoints[0] - assert 0 in optimizer_checkpoints[0]['state'] and 1 in optimizer_checkpoints[0]['state'] - dist_meta = {'fc.weight': ParamDistMeta(1, 2, 0, 1), 'fc.bias': ParamDistMeta(1, 2, 0, 1)} - model_checkpoints, optimizer_checkpoints, meta = build_checkpoints(0, model, optimizer, dist_meta=dist_meta) - assert dist_meta == meta['dist_meta'] - assert len(model_checkpoints) == 1 - assert len(optimizer_checkpoints) == 1 - - -if __name__ == '__main__': - test_global_model() - test_global_model_shard() - test_global_optimizer() - test_global_optimizer_shard() - test_dist_model_optimizer() diff --git a/tests/test_utils/test_checkpoint_io/test_load.py b/tests/test_utils/test_checkpoint_io/test_load.py deleted file mode 100644 index b1a741515728c2e6f697bdbd8f7e017924388b90..0000000000000000000000000000000000000000 --- a/tests/test_utils/test_checkpoint_io/test_load.py +++ /dev/null @@ -1,186 +0,0 @@ -from copy import deepcopy -from functools import partial -from tempfile import TemporaryDirectory -from typing import Dict - -import pytest -import torch -import torch.distributed as dist -import torch.nn as nn -from torch import Tensor -from torch.nn import Module -from torch.optim import Adam, Optimizer - -import colossalai -from colossalai.testing import rerun_if_address_is_in_use, spawn -from colossalai.utils.checkpoint_io.io import load, save -from colossalai.utils.checkpoint_io.meta import ParamDistMeta, ParamRedistMeta, RankRedistMeta, RedistMeta - - -def check_model_state_dict(a: Dict[str, Tensor], b: Dict[str, Tensor]) -> None: - assert set(a.keys()) == set(b.keys()) - for k, v in a.items(): - assert torch.equal(v, b[k]) - - -def check_optim_state_dict(a: dict, b: dict, ignore_param_gruops: bool = False) -> None: - assert set(a['state'].keys()) == set(b['state'].keys()) - for k, state in a['state'].items(): - b_state = b['state'][k] - for v1, v2 in zip(state.values(), b_state.values()): - if isinstance(v1, Tensor): - assert torch.equal(v1, v2) - else: - assert v1 == v2 - if not ignore_param_gruops: - assert a['param_groups'] == b['param_groups'] - - -class DummyModel(nn.Module): - - def __init__(self) -> None: - super().__init__() - self.fc = nn.Linear(20, 1) - - -def prepare_model_optim(shard: bool = False, zero: bool = False): - model = DummyModel() - if shard: - model.fc.weight.data = model.fc.weight.chunk(2, 1)[dist.get_rank() % 2] - if zero: - dp_rank = dist.get_rank() // 2 - model.fc.weight.data = model.fc.weight.reshape(-1).split([3, model.fc.weight.size(1) - 3], 0)[dp_rank] - if dp_rank != 0: - model.fc.bias.data = torch.empty(0, dtype=model.fc.bias.dtype) - for p in model.parameters(): - p.grad = torch.rand_like(p) - optimizer = Adam(model.parameters(), lr=1e-3) - optimizer.step() - return model, optimizer - - -def reset_model_optim(model: Module, optimizer: Optimizer, scalar: float = 0.0): - with torch.no_grad(): - for p in model.parameters(): - p.fill_(scalar) - for state in optimizer.state.values(): - for v in state.values(): - if isinstance(v, Tensor): - v.fill_(scalar) - - -def get_dist_metas(nprocs: int, zero: bool = False): - dp_world_size = nprocs // 2 - dist_metas = [] - for rank in range(nprocs): - if zero: - dist_metas.append({ - 'fc.weight': - ParamDistMeta(rank // 2, - dp_world_size, - rank % 2, - 2, - tp_shard_dims=[1], - tp_num_parts=[2], - zero_numel=10, - zero_orig_shape=[1, 10]), - 'fc.bias': - ParamDistMeta(rank // 2, dp_world_size, 0, 1, zero_numel=1, zero_orig_shape=[1]) - }) - else: - dist_metas.append({ - 'fc.weight': ParamDistMeta(rank // 2, dp_world_size, rank % 2, 2, tp_shard_dims=[1], tp_num_parts=[2]), - 'fc.bias': ParamDistMeta(rank // 2, dp_world_size, 0, 1) - }) - return dist_metas - - -def get_redist_meta(nprocs: int): - dp_world_size = nprocs // 2 - rank_meta = { - 'fc.weight': {rank: RankRedistMeta(rank // 2, rank % 2, 0) for rank in range(nprocs)}, - 'fc.bias': {rank: RankRedistMeta(rank // 2, 0, 0) for rank in range(nprocs)} - } - param_meta = { - 'fc.weight': ParamRedistMeta(dp_world_size, 2, tp_shard_dims=[1], tp_num_parts=[2]), - 'fc.bias': ParamRedistMeta(dp_world_size, 1) - } - return RedistMeta(rank_meta, [], param_meta) - - -@pytest.mark.parametrize('max_shard_size_gb', [80 / 1024**3, 0]) -def test_save_global_load_global(max_shard_size_gb: float): - model, optimizer = prepare_model_optim() - with TemporaryDirectory() as dir_name: - save(dir_name, model, optimizer, max_shard_size_gb=max_shard_size_gb) - new_model, new_optimizer = prepare_model_optim() - load(dir_name, new_model, new_optimizer, max_shard_size_gb=max_shard_size_gb) - check_model_state_dict(model.state_dict(), new_model.state_dict()) - check_optim_state_dict(optimizer.state_dict(), new_optimizer.state_dict()) - - -def run_dist(rank, world_size, port, test_fn): - colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') - test_fn() - - -def launch_dist(fn, world_size: int): - spawn(run_dist, world_size, test_fn=fn) - - -def save_dist(dir_name: str, zero: bool): - model, optmizer = prepare_model_optim(shard=True, zero=zero) - reset_model_optim(model, optmizer) - world_size = dist.get_world_size() - rank = dist.get_rank() - save(dir_name, model, optmizer, dist_meta=get_dist_metas(world_size, zero)[rank]) - - -def load_and_check_dist(dir_name: str): - world_size = dist.get_world_size() - model, optmizer = prepare_model_optim(shard=True) - reset_model_optim(model, optmizer) - model_state_dict = deepcopy(model.state_dict()) - optimizer_state_dict = deepcopy(optmizer.state_dict()) - reset_model_optim(model, optmizer, 1) - load(dir_name, model, optmizer, get_redist_meta(world_size), get_dist_metas(world_size)) - check_model_state_dict(model_state_dict, model.state_dict()) - check_optim_state_dict(optimizer_state_dict, optmizer.state_dict()) - - -@pytest.mark.dist -@rerun_if_address_is_in_use() -def test_save_global_load_dist(): - model, optimizer = prepare_model_optim() - reset_model_optim(model, optimizer) - with TemporaryDirectory() as dir_name: - save(dir_name, model, optimizer) - fn = partial(load_and_check_dist, dir_name) - launch_dist(fn, 4) - - -@pytest.mark.dist -@rerun_if_address_is_in_use() -def test_save_dist_load_dist(): - with TemporaryDirectory() as dir_name: - # save tp + dp - fn = partial(save_dist, dir_name, False) - launch_dist(fn, 2) - # load tp + dp - fn = partial(load_and_check_dist, dir_name) - launch_dist(fn, 2) - with TemporaryDirectory() as dir_name: - # save tp + zero - fn = partial(save_dist, dir_name, True) - launch_dist(fn, 4) - # load tp + dp - fn = partial(load_and_check_dist, dir_name) - launch_dist(fn, 2) - launch_dist(fn, 4) - - -if __name__ == '__main__': - test_save_global_load_global(80 / 1024**3) - test_save_global_load_global(0) - test_save_global_load_dist() - test_save_dist_load_dist() diff --git a/tests/test_utils/test_checkpoint_io/test_merge.py b/tests/test_utils/test_checkpoint_io/test_merge.py deleted file mode 100644 index 255c74adf0a2c9c1111a60c04930a100f187c227..0000000000000000000000000000000000000000 --- a/tests/test_utils/test_checkpoint_io/test_merge.py +++ /dev/null @@ -1,126 +0,0 @@ -import os -from functools import partial -from tempfile import TemporaryDirectory - -import pytest -import torch -import torch.distributed as dist -import torch.nn as nn -from torch.optim import Adam - -import colossalai -from colossalai.testing import rerun_if_address_is_in_use, spawn -from colossalai.utils.checkpoint_io.constant import GLOBAL_META_FILE_NAME -from colossalai.utils.checkpoint_io.io import merge, save -from colossalai.utils.checkpoint_io.meta import ParamDistMeta - - -class DummyModel(nn.Module): - - def __init__(self) -> None: - super().__init__() - self.fc = nn.Linear(20, 1) - - -def prepare_model_optim(shard: bool = False, zero: bool = False): - model = DummyModel() - if shard: - model.fc.weight.data = model.fc.weight.chunk(2, 1)[dist.get_rank() % 2] - if zero: - dp_rank = dist.get_rank() // 2 - model.fc.weight.data = model.fc.weight.reshape(-1).split([3, model.fc.weight.size(1) - 3], 0)[dp_rank] - if dp_rank != 0: - model.fc.bias.data = torch.empty(0, dtype=model.fc.bias.dtype) - for p in model.parameters(): - p.grad = torch.ones_like(p) - optimizer = Adam(model.parameters(), lr=1e-3) - optimizer.step() - return model, optimizer - - -def test_merge_global(): - model, optimizer = prepare_model_optim() - with TemporaryDirectory() as dir_name: - save(dir_name, model, optimizer) - with TemporaryDirectory() as output_dir: - merge(dir_name, output_dir) - assert len(os.listdir(output_dir)) == 0 - with TemporaryDirectory() as dir_name: - save(dir_name, model, optimizer, max_shard_size_gb=80 / 1024**3) - with TemporaryDirectory() as output_dir: - merge(dir_name, output_dir) - assert len(os.listdir(output_dir)) == 0 - - -def run_dist(rank, world_size, port, test_fn): - colossalai.launch(config={'parallel': { - 'tensor': { - 'mode': '1d', - 'size': 2 - } - }}, - rank=rank, - world_size=world_size, - host='localhost', - port=port, - backend='nccl') - test_fn() - - -def run_save_dist(dir_name: str, zero: bool): - model, optmizer = prepare_model_optim(shard=True, zero=zero) - rank = dist.get_rank() - dp_world_size = dist.get_world_size() // 2 - if not zero: - dist_metas = { - 'fc.weight': ParamDistMeta(rank // 2, dp_world_size, rank % 2, 2, tp_shard_dims=[1], tp_num_parts=[2]), - 'fc.bias': ParamDistMeta(rank // 2, dp_world_size, 0, 1) - } - else: - dist_metas = { - 'fc.weight': - ParamDistMeta(rank // 2, - dp_world_size, - rank % 2, - 2, - tp_shard_dims=[1], - tp_num_parts=[2], - zero_numel=10, - zero_orig_shape=[1, 10]), - 'fc.bias': - ParamDistMeta(rank // 2, dp_world_size, 0, 1, zero_numel=1, zero_orig_shape=[1]) - } - save(dir_name, model, optmizer, dist_meta=dist_metas) - - -@pytest.mark.dist -@pytest.mark.parametrize("zero", [False, True]) -@rerun_if_address_is_in_use() -def test_merge_tp_dp(zero: bool): - with TemporaryDirectory() as dir_name: - fn = partial(run_save_dist, dir_name, zero) - world_size = 4 - spawn(run_dist, world_size, test_fn=fn) - with TemporaryDirectory() as output_dir: - merge(dir_name, output_dir) - assert len(os.listdir(output_dir)) == 5 - global_meta = torch.load(os.path.join(output_dir, GLOBAL_META_FILE_NAME)) - assert len(global_meta['meta']) == 1 - meta = torch.load(os.path.join(output_dir, global_meta['meta'][0])) - assert meta['dist_meta'] is None - assert len(meta['params']) == 2 - assert len(meta['model']) == 1 and len(meta['optimizer']) == 1 - model_state_dict = torch.load(os.path.join(output_dir, meta['model'][0])) - assert len(model_state_dict) == 2 - assert model_state_dict['fc.weight'].size(1) == 20 - optimizer_state_dict = torch.load(os.path.join(output_dir, meta['optimizer'][0])) - assert len(optimizer_state_dict['state']) == 2 - assert 'param_groups' in optimizer_state_dict and 'state' in optimizer_state_dict - assert optimizer_state_dict['state'][0]['exp_avg'].size(1) == 20 - assert optimizer_state_dict['state'][0]['exp_avg_sq'].size(1) == 20 - - -if __name__ == '__main__': - test_merge_global() - test_merge_tp_dp(False) - test_merge_tp_dp(True) diff --git a/tests/test_utils/test_checkpoint_io/test_merge_param.py b/tests/test_utils/test_checkpoint_io/test_merge_param.py deleted file mode 100644 index 5da2ae4fe1f8f0e0797d2b733b4fe4ecf1778a67..0000000000000000000000000000000000000000 --- a/tests/test_utils/test_checkpoint_io/test_merge_param.py +++ /dev/null @@ -1,101 +0,0 @@ -import torch -from colossalai.utils.checkpoint_io.meta import ParamDistMeta -from colossalai.utils.checkpoint_io.distributed import unflatten_zero_param, gather_tp_param, merge_param - - -def test_unflatten_zero_param_even() -> None: - dist_metas = [ParamDistMeta(i, 4, 0, 1, zero_numel=16, zero_orig_shape=[4, 4]) for i in range(4)] - orig_tensor = torch.rand(4, 4) - tensors = list(orig_tensor.reshape(-1).chunk(4)) - unflattened_tensor = unflatten_zero_param(tensors, dist_metas) - assert torch.equal(orig_tensor, unflattened_tensor) - merged_tensor = merge_param(tensors, dist_metas) - assert torch.equal(orig_tensor, merged_tensor) - - -def test_unflatten_zero_param_uneven() -> None: - dist_metas = [ParamDistMeta(i, 4, 0, 1, zero_numel=16, zero_orig_shape=[4, 4]) for i in range(1, 3)] - orig_tensor = torch.rand(4, 4) - tensors = list(orig_tensor.reshape(-1).split([13, 3])) - unflattened_tensor = unflatten_zero_param(tensors, dist_metas) - assert torch.equal(orig_tensor, unflattened_tensor) - merged_tensor = merge_param(tensors, dist_metas) - assert torch.equal(orig_tensor, merged_tensor) - - -def test_gather_tp_param_1d_row() -> None: - dist_metas = [ParamDistMeta(0, 1, i, 4, tp_shard_dims=[0], tp_num_parts=[4]) for i in range(4)] - orig_tensor = torch.rand(4, 4) - tensors = [t.contiguous() for t in orig_tensor.chunk(4, 0)] - gathered_tensor = gather_tp_param(tensors, dist_metas) - assert torch.equal(orig_tensor, gathered_tensor) - merged_tensor = merge_param(tensors, dist_metas) - assert torch.equal(orig_tensor, merged_tensor) - - -def test_gather_tp_param_1d_col() -> None: - dist_metas = [ParamDistMeta(0, 1, i, 4, tp_shard_dims=[1], tp_num_parts=[4]) for i in range(4)] - orig_tensor = torch.rand(4, 4) - tensors = [t.contiguous() for t in orig_tensor.chunk(4, 1)] - gathered_tensor = gather_tp_param(tensors, dist_metas) - assert torch.equal(orig_tensor, gathered_tensor) - merged_tensor = merge_param(tensors, dist_metas) - assert torch.equal(orig_tensor, merged_tensor) - - -def test_gather_tp_param_2d() -> None: - dist_metas = [ParamDistMeta(0, 1, i, 6, tp_shard_dims=[0, 1], tp_num_parts=[2, 3]) for i in range(6)] - orig_tensor = torch.rand(4, 6) - tensors = [t.contiguous() for tl in orig_tensor.chunk(2, 0) for t in tl.chunk(3, 1)] - gathered_tensor = gather_tp_param(tensors, dist_metas) - assert torch.equal(orig_tensor, gathered_tensor) - merged_tensor = merge_param(tensors, dist_metas) - assert torch.equal(orig_tensor, merged_tensor) - - -def test_gather_tp_param_2d_reverse() -> None: - dist_metas = [ParamDistMeta(0, 1, i, 6, tp_shard_dims=[1, 0], tp_num_parts=[3, 2]) for i in range(6)] - orig_tensor = torch.rand(4, 6) - tensors = [t.contiguous() for tl in orig_tensor.chunk(2, 0) for t in tl.chunk(3, 1)] - gathered_tensor = gather_tp_param(tensors, dist_metas) - assert torch.equal(orig_tensor, gathered_tensor) - merged_tensor = merge_param(tensors, dist_metas) - assert torch.equal(orig_tensor, merged_tensor) - - -def test_merge_param_hybrid() -> None: - dist_metas = [ - ParamDistMeta(i % 2, - 2, - i // 2, - 6, - tp_shard_dims=[1, 0], - tp_num_parts=[3, 2], - zero_numel=4, - zero_orig_shape=[2, 2]) for i in range(12) - ] - orig_tensor = torch.rand(4, 6) - tensors = [ - chunk for tl in orig_tensor.chunk(2, 0) for t in tl.chunk(3, 1) - for chunk in t.contiguous().reshape(-1).split([1, 3]) - ] - merged_tensor = merge_param(tensors, dist_metas) - assert torch.equal(orig_tensor, merged_tensor) - - -def test_merge_param_dummy() -> None: - dist_metas = [ParamDistMeta(0, 1, 0, 1)] - orig_tensor = torch.rand(4, 6) - merged_tensor = merge_param([orig_tensor], dist_metas) - assert torch.equal(orig_tensor, merged_tensor) - - -if __name__ == '__main__': - test_unflatten_zero_param_even() - test_unflatten_zero_param_uneven() - test_gather_tp_param_1d_row() - test_gather_tp_param_1d_col() - test_gather_tp_param_2d() - test_gather_tp_param_2d_reverse() - test_merge_param_hybrid() - test_merge_param_dummy() diff --git a/tests/test_utils/test_checkpoint_io/test_redist.py b/tests/test_utils/test_checkpoint_io/test_redist.py deleted file mode 100644 index 144715bdfcca3db119eb96deb16091b10c537598..0000000000000000000000000000000000000000 --- a/tests/test_utils/test_checkpoint_io/test_redist.py +++ /dev/null @@ -1,152 +0,0 @@ -import os -from functools import partial -from tempfile import TemporaryDirectory - -import pytest -import torch -import torch.distributed as dist -import torch.nn as nn -from torch.optim import Adam - -import colossalai -from colossalai.testing import rerun_if_address_is_in_use, spawn -from colossalai.utils.checkpoint_io.constant import GLOBAL_META_FILE_NAME -from colossalai.utils.checkpoint_io.io import redist, save -from colossalai.utils.checkpoint_io.meta import ( - ParamDistMeta, - ParamRedistMeta, - PipelineRedistMeta, - RankRedistMeta, - RedistMeta, -) - - -class DummyModel(nn.Module): - - def __init__(self) -> None: - super().__init__() - self.fc = nn.Linear(20, 1) - - -def prepare_model_optim(shard: bool = False, zero: bool = False): - model = DummyModel() - if shard: - model.fc.weight.data = model.fc.weight.chunk(2, 1)[dist.get_rank() % 2] - if zero: - dp_rank = dist.get_rank() // 2 - model.fc.weight.data = model.fc.weight.reshape(-1).split([3, model.fc.weight.size(1) - 3], 0)[dp_rank] - if dp_rank != 0: - model.fc.bias.data = torch.empty(0, dtype=model.fc.bias.dtype) - for p in model.parameters(): - p.grad = torch.ones_like(p) - optimizer = Adam(model.parameters(), lr=1e-3) - optimizer.step() - return model, optimizer - - -def get_dist_metas(nprocs: int, zero: bool = False): - dp_world_size = nprocs // 2 - dist_metas = [] - for rank in range(nprocs): - if zero: - dist_metas.append({ - 'fc.weight': - ParamDistMeta(rank // 2, - dp_world_size, - rank % 2, - 2, - tp_shard_dims=[1], - tp_num_parts=[2], - zero_numel=10, - zero_orig_shape=[1, 10]), - 'fc.bias': - ParamDistMeta(rank // 2, dp_world_size, 0, 1, zero_numel=1, zero_orig_shape=[1]) - }) - else: - dist_metas.append({ - 'fc.weight': ParamDistMeta(rank // 2, dp_world_size, rank % 2, 2, tp_shard_dims=[1], tp_num_parts=[2]), - 'fc.bias': ParamDistMeta(rank // 2, dp_world_size, 0, 1) - }) - return dist_metas - - -def get_redist_meta(nprocs: int): - dp_world_size = nprocs // 2 - rank_meta = { - 'fc.weight': {rank: RankRedistMeta(rank // 2, rank % 2, 0) for rank in range(nprocs)}, - 'fc.bias': {rank: RankRedistMeta(rank // 2, 0, 0) for rank in range(nprocs)} - } - param_meta = { - 'fc.weight': ParamRedistMeta(dp_world_size, 2, tp_shard_dims=[1], tp_num_parts=[2]), - 'fc.bias': ParamRedistMeta(dp_world_size, 1) - } - return RedistMeta(rank_meta, [], param_meta) - - -def check_checkpoint_shape(dir_name: str): - global_meta = torch.load(os.path.join(dir_name, GLOBAL_META_FILE_NAME)) - for meta_name in global_meta['meta']: - meta = torch.load(os.path.join(dir_name, meta_name)) - assert meta['dist_meta'] is not None - assert len(meta['params']) == 2 - assert len(meta['model']) == 1 and len(meta['optimizer']) == 1 - model_state_dict = torch.load(os.path.join(dir_name, meta['model'][0])) - assert len(model_state_dict) == 2 - assert model_state_dict['fc.weight'].size(1) == 10 - optimizer_state_dict = torch.load(os.path.join(dir_name, meta['optimizer'][0])) - assert len(optimizer_state_dict['state']) == 2 - assert 'param_groups' in optimizer_state_dict and 'state' in optimizer_state_dict - assert optimizer_state_dict['state'][0]['exp_avg'].size(1) == 10 - assert optimizer_state_dict['state'][0]['exp_avg_sq'].size(1) == 10 - - -def test_global_to_dist(): - model, optimizer = prepare_model_optim() - with TemporaryDirectory() as dir_name: - save(dir_name, model, optimizer) - with TemporaryDirectory() as output_dir: - redist(dir_name, output_dir, get_redist_meta(4), get_dist_metas(4)) - check_checkpoint_shape(output_dir) - - -def run_dist(rank, world_size, port, test_fn): - colossalai.launch(config={'parallel': { - 'tensor': { - 'mode': '1d', - 'size': 2 - } - }}, - rank=rank, - world_size=world_size, - host='localhost', - port=port, - backend='nccl') - test_fn() - - -def run_save_dist(dir_name: str, zero: bool): - model, optmizer = prepare_model_optim(shard=True, zero=zero) - rank = dist.get_rank() - save(dir_name, model, optmizer, dist_meta=get_dist_metas(4, zero)[rank]) - - -@pytest.mark.dist -@pytest.mark.parametrize("zero", [False, True]) -@rerun_if_address_is_in_use() -def test_dist_to_dist(zero: bool): - with TemporaryDirectory() as dir_name: - fn = partial(run_save_dist, dir_name, zero) - world_size = 4 - spawn(run_dist, world_size, test_fn=fn) - with TemporaryDirectory() as output_dir: - redist(dir_name, output_dir, get_redist_meta(4), get_dist_metas(4)) - if not zero: - assert len(os.listdir(output_dir)) == 0 - else: - check_checkpoint_shape(output_dir) - - -if __name__ == '__main__': - test_global_to_dist() - test_dist_to_dist(False) - test_dist_to_dist(True) diff --git a/tests/test_utils/test_checkpoint_io/test_save.py b/tests/test_utils/test_checkpoint_io/test_save.py deleted file mode 100644 index e35e566f6ff83c91ce5f8a367e478e3f7eecc579..0000000000000000000000000000000000000000 --- a/tests/test_utils/test_checkpoint_io/test_save.py +++ /dev/null @@ -1,149 +0,0 @@ -import os -from functools import partial -from tempfile import TemporaryDirectory -from typing import Dict - -import pytest -import torch -import torch.distributed as dist -import torch.nn as nn -from torch import Tensor -from torch.optim import Adam - -import colossalai -from colossalai.testing import rerun_if_address_is_in_use, spawn -from colossalai.utils.checkpoint_io.constant import ( - GLOBAL_META_FILE_NAME, - META_CKPT_FILE_NAME, - MODEL_CKPT_FILE_NAME, - OTHER_CKPT_FILE_NAME, -) -from colossalai.utils.checkpoint_io.io import save -from colossalai.utils.checkpoint_io.meta import ParamDistMeta - - -def check_model_state_dict(a: Dict[str, Tensor], b: Dict[str, Tensor]) -> None: - assert set(a.keys()) == set(b.keys()) - for k, v in a.items(): - assert torch.equal(v, b[k]) - - -def check_optim_state_dict(a: dict, b: dict, ignore_param_gruops: bool = False) -> None: - assert set(a['state'].keys()) == set(b['state'].keys()) - for k, state in a['state'].items(): - b_state = b['state'][k] - for v1, v2 in zip(state.values(), b_state.values()): - if isinstance(v1, Tensor): - assert torch.equal(v1, v2) - else: - assert v1 == v2 - if not ignore_param_gruops: - assert a['param_groups'] == b['param_groups'] - - -class DummyModel(nn.Module): - - def __init__(self) -> None: - super().__init__() - self.fc = nn.Linear(20, 1) - - -def prepare_model_optim(): - model = DummyModel() - for p in model.parameters(): - p.grad = torch.ones_like(p) - optimizer = Adam(model.parameters(), lr=1e-3) - optimizer.step() - return model, optimizer - - -def test_overwrite(): - model = DummyModel() - with TemporaryDirectory() as dir_name: - with open(os.path.join(dir_name, MODEL_CKPT_FILE_NAME.replace('.bin', '-shard0.bin')), 'a') as f: - pass - with pytest.raises(RuntimeError, match=r'Save error: Checkpoint ".+" exists\. \(overwrite = False\)'): - save(dir_name, model) - - -def test_save_global(): - model, optimizer = prepare_model_optim() - with TemporaryDirectory() as dir_name: - save(dir_name, model, optimizer) - assert len(os.listdir(dir_name)) == 5 - global_meta = torch.load(os.path.join(dir_name, GLOBAL_META_FILE_NAME)) - assert len(global_meta['meta']) == 1 and global_meta['meta'][0] == META_CKPT_FILE_NAME - meta = torch.load(os.path.join(dir_name, META_CKPT_FILE_NAME)) - assert len(meta['model']) == 1 - assert len(meta['optimizer']) == 1 - model_state_dict = torch.load(os.path.join(dir_name, meta['model'][0])) - check_model_state_dict(model.state_dict(), model_state_dict) - optimizer_state_dict = torch.load(os.path.join(dir_name, meta['optimizer'][0])) - check_optim_state_dict(optimizer.state_dict(), optimizer_state_dict) - other_state_dict = torch.load(os.path.join(dir_name, OTHER_CKPT_FILE_NAME)) - assert len(other_state_dict) == 0 - - -def test_save_global_shard(): - model, optimizer = prepare_model_optim() - with TemporaryDirectory() as dir_name: - save(dir_name, model, optimizer, max_shard_size_gb=80 / 1024**3) - assert len(os.listdir(dir_name)) == 7 - meta = torch.load(os.path.join(dir_name, META_CKPT_FILE_NAME)) - assert len(meta['model']) == 2 and len(meta['optimizer']) == 2 - model_state_dicts = [torch.load(os.path.join(dir_name, name)) for name in meta['model']] - assert len(set(model_state_dicts[0].keys()) & set(model_state_dicts[1].keys())) == 0 - check_model_state_dict(model.state_dict(), {**model_state_dicts[0], **model_state_dicts[1]}) - optimizer_state_dicts = [torch.load(os.path.join(dir_name, name)) for name in meta['optimizer']] - assert len(set(optimizer_state_dicts[0]['state'].keys()) & set(optimizer_state_dicts[1]['state'].keys())) == 0 - assert 'param_groups' in optimizer_state_dicts[0] and 'param_groups' not in optimizer_state_dicts[1] - check_optim_state_dict( - optimizer.state_dict(), { - 'state': { - **optimizer_state_dicts[0]['state'], - **optimizer_state_dicts[1]['state'] - }, - 'param_groups': optimizer_state_dicts[0]['param_groups'] - }) - - -def run_dist(rank, world_size, port, test_fn): - colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') - test_fn() - - -def run_save_dist(dir_name): - model, optmizer = prepare_model_optim() - dist_metas = { - 'fc.weight': ParamDistMeta(dist.get_rank(), dist.get_world_size(), 0, 1), - 'fc.bias': ParamDistMeta(dist.get_rank(), dist.get_world_size(), 0, 1) - } - save(dir_name, model, optmizer, dist_meta=dist_metas) - - -@pytest.mark.dist -@rerun_if_address_is_in_use() -def test_save_dist(): - with TemporaryDirectory() as dir_name: - fn = partial(run_save_dist, dir_name) - world_size = 2 - spawn(run_dist, world_size, test_fn=fn) - assert len(os.listdir(dir_name)) == 8 - global_meta = torch.load(os.path.join(dir_name, GLOBAL_META_FILE_NAME)) - assert len(global_meta['meta']) == 2 - for rank, meta_name in enumerate(global_meta['meta']): - meta = torch.load(os.path.join(dir_name, meta_name)) - assert meta.get('dist_meta', None) is not None - assert len(meta['model']) == 1 and len(meta['optimizer']) == 1 - model_state_dict = torch.load(os.path.join(dir_name, meta['model'][0])) - assert len(model_state_dict) == 2 - optimizer_state_dict = torch.load(os.path.join(dir_name, meta['optimizer'][0])) - assert len(optimizer_state_dict['state']) == 2 - assert 'param_groups' in optimizer_state_dict - - -if __name__ == '__main__': - test_overwrite() - test_save_global() - test_save_global_shard() - test_save_dist() diff --git a/tests/test_utils/test_checkpoint_io/test_unmerge_param.py b/tests/test_utils/test_checkpoint_io/test_unmerge_param.py deleted file mode 100644 index 8b83caa12359ff8c404e65f96701a6c390f9792d..0000000000000000000000000000000000000000 --- a/tests/test_utils/test_checkpoint_io/test_unmerge_param.py +++ /dev/null @@ -1,137 +0,0 @@ -import torch -from colossalai.utils.checkpoint_io.meta import ParamRedistMeta -from colossalai.utils.checkpoint_io.distributed import flatten_zero_param, split_tp_param, unmerge_param - - -def test_flatten_zero_param_even() -> None: - redist_meta = ParamRedistMeta(4, 1, zero_start_dp_rank=0, zero_offsets=[0, 4, 8, 12]) - orig_tensor = torch.rand(4, 4) - tensors = list(orig_tensor.reshape(-1).chunk(4)) - flat_tensors = flatten_zero_param(orig_tensor, redist_meta) - assert len(tensors) == len(flat_tensors) - for t, st in zip(tensors, flat_tensors): - assert torch.equal(t, st) - unmerged_tensors = unmerge_param(orig_tensor, redist_meta) - assert len(unmerged_tensors) == 1 - unmerged_tensors = unmerged_tensors[0] - assert len(tensors) == len(unmerged_tensors) - for t, tl in zip(tensors, unmerged_tensors): - assert torch.equal(t, tl) - - -def test_flatten_zero_param_uneven() -> None: - redist_meta = ParamRedistMeta(4, 1, zero_start_dp_rank=1, zero_offsets=[0, 13]) - orig_tensor = torch.rand(4, 4) - tensors = list(orig_tensor.reshape(-1).split([13, 3])) - flat_tensors = flatten_zero_param(orig_tensor, redist_meta) - assert flat_tensors[0].size(0) == 0 and flat_tensors[-1].size(0) == 0 - flat_tensors = flat_tensors[1:-1] - assert len(tensors) == len(flat_tensors) - for t, st in zip(tensors, flat_tensors): - assert torch.equal(t, st) - unmerged_tensors = unmerge_param(orig_tensor, redist_meta) - assert len(unmerged_tensors) == 1 - unmerged_tensors = unmerged_tensors[0] - assert unmerged_tensors[0].size(0) == 0 and unmerged_tensors[-1].size(0) == 0 - unmerged_tensors = unmerged_tensors[1:-1] - assert len(tensors) == len(unmerged_tensors) - for t, tl in zip(tensors, unmerged_tensors): - assert torch.equal(t, tl) - - -def test_split_tp_param_1d_row() -> None: - redist_meta = ParamRedistMeta(1, 4, tp_shard_dims=[0], tp_num_parts=[4]) - orig_tensor = torch.rand(4, 4) - tensors = [t.contiguous() for t in orig_tensor.chunk(4, 0)] - split_tensors = split_tp_param(orig_tensor, redist_meta) - assert len(tensors) == len(split_tensors) - for t, st in zip(tensors, split_tensors): - assert torch.equal(t, st) - unmerged_tensors = unmerge_param(orig_tensor, redist_meta) - assert len(tensors) == len(unmerged_tensors) - for t, tl in zip(tensors, unmerged_tensors): - assert len(tl) == 1 - assert torch.equal(t, tl[0]) - - -def test_split_tp_param_1d_col() -> None: - redist_meta = ParamRedistMeta(1, 4, tp_shard_dims=[1], tp_num_parts=[4]) - orig_tensor = torch.rand(4, 4) - tensors = [t.contiguous() for t in orig_tensor.chunk(4, 1)] - split_tensors = split_tp_param(orig_tensor, redist_meta) - assert len(tensors) == len(split_tensors) - for t, st in zip(tensors, split_tensors): - assert torch.equal(t, st) - unmerged_tensors = unmerge_param(orig_tensor, redist_meta) - assert len(tensors) == len(unmerged_tensors) - for t, tl in zip(tensors, unmerged_tensors): - assert len(tl) == 1 - assert torch.equal(t, tl[0]) - - -def test_split_tp_param_2d() -> None: - redist_meta = ParamRedistMeta(1, 6, tp_shard_dims=[0, 1], tp_num_parts=[2, 3]) - orig_tensor = torch.rand(4, 6) - tensors = [t.contiguous() for tl in orig_tensor.chunk(2, 0) for t in tl.chunk(3, 1)] - split_tensors = split_tp_param(orig_tensor, redist_meta) - assert len(tensors) == len(split_tensors) - for t, st in zip(tensors, split_tensors): - assert torch.equal(t, st) - unmerged_tensors = unmerge_param(orig_tensor, redist_meta) - assert len(tensors) == len(unmerged_tensors) - for t, tl in zip(tensors, unmerged_tensors): - assert len(tl) == 1 - assert torch.equal(t, tl[0]) - - -def test_split_tp_param_2d_reverse() -> None: - redist_meta = ParamRedistMeta(1, 6, tp_shard_dims=[1, 0], tp_num_parts=[3, 2]) - orig_tensor = torch.rand(4, 6) - tensors = [t.contiguous() for tl in orig_tensor.chunk(2, 0) for t in tl.chunk(3, 1)] - split_tensors = split_tp_param(orig_tensor, redist_meta) - assert len(tensors) == len(split_tensors) - for t, st in zip(tensors, split_tensors): - assert torch.equal(t, st) - unmerged_tensors = unmerge_param(orig_tensor, redist_meta) - assert len(tensors) == len(unmerged_tensors) - for t, tl in zip(tensors, unmerged_tensors): - assert len(tl) == 1 - assert torch.equal(t, tl[0]) - - -def test_unmerge_param_hybrid() -> None: - redist_meta = ParamRedistMeta(2, - 6, - tp_shard_dims=[1, 0], - tp_num_parts=[3, 2], - zero_start_dp_rank=0, - zero_offsets=[0, 1]) - orig_tensor = torch.rand(4, 6) - tensors = [ - chunk for tl in orig_tensor.chunk(2, 0) for t in tl.chunk(3, 1) - for chunk in t.contiguous().reshape(-1).split([1, 3]) - ] - unmerged_tensors = unmerge_param(orig_tensor, redist_meta) - assert len(unmerged_tensors) == 6 and len(unmerged_tensors[0]) == 2 - for tp_rank in range(6): - for dp_rank in range(2): - assert torch.equal(tensors[tp_rank * 2 + dp_rank], unmerged_tensors[tp_rank][dp_rank]) - - -def test_unmerge_param_dummy() -> None: - redist_meta = ParamRedistMeta(1, 1) - orig_tensor = torch.rand(4, 6) - unmerged_tensors = unmerge_param(orig_tensor, redist_meta) - assert len(unmerged_tensors) == 1 and len(unmerged_tensors[0]) == 1 - assert torch.equal(orig_tensor, unmerged_tensors[0][0]) - - -if __name__ == '__main__': - test_flatten_zero_param_even() - test_flatten_zero_param_uneven() - test_split_tp_param_1d_row() - test_split_tp_param_1d_col() - test_split_tp_param_2d() - test_split_tp_param_2d_reverse() - test_unmerge_param_hybrid() - test_unmerge_param_dummy() diff --git a/tests/test_utils/test_colo_checkpoint.py b/tests/test_utils/test_colo_checkpoint.py deleted file mode 100644 index 89760a5456e774362204d09d049bcdd603beaae5..0000000000000000000000000000000000000000 --- a/tests/test_utils/test_colo_checkpoint.py +++ /dev/null @@ -1,206 +0,0 @@ -import os -import shutil -from copy import deepcopy - -import pytest -import torch -import torch.distributed as dist -from torch.optim.lr_scheduler import CosineAnnealingLR, MultiplicativeLR - -import colossalai -from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR -from colossalai.nn.optimizer import ColossalaiOptimizer -from colossalai.tensor import ColoTensor, ComputePattern, ComputeSpec, ProcessGroup, ShardSpec -from colossalai.testing import rerun_if_address_is_in_use, spawn -from colossalai.utils.checkpoint import load_checkpoint, save_checkpoint -from colossalai.utils.cuda import get_current_device -from colossalai.zero import ColoInitContext -from tests.components_to_test.registry import non_distributed_component_funcs - - -def init_1d_row_linear(weight: ColoTensor, pg: ProcessGroup): - spec = (ShardSpec([-1], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D)) - weight.set_process_group(pg) - weight.set_tensor_spec(*spec) - - -def init_1d_col_linear(weight, pg): - spec = (ShardSpec([0], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D)) - weight.set_process_group(pg) - weight.set_tensor_spec(*spec) - - -def init_1d_row_embedding(weight, pg): - spec = (ShardSpec([0], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D)) - weight.set_process_group(pg) - weight.set_tensor_spec(*spec) - - -def init_1d_col_embedding(weight, pg): - spec = (ShardSpec([-1], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D)) - weight.set_process_group(pg) - weight.set_tensor_spec(*spec) - - -def init_1d_row_for_linear_weight_spec(model, pg: ProcessGroup): - spec = (ShardSpec([-1], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D)) - for name, p in model.named_parameters(): - if not isinstance(p, ColoTensor): - continue - if 'embed' in name and 'weight' in name: - init_1d_col_embedding(p, pg) - if 'proj1' in name and ('weight' in name or 'bias' in name): - init_1d_col_linear(p, pg) - if 'proj2' in name and 'weight' in name: - init_1d_row_linear(p, pg) - if 'classifier' in name and ('weight' in name or 'bias' in name): - init_1d_col_linear(p, pg) - - -def check_param_equal(model, torch_model): - for (n, p), (tn, tp) in zip(model.named_parameters(), torch_model.named_parameters()): - assert torch.all(p.data == tp.data), "{} went wrong.\n {} vs {}\n{}".format(n, p, tp, p.shape) - - -def remove(path): - """ param could either be relative or absolute. """ - if os.path.isfile(path) or os.path.islink(path): - os.remove(path) - elif os.path.isdir(path): - shutil.rmtree(path) - else: - raise ValueError("file {} is not a file or dir.".format(path)) - - -def compare_optims(optim1, optim2): - state1 = optim1.state_dict()['state'] - state2 = optim2.state_dict()['state'] - for k, p1 in state1.items(): - if k not in state2: - continue - p2 = state2[k] - for n, t1 in p1.items(): - if n not in p2: - continue - t2 = p2[n] - if isinstance(t1, ColoTensor): - assert isinstance(t2, ColoTensor) - assert torch.allclose(t1, t2, rtol=0, atol=0) - - -def _run_checkpoint(model_name, init_spec_func, use_ddp, use_mp_reload, test_scheduler, pg): - get_components_func = non_distributed_component_funcs.get_callable(model_name) - model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func() - - rank = torch.distributed.get_rank() - world_size = torch.distributed.get_world_size() - - # set_seed(1) - with ColoInitContext(device=get_current_device()): - model = model_builder(checkpoint=True) - - if use_mp_reload: - if 'bert' == model_name: - for name, p in model.named_parameters(): - if not isinstance(p, ColoTensor): - continue - # num_class = type_vocab_size = 2 | (8, 2) - if 'classifier' in name and 'weight' in name: - init_1d_row_linear(p, pg) - # num_class = vocab_size = 30524 | (30524, 8) - elif 'word_embeddings' in name and 'weight' in name: - init_1d_row_embedding(p, pg) - # num_class = seq_len = 512 | (512, 8) - elif 'position_embeddings' in name and 'weight' in name: - init_1d_row_embedding(p, pg) - # num_class = type_vocab_size = 2 | (2, 8) - elif 'token_type_embeddings' in name and 'weight' in name: - init_1d_col_embedding(p, pg) - elif p.process_group.tp_world_size() == 1: - p.set_process_group(pg) - elif "simple_net" == model_name: - init_spec_func(model, pg) - - model_reload = deepcopy(model) - model = model.cuda() - model.eval() - - model_reload = model_reload.cuda() - model_reload.eval() - - opt_class = torch.optim.Adam - colo_optimizer = ColossalaiOptimizer(opt_class(model.parameters(), lr=0.1)) - colo_optimizer_reload = ColossalaiOptimizer(opt_class(model_reload.parameters(), lr=0.1)) - - for i, (data, label) in enumerate(train_dataloader): - - # Zero grad - colo_optimizer.zero_grad() - colo_optimizer_reload.zero_grad() - - data = data.to(get_current_device()) - label = label.to(get_current_device()) - - dist.broadcast(data, pg.tp_rank_list()[0], pg.tp_process_group()) - dist.broadcast(label, pg.tp_rank_list()[0], pg.tp_process_group()) - - # Bcast rank0 data to all processes - if criterion: - output = model(data) - output_reload = model_reload(data) - loss = criterion(output, label) - loss_reload = criterion(output_reload, label) - else: - loss = model(data, label) - loss_reload = model_reload(data, label) - - loss.backward() - loss_reload.backward() - - colo_optimizer.step() - colo_optimizer_reload.step() - - if i > 2: - break - - if not os.path.isdir('./checkpoint') and rank == 0: - os.mkdir('./checkpoint') - dist.barrier() - - save_checkpoint('./checkpoint', 0, model, colo_optimizer, None) - load_checkpoint('./checkpoint', 0, model_reload, colo_optimizer_reload, None) - - check_param_equal(model, model_reload) - compare_optims(colo_optimizer, colo_optimizer_reload) - - if rank == 0: - remove('./checkpoint') - dist.barrier() - - -def run_dist(rank, world_size, port, use_ddp, use_mp_reload, test_scheduler): - colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') - pg = ProcessGroup(tp_degree=world_size) - - # the data loader of BERT is in DDP mode, causing the input data is not replicated in the TP context - for model_name in ['bert']: - _run_checkpoint(model_name, - init_1d_row_for_linear_weight_spec, - use_ddp, - use_mp_reload, - test_scheduler=test_scheduler, - pg=pg) - - -@pytest.mark.dist -@pytest.mark.parametrize('world_size', [1, 2]) -@pytest.mark.parametrize('use_ddp', [False]) -@pytest.mark.parametrize('use_mp_reload', [True, False]) -# @pytest.mark.parametrize('test_scheduler', ['colossalai_cosine_warmup', 'torch_cosine', 'torch_lambda']) -@rerun_if_address_is_in_use() -def test_checkpoint(world_size, use_ddp, use_mp_reload, test_scheduler=None): - spawn(run_dist, world_size, use_ddp=use_ddp, use_mp_reload=use_mp_reload, test_scheduler=test_scheduler) - - -if __name__ == '__main__': - test_checkpoint(2, use_ddp=False, use_mp_reload=True, test_scheduler="torch_cosine") diff --git a/tests/test_utils/test_commons.py b/tests/test_utils/test_commons.py deleted file mode 100644 index 2633d7da21aa3e71c3764b83515901eaff41c24a..0000000000000000000000000000000000000000 --- a/tests/test_utils/test_commons.py +++ /dev/null @@ -1,41 +0,0 @@ -import torch - -import colossalai -from colossalai.testing import rerun_if_address_is_in_use, spawn -from colossalai.zero.legacy.gemini.tensor_utils import colo_model_data_tensor_move, colo_model_data_tensor_move_inline -from colossalai.zero.legacy.sharded_param import ShardedTensor - - -def run_tensor_move(rank, world_size, port): - colossalai.launch(config={}, rank=0, world_size=world_size, host='localhost', port=port, backend='nccl') - - src_t = torch.ones(2, 3).cuda() - tgt_t = torch.zeros(2, 3) - - colo_model_data_tensor_move(src_t, tgt_t) - assert (torch.sum(tgt_t) == 6.0), f"{torch.sum(tgt_t.payload)} vs. 6.0" - - src_t = torch.ones(2, 3) - tgt_t = torch.zeros(2, 3).cuda().half() - colo_model_data_tensor_move(src_t, tgt_t) - # the src_t has been removed - assert (src_t.numel() == 0) - assert (torch.sum(tgt_t) == 6.0), f"{torch.sum(tgt_t.payload)} vs. 6.0" - - src_t = ShardedTensor(torch.ones(2, 3)) - tgt_t = ShardedTensor(torch.zeros(2, 3).cuda().half()) - colo_model_data_tensor_move(src_t, tgt_t) - assert (torch.sum(tgt_t.payload) == 6.0), f"{torch.sum(tgt_t.payload)} vs. 6.0" - - assert (tgt_t.device.type == 'cuda') - colo_model_data_tensor_move_inline(tgt_t, torch.device('cpu')) - assert (tgt_t.device.type == 'cpu') - - -@rerun_if_address_is_in_use() -def test_tensor_move(): - spawn(run_tensor_move, 1) - - -if __name__ == '__main__': - test_tensor_move() diff --git a/tests/test_utils/test_flash_attention.py b/tests/test_utils/test_flash_attention.py index 7a28b0157384f2a664fe777d4743dd41462a3a16..a5c465ba0b07d1aa2a53292f2caad6f3f41fc2f2 100644 --- a/tests/test_utils/test_flash_attention.py +++ b/tests/test_utils/test_flash_attention.py @@ -1,113 +1,169 @@ -import random +import math import pytest import torch from einops import rearrange -from colossalai.kernel.cuda_native.flash_attention import HAS_MEM_EFF_ATTN +from colossalai.kernel.cuda_native.mha.flash_attn_2 import HAS_FLASH_ATTN +from colossalai.kernel.cuda_native.mha.mem_eff_attn import HAS_MEM_EFF_ATTN from colossalai.testing import clear_cache_before_run, parameterize -if HAS_MEM_EFF_ATTN: - from colossalai.kernel.cuda_native.flash_attention import AttnMaskType, ColoAttention +if HAS_MEM_EFF_ATTN or HAS_FLASH_ATTN: + from colossalai.kernel.cuda_native import ColoAttention + from colossalai.kernel.cuda_native.scaled_softmax import AttnMaskType +DTYPE = [torch.float16, torch.bfloat16, torch.float32] -def baseline_attention(Z, N_CTX, H, q, k, v, sm_scale): - M = torch.tril(torch.ones((N_CTX, N_CTX), device="cuda")) - p = torch.matmul(q, k.transpose(2, 3)) * sm_scale - for z in range(Z): - for h in range(H): - p[:, :, M == 0] = float("-inf") - p = torch.softmax(p.float(), dim=-1).half() - ref_out = torch.matmul(p, v) - return ref_out +def attention_ref(q, k, v, attn_mask=None, causal=False): + """ + attention output of the control group + """ + dtype_og = q.dtype + seqlen_q, seqlen_k = q.shape[1], k.shape[1] + d = q.shape[-1] + scale = 1.0 / math.sqrt(d) + scores = torch.einsum("bthd,bshd->bhts", q * scale, k) -@pytest.mark.skipif(HAS_MEM_EFF_ATTN == False, reason="xformers is not available") + if attn_mask is not None: + scores.masked_fill_(rearrange(~attn_mask, "b s -> b 1 1 s"), float("-inf")) + if causal: + causal_mask = torch.triu(torch.ones(seqlen_q, seqlen_k, dtype=torch.bool, device=q.device), 1) + scores.masked_fill_(causal_mask, float("-inf")) + attention = torch.softmax(scores, dim=-1) + + output = torch.einsum("bhts,bshd->bthd", attention, v) + output = rearrange(output, "b s h d -> b s (h d)") + + # Modify the data at the positions of the mask to 0 + if attn_mask is not None: + output.masked_fill_(rearrange(~attn_mask, "b s -> b s 1"), 0.0) + + return output.to(dtype=dtype_og) + + +@pytest.mark.skipif(not HAS_MEM_EFF_ATTN and not HAS_FLASH_ATTN, reason="xformers is not available") @clear_cache_before_run() -@parameterize('B, S, H, D_HEAD', [(6, 8, 4, 16)]) -def test_attention_gpt(B, S, H, D_HEAD, dtype=torch.float16): +@parameterize("proj_shape", [(6, 8, 4, 16)]) +@parameterize("dtype", DTYPE) +@parameterize("dropout", [0.0]) +def test_attention_gpt(proj_shape, dtype, dropout): + (B, S, H, D_HEAD) = proj_shape D = H * D_HEAD - c_attn = torch.nn.Linear(D, 3 * D, dtype=dtype, device="cuda") - attn = ColoAttention(D, H, dropout=0.1) + q = torch.randn((B, S, H, D_HEAD), dtype=dtype, device="cuda", requires_grad=True) + k = torch.randn((B, S, H, D_HEAD), dtype=dtype, device="cuda", requires_grad=True) + v = torch.randn((B, S, H, D_HEAD), dtype=dtype, device="cuda", requires_grad=True) - x = torch.randn((B, S, D), dtype=dtype, device="cuda") + mask = [torch.ones(S - i, dtype=torch.bool, device="cuda") for i in range(B)] + mask = torch.nn.utils.rnn.pad_sequence(mask, batch_first=True) - qkv = c_attn(x) - q, k, v = rearrange(qkv, 'b s (n h d) -> n b s h d', n=3, h=H) - y = attn(q, k, v, attn_mask_type=AttnMaskType.causal) + attn = ColoAttention(D, H, dropout=dropout) + y = attn(q, k, v, attn_mask=mask, attn_mask_type=AttnMaskType.paddedcausal) assert list(y.shape) == [B, S, D] + out_ref = attention_ref(q, k, v, mask, causal=True) + + # check gradients dy = torch.rand_like(y) - y.backward(dy) + grad_q, grad_k, grad_v = torch.autograd.grad(y, (q, k, v), dy) + grad_ref_q, grad_ref_k, grad_ref_v = torch.autograd.grad(out_ref, (q, k, v), dy) + + torch.allclose(y, out_ref, atol=1e-7), f"{(y - out_ref).abs().max()}" + torch.allclose(grad_q, grad_ref_q, atol=1e-7), f"{(grad_q - grad_ref_q).abs().max()}" + torch.allclose(grad_k, grad_ref_k, atol=1e-7), f"{(grad_k - grad_ref_k).abs().max()}" + torch.allclose(grad_v, grad_ref_v, atol=1e-7), f"{(grad_v - grad_ref_v).abs().max()}" -@pytest.mark.skipif(HAS_MEM_EFF_ATTN == False, reason="xformers is not available") +@pytest.mark.skipif(not HAS_MEM_EFF_ATTN and not HAS_FLASH_ATTN, reason="xformers is not available") @clear_cache_before_run() -@parameterize('B, S, H, D_HEAD', [(6, 8, 4, 16)]) -def test_attention_bert(B, S, H, D_HEAD, dtype=torch.float16): +@parameterize("proj_shape", [(6, 8, 4, 16)]) +@parameterize("dtype", DTYPE) +@parameterize("dropout", [0.0]) +def test_attention_bert(proj_shape, dtype, dropout): + (B, S, H, D_HEAD) = proj_shape D = H * D_HEAD - c_attn = torch.nn.Linear(D, 3 * D, dtype=dtype, device="cuda") - attn = ColoAttention(D, H, dropout=0.1) + q = torch.randn((B, S, H, D_HEAD), dtype=dtype, device="cuda", requires_grad=True) + k = torch.randn((B, S, H, D_HEAD), dtype=dtype, device="cuda", requires_grad=True) + v = torch.randn((B, S, H, D_HEAD), dtype=dtype, device="cuda", requires_grad=True) - x = torch.randn((B, S, D), dtype=dtype, device="cuda") # attention mask of shape [B, S] with zero padding to max length S - mask = [torch.ones(S - i, dtype=dtype, device="cuda") for i in range(B)] - mask = torch.nn.utils.rnn.pad_sequence(mask, batch_first=True) + mask = torch.randint(0, 2, (B, S), dtype=torch.bool, device="cuda") - qkv = c_attn(x) - q, k, v = rearrange(qkv, 'b s (n h d) -> b s n h d', n=3, h=H).unbind(dim=2) + attn = ColoAttention(D, H, dropout=dropout) y = attn(q, k, v, attn_mask=mask, attn_mask_type=AttnMaskType.padding) assert list(y.shape) == [B, S, D] + out_ref = attention_ref(q, k, v, mask, causal=False) + dy = torch.rand_like(y) - y.backward(dy) + grad_q, grad_k, grad_v = torch.autograd.grad(y, (q, k, v), dy) + grad_ref_q, grad_ref_k, grad_ref_v = torch.autograd.grad(out_ref, (q, k, v), dy) + + torch.allclose(y, out_ref, atol=1e-7), f"{(y - out_ref).abs().max()}" + torch.allclose(grad_q, grad_ref_q, atol=1e-7), f"{(grad_q - grad_ref_q).abs().max()}" + torch.allclose(grad_k, grad_ref_k, atol=1e-7), f"{(grad_k - grad_ref_k).abs().max()}" + torch.allclose(grad_v, grad_ref_v, atol=1e-7), f"{(grad_v - grad_ref_v).abs().max()}" -@pytest.mark.skipif(HAS_MEM_EFF_ATTN == False, reason="xformers is not available") +@pytest.mark.skipif(not HAS_MEM_EFF_ATTN and not HAS_FLASH_ATTN, reason="xformers is not available") @clear_cache_before_run() -@parameterize('B, S, H, D_HEAD', [(6, 8, 4, 16)]) -def test_attention_no_mask(B, S, H, D_HEAD, dtype=torch.float16): +@parameterize("proj_shape", [(6, 8, 4, 16)]) +@parameterize("dtype", DTYPE) +@parameterize("dropout", [0.0]) +def test_attention_no_mask(proj_shape, dtype, dropout): + (B, S, H, D_HEAD) = proj_shape D = H * D_HEAD - c_attn = torch.nn.Linear(D, 3 * D, dtype=dtype, device="cuda") - attn = ColoAttention(D, H, dropout=0.1) + q = torch.randn((B, S, H, D_HEAD), dtype=dtype, device="cuda", requires_grad=True) + k = torch.randn((B, S, H, D_HEAD), dtype=dtype, device="cuda", requires_grad=True) + v = torch.randn((B, S, H, D_HEAD), dtype=dtype, device="cuda", requires_grad=True) - x = torch.randn((B, S, D), dtype=dtype, device="cuda") - qkv = c_attn(x) - q, k, v = rearrange(qkv, 'b s (n h d) -> b s n h d', n=3, h=H).unbind(dim=2) + attn = ColoAttention(D, H, dropout=dropout) y = attn(q, k, v) assert list(y.shape) == [B, S, D] + out_ref = attention_ref(q, k, v, None, causal=False) + dy = torch.rand_like(y) - y.backward(dy) + grad_q, grad_k, grad_v = torch.autograd.grad(y, (q, k, v), dy) + grad_ref_q, grad_ref_k, grad_ref_v = torch.autograd.grad(out_ref, (q, k, v), dy) + + torch.allclose(y, out_ref, atol=1e-7), f"{(y - out_ref).abs().max()}" + torch.allclose(grad_q, grad_ref_q, atol=1e-7), f"{(grad_q - grad_ref_q).abs().max()}" + torch.allclose(grad_k, grad_ref_k, atol=1e-7), f"{(grad_k - grad_ref_k).abs().max()}" + torch.allclose(grad_v, grad_ref_v, atol=1e-7), f"{(grad_v - grad_ref_v).abs().max()}" -@pytest.mark.skipif(HAS_MEM_EFF_ATTN == False, reason="xformers is not available") +@pytest.mark.skipif(not HAS_MEM_EFF_ATTN and not HAS_FLASH_ATTN, reason="xformers is not available") @clear_cache_before_run() -@parameterize('B, S, T, H, D_HEAD', [(6, 24, 8, 4, 16)]) -def test_cross_attention(B, S, T, H, D_HEAD, dtype=torch.float16): +@parameterize("proj_shape", [(6, 24, 8, 4, 16)]) +@parameterize("dtype", DTYPE) +@parameterize("dropout", [0.0]) +def test_cross_attention(proj_shape, dtype, dropout): + (B, S, T, H, D_HEAD) = proj_shape D = H * D_HEAD - q_attn = torch.nn.Linear(D, D, dtype=dtype, device="cuda") - kv_attn = torch.nn.Linear(D, 2 * D, dtype=dtype, device="cuda") + q = torch.randn((B, T, H, D_HEAD), dtype=dtype, device="cuda", requires_grad=True) + k = torch.randn((B, S, H, D_HEAD), dtype=dtype, device="cuda", requires_grad=True) + v = torch.randn((B, S, H, D_HEAD), dtype=dtype, device="cuda", requires_grad=True) - attn = ColoAttention(D, H, dropout=0.1) - - src = torch.randn((B, S, D), dtype=dtype, device="cuda") - tgt = torch.randn((B, T, D), dtype=dtype, device="cuda") - - q = q_attn(tgt) - kv = kv_attn(src) - q = rearrange(q, 'b s (h d) -> b s h d', h=H) - k, v = rearrange(kv, 'b s (n h d) -> b s n h d', n=2, h=H).unbind(dim=2) + attn = ColoAttention(D, H, dropout=dropout) y = attn(q, k, v, attn_mask_type=AttnMaskType.causal) assert list(y.shape) == [B, T, D] + out_ref = attention_ref(q, k, v, None, causal=True) + dy = torch.rand_like(y) - y.backward(dy) + grad_q, grad_k, grad_v = torch.autograd.grad(y, (q, k, v), dy) + grad_ref_q, grad_ref_k, grad_ref_v = torch.autograd.grad(out_ref, (q, k, v), dy) + + torch.allclose(y, out_ref, atol=1e-18), f"{(y - out_ref).abs().max()}" + torch.allclose(grad_q, grad_ref_q, atol=1e-7), f"{(grad_q - grad_ref_q).abs().max()}" + torch.allclose(grad_k, grad_ref_k, atol=1e-7), f"{(grad_k - grad_ref_k).abs().max()}" + torch.allclose(grad_v, grad_ref_v, atol=1e-7), f"{(grad_v - grad_ref_v).abs().max()}" diff --git a/tests/test_utils/test_lazy_init/test_distribute.py b/tests/test_utils/test_lazy_init/test_distribute.py deleted file mode 100644 index 2c15ca84efaad42a112207e0e5127c5c513e6985..0000000000000000000000000000000000000000 --- a/tests/test_utils/test_lazy_init/test_distribute.py +++ /dev/null @@ -1,109 +0,0 @@ -from typing import Optional - -import pytest -import torch -import torch.nn as nn - -import colossalai -from colossalai.device.device_mesh import DeviceMesh -from colossalai.tensor.d_tensor.layout import Layout -from colossalai.tensor.d_tensor.sharding_spec import ShardingSpec -from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn -from colossalai.utils.common import print_rank_0 - -try: - from colossalai.utils.model.experimental import LazyInitContext, LazyTensor, _MyTensor -except: - pass -from tests.kit.model_zoo import model_zoo - -# from utils import assert_dist_model_equal, set_seed - - -def find_shard_dim(shape: torch.Size) -> Optional[int]: - for dim, size in enumerate(shape): - if size % 2 == 0: - return dim - - -def make_layout(device_mesh: DeviceMesh, original_tensor: torch.Tensor) -> Layout: - shard_dim = find_shard_dim(original_tensor.shape) - dim_partition_dict = {shard_dim: [0]} if shard_dim is not None else {} - target_sharding_spec = ShardingSpec(dim_size=original_tensor.dim(), dim_partition_dict=dim_partition_dict) - layout = Layout(device_mesh=device_mesh, - device_type=torch.device('cuda'), - sharding_spec=target_sharding_spec, - entire_shape=original_tensor.shape) - return layout - - -def _get_current_name(prefix: str, name: str) -> str: - return f'{prefix}.{name}'.lstrip('.') - - -def generate_layout_dict(model: nn.Module, device_mesh: DeviceMesh) -> dict: - layout_dict = {} - - @torch.no_grad() - def generate_recursively(module: nn.Module, prefix: str = ''): - # recursively initialize the module - for name, mod in module.named_children(): - generate_recursively(mod, prefix=_get_current_name(prefix, name)) - - # initialize tensors directly attached to the current module - for name, param in module.named_parameters(recurse=False): - if isinstance(param, LazyTensor): - layout = make_layout(device_mesh, param) - layout_dict[_get_current_name(prefix, name)] = layout - - for name, buf in module.named_buffers(recurse=False): - if isinstance(buf, LazyTensor): - layout = make_layout(device_mesh, buf) - layout_dict[_get_current_name(prefix, name)] = layout - - generate_recursively(model) - - return layout_dict - - -@parameterize('subset', ['torchvision', 'diffusers', 'timm', 'transformers', 'torchaudio', 'deepfm', 'dlrm']) -def run_dist_lazy_init(subset, seed: int = 42): - sub_model_zoo = model_zoo.get_sub_registry(subset) - device_mesh = DeviceMesh(torch.Tensor([0, 1, 2, 3]), (2, 2), init_process_group=True) - # FIXME(ver217): uncomment this line - # _MyTensor._pre_op_fn = lambda *args: set_seed(seed) - # LazyTensor._pre_op_fn = lambda *args: set_seed(seed) - - for name, entry in sub_model_zoo.items(): - # TODO(ver217): lazy init does not support weight norm, skip these models - if name in ('torchaudio_wav2vec2_base', 'torchaudio_hubert_base'): - continue - print_rank_0(name) - model_fn, data_gen_fn, output_transform_fn, model_attr = entry - ctx = LazyInitContext(tensor_cls=_MyTensor) - with ctx: - model = model_fn() - ctx = LazyInitContext() - with ctx: - deferred_model = model_fn() - layout_dict = generate_layout_dict(deferred_model, device_mesh) - ctx.distribute(deferred_model, layout_dict, verbose=True) - # FIXME(ver217): uncomment this line - # assert_dist_model_equal(model, deferred_model, layout_dict) - - -def run_dist(rank, world_size, port) -> None: - colossalai.launch({}, rank=rank, world_size=world_size, host='localhost', port=port) - run_dist_lazy_init() - - -# FIXME(ver217): temporarily skip this test since torch 1.11 does not fully support meta tensor -@pytest.mark.skip -@pytest.mark.dist -@rerun_if_address_is_in_use() -def test_dist_lazy_init(): - spawn(run_dist, 4) - - -if __name__ == '__main__': - test_dist_lazy_init() diff --git a/tests/test_utils/test_lazy_init/test_models.py b/tests/test_utils/test_lazy_init/test_models.py deleted file mode 100644 index 9faddecbaca4dfa57b191b3ed7e351a847f235f3..0000000000000000000000000000000000000000 --- a/tests/test_utils/test_lazy_init/test_models.py +++ /dev/null @@ -1,23 +0,0 @@ -import pytest - -from tests.kit.model_zoo import model_zoo - -# FIXME(ver217): uncomment this line -# from utils import check_lazy_init - - -# FIXME(ver217): temporarily skip this test since torch 1.11 does not fully support meta tensor -@pytest.mark.skip -@pytest.mark.parametrize('subset', ['torchvision', 'diffusers', 'timm', 'transformers', 'torchaudio', 'deepfm', 'dlrm']) -def test_torchvision_models_lazy_init(subset): - sub_model_zoo = model_zoo.get_sub_registry(subset) - for name, entry in sub_model_zoo.items(): - # TODO(ver217): lazy init does not support weight norm, skip these models - if name in ('torchaudio_wav2vec2_base', 'torchaudio_hubert_base'): - continue - # FIXME(ver217): uncomment this line - # check_lazy_init(entry, verbose=True) - - -if __name__ == '__main__': - test_torchvision_models_lazy_init('torchvision') diff --git a/tests/test_utils/test_lazy_init/utils.py b/tests/test_utils/test_lazy_init/utils.py deleted file mode 100644 index a8aeb4c8930c396bd2e3d9c72497f9d4e426a464..0000000000000000000000000000000000000000 --- a/tests/test_utils/test_lazy_init/utils.py +++ /dev/null @@ -1,85 +0,0 @@ -import random -from typing import Any, Callable, Optional, Tuple - -import numpy as np -import torch - -from colossalai.tensor.d_tensor.layout_converter import to_global -from colossalai.utils.model.experimental import LazyInitContext, LazyTensor, _MyTensor -from tests.kit.model_zoo.registry import ModelAttribute - -# model_fn, data_gen_fn, output_transform_fn, model_attr -TestingEntry = Tuple[Callable[[], torch.nn.Module], Callable[[], dict], Callable[[], dict], Optional[ModelAttribute]] - - -def set_seed(seed: int) -> None: - random.seed(seed) - np.random.seed(seed) - torch.manual_seed(seed) - - -def assert_model_eqaual(m1: torch.nn.Module, m2: torch.nn.Module) -> None: - s1 = m1.state_dict() - s2 = m2.state_dict() - - assert len(s1) == len(s2), f'len {len(s1)} vs {len(s2)}' - - for (n1, t1), (n2, t2) in zip(s1.items(), s2.items()): - assert n1 == n2 - assert torch.equal(t1, t2), f'{n1} {t1} vs {t2}' - - -def assert_forward_equal(m1: torch.nn.Module, m2: torch.nn.Module, data_gen_fn: Callable[[], dict], - output_transform_fn: Callable[[Any], dict]) -> None: - data = data_gen_fn() - - m1.eval() - m2.eval() - # run forward - with torch.no_grad(): - outputs1 = m1(**data) - outputs2 = m2(**data) - - # compare output - transformed_out1 = output_transform_fn(outputs1) - transformed_out2 = output_transform_fn(outputs2) - - assert len(transformed_out1) == len(transformed_out2) - - for key, out1 in transformed_out1.items(): - out2 = transformed_out2[key] - assert torch.allclose(out1, out2, atol=1e-5), \ - f'{m1.__class__.__name__} has inconsistent outputs, {out1} vs {out2}' - - -def check_lazy_init(entry: TestingEntry, seed: int = 42, verbose: bool = False, check_forward: bool = False) -> None: - model_fn, data_gen_fn, output_transform_fn, model_attr = entry - _MyTensor._pre_op_fn = lambda *args: set_seed(seed) - LazyTensor._pre_op_fn = lambda *args: set_seed(seed) - ctx = LazyInitContext(tensor_cls=_MyTensor) - with ctx: - model = model_fn() - ctx = LazyInitContext() - with ctx: - deferred_model = model_fn() - deferred_model = ctx.materialize(deferred_model, verbose=verbose) - assert_model_eqaual(model, deferred_model) - if check_forward: - assert_forward_equal(model, deferred_model, data_gen_fn, output_transform_fn) - if verbose: - print(f'{model.__class__.__name__} pass') - - -def assert_dist_model_equal(model: torch.nn.Module, distributed_model: torch.nn.Module, layout_dict: dict) -> None: - state = model.state_dict() - distributed_state = distributed_model.state_dict() - - assert len(state) == len(distributed_state), f'len {len(state)} vs {len(distributed_state)}' - - for (n1, t1), (n2, t2) in zip(state.items(), distributed_state.items()): - assert n1 == n2 - t1 = t1.cuda() - t2 = t2.cuda() - if n2 in layout_dict: - t2 = to_global(t2, layout_dict[n2]) - assert torch.equal(t1, t2), f'{n1} {t1} vs {t2}' diff --git a/tests/test_utils/test_lazy_init_ctx.py b/tests/test_utils/test_lazy_init_ctx.py deleted file mode 100644 index 97efb3367490e772f351939bec7949fd86ad4da3..0000000000000000000000000000000000000000 --- a/tests/test_utils/test_lazy_init_ctx.py +++ /dev/null @@ -1,51 +0,0 @@ -import torch -from colossalai.utils.model.lazy_init_context import LazyInitContext -from torchvision.models import resnet34 -import random -import numpy as np - -MANUAL_SEED = 0 -random.seed(MANUAL_SEED) -np.random.seed(MANUAL_SEED) -torch.manual_seed(MANUAL_SEED) - - -def test_lazy_init_with_meta(): - ctx = LazyInitContext(to_meta=True) - with ctx: - model = resnet34(num_classes=10) - - for param in model.parameters(): - assert param.is_meta - for buffer in model.buffers(): - assert buffer.is_meta - - ctx.lazy_init_parameters(model) - - for name, param in model.named_parameters(): - assert not param.is_meta, name - - for buffer in model.buffers(): - assert not buffer.is_meta - - -def test_lazy_init_without_meta(): - ctx = LazyInitContext(to_meta=False) - with ctx: - model = resnet34(num_classes=10) - - for param in model.parameters(): - assert not param.is_meta - for buffer in model.buffers(): - assert not buffer.is_meta - - conv1_weight_before_init = model.conv1.weight.clone() - ctx.lazy_init_parameters(model) - conv1_weight_after_init = model.conv1.weight.clone() - - assert not torch.allclose(conv1_weight_after_init, conv1_weight_before_init) - - -if __name__ == '__main__': - test_lazy_init_with_meta() - test_lazy_init_without_meta() diff --git a/tests/test_utils/test_memory.py b/tests/test_utils/test_memory.py deleted file mode 100644 index c88c2f8ec3c5abbba033abb682fd788435212e72..0000000000000000000000000000000000000000 --- a/tests/test_utils/test_memory.py +++ /dev/null @@ -1,28 +0,0 @@ -import pytest - -import colossalai -from colossalai.testing import spawn -from colossalai.utils.cuda import get_current_device -from colossalai.utils.memory import colo_device_memory_capacity, colo_set_process_memory_fraction - - -def _run_colo_set_process_memory_fraction_and_colo_device_memory_capacity(): - frac1 = colo_device_memory_capacity(get_current_device()) - colo_set_process_memory_fraction(0.5) - frac2 = colo_device_memory_capacity(get_current_device()) - assert frac2 * 2 == frac1 - - -def run_dist(rank, world_size, port): - colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') - _run_colo_set_process_memory_fraction_and_colo_device_memory_capacity() - - -@pytest.mark.dist -@pytest.mark.parametrize("world_size", [3, 4]) -def test_memory_utils(world_size): - spawn(run_dist, world_size) - - -if __name__ == '__main__': - test_memory_utils(world_size=2) diff --git a/tests/test_utils/test_norm_gradient_clipping.py b/tests/test_utils/test_norm_gradient_clipping.py deleted file mode 100644 index c0d678026c5fe184264ba03b9a7cdb0dbc5b1a06..0000000000000000000000000000000000000000 --- a/tests/test_utils/test_norm_gradient_clipping.py +++ /dev/null @@ -1,77 +0,0 @@ -import pytest -import torch -from torch.nn.parameter import Parameter -from torch.nn.utils import clip_grad_norm_ - -import colossalai -from colossalai.logging import disable_existing_loggers -from colossalai.tensor import ColoTensorSpec, ProcessGroup, distspec -from colossalai.tensor.colo_parameter import ColoParameter -from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn -from colossalai.utils import get_current_device -from colossalai.utils.common import clip_grad_norm - - -def close(num: float, other: float, rtol: float = 1e-5, atol: float = 1e-8): - return abs(num - other) <= atol + rtol * other - - -def shard_param(p: ColoParameter) -> None: - pg = p.get_process_group() - p._redistribute(distspec.ShardSpec([0], [pg.tp_world_size()])) - p.grad = p.grad.chunk(pg.tp_world_size(), 0)[pg.tp_local_rank()].clone().detach() - - -def check_grad_equal(p: Parameter, colo_p: ColoParameter) -> None: - pg = colo_p.get_process_group() - if p.shape != colo_p.shape: - grad = p.grad.chunk(pg.tp_world_size(), 0)[pg.tp_local_rank()] - else: - grad = p.grad - assert torch.allclose(grad, colo_p.grad), f'diff: {torch.abs(grad - colo_p.grad)}' - - -@parameterize('dtype', [torch.float]) -@parameterize('device', ['mixed', 'cuda', 'cpu']) -@parameterize('norm_type', [2.0, 3.0, float('inf')]) -def run_grad_clip_norm(world_size: int, dtype: torch.dtype, device: str, norm_type: float): - print(f'{world_size}, {dtype}, {device}, {norm_type}') - cuda_device = get_current_device() - devices = [cuda_device] * 4 - if device == 'cpu': - devices = [torch.device('cpu')] * 4 - elif device == 'mixed': - devices = [cuda_device] * 2 + [torch.device('cpu')] * 2 - pg = ProcessGroup(tp_degree=world_size) - params = [Parameter(torch.empty(4, 4, dtype=dtype, device=devices[i])) for i in range(4)] - colo_params = [ - ColoParameter(torch.empty(4, 4, dtype=dtype, device=devices[i]), spec=ColoTensorSpec(pg)) for i in range(4) - ] - for p, colo_p in zip(params, colo_params): - grad = torch.rand_like(p) - p.grad = grad - colo_p.grad = grad.clone().detach() - shard_param(colo_params[0]) - shard_param(colo_params[2]) - torch_norm = clip_grad_norm_(params, 1.0, norm_type=norm_type) - colo_norm = clip_grad_norm(colo_params, 1.0, norm_type=norm_type) - assert close(torch_norm, colo_norm), f'diff: {abs(torch_norm-colo_norm)}' - for p, colo_p in zip(params, colo_params): - check_grad_equal(p, colo_p) - - -def run_dist(rank, world_size, port): - disable_existing_loggers() - colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') - run_grad_clip_norm(world_size=world_size) - - -@pytest.mark.dist -@pytest.mark.parametrize('world_size', [1, 2]) -@rerun_if_address_is_in_use() -def test_zero_clip_grad(world_size: int): - spawn(run_dist, world_size) - - -if __name__ == '__main__': - test_zero_clip_grad(2) diff --git a/tests/test_utils/test_zero_gradient_clippling.py b/tests/test_utils/test_zero_gradient_clippling.py deleted file mode 100644 index e99cf388e929df3f34616ced50c4da851205c369..0000000000000000000000000000000000000000 --- a/tests/test_utils/test_zero_gradient_clippling.py +++ /dev/null @@ -1,111 +0,0 @@ -#!/usr/bin/env python -# -*- encoding: utf-8 -*- - -from functools import partial - -import pytest -import torch -import torch.distributed as dist -import torch.nn as nn -from torch.nn.parallel import DistributedDataParallel as DDP -from torch.nn.utils import clip_grad_norm_ - -import colossalai -from colossalai.logging import disable_existing_loggers -from colossalai.testing import rerun_if_address_is_in_use, spawn -from colossalai.utils import checkpoint, clip_grad_norm_fp32 -from colossalai.zero.legacy.shard_utils.tensor_shard_strategy import TensorShardStrategy -from colossalai.zero.legacy.sharded_model.sharded_model_v2 import ShardedModelV2 - - -def checkpoint_wrapper(module, enable=True): - if enable: - module.forward = partial(checkpoint, module.forward, False) - return module - - -class Net(nn.Module): - - def __init__(self, checkpoint=False) -> None: - super().__init__() - self.fc1 = nn.Linear(5, 5) - self.fc2 = nn.Linear(5, 5) - self.fc3 = nn.Linear(5, 1) - if checkpoint: - self.fc1 = checkpoint_wrapper(self.fc1) - self.layers = [self.fc1, self.fc2, self.fc1, self.fc2, self.fc3] - - def forward(self, x): - for layer in self.layers: - x = layer(x) - return x - - -def run_step(model, optimizer, x, enable_autocast=False, norm_type=2.0): - model.train() - optimizer.zero_grad() - with torch.cuda.amp.autocast(enabled=enable_autocast): - y = model(x) - loss = y.sum() - loss = loss.float() - loss.backward() - clip_grad(model, norm_type) - optimizer.step() - - -def clip_grad(model, norm_type): - if isinstance(model, DDP): - clip_grad_norm_(model.parameters(), max_norm=1.0, norm_type=norm_type) - else: - clip_grad_norm_fp32(model.parameters(), max_norm=1.0, norm_type=norm_type) - - -def allclose(tensor_a: torch.Tensor, tensor_b: torch.Tensor, loose=False) -> bool: - if loose: - return torch.allclose(tensor_a, tensor_b, atol=1e-3, rtol=1e-3) - return torch.allclose(tensor_a, tensor_b) - - -def check_grads(model, zero_model, loose=False): - rank = dist.get_rank() - for p, zero_p in zip(model.parameters(), zero_model.parameters()): - zero_grad = zero_p.grad.clone().to(p.device) - chunks = torch.flatten(p.grad).chunk(4) - if rank >= len(chunks): - continue - grad = chunks[rank] - if zero_p.zero_shard_padding > 0: - zero_grad = zero_grad[:-zero_p.zero_shard_padding] - assert grad.dtype == zero_grad.dtype - assert allclose(grad, zero_grad, loose=loose) - - -def check_params(model, zero_model, loose=False): - rank = dist.get_rank() - for p, zero_p in zip(model.parameters(), zero_model.parameters()): - zero_shard_padding = zero_p.zero_shard_padding - zero_p = zero_p.clone().to(p.device) - chunks = torch.flatten(p).chunk(4) - if rank >= len(chunks): - continue - p = chunks[rank] - if zero_shard_padding > 0: - zero_p = zero_p[:-zero_shard_padding] - assert p.dtype == zero_p.dtype - assert allclose(p, zero_p, loose=loose) - - -def run_dist(rank, world_size, port): - disable_existing_loggers() - colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') - - -@pytest.mark.dist -@rerun_if_address_is_in_use() -def test_zero_clip_grad(): - world_size = 4 - spawn(run_dist, world_size) - - -if __name__ == '__main__': - test_zero_clip_grad() diff --git a/tests/test_zero/test_gemini/test_chunk_mgrv2.py b/tests/test_zero/test_gemini/test_chunk_mgrv2.py index 7ea063877b5c483017c66c52ab5cafceff56d0fd..879eeccde3b41622c1a83074f2946b4772c350c3 100644 --- a/tests/test_zero/test_gemini/test_chunk_mgrv2.py +++ b/tests/test_zero/test_gemini/test_chunk_mgrv2.py @@ -1,67 +1,64 @@ import pytest import torch +from torch.distributed.distributed_c10d import _get_default_group import colossalai -from colossalai.tensor import ColoTensor, ColoTensorSpec, ProcessGroup +from colossalai.tensor import ColoTensor from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn from colossalai.zero.gemini.chunk import ChunkManager -from tests.test_tensor.common_utils import debug_print CUDA_MEM_0 = {False: 512, True: 1024} CUDA_MEM_1 = {False: 0, True: 1024} CPU_MEM = {True: {True: 0, False: 0}, False: {True: 512, False: 0}} -@parameterize('keep_gathered', [True, False]) -@parameterize('pin_memory', [True, False]) +@parameterize("keep_gathered", [True, False]) +@parameterize("pin_memory", [True, False]) def exam_chunk_memory(keep_gathered, pin_memory): - pg = ProcessGroup() - - debug_print([0], "keep_gathered: {}, pin_memory: {}".format(keep_gathered, pin_memory)) - - params = [ColoTensor(torch.rand(8, 8), spec=ColoTensorSpec(pg)) for _ in range(3)] + params = [ColoTensor(torch.rand(8, 8)) for _ in range(3)] config = {2: dict(chunk_size=128, keep_gathered=keep_gathered)} chunk_manager = ChunkManager(config) - assert chunk_manager.total_mem['cpu'] == 0 - assert chunk_manager.total_mem['cuda'] == 0 + assert chunk_manager.total_mem["cpu"] == 0 + assert chunk_manager.total_mem["cuda"] == 0 + process_group = _get_default_group() for p in params: - chunk_manager.register_tensor(p, 'param', 2, pin_memory=pin_memory) + chunk_manager.register_tensor(p, "param", 2, process_group, pin_memory=pin_memory) chunk_manager.close_all_groups() - assert chunk_manager.total_mem['cpu'] == CPU_MEM[keep_gathered][pin_memory] - assert chunk_manager.total_mem['cuda'] == CUDA_MEM_0[keep_gathered] + assert chunk_manager.total_mem["cpu"] == CPU_MEM[keep_gathered][pin_memory] + assert chunk_manager.total_mem["cuda"] == CUDA_MEM_0[keep_gathered] chunks = chunk_manager.get_chunks(params) for chunk in chunks: chunk_manager.access_chunk(chunk) - assert chunk_manager.total_mem['cpu'] == CPU_MEM[keep_gathered][pin_memory] - assert chunk_manager.total_mem['cuda'] == CUDA_MEM_0[True] + assert chunk_manager.total_mem["cpu"] == CPU_MEM[keep_gathered][pin_memory] + assert chunk_manager.total_mem["cuda"] == CUDA_MEM_0[True] for chunk in chunks: chunk_manager.release_chunk(chunk) - assert chunk_manager.total_mem['cpu'] == CPU_MEM[keep_gathered][pin_memory] - assert chunk_manager.total_mem['cuda'] == CUDA_MEM_0[keep_gathered] + assert chunk_manager.total_mem["cpu"] == CPU_MEM[keep_gathered][pin_memory] + assert chunk_manager.total_mem["cuda"] == CUDA_MEM_0[keep_gathered] for chunk in chunks: - chunk_manager.move_chunk(chunk, torch.device('cpu')) - assert chunk_manager.total_mem['cpu'] == CPU_MEM[keep_gathered][True] - assert chunk_manager.total_mem['cuda'] == CUDA_MEM_1[keep_gathered] + chunk_manager.move_chunk(chunk, torch.device("cpu")) + assert chunk_manager.total_mem["cpu"] == CPU_MEM[keep_gathered][True] + assert chunk_manager.total_mem["cuda"] == CUDA_MEM_1[keep_gathered] def run_dist(rank, world_size, port): - colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") exam_chunk_memory() @pytest.mark.dist -@pytest.mark.parametrize('world_size', [2]) +@pytest.mark.parametrize("world_size", [2]) @rerun_if_address_is_in_use() def test_chunk_manager(world_size): spawn(run_dist, world_size) -if __name__ == '__main__': +if __name__ == "__main__": test_chunk_manager(2) diff --git a/tests/test_zero/test_gemini/test_chunkv2.py b/tests/test_zero/test_gemini/test_chunkv2.py index 16764aa6b0b1b1234eb62e74e9c5f2fae4994a22..a31c888e966d8692a96236a9034b2d792d7570fd 100644 --- a/tests/test_zero/test_gemini/test_chunkv2.py +++ b/tests/test_zero/test_gemini/test_chunkv2.py @@ -1,10 +1,10 @@ import pytest import torch import torch.distributed as dist +from torch.distributed.distributed_c10d import _get_default_group import colossalai from colossalai.tensor import ColoParameter -from colossalai.tensor import ProcessGroup as ColoProcessGroup from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn from colossalai.utils import get_current_device from colossalai.zero.gemini import TensorState @@ -23,7 +23,7 @@ def add_param(param_list, param_cp_list, *args, **kwargs): param_cp_list.append(param.clone()) -def check_euqal(param, param_cp): +def check_equal(param, param_cp): if param.device != param_cp.device: temp = param.data.to(param_cp.device) else: @@ -31,43 +31,45 @@ def check_euqal(param, param_cp): return torch.equal(temp, param_cp.data) -@parameterize('init_device', [None, torch.device('cpu')]) -@parameterize('keep_gathered', [True, False]) -@parameterize('pin_memory', [True, False]) +@parameterize("init_device", [None, torch.device("cpu")]) +@parameterize("keep_gathered", [True, False]) +@parameterize("pin_memory", [True, False]) def exam_chunk_basic(init_device, keep_gathered, pin_memory): world_size = torch.distributed.get_world_size() - pg = ColoProcessGroup() - my_chunk = Chunk(chunk_size=1024, - process_group=pg, - dtype=torch.float32, - init_device=init_device, - cpu_shard_init=True, - keep_gathered=keep_gathered, - pin_memory=pin_memory) + pg = _get_default_group() + my_chunk = Chunk( + chunk_size=1024, + process_group=pg, + dtype=torch.float32, + init_device=init_device, + cpu_shard_init=True, + keep_gathered=keep_gathered, + pin_memory=pin_memory, + ) param_list = [] param_cp_list = [] - add_param(param_list, param_cp_list, 8, 8, 8, device='cuda') + add_param(param_list, param_cp_list, 8, 8, 8, device="cuda") add_param(param_list, param_cp_list, 4, 4) - add_param(param_list, param_cp_list, 4, 8, 2, device='cuda') + add_param(param_list, param_cp_list, 4, 8, 2, device="cuda") add_param(param_list, param_cp_list, 1, 1, 5) for param in param_list: my_chunk.append_tensor(param) assert my_chunk.utilized_size == 597 for param, param_cp in zip(param_list, param_cp_list): - check_euqal(param, param_cp) + check_equal(param, param_cp) my_chunk.close_chunk() if keep_gathered is False: assert my_chunk.cpu_shard.size(0) == 1024 // world_size - assert my_chunk.device_type == 'cpu' + assert my_chunk.device_type == "cpu" assert my_chunk.can_move my_chunk.shard_move(get_current_device()) else: assert my_chunk.cuda_global_chunk.size(0) == 1024 - assert my_chunk.device_type == 'cuda' + assert my_chunk.device_type == "cuda" assert not my_chunk.can_move assert dist_sum(my_chunk.valid_end) == my_chunk.utilized_size @@ -75,9 +77,9 @@ def exam_chunk_basic(init_device, keep_gathered, pin_memory): assert not flag, "has_inf_or_nan is {}".format(flag) my_chunk.access_chunk() - assert my_chunk.device_type == 'cuda' + assert my_chunk.device_type == "cuda" for param, param_cp in zip(param_list, param_cp_list): - check_euqal(param, param_cp) + check_equal(param, param_cp) assert my_chunk.tensor_state_cnter[TensorState.HOLD] == 4 my_chunk.tensor_trans_state(param_list[0], TensorState.COMPUTE) @@ -97,25 +99,25 @@ def exam_chunk_basic(init_device, keep_gathered, pin_memory): if keep_gathered is False: assert my_chunk.cuda_shard.size(0) == 1024 // world_size - assert my_chunk.device_type == 'cuda' + assert my_chunk.device_type == "cuda" assert my_chunk.can_move else: assert my_chunk.cuda_global_chunk.size(0) == 1024 - assert my_chunk.device_type == 'cuda' + assert my_chunk.device_type == "cuda" assert not my_chunk.can_move def run_dist(rank, world_size, port): - colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") exam_chunk_basic() @pytest.mark.dist -@pytest.mark.parametrize('world_size', [1, 2, 4]) +@pytest.mark.parametrize("world_size", [1, 2, 4]) @rerun_if_address_is_in_use() def test_chunk_function(world_size): spawn(run_dist, world_size) -if __name__ == '__main__': +if __name__ == "__main__": test_chunk_function(4) diff --git a/tests/test_zero/test_gemini/test_fwd_bwd.py b/tests/test_zero/test_gemini/test_fwd_bwd.py index f2cbb7fb77d600a8b7918897764afaa46a1f674d..94e70040019c231577c0cf62eb16c3e459c25786 100644 --- a/tests/test_zero/test_gemini/test_fwd_bwd.py +++ b/tests/test_zero/test_gemini/test_fwd_bwd.py @@ -1,39 +1,45 @@ import pytest import torch +import torch.distributed as dist from torch.nn.parallel import DistributedDataParallel as DDP from torch.testing import assert_close import colossalai -from colossalai.amp import convert_to_apex_amp +from colossalai.legacy.amp import convert_to_apex_amp from colossalai.nn.optimizer import HybridAdam -from colossalai.tensor import ProcessGroup from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn +from colossalai.utils import set_seed from colossalai.utils.cuda import get_current_device -from colossalai.zero import ColoInitContext, ZeroDDP, ZeroOptimizer -from colossalai.zero.gemini.chunk import ChunkManager, search_chunk_configuration -from colossalai.zero.gemini.gemini_mgr import GeminiManager -from tests.components_to_test import run_fwd, run_fwd_bwd +from colossalai.zero import GeminiDDP, GeminiOptimizer +from colossalai.zero.gemini.chunk import search_chunk_configuration +from tests.components_to_test import run_fwd_bwd from tests.components_to_test.registry import non_distributed_component_funcs -from tests.test_tensor.common_utils import set_seed +PLACEMENT_CONFIGS = [ + {"placement_policy": "static", "shard_param_frac": 0.0}, # zero2 + {"placement_policy": "static", "shard_param_frac": 1.0}, # zero3 + {"placement_policy": "static", "shard_param_frac": 0.5}, # zero3-half + {"placement_policy": "auto"}, +] -def check_grad(model: ZeroDDP, torch_model: torch.nn.Module): + +def check_grad(model: GeminiDDP, torch_model: torch.nn.Module): chunk_manager = model.chunk_manager param_list = [p for p in model.parameters()] chunk_list = chunk_manager.get_chunks(param_list) for chunk in chunk_list: chunk_manager.access_chunk(chunk) - for (p0, p1) in zip(model.parameters(), torch_model.parameters()): + for p0, p1 in zip(model.parameters(), torch_model.parameters()): assert_close(p0, p1.grad, rtol=1e-3, atol=5e-5) -@parameterize('placement_policy', ['cuda', 'cpu', 'auto', 'const']) -@parameterize('keep_gather', [False, True]) -@parameterize('model_name', ['gpt2', 'bert', 'albert']) -@parameterize('use_grad_checkpoint', [False, True]) +@parameterize("placement_config", PLACEMENT_CONFIGS) +@parameterize("keep_gather", [False, True]) +@parameterize("model_name", ["gpt2", "bert", "albert"]) +@parameterize("use_grad_checkpoint", [False, True]) def exam_gpt_fwd_bwd( - placement_policy, + placement_config, keep_gather, model_name: str, use_grad_checkpoint: bool = False, @@ -43,8 +49,7 @@ def exam_gpt_fwd_bwd( model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func() set_seed(42) - with ColoInitContext(device=init_device): - model = model_builder(use_grad_checkpoint) + model = model_builder(use_grad_checkpoint) set_seed(42) torch_model = model_builder(use_grad_checkpoint).cuda() @@ -52,22 +57,20 @@ def exam_gpt_fwd_bwd( torch_p.data.copy_(p.data) world_size = torch.distributed.get_world_size() - config_dict, *_ = search_chunk_configuration(model, search_range_mb=1, search_interval_byte=100) - config_dict[world_size]['chunk_size'] = 5000 - config_dict[world_size]['keep_gathered'] = keep_gather - chunk_manager = ChunkManager(config_dict) - gemini_manager = GeminiManager(placement_policy, chunk_manager) - model = ZeroDDP(model, gemini_manager, pin_memory=True) + config_dict, *_ = search_chunk_configuration(model, search_range_m=1, search_interval=100) + config_dict[world_size]["chunk_size"] = 5000 + config_dict[world_size]["keep_gathered"] = keep_gather + model = GeminiDDP(model, config_dict, init_device, pin_memory=True, **placement_config) optimizer = HybridAdam(model.parameters(), lr=1e-3) - zero_optim = ZeroOptimizer(optimizer, model, initial_scale=1) + zero_optim = GeminiOptimizer(optimizer, model, initial_scale=1) - pg = ProcessGroup() - amp_config = dict(opt_level='O2', keep_batchnorm_fp32=False, loss_scale=1) + rank = dist.get_rank() + amp_config = dict(opt_level="O2", keep_batchnorm_fp32=False, loss_scale=1) torch_optim = torch.optim.Adam(torch_model.parameters(), lr=1e-3) torch_model, torch_optim = convert_to_apex_amp(torch_model, torch_optim, amp_config) - torch_model = DDP(torch_model, device_ids=[pg.rank()], process_group=pg.dp_process_group()) + torch_model = DDP(torch_model, device_ids=[rank]) - set_seed(pg.dp_local_rank()) + set_seed(rank) for i, (input_ids, label) in enumerate(train_dataloader): # you can only test a single fwd + bwd. # after bwd param is grad for Gemini, due to the chunk reuse optimization. @@ -89,73 +92,18 @@ def exam_gpt_fwd_bwd( check_grad(model, torch_model) -@parameterize('placement_policy', ['cuda', 'cpu']) -@parameterize('keep_gather', [False, True]) -@parameterize('model_name', ['gpt2', 'bert', 'albert']) -@parameterize('scatter_after_inference', [False, True]) -def exam_gpt_inference( - placement_policy, - keep_gather, - model_name: str, - scatter_after_inference: bool = False, -): - init_device = get_current_device() - get_components_func = non_distributed_component_funcs.get_callable(model_name) - model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func() - - set_seed(42) - with ColoInitContext(device=init_device): - model = model_builder() - - set_seed(42) - torch_model = model_builder().cuda() - for torch_p, p in zip(torch_model.parameters(), model.parameters()): - torch_p.data.copy_(p.data) - - world_size = torch.distributed.get_world_size() - config_dict, *_ = search_chunk_configuration(model, search_range_mb=1, search_interval_byte=100) - config_dict[world_size]['chunk_size'] = 5000 - config_dict[world_size]['keep_gathered'] = keep_gather - chunk_manager = ChunkManager(config_dict) - gemini_manager = GeminiManager(placement_policy, chunk_manager) - model = ZeroDDP(model, gemini_manager, pin_memory=True, scatter_after_inference=scatter_after_inference) - - pg = ProcessGroup() - amp_config = dict(opt_level='O2', keep_batchnorm_fp32=False, loss_scale=1) - torch_optim = torch.optim.Adam(torch_model.parameters(), lr=1e-3) - torch_model, torch_optim = convert_to_apex_amp(torch_model, torch_optim, amp_config) - torch_model = DDP(torch_model, device_ids=[pg.rank()], process_group=pg.dp_process_group()) - - set_seed(pg.dp_local_rank()) - model.eval() - torch_model.eval() - for i, (input_ids, label) in enumerate(train_dataloader): - # you can only test a single fwd + bwd. - # after bwd param is grad for Gemini, due to the chunk reuse optimization. - if i > 0: - break - with torch.no_grad(): - input_ids, label = input_ids.cuda(), label.cuda() - - torch_loss = run_fwd(torch_model, input_ids, label, criterion) - loss = run_fwd(model, input_ids, label, criterion) - - assert torch.equal(torch_loss, loss) - - def run_dist(rank, world_size, port): config = {} - colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + colossalai.launch(config=config, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") exam_gpt_fwd_bwd() - exam_gpt_inference() @pytest.mark.dist -@pytest.mark.parametrize('world_size', [1, 4]) +@pytest.mark.parametrize("world_size", [1, 4]) @rerun_if_address_is_in_use() def test_gpt(world_size): spawn(run_dist, world_size) -if __name__ == '__main__': +if __name__ == "__main__": test_gpt(4) diff --git a/tests/test_zero/test_gemini/test_gemini_use_rmt.py b/tests/test_zero/test_gemini/test_gemini_use_rmt.py index dd580976d8eafd36c83601fb397242cf52f60d09..2fa2d50a6caa055a3848eeaad02946eddb102183 100644 --- a/tests/test_zero/test_gemini/test_gemini_use_rmt.py +++ b/tests/test_zero/test_gemini/test_gemini_use_rmt.py @@ -1,33 +1,31 @@ import pytest import torch +import torch.distributed as dist import colossalai -from colossalai.tensor import ProcessGroup from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn -from colossalai.zero import ColoInitContext, ZeroDDP -from colossalai.zero.gemini.chunk import ChunkManager, search_chunk_configuration -from colossalai.zero.gemini.gemini_mgr import GeminiManager +from colossalai.utils import set_seed +from colossalai.zero import GeminiDDP +from colossalai.zero.gemini.chunk import search_chunk_configuration from colossalai.zero.gemini.memory_tracer.runtime_mem_tracer import RuntimeMemTracer from tests.components_to_test import run_fwd_bwd from tests.components_to_test.registry import non_distributed_component_funcs -from tests.test_tensor.common_utils import set_seed # run gemini use the runtime memory tracer -@parameterize('placement_policy', ['auto']) -@parameterize('keep_gather', [False]) -@parameterize('model_name', ['repeated_computed_layers', 'bert', 'albert', 'gpt2']) -@parameterize('use_grad_checkpoint', [False, True]) +@parameterize("placement_policy", ["auto"]) +@parameterize("keep_gather", [False]) +@parameterize("model_name", ["repeated_computed_layers", "bert", "albert", "gpt2"]) +@parameterize("use_grad_checkpoint", [False, True]) def run_gemini_use_rmt(placement_policy, keep_gather, model_name: str, use_grad_checkpoint: bool = False): set_seed(42) get_components_func = non_distributed_component_funcs.get_callable(model_name) model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func() - with ColoInitContext(device='cpu'): - model = model_builder(use_grad_checkpoint) + model = model_builder(use_grad_checkpoint).cuda() - print(f'model_name {model_name}') + print(f"model_name {model_name}") runtime_mem_tracer = RuntimeMemTracer(model) for i, (input_ids, label) in enumerate(train_dataloader): if i > 0: @@ -39,32 +37,31 @@ def run_gemini_use_rmt(placement_policy, keep_gather, model_name: str, use_grad_ run_fwd_bwd(runtime_mem_tracer, input_ids, label, criterion, runtime_mem_tracer) memstats = runtime_mem_tracer.memstats() runtime_tracer_non_model_data = runtime_mem_tracer._memstats._non_model_data_cuda_list - print('runtime tracer non model data points: ', len(runtime_tracer_non_model_data)) - print('runtime tracer: ', runtime_tracer_non_model_data) + print("runtime tracer non model data points: ", len(runtime_tracer_non_model_data)) + print("runtime tracer: ", runtime_tracer_non_model_data) print([memstats.param_used_step(p) for p in model.parameters()]) - if model_name == 'repeated_computed_layers': + if model_name == "repeated_computed_layers": for idx, p in enumerate(model.parameters()): step_list = memstats.param_used_step(p) if idx < 4: assert len(step_list) == 4 - if model_name == 'repeated_computed_layers': + if model_name == "repeated_computed_layers": for idx, p in enumerate(model.parameters()): step_list = memstats.param_used_step(p) if idx < 4: assert len(step_list) == 4 world_size = torch.distributed.get_world_size() - config_dict, *_ = search_chunk_configuration(model, search_range_mb=1, search_interval_byte=100) - config_dict[world_size]['chunk_size'] = 5000 - config_dict[world_size]['keep_gathered'] = keep_gather - chunk_manager = ChunkManager(config_dict) - gemini_manager = GeminiManager(placement_policy, chunk_manager, memstats) - model = ZeroDDP(model, gemini_manager, pin_memory=True) - - pg = ProcessGroup() - set_seed(pg.dp_local_rank()) + config_dict, *_ = search_chunk_configuration(model, search_range_m=1, search_interval=100) + config_dict[world_size]["chunk_size"] = 5000 + config_dict[world_size]["keep_gathered"] = keep_gather + model = GeminiDDP( + model, chunk_config_dict=config_dict, placement_policy=placement_policy, pin_memory=True, memstats=memstats + ) + + set_seed(dist.get_rank()) for i, (input_ids, label) in enumerate(train_dataloader): # you can only test a single fwd + bwd. # after bwd param is grad for Gemini, due to the chunk reuse optimization. @@ -74,28 +71,30 @@ def run_gemini_use_rmt(placement_policy, keep_gather, model_name: str, use_grad_ input_ids, label = input_ids.cuda(), label.cuda() set_seed(42) - loss = run_fwd_bwd(model, input_ids, label, criterion, model) + run_fwd_bwd(model, input_ids, label, criterion, model) - gemini_non_model_data = gemini_manager._mem_stats_collector._memstats.non_model_data_list('cuda') + gemini_non_model_data = model.gemini_manager._mem_stats_collector._memstats.non_model_data_list("cuda") # print('gemini non model data:', gemini_non_model_data) - assert len(gemini_non_model_data) == len(runtime_tracer_non_model_data), \ - f'model_name {model_name} {len(gemini_non_model_data)} vs {len(runtime_tracer_non_model_data)}' + assert len(gemini_non_model_data) == len( + runtime_tracer_non_model_data + ), f"model_name {model_name} {len(gemini_non_model_data)} vs {len(runtime_tracer_non_model_data)}" def run_dist(rank, world_size, port): config = {} - colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + colossalai.launch(config=config, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") run_gemini_use_rmt() +@pytest.mark.skip("this is not used") @pytest.mark.dist -@pytest.mark.parametrize('world_size', [1, 4]) +@pytest.mark.parametrize("world_size", [1, 4]) @rerun_if_address_is_in_use() def test_gemini_use_rmt(world_size): spawn(run_dist, world_size) -if __name__ == '__main__': +if __name__ == "__main__": test_gemini_use_rmt(1) diff --git a/tests/test_zero/test_gemini/test_get_torch_model.py b/tests/test_zero/test_gemini/test_get_torch_model.py deleted file mode 100644 index b3e3b2b22fc3709b3cf928384b46f7b2928ef4ff..0000000000000000000000000000000000000000 --- a/tests/test_zero/test_gemini/test_get_torch_model.py +++ /dev/null @@ -1,52 +0,0 @@ -import pytest -import torch - -import colossalai -from colossalai.tensor import ColoParameter -from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn -from colossalai.utils.cuda import get_current_device -from colossalai.zero import ColoInitContext, GeminiDDP -from colossalai.zero.gemini.utils import get_static_torch_model -from tests.components_to_test.registry import non_distributed_component_funcs - - -@parameterize('model_name', ['hanging_param_model', 'resnet18', 'gpt2']) -def run_convert_torch_module(model_name: str): - get_components_func = non_distributed_component_funcs.get_callable(model_name) - model_builder, _, _, _, _ = get_components_func() - - with ColoInitContext(device=torch.device("cpu")): - model = model_builder(checkpoint=False) - model = GeminiDDP(model, device=get_current_device(), placement_policy='auto', pin_memory=True) - pytorch_model = get_static_torch_model(model, only_rank_0=False) - - for n, p in pytorch_model.named_parameters(): - assert type(p) == torch.nn.Parameter, f"type error: {n} is a {type(p)}" - - # get the static model should not change the original model - for n, p in model.named_parameters(): - assert isinstance(p, ColoParameter) - - for (pn, pm), (cn, cm) in zip(pytorch_model.named_modules(), model.named_modules()): - assert pn == cn - assert id(pm) != id(cm) - for pp, cp in zip(pm.parameters(recurse=False), cm.parameters(recurse=False)): - assert id(pp) != id(cp) - assert pp.shape == cp.shape - - -def run_dist(rank, world_size, port): - config = {} - colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') - run_convert_torch_module() - - -@pytest.mark.dist -@pytest.mark.parametrize('world_size', [1, 4]) -@rerun_if_address_is_in_use() -def test_convert_torch_module(world_size): - spawn(run_dist, world_size) - - -if __name__ == '__main__': - test_convert_torch_module(2) diff --git a/tests/test_zero/test_gemini/test_grad_clip.py b/tests/test_zero/test_gemini/test_grad_clip.py index 38b6e474ea986a74f451b30ddb4b45c94f6b12dc..d8bcc555a15da9fb4c5964124c1f758db660eb9a 100644 --- a/tests/test_zero/test_gemini/test_grad_clip.py +++ b/tests/test_zero/test_gemini/test_grad_clip.py @@ -5,19 +5,39 @@ from torch.nn.parallel import DistributedDataParallel as DDP from torch.testing import assert_close import colossalai -from colossalai.amp import convert_to_apex_amp +from colossalai.legacy.amp import convert_to_apex_amp from colossalai.nn.optimizer import HybridAdam from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn -from colossalai.utils.cuda import get_current_device -from colossalai.zero import ColoInitContext, ZeroDDP, ZeroOptimizer -from colossalai.zero.gemini.chunk import ChunkManager, search_chunk_configuration -from colossalai.zero.gemini.gemini_mgr import GeminiManager +from colossalai.utils import set_seed +from colossalai.zero import GeminiDDP, GeminiOptimizer +from colossalai.zero.gemini.chunk import search_chunk_configuration from tests.components_to_test import run_fwd_bwd from tests.components_to_test.registry import non_distributed_component_funcs -from tests.test_tensor.common_utils import set_seed - -def check_param(model: ZeroDDP, torch_model: torch.nn.Module): +PLACEMENT_CONFIGS = [ + { + "placement_policy": "static", + "shard_param_frac": 0.0, + "offload_optim_frac": 0.0, + "offload_param_frac": 0.0, + }, # zero2 + { + "placement_policy": "static", + "shard_param_frac": 0.0, + "offload_optim_frac": 1.0, + "offload_param_frac": 0.0, + }, # zero2-offload + { + "placement_policy": "static", + "shard_param_frac": 0.0, + "offload_optim_frac": 0.5, + "offload_param_frac": 0.0, + }, # zero2-offload-half + {"placement_policy": "auto"}, +] + + +def check_param(model: GeminiDDP, torch_model: torch.nn.Module): zero_dict = model.state_dict(only_rank_0=False) torch_dict = torch_model.state_dict() @@ -30,40 +50,39 @@ def check_param(model: ZeroDDP, torch_model: torch.nn.Module): assert_close(value, temp_zero_value, rtol=1e-3, atol=4e-3) -@parameterize('placement_policy', ['cuda', 'cpu', 'auto', 'const']) -@parameterize('model_name', ['gpt2']) -def exam_grad_clipping(placement_policy, model_name: str): +@parameterize("placement_config", PLACEMENT_CONFIGS) +@parameterize("model_name", ["gpt2"]) +def exam_grad_clipping(placement_config, model_name: str): set_seed(1912) get_components_func = non_distributed_component_funcs.get_callable(model_name) model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func() torch_model = model_builder().cuda() - amp_config = dict(opt_level='O2', keep_batchnorm_fp32=False, loss_scale=32) + amp_config = dict(opt_level="O2", keep_batchnorm_fp32=False, loss_scale=32) torch_optim = torch.optim.Adam(torch_model.parameters(), lr=1e-3) torch_model, torch_optim = convert_to_apex_amp(torch_model, torch_optim, amp_config) torch_model = DDP(torch_model, device_ids=[dist.get_rank()]) - init_dev = get_current_device() - with ColoInitContext(device=init_dev): - model = model_builder() + model = model_builder() for torch_p, p in zip(torch_model.parameters(), model.parameters()): p.data.copy_(torch_p.data) world_size = torch.distributed.get_world_size() - config_dict, *_ = search_chunk_configuration(model, search_range_mb=1, search_interval_byte=100) - config_dict[world_size]['chunk_size'] = 5000 - config_dict[world_size]['keep_gathered'] = False - if placement_policy != 'cuda': - init_device = torch.device('cpu') + config_dict, *_ = search_chunk_configuration(model, search_range_m=1, search_interval=100) + config_dict[world_size]["chunk_size"] = 5000 + config_dict[world_size]["keep_gathered"] = False + if placement_config["placement_policy"] != "cuda": + init_device = torch.device("cpu") else: init_device = None - chunk_manager = ChunkManager(config_dict, init_device=init_device) - gemini_manager = GeminiManager(placement_policy, chunk_manager) - model = ZeroDDP(model, gemini_manager, pin_memory=True) + + model = GeminiDDP( + model, chunk_config_dict=config_dict, chunk_init_device=init_device, pin_memory=True, **placement_config + ) optimizer = HybridAdam(model.parameters(), lr=1e-3) - zero_optim = ZeroOptimizer(optimizer, model, initial_scale=32, clipping_norm=1.0) + zero_optim = GeminiOptimizer(optimizer, model, initial_scale=32, clipping_norm=1.0) model.train() torch_model.train() @@ -83,6 +102,7 @@ def exam_grad_clipping(placement_policy, model_name: str): assert_close(torch_loss, loss) import apex.amp as apex_amp + torch.nn.utils.clip_grad_norm_(apex_amp.master_params(torch_optim), 1.0) torch_optim.step() zero_optim.step() @@ -92,16 +112,16 @@ def exam_grad_clipping(placement_policy, model_name: str): def run_dist(rank, world_size, port): config = {} - colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + colossalai.launch(config=config, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") exam_grad_clipping() @pytest.mark.dist -@pytest.mark.parametrize('world_size', [1, 2]) +@pytest.mark.parametrize("world_size", [1, 2]) @rerun_if_address_is_in_use() def test_grad_clip(world_size): spawn(run_dist, world_size) -if __name__ == '__main__': +if __name__ == "__main__": test_grad_clip(2) diff --git a/tests/test_zero/test_gemini/test_inference.py b/tests/test_zero/test_gemini/test_inference.py index 790a0611c9ddab198b18bf6c5276521efabb399e..2b2b246a9f541ed78e793d3152216807ff1111cb 100644 --- a/tests/test_zero/test_gemini/test_inference.py +++ b/tests/test_zero/test_gemini/test_inference.py @@ -7,19 +7,25 @@ from torch.nn.parallel import DistributedDataParallel as DDP from torch.testing import assert_close import colossalai -from colossalai.amp import convert_to_apex_amp +from colossalai.legacy.amp import convert_to_apex_amp from colossalai.nn.optimizer import HybridAdam from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn +from colossalai.utils import set_seed from colossalai.utils.cuda import get_current_device -from colossalai.zero import ColoInitContext, ZeroDDP, ZeroOptimizer, post_process_colo_init_ctx, zero_model_wrapper -from colossalai.zero.gemini.chunk import ChunkManager, init_chunk_manager, search_chunk_configuration -from colossalai.zero.gemini.gemini_mgr import GeminiManager +from colossalai.zero import GeminiDDP, GeminiOptimizer +from colossalai.zero.gemini.chunk import search_chunk_configuration from tests.components_to_test import run_fwd_bwd from tests.components_to_test.registry import non_distributed_component_funcs -from tests.test_tensor.common_utils import debug_print, set_seed +PLACEMENT_CONFIGS = [ + {"placement_policy": "static", "shard_param_frac": 0.0}, # zero2 + {"placement_policy": "static", "shard_param_frac": 1.0}, # zero3 + {"placement_policy": "static", "shard_param_frac": 0.5}, # zero3-half + {"placement_policy": "auto"}, +] -def check_param(model: ZeroDDP, torch_model: torch.nn.Module): + +def check_param(model: GeminiDDP, torch_model: torch.nn.Module): zero_dict = model.state_dict(only_rank_0=False) torch_dict = torch_model.state_dict() @@ -32,55 +38,42 @@ def check_param(model: ZeroDDP, torch_model: torch.nn.Module): assert_close(value, temp_zero_value, rtol=1e-3, atol=4e-3) -def multi_chunk_init(model: torch.nn.Module, placement_policy: str): +def multi_chunk_init(model: torch.nn.Module, placement_config: dict): world_size = dist.get_world_size() - config_dict, *_ = search_chunk_configuration(model, search_range_mb=1, search_interval_byte=100) - config_dict[world_size]['chunk_size'] = 5000 - config_dict[world_size]['keep_gathered'] = False - if placement_policy != 'cuda': - init_device = torch.device('cpu') - else: - init_device = None - chunk_manager = ChunkManager(config_dict, init_device=init_device) - gemini_manager = GeminiManager(placement_policy, chunk_manager) - model = ZeroDDP(model, gemini_manager, pin_memory=True) + config_dict, *_ = search_chunk_configuration(model, search_range_m=1, search_interval=100) + config_dict[world_size]["chunk_size"] = 5000 + config_dict[world_size]["keep_gathered"] = False + model = GeminiDDP(model, config_dict, pin_memory=True, **placement_config) return model -def single_chunk_init(model: torch.nn.Module, placement_policy: str): - gemini_config = dict( - device=get_current_device(), - placement_policy=placement_policy, - pin_memory=True, - ) - model = zero_model_wrapper(model=model, zero_stage=3, gemini_config=gemini_config) +def single_chunk_init(model: torch.nn.Module, placement_config: dict): + model = GeminiDDP(model, chunk_init_device=get_current_device(), pin_memory=True, **placement_config) return model -@parameterize('placement_policy', ['cuda', 'cpu', 'auto', 'const']) -@parameterize('model_name', ['gpt2']) -@parameterize('model_init_func', [single_chunk_init, multi_chunk_init]) -def exam_inference(placement_policy: str, model_name: str, model_init_func: Callable): +@parameterize("placement_config", PLACEMENT_CONFIGS) +@parameterize("model_name", ["gpt2"]) +@parameterize("model_init_func", [single_chunk_init, multi_chunk_init]) +def exam_inference(placement_config: dict, model_name: str, model_init_func: Callable): set_seed(19360226) get_components_func = non_distributed_component_funcs.get_callable(model_name) model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func() torch_model = model_builder().cuda() - amp_config = dict(opt_level='O2', keep_batchnorm_fp32=False, loss_scale=128) + amp_config = dict(opt_level="O2", keep_batchnorm_fp32=False, loss_scale=128) torch_optim = torch.optim.Adam(torch_model.parameters(), lr=1e-3) torch_model, torch_optim = convert_to_apex_amp(torch_model, torch_optim, amp_config) torch_model = DDP(torch_model, device_ids=[dist.get_rank()]) - init_dev = get_current_device() - with ColoInitContext(device=init_dev): - model = model_builder() + model = model_builder().to(init_dev) for torch_p, p in zip(torch_model.parameters(), model.parameters()): p.data.copy_(torch_p.data) - model = model_init_func(model, placement_policy) + model = model_init_func(model, placement_config) optimizer = HybridAdam(model.parameters(), lr=1e-3) - zero_optim = ZeroOptimizer(optimizer, model, initial_scale=128) + zero_optim = GeminiOptimizer(optimizer, model, initial_scale=128) model.eval() torch_model.eval() @@ -95,7 +88,7 @@ def exam_inference(placement_policy: str, model_name: str, model_init_func: Call torch_optim.zero_grad() torch_loss = run_fwd_bwd(torch_model, input_ids, label, criterion, torch_optim) loss = run_fwd_bwd(model, input_ids, label, criterion, zero_optim) - assert_close(torch_loss, loss) + assert_close(torch_loss, loss, rtol=1e-5, atol=1e-5) zero_optim.step() torch_optim.step() check_param(model, torch_model) @@ -117,16 +110,16 @@ def exam_inference(placement_policy: str, model_name: str, model_init_func: Call def run_dist(rank, world_size, port): config = {} - colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + colossalai.launch(config=config, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") exam_inference() @pytest.mark.dist -@pytest.mark.parametrize('world_size', [1, 4]) +@pytest.mark.parametrize("world_size", [1, 4]) @rerun_if_address_is_in_use() def test_inference(world_size): spawn(run_dist, world_size) -if __name__ == '__main__': +if __name__ == "__main__": test_inference(1) diff --git a/tests/test_zero/test_gemini/test_optim.py b/tests/test_zero/test_gemini/test_optim.py index 8ce20c16e8f965824c7a3c63e4be6a5ed8f851a7..b7c08392600ff4d5276a03439501ffbd9501b9e5 100644 --- a/tests/test_zero/test_gemini/test_optim.py +++ b/tests/test_zero/test_gemini/test_optim.py @@ -5,75 +5,101 @@ from torch.nn.parallel import DistributedDataParallel as DDP from torch.testing import assert_close import colossalai -from colossalai.amp import convert_to_apex_amp +from colossalai.legacy.amp import convert_to_apex_amp from colossalai.nn.optimizer import HybridAdam from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn +from colossalai.utils import set_seed from colossalai.utils.cuda import get_current_device -from colossalai.zero import ColoInitContext, ZeroDDP, ZeroOptimizer, post_process_colo_init_ctx -from colossalai.zero.gemini.chunk import ChunkManager, init_chunk_manager, search_chunk_configuration -from colossalai.zero.gemini.gemini_mgr import GeminiManager +from colossalai.zero import GeminiDDP, GeminiOptimizer +from colossalai.zero.gemini.chunk import search_chunk_configuration from tests.components_to_test import run_fwd_bwd from tests.components_to_test.registry import non_distributed_component_funcs -from tests.test_tensor.common_utils import debug_print, set_seed + +PLACEMENT_CONFIGS = [ + {"placement_policy": "static", "shard_param_frac": 0.0, "offload_optim_frac": 0.0}, # zero2 + {"placement_policy": "static", "shard_param_frac": 0.0, "offload_optim_frac": 1.0}, # zero2-offload + {"placement_policy": "static", "shard_param_frac": 0.0, "offload_optim_frac": 0.5}, # zero2-offload-half + {"placement_policy": "static", "shard_param_frac": 1.0}, # zero3 + {"placement_policy": "static", "shard_param_frac": 0.5}, # zero3-half + { + "placement_policy": "static", + "shard_param_frac": 1.0, + "offload_optim_frac": 1.0, + "offload_param_frac": 1.0, + }, # zero3-offload-all + {"placement_policy": "auto"}, +] # this model is large enough to slice to chunks -TEST_MODELS = ['gpt2'] +TEST_MODELS = ["gpt2"] # these models are too small, all parameters in these models are compacted into one chunk -EXAMPLE_MODELS = ['albert', 'beit', 'bert', 'hanging_param_model', 'nested_model', 'repeated_computed_layers'] +EXAMPLE_MODELS = ["albert", "beit", "bert", "hanging_param_model", "nested_model", "repeated_computed_layers"] + +# bfloat16 cannot represent them exactly +BF16_IGNORED_KEYS = [ + "albert.embeddings.word_embeddings.weight", + "albert.embeddings.position_embeddings.weight", + "masked_bias", +] -def check_param(model: ZeroDDP, torch_model: torch.nn.Module): - zero_dict = model.state_dict(only_rank_0=False) +def check_param(model: GeminiDDP, torch_model: torch.nn.Module, dtype: torch.dtype): + zero_dict = model.state_dict(only_rank_0=False, dtype=dtype) torch_dict = torch_model.state_dict() for key, value in torch_dict.items(): # key is 'module.model.PARAMETER', so we truncate it key = key[7:] assert key in zero_dict, "{} not in ZeRO dictionary.".format(key) - temp_zero_value = zero_dict[key].to(device=value.device, dtype=value.dtype) + temp_zero_value = zero_dict[key].to(device=value.device) + if dtype is torch.bfloat16 and any(k in key for k in BF16_IGNORED_KEYS): + continue + rtol, atol = 1e-3, 4e-3 + if dtype is torch.bfloat16: + rtol, atol = 4e-3, 8e-3 # debug_print([0], "max range: ", key, torch.max(torch.abs(value - temp_zero_value))) - assert_close(value, temp_zero_value, rtol=1e-3, atol=4e-3) - - -@parameterize('placement_policy', ['cuda', 'cpu', 'auto', 'const']) -@parameterize('model_name', TEST_MODELS) -def exam_model_step(placement_policy, model_name: str): + assert_close( + value.float(), + temp_zero_value.float(), + rtol=rtol, + atol=atol, + msg=lambda s: s + f"\n{key}\n{temp_zero_value.dtype}", + ) + + +@parameterize("placement_config", PLACEMENT_CONFIGS) +@parameterize("model_name", TEST_MODELS) +@parameterize("mixed_precision", [torch.half, torch.bfloat16]) +def exam_model_step(placement_config, model_name: str, mixed_precision: torch.dtype): set_seed(42) get_components_func = non_distributed_component_funcs.get_callable(model_name) model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func() torch_model = model_builder().cuda() - amp_config = dict(opt_level='O2', keep_batchnorm_fp32=False, loss_scale=128) + amp_config = dict(opt_level="O2", keep_batchnorm_fp32=False, loss_scale=128) torch_optim = torch.optim.Adam(torch_model.parameters(), lr=1e-3) torch_model, torch_optim = convert_to_apex_amp(torch_model, torch_optim, amp_config) torch_model = DDP(torch_model, device_ids=[dist.get_rank()]) - init_dev = get_current_device() - with ColoInitContext(device=init_dev): - model = model_builder() + model = model_builder().cuda() for torch_p, p in zip(torch_model.parameters(), model.parameters()): p.data.copy_(torch_p.data) world_size = torch.distributed.get_world_size() - config_dict, *_ = search_chunk_configuration(model, search_range_mb=1, search_interval_byte=100) - config_dict[world_size]['chunk_size'] = 5000 - config_dict[world_size]['keep_gathered'] = False - if placement_policy != 'cuda': - init_device = torch.device('cpu') - else: - init_device = None - chunk_manager = ChunkManager(config_dict, init_device=init_device) - gemini_manager = GeminiManager(placement_policy, chunk_manager) - model = ZeroDDP(model, gemini_manager, pin_memory=True) + config_dict, *_ = search_chunk_configuration(model, search_range_m=1, search_interval=100) + config_dict[world_size]["chunk_size"] = 5000 + config_dict[world_size]["keep_gathered"] = False + model = GeminiDDP(model, config_dict, **placement_config, mixed_precision=mixed_precision) optimizer = HybridAdam(model.parameters(), lr=1e-3) - zero_optim = ZeroOptimizer(optimizer, model, initial_scale=128) + zero_optim = GeminiOptimizer(optimizer, model, initial_scale=128) model.eval() torch_model.eval() set_seed(dist.get_rank() * 3 + 128) + rtol, atol = 1e-4, 1e-5 for i, (input_ids, label) in enumerate(train_dataloader): if i > 2: break @@ -83,44 +109,51 @@ def exam_model_step(placement_policy, model_name: str): torch_loss = run_fwd_bwd(torch_model, input_ids, label, criterion, torch_optim) loss = run_fwd_bwd(model, input_ids, label, criterion, zero_optim) - assert_close(torch_loss, loss) + assert_close(torch_loss, loss, rtol=rtol, atol=atol) zero_optim.step() torch_optim.step() - check_param(model, torch_model) + check_param(model, torch_model, mixed_precision) -@parameterize('placement_policy', ['cuda', 'cpu', 'auto', 'const']) -@parameterize('model_name', EXAMPLE_MODELS) -def exam_tiny_example(placement_policy, model_name: str): +@parameterize("placement_config", PLACEMENT_CONFIGS) +@parameterize("model_name", EXAMPLE_MODELS) +@parameterize("mixed_precision", [torch.half, torch.bfloat16]) +def exam_tiny_example(placement_config, model_name: str, mixed_precision: torch.dtype): set_seed(2008) get_components_func = non_distributed_component_funcs.get_callable(model_name) model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func() torch_model = model_builder().cuda() - amp_config = dict(opt_level='O2', keep_batchnorm_fp32=False, loss_scale=2) + amp_config = dict(opt_level="O2", keep_batchnorm_fp32=False, loss_scale=2) torch_optim = torch.optim.Adam(torch_model.parameters(), lr=1e-3) torch_model, torch_optim = convert_to_apex_amp(torch_model, torch_optim, amp_config) torch_model = DDP(torch_model, device_ids=[dist.get_rank()]) - init_dev = get_current_device() - with ColoInitContext(device=init_dev): - model = model_builder() + model = model_builder().cuda() for torch_p, p in zip(torch_model.parameters(), model.parameters()): p.data.copy_(torch_p.data) - chunk_manager = init_chunk_manager(model=model, init_device=get_current_device(), search_range_mb=1) - gemini_manager = GeminiManager(placement_policy, chunk_manager) - model = ZeroDDP(model, gemini_manager, pin_memory=True) + model = GeminiDDP( + model, + chunk_init_device=get_current_device(), + search_range_m=1, + pin_memory=True, + mixed_precision=mixed_precision, + **placement_config, + ) optimizer = HybridAdam(model.parameters(), lr=1e-3) - zero_optim = ZeroOptimizer(optimizer, model, initial_scale=2) + zero_optim = GeminiOptimizer(optimizer, model, initial_scale=2) model.eval() torch_model.eval() set_seed(dist.get_rank() * 3 + 128) + rtol, atol = 1.5e-6, 2e-5 + if mixed_precision is torch.bfloat16: + rtol, atol = 2e-3, 2e-3 for i, (input_ids, label) in enumerate(train_dataloader): if i > 2: break @@ -133,27 +166,27 @@ def exam_tiny_example(placement_policy, model_name: str): torch_loss = run_fwd_bwd(torch_model, input_ids, label, criterion, torch_optim) loss = run_fwd_bwd(model, input_ids, label, criterion, zero_optim) - assert_close(torch_loss, loss, rtol=1.5e-6, atol=2e-5) # atol should be 2e-5 for torch lower than 1.12 + assert_close(torch_loss, loss, rtol=rtol, atol=atol) # atol should be 2e-5 for torch lower than 1.12 zero_optim.step() torch_optim.step() - check_param(model, torch_model) + check_param(model, torch_model, mixed_precision) def run_dist(rank, world_size, port): config = {} - colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + colossalai.launch(config=config, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") exam_model_step() exam_tiny_example() @pytest.mark.dist -@pytest.mark.parametrize('world_size', [1, 4]) +@pytest.mark.parametrize("world_size", [1, 4]) @rerun_if_address_is_in_use() def test_optim(world_size): spawn(run_dist, world_size) -if __name__ == '__main__': +if __name__ == "__main__": test_optim(1) diff --git a/tests/test_zero/test_gemini/test_runtime_mem_tracer.py b/tests/test_zero/test_gemini/test_runtime_mem_tracer.py index 0e6f283aa5d23ed611d75f0209e3ef4298adb0fe..8e0f6ae36c46e2b64b329c9188e476f73594a47a 100644 --- a/tests/test_zero/test_gemini/test_runtime_mem_tracer.py +++ b/tests/test_zero/test_gemini/test_runtime_mem_tracer.py @@ -1,25 +1,25 @@ from copy import deepcopy import numpy as np +import pytest import torch from colossalai.testing import clear_cache_before_run -from colossalai.zero import ColoInitContext from colossalai.zero.gemini.memory_tracer.runtime_mem_tracer import RuntimeMemTracer from tests.components_to_test import run_fwd_bwd from tests.components_to_test.registry import non_distributed_component_funcs +@pytest.mark.skip("this is not used") @clear_cache_before_run() def test_runtime_mem_tracer(): - test_models = ['gpt2', 'bert', 'simple_net', 'repeated_computed_layers', 'nested_model', 'albert'] + test_models = ["gpt2", "bert", "simple_net", "repeated_computed_layers", "nested_model", "albert"] for model_name in test_models: get_components_func = non_distributed_component_funcs.get_callable(model_name) model_builder, train_dataloader, _, _, criterion = get_components_func() - with ColoInitContext(device='cpu'): - model = model_builder(checkpoint=False) + model = model_builder(checkpoint=False).cuda() model_bk = deepcopy(model) runtime_mem_tracer = RuntimeMemTracer(model) @@ -35,7 +35,7 @@ def test_runtime_mem_tracer(): for p1, p2 in zip(model_bk.parameters(), model.parameters()): torch.allclose(p1.to(torch.half), p2) - non_model_data_list = runtime_mem_tracer._memstats.non_model_data_list('cuda') + non_model_data_list = runtime_mem_tracer._memstats.non_model_data_list("cuda") cuda_non_model_data_list = np.array(non_model_data_list) / 1024**2 print("cuda_non_model_data_list", len(cuda_non_model_data_list)) print(non_model_data_list) @@ -46,9 +46,9 @@ def test_runtime_mem_tracer(): cnt2 = 0 for p in model.parameters(): cnt2 += 1 - assert cnt2 == cnt1, f'visited param number {cnt1} vs real param number {cnt2}' + assert cnt2 == cnt1, f"visited param number {cnt1} vs real param number {cnt2}" del model -if __name__ == '__main__': +if __name__ == "__main__": test_runtime_mem_tracer() diff --git a/tests/test_zero/test_gemini/test_search.py b/tests/test_zero/test_gemini/test_search.py index 35b3b93ade0c223100c660c7612a97e8591bd6ff..e22e5ece42a54d2b282b4645c7ba05bb6bbdd7de 100644 --- a/tests/test_zero/test_gemini/test_search.py +++ b/tests/test_zero/test_gemini/test_search.py @@ -2,117 +2,65 @@ import pytest import torch import colossalai -from colossalai.tensor import ComputePattern, ComputeSpec, ProcessGroup, ShardSpec from colossalai.testing import rerun_if_address_is_in_use, spawn from colossalai.utils import get_current_device -from colossalai.zero import ColoInitContext from colossalai.zero.gemini.chunk import init_chunk_manager, search_chunk_configuration from tests.components_to_test.registry import non_distributed_component_funcs -def init_1d_row_spec(model, pg: ProcessGroup): - tensor_spec = (ShardSpec([0], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D)) - for n, p in model.named_parameters(): - if 'weight' in n and 'ln' not in n: - p.set_process_group(pg) - p.set_tensor_spec(*tensor_spec) - - def exam_search_chunk_size(): world_size = torch.distributed.get_world_size() - pg_tp = ProcessGroup(tp_degree=world_size) - get_components_func = non_distributed_component_funcs.get_callable('gpt2') + get_components_func = non_distributed_component_funcs.get_callable("gpt2") model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func() # make sure torch_model and model has the same parameter values - with ColoInitContext(device=get_current_device()): - model = model_builder() - init_1d_row_spec(model, pg_tp) - config_dict, *_ = search_chunk_configuration(model, - search_range_mb=1, - search_interval_byte=16, - min_chunk_size_mb=0, - filter_exlarge_params=True) + model = model_builder() + config_dict, *_ = search_chunk_configuration( + model, search_range_m=1, search_interval=16, min_chunk_size_m=0, filter_exlarge_params=True + ) for key in config_dict: - chunk_size = config_dict[key]['chunk_size'] - if world_size == 1: + chunk_size = config_dict[key]["chunk_size"] + if world_size == 1 or True: assert chunk_size == 31616 else: assert chunk_size == 1024 -def exam_search_strict_ddp(): - world_size = torch.distributed.get_world_size() - default_shard_pg = ProcessGroup(tp_degree=world_size) - default_shard_spec = ShardSpec([-1], [world_size]) - - get_components_func = non_distributed_component_funcs.get_callable('gpt2') - model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func() - # get the chunk configuration over replicated models - with ColoInitContext(device=get_current_device()): - ddp_model = model_builder() - re_dict, re_total, re_wasted = search_chunk_configuration(ddp_model, - search_range_mb=1, - search_interval_byte=16, - min_chunk_size_mb=0, - filter_exlarge_params=True, - strict_ddp_flag=False) - # get the chunk configuration over sharded ddp models - with ColoInitContext(device=get_current_device(), default_pg=default_shard_pg, - default_dist_spec=default_shard_spec): - sharded_ddp_model = model_builder() - sh_dict, sh_total, sh_wasted = search_chunk_configuration(sharded_ddp_model, - search_range_mb=1, - search_interval_byte=16, - min_chunk_size_mb=0, - filter_exlarge_params=True, - strict_ddp_flag=True) - assert re_dict == sh_dict - for key in re_dict: - assert re_dict[key] == sh_dict[key] - - assert re_total == sh_total - assert re_wasted == sh_wasted - - def exam_chunk_manager(): world_size = torch.distributed.get_world_size() - default_shard_pg = ProcessGroup(tp_degree=world_size) - default_shard_spec = ShardSpec([-1], [world_size]) - get_components_func = non_distributed_component_funcs.get_callable('gpt2') + get_components_func = non_distributed_component_funcs.get_callable("gpt2") model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func() - with ColoInitContext(device=get_current_device(), default_pg=default_shard_pg, - default_dist_spec=default_shard_spec): - sharded_ddp_model = model_builder() - chunk_manager = init_chunk_manager(sharded_ddp_model, - get_current_device(), - hidden_dim=16, - search_range_mb=1, - min_chunk_size_mb=0, - filter_exlarge_params=True, - strict_ddp_flag=True) + sharded_ddp_model = model_builder() + chunk_manager = init_chunk_manager( + sharded_ddp_model, + get_current_device(), + hidden_dim=16, + search_range_m=1, + min_chunk_size_m=0, + filter_exlarge_params=True, + strict_ddp_flag=True, + ) config_dict = chunk_manager.dp_degree_chunk_size_dict assert len(config_dict) == 1 assert config_dict[world_size] == 31616 def run_dist(rank, world_size, port): - colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") exam_search_chunk_size() - exam_search_strict_ddp() exam_chunk_manager() @pytest.mark.dist -@pytest.mark.parametrize('world_size', [1, 4]) +@pytest.mark.parametrize("world_size", [1, 4]) @rerun_if_address_is_in_use() def test_search(world_size): spawn(run_dist, world_size) -if __name__ == '__main__': +if __name__ == "__main__": test_search(4) diff --git a/tests/test_zero/test_gemini/test_zeroddp_state_dict.py b/tests/test_zero/test_gemini/test_zeroddp_state_dict.py index 66e05f3ed1ecb7b6a0d58ff3c29dad6c6a477ef8..3130440bd92573466dfd170995ee1d93848d6a76 100644 --- a/tests/test_zero/test_gemini/test_zeroddp_state_dict.py +++ b/tests/test_zero/test_gemini/test_zeroddp_state_dict.py @@ -4,43 +4,45 @@ from torch.testing import assert_close import colossalai from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn -from colossalai.utils.cuda import get_current_device -from colossalai.zero import ColoInitContext, ZeroDDP -from colossalai.zero.gemini.chunk import ChunkManager, search_chunk_configuration -from colossalai.zero.gemini.gemini_mgr import GeminiManager +from colossalai.utils import set_seed +from colossalai.zero import GeminiDDP +from colossalai.zero.gemini.chunk import search_chunk_configuration from tests.components_to_test.registry import non_distributed_component_funcs -from tests.test_tensor.common_utils import debug_print, set_seed + +PLACEMENT_CONFIGS = [ + {"placement_policy": "static", "shard_param_frac": 0.0}, # zero2 + {"placement_policy": "static", "shard_param_frac": 1.0}, # zero3 + {"placement_policy": "static", "shard_param_frac": 0.5}, # zero3-half + {"placement_policy": "auto"}, +] def ignore_the_first_parameter(model: torch.nn.Module): for name, param in model.named_parameters(): print(f"parameter `{name}` is set ignored") - ZeroDDP.set_params_to_ignore([param]) + GeminiDDP.set_params_to_ignore([param]) return -@parameterize('placement_policy', ['cuda', 'cpu', 'auto']) -@parameterize('keep_gathered', [True, False]) -@parameterize('model_name', ['gpt2', 'bert']) -def exam_state_dict(placement_policy, keep_gathered, model_name: str): +@parameterize("placement_config", PLACEMENT_CONFIGS) +@parameterize("keep_gathered", [True, False]) +@parameterize("model_name", ["gpt2", "bert"]) +def exam_state_dict(placement_config, keep_gathered, model_name: str): set_seed(431) get_components_func = non_distributed_component_funcs.get_callable(model_name) model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func() - with ColoInitContext(device=get_current_device()): - model = model_builder() + model = model_builder() torch_model = model_builder() for torch_p, p in zip(torch_model.parameters(), model.parameters()): torch_p.data.copy_(p.data) world_size = torch.distributed.get_world_size() - config_dict, *_ = search_chunk_configuration(model, search_range_mb=1, search_interval_byte=100) - config_dict[world_size]['chunk_size'] = 5000 - config_dict[world_size]['keep_gathered'] = keep_gathered - chunk_manager = ChunkManager(config_dict) - gemini_manager = GeminiManager(placement_policy, chunk_manager) - model = ZeroDDP(model, gemini_manager, pin_memory=True) + config_dict, *_ = search_chunk_configuration(model, search_range_m=1, search_interval=100) + config_dict[world_size]["chunk_size"] = 5000 + config_dict[world_size]["keep_gathered"] = keep_gathered + model = GeminiDDP(model, config_dict, **placement_config, pin_memory=True) model.train() zero_dict = model.state_dict(only_rank_0=False) @@ -52,32 +54,25 @@ def exam_state_dict(placement_policy, keep_gathered, model_name: str): assert_close(value, temp_zero_value, rtol=1e-3, atol=1e-5) -@parameterize('placement_policy', ['cuda', 'cpu', 'auto']) -@parameterize('keep_gathered', [True, False]) -@parameterize('model_name', ['gpt2', 'bert']) -def exam_load_state_dict(placement_policy, keep_gathered, model_name: str): +@parameterize("placement_config", PLACEMENT_CONFIGS) +@parameterize("keep_gathered", [True, False]) +@parameterize("model_name", ["gpt2", "bert"]) +def exam_load_state_dict(placement_config, keep_gathered, model_name: str): set_seed(431) get_components_func = non_distributed_component_funcs.get_callable(model_name) model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func() - with ColoInitContext(device=get_current_device()): - model = model_builder() + model = model_builder() set_seed(451) - torch_model = model_builder() # get a different model + torch_model = model_builder() # get a different model world_size = torch.distributed.get_world_size() - config_dict, *_ = search_chunk_configuration(model, search_range_mb=1, search_interval_byte=100) - config_dict[world_size]['chunk_size'] = 5000 - config_dict[world_size]['keep_gathered'] = keep_gathered - - if placement_policy != 'cuda': - init_device = torch.device('cpu') - else: - init_device = None - chunk_manager = ChunkManager(config_dict, init_device=init_device) - gemini_manager = GeminiManager(placement_policy, chunk_manager) - model = ZeroDDP(model, gemini_manager, pin_memory=True) + config_dict, *_ = search_chunk_configuration(model, search_range_m=1, search_interval=100) + config_dict[world_size]["chunk_size"] = 5000 + config_dict[world_size]["keep_gathered"] = keep_gathered + + model = GeminiDDP(model, config_dict, **placement_config, pin_memory=True) torch_dict = torch_model.state_dict() model.load_state_dict(torch_dict, strict=False) @@ -89,19 +84,45 @@ def exam_load_state_dict(placement_policy, keep_gathered, model_name: str): assert_close(value, temp_zero_value, rtol=1e-3, atol=1e-5) +@parameterize("placement_config", PLACEMENT_CONFIGS) +@parameterize("model_name", ["gpt2", "bert"]) +def exam_state_dict_shard(placement_config, model_name: str): + get_components_func = non_distributed_component_funcs.get_callable(model_name) + model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func() + + model = model_builder() + + model_size = sum(p.numel() * p.element_size() for p in model.parameters()) / 1024**2 + + config_dict, *_ = search_chunk_configuration(model, search_range_m=1, search_interval=100) + model = GeminiDDP(model, config_dict, **placement_config) + model.train() + + zero_dict = model.state_dict(only_rank_0=False) + accumulated_keys = set() + # ensure number of shards > 1 + for shard, _ in model.state_dict_shard(max_shard_size=(model_size / 3), only_rank_0=False): + for key, value in shard.items(): + assert key not in accumulated_keys, f"key `{key}` is duplicated." + accumulated_keys.add(key) + assert key in zero_dict, f"{key} not in ZeRO dictionary." + assert torch.equal(value, zero_dict[key]), f"{key} not equal." + + def run_dist(rank, world_size, port): config = {} - colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + colossalai.launch(config=config, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") exam_state_dict() exam_load_state_dict() + exam_state_dict_shard() @pytest.mark.dist -@pytest.mark.parametrize('world_size', [1, 4]) +@pytest.mark.parametrize("world_size", [1, 4]) @rerun_if_address_is_in_use() def test_zero_ddp(world_size): spawn(run_dist, world_size) -if __name__ == '__main__': +if __name__ == "__main__": test_zero_ddp(1) diff --git a/tests/test_zero/test_gemini/test_zeroddp_state_dict_shard.py b/tests/test_zero/test_gemini/test_zeroddp_state_dict_shard.py deleted file mode 100644 index 96c26a1de4df58111c943cd3e11d2a5ce673ab85..0000000000000000000000000000000000000000 --- a/tests/test_zero/test_gemini/test_zeroddp_state_dict_shard.py +++ /dev/null @@ -1,56 +0,0 @@ -import pytest -import torch -from torch.testing import assert_close - -import colossalai -from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn -from colossalai.utils.cuda import get_current_device -from colossalai.zero import ColoInitContext, ZeroDDP -from colossalai.zero.gemini.chunk import ChunkManager, search_chunk_configuration -from colossalai.zero.gemini.gemini_mgr import GeminiManager -from tests.components_to_test.registry import non_distributed_component_funcs - - -@parameterize('placement_policy', ['cuda', 'cpu']) -@parameterize('model_name', ['gpt2', 'bert']) -def exam_state_dict(placement_policy, model_name: str): - get_components_func = non_distributed_component_funcs.get_callable(model_name) - model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func() - - with ColoInitContext(device=get_current_device()): - model = model_builder() - - model_size = sum(p.numel() * p.element_size() for p in model.parameters()) / 1024**2 - - config_dict, *_ = search_chunk_configuration(model, search_range_mb=1, search_interval_byte=100) - chunk_manager = ChunkManager(config_dict) - gemini_manager = GeminiManager(placement_policy, chunk_manager) - model = ZeroDDP(model, gemini_manager) - model.train() - - zero_dict = model.state_dict(only_rank_0=False) - accumulated_keys = set() - # ensure number of shards > 1 - for shard in model.state_dict_shard(max_shard_size=(model_size / 3), only_rank_0=False): - for key, value in shard.items(): - assert key not in accumulated_keys, f"key `{key}` is duplicated." - accumulated_keys.add(key) - assert key in zero_dict, f"{key} not in ZeRO dictionary." - assert torch.equal(value, zero_dict[key]), f"{key} not equal." - - -def run_dist(rank, world_size, port): - config = {} - colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') - exam_state_dict() - - -@pytest.mark.dist -@pytest.mark.parametrize('world_size', [1, 4]) -@rerun_if_address_is_in_use() -def test_zero_ddp_state_dict_shard(world_size): - spawn(run_dist, world_size) - - -if __name__ == '__main__': - test_zero_ddp_state_dict_shard(1) diff --git a/tests/test_zero/test_gemini/test_zerooptim_state_dict.py b/tests/test_zero/test_gemini/test_zerooptim_state_dict.py index a8af176c5b3dc9b1880fd6ce5afc8584114b3b50..8aa656b74cf9e64459a5d82b700749c725cbf9f1 100644 --- a/tests/test_zero/test_gemini/test_zerooptim_state_dict.py +++ b/tests/test_zero/test_gemini/test_zerooptim_state_dict.py @@ -5,42 +5,39 @@ import torch.distributed as dist import colossalai from colossalai.nn.optimizer import HybridAdam from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn -from colossalai.utils.cuda import get_current_device -from colossalai.zero import ColoInitContext, ZeroDDP, ZeroOptimizer -from colossalai.zero.gemini.chunk import ChunkManager, search_chunk_configuration -from colossalai.zero.gemini.gemini_mgr import GeminiManager +from colossalai.utils import set_seed +from colossalai.zero import GeminiDDP, GeminiOptimizer +from colossalai.zero.gemini.chunk import search_chunk_configuration from tests.components_to_test.registry import non_distributed_component_funcs -from tests.test_tensor.common_utils import debug_print, set_seed +PLACEMENT_CONFIGS = [ + {"placement_policy": "static", "shard_param_frac": 0.0, "offload_optim_frac": 0.0}, # zero2 + {"placement_policy": "static", "shard_param_frac": 0.0, "offload_optim_frac": 1.0}, # zero2-offload + {"placement_policy": "static", "shard_param_frac": 0.0, "offload_optim_frac": 0.5}, # zero2-offload-half + {"placement_policy": "auto"}, +] -@parameterize('placement_policy', ['cuda', 'cpu', 'auto']) -@parameterize('keep_gathered', [True, False]) -def exam_zero_optim_state_dict(placement_policy, keep_gathered): + +@parameterize("placement_config", PLACEMENT_CONFIGS) +@parameterize("keep_gathered", [True, False]) +def exam_zero_optim_state_dict(placement_config, keep_gathered): set_seed(431) - get_components_func = non_distributed_component_funcs.get_callable('gpt2') + get_components_func = non_distributed_component_funcs.get_callable("gpt2") model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func() - with ColoInitContext(device=get_current_device()): - model = model_builder() + model = model_builder() set_seed(451) - torch_model = model_builder() # get a different model world_size = torch.distributed.get_world_size() - config_dict, *_ = search_chunk_configuration(model, search_range_mb=1, search_interval_byte=100) - config_dict[world_size]['chunk_size'] = 5000 - config_dict[world_size]['keep_gathered'] = keep_gathered - - if placement_policy != 'cuda': - init_device = torch.device('cpu') - else: - init_device = None - chunk_manager = ChunkManager(config_dict, init_device=init_device) - gemini_manager = GeminiManager(placement_policy, chunk_manager) - model = ZeroDDP(model, gemini_manager, pin_memory=True) + config_dict, *_ = search_chunk_configuration(model, search_range_m=1, search_interval=100) + config_dict[world_size]["chunk_size"] = 5000 + config_dict[world_size]["keep_gathered"] = keep_gathered + + model = GeminiDDP(model, config_dict, **placement_config, pin_memory=True) optimizer = HybridAdam(model.parameters()) - optim = ZeroOptimizer(optimizer, model, initial_scale=32) # initialize the link between chunk16 and chunk32 + optim = GeminiOptimizer(optimizer, model, initial_scale=32) # initialize the link between chunk16 and chunk32 set_seed(dist.get_rank() * 3 + 128) model.train() @@ -56,8 +53,8 @@ def exam_zero_optim_state_dict(placement_policy, keep_gathered): optim_state_dict = optim.state_dict() optim.load_state_dict(optim_state_dict) - new_state = optim.state_dict()['state'] - org_state = optim_state_dict['state'] + new_state = optim.state_dict()["state"] + org_state = optim_state_dict["state"] for k, v in org_state.items(): w = new_state[k] @@ -71,16 +68,16 @@ def exam_zero_optim_state_dict(placement_policy, keep_gathered): def run_dist(rank, world_size, port): config = {} - colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + colossalai.launch(config=config, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") exam_zero_optim_state_dict() @pytest.mark.dist -@pytest.mark.parametrize('world_size', [1, 4]) +@pytest.mark.parametrize("world_size", [1, 4]) @rerun_if_address_is_in_use() def test_zero_optim(world_size): spawn(run_dist, world_size) -if __name__ == '__main__': +if __name__ == "__main__": test_zero_optim(1) diff --git a/tests/test_zero/test_legacy/common.py b/tests/test_zero/test_legacy/common.py deleted file mode 100644 index 2c3d122c79af9c959d91c4c4dcf0e5f984656ed1..0000000000000000000000000000000000000000 --- a/tests/test_zero/test_legacy/common.py +++ /dev/null @@ -1,140 +0,0 @@ -from functools import partial - -import torch -import torch.distributed as dist - -from colossalai.logging import get_dist_logger -from colossalai.utils import checkpoint -from colossalai.zero.legacy.shard_utils import TensorShardStrategy -from colossalai.zero.legacy.sharded_model import ShardedModelV2 - -LOGGER = get_dist_logger('zero_test') - -MP_PARALLEL_CONFIG = dict(fp16=dict(mode=None,), parallel=dict(pipeline=dict(size=1), tensor=dict(size=2, mode=None))) - -_ZERO_MODEL_CONFIG = dict(reduce_scatter_bucket_size_mb=25, - fp32_reduce_scatter=False, - tensor_placement_policy='cuda', - gradient_predivide_factor=1.0, - shard_strategy=TensorShardStrategy(), - reuse_fp16_shard=False) - -_ZERO_OPTIMIZER_CONFIG = dict(initial_scale=2**5, - min_scale=1, - growth_factor=2, - backoff_factor=0.5, - growth_interval=1000, - hysteresis=2, - max_scale=2**32) - -ZERO_PARALLEL_CONFIG = dict(fp16=dict(mode=None,), - zero=dict( - model_config=_ZERO_MODEL_CONFIG, - optimizer_config=_ZERO_OPTIMIZER_CONFIG, - ), - parallel=dict(pipeline=dict(size=1), tensor=dict(size=1, mode=None))) - -CONFIG = dict(fp16=dict(mode=None,), - zero=dict(level=3, - verbose=False, - offload_optimizer_config=dict(device='cpu', pin_memory=True, buffer_count=5, fast_init=False), - offload_param_config=dict(device='cpu', - pin_memory=True, - buffer_count=5, - buffer_size=1e8, - max_in_cpu=1e9)), - parallel=dict(pipeline=dict(size=1), tensor=dict(size=1, mode=None))) - - -def run_fwd_bwd(model, data, label, criterion, enable_autocast=False): - model.train() - with torch.cuda.amp.autocast(enabled=enable_autocast): - if criterion: - y = model(data) - loss = criterion(y, label) - else: - loss = model(data, label) - loss = loss.float() - if isinstance(model, ShardedModelV2): - model.backward(loss) - else: - loss.backward() - - -def checkpoint_wrapper(module, enable=True): - if enable: - module.forward = partial(checkpoint, module.forward) - return module - - -def allclose(tensor_a: torch.Tensor, tensor_b: torch.Tensor, loose=False) -> bool: - if loose: - return torch.allclose(tensor_a, tensor_b, atol=1e-2, rtol=1e-3) - return torch.allclose(tensor_a, tensor_b) - - -def check_grads(model, zero_model, loose=False): - for p, zero_p in zip(model.parameters(), zero_model.parameters()): - zero_grad = zero_p.grad.clone().to(p.device) - grad = p.grad.float() - assert grad.dtype == zero_grad.dtype - assert allclose(grad, zero_grad, loose=loose) - - -def check_params(model, zero_model, loose=False): - for p, zero_p in zip(model.parameters(), zero_model.parameters()): - zero_p = zero_p.clone().to(p.device) - # assert p.dtype == zero_p.dtype - assert allclose(p.float(), zero_p.float(), loose=loose), f"diff {p.float() - zero_p.float()}" - - -def check_grads_padding(model, zero_model, loose=False): - rank = dist.get_rank() - for (name, p), (zero_name, zero_p) in zip(model.named_parameters(), zero_model.named_parameters()): - # zero_grad = zero_p.grad.clone().to(p.device) - if zero_p.colo_attr.is_replicated: - zero_grad = zero_p.colo_attr.grad_payload.clone().to(p.device) - chunks = torch.flatten(p.grad).chunk(dist.get_world_size()) - if rank >= len(chunks): - continue - grad = chunks[rank].float() - if zero_grad.size(0) > grad.size(0): - zero_grad = zero_grad[:grad.size(0)] - else: - zero_grad = zero_p.colo_attr.grad_payload - grad = p.grad.to(zero_grad.dtype) - - assert grad.dtype == zero_grad.dtype - assert allclose(grad, zero_grad, loose=loose), f'diff: {grad - zero_grad}' - - -def check_params_padding(model, zero_model, loose=False): - rank = dist.get_rank() - for p, zero_p in zip(model.parameters(), zero_model.parameters()): - zero_p = zero_p.clone().to(p.device) - chunks = torch.flatten(p).chunk(dist.get_world_size()) - if rank >= len(chunks): - continue - p = chunks[rank] - if zero_p.size(0) > p.size(0): - zero_p = zero_p[:p.size(0)] - assert p.dtype == zero_p.dtype - assert allclose(p, zero_p, loose=loose) - - -def check_sharded_model_params(model, zero_model, loose=False, reuse_fp16_shard=False): - rank = dist.get_rank() - for (name, p), (zero_name, zero_p) in zip(model.named_parameters(), zero_model.named_parameters()): - if zero_p.colo_attr.param_is_sharded: - zero_p = zero_p.colo_attr.data_payload.to(p.device).float() - chunks = torch.flatten(p).chunk(dist.get_world_size()) - if rank >= len(chunks): - continue - p = chunks[rank].float() - if zero_p.size(0) > p.size(0): - zero_p = zero_p[:p.size(0)] - else: - zero_p = zero_p.colo_attr.data_payload.to(p.device) - - assert p.dtype == zero_p.dtype, "Parameter `{}`:\n{} vs {}".format(name, p.dtype, zero_p.dtype) - assert allclose(p, zero_p, loose=loose), f'{p} vs {zero_p}' diff --git a/tests/test_zero/test_legacy/test_found_inf.py b/tests/test_zero/test_legacy/test_found_inf.py deleted file mode 100644 index e90158e0a43b65ae756a44d45e5203d82a5cc292..0000000000000000000000000000000000000000 --- a/tests/test_zero/test_legacy/test_found_inf.py +++ /dev/null @@ -1,67 +0,0 @@ -import pytest -import torch -from common import CONFIG -from test_sharded_optim_v2 import _run_step - -import colossalai -from colossalai.nn.optimizer import HybridAdam -from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn -from colossalai.utils.cuda import get_current_device -from colossalai.zero.legacy.init_ctx import ZeroInitContext -from colossalai.zero.legacy.shard_utils import BucketTensorShardStrategy -from colossalai.zero.legacy.sharded_model import ShardedModelV2 -from colossalai.zero.legacy.sharded_optim import ShardedOptimizerV2 -from colossalai.zero.low_level._utils import has_inf_or_nan -from tests.components_to_test.registry import non_distributed_component_funcs - - -@parameterize("cpu_offload", [True, False]) -@parameterize("shard_strategy_class", [BucketTensorShardStrategy]) -@parameterize("gpu_margin_mem_ratio", [0.0, 0.7]) -def _run_test_found_inf(cpu_offload, shard_strategy_class, gpu_margin_mem_ratio): - test_models = ['repeated_computed_layers'] - shard_strategy = shard_strategy_class() - - for model_name in test_models: - get_components_func = non_distributed_component_funcs.get_callable(model_name) - model_builder, train_dataloader, _, optimizer_class, criterion = get_components_func() - - with ZeroInitContext(target_device=torch.device(f'cpu:0') if cpu_offload else get_current_device(), - shard_strategy=shard_strategy, - shard_param=True): - zero_model = model_builder(checkpoint=True) - zero_model = ShardedModelV2( - zero_model, - shard_strategy, - tensor_placement_policy='cpu' if cpu_offload else 'cuda', - reuse_fp16_shard=True, - ) - - sharded_optim = HybridAdam(zero_model.parameters(), lr=1e-3) - sharded_optim = ShardedOptimizerV2(zero_model, sharded_optim, gpu_margin_mem_ratio=gpu_margin_mem_ratio) - - for i, (data, label) in enumerate(train_dataloader): - if i > 1: - break - assert zero_model.overflow_counter == 0 - data, label = data.cuda(), label.cuda() - _run_step(zero_model, sharded_optim, data, label, criterion, False) - for param in zero_model.parameters(): - assert not has_inf_or_nan(param.colo_attr.data_payload) - - -def _run_dist(rank, world_size, port): - colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') - _run_test_found_inf() - - -# use_cpuadam = True can be used with cpu_offload = False -@pytest.mark.dist -@pytest.mark.parametrize("world_size", [1, 2]) -@rerun_if_address_is_in_use() -def test_found_inf(world_size): - spawn(_run_dist, world_size) - - -if __name__ == '__main__': - test_found_inf(world_size=2) diff --git a/tests/test_zero/test_legacy/test_gemini_manager.py b/tests/test_zero/test_legacy/test_gemini_manager.py deleted file mode 100644 index 0e956f7cc6178790492abfb91bcdcb2eb5946823..0000000000000000000000000000000000000000 --- a/tests/test_zero/test_legacy/test_gemini_manager.py +++ /dev/null @@ -1,75 +0,0 @@ -import pytest -import torch - -from colossalai.testing import clear_cache_before_run -from colossalai.zero.legacy.gemini.stateful_tensor import StatefulTensor, TensorState - - -@pytest.mark.dist -@clear_cache_before_run() -def test_gemini_manager(): - # reset the manager, in case that there exists memory information left - manager = StatefulTensor.GST_MGR - manager.reset() - - # occupation 8 - st1 = StatefulTensor(torch.empty(2, 2, dtype=torch.float16, device='cuda')) - # occupation 60 - st2 = StatefulTensor(torch.empty(3, 5, dtype=torch.float32, device='cpu')) - - # occupation 28 - t1 = torch.empty(7, device='cuda') - # occupation 12 - t2 = torch.empty(3, device='cpu') - st3 = StatefulTensor(t1, TensorState.HOLD_AFTER_FWD) - st4 = StatefulTensor(None, TensorState.FREE) - - assert manager.total_number == 4 - assert manager.total_mem['cpu'] == 60 - assert manager.total_mem['cuda'] == 36 - assert manager.state_mem['cpu'][TensorState.HOLD] == 60 - assert manager.state_mem['cuda'][TensorState.HOLD] == 8 - assert manager.state_mem['cuda'][TensorState.HOLD_AFTER_FWD] == 28 - - st4.payload_reset(t2) - st3.payload_reset(t2) - - assert manager.total_number == 4 - assert manager.total_mem['cpu'] == 84 - assert manager.total_mem['cuda'] == 8 - assert manager.state_mem['cpu'][TensorState.HOLD] == 72 - assert manager.state_mem['cuda'][TensorState.HOLD] == 8 - assert manager.state_mem['cpu'][TensorState.HOLD_AFTER_FWD] == 12 - assert manager.state_mem['cuda'][TensorState.HOLD_AFTER_FWD] == 0 - - st1.move_to(torch.device('cpu')) - st2.move_to(torch.device('cpu')) - st3.move_to(torch.device('cuda', 0)) - - assert manager.total_number == 4 - assert manager.total_mem['cpu'] == 80 - assert manager.total_mem['cuda'] == 12 - assert manager.state_mem['cpu'][TensorState.HOLD] == 80 - assert manager.state_mem['cuda'][TensorState.HOLD] == 0 - assert manager.state_mem['cpu'][TensorState.HOLD_AFTER_FWD] == 0 - assert manager.state_mem['cuda'][TensorState.HOLD_AFTER_FWD] == 12 - - st1.trans_state(TensorState.COMPUTE) - st2.trans_state(TensorState.COMPUTE) - st2.trans_state(TensorState.HOLD_AFTER_BWD) - - assert manager.total_number == 4 - assert manager.total_mem['cpu'] == 80 - assert manager.total_mem['cuda'] == 12 - assert manager.state_mem['cpu'][TensorState.HOLD] == 12 - assert manager.state_mem['cuda'][TensorState.HOLD] == 0 - assert manager.state_mem['cpu'][TensorState.HOLD_AFTER_FWD] == 0 - assert manager.state_mem['cuda'][TensorState.HOLD_AFTER_FWD] == 12 - assert manager.state_mem['cpu'][TensorState.HOLD_AFTER_BWD] == 60 - assert manager.state_mem['cuda'][TensorState.HOLD_AFTER_BWD] == 0 - assert manager.state_mem['cpu'][TensorState.COMPUTE] == 8 - assert manager.state_mem['cuda'][TensorState.COMPUTE] == 0 - - -if __name__ == '__main__': - test_gemini_manager() diff --git a/tests/test_zero/test_legacy/test_init_context.py b/tests/test_zero/test_legacy/test_init_context.py deleted file mode 100644 index 84493827193eef6b47db051fe1fe55c1a10098b1..0000000000000000000000000000000000000000 --- a/tests/test_zero/test_legacy/test_init_context.py +++ /dev/null @@ -1,73 +0,0 @@ -#!/usr/bin/env python -# -*- encoding: utf-8 -*- - -import pytest -import torch -from common import CONFIG - -import colossalai -from colossalai.logging import get_dist_logger -from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn -from colossalai.utils.cuda import get_current_device -from colossalai.utils.memory import colo_device_memory_used -from colossalai.zero.gemini.memory_tracer.utils import colo_model_mem_usage -from colossalai.zero.legacy.init_ctx import ZeroInitContext -from colossalai.zero.legacy.shard_utils import BucketTensorShardStrategy, TensorShardStrategy -from tests.components_to_test.registry import non_distributed_component_funcs - - -@parameterize("init_device_type", ['cpu', 'cuda']) -@parameterize("shard_strategy_class", [TensorShardStrategy, BucketTensorShardStrategy]) -def run_model_test(init_device_type, shard_strategy_class): - logger = get_dist_logger("test_zero_init") - - for name, get_components_func in non_distributed_component_funcs._registry.items(): - # because the ZeroInitContext automatically turns parameters to fp16 - # and the beit model use tensor.erfinv_() function to initialize weights - # tensor.erfinv_() doesn't support Half in CPU, we omit the beit model - if name == 'beit': - continue - model_builder, _, _, _, _ = get_components_func() - if init_device_type == 'cuda': - init_device = get_current_device() - elif init_device_type == 'cpu': - init_device = torch.device("cpu") - else: - continue - - model_numel_tensor = torch.zeros(1, dtype=torch.int) - with ZeroInitContext(target_device=init_device, - shard_strategy=shard_strategy_class(), - shard_param=True, - model_numel_tensor=model_numel_tensor): - model = model_builder(checkpoint=True) - - for param in model.parameters(): - assert hasattr(param, 'colo_attr') - assert param.colo_attr.sharded_data_tensor.dtype == torch.half - assert param.colo_attr.sharded_data_tensor.is_sharded - assert param.colo_attr.data_payload.device.type == init_device.type, \ - f'{param.colo_attr.data_payload.device.type} vs. {init_device.type}' - - cuda_mem_use, _ = colo_model_mem_usage(model) - model_data_cuda_mem_MB = cuda_mem_use / 1e6 - logger.info(f"Existing ZeRO Context.\nModel Data CUDA Memory {model_data_cuda_mem_MB} MB", ranks=[0]) - sys_cuda_mem_MB = colo_device_memory_used(get_current_device()) / 1e6 - logger.info(f"System CUDA Memory Usage {sys_cuda_mem_MB} MB", ranks=[0]) - logger.info(f"Model Number Parameter {model_numel_tensor.numpy()[0]/1e6} M", ranks=[0]) - - -def run_dist(rank, world_size, port): - colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') - run_model_test() - - -@pytest.mark.dist -@pytest.mark.parametrize("world_size", [1, 4]) -@rerun_if_address_is_in_use() -def test_zero_init_context(world_size): - spawn(run_dist, world_size) - - -if __name__ == '__main__': - test_zero_init_context(1) diff --git a/tests/test_zero/test_legacy/test_param_op.py b/tests/test_zero/test_legacy/test_param_op.py deleted file mode 100644 index b91371b98922d25c69af15c6ba645966a1a2ac88..0000000000000000000000000000000000000000 --- a/tests/test_zero/test_legacy/test_param_op.py +++ /dev/null @@ -1,82 +0,0 @@ -import copy - -import torch - -from colossalai.testing import clear_cache_before_run -from colossalai.zero.legacy.gemini.paramhooks import BaseParamHookMgr -from tests.components_to_test.registry import non_distributed_component_funcs - - -def allclose(tensor_a: torch.Tensor, tensor_b: torch.Tensor, loose=False) -> bool: - if loose: - return torch.allclose(tensor_a, tensor_b, atol=1e-3, rtol=1e-3) - return torch.allclose(tensor_a, tensor_b) - - -def run_model(model, inputs, label, criterion, use_param_hook=False): - if use_param_hook: - - class HooKWrapper: - - def __init__(self) -> None: - self.hook_triggered_times = 0 - - def wrapper_func(self): - - def hook(param, grad) -> torch.Tensor or None: - self.hook_triggered_times += 1 - return grad - - return hook - - hookwrapper = HooKWrapper() - param_list = [p for p in model.parameters()] - hook_mgr = BaseParamHookMgr(param_list) - hook_mgr.register_backward_hooks(hookwrapper.wrapper_func()) - - model.zero_grad(set_to_none=True) - - with torch.cuda.amp.autocast(): - if criterion: - y = model(inputs) - loss = criterion(y, label) - else: - loss = model(inputs, label) - loss = loss.float() - loss.backward() - - if use_param_hook: - hook_mgr.remove_hooks() - return hookwrapper.hook_triggered_times - - -@clear_cache_before_run() -def test_base_param_hook(): - test_models = ['repeated_computed_layers', 'resnet18', 'hanging_param_model', 'inline_op_model'] - # test_models = ['bert'] - - for model_name in test_models: - get_components_func = non_distributed_component_funcs.get_callable(model_name) - model_builder, train_dataloader, _, _, criterion = get_components_func() - - torch.manual_seed(0) - model = model_builder(checkpoint=True).cuda() - model.train() - - for i, (inputs, label) in enumerate(train_dataloader): - if i > 0: - break - model_copy = copy.deepcopy(model) - - run_model(model, inputs.cuda(), label.cuda(), criterion, False) - ret2 = run_model(model_copy, inputs.cuda(), label.cuda(), criterion, True) - - # Make sure param hook has only be fired once in case of parameter sharing - assert ret2 == len(list(model.parameters())) - - for p, p_copy in zip(model.parameters(), model_copy.parameters()): - assert allclose(p.grad, p_copy.grad), f"{p.grad} vs {p_copy.grad}" - - -if __name__ == '__main__': - test_base_param_hook() diff --git a/tests/test_zero/test_legacy/test_shard_model_v2.py b/tests/test_zero/test_legacy/test_shard_model_v2.py deleted file mode 100644 index 93d624aa2bbd66c17a2636370687358c3c4fe25e..0000000000000000000000000000000000000000 --- a/tests/test_zero/test_legacy/test_shard_model_v2.py +++ /dev/null @@ -1,64 +0,0 @@ -#!/usr/bin/env python -# -*- encoding: utf-8 -*- - -import pytest -import torch -from common import CONFIG, check_grads_padding, run_fwd_bwd -from torch.nn.parallel import DistributedDataParallel as DDP - -import colossalai -from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn -from colossalai.zero.legacy.init_ctx import ZeroInitContext -from colossalai.zero.legacy.shard_utils import BucketTensorShardStrategy -from colossalai.zero.legacy.sharded_model import ShardedModelV2 -from colossalai.zero.legacy.sharded_model._utils import cast_tensor_to_fp16 -from colossalai.zero.legacy.sharded_model.utils import col_model_deepcopy -from tests.components_to_test.registry import non_distributed_component_funcs - - -@parameterize("enable_autocast", [True]) -@parameterize("shard_strategy_class", [BucketTensorShardStrategy]) -def run_model_test(enable_autocast, shard_strategy_class): - test_models = ['repeated_computed_layers', 'resnet18', 'bert', 'hanging_param_model'] - shard_strategy = shard_strategy_class() - for model_name in test_models: - get_components_func = non_distributed_component_funcs.get_callable(model_name) - model_builder, train_dataloader, _, _, criterion = get_components_func() - - with ZeroInitContext(target_device=torch.device('cuda', torch.cuda.current_device()), - shard_strategy=shard_strategy, - shard_param=True): - zero_model = model_builder(checkpoint=True) - zero_model = ShardedModelV2(zero_model, shard_strategy) - - model = model_builder(checkpoint=True).half() - col_model_deepcopy(zero_model, model) - model = model.cuda() - - model = DDP(model, device_ids=[torch.cuda.current_device()]) - - for i, (data, label) in enumerate(train_dataloader): - if i > 5: - break - - data, label = cast_tensor_to_fp16(data).cuda(), label.cuda() - run_fwd_bwd(model, data, label, criterion, enable_autocast) - run_fwd_bwd(zero_model, data, label, criterion, enable_autocast) - - check_grads_padding(model, zero_model, loose=True) - - -def run_dist(rank, world_size, port): - colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') - run_model_test() - - -@pytest.mark.dist -@pytest.mark.parametrize("world_size", [1, 2]) -@rerun_if_address_is_in_use() -def test_shard_model_v2(world_size): - spawn(run_dist, world_size) - - -if __name__ == '__main__': - test_shard_model_v2(world_size=2) diff --git a/tests/test_zero/test_legacy/test_shard_param.py b/tests/test_zero/test_legacy/test_shard_param.py deleted file mode 100644 index 4ba43edceb5d1cd9ad2413ac2c0d771084d6056d..0000000000000000000000000000000000000000 --- a/tests/test_zero/test_legacy/test_shard_param.py +++ /dev/null @@ -1,91 +0,0 @@ -from copy import deepcopy - -import pytest -import torch -from common import CONFIG, allclose - -import colossalai -from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn -from colossalai.zero.legacy.gemini.stateful_tensor import StatefulTensor -from colossalai.zero.legacy.shard_utils import BucketTensorShardStrategy, TensorShardStrategy -from colossalai.zero.legacy.sharded_param import ShardedTensor -from colossalai.zero.legacy.sharded_param.sharded_param import ShardedParamV2 - - -@parameterize("shard_strategy_class", [TensorShardStrategy, BucketTensorShardStrategy]) -def run_shard_tensor_with_strategy(shard_strategy_class, world_size): - t = ShardedTensor(tensor=torch.randn(world_size * 2, 3)) - assert list(t.origin_shape) == [world_size * 2, 3] - assert list(t.shape) == [world_size * 2, 3] - - shard_strategy = shard_strategy_class() - - # test shard strategy - shard_strategy.shard([t]) - assert list(t.shape) == [6], f"{list(t.shape)} vs 6" - shard_strategy.gather([t]) - assert list(t.shape) == [world_size * 2, 3], f"{list(t.shape)} vs {[world_size * 2, 3]}" - - -def _run_shard_tensor(rank, world_size, port): - colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') - run_shard_tensor_with_strategy(world_size=world_size) - - -@pytest.mark.dist -@pytest.mark.parametrize("world_size", [1, 2]) -@rerun_if_address_is_in_use() -def test_shard_tensor(world_size): - spawn(_run_shard_tensor, world_size) - - -def _run_shard_param_v2(rank, world_size, port): - colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') - - param = torch.nn.Parameter(torch.randn(2, 3)) - param_ref = deepcopy(param) - sparam = ShardedParamV2(param=param) - - allclose(sparam.data_payload, param_ref.data) - - # Test get memory usage - sparam.saved_grad = StatefulTensor(torch.randn(2, 3)) - cuda_mem_use, cpu_mem_use = sparam.get_memory_usage() - assert cpu_mem_use == 2 * 3 * 4 * 2, f"cpu_mem_use: {cpu_mem_use}" - - sparam.set_data_none() - assert (param.data.numel() == 0) - cuda_mem_use, cpu_mem_use = sparam.get_memory_usage() - # 4 is size of dummy tensor of param.data - assert cpu_mem_use == 2 * 3 * 4 * 2 - - sparam.saved_grad = StatefulTensor(torch.randn(2, 3)) - sparam.set_data_none() - cuda_mem_use, cpu_mem_use = sparam.get_memory_usage() - assert cpu_mem_use == 2 * 3 * 4 * 2 - assert cuda_mem_use == 0 - - # append a grad to torch param - param.data = sparam.data_payload - param.grad = torch.randn(2, 3) - cuda_mem_use, cpu_mem_use = sparam.get_memory_usage() - assert cpu_mem_use == 2 * 3 * 4 * 2 + 2 * 3 * 4, f"cpu_mem_use {cpu_mem_use}" - assert cuda_mem_use == 0 - - # reuse torch grad for sparam - sparam.saved_grad = StatefulTensor(param.grad) - cuda_mem_use, cpu_mem_use = sparam.get_memory_usage() - assert cpu_mem_use == 2 * 3 * 4 * 2 - assert cuda_mem_use == 0 - - -@pytest.mark.dist -@pytest.mark.parametrize("world_size", [1, 2]) -@rerun_if_address_is_in_use() -def test_shard_param_v2(world_size): - spawn(_run_shard_param_v2, world_size) - - -if __name__ == '__main__': - # test_shard_tensor(2) - test_shard_param_v2(2) diff --git a/tests/test_zero/test_legacy/test_sharded_optim_state_dict.py b/tests/test_zero/test_legacy/test_sharded_optim_state_dict.py deleted file mode 100644 index 1ca144662722df86800431547d9855ef7f52e4b4..0000000000000000000000000000000000000000 --- a/tests/test_zero/test_legacy/test_sharded_optim_state_dict.py +++ /dev/null @@ -1,89 +0,0 @@ -import pytest -import torch - -import colossalai -from colossalai.nn.optimizer import HybridAdam -from colossalai.tensor import ProcessGroup -from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn -from colossalai.utils.cuda import get_current_device -from colossalai.zero.legacy.init_ctx import ZeroInitContext -from colossalai.zero.legacy.shard_utils import TensorShardStrategy -from colossalai.zero.legacy.sharded_model import ShardedModelV2 -from colossalai.zero.legacy.sharded_optim import ShardedOptimizerV2 -from tests.components_to_test.registry import non_distributed_component_funcs -from tests.test_tensor.common_utils import set_seed - - -def init_zero(model_builder, placement_policy): - device = get_current_device() if placement_policy == 'cuda' else torch.device('cpu') - shard_strategy = TensorShardStrategy() - with ZeroInitContext(target_device=device, shard_strategy=shard_strategy, shard_param=True): - model = model_builder() - model = ShardedModelV2( - model, - shard_strategy, - tensor_placement_policy=placement_policy, - reuse_fp16_shard=True, - ) - optim = HybridAdam(model.parameters(), lr=1e-3) - optim = ShardedOptimizerV2(model, optim, initial_scale=32) - return model, optim - - -def run_step(model, optim, criterion, data, label): - optim.zero_grad() - logits = model(data) - loss = criterion(logits, label) - optim.backward(loss) - optim.step() - - -def check_state_dict_eq(state_dict, other): - for p, state in state_dict['state'].items(): - other_state = other['state'][p] - for k, v in state.items(): - if isinstance(v, torch.Tensor): - assert torch.allclose(v, other_state[k], atol=1e-3), f'{v} vs {other_state[k]}' - else: - assert v == other_state[k] - - -@parameterize('placement_policy', ['cuda', 'cpu']) -def run_nested_model(placement_policy): - get_components_func = non_distributed_component_funcs.get_callable('simple_net') - model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func() - - set_seed(42) - model, optim = init_zero(model_builder, placement_policy) - set_seed(42) - model_copy, optim_copy = init_zero(model_builder, placement_policy) - - model.train() - model_copy.train() - pg = ProcessGroup() - set_seed(pg.dp_local_rank()) - data_iter = iter(train_dataloader) - - data, label = map(lambda x: x.cuda(), next(data_iter)) - run_step(model, optim, criterion, data, label) - optim_copy.load_state_dict(optim.state_dict()) - check_state_dict_eq(optim.state_dict(), optim_copy.state_dict()) - - data, label = map(lambda x: x.cuda(), next(data_iter)) - run_step(model_copy, optim_copy, criterion, data, label) - - -def run_dist(rank, world_size, port): - colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') - run_nested_model() - - -@pytest.mark.dist -@pytest.mark.parametrize('world_size', [1, 2]) -@rerun_if_address_is_in_use() -def test_sharded_optim_state_dist(world_size): - spawn(run_dist, world_size) - - -if __name__ == '__main__': - test_sharded_optim_state_dist(2) diff --git a/tests/test_zero/test_legacy/test_sharded_optim_v2.py b/tests/test_zero/test_legacy/test_sharded_optim_v2.py deleted file mode 100644 index c6f77995ebcd7820795d6a93e49195105a79bffd..0000000000000000000000000000000000000000 --- a/tests/test_zero/test_legacy/test_sharded_optim_v2.py +++ /dev/null @@ -1,110 +0,0 @@ -import pytest -import torch -import torch.distributed as dist -from common import CONFIG, check_sharded_model_params -from torch.nn.parallel import DistributedDataParallel as DDP - -import colossalai -from colossalai.amp import convert_to_apex_amp -from colossalai.nn.optimizer import CPUAdam -from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn -from colossalai.utils.cuda import get_current_device -from colossalai.zero.legacy.init_ctx import ZeroInitContext -from colossalai.zero.legacy.shard_utils import BucketTensorShardStrategy, TensorShardStrategy -from colossalai.zero.legacy.sharded_model import ShardedModelV2 -from colossalai.zero.legacy.sharded_model.utils import col_model_deepcopy -from colossalai.zero.legacy.sharded_optim import ShardedOptimizerV2 -from colossalai.zero.low_level._utils import has_inf_or_nan -from tests.components_to_test.registry import non_distributed_component_funcs - - -def _run_step(model, optimizer, data, label, criterion, enable_autocast=False): - model.train() - optimizer.zero_grad() - with torch.cuda.amp.autocast(enabled=enable_autocast): - if criterion: - y = model(data) - loss = criterion(y, label) - else: - loss = model(data, label) - - loss = loss.float() - if isinstance(model, ShardedModelV2): - optimizer.backward(loss) - else: - loss.backward() - optimizer.step() - - -@parameterize("cpu_offload", [True, False]) -@parameterize("use_cpuadam", [True, False]) -@parameterize("shard_strategy_class", [TensorShardStrategy, BucketTensorShardStrategy]) -@parameterize("gpu_margin_mem_ratio", [0.0, 0.7]) -def _run_test_sharded_optim_v2(cpu_offload, shard_strategy_class, use_cpuadam, gpu_margin_mem_ratio): - test_models = ['repeated_computed_layers', 'resnet18', 'bert', 'hanging_param_model'] - shard_strategy = shard_strategy_class() - - if use_cpuadam and cpu_offload is False: - return - if gpu_margin_mem_ratio > 0.0 and not (cpu_offload and use_cpuadam): - return - - for model_name in test_models: - get_components_func = non_distributed_component_funcs.get_callable(model_name) - model_builder, train_dataloader, _, optimizer_class, criterion = get_components_func() - - with ZeroInitContext(target_device=torch.device(f'cpu:0') if cpu_offload else get_current_device(), - shard_strategy=shard_strategy, - shard_param=True): - zero_model = model_builder(checkpoint=True) - zero_model = ShardedModelV2( - zero_model, - shard_strategy, - tensor_placement_policy='cpu' if cpu_offload else 'auto', - reuse_fp16_shard=use_cpuadam, - ) - - model = model_builder(checkpoint=True).half() - col_model_deepcopy(zero_model, model) - model = model.cuda().float() - - if use_cpuadam: - optimizer_class = CPUAdam - optim = optimizer_class(model.parameters(), lr=1e-3) - sharded_optim = optimizer_class(zero_model.parameters(), lr=1e-3) - sharded_optim = ShardedOptimizerV2(zero_model, - sharded_optim, - initial_scale=2**5, - gpu_margin_mem_ratio=gpu_margin_mem_ratio) - - amp_config = dict(opt_level='O2', keep_batchnorm_fp32=False) - apex_model, apex_optimizer = convert_to_apex_amp(model, optim, amp_config) - if dist.get_world_size() > 1: - apex_model = DDP(apex_model, device_ids=[torch.cuda.current_device()]) - - for i, (data, label) in enumerate(train_dataloader): - if i > 5: - break - data, label = data.cuda(), label.cuda() - _run_step(apex_model, apex_optimizer, data, label, criterion, False) - _run_step(zero_model, sharded_optim, data, label, criterion, False) - check_sharded_model_params(model, zero_model, loose=True, reuse_fp16_shard=use_cpuadam) - for param in model.parameters(): - assert not has_inf_or_nan(param) - - -def _run_dist(rank, world_size, port): - colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') - _run_test_sharded_optim_v2() - - -# use_cpuadam = True can be used with cpu_offload = False -@pytest.mark.dist -@pytest.mark.parametrize("world_size", [1, 2]) -@rerun_if_address_is_in_use() -def test_sharded_optim_v2(world_size): - spawn(_run_dist, world_size) - - -if __name__ == '__main__': - test_sharded_optim_v2(world_size=2) diff --git a/tests/test_zero/test_legacy/test_sharded_optim_with_sync_bn.py b/tests/test_zero/test_legacy/test_sharded_optim_with_sync_bn.py deleted file mode 100644 index 61d850d06080fd444e365fe5d4e5b91d284f42d8..0000000000000000000000000000000000000000 --- a/tests/test_zero/test_legacy/test_sharded_optim_with_sync_bn.py +++ /dev/null @@ -1,87 +0,0 @@ -#!/usr/bin/env python -# -*- encoding: utf-8 -*- - -import pytest -import torch -import torch.distributed as dist -from torchvision.models import resnet50 - -import colossalai -from colossalai.context.parallel_mode import ParallelMode -from colossalai.core import global_context as gpc -from colossalai.testing import rerun_if_address_is_in_use, spawn -from colossalai.zero.legacy.init_ctx import ZeroInitContext -from colossalai.zero.legacy.shard_utils import TensorShardStrategy - - -def run_dist(rank, world_size, port): - # this test only runs on resnet18 - # as this model has sync batch normalization - # need to configure cudnn deterministic so that - # randomness of convolution layers will be disabled - zero_config = dict(model_config=dict(shard_strategy=TensorShardStrategy())) - colossalai.launch(config=dict(zero=zero_config, cudnn_determinstic=True, cudnn_benchmark=False), - rank=rank, - world_size=world_size, - host='localhost', - port=port, - backend='nccl') - - with ZeroInitContext(target_device=torch.cuda.current_device(), - shard_strategy=gpc.config.zero.model_config.shard_strategy, - shard_param=True): - model = resnet50() - optimizer = torch.optim.Adam(model.parameters(), lr=0.001) - criterion = torch.nn.CrossEntropyLoss() - - engine, *args = colossalai.initialize(model, optimizer, criterion) - - # train for dummy iterations - engine.train() - for _ in range(2): - data = torch.rand(4, 3, 128, 128).cuda().half() - label = torch.randint(0, 10, size=(4,)).cuda() - engine.zero_grad() - out = engine(data) - loss = engine.criterion(out, label) - engine.backward(loss) - engine.step() - - # test - # need to make sure the batch norm stats are synchronized - # so that given the same input, the model will produce the same - # output on different ranks - engine.eval() - data = torch.rand(4, 3, 128, 128).cuda().half() - dist.broadcast(data, src=0, group=gpc.get_group(ParallelMode.DATA)) - - # predict - out = engine(data) - - # test if results are equal - tensor_list = [torch.empty_like(out) for _ in range(world_size - 1)] - tensor_list.insert(rank, out) - dist.all_gather(tensor_list=tensor_list, tensor=out, group=gpc.get_group(ParallelMode.DATA)) - - assert torch.all(tensor_list[0] == tensor_list[1]), \ - 'expected the output from different ranks to be the same, but got different values' - - -@pytest.mark.dist -@rerun_if_address_is_in_use() -def test_sharded_optim_with_sync_bn(): - """ - This test is to make sure that buffers are synchronized between ranks - when using ZeRO. An example of module buffer is the running stats of - BatchNormalization layer, i.e. mean and var. - - If the buffers are not synchronized, the model will produce different - output even though the input and parameters are the same. This is not - wanted if we are doing predictions. - - """ - spawn(run_dist, 2) - - -if __name__ == '__main__': - test_sharded_optim_with_sync_bn() diff --git a/tests/test_zero/test_legacy/test_state_dict.py b/tests/test_zero/test_legacy/test_state_dict.py deleted file mode 100644 index 5f76fff3e5c372eeb7b62280c73b04082d289901..0000000000000000000000000000000000000000 --- a/tests/test_zero/test_legacy/test_state_dict.py +++ /dev/null @@ -1,55 +0,0 @@ -#!/usr/bin/env python -# -*- encoding: utf-8 -*- - -from functools import partial - -import pytest -import torch -from common import CONFIG - -import colossalai -from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn -from colossalai.zero.legacy.init_ctx import ZeroInitContext -from colossalai.zero.legacy.shard_utils import BucketTensorShardStrategy, TensorShardStrategy -from colossalai.zero.legacy.sharded_model import ShardedModelV2 -from colossalai.zero.legacy.sharded_model.utils import col_model_deepcopy -from tests.components_to_test.registry import non_distributed_component_funcs - - -@parameterize("shard_strategy_class", [TensorShardStrategy, BucketTensorShardStrategy]) -def run_zero_state_dict(shard_strategy_class): - test_models = ['repeated_computed_layers', 'resnet18'] - shard_strategy = shard_strategy_class() - for model_name in test_models: - get_components_func = non_distributed_component_funcs.get_callable(model_name) - model_builder, train_dataloader, test_dataloader, optimizer, criterion = get_components_func() - - with ZeroInitContext(target_device=torch.device('cuda', torch.cuda.current_device()), - shard_strategy=shard_strategy, - shard_param=True): - zero_model = model_builder(checkpoint=True) - zero_model = ShardedModelV2(zero_model, shard_strategy) - - model = model_builder(checkpoint=True).half() - col_model_deepcopy(zero_model, model) - model = model.cuda() - - zero_state_dict = zero_model.state_dict() - for key, val in model.state_dict().items(): - assert torch.equal(val, zero_state_dict[key].to(val.device)) - - -def run_dist(rank, world_size, port): - colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') - run_zero_state_dict() - - -@pytest.mark.dist -@pytest.mark.parametrize("world_size", [1, 2]) -@rerun_if_address_is_in_use() -def test_zero_state_dict(world_size): - spawn(run_dist, world_size) - - -if __name__ == '__main__': - test_zero_state_dict(2) diff --git a/tests/test_zero/test_legacy/test_tensor_utils.py b/tests/test_zero/test_legacy/test_tensor_utils.py deleted file mode 100644 index 238bc3fe1a98084df12381ae48a4296c9618c08b..0000000000000000000000000000000000000000 --- a/tests/test_zero/test_legacy/test_tensor_utils.py +++ /dev/null @@ -1,94 +0,0 @@ -import pytest -import torch - -import colossalai -from colossalai.testing import rerun_if_address_is_in_use, spawn -from colossalai.utils.cuda import get_current_device -from colossalai.zero.legacy.gemini.stateful_tensor import StatefulTensor -from colossalai.zero.legacy.gemini.tensor_utils import ( - colo_model_data_move_to_cpu, - colo_model_data_tensor_move, - colo_model_data_tensor_move_inline, - colo_model_tensor_clone, - colo_tensor_mem_usage, -) - - -def _run_colo_tensor_mem_usage(): - for i in range(1): - if i == 1: - t1 = StatefulTensor(torch.randn(2, 2)) - t2 = StatefulTensor(torch.randn(4, 4)) - c1, g1 = colo_tensor_mem_usage(t1) - c2, g2 = colo_tensor_mem_usage(t2) - assert c1 * 4 == c2 - assert g1 * 4 == g2 - else: - t1 = torch.randn(2, 2) - t2 = torch.randn(4, 4) - c1, g1 = colo_tensor_mem_usage(t1) - c2, g2 = colo_tensor_mem_usage(t2) - assert c1 * 4 == c2 - assert g1 * 4 == g2 - - -def _run_colo_model_data_tensor_move_inline(): - for t in [StatefulTensor(torch.randn(2, 3)), torch.randn(2, 3)]: - colo_model_data_tensor_move_inline(t, get_current_device()) - assert t.device == get_current_device() - - -def _run_colo_model_data_tensor_move(): - for t in [(StatefulTensor(torch.ones(2, 3)), StatefulTensor(torch.zeros(2, 3).to(get_current_device()))), - (torch.ones(2, 3), torch.zeros(2, 3).to(get_current_device()))]: - cpu_t, cuda_t = t - colo_model_data_tensor_move(cpu_t, cuda_t) - assert cuda_t.device == get_current_device() - - -def _run_colo_model_data_move_to_cpu(): - for t in [StatefulTensor(torch.randn(2, 2)), torch.randn(4, 4)]: - colo_model_data_move_to_cpu(t) - assert t.device == torch.device("cpu") - - -def _run_colo_model_tensor_clone(): - for t in [ - StatefulTensor(torch.randn(2, 2).cuda(torch.cuda.current_device())), - torch.randn(4, 4).cuda(torch.cuda.current_device()) - ]: - if issubclass(type(t), StatefulTensor): - assert t.payload.device == get_current_device() - else: - assert t.device == get_current_device() - p = colo_model_tensor_clone(t, get_current_device()) - assert p.device == get_current_device() - for i in range(2): - for j in range(2): - if issubclass(type(t), StatefulTensor): - assert t.payload.device == p.device - assert t.payload[i][j] == p[i][j] - else: - assert t.device == p.device - assert t[i][j] == p[i][j] - - -def run_dist(rank, world_size, port): - colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') - - _run_colo_tensor_mem_usage() - _run_colo_model_data_tensor_move_inline() - _run_colo_model_data_tensor_move() - _run_colo_model_data_move_to_cpu() - _run_colo_model_tensor_clone() - - -@pytest.mark.dist -@pytest.mark.parametrize("world_size", [2, 4]) -@rerun_if_address_is_in_use() -def test_zero_tensor_utils(world_size): - spawn(run_dist, world_size) - - -if __name__ == '__main__': - test_zero_tensor_utils(world_size=2) diff --git a/tests/test_zero/test_legacy/test_zero_engine.py b/tests/test_zero/test_legacy/test_zero_engine.py deleted file mode 100644 index dc8847ce56ab97c188684c1cf154f0d37628f5d5..0000000000000000000000000000000000000000 --- a/tests/test_zero/test_legacy/test_zero_engine.py +++ /dev/null @@ -1,106 +0,0 @@ -#!/usr/bin/env python -# -*- encoding: utf-8 -*- - -import pytest -import torch -import torch.distributed as dist -from common import MP_PARALLEL_CONFIG, ZERO_PARALLEL_CONFIG, check_params, check_sharded_model_params -from torch.nn.parallel import DistributedDataParallel as DDP - -import colossalai -from colossalai.core import global_context as gpc -from colossalai.testing import rerun_if_address_is_in_use, spawn -from colossalai.zero.legacy.init_ctx import ZeroInitContext -from colossalai.zero.legacy.sharded_model.utils import col_model_deepcopy -from colossalai.zero.low_level._utils import has_inf_or_nan -from tests.components_to_test.registry import non_distributed_component_funcs - - -def run_dist(rank, world_size, port, parallel_config): - colossalai.launch(config=parallel_config, - rank=rank, - world_size=world_size, - host='localhost', - port=port, - backend='nccl') - - test_models = ['repeated_computed_layers', 'resnet18', 'bert'] - for model_name in test_models: - get_components_func = non_distributed_component_funcs.get_callable(model_name) - model_builder, train_dataloader, _, optimizer_class, criterion = get_components_func() - with ZeroInitContext(target_device=torch.cuda.current_device(), - shard_strategy=gpc.config.zero.model_config.shard_strategy, - shard_param=True): - colo_model = model_builder(checkpoint=True) - - colo_optimizer = optimizer_class(colo_model.parameters(), lr=1e-3) - engine, train_dataloader, _, _ = colossalai.initialize(colo_model, - optimizer=colo_optimizer, - criterion=criterion, - train_dataloader=train_dataloader) - torch_model = model_builder(checkpoint=True).half() - col_model_deepcopy(engine.model, torch_model) - torch_model = torch_model.cuda().float() - - engine.train() - torch_optimizer = optimizer_class(torch_model.parameters(), lr=1e-3) - - if dist.get_world_size() > 1: - torch_model = DDP(torch_model, device_ids=[torch.cuda.current_device()]) - - i = 0 - for data, label in train_dataloader: - if i > 4: - break - - data, label = data.cuda(), label.cuda() - - engine.zero_grad() - torch_optimizer.zero_grad() - - if criterion: - output = engine(data) - loss = engine.criterion(output, label) - - torch_output = torch_model(data) - torch_loss = engine.criterion(torch_output, label) - else: - loss = engine(data, label) - torch_loss = torch_model(data, label) - - engine.backward(loss) - engine.step() - - torch_loss.backward() - - for param in torch_model.parameters(): - if param.grad is not None: - assert not has_inf_or_nan(param.grad) - - torch_optimizer.step() - i += 1 - - if parallel_config == MP_PARALLEL_CONFIG: - check_params(torch_model, colo_model, loose=True) - elif parallel_config == ZERO_PARALLEL_CONFIG: - check_sharded_model_params(torch_model, colo_model, loose=True) - - -# FIXME: enable this test in next PR -@pytest.mark.skip -@pytest.mark.dist -@pytest.mark.parametrize("world_size", [2, 4]) -@rerun_if_address_is_in_use() -def test_mp_engine(world_size): - spawn(run_dist, world_size, parallel_config=MP_PARALLEL_CONFIG) - - -@pytest.mark.dist -@pytest.mark.parametrize("world_size", [1, 2]) -@rerun_if_address_is_in_use() -def test_zero_engine(world_size): - spawn(run_dist, world_size, parallel_config=ZERO_PARALLEL_CONFIG) - - -if __name__ == '__main__': - test_zero_engine(world_size=4) diff --git a/tests/test_zero/test_low_level/test_grad_acc.py b/tests/test_zero/test_low_level/test_grad_acc.py index 2ae1f3a99d79b3d3657617a98cdafad7a10cb683..3c5baea138e0e3b941d81f4eb34000943812f598 100644 --- a/tests/test_zero/test_low_level/test_grad_acc.py +++ b/tests/test_zero/test_low_level/test_grad_acc.py @@ -9,11 +9,11 @@ from torch.testing import assert_close import colossalai from colossalai.testing import spawn from colossalai.testing.random import seed_all +from colossalai.utils import conditional_context from colossalai.zero import LowLevelZeroOptimizer class MlpModel(nn.Module): - def __init__(self): super(MlpModel, self).__init__() self.linear1 = nn.Linear(128, 256) @@ -35,41 +35,29 @@ def exam_zero_1_2_grad_acc(): # create optimizer zero1_optimizer = torch.optim.Adam(zero1_model.parameters(), lr=1) zero2_optimizer = torch.optim.Adam(zero2_model.parameters(), lr=1) - zero1_optimizer = LowLevelZeroOptimizer(zero1_optimizer, - overlap_communication=True, - initial_scale=32, - clip_grad_norm=1.0, - verbose=True) - zero2_optimizer = LowLevelZeroOptimizer(zero2_optimizer, - overlap_communication=True, - partition_grad=True, - initial_scale=32, - clip_grad_norm=1.0) + zero1_optimizer = LowLevelZeroOptimizer( + zero1_optimizer, overlap_communication=True, initial_scale=32, clip_grad_norm=1.0, verbose=True + ) + zero2_optimizer = LowLevelZeroOptimizer( + zero2_optimizer, overlap_communication=True, partition_grad=True, initial_scale=32, clip_grad_norm=1.0 + ) # create data seed_all(2021 + local_rank) input_data1 = torch.randn(32, 128).cuda() input_data2 = torch.randn(32, 128).cuda() - def fwd_bwd_func(number, cur_data): + def fwd_bwd_func(number, cur_data, check_flag): # zero-dp forward zero1_output = zero1_model(cur_data) zero2_output = zero2_model(cur_data) assert torch.equal(zero1_output, zero2_output) # zero-dp backward - zero1_optimizer.backward(zero1_output.sum().float(), sync_grad=False) - zero2_optimizer.backward(zero2_output.sum().float(), sync_grad=False) - - for (n, z1p), z2p in zip(zero1_model.named_parameters(), zero2_model.parameters()): - if z2p.grad is not None: - # print(local_rank, n, z1p.shape, torch.max(z2p.grad), torch.max(torch.abs(z1p.grad - z2p.grad))) - assert torch.equal(z1p.grad, z2p.grad) - - zero1_optimizer._sync_grad() - zero2_optimizer._sync_grad() + zero1_optimizer.backward(zero1_output.sum().float()) + zero2_optimizer.backward(zero2_output.sum().float()) - fwd_bwd_func(0, input_data1) - fwd_bwd_func(1, input_data2) + fwd_bwd_func(0, input_data1, True) + fwd_bwd_func(1, input_data2, False) # step zero1_optimizer.step() @@ -80,9 +68,8 @@ def exam_zero_1_2_grad_acc(): assert torch.equal(z1p.data, z2p.data) -def exam_zero_1_grad_acc(): +def exam_zero_1_grad_acc(sync): local_rank = torch.distributed.get_rank() - grad_scale = 32 seed_all(2008) # create models @@ -99,11 +86,9 @@ def exam_zero_1_grad_acc(): # we only test stage 1 here # in `check_sharded_param_consistency.py`, we will test whether # level 1 and 2 will produce exactly the same results - zero_optimizer = LowLevelZeroOptimizer(zero_optimizer, - overlap_communication=False, - initial_scale=grad_scale, - reduce_bucket_size=262144, - clip_grad_norm=1.0) + zero_optimizer = LowLevelZeroOptimizer( + zero_optimizer, overlap_communication=False, reduce_bucket_size=262144, clip_grad_norm=1.0 + ) torch_optimizer = torch.optim.Adam(torch_model.parameters(), lr=1) @@ -112,30 +97,25 @@ def exam_zero_1_grad_acc(): input_data1 = torch.randn(32, 128).cuda() input_data2 = torch.randn(32, 128).cuda() - def fwd_bwd_func(number, cur_data, check_flag): - # zero-dp forward - zero_output = zero_model(cur_data) - - # torch-ddp forward - torch_output = torch_model(cur_data) - assert torch.equal(zero_output, torch_output) + def fwd_bwd_func(no_sync, cur_data, check_flag): + # zero1 fwd and bwd + with conditional_context(zero_optimizer.no_sync(), no_sync): + zero_output = zero_model(cur_data) + zero_optimizer.backward(zero_output.sum().float()) - # zero-dp backward - zero_optimizer.backward(zero_output.sum().float(), sync_grad=False) - # torch-ddp backward - torch_output.sum().backward() + # torch-ddp fwd and bwd + with conditional_context(torch_model.no_sync(), no_sync): + torch_output = torch_model(cur_data) + assert torch.equal(zero_output, torch_output) + torch_output.sum().backward() if check_flag: # check grad for (n, p), z1p in zip(torch_model.named_parameters(), zero_model.parameters()): - unscale_grad = z1p.grad / grad_scale - # print(n, p.shape, torch.max(torch.abs(p.grad - unscale_grad))) - assert torch.equal(p.grad, unscale_grad) + assert torch.equal(p.grad, z1p.grad) - zero_optimizer._sync_grad() - - fwd_bwd_func(0, input_data1, True) - fwd_bwd_func(1, input_data2, False) + fwd_bwd_func(sync, input_data1, sync) + fwd_bwd_func(False, input_data2, False) zero_optimizer.step() torch.nn.utils.clip_grad_norm_(torch_model.parameters(), 1.0) @@ -148,9 +128,10 @@ def exam_zero_1_grad_acc(): def run_dist(rank, world_size, port): - colossalai.launch(config=dict(), rank=rank, world_size=world_size, port=port, host='localhost') + colossalai.launch(config=dict(), rank=rank, world_size=world_size, port=port, host="localhost") - exam_zero_1_grad_acc() + exam_zero_1_grad_acc(sync=True) + exam_zero_1_grad_acc(sync=False) exam_zero_1_2_grad_acc() @@ -159,5 +140,5 @@ def test_grad_accumulation(): spawn(run_dist, 2) -if __name__ == '__main__': +if __name__ == "__main__": test_grad_accumulation() diff --git a/tests/test_zero/test_low_level/test_zero1_2.py b/tests/test_zero/test_low_level/test_zero1_2.py index 4086af9d896e83ba51953b195fd23f15e301628a..ebda9f6f25c5f689f22920fd0f396ff2d34824b7 100644 --- a/tests/test_zero/test_low_level/test_zero1_2.py +++ b/tests/test_zero/test_low_level/test_zero1_2.py @@ -7,17 +7,17 @@ from torch.nn.parallel import DistributedDataParallel as DDP from torch.testing import assert_close import colossalai -from colossalai.testing import rerun_if_address_is_in_use, spawn +from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn from colossalai.testing.random import seed_all from colossalai.zero import LowLevelZeroOptimizer class MlpModel(nn.Module): - def __init__(self): super(MlpModel, self).__init__() - self.linear1 = nn.Linear(128, 256) - self.linear2 = nn.Linear(256, 512) + self.linear1 = nn.Linear(123, 253) + self.linear_drop = nn.Linear(253, 253) + self.linear2 = nn.Linear(253, 512) def forward(self, x): x = self.linear1(x) @@ -25,19 +25,32 @@ class MlpModel(nn.Module): return x -def half_close(a, b, loose=False): +def loose_close(a, b, dtype: torch.dtype = torch.float32): rtol = None atol = None - if loose: + if dtype is torch.float16: rtol = 5e-2 atol = 5e-4 + elif dtype is torch.bfloat16: + rtol = 4e-3 + atol = 4e-3 - a = a.detach().half() - b = b.detach().half() + a = a.detach().to(dtype) + b = b.detach().to(dtype) assert_close(a, b, rtol=rtol, atol=atol) +def split_ddp_grad(grad, world_size): + with torch.no_grad(): + grad = grad.clone().detach().flatten() + padding_size = (world_size - grad.numel() % world_size) % world_size + if padding_size > 0: + grad = torch.nn.functional.pad(grad, [0, padding_size]) + splited_grad = grad.split(grad.numel() // world_size) + return splited_grad + + def exam_zero_1_2(): """ In this test, we want to test whether zero stage 1 and 2 @@ -59,33 +72,29 @@ def exam_zero_1_2(): # create optimizer zero1_optimizer = torch.optim.Adam(zero1_model.parameters(), lr=1) zero2_optimizer = torch.optim.Adam(zero2_model.parameters(), lr=1) - zero1_optimizer = LowLevelZeroOptimizer(zero1_optimizer, - overlap_communication=True, - initial_scale=128, - verbose=True) - zero2_optimizer = LowLevelZeroOptimizer(zero2_optimizer, - overlap_communication=True, - partition_grad=True, - initial_scale=128) + zero1_optimizer = LowLevelZeroOptimizer( + zero1_optimizer, overlap_communication=True, initial_scale=128, verbose=True + ) + zero2_optimizer = LowLevelZeroOptimizer( + zero2_optimizer, overlap_communication=True, partition_grad=True, initial_scale=128 + ) # create data seed_all(2001 + local_rank) - input_data = torch.randn(32, 128).cuda() + input_data = torch.randn(32, 123).cuda() zero1_output = zero1_model(input_data) zero2_output = zero2_model(input_data) assert torch.equal(zero1_output, zero2_output) # zero-dp backward - zero1_optimizer.backward(zero1_output.mean().float(), sync_grad=False) - zero2_optimizer.backward(zero2_output.mean().float(), sync_grad=False) + zero1_optimizer.backward(zero1_output.mean().float()) + zero2_optimizer.backward(zero2_output.mean().float()) - for (n, z1p), z2p in zip(zero1_model.named_parameters(), zero2_model.parameters()): - if z2p.grad is not None: - # print(local_rank, n, z1p.shape, torch.max(z2p.grad), torch.max(torch.abs(z1p.grad - z2p.grad))) - assert torch.equal(z1p.grad, z2p.grad) - - zero1_optimizer._sync_grad() - zero2_optimizer._sync_grad() + # check grad + z1g_list = zero1_optimizer._grad_store.get_working_grads_by_group_id(0) + z2g_list = zero2_optimizer._grad_store.get_working_grads_by_group_id(0) + for z1g, z2g in zip(z1g_list, z2g_list): + assert torch.equal(z1g, z2g) # step zero1_optimizer.step() @@ -96,7 +105,8 @@ def exam_zero_1_2(): assert torch.equal(z1p.data, z2p.data) -def exam_zero_1_torch_ddp(): +@parameterize("dtype", [torch.float16, torch.bfloat16]) +def exam_zero_1_torch_ddp(world_size, dtype: torch.dtype): """ In this test, two pairs of model and optimizers are created. 1. zero: use sharded optimizer and fp16 parameters @@ -109,15 +119,10 @@ def exam_zero_1_torch_ddp(): seed_all(1453) # create models - zero_model = MlpModel() - torch_model = copy.deepcopy(zero_model) - - zero_model = zero_model.cuda().half() - torch_model = DDP(torch_model.cuda(), bucket_cap_mb=0) - torch_model = torch_model.cuda() + torch_model = MlpModel().cuda() + zero_model = copy.deepcopy(torch_model).to(dtype) - # for (n, p), z1p in zip(torch_model.named_parameters(), zero_model.parameters()): - # half_close(p.data, z1p.data) + torch_model = DDP(torch_model.cuda(), static_graph=True).cuda() # create optimizer zero_optimizer = torch.optim.SGD(zero_model.parameters(), lr=1) @@ -125,36 +130,38 @@ def exam_zero_1_torch_ddp(): # we only test stage 1 here # in `check_sharded_param_consistency.py`, we will test whether # level 1 and 2 will produce exactly the same results - zero_optimizer = LowLevelZeroOptimizer(zero_optimizer, - overlap_communication=True, - initial_scale=1, - reduce_bucket_size=262144) + zero_optimizer = LowLevelZeroOptimizer( + zero_optimizer, overlap_communication=True, initial_scale=1, reduce_bucket_size=1024 * 1024 + ) torch_optimizer = torch.optim.SGD(torch_model.parameters(), lr=1) seed_all(1453 + local_rank) # create - input_data = torch.rand(32, 128).cuda() + input_data = torch.rand(32, 123).cuda() # zero-dp forward - zero_output = zero_model(input_data.half()) + zero_output = zero_model(input_data.to(dtype)) # torch-ddp forward torch_output = torch_model(input_data) - half_close(zero_output, torch_output, loose=True) + loose_close(zero_output, torch_output, dtype=dtype) # zero-dp backward - zero_optimizer.backward(zero_output.mean().float(), sync_grad=False) + zero_optimizer.backward(zero_output.mean().float()) # torch-ddp backward torch_output.mean().backward() # check grad for (n, p), z1p in zip(torch_model.named_parameters(), zero_model.parameters()): - half_close(p.grad, z1p.grad, loose=True) + if p.grad is not None: + zero_grad_list = zero_optimizer._grad_store.get_partitioned_gradients_by_param_id(0, id(z1p)) + torch_grad_list = split_ddp_grad(p.grad, world_size) + for zero_grad, torch_grad in zip(zero_grad_list, torch_grad_list): + loose_close(zero_grad, torch_grad, dtype=dtype) # zero-dp step - zero_optimizer._sync_grad() zero_optimizer.step() # torch ddp step @@ -162,14 +169,13 @@ def exam_zero_1_torch_ddp(): # check updated param for (n, p), z1p in zip(torch_model.named_parameters(), zero_model.parameters()): - # print(n, torch.max(torch.abs(p.data - z1p.data))) - half_close(p.data, z1p.data, loose=True) + loose_close(p.data, z1p.data, dtype=dtype) def run_dist(rank, world_size, port): - colossalai.launch(config=dict(), rank=rank, world_size=world_size, port=port, host='localhost') + colossalai.launch(config=dict(), rank=rank, world_size=world_size, port=port, host="localhost") - exam_zero_1_torch_ddp() + exam_zero_1_torch_ddp(world_size=world_size) exam_zero_1_2() @@ -179,5 +185,5 @@ def test_zero_1_2(): spawn(run_dist, 2) -if __name__ == '__main__': +if __name__ == "__main__": test_zero_1_2() diff --git a/tests/test_zero/test_low_level/test_zero_ckpt.py b/tests/test_zero/test_low_level/test_zero_ckpt.py new file mode 100644 index 0000000000000000000000000000000000000000..e9fc8598a62d496de314297e6552e9ff26cffca6 --- /dev/null +++ b/tests/test_zero/test_low_level/test_zero_ckpt.py @@ -0,0 +1,118 @@ +import copy + +import pytest +import torch +import torch.nn as nn +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.testing import assert_close + +import colossalai +from colossalai.testing import rerun_if_address_is_in_use, spawn +from colossalai.testing.random import seed_all +from colossalai.zero import LowLevelZeroOptimizer + + +class MlpModel(nn.Module): + def __init__(self): + super(MlpModel, self).__init__() + self.linear1 = nn.Linear(12, 24) + self.linear2 = nn.Linear(24, 12) + + def forward(self, x): + x = self.linear1(x) + x = self.linear2(x) + return x + + +def loose_close(a, b, dtype: torch.dtype = torch.float32): + rtol = None + atol = None + if dtype is torch.float16: + rtol = 5e-2 + atol = 5e-4 + elif dtype is torch.bfloat16: + rtol = 4e-3 + atol = 4e-3 + + a = a.detach().to(dtype) + b = b.detach().to(dtype).to(a.device) + + assert_close(a, b, rtol=rtol, atol=atol) + + +def exam_zero_1_torch_ddp_ckpt(): + """ + We examine the state_dict of zero and DDP. + Moreover, we examine the zero's loading checkpoint of a torch ckpt. + """ + local_rank = torch.distributed.get_rank() + seed_all(1453) + + # create models + torch_model = MlpModel().cuda() + zero_model = copy.deepcopy(torch_model) + + torch_model = DDP(torch_model.cuda(), static_graph=True).cuda() + + # create optimizer + zero_optimizer = torch.optim.Adam(zero_model.parameters(), lr=1) + + # we only test stage 1 here + # the state dicts of stage 1 and stage 2 are the same + zero_optimizer = LowLevelZeroOptimizer( + zero_optimizer, overlap_communication=True, initial_scale=1, reduce_bucket_size=262144 + ) + + torch_optimizer = torch.optim.Adam(torch_model.parameters(), lr=1) + + seed_all(1453 + local_rank) + # create + input_data = torch.rand(4, 12).cuda() + + # forward + zero_output = zero_model(input_data) + torch_output = torch_model(input_data) + + # backward + zero_optimizer.backward(zero_output.mean().float()) + torch_output.mean().backward() + + # step + zero_optimizer.step() + torch_optimizer.step() + + torch_state_dict = torch_optimizer.state_dict() + zero_state_dict = zero_optimizer.state_dict() + + # examine the original state dict + for torch_state, zero_state in zip(torch_state_dict["state"].values(), zero_state_dict["state"].values()): + for t_v, z_v in zip(torch_state.values(), zero_state.values()): + loose_close(t_v, z_v) + + # empty the optimzer state + zero_optimizer.optim.state = [] + + # zero load a torch checkpoint + zero_optimizer.load_state_dict(copy.deepcopy(torch_state_dict)) + zero_state_dict = zero_optimizer.state_dict() + + # examine the loaded state dict + for torch_state, zero_state in zip(torch_state_dict["state"].values(), zero_state_dict["state"].values()): + for t_v, z_v in zip(torch_state.values(), zero_state.values()): + loose_close(t_v, z_v) + + +def run_dist(rank, world_size, port): + colossalai.launch(config=dict(), rank=rank, world_size=world_size, port=port, host="localhost") + + exam_zero_1_torch_ddp_ckpt() + + +@pytest.mark.dist +@rerun_if_address_is_in_use() +def test_zero_ckpt(): + spawn(run_dist, 2) + + +if __name__ == "__main__": + test_zero_ckpt() diff --git a/tests/test_zero/test_low_level/test_zero_init.py b/tests/test_zero/test_low_level/test_zero_init.py deleted file mode 100644 index aeeaff5b5cb92a08552fc2bd9b65cd046c25a9e5..0000000000000000000000000000000000000000 --- a/tests/test_zero/test_low_level/test_zero_init.py +++ /dev/null @@ -1,56 +0,0 @@ -import pytest -import torch -import torch.distributed as dist -import torch.nn as nn - -import colossalai -from colossalai.tensor import ProcessGroup -from colossalai.testing import spawn -from colossalai.utils import get_current_device -from colossalai.zero import ColoInitContext, LowLevelZeroOptimizer - - -class MlpModel(nn.Module): - - def __init__(self): - super(MlpModel, self).__init__() - self.linear1 = nn.Linear(128, 256) - self.linear2 = nn.Linear(256, 512) - - def forward(self, x): - x = self.linear1(x) - x = self.linear2(x) - return x - - -def exam_zero_init(): - dp_2_tp_2_pg = ProcessGroup(dp_degree=2, tp_degree=2) - model1 = MlpModel().cuda() - with ColoInitContext(device=get_current_device(), default_pg=dp_2_tp_2_pg): - model2 = MlpModel() - optimizer1 = LowLevelZeroOptimizer(torch.optim.Adam(model1.parameters(), lr=1)) - optimizer2 = LowLevelZeroOptimizer(torch.optim.Adam(model2.parameters(), lr=1)) - - assert optimizer1._local_rank == optimizer2._local_rank - assert optimizer1._world_size == optimizer2._world_size - assert optimizer1._dp_global_ranks == optimizer2._dp_global_ranks - - mp_group1 = optimizer1._mp_torch_group - mp_group2 = optimizer2._mp_torch_group - assert dist.get_world_size(mp_group1) == dist.get_world_size(mp_group2) - assert dist.get_rank(mp_group1) == dist.get_rank(mp_group2) - - -def run_dist(rank, world_size, port): - config_dict = dict(parallel=dict(data=2, tensor=dict(size=2, mode='1d'))) - colossalai.launch(config=config_dict, rank=rank, world_size=world_size, port=port, host='localhost') - exam_zero_init() - - -@pytest.mark.dist -def test_zero_init(): - spawn(run_dist, 4) - - -if __name__ == '__main__': - test_zero_init() diff --git a/tests/test_zero/test_low_level/test_zero_tp.py b/tests/test_zero/test_low_level/test_zero_tp.py deleted file mode 100644 index f0804f4bb5ba51b5990b91d1b43f3dd365c93b7a..0000000000000000000000000000000000000000 --- a/tests/test_zero/test_low_level/test_zero_tp.py +++ /dev/null @@ -1,93 +0,0 @@ -import pytest -import torch -import torch.nn as nn -from torch.nn.parallel import DistributedDataParallel as DDP -from torch.testing import assert_close - -import colossalai -from colossalai.tensor import ProcessGroup -from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn -from colossalai.utils import get_current_device -from colossalai.zero import ColoInitContext, LowLevelZeroOptimizer -from tests.test_tensor.common_utils import set_seed, split_param_col_tp1d, split_param_row_tp1d, tensor_shard_equal - - -def strict_shard_equal(tensor, shard, tp_pg, rtol=1e-3, atol=1e-4): - return tensor_shard_equal(tensor, shard, tp_pg.tp_local_rank(), tp_pg.tp_world_size(), rtol, atol) - - -class MlpModel(nn.Module): - - def __init__(self): - super(MlpModel, self).__init__() - self.linear1 = nn.Linear(32, 128) - self.act = nn.GELU() - self.linear2 = nn.Linear(128, 32) - - def forward(self, x): - y = self.linear1(x) - y = self.act(y) - y = self.linear2(y) - return x + y - - -@parameterize("overlap_flag", [False, True]) -@parameterize("partition_flag", [False, True]) -def exam_zero_with_tp(overlap_flag, partition_flag): - set_seed(233010) - tp_pg = ProcessGroup(tp_degree=2) - - with ColoInitContext(device=get_current_device(), default_pg=tp_pg): - hybrid_model = MlpModel() - torch_model = MlpModel().cuda() - for pt, ph in zip(torch_model.parameters(), hybrid_model.parameters()): - pt.data.copy_(ph.data) - - for name, param in hybrid_model.named_parameters(): - if 'linear1' in name: - split_param_row_tp1d(param, tp_pg) - param.compute_spec.set_output_replicate(False) - if 'linear2.weight' in name: - split_param_col_tp1d(param, tp_pg) - - torch_model = DDP(torch_model, device_ids=[tp_pg.rank()], process_group=tp_pg.dp_process_group()) - torch_optim = torch.optim.Adam(torch_model.parameters(), lr=1e-2) # set to 1e-2 for torch-1.11 - hybrid_optim = torch.optim.Adam(hybrid_model.parameters(), lr=1e-2) - hybrid_optim = LowLevelZeroOptimizer(hybrid_optim, - initial_scale=2, - clip_grad_norm=1.0, - overlap_communication=overlap_flag, - partition_grad=partition_flag) - - dp_local_rank = tp_pg.dp_local_rank() - set_seed(255 + dp_local_rank) - - data = torch.randn(8, 32, device=get_current_device()) - torch_loss = torch_model(data).sum() - hybrid_loss = hybrid_model(data).sum() - assert_close(torch_loss, hybrid_loss) - - torch_loss.backward() - torch.nn.utils.clip_grad_norm_(torch_model.parameters(), 1.0) - hybrid_optim.backward(hybrid_loss) - - torch_optim.step() - hybrid_optim.step() - - for (name, pt), ph in zip(torch_model.named_parameters(), hybrid_model.parameters()): - assert strict_shard_equal(pt.data, ph.data, tp_pg) - - -def run_dist(rank, world_size, port): - colossalai.launch(config={}, rank=rank, world_size=world_size, port=port, host='localhost') - exam_zero_with_tp() - - -@pytest.mark.dist -@rerun_if_address_is_in_use() -def test_zero_with_tp(): - spawn(run_dist, 4) - - -if __name__ == '__main__': - test_zero_with_tp() diff --git a/version.txt b/version.txt index a45be4627678330112f94a9b48a5e821ed846104..1c09c74e221cd58f30240fbcfd9545ed19df54d7 100644 --- a/version.txt +++ b/version.txt @@ -1 +1 @@ -0.2.8 +0.3.3